Unverified Commit 3e8d17e6 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Add forward method to dummy models (#14419)

* Add forward method to dummy models

* Fix quality
parent 040fd471
This diff is collapsed.
...@@ -13,6 +13,9 @@ class TapasForMaskedLM: ...@@ -13,6 +13,9 @@ class TapasForMaskedLM:
def from_pretrained(cls, *args, **kwargs): def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["scatter"]) requires_backends(cls, ["scatter"])
def forward(self, *args, **kwargs):
requires_backends(self, ["scatter"])
class TapasForQuestionAnswering: class TapasForQuestionAnswering:
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
...@@ -22,6 +25,9 @@ class TapasForQuestionAnswering: ...@@ -22,6 +25,9 @@ class TapasForQuestionAnswering:
def from_pretrained(cls, *args, **kwargs): def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["scatter"]) requires_backends(cls, ["scatter"])
def forward(self, *args, **kwargs):
requires_backends(self, ["scatter"])
class TapasForSequenceClassification: class TapasForSequenceClassification:
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
...@@ -31,6 +37,9 @@ class TapasForSequenceClassification: ...@@ -31,6 +37,9 @@ class TapasForSequenceClassification:
def from_pretrained(cls, *args, **kwargs): def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["scatter"]) requires_backends(cls, ["scatter"])
def forward(self, *args, **kwargs):
requires_backends(self, ["scatter"])
class TapasModel: class TapasModel:
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
...@@ -40,6 +49,9 @@ class TapasModel: ...@@ -40,6 +49,9 @@ class TapasModel:
def from_pretrained(cls, *args, **kwargs): def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["scatter"]) requires_backends(cls, ["scatter"])
def forward(self, *args, **kwargs):
requires_backends(self, ["scatter"])
class TapasPreTrainedModel: class TapasPreTrainedModel:
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
...@@ -49,6 +61,9 @@ class TapasPreTrainedModel: ...@@ -49,6 +61,9 @@ class TapasPreTrainedModel:
def from_pretrained(cls, *args, **kwargs): def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["scatter"]) requires_backends(cls, ["scatter"])
def forward(self, *args, **kwargs):
requires_backends(self, ["scatter"])
def load_tf_weights_in_tapas(*args, **kwargs): def load_tf_weights_in_tapas(*args, **kwargs):
requires_backends(load_tf_weights_in_tapas, ["scatter"]) requires_backends(load_tf_weights_in_tapas, ["scatter"])
This diff is collapsed.
...@@ -13,6 +13,9 @@ class DetrForObjectDetection: ...@@ -13,6 +13,9 @@ class DetrForObjectDetection:
def from_pretrained(cls, *args, **kwargs): def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["timm", "vision"]) requires_backends(cls, ["timm", "vision"])
def forward(self, *args, **kwargs):
requires_backends(self, ["timm", "vision"])
class DetrForSegmentation: class DetrForSegmentation:
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
...@@ -22,6 +25,9 @@ class DetrForSegmentation: ...@@ -22,6 +25,9 @@ class DetrForSegmentation:
def from_pretrained(cls, *args, **kwargs): def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["timm", "vision"]) requires_backends(cls, ["timm", "vision"])
def forward(self, *args, **kwargs):
requires_backends(self, ["timm", "vision"])
class DetrModel: class DetrModel:
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
...@@ -31,6 +37,9 @@ class DetrModel: ...@@ -31,6 +37,9 @@ class DetrModel:
def from_pretrained(cls, *args, **kwargs): def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["timm", "vision"]) requires_backends(cls, ["timm", "vision"])
def forward(self, *args, **kwargs):
requires_backends(self, ["timm", "vision"])
class DetrPreTrainedModel: class DetrPreTrainedModel:
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
...@@ -39,3 +48,6 @@ class DetrPreTrainedModel: ...@@ -39,3 +48,6 @@ class DetrPreTrainedModel:
@classmethod @classmethod
def from_pretrained(cls, *args, **kwargs): def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["timm", "vision"]) requires_backends(cls, ["timm", "vision"])
def forward(self, *args, **kwargs):
requires_backends(self, ["timm", "vision"])
...@@ -43,6 +43,30 @@ class {0}: ...@@ -43,6 +43,30 @@ class {0}:
requires_backends(cls, {1}) requires_backends(cls, {1})
""" """
PT_DUMMY_PRETRAINED_CLASS = (
DUMMY_PRETRAINED_CLASS
+ """
def forward(self, *args, **kwargs):
requires_backends(self, {1})
"""
)
TF_DUMMY_PRETRAINED_CLASS = (
DUMMY_PRETRAINED_CLASS
+ """
def call(self, *args, **kwargs):
requires_backends(self, {1})
"""
)
FLAX_DUMMY_PRETRAINED_CLASS = (
DUMMY_PRETRAINED_CLASS
+ """
def __call__(self, *args, **kwargs):
requires_backends(self, {1})
"""
)
DUMMY_CLASS = """ DUMMY_CLASS = """
class {0}: class {0}:
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
...@@ -102,8 +126,7 @@ def read_init(): ...@@ -102,8 +126,7 @@ def read_init():
def create_dummy_object(name, backend_name): def create_dummy_object(name, backend_name):
"""Create the code for the dummy object corresponding to `name`.""" """Create the code for the dummy object corresponding to `name`."""
_pretrained = [ _models = [
"Config",
"ForCausalLM", "ForCausalLM",
"ForConditionalGeneration", "ForConditionalGeneration",
"ForMaskedLM", "ForMaskedLM",
...@@ -114,14 +137,24 @@ def create_dummy_object(name, backend_name): ...@@ -114,14 +137,24 @@ def create_dummy_object(name, backend_name):
"ForSequenceClassification", "ForSequenceClassification",
"ForTokenClassification", "ForTokenClassification",
"Model", "Model",
"Tokenizer",
"Processor",
] ]
_pretrained = ["Config", "Tokenizer", "Processor"]
if name.isupper(): if name.isupper():
return DUMMY_CONSTANT.format(name) return DUMMY_CONSTANT.format(name)
elif name.islower(): elif name.islower():
return DUMMY_FUNCTION.format(name, backend_name) return DUMMY_FUNCTION.format(name, backend_name)
else: else:
is_model = False
for part in _models:
if part in name:
is_model = True
break
if is_model:
if name.startswith("TF"):
return TF_DUMMY_PRETRAINED_CLASS.format(name, backend_name)
if name.startswith("Flax"):
return FLAX_DUMMY_PRETRAINED_CLASS.format(name, backend_name)
return PT_DUMMY_PRETRAINED_CLASS.format(name, backend_name)
is_pretrained = False is_pretrained = False
for part in _pretrained: for part in _pretrained:
if part in name: if part in name:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment