Unverified Commit e7b16f33 authored by Nicolas Patry's avatar Nicolas Patry Committed by GitHub
Browse files

Fixing GPU for token-classification in a better way. (#13856)


Co-authored-by: default avatarPierre Snell <pierre.snell@botpress.com>
Co-authored-by: default avatarPierre Snell <pierre.snell@botpress.com>
parent 7d83655d
...@@ -791,7 +791,7 @@ class Pipeline(_ScikitCompat): ...@@ -791,7 +791,7 @@ class Pipeline(_ScikitCompat):
elif isinstance(inputs, tuple): elif isinstance(inputs, tuple):
return tuple([self._ensure_tensor_on_device(item, device) for item in inputs]) return tuple([self._ensure_tensor_on_device(item, device) for item in inputs])
elif isinstance(inputs, torch.Tensor): elif isinstance(inputs, torch.Tensor):
return inputs.to(self.device) return inputs.to(device)
else: else:
return inputs return inputs
......
...@@ -204,9 +204,10 @@ class TokenClassificationPipeline(Pipeline): ...@@ -204,9 +204,10 @@ class TokenClassificationPipeline(Pipeline):
offset_mapping = model_inputs.pop("offset_mapping", None) offset_mapping = model_inputs.pop("offset_mapping", None)
sentence = model_inputs.pop("sentence") sentence = model_inputs.pop("sentence")
if self.framework == "tf": if self.framework == "tf":
outputs = self.model(model_inputs.data)[0][0].numpy() outputs = self.model(model_inputs.data)[0][0]
else: else:
outputs = self.model(**model_inputs)[0][0].numpy() outputs = self.model(**model_inputs)[0][0]
return { return {
"outputs": outputs, "outputs": outputs,
"special_tokens_mask": special_tokens_mask, "special_tokens_mask": special_tokens_mask,
...@@ -216,7 +217,7 @@ class TokenClassificationPipeline(Pipeline): ...@@ -216,7 +217,7 @@ class TokenClassificationPipeline(Pipeline):
} }
def postprocess(self, model_outputs, aggregation_strategy=AggregationStrategy.NONE): def postprocess(self, model_outputs, aggregation_strategy=AggregationStrategy.NONE):
outputs = model_outputs["outputs"] outputs = model_outputs["outputs"].numpy()
sentence = model_outputs["sentence"] sentence = model_outputs["sentence"]
input_ids = model_outputs["input_ids"][0] input_ids = model_outputs["input_ids"][0]
offset_mapping = model_outputs["offset_mapping"][0] if model_outputs["offset_mapping"] is not None else None offset_mapping = model_outputs["offset_mapping"][0] if model_outputs["offset_mapping"] is not None else None
......
...@@ -25,7 +25,14 @@ from transformers import ( ...@@ -25,7 +25,14 @@ from transformers import (
pipeline, pipeline,
) )
from transformers.pipelines import AggregationStrategy, TokenClassificationArgumentHandler from transformers.pipelines import AggregationStrategy, TokenClassificationArgumentHandler
from transformers.testing_utils import is_pipeline_test, nested_simplify, require_tf, require_torch, slow from transformers.testing_utils import (
is_pipeline_test,
nested_simplify,
require_tf,
require_torch,
require_torch_gpu,
slow,
)
from .test_pipelines_common import ANY, PipelineTestCaseMeta from .test_pipelines_common import ANY, PipelineTestCaseMeta
...@@ -246,6 +253,19 @@ class TokenClassificationPipelineTests(unittest.TestCase, metaclass=PipelineTest ...@@ -246,6 +253,19 @@ class TokenClassificationPipelineTests(unittest.TestCase, metaclass=PipelineTest
], ],
) )
@require_torch_gpu
@slow
def test_gpu(self):
sentence = "This is dummy sentence"
ner = pipeline(
"token-classification",
device=0,
aggregation_strategy=AggregationStrategy.SIMPLE,
)
output = ner(sentence)
self.assertEqual(nested_simplify(output), [])
@require_torch @require_torch
@slow @slow
def test_dbmdz_english(self): def test_dbmdz_english(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment