Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
acfe848e
Commit
acfe848e
authored
Jun 05, 2020
by
mohammad
Browse files
added fp16 cross entropy loss option for gpt2
parent
2ede8235
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
20 additions
and
19 deletions
+20
-19
megatron/arguments.py
megatron/arguments.py
+4
-0
megatron/model/gpt2_model.py
megatron/model/gpt2_model.py
+7
-11
megatron/training.py
megatron/training.py
+0
-1
pretrain_gpt2.py
pretrain_gpt2.py
+9
-7
No files found.
megatron/arguments.py
View file @
acfe848e
...
...
@@ -294,6 +294,10 @@ def _add_mixed_precision_args(parser):
help
=
'Window over which to raise/lower dynamic scale.'
)
group
.
add_argument
(
'--min-scale'
,
type
=
float
,
default
=
1
,
help
=
'Minimum loss scale for dynamic loss scale.'
)
group
.
add_argument
(
'--fp16-lm-cross-entropy'
,
action
=
'store_true'
,
help
=
'Move the cross entropy unreduced loss calculation'
'for lm head to fp16.'
)
return
parser
...
...
megatron/model/gpt2_model.py
View file @
acfe848e
...
...
@@ -18,6 +18,7 @@
import
torch
from
megatron
import
get_args
from
megatron
import
mpu
from
megatron.module
import
MegatronModule
from
.language_model
import
parallel_lm_logits
...
...
@@ -25,9 +26,6 @@ from .language_model import get_language_model
from
.utils
import
init_method_normal
from
.utils
import
scaled_init_method_normal
from
megatron.utils
import
report_memory
from
megatron
import
mpu
def
gpt2_attention_mask_func
(
attention_scores
,
ltor_mask
):
attention_scores
.
masked_fill_
(
ltor_mask
,
-
10000.0
)
...
...
@@ -51,7 +49,7 @@ class GPT2Model(MegatronModule):
scaled_init_method
=
scaled_init_method_normal
(
args
.
init_method_std
,
args
.
num_layers
))
def
forward
(
self
,
input_ids
,
position_ids
,
attention_mask
,
labels
,
def
forward
(
self
,
input_ids
,
position_ids
,
attention_mask
,
labels
=
None
,
tokentype_ids
=
None
,
layer_past
=
None
,
get_key_value
=
False
,
forward_method_parallel_output
=
None
):
...
...
@@ -78,14 +76,12 @@ class GPT2Model(MegatronModule):
if
get_key_value
:
output
=
[
output
,
presents
]
#report_memory('AAA')
losses
=
mpu
.
vocab_parallel_cross_entropy
(
output
,
labels
)
#report_memory('BBB')
if
labels
is
not
None
:
return
output
else
:
loss
=
mpu
.
vocab_parallel_cross_entropy
(
output
,
labels
)
return
loss
#return output
return
losses
def
state_dict_for_save_checkpoint
(
self
,
destination
=
None
,
prefix
=
''
,
keep_vars
=
False
):
...
...
megatron/training.py
View file @
acfe848e
...
...
@@ -379,7 +379,6 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
optimizer
.
param_groups
[
0
][
'lr'
],
iteration
,
loss_scale
,
report_memory_flag
)
#report_memory_flag = True
# Autoresume
if
args
.
adlr_autoresume
and
\
...
...
pretrain_gpt2.py
View file @
acfe848e
...
...
@@ -27,7 +27,7 @@ from megatron.model import GPT2Model
from
megatron.training
import
pretrain
from
megatron.utils
import
get_ltor_masks_and_position_ids
from
megatron.utils
import
reduce_losses
from
megatron.utils
import
report_memory
def
model_provider
():
"""Build the model."""
...
...
@@ -72,6 +72,7 @@ def get_batch(data_iterator):
def
forward_step
(
data_iterator
,
model
):
"""Forward step."""
args
=
get_args
()
timers
=
get_timers
()
# Get the batch.
...
...
@@ -81,12 +82,13 @@ def forward_step(data_iterator, model):
timers
(
'batch generator'
).
stop
()
# Forward model.
losses
=
model
(
tokens
,
position_ids
,
attention_mask
,
labels
)
#report_memory('CCC')
#exit()
#losses = mpu.vocab_parallel_cross_entropy(output.contiguous().float(),
# labels)
#report_memory('DDD')
if
args
.
fp16_lm_cross_entropy
:
losses
=
model
(
tokens
,
position_ids
,
attention_mask
,
labels
=
labels
)
else
:
output
=
model
(
tokens
,
position_ids
,
attention_mask
)
losses
=
mpu
.
vocab_parallel_cross_entropy
(
output
.
contiguous
().
float
(),
labels
)
loss_mask
=
loss_mask
.
view
(
-
1
)
loss
=
torch
.
sum
(
losses
.
view
(
-
1
)
*
loss_mask
)
/
loss_mask
.
sum
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment