Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
6c0a5bd8
Commit
6c0a5bd8
authored
Apr 28, 2020
by
Neel Kant
Browse files
Update and test misc functionality
parent
1eccfc94
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
86 additions
and
38 deletions
+86
-38
hashed_index.py
hashed_index.py
+22
-14
megatron/data/ict_dataset.py
megatron/data/ict_dataset.py
+5
-5
megatron/data/realm_dataset.py
megatron/data/realm_dataset.py
+1
-1
megatron/model/bert_model.py
megatron/model/bert_model.py
+36
-7
megatron/model/language_model.py
megatron/model/language_model.py
+7
-4
megatron/training.py
megatron/training.py
+4
-2
pretrain_bert_ict.py
pretrain_bert_ict.py
+4
-2
pretrain_realm.py
pretrain_realm.py
+7
-3
No files found.
hashed_index.py
View file @
6c0a5bd8
...
@@ -29,7 +29,8 @@ class HashedIndex(object):
...
@@ -29,7 +29,8 @@ class HashedIndex(object):
np
.
random
.
seed
(
seed
)
np
.
random
.
seed
(
seed
)
self
.
block_data
=
defaultdict
(
list
)
self
.
block_data
=
defaultdict
(
list
)
self
.
hash_data
=
defaultdict
(
list
)
self
.
hash_data
=
defaultdict
(
list
)
self
.
hash_matrix
=
np
.
random
.
rand
(
embed_size
,
int
(
num_buckets
/
2
))
hash_matrix
=
np
.
random
.
rand
(
embed_size
,
int
(
num_buckets
/
2
))
self
.
hash_matrix
=
hash_matrix
/
np
.
linalg
.
norm
(
hash_matrix
,
axis
=
0
).
reshape
(
1
,
-
1
)
def
state
(
self
):
def
state
(
self
):
state
=
{
state
=
{
...
@@ -47,7 +48,7 @@ class HashedIndex(object):
...
@@ -47,7 +48,7 @@ class HashedIndex(object):
def
hash_embeds
(
self
,
embeds
,
block_data
=
None
):
def
hash_embeds
(
self
,
embeds
,
block_data
=
None
):
"""Hash a tensor of embeddings using a random projection matrix"""
"""Hash a tensor of embeddings using a random projection matrix"""
embed_scores_pos
=
torch
.
matmul
(
embeds
,
torch
.
cuda
.
Half
Tensor
(
self
.
hash_matrix
))
embed_scores_pos
=
torch
.
matmul
(
embeds
,
torch
.
cuda
.
Float
Tensor
(
self
.
hash_matrix
))
embed_scores
=
torch
.
cat
((
embed_scores_pos
,
-
embed_scores_pos
),
axis
=
1
)
embed_scores
=
torch
.
cat
((
embed_scores_pos
,
-
embed_scores_pos
),
axis
=
1
)
embed_hashes
=
detach
(
torch
.
argmax
(
embed_scores
,
axis
=
1
))
embed_hashes
=
detach
(
torch
.
argmax
(
embed_scores
,
axis
=
1
))
...
@@ -62,7 +63,7 @@ class HashedIndex(object):
...
@@ -62,7 +63,7 @@ class HashedIndex(object):
for
idx
,
embed
in
zip
(
block_indices
,
block_embeds
):
for
idx
,
embed
in
zip
(
block_indices
,
block_embeds
):
if
not
allow_overwrite
and
int
(
idx
)
in
self
.
block_data
:
if
not
allow_overwrite
and
int
(
idx
)
in
self
.
block_data
:
raise
ValueError
(
"Attempted to overwrite a read-only HashedIndex"
)
raise
ValueError
(
"Attempted to overwrite a read-only HashedIndex"
)
self
.
block_data
[
int
(
idx
)]
=
embed
self
.
block_data
[
int
(
idx
)]
=
np
.
float16
(
embed
)
def
save_shard
(
self
,
rank
):
def
save_shard
(
self
,
rank
):
dir_name
=
'block_hash_data'
dir_name
=
'block_hash_data'
...
@@ -92,7 +93,8 @@ class HashedIndex(object):
...
@@ -92,7 +93,8 @@ class HashedIndex(object):
for
bucket
,
items
in
data
[
'hash_data'
].
items
():
for
bucket
,
items
in
data
[
'hash_data'
].
items
():
self
.
hash_data
[
bucket
].
extend
(
items
)
self
.
hash_data
[
bucket
].
extend
(
items
)
with
open
(
'block_hash_data.pkl'
,
'wb'
)
as
final_file
:
args
=
get_args
()
with
open
(
args
.
hash_data_path
,
'wb'
)
as
final_file
:
pickle
.
dump
(
self
.
state
(),
final_file
)
pickle
.
dump
(
self
.
state
(),
final_file
)
shutil
.
rmtree
(
dir_name
,
ignore_errors
=
True
)
shutil
.
rmtree
(
dir_name
,
ignore_errors
=
True
)
...
@@ -119,7 +121,7 @@ def test_retriever():
...
@@ -119,7 +121,7 @@ def test_retriever():
initialize_megatron
(
extra_args_provider
=
None
,
initialize_megatron
(
extra_args_provider
=
None
,
args_defaults
=
{
'tokenizer_type'
:
'BertWordPieceLowerCase'
})
args_defaults
=
{
'tokenizer_type'
:
'BertWordPieceLowerCase'
})
args
=
get_args
()
args
=
get_args
()
model
=
load_ict_checkpoint
()
model
=
load_ict_checkpoint
(
only_block_model
=
True
)
model
.
eval
()
model
.
eval
()
dataset
=
get_ict_dataset
()
dataset
=
get_ict_dataset
()
hashed_index
=
HashedIndex
.
load_from_file
(
args
.
hash_data_path
)
hashed_index
=
HashedIndex
.
load_from_file
(
args
.
hash_data_path
)
...
@@ -158,11 +160,11 @@ def main():
...
@@ -158,11 +160,11 @@ def main():
initialize_megatron
(
extra_args_provider
=
None
,
initialize_megatron
(
extra_args_provider
=
None
,
args_defaults
=
{
'tokenizer_type'
:
'BertWordPieceLowerCase'
})
args_defaults
=
{
'tokenizer_type'
:
'BertWordPieceLowerCase'
})
args
=
get_args
()
args
=
get_args
()
model
=
load_ict_checkpoint
()
model
=
load_ict_checkpoint
(
only_block_model
=
True
,
no_grad
=
True
)
model
.
eval
()
model
.
eval
()
dataset
=
get_ict_dataset
()
dataset
=
get_ict_dataset
()
data_iter
=
iter
(
get_dataloader
(
dataset
))
data_iter
=
iter
(
get_dataloader
(
dataset
))
hashed_index
=
HashedIndex
(
embed_size
=
128
,
num_buckets
=
2048
)
hashed_index
=
HashedIndex
(
embed_size
=
128
,
num_buckets
=
4096
)
i
=
0
i
=
0
while
True
:
while
True
:
...
@@ -172,10 +174,8 @@ def main():
...
@@ -172,10 +174,8 @@ def main():
except
:
except
:
break
break
actual_model
=
model
.
module
.
module
block_indices
=
detach
(
block_indices
)
block_indices
=
detach
(
block_indices
)
block_logits
=
model
(
None
,
None
,
block_tokens
,
block_pad_mask
,
only_block
=
True
)
block_logits
=
actual_model
.
embed_block
(
block_tokens
,
block_pad_mask
)
hashed_index
.
hash_embeds
(
block_logits
,
block_indices
)
hashed_index
.
hash_embeds
(
block_logits
,
block_indices
)
hashed_index
.
assign_block_embeds
(
block_indices
[:,
3
],
detach
(
block_logits
))
hashed_index
.
assign_block_embeds
(
block_indices
[:,
3
],
detach
(
block_logits
))
...
@@ -193,9 +193,9 @@ def main():
...
@@ -193,9 +193,9 @@ def main():
hashed_index
.
clear
()
hashed_index
.
clear
()
def
load_ict_checkpoint
():
def
load_ict_checkpoint
(
only_query_model
=
False
,
only_block_model
=
False
,
no_grad
=
False
):
args
=
get_args
()
args
=
get_args
()
model
=
get_model
(
model_provider
)
model
=
get_model
(
lambda
:
model_provider
(
only_query_model
,
only_block_model
)
)
if
isinstance
(
model
,
torchDDP
):
if
isinstance
(
model
,
torchDDP
):
model
=
model
.
module
model
=
model
.
module
...
@@ -210,7 +210,15 @@ def load_ict_checkpoint():
...
@@ -210,7 +210,15 @@ def load_ict_checkpoint():
torch
.
distributed
.
get_rank
(),
checkpoint_name
))
torch
.
distributed
.
get_rank
(),
checkpoint_name
))
state_dict
=
torch
.
load
(
checkpoint_name
,
map_location
=
'cpu'
)
state_dict
=
torch
.
load
(
checkpoint_name
,
map_location
=
'cpu'
)
model
.
load_state_dict
(
state_dict
[
'model'
])
if
only_query_model
:
state_dict
[
'model'
].
pop
(
'context_model'
)
if
only_block_model
:
state_dict
[
'model'
].
pop
(
'question_model'
)
if
no_grad
:
with
torch
.
no_grad
():
model
.
load_state_dict
(
state_dict
[
'model'
])
else
:
model
.
load_state_dict
(
state_dict
[
'model'
])
torch
.
distributed
.
barrier
()
torch
.
distributed
.
barrier
()
if
mpu
.
get_data_parallel_rank
()
==
0
:
if
mpu
.
get_data_parallel_rank
()
==
0
:
...
@@ -261,4 +269,4 @@ def get_dataloader(dataset):
...
@@ -261,4 +269,4 @@ def get_dataloader(dataset):
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
test_retriever
()
main
()
megatron/data/ict_dataset.py
View file @
6c0a5bd8
...
@@ -131,8 +131,8 @@ class InverseClozeDataset(Dataset):
...
@@ -131,8 +131,8 @@ class InverseClozeDataset(Dataset):
'the indices on rank 0 ...'
.
format
(
indexmap_filename
))
'the indices on rank 0 ...'
.
format
(
indexmap_filename
))
# Make sure the types match the helpers input types.
# Make sure the types match the helpers input types.
assert
self
.
context
_dataset
.
doc_idx
.
dtype
==
np
.
int64
assert
self
.
block
_dataset
.
doc_idx
.
dtype
==
np
.
int64
assert
self
.
context
_dataset
.
sizes
.
dtype
==
np
.
int32
assert
self
.
block
_dataset
.
sizes
.
dtype
==
np
.
int32
# Build samples mapping
# Build samples mapping
verbose
=
torch
.
distributed
.
get_rank
()
==
0
verbose
=
torch
.
distributed
.
get_rank
()
==
0
...
@@ -140,9 +140,9 @@ class InverseClozeDataset(Dataset):
...
@@ -140,9 +140,9 @@ class InverseClozeDataset(Dataset):
print_rank_0
(
' > building samples index mapping for {} ...'
.
format
(
print_rank_0
(
' > building samples index mapping for {} ...'
.
format
(
self
.
name
))
self
.
name
))
samples_mapping
=
helpers
.
build_blocks_mapping
(
samples_mapping
=
helpers
.
build_blocks_mapping
(
self
.
context
_dataset
.
doc_idx
,
self
.
block
_dataset
.
doc_idx
,
self
.
context
_dataset
.
sizes
,
self
.
block
_dataset
.
sizes
,
self
.
title
s
_dataset
.
sizes
,
self
.
title_dataset
.
sizes
,
num_epochs
,
num_epochs
,
max_num_samples
,
max_num_samples
,
self
.
max_seq_length
-
3
,
# account for added tokens
self
.
max_seq_length
-
3
,
# account for added tokens
...
...
megatron/data/realm_dataset.py
View file @
6c0a5bd8
...
@@ -47,7 +47,7 @@ def build_simple_training_sample(sample, target_seq_length, max_seq_length,
...
@@ -47,7 +47,7 @@ def build_simple_training_sample(sample, target_seq_length, max_seq_length,
masked_labels
,
pad_id
,
max_seq_length
)
masked_labels
,
pad_id
,
max_seq_length
)
# REALM true sequence length is twice as long but none of that is to be predicted with LM
# REALM true sequence length is twice as long but none of that is to be predicted with LM
loss_mask_np
=
np
.
concatenate
((
loss_mask_np
,
np
.
ones
(
loss_mask_np
.
shape
)),
-
1
).
astype
(
np
.
int64
)
#
loss_mask_np = np.concatenate((loss_mask_np, np.ones(loss_mask_np.shape)), -1).astype(np.int64)
train_sample
=
{
train_sample
=
{
'tokens'
:
tokens_np
,
'tokens'
:
tokens_np
,
...
...
megatron/model/bert_model.py
View file @
6c0a5bd8
...
@@ -126,12 +126,18 @@ class BertModel(MegatronModule):
...
@@ -126,12 +126,18 @@ class BertModel(MegatronModule):
add_pooler
=
self
.
add_binary_head
or
self
.
add_ict_head
add_pooler
=
self
.
add_binary_head
or
self
.
add_ict_head
scaled_init_method
=
scaled_init_method_normal
(
args
.
init_method_std
,
scaled_init_method
=
scaled_init_method_normal
(
args
.
init_method_std
,
args
.
num_layers
)
args
.
num_layers
)
max_pos_embeds
=
None
if
not
add_binary_head
and
ict_head_size
is
None
:
max_pos_embeds
=
2
*
args
.
seq_length
self
.
language_model
,
self
.
_language_model_key
=
get_language_model
(
self
.
language_model
,
self
.
_language_model_key
=
get_language_model
(
attention_mask_func
=
bert_attention_mask_func
,
attention_mask_func
=
bert_attention_mask_func
,
num_tokentypes
=
num_tokentypes
,
num_tokentypes
=
num_tokentypes
,
add_pooler
=
add_pooler
,
add_pooler
=
add_pooler
,
init_method
=
init_method
,
init_method
=
init_method
,
scaled_init_method
=
scaled_init_method
)
scaled_init_method
=
scaled_init_method
,
max_pos_embeds
=
max_pos_embeds
)
if
not
self
.
add_ict_head
:
if
not
self
.
add_ict_head
:
self
.
lm_head
=
BertLMHead
(
self
.
lm_head
=
BertLMHead
(
...
@@ -218,6 +224,8 @@ class BertModel(MegatronModule):
...
@@ -218,6 +224,8 @@ class BertModel(MegatronModule):
class
REALMBertModel
(
MegatronModule
):
class
REALMBertModel
(
MegatronModule
):
# TODO: load BertModel checkpoint
def
__init__
(
self
,
retriever
):
def
__init__
(
self
,
retriever
):
super
(
REALMBertModel
,
self
).
__init__
()
super
(
REALMBertModel
,
self
).
__init__
()
bert_args
=
dict
(
bert_args
=
dict
(
...
@@ -241,10 +249,11 @@ class REALMBertModel(MegatronModule):
...
@@ -241,10 +249,11 @@ class REALMBertModel(MegatronModule):
top5_block_attention_mask
=
torch
.
cuda
.
LongTensor
(
top5_block_attention_mask
).
reshape
(
-
1
,
seq_length
)
top5_block_attention_mask
=
torch
.
cuda
.
LongTensor
(
top5_block_attention_mask
).
reshape
(
-
1
,
seq_length
)
# [batch_size x 5 x embed_size]
# [batch_size x 5 x embed_size]
fresh_block_logits
=
self
.
retriever
.
ict_model
.
module
.
module
.
embed_block
(
top5_block_tokens
,
top5_block_attention_mask
).
reshape
(
batch_size
,
5
,
-
1
)
fresh_block_logits
=
self
.
retriever
.
ict_model
(
None
,
None
,
top5_block_tokens
,
top5_block_attention_mask
,
only_block
=
True
).
reshape
(
batch_size
,
5
,
-
1
)
# fresh_block_logits.register_hook(lambda x: print("fresh block: ", x.shape, flush=True))
# [batch_size x embed_size x 1]
# [batch_size x embed_size x 1]
query_logits
=
self
.
retriever
.
ict_model
.
module
.
module
.
embed_query
(
tokens
,
attention_mask
).
unsqueeze
(
2
)
query_logits
=
self
.
retriever
.
ict_model
(
tokens
,
attention_mask
,
None
,
None
,
only_query
=
True
).
unsqueeze
(
2
)
# [batch_size x 5]
# [batch_size x 5]
...
@@ -282,6 +291,7 @@ class REALMRetriever(MegatronModule):
...
@@ -282,6 +291,7 @@ class REALMRetriever(MegatronModule):
self
.
ict_model
=
ict_model
self
.
ict_model
=
ict_model
self
.
ict_dataset
=
ict_dataset
self
.
ict_dataset
=
ict_dataset
self
.
hashed_index
=
hashed_index
self
.
hashed_index
=
hashed_index
self
.
top_k
=
top_k
def
retrieve_evidence_blocks_text
(
self
,
query_text
):
def
retrieve_evidence_blocks_text
(
self
,
query_text
):
"""Get the top k evidence blocks for query_text in text form"""
"""Get the top k evidence blocks for query_text in text form"""
...
@@ -300,16 +310,25 @@ class REALMRetriever(MegatronModule):
...
@@ -300,16 +310,25 @@ class REALMRetriever(MegatronModule):
print
(
'
\n
> Block {}: {}'
.
format
(
i
,
block_text
))
print
(
'
\n
> Block {}: {}'
.
format
(
i
,
block_text
))
def
retrieve_evidence_blocks
(
self
,
query_tokens
,
query_pad_mask
):
def
retrieve_evidence_blocks
(
self
,
query_tokens
,
query_pad_mask
):
query_embeds
=
self
.
ict_model
.
module
.
module
.
embed_query
(
query_tokens
,
query_pad_mask
)
"""Embed blocks to be used in a forward pass"""
query_embeds
=
self
.
ict_model
(
query_tokens
,
query_pad_mask
,
None
,
None
,
only_query
=
True
)
query_hashes
=
self
.
hashed_index
.
hash_embeds
(
query_embeds
)
query_hashes
=
self
.
hashed_index
.
hash_embeds
(
query_embeds
)
block_buckets
=
[
self
.
hashed_index
.
get_block_bucket
(
hash
)
for
hash
in
query_hashes
]
block_buckets
=
[
self
.
hashed_index
.
get_block_bucket
(
hash
)
for
hash
in
query_hashes
]
block_embeds
=
[
torch
.
cuda
.
HalfTensor
(
np
.
array
([
self
.
hashed_index
.
get_block_embed
(
arr
[
3
])
for
j
,
bucket
in
enumerate
(
block_buckets
):
if
len
(
bucket
)
<
5
:
for
i
in
range
(
len
(
block_buckets
)):
if
len
(
block_buckets
[
i
])
>
5
:
block_buckets
[
j
]
=
block_buckets
[
i
].
copy
()
# [batch_size x max_bucket_population x embed_size]
block_embeds
=
[
torch
.
cuda
.
FloatTensor
(
np
.
array
([
self
.
hashed_index
.
get_block_embed
(
arr
[
3
])
for
arr
in
bucket
]))
for
bucket
in
block_buckets
]
for
arr
in
bucket
]))
for
bucket
in
block_buckets
]
all_top5_tokens
,
all_top5_pad_masks
=
[],
[]
all_top5_tokens
,
all_top5_pad_masks
=
[],
[]
for
query_embed
,
embed_tensor
,
bucket
in
zip
(
query_embeds
,
block_embeds
,
block_buckets
):
for
query_embed
,
embed_tensor
,
bucket
in
zip
(
query_embeds
,
block_embeds
,
block_buckets
):
retrieval_scores
=
query_embed
.
matmul
(
torch
.
transpose
(
embed_tensor
,
0
,
1
))
retrieval_scores
=
query_embed
.
matmul
(
torch
.
transpose
(
embed_tensor
.
reshape
(
-
1
,
query_embed
.
size
()[
0
]),
0
,
1
))
print
(
retrieval_scores
.
shape
,
flush
=
True
)
top5_vals
,
top5_indices
=
torch
.
topk
(
retrieval_scores
,
k
=
5
,
sorted
=
True
)
top5_vals
,
top5_indices
=
torch
.
topk
(
retrieval_scores
,
k
=
5
,
sorted
=
True
)
top5_start_end_doc
=
[
bucket
[
idx
][:
3
]
for
idx
in
top5_indices
.
squeeze
()]
top5_start_end_doc
=
[
bucket
[
idx
][:
3
]
for
idx
in
top5_indices
.
squeeze
()]
...
@@ -354,8 +373,16 @@ class ICTBertModel(MegatronModule):
...
@@ -354,8 +373,16 @@ class ICTBertModel(MegatronModule):
self
.
block_model
=
BertModel
(
**
bert_args
)
self
.
block_model
=
BertModel
(
**
bert_args
)
self
.
_block_key
=
'context_model'
self
.
_block_key
=
'context_model'
def
forward
(
self
,
query_tokens
,
query_attention_mask
,
block_tokens
,
block_attention_mask
):
def
forward
(
self
,
query_tokens
,
query_attention_mask
,
block_tokens
,
block_attention_mask
,
only_query
=
False
,
only_block
=
False
):
"""Run a forward pass for each of the models and compute the similarity scores."""
"""Run a forward pass for each of the models and compute the similarity scores."""
if
only_query
:
return
self
.
embed_query
(
query_tokens
,
query_attention_mask
)
if
only_block
:
return
self
.
embed_block
(
block_tokens
,
block_attention_mask
)
query_logits
=
self
.
embed_query
(
query_tokens
,
query_attention_mask
)
query_logits
=
self
.
embed_query
(
query_tokens
,
query_attention_mask
)
block_logits
=
self
.
embed_block
(
block_tokens
,
block_attention_mask
)
block_logits
=
self
.
embed_block
(
block_tokens
,
block_attention_mask
)
...
@@ -399,9 +426,11 @@ class ICTBertModel(MegatronModule):
...
@@ -399,9 +426,11 @@ class ICTBertModel(MegatronModule):
def
load_state_dict
(
self
,
state_dict
,
strict
=
True
):
def
load_state_dict
(
self
,
state_dict
,
strict
=
True
):
"""Load the state dicts of each of the models"""
"""Load the state dicts of each of the models"""
if
self
.
use_query_model
:
if
self
.
use_query_model
:
print
(
"Loading ICT query model"
,
flush
=
True
)
self
.
query_model
.
load_state_dict
(
self
.
query_model
.
load_state_dict
(
state_dict
[
self
.
_query_key
],
strict
=
strict
)
state_dict
[
self
.
_query_key
],
strict
=
strict
)
if
self
.
use_block_model
:
if
self
.
use_block_model
:
print
(
"Loading ICT block model"
,
flush
=
True
)
self
.
block_model
.
load_state_dict
(
self
.
block_model
.
load_state_dict
(
state_dict
[
self
.
_block_key
],
strict
=
strict
)
state_dict
[
self
.
_block_key
],
strict
=
strict
)
megatron/model/language_model.py
View file @
6c0a5bd8
...
@@ -45,7 +45,7 @@ def parallel_lm_logits(input_, word_embeddings_weight, parallel_output,
...
@@ -45,7 +45,7 @@ def parallel_lm_logits(input_, word_embeddings_weight, parallel_output,
def
get_language_model
(
attention_mask_func
,
num_tokentypes
,
add_pooler
,
def
get_language_model
(
attention_mask_func
,
num_tokentypes
,
add_pooler
,
init_method
,
scaled_init_method
):
init_method
,
scaled_init_method
,
max_pos_embeds
=
None
):
"""Build language model and return along with the key to save."""
"""Build language model and return along with the key to save."""
# Language model.
# Language model.
...
@@ -55,7 +55,8 @@ def get_language_model(attention_mask_func, num_tokentypes, add_pooler,
...
@@ -55,7 +55,8 @@ def get_language_model(attention_mask_func, num_tokentypes, add_pooler,
init_method
=
init_method
,
init_method
=
init_method
,
output_layer_init_method
=
scaled_init_method
,
output_layer_init_method
=
scaled_init_method
,
num_tokentypes
=
num_tokentypes
,
num_tokentypes
=
num_tokentypes
,
add_pooler
=
add_pooler
)
add_pooler
=
add_pooler
,
max_pos_embeds
=
max_pos_embeds
)
# key used for checkpoints.
# key used for checkpoints.
language_model_key
=
'language_model'
language_model_key
=
'language_model'
...
@@ -266,7 +267,8 @@ class TransformerLanguageModel(MegatronModule):
...
@@ -266,7 +267,8 @@ class TransformerLanguageModel(MegatronModule):
init_method
,
init_method
,
output_layer_init_method
,
output_layer_init_method
,
num_tokentypes
=
0
,
num_tokentypes
=
0
,
add_pooler
=
False
):
add_pooler
=
False
,
max_pos_embeds
=
None
):
super
(
TransformerLanguageModel
,
self
).
__init__
()
super
(
TransformerLanguageModel
,
self
).
__init__
()
args
=
get_args
()
args
=
get_args
()
...
@@ -275,10 +277,11 @@ class TransformerLanguageModel(MegatronModule):
...
@@ -275,10 +277,11 @@ class TransformerLanguageModel(MegatronModule):
self
.
init_method
=
init_method
self
.
init_method
=
init_method
self
.
add_pooler
=
add_pooler
self
.
add_pooler
=
add_pooler
max_pos_embeds
=
args
.
max_position_embeddings
if
max_pos_embeds
is
None
else
max_pos_embeds
# Embeddings
# Embeddings
self
.
embedding
=
Embedding
(
self
.
hidden_size
,
self
.
embedding
=
Embedding
(
self
.
hidden_size
,
args
.
padded_vocab_size
,
args
.
padded_vocab_size
,
args
.
max_pos
ition
_embed
ding
s
,
max_pos_embeds
,
args
.
hidden_dropout
,
args
.
hidden_dropout
,
self
.
init_method
,
self
.
init_method
,
self
.
num_tokentypes
)
self
.
num_tokentypes
)
...
...
megatron/training.py
View file @
6c0a5bd8
...
@@ -225,7 +225,7 @@ def backward_step(optimizer, model, loss):
...
@@ -225,7 +225,7 @@ def backward_step(optimizer, model, loss):
"""Backward step."""
"""Backward step."""
args
=
get_args
()
args
=
get_args
()
timers
=
get_timers
()
timers
=
get_timers
()
print
(
"start backward"
,
flush
=
True
)
torch
.
cuda
.
synchronize
(
)
# Backward pass.
# Backward pass.
optimizer
.
zero_grad
()
optimizer
.
zero_grad
()
...
@@ -250,6 +250,7 @@ def backward_step(optimizer, model, loss):
...
@@ -250,6 +250,7 @@ def backward_step(optimizer, model, loss):
else
:
else
:
optimizer
.
clip_master_grads
(
args
.
clip_grad
)
optimizer
.
clip_master_grads
(
args
.
clip_grad
)
ran_backward_once
=
False
def
train_step
(
forward_step_func
,
data_iterator
,
def
train_step
(
forward_step_func
,
data_iterator
,
model
,
optimizer
,
lr_scheduler
):
model
,
optimizer
,
lr_scheduler
):
...
@@ -262,11 +263,12 @@ def train_step(forward_step_func, data_iterator,
...
@@ -262,11 +263,12 @@ def train_step(forward_step_func, data_iterator,
loss
,
loss_reduced
=
forward_step_func
(
data_iterator
,
model
)
loss
,
loss_reduced
=
forward_step_func
(
data_iterator
,
model
)
timers
(
'forward'
).
stop
()
timers
(
'forward'
).
stop
()
# Calculate gradients, reduce across processes, and clip.
timers
(
'backward'
).
start
()
timers
(
'backward'
).
start
()
backward_step
(
optimizer
,
model
,
loss
)
backward_step
(
optimizer
,
model
,
loss
)
timers
(
'backward'
).
stop
()
timers
(
'backward'
).
stop
()
# Calculate gradients, reduce across processes, and clip.
# Update parameters.
# Update parameters.
timers
(
'optimizer'
).
start
()
timers
(
'optimizer'
).
start
()
optimizer
.
step
()
optimizer
.
step
()
...
...
pretrain_bert_ict.py
View file @
6c0a5bd8
...
@@ -29,7 +29,7 @@ from megatron.utils import reduce_losses
...
@@ -29,7 +29,7 @@ from megatron.utils import reduce_losses
num_batches
=
0
num_batches
=
0
def
model_provider
():
def
model_provider
(
only_query_model
=
False
,
only_block_model
=
False
):
"""Build the model."""
"""Build the model."""
args
=
get_args
()
args
=
get_args
()
print_rank_0
(
'building BERT models ...'
)
print_rank_0
(
'building BERT models ...'
)
...
@@ -37,7 +37,9 @@ def model_provider():
...
@@ -37,7 +37,9 @@ def model_provider():
model
=
ICTBertModel
(
model
=
ICTBertModel
(
ict_head_size
=
128
,
ict_head_size
=
128
,
num_tokentypes
=
2
,
num_tokentypes
=
2
,
parallel_output
=
True
)
parallel_output
=
True
,
only_query_model
=
only_query_model
,
only_block_model
=
only_block_model
)
return
model
return
model
...
...
pretrain_realm.py
View file @
6c0a5bd8
...
@@ -38,9 +38,10 @@ def model_provider():
...
@@ -38,9 +38,10 @@ def model_provider():
ict_model
=
load_ict_checkpoint
()
ict_model
=
load_ict_checkpoint
()
ict_dataset
=
get_ict_dataset
()
ict_dataset
=
get_ict_dataset
()
hashed_index
=
HashedIndex
.
load_from_file
(
'block_
hash_data
.pkl'
)
hashed_index
=
HashedIndex
.
load_from_file
(
args
.
hash_data
_path
)
retriever
=
REALMRetriever
(
ict_model
,
ict_dataset
,
hashed_index
)
retriever
=
REALMRetriever
(
ict_model
,
ict_dataset
,
hashed_index
)
# TODO: REALMBertModel should accept a path to a pretrained bert-base
model
=
REALMBertModel
(
retriever
)
model
=
REALMBertModel
(
retriever
)
return
model
return
model
...
@@ -74,7 +75,6 @@ def forward_step(data_iterator, model):
...
@@ -74,7 +75,6 @@ def forward_step(data_iterator, model):
# Get the batch.
# Get the batch.
timers
(
'batch generator'
).
start
()
timers
(
'batch generator'
).
start
()
tokens
,
labels
,
loss_mask
,
pad_mask
=
get_batch
(
data_iterator
)
tokens
,
labels
,
loss_mask
,
pad_mask
=
get_batch
(
data_iterator
)
labels
=
torch
.
cat
((
labels
,
labels
),
axis
=-
1
)
timers
(
'batch generator'
).
stop
()
timers
(
'batch generator'
).
stop
()
# Forward model.
# Forward model.
...
@@ -83,13 +83,17 @@ def forward_step(data_iterator, model):
...
@@ -83,13 +83,17 @@ def forward_step(data_iterator, model):
# P(y|x) = sum_z(P(y|z, x) * P(z|x))
# P(y|x) = sum_z(P(y|z, x) * P(z|x))
block_probs
=
block_probs
.
unsqueeze
(
2
).
unsqueeze
(
3
).
expand_as
(
lm_logits
)
block_probs
=
block_probs
.
unsqueeze
(
2
).
unsqueeze
(
3
).
expand_as
(
lm_logits
)
lm_logits
=
torch
.
sum
(
lm_logits
*
block_probs
,
dim
=
1
)
#block_probs.register_hook(lambda x: print("block_probs: ", x.shape, flush=True))
lm_logits
=
torch
.
sum
(
lm_logits
*
block_probs
,
dim
=
1
)[:,
:
labels
.
shape
[
1
]]
lm_loss_
=
mpu
.
vocab_parallel_cross_entropy
(
lm_logits
.
contiguous
().
float
(),
lm_loss_
=
mpu
.
vocab_parallel_cross_entropy
(
lm_logits
.
contiguous
().
float
(),
labels
.
contiguous
())
labels
.
contiguous
())
lm_loss
=
torch
.
sum
(
lm_loss
=
torch
.
sum
(
lm_loss_
.
view
(
-
1
)
*
loss_mask
.
reshape
(
-
1
))
/
loss_mask
.
sum
()
lm_loss_
.
view
(
-
1
)
*
loss_mask
.
reshape
(
-
1
))
/
loss_mask
.
sum
()
reduced_loss
=
reduce_losses
([
lm_loss
])
reduced_loss
=
reduce_losses
([
lm_loss
])
torch
.
cuda
.
synchronize
()
print
(
reduced_loss
,
flush
=
True
)
print
(
reduced_loss
,
flush
=
True
)
return
lm_loss
,
{
'lm_loss'
:
reduced_loss
[
0
]}
return
lm_loss
,
{
'lm_loss'
:
reduced_loss
[
0
]}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment