Commit 8ceed7c7 authored by Mohammad's avatar Mohammad
Browse files

changed gpt2 masking to binary and masked_fill

parent c0a59a66
......@@ -27,8 +27,7 @@ from .utils import scaled_init_method_normal
def gpt2_attention_mask_func(attention_scores, ltor_mask):
attention_scores = torch.mul(attention_scores, ltor_mask) - \
10000.0 * (1.0 - ltor_mask)
attention_scores.masked_fill_(ltor_mask, -10000.0)
return attention_scores
......
......@@ -42,8 +42,7 @@ def get_batch(context_tokens):
tokenizer.eod,
args.reset_position_ids,
args.reset_attention_mask,
args.eod_mask_loss,
args.fp16)
args.eod_mask_loss)
return tokens, attention_mask, position_ids
......
......@@ -119,8 +119,7 @@ def get_ltor_masks_and_position_ids(data,
eod_token,
reset_position_ids,
reset_attention_mask,
eod_mask_loss,
fp16):
eod_mask_loss):
"""Build masks and position id for left to right model."""
# Extract batch size and sequence length.
......@@ -170,8 +169,7 @@ def get_ltor_masks_and_position_ids(data,
position_ids[b, (i + 1):] -= (i + 1 - prev_index)
prev_index = i + 1
# Convert
if fp16:
attention_mask = attention_mask.half()
# Convert attention mask to binary:
attention_mask = (attention_mask < 0.5)
return attention_mask, loss_mask, position_ids
......@@ -65,8 +65,7 @@ def get_batch(data_iterator):
tokenizer.eod,
args.reset_position_ids,
args.reset_attention_mask,
args.eod_mask_loss,
args.fp16)
args.eod_mask_loss)
return tokens, labels, loss_mask, attention_mask, position_ids
......
......@@ -71,8 +71,7 @@ def process_batch(batch):
tokenizer.eod,
args.reset_position_ids,
args.reset_attention_mask,
args.eod_mask_loss,
args.fp16)
args.eod_mask_loss)
return tokens, labels, attention_mask, position_ids, loss_mask
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment