"vscode:/vscode.git/clone" did not exist on "6820904454d108961a7bb1c99b2065b75d94bf01"
Unverified Commit e7b16f33 authored by Nicolas Patry's avatar Nicolas Patry Committed by GitHub
Browse files

Fixing GPU for token-classification in a better way. (#13856)


Co-authored-by: default avatarPierre Snell <pierre.snell@botpress.com>
Co-authored-by: default avatarPierre Snell <pierre.snell@botpress.com>
parent 7d83655d
......@@ -791,7 +791,7 @@ class Pipeline(_ScikitCompat):
elif isinstance(inputs, tuple):
return tuple([self._ensure_tensor_on_device(item, device) for item in inputs])
elif isinstance(inputs, torch.Tensor):
return inputs.to(self.device)
return inputs.to(device)
else:
return inputs
......
......@@ -204,9 +204,10 @@ class TokenClassificationPipeline(Pipeline):
offset_mapping = model_inputs.pop("offset_mapping", None)
sentence = model_inputs.pop("sentence")
if self.framework == "tf":
outputs = self.model(model_inputs.data)[0][0].numpy()
outputs = self.model(model_inputs.data)[0][0]
else:
outputs = self.model(**model_inputs)[0][0].numpy()
outputs = self.model(**model_inputs)[0][0]
return {
"outputs": outputs,
"special_tokens_mask": special_tokens_mask,
......@@ -216,7 +217,7 @@ class TokenClassificationPipeline(Pipeline):
}
def postprocess(self, model_outputs, aggregation_strategy=AggregationStrategy.NONE):
outputs = model_outputs["outputs"]
outputs = model_outputs["outputs"].numpy()
sentence = model_outputs["sentence"]
input_ids = model_outputs["input_ids"][0]
offset_mapping = model_outputs["offset_mapping"][0] if model_outputs["offset_mapping"] is not None else None
......
......@@ -25,7 +25,14 @@ from transformers import (
pipeline,
)
from transformers.pipelines import AggregationStrategy, TokenClassificationArgumentHandler
from transformers.testing_utils import is_pipeline_test, nested_simplify, require_tf, require_torch, slow
from transformers.testing_utils import (
is_pipeline_test,
nested_simplify,
require_tf,
require_torch,
require_torch_gpu,
slow,
)
from .test_pipelines_common import ANY, PipelineTestCaseMeta
......@@ -246,6 +253,19 @@ class TokenClassificationPipelineTests(unittest.TestCase, metaclass=PipelineTest
],
)
@require_torch_gpu
@slow
def test_gpu(self):
sentence = "This is dummy sentence"
ner = pipeline(
"token-classification",
device=0,
aggregation_strategy=AggregationStrategy.SIMPLE,
)
output = ner(sentence)
self.assertEqual(nested_simplify(output), [])
@require_torch
@slow
def test_dbmdz_english(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment