Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
981c3dfa
Commit
981c3dfa
authored
Sep 21, 2022
by
ANMOL GUPTA
Browse files
support separate datasets for train, valid and test
parent
d63c2541
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
150 additions
and
48 deletions
+150
-48
megatron/arguments.py
megatron/arguments.py
+15
-0
megatron/data/dataset_utils.py
megatron/data/dataset_utils.py
+10
-6
megatron/data/gpt_dataset.py
megatron/data/gpt_dataset.py
+122
-42
pretrain_gpt.py
pretrain_gpt.py
+3
-0
No files found.
megatron/arguments.py
View file @
981c3dfa
...
@@ -839,6 +839,21 @@ def _add_data_args(parser):
...
@@ -839,6 +839,21 @@ def _add_data_args(parser):
'1) a single data path, 2) multiple datasets in the'
'1) a single data path, 2) multiple datasets in the'
'form: dataset1-weight dataset1-path dataset2-weight '
'form: dataset1-weight dataset1-path dataset2-weight '
'dataset2-path ...'
)
'dataset2-path ...'
)
group
.
add_argument
(
'--train-data-path'
,
nargs
=
'*'
,
default
=
None
,
help
=
'Path to the training dataset. Accepted format:'
'1) a single data path, 2) multiple datasets in the'
'form: dataset1-weight dataset1-path dataset2-weight '
'dataset2-path ...'
)
group
.
add_argument
(
'--valid-data-path'
,
nargs
=
'*'
,
default
=
None
,
help
=
'Path to the validation dataset. Accepted format:'
'1) a single data path, 2) multiple datasets in the'
'form: dataset1-weight dataset1-path dataset2-weight '
'dataset2-path ...'
)
group
.
add_argument
(
'--test-data-path'
,
nargs
=
'*'
,
default
=
None
,
help
=
'Path to the test dataset. Accepted format:'
'1) a single data path, 2) multiple datasets in the'
'form: dataset1-weight dataset1-path dataset2-weight '
'dataset2-path ...'
)
group
.
add_argument
(
'--split'
,
type
=
str
,
default
=
'969, 30, 1'
,
group
.
add_argument
(
'--split'
,
type
=
str
,
default
=
'969, 30, 1'
,
help
=
'Comma-separated list of proportions for training,'
help
=
'Comma-separated list of proportions for training,'
' validation, and test split. For example the split '
' validation, and test split. For example the split '
...
...
megatron/data/dataset_utils.py
View file @
981c3dfa
...
@@ -63,12 +63,16 @@ def get_datasets_weights_and_num_samples(data_prefix,
...
@@ -63,12 +63,16 @@ def get_datasets_weights_and_num_samples(data_prefix,
# Add 0.5% (the 1.005 factor) so in case the bleding dataset does
# Add 0.5% (the 1.005 factor) so in case the bleding dataset does
# not uniformly distribute the number of samples, we still have
# not uniformly distribute the number of samples, we still have
# samples left to feed to the network.
# samples left to feed to the network.
datasets_train_valid_test_num_samples
=
[]
if
isinstance
(
train_valid_test_num_samples
,
list
):
for
weight
in
weights
:
datasets_train_valid_test_num_samples
=
[]
datasets_train_valid_test_num_samples
.
append
(
for
weight
in
weights
:
[
int
(
math
.
ceil
(
val
*
weight
*
1.005
))
datasets_train_valid_test_num_samples
.
append
(
for
val
in
train_valid_test_num_samples
])
[
int
(
math
.
ceil
(
val
*
weight
*
1.005
))
for
val
in
train_valid_test_num_samples
])
else
:
datasets_train_valid_test_num_samples
=
[
int
(
math
.
ceil
(
train_valid_test_num_samples
*
weight
*
1.005
))
for
weight
in
weights
]
return
prefixes
,
weights
,
datasets_train_valid_test_num_samples
return
prefixes
,
weights
,
datasets_train_valid_test_num_samples
...
...
megatron/data/gpt_dataset.py
View file @
981c3dfa
...
@@ -28,53 +28,133 @@ from megatron.data.dataset_utils import get_train_valid_test_split_
...
@@ -28,53 +28,133 @@ from megatron.data.dataset_utils import get_train_valid_test_split_
from
megatron.data.indexed_dataset
import
make_dataset
as
make_indexed_dataset
from
megatron.data.indexed_dataset
import
make_dataset
as
make_indexed_dataset
def
build_train_valid_test_datasets
(
data_prefix
,
data_impl
,
splits_string
,
def
build_train_valid_test_datasets
(
data_prefix
,
train_data_prefix
,
valid_data_prefix
,
test_data_prefix
,
data_impl
,
splits_string
,
train_valid_test_num_samples
,
train_valid_test_num_samples
,
seq_length
,
seed
,
skip_warmup
):
seq_length
,
seed
,
skip_warmup
):
"""Build train, valid, and test datasets."""
"""Build train, valid, and test datasets."""
# Single dataset.
if
data_prefix
:
print_rank_0
(
"Single data path provided for train, valid & test"
)
# Single dataset.
if
len
(
data_prefix
)
==
1
:
return
_build_train_valid_test_datasets
(
data_prefix
[
0
],
data_impl
,
splits_string
,
train_valid_test_num_samples
,
seq_length
,
seed
,
skip_warmup
)
# Blending dataset.
# Parse the values.
output
=
get_datasets_weights_and_num_samples
(
data_prefix
,
train_valid_test_num_samples
)
prefixes
,
weights
,
datasets_train_valid_test_num_samples
=
output
# Build individual datasets.
train_datasets
=
[]
valid_datasets
=
[]
test_datasets
=
[]
for
i
in
range
(
len
(
prefixes
)):
train_ds
,
valid_ds
,
test_ds
=
_build_train_valid_test_datasets
(
prefixes
[
i
],
data_impl
,
splits_string
,
datasets_train_valid_test_num_samples
[
i
],
seq_length
,
seed
,
skip_warmup
)
if
train_ds
:
train_datasets
.
append
(
train_ds
)
if
valid_ds
:
valid_datasets
.
append
(
valid_ds
)
if
test_ds
:
test_datasets
.
append
(
test_ds
)
# Blend.
blending_train_dataset
=
None
if
train_datasets
:
blending_train_dataset
=
BlendableDataset
(
train_datasets
,
weights
)
blending_valid_dataset
=
None
if
valid_datasets
:
blending_valid_dataset
=
BlendableDataset
(
valid_datasets
,
weights
)
blending_test_dataset
=
None
if
test_datasets
:
blending_test_dataset
=
BlendableDataset
(
test_datasets
,
weights
)
return
(
blending_train_dataset
,
blending_valid_dataset
,
blending_test_dataset
)
else
:
print_rank_0
(
"Separate data paths provided for train, valid & test. Split string will be ignored."
)
assert
(
train_data_prefix
is
not
None
)
train_dataset
,
valid_dataset
,
test_dataset
=
None
,
None
,
None
# Single dataset.
train_dataset
=
build_dataset
(
"train"
,
train_data_prefix
,
data_impl
,
train_valid_test_num_samples
[
0
],
seq_length
,
seed
,
skip_warmup
)
if
valid_data_prefix
is
not
None
:
valid_dataset
=
build_dataset
(
"valid"
,
valid_data_prefix
,
data_impl
,
train_valid_test_num_samples
[
1
],
seq_length
,
seed
,
False
)
if
test_data_prefix
is
not
None
:
test_dataset
=
build_dataset
(
"test"
,
test_data_prefix
,
data_impl
,
train_valid_test_num_samples
[
2
],
seq_length
,
seed
,
False
)
return
(
train_dataset
,
valid_dataset
,
test_dataset
)
def
build_dataset
(
dataset_name
,
data_prefix
,
data_impl
,
num_samples
,
seq_length
,
seed
,
skip_warmup
):
dataset
=
None
if
len
(
data_prefix
)
==
1
:
if
len
(
data_prefix
)
==
1
:
return
_build_train_valid_test_datasets
(
data_prefix
[
0
],
dataset
=
_build_dataset
(
dataset_name
,
data_impl
,
splits_string
,
data_prefix
[
0
],
data_impl
,
train_valid_test_num_samples
,
num_samples
,
seq_length
,
seq_length
,
seed
,
skip_warmup
)
seed
,
skip_warmup
)
else
:
# Blending dataset.
# Blending dataset.
# Parse the values.
# Parse the values.
output
=
get_datasets_weights_and_num_samples
(
data_prefix
,
output
=
get_datasets_weights_and_num_samples
(
data_prefix
,
num_samples
)
train_valid_test_num_samples
)
prefixes
,
weights
,
dataset_num_samples
=
output
prefixes
,
weights
,
datasets_train_valid_test_num_samples
=
output
# Build individual datasets.
# Build individual datasets.
datasets
=
[]
train_datasets
=
[]
for
i
in
range
(
len
(
prefixes
)):
valid_datasets
=
[]
ds
=
_build_dataset
(
dataset_name
,
prefixes
[
i
],
test_datasets
=
[]
data_impl
,
dataset_num_samples
[
i
],
for
i
in
range
(
len
(
prefixes
)):
seq_length
,
seed
,
skip_warmup
)
train_ds
,
valid_ds
,
test_ds
=
_build_train_valid_test_datasets
(
if
ds
:
prefixes
[
i
],
data_impl
,
splits_string
,
datasets
.
append
(
ds
)
datasets_train_valid_test_num_samples
[
i
],
seq_length
,
seed
,
skip_warmup
)
if
datasets
:
if
train_ds
:
dataset
=
BlendableDataset
(
datasets
,
weights
)
train_datasets
.
append
(
train_ds
)
if
valid_ds
:
return
dataset
valid_datasets
.
append
(
valid_ds
)
if
test_ds
:
test_datasets
.
append
(
test_ds
)
def
_build_dataset
(
dataset_name
,
data_prefix
,
data_impl
,
num_samples
,
seq_length
,
seed
,
skip_warmup
):
# Blend.
"""
blending_train_dataset
=
None
Build dataset. This method is called when individual
if
train_datasets
:
train, valid, test datasets are provided
blending_train_dataset
=
BlendableDataset
(
train_datasets
,
weights
)
"""
blending_valid_dataset
=
None
if
valid_datasets
:
# Indexed dataset.
blending_valid_dataset
=
BlendableDataset
(
valid_datasets
,
weights
)
indexed_dataset
=
get_indexed_dataset_
(
data_prefix
,
blending_test_dataset
=
None
data_impl
,
if
test_datasets
:
skip_warmup
)
blending_test_dataset
=
BlendableDataset
(
test_datasets
,
weights
)
total_num_of_documents
=
indexed_dataset
.
sizes
.
shape
[
0
]
return
(
blending_train_dataset
,
blending_valid_dataset
,
blending_test_dataset
)
print_rank_0
(
' {}:'
.
format
(
dataset_name
))
print_rank_0
(
' document indices in [0, {}) total of {} '
'documents'
.
format
(
total_num_of_documents
,
total_num_of_documents
))
documents
=
np
.
arange
(
start
=
0
,
stop
=
total_num_of_documents
,
step
=
1
,
dtype
=
np
.
int32
)
dataset
=
GPTDataset
(
dataset_name
,
data_prefix
,
documents
,
indexed_dataset
,
num_samples
,
seq_length
,
seed
)
return
dataset
def
_build_train_valid_test_datasets
(
data_prefix
,
data_impl
,
splits_string
,
def
_build_train_valid_test_datasets
(
data_prefix
,
data_impl
,
splits_string
,
...
...
pretrain_gpt.py
View file @
981c3dfa
...
@@ -108,6 +108,9 @@ def train_valid_test_datasets_provider(train_val_test_num_samples):
...
@@ -108,6 +108,9 @@ def train_valid_test_datasets_provider(train_val_test_num_samples):
'for GPT ...'
)
'for GPT ...'
)
train_ds
,
valid_ds
,
test_ds
=
build_train_valid_test_datasets
(
train_ds
,
valid_ds
,
test_ds
=
build_train_valid_test_datasets
(
data_prefix
=
args
.
data_path
,
data_prefix
=
args
.
data_path
,
train_data_prefix
=
args
.
train_data_path
,
valid_data_prefix
=
args
.
valid_data_path
,
test_data_prefix
=
args
.
test_data_path
,
data_impl
=
args
.
data_impl
,
data_impl
=
args
.
data_impl
,
splits_string
=
args
.
split
,
splits_string
=
args
.
split
,
train_valid_test_num_samples
=
train_val_test_num_samples
,
train_valid_test_num_samples
=
train_val_test_num_samples
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment