Commit ac967fa0 authored by Neel Kant's avatar Neel Kant
Browse files

Merge branch 'master' into ict-merge

parents 7b3baaaa 46a536cc
...@@ -97,6 +97,9 @@ def parse_args(extra_args_provider=None, defaults={}, ...@@ -97,6 +97,9 @@ def parse_args(extra_args_provider=None, defaults={},
if args.num_unique_layers < args.num_layers: if args.num_unique_layers < args.num_layers:
assert args.DDP_impl == 'local', \ assert args.DDP_impl == 'local', \
'torch-DDP does not work with parameters sharing.' 'torch-DDP does not work with parameters sharing.'
# Mixed precision checks.
if args.fp16_lm_cross_entropy:
assert args.fp16, 'lm cross entropy in fp16 only support in fp16 mode.'
_print_args(args) _print_args(args)
return args return args
...@@ -300,6 +303,10 @@ def _add_mixed_precision_args(parser): ...@@ -300,6 +303,10 @@ def _add_mixed_precision_args(parser):
help='Window over which to raise/lower dynamic scale.') help='Window over which to raise/lower dynamic scale.')
group.add_argument('--min-scale', type=float, default=1, group.add_argument('--min-scale', type=float, default=1,
help='Minimum loss scale for dynamic loss scale.') help='Minimum loss scale for dynamic loss scale.')
group.add_argument('--fp16-lm-cross-entropy', action='store_true',
help='Move the cross entropy unreduced loss calculation'
'for lm head to fp16.')
return parser return parser
......
...@@ -159,7 +159,7 @@ def get_samples_mapping_(indexed_dataset, ...@@ -159,7 +159,7 @@ def get_samples_mapping_(indexed_dataset,
print_rank_0(' > loading indexed mapping from {}'.format( print_rank_0(' > loading indexed mapping from {}'.format(
indexmap_filename)) indexmap_filename))
start_time = time.time() start_time = time.time()
samples_mapping = np.load(indexmap_filename, allow_pickle=True) samples_mapping = np.load(indexmap_filename, allow_pickle=True, mmap_mode='r')
print_rank_0(' loaded indexed file in {:3.3f} seconds'.format( print_rank_0(' loaded indexed file in {:3.3f} seconds'.format(
time.time() - start_time)) time.time() - start_time))
print_rank_0(' total number of samples: {}'.format( print_rank_0(' total number of samples: {}'.format(
......
...@@ -211,13 +211,13 @@ def _build_index_mappings(name, data_prefix, documents, sizes, ...@@ -211,13 +211,13 @@ def _build_index_mappings(name, data_prefix, documents, sizes,
start_time = time.time() start_time = time.time()
print_rank_0(' > loading doc-idx mapping from {}'.format( print_rank_0(' > loading doc-idx mapping from {}'.format(
doc_idx_filename)) doc_idx_filename))
doc_idx = np.load(doc_idx_filename, allow_pickle=True) doc_idx = np.load(doc_idx_filename, allow_pickle=True, mmap_mode='r')
print_rank_0(' > loading sample-idx mapping from {}'.format( print_rank_0(' > loading sample-idx mapping from {}'.format(
sample_idx_filename)) sample_idx_filename))
sample_idx = np.load(sample_idx_filename, allow_pickle=True) sample_idx = np.load(sample_idx_filename, allow_pickle=True, mmap_mode='r')
print_rank_0(' > loading shuffle-idx mapping from {}'.format( print_rank_0(' > loading shuffle-idx mapping from {}'.format(
shuffle_idx_filename)) shuffle_idx_filename))
shuffle_idx = np.load(shuffle_idx_filename, allow_pickle=True) shuffle_idx = np.load(shuffle_idx_filename, allow_pickle=True, mmap_mode='r')
print_rank_0(' loaded indexed file in {:3.3f} seconds'.format( print_rank_0(' loaded indexed file in {:3.3f} seconds'.format(
time.time() - start_time)) time.time() - start_time))
print_rank_0(' total number of samples: {}'.format( print_rank_0(' total number of samples: {}'.format(
......
...@@ -18,6 +18,7 @@ ...@@ -18,6 +18,7 @@
import torch import torch
from megatron import get_args from megatron import get_args
from megatron import mpu
from megatron.model.language_model import parallel_lm_logits from megatron.model.language_model import parallel_lm_logits
from megatron.model.language_model import get_language_model from megatron.model.language_model import get_language_model
from megatron.model.transformer import LayerNorm from megatron.model.transformer import LayerNorm
...@@ -80,6 +81,7 @@ class BertModel(MegatronModule): ...@@ -80,6 +81,7 @@ class BertModel(MegatronModule):
super(BertModel, self).__init__() super(BertModel, self).__init__()
args = get_args() args = get_args()
self.fp16_lm_cross_entropy = args.fp16_lm_cross_entropy
self.add_binary_head = add_binary_head self.add_binary_head = add_binary_head
self.parallel_output = parallel_output self.parallel_output = parallel_output
init_method = init_method_normal(args.init_method_std) init_method = init_method_normal(args.init_method_std)
...@@ -102,7 +104,8 @@ class BertModel(MegatronModule): ...@@ -102,7 +104,8 @@ class BertModel(MegatronModule):
init_method) init_method)
self._binary_head_key = 'binary_head' self._binary_head_key = 'binary_head'
def forward(self, input_ids, attention_mask, tokentype_ids=None): def forward(self, input_ids, attention_mask,
tokentype_ids=None, lm_labels=None):
extended_attention_mask = bert_extended_attention_mask( extended_attention_mask = bert_extended_attention_mask(
attention_mask, next(self.language_model.parameters()).dtype) attention_mask, next(self.language_model.parameters()).dtype)
...@@ -125,11 +128,21 @@ class BertModel(MegatronModule): ...@@ -125,11 +128,21 @@ class BertModel(MegatronModule):
lm_logits = self.lm_head( lm_logits = self.lm_head(
lm_output, self.language_model.embedding.word_embeddings.weight) lm_output, self.language_model.embedding.word_embeddings.weight)
binary_logits = None
if self.add_binary_head: if self.add_binary_head:
binary_logits = self.binary_head(pooled_output) binary_logits = self.binary_head(pooled_output)
if lm_labels is None:
return lm_logits, binary_logits return lm_logits, binary_logits
else:
if self.fp16_lm_cross_entropy:
assert lm_logits.dtype == torch.half
lm_loss = mpu.vocab_parallel_cross_entropy(lm_logits, lm_labels)
else:
lm_loss = mpu.vocab_parallel_cross_entropy(lm_logits.float(),
lm_labels)
return lm_loss, binary_logits
return lm_logits, None
def state_dict_for_save_checkpoint(self, destination=None, prefix='', def state_dict_for_save_checkpoint(self, destination=None, prefix='',
keep_vars=False): keep_vars=False):
......
...@@ -18,6 +18,7 @@ ...@@ -18,6 +18,7 @@
import torch import torch
from megatron import get_args from megatron import get_args
from megatron import mpu
from megatron.module import MegatronModule from megatron.module import MegatronModule
from .language_model import parallel_lm_logits from .language_model import parallel_lm_logits
...@@ -39,6 +40,7 @@ class GPT2Model(MegatronModule): ...@@ -39,6 +40,7 @@ class GPT2Model(MegatronModule):
args = get_args() args = get_args()
self.parallel_output = parallel_output self.parallel_output = parallel_output
self.fp16_lm_cross_entropy = args.fp16_lm_cross_entropy
self.language_model, self._language_model_key = get_language_model( self.language_model, self._language_model_key = get_language_model(
attention_mask_func=gpt2_attention_mask_func, attention_mask_func=gpt2_attention_mask_func,
...@@ -48,7 +50,7 @@ class GPT2Model(MegatronModule): ...@@ -48,7 +50,7 @@ class GPT2Model(MegatronModule):
scaled_init_method=scaled_init_method_normal(args.init_method_std, scaled_init_method=scaled_init_method_normal(args.init_method_std,
args.num_layers)) args.num_layers))
def forward(self, input_ids, position_ids, attention_mask, def forward(self, input_ids, position_ids, attention_mask, labels=None,
tokentype_ids=None, layer_past=None, get_key_value=False, tokentype_ids=None, layer_past=None, get_key_value=False,
forward_method_parallel_output=None): forward_method_parallel_output=None):
...@@ -75,7 +77,16 @@ class GPT2Model(MegatronModule): ...@@ -75,7 +77,16 @@ class GPT2Model(MegatronModule):
if get_key_value: if get_key_value:
output = [output, presents] output = [output, presents]
return output if labels is None:
return output
else:
if self.fp16_lm_cross_entropy:
assert output.dtype == torch.half
loss = mpu.vocab_parallel_cross_entropy(output, labels)
else:
loss = mpu.vocab_parallel_cross_entropy(output.float(), labels)
return loss
def state_dict_for_save_checkpoint(self, destination=None, prefix='', def state_dict_for_save_checkpoint(self, destination=None, prefix='',
keep_vars=False): keep_vars=False):
......
...@@ -67,6 +67,7 @@ def get_batch(data_iterator): ...@@ -67,6 +67,7 @@ def get_batch(data_iterator):
def forward_step(data_iterator, model): def forward_step(data_iterator, model):
"""Forward step.""" """Forward step."""
args = get_args()
timers = get_timers() timers = get_timers()
# Get the batch. # Get the batch.
...@@ -75,15 +76,15 @@ def forward_step(data_iterator, model): ...@@ -75,15 +76,15 @@ def forward_step(data_iterator, model):
= get_batch(data_iterator) = get_batch(data_iterator)
timers('batch generator').stop() timers('batch generator').stop()
# Forward model. # Forward model. lm_labels
lm_logits, sop_logits = model(tokens, padding_mask, tokentype_ids=types) lm_loss_, sop_logits = model(tokens, padding_mask,
tokentype_ids=types,
lm_labels=lm_labels)
sop_loss = F.cross_entropy(sop_logits.view(-1, 2).contiguous().float(), sop_loss = F.cross_entropy(sop_logits.view(-1, 2).float(),
sentence_order.view(-1).contiguous(), sentence_order.view(-1),
ignore_index=-1) ignore_index=-1)
lm_loss_ = mpu.vocab_parallel_cross_entropy(lm_logits.contiguous().float(),
lm_labels.contiguous())
lm_loss = torch.sum( lm_loss = torch.sum(
lm_loss_.view(-1) * loss_mask.reshape(-1)) / loss_mask.sum() lm_loss_.view(-1) * loss_mask.reshape(-1)) / loss_mask.sum()
......
...@@ -71,6 +71,7 @@ def get_batch(data_iterator): ...@@ -71,6 +71,7 @@ def get_batch(data_iterator):
def forward_step(data_iterator, model): def forward_step(data_iterator, model):
"""Forward step.""" """Forward step."""
args = get_args()
timers = get_timers() timers = get_timers()
# Get the batch. # Get the batch.
...@@ -80,9 +81,8 @@ def forward_step(data_iterator, model): ...@@ -80,9 +81,8 @@ def forward_step(data_iterator, model):
timers('batch generator').stop() timers('batch generator').stop()
# Forward model. # Forward model.
output = model(tokens, position_ids, attention_mask) losses = model(tokens, position_ids, attention_mask, labels=labels)
losses = mpu.vocab_parallel_cross_entropy(output.contiguous().float(),
labels)
loss_mask = loss_mask.view(-1) loss_mask = loss_mask.view(-1)
loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask.sum() loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask.sum()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment