Commit 8c119d80 authored by mshoeybi's avatar mshoeybi
Browse files

tested and working

parent a7539b0f
......@@ -41,6 +41,7 @@ def parse_args(extra_args_provider=None, defaults={},
parser = _add_biencoder_args(parser)
parser = _add_vit_args(parser)
parser = _add_logging_args(parser)
parser = _add_inference_args(parser)
# Custom arguments.
if extra_args_provider is not None:
......@@ -273,6 +274,18 @@ def _check_arg_is_not_none(args, arg):
assert getattr(args, arg) is not None, '{} argument is None'.format(arg)
def _add_inference_args(parser):
group = parser.add_argument_group(title='inference')
group.add_argument('--inference-batch-times-seqlen-threshold',
type=int, default=512,
help='During inference, if batch-size times '
'sequence-length is smaller than this threshold '
'then we will not use pipelining, otherwise we will.')
return parser
def _add_network_size_args(parser):
group = parser.add_argument_group(title='network size')
......
......@@ -26,14 +26,20 @@ from .tokenization import (
detokenize_generations)
def generate_and_post_process(model,
prompts=None,
tokens_to_generate=0,
return_output_log_probs=False,
return_all_log_probs=False,
greedy_sampling=False,
top_k_sampling=0,
top_p_sampling=0.0,
temperature=1.0,
add_BOS=False):
"""TO DO ..."""
add_BOS=False,
use_eod_token_for_early_termination=True):
"""Run inferecne and post-process outputs, i.e., detokenize,
move to cpu and convert to list."""
# Main inference.
tokens, lengths, output_log_probs, all_log_probs = generate(
......@@ -42,8 +48,12 @@ def generate_and_post_process(model,
tokens_to_generate=tokens_to_generate,
return_output_log_probs=return_output_log_probs,
return_all_log_probs=return_all_log_probs,
greedy_sampling=greedy_sampling,
top_k_sampling=top_k_sampling,
top_p_sampling=top_p_sampling,
temperature=temperature,
add_BOS=add_BOS)
add_BOS=add_BOS,
use_eod_token_for_early_termination=use_eod_token_for_early_termination)
# Only post-process on first stage.
if mpu.is_pipeline_first_stage():
......@@ -62,24 +72,42 @@ def generate_and_post_process(model,
return None
def generate(model,
prompts=None,
tokens_to_generate=0,
return_output_log_probs=False,
return_all_log_probs=False,
greedy_sampling=False,
top_k_sampling=0,
top_p_sampling=0.0,
temperature=1.0,
add_BOS=False):
"""TO DO ..."""
add_BOS=False,
use_eod_token_for_early_termination=True):
"""Given prompts and input parameters, run inference and return:
tokens: prompts plus the generated tokens.
lengths: length of the prompt + generations. Note that we can
discard tokens in the tokens tensor that are after the
corresponding length.
output_log_probs: log probs of the tokens.
all_log_probs: full log probs for all of tokens.
"""
# Make sure input params are avaialble to all ranks.
values = [tokens_to_generate, return_output_log_probs,
return_all_log_probs, temperature, add_BOS]
values_float_tensor = broadcast_float_list(5, float_list=values)
values = [tokens_to_generate,
return_output_log_probs, return_all_log_probs,
greedy_sampling, top_k_sampling, top_p_sampling,
temperature, add_BOS, use_eod_token_for_early_termination]
values_float_tensor = broadcast_float_list(9, float_list=values)
tokens_to_generate = int(values_float_tensor[0].item())
return_output_log_probs = bool(values_float_tensor[1].item())
return_all_log_probs = bool(values_float_tensor[2].item())
temperature = values_float_tensor[3].item()
add_BOS = bool(values_float_tensor[4].item())
greedy_sampling = bool(values_float_tensor[3].item())
top_k_sampling = int(values_float_tensor[4].item())
top_p_sampling = values_float_tensor[5].item()
temperature = values_float_tensor[6].item()
add_BOS = bool(values_float_tensor[7].item())
use_eod_token_for_early_termination = bool(values_float_tensor[8].item())
# Tokenize prompts and get the batch.
# Note that these tensors are broadcaseted to all ranks.
......@@ -95,4 +123,6 @@ def generate(model,
model, context_tokens_tensor, context_length_tensor,
return_output_log_probs=return_output_log_probs,
return_all_log_probs=return_all_log_probs,
temperature=temperature)
greedy=greedy_sampling, top_k=top_k_sampling, top_p=top_p_sampling,
temperature=temperature,
use_eod_token_for_early_termination=use_eod_token_for_early_termination)
......@@ -21,6 +21,38 @@ import torch
from megatron import mpu
def recv_from_prev_pipeline_rank_(recv_buffer=None):
"""Receive from previous pipeline stage and update the
input buffer inplace."""
if not mpu.is_pipeline_first_stage():
assert recv_buffer is not None
recv_prev_op = torch.distributed.P2POp(
torch.distributed.irecv, recv_buffer,
mpu.get_pipeline_model_parallel_prev_rank())
reqs = torch.distributed.batch_isend_irecv([recv_prev_op])
for req in reqs:
req.wait()
# To protect against race condition when using batch_isend_irecv().
torch.cuda.synchronize()
def send_to_next_pipeline_rank(tensor=None):
"""Send output to the next pipeline stage."""
if not mpu.is_pipeline_last_stage():
assert tensor is not None
send_next_op = torch.distributed.P2POp(
torch.distributed.isend, tensor,
mpu.get_pipeline_model_parallel_next_rank())
reqs = torch.distributed.batch_isend_irecv([send_next_op])
for req in reqs:
req.wait()
# To protect against race condition when using batch_isend_irecv().
torch.cuda.synchronize()
def broadcast_from_last_pipeline_stage(size, dtype, tensor=None):
"""Broadcast a tensor from last pipeline stage to all ranks."""
......@@ -96,6 +128,7 @@ def copy_from_last_to_first_pipeline_stage(size, dtype, tensor=None):
tensor[...] = tensor_
def broadcast_tensor(size, dtype, tensor=None, rank=0):
""" Given size and type of a tensor on all ranks and the tensor value
only on a specific rank, broadcast from that rank to all other ranks.
......@@ -114,6 +147,7 @@ def broadcast_tensor(size, dtype, tensor=None, rank=0):
return tensor
def broadcast_list(size, dtype, list_values=None, rank=0):
"""Broadcast a list of values with a given type."""
......@@ -125,12 +159,14 @@ def broadcast_list(size, dtype, list_values=None, rank=0):
return broadcast_tensor(size, dtype, tensor=tensor, rank=rank)
def broadcast_int_list(size, int_list=None, rank=0):
"""Broadcast a list of interger values."""
return broadcast_list(size, torch.int64, list_values=int_list, rank=rank)
def broadcast_float_list(size, float_list=None, rank=0):
"""Broadcast a list of float values."""
......
......@@ -22,14 +22,20 @@ import torch
from megatron import (
get_args,
mpu)
from .communication import (
send_to_next_pipeline_rank,
recv_from_prev_pipeline_rank_)
class InferenceParams:
"""Inference parameters that are passed to the main model in order
to efficienly calculate and store the context during inference."""
def __init__(self, max_batch_size, max_sequence_len):
"""Note that offsets are set to zero and we always set the
flag to allocate memory. After the first call, make sure to
set this flag to False."""
self.max_sequence_len = max_sequence_len
self.max_batch_size = max_batch_size
self.sequence_len_offset = 0
......@@ -39,38 +45,50 @@ class InferenceParams:
class ForwardStep:
"""Forward step function with all the communications.
We use a class here to hide the inference parameters
from the outside caller."""
def __init__(self, model, max_batch_size, max_sequence_len):
"""Set values so we don't need to do it multiple times."""
# Make sure model is in eval mode.
if isinstance(model, Iterable):
for this_model in model:
this_model.eval()
else:
model.eval()
assert not isinstance(model, Iterable), \
'interleaving schedule is not supported for inference'
model.eval()
self.model = model
self.constant = 512
# Initialize inference parameters.
self.inference_params = InferenceParams(max_batch_size,
max_sequence_len)
# Pipelining arguments.
args = get_args()
self.pipeline_size_larger_than_one = args.pipeline_model_parallel_size
# Threshold of pipelining.
self.pipelining_batch_x_seqlen = \
args.inference_batch_times_seqlen_threshold
def __call__(self, tokens, position_ids, attention_mask):
if tokens.size(0) * tokens.size(1) >= self.constant:
micro_batch_size = max(1, self.constant // tokens.size(1))
return _with_pipelining_forward_step(self.model, tokens,
position_ids,
attention_mask,
self.inference_params,
micro_batch_size)
else:
return _no_pipelining_forward_step(self.model, tokens,
position_ids,
attention_mask,
self.inference_params)
"""Invocation of the forward methods. Note that self.inference_params
is being modified by the forward step."""
# Pipelining case.
if self.pipeline_size_larger_than_one:
current_batch_x_seqlen = tokens.size(0) * tokens.size(1)
if current_batch_x_seqlen >= self.pipelining_batch_x_seqlen:
micro_batch_size = \
max(1, self.pipelining_batch_x_seqlen // tokens.size(1))
return _with_pipelining_forward_step(self.model,
tokens,
position_ids,
attention_mask,
self.inference_params,
micro_batch_size)
return _no_pipelining_forward_step(self.model,
tokens,
position_ids,
attention_mask,
self.inference_params)
def _get_recv_buffer_dtype(args):
......@@ -103,9 +121,7 @@ def _forward_step_helper(model, tokens, position_ids, attention_mask,
recv_buffer = _allocate_recv_buffer(batch_size, sequence_length)
# Receive from previous stage.
if not mpu.is_pipeline_first_stage():
torch.distributed.recv(recv_buffer,
src=mpu.get_pipeline_model_parallel_prev_rank())
recv_from_prev_pipeline_rank_(recv_buffer)
# Forward pass through the model.
model.set_input_tensor(recv_buffer)
......@@ -113,9 +129,7 @@ def _forward_step_helper(model, tokens, position_ids, attention_mask,
inference_params=inference_params)
# Send output to the next stage.
if not mpu.is_pipeline_last_stage():
torch.distributed.send(output_tensor,
mpu.get_pipeline_model_parallel_next_rank())
send_to_next_pipeline_rank(output_tensor)
# Make sure we do not allocate context memory anymore.
if inference_params.allocate_key_value_memory:
......@@ -128,7 +142,7 @@ def _forward_step_helper(model, tokens, position_ids, attention_mask,
def _no_pipelining_forward_step(model, tokens, position_ids, attention_mask,
inference_params, recv_buffer=None):
"""If recv_buffer is none, we will allocate one on the fly."""
# Run a simple forward pass.
output_tensor = _forward_step_helper(model, tokens, position_ids,
attention_mask, inference_params,
......@@ -143,9 +157,10 @@ def _no_pipelining_forward_step(model, tokens, position_ids, attention_mask,
return logits
def _with_pipelining_forward_step(model, tokens, position_ids, attention_mask,
inference_params, micro_batch_size):
"""No interleaving is supported."""
sequence_length = tokens.size(1)
batch_size = tokens.size(0)
......
......@@ -32,10 +32,12 @@ def generate_tokens_probs_and_return_on_first_stage(
model, tokens, lengths,
return_output_log_probs=False,
return_all_log_probs=False,
temperature=1.0):
greedy=False, top_k=0, top_p=0.0,
temperature=1.0,
use_eod_token_for_early_termination=True):
"""Main token generation function.
Arguments:
model: XXX
model: no interleaving is supported.
tokens: prompt tokens extended to be of size [b, max-sequence-length]
lengths: original prompt length, size: [b]
return_output_log_probs: flag to calculate the log probability of
......@@ -44,7 +46,14 @@ def generate_tokens_probs_and_return_on_first_stage(
return_all_log_probs: flag to calculate the log probability of across
all the tokens (vocab size). Note that the log probability is the
one after logits are modifed for sampling.
greedy, top_k, top_p: greedy, top-k, and top-p sampling parameters.
Note that these three paramters are exclusive meaning that:
if greedy = true then we should have top-k=top-p=0.
if top-k > 0 then we expect greedy=false and top-p=0.
if top-p > 0 then we check for greedy=false and top-k=0.
temperature: sampling temperature.
use_eod_token_for_early_termination: if True, do early termination if
all the sequences have reached this token.
Note: Outside of model, other parameters only need to be available on
rank 0.
Outputs: Note that is size is adjusted to a lower value than
......@@ -108,10 +117,9 @@ def generate_tokens_probs_and_return_on_first_stage(
# Run infernece
# =============
attention_mask, position_ids = _build_attention_mask_and_position_ids(
tokens)
with torch.no_grad():
attention_mask, position_ids = _build_attention_mask_and_position_ids(
tokens)
prev_context_length = 0
for context_length in range(min_prompt_length, max_sequence_length):
......@@ -132,9 +140,9 @@ def generate_tokens_probs_and_return_on_first_stage(
last_token_logits = logits[:, -1, :]
new_sample, updated_last_token_logits = sample(
last_token_logits,
greedy=args.greedy,
top_k=args.top_k,
top_p=args.top_p,
greedy=greedy,
top_k=top_k,
top_p=top_p,
temperature=temperature,
vocab_size=tokenizer.vocab_size)
# Now that we have the sample and updated logits,
......@@ -189,8 +197,8 @@ def generate_tokens_probs_and_return_on_first_stage(
done = torch.all(is_generation_done)
done = broadcast_from_last_pipeline_stage(1, torch.uint8,
tensor=done)
#if done:
# break
if use_eod_token_for_early_termination and done:
break
# ===================================================
# Update the length of based on max generated length.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment