Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
0601702a
Commit
0601702a
authored
Nov 24, 2019
by
Mohammad Shoeybi
Browse files
zero worker seems to be working
parent
f6a6811f
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
305 additions
and
180 deletions
+305
-180
megatron/data/albert_dataset.py
megatron/data/albert_dataset.py
+154
-24
megatron/data/helpers.cpp
megatron/data/helpers.cpp
+40
-44
megatron/data/indexed_dataset.py
megatron/data/indexed_dataset.py
+13
-9
megatron/data/split_dataset.py
megatron/data/split_dataset.py
+19
-28
pretrain_albert.py
pretrain_albert.py
+79
-75
No files found.
megatron/data/albert_dataset.py
View file @
0601702a
...
...
@@ -8,6 +8,7 @@ import numpy as np
import
torch
from
torch.utils.data
import
Dataset
from
megatron
import
mpu
from
megatron.data
import
helpers
from
megatron.data
import
FullBertTokenizer
from
megatron.data.dataset_utils
import
build_training_sample
...
...
@@ -15,22 +16,97 @@ from megatron.data.indexed_dataset import make_dataset as make_indexed_dataset
from
megatron.utils
import
print_rank_0
def
build_train_valid_test_datasets
(
vocab_file
,
data_prefix
,
data_impl
,
splits_string
,
train_valid_test_num_samples
,
max_seq_length
,
masked_lm_prob
,
short_seq_prob
,
seed
,
skip_warmup
):
# Tokenizer is the same
tokenizer
=
FullBertTokenizer
(
vocab_file
,
do_lower_case
=
True
)
print_rank_0
(
' > using full BERT tokenizer with vocabulary size: {}'
.
format
(
tokenizer
.
vocab_size
()))
# Indexed dataset.
indexed_dataset
=
get_indexed_dataset_
(
data_prefix
,
data_impl
,
skip_warmup
)
# Get start and end indices of train/valid/train into doc-idx
# Note that doc-idx is desinged to be num-docs + 1 so we can
# easily iterate over it.
total_num_of_documents
=
indexed_dataset
.
doc_idx
.
shape
[
0
]
-
1
splits
=
get_train_valid_test_split_
(
splits_string
,
total_num_of_documents
)
# Print stats about the splits.
print_rank_0
(
' > dataset split:'
)
def
print_split_stats
(
name
,
index
):
print_rank_0
(
' {}:'
.
format
(
name
))
print_rank_0
(
' document indices in [{}, {}) total of {} '
'documents'
.
format
(
splits
[
index
],
splits
[
index
+
1
],
splits
[
index
+
1
]
-
splits
[
index
]))
start_index
=
indexed_dataset
.
doc_idx
[
splits
[
index
]]
end_index
=
indexed_dataset
.
doc_idx
[
splits
[
index
+
1
]]
print_rank_0
(
' sentence indices in [{}, {}) total of {} '
'sentences'
.
format
(
start_index
,
end_index
,
end_index
-
start_index
))
print_split_stats
(
'train'
,
0
)
print_split_stats
(
'validation'
,
1
)
print_split_stats
(
'test'
,
2
)
def
build_dataset
(
index
,
name
):
dataset
=
None
if
splits
[
index
+
1
]
>
splits
[
index
]:
# Get the pointer to the original doc-idx so we can set it later.
doc_idx_ptr
=
indexed_dataset
.
get_doc_idx
()
# Slice the doc-idx
start_index
=
splits
[
index
]
# Add +1 so we can index into the dataset to get the upper bound.
end_index
=
splits
[
index
+
1
]
+
1
# New doc_idx view.
indexed_dataset
.
set_doc_idx
(
doc_idx_ptr
[
start_index
:
end_index
])
# Build the dataset accordingly.
dataset
=
AlbertDataset
(
name
=
name
,
indexed_dataset
=
indexed_dataset
,
tokenizer
=
tokenizer
,
data_prefix
=
data_prefix
,
num_epochs
=
None
,
max_num_samples
=
train_valid_test_num_samples
[
index
],
masked_lm_prob
=
masked_lm_prob
,
max_seq_length
=
max_seq_length
,
short_seq_prob
=
short_seq_prob
,
seed
=
seed
)
# Set the original pointer so dataset remains the main dataset.
indexed_dataset
.
set_doc_idx
(
doc_idx_ptr
)
# Checks.
assert
indexed_dataset
.
doc_idx
[
0
]
==
0
assert
indexed_dataset
.
doc_idx
.
shape
[
0
]
==
\
(
total_num_of_documents
+
1
)
return
dataset
train_dataset
=
build_dataset
(
0
,
'train'
)
valid_dataset
=
build_dataset
(
1
,
'valid'
)
test_dataset
=
build_dataset
(
2
,
'test'
)
return
(
train_dataset
,
valid_dataset
,
test_dataset
)
class
AlbertDataset
(
Dataset
):
def
__init__
(
self
,
vocab_file
,
data_prefix
,
data_impl
,
skip_warmup
,
num_epochs
,
max_num_samples
,
masked_lm_prob
,
max_seq_length
,
short_seq_prob
,
seed
):
def
__init__
(
self
,
name
,
indexed_dataset
,
tokenizer
,
data_prefix
,
num_epochs
,
max_num_samples
,
masked_lm_prob
,
max_seq_length
,
short_seq_prob
,
seed
):
# Params to store.
self
.
name
=
name
self
.
seed
=
seed
self
.
masked_lm_prob
=
masked_lm_prob
self
.
max_seq_length
=
max_seq_length
self
.
tokenizer
=
FullBertTokenizer
(
vocab_file
,
do_lower_case
=
True
)
#
Indexe
d dataset.
self
.
indexed_dataset
=
get_indexed_dataset_
(
data_prefix
,
data_impl
,
skip_warmup
)
#
Tokenizer an
d dataset.
self
.
tokenizer
=
tokenizer
self
.
indexed_dataset
=
indexed_dataset
# Build the samples mapping.
self
.
samples_mapping
=
get_samples_mapping_
(
self
.
indexed_dataset
,
...
...
@@ -39,7 +115,8 @@ class AlbertDataset(Dataset):
max_num_samples
,
self
.
max_seq_length
,
short_seq_prob
,
self
.
seed
)
self
.
seed
,
self
.
name
)
# Vocab stuff.
self
.
vocab_id_list
=
list
(
self
.
tokenizer
.
inv_vocab
.
keys
())
...
...
@@ -48,7 +125,6 @@ class AlbertDataset(Dataset):
self
.
sep_id
=
self
.
tokenizer
.
vocab
[
'[SEP]'
]
self
.
mask_id
=
self
.
tokenizer
.
vocab
[
'[MASK]'
]
self
.
pad_id
=
self
.
tokenizer
.
vocab
[
'[PAD]'
]
exit
()
def
num_tokens
(
self
):
...
...
@@ -68,9 +144,11 @@ class AlbertDataset(Dataset):
sample
=
[]
for
index
in
range
(
start_index
,
end_index
):
sample
.
append
(
self
.
indexed_dataset
[
index
])
'''
for s in sample:
if len(s) > 1000:
print(self.tokenizer.convert_ids_to_tokens(s))
'''
return
build_training_sample
(
sample
,
seq_length
,
self
.
max_seq_length
,
# needed for padding
self
.
vocab_id_list
,
...
...
@@ -80,25 +158,63 @@ class AlbertDataset(Dataset):
self
.
masked_lm_prob
,
rng
)
def
get_indexed_dataset_
(
data_prefix
,
data_impl
,
skip_warmup
):
print_rank_0
(
' > building dataset index ...'
)
start_time
=
time
.
time
()
print_rank_0
(
"> Reading dataset index ..."
)
indexed_dataset
=
make_indexed_dataset
(
data_prefix
,
data_impl
,
skip_warmup
)
print_rank_0
(
"> Finished creating indexed dataset in {:4f} "
"seconds"
.
format
(
time
.
time
()
-
start_time
))
assert
indexed_dataset
.
sizes
.
shape
[
0
]
==
indexed_dataset
.
doc_idx
[
-
1
]
print_rank_0
(
' > finished creating indexed dataset in {:4f} '
'seconds'
.
format
(
time
.
time
()
-
start_time
))
print_rank_0
(
' > indexed dataset stats:'
)
print_rank_0
(
' number of documents: {}'
.
format
(
indexed_dataset
.
doc_idx
.
shape
[
0
]
-
1
))
print_rank_0
(
' number of sentences: {}'
.
format
(
indexed_dataset
.
sizes
.
shape
[
0
]))
return
indexed_dataset
def
get_train_valid_test_split_
(
splits_string
,
size
):
""" Get dataset splits from comma or '/' separated string list."""
splits
=
[]
if
splits_string
.
find
(
','
)
!=
-
1
:
splits
=
[
float
(
s
)
for
s
in
splits_string
.
split
(
','
)]
elif
splits_string
.
find
(
'/'
)
!=
-
1
:
splits
=
[
float
(
s
)
for
s
in
splits_string
.
split
(
'/'
)]
else
:
splits
=
[
float
(
splits_string
)]
while
len
(
splits
)
<
3
:
splits
.
append
(
0.
)
splits
=
splits
[:
3
]
splits_sum
=
sum
(
splits
)
assert
splits_sum
>
0.0
splits
=
[
split
/
splits_sum
for
split
in
splits
]
splits_index
=
[
0
]
for
index
,
split
in
enumerate
(
splits
):
splits_index
.
append
(
splits_index
[
index
]
+
int
(
round
(
split
*
float
(
size
))))
diff
=
splits_index
[
-
1
]
-
size
for
index
in
range
(
1
,
len
(
splits_index
)):
splits_index
[
index
]
-=
diff
assert
len
(
splits_index
)
==
4
assert
splits_index
[
-
1
]
==
size
return
splits_index
def
get_samples_mapping_
(
indexed_dataset
,
data_prefix
,
num_epochs
,
max_num_samples
,
max_seq_length
,
short_seq_prob
,
seed
):
seed
,
name
):
if
not
num_epochs
:
if
not
max_num_samples
:
raise
ValueError
(
"Need to specify either max_num_samples "
...
...
@@ -109,9 +225,11 @@ def get_samples_mapping_(indexed_dataset,
# Filename of the index mapping
indexmap_filename
=
data_prefix
indexmap_filename
+=
'_indexmap'
indexmap_filename
+=
'_{}ep'
.
format
(
num_epochs
)
indexmap_filename
+=
'_{}mns'
.
format
(
max_num_samples
)
indexmap_filename
+=
'_{}_indexmap'
.
format
(
name
)
if
num_epochs
!=
(
np
.
iinfo
(
np
.
int32
).
max
-
1
):
indexmap_filename
+=
'_{}ep'
.
format
(
num_epochs
)
if
max_num_samples
!=
(
np
.
iinfo
(
np
.
int64
).
max
-
1
):
indexmap_filename
+=
'_{}mns'
.
format
(
max_num_samples
)
indexmap_filename
+=
'_{}msl'
.
format
(
max_seq_length
)
indexmap_filename
+=
'_{:0.2f}ssp'
.
format
(
short_seq_prob
)
indexmap_filename
+=
'_{}s'
.
format
(
seed
)
...
...
@@ -120,8 +238,9 @@ def get_samples_mapping_(indexed_dataset,
# Build the indexed mapping if not exist.
if
torch
.
distributed
.
get_rank
()
==
0
and
\
not
os
.
path
.
isfile
(
indexmap_filename
):
print
(
'WARNING: could not find index map file {}, building '
print
(
'
>
WARNING: could not find index map file {}, building '
'the indices on rank 0 ...'
.
format
(
indexmap_filename
))
# Make sure the types match the helpers input types.
assert
indexed_dataset
.
doc_idx
.
dtype
==
np
.
int64
assert
indexed_dataset
.
sizes
.
dtype
==
np
.
int32
...
...
@@ -129,6 +248,8 @@ def get_samples_mapping_(indexed_dataset,
# Build samples mapping
verbose
=
torch
.
distributed
.
get_rank
()
==
0
start_time
=
time
.
time
()
print_rank_0
(
' > building sapmles index mapping for {} ...'
.
format
(
name
))
samples_mapping
=
helpers
.
build_mapping
(
indexed_dataset
.
doc_idx
,
indexed_dataset
.
sizes
,
...
...
@@ -138,21 +259,30 @@ def get_samples_mapping_(indexed_dataset,
short_seq_prob
,
seed
,
verbose
)
print_rank_0
(
' > done building sapmles index maping'
)
np
.
save
(
indexmap_filename
,
samples_mapping
,
allow_pickle
=
True
)
print_rank_0
(
' > saved the index mapping in {}'
.
format
(
indexmap_filename
))
# Make sure all the ranks have built the mapping
print_rank_0
(
'> elasped time to build and save samples mapping '
print_rank_0
(
'
> elasped time to build and save samples mapping '
'(seconds): {:4f}'
.
format
(
time
.
time
()
-
start_time
))
torch
.
distributed
.
barrier
()
# This should be a barrier but nccl barrier assumes
# device_index=rank which is not the case for model
# parallel case
counts
=
torch
.
cuda
.
LongTensor
([
1
])
torch
.
distributed
.
all_reduce
(
counts
,
group
=
mpu
.
get_data_parallel_group
())
assert
counts
[
0
].
item
()
==
torch
.
distributed
.
get_world_size
(
group
=
mpu
.
get_data_parallel_group
())
# Load indexed dataset.
print_rank_0
(
'> loading indexed mapping from {}'
.
format
(
print_rank_0
(
'
> loading indexed mapping from {}'
.
format
(
indexmap_filename
))
start_time
=
time
.
time
()
samples_mapping
=
np
.
load
(
indexmap_filename
,
allow_pickle
=
True
)
print_rank_0
(
' loaded indexed file in {:3.3f} seconds'
.
format
(
print_rank_0
(
'
loaded indexed file in {:3.3f} seconds'
.
format
(
time
.
time
()
-
start_time
))
print_rank_0
(
' total number of samples: {}'
.
format
(
print_rank_0
(
'
total number of samples: {}'
.
format
(
samples_mapping
.
shape
[
0
]))
return
samples_mapping
...
...
megatron/data/helpers.cpp
View file @
0601702a
...
...
@@ -39,12 +39,6 @@ py::array build_mapping_impl(const py::array_t<int64_t>& docs_,
and sequence-length is the target sequence length.
*/
if
(
verbose
)
{
cout
<<
" > using "
<<
docs_
.
shape
(
0
)
-
1
<<
" documents with "
<<
sizes_
.
shape
(
0
)
<<
" sentences ..."
<<
endl
<<
std
::
flush
;
}
// Consistency checks.
assert
(
num_epochs
>
0
);
assert
(
max_seq_length
>
1
);
...
...
@@ -52,16 +46,36 @@ py::array build_mapping_impl(const py::array_t<int64_t>& docs_,
assert
(
short_seq_prob
<=
1.0
);
assert
(
seed
>
0
);
// For efficiency, convert probability to ratio. Note: rand() generates int.
const
auto
short_seq_ratio
=
static_cast
<
int32_t
>
(
round
(
1.0
/
short_seq_prob
));
// Remove bound checks.
auto
docs
=
docs_
.
unchecked
<
1
>
();
auto
sizes
=
sizes_
.
unchecked
<
1
>
();
if
(
docs
[
docs
.
shape
(
0
)
-
1
]
!=
sizes
.
shape
(
0
))
{
cout
<<
"document values is not consistent with length of sizes: "
<<
docs
[
docs
.
shape
(
0
)
-
1
]
<<
" != "
<<
sizes
.
shape
(
0
)
<<
endl
;
throw
std
::
length_error
(
"docs and sizes"
);
// For efficiency, convert probability to ratio. Note: rand() generates int.
const
auto
short_seq_ratio
=
static_cast
<
int32_t
>
(
round
(
1.0
/
short_seq_prob
));
if
(
verbose
)
{
const
auto
sent_start_index
=
docs
[
0
];
const
auto
sent_end_index
=
docs
[
docs_
.
shape
(
0
)
-
1
];
const
auto
num_sentences
=
sent_end_index
-
sent_start_index
;
cout
<<
" using:"
<<
endl
<<
std
::
flush
;
cout
<<
" number of documents: "
<<
docs_
.
shape
(
0
)
-
1
<<
endl
<<
std
::
flush
;
cout
<<
" sentences range: ["
<<
sent_start_index
<<
", "
<<
sent_end_index
<<
")"
<<
endl
<<
std
::
flush
;
cout
<<
" total number of sentences: "
<<
num_sentences
<<
endl
<<
std
::
flush
;
cout
<<
" number of epochs: "
<<
num_epochs
<<
endl
<<
std
::
flush
;
cout
<<
" maximum number of samples: "
<<
max_num_samples
<<
endl
<<
std
::
flush
;
cout
<<
" maximum sequence length: "
<<
max_seq_length
<<
endl
<<
std
::
flush
;
cout
<<
" short sequence probability: "
<<
short_seq_prob
<<
endl
<<
std
::
flush
;
cout
<<
" short sequence ration (1/prob): "
<<
short_seq_ratio
<<
endl
<<
std
::
flush
;
cout
<<
" seed: "
<<
seed
<<
endl
<<
std
::
flush
;
}
// Mapping and it's length (1D).
...
...
@@ -90,7 +104,7 @@ py::array build_mapping_impl(const py::array_t<int64_t>& docs_,
for
(
int32_t
epoch
=
0
;
epoch
<
num_epochs
;
++
epoch
)
{
if
(
map_index
>=
max_num_samples
)
{
if
(
verbose
&&
(
!
second
))
{
cout
<<
"
>
reached "
<<
max_num_samples
<<
" samples after "
cout
<<
"
reached "
<<
max_num_samples
<<
" samples after "
<<
epoch
<<
" epochs ..."
<<
endl
<<
std
::
flush
;
}
break
;
...
...
@@ -181,11 +195,11 @@ py::array build_mapping_impl(const py::array_t<int64_t>& docs_,
if
(
!
second
)
{
if
(
verbose
)
{
cout
<<
"
>
number of empty documents: "
<<
empty_docs
<<
cout
<<
"
number of empty documents: "
<<
empty_docs
<<
endl
<<
std
::
flush
;
cout
<<
"
>
number of documents with one sentence: "
<<
cout
<<
"
number of documents with one sentence: "
<<
one_sent_docs
<<
endl
<<
std
::
flush
;
cout
<<
"
>
will create mapping for "
<<
map_index
<<
cout
<<
"
will create mapping for "
<<
map_index
<<
" samples"
<<
endl
<<
std
::
flush
;
}
assert
(
maps
==
NULL
);
...
...
@@ -210,10 +224,6 @@ py::array build_mapping_impl(const py::array_t<int64_t>& docs_,
swap
(
maps
[
i0
+
2
],
maps
[
j0
+
2
]);
}
if
(
verbose
)
{
cout
<<
"> done building the mapping."
<<
endl
;
}
// Method to deallocate memory.
py
::
capsule
free_when_done
(
maps
,
[](
void
*
mem_
)
{
DocIdx
*
mem
=
reinterpret_cast
<
DocIdx
*>
(
mem_
);
...
...
@@ -239,34 +249,20 @@ py::array build_mapping(const py::array_t<int64_t>& docs_,
const
int
seed
,
const
bool
verbose
)
{
if
(
verbose
)
{
cout
<<
"> building sample map using: "
<<
endl
<<
std
::
flush
;
cout
<<
" number of epochs: "
<<
num_epochs
<<
endl
<<
std
::
flush
;
cout
<<
" maximum number of samples: "
<<
max_num_samples
<<
endl
<<
std
::
flush
;
cout
<<
" maximum sequence length: "
<<
max_seq_length
<<
endl
<<
std
::
flush
;
cout
<<
" short sequence probability: "
<<
short_seq_prob
<<
endl
<<
std
::
flush
;
cout
<<
" seed: "
<<
seed
<<
endl
<<
std
::
flush
;
}
if
(
sizes_
.
size
()
>
std
::
numeric_limits
<
uint32_t
>::
max
())
{
if
(
verbose
)
{
cout
<<
"
>
using uint64 for data mapping..."
<<
endl
<<
std
::
flush
;
}
return
build_mapping_impl
<
uint64_t
>
(
docs_
,
sizes_
,
num_epochs
,
cout
<<
"
using uint64 for data mapping..."
<<
endl
<<
std
::
flush
;
}
return
build_mapping_impl
<
uint64_t
>
(
docs_
,
sizes_
,
num_epochs
,
max_num_samples
,
max_seq_length
,
short_seq_prob
,
seed
,
verbose
);
}
else
{
if
(
verbose
)
{
cout
<<
"
>
using uint32 for data mapping..."
<<
endl
<<
std
::
flush
;
}
return
build_mapping_impl
<
uint32_t
>
(
docs_
,
sizes_
,
num_epochs
,
max_num_samples
,
max_seq_length
,
short_seq_prob
,
seed
,
verbose
);
if
(
verbose
)
{
cout
<<
"
using uint32 for data mapping..."
<<
endl
<<
std
::
flush
;
}
return
build_mapping_impl
<
uint32_t
>
(
docs_
,
sizes_
,
num_epochs
,
max_num_samples
,
max_seq_length
,
short_seq_prob
,
seed
,
verbose
);
}
}
...
...
megatron/data/indexed_dataset.py
View file @
0601702a
...
...
@@ -391,17 +391,17 @@ class MMapIndexedDataset(torch.utils.data.Dataset):
offset
=
stream
.
tell
()
if
not
skip_warmup
:
print_rank_0
(
"
>
W
arming up index mmap file..."
)
print_rank_0
(
"
w
arming up index mmap file..."
)
_warmup_mmap_file
(
path
)
self
.
_bin_buffer_mmap
=
np
.
memmap
(
path
,
mode
=
'r'
,
order
=
'C'
)
self
.
_bin_buffer
=
memoryview
(
self
.
_bin_buffer_mmap
)
print_rank_0
(
"
>
R
eading sizes..."
)
print_rank_0
(
"
r
eading sizes..."
)
self
.
_sizes
=
np
.
frombuffer
(
self
.
_bin_buffer
,
dtype
=
np
.
int32
,
count
=
self
.
_len
,
offset
=
offset
)
print_rank_0
(
"
>
R
eading pointers..."
)
print_rank_0
(
"
r
eading pointers..."
)
self
.
_pointers
=
np
.
frombuffer
(
self
.
_bin_buffer
,
dtype
=
np
.
int64
,
count
=
self
.
_len
,
offset
=
offset
+
self
.
_sizes
.
nbytes
)
print_rank_0
(
"
>
R
eading document index..."
)
print_rank_0
(
"
r
eading document index..."
)
self
.
_doc_idx
=
np
.
frombuffer
(
self
.
_bin_buffer
,
dtype
=
np
.
int64
,
count
=
self
.
_doc_count
,
offset
=
offset
+
self
.
_sizes
.
nbytes
+
self
.
_pointers
.
nbytes
)
def
__del__
(
self
):
...
...
@@ -447,13 +447,12 @@ class MMapIndexedDataset(torch.utils.data.Dataset):
self
.
_index
=
self
.
Index
(
index_file_path
(
self
.
_path
),
skip_warmup
)
if
not
skip_warmup
:
print_rank_0
(
"
>
W
arming up data mmap file..."
)
print_rank_0
(
"
w
arming up data mmap file..."
)
_warmup_mmap_file
(
data_file_path
(
self
.
_path
))
print_rank_0
(
"
>
C
reating numpy buffer of mmap..."
)
print_rank_0
(
"
c
reating numpy buffer of mmap..."
)
self
.
_bin_buffer_mmap
=
np
.
memmap
(
data_file_path
(
self
.
_path
),
mode
=
'r'
,
order
=
'C'
)
print_rank_0
(
"
>
C
reating memory view of numpy buffer..."
)
print_rank_0
(
"
c
reating memory view of numpy buffer..."
)
self
.
_bin_buffer
=
memoryview
(
self
.
_bin_buffer_mmap
)
print_rank_0
(
"> Done"
)
def
__del__
(
self
):
self
.
_bin_buffer_mmap
.
_mmap
.
close
()
...
...
@@ -470,7 +469,6 @@ class MMapIndexedDataset(torch.utils.data.Dataset):
np_array
=
np
.
frombuffer
(
self
.
_bin_buffer
,
dtype
=
self
.
_index
.
dtype
,
count
=
size
,
offset
=
ptr
)
if
self
.
_index
.
dtype
!=
np
.
int64
:
np_array
=
np_array
.
astype
(
np
.
int64
)
return
np_array
elif
isinstance
(
idx
,
slice
):
start
,
stop
,
step
=
idx
.
indices
(
len
(
self
))
...
...
@@ -492,6 +490,12 @@ class MMapIndexedDataset(torch.utils.data.Dataset):
def
doc_idx
(
self
):
return
self
.
_index
.
doc_idx
def
get_doc_idx
(
self
):
return
self
.
_index
.
_doc_idx
def
set_doc_idx
(
self
,
doc_idx_
):
self
.
_index
.
_doc_idx
=
doc_idx_
@
property
def
supports_prefetch
(
self
):
return
False
...
...
megatron/data/split_dataset.py
View file @
0601702a
...
...
@@ -13,43 +13,34 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""dataset to split one large one into multiple smaller datasets"""
import
torch
import
numpy
as
np
def
should_split
(
split
):
"""
given split proportions checks if should split
Examples:
>>> should_split([10,0,0])
False
>>> should_split([1,.1,.2])
True
"""
return
max
(
split
)
/
sum
(
split
)
!=
1.
def
get_train_valid_test_split
(
splits_string
,
size
):
""" Get dataset splits from comma or '/' separated string list."""
def
get_split
(
args
):
"""
Get dataset splits from comma separated string list
"""
splits
=
[]
if
args
.
split
.
find
(
','
)
!=
-
1
:
splits
=
[
float
(
s
)
for
s
in
args
.
split
.
split
(
','
)]
elif
args
.
split
.
find
(
'/'
)
!=
-
1
:
splits
=
[
float
(
s
)
for
s
in
args
.
split
.
split
(
'/'
)]
if
split
s_string
.
find
(
','
)
!=
-
1
:
splits
=
[
float
(
s
)
for
s
in
split
s_string
.
split
(
','
)]
elif
split
s_string
.
find
(
'/'
)
!=
-
1
:
splits
=
[
float
(
s
)
for
s
in
split
s_string
.
split
(
'/'
)]
else
:
splits
=
[
float
(
args
.
split
)]
split_total
=
sum
(
splits
)
if
split_total
<
1.
:
splits
.
append
(
1
-
split_total
)
splits
=
[
float
(
splits_string
)]
while
len
(
splits
)
<
3
:
splits
.
append
(
0.
)
splits
=
splits
[:
3
]
if
args
.
valid_data
is
not
None
:
splits
[
1
]
=
0.
if
args
.
test_data
is
not
None
:
splits
[
2
]
=
0.
final_sum
=
sum
(
splits
)
return
[
s
/
final_sum
for
s
in
splits
]
splits_sum
=
sum
(
splits
)
assert
splits_sum
>
0.0
splits
=
[
split
/
splits_sum
for
split
in
splits
]
splits_index
=
[
0
]
for
index
,
split
in
enumerate
(
splits
):
splits_index
.
append
(
splits_index
[
index
]
+
int
(
round
(
split
*
float
(
size
))))
diff
=
splits_index
[
-
1
]
-
size
for
index
in
range
(
1
,
len
(
splits_index
)):
splits_index
[
index
]
-=
diff
return
splits_index
class
SplitDataset
(
torch
.
utils
.
data
.
Dataset
):
"""
...
...
pretrain_albert.py
View file @
0601702a
...
...
@@ -13,21 +13,21 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Pretrain BERT"""
"""Pretrain
AL
BERT"""
import
torch
import
torch.nn.functional
as
F
from
configure_data
import
configure_data
from
megatron
import
mpu
from
megatron.model
import
BertModel
from
megatron.utils
import
print_rank_0
from
megatron.utils
import
reduce_losses
from
megatron.utils
import
vocab_size_with_padding
from
megatron.training
import
run
from
megatron.data
import
A
lbert
D
ataset
,
spli
t_dataset
from
megatron.data
.a
lbert
_d
ataset
import
build_train_valid_tes
t_dataset
s
from
megatron.data_utils.samplers
import
DistributedBatchSampler
def
model_provider
(
args
):
"""Build the model."""
...
...
@@ -109,94 +109,98 @@ def forward_step(data_iterator, model, args, timers):
def
get_train_val_test_data
(
args
):
"""Load the data on rank zero and boradcast number of tokens to all GPUS."""
(
train_data
,
val_data
,
test_data
)
=
(
None
,
None
,
None
)
(
train_data
,
val
id
_data
,
test_data
)
=
(
None
,
None
,
None
)
# Data loader only on rank 0 of each model parallel group.
if
mpu
.
get_model_parallel_rank
()
==
0
:
if
args
.
data_loader
==
None
:
print_rank_0
(
'> building train, validation, and test datasets '
'for ALBERT ...'
)
if
args
.
data_loader
is
None
:
args
.
data_loader
=
'binary'
if
args
.
data_loader
==
'binary'
:
if
not
args
.
max_num_samples
:
args
.
max_num_samples
=
(
args
.
train_iters
+
2
*
args
.
eval_iters
)
*
args
.
batch_size
if
not
args
.
data_path
:
print
(
"Albert currently only supports a unified dataset specified with --data-path"
)
exit
(
1
)
print_rank_0
(
"Creating AlbertDataset..."
)
full_data
=
AlbertDataset
(
vocab_file
=
args
.
vocab
,
data_prefix
=
args
.
data_path
,
data_impl
=
args
.
data_impl
,
skip_warmup
=
args
.
skip_mmap_warmup
,
num_epochs
=
args
.
data_epochs
,
max_num_samples
=
args
.
max_num_samples
,
masked_lm_prob
=
args
.
mask_prob
,
max_seq_length
=
args
.
seq_length
,
short_seq_prob
=
args
.
short_seq_prob
,
seed
=
args
.
seed
)
print_rank_0
(
"Finished creating AlbertDataset..."
)
split
=
split_dataset
.
get_split
(
args
)
if
split_dataset
.
should_split
(
split
):
train_ds
,
val_ds
,
test_ds
=
split_dataset
.
split_ds
(
full_data
,
split
,
args
.
shuffle
)
else
:
train_ds
=
full_data
num_tokens
=
train_ds
.
num_tokens
()
world_size
=
mpu
.
get_data_parallel_world_size
()
rank
=
mpu
.
get_data_parallel_rank
()
global_batch_size
=
args
.
batch_size
*
world_size
num_workers
=
args
.
num_workers
def
make_data_loader_
(
dataset
):
if
not
dataset
:
return
None
# Use a simple sampler with distributed batch sampler.
sampler
=
torch
.
utils
.
data
.
SequentialSampler
(
dataset
)
batch_sampler
=
DistributedBatchSampler
(
sampler
=
sampler
,
batch_size
=
global_batch_size
,
drop_last
=
True
,
rank
=
rank
,
world_size
=
world_size
)
# Torch dataloader.
return
torch
.
utils
.
data
.
DataLoader
(
dataset
,
batch_sampler
=
batch_sampler
,
num_workers
=
num_workers
,
pin_memory
=
True
)
train_data
=
make_data_loader_
(
train_ds
)
valid_data
=
make_data_loader_
(
val_ds
)
test_data
=
make_data_loader_
(
test_ds
)
do_train
=
train_data
is
not
None
and
args
.
train_iters
>
0
do_valid
=
valid_data
is
not
None
and
args
.
eval_iters
>
0
do_test
=
test_data
is
not
None
and
args
.
eval_iters
>
0
# Need to broadcast num_tokens and num_type_tokens.
token_counts
=
torch
.
cuda
.
LongTensor
([
num_tokens
,
2
,
# hard coded num_type_tokens for now
int
(
do_train
),
int
(
do_valid
),
int
(
do_test
)])
else
:
print
(
"Unsupported data loader for BERT."
)
if
args
.
data_loader
!=
'binary'
:
print
(
'Unsupported {} data loader for ALBERT.'
.
format
(
args
.
data_loader
))
exit
(
1
)
if
not
args
.
data_path
:
print
(
'ALBERT only supports a unified dataset specified '
'with --data-path'
)
exit
(
1
)
data_parallel_size
=
mpu
.
get_data_parallel_world_size
()
data_parallel_rank
=
mpu
.
get_data_parallel_rank
()
global_batch_size
=
args
.
batch_size
*
data_parallel_size
# Number of train/valid/test samples.
train_iters
=
args
.
train_iters
eval_iters
=
(
train_iters
//
args
.
eval_interval
+
1
)
*
args
.
eval_iters
test_iters
=
args
.
eval_iters
train_val_test_num_samples
=
[
args
.
train_iters
*
global_batch_size
,
eval_iters
*
global_batch_size
,
test_iters
*
global_batch_size
]
print_rank_0
(
' > datasets target sizes (minimum size):'
)
print_rank_0
(
' train: {}'
.
format
(
train_val_test_num_samples
[
0
]))
print_rank_0
(
' validation: {}'
.
format
(
train_val_test_num_samples
[
1
]))
print_rank_0
(
' test: {}'
.
format
(
train_val_test_num_samples
[
2
]))
train_ds
,
valid_ds
,
test_ds
=
build_train_valid_test_datasets
(
vocab_file
=
args
.
vocab
,
data_prefix
=
args
.
data_path
,
data_impl
=
args
.
data_impl
,
splits_string
=
args
.
split
,
train_valid_test_num_samples
=
train_val_test_num_samples
,
max_seq_length
=
args
.
seq_length
,
masked_lm_prob
=
args
.
mask_prob
,
short_seq_prob
=
args
.
short_seq_prob
,
seed
=
args
.
seed
,
skip_warmup
=
args
.
skip_mmap_warmup
)
print_rank_0
(
"> finished creating ALBERT datasets ..."
)
def
make_data_loader_
(
dataset
):
if
not
dataset
:
return
None
# Use a simple sampler with distributed batch sampler.
sampler
=
torch
.
utils
.
data
.
SequentialSampler
(
dataset
)
batch_sampler
=
DistributedBatchSampler
(
sampler
=
sampler
,
batch_size
=
global_batch_size
,
drop_last
=
True
,
rank
=
data_parallel_rank
,
world_size
=
data_parallel_size
)
# Torch dataloader.
return
torch
.
utils
.
data
.
DataLoader
(
dataset
,
batch_sampler
=
batch_sampler
,
num_workers
=
args
.
num_workers
,
pin_memory
=
True
)
train_data
=
make_data_loader_
(
train_ds
)
valid_data
=
make_data_loader_
(
valid_ds
)
test_data
=
make_data_loader_
(
test_ds
)
do_train
=
train_data
is
not
None
and
args
.
train_iters
>
0
do_valid
=
valid_data
is
not
None
and
args
.
eval_iters
>
0
do_test
=
test_data
is
not
None
and
args
.
eval_iters
>
0
# Need to broadcast num_tokens and num_type_tokens.
num_tokens
=
vocab_size_with_padding
(
train_ds
.
num_tokens
(),
args
)
token_counts
=
torch
.
cuda
.
LongTensor
([
num_tokens
,
2
,
# hard coded num_type_tokens
int
(
do_train
),
int
(
do_valid
),
int
(
do_test
)])
else
:
token_counts
=
torch
.
cuda
.
LongTensor
([
0
,
0
,
0
,
0
,
0
])
# Broadcast num tokens.
torch
.
distributed
.
broadcast
(
token_counts
,
mpu
.
get_model_parallel_src_rank
(),
group
=
mpu
.
get_model_parallel_group
())
num_tokens
=
token_counts
[
0
].
item
()
num_type_tokens
=
token_counts
[
1
].
item
()
args
.
vocab_size
=
token_counts
[
0
].
item
()
args
.
tokentype_size
=
token_counts
[
1
].
item
()
args
.
do_train
=
token_counts
[
2
].
item
()
args
.
do_valid
=
token_counts
[
3
].
item
()
args
.
do_test
=
token_counts
[
4
].
item
()
args
.
vocab_size
=
num_tokens
args
.
tokentype_size
=
num_type_tokens
return
train_data
,
val_data
,
test_data
return
train_data
,
valid_data
,
test_data
if
__name__
==
"__main__"
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment