Commit 01fc0833 authored by Jared Casper's avatar Jared Casper
Browse files

Merge branch 'vit_pipeline_fixes' into 'main'

vit pipeline fixes

See merge request ADLR/megatron-lm!276
parents 217f54b3 ccae9dbd
......@@ -60,8 +60,8 @@ def check_checkpoint_args(checkpoint_args):
_compare('num_layers')
_compare('hidden_size')
_compare('num_attention_heads')
_compare('max_position_embeddings')
if args.vocab_file:
_compare('max_position_embeddings')
_compare('make_vocab_size_divisible_by')
_compare('padded_vocab_size')
_compare('tokenizer_type')
......
......@@ -50,11 +50,11 @@ class VitMlpHead(MegatronModule):
def forward(self, hidden_states, sequence_index=0):
# hidden_states: [b, s, h]
# sequence_index: index of the token to pool.
x = hidden_states[:, sequence_index, :]
x = self.dense_in(x)
x = torch.tanh(x)
x = self.dense_out(x)
return x
hidden_state = hidden_states[:, sequence_index, :]
dense_in_result = self.dense_in(hidden_state)
tanh_result = torch.tanh(dense_in_result)
dense_out_result = self.dense_out(tanh_result)
return dense_out_result
def twod_interpolate_position_embeddings_hook(
......@@ -122,8 +122,12 @@ def twod_interpolate_position_embeddings_hook(
class VitModel(MegatronModule):
"""Vision Transformer Model."""
def __init__(self, num_classes, finetune=False):
super(VitModel, self).__init__()
def __init__(self,
num_classes,
finetune=False,
pre_process=True,
post_process=True):
super(VitModel, self).__init__(share_word_embeddings=False)
args = get_args()
self.fp16_lm_cross_entropy = args.fp16_lm_cross_entropy
......@@ -136,6 +140,8 @@ class VitModel(MegatronModule):
args.init_method_std, args.num_layers
)
self.pre_process = pre_process
self.post_process = post_process
self.hidden_size = args.hidden_size
self.num_classes = num_classes
self.patch_dim = args.patch_dim
......@@ -148,63 +154,81 @@ class VitModel(MegatronModule):
self.seq_length = self.num_patches + 1
self.flatten_dim = self.patch_dim * self.patch_dim * args.num_channels
# cls_token
self.cls_token = torch.nn.Parameter(torch.randn(1, 1, self.hidden_size))
torch.nn.init.zeros_(self.cls_token)
if self.pre_process:
# cls_token
self.cls_token = torch.nn.Parameter(
torch.randn(1, 1, self.hidden_size)
)
torch.nn.init.zeros_(self.cls_token)
# Linear encoder
self.linear_encoder = torch.nn.Linear(
self.flatten_dim, self.hidden_size
)
# Linear encoder
self.linear_encoder = torch.nn.Linear(
self.flatten_dim, self.hidden_size
)
# embedding
self.position_embeddings = torch.nn.Embedding(
self.seq_length, self.hidden_size
)
init_method_normal(args.init_method_std)(
self.position_embeddings.weight
)
self.position_ids = torch.arange(self.seq_length).expand(1, -1).cuda()
# embedding
self.position_embeddings = torch.nn.Embedding(
self.seq_length, self.hidden_size
)
init_method_normal(args.init_method_std)(
self.position_embeddings.weight
)
self.position_ids = torch.arange(self.seq_length).expand(1, -1).cuda()
self.position_embeddings._register_load_state_dict_pre_hook(
twod_interpolate_position_embeddings_hook
)
self.position_embeddings._register_load_state_dict_pre_hook(
twod_interpolate_position_embeddings_hook
)
self.embedding_dropout = torch.nn.Dropout(args.hidden_dropout)
self.embedding_dropout = torch.nn.Dropout(args.hidden_dropout)
# Transformer
self.transformer = ParallelTransformer(
self.init_method, self.scaled_init_method
self.init_method,
self.scaled_init_method,
pre_process=self.pre_process,
post_process=self.post_process
)
# MLP head
if not self.finetune:
self.mlp_head = VitMlpHead(self.hidden_size, self.num_classes)
else:
self.class_head = get_linear_layer(
self.hidden_size, num_classes, torch.nn.init.zeros_
if self.post_process:
# MLP head
if not self.finetune:
self.mlp_head = VitMlpHead(self.hidden_size, self.num_classes)
else:
self.class_head = get_linear_layer(
self.hidden_size, num_classes, torch.nn.init.zeros_
)
def set_input_tensor(self, input_tensor):
"""See megatron.model.transformer.set_input_tensor()"""
self.transformer.set_input_tensor(input_tensor)
def forward(self, input):
if self.pre_process:
rearranged_input = einops.rearrange(
input,
"b c (h p1) (w p2) -> b (h w) (p1 p2 c)",
p1=self.patch_dim,
p2=self.patch_dim,
)
def forward(self, x):
x = einops.rearrange(
x,
"b c (h p1) (w p2) -> b (h w) (p1 p2 c)",
p1=self.patch_dim,
p2=self.patch_dim,
)
assert rearranged_input.dtype == torch.half
encoder_output = self.linear_encoder(rearranged_input)
cls_tokens = self.cls_token.expand(encoder_output.shape[0], -1, -1)
concatenated_tokens = torch.cat((cls_tokens, encoder_output), dim=1)
assert x.dtype == torch.half
x = self.linear_encoder(x)
cls_tokens = self.cls_token.expand(x.shape[0], -1, -1)
x = torch.cat((cls_tokens, x), dim=1)
token_embeddings = concatenated_tokens + \
self.position_embeddings(self.position_ids)
hidden_states = self.embedding_dropout(token_embeddings)
else:
hidden_states = input
x = x + self.position_embeddings(self.position_ids)
x = self.embedding_dropout(x)
x = self.transformer(x, None)
hidden_states = self.transformer(hidden_states, None)
if not self.finetune:
x = self.mlp_head(x)
else:
x = self.class_head(x[:, 0, :])
if self.post_process:
if not self.finetune:
hidden_states = self.mlp_head(hidden_states)
else:
hidden_states = self.class_head(hidden_states[:, 0, :])
return x
return hidden_states
......@@ -17,19 +17,22 @@
import torch
import torch.nn.functional as F
from functools import partial
from megatron import get_args, get_timers, mpu, print_rank_0
from megatron.data.vit_dataset import build_train_valid_datasets
from megatron.model.vit_model import VitModel
from megatron.training import pretrain
from megatron.utils import average_losses_across_data_parallel_group
def model_provider():
def model_provider(pre_process=True, post_process=True):
"""Build the model."""
print_rank_0("building VIT model ...")
args = get_args()
model = VitModel(num_classes=args.num_classes)
model = VitModel(num_classes=args.num_classes,
pre_process=pre_process,
post_process=post_process)
return model
def get_batch(data_iterator):
......@@ -42,10 +45,21 @@ def get_batch(data_iterator):
return images, labels
def forward_step(data_iterator, model, input_tensor):
def loss_func(labels, output_tensor):
logits = output_tensor.contiguous().float()
loss = F.cross_entropy(logits, labels)
outputs = torch.argmax(logits, -1)
correct = (outputs == labels).float()
accuracy = torch.mean(correct)
averaged_loss = average_losses_across_data_parallel_group([loss, accuracy])
return loss, {"loss": averaged_loss[0], "accuracy": averaged_loss[1]}
def forward_step(data_iterator, model):
"""Forward step."""
timers = get_timers()
assert input_tensor is None
# Get the batch.
timers("batch-generator").start()
......@@ -56,17 +70,9 @@ def forward_step(data_iterator, model, input_tensor):
timers("batch-generator").stop()
# Forward model. lm_labels
logits = model(images).contiguous().float()
loss = F.cross_entropy(logits, labels)
outputs = torch.argmax(logits, -1)
correct = (outputs == labels).float()
accuracy = torch.mean(correct)
averaged_loss = average_losses_across_data_parallel_group([loss, accuracy])
return loss, {"loss": averaged_loss[0], "accuracy": averaged_loss[1]}
output_tensor = model(images)
return output_tensor, partial(loss_func, labels)
def train_valid_test_datasets_provider(train_val_test_num_samples):
"""Build train, valid, and test datasets."""
......
......@@ -34,13 +34,14 @@ def classification():
)
return train_ds, valid_ds
def model_provider():
def model_provider(pre_process=True, post_process=True):
"""Build the model."""
args = get_args()
print_rank_0("building classification model for ImageNet ...")
return VitModel(num_classes=args.num_classes, finetune=True)
return VitModel(num_classes=args.num_classes, finetune=True,
pre_process=pre_process, post_process=post_process)
"""Finetune/evaluate."""
finetune(
......
......@@ -16,10 +16,14 @@
"""Evaluation utilities."""
import os
from functools import partial
import torch
from megatron import get_args
from megatron import print_rank_0
from megatron import print_rank_0, print_rank_last
from megatron import mpu
from megatron.schedules import get_forward_backward_func
from tasks.vision.finetune_utils import build_data_loader
from tasks.vision.finetune_utils import process_batch
from torchvision import datasets, transforms
......@@ -56,7 +60,7 @@ def accuracy_func_provider():
print_rank_0("calculating metrics ...")
correct, total = calculate_correct_answers(model, dataloader, epoch)
percent = float(correct) * 100.0 / float(total)
print_rank_0(
print_rank_last(
" >> |epoch: {}| overall: correct / total = {} / {} = "
"{:.4f} %".format(epoch, correct, total, percent)
)
......@@ -67,29 +71,61 @@ def accuracy_func_provider():
def calculate_correct_answers(model, dataloader, epoch):
"""Calculate correct over total answers"""
model.eval()
args = get_args()
forward_backward_func = get_forward_backward_func()
for m in model:
m.eval()
def loss_func(labels, output_tensor):
logits = output_tensor
loss_dict = {}
# Compute the correct answers.
predicted = torch.argmax(logits, dim=-1)
corrects = (predicted == labels).float()
# Add to the counters.
loss_dict['total'] = labels.size(0)
loss_dict['correct'] = corrects.sum().item()
return 0, loss_dict
#defined inside to capture output_predictions
def correct_answers_forward_step(batch, model):
try:
batch_ = next(batch)
except BaseException:
batch_ = batch
images, labels = process_batch(batch_)
# Forward model.
args = get_args()
output_tensor = model(images)
return output_tensor, partial(loss_func, labels)
with torch.no_grad():
# For all the batches in the dataset.
total = 0
correct = 0
for _, batch in enumerate(dataloader):
# Run the model forward.
images, labels = process_batch(batch)
logits = model(images).contiguous().float()
# Add output predictions.
# Compute the correct answers.
predicted = torch.argmax(logits, dim=-1)
corrects = (predicted == labels).float()
# Add to the counters.
total += labels.size(0)
correct += corrects.sum().item()
model.train()
loss_dicts = forward_backward_func(correct_answers_forward_step, batch, model,
optimizer=None, timers=None, forward_only=True)
for loss_dict in loss_dicts:
total += loss_dict['total']
correct += loss_dict['correct']
for m in model:
m.train()
# Reduce.
unreduced = torch.cuda.LongTensor([correct, total])
torch.distributed.all_reduce(unreduced, group=mpu.get_data_parallel_group())
if mpu.is_pipeline_last_stage():
unreduced = torch.cuda.LongTensor([correct, total])
torch.distributed.all_reduce(unreduced,
group=mpu.get_data_parallel_group())
# Print on screen.
correct_ans = unreduced[0].item()
total_count = unreduced[1].item()
return correct_ans, total_count
# Print on screen.
correct_ans = unreduced[0].item()
total_count = unreduced[1].item()
return correct_ans, total_count
......@@ -17,6 +17,7 @@
import torch
import torch.nn.functional as F
from functools import partial
from megatron import get_args
from megatron import print_rank_0
from megatron import get_timers
......@@ -38,10 +39,21 @@ def process_batch(batch):
return images, labels
def _cross_entropy_forward_step(batch, model, input_tensor):
def cross_entropy_loss_func(labels, output_tensor):
logits = output_tensor
# Cross-entropy loss.
loss = F.cross_entropy(logits.contiguous().float(), labels)
# Reduce loss for logging.
averaged_loss = average_losses_across_data_parallel_group([loss])
return loss, {'lm loss': averaged_loss[0]}
def _cross_entropy_forward_step(batch, model):
"""Simple forward step with cross-entropy loss."""
timers = get_timers()
assert input_tensor is None
# Get the batch.
timers("batch generator").start()
......@@ -52,16 +64,10 @@ def _cross_entropy_forward_step(batch, model, input_tensor):
images, labels = process_batch(batch_)
timers("batch generator").stop()
# Forward model.
logits = model(images).contiguous().float()
# Cross-entropy loss.
loss = F.cross_entropy(logits, labels)
# Reduce loss for logging.
average_loss = average_losses_across_data_parallel_group([loss])
return loss, {"lm loss": average_loss[0]}
# Forward model.
output_tensor = model(images)
return output_tensor, partial(cross_entropy_loss_func, labels)
def build_data_loader(dataset, micro_batch_size, num_workers, drop_last):
......@@ -103,23 +109,28 @@ def _build_train_valid_dataloaders(train_dataset, valid_dataset):
"""Traing and validation dataloaders."""
args = get_args()
print_rank_0("building train and validation dataloaders ...")
print_rank_0('building train and validation dataloaders ...')
# Training dataset.
train_dataloader = build_data_loader(
train_dataset, args.micro_batch_size, args.num_workers, not args.keep_last
)
train_dataloader = build_data_loader(train_dataset, args.micro_batch_size,
args.num_workers, not args.keep_last)
# Set the training iterations.
args.train_iters_per_epoch = len(train_dataloader)
args.train_iters = args.epochs * args.train_iters_per_epoch
# Validation dataset. For this dataset, we do not need to set up
# shuffling so we can just use a simple infinite loop.
valid_dataloader_ = build_data_loader(
valid_dataset, args.micro_batch_size, args.num_workers, not args.keep_last
)
valid_dataloader_ = build_data_loader(valid_dataset, args.micro_batch_size,
args.num_workers, not args.keep_last)
valid_dataloader = _build_infinite_size_dataloader(valid_dataloader_)
return train_dataloader, valid_dataloader
# Now that we've built the data loaders, set batch_size arguments
# to the actual batch size the model will see for this dataset.
# This is necessary so pipeline transfers know what size they are
# and the LR schedule, which is based on samples seen, gets set
# correctly.
args.orig_micro_batch_size = args.micro_batch_size
args.orig_global_batch_size = args.global_batch_size
return train_dataloader, valid_dataloader
def _train(
model,
......@@ -135,7 +146,8 @@ def _train(
timers = get_timers()
# Turn on training mode which enables dropout.
model.train()
for m in model:
m.train()
# Tracking loss.
losses_dict_sum = {}
......@@ -166,12 +178,16 @@ def _train(
start_iteration = 0
# Train for one step.
losses_dict, skipped_iter = train_step(
losses_dict, skipped_iter, grad_norm, num_zeros_in_grad = train_step(
forward_step, batch, model, optimizer, lr_scheduler
)
iteration += 1
# Logging.
params_norm = None
if args.log_params_norm:
params_norm = calc_params_l2_norm(model)
report_memory_flag = training_log(
losses_dict,
losses_dict_sum,
......@@ -180,6 +196,9 @@ def _train(
optimizer.get_loss_scale().item(),
report_memory_flag,
skipped_iter,
grad_norm,
params_norm,
num_zeros_in_grad
)
# Autoresume
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment