"src/git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "561ab54de3d3aaa9007e76aeb3b15e8be3ed353f"
Unverified Commit dcf836cf authored by hlky's avatar hlky Committed by GitHub
Browse files

Use float32 on mps or npu in transformer_hidream_image's rope (#11316)

parent 1cb73cb1
...@@ -95,7 +95,12 @@ class HiDreamImagePatchEmbed(nn.Module): ...@@ -95,7 +95,12 @@ class HiDreamImagePatchEmbed(nn.Module):
def rope(pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor: def rope(pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor:
assert dim % 2 == 0, "The dimension must be even." assert dim % 2 == 0, "The dimension must be even."
scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim is_mps = pos.device.type == "mps"
is_npu = pos.device.type == "npu"
dtype = torch.float32 if (is_mps or is_npu) else torch.float64
scale = torch.arange(0, dim, 2, dtype=dtype, device=pos.device) / dim
omega = 1.0 / (theta**scale) omega = 1.0 / (theta**scale)
batch_size, seq_length = pos.shape batch_size, seq_length = pos.shape
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment