Commit 80389ef6 authored by Jared Casper's avatar Jared Casper
Browse files

Merge branch 'main' into checkpoint_util

parents 1b2db724 d07d29df
...@@ -26,6 +26,7 @@ from .communication import ( ...@@ -26,6 +26,7 @@ from .communication import (
broadcast_from_last_to_first_pipeline_stage) broadcast_from_last_to_first_pipeline_stage)
from .forward_step import ForwardStep from .forward_step import ForwardStep
from .sampling import sample from .sampling import sample
from .beam_utils import BeamHypotheses
def score_and_return_on_first_stage(model, tokens, lengths): def score_and_return_on_first_stage(model, tokens, lengths):
"""Function for just scoring. """Function for just scoring.
...@@ -200,6 +201,7 @@ def generate_tokens_probs_and_return_on_first_stage( ...@@ -200,6 +201,7 @@ def generate_tokens_probs_and_return_on_first_stage(
top_p=top_p, top_p=top_p,
temperature=temperature, temperature=temperature,
vocab_size=tokenizer.vocab_size) vocab_size=tokenizer.vocab_size)
# If a prompt length is smaller or equal th current context # If a prompt length is smaller or equal th current context
# length, it means we have started generating tokens # length, it means we have started generating tokens
started = lengths <= context_length started = lengths <= context_length
...@@ -280,6 +282,118 @@ def generate_tokens_probs_and_return_on_first_stage( ...@@ -280,6 +282,118 @@ def generate_tokens_probs_and_return_on_first_stage(
return tokens, generated_sequence_lengths, output_log_probs return tokens, generated_sequence_lengths, output_log_probs
def beam_search_and_return_on_first_stage(model, tokens, lengths, beam_size, stop_token, num_return_gen, length_penalty):
args = get_args()
tokenizer = get_tokenizer()
batch_size = tokens.size(0)
assert(batch_size == 1)
prompt_length = lengths.item()
final_sequence_length = tokens.size(1)
final_sequence_length = min(final_sequence_length, args.max_position_embeddings)
# If the context is too big, this happens
if prompt_length >= final_sequence_length:
raise ValueError("context length + tokens_to_generate too large")
# forward step.
forward_step = ForwardStep(model, beam_size, final_sequence_length)
beam_hyp = BeamHypotheses(beam_size, length_penalty)
done = False
scores = torch.zeros(beam_size,
dtype=torch.float32,
device=torch.cuda.current_device()).unsqueeze(1)
# =============
# Run infernece
# =============
with torch.no_grad():
tokens = tokens.repeat(beam_size, 1)
attention_mask, position_ids = _build_attention_mask_and_position_ids(tokens)
prev_context_length = 0
for context_length in range(prompt_length, final_sequence_length):
# Pick the slice that we need to pass through the network.
tokens2use = tokens[:, prev_context_length:context_length]
positions2use = position_ids[:, prev_context_length:context_length]
attention_mask2use = attention_mask[
..., prev_context_length:context_length, :context_length]
# logits will be meanigful only in the last pipeline stage.
logits = forward_step(tokens2use, positions2use, attention_mask2use)
if mpu.is_pipeline_last_stage():
vocab_size = logits.size(2)
log_probs = F.log_softmax(logits, dim=2)
new_scores = log_probs[:, -1, :] + scores
if context_length == prompt_length: # if this is the first one
sorted_scores, indices = torch.sort(new_scores[0,:], descending=True)
else:
sorted_scores, indices = torch.sort(new_scores.view(-1), descending=True)
best_beam_ids = torch.div(indices[: 2 * beam_size], vocab_size).trunc().long()
best_words = indices[:2 * beam_size] % vocab_size
best_scores = sorted_scores[: 2 * beam_size]
next_beams = []
for beam_token_rank, (token_id, beam_score, beam_id) in enumerate(
zip(best_words, best_scores, best_beam_ids)
):
if token_id.item() == stop_token:
# if beam_token does not belong to top num_beams tokens, it should not be added
is_beam_token_worse_than_top_num_beams = beam_token_rank >= beam_size
if is_beam_token_worse_than_top_num_beams:
continue
beam_hyp.add(
tokens[beam_id].clone(),
beam_score,
context_length + 1 - prompt_length
)
else:
# add next predicted token since it is not eos_token
next_beams.append((token_id, beam_score, beam_id))
if len(next_beams) == beam_size:
break
if beam_hyp.is_done(best_scores.max().item(), context_length + 1 - prompt_length):
done = True
break
best_batches = tokens.new([item[2] for item in next_beams])
tokens = tokens[best_batches,:]
tokens[:, context_length] = tokens.new([item[0] for item in next_beams])
scores = scores.new([item[1] for item in next_beams]).unsqueeze(1)
# set inference key values to make it consistent with best beam index
forward_step.inference_params.swap_key_value_dict(best_batches)
# Update the tokens on the first stage so the next input to
# the network is correct.
copy_from_last_to_first_pipeline_stage(batch_size, torch.int64,
tokens[:, context_length])
# Update the context length for the next token generation.
prev_context_length = context_length
copy_from_last_to_first_pipeline_stage(scores.size(0), torch.float32,
scores[:,0])
# if cannot find stop token, add open beams to hyps
if not done:
for beam_id in range(beam_size):
beam_hyp.add(tokens[beam_id].clone(), scores[beam_id], context_length + 1 - prompt_length)
# rank based on scores
sorted_hyps = sorted(beam_hyp.beams, key=lambda x: x[0], reverse=True)
num_return_gen = min(num_return_gen, len(sorted_hyps))
scores = [sorted_hyps[i][0] for i in range(num_return_gen)]
tokens = [sorted_hyps[i][1] for i in range(num_return_gen)]
scores = torch.stack(scores, dim=0)
tokens = torch.stack(tokens, dim=0)
return tokens, scores
def _build_attention_mask_and_position_ids(tokens): def _build_attention_mask_and_position_ids(tokens):
......
...@@ -20,9 +20,11 @@ from flask import Flask, request, jsonify, current_app ...@@ -20,9 +20,11 @@ from flask import Flask, request, jsonify, current_app
from flask_restful import Resource, Api from flask_restful import Resource, Api
from megatron import get_args from megatron import get_args
from megatron.text_generation import generate_and_post_process from megatron.text_generation import generate_and_post_process
from megatron.text_generation import beam_search_and_post_process
GENERATE_NUM = 0 GENERATE_NUM = 0
BEAM_NUM = 1
lock = threading.Lock() lock = threading.Lock()
class MegatronGenerate(Resource): class MegatronGenerate(Resource):
...@@ -34,6 +36,11 @@ class MegatronGenerate(Resource): ...@@ -34,6 +36,11 @@ class MegatronGenerate(Resource):
choice = torch.cuda.LongTensor([GENERATE_NUM]) choice = torch.cuda.LongTensor([GENERATE_NUM])
torch.distributed.broadcast(choice, 0) torch.distributed.broadcast(choice, 0)
@staticmethod
def send_do_beam_search():
choice = torch.cuda.LongTensor([BEAM_NUM])
torch.distributed.broadcast(choice, 0)
def put(self): def put(self):
args = get_args() args = get_args()
...@@ -128,13 +135,55 @@ class MegatronGenerate(Resource): ...@@ -128,13 +135,55 @@ class MegatronGenerate(Resource):
if not isinstance(no_log, bool): if not isinstance(no_log, bool):
return "no_log must be a boolean value" return "no_log must be a boolean value"
beam_width = None
if "beam_width" in request.get_json():
beam_width = request.get_json()["beam_width"]
if not isinstance(beam_width, int):
return "beam_width must be integer"
if beam_width < 1:
return "beam_width must be an integer > 1"
if len(prompts) > 1:
return "When doing beam_search, batch size must be 1"
stop_token=50256
if "stop_token" in request.get_json():
stop_token = request.get_json()["stop_token"]
if not isinstance(stop_token, int):
return "stop_token must be an integer"
length_penalty = 1
if "length_penalty" in request.get_json():
length_penalty = request.get_json()["length_penalty"]
if not isinstance(length_penalty, float):
return "length_penalty must be a float"
with lock: # Need to get lock to keep multiple threads from hitting code with lock: # Need to get lock to keep multiple threads from hitting code
if not no_log: if not no_log:
print("request IP: " + str(request.remote_addr)) print("request IP: " + str(request.remote_addr))
print(json.dumps(request.get_json()),flush=True) print(json.dumps(request.get_json()),flush=True)
print("start time: ", datetime.datetime.now()) print("start time: ", datetime.datetime.now())
MegatronGenerate.send_do_generate() # Tell other ranks we're doing generate
try: try:
if beam_width is not None:
MegatronGenerate.send_do_beam_search() # Tell other ranks we're doing beam_search
response, response_seg, response_scores = \
beam_search_and_post_process(
self.model,
prompts=prompts,
tokens_to_generate=tokens_to_generate,
beam_size = beam_width,
add_BOS=add_BOS,
stop_token=stop_token,
num_return_gen=beam_width, # Returning whole beam
length_penalty=length_penalty
)
return jsonify({"text": response,
"segments": response_seg,
"scores": response_scores})
else:
MegatronGenerate.send_do_generate() # Tell other ranks we're doing generate
response, response_seg, response_logprobs, _ = \ response, response_seg, response_logprobs, _ = \
generate_and_post_process( generate_and_post_process(
self.model, self.model,
...@@ -149,14 +198,16 @@ class MegatronGenerate(Resource): ...@@ -149,14 +198,16 @@ class MegatronGenerate(Resource):
stop_on_double_eol=stop_on_double_eol, stop_on_double_eol=stop_on_double_eol,
stop_on_eol=stop_on_eol, stop_on_eol=stop_on_eol,
random_seed=random_seed) random_seed=random_seed)
except ValueError as ve:
return "Length of prompt + tokens_to_generate longer than allowed"
print("end time: ", datetime.datetime.now())
return jsonify({"text": response, return jsonify({"text": response,
"segments": response_seg, "segments": response_seg,
"logprobs": response_logprobs}) "logprobs": response_logprobs})
except ValueError as ve:
return "Length of prompt + tokens_to_generate longer than allowed"
print("end time: ", datetime.datetime.now())
class MegatronServer(object): class MegatronServer(object):
def __init__(self, model): def __init__(self, model):
self.app = Flask(__name__, static_url_path='') self.app = Flask(__name__, static_url_path='')
......
...@@ -42,6 +42,7 @@ from megatron.model import ModelType ...@@ -42,6 +42,7 @@ from megatron.model import ModelType
from megatron.optimizer import get_megatron_optimizer from megatron.optimizer import get_megatron_optimizer
from megatron.initialize import initialize_megatron from megatron.initialize import initialize_megatron
from megatron.initialize import write_args_to_tensorboard from megatron.initialize import write_args_to_tensorboard
from megatron.initialize import set_jit_fusion_options
from megatron.optimizer_param_scheduler import OptimizerParamScheduler from megatron.optimizer_param_scheduler import OptimizerParamScheduler
from megatron.model import DistributedDataParallel as LocalDDP from megatron.model import DistributedDataParallel as LocalDDP
from megatron.utils import check_adlr_autoresume_termination from megatron.utils import check_adlr_autoresume_termination
...@@ -99,6 +100,8 @@ def pretrain(train_valid_test_dataset_provider, ...@@ -99,6 +100,8 @@ def pretrain(train_valid_test_dataset_provider,
# Initalize and get arguments, timers, and Tensorboard writer. # Initalize and get arguments, timers, and Tensorboard writer.
initialize_megatron(extra_args_provider=extra_args_provider, initialize_megatron(extra_args_provider=extra_args_provider,
args_defaults=args_defaults) args_defaults=args_defaults)
# Set pytorch JIT layer fusion options and warmup JIT functions.
set_jit_fusion_options()
# Adjust the startup time so it reflects the largest value. # Adjust the startup time so it reflects the largest value.
# This will be closer to what scheduler will see (outside of # This will be closer to what scheduler will see (outside of
...@@ -361,12 +364,11 @@ def setup_model_and_optimizer(model_provider_func, ...@@ -361,12 +364,11 @@ def setup_model_and_optimizer(model_provider_func,
args = get_args() args = get_args()
model = get_model(model_provider_func, model_type) model = get_model(model_provider_func, model_type)
unwrapped_model = unwrap_model(model, unwrapped_model = unwrap_model(model,
(torchDDP, LocalDDP, Float16Module)) (torchDDP, LocalDDP, Float16Module))
optimizer = get_megatron_optimizer(unwrapped_model, no_wd_decay_cond,
scale_lr_cond, lr_mult)
optimizer = get_megatron_optimizer(model, no_wd_decay_cond,
scale_lr_cond, lr_mult)
opt_param_scheduler = get_optimizer_param_scheduler(optimizer) opt_param_scheduler = get_optimizer_param_scheduler(optimizer)
if args.load is not None: if args.load is not None:
...@@ -409,78 +411,44 @@ def train_step(forward_step_func, data_iterator, ...@@ -409,78 +411,44 @@ def train_step(forward_step_func, data_iterator,
partition.zero_grad_buffer() partition.zero_grad_buffer()
optimizer.zero_grad() optimizer.zero_grad()
# Forward pass.
forward_backward_func = get_forward_backward_func() forward_backward_func = get_forward_backward_func()
losses_reduced = forward_backward_func( losses_reduced = forward_backward_func(
forward_step_func, data_iterator, model, forward_step_func, data_iterator, model,
optimizer, timers, forward_only=False) optimizer, timers, forward_only=False)
# Empty unused memory # Empty unused memory.
if args.empty_unused_memory_level >= 1: if args.empty_unused_memory_level >= 1:
torch.cuda.empty_cache() torch.cuda.empty_cache()
# All-reduce if needed. # Reduce gradients.
if args.DDP_impl == 'local': timers('backward-reduce-model-grads').start()
timers('backward-params-all-reduce').start() optimizer.reduce_model_grads(args, timers)
for model_module in model: timers('backward-reduce-model-grads').stop()
model_module.allreduce_gradients()
timers('backward-params-all-reduce').stop()
# All-reduce word_embeddings' grad across first and last stages to ensure
# that word_embeddings parameters stay in sync.
# This should only run for models that support pipelined model parallelism
# (BERT and GPT-2).
timers('backward-embedding-all-reduce').start()
if mpu.is_rank_in_embedding_group(ignore_virtual=True) and \
mpu.get_pipeline_model_parallel_world_size() > 1:
if mpu.is_pipeline_first_stage(ignore_virtual=True):
unwrapped_model = model[0]
elif mpu.is_pipeline_last_stage(ignore_virtual=True):
unwrapped_model = model[-1]
else: # We do not support the interleaved schedule for T5 yet.
unwrapped_model = model[0]
unwrapped_model = unwrap_model(
unwrapped_model, (torchDDP, LocalDDP, Float16Module))
if unwrapped_model.share_word_embeddings:
word_embeddings_weight = unwrapped_model.word_embeddings_weight()
if args.DDP_impl == 'local':
grad = word_embeddings_weight.main_grad
else:
grad = word_embeddings_weight.grad
torch.distributed.all_reduce(grad, group=mpu.get_embedding_group())
# All-reduce position_embeddings grad across first (encoder) and split (decoder)
# stages to ensure that position embeddings parameters stay in sync.
# This should only run for T5 models with pipeline parallelism
if mpu.is_rank_in_position_embedding_group() and \
mpu.get_pipeline_model_parallel_world_size() > 1 and \
args.pipeline_model_parallel_split_rank is not None:
unwrapped_model = model[0]
unwrapped_model = unwrap_model(
unwrapped_model, (torchDDP, LocalDDP, Float16Module))
assert args.DDP_impl == 'local', \
'T5 model is only supported with local DDP mode'
grad = unwrapped_model.language_model.embedding.position_embeddings.weight.main_grad
torch.distributed.all_reduce(grad, group=mpu.get_position_embedding_group())
timers('backward-embedding-all-reduce').stop()
# Vision gradients.
if args.vision_pretraining and args.vision_pretraining_type == "dino": if args.vision_pretraining and args.vision_pretraining_type == "dino":
unwrapped_model = unwrap_model(model[0], unwrapped_model = unwrap_model(model[0],
(torchDDP, LocalDDP, Float16Module)) (torchDDP, LocalDDP, Float16Module))
unwrapped_model.cancel_gradients_last_layer(args.curr_iteration) unwrapped_model.cancel_gradients_last_layer(args.curr_iteration)
# Update parameters. # Update parameters.
timers('optimizer').start() timers('optimizer').start()
update_successful, grad_norm, num_zeros_in_grad = optimizer.step() update_successful, grad_norm, num_zeros_in_grad = optimizer.step(args, timers)
timers('optimizer').stop() timers('optimizer').stop()
# Gather params.
if update_successful:
timers('backward-gather-model-params').start()
optimizer.gather_model_params(args, timers)
timers('backward-gather-model-params').stop()
# Vision momentum.
if args.vision_pretraining and args.vision_pretraining_type == "dino": if args.vision_pretraining and args.vision_pretraining_type == "dino":
unwrapped_model = unwrap_model(model[0], unwrapped_model = unwrap_model(model[0],
(torchDDP, LocalDDP, Float16Module)) (torchDDP, LocalDDP, Float16Module))
unwrapped_model.update_momentum(args.curr_iteration) unwrapped_model.update_momentum(args.curr_iteration)
# Update learning rate. # Update learning rate.
if update_successful: if update_successful:
increment = get_num_microbatches() * \ increment = get_num_microbatches() * \
...@@ -491,7 +459,7 @@ def train_step(forward_step_func, data_iterator, ...@@ -491,7 +459,7 @@ def train_step(forward_step_func, data_iterator,
else: else:
skipped_iter = 1 skipped_iter = 1
# Empty unused memory # Empty unused memory.
if args.empty_unused_memory_level >= 2: if args.empty_unused_memory_level >= 2:
torch.cuda.empty_cache() torch.cuda.empty_cache()
...@@ -558,10 +526,15 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration, ...@@ -558,10 +526,15 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
add_to_logging('backward-send-forward-recv') add_to_logging('backward-send-forward-recv')
add_to_logging('backward-send-backward-recv') add_to_logging('backward-send-backward-recv')
add_to_logging('backward-params-all-reduce') add_to_logging('backward-params-all-reduce')
add_to_logging('backward-layernorm-all-reduce')
add_to_logging('backward-embedding-all-reduce') add_to_logging('backward-embedding-all-reduce')
add_to_logging('backward-reduce-model-grads')
add_to_logging('backward-gather-model-params')
add_to_logging('optimizer-copy-to-main-grad') add_to_logging('optimizer-copy-to-main-grad')
add_to_logging('optimizer-unscale-and-check-inf') add_to_logging('optimizer-unscale-and-check-inf')
add_to_logging('optimizer-clip-main-grad') add_to_logging('optimizer-clip-main-grad')
add_to_logging('optimizer-count-zeros')
add_to_logging('optimizer-inner-step')
add_to_logging('optimizer-copy-main-to-model-params') add_to_logging('optimizer-copy-main-to-model-params')
add_to_logging('optimizer') add_to_logging('optimizer')
add_to_logging('batch-generator') add_to_logging('batch-generator')
......
...@@ -229,7 +229,7 @@ def _train(model, optimizer, opt_param_scheduler, forward_step, ...@@ -229,7 +229,7 @@ def _train(model, optimizer, opt_param_scheduler, forward_step,
prefix = 'iteration {}'.format(iteration) prefix = 'iteration {}'.format(iteration)
evaluate_and_print_results(prefix, forward_step, evaluate_and_print_results(prefix, forward_step,
valid_dataloader, model, valid_dataloader, model,
iteration, False) iteration, None, False)
# Exiting based on iterations # Exiting based on iterations
if args.exit_interval and iteration % args.exit_interval == 0: if args.exit_interval and iteration % args.exit_interval == 0:
......
...@@ -15,12 +15,15 @@ ...@@ -15,12 +15,15 @@
"""Vision-classification finetuning/evaluation.""" """Vision-classification finetuning/evaluation."""
from megatron import get_args import torch.nn.functional as F
from functools import partial
from megatron import get_args, get_timers
from megatron import print_rank_0 from megatron import print_rank_0
from megatron.model.vit_model import VitModel from megatron.model.vision.classification import VitClassificationModel
from megatron.data.vit_dataset import build_train_valid_datasets from megatron.data.vit_dataset import build_train_valid_datasets
from tasks.vision.eval_utils import accuracy_func_provider from tasks.vision.classification.eval_utils import accuracy_func_provider
from tasks.vision.finetune_utils import finetune from tasks.vision.finetune_utils import finetune
from megatron.utils import average_losses_across_data_parallel_group
def classification(): def classification():
...@@ -30,7 +33,7 @@ def classification(): ...@@ -30,7 +33,7 @@ def classification():
train_ds, valid_ds = build_train_valid_datasets( train_ds, valid_ds = build_train_valid_datasets(
data_path=args.data_path, data_path=args.data_path,
crop_size=args.img_dim, image_size=(args.img_h, args.img_w),
) )
return train_ds, valid_ds return train_ds, valid_ds
...@@ -40,16 +43,52 @@ def classification(): ...@@ -40,16 +43,52 @@ def classification():
print_rank_0("building classification model for ImageNet ...") print_rank_0("building classification model for ImageNet ...")
return VitModel(num_classes=args.num_classes, finetune=True, return VitClassificationModel(num_classes=args.num_classes, finetune=True,
pre_process=pre_process, post_process=post_process) pre_process=pre_process, post_process=post_process)
def process_batch(batch):
"""Process batch and produce inputs for the model."""
images = batch[0].cuda().contiguous()
labels = batch[1].cuda().contiguous()
return images, labels
def cross_entropy_loss_func(labels, output_tensor):
logits = output_tensor
# Cross-entropy loss.
loss = F.cross_entropy(logits.contiguous().float(), labels)
# Reduce loss for logging.
averaged_loss = average_losses_across_data_parallel_group([loss])
return loss, {'lm loss': averaged_loss[0]}
def _cross_entropy_forward_step(batch, model):
"""Simple forward step with cross-entropy loss."""
timers = get_timers()
# Get the batch.
timers("batch generator").start()
try:
batch_ = next(batch)
except BaseException:
batch_ = batch
images, labels = process_batch(batch_)
timers("batch generator").stop()
# Forward model.
output_tensor = model(images)
return output_tensor, partial(cross_entropy_loss_func, labels)
"""Finetune/evaluate.""" """Finetune/evaluate."""
finetune( finetune(
train_valid_datasets_provider, train_valid_datasets_provider,
model_provider, model_provider,
forward_step=_cross_entropy_forward_step,
end_of_epoch_callback_provider=accuracy_func_provider, end_of_epoch_callback_provider=accuracy_func_provider,
) )
def main(): def main():
classification() classification()
...@@ -33,11 +33,10 @@ def accuracy_func_provider(): ...@@ -33,11 +33,10 @@ def accuracy_func_provider():
"""Provide function that calculates accuracies.""" """Provide function that calculates accuracies."""
args = get_args() args = get_args()
data_path = args.data_path data_path = args.data_path
crop_size = args.img_dim crop_size = (args.img_h, args.img_w)
# mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
# Build dataloaders. # Build dataloaders.
val_data_path = os.path.join(data_path[0], "val") val_data_path = data_path[1]
normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
transform_val = transforms.Compose( transform_val = transforms.Compose(
[ [
...@@ -54,6 +53,7 @@ def accuracy_func_provider(): ...@@ -54,6 +53,7 @@ def accuracy_func_provider():
args.micro_batch_size, args.micro_batch_size,
num_workers=args.num_workers, num_workers=args.num_workers,
drop_last=(mpu.get_data_parallel_world_size() > 1), drop_last=(mpu.get_data_parallel_world_size() > 1),
shuffle=False
) )
def metrics_func(model, epoch): def metrics_func(model, epoch):
...@@ -71,7 +71,6 @@ def accuracy_func_provider(): ...@@ -71,7 +71,6 @@ def accuracy_func_provider():
def calculate_correct_answers(model, dataloader, epoch): def calculate_correct_answers(model, dataloader, epoch):
"""Calculate correct over total answers""" """Calculate correct over total answers"""
args = get_args()
forward_backward_func = get_forward_backward_func() forward_backward_func = get_forward_backward_func()
for m in model: for m in model:
m.eval() m.eval()
...@@ -98,7 +97,6 @@ def calculate_correct_answers(model, dataloader, epoch): ...@@ -98,7 +97,6 @@ def calculate_correct_answers(model, dataloader, epoch):
images, labels = process_batch(batch_) images, labels = process_batch(batch_)
# Forward model. # Forward model.
args = get_args()
output_tensor = model(images) output_tensor = model(images)
return output_tensor, partial(loss_func, labels) return output_tensor, partial(loss_func, labels)
......
This diff is collapsed.
...@@ -28,32 +28,24 @@ sys.path.append( ...@@ -28,32 +28,24 @@ sys.path.append(
) )
from megatron import get_args from megatron import get_args
from megatron.initialize import initialize_megatron from megatron.initialize import initialize_megatron
from classification import main
def get_tasks_args(parser): def get_tasks_args(parser):
"""Provide extra arguments required for tasks.""" """Provide extra arguments required for tasks."""
group = parser.add_argument_group(title="tasks") group = parser.add_argument_group(title="tasks")
group.add_argument( group.add_argument('--task', type=str, default='segment',
"--epochs", choices=['classify', 'segment_setr', 'segment_segformer'],
type=int, help='task name.')
default=None, group.add_argument("--epochs", type=int, default=None,
help="Number of finetunning epochs. Zero results in " help="Number of finetunning epochs. Zero results in "
"evaluation only.", "evaluation only.")
) group.add_argument('--pretrained-checkpoint-type', type=str, default='default',
group.add_argument( choices=['default', 'external', 'constrastive'],
"--pretrained-checkpoint", help='Type of pretrained checkpoint')
type=str, group.add_argument("--pretrained-checkpoint", type=str, default=None,
default=None, help="Pretrained checkpoint used for finetunning.")
help="Pretrained checkpoint used for finetunning.", group.add_argument('--seg-stride', type=int, default=None,
) help='sliding window stride during evaluation')
group.add_argument(
"--keep-last",
action="store_true",
help="Keep the last batch (maybe incomplete) in" "the data loader",
)
return parser return parser
...@@ -61,4 +53,14 @@ if __name__ == "__main__": ...@@ -61,4 +53,14 @@ if __name__ == "__main__":
initialize_megatron(extra_args_provider=get_tasks_args) initialize_megatron(extra_args_provider=get_tasks_args)
args = get_args() args = get_args()
if args.task == 'classify':
from tasks.vision.classification.classification import main
main()
elif args.task == 'segment_setr':
from tasks.vision.segmentation.finetune_setr import main
main() main()
elif args.task == 'segment_segformer':
from tasks.vision.segmentation.finetune_segformer import main
main()
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import einops
import torch
import apex
import torch.nn.functional as F
from megatron import get_args
from megatron.model.module import MegatronModule
from megatron.model.vision.vit_backbone import VitBackbone, VitMlpHead
from megatron.model.vision.mit_backbone import mit_b3, mit_b5
from tasks.vision.segmentation.seg_heads import SetrSegmentationHead, SegformerSegmentationHead
class SetrSegmentationModel(MegatronModule):
def __init__(self,
num_classes,
pre_process=True,
post_process=True):
super(SetrSegmentationModel, self).__init__()
args = get_args()
assert post_process & pre_process
self.hidden_size = args.hidden_size
self.num_classes = num_classes
self.backbone = VitBackbone(
pre_process=pre_process,
post_process=post_process,
class_token=False,
post_layer_norm=False,
drop_path_rate=0.1
)
self.head = SetrSegmentationHead(
self.hidden_size,
self.num_classes
)
def set_input_tensor(self, input_tensor):
"""See megatron.model.transformer.set_input_tensor()"""
pass
def forward(self, input):
# [b hw c]
hidden_states = self.backbone(input)
result_final = self.head(hidden_states)
return result_final
class SegformerSegmentationModel(MegatronModule):
def __init__(self,
num_classes,
pre_process=True,
post_process=True):
super(SegformerSegmentationModel, self).__init__()
args = get_args()
self.hidden_size = args.hidden_size
self.num_classes = num_classes
self.pre_process = pre_process
self.post_process = post_process
self.backbone = mit_b5()
self.head = SegformerSegmentationHead(
feature_strides=[4, 8, 16, 32],
in_channels=[64, 128, 320, 512],
embedding_dim=768,
dropout_ratio=0.1
)
def set_input_tensor(self, input_tensor):
"""See megatron.model.transformer.set_input_tensor()"""
pass
def forward(self, input):
# [b hw c]
hidden_states = self.backbone(input)
hidden_states = self.head(hidden_states)
return hidden_states
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
...@@ -122,8 +122,10 @@ def get_args(): ...@@ -122,8 +122,10 @@ def get_args():
choices=['lazy', 'cached', 'mmap']) choices=['lazy', 'cached', 'mmap'])
group = parser.add_argument_group(title='runtime') group = parser.add_argument_group(title='runtime')
group.add_argument('--workers', type=int, default=1, group.add_argument('--workers', type=int, required=True,
help='Number of worker processes to launch') help='Number of worker processes to launch')
group.add_argument('--chunk-size', type=int, required=True,
help='Chunk size assigned to each worker process')
group.add_argument('--log-interval', type=int, default=100, group.add_argument('--log-interval', type=int, default=100,
help='Interval between progress updates') help='Interval between progress updates')
args = parser.parse_args() args = parser.parse_args()
...@@ -154,7 +156,7 @@ def main(): ...@@ -154,7 +156,7 @@ def main():
encoder = Encoder(args) encoder = Encoder(args)
tokenizer = build_tokenizer(args) tokenizer = build_tokenizer(args)
pool = multiprocessing.Pool(args.workers, initializer=encoder.initializer) pool = multiprocessing.Pool(args.workers, initializer=encoder.initializer)
encoded_docs = pool.imap(encoder.encode, fin, 25) encoded_docs = pool.imap(encoder.encode, fin, args.chunk_size)
#encoded_docs = map(encoder.encode, fin) #encoded_docs = map(encoder.encode, fin)
level = "document" level = "document"
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment