Unverified Commit 0feb21a1 authored by Anton Lozhkov's avatar Anton Lozhkov Committed by GitHub
Browse files

[Tests] Fix mps+generator fast tests (#1230)

* [Tests] Fix mps+generator fast tests

* mps for Euler

* retry

* warmup issue again?

* fix reproducible initial noise

* Revert "fix reproducible initial noise"

This reverts commit f300d05cb9782ed320064a0c58577a32d4139e6d.

* fix reproducible initial noise

* fix device
parent 187de443
...@@ -136,7 +136,7 @@ jobs: ...@@ -136,7 +136,7 @@ jobs:
- name: Run fast PyTorch tests on M1 (MPS) - name: Run fast PyTorch tests on M1 (MPS)
shell: arch -arch arm64 bash {0} shell: arch -arch arm64 bash {0}
run: | run: |
${CONDA_RUN} python -m pytest -n 1 -s -v --make-reports=tests_torch_mps tests/ ${CONDA_RUN} python -m pytest -n 0 -s -v --make-reports=tests_torch_mps tests/
- name: Failure short reports - name: Failure short reports
if: ${{ failure() }} if: ${{ failure() }}
......
...@@ -78,7 +78,7 @@ class DDIMPipeline(DiffusionPipeline): ...@@ -78,7 +78,7 @@ class DDIMPipeline(DiffusionPipeline):
if generator is not None and generator.device.type != self.device.type and self.device.type != "mps": if generator is not None and generator.device.type != self.device.type and self.device.type != "mps":
message = ( message = (
f"The `generator` device is `{generator.device}` and does not match the pipeline " f"The `generator` device is `{generator.device}` and does not match the pipeline "
f"device `{self.device}`, so the `generator` will be set to `None`. " f"device `{self.device}`, so the `generator` will be ignored. "
f'Please use `generator=torch.Generator(device="{self.device}")` instead.' f'Please use `generator=torch.Generator(device="{self.device}")` instead.'
) )
deprecate( deprecate(
...@@ -89,11 +89,13 @@ class DDIMPipeline(DiffusionPipeline): ...@@ -89,11 +89,13 @@ class DDIMPipeline(DiffusionPipeline):
generator = None generator = None
# Sample gaussian noise to begin loop # Sample gaussian noise to begin loop
image = torch.randn( image_shape = (batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size)
(batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size), if self.device.type == "mps":
generator=generator, # randn does not work reproducibly on mps
device=self.device, image = torch.randn(image_shape, generator=generator)
) image = image.to(self.device)
else:
image = torch.randn(image_shape, generator=generator, device=self.device)
# set step values # set step values
self.scheduler.set_timesteps(num_inference_steps) self.scheduler.set_timesteps(num_inference_steps)
......
...@@ -83,7 +83,7 @@ class DDPMPipeline(DiffusionPipeline): ...@@ -83,7 +83,7 @@ class DDPMPipeline(DiffusionPipeline):
if generator is not None and generator.device.type != self.device.type and self.device.type != "mps": if generator is not None and generator.device.type != self.device.type and self.device.type != "mps":
message = ( message = (
f"The `generator` device is `{generator.device}` and does not match the pipeline " f"The `generator` device is `{generator.device}` and does not match the pipeline "
f"device `{self.device}`, so the `generator` will be set to `None`. " f"device `{self.device}`, so the `generator` will be ignored. "
f'Please use `torch.Generator(device="{self.device}")` instead.' f'Please use `torch.Generator(device="{self.device}")` instead.'
) )
deprecate( deprecate(
...@@ -94,11 +94,13 @@ class DDPMPipeline(DiffusionPipeline): ...@@ -94,11 +94,13 @@ class DDPMPipeline(DiffusionPipeline):
generator = None generator = None
# Sample gaussian noise to begin loop # Sample gaussian noise to begin loop
image = torch.randn( image_shape = (batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size)
(batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size), if self.device.type == "mps":
generator=generator, # randn does not work reproducibly on mps
device=self.device, image = torch.randn(image_shape, generator=generator)
) image = image.to(self.device)
else:
image = torch.randn(image_shape, generator=generator, device=self.device)
# set step values # set step values
self.scheduler.set_timesteps(num_inference_steps) self.scheduler.set_timesteps(num_inference_steps)
......
...@@ -81,10 +81,14 @@ class DDPMPipelineFastTests(PipelineTesterMixin, unittest.TestCase): ...@@ -81,10 +81,14 @@ class DDPMPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
if torch_device == "mps": if torch_device == "mps":
_ = ddpm(num_inference_steps=1) _ = ddpm(num_inference_steps=1)
generator = torch.Generator(device=torch_device).manual_seed(0) if torch_device == "mps":
# device type MPS is not supported for torch.Generator() api.
generator = torch.manual_seed(0)
else:
generator = torch.Generator(device=torch_device).manual_seed(0)
image = ddpm(generator=generator, num_inference_steps=2, output_type="numpy").images image = ddpm(generator=generator, num_inference_steps=2, output_type="numpy").images
generator = torch.Generator(device=torch_device).manual_seed(0) generator = generator.manual_seed(0)
image_eps = ddpm(generator=generator, num_inference_steps=2, output_type="numpy", predict_epsilon=False)[0] image_eps = ddpm(generator=generator, num_inference_steps=2, output_type="numpy", predict_epsilon=False)[0]
image_slice = image[0, -3:, -3:, -1] image_slice = image[0, -3:, -3:, -1]
......
...@@ -1281,7 +1281,11 @@ class EulerDiscreteSchedulerTest(SchedulerCommonTest): ...@@ -1281,7 +1281,11 @@ class EulerDiscreteSchedulerTest(SchedulerCommonTest):
scheduler.set_timesteps(self.num_inference_steps) scheduler.set_timesteps(self.num_inference_steps)
generator = torch.Generator(torch_device).manual_seed(0) if torch_device == "mps":
# device type MPS is not supported for torch.Generator() api.
generator = torch.manual_seed(0)
else:
generator = torch.Generator(device=torch_device).manual_seed(0)
model = self.dummy_model() model = self.dummy_model()
sample = self.dummy_sample_deter * scheduler.init_noise_sigma sample = self.dummy_sample_deter * scheduler.init_noise_sigma
...@@ -1308,7 +1312,11 @@ class EulerDiscreteSchedulerTest(SchedulerCommonTest): ...@@ -1308,7 +1312,11 @@ class EulerDiscreteSchedulerTest(SchedulerCommonTest):
scheduler.set_timesteps(self.num_inference_steps, device=torch_device) scheduler.set_timesteps(self.num_inference_steps, device=torch_device)
generator = torch.Generator(torch_device).manual_seed(0) if torch_device == "mps":
# device type MPS is not supported for torch.Generator() api.
generator = torch.manual_seed(0)
else:
generator = torch.Generator(device=torch_device).manual_seed(0)
model = self.dummy_model() model = self.dummy_model()
sample = self.dummy_sample_deter * scheduler.init_noise_sigma sample = self.dummy_sample_deter * scheduler.init_noise_sigma
...@@ -1364,7 +1372,11 @@ class EulerAncestralDiscreteSchedulerTest(SchedulerCommonTest): ...@@ -1364,7 +1372,11 @@ class EulerAncestralDiscreteSchedulerTest(SchedulerCommonTest):
scheduler.set_timesteps(self.num_inference_steps) scheduler.set_timesteps(self.num_inference_steps)
generator = torch.Generator(device=torch_device).manual_seed(0) if torch_device == "mps":
# device type MPS is not supported for torch.Generator() api.
generator = torch.manual_seed(0)
else:
generator = torch.Generator(device=torch_device).manual_seed(0)
model = self.dummy_model() model = self.dummy_model()
sample = self.dummy_sample_deter * scheduler.init_noise_sigma sample = self.dummy_sample_deter * scheduler.init_noise_sigma
...@@ -1381,7 +1393,7 @@ class EulerAncestralDiscreteSchedulerTest(SchedulerCommonTest): ...@@ -1381,7 +1393,7 @@ class EulerAncestralDiscreteSchedulerTest(SchedulerCommonTest):
result_sum = torch.sum(torch.abs(sample)) result_sum = torch.sum(torch.abs(sample))
result_mean = torch.mean(torch.abs(sample)) result_mean = torch.mean(torch.abs(sample))
if str(torch_device).startswith("cpu"): if torch_device in ["cpu", "mps"]:
assert abs(result_sum.item() - 152.3192) < 1e-2 assert abs(result_sum.item() - 152.3192) < 1e-2
assert abs(result_mean.item() - 0.1983) < 1e-3 assert abs(result_mean.item() - 0.1983) < 1e-3
else: else:
...@@ -1396,7 +1408,11 @@ class EulerAncestralDiscreteSchedulerTest(SchedulerCommonTest): ...@@ -1396,7 +1408,11 @@ class EulerAncestralDiscreteSchedulerTest(SchedulerCommonTest):
scheduler.set_timesteps(self.num_inference_steps, device=torch_device) scheduler.set_timesteps(self.num_inference_steps, device=torch_device)
generator = torch.Generator(device=torch_device).manual_seed(0) if torch_device == "mps":
# device type MPS is not supported for torch.Generator() api.
generator = torch.manual_seed(0)
else:
generator = torch.Generator(device=torch_device).manual_seed(0)
model = self.dummy_model() model = self.dummy_model()
sample = self.dummy_sample_deter * scheduler.init_noise_sigma sample = self.dummy_sample_deter * scheduler.init_noise_sigma
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment