Unverified Commit 41436d3d authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[DPR] Correct init (#13796)

* update

* add to docs and init

* make fix-copies
parent 44eb8bde
...@@ -41,6 +41,13 @@ DPRConfig ...@@ -41,6 +41,13 @@ DPRConfig
:members: :members:
DPRPreTrainedModel
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.DPRPreTrainedModel
:members:
DPRContextEncoderTokenizer DPRContextEncoderTokenizer
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
......
...@@ -773,6 +773,7 @@ if is_torch_available(): ...@@ -773,6 +773,7 @@ if is_torch_available():
"DPR_READER_PRETRAINED_MODEL_ARCHIVE_LIST", "DPR_READER_PRETRAINED_MODEL_ARCHIVE_LIST",
"DPRContextEncoder", "DPRContextEncoder",
"DPRPretrainedContextEncoder", "DPRPretrainedContextEncoder",
"DPRPreTrainedModel",
"DPRPretrainedQuestionEncoder", "DPRPretrainedQuestionEncoder",
"DPRPretrainedReader", "DPRPretrainedReader",
"DPRQuestionEncoder", "DPRQuestionEncoder",
...@@ -2512,6 +2513,7 @@ if TYPE_CHECKING: ...@@ -2512,6 +2513,7 @@ if TYPE_CHECKING:
DPR_READER_PRETRAINED_MODEL_ARCHIVE_LIST, DPR_READER_PRETRAINED_MODEL_ARCHIVE_LIST,
DPRContextEncoder, DPRContextEncoder,
DPRPretrainedContextEncoder, DPRPretrainedContextEncoder,
DPRPreTrainedModel,
DPRPretrainedQuestionEncoder, DPRPretrainedQuestionEncoder,
DPRPretrainedReader, DPRPretrainedReader,
DPRQuestionEncoder, DPRQuestionEncoder,
......
...@@ -46,6 +46,7 @@ if is_torch_available(): ...@@ -46,6 +46,7 @@ if is_torch_available():
"DPR_READER_PRETRAINED_MODEL_ARCHIVE_LIST", "DPR_READER_PRETRAINED_MODEL_ARCHIVE_LIST",
"DPRContextEncoder", "DPRContextEncoder",
"DPRPretrainedContextEncoder", "DPRPretrainedContextEncoder",
"DPRPreTrainedModel",
"DPRPretrainedQuestionEncoder", "DPRPretrainedQuestionEncoder",
"DPRPretrainedReader", "DPRPretrainedReader",
"DPRQuestionEncoder", "DPRQuestionEncoder",
...@@ -89,6 +90,7 @@ if TYPE_CHECKING: ...@@ -89,6 +90,7 @@ if TYPE_CHECKING:
DPR_READER_PRETRAINED_MODEL_ARCHIVE_LIST, DPR_READER_PRETRAINED_MODEL_ARCHIVE_LIST,
DPRContextEncoder, DPRContextEncoder,
DPRPretrainedContextEncoder, DPRPretrainedContextEncoder,
DPRPreTrainedModel,
DPRPretrainedQuestionEncoder, DPRPretrainedQuestionEncoder,
DPRPretrainedReader, DPRPretrainedReader,
DPRQuestionEncoder, DPRQuestionEncoder,
......
...@@ -147,7 +147,29 @@ class DPRReaderOutput(ModelOutput): ...@@ -147,7 +147,29 @@ class DPRReaderOutput(ModelOutput):
attentions: Optional[Tuple[torch.FloatTensor]] = None attentions: Optional[Tuple[torch.FloatTensor]] = None
class DPREncoder(PreTrainedModel): class DPRPreTrainedModel(PreTrainedModel):
def _init_weights(self, module):
"""Initialize the weights"""
if isinstance(module, nn.Linear):
# Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, BertEncoder):
module.gradient_checkpointing = value
class DPREncoder(DPRPreTrainedModel):
base_model_prefix = "bert_model" base_model_prefix = "bert_model"
...@@ -200,13 +222,8 @@ class DPREncoder(PreTrainedModel): ...@@ -200,13 +222,8 @@ class DPREncoder(PreTrainedModel):
return self.encode_proj.out_features return self.encode_proj.out_features
return self.bert_model.config.hidden_size return self.bert_model.config.hidden_size
def init_weights(self):
self.bert_model.init_weights()
if self.projection_dim > 0:
self.encode_proj.apply(self.bert_model._init_weights)
class DPRSpanPredictor(PreTrainedModel): class DPRSpanPredictor(DPRPreTrainedModel):
base_model_prefix = "encoder" base_model_prefix = "encoder"
...@@ -262,16 +279,13 @@ class DPRSpanPredictor(PreTrainedModel): ...@@ -262,16 +279,13 @@ class DPRSpanPredictor(PreTrainedModel):
attentions=outputs.attentions, attentions=outputs.attentions,
) )
def init_weights(self):
self.encoder.init_weights()
################## ##################
# PreTrainedModel # PreTrainedModel
################## ##################
class DPRPretrainedContextEncoder(PreTrainedModel): class DPRPretrainedContextEncoder(DPRPreTrainedModel):
""" """
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
models. models.
...@@ -282,11 +296,8 @@ class DPRPretrainedContextEncoder(PreTrainedModel): ...@@ -282,11 +296,8 @@ class DPRPretrainedContextEncoder(PreTrainedModel):
base_model_prefix = "ctx_encoder" base_model_prefix = "ctx_encoder"
_keys_to_ignore_on_load_missing = [r"position_ids"] _keys_to_ignore_on_load_missing = [r"position_ids"]
def init_weights(self):
self.ctx_encoder.init_weights()
class DPRPretrainedQuestionEncoder(PreTrainedModel): class DPRPretrainedQuestionEncoder(DPRPreTrainedModel):
""" """
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
models. models.
...@@ -297,15 +308,8 @@ class DPRPretrainedQuestionEncoder(PreTrainedModel): ...@@ -297,15 +308,8 @@ class DPRPretrainedQuestionEncoder(PreTrainedModel):
base_model_prefix = "question_encoder" base_model_prefix = "question_encoder"
_keys_to_ignore_on_load_missing = [r"position_ids"] _keys_to_ignore_on_load_missing = [r"position_ids"]
def init_weights(self):
self.question_encoder.init_weights()
def _set_gradient_checkpointing(self, module, value=False): class DPRPretrainedReader(DPRPreTrainedModel):
if isinstance(module, BertEncoder):
module.gradient_checkpointing = value
class DPRPretrainedReader(PreTrainedModel):
""" """
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
models. models.
...@@ -316,15 +320,6 @@ class DPRPretrainedReader(PreTrainedModel): ...@@ -316,15 +320,6 @@ class DPRPretrainedReader(PreTrainedModel):
base_model_prefix = "span_predictor" base_model_prefix = "span_predictor"
_keys_to_ignore_on_load_missing = [r"position_ids"] _keys_to_ignore_on_load_missing = [r"position_ids"]
def init_weights(self):
self.span_predictor.encoder.init_weights()
self.span_predictor.qa_classifier.apply(self.span_predictor.encoder.bert_model._init_weights)
self.span_predictor.qa_outputs.apply(self.span_predictor.encoder.bert_model._init_weights)
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, BertEncoder):
module.gradient_checkpointing = value
############### ###############
# Actual Models # Actual Models
......
...@@ -1462,6 +1462,15 @@ class DPRPretrainedContextEncoder: ...@@ -1462,6 +1462,15 @@ class DPRPretrainedContextEncoder:
requires_backends(self, ["torch"]) requires_backends(self, ["torch"])
class DPRPreTrainedModel:
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class DPRPretrainedQuestionEncoder: class DPRPretrainedQuestionEncoder:
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"]) requires_backends(self, ["torch"])
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
# limitations under the License. # limitations under the License.
import tempfile
import unittest import unittest
from transformers import DPRConfig, is_torch_available from transformers import DPRConfig, is_torch_available
...@@ -213,6 +214,19 @@ class DPRModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -213,6 +214,19 @@ class DPRModelTest(ModelTesterMixin, unittest.TestCase):
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_reader(*config_and_inputs) self.model_tester.create_and_check_reader(*config_and_inputs)
def test_init_changed_config(self):
config = self.model_tester.prepare_config_and_inputs()[0]
model = DPRQuestionEncoder(config=config)
model.to(torch_device)
model.eval()
with tempfile.TemporaryDirectory() as tmp_dirname:
model.save_pretrained(tmp_dirname)
model = DPRQuestionEncoder.from_pretrained(tmp_dirname, projection_dim=512)
self.assertIsNotNone(model)
@slow @slow
def test_model_from_pretrained(self): def test_model_from_pretrained(self):
for model_name in DPR_CONTEXT_ENCODER_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: for model_name in DPR_CONTEXT_ENCODER_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment