Commit d9a5cb29 authored by Tri Dao's avatar Tri Dao
Browse files

Fix dv = torch::empty_like(k) for mha_bwd_varlen as well

parent a190df01
...@@ -1069,7 +1069,7 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size ...@@ -1069,7 +1069,7 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
TORCH_CHECK(dv.stride(-1) == 1, "dv must have contiguous last dimension"); TORCH_CHECK(dv.stride(-1) == 1, "dv must have contiguous last dimension");
CHECK_SHAPE(dv, total_k, num_heads_k, head_size); CHECK_SHAPE(dv, total_k, num_heads_k, head_size);
} else { } else {
dv = torch::empty_like(k); dv = torch::empty_like(v);
} }
at::Tensor dout_padded; at::Tensor dout_padded;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment