Commit 5f174c07 authored by Mohammad's avatar Mohammad
Browse files

fp32 fixes

parent 1c1a55da
...@@ -102,6 +102,7 @@ class ParallelSelfAttention(MegatronModule): ...@@ -102,6 +102,7 @@ class ParallelSelfAttention(MegatronModule):
output_layer_init_method, layer_number): output_layer_init_method, layer_number):
super(ParallelSelfAttention, self).__init__() super(ParallelSelfAttention, self).__init__()
args = get_args() args = get_args()
self.fp16 = args.fp16
self.attention_mask_func = attention_mask_func self.attention_mask_func = attention_mask_func
self.apply_query_key_layer_scaling = args.apply_query_key_layer_scaling self.apply_query_key_layer_scaling = args.apply_query_key_layer_scaling
...@@ -244,7 +245,7 @@ class ParallelSelfAttention(MegatronModule): ...@@ -244,7 +245,7 @@ class ParallelSelfAttention(MegatronModule):
query_layer, key_layer) query_layer, key_layer)
# fp32 conversion. # fp32 conversion.
if self.attention_softmax_in_fp32: if self.fp16 and self.attention_softmax_in_fp32:
attention_scores = attention_scores.float() attention_scores = attention_scores.float()
# Apply attention mask. [b, np, s, s] # Apply attention mask. [b, np, s, s]
...@@ -267,7 +268,7 @@ class ParallelSelfAttention(MegatronModule): ...@@ -267,7 +268,7 @@ class ParallelSelfAttention(MegatronModule):
attention_probs = self._get_attention_probs(attention_scores) attention_probs = self._get_attention_probs(attention_scores)
# fp16 conversion # fp16 conversion
if self.attention_softmax_in_fp32: if self.fp16 and self.attention_softmax_in_fp32:
attention_probs = attention_probs.half() attention_probs = attention_probs.half()
# Context layer. [b, s, hp] # Context layer. [b, s, hp]
......
...@@ -379,9 +379,12 @@ def train(forward_step_func, model, optimizer, lr_scheduler, ...@@ -379,9 +379,12 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
iteration += 1 iteration += 1
# Logging. # Logging.
loss_scale = None
if args.fp16:
loss_scale = optimizer.loss_scale
report_memory_flag = training_log(loss_dict, total_loss_dict, report_memory_flag = training_log(loss_dict, total_loss_dict,
optimizer.param_groups[0]['lr'], optimizer.param_groups[0]['lr'],
iteration, optimizer.loss_scale, iteration, loss_scale,
report_memory_flag) report_memory_flag)
# Autoresume # Autoresume
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment