Unverified Commit 3e8d17e6 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Add forward method to dummy models (#14419)

* Add forward method to dummy models

* Fix quality
parent 040fd471
This diff is collapsed.
......@@ -13,6 +13,9 @@ class TapasForMaskedLM:
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["scatter"])
def forward(self, *args, **kwargs):
requires_backends(self, ["scatter"])
class TapasForQuestionAnswering:
def __init__(self, *args, **kwargs):
......@@ -22,6 +25,9 @@ class TapasForQuestionAnswering:
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["scatter"])
def forward(self, *args, **kwargs):
requires_backends(self, ["scatter"])
class TapasForSequenceClassification:
def __init__(self, *args, **kwargs):
......@@ -31,6 +37,9 @@ class TapasForSequenceClassification:
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["scatter"])
def forward(self, *args, **kwargs):
requires_backends(self, ["scatter"])
class TapasModel:
def __init__(self, *args, **kwargs):
......@@ -40,6 +49,9 @@ class TapasModel:
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["scatter"])
def forward(self, *args, **kwargs):
requires_backends(self, ["scatter"])
class TapasPreTrainedModel:
def __init__(self, *args, **kwargs):
......@@ -49,6 +61,9 @@ class TapasPreTrainedModel:
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["scatter"])
def forward(self, *args, **kwargs):
requires_backends(self, ["scatter"])
def load_tf_weights_in_tapas(*args, **kwargs):
requires_backends(load_tf_weights_in_tapas, ["scatter"])
This diff is collapsed.
......@@ -13,6 +13,9 @@ class DetrForObjectDetection:
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["timm", "vision"])
def forward(self, *args, **kwargs):
requires_backends(self, ["timm", "vision"])
class DetrForSegmentation:
def __init__(self, *args, **kwargs):
......@@ -22,6 +25,9 @@ class DetrForSegmentation:
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["timm", "vision"])
def forward(self, *args, **kwargs):
requires_backends(self, ["timm", "vision"])
class DetrModel:
def __init__(self, *args, **kwargs):
......@@ -31,6 +37,9 @@ class DetrModel:
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["timm", "vision"])
def forward(self, *args, **kwargs):
requires_backends(self, ["timm", "vision"])
class DetrPreTrainedModel:
def __init__(self, *args, **kwargs):
......@@ -39,3 +48,6 @@ class DetrPreTrainedModel:
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["timm", "vision"])
def forward(self, *args, **kwargs):
requires_backends(self, ["timm", "vision"])
......@@ -43,6 +43,30 @@ class {0}:
requires_backends(cls, {1})
"""
PT_DUMMY_PRETRAINED_CLASS = (
DUMMY_PRETRAINED_CLASS
+ """
def forward(self, *args, **kwargs):
requires_backends(self, {1})
"""
)
TF_DUMMY_PRETRAINED_CLASS = (
DUMMY_PRETRAINED_CLASS
+ """
def call(self, *args, **kwargs):
requires_backends(self, {1})
"""
)
FLAX_DUMMY_PRETRAINED_CLASS = (
DUMMY_PRETRAINED_CLASS
+ """
def __call__(self, *args, **kwargs):
requires_backends(self, {1})
"""
)
DUMMY_CLASS = """
class {0}:
def __init__(self, *args, **kwargs):
......@@ -102,8 +126,7 @@ def read_init():
def create_dummy_object(name, backend_name):
"""Create the code for the dummy object corresponding to `name`."""
_pretrained = [
"Config",
_models = [
"ForCausalLM",
"ForConditionalGeneration",
"ForMaskedLM",
......@@ -114,14 +137,24 @@ def create_dummy_object(name, backend_name):
"ForSequenceClassification",
"ForTokenClassification",
"Model",
"Tokenizer",
"Processor",
]
_pretrained = ["Config", "Tokenizer", "Processor"]
if name.isupper():
return DUMMY_CONSTANT.format(name)
elif name.islower():
return DUMMY_FUNCTION.format(name, backend_name)
else:
is_model = False
for part in _models:
if part in name:
is_model = True
break
if is_model:
if name.startswith("TF"):
return TF_DUMMY_PRETRAINED_CLASS.format(name, backend_name)
if name.startswith("Flax"):
return FLAX_DUMMY_PRETRAINED_CLASS.format(name, backend_name)
return PT_DUMMY_PRETRAINED_CLASS.format(name, backend_name)
is_pretrained = False
for part in _pretrained:
if part in name:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment