Unverified Commit d8bc1a4e authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[Torch.compile] Fixes torch compile graph break (#4315)

* fix torch compile

* Fix all

* make style
parent 80c10d82
......@@ -14,6 +14,7 @@
from typing import Optional
import torch.nn.functional as F
from torch import nn
......@@ -91,7 +92,9 @@ class LoRACompatibleConv(nn.Conv2d):
def forward(self, x):
if self.lora_layer is None:
return super().forward(x)
# make sure to the functional Conv2D function as otherwise torch.compile's graph will break
# see: https://github.com/huggingface/diffusers/pull/4315
return F.conv2d(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
else:
return super().forward(x) + self.lora_layer(x)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment