Commit 197c132e authored by mohammad's avatar mohammad
Browse files

addressed jareds comments

parent 78022005
......@@ -97,6 +97,9 @@ def parse_args(extra_args_provider=None, defaults={},
if args.num_unique_layers < args.num_layers:
assert args.DDP_impl == 'local', \
'torch-DDP does not work with parameters sharing.'
# Mixed precision checks.
if args.fp16_lm_cross_entropy:
assert args.fp16, 'lm cross entropy in fp16 only support in fp16 mode.'
_print_args(args)
return args
......
......@@ -115,6 +115,7 @@ class BertModel(MegatronModule):
super(BertModel, self).__init__()
args = get_args()
self.fp16_lm_cross_entropy = args.fp16_lm_cross_entropy
self.add_binary_head = add_binary_head
self.parallel_output = parallel_output
init_method = init_method_normal(args.init_method_std)
......@@ -170,7 +171,12 @@ class BertModel(MegatronModule):
if lm_labels is None:
return lm_logits, binary_logits
else:
lm_loss = mpu.vocab_parallel_cross_entropy(lm_logits, lm_labels)
if self.fp16_lm_cross_entropy:
assert lm_logits.dtype == torch.half
lm_loss = mpu.vocab_parallel_cross_entropy(lm_logits, lm_labels)
else:
lm_loss = mpu.vocab_parallel_cross_entropy(lm_logits.float(),
lm_labels)
return lm_loss, binary_logits
......
......@@ -40,6 +40,7 @@ class GPT2Model(MegatronModule):
args = get_args()
self.parallel_output = parallel_output
self.fp16_lm_cross_entropy = args.fp16_lm_cross_entropy
self.language_model, self._language_model_key = get_language_model(
attention_mask_func=gpt2_attention_mask_func,
......@@ -79,7 +80,11 @@ class GPT2Model(MegatronModule):
if labels is None:
return output
else:
loss = mpu.vocab_parallel_cross_entropy(output, labels)
if self.fp16_lm_cross_entropy:
assert output.dtype == torch.half
loss = mpu.vocab_parallel_cross_entropy(output, labels)
else:
loss = mpu.vocab_parallel_cross_entropy(output.float(), labels)
return loss
......
......@@ -78,16 +78,12 @@ def forward_step(data_iterator, model):
timers('batch generator').stop()
# Forward model. lm_labels
if args.fp16_lm_cross_entropy:
lm_loss_, sop_logits = model(tokens, padding_mask, tokentype_ids=types,
lm_labels=lm_labels)
else:
lm_logits, sop_logits = model(tokens, padding_mask, tokentype_ids=types)
lm_loss_ = mpu.vocab_parallel_cross_entropy(
lm_logits.contiguous().float(), lm_labels.contiguous())
lm_loss_, sop_logits = model(tokens, padding_mask,
tokentype_ids=types,
lm_labels=lm_labels)
sop_loss = F.cross_entropy(sop_logits.view(-1, 2).contiguous().float(),
sentence_order.view(-1).contiguous(),
sop_loss = F.cross_entropy(sop_logits.view(-1, 2).float(),
sentence_order.view(-1),
ignore_index=-1)
lm_loss = torch.sum(
......
......@@ -82,13 +82,8 @@ def forward_step(data_iterator, model):
timers('batch generator').stop()
# Forward model.
if args.fp16_lm_cross_entropy:
losses = model(tokens, position_ids, attention_mask, labels=labels)
else:
output = model(tokens, position_ids, attention_mask)
losses = mpu.vocab_parallel_cross_entropy(output.contiguous().float(),
labels)
losses = model(tokens, position_ids, attention_mask, labels=labels)
loss_mask = loss_mask.view(-1)
loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask.sum()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment