Commit c7821444 authored by comfyanonymous's avatar comfyanonymous
Browse files

Fix clip vision lowvram mode not working.

parent e478b179
...@@ -151,7 +151,7 @@ class CLIPVisionEmbeddings(torch.nn.Module): ...@@ -151,7 +151,7 @@ class CLIPVisionEmbeddings(torch.nn.Module):
def forward(self, pixel_values): def forward(self, pixel_values):
embeds = self.patch_embedding(pixel_values).flatten(2).transpose(1, 2) embeds = self.patch_embedding(pixel_values).flatten(2).transpose(1, 2)
return torch.cat([self.class_embedding.expand(pixel_values.shape[0], 1, -1), embeds], dim=1) + self.position_embedding.weight return torch.cat([self.class_embedding.to(embeds.device).expand(pixel_values.shape[0], 1, -1), embeds], dim=1) + self.position_embedding.weight.to(embeds.device)
class CLIPVision(torch.nn.Module): class CLIPVision(torch.nn.Module):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment