Commit 8f160844 authored by mshoeybi's avatar mshoeybi
Browse files

simple pipelining works

parent dee8707e
......@@ -15,15 +15,33 @@
"""Forward step utilities."""
from abc import ABC
from abc import abstractmethod
from collections.abc import Iterable
from enum import Enum
from megatron.p2p_communication import recv_forward, send_forward
from megatron import get_args
import torch
from megatron import (
get_args,
mpu)
from megatron.p2p_communication import (
recv_forward,
send_forward)
class ForwardStepTypes(Enum):
NO_PIPELINING = 1
def forward_step_provider(model,
batch_size,
micro_batch_size,
max_sequence_len):
args = get_args()
if args.pipeline_model_parallel_size == 1 or micro_batch_size >= batch_size:
return NoPipeliningForwardStep(model, batch_size, max_sequence_len)
return SimplePipeliningForwardStep(model, batch_size,
micro_batch_size,
max_sequence_len)
......@@ -37,12 +55,12 @@ class InferenceParams:
self.micro_batch_size_list = micro_batch_size_list
self.max_sequence_len = max_sequence_len
self.allocate_key_value_memory = True
self.micro_batch_size_index = 0
self.micro_batch_index = 0
class InferenceForwardStep:
class ForwardStepBase(ABC):
def __init__(self, model, batch_size, max_sequence_len):
def __init__(self, model):
if isinstance(model, Iterable):
for this_model in model:
......@@ -51,21 +69,100 @@ class InferenceForwardStep:
model.eval()
self.model = model
self.inference_params = InferenceParams([batch_size], max_sequence_len)
self.forward_step_type = ForwardStepTypes.NO_PIPELINING
@abstractmethod
def __call__(self, tokens, position_ids, attention_mask):
pass
class SimplePipeliningForwardStep(ForwardStepBase):
def __init__(self, model, batch_size, micro_batch_size, max_sequence_len):
super().__init__(model)
self.batch_size = batch_size
# Divide the batch dimension into micro batches.
self.num_micro_batches, last_chunk = divmod(batch_size,
micro_batch_size)
self.micro_batch_size_list = []
self.batch_dim_start_index = [0]
for i in range(self.num_micro_batches):
self.micro_batch_size_list.append(micro_batch_size)
self.batch_dim_start_index.append((i + 1) * micro_batch_size)
if last_chunk > 0:
self.num_micro_batches += 1
self.micro_batch_size_list.append(last_chunk)
self.batch_dim_start_index.append(batch_size)
self.inference_params = InferenceParams(self.micro_batch_size_list,
max_sequence_len)
def __call__(self, tokens, position_ids, attention_mask):
if self.forward_step_type == ForwardStepTypes.NO_PIPELINING:
return self._forward_step_no_pipelining(tokens, position_ids,
attention_mask)
# Need to tell p2p_communicate functions the correct size.
args = get_args()
orig_seq_length = args.seq_length
args.seq_length = tokens.size(1)
assert args.seq_length <= self.inference_params.max_sequence_len
# Preallocate memory for output logits.
logits = None
if mpu.is_pipeline_last_stage():
logits = torch.empty(tokens.size(0),
tokens.size(1),
args.padded_vocab_size,
dtype=torch.float32,
device=torch.cuda.current_device())
# Pileline using micro batches.
for micro_batch_index in range(self.num_micro_batches):
# Set micro-batch size and index.
self.inference_params.micro_batch_index = micro_batch_index
args.micro_batch_size = self.micro_batch_size_list[
micro_batch_index]
# Slice among the batch dimenion.
start = self.batch_dim_start_index[micro_batch_index]
end = self.batch_dim_start_index[micro_batch_index + 1]
tokens2use = tokens[start:end, ...]
position_ids2use = position_ids[start:end, ...]
# Receive from previous stage.
input_tensor = recv_forward()
# Forward pass through the model.
self.model.set_input_tensor(input_tensor)
output_tensor = self.model(tokens2use, position_ids2use,
attention_mask,
inference_params=self.inference_params)
# Send output to the next stage.
send_forward(output_tensor)
# Reset the sequence lenght to whatwever it was before.
# Make sure we do not allocate context memory anymore.
if self.inference_params.allocate_key_value_memory:
self.inference_params.allocate_key_value_memory = False
if mpu.is_pipeline_last_stage():
logits[start:end, ...] = output_tensor
# Adjust the sequence length back to whatever it was before.
args.seq_length = orig_seq_length
return logits
class NoPipeliningForwardStep(ForwardStepBase):
raise Exception('unknown forward step type {}'.format(
self.forward_step_type))
def __init__(self, model, batch_size, max_sequence_len):
super().__init__(model)
self.inference_params = InferenceParams([batch_size], max_sequence_len)
def _forward_step_no_pipelining(self, tokens, position_ids, attention_mask):
def __call__(self, tokens, position_ids, attention_mask):
# Need to tell p2p_communicate functions the correct size.
args = get_args()
......@@ -74,7 +171,7 @@ class InferenceForwardStep:
assert args.seq_length <= self.inference_params.max_sequence_len
args.micro_batch_size = tokens.shape[0]
assert self.inference_params.micro_batch_size_list[0] == tokens.shape[0]
assert self.inference_params.micro_batch_size_index == 0
assert self.inference_params.micro_batch_index == 0
# Receive from previous stage.
input_tensor = recv_forward()
......@@ -94,27 +191,3 @@ class InferenceForwardStep:
self.inference_params.allocate_key_value_memory = False
return output_tensor
def forward_step(model, tokens, position_ids, attention_mask, inference_params):
# Hidden size changes when not using recompute, need to tell p2p_communicate
# functions the correct size
args = get_args()
orig_seq_length = args.seq_length
args.seq_length = tokens.shape[1]
args.micro_batch_size = tokens.shape[0]
input_tensor = recv_forward()
# Forward pass through the model.
model.set_input_tensor(input_tensor)
output_tensor = model(tokens, position_ids, attention_mask,
inference_params=inference_params)
send_forward(output_tensor)
args.seq_length = orig_seq_length
return output_tensor
......@@ -24,7 +24,7 @@ from .communication import (
copy_from_last_to_first_pipeline_stage,
broadcast_from_last_pipeline_stage,
broadcast_from_last_to_first_pipeline_stage)
from .forward_step import InferenceForwardStep
from .forward_step import forward_step_provider
from .sampling import sample
......@@ -66,7 +66,8 @@ def generate_tokens_probs_and_return_on_first_stage(
max_sequence_length = min(max_sequence_length, args.max_position_embeddings)
# forward step.
forward_step = InferenceForwardStep(model, batch_size, max_sequence_length)
forward_step = forward_step_provider(model, batch_size, 4,
max_sequence_length)
# Added termination_id to support the case that we want to terminate the
# generation once that id is generated.
......
......@@ -269,18 +269,19 @@ class ParallelAttention(MegatronModule):
# ==================================
if inference_params:
inf_batch_index = inference_params.micro_batch_size_index
inf_batch_index = inference_params.micro_batch_index
assert key_layer.size(1) == \
inference_params.micro_batch_size_list[inf_batch_index]
# Adjust the range variables.
start = self.inference_current_sequence_len_list[inf_batch_index]
end = start + key_layer.size(0)
assert end <= inference_params.max_sequence_len
self.inference_current_sequence_len_list[inf_batch_index] = end
# Copy key and values.
self.inference_key_memory_list[inf_batch_index][start:end, ...] =\
key_layer
self.inference_value_memory_list[inf_batch_index][start:end, ...] =\
value_layer
self.inference_key_memory_list[inf_batch_index][start:end, ...] \
= key_layer
self.inference_value_memory_list[inf_batch_index][start:end, ...] \
= value_layer
key_layer = \
self.inference_key_memory_list[inf_batch_index][:end, ...]
value_layer = \
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment