"git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "7d4a257c7f2a8fa85fdc1532d19f7dce92e5a3d9"
Commit b1efc33d authored by Neel Kant's avatar Neel Kant
Browse files

Modify pretrain_bert_ict.py to work with ICTBertModel

parent 371d2ea9
...@@ -78,7 +78,7 @@ def broadcast_data(keys, data, datatype): ...@@ -78,7 +78,7 @@ def broadcast_data(keys, data, datatype):
members of the same model parallel group. members of the same model parallel group.
Arguments: Arguments:
keys: list of keys in the data disctionary to be broadcasted keys: list of keys in the data dictionary to be broadcasted
data: data dictionary of string keys and cpu tensor values. data: data dictionary of string keys and cpu tensor values.
datatype: torch data type of all tensors in data associated datatype: torch data type of all tensors in data associated
with keys. with keys.
......
...@@ -20,7 +20,7 @@ import torch.nn.functional as F ...@@ -20,7 +20,7 @@ import torch.nn.functional as F
from configure_data import configure_data from configure_data import configure_data
from megatron import mpu from megatron import mpu
from megatron.model import BertModel from megatron.model import ICTBertModel
from megatron.utils import print_rank_0 from megatron.utils import print_rank_0
from megatron.utils import reduce_losses from megatron.utils import reduce_losses
from megatron.utils import vocab_size_with_padding from megatron.utils import vocab_size_with_padding
...@@ -30,9 +30,9 @@ from megatron.training import run ...@@ -30,9 +30,9 @@ from megatron.training import run
def model_provider(args): def model_provider(args):
"""Build the model.""" """Build the model."""
print_rank_0('building BERT model ...') print_rank_0('building BERT models ...')
model = BertModel( model = ICTBertModel(
num_layers=args.num_layers, num_layers=args.num_layers,
vocab_size=args.vocab_size, vocab_size=args.vocab_size,
hidden_size=args.hidden_size, hidden_size=args.hidden_size,
...@@ -42,8 +42,8 @@ def model_provider(args): ...@@ -42,8 +42,8 @@ def model_provider(args):
output_dropout_prob=args.hidden_dropout, output_dropout_prob=args.hidden_dropout,
max_sequence_length=args.max_position_embeddings, max_sequence_length=args.max_position_embeddings,
checkpoint_activations=args.checkpoint_activations, checkpoint_activations=args.checkpoint_activations,
ict_head_size=128,
checkpoint_num_layers=args.checkpoint_num_layers, checkpoint_num_layers=args.checkpoint_num_layers,
add_binary_head=True,
layernorm_epsilon=args.layernorm_epsilon, layernorm_epsilon=args.layernorm_epsilon,
num_tokentypes=args.tokentype_size, num_tokentypes=args.tokentype_size,
parallel_output=True, parallel_output=True,
...@@ -56,27 +56,30 @@ def model_provider(args): ...@@ -56,27 +56,30 @@ def model_provider(args):
def get_batch(data_iterator, timers): def get_batch(data_iterator, timers):
# Items and their type. # Items and their type.
keys = ['text', 'types', 'is_random', 'mask', 'mask_labels', 'pad_mask'] keys = ['input_text', 'input_types', 'input_pad_mask',
'context_text', 'context_types', 'context_pad_mask']
datatype = torch.int64 datatype = torch.int64
# Broadcast data. # Broadcast data.
timers('data loader').start() timers('data loader').start()
if data_iterator is not None: if data_iterator is None:
data = next(data_iterator)
else:
data = None data = None
else:
data = next(data_iterator)
timers('data loader').stop() timers('data loader').stop()
data_b = mpu.broadcast_data(keys, data, datatype) data_b = mpu.broadcast_data(keys, data, datatype)
# Unpack. # Unpack.
tokens = data_b['text'].long() input_tokens = data_b['input_text'].long()
types = data_b['types'].long() input_types = data_b['input_types'].long()
next_sentence = data_b['is_random'].long() input_pad_mask = data_b['input_pad_mask'].long()
loss_mask = data_b['mask'].float() context_tokens = data_b['context_text'].long()
lm_labels = data_b['mask_labels'].long() context_types = data_b['context_types'].long()
padding_mask = data_b['pad_mask'].long() context_pad_mask = data_b['context_pad_mask'].long()
return tokens, types, next_sentence, loss_mask, lm_labels, padding_mask return input_tokens, input_types, input_pad_mask,\
context_tokens, context_types, context_pad_mask
def forward_step(data_iterator, model, args, timers): def forward_step(data_iterator, model, args, timers):
...@@ -84,27 +87,20 @@ def forward_step(data_iterator, model, args, timers): ...@@ -84,27 +87,20 @@ def forward_step(data_iterator, model, args, timers):
# Get the batch. # Get the batch.
timers('batch generator').start() timers('batch generator').start()
tokens, types, next_sentence, loss_mask, lm_labels, padding_mask \ input_tokens, input_types, input_pad_mask,\
= get_batch(data_iterator, timers) context_tokens, context_types, context_pad_mask = get_batch(data_iterator, timers)
timers('batch generator').stop() timers('batch generator').stop()
# Forward model. # Forward model.
lm_logits, nsp_logits = model(tokens, 1-padding_mask, tokentype_ids=types) retrieval_scores = model(input_tokens, 1 - input_pad_mask, input_types,
context_tokens, 1 - context_pad_mask, context_types)
nsp_loss = F.cross_entropy(nsp_logits.view(-1, 2).contiguous().float(),
next_sentence.view(-1).contiguous(),
ignore_index=-1)
lm_loss_ = mpu.vocab_parallel_cross_entropy(lm_logits.contiguous().float(),
lm_labels.contiguous())
lm_loss = torch.sum(
lm_loss_.view(-1) * loss_mask.reshape(-1)) / loss_mask.sum()
loss = lm_loss + nsp_loss softmaxed = F.softmax(retrieval_scores, dim=0).float()
retrieval_loss = F.cross_entropy(softmaxed, torch.arange(softmaxed.size()[0]))
reduced_losses = reduce_losses([lm_loss, nsp_loss]) reduced_losses = reduce_losses([retrieval_loss])
return loss, {'lm loss': reduced_losses[0], 'nsp loss': reduced_losses[1]} return retrieval_loss, {'retrieval loss': reduced_losses[0]}
def get_train_val_test_data(args): def get_train_val_test_data(args):
...@@ -152,5 +148,5 @@ def get_train_val_test_data(args): ...@@ -152,5 +148,5 @@ def get_train_val_test_data(args):
if __name__ == "__main__": if __name__ == "__main__":
run('Pretrain BERT model', get_train_val_test_data, run('Pretrain ICT BERT model', get_train_val_test_data,
model_provider, forward_step) model_provider, forward_step)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment