Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
fbabd674
"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "0145c6825e488b2bfa1bbf403a6b92f754043ed3"
Unverified
Commit
fbabd674
authored
May 02, 2024
by
Michael Benayoun
Committed by
GitHub
May 02, 2024
Browse files
Fix for Neuron (#30259)
parent
5cf3e6bf
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
240 additions
and
99 deletions
+240
-99
src/transformers/models/cohere/modeling_cohere.py
src/transformers/models/cohere/modeling_cohere.py
+5
-2
src/transformers/models/gemma/modeling_gemma.py
src/transformers/models/gemma/modeling_gemma.py
+5
-2
src/transformers/models/llama/modeling_llama.py
src/transformers/models/llama/modeling_llama.py
+5
-2
src/transformers/models/olmo/modeling_olmo.py
src/transformers/models/olmo/modeling_olmo.py
+5
-2
src/transformers/training_args.py
src/transformers/training_args.py
+2
-2
src/transformers/utils/fx.py
src/transformers/utils/fx.py
+218
-68
tests/test_modeling_common.py
tests/test_modeling_common.py
+0
-21
No files found.
src/transformers/models/cohere/modeling_cohere.py
View file @
fbabd674
...
@@ -1010,8 +1010,11 @@ class CohereModel(CoherePreTrainedModel):
...
@@ -1010,8 +1010,11 @@ class CohereModel(CoherePreTrainedModel):
causal_mask
=
causal_mask
.
clone
()
# copy to contiguous memory for in-place edit
causal_mask
=
causal_mask
.
clone
()
# copy to contiguous memory for in-place edit
if
attention_mask
.
dim
()
==
2
:
if
attention_mask
.
dim
()
==
2
:
mask_length
=
attention_mask
.
shape
[
-
1
]
mask_length
=
attention_mask
.
shape
[
-
1
]
padding_mask
=
causal_mask
[...,
:
mask_length
].
eq
(
0.0
)
*
attention_mask
[:,
None
,
None
,
:].
eq
(
0.0
)
padding_mask
=
causal_mask
[:,
:,
:,
:
mask_length
]
+
attention_mask
[:,
None
,
None
,
:]
causal_mask
[...,
:
mask_length
]
=
causal_mask
[...,
:
mask_length
].
masked_fill
(
padding_mask
,
min_dtype
)
padding_mask
=
padding_mask
==
0
causal_mask
[:,
:,
:,
:
mask_length
]
=
causal_mask
[:,
:,
:,
:
mask_length
].
masked_fill
(
padding_mask
,
min_dtype
)
elif
attention_mask
.
dim
()
==
4
:
elif
attention_mask
.
dim
()
==
4
:
# backwards compatibility: we allow passing a 4D attention mask shorter than the input length with
# backwards compatibility: we allow passing a 4D attention mask shorter than the input length with
# cache. In that case, the 4D attention mask attends to the newest tokens only.
# cache. In that case, the 4D attention mask attends to the newest tokens only.
...
...
src/transformers/models/gemma/modeling_gemma.py
View file @
fbabd674
...
@@ -1001,8 +1001,11 @@ class GemmaModel(GemmaPreTrainedModel):
...
@@ -1001,8 +1001,11 @@ class GemmaModel(GemmaPreTrainedModel):
causal_mask
=
causal_mask
.
clone
()
# copy to contiguous memory for in-place edit
causal_mask
=
causal_mask
.
clone
()
# copy to contiguous memory for in-place edit
if
attention_mask
.
dim
()
==
2
:
if
attention_mask
.
dim
()
==
2
:
mask_length
=
attention_mask
.
shape
[
-
1
]
mask_length
=
attention_mask
.
shape
[
-
1
]
padding_mask
=
causal_mask
[...,
:
mask_length
].
eq
(
0.0
)
*
attention_mask
[:,
None
,
None
,
:].
eq
(
0.0
)
padding_mask
=
causal_mask
[:,
:,
:,
:
mask_length
]
+
attention_mask
[:,
None
,
None
,
:]
causal_mask
[...,
:
mask_length
]
=
causal_mask
[...,
:
mask_length
].
masked_fill
(
padding_mask
,
min_dtype
)
padding_mask
=
padding_mask
==
0
causal_mask
[:,
:,
:,
:
mask_length
]
=
causal_mask
[:,
:,
:,
:
mask_length
].
masked_fill
(
padding_mask
,
min_dtype
)
elif
attention_mask
.
dim
()
==
4
:
elif
attention_mask
.
dim
()
==
4
:
# backwards compatibility: we allow passing a 4D attention mask shorter than the input length with
# backwards compatibility: we allow passing a 4D attention mask shorter than the input length with
# cache. In that case, the 4D attention mask attends to the newest tokens only.
# cache. In that case, the 4D attention mask attends to the newest tokens only.
...
...
src/transformers/models/llama/modeling_llama.py
View file @
fbabd674
...
@@ -1089,8 +1089,11 @@ class LlamaModel(LlamaPreTrainedModel):
...
@@ -1089,8 +1089,11 @@ class LlamaModel(LlamaPreTrainedModel):
causal_mask
=
causal_mask
.
clone
()
# copy to contiguous memory for in-place edit
causal_mask
=
causal_mask
.
clone
()
# copy to contiguous memory for in-place edit
if
attention_mask
.
dim
()
==
2
:
if
attention_mask
.
dim
()
==
2
:
mask_length
=
attention_mask
.
shape
[
-
1
]
mask_length
=
attention_mask
.
shape
[
-
1
]
padding_mask
=
causal_mask
[...,
:
mask_length
].
eq
(
0.0
)
*
attention_mask
[:,
None
,
None
,
:].
eq
(
0.0
)
padding_mask
=
causal_mask
[:,
:,
:,
:
mask_length
]
+
attention_mask
[:,
None
,
None
,
:]
causal_mask
[...,
:
mask_length
]
=
causal_mask
[...,
:
mask_length
].
masked_fill
(
padding_mask
,
min_dtype
)
padding_mask
=
padding_mask
==
0
causal_mask
[:,
:,
:,
:
mask_length
]
=
causal_mask
[:,
:,
:,
:
mask_length
].
masked_fill
(
padding_mask
,
min_dtype
)
elif
attention_mask
.
dim
()
==
4
:
elif
attention_mask
.
dim
()
==
4
:
# backwards compatibility: we allow passing a 4D attention mask shorter than the input length with
# backwards compatibility: we allow passing a 4D attention mask shorter than the input length with
# cache. In that case, the 4D attention mask attends to the newest tokens only.
# cache. In that case, the 4D attention mask attends to the newest tokens only.
...
...
src/transformers/models/olmo/modeling_olmo.py
View file @
fbabd674
...
@@ -1068,8 +1068,11 @@ class OlmoModel(OlmoPreTrainedModel):
...
@@ -1068,8 +1068,11 @@ class OlmoModel(OlmoPreTrainedModel):
causal_mask
=
causal_mask
.
clone
()
# copy to contiguous memory for in-place edit
causal_mask
=
causal_mask
.
clone
()
# copy to contiguous memory for in-place edit
if
attention_mask
.
dim
()
==
2
:
if
attention_mask
.
dim
()
==
2
:
mask_length
=
attention_mask
.
shape
[
-
1
]
mask_length
=
attention_mask
.
shape
[
-
1
]
padding_mask
=
causal_mask
[...,
:
mask_length
].
eq
(
0.0
)
*
attention_mask
[:,
None
,
None
,
:].
eq
(
0.0
)
padding_mask
=
causal_mask
[:,
:,
:,
:
mask_length
]
+
attention_mask
[:,
None
,
None
,
:]
causal_mask
[...,
:
mask_length
]
=
causal_mask
[...,
:
mask_length
].
masked_fill
(
padding_mask
,
min_dtype
)
padding_mask
=
padding_mask
==
0
causal_mask
[:,
:,
:,
:
mask_length
]
=
causal_mask
[:,
:,
:,
:
mask_length
].
masked_fill
(
padding_mask
,
min_dtype
)
elif
attention_mask
.
dim
()
==
4
:
elif
attention_mask
.
dim
()
==
4
:
# backwards compatibility: we allow passing a 4D attention mask shorter than the input length with
# backwards compatibility: we allow passing a 4D attention mask shorter than the input length with
# cache. In that case, the 4D attention mask attends to the newest tokens only.
# cache. In that case, the 4D attention mask attends to the newest tokens only.
...
...
src/transformers/training_args.py
View file @
fbabd674
...
@@ -84,12 +84,12 @@ if is_torch_neuroncore_available(check_device=False):
...
@@ -84,12 +84,12 @@ if is_torch_neuroncore_available(check_device=False):
if
os
.
environ
.
get
(
"TORCHELASTIC_RUN_ID"
):
if
os
.
environ
.
get
(
"TORCHELASTIC_RUN_ID"
):
if
is_optimum_neuron_available
():
if
is_optimum_neuron_available
():
logger
.
info
(
logger
.
info
(
"Make sure that you are performing the training with the
Trainium
Trainer from optimum[neuron], this "
"Make sure that you are performing the training with the
Neuron
Trainer from optimum[neuron], this "
"will fail otherwise."
"will fail otherwise."
)
)
else
:
else
:
logger
.
warning
(
logger
.
warning
(
"Please use the
Trainium
Trainer from optimum[neuron] instead of the Transformers library to perform "
"Please use the
Neuron
Trainer from optimum[neuron] instead of the Transformers library to perform "
"training on AWS Trainium instances. More information here: "
"training on AWS Trainium instances. More information here: "
"https://github.com/huggingface/optimum-neuron"
"https://github.com/huggingface/optimum-neuron"
)
)
...
...
src/transformers/utils/fx.py
View file @
fbabd674
...
@@ -15,22 +15,28 @@
...
@@ -15,22 +15,28 @@
import
builtins
import
builtins
import
collections
import
collections
import
contextlib
import
functools
import
functools
import
inspect
import
inspect
import
math
import
math
import
operator
import
operator
import
os
import
os
import
random
import
random
import
sys
import
warnings
import
warnings
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Optional
,
Type
,
Union
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Literal
,
Optional
,
Tuple
,
Type
,
Union
import
torch
import
torch
import
torch.utils._pytree
as
pytree
from
torch
import
nn
from
torch
import
nn
from
torch.fx
import
Graph
,
GraphModule
,
Proxy
,
Tracer
from
torch.fx
import
Graph
,
GraphModule
,
Node
,
Proxy
,
Tracer
from
torch.fx._compatibility
import
compatibility
from
torch.fx._compatibility
import
compatibility
from
torch.fx._symbolic_trace
import
is_fx_tracing
from
torch.fx.proxy
import
ParameterProxy
from
torch.fx.proxy
import
ParameterProxy
from
..
import
PretrainedConfig
,
PreTrainedModel
,
logging
from
..
import
logging
from
..cache_utils
import
Cache
,
DynamicCache
,
SinkCache
,
StaticCache
from
..modeling_utils
import
PretrainedConfig
,
PreTrainedModel
from
..models.auto
import
get_values
from
..models.auto
import
get_values
from
..models.auto.modeling_auto
import
(
from
..models.auto.modeling_auto
import
(
MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES
,
MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES
,
...
@@ -55,7 +61,7 @@ from ..models.auto.modeling_auto import (
...
@@ -55,7 +61,7 @@ from ..models.auto.modeling_auto import (
MODEL_MAPPING_NAMES
,
MODEL_MAPPING_NAMES
,
)
)
from
..pytorch_utils
import
is_torch_greater_or_equal_than_2_0
from
..pytorch_utils
import
is_torch_greater_or_equal_than_2_0
from
.
.
utils
import
(
from
.
import_
utils
import
(
ENV_VARS_TRUE_VALUES
,
ENV_VARS_TRUE_VALUES
,
TORCH_FX_REQUIRED_VERSION
,
TORCH_FX_REQUIRED_VERSION
,
get_torch_version
,
get_torch_version
,
...
@@ -192,6 +198,8 @@ _SPECIAL_SUPPORTED_MODELS = [
...
@@ -192,6 +198,8 @@ _SPECIAL_SUPPORTED_MODELS = [
]
]
_SUPPORTED_MODELS
=
tuple
(
sorted
(
set
(
_REGULAR_SUPPORTED_MODELS
+
_SPECIAL_SUPPORTED_MODELS
)))
_SUPPORTED_MODELS
=
tuple
(
sorted
(
set
(
_REGULAR_SUPPORTED_MODELS
+
_SPECIAL_SUPPORTED_MODELS
)))
_CURRENT_TRACER
=
None
def
torch_nn_embedding
(
self
,
input
):
def
torch_nn_embedding
(
self
,
input
):
return
torch
.
empty
(
*
input
.
shape
,
self
.
weight
.
shape
[
-
1
],
device
=
"meta"
,
dtype
=
self
.
weight
.
dtype
)
return
torch
.
empty
(
*
input
.
shape
,
self
.
weight
.
shape
[
-
1
],
device
=
"meta"
,
dtype
=
self
.
weight
.
dtype
)
...
@@ -701,6 +709,92 @@ class MetaDeviceAttribute(HFAttribute):
...
@@ -701,6 +709,92 @@ class MetaDeviceAttribute(HFAttribute):
pass
pass
class
HFCacheProxy
(
HFProxy
):
"""
Proxy that represents an instance of `transformers.cache_utils.Cache`.
"""
@
property
def
__class__
(
self
):
return
ProxyableCache
def
create_wrapper
(
function
:
Callable
,
op_type
:
Union
[
Literal
[
"call_function"
],
Literal
[
"call_method"
],
Literal
[
"get_attr"
]],
proxy_factory_fn
:
Optional
[
Callable
[[
Node
],
Proxy
]]
=
None
,
)
->
Callable
:
@
functools
.
wraps
(
function
)
def
wrapper
(
*
args
,
**
kwargs
):
if
not
is_fx_tracing
():
return
function
(
*
args
,
**
kwargs
)
found_proxies
=
[]
def
check_proxy
(
a
):
if
isinstance
(
a
,
Proxy
):
found_proxies
.
append
(
a
)
torch
.
fx
.
node
.
map_aggregate
(
args
,
check_proxy
)
torch
.
fx
.
node
.
map_aggregate
(
kwargs
,
check_proxy
)
if
len
(
found_proxies
)
>
0
:
tracer
=
found_proxies
[
0
].
tracer
if
op_type
==
"call_function"
:
target
=
function
elif
op_type
==
"call_method"
:
target
=
function
.
__name__
elif
op_type
==
"get_attr"
:
target
=
function
.
__name__
else
:
raise
ValueError
(
f
"op_type
{
op_type
}
not supported."
)
return
tracer
.
create_proxy
(
op_type
,
target
,
args
,
kwargs
,
proxy_factory_fn
=
proxy_factory_fn
)
else
:
return
function
(
*
args
,
**
kwargs
)
return
wrapper
class
HFProxyableClassMeta
(
type
):
"""
Metaclass that creates a class with its main methods wrapped to be proxyable.
"""
def
__new__
(
cls
,
name
:
str
,
bases
:
Tuple
[
Type
,
...],
attrs
:
Dict
[
str
,
Any
],
proxy_factory_fn
:
Optional
[
Callable
[[
Node
],
Proxy
]]
=
None
,
):
cls
=
super
().
__new__
(
cls
,
name
,
bases
,
attrs
)
for
attr_name
in
dir
(
cls
):
attr
=
getattr
(
cls
,
attr_name
,
None
)
if
attr
is
None
:
continue
if
attr_name
==
"__init__"
:
op_type
=
"call_function"
elif
attr_name
.
startswith
(
"__"
):
op_type
=
None
elif
inspect
.
ismethod
(
attr
):
op_type
=
"call_function"
elif
inspect
.
isfunction
(
attr
):
op_type
=
"call_method"
else
:
op_type
=
None
if
op_type
is
not
None
:
setattr
(
cls
,
attr_name
,
create_wrapper
(
attr
,
op_type
,
proxy_factory_fn
=
proxy_factory_fn
))
return
cls
def
gen_constructor_wrapper
(
target
:
Callable
)
->
Tuple
[
Callable
,
Callable
]:
"""
Wraps `target` to be proxyable. Used for tensor creators like `torch.ones`, `torch.arange` and so on.
"""
wrapper
=
create_wrapper
(
target
,
"call_function"
)
return
wrapper
,
target
def
_proxies_to_metas
(
v
):
def
_proxies_to_metas
(
v
):
"""Returns the underlying metadata for HFProxies, and behaves like the identity for the others."""
"""Returns the underlying metadata for HFProxies, and behaves like the identity for the others."""
if
isinstance
(
v
,
MetaDeviceAttribute
):
if
isinstance
(
v
,
MetaDeviceAttribute
):
...
@@ -712,25 +806,24 @@ def _proxies_to_metas(v):
...
@@ -712,25 +806,24 @@ def _proxies_to_metas(v):
return
v
return
v
def
_gen_constructor_wrapper
(
target
):
def
cache_proxy_factory_fn
(
n
:
Node
)
->
HFCacheProxy
:
@
functools
.
wraps
(
target
)
global
_CURRENT_TRACER
def
wrapper
(
*
args
,
**
kwargs
):
if
not
isinstance
(
_CURRENT_TRACER
,
HFTracer
):
proxy
=
None
raise
RuntimeError
(
"Cannot create HFCacheProxy because there is no HFTracer currently tracing."
)
return
HFCacheProxy
(
n
,
_CURRENT_TRACER
)
def
check_has_proxy
(
v
):
if
isinstance
(
v
,
Proxy
):
nonlocal
proxy
proxy
=
v
torch
.
fx
.
node
.
map_aggregate
(
args
,
check_has_proxy
)
# Proxyable equivalent of the cache classes defined in `transformers.cache_utils`.
torch
.
fx
.
node
.
map_aggregate
(
kwargs
,
check_has_proxy
)
ProxyableCache
=
HFProxyableClassMeta
(
"ProxyableCache"
,
(
Cache
,),
{},
proxy_factory_fn
=
cache_proxy_factory_fn
)
ProxyableDynamicCache
=
HFProxyableClassMeta
(
if
proxy
is
not
None
:
"ProxyableDynamicCache"
,
(
DynamicCache
,),
{},
proxy_factory_fn
=
cache_proxy_factory_fn
return
proxy
.
tracer
.
create_proxy
(
"call_function"
,
target
,
args
,
kwargs
)
)
else
:
ProxyableSinkCache
=
HFProxyableClassMeta
(
return
target
(
*
args
,
**
kwargs
)
"ProxyableSinkCache"
,
(
SinkCache
,),
{},
proxy_factory_fn
=
cache_proxy_factory_fn
)
return
wrapper
,
target
ProxyableStaticCache
=
HFProxyableClassMeta
(
"ProxyableStaticCache"
,
(
StaticCache
,),
{},
proxy_factory_fn
=
cache_proxy_factory_fn
)
def
_generate_random_int
(
low
:
int
=
10
,
high
:
int
=
20
,
forbidden_values
:
Optional
[
List
[
int
]]
=
None
):
def
_generate_random_int
(
low
:
int
=
10
,
high
:
int
=
20
,
forbidden_values
:
Optional
[
List
[
int
]]
=
None
):
...
@@ -764,6 +857,13 @@ class HFTracer(Tracer):
...
@@ -764,6 +857,13 @@ class HFTracer(Tracer):
"finfo"
,
"finfo"
,
"tril"
,
"tril"
,
]
]
_CLASSES_TO_PATCH
=
{
Cache
:
ProxyableCache
,
DynamicCache
:
ProxyableDynamicCache
,
SinkCache
:
ProxyableSinkCache
,
StaticCache
:
ProxyableStaticCache
,
}
supported_archs
=
(
PreTrainedModel
,)
if
not
is_peft_available
()
else
(
PreTrainedModel
,
PeftModel
)
supported_archs
=
(
PreTrainedModel
,)
if
not
is_peft_available
()
else
(
PreTrainedModel
,
PeftModel
)
def
__init__
(
self
,
autowrap_modules
=
(
math
,),
autowrap_functions
=
()):
def
__init__
(
self
,
autowrap_modules
=
(
math
,),
autowrap_functions
=
()):
...
@@ -776,7 +876,7 @@ class HFTracer(Tracer):
...
@@ -776,7 +876,7 @@ class HFTracer(Tracer):
)
)
def
_generate_dummy_input
(
def
_generate_dummy_input
(
self
,
model
:
PreTrainedModel
,
input_name
:
str
,
shape
:
List
[
int
],
input_names
:
List
[
str
]
self
,
model
:
"
PreTrainedModel
"
,
input_name
:
str
,
shape
:
List
[
int
],
input_names
:
List
[
str
]
)
->
Dict
[
str
,
torch
.
Tensor
]:
)
->
Dict
[
str
,
torch
.
Tensor
]:
"""Generates dummy input for model inference recording."""
"""Generates dummy input for model inference recording."""
# Retrieving the model class, either from the "class_for_deserialization" attribute if the model was restored
# Retrieving the model class, either from the "class_for_deserialization" attribute if the model was restored
...
@@ -951,6 +1051,11 @@ class HFTracer(Tracer):
...
@@ -951,6 +1051,11 @@ class HFTracer(Tracer):
args_metas
=
torch
.
fx
.
node
.
map_aggregate
(
args
,
_proxies_to_metas
)
args_metas
=
torch
.
fx
.
node
.
map_aggregate
(
args
,
_proxies_to_metas
)
kwargs_metas
=
torch
.
fx
.
node
.
map_aggregate
(
kwargs
,
_proxies_to_metas
)
kwargs_metas
=
torch
.
fx
.
node
.
map_aggregate
(
kwargs
,
_proxies_to_metas
)
should_install_metadata
=
True
self
.
_disable_module_getattr
=
True
self
.
_disable_call_module
=
True
if
kind
==
"call_function"
:
if
kind
==
"call_function"
:
meta_target
=
_MANUAL_META_OVERRIDES
.
get
(
target
,
target
)
meta_target
=
_MANUAL_META_OVERRIDES
.
get
(
target
,
target
)
meta_out
=
meta_target
(
*
args_metas
,
**
kwargs_metas
)
meta_out
=
meta_target
(
*
args_metas
,
**
kwargs_metas
)
...
@@ -963,39 +1068,36 @@ class HFTracer(Tracer):
...
@@ -963,39 +1068,36 @@ class HFTracer(Tracer):
elif
kind
==
"call_module"
:
elif
kind
==
"call_module"
:
if
not
hasattr
(
self
,
"orig_forward"
):
if
not
hasattr
(
self
,
"orig_forward"
):
raise
AttributeError
(
f
"
{
self
}
does not have an attribute called orig_forward"
)
raise
AttributeError
(
f
"
{
self
}
does not have an attribute called orig_forward"
)
self
.
_disable_module_getattr
=
True
mod
=
self
.
root
.
get_submodule
(
target
)
try
:
mod_type
=
type
(
mod
)
mod
=
self
.
root
.
get_submodule
(
target
)
if
mod_type
in
_MANUAL_META_OVERRIDES
:
mod_type
=
type
(
mod
)
meta_out
=
_MANUAL_META_OVERRIDES
[
mod_type
](
mod
,
*
args_metas
,
**
kwargs_metas
)
if
mod_type
in
_MANUAL_META_OVERRIDES
:
else
:
meta_out
=
_MANUAL_META_OVERRIDES
[
mod_type
](
mod
,
*
args_metas
,
**
kwargs_metas
)
meta_out
=
self
.
orig_forward
(
*
args_metas
,
**
kwargs_metas
)
else
:
meta_out
=
self
.
orig_forward
(
*
args_metas
,
**
kwargs_metas
)
finally
:
self
.
_disable_module_getattr
=
False
elif
kind
==
"get_attr"
:
elif
kind
==
"get_attr"
:
self
.
_disable_module_getattr
=
True
attr_itr
=
self
.
root
try
:
atoms
=
target
.
split
(
"."
)
attr_itr
=
self
.
root
for
atom
in
atoms
:
atoms
=
target
.
split
(
"."
)
attr_itr
=
getattr
(
attr_itr
,
atom
)
for
atom
in
atoms
:
if
isinstance
(
attr_itr
,
torch
.
Tensor
):
attr_itr
=
getattr
(
attr_itr
,
atom
)
meta_out
=
attr_itr
.
to
(
device
=
"meta"
)
if
isinstance
(
attr_itr
,
torch
.
Tensor
):
else
:
meta_out
=
attr_itr
.
to
(
device
=
"meta"
)
meta_out
=
attr_itr
else
:
meta_out
=
attr_itr
finally
:
self
.
_disable_module_getattr
=
False
else
:
else
:
return
rv
should_install_metadata
=
False
if
should_install_metadata
:
if
not
isinstance
(
rv
,
Proxy
):
raise
ValueError
(
"Don't support composite output yet"
)
rv
.
install_metadata
(
meta_out
)
if
not
isinstance
(
rv
,
Proxy
):
raise
ValueError
(
"Don't support composite output yet"
)
rv
.
install_metadata
(
meta_out
)
except
Exception
as
e
:
except
Exception
as
e
:
if
_IS_IN_DEBUG_MODE
:
if
_IS_IN_DEBUG_MODE
:
warnings
.
warn
(
f
"Could not compute metadata for
{
kind
}
target
{
target
}
:
{
e
}
"
)
warnings
.
warn
(
f
"Could not compute metadata for
{
kind
}
target
{
target
}
:
{
e
}
"
)
self
.
_disable_module_getattr
=
False
self
.
_disable_call_module
=
False
return
rv
return
rv
# Replaced by .getattr from PyTorch 1.13
# Replaced by .getattr from PyTorch 1.13
...
@@ -1041,12 +1143,51 @@ class HFTracer(Tracer):
...
@@ -1041,12 +1143,51 @@ class HFTracer(Tracer):
return
self
.
_module_getattr
(
attr
,
attr_val
,
parameter_proxy_cache
)
return
self
.
_module_getattr
(
attr
,
attr_val
,
parameter_proxy_cache
)
def
call_module
(
self
,
m
,
forward
,
args
,
kwargs
):
def
call_module
(
self
,
m
,
forward
,
args
,
kwargs
):
if
getattr
(
self
,
"_disable_call_module"
,
False
):
return
forward
(
*
args
,
**
kwargs
)
self
.
orig_forward
=
forward
self
.
orig_forward
=
forward
return
super
().
call_module
(
m
,
forward
,
args
,
kwargs
)
return
super
().
call_module
(
m
,
forward
,
args
,
kwargs
)
def
proxy
(
self
,
node
):
def
proxy
(
self
,
node
):
return
HFProxy
(
node
,
self
)
return
HFProxy
(
node
,
self
)
@
contextlib
.
contextmanager
def
patch_for_tracing
(
self
,
root
:
Union
[
torch
.
nn
.
Module
,
Callable
[...,
Any
]]):
# Patching torch functions
self
.
patched_torch_methods
=
{
target
:
gen_constructor_wrapper
(
getattr
(
torch
,
target
))
for
target
in
self
.
_TORCH_METHODS_TO_PATCH
}
self
.
orig_fns
=
set
()
for
name
,
(
wrapper
,
orig
)
in
self
.
patched_torch_methods
.
items
():
setattr
(
torch
,
name
,
wrapper
)
self
.
orig_fns
.
add
(
orig
)
# Patching classes
patched
=
[]
module_of_model
=
inspect
.
getmodule
(
root
)
for
name
,
mod
in
sys
.
modules
.
items
():
if
module_of_model
is
not
None
and
mod
is
not
module_of_model
:
continue
if
not
name
.
startswith
(
"transformers"
):
continue
for
orig_cls
,
patched_cls
in
self
.
_CLASSES_TO_PATCH
.
items
():
for
attr_name
,
attr
in
mod
.
__dict__
.
items
():
if
attr
is
orig_cls
:
patched
.
append
((
mod
,
attr_name
,
orig_cls
))
setattr
(
mod
,
attr_name
,
patched_cls
)
yield
# Restoring patched functions and classes.
for
name
,
(
_
,
orig
)
in
self
.
patched_torch_methods
.
items
():
setattr
(
torch
,
name
,
orig
)
self
.
patched_torch_methods
=
{}
self
.
orig_fns
=
set
()
for
mod
,
attr_name
,
orig_cls
in
patched
:
setattr
(
mod
,
attr_name
,
orig_cls
)
def
trace
(
def
trace
(
self
,
self
,
root
:
Union
[
torch
.
nn
.
Module
,
Callable
[...,
Any
]],
root
:
Union
[
torch
.
nn
.
Module
,
Callable
[...,
Any
]],
...
@@ -1125,28 +1266,25 @@ class HFTracer(Tracer):
...
@@ -1125,28 +1266,25 @@ class HFTracer(Tracer):
" transformers.PreTrainedModel."
" transformers.PreTrainedModel."
)
)
concrete_metas
=
{
def
to_meta
(
value
):
input_name
:
input_
.
to
(
"meta"
)
if
isinstance
(
input_
,
torch
.
Tensor
)
else
input_
if
isinstance
(
value
,
torch
.
Tensor
):
for
input_name
,
input_
in
inputs
.
items
()
return
value
.
to
(
"meta"
)
}
return
value
concrete_metas
=
pytree
.
tree_map
(
to_meta
,
inputs
)
for
param
in
sig
.
parameters
.
values
():
for
param
in
sig
.
parameters
.
values
():
if
param
.
kind
==
inspect
.
Parameter
.
VAR_KEYWORD
and
param
.
name
not
in
input_names
:
if
param
.
kind
==
inspect
.
Parameter
.
VAR_KEYWORD
and
param
.
name
not
in
input_names
:
concrete_metas
[
f
"**
{
param
.
name
}
"
]
=
{}
concrete_metas
[
f
"**
{
param
.
name
}
"
]
=
{}
self
.
meta_args
=
concrete_metas
self
.
meta_args
=
concrete_metas
self
.
patched_torch_methods
=
{
target
:
_gen_constructor_wrapper
(
getattr
(
torch
,
target
))
for
target
in
self
.
_TORCH_METHODS_TO_PATCH
}
self
.
orig_fns
=
set
()
for
name
,
(
wrapper
,
orig
)
in
self
.
patched_torch_methods
.
items
():
global
_CURRENT_TRACER
setattr
(
torch
,
name
,
wrapper
)
_CURRENT_TRACER
=
self
self
.
orig_fns
.
add
(
orig
)
with
self
.
patch_for_tracing
(
root
):
try
:
try
:
self
.
graph
=
super
().
trace
(
root
,
concrete_args
=
concrete_args
)
self
.
graph
=
super
().
trace
(
root
,
concrete_args
=
concrete_args
)
finally
:
finally
:
_CURRENT_TRACER
=
None
for
name
,
(
_
,
orig
)
in
self
.
patched_torch_methods
.
items
():
setattr
(
torch
,
name
,
orig
)
# This is necessary because concrete args are added as input to the traced module since
# This is necessary because concrete args are added as input to the traced module since
# https://github.com/pytorch/pytorch/pull/55888.
# https://github.com/pytorch/pytorch/pull/55888.
...
@@ -1256,11 +1394,11 @@ def get_concrete_args(model: nn.Module, input_names: List[str]):
...
@@ -1256,11 +1394,11 @@ def get_concrete_args(model: nn.Module, input_names: List[str]):
return
{
p
.
name
:
p
.
default
for
p
in
sig
.
parameters
.
values
()
if
p
.
name
not
in
input_names
}
return
{
p
.
name
:
p
.
default
for
p
in
sig
.
parameters
.
values
()
if
p
.
name
not
in
input_names
}
def
is_model_supported
(
model
:
PreTrainedModel
):
def
is_model_supported
(
model
:
"
PreTrainedModel
"
):
return
model
.
__class__
.
__name__
in
_SUPPORTED_MODELS
return
model
.
__class__
.
__name__
in
_SUPPORTED_MODELS
def
check_if_model_is_supported
(
model
:
PreTrainedModel
):
def
check_if_model_is_supported
(
model
:
"
PreTrainedModel
"
):
if
not
is_model_supported
(
model
):
if
not
is_model_supported
(
model
):
supported_model_names
=
", "
.
join
(
_SUPPORTED_MODELS
)
supported_model_names
=
", "
.
join
(
_SUPPORTED_MODELS
)
raise
NotImplementedError
(
raise
NotImplementedError
(
...
@@ -1269,7 +1407,7 @@ def check_if_model_is_supported(model: PreTrainedModel):
...
@@ -1269,7 +1407,7 @@ def check_if_model_is_supported(model: PreTrainedModel):
def
symbolic_trace
(
def
symbolic_trace
(
model
:
PreTrainedModel
,
model
:
"
PreTrainedModel
"
,
input_names
:
Optional
[
List
[
str
]]
=
None
,
input_names
:
Optional
[
List
[
str
]]
=
None
,
disable_check
:
bool
=
False
,
disable_check
:
bool
=
False
,
tracer_cls
:
Type
[
HFTracer
]
=
HFTracer
,
tracer_cls
:
Type
[
HFTracer
]
=
HFTracer
,
...
@@ -1307,6 +1445,18 @@ def symbolic_trace(
...
@@ -1307,6 +1445,18 @@ def symbolic_trace(
if
not
disable_check
:
if
not
disable_check
:
check_if_model_is_supported
(
model
)
check_if_model_is_supported
(
model
)
if
"past_key_values"
in
input_names
and
not
getattr
(
model
.
config
,
"use_cache"
,
False
):
logger
.
warning
(
"`past_key_values` were specified as input names, but model.config.use_cache = False, this might lead to "
"unexpected behavior."
)
if
"past_key_values"
not
in
input_names
and
getattr
(
model
.
config
,
"use_cache"
,
False
):
logger
.
warning
(
"`past_key_values` were not specified as input names, but model.config.use_cache = True. Setting "
"model.config.use_cache = False."
)
model
.
config
.
use_cache
=
False
# Tracing.
# Tracing.
tracer
=
tracer_cls
()
tracer
=
tracer_cls
()
traced_graph
=
tracer
.
trace
(
model
,
concrete_args
=
concrete_args
)
traced_graph
=
tracer
.
trace
(
model
,
concrete_args
=
concrete_args
)
...
...
tests/test_modeling_common.py
View file @
fbabd674
...
@@ -18,7 +18,6 @@ import gc
...
@@ -18,7 +18,6 @@ import gc
import
inspect
import
inspect
import
os
import
os
import
os.path
import
os.path
import
pickle
import
random
import
random
import
re
import
re
import
tempfile
import
tempfile
...
@@ -1279,26 +1278,6 @@ class ModelTesterMixin:
...
@@ -1279,26 +1278,6 @@ class ModelTesterMixin:
f
"traced
{
i
}
th output doesn't match model
{
i
}
th output for
{
model_class
}
"
,
f
"traced
{
i
}
th output doesn't match model
{
i
}
th output for
{
model_class
}
"
,
)
)
# Test that the model can be serialized and restored properly
with
tempfile
.
TemporaryDirectory
()
as
tmp_dir_name
:
pkl_file_name
=
os
.
path
.
join
(
tmp_dir_name
,
"model.pkl"
)
try
:
with
open
(
pkl_file_name
,
"wb"
)
as
f
:
pickle
.
dump
(
traced_model
,
f
)
with
open
(
pkl_file_name
,
"rb"
)
as
f
:
loaded
=
pickle
.
load
(
f
)
except
Exception
as
e
:
self
.
fail
(
f
"Couldn't serialize / deserialize the traced model:
{
e
}
"
)
loaded_output
=
loaded
(
**
filtered_inputs
)
loaded_output
=
flatten_output
(
loaded_output
)
for
i
in
range
(
num_outputs
):
self
.
assertTrue
(
torch
.
allclose
(
model_output
[
i
],
loaded_output
[
i
]),
f
"serialized model
{
i
}
th output doesn't match model
{
i
}
th output for
{
model_class
}
"
,
)
# Avoid memory leak. Without this, each call increase RAM usage by ~20MB.
# Avoid memory leak. Without this, each call increase RAM usage by ~20MB.
# (Even with this call, there are still memory leak by ~0.04MB)
# (Even with this call, there are still memory leak by ~0.04MB)
self
.
clear_torch_jit_class_registry
()
self
.
clear_torch_jit_class_registry
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment