Commit 767e6e92 authored by Deepak Narayanan's avatar Deepak Narayanan
Browse files

Simplify logic in megatron/fp16/fp16.py

parent aa9cae27
...@@ -72,17 +72,10 @@ class FP16_Module(MegatronModule): ...@@ -72,17 +72,10 @@ class FP16_Module(MegatronModule):
self.add_module('module', module.half()) self.add_module('module', module.half())
def forward(self, *inputs, **kwargs): def forward(self, *inputs, **kwargs):
convert_inputs = True if mpu.is_pipeline_first_stage():
convert_outputs = True
if mpu.get_pipeline_model_parallel_world_size() > 1:
if not mpu.is_pipeline_first_stage():
convert_inputs = False
if not mpu.is_pipeline_last_stage():
convert_outputs = False
if convert_inputs:
inputs = fp32_to_fp16(inputs) inputs = fp32_to_fp16(inputs)
outputs = self.module(*inputs, **kwargs) outputs = self.module(*inputs, **kwargs)
if convert_outputs: if mpu.is_pipeline_last_stage():
outputs = fp16_to_fp32(outputs) outputs = fp16_to_fp32(outputs)
return outputs return outputs
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment