Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
41436d3d
Unverified
Commit
41436d3d
authored
Sep 30, 2021
by
Patrick von Platen
Committed by
GitHub
Sep 30, 2021
Browse files
[DPR] Correct init (#13796)
* update * add to docs and init * make fix-copies
parent
44eb8bde
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
61 additions
and
32 deletions
+61
-32
docs/source/model_doc/dpr.rst
docs/source/model_doc/dpr.rst
+7
-0
src/transformers/__init__.py
src/transformers/__init__.py
+2
-0
src/transformers/models/dpr/__init__.py
src/transformers/models/dpr/__init__.py
+2
-0
src/transformers/models/dpr/modeling_dpr.py
src/transformers/models/dpr/modeling_dpr.py
+27
-32
src/transformers/utils/dummy_pt_objects.py
src/transformers/utils/dummy_pt_objects.py
+9
-0
tests/test_modeling_dpr.py
tests/test_modeling_dpr.py
+14
-0
No files found.
docs/source/model_doc/dpr.rst
View file @
41436d3d
...
@@ -41,6 +41,13 @@ DPRConfig
...
@@ -41,6 +41,13 @@ DPRConfig
:members:
:members:
DPRPreTrainedModel
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.DPRPreTrainedModel
:members:
DPRContextEncoderTokenizer
DPRContextEncoderTokenizer
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
...
...
src/transformers/__init__.py
View file @
41436d3d
...
@@ -773,6 +773,7 @@ if is_torch_available():
...
@@ -773,6 +773,7 @@ if is_torch_available():
"DPR_READER_PRETRAINED_MODEL_ARCHIVE_LIST"
,
"DPR_READER_PRETRAINED_MODEL_ARCHIVE_LIST"
,
"DPRContextEncoder"
,
"DPRContextEncoder"
,
"DPRPretrainedContextEncoder"
,
"DPRPretrainedContextEncoder"
,
"DPRPreTrainedModel"
,
"DPRPretrainedQuestionEncoder"
,
"DPRPretrainedQuestionEncoder"
,
"DPRPretrainedReader"
,
"DPRPretrainedReader"
,
"DPRQuestionEncoder"
,
"DPRQuestionEncoder"
,
...
@@ -2512,6 +2513,7 @@ if TYPE_CHECKING:
...
@@ -2512,6 +2513,7 @@ if TYPE_CHECKING:
DPR_READER_PRETRAINED_MODEL_ARCHIVE_LIST
,
DPR_READER_PRETRAINED_MODEL_ARCHIVE_LIST
,
DPRContextEncoder
,
DPRContextEncoder
,
DPRPretrainedContextEncoder
,
DPRPretrainedContextEncoder
,
DPRPreTrainedModel
,
DPRPretrainedQuestionEncoder
,
DPRPretrainedQuestionEncoder
,
DPRPretrainedReader
,
DPRPretrainedReader
,
DPRQuestionEncoder
,
DPRQuestionEncoder
,
...
...
src/transformers/models/dpr/__init__.py
View file @
41436d3d
...
@@ -46,6 +46,7 @@ if is_torch_available():
...
@@ -46,6 +46,7 @@ if is_torch_available():
"DPR_READER_PRETRAINED_MODEL_ARCHIVE_LIST"
,
"DPR_READER_PRETRAINED_MODEL_ARCHIVE_LIST"
,
"DPRContextEncoder"
,
"DPRContextEncoder"
,
"DPRPretrainedContextEncoder"
,
"DPRPretrainedContextEncoder"
,
"DPRPreTrainedModel"
,
"DPRPretrainedQuestionEncoder"
,
"DPRPretrainedQuestionEncoder"
,
"DPRPretrainedReader"
,
"DPRPretrainedReader"
,
"DPRQuestionEncoder"
,
"DPRQuestionEncoder"
,
...
@@ -89,6 +90,7 @@ if TYPE_CHECKING:
...
@@ -89,6 +90,7 @@ if TYPE_CHECKING:
DPR_READER_PRETRAINED_MODEL_ARCHIVE_LIST
,
DPR_READER_PRETRAINED_MODEL_ARCHIVE_LIST
,
DPRContextEncoder
,
DPRContextEncoder
,
DPRPretrainedContextEncoder
,
DPRPretrainedContextEncoder
,
DPRPreTrainedModel
,
DPRPretrainedQuestionEncoder
,
DPRPretrainedQuestionEncoder
,
DPRPretrainedReader
,
DPRPretrainedReader
,
DPRQuestionEncoder
,
DPRQuestionEncoder
,
...
...
src/transformers/models/dpr/modeling_dpr.py
View file @
41436d3d
...
@@ -147,7 +147,29 @@ class DPRReaderOutput(ModelOutput):
...
@@ -147,7 +147,29 @@ class DPRReaderOutput(ModelOutput):
attentions
:
Optional
[
Tuple
[
torch
.
FloatTensor
]]
=
None
attentions
:
Optional
[
Tuple
[
torch
.
FloatTensor
]]
=
None
class
DPREncoder
(
PreTrainedModel
):
class
DPRPreTrainedModel
(
PreTrainedModel
):
def
_init_weights
(
self
,
module
):
"""Initialize the weights"""
if
isinstance
(
module
,
nn
.
Linear
):
# Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617
module
.
weight
.
data
.
normal_
(
mean
=
0.0
,
std
=
self
.
config
.
initializer_range
)
if
module
.
bias
is
not
None
:
module
.
bias
.
data
.
zero_
()
elif
isinstance
(
module
,
nn
.
Embedding
):
module
.
weight
.
data
.
normal_
(
mean
=
0.0
,
std
=
self
.
config
.
initializer_range
)
if
module
.
padding_idx
is
not
None
:
module
.
weight
.
data
[
module
.
padding_idx
].
zero_
()
elif
isinstance
(
module
,
nn
.
LayerNorm
):
module
.
bias
.
data
.
zero_
()
module
.
weight
.
data
.
fill_
(
1.0
)
def
_set_gradient_checkpointing
(
self
,
module
,
value
=
False
):
if
isinstance
(
module
,
BertEncoder
):
module
.
gradient_checkpointing
=
value
class
DPREncoder
(
DPRPreTrainedModel
):
base_model_prefix
=
"bert_model"
base_model_prefix
=
"bert_model"
...
@@ -200,13 +222,8 @@ class DPREncoder(PreTrainedModel):
...
@@ -200,13 +222,8 @@ class DPREncoder(PreTrainedModel):
return
self
.
encode_proj
.
out_features
return
self
.
encode_proj
.
out_features
return
self
.
bert_model
.
config
.
hidden_size
return
self
.
bert_model
.
config
.
hidden_size
def
init_weights
(
self
):
self
.
bert_model
.
init_weights
()
if
self
.
projection_dim
>
0
:
self
.
encode_proj
.
apply
(
self
.
bert_model
.
_init_weights
)
class
DPRSpanPredictor
(
PreTrainedModel
):
class
DPRSpanPredictor
(
DPR
PreTrainedModel
):
base_model_prefix
=
"encoder"
base_model_prefix
=
"encoder"
...
@@ -262,16 +279,13 @@ class DPRSpanPredictor(PreTrainedModel):
...
@@ -262,16 +279,13 @@ class DPRSpanPredictor(PreTrainedModel):
attentions
=
outputs
.
attentions
,
attentions
=
outputs
.
attentions
,
)
)
def
init_weights
(
self
):
self
.
encoder
.
init_weights
()
##################
##################
# PreTrainedModel
# PreTrainedModel
##################
##################
class
DPRPretrainedContextEncoder
(
PreTrainedModel
):
class
DPRPretrainedContextEncoder
(
DPR
PreTrainedModel
):
"""
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
models.
models.
...
@@ -282,11 +296,8 @@ class DPRPretrainedContextEncoder(PreTrainedModel):
...
@@ -282,11 +296,8 @@ class DPRPretrainedContextEncoder(PreTrainedModel):
base_model_prefix
=
"ctx_encoder"
base_model_prefix
=
"ctx_encoder"
_keys_to_ignore_on_load_missing
=
[
r
"position_ids"
]
_keys_to_ignore_on_load_missing
=
[
r
"position_ids"
]
def
init_weights
(
self
):
self
.
ctx_encoder
.
init_weights
()
class
DPRPretrainedQuestionEncoder
(
PreTrainedModel
):
class
DPRPretrainedQuestionEncoder
(
DPR
PreTrainedModel
):
"""
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
models.
models.
...
@@ -297,15 +308,8 @@ class DPRPretrainedQuestionEncoder(PreTrainedModel):
...
@@ -297,15 +308,8 @@ class DPRPretrainedQuestionEncoder(PreTrainedModel):
base_model_prefix
=
"question_encoder"
base_model_prefix
=
"question_encoder"
_keys_to_ignore_on_load_missing
=
[
r
"position_ids"
]
_keys_to_ignore_on_load_missing
=
[
r
"position_ids"
]
def
init_weights
(
self
):
self
.
question_encoder
.
init_weights
()
def
_set_gradient_checkpointing
(
self
,
module
,
value
=
False
):
class
DPRPretrainedReader
(
DPRPreTrainedModel
):
if
isinstance
(
module
,
BertEncoder
):
module
.
gradient_checkpointing
=
value
class
DPRPretrainedReader
(
PreTrainedModel
):
"""
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
models.
models.
...
@@ -316,15 +320,6 @@ class DPRPretrainedReader(PreTrainedModel):
...
@@ -316,15 +320,6 @@ class DPRPretrainedReader(PreTrainedModel):
base_model_prefix
=
"span_predictor"
base_model_prefix
=
"span_predictor"
_keys_to_ignore_on_load_missing
=
[
r
"position_ids"
]
_keys_to_ignore_on_load_missing
=
[
r
"position_ids"
]
def
init_weights
(
self
):
self
.
span_predictor
.
encoder
.
init_weights
()
self
.
span_predictor
.
qa_classifier
.
apply
(
self
.
span_predictor
.
encoder
.
bert_model
.
_init_weights
)
self
.
span_predictor
.
qa_outputs
.
apply
(
self
.
span_predictor
.
encoder
.
bert_model
.
_init_weights
)
def
_set_gradient_checkpointing
(
self
,
module
,
value
=
False
):
if
isinstance
(
module
,
BertEncoder
):
module
.
gradient_checkpointing
=
value
###############
###############
# Actual Models
# Actual Models
...
...
src/transformers/utils/dummy_pt_objects.py
View file @
41436d3d
...
@@ -1462,6 +1462,15 @@ class DPRPretrainedContextEncoder:
...
@@ -1462,6 +1462,15 @@ class DPRPretrainedContextEncoder:
requires_backends
(
self
,
[
"torch"
])
requires_backends
(
self
,
[
"torch"
])
class
DPRPreTrainedModel
:
def
__init__
(
self
,
*
args
,
**
kwargs
):
requires_backends
(
self
,
[
"torch"
])
@
classmethod
def
from_pretrained
(
cls
,
*
args
,
**
kwargs
):
requires_backends
(
cls
,
[
"torch"
])
class
DPRPretrainedQuestionEncoder
:
class
DPRPretrainedQuestionEncoder
:
def
__init__
(
self
,
*
args
,
**
kwargs
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
requires_backends
(
self
,
[
"torch"
])
requires_backends
(
self
,
[
"torch"
])
...
...
tests/test_modeling_dpr.py
View file @
41436d3d
...
@@ -14,6 +14,7 @@
...
@@ -14,6 +14,7 @@
# limitations under the License.
# limitations under the License.
import
tempfile
import
unittest
import
unittest
from
transformers
import
DPRConfig
,
is_torch_available
from
transformers
import
DPRConfig
,
is_torch_available
...
@@ -213,6 +214,19 @@ class DPRModelTest(ModelTesterMixin, unittest.TestCase):
...
@@ -213,6 +214,19 @@ class DPRModelTest(ModelTesterMixin, unittest.TestCase):
config_and_inputs
=
self
.
model_tester
.
prepare_config_and_inputs
()
config_and_inputs
=
self
.
model_tester
.
prepare_config_and_inputs
()
self
.
model_tester
.
create_and_check_reader
(
*
config_and_inputs
)
self
.
model_tester
.
create_and_check_reader
(
*
config_and_inputs
)
def
test_init_changed_config
(
self
):
config
=
self
.
model_tester
.
prepare_config_and_inputs
()[
0
]
model
=
DPRQuestionEncoder
(
config
=
config
)
model
.
to
(
torch_device
)
model
.
eval
()
with
tempfile
.
TemporaryDirectory
()
as
tmp_dirname
:
model
.
save_pretrained
(
tmp_dirname
)
model
=
DPRQuestionEncoder
.
from_pretrained
(
tmp_dirname
,
projection_dim
=
512
)
self
.
assertIsNotNone
(
model
)
@
slow
@
slow
def
test_model_from_pretrained
(
self
):
def
test_model_from_pretrained
(
self
):
for
model_name
in
DPR_CONTEXT_ENCODER_PRETRAINED_MODEL_ARCHIVE_LIST
[:
1
]:
for
model_name
in
DPR_CONTEXT_ENCODER_PRETRAINED_MODEL_ARCHIVE_LIST
[:
1
]:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment