Commit acfe848e authored by mohammad's avatar mohammad
Browse files

added fp16 cross entropy loss option for gpt2

parent 2ede8235
......@@ -294,6 +294,10 @@ def _add_mixed_precision_args(parser):
help='Window over which to raise/lower dynamic scale.')
group.add_argument('--min-scale', type=float, default=1,
help='Minimum loss scale for dynamic loss scale.')
group.add_argument('--fp16-lm-cross-entropy', action='store_true',
help='Move the cross entropy unreduced loss calculation'
'for lm head to fp16.')
return parser
......
......@@ -18,6 +18,7 @@
import torch
from megatron import get_args
from megatron import mpu
from megatron.module import MegatronModule
from .language_model import parallel_lm_logits
......@@ -25,9 +26,6 @@ from .language_model import get_language_model
from .utils import init_method_normal
from .utils import scaled_init_method_normal
from megatron.utils import report_memory
from megatron import mpu
def gpt2_attention_mask_func(attention_scores, ltor_mask):
attention_scores.masked_fill_(ltor_mask, -10000.0)
......@@ -51,7 +49,7 @@ class GPT2Model(MegatronModule):
scaled_init_method=scaled_init_method_normal(args.init_method_std,
args.num_layers))
def forward(self, input_ids, position_ids, attention_mask, labels,
def forward(self, input_ids, position_ids, attention_mask, labels=None,
tokentype_ids=None, layer_past=None, get_key_value=False,
forward_method_parallel_output=None):
......@@ -78,14 +76,12 @@ class GPT2Model(MegatronModule):
if get_key_value:
output = [output, presents]
#report_memory('AAA')
losses = mpu.vocab_parallel_cross_entropy(output, labels)
#report_memory('BBB')
if labels is not None:
return output
else:
loss = mpu.vocab_parallel_cross_entropy(output, labels)
return loss
#return output
return losses
def state_dict_for_save_checkpoint(self, destination=None, prefix='',
keep_vars=False):
......
......@@ -379,7 +379,6 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
optimizer.param_groups[0]['lr'],
iteration, loss_scale,
report_memory_flag)
#report_memory_flag = True
# Autoresume
if args.adlr_autoresume and \
......
......@@ -27,7 +27,7 @@ from megatron.model import GPT2Model
from megatron.training import pretrain
from megatron.utils import get_ltor_masks_and_position_ids
from megatron.utils import reduce_losses
from megatron.utils import report_memory
def model_provider():
"""Build the model."""
......@@ -72,6 +72,7 @@ def get_batch(data_iterator):
def forward_step(data_iterator, model):
"""Forward step."""
args = get_args()
timers = get_timers()
# Get the batch.
......@@ -81,12 +82,13 @@ def forward_step(data_iterator, model):
timers('batch generator').stop()
# Forward model.
losses = model(tokens, position_ids, attention_mask, labels)
#report_memory('CCC')
#exit()
#losses = mpu.vocab_parallel_cross_entropy(output.contiguous().float(),
# labels)
#report_memory('DDD')
if args.fp16_lm_cross_entropy:
losses = model(tokens, position_ids, attention_mask, labels=labels)
else:
output = model(tokens, position_ids, attention_mask)
losses = mpu.vocab_parallel_cross_entropy(output.contiguous().float(),
labels)
loss_mask = loss_mask.view(-1)
loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask.sum()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment