Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
0f5e2809
"src/vscode:/vscode.git/clone" did not exist on "22e8b2ca2b38ff924ca1cbf82ae7e34b51d1d61a"
Commit
0f5e2809
authored
May 06, 2020
by
Neel Kant
Browse files
Resolve internal merge conflict
parent
f2094783
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
190 additions
and
8 deletions
+190
-8
megatron/arguments.py
megatron/arguments.py
+2
-0
megatron/data/bert_dataset.py
megatron/data/bert_dataset.py
+1
-2
megatron/data/dataset_utils.py
megatron/data/dataset_utils.py
+3
-1
megatron/data/ict_dataset.py
megatron/data/ict_dataset.py
+178
-0
megatron/data/realm_dataset.py
megatron/data/realm_dataset.py
+5
-4
megatron/data/realm_index.py
megatron/data/realm_index.py
+1
-1
No files found.
megatron/arguments.py
View file @
0f5e2809
...
@@ -365,6 +365,8 @@ def _add_data_args(parser):
...
@@ -365,6 +365,8 @@ def _add_data_args(parser):
'end-of-document token.'
)
'end-of-document token.'
)
group
.
add_argument
(
'--eod-mask-loss'
,
action
=
'store_true'
,
group
.
add_argument
(
'--eod-mask-loss'
,
action
=
'store_true'
,
help
=
'Mask loss for the end of document tokens.'
)
help
=
'Mask loss for the end of document tokens.'
)
group
.
add_argument
(
'--query-in-block-prob'
,
type
=
float
,
default
=
0.1
,
help
=
'Probability of keeping query in block for ICT dataset'
)
return
parser
return
parser
...
...
megatron/data/bert_dataset.py
View file @
0f5e2809
...
@@ -22,13 +22,12 @@ import numpy as np
...
@@ -22,13 +22,12 @@ import numpy as np
import
torch
import
torch
from
torch.utils.data
import
Dataset
from
torch.utils.data
import
Dataset
from
megatron
import
get_tokenizer
from
megatron
import
get_tokenizer
,
get_args
from
megatron
import
mpu
from
megatron
import
mpu
from
megatron.data.indexed_dataset
import
make_dataset
as
make_indexed_dataset
from
megatron.data.indexed_dataset
import
make_dataset
as
make_indexed_dataset
from
megatron
import
print_rank_0
from
megatron
import
print_rank_0
class
BertDataset
(
Dataset
):
class
BertDataset
(
Dataset
):
def
__init__
(
self
,
name
,
indexed_dataset
,
data_prefix
,
def
__init__
(
self
,
name
,
indexed_dataset
,
data_prefix
,
...
...
megatron/data/dataset_utils.py
View file @
0f5e2809
...
@@ -22,7 +22,7 @@ import collections
...
@@ -22,7 +22,7 @@ import collections
import
itertools
import
itertools
import
numpy
as
np
import
numpy
as
np
from
megatron
import
print_rank_0
from
megatron
import
print_rank_0
,
get_args
from
megatron.data.bert_dataset
import
get_indexed_dataset_
,
get_train_valid_test_split_
,
BertDataset
from
megatron.data.bert_dataset
import
get_indexed_dataset_
,
get_train_valid_test_split_
,
BertDataset
DATASET_TYPES
=
[
'standard_bert'
,
'ict'
,
'realm'
]
DATASET_TYPES
=
[
'standard_bert'
,
'ict'
,
'realm'
]
...
@@ -478,9 +478,11 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
...
@@ -478,9 +478,11 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
)
)
if
dataset_type
==
'ict'
:
if
dataset_type
==
'ict'
:
args
=
get_args
()
dataset
=
ICTDataset
(
dataset
=
ICTDataset
(
block_dataset
=
indexed_dataset
,
block_dataset
=
indexed_dataset
,
title_dataset
=
title_dataset
,
title_dataset
=
title_dataset
,
query_in_block_prob
=
args
.
query_in_block_prob
,
**
kwargs
**
kwargs
)
)
elif
dataset_type
==
'realm'
:
elif
dataset_type
==
'realm'
:
...
...
megatron/data/ict_dataset.py
0 → 100644
View file @
0f5e2809
import
itertools
import
random
import
os
import
time
import
numpy
as
np
import
torch
from
torch.utils.data
import
Dataset
from
megatron
import
get_tokenizer
from
megatron
import
print_rank_0
from
megatron
import
mpu
from
megatron.data
import
helpers
class
InverseClozeDataset
(
Dataset
):
"""Dataset containing sentences and their blocks for an inverse cloze task."""
def
__init__
(
self
,
name
,
block_dataset
,
title_dataset
,
data_prefix
,
num_epochs
,
max_num_samples
,
max_seq_length
,
query_in_block_prob
,
short_seq_prob
,
seed
):
self
.
name
=
name
self
.
seed
=
seed
self
.
max_seq_length
=
max_seq_length
self
.
query_in_block_prob
=
query_in_block_prob
self
.
block_dataset
=
block_dataset
self
.
title_dataset
=
title_dataset
self
.
short_seq_prob
=
short_seq_prob
self
.
rng
=
random
.
Random
(
self
.
seed
)
self
.
samples_mapping
=
self
.
get_samples_mapping
(
data_prefix
,
num_epochs
,
max_num_samples
)
self
.
tokenizer
=
get_tokenizer
()
self
.
vocab_id_list
=
list
(
self
.
tokenizer
.
inv_vocab
.
keys
())
self
.
vocab_id_to_token_list
=
self
.
tokenizer
.
inv_vocab
self
.
cls_id
=
self
.
tokenizer
.
cls
self
.
sep_id
=
self
.
tokenizer
.
sep
self
.
mask_id
=
self
.
tokenizer
.
mask
self
.
pad_id
=
self
.
tokenizer
.
pad
def
__len__
(
self
):
return
self
.
samples_mapping
.
shape
[
0
]
def
__getitem__
(
self
,
idx
):
start_idx
,
end_idx
,
doc_idx
,
block_idx
=
self
.
samples_mapping
[
idx
]
title
=
list
(
self
.
title_dataset
[
int
(
doc_idx
)])
block
=
[
list
(
self
.
block_dataset
[
i
])
for
i
in
range
(
start_idx
,
end_idx
)]
assert
len
(
block
)
>
1
# avoid selecting the first or last sentence to be the query.
if
len
(
block
)
==
2
:
rand_sent_idx
=
int
(
self
.
rng
.
random
()
>
0.5
)
else
:
rand_sent_idx
=
self
.
rng
.
randint
(
1
,
len
(
block
)
-
2
)
# keep the query in the context 10% of the time.
if
self
.
rng
.
random
()
<
self
.
query_in_block_prob
:
query
=
block
[
rand_sent_idx
].
copy
()
else
:
query
=
block
.
pop
(
rand_sent_idx
)
# still need to truncate because blocks are concluded when
# the sentence lengths have exceeded max_seq_length.
query
=
query
[:
self
.
max_seq_length
-
2
]
block
=
list
(
itertools
.
chain
(
*
block
))[:
self
.
max_seq_length
-
(
3
+
len
(
title
))]
query_tokens
,
query_pad_mask
=
self
.
concat_and_pad_tokens
(
query
)
block_tokens
,
block_pad_mask
=
self
.
concat_and_pad_tokens
(
block
,
title
)
sample
=
{
'query_tokens'
:
np
.
array
(
query_tokens
),
'query_pad_mask'
:
np
.
array
(
query_pad_mask
),
'block_tokens'
:
np
.
array
(
block_tokens
),
'block_pad_mask'
:
np
.
array
(
block_pad_mask
),
'block_data'
:
np
.
array
([
start_idx
,
end_idx
,
doc_idx
,
block_idx
]).
astype
(
np
.
int64
)
}
return
sample
def
encode_text
(
self
,
text
):
return
self
.
tokenizer
.
tokenize
(
text
)
def
decode_tokens
(
self
,
token_ids
):
tokens
=
self
.
tokenizer
.
tokenizer
.
convert_ids_to_tokens
(
token_ids
)
return
' '
.
join
(
token
for
token
in
tokens
if
token
!=
'[PAD]'
)
def
get_block
(
self
,
start_idx
,
end_idx
,
doc_idx
):
"""Get the IDs for an evidence block plus the title of the corresponding document"""
block
=
[
list
(
self
.
block_dataset
[
i
])
for
i
in
range
(
start_idx
,
end_idx
)]
title
=
list
(
self
.
title_dataset
[
int
(
doc_idx
)])
block
=
list
(
itertools
.
chain
(
*
block
))[:
self
.
max_seq_length
-
(
3
+
len
(
title
))]
block_tokens
,
block_pad_mask
=
self
.
concat_and_pad_tokens
(
block
,
title
)
return
(
block_tokens
,
block_pad_mask
)
def
concat_and_pad_tokens
(
self
,
tokens
,
title
=
None
):
"""concat with special tokens and pad sequence to self.max_seq_length"""
tokens
=
[
self
.
cls_id
]
+
tokens
+
[
self
.
sep_id
]
if
title
is
not
None
:
tokens
+=
title
+
[
self
.
sep_id
]
assert
len
(
tokens
)
<=
self
.
max_seq_length
,
len
(
tokens
)
num_pad
=
self
.
max_seq_length
-
len
(
tokens
)
pad_mask
=
[
1
]
*
len
(
tokens
)
+
[
0
]
*
num_pad
tokens
+=
[
self
.
pad_id
]
*
num_pad
return
tokens
,
pad_mask
def
get_samples_mapping
(
self
,
data_prefix
,
num_epochs
,
max_num_samples
):
if
not
num_epochs
:
if
not
max_num_samples
:
raise
ValueError
(
"Need to specify either max_num_samples "
"or num_epochs"
)
num_epochs
=
np
.
iinfo
(
np
.
int32
).
max
-
1
if
not
max_num_samples
:
max_num_samples
=
np
.
iinfo
(
np
.
int64
).
max
-
1
# Filename of the index mapping
indexmap_filename
=
data_prefix
indexmap_filename
+=
'_{}_indexmap'
.
format
(
self
.
name
)
if
num_epochs
!=
(
np
.
iinfo
(
np
.
int32
).
max
-
1
):
indexmap_filename
+=
'_{}ep'
.
format
(
num_epochs
)
if
max_num_samples
!=
(
np
.
iinfo
(
np
.
int64
).
max
-
1
):
indexmap_filename
+=
'_{}mns'
.
format
(
max_num_samples
)
indexmap_filename
+=
'_{}msl'
.
format
(
self
.
max_seq_length
)
indexmap_filename
+=
'_{}s'
.
format
(
self
.
seed
)
indexmap_filename
+=
'.npy'
# Build the indexed mapping if not exist.
if
torch
.
distributed
.
get_rank
()
==
0
and
\
not
os
.
path
.
isfile
(
indexmap_filename
):
print
(
' > WARNING: could not find index map file {}, building '
'the indices on rank 0 ...'
.
format
(
indexmap_filename
))
# Make sure the types match the helpers input types.
assert
self
.
block_dataset
.
doc_idx
.
dtype
==
np
.
int64
assert
self
.
block_dataset
.
sizes
.
dtype
==
np
.
int32
# Build samples mapping
verbose
=
torch
.
distributed
.
get_rank
()
==
0
start_time
=
time
.
time
()
print_rank_0
(
' > building samples index mapping for {} ...'
.
format
(
self
.
name
))
samples_mapping
=
helpers
.
build_blocks_mapping
(
self
.
block_dataset
.
doc_idx
,
self
.
block_dataset
.
sizes
,
self
.
title_dataset
.
sizes
,
num_epochs
,
max_num_samples
,
self
.
max_seq_length
-
3
,
# account for added tokens
self
.
seed
,
verbose
)
print_rank_0
(
' > done building samples index mapping'
)
np
.
save
(
indexmap_filename
,
samples_mapping
,
allow_pickle
=
True
)
print_rank_0
(
' > saved the index mapping in {}'
.
format
(
indexmap_filename
))
# Make sure all the ranks have built the mapping
print_rank_0
(
' > elapsed time to build and save samples mapping '
'(seconds): {:4f}'
.
format
(
time
.
time
()
-
start_time
))
# This should be a barrier but nccl barrier assumes
# device_index=rank which is not the case for model
# parallel case
counts
=
torch
.
cuda
.
LongTensor
([
1
])
torch
.
distributed
.
all_reduce
(
counts
,
group
=
mpu
.
get_data_parallel_group
())
assert
counts
[
0
].
item
()
==
torch
.
distributed
.
get_world_size
(
group
=
mpu
.
get_data_parallel_group
())
# Load indexed dataset.
print_rank_0
(
' > loading indexed mapping from {}'
.
format
(
indexmap_filename
))
start_time
=
time
.
time
()
samples_mapping
=
np
.
load
(
indexmap_filename
,
allow_pickle
=
True
)
print_rank_0
(
' loaded indexed file in {:3.3f} seconds'
.
format
(
time
.
time
()
-
start_time
))
print_rank_0
(
' total number of samples: {}'
.
format
(
samples_mapping
.
shape
[
0
]))
return
samples_mapping
megatron/data/realm_dataset.py
View file @
0f5e2809
...
@@ -4,7 +4,7 @@ import random
...
@@ -4,7 +4,7 @@ import random
import
time
import
time
import
numpy
as
np
import
numpy
as
np
import
spacy
#
import spacy
import
torch
import
torch
from
torch.utils.data
import
Dataset
from
torch.utils.data
import
Dataset
...
@@ -38,7 +38,7 @@ def build_simple_training_sample(sample, target_seq_length, max_seq_length,
...
@@ -38,7 +38,7 @@ def build_simple_training_sample(sample, target_seq_length, max_seq_length,
return
train_sample
return
train_sample
qa_nlp
=
spacy
.
load
(
'en_core_web_lg'
)
#
qa_nlp = spacy.load('en_core_web_lg')
def
salient_span_mask
(
tokens
,
vocab_id_list
,
vocab_id_to_token_dict
,
def
salient_span_mask
(
tokens
,
vocab_id_list
,
vocab_id_to_token_dict
,
...
@@ -357,10 +357,11 @@ class ICTDataset(Dataset):
...
@@ -357,10 +357,11 @@ class ICTDataset(Dataset):
"""Dataset containing sentences and their blocks for an inverse cloze task."""
"""Dataset containing sentences and their blocks for an inverse cloze task."""
def
__init__
(
self
,
name
,
block_dataset
,
title_dataset
,
data_prefix
,
def
__init__
(
self
,
name
,
block_dataset
,
title_dataset
,
data_prefix
,
num_epochs
,
max_num_samples
,
max_seq_length
,
num_epochs
,
max_num_samples
,
max_seq_length
,
short_seq_prob
,
seed
,
use_titles
=
True
):
query_in_block_prob
,
short_seq_prob
,
seed
,
use_titles
=
True
):
self
.
name
=
name
self
.
name
=
name
self
.
seed
=
seed
self
.
seed
=
seed
self
.
max_seq_length
=
max_seq_length
self
.
max_seq_length
=
max_seq_length
self
.
query_in_block_prob
=
query_in_block_prob
self
.
block_dataset
=
block_dataset
self
.
block_dataset
=
block_dataset
self
.
title_dataset
=
title_dataset
self
.
title_dataset
=
title_dataset
self
.
short_seq_prob
=
short_seq_prob
self
.
short_seq_prob
=
short_seq_prob
...
@@ -394,7 +395,7 @@ class ICTDataset(Dataset):
...
@@ -394,7 +395,7 @@ class ICTDataset(Dataset):
rand_sent_idx
=
self
.
rng
.
randint
(
0
,
len
(
block
)
-
1
)
rand_sent_idx
=
self
.
rng
.
randint
(
0
,
len
(
block
)
-
1
)
# keep the query in the context 10% of the time.
# keep the query in the context 10% of the time.
if
self
.
rng
.
random
()
<
1
:
if
self
.
rng
.
random
()
<
self
.
query_in_block_prob
:
query
=
block
[
rand_sent_idx
].
copy
()
query
=
block
[
rand_sent_idx
].
copy
()
else
:
else
:
query
=
block
.
pop
(
rand_sent_idx
)
query
=
block
.
pop
(
rand_sent_idx
)
...
...
megatron/data/realm_index.py
View file @
0f5e2809
...
@@ -3,7 +3,6 @@ import os
...
@@ -3,7 +3,6 @@ import os
import
pickle
import
pickle
import
shutil
import
shutil
import
faiss
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
...
@@ -103,6 +102,7 @@ class FaissMIPSIndex(object):
...
@@ -103,6 +102,7 @@ class FaissMIPSIndex(object):
return
new_index
return
new_index
def
get_block_index
(
self
):
def
get_block_index
(
self
):
import
faiss
INDEX_TYPES
=
[
'flat_l2'
,
'flat_ip'
]
INDEX_TYPES
=
[
'flat_l2'
,
'flat_ip'
]
if
self
.
index_type
not
in
INDEX_TYPES
:
if
self
.
index_type
not
in
INDEX_TYPES
:
raise
ValueError
(
"Invalid index type specified"
)
raise
ValueError
(
"Invalid index type specified"
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment