Commit 11e3221f authored by comfyanonymous's avatar comfyanonymous
Browse files

fp8 weight support for Stable Cascade.

parent f8706546
...@@ -84,7 +84,7 @@ class GlobalResponseNorm(nn.Module): ...@@ -84,7 +84,7 @@ class GlobalResponseNorm(nn.Module):
def forward(self, x): def forward(self, x):
Gx = torch.norm(x, p=2, dim=(1, 2), keepdim=True) Gx = torch.norm(x, p=2, dim=(1, 2), keepdim=True)
Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6) Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6)
return self.gamma.to(x.device) * (x * Nx) + self.beta.to(x.device) + x return self.gamma.to(device=x.device, dtype=x.dtype) * (x * Nx) + self.beta.to(device=x.device, dtype=x.dtype) + x
class ResBlock(nn.Module): class ResBlock(nn.Module):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment