Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
31484afb
Unverified
Commit
31484afb
authored
May 25, 2022
by
Sylvain Gugger
Committed by
GitHub
May 25, 2022
Browse files
Add test for new model parallelism features (#17401)
parent
56b35ce3
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
103 additions
and
9 deletions
+103
-9
src/transformers/modeling_utils.py
src/transformers/modeling_utils.py
+7
-1
src/transformers/models/t5/modeling_t5.py
src/transformers/models/t5/modeling_t5.py
+6
-8
tests/test_modeling_common.py
tests/test_modeling_common.py
+90
-0
No files found.
src/transformers/modeling_utils.py
View file @
31484afb
...
@@ -1734,6 +1734,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
...
@@ -1734,6 +1734,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
same device.
same device.
To have Accelerate compute the most optimized `device_map` automatically, set `device_map="auto"`.
To have Accelerate compute the most optimized `device_map` automatically, set `device_map="auto"`.
max_memory (`Dict`, *optional*):
A dictionary device identifier to maximum memory. Will default to the maximum memory available for each
GPU and the available CPU RAM if unset.
offload_folder (`str` or `os.PathLike`, *optional*):
offload_folder (`str` or `os.PathLike`, *optional*):
If the `device_map` contains any value `"disk"`, the folder where we will offload weights.
If the `device_map` contains any value `"disk"`, the folder where we will offload weights.
offload_state_dict (`bool`, *optional*, defaults to `False`):
offload_state_dict (`bool`, *optional*, defaults to `False`):
...
@@ -1822,6 +1825,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
...
@@ -1822,6 +1825,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
torch_dtype
=
kwargs
.
pop
(
"torch_dtype"
,
None
)
torch_dtype
=
kwargs
.
pop
(
"torch_dtype"
,
None
)
low_cpu_mem_usage
=
kwargs
.
pop
(
"low_cpu_mem_usage"
,
None
)
low_cpu_mem_usage
=
kwargs
.
pop
(
"low_cpu_mem_usage"
,
None
)
device_map
=
kwargs
.
pop
(
"device_map"
,
None
)
device_map
=
kwargs
.
pop
(
"device_map"
,
None
)
max_memory
=
kwargs
.
pop
(
"max_memory"
,
None
)
offload_folder
=
kwargs
.
pop
(
"offload_folder"
,
None
)
offload_folder
=
kwargs
.
pop
(
"offload_folder"
,
None
)
offload_state_dict
=
kwargs
.
pop
(
"offload_state_dict"
,
False
)
offload_state_dict
=
kwargs
.
pop
(
"offload_state_dict"
,
False
)
...
@@ -2119,7 +2123,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
...
@@ -2119,7 +2123,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
if
model
.
_no_split_modules
is
None
:
if
model
.
_no_split_modules
is
None
:
raise
ValueError
(
f
"
{
model
.
__class__
.
__name__
}
does not support `device_map='auto'` yet."
)
raise
ValueError
(
f
"
{
model
.
__class__
.
__name__
}
does not support `device_map='auto'` yet."
)
no_split_modules
=
model
.
_no_split_modules
no_split_modules
=
model
.
_no_split_modules
device_map
=
infer_auto_device_map
(
model
,
no_split_module_classes
=
no_split_modules
,
dtype
=
torch_dtype
)
device_map
=
infer_auto_device_map
(
model
,
no_split_module_classes
=
no_split_modules
,
dtype
=
torch_dtype
,
max_memory
=
max_memory
)
if
from_tf
:
if
from_tf
:
if
resolved_archive_file
.
endswith
(
".index"
):
if
resolved_archive_file
.
endswith
(
".index"
):
...
...
src/transformers/models/t5/modeling_t5.py
View file @
31484afb
...
@@ -420,14 +420,12 @@ class T5Attention(nn.Module):
...
@@ -420,14 +420,12 @@ class T5Attention(nn.Module):
relative_buckets
+=
torch
.
where
(
is_small
,
relative_position
,
relative_position_if_large
)
relative_buckets
+=
torch
.
where
(
is_small
,
relative_position
,
relative_position_if_large
)
return
relative_buckets
return
relative_buckets
def
compute_bias
(
self
,
query_length
,
key_length
):
def
compute_bias
(
self
,
query_length
,
key_length
,
device
=
None
):
"""Compute binned relative position bias"""
"""Compute binned relative position bias"""
context_position
=
torch
.
arange
(
if
device
is
None
:
query_length
,
dtype
=
torch
.
long
,
device
=
self
.
relative_attention_bias
.
weight
.
device
device
=
self
.
relative_attention_bias
.
weight
.
device
)[:,
None
]
context_position
=
torch
.
arange
(
query_length
,
dtype
=
torch
.
long
,
device
=
device
)[:,
None
]
memory_position
=
torch
.
arange
(
memory_position
=
torch
.
arange
(
key_length
,
dtype
=
torch
.
long
,
device
=
device
)[
None
,
:]
key_length
,
dtype
=
torch
.
long
,
device
=
self
.
relative_attention_bias
.
weight
.
device
)[
None
,
:]
relative_position
=
memory_position
-
context_position
# shape (query_length, key_length)
relative_position
=
memory_position
-
context_position
# shape (query_length, key_length)
relative_position_bucket
=
self
.
_relative_position_bucket
(
relative_position_bucket
=
self
.
_relative_position_bucket
(
relative_position
,
# shape (query_length, key_length)
relative_position
,
# shape (query_length, key_length)
...
@@ -522,7 +520,7 @@ class T5Attention(nn.Module):
...
@@ -522,7 +520,7 @@ class T5Attention(nn.Module):
if
self
.
gradient_checkpointing
and
self
.
training
:
if
self
.
gradient_checkpointing
and
self
.
training
:
position_bias
.
requires_grad
=
True
position_bias
.
requires_grad
=
True
else
:
else
:
position_bias
=
self
.
compute_bias
(
real_seq_length
,
key_length
)
position_bias
=
self
.
compute_bias
(
real_seq_length
,
key_length
,
device
=
scores
.
device
)
# if key and values are already calculated
# if key and values are already calculated
# we want only the last query position bias
# we want only the last query position bias
...
...
tests/test_modeling_common.py
View file @
31484afb
...
@@ -51,7 +51,9 @@ from transformers.testing_utils import (
...
@@ -51,7 +51,9 @@ from transformers.testing_utils import (
is_pt_flax_cross_test
,
is_pt_flax_cross_test
,
is_pt_tf_cross_test
,
is_pt_tf_cross_test
,
is_staging_test
,
is_staging_test
,
require_accelerate
,
require_torch
,
require_torch
,
require_torch_gpu
,
require_torch_multi_gpu
,
require_torch_multi_gpu
,
require_usr_bin_time
,
require_usr_bin_time
,
slow
,
slow
,
...
@@ -60,6 +62,7 @@ from transformers.testing_utils import (
...
@@ -60,6 +62,7 @@ from transformers.testing_utils import (
from
transformers.utils
import
(
from
transformers.utils
import
(
WEIGHTS_INDEX_NAME
,
WEIGHTS_INDEX_NAME
,
WEIGHTS_NAME
,
WEIGHTS_NAME
,
is_accelerate_available
,
is_flax_available
,
is_flax_available
,
is_tf_available
,
is_tf_available
,
is_torch_fx_available
,
is_torch_fx_available
,
...
@@ -72,6 +75,10 @@ sys.path.append(str(Path(__file__).parent.parent / "utils"))
...
@@ -72,6 +75,10 @@ sys.path.append(str(Path(__file__).parent.parent / "utils"))
from
test_module.custom_configuration
import
CustomConfig
,
NoSuperInitConfig
# noqa E402
from
test_module.custom_configuration
import
CustomConfig
,
NoSuperInitConfig
# noqa E402
if
is_accelerate_available
():
from
accelerate.utils
import
compute_module_sizes
if
is_torch_available
():
if
is_torch_available
():
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
...
@@ -2178,6 +2185,86 @@ class ModelTesterMixin:
...
@@ -2178,6 +2185,86 @@ class ModelTesterMixin:
model
.
parallelize
()
model
.
parallelize
()
model
.
generate
(
**
cast_to_device
(
inputs_dict
,
"cuda:0"
),
num_beams
=
2
)
model
.
generate
(
**
cast_to_device
(
inputs_dict
,
"cuda:0"
),
num_beams
=
2
)
def
check_device_map_is_respected
(
self
,
model
,
device_map
):
for
param_name
,
param
in
model
.
named_parameters
():
# Find device in device_map
while
len
(
param_name
)
>
0
and
param_name
not
in
device_map
:
param_name
=
"."
.
join
(
param_name
.
split
(
"."
)[:
-
1
])
if
param_name
not
in
device_map
:
raise
ValueError
(
"device map is incomplete, it does not contain any device for `param_name`."
)
param_device
=
device_map
[
param_name
]
if
param_device
in
[
"cpu"
,
"disk"
]:
self
.
assertEqual
(
param
.
device
,
torch
.
device
(
"meta"
))
else
:
self
.
assertEqual
(
param
.
device
,
torch
.
device
(
param_device
))
@
require_accelerate
@
require_torch_gpu
def
test_cpu_offload
(
self
):
config
,
inputs_dict
=
self
.
model_tester
.
prepare_config_and_inputs_for_common
()
if
config
.
num_hidden_layers
<
5
:
config
.
num_hidden_layers
=
5
for
model_class
in
self
.
all_model_classes
:
if
model_class
.
_no_split_modules
is
None
:
continue
inputs_dict
=
self
.
_prepare_for_class
(
inputs_dict
,
model_class
)
model
=
model_class
(
config
).
eval
()
model
=
model
.
to
(
torch_device
)
base_output
=
model
(
**
inputs_dict
)
model_size
=
compute_module_sizes
(
model
)[
""
]
# We test several splits of sizes to make sure it works.
max_gpu_sizes
=
[
int
(
p
*
model_size
)
for
p
in
[
0.5
,
0.7
,
0.9
]]
with
tempfile
.
TemporaryDirectory
()
as
tmp_dir
:
model
.
cpu
().
save_pretrained
(
tmp_dir
)
for
max_size
in
max_gpu_sizes
:
max_memory
=
{
0
:
max_size
,
"cpu"
:
model_size
*
2
}
new_model
=
model_class
.
from_pretrained
(
tmp_dir
,
device_map
=
"auto"
,
max_memory
=
max_memory
)
# Making sure part of the model will actually end up offloaded
self
.
assertSetEqual
(
set
(
new_model
.
hf_device_map
.
values
()),
{
0
,
"cpu"
})
self
.
check_device_map_is_respected
(
new_model
,
new_model
.
hf_device_map
)
new_output
=
new_model
(
**
inputs_dict
)
self
.
assertTrue
(
torch
.
allclose
(
base_output
[
0
],
new_output
[
0
]))
@
require_accelerate
@
require_torch_multi_gpu
def
test_model_parallelism
(
self
):
config
,
inputs_dict
=
self
.
model_tester
.
prepare_config_and_inputs_for_common
()
if
config
.
num_hidden_layers
<
5
:
config
.
num_hidden_layers
=
5
for
model_class
in
self
.
all_model_classes
:
if
model_class
.
_no_split_modules
is
None
:
continue
inputs_dict
=
self
.
_prepare_for_class
(
inputs_dict
,
model_class
)
model
=
model_class
(
config
).
eval
()
model
=
model
.
to
(
torch_device
)
base_output
=
model
(
**
inputs_dict
)
model_size
=
compute_module_sizes
(
model
)[
""
]
# We test several splits of sizes to make sure it works.
max_gpu_sizes
=
[
int
(
p
*
model_size
)
for
p
in
[
0.5
,
0.7
,
0.9
]]
with
tempfile
.
TemporaryDirectory
()
as
tmp_dir
:
model
.
cpu
().
save_pretrained
(
tmp_dir
)
for
max_size
in
max_gpu_sizes
:
max_memory
=
{
0
:
max_size
,
1
:
model_size
*
2
,
"cpu"
:
model_size
*
2
}
new_model
=
model_class
.
from_pretrained
(
tmp_dir
,
device_map
=
"auto"
,
max_memory
=
max_memory
)
# Making sure part of the model will actually end up offloaded
self
.
assertSetEqual
(
set
(
new_model
.
hf_device_map
.
values
()),
{
0
,
1
})
self
.
check_device_map_is_respected
(
new_model
,
new_model
.
hf_device_map
)
new_output
=
new_model
(
**
inputs_dict
)
self
.
assertTrue
(
torch
.
allclose
(
base_output
[
0
],
new_output
[
0
]))
def
test_problem_types
(
self
):
def
test_problem_types
(
self
):
config
,
inputs_dict
=
self
.
model_tester
.
prepare_config_and_inputs_for_common
()
config
,
inputs_dict
=
self
.
model_tester
.
prepare_config_and_inputs_for_common
()
...
@@ -2547,6 +2634,7 @@ class ModelUtilsTest(TestCasePlus):
...
@@ -2547,6 +2634,7 @@ class ModelUtilsTest(TestCasePlus):
for
p1
,
p2
in
zip
(
model
.
parameters
(),
ref_model
.
parameters
()):
for
p1
,
p2
in
zip
(
model
.
parameters
(),
ref_model
.
parameters
()):
self
.
assertTrue
(
torch
.
allclose
(
p1
,
p2
))
self
.
assertTrue
(
torch
.
allclose
(
p1
,
p2
))
@
require_accelerate
def
test_from_pretrained_low_cpu_mem_usage_functional
(
self
):
def
test_from_pretrained_low_cpu_mem_usage_functional
(
self
):
# test that we can use `from_pretrained(..., low_cpu_mem_usage=True)` with normal and
# test that we can use `from_pretrained(..., low_cpu_mem_usage=True)` with normal and
# sharded models
# sharded models
...
@@ -2559,6 +2647,7 @@ class ModelUtilsTest(TestCasePlus):
...
@@ -2559,6 +2647,7 @@ class ModelUtilsTest(TestCasePlus):
_
=
BertModel
.
from_pretrained
(
mname
,
low_cpu_mem_usage
=
True
)
_
=
BertModel
.
from_pretrained
(
mname
,
low_cpu_mem_usage
=
True
)
@
require_usr_bin_time
@
require_usr_bin_time
@
require_accelerate
def
test_from_pretrained_low_cpu_mem_usage_measured
(
self
):
def
test_from_pretrained_low_cpu_mem_usage_measured
(
self
):
# test that `from_pretrained(..., low_cpu_mem_usage=True)` uses less cpu memory than default
# test that `from_pretrained(..., low_cpu_mem_usage=True)` uses less cpu memory than default
...
@@ -2597,6 +2686,7 @@ class ModelUtilsTest(TestCasePlus):
...
@@ -2597,6 +2686,7 @@ class ModelUtilsTest(TestCasePlus):
# functionality to load models directly on gpu, this test can be rewritten to use torch's
# functionality to load models directly on gpu, this test can be rewritten to use torch's
# cuda memory tracking and then we should be able to do a much more precise test.
# cuda memory tracking and then we should be able to do a much more precise test.
@
require_accelerate
@
require_torch_multi_gpu
@
require_torch_multi_gpu
@
slow
@
slow
def
test_model_parallelism_gpt2
(
self
):
def
test_model_parallelism_gpt2
(
self
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment