Commit 1281f933 authored by comfyanonymous's avatar comfyanonymous
Browse files

Small optimization.

parent f2e844e0
...@@ -243,9 +243,9 @@ class TimestepEmbedder(nn.Module): ...@@ -243,9 +243,9 @@ class TimestepEmbedder(nn.Module):
half = dim // 2 half = dim // 2
freqs = torch.exp( freqs = torch.exp(
-math.log(max_period) -math.log(max_period)
* torch.arange(start=0, end=half, dtype=torch.float32) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device)
/ half / half
).to(device=t.device) )
args = t[:, None].float() * freqs[None] args = t[:, None].float() * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2: if dim % 2:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment