Commit 46dc050c authored by comfyanonymous's avatar comfyanonymous
Browse files

Fix potential tensors being on different devices issues.

parent 90aa5970
...@@ -32,7 +32,7 @@ class ClipTokenWeightEncoder: ...@@ -32,7 +32,7 @@ class ClipTokenWeightEncoder:
output.append(z) output.append(z)
if (len(output) == 0): if (len(output) == 0):
return z_empty, first_pooled return z_empty.cpu(), first_pooled.cpu()
return torch.cat(output, dim=-2).cpu(), first_pooled.cpu() return torch.cat(output, dim=-2).cpu(), first_pooled.cpu()
class SD1ClipModel(torch.nn.Module, ClipTokenWeightEncoder): class SD1ClipModel(torch.nn.Module, ClipTokenWeightEncoder):
...@@ -139,7 +139,7 @@ class SD1ClipModel(torch.nn.Module, ClipTokenWeightEncoder): ...@@ -139,7 +139,7 @@ class SD1ClipModel(torch.nn.Module, ClipTokenWeightEncoder):
pooled_output = outputs.pooler_output pooled_output = outputs.pooler_output
if self.text_projection is not None: if self.text_projection is not None:
pooled_output = pooled_output @ self.text_projection pooled_output = pooled_output.to(self.text_projection.device) @ self.text_projection
return z.float(), pooled_output.float() return z.float(), pooled_output.float()
def encode(self, tokens): def encode(self, tokens):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment