Commit caa9dca5 authored by Jared Casper's avatar Jared Casper Committed by Deepak Narayanan
Browse files

Add pipelining to GLUE and RACE tasks

parent 3574b8e6
......@@ -493,7 +493,7 @@ Further command line arguments are described in the source file [`main.py`](./ta
## BERT Task Evaluation
<a id="race-evaluation"></a>
### RACE Evaluation
The following script finetunes the BERT model for evaluation on the [RACE dataset](http://www.cs.cmu.edu/~glai1/data/race/). The `TRAIN_DATA` and `VALID_DATA` directory contain the RACE dataset as separate `.txt` files.
The following script finetunes the BERT model for evaluation on the [RACE dataset](http://www.cs.cmu.edu/~glai1/data/race/). The `TRAIN_DATA` and `VALID_DATA` directory contain the RACE dataset as separate `.txt` files. Note that for RACE, the batch size is the number of RACE query's to evaluate. Since each RACE query has four samples, the effective batch size passed through the model will be four times the batch size specified on the command line.
<pre>
TRAIN_DATA="data/RACE/train/middle"
......
......@@ -42,11 +42,14 @@ def print_rank_0(message):
else:
print(message, flush=True)
def is_last_rank():
return torch.distributed.get_rank() == (
torch.distributed.get_world_size() - 1)
def print_rank_last(message):
"""If distributed is initialized, print only on last rank."""
if torch.distributed.is_initialized():
if torch.distributed.get_rank() == (
torch.distributed.get_world_size() - 1):
if is_last_rank():
print(message, flush=True)
else:
print(message, flush=True)
......@@ -119,8 +119,9 @@ class AnnealingLR(object):
return cls_value
if not self.use_checkpoint_lr_scheduler:
assert cls_value == sd_value, 'AnnealingLR: class input value' \
'and checkpoint values for {} do not match'.format(name)
assert cls_value == sd_value, \
f'AnnealingLR: class input value {cls_value} and checkpoint' \
f'value {sd_value} for {name} do not match'
print_rank_0(' > using checkpoint value {} for {}'.format(sd_value,
name))
return sd_value
......
......@@ -18,18 +18,19 @@
import torch
from megatron import get_args, print_rank_0
from megatron import mpu
from megatron.model.bert_model import bert_attention_mask_func, bert_extended_attention_mask, bert_position_ids
from megatron.model.language_model import get_language_model
from megatron.model.utils import get_linear_layer
from megatron.model.utils import init_method_normal
from megatron.model.utils import scaled_init_method_normal
from megatron.module import MegatronModule
from megatron.module import PipelinedMegatronModule
class Classification(MegatronModule):
class ClassificationBase(PipelinedMegatronModule):
def __init__(self, num_classes, num_tokentypes=2):
super(Classification, self).__init__()
super(ClassificationBase, self).__init__(share_word_embeddings=False)
args = get_args()
self.num_classes = num_classes
......@@ -50,24 +51,30 @@ class Classification(MegatronModule):
init_method)
self._classification_head_key = 'classification_head'
def forward(self, input_ids, attention_mask, tokentype_ids):
def forward(self, model_input, attention_mask, tokentype_ids=None):
extended_attention_mask = bert_extended_attention_mask(attention_mask)
position_ids = bert_position_ids(input_ids)
_, pooled_output = self.language_model(input_ids,
position_ids,
extended_attention_mask,
tokentype_ids=tokentype_ids)
kwargs = {}
if mpu.is_pipeline_first_stage():
input_ids = model_input
position_ids = bert_position_ids(input_ids)
# Output.
classification_output = self.classification_dropout(pooled_output)
classification_logits = self.classification_head(classification_output)
args = [input_ids, position_ids, extended_attention_mask]
kwargs['tokentype_ids'] = tokentype_ids
else:
args = [model_input, extended_attention_mask]
lm_output = self.language_model(*args, **kwargs)
if mpu.is_pipeline_last_stage():
_, pooled_output = lm_output
classification_output = self.classification_dropout(pooled_output)
classification_logits = self.classification_head(classification_output)
# Reshape back to separate choices.
classification_logits = classification_logits.view(-1, self.num_classes)
# Reshape back to separate choices.
classification_logits = classification_logits.view(-1, self.num_classes)
return classification_logits
return classification_logits
return lm_output
def state_dict_for_save_checkpoint(self, destination=None, prefix='',
keep_vars=False):
......@@ -95,3 +102,55 @@ class Classification(MegatronModule):
print_rank_0('***WARNING*** could not find {} in the checkpoint, '
'initializing to random'.format(
self._classification_head_key))
class Classification(ClassificationBase):
def __init__(self, num_classes, num_tokentypes=2):
super(Classification, self).__init__(
num_classes, num_tokentypes=num_tokentypes)
def forward(self, input_ids, attention_mask,
tokentype_ids=None):
return super(Classification, self).forward(
input_ids,
attention_mask,
tokentype_ids=tokentype_ids)
class ClassificationFirstStage(ClassificationBase):
def __init__(self, num_classes, num_tokentypes=2):
super(ClassificationFirstStage, self).__init__(
num_classes, num_tokentypes=num_tokentypes)
def forward(self, input_ids, attention_mask,
tokentype_ids=None):
return super(ClassificationFirstStage, self).forward(
input_ids,
attention_mask,
tokentype_ids=tokentype_ids)
class ClassificationIntermediateStage(ClassificationBase):
def __init__(self, num_classes, num_tokentypes=2):
super(ClassificationIntermediateStage, self).__init__(
num_classes, num_tokentypes=num_tokentypes)
def forward(self, hidden_state, attention_mask):
return super(ClassificationIntermediateStage, self).forward(
hidden_state,
attention_mask)
class ClassificationLastStage(ClassificationBase):
def __init__(self, num_classes, num_tokentypes=2):
super(ClassificationLastStage, self).__init__(
num_classes, num_tokentypes=num_tokentypes)
def forward(self, hidden_state, attention_mask):
return super(ClassificationLastStage, self).forward(
hidden_state,
attention_mask)
......@@ -18,18 +18,19 @@
import torch
from megatron import get_args, print_rank_0
from megatron import mpu
from megatron.model.bert_model import bert_attention_mask_func, bert_extended_attention_mask, bert_position_ids
from megatron.model.language_model import get_language_model
from megatron.model.utils import get_linear_layer
from megatron.model.utils import init_method_normal
from megatron.model.utils import scaled_init_method_normal
from megatron.module import MegatronModule
from megatron.module import PipelinedMegatronModule
class MultipleChoice(MegatronModule):
class MultipleChoiceBase(PipelinedMegatronModule):
def __init__(self, num_tokentypes=2):
super(MultipleChoice, self).__init__()
super(MultipleChoiceBase, self).__init__(share_word_embeddings=False)
args = get_args()
init_method = init_method_normal(args.init_method_std)
......@@ -48,38 +49,44 @@ class MultipleChoice(MegatronModule):
init_method)
self._multichoice_head_key = 'multichoice_head'
def forward(self, input_ids, attention_mask, tokentype_ids):
def forward(self, model_input, attention_mask, tokentype_ids=None):
# [batch, choices, sequence] --> [batch * choices, sequence] -->
# transformer --> [batch, choices] --> softmax
# Ensure the shape is [batch-size, choices, sequence]
assert len(input_ids.shape) == 3
assert len(attention_mask.shape) == 3
assert len(tokentype_ids.shape) == 3
num_choices = attention_mask.shape[1]
# Reshape and treat choice dimension the same as batch.
num_choices = input_ids.shape[1]
input_ids = input_ids.view(-1, input_ids.size(-1))
attention_mask = attention_mask.view(-1, attention_mask.size(-1))
tokentype_ids = tokentype_ids.view(-1, tokentype_ids.size(-1))
extended_attention_mask = bert_extended_attention_mask(attention_mask)
position_ids = bert_position_ids(input_ids)
_, pooled_output = self.language_model(input_ids,
position_ids,
extended_attention_mask,
tokentype_ids=tokentype_ids)
# Output.
multichoice_output = self.multichoice_dropout(pooled_output)
multichoice_logits = self.multichoice_head(multichoice_output)
kwargs = {}
if mpu.is_pipeline_first_stage():
input_ids = model_input
# Do the same as attention_mask for input_ids, tokentype_ids
assert len(input_ids.shape) == 3
assert len(tokentype_ids.shape) == 3
input_ids = input_ids.view(-1, input_ids.size(-1))
tokentype_ids = tokentype_ids.view(-1, tokentype_ids.size(-1))
position_ids = bert_position_ids(input_ids)
args = [input_ids, position_ids, extended_attention_mask]
kwargs['tokentype_ids'] = tokentype_ids
else:
args = [model_input, extended_attention_mask]
lm_output = self.language_model(*args, **kwargs)
if mpu.is_pipeline_last_stage():
_, pooled_output = lm_output
multichoice_output = self.multichoice_dropout(pooled_output)
multichoice_logits = self.multichoice_head(multichoice_output)
# Reshape back to separate choices.
multichoice_logits = multichoice_logits.view(-1, num_choices)
# Reshape back to separate choices.
multichoice_logits = multichoice_logits.view(-1, num_choices)
return multichoice_logits
return multichoice_logits
return lm_output
def state_dict_for_save_checkpoint(self, destination=None, prefix='',
keep_vars=False):
......@@ -107,3 +114,54 @@ class MultipleChoice(MegatronModule):
print_rank_0('***WARNING*** could not find {} in the checkpoint, '
'initializing to random'.format(
self._multichoice_head_key))
class MultipleChoice(MultipleChoiceBase):
def __init__(self, num_tokentypes=2):
super(MultipleChoice, self).__init__(
num_tokentypes=num_tokentypes)
def forward(self, input_ids, attention_mask,
tokentype_ids=None):
return super(MultipleChoice, self).forward(
input_ids,
attention_mask,
tokentype_ids=tokentype_ids)
class MultipleChoiceFirstStage(MultipleChoiceBase):
def __init__(self, num_tokentypes=2):
super(MultipleChoiceFirstStage, self).__init__(
num_tokentypes=num_tokentypes)
def forward(self, input_ids, attention_mask,
tokentype_ids=None):
return super(MultipleChoiceFirstStage, self).forward(
input_ids,
attention_mask,
tokentype_ids=tokentype_ids)
class MultipleChoiceIntermediateStage(MultipleChoiceBase):
def __init__(self, num_tokentypes=2):
super(MultipleChoiceIntermediateStage, self).__init__(
num_tokentypes=num_tokentypes)
def forward(self, hidden_state, attention_mask):
return super(MultipleChoiceIntermediateStage, self).forward(
hidden_state,
attention_mask)
class MultipleChoiceLastStage(MultipleChoiceBase):
def __init__(self, num_tokentypes=2):
super(MultipleChoiceLastStage, self).__init__(
num_tokentypes=num_tokentypes)
def forward(self, hidden_state, attention_mask):
return super(MultipleChoiceLastStage, self).forward(
hidden_state,
attention_mask)
......@@ -37,19 +37,27 @@ class MegatronModule(torch.nn.Module):
class PipelinedMegatronModule(MegatronModule):
"""Pipelining specific extensions of MegatronModule."""
def __init__(self):
def __init__(self, share_word_embeddings=True):
super(PipelinedMegatronModule, self).__init__()
args = get_args()
self.share_word_embeddings = share_word_embeddings
def word_embeddings_weight(self):
if mpu.is_pipeline_first_stage():
return self.language_model.embedding.word_embeddings.weight
if mpu.is_pipeline_last_stage():
if not self.share_word_embeddings:
raise Exception('word_embeddings_weight() called for last stage, '
'but share_word_embeddings is false')
return self.word_embeddings.weight
raise Exception('word_embeddings_weight() should be '
'called for first and last stage only')
def initialize_word_embeddings(self, init_method_normal):
args = get_args()
if not self.share_word_embeddings:
raise Exception('initialize_word_embeddings() was called but '
'share_word_embeddings is false')
# Parameters are shared between the word embeddings layer, and the heads at
# the end of the model. In a pipelined setup with more than one stage, the
# initial embedding layer and the head are on different workers, so we do
......
......@@ -575,9 +575,10 @@ def train_step(forward_step_func, data_iterator,
while isinstance(unwrapped_model, (torchDDP, LocalDDP, FP16_Module)):
unwrapped_model = unwrapped_model.module
word_embeddings_weight = unwrapped_model.word_embeddings_weight()
torch.distributed.all_reduce(word_embeddings_weight.grad,
group=mpu.get_embedding_group())
if unwrapped_model.share_word_embeddings:
word_embeddings_weight = unwrapped_model.word_embeddings_weight()
torch.distributed.all_reduce(word_embeddings_weight.grad,
group=mpu.get_embedding_group())
timers('backward-embedding-all-reduce').stop()
# Update master gradients.
......@@ -847,7 +848,6 @@ def evaluate(forward_step_func, data_iterator, model, verbose=False):
return total_loss_dict
def evaluate_and_print_results(prefix, forward_step_func,
data_iterator, model,
iteration, verbose=False):
......
......@@ -21,8 +21,9 @@ import time
import torch
from megatron import get_args
from megatron import print_rank_0
from megatron import print_rank_last, is_last_rank
from megatron import mpu
from megatron.training import communicate
from tasks.finetune_utils import build_data_loader
from tasks.finetune_utils import process_batch
......@@ -42,7 +43,7 @@ def accuracy_func_provider(single_dataset_provider):
dataloaders.append((dataset.dataset_name, dataloader))
def metrics_func(model, epoch, output_predictions=False):
print_rank_0('calculating metrics ...')
print_rank_last('calculating metrics ...')
correct = 0
total = 0
if output_predictions:
......@@ -60,25 +61,26 @@ def accuracy_func_provider(single_dataset_provider):
names += '_' + name
correct += correct_ans
total += total_count
percent = float(correct) * 100.0 / float(total)
print_rank_0(' >> |epoch: {}| overall: correct / total = {} / {} = '
'{:.4f} %'.format(epoch, correct, total, percent))
if is_last_rank():
percent = float(correct) * 100.0 / float(total)
print(' >> |epoch: {}| overall: correct / total = {} / {} = '
'{:.4f} %'.format(epoch, correct, total, percent))
if output_predictions and torch.distributed.get_rank() == 0:
if output_predictions and is_last_rank():
assert args.load is not None
filename = os.path.join(args.load, names + '.pt')
torch.save(named_predictions, filename)
return metrics_func
def calculate_correct_answers(name, model, dataloader,
epoch, output_predictions):
"""Calculate correct over total answers and return prediction if the
`output_predictions` is true."""
args = get_args()
start_time = time.time()
model.eval()
saved_batch_size = args.batch_size
with torch.no_grad():
# For all the batches in the dataset.
total = 0
......@@ -92,36 +94,79 @@ def calculate_correct_answers(name, model, dataloader,
for _, batch in enumerate(dataloader):
# Run the model forward.
tokens, types, labels_, attention_mask = process_batch(batch)
logits = model(tokens, attention_mask, types)
# Add output predictions.
if output_predictions:
softmaxes.extend(torch.nn.Softmax(dim=-1)(
logits.float()).data.cpu().numpy().tolist())
labels.extend(labels_.data.cpu().numpy().tolist())
ids.extend(batch['uid'].cpu().numpy().tolist())
# Compute the correct answers.
predicted = torch.argmax(logits, dim=-1)
corrects = (predicted == labels_)
# Add to the counters.
total += labels_.size(0)
correct += corrects.sum().item()
# For evaluation only mode we use drop_last = False to get all the
# samples, which means we might not have a full batch, so we
# adjust batch_size here to actual batch size of data
actual_batch_size = len(labels_)
# ... applying sample_multiplier if necessary
ds = dataloader.dataset
if hasattr(ds, 'sample_multiplier'):
actual_batch_size *= ds.sample_multiplier
args.batch_size = actual_batch_size
if not mpu.is_pipeline_first_stage():
input_tensor, _ = communicate(
tensor_send_next=None,
tensor_send_prev=None,
recv_forward=True,
recv_backward=False)
else:
input_tensor = None
# Forward model.
if mpu.is_pipeline_first_stage():
assert input_tensor is None
output_tensor = model(tokens, attention_mask, tokentype_ids=types)
else:
assert input_tensor is not None
output_tensor = model(input_tensor, attention_mask)
if mpu.is_pipeline_last_stage():
logits = output_tensor
# Add output predictions.
if output_predictions:
softmaxes.extend(torch.nn.Softmax(dim=-1)(
logits.float()).data.cpu().numpy().tolist())
labels.extend(labels_.data.cpu().numpy().tolist())
ids.extend(batch['uid'].cpu().numpy().tolist())
# Compute the correct answers.
predicted = torch.argmax(logits, dim=-1)
corrects = (predicted == labels_)
# Add to the counters.
total += labels_.size(0)
correct += corrects.sum().item()
else:
communicate(
tensor_send_next=output_tensor,
tensor_send_prev=None,
recv_forward=False,
recv_backward=False)
model.train()
args.batch_size = saved_batch_size
# Reduce.
unreduced = torch.cuda.LongTensor([correct, total])
torch.distributed.all_reduce(unreduced,
group=mpu.get_data_parallel_group())
# Print on screen.
correct_ans = unreduced[0].item()
total_count = unreduced[1].item()
percent = float(correct_ans) * 100.0 / float(total_count)
elapsed_time = time.time() - start_time
print_rank_0(' > |epoch: {}| metrics for {}: correct / total '
'= {} / {} = {:.4f} %, elapsed time (sec): {:.3f}'.format(
epoch, name, correct_ans, total_count,
percent, elapsed_time))
if mpu.is_pipeline_last_stage():
unreduced = torch.cuda.LongTensor([correct, total])
torch.distributed.all_reduce(unreduced,
group=mpu.get_data_parallel_group())
# Print on screen.
correct_ans = unreduced[0].item()
total_count = unreduced[1].item()
percent = float(correct_ans) * 100.0 / float(total_count)
elapsed_time = time.time() - start_time
print_rank_last(' > |epoch: {}| metrics for {}: correct / total '
'= {} / {} = {:.4f} %, elapsed time (sec): {:.3f}'.format(
epoch, name, correct_ans, total_count,
percent, elapsed_time))
if output_predictions:
return correct_ans, total_count, (softmaxes, labels, ids)
return correct_ans, total_count
if output_predictions:
return correct_ans, total_count, (softmaxes, labels, ids)
return correct_ans, total_count
return 0, 0, ()
return 0, 0
......@@ -45,7 +45,7 @@ def process_batch(batch):
return tokens, types, labels, attention_mask
def _cross_entropy_forward_step(batch, model):
def _cross_entropy_forward_step(batch, model, input_tensor):
"""Simple forward step with cross-entropy loss."""
timers = get_timers()
......@@ -59,16 +59,25 @@ def _cross_entropy_forward_step(batch, model):
timers('batch generator').stop()
# Forward model.
logits = model(tokens, attention_mask, types)
if mpu.is_pipeline_first_stage():
assert input_tensor is None
output_tensor = model(tokens, attention_mask, tokentype_ids=types)
else:
assert input_tensor is not None
output_tensor = model(input_tensor, attention_mask)
if mpu.is_pipeline_last_stage():
logits = output_tensor
# Cross-entropy loss.
loss_func = torch.nn.CrossEntropyLoss()
loss = loss_func(logits.contiguous().float(), labels)
# Cross-entropy loss.
loss_func = torch.nn.CrossEntropyLoss()
loss = loss_func(logits.contiguous().float(), labels)
# Reduce loss for logging.
averaged_loss = average_losses_across_data_parallel_group([loss])
# Reduce loss for logging.
averaged_loss = average_losses_across_data_parallel_group([loss])
return loss, {'lm loss': averaged_loss[0]}
return loss, {'lm loss': averaged_loss[0]}
return output_tensor
def build_data_loader(dataset, micro_batch_size, num_workers, drop_last):
......@@ -120,6 +129,11 @@ def _build_train_valid_dataloaders(train_dataset, valid_dataset):
args.num_workers, not args.keep_last)
valid_dataloader = _build_infinite_size_dataloader(valid_dataloader_)
# Now that we've built the data loaders, set args.batch_size to
# the actual batch size the model will see for this dataset
if hasattr(train_dataset, 'sample_multiplier'):
args.batch_size *= train_dataset.sample_multiplier
return train_dataloader, valid_dataloader
......@@ -211,6 +225,8 @@ def finetune(train_valid_datasets_provider, model_provider,
train_dataset, valid_dataset = train_valid_datasets_provider()
train_dataloader, valid_dataloader = _build_train_valid_dataloaders(
train_dataset, valid_dataset)
else:
args.train_iters = 0
timers('train/valid/test dataset/dataloder').stop()
# Build calback function.
......@@ -255,5 +271,4 @@ def finetune(train_valid_datasets_provider, model_provider,
if end_of_epoch_callback is not None:
print_rank_0('evaluation only mode, setting epoch to -1')
end_of_epoch_callback(model, epoch=-1, output_predictions=True)
print_rank_0('done :-)')
......@@ -18,7 +18,8 @@
from megatron import get_args
from megatron import print_rank_0
from megatron import get_tokenizer
from megatron.model.classification import Classification
from megatron import mpu
from megatron.model.classification import Classification, ClassificationFirstStage, ClassificationIntermediateStage, ClassificationLastStage
from tasks.eval_utils import accuracy_func_provider
from tasks.finetune_utils import finetune
......@@ -44,8 +45,21 @@ def glue_classification(num_classes, Dataset,
print_rank_0('building classification model for {} ...'.format(
args.task))
return Classification(num_classes=num_classes, num_tokentypes=2)
if mpu.get_pipeline_model_parallel_world_size() > 1:
# Determine model based on position of stage in pipeline.
if mpu.is_pipeline_first_stage():
model = ClassificationFirstStage(
num_classes=num_classes, num_tokentypes=2)
elif mpu.is_pipeline_last_stage():
model = ClassificationLastStage(
num_classes=num_classes, num_tokentypes=2)
else:
model = ClassificationIntermediateStage(
num_classes=num_classes, num_tokentypes=2)
else:
model = Classification(num_classes=num_classes, num_tokentypes=2)
return model
def metrics_func_provider():
"""Privde metrics callback function."""
......
......@@ -39,6 +39,8 @@ class RaceDataset(Dataset):
print_rank_0(' >> total number of samples: {}'.format(
len(self.samples)))
self.sample_multiplier = NUM_CHOICES
def __len__(self):
return len(self.samples)
......
......@@ -18,7 +18,8 @@
from megatron import get_args
from megatron import print_rank_0
from megatron import get_tokenizer
from megatron.model.multiple_choice import MultipleChoice
from megatron import mpu
from megatron.model.multiple_choice import MultipleChoice, MultipleChoiceFirstStage, MultipleChoiceIntermediateStage, MultipleChoiceLastStage
from tasks.eval_utils import accuracy_func_provider
from tasks.finetune_utils import finetune
from tasks.race.data import RaceDataset
......@@ -41,8 +42,18 @@ def model_provider():
"""Build the model."""
print_rank_0('building multichoice model for RACE ...')
return MultipleChoice(num_tokentypes=2)
if mpu.get_pipeline_model_parallel_world_size() > 1:
# Determine model based on position of stage in pipeline.
if mpu.is_pipeline_first_stage():
model = MultipleChoiceFirstStage(num_tokentypes=2)
elif mpu.is_pipeline_last_stage():
model = MultipleChoiceLastStage(num_tokentypes=2)
else:
model = MultipleChoiceIntermediateStage(num_tokentypes=2)
else:
model = MultipleChoice(num_tokentypes=2)
return model
def metrics_func_provider():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment