Unverified Commit 4ba66fdb authored by amyeroberts's avatar amyeroberts Committed by GitHub
Browse files

Fix pipeline tests - torch imports (#31227)

* Fix pipeline tests - torch imports

* Frameowrk dependant float conversion
parent 6b22a8f2
......@@ -202,7 +202,12 @@ class TextClassificationPipeline(Pipeline):
function_to_apply = ClassificationFunction.NONE
outputs = model_outputs["logits"][0]
outputs = outputs.float().numpy()
if self.framework == "pt":
# To enable using fp16 and bf16
outputs = outputs.float().numpy()
else:
outputs = outputs.numpy()
if function_to_apply == ClassificationFunction.SIGMOID:
scores = sigmoid(outputs)
......
......@@ -14,8 +14,6 @@
import unittest
import torch
from transformers import (
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
......@@ -24,6 +22,7 @@ from transformers import (
)
from transformers.testing_utils import (
is_pipeline_test,
is_torch_available,
nested_simplify,
require_tf,
require_torch,
......@@ -36,6 +35,10 @@ from transformers.testing_utils import (
from .test_pipelines_common import ANY
if is_torch_available():
import torch
# These 2 model types require different inputs than those of the usual text models.
_TO_SKIP = {"LayoutLMv2Config", "LayoutLMv3Config"}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment