Commit 22e3c7e6 authored by mohammad's avatar mohammad
Browse files

added fp16 cross entropy loss option for gpt2

parent acfe848e
...@@ -76,7 +76,7 @@ class GPT2Model(MegatronModule): ...@@ -76,7 +76,7 @@ class GPT2Model(MegatronModule):
if get_key_value: if get_key_value:
output = [output, presents] output = [output, presents]
if labels is not None: if labels is None:
return output return output
else: else:
loss = mpu.vocab_parallel_cross_entropy(output, labels) loss = mpu.vocab_parallel_cross_entropy(output, labels)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment