Unverified Commit 7d3409be authored by Woosuk Kwon's avatar Woosuk Kwon Committed by GitHub
Browse files

Remove redundant code in `mha_fwd` (#29)


Signed-off-by: default avatarWoosuk Kwon <woosuk.kwon@berkeley.edu>
parent d886f881
......@@ -406,22 +406,23 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
params, batch_size, num_heads, head_size, seqlen_k, seqlen_q,
head_size_rounded, p_dropout, /*num_splits*/ 0, dprops, opts);
// number of times random will be generated per thread, to offset philox counter in thc random
// state
// We use a custom RNG that increases the offset by batch_size * nheads * 32.
int64_t counter_offset = params.b * params.h * 32;
auto options = torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA);
auto rng_state = torch::empty({2}, options.dtype(torch::kInt64));
// Forward kernel will populate memory with the seed and offset.
params.rng_state = reinterpret_cast<uint64_t*>(rng_state.data_ptr());
// NOTE(woosuk): Commented out because they are not used in inference.
// // number of times random will be generated per thread, to offset philox counter in thc random
// // state
// // We use a custom RNG that increases the offset by batch_size * nheads * 32.
// int64_t counter_offset = params.b * params.h * 32;
// auto options = torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA);
// auto rng_state = torch::empty({2}, options.dtype(torch::kInt64));
// // Forward kernel will populate memory with the seed and offset.
// params.rng_state = reinterpret_cast<uint64_t*>(rng_state.data_ptr());
if (p_dropout > 0.0) {
auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
gen_, at::cuda::detail::getDefaultCUDAGenerator());
// See Note [Acquire lock when using random generators]
std::lock_guard<std::mutex> lock(gen->mutex_);
params.philox_args = gen->philox_cuda_state(counter_offset);
}
// if (p_dropout > 0.0) {
// auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
// gen_, at::cuda::detail::getDefaultCUDAGenerator());
// // See Note [Acquire lock when using random generators]
// std::lock_guard<std::mutex> lock(gen->mutex_);
// params.philox_args = gen->philox_cuda_state(counter_offset);
// }
set_params_alibi(params, alibi_slopes_, batch_size, num_heads);
......@@ -442,11 +443,12 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
if (seqlenq_ngroups_swapped) {
out = out.transpose(1, 2).reshape({batch_size, 1, num_heads_k * seqlen_q, head_size_og});
out_padded = out_padded.transpose(1, 2).reshape({batch_size, 1, num_heads_k * seqlen_q, head_size_og});
q_padded = q_padded.transpose(1, 2).reshape({batch_size, 1, num_heads_k * seqlen_q, head_size_og});
// NOTE(woosuk): The two lines are not needed because out_padded and q_padded are not used.
// out_padded = out_padded.transpose(1, 2).reshape({batch_size, 1, num_heads_k * seqlen_q, head_size_og});
// q_padded = q_padded.transpose(1, 2).reshape({batch_size, 1, num_heads_k * seqlen_q, head_size_og});
softmax_lse = softmax_lse.reshape({batch_size, num_heads_k * seqlen_q, 1});
}
return {out, q_padded, k_padded, v_padded, out_padded, softmax_lse, p, rng_state};
return {out, softmax_lse};
}
std::vector<at::Tensor>
......@@ -698,7 +700,7 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s
int64_t size_before[] = {batch_size, max_seqlen_q, num_heads_k, head_size_og};
int64_t size_after[] = {batch_size, num_heads_k * max_seqlen_q, head_size_og};
out = out.reshape(size_before).transpose(1, 2).reshape(size_after);
// NOTE(woosuk): The two lines are not necessary because out_padded and q_padded are not used.
// NOTE(woosuk): The two lines are not needed because out_padded and q_padded are not used.
// out_padded = out_padded.reshape(size_before).transpose(1, 2).reshape(size_after);
// q_padded = q_padded.reshape(size_before).transpose(1, 2).reshape(size_after);
softmax_lse = softmax_lse.reshape({num_heads * max_seqlen_q, batch_size});
......
......@@ -50,7 +50,7 @@ def _flash_attn_forward(
q, k, v, dropout_p, softmax_scale, causal, window_size, softcap, alibi_slopes, return_softmax, *, out=None
):
q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = torch.ops.vllm_flash_attn_c.fwd(
out, softmax_lse = torch.ops.vllm_flash_attn_c.fwd(
q,
k,
v,
......@@ -65,7 +65,9 @@ def _flash_attn_forward(
return_softmax,
None,
)
return out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state
# NOTE(woosuk): out_padded, S_dmask, and rng_state are None
# because we only use the forward pass in the vLLM.
return out, q, k, v, out, softmax_lse, None, None
def _flash_attn_varlen_forward(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment