Commit cf4664e8 authored by Patrick von Platen's avatar Patrick von Platen
Browse files

fix tests

parent 7222a8ea
......@@ -301,7 +301,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
# This would be a good case for the `match` statement (Python 3.10+)
is_mps = sample.device.type == "mps"
if torch.is_floating_point(timesteps):
if isinstance(timestep, float):
dtype = torch.float32 if is_mps else torch.float64
else:
dtype = torch.int32 if is_mps else torch.int64
......
......@@ -379,7 +379,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
# This would be a good case for the `match` statement (Python 3.10+)
is_mps = sample.device.type == "mps"
if torch.is_floating_point(timesteps):
if isinstance(timestep, float):
dtype = torch.float32 if is_mps else torch.float64
else:
dtype = torch.int32 if is_mps else torch.int64
......
......@@ -117,8 +117,8 @@ class FlaxLMSDiscreteScheduler(FlaxSchedulerMixin, ConfigMixin):
Returns:
`jnp.ndarray`: scaled input sample
"""
(step_index,) = jnp.where(scheduler_state.timesteps == timestep, size=1)
sigma = scheduler_state.sigmas[step_index]
(step_index,) = jnp.where(state.timesteps == timestep, size=1)
sigma = state.sigmas[step_index]
sample = sample / ((sigma**2 + 1) ** 0.5)
return sample
......
......@@ -15,7 +15,6 @@
import gc
import tempfile
import time
import unittest
import numpy as np
......@@ -694,24 +693,6 @@ class StableDiffusion2PipelineIntegrationTests(unittest.TestCase):
assert test_callback_fn.has_been_called
assert number_of_steps == 20
def test_stable_diffusion_low_cpu_mem_usage(self):
pipeline_id = "stabilityai/stable-diffusion-2-base"
start_time = time.time()
pipeline_low_cpu_mem_usage = StableDiffusionPipeline.from_pretrained(
pipeline_id, revision="fp16", torch_dtype=torch.float16
)
pipeline_low_cpu_mem_usage.to(torch_device)
low_cpu_mem_usage_time = time.time() - start_time
start_time = time.time()
_ = StableDiffusionPipeline.from_pretrained(
pipeline_id, revision="fp16", torch_dtype=torch.float16, use_auth_token=True, low_cpu_mem_usage=False
)
normal_load_time = time.time() - start_time
assert 2 * low_cpu_mem_usage_time < normal_load_time
def test_stable_diffusion_pipeline_with_sequential_cpu_offloading(self):
torch.cuda.empty_cache()
torch.cuda.reset_max_memory_allocated()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment