Commit c24f8973 authored by comfyanonymous's avatar comfyanonymous
Browse files

Fix to get fp8 working on T5 base.

parent a5991a7a
...@@ -236,4 +236,6 @@ class T5(torch.nn.Module): ...@@ -236,4 +236,6 @@ class T5(torch.nn.Module):
def forward(self, input_ids, *args, **kwargs): def forward(self, input_ids, *args, **kwargs):
x = self.shared(input_ids, out_dtype=kwargs.get("dtype", torch.float32)) x = self.shared(input_ids, out_dtype=kwargs.get("dtype", torch.float32))
if self.dtype not in [torch.float32, torch.float16, torch.bfloat16]:
x = torch.nan_to_num(x) #Fix for fp8 T5 base
return self.encoder(x, *args, **kwargs) return self.encoder(x, *args, **kwargs)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment