Commit 83671bbf authored by mshoeybi's avatar mshoeybi Committed by Deepak Narayanan
Browse files

Address Deepak's comments

parent 8bed1d63
...@@ -185,7 +185,7 @@ def parse_args(extra_args_provider=None, defaults={}, ...@@ -185,7 +185,7 @@ def parse_args(extra_args_provider=None, defaults={},
assert args.fp16, 'lm cross entropy in fp16 only support in fp16 mode.' assert args.fp16, 'lm cross entropy in fp16 only support in fp16 mode.'
if args.fp32_residual_connection: if args.fp32_residual_connection:
assert args.fp16, \ assert args.fp16, \
'residual connection in fp32 only supports in fp16 mode.' 'residual connection in fp32 only supported when using fp16.'
# Activation checkpointing. # Activation checkpointing.
if args.distribute_checkpointed_activations: if args.distribute_checkpointed_activations:
assert args.checkpoint_activations, \ assert args.checkpoint_activations, \
......
...@@ -568,8 +568,10 @@ class ParallelTransformer(MegatronModule): ...@@ -568,8 +568,10 @@ class ParallelTransformer(MegatronModule):
if mpu.is_pipeline_first_stage(): if mpu.is_pipeline_first_stage():
# Data format change to avoid explicit tranposes : [b s h] --> [s b h]. # Data format change to avoid explicit tranposes : [b s h] --> [s b h].
# If the input flag for fp32 residual connection is set, convert for float.
if self.fp32_residual_connection: if self.fp32_residual_connection:
hidden_states = hidden_states.transpose(0, 1).contiguous().float() hidden_states = hidden_states.transpose(0, 1).contiguous().float()
# Otherwise, leave it as is.
else: else:
hidden_states = hidden_states.transpose(0, 1).contiguous() hidden_states = hidden_states.transpose(0, 1).contiguous()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment