Commit 0024a5c6 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge branch 'main' of https://github.com/NVIDIA/Megatron-LM

parents b004456b 3db2063b
Pipeline #229 failed with stages
in 0 seconds
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors, Facebook AI Research authors and The HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
## from huggingface beam search
class BeamHypotheses(object):
def __init__(self, num_beams, length_penalty=1.0, early_stopping=False):
"""
Initialize n-best list of hypotheses.
"""
self.length_penalty = length_penalty
self.early_stopping = early_stopping
self.num_beams = num_beams
self.beams = []
self.worst_score = 1e9
def __len__(self):
"""
Number of hypotheses in the list.
"""
return len(self.beams)
def add(self, hyp, sum_logprobs, length):
"""
Add a new hypothesis to the list.
"""
score = sum_logprobs / length ** self.length_penalty
if len(self) < self.num_beams or score > self.worst_score:
self.beams.append((score, hyp))
if len(self) > self.num_beams:
sorted_scores = sorted([(s, idx) for idx, (s, _) in enumerate(self.beams)])
del self.beams[sorted_scores[0][1]]
self.worst_score = sorted_scores[1][0]
else:
self.worst_score = min(score, self.worst_score)
def is_done(self, best_sum_logprobs, cur_len):
"""
If there are enough hypotheses and that none of the hypotheses being generated
can become better than the worst one in the heap, then we are done with this sentence.
"""
if len(self) < self.num_beams:
return False
elif self.early_stopping:
return True
else:
cur_score = best_sum_logprobs / cur_len ** self.length_penalty
ret = self.worst_score >= cur_score
return ret
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
"""Communications utilities."""
import torch
from megatron.core import mpu
# TODO: use functions from megatron/p2p
def recv_from_prev_pipeline_rank_(recv_buffer=None):
"""Receive from previous pipeline stage and update the
input buffer inplace."""
if not mpu.is_pipeline_first_stage():
assert recv_buffer is not None
recv_prev_op = torch.distributed.P2POp(
torch.distributed.irecv, recv_buffer,
mpu.get_pipeline_model_parallel_prev_rank())
reqs = torch.distributed.batch_isend_irecv([recv_prev_op])
for req in reqs:
req.wait()
# To protect against race condition when using batch_isend_irecv().
torch.cuda.synchronize()
# TODO: use functions from megatron/p2p
def send_to_next_pipeline_rank(tensor=None):
"""Send output to the next pipeline stage."""
if not mpu.is_pipeline_last_stage():
assert tensor is not None
send_next_op = torch.distributed.P2POp(
torch.distributed.isend, tensor,
mpu.get_pipeline_model_parallel_next_rank())
reqs = torch.distributed.batch_isend_irecv([send_next_op])
for req in reqs:
req.wait()
# To protect against race condition when using batch_isend_irecv().
torch.cuda.synchronize()
def _is_cuda(tensor):
"""Check if a tensor is not none and is cuda."""
assert tensor is not None
assert tensor.is_cuda
def _is_cuda_contiguous(tensor):
"""Check if a tensor is not none, is cuda, and is contiguous."""
_is_cuda(tensor)
assert tensor.is_contiguous()
def broadcast_from_last_pipeline_stage(size, dtype, tensor=None):
"""Broadcast a tensor from last pipeline stage to all ranks."""
is_last_stage = mpu.is_pipeline_last_stage()
# If first stage and last state are the same, then there is no
# pipeline parallelism and no need to communicate.
if mpu.is_pipeline_first_stage() and is_last_stage:
return tensor
if is_last_stage:
_is_cuda_contiguous(tensor)
else:
tensor = torch.empty(size,
dtype=dtype,
device=torch.cuda.current_device())
# Get the group and corresponding source rank.
src = mpu.get_pipeline_model_parallel_last_rank()
group = mpu.get_pipeline_model_parallel_group()
torch.distributed.broadcast(tensor, src, group)
return tensor
def broadcast_from_last_to_first_pipeline_stage(size, dtype, tensor=None):
"""Broadcast tensor values from last stage into the first stage."""
is_last_stage = mpu.is_pipeline_last_stage()
is_first_stage = mpu.is_pipeline_first_stage()
# If first stage and last state are the same, then there is no
# pipeline parallelism and no need to communicate.
if is_first_stage and is_last_stage:
return tensor
# Only first and last stage pipeline stages need to be involved.
if is_last_stage or is_first_stage:
if is_last_stage:
_is_cuda_contiguous(tensor)
else:
tensor = torch.empty(size,
dtype=dtype,
device=torch.cuda.current_device())
src = mpu.get_pipeline_model_parallel_last_rank()
group = mpu.get_embedding_group()
# Broadcast from last stage into the first stage.
torch.distributed.broadcast(tensor, src, group)
else:
tensor = None
return tensor
def copy_from_last_to_first_pipeline_stage(size, dtype, tensor=None):
"""Copy tensor values from last stage into the first stage.
Note that the input tensor is updated in place."""
is_last_stage = mpu.is_pipeline_last_stage()
is_first_stage = mpu.is_pipeline_first_stage()
# If first stage and last state are the same, then there is no
# pipeline parallelism and no need to communicate.
if is_first_stage and is_last_stage:
return
# Only first and last stage pipeline stages need to be involved.
if is_last_stage or is_first_stage:
_is_cuda(tensor)
is_contiguous = tensor.is_contiguous()
src = mpu.get_pipeline_model_parallel_last_rank()
group = mpu.get_embedding_group()
if is_contiguous:
tensor_ = tensor
else:
if is_last_stage:
tensor_ = tensor.contiguous()
else:
tensor_ = torch.empty(size,
dtype=dtype,
device=torch.cuda.current_device())
# Broadcast from last stage into the first stage.
torch.distributed.broadcast(tensor_, src, group)
# Update the first stage tensor
if is_first_stage and not is_contiguous:
tensor[...] = tensor_
def broadcast_tensor(size, dtype, tensor=None, rank=0):
""" Given size and type of a tensor on all ranks and the tensor value
only on a specific rank, broadcast from that rank to all other ranks.
"""
if torch.distributed.get_rank() == rank:
_is_cuda_contiguous(tensor)
else:
tensor = torch.empty(size,
dtype=dtype,
device=torch.cuda.current_device())
torch.distributed.broadcast(tensor, rank)
return tensor
def broadcast_list(size, dtype, list_values=None, rank=0):
"""Broadcast a list of values with a given type."""
tensor = None
if torch.distributed.get_rank() == rank:
tensor = torch.tensor(list_values, dtype=dtype,
device=torch.cuda.current_device())
return broadcast_tensor(size, dtype, tensor=tensor, rank=rank)
def broadcast_int_list(size, int_list=None, rank=0):
"""Broadcast a list of interger values."""
return broadcast_list(size, torch.int64, list_values=int_list, rank=rank)
def broadcast_float_list(size, float_list=None, rank=0):
"""Broadcast a list of float values."""
return broadcast_list(size, torch.float32, list_values=float_list,
rank=rank)
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
"""Forward step utilities."""
from collections.abc import Iterable
import torch
from megatron import get_args
from megatron.core import mpu
from .communication import (
send_to_next_pipeline_rank,
recv_from_prev_pipeline_rank_)
class InferenceParams:
"""Inference parameters that are passed to the main model in order
to efficienly calculate and store the context during inference."""
def __init__(self, max_batch_size, max_sequence_len):
"""Note that offsets are set to zero and we always set the
flag to allocate memory. After the first call, make sure to
set this flag to False."""
self.max_sequence_len = max_sequence_len
self.max_batch_size = max_batch_size
self.sequence_len_offset = 0
self.batch_size_offset = 0
self.key_value_memory_dict = {}
def swap_key_value_dict(self, batch_idx):
"swap between batches"
if len(self.key_value_memory_dict) == 0:
raise ValueError("should not swap when dict in empty")
for layer_number in self.key_value_memory_dict.keys():
inference_key_memory, inference_value_memory = self.key_value_memory_dict[layer_number]
assert len(batch_idx) == inference_key_memory.shape[1] ## make sure batch size is the same
new_inference_key_memory = inference_key_memory[:, batch_idx]
new_inference_value_memory = inference_value_memory[:, batch_idx]
self.key_value_memory_dict[layer_number] = (
new_inference_key_memory, new_inference_value_memory)
class ForwardStep:
"""Forward step function with all the communications.
We use a class here to hide the inference parameters
from the outside caller."""
def __init__(self, model, max_batch_size, max_sequence_len):
"""Set values so we don't need to do it multiple times."""
# Make sure model is in eval mode.
assert not isinstance(model, Iterable), \
'interleaving schedule is not supported for inference'
model.eval()
self.model = model
# Initialize inference parameters.
self.inference_params = InferenceParams(max_batch_size,
max_sequence_len)
# Pipelining arguments.
args = get_args()
self.pipeline_size_larger_than_one = (
args.pipeline_model_parallel_size > 1)
# Threshold of pipelining.
self.pipelining_batch_x_seqlen = \
args.inference_batch_times_seqlen_threshold
def __call__(self, tokens, position_ids, attention_mask):
"""Invocation of the forward methods. Note that self.inference_params
is being modified by the forward step."""
# Pipelining case.
if self.pipeline_size_larger_than_one:
current_batch_x_seqlen = tokens.size(0) * tokens.size(1)
if current_batch_x_seqlen >= self.pipelining_batch_x_seqlen:
micro_batch_size = \
max(1, self.pipelining_batch_x_seqlen // tokens.size(1))
return _with_pipelining_forward_step(self.model,
tokens,
position_ids,
attention_mask,
self.inference_params,
micro_batch_size)
return _no_pipelining_forward_step(self.model,
tokens,
position_ids,
attention_mask,
self.inference_params)
def _get_recv_buffer_dtype(args):
"""Receive happens between the layers."""
if args.fp32_residual_connection:
return torch.float
return args.params_dtype
def _allocate_recv_buffer(batch_size, sequence_length):
"""Receive happens between the layers with size [s, b, h]."""
if mpu.is_pipeline_first_stage():
return None
args = get_args()
recv_size = (sequence_length, batch_size, args.hidden_size)
return torch.empty(recv_size,
dtype=_get_recv_buffer_dtype(args),
device=torch.cuda.current_device())
def _forward_step_helper(model, tokens, position_ids, attention_mask,
inference_params, recv_buffer=None):
"""Single forward step. Update the allocate memory flag so
only the first time the memory is allocated."""
batch_size = tokens.size(0)
sequence_length = tokens.size(1)
if recv_buffer is None:
recv_buffer = _allocate_recv_buffer(batch_size, sequence_length)
# Receive from previous stage.
recv_from_prev_pipeline_rank_(recv_buffer)
# Forward pass through the model.
model.set_input_tensor(recv_buffer)
output_tensor = model(tokens, position_ids, attention_mask,
inference_params=inference_params)
# Send output to the next stage.
send_to_next_pipeline_rank(output_tensor)
return output_tensor
def _no_pipelining_forward_step(model, tokens, position_ids, attention_mask,
inference_params, recv_buffer=None):
"""If recv_buffer is none, we will allocate one on the fly."""
# Run a simple forward pass.
output_tensor = _forward_step_helper(model, tokens, position_ids,
attention_mask, inference_params,
recv_buffer=recv_buffer)
# Update the sequence length offset.
inference_params.sequence_len_offset += tokens.size(1)
logits = None
if mpu.is_pipeline_last_stage():
logits = output_tensor
return logits
def _with_pipelining_forward_step(model, tokens, position_ids, attention_mask,
inference_params, micro_batch_size):
"""No interleaving is supported."""
sequence_length = tokens.size(1)
batch_size = tokens.size(0)
# Divide the batch dimension into micro batches.
num_micro_batches, last_chunk = divmod(batch_size,
micro_batch_size)
if last_chunk > 0:
num_micro_batches += 1
# Preallocate memory for output logits.
logits = None
if mpu.is_pipeline_last_stage():
args = get_args()
logits = torch.empty(
(batch_size, sequence_length, args.padded_vocab_size),
dtype=torch.float32, device=torch.cuda.current_device())
# Preallocate recv buffer.
recv_buffer = _allocate_recv_buffer(micro_batch_size, sequence_length)
for micro_batch_index in range(num_micro_batches):
# Slice among the batch dimenion.
start = micro_batch_index * micro_batch_size
end = min(start + micro_batch_size, batch_size)
this_micro_batch_size = end - start
tokens2use = tokens[start:end, ...]
position_ids2use = position_ids[start:end, ...]
# Run a simple forward pass.
if this_micro_batch_size != micro_batch_size:
recv_buffer = None
output = _forward_step_helper(model, tokens2use, position_ids2use,
attention_mask, inference_params,
recv_buffer=recv_buffer)
# Adjust the batch size offset to account for the micro-batch.
inference_params.batch_size_offset += this_micro_batch_size
# Copy logits.
if mpu.is_pipeline_last_stage():
logits[start:end, ...] = output
# Once we are done with all the micro-batches, we can
# adjust the sequence length offset.
inference_params.sequence_len_offset += sequence_length
# and reset the batch size offset
inference_params.batch_size_offset = 0
return logits
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
"""Generation utilities."""
import torch
import torch.nn.functional as F
from megatron import get_args, get_tokenizer
from megatron.core import mpu
from megatron.utils import get_ltor_masks_and_position_ids
from .communication import (
copy_from_last_to_first_pipeline_stage,
broadcast_from_last_pipeline_stage,
broadcast_from_last_to_first_pipeline_stage)
from .forward_step import ForwardStep
from .sampling import sample
from .beam_utils import BeamHypotheses
def score_and_return_on_first_stage(model, tokens, lengths):
"""Function for just scoring.
Arguments:
model: no interleaving is supported.
tokens: prompt tokens extended to be of size [b, max_prompt_length]
lengths: original prompt length, size: [b]
Note: Outside of model, other parameters only need to be available on
rank 0.
Outputs:
output_log_probs: log probability of the selected tokens. size: [b, s]
"""
args = get_args()
batch_size = tokens.size(0)
max_prompt_length = lengths.max().item()
assert max_prompt_length == tokens.size(1)
if max_prompt_length > args.max_position_embeddings:
raise ValueError("Length of prompt + tokens_to_generate longer than allowed")
if max_prompt_length * batch_size > args.max_tokens_to_oom:
raise ValueError("Too many tokens. " + str(max_prompt_length*batch_size)+ " is greater than "+str(args.max_tokens_to_oom))
# forward step.
forward_step = ForwardStep(model, batch_size, max_prompt_length)
# ===================
# Pre-allocate memory
# ===================
# Log probability of the sequence (prompt + generated tokens).
output_log_probs = None
output_log_probs_size = (batch_size, max_prompt_length - 1)
if mpu.is_pipeline_last_stage():
output_log_probs = torch.empty(output_log_probs_size,
dtype=torch.float32,
device=torch.cuda.current_device())
# =============
# Run infernece
# =============
with torch.no_grad():
attention_mask, position_ids = _build_attention_mask_and_position_ids(tokens)
# logits will be meanigful only in the last pipeline stage.
logits = forward_step(tokens, position_ids, attention_mask)
if mpu.is_pipeline_last_stage():
# Always the last stage should have an output.
assert logits is not None
log_probs = F.log_softmax(logits, dim=2)
# Pick the tokens that we need to get the log
# probabilities for. Note that next input token is
# the token which we selected in the current logits,
# so shift by 1.
indices = torch.unsqueeze(tokens[:, 1:], 2)
output_log_probs = torch.gather(log_probs, 2, indices).squeeze(2)
# ======================================
# Broadcast to the first pipeline stage.
# ======================================
output_log_probs = broadcast_from_last_to_first_pipeline_stage(
output_log_probs_size, torch.float32, output_log_probs)
return tokens, lengths, output_log_probs
def generate_tokens_probs_and_return_on_first_stage(
model, tokens, lengths,
return_output_log_probs=False,
top_k=0, top_p=0.0, top_p_decay=0.0, top_p_bound=0.0,
temperature=1.0,
use_eod_token_for_early_termination=True,
stop_on_double_eol=False,
stop_on_eol=False,
prevent_newline_after_colon=True
):
"""Main token generation function.
Arguments:
model: no interleaving is supported.
tokens: prompt tokens extended to be of size [b, max-sequence-length]
lengths: original prompt length, size: [b]
return_output_log_probs: flag to calculate the log probability of
the generated tokens. Note that the log probability is the one
from the original logit.
top_k, top_p: top-k and top-p sampling parameters.
Note that top-k = 1 is gready. Also, these paramters are
exclusive meaning that:
if top-k > 0 then we expect top-p=0.
if top-p > 0 then we check for top-k=0.
temperature: sampling temperature.
use_eod_token_for_early_termination: if True, do early termination if
all the sequences have reached this token.
prevent_newline_after_colon: if True, it will disable generating new line \n after :
Note: Outside of model, other parameters only need to be available on
rank 0.
Outputs: Note that is size is adjusted to a lower value than
max-sequence-length if generation is terminated early.
tokens: prompt and generated tokens. size: [b, :]
generated_sequence_lengths: total length (including prompt) of
the generated sequence. size: [b]
output_log_probs: log probability of the selected tokens. size: [b, s]
"""
args = get_args()
tokenizer = get_tokenizer()
batch_size = tokens.size(0)
min_prompt_length = lengths.min().item()
max_sequence_length = tokens.size(1)
if max_sequence_length > args.max_position_embeddings:
raise ValueError("Length of prompt + tokens_to_generate longer than allowed")
if max_sequence_length * batch_size > args.max_tokens_to_oom:
raise ValueError("Too many tokens. " + str(max_sequence_length*batch_size)+ " is greater than "+str(args.max_tokens_to_oom))
# forward step.
forward_step = ForwardStep(model, batch_size, max_sequence_length)
# Added termination_id to support the case that we want to terminate the
# generation once that id is generated.
if hasattr(args, 'eos_id'):
termination_id = args.eos_id
else:
termination_id = tokenizer.eod
# ===================
# Pre-allocate memory
# ===================
# Log probability of the sequence (prompt + generated tokens).
output_log_probs = None
output_log_probs_size = (batch_size, max_sequence_length - 1)
# Lengths of generated seuquence including including prompts.
generated_sequence_lengths = None
if mpu.is_pipeline_last_stage():
if return_output_log_probs:
output_log_probs = torch.empty(output_log_probs_size,
dtype=torch.float32,
device=torch.cuda.current_device())
generated_sequence_lengths = torch.ones(
batch_size, dtype=torch.int64,
device=torch.cuda.current_device()) * max_sequence_length
# Whether we have reached a termination id.
is_generation_done = torch.zeros(batch_size, dtype=torch.uint8,
device=torch.cuda.current_device())
# =============
# Run infernece
# =============
with torch.no_grad():
attention_mask, position_ids = _build_attention_mask_and_position_ids(
tokens)
prev_context_length = 0
for context_length in range(min_prompt_length, max_sequence_length):
# Pick the slice that we need to pass through the network.
tokens2use = tokens[:, prev_context_length:context_length]
positions2use = position_ids[:, prev_context_length:context_length]
attention_mask2use = attention_mask[
..., prev_context_length:context_length, :context_length]
# logits will be meanigful only in the last pipeline stage.
logits = forward_step(tokens2use, positions2use, attention_mask2use)
if mpu.is_pipeline_last_stage():
if prevent_newline_after_colon:
logits[tokens2use[:, -1] == tokenizer.tokenize(':')[0], -1, tokenizer.tokenize('\n')[0]] = -1e10 # disable "\n" after ":"
# Always the last stage should have an output.
assert logits is not None
# Sample.
last_token_logits = logits[:, -1, :]
new_sample = sample(last_token_logits,
top_k=top_k,
top_p=top_p,
temperature=temperature,
vocab_size=tokenizer.vocab_size)
if top_p > 0.0 and top_p_decay > 0.0:
top_p = top_p * top_p_decay
if top_p_bound > 0.0:
top_p = max(top_p, top_p_bound)
# If a prompt length is smaller or equal th current context
# length, it means we have started generating tokens
started = lengths <= context_length
# Update the tokens.
tokens[started, context_length] = new_sample[started]
# Calculate the log probabilities.
if return_output_log_probs:
log_probs = F.log_softmax(logits, dim=2)
if return_output_log_probs:
# Pick the tokens that we need to get the log
# probabilities for. Note that next input token is
# the token which we selected in the current logits,
# so shift by 1.
indices = torch.unsqueeze(
tokens[
:,
(prev_context_length + 1):(context_length + 1)],
2)
output_log_probs[:,
prev_context_length:context_length] = \
torch.gather(log_probs, 2, indices).squeeze(2)
# Update the tokens on the first stage so the next input to
# the network is correct.
copy_from_last_to_first_pipeline_stage(batch_size, torch.int64,
tokens[:, context_length])
# Update the context length for the next token generation.
prev_context_length = context_length
# Check if all the sequences have hit the termination_id.
done = None
if mpu.is_pipeline_last_stage():
# TODO(rprenger) These stopping methods are tokenizer dependent
# instead tokenization should be in the inference loop so stop sequences can be used
if stop_on_double_eol:
hit_double_eol = (new_sample == 628).byte() & started.byte()
hit_two_eols = (new_sample == 198).byte() & (tokens[:, context_length-1] == 198).byte() & started.byte()
done_token = hit_double_eol | hit_two_eols
elif stop_on_eol:
hit_double_eol = (new_sample == 628).byte() & started.byte()
hit_eol = (new_sample == 198).byte() & started.byte()
done_token = hit_double_eol | hit_eol
else:
done_token = (new_sample == termination_id).byte() & \
started.byte()
just_finished = (done_token & ~is_generation_done).bool()
generated_sequence_lengths[just_finished.view(-1)] = \
context_length + 1
is_generation_done = is_generation_done | done_token
done = torch.all(is_generation_done)
done = broadcast_from_last_pipeline_stage(1, torch.uint8,
tensor=done)
if use_eod_token_for_early_termination and done:
break
# ===================================================
# Update the length of based on max generated length.
# ===================================================
tokens = tokens[:, :(context_length + 1)]
if mpu.is_pipeline_last_stage():
if return_output_log_probs:
output_log_probs = output_log_probs[:, :context_length]
# ======================================
# Broadcast to the first pipeline stage.
# ======================================
generated_sequence_lengths = broadcast_from_last_to_first_pipeline_stage(
batch_size, torch.int64, generated_sequence_lengths)
if return_output_log_probs:
output_log_probs_size = (batch_size, context_length)
output_log_probs = broadcast_from_last_to_first_pipeline_stage(
output_log_probs_size, torch.float32, output_log_probs)
return tokens, generated_sequence_lengths, output_log_probs
def beam_search_and_return_on_first_stage(model, tokens, lengths, beam_size, stop_token, num_return_gen, length_penalty, prevent_newline_after_colon=True):
args = get_args()
tokenizer = get_tokenizer()
batch_size = tokens.size(0)
assert(batch_size == 1)
prompt_length = lengths.item()
final_sequence_length = tokens.size(1)
final_sequence_length = min(final_sequence_length, args.max_position_embeddings)
# If the context is too big, this happens
if prompt_length >= final_sequence_length:
raise ValueError("context length + tokens_to_generate too large")
# forward step.
forward_step = ForwardStep(model, beam_size, final_sequence_length)
beam_hyp = BeamHypotheses(beam_size, length_penalty)
best_batches = None
done = torch.zeros(1, dtype=torch.uint8, device=torch.cuda.current_device())
scores = torch.zeros(beam_size,
dtype=torch.float32,
device=torch.cuda.current_device()).unsqueeze(1)
scores_size_tensor, tokens_size_tensor = None, None
# =============
# Run infernece
# =============
with torch.no_grad():
tokens = tokens.repeat(beam_size, 1)
attention_mask, position_ids = _build_attention_mask_and_position_ids(tokens)
prev_context_length = 0
for context_length in range(prompt_length, final_sequence_length):
# Pick the slice that we need to pass through the network.
tokens2use = tokens[:, prev_context_length:context_length]
positions2use = position_ids[:, prev_context_length:context_length]
attention_mask2use = attention_mask[
..., prev_context_length:context_length, :context_length]
# logits will be meanigful only in the last pipeline stage.
logits = forward_step(tokens2use, positions2use, attention_mask2use)
if mpu.is_pipeline_last_stage():
if prevent_newline_after_colon:
logits[tokens2use[:, -1] == tokenizer.tokenize(':')[0], -1, tokenizer.tokenize('\n')[0]] = -1e10 # disable "\n" after ":"
vocab_size = logits.size(2)
log_probs = F.log_softmax(logits, dim=2)
new_scores = log_probs[:, -1, :] + scores
if context_length == prompt_length: # if this is the first one
sorted_scores, indices = torch.sort(new_scores[0,:], descending=True)
else:
sorted_scores, indices = torch.sort(new_scores.view(-1), descending=True)
best_beam_ids = torch.div(indices[: 2 * beam_size], vocab_size).trunc().long()
best_words = indices[:2 * beam_size] % vocab_size
best_scores = sorted_scores[: 2 * beam_size]
next_beams = []
for beam_token_rank, (token_id, beam_score, beam_id) in enumerate(
zip(best_words, best_scores, best_beam_ids)
):
if token_id.item() == stop_token:
# if beam_token does not belong to top num_beams tokens, it should not be added
is_beam_token_worse_than_top_num_beams = beam_token_rank >= beam_size
if is_beam_token_worse_than_top_num_beams:
continue
beam_hyp.add(
tokens[beam_id].clone(),
beam_score,
context_length + 1 - prompt_length
)
else:
# add next predicted token since it is not eos_token
next_beams.append((token_id, beam_score, beam_id))
if len(next_beams) == beam_size:
break
if beam_hyp.is_done(best_scores.max().item(), context_length + 1 - prompt_length):
done = torch.ones(1, dtype=torch.uint8, device=torch.cuda.current_device())
best_batches = tokens.new([item[2] for item in next_beams])
tokens = tokens[best_batches,:]
tokens[:, context_length] = tokens.new([item[0] for item in next_beams])
scores = scores.new([item[1] for item in next_beams]).unsqueeze(1)
# torch.distributed.barrier()
done = broadcast_from_last_pipeline_stage(1, torch.uint8, done)
if done:
break
# Update the tokens on the first stage so the next input to
# the network is correct.
copy_from_last_to_first_pipeline_stage(tokens.size(), torch.int64,
tokens)
# set inference key values to make it consistent with best beam index
best_batches = broadcast_from_last_pipeline_stage(beam_size, torch.int64, best_batches)
forward_step.inference_params.swap_key_value_dict(best_batches)
# Update the context length for the next token generation.
prev_context_length = context_length
if mpu.is_pipeline_last_stage():
# if cannot find stop token, add open beams to hyps
if not done:
for beam_id in range(beam_size):
beam_hyp.add(tokens[beam_id].clone(), scores[beam_id].squeeze(), context_length + 1 - prompt_length)
# rank based on scores
sorted_hyps = sorted(beam_hyp.beams, key=lambda x: x[0], reverse=True)
num_return_gen = min(num_return_gen, len(sorted_hyps))
scores = [sorted_hyps[i][0] for i in range(num_return_gen)]
tokens = [sorted_hyps[i][1] for i in range(num_return_gen)]
scores = torch.stack(scores, dim=0)
tokens = torch.stack(tokens, dim=0)
scores_size_tensor = torch.tensor(scores.shape, dtype=torch.int64, device=torch.cuda.current_device())
tokens_size_tensor = torch.tensor(tokens.shape, dtype=torch.int64, device=torch.cuda.current_device())
scores_size_tensor = broadcast_from_last_pipeline_stage(1, torch.int64, scores_size_tensor)
tokens_size_tensor = broadcast_from_last_pipeline_stage(2, torch.int64, tokens_size_tensor)
scores = broadcast_from_last_to_first_pipeline_stage(tuple(scores_size_tensor), torch.float32, scores)
tokens = broadcast_from_last_to_first_pipeline_stage(tuple(tokens_size_tensor), torch.int64, tokens)
return tokens, scores
def _build_attention_mask_and_position_ids(tokens):
"""Build the attention mask and postition ids for the input tokens."""
# Since we are not interested in loss-mask and reset attention/position
# is also False, eod_token is not used so it is safe to set it to None.
attention_mask, _, position_ids = get_ltor_masks_and_position_ids(
data=tokens,
eod_token=None,
reset_position_ids=False,
reset_attention_mask=False,
eod_mask_loss=False)
return attention_mask, position_ids
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
"""Sampling utilities.
Part of this code is inspired by:
- https://github.com/ari-holtzman/degen/blob/master/gen.py
- https://huggingface.co/transformers/_modules/transformers/generation_logits_process.html
"""
import torch
def modify_logits_for_top_k_filtering(logits, top_k):
"""Set the logits for none top-k values to -inf."""
filter_ = logits < torch.topk(logits, top_k)[0][..., -1, None]
logits.masked_fill_(filter_, float('-Inf'))
def modify_logits_for_top_p_filtering(logits, top_p):
"""Set the logits for none top-p values to -inf."""
# First sort and calculate cumulative sum of probabilities.
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)
# Filteration based on the cumulative sum.
filter_ = cumulative_probs > top_p
# This shift by 1 is weird and I cannot justify it. This existed
# in the original implementation:
# https://github.com/ari-holtzman/degen/blob/master/gen.py
# and I guess it is needed so keeping it for now.
filter_[:, 1:] = filter_[:, :-1].clone()
# Make sure we at least have one token to select from.
filter_[..., 0] = 0
# Fill in the filtered part
filter_ = filter_.scatter(1, sorted_indices, filter_)
logits.masked_fill_(filter_, float('-Inf'))
def sample(logits, top_k=0, top_p=0.0, temperature=1.0, vocab_size=None):
""" Sample and generate a token.
Note: logits has the dimension [b, v] where b is the batch size
and v is the vocabulary size.
If vocab_size is provided, we will make sure the sample that is
generated is in [0, vocab-size). This will avoid out of vocabulary
generations due to padding.
"""
# Check logits for consistency.
assert logits.ndim == 2, 'expected the logits to be of [b, v] shape.'
assert logits.type() == 'torch.cuda.FloatTensor', \
'input logits should be floats.'
# Greedy is just simple argmax.
if top_k == 1:
assert top_p == 0.0, 'cannot set both greedy and top-p samplings.'
samples = torch.argmax(logits, dim=-1)
# Top-k or top-p sampling.
else:
# Clone so we do not modify the inputs,
logits = logits.clone()
# Apply temperature in place.
if temperature != 1.0:
logits.div_(temperature)
if top_k > 1:
assert top_p == 0.0, 'cannot set both top-k and top-p samplings.'
assert top_k <= logits.size(1), 'top-k is larger than logit size.'
if vocab_size:
assert top_k < vocab_size, 'top-k is larger than vocab size.'
modify_logits_for_top_k_filtering(logits, top_k)
elif top_p > 0.0:
assert top_p <= 1.0, 'top-p should be in (0, 1].'
modify_logits_for_top_p_filtering(logits, top_p)
# After filtering, we need to recalculate the distribution.
probs = logits.softmax(dim=-1)
samples = torch.multinomial(probs, num_samples=1).view(-1)
# If vocab size is provided, make sure the samples are in
# in the range [0, vocab-size).
if vocab_size:
samples = torch.clamp(samples, min=0, max=(vocab_size - 1))
return samples
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
"""Tokenization utilities."""
import torch
from megatron import get_tokenizer, get_args
from .communication import broadcast_int_list, broadcast_tensor
def detokenize_generations(tokens_gpu_tensor,
lengths_gpu_tensor,
return_segments):
"""Detokenize the generated tokens."""
tokenizer = get_tokenizer()
args = get_args()
prompts_plus_generations = []
if return_segments:
prompts_plus_generations_segments = []
tokens = tokens_gpu_tensor.cpu().numpy().tolist()
lengths = lengths_gpu_tensor.cpu().numpy().tolist()
for sequence_tokens, length in zip(tokens, lengths):
sequence_tokens = sequence_tokens[:length]
prompts_plus_generations.append(
tokenizer.detokenize(sequence_tokens))
if return_segments:
words = []
for token in sequence_tokens:
if args.tokenizer_type in ['SentencePieceTokenizer', 'GPTSentencePieceTokenizer']:
word = tokenizer.decoder[token]
else:
word = tokenizer.tokenizer.decoder[token]
word = bytearray(
[tokenizer.tokenizer.byte_decoder[c] for c in word]).decode(
'utf-8', errors='replace')
words.append(word)
prompts_plus_generations_segments.append(words)
if return_segments:
return tokens, prompts_plus_generations, \
prompts_plus_generations_segments
return tokens, prompts_plus_generations
def tokenize_prompts(prompts=None, tokens_to_generate=None,
add_BOS=None, rank=0):
"""Tokenize prompts and make them avaiable on all ranks."""
# On all ranks set to None so we can pass them to functions
sizes_list = None
prompts_tokens_cuda_long_tensor = None
prompts_length_cuda_long_tensor = None
# On the specified rank, build the above.
if torch.distributed.get_rank() == rank:
assert prompts is not None
assert tokens_to_generate is not None
# Tensor of tokens padded and their unpadded length.
prompts_tokens_cuda_long_tensor, prompts_length_cuda_long_tensor = \
_tokenize_prompts_and_batch(prompts, tokens_to_generate, add_BOS)
# We need the sizes of these tensors for the boradcast
sizes_list = [prompts_tokens_cuda_long_tensor.size(0), # Batch size
prompts_tokens_cuda_long_tensor.size(1)] # Sequence lenght
# First, broadcast the sizes.
sizes_tensor = broadcast_int_list(2, int_list=sizes_list, rank=rank)
# Now that we have the sizes, we can boradcast the tokens
# and length tensors.
sizes = sizes_tensor.tolist()
prompts_tokens_cuda_long_tensor = broadcast_tensor(
sizes, torch.int64, tensor=prompts_tokens_cuda_long_tensor, rank=rank)
prompts_length_cuda_long_tensor = broadcast_tensor(
sizes[0], torch.int64, tensor=prompts_length_cuda_long_tensor,
rank=rank)
return prompts_tokens_cuda_long_tensor, prompts_length_cuda_long_tensor
def _tokenize_prompts_and_batch(prompts, tokens_to_generate, add_BOS):
"""Given a set of prompts and number of tokens to generate:
- tokenize prompts
- set the sequence length to be the max of length of prompts
plus the number of tokens we would like to generate
- pad all the sequences to this length so we can convert them
into a 2D tensor.
"""
# Tokenize all the prompts.
tokenizer = get_tokenizer()
if add_BOS:
prompts_tokens = [[tokenizer.eod] + tokenizer.tokenize(prompt)
for prompt in prompts]
else:
prompts_tokens = [tokenizer.tokenize(prompt) for prompt in prompts]
# Now we have a list of list of tokens which each list has a different
# size. We want to extend this list to:
# - incorporate the tokens that need to be generated
# - make all the sequences equal length.
# Get the prompts length.
prompts_length = [len(prompt_tokens) for prompt_tokens in prompts_tokens]
# Get the max prompts length.
max_prompt_len = max(prompts_length)
# Number of tokens in the each sample of the batch.
samples_length = max_prompt_len + tokens_to_generate
# Now update the list of list to be of the same size: samples_length.
for prompt_tokens, prompt_length in zip(prompts_tokens, prompts_length):
padding_size = samples_length - prompt_length
prompt_tokens.extend([tokenizer.eod] * padding_size)
# Now we are in a structured format, we can convert to tensors.
prompts_tokens_tensor = torch.cuda.LongTensor(prompts_tokens)
prompts_length_tensor = torch.cuda.LongTensor(prompts_length)
return prompts_tokens_tensor, prompts_length_tensor
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
import datetime
import torch
import json
import threading
from flask import Flask, request, jsonify, current_app
from flask_restful import Resource, Api
from megatron import get_args
from megatron.text_generation import generate_and_post_process
from megatron.text_generation import beam_search_and_post_process
GENERATE_NUM = 0
BEAM_NUM = 1
lock = threading.Lock()
class MegatronGenerate(Resource):
def __init__(self, model):
self.model = model
@staticmethod
def send_do_generate():
choice = torch.cuda.LongTensor([GENERATE_NUM])
torch.distributed.broadcast(choice, 0)
@staticmethod
def send_do_beam_search():
choice = torch.cuda.LongTensor([BEAM_NUM])
torch.distributed.broadcast(choice, 0)
def put(self):
args = get_args()
if not "prompts" in request.get_json():
return "prompts argument required", 400
if "max_len" in request.get_json():
return "max_len is no longer used. Replace with tokens_to_generate", 400
if "sentences" in request.get_json():
return "sentences is no longer used. Replace with prompts", 400
prompts = request.get_json()["prompts"]
if not isinstance(prompts, list):
return "prompts is not a list of strings", 400
if len(prompts) == 0:
return "prompts is empty", 400
if len(prompts) > 128:
return "Maximum number of prompts is 128", 400
tokens_to_generate = 64 # Choosing hopefully sane default. Full sequence is slow
if "tokens_to_generate" in request.get_json():
tokens_to_generate = request.get_json()["tokens_to_generate"]
if not isinstance(tokens_to_generate, int):
return "tokens_to_generate must be an integer greater than 0"
if tokens_to_generate < 0:
return "tokens_to_generate must be an integer greater than or equal to 0"
logprobs = False
if "logprobs" in request.get_json():
logprobs = request.get_json()["logprobs"]
if not isinstance(logprobs, bool):
return "logprobs must be a boolean value"
if tokens_to_generate == 0 and not logprobs:
return "tokens_to_generate=0 implies logprobs should be True"
temperature = 1.0
if "temperature" in request.get_json():
temperature = request.get_json()["temperature"]
if not (type(temperature) == int or type(temperature) == float):
return "temperature must be a positive number less than or equal to 100.0"
if not (0.0 < temperature <= 100.0):
return "temperature must be a positive number less than or equal to 100.0"
top_k = 0.0
if "top_k" in request.get_json():
top_k = request.get_json()["top_k"]
if not (type(top_k) == int):
return "top_k must be an integer equal to or greater than 0 and less than or equal to 1000"
if not (0 <= top_k <= 1000):
return "top_k must be equal to or greater than 0 and less than or equal to 1000"
top_p = 0.0
if "top_p" in request.get_json():
top_p = request.get_json()["top_p"]
if not (type(top_p) == float):
return "top_p must be a positive float less than or equal to 1.0"
if top_p > 0.0 and top_k > 0.0:
return "cannot set both top-k and top-p samplings."
if not (0 <= top_p <= 1.0):
return "top_p must be less than or equal to 1.0"
top_p_decay = 0.0
if "top_p_decay" in request.get_json():
top_p_decay = request.get_json()["top_p_decay"]
if not (type(top_p_decay) == float):
return "top_p_decay must be a positive float less than or equal to 1.0"
if top_p == 0.0:
return "top_p_decay cannot be set without top_p"
if not (0 <= top_p_decay <= 1.0):
return "top_p_decay must be less than or equal to 1.0"
top_p_bound = 0.0
if "top_p_bound" in request.get_json():
top_p_bound = request.get_json()["top_p_bound"]
if not (type(top_p_bound) == float):
return "top_p_bound must be a positive float less than or equal to top_p"
if top_p == 0.0:
return "top_p_bound cannot be set without top_p"
if not (0.0 < top_p_bound <= top_p):
return "top_p_bound must be greater than 0 and less than top_p"
add_BOS = False
if "add_BOS" in request.get_json():
add_BOS = request.get_json()["add_BOS"]
if not isinstance(add_BOS, bool):
return "add_BOS must be a boolean value"
if any([len(prompt) == 0 for prompt in prompts]) and not add_BOS:
return "Empty prompts require add_BOS=true"
stop_on_double_eol = False
if "stop_on_double_eol" in request.get_json():
stop_on_double_eol = request.get_json()["stop_on_double_eol"]
if not isinstance(stop_on_double_eol, bool):
return "stop_on_double_eol must be a boolean value"
stop_on_eol = False
if "stop_on_eol" in request.get_json():
stop_on_eol = request.get_json()["stop_on_eol"]
if not isinstance(stop_on_eol, bool):
return "stop_on_eol must be a boolean value"
prevent_newline_after_colon = False
if "prevent_newline_after_colon" in request.get_json():
prevent_newline_after_colon = request.get_json()["prevent_newline_after_colon"]
if not isinstance(prevent_newline_after_colon, bool):
return "prevent_newline_after_colon must be a boolean value"
random_seed = -1
if "random_seed" in request.get_json():
random_seed = request.get_json()["random_seed"]
if not isinstance(random_seed, int):
return "random_seed must be integer"
if random_seed < 0:
return "random_seed must be a positive integer"
no_log = False
if "no_log" in request.get_json():
no_log = request.get_json()["no_log"]
if not isinstance(no_log, bool):
return "no_log must be a boolean value"
beam_width = None
if "beam_width" in request.get_json():
beam_width = request.get_json()["beam_width"]
if not isinstance(beam_width, int):
return "beam_width must be integer"
if beam_width < 1:
return "beam_width must be an integer > 1"
if len(prompts) > 1:
return "When doing beam_search, batch size must be 1"
stop_token=50256
if "stop_token" in request.get_json():
stop_token = request.get_json()["stop_token"]
if not isinstance(stop_token, int):
return "stop_token must be an integer"
length_penalty = 1
if "length_penalty" in request.get_json():
length_penalty = request.get_json()["length_penalty"]
if not isinstance(length_penalty, float):
return "length_penalty must be a float"
with lock: # Need to get lock to keep multiple threads from hitting code
if not no_log:
print("request IP: " + str(request.remote_addr))
print(json.dumps(request.get_json()),flush=True)
print("start time: ", datetime.datetime.now())
try:
if beam_width is not None:
MegatronGenerate.send_do_beam_search() # Tell other ranks we're doing beam_search
response, response_seg, response_scores = \
beam_search_and_post_process(
self.model,
prompts=prompts,
tokens_to_generate=tokens_to_generate,
beam_size = beam_width,
add_BOS=add_BOS,
stop_token=stop_token,
num_return_gen=beam_width, # Returning whole beam
length_penalty=length_penalty,
prevent_newline_after_colon=prevent_newline_after_colon
)
return jsonify({"text": response,
"segments": response_seg,
"scores": response_scores})
else:
MegatronGenerate.send_do_generate() # Tell other ranks we're doing generate
response, response_seg, response_logprobs, _ = \
generate_and_post_process(
self.model,
prompts=prompts,
tokens_to_generate=tokens_to_generate,
return_output_log_probs=logprobs,
top_k_sampling=top_k,
top_p_sampling=top_p,
top_p_decay=top_p_decay,
top_p_bound=top_p_bound,
temperature=temperature,
add_BOS=add_BOS,
use_eod_token_for_early_termination=True,
stop_on_double_eol=stop_on_double_eol,
stop_on_eol=stop_on_eol,
prevent_newline_after_colon=prevent_newline_after_colon,
random_seed=random_seed)
return jsonify({"text": response,
"segments": response_seg,
"logprobs": response_logprobs})
except ValueError as ve:
return ve.args[0]
print("end time: ", datetime.datetime.now())
class MegatronServer(object):
def __init__(self, model):
self.app = Flask(__name__, static_url_path='')
api = Api(self.app)
api.add_resource(MegatronGenerate, '/api', resource_class_args=[model])
def run(self, url):
self.app.run(url, threaded=True, debug=False)
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
"""Megatron timers."""
from abc import ABC
from abc import abstractmethod
import time
import torch
class TimerBase(ABC):
def __init__(self, name):
self.name = name
@abstractmethod
def start(self, barrier=False):
pass
@abstractmethod
def stop(self, barrier=False):
pass
@abstractmethod
def reset(self):
pass
@abstractmethod
def elapsed(self, reset=True, barrier=False):
pass
class DummyTimer(TimerBase):
def __init__(self):
super().__init__('dummy timer')
def start(self, barrier=False):
return
def stop(self, barrier=False):
return
def reset(self):
return
def elapsed(self, reset=True, barrier=False):
raise Exception('dummy timer should not be used to '
'calculate elapsed time')
class Timer(TimerBase):
"""
Comment on using `barrier`: If this flag is passed, then all
the caller processes will wait till all reach the timing routine.
It is up to the user to make sure all the ranks in `barrier_group`
call it otherwise, it will result in a hang.
Comment on `barrier_group`: By default it is set to None which
in torch distributed land, it will result in the global communicator.
"""
def __init__(self, name):
super().__init__(name)
self._elapsed = 0.0
self._started = False
# Note that None will default to the global process group
self._barrier_group = None
self._start_time = time.time()
def set_barrier_group(self, barrier_group):
self._barrier_group = barrier_group
def start(self, barrier=False):
"""Start the timer."""
assert not self._started, 'timer has already been started'
if barrier:
torch.distributed.barrier(group=self._barrier_group)
torch.cuda.synchronize()
self._start_time = time.time()
self._started = True
def stop(self, barrier=False):
"""Stop the timer."""
assert self._started, 'timer is not started'
if barrier:
torch.distributed.barrier(group=self._barrier_group)
torch.cuda.synchronize()
self._elapsed += (time.time() - self._start_time)
self._started = False
def reset(self):
"""Reset timer."""
self._elapsed = 0.0
self._started = False
def elapsed(self, reset=True, barrier=False):
"""Calculate the elapsed time."""
_started = self._started
# If the timing in progress, end it first.
if self._started:
self.stop(barrier=barrier)
# Get the elapsed time.
_elapsed = self._elapsed
# Reset the elapsed time
if reset:
self.reset()
# If timing was in progress, set it back.
if _started:
self.start(barrier=barrier)
return _elapsed
class Timers:
"""Group of timers."""
def __init__(self, log_level, log_option):
self._log_level = log_level
self._log_option = log_option
self._timers = {}
self._log_levels = {}
self._dummy_timer = DummyTimer()
self._max_log_level = 2
def __call__(self, name, log_level=None):
# If the timer has already been set, then check if the log-level
# is provided, it matches the one that the timer was created with.
if name in self._timers:
if log_level is not None:
assert log_level == self._log_levels[name], \
'input log level {} does not match already existing '\
'log level {} for {} timer'.format(
log_level, self._log_levels[name], name)
return self._timers[name]
# If timer does not exist and no log level is provided,
# set it to the max log level which is 2.
if log_level is None:
log_level = self._max_log_level
assert log_level <= self._max_log_level, \
'log level {} is larger than max supported log level {}'.format(
log_level, self._max_log_level)
# Now if the input log level is larger than the one set for
# the timers class, just ignore it and return a dummy timer.
if log_level > self._log_level:
return self._dummy_timer
# Otherwise, initalize the timer and set the level.
self._timers[name] = Timer(name)
self._log_levels[name] = log_level
return self._timers[name]
def _get_elapsed_time_all_ranks(self, names, reset, barrier):
"""
Assumptions:
- All the ranks call this function.
- `names` are identical on all ranks.
If the above assumptions are not met, calling this function will
result in hang.
Arguments:
- names: list of timer names
- reset: reset the timer after recording the elapsed time
- barrier: if set, do a global barrier before time measurments
"""
# First make sure all the callers are in sync.
if barrier:
torch.distributed.barrier()
world_size = torch.distributed.get_world_size()
rank = torch.distributed.get_rank()
# Here we can use gather on the rank we want to print the
# timing, however, there is no gather_base support in
# pytorch yet. It is simpler to deal with a single tensor
# and since we are only gathering a small amount of data,
# it should be ok to use all-gather instead of gather.
rank_name_to_time = torch.zeros((world_size, len(names)),
dtype=torch.float,
device=torch.cuda.current_device())
for i, name in enumerate(names):
if name in self._timers:
# Here we don't need to pass the barrier flag as all
# the processes are already in sync. This avoids the
# issue of different timers having different barrier
# groups inside their class.
rank_name_to_time[rank, i] = self._timers[name].elapsed(
reset=reset)
# See the note above for why we are not using gather.
torch.distributed._all_gather_base(rank_name_to_time.view(-1),
rank_name_to_time[rank, :].view(-1))
return rank_name_to_time
def _get_global_min_max_time(self, names, reset, barrier, normalizer):
"""Report only min and max times across all ranks."""
rank_name_to_time = self._get_elapsed_time_all_ranks(names, reset,
barrier)
name_to_min_max_time = {}
for i, name in enumerate(names):
rank_to_time = rank_name_to_time[:, i]
# filter out the ones we did not have any timings for
rank_to_time = rank_to_time[rank_to_time > 0.0]
# If the timer exists:
if rank_to_time.numel() > 0:
name_to_min_max_time[name] = (
rank_to_time.min().item() / normalizer,
rank_to_time.max().item() / normalizer)
return name_to_min_max_time
def _get_global_min_max_time_string(self, names, reset, barrier,
normalizer, max_only):
name_to_min_max_time = self._get_global_min_max_time(
names, reset, barrier, normalizer)
if not name_to_min_max_time:
return None
output_string = '(min, max) time across ranks (ms):'
for name in name_to_min_max_time:
min_time, max_time = name_to_min_max_time[name]
if max_only:
output_string += '\n {}: {:.2f}'.format(
(name+' ').ljust(48, '.'), max_time)
else:
output_string += '\n {}: ({:.2f}, {:.2f})'.format(
(name+' ').ljust(48, '.'), min_time, max_time)
return output_string
def _get_all_ranks_time_string(self, names, reset, barrier, normalizer):
"""Report times across all ranks."""
rank_name_to_time = self._get_elapsed_time_all_ranks(names, reset,
barrier)
output_string = 'times across ranks (ms):'
no_reported_timing = True
for i, name in enumerate(names):
not_yet_found = True
for rank in range(torch.distributed.get_world_size()):
if rank_name_to_time[rank, i] > 0:
no_reported_timing = False
if not_yet_found:
not_yet_found = False
output_string += '\n {}:'.format(name)
output_string += '\n rank {:2d}: {:.2f}'.format(
rank, rank_name_to_time[rank, i] / normalizer)
if no_reported_timing:
return None
return output_string
def log(self, names, rank=None, normalizer=1.0, reset=True, barrier=False):
"""Log a group of timers."""
# Print.
assert normalizer > 0.0
if self._log_option in ['max', 'minmax']:
max_only = False
if self._log_option == 'max':
max_only = True
output_string = self._get_global_min_max_time_string(
names, reset, barrier, normalizer/1000.0, max_only)
elif self._log_option == 'all':
output_string = self._get_all_ranks_time_string(names,
reset, barrier,
normalizer/1000.0)
else:
raise Exception('unknown timing log option {}'.format(
self._log_option))
# If no input rank is provided, log on last rank.
if rank is None:
rank = torch.distributed.get_world_size() - 1
if rank == torch.distributed.get_rank() and output_string is not None:
print(output_string, flush=True)
def write(self, names, writer, iteration, normalizer=1.0,
reset=False, barrier=False):
"""Write timers to a tensorboard writer
Note that we only report maximum time across ranks to tensorboard.
"""
# currently when using add_scalars,
# torch.utils.add_scalars makes each timer its own run, which
# polutes the runs list, so we just add each as a scalar
assert normalizer > 0.0
name_to_min_max_time = self._get_global_min_max_time(
names, reset, barrier, normalizer)
if writer is not None:
for name in name_to_min_max_time:
_, max_time = name_to_min_max_time[name]
writer.add_scalar(name + '-time', max_time, iteration)
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
from .tokenizer import build_tokenizer
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tokenization classes."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import re
import unicodedata
import six
def validate_case_matches_checkpoint(do_lower_case, init_checkpoint):
"""Checks whether the casing config is consistent with the checkpoint name."""
# The casing has to be passed in by the user and there is no explicit check
# as to whether it matches the checkpoint. The casing information probably
# should have been stored in the bert_config.json file, but it's not, so
# we have to heuristically detect it to validate.
if not init_checkpoint:
return
m = re.match("^.*?([A-Za-z0-9_-]+)/bert_model.ckpt", init_checkpoint)
if m is None:
return
model_name = m.group(1)
lower_models = [
"uncased_L-24_H-1024_A-16", "uncased_L-12_H-768_A-12",
"multilingual_L-12_H-768_A-12", "chinese_L-12_H-768_A-12"
]
cased_models = [
"cased_L-12_H-768_A-12", "cased_L-24_H-1024_A-16",
"multi_cased_L-12_H-768_A-12"
]
is_bad_config = False
if model_name in lower_models and not do_lower_case:
is_bad_config = True
actual_flag = "False"
case_name = "lowercased"
opposite_flag = "True"
if model_name in cased_models and do_lower_case:
is_bad_config = True
actual_flag = "True"
case_name = "cased"
opposite_flag = "False"
if is_bad_config:
raise ValueError(
"You passed in `--do_lower_case=%s` with `--init_checkpoint=%s`. "
"However, `%s` seems to be a %s model, so you "
"should pass in `--do_lower_case=%s` so that the fine-tuning matches "
"how the model was pre-training. If this error is wrong, please "
"just comment out this check." % (actual_flag, init_checkpoint,
model_name, case_name, opposite_flag))
def convert_to_unicode(text):
"""Converts `text` to Unicode (if it's not already), assuming utf-8 input."""
if six.PY3:
if isinstance(text, str):
return text
elif isinstance(text, bytes):
return text.decode("utf-8", "ignore")
else:
raise ValueError("Unsupported string type: %s" % (type(text)))
elif six.PY2:
if isinstance(text, str):
return text.decode("utf-8", "ignore")
elif isinstance(text, unicode):
return text
else:
raise ValueError("Unsupported string type: %s" % (type(text)))
else:
raise ValueError("Not running on Python2 or Python 3?")
def printable_text(text):
"""Returns text encoded in a way suitable for print or `tf.logging`."""
# These functions want `str` for both Python2 and Python3, but in one case
# it's a Unicode string and in the other it's a byte string.
if six.PY3:
if isinstance(text, str):
return text
elif isinstance(text, bytes):
return text.decode("utf-8", "ignore")
else:
raise ValueError("Unsupported string type: %s" % (type(text)))
elif six.PY2:
if isinstance(text, str):
return text
elif isinstance(text, unicode):
return text.encode("utf-8")
else:
raise ValueError("Unsupported string type: %s" % (type(text)))
else:
raise ValueError("Not running on Python2 or Python 3?")
def load_vocab(vocab_file):
"""Loads a vocabulary file into a dictionary."""
vocab = collections.OrderedDict()
index = 0
with open(vocab_file, "r", encoding = "utf-8") as reader:
while True:
token = convert_to_unicode(reader.readline())
if not token:
break
token = token.strip()
vocab[token] = index
index += 1
return vocab
def convert_by_vocab(vocab, items):
"""Converts a sequence of [tokens|ids] using the vocab."""
output = []
for item in items:
output.append(vocab[item])
return output
def convert_tokens_to_ids(vocab, tokens):
return convert_by_vocab(vocab, tokens)
def convert_ids_to_tokens(inv_vocab, ids):
return convert_by_vocab(inv_vocab, ids)
def whitespace_tokenize(text):
"""Runs basic whitespace cleaning and splitting on a piece of text."""
text = text.strip()
if not text:
return []
tokens = text.split()
return tokens
class FullTokenizer(object):
"""Runs end-to-end tokenziation."""
def __init__(self, vocab_file, do_lower_case=True):
self.vocab = load_vocab(vocab_file)
self.inv_vocab = {v: k for k, v in self.vocab.items()}
self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case)
self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab)
def tokenize(self, text):
split_tokens = []
for token in self.basic_tokenizer.tokenize(text):
for sub_token in self.wordpiece_tokenizer.tokenize(token):
split_tokens.append(sub_token)
return split_tokens
def convert_tokens_to_ids(self, tokens):
return convert_by_vocab(self.vocab, tokens)
def convert_ids_to_tokens(self, ids):
return convert_by_vocab(self.inv_vocab, ids)
@staticmethod
def convert_tokens_to_string(tokens, clean_up_tokenization_spaces=True):
""" Converts a sequence of tokens (string) in a single string. """
def clean_up_tokenization(out_string):
""" Clean up a list of simple English tokenization artifacts
like spaces before punctuations and abreviated forms.
"""
out_string = (
out_string.replace(" .", ".")
.replace(" ?", "?")
.replace(" !", "!")
.replace(" ,", ",")
.replace(" ' ", "'")
.replace(" n't", "n't")
.replace(" 'm", "'m")
.replace(" 's", "'s")
.replace(" 've", "'ve")
.replace(" 're", "'re")
)
return out_string
text = ' '.join(tokens).replace(' ##', '').strip()
if clean_up_tokenization_spaces:
clean_text = clean_up_tokenization(text)
return clean_text
else:
return text
def vocab_size(self):
return len(self.vocab)
class BasicTokenizer(object):
"""Runs basic tokenization (punctuation splitting, lower casing, etc.)."""
def __init__(self, do_lower_case=True):
"""Constructs a BasicTokenizer.
Args:
do_lower_case: Whether to lower case the input.
"""
self.do_lower_case = do_lower_case
def tokenize(self, text):
"""Tokenizes a piece of text."""
text = convert_to_unicode(text)
text = self._clean_text(text)
# This was added on November 1st, 2018 for the multilingual and Chinese
# models. This is also applied to the English models now, but it doesn't
# matter since the English models were not trained on any Chinese data
# and generally don't have any Chinese data in them (there are Chinese
# characters in the vocabulary because Wikipedia does have some Chinese
# words in the English Wikipedia.).
text = self._tokenize_chinese_chars(text)
orig_tokens = whitespace_tokenize(text)
split_tokens = []
for token in orig_tokens:
if self.do_lower_case:
token = token.lower()
token = self._run_strip_accents(token)
split_tokens.extend(self._run_split_on_punc(token))
output_tokens = whitespace_tokenize(" ".join(split_tokens))
return output_tokens
def _run_strip_accents(self, text):
"""Strips accents from a piece of text."""
text = unicodedata.normalize("NFD", text)
output = []
for char in text:
cat = unicodedata.category(char)
if cat == "Mn":
continue
output.append(char)
return "".join(output)
def _run_split_on_punc(self, text):
"""Splits punctuation on a piece of text."""
chars = list(text)
i = 0
start_new_word = True
output = []
while i < len(chars):
char = chars[i]
if _is_punctuation(char):
output.append([char])
start_new_word = True
else:
if start_new_word:
output.append([])
start_new_word = False
output[-1].append(char)
i += 1
return ["".join(x) for x in output]
def _tokenize_chinese_chars(self, text):
"""Adds whitespace around any CJK character."""
output = []
for char in text:
cp = ord(char)
if self._is_chinese_char(cp):
output.append(" ")
output.append(char)
output.append(" ")
else:
output.append(char)
return "".join(output)
def _is_chinese_char(self, cp):
"""Checks whether CP is the codepoint of a CJK character."""
# This defines a "chinese character" as anything in the CJK Unicode block:
# https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
#
# Note that the CJK Unicode block is NOT all Japanese and Korean characters,
# despite its name. The modern Korean Hangul alphabet is a different block,
# as is Japanese Hiragana and Katakana. Those alphabets are used to write
# space-separated words, so they are not treated specially and handled
# like the all of the other languages.
if ((cp >= 0x4E00 and cp <= 0x9FFF) or #
(cp >= 0x3400 and cp <= 0x4DBF) or #
(cp >= 0x20000 and cp <= 0x2A6DF) or #
(cp >= 0x2A700 and cp <= 0x2B73F) or #
(cp >= 0x2B740 and cp <= 0x2B81F) or #
(cp >= 0x2B820 and cp <= 0x2CEAF) or
(cp >= 0xF900 and cp <= 0xFAFF) or #
(cp >= 0x2F800 and cp <= 0x2FA1F)): #
return True
return False
def _clean_text(self, text):
"""Performs invalid character removal and whitespace cleanup on text."""
output = []
for char in text:
cp = ord(char)
if cp == 0 or cp == 0xfffd or _is_control(char):
continue
if _is_whitespace(char):
output.append(" ")
else:
output.append(char)
return "".join(output)
class WordpieceTokenizer(object):
"""Runs WordPiece tokenziation."""
def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=200):
self.vocab = vocab
self.unk_token = unk_token
self.max_input_chars_per_word = max_input_chars_per_word
def tokenize(self, text):
"""Tokenizes a piece of text into its word pieces.
This uses a greedy longest-match-first algorithm to perform tokenization
using the given vocabulary.
For example:
input = "unaffable"
output = ["un", "##aff", "##able"]
Args:
text: A single token or whitespace separated tokens. This should have
already been passed through `BasicTokenizer.
Returns:
A list of wordpiece tokens.
"""
text = convert_to_unicode(text)
output_tokens = []
for token in whitespace_tokenize(text):
chars = list(token)
if len(chars) > self.max_input_chars_per_word:
output_tokens.append(self.unk_token)
continue
is_bad = False
start = 0
sub_tokens = []
while start < len(chars):
end = len(chars)
cur_substr = None
while start < end:
substr = "".join(chars[start:end])
if start > 0:
substr = "##" + substr
if substr in self.vocab:
cur_substr = substr
break
end -= 1
if cur_substr is None:
is_bad = True
break
sub_tokens.append(cur_substr)
start = end
if is_bad:
output_tokens.append(self.unk_token)
else:
output_tokens.extend(sub_tokens)
return output_tokens
def _is_whitespace(char):
"""Checks whether `chars` is a whitespace character."""
# \t, \n, and \r are technically contorl characters but we treat them
# as whitespace since they are generally considered as such.
if char == " " or char == "\t" or char == "\n" or char == "\r":
return True
cat = unicodedata.category(char)
if cat == "Zs":
return True
return False
def _is_control(char):
"""Checks whether `chars` is a control character."""
# These are technically control characters but we count them as whitespace
# characters.
if char == "\t" or char == "\n" or char == "\r":
return False
cat = unicodedata.category(char)
if cat in ("Cc", "Cf"):
return True
return False
def _is_punctuation(char):
"""Checks whether `chars` is a punctuation character."""
cp = ord(char)
# We treat all non-letter/number ASCII as punctuation.
# Characters such as "^", "$", and "`" are not in the Unicode
# Punctuation class but we treat them as punctuation anyways, for
# consistency.
if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or
(cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)):
return True
cat = unicodedata.category(char)
if cat.startswith("P"):
return True
return False
# coding=utf-8
# Copyright 2018 The Open AI Team Authors and The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tokenization classes for OpenAI GPT."""
from __future__ import (absolute_import, division, print_function,
unicode_literals)
import sys
import json
import logging
import os
import regex as re
from io import open
try:
from functools import lru_cache
except ImportError:
# Just a dummy decorator to get the checks to run on python2
# because honestly I don't want to support a byte-level unicode BPE
# tokenizer on python 2 right now.
def lru_cache():
return lambda func: func
logger = logging.getLogger(__name__)
PRETRAINED_VOCAB_ARCHIVE_MAP = {
'gpt2': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-vocab.json",
}
PRETRAINED_MERGES_ARCHIVE_MAP = {
'gpt2': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-merges.txt",
}
PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP = {
'gpt2': 1024,
}
VOCAB_NAME = 'vocab.json'
MERGES_NAME = 'merges.txt'
SPECIAL_TOKENS_NAME = 'special_tokens.txt'
@lru_cache()
def bytes_to_unicode():
"""
Returns list of utf-8 byte and a corresponding list of unicode strings.
The reversible bpe codes work on unicode strings.
This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
This is a signficant percentage of your normal, say, 32K bpe vocab.
To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
And avoids mapping to whitespace/control characters the bpe code barfs on.
"""
_chr = unichr if sys.version_info[0] == 2 else chr
bs = list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + \
list(range(ord("®"), ord("ÿ") + 1))
cs = bs[:]
n = 0
for b in range(2**8):
if b not in bs:
bs.append(b)
cs.append(2**8 + n)
n += 1
cs = [_chr(n) for n in cs]
return dict(zip(bs, cs))
def get_pairs(word):
"""Return set of symbol pairs in a word.
Word is represented as tuple of symbols (symbols being variable-length strings).
"""
pairs = set()
prev_char = word[0]
for char in word[1:]:
pairs.add((prev_char, char))
prev_char = char
return pairs
class GPT2Tokenizer(object):
"""
GPT-2 BPE tokenizer. Peculiarities:
- Byte-level BPE
"""
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs, **kwargs):
"""
Instantiate a PreTrainedBertModel from a pre-trained model file.
Download and cache the pre-trained model file if needed.
"""
if pretrained_model_name_or_path in PRETRAINED_VOCAB_ARCHIVE_MAP:
vocab_file = PRETRAINED_VOCAB_ARCHIVE_MAP[pretrained_model_name_or_path]
merges_file = PRETRAINED_MERGES_ARCHIVE_MAP[pretrained_model_name_or_path]
special_tokens_file = None
else:
vocab_file = os.path.join(pretrained_model_name_or_path, VOCAB_NAME)
merges_file = os.path.join(pretrained_model_name_or_path, MERGES_NAME)
special_tokens_file = os.path.join(pretrained_model_name_or_path, SPECIAL_TOKENS_NAME)
if not os.path.exists(special_tokens_file):
special_tokens_file = None
else:
logger.info("loading special tokens file {}".format(special_tokens_file))
# redirect to the cache, if necessary
try:
from .file_utils import cached_path
resolved_vocab_file = cached_path(vocab_file, cache_dir=cache_dir)
resolved_merges_file = cached_path(merges_file, cache_dir=cache_dir)
except EnvironmentError:
logger.error(
"Model name '{}' was not found in model name list ({}). "
"We assumed '{}' was a path or url but couldn't find files {} and {} "
"at this path or url.".format(
pretrained_model_name_or_path,
', '.join(PRETRAINED_VOCAB_ARCHIVE_MAP.keys()),
pretrained_model_name_or_path,
vocab_file, merges_file))
return None
if resolved_vocab_file == vocab_file and resolved_merges_file == merges_file:
logger.info("loading vocabulary file {}".format(vocab_file))
logger.info("loading merges file {}".format(merges_file))
else:
logger.info("loading vocabulary file {} from cache at {}".format(
vocab_file, resolved_vocab_file))
logger.info("loading merges file {} from cache at {}".format(
merges_file, resolved_merges_file))
if pretrained_model_name_or_path in PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP:
# if we're using a pretrained model, ensure the tokenizer wont index sequences longer
# than the number of positional embeddings
max_len = PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP[pretrained_model_name_or_path]
kwargs['max_len'] = min(kwargs.get('max_len', int(1e12)), max_len)
# Instantiate tokenizer.
if special_tokens_file and 'special_tokens' not in kwargs:
special_tokens = open(special_tokens_file, encoding='utf-8').read().split('\n')[:-1]
else:
special_tokens = kwargs.pop('special_tokens', [])
tokenizer = cls(
resolved_vocab_file,
resolved_merges_file,
special_tokens=special_tokens,
*inputs,
**kwargs)
return tokenizer
def __init__(self, vocab_file, merges_file, errors='replace',
special_tokens=None, max_len=None):
self.max_len = max_len if max_len is not None else int(1e12)
self.encoder = json.load(open(vocab_file))
self.decoder = {v: k for k, v in self.encoder.items()}
self.errors = errors # how to handle errors in decoding
self.byte_encoder = bytes_to_unicode()
self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
bpe_data = open(merges_file, encoding='utf-8').read().split('\n')[1:-1]
bpe_merges = [tuple(merge.split()) for merge in bpe_data]
self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges))))
self.cache = {}
# Should haved added re.IGNORECASE so BPE merges can happen for
# capitalized versions of contractions
self.pat = re.compile(
r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""")
self.special_tokens = {}
self.special_tokens_decoder = {}
self.set_special_tokens(special_tokens)
def __len__(self):
return len(self.encoder) + len(self.special_tokens)
def set_special_tokens(self, special_tokens):
""" Add a list of additional tokens to the encoder.
The additional tokens are indexed starting from the last index of the
current vocabulary in the order of the `special_tokens` list.
"""
if not special_tokens:
self.special_tokens = {}
self.special_tokens_decoder = {}
return
self.special_tokens = dict((tok, len(self.encoder) + i)
for i, tok in enumerate(special_tokens))
self.special_tokens_decoder = {v: k for k, v in self.special_tokens.items()}
logger.info("Special tokens {}".format(self.special_tokens))
def bpe(self, token):
if token in self.cache:
return self.cache[token]
word = tuple(token)
pairs = get_pairs(word)
if not pairs:
return token
while True:
bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float('inf')))
if bigram not in self.bpe_ranks:
break
first, second = bigram
new_word = []
i = 0
while i < len(word):
try:
j = word.index(first, i)
new_word.extend(word[i:j])
i = j
except BaseException:
new_word.extend(word[i:])
break
if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
new_word.append(first + second)
i += 2
else:
new_word.append(word[i])
i += 1
new_word = tuple(new_word)
word = new_word
if len(word) == 1:
break
else:
pairs = get_pairs(word)
word = ' '.join(word)
self.cache[token] = word
return word
def tokenize(self, text):
""" Tokenize a string. """
bpe_tokens = []
for token in re.findall(self.pat, text):
if sys.version_info[0] == 2:
token = ''.join(self.byte_encoder[ord(b)] for b in token)
else:
token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
bpe_tokens.extend(bpe_token for bpe_token in self.bpe(token).split(' '))
return bpe_tokens
def convert_tokens_to_ids(self, tokens):
""" Converts a sequence of tokens into ids using the vocab. """
ids = []
if isinstance(tokens, str) or (sys.version_info[0] == 2 and isinstance(tokens, unicode)):
if tokens in self.special_tokens:
return self.special_tokens[tokens]
else:
return self.encoder.get(tokens, 0)
for token in tokens:
if token in self.special_tokens:
ids.append(self.special_tokens[token])
else:
ids.append(self.encoder.get(token, 0))
if len(ids) > self.max_len:
logger.warning(
"Token indices sequence length is longer than the specified maximum "
" sequence length for this OpenAI GPT model ({} > {}). Running this"
" sequence through the model will result in indexing errors".format(
len(ids), self.max_len)
)
return ids
def convert_ids_to_tokens(self, ids, skip_special_tokens=False):
"""Converts a sequence of ids in BPE tokens using the vocab."""
tokens = []
for i in ids:
if i in self.special_tokens_decoder:
if not skip_special_tokens:
tokens.append(self.special_tokens_decoder[i])
else:
tokens.append(self.decoder[i])
return tokens
def encode(self, text):
return self.convert_tokens_to_ids(self.tokenize(text))
def decode(self, tokens):
text = ''.join([self.decoder[token] for token in tokens])
text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors=self.errors)
return text
def save_vocabulary(self, vocab_path):
"""Save the tokenizer vocabulary and merge files to a directory."""
if not os.path.isdir(vocab_path):
logger.error("Vocabulary path ({}) should be a directory".format(vocab_path))
return
vocab_file = os.path.join(vocab_path, VOCAB_NAME)
merge_file = os.path.join(vocab_path, MERGES_NAME)
special_tokens_file = os.path.join(vocab_path, SPECIAL_TOKENS_NAME)
with open(vocab_file, 'w', encoding='utf-8') as f:
f.write(json.dumps(self.encoder, ensure_ascii=False))
index = 0
with open(merge_file, "w", encoding="utf-8") as writer:
writer.write(u'#version: 0.2\n')
for bpe_tokens, token_index in sorted(self.bpe_ranks.items(), key=lambda kv: kv[1]):
if index != token_index:
logger.warning("Saving vocabulary to {}: BPE merge indices are not consecutive."
" Please check that the tokenizer is not corrupted!".format(merge_file))
index = token_index
writer.write(' '.join(bpe_tokens) + u'\n')
index += 1
index = len(self.encoder)
with open(special_tokens_file, 'w', encoding='utf-8') as writer:
for token, token_index in sorted(self.special_tokens.items(), key=lambda kv: kv[1]):
if index != token_index:
logger.warning("Saving special tokens vocabulary to {}: BPE indices are not consecutive."
" Please check that the tokenizer is not corrupted!".format(special_tokens_file))
index = token_index
writer.write(token + u'\n')
index += 1
return vocab_file, merge_file, special_tokens_file
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
"""Megatron tokenizers."""
from abc import ABC
from abc import abstractmethod
from .bert_tokenization import FullTokenizer as FullBertTokenizer
from .gpt2_tokenization import GPT2Tokenizer
def build_tokenizer(args):
"""Initialize tokenizer."""
if args.rank == 0:
print('> building {} tokenizer ...'.format(args.tokenizer_type),
flush=True)
if args.tokenizer_type not in ['SentencePieceTokenizer', 'GPTSentencePieceTokenizer']:
assert args.vocab_file is not None
# Select and instantiate the tokenizer.
if args.tokenizer_type == 'BertWordPieceLowerCase':
tokenizer = _BertWordPieceTokenizer(vocab_file=args.vocab_file,
lower_case=True,
vocab_extra_ids=args.vocab_extra_ids)
elif args.tokenizer_type == 'BertWordPieceCase':
tokenizer = _BertWordPieceTokenizer(vocab_file=args.vocab_file,
lower_case=False,
vocab_extra_ids=args.vocab_extra_ids)
elif args.tokenizer_type == 'GPT2BPETokenizer':
assert args.merge_file is not None
tokenizer = _GPT2BPETokenizer(args.vocab_file, args.merge_file)
elif args.tokenizer_type == 'SentencePieceTokenizer':
assert args.tokenizer_model is not None
tokenizer = _SentencePieceTokenizer(args.tokenizer_model, vocab_extra_ids=args.vocab_extra_ids)
elif args.tokenizer_type == 'GPTSentencePieceTokenizer':
assert args.tokenizer_model is not None
tokenizer = _GPTSentencePieceTokenizer(args.tokenizer_model)
else:
raise NotImplementedError('{} tokenizer is not '
'implemented.'.format(args.tokenizer_type))
# Add vocab size.
args.padded_vocab_size = _vocab_size_with_padding(tokenizer.vocab_size,
args)
return tokenizer
def _vocab_size_with_padding(orig_vocab_size, args):
"""Pad vocab size so it is divisible by model parallel size and
still having GPU friendly size."""
after = orig_vocab_size
multiple = args.make_vocab_size_divisible_by * \
args.tensor_model_parallel_size
while (after % multiple) != 0:
after += 1
if args.rank == 0:
print(' > padded vocab (size: {}) with {} dummy tokens '
'(new size: {})'.format(
orig_vocab_size, after - orig_vocab_size, after), flush=True)
return after
class AbstractTokenizer(ABC):
"""Abstract class for tokenizer."""
def __init__(self, name):
self.name = name
super().__init__()
@property
@abstractmethod
def vocab_size(self):
pass
@property
@abstractmethod
def vocab(self):
"""Dictionary from vocab text token to id token."""
pass
@property
@abstractmethod
def inv_vocab(self):
"""Dictionary from vocab id token to text token."""
pass
@abstractmethod
def tokenize(self, text):
pass
def detokenize(self, token_ids):
raise NotImplementedError('detokenizer is not implemented for {} '
'tokenizer'.format(self.name))
@property
def cls(self):
raise NotImplementedError('CLS is not provided for {} '
'tokenizer'.format(self.name))
@property
def sep(self):
raise NotImplementedError('SEP is not provided for {} '
'tokenizer'.format(self.name))
@property
def pad(self):
raise NotImplementedError('PAD is not provided for {} '
'tokenizer'.format(self.name))
@property
def eod(self):
raise NotImplementedError('EOD is not provided for {} '
'tokenizer'.format(self.name))
@property
def mask(self):
raise NotImplementedError('MASK is not provided for {} '
'tokenizer'.format(self.name))
class _BertWordPieceTokenizer(AbstractTokenizer):
"""Original BERT wordpiece tokenizer."""
def __init__(self, vocab_file, lower_case=True, vocab_extra_ids=0):
if lower_case:
name = 'BERT Lower Case'
else:
name = 'BERT Upper Case'
super().__init__(name)
self.tokenizer = FullBertTokenizer(vocab_file, do_lower_case=lower_case)
self.cls_id = self.tokenizer.vocab['[CLS]']
self.sep_id = self.tokenizer.vocab['[SEP]']
self.pad_id = self.tokenizer.vocab['[PAD]']
self.mask_id = self.tokenizer.vocab['[MASK]']
self._additional_special_tokens = []
# (dsachan) Add BOS and EOS tokens
SPECIAL_TOKENS = {'eos_token': '[EOS]',
'bos_token': '[BOS]'}
self._bos_token = '[BOS]'
self.add_token(self._bos_token)
self._bos_token_id = self.vocab.get(self._bos_token)
self._eos_token = '[EOS]'
self.add_token(self._eos_token)
self._eos_token_id = self.vocab.get(self._eos_token)
# (dsachan) Add additional special tokens
# These can be used as sentinel tokens in T5 model inputs
additional_special_tokens = []
additional_special_tokens.extend(
["<extra_id_{}>".format(i) for i in range(vocab_extra_ids)])
self.add_additional_special_tokens(additional_special_tokens)
def add_token(self, token):
if token not in self.vocab:
self.inv_vocab[self.vocab_size] = token
# self.vocab_size comes from len(vocab)
# and it will increase as we add elements
self.vocab[token] = self.vocab_size
def add_additional_special_tokens(self, tokens_list):
setattr(self, "additional_special_tokens", tokens_list)
for value in tokens_list:
self.add_token(value)
@property
def vocab_size(self):
return self.tokenizer.vocab_size()
@property
def vocab(self):
return self.tokenizer.vocab
@property
def inv_vocab(self):
return self.tokenizer.inv_vocab
def tokenize(self, text):
text_tokens = self.tokenizer.tokenize(text)
return self.tokenizer.convert_tokens_to_ids(text_tokens)
def decode(self, ids):
tokens = self.tokenizer.convert_ids_to_tokens(ids)
return self.tokenizer.convert_tokens_to_string(tokens)
def decode_token_ids(self, token_ids):
tokens = self.tokenizer.convert_ids_to_tokens(token_ids)
exclude_list = ['[PAD]', '[CLS]']
non_pads = [t for t in tokens if t not in exclude_list]
result = ""
for s in non_pads:
if s.startswith("##"):
result += s[2:]
else:
result += " " + s
return result
@property
def cls(self):
return self.cls_id
@property
def sep(self):
return self.sep_id
@property
def pad(self):
return self.pad_id
@property
def mask(self):
return self.mask_id
@property
def bos_token(self):
""" Beginning of sentence token id """
return self._bos_token
@property
def eos_token(self):
""" End of sentence token id """
return self._eos_token
@property
def additional_special_tokens(self):
""" All the additional special tokens you may want to use (list of strings)."""
return self._additional_special_tokens
@property
def bos_token_id(self):
""" Id of the beginning of sentence token in the vocabulary."""
return self._bos_token_id
@property
def eos_token_id(self):
""" Id of the end of sentence token in the vocabulary."""
return self._eos_token_id
@property
def additional_special_tokens_ids(self):
""" Ids of all the additional special tokens in the vocabulary (list of integers)."""
return [self.vocab.get(token) for token in self._additional_special_tokens]
@additional_special_tokens.setter
def additional_special_tokens(self, value):
self._additional_special_tokens = value
class _GPT2BPETokenizer(AbstractTokenizer):
"""Original GPT2 BPE tokenizer."""
def __init__(self, vocab_file, merge_file):
name = 'GPT2 BPE'
super().__init__(name)
self.tokenizer = GPT2Tokenizer(vocab_file, merge_file, errors='replace',
special_tokens=[], max_len=None)
self.eod_id = self.tokenizer.encoder['<|endoftext|>']
@property
def vocab_size(self):
return len(self.tokenizer.encoder)
@property
def vocab(self):
return self.tokenizer.encoder
@property
def inv_vocab(self):
return self.tokenizer.decoder
def tokenize(self, text):
return self.tokenizer.encode(text)
def detokenize(self, token_ids):
return self.tokenizer.decode(token_ids)
@property
def eod(self):
return self.eod_id
class _SentencePieceTokenizer(AbstractTokenizer):
"""SentencePieceTokenizer-Megatron wrapper"""
def __init__(self, model_file, vocab_extra_ids=0):
name = 'SentencePieceTokenizer'
super().__init__(name)
import sentencepiece
self.tokenizer = sentencepiece.SentencePieceProcessor(model_file=model_file)
self._initalize(vocab_extra_ids)
def _populate_vocab(self):
self._vocab = {}
self._inv_vocab = {}
for i in range(len(self.tokenizer)):
t = self.tokenizer.id_to_piece(i)
self._inv_vocab[i] = t
self._vocab[t] = i
def _initalize(self, vocab_extra_ids):
self._populate_vocab()
self._special_tokens = {}
self._inv_special_tokens = {}
self._t5_tokens = []
def _add_special_token(t):
if t not in self._vocab:
next_id = len(self._vocab)
self._vocab[t] = next_id
self._inv_vocab[next_id] = t
self._special_tokens[t] = self._vocab[t]
self._inv_special_tokens[self._vocab[t]] = t
_add_special_token('<CLS>')
self._cls_id = self._vocab['<CLS>']
_add_special_token('<SEP>')
self._sep_id = self._vocab['<SEP>']
_add_special_token('<EOD>')
self._eod_id = self._vocab['<EOD>']
_add_special_token('<MASK>')
self._mask_id = self._vocab['<MASK>']
pad_id = self.tokenizer.pad_id()
try:
pad_token = self.tokenizer.id_to_piece(pad_id)
except IndexError:
pad_token = '<PAD>'
_add_special_token(pad_token)
self._pad_id = self._vocab[pad_token]
bos_id = self.tokenizer.bos_id()
try:
bos_token = self.tokenizer.id_to_piece(bos_id)
except IndexError:
bos_token = '<BOS>'
_add_special_token(bos_token)
self._bos_id = self._vocab[bos_token]
eos_id = self.tokenizer.eos_id()
try:
eos_token = self.tokenizer.id_to_piece(eos_id)
except IndexError:
eos_token = '<EOS>'
_add_special_token(eos_token)
self._eos_id = self._vocab[eos_token]
for i in range(vocab_extra_ids):
t = "<extra_id_{}>".format(i)
_add_special_token(t)
self._t5_tokens += [t]
@property
def vocab_size(self):
return len(self._vocab)
@property
def vocab(self):
return self._vocab
@property
def inv_vocab(self):
return self._inv_vocab
@property
def decoder(self):
return self._inv_vocab
@property
def encoder(self):
return self._vocab
# From:
# https://github.com/NVIDIA/NeMo/blob/c8fa217e811d60d11d014827c7f3845ff6c99ae7/nemo/collections/common/tokenizers/sentencepiece_tokenizer.py#L89
def tokenize(self, text):
ids = []
idx = 0
while 1:
indices = {}
for token in self._special_tokens:
try:
indices[token] = text[idx:].index(token)
except ValueError:
continue
if len(indices) == 0:
break
next_token = min(indices, key=indices.get)
next_idx = idx + indices[next_token]
ids.extend(self.tokenizer.encode_as_ids(text[idx:next_idx]))
ids.append(self._special_tokens[next_token])
idx = next_idx + len(next_token)
ids.extend(self.tokenizer.encode_as_ids(text[idx:]))
return ids
# From:
# https://github.com/NVIDIA/NeMo/blob/c8fa217e811d60d11d014827c7f3845ff6c99ae7/nemo/collections/common/tokenizers/sentencepiece_tokenizer.py#L125
def detokenize(self, ids):
text = ""
last_i = 0
for i, id in enumerate(ids):
if id in self._inv_special_tokens:
text += self.tokenizer.decode_ids(ids[last_i:i]) + " "
text += self._inv_special_tokens[id] + " "
last_i = i + 1
text += self.tokenizer.decode_ids(ids[last_i:])
return text
@property
def cls(self):
return self._cls_id
@property
def sep(self):
return self._sep_id
@property
def pad(self):
return self._pad_id
@property
def bos_token_id(self):
return self._bos_id
@property
def bos(self):
return self._bos_id
@property
def eod(self):
return self._eod_id
@property
def eos_token_id(self):
return self._eos_id
@property
def eos(self):
return self._eos_id
@property
def mask(self):
return self._mask_id
@property
def additional_special_tokens_ids(self):
return [self.vocab[k] for k in self._t5_tokens]
class _GPTSentencePieceTokenizer(_SentencePieceTokenizer):
"""SentencePieceTokenizer-Megatron wrapper"""
def __init__(self, model_file,):
super().__init__(model_file, vocab_extra_ids=0)
def _initalize(self, vocab_extra_ids):
self._populate_vocab()
self._pad_id = self.tokenizer.pad_id()
self._bos_id = self.tokenizer.bos_id()
self._eos_id = self.tokenizer.eos_id()
def tokenize(self, text):
return self.tokenizer.encode_as_ids(text)
def detokenize(self, ids):
return self.tokenizer.decode_ids(ids)
@property
def cls(self):
return -1
@property
def sep(self):
return -1
@property
def mask(self):
return -1
@property
def eod(self):
return self._eos_id
@property
def additional_special_tokens_ids(self):
return None
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
"""Pretrain utilities."""
from datetime import datetime
import math
import sys
import time
# The earliest we can measure the start time.
_TRAIN_START_TIME = time.time()
import torch
from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
from megatron import get_args
from megatron import get_signal_handler
from megatron import get_timers
from megatron import get_tensorboard_writer
from megatron import get_current_global_batch_size
from megatron import get_num_microbatches
from megatron import is_last_rank
from megatron import update_num_microbatches
from megatron.core import mpu, tensor_parallel
from megatron import print_rank_0
from megatron import print_rank_last
from megatron.checkpointing import load_checkpoint
from megatron.checkpointing import save_checkpoint
from megatron.model import Float16Module
from megatron.model import GPTModel
from megatron.core.enums import ModelType
from megatron.optimizer import get_megatron_optimizer
from megatron.initialize import initialize_megatron
from megatron.initialize import write_args_to_tensorboard
from megatron.initialize import set_jit_fusion_options
from megatron.optimizer_param_scheduler import OptimizerParamScheduler
from megatron.model import DistributedDataParallel as LocalDDP
from megatron.utils import check_adlr_autoresume_termination
from megatron.utils import unwrap_model
from megatron.data.data_samplers import build_pretraining_data_loader
from megatron.utils import calc_params_l2_norm
from megatron.core.pipeline_parallel import get_forward_backward_func
from megatron.utils import report_memory
from megatron.model.vision.knn_monitor import compute_feature_bank
def print_datetime(string):
"""Note that this call will sync across all ranks."""
torch.distributed.barrier()
time_str = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
print_rank_0('[' + string + '] datetime: {} '.format(time_str))
def pretrain(train_valid_test_dataset_provider,
model_provider,
model_type,
forward_step_func,
process_non_loss_data_func=None,
extra_args_provider=None,
args_defaults={}):
"""Main training program.
This function will run the followings in the order provided:
1) initialize Megatron.
2) setup model, optimizer and lr schedule using the model_provider.
3) call train_val_test_data_provider to get train/val/test datasets.
4) train the modle using the forward_step_func.
Arguments:
train_valid_test_dataset_provider: a function that takes the size of
train/valid/test dataset and returns `train, valid, test` datasets.
model_provider: a function that returns a vanilla version of the
model. By vanilla we mean a simple model on cpu with no fp16 or ddp.
model_type: an enum that specifies the type of model being trained.
forward_step_func: a function that takes a `data iterator` and `model`,
and returns a `loss` scalar with a dictionary with key:values being
the info we would like to monitor during training, for example
`lm-loss: value`. We also require that this function add
`batch generator` to the timers class.
process_non_loss_data_func: a function to post process outputs of the
network. It can be used for dumping output tensors (e.g images) to
tensorboard. It takes `collected data`(list of tensors),
`current iteration index` and `tensorboard writer` as arguments.
extra_args_provider: a function that takes a parser and adds arguments
to it. It is used for programs to add their own arguments.
args_defaults: a dictionary from argument-name to argument-value. It
to set already parse arguments.
"""
# Initalize and get arguments, timers, and Tensorboard writer.
initialize_megatron(extra_args_provider=extra_args_provider,
args_defaults=args_defaults)
# Set pytorch JIT layer fusion options and warmup JIT functions.
set_jit_fusion_options()
# Adjust the startup time so it reflects the largest value.
# This will be closer to what scheduler will see (outside of
# image ... launches.
global _TRAIN_START_TIME
start_time_tensor = torch.cuda.DoubleTensor([_TRAIN_START_TIME])
torch.distributed.all_reduce(start_time_tensor,
op=torch.distributed.ReduceOp.MIN)
_TRAIN_START_TIME = start_time_tensor.item()
print_rank_0('time to initialize megatron (seconds): {:.3f}'.format(
time.time() - _TRAIN_START_TIME))
print_datetime('after megatron is initialized')
args = get_args()
timers = get_timers()
# Model, optimizer, and learning rate.
timers('model-and-optimizer-setup', log_level=0).start(barrier=True)
model, optimizer, opt_param_scheduler = setup_model_and_optimizer(
model_provider, model_type)
timers('model-and-optimizer-setup').stop()
print_datetime('after model, optimizer, and learning rate '
'scheduler are built')
# Data stuff.
timers('train/valid/test-data-iterators-setup', log_level=0).start(
barrier=True)
if args.virtual_pipeline_model_parallel_size is not None:
all_data_iterators = [
build_train_valid_test_data_iterators(
train_valid_test_dataset_provider)
for _ in range(len(model))
]
train_data_iterator = [data_iterators[0]
for data_iterators in all_data_iterators]
valid_data_iterator = [data_iterators[1]
for data_iterators in all_data_iterators]
test_data_iterator = [data_iterators[2]
for data_iterators in all_data_iterators]
else:
train_data_iterator, valid_data_iterator, test_data_iterator \
= build_train_valid_test_data_iterators(
train_valid_test_dataset_provider)
timers('train/valid/test-data-iterators-setup').stop()
print_datetime('after dataloaders are built')
# Print setup timing.
print_rank_0('done with setup ...')
timers.log(['model-and-optimizer-setup',
'train/valid/test-data-iterators-setup'], barrier=True)
print_rank_0('training ...')
iteration = 0
if args.dataloader_type == 'cyclic' and args.retro_add_retriever:
args.train_iters = args.retro_cyclic_train_iters
print_rank_0("retro cyclic train iters : %d" % args.train_iters)
if args.do_train and args.train_iters > 0:
iteration = train(forward_step_func,
model, optimizer, opt_param_scheduler,
train_data_iterator, valid_data_iterator,
process_non_loss_data_func)
print_datetime('after training is done')
if args.do_valid:
prefix = 'the end of training for val data'
evaluate_and_print_results(prefix, forward_step_func,
valid_data_iterator, model,
iteration, process_non_loss_data_func,
False)
if args.save and iteration != 0:
save_checkpoint(iteration, model, optimizer, opt_param_scheduler)
if args.do_test:
# Run on test data.
prefix = 'the end of training for test data'
evaluate_and_print_results(prefix, forward_step_func,
test_data_iterator, model,
0, process_non_loss_data_func,
True)
def update_train_iters(args):
# For iteration-based training, we don't need to do anything
if args.train_iters:
return
# Constant batch size with sample-based training.
if args.rampup_batch_size is None:
args.train_iters = args.train_samples // args.global_batch_size
else:
# Sample based training with rampup batch size.
iterations = 0
consumed_samples = 0
# Rampup phase.
while consumed_samples <= int(args.rampup_batch_size[2]):
update_num_microbatches(consumed_samples, consistency_check=False)
consumed_samples += get_current_global_batch_size()
iterations += 1
# Reset
update_num_microbatches(0, consistency_check=False)
# Constant phase
# Note that we throw away any partial last batch.
iterations += (args.train_samples - consumed_samples) // \
args.global_batch_size
args.train_iters = iterations
print_rank_0('setting training iterations to {}'.format(args.train_iters))
def get_model(model_provider_func, model_type=ModelType.encoder_or_decoder, wrap_with_ddp=True):
"""Build the model."""
args = get_args()
args.model_type = model_type
# Build model.
if mpu.get_pipeline_model_parallel_world_size() > 1 and \
args.virtual_pipeline_model_parallel_size is not None:
assert model_type != ModelType.encoder_and_decoder, \
"Interleaved schedule not supported for model with both encoder and decoder"
model = []
for i in range(args.virtual_pipeline_model_parallel_size):
mpu.set_virtual_pipeline_model_parallel_rank(i)
# Set pre_process and post_process only after virtual rank is set.
pre_process = mpu.is_pipeline_first_stage()
post_process = mpu.is_pipeline_last_stage()
this_model = model_provider_func(
pre_process=pre_process,
post_process=post_process
)
this_model.model_type = model_type
model.append(this_model)
else:
pre_process = mpu.is_pipeline_first_stage()
post_process = mpu.is_pipeline_last_stage()
add_encoder = True
add_decoder = True
if model_type == ModelType.encoder_and_decoder:
if mpu.get_pipeline_model_parallel_world_size() > 1:
assert args.pipeline_model_parallel_split_rank is not None, \
"Split rank needs to be specified for model with both encoder and decoder"
rank = mpu.get_pipeline_model_parallel_rank()
split_rank = args.pipeline_model_parallel_split_rank
world_size = mpu.get_pipeline_model_parallel_world_size()
pre_process = rank == 0 or rank == split_rank
post_process = (rank == (split_rank - 1)) or (
rank == (world_size - 1))
add_encoder = mpu.is_pipeline_stage_before_split()
add_decoder = mpu.is_pipeline_stage_after_split()
model = model_provider_func(
pre_process=pre_process,
post_process=post_process,
add_encoder=add_encoder,
add_decoder=add_decoder)
else:
model = model_provider_func(
pre_process=pre_process,
post_process=post_process
)
model.model_type = model_type
if not isinstance(model, list):
model = [model]
# Disallow training and inference with Transformer Engine
# for non-GPT models
args.allow_transformer_engine = all([type(m) == GPTModel for m in model])
assert args.allow_transformer_engine or args.transformer_impl == 'local', \
'Transformer Engine is only approved for GPT models'
# Set tensor model parallel attributes if not set.
# Only parameters that are already tensor model parallel have these
# attributes set for them. We should make sure the default attributes
# are set for all params so the optimizer can use them.
for model_module in model:
for param in model_module.parameters():
tensor_parallel.set_defaults_if_not_set_tensor_model_parallel_attributes(param)
# Print number of parameters.
if mpu.get_data_parallel_rank() == 0:
print(' > number of parameters on (tensor, pipeline) '
'model parallel rank ({}, {}): {}'.format(
mpu.get_tensor_model_parallel_rank(),
mpu.get_pipeline_model_parallel_rank(),
sum([sum([p.nelement() for p in model_module.parameters()])
for model_module in model])), flush=True)
# GPU allocation.
for model_module in model:
model_module.cuda(torch.cuda.current_device())
# Fp16 conversion.
if args.fp16 or args.bf16:
model = [Float16Module(model_module, args) for model_module in model]
if wrap_with_ddp:
if args.DDP_impl == 'torch':
i = torch.cuda.current_device()
model = [torchDDP(model_module, device_ids=[i], output_device=i,
process_group=mpu.get_data_parallel_group())
for model_module in model]
elif args.DDP_impl == 'local':
model = [LocalDDP(model_module,
args.accumulate_allreduce_grads_in_fp32,
args.use_contiguous_buffers_in_local_ddp)
for model_module in model]
# broad cast params from data parallel src rank to other data parallel ranks
if args.data_parallel_random_init:
for model_module in model:
model_module.broadcast_params()
else:
raise NotImplementedError('Unknown DDP implementation specified: '
'{}. Exiting.'.format(args.DDP_impl))
return model
def get_optimizer_param_scheduler(optimizer):
"""Build the learning rate scheduler."""
args = get_args()
# Iteration-based training.
if args.train_iters:
if args.lr_decay_iters is None:
args.lr_decay_iters = args.train_iters
lr_decay_steps = args.lr_decay_iters * args.global_batch_size
wd_incr_steps = args.train_iters * args.global_batch_size
if args.lr_warmup_fraction is not None:
lr_warmup_steps = args.lr_warmup_fraction * lr_decay_steps
else:
lr_warmup_steps = args.lr_warmup_iters * args.global_batch_size
# Sample-based training.
elif args.train_samples:
# We need to set training iters for later use. Technically
# we need to adjust the training samples too (due to last
# batch being incomplete) but we leave it as is for now.
update_train_iters(args)
if args.lr_decay_samples is None:
args.lr_decay_samples = args.train_samples
lr_decay_steps = args.lr_decay_samples
wd_incr_steps = args.train_samples
if args.lr_warmup_fraction is not None:
lr_warmup_steps = args.lr_warmup_fraction * lr_decay_steps
else:
lr_warmup_steps = args.lr_warmup_samples
else:
raise Exception(
'either train-iters or train-samples should be provided.')
opt_param_scheduler = OptimizerParamScheduler(
optimizer,
max_lr=args.lr,
min_lr=args.min_lr,
lr_warmup_steps=lr_warmup_steps,
lr_decay_steps=lr_decay_steps,
lr_decay_style=args.lr_decay_style,
start_wd=args.start_weight_decay,
end_wd=args.end_weight_decay,
wd_incr_steps=wd_incr_steps,
wd_incr_style=args.weight_decay_incr_style,
use_checkpoint_opt_param_scheduler=args.use_checkpoint_opt_param_scheduler,
override_opt_param_scheduler=args.override_opt_param_scheduler)
return opt_param_scheduler
def setup_model_and_optimizer(model_provider_func,
model_type,
no_wd_decay_cond=None,
scale_lr_cond=None,
lr_mult=1.0):
"""Setup model and optimizer."""
args = get_args()
model = get_model(model_provider_func, model_type)
unwrapped_model = unwrap_model(model,
(torchDDP, LocalDDP, Float16Module))
optimizer = get_megatron_optimizer(model, no_wd_decay_cond,
scale_lr_cond, lr_mult)
opt_param_scheduler = get_optimizer_param_scheduler(optimizer)
if args.load is not None:
timers = get_timers()
timers('load-checkpoint', log_level=0).start(barrier=True)
args.iteration = load_checkpoint(model, optimizer, opt_param_scheduler)
timers('load-checkpoint').stop(barrier=True)
timers.log(['load-checkpoint'])
else:
args.iteration = 0
# We only support local DDP with multiple micro-batches.
if len(model) > 1 or mpu.get_pipeline_model_parallel_world_size() > 1:
assert args.DDP_impl == 'local'
# get model without FP16 and/or TorchDDP wrappers
if args.iteration == 0 and len(unwrapped_model) == 1 \
and hasattr(unwrapped_model[0], 'init_state_dict_from_bert'):
print_rank_0("Initializing ICT from pretrained BERT model")
unwrapped_model[0].init_state_dict_from_bert()
if args.fp16:
optimizer.reload_model_params()
return model, optimizer, opt_param_scheduler
def train_step(forward_step_func, data_iterator,
model, optimizer, opt_param_scheduler):
"""Single training step."""
args = get_args()
timers = get_timers()
# Set grad to zero.
if args.DDP_impl == 'local' and args.use_contiguous_buffers_in_local_ddp:
for partition in model:
partition.zero_grad_buffer()
optimizer.zero_grad()
# Forward pass.
timers('forward-backward', log_level=1).start(
barrier=args.barrier_with_L1_time)
forward_backward_func = get_forward_backward_func()
fwd_bwd_timers = timers if args.timing_log_level > 1 else None
losses_reduced = forward_backward_func(
forward_step_func=forward_step_func,
data_iterator=data_iterator,
model=model,
num_microbatches=get_num_microbatches(),
dtype=args.params_dtype,
tensor_shape=(args.seq_length, args.micro_batch_size, args.hidden_size),
grad_scaler=optimizer.scale_loss,
sequence_parallel=args.sequence_parallel,
forward_only=False,
timers=fwd_bwd_timers)
timers('forward-backward').stop()
# Empty unused memory.
if args.empty_unused_memory_level >= 1:
torch.cuda.empty_cache()
# Reduce gradients.
optimizer.reduce_model_grads(args, timers)
# Vision gradients.
if args.vision_pretraining and args.vision_pretraining_type == "dino":
unwrapped_model = unwrap_model(model[0],
(torchDDP, LocalDDP, Float16Module))
unwrapped_model.cancel_gradients_last_layer(args.curr_iteration)
# Update parameters.
timers('optimizer', log_level=1).start(barrier=args.barrier_with_L1_time)
update_successful, grad_norm, num_zeros_in_grad = optimizer.step(args, timers)
timers('optimizer').stop()
# Gather params.
if update_successful:
optimizer.gather_model_params(args, timers)
# Vision momentum.
if args.vision_pretraining and args.vision_pretraining_type == "dino":
unwrapped_model = unwrap_model(model[0],
(torchDDP, LocalDDP, Float16Module))
unwrapped_model.update_momentum(args.curr_iteration)
# Update learning rate.
if update_successful:
increment = get_num_microbatches() * \
args.micro_batch_size * \
args.data_parallel_size
opt_param_scheduler.step(increment=increment)
skipped_iter = 0
else:
skipped_iter = 1
# Empty unused memory.
if args.empty_unused_memory_level >= 2:
torch.cuda.empty_cache()
if mpu.is_pipeline_last_stage(ignore_virtual=True):
# Average loss across microbatches.
loss_reduced = {}
for key in losses_reduced[0]:
losses_reduced_for_key = [x[key] for x in losses_reduced]
loss_reduced[key] = sum(losses_reduced_for_key) / len(losses_reduced_for_key)
return loss_reduced, skipped_iter, grad_norm, num_zeros_in_grad
return {}, skipped_iter, grad_norm, num_zeros_in_grad
def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
loss_scale, report_memory_flag, skipped_iter,
grad_norm, params_norm, num_zeros_in_grad):
"""Log training information such as losses, timing, ...."""
args = get_args()
timers = get_timers()
writer = get_tensorboard_writer()
# Advanced, skipped, and Nan iterations.
advanced_iters_key = 'advanced iterations'
skipped_iters_key = 'skipped iterations'
nan_iters_key = 'nan iterations'
# Advanced iterations.
if not skipped_iter:
total_loss_dict[advanced_iters_key] = total_loss_dict.get(
advanced_iters_key, 0) + 1
else:
if advanced_iters_key not in total_loss_dict:
total_loss_dict[advanced_iters_key] = 0
# Skipped iterations.
total_loss_dict[skipped_iters_key] = total_loss_dict.get(
skipped_iters_key, 0) + skipped_iter
# Update losses and set nan iterations
got_nan = False
for key in loss_dict:
if not skipped_iter:
total_loss_dict[key] = total_loss_dict.get(
key, torch.cuda.FloatTensor([0.0])) + loss_dict[key]
else:
value = loss_dict[key].float().sum().item()
is_nan = value == float('inf') or \
value == -float('inf') or \
value != value
got_nan = got_nan or is_nan
total_loss_dict[nan_iters_key] = total_loss_dict.get(
nan_iters_key, 0) + int(got_nan)
# Logging.
timers_to_log = [
'forward-backward',
'forward-compute',
'backward-compute',
'batch-generator',
'forward-recv',
'forward-send',
'backward-recv',
'backward-send',
'forward-send-forward-recv',
'forward-send-backward-recv',
'backward-send-forward-recv',
'backward-send-backward-recv',
'forward-backward-send-forward-backward-recv',
'layernorm-grads-all-reduce',
'embedding-grads-all-reduce',
'grads-all-reduce',
'grads-reduce-scatter',
'params-all-gather',
'optimizer-copy-to-main-grad',
'optimizer-unscale-and-check-inf',
'optimizer-clip-main-grad',
'optimizer-count-zeros',
'optimizer-inner-step',
'optimizer-copy-main-to-model-params',
'optimizer']
# Calculate batch size.
batch_size = args.micro_batch_size * args.data_parallel_size * \
get_num_microbatches()
total_iterations = total_loss_dict[advanced_iters_key] + \
total_loss_dict[skipped_iters_key]
# Tensorboard values.
# Timer requires all the ranks to call.
if args.log_timers_to_tensorboard and \
(iteration % args.tensorboard_log_interval == 0):
timers.write(timers_to_log, writer, iteration,
normalizer=total_iterations)
if writer and (iteration % args.tensorboard_log_interval == 0):
if args.log_learning_rate_to_tensorboard:
writer.add_scalar('learning-rate', learning_rate, iteration)
writer.add_scalar('learning-rate vs samples', learning_rate,
args.consumed_train_samples)
if args.log_batch_size_to_tensorboard:
writer.add_scalar('batch-size', batch_size, iteration)
writer.add_scalar('batch-size vs samples', batch_size,
args.consumed_train_samples)
for key in loss_dict:
writer.add_scalar(key , loss_dict[key], iteration)
writer.add_scalar(key + ' vs samples', loss_dict[key],
args.consumed_train_samples)
if args.log_loss_scale_to_tensorboard:
writer.add_scalar('loss-scale', loss_scale, iteration)
writer.add_scalar('loss-scale vs samples', loss_scale,
args.consumed_train_samples)
if args.log_world_size_to_tensorboard:
writer.add_scalar('world-size', args.world_size, iteration)
writer.add_scalar('world-size vs samples', args.world_size,
args.consumed_train_samples)
if grad_norm is not None:
writer.add_scalar('grad-norm', grad_norm, iteration)
writer.add_scalar('grad-norm vs samples', grad_norm,
args.consumed_train_samples)
if num_zeros_in_grad is not None:
writer.add_scalar('num-zeros', num_zeros_in_grad, iteration)
writer.add_scalar('num-zeros vs samples', num_zeros_in_grad,
args.consumed_train_samples)
if params_norm is not None:
writer.add_scalar('params-norm', params_norm, iteration)
writer.add_scalar('params-norm vs samples', params_norm,
args.consumed_train_samples)
if args.log_memory_to_tensorboard:
mem_stats = torch.cuda.memory_stats()
writer.add_scalar(
"mem-reserved-bytes",
mem_stats["reserved_bytes.all.current"],
iteration,
)
writer.add_scalar(
"mem-allocated-bytes",
mem_stats["allocated_bytes.all.current"],
iteration,
)
writer.add_scalar(
"mem-allocated-count",
mem_stats["allocation.all.current"],
iteration,
)
if iteration % args.log_interval == 0:
elapsed_time = timers('interval-time').elapsed(barrier=True)
elapsed_time_per_iteration = elapsed_time / total_iterations
if writer:
if args.log_timers_to_tensorboard:
writer.add_scalar('iteration-time',
elapsed_time_per_iteration, iteration)
log_string = ' iteration {:8d}/{:8d} |'.format(
iteration, args.train_iters)
log_string += ' consumed samples: {:12d} |'.format(
args.consumed_train_samples)
log_string += ' elapsed time per iteration (ms): {:.1f} |'.format(
elapsed_time_per_iteration * 1000.0)
log_string += ' learning rate: {:.3E} |'.format(learning_rate)
log_string += ' global batch size: {:5d} |'.format(batch_size)
for key in total_loss_dict:
if key not in [advanced_iters_key, skipped_iters_key,
nan_iters_key]:
avg = total_loss_dict[key].item() / \
float(max(1, total_loss_dict[advanced_iters_key]))
if avg > 0.0:
log_string += ' {}: {:.6E} |'.format(key, avg)
total_loss_dict[key] = torch.cuda.FloatTensor([0.0])
log_string += ' loss scale: {:.1f} |'.format(loss_scale)
if grad_norm is not None:
log_string += ' grad norm: {:.3f} |'.format(grad_norm)
if num_zeros_in_grad is not None:
log_string += ' num zeros: {:.1f} |'.format(num_zeros_in_grad)
if params_norm is not None:
log_string += ' params norm: {:.3f} |'.format(params_norm)
log_string += ' number of skipped iterations: {:3d} |'.format(
total_loss_dict[skipped_iters_key])
log_string += ' number of nan iterations: {:3d} |'.format(
total_loss_dict[nan_iters_key])
total_loss_dict[advanced_iters_key] = 0
total_loss_dict[skipped_iters_key] = 0
total_loss_dict[nan_iters_key] = 0
print_rank_last(log_string)
if report_memory_flag and learning_rate > 0.:
# Report memory after optimizer state has been initialized.
report_memory('(after {} iterations)'.format(iteration))
report_memory_flag = False
timers.log(timers_to_log, normalizer=args.log_interval)
return report_memory_flag
def save_checkpoint_and_time(iteration, model, optimizer, opt_param_scheduler):
timers = get_timers()
# Extra barrier is added to make sure
# all ranks report the max time.
timers('save-checkpoint', log_level=0).start(barrier=True)
save_checkpoint(iteration, model, optimizer, opt_param_scheduler)
timers('save-checkpoint').stop(barrier=True)
timers.log(['save-checkpoint'])
def train(forward_step_func, model, optimizer, opt_param_scheduler,
train_data_iterator, valid_data_iterator,
process_non_loss_data_func):
"""Train the model function."""
args = get_args()
timers = get_timers()
# Write args to tensorboard
write_args_to_tensorboard()
# Turn on training mode which enables dropout.
for model_module in model:
model_module.train()
# Tracking loss.
total_loss_dict = {}
# Iterations.
iteration = args.iteration
timers('interval-time', log_level=0).start(barrier=True)
print_datetime('before the start of training step')
report_memory_flag = True
while iteration < args.train_iters:
update_num_microbatches(args.consumed_train_samples)
args.curr_iteration = iteration
loss_dict, skipped_iter, grad_norm, num_zeros_in_grad = \
train_step(forward_step_func,
train_data_iterator,
model,
optimizer,
opt_param_scheduler)
iteration += 1
args.consumed_train_samples += mpu.get_data_parallel_world_size() * \
args.micro_batch_size * \
get_num_microbatches()
# Logging.
loss_scale = optimizer.get_loss_scale().item()
params_norm = None
if args.log_params_norm:
params_norm = calc_params_l2_norm(model)
report_memory_flag = training_log(loss_dict, total_loss_dict,
optimizer.param_groups[0]['lr'],
iteration, loss_scale,
report_memory_flag, skipped_iter,
grad_norm, params_norm, num_zeros_in_grad)
# Autoresume
if args.adlr_autoresume and \
(iteration % args.adlr_autoresume_interval == 0):
check_adlr_autoresume_termination(iteration, model, optimizer,
opt_param_scheduler)
# Evaluation
if args.eval_interval and iteration % args.eval_interval == 0 and \
args.do_valid:
prefix = 'iteration {}'.format(iteration)
evaluate_and_print_results(prefix, forward_step_func,
valid_data_iterator, model,
iteration, process_non_loss_data_func,
False)
# Checkpointing
saved_checkpoint = False
if args.exit_signal_handler:
signal_handler = get_signal_handler()
if any(signal_handler.signals_received()):
save_checkpoint_and_time(iteration, model, optimizer,
opt_param_scheduler)
print_datetime('exiting program after receiving SIGTERM.')
sys.exit()
if args.save and args.save_interval and \
iteration % args.save_interval == 0:
save_checkpoint_and_time(iteration, model, optimizer,
opt_param_scheduler)
saved_checkpoint = True
# Exiting based on duration
if args.exit_duration_in_mins:
train_time = (time.time() - _TRAIN_START_TIME) / 60.0
done_cuda = torch.cuda.IntTensor(
[train_time > args.exit_duration_in_mins])
torch.distributed.all_reduce(
done_cuda, op=torch.distributed.ReduceOp.MAX)
done = done_cuda.item()
if done:
if not saved_checkpoint:
save_checkpoint_and_time(iteration, model, optimizer,
opt_param_scheduler)
print_datetime('exiting program after {} minutes'.format(train_time))
sys.exit()
# Exiting based on iterations
if args.exit_interval and iteration % args.exit_interval == 0:
if args.save and not saved_checkpoint:
save_checkpoint_and_time(iteration, model, optimizer,
opt_param_scheduler)
torch.distributed.barrier()
print_datetime('exiting program at iteration {}'.format(iteration))
sys.exit()
return iteration
def evaluate(forward_step_func,
data_iterator,
model,
process_non_loss_data_func,
verbose=False):
"""Evaluation."""
args = get_args()
if args.vision_pretraining and args.vision_pretraining_type == "dino":
compute_feature_bank(model)
# Turn on evaluation mode which disables dropout.
for model_module in model:
model_module.eval()
total_loss_dict = {}
with torch.no_grad():
iteration = 0
while iteration < args.eval_iters:
iteration += 1
if verbose and iteration % args.log_interval == 0:
print_rank_0('Evaluating iter {}/{}'.format(iteration,
args.eval_iters))
forward_backward_func = get_forward_backward_func()
loss_dicts = forward_backward_func(
forward_step_func=forward_step_func,
data_iterator=data_iterator,
model=model,
num_microbatches=get_num_microbatches(),
dtype=args.params_dtype,
tensor_shape=(args.seq_length, args.micro_batch_size, args.hidden_size),
sequence_parallel=args.sequence_parallel,
forward_only=True,
timers=None)
# Empty unused memory
if args.empty_unused_memory_level >= 1:
torch.cuda.empty_cache()
if mpu.is_pipeline_last_stage(ignore_virtual=True):
# Reduce across processes.
for loss_dict in loss_dicts:
for key in loss_dict:
total_loss_dict[key] = total_loss_dict.get(
key, torch.cuda.FloatTensor([0.0])) + loss_dict[key]
args.consumed_valid_samples += mpu.get_data_parallel_world_size() \
* args.micro_batch_size \
* get_num_microbatches()
collected_non_loss_data = None
if process_non_loss_data_func is not None and is_last_rank():
collected_non_loss_data = forward_backward_func(
forward_step_func, data_iterator, model, optimizer=None,
timers=None, forward_only=True, collect_non_loss_data=True)
# Move model back to the train mode.
for model_module in model:
model_module.train()
for key in total_loss_dict:
total_loss_dict[key] /= args.eval_iters * get_num_microbatches()
return total_loss_dict, collected_non_loss_data
def evaluate_and_print_results(prefix, forward_step_func,
data_iterator, model,
iteration, process_non_loss_data_func,
verbose=False):
"""Helper function to evaluate and dump results on screen."""
args = get_args()
writer = get_tensorboard_writer()
total_loss_dict, collected_non_loss_data = evaluate(
forward_step_func, data_iterator, model,
process_non_loss_data_func, verbose)
string = ' validation loss at {} | '.format(prefix)
for key in total_loss_dict:
string += '{} value: {:.6E} | '.format(key, total_loss_dict[key].item())
ppl = math.exp(min(20, total_loss_dict[key].item()))
string += '{} PPL: {:.6E} | '.format(key, ppl)
if writer:
writer.add_scalar('{} validation'.format(key),
total_loss_dict[key].item(),
iteration)
writer.add_scalar('{} validation vs samples'.format(key),
total_loss_dict[key].item(),
args.consumed_train_samples)
if args.log_validation_ppl_to_tensorboard:
writer.add_scalar('{} validation ppl'.format(key), ppl,
iteration)
writer.add_scalar('{} validation ppl vs samples'.format(key),
ppl, args.consumed_train_samples)
if process_non_loss_data_func is not None and writer and is_last_rank():
process_non_loss_data_func(collected_non_loss_data, iteration, writer)
length = len(string) + 1
print_rank_last('-' * length)
print_rank_last(string)
print_rank_last('-' * length)
def cyclic_iter(iter):
while True:
for x in iter:
yield x
def build_train_valid_test_data_loaders(
build_train_valid_test_datasets_provider):
"""XXX"""
args = get_args()
(train_dataloader, valid_dataloader, test_dataloader) = (None, None, None)
print_rank_0('> building train, validation, and test datasets ...')
# Backward compatibility, assume fixed batch size.
if args.iteration > 0 and args.consumed_train_samples == 0:
assert args.train_samples is None, \
'only backward compatiblity support for iteration-based training'
args.consumed_train_samples = args.iteration * args.global_batch_size
if args.iteration > 0 and args.consumed_valid_samples == 0:
if args.train_samples is None:
args.consumed_valid_samples = (args.iteration // args.eval_interval) * \
args.eval_iters * args.global_batch_size
# Data loader only on rank 0 of each model parallel group.
if mpu.get_tensor_model_parallel_rank() == 0:
# Number of train/valid/test samples.
if args.train_samples:
train_samples = args.train_samples
else:
train_samples = args.train_iters * args.global_batch_size
eval_iters = (args.train_iters // args.eval_interval + 1) * \
args.eval_iters
test_iters = args.eval_iters
train_val_test_num_samples = [train_samples,
eval_iters * args.global_batch_size,
test_iters * args.global_batch_size]
print_rank_0(' > datasets target sizes (minimum size):')
print_rank_0(' train: {}'.format(train_val_test_num_samples[0]))
print_rank_0(' validation: {}'.format(train_val_test_num_samples[1]))
print_rank_0(' test: {}'.format(train_val_test_num_samples[2]))
# Build the datasets.
train_ds, valid_ds, test_ds = build_train_valid_test_datasets_provider(
train_val_test_num_samples)
# Build dataloders.
train_dataloader = build_pretraining_data_loader(
train_ds, args.consumed_train_samples)
valid_dataloader = build_pretraining_data_loader(
valid_ds, args.consumed_valid_samples)
test_dataloader = build_pretraining_data_loader(test_ds, 0)
# Flags to know if we need to do training/validation/testing.
do_train = train_dataloader is not None and args.train_iters > 0
do_valid = valid_dataloader is not None and args.eval_iters > 0
do_test = test_dataloader is not None and args.eval_iters > 0
# Need to broadcast num_tokens and num_type_tokens.
flags = torch.cuda.LongTensor(
[int(do_train), int(do_valid), int(do_test)])
else:
flags = torch.cuda.LongTensor([0, 0, 0])
# Broadcast num tokens.
torch.distributed.broadcast(flags,
mpu.get_tensor_model_parallel_src_rank(),
group=mpu.get_tensor_model_parallel_group())
args.do_train = flags[0].item()
args.do_valid = flags[1].item()
args.do_test = flags[2].item()
return train_dataloader, valid_dataloader, test_dataloader
def build_train_valid_test_data_iterators(
build_train_valid_test_datasets_provider):
args = get_args()
# Build loaders.
train_dataloader, valid_dataloader, test_dataloader = \
build_train_valid_test_data_loaders(
build_train_valid_test_datasets_provider)
# Build iterators.
dl_type = args.dataloader_type
assert dl_type in ['single', 'cyclic']
if train_dataloader is not None:
train_data_iterator = iter(train_dataloader) if dl_type == 'single' \
else iter(cyclic_iter(train_dataloader))
else:
train_data_iterator = None
if valid_dataloader is not None:
valid_data_iterator = iter(valid_dataloader) if dl_type == 'single' \
else iter(cyclic_iter(valid_dataloader))
else:
valid_data_iterator = None
if test_dataloader is not None:
test_data_iterator = iter(test_dataloader) if dl_type == 'single' \
else iter(cyclic_iter(test_dataloader))
else:
test_data_iterator = None
return train_data_iterator, valid_data_iterator, test_data_iterator
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
"""General utilities."""
import sys
import torch
from torch.nn.parallel import DistributedDataParallel as torchDDP
from apex.multi_tensor_apply import multi_tensor_applier
import amp_C
from megatron import (
get_args,
get_adlr_autoresume,
)
from megatron.core import mpu
from megatron.core.tensor_parallel import param_is_not_tensor_parallel_duplicate
from megatron.model.module import param_is_not_shared
def unwrap_model(model, module_instances=(torchDDP)):
return_list = True
if not isinstance(model, list):
model = [model]
return_list = False
unwrapped_model = []
for model_module in model:
while isinstance(model_module, module_instances):
model_module = model_module.module
unwrapped_model.append(model_module)
if not return_list:
return unwrapped_model[0]
return unwrapped_model
def calc_params_l2_norm(model):
"""Calculate l2 norm of parameters """
args = get_args()
if not isinstance(model, list):
model = [model]
# Remove duplicate params.
params_data = []
for model_ in model:
for param in model_.parameters():
is_not_shared = param_is_not_shared(param)
is_not_tp_duplicate = param_is_not_tensor_parallel_duplicate(param)
if is_not_shared and is_not_tp_duplicate:
if args.bf16:
params_data.append(param.data.float())
else:
params_data.append(param.data)
# Calculate norm
dummy_overflow_buf = torch.cuda.IntTensor([0])
norm, _ = multi_tensor_applier(
amp_C.multi_tensor_l2norm,
dummy_overflow_buf,
[params_data],
False # no per-parameter norm
)
norm_2 = norm * norm
# Sum across all model-parallel GPUs.
torch.distributed.all_reduce(norm_2,
op=torch.distributed.ReduceOp.SUM,
group=mpu.get_model_parallel_group())
return norm_2.item() ** 0.5
def average_losses_across_data_parallel_group(losses):
"""Reduce a tensor of losses across all GPUs."""
averaged_losses = torch.cat(
[loss.clone().detach().view(1) for loss in losses])
torch.distributed.all_reduce(averaged_losses,
group=mpu.get_data_parallel_group())
averaged_losses = averaged_losses / \
torch.distributed.get_world_size(group=mpu.get_data_parallel_group())
return averaged_losses
def report_memory(name):
"""Simple GPU memory report."""
mega_bytes = 1024.0 * 1024.0
string = name + ' memory (MB)'
string += ' | allocated: {}'.format(
torch.cuda.memory_allocated() / mega_bytes)
string += ' | max allocated: {}'.format(
torch.cuda.max_memory_allocated() / mega_bytes)
string += ' | reserved: {}'.format(
torch.cuda.memory_reserved() / mega_bytes)
string += ' | max reserved: {}'.format(
torch.cuda.max_memory_reserved() / mega_bytes)
if mpu.get_data_parallel_rank() == 0:
print("[Rank {}] {}".format(torch.distributed.get_rank(), string),
flush=True)
def print_params_min_max_norm(optimizer, iteration):
"""Print min, max, and norm of all parameters."""
index = 0
rank = torch.distributed.get_rank()
string = 'iteration, rank, index, tensor-model-parallel, min, max, norm\n'
optimizer_ = optimizer.optimizer
for param_group in optimizer_.param_groups:
for param in param_group['params']:
index += 1
min_ = param.data.min()
max_ = param.data.max()
norm = torch.linalg.norm(param.data)
string += '{:7d}, {:4d}, {:4d}, {:2d}, '.format(
iteration, rank, index, int(param.tensor_model_parallel))
string += '{:.6E}, {:.6E}, {:.6E}\n'.format(min_, max_, norm)
print(string, flush=True)
def check_adlr_autoresume_termination(iteration, model,
optimizer, opt_param_scheduler):
"""Check for autoresume signal and exit if it is received."""
from megatron.checkpointing import save_checkpoint
args = get_args()
autoresume = get_adlr_autoresume()
# Add barrier to ensure consistnecy.
torch.distributed.barrier()
if autoresume.termination_requested():
if args.save:
save_checkpoint(iteration, model, optimizer, opt_param_scheduler)
print_rank_0(">>> autoresume termination request found!")
if torch.distributed.get_rank() == 0:
autoresume.request_resume()
print_rank_0(">>> training terminated. Returning")
sys.exit(0)
def get_ltor_masks_and_position_ids(data,
eod_token,
reset_position_ids,
reset_attention_mask,
eod_mask_loss):
"""Build masks and position id for left to right model."""
# Extract batch size and sequence length.
micro_batch_size, seq_length = data.size()
# Attention mask (lower triangular).
if reset_attention_mask:
att_mask_batch = micro_batch_size
else:
att_mask_batch = 1
attention_mask = torch.tril(torch.ones(
(att_mask_batch, seq_length, seq_length), device=data.device)).view(
att_mask_batch, 1, seq_length, seq_length)
# Loss mask.
loss_mask = torch.ones(data.size(), dtype=torch.float, device=data.device)
if eod_mask_loss:
loss_mask[data == eod_token] = 0.0
# Position ids.
position_ids = torch.arange(seq_length, dtype=torch.long,
device=data.device)
position_ids = position_ids.unsqueeze(0).expand_as(data)
# We need to clone as the ids will be modifed based on batch index.
if reset_position_ids:
position_ids = position_ids.clone()
if reset_position_ids or reset_attention_mask:
# Loop through the batches:
for b in range(micro_batch_size):
# Find indecies where EOD token is.
eod_index = position_ids[b, data[b] == eod_token]
# Detach indecies from positions if going to modify positions.
if reset_position_ids:
eod_index = eod_index.clone()
# Loop through EOD indecies:
prev_index = 0
for j in range(eod_index.size()[0]):
i = eod_index[j]
# Mask attention loss.
if reset_attention_mask:
attention_mask[b, 0, (i + 1):, :(i + 1)] = 0
# Reset positions.
if reset_position_ids:
position_ids[b, (i + 1):] -= (i + 1 - prev_index)
prev_index = i + 1
# Convert attention mask to binary:
attention_mask = (attention_mask < 0.5)
return attention_mask, loss_mask, position_ids
def print_rank_0(message):
"""If distributed is initialized, print only on rank 0."""
if torch.distributed.is_initialized():
if torch.distributed.get_rank() == 0:
print(message, flush=True)
else:
print(message, flush=True)
def is_last_rank():
return torch.distributed.get_rank() == (
torch.distributed.get_world_size() - 1)
def print_rank_last(message):
"""If distributed is initialized, print only on last rank."""
if torch.distributed.is_initialized():
if is_last_rank():
print(message, flush=True)
else:
print(message, flush=True)
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
"""Pretrain BERT"""
from functools import partial
import torch
import torch.nn.functional as F
from megatron import get_args
from megatron import print_rank_0
from megatron import get_timers
from megatron.core import tensor_parallel
from megatron.core.enums import ModelType
from megatron.data.dataset_utils import build_train_valid_test_datasets
from megatron.model import BertModel
from megatron.training import pretrain
from megatron.utils import average_losses_across_data_parallel_group
def model_provider(pre_process=True, post_process=True):
"""Build the model."""
print_rank_0('building BERT model ...')
args = get_args()
num_tokentypes = 2 if args.bert_binary_head else 0
model = BertModel(
num_tokentypes=num_tokentypes,
add_binary_head=args.bert_binary_head,
parallel_output=True,
pre_process=pre_process,
post_process=post_process)
return model
def get_batch(data_iterator):
"""Build the batch."""
# Items and their type.
keys = ['text', 'types', 'labels', 'is_random', 'loss_mask', 'padding_mask']
datatype = torch.int64
# Broadcast data.
if data_iterator is not None:
data = next(data_iterator)
else:
data = None
data_b = tensor_parallel.broadcast_data(keys, data, datatype)
# Unpack.
tokens = data_b['text'].long()
types = data_b['types'].long()
sentence_order = data_b['is_random'].long()
loss_mask = data_b['loss_mask'].float()
lm_labels = data_b['labels'].long()
padding_mask = data_b['padding_mask'].long()
return tokens, types, sentence_order, loss_mask, lm_labels, padding_mask
def loss_func(loss_mask, sentence_order, output_tensor):
lm_loss_, sop_logits = output_tensor
lm_loss_ = lm_loss_.float()
loss_mask = loss_mask.float()
lm_loss = torch.sum(
lm_loss_.view(-1) * loss_mask.reshape(-1)) / loss_mask.sum()
if sop_logits is not None:
sop_loss = F.cross_entropy(sop_logits.view(-1, 2).float(),
sentence_order.view(-1),
ignore_index=-1)
sop_loss = sop_loss.float()
loss = lm_loss + sop_loss
averaged_losses = average_losses_across_data_parallel_group(
[lm_loss, sop_loss])
return loss, {'lm loss': averaged_losses[0],
'sop loss': averaged_losses[1]}
else:
loss = lm_loss
averaged_losses = average_losses_across_data_parallel_group(
[lm_loss])
return loss, {'lm loss': averaged_losses[0]}
def forward_step(data_iterator, model):
"""Forward step."""
args = get_args()
timers = get_timers()
# Get the batch.
timers('batch-generator', log_level=2).start()
tokens, types, sentence_order, loss_mask, lm_labels, padding_mask = get_batch(
data_iterator)
timers('batch-generator').stop()
if not args.bert_binary_head:
types = None
# Forward pass through the model.
output_tensor = model(tokens, padding_mask, tokentype_ids=types,
lm_labels=lm_labels)
return output_tensor, partial(loss_func, loss_mask, sentence_order)
def train_valid_test_datasets_provider(train_val_test_num_samples):
"""Build train, valid, and test datasets."""
args = get_args()
print_rank_0('> building train, validation, and test datasets '
'for BERT ...')
train_ds, valid_ds, test_ds = build_train_valid_test_datasets(
data_prefix=args.data_path,
data_impl=args.data_impl,
splits_string=args.split,
train_valid_test_num_samples=train_val_test_num_samples,
max_seq_length=args.seq_length,
masked_lm_prob=args.mask_prob,
short_seq_prob=args.short_seq_prob,
seed=args.seed,
skip_warmup=(not args.mmap_warmup),
binary_head=args.bert_binary_head)
print_rank_0("> finished creating BERT datasets ...")
return train_ds, valid_ds, test_ds
if __name__ == "__main__":
pretrain(train_valid_test_datasets_provider, model_provider,
ModelType.encoder_or_decoder,
forward_step, args_defaults={'tokenizer_type': 'BertWordPieceLowerCase'})
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
"""Pretrain GPT"""
import torch
from functools import partial
from megatron import get_args
from megatron import print_rank_0
from megatron import get_timers
from megatron import get_tokenizer
from megatron.core import tensor_parallel
from megatron.core.enums import ModelType
from megatron.data.gpt_dataset import build_train_valid_test_datasets
from megatron.model import GPTModel
from megatron.training import pretrain
from megatron.utils import get_ltor_masks_and_position_ids
from megatron.utils import average_losses_across_data_parallel_group
def model_provider(pre_process=True, post_process=True):
"""Build the model."""
print_rank_0('building GPT model ...')
model = GPTModel(
num_tokentypes=0,
parallel_output=True,
pre_process=pre_process,
post_process=post_process
)
return model
def get_batch(data_iterator):
"""Generate a batch"""
args = get_args()
tokenizer = get_tokenizer()
# Items and their type.
keys = ['text']
datatype = torch.int64
# Broadcast data.
if data_iterator is not None:
data = next(data_iterator)
else:
data = None
data_b = tensor_parallel.broadcast_data(keys, data, datatype)
# Unpack.
tokens_ = data_b['text'].long()
labels = tokens_[:, 1:].contiguous()
tokens = tokens_[:, :-1].contiguous()
# Get the masks and postition ids.
attention_mask, loss_mask, position_ids = get_ltor_masks_and_position_ids(
tokens,
tokenizer.eod,
args.reset_position_ids,
args.reset_attention_mask,
args.eod_mask_loss)
return tokens, labels, loss_mask, attention_mask, position_ids
def loss_func(loss_mask, output_tensor):
losses = output_tensor.float()
loss_mask = loss_mask.view(-1).float()
loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask.sum()
# Reduce loss for logging.
averaged_loss = average_losses_across_data_parallel_group([loss])
return loss, {'lm loss': averaged_loss[0]}
def forward_step(data_iterator, model):
"""Forward step."""
args = get_args()
timers = get_timers()
# Get the batch.
timers('batch-generator', log_level=2).start()
tokens, labels, loss_mask, attention_mask, position_ids = get_batch(
data_iterator)
timers('batch-generator').stop()
output_tensor = model(tokens, position_ids, attention_mask,
labels=labels)
return output_tensor, partial(loss_func, loss_mask)
def train_valid_test_datasets_provider(train_val_test_num_samples):
"""Build train, valid, and test datasets."""
args = get_args()
print_rank_0('> building train, validation, and test datasets '
'for GPT ...')
train_ds, valid_ds, test_ds = build_train_valid_test_datasets(
data_prefix=args.data_path,
data_impl=args.data_impl,
splits_string=args.split,
train_valid_test_num_samples=train_val_test_num_samples,
seq_length=args.seq_length,
seed=args.seed,
skip_warmup=(not args.mmap_warmup),
train_data_prefix=args.train_data_path,
valid_data_prefix=args.valid_data_path,
test_data_prefix=args.test_data_path)
print_rank_0("> finished creating GPT datasets ...")
return train_ds, valid_ds, test_ds
if __name__ == "__main__":
pretrain(train_valid_test_datasets_provider, model_provider,
ModelType.encoder_or_decoder,
forward_step,
args_defaults={'tokenizer_type': 'GPT2BPETokenizer'}
)
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
"""Pretrain BERT for Inverse Cloze Task"""
from functools import partial
import math
import torch
import torch.distributed as dist
import torch.nn.functional as F
from megatron import get_args
from megatron import print_rank_0
from megatron import get_timers
from megatron.core import mpu
from megatron.core.enums import ModelType
from megatron.data.biencoder_dataset_utils import get_ict_batch
from megatron.data.dataset_utils import build_train_valid_test_datasets
from megatron.model.biencoder_model import biencoder_model_provider
from megatron.training import pretrain
from megatron.utils import average_losses_across_data_parallel_group
def pretrain_ict_model_provider(pre_process=True, post_process=True):
args = get_args()
model = biencoder_model_provider(
only_context_model=False,
only_query_model=False,
biencoder_shared_query_context_model=\
args.biencoder_shared_query_context_model,
pre_process=pre_process, post_process=post_process)
return model
def get_group_world_size_rank():
group = mpu.get_data_parallel_group()
rank = torch.distributed.get_rank(group=group)
world_size = torch.distributed.get_world_size(group=group)
return group, rank, world_size
class AllgatherFromDataParallelRegion(torch.autograd.Function):
@staticmethod
def forward(ctx, input_):
assert input_.dim() == 2
group, rank, world_size = get_group_world_size_rank()
tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
tensor_list[rank] = input_
torch.distributed.all_gather(tensor_list, input_, group=group)
output = torch.cat(tensor_list, dim=0).contiguous()
return output
@staticmethod
def backward(ctx, grad_output):
group, rank, world_size = get_group_world_size_rank()
assert grad_output.shape[0] % world_size == 0
dim_size = grad_output.shape[0] // world_size
output_list = torch.split(grad_output, dim_size, dim=0)
# get chunk from this rank
output = output_list[rank].contiguous()
return output
def loss_func(output_tensor):
args = get_args()
query_logits, context_logits = output_tensor
micro_batch_size = query_logits.shape[0]
# recall we assert that tensor_model_parallel_size == 1
assert mpu.get_tensor_model_parallel_world_size() == 1, \
"Model parallel size > 1 not supported for ICT"
global_batch_size = dist.get_world_size() * micro_batch_size
all_query_logits = AllgatherFromDataParallelRegion.apply(query_logits)
all_context_logits = AllgatherFromDataParallelRegion.apply(context_logits)
# scores are inner products between query and context embeddings
retrieval_scores = torch.matmul(all_query_logits,
torch.transpose(all_context_logits, 0, 1))
# scaling the retriever scores
if args.retriever_score_scaling:
retrieval_scores = retrieval_scores / math.sqrt(args.hidden_size)
softmax_scores = F.log_softmax(retrieval_scores, dim=1)
sorted_vals, sorted_indices = torch.topk(softmax_scores,
k=softmax_scores.shape[1], sorted=True)
def topk_accuracy(k):
return torch.cuda.FloatTensor([sum([int(i in sorted_indices[i, :k]) \
for i in range(global_batch_size)]) / global_batch_size])
topk_accs = [topk_accuracy(int(k)) for k in args.retriever_report_topk_accuracies]
labels = torch.arange(global_batch_size).long().cuda()
loss = F.nll_loss(softmax_scores, labels, reduction='mean')
reduced_losses = average_losses_across_data_parallel_group([loss, *topk_accs])
# Scale the retrieval loss
loss = loss * mpu.get_data_parallel_world_size()
# create stats_dict with retrieval loss and all specified top-k accuracies
topk_acc_dict = {'top{}_acc'.format(k): v * 100 for k, v in \
zip(args.retriever_report_topk_accuracies, reduced_losses[1:])}
stats_dict = dict(loss=reduced_losses[0], **topk_acc_dict)
return loss, stats_dict
def forward_step(data_iterator, model):
"""Forward step."""
args = get_args()
timers = get_timers()
# Get the batch.
timers('batch-generator', log_level=2).start()
query_tokens, query_mask, \
context_tokens, context_mask, context_indices = get_ict_batch(data_iterator)
timers('batch-generator').stop()
# Query and Context Types
query_types = torch.cuda.LongTensor(*query_tokens.shape).fill_(0)
context_types = torch.cuda.LongTensor(*context_tokens.shape).fill_(0)
# Forward model.
output_tensor = model(query_tokens, query_mask, query_types, context_tokens,
context_mask, context_types)
return output_tensor, partial(loss_func)
def train_valid_test_datasets_provider(train_val_test_num_samples):
"""Build train, valid and test datasets."""
args = get_args()
print_rank_0('> building train, validation, and test datasets '
'for BERT ICT...')
train_ds, valid_ds, test_ds = build_train_valid_test_datasets(
data_prefix=args.data_path,
data_impl=args.data_impl,
splits_string=args.split,
train_valid_test_num_samples=train_val_test_num_samples,
max_seq_length=args.seq_length,
masked_lm_prob=args.mask_prob,
short_seq_prob=args.short_seq_prob,
seed=args.seed,
skip_warmup=(not args.mmap_warmup),
binary_head=False,
dataset_type='ict')
print_rank_0("> finished creating BERT ICT datasets ...")
return train_ds, valid_ds, test_ds
if __name__ == "__main__":
pretrain(train_valid_test_datasets_provider,
pretrain_ict_model_provider,
ModelType.encoder_or_decoder,
forward_step,
args_defaults={'tokenizer_type': 'BertWordPieceLowerCase'})
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
"""Pretrain Retro."""
from functools import partial
import torch
from megatron import get_args, get_retro_args
from megatron import get_timers
from megatron import get_tokenizer
from megatron import print_rank_0
from megatron.core import mpu, tensor_parallel
from megatron.core.enums import ModelType
from megatron.model import GPTModel
from megatron.training import pretrain
from megatron.utils import get_ltor_masks_and_position_ids
from tools.retro.pretraining.retro_dataset import get_retro_datasets
from pretrain_gpt import (
loss_func,
model_provider,
train_valid_test_datasets_provider as standard_datasets_provider,
)
def get_batch(data_iterator):
"""Generate a batch"""
args = get_args()
retro_args = get_retro_args()
tokenizer = get_tokenizer()
# Items and their type.
keys = ['text']
datatype = torch.int64
if args.retro_add_retriever:
keys += 'neighbor_tokens',
# Broadcast data.
if data_iterator is not None:
data = next(data_iterator)
else:
data = None
data_b = tensor_parallel.broadcast_data(keys, data, datatype)
# Unpack.
tokens_ = data_b['text'].long()
labels = tokens_[:, 1:].contiguous()
tokens = tokens_[:, :-1].contiguous()
if args.retro_add_retriever:
# note: [bs * l * k, r]
# note: 2x == neighbor, continuation
neighbor_tokens = data_b['neighbor_tokens'] \
.view(-1, retro_args.retro_gpt_retrieved_length).long()
# Get the masks and postition ids.
attention_mask, loss_mask, position_ids = get_ltor_masks_and_position_ids(
tokens,
tokenizer.eod,
args.reset_position_ids,
args.reset_attention_mask,
args.eod_mask_loss)
if args.retro_add_retriever:
_, _, neighbor_position_ids = get_ltor_masks_and_position_ids(
neighbor_tokens,
tokenizer.eod,
args.reset_position_ids,
args.reset_attention_mask,
args.eod_mask_loss)
neighbor_attention_mask = None
return tokens, labels, loss_mask, attention_mask, position_ids, \
neighbor_tokens, neighbor_attention_mask, neighbor_position_ids
else:
return tokens, labels, loss_mask, attention_mask, position_ids
def forward_step(data_iterator, model):
"""Forward step."""
args = get_args()
timers = get_timers()
# Get the batch.
timers('batch-generator').start()
if args.retro_add_retriever:
tokens, labels, loss_mask, attention_mask, position_ids, \
neighbor_tokens, neighbor_attention_mask, neighbor_position_ids = \
get_batch(data_iterator)
else:
tokens, labels, loss_mask, attention_mask, position_ids = get_batch(
data_iterator)
neighbor_tokens, neighbor_attention_mask, neighbor_position_ids = \
None, None, None
timers('batch-generator').stop()
output_tensor = model(tokens, position_ids, attention_mask,
ret_input_ids=neighbor_tokens,
ret_position_ids=neighbor_position_ids,
ret_attn_mask=neighbor_attention_mask,
labels=labels)
return output_tensor, partial(loss_func, loss_mask)
def train_valid_test_datasets_provider(train_val_test_num_samples):
"""Build train, valid, and test datasets."""
args = get_args()
if args.retro_add_retriever:
return get_retro_datasets()
else:
return standard_datasets_provider(train_val_test_num_samples)
if __name__ == "__main__":
pretrain(train_valid_test_datasets_provider, model_provider,
ModelType.encoder_or_decoder,
forward_step, args_defaults={'tokenizer_type': 'GPT2BPETokenizer'})
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
"""Pretrain T5"""
from functools import partial
import torch
from megatron import (
get_args,
get_timers,
print_rank_0
)
from megatron.core import tensor_parallel
from megatron.core.enums import ModelType
from megatron.data.dataset_utils import build_train_valid_test_datasets
from megatron.model import T5Model
from megatron.training import pretrain
from megatron.utils import average_losses_across_data_parallel_group
"""
Pipeline parallelism for T5
===========================
T5 is a model architecture with both encoder and decoder blocks.
Consequently, pipeline parallelism is implemented slightly differently
compared to architectures like GPT and BERT.
In particular, when pipeline_model_parallel_world_size > 1, each stage
either executes an encoder block or a decoder block. The
--pipeline-model-parallel-split-rank argument controls the rank at which
the split happens: all ranks lower than this argument execute the
encoder block, and all ranks equal to or higher than this argument value
execute the decoder block.
In the encoder section of the model, only one tensor is sent downstream:
the intermediate encoder_hidden_state. In the decoder section of the
model, two tensors are sent downstream in the forward pass: the fully
computed encoder_hidden_state, and the intermediate decoder_hidden_state.
In particular, these are the shapes of the tensors sent between
different workers:
If rank is in decoder section:
intermediate decoder_hidden_state (pre-transpose),
complete encoder_hidden_state (post-transpose).
If rank is at boundary between encoder and decoder sections:
complete encoder_hidden_state (post-transpose).
If rank is in encoder section:
intermediate encoder_hidden_state (pre-transpose).
Additionally, we have code in the backward_step function in schedules.py
to accumulate the encoder_hidden_state gradient across skip connections
(encoder_hidden_state fed in as input to each layer in the decoder).
"""
def model_provider(pre_process=True, post_process=True,
add_encoder=True, add_decoder=True):
"""Build the model."""
print_rank_0('building T5 model ...')
model = T5Model(num_tokentypes=0,
parallel_output=True,
pre_process=pre_process,
post_process=post_process,
add_encoder=add_encoder,
add_decoder=add_decoder)
return model
def get_batch(data_iterator):
"""Build the batch."""
keys = ['text_enc', 'text_dec', 'labels', 'loss_mask',
'enc_mask', 'dec_mask', 'enc_dec_mask']
datatype = torch.int64
# Broadcast data.
if data_iterator is not None:
data = next(data_iterator)
else:
data = None
data_b = tensor_parallel.broadcast_data(keys, data, datatype)
# Unpack.
tokens_enc = data_b['text_enc'].long()
tokens_dec = data_b['text_dec'].long()
labels = data_b['labels'].long()
loss_mask = data_b['loss_mask'].float()
enc_mask = (data_b['enc_mask'] < 0.5)
dec_mask = (data_b['dec_mask'] < 0.5)
enc_dec_mask = (data_b['enc_dec_mask'] < 0.5)
return tokens_enc, tokens_dec, loss_mask, labels, \
enc_mask, dec_mask, enc_dec_mask
def loss_func(loss_mask, output_tensor):
lm_loss_ = output_tensor.float()
lm_loss = torch.sum(
lm_loss_.view(-1) * loss_mask.reshape(-1)) / loss_mask.sum()
loss = lm_loss
averaged_losses = average_losses_across_data_parallel_group([lm_loss])
return loss, {'lm loss': averaged_losses[0]}
def forward_step(data_iterator, model):
"""Forward step."""
args = get_args()
timers = get_timers()
# Get the batch.
timers('batch generator', log_level=2).start()
tokens_enc, tokens_dec, loss_mask, lm_labels, enc_mask, dec_mask, enc_dec_mask \
= get_batch(data_iterator)
timers('batch generator').stop()
# Forward model lm_labels
output_tensor = model(tokens_enc,
tokens_dec,
enc_mask,
dec_mask,
enc_dec_mask,
tokentype_ids=None,
lm_labels=lm_labels)
return output_tensor, partial(loss_func, loss_mask)
def train_valid_test_datasets_provider(train_val_test_num_samples):
"""Build train, valid, and test datasets."""
args = get_args()
print_rank_0('> building train, validation, and test datasets '
'for T5 ...')
train_ds, valid_ds, test_ds = build_train_valid_test_datasets(
data_prefix=args.data_path,
data_impl=args.data_impl,
splits_string=args.split,
train_valid_test_num_samples=train_val_test_num_samples,
max_seq_length=args.encoder_seq_length,
max_seq_length_dec=args.decoder_seq_length,
masked_lm_prob=args.mask_prob,
short_seq_prob=args.short_seq_prob,
seed=args.seed,
skip_warmup=(not args.mmap_warmup),
dataset_type='t5')
print_rank_0("> finished creating T5 datasets ...")
return train_ds, valid_ds, test_ds
if __name__ == "__main__":
pretrain(train_valid_test_datasets_provider, model_provider, ModelType.encoder_and_decoder,
forward_step, args_defaults={'tokenizer_type': 'BertWordPieceLowerCase'})
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
"""Pretrain VIT"""
import torch
import torch.nn.functional as F
from functools import partial
from megatron import get_args, get_timers, print_rank_0
from megatron.core.enums import ModelType
from megatron.data.vit_dataset import build_train_valid_datasets
from megatron.model.vision.classification import VitClassificationModel
from megatron.model.vision.classification import MitClassificationModel
from megatron.training import pretrain
from megatron.utils import average_losses_across_data_parallel_group
def model_provider(pre_process=True, post_process=True):
"""Build the model."""
args = get_args()
if args.vision_backbone_type == 'vit':
print_rank_0("building VIT model ...")
model = VitClassificationModel(num_classes=args.num_classes,
pre_process=pre_process,
post_process=post_process)
elif args.vision_backbone_type == 'mit':
print_rank_0("building MIT model ...")
model = MitClassificationModel(num_classes=args.num_classes,
pre_process=pre_process,
post_process=post_process)
else:
raise Exception('{} vision backbone is not supported.'.format(
args.vision_backbone_type))
return model
def get_batch(data_iterator):
"""Build the batch."""
data = next(data_iterator)
# only data parallelism; no need for broadcast
images = data[0].cuda()
labels = data[1].cuda()
return images, labels
def loss_func(labels, output_tensor):
logits = output_tensor.contiguous().float()
loss = F.cross_entropy(logits, labels)
outputs = torch.argmax(logits, -1)
correct = (outputs == labels).float()
accuracy = torch.mean(correct)
averaged_loss = average_losses_across_data_parallel_group([loss, accuracy])
return loss, {"loss": averaged_loss[0], "accuracy": averaged_loss[1]}
def forward_step(data_iterator, model):
"""Forward step."""
timers = get_timers()
# Get the batch.
timers("batch-generator", log_level=2).start()
(
images,
labels,
) = get_batch(data_iterator)
timers("batch-generator").stop()
# Forward model. lm_labels
output_tensor = model(images)
return output_tensor, partial(loss_func, labels)
def train_valid_test_datasets_provider(train_val_test_num_samples):
"""Build train, valid, and test datasets."""
args = get_args()
print_rank_0(
"> building train, validation, and test datasets " "for VIT ..."
)
train_ds, valid_ds = build_train_valid_datasets(
data_path=args.data_path,
image_size=(args.img_h, args.img_w)
)
print_rank_0("> finished creating VIT datasets ...")
return train_ds, valid_ds, None
if __name__ == "__main__":
pretrain(
train_valid_test_datasets_provider,
model_provider,
ModelType.encoder_or_decoder,
forward_step,
args_defaults={'dataloader_type': 'cyclic', 'vision_pretraining': True}
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment