Unverified Commit 1a92bc57 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Fix dummy objects for quantization (#14478)

* Fix dummy objects for quantization

* Add more models
parent c9d2cf85
......@@ -477,6 +477,13 @@ class FlaxBertForNextSentencePrediction:
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["flax"])
def __call__(self, *args, **kwargs):
requires_backends(self, ["flax"])
class FlaxBertForPreTraining:
def __init__(self, *args, **kwargs):
......
......@@ -56,6 +56,13 @@ class TextDatasetForNextSentencePrediction:
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
def forward(self, *args, **kwargs):
requires_backends(self, ["torch"])
class BeamScorer:
def __init__(self, *args, **kwargs):
......@@ -783,6 +790,13 @@ class BertForNextSentencePrediction:
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
def forward(self, *args, **kwargs):
requires_backends(self, ["torch"])
class BertForPreTraining:
def __init__(self, *args, **kwargs):
......@@ -2106,6 +2120,13 @@ class FNetForNextSentencePrediction:
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
def forward(self, *args, **kwargs):
requires_backends(self, ["torch"])
class FNetForPreTraining:
def __init__(self, *args, **kwargs):
......@@ -3254,6 +3275,13 @@ class MegatronBertForNextSentencePrediction:
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
def forward(self, *args, **kwargs):
requires_backends(self, ["torch"])
class MegatronBertForPreTraining:
def __init__(self, *args, **kwargs):
......@@ -3373,6 +3401,13 @@ class MobileBertForNextSentencePrediction:
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
def forward(self, *args, **kwargs):
requires_backends(self, ["torch"])
class MobileBertForPreTraining:
def __init__(self, *args, **kwargs):
......
......@@ -13,6 +13,9 @@ class QDQBertForMaskedLM:
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["pytorch_quantization", "torch"])
def forward(self, *args, **kwargs):
requires_backends(self, ["pytorch_quantization", "torch"])
class QDQBertForMultipleChoice:
def __init__(self, *args, **kwargs):
......@@ -22,11 +25,21 @@ class QDQBertForMultipleChoice:
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["pytorch_quantization", "torch"])
def forward(self, *args, **kwargs):
requires_backends(self, ["pytorch_quantization", "torch"])
class QDQBertForNextSentencePrediction:
def __init__(self, *args, **kwargs):
requires_backends(self, ["pytorch_quantization", "torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["pytorch_quantization", "torch"])
def forward(self, *args, **kwargs):
requires_backends(self, ["pytorch_quantization", "torch"])
class QDQBertForQuestionAnswering:
def __init__(self, *args, **kwargs):
......@@ -36,6 +49,9 @@ class QDQBertForQuestionAnswering:
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["pytorch_quantization", "torch"])
def forward(self, *args, **kwargs):
requires_backends(self, ["pytorch_quantization", "torch"])
class QDQBertForSequenceClassification:
def __init__(self, *args, **kwargs):
......@@ -45,6 +61,9 @@ class QDQBertForSequenceClassification:
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["pytorch_quantization", "torch"])
def forward(self, *args, **kwargs):
requires_backends(self, ["pytorch_quantization", "torch"])
class QDQBertForTokenClassification:
def __init__(self, *args, **kwargs):
......@@ -54,6 +73,9 @@ class QDQBertForTokenClassification:
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["pytorch_quantization", "torch"])
def forward(self, *args, **kwargs):
requires_backends(self, ["pytorch_quantization", "torch"])
class QDQBertLayer:
def __init__(self, *args, **kwargs):
......@@ -68,6 +90,9 @@ class QDQBertLMHeadModel:
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["pytorch_quantization", "torch"])
def forward(self, *args, **kwargs):
requires_backends(self, ["pytorch_quantization", "torch"])
class QDQBertModel:
def __init__(self, *args, **kwargs):
......@@ -77,6 +102,9 @@ class QDQBertModel:
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["pytorch_quantization", "torch"])
def forward(self, *args, **kwargs):
requires_backends(self, ["pytorch_quantization", "torch"])
class QDQBertPreTrainedModel:
def __init__(self, *args, **kwargs):
......@@ -86,6 +114,9 @@ class QDQBertPreTrainedModel:
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["pytorch_quantization", "torch"])
def forward(self, *args, **kwargs):
requires_backends(self, ["pytorch_quantization", "torch"])
def load_tf_weights_in_qdqbert(*args, **kwargs):
requires_backends(load_tf_weights_in_qdqbert, ["pytorch_quantization", "torch"])
......@@ -452,6 +452,13 @@ class TFBertForNextSentencePrediction:
def __init__(self, *args, **kwargs):
requires_backends(self, ["tf"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["tf"])
def call(self, *args, **kwargs):
requires_backends(self, ["tf"])
class TFBertForPreTraining:
def __init__(self, *args, **kwargs):
......@@ -1774,6 +1781,13 @@ class TFMobileBertForNextSentencePrediction:
def __init__(self, *args, **kwargs):
requires_backends(self, ["tf"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["tf"])
def call(self, *args, **kwargs):
requires_backends(self, ["tf"])
class TFMobileBertForPreTraining:
def __init__(self, *args, **kwargs):
......
......@@ -23,7 +23,7 @@ import re
PATH_TO_TRANSFORMERS = "src/transformers"
# Matches is_xxx_available()
_re_backend = re.compile(r"is\_([a-z]*)_available()")
_re_backend = re.compile(r"is\_([a-z_]*)_available()")
# Matches from xxx import bla
_re_single_line_import = re.compile(r"\s+from\s+\S*\s+import\s+([^\(\s].*)\n")
_re_test_backend = re.compile(r"^\s+if\s+is\_[a-z]*\_available\(\)")
......@@ -131,6 +131,7 @@ def create_dummy_object(name, backend_name):
"ForConditionalGeneration",
"ForMaskedLM",
"ForMultipleChoice",
"ForNextSentencePrediction",
"ForObjectDetection",
"ForQuestionAnswering",
"ForSegmentation",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment