Commit acfe848e authored by mohammad's avatar mohammad
Browse files

added fp16 cross entropy loss option for gpt2

parent 2ede8235
...@@ -294,6 +294,10 @@ def _add_mixed_precision_args(parser): ...@@ -294,6 +294,10 @@ def _add_mixed_precision_args(parser):
help='Window over which to raise/lower dynamic scale.') help='Window over which to raise/lower dynamic scale.')
group.add_argument('--min-scale', type=float, default=1, group.add_argument('--min-scale', type=float, default=1,
help='Minimum loss scale for dynamic loss scale.') help='Minimum loss scale for dynamic loss scale.')
group.add_argument('--fp16-lm-cross-entropy', action='store_true',
help='Move the cross entropy unreduced loss calculation'
'for lm head to fp16.')
return parser return parser
......
...@@ -18,6 +18,7 @@ ...@@ -18,6 +18,7 @@
import torch import torch
from megatron import get_args from megatron import get_args
from megatron import mpu
from megatron.module import MegatronModule from megatron.module import MegatronModule
from .language_model import parallel_lm_logits from .language_model import parallel_lm_logits
...@@ -25,9 +26,6 @@ from .language_model import get_language_model ...@@ -25,9 +26,6 @@ from .language_model import get_language_model
from .utils import init_method_normal from .utils import init_method_normal
from .utils import scaled_init_method_normal from .utils import scaled_init_method_normal
from megatron.utils import report_memory
from megatron import mpu
def gpt2_attention_mask_func(attention_scores, ltor_mask): def gpt2_attention_mask_func(attention_scores, ltor_mask):
attention_scores.masked_fill_(ltor_mask, -10000.0) attention_scores.masked_fill_(ltor_mask, -10000.0)
...@@ -51,7 +49,7 @@ class GPT2Model(MegatronModule): ...@@ -51,7 +49,7 @@ class GPT2Model(MegatronModule):
scaled_init_method=scaled_init_method_normal(args.init_method_std, scaled_init_method=scaled_init_method_normal(args.init_method_std,
args.num_layers)) args.num_layers))
def forward(self, input_ids, position_ids, attention_mask, labels, def forward(self, input_ids, position_ids, attention_mask, labels=None,
tokentype_ids=None, layer_past=None, get_key_value=False, tokentype_ids=None, layer_past=None, get_key_value=False,
forward_method_parallel_output=None): forward_method_parallel_output=None):
...@@ -78,14 +76,12 @@ class GPT2Model(MegatronModule): ...@@ -78,14 +76,12 @@ class GPT2Model(MegatronModule):
if get_key_value: if get_key_value:
output = [output, presents] output = [output, presents]
#report_memory('AAA') if labels is not None:
return output
losses = mpu.vocab_parallel_cross_entropy(output, labels) else:
loss = mpu.vocab_parallel_cross_entropy(output, labels)
#report_memory('BBB') return loss
#return output
return losses
def state_dict_for_save_checkpoint(self, destination=None, prefix='', def state_dict_for_save_checkpoint(self, destination=None, prefix='',
keep_vars=False): keep_vars=False):
......
...@@ -379,7 +379,6 @@ def train(forward_step_func, model, optimizer, lr_scheduler, ...@@ -379,7 +379,6 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
optimizer.param_groups[0]['lr'], optimizer.param_groups[0]['lr'],
iteration, loss_scale, iteration, loss_scale,
report_memory_flag) report_memory_flag)
#report_memory_flag = True
# Autoresume # Autoresume
if args.adlr_autoresume and \ if args.adlr_autoresume and \
......
...@@ -27,7 +27,7 @@ from megatron.model import GPT2Model ...@@ -27,7 +27,7 @@ from megatron.model import GPT2Model
from megatron.training import pretrain from megatron.training import pretrain
from megatron.utils import get_ltor_masks_and_position_ids from megatron.utils import get_ltor_masks_and_position_ids
from megatron.utils import reduce_losses from megatron.utils import reduce_losses
from megatron.utils import report_memory
def model_provider(): def model_provider():
"""Build the model.""" """Build the model."""
...@@ -72,6 +72,7 @@ def get_batch(data_iterator): ...@@ -72,6 +72,7 @@ def get_batch(data_iterator):
def forward_step(data_iterator, model): def forward_step(data_iterator, model):
"""Forward step.""" """Forward step."""
args = get_args()
timers = get_timers() timers = get_timers()
# Get the batch. # Get the batch.
...@@ -81,12 +82,13 @@ def forward_step(data_iterator, model): ...@@ -81,12 +82,13 @@ def forward_step(data_iterator, model):
timers('batch generator').stop() timers('batch generator').stop()
# Forward model. # Forward model.
losses = model(tokens, position_ids, attention_mask, labels) if args.fp16_lm_cross_entropy:
#report_memory('CCC') losses = model(tokens, position_ids, attention_mask, labels=labels)
#exit() else:
#losses = mpu.vocab_parallel_cross_entropy(output.contiguous().float(), output = model(tokens, position_ids, attention_mask)
# labels) losses = mpu.vocab_parallel_cross_entropy(output.contiguous().float(),
#report_memory('DDD') labels)
loss_mask = loss_mask.view(-1) loss_mask = loss_mask.view(-1)
loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask.sum() loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask.sum()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment