Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
FlashMLA
Commits
c3cf875a
Commit
c3cf875a
authored
Jan 28, 2026
by
zhanghj2
Browse files
实现sparse prefill, 还有bug
parent
50e2de8d
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
725 additions
and
84 deletions
+725
-84
csrc/api/sparse_fwd.h
csrc/api/sparse_fwd.h
+8
-30
csrc/sm90/prefill/sparse/config.h
csrc/sm90/prefill/sparse/config.h
+85
-49
csrc/sm90/prefill/sparse/phase1.cuh
csrc/sm90/prefill/sparse/phase1.cuh
+433
-4
csrc/softmax.h
csrc/softmax.h
+1
-1
csrc/utils.h
csrc/utils.h
+198
-0
No files found.
csrc/api/sparse_fwd.h
View file @
c3cf875a
...
@@ -4,7 +4,7 @@
...
@@ -4,7 +4,7 @@
#include "params.h"
#include "params.h"
//
#include "sm90/prefill/sparse/phase1.h"
#include "sm90/prefill/sparse/phase1.h"
enum
class
FwdFeatures
:
int
{
enum
class
FwdFeatures
:
int
{
...
@@ -39,7 +39,7 @@ protected:
...
@@ -39,7 +39,7 @@ protected:
void
run_
(
const
SparseAttnFwdParams
&
params
,
const
std
::
vector
<
FeatureT
>
&
required_features
)
override
{
void
run_
(
const
SparseAttnFwdParams
&
params
,
const
std
::
vector
<
FeatureT
>
&
required_features
)
override
{
DISPATCH_HEAD_DIM
(
params
.
d_qk
,
HEAD_DIM_QK
,
[
&
]()
{
DISPATCH_HEAD_DIM
(
params
.
d_qk
,
HEAD_DIM_QK
,
[
&
]()
{
DISPATCH_BOOLEAN_FLAG
(
params
.
topk_length
!=
nullptr
,
HAVE_TOPK_LENGTH
,
[
&
]()
{
DISPATCH_BOOLEAN_FLAG
(
params
.
topk_length
!=
nullptr
,
HAVE_TOPK_LENGTH
,
[
&
]()
{
//
sm90::fwd::run_fwd_phase1_kernel<HEAD_DIM_QK, HAVE_TOPK_LENGTH>(params);
sm90
::
fwd
::
run_fwd_phase1_kernel
<
HEAD_DIM_QK
,
HAVE_TOPK_LENGTH
>
(
params
);
});
});
});
});
}
}
...
@@ -208,34 +208,12 @@ static std::vector<at::Tensor> sparse_attn_prefill_interface(
...
@@ -208,34 +208,12 @@ static std::vector<at::Tensor> sparse_attn_prefill_interface(
required_features
.
push_back
(
FwdFeatures
::
TOPK_LENGTH
);
required_features
.
push_back
(
FwdFeatures
::
TOPK_LENGTH
);
}
}
// if (is_sm90a) {
if
(
is_sm90a
)
{
// Fwd_Sm90_Impl fwd_impl;
Fwd_Sm90_Impl
fwd_impl
;
// fwd_impl.run(params, required_features);
fwd_impl
.
run
(
params
,
required_features
);
// } else if (is_sm100f) {
}
else
{
// if (h_q == 64) {
TORCH_CHECK
(
false
,
"Unsupported architecture"
);
// Fwd_Sm100_Head64_Impl fwd_impl;
}
// fwd_impl.run(params, required_features);
// } else if (h_q == 128) {
// Fwd_Sm100_Head128_Small_TopK_Impl small_topk_impl;
// Fwd_Sm100_Head128_Impl regular_impl;
// bool use_small_topk_impl = false;
// if (
// (topk <= 1280 && small_topk_impl.check_if_all_features_are_supported(required_features)) ||
// !regular_impl.check_if_all_features_are_supported(required_features)
// ) {
// use_small_topk_impl = true;
// }
// if (use_small_topk_impl) {
// small_topk_impl.run(params, required_features);
// } else {
// regular_impl.run(params, required_features);
// }
// } else {
// TORCH_CHECK(false, "Unsupported h_q: ", h_q);
// }
// } else {
// TORCH_CHECK(false, "Unsupported architecture");
// }
return
{
out
,
max_logits
,
lse
};
return
{
out
,
max_logits
,
lse
};
}
}
csrc/sm90/prefill/sparse/config.h
View file @
c3cf875a
...
@@ -19,61 +19,97 @@ static constexpr int D_Q = D_QK;
...
@@ -19,61 +19,97 @@ static constexpr int D_Q = D_QK;
static
constexpr
int
D_K
=
D_QK
;
static
constexpr
int
D_K
=
D_QK
;
static
constexpr
int
D_V
=
512
;
static
constexpr
int
D_V
=
512
;
static
constexpr
int
B_H
=
64
;
static
constexpr
int
kNWarps
=
4
;
static
constexpr
int
B_H
=
16
;
static
constexpr
int
B_TOPK
=
64
;
// TopK block size
static
constexpr
int
B_TOPK
=
64
;
// TopK block size
static
constexpr
int
NUM_THREADS
=
128
*
3
;
static
constexpr
int
NUM_THREADS
=
kNWarps
*
64
;
static
constexpr
float
MAX_INIT_VAL
=
-
1e30
;
// We use this number as the initial value for mi (max logits)
static
constexpr
float
MAX_INIT_VAL
=
-
1e30
;
// We use this number as the initial value for mi (max logits)
template
<
int
NUM_TILES
>
using
Element
=
cutlass
::
bfloat16_t
;
using
SmemLayoutQTiles
=
decltype
(
coalesce
(
tile_to_shape
(
using
elem_type
=
Element
;
GMMA
::
Layout_K_SW128_Atom
<
bf16
>
{},
using
ElementAccum
=
float
;
Shape
<
Int
<
B_H
>
,
Int
<
64
*
NUM_TILES
>>
{},
using
index_t
=
int64_t
;
Step
<
_1
,
_2
>
{}
static
constexpr
int
kBlockM
=
B_H
;
),
Shape
<
_1
,
_1
>
{}));
static
constexpr
int
kBlockN
=
B_TOPK
;
static
constexpr
int
kHeadDim
=
D_QK
;
template
<
int
NUM_TILES
>
static
constexpr
int
kHeadDimV
=
D_V
;
using
SmemLayoutOTiles
=
decltype
(
coalesce
(
tile_to_shape
(
GMMA
::
Layout_K_SW128_Atom
<
bf16
>
{},
using
ValLayoutMNK
=
Layout
<
Shape
<
_1
,
_1
,
_1
>>
;
Shape
<
Int
<
B_H
>
,
Int
<
64
*
NUM_TILES
>>
{},
// 没打开?
Step
<
_1
,
_2
>
{}
// #if defined(__gfx936__) || defined(__gfx938__) || 1
),
Shape
<
_1
,
_1
>
{}));
// using MMA_Atom_Arch = std::conditional_t<
// std::is_same_v<elem_type, cutlass::half_t>,
template
<
int
NUM_TILES
>
// MMA_Atom<GFX928_16x16x32_F32F16F16F32_NT>,
using
SmemLayoutKTiles
=
decltype
(
coalesce
(
tile_to_shape
(
// MMA_Atom<GFX928_16x16x32_F32BF16BF16F32_NT>
GMMA
::
Layout_SW128_Atom
<
bf16
,
GMMA
::
Major
::
K
>
{},
// >;
Shape
<
Int
<
B_TOPK
>
,
Int
<
64
*
NUM_TILES
>>
{},
using
MMA_Atom_Arch
=
std
::
conditional_t
<
Step
<
_1
,
_2
>
{}
std
::
is_same_v
<
elem_type
,
cutlass
::
half_t
>
,
),
Shape
<
_1
,
_1
>
{}));
MMA_Atom
<
GFX928_16x16x32_F32F16F16F32_NN
>
,
MMA_Atom
<
GFX928_16x16x32_F32BF16BF16F32_NN
>
template
<
int
NUM_TILES
>
>
;
using
SmemLayoutKTilesTransposed
=
decltype
(
composition
(
using
TiledMma
=
TiledMMA
<
SmemLayoutKTiles
<
NUM_TILES
>
{},
MMA_Atom_Arch
,
Layout
<
Shape
<
Int
<
64
*
NUM_TILES
>
,
Int
<
B_TOPK
>>
,
Stride
<
Int
<
B_TOPK
>
,
_1
>>
{}
Layout
<
Shape
<
_1
,
Int
<
kNWarps
>
,
_1
>>
,
// 1x4x1 or 1x8x1 thread group
));
ValLayoutMNK
>
;
// #endif
using
SmemLayoutQ
=
SmemLayoutQTiles
<
D_Q
/
64
>
;
using
SmemLayoutO
=
SmemLayoutOTiles
<
D_V
/
64
>
;
using
MMA_Atom_Arch_16x32
=
std
::
conditional_t
<
using
SmemLayoutK
=
SmemLayoutKTiles
<
D_Q
/
64
>
;
std
::
is_same_v
<
elem_type
,
cutlass
::
half_t
>
,
using
SmemLayoutV
=
SmemLayoutKTilesTransposed
<
D_V
/
64
>
;
MMA_Atom
<
GFX928_16x32x16_F32F16F16F32_NT
>
,
using
SmemLayoutHalfV
=
SmemLayoutKTilesTransposed
<
D_V
/
64
/
2
>
;
MMA_Atom
<
GFX928_16x32x16_F32BF16BF16F32_NT
>
>
;
using
SmemLayoutS
=
decltype
(
coalesce
(
tile_to_shape
(
GMMA
::
Layout_K_SW128_Atom
<
bf16
>
{},
using
TiledMma_O
=
TiledMMA
<
Shape
<
Int
<
B_H
>
,
Int
<
B_TOPK
>>
{}
MMA_Atom_Arch_16x32
,
),
Shape
<
_1
,
_1
>
{}));
Layout
<
Shape
<
_1
,
Int
<
kNWarps
>
,
_1
>>
,
// 1x4x1 or 1x8x1 thread group
ValLayoutMNK
>
;
using
SmemLayoutAtomQ
=
Layout
<
Shape
<
Int
<
16
>
,
Int
<
32
>>
,
Stride
<
Int
<
32
>
,
_1
>>
;
using
SmemLayoutQ
=
decltype
(
tile_to_shape
(
SmemLayoutAtomQ
{},
Shape
<
Int
<
kBlockM
>
,
Int
<
kHeadDim
>>
{}));
using
SmemLayoutAtomK
=
decltype
(
composition
(
Swizzle
<
3
,
3
,
3
>
{},
Layout
<
Shape
<
Int
<
8
>
,
Int
<
32
>>
,
Stride
<
Int
<
32
>
,
_1
>>
{}));
using
SmemLayoutK
=
decltype
(
tile_to_shape
(
SmemLayoutAtomK
{},
Shape
<
Int
<
kBlockN
>
,
Int
<
16
*
32
>>
{}));
using
SmemLayoutAtomV
=
SmemLayoutAtomK
;
using
SmemLayoutV
=
decltype
(
tile_to_shape
(
SmemLayoutAtomV
{},
Shape
<
Int
<
kBlockN
>
,
Int
<
kHeadDimV
>>
{}));
using
SmemLayoutVtransposed
=
decltype
(
composition
(
SmemLayoutV
{},
make_layout
(
Shape
<
Int
<
kHeadDimV
>
,
Int
<
kBlockN
>>
{},
GenRowMajor
{})));
using
SmemLayoutVtransposedNoSwizzle
=
decltype
(
get_nonswizzle_portion
(
SmemLayoutVtransposed
{}));
using
SmemLayoutAtomP
=
Layout
<
Shape
<
Int
<
4
*
16
*
16
>>
,
Stride
<
Int
<
1
>>>
;
using
SmemLayoutP
=
decltype
(
tile_to_shape
(
SmemLayoutAtomP
{},
Shape
<
Int
<
4
*
16
*
16
>>
{}));
using
SmemLayoutRow
=
Layout
<
Shape
<
_128
>
,
Stride
<
_1
>>
;
using
SmemLayoutK_place_holder
=
decltype
(
tile_to_shape
(
SmemLayoutAtomK
{},
Shape
<
Int
<
kBlockN
>
,
Int
<
4
*
32
>>
{}));
struct
SharedMemoryPlan
{
struct
SharedMemoryPlan
{
union
{
union
{
array_aligned
<
bf16
,
cosize_v
<
SmemLayoutQ
>>
q
;
struct
{
array_aligned
<
bf16
,
cosize_v
<
SmemLayoutO
>>
o
;
cute
::
array_aligned
<
Element
,
cute
::
cosize_v
<
SmemLayoutV
>>
smem_v
;
// Double buffer
}
q_o
;
};
array_aligned
<
bf16
,
cosize_v
<
SmemLayoutK
>>
k
[
2
];
struct
{
array_aligned
<
bf16
,
cosize_v
<
SmemLayoutS
>>
s
[
D_QK
==
576
?
1
:
2
];
// For V3.2 (whose D_QK is 576), we overlap sS[0] with k's RoPE part to save shared memory; For MODEL1 (whose D_QK is 512), we allocate two buffers
cute
::
array_aligned
<
Element
,
cute
::
cosize_v
<
SmemLayoutK_place_holder
>>
smem_place_holder
;
// Double buffer
cute
::
array_aligned
<
Element
,
cute
::
cosize_v
<
SmemLayoutP
>>
smem_p
;
bool
is_kv_valid
[
2
][
B_TOPK
];
cute
::
array_aligned
<
ElementAccum
,
cute
::
cosize_v
<
SmemLayoutRow
>>
smem_row_sum
;
float2
sM
[
32
];
cute
::
array_aligned
<
ElementAccum
,
cute
::
cosize_v
<
SmemLayoutRow
>>
smem_row_max
;
float2
sL
[
64
];
// For reduction across WG0/1 in epilogue
};
float
final_max_logits
[
64
],
final_lse
[
64
];
struct
{
cute
::
array_aligned
<
Element
,
cute
::
cosize_v
<
SmemLayoutQ
>>
smem_q
;
};
};
// transac_bar_t bar_q, bar_k0_free[2], bar_k0_ready[2], bar_k1_free[2], bar_k1_ready[2], bar_is_kv_valid_ready;
// transac_bar_t bar_q, bar_k0_free[2], bar_k0_ready[2], bar_k1_free[2], bar_k1_ready[2], bar_is_kv_valid_ready;
};
};
...
...
csrc/sm90/prefill/sparse/phase1.cuh
View file @
c3cf875a
#pragma once
#pragma once
#include "config.h"
#include "config.h"
#include "utils.h"
#include "utils.h"
#include "softmax.h"
#include "../../helpers.h"
#include "../../helpers.h"
namespace
sm90
::
fwd
{
namespace
sm90
::
fwd
{
...
@@ -11,6 +11,433 @@ using namespace cute;
...
@@ -11,6 +11,433 @@ using namespace cute;
template
<
int
D_QK
,
bool
HAVE_TOPK_LENGTH
>
template
<
int
D_QK
,
bool
HAVE_TOPK_LENGTH
>
__device__
void
KernelTemplate
<
D_QK
,
HAVE_TOPK_LENGTH
>::
devfunc
(
const
SparseAttnFwdParams
&
params
)
{
__device__
void
KernelTemplate
<
D_QK
,
HAVE_TOPK_LENGTH
>::
devfunc
(
const
SparseAttnFwdParams
&
params
)
{
extern
__shared__
char
smem_
[];
SharedMemoryPlan
&
plan
=
*
reinterpret_cast
<
SharedMemoryPlan
*>
(
smem_
);
const
int
tidx
=
threadIdx
.
x
;
static
constexpr
int
kBlockM
=
B_H
;
static
constexpr
int
kBlockN
=
B_TOPK
;
static
constexpr
int
kHeadDim
=
D_QK
;
static
constexpr
int
kHeadDimV
=
D_V
;
const
int
warp_idx
=
tidx
/
64
;
const
int
s_q_idx
=
blockIdx
.
x
;
const
int
bidh
=
blockIdx
.
y
;
const
int
lane_idx
=
tidx
%
64
;
const
index_t
row_offset_q
=
s_q_idx
*
params
.
stride_q_s_q
+
bidh
*
kBlockM
*
params
.
stride_q_h_q
;
Tensor
gQ
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
Element
*>
(
params
.
q
)
+
row_offset_q
),
Shape
<
Int
<
kBlockM
>
,
Int
<
kHeadDim
>>
{},
make_stride
(
params
.
stride_q_h_q
,
_1
{}));
const
index_t
row_offset_k
=
0
*
params
.
stride_kv_h_kv
;
Tensor
gK
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
Element
*>
(
params
.
kv
)
+
row_offset_k
),
Shape
<
Int
<
kBlockN
>
,
Int
<
kHeadDim
>>
{},
make_stride
(
params
.
stride_kv_s_kv
,
_1
{}));
const
index_t
row_offset_topk
=
s_q_idx
*
params
.
stride_indices_s_q
;
int
*
gIndices
=
reinterpret_cast
<
int
*>
(
params
.
indices
)
+
row_offset_topk
;
Tensor
sQ
=
make_tensor
(
make_smem_ptr
(
plan
.
smem_q
.
data
()),
SmemLayoutQ
{});
Tensor
sV
=
make_tensor
(
make_smem_ptr
(
plan
.
smem_v
.
data
()),
SmemLayoutV
{});
Tensor
sK
=
make_tensor
(
make_smem_ptr
(
plan
.
smem_v
.
data
()),
SmemLayoutK
{});
Tensor
sP
=
make_tensor
(
make_smem_ptr
(
plan
.
smem_p
.
data
()),
SmemLayoutP
{});
Tensor
sVt
=
make_tensor
(
sV
.
data
(),
SmemLayoutVtransposed
{});
Tensor
sVtNoSwizzle
=
make_tensor
(
sV
.
data
(),
SmemLayoutVtransposedNoSwizzle
{});
Tensor
sRow_max_reduce_buffer
=
make_tensor
(
make_smem_ptr
(
plan
.
smem_row_max
.
data
()),
SmemLayoutRow
{});
Tensor
sRow_sum_reduce_buffer
=
make_tensor
(
make_smem_ptr
(
plan
.
smem_row_sum
.
data
()),
SmemLayoutRow
{});
TiledMMA
tiled_mma
=
TiledMma
{};
auto
thr_mma
=
tiled_mma
.
get_thread_slice
(
tidx
);
TiledMMA
tiled_mma_o
=
TiledMma_O
{};
auto
thr_mma_o
=
tiled_mma_o
.
get_thread_slice
(
tidx
);
flash
::
lds_direct_copy
<
false
,
true
,
true
>
(
gQ
,
sQ
,
0
,
params
.
stride_q_h_q
,
params
.
h_q
-
bidh
*
kBlockM
);
flash
::
lds_direct_copy
<
false
,
true
,
true
>
(
gQ
,
sQ
,
1
,
params
.
stride_q_h_q
,
params
.
h_q
-
bidh
*
kBlockM
);
flash
::
lds_direct_copy
<
false
,
true
,
true
>
(
gQ
,
sQ
,
2
,
params
.
stride_q_h_q
,
params
.
h_q
-
bidh
*
kBlockM
);
flash
::
lds_direct_copy
<
false
,
true
,
true
>
(
gQ
,
sQ
,
3
,
params
.
stride_q_h_q
,
params
.
h_q
-
bidh
*
kBlockM
);
flash
::
lds_direct_copy
<
false
,
false
,
true
>
(
gQ
,
sQ
,
4
,
params
.
stride_q_h_q
,
params
.
h_q
-
bidh
*
kBlockM
);
auto
smem_tiled_copy_Q
=
make_tiled_copy_A
(
Copy_Atom
<
DefaultCopy
,
Element
>
{},
tiled_mma
);
auto
smem_thr_copy_Q
=
smem_tiled_copy_Q
.
get_thread_slice
(
tidx
);
Tensor
tSsQ
=
smem_thr_copy_Q
.
partition_S
(
sQ
);
Tensor
tSrQ
=
thr_mma
.
partition_fragment_A
(
sQ
);
Tensor
tSrQ_copy_view
=
smem_thr_copy_Q
.
retile_D
(
tSrQ
);
// asm volatile("s_waitcnt vmcnt(0) \n\t s_barrier\n\t");
asm
volatile
(
"s_waitcnt vmcnt(4)
\n\t
s_barrier
\n\t
"
);
cute
::
copy
(
smem_tiled_copy_Q
,
tSsQ
(
_
,
_
,
0
),
tSrQ_copy_view
(
_
,
_
,
0
));
cute
::
copy
(
smem_tiled_copy_Q
,
tSsQ
(
_
,
_
,
1
),
tSrQ_copy_view
(
_
,
_
,
1
));
cute
::
copy
(
smem_tiled_copy_Q
,
tSsQ
(
_
,
_
,
2
),
tSrQ_copy_view
(
_
,
_
,
2
));
cute
::
copy
(
smem_tiled_copy_Q
,
tSsQ
(
_
,
_
,
3
),
tSrQ_copy_view
(
_
,
_
,
3
));
asm
volatile
(
"s_waitcnt vmcnt(3)
\n\t
s_barrier
\n\t
"
);
cute
::
copy
(
smem_tiled_copy_Q
,
tSsQ
(
_
,
_
,
4
),
tSrQ_copy_view
(
_
,
_
,
4
));
cute
::
copy
(
smem_tiled_copy_Q
,
tSsQ
(
_
,
_
,
5
),
tSrQ_copy_view
(
_
,
_
,
5
));
cute
::
copy
(
smem_tiled_copy_Q
,
tSsQ
(
_
,
_
,
6
),
tSrQ_copy_view
(
_
,
_
,
6
));
cute
::
copy
(
smem_tiled_copy_Q
,
tSsQ
(
_
,
_
,
7
),
tSrQ_copy_view
(
_
,
_
,
7
));
asm
volatile
(
"s_waitcnt vmcnt(2)
\n\t
s_barrier
\n\t
"
);
cute
::
copy
(
smem_tiled_copy_Q
,
tSsQ
(
_
,
_
,
8
),
tSrQ_copy_view
(
_
,
_
,
8
));
cute
::
copy
(
smem_tiled_copy_Q
,
tSsQ
(
_
,
_
,
9
),
tSrQ_copy_view
(
_
,
_
,
9
));
cute
::
copy
(
smem_tiled_copy_Q
,
tSsQ
(
_
,
_
,
10
),
tSrQ_copy_view
(
_
,
_
,
10
));
cute
::
copy
(
smem_tiled_copy_Q
,
tSsQ
(
_
,
_
,
11
),
tSrQ_copy_view
(
_
,
_
,
11
));
asm
volatile
(
"s_waitcnt vmcnt(1)
\n\t
s_barrier
\n\t
"
);
cute
::
copy
(
smem_tiled_copy_Q
,
tSsQ
(
_
,
_
,
12
),
tSrQ_copy_view
(
_
,
_
,
12
));
cute
::
copy
(
smem_tiled_copy_Q
,
tSsQ
(
_
,
_
,
13
),
tSrQ_copy_view
(
_
,
_
,
13
));
cute
::
copy
(
smem_tiled_copy_Q
,
tSsQ
(
_
,
_
,
14
),
tSrQ_copy_view
(
_
,
_
,
14
));
cute
::
copy
(
smem_tiled_copy_Q
,
tSsQ
(
_
,
_
,
15
),
tSrQ_copy_view
(
_
,
_
,
15
));
asm
volatile
(
"s_waitcnt vmcnt(0)
\n\t
s_barrier
\n\t
"
);
cute
::
copy
(
smem_tiled_copy_Q
,
tSsQ
(
_
,
_
,
16
),
tSrQ_copy_view
(
_
,
_
,
16
));
cute
::
copy
(
smem_tiled_copy_Q
,
tSsQ
(
_
,
_
,
17
),
tSrQ_copy_view
(
_
,
_
,
17
));
__syncthreads
();
const
int
topk_length
=
HAVE_TOPK_LENGTH
?
__ldg
(
params
.
topk_length
+
s_q_idx
)
:
params
.
topk
;
const
int
num_topk_blocks
=
HAVE_TOPK_LENGTH
?
ku
::
ceil_div
(
topk_length
,
(
int
)
B_TOPK
)
:
(
int
)((
unsigned
int
)
params
.
topk
/
(
unsigned
int
)
B_TOPK
);
auto
smem_tiled_copy_K
=
make_tiled_copy_B
(
Copy_Atom
<
DefaultCopy
,
Element
>
{},
tiled_mma
);
auto
smem_thr_copy_K
=
smem_tiled_copy_K
.
get_thread_slice
(
tidx
);
Tensor
tSsK
=
smem_thr_copy_K
.
partition_S
(
sK
);
Tensor
tSrK
=
thr_mma
.
partition_fragment_B
(
sK
);
Tensor
tSrK_copy_view
=
smem_thr_copy_K
.
retile_D
(
tSrK
);
Tensor
tSrK_smem
=
thr_mma
.
partition_fragment_B
(
gK
);
auto
smem_tiled_copy_V
=
make_tiled_copy_B
(
Copy_Atom
<
GFX928_DS_READ_DS_M32x16_B16
,
Element
>
{},
tiled_mma_o
);
auto
smem_thr_copy_V
=
smem_tiled_copy_V
.
get_thread_slice
(
tidx
);
Tensor
tOsVt
=
smem_thr_copy_V
.
partition_S
(
sVt
);
Tensor
tOrVt
=
thr_mma_o
.
partition_fragment_B
(
sVtNoSwizzle
);
Tensor
tOrVt_copy_view
=
smem_thr_copy_V
.
retile_D
(
tOrVt
);
Tensor
acc_o
=
partition_fragment_C
(
tiled_mma_o
,
Shape
<
Int
<
kBlockM
>
,
Int
<
kHeadDimV
>>
{});
clear
(
acc_o
);
flash
::
Softmax
<
size
<
1
>
(
acc_o
)
>
softmax
;
auto
calc_row_and_col
=
[
&
](
const
int
block_idx
)
->
std
::
tuple
<
int
,
int
>
{
// 计算swizzle后的全局显存访存地址
int
virtual_row
=
lane_idx
/
8
;
int
virtual_col
=
lane_idx
%
8
;
int
swizzle_col
=
virtual_row
^
virtual_col
;
int
row
=
lane_idx
/
4
;
row
=
(
row
>=
8
)
^
row
;
int
col
=
swizzle_col
%
4
;
int
warp_id
=
tidx
/
64
;
int
row_offset
=
block_idx
*
kBlockN
+
row
+
(
warp_idx
*
16
)
;
// row_offset = row_offset < params.topk ? gIndices[row_offset] : -1;
row_offset
=
gIndices
[
row_offset
];
return
{
row_offset
,
col
};
};
for
(
int
block_idx
=
0
;
block_idx
<
num_topk_blocks
;
block_idx
++
)
{
Tensor
acc_s
=
partition_fragment_C
(
tiled_mma
,
Shape
<
Int
<
kBlockM
>
,
Int
<
kBlockN
>>
{});
clear
(
acc_s
);
auto
[
row_offset
,
col
]
=
calc_row_and_col
(
block_idx
);
if
constexpr
(
D_QK
==
576
)
{
for
(
int
i
=
16
;
i
<
18
;
i
++
)
{
flash
::
lds_direct_copy_for_prefill_sparse_mla
<
true
,
false
,
false
>
(
gK
,
sK
,
row_offset
,
col
,
i
,
params
.
stride_kv_s_kv
,
params
.
s_kv
);
}
asm
volatile
(
"s_waitcnt vmcnt(1)
\n
s_barrier"
);
cute
::
copy
(
smem_tiled_copy_K
,
tSsK
(
_
,
_
,
0
),
tSrK_copy_view
(
_
,
_
,
0
));
cute
::
gemm
(
tiled_mma
,
tSrQ
(
_
,
_
,
0
+
16
),
tSrK
(
_
,
_
,
0
),
acc_s
);
asm
volatile
(
"s_waitcnt vmcnt(0)
\n
s_barrier"
);
flash
::
lds_direct_copy_for_prefill_sparse_mla
<
true
,
false
,
false
>
(
gK
,
sK
,
row_offset
,
col
,
0
,
params
.
stride_kv_s_kv
,
params
.
s_kv
);
cute
::
copy
(
smem_tiled_copy_K
,
tSsK
(
_
,
_
,
1
),
tSrK_copy_view
(
_
,
_
,
1
));
cute
::
gemm
(
tiled_mma
,
tSrQ
(
_
,
_
,
1
+
16
),
tSrK
(
_
,
_
,
1
),
acc_s
);
}
else
{
flash
::
lds_direct_copy_for_prefill_sparse_mla
<
true
,
false
,
false
>
(
gK
,
sK
,
row_offset
,
col
,
0
,
params
.
stride_kv_s_kv
,
params
.
s_kv
);
}
for
(
int
i
=
1
;
i
<
4
;
i
++
)
{
flash
::
lds_direct_copy_for_prefill_sparse_mla
<
true
,
false
,
false
>
(
gK
,
sK
,
row_offset
,
col
,
i
,
params
.
stride_kv_s_kv
,
params
.
s_kv
);
}
int
k_idx
=
0
;
asm
volatile
(
"s_waitcnt vmcnt(3)
\n\t
s_barrier
\n\t
"
);
cute
::
copy
(
smem_tiled_copy_K
,
tSsK
(
_
,
_
,
k_idx
%
4
),
tSrK_copy_view
(
_
,
_
,
k_idx
));
cute
::
gemm
(
tiled_mma
,
tSrQ
(
_
,
_
,
k_idx
),
tSrK
(
_
,
_
,
k_idx
),
acc_s
);
k_idx
++
;
asm
volatile
(
"s_waitcnt vmcnt(2)
\n\t
s_barrier
\n\t
"
);
cute
::
copy
(
smem_tiled_copy_K
,
tSsK
(
_
,
_
,
k_idx
%
4
),
tSrK_copy_view
(
_
,
_
,
k_idx
));
cute
::
gemm
(
tiled_mma
,
tSrQ
(
_
,
_
,
k_idx
),
tSrK
(
_
,
_
,
k_idx
),
acc_s
);
k_idx
++
;
asm
volatile
(
"s_waitcnt vmcnt(1)
\n\t
s_barrier
\n\t
"
);
cute
::
copy
(
smem_tiled_copy_K
,
tSsK
(
_
,
_
,
k_idx
%
4
),
tSrK_copy_view
(
_
,
_
,
k_idx
));
cute
::
gemm
(
tiled_mma
,
tSrQ
(
_
,
_
,
k_idx
),
tSrK
(
_
,
_
,
k_idx
),
acc_s
);
k_idx
++
;
asm
volatile
(
"s_waitcnt vmcnt(0)
\n\t
s_barrier
\n\t
"
);
cute
::
copy
(
smem_tiled_copy_K
,
tSsK
(
_
,
_
,
k_idx
%
4
),
tSrK_copy_view
(
_
,
_
,
k_idx
));
flash
::
__ds_read_m32x16_row_col_rrow
<
0
,
0
,
0
>
(
tOsVt
,
tOrVt_copy_view
);
flash
::
__ds_read_m32x16_row_col_rrow
<
0
,
1
,
0
>
(
tOsVt
,
tOrVt_copy_view
);
flash
::
__ds_read_m32x16_row_col_rrow
<
0
,
2
,
0
>
(
tOsVt
,
tOrVt_copy_view
);
flash
::
__ds_read_m32x16_row_col_rrow
<
0
,
3
,
0
>
(
tOsVt
,
tOrVt_copy_view
);
cute
::
gemm
(
tiled_mma
,
tSrQ
(
_
,
_
,
k_idx
),
tSrK
(
_
,
_
,
k_idx
),
acc_s
);
asm
volatile
(
"s_waitcnt lgkmcnt(0)
\n\t
s_barrier
\n\t
"
);
flash
::
lds_direct_copy_for_prefill_sparse_mla
<
true
,
false
,
false
>
(
gK
,
sK
,
row_offset
,
col
,
4
,
params
.
stride_kv_s_kv
,
params
.
s_kv
);
flash
::
lds_direct_copy_for_prefill_sparse_mla
<
true
,
false
,
false
>
(
gK
,
sK
,
row_offset
,
col
,
5
,
params
.
stride_kv_s_kv
,
params
.
s_kv
);
flash
::
lds_direct_copy_for_prefill_sparse_mla
<
true
,
false
,
false
>
(
gK
,
sK
,
row_offset
,
col
,
6
,
params
.
stride_kv_s_kv
,
params
.
s_kv
);
flash
::
lds_direct_copy_for_prefill_sparse_mla
<
true
,
false
,
false
>
(
gK
,
sK
,
row_offset
,
col
,
7
,
params
.
stride_kv_s_kv
,
params
.
s_kv
);
k_idx
++
;
asm
volatile
(
"s_waitcnt vmcnt(3)
\n\t
s_barrier
\n\t
"
);
cute
::
copy
(
smem_tiled_copy_K
,
tSsK
(
_
,
_
,
k_idx
%
4
),
tSrK_copy_view
(
_
,
_
,
k_idx
));
cute
::
gemm
(
tiled_mma
,
tSrQ
(
_
,
_
,
k_idx
),
tSrK
(
_
,
_
,
k_idx
),
acc_s
);
k_idx
++
;
asm
volatile
(
"s_waitcnt vmcnt(2)
\n\t
s_barrier
\n\t
"
);
cute
::
copy
(
smem_tiled_copy_K
,
tSsK
(
_
,
_
,
k_idx
%
4
),
tSrK_copy_view
(
_
,
_
,
k_idx
));
cute
::
gemm
(
tiled_mma
,
tSrQ
(
_
,
_
,
k_idx
),
tSrK
(
_
,
_
,
k_idx
),
acc_s
);
k_idx
++
;
asm
volatile
(
"s_waitcnt vmcnt(1)
\n\t
s_barrier
\n\t
"
);
cute
::
copy
(
smem_tiled_copy_K
,
tSsK
(
_
,
_
,
k_idx
%
4
),
tSrK_copy_view
(
_
,
_
,
k_idx
));
cute
::
gemm
(
tiled_mma
,
tSrQ
(
_
,
_
,
k_idx
),
tSrK
(
_
,
_
,
k_idx
),
acc_s
);
k_idx
++
;
asm
volatile
(
"s_waitcnt vmcnt(0)
\n\t
s_barrier
\n\t
"
);
cute
::
copy
(
smem_tiled_copy_K
,
tSsK
(
_
,
_
,
k_idx
%
4
),
tSrK_copy_view
(
_
,
_
,
k_idx
));
flash
::
__ds_read_m32x16_row_col_rrow
<
0
,
0
,
1
>
(
tOsVt
,
tOrVt_copy_view
);
flash
::
__ds_read_m32x16_row_col_rrow
<
0
,
1
,
1
>
(
tOsVt
,
tOrVt_copy_view
);
flash
::
__ds_read_m32x16_row_col_rrow
<
0
,
2
,
1
>
(
tOsVt
,
tOrVt_copy_view
);
flash
::
__ds_read_m32x16_row_col_rrow
<
0
,
3
,
1
>
(
tOsVt
,
tOrVt_copy_view
);
cute
::
gemm
(
tiled_mma
,
tSrQ
(
_
,
_
,
k_idx
),
tSrK
(
_
,
_
,
k_idx
),
acc_s
);
asm
volatile
(
"s_waitcnt lgkmcnt(0)
\n\t
s_barrier
\n\t
"
);
flash
::
lds_direct_copy_for_prefill_sparse_mla
<
true
,
false
,
false
>
(
gK
,
sK
,
row_offset
,
col
,
8
,
params
.
stride_kv_s_kv
,
params
.
s_kv
);
flash
::
lds_direct_copy_for_prefill_sparse_mla
<
true
,
false
,
false
>
(
gK
,
sK
,
row_offset
,
col
,
9
,
params
.
stride_kv_s_kv
,
params
.
s_kv
);
flash
::
lds_direct_copy_for_prefill_sparse_mla
<
true
,
false
,
false
>
(
gK
,
sK
,
row_offset
,
col
,
10
,
params
.
stride_kv_s_kv
,
params
.
s_kv
);
flash
::
lds_direct_copy_for_prefill_sparse_mla
<
true
,
false
,
false
>
(
gK
,
sK
,
row_offset
,
col
,
11
,
params
.
stride_kv_s_kv
,
params
.
s_kv
);
k_idx
++
;
asm
volatile
(
"s_waitcnt vmcnt(3)
\n\t
s_barrier
\n\t
"
);
cute
::
copy
(
smem_tiled_copy_K
,
tSsK
(
_
,
_
,
k_idx
%
4
),
tSrK_copy_view
(
_
,
_
,
k_idx
));
cute
::
gemm
(
tiled_mma
,
tSrQ
(
_
,
_
,
k_idx
),
tSrK
(
_
,
_
,
k_idx
),
acc_s
);
k_idx
++
;
asm
volatile
(
"s_waitcnt vmcnt(2)
\n\t
s_barrier
\n\t
"
);
cute
::
copy
(
smem_tiled_copy_K
,
tSsK
(
_
,
_
,
k_idx
%
4
),
tSrK_copy_view
(
_
,
_
,
k_idx
));
cute
::
gemm
(
tiled_mma
,
tSrQ
(
_
,
_
,
k_idx
),
tSrK
(
_
,
_
,
k_idx
),
acc_s
);
k_idx
++
;
asm
volatile
(
"s_waitcnt vmcnt(1)
\n\t
s_barrier
\n\t
"
);
cute
::
copy
(
smem_tiled_copy_K
,
tSsK
(
_
,
_
,
k_idx
%
4
),
tSrK_copy_view
(
_
,
_
,
k_idx
));
cute
::
gemm
(
tiled_mma
,
tSrQ
(
_
,
_
,
k_idx
),
tSrK
(
_
,
_
,
k_idx
),
acc_s
);
k_idx
++
;
asm
volatile
(
"s_waitcnt vmcnt(0)
\n\t
s_barrier
\n\t
"
);
cute
::
copy
(
smem_tiled_copy_K
,
tSsK
(
_
,
_
,
k_idx
%
4
),
tSrK_copy_view
(
_
,
_
,
k_idx
));
flash
::
__ds_read_m32x16_row_col_rrow
<
0
,
0
,
2
>
(
tOsVt
,
tOrVt_copy_view
);
flash
::
__ds_read_m32x16_row_col_rrow
<
0
,
1
,
2
>
(
tOsVt
,
tOrVt_copy_view
);
flash
::
__ds_read_m32x16_row_col_rrow
<
0
,
2
,
2
>
(
tOsVt
,
tOrVt_copy_view
);
flash
::
__ds_read_m32x16_row_col_rrow
<
0
,
3
,
2
>
(
tOsVt
,
tOrVt_copy_view
);
cute
::
gemm
(
tiled_mma
,
tSrQ
(
_
,
_
,
k_idx
),
tSrK
(
_
,
_
,
k_idx
),
acc_s
);
asm
volatile
(
"s_waitcnt lgkmcnt(0)
\n\t
s_barrier
\n\t
"
);
flash
::
lds_direct_copy_for_prefill_sparse_mla
<
true
,
false
,
false
>
(
gK
,
sK
,
row_offset
,
col
,
12
,
params
.
stride_kv_s_kv
,
params
.
s_kv
);
flash
::
lds_direct_copy_for_prefill_sparse_mla
<
true
,
false
,
false
>
(
gK
,
sK
,
row_offset
,
col
,
13
,
params
.
stride_kv_s_kv
,
params
.
s_kv
);
flash
::
lds_direct_copy_for_prefill_sparse_mla
<
true
,
false
,
false
>
(
gK
,
sK
,
row_offset
,
col
,
14
,
params
.
stride_kv_s_kv
,
params
.
s_kv
);
flash
::
lds_direct_copy_for_prefill_sparse_mla
<
true
,
false
,
false
>
(
gK
,
sK
,
row_offset
,
col
,
15
,
params
.
stride_kv_s_kv
,
params
.
s_kv
);
k_idx
++
;
asm
volatile
(
"s_waitcnt vmcnt(3)
\n\t
s_barrier
\n\t
"
);
cute
::
copy
(
smem_tiled_copy_K
,
tSsK
(
_
,
_
,
k_idx
%
4
),
tSrK_copy_view
(
_
,
_
,
k_idx
));
cute
::
gemm
(
tiled_mma
,
tSrQ
(
_
,
_
,
k_idx
),
tSrK
(
_
,
_
,
k_idx
),
acc_s
);
k_idx
++
;
asm
volatile
(
"s_waitcnt vmcnt(2)
\n\t
s_barrier
\n\t
"
);
cute
::
copy
(
smem_tiled_copy_K
,
tSsK
(
_
,
_
,
k_idx
%
4
),
tSrK_copy_view
(
_
,
_
,
k_idx
));
cute
::
gemm
(
tiled_mma
,
tSrQ
(
_
,
_
,
k_idx
),
tSrK
(
_
,
_
,
k_idx
),
acc_s
);
k_idx
++
;
asm
volatile
(
"s_waitcnt vmcnt(1)
\n\t
s_barrier
\n\t
"
);
cute
::
copy
(
smem_tiled_copy_K
,
tSsK
(
_
,
_
,
k_idx
%
4
),
tSrK_copy_view
(
_
,
_
,
k_idx
));
cute
::
gemm
(
tiled_mma
,
tSrQ
(
_
,
_
,
k_idx
),
tSrK
(
_
,
_
,
k_idx
),
acc_s
);
k_idx
++
;
asm
volatile
(
"s_waitcnt vmcnt(0)
\n\t
s_barrier
\n\t
"
);
cute
::
copy
(
smem_tiled_copy_K
,
tSsK
(
_
,
_
,
k_idx
%
4
),
tSrK_copy_view
(
_
,
_
,
k_idx
));
cute
::
gemm
(
tiled_mma
,
tSrQ
(
_
,
_
,
k_idx
),
tSrK
(
_
,
_
,
k_idx
),
acc_s
);
asm
volatile
(
"s_barrier
\n\t
"
);
// if (block0())
// {
// printf(" %.2f %.2f %.2f \n ", acc_s(0), acc_s(1), acc_s(2));
// }
Tensor
cS
=
make_identity_tensor
(
Shape
<
Int
<
kBlockM
>
,
Int
<
kBlockN
>>
{});
Tensor
tScS
=
thr_mma
.
partition_C
(
cS
);
auto
is_valid_token
=
[
&
](
const
int
idx
)
->
bool
{
int
offs
=
int
(
get
<
1
>
(
tScS
(
idx
)))
+
block_idx
*
kBlockN
;
int
t
=
gIndices
[
offs
];
bool
is_cur_token_valid
=
t
>=
0
&&
t
<
params
.
s_kv
;
if
constexpr
(
HAVE_TOPK_LENGTH
)
{
is_cur_token_valid
=
is_cur_token_valid
&&
(
offs
<
topk_length
);
}
return
is_cur_token_valid
;
};
{
for
(
int
i
=
0
;
i
<
size
(
acc_s
);
++
i
)
{
// idx = idx < params.topk ? gIndices[idx] : -1;
if
(
!
is_valid_token
(
i
))
acc_s
(
i
)
=
-
INFINITY
;
}
}
block_idx
==
0
?
softmax
.
template
softmax_rescale_o_prefill
<
/*Is_first=*/
true
,
/*Check_inf=*//*Is_local=*/
false
>(
acc_s
,
acc_o
,
sRow_max_reduce_buffer
,
params
.
sm_scale_div_log2
)
:
softmax
.
template
softmax_rescale_o_prefill
<
/*Is_first=*/
false
,
/*Check_inf=*//*Is_local=*/
false
>(
acc_s
,
acc_o
,
sRow_max_reduce_buffer
,
params
.
sm_scale_div_log2
);
// if (block0())
// {
// printf(" %.2f %.2f %.2f %.2f %.2f %.2f \n ", acc_s(0), acc_s(1), acc_s(2), acc_s(3), softmax.row_max(0), params.sm_scale_div_log2);
// }
Tensor
rP
=
flash
::
convert_type
<
Element
>
(
acc_s
);
Tensor
tOrP
=
flash
::
convert_layout_acc_Aregs
(
tiled_mma
,
tiled_mma_o
,
rP
,
sP
);
{
flash
::
__ds_read_m32x16_row_col_rrow
<
0
,
0
,
3
>
(
tOsVt
,
tOrVt_copy_view
);
// __ds_read_m32x16_row_col<0, 0>(tOsVt, tOrVt_copy_view);
// __ds_read_m32x16_row_col<1, 0>(tOsVt, tOrVt_copy_view);
// __ds_read_m32x16_row_col<2, 0>(tOsVt, tOrVt_copy_view);
// __ds_read_m32x16_row_col<0, 1>(tOsVt, tOrVt_copy_view);
// __ds_read_m32x16_row_col<1, 1>(tOsVt, tOrVt_copy_view);
// __ds_read_m32x16_row_col<2, 1>(tOsVt, tOrVt_copy_view);
cute
::
gemm
(
tiled_mma_o
,
tOrP
(
_
,
_
,
0
),
tOrVt
(
_
,
_
,
0
),
acc_o
);
flash
::
__ds_read_m32x16_row_col_rrow
<
0
,
1
,
3
>
(
tOsVt
,
tOrVt_copy_view
);
cute
::
gemm
(
tiled_mma_o
,
tOrP
(
_
,
_
,
1
),
tOrVt
(
_
,
_
,
1
),
acc_o
);
// __ds_read_m32x16_row_col<0, 2>(tOsVt, tOrVt_copy_view);
// __ds_read_m32x16_row_col<1, 2>(tOsVt, tOrVt_copy_view);
// __ds_read_m32x16_row_col<2, 2>(tOsVt, tOrVt_copy_view);
// __ds_read_m32x16_row_col<0, 3>(tOsVt, tOrVt_copy_view);
// __ds_read_m32x16_row_col<1, 3>(tOsVt, tOrVt_copy_view);
// __ds_read_m32x16_row_col<2, 3>(tOsVt, tOrVt_copy_view);
flash
::
__ds_read_m32x16_row_col_rrow
<
0
,
2
,
3
>
(
tOsVt
,
tOrVt_copy_view
);
cute
::
gemm
(
tiled_mma_o
,
tOrP
(
_
,
_
,
2
),
tOrVt
(
_
,
_
,
2
),
acc_o
);
flash
::
__ds_read_m32x16_row_col_rrow
<
0
,
3
,
3
>
(
tOsVt
,
tOrVt_copy_view
);
cute
::
gemm
(
tiled_mma_o
,
tOrP
(
_
,
_
,
3
),
tOrVt
(
_
,
_
,
3
),
acc_o
);
// for (int i = 0; i < size(tOrP); i++)
// {
// tOrP(i) = Element(1.0f);
// }
// cute::copy(smem_tiled_copy_V, tOsVt(_, 0, 0), tOrVt_copy_view(_, 0, 0));
// for (int i = 0; i < 4; i++) {
// cute::copy(smem_tiled_copy_V, tOsVt(_, _, i), tOrVt_copy_view(_, _, i));
// // if (tOrVt(_, _, i) )
// cute::gemm(tiled_mma_o, tOrP(_, _, i), tOrVt(_, _, i), acc_o);
// }
// for (int i = 0; i < 8 * 2 * 16; i++)
// {
// }
// asm volatile("s_barrier");
// if (thread0()) {
// for (int i = 0; i < 64; i++) {
// for (int j = 0; j < 512; j++) {
// printf(" %.2f ", float(sK(i, j)));
// }
// printf("\n");
// }
// }
// if (block0())
// {
// print("tidx %d acc_s %.2f %.2f %.2f %.2f %.2f %.2f %.2f %.2f %.2f %.2f %.2f %.2f %.2f %.2f %.2f %.2f \n",
// tidx, acc_o(0), acc_o(1), acc_o(2), acc_o(3),
// acc_o(4), acc_o(5), acc_o(6), acc_o(7),
// acc_o(8), acc_o(9), acc_o(10), acc_o(11),
// acc_o(12), acc_o(13), acc_o(14), acc_o(15)
// );
// }
}
// asm volatile("s_barrier\n\t");
}
Tensor
lse
=
softmax
.
template
normalize_softmax_lse_prefill
<
false
>(
acc_o
,
sRow_sum_reduce_buffer
,
params
.
sm_scale
);
const
index_t
row_offset_o
=
s_q_idx
*
params
.
h_q
*
params
.
d_v
+
bidh
*
kBlockM
*
params
.
d_v
;
Tensor
gO
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
Element
*>
(
params
.
out
)
+
row_offset_o
),
Shape
<
Int
<
kBlockM
>
,
Int
<
kHeadDimV
>>
{},
make_stride
(
params
.
d_v
,
_1
{}));
// lse = torch::empty({s_q, h_q}, opts.dtype(torch::kFloat));
const
index_t
row_offset_lse
=
s_q_idx
*
params
.
h_q
+
bidh
*
kBlockM
;
float
*
gLSE
=
reinterpret_cast
<
float
*>
(
params
.
lse
)
+
row_offset_lse
;
// const index_t row_offset_lse = m_block * params.h_q;
float
*
gMax_logits
=
reinterpret_cast
<
float
*>
(
params
.
max_logits
)
+
row_offset_lse
;
if
(
params
.
attn_sink
!=
nullptr
)
{
float
rAttn_sink
=
__ldg
((
float
*
)
params
.
attn_sink
+
start_head_idx
+
lane_idx
%
16
);
if
(
flash
::
is_positive_infinity
(
rAttn_sink
))
{
for
(
int
i
=
0
;
i
<
size
(
acc_o
);
i
++
)
{
acc_o
(
i
)
=
0.0
f
;
}
}
else
{
if
(
!
flash
::
is_positive_infinity
(
lse
(
0
)))
{
float
lse_exp2
=
__builtin_amdgcn_exp2f
(
lse
[
0
]
*
CUDART_L2E_F
);
float
rAttn_sink_exp2
=
__builtin_amdgcn_exp2f
(
rAttn_sink
*
CUDART_L2E_F
);
float
o_scale
=
lse_exp2
/
(
lse_exp2
+
rAttn_sink_exp2
);
for
(
int
i
=
0
;
i
<
size
(
acc_o
);
i
++
)
{
acc_o
(
i
)
*=
o_scale
;
}
}
}
}
// if (block0())
// {
// print("tidx %d acc_s %.2f %.2f %.2f %.2f %.2f %.2f %.2f %.2f %.2f %.2f %.2f %.2f %.2f %.2f %.2f %.2f \n",
// tidx, acc_o(0), acc_o(1), acc_o(2), acc_o(3),
// acc_o(4), acc_o(5), acc_o(6), acc_o(7),
// acc_o(8), acc_o(9), acc_o(10), acc_o(11),
// acc_o(12), acc_o(13), acc_o(14), acc_o(15)
// );
// }
{
// store O and gLSE
auto
rO
=
flash
::
convert_type
<
Element
>
(
acc_o
);
int
row
,
col
;
const
int
warpId
=
tidx
/
64
;
const
int
laneId
=
tidx
%
64
;
for
(
int
mi
=
0
;
mi
<
size
<
1
>
(
acc_o
);
++
mi
)
{
row
=
mi
*
kBlockM
+
laneId
%
16
;
if
(
row
<
params
.
h_q
)
{
for
(
int
ni
=
0
;
ni
<
size
<
2
>
(
acc_o
);
++
ni
)
{
col
=
(
laneId
/
16
)
+
ni
*
128
+
warpId
*
32
;
for
(
int
ei
=
0
;
ei
<
size
<
0
>
(
acc_o
);
++
ei
)
{
gO
(
row
,
col
)
=
rO
(
ei
,
mi
,
ni
);
col
+=
4
;
}
}
gLSE
[
row
]
=
lse
(
mi
);
gMax_logits
[
row
]
=
softmax
.
row_max
(
mi
)
*
params
.
sm_scale
;
}
}
}
}
}
...
@@ -26,9 +453,11 @@ void KernelTemplate<D_QK, HAVE_TOPK_LENGTH>::run(const SparseAttnFwdParams ¶
...
@@ -26,9 +453,11 @@ void KernelTemplate<D_QK, HAVE_TOPK_LENGTH>::run(const SparseAttnFwdParams ¶
KU_ASSERT
(
params
.
topk
%
(
2
*
B_TOPK
)
==
0
);
// To save some boundry checkings
KU_ASSERT
(
params
.
topk
%
(
2
*
B_TOPK
)
==
0
);
// To save some boundry checkings
KU_ASSERT
(
params
.
topk
>
0
);
KU_ASSERT
(
params
.
topk
>
0
);
KU_ASSERT
(
params
.
h_q
%
B_H
==
0
);
KU_ASSERT
(
params
.
h_q
%
B_H
==
0
);
auto
kernel
=
&
sparse_attn_fwd_kernel
<
KernelTemplate
<
D_QK
,
HAVE_TOPK_LENGTH
>>
;
auto
shape_Q
=
make_shape
(
params
.
h_q
,
params
.
d_qk
,
params
.
s_q
);
constexpr
size_t
smem_size
=
65536
;
dim3
grid
((
params
.
s_q
,
params
.
h_q
/
B_H
),
1
);
kernel
<<<
grid
,
NUM_THREADS
,
smem_size
,
params
.
stream
>>>
(
params
);
KU_CHECK_KERNEL_LAUNCH
();
}
}
template
<
int
D_QK
,
bool
HAVE_TOPK_LENGTH
>
template
<
int
D_QK
,
bool
HAVE_TOPK_LENGTH
>
...
...
csrc/softmax.h
View file @
c3cf875a
...
@@ -376,7 +376,7 @@ struct Softmax {
...
@@ -376,7 +376,7 @@ struct Softmax {
for
(
int
mi
=
0
;
mi
<
size
<
0
>
(
acc_o_rowcol
);
++
mi
)
{
for
(
int
mi
=
0
;
mi
<
size
<
0
>
(
acc_o_rowcol
);
++
mi
)
{
float
sum
=
row_sum
(
mi
);
float
sum
=
row_sum
(
mi
);
float
inv_sum
=
(
sum
==
0.
f
||
sum
!=
sum
)
?
1.
f
:
1.
f
/
sum
;
float
inv_sum
=
(
sum
==
0.
f
||
sum
!=
sum
)
?
1.
f
:
1.
f
/
sum
;
lse
(
mi
)
=
(
sum
==
0.
f
||
sum
!=
sum
)
?
(
Split
?
-
INFINITY
:
INFINITY
)
:
row_max
(
mi
)
*
softmax_scale
+
__log
2
f
(
sum
);
lse
(
mi
)
=
(
sum
==
0.
f
||
sum
!=
sum
)
?
(
Split
?
-
INFINITY
:
INFINITY
)
:
row_max
(
mi
)
*
softmax_scale
+
__logf
(
sum
);
float
scale
=
!
Is_dropout
?
inv_sum
:
inv_sum
*
rp_dropout
;
float
scale
=
!
Is_dropout
?
inv_sum
:
inv_sum
*
rp_dropout
;
#pragma unroll
#pragma unroll
for
(
int
ni
=
0
;
ni
<
size
<
1
>
(
acc_o_rowcol
);
++
ni
)
{
acc_o_rowcol
(
mi
,
ni
)
*=
scale
;
}
for
(
int
ni
=
0
;
ni
<
size
<
1
>
(
acc_o_rowcol
);
++
ni
)
{
acc_o_rowcol
(
mi
,
ni
)
*=
scale
;
}
...
...
csrc/utils.h
View file @
c3cf875a
...
@@ -362,4 +362,202 @@ is_positive_infinity(const float& f_val)
...
@@ -362,4 +362,202 @@ is_positive_infinity(const float& f_val)
return
fp32
.
as_bits
==
inf_tmp
.
as_bits
;
return
fp32
.
as_bits
==
inf_tmp
.
as_bits
;
}
}
template
<
bool
Is_even_MN
=
true
,
bool
Is_even_K
=
true
,
bool
Is_load_Q
=
false
,
class
SrcEngine
,
class
SrcLayout
,
class
DstEngine
,
class
DstLayout
>
CUTE_HOST_DEVICE
void
lds_direct_copy
(
Tensor
<
SrcEngine
,
SrcLayout
>
const
&
src
,
Tensor
<
DstEngine
,
DstLayout
>
&
dst
,
int
k_idx_
,
const
int
row_stride
,
const
int
max_MN
=
0
)
{
#if defined(__gfx936__) || defined(__gfx938__)
{
if
constexpr
(
Is_load_Q
)
{
// // 32x64
constexpr
int
warp_size
=
64
;
int
tidx
=
threadIdx
.
x
;
int
warp_id
=
__builtin_amdgcn_readfirstlane
(
tidx
/
warp_size
);
int
lane
=
tidx
%
warp_size
;
constexpr
int
element_size
=
2
;
int
k_idx
=
__builtin_amdgcn_readfirstlane
(
k_idx_
);
const
int
offset_s
=
0
;
struct
PtrWrapper
{
uint32_t
former
;
uint32_t
latter
;
};
PtrWrapper
glob_ptr
;
*
(
uint64_t
*
)
&
glob_ptr
=
reinterpret_cast
<
uint64_t
>
(
src
.
data
().
get
());
// glob_ptr.latter |= 0x40000000; // 62 bit: cache swizzle; 48~61: Stride
uint32x4_t
global_addr
=
{
0
};
global_addr
[
0
]
=
(
glob_ptr
.
former
);
global_addr
[
1
]
=
(
glob_ptr
.
latter
);
global_addr
[
2
]
=
0x80000000
;
global_addr
[
3
]
=
0x00020000
;
constexpr
int
elements_per_thread
=
8
;
constexpr
int
bytes_per_warp
=
warp_size
*
8
*
element_size
;
int
mma_k
=
16
*
128
;
int
row
=
lane
%
16
;
int
col
=
lane
/
16
;
int
row_offset
=
row
;
int
col_offset
=
(
col
+
warp_id
*
4
)
*
elements_per_thread
+
k_idx
*
128
;
int
offset_v
=
(
row_offset
*
row_stride
+
col_offset
)
*
element_size
;
// bytes
if
(
!
Is_even_MN
&&
row_offset
>=
max_MN
)
offset_v
=
-
1
;
if
(
!
Is_even_K
&&
col_offset
>=
576
)
offset_v
=
-
1
;
int
ldsAddrPerWave
=
reinterpret_cast
<
size_t
>
(
dst
.
data
().
get
())
+
warp_id
*
bytes_per_warp
+
k_idx
*
mma_k
*
element_size
;
asm
volatile
(
"s_mov_b32 m0, %1
\n\t
"
"buffer_load_dwordx4 %0, %2, %3 ,offen offset:0, lds
\n
"
::
"v"
(
offset_v
),
"s"
(
ldsAddrPerWave
),
"s"
(
global_addr
),
"s"
(
offset_s
)
:
);
}
else
{
constexpr
int
warp_size
=
64
;
int
tidx
=
threadIdx
.
x
;
int
warp_id
=
__builtin_amdgcn_readfirstlane
(
tidx
/
warp_size
);
int
lane
=
tidx
%
warp_size
;
constexpr
int
element_size
=
2
;
int
k_idx
=
__builtin_amdgcn_readfirstlane
(
k_idx_
);
const
int
offset_s
=
0
;
// global addr
// uint32x4_t global_addr = {0};
// *(uint64_t*)&global_addr = reinterpret_cast<uint64_t>(src.data().get());
// global_addr[1] += 0x41000000; // 62 bit: cache swizzle; 48~61: Stride
// global_addr[2] = 0xfffffffe;
// global_addr[3] = 0x00020000;
struct
PtrWrapper
{
uint32_t
former
;
uint32_t
latter
;
};
PtrWrapper
glob_ptr
;
*
(
uint64_t
*
)
&
glob_ptr
=
reinterpret_cast
<
uint64_t
>
(
src
.
data
().
get
());
// glob_ptr.latter |= 0x40000000; // 62 bit: cache swizzle; 48~61: Stride
uint32x4_t
global_addr
=
{
0
};
global_addr
[
0
]
=
(
glob_ptr
.
former
);
global_addr
[
1
]
=
(
glob_ptr
.
latter
);
global_addr
[
2
]
=
0x80000000
;
global_addr
[
3
]
=
0x00020000
;
constexpr
int
elements_per_thread
=
8
;
constexpr
int
bytes_per_warp
=
warp_size
*
8
*
element_size
;
int
mma_k
=
32
*
64
;
// int row = lane / 4;
// int col = lane % 4;
// int swizzle_col = ((row / 2) ^ (col )) * 4 + (col % 4);
// 此处待优化,后8行,行号需要交换
int
virtual_row
=
lane
/
8
;
int
virtual_col
=
lane
%
8
;
int
swizzle_col
=
virtual_row
^
virtual_col
;
int
row
=
lane
/
4
;
// 8->9 9->8
row
=
(
row
>=
8
)
^
row
;
// row = row >= 8 ? (swizzle_col / 4) > 0 ? row + 1 : row - 1 : row;
int
col
=
swizzle_col
%
4
;
int
row_offset
=
row
+
(
warp_id
*
16
)
;
int
col_offset
=
col
*
elements_per_thread
+
k_idx
*
32
;
int
offset_v
=
(
row_offset
*
row_stride
+
col_offset
)
*
element_size
;
// bytes
if
(
!
Is_even_MN
&&
row_offset
>=
max_MN
)
offset_v
=
-
1
;
int
ldsAddrPerWave
=
reinterpret_cast
<
size_t
>
(
dst
.
data
().
get
())
+
warp_id
*
bytes_per_warp
+
k_idx
*
mma_k
*
element_size
;
asm
volatile
(
"s_mov_b32 m0, %1
\n\t
"
"buffer_load_dwordx4 %0, %2, %3 ,offen offset:0, lds
\n
"
::
"v"
(
offset_v
),
"s"
(
ldsAddrPerWave
),
"s"
(
global_addr
),
"s"
(
offset_s
)
:
);
}
}
#endif
}
template
<
bool
Is_even_K
=
true
,
bool
Is_even_MN
=
true
,
bool
Use_cache_swizzle
=
true
,
class
SrcEngine
,
class
SrcLayout
,
class
DstEngine
,
class
DstLayout
// class IdxEngine, class IdxLayout
>
CUTE_HOST_DEVICE
void
lds_direct_copy_for_prefill_sparse_mla
(
Tensor
<
SrcEngine
,
SrcLayout
>
const
&
src
,
Tensor
<
DstEngine
,
DstLayout
>
&
dst
,
int
row_offset
,
int
col
,
int
k_idx_
,
const
int
row_stride
,
int
max_MN
=
0
)
{
constexpr
int
warp_size
=
64
;
int
tidx
=
threadIdx
.
x
;
int
warp_id
=
__builtin_amdgcn_readfirstlane
(
tidx
/
warp_size
);
int
lane
=
tidx
%
warp_size
;
constexpr
int
element_size
=
2
;
int
k_idx
=
__builtin_amdgcn_readfirstlane
(
k_idx_
);
const
int
offset_s
=
0
;
// global addr
// uint32x4_t global_addr = {0};
// *(uint64_t*)&global_addr = reinterpret_cast<uint64_t>(src.data().get());
// global_addr[1] += 0x41000000; // 62 bit: cache swizzle; 48~61: Stride
// global_addr[2] = 0xfffffffe;
// global_addr[3] = 0x00020000;
struct
PtrWrapper
{
uint32_t
former
;
uint32_t
latter
;
};
PtrWrapper
glob_ptr
;
*
(
uint64_t
*
)
&
glob_ptr
=
reinterpret_cast
<
uint64_t
>
(
src
.
data
().
get
());
glob_ptr
.
latter
|=
((
row_stride
*
2
)
<<
16
);
// 62 bit: cache swizzle; 48~61: Stride
uint32x4_t
global_addr
=
{
0
};
global_addr
[
0
]
=
(
glob_ptr
.
former
);
global_addr
[
1
]
=
(
glob_ptr
.
latter
);
global_addr
[
2
]
=
max_MN
;
global_addr
[
3
]
=
0x00020000
;
constexpr
int
elements_per_thread
=
8
;
constexpr
int
bytes_per_warp
=
warp_size
*
8
*
element_size
;
int
mma_k
=
32
*
64
;
int
col_offset
=
col
*
elements_per_thread
+
k_idx
*
32
;
int
offset_v
=
(
col_offset
)
*
element_size
;
// bytes
// int offset_v = (row_offset * row_stride + col_offset) * element_size; // bytes
// if (!Is_even_MN && (row_offset >= max_MN || row_offset < 0)) offset_v = -1;
int
ldsAddrPerWave
=
reinterpret_cast
<
size_t
>
(
dst
.
data
().
get
())
+
warp_id
*
bytes_per_warp
+
(
k_idx
%
4
)
*
mma_k
*
element_size
;
typedef
uint32_t
uint32x2_t
__attribute__
((
ext_vector_type
(
2
)));
uint32x2_t
index_offset
=
{
0
};
index_offset
[
0
]
=
row_offset
==
-
1
?
max_MN
:
row_offset
;
index_offset
[
1
]
=
offset_v
;
asm
volatile
(
"s_mov_b32 m0, %1
\n\t
"
"buffer_load_dwordx4 %0, %2, %3 , idxen offen offset:0, lds
\n
"
::
"v"
(
index_offset
),
"s"
(
ldsAddrPerWave
),
"s"
(
global_addr
),
"s"
(
offset_s
)
:
);
}
}
}
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment