Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
41436d3d
Unverified
Commit
41436d3d
authored
Sep 30, 2021
by
Patrick von Platen
Committed by
GitHub
Sep 30, 2021
Browse files
[DPR] Correct init (#13796)
* update * add to docs and init * make fix-copies
parent
44eb8bde
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
61 additions
and
32 deletions
+61
-32
docs/source/model_doc/dpr.rst
docs/source/model_doc/dpr.rst
+7
-0
src/transformers/__init__.py
src/transformers/__init__.py
+2
-0
src/transformers/models/dpr/__init__.py
src/transformers/models/dpr/__init__.py
+2
-0
src/transformers/models/dpr/modeling_dpr.py
src/transformers/models/dpr/modeling_dpr.py
+27
-32
src/transformers/utils/dummy_pt_objects.py
src/transformers/utils/dummy_pt_objects.py
+9
-0
tests/test_modeling_dpr.py
tests/test_modeling_dpr.py
+14
-0
No files found.
docs/source/model_doc/dpr.rst
View file @
41436d3d
...
@@ -41,6 +41,13 @@ DPRConfig
...
@@ -41,6 +41,13 @@ DPRConfig
:members:
:members:
DPRPreTrainedModel
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.DPRPreTrainedModel
:members:
DPRContextEncoderTokenizer
DPRContextEncoderTokenizer
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
...
...
src/transformers/__init__.py
View file @
41436d3d
...
@@ -773,6 +773,7 @@ if is_torch_available():
...
@@ -773,6 +773,7 @@ if is_torch_available():
"DPR_READER_PRETRAINED_MODEL_ARCHIVE_LIST"
,
"DPR_READER_PRETRAINED_MODEL_ARCHIVE_LIST"
,
"DPRContextEncoder"
,
"DPRContextEncoder"
,
"DPRPretrainedContextEncoder"
,
"DPRPretrainedContextEncoder"
,
"DPRPreTrainedModel"
,
"DPRPretrainedQuestionEncoder"
,
"DPRPretrainedQuestionEncoder"
,
"DPRPretrainedReader"
,
"DPRPretrainedReader"
,
"DPRQuestionEncoder"
,
"DPRQuestionEncoder"
,
...
@@ -2512,6 +2513,7 @@ if TYPE_CHECKING:
...
@@ -2512,6 +2513,7 @@ if TYPE_CHECKING:
DPR_READER_PRETRAINED_MODEL_ARCHIVE_LIST
,
DPR_READER_PRETRAINED_MODEL_ARCHIVE_LIST
,
DPRContextEncoder
,
DPRContextEncoder
,
DPRPretrainedContextEncoder
,
DPRPretrainedContextEncoder
,
DPRPreTrainedModel
,
DPRPretrainedQuestionEncoder
,
DPRPretrainedQuestionEncoder
,
DPRPretrainedReader
,
DPRPretrainedReader
,
DPRQuestionEncoder
,
DPRQuestionEncoder
,
...
...
src/transformers/models/dpr/__init__.py
View file @
41436d3d
...
@@ -46,6 +46,7 @@ if is_torch_available():
...
@@ -46,6 +46,7 @@ if is_torch_available():
"DPR_READER_PRETRAINED_MODEL_ARCHIVE_LIST"
,
"DPR_READER_PRETRAINED_MODEL_ARCHIVE_LIST"
,
"DPRContextEncoder"
,
"DPRContextEncoder"
,
"DPRPretrainedContextEncoder"
,
"DPRPretrainedContextEncoder"
,
"DPRPreTrainedModel"
,
"DPRPretrainedQuestionEncoder"
,
"DPRPretrainedQuestionEncoder"
,
"DPRPretrainedReader"
,
"DPRPretrainedReader"
,
"DPRQuestionEncoder"
,
"DPRQuestionEncoder"
,
...
@@ -89,6 +90,7 @@ if TYPE_CHECKING:
...
@@ -89,6 +90,7 @@ if TYPE_CHECKING:
DPR_READER_PRETRAINED_MODEL_ARCHIVE_LIST
,
DPR_READER_PRETRAINED_MODEL_ARCHIVE_LIST
,
DPRContextEncoder
,
DPRContextEncoder
,
DPRPretrainedContextEncoder
,
DPRPretrainedContextEncoder
,
DPRPreTrainedModel
,
DPRPretrainedQuestionEncoder
,
DPRPretrainedQuestionEncoder
,
DPRPretrainedReader
,
DPRPretrainedReader
,
DPRQuestionEncoder
,
DPRQuestionEncoder
,
...
...
src/transformers/models/dpr/modeling_dpr.py
View file @
41436d3d
...
@@ -147,7 +147,29 @@ class DPRReaderOutput(ModelOutput):
...
@@ -147,7 +147,29 @@ class DPRReaderOutput(ModelOutput):
attentions
:
Optional
[
Tuple
[
torch
.
FloatTensor
]]
=
None
attentions
:
Optional
[
Tuple
[
torch
.
FloatTensor
]]
=
None
class
DPREncoder
(
PreTrainedModel
):
class
DPRPreTrainedModel
(
PreTrainedModel
):
def
_init_weights
(
self
,
module
):
"""Initialize the weights"""
if
isinstance
(
module
,
nn
.
Linear
):
# Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617
module
.
weight
.
data
.
normal_
(
mean
=
0.0
,
std
=
self
.
config
.
initializer_range
)
if
module
.
bias
is
not
None
:
module
.
bias
.
data
.
zero_
()
elif
isinstance
(
module
,
nn
.
Embedding
):
module
.
weight
.
data
.
normal_
(
mean
=
0.0
,
std
=
self
.
config
.
initializer_range
)
if
module
.
padding_idx
is
not
None
:
module
.
weight
.
data
[
module
.
padding_idx
].
zero_
()
elif
isinstance
(
module
,
nn
.
LayerNorm
):
module
.
bias
.
data
.
zero_
()
module
.
weight
.
data
.
fill_
(
1.0
)
def
_set_gradient_checkpointing
(
self
,
module
,
value
=
False
):
if
isinstance
(
module
,
BertEncoder
):
module
.
gradient_checkpointing
=
value
class
DPREncoder
(
DPRPreTrainedModel
):
base_model_prefix
=
"bert_model"
base_model_prefix
=
"bert_model"
...
@@ -200,13 +222,8 @@ class DPREncoder(PreTrainedModel):
...
@@ -200,13 +222,8 @@ class DPREncoder(PreTrainedModel):
return
self
.
encode_proj
.
out_features
return
self
.
encode_proj
.
out_features
return
self
.
bert_model
.
config
.
hidden_size
return
self
.
bert_model
.
config
.
hidden_size
def
init_weights
(
self
):
self
.
bert_model
.
init_weights
()
if
self
.
projection_dim
>
0
:
self
.
encode_proj
.
apply
(
self
.
bert_model
.
_init_weights
)
class
DPRSpanPredictor
(
PreTrainedModel
):
class
DPRSpanPredictor
(
DPR
PreTrainedModel
):
base_model_prefix
=
"encoder"
base_model_prefix
=
"encoder"
...
@@ -262,16 +279,13 @@ class DPRSpanPredictor(PreTrainedModel):
...
@@ -262,16 +279,13 @@ class DPRSpanPredictor(PreTrainedModel):
attentions
=
outputs
.
attentions
,
attentions
=
outputs
.
attentions
,
)
)
def
init_weights
(
self
):
self
.
encoder
.
init_weights
()
##################
##################
# PreTrainedModel
# PreTrainedModel
##################
##################
class
DPRPretrainedContextEncoder
(
PreTrainedModel
):
class
DPRPretrainedContextEncoder
(
DPR
PreTrainedModel
):
"""
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
models.
models.
...
@@ -282,11 +296,8 @@ class DPRPretrainedContextEncoder(PreTrainedModel):
...
@@ -282,11 +296,8 @@ class DPRPretrainedContextEncoder(PreTrainedModel):
base_model_prefix
=
"ctx_encoder"
base_model_prefix
=
"ctx_encoder"
_keys_to_ignore_on_load_missing
=
[
r
"position_ids"
]
_keys_to_ignore_on_load_missing
=
[
r
"position_ids"
]
def
init_weights
(
self
):
self
.
ctx_encoder
.
init_weights
()
class
DPRPretrainedQuestionEncoder
(
PreTrainedModel
):
class
DPRPretrainedQuestionEncoder
(
DPR
PreTrainedModel
):
"""
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
models.
models.
...
@@ -297,15 +308,8 @@ class DPRPretrainedQuestionEncoder(PreTrainedModel):
...
@@ -297,15 +308,8 @@ class DPRPretrainedQuestionEncoder(PreTrainedModel):
base_model_prefix
=
"question_encoder"
base_model_prefix
=
"question_encoder"
_keys_to_ignore_on_load_missing
=
[
r
"position_ids"
]
_keys_to_ignore_on_load_missing
=
[
r
"position_ids"
]
def
init_weights
(
self
):
self
.
question_encoder
.
init_weights
()
def
_set_gradient_checkpointing
(
self
,
module
,
value
=
False
):
class
DPRPretrainedReader
(
DPRPreTrainedModel
):
if
isinstance
(
module
,
BertEncoder
):
module
.
gradient_checkpointing
=
value
class
DPRPretrainedReader
(
PreTrainedModel
):
"""
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
models.
models.
...
@@ -316,15 +320,6 @@ class DPRPretrainedReader(PreTrainedModel):
...
@@ -316,15 +320,6 @@ class DPRPretrainedReader(PreTrainedModel):
base_model_prefix
=
"span_predictor"
base_model_prefix
=
"span_predictor"
_keys_to_ignore_on_load_missing
=
[
r
"position_ids"
]
_keys_to_ignore_on_load_missing
=
[
r
"position_ids"
]
def
init_weights
(
self
):
self
.
span_predictor
.
encoder
.
init_weights
()
self
.
span_predictor
.
qa_classifier
.
apply
(
self
.
span_predictor
.
encoder
.
bert_model
.
_init_weights
)
self
.
span_predictor
.
qa_outputs
.
apply
(
self
.
span_predictor
.
encoder
.
bert_model
.
_init_weights
)
def
_set_gradient_checkpointing
(
self
,
module
,
value
=
False
):
if
isinstance
(
module
,
BertEncoder
):
module
.
gradient_checkpointing
=
value
###############
###############
# Actual Models
# Actual Models
...
...
src/transformers/utils/dummy_pt_objects.py
View file @
41436d3d
...
@@ -1462,6 +1462,15 @@ class DPRPretrainedContextEncoder:
...
@@ -1462,6 +1462,15 @@ class DPRPretrainedContextEncoder:
requires_backends
(
self
,
[
"torch"
])
requires_backends
(
self
,
[
"torch"
])
class
DPRPreTrainedModel
:
def
__init__
(
self
,
*
args
,
**
kwargs
):
requires_backends
(
self
,
[
"torch"
])
@
classmethod
def
from_pretrained
(
cls
,
*
args
,
**
kwargs
):
requires_backends
(
cls
,
[
"torch"
])
class
DPRPretrainedQuestionEncoder
:
class
DPRPretrainedQuestionEncoder
:
def
__init__
(
self
,
*
args
,
**
kwargs
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
requires_backends
(
self
,
[
"torch"
])
requires_backends
(
self
,
[
"torch"
])
...
...
tests/test_modeling_dpr.py
View file @
41436d3d
...
@@ -14,6 +14,7 @@
...
@@ -14,6 +14,7 @@
# limitations under the License.
# limitations under the License.
import
tempfile
import
unittest
import
unittest
from
transformers
import
DPRConfig
,
is_torch_available
from
transformers
import
DPRConfig
,
is_torch_available
...
@@ -213,6 +214,19 @@ class DPRModelTest(ModelTesterMixin, unittest.TestCase):
...
@@ -213,6 +214,19 @@ class DPRModelTest(ModelTesterMixin, unittest.TestCase):
config_and_inputs
=
self
.
model_tester
.
prepare_config_and_inputs
()
config_and_inputs
=
self
.
model_tester
.
prepare_config_and_inputs
()
self
.
model_tester
.
create_and_check_reader
(
*
config_and_inputs
)
self
.
model_tester
.
create_and_check_reader
(
*
config_and_inputs
)
def
test_init_changed_config
(
self
):
config
=
self
.
model_tester
.
prepare_config_and_inputs
()[
0
]
model
=
DPRQuestionEncoder
(
config
=
config
)
model
.
to
(
torch_device
)
model
.
eval
()
with
tempfile
.
TemporaryDirectory
()
as
tmp_dirname
:
model
.
save_pretrained
(
tmp_dirname
)
model
=
DPRQuestionEncoder
.
from_pretrained
(
tmp_dirname
,
projection_dim
=
512
)
self
.
assertIsNotNone
(
model
)
@
slow
@
slow
def
test_model_from_pretrained
(
self
):
def
test_model_from_pretrained
(
self
):
for
model_name
in
DPR_CONTEXT_ENCODER_PRETRAINED_MODEL_ARCHIVE_LIST
[:
1
]:
for
model_name
in
DPR_CONTEXT_ENCODER_PRETRAINED_MODEL_ARCHIVE_LIST
[:
1
]:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment