Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
cf4664e8
Commit
cf4664e8
authored
Dec 02, 2022
by
Patrick von Platen
Browse files
fix tests
parent
7222a8ea
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
4 additions
and
23 deletions
+4
-23
src/diffusers/models/unet_2d_condition.py
src/diffusers/models/unet_2d_condition.py
+1
-1
src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py
...users/pipelines/versatile_diffusion/modeling_text_unet.py
+1
-1
src/diffusers/schedulers/scheduling_lms_discrete_flax.py
src/diffusers/schedulers/scheduling_lms_discrete_flax.py
+2
-2
tests/pipelines/stable_diffusion_2/test_stable_diffusion.py
tests/pipelines/stable_diffusion_2/test_stable_diffusion.py
+0
-19
No files found.
src/diffusers/models/unet_2d_condition.py
View file @
cf4664e8
...
...
@@ -301,7 +301,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
# This would be a good case for the `match` statement (Python 3.10+)
is_mps
=
sample
.
device
.
type
==
"mps"
if
torch
.
is_floating_point
(
timestep
s
):
if
isinstance
(
timestep
,
float
):
dtype
=
torch
.
float32
if
is_mps
else
torch
.
float64
else
:
dtype
=
torch
.
int32
if
is_mps
else
torch
.
int64
...
...
src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py
View file @
cf4664e8
...
...
@@ -379,7 +379,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
# This would be a good case for the `match` statement (Python 3.10+)
is_mps
=
sample
.
device
.
type
==
"mps"
if
torch
.
is_floating_point
(
timestep
s
):
if
isinstance
(
timestep
,
float
):
dtype
=
torch
.
float32
if
is_mps
else
torch
.
float64
else
:
dtype
=
torch
.
int32
if
is_mps
else
torch
.
int64
...
...
src/diffusers/schedulers/scheduling_lms_discrete_flax.py
View file @
cf4664e8
...
...
@@ -117,8 +117,8 @@ class FlaxLMSDiscreteScheduler(FlaxSchedulerMixin, ConfigMixin):
Returns:
`jnp.ndarray`: scaled input sample
"""
(
step_index
,)
=
jnp
.
where
(
scheduler_
state
.
timesteps
==
timestep
,
size
=
1
)
sigma
=
scheduler_
state
.
sigmas
[
step_index
]
(
step_index
,)
=
jnp
.
where
(
state
.
timesteps
==
timestep
,
size
=
1
)
sigma
=
state
.
sigmas
[
step_index
]
sample
=
sample
/
((
sigma
**
2
+
1
)
**
0.5
)
return
sample
...
...
tests/pipelines/stable_diffusion_2/test_stable_diffusion.py
View file @
cf4664e8
...
...
@@ -15,7 +15,6 @@
import
gc
import
tempfile
import
time
import
unittest
import
numpy
as
np
...
...
@@ -694,24 +693,6 @@ class StableDiffusion2PipelineIntegrationTests(unittest.TestCase):
assert
test_callback_fn
.
has_been_called
assert
number_of_steps
==
20
def
test_stable_diffusion_low_cpu_mem_usage
(
self
):
pipeline_id
=
"stabilityai/stable-diffusion-2-base"
start_time
=
time
.
time
()
pipeline_low_cpu_mem_usage
=
StableDiffusionPipeline
.
from_pretrained
(
pipeline_id
,
revision
=
"fp16"
,
torch_dtype
=
torch
.
float16
)
pipeline_low_cpu_mem_usage
.
to
(
torch_device
)
low_cpu_mem_usage_time
=
time
.
time
()
-
start_time
start_time
=
time
.
time
()
_
=
StableDiffusionPipeline
.
from_pretrained
(
pipeline_id
,
revision
=
"fp16"
,
torch_dtype
=
torch
.
float16
,
use_auth_token
=
True
,
low_cpu_mem_usage
=
False
)
normal_load_time
=
time
.
time
()
-
start_time
assert
2
*
low_cpu_mem_usage_time
<
normal_load_time
def
test_stable_diffusion_pipeline_with_sequential_cpu_offloading
(
self
):
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
reset_max_memory_allocated
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment