Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
447c1171
Commit
447c1171
authored
Feb 15, 2021
by
Mostofa Patwary
Browse files
addressed the comments given by Mohammad
parent
22a3d81a
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
30 additions
and
39 deletions
+30
-39
megatron/arguments.py
megatron/arguments.py
+3
-7
megatron/model/biencoder_model.py
megatron/model/biencoder_model.py
+23
-29
pretrain_ict.py
pretrain_ict.py
+4
-3
No files found.
megatron/arguments.py
View file @
447c1171
...
...
@@ -646,16 +646,12 @@ def _add_biencoder_args(parser):
group
.
add_argument
(
'--ict-head-size'
,
type
=
int
,
default
=
None
,
help
=
'Size of block embeddings to be used in ICT and '
'REALM (paper default: 128)'
)
group
.
add_argument
(
'--projection-dim'
,
type
=
int
,
default
=
0
,
group
.
add_argument
(
'--
biencoder-
projection-dim'
,
type
=
int
,
default
=
0
,
help
=
'Size of projection head used in biencoder (paper'
' default: 128)'
)
group
.
add_argument
(
'--shared-query-context-model'
,
action
=
'store_true'
,
group
.
add_argument
(
'--
biencoder-
shared-query-context-model'
,
action
=
'store_true'
,
help
=
'Whether to share the parameters of the query '
'and context models or not'
)
group
.
add_argument
(
'--pool-type'
,
type
=
str
,
default
=
'cls-token'
,
choices
=
[
'avg'
,
'cls-token'
,
'max'
],
help
=
'different options are: avg | cls-token | max, '
'default=cls-token'
)
# checkpointing
group
.
add_argument
(
'--ict-load'
,
type
=
str
,
default
=
None
,
...
...
@@ -674,7 +670,7 @@ def _add_biencoder_args(parser):
help
=
'Whether to use one sentence documents in ICT'
)
# training
group
.
add_argument
(
'--report-topk-accuracies'
,
nargs
=
'+'
,
type
=
int
,
group
.
add_argument
(
'--
retriever-
report-topk-accuracies'
,
nargs
=
'+'
,
type
=
int
,
default
=
[],
help
=
"Which top-k accuracies to report "
"(e.g. '1 5 20')"
)
group
.
add_argument
(
'--retriever-score-scaling'
,
action
=
'store_true'
,
...
...
megatron/model/biencoder_model.py
View file @
447c1171
...
...
@@ -17,7 +17,7 @@ from .module import MegatronModule
def
biencoder_model_provider
(
only_query_model
=
False
,
only_context_model
=
False
,
shared_query_context_model
=
False
):
biencoder_
shared_query_context_model
=
False
):
"""Build the model."""
args
=
get_args
()
...
...
@@ -31,10 +31,11 @@ def biencoder_model_provider(only_query_model=False,
# the LM we initialize with has 2 tokentypes
model
=
BiEncoderModel
(
num_tokentypes
=
2
,
parallel_output
=
Tru
e
,
parallel_output
=
Fals
e
,
only_query_model
=
only_query_model
,
only_context_model
=
only_context_model
,
shared_query_context_model
=
shared_query_context_model
)
biencoder_shared_query_context_model
=
\
biencoder_shared_query_context_model
)
return
model
...
...
@@ -47,7 +48,7 @@ class BiEncoderModel(MegatronModule):
parallel_output
=
True
,
only_query_model
=
False
,
only_context_model
=
False
,
shared_query_context_model
=
False
):
biencoder_
shared_query_context_model
=
False
):
super
(
BiEncoderModel
,
self
).
__init__
()
args
=
get_args
()
...
...
@@ -55,13 +56,14 @@ class BiEncoderModel(MegatronModule):
num_tokentypes
=
num_tokentypes
,
parallel_output
=
parallel_output
)
self
.
shared_query_context_model
=
shared_query_context_model
self
.
biencoder_shared_query_context_model
=
\
biencoder_shared_query_context_model
assert
not
(
only_context_model
and
only_query_model
)
self
.
use_context_model
=
not
only_query_model
self
.
use_query_model
=
not
only_context_model
self
.
projection_dim
=
args
.
projection_dim
self
.
biencoder_
projection_dim
=
args
.
biencoder_
projection_dim
if
self
.
shared_query_context_model
:
if
self
.
biencoder_
shared_query_context_model
:
self
.
model
=
PretrainedBertModel
(
**
bert_kwargs
)
self
.
_model_key
=
'shared_model'
self
.
query_model
,
self
.
context_model
=
self
.
model
,
self
.
model
...
...
@@ -109,7 +111,7 @@ class BiEncoderModel(MegatronModule):
prefix
=
''
,
keep_vars
=
False
):
"""Save dict with state dicts of each of the models."""
state_dict_
=
{}
if
self
.
shared_query_context_model
:
if
self
.
biencoder_
shared_query_context_model
:
state_dict_
[
self
.
_model_key
]
=
\
self
.
model
.
state_dict_for_save_checkpoint
(
destination
,
prefix
,
...
...
@@ -129,7 +131,7 @@ class BiEncoderModel(MegatronModule):
def
load_state_dict
(
self
,
state_dict
,
strict
=
True
):
"""Load the state dicts of each of the models"""
if
self
.
shared_query_context_model
:
if
self
.
biencoder_
shared_query_context_model
:
print_rank_0
(
"Loading shared query-context model"
)
self
.
model
.
load_state_dict
(
state_dict
[
self
.
_model_key
],
\
strict
=
strict
)
...
...
@@ -188,14 +190,14 @@ class BiEncoderModel(MegatronModule):
# load the LM state dict into each model
model_dict
=
state_dict
[
'model'
][
'language_model'
]
if
self
.
shared_query_context_model
:
if
self
.
biencoder_
shared_query_context_model
:
self
.
model
.
language_model
.
load_state_dict
(
model_dict
)
fix_query_key_value_ordering
(
self
.
model
,
checkpoint_version
)
else
:
if
self
.
use_query_model
:
self
.
query_model
.
language_model
.
load_state_dict
(
model_dict
)
# give each model the same ict_head to begin with as well
if
self
.
projection_dim
>
0
:
if
self
.
biencoder_
projection_dim
>
0
:
query_proj_state_dict
=
\
self
.
state_dict_for_save_checkpoint
()
\
[
self
.
_query_key
][
'projection_enc'
]
...
...
@@ -203,7 +205,8 @@ class BiEncoderModel(MegatronModule):
if
self
.
use_context_model
:
self
.
context_model
.
language_model
.
load_state_dict
(
model_dict
)
if
self
.
query_model
is
not
None
and
self
.
projection_dim
>
0
:
if
self
.
query_model
is
not
None
and
\
self
.
biencoder_projection_dim
>
0
:
self
.
context_model
.
projection_enc
.
load_state_dict
\
(
query_proj_state_dict
)
fix_query_key_value_ordering
(
self
.
context_model
,
checkpoint_version
)
...
...
@@ -220,8 +223,7 @@ class PretrainedBertModel(MegatronModule):
args
=
get_args
()
tokenizer
=
get_tokenizer
()
self
.
pad_id
=
tokenizer
.
pad
self
.
pool_type
=
args
.
pool_type
self
.
projection_dim
=
args
.
projection_dim
self
.
biencoder_projection_dim
=
args
.
biencoder_projection_dim
self
.
parallel_output
=
parallel_output
init_method
=
init_method_normal
(
args
.
init_method_std
)
scaled_init_method
=
scaled_init_method_normal
(
...
...
@@ -234,9 +236,9 @@ class PretrainedBertModel(MegatronModule):
init_method
=
init_method
,
scaled_init_method
=
scaled_init_method
)
if
args
.
projection_dim
>
0
:
if
args
.
biencoder_
projection_dim
>
0
:
self
.
projection_enc
=
get_linear_layer
(
args
.
hidden_size
,
args
.
projection_dim
,
args
.
biencoder_
projection_dim
,
init_method
)
self
.
_projection_enc_key
=
'projection_enc'
...
...
@@ -253,22 +255,14 @@ class PretrainedBertModel(MegatronModule):
# This mask will be used in average-pooling and max-pooling
pool_mask
=
(
input_ids
==
self
.
pad_id
).
unsqueeze
(
2
)
# Taking the representation of the [CLS] token of BERT
if
self
.
pool_type
==
"cls-token"
:
pooled_output
=
lm_output
[:,
0
,
:]
elif
self
.
pool_type
==
"avg"
:
# Average Pooling
pooled_output
=
lm_output
.
masked_fill
(
pool_mask
,
0
)
pooled_output
=
pooled_output
.
sum
(
1
)
/
(
pool_mask
.
size
(
1
)
\
-
pool_mask
.
float
().
sum
(
1
))
elif
self
.
pool_type
==
"max"
:
# Max-Pooling
pooled_output
=
lm_output
.
masked_fill
(
pool_mask
,
-
1000
)
pooled_output
=
torch
.
max
(
pooled_output
,
1
)[
0
]
# Taking the representation of the [CLS] token of BERT
pooled_output
=
lm_output
[:,
0
,
:]
# Converting to float16 dtype
pooled_output
=
pooled_output
.
to
(
lm_output
.
dtype
)
# Output.
if
self
.
projection_dim
:
if
self
.
biencoder_
projection_dim
:
pooled_output
=
self
.
projection_enc
(
pooled_output
)
return
pooled_output
...
...
@@ -283,7 +277,7 @@ class PretrainedBertModel(MegatronModule):
=
self
.
language_model
.
state_dict_for_save_checkpoint
(
destination
,
prefix
,
keep_vars
)
if
self
.
projection_dim
>
0
:
if
self
.
biencoder_
projection_dim
>
0
:
state_dict_
[
self
.
_projection_enc_key
]
=
\
self
.
projection_enc
.
state_dict
(
destination
,
prefix
,
keep_vars
)
...
...
@@ -295,7 +289,7 @@ class PretrainedBertModel(MegatronModule):
self
.
language_model
.
load_state_dict
(
state_dict
[
self
.
_language_model_key
],
strict
=
strict
)
if
self
.
projection_dim
>
0
:
if
self
.
biencoder_
projection_dim
>
0
:
print_rank_0
(
"loading projection head weights"
)
self
.
projection_enc
.
load_state_dict
(
state_dict
[
self
.
_projection_enc_key
],
strict
=
strict
)
pretrain_ict.py
View file @
447c1171
...
...
@@ -36,7 +36,8 @@ def pretrain_ict_model_provider():
model
=
biencoder_model_provider
(
only_context_model
=
False
,
only_query_model
=
False
,
shared_query_context_model
=
args
.
shared_query_context_model
)
biencoder_shared_query_context_model
=
\
args
.
biencoder_shared_query_context_model
)
return
model
def
get_group_world_size_rank
():
...
...
@@ -120,7 +121,7 @@ def forward_step(data_iterator, model, input_tensor):
return
torch
.
cuda
.
FloatTensor
([
sum
([
int
(
i
in
sorted_indices
[
i
,
:
k
])
\
for
i
in
range
(
global_batch_size
)])
/
global_batch_size
])
topk_accs
=
[
topk_accuracy
(
int
(
k
))
for
k
in
args
.
report_topk_accuracies
]
topk_accs
=
[
topk_accuracy
(
int
(
k
))
for
k
in
args
.
retriever_
report_topk_accuracies
]
labels
=
torch
.
arange
(
global_batch_size
).
long
().
cuda
()
loss
=
F
.
nll_loss
(
softmax_scores
,
labels
,
reduction
=
'mean'
)
...
...
@@ -131,7 +132,7 @@ def forward_step(data_iterator, model, input_tensor):
# create stats_dict with retrieval loss and all specified top-k accuracies
topk_acc_dict
=
{
'top{}_acc'
.
format
(
k
):
v
*
100
for
k
,
v
in
\
zip
(
args
.
report_topk_accuracies
,
reduced_losses
[
1
:])}
zip
(
args
.
retriever_
report_topk_accuracies
,
reduced_losses
[
1
:])}
stats_dict
=
dict
(
loss
=
reduced_losses
[
0
],
**
topk_acc_dict
)
return
loss
,
stats_dict
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment