Commit fcef47f0 authored by comfyanonymous's avatar comfyanonymous
Browse files

Fix bug.

parent 2d880fec
...@@ -156,10 +156,10 @@ class SDXLRefiner(BaseModel): ...@@ -156,10 +156,10 @@ class SDXLRefiner(BaseModel):
print(clip_pooled.shape, width, height, crop_w, crop_h, aesthetic_score) print(clip_pooled.shape, width, height, crop_w, crop_h, aesthetic_score)
out = [] out = []
out.append(self.embedder(torch.Tensor([width])))
out.append(self.embedder(torch.Tensor([height]))) out.append(self.embedder(torch.Tensor([height])))
out.append(self.embedder(torch.Tensor([crop_w]))) out.append(self.embedder(torch.Tensor([width])))
out.append(self.embedder(torch.Tensor([crop_h]))) out.append(self.embedder(torch.Tensor([crop_h])))
out.append(self.embedder(torch.Tensor([crop_w])))
out.append(self.embedder(torch.Tensor([aesthetic_score]))) out.append(self.embedder(torch.Tensor([aesthetic_score])))
flat = torch.flatten(torch.cat(out))[None, ] flat = torch.flatten(torch.cat(out))[None, ]
return torch.cat((clip_pooled.to(flat.device), flat), dim=1) return torch.cat((clip_pooled.to(flat.device), flat), dim=1)
...@@ -180,11 +180,11 @@ class SDXL(BaseModel): ...@@ -180,11 +180,11 @@ class SDXL(BaseModel):
print(clip_pooled.shape, width, height, crop_w, crop_h, target_width, target_height) print(clip_pooled.shape, width, height, crop_w, crop_h, target_width, target_height)
out = [] out = []
out.append(self.embedder(torch.Tensor([width])))
out.append(self.embedder(torch.Tensor([height]))) out.append(self.embedder(torch.Tensor([height])))
out.append(self.embedder(torch.Tensor([crop_w]))) out.append(self.embedder(torch.Tensor([width])))
out.append(self.embedder(torch.Tensor([crop_h]))) out.append(self.embedder(torch.Tensor([crop_h])))
out.append(self.embedder(torch.Tensor([target_width]))) out.append(self.embedder(torch.Tensor([crop_w])))
out.append(self.embedder(torch.Tensor([target_height]))) out.append(self.embedder(torch.Tensor([target_height])))
out.append(self.embedder(torch.Tensor([target_width])))
flat = torch.flatten(torch.cat(out))[None, ] flat = torch.flatten(torch.cat(out))[None, ]
return torch.cat((clip_pooled.to(flat.device), flat), dim=1) return torch.cat((clip_pooled.to(flat.device), flat), dim=1)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment