Commit 25f9c3f0 authored by mshoeybi's avatar mshoeybi
Browse files

refactoring, tested and working

parent e722c4a9
......@@ -18,6 +18,51 @@
import torch
from megatron import mpu
def broadcast_from_last_pipeline_stage(size, dtype, tensor=None):
"""Broadcast a tensor from last pipeline stage to all ranks."""
if mpu.is_pipeline_last_stage():
assert tensor is not None
assert tensor.is_cuda
assert tensor.is_contiguous()
else:
tensor = torch.empty(size,
dtype=dtype,
device=torch.cuda.current_device())
# Get the group and corresponding source rank.
src = mpu.get_pipeline_model_parallel_last_rank()
group = mpu.get_pipeline_model_parallel_group()
torch.distributed.broadcast(tensor, src, group)
return tensor
def copy_from_last_to_first_pipeline_stage(size, dtype, tensor=None):
"""Copy tensor values from last stage into the first stage.
Note that the input tensor is updated in place."""
# Only first and last stage pipeline stages need to be involved.
is_last_stage = mpu.is_pipeline_last_stage()
is_first_stage = mpu.is_pipeline_first_stage()
if is_last_stage or is_first_stage:
src = mpu.get_pipeline_model_parallel_last_rank()
group = mpu.get_embedding_group()
if is_last_stage:
assert tensor is not None
assert tensor.is_cuda
tensor_ = tensor.contiguous()
else:
tensor_ = torch.empty(size,
dtype=dtype,
device=torch.cuda.current_device())
# Broadcast from last stage into the first stage.
torch.distributed.broadcast(tensor_, src, group)
# Update the first stage tensor
if is_first_stage:
tensor[...] = tensor_
def broadcast_tensor(size, dtype, tensor=None, rank=0):
......
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Forward step utilities."""
import torch
from megatron.p2p_communication import recv_forward, send_forward
from .sampling import sample
from megatron import mpu
import torch.nn.functional as F
from megatron import print_rank_0
from megatron import get_args, get_tokenizer
from megatron.utils import get_ltor_masks_and_position_ids, unwrap_model
from .communication import (
broadcast_float_list,
copy_from_last_to_first_pipeline_stage,
broadcast_from_last_pipeline_stage)
from .tokenization import tokenize_prompts
# These are needed to unwrap the model, would be nice to put these in megatron.utils if possible?
from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
from megatron.model import DistributedDataParallel as LocalDDP
from megatron.model import Float16Module
def forward_step(model, tokens, position_ids, attention_mask,
set_inference_key_value_memory=False,
inference_max_sequence_len=None):
# Hidden size changes when not using recompute, need to tell p2p_communicate
# functions the correct size
args = get_args()
orig_seq_length = args.seq_length
args.seq_length = tokens.shape[1]
args.micro_batch_size = tokens.shape[0]
input_tensor = recv_forward()
# Forward pass through the model.
unwrapped_model = unwrap_model(
model, (torchDDP, LocalDDP, Float16Module))
unwrapped_model.set_input_tensor(input_tensor)
output_tensor = model(
tokens, position_ids, attention_mask,
set_inference_key_value_memory=set_inference_key_value_memory,
inference_max_sequence_len=inference_max_sequence_len)
send_forward(output_tensor)
args.seq_length = orig_seq_length
return output_tensor
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Generation utilities."""
import torch
import torch.nn.functional as F
from megatron import get_args, get_tokenizer
from megatron import mpu
from megatron.utils import get_ltor_masks_and_position_ids
from .communication import (
copy_from_last_to_first_pipeline_stage,
broadcast_from_last_pipeline_stage)
from .forward_step import forward_step
from .sampling import sample
def generate_tokens(model, tokens, lengths, return_all_probs=False,
temperature=1.0):
"""Main token generation function."""
args = get_args()
tokenizer = get_tokenizer()
batch_size = tokens.size(0)
min_prompt_length = lengths.min().item()
max_sequence_length = tokens.size(1)
max_sequence_length = min(max_sequence_length, args.max_position_embeddings)
# Added termination_id to support the case that we want to terminate the
# generation once that id is generated.
if hasattr(args, 'eos_id'):
termination_id = args.eos_id
else:
termination_id = tokenizer.eod
# ===================
# Pre-allocate memory
# ===================
# Log probability of the sequence (prompt + generated tokens)
output_log_probs = torch.empty(batch_size, max_sequence_length - 1,
dtype=torch.float32,
device=torch.cuda.current_device())
# Lengths of generated seuquence including including prompts.
generated_sequence_lengths = torch.ones(
batch_size, dtype=torch.int64,
device=torch.cuda.current_device()) * max_sequence_length
# Whether we have reached a termination id.
is_generation_done = torch.zeros(batch_size, dtype=torch.uint8,
device=torch.cuda.current_device())
attention_mask, position_ids = _build_attention_mask_and_position_ids(
tokens)
model.eval()
with torch.no_grad():
prev_context_length = 0
for context_length in range(min_prompt_length, max_sequence_length):
# If we are starting from scratch, allocate memory for the entire
# context, otherwise set this to false so the memory is not
# reallocated.
set_inference_key_value_memory = (prev_context_length == 0)
# Pick the slice that we need to pass through the network.
tokens2use = tokens[:, prev_context_length:context_length]
positions2use = position_ids[:, prev_context_length:context_length]
attention_mask2use = attention_mask[
..., prev_context_length:context_length, :context_length]
# logits will be meanigful only in the last pipeline stage.
logits = forward_step(
model, tokens2use, positions2use, attention_mask2use,
set_inference_key_value_memory=set_inference_key_value_memory,
inference_max_sequence_len=max_sequence_length)
if mpu.is_pipeline_last_stage():
# Always the last stage should have an output.
assert logits is not None
# Sample.
last_token_logits = logits[:, -1, :]
new_sample, updated_last_token_logits = sample(
last_token_logits,
greedy=args.greedy,
top_k=args.top_k,
top_p=args.top_p,
temperature=temperature,
vocab_size=tokenizer.vocab_size)
# Now that we have the sample and updated logits,
# update the main logits and input tokens.
# If a prompt length is smaller or equal th current context
# length, it means we have started generating tokens
started = lengths <= context_length
# Update the logits
last_token_logits.masked_scatter_(
started.unsqueeze(1), updated_last_token_logits[started])
# and the tokens.
tokens[started, context_length] = new_sample[started]
# Calculate the log probabilities.
log_probs = F.log_softmax(logits, dim=2)
# Pick the tokens that we need to get the log probabilities for.
# Note that next input token is the token which we selected in
# the current logits, so shift by 1.
indices = torch.unsqueeze(
tokens[:, (prev_context_length + 1):(context_length + 1)],
2)
output_log_probs[:, prev_context_length:context_length] = \
torch.gather(log_probs, 2, indices).squeeze(2)
# Update the tokens on the first stage so the next input to
# the network is correct.
copy_from_last_to_first_pipeline_stage(batch_size, torch.int64,
tokens[:, context_length])
# Update the context length for the next token generation.
prev_context_length = context_length
# Check if all the sequences have hit the termination_id.
done = None
if mpu.is_pipeline_last_stage():
done_token = (new_sample == termination_id).byte() & \
started.byte()
just_finished = (done_token & ~is_generation_done).bool()
generated_sequence_lengths[just_finished.view(-1)] = \
context_length + 1
is_generation_done = is_generation_done | done_token
done = torch.all(is_generation_done)
done = broadcast_from_last_pipeline_stage(1, torch.uint8,
tensor=done)
if done:
break
if mpu.is_pipeline_last_stage():
if return_all_probs:
full_logits = None
return tokens, generated_sequence_lengths, output_log_probs, \
full_logits, context_length + 1
return tokens, generated_sequence_lengths, output_log_probs, \
None, context_length + 1
if mpu.is_pipeline_first_stage():
return tokens, None, None, None, context_length + 1
return None, None, None, None, context_length + 1
def _build_attention_mask_and_position_ids(tokens):
"""Build the attention mask and postition ids for the input tokens."""
# Since we are not interested in loss-mask and reset attention/position
# is also False, eod_token is not used so it is safe to set it to None.
attention_mask, _, position_ids = get_ltor_masks_and_position_ids(
data=tokens,
eod_token=None,
reset_position_ids=False,
reset_attention_mask=False,
eod_mask_loss=False)
return attention_mask, position_ids
......@@ -281,8 +281,6 @@ class ParallelAttention(MegatronModule):
self.inference_value_memory[start:end, ...] = value_layer
key_layer = self.inference_key_memory[:end, ...]
value_layer = self.inference_value_memory[:end, ...]
# Adjust attention mask
attention_mask = attention_mask[..., start:end, :end]
# ===================================
......
......@@ -297,6 +297,7 @@ def sample_sequence_batch(model, context_tokens, context_lengths,
positions2use = position_ids[:, :context_length]
if type_ids is not None:
types2use = type_ids[:, :context_length]
attention_mask2use = attention_mask[..., :context_length, :context_length]
else:
# Set this to false so the memory is not reallocated.
set_inference_key_value_memory = False
......@@ -307,11 +308,12 @@ def sample_sequence_batch(model, context_tokens, context_lengths,
if type_ids is not None:
types2use = type_ids[:, context_length - 1].view(
batch_size, -1)
attention_mask2use = attention_mask[..., (context_length-1):context_length, :context_length]
output = forward_step(
model, tokens2use,
positions2use,
attention_mask,
attention_mask2use,
set_inference_key_value_memory=set_inference_key_value_memory,
inference_max_sequence_len=maxlen,
tokentype_ids=types2use)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment