Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
5df02fc1
Unverified
Commit
5df02fc1
authored
Jun 24, 2025
by
Aryan
Committed by
GitHub
Jun 24, 2025
Browse files
[tests] Fix group offloading and layerwise casting test interaction (#11796)
* update * update * update
parent
7392c8ff
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
17 additions
and
16 deletions
+17
-16
src/diffusers/models/autoencoders/autoencoder_kl_cosmos.py
src/diffusers/models/autoencoders/autoencoder_kl_cosmos.py
+10
-8
tests/models/test_modeling_common.py
tests/models/test_modeling_common.py
+7
-8
No files found.
src/diffusers/models/autoencoders/autoencoder_kl_cosmos.py
View file @
5df02fc1
...
@@ -110,8 +110,11 @@ class CosmosPatchEmbed3d(nn.Module):
...
@@ -110,8 +110,11 @@ class CosmosPatchEmbed3d(nn.Module):
self
.
patch_size
=
patch_size
self
.
patch_size
=
patch_size
self
.
patch_method
=
patch_method
self
.
patch_method
=
patch_method
self
.
register_buffer
(
"wavelets"
,
_WAVELETS
[
patch_method
],
persistent
=
False
)
wavelets
=
_WAVELETS
.
get
(
patch_method
).
clone
()
self
.
register_buffer
(
"_arange"
,
torch
.
arange
(
_WAVELETS
[
patch_method
].
shape
[
0
]),
persistent
=
False
)
arange
=
torch
.
arange
(
wavelets
.
shape
[
0
])
self
.
register_buffer
(
"wavelets"
,
wavelets
,
persistent
=
False
)
self
.
register_buffer
(
"_arange"
,
arange
,
persistent
=
False
)
def
_dwt
(
self
,
hidden_states
:
torch
.
Tensor
,
mode
:
str
=
"reflect"
,
rescale
=
False
)
->
torch
.
Tensor
:
def
_dwt
(
self
,
hidden_states
:
torch
.
Tensor
,
mode
:
str
=
"reflect"
,
rescale
=
False
)
->
torch
.
Tensor
:
dtype
=
hidden_states
.
dtype
dtype
=
hidden_states
.
dtype
...
@@ -185,12 +188,11 @@ class CosmosUnpatcher3d(nn.Module):
...
@@ -185,12 +188,11 @@ class CosmosUnpatcher3d(nn.Module):
self
.
patch_size
=
patch_size
self
.
patch_size
=
patch_size
self
.
patch_method
=
patch_method
self
.
patch_method
=
patch_method
self
.
register_buffer
(
"wavelets"
,
_WAVELETS
[
patch_method
],
persistent
=
False
)
wavelets
=
_WAVELETS
.
get
(
patch_method
).
clone
()
self
.
register_buffer
(
arange
=
torch
.
arange
(
wavelets
.
shape
[
0
])
"_arange"
,
torch
.
arange
(
_WAVELETS
[
patch_method
].
shape
[
0
]),
self
.
register_buffer
(
"wavelets"
,
wavelets
,
persistent
=
False
)
persistent
=
False
,
self
.
register_buffer
(
"_arange"
,
arange
,
persistent
=
False
)
)
def
_idwt
(
self
,
hidden_states
:
torch
.
Tensor
,
rescale
:
bool
=
False
)
->
torch
.
Tensor
:
def
_idwt
(
self
,
hidden_states
:
torch
.
Tensor
,
rescale
:
bool
=
False
)
->
torch
.
Tensor
:
device
=
hidden_states
.
device
device
=
hidden_states
.
device
...
...
tests/models/test_modeling_common.py
View file @
5df02fc1
...
@@ -1528,14 +1528,16 @@ class ModelTesterMixin:
...
@@ -1528,14 +1528,16 @@ class ModelTesterMixin:
test_fn
(
torch
.
float8_e5m2
,
torch
.
float32
)
test_fn
(
torch
.
float8_e5m2
,
torch
.
float32
)
test_fn
(
torch
.
float8_e4m3fn
,
torch
.
bfloat16
)
test_fn
(
torch
.
float8_e4m3fn
,
torch
.
bfloat16
)
@
torch
.
no_grad
()
def
test_layerwise_casting_inference
(
self
):
def
test_layerwise_casting_inference
(
self
):
from
diffusers.hooks.layerwise_casting
import
DEFAULT_SKIP_MODULES_PATTERN
,
SUPPORTED_PYTORCH_LAYERS
from
diffusers.hooks.layerwise_casting
import
DEFAULT_SKIP_MODULES_PATTERN
,
SUPPORTED_PYTORCH_LAYERS
torch
.
manual_seed
(
0
)
torch
.
manual_seed
(
0
)
config
,
inputs_dict
=
self
.
prepare_init_args_and_inputs_for_common
()
config
,
inputs_dict
=
self
.
prepare_init_args_and_inputs_for_common
()
model
=
self
.
model_class
(
**
config
).
eval
()
model
=
self
.
model_class
(
**
config
)
model
=
model
.
to
(
torch_device
)
model
.
eval
()
base_slice
=
model
(
**
inputs_dict
)[
0
].
flatten
().
detach
().
cpu
().
numpy
()
model
.
to
(
torch_device
)
base_slice
=
model
(
**
inputs_dict
)[
0
].
detach
().
flatten
().
cpu
().
numpy
()
def
check_linear_dtype
(
module
,
storage_dtype
,
compute_dtype
):
def
check_linear_dtype
(
module
,
storage_dtype
,
compute_dtype
):
patterns_to_check
=
DEFAULT_SKIP_MODULES_PATTERN
patterns_to_check
=
DEFAULT_SKIP_MODULES_PATTERN
...
@@ -1573,6 +1575,7 @@ class ModelTesterMixin:
...
@@ -1573,6 +1575,7 @@ class ModelTesterMixin:
test_layerwise_casting
(
torch
.
float8_e4m3fn
,
torch
.
bfloat16
)
test_layerwise_casting
(
torch
.
float8_e4m3fn
,
torch
.
bfloat16
)
@
require_torch_accelerator
@
require_torch_accelerator
@
torch
.
no_grad
()
def
test_layerwise_casting_memory
(
self
):
def
test_layerwise_casting_memory
(
self
):
MB_TOLERANCE
=
0.2
MB_TOLERANCE
=
0.2
LEAST_COMPUTE_CAPABILITY
=
8.0
LEAST_COMPUTE_CAPABILITY
=
8.0
...
@@ -1706,10 +1709,6 @@ class ModelTesterMixin:
...
@@ -1706,10 +1709,6 @@ class ModelTesterMixin:
if
not
self
.
model_class
.
_supports_group_offloading
:
if
not
self
.
model_class
.
_supports_group_offloading
:
pytest
.
skip
(
"Model does not support group offloading."
)
pytest
.
skip
(
"Model does not support group offloading."
)
torch
.
manual_seed
(
0
)
init_dict
,
inputs_dict
=
self
.
prepare_init_args_and_inputs_for_common
()
model
=
self
.
model_class
(
**
init_dict
)
torch
.
manual_seed
(
0
)
torch
.
manual_seed
(
0
)
init_dict
,
inputs_dict
=
self
.
prepare_init_args_and_inputs_for_common
()
init_dict
,
inputs_dict
=
self
.
prepare_init_args_and_inputs_for_common
()
model
=
self
.
model_class
(
**
init_dict
)
model
=
self
.
model_class
(
**
init_dict
)
...
@@ -1725,7 +1724,7 @@ class ModelTesterMixin:
...
@@ -1725,7 +1724,7 @@ class ModelTesterMixin:
**
additional_kwargs
,
**
additional_kwargs
,
)
)
has_safetensors
=
glob
.
glob
(
f
"
{
tmpdir
}
/*.safetensors"
)
has_safetensors
=
glob
.
glob
(
f
"
{
tmpdir
}
/*.safetensors"
)
assert
has_safetensors
,
"No safetensors found in the directory."
self
.
assert
True
(
len
(
has_safetensors
)
>
0
,
"No safetensors found in the
offload
directory."
)
_
=
model
(
**
inputs_dict
)[
0
]
_
=
model
(
**
inputs_dict
)[
0
]
def
test_auto_model
(
self
,
expected_max_diff
=
5e-5
):
def
test_auto_model
(
self
,
expected_max_diff
=
5e-5
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment