Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
2e38461b
Commit
2e38461b
authored
Apr 12, 2020
by
Mohammad
Browse files
data loading for BERT and GPT cleaned up
parent
eb74fa34
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
118 additions
and
468 deletions
+118
-468
megatron/data/old_gpt2_dataset.py
megatron/data/old_gpt2_dataset.py
+0
-136
megatron/training.py
megatron/training.py
+85
-42
pretrain_bert.py
pretrain_bert.py
+18
-61
pretrain_gpt2.py
pretrain_gpt2.py
+15
-61
pretrain_gpt2_old.py
pretrain_gpt2_old.py
+0
-168
No files found.
megatron/data/old_gpt2_dataset.py
deleted
100644 → 0
View file @
eb74fa34
# coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""GPT2 dataset."""
import
json
import
os
import
numpy
as
np
import
torch
from
torch.utils.data
import
Dataset
class
GPT2Dataset
(
Dataset
):
def
__init__
(
self
,
data_path
,
sizes_filename
,
seq_length
,
initial_seed
,
max_epochs
=
100
):
# Input parameters.
self
.
data_path
=
data_path
self
.
sizes_filename
=
sizes_filename
self
.
seq_length
=
seq_length
self
.
initial_seed
=
initial_seed
self
.
max_epochs
=
max_epochs
# Shard stuff.
# Dictionary from shard nameto its size (number of element).
self
.
master_shard_size_dict
=
None
# Dictionary from shard name to modified size so it is
# divisible by self.seq_length.
self
.
shard_size_dict
=
None
# Long array (self.max_epochs * num-shards) populated
# randomly with shard names.
self
.
shards_name
=
None
# Start index of the data for a shard.
self
.
shards_start_index
=
None
self
.
build_shard_mappings_
()
self
.
data_length
=
self
.
shards_start_index
[
-
1
]
# Data.
self
.
shards_data
=
[
None
]
*
self
.
shards_name
.
size
self
.
shards_sample_index
=
[
None
]
*
self
.
shards_name
.
size
def
__len__
(
self
):
return
self
.
data_length
def
__getitem__
(
self
,
idx
):
# Find which shard we need.
shard_index
=
np
.
searchsorted
(
self
.
shards_start_index
,
idx
,
side
=
'right'
)
-
1
# data index in the shard.
data_idx
=
idx
-
self
.
shards_start_index
[
shard_index
]
# Load the shard if it is not in memory.
if
self
.
shards_data
[
shard_index
]
is
None
:
print
(
'global rank {} is building data for shard index {} ...'
.
format
(
torch
.
distributed
.
get_rank
(),
shard_index
))
self
.
build_dataset_
(
shard_index
)
#assert self.shards_data[shard_index] is not None
# Start index.
start_index
=
self
.
shards_sample_index
[
shard_index
][
data_idx
]
# Add one for label shift.
end_index
=
start_index
+
self
.
seq_length
+
1
data
=
self
.
shards_data
[
shard_index
][
start_index
:
end_index
]
return
{
'text'
:
np
.
array
(
data
,
dtype
=
np
.
int64
)}
def
build_dataset_
(
self
,
shard_index
):
# Garbage collect so we don't use a lot of memory.
# Leave the last one in case other threads have not catche up yet.
#for i in range(shard_index - 1):
for
i
in
range
(
shard_index
):
self
.
shards_data
[
i
]
=
None
self
.
shards_sample_index
[
i
]
=
None
# Read the shard.
filename
=
os
.
path
.
join
(
self
.
data_path
,
self
.
shards_name
[
shard_index
])
print
(
'loading {}'
.
format
(
filename
))
data
=
np
.
load
(
filename
,
allow_pickle
=
True
)
# Shuffle the data
rng
=
np
.
random
.
RandomState
(
self
.
initial_seed
+
shard_index
)
rng
.
shuffle
(
data
)
# Flatten.
data
=
np
.
hstack
(
data
)
size
=
(
data
.
shape
[
0
]
-
1
)
//
self
.
seq_length
last_index
=
size
*
self
.
seq_length
+
1
data
=
data
[
0
:
last_index
]
self
.
shards_data
[
shard_index
]
=
data
indices
=
np
.
arange
(
size
)
*
self
.
seq_length
rng
.
shuffle
(
indices
)
self
.
shards_sample_index
[
shard_index
]
=
indices
def
build_shard_mappings_
(
self
):
# Load the sizes file.
sizes_filename
=
os
.
path
.
join
(
self
.
data_path
,
self
.
sizes_filename
)
if
torch
.
distributed
.
get_rank
()
==
0
:
print
(
' > loading sizes from {}'
.
format
(
sizes_filename
))
with
open
(
sizes_filename
,
'r'
)
as
f
:
self
.
master_shard_size_dict
=
json
.
load
(
f
)
if
torch
.
distributed
.
get_rank
()
==
0
:
print
(
' found {} shards'
.
format
(
len
(
self
.
master_shard_size_dict
)))
# Adjust sizes to be a multiple of seq_length.
self
.
shard_size_dict
=
self
.
master_shard_size_dict
.
copy
()
total_samples
=
0
for
shard
in
self
.
shard_size_dict
:
size
=
self
.
shard_size_dict
[
shard
]
size
=
((
size
-
1
)
//
self
.
seq_length
)
*
self
.
seq_length
total_samples
+=
size
//
self
.
seq_length
self
.
shard_size_dict
[
shard
]
=
size
if
torch
.
distributed
.
get_rank
()
==
0
:
print
(
' found {} samples in the dataset'
.
format
(
total_samples
))
# Build a list of shards.
shards_
=
np
.
sort
(
np
.
array
(
list
(
self
.
shard_size_dict
.
keys
())))
rng
=
np
.
random
.
RandomState
(
self
.
initial_seed
)
self
.
shards_name
=
np
.
copy
(
shards_
)
rng
.
shuffle
(
self
.
shards_name
)
for
i
in
range
(
1
,
self
.
max_epochs
):
shards_c
=
np
.
copy
(
shards_
)
rng
.
shuffle
(
shards_c
)
self
.
shards_name
=
np
.
append
(
self
.
shards_name
,
shards_c
)
# Build the global indexing.
self
.
shards_start_index
=
np
.
zeros
(
self
.
shards_name
.
size
,
dtype
=
np
.
int
)
self
.
shards_start_index
[
0
]
=
0
for
i
in
range
(
1
,
self
.
shards_name
.
size
):
shard
=
str
(
self
.
shards_name
[
i
-
1
])
size
=
self
.
shard_size_dict
[
shard
]
self
.
shards_start_index
[
i
]
=
self
.
shards_start_index
[
i
-
1
]
+
\
size
//
self
.
seq_length
megatron/training.py
View file @
2e38461b
...
...
@@ -37,11 +37,12 @@ from megatron.learning_rates import AnnealingLR
from
megatron.model
import
DistributedDataParallel
as
LocalDDP
from
megatron.model
import
get_params_for_weight_decay_optimization
from
megatron.utils
import
check_adlr_autoresume_termination
from
megatron.utils
import
make_data_loader
from
megatron.utils
import
report_memory
def
pretrain
(
train_val_test_data_provider
,
model_provider
,
forward_step_func
,
extra_args_provider
=
None
,
args_defaults
=
{}):
def
pretrain
(
train_val
id
_test_data
set
_provider
,
model_provider
,
forward_step_func
,
extra_args_provider
=
None
,
args_defaults
=
{}):
"""Main training program.
This function will run the followings in the order provided:
...
...
@@ -51,9 +52,9 @@ def pretrain(train_val_test_data_provider, model_provider, forward_step_func,
4) train the modle using the forward_step_func.
Arguments:
train_val_test_data_provider: a function that
builds datasets
and returns `train, val, test` data
loader
s.
model_provider: a function that
returns a vanilla version of the
train_val
id
_test_data
set
_provider: a function that
takes the size of
train/valid/test dataset
and returns `train, val
id
, test` data
set
s.
model_provider: a function that returns a vanilla version of the
model. By vanilla we mean a simple model on cpu with no fp16 or ddp.
forward_step_func: a function that takes a `data iterator` and `model`,
and returns a `loss` scalar with a dictionary with key:values being
...
...
@@ -78,22 +79,15 @@ def pretrain(train_val_test_data_provider, model_provider, forward_step_func,
timers
(
'model and optimizer'
).
stop
()
# Data stuff.
timers
(
'train/valid/test dataset'
).
start
()
train_data
,
val_data
,
test_data
=
train_val_test_data_provider
()
timers
(
'train/valid/test dataset'
).
stop
()
# Train, validation, and test data.
timers
(
'train/valid/test dataloader'
).
start
()
train_data_iterator
,
val_data_iterator
,
\
test_data_iterator
=
get_train_val_test_data_iterators
(
train_data
,
val_data
,
test_data
)
timers
(
'train/valid/test dataloader'
).
stop
()
timers
(
'train/valid/test data iterators'
).
start
()
train_data_iterator
,
valid_data_iterator
,
test_data_iterator
\
=
build_train_valid_test_data_iterators
(
train_valid_test_dataset_provider
)
timers
(
'train/valid/test data iterators'
).
stop
()
# Print setup timing.
print_rank_0
(
'done with setups ...'
)
timers
.
log
([
'model and optimizer'
,
'train/valid/test dataset'
,
'train/valid/test dataloader'
])
timers
.
log
([
'model and optimizer'
,
'train/valid/test data iterators'
])
print_rank_0
(
'training ...'
)
iteration
=
0
...
...
@@ -101,13 +95,13 @@ def pretrain(train_val_test_data_provider, model_provider, forward_step_func,
if
args
.
do_train
:
iteration
,
_
=
train
(
forward_step_func
,
model
,
optimizer
,
lr_scheduler
,
train_data_iterator
,
val_data_iterator
)
train_data_iterator
,
val
id
_data_iterator
)
if
args
.
do_valid
:
prefix
=
'the end of training for val data'
evaluate_and_print_results
(
prefix
,
forward_step_func
,
val_data_iterator
,
model
,
val
id
_data_iterator
,
model
,
iteration
,
False
)
if
args
.
save
and
iteration
!=
0
:
...
...
@@ -152,8 +146,7 @@ def get_model(model_provider_func):
return
model
raise
NotImplementedError
(
'Unknown DDP implementation specified: {}. '
'Exiting.'
.
format
(
args
.
DDP_impl
))
sys
.
exit
()
'Exiting.'
.
format
(
args
.
DDP_impl
))
def
get_optimizer
(
model
):
...
...
@@ -352,7 +345,7 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
def
train
(
forward_step_func
,
model
,
optimizer
,
lr_scheduler
,
train_data_iterator
,
val_data_iterator
):
train_data_iterator
,
val
id
_data_iterator
):
"""Train the model function."""
args
=
get_args
()
timers
=
get_timers
()
...
...
@@ -403,7 +396,7 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
args
.
do_valid
:
prefix
=
'iteration {}'
.
format
(
iteration
)
evaluate_and_print_results
(
prefix
,
forward_step_func
,
val_data_iterator
,
model
,
val
id
_data_iterator
,
model
,
iteration
,
False
)
if
args
.
exit_interval
and
iteration
%
args
.
exit_interval
==
0
:
...
...
@@ -472,37 +465,87 @@ def evaluate_and_print_results(prefix, forward_step_func,
print_rank_0
(
'-'
*
length
)
def
get_train_val_test_data_iterators
(
train_data
,
val_data
,
test_data
):
"""Build train/validation/test iterators"""
def
build_train_valid_test_data_iterators
(
build_train_valid_test_datasets_provider
):
"""XXX"""
args
=
get_args
()
(
train_dataloader
,
valid_dataloader
,
test_dataloader
)
=
(
None
,
None
,
None
)
print_rank_0
(
'> building train, validation, and test datasets ...'
)
# Data loader only on rank 0 of each model parallel group.
if
mpu
.
get_model_parallel_rank
()
==
0
:
# Rank, size, and global batch size.
data_parallel_size
=
mpu
.
get_data_parallel_world_size
()
global_batch_size
=
args
.
batch_size
*
data_parallel_size
# Number of train/valid/test samples.
train_iters
=
args
.
train_iters
eval_iters
=
(
train_iters
//
args
.
eval_interval
+
1
)
*
args
.
eval_iters
test_iters
=
args
.
eval_iters
train_val_test_num_samples
=
[
train_iters
*
global_batch_size
,
eval_iters
*
global_batch_size
,
test_iters
*
global_batch_size
]
print_rank_0
(
' > datasets target sizes (minimum size):'
)
print_rank_0
(
' train: {}'
.
format
(
train_val_test_num_samples
[
0
]))
print_rank_0
(
' validation: {}'
.
format
(
train_val_test_num_samples
[
1
]))
print_rank_0
(
' test: {}'
.
format
(
train_val_test_num_samples
[
2
]))
# Build the datasets.
train_ds
,
valid_ds
,
test_ds
=
build_train_valid_test_datasets_provider
(
train_val_test_num_samples
)
# Build dataloders.
train_dataloader
=
make_data_loader
(
train_ds
)
valid_dataloader
=
make_data_loader
(
valid_ds
)
test_dataloader
=
make_data_loader
(
test_ds
)
# Flags to know if we need to do training/validation/testing.
do_train
=
train_dataloader
is
not
None
and
args
.
train_iters
>
0
do_valid
=
valid_dataloader
is
not
None
and
args
.
eval_iters
>
0
do_test
=
test_dataloader
is
not
None
and
args
.
eval_iters
>
0
# Need to broadcast num_tokens and num_type_tokens.
flags
=
torch
.
cuda
.
LongTensor
(
[
int
(
do_train
),
int
(
do_valid
),
int
(
do_test
)])
else
:
flags
=
torch
.
cuda
.
LongTensor
([
0
,
0
,
0
])
# Broadcast num tokens.
torch
.
distributed
.
broadcast
(
flags
,
mpu
.
get_model_parallel_src_rank
(),
group
=
mpu
.
get_model_parallel_group
())
args
.
do_train
=
flags
[
0
].
item
()
args
.
do_valid
=
flags
[
1
].
item
()
args
.
do_test
=
flags
[
2
].
item
()
# Shift the start iterations.
if
train_data
is
not
None
:
train_data
.
batch_sampler
.
start_iter
=
args
.
iteration
%
\
len
(
train_data
)
if
train_data
loader
is
not
None
:
train_data
loader
.
batch_sampler
.
start_iter
=
args
.
iteration
%
\
len
(
train_data
loader
)
print_rank_0
(
'setting training data start iteration to {}'
.
format
(
train_data
.
batch_sampler
.
start_iter
))
if
val_data
is
not
None
:
format
(
train_data
loader
.
batch_sampler
.
start_iter
))
if
val
id
_data
loader
is
not
None
:
start_iter_val
=
(
args
.
iteration
//
args
.
eval_interval
)
*
\
args
.
eval_iters
val_data
.
batch_sampler
.
start_iter
=
start_iter_val
%
\
len
(
val_data
)
val
id
_data
loader
.
batch_sampler
.
start_iter
=
start_iter_val
%
\
len
(
val
id
_data
loader
)
print_rank_0
(
'setting validation data start iteration to {}'
.
format
(
val_data
.
batch_sampler
.
start_iter
))
format
(
val
id
_data
loader
.
batch_sampler
.
start_iter
))
if
train_data
is
not
None
:
train_data_iterator
=
iter
(
train_data
)
# Build iterators.
if
train_dataloader
is
not
None
:
train_data_iterator
=
iter
(
train_dataloader
)
else
:
train_data_iterator
=
None
if
val_data
is
not
None
:
val_data_iterator
=
iter
(
val_data
)
if
val
id
_data
loader
is
not
None
:
val
id
_data_iterator
=
iter
(
val
id
_data
loader
)
else
:
val_data_iterator
=
None
val
id
_data_iterator
=
None
if
test_data
is
not
None
:
test_data_iterator
=
iter
(
test_data
)
if
test_data
loader
is
not
None
:
test_data_iterator
=
iter
(
test_data
loader
)
else
:
test_data_iterator
=
None
return
train_data_iterator
,
val_data_iterator
,
test_data_iterator
return
train_data_iterator
,
val
id
_data_iterator
,
test_data_iterator
pretrain_bert.py
View file @
2e38461b
...
...
@@ -25,13 +25,11 @@ from megatron import print_rank_0
from
megatron.data.bert_dataset
import
build_train_valid_test_datasets
from
megatron.model
import
BertModel
from
megatron.training
import
pretrain
from
megatron.utils
import
make_data_loader
from
megatron.utils
import
reduce_losses
def
model_provider
():
"""Build the model."""
args
=
get_args
()
print_rank_0
(
'building BERT model ...'
)
...
...
@@ -44,6 +42,7 @@ def model_provider():
def
get_batch
(
data_iterator
):
"""Build the batch."""
# Items and their type.
keys
=
[
'text'
,
'types'
,
'labels'
,
'is_random'
,
'loss_mask'
,
'padding_mask'
]
...
...
@@ -96,70 +95,28 @@ def forward_step(data_iterator, model):
return
loss
,
{
'lm loss'
:
reduced_losses
[
0
],
'sop loss'
:
reduced_losses
[
1
]}
def
get_
train_val_test_data
(
):
"""
Load the data on rank zero and boradcast number of tokens to all GPUS
."""
def
train_val
id
_test_data
sets_provider
(
train_val_test_num_samples
):
"""
Build train, valid, and test datasets
."""
args
=
get_args
()
(
train_data
,
valid_data
,
test_data
)
=
(
None
,
None
,
None
)
# Data loader only on rank 0 of each model parallel group.
if
mpu
.
get_model_parallel_rank
()
==
0
:
print_rank_0
(
'> building train, validation, and test datasets '
'for BERT ...'
)
data_parallel_size
=
mpu
.
get_data_parallel_world_size
()
data_parallel_rank
=
mpu
.
get_data_parallel_rank
()
global_batch_size
=
args
.
batch_size
*
data_parallel_size
# Number of train/valid/test samples.
train_iters
=
args
.
train_iters
eval_iters
=
(
train_iters
//
args
.
eval_interval
+
1
)
*
args
.
eval_iters
test_iters
=
args
.
eval_iters
train_val_test_num_samples
=
[
train_iters
*
global_batch_size
,
eval_iters
*
global_batch_size
,
test_iters
*
global_batch_size
]
print_rank_0
(
' > datasets target sizes (minimum size):'
)
print_rank_0
(
' train: {}'
.
format
(
train_val_test_num_samples
[
0
]))
print_rank_0
(
' validation: {}'
.
format
(
train_val_test_num_samples
[
1
]))
print_rank_0
(
' test: {}'
.
format
(
train_val_test_num_samples
[
2
]))
train_ds
,
valid_ds
,
test_ds
=
build_train_valid_test_datasets
(
data_prefix
=
args
.
data_path
,
data_impl
=
args
.
data_impl
,
splits_string
=
args
.
split
,
train_valid_test_num_samples
=
train_val_test_num_samples
,
max_seq_length
=
args
.
seq_length
,
masked_lm_prob
=
args
.
mask_prob
,
short_seq_prob
=
args
.
short_seq_prob
,
seed
=
args
.
seed
,
skip_warmup
=
(
not
args
.
mmap_warmup
))
print_rank_0
(
"> finished creating BERT datasets ..."
)
train_data
=
make_data_loader
(
train_ds
)
valid_data
=
make_data_loader
(
valid_ds
)
test_data
=
make_data_loader
(
test_ds
)
do_train
=
train_data
is
not
None
and
args
.
train_iters
>
0
do_valid
=
valid_data
is
not
None
and
args
.
eval_iters
>
0
do_test
=
test_data
is
not
None
and
args
.
eval_iters
>
0
# Need to broadcast num_tokens and num_type_tokens.
flags
=
torch
.
cuda
.
LongTensor
(
[
int
(
do_train
),
int
(
do_valid
),
int
(
do_test
)])
else
:
flags
=
torch
.
cuda
.
LongTensor
([
0
,
0
,
0
])
# Broadcast num tokens.
torch
.
distributed
.
broadcast
(
flags
,
mpu
.
get_model_parallel_src_rank
(),
group
=
mpu
.
get_model_parallel_group
())
args
.
do_train
=
flags
[
0
].
item
()
args
.
do_valid
=
flags
[
1
].
item
()
args
.
do_test
=
flags
[
2
].
item
()
print_rank_0
(
'> building train, validation, and test datasets '
'for BERT ...'
)
train_ds
,
valid_ds
,
test_ds
=
build_train_valid_test_datasets
(
data_prefix
=
args
.
data_path
,
data_impl
=
args
.
data_impl
,
splits_string
=
args
.
split
,
train_valid_test_num_samples
=
train_val_test_num_samples
,
max_seq_length
=
args
.
seq_length
,
masked_lm_prob
=
args
.
mask_prob
,
short_seq_prob
=
args
.
short_seq_prob
,
seed
=
args
.
seed
,
skip_warmup
=
(
not
args
.
mmap_warmup
))
print_rank_0
(
"> finished creating BERT datasets ..."
)
return
train_d
ata
,
valid_d
ata
,
test_d
ata
return
train_d
s
,
valid_d
s
,
test_d
s
if
__name__
==
"__main__"
:
pretrain
(
get_
train_val_test_data
,
model_provider
,
forward_step
,
pretrain
(
train_val
id
_test_data
sets_provider
,
model_provider
,
forward_step
,
args_defaults
=
{
'tokenizer_type'
:
'BertWordPieceLowerCase'
})
pretrain_gpt2.py
View file @
2e38461b
...
...
@@ -15,8 +15,6 @@
"""Pretrain GPT2"""
import
os
import
torch
from
megatron
import
get_args
...
...
@@ -28,13 +26,11 @@ from megatron.data.gpt2_dataset import build_train_valid_test_datasets
from
megatron.model
import
GPT2Model
from
megatron.training
import
pretrain
from
megatron.utils
import
get_ltor_masks_and_position_ids
from
megatron.utils
import
make_data_loader
from
megatron.utils
import
reduce_losses
def
model_provider
():
"""Build the model."""
args
=
get_args
()
print_rank_0
(
'building GPT2 model ...'
)
model
=
GPT2Model
(
num_tokentypes
=
0
,
parallel_output
=
True
)
...
...
@@ -98,68 +94,26 @@ def forward_step(data_iterator, model):
return
loss
,
{
'lm loss'
:
reduced_loss
[
0
]}
def
get_
train_val_test_data
(
):
"""
Load the data on rank zero and boradcast number of tokens to all GPUS
."""
def
train_val
id
_test_data
sets_provider
(
train_val_test_num_samples
):
"""
Build train, valid, and test datasets
."""
args
=
get_args
()
(
train_data
,
valid_data
,
test_data
)
=
(
None
,
None
,
None
)
# Data loader only on rank 0 of each model parallel group.
if
mpu
.
get_model_parallel_rank
()
==
0
:
print_rank_0
(
'> building train, validation, and test datasets '
'for GPT2 ...'
)
data_parallel_size
=
mpu
.
get_data_parallel_world_size
()
data_parallel_rank
=
mpu
.
get_data_parallel_rank
()
global_batch_size
=
args
.
batch_size
*
data_parallel_size
# Number of train/valid/test samples.
train_iters
=
args
.
train_iters
eval_iters
=
(
train_iters
//
args
.
eval_interval
+
1
)
*
args
.
eval_iters
test_iters
=
args
.
eval_iters
train_val_test_num_samples
=
[
train_iters
*
global_batch_size
,
eval_iters
*
global_batch_size
,
test_iters
*
global_batch_size
]
print_rank_0
(
' > datasets target sizes (minimum size):'
)
print_rank_0
(
' train: {}'
.
format
(
train_val_test_num_samples
[
0
]))
print_rank_0
(
' validation: {}'
.
format
(
train_val_test_num_samples
[
1
]))
print_rank_0
(
' test: {}'
.
format
(
train_val_test_num_samples
[
2
]))
train_ds
,
valid_ds
,
test_ds
=
build_train_valid_test_datasets
(
data_prefix
=
args
.
data_path
,
data_impl
=
args
.
data_impl
,
splits_string
=
args
.
split
,
train_valid_test_num_samples
=
train_val_test_num_samples
,
seq_length
=
args
.
seq_length
,
seed
=
args
.
seed
,
skip_warmup
=
(
not
args
.
mmap_warmup
))
print_rank_0
(
"> finished creating GPT2 datasets ..."
)
train_data
=
make_data_loader
(
train_ds
)
valid_data
=
make_data_loader
(
valid_ds
)
test_data
=
make_data_loader
(
test_ds
)
do_train
=
train_data
is
not
None
and
args
.
train_iters
>
0
do_valid
=
valid_data
is
not
None
and
args
.
eval_iters
>
0
do_test
=
test_data
is
not
None
and
args
.
eval_iters
>
0
# Need to broadcast num_tokens and num_type_tokens.
flags
=
torch
.
cuda
.
LongTensor
(
[
int
(
do_train
),
int
(
do_valid
),
int
(
do_test
)])
else
:
flags
=
torch
.
cuda
.
LongTensor
([
0
,
0
,
0
])
# Broadcast num tokens.
torch
.
distributed
.
broadcast
(
flags
,
mpu
.
get_model_parallel_src_rank
(),
group
=
mpu
.
get_model_parallel_group
())
args
.
do_train
=
flags
[
0
].
item
()
args
.
do_valid
=
flags
[
1
].
item
()
args
.
do_test
=
flags
[
2
].
item
()
print_rank_0
(
'> building train, validation, and test datasets '
'for GPT2 ...'
)
train_ds
,
valid_ds
,
test_ds
=
build_train_valid_test_datasets
(
data_prefix
=
args
.
data_path
,
data_impl
=
args
.
data_impl
,
splits_string
=
args
.
split
,
train_valid_test_num_samples
=
train_val_test_num_samples
,
seq_length
=
args
.
seq_length
,
seed
=
args
.
seed
,
skip_warmup
=
(
not
args
.
mmap_warmup
))
print_rank_0
(
"> finished creating GPT2 datasets ..."
)
return
train_d
ata
,
valid_d
ata
,
test_d
ata
return
train_d
s
,
valid_d
s
,
test_d
s
if
__name__
==
"__main__"
:
pretrain
(
get_
train_val_test_data
,
model_provider
,
forward_step
,
pretrain
(
train_val
id
_test_data
sets_provider
,
model_provider
,
forward_step
,
args_defaults
=
{
'tokenizer_type'
:
'GPT2BPETokenizer'
})
pretrain_gpt2_old.py
deleted
100644 → 0
View file @
eb74fa34
# coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Pretrain GPT2"""
import
os
import
torch
from
megatron
import
get_args
from
megatron
import
get_timers
from
megatron
import
get_tokenizer
from
megatron
import
mpu
from
megatron
import
print_rank_0
from
megatron.data.old_gpt2_dataset
import
GPT2Dataset
from
megatron.model
import
GPT2Model
from
megatron.training
import
pretrain
from
megatron.utils
import
get_ltor_masks_and_position_ids
from
megatron.utils
import
make_data_loader
from
megatron.utils
import
reduce_losses
def
model_provider
():
"""Build the model."""
args
=
get_args
()
print_rank_0
(
'building GPT2 model ...'
)
model
=
GPT2Model
(
num_tokentypes
=
0
,
parallel_output
=
True
)
return
model
def
get_batch
(
data_iterator
):
"""Generate a batch"""
args
=
get_args
()
tokenizer
=
get_tokenizer
()
# Items and their type.
keys
=
[
'text'
]
datatype
=
torch
.
int64
# Broadcast data.
if
data_iterator
is
not
None
:
data
=
next
(
data_iterator
)
else
:
data
=
None
data_b
=
mpu
.
broadcast_data
(
keys
,
data
,
datatype
)
# Unpack.
tokens_
=
data_b
[
'text'
].
long
()
labels
=
tokens_
[:,
1
:].
contiguous
()
tokens
=
tokens_
[:,
:
-
1
].
contiguous
()
# Get the masks and postition ids.
attention_mask
,
loss_mask
,
position_ids
=
get_ltor_masks_and_position_ids
(
tokens
,
tokenizer
.
eod
,
args
.
reset_position_ids
,
args
.
reset_attention_mask
,
args
.
eod_mask_loss
,
args
.
fp16
)
return
tokens
,
labels
,
loss_mask
,
attention_mask
,
position_ids
def
forward_step
(
data_iterator
,
model
):
"""Forward step."""
timers
=
get_timers
()
# Get the batch.
timers
(
'batch generator'
).
start
()
tokens
,
labels
,
loss_mask
,
attention_mask
,
position_ids
=
get_batch
(
data_iterator
)
timers
(
'batch generator'
).
stop
()
# Forward model.
output
=
model
(
tokens
,
position_ids
,
attention_mask
)
losses
=
mpu
.
vocab_parallel_cross_entropy
(
output
.
contiguous
().
float
(),
labels
)
loss_mask
=
loss_mask
.
view
(
-
1
)
loss
=
torch
.
sum
(
losses
.
view
(
-
1
)
*
loss_mask
)
/
loss_mask
.
sum
()
# Reduce loss for logging.
reduced_loss
=
reduce_losses
([
loss
])
return
loss
,
{
'lm loss'
:
reduced_loss
[
0
]}
def
make_gpt2_dataloaders
():
"""Build gpt2 dataloders."""
args
=
get_args
()
# Input parameters.
input_data_sizes_file
=
args
.
input_data_sizes_file
seq_length
=
args
.
seq_length
initial_seed
=
args
.
seed
# Build the datasets.
def
_build_dataset
(
name
):
return
GPT2Dataset
(
os
.
path
.
join
(
args
.
data_path
,
name
),
args
.
input_data_sizes_file
,
args
.
seq_length
,
args
.
seed
)
train_ds
=
_build_dataset
(
'train'
)
valid_ds
=
_build_dataset
(
'valid'
)
test_ds
=
_build_dataset
(
'test'
)
# Dataloaders
train
=
make_data_loader
(
train_ds
)
valid
=
make_data_loader
(
valid_ds
)
test
=
make_data_loader
(
test_ds
)
args
.
do_train
=
False
args
.
do_valid
=
False
args
.
do_test
=
False
if
train
is
not
None
:
args
.
do_train
=
True
if
valid
is
not
None
:
args
.
do_valid
=
True
if
test
is
not
None
:
args
.
do_test
=
True
return
(
train
,
valid
,
test
)
def
get_train_val_test_data
():
"""Load the data on rank zero and boradcast number of tokens to all GPUS."""
args
=
get_args
()
(
train_data
,
val_data
,
test_data
)
=
(
None
,
None
,
None
)
# Data loader only on rank 0 of each model parallel group.
if
mpu
.
get_model_parallel_rank
()
==
0
:
(
train_data
,
val_data
,
test_data
)
=
make_gpt2_dataloaders
()
flags
=
torch
.
cuda
.
LongTensor
([
int
(
args
.
do_train
),
int
(
args
.
do_valid
),
int
(
args
.
do_test
)])
else
:
flags
=
torch
.
cuda
.
LongTensor
([
0
,
0
,
0
])
# Broadcast num tokens.
torch
.
distributed
.
broadcast
(
flags
,
mpu
.
get_model_parallel_src_rank
(),
group
=
mpu
.
get_model_parallel_group
())
args
.
do_train
=
flags
[
0
].
item
()
args
.
do_valid
=
flags
[
1
].
item
()
args
.
do_test
=
flags
[
2
].
item
()
return
train_data
,
val_data
,
test_data
if
__name__
==
"__main__"
:
pretrain
(
get_train_val_test_data
,
model_provider
,
forward_step
,
args_defaults
=
{
'tokenizer_type'
:
'GPT2BPETokenizer'
})
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment