Commit 78022005 authored by mohammad's avatar mohammad
Browse files

added fp16 lm cross entropy to bert

parent 22e3c7e6
...@@ -18,6 +18,7 @@ ...@@ -18,6 +18,7 @@
import torch import torch
from megatron import get_args from megatron import get_args
from megatron import mpu
from megatron.model.language_model import parallel_lm_logits from megatron.model.language_model import parallel_lm_logits
from megatron.model.language_model import get_language_model from megatron.model.language_model import get_language_model
from megatron.model.transformer import LayerNorm from megatron.model.transformer import LayerNorm
...@@ -138,7 +139,8 @@ class BertModel(MegatronModule): ...@@ -138,7 +139,8 @@ class BertModel(MegatronModule):
init_method) init_method)
self._binary_head_key = 'binary_head' self._binary_head_key = 'binary_head'
def forward(self, input_ids, attention_mask, tokentype_ids=None): def forward(self, input_ids, attention_mask,
tokentype_ids=None, lm_labels=None):
extended_attention_mask = bert_extended_attention_mask( extended_attention_mask = bert_extended_attention_mask(
attention_mask, next(self.language_model.parameters()).dtype) attention_mask, next(self.language_model.parameters()).dtype)
...@@ -161,11 +163,16 @@ class BertModel(MegatronModule): ...@@ -161,11 +163,16 @@ class BertModel(MegatronModule):
lm_logits = self.lm_head( lm_logits = self.lm_head(
lm_output, self.language_model.embedding.word_embeddings.weight) lm_output, self.language_model.embedding.word_embeddings.weight)
binary_logits = None
if self.add_binary_head: if self.add_binary_head:
binary_logits = self.binary_head(pooled_output) binary_logits = self.binary_head(pooled_output)
if lm_labels is None:
return lm_logits, binary_logits return lm_logits, binary_logits
else:
lm_loss = mpu.vocab_parallel_cross_entropy(lm_logits, lm_labels)
return lm_loss, binary_logits
return lm_logits, None
def state_dict_for_save_checkpoint(self, destination=None, prefix='', def state_dict_for_save_checkpoint(self, destination=None, prefix='',
keep_vars=False): keep_vars=False):
......
...@@ -68,6 +68,7 @@ def get_batch(data_iterator): ...@@ -68,6 +68,7 @@ def get_batch(data_iterator):
def forward_step(data_iterator, model): def forward_step(data_iterator, model):
"""Forward step.""" """Forward step."""
args = get_args()
timers = get_timers() timers = get_timers()
# Get the batch. # Get the batch.
...@@ -76,15 +77,19 @@ def forward_step(data_iterator, model): ...@@ -76,15 +77,19 @@ def forward_step(data_iterator, model):
= get_batch(data_iterator) = get_batch(data_iterator)
timers('batch generator').stop() timers('batch generator').stop()
# Forward model. # Forward model. lm_labels
lm_logits, sop_logits = model(tokens, padding_mask, tokentype_ids=types) if args.fp16_lm_cross_entropy:
lm_loss_, sop_logits = model(tokens, padding_mask, tokentype_ids=types,
lm_labels=lm_labels)
else:
lm_logits, sop_logits = model(tokens, padding_mask, tokentype_ids=types)
lm_loss_ = mpu.vocab_parallel_cross_entropy(
lm_logits.contiguous().float(), lm_labels.contiguous())
sop_loss = F.cross_entropy(sop_logits.view(-1, 2).contiguous().float(), sop_loss = F.cross_entropy(sop_logits.view(-1, 2).contiguous().float(),
sentence_order.view(-1).contiguous(), sentence_order.view(-1).contiguous(),
ignore_index=-1) ignore_index=-1)
lm_loss_ = mpu.vocab_parallel_cross_entropy(lm_logits.contiguous().float(),
lm_labels.contiguous())
lm_loss = torch.sum( lm_loss = torch.sum(
lm_loss_.view(-1) * loss_mask.reshape(-1)) / loss_mask.sum() lm_loss_.view(-1) * loss_mask.reshape(-1)) / loss_mask.sum()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment