Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
diffusers
Commits
cf4664e8
You need to sign in or sign up before continuing.
Commit
cf4664e8
authored
Dec 02, 2022
by
Patrick von Platen
Browse files
fix tests
parent
7222a8ea
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
4 additions
and
23 deletions
+4
-23
src/diffusers/models/unet_2d_condition.py
src/diffusers/models/unet_2d_condition.py
+1
-1
src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py
...users/pipelines/versatile_diffusion/modeling_text_unet.py
+1
-1
src/diffusers/schedulers/scheduling_lms_discrete_flax.py
src/diffusers/schedulers/scheduling_lms_discrete_flax.py
+2
-2
tests/pipelines/stable_diffusion_2/test_stable_diffusion.py
tests/pipelines/stable_diffusion_2/test_stable_diffusion.py
+0
-19
No files found.
src/diffusers/models/unet_2d_condition.py
View file @
cf4664e8
...
@@ -301,7 +301,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
...
@@ -301,7 +301,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
# This would be a good case for the `match` statement (Python 3.10+)
# This would be a good case for the `match` statement (Python 3.10+)
is_mps
=
sample
.
device
.
type
==
"mps"
is_mps
=
sample
.
device
.
type
==
"mps"
if
torch
.
is_floating_point
(
timestep
s
):
if
isinstance
(
timestep
,
float
):
dtype
=
torch
.
float32
if
is_mps
else
torch
.
float64
dtype
=
torch
.
float32
if
is_mps
else
torch
.
float64
else
:
else
:
dtype
=
torch
.
int32
if
is_mps
else
torch
.
int64
dtype
=
torch
.
int32
if
is_mps
else
torch
.
int64
...
...
src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py
View file @
cf4664e8
...
@@ -379,7 +379,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
...
@@ -379,7 +379,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
# This would be a good case for the `match` statement (Python 3.10+)
# This would be a good case for the `match` statement (Python 3.10+)
is_mps
=
sample
.
device
.
type
==
"mps"
is_mps
=
sample
.
device
.
type
==
"mps"
if
torch
.
is_floating_point
(
timestep
s
):
if
isinstance
(
timestep
,
float
):
dtype
=
torch
.
float32
if
is_mps
else
torch
.
float64
dtype
=
torch
.
float32
if
is_mps
else
torch
.
float64
else
:
else
:
dtype
=
torch
.
int32
if
is_mps
else
torch
.
int64
dtype
=
torch
.
int32
if
is_mps
else
torch
.
int64
...
...
src/diffusers/schedulers/scheduling_lms_discrete_flax.py
View file @
cf4664e8
...
@@ -117,8 +117,8 @@ class FlaxLMSDiscreteScheduler(FlaxSchedulerMixin, ConfigMixin):
...
@@ -117,8 +117,8 @@ class FlaxLMSDiscreteScheduler(FlaxSchedulerMixin, ConfigMixin):
Returns:
Returns:
`jnp.ndarray`: scaled input sample
`jnp.ndarray`: scaled input sample
"""
"""
(
step_index
,)
=
jnp
.
where
(
scheduler_
state
.
timesteps
==
timestep
,
size
=
1
)
(
step_index
,)
=
jnp
.
where
(
state
.
timesteps
==
timestep
,
size
=
1
)
sigma
=
scheduler_
state
.
sigmas
[
step_index
]
sigma
=
state
.
sigmas
[
step_index
]
sample
=
sample
/
((
sigma
**
2
+
1
)
**
0.5
)
sample
=
sample
/
((
sigma
**
2
+
1
)
**
0.5
)
return
sample
return
sample
...
...
tests/pipelines/stable_diffusion_2/test_stable_diffusion.py
View file @
cf4664e8
...
@@ -15,7 +15,6 @@
...
@@ -15,7 +15,6 @@
import
gc
import
gc
import
tempfile
import
tempfile
import
time
import
unittest
import
unittest
import
numpy
as
np
import
numpy
as
np
...
@@ -694,24 +693,6 @@ class StableDiffusion2PipelineIntegrationTests(unittest.TestCase):
...
@@ -694,24 +693,6 @@ class StableDiffusion2PipelineIntegrationTests(unittest.TestCase):
assert
test_callback_fn
.
has_been_called
assert
test_callback_fn
.
has_been_called
assert
number_of_steps
==
20
assert
number_of_steps
==
20
def
test_stable_diffusion_low_cpu_mem_usage
(
self
):
pipeline_id
=
"stabilityai/stable-diffusion-2-base"
start_time
=
time
.
time
()
pipeline_low_cpu_mem_usage
=
StableDiffusionPipeline
.
from_pretrained
(
pipeline_id
,
revision
=
"fp16"
,
torch_dtype
=
torch
.
float16
)
pipeline_low_cpu_mem_usage
.
to
(
torch_device
)
low_cpu_mem_usage_time
=
time
.
time
()
-
start_time
start_time
=
time
.
time
()
_
=
StableDiffusionPipeline
.
from_pretrained
(
pipeline_id
,
revision
=
"fp16"
,
torch_dtype
=
torch
.
float16
,
use_auth_token
=
True
,
low_cpu_mem_usage
=
False
)
normal_load_time
=
time
.
time
()
-
start_time
assert
2
*
low_cpu_mem_usage_time
<
normal_load_time
def
test_stable_diffusion_pipeline_with_sequential_cpu_offloading
(
self
):
def
test_stable_diffusion_pipeline_with_sequential_cpu_offloading
(
self
):
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
reset_max_memory_allocated
()
torch
.
cuda
.
reset_max_memory_allocated
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment