Unverified Commit 3ceaa280 authored by Pedro Cuenca's avatar Pedro Cuenca Committed by GitHub
Browse files

Do not use torch.long in mps (#1488)



* Do not use torch.long in mps

Addresses #1056.

* Use torch.int instead of float.

* Propagate changes.

* Do not silently change float -> int.

* Propagate changes.

* Apply suggestions from code review
Co-authored-by: default avatarAnton Lozhkov <anton@huggingface.co>
Co-authored-by: default avatarAnton Lozhkov <anton@huggingface.co>
parent a816a87a
...@@ -299,8 +299,14 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin): ...@@ -299,8 +299,14 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
timesteps = timestep timesteps = timestep
if not torch.is_tensor(timesteps): if not torch.is_tensor(timesteps):
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device) # This would be a good case for the `match` statement (Python 3.10+)
elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0: is_mps = sample.device.type == "mps"
if torch.is_floating_point(timesteps):
dtype = torch.float32 if is_mps else torch.float64
else:
dtype = torch.int32 if is_mps else torch.int64
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
elif len(timesteps.shape) == 0:
timesteps = timesteps[None].to(sample.device) timesteps = timesteps[None].to(sample.device)
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
......
...@@ -377,8 +377,14 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin): ...@@ -377,8 +377,14 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
timesteps = timestep timesteps = timestep
if not torch.is_tensor(timesteps): if not torch.is_tensor(timesteps):
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device) # This would be a good case for the `match` statement (Python 3.10+)
elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0: is_mps = sample.device.type == "mps"
if torch.is_floating_point(timesteps):
dtype = torch.float32 if is_mps else torch.float64
else:
dtype = torch.int32 if is_mps else torch.int64
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
elif len(timesteps.shape) == 0:
timesteps = timesteps[None].to(sample.device) timesteps = timesteps[None].to(sample.device)
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment