Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
7d3409be
Unverified
Commit
7d3409be
authored
Nov 28, 2024
by
Woosuk Kwon
Committed by
GitHub
Nov 28, 2024
Browse files
Remove redundant code in `mha_fwd` (#29)
Signed-off-by:
Woosuk Kwon
<
woosuk.kwon@berkeley.edu
>
parent
d886f881
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
25 additions
and
21 deletions
+25
-21
csrc/flash_attn/flash_api.cpp
csrc/flash_attn/flash_api.cpp
+21
-19
vllm_flash_attn/flash_attn_interface.py
vllm_flash_attn/flash_attn_interface.py
+4
-2
No files found.
csrc/flash_attn/flash_api.cpp
View file @
7d3409be
...
...
@@ -406,22 +406,23 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
params
,
batch_size
,
num_heads
,
head_size
,
seqlen_k
,
seqlen_q
,
head_size_rounded
,
p_dropout
,
/*num_splits*/
0
,
dprops
,
opts
);
// number of times random will be generated per thread, to offset philox counter in thc random
// state
// We use a custom RNG that increases the offset by batch_size * nheads * 32.
int64_t
counter_offset
=
params
.
b
*
params
.
h
*
32
;
auto
options
=
torch
::
TensorOptions
().
dtype
(
torch
::
kFloat32
).
device
(
torch
::
kCUDA
);
auto
rng_state
=
torch
::
empty
({
2
},
options
.
dtype
(
torch
::
kInt64
));
// Forward kernel will populate memory with the seed and offset.
params
.
rng_state
=
reinterpret_cast
<
uint64_t
*>
(
rng_state
.
data_ptr
());
// NOTE(woosuk): Commented out because they are not used in inference.
// // number of times random will be generated per thread, to offset philox counter in thc random
// // state
// // We use a custom RNG that increases the offset by batch_size * nheads * 32.
// int64_t counter_offset = params.b * params.h * 32;
// auto options = torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA);
// auto rng_state = torch::empty({2}, options.dtype(torch::kInt64));
// // Forward kernel will populate memory with the seed and offset.
// params.rng_state = reinterpret_cast<uint64_t*>(rng_state.data_ptr());
if
(
p_dropout
>
0.0
)
{
auto
gen
=
at
::
get_generator_or_default
<
at
::
CUDAGeneratorImpl
>
(
gen_
,
at
::
cuda
::
detail
::
getDefaultCUDAGenerator
());
// See Note [Acquire lock when using random generators]
std
::
lock_guard
<
std
::
mutex
>
lock
(
gen
->
mutex_
);
params
.
philox_args
=
gen
->
philox_cuda_state
(
counter_offset
);
}
//
if (p_dropout > 0.0) {
//
auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
//
gen_, at::cuda::detail::getDefaultCUDAGenerator());
//
// See Note [Acquire lock when using random generators]
//
std::lock_guard<std::mutex> lock(gen->mutex_);
//
params.philox_args = gen->philox_cuda_state(counter_offset);
//
}
set_params_alibi
(
params
,
alibi_slopes_
,
batch_size
,
num_heads
);
...
...
@@ -442,11 +443,12 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
if
(
seqlenq_ngroups_swapped
)
{
out
=
out
.
transpose
(
1
,
2
).
reshape
({
batch_size
,
1
,
num_heads_k
*
seqlen_q
,
head_size_og
});
out_padded
=
out_padded
.
transpose
(
1
,
2
).
reshape
({
batch_size
,
1
,
num_heads_k
*
seqlen_q
,
head_size_og
});
q_padded
=
q_padded
.
transpose
(
1
,
2
).
reshape
({
batch_size
,
1
,
num_heads_k
*
seqlen_q
,
head_size_og
});
// NOTE(woosuk): The two lines are not needed because out_padded and q_padded are not used.
// out_padded = out_padded.transpose(1, 2).reshape({batch_size, 1, num_heads_k * seqlen_q, head_size_og});
// q_padded = q_padded.transpose(1, 2).reshape({batch_size, 1, num_heads_k * seqlen_q, head_size_og});
softmax_lse
=
softmax_lse
.
reshape
({
batch_size
,
num_heads_k
*
seqlen_q
,
1
});
}
return
{
out
,
q_padded
,
k_padded
,
v_padded
,
out_padded
,
softmax_lse
,
p
,
rng_stat
e
};
return
{
out
,
softmax_ls
e
};
}
std
::
vector
<
at
::
Tensor
>
...
...
@@ -698,7 +700,7 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s
int64_t
size_before
[]
=
{
batch_size
,
max_seqlen_q
,
num_heads_k
,
head_size_og
};
int64_t
size_after
[]
=
{
batch_size
,
num_heads_k
*
max_seqlen_q
,
head_size_og
};
out
=
out
.
reshape
(
size_before
).
transpose
(
1
,
2
).
reshape
(
size_after
);
// NOTE(woosuk): The two lines are not ne
cessary
because out_padded and q_padded are not used.
// NOTE(woosuk): The two lines are not ne
eded
because out_padded and q_padded are not used.
// out_padded = out_padded.reshape(size_before).transpose(1, 2).reshape(size_after);
// q_padded = q_padded.reshape(size_before).transpose(1, 2).reshape(size_after);
softmax_lse
=
softmax_lse
.
reshape
({
num_heads
*
max_seqlen_q
,
batch_size
});
...
...
vllm_flash_attn/flash_attn_interface.py
View file @
7d3409be
...
...
@@ -50,7 +50,7 @@ def _flash_attn_forward(
q
,
k
,
v
,
dropout_p
,
softmax_scale
,
causal
,
window_size
,
softcap
,
alibi_slopes
,
return_softmax
,
*
,
out
=
None
):
q
,
k
,
v
=
[
maybe_contiguous
(
x
)
for
x
in
(
q
,
k
,
v
)]
out
,
q
,
k
,
v
,
out_padded
,
softmax_lse
,
S_dmask
,
rng_stat
e
=
torch
.
ops
.
vllm_flash_attn_c
.
fwd
(
out
,
softmax_ls
e
=
torch
.
ops
.
vllm_flash_attn_c
.
fwd
(
q
,
k
,
v
,
...
...
@@ -65,7 +65,9 @@ def _flash_attn_forward(
return_softmax
,
None
,
)
return
out
,
q
,
k
,
v
,
out_padded
,
softmax_lse
,
S_dmask
,
rng_state
# NOTE(woosuk): out_padded, S_dmask, and rng_state are None
# because we only use the forward pass in the vLLM.
return
out
,
q
,
k
,
v
,
out
,
softmax_lse
,
None
,
None
def
_flash_attn_varlen_forward
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment