Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
9080607b
Unverified
Commit
9080607b
authored
Nov 03, 2022
by
Michael Benayoun
Committed by
GitHub
Nov 03, 2022
Browse files
Fixed torch.finfo issue with torch.fx (#20040)
parent
6f257bb3
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
31 additions
and
28 deletions
+31
-28
src/transformers/utils/fx.py
src/transformers/utils/fx.py
+28
-8
tests/test_modeling_common.py
tests/test_modeling_common.py
+3
-20
No files found.
src/transformers/utils/fx.py
View file @
9080607b
...
...
@@ -230,6 +230,15 @@ def torch_arange(*args, **kwargs):
return
torch
.
empty
((
end
-
start
)
//
step
,
dtype
=
dtype
,
device
=
"meta"
)
def
torch_full
(
*
args
,
**
kwargs
):
args
=
list
(
args
)
if
isinstance
(
args
[
1
],
torch
.
Tensor
)
and
args
[
1
].
device
==
torch
.
device
(
"meta"
):
args
[
1
]
=
1
# Any value.
kwargs_without_device
=
dict
(
kwargs
)
kwargs_without_device
.
pop
(
"device"
,
None
)
return
torch
.
full
(
*
args
,
**
kwargs_without_device
)
def
torch_cat
(
tensors
,
dim
=
None
,
axis
=
None
,
*
,
out
=
None
):
if
dim
is
None
and
axis
is
None
:
dim
=
0
...
...
@@ -509,6 +518,7 @@ _MANUAL_META_OVERRIDES: Dict[Callable, Callable] = {
torch
.
where
:
torch_where
,
torch
.
abs
:
torch_abs
,
torch
.
arange
:
torch_arange
,
torch
.
full
:
torch_full
,
torch
.
cat
:
torch_cat
,
torch
.
stack
:
torch_stack
,
torch
.
add
:
torch_add
,
...
...
@@ -552,12 +562,6 @@ class HFProxy(Proxy):
def
shape
(
self
):
return
self
.
tracer
.
create_proxy
(
"call_method"
,
"size"
,
(
self
,),
{})
@
property
def
dtype
(
self
):
if
hasattr
(
self
,
"_metadata"
)
and
self
.
_metadata
is
not
None
:
return
self
.
_metadata
.
dtype
return
self
.
tracer
.
create_proxy
(
"call_function"
,
builtins
.
getattr
,
(
self
,
"dtype"
),
{})
@
property
def
device
(
self
):
# Hack so we can track when devices are used. During meta-tensor propagation,
...
...
@@ -597,12 +601,15 @@ class HFAttribute(HFProxy):
self
.
tracer
=
root
.
tracer
self
.
_node
=
None
if
hasattr
(
self
.
root
,
"_metadata"
):
self
.
install_metadata
(
getattr
(
self
.
root
.
_metadata
,
attr
))
@
property
def
node
(
self
):
# the node for attributes is added lazily, since most will just be method calls
# which do not rely on the getitem call
if
self
.
_node
is
None
:
self
.
_node
=
self
.
tracer
.
create_proxy
(
"call_function"
,
getattr
,
(
self
.
root
,
self
.
attr
),
{}).
node
self
.
_node
=
self
.
tracer
.
create_proxy
(
"call_function"
,
builtins
.
getattr
,
(
self
.
root
,
self
.
attr
),
{}).
node
return
self
.
_node
def
__call__
(
self
,
*
args
,
**
kwargs
):
...
...
@@ -663,7 +670,18 @@ class HFTracer(Tracer):
# Feature flag for proxying accesses to buffer values
proxy_buffer_attributes
:
bool
=
True
allow_insert_stateless_mods
:
bool
=
True
_TORCH_METHODS_TO_PATCH
=
[
"arange"
,
"zeros"
,
"ones"
,
"full"
,
"full_like"
,
"eye"
,
"empty"
,
"tensor"
]
_TORCH_METHODS_TO_PATCH
=
[
"arange"
,
"zeros"
,
"ones"
,
"full"
,
"full_like"
,
"eye"
,
"empty"
,
"tensor"
,
"clamp"
,
"finfo"
,
]
def
__init__
(
self
,
autowrap_modules
=
(
math
,),
autowrap_functions
=
()):
...
...
@@ -737,6 +755,8 @@ class HFTracer(Tracer):
"GPT2DoubleHeadsModel"
,
]:
inputs_dict
[
"labels"
]
=
torch
.
zeros
(
shape
,
dtype
=
torch
.
long
,
device
=
device
)
elif
model_class_name
in
[
*
get_values
(
MODEL_FOR_CTC_MAPPING_NAMES
)]:
inputs_dict
[
"labels"
]
=
torch
.
zeros
(
shape
,
dtype
=
torch
.
float32
,
device
=
device
)
else
:
raise
NotImplementedError
(
f
"Generating the dummy input named
{
input_name
}
for
{
model_class_name
}
is not supported yet."
...
...
tests/test_modeling_common.py
View file @
9080607b
...
...
@@ -835,17 +835,14 @@ class ModelTesterMixin:
filtered_inputs
=
{
k
:
v
for
(
k
,
v
)
in
inputs
.
items
()
if
k
in
input_names
}
input_names
=
list
(
filtered_inputs
.
keys
())
model_output
=
model
(
**
filtered_inputs
)
if
(
isinstance
(
model
,
tuple
(
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING
.
values
()))
and
not
hasattr
(
model
.
config
,
"problem_type"
)
or
model
.
config
.
problem_type
is
None
if
isinstance
(
model
,
tuple
(
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING
.
values
()))
and
(
not
hasattr
(
model
.
config
,
"problem_type"
)
or
model
.
config
.
problem_type
is
None
):
model
.
config
.
problem_type
=
"single_label_classification"
traced_model
=
symbolic_trace
(
model
,
input_names
)
traced_output
=
traced_model
(
**
filtered_inputs
)
model_output
=
model
(
**
filtered_inputs
)
except
Exception
as
e
:
self
.
fail
(
f
"Couldn't trace module:
{
e
}
"
)
...
...
@@ -871,20 +868,6 @@ class ModelTesterMixin:
f
"traced
{
i
}
th output doesn't match model
{
i
}
th output for
{
model_class
}
"
,
)
# Test that the model can be TorchScripted
try
:
scripted
=
torch
.
jit
.
script
(
traced_model
)
except
Exception
as
e
:
self
.
fail
(
f
"Could not TorchScript the traced model:
{
e
}
"
)
scripted_output
=
scripted
(
**
filtered_inputs
)
scripted_output
=
flatten_output
(
scripted_output
)
for
i
in
range
(
num_outputs
):
self
.
assertTrue
(
torch
.
allclose
(
model_output
[
i
],
scripted_output
[
i
]),
f
"scripted
{
i
}
th output doesn't match model
{
i
}
th output for
{
model_class
}
"
,
)
# Test that the model can be serialized and restored properly
with
tempfile
.
TemporaryDirectory
()
as
tmp_dir_name
:
pkl_file_name
=
os
.
path
.
join
(
tmp_dir_name
,
"model.pkl"
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment