Commit a7539b0f authored by mshoeybi's avatar mshoeybi
Browse files

pipelining works

parent 8f160844
...@@ -15,8 +15,6 @@ ...@@ -15,8 +15,6 @@
"""Forward step utilities.""" """Forward step utilities."""
from abc import ABC
from abc import abstractmethod
from collections.abc import Iterable from collections.abc import Iterable
import torch import torch
...@@ -24,44 +22,27 @@ import torch ...@@ -24,44 +22,27 @@ import torch
from megatron import ( from megatron import (
get_args, get_args,
mpu) mpu)
from megatron.p2p_communication import (
recv_forward,
send_forward)
def forward_step_provider(model,
batch_size,
micro_batch_size,
max_sequence_len):
args = get_args()
if args.pipeline_model_parallel_size == 1 or micro_batch_size >= batch_size:
return NoPipeliningForwardStep(model, batch_size, max_sequence_len)
return SimplePipeliningForwardStep(model, batch_size,
micro_batch_size,
max_sequence_len)
class InferenceParams: class InferenceParams:
def __init__(self, micro_batch_size_list, max_sequence_len):
assert isinstance(micro_batch_size_list, list) def __init__(self, max_batch_size, max_sequence_len):
assert max_sequence_len > 0
self.micro_batch_size_list = micro_batch_size_list
self.max_sequence_len = max_sequence_len self.max_sequence_len = max_sequence_len
self.max_batch_size = max_batch_size
self.sequence_len_offset = 0
self.batch_size_offset = 0
self.allocate_key_value_memory = True self.allocate_key_value_memory = True
self.micro_batch_index = 0
class ForwardStepBase(ABC):
def __init__(self, model): class ForwardStep:
def __init__(self, model, max_batch_size, max_sequence_len):
# Make sure model is in eval mode.
if isinstance(model, Iterable): if isinstance(model, Iterable):
for this_model in model: for this_model in model:
this_model.eval() this_model.eval()
...@@ -69,125 +50,148 @@ class ForwardStepBase(ABC): ...@@ -69,125 +50,148 @@ class ForwardStepBase(ABC):
model.eval() model.eval()
self.model = model self.model = model
@abstractmethod self.constant = 512
def __call__(self, tokens, position_ids, attention_mask):
pass # Initialize inference parameters.
self.inference_params = InferenceParams(max_batch_size,
max_sequence_len)
def __call__(self, tokens, position_ids, attention_mask):
if tokens.size(0) * tokens.size(1) >= self.constant:
micro_batch_size = max(1, self.constant // tokens.size(1))
return _with_pipelining_forward_step(self.model, tokens,
position_ids,
attention_mask,
self.inference_params,
micro_batch_size)
else:
return _no_pipelining_forward_step(self.model, tokens,
position_ids,
attention_mask,
self.inference_params)
class SimplePipeliningForwardStep(ForwardStepBase): def _get_recv_buffer_dtype(args):
"""Receive happens between the layers."""
if args.fp32_residual_connection:
return torch.float
return args.params_dtype
def __init__(self, model, batch_size, micro_batch_size, max_sequence_len):
super().__init__(model)
self.batch_size = batch_size
# Divide the batch dimension into micro batches.
self.num_micro_batches, last_chunk = divmod(batch_size,
micro_batch_size)
self.micro_batch_size_list = []
self.batch_dim_start_index = [0]
for i in range(self.num_micro_batches):
self.micro_batch_size_list.append(micro_batch_size)
self.batch_dim_start_index.append((i + 1) * micro_batch_size)
if last_chunk > 0:
self.num_micro_batches += 1
self.micro_batch_size_list.append(last_chunk)
self.batch_dim_start_index.append(batch_size)
self.inference_params = InferenceParams(self.micro_batch_size_list, def _allocate_recv_buffer(batch_size, sequence_length):
max_sequence_len) """Receive happens between the layers with size [s, b, h]."""
if mpu.is_pipeline_first_stage():
return None
args = get_args()
recv_size = (sequence_length, batch_size, args.hidden_size)
return torch.empty(recv_size,
dtype=_get_recv_buffer_dtype(args),
device=torch.cuda.current_device())
def __call__(self, tokens, position_ids, attention_mask):
# Need to tell p2p_communicate functions the correct size. def _forward_step_helper(model, tokens, position_ids, attention_mask,
args = get_args() inference_params, recv_buffer=None):
orig_seq_length = args.seq_length """Single forward step. Update the allocate memory flag so
args.seq_length = tokens.size(1) only the first time the memory is allocated."""
assert args.seq_length <= self.inference_params.max_sequence_len batch_size = tokens.size(0)
sequence_length = tokens.size(1)
if recv_buffer is None:
recv_buffer = _allocate_recv_buffer(batch_size, sequence_length)
# Preallocate memory for output logits. # Receive from previous stage.
logits = None if not mpu.is_pipeline_first_stage():
if mpu.is_pipeline_last_stage(): torch.distributed.recv(recv_buffer,
logits = torch.empty(tokens.size(0), src=mpu.get_pipeline_model_parallel_prev_rank())
tokens.size(1),
args.padded_vocab_size,
dtype=torch.float32,
device=torch.cuda.current_device())
# Pileline using micro batches. # Forward pass through the model.
for micro_batch_index in range(self.num_micro_batches): model.set_input_tensor(recv_buffer)
# Set micro-batch size and index. output_tensor = model(tokens, position_ids, attention_mask,
self.inference_params.micro_batch_index = micro_batch_index inference_params=inference_params)
args.micro_batch_size = self.micro_batch_size_list[
micro_batch_index]
# Slice among the batch dimenion.
start = self.batch_dim_start_index[micro_batch_index]
end = self.batch_dim_start_index[micro_batch_index + 1]
tokens2use = tokens[start:end, ...]
position_ids2use = position_ids[start:end, ...]
# Receive from previous stage. # Send output to the next stage.
input_tensor = recv_forward() if not mpu.is_pipeline_last_stage():
torch.distributed.send(output_tensor,
mpu.get_pipeline_model_parallel_next_rank())
# Forward pass through the model. # Make sure we do not allocate context memory anymore.
self.model.set_input_tensor(input_tensor) if inference_params.allocate_key_value_memory:
output_tensor = self.model(tokens2use, position_ids2use, inference_params.allocate_key_value_memory = False
attention_mask,
inference_params=self.inference_params)
# Send output to the next stage.
send_forward(output_tensor)
# Reset the sequence lenght to whatwever it was before. return output_tensor
# Make sure we do not allocate context memory anymore.
if self.inference_params.allocate_key_value_memory:
self.inference_params.allocate_key_value_memory = False
if mpu.is_pipeline_last_stage():
logits[start:end, ...] = output_tensor
# Adjust the sequence length back to whatever it was before.
args.seq_length = orig_seq_length
return logits def _no_pipelining_forward_step(model, tokens, position_ids, attention_mask,
inference_params, recv_buffer=None):
# Run a simple forward pass.
output_tensor = _forward_step_helper(model, tokens, position_ids,
attention_mask, inference_params,
recv_buffer=recv_buffer)
# Update the sequence length offset.
inference_params.sequence_len_offset += tokens.size(1)
logits = None
if mpu.is_pipeline_last_stage():
logits = output_tensor
class NoPipeliningForwardStep(ForwardStepBase): return logits
def __init__(self, model, batch_size, max_sequence_len):
super().__init__(model)
self.inference_params = InferenceParams([batch_size], max_sequence_len) def _with_pipelining_forward_step(model, tokens, position_ids, attention_mask,
inference_params, micro_batch_size):
sequence_length = tokens.size(1)
batch_size = tokens.size(0)
def __call__(self, tokens, position_ids, attention_mask): # Divide the batch dimension into micro batches.
num_micro_batches, last_chunk = divmod(batch_size,
micro_batch_size)
if last_chunk > 0:
num_micro_batches += 1
# Need to tell p2p_communicate functions the correct size. # Preallocate memory for output logits.
logits = None
if mpu.is_pipeline_last_stage():
args = get_args() args = get_args()
orig_seq_length = args.seq_length logits = torch.empty(
args.seq_length = tokens.shape[1] (batch_size, sequence_length, args.padded_vocab_size),
assert args.seq_length <= self.inference_params.max_sequence_len dtype=torch.float32, device=torch.cuda.current_device())
args.micro_batch_size = tokens.shape[0]
assert self.inference_params.micro_batch_size_list[0] == tokens.shape[0] # Preallocate recv buffer.
assert self.inference_params.micro_batch_index == 0 recv_buffer = _allocate_recv_buffer(micro_batch_size, sequence_length)
# Receive from previous stage. for micro_batch_index in range(num_micro_batches):
input_tensor = recv_forward() # Slice among the batch dimenion.
start = micro_batch_index * micro_batch_size
# Forward pass through the model. end = min(start + micro_batch_size, batch_size)
self.model.set_input_tensor(input_tensor) this_micro_batch_size = end - start
output_tensor = self.model(tokens, position_ids, attention_mask, tokens2use = tokens[start:end, ...]
inference_params=self.inference_params) position_ids2use = position_ids[start:end, ...]
# Send output to the next stage. # Run a simple forward pass.
send_forward(output_tensor) if this_micro_batch_size != micro_batch_size:
recv_buffer = None
# Reset the sequence lenght to whatwever it was before. output = _forward_step_helper(model, tokens2use, position_ids2use,
args.seq_length = orig_seq_length attention_mask, inference_params,
# Make sure we do not allocate context memory anymore. recv_buffer=recv_buffer)
if self.inference_params.allocate_key_value_memory:
self.inference_params.allocate_key_value_memory = False # Adjust the batch size offset to account for the micro-batch.
inference_params.batch_size_offset += this_micro_batch_size
return output_tensor
# Copy logits.
if mpu.is_pipeline_last_stage():
logits[start:end, ...] = output
# Once we are done with all the micro-batches, we can
# adjust the sequence length offset.
inference_params.sequence_len_offset += sequence_length
# and reset the batch size offset
inference_params.batch_size_offset = 0
return logits
...@@ -24,7 +24,7 @@ from .communication import ( ...@@ -24,7 +24,7 @@ from .communication import (
copy_from_last_to_first_pipeline_stage, copy_from_last_to_first_pipeline_stage,
broadcast_from_last_pipeline_stage, broadcast_from_last_pipeline_stage,
broadcast_from_last_to_first_pipeline_stage) broadcast_from_last_to_first_pipeline_stage)
from .forward_step import forward_step_provider from .forward_step import ForwardStep
from .sampling import sample from .sampling import sample
...@@ -66,8 +66,7 @@ def generate_tokens_probs_and_return_on_first_stage( ...@@ -66,8 +66,7 @@ def generate_tokens_probs_and_return_on_first_stage(
max_sequence_length = min(max_sequence_length, args.max_position_embeddings) max_sequence_length = min(max_sequence_length, args.max_position_embeddings)
# forward step. # forward step.
forward_step = forward_step_provider(model, batch_size, 4, forward_step = ForwardStep(model, batch_size, max_sequence_length)
max_sequence_length)
# Added termination_id to support the case that we want to terminate the # Added termination_id to support the case that we want to terminate the
# generation once that id is generated. # generation once that id is generated.
...@@ -190,8 +189,8 @@ def generate_tokens_probs_and_return_on_first_stage( ...@@ -190,8 +189,8 @@ def generate_tokens_probs_and_return_on_first_stage(
done = torch.all(is_generation_done) done = torch.all(is_generation_done)
done = broadcast_from_last_pipeline_stage(1, torch.uint8, done = broadcast_from_last_pipeline_stage(1, torch.uint8,
tensor=done) tensor=done)
if done: #if done:
break # break
# =================================================== # ===================================================
# Update the length of based on max generated length. # Update the length of based on max generated length.
......
...@@ -180,9 +180,8 @@ class ParallelAttention(MegatronModule): ...@@ -180,9 +180,8 @@ class ParallelAttention(MegatronModule):
skip_bias_add=True) skip_bias_add=True)
# Inference key-value memory # Inference key-value memory
self.inference_key_memory_list = None self.inference_key_memory = None
self.inference_value_memory_list = None self.inference_value_memory = None
self.inference_current_sequence_len_list = None
def _allocate_memory(self, inference_max_sequence_len, batch_size): def _allocate_memory(self, inference_max_sequence_len, batch_size):
...@@ -206,22 +205,17 @@ class ParallelAttention(MegatronModule): ...@@ -206,22 +205,17 @@ class ParallelAttention(MegatronModule):
if inference_params: if inference_params:
if inference_params.allocate_key_value_memory: if inference_params.allocate_key_value_memory:
inf_max_seq_len = inference_params.max_sequence_len inf_max_seq_len = inference_params.max_sequence_len
inf_batch_sizes = inference_params.micro_batch_size_list inf_max_batch_size = inference_params.max_batch_size
self.inference_key_memory_list = [ self.inference_key_memory = self._allocate_memory(
self._allocate_memory(inf_max_seq_len, inf_batch_size) inf_max_seq_len, inf_max_batch_size)
for inf_batch_size in inf_batch_sizes] self.inference_value_memory = self._allocate_memory(
self.inference_value_memory_list = [ inf_max_seq_len, inf_max_batch_size)
self._allocate_memory(inf_max_seq_len, inf_batch_size)
for inf_batch_size in inf_batch_sizes]
self.inference_current_sequence_len_list = [
0 for _ in inf_batch_sizes]
# This is added for safety. In case inference_params # This is added for safety. In case inference_params
# is not provided, make sure there is no potential memory left # is not provided, make sure there is no potential memory left
# from previous inference. # from previous inference.
else: else:
self.inference_key_memory_list = None self.inference_value_memory = None
self.inference_value_memory_list = None self.inference_current_sequence_len = None
self.inference_current_sequence_len_list = None
# ===================== # =====================
# Query, Key, and Value # Query, Key, and Value
...@@ -269,23 +263,23 @@ class ParallelAttention(MegatronModule): ...@@ -269,23 +263,23 @@ class ParallelAttention(MegatronModule):
# ================================== # ==================================
if inference_params: if inference_params:
inf_batch_index = inference_params.micro_batch_index batch_start = inference_params.batch_size_offset
assert key_layer.size(1) == \ batch_end = batch_start + key_layer.size(1)
inference_params.micro_batch_size_list[inf_batch_index] assert batch_end <= self.inference_key_memory.size(1)
# Adjust the range variables. sequence_start = inference_params.sequence_len_offset
start = self.inference_current_sequence_len_list[inf_batch_index] sequence_end = sequence_start + key_layer.size(0)
end = start + key_layer.size(0) assert sequence_end <= self.inference_key_memory.size(0)
assert end <= inference_params.max_sequence_len
self.inference_current_sequence_len_list[inf_batch_index] = end
# Copy key and values. # Copy key and values.
self.inference_key_memory_list[inf_batch_index][start:end, ...] \ self.inference_key_memory[sequence_start:sequence_end,
= key_layer batch_start:batch_end,
self.inference_value_memory_list[inf_batch_index][start:end, ...] \ ...] = key_layer
= value_layer self.inference_value_memory[sequence_start:sequence_end,
key_layer = \ batch_start:batch_end,
self.inference_key_memory_list[inf_batch_index][:end, ...] ...] = value_layer
value_layer = \ key_layer = self.inference_key_memory[
self.inference_value_memory_list[inf_batch_index][:end, ...] :sequence_end, batch_start:batch_end, ...]
value_layer = self.inference_value_memory[
:sequence_end, batch_start:batch_end, ...]
# =================================== # ===================================
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment