Unverified Commit 41436d3d authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[DPR] Correct init (#13796)

* update

* add to docs and init

* make fix-copies
parent 44eb8bde
......@@ -41,6 +41,13 @@ DPRConfig
:members:
DPRPreTrainedModel
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.DPRPreTrainedModel
:members:
DPRContextEncoderTokenizer
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
......
......@@ -773,6 +773,7 @@ if is_torch_available():
"DPR_READER_PRETRAINED_MODEL_ARCHIVE_LIST",
"DPRContextEncoder",
"DPRPretrainedContextEncoder",
"DPRPreTrainedModel",
"DPRPretrainedQuestionEncoder",
"DPRPretrainedReader",
"DPRQuestionEncoder",
......@@ -2512,6 +2513,7 @@ if TYPE_CHECKING:
DPR_READER_PRETRAINED_MODEL_ARCHIVE_LIST,
DPRContextEncoder,
DPRPretrainedContextEncoder,
DPRPreTrainedModel,
DPRPretrainedQuestionEncoder,
DPRPretrainedReader,
DPRQuestionEncoder,
......
......@@ -46,6 +46,7 @@ if is_torch_available():
"DPR_READER_PRETRAINED_MODEL_ARCHIVE_LIST",
"DPRContextEncoder",
"DPRPretrainedContextEncoder",
"DPRPreTrainedModel",
"DPRPretrainedQuestionEncoder",
"DPRPretrainedReader",
"DPRQuestionEncoder",
......@@ -89,6 +90,7 @@ if TYPE_CHECKING:
DPR_READER_PRETRAINED_MODEL_ARCHIVE_LIST,
DPRContextEncoder,
DPRPretrainedContextEncoder,
DPRPreTrainedModel,
DPRPretrainedQuestionEncoder,
DPRPretrainedReader,
DPRQuestionEncoder,
......
......@@ -147,7 +147,29 @@ class DPRReaderOutput(ModelOutput):
attentions: Optional[Tuple[torch.FloatTensor]] = None
class DPREncoder(PreTrainedModel):
class DPRPreTrainedModel(PreTrainedModel):
def _init_weights(self, module):
"""Initialize the weights"""
if isinstance(module, nn.Linear):
# Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, BertEncoder):
module.gradient_checkpointing = value
class DPREncoder(DPRPreTrainedModel):
base_model_prefix = "bert_model"
......@@ -200,13 +222,8 @@ class DPREncoder(PreTrainedModel):
return self.encode_proj.out_features
return self.bert_model.config.hidden_size
def init_weights(self):
self.bert_model.init_weights()
if self.projection_dim > 0:
self.encode_proj.apply(self.bert_model._init_weights)
class DPRSpanPredictor(PreTrainedModel):
class DPRSpanPredictor(DPRPreTrainedModel):
base_model_prefix = "encoder"
......@@ -262,16 +279,13 @@ class DPRSpanPredictor(PreTrainedModel):
attentions=outputs.attentions,
)
def init_weights(self):
self.encoder.init_weights()
##################
# PreTrainedModel
##################
class DPRPretrainedContextEncoder(PreTrainedModel):
class DPRPretrainedContextEncoder(DPRPreTrainedModel):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
models.
......@@ -282,11 +296,8 @@ class DPRPretrainedContextEncoder(PreTrainedModel):
base_model_prefix = "ctx_encoder"
_keys_to_ignore_on_load_missing = [r"position_ids"]
def init_weights(self):
self.ctx_encoder.init_weights()
class DPRPretrainedQuestionEncoder(PreTrainedModel):
class DPRPretrainedQuestionEncoder(DPRPreTrainedModel):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
models.
......@@ -297,15 +308,8 @@ class DPRPretrainedQuestionEncoder(PreTrainedModel):
base_model_prefix = "question_encoder"
_keys_to_ignore_on_load_missing = [r"position_ids"]
def init_weights(self):
self.question_encoder.init_weights()
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, BertEncoder):
module.gradient_checkpointing = value
class DPRPretrainedReader(PreTrainedModel):
class DPRPretrainedReader(DPRPreTrainedModel):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
models.
......@@ -316,15 +320,6 @@ class DPRPretrainedReader(PreTrainedModel):
base_model_prefix = "span_predictor"
_keys_to_ignore_on_load_missing = [r"position_ids"]
def init_weights(self):
self.span_predictor.encoder.init_weights()
self.span_predictor.qa_classifier.apply(self.span_predictor.encoder.bert_model._init_weights)
self.span_predictor.qa_outputs.apply(self.span_predictor.encoder.bert_model._init_weights)
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, BertEncoder):
module.gradient_checkpointing = value
###############
# Actual Models
......
......@@ -1462,6 +1462,15 @@ class DPRPretrainedContextEncoder:
requires_backends(self, ["torch"])
class DPRPreTrainedModel:
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class DPRPretrainedQuestionEncoder:
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
......
......@@ -14,6 +14,7 @@
# limitations under the License.
import tempfile
import unittest
from transformers import DPRConfig, is_torch_available
......@@ -213,6 +214,19 @@ class DPRModelTest(ModelTesterMixin, unittest.TestCase):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_reader(*config_and_inputs)
def test_init_changed_config(self):
config = self.model_tester.prepare_config_and_inputs()[0]
model = DPRQuestionEncoder(config=config)
model.to(torch_device)
model.eval()
with tempfile.TemporaryDirectory() as tmp_dirname:
model.save_pretrained(tmp_dirname)
model = DPRQuestionEncoder.from_pretrained(tmp_dirname, projection_dim=512)
self.assertIsNotNone(model)
@slow
def test_model_from_pretrained(self):
for model_name in DPR_CONTEXT_ENCODER_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment