Commit b6555b71 authored by mshoeybi's avatar mshoeybi
Browse files

working

parent 6c40f892
...@@ -16,25 +16,10 @@ ...@@ -16,25 +16,10 @@
"""Forward step utilities.""" """Forward step utilities."""
import torch import torch
from megatron.p2p_communication import recv_forward, send_forward from megatron.p2p_communication import recv_forward, send_forward
from .sampling import sample from megatron import get_args
from megatron import mpu
import torch.nn.functional as F
from megatron import print_rank_0
from megatron import get_args, get_tokenizer
from megatron.utils import get_ltor_masks_and_position_ids, unwrap_model
from .communication import (
broadcast_float_list,
copy_from_last_to_first_pipeline_stage,
broadcast_from_last_pipeline_stage)
from .tokenization import tokenize_prompts
# These are needed to unwrap the model, would be nice to put these in megatron.utils if possible?
from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
from megatron.model import DistributedDataParallel as LocalDDP
from megatron.model import Float16Module
def forward_step(model, tokens, position_ids, attention_mask, def forward_step(model, tokens, position_ids, attention_mask,
...@@ -51,9 +36,7 @@ def forward_step(model, tokens, position_ids, attention_mask, ...@@ -51,9 +36,7 @@ def forward_step(model, tokens, position_ids, attention_mask,
input_tensor = recv_forward() input_tensor = recv_forward()
# Forward pass through the model. # Forward pass through the model.
unwrapped_model = unwrap_model( model.set_input_tensor(input_tensor)
model, (torchDDP, LocalDDP, Float16Module))
unwrapped_model.set_input_tensor(input_tensor)
output_tensor = model( output_tensor = model(
tokens, position_ids, attention_mask, tokens, position_ids, attention_mask,
set_inference_key_value_memory=set_inference_key_value_memory, set_inference_key_value_memory=set_inference_key_value_memory,
......
...@@ -166,6 +166,10 @@ class Float16Module(MegatronModule): ...@@ -166,6 +166,10 @@ class Float16Module(MegatronModule):
self.float16_convertor = float16_convertor self.float16_convertor = float16_convertor
def set_input_tensor(self, input_tensor):
return self.module.set_input_tensor(input_tensor)
def forward(self, *inputs, **kwargs): def forward(self, *inputs, **kwargs):
if mpu.is_pipeline_first_stage(): if mpu.is_pipeline_first_stage():
inputs = fp32_to_float16(inputs, self.float16_convertor) inputs = fp32_to_float16(inputs, self.float16_convertor)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment