Commit 80389ef6 authored by Jared Casper's avatar Jared Casper
Browse files

Merge branch 'main' into checkpoint_util

parents 1b2db724 d07d29df
...@@ -26,6 +26,7 @@ from .communication import ( ...@@ -26,6 +26,7 @@ from .communication import (
broadcast_from_last_to_first_pipeline_stage) broadcast_from_last_to_first_pipeline_stage)
from .forward_step import ForwardStep from .forward_step import ForwardStep
from .sampling import sample from .sampling import sample
from .beam_utils import BeamHypotheses
def score_and_return_on_first_stage(model, tokens, lengths): def score_and_return_on_first_stage(model, tokens, lengths):
"""Function for just scoring. """Function for just scoring.
...@@ -200,6 +201,7 @@ def generate_tokens_probs_and_return_on_first_stage( ...@@ -200,6 +201,7 @@ def generate_tokens_probs_and_return_on_first_stage(
top_p=top_p, top_p=top_p,
temperature=temperature, temperature=temperature,
vocab_size=tokenizer.vocab_size) vocab_size=tokenizer.vocab_size)
# If a prompt length is smaller or equal th current context # If a prompt length is smaller or equal th current context
# length, it means we have started generating tokens # length, it means we have started generating tokens
started = lengths <= context_length started = lengths <= context_length
...@@ -257,7 +259,7 @@ def generate_tokens_probs_and_return_on_first_stage( ...@@ -257,7 +259,7 @@ def generate_tokens_probs_and_return_on_first_stage(
tensor=done) tensor=done)
if use_eod_token_for_early_termination and done: if use_eod_token_for_early_termination and done:
break break
# =================================================== # ===================================================
# Update the length of based on max generated length. # Update the length of based on max generated length.
# =================================================== # ===================================================
...@@ -280,6 +282,118 @@ def generate_tokens_probs_and_return_on_first_stage( ...@@ -280,6 +282,118 @@ def generate_tokens_probs_and_return_on_first_stage(
return tokens, generated_sequence_lengths, output_log_probs return tokens, generated_sequence_lengths, output_log_probs
def beam_search_and_return_on_first_stage(model, tokens, lengths, beam_size, stop_token, num_return_gen, length_penalty):
args = get_args()
tokenizer = get_tokenizer()
batch_size = tokens.size(0)
assert(batch_size == 1)
prompt_length = lengths.item()
final_sequence_length = tokens.size(1)
final_sequence_length = min(final_sequence_length, args.max_position_embeddings)
# If the context is too big, this happens
if prompt_length >= final_sequence_length:
raise ValueError("context length + tokens_to_generate too large")
# forward step.
forward_step = ForwardStep(model, beam_size, final_sequence_length)
beam_hyp = BeamHypotheses(beam_size, length_penalty)
done = False
scores = torch.zeros(beam_size,
dtype=torch.float32,
device=torch.cuda.current_device()).unsqueeze(1)
# =============
# Run infernece
# =============
with torch.no_grad():
tokens = tokens.repeat(beam_size, 1)
attention_mask, position_ids = _build_attention_mask_and_position_ids(tokens)
prev_context_length = 0
for context_length in range(prompt_length, final_sequence_length):
# Pick the slice that we need to pass through the network.
tokens2use = tokens[:, prev_context_length:context_length]
positions2use = position_ids[:, prev_context_length:context_length]
attention_mask2use = attention_mask[
..., prev_context_length:context_length, :context_length]
# logits will be meanigful only in the last pipeline stage.
logits = forward_step(tokens2use, positions2use, attention_mask2use)
if mpu.is_pipeline_last_stage():
vocab_size = logits.size(2)
log_probs = F.log_softmax(logits, dim=2)
new_scores = log_probs[:, -1, :] + scores
if context_length == prompt_length: # if this is the first one
sorted_scores, indices = torch.sort(new_scores[0,:], descending=True)
else:
sorted_scores, indices = torch.sort(new_scores.view(-1), descending=True)
best_beam_ids = torch.div(indices[: 2 * beam_size], vocab_size).trunc().long()
best_words = indices[:2 * beam_size] % vocab_size
best_scores = sorted_scores[: 2 * beam_size]
next_beams = []
for beam_token_rank, (token_id, beam_score, beam_id) in enumerate(
zip(best_words, best_scores, best_beam_ids)
):
if token_id.item() == stop_token:
# if beam_token does not belong to top num_beams tokens, it should not be added
is_beam_token_worse_than_top_num_beams = beam_token_rank >= beam_size
if is_beam_token_worse_than_top_num_beams:
continue
beam_hyp.add(
tokens[beam_id].clone(),
beam_score,
context_length + 1 - prompt_length
)
else:
# add next predicted token since it is not eos_token
next_beams.append((token_id, beam_score, beam_id))
if len(next_beams) == beam_size:
break
if beam_hyp.is_done(best_scores.max().item(), context_length + 1 - prompt_length):
done = True
break
best_batches = tokens.new([item[2] for item in next_beams])
tokens = tokens[best_batches,:]
tokens[:, context_length] = tokens.new([item[0] for item in next_beams])
scores = scores.new([item[1] for item in next_beams]).unsqueeze(1)
# set inference key values to make it consistent with best beam index
forward_step.inference_params.swap_key_value_dict(best_batches)
# Update the tokens on the first stage so the next input to
# the network is correct.
copy_from_last_to_first_pipeline_stage(batch_size, torch.int64,
tokens[:, context_length])
# Update the context length for the next token generation.
prev_context_length = context_length
copy_from_last_to_first_pipeline_stage(scores.size(0), torch.float32,
scores[:,0])
# if cannot find stop token, add open beams to hyps
if not done:
for beam_id in range(beam_size):
beam_hyp.add(tokens[beam_id].clone(), scores[beam_id], context_length + 1 - prompt_length)
# rank based on scores
sorted_hyps = sorted(beam_hyp.beams, key=lambda x: x[0], reverse=True)
num_return_gen = min(num_return_gen, len(sorted_hyps))
scores = [sorted_hyps[i][0] for i in range(num_return_gen)]
tokens = [sorted_hyps[i][1] for i in range(num_return_gen)]
scores = torch.stack(scores, dim=0)
tokens = torch.stack(tokens, dim=0)
return tokens, scores
def _build_attention_mask_and_position_ids(tokens): def _build_attention_mask_and_position_ids(tokens):
......
...@@ -20,9 +20,11 @@ from flask import Flask, request, jsonify, current_app ...@@ -20,9 +20,11 @@ from flask import Flask, request, jsonify, current_app
from flask_restful import Resource, Api from flask_restful import Resource, Api
from megatron import get_args from megatron import get_args
from megatron.text_generation import generate_and_post_process from megatron.text_generation import generate_and_post_process
from megatron.text_generation import beam_search_and_post_process
GENERATE_NUM = 0 GENERATE_NUM = 0
BEAM_NUM = 1
lock = threading.Lock() lock = threading.Lock()
class MegatronGenerate(Resource): class MegatronGenerate(Resource):
...@@ -34,6 +36,11 @@ class MegatronGenerate(Resource): ...@@ -34,6 +36,11 @@ class MegatronGenerate(Resource):
choice = torch.cuda.LongTensor([GENERATE_NUM]) choice = torch.cuda.LongTensor([GENERATE_NUM])
torch.distributed.broadcast(choice, 0) torch.distributed.broadcast(choice, 0)
@staticmethod
def send_do_beam_search():
choice = torch.cuda.LongTensor([BEAM_NUM])
torch.distributed.broadcast(choice, 0)
def put(self): def put(self):
args = get_args() args = get_args()
...@@ -128,15 +135,57 @@ class MegatronGenerate(Resource): ...@@ -128,15 +135,57 @@ class MegatronGenerate(Resource):
if not isinstance(no_log, bool): if not isinstance(no_log, bool):
return "no_log must be a boolean value" return "no_log must be a boolean value"
beam_width = None
if "beam_width" in request.get_json():
beam_width = request.get_json()["beam_width"]
if not isinstance(beam_width, int):
return "beam_width must be integer"
if beam_width < 1:
return "beam_width must be an integer > 1"
if len(prompts) > 1:
return "When doing beam_search, batch size must be 1"
stop_token=50256
if "stop_token" in request.get_json():
stop_token = request.get_json()["stop_token"]
if not isinstance(stop_token, int):
return "stop_token must be an integer"
length_penalty = 1
if "length_penalty" in request.get_json():
length_penalty = request.get_json()["length_penalty"]
if not isinstance(length_penalty, float):
return "length_penalty must be a float"
with lock: # Need to get lock to keep multiple threads from hitting code with lock: # Need to get lock to keep multiple threads from hitting code
if not no_log: if not no_log:
print("request IP: " + str(request.remote_addr)) print("request IP: " + str(request.remote_addr))
print(json.dumps(request.get_json()),flush=True) print(json.dumps(request.get_json()),flush=True)
print("start time: ", datetime.datetime.now()) print("start time: ", datetime.datetime.now())
MegatronGenerate.send_do_generate() # Tell other ranks we're doing generate
try: try:
response, response_seg, response_logprobs, _ = \ if beam_width is not None:
generate_and_post_process( MegatronGenerate.send_do_beam_search() # Tell other ranks we're doing beam_search
response, response_seg, response_scores = \
beam_search_and_post_process(
self.model,
prompts=prompts,
tokens_to_generate=tokens_to_generate,
beam_size = beam_width,
add_BOS=add_BOS,
stop_token=stop_token,
num_return_gen=beam_width, # Returning whole beam
length_penalty=length_penalty
)
return jsonify({"text": response,
"segments": response_seg,
"scores": response_scores})
else:
MegatronGenerate.send_do_generate() # Tell other ranks we're doing generate
response, response_seg, response_logprobs, _ = \
generate_and_post_process(
self.model, self.model,
prompts=prompts, prompts=prompts,
tokens_to_generate=tokens_to_generate, tokens_to_generate=tokens_to_generate,
...@@ -149,13 +198,15 @@ class MegatronGenerate(Resource): ...@@ -149,13 +198,15 @@ class MegatronGenerate(Resource):
stop_on_double_eol=stop_on_double_eol, stop_on_double_eol=stop_on_double_eol,
stop_on_eol=stop_on_eol, stop_on_eol=stop_on_eol,
random_seed=random_seed) random_seed=random_seed)
return jsonify({"text": response,
"segments": response_seg,
"logprobs": response_logprobs})
except ValueError as ve: except ValueError as ve:
return "Length of prompt + tokens_to_generate longer than allowed" return "Length of prompt + tokens_to_generate longer than allowed"
print("end time: ", datetime.datetime.now()) print("end time: ", datetime.datetime.now())
return jsonify({"text": response,
"segments": response_seg,
"logprobs": response_logprobs})
class MegatronServer(object): class MegatronServer(object):
def __init__(self, model): def __init__(self, model):
......
...@@ -42,6 +42,7 @@ from megatron.model import ModelType ...@@ -42,6 +42,7 @@ from megatron.model import ModelType
from megatron.optimizer import get_megatron_optimizer from megatron.optimizer import get_megatron_optimizer
from megatron.initialize import initialize_megatron from megatron.initialize import initialize_megatron
from megatron.initialize import write_args_to_tensorboard from megatron.initialize import write_args_to_tensorboard
from megatron.initialize import set_jit_fusion_options
from megatron.optimizer_param_scheduler import OptimizerParamScheduler from megatron.optimizer_param_scheduler import OptimizerParamScheduler
from megatron.model import DistributedDataParallel as LocalDDP from megatron.model import DistributedDataParallel as LocalDDP
from megatron.utils import check_adlr_autoresume_termination from megatron.utils import check_adlr_autoresume_termination
...@@ -99,6 +100,8 @@ def pretrain(train_valid_test_dataset_provider, ...@@ -99,6 +100,8 @@ def pretrain(train_valid_test_dataset_provider,
# Initalize and get arguments, timers, and Tensorboard writer. # Initalize and get arguments, timers, and Tensorboard writer.
initialize_megatron(extra_args_provider=extra_args_provider, initialize_megatron(extra_args_provider=extra_args_provider,
args_defaults=args_defaults) args_defaults=args_defaults)
# Set pytorch JIT layer fusion options and warmup JIT functions.
set_jit_fusion_options()
# Adjust the startup time so it reflects the largest value. # Adjust the startup time so it reflects the largest value.
# This will be closer to what scheduler will see (outside of # This will be closer to what scheduler will see (outside of
...@@ -361,12 +364,11 @@ def setup_model_and_optimizer(model_provider_func, ...@@ -361,12 +364,11 @@ def setup_model_and_optimizer(model_provider_func,
args = get_args() args = get_args()
model = get_model(model_provider_func, model_type) model = get_model(model_provider_func, model_type)
unwrapped_model = unwrap_model(model, unwrapped_model = unwrap_model(model,
(torchDDP, LocalDDP, Float16Module)) (torchDDP, LocalDDP, Float16Module))
optimizer = get_megatron_optimizer(unwrapped_model, no_wd_decay_cond,
scale_lr_cond, lr_mult)
optimizer = get_megatron_optimizer(model, no_wd_decay_cond,
scale_lr_cond, lr_mult)
opt_param_scheduler = get_optimizer_param_scheduler(optimizer) opt_param_scheduler = get_optimizer_param_scheduler(optimizer)
if args.load is not None: if args.load is not None:
...@@ -409,78 +411,44 @@ def train_step(forward_step_func, data_iterator, ...@@ -409,78 +411,44 @@ def train_step(forward_step_func, data_iterator,
partition.zero_grad_buffer() partition.zero_grad_buffer()
optimizer.zero_grad() optimizer.zero_grad()
# Forward pass.
forward_backward_func = get_forward_backward_func() forward_backward_func = get_forward_backward_func()
losses_reduced = forward_backward_func( losses_reduced = forward_backward_func(
forward_step_func, data_iterator, model, forward_step_func, data_iterator, model,
optimizer, timers, forward_only=False) optimizer, timers, forward_only=False)
# Empty unused memory # Empty unused memory.
if args.empty_unused_memory_level >= 1: if args.empty_unused_memory_level >= 1:
torch.cuda.empty_cache() torch.cuda.empty_cache()
# All-reduce if needed. # Reduce gradients.
if args.DDP_impl == 'local': timers('backward-reduce-model-grads').start()
timers('backward-params-all-reduce').start() optimizer.reduce_model_grads(args, timers)
for model_module in model: timers('backward-reduce-model-grads').stop()
model_module.allreduce_gradients()
timers('backward-params-all-reduce').stop()
# All-reduce word_embeddings' grad across first and last stages to ensure
# that word_embeddings parameters stay in sync.
# This should only run for models that support pipelined model parallelism
# (BERT and GPT-2).
timers('backward-embedding-all-reduce').start()
if mpu.is_rank_in_embedding_group(ignore_virtual=True) and \
mpu.get_pipeline_model_parallel_world_size() > 1:
if mpu.is_pipeline_first_stage(ignore_virtual=True):
unwrapped_model = model[0]
elif mpu.is_pipeline_last_stage(ignore_virtual=True):
unwrapped_model = model[-1]
else: # We do not support the interleaved schedule for T5 yet.
unwrapped_model = model[0]
unwrapped_model = unwrap_model(
unwrapped_model, (torchDDP, LocalDDP, Float16Module))
if unwrapped_model.share_word_embeddings:
word_embeddings_weight = unwrapped_model.word_embeddings_weight()
if args.DDP_impl == 'local':
grad = word_embeddings_weight.main_grad
else:
grad = word_embeddings_weight.grad
torch.distributed.all_reduce(grad, group=mpu.get_embedding_group())
# All-reduce position_embeddings grad across first (encoder) and split (decoder)
# stages to ensure that position embeddings parameters stay in sync.
# This should only run for T5 models with pipeline parallelism
if mpu.is_rank_in_position_embedding_group() and \
mpu.get_pipeline_model_parallel_world_size() > 1 and \
args.pipeline_model_parallel_split_rank is not None:
unwrapped_model = model[0]
unwrapped_model = unwrap_model(
unwrapped_model, (torchDDP, LocalDDP, Float16Module))
assert args.DDP_impl == 'local', \
'T5 model is only supported with local DDP mode'
grad = unwrapped_model.language_model.embedding.position_embeddings.weight.main_grad
torch.distributed.all_reduce(grad, group=mpu.get_position_embedding_group())
timers('backward-embedding-all-reduce').stop()
# Vision gradients.
if args.vision_pretraining and args.vision_pretraining_type == "dino": if args.vision_pretraining and args.vision_pretraining_type == "dino":
unwrapped_model = unwrap_model(model[0], unwrapped_model = unwrap_model(model[0],
(torchDDP, LocalDDP, Float16Module)) (torchDDP, LocalDDP, Float16Module))
unwrapped_model.cancel_gradients_last_layer(args.curr_iteration) unwrapped_model.cancel_gradients_last_layer(args.curr_iteration)
# Update parameters. # Update parameters.
timers('optimizer').start() timers('optimizer').start()
update_successful, grad_norm, num_zeros_in_grad = optimizer.step() update_successful, grad_norm, num_zeros_in_grad = optimizer.step(args, timers)
timers('optimizer').stop() timers('optimizer').stop()
# Gather params.
if update_successful:
timers('backward-gather-model-params').start()
optimizer.gather_model_params(args, timers)
timers('backward-gather-model-params').stop()
# Vision momentum.
if args.vision_pretraining and args.vision_pretraining_type == "dino": if args.vision_pretraining and args.vision_pretraining_type == "dino":
unwrapped_model = unwrap_model(model[0], unwrapped_model = unwrap_model(model[0],
(torchDDP, LocalDDP, Float16Module)) (torchDDP, LocalDDP, Float16Module))
unwrapped_model.update_momentum(args.curr_iteration) unwrapped_model.update_momentum(args.curr_iteration)
# Update learning rate. # Update learning rate.
if update_successful: if update_successful:
increment = get_num_microbatches() * \ increment = get_num_microbatches() * \
...@@ -491,7 +459,7 @@ def train_step(forward_step_func, data_iterator, ...@@ -491,7 +459,7 @@ def train_step(forward_step_func, data_iterator,
else: else:
skipped_iter = 1 skipped_iter = 1
# Empty unused memory # Empty unused memory.
if args.empty_unused_memory_level >= 2: if args.empty_unused_memory_level >= 2:
torch.cuda.empty_cache() torch.cuda.empty_cache()
...@@ -558,10 +526,15 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration, ...@@ -558,10 +526,15 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
add_to_logging('backward-send-forward-recv') add_to_logging('backward-send-forward-recv')
add_to_logging('backward-send-backward-recv') add_to_logging('backward-send-backward-recv')
add_to_logging('backward-params-all-reduce') add_to_logging('backward-params-all-reduce')
add_to_logging('backward-layernorm-all-reduce')
add_to_logging('backward-embedding-all-reduce') add_to_logging('backward-embedding-all-reduce')
add_to_logging('backward-reduce-model-grads')
add_to_logging('backward-gather-model-params')
add_to_logging('optimizer-copy-to-main-grad') add_to_logging('optimizer-copy-to-main-grad')
add_to_logging('optimizer-unscale-and-check-inf') add_to_logging('optimizer-unscale-and-check-inf')
add_to_logging('optimizer-clip-main-grad') add_to_logging('optimizer-clip-main-grad')
add_to_logging('optimizer-count-zeros')
add_to_logging('optimizer-inner-step')
add_to_logging('optimizer-copy-main-to-model-params') add_to_logging('optimizer-copy-main-to-model-params')
add_to_logging('optimizer') add_to_logging('optimizer')
add_to_logging('batch-generator') add_to_logging('batch-generator')
......
...@@ -229,7 +229,7 @@ def _train(model, optimizer, opt_param_scheduler, forward_step, ...@@ -229,7 +229,7 @@ def _train(model, optimizer, opt_param_scheduler, forward_step,
prefix = 'iteration {}'.format(iteration) prefix = 'iteration {}'.format(iteration)
evaluate_and_print_results(prefix, forward_step, evaluate_and_print_results(prefix, forward_step,
valid_dataloader, model, valid_dataloader, model,
iteration, False) iteration, None, False)
# Exiting based on iterations # Exiting based on iterations
if args.exit_interval and iteration % args.exit_interval == 0: if args.exit_interval and iteration % args.exit_interval == 0:
......
...@@ -15,12 +15,15 @@ ...@@ -15,12 +15,15 @@
"""Vision-classification finetuning/evaluation.""" """Vision-classification finetuning/evaluation."""
from megatron import get_args import torch.nn.functional as F
from functools import partial
from megatron import get_args, get_timers
from megatron import print_rank_0 from megatron import print_rank_0
from megatron.model.vit_model import VitModel from megatron.model.vision.classification import VitClassificationModel
from megatron.data.vit_dataset import build_train_valid_datasets from megatron.data.vit_dataset import build_train_valid_datasets
from tasks.vision.eval_utils import accuracy_func_provider from tasks.vision.classification.eval_utils import accuracy_func_provider
from tasks.vision.finetune_utils import finetune from tasks.vision.finetune_utils import finetune
from megatron.utils import average_losses_across_data_parallel_group
def classification(): def classification():
...@@ -30,7 +33,7 @@ def classification(): ...@@ -30,7 +33,7 @@ def classification():
train_ds, valid_ds = build_train_valid_datasets( train_ds, valid_ds = build_train_valid_datasets(
data_path=args.data_path, data_path=args.data_path,
crop_size=args.img_dim, image_size=(args.img_h, args.img_w),
) )
return train_ds, valid_ds return train_ds, valid_ds
...@@ -40,16 +43,52 @@ def classification(): ...@@ -40,16 +43,52 @@ def classification():
print_rank_0("building classification model for ImageNet ...") print_rank_0("building classification model for ImageNet ...")
return VitModel(num_classes=args.num_classes, finetune=True, return VitClassificationModel(num_classes=args.num_classes, finetune=True,
pre_process=pre_process, post_process=post_process) pre_process=pre_process, post_process=post_process)
def process_batch(batch):
"""Process batch and produce inputs for the model."""
images = batch[0].cuda().contiguous()
labels = batch[1].cuda().contiguous()
return images, labels
def cross_entropy_loss_func(labels, output_tensor):
logits = output_tensor
# Cross-entropy loss.
loss = F.cross_entropy(logits.contiguous().float(), labels)
# Reduce loss for logging.
averaged_loss = average_losses_across_data_parallel_group([loss])
return loss, {'lm loss': averaged_loss[0]}
def _cross_entropy_forward_step(batch, model):
"""Simple forward step with cross-entropy loss."""
timers = get_timers()
# Get the batch.
timers("batch generator").start()
try:
batch_ = next(batch)
except BaseException:
batch_ = batch
images, labels = process_batch(batch_)
timers("batch generator").stop()
# Forward model.
output_tensor = model(images)
return output_tensor, partial(cross_entropy_loss_func, labels)
"""Finetune/evaluate.""" """Finetune/evaluate."""
finetune( finetune(
train_valid_datasets_provider, train_valid_datasets_provider,
model_provider, model_provider,
forward_step=_cross_entropy_forward_step,
end_of_epoch_callback_provider=accuracy_func_provider, end_of_epoch_callback_provider=accuracy_func_provider,
) )
def main(): def main():
classification() classification()
...@@ -33,11 +33,10 @@ def accuracy_func_provider(): ...@@ -33,11 +33,10 @@ def accuracy_func_provider():
"""Provide function that calculates accuracies.""" """Provide function that calculates accuracies."""
args = get_args() args = get_args()
data_path = args.data_path data_path = args.data_path
crop_size = args.img_dim crop_size = (args.img_h, args.img_w)
# mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
# Build dataloaders. # Build dataloaders.
val_data_path = os.path.join(data_path[0], "val") val_data_path = data_path[1]
normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
transform_val = transforms.Compose( transform_val = transforms.Compose(
[ [
...@@ -54,6 +53,7 @@ def accuracy_func_provider(): ...@@ -54,6 +53,7 @@ def accuracy_func_provider():
args.micro_batch_size, args.micro_batch_size,
num_workers=args.num_workers, num_workers=args.num_workers,
drop_last=(mpu.get_data_parallel_world_size() > 1), drop_last=(mpu.get_data_parallel_world_size() > 1),
shuffle=False
) )
def metrics_func(model, epoch): def metrics_func(model, epoch):
...@@ -71,7 +71,6 @@ def accuracy_func_provider(): ...@@ -71,7 +71,6 @@ def accuracy_func_provider():
def calculate_correct_answers(model, dataloader, epoch): def calculate_correct_answers(model, dataloader, epoch):
"""Calculate correct over total answers""" """Calculate correct over total answers"""
args = get_args()
forward_backward_func = get_forward_backward_func() forward_backward_func = get_forward_backward_func()
for m in model: for m in model:
m.eval() m.eval()
...@@ -98,7 +97,6 @@ def calculate_correct_answers(model, dataloader, epoch): ...@@ -98,7 +97,6 @@ def calculate_correct_answers(model, dataloader, epoch):
images, labels = process_batch(batch_) images, labels = process_batch(batch_)
# Forward model. # Forward model.
args = get_args()
output_tensor = model(images) output_tensor = model(images)
return output_tensor, partial(loss_func, labels) return output_tensor, partial(loss_func, labels)
......
This diff is collapsed.
...@@ -28,32 +28,24 @@ sys.path.append( ...@@ -28,32 +28,24 @@ sys.path.append(
) )
from megatron import get_args from megatron import get_args
from megatron.initialize import initialize_megatron from megatron.initialize import initialize_megatron
from classification import main
def get_tasks_args(parser): def get_tasks_args(parser):
"""Provide extra arguments required for tasks.""" """Provide extra arguments required for tasks."""
group = parser.add_argument_group(title="tasks") group = parser.add_argument_group(title="tasks")
group.add_argument( group.add_argument('--task', type=str, default='segment',
"--epochs", choices=['classify', 'segment_setr', 'segment_segformer'],
type=int, help='task name.')
default=None, group.add_argument("--epochs", type=int, default=None,
help="Number of finetunning epochs. Zero results in " help="Number of finetunning epochs. Zero results in "
"evaluation only.", "evaluation only.")
) group.add_argument('--pretrained-checkpoint-type', type=str, default='default',
group.add_argument( choices=['default', 'external', 'constrastive'],
"--pretrained-checkpoint", help='Type of pretrained checkpoint')
type=str, group.add_argument("--pretrained-checkpoint", type=str, default=None,
default=None, help="Pretrained checkpoint used for finetunning.")
help="Pretrained checkpoint used for finetunning.", group.add_argument('--seg-stride', type=int, default=None,
) help='sliding window stride during evaluation')
group.add_argument(
"--keep-last",
action="store_true",
help="Keep the last batch (maybe incomplete) in" "the data loader",
)
return parser return parser
...@@ -61,4 +53,14 @@ if __name__ == "__main__": ...@@ -61,4 +53,14 @@ if __name__ == "__main__":
initialize_megatron(extra_args_provider=get_tasks_args) initialize_megatron(extra_args_provider=get_tasks_args)
args = get_args() args = get_args()
main()
if args.task == 'classify':
from tasks.vision.classification.classification import main
main()
elif args.task == 'segment_setr':
from tasks.vision.segmentation.finetune_setr import main
main()
elif args.task == 'segment_segformer':
from tasks.vision.segmentation.finetune_segformer import main
main()
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
...@@ -122,8 +122,10 @@ def get_args(): ...@@ -122,8 +122,10 @@ def get_args():
choices=['lazy', 'cached', 'mmap']) choices=['lazy', 'cached', 'mmap'])
group = parser.add_argument_group(title='runtime') group = parser.add_argument_group(title='runtime')
group.add_argument('--workers', type=int, default=1, group.add_argument('--workers', type=int, required=True,
help='Number of worker processes to launch') help='Number of worker processes to launch')
group.add_argument('--chunk-size', type=int, required=True,
help='Chunk size assigned to each worker process')
group.add_argument('--log-interval', type=int, default=100, group.add_argument('--log-interval', type=int, default=100,
help='Interval between progress updates') help='Interval between progress updates')
args = parser.parse_args() args = parser.parse_args()
...@@ -154,7 +156,7 @@ def main(): ...@@ -154,7 +156,7 @@ def main():
encoder = Encoder(args) encoder = Encoder(args)
tokenizer = build_tokenizer(args) tokenizer = build_tokenizer(args)
pool = multiprocessing.Pool(args.workers, initializer=encoder.initializer) pool = multiprocessing.Pool(args.workers, initializer=encoder.initializer)
encoded_docs = pool.imap(encoder.encode, fin, 25) encoded_docs = pool.imap(encoder.encode, fin, args.chunk_size)
#encoded_docs = map(encoder.encode, fin) #encoded_docs = map(encoder.encode, fin)
level = "document" level = "document"
......
...@@ -28,6 +28,7 @@ from megatron.model import GPTModel ...@@ -28,6 +28,7 @@ from megatron.model import GPTModel
from megatron.training import get_model from megatron.training import get_model
from megatron.text_generation_server import MegatronServer from megatron.text_generation_server import MegatronServer
from megatron.text_generation import generate_and_post_process from megatron.text_generation import generate_and_post_process
from megatron.text_generation import beam_search_and_post_process
import torch import torch
def model_provider(pre_process=True, post_process=True): def model_provider(pre_process=True, post_process=True):
...@@ -82,3 +83,8 @@ if __name__ == "__main__": ...@@ -82,3 +83,8 @@ if __name__ == "__main__":
generate_and_post_process(model) generate_and_post_process(model)
except ValueError as ve: except ValueError as ve:
pass pass
elif choice[0].item() == 1:
try:
beam_search_and_post_process(model)
except ValueError as ve:
pass
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment