Commit 523ec9cc authored by wangsen's avatar wangsen
Browse files

all

parents
Pipeline #1668 failed with stages
in 0 seconds
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
"""Generation utilities."""
import torch
import torch.nn.functional as F
from megatron.training import get_args, get_tokenizer
from megatron.core import mpu
from megatron.training.utils import get_ltor_masks_and_position_ids
from .communication import (
copy_from_last_to_first_pipeline_stage,
broadcast_from_last_pipeline_stage,
broadcast_from_last_to_first_pipeline_stage)
from .forward_step import ForwardStep
from .sampling import sample
from .beam_utils import BeamHypotheses
def score_and_return_on_first_stage(model, tokens, lengths):
"""Function for just scoring.
Args:
model: no interleaving is supported.
tokens: prompt tokens extended to be of size [b, max_prompt_length]
lengths: original prompt length, size: [b]
Note: Outside of model, other parameters only need to be available on
rank 0.
Returns:
output_log_probs: log probability of the selected tokens. size: [b, s]
"""
args = get_args()
batch_size = tokens.size(0)
max_prompt_length = lengths.max().item()
assert max_prompt_length == tokens.size(1)
if max_prompt_length > args.max_position_embeddings:
raise ValueError("Length of prompt + tokens_to_generate longer than allowed")
if max_prompt_length * batch_size > args.max_tokens_to_oom:
raise ValueError("Too many tokens. " + str(max_prompt_length*batch_size)+ " is greater than "+str(args.max_tokens_to_oom))
# forward step.
forward_step = ForwardStep(model, batch_size, max_prompt_length)
# ===================
# Pre-allocate memory
# ===================
# Log probability of the sequence (prompt + generated tokens).
output_log_probs = None
output_log_probs_size = (batch_size, max_prompt_length - 1)
if mpu.is_pipeline_last_stage():
output_log_probs = torch.empty(output_log_probs_size,
dtype=torch.float32,
device=torch.cuda.current_device())
# =============
# Run infernece
# =============
with torch.no_grad():
attention_mask, position_ids = _build_attention_mask_and_position_ids(tokens)
# logits will be meanigful only in the last pipeline stage.
logits = forward_step(tokens, position_ids, attention_mask)
if mpu.is_pipeline_last_stage():
# Always the last stage should have an output.
assert logits is not None
log_probs = F.log_softmax(logits, dim=2)
# Pick the tokens that we need to get the log
# probabilities for. Note that next input token is
# the token which we selected in the current logits,
# so shift by 1.
indices = torch.unsqueeze(tokens[:, 1:], 2)
output_log_probs = torch.gather(log_probs, 2, indices).squeeze(2)
# ======================================
# Broadcast to the first pipeline stage.
# ======================================
output_log_probs = broadcast_from_last_to_first_pipeline_stage(
output_log_probs_size, torch.float32, output_log_probs)
return tokens, lengths, output_log_probs, logits
def generate_tokens_probs_and_return_on_first_stage(
model, forward_step, tokens, lengths,
return_output_log_probs=False,
top_k=0, top_p=0.0, top_p_decay=0.0, top_p_bound=0.0,
temperature=1.0,
use_eod_token_for_early_termination=True,
stop_on_double_eol=False,
stop_on_eol=False,
prevent_newline_after_colon=True
):
"""Main token generation function.
Args:
model: no interleaving is supported.
forward_step (ForwardStep): Class for running the model forward step.
tokens: prompt tokens extended to be of size [b, max-sequence-length]
lengths: original prompt length, size: [b]
return_output_log_probs: flag to calculate the log probability of
the generated tokens. Note that the log probability is the one
from the original logit.
top_k, top_p: top-k and top-p sampling parameters.
Note that top-k = 1 is gready. Also, these paramters are
exclusive meaning that:
if top-k > 0 then we expect top-p=0.
if top-p > 0 then we check for top-k=0.
temperature: sampling temperature.
use_eod_token_for_early_termination: if True, do early termination if
all the sequences have reached this token.
prevent_newline_after_colon: if True, it will disable generating new line \n after :
Note: Outside of model, other parameters only need to be available on
rank 0.
Returns: Note that is size is adjusted to a lower value than
max-sequence-length if generation is terminated early.
tokens: prompt and generated tokens. size: [b, :]
generated_sequence_lengths: total length (including prompt) of
the generated sequence. size: [b]
output_log_probs: log probability of the selected tokens. size: [b, s]
"""
args = get_args()
tokenizer = get_tokenizer()
batch_size = tokens.size(0)
min_prompt_length = lengths.min().item()
max_sequence_length = tokens.size(1)
if max_sequence_length > args.max_position_embeddings:
raise ValueError("Length of prompt + tokens_to_generate longer than allowed")
if max_sequence_length * batch_size > args.max_tokens_to_oom:
raise ValueError("Too many tokens. " + str(max_sequence_length*batch_size)+ " is greater than "+str(args.max_tokens_to_oom))
# forward step.
forward_step = forward_step(model, batch_size, max_sequence_length)
# Added termination_id to support the case that we want to terminate the
# generation once that id is generated.
if hasattr(args, 'eos_id'):
termination_id = args.eos_id
else:
termination_id = tokenizer.eod
# ===================
# Pre-allocate memory
# ===================
# Log probability of the sequence (prompt + generated tokens).
output_log_probs = None
output_log_probs_size = (batch_size, max_sequence_length - 1)
# Lengths of generated seuquence including including prompts.
generated_sequence_lengths = None
if mpu.is_pipeline_last_stage():
if return_output_log_probs:
output_log_probs = torch.empty(output_log_probs_size,
dtype=torch.float32,
device=torch.cuda.current_device())
generated_sequence_lengths = torch.ones(
batch_size, dtype=torch.int64,
device=torch.cuda.current_device()) * max_sequence_length
# Whether we have reached a termination id.
is_generation_done = torch.zeros(batch_size, dtype=torch.uint8,
device=torch.cuda.current_device())
# =============
# Run infernece
# =============
with torch.no_grad():
attention_mask, position_ids = _build_attention_mask_and_position_ids(
tokens)
prev_context_length = 0
for context_length in range(min_prompt_length, max_sequence_length):
# Pick the slice that we need to pass through the network.
tokens2use = tokens[:, prev_context_length:context_length]
positions2use = position_ids[:, prev_context_length:context_length]
attention_mask2use = attention_mask[
..., prev_context_length:context_length, :context_length]
# logits will be meanigful only in the last pipeline stage.
logits = forward_step(tokens2use, positions2use, attention_mask2use)
if mpu.is_pipeline_last_stage():
if prevent_newline_after_colon:
logits[tokens2use[:, -1] == tokenizer.tokenize(':')[0], -1, tokenizer.tokenize('\n')[0]] = -1e10 # disable "\n" after ":"
# Always the last stage should have an output.
assert logits is not None
# Sample.
last_token_logits = logits[:, -1, :]
new_sample = sample(last_token_logits,
top_k=top_k,
top_p=top_p,
temperature=temperature,
vocab_size=tokenizer.vocab_size)
if top_p > 0.0 and top_p_decay > 0.0:
top_p = top_p * top_p_decay
if top_p_bound > 0.0:
top_p = max(top_p, top_p_bound)
# If a prompt length is smaller or equal th current context
# length, it means we have started generating tokens
started = lengths <= context_length
# Update the tokens.
tokens[started, context_length] = new_sample[started]
# Calculate the log probabilities.
if return_output_log_probs:
log_probs = F.log_softmax(logits, dim=2)
if return_output_log_probs:
# Pick the tokens that we need to get the log
# probabilities for. Note that next input token is
# the token which we selected in the current logits,
# so shift by 1.
indices = torch.unsqueeze(
tokens[
:,
(prev_context_length + 1):(context_length + 1)],
2)
output_log_probs[:,
prev_context_length:context_length] = \
torch.gather(log_probs, 2, indices).squeeze(2)
# Update the tokens on the first stage so the next input to
# the network is correct.
copy_from_last_to_first_pipeline_stage(batch_size, torch.int64,
tokens[:, context_length])
# Update the context length for the next token generation.
prev_context_length = context_length
# Check if all the sequences have hit the termination_id.
done = None
if mpu.is_pipeline_last_stage():
# TODO(rprenger) These stopping methods are tokenizer dependent
# instead tokenization should be in the inference loop so stop sequences can be used
if stop_on_double_eol:
hit_double_eol = (new_sample == 628).byte() & started.byte()
hit_two_eols = (new_sample == 198).byte() & (tokens[:, context_length-1] == 198).byte() & started.byte()
done_token = hit_double_eol | hit_two_eols
elif stop_on_eol:
hit_double_eol = (new_sample == 628).byte() & started.byte()
hit_eol = (new_sample == 198).byte() & started.byte()
done_token = hit_double_eol | hit_eol
else:
done_token = (new_sample == termination_id).byte() & \
started.byte()
just_finished = (done_token & ~is_generation_done).bool()
generated_sequence_lengths[just_finished.view(-1)] = \
context_length + 1
is_generation_done = is_generation_done | done_token
done = torch.all(is_generation_done)
done = broadcast_from_last_pipeline_stage(1, torch.uint8,
tensor=done)
if use_eod_token_for_early_termination and done:
break
# ===================================================
# Update the length of based on max generated length.
# ===================================================
tokens = tokens[:, :(context_length + 1)]
if mpu.is_pipeline_last_stage():
if return_output_log_probs:
output_log_probs = output_log_probs[:, :context_length]
# ======================================
# Broadcast to the first pipeline stage.
# ======================================
generated_sequence_lengths = broadcast_from_last_to_first_pipeline_stage(
batch_size, torch.int64, generated_sequence_lengths)
if return_output_log_probs:
output_log_probs_size = (batch_size, context_length)
output_log_probs = broadcast_from_last_to_first_pipeline_stage(
output_log_probs_size, torch.float32, output_log_probs)
return tokens, generated_sequence_lengths, output_log_probs, None
def beam_search_and_return_on_first_stage(model, forward_step, tokens, lengths, beam_size, stop_token, num_return_gen, length_penalty, prevent_newline_after_colon=True):
args = get_args()
tokenizer = get_tokenizer()
batch_size = tokens.size(0)
assert(batch_size == 1)
prompt_length = lengths.item()
final_sequence_length = tokens.size(1)
final_sequence_length = min(final_sequence_length, args.max_position_embeddings)
# If the context is too big, this happens
if prompt_length >= final_sequence_length:
raise ValueError("context length + tokens_to_generate too large")
# forward step.
forward_step = forward_step(model, beam_size, final_sequence_length)
beam_hyp = BeamHypotheses(beam_size, length_penalty)
best_batches = None
done = torch.zeros(1, dtype=torch.uint8, device=torch.cuda.current_device())
scores = torch.zeros(beam_size,
dtype=torch.float32,
device=torch.cuda.current_device()).unsqueeze(1)
scores_size_tensor, tokens_size_tensor = None, None
# =============
# Run infernece
# =============
with torch.no_grad():
tokens = tokens.repeat(beam_size, 1)
attention_mask, position_ids = _build_attention_mask_and_position_ids(tokens)
prev_context_length = 0
for context_length in range(prompt_length, final_sequence_length):
# Pick the slice that we need to pass through the network.
tokens2use = tokens[:, prev_context_length:context_length]
positions2use = position_ids[:, prev_context_length:context_length]
attention_mask2use = attention_mask[
..., prev_context_length:context_length, :context_length]
# logits will be meanigful only in the last pipeline stage.
logits = forward_step(tokens2use, positions2use, attention_mask2use)
if mpu.is_pipeline_last_stage():
if prevent_newline_after_colon:
logits[tokens2use[:, -1] == tokenizer.tokenize(':')[0], -1, tokenizer.tokenize('\n')[0]] = -1e10 # disable "\n" after ":"
vocab_size = logits.size(2)
log_probs = F.log_softmax(logits, dim=2)
new_scores = log_probs[:, -1, :] + scores
if context_length == prompt_length: # if this is the first one
sorted_scores, indices = torch.sort(new_scores[0,:], descending=True)
else:
sorted_scores, indices = torch.sort(new_scores.view(-1), descending=True)
best_beam_ids = torch.div(indices[: 2 * beam_size], vocab_size).trunc().long()
best_words = indices[:2 * beam_size] % vocab_size
best_scores = sorted_scores[: 2 * beam_size]
next_beams = []
for beam_token_rank, (token_id, beam_score, beam_id) in enumerate(
zip(best_words, best_scores, best_beam_ids)
):
if token_id.item() == stop_token:
# if beam_token does not belong to top num_beams tokens, it should not be added
is_beam_token_worse_than_top_num_beams = beam_token_rank >= beam_size
if is_beam_token_worse_than_top_num_beams:
continue
beam_hyp.add(
tokens[beam_id].clone(),
beam_score,
context_length + 1 - prompt_length
)
else:
# add next predicted token since it is not eos_token
next_beams.append((token_id, beam_score, beam_id))
if len(next_beams) == beam_size:
break
if beam_hyp.is_done(best_scores.max().item(), context_length + 1 - prompt_length):
done = torch.ones(1, dtype=torch.uint8, device=torch.cuda.current_device())
best_batches = tokens.new([item[2] for item in next_beams])
tokens = tokens[best_batches,:]
tokens[:, context_length] = tokens.new([item[0] for item in next_beams])
scores = scores.new([item[1] for item in next_beams]).unsqueeze(1)
# torch.distributed.barrier()
done = broadcast_from_last_pipeline_stage(1, torch.uint8, done)
if done:
break
# Update the tokens on the first stage so the next input to
# the network is correct.
copy_from_last_to_first_pipeline_stage(tokens.size(), torch.int64,
tokens)
# set inference key values to make it consistent with best beam index
best_batches = broadcast_from_last_pipeline_stage(beam_size, torch.int64, best_batches)
forward_step.inference_params.swap_key_value_dict(best_batches)
# Update the context length for the next token generation.
prev_context_length = context_length
if mpu.is_pipeline_last_stage():
# if cannot find stop token, add open beams to hyps
if not done:
for beam_id in range(beam_size):
beam_hyp.add(tokens[beam_id].clone(), scores[beam_id].squeeze(), context_length + 1 - prompt_length)
# rank based on scores
sorted_hyps = sorted(beam_hyp.beams, key=lambda x: x[0], reverse=True)
num_return_gen = min(num_return_gen, len(sorted_hyps))
scores = [sorted_hyps[i][0] for i in range(num_return_gen)]
tokens = [sorted_hyps[i][1] for i in range(num_return_gen)]
scores = torch.stack(scores, dim=0)
tokens = torch.stack(tokens, dim=0)
scores_size_tensor = torch.tensor(scores.shape, dtype=torch.int64, device=torch.cuda.current_device())
tokens_size_tensor = torch.tensor(tokens.shape, dtype=torch.int64, device=torch.cuda.current_device())
scores_size_tensor = broadcast_from_last_pipeline_stage(1, torch.int64, scores_size_tensor)
tokens_size_tensor = broadcast_from_last_pipeline_stage(2, torch.int64, tokens_size_tensor)
scores = broadcast_from_last_to_first_pipeline_stage(tuple(scores_size_tensor), torch.float32, scores)
tokens = broadcast_from_last_to_first_pipeline_stage(tuple(tokens_size_tensor), torch.int64, tokens)
return tokens, scores
def _build_attention_mask_and_position_ids(tokens):
"""Build the attention mask and postition ids for the input tokens."""
# Since we are not interested in loss-mask and reset attention/position
# is also False, eod_token is not used so it is safe to set it to None.
attention_mask, _, position_ids = get_ltor_masks_and_position_ids(
data=tokens,
eod_token=None,
reset_position_ids=False,
reset_attention_mask=False,
eod_mask_loss=False)
return attention_mask, position_ids
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
"""Sampling utilities.
Part of this code is inspired by:
- https://github.com/ari-holtzman/degen/blob/master/gen.py
- https://huggingface.co/transformers/_modules/transformers/generation_logits_process.html
"""
import torch
def modify_logits_for_top_k_filtering(logits, top_k):
"""Set the logits for none top-k values to -inf."""
filter_ = logits < torch.topk(logits, top_k)[0][..., -1, None]
logits.masked_fill_(filter_, float('-Inf'))
def modify_logits_for_top_p_filtering(logits, top_p):
"""Set the logits for none top-p values to -inf."""
# First sort and calculate cumulative sum of probabilities.
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)
# Filteration based on the cumulative sum.
filter_ = cumulative_probs > top_p
# This shift by 1 is weird and I cannot justify it. This existed
# in the original implementation:
# https://github.com/ari-holtzman/degen/blob/master/gen.py
# and I guess it is needed so keeping it for now.
filter_[:, 1:] = filter_[:, :-1].clone()
# Make sure we at least have one token to select from.
filter_[..., 0] = 0
# Fill in the filtered part
filter_ = filter_.scatter(1, sorted_indices, filter_)
logits.masked_fill_(filter_, float('-Inf'))
def sample(logits, top_k=0, top_p=0.0, temperature=1.0, vocab_size=None):
""" Sample and generate a token.
Note: logits has the dimension [b, v] where b is the batch size
and v is the vocabulary size.
If vocab_size is provided, we will make sure the sample that is
generated is in [0, vocab-size). This will avoid out of vocabulary
generations due to padding.
"""
# Check logits for consistency.
assert logits.ndim == 2, 'expected the logits to be of [b, v] shape.'
assert logits.type() == 'torch.cuda.FloatTensor', \
'input logits should be floats.'
# Greedy is just simple argmax.
if top_k == 1:
assert top_p == 0.0, 'cannot set both greedy and top-p samplings.'
samples = torch.argmax(logits, dim=-1)
# Top-k or top-p sampling.
else:
# Clone so we do not modify the inputs,
logits = logits.clone()
# Apply temperature in place.
if temperature != 1.0:
logits.div_(temperature)
if top_k > 1:
assert top_p == 0.0, 'cannot set both top-k and top-p samplings.'
assert top_k <= logits.size(1), 'top-k is larger than logit size.'
if vocab_size:
assert top_k < vocab_size, 'top-k is larger than vocab size.'
modify_logits_for_top_k_filtering(logits, top_k)
elif top_p > 0.0:
assert top_p <= 1.0, 'top-p should be in (0, 1].'
modify_logits_for_top_p_filtering(logits, top_p)
# After filtering, we need to recalculate the distribution.
probs = logits.softmax(dim=-1)
samples = torch.multinomial(probs, num_samples=1).view(-1)
# If vocab size is provided, make sure the samples are in
# in the range [0, vocab-size).
if vocab_size:
samples = torch.clamp(samples, min=0, max=(vocab_size - 1))
return samples
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
"""Tokenization utilities."""
import torch
from megatron.training import get_tokenizer, get_args
from .communication import broadcast_int_list, broadcast_tensor
def detokenize_generations(tokens_gpu_tensor,
lengths_gpu_tensor,
return_segments):
"""Detokenize the generated tokens."""
tokenizer = get_tokenizer()
args = get_args()
prompts_plus_generations = []
if return_segments:
prompts_plus_generations_segments = []
tokens = tokens_gpu_tensor.cpu().numpy().tolist()
lengths = lengths_gpu_tensor.cpu().numpy().tolist()
for sequence_tokens, length in zip(tokens, lengths):
sequence_tokens = sequence_tokens[:length]
prompts_plus_generations.append(
tokenizer.detokenize(sequence_tokens))
if return_segments:
words = []
for token in sequence_tokens:
if args.tokenizer_type in ['SentencePieceTokenizer',
'GPTSentencePieceTokenizer',
'HuggingFaceTokenizer',
'Llama2Tokenizer',
'MistralTokenizer']:
word = tokenizer.decoder[token]
elif args.tokenizer_type == 'Llama3Tokenizer':
word = tokenizer.decode([token])
elif args.tokenizer_type == 'NullTokenizer':
word = str(token)
else:
word = tokenizer.tokenizer.decoder[token]
word = bytearray(
[tokenizer.tokenizer.byte_decoder[c] for c in word]).decode(
'utf-8', errors='replace')
words.append(word)
prompts_plus_generations_segments.append(words)
if return_segments:
return tokens, prompts_plus_generations, \
prompts_plus_generations_segments
return tokens, prompts_plus_generations
def tokenize_prompts(prompts=None, tokens_to_generate=None,
add_BOS=None, rank=0):
"""Tokenize prompts and make them avaiable on all ranks."""
# On all ranks set to None so we can pass them to functions
sizes_list = None
prompts_tokens_cuda_long_tensor = None
prompts_length_cuda_long_tensor = None
# On the specified rank, build the above.
if torch.distributed.get_rank() == rank:
assert prompts is not None
assert tokens_to_generate is not None
# Tensor of tokens padded and their unpadded length.
prompts_tokens_cuda_long_tensor, prompts_length_cuda_long_tensor = \
_tokenize_prompts_and_batch(prompts, tokens_to_generate, add_BOS)
# We need the sizes of these tensors for the boradcast
sizes_list = [prompts_tokens_cuda_long_tensor.size(0), # Batch size
prompts_tokens_cuda_long_tensor.size(1)] # Sequence lenght
# First, broadcast the sizes.
sizes_tensor = broadcast_int_list(2, int_list=sizes_list, rank=rank)
# Now that we have the sizes, we can boradcast the tokens
# and length tensors.
sizes = sizes_tensor.tolist()
prompts_tokens_cuda_long_tensor = broadcast_tensor(
sizes, torch.int64, tensor=prompts_tokens_cuda_long_tensor, rank=rank)
prompts_length_cuda_long_tensor = broadcast_tensor(
sizes[0], torch.int64, tensor=prompts_length_cuda_long_tensor,
rank=rank)
return prompts_tokens_cuda_long_tensor, prompts_length_cuda_long_tensor
def _tokenize_prompts_and_batch(prompts, tokens_to_generate, add_BOS):
"""Given a set of prompts and number of tokens to generate:
- tokenize prompts
- set the sequence length to be the max of length of prompts
plus the number of tokens we would like to generate
- pad all the sequences to this length so we can convert them
into a 2D tensor.
"""
# Tokenize all the prompts.
tokenizer = get_tokenizer()
if add_BOS:
prompts_tokens = [[tokenizer.eod] + tokenizer.tokenize(prompt)
for prompt in prompts]
else:
prompts_tokens = [tokenizer.tokenize(prompt) for prompt in prompts]
# Now we have a list of list of tokens which each list has a different
# size. We want to extend this list to:
# - incorporate the tokens that need to be generated
# - make all the sequences equal length.
# Get the prompts length.
prompts_length = [len(prompt_tokens) for prompt_tokens in prompts_tokens]
# Get the max prompts length.
max_prompt_len = max(prompts_length)
# Number of tokens in the each sample of the batch.
samples_length = max_prompt_len + tokens_to_generate
# Now update the list of list to be of the same size: samples_length.
for prompt_tokens, prompt_length in zip(prompts_tokens, prompts_length):
padding_size = samples_length - prompt_length
prompt_tokens.extend([tokenizer.eod] * padding_size)
# Now we are in a structured format, we can convert to tensors.
prompts_tokens_tensor = torch.tensor(prompts_tokens, dtype=torch.long, device='cuda')
prompts_length_tensor = torch.tensor(prompts_length, dtype=torch.long, device='cuda')
return prompts_tokens_tensor, prompts_length_tensor
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
import datetime
import torch
import json
import threading
from flask import Flask, request, jsonify, current_app
from flask_restful import Resource, Api
from megatron.training import get_args
from megatron.inference.text_generation import generate_and_post_process
from megatron.inference.text_generation import beam_search_and_post_process
GENERATE_NUM = 0
BEAM_NUM = 1
lock = threading.Lock()
class MegatronGenerate(Resource):
def __init__(self, model):
self.model = model
@staticmethod
def send_do_generate():
choice = torch.tensor([GENERATE_NUM], dtype=torch.long, device='cuda')
torch.distributed.broadcast(choice, 0)
@staticmethod
def send_do_beam_search():
choice = torch.tensor([BEAM_NUM], dtype=torch.long, device='cuda')
torch.distributed.broadcast(choice, 0)
def put(self):
args = get_args()
if not "prompts" in request.get_json():
return "prompts argument required", 400
if "max_len" in request.get_json():
return "max_len is no longer used. Replace with tokens_to_generate", 400
if "sentences" in request.get_json():
return "sentences is no longer used. Replace with prompts", 400
prompts = request.get_json()["prompts"]
if not isinstance(prompts, list):
return "prompts is not a list of strings", 400
if len(prompts) == 0:
return "prompts is empty", 400
if len(prompts) > 128:
return "Maximum number of prompts is 128", 400
tokens_to_generate = 64 # Choosing hopefully sane default. Full sequence is slow
if "tokens_to_generate" in request.get_json():
tokens_to_generate = request.get_json()["tokens_to_generate"]
if not isinstance(tokens_to_generate, int):
return "tokens_to_generate must be an integer greater than 0"
if tokens_to_generate < 0:
return "tokens_to_generate must be an integer greater than or equal to 0"
logprobs = False
if "logprobs" in request.get_json():
logprobs = request.get_json()["logprobs"]
if not isinstance(logprobs, bool):
return "logprobs must be a boolean value"
if tokens_to_generate == 0 and not logprobs:
return "tokens_to_generate=0 implies logprobs should be True"
temperature = 1.0
if "temperature" in request.get_json():
temperature = request.get_json()["temperature"]
if not (type(temperature) == int or type(temperature) == float):
return "temperature must be a positive number less than or equal to 100.0"
if not (0.0 < temperature <= 100.0):
return "temperature must be a positive number less than or equal to 100.0"
top_k = 0.0
if "top_k" in request.get_json():
top_k = request.get_json()["top_k"]
if not (type(top_k) == int):
return "top_k must be an integer equal to or greater than 0 and less than or equal to 1000"
if not (0 <= top_k <= 1000):
return "top_k must be equal to or greater than 0 and less than or equal to 1000"
top_p = 0.0
if "top_p" in request.get_json():
top_p = request.get_json()["top_p"]
if not (type(top_p) == float):
return "top_p must be a positive float less than or equal to 1.0"
if top_p > 0.0 and top_k > 0.0:
return "cannot set both top-k and top-p samplings."
if not (0 <= top_p <= 1.0):
return "top_p must be less than or equal to 1.0"
top_p_decay = 0.0
if "top_p_decay" in request.get_json():
top_p_decay = request.get_json()["top_p_decay"]
if not (type(top_p_decay) == float):
return "top_p_decay must be a positive float less than or equal to 1.0"
if top_p == 0.0:
return "top_p_decay cannot be set without top_p"
if not (0 <= top_p_decay <= 1.0):
return "top_p_decay must be less than or equal to 1.0"
top_p_bound = 0.0
if "top_p_bound" in request.get_json():
top_p_bound = request.get_json()["top_p_bound"]
if not (type(top_p_bound) == float):
return "top_p_bound must be a positive float less than or equal to top_p"
if top_p == 0.0:
return "top_p_bound cannot be set without top_p"
if not (0.0 < top_p_bound <= top_p):
return "top_p_bound must be greater than 0 and less than top_p"
add_BOS = False
if "add_BOS" in request.get_json():
add_BOS = request.get_json()["add_BOS"]
if not isinstance(add_BOS, bool):
return "add_BOS must be a boolean value"
if any([len(prompt) == 0 for prompt in prompts]) and not add_BOS:
return "Empty prompts require add_BOS=true"
stop_on_double_eol = False
if "stop_on_double_eol" in request.get_json():
stop_on_double_eol = request.get_json()["stop_on_double_eol"]
if not isinstance(stop_on_double_eol, bool):
return "stop_on_double_eol must be a boolean value"
stop_on_eol = False
if "stop_on_eol" in request.get_json():
stop_on_eol = request.get_json()["stop_on_eol"]
if not isinstance(stop_on_eol, bool):
return "stop_on_eol must be a boolean value"
prevent_newline_after_colon = False
if "prevent_newline_after_colon" in request.get_json():
prevent_newline_after_colon = request.get_json()["prevent_newline_after_colon"]
if not isinstance(prevent_newline_after_colon, bool):
return "prevent_newline_after_colon must be a boolean value"
random_seed = -1
if "random_seed" in request.get_json():
random_seed = request.get_json()["random_seed"]
if not isinstance(random_seed, int):
return "random_seed must be integer"
if random_seed < 0:
return "random_seed must be a positive integer"
no_log = False
if "no_log" in request.get_json():
no_log = request.get_json()["no_log"]
if not isinstance(no_log, bool):
return "no_log must be a boolean value"
beam_width = None
if "beam_width" in request.get_json():
beam_width = request.get_json()["beam_width"]
if not isinstance(beam_width, int):
return "beam_width must be integer"
if beam_width < 1:
return "beam_width must be an integer > 1"
if len(prompts) > 1:
return "When doing beam_search, batch size must be 1"
stop_token=50256
if "stop_token" in request.get_json():
stop_token = request.get_json()["stop_token"]
if not isinstance(stop_token, int):
return "stop_token must be an integer"
length_penalty = 1
if "length_penalty" in request.get_json():
length_penalty = request.get_json()["length_penalty"]
if not isinstance(length_penalty, float):
return "length_penalty must be a float"
with lock: # Need to get lock to keep multiple threads from hitting code
if not no_log:
print("request IP: " + str(request.remote_addr))
print(json.dumps(request.get_json()),flush=True)
print("start time: ", datetime.datetime.now())
try:
if beam_width is not None:
MegatronGenerate.send_do_beam_search() # Tell other ranks we're doing beam_search
response, response_seg, response_scores = \
beam_search_and_post_process(
self.model,
prompts=prompts,
tokens_to_generate=tokens_to_generate,
beam_size = beam_width,
add_BOS=add_BOS,
stop_token=stop_token,
num_return_gen=beam_width, # Returning whole beam
length_penalty=length_penalty,
prevent_newline_after_colon=prevent_newline_after_colon
)
return jsonify({"text": response,
"segments": response_seg,
"scores": response_scores})
else:
MegatronGenerate.send_do_generate() # Tell other ranks we're doing generate
response, response_seg, response_logprobs, _ = \
generate_and_post_process(
self.model,
prompts=prompts,
tokens_to_generate=tokens_to_generate,
return_output_log_probs=logprobs,
top_k_sampling=top_k,
top_p_sampling=top_p,
top_p_decay=top_p_decay,
top_p_bound=top_p_bound,
temperature=temperature,
add_BOS=add_BOS,
use_eod_token_for_early_termination=True,
stop_on_double_eol=stop_on_double_eol,
stop_on_eol=stop_on_eol,
prevent_newline_after_colon=prevent_newline_after_colon,
random_seed=random_seed)
return jsonify({"text": response,
"segments": response_seg,
"logprobs": response_logprobs})
except ValueError as ve:
return ve.args[0]
print("end time: ", datetime.datetime.now())
class MegatronServer(object):
def __init__(self, model):
self.app = Flask(__name__, static_url_path='')
api = Api(self.app)
api.add_resource(MegatronGenerate, '/api', resource_class_args=[model])
def run(self, url, port):
self.app.run(url, threaded=True, debug=False, port=port)
"""AutoAugment data augmentation policy for ImageNet.
-- Begin license text.
MIT License
Copyright (c) 2018 Philip Popien
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
-- End license text.
Code adapted from https://github.com/DeepVoltaire/AutoAugment.
This module implements the fixed AutoAugment data augmentation policy for ImageNet provided in
Appendix A, Table 9 of reference [1]. It does not include any of the search code for augmentation
policies.
Reference:
[1] https://arxiv.org/abs/1805.09501
"""
import random
import numpy as np
from PIL import Image
from PIL import ImageEnhance
from PIL import ImageOps
_MAX_LEVEL = 10 # Maximum integer strength of an augmentation, if applicable.
class ImageNetPolicy:
"""Definition of an ImageNetPolicy.
Implements a fixed AutoAugment data augmentation policy targeted at
ImageNet training by randomly applying at runtime one of the 25 pre-defined
data augmentation sub-policies provided in Reference [1].
Usage example as a Pytorch Transform:
>>> transform=transforms.Compose([transforms.Resize(256),
>>> ImageNetPolicy(),
>>> transforms.ToTensor()])
"""
def __init__(self, fillcolor=(128, 128, 128)):
"""Initialize an ImageNetPolicy.
Args:
fillcolor (tuple): RGB color components of the color to be used for
filling when needed (default: (128, 128, 128), which
corresponds to gray).
"""
# Instantiate a list of sub-policies.
# Each entry of the list is a SubPolicy which consists of
# two augmentation operations,
# each of those parametrized as operation, probability, magnitude.
# Those two operations are applied sequentially on the image upon call.
self.policies = [
SubPolicy("posterize", 0.4, 8, "rotate", 0.6, 9, fillcolor),
SubPolicy("solarize", 0.6, 5, "autocontrast", 0.6, 5, fillcolor),
SubPolicy("equalize", 0.8, 8, "equalize", 0.6, 3, fillcolor),
SubPolicy("posterize", 0.6, 7, "posterize", 0.6, 6, fillcolor),
SubPolicy("equalize", 0.4, 7, "solarize", 0.2, 4, fillcolor),
SubPolicy("equalize", 0.4, 4, "rotate", 0.8, 8, fillcolor),
SubPolicy("solarize", 0.6, 3, "equalize", 0.6, 7, fillcolor),
SubPolicy("posterize", 0.8, 5, "equalize", 1.0, 2, fillcolor),
SubPolicy("rotate", 0.2, 3, "solarize", 0.6, 8, fillcolor),
SubPolicy("equalize", 0.6, 8, "posterize", 0.4, 6, fillcolor),
SubPolicy("rotate", 0.8, 8, "color", 0.4, 0, fillcolor),
SubPolicy("rotate", 0.4, 9, "equalize", 0.6, 2, fillcolor),
SubPolicy("equalize", 0.0, 7, "equalize", 0.8, 8, fillcolor),
SubPolicy("invert", 0.6, 4, "equalize", 1.0, 8, fillcolor),
SubPolicy("color", 0.6, 4, "contrast", 1.0, 8, fillcolor),
SubPolicy("rotate", 0.8, 8, "color", 1.0, 2, fillcolor),
SubPolicy("color", 0.8, 8, "solarize", 0.8, 7, fillcolor),
SubPolicy("sharpness", 0.4, 7, "invert", 0.6, 8, fillcolor),
SubPolicy("shearX", 0.6, 5, "equalize", 1.0, 9, fillcolor),
SubPolicy("color", 0.4, 0, "equalize", 0.6, 3, fillcolor),
SubPolicy("equalize", 0.4, 7, "solarize", 0.2, 4, fillcolor),
SubPolicy("solarize", 0.6, 5, "autocontrast", 0.6, 5, fillcolor),
SubPolicy("invert", 0.6, 4, "equalize", 1.0, 8, fillcolor),
SubPolicy("color", 0.6, 4, "contrast", 1.0, 8, fillcolor),
SubPolicy("equalize", 0.8, 8, "equalize", 0.6, 3, fillcolor),
]
def __call__(self, img):
"""Define call method for ImageNetPolicy class."""
policy_idx = random.randint(0, len(self.policies) - 1)
return self.policies[policy_idx](img)
def __repr__(self):
"""Define repr method for ImageNetPolicy class."""
return "ImageNetPolicy"
class SubPolicy:
"""Definition of a SubPolicy.
A SubPolicy consists of two augmentation operations,
each of those parametrized as operation, probability, magnitude.
The two operations are applied sequentially on the image upon call.
"""
def __init__(
self,
operation1,
probability1,
magnitude_idx1,
operation2,
probability2,
magnitude_idx2,
fillcolor,
):
"""Initialize a SubPolicy.
Args:
operation1 (str): Key specifying the first augmentation operation.
There are fourteen key values altogether (see supported_ops below
listing supported operations). probability1 (float): Probability
within [0., 1.] of applying the first augmentation operation.
magnitude_idx1 (int): Integer specifiying the strength of the first
operation as an index further used to derive the magnitude from a
range of possible values.
operation2 (str): Key specifying the second augmentation operation.
probability2 (float): Probability within [0., 1.] of applying the
second augmentation operation.
magnitude_idx2 (int): Integer specifiying the strength of the
second operation as an index further used to derive the magnitude
from a range of possible values.
fillcolor (tuple): RGB color components of the color to be used for
filling.
Returns:
"""
# List of supported operations for operation1 and operation2.
supported_ops = [
"shearX",
"shearY",
"translateX",
"translateY",
"rotate",
"color",
"posterize",
"solarize",
"contrast",
"sharpness",
"brightness",
"autocontrast",
"equalize",
"invert",
]
assert (operation1 in supported_ops) and (
operation2 in supported_ops
), "SubPolicy:one of oper1 or oper2 refers to an unsupported operation."
assert (
0.0 <= probability1 <= 1.0 and 0.0 <= probability2 <= 1.0
), "SubPolicy: prob1 and prob2 should be within [0., 1.]."
assert (
isinstance(magnitude_idx1, int) and 0 <= magnitude_idx1 <= 10
), "SubPolicy: idx1 should be specified as an integer within [0, 10]."
assert (
isinstance(magnitude_idx2, int) and 0 <= magnitude_idx2 <= 10
), "SubPolicy: idx2 should be specified as an integer within [0, 10]."
# Define a dictionary where each key refers to a specific type of
# augmentation and the corresponding value is a range of ten possible
# magnitude values for that augmentation.
num_levels = _MAX_LEVEL + 1
ranges = {
"shearX": np.linspace(0, 0.3, num_levels),
"shearY": np.linspace(0, 0.3, num_levels),
"translateX": np.linspace(0, 150 / 331, num_levels),
"translateY": np.linspace(0, 150 / 331, num_levels),
"rotate": np.linspace(0, 30, num_levels),
"color": np.linspace(0.0, 0.9, num_levels),
"posterize": np.round(np.linspace(8, 4, num_levels), 0).astype(
np.int32
),
"solarize": np.linspace(256, 0, num_levels), # range [0, 256]
"contrast": np.linspace(0.0, 0.9, num_levels),
"sharpness": np.linspace(0.0, 0.9, num_levels),
"brightness": np.linspace(0.0, 0.9, num_levels),
"autocontrast": [0]
* num_levels, # This augmentation doesn't use magnitude parameter.
"equalize": [0]
* num_levels, # This augmentation doesn't use magnitude parameter.
"invert": [0]
* num_levels, # This augmentation doesn't use magnitude parameter.
}
def rotate_with_fill(img, magnitude):
"""Define rotation transformation with fill.
The input image is first rotated, then it is blended together with
a gray mask of the same size. Note that fillcolor as defined
elsewhere in this module doesn't apply here.
Args:
magnitude (float): rotation angle in degrees.
Returns:
rotated_filled (PIL Image): rotated image with gray filling for
disoccluded areas unveiled by the rotation.
"""
rotated = img.convert("RGBA").rotate(magnitude)
rotated_filled = Image.composite(
rotated, Image.new("RGBA", rotated.size, (128,) * 4), rotated
)
return rotated_filled.convert(img.mode)
# Define a dictionary of augmentation functions where each key refers
# to a specific type of augmentation and the corresponding value defines
# the augmentation itself using a lambda function.
# pylint: disable=unnecessary-lambda
func_dict = {
"shearX": lambda img, magnitude: img.transform(
img.size,
Image.AFFINE,
(1, magnitude * random.choice([-1, 1]), 0, 0, 1, 0),
Image.BICUBIC,
fillcolor=fillcolor,
),
"shearY": lambda img, magnitude: img.transform(
img.size,
Image.AFFINE,
(1, 0, 0, magnitude * random.choice([-1, 1]), 1, 0),
Image.BICUBIC,
fillcolor=fillcolor,
),
"translateX": lambda img, magnitude: img.transform(
img.size,
Image.AFFINE,
(
1,
0,
magnitude * img.size[0] * random.choice([-1, 1]),
0,
1,
0,
),
fillcolor=fillcolor,
),
"translateY": lambda img, magnitude: img.transform(
img.size,
Image.AFFINE,
(
1,
0,
0,
0,
1,
magnitude * img.size[1] * random.choice([-1, 1]),
),
fillcolor=fillcolor,
),
"rotate": lambda img, magnitude: rotate_with_fill(img, magnitude),
"color": lambda img, magnitude: ImageEnhance.Color(img).enhance(
1 + magnitude * random.choice([-1, 1])
),
"posterize": lambda img, magnitude: ImageOps.posterize(
img, magnitude
),
"solarize": lambda img, magnitude: ImageOps.solarize(
img, magnitude
),
"contrast": lambda img, magnitude: ImageEnhance.Contrast(
img
).enhance(1 + magnitude * random.choice([-1, 1])),
"sharpness": lambda img, magnitude: ImageEnhance.Sharpness(
img
).enhance(1 + magnitude * random.choice([-1, 1])),
"brightness": lambda img, magnitude: ImageEnhance.Brightness(
img
).enhance(1 + magnitude * random.choice([-1, 1])),
"autocontrast": lambda img, magnitude: ImageOps.autocontrast(img),
"equalize": lambda img, magnitude: ImageOps.equalize(img),
"invert": lambda img, magnitude: ImageOps.invert(img),
}
# Store probability, function and magnitude of the first augmentation
# for the sub-policy.
self.probability1 = probability1
self.operation1 = func_dict[operation1]
self.magnitude1 = ranges[operation1][magnitude_idx1]
# Store probability, function and magnitude of the second augmentation
# for the sub-policy.
self.probability2 = probability2
self.operation2 = func_dict[operation2]
self.magnitude2 = ranges[operation2][magnitude_idx2]
def __call__(self, img):
"""Define call method for SubPolicy class."""
# Randomly apply operation 1.
if random.random() < self.probability1:
img = self.operation1(img, self.magnitude1)
# Randomly apply operation 2.
if random.random() < self.probability2:
img = self.operation2(img, self.magnitude2)
return img
import os
import time
import numpy as np
import torch
from megatron.training import get_args, get_tokenizer, print_rank_0
from megatron.core import mpu, tensor_parallel
from megatron.legacy.data.dataset_utils import create_masked_lm_predictions, \
pad_and_convert_to_numpy
from megatron.legacy.data.data_samplers import MegatronPretrainingSampler
def make_attention_mask(source_block, target_block):
"""
Returns a 2-dimensional (2-D) attention mask
:param source_block: 1-D array
:param target_block: 1-D array
"""
mask = (target_block[None, :] >= 1) * (source_block[:, None] >= 1)
mask = mask.astype(np.int64)
# (source_length, target_length)
return mask
def get_one_epoch_dataloader(dataset, micro_batch_size=None):
"""Specifically one epoch to be used in an indexing job."""
args = get_args()
if micro_batch_size is None:
micro_batch_size = args.micro_batch_size
num_workers = args.num_workers
# Use megatron's sampler with consumed samples set to 0 as
# this is only for evaluation and don't intend to resume half way.
# Also, set the drop last to false as don't intend to remove
# the last batch
batch_sampler = MegatronPretrainingSampler(
total_samples=len(dataset),
consumed_samples=0,
micro_batch_size=args.micro_batch_size,
data_parallel_rank=mpu.get_data_parallel_rank(),
data_parallel_size=mpu.get_data_parallel_world_size(),
drop_last=False)
return torch.utils.data.DataLoader(dataset,
batch_sampler=batch_sampler,
num_workers=num_workers,
pin_memory=True)
def get_ict_batch(data_iterator):
# Items and their type.
keys = ['query_tokens', 'query_mask',
'context_tokens', 'context_mask', 'block_data']
datatype = torch.int64
# Broadcast data.
if data_iterator is None:
data = None
else:
data = next(data_iterator)
data_b = tensor_parallel.broadcast_data(keys, data, datatype)
# Unpack.
query_tokens = data_b['query_tokens'].long()
query_mask = data_b['query_mask'] < 0.5
context_tokens = data_b['context_tokens'].long()
context_mask = data_b['context_mask'] < 0.5
block_indices = data_b['block_data'].long()
return query_tokens, query_mask,\
context_tokens, context_mask, block_indices
def join_str_list(str_list):
"""Join a list of strings, handling spaces appropriately"""
result = ""
for s in str_list:
if s.startswith("##"):
result += s[2:]
else:
result += " " + s
return result
class BlockSampleData(object):
"""A struct for fully describing a fixed-size block of data as used in REALM
:param start_idx: for first sentence of the block
:param end_idx: for last sentence of the block (may be partially truncated in sample construction)
:param doc_idx: the index of the document from which the block comes in the original indexed dataset
:param block_idx: a unique integer identifier given to every block.
"""
def __init__(self, start_idx, end_idx, doc_idx, block_idx):
self.start_idx = start_idx
self.end_idx = end_idx
self.doc_idx = doc_idx
self.block_idx = block_idx
def as_array(self):
return np.array([self.start_idx, self.end_idx, self.doc_idx, self.block_idx]).astype(np.int64)
def as_tuple(self):
return self.start_idx, self.end_idx, self.doc_idx, self.block_idx
class BlockSamplesMapping(object):
def __init__(self, mapping_array):
# make sure that the array is compatible with BlockSampleData
assert mapping_array.shape[1] == 4
self.mapping_array = mapping_array
def __len__(self):
return self.mapping_array.shape[0]
def __getitem__(self, idx):
"""Get the data associated with an indexed sample."""
sample_data = BlockSampleData(*self.mapping_array[idx])
return sample_data
def get_block_samples_mapping(block_dataset, title_dataset, data_prefix, num_epochs,
max_num_samples, max_seq_length, seed, name, use_one_sent_docs=False):
"""Get samples mapping for a dataset over fixed size blocks. This function also requires
a dataset of the titles for the source documents since their lengths must be taken into account.
:return: samples_mapping (BlockSamplesMapping)
"""
if not num_epochs:
if not max_num_samples:
raise ValueError("Need to specify either max_num_samples "
"or num_epochs")
num_epochs = np.iinfo(np.int32).max - 1
if not max_num_samples:
max_num_samples = np.iinfo(np.int64).max - 1
# Filename of the index mapping
indexmap_filename = data_prefix
indexmap_filename += '_{}_indexmap'.format(name)
if num_epochs != (np.iinfo(np.int32).max - 1):
indexmap_filename += '_{}ep'.format(num_epochs)
if max_num_samples != (np.iinfo(np.int64).max - 1):
indexmap_filename += '_{}mns'.format(max_num_samples)
indexmap_filename += '_{}msl'.format(max_seq_length)
indexmap_filename += '_{}s'.format(seed)
if use_one_sent_docs:
indexmap_filename += '_1sentok'
indexmap_filename += '.npy'
# Build the indexed mapping if not exist.
if mpu.get_data_parallel_rank() == 0 and \
not os.path.isfile(indexmap_filename):
print(' > WARNING: could not find index map file {}, building '
'the indices on rank 0 ...'.format(indexmap_filename))
# Make sure the types match the helpers input types.
assert block_dataset.document_indices.dtype == np.int64
assert block_dataset.sequence_lengths.dtype == np.int32
# Build samples mapping
verbose = torch.distributed.get_rank() == 0
start_time = time.time()
print_rank_0(' > building samples index mapping for {} ...'.format(
name))
from megatron.core.datasets import helpers
mapping_array = helpers.build_blocks_mapping(
block_dataset.document_indices,
block_dataset.sequence_lengths,
title_dataset.sequence_lengths,
num_epochs,
max_num_samples,
max_seq_length - 3, # account for added tokens
seed,
verbose,
use_one_sent_docs)
print_rank_0(' > done building samples index mapping')
np.save(indexmap_filename, mapping_array, allow_pickle=True)
print_rank_0(' > saved the index mapping in {}'.format(
indexmap_filename))
# Make sure all the ranks have built the mapping
print_rank_0(' > elapsed time to build and save samples mapping '
'(seconds): {:4f}'.format(
time.time() - start_time))
# This should be a barrier but nccl barrier assumes
# device_index=rank which is not the case for model
# parallel case
counts = torch.tensor([1], dtype=torch.long, device='cuda')
torch.distributed.all_reduce(counts, group=mpu.get_data_parallel_group())
assert counts[0].item() == torch.distributed.get_world_size(
group=mpu.get_data_parallel_group())
# Load indexed dataset.
print_rank_0(' > loading indexed mapping from {}'.format(
indexmap_filename))
start_time = time.time()
mapping_array = np.load(indexmap_filename, allow_pickle=True, mmap_mode='r')
samples_mapping = BlockSamplesMapping(mapping_array)
print_rank_0(' loaded indexed file in {:3.3f} seconds'.format(
time.time() - start_time))
print_rank_0(' total number of samples: {}'.format(
mapping_array.shape[0]))
return samples_mapping
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
"""Dataloaders."""
import random
import torch
import numpy as np
from torch.utils.data import Dataset
from megatron.training import get_args
from megatron.core import mpu
def build_pretraining_data_loader(dataset, consumed_samples):
"""Build dataloader given an input dataset."""
if dataset is None:
return None
args = get_args()
# Megatron sampler
if args.dataloader_type == 'single':
batch_sampler = MegatronPretrainingSampler(
total_samples=len(dataset),
consumed_samples=consumed_samples,
micro_batch_size=args.micro_batch_size,
data_parallel_rank=mpu.get_data_parallel_rank(),
data_parallel_size=mpu.get_data_parallel_world_size())
elif args.dataloader_type == 'cyclic':
batch_sampler = MegatronPretrainingRandomSampler(
dataset,
total_samples=len(dataset),
consumed_samples=consumed_samples,
micro_batch_size=args.micro_batch_size,
data_parallel_rank=mpu.get_data_parallel_rank(),
data_parallel_size=mpu.get_data_parallel_world_size(),
data_sharding=args.data_sharding)
elif args.dataloader_type == "external":
# External dataloaders are passed through. User is expected to provide a
# torch-compatible dataloader and define samplers, if needed.
return dataset
else:
raise Exception('{} dataloader type is not supported.'.format(
args.dataloader_type))
# Torch dataloader.
return torch.utils.data.DataLoader(dataset,
batch_sampler=batch_sampler,
num_workers=args.num_workers,
pin_memory=True,
persistent_workers=True if args.num_workers > 0 else False,
)
class MegatronPretrainingSampler:
def __init__(self, total_samples, consumed_samples, micro_batch_size,
data_parallel_rank, data_parallel_size, drop_last=True):
# Keep a copy of input params for later use.
self.total_samples = total_samples
self.consumed_samples = consumed_samples
self.micro_batch_size = micro_batch_size
self.data_parallel_rank = data_parallel_rank
self.micro_batch_times_data_parallel_size = \
self.micro_batch_size * data_parallel_size
self.drop_last = drop_last
# Sanity checks.
assert self.total_samples > 0, \
'no sample to consume: {}'.format(self.total_samples)
assert self.consumed_samples < self.total_samples, \
'no samples left to consume: {}, {}'.format(self.consumed_samples,
self.total_samples)
assert self.micro_batch_size > 0
assert data_parallel_size > 0
assert self.data_parallel_rank < data_parallel_size, \
'data_parallel_rank should be smaller than data size: {}, ' \
'{}'.format(self.data_parallel_rank, data_parallel_size)
def __len__(self):
return self.total_samples
def get_start_end_idx(self):
start_idx = self.data_parallel_rank * self.micro_batch_size
end_idx = start_idx + self.micro_batch_size
return start_idx, end_idx
def __iter__(self):
batch = []
# Last batch will be dropped if drop_last is not set False
for idx in range(self.consumed_samples, self.total_samples):
batch.append(idx)
if len(batch) == self.micro_batch_times_data_parallel_size:
start_idx, end_idx = self.get_start_end_idx()
yield batch[start_idx:end_idx]
batch = []
# Check the last partial batch and see drop_last is set
if len(batch) > 0 and not self.drop_last:
start_idx, end_idx = self.get_start_end_idx()
yield batch[start_idx:end_idx]
class RandomSeedDataset(Dataset):
def __init__(self, dataset):
args = get_args()
self.base_seed = args.seed
self.curr_seed = args.seed
self.dataset = dataset
def __len__(self):
return len(self.dataset)
def set_epoch(self, epoch):
self.curr_seed = self.base_seed + epoch
def __getitem__(self, idx):
seed = idx + self.curr_seed
torch.manual_seed(seed)
random.seed(seed)
np.random.seed(seed)
return self.dataset[idx]
class MegatronPretrainingRandomSampler:
def __init__(self, dataset, total_samples, consumed_samples, micro_batch_size,
data_parallel_rank, data_parallel_size, data_sharding):
# Keep a copy of input params for later use.
self.dataset = dataset
self.total_samples = total_samples
self.consumed_samples = consumed_samples
self.micro_batch_size = micro_batch_size
self.data_parallel_rank = data_parallel_rank
self.data_parallel_size = data_parallel_size
self.data_sharding = data_sharding
self.micro_batch_times_data_parallel_size = \
self.micro_batch_size * data_parallel_size
self.last_batch_size = \
self.total_samples % self.micro_batch_times_data_parallel_size
# Sanity checks.
assert self.total_samples > 0, \
'no sample to consume: {}'.format(self.total_samples)
assert self.micro_batch_size > 0
assert data_parallel_size > 0
assert self.data_parallel_rank < data_parallel_size, \
'data_parallel_rank should be smaller than data size: {}, ' \
'{}'.format(self.data_parallel_rank, data_parallel_size)
def __len__(self):
return self.total_samples
def __iter__(self):
active_total_samples = self.total_samples - self.last_batch_size
self.epoch = self.consumed_samples // active_total_samples
current_epoch_samples = self.consumed_samples % active_total_samples
assert current_epoch_samples % self.micro_batch_times_data_parallel_size == 0
if isinstance(self.dataset, RandomSeedDataset):
self.dataset.set_epoch(self.epoch)
# data sharding and random sampling
if self.data_sharding:
bucket_size = (self.total_samples // self.micro_batch_times_data_parallel_size) \
* self.micro_batch_size
bucket_offset = current_epoch_samples // self.data_parallel_size
start_idx = self.data_parallel_rank * bucket_size
g = torch.Generator()
g.manual_seed(self.epoch)
random_idx = torch.randperm(bucket_size, generator=g).tolist()
idx_range = [start_idx + x for x in random_idx[bucket_offset:]]
else:
full_bucket_size = (self.total_samples // self.micro_batch_size) \
* self.micro_batch_size
full_bucket_offset = current_epoch_samples
g = torch.Generator()
g.manual_seed(self.epoch)
idx_range_total = \
torch.randperm(full_bucket_size, generator=g).tolist()
idx_range_active = idx_range_total[full_bucket_offset:]
idx_range = idx_range_active[self.data_parallel_rank::self.data_parallel_size]
batch = []
# Last batch if not complete will be dropped.
for idx in idx_range:
batch.append(idx)
if len(batch) == self.micro_batch_size:
self.consumed_samples += self.micro_batch_times_data_parallel_size
yield batch
batch = []
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment