Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
0601702a
Commit
0601702a
authored
Nov 24, 2019
by
Mohammad Shoeybi
Browse files
zero worker seems to be working
parent
f6a6811f
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
305 additions
and
180 deletions
+305
-180
megatron/data/albert_dataset.py
megatron/data/albert_dataset.py
+154
-24
megatron/data/helpers.cpp
megatron/data/helpers.cpp
+40
-44
megatron/data/indexed_dataset.py
megatron/data/indexed_dataset.py
+13
-9
megatron/data/split_dataset.py
megatron/data/split_dataset.py
+19
-28
pretrain_albert.py
pretrain_albert.py
+79
-75
No files found.
megatron/data/albert_dataset.py
View file @
0601702a
...
@@ -8,6 +8,7 @@ import numpy as np
...
@@ -8,6 +8,7 @@ import numpy as np
import
torch
import
torch
from
torch.utils.data
import
Dataset
from
torch.utils.data
import
Dataset
from
megatron
import
mpu
from
megatron.data
import
helpers
from
megatron.data
import
helpers
from
megatron.data
import
FullBertTokenizer
from
megatron.data
import
FullBertTokenizer
from
megatron.data.dataset_utils
import
build_training_sample
from
megatron.data.dataset_utils
import
build_training_sample
...
@@ -15,22 +16,97 @@ from megatron.data.indexed_dataset import make_dataset as make_indexed_dataset
...
@@ -15,22 +16,97 @@ from megatron.data.indexed_dataset import make_dataset as make_indexed_dataset
from
megatron.utils
import
print_rank_0
from
megatron.utils
import
print_rank_0
def
build_train_valid_test_datasets
(
vocab_file
,
data_prefix
,
data_impl
,
splits_string
,
train_valid_test_num_samples
,
max_seq_length
,
masked_lm_prob
,
short_seq_prob
,
seed
,
skip_warmup
):
# Tokenizer is the same
tokenizer
=
FullBertTokenizer
(
vocab_file
,
do_lower_case
=
True
)
print_rank_0
(
' > using full BERT tokenizer with vocabulary size: {}'
.
format
(
tokenizer
.
vocab_size
()))
# Indexed dataset.
indexed_dataset
=
get_indexed_dataset_
(
data_prefix
,
data_impl
,
skip_warmup
)
# Get start and end indices of train/valid/train into doc-idx
# Note that doc-idx is desinged to be num-docs + 1 so we can
# easily iterate over it.
total_num_of_documents
=
indexed_dataset
.
doc_idx
.
shape
[
0
]
-
1
splits
=
get_train_valid_test_split_
(
splits_string
,
total_num_of_documents
)
# Print stats about the splits.
print_rank_0
(
' > dataset split:'
)
def
print_split_stats
(
name
,
index
):
print_rank_0
(
' {}:'
.
format
(
name
))
print_rank_0
(
' document indices in [{}, {}) total of {} '
'documents'
.
format
(
splits
[
index
],
splits
[
index
+
1
],
splits
[
index
+
1
]
-
splits
[
index
]))
start_index
=
indexed_dataset
.
doc_idx
[
splits
[
index
]]
end_index
=
indexed_dataset
.
doc_idx
[
splits
[
index
+
1
]]
print_rank_0
(
' sentence indices in [{}, {}) total of {} '
'sentences'
.
format
(
start_index
,
end_index
,
end_index
-
start_index
))
print_split_stats
(
'train'
,
0
)
print_split_stats
(
'validation'
,
1
)
print_split_stats
(
'test'
,
2
)
def
build_dataset
(
index
,
name
):
dataset
=
None
if
splits
[
index
+
1
]
>
splits
[
index
]:
# Get the pointer to the original doc-idx so we can set it later.
doc_idx_ptr
=
indexed_dataset
.
get_doc_idx
()
# Slice the doc-idx
start_index
=
splits
[
index
]
# Add +1 so we can index into the dataset to get the upper bound.
end_index
=
splits
[
index
+
1
]
+
1
# New doc_idx view.
indexed_dataset
.
set_doc_idx
(
doc_idx_ptr
[
start_index
:
end_index
])
# Build the dataset accordingly.
dataset
=
AlbertDataset
(
name
=
name
,
indexed_dataset
=
indexed_dataset
,
tokenizer
=
tokenizer
,
data_prefix
=
data_prefix
,
num_epochs
=
None
,
max_num_samples
=
train_valid_test_num_samples
[
index
],
masked_lm_prob
=
masked_lm_prob
,
max_seq_length
=
max_seq_length
,
short_seq_prob
=
short_seq_prob
,
seed
=
seed
)
# Set the original pointer so dataset remains the main dataset.
indexed_dataset
.
set_doc_idx
(
doc_idx_ptr
)
# Checks.
assert
indexed_dataset
.
doc_idx
[
0
]
==
0
assert
indexed_dataset
.
doc_idx
.
shape
[
0
]
==
\
(
total_num_of_documents
+
1
)
return
dataset
train_dataset
=
build_dataset
(
0
,
'train'
)
valid_dataset
=
build_dataset
(
1
,
'valid'
)
test_dataset
=
build_dataset
(
2
,
'test'
)
return
(
train_dataset
,
valid_dataset
,
test_dataset
)
class
AlbertDataset
(
Dataset
):
class
AlbertDataset
(
Dataset
):
def
__init__
(
self
,
vocab_file
,
data_prefix
,
data_impl
,
skip_warmup
,
def
__init__
(
self
,
name
,
indexed_dataset
,
tokenizer
,
data_prefix
,
num_epochs
,
max_num_samples
,
masked_lm_prob
,
max_seq_length
,
num_epochs
,
max_num_samples
,
masked_lm_prob
,
short_seq_prob
,
seed
):
max_seq_length
,
short_seq_prob
,
seed
):
# Params to store.
# Params to store.
self
.
name
=
name
self
.
seed
=
seed
self
.
seed
=
seed
self
.
masked_lm_prob
=
masked_lm_prob
self
.
masked_lm_prob
=
masked_lm_prob
self
.
max_seq_length
=
max_seq_length
self
.
max_seq_length
=
max_seq_length
self
.
tokenizer
=
FullBertTokenizer
(
vocab_file
,
do_lower_case
=
True
)
#
Indexe
d dataset.
#
Tokenizer an
d dataset.
self
.
indexed_dataset
=
get_indexed_dataset_
(
data_prefix
,
self
.
tokenizer
=
tokenizer
data_impl
,
self
.
indexed_dataset
=
indexed_dataset
skip_warmup
)
# Build the samples mapping.
# Build the samples mapping.
self
.
samples_mapping
=
get_samples_mapping_
(
self
.
indexed_dataset
,
self
.
samples_mapping
=
get_samples_mapping_
(
self
.
indexed_dataset
,
...
@@ -39,7 +115,8 @@ class AlbertDataset(Dataset):
...
@@ -39,7 +115,8 @@ class AlbertDataset(Dataset):
max_num_samples
,
max_num_samples
,
self
.
max_seq_length
,
self
.
max_seq_length
,
short_seq_prob
,
short_seq_prob
,
self
.
seed
)
self
.
seed
,
self
.
name
)
# Vocab stuff.
# Vocab stuff.
self
.
vocab_id_list
=
list
(
self
.
tokenizer
.
inv_vocab
.
keys
())
self
.
vocab_id_list
=
list
(
self
.
tokenizer
.
inv_vocab
.
keys
())
...
@@ -48,7 +125,6 @@ class AlbertDataset(Dataset):
...
@@ -48,7 +125,6 @@ class AlbertDataset(Dataset):
self
.
sep_id
=
self
.
tokenizer
.
vocab
[
'[SEP]'
]
self
.
sep_id
=
self
.
tokenizer
.
vocab
[
'[SEP]'
]
self
.
mask_id
=
self
.
tokenizer
.
vocab
[
'[MASK]'
]
self
.
mask_id
=
self
.
tokenizer
.
vocab
[
'[MASK]'
]
self
.
pad_id
=
self
.
tokenizer
.
vocab
[
'[PAD]'
]
self
.
pad_id
=
self
.
tokenizer
.
vocab
[
'[PAD]'
]
exit
()
def
num_tokens
(
self
):
def
num_tokens
(
self
):
...
@@ -68,9 +144,11 @@ class AlbertDataset(Dataset):
...
@@ -68,9 +144,11 @@ class AlbertDataset(Dataset):
sample
=
[]
sample
=
[]
for
index
in
range
(
start_index
,
end_index
):
for
index
in
range
(
start_index
,
end_index
):
sample
.
append
(
self
.
indexed_dataset
[
index
])
sample
.
append
(
self
.
indexed_dataset
[
index
])
'''
for s in sample:
for s in sample:
if len(s) > 1000:
if len(s) > 1000:
print(self.tokenizer.convert_ids_to_tokens(s))
print(self.tokenizer.convert_ids_to_tokens(s))
'''
return
build_training_sample
(
sample
,
seq_length
,
return
build_training_sample
(
sample
,
seq_length
,
self
.
max_seq_length
,
# needed for padding
self
.
max_seq_length
,
# needed for padding
self
.
vocab_id_list
,
self
.
vocab_id_list
,
...
@@ -80,25 +158,63 @@ class AlbertDataset(Dataset):
...
@@ -80,25 +158,63 @@ class AlbertDataset(Dataset):
self
.
masked_lm_prob
,
rng
)
self
.
masked_lm_prob
,
rng
)
def
get_indexed_dataset_
(
data_prefix
,
data_impl
,
skip_warmup
):
def
get_indexed_dataset_
(
data_prefix
,
data_impl
,
skip_warmup
):
print_rank_0
(
' > building dataset index ...'
)
start_time
=
time
.
time
()
start_time
=
time
.
time
()
print_rank_0
(
"> Reading dataset index ..."
)
indexed_dataset
=
make_indexed_dataset
(
data_prefix
,
indexed_dataset
=
make_indexed_dataset
(
data_prefix
,
data_impl
,
data_impl
,
skip_warmup
)
skip_warmup
)
print_rank_0
(
"> Finished creating indexed dataset in {:4f} "
assert
indexed_dataset
.
sizes
.
shape
[
0
]
==
indexed_dataset
.
doc_idx
[
-
1
]
"seconds"
.
format
(
time
.
time
()
-
start_time
))
print_rank_0
(
' > finished creating indexed dataset in {:4f} '
'seconds'
.
format
(
time
.
time
()
-
start_time
))
print_rank_0
(
' > indexed dataset stats:'
)
print_rank_0
(
' number of documents: {}'
.
format
(
indexed_dataset
.
doc_idx
.
shape
[
0
]
-
1
))
print_rank_0
(
' number of sentences: {}'
.
format
(
indexed_dataset
.
sizes
.
shape
[
0
]))
return
indexed_dataset
return
indexed_dataset
def
get_train_valid_test_split_
(
splits_string
,
size
):
""" Get dataset splits from comma or '/' separated string list."""
splits
=
[]
if
splits_string
.
find
(
','
)
!=
-
1
:
splits
=
[
float
(
s
)
for
s
in
splits_string
.
split
(
','
)]
elif
splits_string
.
find
(
'/'
)
!=
-
1
:
splits
=
[
float
(
s
)
for
s
in
splits_string
.
split
(
'/'
)]
else
:
splits
=
[
float
(
splits_string
)]
while
len
(
splits
)
<
3
:
splits
.
append
(
0.
)
splits
=
splits
[:
3
]
splits_sum
=
sum
(
splits
)
assert
splits_sum
>
0.0
splits
=
[
split
/
splits_sum
for
split
in
splits
]
splits_index
=
[
0
]
for
index
,
split
in
enumerate
(
splits
):
splits_index
.
append
(
splits_index
[
index
]
+
int
(
round
(
split
*
float
(
size
))))
diff
=
splits_index
[
-
1
]
-
size
for
index
in
range
(
1
,
len
(
splits_index
)):
splits_index
[
index
]
-=
diff
assert
len
(
splits_index
)
==
4
assert
splits_index
[
-
1
]
==
size
return
splits_index
def
get_samples_mapping_
(
indexed_dataset
,
def
get_samples_mapping_
(
indexed_dataset
,
data_prefix
,
data_prefix
,
num_epochs
,
num_epochs
,
max_num_samples
,
max_num_samples
,
max_seq_length
,
max_seq_length
,
short_seq_prob
,
short_seq_prob
,
seed
):
seed
,
name
):
if
not
num_epochs
:
if
not
num_epochs
:
if
not
max_num_samples
:
if
not
max_num_samples
:
raise
ValueError
(
"Need to specify either max_num_samples "
raise
ValueError
(
"Need to specify either max_num_samples "
...
@@ -109,9 +225,11 @@ def get_samples_mapping_(indexed_dataset,
...
@@ -109,9 +225,11 @@ def get_samples_mapping_(indexed_dataset,
# Filename of the index mapping
# Filename of the index mapping
indexmap_filename
=
data_prefix
indexmap_filename
=
data_prefix
indexmap_filename
+=
'_indexmap'
indexmap_filename
+=
'_{}_indexmap'
.
format
(
name
)
indexmap_filename
+=
'_{}ep'
.
format
(
num_epochs
)
if
num_epochs
!=
(
np
.
iinfo
(
np
.
int32
).
max
-
1
):
indexmap_filename
+=
'_{}mns'
.
format
(
max_num_samples
)
indexmap_filename
+=
'_{}ep'
.
format
(
num_epochs
)
if
max_num_samples
!=
(
np
.
iinfo
(
np
.
int64
).
max
-
1
):
indexmap_filename
+=
'_{}mns'
.
format
(
max_num_samples
)
indexmap_filename
+=
'_{}msl'
.
format
(
max_seq_length
)
indexmap_filename
+=
'_{}msl'
.
format
(
max_seq_length
)
indexmap_filename
+=
'_{:0.2f}ssp'
.
format
(
short_seq_prob
)
indexmap_filename
+=
'_{:0.2f}ssp'
.
format
(
short_seq_prob
)
indexmap_filename
+=
'_{}s'
.
format
(
seed
)
indexmap_filename
+=
'_{}s'
.
format
(
seed
)
...
@@ -120,8 +238,9 @@ def get_samples_mapping_(indexed_dataset,
...
@@ -120,8 +238,9 @@ def get_samples_mapping_(indexed_dataset,
# Build the indexed mapping if not exist.
# Build the indexed mapping if not exist.
if
torch
.
distributed
.
get_rank
()
==
0
and
\
if
torch
.
distributed
.
get_rank
()
==
0
and
\
not
os
.
path
.
isfile
(
indexmap_filename
):
not
os
.
path
.
isfile
(
indexmap_filename
):
print
(
'WARNING: could not find index map file {}, building '
print
(
'
>
WARNING: could not find index map file {}, building '
'the indices on rank 0 ...'
.
format
(
indexmap_filename
))
'the indices on rank 0 ...'
.
format
(
indexmap_filename
))
# Make sure the types match the helpers input types.
# Make sure the types match the helpers input types.
assert
indexed_dataset
.
doc_idx
.
dtype
==
np
.
int64
assert
indexed_dataset
.
doc_idx
.
dtype
==
np
.
int64
assert
indexed_dataset
.
sizes
.
dtype
==
np
.
int32
assert
indexed_dataset
.
sizes
.
dtype
==
np
.
int32
...
@@ -129,6 +248,8 @@ def get_samples_mapping_(indexed_dataset,
...
@@ -129,6 +248,8 @@ def get_samples_mapping_(indexed_dataset,
# Build samples mapping
# Build samples mapping
verbose
=
torch
.
distributed
.
get_rank
()
==
0
verbose
=
torch
.
distributed
.
get_rank
()
==
0
start_time
=
time
.
time
()
start_time
=
time
.
time
()
print_rank_0
(
' > building sapmles index mapping for {} ...'
.
format
(
name
))
samples_mapping
=
helpers
.
build_mapping
(
samples_mapping
=
helpers
.
build_mapping
(
indexed_dataset
.
doc_idx
,
indexed_dataset
.
doc_idx
,
indexed_dataset
.
sizes
,
indexed_dataset
.
sizes
,
...
@@ -138,21 +259,30 @@ def get_samples_mapping_(indexed_dataset,
...
@@ -138,21 +259,30 @@ def get_samples_mapping_(indexed_dataset,
short_seq_prob
,
short_seq_prob
,
seed
,
seed
,
verbose
)
verbose
)
print_rank_0
(
' > done building sapmles index maping'
)
np
.
save
(
indexmap_filename
,
samples_mapping
,
allow_pickle
=
True
)
np
.
save
(
indexmap_filename
,
samples_mapping
,
allow_pickle
=
True
)
print_rank_0
(
' > saved the index mapping in {}'
.
format
(
indexmap_filename
))
# Make sure all the ranks have built the mapping
# Make sure all the ranks have built the mapping
print_rank_0
(
'> elasped time to build and save samples mapping '
print_rank_0
(
'
> elasped time to build and save samples mapping '
'(seconds): {:4f}'
.
format
(
'(seconds): {:4f}'
.
format
(
time
.
time
()
-
start_time
))
time
.
time
()
-
start_time
))
torch
.
distributed
.
barrier
()
# This should be a barrier but nccl barrier assumes
# device_index=rank which is not the case for model
# parallel case
counts
=
torch
.
cuda
.
LongTensor
([
1
])
torch
.
distributed
.
all_reduce
(
counts
,
group
=
mpu
.
get_data_parallel_group
())
assert
counts
[
0
].
item
()
==
torch
.
distributed
.
get_world_size
(
group
=
mpu
.
get_data_parallel_group
())
# Load indexed dataset.
# Load indexed dataset.
print_rank_0
(
'> loading indexed mapping from {}'
.
format
(
print_rank_0
(
'
> loading indexed mapping from {}'
.
format
(
indexmap_filename
))
indexmap_filename
))
start_time
=
time
.
time
()
start_time
=
time
.
time
()
samples_mapping
=
np
.
load
(
indexmap_filename
,
allow_pickle
=
True
)
samples_mapping
=
np
.
load
(
indexmap_filename
,
allow_pickle
=
True
)
print_rank_0
(
' loaded indexed file in {:3.3f} seconds'
.
format
(
print_rank_0
(
'
loaded indexed file in {:3.3f} seconds'
.
format
(
time
.
time
()
-
start_time
))
time
.
time
()
-
start_time
))
print_rank_0
(
' total number of samples: {}'
.
format
(
print_rank_0
(
'
total number of samples: {}'
.
format
(
samples_mapping
.
shape
[
0
]))
samples_mapping
.
shape
[
0
]))
return
samples_mapping
return
samples_mapping
...
...
megatron/data/helpers.cpp
View file @
0601702a
...
@@ -39,12 +39,6 @@ py::array build_mapping_impl(const py::array_t<int64_t>& docs_,
...
@@ -39,12 +39,6 @@ py::array build_mapping_impl(const py::array_t<int64_t>& docs_,
and sequence-length is the target sequence length.
and sequence-length is the target sequence length.
*/
*/
if
(
verbose
)
{
cout
<<
" > using "
<<
docs_
.
shape
(
0
)
-
1
<<
" documents with "
<<
sizes_
.
shape
(
0
)
<<
" sentences ..."
<<
endl
<<
std
::
flush
;
}
// Consistency checks.
// Consistency checks.
assert
(
num_epochs
>
0
);
assert
(
num_epochs
>
0
);
assert
(
max_seq_length
>
1
);
assert
(
max_seq_length
>
1
);
...
@@ -52,16 +46,36 @@ py::array build_mapping_impl(const py::array_t<int64_t>& docs_,
...
@@ -52,16 +46,36 @@ py::array build_mapping_impl(const py::array_t<int64_t>& docs_,
assert
(
short_seq_prob
<=
1.0
);
assert
(
short_seq_prob
<=
1.0
);
assert
(
seed
>
0
);
assert
(
seed
>
0
);
// For efficiency, convert probability to ratio. Note: rand() generates int.
const
auto
short_seq_ratio
=
static_cast
<
int32_t
>
(
round
(
1.0
/
short_seq_prob
));
// Remove bound checks.
// Remove bound checks.
auto
docs
=
docs_
.
unchecked
<
1
>
();
auto
docs
=
docs_
.
unchecked
<
1
>
();
auto
sizes
=
sizes_
.
unchecked
<
1
>
();
auto
sizes
=
sizes_
.
unchecked
<
1
>
();
if
(
docs
[
docs
.
shape
(
0
)
-
1
]
!=
sizes
.
shape
(
0
))
{
cout
<<
"document values is not consistent with length of sizes: "
<<
// For efficiency, convert probability to ratio. Note: rand() generates int.
docs
[
docs
.
shape
(
0
)
-
1
]
<<
" != "
<<
sizes
.
shape
(
0
)
<<
endl
;
const
auto
short_seq_ratio
=
static_cast
<
int32_t
>
(
round
(
1.0
/
short_seq_prob
));
throw
std
::
length_error
(
"docs and sizes"
);
if
(
verbose
)
{
const
auto
sent_start_index
=
docs
[
0
];
const
auto
sent_end_index
=
docs
[
docs_
.
shape
(
0
)
-
1
];
const
auto
num_sentences
=
sent_end_index
-
sent_start_index
;
cout
<<
" using:"
<<
endl
<<
std
::
flush
;
cout
<<
" number of documents: "
<<
docs_
.
shape
(
0
)
-
1
<<
endl
<<
std
::
flush
;
cout
<<
" sentences range: ["
<<
sent_start_index
<<
", "
<<
sent_end_index
<<
")"
<<
endl
<<
std
::
flush
;
cout
<<
" total number of sentences: "
<<
num_sentences
<<
endl
<<
std
::
flush
;
cout
<<
" number of epochs: "
<<
num_epochs
<<
endl
<<
std
::
flush
;
cout
<<
" maximum number of samples: "
<<
max_num_samples
<<
endl
<<
std
::
flush
;
cout
<<
" maximum sequence length: "
<<
max_seq_length
<<
endl
<<
std
::
flush
;
cout
<<
" short sequence probability: "
<<
short_seq_prob
<<
endl
<<
std
::
flush
;
cout
<<
" short sequence ration (1/prob): "
<<
short_seq_ratio
<<
endl
<<
std
::
flush
;
cout
<<
" seed: "
<<
seed
<<
endl
<<
std
::
flush
;
}
}
// Mapping and it's length (1D).
// Mapping and it's length (1D).
...
@@ -90,7 +104,7 @@ py::array build_mapping_impl(const py::array_t<int64_t>& docs_,
...
@@ -90,7 +104,7 @@ py::array build_mapping_impl(const py::array_t<int64_t>& docs_,
for
(
int32_t
epoch
=
0
;
epoch
<
num_epochs
;
++
epoch
)
{
for
(
int32_t
epoch
=
0
;
epoch
<
num_epochs
;
++
epoch
)
{
if
(
map_index
>=
max_num_samples
)
{
if
(
map_index
>=
max_num_samples
)
{
if
(
verbose
&&
(
!
second
))
{
if
(
verbose
&&
(
!
second
))
{
cout
<<
"
>
reached "
<<
max_num_samples
<<
" samples after "
cout
<<
"
reached "
<<
max_num_samples
<<
" samples after "
<<
epoch
<<
" epochs ..."
<<
endl
<<
std
::
flush
;
<<
epoch
<<
" epochs ..."
<<
endl
<<
std
::
flush
;
}
}
break
;
break
;
...
@@ -181,11 +195,11 @@ py::array build_mapping_impl(const py::array_t<int64_t>& docs_,
...
@@ -181,11 +195,11 @@ py::array build_mapping_impl(const py::array_t<int64_t>& docs_,
if
(
!
second
)
{
if
(
!
second
)
{
if
(
verbose
)
{
if
(
verbose
)
{
cout
<<
"
>
number of empty documents: "
<<
empty_docs
<<
cout
<<
"
number of empty documents: "
<<
empty_docs
<<
endl
<<
std
::
flush
;
endl
<<
std
::
flush
;
cout
<<
"
>
number of documents with one sentence: "
<<
cout
<<
"
number of documents with one sentence: "
<<
one_sent_docs
<<
endl
<<
std
::
flush
;
one_sent_docs
<<
endl
<<
std
::
flush
;
cout
<<
"
>
will create mapping for "
<<
map_index
<<
cout
<<
"
will create mapping for "
<<
map_index
<<
" samples"
<<
endl
<<
std
::
flush
;
" samples"
<<
endl
<<
std
::
flush
;
}
}
assert
(
maps
==
NULL
);
assert
(
maps
==
NULL
);
...
@@ -210,10 +224,6 @@ py::array build_mapping_impl(const py::array_t<int64_t>& docs_,
...
@@ -210,10 +224,6 @@ py::array build_mapping_impl(const py::array_t<int64_t>& docs_,
swap
(
maps
[
i0
+
2
],
maps
[
j0
+
2
]);
swap
(
maps
[
i0
+
2
],
maps
[
j0
+
2
]);
}
}
if
(
verbose
)
{
cout
<<
"> done building the mapping."
<<
endl
;
}
// Method to deallocate memory.
// Method to deallocate memory.
py
::
capsule
free_when_done
(
maps
,
[](
void
*
mem_
)
{
py
::
capsule
free_when_done
(
maps
,
[](
void
*
mem_
)
{
DocIdx
*
mem
=
reinterpret_cast
<
DocIdx
*>
(
mem_
);
DocIdx
*
mem
=
reinterpret_cast
<
DocIdx
*>
(
mem_
);
...
@@ -239,34 +249,20 @@ py::array build_mapping(const py::array_t<int64_t>& docs_,
...
@@ -239,34 +249,20 @@ py::array build_mapping(const py::array_t<int64_t>& docs_,
const
int
seed
,
const
int
seed
,
const
bool
verbose
)
{
const
bool
verbose
)
{
if
(
verbose
)
{
cout
<<
"> building sample map using: "
<<
endl
<<
std
::
flush
;
cout
<<
" number of epochs: "
<<
num_epochs
<<
endl
<<
std
::
flush
;
cout
<<
" maximum number of samples: "
<<
max_num_samples
<<
endl
<<
std
::
flush
;
cout
<<
" maximum sequence length: "
<<
max_seq_length
<<
endl
<<
std
::
flush
;
cout
<<
" short sequence probability: "
<<
short_seq_prob
<<
endl
<<
std
::
flush
;
cout
<<
" seed: "
<<
seed
<<
endl
<<
std
::
flush
;
}
if
(
sizes_
.
size
()
>
std
::
numeric_limits
<
uint32_t
>::
max
())
{
if
(
sizes_
.
size
()
>
std
::
numeric_limits
<
uint32_t
>::
max
())
{
if
(
verbose
)
{
if
(
verbose
)
{
cout
<<
"
>
using uint64 for data mapping..."
<<
endl
<<
std
::
flush
;
cout
<<
"
using uint64 for data mapping..."
<<
endl
<<
std
::
flush
;
}
}
return
build_mapping_impl
<
uint64_t
>
(
docs_
,
sizes_
,
num_epochs
,
return
build_mapping_impl
<
uint64_t
>
(
docs_
,
sizes_
,
num_epochs
,
max_num_samples
,
max_seq_length
,
max_num_samples
,
max_seq_length
,
short_seq_prob
,
seed
,
verbose
);
short_seq_prob
,
seed
,
verbose
);
}
else
{
}
else
{
if
(
verbose
)
{
if
(
verbose
)
{
cout
<<
"
>
using uint32 for data mapping..."
<<
endl
<<
std
::
flush
;
cout
<<
"
using uint32 for data mapping..."
<<
endl
<<
std
::
flush
;
}
}
return
build_mapping_impl
<
uint32_t
>
(
docs_
,
sizes_
,
num_epochs
,
return
build_mapping_impl
<
uint32_t
>
(
docs_
,
sizes_
,
num_epochs
,
max_num_samples
,
max_seq_length
,
max_num_samples
,
max_seq_length
,
short_seq_prob
,
seed
,
verbose
);
short_seq_prob
,
seed
,
verbose
);
}
}
}
}
...
...
megatron/data/indexed_dataset.py
View file @
0601702a
...
@@ -391,17 +391,17 @@ class MMapIndexedDataset(torch.utils.data.Dataset):
...
@@ -391,17 +391,17 @@ class MMapIndexedDataset(torch.utils.data.Dataset):
offset
=
stream
.
tell
()
offset
=
stream
.
tell
()
if
not
skip_warmup
:
if
not
skip_warmup
:
print_rank_0
(
"
>
W
arming up index mmap file..."
)
print_rank_0
(
"
w
arming up index mmap file..."
)
_warmup_mmap_file
(
path
)
_warmup_mmap_file
(
path
)
self
.
_bin_buffer_mmap
=
np
.
memmap
(
path
,
mode
=
'r'
,
order
=
'C'
)
self
.
_bin_buffer_mmap
=
np
.
memmap
(
path
,
mode
=
'r'
,
order
=
'C'
)
self
.
_bin_buffer
=
memoryview
(
self
.
_bin_buffer_mmap
)
self
.
_bin_buffer
=
memoryview
(
self
.
_bin_buffer_mmap
)
print_rank_0
(
"
>
R
eading sizes..."
)
print_rank_0
(
"
r
eading sizes..."
)
self
.
_sizes
=
np
.
frombuffer
(
self
.
_bin_buffer
,
dtype
=
np
.
int32
,
count
=
self
.
_len
,
offset
=
offset
)
self
.
_sizes
=
np
.
frombuffer
(
self
.
_bin_buffer
,
dtype
=
np
.
int32
,
count
=
self
.
_len
,
offset
=
offset
)
print_rank_0
(
"
>
R
eading pointers..."
)
print_rank_0
(
"
r
eading pointers..."
)
self
.
_pointers
=
np
.
frombuffer
(
self
.
_bin_buffer
,
dtype
=
np
.
int64
,
count
=
self
.
_len
,
self
.
_pointers
=
np
.
frombuffer
(
self
.
_bin_buffer
,
dtype
=
np
.
int64
,
count
=
self
.
_len
,
offset
=
offset
+
self
.
_sizes
.
nbytes
)
offset
=
offset
+
self
.
_sizes
.
nbytes
)
print_rank_0
(
"
>
R
eading document index..."
)
print_rank_0
(
"
r
eading document index..."
)
self
.
_doc_idx
=
np
.
frombuffer
(
self
.
_bin_buffer
,
dtype
=
np
.
int64
,
count
=
self
.
_doc_count
,
self
.
_doc_idx
=
np
.
frombuffer
(
self
.
_bin_buffer
,
dtype
=
np
.
int64
,
count
=
self
.
_doc_count
,
offset
=
offset
+
self
.
_sizes
.
nbytes
+
self
.
_pointers
.
nbytes
)
offset
=
offset
+
self
.
_sizes
.
nbytes
+
self
.
_pointers
.
nbytes
)
def
__del__
(
self
):
def
__del__
(
self
):
...
@@ -447,13 +447,12 @@ class MMapIndexedDataset(torch.utils.data.Dataset):
...
@@ -447,13 +447,12 @@ class MMapIndexedDataset(torch.utils.data.Dataset):
self
.
_index
=
self
.
Index
(
index_file_path
(
self
.
_path
),
skip_warmup
)
self
.
_index
=
self
.
Index
(
index_file_path
(
self
.
_path
),
skip_warmup
)
if
not
skip_warmup
:
if
not
skip_warmup
:
print_rank_0
(
"
>
W
arming up data mmap file..."
)
print_rank_0
(
"
w
arming up data mmap file..."
)
_warmup_mmap_file
(
data_file_path
(
self
.
_path
))
_warmup_mmap_file
(
data_file_path
(
self
.
_path
))
print_rank_0
(
"
>
C
reating numpy buffer of mmap..."
)
print_rank_0
(
"
c
reating numpy buffer of mmap..."
)
self
.
_bin_buffer_mmap
=
np
.
memmap
(
data_file_path
(
self
.
_path
),
mode
=
'r'
,
order
=
'C'
)
self
.
_bin_buffer_mmap
=
np
.
memmap
(
data_file_path
(
self
.
_path
),
mode
=
'r'
,
order
=
'C'
)
print_rank_0
(
"
>
C
reating memory view of numpy buffer..."
)
print_rank_0
(
"
c
reating memory view of numpy buffer..."
)
self
.
_bin_buffer
=
memoryview
(
self
.
_bin_buffer_mmap
)
self
.
_bin_buffer
=
memoryview
(
self
.
_bin_buffer_mmap
)
print_rank_0
(
"> Done"
)
def
__del__
(
self
):
def
__del__
(
self
):
self
.
_bin_buffer_mmap
.
_mmap
.
close
()
self
.
_bin_buffer_mmap
.
_mmap
.
close
()
...
@@ -470,7 +469,6 @@ class MMapIndexedDataset(torch.utils.data.Dataset):
...
@@ -470,7 +469,6 @@ class MMapIndexedDataset(torch.utils.data.Dataset):
np_array
=
np
.
frombuffer
(
self
.
_bin_buffer
,
dtype
=
self
.
_index
.
dtype
,
count
=
size
,
offset
=
ptr
)
np_array
=
np
.
frombuffer
(
self
.
_bin_buffer
,
dtype
=
self
.
_index
.
dtype
,
count
=
size
,
offset
=
ptr
)
if
self
.
_index
.
dtype
!=
np
.
int64
:
if
self
.
_index
.
dtype
!=
np
.
int64
:
np_array
=
np_array
.
astype
(
np
.
int64
)
np_array
=
np_array
.
astype
(
np
.
int64
)
return
np_array
return
np_array
elif
isinstance
(
idx
,
slice
):
elif
isinstance
(
idx
,
slice
):
start
,
stop
,
step
=
idx
.
indices
(
len
(
self
))
start
,
stop
,
step
=
idx
.
indices
(
len
(
self
))
...
@@ -492,6 +490,12 @@ class MMapIndexedDataset(torch.utils.data.Dataset):
...
@@ -492,6 +490,12 @@ class MMapIndexedDataset(torch.utils.data.Dataset):
def
doc_idx
(
self
):
def
doc_idx
(
self
):
return
self
.
_index
.
doc_idx
return
self
.
_index
.
doc_idx
def
get_doc_idx
(
self
):
return
self
.
_index
.
_doc_idx
def
set_doc_idx
(
self
,
doc_idx_
):
self
.
_index
.
_doc_idx
=
doc_idx_
@
property
@
property
def
supports_prefetch
(
self
):
def
supports_prefetch
(
self
):
return
False
return
False
...
...
megatron/data/split_dataset.py
View file @
0601702a
...
@@ -13,43 +13,34 @@
...
@@ -13,43 +13,34 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
"""dataset to split one large one into multiple smaller datasets"""
"""dataset to split one large one into multiple smaller datasets"""
import
torch
import
torch
import
numpy
as
np
import
numpy
as
np
def
should_split
(
split
):
def
get_train_valid_test_split
(
splits_string
,
size
):
"""
""" Get dataset splits from comma or '/' separated string list."""
given split proportions checks if should split
Examples:
>>> should_split([10,0,0])
False
>>> should_split([1,.1,.2])
True
"""
return
max
(
split
)
/
sum
(
split
)
!=
1.
def
get_split
(
args
):
"""
Get dataset splits from comma separated string list
"""
splits
=
[]
splits
=
[]
if
args
.
split
.
find
(
','
)
!=
-
1
:
if
split
s_string
.
find
(
','
)
!=
-
1
:
splits
=
[
float
(
s
)
for
s
in
args
.
split
.
split
(
','
)]
splits
=
[
float
(
s
)
for
s
in
split
s_string
.
split
(
','
)]
elif
args
.
split
.
find
(
'/'
)
!=
-
1
:
elif
split
s_string
.
find
(
'/'
)
!=
-
1
:
splits
=
[
float
(
s
)
for
s
in
args
.
split
.
split
(
'/'
)]
splits
=
[
float
(
s
)
for
s
in
split
s_string
.
split
(
'/'
)]
else
:
else
:
splits
=
[
float
(
args
.
split
)]
splits
=
[
float
(
splits_string
)]
split_total
=
sum
(
splits
)
if
split_total
<
1.
:
splits
.
append
(
1
-
split_total
)
while
len
(
splits
)
<
3
:
while
len
(
splits
)
<
3
:
splits
.
append
(
0.
)
splits
.
append
(
0.
)
splits
=
splits
[:
3
]
splits
=
splits
[:
3
]
if
args
.
valid_data
is
not
None
:
splits_sum
=
sum
(
splits
)
splits
[
1
]
=
0.
assert
splits_sum
>
0.0
if
args
.
test_data
is
not
None
:
splits
=
[
split
/
splits_sum
for
split
in
splits
]
splits
[
2
]
=
0.
splits_index
=
[
0
]
final_sum
=
sum
(
splits
)
for
index
,
split
in
enumerate
(
splits
):
return
[
s
/
final_sum
for
s
in
splits
]
splits_index
.
append
(
splits_index
[
index
]
+
int
(
round
(
split
*
float
(
size
))))
diff
=
splits_index
[
-
1
]
-
size
for
index
in
range
(
1
,
len
(
splits_index
)):
splits_index
[
index
]
-=
diff
return
splits_index
class
SplitDataset
(
torch
.
utils
.
data
.
Dataset
):
class
SplitDataset
(
torch
.
utils
.
data
.
Dataset
):
"""
"""
...
...
pretrain_albert.py
View file @
0601702a
...
@@ -13,21 +13,21 @@
...
@@ -13,21 +13,21 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
"""Pretrain BERT"""
"""Pretrain
AL
BERT"""
import
torch
import
torch
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
configure_data
import
configure_data
from
megatron
import
mpu
from
megatron
import
mpu
from
megatron.model
import
BertModel
from
megatron.model
import
BertModel
from
megatron.utils
import
print_rank_0
from
megatron.utils
import
print_rank_0
from
megatron.utils
import
reduce_losses
from
megatron.utils
import
reduce_losses
from
megatron.utils
import
vocab_size_with_padding
from
megatron.utils
import
vocab_size_with_padding
from
megatron.training
import
run
from
megatron.training
import
run
from
megatron.data
import
A
lbert
D
ataset
,
spli
t_dataset
from
megatron.data
.a
lbert
_d
ataset
import
build_train_valid_tes
t_dataset
s
from
megatron.data_utils.samplers
import
DistributedBatchSampler
from
megatron.data_utils.samplers
import
DistributedBatchSampler
def
model_provider
(
args
):
def
model_provider
(
args
):
"""Build the model."""
"""Build the model."""
...
@@ -109,94 +109,98 @@ def forward_step(data_iterator, model, args, timers):
...
@@ -109,94 +109,98 @@ def forward_step(data_iterator, model, args, timers):
def
get_train_val_test_data
(
args
):
def
get_train_val_test_data
(
args
):
"""Load the data on rank zero and boradcast number of tokens to all GPUS."""
"""Load the data on rank zero and boradcast number of tokens to all GPUS."""
(
train_data
,
val_data
,
test_data
)
=
(
None
,
None
,
None
)
(
train_data
,
val
id
_data
,
test_data
)
=
(
None
,
None
,
None
)
# Data loader only on rank 0 of each model parallel group.
# Data loader only on rank 0 of each model parallel group.
if
mpu
.
get_model_parallel_rank
()
==
0
:
if
mpu
.
get_model_parallel_rank
()
==
0
:
if
args
.
data_loader
==
None
:
print_rank_0
(
'> building train, validation, and test datasets '
'for ALBERT ...'
)
if
args
.
data_loader
is
None
:
args
.
data_loader
=
'binary'
args
.
data_loader
=
'binary'
if
args
.
data_loader
==
'binary'
:
if
args
.
data_loader
!=
'binary'
:
if
not
args
.
max_num_samples
:
print
(
'Unsupported {} data loader for ALBERT.'
.
format
(
args
.
max_num_samples
=
(
args
.
train_iters
+
2
*
args
.
eval_iters
)
*
args
.
batch_size
args
.
data_loader
))
if
not
args
.
data_path
:
exit
(
1
)
print
(
"Albert currently only supports a unified dataset specified with --data-path"
)
if
not
args
.
data_path
:
exit
(
1
)
print
(
'ALBERT only supports a unified dataset specified '
print_rank_0
(
"Creating AlbertDataset..."
)
'with --data-path'
)
full_data
=
AlbertDataset
(
vocab_file
=
args
.
vocab
,
data_prefix
=
args
.
data_path
,
data_impl
=
args
.
data_impl
,
skip_warmup
=
args
.
skip_mmap_warmup
,
num_epochs
=
args
.
data_epochs
,
max_num_samples
=
args
.
max_num_samples
,
masked_lm_prob
=
args
.
mask_prob
,
max_seq_length
=
args
.
seq_length
,
short_seq_prob
=
args
.
short_seq_prob
,
seed
=
args
.
seed
)
print_rank_0
(
"Finished creating AlbertDataset..."
)
split
=
split_dataset
.
get_split
(
args
)
if
split_dataset
.
should_split
(
split
):
train_ds
,
val_ds
,
test_ds
=
split_dataset
.
split_ds
(
full_data
,
split
,
args
.
shuffle
)
else
:
train_ds
=
full_data
num_tokens
=
train_ds
.
num_tokens
()
world_size
=
mpu
.
get_data_parallel_world_size
()
rank
=
mpu
.
get_data_parallel_rank
()
global_batch_size
=
args
.
batch_size
*
world_size
num_workers
=
args
.
num_workers
def
make_data_loader_
(
dataset
):
if
not
dataset
:
return
None
# Use a simple sampler with distributed batch sampler.
sampler
=
torch
.
utils
.
data
.
SequentialSampler
(
dataset
)
batch_sampler
=
DistributedBatchSampler
(
sampler
=
sampler
,
batch_size
=
global_batch_size
,
drop_last
=
True
,
rank
=
rank
,
world_size
=
world_size
)
# Torch dataloader.
return
torch
.
utils
.
data
.
DataLoader
(
dataset
,
batch_sampler
=
batch_sampler
,
num_workers
=
num_workers
,
pin_memory
=
True
)
train_data
=
make_data_loader_
(
train_ds
)
valid_data
=
make_data_loader_
(
val_ds
)
test_data
=
make_data_loader_
(
test_ds
)
do_train
=
train_data
is
not
None
and
args
.
train_iters
>
0
do_valid
=
valid_data
is
not
None
and
args
.
eval_iters
>
0
do_test
=
test_data
is
not
None
and
args
.
eval_iters
>
0
# Need to broadcast num_tokens and num_type_tokens.
token_counts
=
torch
.
cuda
.
LongTensor
([
num_tokens
,
2
,
# hard coded num_type_tokens for now
int
(
do_train
),
int
(
do_valid
),
int
(
do_test
)])
else
:
print
(
"Unsupported data loader for BERT."
)
exit
(
1
)
exit
(
1
)
data_parallel_size
=
mpu
.
get_data_parallel_world_size
()
data_parallel_rank
=
mpu
.
get_data_parallel_rank
()
global_batch_size
=
args
.
batch_size
*
data_parallel_size
# Number of train/valid/test samples.
train_iters
=
args
.
train_iters
eval_iters
=
(
train_iters
//
args
.
eval_interval
+
1
)
*
args
.
eval_iters
test_iters
=
args
.
eval_iters
train_val_test_num_samples
=
[
args
.
train_iters
*
global_batch_size
,
eval_iters
*
global_batch_size
,
test_iters
*
global_batch_size
]
print_rank_0
(
' > datasets target sizes (minimum size):'
)
print_rank_0
(
' train: {}'
.
format
(
train_val_test_num_samples
[
0
]))
print_rank_0
(
' validation: {}'
.
format
(
train_val_test_num_samples
[
1
]))
print_rank_0
(
' test: {}'
.
format
(
train_val_test_num_samples
[
2
]))
train_ds
,
valid_ds
,
test_ds
=
build_train_valid_test_datasets
(
vocab_file
=
args
.
vocab
,
data_prefix
=
args
.
data_path
,
data_impl
=
args
.
data_impl
,
splits_string
=
args
.
split
,
train_valid_test_num_samples
=
train_val_test_num_samples
,
max_seq_length
=
args
.
seq_length
,
masked_lm_prob
=
args
.
mask_prob
,
short_seq_prob
=
args
.
short_seq_prob
,
seed
=
args
.
seed
,
skip_warmup
=
args
.
skip_mmap_warmup
)
print_rank_0
(
"> finished creating ALBERT datasets ..."
)
def
make_data_loader_
(
dataset
):
if
not
dataset
:
return
None
# Use a simple sampler with distributed batch sampler.
sampler
=
torch
.
utils
.
data
.
SequentialSampler
(
dataset
)
batch_sampler
=
DistributedBatchSampler
(
sampler
=
sampler
,
batch_size
=
global_batch_size
,
drop_last
=
True
,
rank
=
data_parallel_rank
,
world_size
=
data_parallel_size
)
# Torch dataloader.
return
torch
.
utils
.
data
.
DataLoader
(
dataset
,
batch_sampler
=
batch_sampler
,
num_workers
=
args
.
num_workers
,
pin_memory
=
True
)
train_data
=
make_data_loader_
(
train_ds
)
valid_data
=
make_data_loader_
(
valid_ds
)
test_data
=
make_data_loader_
(
test_ds
)
do_train
=
train_data
is
not
None
and
args
.
train_iters
>
0
do_valid
=
valid_data
is
not
None
and
args
.
eval_iters
>
0
do_test
=
test_data
is
not
None
and
args
.
eval_iters
>
0
# Need to broadcast num_tokens and num_type_tokens.
num_tokens
=
vocab_size_with_padding
(
train_ds
.
num_tokens
(),
args
)
token_counts
=
torch
.
cuda
.
LongTensor
([
num_tokens
,
2
,
# hard coded num_type_tokens
int
(
do_train
),
int
(
do_valid
),
int
(
do_test
)])
else
:
else
:
token_counts
=
torch
.
cuda
.
LongTensor
([
0
,
0
,
0
,
0
,
0
])
token_counts
=
torch
.
cuda
.
LongTensor
([
0
,
0
,
0
,
0
,
0
])
# Broadcast num tokens.
# Broadcast num tokens.
torch
.
distributed
.
broadcast
(
token_counts
,
torch
.
distributed
.
broadcast
(
token_counts
,
mpu
.
get_model_parallel_src_rank
(),
mpu
.
get_model_parallel_src_rank
(),
group
=
mpu
.
get_model_parallel_group
())
group
=
mpu
.
get_model_parallel_group
())
num_tokens
=
token_counts
[
0
].
item
()
args
.
vocab_size
=
token_counts
[
0
].
item
()
num_type_tokens
=
token_counts
[
1
].
item
()
args
.
tokentype_size
=
token_counts
[
1
].
item
()
args
.
do_train
=
token_counts
[
2
].
item
()
args
.
do_train
=
token_counts
[
2
].
item
()
args
.
do_valid
=
token_counts
[
3
].
item
()
args
.
do_valid
=
token_counts
[
3
].
item
()
args
.
do_test
=
token_counts
[
4
].
item
()
args
.
do_test
=
token_counts
[
4
].
item
()
args
.
vocab_size
=
num_tokens
return
train_data
,
valid_data
,
test_data
args
.
tokentype_size
=
num_type_tokens
return
train_data
,
val_data
,
test_data
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment