Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
FlashMLA
Commits
c3cf875a
Commit
c3cf875a
authored
Jan 28, 2026
by
zhanghj2
Browse files
实现sparse prefill, 还有bug
parent
50e2de8d
Changes
5
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
725 additions
and
84 deletions
+725
-84
csrc/api/sparse_fwd.h
csrc/api/sparse_fwd.h
+8
-30
csrc/sm90/prefill/sparse/config.h
csrc/sm90/prefill/sparse/config.h
+85
-49
csrc/sm90/prefill/sparse/phase1.cuh
csrc/sm90/prefill/sparse/phase1.cuh
+433
-4
csrc/softmax.h
csrc/softmax.h
+1
-1
csrc/utils.h
csrc/utils.h
+198
-0
No files found.
csrc/api/sparse_fwd.h
View file @
c3cf875a
...
@@ -4,7 +4,7 @@
...
@@ -4,7 +4,7 @@
#include "params.h"
#include "params.h"
//
#include "sm90/prefill/sparse/phase1.h"
#include "sm90/prefill/sparse/phase1.h"
enum
class
FwdFeatures
:
int
{
enum
class
FwdFeatures
:
int
{
...
@@ -39,7 +39,7 @@ protected:
...
@@ -39,7 +39,7 @@ protected:
void
run_
(
const
SparseAttnFwdParams
&
params
,
const
std
::
vector
<
FeatureT
>
&
required_features
)
override
{
void
run_
(
const
SparseAttnFwdParams
&
params
,
const
std
::
vector
<
FeatureT
>
&
required_features
)
override
{
DISPATCH_HEAD_DIM
(
params
.
d_qk
,
HEAD_DIM_QK
,
[
&
]()
{
DISPATCH_HEAD_DIM
(
params
.
d_qk
,
HEAD_DIM_QK
,
[
&
]()
{
DISPATCH_BOOLEAN_FLAG
(
params
.
topk_length
!=
nullptr
,
HAVE_TOPK_LENGTH
,
[
&
]()
{
DISPATCH_BOOLEAN_FLAG
(
params
.
topk_length
!=
nullptr
,
HAVE_TOPK_LENGTH
,
[
&
]()
{
//
sm90::fwd::run_fwd_phase1_kernel<HEAD_DIM_QK, HAVE_TOPK_LENGTH>(params);
sm90
::
fwd
::
run_fwd_phase1_kernel
<
HEAD_DIM_QK
,
HAVE_TOPK_LENGTH
>
(
params
);
});
});
});
});
}
}
...
@@ -208,34 +208,12 @@ static std::vector<at::Tensor> sparse_attn_prefill_interface(
...
@@ -208,34 +208,12 @@ static std::vector<at::Tensor> sparse_attn_prefill_interface(
required_features
.
push_back
(
FwdFeatures
::
TOPK_LENGTH
);
required_features
.
push_back
(
FwdFeatures
::
TOPK_LENGTH
);
}
}
// if (is_sm90a) {
if
(
is_sm90a
)
{
// Fwd_Sm90_Impl fwd_impl;
Fwd_Sm90_Impl
fwd_impl
;
// fwd_impl.run(params, required_features);
fwd_impl
.
run
(
params
,
required_features
);
// } else if (is_sm100f) {
}
else
{
// if (h_q == 64) {
TORCH_CHECK
(
false
,
"Unsupported architecture"
);
// Fwd_Sm100_Head64_Impl fwd_impl;
}
// fwd_impl.run(params, required_features);
// } else if (h_q == 128) {
// Fwd_Sm100_Head128_Small_TopK_Impl small_topk_impl;
// Fwd_Sm100_Head128_Impl regular_impl;
// bool use_small_topk_impl = false;
// if (
// (topk <= 1280 && small_topk_impl.check_if_all_features_are_supported(required_features)) ||
// !regular_impl.check_if_all_features_are_supported(required_features)
// ) {
// use_small_topk_impl = true;
// }
// if (use_small_topk_impl) {
// small_topk_impl.run(params, required_features);
// } else {
// regular_impl.run(params, required_features);
// }
// } else {
// TORCH_CHECK(false, "Unsupported h_q: ", h_q);
// }
// } else {
// TORCH_CHECK(false, "Unsupported architecture");
// }
return
{
out
,
max_logits
,
lse
};
return
{
out
,
max_logits
,
lse
};
}
}
csrc/sm90/prefill/sparse/config.h
View file @
c3cf875a
...
@@ -19,61 +19,97 @@ static constexpr int D_Q = D_QK;
...
@@ -19,61 +19,97 @@ static constexpr int D_Q = D_QK;
static
constexpr
int
D_K
=
D_QK
;
static
constexpr
int
D_K
=
D_QK
;
static
constexpr
int
D_V
=
512
;
static
constexpr
int
D_V
=
512
;
static
constexpr
int
B_H
=
64
;
static
constexpr
int
kNWarps
=
4
;
static
constexpr
int
B_H
=
16
;
static
constexpr
int
B_TOPK
=
64
;
// TopK block size
static
constexpr
int
B_TOPK
=
64
;
// TopK block size
static
constexpr
int
NUM_THREADS
=
128
*
3
;
static
constexpr
int
NUM_THREADS
=
kNWarps
*
64
;
static
constexpr
float
MAX_INIT_VAL
=
-
1e30
;
// We use this number as the initial value for mi (max logits)
static
constexpr
float
MAX_INIT_VAL
=
-
1e30
;
// We use this number as the initial value for mi (max logits)
template
<
int
NUM_TILES
>
using
Element
=
cutlass
::
bfloat16_t
;
using
SmemLayoutQTiles
=
decltype
(
coalesce
(
tile_to_shape
(
using
elem_type
=
Element
;
GMMA
::
Layout_K_SW128_Atom
<
bf16
>
{},
using
ElementAccum
=
float
;
Shape
<
Int
<
B_H
>
,
Int
<
64
*
NUM_TILES
>>
{},
using
index_t
=
int64_t
;
Step
<
_1
,
_2
>
{}
static
constexpr
int
kBlockM
=
B_H
;
),
Shape
<
_1
,
_1
>
{}));
static
constexpr
int
kBlockN
=
B_TOPK
;
static
constexpr
int
kHeadDim
=
D_QK
;
template
<
int
NUM_TILES
>
static
constexpr
int
kHeadDimV
=
D_V
;
using
SmemLayoutOTiles
=
decltype
(
coalesce
(
tile_to_shape
(
GMMA
::
Layout_K_SW128_Atom
<
bf16
>
{},
using
ValLayoutMNK
=
Layout
<
Shape
<
_1
,
_1
,
_1
>>
;
Shape
<
Int
<
B_H
>
,
Int
<
64
*
NUM_TILES
>>
{},
// 没打开?
Step
<
_1
,
_2
>
{}
// #if defined(__gfx936__) || defined(__gfx938__) || 1
),
Shape
<
_1
,
_1
>
{}));
// using MMA_Atom_Arch = std::conditional_t<
// std::is_same_v<elem_type, cutlass::half_t>,
template
<
int
NUM_TILES
>
// MMA_Atom<GFX928_16x16x32_F32F16F16F32_NT>,
using
SmemLayoutKTiles
=
decltype
(
coalesce
(
tile_to_shape
(
// MMA_Atom<GFX928_16x16x32_F32BF16BF16F32_NT>
GMMA
::
Layout_SW128_Atom
<
bf16
,
GMMA
::
Major
::
K
>
{},
// >;
Shape
<
Int
<
B_TOPK
>
,
Int
<
64
*
NUM_TILES
>>
{},
using
MMA_Atom_Arch
=
std
::
conditional_t
<
Step
<
_1
,
_2
>
{}
std
::
is_same_v
<
elem_type
,
cutlass
::
half_t
>
,
),
Shape
<
_1
,
_1
>
{}));
MMA_Atom
<
GFX928_16x16x32_F32F16F16F32_NN
>
,
MMA_Atom
<
GFX928_16x16x32_F32BF16BF16F32_NN
>
template
<
int
NUM_TILES
>
>
;
using
SmemLayoutKTilesTransposed
=
decltype
(
composition
(
using
TiledMma
=
TiledMMA
<
SmemLayoutKTiles
<
NUM_TILES
>
{},
MMA_Atom_Arch
,
Layout
<
Shape
<
Int
<
64
*
NUM_TILES
>
,
Int
<
B_TOPK
>>
,
Stride
<
Int
<
B_TOPK
>
,
_1
>>
{}
Layout
<
Shape
<
_1
,
Int
<
kNWarps
>
,
_1
>>
,
// 1x4x1 or 1x8x1 thread group
));
ValLayoutMNK
>
;
// #endif
using
SmemLayoutQ
=
SmemLayoutQTiles
<
D_Q
/
64
>
;
using
SmemLayoutO
=
SmemLayoutOTiles
<
D_V
/
64
>
;
using
MMA_Atom_Arch_16x32
=
std
::
conditional_t
<
using
SmemLayoutK
=
SmemLayoutKTiles
<
D_Q
/
64
>
;
std
::
is_same_v
<
elem_type
,
cutlass
::
half_t
>
,
using
SmemLayoutV
=
SmemLayoutKTilesTransposed
<
D_V
/
64
>
;
MMA_Atom
<
GFX928_16x32x16_F32F16F16F32_NT
>
,
using
SmemLayoutHalfV
=
SmemLayoutKTilesTransposed
<
D_V
/
64
/
2
>
;
MMA_Atom
<
GFX928_16x32x16_F32BF16BF16F32_NT
>
>
;
using
SmemLayoutS
=
decltype
(
coalesce
(
tile_to_shape
(
GMMA
::
Layout_K_SW128_Atom
<
bf16
>
{},
using
TiledMma_O
=
TiledMMA
<
Shape
<
Int
<
B_H
>
,
Int
<
B_TOPK
>>
{}
MMA_Atom_Arch_16x32
,
),
Shape
<
_1
,
_1
>
{}));
Layout
<
Shape
<
_1
,
Int
<
kNWarps
>
,
_1
>>
,
// 1x4x1 or 1x8x1 thread group
ValLayoutMNK
>
;
using
SmemLayoutAtomQ
=
Layout
<
Shape
<
Int
<
16
>
,
Int
<
32
>>
,
Stride
<
Int
<
32
>
,
_1
>>
;
using
SmemLayoutQ
=
decltype
(
tile_to_shape
(
SmemLayoutAtomQ
{},
Shape
<
Int
<
kBlockM
>
,
Int
<
kHeadDim
>>
{}));
using
SmemLayoutAtomK
=
decltype
(
composition
(
Swizzle
<
3
,
3
,
3
>
{},
Layout
<
Shape
<
Int
<
8
>
,
Int
<
32
>>
,
Stride
<
Int
<
32
>
,
_1
>>
{}));
using
SmemLayoutK
=
decltype
(
tile_to_shape
(
SmemLayoutAtomK
{},
Shape
<
Int
<
kBlockN
>
,
Int
<
16
*
32
>>
{}));
using
SmemLayoutAtomV
=
SmemLayoutAtomK
;
using
SmemLayoutV
=
decltype
(
tile_to_shape
(
SmemLayoutAtomV
{},
Shape
<
Int
<
kBlockN
>
,
Int
<
kHeadDimV
>>
{}));
using
SmemLayoutVtransposed
=
decltype
(
composition
(
SmemLayoutV
{},
make_layout
(
Shape
<
Int
<
kHeadDimV
>
,
Int
<
kBlockN
>>
{},
GenRowMajor
{})));
using
SmemLayoutVtransposedNoSwizzle
=
decltype
(
get_nonswizzle_portion
(
SmemLayoutVtransposed
{}));
using
SmemLayoutAtomP
=
Layout
<
Shape
<
Int
<
4
*
16
*
16
>>
,
Stride
<
Int
<
1
>>>
;
using
SmemLayoutP
=
decltype
(
tile_to_shape
(
SmemLayoutAtomP
{},
Shape
<
Int
<
4
*
16
*
16
>>
{}));
using
SmemLayoutRow
=
Layout
<
Shape
<
_128
>
,
Stride
<
_1
>>
;
using
SmemLayoutK_place_holder
=
decltype
(
tile_to_shape
(
SmemLayoutAtomK
{},
Shape
<
Int
<
kBlockN
>
,
Int
<
4
*
32
>>
{}));
struct
SharedMemoryPlan
{
struct
SharedMemoryPlan
{
union
{
union
{
array_aligned
<
bf16
,
cosize_v
<
SmemLayoutQ
>>
q
;
struct
{
array_aligned
<
bf16
,
cosize_v
<
SmemLayoutO
>>
o
;
cute
::
array_aligned
<
Element
,
cute
::
cosize_v
<
SmemLayoutV
>>
smem_v
;
// Double buffer
}
q_o
;
};
array_aligned
<
bf16
,
cosize_v
<
SmemLayoutK
>>
k
[
2
];
struct
{
array_aligned
<
bf16
,
cosize_v
<
SmemLayoutS
>>
s
[
D_QK
==
576
?
1
:
2
];
// For V3.2 (whose D_QK is 576), we overlap sS[0] with k's RoPE part to save shared memory; For MODEL1 (whose D_QK is 512), we allocate two buffers
cute
::
array_aligned
<
Element
,
cute
::
cosize_v
<
SmemLayoutK_place_holder
>>
smem_place_holder
;
// Double buffer
cute
::
array_aligned
<
Element
,
cute
::
cosize_v
<
SmemLayoutP
>>
smem_p
;
bool
is_kv_valid
[
2
][
B_TOPK
];
cute
::
array_aligned
<
ElementAccum
,
cute
::
cosize_v
<
SmemLayoutRow
>>
smem_row_sum
;
float2
sM
[
32
];
cute
::
array_aligned
<
ElementAccum
,
cute
::
cosize_v
<
SmemLayoutRow
>>
smem_row_max
;
float2
sL
[
64
];
// For reduction across WG0/1 in epilogue
};
float
final_max_logits
[
64
],
final_lse
[
64
];
struct
{
cute
::
array_aligned
<
Element
,
cute
::
cosize_v
<
SmemLayoutQ
>>
smem_q
;
};
};
// transac_bar_t bar_q, bar_k0_free[2], bar_k0_ready[2], bar_k1_free[2], bar_k1_ready[2], bar_is_kv_valid_ready;
// transac_bar_t bar_q, bar_k0_free[2], bar_k0_ready[2], bar_k1_free[2], bar_k1_ready[2], bar_is_kv_valid_ready;
};
};
...
...
csrc/sm90/prefill/sparse/phase1.cuh
View file @
c3cf875a
This diff is collapsed.
Click to expand it.
csrc/softmax.h
View file @
c3cf875a
...
@@ -376,7 +376,7 @@ struct Softmax {
...
@@ -376,7 +376,7 @@ struct Softmax {
for
(
int
mi
=
0
;
mi
<
size
<
0
>
(
acc_o_rowcol
);
++
mi
)
{
for
(
int
mi
=
0
;
mi
<
size
<
0
>
(
acc_o_rowcol
);
++
mi
)
{
float
sum
=
row_sum
(
mi
);
float
sum
=
row_sum
(
mi
);
float
inv_sum
=
(
sum
==
0.
f
||
sum
!=
sum
)
?
1.
f
:
1.
f
/
sum
;
float
inv_sum
=
(
sum
==
0.
f
||
sum
!=
sum
)
?
1.
f
:
1.
f
/
sum
;
lse
(
mi
)
=
(
sum
==
0.
f
||
sum
!=
sum
)
?
(
Split
?
-
INFINITY
:
INFINITY
)
:
row_max
(
mi
)
*
softmax_scale
+
__log
2
f
(
sum
);
lse
(
mi
)
=
(
sum
==
0.
f
||
sum
!=
sum
)
?
(
Split
?
-
INFINITY
:
INFINITY
)
:
row_max
(
mi
)
*
softmax_scale
+
__logf
(
sum
);
float
scale
=
!
Is_dropout
?
inv_sum
:
inv_sum
*
rp_dropout
;
float
scale
=
!
Is_dropout
?
inv_sum
:
inv_sum
*
rp_dropout
;
#pragma unroll
#pragma unroll
for
(
int
ni
=
0
;
ni
<
size
<
1
>
(
acc_o_rowcol
);
++
ni
)
{
acc_o_rowcol
(
mi
,
ni
)
*=
scale
;
}
for
(
int
ni
=
0
;
ni
<
size
<
1
>
(
acc_o_rowcol
);
++
ni
)
{
acc_o_rowcol
(
mi
,
ni
)
*=
scale
;
}
...
...
csrc/utils.h
View file @
c3cf875a
...
@@ -362,4 +362,202 @@ is_positive_infinity(const float& f_val)
...
@@ -362,4 +362,202 @@ is_positive_infinity(const float& f_val)
return
fp32
.
as_bits
==
inf_tmp
.
as_bits
;
return
fp32
.
as_bits
==
inf_tmp
.
as_bits
;
}
}
template
<
bool
Is_even_MN
=
true
,
bool
Is_even_K
=
true
,
bool
Is_load_Q
=
false
,
class
SrcEngine
,
class
SrcLayout
,
class
DstEngine
,
class
DstLayout
>
CUTE_HOST_DEVICE
void
lds_direct_copy
(
Tensor
<
SrcEngine
,
SrcLayout
>
const
&
src
,
Tensor
<
DstEngine
,
DstLayout
>
&
dst
,
int
k_idx_
,
const
int
row_stride
,
const
int
max_MN
=
0
)
{
#if defined(__gfx936__) || defined(__gfx938__)
{
if
constexpr
(
Is_load_Q
)
{
// // 32x64
constexpr
int
warp_size
=
64
;
int
tidx
=
threadIdx
.
x
;
int
warp_id
=
__builtin_amdgcn_readfirstlane
(
tidx
/
warp_size
);
int
lane
=
tidx
%
warp_size
;
constexpr
int
element_size
=
2
;
int
k_idx
=
__builtin_amdgcn_readfirstlane
(
k_idx_
);
const
int
offset_s
=
0
;
struct
PtrWrapper
{
uint32_t
former
;
uint32_t
latter
;
};
PtrWrapper
glob_ptr
;
*
(
uint64_t
*
)
&
glob_ptr
=
reinterpret_cast
<
uint64_t
>
(
src
.
data
().
get
());
// glob_ptr.latter |= 0x40000000; // 62 bit: cache swizzle; 48~61: Stride
uint32x4_t
global_addr
=
{
0
};
global_addr
[
0
]
=
(
glob_ptr
.
former
);
global_addr
[
1
]
=
(
glob_ptr
.
latter
);
global_addr
[
2
]
=
0x80000000
;
global_addr
[
3
]
=
0x00020000
;
constexpr
int
elements_per_thread
=
8
;
constexpr
int
bytes_per_warp
=
warp_size
*
8
*
element_size
;
int
mma_k
=
16
*
128
;
int
row
=
lane
%
16
;
int
col
=
lane
/
16
;
int
row_offset
=
row
;
int
col_offset
=
(
col
+
warp_id
*
4
)
*
elements_per_thread
+
k_idx
*
128
;
int
offset_v
=
(
row_offset
*
row_stride
+
col_offset
)
*
element_size
;
// bytes
if
(
!
Is_even_MN
&&
row_offset
>=
max_MN
)
offset_v
=
-
1
;
if
(
!
Is_even_K
&&
col_offset
>=
576
)
offset_v
=
-
1
;
int
ldsAddrPerWave
=
reinterpret_cast
<
size_t
>
(
dst
.
data
().
get
())
+
warp_id
*
bytes_per_warp
+
k_idx
*
mma_k
*
element_size
;
asm
volatile
(
"s_mov_b32 m0, %1
\n\t
"
"buffer_load_dwordx4 %0, %2, %3 ,offen offset:0, lds
\n
"
::
"v"
(
offset_v
),
"s"
(
ldsAddrPerWave
),
"s"
(
global_addr
),
"s"
(
offset_s
)
:
);
}
else
{
constexpr
int
warp_size
=
64
;
int
tidx
=
threadIdx
.
x
;
int
warp_id
=
__builtin_amdgcn_readfirstlane
(
tidx
/
warp_size
);
int
lane
=
tidx
%
warp_size
;
constexpr
int
element_size
=
2
;
int
k_idx
=
__builtin_amdgcn_readfirstlane
(
k_idx_
);
const
int
offset_s
=
0
;
// global addr
// uint32x4_t global_addr = {0};
// *(uint64_t*)&global_addr = reinterpret_cast<uint64_t>(src.data().get());
// global_addr[1] += 0x41000000; // 62 bit: cache swizzle; 48~61: Stride
// global_addr[2] = 0xfffffffe;
// global_addr[3] = 0x00020000;
struct
PtrWrapper
{
uint32_t
former
;
uint32_t
latter
;
};
PtrWrapper
glob_ptr
;
*
(
uint64_t
*
)
&
glob_ptr
=
reinterpret_cast
<
uint64_t
>
(
src
.
data
().
get
());
// glob_ptr.latter |= 0x40000000; // 62 bit: cache swizzle; 48~61: Stride
uint32x4_t
global_addr
=
{
0
};
global_addr
[
0
]
=
(
glob_ptr
.
former
);
global_addr
[
1
]
=
(
glob_ptr
.
latter
);
global_addr
[
2
]
=
0x80000000
;
global_addr
[
3
]
=
0x00020000
;
constexpr
int
elements_per_thread
=
8
;
constexpr
int
bytes_per_warp
=
warp_size
*
8
*
element_size
;
int
mma_k
=
32
*
64
;
// int row = lane / 4;
// int col = lane % 4;
// int swizzle_col = ((row / 2) ^ (col )) * 4 + (col % 4);
// 此处待优化,后8行,行号需要交换
int
virtual_row
=
lane
/
8
;
int
virtual_col
=
lane
%
8
;
int
swizzle_col
=
virtual_row
^
virtual_col
;
int
row
=
lane
/
4
;
// 8->9 9->8
row
=
(
row
>=
8
)
^
row
;
// row = row >= 8 ? (swizzle_col / 4) > 0 ? row + 1 : row - 1 : row;
int
col
=
swizzle_col
%
4
;
int
row_offset
=
row
+
(
warp_id
*
16
)
;
int
col_offset
=
col
*
elements_per_thread
+
k_idx
*
32
;
int
offset_v
=
(
row_offset
*
row_stride
+
col_offset
)
*
element_size
;
// bytes
if
(
!
Is_even_MN
&&
row_offset
>=
max_MN
)
offset_v
=
-
1
;
int
ldsAddrPerWave
=
reinterpret_cast
<
size_t
>
(
dst
.
data
().
get
())
+
warp_id
*
bytes_per_warp
+
k_idx
*
mma_k
*
element_size
;
asm
volatile
(
"s_mov_b32 m0, %1
\n\t
"
"buffer_load_dwordx4 %0, %2, %3 ,offen offset:0, lds
\n
"
::
"v"
(
offset_v
),
"s"
(
ldsAddrPerWave
),
"s"
(
global_addr
),
"s"
(
offset_s
)
:
);
}
}
#endif
}
template
<
bool
Is_even_K
=
true
,
bool
Is_even_MN
=
true
,
bool
Use_cache_swizzle
=
true
,
class
SrcEngine
,
class
SrcLayout
,
class
DstEngine
,
class
DstLayout
// class IdxEngine, class IdxLayout
>
CUTE_HOST_DEVICE
void
lds_direct_copy_for_prefill_sparse_mla
(
Tensor
<
SrcEngine
,
SrcLayout
>
const
&
src
,
Tensor
<
DstEngine
,
DstLayout
>
&
dst
,
int
row_offset
,
int
col
,
int
k_idx_
,
const
int
row_stride
,
int
max_MN
=
0
)
{
constexpr
int
warp_size
=
64
;
int
tidx
=
threadIdx
.
x
;
int
warp_id
=
__builtin_amdgcn_readfirstlane
(
tidx
/
warp_size
);
int
lane
=
tidx
%
warp_size
;
constexpr
int
element_size
=
2
;
int
k_idx
=
__builtin_amdgcn_readfirstlane
(
k_idx_
);
const
int
offset_s
=
0
;
// global addr
// uint32x4_t global_addr = {0};
// *(uint64_t*)&global_addr = reinterpret_cast<uint64_t>(src.data().get());
// global_addr[1] += 0x41000000; // 62 bit: cache swizzle; 48~61: Stride
// global_addr[2] = 0xfffffffe;
// global_addr[3] = 0x00020000;
struct
PtrWrapper
{
uint32_t
former
;
uint32_t
latter
;
};
PtrWrapper
glob_ptr
;
*
(
uint64_t
*
)
&
glob_ptr
=
reinterpret_cast
<
uint64_t
>
(
src
.
data
().
get
());
glob_ptr
.
latter
|=
((
row_stride
*
2
)
<<
16
);
// 62 bit: cache swizzle; 48~61: Stride
uint32x4_t
global_addr
=
{
0
};
global_addr
[
0
]
=
(
glob_ptr
.
former
);
global_addr
[
1
]
=
(
glob_ptr
.
latter
);
global_addr
[
2
]
=
max_MN
;
global_addr
[
3
]
=
0x00020000
;
constexpr
int
elements_per_thread
=
8
;
constexpr
int
bytes_per_warp
=
warp_size
*
8
*
element_size
;
int
mma_k
=
32
*
64
;
int
col_offset
=
col
*
elements_per_thread
+
k_idx
*
32
;
int
offset_v
=
(
col_offset
)
*
element_size
;
// bytes
// int offset_v = (row_offset * row_stride + col_offset) * element_size; // bytes
// if (!Is_even_MN && (row_offset >= max_MN || row_offset < 0)) offset_v = -1;
int
ldsAddrPerWave
=
reinterpret_cast
<
size_t
>
(
dst
.
data
().
get
())
+
warp_id
*
bytes_per_warp
+
(
k_idx
%
4
)
*
mma_k
*
element_size
;
typedef
uint32_t
uint32x2_t
__attribute__
((
ext_vector_type
(
2
)));
uint32x2_t
index_offset
=
{
0
};
index_offset
[
0
]
=
row_offset
==
-
1
?
max_MN
:
row_offset
;
index_offset
[
1
]
=
offset_v
;
asm
volatile
(
"s_mov_b32 m0, %1
\n\t
"
"buffer_load_dwordx4 %0, %2, %3 , idxen offen offset:0, lds
\n
"
::
"v"
(
index_offset
),
"s"
(
ldsAddrPerWave
),
"s"
(
global_addr
),
"s"
(
offset_s
)
:
);
}
}
}
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment