Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
16a64c41
Commit
16a64c41
authored
May 03, 2020
by
Neel Kant
Browse files
Move get_train_val... to dataset_utils
parent
59031aa7
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
100 additions
and
98 deletions
+100
-98
megatron/data/bert_dataset.py
megatron/data/bert_dataset.py
+0
-95
megatron/data/dataset_utils.py
megatron/data/dataset_utils.py
+97
-0
pretrain_bert.py
pretrain_bert.py
+1
-1
pretrain_bert_ict.py
pretrain_bert_ict.py
+1
-1
pretrain_realm.py
pretrain_realm.py
+1
-1
No files found.
megatron/data/bert_dataset.py
View file @
16a64c41
...
@@ -26,106 +26,11 @@ from megatron import get_tokenizer
...
@@ -26,106 +26,11 @@ from megatron import get_tokenizer
from
megatron
import
mpu
from
megatron
import
mpu
from
megatron.data.dataset_utils
import
build_training_sample
from
megatron.data.dataset_utils
import
build_training_sample
from
megatron.data.indexed_dataset
import
make_dataset
as
make_indexed_dataset
from
megatron.data.indexed_dataset
import
make_dataset
as
make_indexed_dataset
from
megatron.data.realm_dataset
import
InverseClozeDataset
from
megatron
import
print_rank_0
from
megatron
import
print_rank_0
DATASET_TYPES
=
[
'standard_bert'
,
'ict'
,
'realm'
]
DATASET_TYPES
=
[
'standard_bert'
,
'ict'
,
'realm'
]
def
build_train_valid_test_datasets
(
data_prefix
,
data_impl
,
splits_string
,
train_valid_test_num_samples
,
max_seq_length
,
masked_lm_prob
,
short_seq_prob
,
seed
,
skip_warmup
,
dataset_type
=
'standard_bert'
):
if
dataset_type
not
in
DATASET_TYPES
:
raise
ValueError
(
"Invalid dataset_type: "
,
dataset_type
)
# Indexed dataset.
indexed_dataset
=
get_indexed_dataset_
(
data_prefix
,
data_impl
,
skip_warmup
)
if
dataset_type
==
'ict'
:
title_dataset
=
get_indexed_dataset_
(
data_prefix
+
'-titles'
,
data_impl
,
skip_warmup
)
# Get start and end indices of train/valid/train into doc-idx
# Note that doc-idx is desinged to be num-docs + 1 so we can
# easily iterate over it.
total_num_of_documents
=
indexed_dataset
.
doc_idx
.
shape
[
0
]
-
1
splits
=
get_train_valid_test_split_
(
splits_string
,
total_num_of_documents
)
# Print stats about the splits.
print_rank_0
(
' > dataset split:'
)
def
print_split_stats
(
name
,
index
):
print_rank_0
(
' {}:'
.
format
(
name
))
print_rank_0
(
' document indices in [{}, {}) total of {} '
'documents'
.
format
(
splits
[
index
],
splits
[
index
+
1
],
splits
[
index
+
1
]
-
splits
[
index
]))
start_index
=
indexed_dataset
.
doc_idx
[
splits
[
index
]]
end_index
=
indexed_dataset
.
doc_idx
[
splits
[
index
+
1
]]
print_rank_0
(
' sentence indices in [{}, {}) total of {} '
'sentences'
.
format
(
start_index
,
end_index
,
end_index
-
start_index
))
print_split_stats
(
'train'
,
0
)
print_split_stats
(
'validation'
,
1
)
print_split_stats
(
'test'
,
2
)
def
build_dataset
(
index
,
name
):
from
megatron.data.realm_dataset
import
RealmDataset
dataset
=
None
if
splits
[
index
+
1
]
>
splits
[
index
]:
# Get the pointer to the original doc-idx so we can set it later.
doc_idx_ptr
=
indexed_dataset
.
get_doc_idx
()
# Slice the doc-idx
start_index
=
splits
[
index
]
# Add +1 so we can index into the dataset to get the upper bound.
end_index
=
splits
[
index
+
1
]
+
1
# New doc_idx view.
indexed_dataset
.
set_doc_idx
(
doc_idx_ptr
[
start_index
:
end_index
])
# Build the dataset accordingly.
kwargs
=
dict
(
name
=
name
,
data_prefix
=
data_prefix
,
num_epochs
=
None
,
max_num_samples
=
train_valid_test_num_samples
[
index
],
max_seq_length
=
max_seq_length
,
short_seq_prob
=
short_seq_prob
,
seed
=
seed
)
if
dataset_type
==
'ict'
:
dataset
=
InverseClozeDataset
(
block_dataset
=
indexed_dataset
,
title_dataset
=
title_dataset
,
**
kwargs
)
else
:
dataset_cls
=
BertDataset
if
dataset_type
==
'standard_bert'
else
RealmDataset
dataset
=
dataset_cls
(
indexed_dataset
=
indexed_dataset
,
masked_lm_prob
=
masked_lm_prob
,
**
kwargs
)
# Set the original pointer so dataset remains the main dataset.
indexed_dataset
.
set_doc_idx
(
doc_idx_ptr
)
# Checks.
assert
indexed_dataset
.
doc_idx
[
0
]
==
0
assert
indexed_dataset
.
doc_idx
.
shape
[
0
]
==
\
(
total_num_of_documents
+
1
)
return
dataset
train_dataset
=
build_dataset
(
0
,
'train'
)
valid_dataset
=
build_dataset
(
1
,
'valid'
)
test_dataset
=
build_dataset
(
2
,
'test'
)
return
(
train_dataset
,
valid_dataset
,
test_dataset
)
class
BertDataset
(
Dataset
):
class
BertDataset
(
Dataset
):
def
__init__
(
self
,
name
,
indexed_dataset
,
data_prefix
,
def
__init__
(
self
,
name
,
indexed_dataset
,
data_prefix
,
...
...
megatron/data/dataset_utils.py
View file @
16a64c41
...
@@ -22,6 +22,9 @@ import collections
...
@@ -22,6 +22,9 @@ import collections
import
itertools
import
itertools
import
numpy
as
np
import
numpy
as
np
from
megatron
import
print_rank_0
from
megatron.data.bert_dataset
import
DATASET_TYPES
,
get_indexed_dataset_
,
get_train_valid_test_split_
,
BertDataset
from
megatron.data.realm_dataset
import
InverseClozeDataset
def
compile_helper
():
def
compile_helper
():
...
@@ -406,3 +409,97 @@ def pad_and_convert_to_numpy(tokens, tokentypes, masked_positions,
...
@@ -406,3 +409,97 @@ def pad_and_convert_to_numpy(tokens, tokentypes, masked_positions,
loss_mask_np
=
np
.
array
(
loss_mask
,
dtype
=
np
.
int64
)
loss_mask_np
=
np
.
array
(
loss_mask
,
dtype
=
np
.
int64
)
return
tokens_np
,
tokentypes_np
,
labels_np
,
padding_mask_np
,
loss_mask_np
return
tokens_np
,
tokentypes_np
,
labels_np
,
padding_mask_np
,
loss_mask_np
def
build_train_valid_test_datasets
(
data_prefix
,
data_impl
,
splits_string
,
train_valid_test_num_samples
,
max_seq_length
,
masked_lm_prob
,
short_seq_prob
,
seed
,
skip_warmup
,
dataset_type
=
'standard_bert'
):
if
dataset_type
not
in
DATASET_TYPES
:
raise
ValueError
(
"Invalid dataset_type: "
,
dataset_type
)
# Indexed dataset.
indexed_dataset
=
get_indexed_dataset_
(
data_prefix
,
data_impl
,
skip_warmup
)
if
dataset_type
==
'ict'
:
title_dataset
=
get_indexed_dataset_
(
data_prefix
+
'-titles'
,
data_impl
,
skip_warmup
)
# Get start and end indices of train/valid/train into doc-idx
# Note that doc-idx is desinged to be num-docs + 1 so we can
# easily iterate over it.
total_num_of_documents
=
indexed_dataset
.
doc_idx
.
shape
[
0
]
-
1
splits
=
get_train_valid_test_split_
(
splits_string
,
total_num_of_documents
)
# Print stats about the splits.
print_rank_0
(
' > dataset split:'
)
def
print_split_stats
(
name
,
index
):
print_rank_0
(
' {}:'
.
format
(
name
))
print_rank_0
(
' document indices in [{}, {}) total of {} '
'documents'
.
format
(
splits
[
index
],
splits
[
index
+
1
],
splits
[
index
+
1
]
-
splits
[
index
]))
start_index
=
indexed_dataset
.
doc_idx
[
splits
[
index
]]
end_index
=
indexed_dataset
.
doc_idx
[
splits
[
index
+
1
]]
print_rank_0
(
' sentence indices in [{}, {}) total of {} '
'sentences'
.
format
(
start_index
,
end_index
,
end_index
-
start_index
))
print_split_stats
(
'train'
,
0
)
print_split_stats
(
'validation'
,
1
)
print_split_stats
(
'test'
,
2
)
def
build_dataset
(
index
,
name
):
from
megatron.data.realm_dataset
import
RealmDataset
dataset
=
None
if
splits
[
index
+
1
]
>
splits
[
index
]:
# Get the pointer to the original doc-idx so we can set it later.
doc_idx_ptr
=
indexed_dataset
.
get_doc_idx
()
# Slice the doc-idx
start_index
=
splits
[
index
]
# Add +1 so we can index into the dataset to get the upper bound.
end_index
=
splits
[
index
+
1
]
+
1
# New doc_idx view.
indexed_dataset
.
set_doc_idx
(
doc_idx_ptr
[
start_index
:
end_index
])
# Build the dataset accordingly.
kwargs
=
dict
(
name
=
name
,
data_prefix
=
data_prefix
,
num_epochs
=
None
,
max_num_samples
=
train_valid_test_num_samples
[
index
],
max_seq_length
=
max_seq_length
,
short_seq_prob
=
short_seq_prob
,
seed
=
seed
)
if
dataset_type
==
'ict'
:
dataset
=
InverseClozeDataset
(
block_dataset
=
indexed_dataset
,
title_dataset
=
title_dataset
,
**
kwargs
)
else
:
dataset_cls
=
BertDataset
if
dataset_type
==
'standard_bert'
else
RealmDataset
dataset
=
dataset_cls
(
indexed_dataset
=
indexed_dataset
,
masked_lm_prob
=
masked_lm_prob
,
**
kwargs
)
# Set the original pointer so dataset remains the main dataset.
indexed_dataset
.
set_doc_idx
(
doc_idx_ptr
)
# Checks.
assert
indexed_dataset
.
doc_idx
[
0
]
==
0
assert
indexed_dataset
.
doc_idx
.
shape
[
0
]
==
\
(
total_num_of_documents
+
1
)
return
dataset
train_dataset
=
build_dataset
(
0
,
'train'
)
valid_dataset
=
build_dataset
(
1
,
'valid'
)
test_dataset
=
build_dataset
(
2
,
'test'
)
return
(
train_dataset
,
valid_dataset
,
test_dataset
)
\ No newline at end of file
pretrain_bert.py
View file @
16a64c41
...
@@ -22,7 +22,7 @@ from megatron import get_args
...
@@ -22,7 +22,7 @@ from megatron import get_args
from
megatron
import
get_timers
from
megatron
import
get_timers
from
megatron
import
mpu
from
megatron
import
mpu
from
megatron
import
print_rank_0
from
megatron
import
print_rank_0
from
megatron.data.
bert_
dataset
import
build_train_valid_test_datasets
from
megatron.data.dataset
_utils
import
build_train_valid_test_datasets
from
megatron.model
import
BertModel
from
megatron.model
import
BertModel
from
megatron.training
import
pretrain
from
megatron.training
import
pretrain
from
megatron.utils
import
reduce_losses
from
megatron.utils
import
reduce_losses
...
...
pretrain_bert_ict.py
View file @
16a64c41
...
@@ -22,7 +22,7 @@ from megatron import get_args
...
@@ -22,7 +22,7 @@ from megatron import get_args
from
megatron
import
get_timers
from
megatron
import
get_timers
from
megatron
import
mpu
from
megatron
import
mpu
from
megatron
import
print_rank_0
from
megatron
import
print_rank_0
from
megatron.data.
bert_
dataset
import
build_train_valid_test_datasets
from
megatron.data.dataset
_utils
import
build_train_valid_test_datasets
from
megatron.model
import
ICTBertModel
from
megatron.model
import
ICTBertModel
from
megatron.training
import
pretrain
from
megatron.training
import
pretrain
from
megatron.utils
import
reduce_losses
from
megatron.utils
import
reduce_losses
...
...
pretrain_realm.py
View file @
16a64c41
...
@@ -24,7 +24,7 @@ from megatron import get_args
...
@@ -24,7 +24,7 @@ from megatron import get_args
from
megatron
import
get_timers
from
megatron
import
get_timers
from
megatron
import
mpu
from
megatron
import
mpu
from
megatron
import
print_rank_0
from
megatron
import
print_rank_0
from
megatron.data.
bert_
dataset
import
build_train_valid_test_datasets
from
megatron.data.dataset
_utils
import
build_train_valid_test_datasets
from
megatron.model
import
REALMBertModel
,
REALMRetriever
from
megatron.model
import
REALMBertModel
,
REALMRetriever
from
megatron.training
import
pretrain
from
megatron.training
import
pretrain
from
megatron.utils
import
reduce_losses
from
megatron.utils
import
reduce_losses
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment