Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
1788c910
Commit
1788c910
authored
Mar 29, 2020
by
Mohammad
Browse files
both bert and gpt2 tested and working
parent
5f8623db
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
54 additions
and
62 deletions
+54
-62
megatron/arguments.py
megatron/arguments.py
+10
-10
megatron/utils.py
megatron/utils.py
+26
-2
pretrain_bert.py
pretrain_bert.py
+4
-23
pretrain_gpt2.py
pretrain_gpt2.py
+14
-27
No files found.
megatron/arguments.py
View file @
1788c910
...
@@ -312,9 +312,16 @@ def _add_data_args(parser):
...
@@ -312,9 +312,16 @@ def _add_data_args(parser):
choices
=
[
'BertWordPieceLowerCase'
,
choices
=
[
'BertWordPieceLowerCase'
,
'GPT2BPETokenizer'
],
'GPT2BPETokenizer'
],
help
=
'What type of tokenizer to use.'
)
help
=
'What type of tokenizer to use.'
)
parser
.
add_argument
(
'--data-impl'
,
type
=
str
,
default
=
'infer'
,
group
.
add_argument
(
'--data-impl'
,
type
=
str
,
default
=
'infer'
,
choices
=
[
'lazy'
,
'cached'
,
'mmap'
,
'infer'
],
choices
=
[
'lazy'
,
'cached'
,
'mmap'
,
'infer'
],
help
=
'Implementation of indexed datasets.'
)
help
=
'Implementation of indexed datasets.'
)
group
.
add_argument
(
'--reset-position-ids'
,
action
=
'store_true'
,
help
=
'Reset posistion ids after end-of-document token.'
)
group
.
add_argument
(
'--reset-attention-mask'
,
action
=
'store_true'
,
help
=
'Reset self attention maske after '
'end-of-document token.'
)
group
.
add_argument
(
'--eod-mask-loss'
,
action
=
'store_true'
,
help
=
'Mask loss for the end of document tokens.'
)
return
parser
return
parser
...
@@ -340,13 +347,6 @@ def _add_gpt2_args(parser):
...
@@ -340,13 +347,6 @@ def _add_gpt2_args(parser):
group
.
add_argument
(
'--input-data-sizes-file'
,
type
=
str
,
default
=
'sizes.txt'
,
group
.
add_argument
(
'--input-data-sizes-file'
,
type
=
str
,
default
=
'sizes.txt'
,
help
=
'The filename containing all the shards '
help
=
'The filename containing all the shards '
'sizes for numpy data loader'
)
'sizes for numpy data loader'
)
group
.
add_argument
(
'--reset-position-ids'
,
action
=
'store_true'
,
help
=
'Reset posistion ids after end-of-document token.'
)
group
.
add_argument
(
'--reset-attention-mask'
,
action
=
'store_true'
,
help
=
'Reset self attention maske after '
'end-of-document token.'
)
group
.
add_argument
(
'--eod-mask-loss'
,
action
=
'store_true'
,
help
=
'Mask loss for the end of document tokens.'
)
return
parser
return
parser
...
...
megatron/utils.py
View file @
1788c910
...
@@ -21,8 +21,10 @@ import torch
...
@@ -21,8 +21,10 @@ import torch
from
megatron
import
get_args
from
megatron
import
get_args
from
megatron
import
get_adlr_autoresume
from
megatron
import
get_adlr_autoresume
from
megatron
import
mpu
from
megatron
import
print_rank_0
from
megatron
import
print_rank_0
from
megatron.checkpointing
import
save_checkpoint
from
megatron.checkpointing
import
save_checkpoint
from
megatron.data_utils.samplers
import
DistributedBatchSampler
from
megatron.fp16
import
FP16_Optimizer
from
megatron.fp16
import
FP16_Optimizer
...
@@ -87,7 +89,30 @@ def check_adlr_autoresume_termination(iteration, model,
...
@@ -87,7 +89,30 @@ def check_adlr_autoresume_termination(iteration, model,
sys
.
exit
(
0
)
sys
.
exit
(
0
)
###################################################
def
make_data_loader
(
dataset
):
"""Buld dataloader given an input dataset."""
if
dataset
is
None
:
return
None
args
=
get_args
()
# Data parallel arguments.
world_size
=
mpu
.
get_data_parallel_world_size
()
rank
=
mpu
.
get_data_parallel_rank
()
global_batch_size
=
args
.
batch_size
*
world_size
num_workers
=
args
.
num_workers
# Use a simple sampler with distributed batch sampler.
sampler
=
torch
.
utils
.
data
.
SequentialSampler
(
dataset
)
batch_sampler
=
DistributedBatchSampler
(
sampler
=
sampler
,
batch_size
=
global_batch_size
,
drop_last
=
True
,
rank
=
rank
,
world_size
=
world_size
)
# Torch dataloader.
return
torch
.
utils
.
data
.
DataLoader
(
dataset
,
batch_sampler
=
batch_sampler
,
num_workers
=
num_workers
,
pin_memory
=
True
)
def
get_ltor_masks_and_position_ids
(
data
,
def
get_ltor_masks_and_position_ids
(
data
,
...
@@ -145,4 +170,3 @@ def get_ltor_masks_and_position_ids(data,
...
@@ -145,4 +170,3 @@ def get_ltor_masks_and_position_ids(data,
prev_index
=
i
+
1
prev_index
=
i
+
1
return
attention_mask
,
loss_mask
,
position_ids
return
attention_mask
,
loss_mask
,
position_ids
pretrain_bert.py
View file @
1788c910
...
@@ -23,14 +23,12 @@ from megatron import get_timers
...
@@ -23,14 +23,12 @@ from megatron import get_timers
from
megatron
import
mpu
from
megatron
import
mpu
from
megatron
import
print_rank_0
from
megatron
import
print_rank_0
from
megatron.data.bert_dataset
import
build_train_valid_test_datasets
from
megatron.data.bert_dataset
import
build_train_valid_test_datasets
from
megatron.data_utils.samplers
import
DistributedBatchSampler
from
megatron.model
import
BertModel
from
megatron.model
import
BertModel
from
megatron.training
import
pretrain
from
megatron.training
import
pretrain
from
megatron.utils
import
make_data_loader
from
megatron.utils
import
reduce_losses
from
megatron.utils
import
reduce_losses
def
model_provider
():
def
model_provider
():
"""Build the model."""
"""Build the model."""
args
=
get_args
()
args
=
get_args
()
...
@@ -151,26 +149,9 @@ def get_train_val_test_data():
...
@@ -151,26 +149,9 @@ def get_train_val_test_data():
skip_warmup
=
(
not
args
.
mmap_warmup
))
skip_warmup
=
(
not
args
.
mmap_warmup
))
print_rank_0
(
"> finished creating BERT datasets ..."
)
print_rank_0
(
"> finished creating BERT datasets ..."
)
def
make_data_loader_
(
dataset
):
train_data
=
make_data_loader
(
train_ds
)
if
not
dataset
:
valid_data
=
make_data_loader
(
valid_ds
)
return
None
test_data
=
make_data_loader
(
test_ds
)
# Use a simple sampler with distributed batch sampler.
sampler
=
torch
.
utils
.
data
.
SequentialSampler
(
dataset
)
batch_sampler
=
DistributedBatchSampler
(
sampler
=
sampler
,
batch_size
=
global_batch_size
,
drop_last
=
True
,
rank
=
data_parallel_rank
,
world_size
=
data_parallel_size
)
# Torch dataloader.
return
torch
.
utils
.
data
.
DataLoader
(
dataset
,
batch_sampler
=
batch_sampler
,
num_workers
=
args
.
num_workers
,
pin_memory
=
True
)
train_data
=
make_data_loader_
(
train_ds
)
valid_data
=
make_data_loader_
(
valid_ds
)
test_data
=
make_data_loader_
(
test_ds
)
do_train
=
train_data
is
not
None
and
args
.
train_iters
>
0
do_train
=
train_data
is
not
None
and
args
.
train_iters
>
0
do_valid
=
valid_data
is
not
None
and
args
.
eval_iters
>
0
do_valid
=
valid_data
is
not
None
and
args
.
eval_iters
>
0
...
...
pretrain_gpt2.py
View file @
1788c910
...
@@ -25,10 +25,10 @@ from megatron import get_tokenizer
...
@@ -25,10 +25,10 @@ from megatron import get_tokenizer
from
megatron
import
mpu
from
megatron
import
mpu
from
megatron
import
print_rank_0
from
megatron
import
print_rank_0
from
megatron.data.gpt2_dataset
import
GPT2Dataset
from
megatron.data.gpt2_dataset
import
GPT2Dataset
from
megatron.data_utils.samplers
import
DistributedBatchSampler
from
megatron.model
import
GPT2Model
from
megatron.model
import
GPT2Model
from
megatron.training
import
pretrain
from
megatron.training
import
pretrain
from
megatron.utils
import
get_ltor_masks_and_position_ids
from
megatron.utils
import
get_ltor_masks_and_position_ids
from
megatron.utils
import
make_data_loader
from
megatron.utils
import
reduce_losses
from
megatron.utils
import
reduce_losses
...
@@ -121,32 +121,19 @@ def make_gpt2_dataloaders():
...
@@ -121,32 +121,19 @@ def make_gpt2_dataloaders():
seq_length
=
args
.
seq_length
seq_length
=
args
.
seq_length
initial_seed
=
args
.
seed
initial_seed
=
args
.
seed
# Data parallel arguments.
# Build the datasets.
world_size
=
mpu
.
get_data_parallel_world_size
()
def
build_dataset_
(
name
):
rank
=
mpu
.
get_data_parallel_rank
()
return
GPT2Dataset
(
os
.
path
.
join
(
args
.
data_path
,
name
),
global_batch_size
=
args
.
batch_size
*
world_size
args
.
input_data_sizes_file
,
num_workers
=
args
.
num_workers
args
.
seq_length
,
args
.
seed
)
train_ds
=
build_dataset_
(
'train'
)
def
make_data_loader_
(
data_path
):
valid_ds
=
build_dataset_
(
'valid'
)
# Build the dataset.
test_ds
=
build_dataset_
(
'test'
)
dataset
=
GPT2Dataset
(
data_path
,
input_data_sizes_file
,
seq_length
,
initial_seed
)
# Dataloaders
# Use a simple sampler with distributed batch sampler.
train
=
make_data_loader
(
train_ds
)
sampler
=
torch
.
utils
.
data
.
SequentialSampler
(
dataset
)
valid
=
make_data_loader
(
valid_ds
)
batch_sampler
=
DistributedBatchSampler
(
sampler
=
sampler
,
test
=
make_data_loader
(
test_ds
)
batch_size
=
global_batch_size
,
drop_last
=
True
,
rank
=
rank
,
world_size
=
world_size
)
# Torch dataloader.
return
torch
.
utils
.
data
.
DataLoader
(
dataset
,
batch_sampler
=
batch_sampler
,
num_workers
=
num_workers
,
pin_memory
=
True
)
train
=
make_data_loader_
(
os
.
path
.
join
(
args
.
data_path
,
'train'
))
valid
=
make_data_loader_
(
os
.
path
.
join
(
args
.
data_path
,
'valid'
))
test
=
make_data_loader_
(
os
.
path
.
join
(
args
.
data_path
,
'test'
))
args
.
do_train
=
False
args
.
do_train
=
False
args
.
do_valid
=
False
args
.
do_valid
=
False
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment