Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
ac967fa0
Commit
ac967fa0
authored
Jun 26, 2020
by
Neel Kant
Browse files
Merge branch 'master' into ict-merge
parents
7b3baaaa
46a536cc
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
49 additions
and
17 deletions
+49
-17
megatron/arguments.py
megatron/arguments.py
+7
-0
megatron/data/bert_dataset.py
megatron/data/bert_dataset.py
+1
-1
megatron/data/gpt2_dataset.py
megatron/data/gpt2_dataset.py
+3
-3
megatron/model/bert_model.py
megatron/model/bert_model.py
+15
-2
megatron/model/gpt2_model.py
megatron/model/gpt2_model.py
+13
-2
pretrain_bert.py
pretrain_bert.py
+7
-6
pretrain_gpt2.py
pretrain_gpt2.py
+3
-3
No files found.
megatron/arguments.py
View file @
ac967fa0
...
...
@@ -97,6 +97,9 @@ def parse_args(extra_args_provider=None, defaults={},
if
args
.
num_unique_layers
<
args
.
num_layers
:
assert
args
.
DDP_impl
==
'local'
,
\
'torch-DDP does not work with parameters sharing.'
# Mixed precision checks.
if
args
.
fp16_lm_cross_entropy
:
assert
args
.
fp16
,
'lm cross entropy in fp16 only support in fp16 mode.'
_print_args
(
args
)
return
args
...
...
@@ -300,6 +303,10 @@ def _add_mixed_precision_args(parser):
help
=
'Window over which to raise/lower dynamic scale.'
)
group
.
add_argument
(
'--min-scale'
,
type
=
float
,
default
=
1
,
help
=
'Minimum loss scale for dynamic loss scale.'
)
group
.
add_argument
(
'--fp16-lm-cross-entropy'
,
action
=
'store_true'
,
help
=
'Move the cross entropy unreduced loss calculation'
'for lm head to fp16.'
)
return
parser
...
...
megatron/data/bert_dataset.py
View file @
ac967fa0
...
...
@@ -159,7 +159,7 @@ def get_samples_mapping_(indexed_dataset,
print_rank_0
(
' > loading indexed mapping from {}'
.
format
(
indexmap_filename
))
start_time
=
time
.
time
()
samples_mapping
=
np
.
load
(
indexmap_filename
,
allow_pickle
=
True
)
samples_mapping
=
np
.
load
(
indexmap_filename
,
allow_pickle
=
True
,
mmap_mode
=
'r'
)
print_rank_0
(
' loaded indexed file in {:3.3f} seconds'
.
format
(
time
.
time
()
-
start_time
))
print_rank_0
(
' total number of samples: {}'
.
format
(
...
...
megatron/data/gpt2_dataset.py
View file @
ac967fa0
...
...
@@ -211,13 +211,13 @@ def _build_index_mappings(name, data_prefix, documents, sizes,
start_time
=
time
.
time
()
print_rank_0
(
' > loading doc-idx mapping from {}'
.
format
(
doc_idx_filename
))
doc_idx
=
np
.
load
(
doc_idx_filename
,
allow_pickle
=
True
)
doc_idx
=
np
.
load
(
doc_idx_filename
,
allow_pickle
=
True
,
mmap_mode
=
'r'
)
print_rank_0
(
' > loading sample-idx mapping from {}'
.
format
(
sample_idx_filename
))
sample_idx
=
np
.
load
(
sample_idx_filename
,
allow_pickle
=
True
)
sample_idx
=
np
.
load
(
sample_idx_filename
,
allow_pickle
=
True
,
mmap_mode
=
'r'
)
print_rank_0
(
' > loading shuffle-idx mapping from {}'
.
format
(
shuffle_idx_filename
))
shuffle_idx
=
np
.
load
(
shuffle_idx_filename
,
allow_pickle
=
True
)
shuffle_idx
=
np
.
load
(
shuffle_idx_filename
,
allow_pickle
=
True
,
mmap_mode
=
'r'
)
print_rank_0
(
' loaded indexed file in {:3.3f} seconds'
.
format
(
time
.
time
()
-
start_time
))
print_rank_0
(
' total number of samples: {}'
.
format
(
...
...
megatron/model/bert_model.py
View file @
ac967fa0
...
...
@@ -18,6 +18,7 @@
import
torch
from
megatron
import
get_args
from
megatron
import
mpu
from
megatron.model.language_model
import
parallel_lm_logits
from
megatron.model.language_model
import
get_language_model
from
megatron.model.transformer
import
LayerNorm
...
...
@@ -80,6 +81,7 @@ class BertModel(MegatronModule):
super
(
BertModel
,
self
).
__init__
()
args
=
get_args
()
self
.
fp16_lm_cross_entropy
=
args
.
fp16_lm_cross_entropy
self
.
add_binary_head
=
add_binary_head
self
.
parallel_output
=
parallel_output
init_method
=
init_method_normal
(
args
.
init_method_std
)
...
...
@@ -102,7 +104,8 @@ class BertModel(MegatronModule):
init_method
)
self
.
_binary_head_key
=
'binary_head'
def
forward
(
self
,
input_ids
,
attention_mask
,
tokentype_ids
=
None
):
def
forward
(
self
,
input_ids
,
attention_mask
,
tokentype_ids
=
None
,
lm_labels
=
None
):
extended_attention_mask
=
bert_extended_attention_mask
(
attention_mask
,
next
(
self
.
language_model
.
parameters
()).
dtype
)
...
...
@@ -125,11 +128,21 @@ class BertModel(MegatronModule):
lm_logits
=
self
.
lm_head
(
lm_output
,
self
.
language_model
.
embedding
.
word_embeddings
.
weight
)
binary_logits
=
None
if
self
.
add_binary_head
:
binary_logits
=
self
.
binary_head
(
pooled_output
)
if
lm_labels
is
None
:
return
lm_logits
,
binary_logits
else
:
if
self
.
fp16_lm_cross_entropy
:
assert
lm_logits
.
dtype
==
torch
.
half
lm_loss
=
mpu
.
vocab_parallel_cross_entropy
(
lm_logits
,
lm_labels
)
else
:
lm_loss
=
mpu
.
vocab_parallel_cross_entropy
(
lm_logits
.
float
(),
lm_labels
)
return
lm_loss
,
binary_logits
return
lm_logits
,
None
def
state_dict_for_save_checkpoint
(
self
,
destination
=
None
,
prefix
=
''
,
keep_vars
=
False
):
...
...
megatron/model/gpt2_model.py
View file @
ac967fa0
...
...
@@ -18,6 +18,7 @@
import
torch
from
megatron
import
get_args
from
megatron
import
mpu
from
megatron.module
import
MegatronModule
from
.language_model
import
parallel_lm_logits
...
...
@@ -39,6 +40,7 @@ class GPT2Model(MegatronModule):
args
=
get_args
()
self
.
parallel_output
=
parallel_output
self
.
fp16_lm_cross_entropy
=
args
.
fp16_lm_cross_entropy
self
.
language_model
,
self
.
_language_model_key
=
get_language_model
(
attention_mask_func
=
gpt2_attention_mask_func
,
...
...
@@ -48,7 +50,7 @@ class GPT2Model(MegatronModule):
scaled_init_method
=
scaled_init_method_normal
(
args
.
init_method_std
,
args
.
num_layers
))
def
forward
(
self
,
input_ids
,
position_ids
,
attention_mask
,
def
forward
(
self
,
input_ids
,
position_ids
,
attention_mask
,
labels
=
None
,
tokentype_ids
=
None
,
layer_past
=
None
,
get_key_value
=
False
,
forward_method_parallel_output
=
None
):
...
...
@@ -75,7 +77,16 @@ class GPT2Model(MegatronModule):
if
get_key_value
:
output
=
[
output
,
presents
]
return
output
if
labels
is
None
:
return
output
else
:
if
self
.
fp16_lm_cross_entropy
:
assert
output
.
dtype
==
torch
.
half
loss
=
mpu
.
vocab_parallel_cross_entropy
(
output
,
labels
)
else
:
loss
=
mpu
.
vocab_parallel_cross_entropy
(
output
.
float
(),
labels
)
return
loss
def
state_dict_for_save_checkpoint
(
self
,
destination
=
None
,
prefix
=
''
,
keep_vars
=
False
):
...
...
pretrain_bert.py
View file @
ac967fa0
...
...
@@ -67,6 +67,7 @@ def get_batch(data_iterator):
def
forward_step
(
data_iterator
,
model
):
"""Forward step."""
args
=
get_args
()
timers
=
get_timers
()
# Get the batch.
...
...
@@ -75,15 +76,15 @@ def forward_step(data_iterator, model):
=
get_batch
(
data_iterator
)
timers
(
'batch generator'
).
stop
()
# Forward model.
lm_logits
,
sop_logits
=
model
(
tokens
,
padding_mask
,
tokentype_ids
=
types
)
# Forward model. lm_labels
lm_loss_
,
sop_logits
=
model
(
tokens
,
padding_mask
,
tokentype_ids
=
types
,
lm_labels
=
lm_labels
)
sop_loss
=
F
.
cross_entropy
(
sop_logits
.
view
(
-
1
,
2
).
contiguous
().
float
(),
sentence_order
.
view
(
-
1
)
.
contiguous
()
,
sop_loss
=
F
.
cross_entropy
(
sop_logits
.
view
(
-
1
,
2
).
float
(),
sentence_order
.
view
(
-
1
),
ignore_index
=-
1
)
lm_loss_
=
mpu
.
vocab_parallel_cross_entropy
(
lm_logits
.
contiguous
().
float
(),
lm_labels
.
contiguous
())
lm_loss
=
torch
.
sum
(
lm_loss_
.
view
(
-
1
)
*
loss_mask
.
reshape
(
-
1
))
/
loss_mask
.
sum
()
...
...
pretrain_gpt2.py
View file @
ac967fa0
...
...
@@ -71,6 +71,7 @@ def get_batch(data_iterator):
def
forward_step
(
data_iterator
,
model
):
"""Forward step."""
args
=
get_args
()
timers
=
get_timers
()
# Get the batch.
...
...
@@ -80,9 +81,8 @@ def forward_step(data_iterator, model):
timers
(
'batch generator'
).
stop
()
# Forward model.
output
=
model
(
tokens
,
position_ids
,
attention_mask
)
losses
=
mpu
.
vocab_parallel_cross_entropy
(
output
.
contiguous
().
float
(),
labels
)
losses
=
model
(
tokens
,
position_ids
,
attention_mask
,
labels
=
labels
)
loss_mask
=
loss_mask
.
view
(
-
1
)
loss
=
torch
.
sum
(
losses
.
view
(
-
1
)
*
loss_mask
)
/
loss_mask
.
sum
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment