"examples/vscode:/vscode.git/clone" did not exist on "e3103e171f08361dc7c781685bd8127c8f672249"
Commit 2a23ba0b authored by comfyanonymous's avatar comfyanonymous
Browse files

Fix unet ops not entirely on GPU.

parent 844dbf97
......@@ -170,8 +170,8 @@ def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
if not repeat_only:
half = dim // 2
freqs = torch.exp(
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
).to(device=timesteps.device)
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=timesteps.device) / half
)
args = timesteps[:, None].float() * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment