Commit ccae9dbd authored by Vijay Korthikanti's avatar Vijay Korthikanti
Browse files

vit pipeline fixes

parent 42c1cf42
...@@ -60,8 +60,8 @@ def check_checkpoint_args(checkpoint_args): ...@@ -60,8 +60,8 @@ def check_checkpoint_args(checkpoint_args):
_compare('num_layers') _compare('num_layers')
_compare('hidden_size') _compare('hidden_size')
_compare('num_attention_heads') _compare('num_attention_heads')
_compare('max_position_embeddings')
if args.vocab_file: if args.vocab_file:
_compare('max_position_embeddings')
_compare('make_vocab_size_divisible_by') _compare('make_vocab_size_divisible_by')
_compare('padded_vocab_size') _compare('padded_vocab_size')
_compare('tokenizer_type') _compare('tokenizer_type')
......
...@@ -50,11 +50,11 @@ class VitMlpHead(MegatronModule): ...@@ -50,11 +50,11 @@ class VitMlpHead(MegatronModule):
def forward(self, hidden_states, sequence_index=0): def forward(self, hidden_states, sequence_index=0):
# hidden_states: [b, s, h] # hidden_states: [b, s, h]
# sequence_index: index of the token to pool. # sequence_index: index of the token to pool.
x = hidden_states[:, sequence_index, :] hidden_state = hidden_states[:, sequence_index, :]
x = self.dense_in(x) dense_in_result = self.dense_in(hidden_state)
x = torch.tanh(x) tanh_result = torch.tanh(dense_in_result)
x = self.dense_out(x) dense_out_result = self.dense_out(tanh_result)
return x return dense_out_result
def twod_interpolate_position_embeddings_hook( def twod_interpolate_position_embeddings_hook(
...@@ -122,8 +122,12 @@ def twod_interpolate_position_embeddings_hook( ...@@ -122,8 +122,12 @@ def twod_interpolate_position_embeddings_hook(
class VitModel(MegatronModule): class VitModel(MegatronModule):
"""Vision Transformer Model.""" """Vision Transformer Model."""
def __init__(self, num_classes, finetune=False): def __init__(self,
super(VitModel, self).__init__() num_classes,
finetune=False,
pre_process=True,
post_process=True):
super(VitModel, self).__init__(share_word_embeddings=False)
args = get_args() args = get_args()
self.fp16_lm_cross_entropy = args.fp16_lm_cross_entropy self.fp16_lm_cross_entropy = args.fp16_lm_cross_entropy
...@@ -136,6 +140,8 @@ class VitModel(MegatronModule): ...@@ -136,6 +140,8 @@ class VitModel(MegatronModule):
args.init_method_std, args.num_layers args.init_method_std, args.num_layers
) )
self.pre_process = pre_process
self.post_process = post_process
self.hidden_size = args.hidden_size self.hidden_size = args.hidden_size
self.num_classes = num_classes self.num_classes = num_classes
self.patch_dim = args.patch_dim self.patch_dim = args.patch_dim
...@@ -148,63 +154,81 @@ class VitModel(MegatronModule): ...@@ -148,63 +154,81 @@ class VitModel(MegatronModule):
self.seq_length = self.num_patches + 1 self.seq_length = self.num_patches + 1
self.flatten_dim = self.patch_dim * self.patch_dim * args.num_channels self.flatten_dim = self.patch_dim * self.patch_dim * args.num_channels
# cls_token if self.pre_process:
self.cls_token = torch.nn.Parameter(torch.randn(1, 1, self.hidden_size)) # cls_token
torch.nn.init.zeros_(self.cls_token) self.cls_token = torch.nn.Parameter(
torch.randn(1, 1, self.hidden_size)
)
torch.nn.init.zeros_(self.cls_token)
# Linear encoder # Linear encoder
self.linear_encoder = torch.nn.Linear( self.linear_encoder = torch.nn.Linear(
self.flatten_dim, self.hidden_size self.flatten_dim, self.hidden_size
) )
# embedding # embedding
self.position_embeddings = torch.nn.Embedding( self.position_embeddings = torch.nn.Embedding(
self.seq_length, self.hidden_size self.seq_length, self.hidden_size
) )
init_method_normal(args.init_method_std)( init_method_normal(args.init_method_std)(
self.position_embeddings.weight self.position_embeddings.weight
) )
self.position_ids = torch.arange(self.seq_length).expand(1, -1).cuda() self.position_ids = torch.arange(self.seq_length).expand(1, -1).cuda()
self.position_embeddings._register_load_state_dict_pre_hook( self.position_embeddings._register_load_state_dict_pre_hook(
twod_interpolate_position_embeddings_hook twod_interpolate_position_embeddings_hook
) )
self.embedding_dropout = torch.nn.Dropout(args.hidden_dropout) self.embedding_dropout = torch.nn.Dropout(args.hidden_dropout)
# Transformer # Transformer
self.transformer = ParallelTransformer( self.transformer = ParallelTransformer(
self.init_method, self.scaled_init_method self.init_method,
self.scaled_init_method,
pre_process=self.pre_process,
post_process=self.post_process
) )
# MLP head if self.post_process:
if not self.finetune: # MLP head
self.mlp_head = VitMlpHead(self.hidden_size, self.num_classes) if not self.finetune:
else: self.mlp_head = VitMlpHead(self.hidden_size, self.num_classes)
self.class_head = get_linear_layer( else:
self.hidden_size, num_classes, torch.nn.init.zeros_ self.class_head = get_linear_layer(
self.hidden_size, num_classes, torch.nn.init.zeros_
)
def set_input_tensor(self, input_tensor):
"""See megatron.model.transformer.set_input_tensor()"""
self.transformer.set_input_tensor(input_tensor)
def forward(self, input):
if self.pre_process:
rearranged_input = einops.rearrange(
input,
"b c (h p1) (w p2) -> b (h w) (p1 p2 c)",
p1=self.patch_dim,
p2=self.patch_dim,
) )
def forward(self, x): assert rearranged_input.dtype == torch.half
x = einops.rearrange( encoder_output = self.linear_encoder(rearranged_input)
x, cls_tokens = self.cls_token.expand(encoder_output.shape[0], -1, -1)
"b c (h p1) (w p2) -> b (h w) (p1 p2 c)", concatenated_tokens = torch.cat((cls_tokens, encoder_output), dim=1)
p1=self.patch_dim,
p2=self.patch_dim,
)
assert x.dtype == torch.half token_embeddings = concatenated_tokens + \
x = self.linear_encoder(x) self.position_embeddings(self.position_ids)
cls_tokens = self.cls_token.expand(x.shape[0], -1, -1) hidden_states = self.embedding_dropout(token_embeddings)
x = torch.cat((cls_tokens, x), dim=1) else:
hidden_states = input
x = x + self.position_embeddings(self.position_ids) hidden_states = self.transformer(hidden_states, None)
x = self.embedding_dropout(x)
x = self.transformer(x, None)
if not self.finetune: if self.post_process:
x = self.mlp_head(x) if not self.finetune:
else: hidden_states = self.mlp_head(hidden_states)
x = self.class_head(x[:, 0, :]) else:
hidden_states = self.class_head(hidden_states[:, 0, :])
return x return hidden_states
...@@ -17,19 +17,22 @@ ...@@ -17,19 +17,22 @@
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from functools import partial
from megatron import get_args, get_timers, mpu, print_rank_0 from megatron import get_args, get_timers, mpu, print_rank_0
from megatron.data.vit_dataset import build_train_valid_datasets from megatron.data.vit_dataset import build_train_valid_datasets
from megatron.model.vit_model import VitModel from megatron.model.vit_model import VitModel
from megatron.training import pretrain from megatron.training import pretrain
from megatron.utils import average_losses_across_data_parallel_group from megatron.utils import average_losses_across_data_parallel_group
def model_provider(): def model_provider(pre_process=True, post_process=True):
"""Build the model.""" """Build the model."""
print_rank_0("building VIT model ...") print_rank_0("building VIT model ...")
args = get_args() args = get_args()
model = VitModel(num_classes=args.num_classes) model = VitModel(num_classes=args.num_classes,
pre_process=pre_process,
post_process=post_process)
return model return model
def get_batch(data_iterator): def get_batch(data_iterator):
...@@ -42,10 +45,21 @@ def get_batch(data_iterator): ...@@ -42,10 +45,21 @@ def get_batch(data_iterator):
return images, labels return images, labels
def forward_step(data_iterator, model, input_tensor): def loss_func(labels, output_tensor):
logits = output_tensor.contiguous().float()
loss = F.cross_entropy(logits, labels)
outputs = torch.argmax(logits, -1)
correct = (outputs == labels).float()
accuracy = torch.mean(correct)
averaged_loss = average_losses_across_data_parallel_group([loss, accuracy])
return loss, {"loss": averaged_loss[0], "accuracy": averaged_loss[1]}
def forward_step(data_iterator, model):
"""Forward step.""" """Forward step."""
timers = get_timers() timers = get_timers()
assert input_tensor is None
# Get the batch. # Get the batch.
timers("batch-generator").start() timers("batch-generator").start()
...@@ -56,17 +70,9 @@ def forward_step(data_iterator, model, input_tensor): ...@@ -56,17 +70,9 @@ def forward_step(data_iterator, model, input_tensor):
timers("batch-generator").stop() timers("batch-generator").stop()
# Forward model. lm_labels # Forward model. lm_labels
logits = model(images).contiguous().float() output_tensor = model(images)
loss = F.cross_entropy(logits, labels)
outputs = torch.argmax(logits, -1)
correct = (outputs == labels).float()
accuracy = torch.mean(correct)
averaged_loss = average_losses_across_data_parallel_group([loss, accuracy])
return loss, {"loss": averaged_loss[0], "accuracy": averaged_loss[1]}
return output_tensor, partial(loss_func, labels)
def train_valid_test_datasets_provider(train_val_test_num_samples): def train_valid_test_datasets_provider(train_val_test_num_samples):
"""Build train, valid, and test datasets.""" """Build train, valid, and test datasets."""
......
...@@ -34,13 +34,14 @@ def classification(): ...@@ -34,13 +34,14 @@ def classification():
) )
return train_ds, valid_ds return train_ds, valid_ds
def model_provider(): def model_provider(pre_process=True, post_process=True):
"""Build the model.""" """Build the model."""
args = get_args() args = get_args()
print_rank_0("building classification model for ImageNet ...") print_rank_0("building classification model for ImageNet ...")
return VitModel(num_classes=args.num_classes, finetune=True) return VitModel(num_classes=args.num_classes, finetune=True,
pre_process=pre_process, post_process=post_process)
"""Finetune/evaluate.""" """Finetune/evaluate."""
finetune( finetune(
......
...@@ -16,10 +16,14 @@ ...@@ -16,10 +16,14 @@
"""Evaluation utilities.""" """Evaluation utilities."""
import os import os
from functools import partial
import torch import torch
from megatron import get_args from megatron import get_args
from megatron import print_rank_0 from megatron import print_rank_0, print_rank_last
from megatron import mpu from megatron import mpu
from megatron.schedules import get_forward_backward_func
from tasks.vision.finetune_utils import build_data_loader from tasks.vision.finetune_utils import build_data_loader
from tasks.vision.finetune_utils import process_batch from tasks.vision.finetune_utils import process_batch
from torchvision import datasets, transforms from torchvision import datasets, transforms
...@@ -56,7 +60,7 @@ def accuracy_func_provider(): ...@@ -56,7 +60,7 @@ def accuracy_func_provider():
print_rank_0("calculating metrics ...") print_rank_0("calculating metrics ...")
correct, total = calculate_correct_answers(model, dataloader, epoch) correct, total = calculate_correct_answers(model, dataloader, epoch)
percent = float(correct) * 100.0 / float(total) percent = float(correct) * 100.0 / float(total)
print_rank_0( print_rank_last(
" >> |epoch: {}| overall: correct / total = {} / {} = " " >> |epoch: {}| overall: correct / total = {} / {} = "
"{:.4f} %".format(epoch, correct, total, percent) "{:.4f} %".format(epoch, correct, total, percent)
) )
...@@ -67,29 +71,61 @@ def accuracy_func_provider(): ...@@ -67,29 +71,61 @@ def accuracy_func_provider():
def calculate_correct_answers(model, dataloader, epoch): def calculate_correct_answers(model, dataloader, epoch):
"""Calculate correct over total answers""" """Calculate correct over total answers"""
model.eval() args = get_args()
forward_backward_func = get_forward_backward_func()
for m in model:
m.eval()
def loss_func(labels, output_tensor):
logits = output_tensor
loss_dict = {}
# Compute the correct answers.
predicted = torch.argmax(logits, dim=-1)
corrects = (predicted == labels).float()
# Add to the counters.
loss_dict['total'] = labels.size(0)
loss_dict['correct'] = corrects.sum().item()
return 0, loss_dict
#defined inside to capture output_predictions
def correct_answers_forward_step(batch, model):
try:
batch_ = next(batch)
except BaseException:
batch_ = batch
images, labels = process_batch(batch_)
# Forward model.
args = get_args()
output_tensor = model(images)
return output_tensor, partial(loss_func, labels)
with torch.no_grad(): with torch.no_grad():
# For all the batches in the dataset. # For all the batches in the dataset.
total = 0 total = 0
correct = 0 correct = 0
for _, batch in enumerate(dataloader): for _, batch in enumerate(dataloader):
# Run the model forward.
images, labels = process_batch(batch) loss_dicts = forward_backward_func(correct_answers_forward_step, batch, model,
logits = model(images).contiguous().float() optimizer=None, timers=None, forward_only=True)
# Add output predictions.
# Compute the correct answers. for loss_dict in loss_dicts:
predicted = torch.argmax(logits, dim=-1) total += loss_dict['total']
corrects = (predicted == labels).float() correct += loss_dict['correct']
# Add to the counters.
total += labels.size(0) for m in model:
correct += corrects.sum().item() m.train()
model.train()
# Reduce. # Reduce.
unreduced = torch.cuda.LongTensor([correct, total]) if mpu.is_pipeline_last_stage():
torch.distributed.all_reduce(unreduced, group=mpu.get_data_parallel_group()) unreduced = torch.cuda.LongTensor([correct, total])
torch.distributed.all_reduce(unreduced,
group=mpu.get_data_parallel_group())
# Print on screen. # Print on screen.
correct_ans = unreduced[0].item() correct_ans = unreduced[0].item()
total_count = unreduced[1].item() total_count = unreduced[1].item()
return correct_ans, total_count return correct_ans, total_count
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from functools import partial
from megatron import get_args from megatron import get_args
from megatron import print_rank_0 from megatron import print_rank_0
from megatron import get_timers from megatron import get_timers
...@@ -38,10 +39,21 @@ def process_batch(batch): ...@@ -38,10 +39,21 @@ def process_batch(batch):
return images, labels return images, labels
def _cross_entropy_forward_step(batch, model, input_tensor): def cross_entropy_loss_func(labels, output_tensor):
logits = output_tensor
# Cross-entropy loss.
loss = F.cross_entropy(logits.contiguous().float(), labels)
# Reduce loss for logging.
averaged_loss = average_losses_across_data_parallel_group([loss])
return loss, {'lm loss': averaged_loss[0]}
def _cross_entropy_forward_step(batch, model):
"""Simple forward step with cross-entropy loss.""" """Simple forward step with cross-entropy loss."""
timers = get_timers() timers = get_timers()
assert input_tensor is None
# Get the batch. # Get the batch.
timers("batch generator").start() timers("batch generator").start()
...@@ -52,16 +64,10 @@ def _cross_entropy_forward_step(batch, model, input_tensor): ...@@ -52,16 +64,10 @@ def _cross_entropy_forward_step(batch, model, input_tensor):
images, labels = process_batch(batch_) images, labels = process_batch(batch_)
timers("batch generator").stop() timers("batch generator").stop()
# Forward model. # Forward model.
logits = model(images).contiguous().float() output_tensor = model(images)
# Cross-entropy loss. return output_tensor, partial(cross_entropy_loss_func, labels)
loss = F.cross_entropy(logits, labels)
# Reduce loss for logging.
average_loss = average_losses_across_data_parallel_group([loss])
return loss, {"lm loss": average_loss[0]}
def build_data_loader(dataset, micro_batch_size, num_workers, drop_last): def build_data_loader(dataset, micro_batch_size, num_workers, drop_last):
...@@ -103,23 +109,28 @@ def _build_train_valid_dataloaders(train_dataset, valid_dataset): ...@@ -103,23 +109,28 @@ def _build_train_valid_dataloaders(train_dataset, valid_dataset):
"""Traing and validation dataloaders.""" """Traing and validation dataloaders."""
args = get_args() args = get_args()
print_rank_0("building train and validation dataloaders ...") print_rank_0('building train and validation dataloaders ...')
# Training dataset. # Training dataset.
train_dataloader = build_data_loader( train_dataloader = build_data_loader(train_dataset, args.micro_batch_size,
train_dataset, args.micro_batch_size, args.num_workers, not args.keep_last args.num_workers, not args.keep_last)
)
# Set the training iterations. # Set the training iterations.
args.train_iters_per_epoch = len(train_dataloader) args.train_iters_per_epoch = len(train_dataloader)
args.train_iters = args.epochs * args.train_iters_per_epoch args.train_iters = args.epochs * args.train_iters_per_epoch
# Validation dataset. For this dataset, we do not need to set up # Validation dataset. For this dataset, we do not need to set up
# shuffling so we can just use a simple infinite loop. # shuffling so we can just use a simple infinite loop.
valid_dataloader_ = build_data_loader( valid_dataloader_ = build_data_loader(valid_dataset, args.micro_batch_size,
valid_dataset, args.micro_batch_size, args.num_workers, not args.keep_last args.num_workers, not args.keep_last)
)
valid_dataloader = _build_infinite_size_dataloader(valid_dataloader_) valid_dataloader = _build_infinite_size_dataloader(valid_dataloader_)
return train_dataloader, valid_dataloader # Now that we've built the data loaders, set batch_size arguments
# to the actual batch size the model will see for this dataset.
# This is necessary so pipeline transfers know what size they are
# and the LR schedule, which is based on samples seen, gets set
# correctly.
args.orig_micro_batch_size = args.micro_batch_size
args.orig_global_batch_size = args.global_batch_size
return train_dataloader, valid_dataloader
def _train( def _train(
model, model,
...@@ -135,7 +146,8 @@ def _train( ...@@ -135,7 +146,8 @@ def _train(
timers = get_timers() timers = get_timers()
# Turn on training mode which enables dropout. # Turn on training mode which enables dropout.
model.train() for m in model:
m.train()
# Tracking loss. # Tracking loss.
losses_dict_sum = {} losses_dict_sum = {}
...@@ -166,12 +178,16 @@ def _train( ...@@ -166,12 +178,16 @@ def _train(
start_iteration = 0 start_iteration = 0
# Train for one step. # Train for one step.
losses_dict, skipped_iter = train_step( losses_dict, skipped_iter, grad_norm, num_zeros_in_grad = train_step(
forward_step, batch, model, optimizer, lr_scheduler forward_step, batch, model, optimizer, lr_scheduler
) )
iteration += 1 iteration += 1
# Logging. # Logging.
params_norm = None
if args.log_params_norm:
params_norm = calc_params_l2_norm(model)
report_memory_flag = training_log( report_memory_flag = training_log(
losses_dict, losses_dict,
losses_dict_sum, losses_dict_sum,
...@@ -180,6 +196,9 @@ def _train( ...@@ -180,6 +196,9 @@ def _train(
optimizer.get_loss_scale().item(), optimizer.get_loss_scale().item(),
report_memory_flag, report_memory_flag,
skipped_iter, skipped_iter,
grad_norm,
params_norm,
num_zeros_in_grad
) )
# Autoresume # Autoresume
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment