Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
f66c58a9
Commit
f66c58a9
authored
Apr 07, 2020
by
Mohammad
Browse files
added build sample index to c++
parent
fe207787
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
157 additions
and
3 deletions
+157
-3
megatron/data/helpers.cpp
megatron/data/helpers.cpp
+88
-0
megatron/data/new_gpt2_dataset.py
megatron/data/new_gpt2_dataset.py
+69
-3
No files found.
megatron/data/helpers.cpp
View file @
f66c58a9
...
...
@@ -33,6 +33,93 @@ using namespace std;
const
int32_t
LONG_SENTENCE_LEN
=
512
;
py
::
array
build_sample_idx
(
const
py
::
array_t
<
int32_t
>&
sizes_
,
const
py
::
array_t
<
int32_t
>&
doc_idx_
,
const
int32_t
seq_length
,
const
int32_t
num_epochs
,
const
int64_t
tokens_per_epoch
)
{
/* Sample index mapping is a 2D array with sizes [number-of-samples + 1, 2]
where [..., 0] contains the index into `doc_idx` and [..., 0] is the
starting offset in that document.*/
// Consistency checks.
assert
(
seq_length
>
1
);
assert
(
num_epochs
>
0
);
assert
(
tokens_per_epoch
>
1
);
// Remove bound checks.
auto
sizes
=
sizes_
.
unchecked
<
1
>
();
auto
doc_idx
=
doc_idx_
.
unchecked
<
1
>
();
// Mapping and it's length (1D).
int64_t
num_samples
=
(
num_epochs
*
tokens_per_epoch
-
1
)
/
seq_length
;
int32_t
*
sample_idx
=
new
int32_t
[
2
*
(
num_samples
+
1
)];
cout
<<
" using:"
<<
endl
<<
std
::
flush
;
cout
<<
" number of documents: "
<<
doc_idx_
.
shape
(
0
)
/
num_epochs
<<
endl
<<
std
::
flush
;
cout
<<
" number of epochs: "
<<
num_epochs
<<
endl
<<
std
::
flush
;
cout
<<
" sequence length: "
<<
seq_length
<<
endl
<<
std
::
flush
;
cout
<<
" total number of samples: "
<<
num_samples
<<
endl
<<
std
::
flush
;
// Index into sample_idx.
int64_t
sample_index
=
0
;
// Index into doc_idx.
int64_t
doc_idx_index
=
0
;
// Begining offset for each document.
int32_t
doc_offset
=
0
;
// Start with first document and no offset.
sample_idx
[
2
*
sample_index
]
=
doc_idx_index
;
sample_idx
[
2
*
sample_index
+
1
]
=
doc_offset
;
++
sample_index
;
while
(
sample_index
<=
num_samples
)
{
// Start with a fresh sequence.
int32_t
remaining_seq_length
=
seq_length
+
1
;
while
(
remaining_seq_length
!=
0
)
{
// Get the document length.
auto
doc_id
=
doc_idx
[
doc_idx_index
];
auto
doc_length
=
sizes
[
doc_id
]
-
doc_offset
;
// And add it to the current sequence.
remaining_seq_length
-=
doc_length
;
// If we have more than a full sequence, adjust offset and set
// remaining length to zero so we return from the while loop.
// Note that -1 here is for the same reason we have -1 in
// `_num_epochs` calculations.
if
(
remaining_seq_length
<=
0
)
{
doc_offset
+=
(
remaining_seq_length
+
doc_length
-
1
);
remaining_seq_length
=
0
;
}
else
{
// Otherwise, start from the begining of the next document.
++
doc_idx_index
;
doc_offset
=
0
;
}
}
// Record the sequence.
sample_idx
[
2
*
sample_index
]
=
doc_idx_index
;
sample_idx
[
2
*
sample_index
+
1
]
=
doc_offset
;
++
sample_index
;
}
// Method to deallocate memory.
py
::
capsule
free_when_done
(
sample_idx
,
[](
void
*
mem_
)
{
int32_t
*
mem
=
reinterpret_cast
<
int32_t
*>
(
mem_
);
delete
[]
mem
;
});
// Return the numpy array.
const
auto
byte_size
=
sizeof
(
int32_t
);
return
py
::
array
(
std
::
vector
<
int64_t
>
{
num_samples
+
1
,
2
},
// shape
{
2
*
byte_size
,
byte_size
},
// C-style contiguous strides
sample_idx
,
// the data pointer
free_when_done
);
// numpy array references
}
inline
int32_t
get_target_sample_len
(
const
int32_t
short_seq_ratio
,
const
int32_t
max_length
,
std
::
mt19937
&
rand32_gen
)
{
...
...
@@ -307,4 +394,5 @@ py::array build_mapping(const py::array_t<int64_t>& docs_,
PYBIND11_MODULE
(
helpers
,
m
)
{
m
.
def
(
"build_mapping"
,
&
build_mapping
);
m
.
def
(
"build_sample_idx"
,
&
build_sample_idx
);
}
megatron/data/new_gpt2_dataset.py
View file @
f66c58a9
...
...
@@ -22,11 +22,73 @@ import numpy as np
import
torch
from
torch.utils.data
import
Dataset
import
helpers
#from bert_dataset import get_train_valid_test_split_
def
print_rank_0
(
message
):
print
(
message
)
def
build_train_valid_test_datasets
(
data_prefix
,
data_impl
,
splits_string
,
train_valid_test_num_samples
,
seq_length
,
seed
,
skip_warmup
):
# Indexed dataset.
indexed_dataset
=
get_indexed_dataset_
(
data_prefix
,
data_impl
,
skip_warmup
)
total_num_of_documents
=
indexed_dataset
.
sizes
.
shape
[
0
]
splits
=
get_train_valid_test_split_
(
splits_string
,
total_num_of_documents
)
# Print stats about the splits.
print_rank_0
(
' > dataset split:'
)
def
print_split_stats
(
name
,
index
):
print_rank_0
(
' {}:'
.
format
(
name
))
print_rank_0
(
' document indices in [{}, {}) total of {} '
'documents'
.
format
(
splits
[
index
],
splits
[
index
+
1
],
splits
[
index
+
1
]
-
splits
[
index
]))
print_split_stats
(
'train'
,
0
)
print_split_stats
(
'validation'
,
1
)
print_split_stats
(
'test'
,
2
)
def
build_dataset
(
index
,
name
):
dataset
=
None
if
splits
[
index
+
1
]
>
splits
[
index
]:
documents
=
np
.
arange
(
start
=
splits
[
index
],
end
=
splits
[
index
+
1
],
step
=
1
,
dtype
=
np
.
int32
)
dataset
=
GPT2Dataset
(
name
,
data_prefix
,
documents
,
indexed_dataset
,
train_valid_test_num_samples
[
index
],
seq_length
,
seed
)
return
dataset
train_dataset
=
build_dataset
(
0
,
'train'
)
valid_dataset
=
build_dataset
(
1
,
'valid'
)
test_dataset
=
build_dataset
(
2
,
'test'
)
return
(
train_dataset
,
valid_dataset
,
test_dataset
)
def
get_indexed_dataset_
(
data_prefix
,
data_impl
,
skip_warmup
):
print_rank_0
(
' > building dataset index ...'
)
start_time
=
time
.
time
()
indexed_dataset
=
make_indexed_dataset
(
data_prefix
,
data_impl
,
skip_warmup
)
print_rank_0
(
' > finished creating indexed dataset in {:4f} '
'seconds'
.
format
(
time
.
time
()
-
start_time
))
print_rank_0
(
' > indexed dataset stats:'
)
print_rank_0
(
' number of documents: {}'
.
format
(
indexed_dataset
.
sizes
.
shape
[
0
]))
return
indexed_dataset
class
GPT2Dataset
(
Dataset
):
def
__init__
(
self
,
name
,
data_prefix
,
...
...
@@ -121,8 +183,11 @@ def _build_index_mappings(name, data_prefix, documents, sizes,
'(seconds): {:4f}'
.
format
(
time
.
time
()
-
start_time
))
# sample-idx.
start_time
=
time
.
time
()
sample_idx
=
_build_sample_idx
(
sizes
,
doc_idx
,
seq_length
,
num_epochs
,
tokens_per_epoch
)
import
helpers
sample_idx
=
helpers
.
build_sample_idx
(
sizes
,
doc_idx
,
seq_length
,
num_epochs
,
tokens_per_epoch
)
#sample_idx = _build_sample_idx(sizes, doc_idx, seq_length,
# num_epochs, tokens_per_epoch)
np
.
save
(
sample_idx_filename
,
sample_idx
,
allow_pickle
=
True
)
print_rank_0
(
' > elasped time to build and save sample-idx mapping '
'(seconds): {:4f}'
.
format
(
time
.
time
()
-
start_time
))
...
...
@@ -186,6 +251,7 @@ def _build_doc_idx(documents, num_epochs, np_rng):
doc_idx
=
np
.
mgrid
[
0
:
num_epochs
,
0
:
len
(
documents
)][
1
]
doc_idx
[:]
=
documents
doc_idx
=
doc_idx
.
reshape
(
-
1
)
doc_idx
=
doc_idx
.
astype
(
np
.
int32
)
np_rng
.
shuffle
(
doc_idx
)
return
doc_idx
...
...
@@ -323,7 +389,7 @@ if __name__ == '__main__':
import
random
data_prefix
=
'junk/'
for
seed
in
range
(
1234
,
124
0
):
for
seed
in
range
(
1234
,
124
5
):
random
.
seed
(
seed
)
num_docs
=
random
.
randint
(
1
,
999
)
min_doc_length
=
random
.
randint
(
1
,
99
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment