Commit a2410110 authored by thomwolf's avatar thomwolf
Browse files

fix pipeline NER

parent e37ca8e1
...@@ -463,7 +463,7 @@ class NerPipeline(Pipeline): ...@@ -463,7 +463,7 @@ class NerPipeline(Pipeline):
def __init__(self, model, tokenizer: PreTrainedTokenizer = None, def __init__(self, model, tokenizer: PreTrainedTokenizer = None,
modelcard: ModelCard = None, framework: Optional[str] = None, modelcard: ModelCard = None, framework: Optional[str] = None,
args_parser: ArgumentHandler = None, device: int = -1, args_parser: ArgumentHandler = None, device: int = -1,
binary_output: bool = False): binary_output: bool = False, ignore_labels=['O']):
super().__init__(model=model, super().__init__(model=model,
tokenizer=tokenizer, tokenizer=tokenizer,
modelcard=modelcard, modelcard=modelcard,
...@@ -473,17 +473,12 @@ class NerPipeline(Pipeline): ...@@ -473,17 +473,12 @@ class NerPipeline(Pipeline):
binary_output=binary_output) binary_output=binary_output)
self._basic_tokenizer = BasicTokenizer(do_lower_case=False) self._basic_tokenizer = BasicTokenizer(do_lower_case=False)
self.ignore_labels = ignore_labels
def __call__(self, *texts, **kwargs): def __call__(self, *texts, **kwargs):
inputs, answers = self._args_parser(*texts, **kwargs), [] inputs, answers = self._args_parser(*texts, **kwargs), []
for sentence in inputs: for sentence in inputs:
# Ugly token to word idx mapping (for now)
token_to_word, words = [], self._basic_tokenizer.tokenize(sentence)
for i, w in enumerate(words):
tokens = self.tokenizer.tokenize(w)
token_to_word += [i] * len(tokens)
# Manage correct placement of the tensors # Manage correct placement of the tensors
with self.device_placement(): with self.device_placement():
...@@ -500,26 +495,22 @@ class NerPipeline(Pipeline): ...@@ -500,26 +495,22 @@ class NerPipeline(Pipeline):
with torch.no_grad(): with torch.no_grad():
entities = self.model(**tokens)[0][0].cpu().numpy() entities = self.model(**tokens)[0][0].cpu().numpy()
# Normalize scores score = np.exp(entities) / np.exp(entities).sum(-1, keepdims=True)
answer, token_start = [], 1 labels_idx = score.argmax(axis=-1)
for idx, word in groupby(token_to_word):
# Sum log prob over token, then normalize across labels
score = np.exp(entities[token_start]) / np.exp(entities[token_start]).sum(-1, keepdims=True)
label_idx = score.argmax()
if label_idx > 0: answer = []
for idx, label_idx in enumerate(labels_idx):
if self.model.config.id2label[label_idx] not in self.ignore_labels:
answer += [{ answer += [{
'word': words[idx], 'word': self.tokenizer.decode(tokens['input_ids'][0][idx].cpu().tolist()),
'score': score[label_idx].item(), 'score': score[idx][label_idx].item(),
'entity': self.model.config.id2label[label_idx] 'entity': self.model.config.id2label[label_idx]
}] }]
# Update token start
token_start += len(list(word))
# Append # Append
answers += [answer] answers += [answer]
if len(answers) == 1:
return answers[0]
return answers return answers
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment