Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
371d2ea9
Commit
371d2ea9
authored
Mar 26, 2020
by
Neel Kant
Browse files
Complete definition of ICTBertModel
parent
9873a8da
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
33 additions
and
1 deletion
+33
-1
megatron/model/bert_model.py
megatron/model/bert_model.py
+33
-1
No files found.
megatron/model/bert_model.py
View file @
371d2ea9
...
...
@@ -284,4 +284,36 @@ class ICTBertModel(MegatronModule):
attention_softmax_in_fp32
=
attention_softmax_in_fp32
)
self
.
question_model
=
BertModel
(
**
bert_args
)
self
.
evidence_model
=
BertModel
(
**
bert_args
)
self
.
_question_key
=
'question_model'
self
.
context_model
=
BertModel
(
**
bert_args
)
self
.
_context_key
=
'context_model'
def
forward
(
self
,
input_tokens
,
input_attention_mask
,
input_types
,
context_tokens
,
context_attention_mask
,
context_types
):
question_ict_logits
,
_
=
self
.
question_model
.
forward
(
input_tokens
,
input_attention_mask
,
input_types
)
context_ict_logits
,
_
=
self
.
context_model
.
forward
(
context_tokens
,
context_attention_mask
,
context_types
)
# [batch x h] * [h x batch]
retrieval_scores
=
question_ict_logits
.
matmul
(
torch
.
transpose
(
context_ict_logits
,
0
,
1
))
return
retrieval_scores
def
state_dict_for_save_checkpoint
(
self
,
destination
=
None
,
prefix
=
''
,
keep_vars
=
False
):
state_dict_
=
{}
state_dict_
[
self
.
_question_key
]
\
=
self
.
question_model
.
state_dict_for_save_checkpoint
(
destination
,
prefix
,
keep_vars
)
state_dict_
[
self
.
_context_key
]
\
=
self
.
context_model
.
state_dict_for_save_checkpoint
(
destination
,
prefix
,
keep_vars
)
return
state_dict_
def
load_state_dict
(
self
,
state_dict
,
strict
=
True
):
"""Customized load."""
self
.
question_model
.
load_state_dict
(
state_dict
[
self
.
_question_key
],
strict
=
strict
)
self
.
context_model
.
load_state_dict
(
state_dict
[
self
.
_context_key
],
strict
=
strict
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment