Commit f79a7dc6 authored by thomwolf's avatar thomwolf
Browse files

fix NER pipeline

parent a2410110
...@@ -491,9 +491,11 @@ class NerPipeline(Pipeline): ...@@ -491,9 +491,11 @@ class NerPipeline(Pipeline):
# Forward # Forward
if self.framework == 'tf': if self.framework == 'tf':
entities = self.model(tokens)[0][0].numpy() entities = self.model(tokens)[0][0].numpy()
input_ids = tokens['input_ids'].numpy()[0]
else: else:
with torch.no_grad(): with torch.no_grad():
entities = self.model(**tokens)[0][0].cpu().numpy() entities = self.model(**tokens)[0][0].cpu().numpy()
input_ids = tokens['input_ids'].cpu().numpy()[0]
score = np.exp(entities) / np.exp(entities).sum(-1, keepdims=True) score = np.exp(entities) / np.exp(entities).sum(-1, keepdims=True)
labels_idx = score.argmax(axis=-1) labels_idx = score.argmax(axis=-1)
...@@ -502,7 +504,7 @@ class NerPipeline(Pipeline): ...@@ -502,7 +504,7 @@ class NerPipeline(Pipeline):
for idx, label_idx in enumerate(labels_idx): for idx, label_idx in enumerate(labels_idx):
if self.model.config.id2label[label_idx] not in self.ignore_labels: if self.model.config.id2label[label_idx] not in self.ignore_labels:
answer += [{ answer += [{
'word': self.tokenizer.decode(tokens['input_ids'][0][idx].cpu().tolist()), 'word': self.tokenizer.decode(int(input_ids[idx])),
'score': score[idx][label_idx].item(), 'score': score[idx][label_idx].item(),
'entity': self.model.config.id2label[label_idx] 'entity': self.model.config.id2label[label_idx]
}] }]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment