Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
be54a95b
Unverified
Commit
be54a95b
authored
Mar 15, 2025
by
Dimitri Barbot
Committed by
GitHub
Mar 15, 2025
Browse files
Fix deterministic issue when getting pipeline dtype and device (#10696)
Co-authored-by:
Dhruv Nair
<
dhruv.nair@gmail.com
>
parent
6b9a3334
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
107 additions
and
4 deletions
+107
-4
src/diffusers/pipelines/pipeline_utils.py
src/diffusers/pipelines/pipeline_utils.py
+5
-3
tests/pipelines/test_pipeline_utils.py
tests/pipelines/test_pipeline_utils.py
+102
-1
No files found.
src/diffusers/pipelines/pipeline_utils.py
View file @
be54a95b
...
@@ -1610,7 +1610,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
...
@@ -1610,7 +1610,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
expected_modules
.
add
(
name
)
expected_modules
.
add
(
name
)
optional_parameters
.
remove
(
name
)
optional_parameters
.
remove
(
name
)
return
expected_modules
,
optional_parameters
return
sorted
(
expected_modules
)
,
sorted
(
optional_parameters
)
@
classmethod
@
classmethod
def
_get_signature_types
(
cls
):
def
_get_signature_types
(
cls
):
...
@@ -1652,10 +1652,12 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
...
@@ -1652,10 +1652,12 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
k
:
getattr
(
self
,
k
)
for
k
in
self
.
config
.
keys
()
if
not
k
.
startswith
(
"_"
)
and
k
not
in
optional_parameters
k
:
getattr
(
self
,
k
)
for
k
in
self
.
config
.
keys
()
if
not
k
.
startswith
(
"_"
)
and
k
not
in
optional_parameters
}
}
if
set
(
components
.
keys
())
!=
expected_modules
:
actual
=
sorted
(
set
(
components
.
keys
()))
expected
=
sorted
(
expected_modules
)
if
actual
!=
expected
:
raise
ValueError
(
raise
ValueError
(
f
"
{
self
}
has been incorrectly initialized or
{
self
.
__class__
}
is incorrectly implemented. Expected"
f
"
{
self
}
has been incorrectly initialized or
{
self
.
__class__
}
is incorrectly implemented. Expected"
f
"
{
expected
_modules
}
to be defined, but
{
components
.
keys
()
}
are defined."
f
"
{
expected
}
to be defined, but
{
actual
}
are defined."
)
)
return
components
return
components
...
...
tests/pipelines/test_pipeline_utils.py
View file @
be54a95b
...
@@ -19,7 +19,7 @@ from diffusers import (
...
@@ -19,7 +19,7 @@ from diffusers import (
UNet2DConditionModel
,
UNet2DConditionModel
,
)
)
from
diffusers.pipelines.pipeline_loading_utils
import
is_safetensors_compatible
,
variant_compatible_siblings
from
diffusers.pipelines.pipeline_loading_utils
import
is_safetensors_compatible
,
variant_compatible_siblings
from
diffusers.utils.testing_utils
import
torch_device
from
diffusers.utils.testing_utils
import
require_torch_gpu
,
torch_device
class
IsSafetensorsCompatibleTests
(
unittest
.
TestCase
):
class
IsSafetensorsCompatibleTests
(
unittest
.
TestCase
):
...
@@ -826,3 +826,104 @@ class ProgressBarTests(unittest.TestCase):
...
@@ -826,3 +826,104 @@ class ProgressBarTests(unittest.TestCase):
with
io
.
StringIO
()
as
stderr
,
contextlib
.
redirect_stderr
(
stderr
):
with
io
.
StringIO
()
as
stderr
,
contextlib
.
redirect_stderr
(
stderr
):
_
=
pipe
(
**
inputs
)
_
=
pipe
(
**
inputs
)
self
.
assertTrue
(
stderr
.
getvalue
()
==
""
,
"Progress bar should be disabled"
)
self
.
assertTrue
(
stderr
.
getvalue
()
==
""
,
"Progress bar should be disabled"
)
@
require_torch_gpu
class
PipelineDeviceAndDtypeStabilityTests
(
unittest
.
TestCase
):
expected_pipe_device
=
torch
.
device
(
"cuda:0"
)
expected_pipe_dtype
=
torch
.
float64
def
get_dummy_components_image_generation
(
self
):
cross_attention_dim
=
8
torch
.
manual_seed
(
0
)
unet
=
UNet2DConditionModel
(
block_out_channels
=
(
4
,
8
),
layers_per_block
=
1
,
sample_size
=
32
,
in_channels
=
4
,
out_channels
=
4
,
down_block_types
=
(
"DownBlock2D"
,
"CrossAttnDownBlock2D"
),
up_block_types
=
(
"CrossAttnUpBlock2D"
,
"UpBlock2D"
),
cross_attention_dim
=
cross_attention_dim
,
norm_num_groups
=
2
,
)
scheduler
=
DDIMScheduler
(
beta_start
=
0.00085
,
beta_end
=
0.012
,
beta_schedule
=
"scaled_linear"
,
clip_sample
=
False
,
set_alpha_to_one
=
False
,
)
torch
.
manual_seed
(
0
)
vae
=
AutoencoderKL
(
block_out_channels
=
[
4
,
8
],
in_channels
=
3
,
out_channels
=
3
,
down_block_types
=
[
"DownEncoderBlock2D"
,
"DownEncoderBlock2D"
],
up_block_types
=
[
"UpDecoderBlock2D"
,
"UpDecoderBlock2D"
],
latent_channels
=
4
,
norm_num_groups
=
2
,
)
torch
.
manual_seed
(
0
)
text_encoder_config
=
CLIPTextConfig
(
bos_token_id
=
0
,
eos_token_id
=
2
,
hidden_size
=
cross_attention_dim
,
intermediate_size
=
16
,
layer_norm_eps
=
1e-05
,
num_attention_heads
=
2
,
num_hidden_layers
=
2
,
pad_token_id
=
1
,
vocab_size
=
1000
,
)
text_encoder
=
CLIPTextModel
(
text_encoder_config
)
tokenizer
=
CLIPTokenizer
.
from_pretrained
(
"hf-internal-testing/tiny-random-clip"
)
components
=
{
"unet"
:
unet
,
"scheduler"
:
scheduler
,
"vae"
:
vae
,
"text_encoder"
:
text_encoder
,
"tokenizer"
:
tokenizer
,
"safety_checker"
:
None
,
"feature_extractor"
:
None
,
"image_encoder"
:
None
,
}
return
components
def
test_deterministic_device
(
self
):
components
=
self
.
get_dummy_components_image_generation
()
pipe
=
StableDiffusionPipeline
(
**
components
)
pipe
.
to
(
device
=
torch_device
,
dtype
=
torch
.
float32
)
pipe
.
unet
.
to
(
device
=
"cpu"
)
pipe
.
vae
.
to
(
device
=
"cuda"
)
pipe
.
text_encoder
.
to
(
device
=
"cuda:0"
)
pipe_device
=
pipe
.
device
self
.
assertEqual
(
self
.
expected_pipe_device
,
pipe_device
,
f
"Wrong expected device. Expected
{
self
.
expected_pipe_device
}
. Got
{
pipe_device
}
."
,
)
def
test_deterministic_dtype
(
self
):
components
=
self
.
get_dummy_components_image_generation
()
pipe
=
StableDiffusionPipeline
(
**
components
)
pipe
.
to
(
device
=
torch_device
,
dtype
=
torch
.
float32
)
pipe
.
unet
.
to
(
dtype
=
torch
.
float16
)
pipe
.
vae
.
to
(
dtype
=
torch
.
float32
)
pipe
.
text_encoder
.
to
(
dtype
=
torch
.
float64
)
pipe_dtype
=
pipe
.
dtype
self
.
assertEqual
(
self
.
expected_pipe_dtype
,
pipe_dtype
,
f
"Wrong expected dtype. Expected
{
self
.
expected_pipe_dtype
}
. Got
{
pipe_dtype
}
."
,
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment