Commit 0b0e37f0 authored by Jared Casper's avatar Jared Casper
Browse files

Merge branch 'inference' into 'main'

Inference refactoring

See merge request ADLR/megatron-lm!339
parents 5478d67e 71359e1f
...@@ -427,7 +427,7 @@ Several downstream tasks are described for both GPT and BERT models below. They ...@@ -427,7 +427,7 @@ Several downstream tasks are described for both GPT and BERT models below. They
## GPT Text Generation ## GPT Text Generation
We have included a simple REST server to use for text generation in `tools/run_text_generation_server.py`. You run it much like you would start a pretraining job, specifying an appropriate pretrained checkpoint. There are also few optional parameters: `temperature`, `top-k`, `top-p`, and `greedy`. See `--help` or the source file for more information. See [examples/run_text_generation_server_345M.sh](examples/run_text_generation_server_345M.sh) for an example of how to run the server. We have included a simple REST server to use for text generation in `tools/run_text_generation_server.py`. You run it much like you would start a pretraining job, specifying an appropriate pretrained checkpoint. There are also few optional parameters: `temperature`, `top-k`and `top-p`. See `--help` or the source file for more information. See [examples/run_text_generation_server_345M.sh](examples/run_text_generation_server_345M.sh) for an example of how to run the server.
Once the server is running you can use `tools/text_generation_cli.py` to query it, it takes one argument which is the host the server is running on. Once the server is running you can use `tools/text_generation_cli.py` to query it, it takes one argument which is the host the server is running on.
......
...@@ -41,6 +41,7 @@ def parse_args(extra_args_provider=None, defaults={}, ...@@ -41,6 +41,7 @@ def parse_args(extra_args_provider=None, defaults={},
parser = _add_biencoder_args(parser) parser = _add_biencoder_args(parser)
parser = _add_vit_args(parser) parser = _add_vit_args(parser)
parser = _add_logging_args(parser) parser = _add_logging_args(parser)
parser = _add_inference_args(parser)
# Custom arguments. # Custom arguments.
if extra_args_provider is not None: if extra_args_provider is not None:
...@@ -279,6 +280,18 @@ def _check_arg_is_not_none(args, arg): ...@@ -279,6 +280,18 @@ def _check_arg_is_not_none(args, arg):
assert getattr(args, arg) is not None, '{} argument is None'.format(arg) assert getattr(args, arg) is not None, '{} argument is None'.format(arg)
def _add_inference_args(parser):
group = parser.add_argument_group(title='inference')
group.add_argument('--inference-batch-times-seqlen-threshold',
type=int, default=512,
help='During inference, if batch-size times '
'sequence-length is smaller than this threshold '
'then we will not use pipelining, otherwise we will.')
return parser
def _add_network_size_args(parser): def _add_network_size_args(parser):
group = parser.add_argument_group(title='network size') group = parser.add_argument_group(title='network size')
......
...@@ -82,16 +82,13 @@ class GPTModel(MegatronModule): ...@@ -82,16 +82,13 @@ class GPTModel(MegatronModule):
self.language_model.set_input_tensor(input_tensor) self.language_model.set_input_tensor(input_tensor)
def forward(self, input_ids, position_ids, attention_mask, labels=None, def forward(self, input_ids, position_ids, attention_mask, labels=None,
tokentype_ids=None, tokentype_ids=None, inference_params=None):
set_inference_key_value_memory=False,
inference_max_sequence_len=None):
lm_output = self.language_model( lm_output = self.language_model(
input_ids, input_ids,
position_ids, position_ids,
attention_mask, attention_mask,
set_inference_key_value_memory=set_inference_key_value_memory, inference_params=inference_params)
inference_max_sequence_len=inference_max_sequence_len)
if self.post_process: if self.post_process:
return post_language_model_processing( return post_language_model_processing(
......
...@@ -386,8 +386,7 @@ class TransformerLanguageModel(MegatronModule): ...@@ -386,8 +386,7 @@ class TransformerLanguageModel(MegatronModule):
def forward(self, enc_input_ids, enc_position_ids, enc_attn_mask, def forward(self, enc_input_ids, enc_position_ids, enc_attn_mask,
dec_input_ids=None, dec_position_ids=None, dec_attn_mask=None, dec_input_ids=None, dec_position_ids=None, dec_attn_mask=None,
enc_dec_attn_mask=None, tokentype_ids=None, enc_dec_attn_mask=None, tokentype_ids=None,
set_inference_key_value_memory=False, inference_params=None,
inference_max_sequence_len=None,
pooling_sequence_index=0, pooling_sequence_index=0,
enc_hidden_states=None, output_enc_hidden=False): enc_hidden_states=None, output_enc_hidden=False):
...@@ -404,8 +403,7 @@ class TransformerLanguageModel(MegatronModule): ...@@ -404,8 +403,7 @@ class TransformerLanguageModel(MegatronModule):
encoder_output = self.encoder( encoder_output = self.encoder(
encoder_input, encoder_input,
enc_attn_mask, enc_attn_mask,
set_inference_key_value_memory=set_inference_key_value_memory, inference_params=inference_params)
inference_max_sequence_len=inference_max_sequence_len)
else: else:
encoder_output = self.encoder_hidden_state encoder_output = self.encoder_hidden_state
else: else:
...@@ -438,8 +436,7 @@ class TransformerLanguageModel(MegatronModule): ...@@ -438,8 +436,7 @@ class TransformerLanguageModel(MegatronModule):
dec_attn_mask, dec_attn_mask,
encoder_output=encoder_output, encoder_output=encoder_output,
enc_dec_attn_mask=enc_dec_attn_mask, enc_dec_attn_mask=enc_dec_attn_mask,
set_inference_key_value_memory=set_inference_key_value_memory, inference_params=inference_params)
inference_max_sequence_len=inference_max_sequence_len)
if self.add_pooler and self.post_process: if self.add_pooler and self.post_process:
return decoder_output, encoder_output, pooled_output return decoder_output, encoder_output, pooled_output
......
...@@ -187,6 +187,10 @@ class Float16Module(MegatronModule): ...@@ -187,6 +187,10 @@ class Float16Module(MegatronModule):
self.float16_convertor = float16_convertor self.float16_convertor = float16_convertor
def set_input_tensor(self, input_tensor):
return self.module.set_input_tensor(input_tensor)
def forward(self, *inputs, **kwargs): def forward(self, *inputs, **kwargs):
if mpu.is_pipeline_first_stage(): if mpu.is_pipeline_first_stage():
inputs = fp32_to_float16(inputs, self.float16_convertor) inputs = fp32_to_float16(inputs, self.float16_convertor)
......
...@@ -182,7 +182,6 @@ class ParallelAttention(MegatronModule): ...@@ -182,7 +182,6 @@ class ParallelAttention(MegatronModule):
# Inference key-value memory # Inference key-value memory
self.inference_key_memory = None self.inference_key_memory = None
self.inference_value_memory = None self.inference_value_memory = None
self.inference_current_sequence_len = 0
def _allocate_memory(self, inference_max_sequence_len, batch_size): def _allocate_memory(self, inference_max_sequence_len, batch_size):
...@@ -196,35 +195,28 @@ class ParallelAttention(MegatronModule): ...@@ -196,35 +195,28 @@ class ParallelAttention(MegatronModule):
def forward(self, hidden_states, attention_mask, def forward(self, hidden_states, attention_mask,
encoder_output=None, encoder_output=None, inference_params=None):
set_inference_key_value_memory=False,
inference_max_sequence_len=None):
# hidden_states: [sq, b, h] # hidden_states: [sq, b, h]
# ================================================= # =================================================
# Pre-allocate memory for key-values for inference. # Pre-allocate memory for key-values for inference.
# ================================================= # =================================================
if set_inference_key_value_memory: if inference_params:
assert inference_max_sequence_len and inference_max_sequence_len > 0 if inference_params.allocate_key_value_memory:
self.inference_key_memory = self._allocate_memory( inf_max_seq_len = inference_params.max_sequence_len
inference_max_sequence_len, hidden_states.size(1)) inf_max_batch_size = inference_params.max_batch_size
self.inference_value_memory = self._allocate_memory( self.inference_key_memory = self._allocate_memory(
inference_max_sequence_len, hidden_states.size(1)) inf_max_seq_len, inf_max_batch_size)
self.inference_current_sequence_len = 0 self.inference_value_memory = self._allocate_memory(
# Some consistency check. inf_max_seq_len, inf_max_batch_size)
if inference_max_sequence_len: # This is added for safety. In case inference_params
assert self.inference_current_sequence_len < \
self.inference_key_memory.size(0)
assert inference_max_sequence_len == \
self.inference_key_memory.size(0)
# This is added for safety. In case inference_max_sequence_len
# is not provided, make sure there is no potential memory left # is not provided, make sure there is no potential memory left
# from previous inference. # from previous inference.
if not inference_max_sequence_len: else:
self.inference_key_memory = None self.inference_key_memory = None
self.inference_value_memory = None self.inference_value_memory = None
# ===================== # =====================
# Query, Key, and Value # Query, Key, and Value
...@@ -267,22 +259,28 @@ class ParallelAttention(MegatronModule): ...@@ -267,22 +259,28 @@ class ParallelAttention(MegatronModule):
query_layer = query_layer.view(*new_tensor_shape) query_layer = query_layer.view(*new_tensor_shape)
# =================================================== # ==================================
# Adjust key, value, and attention mask for inference # Adjust key and value for inference
# =================================================== # ==================================
if inference_max_sequence_len: if inference_params:
# Adjust the range variables. batch_start = inference_params.batch_size_offset
start = self.inference_current_sequence_len batch_end = batch_start + key_layer.size(1)
self.inference_current_sequence_len += key_layer.size(0) assert batch_end <= self.inference_key_memory.size(1)
end = self.inference_current_sequence_len sequence_start = inference_params.sequence_len_offset
sequence_end = sequence_start + key_layer.size(0)
assert sequence_end <= self.inference_key_memory.size(0)
# Copy key and values. # Copy key and values.
self.inference_key_memory[start:end, ...] = key_layer self.inference_key_memory[sequence_start:sequence_end,
self.inference_value_memory[start:end, ...] = value_layer batch_start:batch_end,
key_layer = self.inference_key_memory[:end, ...] ...] = key_layer
value_layer = self.inference_value_memory[:end, ...] self.inference_value_memory[sequence_start:sequence_end,
# Adjust attention mask batch_start:batch_end,
attention_mask = attention_mask[..., start:end, :end] ...] = value_layer
key_layer = self.inference_key_memory[
:sequence_end, batch_start:batch_end, ...]
value_layer = self.inference_value_memory[
:sequence_end, batch_start:batch_end, ...]
# =================================== # ===================================
...@@ -465,10 +463,8 @@ class ParallelTransformerLayer(MegatronModule): ...@@ -465,10 +463,8 @@ class ParallelTransformerLayer(MegatronModule):
output_layer_init_method) output_layer_init_method)
def forward(self, hidden_states, attention_mask, def forward(self, hidden_states, attention_mask,
encoder_output=None, encoder_output=None, enc_dec_attn_mask=None,
enc_dec_attn_mask=None, inference_params=None):
set_inference_key_value_memory=False,
inference_max_sequence_len=None):
# hidden_states: [b, s, h] # hidden_states: [b, s, h]
# Layer norm at the beginning of the transformer layer. # Layer norm at the beginning of the transformer layer.
...@@ -478,8 +474,7 @@ class ParallelTransformerLayer(MegatronModule): ...@@ -478,8 +474,7 @@ class ParallelTransformerLayer(MegatronModule):
self.self_attention( self.self_attention(
layernorm_output, layernorm_output,
attention_mask, attention_mask,
set_inference_key_value_memory=set_inference_key_value_memory, inference_params=inference_params)
inference_max_sequence_len=inference_max_sequence_len)
# Residual connection. # Residual connection.
if self.apply_residual_connection_post_layernorm: if self.apply_residual_connection_post_layernorm:
...@@ -691,13 +686,11 @@ class ParallelTransformer(MegatronModule): ...@@ -691,13 +686,11 @@ class ParallelTransformer(MegatronModule):
self.input_tensor = input_tensor self.input_tensor = input_tensor
def forward(self, hidden_states, attention_mask, def forward(self, hidden_states, attention_mask,
encoder_output=None, encoder_output=None, enc_dec_attn_mask=None,
enc_dec_attn_mask=None, inference_params=None):
set_inference_key_value_memory=False,
inference_max_sequence_len=None):
# Checks. # Checks.
if inference_max_sequence_len: if inference_params:
assert self.activations_checkpoint_method is None, \ assert self.activations_checkpoint_method is None, \
'inference does not work with activation checkpointing' 'inference does not work with activation checkpointing'
...@@ -729,8 +722,8 @@ class ParallelTransformer(MegatronModule): ...@@ -729,8 +722,8 @@ class ParallelTransformer(MegatronModule):
attention_mask, attention_mask,
encoder_output=encoder_output, encoder_output=encoder_output,
enc_dec_attn_mask=enc_dec_attn_mask, enc_dec_attn_mask=enc_dec_attn_mask,
set_inference_key_value_memory=set_inference_key_value_memory, inference_params=inference_params)
inference_max_sequence_len=inference_max_sequence_len)
# Final layer norm. # Final layer norm.
if self.post_process: if self.post_process:
......
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .api import (
generate,
generate_and_post_process)
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference API."""
import torch
from megatron import mpu
from .communication import broadcast_float_list
from .generation import generate_tokens_probs_and_return_on_first_stage
from .tokenization import (
tokenize_prompts,
detokenize_generations)
def generate_and_post_process(model,
prompts=None,
tokens_to_generate=0,
return_output_log_probs=False,
top_k_sampling=0,
top_p_sampling=0.0,
temperature=1.0,
add_BOS=False,
use_eod_token_for_early_termination=True):
"""Run inference and post-process outputs, i.e., detokenize,
move to cpu and convert to list."""
# Main inference.
tokens, lengths, output_log_probs = generate(
model,
prompts=prompts,
tokens_to_generate=tokens_to_generate,
return_output_log_probs=return_output_log_probs,
top_k_sampling=top_k_sampling,
top_p_sampling=top_p_sampling,
temperature=temperature,
add_BOS=add_BOS,
use_eod_token_for_early_termination=use_eod_token_for_early_termination)
# Only post-process on first stage.
if mpu.is_pipeline_first_stage():
tokens, prompts_plus_generations, prompts_plus_generations_segments = \
detokenize_generations(tokens, lengths, True)
if return_output_log_probs:
output_log_probs = output_log_probs.cpu().numpy().tolist()
return prompts_plus_generations, prompts_plus_generations_segments, \
output_log_probs, tokens
return None
def generate(model,
prompts=None,
tokens_to_generate=0,
return_output_log_probs=False,
top_k_sampling=0,
top_p_sampling=0.0,
temperature=1.0,
add_BOS=False,
use_eod_token_for_early_termination=True):
"""Given prompts and input parameters, run inference and return:
tokens: prompts plus the generated tokens.
lengths: length of the prompt + generations. Note that we can
discard tokens in the tokens tensor that are after the
corresponding length.
output_log_probs: log probs of the tokens.
"""
# Make sure input params are avaialble to all ranks.
values = [tokens_to_generate, return_output_log_probs,
top_k_sampling, top_p_sampling,
temperature, add_BOS, use_eod_token_for_early_termination]
values_float_tensor = broadcast_float_list(7, float_list=values)
tokens_to_generate = int(values_float_tensor[0].item())
return_output_log_probs = bool(values_float_tensor[1].item())
top_k_sampling = int(values_float_tensor[2].item())
top_p_sampling = values_float_tensor[3].item()
temperature = values_float_tensor[4].item()
add_BOS = bool(values_float_tensor[5].item())
use_eod_token_for_early_termination = bool(values_float_tensor[6].item())
# Tokenize prompts and get the batch.
# Note that these tensors are broadcaseted to all ranks.
if torch.distributed.get_rank() == 0:
assert prompts is not None
assert tokens_to_generate > 0
context_tokens_tensor, context_length_tensor = tokenize_prompts(
prompts=prompts, tokens_to_generate=tokens_to_generate, add_BOS=add_BOS)
# Main inference function.
# Note that the outputs are available on the first stage.
return generate_tokens_probs_and_return_on_first_stage(
model, context_tokens_tensor, context_length_tensor,
return_output_log_probs=return_output_log_probs,
top_k=top_k_sampling,
top_p=top_p_sampling,
temperature=temperature,
use_eod_token_for_early_termination=use_eod_token_for_early_termination)
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Communications utilities."""
import torch
from megatron import mpu
# TODO: use functions from megatron/p2p
def recv_from_prev_pipeline_rank_(recv_buffer=None):
"""Receive from previous pipeline stage and update the
input buffer inplace."""
if not mpu.is_pipeline_first_stage():
assert recv_buffer is not None
recv_prev_op = torch.distributed.P2POp(
torch.distributed.irecv, recv_buffer,
mpu.get_pipeline_model_parallel_prev_rank())
reqs = torch.distributed.batch_isend_irecv([recv_prev_op])
for req in reqs:
req.wait()
# To protect against race condition when using batch_isend_irecv().
torch.cuda.synchronize()
# TODO: use functions from megatron/p2p
def send_to_next_pipeline_rank(tensor=None):
"""Send output to the next pipeline stage."""
if not mpu.is_pipeline_last_stage():
assert tensor is not None
send_next_op = torch.distributed.P2POp(
torch.distributed.isend, tensor,
mpu.get_pipeline_model_parallel_next_rank())
reqs = torch.distributed.batch_isend_irecv([send_next_op])
for req in reqs:
req.wait()
# To protect against race condition when using batch_isend_irecv().
torch.cuda.synchronize()
def _is_cuda(tensor):
"""Check if a tensor is not none and is cuda."""
assert tensor is not None
assert tensor.is_cuda
def _is_cuda_contiguous(tensor):
"""Check if a tensor is not none, is cuda, and is contiguous."""
_is_cuda(tensor)
assert tensor.is_contiguous()
def broadcast_from_last_pipeline_stage(size, dtype, tensor=None):
"""Broadcast a tensor from last pipeline stage to all ranks."""
is_last_stage = mpu.is_pipeline_last_stage()
# If first stage and last state are the same, then there is no
# pipeline parallelism and no need to communicate.
if mpu.is_pipeline_first_stage() and is_last_stage:
return tensor
if is_last_stage:
_is_cuda_contiguous(tensor)
else:
tensor = torch.empty(size,
dtype=dtype,
device=torch.cuda.current_device())
# Get the group and corresponding source rank.
src = mpu.get_pipeline_model_parallel_last_rank()
group = mpu.get_pipeline_model_parallel_group()
torch.distributed.broadcast(tensor, src, group)
return tensor
def broadcast_from_last_to_first_pipeline_stage(size, dtype, tensor=None):
"""Broadcast tensor values from last stage into the first stage."""
is_last_stage = mpu.is_pipeline_last_stage()
is_first_stage = mpu.is_pipeline_first_stage()
# If first stage and last state are the same, then there is no
# pipeline parallelism and no need to communicate.
if is_first_stage and is_last_stage:
return tensor
# Only first and last stage pipeline stages need to be involved.
if is_last_stage or is_first_stage:
if is_last_stage:
_is_cuda_contiguous(tensor)
else:
tensor = torch.empty(size,
dtype=dtype,
device=torch.cuda.current_device())
src = mpu.get_pipeline_model_parallel_last_rank()
group = mpu.get_embedding_group()
# Broadcast from last stage into the first stage.
torch.distributed.broadcast(tensor, src, group)
else:
tensor = None
return tensor
def copy_from_last_to_first_pipeline_stage(size, dtype, tensor=None):
"""Copy tensor values from last stage into the first stage.
Note that the input tensor is updated in place."""
is_last_stage = mpu.is_pipeline_last_stage()
is_first_stage = mpu.is_pipeline_first_stage()
# If first stage and last state are the same, then there is no
# pipeline parallelism and no need to communicate.
if is_first_stage and is_last_stage:
return
# Only first and last stage pipeline stages need to be involved.
if is_last_stage or is_first_stage:
_is_cuda(tensor)
is_contiguous = tensor.is_contiguous()
src = mpu.get_pipeline_model_parallel_last_rank()
group = mpu.get_embedding_group()
if is_contiguous:
tensor_ = tensor
else:
if is_last_stage:
tensor_ = tensor.contiguous()
else:
tensor_ = torch.empty(size,
dtype=dtype,
device=torch.cuda.current_device())
# Broadcast from last stage into the first stage.
torch.distributed.broadcast(tensor_, src, group)
# Update the first stage tensor
if is_first_stage and not is_contiguous:
tensor[...] = tensor_
def broadcast_tensor(size, dtype, tensor=None, rank=0):
""" Given size and type of a tensor on all ranks and the tensor value
only on a specific rank, broadcast from that rank to all other ranks.
"""
if torch.distributed.get_rank() == rank:
_is_cuda_contiguous(tensor)
else:
tensor = torch.empty(size,
dtype=dtype,
device=torch.cuda.current_device())
torch.distributed.broadcast(tensor, rank)
return tensor
def broadcast_list(size, dtype, list_values=None, rank=0):
"""Broadcast a list of values with a given type."""
tensor = None
if torch.distributed.get_rank() == rank:
tensor = torch.tensor(list_values, dtype=dtype,
device=torch.cuda.current_device())
return broadcast_tensor(size, dtype, tensor=tensor, rank=rank)
def broadcast_int_list(size, int_list=None, rank=0):
"""Broadcast a list of interger values."""
return broadcast_list(size, torch.int64, list_values=int_list, rank=rank)
def broadcast_float_list(size, float_list=None, rank=0):
"""Broadcast a list of float values."""
return broadcast_list(size, torch.float32, list_values=float_list,
rank=rank)
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Forward step utilities."""
from collections.abc import Iterable
import torch
from megatron import (
get_args,
mpu)
from .communication import (
send_to_next_pipeline_rank,
recv_from_prev_pipeline_rank_)
class InferenceParams:
"""Inference parameters that are passed to the main model in order
to efficienly calculate and store the context during inference."""
def __init__(self, max_batch_size, max_sequence_len):
"""Note that offsets are set to zero and we always set the
flag to allocate memory. After the first call, make sure to
set this flag to False."""
self.max_sequence_len = max_sequence_len
self.max_batch_size = max_batch_size
self.sequence_len_offset = 0
self.batch_size_offset = 0
self.allocate_key_value_memory = True
class ForwardStep:
"""Forward step function with all the communications.
We use a class here to hide the inference parameters
from the outside caller."""
def __init__(self, model, max_batch_size, max_sequence_len):
"""Set values so we don't need to do it multiple times."""
# Make sure model is in eval mode.
assert not isinstance(model, Iterable), \
'interleaving schedule is not supported for inference'
model.eval()
self.model = model
# Initialize inference parameters.
self.inference_params = InferenceParams(max_batch_size,
max_sequence_len)
# Pipelining arguments.
args = get_args()
self.pipeline_size_larger_than_one = (
args.pipeline_model_parallel_size > 1)
# Threshold of pipelining.
self.pipelining_batch_x_seqlen = \
args.inference_batch_times_seqlen_threshold
def __call__(self, tokens, position_ids, attention_mask):
"""Invocation of the forward methods. Note that self.inference_params
is being modified by the forward step."""
# Pipelining case.
if self.pipeline_size_larger_than_one:
current_batch_x_seqlen = tokens.size(0) * tokens.size(1)
if current_batch_x_seqlen >= self.pipelining_batch_x_seqlen:
micro_batch_size = \
max(1, self.pipelining_batch_x_seqlen // tokens.size(1))
return _with_pipelining_forward_step(self.model,
tokens,
position_ids,
attention_mask,
self.inference_params,
micro_batch_size)
return _no_pipelining_forward_step(self.model,
tokens,
position_ids,
attention_mask,
self.inference_params)
def _get_recv_buffer_dtype(args):
"""Receive happens between the layers."""
if args.fp32_residual_connection:
return torch.float
return args.params_dtype
def _allocate_recv_buffer(batch_size, sequence_length):
"""Receive happens between the layers with size [s, b, h]."""
if mpu.is_pipeline_first_stage():
return None
args = get_args()
recv_size = (sequence_length, batch_size, args.hidden_size)
return torch.empty(recv_size,
dtype=_get_recv_buffer_dtype(args),
device=torch.cuda.current_device())
def _forward_step_helper(model, tokens, position_ids, attention_mask,
inference_params, recv_buffer=None):
"""Single forward step. Update the allocate memory flag so
only the first time the memory is allocated."""
batch_size = tokens.size(0)
sequence_length = tokens.size(1)
if recv_buffer is None:
recv_buffer = _allocate_recv_buffer(batch_size, sequence_length)
# Receive from previous stage.
recv_from_prev_pipeline_rank_(recv_buffer)
# Forward pass through the model.
model.set_input_tensor(recv_buffer)
output_tensor = model(tokens, position_ids, attention_mask,
inference_params=inference_params)
# Send output to the next stage.
send_to_next_pipeline_rank(output_tensor)
# Make sure we do not allocate context memory anymore.
if inference_params.allocate_key_value_memory:
inference_params.allocate_key_value_memory = False
return output_tensor
def _no_pipelining_forward_step(model, tokens, position_ids, attention_mask,
inference_params, recv_buffer=None):
"""If recv_buffer is none, we will allocate one on the fly."""
# Run a simple forward pass.
output_tensor = _forward_step_helper(model, tokens, position_ids,
attention_mask, inference_params,
recv_buffer=recv_buffer)
# Update the sequence length offset.
inference_params.sequence_len_offset += tokens.size(1)
logits = None
if mpu.is_pipeline_last_stage():
logits = output_tensor
return logits
def _with_pipelining_forward_step(model, tokens, position_ids, attention_mask,
inference_params, micro_batch_size):
"""No interleaving is supported."""
sequence_length = tokens.size(1)
batch_size = tokens.size(0)
# Divide the batch dimension into micro batches.
num_micro_batches, last_chunk = divmod(batch_size,
micro_batch_size)
if last_chunk > 0:
num_micro_batches += 1
# Preallocate memory for output logits.
logits = None
if mpu.is_pipeline_last_stage():
args = get_args()
logits = torch.empty(
(batch_size, sequence_length, args.padded_vocab_size),
dtype=torch.float32, device=torch.cuda.current_device())
# Preallocate recv buffer.
recv_buffer = _allocate_recv_buffer(micro_batch_size, sequence_length)
for micro_batch_index in range(num_micro_batches):
# Slice among the batch dimenion.
start = micro_batch_index * micro_batch_size
end = min(start + micro_batch_size, batch_size)
this_micro_batch_size = end - start
tokens2use = tokens[start:end, ...]
position_ids2use = position_ids[start:end, ...]
# Run a simple forward pass.
if this_micro_batch_size != micro_batch_size:
recv_buffer = None
output = _forward_step_helper(model, tokens2use, position_ids2use,
attention_mask, inference_params,
recv_buffer=recv_buffer)
# Adjust the batch size offset to account for the micro-batch.
inference_params.batch_size_offset += this_micro_batch_size
# Copy logits.
if mpu.is_pipeline_last_stage():
logits[start:end, ...] = output
# Once we are done with all the micro-batches, we can
# adjust the sequence length offset.
inference_params.sequence_len_offset += sequence_length
# and reset the batch size offset
inference_params.batch_size_offset = 0
return logits
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Generation utilities."""
import torch
import torch.nn.functional as F
from megatron import get_args, get_tokenizer, mpu
from megatron.utils import get_ltor_masks_and_position_ids
from .communication import (
copy_from_last_to_first_pipeline_stage,
broadcast_from_last_pipeline_stage,
broadcast_from_last_to_first_pipeline_stage)
from .forward_step import ForwardStep
from .sampling import sample
def generate_tokens_probs_and_return_on_first_stage(
model, tokens, lengths,
return_output_log_probs=False,
top_k=0, top_p=0.0,
temperature=1.0,
use_eod_token_for_early_termination=True):
"""Main token generation function.
Arguments:
model: no interleaving is supported.
tokens: prompt tokens extended to be of size [b, max-sequence-length]
lengths: original prompt length, size: [b]
return_output_log_probs: flag to calculate the log probability of
the generated tokens. Note that the log probability is the one
from the original logit.
top_k, top_p: top-k and top-p sampling parameters.
Note that top-k = 1 is gready. Also, these paramters are
exclusive meaning that:
if top-k > 0 then we expect top-p=0.
if top-p > 0 then we check for top-k=0.
temperature: sampling temperature.
use_eod_token_for_early_termination: if True, do early termination if
all the sequences have reached this token.
Note: Outside of model, other parameters only need to be available on
rank 0.
Outputs: Note that is size is adjusted to a lower value than
max-sequence-length if generation is terminated early.
tokens: prompt and generated tokens. size: [b, :]
generated_sequence_lengths: total length (including prompt) of
the generated sequence. size: [b]
output_log_probs: log probability of the selected tokens. size: [b, s]
"""
args = get_args()
tokenizer = get_tokenizer()
batch_size = tokens.size(0)
min_prompt_length = lengths.min().item()
max_sequence_length = tokens.size(1)
max_sequence_length = min(max_sequence_length, args.max_position_embeddings)
# forward step.
forward_step = ForwardStep(model, batch_size, max_sequence_length)
# Added termination_id to support the case that we want to terminate the
# generation once that id is generated.
if hasattr(args, 'eos_id'):
termination_id = args.eos_id
else:
termination_id = tokenizer.eod
# ===================
# Pre-allocate memory
# ===================
# Log probability of the sequence (prompt + generated tokens).
output_log_probs = None
output_log_probs_size = (batch_size, max_sequence_length - 1)
# Lengths of generated seuquence including including prompts.
generated_sequence_lengths = None
if mpu.is_pipeline_last_stage():
if return_output_log_probs:
output_log_probs = torch.empty(output_log_probs_size,
dtype=torch.float32,
device=torch.cuda.current_device())
generated_sequence_lengths = torch.ones(
batch_size, dtype=torch.int64,
device=torch.cuda.current_device()) * max_sequence_length
# Whether we have reached a termination id.
is_generation_done = torch.zeros(batch_size, dtype=torch.uint8,
device=torch.cuda.current_device())
# =============
# Run infernece
# =============
with torch.no_grad():
attention_mask, position_ids = _build_attention_mask_and_position_ids(
tokens)
prev_context_length = 0
for context_length in range(min_prompt_length, max_sequence_length):
# Pick the slice that we need to pass through the network.
tokens2use = tokens[:, prev_context_length:context_length]
positions2use = position_ids[:, prev_context_length:context_length]
attention_mask2use = attention_mask[
..., prev_context_length:context_length, :context_length]
# logits will be meanigful only in the last pipeline stage.
logits = forward_step(tokens2use, positions2use, attention_mask2use)
if mpu.is_pipeline_last_stage():
# Always the last stage should have an output.
assert logits is not None
# Sample.
last_token_logits = logits[:, -1, :]
new_sample = sample(last_token_logits,
top_k=top_k,
top_p=top_p,
temperature=temperature,
vocab_size=tokenizer.vocab_size)
# If a prompt length is smaller or equal th current context
# length, it means we have started generating tokens
started = lengths <= context_length
# Update the tokens.
tokens[started, context_length] = new_sample[started]
# Calculate the log probabilities.
if return_output_log_probs:
log_probs = F.log_softmax(logits, dim=2)
if return_output_log_probs:
# Pick the tokens that we need to get the log
# probabilities for. Note that next input token is
# the token which we selected in the current logits,
# so shift by 1.
indices = torch.unsqueeze(
tokens[
:,
(prev_context_length + 1):(context_length + 1)],
2)
output_log_probs[:,
prev_context_length:context_length] = \
torch.gather(log_probs, 2, indices).squeeze(2)
# Update the tokens on the first stage so the next input to
# the network is correct.
copy_from_last_to_first_pipeline_stage(batch_size, torch.int64,
tokens[:, context_length])
# Update the context length for the next token generation.
prev_context_length = context_length
# Check if all the sequences have hit the termination_id.
done = None
if mpu.is_pipeline_last_stage():
done_token = (new_sample == termination_id).byte() & \
started.byte()
just_finished = (done_token & ~is_generation_done).bool()
generated_sequence_lengths[just_finished.view(-1)] = \
context_length + 1
is_generation_done = is_generation_done | done_token
done = torch.all(is_generation_done)
done = broadcast_from_last_pipeline_stage(1, torch.uint8,
tensor=done)
if use_eod_token_for_early_termination and done:
break
# ===================================================
# Update the length of based on max generated length.
# ===================================================
tokens = tokens[:, :(context_length + 1)]
if mpu.is_pipeline_last_stage():
if return_output_log_probs:
output_log_probs = output_log_probs[:, :context_length]
# ======================================
# Broadcast to the first pipeline stage.
# ======================================
generated_sequence_lengths = broadcast_from_last_to_first_pipeline_stage(
batch_size, torch.int64, generated_sequence_lengths)
if return_output_log_probs:
output_log_probs_size = (batch_size, context_length)
output_log_probs = broadcast_from_last_to_first_pipeline_stage(
output_log_probs_size, torch.float32, output_log_probs)
return tokens, generated_sequence_lengths, output_log_probs
def _build_attention_mask_and_position_ids(tokens):
"""Build the attention mask and postition ids for the input tokens."""
# Since we are not interested in loss-mask and reset attention/position
# is also False, eod_token is not used so it is safe to set it to None.
attention_mask, _, position_ids = get_ltor_masks_and_position_ids(
data=tokens,
eod_token=None,
reset_position_ids=False,
reset_attention_mask=False,
eod_mask_loss=False)
return attention_mask, position_ids
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Sampling utilities.
Part of this code is inspired by:
- https://github.com/ari-holtzman/degen/blob/master/gen.py
- https://huggingface.co/transformers/_modules/transformers/generation_logits_process.html
"""
import torch
def modify_logits_for_top_k_filtering(logits, top_k):
"""Set the logits for none top-k values to -inf."""
filter_ = logits < torch.topk(logits, top_k)[0][..., -1, None]
logits.masked_fill_(filter_, float('-Inf'))
def modify_logits_for_top_p_filtering(logits, top_p):
"""Set the logits for none top-p values to -inf."""
# First sort and calculate cumulative sum of probabilities.
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)
# Filteration based on the cumulative sum.
filter_ = cumulative_probs > top_p
# This shift by 1 is weird and I cannot justify it. This existed
# in the original implementation:
# https://github.com/ari-holtzman/degen/blob/master/gen.py
# and I guess it is needed so keeping it for now.
filter_[:, 1:] = filter_[:, :-1].clone()
# Make sure we at least have one token to select from.
filter_[..., 0] = 0
# Fill in the filtered part
filter_ = filter_.scatter(1, sorted_indices, filter_)
logits.masked_fill_(filter_, float('-Inf'))
def sample(logits, top_k=0, top_p=0.0, temperature=1.0, vocab_size=None):
""" Sample and generate a token.
Note: logits has the dimension [b, v] where b is the batch size
and v is the vocabulary size.
If vocab_size is provided, we will make sure the sample that is
generated is in [0, vocab-size). This will avoid out of vocabulary
generations due to padding.
"""
# Check logits for consistency.
assert logits.ndim == 2, 'expected the logits to be of [b, v] shape.'
assert logits.type() == 'torch.cuda.FloatTensor', \
'input logits should be floats.'
# Greedy is just simple argmax.
if top_k == 1:
assert top_p == 0.0, 'cannot set both greedy and top-p samplings.'
samples = torch.argmax(logits, dim=-1)
# Top-k or top-p sampling.
else:
# Clone so we do not modify the inputs,
logits = logits.clone()
# Apply temperature in place.
if temperature != 1.0:
logits.div_(temperature)
if top_k > 1:
assert top_p == 0.0, 'cannot set both top-k and top-p samplings.'
assert top_k <= logits.size(1), 'top-k is larger than logit size.'
if vocab_size:
assert top_k < vocab_size, 'top-k is larger than vocab size.'
modify_logits_for_top_k_filtering(logits, top_k)
elif top_p > 0.0:
assert top_p <= 1.0, 'top-p should be in (0, 1].'
modify_logits_for_top_p_filtering(logits, top_p)
# After filtering, we need to recalculate the distribution.
probs = logits.softmax(dim=-1)
samples = torch.multinomial(probs, num_samples=1).view(-1)
# If vocab size is provided, make sure the samples are in
# in the range [0, vocab-size).
if vocab_size:
samples = torch.clamp(samples, min=0, max=(vocab_size - 1))
return samples
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tokenization utilities."""
import torch
from megatron import get_tokenizer
from .communication import broadcast_int_list, broadcast_tensor
def detokenize_generations(tokens_gpu_tensor,
lengths_gpu_tensor,
return_segments):
"""Detokenize the generated tokens."""
tokenizer = get_tokenizer()
prompts_plus_generations = []
if return_segments:
prompts_plus_generations_segments = []
tokens = tokens_gpu_tensor.cpu().numpy().tolist()
lengths = lengths_gpu_tensor.cpu().numpy().tolist()
for sequence_tokens, length in zip(tokens, lengths):
sequence_tokens = sequence_tokens[:length]
prompts_plus_generations.append(
tokenizer.detokenize(sequence_tokens))
if return_segments:
words = []
for token in sequence_tokens:
word = tokenizer.tokenizer.decoder[token]
word = bytearray(
[tokenizer.tokenizer.byte_decoder[c] for c in word]).decode(
'utf-8', errors='replace')
words.append(word)
prompts_plus_generations_segments.append(words)
if return_segments:
return tokens, prompts_plus_generations, \
prompts_plus_generations_segments
return tokens, prompts_plus_generations
def tokenize_prompts(prompts=None, tokens_to_generate=None,
add_BOS=None, rank=0):
"""Tokenize prompts and make them avaiable on all ranks."""
# On all ranks set to None so we can pass them to functions
sizes_list = None
prompts_tokens_cuda_long_tensor = None
prompts_length_cuda_long_tensor = None
# On the specified rank, build the above.
if torch.distributed.get_rank() == rank:
assert prompts is not None
assert tokens_to_generate is not None
# Tensor of tokens padded and their unpadded length.
prompts_tokens_cuda_long_tensor, prompts_length_cuda_long_tensor = \
_tokenize_prompts_and_batch(prompts, tokens_to_generate, add_BOS)
# We need the sizes of these tensors for the boradcast
sizes_list = [prompts_tokens_cuda_long_tensor.size(0), # Batch size
prompts_tokens_cuda_long_tensor.size(1)] # Sequence lenght
# First, broadcast the sizes.
sizes_tensor = broadcast_int_list(2, int_list=sizes_list, rank=rank)
# Now that we have the sizes, we can boradcast the tokens
# and length tensors.
sizes = sizes_tensor.tolist()
prompts_tokens_cuda_long_tensor = broadcast_tensor(
sizes, torch.int64, tensor=prompts_tokens_cuda_long_tensor, rank=rank)
prompts_length_cuda_long_tensor = broadcast_tensor(
sizes[0], torch.int64, tensor=prompts_length_cuda_long_tensor,
rank=rank)
return prompts_tokens_cuda_long_tensor, prompts_length_cuda_long_tensor
def _tokenize_prompts_and_batch(prompts, tokens_to_generate, add_BOS):
"""Given a set of prompts and number of tokens to generate:
- tokenize prompts
- set the sequence length to be the max of length of prompts
plus the number of tokens we would like to generate
- pad all the sequences to this length so we can convert them
into a 2D tensor.
"""
# Tokenize all the prompts.
tokenizer = get_tokenizer()
if add_BOS:
prompts_tokens = [[tokenizer.eod] + tokenizer.tokenize(prompt)
for prompt in prompts]
else:
prompts_tokens = [tokenizer.tokenize(prompt) for prompt in prompts]
# Now we have a list of list of tokens which each list has a different
# size. We want to extend this list to:
# - incorporate the tokens that need to be generated
# - make all the sequences equal length.
# Get the prompts length.
prompts_length = [len(prompt_tokens) for prompt_tokens in prompts_tokens]
# Get the max prompts length.
max_prompt_len = max(prompts_length)
# Number of tokens in the each sample of the batch.
samples_length = max_prompt_len + tokens_to_generate
# Now update the list of list to be of the same size: samples_length.
for prompt_tokens, prompt_length in zip(prompts_tokens, prompts_length):
padding_size = samples_length - prompt_length
prompt_tokens.extend([tokenizer.eod] * padding_size)
# Now we are in a structured format, we can convert to tensors.
prompts_tokens_tensor = torch.cuda.LongTensor(prompts_tokens)
prompts_length_tensor = torch.cuda.LongTensor(prompts_length)
return prompts_tokens_tensor, prompts_length_tensor
...@@ -19,8 +19,8 @@ import threading ...@@ -19,8 +19,8 @@ import threading
from flask import Flask, request, jsonify, current_app from flask import Flask, request, jsonify, current_app
from flask_restful import Resource, Api from flask_restful import Resource, Api
from megatron import get_args from megatron import get_args
from megatron import mpu from megatron.text_generation import generate_and_post_process
from megatron.text_generation_utils import generate
GENERATE_NUM = 0 GENERATE_NUM = 0
lock = threading.Lock() lock = threading.Lock()
...@@ -67,7 +67,7 @@ class MegatronGenerate(Resource): ...@@ -67,7 +67,7 @@ class MegatronGenerate(Resource):
if not isinstance(logprobs, bool): if not isinstance(logprobs, bool):
return "logprobs must be a boolean value" return "logprobs must be a boolean value"
temperature = args.temperature temperature = 1.0
if "temperature" in request.get_json(): if "temperature" in request.get_json():
temperature = request.get_json()["temperature"] temperature = request.get_json()["temperature"]
if not (type(temperature) == int or type(temperature) == float): if not (type(temperature) == int or type(temperature) == float):
...@@ -75,7 +75,7 @@ class MegatronGenerate(Resource): ...@@ -75,7 +75,7 @@ class MegatronGenerate(Resource):
if not (0.0 < temperature <= 100.0): if not (0.0 < temperature <= 100.0):
return "temperature must be a positive number less than or equal to 100.0" return "temperature must be a positive number less than or equal to 100.0"
top_k = args.top_k top_k = 0.0
if "top_k" in request.get_json(): if "top_k" in request.get_json():
top_k = request.get_json()["top_k"] top_k = request.get_json()["top_k"]
if not (type(top_k) == int): if not (type(top_k) == int):
...@@ -83,11 +83,13 @@ class MegatronGenerate(Resource): ...@@ -83,11 +83,13 @@ class MegatronGenerate(Resource):
if not (0 < top_k <= 1000): if not (0 < top_k <= 1000):
return "top_k must be equal to or greater than 0 and less than or equal to 1000" return "top_k must be equal to or greater than 0 and less than or equal to 1000"
top_p = args.top_p top_p = 0.0
if "top_p" in request.get_json(): if "top_p" in request.get_json():
top_p = request.get_json()["top_p"] top_p = request.get_json()["top_p"]
if not (type(top_p) == float): if not (type(top_p) == float):
return "top_p must be a positive float less than or equal to 1.0" return "top_p must be a positive float less than or equal to 1.0"
if top_p > 0.0 and top_k > 0.0:
return "cannot set both top-k and top-p samplings."
if not (0 < top_p <= 1.0): if not (0 < top_p <= 1.0):
return "top_p must be less than or equal to 1.0" return "top_p must be less than or equal to 1.0"
...@@ -99,14 +101,17 @@ class MegatronGenerate(Resource): ...@@ -99,14 +101,17 @@ class MegatronGenerate(Resource):
with lock: # Need to get lock to keep multiple threads from hitting code with lock: # Need to get lock to keep multiple threads from hitting code
MegatronGenerate.send_do_generate() # Tell other ranks we're doing generate MegatronGenerate.send_do_generate() # Tell other ranks we're doing generate
response, response_seg, response_logprobs = generate(self.model, response, response_seg, response_logprobs, _ = \
prompts, generate_and_post_process(
tokens_to_generate, self.model,
logprobs, prompts=prompts,
temperature, tokens_to_generate=tokens_to_generate,
top_k, return_output_log_probs=logprobs,
top_p, top_k_sampling=top_k,
add_BOS) top_p_sampling=top_p,
temperature=temperature,
add_BOS=add_BOS,
use_eod_token_for_early_termination=True)
return jsonify({"text": response, return jsonify({"text": response,
"segments": response_seg, "segments": response_seg,
......
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities for generating text."""
import copy
import json
import os
import time
import torch
import torch.nn.functional as F
from megatron import get_args
from megatron import get_tokenizer
from megatron import mpu
from megatron.utils import get_ltor_masks_and_position_ids, unwrap_model
from megatron.p2p_communication import recv_forward, send_forward
# These are needed to unwrap the model, would be nice to put these in megatron.utils if possible?
from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
from megatron.model import DistributedDataParallel as LocalDDP
from megatron.model import Float16Module
def get_batch(context_tokens):
"""Generate batch from context tokens."""
args = get_args()
tokenizer = get_tokenizer()
# Move to GPU.
tokens = context_tokens.contiguous().cuda()
# Get the attention mask and postition ids.
attention_mask, _, position_ids = get_ltor_masks_and_position_ids(
tokens,
tokenizer.eod,
args.reset_position_ids,
args.reset_attention_mask,
args.eod_mask_loss)
return tokens, attention_mask, position_ids
def top_k_logits(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')):
""" This function has been mostly taken from huggingface conversational
ai code at
https://medium.com/huggingface/how-to-build-a-state-of-the-art-
conversational-ai-with-transfer-learning-2d818ac26313 """
if top_k > 0:
# Remove all tokens with a probability less than the
# last token of the top-k
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
logits[indices_to_remove] = filter_value
if top_p > 0.0:
# Cconvert to 1D
sorted_logits, sorted_indices = torch.sort(
logits, descending=True, dim=-1)
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1),
dim=-1)
# Remove tokens with cumulative probability above the threshold
sorted_indices_to_remove = cumulative_probs > top_p
# Shift the indices to the right to keep also the first token
# above the threshold
sorted_indices_to_remove[..., 1:] \
= sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
for i in range(sorted_indices.size(0)):
indices_to_remove = sorted_indices[i][sorted_indices_to_remove[i]]
logits[i][indices_to_remove] = filter_value
return logits
def pad_batch(batch, pad_id, max_len):
context_lengths = []
max_context_length = max([len(tokens) for tokens in batch])
for tokens in batch:
context_length = len(tokens)
if context_length < max_context_length + max_len:
tokens.extend([pad_id] * (max_context_length + max_len - context_length))
context_lengths.append(context_length)
return batch, context_lengths
def tokenize_batch(sentences, max_len, add_BOS):
args = get_args()
tokenizer = get_tokenizer()
if add_BOS:
context_tokens = [[tokenizer.eod] + tokenizer.tokenize(s) for s in sentences]
else:
context_tokens = [tokenizer.tokenize(s) for s in sentences]
context_tokens, context_lengths = pad_batch(context_tokens,
tokenizer.eod, max_len)
context_tokens_tensor = torch.cuda.LongTensor(context_tokens)
context_length_tensor = torch.cuda.LongTensor(context_lengths)
return context_tokens_tensor, context_length_tensor
def send_generate_info(context_tokens_tensor, context_length_tensor, tokens_to_generate, logprobs, temperature, top_k, top_p):
"""
Needs to be synced up with receive_generate_info
"""
# Send the sizes of the tensors
input_info = [context_tokens_tensor.size(0), context_tokens_tensor.size(1), tokens_to_generate, logprobs, temperature, top_k, top_p]
input_info_tensor = torch.cuda.FloatTensor(input_info)
torch.distributed.broadcast(input_info_tensor, 0)
# Send variables to all ranks
torch.distributed.broadcast(context_length_tensor, 0)
torch.distributed.broadcast(context_tokens_tensor, 0)
def receive_generate_info():
"""
Needs to be synced up with send_generate_info
"""
input_info_tensor = torch.empty(7, dtype=torch.float32, device=torch.cuda.current_device())
torch.distributed.broadcast(input_info_tensor, 0)
batch_size = int(input_info_tensor[0].item())
seq_len = int(input_info_tensor[1].item())
tokens_to_generate = int(input_info_tensor[2].item())
logprobs = bool(input_info_tensor[3].item())
temperature = float(input_info_tensor[4].item())
top_k = int(input_info_tensor[5].item())
top_p = float(input_info_tensor[6].item())
context_length_tensor = torch.empty(batch_size, dtype=torch.int64, device=torch.cuda.current_device())
context_tokens_tensor = torch.empty(batch_size, seq_len, dtype=torch.int64, device=torch.cuda.current_device())
# Send variables to all ranks
torch.distributed.broadcast(context_length_tensor, 0)
torch.distributed.broadcast(context_tokens_tensor, 0)
return context_length_tensor, context_tokens_tensor, tokens_to_generate, logprobs, temperature, top_k, top_p
def synced_generate(model, context_tokens_tensor, context_length_tensor, tokens_to_generate, logprobs, temperature, top_k, top_p):
context_length = context_length_tensor.min().item()
tokens, attention_mask, position_ids = get_batch(context_tokens_tensor)
batch_token_iterator = sample_sequence_batch(model,
context_tokens_tensor,
context_length_tensor,
attention_mask,
position_ids,
tokens_to_generate,
logprobs,
temperature,
top_k,
top_p)
for tokens, lengths, output_logits in batch_token_iterator:
context_length += 1
if logprobs:
if mpu.is_pipeline_last_stage():
src = mpu.get_pipeline_model_parallel_last_rank()
group = mpu.get_embedding_group()
torch.distributed.broadcast(output_logits, src, group)
else:
if mpu.is_pipeline_first_stage():
src = mpu.get_pipeline_model_parallel_last_rank()
group = mpu.get_embedding_group()
output_logits = torch.empty(tokens.size(0), context_length-1, dtype=torch.float32, device=torch.device("cuda"))
torch.distributed.broadcast(output_logits, src, group)
if tokens is not None:
return tokens[:, :context_length], output_logits
def generate(model, sentences=None, tokens_to_generate=0, logprobs=False, temperature=1.0, top_k=0, top_p=0.0, add_BOS=False):
model.eval()
if torch.distributed.get_rank() == 0:
context_tokens_tensor, context_length_tensor = tokenize_batch(sentences, tokens_to_generate, add_BOS)
send_generate_info(context_tokens_tensor, context_length_tensor, tokens_to_generate, logprobs, temperature, top_k, top_p)
else:
context_length_tensor, context_tokens_tensor, tokens_to_generate, logprobs, temperature, top_k, top_p = receive_generate_info()
output = synced_generate(model, context_tokens_tensor, context_length_tensor, tokens_to_generate, logprobs, temperature, top_k, top_p)
if output is not None:
decode_tokens, output_logits = output
args = get_args()
tokenizer = get_tokenizer()
resp_sentences = []
resp_sentences_seg = []
decode_tokens = decode_tokens.cpu().numpy().tolist()
for i, decode_token in enumerate(decode_tokens):
resp_sentences.append(tokenizer.detokenize(decode_token))
words = []
for token in decode_token:
word = tokenizer.tokenizer.decoder[token]
word = bytearray([tokenizer.tokenizer.byte_decoder[c] for c in word]).decode('utf-8', errors='replace')
words.append(word)
resp_sentences_seg.append(words)
if logprobs:
output_logits = output_logits.cpu().numpy().tolist()
return resp_sentences, resp_sentences_seg, output_logits
def generate_samples_eval(model, context, max_gen_length, eos_token_id):
"""
This function is here to provide an a matching API for a legacy task
This implementation hasn't been tested yet to make sure it matches
"""
#assert False, "Implementation untested"
args = get_args()
args.eos_id = eos_token_id
raw_text_len = len(context)
resp_sentences = generate(model, [context], max_gen_length)
if resp_sentences:
return resp_sentences[0][raw_text_len:]
def switch(val1, val2, boolean):
boolean = boolean.type_as(val1)
return (1 - boolean) * val1 + boolean * val2
def forward_step(model, tokens, position_ids, attention_mask, tokentype_ids,
set_inference_key_value_memory=False,
inference_max_sequence_len=None):
# Hidden size changes when not using recompute, need to tell p2p_communicate
# functions the correct size
args = get_args()
orig_seq_length = args.seq_length
args.seq_length = tokens.shape[1]
args.micro_batch_size = tokens.shape[0]
input_tensor = recv_forward()
# Forward pass through the model.
unwrapped_model = unwrap_model(
model, (torchDDP, LocalDDP, Float16Module))
unwrapped_model.set_input_tensor(input_tensor)
output_tensor = model(
tokens, position_ids, attention_mask,
tokentype_ids=tokentype_ids,
set_inference_key_value_memory=set_inference_key_value_memory,
inference_max_sequence_len=inference_max_sequence_len)
send_forward(output_tensor)
args.seq_length = orig_seq_length
return output_tensor
def sample_sequence_batch(model,
context_tokens,
context_lengths,
attention_mask,
position_ids,
tokens_to_generate,
logprobs,
temperature,
top_k,
top_p,
type_ids=None):
args = get_args()
tokenizer = get_tokenizer()
model.eval()
with torch.no_grad():
context_length = context_lengths.min().item()
# added eos_id to support the function generate_samples_eval that passes
# eos_id as an argument and needs termination when that id id found.
if hasattr(args, 'eos_id'):
eos_id = args.eos_id
else:
eos_id = tokenizer.eod
counter = 0
batch_size = context_tokens.size(0)
is_done = torch.zeros([batch_size]).byte().cuda()
tokens = context_tokens
output_logits = None
# Generate enough tokens for the longest sequence
maxlen = tokens_to_generate + context_lengths.max().item()
if maxlen > args.seq_length:
maxlen = args.seq_length
lengths = torch.ones([batch_size]).long().cuda() * maxlen
while context_length < maxlen:
types2use = None
if counter == 0:
# Allocate memory for the entire context.
set_inference_key_value_memory = True
tokens2use = tokens[:, :context_length]
positions2use = position_ids[:, :context_length]
if type_ids is not None:
types2use = type_ids[:, :context_length]
else:
# Set this to false so the memory is not reallocated.
set_inference_key_value_memory = False
tokens2use = tokens[:, context_length - 1].view(
batch_size, -1)
positions2use = position_ids[:, context_length - 1].view(
batch_size, -1)
if type_ids is not None:
types2use = type_ids[:, context_length - 1].view(
batch_size, -1)
output = forward_step(
model, tokens2use,
positions2use,
attention_mask,
set_inference_key_value_memory=set_inference_key_value_memory,
inference_max_sequence_len=maxlen,
tokentype_ids=types2use)
if mpu.is_pipeline_last_stage():
assert output is not None
output = output.float()
logits = output[:, -1].view(batch_size, -1).contiguous()
if args.greedy:
prev = torch.argmax(logits, dim=-1).view(-1)
else:
logits = logits.float()
logits /= temperature
logits = top_k_logits(logits, top_k=top_k,
top_p=top_p)
log_probs = F.softmax(logits, dim=-1)
prev = torch.multinomial(log_probs, num_samples=1).view(-1)
started = context_lengths <= context_length
# Clamp the out of vocabulary tokens.
tokenizer = get_tokenizer()
prev = torch.clamp(prev, max=tokenizer.vocab_size - 1)
new_tokens = switch(
tokens[:, context_length].view(-1), prev, started)
tokens[:, context_length] = new_tokens
if logprobs:
if output_logits is None:
output_context = F.log_softmax(output[:, :context_length, :], 2)
indices = torch.unsqueeze(tokens[:, 1:context_length+1],2)
output_logits = torch.gather(output_context, 2, indices).squeeze(2)
else:
output_context = F.log_softmax(output, 2)
indices = torch.unsqueeze(new_tokens,1).unsqueeze(2)
new_output_logits = torch.gather(output_context, 2, indices).squeeze(2)
# TODO(rprenger) we're copying output_logits every time. Should pre-allocate
output_logits = torch.cat([output_logits, new_output_logits],1)
src = mpu.get_pipeline_model_parallel_last_rank()
group = mpu.get_embedding_group()
torch.distributed.broadcast(new_tokens, src, group)
done_token = (prev == eos_id).byte() & started.byte()
just_finished = (done_token & ~is_done).bool()
lengths[just_finished.view(-1)] = context_length
is_done = is_done | done_token
done = torch.all(is_done)
src = mpu.get_pipeline_model_parallel_last_rank()
group = mpu.get_pipeline_model_parallel_group()
torch.distributed.broadcast(done, src, group)
yield tokens, lengths, output_logits
else:
if mpu.is_pipeline_first_stage():
src = mpu.get_pipeline_model_parallel_last_rank()
group = mpu.get_embedding_group()
new_tokens = torch.empty_like(tokens[:, context_length])
torch.distributed.broadcast(new_tokens, src, group)
tokens[:, context_length] = new_tokens
yield tokens, None, None
else:
yield None, None, None
done = torch.cuda.ByteTensor([0])
src = mpu.get_pipeline_model_parallel_last_rank()
group = mpu.get_pipeline_model_parallel_group()
torch.distributed.broadcast(done, src, group)
context_length += 1
counter += 1
if done:
break
...@@ -27,7 +27,7 @@ from megatron.initialize import initialize_megatron ...@@ -27,7 +27,7 @@ from megatron.initialize import initialize_megatron
from megatron.model import GPTModel from megatron.model import GPTModel
from megatron.training import get_model from megatron.training import get_model
from megatron.text_generation_server import MegatronServer from megatron.text_generation_server import MegatronServer
from megatron.text_generation_utils import generate from megatron.text_generation import generate_and_post_process
import torch import torch
def model_provider(pre_process=True, post_process=True): def model_provider(pre_process=True, post_process=True):
...@@ -43,8 +43,6 @@ def add_text_generate_args(parser): ...@@ -43,8 +43,6 @@ def add_text_generate_args(parser):
group.add_argument("--temperature", type=float, default=1.0, group.add_argument("--temperature", type=float, default=1.0,
help='Sampling temperature.') help='Sampling temperature.')
group.add_argument("--greedy", action='store_true', default=False,
help='Use greedy sampling.')
group.add_argument("--top_p", type=float, default=0.0, group.add_argument("--top_p", type=float, default=0.0,
help='Top p sampling.') help='Top p sampling.')
group.add_argument("--top_k", type=int, default=0, group.add_argument("--top_k", type=int, default=0,
...@@ -80,4 +78,4 @@ if __name__ == "__main__": ...@@ -80,4 +78,4 @@ if __name__ == "__main__":
choice = torch.cuda.LongTensor(1) choice = torch.cuda.LongTensor(1)
torch.distributed.broadcast(choice, 0) torch.distributed.broadcast(choice, 0)
if choice[0].item() == 0: if choice[0].item() == 0:
generate(model) generate_and_post_process(model)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment