Unverified Commit c69d3b65 authored by Younes Belkada's avatar Younes Belkada Committed by GitHub
Browse files

[`PEFT`] Fix PEFT batch size > 1 (#338)

parent 29ee66d9
......@@ -89,10 +89,10 @@ class WQLinearMMFunction(Function):
)
if ctx.needs_input_grad[0]:
# 2D matrix multiplication, unsqueeze to 3D
grad_input = grad_output.squeeze(0).mm(
weights.transpose(0, 1)
).unsqueeze(0)
# 3D matmul using torch.bmm: https://pytorch.org/docs/stable/generated/torch.bmm.html#torch.bmm
# to propagate gradient across all batch sizes.
batch_size = grad_output.shape[0]
grad_input = grad_output.bmm(weights.transpose(0, 1).unsqueeze(0).repeat(batch_size, 1, 1))
return grad_input, None, None, None, None, None, None, None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment