Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
fcc500d6
Commit
fcc500d6
authored
Jun 10, 2020
by
Neel Kant
Browse files
Ran and cleaned up
parent
c044f59a
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
18 additions
and
19 deletions
+18
-19
megatron/data/bert_dataset.py
megatron/data/bert_dataset.py
+9
-8
megatron/data/dataset_utils.py
megatron/data/dataset_utils.py
+5
-3
megatron/data/realm_dataset.py
megatron/data/realm_dataset.py
+1
-0
megatron/data/realm_dataset_utils.py
megatron/data/realm_dataset_utils.py
+2
-2
megatron/model/realm_model.py
megatron/model/realm_model.py
+1
-6
No files found.
megatron/data/bert_dataset.py
View file @
fcc500d6
...
...
@@ -25,7 +25,6 @@ from torch.utils.data import Dataset
from
megatron
import
get_tokenizer
,
get_args
from
megatron
import
mpu
from
megatron.data.indexed_dataset
import
make_dataset
as
make_indexed_dataset
from
megatron.data.dataset_utils
import
build_training_sample
from
megatron
import
print_rank_0
...
...
@@ -62,6 +61,8 @@ class BertDataset(Dataset):
self
.
sep_id
=
tokenizer
.
sep
self
.
mask_id
=
tokenizer
.
mask
self
.
pad_id
=
tokenizer
.
pad
from
megatron.data.dataset_utils
import
build_training_sample
self
.
build_sample_fn
=
build_training_sample
def
__len__
(
self
):
return
self
.
samples_mapping
.
shape
[
0
]
...
...
@@ -72,13 +73,13 @@ class BertDataset(Dataset):
# Note that this rng state should be numpy and not python since
# python randint is inclusive whereas the numpy one is exclusive.
np_rng
=
np
.
random
.
RandomState
(
seed
=
(
self
.
seed
+
idx
))
return
build_
training_
sample
(
sample
,
seq_length
,
self
.
max_seq_length
,
# needed for padding
self
.
vocab_id_list
,
self
.
vocab_id_to_token_dict
,
self
.
cls_id
,
self
.
sep_id
,
self
.
mask_id
,
self
.
pad_id
,
self
.
masked_lm_prob
,
np_rng
)
return
self
.
build_sample
_fn
(
sample
,
seq_length
,
self
.
max_seq_length
,
# needed for padding
self
.
vocab_id_list
,
self
.
vocab_id_to_token_dict
,
self
.
cls_id
,
self
.
sep_id
,
self
.
mask_id
,
self
.
pad_id
,
self
.
masked_lm_prob
,
np_rng
)
def
get_indexed_dataset_
(
data_prefix
,
data_impl
,
skip_warmup
):
...
...
megatron/data/dataset_utils.py
View file @
fcc500d6
...
...
@@ -23,7 +23,7 @@ import itertools
import
numpy
as
np
from
megatron
import
print_rank_0
,
get_args
from
megatron.data.bert_dataset
import
get_indexed_dataset_
,
get_train_valid_test_split_
,
BertDataset
from
megatron.data.bert_dataset
import
get_indexed_dataset_
,
get_train_valid_test_split_
DATASET_TYPES
=
[
'standard_bert'
,
'ict'
]
...
...
@@ -426,8 +426,9 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
data_impl
,
skip_warmup
)
if
dataset_type
in
[
'ict'
,
'realm'
]:
title_dataset
=
get_indexed_dataset_
(
data_prefix
+
'-titles'
,
if
dataset_type
in
[
'ict'
]:
args
=
get_args
()
title_dataset
=
get_indexed_dataset_
(
args
.
titles_data_path
,
data_impl
,
skip_warmup
)
...
...
@@ -455,6 +456,7 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
print_split_stats
(
'test'
,
2
)
def
build_dataset
(
index
,
name
):
from
megatron.data.bert_dataset
import
BertDataset
from
megatron.data.realm_dataset
import
ICTDataset
dataset
=
None
if
splits
[
index
+
1
]
>
splits
[
index
]:
...
...
megatron/data/realm_dataset.py
View file @
fcc500d6
...
...
@@ -5,6 +5,7 @@ import numpy as np
from
torch.utils.data
import
Dataset
from
megatron
import
get_tokenizer
from
megatron.data.realm_dataset_utils
import
get_block_samples_mapping
,
join_str_list
class
ICTDataset
(
Dataset
):
...
...
megatron/data/realm_dataset_utils.py
View file @
fcc500d6
...
...
@@ -79,8 +79,8 @@ def get_block_samples_mapping(block_dataset, title_dataset, data_prefix, num_epo
# parallel case
counts
=
torch
.
cuda
.
LongTensor
([
1
])
torch
.
distributed
.
all_reduce
(
counts
,
group
=
mpu
.
get_data_parallel_group
())
#
assert counts[0].item() == torch.distributed.get_world_size(
#
group=mpu.get_data_parallel_group())
assert
counts
[
0
].
item
()
==
torch
.
distributed
.
get_world_size
(
group
=
mpu
.
get_data_parallel_group
())
# Load indexed dataset.
print_rank_0
(
' > loading indexed mapping from {}'
.
format
(
...
...
megatron/model/realm_model.py
View file @
fcc500d6
import
os
import
numpy
as
np
import
torch
import
torch.nn.functional
as
F
from
megatron
import
get_args
from
megatron.checkpointing
import
load_checkpoint
,
get_checkpoint_tracker_filename
,
get_checkpoint_name
from
megatron.data.realm_index
import
detach
,
BlockData
,
FaissMIPSIndex
from
megatron.checkpointing
import
get_checkpoint_tracker_filename
,
get_checkpoint_name
from
megatron.model
import
BertModel
from
megatron.model.utils
import
get_linear_layer
,
init_method_normal
from
megatron.module
import
MegatronModule
from
megatron.utils
import
report_memory
from
megatron
import
mpu
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment