Commit 80389ef6 authored by Jared Casper's avatar Jared Casper
Browse files

Merge branch 'main' into checkpoint_util

parents 1b2db724 d07d29df
......@@ -26,6 +26,7 @@ from .communication import (
broadcast_from_last_to_first_pipeline_stage)
from .forward_step import ForwardStep
from .sampling import sample
from .beam_utils import BeamHypotheses
def score_and_return_on_first_stage(model, tokens, lengths):
"""Function for just scoring.
......@@ -200,6 +201,7 @@ def generate_tokens_probs_and_return_on_first_stage(
top_p=top_p,
temperature=temperature,
vocab_size=tokenizer.vocab_size)
# If a prompt length is smaller or equal th current context
# length, it means we have started generating tokens
started = lengths <= context_length
......@@ -257,7 +259,7 @@ def generate_tokens_probs_and_return_on_first_stage(
tensor=done)
if use_eod_token_for_early_termination and done:
break
# ===================================================
# Update the length of based on max generated length.
# ===================================================
......@@ -280,6 +282,118 @@ def generate_tokens_probs_and_return_on_first_stage(
return tokens, generated_sequence_lengths, output_log_probs
def beam_search_and_return_on_first_stage(model, tokens, lengths, beam_size, stop_token, num_return_gen, length_penalty):
args = get_args()
tokenizer = get_tokenizer()
batch_size = tokens.size(0)
assert(batch_size == 1)
prompt_length = lengths.item()
final_sequence_length = tokens.size(1)
final_sequence_length = min(final_sequence_length, args.max_position_embeddings)
# If the context is too big, this happens
if prompt_length >= final_sequence_length:
raise ValueError("context length + tokens_to_generate too large")
# forward step.
forward_step = ForwardStep(model, beam_size, final_sequence_length)
beam_hyp = BeamHypotheses(beam_size, length_penalty)
done = False
scores = torch.zeros(beam_size,
dtype=torch.float32,
device=torch.cuda.current_device()).unsqueeze(1)
# =============
# Run infernece
# =============
with torch.no_grad():
tokens = tokens.repeat(beam_size, 1)
attention_mask, position_ids = _build_attention_mask_and_position_ids(tokens)
prev_context_length = 0
for context_length in range(prompt_length, final_sequence_length):
# Pick the slice that we need to pass through the network.
tokens2use = tokens[:, prev_context_length:context_length]
positions2use = position_ids[:, prev_context_length:context_length]
attention_mask2use = attention_mask[
..., prev_context_length:context_length, :context_length]
# logits will be meanigful only in the last pipeline stage.
logits = forward_step(tokens2use, positions2use, attention_mask2use)
if mpu.is_pipeline_last_stage():
vocab_size = logits.size(2)
log_probs = F.log_softmax(logits, dim=2)
new_scores = log_probs[:, -1, :] + scores
if context_length == prompt_length: # if this is the first one
sorted_scores, indices = torch.sort(new_scores[0,:], descending=True)
else:
sorted_scores, indices = torch.sort(new_scores.view(-1), descending=True)
best_beam_ids = torch.div(indices[: 2 * beam_size], vocab_size).trunc().long()
best_words = indices[:2 * beam_size] % vocab_size
best_scores = sorted_scores[: 2 * beam_size]
next_beams = []
for beam_token_rank, (token_id, beam_score, beam_id) in enumerate(
zip(best_words, best_scores, best_beam_ids)
):
if token_id.item() == stop_token:
# if beam_token does not belong to top num_beams tokens, it should not be added
is_beam_token_worse_than_top_num_beams = beam_token_rank >= beam_size
if is_beam_token_worse_than_top_num_beams:
continue
beam_hyp.add(
tokens[beam_id].clone(),
beam_score,
context_length + 1 - prompt_length
)
else:
# add next predicted token since it is not eos_token
next_beams.append((token_id, beam_score, beam_id))
if len(next_beams) == beam_size:
break
if beam_hyp.is_done(best_scores.max().item(), context_length + 1 - prompt_length):
done = True
break
best_batches = tokens.new([item[2] for item in next_beams])
tokens = tokens[best_batches,:]
tokens[:, context_length] = tokens.new([item[0] for item in next_beams])
scores = scores.new([item[1] for item in next_beams]).unsqueeze(1)
# set inference key values to make it consistent with best beam index
forward_step.inference_params.swap_key_value_dict(best_batches)
# Update the tokens on the first stage so the next input to
# the network is correct.
copy_from_last_to_first_pipeline_stage(batch_size, torch.int64,
tokens[:, context_length])
# Update the context length for the next token generation.
prev_context_length = context_length
copy_from_last_to_first_pipeline_stage(scores.size(0), torch.float32,
scores[:,0])
# if cannot find stop token, add open beams to hyps
if not done:
for beam_id in range(beam_size):
beam_hyp.add(tokens[beam_id].clone(), scores[beam_id], context_length + 1 - prompt_length)
# rank based on scores
sorted_hyps = sorted(beam_hyp.beams, key=lambda x: x[0], reverse=True)
num_return_gen = min(num_return_gen, len(sorted_hyps))
scores = [sorted_hyps[i][0] for i in range(num_return_gen)]
tokens = [sorted_hyps[i][1] for i in range(num_return_gen)]
scores = torch.stack(scores, dim=0)
tokens = torch.stack(tokens, dim=0)
return tokens, scores
def _build_attention_mask_and_position_ids(tokens):
......
......@@ -20,9 +20,11 @@ from flask import Flask, request, jsonify, current_app
from flask_restful import Resource, Api
from megatron import get_args
from megatron.text_generation import generate_and_post_process
from megatron.text_generation import beam_search_and_post_process
GENERATE_NUM = 0
BEAM_NUM = 1
lock = threading.Lock()
class MegatronGenerate(Resource):
......@@ -34,6 +36,11 @@ class MegatronGenerate(Resource):
choice = torch.cuda.LongTensor([GENERATE_NUM])
torch.distributed.broadcast(choice, 0)
@staticmethod
def send_do_beam_search():
choice = torch.cuda.LongTensor([BEAM_NUM])
torch.distributed.broadcast(choice, 0)
def put(self):
args = get_args()
......@@ -128,15 +135,57 @@ class MegatronGenerate(Resource):
if not isinstance(no_log, bool):
return "no_log must be a boolean value"
beam_width = None
if "beam_width" in request.get_json():
beam_width = request.get_json()["beam_width"]
if not isinstance(beam_width, int):
return "beam_width must be integer"
if beam_width < 1:
return "beam_width must be an integer > 1"
if len(prompts) > 1:
return "When doing beam_search, batch size must be 1"
stop_token=50256
if "stop_token" in request.get_json():
stop_token = request.get_json()["stop_token"]
if not isinstance(stop_token, int):
return "stop_token must be an integer"
length_penalty = 1
if "length_penalty" in request.get_json():
length_penalty = request.get_json()["length_penalty"]
if not isinstance(length_penalty, float):
return "length_penalty must be a float"
with lock: # Need to get lock to keep multiple threads from hitting code
if not no_log:
print("request IP: " + str(request.remote_addr))
print(json.dumps(request.get_json()),flush=True)
print("start time: ", datetime.datetime.now())
MegatronGenerate.send_do_generate() # Tell other ranks we're doing generate
try:
response, response_seg, response_logprobs, _ = \
generate_and_post_process(
if beam_width is not None:
MegatronGenerate.send_do_beam_search() # Tell other ranks we're doing beam_search
response, response_seg, response_scores = \
beam_search_and_post_process(
self.model,
prompts=prompts,
tokens_to_generate=tokens_to_generate,
beam_size = beam_width,
add_BOS=add_BOS,
stop_token=stop_token,
num_return_gen=beam_width, # Returning whole beam
length_penalty=length_penalty
)
return jsonify({"text": response,
"segments": response_seg,
"scores": response_scores})
else:
MegatronGenerate.send_do_generate() # Tell other ranks we're doing generate
response, response_seg, response_logprobs, _ = \
generate_and_post_process(
self.model,
prompts=prompts,
tokens_to_generate=tokens_to_generate,
......@@ -149,13 +198,15 @@ class MegatronGenerate(Resource):
stop_on_double_eol=stop_on_double_eol,
stop_on_eol=stop_on_eol,
random_seed=random_seed)
return jsonify({"text": response,
"segments": response_seg,
"logprobs": response_logprobs})
except ValueError as ve:
return "Length of prompt + tokens_to_generate longer than allowed"
print("end time: ", datetime.datetime.now())
return jsonify({"text": response,
"segments": response_seg,
"logprobs": response_logprobs})
class MegatronServer(object):
def __init__(self, model):
......
......@@ -42,6 +42,7 @@ from megatron.model import ModelType
from megatron.optimizer import get_megatron_optimizer
from megatron.initialize import initialize_megatron
from megatron.initialize import write_args_to_tensorboard
from megatron.initialize import set_jit_fusion_options
from megatron.optimizer_param_scheduler import OptimizerParamScheduler
from megatron.model import DistributedDataParallel as LocalDDP
from megatron.utils import check_adlr_autoresume_termination
......@@ -99,6 +100,8 @@ def pretrain(train_valid_test_dataset_provider,
# Initalize and get arguments, timers, and Tensorboard writer.
initialize_megatron(extra_args_provider=extra_args_provider,
args_defaults=args_defaults)
# Set pytorch JIT layer fusion options and warmup JIT functions.
set_jit_fusion_options()
# Adjust the startup time so it reflects the largest value.
# This will be closer to what scheduler will see (outside of
......@@ -361,12 +364,11 @@ def setup_model_and_optimizer(model_provider_func,
args = get_args()
model = get_model(model_provider_func, model_type)
unwrapped_model = unwrap_model(model,
(torchDDP, LocalDDP, Float16Module))
optimizer = get_megatron_optimizer(unwrapped_model, no_wd_decay_cond,
scale_lr_cond, lr_mult)
optimizer = get_megatron_optimizer(model, no_wd_decay_cond,
scale_lr_cond, lr_mult)
opt_param_scheduler = get_optimizer_param_scheduler(optimizer)
if args.load is not None:
......@@ -409,78 +411,44 @@ def train_step(forward_step_func, data_iterator,
partition.zero_grad_buffer()
optimizer.zero_grad()
# Forward pass.
forward_backward_func = get_forward_backward_func()
losses_reduced = forward_backward_func(
forward_step_func, data_iterator, model,
optimizer, timers, forward_only=False)
# Empty unused memory
# Empty unused memory.
if args.empty_unused_memory_level >= 1:
torch.cuda.empty_cache()
# All-reduce if needed.
if args.DDP_impl == 'local':
timers('backward-params-all-reduce').start()
for model_module in model:
model_module.allreduce_gradients()
timers('backward-params-all-reduce').stop()
# All-reduce word_embeddings' grad across first and last stages to ensure
# that word_embeddings parameters stay in sync.
# This should only run for models that support pipelined model parallelism
# (BERT and GPT-2).
timers('backward-embedding-all-reduce').start()
if mpu.is_rank_in_embedding_group(ignore_virtual=True) and \
mpu.get_pipeline_model_parallel_world_size() > 1:
if mpu.is_pipeline_first_stage(ignore_virtual=True):
unwrapped_model = model[0]
elif mpu.is_pipeline_last_stage(ignore_virtual=True):
unwrapped_model = model[-1]
else: # We do not support the interleaved schedule for T5 yet.
unwrapped_model = model[0]
unwrapped_model = unwrap_model(
unwrapped_model, (torchDDP, LocalDDP, Float16Module))
if unwrapped_model.share_word_embeddings:
word_embeddings_weight = unwrapped_model.word_embeddings_weight()
if args.DDP_impl == 'local':
grad = word_embeddings_weight.main_grad
else:
grad = word_embeddings_weight.grad
torch.distributed.all_reduce(grad, group=mpu.get_embedding_group())
# All-reduce position_embeddings grad across first (encoder) and split (decoder)
# stages to ensure that position embeddings parameters stay in sync.
# This should only run for T5 models with pipeline parallelism
if mpu.is_rank_in_position_embedding_group() and \
mpu.get_pipeline_model_parallel_world_size() > 1 and \
args.pipeline_model_parallel_split_rank is not None:
unwrapped_model = model[0]
unwrapped_model = unwrap_model(
unwrapped_model, (torchDDP, LocalDDP, Float16Module))
assert args.DDP_impl == 'local', \
'T5 model is only supported with local DDP mode'
grad = unwrapped_model.language_model.embedding.position_embeddings.weight.main_grad
torch.distributed.all_reduce(grad, group=mpu.get_position_embedding_group())
timers('backward-embedding-all-reduce').stop()
# Reduce gradients.
timers('backward-reduce-model-grads').start()
optimizer.reduce_model_grads(args, timers)
timers('backward-reduce-model-grads').stop()
# Vision gradients.
if args.vision_pretraining and args.vision_pretraining_type == "dino":
unwrapped_model = unwrap_model(model[0],
(torchDDP, LocalDDP, Float16Module))
unwrapped_model.cancel_gradients_last_layer(args.curr_iteration)
# Update parameters.
timers('optimizer').start()
update_successful, grad_norm, num_zeros_in_grad = optimizer.step()
update_successful, grad_norm, num_zeros_in_grad = optimizer.step(args, timers)
timers('optimizer').stop()
# Gather params.
if update_successful:
timers('backward-gather-model-params').start()
optimizer.gather_model_params(args, timers)
timers('backward-gather-model-params').stop()
# Vision momentum.
if args.vision_pretraining and args.vision_pretraining_type == "dino":
unwrapped_model = unwrap_model(model[0],
(torchDDP, LocalDDP, Float16Module))
unwrapped_model.update_momentum(args.curr_iteration)
# Update learning rate.
if update_successful:
increment = get_num_microbatches() * \
......@@ -491,7 +459,7 @@ def train_step(forward_step_func, data_iterator,
else:
skipped_iter = 1
# Empty unused memory
# Empty unused memory.
if args.empty_unused_memory_level >= 2:
torch.cuda.empty_cache()
......@@ -558,10 +526,15 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
add_to_logging('backward-send-forward-recv')
add_to_logging('backward-send-backward-recv')
add_to_logging('backward-params-all-reduce')
add_to_logging('backward-layernorm-all-reduce')
add_to_logging('backward-embedding-all-reduce')
add_to_logging('backward-reduce-model-grads')
add_to_logging('backward-gather-model-params')
add_to_logging('optimizer-copy-to-main-grad')
add_to_logging('optimizer-unscale-and-check-inf')
add_to_logging('optimizer-clip-main-grad')
add_to_logging('optimizer-count-zeros')
add_to_logging('optimizer-inner-step')
add_to_logging('optimizer-copy-main-to-model-params')
add_to_logging('optimizer')
add_to_logging('batch-generator')
......
......@@ -229,7 +229,7 @@ def _train(model, optimizer, opt_param_scheduler, forward_step,
prefix = 'iteration {}'.format(iteration)
evaluate_and_print_results(prefix, forward_step,
valid_dataloader, model,
iteration, False)
iteration, None, False)
# Exiting based on iterations
if args.exit_interval and iteration % args.exit_interval == 0:
......
......@@ -15,12 +15,15 @@
"""Vision-classification finetuning/evaluation."""
from megatron import get_args
import torch.nn.functional as F
from functools import partial
from megatron import get_args, get_timers
from megatron import print_rank_0
from megatron.model.vit_model import VitModel
from megatron.model.vision.classification import VitClassificationModel
from megatron.data.vit_dataset import build_train_valid_datasets
from tasks.vision.eval_utils import accuracy_func_provider
from tasks.vision.classification.eval_utils import accuracy_func_provider
from tasks.vision.finetune_utils import finetune
from megatron.utils import average_losses_across_data_parallel_group
def classification():
......@@ -30,7 +33,7 @@ def classification():
train_ds, valid_ds = build_train_valid_datasets(
data_path=args.data_path,
crop_size=args.img_dim,
image_size=(args.img_h, args.img_w),
)
return train_ds, valid_ds
......@@ -40,16 +43,52 @@ def classification():
print_rank_0("building classification model for ImageNet ...")
return VitModel(num_classes=args.num_classes, finetune=True,
pre_process=pre_process, post_process=post_process)
return VitClassificationModel(num_classes=args.num_classes, finetune=True,
pre_process=pre_process, post_process=post_process)
def process_batch(batch):
"""Process batch and produce inputs for the model."""
images = batch[0].cuda().contiguous()
labels = batch[1].cuda().contiguous()
return images, labels
def cross_entropy_loss_func(labels, output_tensor):
logits = output_tensor
# Cross-entropy loss.
loss = F.cross_entropy(logits.contiguous().float(), labels)
# Reduce loss for logging.
averaged_loss = average_losses_across_data_parallel_group([loss])
return loss, {'lm loss': averaged_loss[0]}
def _cross_entropy_forward_step(batch, model):
"""Simple forward step with cross-entropy loss."""
timers = get_timers()
# Get the batch.
timers("batch generator").start()
try:
batch_ = next(batch)
except BaseException:
batch_ = batch
images, labels = process_batch(batch_)
timers("batch generator").stop()
# Forward model.
output_tensor = model(images)
return output_tensor, partial(cross_entropy_loss_func, labels)
"""Finetune/evaluate."""
finetune(
train_valid_datasets_provider,
model_provider,
forward_step=_cross_entropy_forward_step,
end_of_epoch_callback_provider=accuracy_func_provider,
)
def main():
classification()
......@@ -33,11 +33,10 @@ def accuracy_func_provider():
"""Provide function that calculates accuracies."""
args = get_args()
data_path = args.data_path
crop_size = args.img_dim
crop_size = (args.img_h, args.img_w)
# mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
# Build dataloaders.
val_data_path = os.path.join(data_path[0], "val")
val_data_path = data_path[1]
normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
transform_val = transforms.Compose(
[
......@@ -54,6 +53,7 @@ def accuracy_func_provider():
args.micro_batch_size,
num_workers=args.num_workers,
drop_last=(mpu.get_data_parallel_world_size() > 1),
shuffle=False
)
def metrics_func(model, epoch):
......@@ -71,7 +71,6 @@ def accuracy_func_provider():
def calculate_correct_answers(model, dataloader, epoch):
"""Calculate correct over total answers"""
args = get_args()
forward_backward_func = get_forward_backward_func()
for m in model:
m.eval()
......@@ -98,7 +97,6 @@ def calculate_correct_answers(model, dataloader, epoch):
images, labels = process_batch(batch_)
# Forward model.
args = get_args()
output_tensor = model(images)
return output_tensor, partial(loss_func, labels)
......
This diff is collapsed.
......@@ -28,32 +28,24 @@ sys.path.append(
)
from megatron import get_args
from megatron.initialize import initialize_megatron
from classification import main
def get_tasks_args(parser):
"""Provide extra arguments required for tasks."""
group = parser.add_argument_group(title="tasks")
group.add_argument(
"--epochs",
type=int,
default=None,
help="Number of finetunning epochs. Zero results in "
"evaluation only.",
)
group.add_argument(
"--pretrained-checkpoint",
type=str,
default=None,
help="Pretrained checkpoint used for finetunning.",
)
group.add_argument(
"--keep-last",
action="store_true",
help="Keep the last batch (maybe incomplete) in" "the data loader",
)
group.add_argument('--task', type=str, default='segment',
choices=['classify', 'segment_setr', 'segment_segformer'],
help='task name.')
group.add_argument("--epochs", type=int, default=None,
help="Number of finetunning epochs. Zero results in "
"evaluation only.")
group.add_argument('--pretrained-checkpoint-type', type=str, default='default',
choices=['default', 'external', 'constrastive'],
help='Type of pretrained checkpoint')
group.add_argument("--pretrained-checkpoint", type=str, default=None,
help="Pretrained checkpoint used for finetunning.")
group.add_argument('--seg-stride', type=int, default=None,
help='sliding window stride during evaluation')
return parser
......@@ -61,4 +53,14 @@ if __name__ == "__main__":
initialize_megatron(extra_args_provider=get_tasks_args)
args = get_args()
main()
if args.task == 'classify':
from tasks.vision.classification.classification import main
main()
elif args.task == 'segment_setr':
from tasks.vision.segmentation.finetune_setr import main
main()
elif args.task == 'segment_segformer':
from tasks.vision.segmentation.finetune_segformer import main
main()
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
......@@ -122,8 +122,10 @@ def get_args():
choices=['lazy', 'cached', 'mmap'])
group = parser.add_argument_group(title='runtime')
group.add_argument('--workers', type=int, default=1,
group.add_argument('--workers', type=int, required=True,
help='Number of worker processes to launch')
group.add_argument('--chunk-size', type=int, required=True,
help='Chunk size assigned to each worker process')
group.add_argument('--log-interval', type=int, default=100,
help='Interval between progress updates')
args = parser.parse_args()
......@@ -154,7 +156,7 @@ def main():
encoder = Encoder(args)
tokenizer = build_tokenizer(args)
pool = multiprocessing.Pool(args.workers, initializer=encoder.initializer)
encoded_docs = pool.imap(encoder.encode, fin, 25)
encoded_docs = pool.imap(encoder.encode, fin, args.chunk_size)
#encoded_docs = map(encoder.encode, fin)
level = "document"
......
......@@ -28,6 +28,7 @@ from megatron.model import GPTModel
from megatron.training import get_model
from megatron.text_generation_server import MegatronServer
from megatron.text_generation import generate_and_post_process
from megatron.text_generation import beam_search_and_post_process
import torch
def model_provider(pre_process=True, post_process=True):
......@@ -82,3 +83,8 @@ if __name__ == "__main__":
generate_and_post_process(model)
except ValueError as ve:
pass
elif choice[0].item() == 1:
try:
beam_search_and_post_process(model)
except ValueError as ve:
pass
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment