Commit 1016e98a authored by zhuww's avatar zhuww
Browse files

megatron-lm0.3.2 based on dtk-22.10

parent 6c3f6c7b
<!-- coding=utf-8-->
<!-- Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.-->
<!---->
<!-- Licensed under the Apache License, Version 2.0 (the "License");-->
<!-- you may not use this file except in compliance with the License.-->
<!-- You may obtain a copy of the License at-->
<!---->
<!-- http://www.apache.org/licenses/LICENSE-2.0-->
<!---->
<!-- Unless required by applicable law or agreed to in writing, software-->
<!-- distributed under the License is distributed on an "AS IS" BASIS,-->
<!-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.-->
<!-- See the License for the specific language governing permissions and-->
<!-- limitations under the License.-->
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="utf-8" />
<title>Megatron</title>
<style>
.wrapper {
max-width: 75%;
margin: auto;
}
h1 {
margin: 3rem 0 1rem 0;
padding: 0;
font-size: 1.5rem;
}
textarea {
width: 100%;
min-height: 300px;
resize: none;
border-radius: 8px;
border: 1px solid #ddd;
padding: 0.5rem;
box-shadow: inset 0 0 0.25rem #ddd;
&:focus {
outline: none;
border: 1px solid darken(#ddd, 5%);
box-shadow: inset 0 0 0.5rem darken(#ddd, 5%);
}
}
#the-count {
float: right;
padding: 0.1rem 0 0 0;
font-size: 0.875rem;
}
/* Chat containers */
.container {
font-family: 'Arial', sans-serif;
font-size: 16px;
border: 2px solid #dedede;
background-color: #f1f1f1;
border-radius: 5px;
padding: 15px;
margin: 10px 0;
}
/* Clear floats */
.container::after {
content: "";
clear: both;
display: table;
}
/* Style images */
.container img {
float: left;
max-width: 60px;
width: 100%;
margin-right: 20px;
border-radius: 50%;
}
</style>
</head>
<body>
<div class="wrapper">
<h1>Prompt Megatron</h1>
<textarea name="prompt" id="prompt" maxlength="1024" placeholder="Add prompt"autofocus></textarea>
<label for="tokens_to_generate">Number tokens to generate (1-1024):</label>
<input type="number" id="tokens_to_generate" name="tokens_to_generate" min="10" max="256", value=32>
<button onclick="submit_query()">Submit</button>
<div id="the-count">
<span id="current">0</span>
<span id="maximum">/ 1000</span>
</div>
<textarea name="response" id="response" maxlength="2048" placeholder="Megatron response..."></textarea>
</div>
<script type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/jquery/3.5.1/jquery.min.js"></script>
<script type="text/javascript">
function submit_query() {
$("#response").val("Waiting for Megatron response...");
$.ajax({
url:"api",
type:"PUT",
data:JSON.stringify({prompts: [$("#prompt").val()], tokens_to_generate: parseInt($("#tokens_to_generate").val(),10)}),
contentType:"application/json; charset=utf-8",
dataType:"json",
success: function(data){
data.max_len=35;
$("#response").val(data.text);
}
});
}
$('textarea').keyup(function() {
var characterCount = $(this).val().length,
current = $('#current'),
maximum = $('#maximum'),
theCount = $('#the-count');
current.text(characterCount);
if (characterCount >= 800) {
maximum.css('color', '#8f0001');
current.css('color', '#8f0001');
theCount.css('font-weight','bold');
} else {
maximum.css('color','#666');
theCount.css('font-weight','normal');
}
});
</script>
</body>
</html>
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .api import (
generate,
generate_and_post_process)
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference API."""
import torch
from megatron import mpu
from .communication import broadcast_float_list
from .generation import (
generate_tokens_probs_and_return_on_first_stage,
score_and_return_on_first_stage)
from .tokenization import (
tokenize_prompts,
detokenize_generations)
def generate_and_post_process(model,
prompts=None,
tokens_to_generate=0,
return_output_log_probs=False,
top_k_sampling=0,
top_p_sampling=0.0,
temperature=1.0,
add_BOS=False,
use_eod_token_for_early_termination=True,
stop_on_double_eol=False,
stop_on_eol=False,
random_seed=-1):
"""Run inference and post-process outputs, i.e., detokenize,
move to cpu and convert to list."""
# Main inference.
tokens, lengths, output_log_probs = generate(
model,
prompts=prompts,
tokens_to_generate=tokens_to_generate,
return_output_log_probs=return_output_log_probs,
top_k_sampling=top_k_sampling,
top_p_sampling=top_p_sampling,
temperature=temperature,
add_BOS=add_BOS,
use_eod_token_for_early_termination=use_eod_token_for_early_termination,
stop_on_double_eol=stop_on_double_eol,
stop_on_eol=stop_on_eol,
random_seed=random_seed)
# Only post-process on first stage.
if mpu.is_pipeline_first_stage():
tokens, prompts_plus_generations, prompts_plus_generations_segments = \
detokenize_generations(tokens, lengths, True)
if return_output_log_probs:
output_log_probs = output_log_probs.cpu().numpy().tolist()
for i, (prob, seg) in enumerate(zip(output_log_probs, prompts_plus_generations_segments)):
output_log_probs[i] = prob[:len(seg)-1]
return prompts_plus_generations, prompts_plus_generations_segments, \
output_log_probs, tokens
return None
def generate(model,
prompts=None,
tokens_to_generate=0,
return_output_log_probs=False,
top_k_sampling=0,
top_p_sampling=0.0,
temperature=1.0,
add_BOS=False,
use_eod_token_for_early_termination=True,
stop_on_double_eol=False,
stop_on_eol=False,
random_seed=-1):
"""Given prompts and input parameters, run inference and return:
tokens: prompts plus the generated tokens.
lengths: length of the prompt + generations. Note that we can
discard tokens in the tokens tensor that are after the
corresponding length.
output_log_probs: log probs of the tokens.
"""
# Make sure input params are avaialble to all ranks.
values = [tokens_to_generate,
return_output_log_probs,
top_k_sampling, top_p_sampling,
temperature, add_BOS, use_eod_token_for_early_termination,
stop_on_double_eol,
stop_on_eol,
random_seed]
values_float_tensor = broadcast_float_list(10, float_list=values)
tokens_to_generate = int(values_float_tensor[0].item())
return_output_log_probs = bool(values_float_tensor[1].item())
top_k_sampling = int(values_float_tensor[2].item())
top_p_sampling = values_float_tensor[3].item()
temperature = values_float_tensor[4].item()
add_BOS = bool(values_float_tensor[5].item())
use_eod_token_for_early_termination = bool(values_float_tensor[6].item())
stop_on_double_eol = bool(values_float_tensor[7].item())
stop_on_eol = bool(values_float_tensor[8].item())
random_seed = int(values_float_tensor[9].item())
if random_seed != -1:
torch.random.manual_seed(random_seed)
# Tokenize prompts and get the batch.
# Note that these tensors are broadcaseted to all ranks.
if torch.distributed.get_rank() == 0:
assert prompts is not None
context_tokens_tensor, context_length_tensor = tokenize_prompts(
prompts=prompts, tokens_to_generate=tokens_to_generate, add_BOS=add_BOS)
if tokens_to_generate == 0:
return score_and_return_on_first_stage(
model, context_tokens_tensor, context_length_tensor)
# Main inference function.
# Note that the outputs are available on the first stage.
return generate_tokens_probs_and_return_on_first_stage(
model, context_tokens_tensor, context_length_tensor,
return_output_log_probs=return_output_log_probs,
top_k=top_k_sampling,
top_p=top_p_sampling,
temperature=temperature,
use_eod_token_for_early_termination=use_eod_token_for_early_termination,
stop_on_double_eol=stop_on_double_eol,
stop_on_eol=stop_on_eol)
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Communications utilities."""
import torch
from megatron import mpu
# TODO: use functions from megatron/p2p
def recv_from_prev_pipeline_rank_(recv_buffer=None):
"""Receive from previous pipeline stage and update the
input buffer inplace."""
if not mpu.is_pipeline_first_stage():
assert recv_buffer is not None
recv_prev_op = torch.distributed.P2POp(
torch.distributed.irecv, recv_buffer,
mpu.get_pipeline_model_parallel_prev_rank())
reqs = torch.distributed.batch_isend_irecv([recv_prev_op])
for req in reqs:
req.wait()
# To protect against race condition when using batch_isend_irecv().
torch.cuda.synchronize()
# TODO: use functions from megatron/p2p
def send_to_next_pipeline_rank(tensor=None):
"""Send output to the next pipeline stage."""
if not mpu.is_pipeline_last_stage():
assert tensor is not None
send_next_op = torch.distributed.P2POp(
torch.distributed.isend, tensor,
mpu.get_pipeline_model_parallel_next_rank())
reqs = torch.distributed.batch_isend_irecv([send_next_op])
for req in reqs:
req.wait()
# To protect against race condition when using batch_isend_irecv().
torch.cuda.synchronize()
def _is_cuda(tensor):
"""Check if a tensor is not none and is cuda."""
assert tensor is not None
assert tensor.is_cuda
def _is_cuda_contiguous(tensor):
"""Check if a tensor is not none, is cuda, and is contiguous."""
_is_cuda(tensor)
assert tensor.is_contiguous()
def broadcast_from_last_pipeline_stage(size, dtype, tensor=None):
"""Broadcast a tensor from last pipeline stage to all ranks."""
is_last_stage = mpu.is_pipeline_last_stage()
# If first stage and last state are the same, then there is no
# pipeline parallelism and no need to communicate.
if mpu.is_pipeline_first_stage() and is_last_stage:
return tensor
if is_last_stage:
_is_cuda_contiguous(tensor)
else:
tensor = torch.empty(size,
dtype=dtype,
device=torch.cuda.current_device())
# Get the group and corresponding source rank.
src = mpu.get_pipeline_model_parallel_last_rank()
group = mpu.get_pipeline_model_parallel_group()
torch.distributed.broadcast(tensor, src, group)
return tensor
def broadcast_from_last_to_first_pipeline_stage(size, dtype, tensor=None):
"""Broadcast tensor values from last stage into the first stage."""
is_last_stage = mpu.is_pipeline_last_stage()
is_first_stage = mpu.is_pipeline_first_stage()
# If first stage and last state are the same, then there is no
# pipeline parallelism and no need to communicate.
if is_first_stage and is_last_stage:
return tensor
# Only first and last stage pipeline stages need to be involved.
if is_last_stage or is_first_stage:
if is_last_stage:
_is_cuda_contiguous(tensor)
else:
tensor = torch.empty(size,
dtype=dtype,
device=torch.cuda.current_device())
src = mpu.get_pipeline_model_parallel_last_rank()
group = mpu.get_embedding_group()
# Broadcast from last stage into the first stage.
torch.distributed.broadcast(tensor, src, group)
else:
tensor = None
return tensor
def copy_from_last_to_first_pipeline_stage(size, dtype, tensor=None):
"""Copy tensor values from last stage into the first stage.
Note that the input tensor is updated in place."""
is_last_stage = mpu.is_pipeline_last_stage()
is_first_stage = mpu.is_pipeline_first_stage()
# If first stage and last state are the same, then there is no
# pipeline parallelism and no need to communicate.
if is_first_stage and is_last_stage:
return
# Only first and last stage pipeline stages need to be involved.
if is_last_stage or is_first_stage:
_is_cuda(tensor)
is_contiguous = tensor.is_contiguous()
src = mpu.get_pipeline_model_parallel_last_rank()
group = mpu.get_embedding_group()
if is_contiguous:
tensor_ = tensor
else:
if is_last_stage:
tensor_ = tensor.contiguous()
else:
tensor_ = torch.empty(size,
dtype=dtype,
device=torch.cuda.current_device())
# Broadcast from last stage into the first stage.
torch.distributed.broadcast(tensor_, src, group)
# Update the first stage tensor
if is_first_stage and not is_contiguous:
tensor[...] = tensor_
def broadcast_tensor(size, dtype, tensor=None, rank=0):
""" Given size and type of a tensor on all ranks and the tensor value
only on a specific rank, broadcast from that rank to all other ranks.
"""
if torch.distributed.get_rank() == rank:
_is_cuda_contiguous(tensor)
else:
tensor = torch.empty(size,
dtype=dtype,
device=torch.cuda.current_device())
torch.distributed.broadcast(tensor, rank)
return tensor
def broadcast_list(size, dtype, list_values=None, rank=0):
"""Broadcast a list of values with a given type."""
tensor = None
if torch.distributed.get_rank() == rank:
tensor = torch.tensor(list_values, dtype=dtype,
device=torch.cuda.current_device())
return broadcast_tensor(size, dtype, tensor=tensor, rank=rank)
def broadcast_int_list(size, int_list=None, rank=0):
"""Broadcast a list of interger values."""
return broadcast_list(size, torch.int64, list_values=int_list, rank=rank)
def broadcast_float_list(size, float_list=None, rank=0):
"""Broadcast a list of float values."""
return broadcast_list(size, torch.float32, list_values=float_list,
rank=rank)
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Forward step utilities."""
from collections.abc import Iterable
import torch
from megatron import (
get_args,
mpu)
from .communication import (
send_to_next_pipeline_rank,
recv_from_prev_pipeline_rank_)
class InferenceParams:
"""Inference parameters that are passed to the main model in order
to efficienly calculate and store the context during inference."""
def __init__(self, max_batch_size, max_sequence_len):
"""Note that offsets are set to zero and we always set the
flag to allocate memory. After the first call, make sure to
set this flag to False."""
self.max_sequence_len = max_sequence_len
self.max_batch_size = max_batch_size
self.sequence_len_offset = 0
self.batch_size_offset = 0
self.key_value_memory_dict = {}
class ForwardStep:
"""Forward step function with all the communications.
We use a class here to hide the inference parameters
from the outside caller."""
def __init__(self, model, max_batch_size, max_sequence_len):
"""Set values so we don't need to do it multiple times."""
# Make sure model is in eval mode.
assert not isinstance(model, Iterable), \
'interleaving schedule is not supported for inference'
model.eval()
self.model = model
# Initialize inference parameters.
self.inference_params = InferenceParams(max_batch_size,
max_sequence_len)
# Pipelining arguments.
args = get_args()
self.pipeline_size_larger_than_one = (
args.pipeline_model_parallel_size > 1)
# Threshold of pipelining.
self.pipelining_batch_x_seqlen = \
args.inference_batch_times_seqlen_threshold
def __call__(self, tokens, position_ids, attention_mask):
"""Invocation of the forward methods. Note that self.inference_params
is being modified by the forward step."""
# Pipelining case.
if self.pipeline_size_larger_than_one:
current_batch_x_seqlen = tokens.size(0) * tokens.size(1)
if current_batch_x_seqlen >= self.pipelining_batch_x_seqlen:
micro_batch_size = \
max(1, self.pipelining_batch_x_seqlen // tokens.size(1))
return _with_pipelining_forward_step(self.model,
tokens,
position_ids,
attention_mask,
self.inference_params,
micro_batch_size)
return _no_pipelining_forward_step(self.model,
tokens,
position_ids,
attention_mask,
self.inference_params)
def _get_recv_buffer_dtype(args):
"""Receive happens between the layers."""
if args.fp32_residual_connection:
return torch.float
return args.params_dtype
def _allocate_recv_buffer(batch_size, sequence_length):
"""Receive happens between the layers with size [s, b, h]."""
if mpu.is_pipeline_first_stage():
return None
args = get_args()
recv_size = (sequence_length, batch_size, args.hidden_size)
return torch.empty(recv_size,
dtype=_get_recv_buffer_dtype(args),
device=torch.cuda.current_device())
def _forward_step_helper(model, tokens, position_ids, attention_mask,
inference_params, recv_buffer=None):
"""Single forward step. Update the allocate memory flag so
only the first time the memory is allocated."""
batch_size = tokens.size(0)
sequence_length = tokens.size(1)
if recv_buffer is None:
recv_buffer = _allocate_recv_buffer(batch_size, sequence_length)
# Receive from previous stage.
recv_from_prev_pipeline_rank_(recv_buffer)
# Forward pass through the model.
model.set_input_tensor(recv_buffer)
output_tensor = model(tokens, position_ids, attention_mask,
inference_params=inference_params)
# Send output to the next stage.
send_to_next_pipeline_rank(output_tensor)
return output_tensor
def _no_pipelining_forward_step(model, tokens, position_ids, attention_mask,
inference_params, recv_buffer=None):
"""If recv_buffer is none, we will allocate one on the fly."""
# Run a simple forward pass.
output_tensor = _forward_step_helper(model, tokens, position_ids,
attention_mask, inference_params,
recv_buffer=recv_buffer)
# Update the sequence length offset.
inference_params.sequence_len_offset += tokens.size(1)
logits = None
if mpu.is_pipeline_last_stage():
logits = output_tensor
return logits
def _with_pipelining_forward_step(model, tokens, position_ids, attention_mask,
inference_params, micro_batch_size):
"""No interleaving is supported."""
sequence_length = tokens.size(1)
batch_size = tokens.size(0)
# Divide the batch dimension into micro batches.
num_micro_batches, last_chunk = divmod(batch_size,
micro_batch_size)
if last_chunk > 0:
num_micro_batches += 1
# Preallocate memory for output logits.
logits = None
if mpu.is_pipeline_last_stage():
args = get_args()
logits = torch.empty(
(batch_size, sequence_length, args.padded_vocab_size),
dtype=torch.float32, device=torch.cuda.current_device())
# Preallocate recv buffer.
recv_buffer = _allocate_recv_buffer(micro_batch_size, sequence_length)
for micro_batch_index in range(num_micro_batches):
# Slice among the batch dimenion.
start = micro_batch_index * micro_batch_size
end = min(start + micro_batch_size, batch_size)
this_micro_batch_size = end - start
tokens2use = tokens[start:end, ...]
position_ids2use = position_ids[start:end, ...]
# Run a simple forward pass.
if this_micro_batch_size != micro_batch_size:
recv_buffer = None
output = _forward_step_helper(model, tokens2use, position_ids2use,
attention_mask, inference_params,
recv_buffer=recv_buffer)
# Adjust the batch size offset to account for the micro-batch.
inference_params.batch_size_offset += this_micro_batch_size
# Copy logits.
if mpu.is_pipeline_last_stage():
logits[start:end, ...] = output
# Once we are done with all the micro-batches, we can
# adjust the sequence length offset.
inference_params.sequence_len_offset += sequence_length
# and reset the batch size offset
inference_params.batch_size_offset = 0
return logits
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Generation utilities."""
import torch
import torch.nn.functional as F
from megatron import get_args, get_tokenizer, mpu
from megatron.utils import get_ltor_masks_and_position_ids
from .communication import (
copy_from_last_to_first_pipeline_stage,
broadcast_from_last_pipeline_stage,
broadcast_from_last_to_first_pipeline_stage)
from .forward_step import ForwardStep
from .sampling import sample
def score_and_return_on_first_stage(model, tokens, lengths):
"""Function for just scoring.
Arguments:
model: no interleaving is supported.
tokens: prompt tokens extended to be of size [b, max_prompt_length]
lengths: original prompt length, size: [b]
Note: Outside of model, other parameters only need to be available on
rank 0.
Outputs:
output_log_probs: log probability of the selected tokens. size: [b, s]
"""
args = get_args()
batch_size = tokens.size(0)
max_prompt_length = lengths.max().item()
assert max_prompt_length == tokens.size(1)
max_sequence_length = min(max_prompt_length, args.max_position_embeddings)
# forward step.
forward_step = ForwardStep(model, batch_size, max_sequence_length)
# ===================
# Pre-allocate memory
# ===================
# Log probability of the sequence (prompt + generated tokens).
output_log_probs = None
output_log_probs_size = (batch_size, max_sequence_length - 1)
if mpu.is_pipeline_last_stage():
output_log_probs = torch.empty(output_log_probs_size,
dtype=torch.float32,
device=torch.cuda.current_device())
# =============
# Run infernece
# =============
with torch.no_grad():
attention_mask, position_ids = _build_attention_mask_and_position_ids(tokens)
# logits will be meanigful only in the last pipeline stage.
logits = forward_step(tokens, position_ids, attention_mask)
if mpu.is_pipeline_last_stage():
# Always the last stage should have an output.
assert logits is not None
log_probs = F.log_softmax(logits, dim=2)
# Pick the tokens that we need to get the log
# probabilities for. Note that next input token is
# the token which we selected in the current logits,
# so shift by 1.
indices = torch.unsqueeze(tokens[:, 1:], 2)
output_log_probs = torch.gather(log_probs, 2, indices).squeeze(2)
# ======================================
# Broadcast to the first pipeline stage.
# ======================================
output_log_probs = broadcast_from_last_to_first_pipeline_stage(
output_log_probs_size, torch.float32, output_log_probs)
return tokens, lengths, output_log_probs
def generate_tokens_probs_and_return_on_first_stage(
model, tokens, lengths,
return_output_log_probs=False,
top_k=0, top_p=0.0,
temperature=1.0,
use_eod_token_for_early_termination=True,
stop_on_double_eol=False,
stop_on_eol=False
):
"""Main token generation function.
Arguments:
model: no interleaving is supported.
tokens: prompt tokens extended to be of size [b, max-sequence-length]
lengths: original prompt length, size: [b]
return_output_log_probs: flag to calculate the log probability of
the generated tokens. Note that the log probability is the one
from the original logit.
top_k, top_p: top-k and top-p sampling parameters.
Note that top-k = 1 is gready. Also, these paramters are
exclusive meaning that:
if top-k > 0 then we expect top-p=0.
if top-p > 0 then we check for top-k=0.
temperature: sampling temperature.
use_eod_token_for_early_termination: if True, do early termination if
all the sequences have reached this token.
Note: Outside of model, other parameters only need to be available on
rank 0.
Outputs: Note that is size is adjusted to a lower value than
max-sequence-length if generation is terminated early.
tokens: prompt and generated tokens. size: [b, :]
generated_sequence_lengths: total length (including prompt) of
the generated sequence. size: [b]
output_log_probs: log probability of the selected tokens. size: [b, s]
"""
args = get_args()
tokenizer = get_tokenizer()
batch_size = tokens.size(0)
min_prompt_length = lengths.min().item()
max_sequence_length = tokens.size(1)
max_sequence_length = min(max_sequence_length, args.max_position_embeddings)
# If the context is too big, this happens
if min_prompt_length >= max_sequence_length:
raise ValueError("context length + tokens_to_generate too large")
# forward step.
forward_step = ForwardStep(model, batch_size, max_sequence_length)
# Added termination_id to support the case that we want to terminate the
# generation once that id is generated.
if hasattr(args, 'eos_id'):
termination_id = args.eos_id
else:
termination_id = tokenizer.eod
# ===================
# Pre-allocate memory
# ===================
# Log probability of the sequence (prompt + generated tokens).
output_log_probs = None
output_log_probs_size = (batch_size, max_sequence_length - 1)
# Lengths of generated seuquence including including prompts.
generated_sequence_lengths = None
if mpu.is_pipeline_last_stage():
if return_output_log_probs:
output_log_probs = torch.empty(output_log_probs_size,
dtype=torch.float32,
device=torch.cuda.current_device())
generated_sequence_lengths = torch.ones(
batch_size, dtype=torch.int64,
device=torch.cuda.current_device()) * max_sequence_length
# Whether we have reached a termination id.
is_generation_done = torch.zeros(batch_size, dtype=torch.uint8,
device=torch.cuda.current_device())
# =============
# Run infernece
# =============
with torch.no_grad():
attention_mask, position_ids = _build_attention_mask_and_position_ids(
tokens)
prev_context_length = 0
for context_length in range(min_prompt_length, max_sequence_length):
# Pick the slice that we need to pass through the network.
tokens2use = tokens[:, prev_context_length:context_length]
positions2use = position_ids[:, prev_context_length:context_length]
attention_mask2use = attention_mask[
..., prev_context_length:context_length, :context_length]
# logits will be meanigful only in the last pipeline stage.
logits = forward_step(tokens2use, positions2use, attention_mask2use)
if mpu.is_pipeline_last_stage():
# Always the last stage should have an output.
assert logits is not None
# Sample.
last_token_logits = logits[:, -1, :]
new_sample = sample(last_token_logits,
top_k=top_k,
top_p=top_p,
temperature=temperature,
vocab_size=tokenizer.vocab_size)
# If a prompt length is smaller or equal th current context
# length, it means we have started generating tokens
started = lengths <= context_length
# Update the tokens.
tokens[started, context_length] = new_sample[started]
# Calculate the log probabilities.
if return_output_log_probs:
log_probs = F.log_softmax(logits, dim=2)
if return_output_log_probs:
# Pick the tokens that we need to get the log
# probabilities for. Note that next input token is
# the token which we selected in the current logits,
# so shift by 1.
indices = torch.unsqueeze(
tokens[
:,
(prev_context_length + 1):(context_length + 1)],
2)
output_log_probs[:,
prev_context_length:context_length] = \
torch.gather(log_probs, 2, indices).squeeze(2)
# Update the tokens on the first stage so the next input to
# the network is correct.
copy_from_last_to_first_pipeline_stage(batch_size, torch.int64,
tokens[:, context_length])
# Update the context length for the next token generation.
prev_context_length = context_length
# Check if all the sequences have hit the termination_id.
done = None
if mpu.is_pipeline_last_stage():
# TODO(rprenger) These stopping methods are tokenizer dependent
# instead tokenization should be in the inference loop so stop sequences can be used
if stop_on_double_eol:
hit_double_eol = (new_sample == 628).byte() & started.byte()
hit_two_eols = (new_sample == 198).byte() & (tokens[:, context_length-1] == 198).byte() & started.byte()
done_token = hit_double_eol | hit_two_eols
elif stop_on_eol:
hit_double_eol = (new_sample == 628).byte() & started.byte()
hit_eol = (new_sample == 198).byte() & started.byte()
done_token = hit_double_eol | hit_eol
else:
done_token = (new_sample == termination_id).byte() & \
started.byte()
just_finished = (done_token & ~is_generation_done).bool()
generated_sequence_lengths[just_finished.view(-1)] = \
context_length + 1
is_generation_done = is_generation_done | done_token
done = torch.all(is_generation_done)
done = broadcast_from_last_pipeline_stage(1, torch.uint8,
tensor=done)
if use_eod_token_for_early_termination and done:
break
# ===================================================
# Update the length of based on max generated length.
# ===================================================
tokens = tokens[:, :(context_length + 1)]
if mpu.is_pipeline_last_stage():
if return_output_log_probs:
output_log_probs = output_log_probs[:, :context_length]
# ======================================
# Broadcast to the first pipeline stage.
# ======================================
generated_sequence_lengths = broadcast_from_last_to_first_pipeline_stage(
batch_size, torch.int64, generated_sequence_lengths)
if return_output_log_probs:
output_log_probs_size = (batch_size, context_length)
output_log_probs = broadcast_from_last_to_first_pipeline_stage(
output_log_probs_size, torch.float32, output_log_probs)
return tokens, generated_sequence_lengths, output_log_probs
def _build_attention_mask_and_position_ids(tokens):
"""Build the attention mask and postition ids for the input tokens."""
# Since we are not interested in loss-mask and reset attention/position
# is also False, eod_token is not used so it is safe to set it to None.
attention_mask, _, position_ids = get_ltor_masks_and_position_ids(
data=tokens,
eod_token=None,
reset_position_ids=False,
reset_attention_mask=False,
eod_mask_loss=False)
return attention_mask, position_ids
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Sampling utilities.
Part of this code is inspired by:
- https://github.com/ari-holtzman/degen/blob/master/gen.py
- https://huggingface.co/transformers/_modules/transformers/generation_logits_process.html
"""
import torch
def modify_logits_for_top_k_filtering(logits, top_k):
"""Set the logits for none top-k values to -inf."""
filter_ = logits < torch.topk(logits, top_k)[0][..., -1, None]
logits.masked_fill_(filter_, float('-Inf'))
def modify_logits_for_top_p_filtering(logits, top_p):
"""Set the logits for none top-p values to -inf."""
# First sort and calculate cumulative sum of probabilities.
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)
# Filteration based on the cumulative sum.
filter_ = cumulative_probs > top_p
# This shift by 1 is weird and I cannot justify it. This existed
# in the original implementation:
# https://github.com/ari-holtzman/degen/blob/master/gen.py
# and I guess it is needed so keeping it for now.
filter_[:, 1:] = filter_[:, :-1].clone()
# Make sure we at least have one token to select from.
filter_[..., 0] = 0
# Fill in the filtered part
filter_ = filter_.scatter(1, sorted_indices, filter_)
logits.masked_fill_(filter_, float('-Inf'))
def sample(logits, top_k=0, top_p=0.0, temperature=1.0, vocab_size=None):
""" Sample and generate a token.
Note: logits has the dimension [b, v] where b is the batch size
and v is the vocabulary size.
If vocab_size is provided, we will make sure the sample that is
generated is in [0, vocab-size). This will avoid out of vocabulary
generations due to padding.
"""
# Check logits for consistency.
assert logits.ndim == 2, 'expected the logits to be of [b, v] shape.'
assert logits.type() == 'torch.cuda.FloatTensor', \
'input logits should be floats.'
# Greedy is just simple argmax.
if top_k == 1:
assert top_p == 0.0, 'cannot set both greedy and top-p samplings.'
samples = torch.argmax(logits, dim=-1)
# Top-k or top-p sampling.
else:
# Clone so we do not modify the inputs,
logits = logits.clone()
# Apply temperature in place.
if temperature != 1.0:
logits.div_(temperature)
if top_k > 1:
assert top_p == 0.0, 'cannot set both top-k and top-p samplings.'
assert top_k <= logits.size(1), 'top-k is larger than logit size.'
if vocab_size:
assert top_k < vocab_size, 'top-k is larger than vocab size.'
modify_logits_for_top_k_filtering(logits, top_k)
elif top_p > 0.0:
assert top_p <= 1.0, 'top-p should be in (0, 1].'
modify_logits_for_top_p_filtering(logits, top_p)
# After filtering, we need to recalculate the distribution.
probs = logits.softmax(dim=-1)
samples = torch.multinomial(probs, num_samples=1).view(-1)
# If vocab size is provided, make sure the samples are in
# in the range [0, vocab-size).
if vocab_size:
samples = torch.clamp(samples, min=0, max=(vocab_size - 1))
return samples
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tokenization utilities."""
import torch
from megatron import get_tokenizer
from .communication import broadcast_int_list, broadcast_tensor
def detokenize_generations(tokens_gpu_tensor,
lengths_gpu_tensor,
return_segments):
"""Detokenize the generated tokens."""
tokenizer = get_tokenizer()
prompts_plus_generations = []
if return_segments:
prompts_plus_generations_segments = []
tokens = tokens_gpu_tensor.cpu().numpy().tolist()
lengths = lengths_gpu_tensor.cpu().numpy().tolist()
for sequence_tokens, length in zip(tokens, lengths):
sequence_tokens = sequence_tokens[:length]
prompts_plus_generations.append(
tokenizer.detokenize(sequence_tokens))
if return_segments:
words = []
for token in sequence_tokens:
word = tokenizer.tokenizer.decoder[token]
word = bytearray(
[tokenizer.tokenizer.byte_decoder[c] for c in word]).decode(
'utf-8', errors='replace')
words.append(word)
prompts_plus_generations_segments.append(words)
if return_segments:
return tokens, prompts_plus_generations, \
prompts_plus_generations_segments
return tokens, prompts_plus_generations
def tokenize_prompts(prompts=None, tokens_to_generate=None,
add_BOS=None, rank=0):
"""Tokenize prompts and make them avaiable on all ranks."""
# On all ranks set to None so we can pass them to functions
sizes_list = None
prompts_tokens_cuda_long_tensor = None
prompts_length_cuda_long_tensor = None
# On the specified rank, build the above.
if torch.distributed.get_rank() == rank:
assert prompts is not None
assert tokens_to_generate is not None
# Tensor of tokens padded and their unpadded length.
prompts_tokens_cuda_long_tensor, prompts_length_cuda_long_tensor = \
_tokenize_prompts_and_batch(prompts, tokens_to_generate, add_BOS)
# We need the sizes of these tensors for the boradcast
sizes_list = [prompts_tokens_cuda_long_tensor.size(0), # Batch size
prompts_tokens_cuda_long_tensor.size(1)] # Sequence lenght
# First, broadcast the sizes.
sizes_tensor = broadcast_int_list(2, int_list=sizes_list, rank=rank)
# Now that we have the sizes, we can boradcast the tokens
# and length tensors.
sizes = sizes_tensor.tolist()
prompts_tokens_cuda_long_tensor = broadcast_tensor(
sizes, torch.int64, tensor=prompts_tokens_cuda_long_tensor, rank=rank)
prompts_length_cuda_long_tensor = broadcast_tensor(
sizes[0], torch.int64, tensor=prompts_length_cuda_long_tensor,
rank=rank)
return prompts_tokens_cuda_long_tensor, prompts_length_cuda_long_tensor
def _tokenize_prompts_and_batch(prompts, tokens_to_generate, add_BOS):
"""Given a set of prompts and number of tokens to generate:
- tokenize prompts
- set the sequence length to be the max of length of prompts
plus the number of tokens we would like to generate
- pad all the sequences to this length so we can convert them
into a 2D tensor.
"""
# Tokenize all the prompts.
tokenizer = get_tokenizer()
if add_BOS:
prompts_tokens = [[tokenizer.eod] + tokenizer.tokenize(prompt)
for prompt in prompts]
else:
prompts_tokens = [tokenizer.tokenize(prompt) for prompt in prompts]
# Now we have a list of list of tokens which each list has a different
# size. We want to extend this list to:
# - incorporate the tokens that need to be generated
# - make all the sequences equal length.
# Get the prompts length.
prompts_length = [len(prompt_tokens) for prompt_tokens in prompts_tokens]
# Get the max prompts length.
max_prompt_len = max(prompts_length)
# Number of tokens in the each sample of the batch.
samples_length = max_prompt_len + tokens_to_generate
# Now update the list of list to be of the same size: samples_length.
for prompt_tokens, prompt_length in zip(prompts_tokens, prompts_length):
padding_size = samples_length - prompt_length
prompt_tokens.extend([tokenizer.eod] * padding_size)
# Now we are in a structured format, we can convert to tensors.
prompts_tokens_tensor = torch.cuda.LongTensor(prompts_tokens)
prompts_length_tensor = torch.cuda.LongTensor(prompts_length)
return prompts_tokens_tensor, prompts_length_tensor
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import datetime
import torch
import json
import threading
from flask import Flask, request, jsonify, current_app
from flask_restful import Resource, Api
from megatron import get_args
from megatron.text_generation import generate_and_post_process
GENERATE_NUM = 0
lock = threading.Lock()
class MegatronGenerate(Resource):
def __init__(self, model):
self.model = model
@staticmethod
def send_do_generate():
choice = torch.cuda.LongTensor([GENERATE_NUM])
torch.distributed.broadcast(choice, 0)
def put(self):
args = get_args()
if not "prompts" in request.get_json():
return "prompts argument required", 400
if "max_len" in request.get_json():
return "max_len is no longer used. Replace with tokens_to_generate", 400
if "sentences" in request.get_json():
return "sentences is no longer used. Replace with prompts", 400
prompts = request.get_json()["prompts"]
if len(prompts) > 128:
return "Maximum number of prompts is 128", 400
tokens_to_generate = 64 # Choosing hopefully sane default. Full sequence is slow
if "tokens_to_generate" in request.get_json():
tokens_to_generate = request.get_json()["tokens_to_generate"]
if not isinstance(tokens_to_generate, int):
return "tokens_to_generate must be an integer greater than 0"
if tokens_to_generate < 0:
return "tokens_to_generate must be an integer greater than or equal to 0"
logprobs = False
if "logprobs" in request.get_json():
logprobs = request.get_json()["logprobs"]
if not isinstance(logprobs, bool):
return "logprobs must be a boolean value"
if tokens_to_generate == 0 and not logprobs:
return "tokens_to_generate=0 implies logprobs should be True"
temperature = 1.0
if "temperature" in request.get_json():
temperature = request.get_json()["temperature"]
if not (type(temperature) == int or type(temperature) == float):
return "temperature must be a positive number less than or equal to 100.0"
if not (0.0 < temperature <= 100.0):
return "temperature must be a positive number less than or equal to 100.0"
top_k = 0.0
if "top_k" in request.get_json():
top_k = request.get_json()["top_k"]
if not (type(top_k) == int):
return "top_k must be an integer equal to or greater than 0 and less than or equal to 1000"
if not (0 <= top_k <= 1000):
return "top_k must be equal to or greater than 0 and less than or equal to 1000"
top_p = 0.0
if "top_p" in request.get_json():
top_p = request.get_json()["top_p"]
if not (type(top_p) == float):
return "top_p must be a positive float less than or equal to 1.0"
if top_p > 0.0 and top_k > 0.0:
return "cannot set both top-k and top-p samplings."
if not (0 <= top_p <= 1.0):
return "top_p must be less than or equal to 1.0"
add_BOS = False
if "add_BOS" in request.get_json():
add_BOS = request.get_json()["add_BOS"]
if not isinstance(add_BOS, bool):
return "add_BOS must be a boolean value"
if any([len(prompt) == 0 for prompt in prompts]) and not add_BOS:
return "Empty prompts require add_BOS=true"
stop_on_double_eol = False
if "stop_on_double_eol" in request.get_json():
stop_on_double_eol = request.get_json()["stop_on_double_eol"]
if not isinstance(stop_on_double_eol, bool):
return "stop_on_double_eol must be a boolean value"
stop_on_eol = False
if "stop_on_eol" in request.get_json():
stop_on_eol = request.get_json()["stop_on_eol"]
if not isinstance(stop_on_eol, bool):
return "stop_on_eol must be a boolean value"
random_seed = -1
if "random_seed" in request.get_json():
random_seed = request.get_json()["random_seed"]
if not isinstance(random_seed, int):
return "random_seed must be integer"
if random_seed < 0:
return "random_seed must be a positive integer"
no_log = False
if "no_log" in request.get_json():
no_log = request.get_json()["no_log"]
if not isinstance(no_log, bool):
return "no_log must be a boolean value"
with lock: # Need to get lock to keep multiple threads from hitting code
if not no_log:
print("request IP: " + str(request.remote_addr))
print(json.dumps(request.get_json()),flush=True)
print("start time: ", datetime.datetime.now())
MegatronGenerate.send_do_generate() # Tell other ranks we're doing generate
try:
response, response_seg, response_logprobs, _ = \
generate_and_post_process(
self.model,
prompts=prompts,
tokens_to_generate=tokens_to_generate,
return_output_log_probs=logprobs,
top_k_sampling=top_k,
top_p_sampling=top_p,
temperature=temperature,
add_BOS=add_BOS,
use_eod_token_for_early_termination=True,
stop_on_double_eol=stop_on_double_eol,
stop_on_eol=stop_on_eol,
random_seed=random_seed)
except ValueError as ve:
return "Length of prompt + tokens_to_generate longer than allowed"
print("end time: ", datetime.datetime.now())
return jsonify({"text": response,
"segments": response_seg,
"logprobs": response_logprobs})
class MegatronServer(object):
def __init__(self, model):
self.app = Flask(__name__, static_url_path='')
api = Api(self.app)
api.add_resource(MegatronGenerate, '/api', resource_class_args=[model])
def run(self, url):
self.app.run(url, threaded=True, debug=False)
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .tokenizer import build_tokenizer
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tokenization classes."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import re
import unicodedata
import six
def validate_case_matches_checkpoint(do_lower_case, init_checkpoint):
"""Checks whether the casing config is consistent with the checkpoint name."""
# The casing has to be passed in by the user and there is no explicit check
# as to whether it matches the checkpoint. The casing information probably
# should have been stored in the bert_config.json file, but it's not, so
# we have to heuristically detect it to validate.
if not init_checkpoint:
return
m = re.match("^.*?([A-Za-z0-9_-]+)/bert_model.ckpt", init_checkpoint)
if m is None:
return
model_name = m.group(1)
lower_models = [
"uncased_L-24_H-1024_A-16", "uncased_L-12_H-768_A-12",
"multilingual_L-12_H-768_A-12", "chinese_L-12_H-768_A-12"
]
cased_models = [
"cased_L-12_H-768_A-12", "cased_L-24_H-1024_A-16",
"multi_cased_L-12_H-768_A-12"
]
is_bad_config = False
if model_name in lower_models and not do_lower_case:
is_bad_config = True
actual_flag = "False"
case_name = "lowercased"
opposite_flag = "True"
if model_name in cased_models and do_lower_case:
is_bad_config = True
actual_flag = "True"
case_name = "cased"
opposite_flag = "False"
if is_bad_config:
raise ValueError(
"You passed in `--do_lower_case=%s` with `--init_checkpoint=%s`. "
"However, `%s` seems to be a %s model, so you "
"should pass in `--do_lower_case=%s` so that the fine-tuning matches "
"how the model was pre-training. If this error is wrong, please "
"just comment out this check." % (actual_flag, init_checkpoint,
model_name, case_name, opposite_flag))
def convert_to_unicode(text):
"""Converts `text` to Unicode (if it's not already), assuming utf-8 input."""
if six.PY3:
if isinstance(text, str):
return text
elif isinstance(text, bytes):
return text.decode("utf-8", "ignore")
else:
raise ValueError("Unsupported string type: %s" % (type(text)))
elif six.PY2:
if isinstance(text, str):
return text.decode("utf-8", "ignore")
elif isinstance(text, unicode):
return text
else:
raise ValueError("Unsupported string type: %s" % (type(text)))
else:
raise ValueError("Not running on Python2 or Python 3?")
def printable_text(text):
"""Returns text encoded in a way suitable for print or `tf.logging`."""
# These functions want `str` for both Python2 and Python3, but in one case
# it's a Unicode string and in the other it's a byte string.
if six.PY3:
if isinstance(text, str):
return text
elif isinstance(text, bytes):
return text.decode("utf-8", "ignore")
else:
raise ValueError("Unsupported string type: %s" % (type(text)))
elif six.PY2:
if isinstance(text, str):
return text
elif isinstance(text, unicode):
return text.encode("utf-8")
else:
raise ValueError("Unsupported string type: %s" % (type(text)))
else:
raise ValueError("Not running on Python2 or Python 3?")
def load_vocab(vocab_file):
"""Loads a vocabulary file into a dictionary."""
vocab = collections.OrderedDict()
index = 0
with open(vocab_file, "r") as reader:
while True:
token = convert_to_unicode(reader.readline())
if not token:
break
token = token.strip()
vocab[token] = index
index += 1
return vocab
def convert_by_vocab(vocab, items):
"""Converts a sequence of [tokens|ids] using the vocab."""
output = []
for item in items:
output.append(vocab[item])
return output
def convert_tokens_to_ids(vocab, tokens):
return convert_by_vocab(vocab, tokens)
def convert_ids_to_tokens(inv_vocab, ids):
return convert_by_vocab(inv_vocab, ids)
def whitespace_tokenize(text):
"""Runs basic whitespace cleaning and splitting on a piece of text."""
text = text.strip()
if not text:
return []
tokens = text.split()
return tokens
class FullTokenizer(object):
"""Runs end-to-end tokenziation."""
def __init__(self, vocab_file, do_lower_case=True):
self.vocab = load_vocab(vocab_file)
self.inv_vocab = {v: k for k, v in self.vocab.items()}
self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case)
self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab)
def tokenize(self, text):
split_tokens = []
for token in self.basic_tokenizer.tokenize(text):
for sub_token in self.wordpiece_tokenizer.tokenize(token):
split_tokens.append(sub_token)
return split_tokens
def convert_tokens_to_ids(self, tokens):
return convert_by_vocab(self.vocab, tokens)
def convert_ids_to_tokens(self, ids):
return convert_by_vocab(self.inv_vocab, ids)
@staticmethod
def convert_tokens_to_string(tokens, clean_up_tokenization_spaces=True):
""" Converts a sequence of tokens (string) in a single string. """
def clean_up_tokenization(out_string):
""" Clean up a list of simple English tokenization artifacts
like spaces before punctuations and abreviated forms.
"""
out_string = (
out_string.replace(" .", ".")
.replace(" ?", "?")
.replace(" !", "!")
.replace(" ,", ",")
.replace(" ' ", "'")
.replace(" n't", "n't")
.replace(" 'm", "'m")
.replace(" 's", "'s")
.replace(" 've", "'ve")
.replace(" 're", "'re")
)
return out_string
text = ' '.join(tokens).replace(' ##', '').strip()
if clean_up_tokenization_spaces:
clean_text = clean_up_tokenization(text)
return clean_text
else:
return text
def vocab_size(self):
return len(self.vocab)
class BasicTokenizer(object):
"""Runs basic tokenization (punctuation splitting, lower casing, etc.)."""
def __init__(self, do_lower_case=True):
"""Constructs a BasicTokenizer.
Args:
do_lower_case: Whether to lower case the input.
"""
self.do_lower_case = do_lower_case
def tokenize(self, text):
"""Tokenizes a piece of text."""
text = convert_to_unicode(text)
text = self._clean_text(text)
# This was added on November 1st, 2018 for the multilingual and Chinese
# models. This is also applied to the English models now, but it doesn't
# matter since the English models were not trained on any Chinese data
# and generally don't have any Chinese data in them (there are Chinese
# characters in the vocabulary because Wikipedia does have some Chinese
# words in the English Wikipedia.).
text = self._tokenize_chinese_chars(text)
orig_tokens = whitespace_tokenize(text)
split_tokens = []
for token in orig_tokens:
if self.do_lower_case:
token = token.lower()
token = self._run_strip_accents(token)
split_tokens.extend(self._run_split_on_punc(token))
output_tokens = whitespace_tokenize(" ".join(split_tokens))
return output_tokens
def _run_strip_accents(self, text):
"""Strips accents from a piece of text."""
text = unicodedata.normalize("NFD", text)
output = []
for char in text:
cat = unicodedata.category(char)
if cat == "Mn":
continue
output.append(char)
return "".join(output)
def _run_split_on_punc(self, text):
"""Splits punctuation on a piece of text."""
chars = list(text)
i = 0
start_new_word = True
output = []
while i < len(chars):
char = chars[i]
if _is_punctuation(char):
output.append([char])
start_new_word = True
else:
if start_new_word:
output.append([])
start_new_word = False
output[-1].append(char)
i += 1
return ["".join(x) for x in output]
def _tokenize_chinese_chars(self, text):
"""Adds whitespace around any CJK character."""
output = []
for char in text:
cp = ord(char)
if self._is_chinese_char(cp):
output.append(" ")
output.append(char)
output.append(" ")
else:
output.append(char)
return "".join(output)
def _is_chinese_char(self, cp):
"""Checks whether CP is the codepoint of a CJK character."""
# This defines a "chinese character" as anything in the CJK Unicode block:
# https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
#
# Note that the CJK Unicode block is NOT all Japanese and Korean characters,
# despite its name. The modern Korean Hangul alphabet is a different block,
# as is Japanese Hiragana and Katakana. Those alphabets are used to write
# space-separated words, so they are not treated specially and handled
# like the all of the other languages.
if ((cp >= 0x4E00 and cp <= 0x9FFF) or #
(cp >= 0x3400 and cp <= 0x4DBF) or #
(cp >= 0x20000 and cp <= 0x2A6DF) or #
(cp >= 0x2A700 and cp <= 0x2B73F) or #
(cp >= 0x2B740 and cp <= 0x2B81F) or #
(cp >= 0x2B820 and cp <= 0x2CEAF) or
(cp >= 0xF900 and cp <= 0xFAFF) or #
(cp >= 0x2F800 and cp <= 0x2FA1F)): #
return True
return False
def _clean_text(self, text):
"""Performs invalid character removal and whitespace cleanup on text."""
output = []
for char in text:
cp = ord(char)
if cp == 0 or cp == 0xfffd or _is_control(char):
continue
if _is_whitespace(char):
output.append(" ")
else:
output.append(char)
return "".join(output)
class WordpieceTokenizer(object):
"""Runs WordPiece tokenziation."""
def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=200):
self.vocab = vocab
self.unk_token = unk_token
self.max_input_chars_per_word = max_input_chars_per_word
def tokenize(self, text):
"""Tokenizes a piece of text into its word pieces.
This uses a greedy longest-match-first algorithm to perform tokenization
using the given vocabulary.
For example:
input = "unaffable"
output = ["un", "##aff", "##able"]
Args:
text: A single token or whitespace separated tokens. This should have
already been passed through `BasicTokenizer.
Returns:
A list of wordpiece tokens.
"""
text = convert_to_unicode(text)
output_tokens = []
for token in whitespace_tokenize(text):
chars = list(token)
if len(chars) > self.max_input_chars_per_word:
output_tokens.append(self.unk_token)
continue
is_bad = False
start = 0
sub_tokens = []
while start < len(chars):
end = len(chars)
cur_substr = None
while start < end:
substr = "".join(chars[start:end])
if start > 0:
substr = "##" + substr
if substr in self.vocab:
cur_substr = substr
break
end -= 1
if cur_substr is None:
is_bad = True
break
sub_tokens.append(cur_substr)
start = end
if is_bad:
output_tokens.append(self.unk_token)
else:
output_tokens.extend(sub_tokens)
return output_tokens
def _is_whitespace(char):
"""Checks whether `chars` is a whitespace character."""
# \t, \n, and \r are technically contorl characters but we treat them
# as whitespace since they are generally considered as such.
if char == " " or char == "\t" or char == "\n" or char == "\r":
return True
cat = unicodedata.category(char)
if cat == "Zs":
return True
return False
def _is_control(char):
"""Checks whether `chars` is a control character."""
# These are technically control characters but we count them as whitespace
# characters.
if char == "\t" or char == "\n" or char == "\r":
return False
cat = unicodedata.category(char)
if cat in ("Cc", "Cf"):
return True
return False
def _is_punctuation(char):
"""Checks whether `chars` is a punctuation character."""
cp = ord(char)
# We treat all non-letter/number ASCII as punctuation.
# Characters such as "^", "$", and "`" are not in the Unicode
# Punctuation class but we treat them as punctuation anyways, for
# consistency.
if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or
(cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)):
return True
cat = unicodedata.category(char)
if cat.startswith("P"):
return True
return False
# coding=utf-8
# Copyright 2018 The Open AI Team Authors and The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tokenization classes for OpenAI GPT."""
from __future__ import (absolute_import, division, print_function,
unicode_literals)
import sys
import json
import logging
import os
import regex as re
from io import open
try:
from functools import lru_cache
except ImportError:
# Just a dummy decorator to get the checks to run on python2
# because honestly I don't want to support a byte-level unicode BPE
# tokenizer on python 2 right now.
def lru_cache():
return lambda func: func
logger = logging.getLogger(__name__)
PRETRAINED_VOCAB_ARCHIVE_MAP = {
'gpt2': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-vocab.json",
}
PRETRAINED_MERGES_ARCHIVE_MAP = {
'gpt2': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-merges.txt",
}
PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP = {
'gpt2': 1024,
}
VOCAB_NAME = 'vocab.json'
MERGES_NAME = 'merges.txt'
SPECIAL_TOKENS_NAME = 'special_tokens.txt'
@lru_cache()
def bytes_to_unicode():
"""
Returns list of utf-8 byte and a corresponding list of unicode strings.
The reversible bpe codes work on unicode strings.
This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
This is a signficant percentage of your normal, say, 32K bpe vocab.
To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
And avoids mapping to whitespace/control characters the bpe code barfs on.
"""
_chr = unichr if sys.version_info[0] == 2 else chr
bs = list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + \
list(range(ord("®"), ord("ÿ") + 1))
cs = bs[:]
n = 0
for b in range(2**8):
if b not in bs:
bs.append(b)
cs.append(2**8 + n)
n += 1
cs = [_chr(n) for n in cs]
return dict(zip(bs, cs))
def get_pairs(word):
"""Return set of symbol pairs in a word.
Word is represented as tuple of symbols (symbols being variable-length strings).
"""
pairs = set()
prev_char = word[0]
for char in word[1:]:
pairs.add((prev_char, char))
prev_char = char
return pairs
class GPT2Tokenizer(object):
"""
GPT-2 BPE tokenizer. Peculiarities:
- Byte-level BPE
"""
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs, **kwargs):
"""
Instantiate a PreTrainedBertModel from a pre-trained model file.
Download and cache the pre-trained model file if needed.
"""
if pretrained_model_name_or_path in PRETRAINED_VOCAB_ARCHIVE_MAP:
vocab_file = PRETRAINED_VOCAB_ARCHIVE_MAP[pretrained_model_name_or_path]
merges_file = PRETRAINED_MERGES_ARCHIVE_MAP[pretrained_model_name_or_path]
special_tokens_file = None
else:
vocab_file = os.path.join(pretrained_model_name_or_path, VOCAB_NAME)
merges_file = os.path.join(pretrained_model_name_or_path, MERGES_NAME)
special_tokens_file = os.path.join(pretrained_model_name_or_path, SPECIAL_TOKENS_NAME)
if not os.path.exists(special_tokens_file):
special_tokens_file = None
else:
logger.info("loading special tokens file {}".format(special_tokens_file))
# redirect to the cache, if necessary
try:
from .file_utils import cached_path
resolved_vocab_file = cached_path(vocab_file, cache_dir=cache_dir)
resolved_merges_file = cached_path(merges_file, cache_dir=cache_dir)
except EnvironmentError:
logger.error(
"Model name '{}' was not found in model name list ({}). "
"We assumed '{}' was a path or url but couldn't find files {} and {} "
"at this path or url.".format(
pretrained_model_name_or_path,
', '.join(PRETRAINED_VOCAB_ARCHIVE_MAP.keys()),
pretrained_model_name_or_path,
vocab_file, merges_file))
return None
if resolved_vocab_file == vocab_file and resolved_merges_file == merges_file:
logger.info("loading vocabulary file {}".format(vocab_file))
logger.info("loading merges file {}".format(merges_file))
else:
logger.info("loading vocabulary file {} from cache at {}".format(
vocab_file, resolved_vocab_file))
logger.info("loading merges file {} from cache at {}".format(
merges_file, resolved_merges_file))
if pretrained_model_name_or_path in PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP:
# if we're using a pretrained model, ensure the tokenizer wont index sequences longer
# than the number of positional embeddings
max_len = PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP[pretrained_model_name_or_path]
kwargs['max_len'] = min(kwargs.get('max_len', int(1e12)), max_len)
# Instantiate tokenizer.
if special_tokens_file and 'special_tokens' not in kwargs:
special_tokens = open(special_tokens_file, encoding='utf-8').read().split('\n')[:-1]
else:
special_tokens = kwargs.pop('special_tokens', [])
tokenizer = cls(
resolved_vocab_file,
resolved_merges_file,
special_tokens=special_tokens,
*inputs,
**kwargs)
return tokenizer
def __init__(self, vocab_file, merges_file, errors='replace',
special_tokens=None, max_len=None):
self.max_len = max_len if max_len is not None else int(1e12)
self.encoder = json.load(open(vocab_file))
self.decoder = {v: k for k, v in self.encoder.items()}
self.errors = errors # how to handle errors in decoding
self.byte_encoder = bytes_to_unicode()
self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
bpe_data = open(merges_file, encoding='utf-8').read().split('\n')[1:-1]
bpe_merges = [tuple(merge.split()) for merge in bpe_data]
self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges))))
self.cache = {}
# Should haved added re.IGNORECASE so BPE merges can happen for
# capitalized versions of contractions
self.pat = re.compile(
r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""")
self.special_tokens = {}
self.special_tokens_decoder = {}
self.set_special_tokens(special_tokens)
def __len__(self):
return len(self.encoder) + len(self.special_tokens)
def set_special_tokens(self, special_tokens):
""" Add a list of additional tokens to the encoder.
The additional tokens are indexed starting from the last index of the
current vocabulary in the order of the `special_tokens` list.
"""
if not special_tokens:
self.special_tokens = {}
self.special_tokens_decoder = {}
return
self.special_tokens = dict((tok, len(self.encoder) + i)
for i, tok in enumerate(special_tokens))
self.special_tokens_decoder = {v: k for k, v in self.special_tokens.items()}
logger.info("Special tokens {}".format(self.special_tokens))
def bpe(self, token):
if token in self.cache:
return self.cache[token]
word = tuple(token)
pairs = get_pairs(word)
if not pairs:
return token
while True:
bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float('inf')))
if bigram not in self.bpe_ranks:
break
first, second = bigram
new_word = []
i = 0
while i < len(word):
try:
j = word.index(first, i)
new_word.extend(word[i:j])
i = j
except BaseException:
new_word.extend(word[i:])
break
if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
new_word.append(first + second)
i += 2
else:
new_word.append(word[i])
i += 1
new_word = tuple(new_word)
word = new_word
if len(word) == 1:
break
else:
pairs = get_pairs(word)
word = ' '.join(word)
self.cache[token] = word
return word
def tokenize(self, text):
""" Tokenize a string. """
bpe_tokens = []
for token in re.findall(self.pat, text):
if sys.version_info[0] == 2:
token = ''.join(self.byte_encoder[ord(b)] for b in token)
else:
token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
bpe_tokens.extend(bpe_token for bpe_token in self.bpe(token).split(' '))
return bpe_tokens
def convert_tokens_to_ids(self, tokens):
""" Converts a sequence of tokens into ids using the vocab. """
ids = []
if isinstance(tokens, str) or (sys.version_info[0] == 2 and isinstance(tokens, unicode)):
if tokens in self.special_tokens:
return self.special_tokens[tokens]
else:
return self.encoder.get(tokens, 0)
for token in tokens:
if token in self.special_tokens:
ids.append(self.special_tokens[token])
else:
ids.append(self.encoder.get(token, 0))
if len(ids) > self.max_len:
logger.warning(
"Token indices sequence length is longer than the specified maximum "
" sequence length for this OpenAI GPT model ({} > {}). Running this"
" sequence through the model will result in indexing errors".format(
len(ids), self.max_len)
)
return ids
def convert_ids_to_tokens(self, ids, skip_special_tokens=False):
"""Converts a sequence of ids in BPE tokens using the vocab."""
tokens = []
for i in ids:
if i in self.special_tokens_decoder:
if not skip_special_tokens:
tokens.append(self.special_tokens_decoder[i])
else:
tokens.append(self.decoder[i])
return tokens
def encode(self, text):
return self.convert_tokens_to_ids(self.tokenize(text))
def decode(self, tokens):
text = ''.join([self.decoder[token] for token in tokens])
text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors=self.errors)
return text
def save_vocabulary(self, vocab_path):
"""Save the tokenizer vocabulary and merge files to a directory."""
if not os.path.isdir(vocab_path):
logger.error("Vocabulary path ({}) should be a directory".format(vocab_path))
return
vocab_file = os.path.join(vocab_path, VOCAB_NAME)
merge_file = os.path.join(vocab_path, MERGES_NAME)
special_tokens_file = os.path.join(vocab_path, SPECIAL_TOKENS_NAME)
with open(vocab_file, 'w', encoding='utf-8') as f:
f.write(json.dumps(self.encoder, ensure_ascii=False))
index = 0
with open(merge_file, "w", encoding="utf-8") as writer:
writer.write(u'#version: 0.2\n')
for bpe_tokens, token_index in sorted(self.bpe_ranks.items(), key=lambda kv: kv[1]):
if index != token_index:
logger.warning("Saving vocabulary to {}: BPE merge indices are not consecutive."
" Please check that the tokenizer is not corrupted!".format(merge_file))
index = token_index
writer.write(' '.join(bpe_tokens) + u'\n')
index += 1
index = len(self.encoder)
with open(special_tokens_file, 'w', encoding='utf-8') as writer:
for token, token_index in sorted(self.special_tokens.items(), key=lambda kv: kv[1]):
if index != token_index:
logger.warning("Saving special tokens vocabulary to {}: BPE indices are not consecutive."
" Please check that the tokenizer is not corrupted!".format(special_tokens_file))
index = token_index
writer.write(token + u'\n')
index += 1
return vocab_file, merge_file, special_tokens_file
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Megatron tokenizers."""
from abc import ABC
from abc import abstractmethod
from .bert_tokenization import FullTokenizer as FullBertTokenizer
from .gpt2_tokenization import GPT2Tokenizer
def build_tokenizer(args):
"""Initialize tokenizer."""
if args.rank == 0:
print('> building {} tokenizer ...'.format(args.tokenizer_type),
flush=True)
# Select and instantiate the tokenizer.
assert args.vocab_file is not None
if args.tokenizer_type == 'BertWordPieceLowerCase':
tokenizer = _BertWordPieceTokenizer(vocab_file=args.vocab_file,
lower_case=True,
vocab_extra_ids=args.vocab_extra_ids)
elif args.tokenizer_type == 'BertWordPieceCase':
tokenizer = _BertWordPieceTokenizer(vocab_file=args.vocab_file,
lower_case=False,
vocab_extra_ids=args.vocab_extra_ids)
elif args.tokenizer_type == 'GPT2BPETokenizer':
assert args.merge_file is not None
tokenizer = _GPT2BPETokenizer(args.vocab_file, args.merge_file)
else:
raise NotImplementedError('{} tokenizer is not '
'implemented.'.format(args.tokenizer_type))
# Add vocab size.
args.padded_vocab_size = _vocab_size_with_padding(tokenizer.vocab_size,
args)
return tokenizer
def _vocab_size_with_padding(orig_vocab_size, args):
"""Pad vocab size so it is divisible by model parallel size and
still having GPU friendly size."""
after = orig_vocab_size
multiple = args.make_vocab_size_divisible_by * \
args.tensor_model_parallel_size
while (after % multiple) != 0:
after += 1
if args.rank == 0:
print(' > padded vocab (size: {}) with {} dummy tokens '
'(new size: {})'.format(
orig_vocab_size, after - orig_vocab_size, after), flush=True)
return after
class AbstractTokenizer(ABC):
"""Abstract class for tokenizer."""
def __init__(self, name):
self.name = name
super().__init__()
@property
@abstractmethod
def vocab_size(self):
pass
@property
@abstractmethod
def vocab(self):
"""Dictionary from vocab text token to id token."""
pass
@property
@abstractmethod
def inv_vocab(self):
"""Dictionary from vocab id token to text token."""
pass
@abstractmethod
def tokenize(self, text):
pass
def detokenize(self, token_ids):
raise NotImplementedError('detokenizer is not implemented for {} '
'tokenizer'.format(self.name))
@property
def cls(self):
raise NotImplementedError('CLS is not provided for {} '
'tokenizer'.format(self.name))
@property
def sep(self):
raise NotImplementedError('SEP is not provided for {} '
'tokenizer'.format(self.name))
@property
def pad(self):
raise NotImplementedError('PAD is not provided for {} '
'tokenizer'.format(self.name))
@property
def eod(self):
raise NotImplementedError('EOD is not provided for {} '
'tokenizer'.format(self.name))
@property
def mask(self):
raise NotImplementedError('MASK is not provided for {} '
'tokenizer'.format(self.name))
class _BertWordPieceTokenizer(AbstractTokenizer):
"""Original BERT wordpiece tokenizer."""
def __init__(self, vocab_file, lower_case=True, vocab_extra_ids=0):
if lower_case:
name = 'BERT Lower Case'
else:
name = 'BERT Upper Case'
super().__init__(name)
self.tokenizer = FullBertTokenizer(vocab_file, do_lower_case=lower_case)
self.cls_id = self.tokenizer.vocab['[CLS]']
self.sep_id = self.tokenizer.vocab['[SEP]']
self.pad_id = self.tokenizer.vocab['[PAD]']
self.mask_id = self.tokenizer.vocab['[MASK]']
self._additional_special_tokens = []
# (dsachan) Add BOS and EOS tokens
SPECIAL_TOKENS = {'eos_token': '[EOS]',
'bos_token': '[BOS]'}
self._bos_token = '[BOS]'
self.add_token(self._bos_token)
self._bos_token_id = self.vocab.get(self._bos_token)
self._eos_token = '[EOS]'
self.add_token(self._eos_token)
self._eos_token_id = self.vocab.get(self._eos_token)
# (dsachan) Add additional special tokens
# These can be used as sentinel tokens in T5 model inputs
additional_special_tokens = []
additional_special_tokens.extend(
["<extra_id_{}>".format(i) for i in range(vocab_extra_ids)])
self.add_additional_special_tokens(additional_special_tokens)
def add_token(self, token):
if token not in self.vocab:
self.inv_vocab[self.vocab_size] = token
# self.vocab_size comes from len(vocab)
# and it will increase as we add elements
self.vocab[token] = self.vocab_size
def add_additional_special_tokens(self, tokens_list):
setattr(self, "additional_special_tokens", tokens_list)
for value in tokens_list:
self.add_token(value)
@property
def vocab_size(self):
return self.tokenizer.vocab_size()
@property
def vocab(self):
return self.tokenizer.vocab
@property
def inv_vocab(self):
return self.tokenizer.inv_vocab
def tokenize(self, text):
text_tokens = self.tokenizer.tokenize(text)
return self.tokenizer.convert_tokens_to_ids(text_tokens)
def decode(self, ids):
tokens = self.tokenizer.convert_ids_to_tokens(ids)
return self.tokenizer.convert_tokens_to_string(tokens)
def decode_token_ids(self, token_ids):
tokens = self.tokenizer.convert_ids_to_tokens(token_ids)
exclude_list = ['[PAD]', '[CLS]']
non_pads = [t for t in tokens if t not in exclude_list]
result = ""
for s in non_pads:
if s.startswith("##"):
result += s[2:]
else:
result += " " + s
return result
@property
def cls(self):
return self.cls_id
@property
def sep(self):
return self.sep_id
@property
def pad(self):
return self.pad_id
@property
def mask(self):
return self.mask_id
@property
def bos_token(self):
""" Beginning of sentence token id """
return self._bos_token
@property
def eos_token(self):
""" End of sentence token id """
return self._eos_token
@property
def additional_special_tokens(self):
""" All the additional special tokens you may want to use (list of strings)."""
return self._additional_special_tokens
@property
def bos_token_id(self):
""" Id of the beginning of sentence token in the vocabulary."""
return self._bos_token_id
@property
def eos_token_id(self):
""" Id of the end of sentence token in the vocabulary."""
return self._eos_token_id
@property
def additional_special_tokens_ids(self):
""" Ids of all the additional special tokens in the vocabulary (list of integers)."""
return [self.vocab.get(token) for token in self._additional_special_tokens]
@additional_special_tokens.setter
def additional_special_tokens(self, value):
self._additional_special_tokens = value
class _GPT2BPETokenizer(AbstractTokenizer):
"""Original GPT2 BPE tokenizer."""
def __init__(self, vocab_file, merge_file):
name = 'GPT2 BPE'
super().__init__(name)
self.tokenizer = GPT2Tokenizer(vocab_file, merge_file, errors='replace',
special_tokens=[], max_len=None)
self.eod_id = self.tokenizer.encoder['<|endoftext|>']
@property
def vocab_size(self):
return len(self.tokenizer.encoder)
@property
def vocab(self):
return self.tokenizer.encoder
@property
def inv_vocab(self):
return self.tokenizer.decoder
def tokenize(self, text):
return self.tokenizer.encode(text)
def detokenize(self, token_ids):
return self.tokenizer.decode(token_ids)
@property
def eod(self):
return self.eod_id
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Pretrain utilities."""
from datetime import datetime
import math
import sys
import time
# The earliest we can measure the start time.
_TRAIN_START_TIME = time.time()
import torch
from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
from megatron import get_args
from megatron import get_signal_handler
from megatron import get_timers
from megatron import get_tensorboard_writer
from megatron import get_current_global_batch_size
from megatron import get_num_microbatches
from megatron import is_last_rank
from megatron import update_num_microbatches
from megatron import mpu
from megatron import print_rank_0
from megatron import print_rank_last
from megatron.checkpointing import load_checkpoint
from megatron.checkpointing import save_checkpoint
from megatron.model import Float16Module
from megatron.model import ModelType
from megatron.optimizer import get_megatron_optimizer
from megatron.initialize import initialize_megatron
from megatron.initialize import write_args_to_tensorboard
from megatron.initialize import set_jit_fusion_options
from megatron.optimizer_param_scheduler import OptimizerParamScheduler
from megatron.model import DistributedDataParallel as LocalDDP
from megatron.utils import check_adlr_autoresume_termination
from megatron.utils import unwrap_model
from megatron.data.data_samplers import build_pretraining_data_loader
from megatron.utils import calc_params_l2_norm
from megatron.schedules import get_forward_backward_func
from megatron.utils import report_memory
from megatron.model.vision.knn_monitor import compute_feature_bank
def print_datetime(string):
"""Note that this call will sync across all ranks."""
torch.distributed.barrier()
time_str = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
print_rank_0('[' + string + '] datetime: {} '.format(time_str))
def pretrain(train_valid_test_dataset_provider,
model_provider,
model_type,
forward_step_func,
process_non_loss_data_func=None,
extra_args_provider=None,
args_defaults={}):
"""Main training program.
This function will run the followings in the order provided:
1) initialize Megatron.
2) setup model, optimizer and lr schedule using the model_provider.
3) call train_val_test_data_provider to get train/val/test datasets.
4) train the modle using the forward_step_func.
Arguments:
train_valid_test_dataset_provider: a function that takes the size of
train/valid/test dataset and returns `train, valid, test` datasets.
model_provider: a function that returns a vanilla version of the
model. By vanilla we mean a simple model on cpu with no fp16 or ddp.
model_type: an enum that specifies the type of model being trained.
forward_step_func: a function that takes a `data iterator` and `model`,
and returns a `loss` scalar with a dictionary with key:values being
the info we would like to monitor during training, for example
`lm-loss: value`. We also require that this function add
`batch generator` to the timers class.
process_non_loss_data_func: a function to post process outputs of the
network. It can be used for dumping output tensors (e.g images) to
tensorboard. It takes `collected data`(list of tensors),
`current iteration index` and `tensorboard writer` as arguments.
extra_args_provider: a function that takes a parser and adds arguments
to it. It is used for programs to add their own arguments.
args_defaults: a dictionary from argument-name to argument-value. It
to set already parse arguments.
"""
# Initalize and get arguments, timers, and Tensorboard writer.
initialize_megatron(extra_args_provider=extra_args_provider,
args_defaults=args_defaults)
# Set pytorch JIT layer fusion options and warmup JIT functions.
set_jit_fusion_options()
# Adjust the startup time so it reflects the largest value.
# This will be closer to what scheduler will see (outside of
# image ... launches.
global _TRAIN_START_TIME
start_time_tensor = torch.cuda.DoubleTensor([_TRAIN_START_TIME])
torch.distributed.all_reduce(start_time_tensor,
op=torch.distributed.ReduceOp.MIN)
_TRAIN_START_TIME = start_time_tensor.item()
print_rank_0('time to initialize megatron (seconds): {:.3f}'.format(
time.time() - _TRAIN_START_TIME))
print_datetime('after megatron is initialized')
args = get_args()
timers = get_timers()
# Model, optimizer, and learning rate.
timers('model-and-optimizer-setup').start()
model, optimizer, opt_param_scheduler = setup_model_and_optimizer(model_provider,
model_type)
timers('model-and-optimizer-setup').stop()
print_datetime('after model, optimizer, and learning rate '
'scheduler are built')
# Data stuff.
timers('train/valid/test-data-iterators-setup').start()
if args.virtual_pipeline_model_parallel_size is not None:
all_data_iterators = [
build_train_valid_test_data_iterators(train_valid_test_dataset_provider)
for _ in range(len(model))
]
train_data_iterator = [data_iterators[0] for data_iterators in all_data_iterators]
valid_data_iterator = [data_iterators[1] for data_iterators in all_data_iterators]
test_data_iterator = [data_iterators[2] for data_iterators in all_data_iterators]
else:
train_data_iterator, valid_data_iterator, test_data_iterator \
= build_train_valid_test_data_iterators(
train_valid_test_dataset_provider)
timers('train/valid/test-data-iterators-setup').stop()
print_datetime('after dataloaders are built')
# Print setup timing.
print_rank_0('done with setup ...')
timers.log(['model-and-optimizer-setup', 'train/valid/test-data-iterators-setup'])
print_rank_0('training ...')
iteration = 0
if args.do_train and args.train_iters > 0:
iteration = train(forward_step_func,
model, optimizer, opt_param_scheduler,
train_data_iterator, valid_data_iterator,
process_non_loss_data_func)
print_datetime('after training is done')
if args.do_valid:
prefix = 'the end of training for val data'
evaluate_and_print_results(prefix, forward_step_func,
valid_data_iterator, model,
iteration, process_non_loss_data_func,
False)
if args.save and iteration != 0:
save_checkpoint(iteration, model, optimizer, opt_param_scheduler)
if args.do_test:
# Run on test data.
prefix = 'the end of training for test data'
evaluate_and_print_results(prefix, forward_step_func,
test_data_iterator, model,
0, process_non_loss_data_func,
True)
def update_train_iters(args):
# For iteration-based training, we don't need to do anything
if args.train_iters:
return
# Constant batch size with sample-based training.
if args.rampup_batch_size is None:
args.train_iters = args.train_samples // args.global_batch_size
else:
# Sample based training with rampup batch size.
iterations = 0
consumed_samples = 0
# Rampup phase.
while consumed_samples <= int(args.rampup_batch_size[2]):
update_num_microbatches(consumed_samples, consistency_check=False)
consumed_samples += get_current_global_batch_size()
iterations += 1
# Reset
update_num_microbatches(0, consistency_check=False)
# Constant phase
# Note that we throw away any partial last batch.
iterations += (args.train_samples - consumed_samples) // \
args.global_batch_size
args.train_iters = iterations
print_rank_0('setting training iterations to {}'.format(args.train_iters))
def get_model(model_provider_func, model_type=ModelType.encoder_or_decoder, wrap_with_ddp=True):
"""Build the model."""
args = get_args()
args.model_type = model_type
# Build model.
if mpu.get_pipeline_model_parallel_world_size() > 1 and \
args.virtual_pipeline_model_parallel_size is not None:
assert model_type != ModelType.encoder_and_decoder, \
"Interleaved schedule not supported for model with both encoder and decoder"
model = []
for i in range(args.virtual_pipeline_model_parallel_size):
mpu.set_virtual_pipeline_model_parallel_rank(i)
# Set pre_process and post_process only after virtual rank is set.
pre_process = mpu.is_pipeline_first_stage()
post_process = mpu.is_pipeline_last_stage()
this_model = model_provider_func(
pre_process=pre_process,
post_process=post_process
)
this_model.model_type = model_type
model.append(this_model)
else:
pre_process = mpu.is_pipeline_first_stage()
post_process = mpu.is_pipeline_last_stage()
add_encoder = True
add_decoder = True
if model_type == ModelType.encoder_and_decoder:
if mpu.get_pipeline_model_parallel_world_size() > 1:
assert args.pipeline_model_parallel_split_rank is not None, \
"Split rank needs to be specified for model with both encoder and decoder"
rank = mpu.get_pipeline_model_parallel_rank()
split_rank = args.pipeline_model_parallel_split_rank
world_size = mpu.get_pipeline_model_parallel_world_size()
pre_process = rank == 0 or rank == split_rank
post_process = (rank == (split_rank - 1)) or (
rank == (world_size - 1))
add_encoder = mpu.is_pipeline_stage_before_split()
add_decoder = mpu.is_pipeline_stage_after_split()
model = model_provider_func(
pre_process=pre_process,
post_process=post_process,
add_encoder=add_encoder,
add_decoder=add_decoder)
else:
model = model_provider_func(
pre_process=pre_process,
post_process=post_process
)
model.model_type = model_type
if not isinstance(model, list):
model = [model]
# Set tensor model parallel attributes if not set.
# Only parameters that are already tensor model parallel have these
# attributes set for them. We should make sure the default attributes
# are set for all params so the optimizer can use them.
for model_module in model:
for param in model_module.parameters():
mpu.set_defaults_if_not_set_tensor_model_parallel_attributes(param)
# Print number of parameters.
if mpu.get_data_parallel_rank() == 0:
print(' > number of parameters on (tensor, pipeline) '
'model parallel rank ({}, {}): {}'.format(
mpu.get_tensor_model_parallel_rank(),
mpu.get_pipeline_model_parallel_rank(),
sum([sum([p.nelement() for p in model_module.parameters()])
for model_module in model])), flush=True)
# GPU allocation.
for model_module in model:
model_module.cuda(torch.cuda.current_device())
# Fp16 conversion.
if args.fp16 or args.bf16:
model = [Float16Module(model_module, args) for model_module in model]
if wrap_with_ddp:
if args.DDP_impl == 'torch':
i = torch.cuda.current_device()
model = [torchDDP(model_module, device_ids=[i], output_device=i,
process_group=mpu.get_data_parallel_group())
for model_module in model]
elif args.DDP_impl == 'local':
model = [LocalDDP(model_module,
args.accumulate_allreduce_grads_in_fp32,
args.use_contiguous_buffers_in_local_ddp)
for model_module in model]
# broad cast params from data parallel src rank to other data parallel ranks
if args.data_parallel_random_init:
for model_module in model:
model_module.broadcast_params()
else:
raise NotImplementedError('Unknown DDP implementation specified: '
'{}. Exiting.'.format(args.DDP_impl))
return model
def get_optimizer_param_scheduler(optimizer):
"""Build the learning rate scheduler."""
args = get_args()
# Iteration-based training.
if args.train_iters:
if args.lr_decay_iters is None:
args.lr_decay_iters = args.train_iters
lr_decay_steps = args.lr_decay_iters * args.global_batch_size
wd_incr_steps = args.train_iters * args.global_batch_size
if args.lr_warmup_fraction is not None:
lr_warmup_steps = args.lr_warmup_fraction * lr_decay_steps
else:
lr_warmup_steps = args.lr_warmup_iters * args.global_batch_size
# Sample-based training.
elif args.train_samples:
# We need to set training iters for later use. Technically
# we need to adjust the training samples too (due to last
# batch being incomplete) but we leave it as is for now.
update_train_iters(args)
if args.lr_decay_samples is None:
args.lr_decay_samples = args.train_samples
lr_decay_steps = args.lr_decay_samples
wd_incr_steps = args.train_samples
if args.lr_warmup_fraction is not None:
lr_warmup_steps = args.lr_warmup_fraction * lr_decay_steps
else:
lr_warmup_steps = args.lr_warmup_samples
else:
raise Exception(
'either train-iters or train-samples should be provided.')
opt_param_scheduler = OptimizerParamScheduler(
optimizer,
max_lr=args.lr,
min_lr=args.min_lr,
lr_warmup_steps=lr_warmup_steps,
lr_decay_steps=lr_decay_steps,
lr_decay_style=args.lr_decay_style,
start_wd=args.start_weight_decay,
end_wd=args.end_weight_decay,
wd_incr_steps=wd_incr_steps,
wd_incr_style=args.weight_decay_incr_style,
use_checkpoint_opt_param_scheduler=args.use_checkpoint_opt_param_scheduler,
override_opt_param_scheduler=args.override_opt_param_scheduler)
return opt_param_scheduler
def setup_model_and_optimizer(model_provider_func,
model_type,
no_wd_decay_cond=None,
scale_lr_cond=None,
lr_mult=1.0):
"""Setup model and optimizer."""
args = get_args()
model = get_model(model_provider_func, model_type)
unwrapped_model = unwrap_model(model,
(torchDDP, LocalDDP, Float16Module))
optimizer = get_megatron_optimizer(unwrapped_model, no_wd_decay_cond,
scale_lr_cond, lr_mult)
opt_param_scheduler = get_optimizer_param_scheduler(optimizer)
if args.load is not None:
timers = get_timers()
# Extra barrier is added to make sure all ranks report the
# max time.
torch.distributed.barrier()
timers('load-checkpoint').start()
args.iteration = load_checkpoint(model, optimizer, opt_param_scheduler)
torch.distributed.barrier()
timers('load-checkpoint').stop()
timers.log(['load-checkpoint'])
else:
args.iteration = 0
# We only support local DDP with multiple micro-batches.
if len(model) > 1 or mpu.get_pipeline_model_parallel_world_size() > 1:
assert args.DDP_impl == 'local'
# get model without FP16 and/or TorchDDP wrappers
if args.iteration == 0 and len(unwrapped_model) == 1 \
and hasattr(unwrapped_model[0], 'init_state_dict_from_bert'):
print_rank_0("Initializing ICT from pretrained BERT model")
unwrapped_model[0].init_state_dict_from_bert()
if args.fp16:
optimizer.reload_model_params()
return model, optimizer, opt_param_scheduler
def train_step(forward_step_func, data_iterator,
model, optimizer, opt_param_scheduler):
"""Single training step."""
args = get_args()
timers = get_timers()
# Set grad to zero.
if args.DDP_impl == 'local' and args.use_contiguous_buffers_in_local_ddp:
for partition in model:
partition.zero_grad_buffer()
optimizer.zero_grad()
forward_backward_func = get_forward_backward_func()
losses_reduced = forward_backward_func(
forward_step_func, data_iterator, model,
optimizer, timers, forward_only=False)
# Empty unused memory
if args.empty_unused_memory_level >= 1:
torch.cuda.empty_cache()
# All-reduce layernorm parameters across model parallel nodes
# when sequence parallelism is used
if mpu.get_tensor_model_parallel_world_size() > 1 and \
args.sequence_parallel:
grads = []
for model_module in model:
unwrapped_model = unwrap_model(
model_module, (torchDDP, LocalDDP, Float16Module))
for param in unwrapped_model.parameters():
if getattr(param, 'sequence_parallel', False):
grad = param.main_grad if args.DDP_impl == 'local' else param.grad
grads.append(grad.data)
coalesced = _flatten_dense_tensors(grads)
torch.distributed.all_reduce(
coalesced, group=mpu.get_tensor_model_parallel_group())
for buf, synced in zip(grads, _unflatten_dense_tensors(
coalesced, grads)):
buf.copy_(synced)
# All-reduce if needed.
if args.DDP_impl == 'local':
timers('backward-params-all-reduce').start()
for model_module in model:
model_module.allreduce_gradients()
timers('backward-params-all-reduce').stop()
# All-reduce word_embeddings' grad across first and last stages to ensure
# that word_embeddings parameters stay in sync.
# This should only run for models that support pipelined model parallelism
# (BERT and GPT-2).
timers('backward-embedding-all-reduce').start()
if mpu.is_rank_in_embedding_group(ignore_virtual=True) and \
mpu.get_pipeline_model_parallel_world_size() > 1:
if mpu.is_pipeline_first_stage(ignore_virtual=True):
unwrapped_model = model[0]
elif mpu.is_pipeline_last_stage(ignore_virtual=True):
unwrapped_model = model[-1]
else: # We do not support the interleaved schedule for T5 yet.
unwrapped_model = model[0]
unwrapped_model = unwrap_model(
unwrapped_model, (torchDDP, LocalDDP, Float16Module))
if unwrapped_model.share_word_embeddings:
word_embeddings_weight = unwrapped_model.word_embeddings_weight()
if args.DDP_impl == 'local':
grad = word_embeddings_weight.main_grad
else:
grad = word_embeddings_weight.grad
torch.distributed.all_reduce(grad, group=mpu.get_embedding_group())
# All-reduce position_embeddings grad across first (encoder) and split (decoder)
# stages to ensure that position embeddings parameters stay in sync.
# This should only run for T5 models with pipeline parallelism
if mpu.is_rank_in_position_embedding_group() and \
mpu.get_pipeline_model_parallel_world_size() > 1 and \
args.pipeline_model_parallel_split_rank is not None:
unwrapped_model = model[0]
unwrapped_model = unwrap_model(
unwrapped_model, (torchDDP, LocalDDP, Float16Module))
assert args.DDP_impl == 'local', \
'T5 model is only supported with local DDP mode'
grad = unwrapped_model.language_model.embedding.position_embeddings.weight.main_grad
torch.distributed.all_reduce(grad, group=mpu.get_position_embedding_group())
timers('backward-embedding-all-reduce').stop()
if args.vision_pretraining and args.vision_pretraining_type == "dino":
unwrapped_model = unwrap_model(model[0],
(torchDDP, LocalDDP, Float16Module))
unwrapped_model.cancel_gradients_last_layer(args.curr_iteration)
# Update parameters.
timers('optimizer').start()
update_successful, grad_norm, num_zeros_in_grad = optimizer.step()
timers('optimizer').stop()
if args.vision_pretraining and args.vision_pretraining_type == "dino":
unwrapped_model = unwrap_model(model[0],
(torchDDP, LocalDDP, Float16Module))
unwrapped_model.update_momentum(args.curr_iteration)
# Update learning rate.
if update_successful:
increment = get_num_microbatches() * \
args.micro_batch_size * \
args.data_parallel_size
opt_param_scheduler.step(increment=increment)
skipped_iter = 0
else:
skipped_iter = 1
# Empty unused memory
if args.empty_unused_memory_level >= 2:
torch.cuda.empty_cache()
if mpu.is_pipeline_last_stage(ignore_virtual=True):
# Average loss across microbatches.
loss_reduced = {}
for key in losses_reduced[0]:
losses_reduced_for_key = [x[key] for x in losses_reduced]
loss_reduced[key] = sum(losses_reduced_for_key) / len(losses_reduced_for_key)
return loss_reduced, skipped_iter, grad_norm, num_zeros_in_grad
return {}, skipped_iter, grad_norm, num_zeros_in_grad
def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
loss_scale, report_memory_flag, skipped_iter,
grad_norm, params_norm, num_zeros_in_grad):
"""Log training information such as losses, timing, ...."""
args = get_args()
timers = get_timers()
writer = get_tensorboard_writer()
# Advanced, skipped, and Nan iterations.
advanced_iters_key = 'advanced iterations'
skipped_iters_key = 'skipped iterations'
nan_iters_key = 'nan iterations'
# Advanced iterations.
if not skipped_iter:
total_loss_dict[advanced_iters_key] = total_loss_dict.get(
advanced_iters_key, 0) + 1
else:
if advanced_iters_key not in total_loss_dict:
total_loss_dict[advanced_iters_key] = 0
# Skipped iterations.
total_loss_dict[skipped_iters_key] = total_loss_dict.get(
skipped_iters_key, 0) + skipped_iter
# Update losses and set nan iterations
got_nan = False
for key in loss_dict:
if not skipped_iter:
total_loss_dict[key] = total_loss_dict.get(
key, torch.cuda.FloatTensor([0.0])) + loss_dict[key]
else:
value = loss_dict[key].float().sum().item()
is_nan = value == float('inf') or \
value == -float('inf') or \
value != value
got_nan = got_nan or is_nan
total_loss_dict[nan_iters_key] = total_loss_dict.get(
nan_iters_key, 0) + int(got_nan)
# Logging.
timers_to_log = []
def add_to_logging(name):
if name in timers.timers:
timers_to_log.append(name)
add_to_logging('forward-compute')
add_to_logging('forward-recv')
add_to_logging('forward-send')
add_to_logging('forward-backward-send-forward-backward-recv')
add_to_logging('backward-compute')
add_to_logging('backward-recv')
add_to_logging('backward-send')
add_to_logging('backward-send-forward-recv')
add_to_logging('backward-send-backward-recv')
add_to_logging('backward-params-all-reduce')
add_to_logging('backward-embedding-all-reduce')
add_to_logging('optimizer-copy-to-main-grad')
add_to_logging('optimizer-unscale-and-check-inf')
add_to_logging('optimizer-clip-main-grad')
add_to_logging('optimizer-copy-main-to-model-params')
add_to_logging('optimizer')
add_to_logging('batch-generator')
# Calculate batch size.
batch_size = args.micro_batch_size * args.data_parallel_size * \
get_num_microbatches()
total_iterations = total_loss_dict[advanced_iters_key] + \
total_loss_dict[skipped_iters_key]
# Tensorboard values.
if writer and (iteration % args.tensorboard_log_interval == 0 ) and \
is_last_rank():
if args.log_learning_rate_to_tensorboard:
writer.add_scalar('learning-rate', learning_rate, iteration)
writer.add_scalar('learning-rate vs samples', learning_rate,
args.consumed_train_samples)
if args.log_batch_size_to_tensorboard:
writer.add_scalar('batch-size', batch_size, iteration)
writer.add_scalar('batch-size vs samples', batch_size,
args.consumed_train_samples)
for key in loss_dict:
writer.add_scalar(key , loss_dict[key], iteration)
writer.add_scalar(key + ' vs samples', loss_dict[key],
args.consumed_train_samples)
if args.log_loss_scale_to_tensorboard:
writer.add_scalar('loss-scale', loss_scale, iteration)
writer.add_scalar('loss-scale vs samples', loss_scale,
args.consumed_train_samples)
if args.log_world_size_to_tensorboard:
writer.add_scalar('world-size', args.world_size, iteration)
writer.add_scalar('world-size vs samples', args.world_size,
args.consumed_train_samples)
if grad_norm is not None:
writer.add_scalar('grad-norm', grad_norm, iteration)
writer.add_scalar('grad-norm vs samples', grad_norm,
args.consumed_train_samples)
if num_zeros_in_grad is not None:
writer.add_scalar('num-zeros', num_zeros_in_grad, iteration)
writer.add_scalar('num-zeros vs samples', num_zeros_in_grad,
args.consumed_train_samples)
if params_norm is not None:
writer.add_scalar('params-norm', params_norm, iteration)
writer.add_scalar('params-norm vs samples', params_norm,
args.consumed_train_samples)
if args.log_timers_to_tensorboard:
timers.write(timers_to_log, writer, iteration,
normalizer=total_iterations)
if args.log_memory_to_tensorboard:
mem_stats = torch.cuda.memory_stats()
writer.add_scalar(
"mem-reserved-bytes",
mem_stats["reserved_bytes.all.current"],
iteration,
)
writer.add_scalar(
"mem-allocated-bytes",
mem_stats["allocated_bytes.all.current"],
iteration,
)
writer.add_scalar(
"mem-allocated-count",
mem_stats["allocation.all.current"],
iteration,
)
if iteration % args.log_interval == 0:
elapsed_time = timers('interval-time').elapsed()
elapsed_time_per_iteration = elapsed_time / total_iterations
if writer:
if args.log_timers_to_tensorboard:
writer.add_scalar('iteration-time',
elapsed_time_per_iteration, iteration)
log_string = ' iteration {:8d}/{:8d} |'.format(
iteration, args.train_iters)
log_string += ' consumed samples: {:12d} |'.format(
args.consumed_train_samples)
log_string += ' elapsed time per iteration (ms): {:.1f} |'.format(
elapsed_time_per_iteration * 1000.0)
log_string += ' learning rate: {:.3E} |'.format(learning_rate)
log_string += ' global batch size: {:5d} |'.format(batch_size)
for key in total_loss_dict:
if key not in [advanced_iters_key, skipped_iters_key,
nan_iters_key]:
avg = total_loss_dict[key].item() / \
float(max(1, total_loss_dict[advanced_iters_key]))
if avg > 0.0:
log_string += ' {}: {:.6E} |'.format(key, avg)
total_loss_dict[key] = torch.cuda.FloatTensor([0.0])
log_string += ' loss scale: {:.1f} |'.format(loss_scale)
if grad_norm is not None:
log_string += ' grad norm: {:.3f} |'.format(grad_norm)
if num_zeros_in_grad is not None:
log_string += ' num zeros: {:.1f} |'.format(num_zeros_in_grad)
if params_norm is not None:
log_string += ' params norm: {:.3f} |'.format(params_norm)
log_string += ' number of skipped iterations: {:3d} |'.format(
total_loss_dict[skipped_iters_key])
log_string += ' number of nan iterations: {:3d} |'.format(
total_loss_dict[nan_iters_key])
total_loss_dict[advanced_iters_key] = 0
total_loss_dict[skipped_iters_key] = 0
total_loss_dict[nan_iters_key] = 0
print_rank_last(log_string)
if report_memory_flag and learning_rate > 0.:
# Report memory after optimizer state has been initialized.
report_memory('(after {} iterations)'.format(iteration))
report_memory_flag = False
timers.log(timers_to_log, normalizer=args.log_interval)
return report_memory_flag
def save_checkpoint_and_time(iteration, model, optimizer, opt_param_scheduler):
timers = get_timers()
# Extra barrier is added to make sure
# all ranks report the max time.
torch.distributed.barrier()
timers('save-checkpoint').start()
save_checkpoint(iteration, model, optimizer, opt_param_scheduler)
torch.distributed.barrier()
timers('save-checkpoint').stop()
timers.log(['save-checkpoint'])
def train(forward_step_func, model, optimizer, opt_param_scheduler,
train_data_iterator, valid_data_iterator,
process_non_loss_data_func):
"""Train the model function."""
args = get_args()
timers = get_timers()
# Write args to tensorboard
write_args_to_tensorboard()
# Turn on training mode which enables dropout.
for model_module in model:
model_module.train()
# Tracking loss.
total_loss_dict = {}
# Iterations.
iteration = args.iteration
timers('interval-time').start()
print_datetime('before the start of training step')
report_memory_flag = True
while iteration < args.train_iters:
update_num_microbatches(args.consumed_train_samples)
args.curr_iteration = iteration
loss_dict, skipped_iter, grad_norm, num_zeros_in_grad = \
train_step(forward_step_func,
train_data_iterator,
model,
optimizer,
opt_param_scheduler)
iteration += 1
args.consumed_train_samples += mpu.get_data_parallel_world_size() * \
args.micro_batch_size * \
get_num_microbatches()
# Logging.
loss_scale = optimizer.get_loss_scale().item()
params_norm = None
if args.log_params_norm:
params_norm = calc_params_l2_norm(model)
report_memory_flag = training_log(loss_dict, total_loss_dict,
optimizer.param_groups[0]['lr'],
iteration, loss_scale,
report_memory_flag, skipped_iter,
grad_norm, params_norm, num_zeros_in_grad)
# Autoresume
if args.adlr_autoresume and \
(iteration % args.adlr_autoresume_interval == 0):
check_adlr_autoresume_termination(iteration, model, optimizer,
opt_param_scheduler)
# Evaluation
if args.eval_interval and iteration % args.eval_interval == 0 and \
args.do_valid:
prefix = 'iteration {}'.format(iteration)
evaluate_and_print_results(prefix, forward_step_func,
valid_data_iterator, model,
iteration, process_non_loss_data_func,
False)
# Checkpointing
saved_checkpoint = False
if args.exit_signal_handler:
signal_handler = get_signal_handler()
if any(signal_handler.signals_received()):
save_checkpoint_and_time(iteration, model, optimizer,
opt_param_scheduler)
print_datetime('exiting program after receiving SIGTERM.')
sys.exit()
if args.save and args.save_interval and \
iteration % args.save_interval == 0:
save_checkpoint_and_time(iteration, model, optimizer,
opt_param_scheduler)
saved_checkpoint = True
# Exiting based on duration
if args.exit_duration_in_mins:
train_time = (time.time() - _TRAIN_START_TIME) / 60.0
done_cuda = torch.cuda.IntTensor(
[train_time > args.exit_duration_in_mins])
torch.distributed.all_reduce(
done_cuda, op=torch.distributed.ReduceOp.MAX)
done = done_cuda.item()
if done:
if not saved_checkpoint:
save_checkpoint_and_time(iteration, model, optimizer,
opt_param_scheduler)
print_datetime('exiting program after {} minutes'.format(train_time))
sys.exit()
# Exiting based on iterations
if args.exit_interval and iteration % args.exit_interval == 0:
if not saved_checkpoint:
save_checkpoint_and_time(iteration, model, optimizer,
opt_param_scheduler)
torch.distributed.barrier()
print_datetime('exiting program at iteration {}'.format(iteration))
sys.exit()
return iteration
def evaluate(forward_step_func,
data_iterator,
model,
process_non_loss_data_func,
verbose=False):
"""Evaluation."""
args = get_args()
if args.vision_pretraining and args.vision_pretraining_type == "dino":
compute_feature_bank(model)
# Turn on evaluation mode which disables dropout.
for model_module in model:
model_module.eval()
total_loss_dict = {}
with torch.no_grad():
iteration = 0
while iteration < args.eval_iters:
iteration += 1
if verbose and iteration % args.log_interval == 0:
print_rank_0('Evaluating iter {}/{}'.format(iteration,
args.eval_iters))
forward_backward_func = get_forward_backward_func()
loss_dicts = forward_backward_func(
forward_step_func, data_iterator, model, optimizer=None,
timers=None, forward_only=True)
# Empty unused memory
if args.empty_unused_memory_level >= 1:
torch.cuda.empty_cache()
if mpu.is_pipeline_last_stage(ignore_virtual=True):
# Reduce across processes.
for loss_dict in loss_dicts:
for key in loss_dict:
total_loss_dict[key] = total_loss_dict.get(
key, torch.cuda.FloatTensor([0.0])) + loss_dict[key]
args.consumed_valid_samples += mpu.get_data_parallel_world_size() \
* args.micro_batch_size \
* get_num_microbatches()
collected_non_loss_data = None
if process_non_loss_data_func is not None and is_last_rank():
collected_non_loss_data = forward_backward_func(
forward_step_func, data_iterator, model, optimizer=None,
timers=None, forward_only=True, collect_non_loss_data=True)
# Move model back to the train mode.
for model_module in model:
model_module.train()
for key in total_loss_dict:
total_loss_dict[key] /= args.eval_iters * get_num_microbatches()
return total_loss_dict, collected_non_loss_data
def evaluate_and_print_results(prefix, forward_step_func,
data_iterator, model,
iteration, process_non_loss_data_func,
verbose=False):
"""Helper function to evaluate and dump results on screen."""
args = get_args()
writer = get_tensorboard_writer()
total_loss_dict, collected_non_loss_data = evaluate(
forward_step_func, data_iterator, model,
process_non_loss_data_func, verbose)
string = ' validation loss at {} | '.format(prefix)
for key in total_loss_dict:
string += '{} value: {:.6E} | '.format(key, total_loss_dict[key].item())
ppl = math.exp(min(20, total_loss_dict[key].item()))
string += '{} PPL: {:.6E} | '.format(key, ppl)
if writer:
writer.add_scalar('{} validation'.format(key),
total_loss_dict[key].item(),
iteration)
writer.add_scalar('{} validation vs samples'.format(key),
total_loss_dict[key].item(),
args.consumed_train_samples)
if args.log_validation_ppl_to_tensorboard:
writer.add_scalar('{} validation ppl'.format(key), ppl,
iteration)
writer.add_scalar('{} validation ppl vs samples'.format(key),
ppl, args.consumed_train_samples)
if process_non_loss_data_func is not None and writer and is_last_rank():
process_non_loss_data_func(collected_non_loss_data, iteration, writer)
length = len(string) + 1
print_rank_last('-' * length)
print_rank_last(string)
print_rank_last('-' * length)
def cyclic_iter(iter):
while True:
for x in iter:
yield x
def build_train_valid_test_data_iterators(
build_train_valid_test_datasets_provider):
"""XXX"""
args = get_args()
(train_dataloader, valid_dataloader, test_dataloader) = (None, None, None)
print_rank_0('> building train, validation, and test datasets ...')
# Backward compatibility, assume fixed batch size.
if args.iteration > 0 and args.consumed_train_samples == 0:
assert args.train_samples is None, \
'only backward compatiblity support for iteration-based training'
args.consumed_train_samples = args.iteration * args.global_batch_size
if args.iteration > 0 and args.consumed_valid_samples == 0:
if args.train_samples is None:
args.consumed_valid_samples = (args.iteration // args.eval_interval) * \
args.eval_iters * args.global_batch_size
# Data loader only on rank 0 of each model parallel group.
if mpu.get_tensor_model_parallel_rank() == 0:
# Number of train/valid/test samples.
if args.train_samples:
train_samples = args.train_samples
else:
train_samples = args.train_iters * args.global_batch_size
eval_iters = (args.train_iters // args.eval_interval + 1) * \
args.eval_iters
test_iters = args.eval_iters
train_val_test_num_samples = [train_samples,
eval_iters * args.global_batch_size,
test_iters * args.global_batch_size]
print_rank_0(' > datasets target sizes (minimum size):')
print_rank_0(' train: {}'.format(train_val_test_num_samples[0]))
print_rank_0(' validation: {}'.format(train_val_test_num_samples[1]))
print_rank_0(' test: {}'.format(train_val_test_num_samples[2]))
# Build the datasets.
train_ds, valid_ds, test_ds = build_train_valid_test_datasets_provider(
train_val_test_num_samples)
# Build dataloders.
train_dataloader = build_pretraining_data_loader(
train_ds, args.consumed_train_samples)
valid_dataloader = build_pretraining_data_loader(
valid_ds, args.consumed_valid_samples)
test_dataloader = build_pretraining_data_loader(test_ds, 0)
# Flags to know if we need to do training/validation/testing.
do_train = train_dataloader is not None and args.train_iters > 0
do_valid = valid_dataloader is not None and args.eval_iters > 0
do_test = test_dataloader is not None and args.eval_iters > 0
# Need to broadcast num_tokens and num_type_tokens.
flags = torch.cuda.LongTensor(
[int(do_train), int(do_valid), int(do_test)])
else:
flags = torch.cuda.LongTensor([0, 0, 0])
# Broadcast num tokens.
torch.distributed.broadcast(flags,
mpu.get_tensor_model_parallel_src_rank(),
group=mpu.get_tensor_model_parallel_group())
args.do_train = flags[0].item()
args.do_valid = flags[1].item()
args.do_test = flags[2].item()
# Build iterators.
dl_type = args.dataloader_type
assert dl_type in ['single', 'cyclic']
if train_dataloader is not None:
train_data_iterator = iter(train_dataloader) if dl_type == 'single' \
else iter(cyclic_iter(train_dataloader))
else:
train_data_iterator = None
if valid_dataloader is not None:
valid_data_iterator = iter(valid_dataloader) if dl_type == 'single' \
else iter(cyclic_iter(valid_dataloader))
else:
valid_data_iterator = None
if test_dataloader is not None:
test_data_iterator = iter(test_dataloader) if dl_type == 'single' \
else iter(cyclic_iter(test_dataloader))
else:
test_data_iterator = None
return train_data_iterator, valid_data_iterator, test_data_iterator
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""General utilities."""
import sys
import torch
from torch.nn.parallel import DistributedDataParallel as torchDDP
from apex.multi_tensor_apply import multi_tensor_applier
import amp_C
from megatron import get_args
from megatron import print_rank_0
from megatron import get_adlr_autoresume
from megatron import mpu
from megatron.model.module import param_is_not_shared
from megatron.mpu.layers import param_is_not_tensor_parallel_duplicate
def unwrap_model(model, module_instances=(torchDDP)):
return_list = True
if not isinstance(model, list):
model = [model]
return_list = False
unwrapped_model = []
for model_module in model:
while isinstance(model_module, module_instances):
model_module = model_module.module
unwrapped_model.append(model_module)
if not return_list:
return unwrapped_model[0]
return unwrapped_model
def calc_params_l2_norm(model):
"""Calculate l2 norm of parameters """
args = get_args()
if not isinstance(model, list):
model = [model]
# Remove duplicate params.
params_data = []
for model_ in model:
for param in model_.parameters():
is_not_shared = param_is_not_shared(param)
is_not_tp_duplicate = param_is_not_tensor_parallel_duplicate(param)
if is_not_shared and is_not_tp_duplicate:
if args.bf16:
params_data.append(param.data.float())
else:
params_data.append(param.data)
# Calculate norm
dummy_overflow_buf = torch.cuda.IntTensor([0])
norm, _ = multi_tensor_applier(
amp_C.multi_tensor_l2norm,
dummy_overflow_buf,
[params_data],
False # no per-parameter norm
)
norm_2 = norm * norm
# Sum across all model-parallel GPUs.
torch.distributed.all_reduce(norm_2,
op=torch.distributed.ReduceOp.SUM,
group=mpu.get_model_parallel_group())
return norm_2.item() ** 0.5
def average_losses_across_data_parallel_group(losses):
"""Reduce a tensor of losses across all GPUs."""
averaged_losses = torch.cat(
[loss.clone().detach().view(1) for loss in losses])
torch.distributed.all_reduce(averaged_losses,
group=mpu.get_data_parallel_group())
averaged_losses = averaged_losses / \
torch.distributed.get_world_size(group=mpu.get_data_parallel_group())
return averaged_losses
def report_memory(name):
"""Simple GPU memory report."""
mega_bytes = 1024.0 * 1024.0
string = name + ' memory (MB)'
string += ' | allocated: {}'.format(
torch.cuda.memory_allocated() / mega_bytes)
string += ' | max allocated: {}'.format(
torch.cuda.max_memory_allocated() / mega_bytes)
string += ' | reserved: {}'.format(
torch.cuda.memory_reserved() / mega_bytes)
string += ' | max reserved: {}'.format(
torch.cuda.max_memory_reserved() / mega_bytes)
if mpu.get_data_parallel_rank() == 0:
print("[Rank {}] {}".format(torch.distributed.get_rank(), string),
flush=True)
def print_params_min_max_norm(optimizer, iteration):
"""Print min, max, and norm of all parameters."""
index = 0
rank = torch.distributed.get_rank()
string = 'iteration, rank, index, tensor-model-parallel, min, max, norm\n'
optimizer_ = optimizer.optimizer
for param_group in optimizer_.param_groups:
for param in param_group['params']:
index += 1
min_ = param.data.min()
max_ = param.data.max()
norm = torch.linalg.norm(param.data)
string += '{:7d}, {:4d}, {:4d}, {:2d}, '.format(
iteration, rank, index, int(param.tensor_model_parallel))
string += '{:.6E}, {:.6E}, {:.6E}\n'.format(min_, max_, norm)
print(string, flush=True)
def check_adlr_autoresume_termination(iteration, model,
optimizer, opt_param_scheduler):
"""Check for autoresume signal and exit if it is received."""
from megatron.checkpointing import save_checkpoint
args = get_args()
autoresume = get_adlr_autoresume()
# Add barrier to ensure consistnecy.
torch.distributed.barrier()
if autoresume.termination_requested():
if args.save:
save_checkpoint(iteration, model, optimizer, opt_param_scheduler)
print_rank_0(">>> autoresume termination request found!")
if torch.distributed.get_rank() == 0:
autoresume.request_resume()
print_rank_0(">>> training terminated. Returning")
sys.exit(0)
def get_ltor_masks_and_position_ids(data,
eod_token,
reset_position_ids,
reset_attention_mask,
eod_mask_loss):
"""Build masks and position id for left to right model."""
# Extract batch size and sequence length.
micro_batch_size, seq_length = data.size()
# Attention mask (lower triangular).
if reset_attention_mask:
att_mask_batch = micro_batch_size
else:
att_mask_batch = 1
attention_mask = torch.tril(torch.ones(
(att_mask_batch, seq_length, seq_length), device=data.device)).view(
att_mask_batch, 1, seq_length, seq_length)
# Loss mask.
loss_mask = torch.ones(data.size(), dtype=torch.float, device=data.device)
if eod_mask_loss:
loss_mask[data == eod_token] = 0.0
# Position ids.
position_ids = torch.arange(seq_length, dtype=torch.long,
device=data.device)
position_ids = position_ids.unsqueeze(0).expand_as(data)
# We need to clone as the ids will be modifed based on batch index.
if reset_position_ids:
position_ids = position_ids.clone()
if reset_position_ids or reset_attention_mask:
# Loop through the batches:
for b in range(micro_batch_size):
# Find indecies where EOD token is.
eod_index = position_ids[b, data[b] == eod_token]
# Detach indecies from positions if going to modify positions.
if reset_position_ids:
eod_index = eod_index.clone()
# Loop through EOD indecies:
prev_index = 0
for j in range(eod_index.size()[0]):
i = eod_index[j]
# Mask attention loss.
if reset_attention_mask:
attention_mask[b, 0, (i + 1):, :(i + 1)] = 0
# Reset positions.
if reset_position_ids:
position_ids[b, (i + 1):] -= (i + 1 - prev_index)
prev_index = i + 1
# Convert attention mask to binary:
attention_mask = (attention_mask < 0.5)
return attention_mask, loss_mask, position_ids
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Pretrain BERT"""
from functools import partial
import torch
import torch.nn.functional as F
from megatron import get_args
from megatron import print_rank_0
from megatron import get_timers
from megatron import mpu
from megatron.data.dataset_utils import build_train_valid_test_datasets
from megatron.model import BertModel, ModelType
from megatron.training import pretrain
from megatron.utils import average_losses_across_data_parallel_group
def model_provider(pre_process=True, post_process=True):
"""Build the model."""
print_rank_0('building BERT model ...')
args = get_args()
num_tokentypes = 2 if args.bert_binary_head else 0
model = BertModel(
num_tokentypes=num_tokentypes,
add_binary_head=args.bert_binary_head,
parallel_output=True,
pre_process=pre_process,
post_process=post_process)
return model
def get_batch(data_iterator):
"""Build the batch."""
# Items and their type.
keys = ['text', 'types', 'labels', 'is_random', 'loss_mask', 'padding_mask']
datatype = torch.int64
# Broadcast data.
if data_iterator is not None:
data = next(data_iterator)
else:
data = None
data_b = mpu.broadcast_data(keys, data, datatype)
# Unpack.
tokens = data_b['text'].long()
types = data_b['types'].long()
sentence_order = data_b['is_random'].long()
loss_mask = data_b['loss_mask'].float()
lm_labels = data_b['labels'].long()
padding_mask = data_b['padding_mask'].long()
return tokens, types, sentence_order, loss_mask, lm_labels, padding_mask
def loss_func(loss_mask, sentence_order, output_tensor):
lm_loss_, sop_logits = output_tensor
lm_loss_ = lm_loss_.float()
loss_mask = loss_mask.float()
lm_loss = torch.sum(
lm_loss_.view(-1) * loss_mask.reshape(-1)) / loss_mask.sum()
if sop_logits is not None:
sop_loss = F.cross_entropy(sop_logits.view(-1, 2).float(),
sentence_order.view(-1),
ignore_index=-1)
sop_loss = sop_loss.float()
loss = lm_loss + sop_loss
averaged_losses = average_losses_across_data_parallel_group(
[lm_loss, sop_loss])
return loss, {'lm loss': averaged_losses[0],
'sop loss': averaged_losses[1]}
else:
loss = lm_loss
averaged_losses = average_losses_across_data_parallel_group(
[lm_loss])
return loss, {'lm loss': averaged_losses[0]}
def forward_step(data_iterator, model):
"""Forward step."""
args = get_args()
timers = get_timers()
# Get the batch.
timers('batch-generator').start()
tokens, types, sentence_order, loss_mask, lm_labels, padding_mask = get_batch(
data_iterator)
timers('batch-generator').stop()
if not args.bert_binary_head:
types = None
# Forward pass through the model.
output_tensor = model(tokens, padding_mask, tokentype_ids=types,
lm_labels=lm_labels)
return output_tensor, partial(loss_func, loss_mask, sentence_order)
def train_valid_test_datasets_provider(train_val_test_num_samples):
"""Build train, valid, and test datasets."""
args = get_args()
print_rank_0('> building train, validation, and test datasets '
'for BERT ...')
train_ds, valid_ds, test_ds = build_train_valid_test_datasets(
data_prefix=args.data_path,
data_impl=args.data_impl,
splits_string=args.split,
train_valid_test_num_samples=train_val_test_num_samples,
max_seq_length=args.seq_length,
masked_lm_prob=args.mask_prob,
short_seq_prob=args.short_seq_prob,
seed=args.seed,
skip_warmup=(not args.mmap_warmup),
binary_head=args.bert_binary_head)
print_rank_0("> finished creating BERT datasets ...")
return train_ds, valid_ds, test_ds
if __name__ == "__main__":
pretrain(train_valid_test_datasets_provider, model_provider,
ModelType.encoder_or_decoder,
forward_step, args_defaults={'tokenizer_type': 'BertWordPieceLowerCase'})
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Pretrain GPT"""
import torch
from functools import partial
from megatron import get_args
from megatron import print_rank_0
from megatron import get_timers
from megatron import get_tokenizer
from megatron import mpu
from megatron.data.gpt_dataset import build_train_valid_test_datasets
from megatron.model import GPTModel, ModelType
from megatron.training import pretrain
from megatron.utils import get_ltor_masks_and_position_ids
from megatron.utils import average_losses_across_data_parallel_group
def model_provider(pre_process=True, post_process=True):
"""Build the model."""
print_rank_0('building GPT model ...')
model = GPTModel(
num_tokentypes=0,
parallel_output=True,
pre_process=pre_process,
post_process=post_process
)
return model
def get_batch(data_iterator):
"""Generate a batch"""
args = get_args()
tokenizer = get_tokenizer()
# Items and their type.
keys = ['text']
datatype = torch.int64
# Broadcast data.
if data_iterator is not None:
data = next(data_iterator)
else:
data = None
data_b = mpu.broadcast_data(keys, data, datatype)
# Unpack.
tokens_ = data_b['text'].long()
labels = tokens_[:, 1:].contiguous()
tokens = tokens_[:, :-1].contiguous()
# Get the masks and postition ids.
attention_mask, loss_mask, position_ids = get_ltor_masks_and_position_ids(
tokens,
tokenizer.eod,
args.reset_position_ids,
args.reset_attention_mask,
args.eod_mask_loss)
return tokens, labels, loss_mask, attention_mask, position_ids
def loss_func(loss_mask, output_tensor):
losses = output_tensor.float()
loss_mask = loss_mask.view(-1).float()
loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask.sum()
# Reduce loss for logging.
averaged_loss = average_losses_across_data_parallel_group([loss])
return loss, {'lm loss': averaged_loss[0]}
def forward_step(data_iterator, model):
"""Forward step."""
args = get_args()
timers = get_timers()
# Get the batch.
timers('batch-generator').start()
tokens, labels, loss_mask, attention_mask, position_ids = get_batch(
data_iterator)
timers('batch-generator').stop()
output_tensor = model(tokens, position_ids, attention_mask,
labels=labels)
return output_tensor, partial(loss_func, loss_mask)
def train_valid_test_datasets_provider(train_val_test_num_samples):
"""Build train, valid, and test datasets."""
args = get_args()
print_rank_0('> building train, validation, and test datasets '
'for GPT ...')
train_ds, valid_ds, test_ds = build_train_valid_test_datasets(
data_prefix=args.data_path,
data_impl=args.data_impl,
splits_string=args.split,
train_valid_test_num_samples=train_val_test_num_samples,
seq_length=args.seq_length,
seed=args.seed,
skip_warmup=(not args.mmap_warmup))
print_rank_0("> finished creating GPT datasets ...")
return train_ds, valid_ds, test_ds
if __name__ == "__main__":
pretrain(train_valid_test_datasets_provider, model_provider,
ModelType.encoder_or_decoder,
forward_step, args_defaults={'tokenizer_type': 'GPT2BPETokenizer'})
# coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Pretrain BERT for Inverse Cloze Task"""
from functools import partial
import math
import torch
import torch.distributed as dist
import torch.nn.functional as F
from megatron import get_args
from megatron import print_rank_0
from megatron import get_timers
from megatron import mpu
from megatron.data.biencoder_dataset_utils import get_ict_batch
from megatron.data.dataset_utils import build_train_valid_test_datasets
from megatron.model import ModelType
from megatron.model.biencoder_model import biencoder_model_provider
from megatron.training import pretrain
from megatron.utils import average_losses_across_data_parallel_group
def pretrain_ict_model_provider(pre_process=True, post_process=True):
args = get_args()
model = biencoder_model_provider(
only_context_model=False,
only_query_model=False,
biencoder_shared_query_context_model=\
args.biencoder_shared_query_context_model,
pre_process=pre_process, post_process=post_process)
return model
def get_group_world_size_rank():
group = mpu.get_data_parallel_group()
rank = torch.distributed.get_rank(group=group)
world_size = torch.distributed.get_world_size(group=group)
return group, rank, world_size
class AllgatherFromDataParallelRegion(torch.autograd.Function):
@staticmethod
def forward(ctx, input_):
assert input_.dim() == 2
group, rank, world_size = get_group_world_size_rank()
tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
tensor_list[rank] = input_
torch.distributed.all_gather(tensor_list, input_, group=group)
output = torch.cat(tensor_list, dim=0).contiguous()
return output
@staticmethod
def backward(ctx, grad_output):
group, rank, world_size = get_group_world_size_rank()
assert grad_output.shape[0] % world_size == 0
dim_size = grad_output.shape[0] // world_size
output_list = torch.split(grad_output, dim_size, dim=0)
# get chunk from this rank
output = output_list[rank].contiguous()
return output
def loss_func(output_tensor):
args = get_args()
query_logits, context_logits = output_tensor
micro_batch_size = query_logits.shape[0]
# recall we assert that tensor_model_parallel_size == 1
assert mpu.get_tensor_model_parallel_world_size() == 1, \
"Model parallel size > 1 not supported for ICT"
global_batch_size = dist.get_world_size() * micro_batch_size
all_query_logits = AllgatherFromDataParallelRegion.apply(query_logits)
all_context_logits = AllgatherFromDataParallelRegion.apply(context_logits)
# scores are inner products between query and context embeddings
retrieval_scores = torch.matmul(all_query_logits,
torch.transpose(all_context_logits, 0, 1))
# scaling the retriever scores
if args.retriever_score_scaling:
retrieval_scores = retrieval_scores / math.sqrt(args.hidden_size)
softmax_scores = F.log_softmax(retrieval_scores, dim=1)
sorted_vals, sorted_indices = torch.topk(softmax_scores,
k=softmax_scores.shape[1], sorted=True)
def topk_accuracy(k):
return torch.cuda.FloatTensor([sum([int(i in sorted_indices[i, :k]) \
for i in range(global_batch_size)]) / global_batch_size])
topk_accs = [topk_accuracy(int(k)) for k in args.retriever_report_topk_accuracies]
labels = torch.arange(global_batch_size).long().cuda()
loss = F.nll_loss(softmax_scores, labels, reduction='mean')
reduced_losses = average_losses_across_data_parallel_group([loss, *topk_accs])
# Scale the retrieval loss
loss = loss * mpu.get_data_parallel_world_size()
# create stats_dict with retrieval loss and all specified top-k accuracies
topk_acc_dict = {'top{}_acc'.format(k): v * 100 for k, v in \
zip(args.retriever_report_topk_accuracies, reduced_losses[1:])}
stats_dict = dict(loss=reduced_losses[0], **topk_acc_dict)
return loss, stats_dict
def forward_step(data_iterator, model):
"""Forward step."""
args = get_args()
timers = get_timers()
# Get the batch.
timers('batch-generator').start()
query_tokens, query_mask, \
context_tokens, context_mask, context_indices = get_ict_batch(data_iterator)
timers('batch-generator').stop()
# Query and Context Types
query_types = torch.cuda.LongTensor(*query_tokens.shape).fill_(0)
context_types = torch.cuda.LongTensor(*context_tokens.shape).fill_(0)
# Forward model.
output_tensor = model(query_tokens, query_mask, query_types, context_tokens,
context_mask, context_types)
return output_tensor, partial(loss_func)
def train_valid_test_datasets_provider(train_val_test_num_samples):
"""Build train, valid and test datasets."""
args = get_args()
print_rank_0('> building train, validation, and test datasets '
'for BERT ICT...')
train_ds, valid_ds, test_ds = build_train_valid_test_datasets(
data_prefix=args.data_path,
data_impl=args.data_impl,
splits_string=args.split,
train_valid_test_num_samples=train_val_test_num_samples,
max_seq_length=args.seq_length,
masked_lm_prob=args.mask_prob,
short_seq_prob=args.short_seq_prob,
seed=args.seed,
skip_warmup=(not args.mmap_warmup),
binary_head=False,
dataset_type='ict')
print_rank_0("> finished creating BERT ICT datasets ...")
return train_ds, valid_ds, test_ds
if __name__ == "__main__":
pretrain(train_valid_test_datasets_provider,
pretrain_ict_model_provider,
ModelType.encoder_or_decoder,
forward_step,
args_defaults={'tokenizer_type': 'BertWordPieceLowerCase'})
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Pretrain T5"""
from functools import partial
import torch
from megatron import (
get_args,
get_timers,
mpu,
print_rank_0
)
from megatron.data.dataset_utils import build_train_valid_test_datasets
from megatron.model import T5Model, ModelType
from megatron.training import pretrain
from megatron.utils import average_losses_across_data_parallel_group
"""
Pipeline parallelism for T5
===========================
T5 is a model architecture with both encoder and decoder blocks.
Consequently, pipeline parallelism is implemented slightly differently
compared to architectures like GPT and BERT.
In particular, when pipeline_model_parallel_world_size > 1, each stage
either executes an encoder block or a decoder block. The
--pipeline-model-parallel-split-rank argument controls the rank at which
the split happens: all ranks lower than this argument execute the
encoder block, and all ranks equal to or higher than this argument value
execute the decoder block.
In the encoder section of the model, only one tensor is sent downstream:
the intermediate encoder_hidden_state. In the decoder section of the
model, two tensors are sent downstream in the forward pass: the fully
computed encoder_hidden_state, and the intermediate decoder_hidden_state.
In particular, these are the shapes of the tensors sent between
different workers:
If rank is in decoder section:
intermediate decoder_hidden_state (pre-transpose),
complete encoder_hidden_state (post-transpose).
If rank is at boundary between encoder and decoder sections:
complete encoder_hidden_state (post-transpose).
If rank is in encoder section:
intermediate encoder_hidden_state (pre-transpose).
Additionally, we have code in the backward_step function in schedules.py
to accumulate the encoder_hidden_state gradient across skip connections
(encoder_hidden_state fed in as input to each layer in the decoder).
"""
def model_provider(pre_process=True, post_process=True,
add_encoder=True, add_decoder=True):
"""Build the model."""
print_rank_0('building T5 model ...')
model = T5Model(num_tokentypes=0,
parallel_output=True,
pre_process=pre_process,
post_process=post_process,
add_encoder=add_encoder,
add_decoder=add_decoder)
return model
def get_batch(data_iterator):
"""Build the batch."""
keys = ['text_enc', 'text_dec', 'labels', 'loss_mask',
'enc_mask', 'dec_mask', 'enc_dec_mask']
datatype = torch.int64
# Broadcast data.
if data_iterator is not None:
data = next(data_iterator)
else:
data = None
data_b = mpu.broadcast_data(keys, data, datatype)
# Unpack.
tokens_enc = data_b['text_enc'].long()
tokens_dec = data_b['text_dec'].long()
labels = data_b['labels'].long()
loss_mask = data_b['loss_mask'].float()
enc_mask = (data_b['enc_mask'] < 0.5)
dec_mask = (data_b['dec_mask'] < 0.5)
enc_dec_mask = (data_b['enc_dec_mask'] < 0.5)
return tokens_enc, tokens_dec, loss_mask, labels, \
enc_mask, dec_mask, enc_dec_mask
def loss_func(loss_mask, output_tensor):
lm_loss_ = output_tensor.float()
lm_loss = torch.sum(
lm_loss_.view(-1) * loss_mask.reshape(-1)) / loss_mask.sum()
loss = lm_loss
averaged_losses = average_losses_across_data_parallel_group([lm_loss])
return loss, {'lm loss': averaged_losses[0]}
def forward_step(data_iterator, model):
"""Forward step."""
args = get_args()
timers = get_timers()
# Get the batch.
timers('batch generator').start()
tokens_enc, tokens_dec, loss_mask, lm_labels, enc_mask, dec_mask, enc_dec_mask \
= get_batch(data_iterator)
timers('batch generator').stop()
# Forward model lm_labels
output_tensor = model(tokens_enc,
tokens_dec,
enc_mask,
dec_mask,
enc_dec_mask,
tokentype_ids=None,
lm_labels=lm_labels)
return output_tensor, partial(loss_func, loss_mask)
def train_valid_test_datasets_provider(train_val_test_num_samples):
"""Build train, valid, and test datasets."""
args = get_args()
print_rank_0('> building train, validation, and test datasets '
'for T5 ...')
train_ds, valid_ds, test_ds = build_train_valid_test_datasets(
data_prefix=args.data_path,
data_impl=args.data_impl,
splits_string=args.split,
train_valid_test_num_samples=train_val_test_num_samples,
max_seq_length=args.encoder_seq_length,
max_seq_length_dec=args.decoder_seq_length,
masked_lm_prob=args.mask_prob,
short_seq_prob=args.short_seq_prob,
seed=args.seed,
skip_warmup=(not args.mmap_warmup),
dataset_type='t5')
print_rank_0("> finished creating T5 datasets ...")
return train_ds, valid_ds, test_ds
if __name__ == "__main__":
pretrain(train_valid_test_datasets_provider, model_provider, ModelType.encoder_and_decoder,
forward_step, args_defaults={'tokenizer_type': 'BertWordPieceLowerCase'})
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Pretrain VIT"""
import torch
import torch.nn.functional as F
from functools import partial
from megatron import get_args, get_timers, mpu, print_rank_0
from megatron.data.vit_dataset import build_train_valid_datasets
from megatron.model import ModelType
from megatron.model.vision.classification import VitClassificationModel
from megatron.model.vision.classification import MitClassificationModel
from megatron.training import pretrain
from megatron.utils import average_losses_across_data_parallel_group
def model_provider(pre_process=True, post_process=True):
"""Build the model."""
args = get_args()
if args.vision_backbone_type == 'vit':
print_rank_0("building VIT model ...")
model = VitClassificationModel(num_classes=args.num_classes,
pre_process=pre_process,
post_process=post_process)
elif args.vision_backbone_type == 'mit':
print_rank_0("building MIT model ...")
model = MitClassificationModel(num_classes=args.num_classes,
pre_process=pre_process,
post_process=post_process)
else:
raise Exception('{} vision backbone is not supported.'.format(
args.vision_backbone_type))
return model
def get_batch(data_iterator):
"""Build the batch."""
data = next(data_iterator)
# only data parallelism; no need for broadcast
images = data[0].cuda()
labels = data[1].cuda()
return images, labels
def loss_func(labels, output_tensor):
logits = output_tensor.contiguous().float()
loss = F.cross_entropy(logits, labels)
outputs = torch.argmax(logits, -1)
correct = (outputs == labels).float()
accuracy = torch.mean(correct)
averaged_loss = average_losses_across_data_parallel_group([loss, accuracy])
return loss, {"loss": averaged_loss[0], "accuracy": averaged_loss[1]}
def forward_step(data_iterator, model):
"""Forward step."""
timers = get_timers()
# Get the batch.
timers("batch-generator").start()
(
images,
labels,
) = get_batch(data_iterator)
timers("batch-generator").stop()
# Forward model. lm_labels
output_tensor = model(images)
return output_tensor, partial(loss_func, labels)
def train_valid_test_datasets_provider(train_val_test_num_samples):
"""Build train, valid, and test datasets."""
args = get_args()
print_rank_0(
"> building train, validation, and test datasets " "for VIT ..."
)
train_ds, valid_ds = build_train_valid_datasets(
data_path=args.data_path,
image_size=(args.img_h, args.img_w)
)
print_rank_0("> finished creating VIT datasets ...")
return train_ds, valid_ds, None
if __name__ == "__main__":
pretrain(
train_valid_test_datasets_provider,
model_provider,
ModelType.encoder_or_decoder,
forward_step,
args_defaults={'dataloader_type': 'cyclic', 'vision_pretraining': True}
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment