Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
bcc99fd9
Commit
bcc99fd9
authored
Dec 19, 2019
by
Morgan Funtowicz
Browse files
Fix wrong automatic config allocation through AutoConfig
parent
ec5d6c6a
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
78 additions
and
39 deletions
+78
-39
transformers/pipelines.py
transformers/pipelines.py
+78
-39
No files found.
transformers/pipelines.py
View file @
bcc99fd9
...
@@ -25,7 +25,7 @@ from typing import Union, Optional, Tuple, List, Dict
...
@@ -25,7 +25,7 @@ from typing import Union, Optional, Tuple, List, Dict
import
numpy
as
np
import
numpy
as
np
from
transformers
import
AutoTokenizer
,
PreTrainedTokenizer
,
PretrainedConfig
,
\
from
transformers
import
AutoConfig
,
AutoTokenizer
,
PreTrainedTokenizer
,
PretrainedConfig
,
\
SquadExample
,
squad_convert_examples_to_features
,
is_tf_available
,
is_torch_available
,
logger
SquadExample
,
squad_convert_examples_to_features
,
is_tf_available
,
is_torch_available
,
logger
if
is_tf_available
():
if
is_tf_available
():
...
@@ -264,6 +264,27 @@ class Pipeline(_ScikitCompat):
...
@@ -264,6 +264,27 @@ class Pipeline(_ScikitCompat):
yield
yield
def
inputs_for_model
(
self
,
features
:
Union
[
dict
,
List
[
dict
]])
->
Dict
:
"""
Generates the input dictionary with model-specific parameters.
Returns:
dict holding all the required parameters for model's forward
"""
args
=
[
'input_ids'
,
'attention_mask'
]
model_type
=
type
(
self
.
model
).
__name__
.
lower
()
if
'distilbert'
not
in
model_type
and
'xlm'
not
in
model_type
:
args
+=
[
'token_type_ids'
]
if
'xlnet'
in
model_type
or
'xlm'
in
model_type
:
args
+=
[
'cls_index'
,
'p_mask'
]
if
isinstance
(
features
,
dict
):
return
{
k
:
features
[
k
]
for
k
in
args
}
else
:
return
{
k
:
[
feature
[
k
]
for
feature
in
features
]
for
k
in
args
}
def
__call__
(
self
,
*
texts
,
**
kwargs
):
def
__call__
(
self
,
*
texts
,
**
kwargs
):
# Parse arguments
# Parse arguments
inputs
=
self
.
_args_parser
(
*
texts
,
**
kwargs
)
inputs
=
self
.
_args_parser
(
*
texts
,
**
kwargs
)
...
@@ -271,9 +292,14 @@ class Pipeline(_ScikitCompat):
...
@@ -271,9 +292,14 @@ class Pipeline(_ScikitCompat):
# Encode for forward
# Encode for forward
with
self
.
device_placement
():
with
self
.
device_placement
():
inputs
=
self
.
tokenizer
.
batch_encode_plus
(
inputs
=
self
.
tokenizer
.
batch_encode_plus
(
inputs
,
add_special_tokens
=
True
,
return_tensors
=
'tf'
if
is_tf_available
()
else
'pt'
inputs
,
add_special_tokens
=
True
,
return_tensors
=
'tf'
if
is_tf_available
()
else
'pt'
,
# max_length=self.model.config.max_position_embedding
max_length
=
511
)
)
# Filter out features not available on specific models
inputs
=
self
.
inputs_for_model
(
inputs
)
return
self
.
_forward
(
inputs
)
return
self
.
_forward
(
inputs
)
def
_forward
(
self
,
inputs
):
def
_forward
(
self
,
inputs
):
...
@@ -331,7 +357,11 @@ class NerPipeline(Pipeline):
...
@@ -331,7 +357,11 @@ class NerPipeline(Pipeline):
# Manage correct placement of the tensors
# Manage correct placement of the tensors
with
self
.
device_placement
():
with
self
.
device_placement
():
tokens
=
self
.
tokenizer
.
encode_plus
(
sentence
,
return_attention_mask
=
False
,
return_tensors
=
'tf'
if
is_tf_available
()
else
'pt'
)
tokens
=
self
.
tokenizer
.
encode_plus
(
sentence
,
return_attention_mask
=
False
,
return_tensors
=
'tf'
if
is_tf_available
()
else
'pt'
,
max_length
=
512
)
# Forward
# Forward
if
is_torch_available
():
if
is_torch_available
():
...
@@ -443,27 +473,6 @@ class QuestionAnsweringPipeline(Pipeline):
...
@@ -443,27 +473,6 @@ class QuestionAnsweringPipeline(Pipeline):
super
().
__init__
(
model
,
tokenizer
,
args_parser
=
QuestionAnsweringArgumentHandler
(),
super
().
__init__
(
model
,
tokenizer
,
args_parser
=
QuestionAnsweringArgumentHandler
(),
device
=
device
,
**
kwargs
)
device
=
device
,
**
kwargs
)
def
inputs_for_model
(
self
,
features
:
Union
[
SquadExample
,
List
[
SquadExample
]])
->
Dict
:
"""
Generates the input dictionary with model-specific parameters.
Returns:
dict holding all the required parameters for model's forward
"""
args
=
[
'input_ids'
,
'attention_mask'
]
model_type
=
type
(
self
.
model
).
__name__
.
lower
()
if
'distilbert'
not
in
model_type
and
'xlm'
not
in
model_type
:
args
+=
[
'token_type_ids'
]
if
'xlnet'
in
model_type
or
'xlm'
in
model_type
:
args
+=
[
'cls_index'
,
'p_mask'
]
if
isinstance
(
features
,
SquadExample
):
return
{
k
:
features
.
__dict__
[
k
]
for
k
in
args
}
else
:
return
{
k
:
[
feature
.
__dict__
[
k
]
for
feature
in
features
]
for
k
in
args
}
def
__call__
(
self
,
*
texts
,
**
kwargs
):
def
__call__
(
self
,
*
texts
,
**
kwargs
):
"""
"""
Args:
Args:
...
@@ -495,7 +504,7 @@ class QuestionAnsweringPipeline(Pipeline):
...
@@ -495,7 +504,7 @@ class QuestionAnsweringPipeline(Pipeline):
# Convert inputs to features
# Convert inputs to features
examples
=
self
.
_args_parser
(
*
texts
,
**
kwargs
)
examples
=
self
.
_args_parser
(
*
texts
,
**
kwargs
)
features
=
squad_convert_examples_to_features
(
examples
,
self
.
tokenizer
,
kwargs
[
'max_seq_len'
],
kwargs
[
'doc_stride'
],
kwargs
[
'max_question_len'
],
False
)
features
=
squad_convert_examples_to_features
(
examples
,
self
.
tokenizer
,
kwargs
[
'max_seq_len'
],
kwargs
[
'doc_stride'
],
kwargs
[
'max_question_len'
],
False
)
fw_args
=
self
.
inputs_for_model
(
features
)
fw_args
=
self
.
inputs_for_model
(
features
.
__dict__
)
# Manage tensor allocation on correct device
# Manage tensor allocation on correct device
with
self
.
device_placement
():
with
self
.
device_placement
():
...
@@ -630,26 +639,47 @@ SUPPORTED_TASKS = {
...
@@ -630,26 +639,47 @@ SUPPORTED_TASKS = {
'impl'
:
FeatureExtractionPipeline
,
'impl'
:
FeatureExtractionPipeline
,
'tf'
:
TFAutoModel
if
is_tf_available
()
else
None
,
'tf'
:
TFAutoModel
if
is_tf_available
()
else
None
,
'pt'
:
AutoModel
if
is_torch_available
()
else
None
,
'pt'
:
AutoModel
if
is_torch_available
()
else
None
,
'default'
:
{
'model'
:
'distilbert-base-uncased'
,
'config'
:
None
,
'tokenizer'
:
'bert-base-uncased'
}
},
},
'
text-classification
'
:
{
'
sentiment-analysis
'
:
{
'impl'
:
TextClassificationPipeline
,
'impl'
:
TextClassificationPipeline
,
'tf'
:
TFAutoModelForSequenceClassification
if
is_tf_available
()
else
None
,
'tf'
:
TFAutoModelForSequenceClassification
if
is_tf_available
()
else
None
,
'pt'
:
AutoModelForSequenceClassification
if
is_torch_available
()
else
None
'pt'
:
AutoModelForSequenceClassification
if
is_torch_available
()
else
None
,
'default'
:
{
'model'
:
'https://s3.amazonaws.com/models.huggingface.co/bert/distilbert-base-uncased-finetuned-sst-2-english-pytorch_model.bin'
,
'config'
:
'https://s3.amazonaws.com/models.huggingface.co/bert/distilbert-base-uncased-finetuned-sst-2-english-config.json'
,
'tokenizer'
:
'bert-base-uncased'
}
},
},
'ner'
:
{
'ner'
:
{
'impl'
:
NerPipeline
,
'impl'
:
NerPipeline
,
'tf'
:
TFAutoModelForTokenClassification
if
is_tf_available
()
else
None
,
'tf'
:
TFAutoModelForTokenClassification
if
is_tf_available
()
else
None
,
'pt'
:
AutoModelForTokenClassification
if
is_torch_available
()
else
None
,
'pt'
:
AutoModelForTokenClassification
if
is_torch_available
()
else
None
,
'default'
:
{
'model'
:
'https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-finetuned-conll03-english-pytorch_model.bin'
,
'config'
:
'https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-finetuned-conll03-english-config.json'
,
'tokenizer'
:
'bert-base-cased'
}
},
},
'question-answering'
:
{
'question-answering'
:
{
'impl'
:
QuestionAnsweringPipeline
,
'impl'
:
QuestionAnsweringPipeline
,
'tf'
:
TFAutoModelForQuestionAnswering
if
is_tf_available
()
else
None
,
'tf'
:
TFAutoModelForQuestionAnswering
if
is_tf_available
()
else
None
,
'pt'
:
AutoModelForQuestionAnswering
if
is_torch_available
()
else
None
'pt'
:
AutoModelForQuestionAnswering
if
is_torch_available
()
else
None
,
'default'
:
{
'model'
:
'distilbert-base-uncased-distilled-squad'
,
'config'
:
None
,
'tokenizer'
:
'bert-base-uncased'
}
}
}
}
}
def
pipeline
(
task
:
str
,
model
,
config
:
Optional
[
Union
[
str
,
PretrainedConfig
]]
=
None
,
def
pipeline
(
task
:
str
,
model
:
Optional
=
None
,
config
:
Optional
[
Union
[
str
,
PretrainedConfig
]]
=
None
,
tokenizer
:
Optional
[
Union
[
str
,
PreTrainedTokenizer
]]
=
None
,
**
kwargs
)
->
Pipeline
:
tokenizer
:
Optional
[
Union
[
str
,
PreTrainedTokenizer
]]
=
None
,
**
kwargs
)
->
Pipeline
:
"""
"""
Utility factory method to build a pipeline.
Utility factory method to build a pipeline.
...
@@ -657,23 +687,32 @@ def pipeline(task: str, model, config: Optional[Union[str, PretrainedConfig]] =
...
@@ -657,23 +687,32 @@ def pipeline(task: str, model, config: Optional[Union[str, PretrainedConfig]] =
A Tokenizer instance in charge of mapping raw textual input to token
A Tokenizer instance in charge of mapping raw textual input to token
A Model instance
A Model instance
Some (optional) post processing for enhancing model's output
Some (optional) post processing for enhancing model's output
Examples:
pipeline('ner')
"""
"""
# Try to infer tokenizer from model name (if provided as str)
# Try to infer tokenizer from model name (if provided as str)
if
tokenizer
is
None
:
if
tokenizer
is
None
:
if
not
isinstance
(
model
,
str
):
if
model
is
not
None
and
not
isinstance
(
model
,
str
):
# Impossible to guest what is the right tokenizer here
# Impossible to guest what is the right tokenizer here
raise
Exception
(
'Tokenizer cannot be None if provided model is a PreTrainedModel instance'
)
raise
Exception
(
'Tokenizer cannot be None if provided model is a PreTrainedModel instance'
)
else
:
else
:
tokenizer
=
model
tokenizer
=
model
tokenizer
=
tokenizer
if
isinstance
(
tokenizer
,
PreTrainedTokenizer
)
else
AutoTokenizer
.
from_pretrained
(
tokenizer
)
# Retrieve the task
if
task
not
in
SUPPORTED_TASKS
:
if
task
not
in
SUPPORTED_TASKS
:
raise
KeyError
(
"Unknown task {}, available tasks are {}"
.
format
(
task
,
list
(
SUPPORTED_TASKS
.
keys
())))
raise
KeyError
(
"Unknown task {}, available tasks are {}"
.
format
(
task
,
list
(
SUPPORTED_TASKS
.
keys
())))
targeted_task
=
SUPPORTED_TASKS
[
task
]
targeted_task
=
SUPPORTED_TASKS
[
task
]
task
,
allocator
=
targeted_task
[
'impl'
],
targeted_task
[
'tf'
]
if
is_tf_available
()
else
targeted_task
[
'pt'
]
task
,
allocator
=
targeted_task
[
'impl'
],
targeted_task
[
'tf'
]
if
is_tf_available
()
else
targeted_task
[
'pt'
]
# Handling for default model for the task
if
model
is
None
:
model
,
config
,
tokenizer
=
tuple
(
targeted_task
[
'default'
].
values
())
# Allocate tokenizer
tokenizer
=
tokenizer
if
isinstance
(
tokenizer
,
PreTrainedTokenizer
)
else
AutoTokenizer
.
from_pretrained
(
tokenizer
)
# Special handling for model conversion
# Special handling for model conversion
if
isinstance
(
model
,
str
):
if
isinstance
(
model
,
str
):
from_tf
=
model
.
endswith
(
'.h5'
)
and
not
is_tf_available
()
from_tf
=
model
.
endswith
(
'.h5'
)
and
not
is_tf_available
()
...
@@ -689,7 +728,7 @@ def pipeline(task: str, model, config: Optional[Union[str, PretrainedConfig]] =
...
@@ -689,7 +728,7 @@ def pipeline(task: str, model, config: Optional[Union[str, PretrainedConfig]] =
from_tf
=
from_pt
=
False
from_tf
=
from_pt
=
False
if
isinstance
(
config
,
str
):
if
isinstance
(
config
,
str
):
config
=
Pretrained
Config
.
from_pretrained
(
config
)
config
=
Auto
Config
.
from_pretrained
(
config
)
if
allocator
.
__name__
.
startswith
(
'TF'
):
if
allocator
.
__name__
.
startswith
(
'TF'
):
model
=
allocator
.
from_pretrained
(
model
,
config
=
config
,
from_pt
=
from_pt
)
model
=
allocator
.
from_pretrained
(
model
,
config
=
config
,
from_pt
=
from_pt
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment