"git@developer.sourcefind.cn:change/sglang.git" did not exist on "82e6c3a65ab3701c3ef498bc51fbe447e8c6cbe5"
Commit 8c119d80 authored by mshoeybi's avatar mshoeybi
Browse files

tested and working

parent a7539b0f
...@@ -41,6 +41,7 @@ def parse_args(extra_args_provider=None, defaults={}, ...@@ -41,6 +41,7 @@ def parse_args(extra_args_provider=None, defaults={},
parser = _add_biencoder_args(parser) parser = _add_biencoder_args(parser)
parser = _add_vit_args(parser) parser = _add_vit_args(parser)
parser = _add_logging_args(parser) parser = _add_logging_args(parser)
parser = _add_inference_args(parser)
# Custom arguments. # Custom arguments.
if extra_args_provider is not None: if extra_args_provider is not None:
...@@ -273,6 +274,18 @@ def _check_arg_is_not_none(args, arg): ...@@ -273,6 +274,18 @@ def _check_arg_is_not_none(args, arg):
assert getattr(args, arg) is not None, '{} argument is None'.format(arg) assert getattr(args, arg) is not None, '{} argument is None'.format(arg)
def _add_inference_args(parser):
group = parser.add_argument_group(title='inference')
group.add_argument('--inference-batch-times-seqlen-threshold',
type=int, default=512,
help='During inference, if batch-size times '
'sequence-length is smaller than this threshold '
'then we will not use pipelining, otherwise we will.')
return parser
def _add_network_size_args(parser): def _add_network_size_args(parser):
group = parser.add_argument_group(title='network size') group = parser.add_argument_group(title='network size')
......
...@@ -26,14 +26,20 @@ from .tokenization import ( ...@@ -26,14 +26,20 @@ from .tokenization import (
detokenize_generations) detokenize_generations)
def generate_and_post_process(model, def generate_and_post_process(model,
prompts=None, prompts=None,
tokens_to_generate=0, tokens_to_generate=0,
return_output_log_probs=False, return_output_log_probs=False,
return_all_log_probs=False, return_all_log_probs=False,
greedy_sampling=False,
top_k_sampling=0,
top_p_sampling=0.0,
temperature=1.0, temperature=1.0,
add_BOS=False): add_BOS=False,
"""TO DO ...""" use_eod_token_for_early_termination=True):
"""Run inferecne and post-process outputs, i.e., detokenize,
move to cpu and convert to list."""
# Main inference. # Main inference.
tokens, lengths, output_log_probs, all_log_probs = generate( tokens, lengths, output_log_probs, all_log_probs = generate(
...@@ -42,8 +48,12 @@ def generate_and_post_process(model, ...@@ -42,8 +48,12 @@ def generate_and_post_process(model,
tokens_to_generate=tokens_to_generate, tokens_to_generate=tokens_to_generate,
return_output_log_probs=return_output_log_probs, return_output_log_probs=return_output_log_probs,
return_all_log_probs=return_all_log_probs, return_all_log_probs=return_all_log_probs,
greedy_sampling=greedy_sampling,
top_k_sampling=top_k_sampling,
top_p_sampling=top_p_sampling,
temperature=temperature, temperature=temperature,
add_BOS=add_BOS) add_BOS=add_BOS,
use_eod_token_for_early_termination=use_eod_token_for_early_termination)
# Only post-process on first stage. # Only post-process on first stage.
if mpu.is_pipeline_first_stage(): if mpu.is_pipeline_first_stage():
...@@ -62,24 +72,42 @@ def generate_and_post_process(model, ...@@ -62,24 +72,42 @@ def generate_and_post_process(model,
return None return None
def generate(model, def generate(model,
prompts=None, prompts=None,
tokens_to_generate=0, tokens_to_generate=0,
return_output_log_probs=False, return_output_log_probs=False,
return_all_log_probs=False, return_all_log_probs=False,
greedy_sampling=False,
top_k_sampling=0,
top_p_sampling=0.0,
temperature=1.0, temperature=1.0,
add_BOS=False): add_BOS=False,
"""TO DO ...""" use_eod_token_for_early_termination=True):
"""Given prompts and input parameters, run inference and return:
tokens: prompts plus the generated tokens.
lengths: length of the prompt + generations. Note that we can
discard tokens in the tokens tensor that are after the
corresponding length.
output_log_probs: log probs of the tokens.
all_log_probs: full log probs for all of tokens.
"""
# Make sure input params are avaialble to all ranks. # Make sure input params are avaialble to all ranks.
values = [tokens_to_generate, return_output_log_probs, values = [tokens_to_generate,
return_all_log_probs, temperature, add_BOS] return_output_log_probs, return_all_log_probs,
values_float_tensor = broadcast_float_list(5, float_list=values) greedy_sampling, top_k_sampling, top_p_sampling,
temperature, add_BOS, use_eod_token_for_early_termination]
values_float_tensor = broadcast_float_list(9, float_list=values)
tokens_to_generate = int(values_float_tensor[0].item()) tokens_to_generate = int(values_float_tensor[0].item())
return_output_log_probs = bool(values_float_tensor[1].item()) return_output_log_probs = bool(values_float_tensor[1].item())
return_all_log_probs = bool(values_float_tensor[2].item()) return_all_log_probs = bool(values_float_tensor[2].item())
temperature = values_float_tensor[3].item() greedy_sampling = bool(values_float_tensor[3].item())
add_BOS = bool(values_float_tensor[4].item()) top_k_sampling = int(values_float_tensor[4].item())
top_p_sampling = values_float_tensor[5].item()
temperature = values_float_tensor[6].item()
add_BOS = bool(values_float_tensor[7].item())
use_eod_token_for_early_termination = bool(values_float_tensor[8].item())
# Tokenize prompts and get the batch. # Tokenize prompts and get the batch.
# Note that these tensors are broadcaseted to all ranks. # Note that these tensors are broadcaseted to all ranks.
...@@ -95,4 +123,6 @@ def generate(model, ...@@ -95,4 +123,6 @@ def generate(model,
model, context_tokens_tensor, context_length_tensor, model, context_tokens_tensor, context_length_tensor,
return_output_log_probs=return_output_log_probs, return_output_log_probs=return_output_log_probs,
return_all_log_probs=return_all_log_probs, return_all_log_probs=return_all_log_probs,
temperature=temperature) greedy=greedy_sampling, top_k=top_k_sampling, top_p=top_p_sampling,
temperature=temperature,
use_eod_token_for_early_termination=use_eod_token_for_early_termination)
...@@ -21,6 +21,38 @@ import torch ...@@ -21,6 +21,38 @@ import torch
from megatron import mpu from megatron import mpu
def recv_from_prev_pipeline_rank_(recv_buffer=None):
"""Receive from previous pipeline stage and update the
input buffer inplace."""
if not mpu.is_pipeline_first_stage():
assert recv_buffer is not None
recv_prev_op = torch.distributed.P2POp(
torch.distributed.irecv, recv_buffer,
mpu.get_pipeline_model_parallel_prev_rank())
reqs = torch.distributed.batch_isend_irecv([recv_prev_op])
for req in reqs:
req.wait()
# To protect against race condition when using batch_isend_irecv().
torch.cuda.synchronize()
def send_to_next_pipeline_rank(tensor=None):
"""Send output to the next pipeline stage."""
if not mpu.is_pipeline_last_stage():
assert tensor is not None
send_next_op = torch.distributed.P2POp(
torch.distributed.isend, tensor,
mpu.get_pipeline_model_parallel_next_rank())
reqs = torch.distributed.batch_isend_irecv([send_next_op])
for req in reqs:
req.wait()
# To protect against race condition when using batch_isend_irecv().
torch.cuda.synchronize()
def broadcast_from_last_pipeline_stage(size, dtype, tensor=None): def broadcast_from_last_pipeline_stage(size, dtype, tensor=None):
"""Broadcast a tensor from last pipeline stage to all ranks.""" """Broadcast a tensor from last pipeline stage to all ranks."""
...@@ -96,6 +128,7 @@ def copy_from_last_to_first_pipeline_stage(size, dtype, tensor=None): ...@@ -96,6 +128,7 @@ def copy_from_last_to_first_pipeline_stage(size, dtype, tensor=None):
tensor[...] = tensor_ tensor[...] = tensor_
def broadcast_tensor(size, dtype, tensor=None, rank=0): def broadcast_tensor(size, dtype, tensor=None, rank=0):
""" Given size and type of a tensor on all ranks and the tensor value """ Given size and type of a tensor on all ranks and the tensor value
only on a specific rank, broadcast from that rank to all other ranks. only on a specific rank, broadcast from that rank to all other ranks.
...@@ -114,6 +147,7 @@ def broadcast_tensor(size, dtype, tensor=None, rank=0): ...@@ -114,6 +147,7 @@ def broadcast_tensor(size, dtype, tensor=None, rank=0):
return tensor return tensor
def broadcast_list(size, dtype, list_values=None, rank=0): def broadcast_list(size, dtype, list_values=None, rank=0):
"""Broadcast a list of values with a given type.""" """Broadcast a list of values with a given type."""
...@@ -125,12 +159,14 @@ def broadcast_list(size, dtype, list_values=None, rank=0): ...@@ -125,12 +159,14 @@ def broadcast_list(size, dtype, list_values=None, rank=0):
return broadcast_tensor(size, dtype, tensor=tensor, rank=rank) return broadcast_tensor(size, dtype, tensor=tensor, rank=rank)
def broadcast_int_list(size, int_list=None, rank=0): def broadcast_int_list(size, int_list=None, rank=0):
"""Broadcast a list of interger values.""" """Broadcast a list of interger values."""
return broadcast_list(size, torch.int64, list_values=int_list, rank=rank) return broadcast_list(size, torch.int64, list_values=int_list, rank=rank)
def broadcast_float_list(size, float_list=None, rank=0): def broadcast_float_list(size, float_list=None, rank=0):
"""Broadcast a list of float values.""" """Broadcast a list of float values."""
......
...@@ -22,14 +22,20 @@ import torch ...@@ -22,14 +22,20 @@ import torch
from megatron import ( from megatron import (
get_args, get_args,
mpu) mpu)
from .communication import (
send_to_next_pipeline_rank,
recv_from_prev_pipeline_rank_)
class InferenceParams: class InferenceParams:
"""Inference parameters that are passed to the main model in order
to efficienly calculate and store the context during inference."""
def __init__(self, max_batch_size, max_sequence_len): def __init__(self, max_batch_size, max_sequence_len):
"""Note that offsets are set to zero and we always set the
flag to allocate memory. After the first call, make sure to
set this flag to False."""
self.max_sequence_len = max_sequence_len self.max_sequence_len = max_sequence_len
self.max_batch_size = max_batch_size self.max_batch_size = max_batch_size
self.sequence_len_offset = 0 self.sequence_len_offset = 0
...@@ -39,38 +45,50 @@ class InferenceParams: ...@@ -39,38 +45,50 @@ class InferenceParams:
class ForwardStep: class ForwardStep:
"""Forward step function with all the communications.
We use a class here to hide the inference parameters
from the outside caller."""
def __init__(self, model, max_batch_size, max_sequence_len): def __init__(self, model, max_batch_size, max_sequence_len):
"""Set values so we don't need to do it multiple times."""
# Make sure model is in eval mode. # Make sure model is in eval mode.
if isinstance(model, Iterable): assert not isinstance(model, Iterable), \
for this_model in model: 'interleaving schedule is not supported for inference'
this_model.eval() model.eval()
else:
model.eval()
self.model = model self.model = model
self.constant = 512
# Initialize inference parameters. # Initialize inference parameters.
self.inference_params = InferenceParams(max_batch_size, self.inference_params = InferenceParams(max_batch_size,
max_sequence_len) max_sequence_len)
# Pipelining arguments.
args = get_args()
self.pipeline_size_larger_than_one = args.pipeline_model_parallel_size
# Threshold of pipelining.
self.pipelining_batch_x_seqlen = \
args.inference_batch_times_seqlen_threshold
def __call__(self, tokens, position_ids, attention_mask): def __call__(self, tokens, position_ids, attention_mask):
if tokens.size(0) * tokens.size(1) >= self.constant: """Invocation of the forward methods. Note that self.inference_params
micro_batch_size = max(1, self.constant // tokens.size(1)) is being modified by the forward step."""
return _with_pipelining_forward_step(self.model, tokens, # Pipelining case.
position_ids, if self.pipeline_size_larger_than_one:
attention_mask, current_batch_x_seqlen = tokens.size(0) * tokens.size(1)
self.inference_params, if current_batch_x_seqlen >= self.pipelining_batch_x_seqlen:
micro_batch_size) micro_batch_size = \
else: max(1, self.pipelining_batch_x_seqlen // tokens.size(1))
return _no_pipelining_forward_step(self.model, tokens, return _with_pipelining_forward_step(self.model,
position_ids, tokens,
attention_mask, position_ids,
self.inference_params) attention_mask,
self.inference_params,
micro_batch_size)
return _no_pipelining_forward_step(self.model,
tokens,
position_ids,
attention_mask,
self.inference_params)
def _get_recv_buffer_dtype(args): def _get_recv_buffer_dtype(args):
...@@ -103,9 +121,7 @@ def _forward_step_helper(model, tokens, position_ids, attention_mask, ...@@ -103,9 +121,7 @@ def _forward_step_helper(model, tokens, position_ids, attention_mask,
recv_buffer = _allocate_recv_buffer(batch_size, sequence_length) recv_buffer = _allocate_recv_buffer(batch_size, sequence_length)
# Receive from previous stage. # Receive from previous stage.
if not mpu.is_pipeline_first_stage(): recv_from_prev_pipeline_rank_(recv_buffer)
torch.distributed.recv(recv_buffer,
src=mpu.get_pipeline_model_parallel_prev_rank())
# Forward pass through the model. # Forward pass through the model.
model.set_input_tensor(recv_buffer) model.set_input_tensor(recv_buffer)
...@@ -113,9 +129,7 @@ def _forward_step_helper(model, tokens, position_ids, attention_mask, ...@@ -113,9 +129,7 @@ def _forward_step_helper(model, tokens, position_ids, attention_mask,
inference_params=inference_params) inference_params=inference_params)
# Send output to the next stage. # Send output to the next stage.
if not mpu.is_pipeline_last_stage(): send_to_next_pipeline_rank(output_tensor)
torch.distributed.send(output_tensor,
mpu.get_pipeline_model_parallel_next_rank())
# Make sure we do not allocate context memory anymore. # Make sure we do not allocate context memory anymore.
if inference_params.allocate_key_value_memory: if inference_params.allocate_key_value_memory:
...@@ -128,7 +142,7 @@ def _forward_step_helper(model, tokens, position_ids, attention_mask, ...@@ -128,7 +142,7 @@ def _forward_step_helper(model, tokens, position_ids, attention_mask,
def _no_pipelining_forward_step(model, tokens, position_ids, attention_mask, def _no_pipelining_forward_step(model, tokens, position_ids, attention_mask,
inference_params, recv_buffer=None): inference_params, recv_buffer=None):
"""If recv_buffer is none, we will allocate one on the fly."""
# Run a simple forward pass. # Run a simple forward pass.
output_tensor = _forward_step_helper(model, tokens, position_ids, output_tensor = _forward_step_helper(model, tokens, position_ids,
attention_mask, inference_params, attention_mask, inference_params,
...@@ -143,9 +157,10 @@ def _no_pipelining_forward_step(model, tokens, position_ids, attention_mask, ...@@ -143,9 +157,10 @@ def _no_pipelining_forward_step(model, tokens, position_ids, attention_mask,
return logits return logits
def _with_pipelining_forward_step(model, tokens, position_ids, attention_mask, def _with_pipelining_forward_step(model, tokens, position_ids, attention_mask,
inference_params, micro_batch_size): inference_params, micro_batch_size):
"""No interleaving is supported."""
sequence_length = tokens.size(1) sequence_length = tokens.size(1)
batch_size = tokens.size(0) batch_size = tokens.size(0)
......
...@@ -32,10 +32,12 @@ def generate_tokens_probs_and_return_on_first_stage( ...@@ -32,10 +32,12 @@ def generate_tokens_probs_and_return_on_first_stage(
model, tokens, lengths, model, tokens, lengths,
return_output_log_probs=False, return_output_log_probs=False,
return_all_log_probs=False, return_all_log_probs=False,
temperature=1.0): greedy=False, top_k=0, top_p=0.0,
temperature=1.0,
use_eod_token_for_early_termination=True):
"""Main token generation function. """Main token generation function.
Arguments: Arguments:
model: XXX model: no interleaving is supported.
tokens: prompt tokens extended to be of size [b, max-sequence-length] tokens: prompt tokens extended to be of size [b, max-sequence-length]
lengths: original prompt length, size: [b] lengths: original prompt length, size: [b]
return_output_log_probs: flag to calculate the log probability of return_output_log_probs: flag to calculate the log probability of
...@@ -44,7 +46,14 @@ def generate_tokens_probs_and_return_on_first_stage( ...@@ -44,7 +46,14 @@ def generate_tokens_probs_and_return_on_first_stage(
return_all_log_probs: flag to calculate the log probability of across return_all_log_probs: flag to calculate the log probability of across
all the tokens (vocab size). Note that the log probability is the all the tokens (vocab size). Note that the log probability is the
one after logits are modifed for sampling. one after logits are modifed for sampling.
greedy, top_k, top_p: greedy, top-k, and top-p sampling parameters.
Note that these three paramters are exclusive meaning that:
if greedy = true then we should have top-k=top-p=0.
if top-k > 0 then we expect greedy=false and top-p=0.
if top-p > 0 then we check for greedy=false and top-k=0.
temperature: sampling temperature. temperature: sampling temperature.
use_eod_token_for_early_termination: if True, do early termination if
all the sequences have reached this token.
Note: Outside of model, other parameters only need to be available on Note: Outside of model, other parameters only need to be available on
rank 0. rank 0.
Outputs: Note that is size is adjusted to a lower value than Outputs: Note that is size is adjusted to a lower value than
...@@ -108,10 +117,9 @@ def generate_tokens_probs_and_return_on_first_stage( ...@@ -108,10 +117,9 @@ def generate_tokens_probs_and_return_on_first_stage(
# Run infernece # Run infernece
# ============= # =============
attention_mask, position_ids = _build_attention_mask_and_position_ids(
tokens)
with torch.no_grad(): with torch.no_grad():
attention_mask, position_ids = _build_attention_mask_and_position_ids(
tokens)
prev_context_length = 0 prev_context_length = 0
for context_length in range(min_prompt_length, max_sequence_length): for context_length in range(min_prompt_length, max_sequence_length):
...@@ -132,9 +140,9 @@ def generate_tokens_probs_and_return_on_first_stage( ...@@ -132,9 +140,9 @@ def generate_tokens_probs_and_return_on_first_stage(
last_token_logits = logits[:, -1, :] last_token_logits = logits[:, -1, :]
new_sample, updated_last_token_logits = sample( new_sample, updated_last_token_logits = sample(
last_token_logits, last_token_logits,
greedy=args.greedy, greedy=greedy,
top_k=args.top_k, top_k=top_k,
top_p=args.top_p, top_p=top_p,
temperature=temperature, temperature=temperature,
vocab_size=tokenizer.vocab_size) vocab_size=tokenizer.vocab_size)
# Now that we have the sample and updated logits, # Now that we have the sample and updated logits,
...@@ -189,8 +197,8 @@ def generate_tokens_probs_and_return_on_first_stage( ...@@ -189,8 +197,8 @@ def generate_tokens_probs_and_return_on_first_stage(
done = torch.all(is_generation_done) done = torch.all(is_generation_done)
done = broadcast_from_last_pipeline_stage(1, torch.uint8, done = broadcast_from_last_pipeline_stage(1, torch.uint8,
tensor=done) tensor=done)
#if done: if use_eod_token_for_early_termination and done:
# break break
# =================================================== # ===================================================
# Update the length of based on max generated length. # Update the length of based on max generated length.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment