"...resnet50_tensorflow.git" did not exist on "4d6820e0b0526521840038a6db328401f4fa96b3"
Commit 8938b546 authored by Morgan Funtowicz's avatar Morgan Funtowicz
Browse files

Removed from_config

parent 1ca52567
......@@ -37,11 +37,6 @@ class Pipeline(ABC):
self.model = model
self.tokenizer = tokenizer
@classmethod
@abstractmethod
def from_config(cls, model, tokenizer: PreTrainedTokenizer, **kwargs):
raise NotImplementedError()
def save_pretrained(self, save_directory):
if not os.path.isdir(save_directory):
logger.error("Provided path ({}) should be a directory".format(save_directory))
......@@ -63,6 +58,12 @@ class Pipeline(ABC):
raise NotImplementedError()
class FeatureExtractionPipeline(Pipeline):
def __call__(self, *texts, **kwargs):
pass
class TextClassificationPipeline(Pipeline):
def __init__(self, model, tokenizer: PreTrainedTokenizer, nb_classes: int = 2):
super().__init__(model, tokenizer)
......@@ -71,10 +72,6 @@ class TextClassificationPipeline(Pipeline):
raise Exception('Invalid parameter nb_classes. int >= 2 is required (got: {})'.format(nb_classes))
self._nb_classes = nb_classes
@classmethod
def from_config(cls, model, tokenizer: PreTrainedTokenizer, **kwargs):
return cls(model, tokenizer, **kwargs)
def __call__(self, *texts, **kwargs):
# Generic compatibility with sklearn and Keras
if 'X' in kwargs and not texts:
......@@ -102,10 +99,6 @@ class NerPipeline(Pipeline):
def __init__(self, model, tokenizer: PreTrainedTokenizer):
super().__init__(model, tokenizer)
@classmethod
def from_config(cls, model, tokenizer: PreTrainedTokenizer, **kwargs):
pass
def __call__(self, *texts, **kwargs):
(texts, ), answers = texts, []
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment