"...grub-2.04/git@developer.sourcefind.cn:dadigang/Ventoy.git" did not exist on "a1c6fe2d2428cb8a1b4a9e11a9a5075a4bccecd1"
Commit be5bf7b8 authored by Morgan Funtowicz's avatar Morgan Funtowicz
Browse files

Added NER pipeline.

parent 80eacb8f
......@@ -16,18 +16,20 @@ from __future__ import absolute_import, division, print_function, unicode_litera
import os
from abc import ABC, abstractmethod
from itertools import groupby
from typing import Union, Optional, Tuple, List, Dict
import numpy as np
from transformers import is_tf_available, is_torch_available, logger, AutoTokenizer, PreTrainedTokenizer, \
SquadExample, squad_convert_examples_to_features
from transformers import AutoTokenizer, PreTrainedTokenizer, PretrainedConfig, \
SquadExample, squad_convert_examples_to_features, is_tf_available, is_torch_available, logger
if is_tf_available():
from transformers import TFAutoModelForSequenceClassification, TFAutoModelForQuestionAnswering
from transformers import TFAutoModelForSequenceClassification, TFAutoModelForQuestionAnswering, TFAutoModelForTokenClassification
if is_torch_available():
from transformers import AutoModelForSequenceClassification, AutoModelForQuestionAnswering
import torch
from transformers import AutoModelForSequenceClassification, AutoModelForQuestionAnswering, AutoModelForTokenClassification
class Pipeline(ABC):
......@@ -95,9 +97,57 @@ class TextClassificationPipeline(Pipeline):
return predictions.numpy().tolist()
class NerPipeline(Pipeline):
def __init__(self, model, tokenizer: PreTrainedTokenizer):
super().__init__(model, tokenizer)
@classmethod
def from_config(cls, model, tokenizer: PreTrainedTokenizer, **kwargs):
pass
def __call__(self, *texts, **kwargs):
(texts, ), answers = texts, []
for sentence in texts:
# Ugly token to word idx mapping (for now)
token_to_word, words = [], sentence.split(' ')
for i, w in enumerate(words):
tokens = self.tokenizer.tokenize(w)
token_to_word += [i] * len(tokens)
tokens = self.tokenizer.encode_plus(sentence, return_attention_mask=False, return_tensors='tf' if is_tf_available() else 'pt')
# Forward
if is_torch_available():
with torch.no_grad():
entities = self.model(**tokens)[0][0].cpu().numpy()
else:
entities = self.model(tokens)[0][0].numpy()
# Normalize scores
answer, token_start = [], 1
for idx, word in groupby(token_to_word[1:-1]):
# Sum log prob over token, then normalize across labels
score = np.exp(entities[token_start]) / np.exp(entities[token_start]).sum(-1, keepdims=True)
label_idx = score.argmax()
answer += [{
'word': words[idx - 1], 'score': score[label_idx], 'entity': self.model.config.id2label[label_idx]
}]
# Update token start
token_start += len(list(word))
# Append
answers += [answer]
return answers
class QuestionAnsweringPipeline(Pipeline):
"""
Question Answering pipeling involving Tokenization and Inference.
Question Answering pipeline involving Tokenization and Inference.
"""
@classmethod
......@@ -219,6 +269,7 @@ class QuestionAnsweringPipeline(Pipeline):
# Mask padding and question
start_, end_ = start_ * np.abs(np.array(feature.p_mask) - 1), end_ * np.abs(np.array(feature.p_mask) - 1)
# TODO : What happend if not possible
# Mask CLS
start_[0] = end_[0] = 0
......@@ -301,6 +352,11 @@ SUPPORTED_TASKS = {
'tf': TFAutoModelForSequenceClassification if is_tf_available() else None,
'pt': AutoModelForSequenceClassification if is_torch_available() else None
},
'ner': {
'impl': NerPipeline,
'tf': TFAutoModelForTokenClassification if is_tf_available() else None,
'pt': AutoModelForTokenClassification if is_torch_available() else None,
},
'question-answering': {
'impl': QuestionAnsweringPipeline,
'tf': TFAutoModelForQuestionAnswering if is_tf_available() else None,
......@@ -309,7 +365,7 @@ SUPPORTED_TASKS = {
}
def pipeline(task: str, model, tokenizer: Optional[Union[str, PreTrainedTokenizer]] = None, **kwargs) -> Pipeline:
def pipeline(task: str, model, config: Optional[PretrainedConfig] = None, tokenizer: Optional[Union[str, PreTrainedTokenizer]] = None, **kwargs) -> Pipeline:
"""
Utility factory method to build pipeline.
"""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment