Commit 8ceed7c7 authored by Mohammad's avatar Mohammad
Browse files

changed gpt2 masking to binary and masked_fill

parent c0a59a66
...@@ -27,8 +27,7 @@ from .utils import scaled_init_method_normal ...@@ -27,8 +27,7 @@ from .utils import scaled_init_method_normal
def gpt2_attention_mask_func(attention_scores, ltor_mask): def gpt2_attention_mask_func(attention_scores, ltor_mask):
attention_scores = torch.mul(attention_scores, ltor_mask) - \ attention_scores.masked_fill_(ltor_mask, -10000.0)
10000.0 * (1.0 - ltor_mask)
return attention_scores return attention_scores
......
...@@ -42,8 +42,7 @@ def get_batch(context_tokens): ...@@ -42,8 +42,7 @@ def get_batch(context_tokens):
tokenizer.eod, tokenizer.eod,
args.reset_position_ids, args.reset_position_ids,
args.reset_attention_mask, args.reset_attention_mask,
args.eod_mask_loss, args.eod_mask_loss)
args.fp16)
return tokens, attention_mask, position_ids return tokens, attention_mask, position_ids
......
...@@ -119,8 +119,7 @@ def get_ltor_masks_and_position_ids(data, ...@@ -119,8 +119,7 @@ def get_ltor_masks_and_position_ids(data,
eod_token, eod_token,
reset_position_ids, reset_position_ids,
reset_attention_mask, reset_attention_mask,
eod_mask_loss, eod_mask_loss):
fp16):
"""Build masks and position id for left to right model.""" """Build masks and position id for left to right model."""
# Extract batch size and sequence length. # Extract batch size and sequence length.
...@@ -170,8 +169,7 @@ def get_ltor_masks_and_position_ids(data, ...@@ -170,8 +169,7 @@ def get_ltor_masks_and_position_ids(data,
position_ids[b, (i + 1):] -= (i + 1 - prev_index) position_ids[b, (i + 1):] -= (i + 1 - prev_index)
prev_index = i + 1 prev_index = i + 1
# Convert # Convert attention mask to binary:
if fp16: attention_mask = (attention_mask < 0.5)
attention_mask = attention_mask.half()
return attention_mask, loss_mask, position_ids return attention_mask, loss_mask, position_ids
...@@ -65,8 +65,7 @@ def get_batch(data_iterator): ...@@ -65,8 +65,7 @@ def get_batch(data_iterator):
tokenizer.eod, tokenizer.eod,
args.reset_position_ids, args.reset_position_ids,
args.reset_attention_mask, args.reset_attention_mask,
args.eod_mask_loss, args.eod_mask_loss)
args.fp16)
return tokens, labels, loss_mask, attention_mask, position_ids return tokens, labels, loss_mask, attention_mask, position_ids
......
...@@ -71,8 +71,7 @@ def process_batch(batch): ...@@ -71,8 +71,7 @@ def process_batch(batch):
tokenizer.eod, tokenizer.eod,
args.reset_position_ids, args.reset_position_ids,
args.reset_attention_mask, args.reset_attention_mask,
args.eod_mask_loss, args.eod_mask_loss)
args.fp16)
return tokens, labels, attention_mask, position_ids, loss_mask return tokens, labels, attention_mask, position_ids, loss_mask
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment