Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
wuxk1
Megatron-LM
Commits
b7f1b050
Commit
b7f1b050
authored
Apr 14, 2020
by
Neel Kant
Browse files
Lint whole repo
parent
c99fa80c
Changes
63
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
389 additions
and
276 deletions
+389
-276
megatron/arguments.py
megatron/arguments.py
+0
-3
megatron/checkpointing.py
megatron/checkpointing.py
+3
-3
megatron/data/__init__.py
megatron/data/__init__.py
+0
-2
megatron/data/bert_dataset.py
megatron/data/bert_dataset.py
+4
-6
megatron/data/gpt2_dataset.py
megatron/data/gpt2_dataset.py
+8
-10
megatron/data/indexed_dataset.py
megatron/data/indexed_dataset.py
+13
-5
megatron/data/samplers.py
megatron/data/samplers.py
+4
-3
megatron/data/test/test_indexed_dataset.py
megatron/data/test/test_indexed_dataset.py
+13
-10
megatron/deprecated_data_utils/__init__.py
megatron/deprecated_data_utils/__init__.py
+18
-11
megatron/deprecated_data_utils/configure_data.py
megatron/deprecated_data_utils/configure_data.py
+7
-4
megatron/deprecated_data_utils/corpora.py
megatron/deprecated_data_utils/corpora.py
+35
-32
megatron/deprecated_data_utils/datasets.py
megatron/deprecated_data_utils/datasets.py
+75
-45
megatron/deprecated_data_utils/file_utils.py
megatron/deprecated_data_utils/file_utils.py
+1
-1
megatron/deprecated_data_utils/lazy_loader.py
megatron/deprecated_data_utils/lazy_loader.py
+23
-16
megatron/deprecated_data_utils/samplers.py
megatron/deprecated_data_utils/samplers.py
+8
-4
megatron/deprecated_data_utils/scripts/presplit_sentences_json.py
.../deprecated_data_utils/scripts/presplit_sentences_json.py
+9
-9
megatron/deprecated_data_utils/scripts/split_gpt2_json.py
megatron/deprecated_data_utils/scripts/split_gpt2_json.py
+21
-14
megatron/deprecated_data_utils/scripts/split_json.py
megatron/deprecated_data_utils/scripts/split_json.py
+21
-14
megatron/deprecated_data_utils/tf_dl.py
megatron/deprecated_data_utils/tf_dl.py
+22
-14
megatron/deprecated_data_utils/tokenization.py
megatron/deprecated_data_utils/tokenization.py
+104
-70
No files found.
megatron/arguments.py
View file @
b7f1b050
...
@@ -357,7 +357,6 @@ def _add_gpt2_args(parser):
...
@@ -357,7 +357,6 @@ def _add_gpt2_args(parser):
return
parser
return
parser
def
add_data_args_
(
parser
):
def
add_data_args_
(
parser
):
"""Train/valid/test data arguments."""
"""Train/valid/test data arguments."""
...
@@ -367,6 +366,4 @@ def add_data_args_(parser):
...
@@ -367,6 +366,4 @@ def add_data_args_(parser):
choices
=
[
'raw'
,
'lazy'
,
'tfrecords'
,
'numpy'
,
'binary'
],
choices
=
[
'raw'
,
'lazy'
,
'tfrecords'
,
'numpy'
,
'binary'
],
help
=
'Which data loader to use. Default varies by model.'
)
help
=
'Which data loader to use. Default varies by model.'
)
return
parser
return
parser
megatron/checkpointing.py
View file @
b7f1b050
...
@@ -67,7 +67,7 @@ def get_checkpoint_name(checkpoints_path, iteration,
...
@@ -67,7 +67,7 @@ def get_checkpoint_name(checkpoints_path, iteration,
directory
=
'iter_{:07d}'
.
format
(
iteration
)
directory
=
'iter_{:07d}'
.
format
(
iteration
)
return
os
.
path
.
join
(
checkpoints_path
,
directory
,
return
os
.
path
.
join
(
checkpoints_path
,
directory
,
'mp_rank_{:02d}'
.
format
(
'mp_rank_{:02d}'
.
format
(
mpu
.
get_model_parallel_rank
()
if
mp_rank
is
None
\
mpu
.
get_model_parallel_rank
()
if
mp_rank
is
None
else
mp_rank
),
else
mp_rank
),
'model_optim_rng.pt'
)
'model_optim_rng.pt'
)
...
@@ -179,7 +179,7 @@ def load_checkpoint(model, optimizer, lr_scheduler):
...
@@ -179,7 +179,7 @@ def load_checkpoint(model, optimizer, lr_scheduler):
'megatron.fp16.loss_scaler'
]
'megatron.fp16.loss_scaler'
]
state_dict
=
torch
.
load
(
checkpoint_name
,
map_location
=
'cpu'
)
state_dict
=
torch
.
load
(
checkpoint_name
,
map_location
=
'cpu'
)
sys
.
modules
.
pop
(
'fp16.loss_scaler'
,
None
)
sys
.
modules
.
pop
(
'fp16.loss_scaler'
,
None
)
except
:
except
BaseException
:
print_rank_0
(
'could not load the checkpoint'
)
print_rank_0
(
'could not load the checkpoint'
)
sys
.
exit
()
sys
.
exit
()
...
@@ -190,7 +190,7 @@ def load_checkpoint(model, optimizer, lr_scheduler):
...
@@ -190,7 +190,7 @@ def load_checkpoint(model, optimizer, lr_scheduler):
try
:
try
:
iteration
=
state_dict
[
'iteration'
]
iteration
=
state_dict
[
'iteration'
]
except
KeyError
:
except
KeyError
:
try
:
# Backward compatible with older checkpoints
try
:
# Backward compatible with older checkpoints
iteration
=
state_dict
[
'total_iters'
]
iteration
=
state_dict
[
'total_iters'
]
except
KeyError
:
except
KeyError
:
print_rank_0
(
'A metadata file exists but unable to load '
print_rank_0
(
'A metadata file exists but unable to load '
...
...
megatron/data/__init__.py
View file @
b7f1b050
from
.
import
indexed_dataset
from
.
import
indexed_dataset
megatron/data/bert_dataset.py
View file @
b7f1b050
...
@@ -47,6 +47,7 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
...
@@ -47,6 +47,7 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
# Print stats about the splits.
# Print stats about the splits.
print_rank_0
(
' > dataset split:'
)
print_rank_0
(
' > dataset split:'
)
def
print_split_stats
(
name
,
index
):
def
print_split_stats
(
name
,
index
):
print_rank_0
(
' {}:'
.
format
(
name
))
print_rank_0
(
' {}:'
.
format
(
name
))
print_rank_0
(
' document indices in [{}, {}) total of {} '
print_rank_0
(
' document indices in [{}, {}) total of {} '
...
@@ -113,7 +114,6 @@ class BertDataset(Dataset):
...
@@ -113,7 +114,6 @@ class BertDataset(Dataset):
# Dataset.
# Dataset.
self
.
indexed_dataset
=
indexed_dataset
self
.
indexed_dataset
=
indexed_dataset
# Build the samples mapping.
# Build the samples mapping.
self
.
samples_mapping
=
get_samples_mapping_
(
self
.
indexed_dataset
,
self
.
samples_mapping
=
get_samples_mapping_
(
self
.
indexed_dataset
,
data_prefix
,
data_prefix
,
...
@@ -133,11 +133,9 @@ class BertDataset(Dataset):
...
@@ -133,11 +133,9 @@ class BertDataset(Dataset):
self
.
mask_id
=
tokenizer
.
mask
self
.
mask_id
=
tokenizer
.
mask
self
.
pad_id
=
tokenizer
.
pad
self
.
pad_id
=
tokenizer
.
pad
def
__len__
(
self
):
def
__len__
(
self
):
return
self
.
samples_mapping
.
shape
[
0
]
return
self
.
samples_mapping
.
shape
[
0
]
def
__getitem__
(
self
,
idx
):
def
__getitem__
(
self
,
idx
):
start_index
,
end_index
,
seq_length
=
self
.
samples_mapping
[
idx
]
start_index
,
end_index
,
seq_length
=
self
.
samples_mapping
[
idx
]
...
@@ -148,7 +146,7 @@ class BertDataset(Dataset):
...
@@ -148,7 +146,7 @@ class BertDataset(Dataset):
# python randint is inclusive whereas the numpy one is exclusive.
# python randint is inclusive whereas the numpy one is exclusive.
np_rng
=
np
.
random
.
RandomState
(
seed
=
(
self
.
seed
+
idx
))
np_rng
=
np
.
random
.
RandomState
(
seed
=
(
self
.
seed
+
idx
))
return
build_training_sample
(
sample
,
seq_length
,
return
build_training_sample
(
sample
,
seq_length
,
self
.
max_seq_length
,
# needed for padding
self
.
max_seq_length
,
# needed for padding
self
.
vocab_id_list
,
self
.
vocab_id_list
,
self
.
vocab_id_to_token_dict
,
self
.
vocab_id_to_token_dict
,
self
.
cls_id
,
self
.
sep_id
,
self
.
cls_id
,
self
.
sep_id
,
...
@@ -192,7 +190,7 @@ def get_train_valid_test_split_(splits_string, size):
...
@@ -192,7 +190,7 @@ def get_train_valid_test_split_(splits_string, size):
splits
=
splits
[:
3
]
splits
=
splits
[:
3
]
splits_sum
=
sum
(
splits
)
splits_sum
=
sum
(
splits
)
assert
splits_sum
>
0.0
assert
splits_sum
>
0.0
splits
=
[
split
/
splits_sum
for
split
in
splits
]
splits
=
[
split
/
splits_sum
for
split
in
splits
]
splits_index
=
[
0
]
splits_index
=
[
0
]
for
index
,
split
in
enumerate
(
splits
):
for
index
,
split
in
enumerate
(
splits
):
splits_index
.
append
(
splits_index
[
index
]
+
splits_index
.
append
(
splits_index
[
index
]
+
...
@@ -254,7 +252,7 @@ def get_samples_mapping_(indexed_dataset,
...
@@ -254,7 +252,7 @@ def get_samples_mapping_(indexed_dataset,
indexed_dataset
.
sizes
,
indexed_dataset
.
sizes
,
num_epochs
,
num_epochs
,
max_num_samples
,
max_num_samples
,
max_seq_length
-
3
,
# account for added tokens
max_seq_length
-
3
,
# account for added tokens
short_seq_prob
,
short_seq_prob
,
seed
,
seed
,
verbose
)
verbose
)
...
...
megatron/data/gpt2_dataset.py
View file @
b7f1b050
...
@@ -42,6 +42,7 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
...
@@ -42,6 +42,7 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
# Print stats about the splits.
# Print stats about the splits.
print_rank_0
(
' > dataset split:'
)
print_rank_0
(
' > dataset split:'
)
def
print_split_stats
(
name
,
index
):
def
print_split_stats
(
name
,
index
):
print_rank_0
(
' {}:'
.
format
(
name
))
print_rank_0
(
' {}:'
.
format
(
name
))
print_rank_0
(
' document indices in [{}, {}) total of {} '
print_rank_0
(
' document indices in [{}, {}) total of {} '
...
@@ -54,7 +55,7 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
...
@@ -54,7 +55,7 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
def
build_dataset
(
index
,
name
):
def
build_dataset
(
index
,
name
):
dataset
=
None
dataset
=
None
if
splits
[
index
+
1
]
>
splits
[
index
]:
if
splits
[
index
+
1
]
>
splits
[
index
]:
documents
=
np
.
arange
(
start
=
splits
[
index
],
stop
=
splits
[
index
+
1
],
documents
=
np
.
arange
(
start
=
splits
[
index
],
stop
=
splits
[
index
+
1
],
step
=
1
,
dtype
=
np
.
int32
)
step
=
1
,
dtype
=
np
.
int32
)
dataset
=
GPT2Dataset
(
name
,
data_prefix
,
dataset
=
GPT2Dataset
(
name
,
data_prefix
,
documents
,
indexed_dataset
,
documents
,
indexed_dataset
,
...
@@ -102,21 +103,19 @@ class GPT2Dataset(torch.utils.data.Dataset):
...
@@ -102,21 +103,19 @@ class GPT2Dataset(torch.utils.data.Dataset):
self
.
name
,
data_prefix
,
documents
,
self
.
indexed_dataset
.
sizes
,
self
.
name
,
data_prefix
,
documents
,
self
.
indexed_dataset
.
sizes
,
num_samples
,
seq_length
,
seed
)
num_samples
,
seq_length
,
seed
)
def
__len__
(
self
):
def
__len__
(
self
):
# -1 is due to data structure used to retieve the index:
# -1 is due to data structure used to retieve the index:
# sample i --> [sample_idx[i], sample_idx[i+1])
# sample i --> [sample_idx[i], sample_idx[i+1])
return
self
.
sample_idx
.
shape
[
0
]
-
1
return
self
.
sample_idx
.
shape
[
0
]
-
1
def
__getitem__
(
self
,
idx
):
def
__getitem__
(
self
,
idx
):
# Get the shuffled index.
# Get the shuffled index.
idx
=
self
.
shuffle_idx
[
idx
]
idx
=
self
.
shuffle_idx
[
idx
]
# Start and end documents and offsets.
# Start and end documents and offsets.
doc_index_f
=
self
.
sample_idx
[
idx
][
0
]
doc_index_f
=
self
.
sample_idx
[
idx
][
0
]
doc_index_l
=
self
.
sample_idx
[
idx
+
1
][
0
]
doc_index_l
=
self
.
sample_idx
[
idx
+
1
][
0
]
offset_f
=
self
.
sample_idx
[
idx
][
1
]
offset_f
=
self
.
sample_idx
[
idx
][
1
]
offset_l
=
self
.
sample_idx
[
idx
+
1
][
1
]
offset_l
=
self
.
sample_idx
[
idx
+
1
][
1
]
# If we are within the same document, just extract the chunk.
# If we are within the same document, just extract the chunk.
if
doc_index_f
==
doc_index_l
:
if
doc_index_f
==
doc_index_l
:
sample
=
self
.
indexed_dataset
.
get
(
self
.
doc_idx
[
doc_index_f
],
sample
=
self
.
indexed_dataset
.
get
(
self
.
doc_idx
[
doc_index_f
],
...
@@ -127,18 +126,17 @@ class GPT2Dataset(torch.utils.data.Dataset):
...
@@ -127,18 +126,17 @@ class GPT2Dataset(torch.utils.data.Dataset):
sample_list
=
[
self
.
indexed_dataset
.
get
(
self
.
doc_idx
[
doc_index_f
],
sample_list
=
[
self
.
indexed_dataset
.
get
(
self
.
doc_idx
[
doc_index_f
],
offset
=
offset_f
)]
offset
=
offset_f
)]
# Loop over all in between documents and add the entire document.
# Loop over all in between documents and add the entire document.
for
i
in
range
(
doc_index_f
+
1
,
doc_index_l
):
for
i
in
range
(
doc_index_f
+
1
,
doc_index_l
):
sample_list
.
append
(
self
.
indexed_dataset
.
get
(
self
.
doc_idx
[
i
]))
sample_list
.
append
(
self
.
indexed_dataset
.
get
(
self
.
doc_idx
[
i
]))
# And finally add the relevant portion of last document.
# And finally add the relevant portion of last document.
sample_list
.
append
(
self
.
indexed_dataset
.
get
(
sample_list
.
append
(
self
.
indexed_dataset
.
get
(
self
.
doc_idx
[
doc_index_l
],
self
.
doc_idx
[
doc_index_l
],
length
=
offset_l
+
1
))
length
=
offset_l
+
1
))
sample
=
np
.
concatenate
(
sample_list
)
sample
=
np
.
concatenate
(
sample_list
)
return
{
'text'
:
np
.
array
(
sample
,
dtype
=
np
.
int64
)}
return
{
'text'
:
np
.
array
(
sample
,
dtype
=
np
.
int64
)}
def
_build_index_mappings
(
name
,
data_prefix
,
documents
,
sizes
,
def
_build_index_mappings
(
name
,
data_prefix
,
documents
,
sizes
,
num_samples
,
seq_length
,
seed
):
num_samples
,
seq_length
,
seed
):
"""Build doc-idx, sample-idx, and shuffle-idx.
"""Build doc-idx, sample-idx, and shuffle-idx.
...
@@ -185,7 +183,7 @@ def _build_index_mappings(name, data_prefix, documents, sizes,
...
@@ -185,7 +183,7 @@ def _build_index_mappings(name, data_prefix, documents, sizes,
assert
sizes
.
dtype
==
np
.
int32
assert
sizes
.
dtype
==
np
.
int32
sample_idx
=
helpers
.
build_sample_idx
(
sizes
,
doc_idx
,
seq_length
,
sample_idx
=
helpers
.
build_sample_idx
(
sizes
,
doc_idx
,
seq_length
,
num_epochs
,
tokens_per_epoch
)
num_epochs
,
tokens_per_epoch
)
#sample_idx = _build_sample_idx(sizes, doc_idx, seq_length,
#
sample_idx = _build_sample_idx(sizes, doc_idx, seq_length,
# num_epochs, tokens_per_epoch)
# num_epochs, tokens_per_epoch)
np
.
save
(
sample_idx_filename
,
sample_idx
,
allow_pickle
=
True
)
np
.
save
(
sample_idx_filename
,
sample_idx
,
allow_pickle
=
True
)
print_rank_0
(
' > elasped time to build and save sample-idx mapping '
print_rank_0
(
' > elasped time to build and save sample-idx mapping '
...
@@ -194,7 +192,7 @@ def _build_index_mappings(name, data_prefix, documents, sizes,
...
@@ -194,7 +192,7 @@ def _build_index_mappings(name, data_prefix, documents, sizes,
start_time
=
time
.
time
()
start_time
=
time
.
time
()
# -1 is due to data structure used to retieve the index:
# -1 is due to data structure used to retieve the index:
# sample i --> [sample_idx[i], sample_idx[i+1])
# sample i --> [sample_idx[i], sample_idx[i+1])
shuffle_idx
=
_build_shuffle_idx
(
sample_idx
.
shape
[
0
]
-
1
,
np_rng
)
shuffle_idx
=
_build_shuffle_idx
(
sample_idx
.
shape
[
0
]
-
1
,
np_rng
)
np
.
save
(
shuffle_idx_filename
,
shuffle_idx
,
allow_pickle
=
True
)
np
.
save
(
shuffle_idx_filename
,
shuffle_idx
,
allow_pickle
=
True
)
print_rank_0
(
' > elasped time to build and save shuffle-idx mapping'
print_rank_0
(
' > elasped time to build and save shuffle-idx mapping'
' (seconds): {:4f}'
.
format
(
time
.
time
()
-
start_time
))
' (seconds): {:4f}'
.
format
(
time
.
time
()
-
start_time
))
...
...
megatron/data/indexed_dataset.py
View file @
b7f1b050
...
@@ -20,6 +20,7 @@ import numpy as np
...
@@ -20,6 +20,7 @@ import numpy as np
import
torch
import
torch
from
megatron
import
print_rank_0
from
megatron
import
print_rank_0
def
__best_fitting_dtype
(
vocab_size
=
None
):
def
__best_fitting_dtype
(
vocab_size
=
None
):
if
vocab_size
is
not
None
and
vocab_size
<
65500
:
if
vocab_size
is
not
None
and
vocab_size
<
65500
:
return
np
.
uint16
return
np
.
uint16
...
@@ -109,13 +110,15 @@ def index_file_path(prefix_path):
...
@@ -109,13 +110,15 @@ def index_file_path(prefix_path):
def
data_file_path
(
prefix_path
):
def
data_file_path
(
prefix_path
):
return
prefix_path
+
'.bin'
return
prefix_path
+
'.bin'
def
create_doc_idx
(
sizes
):
def
create_doc_idx
(
sizes
):
doc_idx
=
[
0
]
doc_idx
=
[
0
]
for
i
,
s
in
enumerate
(
sizes
):
for
i
,
s
in
enumerate
(
sizes
):
if
s
==
0
:
if
s
==
0
:
doc_idx
.
append
(
i
+
1
)
doc_idx
.
append
(
i
+
1
)
return
doc_idx
return
doc_idx
class
IndexedDataset
(
torch
.
utils
.
data
.
Dataset
):
class
IndexedDataset
(
torch
.
utils
.
data
.
Dataset
):
"""Loader for IndexedDataset"""
"""Loader for IndexedDataset"""
_HDR_MAGIC
=
b
'TNTIDX
\x00\x00
'
_HDR_MAGIC
=
b
'TNTIDX
\x00\x00
'
...
@@ -155,7 +158,7 @@ class IndexedDataset(torch.utils.data.Dataset):
...
@@ -155,7 +158,7 @@ class IndexedDataset(torch.utils.data.Dataset):
if
self
.
data_file
:
if
self
.
data_file
:
self
.
data_file
.
close
()
self
.
data_file
.
close
()
#@lru_cache(maxsize=8)
#
@lru_cache(maxsize=8)
def
__getitem__
(
self
,
idx
):
def
__getitem__
(
self
,
idx
):
if
not
self
.
data_file
:
if
not
self
.
data_file
:
self
.
read_data
(
self
.
path
)
self
.
read_data
(
self
.
path
)
...
@@ -235,7 +238,7 @@ class IndexedCachedDataset(IndexedDataset):
...
@@ -235,7 +238,7 @@ class IndexedCachedDataset(IndexedDataset):
self
.
data_file
.
close
()
self
.
data_file
.
close
()
self
.
data_file
=
None
self
.
data_file
=
None
#@lru_cache(maxsize=8)
#
@lru_cache(maxsize=8)
def
__getitem__
(
self
,
idx
):
def
__getitem__
(
self
,
idx
):
if
isinstance
(
idx
,
int
):
if
isinstance
(
idx
,
int
):
i
=
idx
i
=
idx
...
@@ -399,13 +402,18 @@ class MMapIndexedDataset(torch.utils.data.Dataset):
...
@@ -399,13 +402,18 @@ class MMapIndexedDataset(torch.utils.data.Dataset):
self
.
_bin_buffer_mmap
=
np
.
memmap
(
path
,
mode
=
'r'
,
order
=
'C'
)
self
.
_bin_buffer_mmap
=
np
.
memmap
(
path
,
mode
=
'r'
,
order
=
'C'
)
self
.
_bin_buffer
=
memoryview
(
self
.
_bin_buffer_mmap
)
self
.
_bin_buffer
=
memoryview
(
self
.
_bin_buffer_mmap
)
print_rank_0
(
" reading sizes..."
)
print_rank_0
(
" reading sizes..."
)
self
.
_sizes
=
np
.
frombuffer
(
self
.
_bin_buffer
,
dtype
=
np
.
int32
,
count
=
self
.
_len
,
offset
=
offset
)
self
.
_sizes
=
np
.
frombuffer
(
self
.
_bin_buffer
,
dtype
=
np
.
int32
,
count
=
self
.
_len
,
offset
=
offset
)
print_rank_0
(
" reading pointers..."
)
print_rank_0
(
" reading pointers..."
)
self
.
_pointers
=
np
.
frombuffer
(
self
.
_bin_buffer
,
dtype
=
np
.
int64
,
count
=
self
.
_len
,
self
.
_pointers
=
np
.
frombuffer
(
self
.
_bin_buffer
,
dtype
=
np
.
int64
,
count
=
self
.
_len
,
offset
=
offset
+
self
.
_sizes
.
nbytes
)
offset
=
offset
+
self
.
_sizes
.
nbytes
)
print_rank_0
(
" reading document index..."
)
print_rank_0
(
" reading document index..."
)
self
.
_doc_idx
=
np
.
frombuffer
(
self
.
_bin_buffer
,
dtype
=
np
.
int64
,
count
=
self
.
_doc_count
,
self
.
_doc_idx
=
np
.
frombuffer
(
self
.
_bin_buffer
,
dtype
=
np
.
int64
,
count
=
self
.
_doc_count
,
offset
=
offset
+
self
.
_sizes
.
nbytes
+
self
.
_pointers
.
nbytes
)
offset
=
offset
+
self
.
_sizes
.
nbytes
+
self
.
_pointers
.
nbytes
)
def
__del__
(
self
):
def
__del__
(
self
):
self
.
_bin_buffer_mmap
.
_mmap
.
close
()
self
.
_bin_buffer_mmap
.
_mmap
.
close
()
del
self
.
_bin_buffer_mmap
del
self
.
_bin_buffer_mmap
...
@@ -464,7 +472,7 @@ class MMapIndexedDataset(torch.utils.data.Dataset):
...
@@ -464,7 +472,7 @@ class MMapIndexedDataset(torch.utils.data.Dataset):
def
__len__
(
self
):
def
__len__
(
self
):
return
len
(
self
.
_index
)
return
len
(
self
.
_index
)
#@lru_cache(maxsize=8)
#
@lru_cache(maxsize=8)
def
__getitem__
(
self
,
idx
):
def
__getitem__
(
self
,
idx
):
if
isinstance
(
idx
,
int
):
if
isinstance
(
idx
,
int
):
ptr
,
size
=
self
.
_index
[
idx
]
ptr
,
size
=
self
.
_index
[
idx
]
...
...
megatron/data/samplers.py
View file @
b7f1b050
...
@@ -81,6 +81,7 @@ class DistributedBatchSampler(data.sampler.BatchSampler):
...
@@ -81,6 +81,7 @@ class DistributedBatchSampler(data.sampler.BatchSampler):
sampler level. This allows wrapping of arbitrary data samplers
sampler level. This allows wrapping of arbitrary data samplers
(sequential, random, WeightedRandomSampler, etc.) with this batch
(sequential, random, WeightedRandomSampler, etc.) with this batch
sampler."""
sampler."""
def
__init__
(
self
,
sampler
,
batch_size
,
drop_last
,
rank
=-
1
,
def
__init__
(
self
,
sampler
,
batch_size
,
drop_last
,
rank
=-
1
,
world_size
=
2
,
wrap_last
=
False
):
world_size
=
2
,
wrap_last
=
False
):
super
(
DistributedBatchSampler
,
self
).
__init__
(
sampler
,
batch_size
,
super
(
DistributedBatchSampler
,
self
).
__init__
(
sampler
,
batch_size
,
...
@@ -120,7 +121,7 @@ class DistributedBatchSampler(data.sampler.BatchSampler):
...
@@ -120,7 +121,7 @@ class DistributedBatchSampler(data.sampler.BatchSampler):
def
data_iterator
(
self
,
_iter
,
wrap_around
=
False
):
def
data_iterator
(
self
,
_iter
,
wrap_around
=
False
):
"""iterates through data and handles wrap around"""
"""iterates through data and handles wrap around"""
for
i
,
idx
in
enumerate
(
_iter
):
for
i
,
idx
in
enumerate
(
_iter
):
if
i
<
self
.
wrap_around
%
self
.
batch_size
:
if
i
<
self
.
wrap_around
%
self
.
batch_size
:
continue
continue
if
wrap_around
:
if
wrap_around
:
self
.
wrap_around
+=
1
self
.
wrap_around
+=
1
...
@@ -129,6 +130,6 @@ class DistributedBatchSampler(data.sampler.BatchSampler):
...
@@ -129,6 +130,6 @@ class DistributedBatchSampler(data.sampler.BatchSampler):
def
_batch
(
self
,
batch
):
def
_batch
(
self
,
batch
):
"""extracts samples only pertaining to this worker's batch"""
"""extracts samples only pertaining to this worker's batch"""
start
=
self
.
rank
*
self
.
batch_size
//
self
.
world_size
start
=
self
.
rank
*
self
.
batch_size
//
self
.
world_size
end
=
(
self
.
rank
+
1
)
*
self
.
batch_size
//
self
.
world_size
end
=
(
self
.
rank
+
1
)
*
self
.
batch_size
//
self
.
world_size
return
batch
[
start
:
end
]
return
batch
[
start
:
end
]
megatron/data/test/test_indexed_dataset.py
View file @
b7f1b050
...
@@ -2,6 +2,8 @@
...
@@ -2,6 +2,8 @@
# put some code used during development and manual testing of
# put some code used during development and manual testing of
# indexed_dataset.
# indexed_dataset.
from
megatron.data
import
indexed_dataset
from
megatron.tokenizer
import
build_tokenizer
import
argparse
import
argparse
import
os
import
os
import
sys
import
sys
...
@@ -11,8 +13,6 @@ import torch
...
@@ -11,8 +13,6 @@ import torch
script_dir
=
os
.
path
.
dirname
(
os
.
path
.
realpath
(
__file__
))
script_dir
=
os
.
path
.
dirname
(
os
.
path
.
realpath
(
__file__
))
sys
.
path
.
append
(
os
.
path
.
join
(
script_dir
,
"../../../"
))
sys
.
path
.
append
(
os
.
path
.
join
(
script_dir
,
"../../../"
))
from
megatron.tokenizer
import
build_tokenizer
from
megatron.data
import
indexed_dataset
def
test_indexed_dataset
(
args
):
def
test_indexed_dataset
(
args
):
ds
=
indexed_dataset
.
make_dataset
(
args
.
data
,
args
.
dataset_impl
)
ds
=
indexed_dataset
.
make_dataset
(
args
.
data
,
args
.
dataset_impl
)
...
@@ -23,12 +23,12 @@ def test_indexed_dataset(args):
...
@@ -23,12 +23,12 @@ def test_indexed_dataset(args):
if
ds
.
supports_prefetch
:
if
ds
.
supports_prefetch
:
# just prefetch the whole thing in test (so assume it is small)
# just prefetch the whole thing in test (so assume it is small)
ds
.
prefetch
(
range
(
len
(
ds
)))
ds
.
prefetch
(
range
(
len
(
ds
)))
if
args
.
count
>
len
(
ds
.
doc_idx
)
-
1
:
if
args
.
count
>
len
(
ds
.
doc_idx
)
-
1
:
args
.
count
=
len
(
ds
.
doc_idx
)
-
1
args
.
count
=
len
(
ds
.
doc_idx
)
-
1
for
i
in
range
(
args
.
count
):
for
i
in
range
(
args
.
count
):
start
=
ds
.
doc_idx
[
i
]
start
=
ds
.
doc_idx
[
i
]
end
=
ds
.
doc_idx
[
i
+
1
]
end
=
ds
.
doc_idx
[
i
+
1
]
ids
=
ds
[
start
:
end
]
ids
=
ds
[
start
:
end
]
print
(
f
"Document
{
i
}
:"
)
print
(
f
"Document
{
i
}
:"
)
print
(
"--------------"
)
print
(
"--------------"
)
...
@@ -39,6 +39,7 @@ def test_indexed_dataset(args):
...
@@ -39,6 +39,7 @@ def test_indexed_dataset(args):
print
(
text
)
print
(
text
)
print
(
"---"
)
print
(
"---"
)
def
test_indexed_dataset_get
(
args
):
def
test_indexed_dataset_get
(
args
):
ds
=
indexed_dataset
.
make_dataset
(
args
.
data
,
args
.
dataset_impl
)
ds
=
indexed_dataset
.
make_dataset
(
args
.
data
,
args
.
dataset_impl
)
tokenizer
=
build_tokenizer
(
args
)
tokenizer
=
build_tokenizer
(
args
)
...
@@ -46,19 +47,19 @@ def test_indexed_dataset_get(args):
...
@@ -46,19 +47,19 @@ def test_indexed_dataset_get(args):
print
(
f
"size:
{
size
}
"
)
print
(
f
"size:
{
size
}
"
)
full
=
ds
.
get
(
0
)
full
=
ds
.
get
(
0
)
print
(
full
)
print
(
full
)
#print(tokenizer.detokenize(full.data.tolist()))
#
print(tokenizer.detokenize(full.data.tolist()))
print
(
"---"
)
print
(
"---"
)
end
=
ds
.
get
(
0
,
offset
=
size
-
10
)
end
=
ds
.
get
(
0
,
offset
=
size
-
10
)
print
(
end
)
print
(
end
)
#print(tokenizer.detokenize(end.data.tolist()))
#
print(tokenizer.detokenize(end.data.tolist()))
start
=
ds
.
get
(
0
,
length
=
10
)
start
=
ds
.
get
(
0
,
length
=
10
)
print
(
start
)
print
(
start
)
#print(tokenizer.detokenize(start.data.tolist()))
#
print(tokenizer.detokenize(start.data.tolist()))
part
=
ds
.
get
(
0
,
offset
=
2
,
length
=
8
)
part
=
ds
.
get
(
0
,
offset
=
2
,
length
=
8
)
print
(
part
)
print
(
part
)
#print(tokenizer.detokenize(part.data.tolist()))
#
print(tokenizer.detokenize(part.data.tolist()))
# def test_albert_dataset(args):
# def test_albert_dataset(args):
# # tokenizer = FullBertTokenizer(args.vocab, do_lower_case=True)
# # tokenizer = FullBertTokenizer(args.vocab, do_lower_case=True)
...
@@ -77,6 +78,7 @@ def test_indexed_dataset_get(args):
...
@@ -77,6 +78,7 @@ def test_indexed_dataset_get(args):
# if i >= args.count-1:
# if i >= args.count-1:
# exit()
# exit()
def
main
():
def
main
():
parser
=
argparse
.
ArgumentParser
()
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'--data'
,
type
=
str
,
help
=
'prefix to data files'
)
parser
.
add_argument
(
'--data'
,
type
=
str
,
help
=
'prefix to data files'
)
...
@@ -118,5 +120,6 @@ def main():
...
@@ -118,5 +120,6 @@ def main():
# test_albert_dataset(args)
# test_albert_dataset(args)
test_indexed_dataset_get
(
args
)
test_indexed_dataset_get
(
args
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
main
()
main
()
megatron/deprecated_data_utils/__init__.py
View file @
b7f1b050
...
@@ -28,21 +28,24 @@ TRAIN_DATA = 0
...
@@ -28,21 +28,24 @@ TRAIN_DATA = 0
VAL_DATA
=
1
VAL_DATA
=
1
TEST_DATA
=
2
TEST_DATA
=
2
def
should_split
(
split
):
def
should_split
(
split
):
"""
"""
given split proportions checks if should split
given split proportions checks if should split
Examples:
Examples:
>>> should_split([10,0,0])
>>> should_split([10,0,0])
False
False
>>> should_split([1,.1,.2])
>>> should_split([1,.1,.2])
True
True
"""
"""
return
max
(
split
)
/
sum
(
split
)
!=
1.
return
max
(
split
)
/
sum
(
split
)
!=
1.
def
get_ext
(
path
):
def
get_ext
(
path
):
"""gets path extension"""
"""gets path extension"""
return
os
.
path
.
splitext
(
path
)[
1
]
return
os
.
path
.
splitext
(
path
)[
1
]
def
get_dataset
(
path
,
**
kwargs
):
def
get_dataset
(
path
,
**
kwargs
):
"""gets dataset object based on keyword args and file at `path`"""
"""gets dataset object based on keyword args and file at `path`"""
if
supported_corpus
(
path
):
if
supported_corpus
(
path
):
...
@@ -53,17 +56,19 @@ def get_dataset(path, **kwargs):
...
@@ -53,17 +56,19 @@ def get_dataset(path, **kwargs):
elif
ext
in
[
'.csv'
,
'.tsv'
]:
elif
ext
in
[
'.csv'
,
'.tsv'
]:
text
=
csv_dataset
(
path
,
**
kwargs
)
text
=
csv_dataset
(
path
,
**
kwargs
)
else
:
else
:
raise
NotImplementedError
(
'data file type %s is not supported'
%
(
ext
))
raise
NotImplementedError
(
'data file type %s is not supported'
%
(
ext
))
return
text
return
text
def
supported_corpus
(
corpus_name
):
def
supported_corpus
(
corpus_name
):
"""checks if corpus name is defined in `corpora.py`"""
"""checks if corpus name is defined in `corpora.py`"""
return
corpus_name
in
corpora
.
NAMED_CORPORA
return
corpus_name
in
corpora
.
NAMED_CORPORA
def
make_dataset
(
path
,
seq_length
,
text_key
,
label_key
,
lazy
=
False
,
process_fn
=
None
,
split
=
[
1.
],
def
make_dataset
(
path
,
seq_length
,
text_key
,
label_key
,
lazy
=
False
,
process_fn
=
None
,
split
=
[
1.
],
delim
=
','
,
loose
=
False
,
binarize_sent
=
False
,
drop_unlabeled
=
False
,
tokenizer
=
None
,
delim
=
','
,
loose
=
False
,
binarize_sent
=
False
,
drop_unlabeled
=
False
,
tokenizer
=
None
,
tokenizer_type
=
'CharacterLevelTokenizer'
,
tokenizer_model_path
=
None
,
vocab_size
=
None
,
tokenizer_type
=
'CharacterLevelTokenizer'
,
tokenizer_model_path
=
None
,
vocab_size
=
None
,
model_type
=
'bpe'
,
pad_token
=
0
,
character_converage
=
1.0
,
non_binary_cols
=
None
,
model_type
=
'bpe'
,
pad_token
=
0
,
character_converage
=
1.0
,
non_binary_cols
=
None
,
parallel_group
=
None
,
**
kwargs
):
parallel_group
=
None
,
**
kwargs
):
"""function to create datasets+tokenizers for common options"""
"""function to create datasets+tokenizers for common options"""
if
isinstance
(
process_fn
,
str
):
if
isinstance
(
process_fn
,
str
):
...
@@ -71,6 +76,7 @@ def make_dataset(path, seq_length, text_key, label_key, lazy=False, process_fn=N
...
@@ -71,6 +76,7 @@ def make_dataset(path, seq_length, text_key, label_key, lazy=False, process_fn=N
if
non_binary_cols
is
not
None
:
if
non_binary_cols
is
not
None
:
# multilabel dataset support (only for csvs)
# multilabel dataset support (only for csvs)
label_key
=
non_binary_cols
label_key
=
non_binary_cols
def
get_dataset_from_path
(
path_
):
def
get_dataset_from_path
(
path_
):
if
lazy
:
if
lazy
:
# get lazily loaded dataset
# get lazily loaded dataset
...
@@ -82,7 +88,7 @@ def make_dataset(path, seq_length, text_key, label_key, lazy=False, process_fn=N
...
@@ -82,7 +88,7 @@ def make_dataset(path, seq_length, text_key, label_key, lazy=False, process_fn=N
if
torch
.
distributed
.
get_rank
()
==
0
and
not
exists_lazy
(
path_
,
data_type
=
'data'
):
if
torch
.
distributed
.
get_rank
()
==
0
and
not
exists_lazy
(
path_
,
data_type
=
'data'
):
# create cached version of dataset for lazy loading if it doesn't exist
# create cached version of dataset for lazy loading if it doesn't exist
text
=
get_dataset
(
name
if
named_corpora
else
path_
,
text_key
=
text_key
,
label_key
=
label_key
,
binarize_sent
=
binarize_sent
,
text
=
get_dataset
(
name
if
named_corpora
else
path_
,
text_key
=
text_key
,
label_key
=
label_key
,
binarize_sent
=
binarize_sent
,
delim
=
delim
,
drop_unlabeled
=
drop_unlabeled
,
loose_json
=
loose
)
delim
=
delim
,
drop_unlabeled
=
drop_unlabeled
,
loose_json
=
loose
)
make_lazy
(
path_
,
text
.
X
,
data_type
=
'data'
)
make_lazy
(
path_
,
text
.
X
,
data_type
=
'data'
)
# This should be a barrier but nccl barrier assumes
# This should be a barrier but nccl barrier assumes
# device_index=rank which is not the case for model
# device_index=rank which is not the case for model
...
@@ -96,7 +102,7 @@ def make_dataset(path, seq_length, text_key, label_key, lazy=False, process_fn=N
...
@@ -96,7 +102,7 @@ def make_dataset(path, seq_length, text_key, label_key, lazy=False, process_fn=N
else
:
else
:
# get dataset
# get dataset
text
=
get_dataset
(
path_
,
text_key
=
text_key
,
label_key
=
label_key
,
binarize_sent
=
binarize_sent
,
text
=
get_dataset
(
path_
,
text_key
=
text_key
,
label_key
=
label_key
,
binarize_sent
=
binarize_sent
,
delim
=
delim
,
drop_unlabeled
=
drop_unlabeled
,
loose_json
=
loose
,
preprocess_fn
=
process_fn
)
delim
=
delim
,
drop_unlabeled
=
drop_unlabeled
,
loose_json
=
loose
,
preprocess_fn
=
process_fn
)
return
text
return
text
# get one or multiple datasets and concatenate
# get one or multiple datasets and concatenate
if
isinstance
(
path
,
str
):
if
isinstance
(
path
,
str
):
...
@@ -108,8 +114,8 @@ def make_dataset(path, seq_length, text_key, label_key, lazy=False, process_fn=N
...
@@ -108,8 +114,8 @@ def make_dataset(path, seq_length, text_key, label_key, lazy=False, process_fn=N
ds
=
ConcatDataset
(
datasets
)
ds
=
ConcatDataset
(
datasets
)
# make tokenizer for dataset
# make tokenizer for dataset
if
tokenizer
is
None
:
if
tokenizer
is
None
:
tokenizer
=
make_tokenizer
(
tokenizer_type
,
ds
,
tokenizer_model_path
,
vocab_size
,
model_type
,
tokenizer
=
make_tokenizer
(
tokenizer_type
,
ds
,
tokenizer_model_path
,
vocab_size
,
model_type
,
pad_token
,
character_converage
,
**
kwargs
)
pad_token
,
character_converage
,
**
kwargs
)
ds_type
=
''
ds_type
=
''
if
'ds_type'
in
kwargs
:
if
'ds_type'
in
kwargs
:
...
@@ -121,7 +127,8 @@ def make_dataset(path, seq_length, text_key, label_key, lazy=False, process_fn=N
...
@@ -121,7 +127,8 @@ def make_dataset(path, seq_length, text_key, label_key, lazy=False, process_fn=N
if
'bert'
in
ds_type
.
lower
():
if
'bert'
in
ds_type
.
lower
():
presplit_sentences
=
kwargs
[
'presplit_sentences'
]
if
'presplit_sentences'
in
kwargs
else
False
presplit_sentences
=
kwargs
[
'presplit_sentences'
]
if
'presplit_sentences'
in
kwargs
else
False
dstype
=
bert_sentencepair_dataset
dstype
=
bert_sentencepair_dataset
ds
=
[
dstype
(
d
,
max_seq_len
=
seq_length
,
presplit_sentences
=
presplit_sentences
)
if
d
is
not
None
else
None
for
d
in
ds
]
ds
=
[
dstype
(
d
,
max_seq_len
=
seq_length
,
presplit_sentences
=
presplit_sentences
)
if
d
is
not
None
else
None
for
d
in
ds
]
elif
ds_type
.
lower
()
==
'gpt2'
:
elif
ds_type
.
lower
()
==
'gpt2'
:
ds
=
[
GPT2Dataset
(
d
,
max_seq_len
=
seq_length
)
if
d
is
not
None
else
None
for
d
in
ds
]
ds
=
[
GPT2Dataset
(
d
,
max_seq_len
=
seq_length
)
if
d
is
not
None
else
None
for
d
in
ds
]
else
:
else
:
...
...
megatron/deprecated_data_utils/configure_data.py
View file @
b7f1b050
...
@@ -21,6 +21,7 @@ import torch
...
@@ -21,6 +21,7 @@ import torch
from
megatron
import
data_utils
from
megatron
import
data_utils
from
megatron
import
mpu
from
megatron
import
mpu
class
DataConfig
:
class
DataConfig
:
def
__init__
(
self
,
defaults
=
{}):
def
__init__
(
self
,
defaults
=
{}):
...
@@ -48,7 +49,8 @@ def make_data_loader(dataset, batch_size, args):
...
@@ -48,7 +49,8 @@ def make_data_loader(dataset, batch_size, args):
shuffle
=
args
.
shuffle
shuffle
=
args
.
shuffle
if
shuffle
:
if
shuffle
:
sampler
=
data_utils
.
samplers
.
RandomSampler
(
dataset
,
replacement
=
True
,
num_samples
=
batch_size
*
args
.
train_iters
)
sampler
=
data_utils
.
samplers
.
RandomSampler
(
dataset
,
replacement
=
True
,
num_samples
=
batch_size
*
args
.
train_iters
)
else
:
else
:
sampler
=
torch
.
utils
.
data
.
SequentialSampler
(
dataset
)
sampler
=
torch
.
utils
.
data
.
SequentialSampler
(
dataset
)
world_size
=
torch
.
distributed
.
get_world_size
(
world_size
=
torch
.
distributed
.
get_world_size
(
...
@@ -204,6 +206,7 @@ def make_loaders(args):
...
@@ -204,6 +206,7 @@ def make_loaders(args):
return
(
train
,
valid
,
test
),
tokenizer
return
(
train
,
valid
,
test
),
tokenizer
def
get_split
(
args
):
def
get_split
(
args
):
"""
"""
Get dataset splits from comma separated string list
Get dataset splits from comma separated string list
...
@@ -217,7 +220,7 @@ def get_split(args):
...
@@ -217,7 +220,7 @@ def get_split(args):
splits
=
[
float
(
args
.
split
)]
splits
=
[
float
(
args
.
split
)]
split_total
=
sum
(
splits
)
split_total
=
sum
(
splits
)
if
split_total
<
1.
:
if
split_total
<
1.
:
splits
.
append
(
1
-
split_total
)
splits
.
append
(
1
-
split_total
)
while
len
(
splits
)
<
3
:
while
len
(
splits
)
<
3
:
splits
.
append
(
0.
)
splits
.
append
(
0.
)
splits
=
splits
[:
3
]
splits
=
splits
[:
3
]
...
@@ -226,10 +229,10 @@ def get_split(args):
...
@@ -226,10 +229,10 @@ def get_split(args):
if
args
.
test_data
is
not
None
:
if
args
.
test_data
is
not
None
:
splits
[
2
]
=
0.
splits
[
2
]
=
0.
final_sum
=
sum
(
splits
)
final_sum
=
sum
(
splits
)
return
[
s
/
final_sum
for
s
in
splits
]
return
[
s
/
final_sum
for
s
in
splits
]
def
configure_data
():
def
configure_data
():
"""add cmdline flags for configuring datasets"""
"""add cmdline flags for configuring datasets"""
# These are options that are used by data_utils, but are either
# These are options that are used by data_utils, but are either
# deprecated or not meant to be exposed to the command line user.
# deprecated or not meant to be exposed to the command line user.
...
...
megatron/deprecated_data_utils/corpora.py
View file @
b7f1b050
...
@@ -16,43 +16,46 @@
...
@@ -16,43 +16,46 @@
from
.datasets
import
json_dataset
,
csv_dataset
from
.datasets
import
json_dataset
,
csv_dataset
import
os
import
os
class
wikipedia
(
json_dataset
):
class
wikipedia
(
json_dataset
):
"""
"""
dataset for wikipedia with arguments configured for convenience
dataset for wikipedia with arguments configured for convenience
command line usage: `--train-data wikipedia`
command line usage: `--train-data wikipedia`
"""
"""
PATH
=
'data/wikipedia/wikidump_lines.json'
PATH
=
'data/wikipedia/wikidump_lines.json'
assert_str
=
"make sure to set PATH for wikipedia data_utils/corpora.py"
assert_str
=
"make sure to set PATH for wikipedia data_utils/corpora.py"
def
__init__
(
self
,
**
kwargs
):
assert
os
.
path
.
exists
(
wikipedia
.
PATH
),
\
def
__init__
(
self
,
**
kwargs
):
wikipedia
.
assert_str
assert
os
.
path
.
exists
(
wikipedia
.
PATH
),
\
if
not
kwargs
:
wikipedia
.
assert_str
kwargs
=
{}
if
not
kwargs
:
kwargs
[
'text_key'
]
=
'text'
kwargs
=
{}
kwargs
[
'loose_json'
]
=
True
kwargs
[
'text_key'
]
=
'text'
super
(
wikipedia
,
self
).
__init__
(
wikipedia
.
PATH
,
**
kwargs
)
kwargs
[
'loose_json'
]
=
True
super
(
wikipedia
,
self
).
__init__
(
wikipedia
.
PATH
,
**
kwargs
)
class
webtext
(
json_dataset
):
class
webtext
(
json_dataset
):
"""
"""
dataset for webtext with arguments configured for convenience
dataset for webtext with arguments configured for convenience
command line usage: `--train-data webtext`
command line usage: `--train-data webtext`
"""
"""
PATH
=
'data/webtext/data.json'
PATH
=
'data/webtext/data.json'
assert_str
=
"make sure to set PATH for webtext data_utils/corpora.py"
assert_str
=
"make sure to set PATH for webtext data_utils/corpora.py"
def
__init__
(
self
,
**
kwargs
):
assert
os
.
path
.
exists
(
webtext
.
PATH
),
\
def
__init__
(
self
,
**
kwargs
):
webtext
.
assert_str
assert
os
.
path
.
exists
(
webtext
.
PATH
),
\
if
not
kwargs
:
webtext
.
assert_str
kwargs
=
{}
if
not
kwargs
:
kwargs
[
'text_key'
]
=
'text'
kwargs
=
{}
kwargs
[
'loose_json'
]
=
True
kwargs
[
'text_key'
]
=
'text'
super
(
webtext
,
self
).
__init__
(
webtext
.
PATH
,
**
kwargs
)
kwargs
[
'loose_json'
]
=
True
super
(
webtext
,
self
).
__init__
(
webtext
.
PATH
,
**
kwargs
)
NAMED_CORPORA
=
{
NAMED_CORPORA
=
{
'wikipedia'
:
wikipedia
,
'wikipedia'
:
wikipedia
,
'webtext'
:
webtext
,
'webtext'
:
webtext
,
}
}
megatron/deprecated_data_utils/datasets.py
View file @
b7f1b050
...
@@ -34,6 +34,7 @@ from nltk import tokenize
...
@@ -34,6 +34,7 @@ from nltk import tokenize
from
.lazy_loader
import
lazy_array_loader
,
exists_lazy
,
make_lazy
from
.lazy_loader
import
lazy_array_loader
,
exists_lazy
,
make_lazy
from
.tokenization
import
Tokenization
from
.tokenization
import
Tokenization
class
ConcatDataset
(
data
.
Dataset
):
class
ConcatDataset
(
data
.
Dataset
):
"""
"""
Dataset to concatenate multiple datasets.
Dataset to concatenate multiple datasets.
...
@@ -57,7 +58,8 @@ class ConcatDataset(data.Dataset):
...
@@ -57,7 +58,8 @@ class ConcatDataset(data.Dataset):
super
(
ConcatDataset
,
self
).
__init__
()
super
(
ConcatDataset
,
self
).
__init__
()
assert
len
(
datasets
)
>
0
,
'datasets should not be an empty iterable'
assert
len
(
datasets
)
>
0
,
'datasets should not be an empty iterable'
self
.
datasets
=
list
(
datasets
)
self
.
datasets
=
list
(
datasets
)
self
.
is_lazy
=
sum
([
isinstance
(
ds
,
lazy_array_loader
)
for
ds
in
self
.
datasets
])
==
len
(
self
.
datasets
)
self
.
is_lazy
=
sum
([
isinstance
(
ds
,
lazy_array_loader
)
for
ds
in
self
.
datasets
])
==
len
(
self
.
datasets
)
self
.
cumulative_sizes
=
self
.
cumsum
(
self
.
datasets
)
self
.
cumulative_sizes
=
self
.
cumsum
(
self
.
datasets
)
self
.
_X
=
None
self
.
_X
=
None
self
.
_Y
=
None
self
.
_Y
=
None
...
@@ -90,7 +92,8 @@ class ConcatDataset(data.Dataset):
...
@@ -90,7 +92,8 @@ class ConcatDataset(data.Dataset):
self
.
_lens
.
extend
(
data
.
lens
)
self
.
_lens
.
extend
(
data
.
lens
)
else
:
else
:
for
data
in
self
.
datasets
:
for
data
in
self
.
datasets
:
self
.
_lens
.
extend
([
len
(
d
[
'text'
])
if
isinstance
(
d
,
dict
)
else
len
(
d
)
for
d
in
data
])
self
.
_lens
.
extend
([
len
(
d
[
'text'
])
if
isinstance
(
d
,
dict
)
else
len
(
d
)
for
d
in
data
])
return
self
.
_lens
return
self
.
_lens
@
property
@
property
...
@@ -116,6 +119,7 @@ class ConcatDataset(data.Dataset):
...
@@ -116,6 +119,7 @@ class ConcatDataset(data.Dataset):
"cumulative_sizes"
,
DeprecationWarning
,
stacklevel
=
2
)
"cumulative_sizes"
,
DeprecationWarning
,
stacklevel
=
2
)
return
self
.
cumulative_sizes
return
self
.
cumulative_sizes
class
SplitDataset
(
data
.
Dataset
):
class
SplitDataset
(
data
.
Dataset
):
"""
"""
Dataset wrapper to access a subset of another dataset.
Dataset wrapper to access a subset of another dataset.
...
@@ -126,6 +130,7 @@ class SplitDataset(data.Dataset):
...
@@ -126,6 +130,7 @@ class SplitDataset(data.Dataset):
ds (Dataset or array-like): List of datasets to be subindexed
ds (Dataset or array-like): List of datasets to be subindexed
split_inds (1D array-like): List of indices part of subset
split_inds (1D array-like): List of indices part of subset
"""
"""
def
__init__
(
self
,
ds
,
split_inds
,
**
kwargs
):
def
__init__
(
self
,
ds
,
split_inds
,
**
kwargs
):
self
.
split_inds
=
list
(
split_inds
)
self
.
split_inds
=
list
(
split_inds
)
self
.
wrapped_data
=
ds
self
.
wrapped_data
=
ds
...
@@ -163,7 +168,8 @@ class SplitDataset(data.Dataset):
...
@@ -163,7 +168,8 @@ class SplitDataset(data.Dataset):
for
idx
in
self
.
split_inds
:
for
idx
in
self
.
split_inds
:
yield
self
.
wrapped_data
[
idx
]
yield
self
.
wrapped_data
[
idx
]
def
split_ds
(
ds
,
split
=
[.
8
,.
2
,.
0
],
shuffle
=
True
):
def
split_ds
(
ds
,
split
=
[.
8
,
.
2
,
.
0
],
shuffle
=
True
):
"""
"""
Split a dataset into subsets given proportions of how
Split a dataset into subsets given proportions of how
much to allocate per split. If a split is 0% returns None for that split.
much to allocate per split. If a split is 0% returns None for that split.
...
@@ -184,18 +190,19 @@ def split_ds(ds, split=[.8,.2,.0], shuffle=True):
...
@@ -184,18 +190,19 @@ def split_ds(ds, split=[.8,.2,.0], shuffle=True):
np
.
random
.
shuffle
(
inds
)
np
.
random
.
shuffle
(
inds
)
start_idx
=
0
start_idx
=
0
residual_idx
=
0
residual_idx
=
0
rtn_ds
=
[
None
]
*
len
(
split
)
rtn_ds
=
[
None
]
*
len
(
split
)
for
i
,
f
in
enumerate
(
split
):
for
i
,
f
in
enumerate
(
split
):
if
f
!=
0
:
if
f
!=
0
:
proportion
=
ds_len
*
split
[
i
]
proportion
=
ds_len
*
split
[
i
]
residual_idx
+=
proportion
%
1
residual_idx
+=
proportion
%
1
split_
=
int
(
int
(
proportion
)
+
residual_idx
)
split_
=
int
(
int
(
proportion
)
+
residual_idx
)
split_inds
=
inds
[
start_idx
:
start_idx
+
max
(
split_
,
1
)]
split_inds
=
inds
[
start_idx
:
start_idx
+
max
(
split_
,
1
)]
rtn_ds
[
i
]
=
SplitDataset
(
ds
,
split_inds
)
rtn_ds
[
i
]
=
SplitDataset
(
ds
,
split_inds
)
start_idx
+=
split_
start_idx
+=
split_
residual_idx
%=
1
residual_idx
%=
1
return
rtn_ds
return
rtn_ds
class
csv_dataset
(
data
.
Dataset
):
class
csv_dataset
(
data
.
Dataset
):
"""
"""
Class for loading datasets from csv files.
Class for loading datasets from csv files.
...
@@ -214,9 +221,10 @@ class csv_dataset(data.Dataset):
...
@@ -214,9 +221,10 @@ class csv_dataset(data.Dataset):
X (list): all strings from the csv file
X (list): all strings from the csv file
Y (np.ndarray): labels to train with
Y (np.ndarray): labels to train with
"""
"""
def
__init__
(
self
,
path
,
tokenizer
=
None
,
preprocess_fn
=
None
,
delim
=
','
,
def
__init__
(
self
,
path
,
tokenizer
=
None
,
preprocess_fn
=
None
,
delim
=
','
,
binarize_sent
=
False
,
drop_unlabeled
=
False
,
text_key
=
'sentence'
,
label_key
=
'label'
,
binarize_sent
=
False
,
drop_unlabeled
=
False
,
text_key
=
'sentence'
,
label_key
=
'label'
,
**
kwargs
):
**
kwargs
):
self
.
is_lazy
=
False
self
.
is_lazy
=
False
self
.
preprocess_fn
=
preprocess_fn
self
.
preprocess_fn
=
preprocess_fn
self
.
SetTokenizer
(
tokenizer
)
self
.
SetTokenizer
(
tokenizer
)
...
@@ -229,7 +237,6 @@ class csv_dataset(data.Dataset):
...
@@ -229,7 +237,6 @@ class csv_dataset(data.Dataset):
if
'.tsv'
in
self
.
path
:
if
'.tsv'
in
self
.
path
:
self
.
delim
=
'
\t
'
self
.
delim
=
'
\t
'
self
.
X
=
[]
self
.
X
=
[]
self
.
Y
=
[]
self
.
Y
=
[]
try
:
try
:
...
@@ -239,7 +246,7 @@ class csv_dataset(data.Dataset):
...
@@ -239,7 +246,7 @@ class csv_dataset(data.Dataset):
else
:
else
:
cols
+=
[
label_key
]
cols
+=
[
label_key
]
data
=
pd
.
read_csv
(
self
.
path
,
sep
=
self
.
delim
,
usecols
=
cols
,
encoding
=
'latin-1'
)
data
=
pd
.
read_csv
(
self
.
path
,
sep
=
self
.
delim
,
usecols
=
cols
,
encoding
=
'latin-1'
)
except
:
except
BaseException
:
data
=
pd
.
read_csv
(
self
.
path
,
sep
=
self
.
delim
,
usecols
=
[
text_key
],
encoding
=
'latin-1'
)
data
=
pd
.
read_csv
(
self
.
path
,
sep
=
self
.
delim
,
usecols
=
[
text_key
],
encoding
=
'latin-1'
)
data
=
data
.
dropna
(
axis
=
0
)
data
=
data
.
dropna
(
axis
=
0
)
...
@@ -248,7 +255,7 @@ class csv_dataset(data.Dataset):
...
@@ -248,7 +255,7 @@ class csv_dataset(data.Dataset):
try
:
try
:
self
.
Y
=
data
[
label_key
].
values
self
.
Y
=
data
[
label_key
].
values
except
Exception
as
e
:
except
Exception
as
e
:
self
.
Y
=
np
.
ones
(
len
(
self
.
X
))
*
-
1
self
.
Y
=
np
.
ones
(
len
(
self
.
X
))
*
-
1
if
binarize_sent
:
if
binarize_sent
:
self
.
Y
=
binarize_labels
(
self
.
Y
,
hard
=
binarize_sent
)
self
.
Y
=
binarize_labels
(
self
.
Y
,
hard
=
binarize_sent
)
...
@@ -295,23 +302,25 @@ class csv_dataset(data.Dataset):
...
@@ -295,23 +302,25 @@ class csv_dataset(data.Dataset):
write the metrics, text, and labels to a csv file
write the metrics, text, and labels to a csv file
"""
"""
if
path
is
None
:
if
path
is
None
:
path
=
self
.
path
+
'.results'
path
=
self
.
path
+
'.results'
print
(
'generating csv at '
+
path
)
print
(
'generating csv at '
+
path
)
with
open
(
path
,
'w'
)
as
csvfile
:
with
open
(
path
,
'w'
)
as
csvfile
:
c
=
csv
.
writer
(
csvfile
,
delimiter
=
self
.
delim
)
c
=
csv
.
writer
(
csvfile
,
delimiter
=
self
.
delim
)
if
writer_gen
is
not
None
:
if
writer_gen
is
not
None
:
#if first item of generator is a header of what the metrics mean then write header to csv file
# if first item of generator is a header of what the metrics mean then
# write header to csv file
if
not
skip_header
:
if
not
skip_header
:
header
=
(
self
.
label_key
,)
+
tuple
(
next
(
writer_gen
))
+
(
self
.
text_key
,)
header
=
(
self
.
label_key
,)
+
tuple
(
next
(
writer_gen
))
+
(
self
.
text_key
,)
c
.
writerow
(
header
)
c
.
writerow
(
header
)
for
i
,
row
in
enumerate
(
writer_gen
):
for
i
,
row
in
enumerate
(
writer_gen
):
row
=
(
self
.
Y
[
i
],)
+
tuple
(
row
)
+
(
self
.
X
[
i
],)
row
=
(
self
.
Y
[
i
],)
+
tuple
(
row
)
+
(
self
.
X
[
i
],)
c
.
writerow
(
row
)
c
.
writerow
(
row
)
else
:
else
:
c
.
writerow
([
self
.
label_key
,
self
.
text_key
])
c
.
writerow
([
self
.
label_key
,
self
.
text_key
])
for
row
in
zip
(
self
.
Y
,
self
.
X
):
for
row
in
zip
(
self
.
Y
,
self
.
X
):
c
.
writerow
(
row
)
c
.
writerow
(
row
)
class
json_dataset
(
data
.
Dataset
):
class
json_dataset
(
data
.
Dataset
):
"""
"""
Class for loading datasets from a json dump.
Class for loading datasets from a json dump.
...
@@ -327,8 +336,9 @@ class json_dataset(data.Dataset):
...
@@ -327,8 +336,9 @@ class json_dataset(data.Dataset):
all_strs (list): list of all strings from the dataset
all_strs (list): list of all strings from the dataset
all_labels (list): list of all labels from the dataset (if they have it)
all_labels (list): list of all labels from the dataset (if they have it)
"""
"""
def
__init__
(
self
,
path
,
tokenizer
=
None
,
preprocess_fn
=
None
,
binarize_sent
=
False
,
def
__init__
(
self
,
path
,
tokenizer
=
None
,
preprocess_fn
=
None
,
binarize_sent
=
False
,
text_key
=
'sentence'
,
label_key
=
'label'
,
loose_json
=
False
,
**
kwargs
):
text_key
=
'sentence'
,
label_key
=
'label'
,
loose_json
=
False
,
**
kwargs
):
self
.
is_lazy
=
False
self
.
is_lazy
=
False
self
.
preprocess_fn
=
preprocess_fn
self
.
preprocess_fn
=
preprocess_fn
self
.
path
=
path
self
.
path
=
path
...
@@ -389,24 +399,25 @@ class json_dataset(data.Dataset):
...
@@ -389,24 +399,25 @@ class json_dataset(data.Dataset):
write the metrics, text, and labels to a json file
write the metrics, text, and labels to a json file
"""
"""
if
path
is
None
:
if
path
is
None
:
path
=
self
.
path
+
'.results'
path
=
self
.
path
+
'.results'
jsons
=
[]
jsons
=
[]
if
writer_gen
is
not
None
:
if
writer_gen
is
not
None
:
#if first item of generator is a header of what the metrics mean then write header to csv file
# if first item of generator is a header of what the metrics mean then
# write header to csv file
def
gen_helper
():
def
gen_helper
():
keys
=
{}
keys
=
{}
keys
[
0
]
=
self
.
label_key
keys
[
0
]
=
self
.
label_key
if
not
skip_header
:
if
not
skip_header
:
for
idx
,
k
in
enumerate
(
tuple
(
next
(
writer_gen
))):
for
idx
,
k
in
enumerate
(
tuple
(
next
(
writer_gen
))):
keys
[
idx
+
1
]
=
k
keys
[
idx
+
1
]
=
k
for
i
,
row
in
enumerate
(
writer_gen
):
for
i
,
row
in
enumerate
(
writer_gen
):
if
i
==
0
and
skip_header
:
if
i
==
0
and
skip_header
:
for
idx
,
_
in
enumerate
(
row
):
for
idx
,
_
in
enumerate
(
row
):
keys
[
idx
+
1
]
=
'metric_%d'
%
(
idx
,)
keys
[
idx
+
1
]
=
'metric_%d'
%
(
idx
,)
j
=
{}
j
=
{}
for
idx
,
v
in
enumerate
((
self
.
Y
[
i
],)
+
tuple
(
row
)):
for
idx
,
v
in
enumerate
((
self
.
Y
[
i
],)
+
tuple
(
row
)):
k
=
keys
[
idx
]
k
=
keys
[
idx
]
j
[
k
]
=
v
j
[
k
]
=
v
yield
j
yield
j
...
@@ -453,6 +464,7 @@ class json_dataset(data.Dataset):
...
@@ -453,6 +464,7 @@ class json_dataset(data.Dataset):
j
[
self
.
label_key
]
=
-
1
j
[
self
.
label_key
]
=
-
1
yield
j
yield
j
class
GPT2Dataset
(
data
.
Dataset
):
class
GPT2Dataset
(
data
.
Dataset
):
def
__init__
(
self
,
ds
,
def
__init__
(
self
,
ds
,
...
@@ -503,7 +515,7 @@ class GPT2Dataset(data.Dataset):
...
@@ -503,7 +515,7 @@ class GPT2Dataset(data.Dataset):
def
__getitem__
(
self
,
idx
):
def
__getitem__
(
self
,
idx
):
# init rng
# init rng
rng
=
random
.
Random
(
idx
)
rng
=
random
.
Random
(
idx
)
rng
=
np
.
random
.
RandomState
(
seed
=
[
rng
.
randint
(
0
,
2
**
32
-
1
)
for
_
in
range
(
16
)])
rng
=
np
.
random
.
RandomState
(
seed
=
[
rng
.
randint
(
0
,
2
**
32
-
1
)
for
_
in
range
(
16
)])
# get possibly weighted random index from dataset
# get possibly weighted random index from dataset
data_idx
=
self
.
get_weighted_samples
(
rng
)
data_idx
=
self
.
get_weighted_samples
(
rng
)
...
@@ -538,10 +550,10 @@ class GPT2Dataset(data.Dataset):
...
@@ -538,10 +550,10 @@ class GPT2Dataset(data.Dataset):
else
:
else
:
data_idx
=
(
data_idx
+
1
)
%
self
.
ds_len
data_idx
=
(
data_idx
+
1
)
%
self
.
ds_len
tokens
+=
self
.
getidx
(
data_idx
)
tokens
+=
self
.
getidx
(
data_idx
)
tokens
=
tokens
[:(
self
.
max_seq_len
+
1
)]
tokens
=
tokens
[:(
self
.
max_seq_len
+
1
)]
tokens
=
self
.
pad_seq
(
tokens
)
tokens
=
self
.
pad_seq
(
tokens
)
return
{
'text'
:
np
.
array
(
tokens
),}
return
{
'text'
:
np
.
array
(
tokens
),
}
def
getidx
(
self
,
data_idx
):
def
getidx
(
self
,
data_idx
):
data
=
self
.
ds
[
data_idx
]
data
=
self
.
ds
[
data_idx
]
...
@@ -556,7 +568,7 @@ class GPT2Dataset(data.Dataset):
...
@@ -556,7 +568,7 @@ class GPT2Dataset(data.Dataset):
def
pad_seq
(
self
,
seq
):
def
pad_seq
(
self
,
seq
):
total_tokens
=
self
.
max_seq_len
+
1
total_tokens
=
self
.
max_seq_len
+
1
num_pad_tokens
=
max
(
0
,
total_tokens
-
len
(
seq
))
num_pad_tokens
=
max
(
0
,
total_tokens
-
len
(
seq
))
seq
+=
[
self
.
tokenizer
.
get_command
(
'pad'
).
Id
]
*
(
num_pad_tokens
)
seq
+=
[
self
.
tokenizer
.
get_command
(
'pad'
).
Id
]
*
(
num_pad_tokens
)
return
seq
return
seq
def
contains_sentence_end
(
self
,
tok
):
def
contains_sentence_end
(
self
,
tok
):
...
@@ -569,6 +581,7 @@ class GPT2Dataset(data.Dataset):
...
@@ -569,6 +581,7 @@ class GPT2Dataset(data.Dataset):
return
True
return
True
return
False
return
False
class
bert_sentencepair_dataset
(
data
.
Dataset
):
class
bert_sentencepair_dataset
(
data
.
Dataset
):
"""
"""
Dataset containing sentencepairs for BERT training. Each index corresponds to a randomly generated sentence pair.
Dataset containing sentencepairs for BERT training. Each index corresponds to a randomly generated sentence pair.
...
@@ -581,7 +594,9 @@ class bert_sentencepair_dataset(data.Dataset):
...
@@ -581,7 +594,9 @@ class bert_sentencepair_dataset(data.Dataset):
dataset_size (int): number of random sentencepairs in the dataset. Default: len(ds)*(len(ds)-1)
dataset_size (int): number of random sentencepairs in the dataset. Default: len(ds)*(len(ds)-1)
"""
"""
def
__init__
(
self
,
ds
,
max_seq_len
=
512
,
mask_lm_prob
=
.
15
,
max_preds_per_seq
=
None
,
short_seq_prob
=
.
01
,
dataset_size
=
None
,
presplit_sentences
=
False
,
weighted
=
True
,
**
kwargs
):
def
__init__
(
self
,
ds
,
max_seq_len
=
512
,
mask_lm_prob
=
.
15
,
max_preds_per_seq
=
None
,
short_seq_prob
=
.
01
,
dataset_size
=
None
,
presplit_sentences
=
False
,
weighted
=
True
,
**
kwargs
):
self
.
ds
=
ds
self
.
ds
=
ds
self
.
ds_len
=
len
(
self
.
ds
)
self
.
ds_len
=
len
(
self
.
ds
)
self
.
tokenizer
=
self
.
ds
.
GetTokenizer
()
self
.
tokenizer
=
self
.
ds
.
GetTokenizer
()
...
@@ -590,12 +605,12 @@ class bert_sentencepair_dataset(data.Dataset):
...
@@ -590,12 +605,12 @@ class bert_sentencepair_dataset(data.Dataset):
self
.
max_seq_len
=
max_seq_len
self
.
max_seq_len
=
max_seq_len
self
.
mask_lm_prob
=
mask_lm_prob
self
.
mask_lm_prob
=
mask_lm_prob
if
max_preds_per_seq
is
None
:
if
max_preds_per_seq
is
None
:
max_preds_per_seq
=
math
.
ceil
(
max_seq_len
*
mask_lm_prob
/
10
)
*
10
max_preds_per_seq
=
math
.
ceil
(
max_seq_len
*
mask_lm_prob
/
10
)
*
10
self
.
max_preds_per_seq
=
max_preds_per_seq
self
.
max_preds_per_seq
=
max_preds_per_seq
self
.
short_seq_prob
=
short_seq_prob
self
.
short_seq_prob
=
short_seq_prob
self
.
dataset_size
=
dataset_size
self
.
dataset_size
=
dataset_size
if
self
.
dataset_size
is
None
:
if
self
.
dataset_size
is
None
:
self
.
dataset_size
=
self
.
ds_len
*
(
self
.
ds_len
-
1
)
self
.
dataset_size
=
self
.
ds_len
*
(
self
.
ds_len
-
1
)
self
.
presplit_sentences
=
presplit_sentences
self
.
presplit_sentences
=
presplit_sentences
if
not
self
.
presplit_sentences
:
if
not
self
.
presplit_sentences
:
nltk
.
download
(
'punkt'
,
download_dir
=
"./nltk"
)
nltk
.
download
(
'punkt'
,
download_dir
=
"./nltk"
)
...
@@ -607,7 +622,8 @@ class bert_sentencepair_dataset(data.Dataset):
...
@@ -607,7 +622,8 @@ class bert_sentencepair_dataset(data.Dataset):
if
hasattr
(
self
.
ds
,
'is_lazy'
)
and
self
.
ds
.
is_lazy
:
if
hasattr
(
self
.
ds
,
'is_lazy'
)
and
self
.
ds
.
is_lazy
:
lens
=
np
.
array
(
self
.
ds
.
lens
)
lens
=
np
.
array
(
self
.
ds
.
lens
)
else
:
else
:
lens
=
np
.
array
([
len
(
d
[
'text'
])
if
isinstance
(
d
,
dict
)
else
len
(
d
)
for
d
in
self
.
ds
])
lens
=
np
.
array
([
len
(
d
[
'text'
])
if
isinstance
(
d
,
dict
)
else
len
(
d
)
for
d
in
self
.
ds
])
self
.
total_len
=
np
.
sum
(
lens
)
self
.
total_len
=
np
.
sum
(
lens
)
self
.
weighting
=
list
(
accumulate
(
lens
))
self
.
weighting
=
list
(
accumulate
(
lens
))
else
:
else
:
...
@@ -626,7 +642,7 @@ class bert_sentencepair_dataset(data.Dataset):
...
@@ -626,7 +642,7 @@ class bert_sentencepair_dataset(data.Dataset):
def
__getitem__
(
self
,
idx
):
def
__getitem__
(
self
,
idx
):
# get rng state corresponding to index (allows deterministic random pair)
# get rng state corresponding to index (allows deterministic random pair)
rng
=
random
.
Random
(
idx
)
rng
=
random
.
Random
(
idx
)
np_rng
=
np
.
random
.
RandomState
(
seed
=
[
rng
.
randint
(
0
,
2
**
32
-
1
)
for
_
in
range
(
16
)])
np_rng
=
np
.
random
.
RandomState
(
seed
=
[
rng
.
randint
(
0
,
2
**
32
-
1
)
for
_
in
range
(
16
)])
# get seq length
# get seq length
target_seq_length
=
self
.
max_seq_len
target_seq_length
=
self
.
max_seq_len
short_seq
=
False
short_seq
=
False
...
@@ -639,15 +655,25 @@ class bert_sentencepair_dataset(data.Dataset):
...
@@ -639,15 +655,25 @@ class bert_sentencepair_dataset(data.Dataset):
lena
=
0
lena
=
0
lenb
=
0
lenb
=
0
while
(
is_random_next
is
None
)
or
(
lena
<
1
)
or
(
lenb
<
1
):
while
(
is_random_next
is
None
)
or
(
lena
<
1
)
or
(
lenb
<
1
):
tokensa
,
tokensb
,
is_random_next
=
self
.
create_random_sentencepair
(
target_seq_length
,
rng
,
np_rng
)
tokensa
,
tokensb
,
is_random_next
=
self
.
create_random_sentencepair
(
target_seq_length
,
rng
,
np_rng
)
lena
=
len
(
tokensa
[
0
])
lena
=
len
(
tokensa
[
0
])
lenb
=
len
(
tokensb
[
0
])
lenb
=
len
(
tokensb
[
0
])
# truncate sentence pair to max_seq_len
# truncate sentence pair to max_seq_len
tokensa
,
tokensb
=
self
.
truncate_seq_pair
(
tokensa
,
tokensb
,
self
.
max_seq_len
,
rng
)
tokensa
,
tokensb
=
self
.
truncate_seq_pair
(
tokensa
,
tokensb
,
self
.
max_seq_len
,
rng
)
# join sentence pair, mask, and pad
# join sentence pair, mask, and pad
tokens
,
mask
,
mask_labels
,
pad_mask
=
self
.
create_masked_lm_predictions
(
tokensa
,
tokensb
,
self
.
mask_lm_prob
,
self
.
max_preds_per_seq
,
self
.
vocab_words
,
rng
)
tokens
,
mask
,
mask_labels
,
pad_mask
=
self
.
create_masked_lm_predictions
(
sample
=
{
'text'
:
np
.
array
(
tokens
[
0
]),
'types'
:
np
.
array
(
tokens
[
1
]),
'is_random'
:
int
(
is_random_next
),
'mask'
:
np
.
array
(
mask
),
'mask_labels'
:
np
.
array
(
mask_labels
),
'pad_mask'
:
np
.
array
(
pad_mask
)}
tokensa
,
tokensb
,
self
.
mask_lm_prob
,
self
.
max_preds_per_seq
,
self
.
vocab_words
,
rng
)
sample
=
{
'text'
:
np
.
array
(
tokens
[
0
]),
'types'
:
np
.
array
(
tokens
[
1
]),
'is_random'
:
int
(
is_random_next
),
'mask'
:
np
.
array
(
mask
),
'mask_labels'
:
np
.
array
(
mask_labels
),
'pad_mask'
:
np
.
array
(
pad_mask
)}
return
sample
return
sample
def
sentence_split
(
self
,
document
):
def
sentence_split
(
self
,
document
):
...
@@ -665,7 +691,7 @@ class bert_sentencepair_dataset(data.Dataset):
...
@@ -665,7 +691,7 @@ class bert_sentencepair_dataset(data.Dataset):
"""tokenize sentence and get token types"""
"""tokenize sentence and get token types"""
tokens
=
self
.
tokenizer
.
EncodeAsIds
(
sent
).
tokenization
tokens
=
self
.
tokenizer
.
EncodeAsIds
(
sent
).
tokenization
str_type
=
'str'
+
str
(
sentence_num
)
str_type
=
'str'
+
str
(
sentence_num
)
token_types
=
[
self
.
tokenizer
.
get_type
(
str_type
).
Id
]
*
len
(
tokens
)
token_types
=
[
self
.
tokenizer
.
get_type
(
str_type
).
Id
]
*
len
(
tokens
)
return
tokens
,
token_types
return
tokens
,
token_types
def
get_doc
(
self
,
idx
):
def
get_doc
(
self
,
idx
):
...
@@ -694,21 +720,22 @@ class bert_sentencepair_dataset(data.Dataset):
...
@@ -694,21 +720,22 @@ class bert_sentencepair_dataset(data.Dataset):
# doc_a_idx = np_rng.choice(self.ds_len, p=self.weighting)
# doc_a_idx = np_rng.choice(self.ds_len, p=self.weighting)
doc_a_idx
=
self
.
get_weighted_samples
(
np_rng
)
doc_a_idx
=
self
.
get_weighted_samples
(
np_rng
)
else
:
else
:
doc_a_idx
=
rng
.
randint
(
0
,
self
.
ds_len
-
1
)
doc_a_idx
=
rng
.
randint
(
0
,
self
.
ds_len
-
1
)
doc_a
=
self
.
sentence_split
(
self
.
get_doc
(
doc_a_idx
))
doc_a
=
self
.
sentence_split
(
self
.
get_doc
(
doc_a_idx
))
if
not
doc_a
:
if
not
doc_a
:
doc_a
=
None
doc_a
=
None
random_start_a
=
rng
.
randint
(
0
,
len
(
doc_a
)
-
1
)
random_start_a
=
rng
.
randint
(
0
,
len
(
doc_a
)
-
1
)
while
random_start_a
<
len
(
doc_a
):
while
random_start_a
<
len
(
doc_a
):
sentence
=
doc_a
[
random_start_a
]
sentence
=
doc_a
[
random_start_a
]
sentence
,
sentence_types
=
self
.
sentence_tokenize
(
sentence
,
0
,
random_start_a
==
0
,
random_start_a
==
len
(
doc_a
))
sentence
,
sentence_types
=
self
.
sentence_tokenize
(
sentence
,
0
,
random_start_a
==
0
,
random_start_a
==
len
(
doc_a
))
curr_strs
.
append
(
sentence
)
curr_strs
.
append
(
sentence
)
curr_str_types
.
append
(
sentence_types
)
curr_str_types
.
append
(
sentence_types
)
curr_len
+=
len
(
sentence
)
curr_len
+=
len
(
sentence
)
if
random_start_a
==
len
(
doc_a
)
-
1
or
curr_len
>=
target_seq_length
:
if
random_start_a
==
len
(
doc_a
)
-
1
or
curr_len
>=
target_seq_length
:
break
break
random_start_a
=
(
random_start_a
+
1
)
random_start_a
=
(
random_start_a
+
1
)
if
curr_strs
:
if
curr_strs
:
num_a
=
1
num_a
=
1
...
@@ -738,16 +765,17 @@ class bert_sentencepair_dataset(data.Dataset):
...
@@ -738,16 +765,17 @@ class bert_sentencepair_dataset(data.Dataset):
if
not
doc_b
:
if
not
doc_b
:
doc_b
=
None
doc_b
=
None
random_start_b
=
rng
.
randint
(
0
,
len
(
doc_b
)
-
1
)
random_start_b
=
rng
.
randint
(
0
,
len
(
doc_b
)
-
1
)
while
random_start_b
<
len
(
doc_b
):
while
random_start_b
<
len
(
doc_b
):
sentence_b
=
doc_b
[
random_start_b
]
sentence_b
=
doc_b
[
random_start_b
]
new_b_tokens
,
new_b_types
=
self
.
sentence_tokenize
(
sentence_b
,
1
,
random_start_b
==
0
,
random_start_b
==
len
(
doc_b
))
new_b_tokens
,
new_b_types
=
self
.
sentence_tokenize
(
sentence_b
,
1
,
random_start_b
==
0
,
random_start_b
==
len
(
doc_b
))
b_len
+=
len
(
new_b_tokens
)
b_len
+=
len
(
new_b_tokens
)
tokens_b
.
extend
(
new_b_tokens
)
tokens_b
.
extend
(
new_b_tokens
)
token_types_b
.
extend
(
new_b_types
)
token_types_b
.
extend
(
new_b_types
)
if
len
(
tokens_b
)
>=
target_b_length
:
if
len
(
tokens_b
)
>=
target_b_length
:
break
break
random_start_b
=
(
random_start_b
+
1
)
random_start_b
=
(
random_start_b
+
1
)
else
:
else
:
is_random_next
=
False
is_random_next
=
False
for
j
in
range
(
num_a
,
len
(
curr_strs
)):
for
j
in
range
(
num_a
,
len
(
curr_strs
)):
...
@@ -812,13 +840,15 @@ class bert_sentencepair_dataset(data.Dataset):
...
@@ -812,13 +840,15 @@ class bert_sentencepair_dataset(data.Dataset):
def
pad_seq
(
self
,
seq
):
def
pad_seq
(
self
,
seq
):
"""helper function to pad sequence pair"""
"""helper function to pad sequence pair"""
num_pad
=
max
(
0
,
self
.
max_seq_len
-
len
(
seq
))
num_pad
=
max
(
0
,
self
.
max_seq_len
-
len
(
seq
))
pad_mask
=
[
0
]
*
len
(
seq
)
+
[
1
]
*
num_pad
pad_mask
=
[
0
]
*
len
(
seq
)
+
[
1
]
*
num_pad
seq
+=
[
self
.
tokenizer
.
get_command
(
'pad'
).
Id
]
*
num_pad
seq
+=
[
self
.
tokenizer
.
get_command
(
'pad'
).
Id
]
*
num_pad
return
seq
,
pad_mask
return
seq
,
pad_mask
def
concat_tokens
(
self
,
tokens_a
,
token_types_a
,
tokens_b
,
token_types_b
):
def
concat_tokens
(
self
,
tokens_a
,
token_types_a
,
tokens_b
,
token_types_b
):
tokens
=
[
self
.
tokenizer
.
get_command
(
'ENC'
).
Id
]
+
tokens_a
+
[
self
.
tokenizer
.
get_command
(
'sep'
).
Id
]
+
tokens_b
+
[
self
.
tokenizer
.
get_command
(
'sep'
).
Id
]
tokens
=
[
self
.
tokenizer
.
get_command
(
'ENC'
).
Id
]
+
tokens_a
+
[
self
.
tokenizer
.
get_command
(
token_types
=
[
token_types_a
[
0
]]
+
token_types_a
+
[
token_types_a
[
0
]]
+
token_types_b
+
[
token_types_b
[
0
]]
'sep'
).
Id
]
+
tokens_b
+
[
self
.
tokenizer
.
get_command
(
'sep'
).
Id
]
token_types
=
[
token_types_a
[
0
]]
+
token_types_a
+
\
[
token_types_a
[
0
]]
+
token_types_b
+
[
token_types_b
[
0
]]
return
tokens
,
token_types
return
tokens
,
token_types
def
create_masked_lm_predictions
(
self
,
a
,
b
,
mask_lm_prob
,
max_preds_per_seq
,
vocab_words
,
rng
):
def
create_masked_lm_predictions
(
self
,
a
,
b
,
mask_lm_prob
,
max_preds_per_seq
,
vocab_words
,
rng
):
...
@@ -833,7 +863,7 @@ class bert_sentencepair_dataset(data.Dataset):
...
@@ -833,7 +863,7 @@ class bert_sentencepair_dataset(data.Dataset):
len_a
=
len
(
tokens_a
)
len_a
=
len
(
tokens_a
)
len_b
=
len
(
tokens_b
)
len_b
=
len
(
tokens_b
)
cand_indices
=
[
idx
+
1
for
idx
in
range
(
len_a
)]
+
[
idx
+
2
+
len_a
for
idx
in
range
(
len_b
)]
cand_indices
=
[
idx
+
1
for
idx
in
range
(
len_a
)]
+
[
idx
+
2
+
len_a
for
idx
in
range
(
len_b
)]
rng
.
shuffle
(
cand_indices
)
rng
.
shuffle
(
cand_indices
)
...
...
megatron/deprecated_data_utils/file_utils.py
View file @
b7f1b050
...
@@ -169,7 +169,7 @@ def http_get(url, temp_file):
...
@@ -169,7 +169,7 @@ def http_get(url, temp_file):
total
=
int
(
content_length
)
if
content_length
is
not
None
else
None
total
=
int
(
content_length
)
if
content_length
is
not
None
else
None
progress
=
tqdm
(
unit
=
"B"
,
total
=
total
)
progress
=
tqdm
(
unit
=
"B"
,
total
=
total
)
for
chunk
in
req
.
iter_content
(
chunk_size
=
1024
):
for
chunk
in
req
.
iter_content
(
chunk_size
=
1024
):
if
chunk
:
# filter out keep-alive new chunks
if
chunk
:
# filter out keep-alive new chunks
progress
.
update
(
len
(
chunk
))
progress
.
update
(
len
(
chunk
))
temp_file
.
write
(
chunk
)
temp_file
.
write
(
chunk
)
progress
.
close
()
progress
.
close
()
...
...
megatron/deprecated_data_utils/lazy_loader.py
View file @
b7f1b050
...
@@ -22,11 +22,13 @@ from itertools import accumulate
...
@@ -22,11 +22,13 @@ from itertools import accumulate
import
torch
import
torch
from
torch.multiprocessing
import
Lock
from
torch.multiprocessing
import
Lock
def
get_lazy_path
(
path
):
def
get_lazy_path
(
path
):
"""
"""
Gets directory path where lazy files are stored.
Gets directory path where lazy files are stored.
"""
"""
return
os
.
path
.
splitext
(
path
)[
0
]
+
'.lazy'
return
os
.
path
.
splitext
(
path
)[
0
]
+
'.lazy'
def
exists_lazy
(
path
,
data_type
=
'data'
):
def
exists_lazy
(
path
,
data_type
=
'data'
):
"""
"""
...
@@ -37,10 +39,11 @@ def exists_lazy(path, data_type='data'):
...
@@ -37,10 +39,11 @@ def exists_lazy(path, data_type='data'):
contents
=
os
.
listdir
(
get_lazy_path
(
path
))
contents
=
os
.
listdir
(
get_lazy_path
(
path
))
if
data_type
not
in
contents
:
if
data_type
not
in
contents
:
return
False
return
False
if
data_type
+
'.len.pkl'
not
in
contents
:
if
data_type
+
'.len.pkl'
not
in
contents
:
return
False
return
False
return
True
return
True
def
make_lazy
(
path
,
strs
,
data_type
=
'data'
):
def
make_lazy
(
path
,
strs
,
data_type
=
'data'
):
"""
"""
Make lazy version of `data_type` field of the file. Byte offsets
Make lazy version of `data_type` field of the file. Byte offsets
...
@@ -50,7 +53,7 @@ def make_lazy(path, strs, data_type='data'):
...
@@ -50,7 +53,7 @@ def make_lazy(path, strs, data_type='data'):
if
not
os
.
path
.
exists
(
lazypath
):
if
not
os
.
path
.
exists
(
lazypath
):
os
.
makedirs
(
lazypath
)
os
.
makedirs
(
lazypath
)
datapath
=
os
.
path
.
join
(
lazypath
,
data_type
)
datapath
=
os
.
path
.
join
(
lazypath
,
data_type
)
lenpath
=
os
.
path
.
join
(
lazypath
,
data_type
+
'.len.pkl'
)
lenpath
=
os
.
path
.
join
(
lazypath
,
data_type
+
'.len.pkl'
)
if
not
torch
.
distributed
.
is_initialized
()
or
torch
.
distributed
.
get_rank
()
==
0
:
if
not
torch
.
distributed
.
is_initialized
()
or
torch
.
distributed
.
get_rank
()
==
0
:
with
open
(
datapath
,
'wb'
)
as
f
:
with
open
(
datapath
,
'wb'
)
as
f
:
str_lens
=
[]
str_lens
=
[]
...
@@ -67,28 +70,32 @@ def make_lazy(path, strs, data_type='data'):
...
@@ -67,28 +70,32 @@ def make_lazy(path, strs, data_type='data'):
while
not
os
.
path
.
exists
(
lenpath
):
while
not
os
.
path
.
exists
(
lenpath
):
time
.
sleep
(
1
)
time
.
sleep
(
1
)
def
split_strings
(
strings
,
start
,
chr_lens
):
def
split_strings
(
strings
,
start
,
chr_lens
):
"""
"""
Split strings based on string lengths and given start.
Split strings based on string lengths and given start.
"""
"""
return
[
strings
[
i
-
start
:
j
-
start
]
for
i
,
j
in
zip
([
start
]
+
chr_lens
[:
-
1
],
chr_lens
)]
return
[
strings
[
i
-
start
:
j
-
start
]
for
i
,
j
in
zip
([
start
]
+
chr_lens
[:
-
1
],
chr_lens
)]
class
ProcessorTokenizer
:
class
ProcessorTokenizer
:
"""
"""
callable class that runs a preprocessing, as well as tokenization step,
callable class that runs a preprocessing, as well as tokenization step,
on input text.
on input text.
"""
"""
def
__init__
(
self
,
tokenizer
,
process_fn
=
None
):
def
__init__
(
self
,
tokenizer
,
process_fn
=
None
):
self
.
tokenizer
=
tokenizer
self
.
tokenizer
=
tokenizer
self
.
process_fn
=
process_fn
self
.
process_fn
=
process_fn
def
__call__
(
self
,
string
):
def
__call__
(
self
,
string
):
if
self
.
tokenizer
is
not
None
:
if
self
.
tokenizer
is
not
None
:
string
=
self
.
tokenizer
(
string
,
process_fn
=
self
.
process_fn
)
string
=
self
.
tokenizer
(
string
,
process_fn
=
self
.
process_fn
)
elif
self
.
process_fn
is
not
None
:
elif
self
.
process_fn
is
not
None
:
string
=
self
.
process_fn
(
string
)
string
=
self
.
process_fn
(
string
)
return
string
return
string
class
lazy_array_loader
(
object
):
class
lazy_array_loader
(
object
):
"""
"""
Arguments:
Arguments:
...
@@ -107,17 +114,18 @@ class lazy_array_loader(object):
...
@@ -107,17 +114,18 @@ class lazy_array_loader(object):
data_type2
data_type2
data_type2.len.pkl
data_type2.len.pkl
"""
"""
def
__init__
(
self
,
path
,
data_type
=
'data'
,
mem_map
=
False
,
map_fn
=
None
):
def
__init__
(
self
,
path
,
data_type
=
'data'
,
mem_map
=
False
,
map_fn
=
None
):
lazypath
=
get_lazy_path
(
path
)
lazypath
=
get_lazy_path
(
path
)
datapath
=
os
.
path
.
join
(
lazypath
,
data_type
)
datapath
=
os
.
path
.
join
(
lazypath
,
data_type
)
#get file where array entries are concatenated into one big string
#
get file where array entries are concatenated into one big string
self
.
_file
=
open
(
datapath
,
'rb'
,
buffering
=
0
)
self
.
_file
=
open
(
datapath
,
'rb'
,
buffering
=
0
)
self
.
file
=
self
.
_file
self
.
file
=
self
.
_file
#memory map file if necessary
#
memory map file if necessary
self
.
mem_map
=
mem_map
self
.
mem_map
=
mem_map
if
self
.
mem_map
:
if
self
.
mem_map
:
self
.
file
=
mmap
.
mmap
(
self
.
file
.
fileno
(),
0
,
prot
=
mmap
.
PROT_READ
)
self
.
file
=
mmap
.
mmap
(
self
.
file
.
fileno
(),
0
,
prot
=
mmap
.
PROT_READ
)
lenpath
=
os
.
path
.
join
(
lazypath
,
data_type
+
'.len.pkl'
)
lenpath
=
os
.
path
.
join
(
lazypath
,
data_type
+
'.len.pkl'
)
self
.
lens
=
pkl
.
load
(
open
(
lenpath
,
'rb'
))
self
.
lens
=
pkl
.
load
(
open
(
lenpath
,
'rb'
))
self
.
ends
=
list
(
accumulate
(
self
.
lens
))
self
.
ends
=
list
(
accumulate
(
self
.
lens
))
self
.
dumb_ends
=
list
(
self
.
ends
)
self
.
dumb_ends
=
list
(
self
.
ends
)
...
@@ -149,7 +157,7 @@ class lazy_array_loader(object):
...
@@ -149,7 +157,7 @@ class lazy_array_loader(object):
if
index
==
0
:
if
index
==
0
:
start
=
0
start
=
0
else
:
else
:
start
=
self
.
ends
[
index
-
1
]
start
=
self
.
ends
[
index
-
1
]
end
=
self
.
ends
[
index
]
end
=
self
.
ends
[
index
]
rtn
=
self
.
file_read
(
start
,
end
)
rtn
=
self
.
file_read
(
start
,
end
)
if
self
.
map_fn
is
not
None
:
if
self
.
map_fn
is
not
None
:
...
@@ -160,7 +168,7 @@ class lazy_array_loader(object):
...
@@ -160,7 +168,7 @@ class lazy_array_loader(object):
if
index
.
start
==
0
or
index
.
start
is
None
:
if
index
.
start
==
0
or
index
.
start
is
None
:
start
=
0
start
=
0
else
:
else
:
start
=
self
.
ends
[
index
.
start
-
1
]
start
=
self
.
ends
[
index
.
start
-
1
]
stop
=
chr_lens
[
-
1
]
stop
=
chr_lens
[
-
1
]
strings
=
self
.
file_read
(
start
,
stop
)
strings
=
self
.
file_read
(
start
,
stop
)
rtn
=
split_strings
(
strings
,
start
,
chr_lens
)
rtn
=
split_strings
(
strings
,
start
,
chr_lens
)
...
@@ -181,15 +189,14 @@ class lazy_array_loader(object):
...
@@ -181,15 +189,14 @@ class lazy_array_loader(object):
# read to end of file if no end point provided
# read to end of file if no end point provided
if
end
is
None
:
if
end
is
None
:
rtn
=
self
.
file
.
read
()
rtn
=
self
.
file
.
read
()
#else read amount needed to reach end point
#
else read amount needed to reach end point
else
:
else
:
rtn
=
self
.
file
.
read
(
end
-
start
)
rtn
=
self
.
file
.
read
(
end
-
start
)
self
.
read_lock
.
release
()
self
.
read_lock
.
release
()
#TODO: @raulp figure out mem map byte string bug
#
TODO: @raulp figure out mem map byte string bug
#if mem map'd need to decode byte string to string
#
if mem map'd need to decode byte string to string
rtn
=
rtn
.
decode
(
'utf-8'
,
'ignore'
)
rtn
=
rtn
.
decode
(
'utf-8'
,
'ignore'
)
# rtn = str(rtn)
# rtn = str(rtn)
if
self
.
mem_map
:
if
self
.
mem_map
:
rtn
=
rtn
.
decode
(
'unicode_escape'
)
rtn
=
rtn
.
decode
(
'unicode_escape'
)
return
rtn
return
rtn
megatron/deprecated_data_utils/samplers.py
View file @
b7f1b050
...
@@ -21,6 +21,7 @@ import torch
...
@@ -21,6 +21,7 @@ import torch
from
torch.utils
import
data
from
torch.utils
import
data
import
numpy
as
np
import
numpy
as
np
class
RandomSampler
(
data
.
sampler
.
Sampler
):
class
RandomSampler
(
data
.
sampler
.
Sampler
):
r
"""
r
"""
Based off of pytorch RandomSampler and DistributedSampler. Essentially a RandomSampler,
Based off of pytorch RandomSampler and DistributedSampler. Essentially a RandomSampler,
...
@@ -63,7 +64,8 @@ class RandomSampler(data.sampler.Sampler):
...
@@ -63,7 +64,8 @@ class RandomSampler(data.sampler.Sampler):
if
self
.
epoch
>=
0
:
if
self
.
epoch
>=
0
:
g
.
manual_seed
(
self
.
epoch
)
g
.
manual_seed
(
self
.
epoch
)
if
self
.
replacement
:
if
self
.
replacement
:
return
iter
(
torch
.
randint
(
high
=
n
,
size
=
(
self
.
num_samples
,),
dtype
=
torch
.
int64
,
generator
=
g
).
tolist
())
return
iter
(
torch
.
randint
(
high
=
n
,
size
=
(
self
.
num_samples
,),
dtype
=
torch
.
int64
,
generator
=
g
).
tolist
())
return
iter
(
torch
.
randperm
(
n
,
generator
=
g
).
tolist
())
return
iter
(
torch
.
randperm
(
n
,
generator
=
g
).
tolist
())
def
__len__
(
self
):
def
__len__
(
self
):
...
@@ -72,12 +74,14 @@ class RandomSampler(data.sampler.Sampler):
...
@@ -72,12 +74,14 @@ class RandomSampler(data.sampler.Sampler):
def
set_epoch
(
self
,
epoch
):
def
set_epoch
(
self
,
epoch
):
self
.
epoch
=
epoch
self
.
epoch
=
epoch
class
DistributedBatchSampler
(
data
.
sampler
.
BatchSampler
):
class
DistributedBatchSampler
(
data
.
sampler
.
BatchSampler
):
"""
"""
similar to normal implementation of distributed sampler, except implementation is at the
similar to normal implementation of distributed sampler, except implementation is at the
batch sampler level, instead of just the sampler level. This allows wrapping of arbitrary
batch sampler level, instead of just the sampler level. This allows wrapping of arbitrary
data samplers (sequential, random, WeightedRandomSampler, etc.) with this batch sampler.
data samplers (sequential, random, WeightedRandomSampler, etc.) with this batch sampler.
"""
"""
def
__init__
(
self
,
sampler
,
batch_size
,
drop_last
,
rank
=-
1
,
world_size
=
2
,
wrap_last
=
False
):
def
__init__
(
self
,
sampler
,
batch_size
,
drop_last
,
rank
=-
1
,
world_size
=
2
,
wrap_last
=
False
):
super
(
DistributedBatchSampler
,
self
).
__init__
(
sampler
,
batch_size
,
drop_last
)
super
(
DistributedBatchSampler
,
self
).
__init__
(
sampler
,
batch_size
,
drop_last
)
if
rank
==
-
1
:
if
rank
==
-
1
:
...
@@ -125,7 +129,7 @@ class DistributedBatchSampler(data.sampler.BatchSampler):
...
@@ -125,7 +129,7 @@ class DistributedBatchSampler(data.sampler.BatchSampler):
def
data_iterator
(
self
,
_iter
,
wrap_around
=
False
):
def
data_iterator
(
self
,
_iter
,
wrap_around
=
False
):
"""iterates through data and handles wrap around"""
"""iterates through data and handles wrap around"""
for
i
,
idx
in
enumerate
(
_iter
):
for
i
,
idx
in
enumerate
(
_iter
):
if
i
<
self
.
wrap_around
%
self
.
batch_size
:
if
i
<
self
.
wrap_around
%
self
.
batch_size
:
continue
continue
if
wrap_around
:
if
wrap_around
:
self
.
wrap_around
+=
1
self
.
wrap_around
+=
1
...
@@ -134,6 +138,6 @@ class DistributedBatchSampler(data.sampler.BatchSampler):
...
@@ -134,6 +138,6 @@ class DistributedBatchSampler(data.sampler.BatchSampler):
def
_batch
(
self
,
batch
):
def
_batch
(
self
,
batch
):
"""extracts samples only pertaining to this worker's batch"""
"""extracts samples only pertaining to this worker's batch"""
start
=
self
.
rank
*
self
.
batch_size
//
self
.
world_size
start
=
self
.
rank
*
self
.
batch_size
//
self
.
world_size
end
=
(
self
.
rank
+
1
)
*
self
.
batch_size
//
self
.
world_size
end
=
(
self
.
rank
+
1
)
*
self
.
batch_size
//
self
.
world_size
return
batch
[
start
:
end
]
return
batch
[
start
:
end
]
megatron/deprecated_data_utils/scripts/presplit_sentences_json.py
View file @
b7f1b050
...
@@ -16,12 +16,12 @@ output_file = sys.argv[2]
...
@@ -16,12 +16,12 @@ output_file = sys.argv[2]
line_seperator
=
"
\n
"
line_seperator
=
"
\n
"
with
open
(
input_file
,
'r'
)
as
ifile
:
with
open
(
input_file
,
'r'
)
as
ifile
:
with
open
(
output_file
,
"w"
)
as
ofile
:
with
open
(
output_file
,
"w"
)
as
ofile
:
for
doc
in
ifile
.
readlines
():
for
doc
in
ifile
.
readlines
():
parsed
=
json
.
loads
(
doc
)
parsed
=
json
.
loads
(
doc
)
sent_list
=
[]
sent_list
=
[]
for
line
in
parsed
[
'text'
].
split
(
'
\n
'
):
for
line
in
parsed
[
'text'
].
split
(
'
\n
'
):
if
line
!=
'
\n
'
:
if
line
!=
'
\n
'
:
sent_list
.
extend
(
nltk
.
tokenize
.
sent_tokenize
(
line
))
sent_list
.
extend
(
nltk
.
tokenize
.
sent_tokenize
(
line
))
parsed
[
'text'
]
=
line_seperator
.
join
(
sent_list
)
parsed
[
'text'
]
=
line_seperator
.
join
(
sent_list
)
ofile
.
write
(
json
.
dumps
(
parsed
)
+
'
\n
'
)
ofile
.
write
(
json
.
dumps
(
parsed
)
+
'
\n
'
)
megatron/deprecated_data_utils/scripts/split_gpt2_json.py
View file @
b7f1b050
...
@@ -18,7 +18,7 @@ Takes a corpora of files (specified by `--input_files`) with json data separated
...
@@ -18,7 +18,7 @@ Takes a corpora of files (specified by `--input_files`) with json data separated
by newlines (loose json). Splits data into train.json, val.json, test.json files
by newlines (loose json). Splits data into train.json, val.json, test.json files
under `output_dir`.
under `output_dir`.
Note: This code has the potential to override files with the names
Note: This code has the potential to override files with the names
train.json, val.json, test.json in `--output_dir`.
train.json, val.json, test.json in `--output_dir`.
"""
"""
import
os
import
os
...
@@ -35,6 +35,7 @@ parser.add_argument('--test_percent', type=float, nargs='+', default=[0.05, 0],
...
@@ -35,6 +35,7 @@ parser.add_argument('--test_percent', type=float, nargs='+', default=[0.05, 0],
help
=
'percentage of available data to use for val/test dataset'
)
help
=
'percentage of available data to use for val/test dataset'
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
def
get_lines
(
filepath
):
def
get_lines
(
filepath
):
lines
=
[]
lines
=
[]
with
open
(
filepath
,
'r'
)
as
f
:
with
open
(
filepath
,
'r'
)
as
f
:
...
@@ -43,6 +44,7 @@ def get_lines(filepath):
...
@@ -43,6 +44,7 @@ def get_lines(filepath):
lines
.
append
(
l
)
lines
.
append
(
l
)
return
lines
return
lines
def
get_splits
(
lines
,
line_counts
):
def
get_splits
(
lines
,
line_counts
):
all_lines
=
[]
all_lines
=
[]
line_idx
=
[]
line_idx
=
[]
...
@@ -50,14 +52,14 @@ def get_splits(lines, line_counts):
...
@@ -50,14 +52,14 @@ def get_splits(lines, line_counts):
for
i
,
l
in
enumerate
(
lines
):
for
i
,
l
in
enumerate
(
lines
):
all_lines
.
extend
(
l
)
all_lines
.
extend
(
l
)
line_idx
.
extend
(
list
(
range
(
len
(
l
))))
line_idx
.
extend
(
list
(
range
(
len
(
l
))))
file_mappings
.
extend
([
i
]
*
len
(
l
))
file_mappings
.
extend
([
i
]
*
len
(
l
))
indices
=
list
(
range
(
len
(
all_lines
)))
indices
=
list
(
range
(
len
(
all_lines
)))
random
.
shuffle
(
indices
)
random
.
shuffle
(
indices
)
all_lines
=
[
all_lines
[
idx
]
for
idx
in
indices
]
all_lines
=
[
all_lines
[
idx
]
for
idx
in
indices
]
line_idx
=
[
line_idx
[
idx
]
for
idx
in
indices
]
line_idx
=
[
line_idx
[
idx
]
for
idx
in
indices
]
file_mappings
=
[
file_mappings
[
idx
]
for
idx
in
indices
]
file_mappings
=
[
file_mappings
[
idx
]
for
idx
in
indices
]
splits
=
[]
splits
=
[]
mappings
=
[]
mappings
=
[]
start
=
0
start
=
0
...
@@ -68,10 +70,11 @@ def get_splits(lines, line_counts):
...
@@ -68,10 +70,11 @@ def get_splits(lines, line_counts):
start
=
end
start
=
end
return
splits
,
mappings
return
splits
,
mappings
def
format_mappings
(
line_idx
,
file_mappings
):
def
format_mappings
(
line_idx
,
file_mappings
):
lines
=
[]
lines
=
[]
for
m
,
l
in
zip
(
file_mappings
,
line_idx
):
for
m
,
l
in
zip
(
file_mappings
,
line_idx
):
lines
.
append
(
str
(
m
).
strip
()
+
'
\t
'
+
str
(
l
).
strip
())
lines
.
append
(
str
(
m
).
strip
()
+
'
\t
'
+
str
(
l
).
strip
())
return
lines
return
lines
...
@@ -85,25 +88,30 @@ def get_filepaths(filepaths, output_dir):
...
@@ -85,25 +88,30 @@ def get_filepaths(filepaths, output_dir):
paths
.
append
(
os
.
path
.
join
(
output_dir
,
test_path
))
paths
.
append
(
os
.
path
.
join
(
output_dir
,
test_path
))
return
paths
return
paths
def
write_files
(
lines
,
mappings
,
filepaths
):
def
write_files
(
lines
,
mappings
,
filepaths
):
for
l
,
m
,
path
in
zip
(
lines
,
mappings
,
filepaths
):
for
l
,
m
,
path
in
zip
(
lines
,
mappings
,
filepaths
):
write_file
(
l
,
path
)
write_file
(
l
,
path
)
write_mapping_file
(
m
,
path
)
write_mapping_file
(
m
,
path
)
def
write_file
(
lines
,
path
):
def
write_file
(
lines
,
path
):
print
(
'Writing:'
,
path
)
print
(
'Writing:'
,
path
)
with
open
(
path
,
'w'
)
as
f
:
with
open
(
path
,
'w'
)
as
f
:
for
l
in
lines
:
for
l
in
lines
:
f
.
write
(
l
+
'
\n
'
)
f
.
write
(
l
+
'
\n
'
)
def
write_mapping_file
(
m
,
path
):
def
write_mapping_file
(
m
,
path
):
path
=
path
+
'.map'
path
=
path
+
'.map'
m
=
[
get_mapping_header
()]
+
m
m
=
[
get_mapping_header
()]
+
m
write_file
(
m
,
path
)
write_file
(
m
,
path
)
def
get_mapping_header
():
def
get_mapping_header
():
return
'file
\t
line #'
return
'file
\t
line #'
if
not
os
.
path
.
exists
(
args
.
output_dir
):
if
not
os
.
path
.
exists
(
args
.
output_dir
):
os
.
makedirs
(
args
.
output_dir
)
os
.
makedirs
(
args
.
output_dir
)
...
@@ -113,16 +121,16 @@ for filepath in args.input_files:
...
@@ -113,16 +121,16 @@ for filepath in args.input_files:
_lines
=
get_lines
(
filepath
)
_lines
=
get_lines
(
filepath
)
lines
.
append
(
_lines
)
lines
.
append
(
_lines
)
#calculate number of lines to use for each
#
calculate number of lines to use for each
line_counts
=
[
len
(
l
)
for
l
in
lines
]
line_counts
=
[
len
(
l
)
for
l
in
lines
]
total_lines
=
sum
(
line_counts
)
total_lines
=
sum
(
line_counts
)
dev_percent
=
args
.
test_percent
[
0
]
dev_percent
=
args
.
test_percent
[
0
]
dev_lines
=
math
.
ceil
(
dev_percent
*
total_lines
)
dev_lines
=
math
.
ceil
(
dev_percent
*
total_lines
)
test_percent
=
0
test_percent
=
0
if
len
(
args
.
test_percent
)
==
2
:
if
len
(
args
.
test_percent
)
==
2
:
test_percent
=
args
.
test_percent
[
1
]
test_percent
=
args
.
test_percent
[
1
]
test_lines
=
math
.
ceil
(
test_percent
*
total_lines
)
test_lines
=
math
.
ceil
(
test_percent
*
total_lines
)
train_lines
=
total_lines
-
(
test_lines
+
dev_lines
)
train_lines
=
total_lines
-
(
test_lines
+
dev_lines
)
normed_lines
=
[
train_lines
,
dev_lines
,
test_lines
]
normed_lines
=
[
train_lines
,
dev_lines
,
test_lines
]
normed_lines
=
[
int
(
l
)
for
l
in
normed_lines
]
normed_lines
=
[
int
(
l
)
for
l
in
normed_lines
]
...
@@ -131,4 +139,3 @@ splits, mappings = get_splits(lines, normed_lines)
...
@@ -131,4 +139,3 @@ splits, mappings = get_splits(lines, normed_lines)
filepaths
=
get_filepaths
(
args
.
input_files
,
args
.
output_dir
)
filepaths
=
get_filepaths
(
args
.
input_files
,
args
.
output_dir
)
print
(
'Writing output to:'
,
filepaths
)
print
(
'Writing output to:'
,
filepaths
)
write_files
(
splits
,
mappings
,
filepaths
)
write_files
(
splits
,
mappings
,
filepaths
)
megatron/deprecated_data_utils/scripts/split_json.py
View file @
b7f1b050
...
@@ -3,7 +3,7 @@ Takes a corpora of files (specified by `--input_files`) with json data separated
...
@@ -3,7 +3,7 @@ Takes a corpora of files (specified by `--input_files`) with json data separated
by newlines (loose json). Splits data into train.json, val.json, test.json files
by newlines (loose json). Splits data into train.json, val.json, test.json files
under `output_dir`.
under `output_dir`.
Note: This code has the potential to override files with the names
Note: This code has the potential to override files with the names
train.json, val.json, test.json in `--output_dir`.
train.json, val.json, test.json in `--output_dir`.
"""
"""
import
os
import
os
...
@@ -20,6 +20,7 @@ parser.add_argument('--test_percent', type=float, nargs='+', default=[0.05, 0],
...
@@ -20,6 +20,7 @@ parser.add_argument('--test_percent', type=float, nargs='+', default=[0.05, 0],
help
=
'percentage of available data to use for val/test dataset'
)
help
=
'percentage of available data to use for val/test dataset'
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
def
get_lines
(
filepath
):
def
get_lines
(
filepath
):
lines
=
[]
lines
=
[]
with
open
(
filepath
,
'r'
)
as
f
:
with
open
(
filepath
,
'r'
)
as
f
:
...
@@ -28,6 +29,7 @@ def get_lines(filepath):
...
@@ -28,6 +29,7 @@ def get_lines(filepath):
lines
.
append
(
l
)
lines
.
append
(
l
)
return
lines
return
lines
def
get_splits
(
lines
,
line_counts
):
def
get_splits
(
lines
,
line_counts
):
all_lines
=
[]
all_lines
=
[]
line_idx
=
[]
line_idx
=
[]
...
@@ -35,14 +37,14 @@ def get_splits(lines, line_counts):
...
@@ -35,14 +37,14 @@ def get_splits(lines, line_counts):
for
i
,
l
in
enumerate
(
lines
):
for
i
,
l
in
enumerate
(
lines
):
all_lines
.
extend
(
l
)
all_lines
.
extend
(
l
)
line_idx
.
extend
(
list
(
range
(
len
(
l
))))
line_idx
.
extend
(
list
(
range
(
len
(
l
))))
file_mappings
.
extend
([
i
]
*
len
(
l
))
file_mappings
.
extend
([
i
]
*
len
(
l
))
indices
=
list
(
range
(
len
(
all_lines
)))
indices
=
list
(
range
(
len
(
all_lines
)))
random
.
shuffle
(
indices
)
random
.
shuffle
(
indices
)
all_lines
=
[
all_lines
[
idx
]
for
idx
in
indices
]
all_lines
=
[
all_lines
[
idx
]
for
idx
in
indices
]
line_idx
=
[
line_idx
[
idx
]
for
idx
in
indices
]
line_idx
=
[
line_idx
[
idx
]
for
idx
in
indices
]
file_mappings
=
[
file_mappings
[
idx
]
for
idx
in
indices
]
file_mappings
=
[
file_mappings
[
idx
]
for
idx
in
indices
]
splits
=
[]
splits
=
[]
mappings
=
[]
mappings
=
[]
start
=
0
start
=
0
...
@@ -53,10 +55,11 @@ def get_splits(lines, line_counts):
...
@@ -53,10 +55,11 @@ def get_splits(lines, line_counts):
start
=
end
start
=
end
return
splits
,
mappings
return
splits
,
mappings
def
format_mappings
(
line_idx
,
file_mappings
):
def
format_mappings
(
line_idx
,
file_mappings
):
lines
=
[]
lines
=
[]
for
m
,
l
in
zip
(
file_mappings
,
line_idx
):
for
m
,
l
in
zip
(
file_mappings
,
line_idx
):
lines
.
append
(
str
(
m
).
strip
()
+
'
\t
'
+
str
(
l
).
strip
())
lines
.
append
(
str
(
m
).
strip
()
+
'
\t
'
+
str
(
l
).
strip
())
return
lines
return
lines
...
@@ -70,25 +73,30 @@ def get_filepaths(filepaths, output_dir):
...
@@ -70,25 +73,30 @@ def get_filepaths(filepaths, output_dir):
paths
.
append
(
os
.
path
.
join
(
output_dir
,
test_path
))
paths
.
append
(
os
.
path
.
join
(
output_dir
,
test_path
))
return
paths
return
paths
def
write_files
(
lines
,
mappings
,
filepaths
):
def
write_files
(
lines
,
mappings
,
filepaths
):
for
l
,
m
,
path
in
zip
(
lines
,
mappings
,
filepaths
):
for
l
,
m
,
path
in
zip
(
lines
,
mappings
,
filepaths
):
write_file
(
l
,
path
)
write_file
(
l
,
path
)
write_mapping_file
(
m
,
path
)
write_mapping_file
(
m
,
path
)
def
write_file
(
lines
,
path
):
def
write_file
(
lines
,
path
):
print
(
'Writing:'
,
path
)
print
(
'Writing:'
,
path
)
with
open
(
path
,
'w'
)
as
f
:
with
open
(
path
,
'w'
)
as
f
:
for
l
in
lines
:
for
l
in
lines
:
f
.
write
(
l
+
'
\n
'
)
f
.
write
(
l
+
'
\n
'
)
def
write_mapping_file
(
m
,
path
):
def
write_mapping_file
(
m
,
path
):
path
=
path
+
'.map'
path
=
path
+
'.map'
m
=
[
get_mapping_header
()]
+
m
m
=
[
get_mapping_header
()]
+
m
write_file
(
m
,
path
)
write_file
(
m
,
path
)
def
get_mapping_header
():
def
get_mapping_header
():
return
'file
\t
line #'
return
'file
\t
line #'
if
not
os
.
path
.
exists
(
args
.
output_dir
):
if
not
os
.
path
.
exists
(
args
.
output_dir
):
os
.
makedirs
(
args
.
output_dir
)
os
.
makedirs
(
args
.
output_dir
)
...
@@ -98,16 +106,16 @@ for filepath in args.input_files:
...
@@ -98,16 +106,16 @@ for filepath in args.input_files:
_lines
=
get_lines
(
filepath
)
_lines
=
get_lines
(
filepath
)
lines
.
append
(
_lines
)
lines
.
append
(
_lines
)
#calculate number of lines to use for each
#
calculate number of lines to use for each
line_counts
=
[
len
(
l
)
for
l
in
lines
]
line_counts
=
[
len
(
l
)
for
l
in
lines
]
total_lines
=
sum
(
line_counts
)
total_lines
=
sum
(
line_counts
)
dev_percent
=
args
.
test_percent
[
0
]
dev_percent
=
args
.
test_percent
[
0
]
dev_lines
=
math
.
ceil
(
dev_percent
*
total_lines
)
dev_lines
=
math
.
ceil
(
dev_percent
*
total_lines
)
test_percent
=
0
test_percent
=
0
if
len
(
args
.
test_percent
)
==
2
:
if
len
(
args
.
test_percent
)
==
2
:
test_percent
=
args
.
test_percent
[
1
]
test_percent
=
args
.
test_percent
[
1
]
test_lines
=
math
.
ceil
(
test_percent
*
total_lines
)
test_lines
=
math
.
ceil
(
test_percent
*
total_lines
)
train_lines
=
total_lines
-
(
test_lines
+
dev_lines
)
train_lines
=
total_lines
-
(
test_lines
+
dev_lines
)
normed_lines
=
[
train_lines
,
dev_lines
,
test_lines
]
normed_lines
=
[
train_lines
,
dev_lines
,
test_lines
]
normed_lines
=
[
int
(
l
)
for
l
in
normed_lines
]
normed_lines
=
[
int
(
l
)
for
l
in
normed_lines
]
...
@@ -116,4 +124,3 @@ splits, mappings = get_splits(lines, normed_lines)
...
@@ -116,4 +124,3 @@ splits, mappings = get_splits(lines, normed_lines)
filepaths
=
get_filepaths
(
args
.
input_files
,
args
.
output_dir
)
filepaths
=
get_filepaths
(
args
.
input_files
,
args
.
output_dir
)
print
(
'Writing output to:'
,
filepaths
)
print
(
'Writing output to:'
,
filepaths
)
write_files
(
splits
,
mappings
,
filepaths
)
write_files
(
splits
,
mappings
,
filepaths
)
megatron/deprecated_data_utils/tf_dl.py
View file @
b7f1b050
...
@@ -14,20 +14,22 @@
...
@@ -14,20 +14,22 @@
# limitations under the License.
# limitations under the License.
"""PyTorch DataLoader for TFRecords"""
"""PyTorch DataLoader for TFRecords"""
import
numpy
as
np
import
torch
import
queue
import
queue
import
threading
import
threading
import
tensorflow
as
tf
import
tensorflow
as
tf
tf
.
enable_eager_execution
()
tf
.
enable_eager_execution
()
import
torch
import
numpy
as
np
class
TFRecordDataLoader
(
object
):
class
TFRecordDataLoader
(
object
):
def
__init__
(
self
,
records
,
batch_size
,
max_seq_len
,
max_preds_per_seq
,
train
,
num_workers
=
2
,
seed
=
1
,
threaded_dl
=
False
):
def
__init__
(
self
,
records
,
batch_size
,
max_seq_len
,
max_preds_per_seq
,
train
,
num_workers
=
2
,
seed
=
1
,
threaded_dl
=
False
):
assert
max_preds_per_seq
is
not
None
,
"--max-preds-per-seq MUST BE SPECIFIED when using tfrecords"
assert
max_preds_per_seq
is
not
None
,
"--max-preds-per-seq MUST BE SPECIFIED when using tfrecords"
tf
.
set_random_seed
(
seed
)
tf
.
set_random_seed
(
seed
)
if
isinstance
(
records
,
str
):
if
isinstance
(
records
,
str
):
records
=
[
records
]
records
=
[
records
]
self
.
record_converter
=
Record2Example
({
"input_ids"
:
tf
.
FixedLenFeature
([
max_seq_len
],
tf
.
int64
),
self
.
record_converter
=
Record2Example
({
"input_ids"
:
tf
.
FixedLenFeature
([
max_seq_len
],
tf
.
int64
),
"input_mask"
:
tf
.
FixedLenFeature
([
max_seq_len
],
tf
.
int64
),
"input_mask"
:
tf
.
FixedLenFeature
([
max_seq_len
],
tf
.
int64
),
...
@@ -37,7 +39,7 @@ class TFRecordDataLoader(object):
...
@@ -37,7 +39,7 @@ class TFRecordDataLoader(object):
"masked_lm_weights"
:
tf
.
FixedLenFeature
([
max_preds_per_seq
],
tf
.
float32
),
"masked_lm_weights"
:
tf
.
FixedLenFeature
([
max_preds_per_seq
],
tf
.
float32
),
"next_sentence_labels"
:
tf
.
FixedLenFeature
([
1
],
tf
.
int64
)})
"next_sentence_labels"
:
tf
.
FixedLenFeature
([
1
],
tf
.
int64
)})
#Instantiate dataset according to original BERT implementation
#
Instantiate dataset according to original BERT implementation
if
train
:
if
train
:
self
.
dataset
=
tf
.
data
.
Dataset
.
from_tensor_slices
(
tf
.
constant
(
records
))
self
.
dataset
=
tf
.
data
.
Dataset
.
from_tensor_slices
(
tf
.
constant
(
records
))
self
.
dataset
=
self
.
dataset
.
repeat
()
self
.
dataset
=
self
.
dataset
.
repeat
()
...
@@ -55,10 +57,12 @@ class TFRecordDataLoader(object):
...
@@ -55,10 +57,12 @@ class TFRecordDataLoader(object):
self
.
dataset
=
self
.
dataset
.
repeat
()
self
.
dataset
=
self
.
dataset
.
repeat
()
# Instantiate dataloader (do not drop remainder for eval)
# Instantiate dataloader (do not drop remainder for eval)
loader_args
=
{
'batch_size'
:
batch_size
,
loader_args
=
{
'batch_size'
:
batch_size
,
'num_parallel_batches'
:
num_workers
,
'num_parallel_batches'
:
num_workers
,
'drop_remainder'
:
train
}
'drop_remainder'
:
train
}
self
.
dataloader
=
self
.
dataset
.
apply
(
tf
.
contrib
.
data
.
map_and_batch
(
self
.
record_converter
,
**
loader_args
))
self
.
dataloader
=
self
.
dataset
.
apply
(
tf
.
contrib
.
data
.
map_and_batch
(
self
.
record_converter
,
**
loader_args
))
self
.
threaded_dl
=
threaded_dl
self
.
threaded_dl
=
threaded_dl
self
.
num_workers
=
num_workers
self
.
num_workers
=
num_workers
...
@@ -72,6 +76,7 @@ class TFRecordDataLoader(object):
...
@@ -72,6 +76,7 @@ class TFRecordDataLoader(object):
for
item
in
data_iter
:
for
item
in
data_iter
:
yield
convert_tf_example_to_torch_tensors
(
item
)
yield
convert_tf_example_to_torch_tensors
(
item
)
class
Record2Example
(
object
):
class
Record2Example
(
object
):
def
__init__
(
self
,
feature_map
):
def
__init__
(
self
,
feature_map
):
self
.
feature_map
=
feature_map
self
.
feature_map
=
feature_map
...
@@ -84,23 +89,25 @@ class Record2Example(object):
...
@@ -84,23 +89,25 @@ class Record2Example(object):
example
[
k
]
=
tf
.
to_int32
(
v
)
example
[
k
]
=
tf
.
to_int32
(
v
)
return
example
return
example
def
convert_tf_example_to_torch_tensors
(
example
):
def
convert_tf_example_to_torch_tensors
(
example
):
item
=
{
k
:
(
v
.
numpy
())
for
k
,
v
in
example
.
items
()}
item
=
{
k
:
(
v
.
numpy
())
for
k
,
v
in
example
.
items
()}
mask
=
np
.
zeros_like
(
item
[
'input_ids'
])
mask
=
np
.
zeros_like
(
item
[
'input_ids'
])
mask_labels
=
np
.
ones_like
(
item
[
'input_ids'
])
*
-
1
mask_labels
=
np
.
ones_like
(
item
[
'input_ids'
])
*
-
1
for
b
,
row
in
enumerate
(
item
[
'masked_lm_positions'
].
astype
(
int
)):
for
b
,
row
in
enumerate
(
item
[
'masked_lm_positions'
].
astype
(
int
)):
for
i
,
idx
in
enumerate
(
row
):
for
i
,
idx
in
enumerate
(
row
):
if
item
[
'masked_lm_weights'
][
b
,
i
]
!=
0
:
if
item
[
'masked_lm_weights'
][
b
,
i
]
!=
0
:
mask
[
b
,
idx
]
=
1
mask
[
b
,
idx
]
=
1
mask_labels
[
b
,
idx
]
=
item
[
'masked_lm_ids'
][
b
,
i
]
mask_labels
[
b
,
idx
]
=
item
[
'masked_lm_ids'
][
b
,
i
]
output
=
{
'text'
:
item
[
'input_ids'
],
'types'
:
item
[
'segment_ids'
],
'is_random'
:
item
[
'next_sentence_labels'
],
output
=
{
'text'
:
item
[
'input_ids'
],
'types'
:
item
[
'segment_ids'
],
'is_random'
:
item
[
'next_sentence_labels'
],
'pad_mask'
:
1
-
item
[
'input_mask'
],
'mask'
:
mask
,
'mask_labels'
:
mask_labels
}
'pad_mask'
:
1
-
item
[
'input_mask'
],
'mask'
:
mask
,
'mask_labels'
:
mask_labels
}
return
{
k
:
torch
.
from_numpy
(
v
)
for
k
,
v
in
output
.
items
()}
return
{
k
:
torch
.
from_numpy
(
v
)
for
k
,
v
in
output
.
items
()}
class
MultiprocessLoader
(
object
):
class
MultiprocessLoader
(
object
):
def
__init__
(
self
,
dataloader
,
num_workers
=
2
):
def
__init__
(
self
,
dataloader
,
num_workers
=
2
):
self
.
dl
=
dataloader
self
.
dl
=
dataloader
self
.
queue_size
=
2
*
num_workers
self
.
queue_size
=
2
*
num_workers
def
__iter__
(
self
):
def
__iter__
(
self
):
output_queue
=
queue
.
Queue
(
self
.
queue_size
)
output_queue
=
queue
.
Queue
(
self
.
queue_size
)
...
@@ -114,8 +121,9 @@ class MultiprocessLoader(object):
...
@@ -114,8 +121,9 @@ class MultiprocessLoader(object):
else
:
else
:
print
(
RuntimeError
(
'TF record data loader thread exited unexpectedly'
))
print
(
RuntimeError
(
'TF record data loader thread exited unexpectedly'
))
def
_multiproc_iter
(
dl
,
output_queue
):
def
_multiproc_iter
(
dl
,
output_queue
):
data_iter
=
iter
(
dl
)
data_iter
=
iter
(
dl
)
for
item
in
data_iter
:
for
item
in
data_iter
:
tensors
=
convert_tf_example_to_torch_tensors
(
item
)
tensors
=
convert_tf_example_to_torch_tensors
(
item
)
output_queue
.
put
(
tensors
,
block
=
True
)
output_queue
.
put
(
tensors
,
block
=
True
)
\ No newline at end of file
megatron/deprecated_data_utils/tokenization.py
View file @
b7f1b050
This diff is collapsed.
Click to expand it.
Prev
1
2
3
4
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment