Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
a1f9a712
Unverified
Commit
a1f9a712
authored
Jan 21, 2025
by
YiYi Xu
Committed by
GitHub
Jan 21, 2025
Browse files
fix offload gpu tests etc (#10366)
* add * style
parent
ec37e209
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
24 additions
and
39 deletions
+24
-39
src/diffusers/models/transformers/sana_transformer.py
src/diffusers/models/transformers/sana_transformer.py
+17
-9
tests/models/test_modeling_common.py
tests/models/test_modeling_common.py
+6
-5
tests/models/transformers/test_models_transformer_sana.py
tests/models/transformers/test_models_transformer_sana.py
+1
-25
No files found.
src/diffusers/models/transformers/sana_transformer.py
View file @
a1f9a712
...
...
@@ -82,6 +82,20 @@ class GLUMBConv(nn.Module):
return
hidden_states
class
SanaModulatedNorm
(
nn
.
Module
):
def
__init__
(
self
,
dim
:
int
,
elementwise_affine
:
bool
=
False
,
eps
:
float
=
1e-6
):
super
().
__init__
()
self
.
norm
=
nn
.
LayerNorm
(
dim
,
elementwise_affine
=
elementwise_affine
,
eps
=
eps
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
temb
:
torch
.
Tensor
,
scale_shift_table
:
torch
.
Tensor
)
->
torch
.
Tensor
:
hidden_states
=
self
.
norm
(
hidden_states
)
shift
,
scale
=
(
scale_shift_table
[
None
]
+
temb
[:,
None
].
to
(
scale_shift_table
.
device
)).
chunk
(
2
,
dim
=
1
)
hidden_states
=
hidden_states
*
(
1
+
scale
)
+
shift
return
hidden_states
class
SanaTransformerBlock
(
nn
.
Module
):
r
"""
Transformer block introduced in [Sana](https://huggingface.co/papers/2410.10629).
...
...
@@ -221,7 +235,7 @@ class SanaTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
"""
_supports_gradient_checkpointing
=
True
_no_split_modules
=
[
"SanaTransformerBlock"
,
"PatchEmbed"
]
_no_split_modules
=
[
"SanaTransformerBlock"
,
"PatchEmbed"
,
"SanaModulatedNorm"
]
@
register_to_config
def
__init__
(
...
...
@@ -288,8 +302,7 @@ class SanaTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
# 4. Output blocks
self
.
scale_shift_table
=
nn
.
Parameter
(
torch
.
randn
(
2
,
inner_dim
)
/
inner_dim
**
0.5
)
self
.
norm_out
=
nn
.
LayerNorm
(
inner_dim
,
elementwise_affine
=
False
,
eps
=
1e-6
)
self
.
norm_out
=
SanaModulatedNorm
(
inner_dim
,
elementwise_affine
=
False
,
eps
=
1e-6
)
self
.
proj_out
=
nn
.
Linear
(
inner_dim
,
patch_size
*
patch_size
*
out_channels
)
self
.
gradient_checkpointing
=
False
...
...
@@ -462,13 +475,8 @@ class SanaTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
)
# 3. Normalization
shift
,
scale
=
(
self
.
scale_shift_table
[
None
]
+
embedded_timestep
[:,
None
].
to
(
self
.
scale_shift_table
.
device
)
).
chunk
(
2
,
dim
=
1
)
hidden_states
=
self
.
norm_out
(
hidden_states
)
hidden_states
=
self
.
norm_out
(
hidden_states
,
embedded_timestep
,
self
.
scale_shift_table
)
# 4. Modulation
hidden_states
=
hidden_states
*
(
1
+
scale
)
+
shift
hidden_states
=
self
.
proj_out
(
hidden_states
)
# 5. Unpatchify
...
...
tests/models/test_modeling_common.py
View file @
a1f9a712
...
...
@@ -29,7 +29,7 @@ import numpy as np
import
requests_mock
import
torch
import
torch.nn
as
nn
from
accelerate.utils.modeling
import
_get_proper_dtype
,
dtype_byte_size
from
accelerate.utils.modeling
import
_get_proper_dtype
,
compute_module_sizes
,
dtype_byte_size
from
huggingface_hub
import
ModelCard
,
delete_repo
,
snapshot_download
from
huggingface_hub.utils
import
is_jinja_available
from
parameterized
import
parameterized
...
...
@@ -1080,7 +1080,7 @@ class ModelTesterMixin:
torch
.
manual_seed
(
0
)
base_output
=
model
(
**
inputs_dict
)
model_size
=
compute_module_
persistent_
sizes
(
model
)[
""
]
model_size
=
compute_module_sizes
(
model
)[
""
]
# We test several splits of sizes to make sure it works.
max_gpu_sizes
=
[
int
(
p
*
model_size
)
for
p
in
self
.
model_split_percents
[
1
:]]
with
tempfile
.
TemporaryDirectory
()
as
tmp_dir
:
...
...
@@ -1110,7 +1110,7 @@ class ModelTesterMixin:
torch
.
manual_seed
(
0
)
base_output
=
model
(
**
inputs_dict
)
model_size
=
compute_module_
persistent_
sizes
(
model
)[
""
]
model_size
=
compute_module_sizes
(
model
)[
""
]
with
tempfile
.
TemporaryDirectory
()
as
tmp_dir
:
model
.
cpu
().
save_pretrained
(
tmp_dir
,
safe_serialization
=
False
)
...
...
@@ -1144,7 +1144,7 @@ class ModelTesterMixin:
torch
.
manual_seed
(
0
)
base_output
=
model
(
**
inputs_dict
)
model_size
=
compute_module_
persistent_
sizes
(
model
)[
""
]
model_size
=
compute_module_sizes
(
model
)[
""
]
with
tempfile
.
TemporaryDirectory
()
as
tmp_dir
:
model
.
cpu
().
save_pretrained
(
tmp_dir
)
...
...
@@ -1172,7 +1172,7 @@ class ModelTesterMixin:
torch
.
manual_seed
(
0
)
base_output
=
model
(
**
inputs_dict
)
model_size
=
compute_module_
persistent_
sizes
(
model
)[
""
]
model_size
=
compute_module_sizes
(
model
)[
""
]
# We test several splits of sizes to make sure it works.
max_gpu_sizes
=
[
int
(
p
*
model_size
)
for
p
in
self
.
model_split_percents
[
1
:]]
with
tempfile
.
TemporaryDirectory
()
as
tmp_dir
:
...
...
@@ -1183,6 +1183,7 @@ class ModelTesterMixin:
new_model
=
self
.
model_class
.
from_pretrained
(
tmp_dir
,
device_map
=
"auto"
,
max_memory
=
max_memory
)
# Making sure part of the model will actually end up offloaded
self
.
assertSetEqual
(
set
(
new_model
.
hf_device_map
.
values
()),
{
0
,
1
})
print
(
f
" new_model.hf_device_map:
{
new_model
.
hf_device_map
}
"
)
self
.
check_device_map_is_respected
(
new_model
,
new_model
.
hf_device_map
)
...
...
tests/models/transformers/test_models_transformer_sana.py
View file @
a1f9a712
...
...
@@ -14,7 +14,6 @@
import
unittest
import
pytest
import
torch
from
diffusers
import
SanaTransformer2DModel
...
...
@@ -33,6 +32,7 @@ class SanaTransformerTests(ModelTesterMixin, unittest.TestCase):
model_class
=
SanaTransformer2DModel
main_input_name
=
"hidden_states"
uses_custom_attn_processor
=
True
model_split_percents
=
[
0.7
,
0.7
,
0.9
]
@
property
def
dummy_input
(
self
):
...
...
@@ -81,27 +81,3 @@ class SanaTransformerTests(ModelTesterMixin, unittest.TestCase):
def
test_gradient_checkpointing_is_applied
(
self
):
expected_set
=
{
"SanaTransformer2DModel"
}
super
().
test_gradient_checkpointing_is_applied
(
expected_set
=
expected_set
)
@
pytest
.
mark
.
xfail
(
condition
=
torch
.
device
(
torch_device
).
type
==
"cuda"
,
reason
=
"Test currently fails."
,
strict
=
True
,
)
def
test_cpu_offload
(
self
):
return
super
().
test_cpu_offload
()
@
pytest
.
mark
.
xfail
(
condition
=
torch
.
device
(
torch_device
).
type
==
"cuda"
,
reason
=
"Test currently fails."
,
strict
=
True
,
)
def
test_disk_offload_with_safetensors
(
self
):
return
super
().
test_disk_offload_with_safetensors
()
@
pytest
.
mark
.
xfail
(
condition
=
torch
.
device
(
torch_device
).
type
==
"cuda"
,
reason
=
"Test currently fails."
,
strict
=
True
,
)
def
test_disk_offload_without_safetensors
(
self
):
return
super
().
test_disk_offload_without_safetensors
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment