"vscode:/vscode.git/clone" did not exist on "0bf6aeb885e624b17233a29bed8dbbe62c56d48e"
Commit 767e6e92 authored by Deepak Narayanan's avatar Deepak Narayanan
Browse files

Simplify logic in megatron/fp16/fp16.py

parent aa9cae27
......@@ -72,17 +72,10 @@ class FP16_Module(MegatronModule):
self.add_module('module', module.half())
def forward(self, *inputs, **kwargs):
convert_inputs = True
convert_outputs = True
if mpu.get_pipeline_model_parallel_world_size() > 1:
if not mpu.is_pipeline_first_stage():
convert_inputs = False
if not mpu.is_pipeline_last_stage():
convert_outputs = False
if convert_inputs:
if mpu.is_pipeline_first_stage():
inputs = fp32_to_fp16(inputs)
outputs = self.module(*inputs, **kwargs)
if convert_outputs:
if mpu.is_pipeline_last_stage():
outputs = fp16_to_fp32(outputs)
return outputs
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment