Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
aaee4038
Unverified
Commit
aaee4038
authored
Apr 26, 2022
by
Krishna Sirumalla
Committed by
GitHub
Apr 26, 2022
Browse files
Add onnx config for RoFormer (#16861)
* add roformer onnx config
parent
8afaaa26
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
37 additions
and
2 deletions
+37
-2
docs/source/en/serialization.mdx
docs/source/en/serialization.mdx
+1
-0
src/transformers/models/roformer/__init__.py
src/transformers/models/roformer/__init__.py
+2
-2
src/transformers/models/roformer/configuration_roformer.py
src/transformers/models/roformer/configuration_roformer.py
+21
-0
src/transformers/onnx/features.py
src/transformers/onnx/features.py
+12
-0
tests/onnx/test_onnx_v2.py
tests/onnx/test_onnx_v2.py
+1
-0
No files found.
docs/source/en/serialization.mdx
View file @
aaee4038
...
...
@@ -70,6 +70,7 @@ Ready-made configurations include the following architectures:
- OpenAI GPT-2
- PLBart
- RoBERTa
- RoFormer
- T5
- TAPEX
- ViT
...
...
src/transformers/models/roformer/__init__.py
View file @
aaee4038
...
...
@@ -21,7 +21,7 @@ from ...utils import _LazyModule, is_flax_available, is_tf_available, is_tokeniz
_import_structure
=
{
"configuration_roformer"
:
[
"ROFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP"
,
"RoFormerConfig"
],
"configuration_roformer"
:
[
"ROFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP"
,
"RoFormerConfig"
,
"RoFormerOnnxConfig"
],
"tokenization_roformer"
:
[
"RoFormerTokenizer"
],
}
...
...
@@ -73,7 +73,7 @@ if is_flax_available():
if
TYPE_CHECKING
:
from
.configuration_roformer
import
ROFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP
,
RoFormerConfig
from
.configuration_roformer
import
ROFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP
,
RoFormerConfig
,
RoFormerOnnxConfig
from
.tokenization_roformer
import
RoFormerTokenizer
if
is_tokenizers_available
():
...
...
src/transformers/models/roformer/configuration_roformer.py
View file @
aaee4038
...
...
@@ -14,7 +14,11 @@
# limitations under the License.
""" RoFormer model configuration"""
from
collections
import
OrderedDict
from
typing
import
Mapping
from
...configuration_utils
import
PretrainedConfig
from
...onnx
import
OnnxConfig
from
...utils
import
logging
...
...
@@ -131,3 +135,20 @@ class RoFormerConfig(PretrainedConfig):
self
.
layer_norm_eps
=
layer_norm_eps
self
.
rotary_value
=
rotary_value
self
.
use_cache
=
use_cache
class
RoFormerOnnxConfig
(
OnnxConfig
):
@
property
def
inputs
(
self
)
->
Mapping
[
str
,
Mapping
[
int
,
str
]]:
if
self
.
task
==
"multiple-choice"
:
dynamic_axis
=
{
0
:
"batch"
,
1
:
"choice"
,
2
:
"sequence"
}
else
:
dynamic_axis
=
{
0
:
"batch"
,
1
:
"sequence"
}
dynamic_axis
=
{
0
:
"batch"
,
1
:
"sequence"
}
return
OrderedDict
(
[
(
"input_ids"
,
dynamic_axis
),
(
"attention_mask"
,
dynamic_axis
),
(
"token_type_ids"
,
dynamic_axis
),
]
)
src/transformers/onnx/features.py
View file @
aaee4038
...
...
@@ -25,6 +25,7 @@ from ..models.m2m_100 import M2M100OnnxConfig
from
..models.marian
import
MarianOnnxConfig
from
..models.mbart
import
MBartOnnxConfig
from
..models.roberta
import
RobertaOnnxConfig
from
..models.roformer
import
RoFormerOnnxConfig
from
..models.t5
import
T5OnnxConfig
from
..models.vit
import
ViTOnnxConfig
from
..models.xlm_roberta
import
XLMRobertaOnnxConfig
...
...
@@ -333,6 +334,17 @@ class FeaturesManager:
"question-answering"
,
onnx_config_cls
=
Data2VecTextOnnxConfig
,
),
"roformer"
:
supported_features_mapping
(
"default"
,
"masked-lm"
,
"causal-lm"
,
"sequence-classification"
,
"token-classification"
,
"multiple-choice"
,
"question-answering"
,
"token-classification"
,
onnx_config_cls
=
RoFormerOnnxConfig
,
),
}
AVAILABLE_FEATURES
=
sorted
(
reduce
(
lambda
s1
,
s2
:
s1
|
s2
,
(
v
.
keys
()
for
v
in
_SUPPORTED_MODEL_TYPE
.
values
())))
...
...
tests/onnx/test_onnx_v2.py
View file @
aaee4038
...
...
@@ -179,6 +179,7 @@ PYTORCH_EXPORT_MODELS = {
(
"distilbert"
,
"distilbert-base-cased"
),
(
"electra"
,
"google/electra-base-generator"
),
(
"roberta"
,
"roberta-base"
),
(
"roformer"
,
"junnyu/roformer_chinese_base"
),
(
"xlm-roberta"
,
"xlm-roberta-base"
),
(
"layoutlm"
,
"microsoft/layoutlm-base-uncased"
),
(
"vit"
,
"google/vit-base-patch16-224"
),
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment