Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
bcc99fd9
"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "155c782a2ccd103cf63ad48a2becd7c76a7d2115"
Commit
bcc99fd9
authored
Dec 19, 2019
by
Morgan Funtowicz
Browse files
Fix wrong automatic config allocation through AutoConfig
parent
ec5d6c6a
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
78 additions
and
39 deletions
+78
-39
transformers/pipelines.py
transformers/pipelines.py
+78
-39
No files found.
transformers/pipelines.py
View file @
bcc99fd9
...
@@ -25,7 +25,7 @@ from typing import Union, Optional, Tuple, List, Dict
...
@@ -25,7 +25,7 @@ from typing import Union, Optional, Tuple, List, Dict
import
numpy
as
np
import
numpy
as
np
from
transformers
import
AutoTokenizer
,
PreTrainedTokenizer
,
PretrainedConfig
,
\
from
transformers
import
AutoConfig
,
AutoTokenizer
,
PreTrainedTokenizer
,
PretrainedConfig
,
\
SquadExample
,
squad_convert_examples_to_features
,
is_tf_available
,
is_torch_available
,
logger
SquadExample
,
squad_convert_examples_to_features
,
is_tf_available
,
is_torch_available
,
logger
if
is_tf_available
():
if
is_tf_available
():
...
@@ -264,6 +264,27 @@ class Pipeline(_ScikitCompat):
...
@@ -264,6 +264,27 @@ class Pipeline(_ScikitCompat):
yield
yield
def
inputs_for_model
(
self
,
features
:
Union
[
dict
,
List
[
dict
]])
->
Dict
:
"""
Generates the input dictionary with model-specific parameters.
Returns:
dict holding all the required parameters for model's forward
"""
args
=
[
'input_ids'
,
'attention_mask'
]
model_type
=
type
(
self
.
model
).
__name__
.
lower
()
if
'distilbert'
not
in
model_type
and
'xlm'
not
in
model_type
:
args
+=
[
'token_type_ids'
]
if
'xlnet'
in
model_type
or
'xlm'
in
model_type
:
args
+=
[
'cls_index'
,
'p_mask'
]
if
isinstance
(
features
,
dict
):
return
{
k
:
features
[
k
]
for
k
in
args
}
else
:
return
{
k
:
[
feature
[
k
]
for
feature
in
features
]
for
k
in
args
}
def
__call__
(
self
,
*
texts
,
**
kwargs
):
def
__call__
(
self
,
*
texts
,
**
kwargs
):
# Parse arguments
# Parse arguments
inputs
=
self
.
_args_parser
(
*
texts
,
**
kwargs
)
inputs
=
self
.
_args_parser
(
*
texts
,
**
kwargs
)
...
@@ -271,9 +292,14 @@ class Pipeline(_ScikitCompat):
...
@@ -271,9 +292,14 @@ class Pipeline(_ScikitCompat):
# Encode for forward
# Encode for forward
with
self
.
device_placement
():
with
self
.
device_placement
():
inputs
=
self
.
tokenizer
.
batch_encode_plus
(
inputs
=
self
.
tokenizer
.
batch_encode_plus
(
inputs
,
add_special_tokens
=
True
,
return_tensors
=
'tf'
if
is_tf_available
()
else
'pt'
inputs
,
add_special_tokens
=
True
,
return_tensors
=
'tf'
if
is_tf_available
()
else
'pt'
,
# max_length=self.model.config.max_position_embedding
max_length
=
511
)
)
# Filter out features not available on specific models
inputs
=
self
.
inputs_for_model
(
inputs
)
return
self
.
_forward
(
inputs
)
return
self
.
_forward
(
inputs
)
def
_forward
(
self
,
inputs
):
def
_forward
(
self
,
inputs
):
...
@@ -331,7 +357,11 @@ class NerPipeline(Pipeline):
...
@@ -331,7 +357,11 @@ class NerPipeline(Pipeline):
# Manage correct placement of the tensors
# Manage correct placement of the tensors
with
self
.
device_placement
():
with
self
.
device_placement
():
tokens
=
self
.
tokenizer
.
encode_plus
(
sentence
,
return_attention_mask
=
False
,
return_tensors
=
'tf'
if
is_tf_available
()
else
'pt'
)
tokens
=
self
.
tokenizer
.
encode_plus
(
sentence
,
return_attention_mask
=
False
,
return_tensors
=
'tf'
if
is_tf_available
()
else
'pt'
,
max_length
=
512
)
# Forward
# Forward
if
is_torch_available
():
if
is_torch_available
():
...
@@ -443,27 +473,6 @@ class QuestionAnsweringPipeline(Pipeline):
...
@@ -443,27 +473,6 @@ class QuestionAnsweringPipeline(Pipeline):
super
().
__init__
(
model
,
tokenizer
,
args_parser
=
QuestionAnsweringArgumentHandler
(),
super
().
__init__
(
model
,
tokenizer
,
args_parser
=
QuestionAnsweringArgumentHandler
(),
device
=
device
,
**
kwargs
)
device
=
device
,
**
kwargs
)
def
inputs_for_model
(
self
,
features
:
Union
[
SquadExample
,
List
[
SquadExample
]])
->
Dict
:
"""
Generates the input dictionary with model-specific parameters.
Returns:
dict holding all the required parameters for model's forward
"""
args
=
[
'input_ids'
,
'attention_mask'
]
model_type
=
type
(
self
.
model
).
__name__
.
lower
()
if
'distilbert'
not
in
model_type
and
'xlm'
not
in
model_type
:
args
+=
[
'token_type_ids'
]
if
'xlnet'
in
model_type
or
'xlm'
in
model_type
:
args
+=
[
'cls_index'
,
'p_mask'
]
if
isinstance
(
features
,
SquadExample
):
return
{
k
:
features
.
__dict__
[
k
]
for
k
in
args
}
else
:
return
{
k
:
[
feature
.
__dict__
[
k
]
for
feature
in
features
]
for
k
in
args
}
def
__call__
(
self
,
*
texts
,
**
kwargs
):
def
__call__
(
self
,
*
texts
,
**
kwargs
):
"""
"""
Args:
Args:
...
@@ -495,7 +504,7 @@ class QuestionAnsweringPipeline(Pipeline):
...
@@ -495,7 +504,7 @@ class QuestionAnsweringPipeline(Pipeline):
# Convert inputs to features
# Convert inputs to features
examples
=
self
.
_args_parser
(
*
texts
,
**
kwargs
)
examples
=
self
.
_args_parser
(
*
texts
,
**
kwargs
)
features
=
squad_convert_examples_to_features
(
examples
,
self
.
tokenizer
,
kwargs
[
'max_seq_len'
],
kwargs
[
'doc_stride'
],
kwargs
[
'max_question_len'
],
False
)
features
=
squad_convert_examples_to_features
(
examples
,
self
.
tokenizer
,
kwargs
[
'max_seq_len'
],
kwargs
[
'doc_stride'
],
kwargs
[
'max_question_len'
],
False
)
fw_args
=
self
.
inputs_for_model
(
features
)
fw_args
=
self
.
inputs_for_model
(
features
.
__dict__
)
# Manage tensor allocation on correct device
# Manage tensor allocation on correct device
with
self
.
device_placement
():
with
self
.
device_placement
():
...
@@ -627,29 +636,50 @@ class QuestionAnsweringPipeline(Pipeline):
...
@@ -627,29 +636,50 @@ class QuestionAnsweringPipeline(Pipeline):
# Register all the supported task here
# Register all the supported task here
SUPPORTED_TASKS
=
{
SUPPORTED_TASKS
=
{
'feature-extraction'
:
{
'feature-extraction'
:
{
'impl'
:
FeatureExtractionPipeline
,
'impl'
:
FeatureExtractionPipeline
,
'tf'
:
TFAutoModel
if
is_tf_available
()
else
None
,
'tf'
:
TFAutoModel
if
is_tf_available
()
else
None
,
'pt'
:
AutoModel
if
is_torch_available
()
else
None
,
'pt'
:
AutoModel
if
is_torch_available
()
else
None
,
'default'
:
{
'model'
:
'distilbert-base-uncased'
,
'config'
:
None
,
'tokenizer'
:
'bert-base-uncased'
}
},
},
'
text-classification
'
:
{
'
sentiment-analysis
'
:
{
'impl'
:
TextClassificationPipeline
,
'impl'
:
TextClassificationPipeline
,
'tf'
:
TFAutoModelForSequenceClassification
if
is_tf_available
()
else
None
,
'tf'
:
TFAutoModelForSequenceClassification
if
is_tf_available
()
else
None
,
'pt'
:
AutoModelForSequenceClassification
if
is_torch_available
()
else
None
'pt'
:
AutoModelForSequenceClassification
if
is_torch_available
()
else
None
,
'default'
:
{
'model'
:
'https://s3.amazonaws.com/models.huggingface.co/bert/distilbert-base-uncased-finetuned-sst-2-english-pytorch_model.bin'
,
'config'
:
'https://s3.amazonaws.com/models.huggingface.co/bert/distilbert-base-uncased-finetuned-sst-2-english-config.json'
,
'tokenizer'
:
'bert-base-uncased'
}
},
},
'ner'
:
{
'ner'
:
{
'impl'
:
NerPipeline
,
'impl'
:
NerPipeline
,
'tf'
:
TFAutoModelForTokenClassification
if
is_tf_available
()
else
None
,
'tf'
:
TFAutoModelForTokenClassification
if
is_tf_available
()
else
None
,
'pt'
:
AutoModelForTokenClassification
if
is_torch_available
()
else
None
,
'pt'
:
AutoModelForTokenClassification
if
is_torch_available
()
else
None
,
'default'
:
{
'model'
:
'https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-finetuned-conll03-english-pytorch_model.bin'
,
'config'
:
'https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-finetuned-conll03-english-config.json'
,
'tokenizer'
:
'bert-base-cased'
}
},
},
'question-answering'
:
{
'question-answering'
:
{
'impl'
:
QuestionAnsweringPipeline
,
'impl'
:
QuestionAnsweringPipeline
,
'tf'
:
TFAutoModelForQuestionAnswering
if
is_tf_available
()
else
None
,
'tf'
:
TFAutoModelForQuestionAnswering
if
is_tf_available
()
else
None
,
'pt'
:
AutoModelForQuestionAnswering
if
is_torch_available
()
else
None
'pt'
:
AutoModelForQuestionAnswering
if
is_torch_available
()
else
None
,
'default'
:
{
'model'
:
'distilbert-base-uncased-distilled-squad'
,
'config'
:
None
,
'tokenizer'
:
'bert-base-uncased'
}
}
}
}
}
def
pipeline
(
task
:
str
,
model
,
config
:
Optional
[
Union
[
str
,
PretrainedConfig
]]
=
None
,
def
pipeline
(
task
:
str
,
model
:
Optional
=
None
,
config
:
Optional
[
Union
[
str
,
PretrainedConfig
]]
=
None
,
tokenizer
:
Optional
[
Union
[
str
,
PreTrainedTokenizer
]]
=
None
,
**
kwargs
)
->
Pipeline
:
tokenizer
:
Optional
[
Union
[
str
,
PreTrainedTokenizer
]]
=
None
,
**
kwargs
)
->
Pipeline
:
"""
"""
Utility factory method to build a pipeline.
Utility factory method to build a pipeline.
...
@@ -657,23 +687,32 @@ def pipeline(task: str, model, config: Optional[Union[str, PretrainedConfig]] =
...
@@ -657,23 +687,32 @@ def pipeline(task: str, model, config: Optional[Union[str, PretrainedConfig]] =
A Tokenizer instance in charge of mapping raw textual input to token
A Tokenizer instance in charge of mapping raw textual input to token
A Model instance
A Model instance
Some (optional) post processing for enhancing model's output
Some (optional) post processing for enhancing model's output
Examples:
pipeline('ner')
"""
"""
# Try to infer tokenizer from model name (if provided as str)
# Try to infer tokenizer from model name (if provided as str)
if
tokenizer
is
None
:
if
tokenizer
is
None
:
if
not
isinstance
(
model
,
str
):
if
model
is
not
None
and
not
isinstance
(
model
,
str
):
# Impossible to guest what is the right tokenizer here
# Impossible to guest what is the right tokenizer here
raise
Exception
(
'Tokenizer cannot be None if provided model is a PreTrainedModel instance'
)
raise
Exception
(
'Tokenizer cannot be None if provided model is a PreTrainedModel instance'
)
else
:
else
:
tokenizer
=
model
tokenizer
=
model
tokenizer
=
tokenizer
if
isinstance
(
tokenizer
,
PreTrainedTokenizer
)
else
AutoTokenizer
.
from_pretrained
(
tokenizer
)
# Retrieve the task
if
task
not
in
SUPPORTED_TASKS
:
if
task
not
in
SUPPORTED_TASKS
:
raise
KeyError
(
"Unknown task {}, available tasks are {}"
.
format
(
task
,
list
(
SUPPORTED_TASKS
.
keys
())))
raise
KeyError
(
"Unknown task {}, available tasks are {}"
.
format
(
task
,
list
(
SUPPORTED_TASKS
.
keys
())))
targeted_task
=
SUPPORTED_TASKS
[
task
]
targeted_task
=
SUPPORTED_TASKS
[
task
]
task
,
allocator
=
targeted_task
[
'impl'
],
targeted_task
[
'tf'
]
if
is_tf_available
()
else
targeted_task
[
'pt'
]
task
,
allocator
=
targeted_task
[
'impl'
],
targeted_task
[
'tf'
]
if
is_tf_available
()
else
targeted_task
[
'pt'
]
# Handling for default model for the task
if
model
is
None
:
model
,
config
,
tokenizer
=
tuple
(
targeted_task
[
'default'
].
values
())
# Allocate tokenizer
tokenizer
=
tokenizer
if
isinstance
(
tokenizer
,
PreTrainedTokenizer
)
else
AutoTokenizer
.
from_pretrained
(
tokenizer
)
# Special handling for model conversion
# Special handling for model conversion
if
isinstance
(
model
,
str
):
if
isinstance
(
model
,
str
):
from_tf
=
model
.
endswith
(
'.h5'
)
and
not
is_tf_available
()
from_tf
=
model
.
endswith
(
'.h5'
)
and
not
is_tf_available
()
...
@@ -689,7 +728,7 @@ def pipeline(task: str, model, config: Optional[Union[str, PretrainedConfig]] =
...
@@ -689,7 +728,7 @@ def pipeline(task: str, model, config: Optional[Union[str, PretrainedConfig]] =
from_tf
=
from_pt
=
False
from_tf
=
from_pt
=
False
if
isinstance
(
config
,
str
):
if
isinstance
(
config
,
str
):
config
=
Pretrained
Config
.
from_pretrained
(
config
)
config
=
Auto
Config
.
from_pretrained
(
config
)
if
allocator
.
__name__
.
startswith
(
'TF'
):
if
allocator
.
__name__
.
startswith
(
'TF'
):
model
=
allocator
.
from_pretrained
(
model
,
config
=
config
,
from_pt
=
from_pt
)
model
=
allocator
.
from_pretrained
(
model
,
config
=
config
,
from_pt
=
from_pt
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment