Commit 5f174c07 authored by Mohammad's avatar Mohammad
Browse files

fp32 fixes

parent 1c1a55da
......@@ -102,6 +102,7 @@ class ParallelSelfAttention(MegatronModule):
output_layer_init_method, layer_number):
super(ParallelSelfAttention, self).__init__()
args = get_args()
self.fp16 = args.fp16
self.attention_mask_func = attention_mask_func
self.apply_query_key_layer_scaling = args.apply_query_key_layer_scaling
......@@ -244,7 +245,7 @@ class ParallelSelfAttention(MegatronModule):
query_layer, key_layer)
# fp32 conversion.
if self.attention_softmax_in_fp32:
if self.fp16 and self.attention_softmax_in_fp32:
attention_scores = attention_scores.float()
# Apply attention mask. [b, np, s, s]
......@@ -267,7 +268,7 @@ class ParallelSelfAttention(MegatronModule):
attention_probs = self._get_attention_probs(attention_scores)
# fp16 conversion
if self.attention_softmax_in_fp32:
if self.fp16 and self.attention_softmax_in_fp32:
attention_probs = attention_probs.half()
# Context layer. [b, s, hp]
......
......@@ -379,9 +379,12 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
iteration += 1
# Logging.
loss_scale = None
if args.fp16:
loss_scale = optimizer.loss_scale
report_memory_flag = training_log(loss_dict, total_loss_dict,
optimizer.param_groups[0]['lr'],
iteration, optimizer.loss_scale,
iteration, loss_scale,
report_memory_flag)
# Autoresume
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment