Unverified Commit d886f881 authored by Woosuk Kwon's avatar Woosuk Kwon Committed by GitHub
Browse files

Remove redundant code in `varlen_fwd` (#28)


Signed-off-by: default avatarWoosuk Kwon <woosuk.kwon@berkeley.edu>
parent 5259c586
...@@ -659,22 +659,23 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s ...@@ -659,22 +659,23 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s
p_dropout, /*num_splits*/ 0, dprops, opts); p_dropout, /*num_splits*/ 0, dprops, opts);
} }
// NOTE(woosuk): Commented out because they are not used in inference.
// number of times random will be generated per thread, to offset philox counter in thc random // number of times random will be generated per thread, to offset philox counter in thc random
// state // state
// We use a custom RNG that increases the offset by batch_size * nheads * 32. // We use a custom RNG that increases the offset by batch_size * nheads * 32.
int64_t counter_offset = params.b * params.h * 32; // int64_t counter_offset = params.b * params.h * 32;
auto options = torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA); // auto options = torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA);
auto rng_state = torch::empty({2}, options.dtype(torch::kInt64)); // auto rng_state = torch::empty({2}, options.dtype(torch::kInt64));
// Forward kernel will populate memory with the seed and offset. // // Forward kernel will populate memory with the seed and offset.
params.rng_state = reinterpret_cast<uint64_t*>(rng_state.data_ptr()); // params.rng_state = reinterpret_cast<uint64_t*>(rng_state.data_ptr());
if (p_dropout > 0.0) { // if (p_dropout > 0.0) {
auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>( // auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
gen_, at::cuda::detail::getDefaultCUDAGenerator()); // gen_, at::cuda::detail::getDefaultCUDAGenerator());
// See Note [Acquire lock when using random generators] // // See Note [Acquire lock when using random generators]
std::lock_guard<std::mutex> lock(gen->mutex_); // std::lock_guard<std::mutex> lock(gen->mutex_);
params.philox_args = gen->philox_cuda_state(counter_offset); // params.philox_args = gen->philox_cuda_state(counter_offset);
} // }
set_params_alibi(params, alibi_slopes_, batch_size, num_heads); set_params_alibi(params, alibi_slopes_, batch_size, num_heads);
...@@ -697,12 +698,13 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s ...@@ -697,12 +698,13 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s
int64_t size_before[] = {batch_size, max_seqlen_q, num_heads_k, head_size_og}; int64_t size_before[] = {batch_size, max_seqlen_q, num_heads_k, head_size_og};
int64_t size_after[] = {batch_size, num_heads_k * max_seqlen_q, head_size_og}; int64_t size_after[] = {batch_size, num_heads_k * max_seqlen_q, head_size_og};
out = out.reshape(size_before).transpose(1, 2).reshape(size_after); out = out.reshape(size_before).transpose(1, 2).reshape(size_after);
out_padded = out_padded.reshape(size_before).transpose(1, 2).reshape(size_after); // NOTE(woosuk): The two lines are not necessary because out_padded and q_padded are not used.
q_padded = q_padded.reshape(size_before).transpose(1, 2).reshape(size_after); // out_padded = out_padded.reshape(size_before).transpose(1, 2).reshape(size_after);
// q_padded = q_padded.reshape(size_before).transpose(1, 2).reshape(size_after);
softmax_lse = softmax_lse.reshape({num_heads * max_seqlen_q, batch_size}); softmax_lse = softmax_lse.reshape({num_heads * max_seqlen_q, batch_size});
} }
return {out, q_padded, k_padded, v_padded, out_padded, softmax_lse, p, rng_state}; return {out, softmax_lse};
} }
std::vector<at::Tensor> std::vector<at::Tensor>
......
...@@ -88,7 +88,7 @@ def _flash_attn_varlen_forward( ...@@ -88,7 +88,7 @@ def _flash_attn_varlen_forward(
out=None out=None
): ):
q, k, v = [maybe_contiguous(x) for x in (q, k, v)] q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = torch.ops.vllm_flash_attn_c.varlen_fwd( out, softmax_lse = torch.ops.vllm_flash_attn_c.varlen_fwd(
q, q,
k, k,
v, v,
...@@ -112,7 +112,9 @@ def _flash_attn_varlen_forward( ...@@ -112,7 +112,9 @@ def _flash_attn_varlen_forward(
) )
# if out.isnan().any() or softmax_lse.isnan().any(): # if out.isnan().any() or softmax_lse.isnan().any():
# breakpoint() # breakpoint()
return out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state # NOTE(woosuk): out_padded, S_dmask, and rng_state are None
# because we only use the forward pass in the vLLM.
return out, q, k, v, None, softmax_lse, None, None
def _flash_attn_backward( def _flash_attn_backward(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment