Unverified Commit 6ba2231d authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

Reproducibility 3/3 (#1924)



* make tests deterministic

* run slow tests

* prepare for testing

* finish

* refactor

* add print statements

* finish more

* correct some test failures

* more fixes

* set up to correct tests

* more corrections

* up

* fix more

* more prints

* add

* up

* up

* up

* uP

* uP

* more fixes

* uP

* up

* up

* up

* up

* fix more

* up

* up

* clean tests

* up

* up

* up

* more fixes

* Apply suggestions from code review
Co-authored-by: default avatarSuraj Patil <surajp815@gmail.com>

* make

* correct

* finish

* finish
Co-authored-by: default avatarSuraj Patil <surajp815@gmail.com>
parent 008c22d3
......@@ -40,7 +40,7 @@ class VersatileDiffusionImageVariationPipelineIntegrationTests(unittest.TestCase
image_prompt = load_image(
"https://raw.githubusercontent.com/SHI-Labs/Versatile-Diffusion/master/assets/benz.jpg"
)
generator = torch.Generator(device=torch_device).manual_seed(0)
generator = torch.manual_seed(0)
image = pipe(
image=image_prompt,
generator=generator,
......@@ -52,5 +52,6 @@ class VersatileDiffusionImageVariationPipelineIntegrationTests(unittest.TestCase
image_slice = image[0, 253:256, 253:256, -1]
assert image.shape == (1, 512, 512, 3)
expected_slice = np.array([0.1205, 0.1914, 0.2289, 0.0883, 0.1595, 0.1683, 0.0703, 0.1493, 0.1298])
expected_slice = np.array([0.0441, 0.0469, 0.0507, 0.0575, 0.0632, 0.0650, 0.0865, 0.0909, 0.0945])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
......@@ -49,7 +49,7 @@ class VersatileDiffusionMegaPipelineIntegrationTests(unittest.TestCase):
"https://raw.githubusercontent.com/SHI-Labs/Versatile-Diffusion/master/assets/benz.jpg"
)
generator = torch.Generator(device=torch_device).manual_seed(0)
generator = torch.manual_seed(0)
image = pipe.dual_guided(
prompt="first prompt",
image=prompt_image,
......@@ -88,7 +88,7 @@ class VersatileDiffusionMegaPipelineIntegrationTests(unittest.TestCase):
init_image = load_image(
"https://raw.githubusercontent.com/SHI-Labs/Versatile-Diffusion/master/assets/benz.jpg"
)
generator = torch.Generator(device=torch_device).manual_seed(0)
generator = torch.manual_seed(0)
image = pipe.dual_guided(
prompt=prompt,
image=init_image,
......@@ -102,11 +102,12 @@ class VersatileDiffusionMegaPipelineIntegrationTests(unittest.TestCase):
image_slice = image[0, 253:256, 253:256, -1]
assert image.shape == (1, 512, 512, 3)
expected_slice = np.array([0.0081, 0.0032, 0.0002, 0.0056, 0.0027, 0.0000, 0.0051, 0.0020, 0.0007])
assert np.abs(image_slice.flatten() - expected_slice).max() < 5e-2
expected_slice = np.array([0.1448, 0.1619, 0.1741, 0.1086, 0.1147, 0.1128, 0.1199, 0.1165, 0.1001])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-1
prompt = "A painting of a squirrel eating a burger "
generator = torch.Generator(device=torch_device).manual_seed(0)
generator = torch.manual_seed(0)
image = pipe.text_to_image(
prompt=prompt, generator=generator, guidance_scale=7.5, num_inference_steps=50, output_type="numpy"
).images
......@@ -114,13 +115,15 @@ class VersatileDiffusionMegaPipelineIntegrationTests(unittest.TestCase):
image_slice = image[0, 253:256, 253:256, -1]
assert image.shape == (1, 512, 512, 3)
expected_slice = np.array([0.0408, 0.0181, 0.0, 0.0388, 0.0046, 0.0461, 0.0411, 0.0, 0.0222])
assert np.abs(image_slice.flatten() - expected_slice).max() < 5e-2
expected_slice = np.array([0.3367, 0.3169, 0.2656, 0.3870, 0.4790, 0.3796, 0.4009, 0.4878, 0.4778])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-1
image = pipe.image_variation(init_image, generator=generator, output_type="numpy").images
image_slice = image[0, 253:256, 253:256, -1]
assert image.shape == (1, 512, 512, 3)
expected_slice = np.array([0.3403, 0.1809, 0.0938, 0.3855, 0.2393, 0.1243, 0.4028, 0.3110, 0.1799])
assert np.abs(image_slice.flatten() - expected_slice).max() < 5e-2
expected_slice = np.array([0.3076, 0.3123, 0.3284, 0.3782, 0.3770, 0.3894, 0.4297, 0.4331, 0.4456])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-1
......@@ -48,7 +48,7 @@ class VersatileDiffusionTextToImagePipelineIntegrationTests(unittest.TestCase):
pipe.set_progress_bar_config(disable=None)
prompt = "A painting of a squirrel eating a burger "
generator = torch.Generator(device=torch_device).manual_seed(0)
generator = torch.manual_seed(0)
image = pipe(
prompt=prompt, generator=generator, guidance_scale=7.5, num_inference_steps=2, output_type="numpy"
).images
......@@ -72,7 +72,7 @@ class VersatileDiffusionTextToImagePipelineIntegrationTests(unittest.TestCase):
pipe.set_progress_bar_config(disable=None)
prompt = "A painting of a squirrel eating a burger "
generator = torch.Generator(device=torch_device).manual_seed(0)
generator = torch.manual_seed(0)
image = pipe(
prompt=prompt, generator=generator, guidance_scale=7.5, num_inference_steps=50, output_type="numpy"
).images
......@@ -80,5 +80,6 @@ class VersatileDiffusionTextToImagePipelineIntegrationTests(unittest.TestCase):
image_slice = image[0, 253:256, 253:256, -1]
assert image.shape == (1, 512, 512, 3)
expected_slice = np.array([0.0408, 0.0181, 0.0, 0.0388, 0.0046, 0.0461, 0.0411, 0.0, 0.0222])
expected_slice = np.array([0.3493, 0.3757, 0.4093, 0.4495, 0.4233, 0.4102, 0.4507, 0.4756, 0.4787])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
......@@ -212,6 +212,8 @@ class VQDiffusionPipelineIntegrationTests(unittest.TestCase):
pipeline = pipeline.to(torch_device)
pipeline.set_progress_bar_config(disable=None)
# requires GPU generator for gumbel softmax
# don't use GPU generator in tests though
generator = torch.Generator(device=torch_device).manual_seed(0)
output = pipeline(
"teddy bear playing in the pool",
......
......@@ -86,19 +86,11 @@ class DownloadTests(unittest.TestCase):
pipe = pipe.to(torch_device)
pipe_2 = pipe_2.to(torch_device)
if torch_device == "mps":
# device type MPS is not supported for torch.Generator() api.
generator = torch.manual_seed(0)
else:
generator = torch.Generator(device=torch_device).manual_seed(0)
generator = torch.manual_seed(0)
out = pipe(prompt, num_inference_steps=2, generator=generator, output_type="numpy").images
if torch_device == "mps":
# device type MPS is not supported for torch.Generator() api.
generator = torch.manual_seed(0)
else:
generator = torch.Generator(device=torch_device).manual_seed(0)
generator = torch.manual_seed(0)
out_2 = pipe_2(prompt, num_inference_steps=2, generator=generator, output_type="numpy").images
assert np.max(np.abs(out - out_2)) < 1e-3
......@@ -125,20 +117,12 @@ class DownloadTests(unittest.TestCase):
"hf-internal-testing/tiny-stable-diffusion-torch", safety_checker=None
)
pipe = pipe.to(torch_device)
if torch_device == "mps":
# device type MPS is not supported for torch.Generator() api.
generator = torch.manual_seed(0)
else:
generator = torch.Generator(device=torch_device).manual_seed(0)
generator = torch.manual_seed(0)
out = pipe(prompt, num_inference_steps=2, generator=generator, output_type="numpy").images
pipe_2 = StableDiffusionPipeline.from_pretrained("hf-internal-testing/tiny-stable-diffusion-torch")
pipe_2 = pipe_2.to(torch_device)
if torch_device == "mps":
# device type MPS is not supported for torch.Generator() api.
generator = torch.manual_seed(0)
else:
generator = torch.Generator(device=torch_device).manual_seed(0)
generator = torch.manual_seed(0)
out_2 = pipe_2(prompt, num_inference_steps=2, generator=generator, output_type="numpy").images
assert np.max(np.abs(out - out_2)) < 1e-3
......@@ -149,11 +133,7 @@ class DownloadTests(unittest.TestCase):
"hf-internal-testing/tiny-stable-diffusion-torch", safety_checker=None
)
pipe = pipe.to(torch_device)
if torch_device == "mps":
# device type MPS is not supported for torch.Generator() api.
generator = torch.manual_seed(0)
else:
generator = torch.Generator(device=torch_device).manual_seed(0)
generator = torch.manual_seed(0)
out = pipe(prompt, num_inference_steps=2, generator=generator, output_type="numpy").images
with tempfile.TemporaryDirectory() as tmpdirname:
......@@ -161,11 +141,7 @@ class DownloadTests(unittest.TestCase):
pipe_2 = StableDiffusionPipeline.from_pretrained(tmpdirname, safety_checker=None)
pipe_2 = pipe_2.to(torch_device)
if torch_device == "mps":
# device type MPS is not supported for torch.Generator() api.
generator = torch.manual_seed(0)
else:
generator = torch.Generator(device=torch_device).manual_seed(0)
generator = torch.manual_seed(0)
out_2 = pipe_2(prompt, num_inference_steps=2, generator=generator, output_type="numpy").images
......@@ -175,11 +151,8 @@ class DownloadTests(unittest.TestCase):
prompt = "hello"
pipe = StableDiffusionPipeline.from_pretrained("hf-internal-testing/tiny-stable-diffusion-torch")
pipe = pipe.to(torch_device)
if torch_device == "mps":
# device type MPS is not supported for torch.Generator() api.
generator = torch.manual_seed(0)
else:
generator = torch.Generator(device=torch_device).manual_seed(0)
generator = torch.manual_seed(0)
out = pipe(prompt, num_inference_steps=2, generator=generator, output_type="numpy").images
with tempfile.TemporaryDirectory() as tmpdirname:
......@@ -187,11 +160,7 @@ class DownloadTests(unittest.TestCase):
pipe_2 = StableDiffusionPipeline.from_pretrained(tmpdirname)
pipe_2 = pipe_2.to(torch_device)
if torch_device == "mps":
# device type MPS is not supported for torch.Generator() api.
generator = torch.manual_seed(0)
else:
generator = torch.Generator(device=torch_device).manual_seed(0)
generator = torch.manual_seed(0)
out_2 = pipe_2(prompt, num_inference_steps=2, generator=generator, output_type="numpy").images
......@@ -401,12 +370,7 @@ class PipelineFastTests(unittest.TestCase):
scheduler = scheduler_fn()
pipeline = pipeline_fn(unet, scheduler).to(torch_device)
# Device type MPS is not supported for torch.Generator() api.
if torch_device == "mps":
generator = torch.manual_seed(0)
else:
generator = torch.Generator(device=torch_device).manual_seed(0)
generator = torch.manual_seed(0)
out_image = pipeline(
generator=generator,
num_inference_steps=2,
......@@ -442,12 +406,7 @@ class PipelineFastTests(unittest.TestCase):
prompt = "A painting of a squirrel eating a burger"
# Device type MPS is not supported for torch.Generator() api.
if torch_device == "mps":
generator = torch.manual_seed(0)
else:
generator = torch.Generator(device=torch_device).manual_seed(0)
generator = torch.manual_seed(0)
image_inpaint = inpaint(
[prompt],
generator=generator,
......@@ -798,7 +757,7 @@ class PipelineSlowTests(unittest.TestCase):
generator = torch.Generator(device=torch_device).manual_seed(0)
image = ddpm(generator=generator, num_inference_steps=5, output_type="numpy").images
generator = generator.manual_seed(0)
generator = torch.Generator(device=torch_device).manual_seed(0)
new_image = new_ddpm(generator=generator, num_inference_steps=5, output_type="numpy").images
assert np.abs(image - new_image).sum() < 1e-5, "Models don't give the same forward pass"
......@@ -819,7 +778,7 @@ class PipelineSlowTests(unittest.TestCase):
generator = torch.Generator(device=torch_device).manual_seed(0)
image = ddpm(generator=generator, num_inference_steps=5, output_type="numpy").images
generator = generator.manual_seed(0)
generator = torch.Generator(device=torch_device).manual_seed(0)
new_image = ddpm_from_hub(generator=generator, num_inference_steps=5, output_type="numpy").images
assert np.abs(image - new_image).sum() < 1e-5, "Models don't give the same forward pass"
......@@ -842,7 +801,7 @@ class PipelineSlowTests(unittest.TestCase):
generator = torch.Generator(device=torch_device).manual_seed(0)
image = ddpm_from_hub_custom_model(generator=generator, num_inference_steps=5, output_type="numpy").images
generator = generator.manual_seed(0)
generator = torch.Generator(device=torch_device).manual_seed(0)
new_image = ddpm_from_hub(generator=generator, num_inference_steps=5, output_type="numpy").images
assert np.abs(image - new_image).sum() < 1e-5, "Models don't give the same forward pass"
......@@ -855,18 +814,17 @@ class PipelineSlowTests(unittest.TestCase):
pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
generator = torch.Generator(device=torch_device).manual_seed(0)
images = pipe(generator=generator, output_type="numpy").images
images = pipe(output_type="numpy").images
assert images.shape == (1, 32, 32, 3)
assert isinstance(images, np.ndarray)
images = pipe(generator=generator, output_type="pil", num_inference_steps=4).images
images = pipe(output_type="pil", num_inference_steps=4).images
assert isinstance(images, list)
assert len(images) == 1
assert isinstance(images[0], PIL.Image.Image)
# use PIL by default
images = pipe(generator=generator, num_inference_steps=4).images
images = pipe(num_inference_steps=4).images
assert isinstance(images, list)
assert isinstance(images[0], PIL.Image.Image)
......
......@@ -288,11 +288,11 @@ class SchedulerCommonTest(unittest.TestCase):
# Set the seed before step() as some schedulers are stochastic like EulerAncestralDiscreteScheduler, EulerDiscreteScheduler
if "generator" in set(inspect.signature(scheduler.step).parameters.keys()):
kwargs["generator"] = torch.Generator().manual_seed(0)
kwargs["generator"] = torch.manual_seed(0)
output = scheduler.step(residual, time_step, sample, **kwargs).prev_sample
if "generator" in set(inspect.signature(scheduler.step).parameters.keys()):
kwargs["generator"] = torch.Generator().manual_seed(0)
kwargs["generator"] = torch.manual_seed(0)
new_output = new_scheduler.step(residual, time_step, sample, **kwargs).prev_sample
assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
......@@ -330,11 +330,11 @@ class SchedulerCommonTest(unittest.TestCase):
kwargs["num_inference_steps"] = num_inference_steps
if "generator" in set(inspect.signature(scheduler.step).parameters.keys()):
kwargs["generator"] = torch.Generator().manual_seed(0)
kwargs["generator"] = torch.manual_seed(0)
output = scheduler.step(residual, time_step, sample, **kwargs).prev_sample
if "generator" in set(inspect.signature(scheduler.step).parameters.keys()):
kwargs["generator"] = torch.Generator().manual_seed(0)
kwargs["generator"] = torch.manual_seed(0)
new_output = new_scheduler.step(residual, time_step, sample, **kwargs).prev_sample
assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
......@@ -372,11 +372,11 @@ class SchedulerCommonTest(unittest.TestCase):
kwargs["num_inference_steps"] = num_inference_steps
if "generator" in set(inspect.signature(scheduler.step).parameters.keys()):
kwargs["generator"] = torch.Generator().manual_seed(0)
kwargs["generator"] = torch.manual_seed(0)
output = scheduler.step(residual, timestep, sample, **kwargs).prev_sample
if "generator" in set(inspect.signature(scheduler.step).parameters.keys()):
kwargs["generator"] = torch.Generator().manual_seed(0)
kwargs["generator"] = torch.manual_seed(0)
new_output = new_scheduler.step(residual, timestep, sample, **kwargs).prev_sample
assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
......@@ -510,7 +510,7 @@ class SchedulerCommonTest(unittest.TestCase):
# Set the seed before state as some schedulers are stochastic like EulerAncestralDiscreteScheduler, EulerDiscreteScheduler
if "generator" in set(inspect.signature(scheduler.step).parameters.keys()):
kwargs["generator"] = torch.Generator().manual_seed(0)
kwargs["generator"] = torch.manual_seed(0)
outputs_dict = scheduler.step(residual, timestep, sample, **kwargs)
if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
......@@ -520,7 +520,7 @@ class SchedulerCommonTest(unittest.TestCase):
# Set the seed before state as some schedulers are stochastic like EulerAncestralDiscreteScheduler, EulerDiscreteScheduler
if "generator" in set(inspect.signature(scheduler.step).parameters.keys()):
kwargs["generator"] = torch.Generator().manual_seed(0)
kwargs["generator"] = torch.manual_seed(0)
outputs_tuple = scheduler.step(residual, timestep, sample, return_dict=False, **kwargs)
recursive_check(outputs_tuple, outputs_dict)
......@@ -664,12 +664,12 @@ class DDPMSchedulerTest(SchedulerCommonTest):
kwargs = {}
if "generator" in set(inspect.signature(scheduler.step).parameters.keys()):
kwargs["generator"] = torch.Generator().manual_seed(0)
kwargs["generator"] = torch.manual_seed(0)
output = scheduler.step(residual, time_step, sample, predict_epsilon=False, **kwargs).prev_sample
kwargs = {}
if "generator" in set(inspect.signature(scheduler.step).parameters.keys()):
kwargs["generator"] = torch.Generator().manual_seed(0)
kwargs["generator"] = torch.manual_seed(0)
output_eps = scheduler_eps.step(residual, time_step, sample, predict_epsilon=False, **kwargs).prev_sample
assert (output - output_eps).abs().sum() < 1e-5
......@@ -1822,11 +1822,7 @@ class EulerDiscreteSchedulerTest(SchedulerCommonTest):
scheduler.set_timesteps(self.num_inference_steps)
if torch_device == "mps":
# device type MPS is not supported for torch.Generator() api.
generator = torch.manual_seed(0)
else:
generator = torch.Generator(device=torch_device).manual_seed(0)
generator = torch.manual_seed(0)
model = self.dummy_model()
sample = self.dummy_sample_deter * scheduler.init_noise_sigma
......@@ -1853,11 +1849,7 @@ class EulerDiscreteSchedulerTest(SchedulerCommonTest):
scheduler.set_timesteps(self.num_inference_steps)
if torch_device == "mps":
# device type MPS is not supported for torch.Generator() api.
generator = torch.manual_seed(0)
else:
generator = torch.Generator(device=torch_device).manual_seed(0)
generator = torch.manual_seed(0)
model = self.dummy_model()
sample = self.dummy_sample_deter * scheduler.init_noise_sigma
......@@ -1884,11 +1876,7 @@ class EulerDiscreteSchedulerTest(SchedulerCommonTest):
scheduler.set_timesteps(self.num_inference_steps, device=torch_device)
if torch_device == "mps":
# device type MPS is not supported for torch.Generator() api.
generator = torch.manual_seed(0)
else:
generator = torch.Generator(device=torch_device).manual_seed(0)
generator = torch.manual_seed(0)
model = self.dummy_model()
sample = self.dummy_sample_deter * scheduler.init_noise_sigma
......@@ -1947,11 +1935,7 @@ class EulerAncestralDiscreteSchedulerTest(SchedulerCommonTest):
scheduler.set_timesteps(self.num_inference_steps)
if torch_device == "mps":
# device type MPS is not supported for torch.Generator() api.
generator = torch.manual_seed(0)
else:
generator = torch.Generator(device=torch_device).manual_seed(0)
generator = torch.manual_seed(0)
model = self.dummy_model()
sample = self.dummy_sample_deter * scheduler.init_noise_sigma
......@@ -1968,13 +1952,8 @@ class EulerAncestralDiscreteSchedulerTest(SchedulerCommonTest):
result_sum = torch.sum(torch.abs(sample))
result_mean = torch.mean(torch.abs(sample))
if torch_device in ["cpu", "mps"]:
assert abs(result_sum.item() - 152.3192) < 1e-2
assert abs(result_mean.item() - 0.1983) < 1e-3
else:
# CUDA
assert abs(result_sum.item() - 144.8084) < 1e-2
assert abs(result_mean.item() - 0.18855) < 1e-3
assert abs(result_sum.item() - 152.3192) < 1e-2
assert abs(result_mean.item() - 0.1983) < 1e-3
def test_full_loop_with_v_prediction(self):
scheduler_class = self.scheduler_classes[0]
......@@ -1983,11 +1962,7 @@ class EulerAncestralDiscreteSchedulerTest(SchedulerCommonTest):
scheduler.set_timesteps(self.num_inference_steps)
if torch_device == "mps":
# device type MPS is not supported for torch.Generator() api.
generator = torch.manual_seed(0)
else:
generator = torch.Generator(device=torch_device).manual_seed(0)
generator = torch.manual_seed(0)
model = self.dummy_model()
sample = self.dummy_sample_deter * scheduler.init_noise_sigma
......@@ -2004,13 +1979,8 @@ class EulerAncestralDiscreteSchedulerTest(SchedulerCommonTest):
result_sum = torch.sum(torch.abs(sample))
result_mean = torch.mean(torch.abs(sample))
if torch_device in ["cpu", "mps"]:
assert abs(result_sum.item() - 108.4439) < 1e-2
assert abs(result_mean.item() - 0.1412) < 1e-3
else:
# CUDA
assert abs(result_sum.item() - 102.5807) < 1e-2
assert abs(result_mean.item() - 0.1335) < 1e-3
assert abs(result_sum.item() - 108.4439) < 1e-2
assert abs(result_mean.item() - 0.1412) < 1e-3
def test_full_loop_device(self):
scheduler_class = self.scheduler_classes[0]
......@@ -2018,12 +1988,7 @@ class EulerAncestralDiscreteSchedulerTest(SchedulerCommonTest):
scheduler = scheduler_class(**scheduler_config)
scheduler.set_timesteps(self.num_inference_steps, device=torch_device)
if torch_device == "mps":
# device type MPS is not supported for torch.Generator() api.
generator = torch.manual_seed(0)
else:
generator = torch.Generator(device=torch_device).manual_seed(0)
generator = torch.manual_seed(0)
model = self.dummy_model()
sample = self.dummy_sample_deter * scheduler.init_noise_sigma
......@@ -2040,17 +2005,8 @@ class EulerAncestralDiscreteSchedulerTest(SchedulerCommonTest):
result_sum = torch.sum(torch.abs(sample))
result_mean = torch.mean(torch.abs(sample))
if str(torch_device).startswith("cpu"):
# The following sum varies between 148 and 156 on mps. Why?
assert abs(result_sum.item() - 152.3192) < 1e-2
assert abs(result_mean.item() - 0.1983) < 1e-3
elif str(torch_device).startswith("mps"):
# Larger tolerance on mps
assert abs(result_mean.item() - 0.1983) < 1e-2
else:
# CUDA
assert abs(result_sum.item() - 144.8084) < 1e-2
assert abs(result_mean.item() - 0.18855) < 1e-3
assert abs(result_sum.item() - 152.3192) < 1e-2
assert abs(result_mean.item() - 0.1983) < 1e-3
class IPNDMSchedulerTest(SchedulerCommonTest):
......@@ -2745,7 +2701,7 @@ class KDPM2AncestralDiscreteSchedulerTest(SchedulerCommonTest):
scheduler.set_timesteps(self.num_inference_steps)
generator = torch.Generator(device=torch_device).manual_seed(0)
generator = torch.manual_seed(0)
model = self.dummy_model()
sample = self.dummy_sample_deter * scheduler.init_noise_sigma
......@@ -2762,13 +2718,8 @@ class KDPM2AncestralDiscreteSchedulerTest(SchedulerCommonTest):
result_sum = torch.sum(torch.abs(sample))
result_mean = torch.mean(torch.abs(sample))
if torch_device in ["cpu", "mps"]:
assert abs(result_sum.item() - 13849.3945) < 1e-2
assert abs(result_mean.item() - 18.0331) < 5e-3
else:
# CUDA
assert abs(result_sum.item() - 13913.0449) < 1e-2
assert abs(result_mean.item() - 18.1159) < 5e-3
assert abs(result_sum.item() - 13849.3877) < 1e-2
assert abs(result_mean.item() - 18.0331) < 5e-3
def test_prediction_type(self):
for prediction_type in ["epsilon", "v_prediction"]:
......@@ -2787,11 +2738,7 @@ class KDPM2AncestralDiscreteSchedulerTest(SchedulerCommonTest):
sample = self.dummy_sample_deter * scheduler.init_noise_sigma
sample = sample.to(torch_device)
if torch_device == "mps":
# device type MPS is not supported for torch.Generator() api.
generator = torch.manual_seed(0)
else:
generator = torch.Generator(device=torch_device).manual_seed(0)
generator = torch.manual_seed(0)
for i, t in enumerate(scheduler.timesteps):
sample = scheduler.scale_model_input(sample, t)
......@@ -2804,13 +2751,8 @@ class KDPM2AncestralDiscreteSchedulerTest(SchedulerCommonTest):
result_sum = torch.sum(torch.abs(sample))
result_mean = torch.mean(torch.abs(sample))
if torch_device in ["cpu", "mps"]:
assert abs(result_sum.item() - 328.9970) < 1e-2
assert abs(result_mean.item() - 0.4284) < 1e-3
else:
# CUDA
assert abs(result_sum.item() - 327.8027) < 1e-2
assert abs(result_mean.item() - 0.4268) < 1e-3
assert abs(result_sum.item() - 328.9970) < 1e-2
assert abs(result_mean.item() - 0.4284) < 1e-3
def test_full_loop_device(self):
if torch_device == "mps":
......@@ -2820,12 +2762,7 @@ class KDPM2AncestralDiscreteSchedulerTest(SchedulerCommonTest):
scheduler = scheduler_class(**scheduler_config)
scheduler.set_timesteps(self.num_inference_steps, device=torch_device)
if torch_device == "mps":
# device type MPS is not supported for torch.Generator() api.
generator = torch.manual_seed(0)
else:
generator = torch.Generator(device=torch_device).manual_seed(0)
generator = torch.manual_seed(0)
model = self.dummy_model()
sample = self.dummy_sample_deter.to(torch_device) * scheduler.init_noise_sigma
......@@ -2841,13 +2778,8 @@ class KDPM2AncestralDiscreteSchedulerTest(SchedulerCommonTest):
result_sum = torch.sum(torch.abs(sample))
result_mean = torch.mean(torch.abs(sample))
if str(torch_device).startswith("cpu"):
assert abs(result_sum.item() - 13849.3945) < 1e-2
assert abs(result_mean.item() - 18.0331) < 5e-3
else:
# CUDA
assert abs(result_sum.item() - 13913.0332) < 1e-1
assert abs(result_mean.item() - 18.1159) < 1e-3
assert abs(result_sum.item() - 13849.3818) < 1e-1
assert abs(result_mean.item() - 18.0331) < 1e-3
# UnCLIPScheduler is a modified DDPMScheduler with a subset of the configuration.
......
# coding=utf-8
# Copyright 2022 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
from collections import defaultdict
def overwrite_file(file, class_name, test_name, correct_line, done_test):
done_test[file] += 1
with open(file, "r") as f:
lines = f.readlines()
class_regex = f"class {class_name}("
test_regex = f"{4 * ' '}def {test_name}("
line_begin_regex = f"{8 * ' '}{correct_line.split()[0]}"
another_line_begin_regex = f"{16 * ' '}{correct_line.split()[0]}"
in_class = False
in_func = False
in_line = False
insert_line = False
count = 0
spaces = 0
new_lines = []
for line in lines:
if line.startswith(class_regex):
in_class = True
elif in_class and line.startswith(test_regex):
in_func = True
elif in_class and in_func and (line.startswith(line_begin_regex) or line.startswith(another_line_begin_regex)):
spaces = len(line.split(correct_line.split()[0])[0])
count += 1
if count == done_test[file]:
in_line = True
if in_class and in_func and in_line:
if ")" not in line:
continue
else:
insert_line = True
if in_class and in_func and in_line and insert_line:
new_lines.append(f"{spaces * ' '}{correct_line}")
in_class = in_func = in_line = insert_line = False
else:
new_lines.append(line)
with open(file, "w") as f:
for line in new_lines:
f.write(line)
def main(correct, fail=None):
if fail is not None:
with open(fail, "r") as f:
test_failures = set([l.strip() for l in f.readlines()])
else:
test_failures = None
with open(correct, "r") as f:
correct_lines = f.readlines()
done_tests = defaultdict(int)
for line in correct_lines:
file, class_name, test_name, correct_line = line.split(";")
if test_failures is None or "::".join([file, class_name, test_name]) in test_failures:
overwrite_file(file, class_name, test_name, correct_line, done_tests)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--correct_filename", help="filename of tests with expected result")
parser.add_argument("--fail_filename", help="filename of test failures", type=str, default=None)
args = parser.parse_args()
main(args.correct_filename, args.fail_filename)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment