Commit 2eea6216 authored by rprenger's avatar rprenger
Browse files

Merging with main and fixing merge conflict

parents ed6806ac 5f694372
...@@ -20,9 +20,11 @@ from flask import Flask, request, jsonify, current_app ...@@ -20,9 +20,11 @@ from flask import Flask, request, jsonify, current_app
from flask_restful import Resource, Api from flask_restful import Resource, Api
from megatron import get_args from megatron import get_args
from megatron.text_generation import generate_and_post_process from megatron.text_generation import generate_and_post_process
from megatron.text_generation import beam_search_and_post_process
GENERATE_NUM = 0 GENERATE_NUM = 0
BEAM_NUM = 1
lock = threading.Lock() lock = threading.Lock()
class MegatronGenerate(Resource): class MegatronGenerate(Resource):
...@@ -34,6 +36,11 @@ class MegatronGenerate(Resource): ...@@ -34,6 +36,11 @@ class MegatronGenerate(Resource):
choice = torch.cuda.LongTensor([GENERATE_NUM]) choice = torch.cuda.LongTensor([GENERATE_NUM])
torch.distributed.broadcast(choice, 0) torch.distributed.broadcast(choice, 0)
@staticmethod
def send_do_beam_search():
choice = torch.cuda.LongTensor([BEAM_NUM])
torch.distributed.broadcast(choice, 0)
def put(self): def put(self):
args = get_args() args = get_args()
...@@ -148,13 +155,55 @@ class MegatronGenerate(Resource): ...@@ -148,13 +155,55 @@ class MegatronGenerate(Resource):
if not isinstance(no_log, bool): if not isinstance(no_log, bool):
return "no_log must be a boolean value" return "no_log must be a boolean value"
beam_width = None
if "beam_width" in request.get_json():
beam_width = request.get_json()["beam_width"]
if not isinstance(beam_width, int):
return "beam_width must be integer"
if beam_width < 1:
return "beam_width must be an integer > 1"
if len(prompts) > 1:
return "When doing beam_search, batch size must be 1"
stop_token=50256
if "stop_token" in request.get_json():
stop_token = request.get_json()["stop_token"]
if not isinstance(stop_token, int):
return "stop_token must be an integer"
length_penalty = 1
if "length_penalty" in request.get_json():
length_penalty = request.get_json()["length_penalty"]
if not isinstance(length_penalty, float):
return "length_penalty must be a float"
with lock: # Need to get lock to keep multiple threads from hitting code with lock: # Need to get lock to keep multiple threads from hitting code
if not no_log: if not no_log:
print("request IP: " + str(request.remote_addr)) print("request IP: " + str(request.remote_addr))
print(json.dumps(request.get_json()),flush=True) print(json.dumps(request.get_json()),flush=True)
print("start time: ", datetime.datetime.now()) print("start time: ", datetime.datetime.now())
MegatronGenerate.send_do_generate() # Tell other ranks we're doing generate
try: try:
if beam_width is not None:
MegatronGenerate.send_do_beam_search() # Tell other ranks we're doing beam_search
response, response_seg, response_scores = \
beam_search_and_post_process(
self.model,
prompts=prompts,
tokens_to_generate=tokens_to_generate,
beam_size = beam_width,
add_BOS=add_BOS,
stop_token=stop_token,
num_return_gen=beam_width, # Returning whole beam
length_penalty=length_penalty
)
return jsonify({"text": response,
"segments": response_seg,
"scores": response_scores})
else:
MegatronGenerate.send_do_generate() # Tell other ranks we're doing generate
response, response_seg, response_logprobs, _ = \ response, response_seg, response_logprobs, _ = \
generate_and_post_process( generate_and_post_process(
self.model, self.model,
...@@ -171,14 +220,16 @@ class MegatronGenerate(Resource): ...@@ -171,14 +220,16 @@ class MegatronGenerate(Resource):
stop_on_double_eol=stop_on_double_eol, stop_on_double_eol=stop_on_double_eol,
stop_on_eol=stop_on_eol, stop_on_eol=stop_on_eol,
random_seed=random_seed) random_seed=random_seed)
except ValueError as ve:
return "Length of prompt + tokens_to_generate longer than allowed"
print("end time: ", datetime.datetime.now())
return jsonify({"text": response, return jsonify({"text": response,
"segments": response_seg, "segments": response_seg,
"logprobs": response_logprobs}) "logprobs": response_logprobs})
except ValueError as ve:
return "Length of prompt + tokens_to_generate longer than allowed"
print("end time: ", datetime.datetime.now())
class MegatronServer(object): class MegatronServer(object):
def __init__(self, model): def __init__(self, model):
self.app = Flask(__name__, static_url_path='') self.app = Flask(__name__, static_url_path='')
......
...@@ -42,6 +42,7 @@ from megatron.model import ModelType ...@@ -42,6 +42,7 @@ from megatron.model import ModelType
from megatron.optimizer import get_megatron_optimizer from megatron.optimizer import get_megatron_optimizer
from megatron.initialize import initialize_megatron from megatron.initialize import initialize_megatron
from megatron.initialize import write_args_to_tensorboard from megatron.initialize import write_args_to_tensorboard
from megatron.initialize import set_jit_fusion_options
from megatron.optimizer_param_scheduler import OptimizerParamScheduler from megatron.optimizer_param_scheduler import OptimizerParamScheduler
from megatron.model import DistributedDataParallel as LocalDDP from megatron.model import DistributedDataParallel as LocalDDP
from megatron.utils import check_adlr_autoresume_termination from megatron.utils import check_adlr_autoresume_termination
...@@ -99,6 +100,8 @@ def pretrain(train_valid_test_dataset_provider, ...@@ -99,6 +100,8 @@ def pretrain(train_valid_test_dataset_provider,
# Initalize and get arguments, timers, and Tensorboard writer. # Initalize and get arguments, timers, and Tensorboard writer.
initialize_megatron(extra_args_provider=extra_args_provider, initialize_megatron(extra_args_provider=extra_args_provider,
args_defaults=args_defaults) args_defaults=args_defaults)
# Set pytorch JIT layer fusion options and warmup JIT functions.
set_jit_fusion_options()
# Adjust the startup time so it reflects the largest value. # Adjust the startup time so it reflects the largest value.
# This will be closer to what scheduler will see (outside of # This will be closer to what scheduler will see (outside of
...@@ -361,12 +364,11 @@ def setup_model_and_optimizer(model_provider_func, ...@@ -361,12 +364,11 @@ def setup_model_and_optimizer(model_provider_func,
args = get_args() args = get_args()
model = get_model(model_provider_func, model_type) model = get_model(model_provider_func, model_type)
unwrapped_model = unwrap_model(model, unwrapped_model = unwrap_model(model,
(torchDDP, LocalDDP, Float16Module)) (torchDDP, LocalDDP, Float16Module))
optimizer = get_megatron_optimizer(unwrapped_model, no_wd_decay_cond,
scale_lr_cond, lr_mult)
optimizer = get_megatron_optimizer(model, no_wd_decay_cond,
scale_lr_cond, lr_mult)
opt_param_scheduler = get_optimizer_param_scheduler(optimizer) opt_param_scheduler = get_optimizer_param_scheduler(optimizer)
if args.load is not None: if args.load is not None:
...@@ -409,78 +411,44 @@ def train_step(forward_step_func, data_iterator, ...@@ -409,78 +411,44 @@ def train_step(forward_step_func, data_iterator,
partition.zero_grad_buffer() partition.zero_grad_buffer()
optimizer.zero_grad() optimizer.zero_grad()
# Forward pass.
forward_backward_func = get_forward_backward_func() forward_backward_func = get_forward_backward_func()
losses_reduced = forward_backward_func( losses_reduced = forward_backward_func(
forward_step_func, data_iterator, model, forward_step_func, data_iterator, model,
optimizer, timers, forward_only=False) optimizer, timers, forward_only=False)
# Empty unused memory # Empty unused memory.
if args.empty_unused_memory_level >= 1: if args.empty_unused_memory_level >= 1:
torch.cuda.empty_cache() torch.cuda.empty_cache()
# All-reduce if needed. # Reduce gradients.
if args.DDP_impl == 'local': timers('backward-reduce-model-grads').start()
timers('backward-params-all-reduce').start() optimizer.reduce_model_grads(args, timers)
for model_module in model: timers('backward-reduce-model-grads').stop()
model_module.allreduce_gradients()
timers('backward-params-all-reduce').stop()
# All-reduce word_embeddings' grad across first and last stages to ensure
# that word_embeddings parameters stay in sync.
# This should only run for models that support pipelined model parallelism
# (BERT and GPT-2).
timers('backward-embedding-all-reduce').start()
if mpu.is_rank_in_embedding_group(ignore_virtual=True) and \
mpu.get_pipeline_model_parallel_world_size() > 1:
if mpu.is_pipeline_first_stage(ignore_virtual=True):
unwrapped_model = model[0]
elif mpu.is_pipeline_last_stage(ignore_virtual=True):
unwrapped_model = model[-1]
else: # We do not support the interleaved schedule for T5 yet.
unwrapped_model = model[0]
unwrapped_model = unwrap_model(
unwrapped_model, (torchDDP, LocalDDP, Float16Module))
if unwrapped_model.share_word_embeddings:
word_embeddings_weight = unwrapped_model.word_embeddings_weight()
if args.DDP_impl == 'local':
grad = word_embeddings_weight.main_grad
else:
grad = word_embeddings_weight.grad
torch.distributed.all_reduce(grad, group=mpu.get_embedding_group())
# All-reduce position_embeddings grad across first (encoder) and split (decoder)
# stages to ensure that position embeddings parameters stay in sync.
# This should only run for T5 models with pipeline parallelism
if mpu.is_rank_in_position_embedding_group() and \
mpu.get_pipeline_model_parallel_world_size() > 1 and \
args.pipeline_model_parallel_split_rank is not None:
unwrapped_model = model[0]
unwrapped_model = unwrap_model(
unwrapped_model, (torchDDP, LocalDDP, Float16Module))
assert args.DDP_impl == 'local', \
'T5 model is only supported with local DDP mode'
grad = unwrapped_model.language_model.embedding.position_embeddings.weight.main_grad
torch.distributed.all_reduce(grad, group=mpu.get_position_embedding_group())
timers('backward-embedding-all-reduce').stop()
# Vision gradients.
if args.vision_pretraining and args.vision_pretraining_type == "dino": if args.vision_pretraining and args.vision_pretraining_type == "dino":
unwrapped_model = unwrap_model(model[0], unwrapped_model = unwrap_model(model[0],
(torchDDP, LocalDDP, Float16Module)) (torchDDP, LocalDDP, Float16Module))
unwrapped_model.cancel_gradients_last_layer(args.curr_iteration) unwrapped_model.cancel_gradients_last_layer(args.curr_iteration)
# Update parameters. # Update parameters.
timers('optimizer').start() timers('optimizer').start()
update_successful, grad_norm, num_zeros_in_grad = optimizer.step() update_successful, grad_norm, num_zeros_in_grad = optimizer.step(args, timers)
timers('optimizer').stop() timers('optimizer').stop()
# Gather params.
if update_successful:
timers('backward-gather-model-params').start()
optimizer.gather_model_params(args, timers)
timers('backward-gather-model-params').stop()
# Vision momentum.
if args.vision_pretraining and args.vision_pretraining_type == "dino": if args.vision_pretraining and args.vision_pretraining_type == "dino":
unwrapped_model = unwrap_model(model[0], unwrapped_model = unwrap_model(model[0],
(torchDDP, LocalDDP, Float16Module)) (torchDDP, LocalDDP, Float16Module))
unwrapped_model.update_momentum(args.curr_iteration) unwrapped_model.update_momentum(args.curr_iteration)
# Update learning rate. # Update learning rate.
if update_successful: if update_successful:
increment = get_num_microbatches() * \ increment = get_num_microbatches() * \
...@@ -491,7 +459,7 @@ def train_step(forward_step_func, data_iterator, ...@@ -491,7 +459,7 @@ def train_step(forward_step_func, data_iterator,
else: else:
skipped_iter = 1 skipped_iter = 1
# Empty unused memory # Empty unused memory.
if args.empty_unused_memory_level >= 2: if args.empty_unused_memory_level >= 2:
torch.cuda.empty_cache() torch.cuda.empty_cache()
...@@ -558,10 +526,15 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration, ...@@ -558,10 +526,15 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
add_to_logging('backward-send-forward-recv') add_to_logging('backward-send-forward-recv')
add_to_logging('backward-send-backward-recv') add_to_logging('backward-send-backward-recv')
add_to_logging('backward-params-all-reduce') add_to_logging('backward-params-all-reduce')
add_to_logging('backward-layernorm-all-reduce')
add_to_logging('backward-embedding-all-reduce') add_to_logging('backward-embedding-all-reduce')
add_to_logging('backward-reduce-model-grads')
add_to_logging('backward-gather-model-params')
add_to_logging('optimizer-copy-to-main-grad') add_to_logging('optimizer-copy-to-main-grad')
add_to_logging('optimizer-unscale-and-check-inf') add_to_logging('optimizer-unscale-and-check-inf')
add_to_logging('optimizer-clip-main-grad') add_to_logging('optimizer-clip-main-grad')
add_to_logging('optimizer-count-zeros')
add_to_logging('optimizer-inner-step')
add_to_logging('optimizer-copy-main-to-model-params') add_to_logging('optimizer-copy-main-to-model-params')
add_to_logging('optimizer') add_to_logging('optimizer')
add_to_logging('batch-generator') add_to_logging('batch-generator')
......
...@@ -24,7 +24,6 @@ from apex.multi_tensor_apply import multi_tensor_applier ...@@ -24,7 +24,6 @@ from apex.multi_tensor_apply import multi_tensor_applier
import amp_C import amp_C
from megatron import get_args from megatron import get_args
from megatron import print_rank_0
from megatron import get_adlr_autoresume from megatron import get_adlr_autoresume
from megatron import mpu from megatron import mpu
from megatron.model.module import param_is_not_shared from megatron.model.module import param_is_not_shared
...@@ -204,3 +203,22 @@ def get_ltor_masks_and_position_ids(data, ...@@ -204,3 +203,22 @@ def get_ltor_masks_and_position_ids(data,
return attention_mask, loss_mask, position_ids return attention_mask, loss_mask, position_ids
def print_rank_0(message):
"""If distributed is initialized, print only on rank 0."""
if torch.distributed.is_initialized():
if torch.distributed.get_rank() == 0:
print(message, flush=True)
else:
print(message, flush=True)
def is_last_rank():
return torch.distributed.get_rank() == (
torch.distributed.get_world_size() - 1)
def print_rank_last(message):
"""If distributed is initialized, print only on last rank."""
if torch.distributed.is_initialized():
if is_last_rank():
print(message, flush=True)
else:
print(message, flush=True)
...@@ -229,7 +229,7 @@ def _train(model, optimizer, opt_param_scheduler, forward_step, ...@@ -229,7 +229,7 @@ def _train(model, optimizer, opt_param_scheduler, forward_step,
prefix = 'iteration {}'.format(iteration) prefix = 'iteration {}'.format(iteration)
evaluate_and_print_results(prefix, forward_step, evaluate_and_print_results(prefix, forward_step,
valid_dataloader, model, valid_dataloader, model,
iteration, False) iteration, None, False)
# Exiting based on iterations # Exiting based on iterations
if args.exit_interval and iteration % args.exit_interval == 0: if args.exit_interval and iteration % args.exit_interval == 0:
......
...@@ -15,12 +15,15 @@ ...@@ -15,12 +15,15 @@
"""Vision-classification finetuning/evaluation.""" """Vision-classification finetuning/evaluation."""
from megatron import get_args import torch.nn.functional as F
from functools import partial
from megatron import get_args, get_timers
from megatron import print_rank_0 from megatron import print_rank_0
from megatron.model.vit_model import VitModel from megatron.model.vision.classification import VitClassificationModel
from megatron.data.vit_dataset import build_train_valid_datasets from megatron.data.vit_dataset import build_train_valid_datasets
from tasks.vision.eval_utils import accuracy_func_provider from tasks.vision.classification.eval_utils import accuracy_func_provider
from tasks.vision.finetune_utils import finetune from tasks.vision.finetune_utils import finetune
from megatron.utils import average_losses_across_data_parallel_group
def classification(): def classification():
...@@ -30,7 +33,7 @@ def classification(): ...@@ -30,7 +33,7 @@ def classification():
train_ds, valid_ds = build_train_valid_datasets( train_ds, valid_ds = build_train_valid_datasets(
data_path=args.data_path, data_path=args.data_path,
crop_size=args.img_dim, image_size=(args.img_h, args.img_w),
) )
return train_ds, valid_ds return train_ds, valid_ds
...@@ -40,16 +43,52 @@ def classification(): ...@@ -40,16 +43,52 @@ def classification():
print_rank_0("building classification model for ImageNet ...") print_rank_0("building classification model for ImageNet ...")
return VitModel(num_classes=args.num_classes, finetune=True, return VitClassificationModel(num_classes=args.num_classes, finetune=True,
pre_process=pre_process, post_process=post_process) pre_process=pre_process, post_process=post_process)
def process_batch(batch):
"""Process batch and produce inputs for the model."""
images = batch[0].cuda().contiguous()
labels = batch[1].cuda().contiguous()
return images, labels
def cross_entropy_loss_func(labels, output_tensor):
logits = output_tensor
# Cross-entropy loss.
loss = F.cross_entropy(logits.contiguous().float(), labels)
# Reduce loss for logging.
averaged_loss = average_losses_across_data_parallel_group([loss])
return loss, {'lm loss': averaged_loss[0]}
def _cross_entropy_forward_step(batch, model):
"""Simple forward step with cross-entropy loss."""
timers = get_timers()
# Get the batch.
timers("batch generator").start()
try:
batch_ = next(batch)
except BaseException:
batch_ = batch
images, labels = process_batch(batch_)
timers("batch generator").stop()
# Forward model.
output_tensor = model(images)
return output_tensor, partial(cross_entropy_loss_func, labels)
"""Finetune/evaluate.""" """Finetune/evaluate."""
finetune( finetune(
train_valid_datasets_provider, train_valid_datasets_provider,
model_provider, model_provider,
forward_step=_cross_entropy_forward_step,
end_of_epoch_callback_provider=accuracy_func_provider, end_of_epoch_callback_provider=accuracy_func_provider,
) )
def main(): def main():
classification() classification()
...@@ -33,11 +33,10 @@ def accuracy_func_provider(): ...@@ -33,11 +33,10 @@ def accuracy_func_provider():
"""Provide function that calculates accuracies.""" """Provide function that calculates accuracies."""
args = get_args() args = get_args()
data_path = args.data_path data_path = args.data_path
crop_size = args.img_dim crop_size = (args.img_h, args.img_w)
# mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
# Build dataloaders. # Build dataloaders.
val_data_path = os.path.join(data_path[0], "val") val_data_path = data_path[1]
normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
transform_val = transforms.Compose( transform_val = transforms.Compose(
[ [
...@@ -54,6 +53,7 @@ def accuracy_func_provider(): ...@@ -54,6 +53,7 @@ def accuracy_func_provider():
args.micro_batch_size, args.micro_batch_size,
num_workers=args.num_workers, num_workers=args.num_workers,
drop_last=(mpu.get_data_parallel_world_size() > 1), drop_last=(mpu.get_data_parallel_world_size() > 1),
shuffle=False
) )
def metrics_func(model, epoch): def metrics_func(model, epoch):
...@@ -71,7 +71,6 @@ def accuracy_func_provider(): ...@@ -71,7 +71,6 @@ def accuracy_func_provider():
def calculate_correct_answers(model, dataloader, epoch): def calculate_correct_answers(model, dataloader, epoch):
"""Calculate correct over total answers""" """Calculate correct over total answers"""
args = get_args()
forward_backward_func = get_forward_backward_func() forward_backward_func = get_forward_backward_func()
for m in model: for m in model:
m.eval() m.eval()
...@@ -98,7 +97,6 @@ def calculate_correct_answers(model, dataloader, epoch): ...@@ -98,7 +97,6 @@ def calculate_correct_answers(model, dataloader, epoch):
images, labels = process_batch(batch_) images, labels = process_batch(batch_)
# Forward model. # Forward model.
args = get_args()
output_tensor = model(images) output_tensor = model(images)
return output_tensor, partial(loss_func, labels) return output_tensor, partial(loss_func, labels)
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment