Commit 6c40f892 authored by mshoeybi's avatar mshoeybi
Browse files

working

parent 25f9c3f0
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference API."""
import torch
from .communication import broadcast_float_list
from .generation import generate_tokens_probs_and_return_on_first_stage
from .tokenization import tokenize_prompts
def generate(model,
prompts=None,
tokens_to_generate=0,
return_output_log_probs=False,
return_all_log_probs=False,
temperature=1.0):
"""TO DO ..."""
# Make sure input params are avaialble to all ranks.
values = [tokens_to_generate, return_output_log_probs,
return_all_log_probs, temperature]
values_float_tensor = broadcast_float_list(4, float_list=values)
tokens_to_generate = int(values_float_tensor[0].item())
return_output_log_probs = bool(values_float_tensor[1].item())
return_all_log_probs = bool(values_float_tensor[2].item())
temperature = values_float_tensor[2].item()
# Tokenize prompts and get the batch.
# Note that these tensors are broadcaseted to all ranks.
if torch.distributed.get_rank() == 0:
assert prompts is not None
context_tokens_tensor, context_length_tensor = tokenize_prompts(
prompts=prompts, tokens_to_generate=tokens_to_generate)
# Main inference function.
# Note that the outputs are available on the first stage.
return generate_tokens_probs_and_return_on_first_stage(
model, context_tokens_tensor, context_length_tensor,
return_output_log_probs=return_output_log_probs,
return_all_log_probs=return_all_log_probs,
temperature=temperature)
......@@ -40,6 +40,33 @@ def broadcast_from_last_pipeline_stage(size, dtype, tensor=None):
return tensor
def broadcast_from_last_to_first_pipeline_stage(size, dtype, tensor=None):
"""Broadcast tensor values from last stage into the first stage."""
# Only first and last stage pipeline stages need to be involved.
is_last_stage = mpu.is_pipeline_last_stage()
is_first_stage = mpu.is_pipeline_first_stage()
if is_last_stage or is_first_stage:
if is_last_stage:
assert tensor is not None
assert tensor.is_cuda
assert tensor.is_contiguous()
else:
tensor = torch.empty(size,
dtype=dtype,
device=torch.cuda.current_device())
src = mpu.get_pipeline_model_parallel_last_rank()
group = mpu.get_embedding_group()
# Broadcast from last stage into the first stage.
torch.distributed.broadcast(tensor, src, group)
else:
tensor = None
return tensor
def copy_from_last_to_first_pipeline_stage(size, dtype, tensor=None):
"""Copy tensor values from last stage into the first stage.
Note that the input tensor is updated in place."""
......@@ -48,20 +75,24 @@ def copy_from_last_to_first_pipeline_stage(size, dtype, tensor=None):
is_last_stage = mpu.is_pipeline_last_stage()
is_first_stage = mpu.is_pipeline_first_stage()
if is_last_stage or is_first_stage:
assert tensor is not None
assert tensor.is_cuda
is_contiguous = tensor.is_contiguous()
src = mpu.get_pipeline_model_parallel_last_rank()
group = mpu.get_embedding_group()
if is_last_stage:
assert tensor is not None
assert tensor.is_cuda
tensor_ = tensor.contiguous()
if is_contiguous:
tensor_ = tensor
else:
tensor_ = torch.empty(size,
dtype=dtype,
device=torch.cuda.current_device())
if is_last_stage:
tensor_ = tensor.contiguous()
else:
tensor_ = torch.empty(size,
dtype=dtype,
device=torch.cuda.current_device())
# Broadcast from last stage into the first stage.
torch.distributed.broadcast(tensor_, src, group)
# Update the first stage tensor
if is_first_stage:
if is_first_stage and not is_contiguous:
tensor[...] = tensor_
......
......@@ -19,19 +19,44 @@
import torch
import torch.nn.functional as F
from megatron import get_args, get_tokenizer
from megatron import mpu
from megatron import get_args, get_tokenizer, mpu
from megatron.utils import get_ltor_masks_and_position_ids
from .communication import (
copy_from_last_to_first_pipeline_stage,
broadcast_from_last_pipeline_stage)
broadcast_from_last_pipeline_stage,
broadcast_from_last_to_first_pipeline_stage)
from .forward_step import forward_step
from .sampling import sample
def generate_tokens(model, tokens, lengths, return_all_probs=False,
temperature=1.0):
"""Main token generation function."""
def generate_tokens_probs_and_return_on_first_stage(
model, tokens, lengths,
return_output_log_probs=False,
return_all_log_probs=False,
temperature=1.0):
"""Main token generation function.
Arguments:
model: XXX
tokens: prompt tokens extended to be of size [b, max-sequence-length]
lengths: original prompt length, size: [b]
return_output_log_probs: flag to calculate the log probability of
the generated tokens. Note that the log probability is the one
after logits are modifed for sampling.
return_all_log_probs: flag to calculate the log probability of across
all the tokens (vocab size). Note that the log probability is the
one after logits are modifed for sampling.
temperature: sampling temperature.
Note: Outside of model, other parameters only need to be available on
rank 0.
Outputs: Note that is size is adjusted to a lower value than
max-sequence-length if generation is terminated early.
tokens: prompt and generated tokens. size: [b, :]
generated_sequence_lengths: total length (including prompt) of
the generated sequence. size: [b]
output_log_probs: log probability of the selected tokens. size: [b, s]
all_log_probs: log probability of all the tokens.
size: [b, s, vocab-size]
"""
args = get_args()
tokenizer = get_tokenizer()
......@@ -52,18 +77,35 @@ def generate_tokens(model, tokens, lengths, return_all_probs=False,
# Pre-allocate memory
# ===================
# Log probability of the sequence (prompt + generated tokens)
output_log_probs = torch.empty(batch_size, max_sequence_length - 1,
dtype=torch.float32,
device=torch.cuda.current_device())
# Log probability of the sequence (prompt + generated tokens).
output_log_probs = None
output_log_probs_size = (batch_size, max_sequence_length - 1)
# Log probability of all tokens for the sequence.
all_log_probs = None
all_log_probs_size = (batch_size, max_sequence_length -1,
args.padded_vocab_size)
# Lengths of generated seuquence including including prompts.
generated_sequence_lengths = torch.ones(
batch_size, dtype=torch.int64,
device=torch.cuda.current_device()) * max_sequence_length
generated_sequence_lengths = None
if mpu.is_pipeline_last_stage():
if return_output_log_probs:
output_log_probs = torch.empty(output_log_probs_size,
dtype=torch.float32,
device=torch.cuda.current_device())
if return_all_log_probs:
all_log_probs = torch.empty(all_log_probs_size,
dtype=torch.float32,
device=torch.cuda.current_device())
generated_sequence_lengths = torch.ones(
batch_size, dtype=torch.int64,
device=torch.cuda.current_device()) * max_sequence_length
# Whether we have reached a termination id.
is_generation_done = torch.zeros(batch_size, dtype=torch.uint8,
device=torch.cuda.current_device())
# =============
# Run infernece
# =============
attention_mask, position_ids = _build_attention_mask_and_position_ids(
tokens)
......@@ -114,15 +156,25 @@ def generate_tokens(model, tokens, lengths, return_all_probs=False,
tokens[started, context_length] = new_sample[started]
# Calculate the log probabilities.
log_probs = F.log_softmax(logits, dim=2)
# Pick the tokens that we need to get the log probabilities for.
# Note that next input token is the token which we selected in
# the current logits, so shift by 1.
indices = torch.unsqueeze(
tokens[:, (prev_context_length + 1):(context_length + 1)],
2)
output_log_probs[:, prev_context_length:context_length] = \
torch.gather(log_probs, 2, indices).squeeze(2)
if return_output_log_probs or return_all_log_probs:
log_probs = F.log_softmax(logits, dim=2)
if return_all_log_probs:
all_log_probs[:,
prev_context_length:context_length,
:] = log_probs
if return_output_log_probs:
# Pick the tokens that we need to get the log
# probabilities for. Note that next input token is
# the token which we selected in the current logits,
# so shift by 1.
indices = torch.unsqueeze(
tokens[
:,
(prev_context_length + 1):(context_length + 1)],
2)
output_log_probs[:,
prev_context_length:context_length] = \
torch.gather(log_probs, 2, indices).squeeze(2)
# Update the tokens on the first stage so the next input to
# the network is correct.
......@@ -147,17 +199,36 @@ def generate_tokens(model, tokens, lengths, return_all_probs=False,
if done:
break
if mpu.is_pipeline_last_stage():
if return_all_probs:
full_logits = None
return tokens, generated_sequence_lengths, output_log_probs, \
full_logits, context_length + 1
return tokens, generated_sequence_lengths, output_log_probs, \
None, context_length + 1
if mpu.is_pipeline_first_stage():
return tokens, None, None, None, context_length + 1
return None, None, None, None, context_length + 1
# ===================================================
# Update the length of based on max generated length.
# ===================================================
tokens = tokens[:, :(context_length + 1)]
if mpu.is_pipeline_last_stage():
if return_output_log_probs:
output_log_probs = output_log_probs[:, :context_length]
if return_all_log_probs:
all_log_probs = all_log_probs[:, :context_length, :]
# ======================================
# Broadcast to the first pipeline stage.
# ======================================
generated_sequence_lengths = broadcast_from_last_to_first_pipeline_stage(
batch_size, torch.int64, generated_sequence_lengths)
if return_output_log_probs:
output_log_probs_size = (batch_size, context_length)
output_log_probs = broadcast_from_last_to_first_pipeline_stage(
output_log_probs_size, torch.float32, output_log_probs)
if return_all_log_probs:
all_log_probs_size = (batch_size, context_length,
args.padded_vocab_size)
all_log_probs = broadcast_from_last_to_first_pipeline_stage(
all_log_probs_size, torch.float32, all_log_probs)
return tokens, generated_sequence_lengths, output_log_probs, \
all_log_probs
def _build_attention_mask_and_position_ids(tokens):
......
......@@ -23,6 +23,39 @@ from megatron import get_tokenizer
from .communication import broadcast_int_list, broadcast_tensor
def detokenize_generations(tokens_gpu_tensor,
lengths_gpu_tensor,
return_segments):
"""Detokenize the generated tokens."""
tokenizer = get_tokenizer()
prompts_plus_generations = []
if return_segments:
prompts_plus_generations_segments = []
tokens = tokens_gpu_tensor.cpu().numpy().tolist()
lengths = lengths_gpu_tensor.cpu().numpy().tolist()
for sequence_tokens, length in zip(tokens, lengths):
sequence_tokens = sequence_tokens[:length]
prompts_plus_generations.append(
tokenizer.detokenize(sequence_tokens))
if return_segments:
words = []
for token in sequence_tokens:
word = tokenizer.tokenizer.decoder[token]
word = bytearray(
[tokenizer.tokenizer.byte_decoder[c] for c in word]).decode(
'utf-8', errors='replace')
words.append(word)
prompts_plus_generations_segments.append(words)
if return_segments:
return tokens, prompts_plus_generations, \
prompts_plus_generations_segments
return tokens, prompts_plus_generations
def tokenize_prompts(prompts=None, tokens_to_generate=None, rank=0):
"""Tokenize prompts and make them avaiable on all ranks."""
......
......@@ -153,8 +153,12 @@ def synced_generate(model, context_tokens_tensor, context_length_tensor, tokens_
if mpu.is_pipeline_last_stage():
src = mpu.get_pipeline_model_parallel_last_rank()
group = mpu.get_embedding_group()
print('last rank output size {} {} | \n'.format(output_logits.size(0), output_logits.size(1)))
torch.distributed.broadcast(output_logits, src, group)
if all_probs:
print('last rank full size {} {} | \n'.format(full_logits.size(0),
full_logits.size(1),
full_logits.size(2)))
src = mpu.get_pipeline_model_parallel_last_rank()
group = mpu.get_embedding_group()
torch.distributed.broadcast(full_logits, src, group)
......@@ -164,13 +168,18 @@ def synced_generate(model, context_tokens_tensor, context_length_tensor, tokens_
src = mpu.get_pipeline_model_parallel_last_rank()
group = mpu.get_embedding_group()
output_logits = torch.empty(tokens.size(0), context_length-1, dtype=torch.float32, device=torch.device("cuda"))
print('first rank output size {} {} | \n'.format(output_logits.size(0), output_logits.size(1)))
torch.distributed.broadcast(output_logits, src, group)
if all_probs:
args = get_args()
src = mpu.get_pipeline_model_parallel_last_rank()
group = mpu.get_embedding_group()
full_logits = torch.empty(tokens.size(0), context_length, args.padded_vocab_size, dtype=torch.float32, device=torch.device("cuda"))
full_logits = torch.empty(tokens.size(0), context_length-1, args.padded_vocab_size, dtype=torch.float32, device=torch.device("cuda"))
print('first rank full size {} {} | \n'.format(full_logits.size(0),
full_logits.size(1),
full_logits.size(2)))
torch.distributed.broadcast(full_logits, src, group)
if tokens is not None:
return tokens[:, :context_length], output_logits, full_logits
......@@ -204,7 +213,7 @@ def generate(model, sentences=None, tokens_to_generate=0, all_probs=False, tempe
output_logits = output_logits.cpu().numpy().tolist()
if all_probs:
full_logits = full_logits.cpu().numpy().tolist()
full_logits = full_logits.cpu().numpy() #.tolist()
return resp_sentences, resp_sentences_seg, output_logits, full_logits, decode_tokens
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment