Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
bde5aec8
Commit
bde5aec8
authored
Feb 13, 2024
by
skrider
Browse files
all working except rotary embedding
parent
fa13c6b0
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
23 additions
and
27 deletions
+23
-27
csrc/flash_attn/src/flash_fwd_kernel.h
csrc/flash_attn/src/flash_fwd_kernel.h
+7
-6
csrc/flash_attn/src/utils.h
csrc/flash_attn/src/utils.h
+8
-4
tests/test_flash_attn.py
tests/test_flash_attn.py
+8
-17
No files found.
csrc/flash_attn/src/flash_fwd_kernel.h
View file @
bde5aec8
...
@@ -597,15 +597,16 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
...
@@ -597,15 +597,16 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
Tensor
tQgQ
=
gmem_thr_copy_Q
.
partition_S
(
gQ
);
Tensor
tQgQ
=
gmem_thr_copy_Q
.
partition_S
(
gQ
);
Tensor
tQsQ
=
gmem_thr_copy_Q
.
partition_D
(
sQ
);
Tensor
tQsQ
=
gmem_thr_copy_Q
.
partition_D
(
sQ
);
Tensor
tKgK_
=
gmem_thr_copy_KV
.
partition_S
(
gK
);
// (KCPY, KCPY_N, KCPY_K)
Tensor
tKgK_
=
gmem_thr_copy_KV
.
partition_S
(
gK
);
// (KCPY, KCPY_N, KCPY_K)
Tensor
tKsK_
=
gmem_thr_copy_KV
.
partition_D
(
sK
);
Tensor
tKsK_
=
gmem_thr_copy_KV
.
partition_D
(
sK
);
Tensor
tVgV_
=
gmem_thr_copy_KV
.
partition_S
(
gV
);
// (VCPY, VCPY_N, VCPY_K)
Tensor
tVgV_
=
gmem_thr_copy_KV
.
partition_S
(
gV
);
// (VCPY, VCPY_N, VCPY_K)
Tensor
tVsV_
=
gmem_thr_copy_KV
.
partition_D
(
sV
);
Tensor
tVsV_
=
gmem_thr_copy_KV
.
partition_D
(
sV
);
Tensor
tKgK
=
make_tensor
(
tKgK_
.
data
(),
unsqueeze
<
2
>
(
layout
<
0
>
(
tKgK_
.
layout
()))
)
;
Tensor
tKgK
=
make_tensor
(
tKgK_
.
data
(),
reshape_thread_tile
(
tKgK_
.
layout
()));
Tensor
tKsK
=
make_tensor
(
tKsK_
.
data
(),
unsqueeze
<
2
>
(
layout
<
0
>
(
tKsK_
.
layout
()))
)
;
Tensor
tKsK
=
make_tensor
(
tKsK_
.
data
(),
reshape_thread_tile
(
tKsK_
.
layout
()));
Tensor
tVgV
=
make_tensor
(
tVgV_
.
data
(),
unsqueeze
<
2
>
(
layout
<
0
>
(
tVgV_
.
layout
()))
)
;
Tensor
tVgV
=
make_tensor
(
tVgV_
.
data
(),
reshape_thread_tile
(
tVgV_
.
layout
()));
Tensor
tVsV
=
make_tensor
(
tVsV_
.
data
(),
unsqueeze
<
2
>
(
layout
<
0
>
(
tVsV_
.
layout
()))
)
;
Tensor
tVsV
=
make_tensor
(
tVsV_
.
data
(),
reshape_thread_tile
(
tVsV_
.
layout
()));
if
(
block_table
!=
nullptr
)
{
if
(
block_table
!=
nullptr
)
{
tKgK
.
data
()
=
gK
.
data
()
+
flash
::
init_thread_kv_page_slice_offset
<
Kernel_traits
>
(
tidx
,
n_block_max
,
params
.
page_block_size
,
tKgK
.
data
()
=
gK
.
data
()
+
flash
::
init_thread_kv_page_slice_offset
<
Kernel_traits
>
(
tidx
,
n_block_max
,
params
.
page_block_size
,
...
@@ -718,8 +719,8 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
...
@@ -718,8 +719,8 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
Tensor
tKgKnew_
=
gmem_thr_copy_KV_new
.
partition_S
(
gKnew
);
// (KCPY, KCPY_N, KCPY_K)
Tensor
tKgKnew_
=
gmem_thr_copy_KV_new
.
partition_S
(
gKnew
);
// (KCPY, KCPY_N, KCPY_K)
Tensor
tVgVnew_
=
gmem_thr_copy_KV_new
.
partition_S
(
gVnew
);
// (VCPY, VCPY_N, VCPY_K)
Tensor
tVgVnew_
=
gmem_thr_copy_KV_new
.
partition_S
(
gVnew
);
// (VCPY, VCPY_N, VCPY_K)
auto
tKgKnew
=
make_tensor
(
tKgKnew_
.
data
(),
unsqueeze
<
2
>
(
layout
<
0
>
(
tKgKnew_
.
layout
()))
)
;
auto
tKgKnew
=
make_tensor
(
tKgKnew_
.
data
(),
reshape_thread_tile
(
tKgKnew_
.
layout
()));
auto
tVgVnew
=
make_tensor
(
tVgVnew_
.
data
(),
unsqueeze
<
2
>
(
layout
<
0
>
(
tVgVnew_
.
layout
()))
)
;
auto
tVgVnew
=
make_tensor
(
tVgVnew_
.
data
(),
reshape_thread_tile
(
tVgVnew_
.
layout
()));
const
int
n_block_copy_min
=
std
::
max
(
n_block_min
,
binfo
.
seqlen_k_cache
/
kBlockN
);
const
int
n_block_copy_min
=
std
::
max
(
n_block_min
,
binfo
.
seqlen_k_cache
/
kBlockN
);
auto
tKgK_data
=
tKgK
.
data
();
auto
tKgK_data
=
tKgK
.
data
();
...
...
csrc/flash_attn/src/utils.h
View file @
bde5aec8
...
@@ -344,10 +344,14 @@ int advance_thread_kv_page_slice_offset(const int tidx, const int n_block, const
...
@@ -344,10 +344,14 @@ int advance_thread_kv_page_slice_offset(const int tidx, const int n_block, const
////////////////////////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
int
N
,
class
Shape
,
class
Stride
>
// somewhat unorthodox reshape function. Given a tuple ((v1, v2), m, k), returns (v1, v2, k),
__forceinline__
__device__
constexpr
auto
unsqueeze
(
Layout
<
Shape
,
Stride
>
l
)
{
// where v2 may be a tuple itself, in the case of swizzled smem-backed thread tiles. This ensures
return
make_layout
(
insert
<
N
>
(
l
.
shape
(),
Int
<
1
>
{}),
// that paged and non-paged copies result in equivalently shaped, if not necessarily strided, tensors.
insert
<
N
>
(
l
.
stride
(),
Int
<
0
>
{}));
template
<
class
Shape
,
class
Stride
>
__forceinline__
__device__
auto
reshape_thread_tile
(
Layout
<
Shape
,
Stride
>
l
)
{
return
make_layout
(
append
(
get
<
0
>
(
l
.
shape
()),
get
<
2
>
(
l
.
shape
())),
append
(
get
<
0
>
(
l
.
stride
()),
get
<
2
>
(
l
.
stride
())));
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////////////////////////
...
...
tests/test_flash_attn.py
View file @
bde5aec8
...
@@ -1818,22 +1818,22 @@ def test_flash_attn_splitkv(
...
@@ -1818,22 +1818,22 @@ def test_flash_attn_splitkv(
# @pytest.mark.parametrize("num_splits", [1])
# @pytest.mark.parametrize("num_splits", [1])
@
pytest
.
mark
.
parametrize
(
"mha_type"
,
[
"mha"
,
"mqa"
,
"gqa"
])
@
pytest
.
mark
.
parametrize
(
"mha_type"
,
[
"mha"
,
"mqa"
,
"gqa"
])
# @pytest.mark.parametrize("mha_type", ["mha"])
# @pytest.mark.parametrize("mha_type", ["mha"])
@
pytest
.
mark
.
parametrize
(
"new_kv"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"new_kv"
,
[
False
])
# @pytest.mark.parametrize("new_kv", [False])
# @pytest.mark.parametrize("new_kv", [False])
@
pytest
.
mark
.
parametrize
(
"alibi"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"alibi"
,
[
True
])
# @pytest.mark.parametrize("alibi", [False])
# @pytest.mark.parametrize("alibi", [False])
@
pytest
.
mark
.
parametrize
(
"local"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"local"
,
[
False
,
True
])
# @pytest.mark.parametrize("local", [False])
# @pytest.mark.parametrize("local", [False])
@
pytest
.
mark
.
parametrize
(
"causal"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"causal"
,
[
True
])
# @pytest.mark.parametrize("causal", [False])
# @pytest.mark.parametrize("causal", [False])
@
pytest
.
mark
.
parametrize
(
"seqlen_new_eq_seqlen_q"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"seqlen_new_eq_seqlen_q"
,
[
True
,
False
])
# @pytest.mark.parametrize("seqlen_new_eq_seqlen_q", [True])
# @pytest.mark.parametrize("seqlen_new_eq_seqlen_q", [True])
@
pytest
.
mark
.
parametrize
(
"rotary_interleaved"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"rotary_interleaved"
,
[
False
])
# @pytest.mark.parametrize("rotary_interleaved", [False])
# @pytest.mark.parametrize("rotary_interleaved", [False])
@
pytest
.
mark
.
parametrize
(
"rotary_fraction"
,
[
0.0
,
0.5
,
1.0
])
@
pytest
.
mark
.
parametrize
(
"rotary_fraction"
,
[
0.0
])
# @pytest.mark.parametrize("rotary_fraction", [0.0])
# @pytest.mark.parametrize("rotary_fraction", [0.0])
# @pytest.mark.parametrize("paged_kv_block_size", [None, 256, 512])
# @pytest.mark.parametrize("paged_kv_block_size", [None, 256, 512])
@
pytest
.
mark
.
parametrize
(
"paged_kv_block_size"
,
[
16
,
256
,
512
])
@
pytest
.
mark
.
parametrize
(
"paged_kv_block_size"
,
[
16
,
48
,
256
,
512
])
# @pytest.mark.parametrize("has_batch_idx", [False, True])
# @pytest.mark.parametrize("has_batch_idx", [False, True])
@
pytest
.
mark
.
parametrize
(
"has_batch_idx"
,
[
False
])
@
pytest
.
mark
.
parametrize
(
"has_batch_idx"
,
[
False
])
@
pytest
.
mark
.
parametrize
(
"d"
,
[
32
,
59
,
64
,
80
,
128
,
256
])
@
pytest
.
mark
.
parametrize
(
"d"
,
[
32
,
59
,
64
,
80
,
128
,
256
])
...
@@ -1844,17 +1844,8 @@ def test_flash_attn_splitkv(
...
@@ -1844,17 +1844,8 @@ def test_flash_attn_splitkv(
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
"seqlen_q,seqlen_k"
,
"seqlen_q,seqlen_k"
,
[
[
(
1
,
128
),
(
1
,
10
*
1024
),
(
1
,
339
),
(
16
,
10
*
1024
),
(
3
,
1024
),
(
64
,
800
),
(
64
,
256
),
(
3
,
799
),
(
64
,
2048
),
(
16
,
20000
),
(
1
,
128
*
1024
),
(
16
,
128
*
1024
),
(
128
,
128
),
],
],
)
)
# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(256, 128)])
# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(256, 128)])
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment