Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
197c132e
Commit
197c132e
authored
Jun 05, 2020
by
mohammad
Browse files
addressed jareds comments
parent
78022005
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
23 additions
and
18 deletions
+23
-18
megatron/arguments.py
megatron/arguments.py
+3
-0
megatron/model/bert_model.py
megatron/model/bert_model.py
+7
-1
megatron/model/gpt2_model.py
megatron/model/gpt2_model.py
+6
-1
pretrain_bert.py
pretrain_bert.py
+5
-9
pretrain_gpt2.py
pretrain_gpt2.py
+2
-7
No files found.
megatron/arguments.py
View file @
197c132e
...
...
@@ -97,6 +97,9 @@ def parse_args(extra_args_provider=None, defaults={},
if
args
.
num_unique_layers
<
args
.
num_layers
:
assert
args
.
DDP_impl
==
'local'
,
\
'torch-DDP does not work with parameters sharing.'
# Mixed precision checks.
if
args
.
fp16_lm_cross_entropy
:
assert
args
.
fp16
,
'lm cross entropy in fp16 only support in fp16 mode.'
_print_args
(
args
)
return
args
...
...
megatron/model/bert_model.py
View file @
197c132e
...
...
@@ -115,6 +115,7 @@ class BertModel(MegatronModule):
super
(
BertModel
,
self
).
__init__
()
args
=
get_args
()
self
.
fp16_lm_cross_entropy
=
args
.
fp16_lm_cross_entropy
self
.
add_binary_head
=
add_binary_head
self
.
parallel_output
=
parallel_output
init_method
=
init_method_normal
(
args
.
init_method_std
)
...
...
@@ -170,7 +171,12 @@ class BertModel(MegatronModule):
if
lm_labels
is
None
:
return
lm_logits
,
binary_logits
else
:
lm_loss
=
mpu
.
vocab_parallel_cross_entropy
(
lm_logits
,
lm_labels
)
if
self
.
fp16_lm_cross_entropy
:
assert
lm_logits
.
dtype
==
torch
.
half
lm_loss
=
mpu
.
vocab_parallel_cross_entropy
(
lm_logits
,
lm_labels
)
else
:
lm_loss
=
mpu
.
vocab_parallel_cross_entropy
(
lm_logits
.
float
(),
lm_labels
)
return
lm_loss
,
binary_logits
...
...
megatron/model/gpt2_model.py
View file @
197c132e
...
...
@@ -40,6 +40,7 @@ class GPT2Model(MegatronModule):
args
=
get_args
()
self
.
parallel_output
=
parallel_output
self
.
fp16_lm_cross_entropy
=
args
.
fp16_lm_cross_entropy
self
.
language_model
,
self
.
_language_model_key
=
get_language_model
(
attention_mask_func
=
gpt2_attention_mask_func
,
...
...
@@ -79,7 +80,11 @@ class GPT2Model(MegatronModule):
if
labels
is
None
:
return
output
else
:
loss
=
mpu
.
vocab_parallel_cross_entropy
(
output
,
labels
)
if
self
.
fp16_lm_cross_entropy
:
assert
output
.
dtype
==
torch
.
half
loss
=
mpu
.
vocab_parallel_cross_entropy
(
output
,
labels
)
else
:
loss
=
mpu
.
vocab_parallel_cross_entropy
(
output
.
float
(),
labels
)
return
loss
...
...
pretrain_bert.py
View file @
197c132e
...
...
@@ -78,16 +78,12 @@ def forward_step(data_iterator, model):
timers
(
'batch generator'
).
stop
()
# Forward model. lm_labels
if
args
.
fp16_lm_cross_entropy
:
lm_loss_
,
sop_logits
=
model
(
tokens
,
padding_mask
,
tokentype_ids
=
types
,
lm_labels
=
lm_labels
)
else
:
lm_logits
,
sop_logits
=
model
(
tokens
,
padding_mask
,
tokentype_ids
=
types
)
lm_loss_
=
mpu
.
vocab_parallel_cross_entropy
(
lm_logits
.
contiguous
().
float
(),
lm_labels
.
contiguous
())
lm_loss_
,
sop_logits
=
model
(
tokens
,
padding_mask
,
tokentype_ids
=
types
,
lm_labels
=
lm_labels
)
sop_loss
=
F
.
cross_entropy
(
sop_logits
.
view
(
-
1
,
2
).
contiguous
().
float
(),
sentence_order
.
view
(
-
1
)
.
contiguous
()
,
sop_loss
=
F
.
cross_entropy
(
sop_logits
.
view
(
-
1
,
2
).
float
(),
sentence_order
.
view
(
-
1
),
ignore_index
=-
1
)
lm_loss
=
torch
.
sum
(
...
...
pretrain_gpt2.py
View file @
197c132e
...
...
@@ -82,13 +82,8 @@ def forward_step(data_iterator, model):
timers
(
'batch generator'
).
stop
()
# Forward model.
if
args
.
fp16_lm_cross_entropy
:
losses
=
model
(
tokens
,
position_ids
,
attention_mask
,
labels
=
labels
)
else
:
output
=
model
(
tokens
,
position_ids
,
attention_mask
)
losses
=
mpu
.
vocab_parallel_cross_entropy
(
output
.
contiguous
().
float
(),
labels
)
losses
=
model
(
tokens
,
position_ids
,
attention_mask
,
labels
=
labels
)
loss_mask
=
loss_mask
.
view
(
-
1
)
loss
=
torch
.
sum
(
losses
.
view
(
-
1
)
*
loss_mask
)
/
loss_mask
.
sum
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment