Unverified Commit c69d3b65 authored by Younes Belkada's avatar Younes Belkada Committed by GitHub
Browse files

[`PEFT`] Fix PEFT batch size > 1 (#338)

parent 29ee66d9
...@@ -89,10 +89,10 @@ class WQLinearMMFunction(Function): ...@@ -89,10 +89,10 @@ class WQLinearMMFunction(Function):
) )
if ctx.needs_input_grad[0]: if ctx.needs_input_grad[0]:
# 2D matrix multiplication, unsqueeze to 3D # 3D matmul using torch.bmm: https://pytorch.org/docs/stable/generated/torch.bmm.html#torch.bmm
grad_input = grad_output.squeeze(0).mm( # to propagate gradient across all batch sizes.
weights.transpose(0, 1) batch_size = grad_output.shape[0]
).unsqueeze(0) grad_input = grad_output.bmm(weights.transpose(0, 1).unsqueeze(0).repeat(batch_size, 1, 1))
return grad_input, None, None, None, None, None, None, None return grad_input, None, None, None, None, None, None, None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment