Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
78022005
Commit
78022005
authored
Jun 05, 2020
by
mohammad
Browse files
added fp16 lm cross entropy to bert
parent
22e3c7e6
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
18 additions
and
6 deletions
+18
-6
megatron/model/bert_model.py
megatron/model/bert_model.py
+9
-2
pretrain_bert.py
pretrain_bert.py
+9
-4
No files found.
megatron/model/bert_model.py
View file @
78022005
...
@@ -18,6 +18,7 @@
...
@@ -18,6 +18,7 @@
import
torch
import
torch
from
megatron
import
get_args
from
megatron
import
get_args
from
megatron
import
mpu
from
megatron.model.language_model
import
parallel_lm_logits
from
megatron.model.language_model
import
parallel_lm_logits
from
megatron.model.language_model
import
get_language_model
from
megatron.model.language_model
import
get_language_model
from
megatron.model.transformer
import
LayerNorm
from
megatron.model.transformer
import
LayerNorm
...
@@ -138,7 +139,8 @@ class BertModel(MegatronModule):
...
@@ -138,7 +139,8 @@ class BertModel(MegatronModule):
init_method
)
init_method
)
self
.
_binary_head_key
=
'binary_head'
self
.
_binary_head_key
=
'binary_head'
def
forward
(
self
,
input_ids
,
attention_mask
,
tokentype_ids
=
None
):
def
forward
(
self
,
input_ids
,
attention_mask
,
tokentype_ids
=
None
,
lm_labels
=
None
):
extended_attention_mask
=
bert_extended_attention_mask
(
extended_attention_mask
=
bert_extended_attention_mask
(
attention_mask
,
next
(
self
.
language_model
.
parameters
()).
dtype
)
attention_mask
,
next
(
self
.
language_model
.
parameters
()).
dtype
)
...
@@ -161,11 +163,16 @@ class BertModel(MegatronModule):
...
@@ -161,11 +163,16 @@ class BertModel(MegatronModule):
lm_logits
=
self
.
lm_head
(
lm_logits
=
self
.
lm_head
(
lm_output
,
self
.
language_model
.
embedding
.
word_embeddings
.
weight
)
lm_output
,
self
.
language_model
.
embedding
.
word_embeddings
.
weight
)
binary_logits
=
None
if
self
.
add_binary_head
:
if
self
.
add_binary_head
:
binary_logits
=
self
.
binary_head
(
pooled_output
)
binary_logits
=
self
.
binary_head
(
pooled_output
)
if
lm_labels
is
None
:
return
lm_logits
,
binary_logits
return
lm_logits
,
binary_logits
else
:
lm_loss
=
mpu
.
vocab_parallel_cross_entropy
(
lm_logits
,
lm_labels
)
return
lm_loss
,
binary_logits
return
lm_logits
,
None
def
state_dict_for_save_checkpoint
(
self
,
destination
=
None
,
prefix
=
''
,
def
state_dict_for_save_checkpoint
(
self
,
destination
=
None
,
prefix
=
''
,
keep_vars
=
False
):
keep_vars
=
False
):
...
...
pretrain_bert.py
View file @
78022005
...
@@ -68,6 +68,7 @@ def get_batch(data_iterator):
...
@@ -68,6 +68,7 @@ def get_batch(data_iterator):
def
forward_step
(
data_iterator
,
model
):
def
forward_step
(
data_iterator
,
model
):
"""Forward step."""
"""Forward step."""
args
=
get_args
()
timers
=
get_timers
()
timers
=
get_timers
()
# Get the batch.
# Get the batch.
...
@@ -76,15 +77,19 @@ def forward_step(data_iterator, model):
...
@@ -76,15 +77,19 @@ def forward_step(data_iterator, model):
=
get_batch
(
data_iterator
)
=
get_batch
(
data_iterator
)
timers
(
'batch generator'
).
stop
()
timers
(
'batch generator'
).
stop
()
# Forward model.
# Forward model. lm_labels
lm_logits
,
sop_logits
=
model
(
tokens
,
padding_mask
,
tokentype_ids
=
types
)
if
args
.
fp16_lm_cross_entropy
:
lm_loss_
,
sop_logits
=
model
(
tokens
,
padding_mask
,
tokentype_ids
=
types
,
lm_labels
=
lm_labels
)
else
:
lm_logits
,
sop_logits
=
model
(
tokens
,
padding_mask
,
tokentype_ids
=
types
)
lm_loss_
=
mpu
.
vocab_parallel_cross_entropy
(
lm_logits
.
contiguous
().
float
(),
lm_labels
.
contiguous
())
sop_loss
=
F
.
cross_entropy
(
sop_logits
.
view
(
-
1
,
2
).
contiguous
().
float
(),
sop_loss
=
F
.
cross_entropy
(
sop_logits
.
view
(
-
1
,
2
).
contiguous
().
float
(),
sentence_order
.
view
(
-
1
).
contiguous
(),
sentence_order
.
view
(
-
1
).
contiguous
(),
ignore_index
=-
1
)
ignore_index
=-
1
)
lm_loss_
=
mpu
.
vocab_parallel_cross_entropy
(
lm_logits
.
contiguous
().
float
(),
lm_labels
.
contiguous
())
lm_loss
=
torch
.
sum
(
lm_loss
=
torch
.
sum
(
lm_loss_
.
view
(
-
1
)
*
loss_mask
.
reshape
(
-
1
))
/
loss_mask
.
sum
()
lm_loss_
.
view
(
-
1
)
*
loss_mask
.
reshape
(
-
1
))
/
loss_mask
.
sum
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment