Commit 6c40f892 authored by mshoeybi's avatar mshoeybi
Browse files

working

parent 25f9c3f0
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference API."""
import torch
from .communication import broadcast_float_list
from .generation import generate_tokens_probs_and_return_on_first_stage
from .tokenization import tokenize_prompts
def generate(model,
prompts=None,
tokens_to_generate=0,
return_output_log_probs=False,
return_all_log_probs=False,
temperature=1.0):
"""TO DO ..."""
# Make sure input params are avaialble to all ranks.
values = [tokens_to_generate, return_output_log_probs,
return_all_log_probs, temperature]
values_float_tensor = broadcast_float_list(4, float_list=values)
tokens_to_generate = int(values_float_tensor[0].item())
return_output_log_probs = bool(values_float_tensor[1].item())
return_all_log_probs = bool(values_float_tensor[2].item())
temperature = values_float_tensor[2].item()
# Tokenize prompts and get the batch.
# Note that these tensors are broadcaseted to all ranks.
if torch.distributed.get_rank() == 0:
assert prompts is not None
context_tokens_tensor, context_length_tensor = tokenize_prompts(
prompts=prompts, tokens_to_generate=tokens_to_generate)
# Main inference function.
# Note that the outputs are available on the first stage.
return generate_tokens_probs_and_return_on_first_stage(
model, context_tokens_tensor, context_length_tensor,
return_output_log_probs=return_output_log_probs,
return_all_log_probs=return_all_log_probs,
temperature=temperature)
...@@ -40,6 +40,33 @@ def broadcast_from_last_pipeline_stage(size, dtype, tensor=None): ...@@ -40,6 +40,33 @@ def broadcast_from_last_pipeline_stage(size, dtype, tensor=None):
return tensor return tensor
def broadcast_from_last_to_first_pipeline_stage(size, dtype, tensor=None):
"""Broadcast tensor values from last stage into the first stage."""
# Only first and last stage pipeline stages need to be involved.
is_last_stage = mpu.is_pipeline_last_stage()
is_first_stage = mpu.is_pipeline_first_stage()
if is_last_stage or is_first_stage:
if is_last_stage:
assert tensor is not None
assert tensor.is_cuda
assert tensor.is_contiguous()
else:
tensor = torch.empty(size,
dtype=dtype,
device=torch.cuda.current_device())
src = mpu.get_pipeline_model_parallel_last_rank()
group = mpu.get_embedding_group()
# Broadcast from last stage into the first stage.
torch.distributed.broadcast(tensor, src, group)
else:
tensor = None
return tensor
def copy_from_last_to_first_pipeline_stage(size, dtype, tensor=None): def copy_from_last_to_first_pipeline_stage(size, dtype, tensor=None):
"""Copy tensor values from last stage into the first stage. """Copy tensor values from last stage into the first stage.
Note that the input tensor is updated in place.""" Note that the input tensor is updated in place."""
...@@ -48,11 +75,15 @@ def copy_from_last_to_first_pipeline_stage(size, dtype, tensor=None): ...@@ -48,11 +75,15 @@ def copy_from_last_to_first_pipeline_stage(size, dtype, tensor=None):
is_last_stage = mpu.is_pipeline_last_stage() is_last_stage = mpu.is_pipeline_last_stage()
is_first_stage = mpu.is_pipeline_first_stage() is_first_stage = mpu.is_pipeline_first_stage()
if is_last_stage or is_first_stage: if is_last_stage or is_first_stage:
assert tensor is not None
assert tensor.is_cuda
is_contiguous = tensor.is_contiguous()
src = mpu.get_pipeline_model_parallel_last_rank() src = mpu.get_pipeline_model_parallel_last_rank()
group = mpu.get_embedding_group() group = mpu.get_embedding_group()
if is_contiguous:
tensor_ = tensor
else:
if is_last_stage: if is_last_stage:
assert tensor is not None
assert tensor.is_cuda
tensor_ = tensor.contiguous() tensor_ = tensor.contiguous()
else: else:
tensor_ = torch.empty(size, tensor_ = torch.empty(size,
...@@ -61,7 +92,7 @@ def copy_from_last_to_first_pipeline_stage(size, dtype, tensor=None): ...@@ -61,7 +92,7 @@ def copy_from_last_to_first_pipeline_stage(size, dtype, tensor=None):
# Broadcast from last stage into the first stage. # Broadcast from last stage into the first stage.
torch.distributed.broadcast(tensor_, src, group) torch.distributed.broadcast(tensor_, src, group)
# Update the first stage tensor # Update the first stage tensor
if is_first_stage: if is_first_stage and not is_contiguous:
tensor[...] = tensor_ tensor[...] = tensor_
......
...@@ -19,19 +19,44 @@ ...@@ -19,19 +19,44 @@
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from megatron import get_args, get_tokenizer from megatron import get_args, get_tokenizer, mpu
from megatron import mpu
from megatron.utils import get_ltor_masks_and_position_ids from megatron.utils import get_ltor_masks_and_position_ids
from .communication import ( from .communication import (
copy_from_last_to_first_pipeline_stage, copy_from_last_to_first_pipeline_stage,
broadcast_from_last_pipeline_stage) broadcast_from_last_pipeline_stage,
broadcast_from_last_to_first_pipeline_stage)
from .forward_step import forward_step from .forward_step import forward_step
from .sampling import sample from .sampling import sample
def generate_tokens(model, tokens, lengths, return_all_probs=False, def generate_tokens_probs_and_return_on_first_stage(
model, tokens, lengths,
return_output_log_probs=False,
return_all_log_probs=False,
temperature=1.0): temperature=1.0):
"""Main token generation function.""" """Main token generation function.
Arguments:
model: XXX
tokens: prompt tokens extended to be of size [b, max-sequence-length]
lengths: original prompt length, size: [b]
return_output_log_probs: flag to calculate the log probability of
the generated tokens. Note that the log probability is the one
after logits are modifed for sampling.
return_all_log_probs: flag to calculate the log probability of across
all the tokens (vocab size). Note that the log probability is the
one after logits are modifed for sampling.
temperature: sampling temperature.
Note: Outside of model, other parameters only need to be available on
rank 0.
Outputs: Note that is size is adjusted to a lower value than
max-sequence-length if generation is terminated early.
tokens: prompt and generated tokens. size: [b, :]
generated_sequence_lengths: total length (including prompt) of
the generated sequence. size: [b]
output_log_probs: log probability of the selected tokens. size: [b, s]
all_log_probs: log probability of all the tokens.
size: [b, s, vocab-size]
"""
args = get_args() args = get_args()
tokenizer = get_tokenizer() tokenizer = get_tokenizer()
...@@ -52,11 +77,24 @@ def generate_tokens(model, tokens, lengths, return_all_probs=False, ...@@ -52,11 +77,24 @@ def generate_tokens(model, tokens, lengths, return_all_probs=False,
# Pre-allocate memory # Pre-allocate memory
# =================== # ===================
# Log probability of the sequence (prompt + generated tokens) # Log probability of the sequence (prompt + generated tokens).
output_log_probs = torch.empty(batch_size, max_sequence_length - 1, output_log_probs = None
output_log_probs_size = (batch_size, max_sequence_length - 1)
# Log probability of all tokens for the sequence.
all_log_probs = None
all_log_probs_size = (batch_size, max_sequence_length -1,
args.padded_vocab_size)
# Lengths of generated seuquence including including prompts.
generated_sequence_lengths = None
if mpu.is_pipeline_last_stage():
if return_output_log_probs:
output_log_probs = torch.empty(output_log_probs_size,
dtype=torch.float32,
device=torch.cuda.current_device())
if return_all_log_probs:
all_log_probs = torch.empty(all_log_probs_size,
dtype=torch.float32, dtype=torch.float32,
device=torch.cuda.current_device()) device=torch.cuda.current_device())
# Lengths of generated seuquence including including prompts.
generated_sequence_lengths = torch.ones( generated_sequence_lengths = torch.ones(
batch_size, dtype=torch.int64, batch_size, dtype=torch.int64,
device=torch.cuda.current_device()) * max_sequence_length device=torch.cuda.current_device()) * max_sequence_length
...@@ -64,6 +102,10 @@ def generate_tokens(model, tokens, lengths, return_all_probs=False, ...@@ -64,6 +102,10 @@ def generate_tokens(model, tokens, lengths, return_all_probs=False,
is_generation_done = torch.zeros(batch_size, dtype=torch.uint8, is_generation_done = torch.zeros(batch_size, dtype=torch.uint8,
device=torch.cuda.current_device()) device=torch.cuda.current_device())
# =============
# Run infernece
# =============
attention_mask, position_ids = _build_attention_mask_and_position_ids( attention_mask, position_ids = _build_attention_mask_and_position_ids(
tokens) tokens)
...@@ -114,14 +156,24 @@ def generate_tokens(model, tokens, lengths, return_all_probs=False, ...@@ -114,14 +156,24 @@ def generate_tokens(model, tokens, lengths, return_all_probs=False,
tokens[started, context_length] = new_sample[started] tokens[started, context_length] = new_sample[started]
# Calculate the log probabilities. # Calculate the log probabilities.
if return_output_log_probs or return_all_log_probs:
log_probs = F.log_softmax(logits, dim=2) log_probs = F.log_softmax(logits, dim=2)
# Pick the tokens that we need to get the log probabilities for. if return_all_log_probs:
# Note that next input token is the token which we selected in all_log_probs[:,
# the current logits, so shift by 1. prev_context_length:context_length,
:] = log_probs
if return_output_log_probs:
# Pick the tokens that we need to get the log
# probabilities for. Note that next input token is
# the token which we selected in the current logits,
# so shift by 1.
indices = torch.unsqueeze( indices = torch.unsqueeze(
tokens[:, (prev_context_length + 1):(context_length + 1)], tokens[
:,
(prev_context_length + 1):(context_length + 1)],
2) 2)
output_log_probs[:, prev_context_length:context_length] = \ output_log_probs[:,
prev_context_length:context_length] = \
torch.gather(log_probs, 2, indices).squeeze(2) torch.gather(log_probs, 2, indices).squeeze(2)
# Update the tokens on the first stage so the next input to # Update the tokens on the first stage so the next input to
...@@ -147,17 +199,36 @@ def generate_tokens(model, tokens, lengths, return_all_probs=False, ...@@ -147,17 +199,36 @@ def generate_tokens(model, tokens, lengths, return_all_probs=False,
if done: if done:
break break
# ===================================================
# Update the length of based on max generated length.
# ===================================================
tokens = tokens[:, :(context_length + 1)]
if mpu.is_pipeline_last_stage(): if mpu.is_pipeline_last_stage():
if return_all_probs: if return_output_log_probs:
full_logits = None output_log_probs = output_log_probs[:, :context_length]
return tokens, generated_sequence_lengths, output_log_probs, \ if return_all_log_probs:
full_logits, context_length + 1 all_log_probs = all_log_probs[:, :context_length, :]
# ======================================
# Broadcast to the first pipeline stage.
# ======================================
generated_sequence_lengths = broadcast_from_last_to_first_pipeline_stage(
batch_size, torch.int64, generated_sequence_lengths)
if return_output_log_probs:
output_log_probs_size = (batch_size, context_length)
output_log_probs = broadcast_from_last_to_first_pipeline_stage(
output_log_probs_size, torch.float32, output_log_probs)
if return_all_log_probs:
all_log_probs_size = (batch_size, context_length,
args.padded_vocab_size)
all_log_probs = broadcast_from_last_to_first_pipeline_stage(
all_log_probs_size, torch.float32, all_log_probs)
return tokens, generated_sequence_lengths, output_log_probs, \ return tokens, generated_sequence_lengths, output_log_probs, \
None, context_length + 1 all_log_probs
if mpu.is_pipeline_first_stage():
return tokens, None, None, None, context_length + 1
return None, None, None, None, context_length + 1
def _build_attention_mask_and_position_ids(tokens): def _build_attention_mask_and_position_ids(tokens):
......
...@@ -23,6 +23,39 @@ from megatron import get_tokenizer ...@@ -23,6 +23,39 @@ from megatron import get_tokenizer
from .communication import broadcast_int_list, broadcast_tensor from .communication import broadcast_int_list, broadcast_tensor
def detokenize_generations(tokens_gpu_tensor,
lengths_gpu_tensor,
return_segments):
"""Detokenize the generated tokens."""
tokenizer = get_tokenizer()
prompts_plus_generations = []
if return_segments:
prompts_plus_generations_segments = []
tokens = tokens_gpu_tensor.cpu().numpy().tolist()
lengths = lengths_gpu_tensor.cpu().numpy().tolist()
for sequence_tokens, length in zip(tokens, lengths):
sequence_tokens = sequence_tokens[:length]
prompts_plus_generations.append(
tokenizer.detokenize(sequence_tokens))
if return_segments:
words = []
for token in sequence_tokens:
word = tokenizer.tokenizer.decoder[token]
word = bytearray(
[tokenizer.tokenizer.byte_decoder[c] for c in word]).decode(
'utf-8', errors='replace')
words.append(word)
prompts_plus_generations_segments.append(words)
if return_segments:
return tokens, prompts_plus_generations, \
prompts_plus_generations_segments
return tokens, prompts_plus_generations
def tokenize_prompts(prompts=None, tokens_to_generate=None, rank=0): def tokenize_prompts(prompts=None, tokens_to_generate=None, rank=0):
"""Tokenize prompts and make them avaiable on all ranks.""" """Tokenize prompts and make them avaiable on all ranks."""
......
...@@ -153,8 +153,12 @@ def synced_generate(model, context_tokens_tensor, context_length_tensor, tokens_ ...@@ -153,8 +153,12 @@ def synced_generate(model, context_tokens_tensor, context_length_tensor, tokens_
if mpu.is_pipeline_last_stage(): if mpu.is_pipeline_last_stage():
src = mpu.get_pipeline_model_parallel_last_rank() src = mpu.get_pipeline_model_parallel_last_rank()
group = mpu.get_embedding_group() group = mpu.get_embedding_group()
print('last rank output size {} {} | \n'.format(output_logits.size(0), output_logits.size(1)))
torch.distributed.broadcast(output_logits, src, group) torch.distributed.broadcast(output_logits, src, group)
if all_probs: if all_probs:
print('last rank full size {} {} | \n'.format(full_logits.size(0),
full_logits.size(1),
full_logits.size(2)))
src = mpu.get_pipeline_model_parallel_last_rank() src = mpu.get_pipeline_model_parallel_last_rank()
group = mpu.get_embedding_group() group = mpu.get_embedding_group()
torch.distributed.broadcast(full_logits, src, group) torch.distributed.broadcast(full_logits, src, group)
...@@ -164,13 +168,18 @@ def synced_generate(model, context_tokens_tensor, context_length_tensor, tokens_ ...@@ -164,13 +168,18 @@ def synced_generate(model, context_tokens_tensor, context_length_tensor, tokens_
src = mpu.get_pipeline_model_parallel_last_rank() src = mpu.get_pipeline_model_parallel_last_rank()
group = mpu.get_embedding_group() group = mpu.get_embedding_group()
output_logits = torch.empty(tokens.size(0), context_length-1, dtype=torch.float32, device=torch.device("cuda")) output_logits = torch.empty(tokens.size(0), context_length-1, dtype=torch.float32, device=torch.device("cuda"))
print('first rank output size {} {} | \n'.format(output_logits.size(0), output_logits.size(1)))
torch.distributed.broadcast(output_logits, src, group) torch.distributed.broadcast(output_logits, src, group)
if all_probs: if all_probs:
args = get_args() args = get_args()
src = mpu.get_pipeline_model_parallel_last_rank() src = mpu.get_pipeline_model_parallel_last_rank()
group = mpu.get_embedding_group() group = mpu.get_embedding_group()
full_logits = torch.empty(tokens.size(0), context_length, args.padded_vocab_size, dtype=torch.float32, device=torch.device("cuda")) full_logits = torch.empty(tokens.size(0), context_length-1, args.padded_vocab_size, dtype=torch.float32, device=torch.device("cuda"))
print('first rank full size {} {} | \n'.format(full_logits.size(0),
full_logits.size(1),
full_logits.size(2)))
torch.distributed.broadcast(full_logits, src, group) torch.distributed.broadcast(full_logits, src, group)
if tokens is not None: if tokens is not None:
return tokens[:, :context_length], output_logits, full_logits return tokens[:, :context_length], output_logits, full_logits
...@@ -204,7 +213,7 @@ def generate(model, sentences=None, tokens_to_generate=0, all_probs=False, tempe ...@@ -204,7 +213,7 @@ def generate(model, sentences=None, tokens_to_generate=0, all_probs=False, tempe
output_logits = output_logits.cpu().numpy().tolist() output_logits = output_logits.cpu().numpy().tolist()
if all_probs: if all_probs:
full_logits = full_logits.cpu().numpy().tolist() full_logits = full_logits.cpu().numpy() #.tolist()
return resp_sentences, resp_sentences_seg, output_logits, full_logits, decode_tokens return resp_sentences, resp_sentences_seg, output_logits, full_logits, decode_tokens
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment