Unverified Commit 129cb6d5 authored by Yih-Dar's avatar Yih-Dar Committed by GitHub
Browse files

Avoid some pipeline tasks to use `use_cache=True` (#24893)



* fix

* fix

* fix

* fix

* fix

---------
Co-authored-by: default avatarydshieh <ydshieh@users.noreply.github.com>
parent 476be08c
import inspect
import types import types
import warnings import warnings
from collections.abc import Iterable from collections.abc import Iterable
...@@ -510,6 +511,10 @@ class QuestionAnsweringPipeline(ChunkPipeline): ...@@ -510,6 +511,10 @@ class QuestionAnsweringPipeline(ChunkPipeline):
def _forward(self, inputs): def _forward(self, inputs):
example = inputs["example"] example = inputs["example"]
model_inputs = {k: inputs[k] for k in self.tokenizer.model_input_names} model_inputs = {k: inputs[k] for k in self.tokenizer.model_input_names}
# `XXXForSequenceClassification` models should not use `use_cache=True` even if it's supported
model_forward = self.model.forward if self.framework == "pt" else self.model.call
if "use_cache" in inspect.signature(model_forward).parameters.keys():
model_inputs["use_cache"] = False
output = self.model(**model_inputs) output = self.model(**model_inputs)
if isinstance(output, dict): if isinstance(output, dict):
return {"start": output["start_logits"], "end": output["end_logits"], "example": example, **inputs} return {"start": output["start_logits"], "end": output["end_logits"], "example": example, **inputs}
......
import inspect
import warnings import warnings
from typing import Dict from typing import Dict
...@@ -179,6 +180,10 @@ class TextClassificationPipeline(Pipeline): ...@@ -179,6 +180,10 @@ class TextClassificationPipeline(Pipeline):
return self.tokenizer(inputs, return_tensors=return_tensors, **tokenizer_kwargs) return self.tokenizer(inputs, return_tensors=return_tensors, **tokenizer_kwargs)
def _forward(self, model_inputs): def _forward(self, model_inputs):
# `XXXForSequenceClassification` models should not use `use_cache=True` even if it's supported
model_forward = self.model.forward if self.framework == "pt" else self.model.call
if "use_cache" in inspect.signature(model_forward).parameters.keys():
model_inputs["use_cache"] = False
return self.model(**model_inputs) return self.model(**model_inputs)
def postprocess(self, model_outputs, function_to_apply=None, top_k=1, _legacy=True): def postprocess(self, model_outputs, function_to_apply=None, top_k=1, _legacy=True):
......
import inspect
from typing import List, Union from typing import List, Union
import numpy as np import numpy as np
...@@ -221,6 +222,10 @@ class ZeroShotClassificationPipeline(ChunkPipeline): ...@@ -221,6 +222,10 @@ class ZeroShotClassificationPipeline(ChunkPipeline):
candidate_label = inputs["candidate_label"] candidate_label = inputs["candidate_label"]
sequence = inputs["sequence"] sequence = inputs["sequence"]
model_inputs = {k: inputs[k] for k in self.tokenizer.model_input_names} model_inputs = {k: inputs[k] for k in self.tokenizer.model_input_names}
# `XXXForSequenceClassification` models should not use `use_cache=True` even if it's supported
model_forward = self.model.forward if self.framework == "pt" else self.model.call
if "use_cache" in inspect.signature(model_forward).parameters.keys():
model_inputs["use_cache"] = False
outputs = self.model(**model_inputs) outputs = self.model(**model_inputs)
model_outputs = { model_outputs = {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment