Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
FlashMLA
Commits
72b2aea0
Commit
72b2aea0
authored
Mar 02, 2026
by
zhanghj2
Browse files
优化nmz tp8
parent
6d3ed1da
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
9 additions
and
5 deletions
+9
-5
csrc/extension/flash_fwd_mla_kernel_fp8.h
csrc/extension/flash_fwd_mla_kernel_fp8.h
+7
-3
csrc/extension/utils.h
csrc/extension/utils.h
+2
-2
No files found.
csrc/extension/flash_fwd_mla_kernel_fp8.h
View file @
72b2aea0
...
@@ -863,8 +863,13 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938(co
...
@@ -863,8 +863,13 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938(co
constexpr
static
int
STAGE
=
8
;
constexpr
static
int
STAGE
=
8
;
#if 1
#if 1
uint8_t
*
kv_lds_write_ptr_base
=
reinterpret_cast
<
uint8_t
*>
(
&
tSsK
(
0
,
0
,
0
));
extern
__shared__
char
shared_memory
[];
int
row_
=
lane_idx
/
8
;
int
col_
=
lane_idx
%
8
;
int
swizzle_col_
=
row_
^
col_
;
uint8_t
*
kv_lds_write_ptr_base
=
reinterpret_cast
<
uint8_t
*>
(
shared_memory
)
+
row_
*
128
+
swizzle_col_
*
16
+
warp_idx
*
16
*
64
;
v4f
c0_0
,
c0_1
,
c1_0
,
c1_1
,
c2_0
,
c2_1
,
c3_0
,
c3_1
;
v4f
c0_0
,
c0_1
,
c1_0
,
c1_1
,
c2_0
,
c2_1
,
c3_0
,
c3_1
;
c0_0
.
x
=
0.0
f
;
c0_0
.
y
=
0.0
f
;
c0_0
.
z
=
0.0
f
;
c0_0
.
w
=
0.0
f
;
c0_0
.
x
=
0.0
f
;
c0_0
.
y
=
0.0
f
;
c0_0
.
z
=
0.0
f
;
c0_0
.
w
=
0.0
f
;
c0_1
.
x
=
0.0
f
;
c0_1
.
y
=
0.0
f
;
c0_1
.
z
=
0.0
f
;
c0_1
.
w
=
0.0
f
;
c0_1
.
x
=
0.0
f
;
c0_1
.
y
=
0.0
f
;
c0_1
.
z
=
0.0
f
;
c0_1
.
w
=
0.0
f
;
...
@@ -877,7 +882,6 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938(co
...
@@ -877,7 +882,6 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938(co
c3_0
.
x
=
0.0
f
;
c3_0
.
y
=
0.0
f
;
c3_0
.
z
=
0.0
f
;
c3_0
.
w
=
0.0
f
;
c3_0
.
x
=
0.0
f
;
c3_0
.
y
=
0.0
f
;
c3_0
.
z
=
0.0
f
;
c3_0
.
w
=
0.0
f
;
c3_1
.
x
=
0.0
f
;
c3_1
.
y
=
0.0
f
;
c3_1
.
z
=
0.0
f
;
c3_1
.
w
=
0.0
f
;
c3_1
.
x
=
0.0
f
;
c3_1
.
y
=
0.0
f
;
c3_1
.
z
=
0.0
f
;
c3_1
.
w
=
0.0
f
;
extern
__shared__
char
shared_memory
[];
struct
IsMaskBlock
{};
struct
IsMaskBlock
{};
struct
IsFirstMaskBlock
{};
struct
IsFirstMaskBlock
{};
struct
IsNoMaskBlock
{};
struct
IsNoMaskBlock
{};
...
...
csrc/extension/utils.h
View file @
72b2aea0
...
@@ -642,8 +642,8 @@ buffer_load_copy_qkvfp8(
...
@@ -642,8 +642,8 @@ buffer_load_copy_qkvfp8(
global_addr
[
3
]
=
0x00020000
;
global_addr
[
3
]
=
0x00020000
;
int
mma_k
=
32
*
64
;
int
mma_k
=
32
*
64
;
int
row
=
lane
%
16
;
int
row
=
lane
/
4
;
int
col
=
lane
/
16
;
int
col
=
lane
%
4
;
int
row_offset
=
row
+
(
warp_id
*
16
);
int
row_offset
=
row
+
(
warp_id
*
16
);
int
col_offset
=
col
*
elements_per_thread
+
k_idx
*
64
;
int
col_offset
=
col
*
elements_per_thread
+
k_idx
*
64
;
int
offset_v
=
(
row_offset
*
row_stride
+
col_offset
)
*
element_size
;
// bytes
int
offset_v
=
(
row_offset
*
row_stride
+
col_offset
)
*
element_size
;
// bytes
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment