Unverified Commit f5741bcd authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Move QDQBert in just PyTorch block (#16062)

parent b6bdb943
...@@ -31,8 +31,6 @@ from . import dependency_versions_check ...@@ -31,8 +31,6 @@ from . import dependency_versions_check
from .file_utils import ( from .file_utils import (
_LazyModule, _LazyModule,
is_flax_available, is_flax_available,
is_pyctcdecode_available,
is_pytorch_quantization_available,
is_scatter_available, is_scatter_available,
is_sentencepiece_available, is_sentencepiece_available,
is_speech_available, is_speech_available,
...@@ -580,29 +578,6 @@ else: ...@@ -580,29 +578,6 @@ else:
name for name in dir(dummy_scatter_objects) if not name.startswith("_") name for name in dir(dummy_scatter_objects) if not name.startswith("_")
] ]
if is_torch_available() and is_pytorch_quantization_available():
_import_structure["models.qdqbert"].extend(
[
"QDQBERT_PRETRAINED_MODEL_ARCHIVE_LIST",
"QDQBertForMaskedLM",
"QDQBertForMultipleChoice",
"QDQBertForNextSentencePrediction",
"QDQBertForQuestionAnswering",
"QDQBertForSequenceClassification",
"QDQBertForTokenClassification",
"QDQBertLayer",
"QDQBertLMHeadModel",
"QDQBertModel",
"QDQBertPreTrainedModel",
"load_tf_weights_in_qdqbert",
]
)
else:
from .utils import dummy_pytorch_quantization_and_torch_objects
_import_structure["utils.dummy_pytorch_quantization_and_torch_objects"] = [
name for name in dir(dummy_pytorch_quantization_and_torch_objects) if not name.startswith("_")
]
# PyTorch-backed objects # PyTorch-backed objects
if is_torch_available(): if is_torch_available():
...@@ -1288,6 +1263,22 @@ if is_torch_available(): ...@@ -1288,6 +1263,22 @@ if is_torch_available():
"ProphetNetPreTrainedModel", "ProphetNetPreTrainedModel",
] ]
) )
_import_structure["models.qdqbert"].extend(
[
"QDQBERT_PRETRAINED_MODEL_ARCHIVE_LIST",
"QDQBertForMaskedLM",
"QDQBertForMultipleChoice",
"QDQBertForNextSentencePrediction",
"QDQBertForQuestionAnswering",
"QDQBertForSequenceClassification",
"QDQBertForTokenClassification",
"QDQBertLayer",
"QDQBertLMHeadModel",
"QDQBertModel",
"QDQBertPreTrainedModel",
"load_tf_weights_in_qdqbert",
]
)
_import_structure["models.rag"].extend( _import_structure["models.rag"].extend(
["RagModel", "RagPreTrainedModel", "RagSequenceForGeneration", "RagTokenForGeneration"] ["RagModel", "RagPreTrainedModel", "RagSequenceForGeneration", "RagTokenForGeneration"]
) )
...@@ -2828,24 +2819,6 @@ if TYPE_CHECKING: ...@@ -2828,24 +2819,6 @@ if TYPE_CHECKING:
else: else:
from .utils.dummy_scatter_objects import * from .utils.dummy_scatter_objects import *
if is_torch_available() and is_pytorch_quantization_available():
from .models.qdqbert import (
QDQBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
QDQBertForMaskedLM,
QDQBertForMultipleChoice,
QDQBertForNextSentencePrediction,
QDQBertForQuestionAnswering,
QDQBertForSequenceClassification,
QDQBertForTokenClassification,
QDQBertLayer,
QDQBertLMHeadModel,
QDQBertModel,
QDQBertPreTrainedModel,
load_tf_weights_in_qdqbert,
)
else:
from .utils.dummy_pytorch_quantization_and_torch_objects import *
if is_torch_available(): if is_torch_available():
# Benchmarks # Benchmarks
from .benchmark.benchmark import PyTorchBenchmark from .benchmark.benchmark import PyTorchBenchmark
...@@ -3428,6 +3401,20 @@ if TYPE_CHECKING: ...@@ -3428,6 +3401,20 @@ if TYPE_CHECKING:
ProphetNetModel, ProphetNetModel,
ProphetNetPreTrainedModel, ProphetNetPreTrainedModel,
) )
from .models.qdqbert import (
QDQBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
QDQBertForMaskedLM,
QDQBertForMultipleChoice,
QDQBertForNextSentencePrediction,
QDQBertForQuestionAnswering,
QDQBertForSequenceClassification,
QDQBertForTokenClassification,
QDQBertLayer,
QDQBertLMHeadModel,
QDQBertModel,
QDQBertPreTrainedModel,
load_tf_weights_in_qdqbert,
)
from .models.rag import RagModel, RagPreTrainedModel, RagSequenceForGeneration, RagTokenForGeneration from .models.rag import RagModel, RagPreTrainedModel, RagSequenceForGeneration, RagTokenForGeneration
from .models.realm import ( from .models.realm import (
REALM_PRETRAINED_MODEL_ARCHIVE_LIST, REALM_PRETRAINED_MODEL_ARCHIVE_LIST,
......
...@@ -3044,6 +3044,83 @@ class ProphetNetPreTrainedModel(metaclass=DummyObject): ...@@ -3044,6 +3044,83 @@ class ProphetNetPreTrainedModel(metaclass=DummyObject):
requires_backends(self, ["torch"]) requires_backends(self, ["torch"])
QDQBERT_PRETRAINED_MODEL_ARCHIVE_LIST = None
class QDQBertForMaskedLM(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class QDQBertForMultipleChoice(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class QDQBertForNextSentencePrediction(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class QDQBertForQuestionAnswering(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class QDQBertForSequenceClassification(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class QDQBertForTokenClassification(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class QDQBertLayer(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class QDQBertLMHeadModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class QDQBertModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class QDQBertPreTrainedModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
def load_tf_weights_in_qdqbert(*args, **kwargs):
requires_backends(load_tf_weights_in_qdqbert, ["torch"])
class RagModel(metaclass=DummyObject): class RagModel(metaclass=DummyObject):
_backends = ["torch"] _backends = ["torch"]
......
# This file is autogenerated by the command `make fix-copies`, do not edit.
# flake8: noqa
from ..file_utils import DummyObject, requires_backends
QDQBERT_PRETRAINED_MODEL_ARCHIVE_LIST = None
class QDQBertForMaskedLM(metaclass=DummyObject):
_backends = ["pytorch_quantization", "torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["pytorch_quantization", "torch"])
class QDQBertForMultipleChoice(metaclass=DummyObject):
_backends = ["pytorch_quantization", "torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["pytorch_quantization", "torch"])
class QDQBertForNextSentencePrediction(metaclass=DummyObject):
_backends = ["pytorch_quantization", "torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["pytorch_quantization", "torch"])
class QDQBertForQuestionAnswering(metaclass=DummyObject):
_backends = ["pytorch_quantization", "torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["pytorch_quantization", "torch"])
class QDQBertForSequenceClassification(metaclass=DummyObject):
_backends = ["pytorch_quantization", "torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["pytorch_quantization", "torch"])
class QDQBertForTokenClassification(metaclass=DummyObject):
_backends = ["pytorch_quantization", "torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["pytorch_quantization", "torch"])
class QDQBertLayer(metaclass=DummyObject):
_backends = ["pytorch_quantization", "torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["pytorch_quantization", "torch"])
class QDQBertLMHeadModel(metaclass=DummyObject):
_backends = ["pytorch_quantization", "torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["pytorch_quantization", "torch"])
class QDQBertModel(metaclass=DummyObject):
_backends = ["pytorch_quantization", "torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["pytorch_quantization", "torch"])
class QDQBertPreTrainedModel(metaclass=DummyObject):
_backends = ["pytorch_quantization", "torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["pytorch_quantization", "torch"])
def load_tf_weights_in_qdqbert(*args, **kwargs):
requires_backends(load_tf_weights_in_qdqbert, ["pytorch_quantization", "torch"])
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment