Commit 2ede8235 authored by mohammad's avatar mohammad
Browse files

testing

parent 5897a790
...@@ -25,6 +25,9 @@ from .language_model import get_language_model ...@@ -25,6 +25,9 @@ from .language_model import get_language_model
from .utils import init_method_normal from .utils import init_method_normal
from .utils import scaled_init_method_normal from .utils import scaled_init_method_normal
from megatron.utils import report_memory
from megatron import mpu
def gpt2_attention_mask_func(attention_scores, ltor_mask): def gpt2_attention_mask_func(attention_scores, ltor_mask):
attention_scores.masked_fill_(ltor_mask, -10000.0) attention_scores.masked_fill_(ltor_mask, -10000.0)
...@@ -48,7 +51,7 @@ class GPT2Model(MegatronModule): ...@@ -48,7 +51,7 @@ class GPT2Model(MegatronModule):
scaled_init_method=scaled_init_method_normal(args.init_method_std, scaled_init_method=scaled_init_method_normal(args.init_method_std,
args.num_layers)) args.num_layers))
def forward(self, input_ids, position_ids, attention_mask, def forward(self, input_ids, position_ids, attention_mask, labels,
tokentype_ids=None, layer_past=None, get_key_value=False, tokentype_ids=None, layer_past=None, get_key_value=False,
forward_method_parallel_output=None): forward_method_parallel_output=None):
...@@ -75,7 +78,14 @@ class GPT2Model(MegatronModule): ...@@ -75,7 +78,14 @@ class GPT2Model(MegatronModule):
if get_key_value: if get_key_value:
output = [output, presents] output = [output, presents]
return output #report_memory('AAA')
losses = mpu.vocab_parallel_cross_entropy(output, labels)
#report_memory('BBB')
#return output
return losses
def state_dict_for_save_checkpoint(self, destination=None, prefix='', def state_dict_for_save_checkpoint(self, destination=None, prefix='',
keep_vars=False): keep_vars=False):
......
...@@ -379,6 +379,7 @@ def train(forward_step_func, model, optimizer, lr_scheduler, ...@@ -379,6 +379,7 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
optimizer.param_groups[0]['lr'], optimizer.param_groups[0]['lr'],
iteration, loss_scale, iteration, loss_scale,
report_memory_flag) report_memory_flag)
#report_memory_flag = True
# Autoresume # Autoresume
if args.adlr_autoresume and \ if args.adlr_autoresume and \
......
...@@ -27,7 +27,7 @@ from megatron.model import GPT2Model ...@@ -27,7 +27,7 @@ from megatron.model import GPT2Model
from megatron.training import pretrain from megatron.training import pretrain
from megatron.utils import get_ltor_masks_and_position_ids from megatron.utils import get_ltor_masks_and_position_ids
from megatron.utils import reduce_losses from megatron.utils import reduce_losses
from megatron.utils import report_memory
def model_provider(): def model_provider():
"""Build the model.""" """Build the model."""
...@@ -81,9 +81,12 @@ def forward_step(data_iterator, model): ...@@ -81,9 +81,12 @@ def forward_step(data_iterator, model):
timers('batch generator').stop() timers('batch generator').stop()
# Forward model. # Forward model.
output = model(tokens, position_ids, attention_mask) losses = model(tokens, position_ids, attention_mask, labels)
losses = mpu.vocab_parallel_cross_entropy(output.contiguous().float(), #report_memory('CCC')
labels) #exit()
#losses = mpu.vocab_parallel_cross_entropy(output.contiguous().float(),
# labels)
#report_memory('DDD')
loss_mask = loss_mask.view(-1) loss_mask = loss_mask.view(-1)
loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask.sum() loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask.sum()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment