Unverified Commit 7d3409be authored by Woosuk Kwon's avatar Woosuk Kwon Committed by GitHub
Browse files

Remove redundant code in `mha_fwd` (#29)


Signed-off-by: default avatarWoosuk Kwon <woosuk.kwon@berkeley.edu>
parent d886f881
...@@ -406,22 +406,23 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size ...@@ -406,22 +406,23 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
params, batch_size, num_heads, head_size, seqlen_k, seqlen_q, params, batch_size, num_heads, head_size, seqlen_k, seqlen_q,
head_size_rounded, p_dropout, /*num_splits*/ 0, dprops, opts); head_size_rounded, p_dropout, /*num_splits*/ 0, dprops, opts);
// number of times random will be generated per thread, to offset philox counter in thc random // NOTE(woosuk): Commented out because they are not used in inference.
// state // // number of times random will be generated per thread, to offset philox counter in thc random
// We use a custom RNG that increases the offset by batch_size * nheads * 32. // // state
int64_t counter_offset = params.b * params.h * 32; // // We use a custom RNG that increases the offset by batch_size * nheads * 32.
auto options = torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA); // int64_t counter_offset = params.b * params.h * 32;
auto rng_state = torch::empty({2}, options.dtype(torch::kInt64)); // auto options = torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA);
// Forward kernel will populate memory with the seed and offset. // auto rng_state = torch::empty({2}, options.dtype(torch::kInt64));
params.rng_state = reinterpret_cast<uint64_t*>(rng_state.data_ptr()); // // Forward kernel will populate memory with the seed and offset.
// params.rng_state = reinterpret_cast<uint64_t*>(rng_state.data_ptr());
if (p_dropout > 0.0) { // if (p_dropout > 0.0) {
auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>( // auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
gen_, at::cuda::detail::getDefaultCUDAGenerator()); // gen_, at::cuda::detail::getDefaultCUDAGenerator());
// See Note [Acquire lock when using random generators] // // See Note [Acquire lock when using random generators]
std::lock_guard<std::mutex> lock(gen->mutex_); // std::lock_guard<std::mutex> lock(gen->mutex_);
params.philox_args = gen->philox_cuda_state(counter_offset); // params.philox_args = gen->philox_cuda_state(counter_offset);
} // }
set_params_alibi(params, alibi_slopes_, batch_size, num_heads); set_params_alibi(params, alibi_slopes_, batch_size, num_heads);
...@@ -442,11 +443,12 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size ...@@ -442,11 +443,12 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
if (seqlenq_ngroups_swapped) { if (seqlenq_ngroups_swapped) {
out = out.transpose(1, 2).reshape({batch_size, 1, num_heads_k * seqlen_q, head_size_og}); out = out.transpose(1, 2).reshape({batch_size, 1, num_heads_k * seqlen_q, head_size_og});
out_padded = out_padded.transpose(1, 2).reshape({batch_size, 1, num_heads_k * seqlen_q, head_size_og}); // NOTE(woosuk): The two lines are not needed because out_padded and q_padded are not used.
q_padded = q_padded.transpose(1, 2).reshape({batch_size, 1, num_heads_k * seqlen_q, head_size_og}); // out_padded = out_padded.transpose(1, 2).reshape({batch_size, 1, num_heads_k * seqlen_q, head_size_og});
// q_padded = q_padded.transpose(1, 2).reshape({batch_size, 1, num_heads_k * seqlen_q, head_size_og});
softmax_lse = softmax_lse.reshape({batch_size, num_heads_k * seqlen_q, 1}); softmax_lse = softmax_lse.reshape({batch_size, num_heads_k * seqlen_q, 1});
} }
return {out, q_padded, k_padded, v_padded, out_padded, softmax_lse, p, rng_state}; return {out, softmax_lse};
} }
std::vector<at::Tensor> std::vector<at::Tensor>
...@@ -698,7 +700,7 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s ...@@ -698,7 +700,7 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s
int64_t size_before[] = {batch_size, max_seqlen_q, num_heads_k, head_size_og}; int64_t size_before[] = {batch_size, max_seqlen_q, num_heads_k, head_size_og};
int64_t size_after[] = {batch_size, num_heads_k * max_seqlen_q, head_size_og}; int64_t size_after[] = {batch_size, num_heads_k * max_seqlen_q, head_size_og};
out = out.reshape(size_before).transpose(1, 2).reshape(size_after); out = out.reshape(size_before).transpose(1, 2).reshape(size_after);
// NOTE(woosuk): The two lines are not necessary because out_padded and q_padded are not used. // NOTE(woosuk): The two lines are not needed because out_padded and q_padded are not used.
// out_padded = out_padded.reshape(size_before).transpose(1, 2).reshape(size_after); // out_padded = out_padded.reshape(size_before).transpose(1, 2).reshape(size_after);
// q_padded = q_padded.reshape(size_before).transpose(1, 2).reshape(size_after); // q_padded = q_padded.reshape(size_before).transpose(1, 2).reshape(size_after);
softmax_lse = softmax_lse.reshape({num_heads * max_seqlen_q, batch_size}); softmax_lse = softmax_lse.reshape({num_heads * max_seqlen_q, batch_size});
......
...@@ -50,7 +50,7 @@ def _flash_attn_forward( ...@@ -50,7 +50,7 @@ def _flash_attn_forward(
q, k, v, dropout_p, softmax_scale, causal, window_size, softcap, alibi_slopes, return_softmax, *, out=None q, k, v, dropout_p, softmax_scale, causal, window_size, softcap, alibi_slopes, return_softmax, *, out=None
): ):
q, k, v = [maybe_contiguous(x) for x in (q, k, v)] q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = torch.ops.vllm_flash_attn_c.fwd( out, softmax_lse = torch.ops.vllm_flash_attn_c.fwd(
q, q,
k, k,
v, v,
...@@ -65,7 +65,9 @@ def _flash_attn_forward( ...@@ -65,7 +65,9 @@ def _flash_attn_forward(
return_softmax, return_softmax,
None, None,
) )
return out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state # NOTE(woosuk): out_padded, S_dmask, and rng_state are None
# because we only use the forward pass in the vLLM.
return out, q, k, v, out, softmax_lse, None, None
def _flash_attn_varlen_forward( def _flash_attn_varlen_forward(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment