Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
FlashMLA
Commits
3722ec71
"...src/ssh:/git@developer.sourcefind.cn:2222/tsoc/openmm.git" did not exist on "337418fd0024862baeb82adebdb9315edde04e4b"
Commit
3722ec71
authored
Feb 28, 2026
by
zhanghj2
Browse files
优化nmz fp8 tp1
parent
34489f46
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
389 additions
and
3 deletions
+389
-3
csrc/extension/flash_fwd_mla_kernel_fp8.h
csrc/extension/flash_fwd_mla_kernel_fp8.h
+336
-3
csrc/extension/utils.h
csrc/extension/utils.h
+53
-0
No files found.
csrc/extension/flash_fwd_mla_kernel_fp8.h
View file @
3722ec71
...
@@ -1075,7 +1075,7 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938_TP
...
@@ -1075,7 +1075,7 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938_TP
Tensor
gK
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
Element
*>
(
params
.
k_ptr
)
+
row_offset_k
),
Tensor
gK
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
Element
*>
(
params
.
k_ptr
)
+
row_offset_k
),
Shape
<
Int
<
kBlockN
>
,
Int
<
kHeadDim
>>
{},
Shape
<
Int
<
kBlockN
>
,
Int
<
kHeadDim
>>
{},
make_stride
(
params
.
k_row_stride
,
_1
{}));
//64*576
make_stride
(
params
.
k_row_stride
,
_1
{}));
//64*576
const
auto
gK_data
=
gK
.
data
();
Tensor
gV
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
Element
*>
(
params
.
k_ptr
)
+
row_offset_k
),
Tensor
gV
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
Element
*>
(
params
.
k_ptr
)
+
row_offset_k
),
Shape
<
Int
<
kBlockN
>
,
Int
<
kHeadDimV
>>
{},
Shape
<
Int
<
kBlockN
>
,
Int
<
kHeadDimV
>>
{},
make_stride
(
params
.
k_row_stride
,
_1
{}));
//64*512
make_stride
(
params
.
k_row_stride
,
_1
{}));
//64*512
...
@@ -1099,6 +1099,7 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938_TP
...
@@ -1099,6 +1099,7 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938_TP
auto
thr_mma_o
=
tiled_mma_o
.
get_thread_slice
(
tidx
);
//16*32*32
auto
thr_mma_o
=
tiled_mma_o
.
get_thread_slice
(
tidx
);
//16*32*32
union
Fp8_storage
union
Fp8_storage
{
{
// uint32x4_t val;
intx4_t
data
;
intx4_t
data
;
intx2_t
p
[
2
];
intx2_t
p
[
2
];
int32_t
fp8_array
[
4
];
int32_t
fp8_array
[
4
];
...
@@ -1198,10 +1199,341 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938_TP
...
@@ -1198,10 +1199,341 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938_TP
acco_f32
[
i
].
w
=
0.0
f
;
acco_f32
[
i
].
w
=
0.0
f
;
}
}
constexpr
static
int
STAGE
=
8
;
constexpr
static
int
STAGE
=
8
;
extern
__shared__
char
shared_memory
[];
struct
IsMaskBlock
{};
struct
IsMaskBlock
{};
struct
IsFirstMaskBlock
{};
struct
IsFirstMaskBlock
{};
struct
IsNoMaskBlock
{};
struct
IsNoMaskBlock
{};
struct
IsLastBlock
{};
int
lane_id
=
tidx
%
64
;
int
row
=
lane_id
/
4
;
int
col
=
lane_id
%
4
;
col
=
(
col
+
(
row
/
2
)
%
4
)
%
4
;
const
auto
lds_offset
=
row
*
64
+
col
*
16
+
(
warp_id
/
4
)
*
64
*
64
;
uint8_t
*
kv_lds_write_ptr_base
=
reinterpret_cast
<
uint8_t
*>
(
shared_memory
)
+
((
warp_id
)
/
4
)
*
64
*
64
+
(
warp_id
%
4
)
*
16
*
64
+
row
*
64
+
col
*
16
;
Fp8_storage
kv_data
[
5
];
{
int
cur_block_table
;
// const int *cur_block_table_ptr;
cur_block_table
=
block_table
[
n_block
];
index_t
offset_k
;
//gK.data() = gK_data + (offset_k);
// cur_block_table_ptr = block_table + n_block;
// asm volatile("s_load_dword %1, %0, 0x0\n\t"
// "s_waitcnt lgkmcnt(0)\n\t":
// "+s"(cur_block_table_ptr),
// "=s"(cur_block_table));
offset_k
=
cur_block_table
*
params
.
k_batch_stride
;
gK
.
data
()
=
gK_data
+
(
offset_k
);
// buffer_load_copy_fp8_tp1<false, true, 0>(gK, kv_data[0].data, params.k_row_stride, seqlen_k - n_block * kBlockN);
// buffer_load_copy_fp8_tp1<false, true, 1>(gK, kv_data[1].data, params.k_row_stride, seqlen_k - n_block * kBlockN);
// buffer_load_copy_fp8_tp1<false, true, 2>(gK, kv_data[2].data, params.k_row_stride, seqlen_k - n_block * kBlockN);
// buffer_load_copy_fp8_tp1<false, true, 3>(gK, kv_data[3].data, params.k_row_stride, seqlen_k - n_block * kBlockN);
// buffer_load_copy_fp8_tp1<false, false, 4>(gK, kv_data[4].data, params.k_row_stride, seqlen_k - n_block * kBlockN);
// uint8_t* kv_lds_write_ptr = kv_lds_write_ptr_base;
// // for (int i = 0; i < )
// *(reinterpret_cast<intx4_t*>(kv_lds_write_ptr)) = kv_data[0].data;
// kv_lds_write_ptr += 64 * 128;
// *(reinterpret_cast<intx4_t*>(kv_lds_write_ptr)) = kv_data[1].data;
// kv_lds_write_ptr += 64 * 128;
// *(reinterpret_cast<intx4_t*>(kv_lds_write_ptr)) = kv_data[2].data;
// kv_lds_write_ptr += 64 * 128;
// *(reinterpret_cast<intx4_t*>(kv_lds_write_ptr)) = kv_data[3].data;
// kv_lds_write_ptr += 64 * 128;
// *(reinterpret_cast<intx4_t*>(kv_lds_write_ptr)) = kv_data[4].data;
// kv_lds_write_ptr += 64 * 128;
// gK.data() = gK.data() + (offset_k);
auto
gK_offset
=
((
warp_id
)
/
4
)
*
64
+
((
warp_id
)
%
4
)
*
16
*
params
.
k_row_stride
;
// auto gK_offset = (offset_k) + ((warp_id) / 4) * 64 + ((warp_id) % 4) * 16 * params.k_row_stride;
// const int k_zero_pad = std::min(std::max(block_idx * kBlockN + ((warp_id) % 4 + 1) * 16 - seqlen_k, 0), 16);
const
int
k_zero_pad
=
std
::
max
(
n_block
*
kBlockN
+
((
warp_id
)
%
4
+
1
)
*
16
-
seqlen_k
,
0
);
uint32x4_t
gK_rscr
=
make_rscr
((
unsigned
char
*
)(
gK
.
data
().
get
()
+
gK_offset
),
params
.
k_row_stride
,
k_zero_pad
);
auto
k_lds_addr
=
reinterpret_cast
<
size_t
>
(
sK
.
data
().
get
()
+
((
warp_id
)
/
4
)
*
64
*
64
+
(
warp_id
%
4
)
*
16
*
64
);
if
(
n_block
*
kBlockN
+
((
warp_id
)
%
4
)
*
16
<
seqlen_k
)
{
k_lds_addr
|=
0x80000000
;
__builtin_hcu_matrix_load_64x16_b8
(
gK_rscr
,
(
__attribute__
((
address_space
(
3
)))
char
*
)(
k_lds_addr
),
0
,
1
,
1
,
0
,
0
);
k_lds_addr
+=
64
*
128
;
__builtin_hcu_matrix_load_64x16_b8
(
gK_rscr
,
(
__attribute__
((
address_space
(
3
)))
char
*
)(
k_lds_addr
),
128
,
1
,
1
,
0
,
0
);
k_lds_addr
+=
64
*
128
;
__builtin_hcu_matrix_load_64x16_b8
(
gK_rscr
,
(
__attribute__
((
address_space
(
3
)))
char
*
)(
k_lds_addr
),
256
,
1
,
1
,
0
,
0
);
k_lds_addr
+=
64
*
128
;
__builtin_hcu_matrix_load_64x16_b8
(
gK_rscr
,
(
__attribute__
((
address_space
(
3
)))
char
*
)(
k_lds_addr
),
256
+
128
,
1
,
1
,
0
,
0
);
k_lds_addr
+=
64
*
128
;
if
(
warp_id
<
4
)
{
__builtin_hcu_matrix_load_64x16_b8
(
gK_rscr
,
(
__attribute__
((
address_space
(
3
)))
char
*
)(
k_lds_addr
),
512
,
1
,
1
,
0
,
0
);
}
else
{
lds_direct_copy_qkvfp8_zero_lds
(
gK
,
sK
,
4
);
}
}
else
{
lds_direct_copy_qkvfp8_zero_lds
(
gK
,
sK
,
0
);
lds_direct_copy_qkvfp8_zero_lds
(
gK
,
sK
,
1
);
lds_direct_copy_qkvfp8_zero_lds
(
gK
,
sK
,
2
);
lds_direct_copy_qkvfp8_zero_lds
(
gK
,
sK
,
3
);
lds_direct_copy_qkvfp8_zero_lds
(
gK
,
sK
,
4
);
}
}
auto
process_one_block
=
[
&
]
(
int
block_idx
,
auto
is_mask_block_t
)
{
static
constexpr
bool
IS_MASK_BLOCK
=
std
::
is_same_v
<
decltype
(
is_mask_block_t
),
IsNoMaskBlock
>
;
static
constexpr
bool
IS_FIRST_MASK_BLOCK
=
std
::
is_same_v
<
decltype
(
is_mask_block_t
),
IsFirstMaskBlock
>
;
static
constexpr
bool
IS_NO_MASK_BLOCK
=
std
::
is_same_v
<
decltype
(
is_mask_block_t
),
IsNoMaskBlock
>
;
static
constexpr
bool
IS_LAST_BLOCK
=
std
::
is_same_v
<
decltype
(
is_mask_block_t
),
IsLastBlock
>
;
v4f
accs_f32
[
2
];
for
(
int
i
=
0
;
i
<
2
;
i
++
)
{
accs_f32
[
i
].
x
=
0.0
f
;
accs_f32
[
i
].
y
=
0.0
f
;
accs_f32
[
i
].
z
=
0.0
f
;
accs_f32
[
i
].
w
=
0.0
f
;
}
__syncthreads
();
auto
k_lds_read_ptr
=
sK
.
data
().
get
()
+
(
warp_id
/
4
)
*
16
*
64
;
constexpr
static
int
k_read_lds_offset
=
32
*
64
;
{
constexpr
static
int
k_idx
=
0
;
// k_lds_read_ptr += k_idx * 64 * 64;
Fp8_storage
k_data
;
k_data
.
data
=
__builtin_hcu_ds_read_matrix_trans_format_u8
((
__attribute__
((
address_space
(
3
)))
int
*
)(
k_lds_read_ptr
),
k_idx
*
64
*
64
,
3
,
1
,
0
);
accs_f32
[
0
]
=
__builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts
(
q_r
[
k_idx
].
p
[
0
],
k_data
.
p
[
0
],
accs_f32
[
0
],
true
,
false
);
accs_f32
[
0
]
=
__builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts
(
q_r
[
k_idx
].
p
[
1
],
k_data
.
p
[
1
],
accs_f32
[
0
],
true
,
false
);
k_data
.
data
=
__builtin_hcu_ds_read_matrix_trans_format_u8
((
__attribute__
((
address_space
(
3
)))
int
*
)(
k_lds_read_ptr
),
k_idx
*
64
*
64
+
k_read_lds_offset
,
3
,
1
,
0
);
accs_f32
[
1
]
=
__builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts
(
q_r
[
k_idx
].
p
[
0
],
k_data
.
p
[
0
],
accs_f32
[
1
],
true
,
false
);
accs_f32
[
1
]
=
__builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts
(
q_r
[
k_idx
].
p
[
1
],
k_data
.
p
[
1
],
accs_f32
[
1
],
true
,
false
);
}
{
constexpr
static
int
k_idx
=
1
;
// k_lds_read_ptr += 64 * 64;
Fp8_storage
k_data
;
k_data
.
data
=
__builtin_hcu_ds_read_matrix_trans_format_u8
((
__attribute__
((
address_space
(
3
)))
int
*
)(
k_lds_read_ptr
),
k_idx
*
64
*
64
,
3
,
1
,
0
);
accs_f32
[
0
]
=
__builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts
(
q_r
[
k_idx
].
p
[
0
],
k_data
.
p
[
0
],
accs_f32
[
0
],
true
,
false
);
accs_f32
[
0
]
=
__builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts
(
q_r
[
k_idx
].
p
[
1
],
k_data
.
p
[
1
],
accs_f32
[
0
],
true
,
false
);
k_data
.
data
=
__builtin_hcu_ds_read_matrix_trans_format_u8
((
__attribute__
((
address_space
(
3
)))
int
*
)(
k_lds_read_ptr
),
k_idx
*
64
*
64
+
k_read_lds_offset
,
3
,
1
,
0
);
accs_f32
[
1
]
=
__builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts
(
q_r
[
k_idx
].
p
[
0
],
k_data
.
p
[
0
],
accs_f32
[
1
],
true
,
false
);
accs_f32
[
1
]
=
__builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts
(
q_r
[
k_idx
].
p
[
1
],
k_data
.
p
[
1
],
accs_f32
[
1
],
true
,
false
);
}
{
constexpr
static
int
k_idx
=
2
;
// k_lds_read_ptr += 64 * 64;
Fp8_storage
k_data
;
k_data
.
data
=
__builtin_hcu_ds_read_matrix_trans_format_u8
((
__attribute__
((
address_space
(
3
)))
int
*
)(
k_lds_read_ptr
),
k_idx
*
64
*
64
,
3
,
1
,
0
);
accs_f32
[
0
]
=
__builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts
(
q_r
[
k_idx
].
p
[
0
],
k_data
.
p
[
0
],
accs_f32
[
0
],
true
,
false
);
accs_f32
[
0
]
=
__builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts
(
q_r
[
k_idx
].
p
[
1
],
k_data
.
p
[
1
],
accs_f32
[
0
],
true
,
false
);
k_data
.
data
=
__builtin_hcu_ds_read_matrix_trans_format_u8
((
__attribute__
((
address_space
(
3
)))
int
*
)(
k_lds_read_ptr
),
k_idx
*
64
*
64
+
k_read_lds_offset
,
3
,
1
,
0
);
accs_f32
[
1
]
=
__builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts
(
q_r
[
k_idx
].
p
[
0
],
k_data
.
p
[
0
],
accs_f32
[
1
],
true
,
false
);
accs_f32
[
1
]
=
__builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts
(
q_r
[
k_idx
].
p
[
1
],
k_data
.
p
[
1
],
accs_f32
[
1
],
true
,
false
);
}
{
constexpr
static
int
k_idx
=
3
;
// k_lds_read_ptr += 64 * 64;
Fp8_storage
k_data
;
k_data
.
data
=
__builtin_hcu_ds_read_matrix_trans_format_u8
((
__attribute__
((
address_space
(
3
)))
int
*
)(
k_lds_read_ptr
),
k_idx
*
64
*
64
,
3
,
1
,
0
);
accs_f32
[
0
]
=
__builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts
(
q_r
[
k_idx
].
p
[
0
],
k_data
.
p
[
0
],
accs_f32
[
0
],
true
,
false
);
accs_f32
[
0
]
=
__builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts
(
q_r
[
k_idx
].
p
[
1
],
k_data
.
p
[
1
],
accs_f32
[
0
],
true
,
false
);
k_data
.
data
=
__builtin_hcu_ds_read_matrix_trans_format_u8
((
__attribute__
((
address_space
(
3
)))
int
*
)(
k_lds_read_ptr
),
k_idx
*
64
*
64
+
k_read_lds_offset
,
3
,
1
,
0
);
accs_f32
[
1
]
=
__builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts
(
q_r
[
k_idx
].
p
[
0
],
k_data
.
p
[
0
],
accs_f32
[
1
],
true
,
false
);
accs_f32
[
1
]
=
__builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts
(
q_r
[
k_idx
].
p
[
1
],
k_data
.
p
[
1
],
accs_f32
[
1
],
true
,
false
);
}
{
constexpr
static
int
k_idx
=
4
;
// k_lds_read_ptr += 64 * 64;
Fp8_storage
k_data
;
k_data
.
data
=
__builtin_hcu_ds_read_matrix_trans_format_u8
((
__attribute__
((
address_space
(
3
)))
int
*
)(
k_lds_read_ptr
),
k_idx
*
64
*
64
,
3
,
1
,
0
);
accs_f32
[
0
]
=
__builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts
(
q_r
[
k_idx
].
p
[
0
],
k_data
.
p
[
0
],
accs_f32
[
0
],
true
,
false
);
accs_f32
[
0
]
=
__builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts
(
q_r
[
k_idx
].
p
[
1
],
k_data
.
p
[
1
],
accs_f32
[
0
],
true
,
false
);
k_data
.
data
=
__builtin_hcu_ds_read_matrix_trans_format_u8
((
__attribute__
((
address_space
(
3
)))
int
*
)(
k_lds_read_ptr
),
k_idx
*
64
*
64
+
k_read_lds_offset
,
3
,
1
,
0
);
accs_f32
[
1
]
=
__builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts
(
q_r
[
k_idx
].
p
[
0
],
k_data
.
p
[
0
],
accs_f32
[
1
],
true
,
false
);
accs_f32
[
1
]
=
__builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts
(
q_r
[
k_idx
].
p
[
1
],
k_data
.
p
[
1
],
accs_f32
[
1
],
true
,
false
);
}
{
constexpr
static
int
k_idx
=
5
;
// k_lds_read_ptr += 64 * 64;
Fp8_storage
k_data
;
k_data
.
data
=
__builtin_hcu_ds_read_matrix_trans_format_u8
((
__attribute__
((
address_space
(
3
)))
int
*
)(
k_lds_read_ptr
),
k_idx
*
64
*
64
,
3
,
1
,
0
);
accs_f32
[
0
]
=
__builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts
(
q_r
[
k_idx
].
p
[
0
],
k_data
.
p
[
0
],
accs_f32
[
0
],
true
,
false
);
accs_f32
[
0
]
=
__builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts
(
q_r
[
k_idx
].
p
[
1
],
k_data
.
p
[
1
],
accs_f32
[
0
],
true
,
false
);
k_data
.
data
=
__builtin_hcu_ds_read_matrix_trans_format_u8
((
__attribute__
((
address_space
(
3
)))
int
*
)(
k_lds_read_ptr
),
k_idx
*
64
*
64
+
k_read_lds_offset
,
3
,
1
,
0
);
accs_f32
[
1
]
=
__builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts
(
q_r
[
k_idx
].
p
[
0
],
k_data
.
p
[
0
],
accs_f32
[
1
],
true
,
false
);
accs_f32
[
1
]
=
__builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts
(
q_r
[
k_idx
].
p
[
1
],
k_data
.
p
[
1
],
accs_f32
[
1
],
true
,
false
);
}
{
constexpr
static
int
k_idx
=
6
;
// k_lds_read_ptr += 64 * 64;
Fp8_storage
k_data
;
k_data
.
data
=
__builtin_hcu_ds_read_matrix_trans_format_u8
((
__attribute__
((
address_space
(
3
)))
int
*
)(
k_lds_read_ptr
),
k_idx
*
64
*
64
,
3
,
1
,
0
);
accs_f32
[
0
]
=
__builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts
(
q_r
[
k_idx
].
p
[
0
],
k_data
.
p
[
0
],
accs_f32
[
0
],
true
,
false
);
accs_f32
[
0
]
=
__builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts
(
q_r
[
k_idx
].
p
[
1
],
k_data
.
p
[
1
],
accs_f32
[
0
],
true
,
false
);
k_data
.
data
=
__builtin_hcu_ds_read_matrix_trans_format_u8
((
__attribute__
((
address_space
(
3
)))
int
*
)(
k_lds_read_ptr
),
k_idx
*
64
*
64
+
k_read_lds_offset
,
3
,
1
,
0
);
accs_f32
[
1
]
=
__builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts
(
q_r
[
k_idx
].
p
[
0
],
k_data
.
p
[
0
],
accs_f32
[
1
],
true
,
false
);
accs_f32
[
1
]
=
__builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts
(
q_r
[
k_idx
].
p
[
1
],
k_data
.
p
[
1
],
accs_f32
[
1
],
true
,
false
);
}
{
constexpr
static
int
k_idx
=
7
;
// k_lds_read_ptr += 64 * 64;
Fp8_storage
k_data
;
k_data
.
data
=
__builtin_hcu_ds_read_matrix_trans_format_u8
((
__attribute__
((
address_space
(
3
)))
int
*
)(
k_lds_read_ptr
),
k_idx
*
64
*
64
,
3
,
1
,
0
);
accs_f32
[
0
]
=
__builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts
(
q_r
[
k_idx
].
p
[
0
],
k_data
.
p
[
0
],
accs_f32
[
0
],
true
,
false
);
accs_f32
[
0
]
=
__builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts
(
q_r
[
k_idx
].
p
[
1
],
k_data
.
p
[
1
],
accs_f32
[
0
],
true
,
false
);
k_data
.
data
=
__builtin_hcu_ds_read_matrix_trans_format_u8
((
__attribute__
((
address_space
(
3
)))
int
*
)(
k_lds_read_ptr
),
k_idx
*
64
*
64
+
k_read_lds_offset
,
3
,
1
,
0
);
accs_f32
[
1
]
=
__builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts
(
q_r
[
k_idx
].
p
[
0
],
k_data
.
p
[
0
],
accs_f32
[
1
],
true
,
false
);
accs_f32
[
1
]
=
__builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts
(
q_r
[
k_idx
].
p
[
1
],
k_data
.
p
[
1
],
accs_f32
[
1
],
true
,
false
);
}
{
constexpr
static
int
k_idx
=
8
;
// k_lds_read_ptr += 64 * 64;
Fp8_storage
k_data
;
k_data
.
data
=
__builtin_hcu_ds_read_matrix_trans_format_u8
((
__attribute__
((
address_space
(
3
)))
int
*
)(
k_lds_read_ptr
),
k_idx
*
64
*
64
,
3
,
1
,
0
);
accs_f32
[
0
]
=
__builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts
(
q_r
[
k_idx
].
p
[
0
],
k_data
.
p
[
0
],
accs_f32
[
0
],
true
,
false
);
accs_f32
[
0
]
=
__builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts
(
q_r
[
k_idx
].
p
[
1
],
k_data
.
p
[
1
],
accs_f32
[
0
],
true
,
false
);
k_data
.
data
=
__builtin_hcu_ds_read_matrix_trans_format_u8
((
__attribute__
((
address_space
(
3
)))
int
*
)(
k_lds_read_ptr
),
k_idx
*
64
*
64
+
k_read_lds_offset
,
3
,
1
,
0
);
accs_f32
[
1
]
=
__builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts
(
q_r
[
k_idx
].
p
[
0
],
k_data
.
p
[
0
],
accs_f32
[
1
],
true
,
false
);
accs_f32
[
1
]
=
__builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts
(
q_r
[
k_idx
].
p
[
1
],
k_data
.
p
[
1
],
accs_f32
[
1
],
true
,
false
);
}
Tensor
acc_s
=
partition_fragment_C
(
tiled_mma
,
Shape
<
Int
<
kBlockM
>
,
Int
<
kBlockN
>>
{});
acc_s
(
0
,
0
,
0
)
=
accs_f32
[
0
].
x
;
acc_s
(
1
,
0
,
0
)
=
accs_f32
[
0
].
y
;
acc_s
(
2
,
0
,
0
)
=
accs_f32
[
0
].
z
;
acc_s
(
3
,
0
,
0
)
=
accs_f32
[
0
].
w
;
acc_s
(
0
,
0
,
1
)
=
accs_f32
[
1
].
x
;
acc_s
(
1
,
0
,
1
)
=
accs_f32
[
1
].
y
;
acc_s
(
2
,
0
,
1
)
=
accs_f32
[
1
].
z
;
acc_s
(
3
,
0
,
1
)
=
accs_f32
[
1
].
w
;
if
constexpr
(
!
IS_NO_MASK_BLOCK
)
{
Tensor
cS
=
make_identity_tensor
(
Shape
<
Int
<
kBlockM
>
,
Int
<
kBlockN
>>
{});
Tensor
tScS
=
thr_mma
.
partition_C
(
cS
);
for
(
int
i
=
0
;
i
<
size
(
acc_s
);
++
i
)
{
if
constexpr
(
!
Is_causal
)
{
if
(
int
(
get
<
1
>
(
tScS
(
i
)))
>=
int
(
seqlen_k
-
block_idx
*
kBlockN
))
acc_s
(
i
)
=
-
INFINITY
;
}
else
{
// Ensure seqlen_k - 1 - (block_idx * kBlockN + col) >= (seqlen_q - 1 - (m_block * kBlockM + row)) / ngroups
// col <= seqlen_k - 1 - block_idx * kBlockN - (seqlen_q - 1 - (m_block * kBlockM + row)) / ngroups
int
row
=
int
(
get
<
0
>
(
tScS
(
i
)));
int
col_limit_right
=
seqlen_k
-
1
-
block_idx
*
kBlockN
-
(
params
.
seqlen_q
-
1
-
(
m_block
*
kBlockM
+
row
))
/
params
.
ngroups
;
if
(
int
(
get
<
1
>
(
tScS
(
i
)))
>
col_limit_right
)
acc_s
(
i
)
=
-
INFINITY
;
}
}
}
softmax
.
template
softmax_rescale_o_fp8_tp1
<
/*Is_first=*/
IS_FIRST_MASK_BLOCK
,
/*Check_inf=*/
Is_causal
,
true
>(
acc_s
,
sRow_max_reduce_buffer
,
scale_softmax_log2
,
acco_f32
);
Fp8_storage
p_fp8
;
{
__builtin_amdgcn_sched_barrier
(
0
);
int
tid
=
threadIdx
.
x
%
64
;
int
warp_id
=
threadIdx
.
x
/
64
;
int32_t
result
;
result
=
__builtin_hcu_cvt_pk_fp8_f32
(
acc_s
(
0
,
0
,
0
),
acc_s
(
1
,
0
,
0
),
result
,
false
);
result
=
__builtin_hcu_cvt_pk_fp8_f32
(
acc_s
(
2
,
0
,
0
),
acc_s
(
3
,
0
,
0
),
result
,
true
);
// int32_t* lds_ptr = reinterpret_cast<int32_t*>(&(sP[ (tid % 16) * 16 + ((tid / 16) % 2 ) * 4 + (tid / 32) * (16 * 16) + (warp_id / 4) * 16 * 32 + (warp_id % 4) * 16 * 64]));
// *lds_ptr = result;
int32_t
result1
;
result1
=
__builtin_hcu_cvt_pk_fp8_f32
(
acc_s
(
0
,
0
,
1
),
acc_s
(
1
,
0
,
1
),
result1
,
false
);
result1
=
__builtin_hcu_cvt_pk_fp8_f32
(
acc_s
(
2
,
0
,
1
),
acc_s
(
3
,
0
,
1
),
result1
,
true
);
// lds_ptr = reinterpret_cast<int32_t*>(&(sP[ (tid % 16) * 16 + ((tid / 16) % 2 ) * 4 + (tid / 32) * (16 * 16) + (warp_id / 4) * 16 * 32 + (warp_id % 4) * 16 * 64 + 8]));
// *lds_ptr = result1;
int32_t
*
lds_ptr
=
reinterpret_cast
<
int32_t
*>
(
&
(
sP
[
(
tid
%
16
)
*
16
+
((
tid
%
64
)
/
16
)
*
16
*
16
+
(
warp_id
/
4
)
*
4
+
(
warp_id
%
4
)
*
16
*
64
]));
*
lds_ptr
=
result
;
int32_t
*
lds_ptr1
=
reinterpret_cast
<
int32_t
*>
(
&
(
sP
[
(
tid
%
16
)
*
16
+
((
tid
%
64
)
/
16
)
*
16
*
16
+
(
warp_id
/
4
)
*
4
+
(
warp_id
%
4
)
*
16
*
64
+
8
]));
*
lds_ptr1
=
result1
;
__syncthreads
();
p_fp8
.
data
=
*
reinterpret_cast
<
intx4_t
*>
(
&
(
sP
[
tid
*
16
+
(
warp_id
%
4
)
*
16
*
64
]));
__builtin_amdgcn_sched_barrier
(
0
);
}
if
(
block_idx
>
n_block_min
)
{
int
cur_block_table
;
const
int
*
cur_block_table_ptr
;
cur_block_table
=
block_table
[
block_idx
-
1
];
index_t
offset_k
;
// cur_block_table_ptr = block_table + block_idx;
// asm volatile("s_load_dword %1, %0, 0x0\n\t"
// "s_waitcnt lgkmcnt(0)\n\t":
// "+s"(cur_block_table_ptr),
// "=s"(cur_block_table));
offset_k
=
cur_block_table
*
params
.
k_batch_stride
;
gK
.
data
()
=
gK_data
+
(
offset_k
);
buffer_load_copy_fp8_tp1
<
true
,
true
,
0
>
(
gK
,
kv_data
[
0
].
data
,
params
.
k_row_stride
,
seqlen_k
-
block_idx
*
kBlockN
);
buffer_load_copy_fp8_tp1
<
true
,
true
,
1
>
(
gK
,
kv_data
[
1
].
data
,
params
.
k_row_stride
,
seqlen_k
-
block_idx
*
kBlockN
);
buffer_load_copy_fp8_tp1
<
true
,
true
,
2
>
(
gK
,
kv_data
[
2
].
data
,
params
.
k_row_stride
,
seqlen_k
-
block_idx
*
kBlockN
);
buffer_load_copy_fp8_tp1
<
true
,
true
,
3
>
(
gK
,
kv_data
[
3
].
data
,
params
.
k_row_stride
,
seqlen_k
-
block_idx
*
kBlockN
);
buffer_load_copy_fp8_tp1
<
true
,
false
,
4
>
(
gK
,
kv_data
[
4
].
data
,
params
.
k_row_stride
,
seqlen_k
-
block_idx
*
kBlockN
);
}
for
(
int
n
=
0
;
n
<
4
;
n
++
)
{
Fp8_storage
v0_0
,
v0_1
;
v0_0
.
data
=
__builtin_hcu_ds_read_m64x16_u8_alt4
((
__attribute__
((
address_space
(
3
)))
int
*
)(
sV
.
data
().
get
()
+
lds_offset
+
n
*
64
*
128
));
v0_1
.
data
=
__builtin_hcu_ds_read_m64x16_u8_alt4
((
__attribute__
((
address_space
(
3
)))
int
*
)(
sV
.
data
().
get
()
+
lds_offset
+
16
*
64
+
n
*
64
*
128
));
for
(
int
j
=
0
;
j
<
4
;
j
++
)
{
intx2_t
v
;
v
[
0
]
=
v0_0
.
fp8_array
[
j
];
v
[
1
]
=
v0_1
.
fp8_array
[
j
];
acco_f32
[
n
*
4
+
j
]
=
__builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts
(
p_fp8
.
p
[
0
],
v
,
acco_f32
[
n
*
4
+
j
],
true
,
false
);
}
v0_0
.
data
=
__builtin_hcu_ds_read_m64x16_u8_alt4
((
__attribute__
((
address_space
(
3
)))
int
*
)(
sV
.
data
().
get
()
+
lds_offset
+
n
*
64
*
128
+
32
*
64
));
v0_1
.
data
=
__builtin_hcu_ds_read_m64x16_u8_alt4
((
__attribute__
((
address_space
(
3
)))
int
*
)(
sV
.
data
().
get
()
+
lds_offset
+
16
*
64
+
n
*
64
*
128
+
32
*
64
));
for
(
int
j
=
0
;
j
<
4
;
j
++
)
{
intx2_t
v
;
v
[
0
]
=
v0_0
.
fp8_array
[
j
];
v
[
1
]
=
v0_1
.
fp8_array
[
j
];
acco_f32
[
n
*
4
+
j
]
=
__builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts
(
p_fp8
.
p
[
1
],
v
,
acco_f32
[
n
*
4
+
j
],
true
,
false
);
}
}
if
(
block_idx
>
n_block_min
)
{
__syncthreads
();
uint8_t
*
kv_lds_write_ptr
=
kv_lds_write_ptr_base
;
// for (int i = 0; i < )
*
(
reinterpret_cast
<
intx4_t
*>
(
kv_lds_write_ptr
))
=
kv_data
[
0
].
data
;
kv_lds_write_ptr
+=
64
*
128
;
*
(
reinterpret_cast
<
intx4_t
*>
(
kv_lds_write_ptr
))
=
kv_data
[
1
].
data
;
kv_lds_write_ptr
+=
64
*
128
;
*
(
reinterpret_cast
<
intx4_t
*>
(
kv_lds_write_ptr
))
=
kv_data
[
2
].
data
;
kv_lds_write_ptr
+=
64
*
128
;
*
(
reinterpret_cast
<
intx4_t
*>
(
kv_lds_write_ptr
))
=
kv_data
[
3
].
data
;
kv_lds_write_ptr
+=
64
*
128
;
*
(
reinterpret_cast
<
intx4_t
*>
(
kv_lds_write_ptr
))
=
kv_data
[
4
].
data
;
}
// asm volatile("s_barrier \n\t");
};
#if 0
auto process_one_block = [&] (int block_idx, auto is_mask_block_t) {
auto process_one_block = [&] (int block_idx, auto is_mask_block_t) {
static constexpr bool IS_MASK_BLOCK = std::is_same_v<decltype(is_mask_block_t), IsNoMaskBlock>;
static constexpr bool IS_MASK_BLOCK = std::is_same_v<decltype(is_mask_block_t), IsNoMaskBlock>;
static constexpr bool IS_FIRST_MASK_BLOCK = std::is_same_v<decltype(is_mask_block_t), IsFirstMaskBlock>;
static constexpr bool IS_FIRST_MASK_BLOCK = std::is_same_v<decltype(is_mask_block_t), IsFirstMaskBlock>;
...
@@ -1494,7 +1826,8 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938_TP
...
@@ -1494,7 +1826,8 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938_TP
}
}
asm volatile("s_barrier \n\t");
asm volatile("s_barrier \n\t");
};
};
#endif
#if 1
#if 1
if
constexpr
(
n_masking_steps
==
1
)
{
if
constexpr
(
n_masking_steps
==
1
)
{
process_one_block
(
n_block
,
IsFirstMaskBlock
{});
process_one_block
(
n_block
,
IsFirstMaskBlock
{});
...
...
csrc/extension/utils.h
View file @
3722ec71
...
@@ -2748,6 +2748,59 @@ lds_direct_copy_qkvfp8_zero_lds(
...
@@ -2748,6 +2748,59 @@ lds_direct_copy_qkvfp8_zero_lds(
#endif
#endif
}
}
template
<
bool
Is_even_MN
=
true
,
bool
Is_even_K
=
true
,
int
k_idx
,
class
SrcEngine
,
class
SrcLayout
>
CUTE_HOST_DEVICE
void
buffer_load_copy_fp8_tp1
(
Tensor
<
SrcEngine
,
SrcLayout
>
const
&
src
,
intx4_t
&
dst
,
const
int
row_stride
,
const
int
max_MN
=
0
)
{
constexpr
int
warp_size
=
64
;
int
tidx
=
threadIdx
.
x
;
int
warp_id
=
__builtin_amdgcn_readfirstlane
(
tidx
/
warp_size
);
int
lane
=
tidx
%
warp_size
;
constexpr
int
element_size
=
1
;
constexpr
int
elements_per_thread
=
16
;
{
struct
PtrWrapper
{
uint32_t
former
;
uint32_t
latter
;
};
PtrWrapper
glob_ptr
;
*
(
uint64_t
*
)
&
glob_ptr
=
reinterpret_cast
<
uint64_t
>
(
src
.
data
().
get
());
// glob_ptr.latter |= 0x40000000; // 62 bit: cache swizzle; 48~61: Stride
uint32x4_t
global_addr
=
{
0
};
global_addr
[
0
]
=
__builtin_amdgcn_readfirstlane
(
glob_ptr
.
former
);
global_addr
[
1
]
=
__builtin_amdgcn_readfirstlane
(
glob_ptr
.
latter
);
global_addr
[
2
]
=
0x80000000
;
global_addr
[
3
]
=
0x00020000
;
int
mma_k
=
32
*
64
;
int
row
=
lane
/
4
;
int
col
=
lane
%
4
;
int
row_offset
=
row
+
((
warp_id
%
4
)
*
16
)
;
int
col_offset
=
col
*
elements_per_thread
+
k_idx
*
128
+
(
warp_id
/
4
)
*
64
;
int
offset_v
=
(
row_offset
*
row_stride
+
col_offset
)
*
element_size
;
// bytes
if
(
!
Is_even_K
&&
col_offset
>=
576
)
offset_v
=
-
1
;
if
(
!
Is_even_MN
&&
row_offset
>=
max_MN
)
offset_v
=
-
1
;
{
dst
=
__builtin_amdgcn_buffer_load_dwordx4
(
global_addr
,
0
,
offset_v
,
false
,
false
);
}
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////////////////////////
}
// namespace flash
}
// namespace flash
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment