Commit a7539b0f authored by mshoeybi's avatar mshoeybi
Browse files

pipelining works

parent 8f160844
......@@ -15,8 +15,6 @@
"""Forward step utilities."""
from abc import ABC
from abc import abstractmethod
from collections.abc import Iterable
import torch
......@@ -24,44 +22,27 @@ import torch
from megatron import (
get_args,
mpu)
from megatron.p2p_communication import (
recv_forward,
send_forward)
def forward_step_provider(model,
batch_size,
micro_batch_size,
max_sequence_len):
args = get_args()
if args.pipeline_model_parallel_size == 1 or micro_batch_size >= batch_size:
return NoPipeliningForwardStep(model, batch_size, max_sequence_len)
return SimplePipeliningForwardStep(model, batch_size,
micro_batch_size,
max_sequence_len)
class InferenceParams:
def __init__(self, micro_batch_size_list, max_sequence_len):
assert isinstance(micro_batch_size_list, list)
assert max_sequence_len > 0
def __init__(self, max_batch_size, max_sequence_len):
self.micro_batch_size_list = micro_batch_size_list
self.max_sequence_len = max_sequence_len
self.max_batch_size = max_batch_size
self.sequence_len_offset = 0
self.batch_size_offset = 0
self.allocate_key_value_memory = True
self.micro_batch_index = 0
class ForwardStepBase(ABC):
def __init__(self, model):
class ForwardStep:
def __init__(self, model, max_batch_size, max_sequence_len):
# Make sure model is in eval mode.
if isinstance(model, Iterable):
for this_model in model:
this_model.eval()
......@@ -69,125 +50,148 @@ class ForwardStepBase(ABC):
model.eval()
self.model = model
@abstractmethod
def __call__(self, tokens, position_ids, attention_mask):
pass
self.constant = 512
# Initialize inference parameters.
self.inference_params = InferenceParams(max_batch_size,
max_sequence_len)
def __call__(self, tokens, position_ids, attention_mask):
if tokens.size(0) * tokens.size(1) >= self.constant:
micro_batch_size = max(1, self.constant // tokens.size(1))
return _with_pipelining_forward_step(self.model, tokens,
position_ids,
attention_mask,
self.inference_params,
micro_batch_size)
else:
return _no_pipelining_forward_step(self.model, tokens,
position_ids,
attention_mask,
self.inference_params)
class SimplePipeliningForwardStep(ForwardStepBase):
def _get_recv_buffer_dtype(args):
"""Receive happens between the layers."""
if args.fp32_residual_connection:
return torch.float
return args.params_dtype
def __init__(self, model, batch_size, micro_batch_size, max_sequence_len):
super().__init__(model)
self.batch_size = batch_size
# Divide the batch dimension into micro batches.
self.num_micro_batches, last_chunk = divmod(batch_size,
micro_batch_size)
self.micro_batch_size_list = []
self.batch_dim_start_index = [0]
for i in range(self.num_micro_batches):
self.micro_batch_size_list.append(micro_batch_size)
self.batch_dim_start_index.append((i + 1) * micro_batch_size)
if last_chunk > 0:
self.num_micro_batches += 1
self.micro_batch_size_list.append(last_chunk)
self.batch_dim_start_index.append(batch_size)
self.inference_params = InferenceParams(self.micro_batch_size_list,
max_sequence_len)
def _allocate_recv_buffer(batch_size, sequence_length):
"""Receive happens between the layers with size [s, b, h]."""
if mpu.is_pipeline_first_stage():
return None
args = get_args()
recv_size = (sequence_length, batch_size, args.hidden_size)
return torch.empty(recv_size,
dtype=_get_recv_buffer_dtype(args),
device=torch.cuda.current_device())
def __call__(self, tokens, position_ids, attention_mask):
# Need to tell p2p_communicate functions the correct size.
args = get_args()
orig_seq_length = args.seq_length
args.seq_length = tokens.size(1)
assert args.seq_length <= self.inference_params.max_sequence_len
def _forward_step_helper(model, tokens, position_ids, attention_mask,
inference_params, recv_buffer=None):
"""Single forward step. Update the allocate memory flag so
only the first time the memory is allocated."""
batch_size = tokens.size(0)
sequence_length = tokens.size(1)
if recv_buffer is None:
recv_buffer = _allocate_recv_buffer(batch_size, sequence_length)
# Preallocate memory for output logits.
logits = None
if mpu.is_pipeline_last_stage():
logits = torch.empty(tokens.size(0),
tokens.size(1),
args.padded_vocab_size,
dtype=torch.float32,
device=torch.cuda.current_device())
# Receive from previous stage.
if not mpu.is_pipeline_first_stage():
torch.distributed.recv(recv_buffer,
src=mpu.get_pipeline_model_parallel_prev_rank())
# Pileline using micro batches.
for micro_batch_index in range(self.num_micro_batches):
# Set micro-batch size and index.
self.inference_params.micro_batch_index = micro_batch_index
args.micro_batch_size = self.micro_batch_size_list[
micro_batch_index]
# Slice among the batch dimenion.
start = self.batch_dim_start_index[micro_batch_index]
end = self.batch_dim_start_index[micro_batch_index + 1]
tokens2use = tokens[start:end, ...]
position_ids2use = position_ids[start:end, ...]
# Forward pass through the model.
model.set_input_tensor(recv_buffer)
output_tensor = model(tokens, position_ids, attention_mask,
inference_params=inference_params)
# Receive from previous stage.
input_tensor = recv_forward()
# Send output to the next stage.
if not mpu.is_pipeline_last_stage():
torch.distributed.send(output_tensor,
mpu.get_pipeline_model_parallel_next_rank())
# Forward pass through the model.
self.model.set_input_tensor(input_tensor)
output_tensor = self.model(tokens2use, position_ids2use,
attention_mask,
inference_params=self.inference_params)
# Make sure we do not allocate context memory anymore.
if inference_params.allocate_key_value_memory:
inference_params.allocate_key_value_memory = False
# Send output to the next stage.
send_forward(output_tensor)
# Reset the sequence lenght to whatwever it was before.
# Make sure we do not allocate context memory anymore.
if self.inference_params.allocate_key_value_memory:
self.inference_params.allocate_key_value_memory = False
return output_tensor
if mpu.is_pipeline_last_stage():
logits[start:end, ...] = output_tensor
# Adjust the sequence length back to whatever it was before.
args.seq_length = orig_seq_length
return logits
def _no_pipelining_forward_step(model, tokens, position_ids, attention_mask,
inference_params, recv_buffer=None):
# Run a simple forward pass.
output_tensor = _forward_step_helper(model, tokens, position_ids,
attention_mask, inference_params,
recv_buffer=recv_buffer)
# Update the sequence length offset.
inference_params.sequence_len_offset += tokens.size(1)
logits = None
if mpu.is_pipeline_last_stage():
logits = output_tensor
class NoPipeliningForwardStep(ForwardStepBase):
return logits
def __init__(self, model, batch_size, max_sequence_len):
super().__init__(model)
self.inference_params = InferenceParams([batch_size], max_sequence_len)
def _with_pipelining_forward_step(model, tokens, position_ids, attention_mask,
inference_params, micro_batch_size):
sequence_length = tokens.size(1)
batch_size = tokens.size(0)
def __call__(self, tokens, position_ids, attention_mask):
# Divide the batch dimension into micro batches.
num_micro_batches, last_chunk = divmod(batch_size,
micro_batch_size)
if last_chunk > 0:
num_micro_batches += 1
# Need to tell p2p_communicate functions the correct size.
# Preallocate memory for output logits.
logits = None
if mpu.is_pipeline_last_stage():
args = get_args()
orig_seq_length = args.seq_length
args.seq_length = tokens.shape[1]
assert args.seq_length <= self.inference_params.max_sequence_len
args.micro_batch_size = tokens.shape[0]
assert self.inference_params.micro_batch_size_list[0] == tokens.shape[0]
assert self.inference_params.micro_batch_index == 0
# Receive from previous stage.
input_tensor = recv_forward()
# Forward pass through the model.
self.model.set_input_tensor(input_tensor)
output_tensor = self.model(tokens, position_ids, attention_mask,
inference_params=self.inference_params)
# Send output to the next stage.
send_forward(output_tensor)
# Reset the sequence lenght to whatwever it was before.
args.seq_length = orig_seq_length
# Make sure we do not allocate context memory anymore.
if self.inference_params.allocate_key_value_memory:
self.inference_params.allocate_key_value_memory = False
return output_tensor
logits = torch.empty(
(batch_size, sequence_length, args.padded_vocab_size),
dtype=torch.float32, device=torch.cuda.current_device())
# Preallocate recv buffer.
recv_buffer = _allocate_recv_buffer(micro_batch_size, sequence_length)
for micro_batch_index in range(num_micro_batches):
# Slice among the batch dimenion.
start = micro_batch_index * micro_batch_size
end = min(start + micro_batch_size, batch_size)
this_micro_batch_size = end - start
tokens2use = tokens[start:end, ...]
position_ids2use = position_ids[start:end, ...]
# Run a simple forward pass.
if this_micro_batch_size != micro_batch_size:
recv_buffer = None
output = _forward_step_helper(model, tokens2use, position_ids2use,
attention_mask, inference_params,
recv_buffer=recv_buffer)
# Adjust the batch size offset to account for the micro-batch.
inference_params.batch_size_offset += this_micro_batch_size
# Copy logits.
if mpu.is_pipeline_last_stage():
logits[start:end, ...] = output
# Once we are done with all the micro-batches, we can
# adjust the sequence length offset.
inference_params.sequence_len_offset += sequence_length
# and reset the batch size offset
inference_params.batch_size_offset = 0
return logits
......@@ -24,7 +24,7 @@ from .communication import (
copy_from_last_to_first_pipeline_stage,
broadcast_from_last_pipeline_stage,
broadcast_from_last_to_first_pipeline_stage)
from .forward_step import forward_step_provider
from .forward_step import ForwardStep
from .sampling import sample
......@@ -66,8 +66,7 @@ def generate_tokens_probs_and_return_on_first_stage(
max_sequence_length = min(max_sequence_length, args.max_position_embeddings)
# forward step.
forward_step = forward_step_provider(model, batch_size, 4,
max_sequence_length)
forward_step = ForwardStep(model, batch_size, max_sequence_length)
# Added termination_id to support the case that we want to terminate the
# generation once that id is generated.
......@@ -190,8 +189,8 @@ def generate_tokens_probs_and_return_on_first_stage(
done = torch.all(is_generation_done)
done = broadcast_from_last_pipeline_stage(1, torch.uint8,
tensor=done)
if done:
break
#if done:
# break
# ===================================================
# Update the length of based on max generated length.
......
......@@ -180,9 +180,8 @@ class ParallelAttention(MegatronModule):
skip_bias_add=True)
# Inference key-value memory
self.inference_key_memory_list = None
self.inference_value_memory_list = None
self.inference_current_sequence_len_list = None
self.inference_key_memory = None
self.inference_value_memory = None
def _allocate_memory(self, inference_max_sequence_len, batch_size):
......@@ -206,22 +205,17 @@ class ParallelAttention(MegatronModule):
if inference_params:
if inference_params.allocate_key_value_memory:
inf_max_seq_len = inference_params.max_sequence_len
inf_batch_sizes = inference_params.micro_batch_size_list
self.inference_key_memory_list = [
self._allocate_memory(inf_max_seq_len, inf_batch_size)
for inf_batch_size in inf_batch_sizes]
self.inference_value_memory_list = [
self._allocate_memory(inf_max_seq_len, inf_batch_size)
for inf_batch_size in inf_batch_sizes]
self.inference_current_sequence_len_list = [
0 for _ in inf_batch_sizes]
inf_max_batch_size = inference_params.max_batch_size
self.inference_key_memory = self._allocate_memory(
inf_max_seq_len, inf_max_batch_size)
self.inference_value_memory = self._allocate_memory(
inf_max_seq_len, inf_max_batch_size)
# This is added for safety. In case inference_params
# is not provided, make sure there is no potential memory left
# from previous inference.
else:
self.inference_key_memory_list = None
self.inference_value_memory_list = None
self.inference_current_sequence_len_list = None
self.inference_value_memory = None
self.inference_current_sequence_len = None
# =====================
# Query, Key, and Value
......@@ -269,23 +263,23 @@ class ParallelAttention(MegatronModule):
# ==================================
if inference_params:
inf_batch_index = inference_params.micro_batch_index
assert key_layer.size(1) == \
inference_params.micro_batch_size_list[inf_batch_index]
# Adjust the range variables.
start = self.inference_current_sequence_len_list[inf_batch_index]
end = start + key_layer.size(0)
assert end <= inference_params.max_sequence_len
self.inference_current_sequence_len_list[inf_batch_index] = end
batch_start = inference_params.batch_size_offset
batch_end = batch_start + key_layer.size(1)
assert batch_end <= self.inference_key_memory.size(1)
sequence_start = inference_params.sequence_len_offset
sequence_end = sequence_start + key_layer.size(0)
assert sequence_end <= self.inference_key_memory.size(0)
# Copy key and values.
self.inference_key_memory_list[inf_batch_index][start:end, ...] \
= key_layer
self.inference_value_memory_list[inf_batch_index][start:end, ...] \
= value_layer
key_layer = \
self.inference_key_memory_list[inf_batch_index][:end, ...]
value_layer = \
self.inference_value_memory_list[inf_batch_index][:end, ...]
self.inference_key_memory[sequence_start:sequence_end,
batch_start:batch_end,
...] = key_layer
self.inference_value_memory[sequence_start:sequence_end,
batch_start:batch_end,
...] = value_layer
key_layer = self.inference_key_memory[
:sequence_end, batch_start:batch_end, ...]
value_layer = self.inference_value_memory[
:sequence_end, batch_start:batch_end, ...]
# ===================================
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment