Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
FlashMLA
Commits
d6379e50
Commit
d6379e50
authored
Jan 30, 2026
by
zhanghj2
Browse files
实现了scale使用buffer load读取
parent
bdf0140b
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
33 additions
and
0 deletions
+33
-0
csrc/sm90/decode/sparse_fp8/splitkv_mla.cuh
csrc/sm90/decode/sparse_fp8/splitkv_mla.cuh
+33
-0
No files found.
csrc/sm90/decode/sparse_fp8/splitkv_mla.cuh
View file @
d6379e50
...
@@ -284,6 +284,7 @@ __device__ void KernelTemplate<MODEL_TYPE, NUM_HEADS>::compute_attn_1rowblock_sp
...
@@ -284,6 +284,7 @@ __device__ void KernelTemplate<MODEL_TYPE, NUM_HEADS>::compute_attn_1rowblock_sp
typedef
unsigned
char
__hip_fp8_storage_t
;
typedef
unsigned
char
__hip_fp8_storage_t
;
typedef
__fp16
__fp16x8_t
__attribute__
((
ext_vector_type
(
8
)));
typedef
__fp16
__fp16x8_t
__attribute__
((
ext_vector_type
(
8
)));
typedef
__fp16
__fp16x4_t
__attribute__
((
ext_vector_type
(
4
)));
typedef
__fp16
__fp16x4_t
__attribute__
((
ext_vector_type
(
4
)));
typedef
int
v2i
__attribute__
((
ext_vector_type
(
2
)));
union
Fp8_storage
{
union
Fp8_storage
{
__fp16x8_t
data_128
;
__fp16x8_t
data_128
;
...
@@ -390,6 +391,7 @@ __device__ void KernelTemplate<MODEL_TYPE, NUM_HEADS>::compute_attn_1rowblock_sp
...
@@ -390,6 +391,7 @@ __device__ void KernelTemplate<MODEL_TYPE, NUM_HEADS>::compute_attn_1rowblock_sp
}
else
{
}
else
{
gK_base
=
k_ptr
+
offset_k
+
rel_idx_in_block
*
(
HEAD_DIM_NOPE
+
HEAD_DIM_ROPE
*
2
);;
gK_base
=
k_ptr
+
offset_k
+
rel_idx_in_block
*
(
HEAD_DIM_NOPE
+
HEAD_DIM_ROPE
*
2
);;
static_assert
(
NUM_SCALES
==
8
);
static_assert
(
NUM_SCALES
==
8
);
#if 1
uint8_t
*
scale_ptr
=
k_ptr
+
offset_k
+
page_block_size
*
(
HEAD_DIM_NOPE
+
HEAD_DIM_ROPE
*
2
)
+
rel_idx_in_block
*
NUM_SCALES
;
uint8_t
*
scale_ptr
=
k_ptr
+
offset_k
+
page_block_size
*
(
HEAD_DIM_NOPE
+
HEAD_DIM_ROPE
*
2
)
+
rel_idx_in_block
*
NUM_SCALES
;
if
(
token_index
==
-
1
)
if
(
token_index
==
-
1
)
{
{
...
@@ -419,6 +421,37 @@ __device__ void KernelTemplate<MODEL_TYPE, NUM_HEADS>::compute_attn_1rowblock_sp
...
@@ -419,6 +421,37 @@ __device__ void KernelTemplate<MODEL_TYPE, NUM_HEADS>::compute_attn_1rowblock_sp
}
}
}
}
#else
struct
PtrWrapper
{
uint32_t
former
;
uint32_t
latter
;
};
PtrWrapper
glob_ptr
;
*
(
uint64_t
*
)
&
glob_ptr
=
reinterpret_cast
<
uint64_t
>
(
k_ptr
+
offset_k
+
page_block_size
*
(
HEAD_DIM_NOPE
+
HEAD_DIM_ROPE
*
2
));
uint32x4_t
global_addr
=
{
0
};
global_addr
[
0
]
=
(
glob_ptr
.
former
);
global_addr
[
1
]
=
(
glob_ptr
.
latter
);
global_addr
[
2
]
=
0x80000000
;
global_addr
[
3
]
=
0x00020000
;
int
offset_v
=
token_index
==
-
1
?
-
1
:
rel_idx_in_block
*
NUM_SCALES
;
union
Scale_e8m0
{
v2i
tmp
;
__hip_fp8_storage_t
fp8_e8m0
[
NUM_SCALES
];
};
Scale_e8m0
scale_e8m0
;
scale_e8m0
.
tmp
=
__builtin_amdgcn_buffer_load_dwordx2
(
global_addr
,
0
,
offset_v
,
0
,
0
);
union
Fp32
{
uint32_t
as_bits
;
float
as_value
;
};
Fp32
fp32
;
for
(
int
i
=
0
;
i
<
NUM_SCALES
-
1
;
i
++
)
{
fp32
.
as_bits
=
(
scale_e8m0
.
fp8_e8m0
[
i
]
<<
23
);
scales
[
i
]
=
fp32
.
as_value
;
}
#endif
// if (block0() && threadIdx.x < 64)
// if (block0() && threadIdx.x < 64)
// {
// {
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment