Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
31484afb
"docs/source/tutorials/darts.py.md5" did not exist on "eb77376e026657a1c5b1317104a46868629c3439"
Unverified
Commit
31484afb
authored
May 25, 2022
by
Sylvain Gugger
Committed by
GitHub
May 25, 2022
Browse files
Add test for new model parallelism features (#17401)
parent
56b35ce3
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
103 additions
and
9 deletions
+103
-9
src/transformers/modeling_utils.py
src/transformers/modeling_utils.py
+7
-1
src/transformers/models/t5/modeling_t5.py
src/transformers/models/t5/modeling_t5.py
+6
-8
tests/test_modeling_common.py
tests/test_modeling_common.py
+90
-0
No files found.
src/transformers/modeling_utils.py
View file @
31484afb
...
...
@@ -1734,6 +1734,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
same device.
To have Accelerate compute the most optimized `device_map` automatically, set `device_map="auto"`.
max_memory (`Dict`, *optional*):
A dictionary device identifier to maximum memory. Will default to the maximum memory available for each
GPU and the available CPU RAM if unset.
offload_folder (`str` or `os.PathLike`, *optional*):
If the `device_map` contains any value `"disk"`, the folder where we will offload weights.
offload_state_dict (`bool`, *optional*, defaults to `False`):
...
...
@@ -1822,6 +1825,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
torch_dtype
=
kwargs
.
pop
(
"torch_dtype"
,
None
)
low_cpu_mem_usage
=
kwargs
.
pop
(
"low_cpu_mem_usage"
,
None
)
device_map
=
kwargs
.
pop
(
"device_map"
,
None
)
max_memory
=
kwargs
.
pop
(
"max_memory"
,
None
)
offload_folder
=
kwargs
.
pop
(
"offload_folder"
,
None
)
offload_state_dict
=
kwargs
.
pop
(
"offload_state_dict"
,
False
)
...
...
@@ -2119,7 +2123,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
if
model
.
_no_split_modules
is
None
:
raise
ValueError
(
f
"
{
model
.
__class__
.
__name__
}
does not support `device_map='auto'` yet."
)
no_split_modules
=
model
.
_no_split_modules
device_map
=
infer_auto_device_map
(
model
,
no_split_module_classes
=
no_split_modules
,
dtype
=
torch_dtype
)
device_map
=
infer_auto_device_map
(
model
,
no_split_module_classes
=
no_split_modules
,
dtype
=
torch_dtype
,
max_memory
=
max_memory
)
if
from_tf
:
if
resolved_archive_file
.
endswith
(
".index"
):
...
...
src/transformers/models/t5/modeling_t5.py
View file @
31484afb
...
...
@@ -420,14 +420,12 @@ class T5Attention(nn.Module):
relative_buckets
+=
torch
.
where
(
is_small
,
relative_position
,
relative_position_if_large
)
return
relative_buckets
def
compute_bias
(
self
,
query_length
,
key_length
):
def
compute_bias
(
self
,
query_length
,
key_length
,
device
=
None
):
"""Compute binned relative position bias"""
context_position
=
torch
.
arange
(
query_length
,
dtype
=
torch
.
long
,
device
=
self
.
relative_attention_bias
.
weight
.
device
)[:,
None
]
memory_position
=
torch
.
arange
(
key_length
,
dtype
=
torch
.
long
,
device
=
self
.
relative_attention_bias
.
weight
.
device
)[
None
,
:]
if
device
is
None
:
device
=
self
.
relative_attention_bias
.
weight
.
device
context_position
=
torch
.
arange
(
query_length
,
dtype
=
torch
.
long
,
device
=
device
)[:,
None
]
memory_position
=
torch
.
arange
(
key_length
,
dtype
=
torch
.
long
,
device
=
device
)[
None
,
:]
relative_position
=
memory_position
-
context_position
# shape (query_length, key_length)
relative_position_bucket
=
self
.
_relative_position_bucket
(
relative_position
,
# shape (query_length, key_length)
...
...
@@ -522,7 +520,7 @@ class T5Attention(nn.Module):
if
self
.
gradient_checkpointing
and
self
.
training
:
position_bias
.
requires_grad
=
True
else
:
position_bias
=
self
.
compute_bias
(
real_seq_length
,
key_length
)
position_bias
=
self
.
compute_bias
(
real_seq_length
,
key_length
,
device
=
scores
.
device
)
# if key and values are already calculated
# we want only the last query position bias
...
...
tests/test_modeling_common.py
View file @
31484afb
...
...
@@ -51,7 +51,9 @@ from transformers.testing_utils import (
is_pt_flax_cross_test
,
is_pt_tf_cross_test
,
is_staging_test
,
require_accelerate
,
require_torch
,
require_torch_gpu
,
require_torch_multi_gpu
,
require_usr_bin_time
,
slow
,
...
...
@@ -60,6 +62,7 @@ from transformers.testing_utils import (
from
transformers.utils
import
(
WEIGHTS_INDEX_NAME
,
WEIGHTS_NAME
,
is_accelerate_available
,
is_flax_available
,
is_tf_available
,
is_torch_fx_available
,
...
...
@@ -72,6 +75,10 @@ sys.path.append(str(Path(__file__).parent.parent / "utils"))
from
test_module.custom_configuration
import
CustomConfig
,
NoSuperInitConfig
# noqa E402
if
is_accelerate_available
():
from
accelerate.utils
import
compute_module_sizes
if
is_torch_available
():
import
torch
from
torch
import
nn
...
...
@@ -2178,6 +2185,86 @@ class ModelTesterMixin:
model
.
parallelize
()
model
.
generate
(
**
cast_to_device
(
inputs_dict
,
"cuda:0"
),
num_beams
=
2
)
def
check_device_map_is_respected
(
self
,
model
,
device_map
):
for
param_name
,
param
in
model
.
named_parameters
():
# Find device in device_map
while
len
(
param_name
)
>
0
and
param_name
not
in
device_map
:
param_name
=
"."
.
join
(
param_name
.
split
(
"."
)[:
-
1
])
if
param_name
not
in
device_map
:
raise
ValueError
(
"device map is incomplete, it does not contain any device for `param_name`."
)
param_device
=
device_map
[
param_name
]
if
param_device
in
[
"cpu"
,
"disk"
]:
self
.
assertEqual
(
param
.
device
,
torch
.
device
(
"meta"
))
else
:
self
.
assertEqual
(
param
.
device
,
torch
.
device
(
param_device
))
@
require_accelerate
@
require_torch_gpu
def
test_cpu_offload
(
self
):
config
,
inputs_dict
=
self
.
model_tester
.
prepare_config_and_inputs_for_common
()
if
config
.
num_hidden_layers
<
5
:
config
.
num_hidden_layers
=
5
for
model_class
in
self
.
all_model_classes
:
if
model_class
.
_no_split_modules
is
None
:
continue
inputs_dict
=
self
.
_prepare_for_class
(
inputs_dict
,
model_class
)
model
=
model_class
(
config
).
eval
()
model
=
model
.
to
(
torch_device
)
base_output
=
model
(
**
inputs_dict
)
model_size
=
compute_module_sizes
(
model
)[
""
]
# We test several splits of sizes to make sure it works.
max_gpu_sizes
=
[
int
(
p
*
model_size
)
for
p
in
[
0.5
,
0.7
,
0.9
]]
with
tempfile
.
TemporaryDirectory
()
as
tmp_dir
:
model
.
cpu
().
save_pretrained
(
tmp_dir
)
for
max_size
in
max_gpu_sizes
:
max_memory
=
{
0
:
max_size
,
"cpu"
:
model_size
*
2
}
new_model
=
model_class
.
from_pretrained
(
tmp_dir
,
device_map
=
"auto"
,
max_memory
=
max_memory
)
# Making sure part of the model will actually end up offloaded
self
.
assertSetEqual
(
set
(
new_model
.
hf_device_map
.
values
()),
{
0
,
"cpu"
})
self
.
check_device_map_is_respected
(
new_model
,
new_model
.
hf_device_map
)
new_output
=
new_model
(
**
inputs_dict
)
self
.
assertTrue
(
torch
.
allclose
(
base_output
[
0
],
new_output
[
0
]))
@
require_accelerate
@
require_torch_multi_gpu
def
test_model_parallelism
(
self
):
config
,
inputs_dict
=
self
.
model_tester
.
prepare_config_and_inputs_for_common
()
if
config
.
num_hidden_layers
<
5
:
config
.
num_hidden_layers
=
5
for
model_class
in
self
.
all_model_classes
:
if
model_class
.
_no_split_modules
is
None
:
continue
inputs_dict
=
self
.
_prepare_for_class
(
inputs_dict
,
model_class
)
model
=
model_class
(
config
).
eval
()
model
=
model
.
to
(
torch_device
)
base_output
=
model
(
**
inputs_dict
)
model_size
=
compute_module_sizes
(
model
)[
""
]
# We test several splits of sizes to make sure it works.
max_gpu_sizes
=
[
int
(
p
*
model_size
)
for
p
in
[
0.5
,
0.7
,
0.9
]]
with
tempfile
.
TemporaryDirectory
()
as
tmp_dir
:
model
.
cpu
().
save_pretrained
(
tmp_dir
)
for
max_size
in
max_gpu_sizes
:
max_memory
=
{
0
:
max_size
,
1
:
model_size
*
2
,
"cpu"
:
model_size
*
2
}
new_model
=
model_class
.
from_pretrained
(
tmp_dir
,
device_map
=
"auto"
,
max_memory
=
max_memory
)
# Making sure part of the model will actually end up offloaded
self
.
assertSetEqual
(
set
(
new_model
.
hf_device_map
.
values
()),
{
0
,
1
})
self
.
check_device_map_is_respected
(
new_model
,
new_model
.
hf_device_map
)
new_output
=
new_model
(
**
inputs_dict
)
self
.
assertTrue
(
torch
.
allclose
(
base_output
[
0
],
new_output
[
0
]))
def
test_problem_types
(
self
):
config
,
inputs_dict
=
self
.
model_tester
.
prepare_config_and_inputs_for_common
()
...
...
@@ -2547,6 +2634,7 @@ class ModelUtilsTest(TestCasePlus):
for
p1
,
p2
in
zip
(
model
.
parameters
(),
ref_model
.
parameters
()):
self
.
assertTrue
(
torch
.
allclose
(
p1
,
p2
))
@
require_accelerate
def
test_from_pretrained_low_cpu_mem_usage_functional
(
self
):
# test that we can use `from_pretrained(..., low_cpu_mem_usage=True)` with normal and
# sharded models
...
...
@@ -2559,6 +2647,7 @@ class ModelUtilsTest(TestCasePlus):
_
=
BertModel
.
from_pretrained
(
mname
,
low_cpu_mem_usage
=
True
)
@
require_usr_bin_time
@
require_accelerate
def
test_from_pretrained_low_cpu_mem_usage_measured
(
self
):
# test that `from_pretrained(..., low_cpu_mem_usage=True)` uses less cpu memory than default
...
...
@@ -2597,6 +2686,7 @@ class ModelUtilsTest(TestCasePlus):
# functionality to load models directly on gpu, this test can be rewritten to use torch's
# cuda memory tracking and then we should be able to do a much more precise test.
@
require_accelerate
@
require_torch_multi_gpu
@
slow
def
test_model_parallelism_gpt2
(
self
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment