Commit 83671bbf authored by mshoeybi's avatar mshoeybi Committed by Deepak Narayanan
Browse files

Address Deepak's comments

parent 8bed1d63
......@@ -185,7 +185,7 @@ def parse_args(extra_args_provider=None, defaults={},
assert args.fp16, 'lm cross entropy in fp16 only support in fp16 mode.'
if args.fp32_residual_connection:
assert args.fp16, \
'residual connection in fp32 only supports in fp16 mode.'
'residual connection in fp32 only supported when using fp16.'
# Activation checkpointing.
if args.distribute_checkpointed_activations:
assert args.checkpoint_activations, \
......
......@@ -568,8 +568,10 @@ class ParallelTransformer(MegatronModule):
if mpu.is_pipeline_first_stage():
# Data format change to avoid explicit tranposes : [b s h] --> [s b h].
# If the input flag for fp32 residual connection is set, convert for float.
if self.fp32_residual_connection:
hidden_states = hidden_states.transpose(0, 1).contiguous().float()
# Otherwise, leave it as is.
else:
hidden_states = hidden_states.transpose(0, 1).contiguous()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment