Commit fb3b7282 authored by comfyanonymous's avatar comfyanonymous
Browse files

Fix issue where autocast fp32 CLIP gave different results from regular.

parent 7d401ed1
......@@ -60,6 +60,9 @@ class SD1ClipModel(torch.nn.Module, ClipTokenWeightEncoder):
if dtype is not None:
self.transformer.to(dtype)
self.transformer.text_model.embeddings.token_embedding.to(torch.float32)
self.transformer.text_model.embeddings.position_embedding.to(torch.float32)
self.max_length = max_length
if freeze:
self.freeze()
......@@ -138,7 +141,7 @@ class SD1ClipModel(torch.nn.Module, ClipTokenWeightEncoder):
tokens = self.set_up_textual_embeddings(tokens, backup_embeds)
tokens = torch.LongTensor(tokens).to(device)
if backup_embeds.weight.dtype != torch.float32:
if self.transformer.text_model.final_layer_norm.weight.dtype != torch.float32:
precision_scope = torch.autocast
else:
precision_scope = lambda a, b: contextlib.nullcontext(a)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment