"vscode:/vscode.git/clone" did not exist on "b9ae7e5da19573de3d9e727884f6609df04f935d"
Unverified Commit 75a636da authored by baymax591's avatar baymax591 Committed by GitHub
Browse files

bugfix for npu not support float64 (#10123)



* bugfix for npu not support float64

* is_mps is_npu

---------
Co-authored-by: default avatar白超 <baichao19@huawei.com>
Co-authored-by: default avatarhlky <hlky@hlky.ac>
parent 4842f5d8
......@@ -822,10 +822,11 @@ class PixArtSigmaPipeline(DiffusionPipeline):
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
# This would be a good case for the `match` statement (Python 3.10+)
is_mps = latent_model_input.device.type == "mps"
is_npu = latent_model_input.device.type == "npu"
if isinstance(current_timestep, float):
dtype = torch.float32 if is_mps else torch.float64
dtype = torch.float32 if (is_mps or is_npu) else torch.float64
else:
dtype = torch.int32 if is_mps else torch.int64
dtype = torch.int32 if (is_mps or is_npu) else torch.int64
current_timestep = torch.tensor([current_timestep], dtype=dtype, device=latent_model_input.device)
elif len(current_timestep.shape) == 0:
current_timestep = current_timestep[None].to(latent_model_input.device)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment