Unverified Commit 2423cca3 authored by Brian Hirsh's avatar Brian Hirsh Committed by GitHub
Browse files

fix backward for when query and key have different contiguity (#818)

parent 46879364
......@@ -830,7 +830,7 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si
TORCH_CHECK(dv.stride(-1) == 1, "dv must have contiguous last dimension");
CHECK_SHAPE(dv, batch_size, seqlen_k, num_heads_k, head_size);
} else {
dv = torch::empty_like(k);
dv = torch::empty_like(v);
}
at::Tensor dout_padded;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment