"git@developer.sourcefind.cn:OpenDAS/vision.git" did not exist on "c307db4bea58b09183bdbca67460c00999f9a844"
Unverified Commit 75a636da authored by baymax591's avatar baymax591 Committed by GitHub
Browse files

bugfix for npu not support float64 (#10123)



* bugfix for npu not support float64

* is_mps is_npu

---------
Co-authored-by: default avatar白超 <baichao19@huawei.com>
Co-authored-by: default avatarhlky <hlky@hlky.ac>
parent 4842f5d8
...@@ -404,10 +404,11 @@ def my_forward( ...@@ -404,10 +404,11 @@ def my_forward(
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
# This would be a good case for the `match` statement (Python 3.10+) # This would be a good case for the `match` statement (Python 3.10+)
is_mps = sample.device.type == "mps" is_mps = sample.device.type == "mps"
is_npu = sample.device.type == "npu"
if isinstance(timestep, float): if isinstance(timestep, float):
dtype = torch.float32 if is_mps else torch.float64 dtype = torch.float32 if (is_mps or is_npu) else torch.float64
else: else:
dtype = torch.int32 if is_mps else torch.int64 dtype = torch.int32 if (is_mps or is_npu) else torch.int64
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
elif len(timesteps.shape) == 0: elif len(timesteps.shape) == 0:
timesteps = timesteps[None].to(sample.device) timesteps = timesteps[None].to(sample.device)
......
...@@ -2806,10 +2806,11 @@ class MatryoshkaUNet2DConditionModel( ...@@ -2806,10 +2806,11 @@ class MatryoshkaUNet2DConditionModel(
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
# This would be a good case for the `match` statement (Python 3.10+) # This would be a good case for the `match` statement (Python 3.10+)
is_mps = sample.device.type == "mps" is_mps = sample.device.type == "mps"
is_npu = sample.device.type == "npu"
if isinstance(timestep, float): if isinstance(timestep, float):
dtype = torch.float32 if is_mps else torch.float64 dtype = torch.float32 if (is_mps or is_npu) else torch.float64
else: else:
dtype = torch.int32 if is_mps else torch.int64 dtype = torch.int32 if (is_mps or is_npu) else torch.int64
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
elif len(timesteps.shape) == 0: elif len(timesteps.shape) == 0:
timesteps = timesteps[None].to(sample.device) timesteps = timesteps[None].to(sample.device)
......
...@@ -1031,10 +1031,11 @@ class PixArtAlphaControlnetPipeline(DiffusionPipeline): ...@@ -1031,10 +1031,11 @@ class PixArtAlphaControlnetPipeline(DiffusionPipeline):
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
# This would be a good case for the `match` statement (Python 3.10+) # This would be a good case for the `match` statement (Python 3.10+)
is_mps = latent_model_input.device.type == "mps" is_mps = latent_model_input.device.type == "mps"
is_npu = latent_model_input.device.type == "npu"
if isinstance(current_timestep, float): if isinstance(current_timestep, float):
dtype = torch.float32 if is_mps else torch.float64 dtype = torch.float32 if (is_mps or is_npu) else torch.float64
else: else:
dtype = torch.int32 if is_mps else torch.int64 dtype = torch.int32 if (is_mps or is_npu) else torch.int64
current_timestep = torch.tensor([current_timestep], dtype=dtype, device=latent_model_input.device) current_timestep = torch.tensor([current_timestep], dtype=dtype, device=latent_model_input.device)
elif len(current_timestep.shape) == 0: elif len(current_timestep.shape) == 0:
current_timestep = current_timestep[None].to(latent_model_input.device) current_timestep = current_timestep[None].to(latent_model_input.device)
......
...@@ -258,10 +258,11 @@ class PromptDiffusionControlNetModel(ControlNetModel): ...@@ -258,10 +258,11 @@ class PromptDiffusionControlNetModel(ControlNetModel):
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
# This would be a good case for the `match` statement (Python 3.10+) # This would be a good case for the `match` statement (Python 3.10+)
is_mps = sample.device.type == "mps" is_mps = sample.device.type == "mps"
is_npu = sample.device.type == "npu"
if isinstance(timestep, float): if isinstance(timestep, float):
dtype = torch.float32 if is_mps else torch.float64 dtype = torch.float32 if (is_mps or is_npu) else torch.float64
else: else:
dtype = torch.int32 if is_mps else torch.int64 dtype = torch.int32 if (is_mps or is_npu) else torch.int64
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
elif len(timesteps.shape) == 0: elif len(timesteps.shape) == 0:
timesteps = timesteps[None].to(sample.device) timesteps = timesteps[None].to(sample.device)
......
...@@ -740,10 +740,11 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalModelMixin): ...@@ -740,10 +740,11 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
# This would be a good case for the `match` statement (Python 3.10+) # This would be a good case for the `match` statement (Python 3.10+)
is_mps = sample.device.type == "mps" is_mps = sample.device.type == "mps"
is_npu = sample.device.type == "npu"
if isinstance(timestep, float): if isinstance(timestep, float):
dtype = torch.float32 if is_mps else torch.float64 dtype = torch.float32 if (is_mps or is_npu) else torch.float64
else: else:
dtype = torch.int32 if is_mps else torch.int64 dtype = torch.int32 if (is_mps or is_npu) else torch.int64
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
elif len(timesteps.shape) == 0: elif len(timesteps.shape) == 0:
timesteps = timesteps[None].to(sample.device) timesteps = timesteps[None].to(sample.device)
......
...@@ -671,10 +671,11 @@ class SparseControlNetModel(ModelMixin, ConfigMixin, FromOriginalModelMixin): ...@@ -671,10 +671,11 @@ class SparseControlNetModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
# This would be a good case for the `match` statement (Python 3.10+) # This would be a good case for the `match` statement (Python 3.10+)
is_mps = sample.device.type == "mps" is_mps = sample.device.type == "mps"
is_npu = sample.device.type == "npu"
if isinstance(timestep, float): if isinstance(timestep, float):
dtype = torch.float32 if is_mps else torch.float64 dtype = torch.float32 if (is_mps or is_npu) else torch.float64
else: else:
dtype = torch.int32 if is_mps else torch.int64 dtype = torch.int32 if (is_mps or is_npu) else torch.int64
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
elif len(timesteps.shape) == 0: elif len(timesteps.shape) == 0:
timesteps = timesteps[None].to(sample.device) timesteps = timesteps[None].to(sample.device)
......
...@@ -681,10 +681,11 @@ class ControlNetUnionModel(ModelMixin, ConfigMixin, FromOriginalModelMixin): ...@@ -681,10 +681,11 @@ class ControlNetUnionModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
# This would be a good case for the `match` statement (Python 3.10+) # This would be a good case for the `match` statement (Python 3.10+)
is_mps = sample.device.type == "mps" is_mps = sample.device.type == "mps"
is_npu = sample.device.type == "npu"
if isinstance(timestep, float): if isinstance(timestep, float):
dtype = torch.float32 if is_mps else torch.float64 dtype = torch.float32 if (is_mps or is_npu) else torch.float64
else: else:
dtype = torch.int32 if is_mps else torch.int64 dtype = torch.int32 if (is_mps or is_npu) else torch.int64
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
elif len(timesteps.shape) == 0: elif len(timesteps.shape) == 0:
timesteps = timesteps[None].to(sample.device) timesteps = timesteps[None].to(sample.device)
......
...@@ -1088,10 +1088,11 @@ class UNetControlNetXSModel(ModelMixin, ConfigMixin): ...@@ -1088,10 +1088,11 @@ class UNetControlNetXSModel(ModelMixin, ConfigMixin):
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
# This would be a good case for the `match` statement (Python 3.10+) # This would be a good case for the `match` statement (Python 3.10+)
is_mps = sample.device.type == "mps" is_mps = sample.device.type == "mps"
is_npu = sample.device.type == "npu"
if isinstance(timestep, float): if isinstance(timestep, float):
dtype = torch.float32 if is_mps else torch.float64 dtype = torch.float32 if (is_mps or is_npu) else torch.float64
else: else:
dtype = torch.int32 if is_mps else torch.int64 dtype = torch.int32 if (is_mps or is_npu) else torch.int64
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
elif len(timesteps.shape) == 0: elif len(timesteps.shape) == 0:
timesteps = timesteps[None].to(sample.device) timesteps = timesteps[None].to(sample.device)
......
...@@ -915,10 +915,11 @@ class UNet2DConditionModel( ...@@ -915,10 +915,11 @@ class UNet2DConditionModel(
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
# This would be a good case for the `match` statement (Python 3.10+) # This would be a good case for the `match` statement (Python 3.10+)
is_mps = sample.device.type == "mps" is_mps = sample.device.type == "mps"
is_npu = sample.device.type == "npu"
if isinstance(timestep, float): if isinstance(timestep, float):
dtype = torch.float32 if is_mps else torch.float64 dtype = torch.float32 if (is_mps or is_npu) else torch.float64
else: else:
dtype = torch.int32 if is_mps else torch.int64 dtype = torch.int32 if (is_mps or is_npu) else torch.int64
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
elif len(timesteps.shape) == 0: elif len(timesteps.shape) == 0:
timesteps = timesteps[None].to(sample.device) timesteps = timesteps[None].to(sample.device)
......
...@@ -624,10 +624,11 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) ...@@ -624,10 +624,11 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
# This would be a good case for the `match` statement (Python 3.10+) # This would be a good case for the `match` statement (Python 3.10+)
is_mps = sample.device.type == "mps" is_mps = sample.device.type == "mps"
is_npu = sample.device.type == "npu"
if isinstance(timestep, float): if isinstance(timestep, float):
dtype = torch.float32 if is_mps else torch.float64 dtype = torch.float32 if (is_mps or is_npu) else torch.float64
else: else:
dtype = torch.int32 if is_mps else torch.int64 dtype = torch.int32 if (is_mps or is_npu) else torch.int64
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
elif len(timesteps.shape) == 0: elif len(timesteps.shape) == 0:
timesteps = timesteps[None].to(sample.device) timesteps = timesteps[None].to(sample.device)
......
...@@ -575,10 +575,11 @@ class I2VGenXLUNet(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin): ...@@ -575,10 +575,11 @@ class I2VGenXLUNet(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
# TODO: this requires sync between CPU and GPU. So try to pass `timesteps` as tensors if you can # TODO: this requires sync between CPU and GPU. So try to pass `timesteps` as tensors if you can
# This would be a good case for the `match` statement (Python 3.10+) # This would be a good case for the `match` statement (Python 3.10+)
is_mps = sample.device.type == "mps" is_mps = sample.device.type == "mps"
is_npu = sample.device.type == "npu"
if isinstance(timesteps, float): if isinstance(timesteps, float):
dtype = torch.float32 if is_mps else torch.float64 dtype = torch.float32 if (is_mps or is_npu) else torch.float64
else: else:
dtype = torch.int32 if is_mps else torch.int64 dtype = torch.int32 if (is_mps or is_npu) else torch.int64
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
elif len(timesteps.shape) == 0: elif len(timesteps.shape) == 0:
timesteps = timesteps[None].to(sample.device) timesteps = timesteps[None].to(sample.device)
......
...@@ -2114,10 +2114,11 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, Peft ...@@ -2114,10 +2114,11 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, Peft
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
# This would be a good case for the `match` statement (Python 3.10+) # This would be a good case for the `match` statement (Python 3.10+)
is_mps = sample.device.type == "mps" is_mps = sample.device.type == "mps"
is_npu = sample.device.type == "npu"
if isinstance(timestep, float): if isinstance(timestep, float):
dtype = torch.float32 if is_mps else torch.float64 dtype = torch.float32 if (is_mps or is_npu) else torch.float64
else: else:
dtype = torch.int32 if is_mps else torch.int64 dtype = torch.int32 if (is_mps or is_npu) else torch.int64
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
elif len(timesteps.shape) == 0: elif len(timesteps.shape) == 0:
timesteps = timesteps[None].to(sample.device) timesteps = timesteps[None].to(sample.device)
......
...@@ -402,10 +402,11 @@ class UNetSpatioTemporalConditionModel(ModelMixin, ConfigMixin, UNet2DConditionL ...@@ -402,10 +402,11 @@ class UNetSpatioTemporalConditionModel(ModelMixin, ConfigMixin, UNet2DConditionL
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
# This would be a good case for the `match` statement (Python 3.10+) # This would be a good case for the `match` statement (Python 3.10+)
is_mps = sample.device.type == "mps" is_mps = sample.device.type == "mps"
is_npu = sample.device.type == "npu"
if isinstance(timestep, float): if isinstance(timestep, float):
dtype = torch.float32 if is_mps else torch.float64 dtype = torch.float32 if (is_mps or is_npu) else torch.float64
else: else:
dtype = torch.int32 if is_mps else torch.int64 dtype = torch.int32 if (is_mps or is_npu) else torch.int64
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
elif len(timesteps.shape) == 0: elif len(timesteps.shape) == 0:
timesteps = timesteps[None].to(sample.device) timesteps = timesteps[None].to(sample.device)
......
...@@ -768,10 +768,11 @@ class AudioLDM2UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoad ...@@ -768,10 +768,11 @@ class AudioLDM2UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoad
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
# This would be a good case for the `match` statement (Python 3.10+) # This would be a good case for the `match` statement (Python 3.10+)
is_mps = sample.device.type == "mps" is_mps = sample.device.type == "mps"
is_npu = sample.device.type == "npu"
if isinstance(timestep, float): if isinstance(timestep, float):
dtype = torch.float32 if is_mps else torch.float64 dtype = torch.float32 if (is_mps or is_npu) else torch.float64
else: else:
dtype = torch.int32 if is_mps else torch.int64 dtype = torch.int32 if (is_mps or is_npu) else torch.int64
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
elif len(timesteps.shape) == 0: elif len(timesteps.shape) == 0:
timesteps = timesteps[None].to(sample.device) timesteps = timesteps[None].to(sample.device)
......
...@@ -1163,10 +1163,11 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin): ...@@ -1163,10 +1163,11 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
# This would be a good case for the `match` statement (Python 3.10+) # This would be a good case for the `match` statement (Python 3.10+)
is_mps = sample.device.type == "mps" is_mps = sample.device.type == "mps"
is_npu = sample.device.type == "npu"
if isinstance(timestep, float): if isinstance(timestep, float):
dtype = torch.float32 if is_mps else torch.float64 dtype = torch.float32 if (is_mps or is_npu) else torch.float64
else: else:
dtype = torch.int32 if is_mps else torch.int64 dtype = torch.int32 if (is_mps or is_npu) else torch.int64
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
elif len(timesteps.shape) == 0: elif len(timesteps.shape) == 0:
timesteps = timesteps[None].to(sample.device) timesteps = timesteps[None].to(sample.device)
......
...@@ -187,10 +187,11 @@ class DiTPipeline(DiffusionPipeline): ...@@ -187,10 +187,11 @@ class DiTPipeline(DiffusionPipeline):
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
# This would be a good case for the `match` statement (Python 3.10+) # This would be a good case for the `match` statement (Python 3.10+)
is_mps = latent_model_input.device.type == "mps" is_mps = latent_model_input.device.type == "mps"
is_npu = latent_model_input.device.type == "npu"
if isinstance(timesteps, float): if isinstance(timesteps, float):
dtype = torch.float32 if is_mps else torch.float64 dtype = torch.float32 if (is_mps or is_npu) else torch.float64
else: else:
dtype = torch.int32 if is_mps else torch.int64 dtype = torch.int32 if (is_mps or is_npu) else torch.int64
timesteps = torch.tensor([timesteps], dtype=dtype, device=latent_model_input.device) timesteps = torch.tensor([timesteps], dtype=dtype, device=latent_model_input.device)
elif len(timesteps.shape) == 0: elif len(timesteps.shape) == 0:
timesteps = timesteps[None].to(latent_model_input.device) timesteps = timesteps[None].to(latent_model_input.device)
......
...@@ -798,10 +798,11 @@ class LattePipeline(DiffusionPipeline): ...@@ -798,10 +798,11 @@ class LattePipeline(DiffusionPipeline):
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
# This would be a good case for the `match` statement (Python 3.10+) # This would be a good case for the `match` statement (Python 3.10+)
is_mps = latent_model_input.device.type == "mps" is_mps = latent_model_input.device.type == "mps"
is_npu = latent_model_input.device.type == "npu"
if isinstance(current_timestep, float): if isinstance(current_timestep, float):
dtype = torch.float32 if is_mps else torch.float64 dtype = torch.float32 if (is_mps or is_npu) else torch.float64
else: else:
dtype = torch.int32 if is_mps else torch.int64 dtype = torch.int32 if (is_mps or is_npu) else torch.int64
current_timestep = torch.tensor([current_timestep], dtype=dtype, device=latent_model_input.device) current_timestep = torch.tensor([current_timestep], dtype=dtype, device=latent_model_input.device)
elif len(current_timestep.shape) == 0: elif len(current_timestep.shape) == 0:
current_timestep = current_timestep[None].to(latent_model_input.device) current_timestep = current_timestep[None].to(latent_model_input.device)
......
...@@ -806,10 +806,11 @@ class LuminaText2ImgPipeline(DiffusionPipeline): ...@@ -806,10 +806,11 @@ class LuminaText2ImgPipeline(DiffusionPipeline):
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
# This would be a good case for the `match` statement (Python 3.10+) # This would be a good case for the `match` statement (Python 3.10+)
is_mps = latent_model_input.device.type == "mps" is_mps = latent_model_input.device.type == "mps"
is_npu = latent_model_input.device.type == "npu"
if isinstance(current_timestep, float): if isinstance(current_timestep, float):
dtype = torch.float32 if is_mps else torch.float64 dtype = torch.float32 if (is_mps or is_npu) else torch.float64
else: else:
dtype = torch.int32 if is_mps else torch.int64 dtype = torch.int32 if (is_mps or is_npu) else torch.int64
current_timestep = torch.tensor( current_timestep = torch.tensor(
[current_timestep], [current_timestep],
dtype=dtype, dtype=dtype,
......
...@@ -807,10 +807,11 @@ class PixArtSigmaPAGPipeline(DiffusionPipeline, PAGMixin): ...@@ -807,10 +807,11 @@ class PixArtSigmaPAGPipeline(DiffusionPipeline, PAGMixin):
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
# This would be a good case for the `match` statement (Python 3.10+) # This would be a good case for the `match` statement (Python 3.10+)
is_mps = latent_model_input.device.type == "mps" is_mps = latent_model_input.device.type == "mps"
is_npu = latent_model_input.device.type == "npu"
if isinstance(current_timestep, float): if isinstance(current_timestep, float):
dtype = torch.float32 if is_mps else torch.float64 dtype = torch.float32 if (is_mps or is_npu) else torch.float64
else: else:
dtype = torch.int32 if is_mps else torch.int64 dtype = torch.int32 if (is_mps or is_npu) else torch.int64
current_timestep = torch.tensor([current_timestep], dtype=dtype, device=latent_model_input.device) current_timestep = torch.tensor([current_timestep], dtype=dtype, device=latent_model_input.device)
elif len(current_timestep.shape) == 0: elif len(current_timestep.shape) == 0:
current_timestep = current_timestep[None].to(latent_model_input.device) current_timestep = current_timestep[None].to(latent_model_input.device)
......
...@@ -907,10 +907,11 @@ class PixArtAlphaPipeline(DiffusionPipeline): ...@@ -907,10 +907,11 @@ class PixArtAlphaPipeline(DiffusionPipeline):
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
# This would be a good case for the `match` statement (Python 3.10+) # This would be a good case for the `match` statement (Python 3.10+)
is_mps = latent_model_input.device.type == "mps" is_mps = latent_model_input.device.type == "mps"
is_npu = latent_model_input.device.type == "npu"
if isinstance(current_timestep, float): if isinstance(current_timestep, float):
dtype = torch.float32 if is_mps else torch.float64 dtype = torch.float32 if (is_mps or is_npu) else torch.float64
else: else:
dtype = torch.int32 if is_mps else torch.int64 dtype = torch.int32 if (is_mps or is_npu) else torch.int64
current_timestep = torch.tensor([current_timestep], dtype=dtype, device=latent_model_input.device) current_timestep = torch.tensor([current_timestep], dtype=dtype, device=latent_model_input.device)
elif len(current_timestep.shape) == 0: elif len(current_timestep.shape) == 0:
current_timestep = current_timestep[None].to(latent_model_input.device) current_timestep = current_timestep[None].to(latent_model_input.device)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment