Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
2ede8235
Commit
2ede8235
authored
Jun 05, 2020
by
mohammad
Browse files
testing
parent
5897a790
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
20 additions
and
6 deletions
+20
-6
megatron/model/gpt2_model.py
megatron/model/gpt2_model.py
+12
-2
megatron/training.py
megatron/training.py
+1
-0
pretrain_gpt2.py
pretrain_gpt2.py
+7
-4
No files found.
megatron/model/gpt2_model.py
View file @
2ede8235
...
@@ -25,6 +25,9 @@ from .language_model import get_language_model
...
@@ -25,6 +25,9 @@ from .language_model import get_language_model
from
.utils
import
init_method_normal
from
.utils
import
init_method_normal
from
.utils
import
scaled_init_method_normal
from
.utils
import
scaled_init_method_normal
from
megatron.utils
import
report_memory
from
megatron
import
mpu
def
gpt2_attention_mask_func
(
attention_scores
,
ltor_mask
):
def
gpt2_attention_mask_func
(
attention_scores
,
ltor_mask
):
attention_scores
.
masked_fill_
(
ltor_mask
,
-
10000.0
)
attention_scores
.
masked_fill_
(
ltor_mask
,
-
10000.0
)
...
@@ -48,7 +51,7 @@ class GPT2Model(MegatronModule):
...
@@ -48,7 +51,7 @@ class GPT2Model(MegatronModule):
scaled_init_method
=
scaled_init_method_normal
(
args
.
init_method_std
,
scaled_init_method
=
scaled_init_method_normal
(
args
.
init_method_std
,
args
.
num_layers
))
args
.
num_layers
))
def
forward
(
self
,
input_ids
,
position_ids
,
attention_mask
,
def
forward
(
self
,
input_ids
,
position_ids
,
attention_mask
,
labels
,
tokentype_ids
=
None
,
layer_past
=
None
,
get_key_value
=
False
,
tokentype_ids
=
None
,
layer_past
=
None
,
get_key_value
=
False
,
forward_method_parallel_output
=
None
):
forward_method_parallel_output
=
None
):
...
@@ -75,7 +78,14 @@ class GPT2Model(MegatronModule):
...
@@ -75,7 +78,14 @@ class GPT2Model(MegatronModule):
if
get_key_value
:
if
get_key_value
:
output
=
[
output
,
presents
]
output
=
[
output
,
presents
]
return
output
#report_memory('AAA')
losses
=
mpu
.
vocab_parallel_cross_entropy
(
output
,
labels
)
#report_memory('BBB')
#return output
return
losses
def
state_dict_for_save_checkpoint
(
self
,
destination
=
None
,
prefix
=
''
,
def
state_dict_for_save_checkpoint
(
self
,
destination
=
None
,
prefix
=
''
,
keep_vars
=
False
):
keep_vars
=
False
):
...
...
megatron/training.py
View file @
2ede8235
...
@@ -379,6 +379,7 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
...
@@ -379,6 +379,7 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
optimizer
.
param_groups
[
0
][
'lr'
],
optimizer
.
param_groups
[
0
][
'lr'
],
iteration
,
loss_scale
,
iteration
,
loss_scale
,
report_memory_flag
)
report_memory_flag
)
#report_memory_flag = True
# Autoresume
# Autoresume
if
args
.
adlr_autoresume
and
\
if
args
.
adlr_autoresume
and
\
...
...
pretrain_gpt2.py
View file @
2ede8235
...
@@ -27,7 +27,7 @@ from megatron.model import GPT2Model
...
@@ -27,7 +27,7 @@ from megatron.model import GPT2Model
from
megatron.training
import
pretrain
from
megatron.training
import
pretrain
from
megatron.utils
import
get_ltor_masks_and_position_ids
from
megatron.utils
import
get_ltor_masks_and_position_ids
from
megatron.utils
import
reduce_losses
from
megatron.utils
import
reduce_losses
from
megatron.utils
import
report_memory
def
model_provider
():
def
model_provider
():
"""Build the model."""
"""Build the model."""
...
@@ -81,9 +81,12 @@ def forward_step(data_iterator, model):
...
@@ -81,9 +81,12 @@ def forward_step(data_iterator, model):
timers
(
'batch generator'
).
stop
()
timers
(
'batch generator'
).
stop
()
# Forward model.
# Forward model.
output
=
model
(
tokens
,
position_ids
,
attention_mask
)
losses
=
model
(
tokens
,
position_ids
,
attention_mask
,
labels
)
losses
=
mpu
.
vocab_parallel_cross_entropy
(
output
.
contiguous
().
float
(),
#report_memory('CCC')
labels
)
#exit()
#losses = mpu.vocab_parallel_cross_entropy(output.contiguous().float(),
# labels)
#report_memory('DDD')
loss_mask
=
loss_mask
.
view
(
-
1
)
loss_mask
=
loss_mask
.
view
(
-
1
)
loss
=
torch
.
sum
(
losses
.
view
(
-
1
)
*
loss_mask
)
/
loss_mask
.
sum
()
loss
=
torch
.
sum
(
losses
.
view
(
-
1
)
*
loss_mask
)
/
loss_mask
.
sum
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment