Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
bf599e86
Commit
bf599e86
authored
Jun 08, 2020
by
Neel Kant
Browse files
Correct retrieval utility and add salient span preprocessing
parent
91158c9b
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
218 additions
and
75 deletions
+218
-75
megatron/arguments.py
megatron/arguments.py
+4
-0
megatron/data/dataset_utils.py
megatron/data/dataset_utils.py
+7
-3
megatron/data/realm_dataset.py
megatron/data/realm_dataset.py
+12
-3
megatron/data/realm_dataset_utils.py
megatron/data/realm_dataset_utils.py
+44
-11
megatron/data/realm_index.py
megatron/data/realm_index.py
+9
-7
megatron/model/realm_model.py
megatron/model/realm_model.py
+28
-20
megatron/training.py
megatron/training.py
+1
-1
pretrain_realm.py
pretrain_realm.py
+26
-25
tools/preprocess_data.py
tools/preprocess_data.py
+87
-5
No files found.
megatron/arguments.py
View file @
bf599e86
...
@@ -389,6 +389,10 @@ def _add_data_args(parser):
...
@@ -389,6 +389,10 @@ def _add_data_args(parser):
group
.
add_argument
(
'--query-in-block-prob'
,
type
=
float
,
default
=
0.1
,
group
.
add_argument
(
'--query-in-block-prob'
,
type
=
float
,
default
=
0.1
,
help
=
'Probability of keeping query in block for ICT dataset'
)
help
=
'Probability of keeping query in block for ICT dataset'
)
group
.
add_argument
(
'--faiss-use-gpu'
,
action
=
'store_true'
)
group
.
add_argument
(
'--faiss-use-gpu'
,
action
=
'store_true'
)
group
.
add_argument
(
'--index-reload-interval'
,
type
=
int
,
default
=
500
)
group
.
add_argument
(
'--use-regular-masking'
,
action
=
'store_true'
)
group
.
add_argument
(
'--allow-trivial-doc'
,
action
=
'store_true'
)
group
.
add_argument
(
'--ner-data-path'
,
type
=
str
,
default
=
None
)
return
parser
return
parser
...
...
megatron/data/dataset_utils.py
View file @
bf599e86
...
@@ -417,7 +417,7 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
...
@@ -417,7 +417,7 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
max_seq_length
,
masked_lm_prob
,
max_seq_length
,
masked_lm_prob
,
short_seq_prob
,
seed
,
skip_warmup
,
short_seq_prob
,
seed
,
skip_warmup
,
dataset_type
=
'standard_bert'
):
dataset_type
=
'standard_bert'
):
args
=
get_args
()
if
dataset_type
not
in
DATASET_TYPES
:
if
dataset_type
not
in
DATASET_TYPES
:
raise
ValueError
(
"Invalid dataset_type: "
,
dataset_type
)
raise
ValueError
(
"Invalid dataset_type: "
,
dataset_type
)
...
@@ -427,7 +427,7 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
...
@@ -427,7 +427,7 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
skip_warmup
)
skip_warmup
)
if
dataset_type
in
[
'ict'
,
'realm'
]:
if
dataset_type
in
[
'ict'
,
'realm'
]:
title_dataset
=
get_indexed_dataset_
(
data_prefix
+
'-titles'
,
title_dataset
=
get_indexed_dataset_
(
args
.
titles_data_path
,
data_impl
,
data_impl
,
skip_warmup
)
skip_warmup
)
...
@@ -479,7 +479,6 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
...
@@ -479,7 +479,6 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
)
)
if
dataset_type
==
'ict'
:
if
dataset_type
==
'ict'
:
args
=
get_args
()
dataset
=
ICTDataset
(
dataset
=
ICTDataset
(
block_dataset
=
indexed_dataset
,
block_dataset
=
indexed_dataset
,
title_dataset
=
title_dataset
,
title_dataset
=
title_dataset
,
...
@@ -487,6 +486,11 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
...
@@ -487,6 +486,11 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
**
kwargs
**
kwargs
)
)
elif
dataset_type
==
'realm'
:
elif
dataset_type
==
'realm'
:
if
args
.
ner_data_path
is
not
None
:
ner_dataset
=
get_indexed_dataset_
(
args
.
ner_data_path
,
data_impl
,
skip_warmup
)
kwargs
.
update
({
'ner_dataset'
:
ner_dataset
})
dataset
=
REALMDataset
(
dataset
=
REALMDataset
(
block_dataset
=
indexed_dataset
,
block_dataset
=
indexed_dataset
,
title_dataset
=
title_dataset
,
title_dataset
=
title_dataset
,
...
...
megatron/data/realm_dataset.py
View file @
bf599e86
...
@@ -18,9 +18,9 @@ class REALMDataset(Dataset):
...
@@ -18,9 +18,9 @@ class REALMDataset(Dataset):
Presumably
Presumably
"""
"""
def
__init__
(
self
,
name
,
block_dataset
,
title_dataset
,
data_prefix
,
def
__init__
(
self
,
name
,
block_dataset
,
title_dataset
,
num_epochs
,
max_num_samples
,
masked_lm_prob
,
data_prefix
,
num_epochs
,
max_num_samples
,
masked_lm_prob
,
max_seq_length
,
short_seq_prob
,
seed
):
max_seq_length
,
short_seq_prob
,
seed
,
ner_dataset
=
None
):
self
.
name
=
name
self
.
name
=
name
self
.
seed
=
seed
self
.
seed
=
seed
self
.
max_seq_length
=
max_seq_length
self
.
max_seq_length
=
max_seq_length
...
@@ -29,6 +29,7 @@ class REALMDataset(Dataset):
...
@@ -29,6 +29,7 @@ class REALMDataset(Dataset):
self
.
title_dataset
=
title_dataset
self
.
title_dataset
=
title_dataset
self
.
short_seq_prob
=
short_seq_prob
self
.
short_seq_prob
=
short_seq_prob
self
.
rng
=
random
.
Random
(
self
.
seed
)
self
.
rng
=
random
.
Random
(
self
.
seed
)
self
.
ner_dataset
=
ner_dataset
self
.
samples_mapping
=
get_block_samples_mapping
(
self
.
samples_mapping
=
get_block_samples_mapping
(
block_dataset
,
title_dataset
,
data_prefix
,
num_epochs
,
block_dataset
,
title_dataset
,
data_prefix
,
num_epochs
,
...
@@ -48,7 +49,14 @@ class REALMDataset(Dataset):
...
@@ -48,7 +49,14 @@ class REALMDataset(Dataset):
def
__getitem__
(
self
,
idx
):
def
__getitem__
(
self
,
idx
):
start_idx
,
end_idx
,
doc_idx
,
block_idx
=
self
.
samples_mapping
[
idx
]
start_idx
,
end_idx
,
doc_idx
,
block_idx
=
self
.
samples_mapping
[
idx
]
block
=
[
list
(
self
.
block_dataset
[
i
])
for
i
in
range
(
start_idx
,
end_idx
)]
block
=
[
list
(
self
.
block_dataset
[
i
])
for
i
in
range
(
start_idx
,
end_idx
)]
# print([len(list(self.block_dataset[i])) for i in range(start_idx, end_idx)], flush=True)
assert
len
(
block
)
>
1
assert
len
(
block
)
>
1
block_ner_mask
=
None
if
self
.
ner_dataset
is
not
None
:
block_ner_mask
=
[
list
(
self
.
ner_dataset
[
i
])
for
i
in
range
(
start_idx
,
end_idx
)]
# print([len(list(self.ner_dataset[i])) for i in range(start_idx, end_idx)], flush=True)
np_rng
=
np
.
random
.
RandomState
(
seed
=
(
self
.
seed
+
idx
))
np_rng
=
np
.
random
.
RandomState
(
seed
=
(
self
.
seed
+
idx
))
sample
=
build_realm_training_sample
(
block
,
sample
=
build_realm_training_sample
(
block
,
...
@@ -60,6 +68,7 @@ class REALMDataset(Dataset):
...
@@ -60,6 +68,7 @@ class REALMDataset(Dataset):
self
.
mask_id
,
self
.
mask_id
,
self
.
pad_id
,
self
.
pad_id
,
self
.
masked_lm_prob
,
self
.
masked_lm_prob
,
block_ner_mask
,
np_rng
)
np_rng
)
sample
.
update
({
'query_block_indices'
:
np
.
array
([
block_idx
]).
astype
(
np
.
int64
)})
sample
.
update
({
'query_block_indices'
:
np
.
array
([
block_idx
]).
astype
(
np
.
int64
)})
return
sample
return
sample
...
...
megatron/data/realm_dataset_utils.py
View file @
bf599e86
...
@@ -8,7 +8,7 @@ import spacy
...
@@ -8,7 +8,7 @@ import spacy
import
torch
import
torch
from
megatron.data.dataset_utils
import
create_masked_lm_predictions
,
pad_and_convert_to_numpy
from
megatron.data.dataset_utils
import
create_masked_lm_predictions
,
pad_and_convert_to_numpy
from
megatron
import
get_tokenizer
,
print_rank_0
,
mpu
from
megatron
import
get_args
,
get_tokenizer
,
print_rank_0
,
mpu
SPACY_NER
=
spacy
.
load
(
'en_core_web_lg'
)
SPACY_NER
=
spacy
.
load
(
'en_core_web_lg'
)
...
@@ -16,19 +16,30 @@ SPACY_NER = spacy.load('en_core_web_lg')
...
@@ -16,19 +16,30 @@ SPACY_NER = spacy.load('en_core_web_lg')
def
build_realm_training_sample
(
sample
,
max_seq_length
,
def
build_realm_training_sample
(
sample
,
max_seq_length
,
vocab_id_list
,
vocab_id_to_token_dict
,
vocab_id_list
,
vocab_id_to_token_dict
,
cls_id
,
sep_id
,
mask_id
,
pad_id
,
cls_id
,
sep_id
,
mask_id
,
pad_id
,
masked_lm_prob
,
np_rng
):
masked_lm_prob
,
block_ner_mask
,
np_rng
):
tokens
=
list
(
itertools
.
chain
(
*
sample
))[:
max_seq_length
-
2
]
tokens
=
list
(
itertools
.
chain
(
*
sample
))[:
max_seq_length
-
2
]
tokens
,
tokentypes
=
create_single_tokens_and_tokentypes
(
tokens
,
cls_id
,
sep_id
)
tokens
,
tokentypes
=
create_single_tokens_and_tokentypes
(
tokens
,
cls_id
,
sep_id
)
try
:
args
=
get_args
()
masked_tokens
,
masked_positions
,
masked_labels
=
salient_span_mask
(
tokens
,
mask_id
)
if
args
.
use_regular_masking
:
except
TypeError
:
# this means the above returned None, and None isn't iterable.
# TODO: consider coding style.
max_predictions_per_seq
=
masked_lm_prob
*
max_seq_length
max_predictions_per_seq
=
masked_lm_prob
*
max_seq_length
masked_tokens
,
masked_positions
,
masked_labels
,
_
=
create_masked_lm_predictions
(
masked_tokens
,
masked_positions
,
masked_labels
,
_
=
create_masked_lm_predictions
(
tokens
,
vocab_id_list
,
vocab_id_to_token_dict
,
masked_lm_prob
,
tokens
,
vocab_id_list
,
vocab_id_to_token_dict
,
masked_lm_prob
,
cls_id
,
sep_id
,
mask_id
,
max_predictions_per_seq
,
np_rng
)
cls_id
,
sep_id
,
mask_id
,
max_predictions_per_seq
,
np_rng
)
elif
block_ner_mask
is
not
None
:
block_ner_mask
=
list
(
itertools
.
chain
(
*
block_ner_mask
))[:
max_seq_length
-
2
]
block_ner_mask
=
[
0
]
+
block_ner_mask
+
[
0
]
masked_tokens
,
masked_positions
,
masked_labels
=
get_arrays_using_ner_mask
(
tokens
,
block_ner_mask
,
mask_id
)
else
:
try
:
masked_tokens
,
masked_positions
,
masked_labels
=
salient_span_mask
(
tokens
,
mask_id
)
except
TypeError
:
# this means the above returned None, and None isn't iterable.
# TODO: consider coding style.
max_predictions_per_seq
=
masked_lm_prob
*
max_seq_length
masked_tokens
,
masked_positions
,
masked_labels
,
_
=
create_masked_lm_predictions
(
tokens
,
vocab_id_list
,
vocab_id_to_token_dict
,
masked_lm_prob
,
cls_id
,
sep_id
,
mask_id
,
max_predictions_per_seq
,
np_rng
)
tokens_np
,
tokentypes_np
,
labels_np
,
padding_mask_np
,
loss_mask_np
\
tokens_np
,
tokentypes_np
,
labels_np
,
padding_mask_np
,
loss_mask_np
\
=
pad_and_convert_to_numpy
(
masked_tokens
,
tokentypes
,
masked_positions
,
=
pad_and_convert_to_numpy
(
masked_tokens
,
tokentypes
,
masked_positions
,
...
@@ -43,6 +54,28 @@ def build_realm_training_sample(sample, max_seq_length,
...
@@ -43,6 +54,28 @@ def build_realm_training_sample(sample, max_seq_length,
return
train_sample
return
train_sample
def
get_arrays_using_ner_mask
(
tokens
,
block_ner_mask
,
mask_id
):
tokenizer
=
get_tokenizer
()
tokens_str
=
join_str_list
(
tokenizer
.
tokenizer
.
convert_ids_to_tokens
(
tokens
))
masked_tokens
=
tokens
.
copy
()
masked_positions
=
[]
masked_labels
=
[]
for
i
in
range
(
len
(
tokens
)):
if
block_ner_mask
[
i
]
==
1
:
masked_positions
.
append
(
i
)
masked_labels
.
append
(
tokens
[
i
])
masked_tokens
[
i
]
=
mask_id
# print("-" * 100 + '\n',
# "TOKEN STR\n", tokens_str + '\n',
# "OUTPUT\n", join_str_list(tokenizer.tokenizer.convert_ids_to_tokens(masked_tokens)), flush=True)
return
masked_tokens
,
masked_positions
,
masked_labels
def
create_single_tokens_and_tokentypes
(
_tokens
,
cls_id
,
sep_id
):
def
create_single_tokens_and_tokentypes
(
_tokens
,
cls_id
,
sep_id
):
tokens
=
[]
tokens
=
[]
tokens
.
append
(
cls_id
)
tokens
.
append
(
cls_id
)
...
@@ -119,10 +152,10 @@ def salient_span_mask(tokens, mask_id):
...
@@ -119,10 +152,10 @@ def salient_span_mask(tokens, mask_id):
for
id_idx
in
masked_positions
:
for
id_idx
in
masked_positions
:
labels
.
append
(
tokens
[
id_idx
])
labels
.
append
(
tokens
[
id_idx
])
output_tokens
[
id_idx
]
=
mask_id
output_tokens
[
id_idx
]
=
mask_id
#print("-" * 100 + '\n',
#
print("-" * 100 + '\n',
# "TOKEN STR\n", tokens_str + '\n',
#
"TOKEN STR\n", tokens_str + '\n',
# "SELECTED ENTITY\n", selected_entity.text + '\n',
#
"SELECTED ENTITY\n", selected_entity.text + '\n',
# "OUTPUT\n", join_str_list(tokenizer.tokenizer.convert_ids_to_tokens(output_tokens)), flush=True)
#
"OUTPUT\n", join_str_list(tokenizer.tokenizer.convert_ids_to_tokens(output_tokens)), flush=True)
return
output_tokens
,
masked_positions
,
labels
return
output_tokens
,
masked_positions
,
labels
...
...
megatron/data/realm_index.py
View file @
bf599e86
...
@@ -16,9 +16,11 @@ def detach(tensor):
...
@@ -16,9 +16,11 @@ def detach(tensor):
class
BlockData
(
object
):
class
BlockData
(
object
):
def
__init__
(
self
):
def
__init__
(
self
):
args
=
get_args
()
self
.
embed_data
=
dict
()
self
.
embed_data
=
dict
()
self
.
meta_data
=
dict
()
self
.
meta_data
=
dict
()
self
.
temp_dir_name
=
'temp_block_data'
block_data_path
=
os
.
path
.
splitext
(
args
.
block_data_path
)[
0
]
self
.
temp_dir_name
=
block_data_path
+
'_tmp'
def
state
(
self
):
def
state
(
self
):
return
{
return
{
...
@@ -150,12 +152,12 @@ class FaissMIPSIndex(object):
...
@@ -150,12 +152,12 @@ class FaissMIPSIndex(object):
for
j
in
range
(
block_indices
.
shape
[
1
]):
for
j
in
range
(
block_indices
.
shape
[
1
]):
fresh_indices
[
i
,
j
]
=
self
.
id_map
[
block_indices
[
i
,
j
]]
fresh_indices
[
i
,
j
]
=
self
.
id_map
[
block_indices
[
i
,
j
]]
block_indices
=
fresh_indices
block_indices
=
fresh_indices
args
=
get_args
()
#
args = get_args()
if
args
.
rank
==
0
:
#
if args.rank == 0:
torch
.
save
({
'query_embeds'
:
query_embeds
,
#
torch.save({'query_embeds': query_embeds,
'id_map'
:
self
.
id_map
,
#
'id_map': self.id_map,
'block_indices'
:
block_indices
,
#
'block_indices': block_indices,
'distances'
:
distances
},
'search.data'
)
#
'distances': distances}, 'search.data')
return
distances
,
block_indices
return
distances
,
block_indices
# functions below are for ALSH, which currently isn't being used
# functions below are for ALSH, which currently isn't being used
...
...
megatron/model/realm_model.py
View file @
bf599e86
...
@@ -114,8 +114,15 @@ class REALMBertModel(MegatronModule):
...
@@ -114,8 +114,15 @@ class REALMBertModel(MegatronModule):
# [batch_size x k x seq_length]
# [batch_size x k x seq_length]
topk_block_tokens
,
topk_block_attention_mask
=
self
.
retriever
.
retrieve_evidence_blocks
(
tokens
,
attention_mask
,
query_block_indices
=
query_block_indices
,
include_null_doc
=
True
)
args
=
get_args
()
if
args
.
allow_trivial_doc
:
topk_block_tokens
,
topk_block_attention_mask
=
self
.
retriever
.
retrieve_evidence_blocks
(
tokens
,
attention_mask
,
query_block_indices
=
None
,
include_null_doc
=
True
)
else
:
topk_block_tokens
,
topk_block_attention_mask
=
self
.
retriever
.
retrieve_evidence_blocks
(
tokens
,
attention_mask
,
query_block_indices
=
query_block_indices
,
include_null_doc
=
True
)
# print("Top k block shape: ", topk_block_tokens.shape, flush=True)
# print("Top k block shape: ", topk_block_tokens.shape, flush=True)
batch_size
=
tokens
.
shape
[
0
]
batch_size
=
tokens
.
shape
[
0
]
...
@@ -130,15 +137,16 @@ class REALMBertModel(MegatronModule):
...
@@ -130,15 +137,16 @@ class REALMBertModel(MegatronModule):
# [batch_size x k x embed_size]
# [batch_size x k x embed_size]
true_model
=
self
.
retriever
.
ict_model
.
module
.
module
true_model
=
self
.
retriever
.
ict_model
.
module
.
module
fresh_block_logits
=
mpu
.
checkpoint
(
true_model
.
embed_block
,
topk_block_tokens
,
topk_block_attention_mask
)
fresh_block_logits
=
true_model
.
embed_block
(
topk_block_tokens
,
topk_block_attention_mask
)
fresh_block_logits
=
fresh_block_logits
.
reshape
(
batch_size
,
self
.
top_k
,
-
1
).
float
()
fresh_block_logits
=
fresh_block_logits
.
reshape
(
batch_size
,
self
.
top_k
,
-
1
).
float
()
# print('Fresh block logits shape: ', fresh_block_logits.shape, flush=True)
# print('Fresh block logits shape: ', fresh_block_logits.shape, flush=True)
# [batch_size x 1 x embed_size]
# [batch_size x 1 x embed_size]
query_logits
=
mpu
.
checkpoint
(
true_model
.
embed_query
,
tokens
,
attention_mask
).
unsqueeze
(
1
).
float
()
query_logits
=
true_model
.
embed_query
(
tokens
,
attention_mask
).
unsqueeze
(
1
).
float
()
# [batch_size x k]
# [batch_size x k]
fresh_block_scores
=
torch
.
matmul
(
query_logits
,
torch
.
transpose
(
fresh_block_logits
,
1
,
2
)).
squeeze
()
fresh_block_scores
=
torch
.
matmul
(
query_logits
,
torch
.
transpose
(
fresh_block_logits
,
1
,
2
)).
squeeze
()
# fresh_block_scores = fresh_block_scores / np.sqrt(query_logits.shape[2])
block_probs
=
F
.
softmax
(
fresh_block_scores
,
dim
=
1
)
block_probs
=
F
.
softmax
(
fresh_block_scores
,
dim
=
1
)
# [batch_size * k x seq_length]
# [batch_size * k x seq_length]
...
@@ -163,7 +171,7 @@ class REALMBertModel(MegatronModule):
...
@@ -163,7 +171,7 @@ class REALMBertModel(MegatronModule):
# block body ends after the second SEP
# block body ends after the second SEP
block_ends
=
block_sep_indices
[:,
1
,
1
]
+
1
block_ends
=
block_sep_indices
[:,
1
,
1
]
+
1
print
(
'-'
*
100
)
#
print('-' * 100)
for
row_num
in
range
(
all_tokens
.
shape
[
0
]):
for
row_num
in
range
(
all_tokens
.
shape
[
0
]):
q_len
=
query_lengths
[
row_num
]
q_len
=
query_lengths
[
row_num
]
b_start
=
block_starts
[
row_num
]
b_start
=
block_starts
[
row_num
]
...
@@ -176,24 +184,24 @@ class REALMBertModel(MegatronModule):
...
@@ -176,24 +184,24 @@ class REALMBertModel(MegatronModule):
all_tokens
[
row_num
,
q_len
:
new_tokens_length
]
=
topk_block_tokens
[
row_num
,
b_start
:
b_end
]
all_tokens
[
row_num
,
q_len
:
new_tokens_length
]
=
topk_block_tokens
[
row_num
,
b_start
:
b_end
]
all_tokens
[
row_num
,
new_tokens_length
:]
=
self
.
retriever
.
ict_dataset
.
pad_id
all_tokens
[
row_num
,
new_tokens_length
:]
=
self
.
retriever
.
ict_dataset
.
pad_id
print
(
dset
.
decode_tokens
(
detach
(
all_tokens
[
row_num
]).
tolist
()),
'
\n
'
,
flush
=
True
)
#
print(dset.decode_tokens(detach(all_tokens[row_num]).tolist()), '\n', flush=True)
all_attention_mask
[
row_num
,
:
new_tokens_length
]
=
1
all_attention_mask
[
row_num
,
:
new_tokens_length
]
=
1
all_attention_mask
[
row_num
,
new_tokens_length
:]
=
0
all_attention_mask
[
row_num
,
new_tokens_length
:]
=
0
print
(
'-'
*
100
)
#
print('-' * 100)
args
=
get_args
()
#
args = get_args()
if
args
.
rank
==
0
:
#
if args.rank == 0:
torch
.
save
({
'lm_tokens'
:
all_tokens
,
#
torch.save({'lm_tokens': all_tokens,
'lm_attn_mask'
:
all_attention_mask
,
#
'lm_attn_mask': all_attention_mask,
'query_tokens'
:
tokens
,
#
'query_tokens': tokens,
'query_attn_mask'
:
attention_mask
,
#
'query_attn_mask': attention_mask,
'query_logits'
:
query_logits
,
#
'query_logits': query_logits,
'block_tokens'
:
topk_block_tokens
,
#
'block_tokens': topk_block_tokens,
'block_attn_mask'
:
topk_block_attention_mask
,
#
'block_attn_mask': topk_block_attention_mask,
'block_logits'
:
fresh_block_logits
,
#
'block_logits': fresh_block_logits,
'block_probs'
:
block_probs
,
#
'block_probs': block_probs,
},
'final_lm_inputs.data'
)
#
}, 'final_lm_inputs.data')
# assert all(torch.equal(all_tokens[i], all_tokens[0]) for i in range(self.top_k))
# assert all(torch.equal(all_tokens[i], all_tokens[0]) for i in range(self.top_k))
# assert all(torch.equal(all_attention_mask[i], all_attention_mask[0]) for i in range(self.top_k))
# assert all(torch.equal(all_attention_mask[i], all_attention_mask[0]) for i in range(self.top_k))
...
...
megatron/training.py
View file @
bf599e86
...
@@ -394,7 +394,7 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
...
@@ -394,7 +394,7 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
recv_handle
=
torch
.
distributed
.
broadcast
(
INDEX_READY
,
args
.
max_training_rank
,
group
=
get_gloo_comm_group
(),
async_op
=
True
)
recv_handle
=
torch
.
distributed
.
broadcast
(
INDEX_READY
,
args
.
max_training_rank
,
group
=
get_gloo_comm_group
(),
async_op
=
True
)
last_reload_iteration
=
iteration
last_reload_iteration
=
iteration
while
iteration
<
args
.
train_iters
:
while
iteration
<
args
.
train_iters
:
if
args
.
max_training_rank
is
not
None
and
iteration
>=
last_reload_iteration
+
100
:
if
args
.
max_training_rank
is
not
None
and
iteration
>=
last_reload_iteration
+
args
.
index_reload_interval
:
if
recv_handle
.
is_completed
():
if
recv_handle
.
is_completed
():
# should add check that INDEX_READY == 1 but what else could be happening
# should add check that INDEX_READY == 1 but what else could be happening
true_model
=
model
true_model
=
model
...
...
pretrain_realm.py
View file @
bf599e86
...
@@ -101,7 +101,7 @@ def forward_step(data_iterator, model):
...
@@ -101,7 +101,7 @@ def forward_step(data_iterator, model):
# print('labels shape: ', labels.shape, flush=True)
# print('labels shape: ', labels.shape, flush=True)
with
torch
.
no_grad
():
with
torch
.
no_grad
():
max_retrieval_utility
,
top_retrieval_utility
,
avg_retrieval_utility
=
mpu
.
checkpoint
(
max_retrieval_utility
,
top_retrieval_utility
,
avg_retrieval_utility
,
tokens_over_batch
=
mpu
.
checkpoint
(
get_retrieval_utility
,
lm_logits
,
block_probs
,
labels
,
loss_mask
)
get_retrieval_utility
,
lm_logits
,
block_probs
,
labels
,
loss_mask
)
# P(y|x) = sum_z(P(y|z, x) * P(z|x))
# P(y|x) = sum_z(P(y|z, x) * P(z|x))
...
@@ -118,7 +118,7 @@ def forward_step(data_iterator, model):
...
@@ -118,7 +118,7 @@ def forward_step(data_iterator, model):
# 'tokens': tokens.cpu(),
# 'tokens': tokens.cpu(),
# 'pad_mask': pad_mask.cpu(),
# 'pad_mask': pad_mask.cpu(),
# }, 'tensors.data')
# }, 'tensors.data')
# torch.load('gagaga')
block_probs
=
block_probs
.
unsqueeze
(
2
).
unsqueeze
(
3
).
expand_as
(
relevant_logits
)
block_probs
=
block_probs
.
unsqueeze
(
2
).
unsqueeze
(
3
).
expand_as
(
relevant_logits
)
# print(torch.sum(block_probs, dim=1), flush=True)
# print(torch.sum(block_probs, dim=1), flush=True)
...
@@ -131,58 +131,59 @@ def forward_step(data_iterator, model):
...
@@ -131,58 +131,59 @@ def forward_step(data_iterator, model):
l_probs
=
torch
.
log
(
marginalized_probs
)
l_probs
=
torch
.
log
(
marginalized_probs
)
return
l_probs
return
l_probs
log_probs
=
mpu
.
checkpoint
(
get_log_probs
,
relevant_logits
,
block_probs
)
def
get_loss
(
l_probs
,
labs
):
def
get_loss
(
l_probs
,
labs
):
vocab_size
=
l_probs
.
shape
[
2
]
vocab_size
=
l_probs
.
shape
[
2
]
loss
=
torch
.
nn
.
NLLLoss
(
ignore_index
=-
1
)(
l_probs
.
reshape
(
-
1
,
vocab_size
),
labs
.
reshape
(
-
1
))
loss
=
torch
.
nn
.
NLLLoss
(
ignore_index
=-
1
)(
l_probs
.
reshape
(
-
1
,
vocab_size
),
labs
.
reshape
(
-
1
))
# loss = torch.sum(lm_loss_.view(-1) * loss_mask.reshape(-1)) / loss_mask.sum()
# loss = torch.sum(lm_loss_.view(-1) * loss_mask.reshape(-1)) / loss_mask.sum()
return
loss
.
float
()
return
loss
.
float
()
lm_loss
=
mpu
.
checkpoint
(
get_loss
,
log_probs
,
labels
)
lm_loss
=
get_loss
(
get_log_probs
(
relevant_logits
,
block_probs
),
labels
)
reduced_loss
=
reduce_losses
([
lm_loss
,
max_retrieval_utility
,
top_retrieval_utility
,
avg_retrieval_utility
,
null_block_probs
,
tokens_over_batch
])
# marginalized_logits = torch.sum(relevant_logits * block_probs, dim=1)
# vocab_size = marginalized_logits.shape[2]
# lm_loss_ = torch.nn.CrossEntropyLoss()(marginalized_logits.reshape(-1, vocab_size), labels.reshape(-1))
# lm_loss = torch.sum(lm_loss_.view(-1) * loss_mask.reshape(-1)) / loss_mask.sum()
reduced_loss
=
reduce_losses
([
lm_loss
,
max_retrieval_utility
,
top_retrieval_utility
,
avg_retrieval_utility
,
null_block_probs
])
# reduced_loss = reduce_losses([lm_loss])
# reduced_loss = reduce_losses([lm_loss])
# torch.cuda.synchronize()
# torch.cuda.synchronize()
return
lm_loss
,
{
'lm_loss'
:
reduced_loss
[
0
],
return
lm_loss
,
{
'lm_loss'
:
reduced_loss
[
0
],
'max_ru'
:
reduced_loss
[
1
],
'max_ru'
:
reduced_loss
[
1
],
'top_ru'
:
reduced_loss
[
2
],
'top_ru'
:
reduced_loss
[
2
],
'avg_ru'
:
reduced_loss
[
3
],
'avg_ru'
:
reduced_loss
[
3
],
'null_prob'
:
reduced_loss
[
4
]}
'null_prob'
:
reduced_loss
[
4
],
'mask/batch'
:
reduced_loss
[
5
]}
def
get_retrieval_utility
(
lm_logits_
,
block_probs
,
labels
,
loss_mask
):
def
get_retrieval_utility
(
lm_logits_
,
block_probs
,
labels
,
loss_mask
):
"""log P(y | z, x) - log P(y | null, x)"""
"""log P(y | z, x) - log P(y | null, x)"""
# [batch x seq_len x vocab_size]
# [batch x top_k x seq_len x vocab_size]
lm_logits
=
lm_logits_
[:,
:,
:
labels
.
shape
[
1
],
:]
lm_logits
=
lm_logits_
[:,
:,
:
labels
.
shape
[
1
],
:]
#non_null_block_probs = block_probs[:, :-1]
batch_size
,
top_k
=
lm_logits
.
shape
[
0
],
lm_logits
.
shape
[
1
]
#non_null_block_probs /= torch.sum(non_null_block_probs, axis=1, keepdim=True)
# non_null_block_probs = non_null_block_probsexpand_as(lm_logits[:, :-1, :, :])
# non_null_block_probs = block_probs[:, :-1]
# non_null_block_probs /= torch.sum(non_null_block_probs, axis=1, keepdim=True)
# non_null_block_probs = non_null_block_probs.expand_as(lm_logits[:, :-1, :, :])
null_block_lm_logits
=
lm_logits
[:,
-
1
,
:,
:]
null_block_lm_logits
=
lm_logits
[:,
-
1
,
:,
:]
null_block_loss_
=
mpu
.
vocab_parallel_cross_entropy
(
null_block_lm_logits
.
contiguous
().
float
(),
null_block_loss_
=
mpu
.
vocab_parallel_cross_entropy
(
null_block_lm_logits
.
contiguous
().
float
(),
labels
.
contiguous
())
labels
.
contiguous
())
null_block_loss
=
torch
.
sum
(
null_block_loss
=
torch
.
sum
(
null_block_loss_
.
view
(
-
1
)
*
loss_mask
.
reshape
(
-
1
))
/
batch_size
null_block_loss_
.
view
(
-
1
)
*
loss_mask
.
reshape
(
-
1
))
/
loss_mask
.
sum
()
retrieved_block_losses
=
[]
retrieved_block_losses
=
[]
for
block_num
in
range
(
lm_logits
.
shape
[
1
]
-
1
):
for
block_num
in
range
(
top_k
-
1
):
retrieved_block_lm_logits
=
lm_logits
[:,
block_num
,
:,
:]
retrieved_block_lm_logits
=
lm_logits
[:,
block_num
,
:,
:]
retrieved_block_loss_
=
mpu
.
vocab_parallel_cross_entropy
(
retrieved_block_lm_logits
.
contiguous
().
float
(),
retrieved_block_loss_
=
mpu
.
vocab_parallel_cross_entropy
(
retrieved_block_lm_logits
.
contiguous
().
float
(),
labels
.
contiguous
())
labels
.
contiguous
())
#retrieved_block_loss_ *= non_null_block_probs[:, block_num].reshape(-1, 1)
retrieved_block_loss
=
torch
.
sum
(
#
retrieved_block_loss
_ *= non_null_block_probs[:, block_num].reshape(-1, 1)
retrieved_block_loss_
.
view
(
-
1
)
*
loss_mask
.
reshape
(
-
1
))
/
loss_mask
.
sum
()
retrieved_block_loss
=
torch
.
sum
(
retrieved_block_loss_
.
view
(
-
1
)
*
loss_mask
.
reshape
(
-
1
))
/
batch_size
retrieved_block_losses
.
append
(
retrieved_block_loss
)
retrieved_block_losses
.
append
(
retrieved_block_loss
)
avg_retrieved_block_loss
=
torch
.
sum
(
torch
.
cuda
.
FloatTensor
(
retrieved_block_losses
))
/
(
lm_logits
.
shape
[
1
]
-
1
)
avg_retrieved_block_loss
=
torch
.
sum
(
torch
.
cuda
.
FloatTensor
(
retrieved_block_losses
))
/
(
top_k
-
1
)
max_retrieval_utility
=
null_block_loss
-
min
(
retrieved_block_losses
)
max_retrieval_utility
=
null_block_loss
-
min
(
retrieved_block_losses
)
top_retrieval_utility
=
null_block_loss
-
retrieved_block_losses
[
0
]
top_retrieval_utility
=
null_block_loss
-
retrieved_block_losses
[
0
]
avg_retrieval_utility
=
null_block_loss
-
avg_retrieved_block_loss
avg_retrieval_utility
=
null_block_loss
-
avg_retrieved_block_loss
return
max_retrieval_utility
,
top_retrieval_utility
,
avg_retrieval_utility
tokens_over_batch
=
loss_mask
.
sum
().
float
()
/
batch_size
return
max_retrieval_utility
,
top_retrieval_utility
,
avg_retrieval_utility
,
tokens_over_batch
def
qa_forward_step
(
data_iterator
,
model
):
def
qa_forward_step
(
data_iterator
,
model
):
...
...
tools/preprocess_data.py
View file @
bf599e86
...
@@ -24,6 +24,7 @@ sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__),
...
@@ -24,6 +24,7 @@ sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__),
os
.
path
.
pardir
)))
os
.
path
.
pardir
)))
import
time
import
time
import
numpy
as
np
import
torch
import
torch
try
:
try
:
import
nltk
import
nltk
...
@@ -31,8 +32,11 @@ try:
...
@@ -31,8 +32,11 @@ try:
except
ImportError
:
except
ImportError
:
nltk_available
=
False
nltk_available
=
False
from
megatron.tokenizer
import
build_tokenizer
from
megatron.tokenizer
import
build_tokenizer
from
megatron.data
import
indexed_dataset
from
megatron.data
import
indexed_dataset
from
megatron.data.realm_dataset_utils
import
id_to_str_pos_map
# https://stackoverflow.com/questions/33139531/preserve-empty-lines-with-nltks-punkt-tokenizer
# https://stackoverflow.com/questions/33139531/preserve-empty-lines-with-nltks-punkt-tokenizer
...
@@ -75,6 +79,14 @@ class Encoder(object):
...
@@ -75,6 +79,14 @@ class Encoder(object):
else
:
else
:
Encoder
.
splitter
=
IdentitySplitter
()
Encoder
.
splitter
=
IdentitySplitter
()
try
:
import
spacy
print
(
"> Loading spacy"
)
Encoder
.
spacy
=
spacy
.
load
(
'en_core_web_lg'
)
print
(
">> Finished loading spacy"
)
except
:
Encoder
.
spacy
=
None
def
encode
(
self
,
json_line
):
def
encode
(
self
,
json_line
):
data
=
json
.
loads
(
json_line
)
data
=
json
.
loads
(
json_line
)
ids
=
{}
ids
=
{}
...
@@ -90,6 +102,56 @@ class Encoder(object):
...
@@ -90,6 +102,56 @@ class Encoder(object):
ids
[
key
]
=
doc_ids
ids
[
key
]
=
doc_ids
return
ids
,
len
(
json_line
)
return
ids
,
len
(
json_line
)
def
encode_with_ner
(
self
,
json_line
):
if
self
.
spacy
is
None
:
raise
ValueError
(
'Cannot do NER without spacy'
)
data
=
json
.
loads
(
json_line
)
ids
=
{}
ner_masks
=
{}
for
key
in
self
.
args
.
json_keys
:
text
=
data
[
key
]
doc_ids
=
[]
doc_ner_mask
=
[]
for
sentence
in
Encoder
.
splitter
.
tokenize
(
text
):
sentence_ids
=
Encoder
.
tokenizer
.
tokenize
(
sentence
)
if
len
(
sentence_ids
)
>
0
:
doc_ids
.
append
(
sentence_ids
)
# sentence is cased?
# print(sentence)
entities
=
self
.
spacy
(
sentence
).
ents
undesired_types
=
[
'CARDINAL'
,
'TIME'
,
'PERCENT'
,
'MONEY'
,
'QUANTITY'
,
'ORDINAL'
]
entities
=
[
e
for
e
in
entities
if
e
.
text
!=
"CLS"
and
e
.
label_
not
in
undesired_types
]
# entities = []
masked_positions
=
[]
if
len
(
entities
)
>
0
:
entity_idx
=
np
.
random
.
randint
(
0
,
len
(
entities
))
selected_entity
=
entities
[
entity_idx
]
token_pos_map
=
id_to_str_pos_map
(
sentence_ids
,
Encoder
.
tokenizer
)
mask_start
=
mask_end
=
0
set_mask_start
=
False
while
mask_end
<
len
(
token_pos_map
)
and
token_pos_map
[
mask_end
]
<
selected_entity
.
end_char
:
if
token_pos_map
[
mask_start
]
>
selected_entity
.
start_char
:
set_mask_start
=
True
if
not
set_mask_start
:
mask_start
+=
1
mask_end
+=
1
masked_positions
=
list
(
range
(
mask_start
-
1
,
mask_end
))
ner_mask
=
[
0
]
*
len
(
sentence_ids
)
for
pos
in
masked_positions
:
ner_mask
[
pos
]
=
1
doc_ner_mask
.
append
(
ner_mask
)
if
self
.
args
.
append_eod
:
doc_ids
[
-
1
].
append
(
Encoder
.
tokenizer
.
eod
)
doc_ner_mask
[
-
1
].
append
(
0
)
ids
[
key
]
=
doc_ids
ner_masks
[
key
+
'-ner'
]
=
doc_ner_mask
return
ids
,
ner_masks
,
len
(
json_line
)
def
get_args
():
def
get_args
():
parser
=
argparse
.
ArgumentParser
()
parser
=
argparse
.
ArgumentParser
()
group
=
parser
.
add_argument_group
(
title
=
'input data'
)
group
=
parser
.
add_argument_group
(
title
=
'input data'
)
...
@@ -126,6 +188,8 @@ def get_args():
...
@@ -126,6 +188,8 @@ def get_args():
help
=
'Number of worker processes to launch'
)
help
=
'Number of worker processes to launch'
)
group
.
add_argument
(
'--log-interval'
,
type
=
int
,
default
=
100
,
group
.
add_argument
(
'--log-interval'
,
type
=
int
,
default
=
100
,
help
=
'Interval between progress updates'
)
help
=
'Interval between progress updates'
)
group
.
add_argument
(
'--create-ner-masks'
,
action
=
'store_true'
,
help
=
'Also create mask tensors for salient span masking'
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
args
.
keep_empty
=
False
args
.
keep_empty
=
False
...
@@ -153,8 +217,11 @@ def main():
...
@@ -153,8 +217,11 @@ def main():
encoder
=
Encoder
(
args
)
encoder
=
Encoder
(
args
)
tokenizer
=
build_tokenizer
(
args
)
tokenizer
=
build_tokenizer
(
args
)
pool
=
multiprocessing
.
Pool
(
args
.
workers
,
initializer
=
encoder
.
initializer
)
pool
=
multiprocessing
.
Pool
(
args
.
workers
,
initializer
=
encoder
.
initializer
)
encoded_docs
=
pool
.
imap
(
encoder
.
encode
,
fin
,
25
)
if
args
.
create_ner_masks
:
#encoded_docs = map(encoder.encode, fin)
encoded_docs
=
pool
.
imap
(
encoder
.
encode_with_ner
,
fin
,
25
)
else
:
encoded_docs
=
pool
.
imap
(
encoder
.
encode
,
fin
,
25
)
#encoded_docs = map(encoder.encode, fin)
level
=
"document"
level
=
"document"
if
args
.
split_sentences
:
if
args
.
split_sentences
:
...
@@ -165,7 +232,10 @@ def main():
...
@@ -165,7 +232,10 @@ def main():
output_bin_files
=
{}
output_bin_files
=
{}
output_idx_files
=
{}
output_idx_files
=
{}
builders
=
{}
builders
=
{}
for
key
in
args
.
json_keys
:
output_keys
=
args
.
json_keys
.
copy
()
if
args
.
create_ner_masks
:
output_keys
.
extend
([
key
+
'-ner'
for
key
in
output_keys
])
for
key
in
output_keys
:
output_bin_files
[
key
]
=
"{}_{}_{}.bin"
.
format
(
args
.
output_prefix
,
output_bin_files
[
key
]
=
"{}_{}_{}.bin"
.
format
(
args
.
output_prefix
,
key
,
level
)
key
,
level
)
output_idx_files
[
key
]
=
"{}_{}_{}.idx"
.
format
(
args
.
output_prefix
,
output_idx_files
[
key
]
=
"{}_{}_{}.idx"
.
format
(
args
.
output_prefix
,
...
@@ -179,12 +249,24 @@ def main():
...
@@ -179,12 +249,24 @@ def main():
total_bytes_processed
=
0
total_bytes_processed
=
0
print
(
"Time to startup:"
,
startup_end
-
startup_start
)
print
(
"Time to startup:"
,
startup_end
-
startup_start
)
for
i
,
(
doc
,
bytes_processed
)
in
enumerate
(
encoded_docs
,
start
=
1
):
# for i, (doc, bytes_processed) in enumerate(encoded_docs, start=1):
for
i
,
doc_data
in
enumerate
(
encoded_docs
,
start
=
1
):
if
args
.
create_ner_masks
:
doc
,
ner_masks
,
bytes_processed
=
doc_data
else
:
doc
,
bytes_processed
=
doc_data
total_bytes_processed
+=
bytes_processed
total_bytes_processed
+=
bytes_processed
for
key
,
sentences
in
doc
.
items
():
for
key
,
sentences
in
doc
.
items
():
for
sentence
in
sentences
:
for
sentence
in
sentences
:
builders
[
key
].
add_item
(
torch
.
IntTensor
(
sentence
))
builders
[
key
].
add_item
(
torch
.
IntTensor
(
sentence
))
builders
[
key
].
end_document
()
builders
[
key
].
end_document
()
if
args
.
create_ner_masks
:
for
key
,
sentence_masks
in
ner_masks
.
items
():
for
mask
in
sentence_masks
:
builders
[
key
].
add_item
(
torch
.
IntTensor
(
mask
))
builders
[
key
].
end_document
()
if
i
%
args
.
log_interval
==
0
:
if
i
%
args
.
log_interval
==
0
:
current
=
time
.
time
()
current
=
time
.
time
()
elapsed
=
current
-
proc_start
elapsed
=
current
-
proc_start
...
@@ -193,7 +275,7 @@ def main():
...
@@ -193,7 +275,7 @@ def main():
f
"(
{
i
/
elapsed
}
docs/s,
{
mbs
}
MB/s)."
,
f
"(
{
i
/
elapsed
}
docs/s,
{
mbs
}
MB/s)."
,
file
=
sys
.
stderr
)
file
=
sys
.
stderr
)
for
key
in
args
.
json
_keys
:
for
key
in
output
_keys
:
builders
[
key
].
finalize
(
output_idx_files
[
key
])
builders
[
key
].
finalize
(
output_idx_files
[
key
])
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment