Commit 9a50828b authored by Julien Chaumond's avatar Julien Chaumond
Browse files

Pipelines: fix crash when modelcard is None

cc @mfuntowicz does this seem correct?
parent 6c1b2355
...@@ -326,7 +326,7 @@ class Pipeline(_ScikitCompat): ...@@ -326,7 +326,7 @@ class Pipeline(_ScikitCompat):
self, self,
model, model,
tokenizer: PreTrainedTokenizer = None, tokenizer: PreTrainedTokenizer = None,
modelcard: ModelCard = None, modelcard: Optional[ModelCard] = None,
framework: Optional[str] = None, framework: Optional[str] = None,
args_parser: ArgumentHandler = None, args_parser: ArgumentHandler = None,
device: int = -1, device: int = -1,
...@@ -358,7 +358,8 @@ class Pipeline(_ScikitCompat): ...@@ -358,7 +358,8 @@ class Pipeline(_ScikitCompat):
self.model.save_pretrained(save_directory) self.model.save_pretrained(save_directory)
self.tokenizer.save_pretrained(save_directory) self.tokenizer.save_pretrained(save_directory)
self.modelcard.save_pretrained(save_directory) if self.modelcard is not None:
self.modelcard.save_pretrained(save_directory)
def transform(self, X): def transform(self, X):
""" """
...@@ -476,7 +477,7 @@ class FeatureExtractionPipeline(Pipeline): ...@@ -476,7 +477,7 @@ class FeatureExtractionPipeline(Pipeline):
self, self,
model, model,
tokenizer: PreTrainedTokenizer = None, tokenizer: PreTrainedTokenizer = None,
modelcard: ModelCard = None, modelcard: Optional[ModelCard] = None,
framework: Optional[str] = None, framework: Optional[str] = None,
args_parser: ArgumentHandler = None, args_parser: ArgumentHandler = None,
device: int = -1, device: int = -1,
...@@ -515,7 +516,7 @@ class FillMaskPipeline(Pipeline): ...@@ -515,7 +516,7 @@ class FillMaskPipeline(Pipeline):
self, self,
model, model,
tokenizer: PreTrainedTokenizer = None, tokenizer: PreTrainedTokenizer = None,
modelcard: ModelCard = None, modelcard: Optional[ModelCard] = None,
framework: Optional[str] = None, framework: Optional[str] = None,
args_parser: ArgumentHandler = None, args_parser: ArgumentHandler = None,
device: int = -1, device: int = -1,
...@@ -582,7 +583,7 @@ class NerPipeline(Pipeline): ...@@ -582,7 +583,7 @@ class NerPipeline(Pipeline):
self, self,
model, model,
tokenizer: PreTrainedTokenizer = None, tokenizer: PreTrainedTokenizer = None,
modelcard: ModelCard = None, modelcard: Optional[ModelCard] = None,
framework: Optional[str] = None, framework: Optional[str] = None,
args_parser: ArgumentHandler = None, args_parser: ArgumentHandler = None,
device: int = -1, device: int = -1,
...@@ -721,7 +722,7 @@ class QuestionAnsweringPipeline(Pipeline): ...@@ -721,7 +722,7 @@ class QuestionAnsweringPipeline(Pipeline):
self, self,
model, model,
tokenizer: Optional[PreTrainedTokenizer], tokenizer: Optional[PreTrainedTokenizer],
modelcard: Optional[ModelCard], modelcard: Optional[ModelCard] = None,
framework: Optional[str] = None, framework: Optional[str] = None,
device: int = -1, device: int = -1,
**kwargs **kwargs
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment