Unverified Commit d07b540a authored by Lysandre Debut's avatar Lysandre Debut Committed by GitHub
Browse files

Have dummy processors have a `from_pretrained` method (#12145)

parent 9b393240
...@@ -6,11 +6,19 @@ class FlaxLogitsProcessor: ...@@ -6,11 +6,19 @@ class FlaxLogitsProcessor:
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"]) requires_backends(self, ["flax"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["flax"])
class FlaxLogitsProcessorList: class FlaxLogitsProcessorList:
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"]) requires_backends(self, ["flax"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["flax"])
class FlaxLogitsWarper: class FlaxLogitsWarper:
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
......
...@@ -127,31 +127,55 @@ class ForcedBOSTokenLogitsProcessor: ...@@ -127,31 +127,55 @@ class ForcedBOSTokenLogitsProcessor:
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"]) requires_backends(self, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class ForcedEOSTokenLogitsProcessor: class ForcedEOSTokenLogitsProcessor:
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"]) requires_backends(self, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class HammingDiversityLogitsProcessor: class HammingDiversityLogitsProcessor:
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"]) requires_backends(self, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class InfNanRemoveLogitsProcessor: class InfNanRemoveLogitsProcessor:
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"]) requires_backends(self, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class LogitsProcessor: class LogitsProcessor:
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"]) requires_backends(self, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class LogitsProcessorList: class LogitsProcessorList:
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"]) requires_backends(self, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class LogitsWarper: class LogitsWarper:
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
...@@ -162,26 +186,46 @@ class MinLengthLogitsProcessor: ...@@ -162,26 +186,46 @@ class MinLengthLogitsProcessor:
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"]) requires_backends(self, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class NoBadWordsLogitsProcessor: class NoBadWordsLogitsProcessor:
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"]) requires_backends(self, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class NoRepeatNGramLogitsProcessor: class NoRepeatNGramLogitsProcessor:
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"]) requires_backends(self, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class PrefixConstrainedLogitsProcessor: class PrefixConstrainedLogitsProcessor:
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"]) requires_backends(self, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class RepetitionPenaltyLogitsProcessor: class RepetitionPenaltyLogitsProcessor:
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"]) requires_backends(self, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class TemperatureLogitsWarper: class TemperatureLogitsWarper:
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
......
...@@ -5,3 +5,7 @@ from ..file_utils import requires_backends ...@@ -5,3 +5,7 @@ from ..file_utils import requires_backends
class Speech2TextProcessor: class Speech2TextProcessor:
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
requires_backends(self, ["sentencepiece", "speech"]) requires_backends(self, ["sentencepiece", "speech"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["sentencepiece", "speech"])
...@@ -16,6 +16,10 @@ class CLIPProcessor: ...@@ -16,6 +16,10 @@ class CLIPProcessor:
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
requires_backends(self, ["vision"]) requires_backends(self, ["vision"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["vision"])
class DeiTFeatureExtractor: class DeiTFeatureExtractor:
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
......
...@@ -115,6 +115,7 @@ def create_dummy_object(name, backend_name): ...@@ -115,6 +115,7 @@ def create_dummy_object(name, backend_name):
"ForTokenClassification", "ForTokenClassification",
"Model", "Model",
"Tokenizer", "Tokenizer",
"Processor",
] ]
if name.isupper(): if name.isupper():
return DUMMY_CONSTANT.format(name) return DUMMY_CONSTANT.format(name)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment