Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
215e0681
Unverified
Commit
215e0681
authored
May 06, 2022
by
Ritik Nandwal
Committed by
GitHub
May 06, 2022
Browse files
Added BigBirdPegasus onnx config (#17104)
* Add onnx configuration for bigbird-pegasus * Modify docs
parent
351cdbdf
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
254 additions
and
3 deletions
+254
-3
docs/source/en/serialization.mdx
docs/source/en/serialization.mdx
+1
-0
src/transformers/models/bigbird_pegasus/__init__.py
src/transformers/models/bigbird_pegasus/__init__.py
+10
-2
src/transformers/models/bigbird_pegasus/configuration_bigbird_pegasus.py
...s/models/bigbird_pegasus/configuration_bigbird_pegasus.py
+230
-1
src/transformers/onnx/features.py
src/transformers/onnx/features.py
+12
-0
tests/onnx/test_onnx_v2.py
tests/onnx/test_onnx_v2.py
+1
-0
No files found.
docs/source/en/serialization.mdx
View file @
215e0681
...
@@ -50,6 +50,7 @@ Ready-made configurations include the following architectures:
...
@@ -50,6 +50,7 @@ Ready-made configurations include the following architectures:
- BEiT
- BEiT
- BERT
- BERT
- BigBird
- BigBird
- BigBirdPegasus
- Blenderbot
- Blenderbot
- BlenderbotSmall
- BlenderbotSmall
- CamemBERT
- CamemBERT
...
...
src/transformers/models/bigbird_pegasus/__init__.py
View file @
215e0681
...
@@ -21,7 +21,11 @@ from ...utils import _LazyModule, is_torch_available
...
@@ -21,7 +21,11 @@ from ...utils import _LazyModule, is_torch_available
_import_structure
=
{
_import_structure
=
{
"configuration_bigbird_pegasus"
:
[
"BIGBIRD_PEGASUS_PRETRAINED_CONFIG_ARCHIVE_MAP"
,
"BigBirdPegasusConfig"
],
"configuration_bigbird_pegasus"
:
[
"BIGBIRD_PEGASUS_PRETRAINED_CONFIG_ARCHIVE_MAP"
,
"BigBirdPegasusConfig"
,
"BigBirdPegasusOnnxConfig"
,
],
}
}
if
is_torch_available
():
if
is_torch_available
():
...
@@ -37,7 +41,11 @@ if is_torch_available():
...
@@ -37,7 +41,11 @@ if is_torch_available():
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
.configuration_bigbird_pegasus
import
BIGBIRD_PEGASUS_PRETRAINED_CONFIG_ARCHIVE_MAP
,
BigBirdPegasusConfig
from
.configuration_bigbird_pegasus
import
(
BIGBIRD_PEGASUS_PRETRAINED_CONFIG_ARCHIVE_MAP
,
BigBirdPegasusConfig
,
BigBirdPegasusOnnxConfig
,
)
if
is_torch_available
():
if
is_torch_available
():
from
.modeling_bigbird_pegasus
import
(
from
.modeling_bigbird_pegasus
import
(
...
...
src/transformers/models/bigbird_pegasus/configuration_bigbird_pegasus.py
View file @
215e0681
...
@@ -14,8 +14,14 @@
...
@@ -14,8 +14,14 @@
# limitations under the License.
# limitations under the License.
""" BigBirdPegasus model configuration"""
""" BigBirdPegasus model configuration"""
from
collections
import
OrderedDict
from
typing
import
Any
,
Mapping
,
Optional
from
...
import
PreTrainedTokenizer
from
...configuration_utils
import
PretrainedConfig
from
...configuration_utils
import
PretrainedConfig
from
...utils
import
logging
from
...onnx
import
OnnxConfig
,
OnnxConfigWithPast
,
OnnxSeq2SeqConfigWithPast
from
...onnx.utils
import
compute_effective_axis_dimension
from
...utils
import
TensorType
,
is_torch_available
,
logging
logger
=
logging
.
get_logger
(
__name__
)
logger
=
logging
.
get_logger
(
__name__
)
...
@@ -185,3 +191,226 @@ class BigBirdPegasusConfig(PretrainedConfig):
...
@@ -185,3 +191,226 @@ class BigBirdPegasusConfig(PretrainedConfig):
decoder_start_token_id
=
decoder_start_token_id
,
decoder_start_token_id
=
decoder_start_token_id
,
**
kwargs
,
**
kwargs
,
)
)
# Copied from transformers.models.bart.configuration_bart.BartOnnxConfig
class
BigBirdPegasusOnnxConfig
(
OnnxSeq2SeqConfigWithPast
):
@
property
def
inputs
(
self
)
->
Mapping
[
str
,
Mapping
[
int
,
str
]]:
if
self
.
task
in
[
"default"
,
"seq2seq-lm"
]:
common_inputs
=
OrderedDict
(
[
(
"input_ids"
,
{
0
:
"batch"
,
1
:
"encoder_sequence"
}),
(
"attention_mask"
,
{
0
:
"batch"
,
1
:
"encoder_sequence"
}),
]
)
if
self
.
use_past
:
common_inputs
[
"decoder_input_ids"
]
=
{
0
:
"batch"
}
common_inputs
[
"decoder_attention_mask"
]
=
{
0
:
"batch"
,
1
:
"past_decoder_sequence + sequence"
}
else
:
common_inputs
[
"decoder_input_ids"
]
=
{
0
:
"batch"
,
1
:
"decoder_sequence"
}
common_inputs
[
"decoder_attention_mask"
]
=
{
0
:
"batch"
,
1
:
"decoder_sequence"
}
if
self
.
use_past
:
self
.
fill_with_past_key_values_
(
common_inputs
,
direction
=
"inputs"
)
elif
self
.
task
==
"causal-lm"
:
# TODO: figure this case out.
common_inputs
=
OrderedDict
(
[
(
"input_ids"
,
{
0
:
"batch"
,
1
:
"encoder_sequence"
}),
(
"attention_mask"
,
{
0
:
"batch"
,
1
:
"encoder_sequence"
}),
]
)
if
self
.
use_past
:
num_encoder_layers
,
_
=
self
.
num_layers
for
i
in
range
(
num_encoder_layers
):
common_inputs
[
f
"past_key_values.
{
i
}
.key"
]
=
{
0
:
"batch"
,
2
:
"past_sequence + sequence"
}
common_inputs
[
f
"past_key_values.
{
i
}
.value"
]
=
{
0
:
"batch"
,
2
:
"past_sequence + sequence"
}
else
:
common_inputs
=
OrderedDict
(
[
(
"input_ids"
,
{
0
:
"batch"
,
1
:
"encoder_sequence"
}),
(
"attention_mask"
,
{
0
:
"batch"
,
1
:
"encoder_sequence"
}),
(
"decoder_input_ids"
,
{
0
:
"batch"
,
1
:
"decoder_sequence"
}),
(
"decoder_attention_mask"
,
{
0
:
"batch"
,
1
:
"decoder_sequence"
}),
]
)
return
common_inputs
@
property
def
outputs
(
self
)
->
Mapping
[
str
,
Mapping
[
int
,
str
]]:
if
self
.
task
in
[
"default"
,
"seq2seq-lm"
]:
common_outputs
=
super
().
outputs
else
:
common_outputs
=
super
(
OnnxConfigWithPast
,
self
).
outputs
if
self
.
use_past
:
num_encoder_layers
,
_
=
self
.
num_layers
for
i
in
range
(
num_encoder_layers
):
common_outputs
[
f
"present.
{
i
}
.key"
]
=
{
0
:
"batch"
,
2
:
"past_sequence + sequence"
}
common_outputs
[
f
"present.
{
i
}
.value"
]
=
{
0
:
"batch"
,
2
:
"past_sequence + sequence"
}
return
common_outputs
def
_generate_dummy_inputs_for_default_and_seq2seq_lm
(
self
,
tokenizer
:
PreTrainedTokenizer
,
batch_size
:
int
=
-
1
,
seq_length
:
int
=
-
1
,
is_pair
:
bool
=
False
,
framework
:
Optional
[
TensorType
]
=
None
,
)
->
Mapping
[
str
,
Any
]:
encoder_inputs
=
self
.
_generate_dummy_inputs_for_sequence_classification_and_question_answering
(
tokenizer
,
batch_size
,
seq_length
,
is_pair
,
framework
)
# Generate decoder inputs
decoder_seq_length
=
seq_length
if
not
self
.
use_past
else
1
decoder_inputs
=
self
.
_generate_dummy_inputs_for_sequence_classification_and_question_answering
(
tokenizer
,
batch_size
,
decoder_seq_length
,
is_pair
,
framework
)
decoder_inputs
=
{
f
"decoder_
{
name
}
"
:
tensor
for
name
,
tensor
in
decoder_inputs
.
items
()}
common_inputs
=
dict
(
**
encoder_inputs
,
**
decoder_inputs
)
if
self
.
use_past
:
if
not
is_torch_available
():
raise
ValueError
(
"Cannot generate dummy past_keys inputs without PyTorch installed."
)
else
:
import
torch
batch
,
encoder_seq_length
=
common_inputs
[
"input_ids"
].
shape
decoder_seq_length
=
common_inputs
[
"decoder_input_ids"
].
shape
[
1
]
num_encoder_attention_heads
,
num_decoder_attention_heads
=
self
.
num_attention_heads
encoder_shape
=
(
batch
,
num_encoder_attention_heads
,
encoder_seq_length
,
self
.
_config
.
hidden_size
//
num_encoder_attention_heads
,
)
decoder_past_length
=
decoder_seq_length
+
3
decoder_shape
=
(
batch
,
num_decoder_attention_heads
,
decoder_past_length
,
self
.
_config
.
hidden_size
//
num_decoder_attention_heads
,
)
common_inputs
[
"decoder_attention_mask"
]
=
torch
.
cat
(
[
common_inputs
[
"decoder_attention_mask"
],
torch
.
ones
(
batch
,
decoder_past_length
)],
dim
=
1
)
common_inputs
[
"past_key_values"
]
=
[]
# If the number of encoder and decoder layers are present in the model configuration, both are considered
num_encoder_layers
,
num_decoder_layers
=
self
.
num_layers
min_num_layers
=
min
(
num_encoder_layers
,
num_decoder_layers
)
max_num_layers
=
max
(
num_encoder_layers
,
num_decoder_layers
)
-
min_num_layers
remaining_side_name
=
"encoder"
if
num_encoder_layers
>
num_decoder_layers
else
"decoder"
for
_
in
range
(
min_num_layers
):
common_inputs
[
"past_key_values"
].
append
(
(
torch
.
zeros
(
decoder_shape
),
torch
.
zeros
(
decoder_shape
),
torch
.
zeros
(
encoder_shape
),
torch
.
zeros
(
encoder_shape
),
)
)
# TODO: test this.
shape
=
encoder_shape
if
remaining_side_name
==
"encoder"
else
decoder_shape
for
_
in
range
(
min_num_layers
,
max_num_layers
):
common_inputs
[
"past_key_values"
].
append
((
torch
.
zeros
(
shape
),
torch
.
zeros
(
shape
)))
return
common_inputs
def
_generate_dummy_inputs_for_causal_lm
(
self
,
tokenizer
:
PreTrainedTokenizer
,
batch_size
:
int
=
-
1
,
seq_length
:
int
=
-
1
,
is_pair
:
bool
=
False
,
framework
:
Optional
[
TensorType
]
=
None
,
)
->
Mapping
[
str
,
Any
]:
common_inputs
=
self
.
_generate_dummy_inputs_for_sequence_classification_and_question_answering
(
tokenizer
,
batch_size
,
seq_length
,
is_pair
,
framework
)
if
self
.
use_past
:
if
not
is_torch_available
():
raise
ValueError
(
"Cannot generate dummy past_keys inputs without PyTorch installed."
)
else
:
import
torch
batch
,
seqlen
=
common_inputs
[
"input_ids"
].
shape
# Not using the same length for past_key_values
past_key_values_length
=
seqlen
+
2
num_encoder_layers
,
_
=
self
.
num_layers
num_encoder_attention_heads
,
_
=
self
.
num_attention_heads
past_shape
=
(
batch
,
num_encoder_attention_heads
,
past_key_values_length
,
self
.
_config
.
hidden_size
//
num_encoder_attention_heads
,
)
common_inputs
[
"attention_mask"
]
=
torch
.
cat
(
[
common_inputs
[
"attention_mask"
],
torch
.
ones
(
batch
,
past_key_values_length
)],
dim
=
1
)
common_inputs
[
"past_key_values"
]
=
[
(
torch
.
zeros
(
past_shape
),
torch
.
zeros
(
past_shape
))
for
_
in
range
(
num_encoder_layers
)
]
return
common_inputs
def
_generate_dummy_inputs_for_sequence_classification_and_question_answering
(
self
,
tokenizer
:
PreTrainedTokenizer
,
batch_size
:
int
=
-
1
,
seq_length
:
int
=
-
1
,
is_pair
:
bool
=
False
,
framework
:
Optional
[
TensorType
]
=
None
,
)
->
Mapping
[
str
,
Any
]:
# Copied from OnnxConfig.generate_dummy_inputs
# Did not use super(OnnxConfigWithPast, self).generate_dummy_inputs for code clarity.
# If dynamic axis (-1) we forward with a fixed dimension of 2 samples to avoid optimizations made by ONNX
batch_size
=
compute_effective_axis_dimension
(
batch_size
,
fixed_dimension
=
OnnxConfig
.
default_fixed_batch
,
num_token_to_add
=
0
)
# If dynamic axis (-1) we forward with a fixed dimension of 8 tokens to avoid optimizations made by ONNX
token_to_add
=
tokenizer
.
num_special_tokens_to_add
(
is_pair
)
seq_length
=
compute_effective_axis_dimension
(
seq_length
,
fixed_dimension
=
OnnxConfig
.
default_fixed_sequence
,
num_token_to_add
=
token_to_add
)
# Generate dummy inputs according to compute batch and sequence
dummy_input
=
[
" "
.
join
([
tokenizer
.
unk_token
])
*
seq_length
]
*
batch_size
common_inputs
=
dict
(
tokenizer
(
dummy_input
,
return_tensors
=
framework
))
return
common_inputs
def
generate_dummy_inputs
(
self
,
tokenizer
:
PreTrainedTokenizer
,
batch_size
:
int
=
-
1
,
seq_length
:
int
=
-
1
,
is_pair
:
bool
=
False
,
framework
:
Optional
[
TensorType
]
=
None
,
)
->
Mapping
[
str
,
Any
]:
if
self
.
task
in
[
"default"
,
"seq2seq-lm"
]:
common_inputs
=
self
.
_generate_dummy_inputs_for_default_and_seq2seq_lm
(
tokenizer
,
batch_size
=
batch_size
,
seq_length
=
seq_length
,
is_pair
=
is_pair
,
framework
=
framework
)
elif
self
.
task
==
"causal-lm"
:
common_inputs
=
self
.
_generate_dummy_inputs_for_causal_lm
(
tokenizer
,
batch_size
=
batch_size
,
seq_length
=
seq_length
,
is_pair
=
is_pair
,
framework
=
framework
)
else
:
common_inputs
=
self
.
_generate_dummy_inputs_for_sequence_classification_and_question_answering
(
tokenizer
,
batch_size
=
batch_size
,
seq_length
=
seq_length
,
is_pair
=
is_pair
,
framework
=
framework
)
return
common_inputs
def
_flatten_past_key_values_
(
self
,
flattened_output
,
name
,
idx
,
t
):
if
self
.
task
in
[
"default"
,
"seq2seq-lm"
]:
flattened_output
=
super
().
_flatten_past_key_values_
(
flattened_output
,
name
,
idx
,
t
)
else
:
flattened_output
=
super
(
OnnxSeq2SeqConfigWithPast
,
self
).
_flatten_past_key_values_
(
flattened_output
,
name
,
idx
,
t
)
src/transformers/onnx/features.py
View file @
215e0681
...
@@ -7,6 +7,7 @@ from ..models.bart import BartOnnxConfig
...
@@ -7,6 +7,7 @@ from ..models.bart import BartOnnxConfig
from
..models.beit
import
BeitOnnxConfig
from
..models.beit
import
BeitOnnxConfig
from
..models.bert
import
BertOnnxConfig
from
..models.bert
import
BertOnnxConfig
from
..models.big_bird
import
BigBirdOnnxConfig
from
..models.big_bird
import
BigBirdOnnxConfig
from
..models.bigbird_pegasus
import
BigBirdPegasusOnnxConfig
from
..models.blenderbot
import
BlenderbotOnnxConfig
from
..models.blenderbot
import
BlenderbotOnnxConfig
from
..models.blenderbot_small
import
BlenderbotSmallOnnxConfig
from
..models.blenderbot_small
import
BlenderbotSmallOnnxConfig
from
..models.camembert
import
CamembertOnnxConfig
from
..models.camembert
import
CamembertOnnxConfig
...
@@ -164,6 +165,17 @@ class FeaturesManager:
...
@@ -164,6 +165,17 @@ class FeaturesManager:
"question-answering"
,
"question-answering"
,
onnx_config_cls
=
BigBirdOnnxConfig
,
onnx_config_cls
=
BigBirdOnnxConfig
,
),
),
"bigbird-pegasus"
:
supported_features_mapping
(
"default"
,
"default-with-past"
,
"causal-lm"
,
"causal-lm-with-past"
,
"seq2seq-lm"
,
"seq2seq-lm-with-past"
,
"sequence-classification"
,
"question-answering"
,
onnx_config_cls
=
BigBirdPegasusOnnxConfig
,
),
"blenderbot"
:
supported_features_mapping
(
"blenderbot"
:
supported_features_mapping
(
"default"
,
"default"
,
"default-with-past"
,
"default-with-past"
,
...
...
tests/onnx/test_onnx_v2.py
View file @
215e0681
...
@@ -201,6 +201,7 @@ PYTORCH_EXPORT_SEQ2SEQ_WITH_PAST_MODELS = {
...
@@ -201,6 +201,7 @@ PYTORCH_EXPORT_SEQ2SEQ_WITH_PAST_MODELS = {
(
"m2m-100"
,
"facebook/m2m100_418M"
),
(
"m2m-100"
,
"facebook/m2m100_418M"
),
(
"blenderbot-small"
,
"facebook/blenderbot_small-90M"
),
(
"blenderbot-small"
,
"facebook/blenderbot_small-90M"
),
(
"blenderbot"
,
"facebook/blenderbot-400M-distill"
),
(
"blenderbot"
,
"facebook/blenderbot-400M-distill"
),
(
"bigbird-pegasus"
,
"google/bigbird-pegasus-large-arxiv"
),
}
}
# TODO(lewtun): Include the same model types in `PYTORCH_EXPORT_MODELS` once TensorFlow has parity with the PyTorch model implementations.
# TODO(lewtun): Include the same model types in `PYTORCH_EXPORT_MODELS` once TensorFlow has parity with the PyTorch model implementations.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment