Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
3366a5b0
Commit
3366a5b0
authored
Mar 28, 2020
by
Mohammad
Browse files
refactored pretrain-bert
parent
27e14f82
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
14 additions
and
27 deletions
+14
-27
pretrain_bert.py
pretrain_bert.py
+14
-27
No files found.
pretrain_bert.py
View file @
3366a5b0
...
@@ -20,15 +20,15 @@ import torch.nn.functional as F
...
@@ -20,15 +20,15 @@ import torch.nn.functional as F
from
megatron
import
get_args
from
megatron
import
get_args
from
megatron
import
get_timers
from
megatron
import
get_timers
from
megatron
import
mpu
from
megatron
import
mpu
from
megatron.model
import
BertModel
from
megatron
import
print_rank_0
from
megatron
import
print_rank_0
from
megatron.utils
import
reduce_losses
from
megatron.utils
import
vocab_size_with_padding
from
megatron.training
import
pretrain
from
megatron.data.bert_dataset
import
build_train_valid_test_datasets
from
megatron.data.bert_dataset
import
build_train_valid_test_datasets
from
megatron.data_utils.samplers
import
DistributedBatchSampler
from
megatron.data_utils.samplers
import
DistributedBatchSampler
from
megatron.model
import
BertModel
from
megatron.training
import
pretrain
from
megatron.utils
import
reduce_losses
def
model_provider
():
def
model_provider
():
...
@@ -114,7 +114,7 @@ def forward_step(data_iterator, model):
...
@@ -114,7 +114,7 @@ def forward_step(data_iterator, model):
def
get_train_val_test_data
():
def
get_train_val_test_data
():
"""Load the data on rank zero and boradcast number of tokens to all GPUS."""
"""Load the data on rank zero and boradcast number of tokens to all GPUS."""
args
=
get_args
()
args
=
get_args
()
(
train_data
,
valid_data
,
test_data
)
=
(
None
,
None
,
None
)
(
train_data
,
valid_data
,
test_data
)
=
(
None
,
None
,
None
)
# Data loader only on rank 0 of each model parallel group.
# Data loader only on rank 0 of each model parallel group.
...
@@ -176,36 +176,23 @@ def get_train_val_test_data():
...
@@ -176,36 +176,23 @@ def get_train_val_test_data():
do_valid
=
valid_data
is
not
None
and
args
.
eval_iters
>
0
do_valid
=
valid_data
is
not
None
and
args
.
eval_iters
>
0
do_test
=
test_data
is
not
None
and
args
.
eval_iters
>
0
do_test
=
test_data
is
not
None
and
args
.
eval_iters
>
0
# Need to broadcast num_tokens and num_type_tokens.
# Need to broadcast num_tokens and num_type_tokens.
num_tokens
=
vocab_size_with_padding
(
train_ds
.
num_tokens
(),
args
)
flags
=
torch
.
cuda
.
LongTensor
(
token_counts
=
torch
.
cuda
.
LongTensor
([
num_tokens
,
[
int
(
do_train
),
int
(
do_valid
),
int
(
do_test
)])
2
,
# hard coded num_type_tokens
int
(
do_train
),
int
(
do_valid
),
int
(
do_test
)])
else
:
else
:
token_count
s
=
torch
.
cuda
.
LongTensor
([
0
,
0
,
0
,
0
,
0
])
flag
s
=
torch
.
cuda
.
LongTensor
([
0
,
0
,
0
])
# Broadcast num tokens.
# Broadcast num tokens.
torch
.
distributed
.
broadcast
(
token_count
s
,
torch
.
distributed
.
broadcast
(
flag
s
,
mpu
.
get_model_parallel_src_rank
(),
mpu
.
get_model_parallel_src_rank
(),
group
=
mpu
.
get_model_parallel_group
())
group
=
mpu
.
get_model_parallel_group
())
args
.
vocab_size
=
token_counts
[
0
].
item
()
args
.
do_train
=
flags
[
0
].
item
()
args
.
tokentype_size
=
token_counts
[
1
].
item
()
args
.
do_valid
=
flags
[
1
].
item
()
args
.
do_train
=
token_counts
[
2
].
item
()
args
.
do_test
=
flags
[
2
].
item
()
args
.
do_valid
=
token_counts
[
3
].
item
()
args
.
do_test
=
token_counts
[
4
].
item
()
return
train_data
,
valid_data
,
test_data
return
train_data
,
valid_data
,
test_data
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
'''
pretrain
(
get_train_val_test_data
,
model_provider
,
forward_step
,
from megatron.initialize import initialize_megatron
initialize_megatron(args_defaults={
'tokenizer_type': 'BertWordPieceLowerCase'})
exit()
'''
pretrain
(
get_train_val_test_data
,
model_provider
,
forward_step
,
args_defaults
=
{
'tokenizer_type'
:
'BertWordPieceLowerCase'
})
args_defaults
=
{
'tokenizer_type'
:
'BertWordPieceLowerCase'
})
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment