Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
41c186d2
"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "5db9abde439bc02c3791da2a4fefee80d94d5b96"
Unverified
Commit
41c186d2
authored
Sep 23, 2021
by
Lysandre Debut
Committed by
GitHub
Sep 23, 2021
Browse files
Replace torch.set_grad_enabled by torch.no_grad (#13703)
parent
f888e5c3
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
37 additions
and
38 deletions
+37
-38
src/transformers/onnx/convert.py
src/transformers/onnx/convert.py
+37
-38
No files found.
src/transformers/onnx/convert.py
View file @
41c186d2
...
@@ -90,44 +90,43 @@ def export(
...
@@ -90,44 +90,43 @@ def export(
raise
AssertionError
(
f
"Unsupported PyTorch version, minimum required is 1.8.0, got:
{
torch_version
}
"
)
raise
AssertionError
(
f
"Unsupported PyTorch version, minimum required is 1.8.0, got:
{
torch_version
}
"
)
logger
.
info
(
f
"Using framework PyTorch:
{
torch
.
__version__
}
"
)
logger
.
info
(
f
"Using framework PyTorch:
{
torch
.
__version__
}
"
)
torch
.
set_grad_enabled
(
False
)
with
torch
.
no_grad
():
model
.
config
.
return_dict
=
True
model
.
config
.
return_dict
=
True
model
.
eval
()
model
.
eval
()
# Check if we need to override certain configuration item
# Check if we need to override certain configuration item
if
config
.
values_override
is
not
None
:
if
config
.
values_override
is
not
None
:
logger
.
info
(
f
"Overriding
{
len
(
config
.
values_override
)
}
configuration item(s)"
)
logger
.
info
(
f
"Overriding
{
len
(
config
.
values_override
)
}
configuration item(s)"
)
for
override_config_key
,
override_config_value
in
config
.
values_override
.
items
():
for
override_config_key
,
override_config_value
in
config
.
values_override
.
items
():
logger
.
info
(
f
"
\t
-
{
override_config_key
}
->
{
override_config_value
}
"
)
logger
.
info
(
f
"
\t
-
{
override_config_key
}
->
{
override_config_value
}
"
)
setattr
(
model
.
config
,
override_config_key
,
override_config_value
)
setattr
(
model
.
config
,
override_config_key
,
override_config_value
)
# Ensure inputs match
# Ensure inputs match
# TODO: Check when exporting QA we provide "is_pair=True"
# TODO: Check when exporting QA we provide "is_pair=True"
model_inputs
=
config
.
generate_dummy_inputs
(
tokenizer
,
framework
=
TensorType
.
PYTORCH
)
model_inputs
=
config
.
generate_dummy_inputs
(
tokenizer
,
framework
=
TensorType
.
PYTORCH
)
inputs_match
,
matched_inputs
=
ensure_model_and_config_inputs_match
(
model
,
model_inputs
.
keys
())
inputs_match
,
matched_inputs
=
ensure_model_and_config_inputs_match
(
model
,
model_inputs
.
keys
())
onnx_outputs
=
list
(
config
.
outputs
.
keys
())
onnx_outputs
=
list
(
config
.
outputs
.
keys
())
if
not
inputs_match
:
if
not
inputs_match
:
raise
ValueError
(
"Model and config inputs doesn't match"
)
raise
ValueError
(
"Model and config inputs doesn't match"
)
config
.
patch_ops
()
config
.
patch_ops
()
# export can works with named args but the dict containing named args as to be last element of the args tuple
# export can works with named args but the dict containing named args as to be last element of the args tuple
export
(
export
(
model
,
model
,
(
model_inputs
,),
(
model_inputs
,),
f
=
output
.
as_posix
(),
f
=
output
.
as_posix
(),
input_names
=
list
(
config
.
inputs
.
keys
()),
input_names
=
list
(
config
.
inputs
.
keys
()),
output_names
=
onnx_outputs
,
output_names
=
onnx_outputs
,
dynamic_axes
=
{
name
:
axes
for
name
,
axes
in
chain
(
config
.
inputs
.
items
(),
config
.
outputs
.
items
())},
dynamic_axes
=
{
name
:
axes
for
name
,
axes
in
chain
(
config
.
inputs
.
items
(),
config
.
outputs
.
items
())},
do_constant_folding
=
True
,
do_constant_folding
=
True
,
use_external_data_format
=
config
.
use_external_data_format
(
model
.
num_parameters
()),
use_external_data_format
=
config
.
use_external_data_format
(
model
.
num_parameters
()),
enable_onnx_checker
=
True
,
enable_onnx_checker
=
True
,
opset_version
=
opset
,
opset_version
=
opset
,
)
)
config
.
restore_ops
()
config
.
restore_ops
()
torch
.
set_grad_enabled
(
True
)
return
matched_inputs
,
onnx_outputs
return
matched_inputs
,
onnx_outputs
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment