Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
f5741bcd
Unverified
Commit
f5741bcd
authored
Mar 11, 2022
by
Sylvain Gugger
Committed by
GitHub
Mar 11, 2022
Browse files
Move QDQBert in just PyTorch block (#16062)
parent
b6bdb943
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
107 additions
and
123 deletions
+107
-123
src/transformers/__init__.py
src/transformers/__init__.py
+30
-43
src/transformers/utils/dummy_pt_objects.py
src/transformers/utils/dummy_pt_objects.py
+77
-0
src/transformers/utils/dummy_pytorch_quantization_and_torch_objects.py
...ers/utils/dummy_pytorch_quantization_and_torch_objects.py
+0
-80
No files found.
src/transformers/__init__.py
View file @
f5741bcd
...
@@ -31,8 +31,6 @@ from . import dependency_versions_check
...
@@ -31,8 +31,6 @@ from . import dependency_versions_check
from
.file_utils
import
(
from
.file_utils
import
(
_LazyModule
,
_LazyModule
,
is_flax_available
,
is_flax_available
,
is_pyctcdecode_available
,
is_pytorch_quantization_available
,
is_scatter_available
,
is_scatter_available
,
is_sentencepiece_available
,
is_sentencepiece_available
,
is_speech_available
,
is_speech_available
,
...
@@ -580,29 +578,6 @@ else:
...
@@ -580,29 +578,6 @@ else:
name
for
name
in
dir
(
dummy_scatter_objects
)
if
not
name
.
startswith
(
"_"
)
name
for
name
in
dir
(
dummy_scatter_objects
)
if
not
name
.
startswith
(
"_"
)
]
]
if
is_torch_available
()
and
is_pytorch_quantization_available
():
_import_structure
[
"models.qdqbert"
].
extend
(
[
"QDQBERT_PRETRAINED_MODEL_ARCHIVE_LIST"
,
"QDQBertForMaskedLM"
,
"QDQBertForMultipleChoice"
,
"QDQBertForNextSentencePrediction"
,
"QDQBertForQuestionAnswering"
,
"QDQBertForSequenceClassification"
,
"QDQBertForTokenClassification"
,
"QDQBertLayer"
,
"QDQBertLMHeadModel"
,
"QDQBertModel"
,
"QDQBertPreTrainedModel"
,
"load_tf_weights_in_qdqbert"
,
]
)
else
:
from
.utils
import
dummy_pytorch_quantization_and_torch_objects
_import_structure
[
"utils.dummy_pytorch_quantization_and_torch_objects"
]
=
[
name
for
name
in
dir
(
dummy_pytorch_quantization_and_torch_objects
)
if
not
name
.
startswith
(
"_"
)
]
# PyTorch-backed objects
# PyTorch-backed objects
if
is_torch_available
():
if
is_torch_available
():
...
@@ -1288,6 +1263,22 @@ if is_torch_available():
...
@@ -1288,6 +1263,22 @@ if is_torch_available():
"ProphetNetPreTrainedModel"
,
"ProphetNetPreTrainedModel"
,
]
]
)
)
_import_structure
[
"models.qdqbert"
].
extend
(
[
"QDQBERT_PRETRAINED_MODEL_ARCHIVE_LIST"
,
"QDQBertForMaskedLM"
,
"QDQBertForMultipleChoice"
,
"QDQBertForNextSentencePrediction"
,
"QDQBertForQuestionAnswering"
,
"QDQBertForSequenceClassification"
,
"QDQBertForTokenClassification"
,
"QDQBertLayer"
,
"QDQBertLMHeadModel"
,
"QDQBertModel"
,
"QDQBertPreTrainedModel"
,
"load_tf_weights_in_qdqbert"
,
]
)
_import_structure
[
"models.rag"
].
extend
(
_import_structure
[
"models.rag"
].
extend
(
[
"RagModel"
,
"RagPreTrainedModel"
,
"RagSequenceForGeneration"
,
"RagTokenForGeneration"
]
[
"RagModel"
,
"RagPreTrainedModel"
,
"RagSequenceForGeneration"
,
"RagTokenForGeneration"
]
)
)
...
@@ -2828,24 +2819,6 @@ if TYPE_CHECKING:
...
@@ -2828,24 +2819,6 @@ if TYPE_CHECKING:
else
:
else
:
from
.utils.dummy_scatter_objects
import
*
from
.utils.dummy_scatter_objects
import
*
if
is_torch_available
()
and
is_pytorch_quantization_available
():
from
.models.qdqbert
import
(
QDQBERT_PRETRAINED_MODEL_ARCHIVE_LIST
,
QDQBertForMaskedLM
,
QDQBertForMultipleChoice
,
QDQBertForNextSentencePrediction
,
QDQBertForQuestionAnswering
,
QDQBertForSequenceClassification
,
QDQBertForTokenClassification
,
QDQBertLayer
,
QDQBertLMHeadModel
,
QDQBertModel
,
QDQBertPreTrainedModel
,
load_tf_weights_in_qdqbert
,
)
else
:
from
.utils.dummy_pytorch_quantization_and_torch_objects
import
*
if
is_torch_available
():
if
is_torch_available
():
# Benchmarks
# Benchmarks
from
.benchmark.benchmark
import
PyTorchBenchmark
from
.benchmark.benchmark
import
PyTorchBenchmark
...
@@ -3428,6 +3401,20 @@ if TYPE_CHECKING:
...
@@ -3428,6 +3401,20 @@ if TYPE_CHECKING:
ProphetNetModel
,
ProphetNetModel
,
ProphetNetPreTrainedModel
,
ProphetNetPreTrainedModel
,
)
)
from
.models.qdqbert
import
(
QDQBERT_PRETRAINED_MODEL_ARCHIVE_LIST
,
QDQBertForMaskedLM
,
QDQBertForMultipleChoice
,
QDQBertForNextSentencePrediction
,
QDQBertForQuestionAnswering
,
QDQBertForSequenceClassification
,
QDQBertForTokenClassification
,
QDQBertLayer
,
QDQBertLMHeadModel
,
QDQBertModel
,
QDQBertPreTrainedModel
,
load_tf_weights_in_qdqbert
,
)
from
.models.rag
import
RagModel
,
RagPreTrainedModel
,
RagSequenceForGeneration
,
RagTokenForGeneration
from
.models.rag
import
RagModel
,
RagPreTrainedModel
,
RagSequenceForGeneration
,
RagTokenForGeneration
from
.models.realm
import
(
from
.models.realm
import
(
REALM_PRETRAINED_MODEL_ARCHIVE_LIST
,
REALM_PRETRAINED_MODEL_ARCHIVE_LIST
,
...
...
src/transformers/utils/dummy_pt_objects.py
View file @
f5741bcd
...
@@ -3044,6 +3044,83 @@ class ProphetNetPreTrainedModel(metaclass=DummyObject):
...
@@ -3044,6 +3044,83 @@ class ProphetNetPreTrainedModel(metaclass=DummyObject):
requires_backends
(
self
,
[
"torch"
])
requires_backends
(
self
,
[
"torch"
])
QDQBERT_PRETRAINED_MODEL_ARCHIVE_LIST
=
None
class
QDQBertForMaskedLM
(
metaclass
=
DummyObject
):
_backends
=
[
"torch"
]
def
__init__
(
self
,
*
args
,
**
kwargs
):
requires_backends
(
self
,
[
"torch"
])
class
QDQBertForMultipleChoice
(
metaclass
=
DummyObject
):
_backends
=
[
"torch"
]
def
__init__
(
self
,
*
args
,
**
kwargs
):
requires_backends
(
self
,
[
"torch"
])
class
QDQBertForNextSentencePrediction
(
metaclass
=
DummyObject
):
_backends
=
[
"torch"
]
def
__init__
(
self
,
*
args
,
**
kwargs
):
requires_backends
(
self
,
[
"torch"
])
class
QDQBertForQuestionAnswering
(
metaclass
=
DummyObject
):
_backends
=
[
"torch"
]
def
__init__
(
self
,
*
args
,
**
kwargs
):
requires_backends
(
self
,
[
"torch"
])
class
QDQBertForSequenceClassification
(
metaclass
=
DummyObject
):
_backends
=
[
"torch"
]
def
__init__
(
self
,
*
args
,
**
kwargs
):
requires_backends
(
self
,
[
"torch"
])
class
QDQBertForTokenClassification
(
metaclass
=
DummyObject
):
_backends
=
[
"torch"
]
def
__init__
(
self
,
*
args
,
**
kwargs
):
requires_backends
(
self
,
[
"torch"
])
class
QDQBertLayer
(
metaclass
=
DummyObject
):
_backends
=
[
"torch"
]
def
__init__
(
self
,
*
args
,
**
kwargs
):
requires_backends
(
self
,
[
"torch"
])
class
QDQBertLMHeadModel
(
metaclass
=
DummyObject
):
_backends
=
[
"torch"
]
def
__init__
(
self
,
*
args
,
**
kwargs
):
requires_backends
(
self
,
[
"torch"
])
class
QDQBertModel
(
metaclass
=
DummyObject
):
_backends
=
[
"torch"
]
def
__init__
(
self
,
*
args
,
**
kwargs
):
requires_backends
(
self
,
[
"torch"
])
class
QDQBertPreTrainedModel
(
metaclass
=
DummyObject
):
_backends
=
[
"torch"
]
def
__init__
(
self
,
*
args
,
**
kwargs
):
requires_backends
(
self
,
[
"torch"
])
def
load_tf_weights_in_qdqbert
(
*
args
,
**
kwargs
):
requires_backends
(
load_tf_weights_in_qdqbert
,
[
"torch"
])
class
RagModel
(
metaclass
=
DummyObject
):
class
RagModel
(
metaclass
=
DummyObject
):
_backends
=
[
"torch"
]
_backends
=
[
"torch"
]
...
...
src/transformers/utils/dummy_pytorch_quantization_and_torch_objects.py
deleted
100644 → 0
View file @
b6bdb943
# This file is autogenerated by the command `make fix-copies`, do not edit.
# flake8: noqa
from
..file_utils
import
DummyObject
,
requires_backends
QDQBERT_PRETRAINED_MODEL_ARCHIVE_LIST
=
None
class
QDQBertForMaskedLM
(
metaclass
=
DummyObject
):
_backends
=
[
"pytorch_quantization"
,
"torch"
]
def
__init__
(
self
,
*
args
,
**
kwargs
):
requires_backends
(
self
,
[
"pytorch_quantization"
,
"torch"
])
class
QDQBertForMultipleChoice
(
metaclass
=
DummyObject
):
_backends
=
[
"pytorch_quantization"
,
"torch"
]
def
__init__
(
self
,
*
args
,
**
kwargs
):
requires_backends
(
self
,
[
"pytorch_quantization"
,
"torch"
])
class
QDQBertForNextSentencePrediction
(
metaclass
=
DummyObject
):
_backends
=
[
"pytorch_quantization"
,
"torch"
]
def
__init__
(
self
,
*
args
,
**
kwargs
):
requires_backends
(
self
,
[
"pytorch_quantization"
,
"torch"
])
class
QDQBertForQuestionAnswering
(
metaclass
=
DummyObject
):
_backends
=
[
"pytorch_quantization"
,
"torch"
]
def
__init__
(
self
,
*
args
,
**
kwargs
):
requires_backends
(
self
,
[
"pytorch_quantization"
,
"torch"
])
class
QDQBertForSequenceClassification
(
metaclass
=
DummyObject
):
_backends
=
[
"pytorch_quantization"
,
"torch"
]
def
__init__
(
self
,
*
args
,
**
kwargs
):
requires_backends
(
self
,
[
"pytorch_quantization"
,
"torch"
])
class
QDQBertForTokenClassification
(
metaclass
=
DummyObject
):
_backends
=
[
"pytorch_quantization"
,
"torch"
]
def
__init__
(
self
,
*
args
,
**
kwargs
):
requires_backends
(
self
,
[
"pytorch_quantization"
,
"torch"
])
class
QDQBertLayer
(
metaclass
=
DummyObject
):
_backends
=
[
"pytorch_quantization"
,
"torch"
]
def
__init__
(
self
,
*
args
,
**
kwargs
):
requires_backends
(
self
,
[
"pytorch_quantization"
,
"torch"
])
class
QDQBertLMHeadModel
(
metaclass
=
DummyObject
):
_backends
=
[
"pytorch_quantization"
,
"torch"
]
def
__init__
(
self
,
*
args
,
**
kwargs
):
requires_backends
(
self
,
[
"pytorch_quantization"
,
"torch"
])
class
QDQBertModel
(
metaclass
=
DummyObject
):
_backends
=
[
"pytorch_quantization"
,
"torch"
]
def
__init__
(
self
,
*
args
,
**
kwargs
):
requires_backends
(
self
,
[
"pytorch_quantization"
,
"torch"
])
class
QDQBertPreTrainedModel
(
metaclass
=
DummyObject
):
_backends
=
[
"pytorch_quantization"
,
"torch"
]
def
__init__
(
self
,
*
args
,
**
kwargs
):
requires_backends
(
self
,
[
"pytorch_quantization"
,
"torch"
])
def
load_tf_weights_in_qdqbert
(
*
args
,
**
kwargs
):
requires_backends
(
load_tf_weights_in_qdqbert
,
[
"pytorch_quantization"
,
"torch"
])
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment