Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
FlashMLA
Commits
59487e20
Commit
59487e20
authored
Feb 24, 2026
by
zhanghj2
Browse files
smxx修改为gfx9
parent
f298a271
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
23 additions
and
23 deletions
+23
-23
csrc/api/dense_decode.h
csrc/api/dense_decode.h
+5
-5
csrc/api/dense_decode_kvfp8.h
csrc/api/dense_decode_kvfp8.h
+4
-4
csrc/api/dense_decode_qkvfp8.h
csrc/api/dense_decode_qkvfp8.h
+4
-4
csrc/api/sparse_decode.h
csrc/api/sparse_decode.h
+4
-4
csrc/gfx9/decode/combine/combine.cu
csrc/gfx9/decode/combine/combine.cu
+1
-1
csrc/gfx9/decode/combine/combine.h
csrc/gfx9/decode/combine/combine.h
+1
-1
csrc/gfx9/decode/get_decoding_sched_meta/get_decoding_sched_meta.cu
...decode/get_decoding_sched_meta/get_decoding_sched_meta.cu
+1
-1
csrc/gfx9/decode/get_decoding_sched_meta/get_decoding_sched_meta.h
.../decode/get_decoding_sched_meta/get_decoding_sched_meta.h
+1
-1
setup.py
setup.py
+2
-2
No files found.
csrc/api/dense_decode.h
View file @
59487e20
...
@@ -7,8 +7,8 @@
...
@@ -7,8 +7,8 @@
#include "params.h"
#include "params.h"
#include "gfx93/decode/dense/splitkv_mla.h"
#include "gfx93/decode/dense/splitkv_mla.h"
#include "
smxx
/decode/get_decoding_sched_meta/get_decoding_sched_meta.h"
#include "
gfx9
/decode/get_decoding_sched_meta/get_decoding_sched_meta.h"
#include "
smxx
/decode/combine/combine.h"
#include "
gfx9
/decode/combine/combine.h"
static
std
::
tuple
<
at
::
Tensor
,
at
::
Tensor
,
std
::
optional
<
at
::
Tensor
>
,
std
::
optional
<
at
::
Tensor
>>
static
std
::
tuple
<
at
::
Tensor
,
at
::
Tensor
,
std
::
optional
<
at
::
Tensor
>
,
std
::
optional
<
at
::
Tensor
>>
dense_attn_decode_interface
(
dense_attn_decode_interface
(
...
@@ -110,7 +110,7 @@ dense_attn_decode_interface(
...
@@ -110,7 +110,7 @@ dense_attn_decode_interface(
num_sm_parts
,
num_sm_parts
,
at
::
cuda
::
getCurrentCUDAStream
().
stream
()
at
::
cuda
::
getCurrentCUDAStream
().
stream
()
};
};
smxx
::
decode
::
run_get_decoding_sched_meta_kernel
(
get_sched_meta_params
);
gfx9
::
decode
::
run_get_decoding_sched_meta_kernel
(
get_sched_meta_params
);
}
else
{
}
else
{
KU_CHECK_DTYPE
(
tile_scheduler_metadata
,
torch
::
kInt32
);
KU_CHECK_DTYPE
(
tile_scheduler_metadata
,
torch
::
kInt32
);
KU_CHECK_DTYPE
(
num_splits
,
torch
::
kInt32
);
KU_CHECK_DTYPE
(
num_splits
,
torch
::
kInt32
);
...
@@ -207,10 +207,10 @@ dense_attn_decode_interface(
...
@@ -207,10 +207,10 @@ dense_attn_decode_interface(
};
};
if
(
q_dtype
==
torch
::
kBFloat16
)
{
if
(
q_dtype
==
torch
::
kBFloat16
)
{
smxx
::
decode
::
run_flash_mla_combine_kernel
<
cutlass
::
bfloat16_t
>
(
combine_params
);
gfx9
::
decode
::
run_flash_mla_combine_kernel
<
cutlass
::
bfloat16_t
>
(
combine_params
);
}
else
if
(
q_dtype
==
torch
::
kHalf
)
{
}
else
if
(
q_dtype
==
torch
::
kHalf
)
{
#ifndef FLASH_MLA_DISABLE_FP16
#ifndef FLASH_MLA_DISABLE_FP16
smxx
::
decode
::
run_flash_mla_combine_kernel
<
cutlass
::
half_t
>
(
combine_params
);
gfx9
::
decode
::
run_flash_mla_combine_kernel
<
cutlass
::
half_t
>
(
combine_params
);
#endif
#endif
}
else
{
}
else
{
TORCH_CHECK
(
false
,
"Unsupported tensor dtype for query"
);
TORCH_CHECK
(
false
,
"Unsupported tensor dtype for query"
);
...
...
csrc/api/dense_decode_kvfp8.h
View file @
59487e20
...
@@ -7,8 +7,8 @@
...
@@ -7,8 +7,8 @@
#include "params.h"
#include "params.h"
#include "gfx93/decode/dense_kvfp8/splitkv_mla.h"
#include "gfx93/decode/dense_kvfp8/splitkv_mla.h"
#include "
smxx
/decode/get_decoding_sched_meta/get_decoding_sched_meta.h"
#include "
gfx9
/decode/get_decoding_sched_meta/get_decoding_sched_meta.h"
#include "
smxx
/decode/combine/combine.h"
#include "
gfx9
/decode/combine/combine.h"
static
std
::
tuple
<
at
::
Tensor
,
at
::
Tensor
,
std
::
optional
<
at
::
Tensor
>
,
std
::
optional
<
at
::
Tensor
>>
static
std
::
tuple
<
at
::
Tensor
,
at
::
Tensor
,
std
::
optional
<
at
::
Tensor
>
,
std
::
optional
<
at
::
Tensor
>>
dense_attn_decode_kvfp8_interface
(
dense_attn_decode_kvfp8_interface
(
...
@@ -123,7 +123,7 @@ dense_attn_decode_kvfp8_interface(
...
@@ -123,7 +123,7 @@ dense_attn_decode_kvfp8_interface(
num_sm_parts
,
num_sm_parts
,
at
::
cuda
::
getCurrentCUDAStream
().
stream
()
at
::
cuda
::
getCurrentCUDAStream
().
stream
()
};
};
smxx
::
decode
::
run_get_decoding_sched_meta_kernel
(
get_sched_meta_params
);
gfx9
::
decode
::
run_get_decoding_sched_meta_kernel
(
get_sched_meta_params
);
}
else
{
}
else
{
KU_CHECK_DTYPE
(
tile_scheduler_metadata
,
torch
::
kInt32
);
KU_CHECK_DTYPE
(
tile_scheduler_metadata
,
torch
::
kInt32
);
KU_CHECK_DTYPE
(
num_splits
,
torch
::
kInt32
);
KU_CHECK_DTYPE
(
num_splits
,
torch
::
kInt32
);
...
@@ -215,7 +215,7 @@ dense_attn_decode_kvfp8_interface(
...
@@ -215,7 +215,7 @@ dense_attn_decode_kvfp8_interface(
at
::
cuda
::
getCurrentCUDAStream
().
stream
()
at
::
cuda
::
getCurrentCUDAStream
().
stream
()
};
};
smxx
::
decode
::
run_flash_mla_combine_kernel
<
cutlass
::
bfloat16_t
>
(
combine_params
);
gfx9
::
decode
::
run_flash_mla_combine_kernel
<
cutlass
::
bfloat16_t
>
(
combine_params
);
out
=
out
.
view
({
batch_size
,
num_heads_k
,
seqlen_q_ori
,
num_q_heads_per_hk
,
head_size_v
}).
transpose
(
1
,
2
)
out
=
out
.
view
({
batch_size
,
num_heads_k
,
seqlen_q_ori
,
num_q_heads_per_hk
,
head_size_v
}).
transpose
(
1
,
2
)
...
...
csrc/api/dense_decode_qkvfp8.h
View file @
59487e20
...
@@ -7,8 +7,8 @@
...
@@ -7,8 +7,8 @@
#include "params.h"
#include "params.h"
#include "gfx93/decode/dense_qkvfp8/splitkv_mla.h"
#include "gfx93/decode/dense_qkvfp8/splitkv_mla.h"
#include "
smxx
/decode/get_decoding_sched_meta/get_decoding_sched_meta.h"
#include "
gfx9
/decode/get_decoding_sched_meta/get_decoding_sched_meta.h"
#include "
smxx
/decode/combine/combine.h"
#include "
gfx9
/decode/combine/combine.h"
static
std
::
tuple
<
at
::
Tensor
,
at
::
Tensor
,
std
::
optional
<
at
::
Tensor
>
,
std
::
optional
<
at
::
Tensor
>>
static
std
::
tuple
<
at
::
Tensor
,
at
::
Tensor
,
std
::
optional
<
at
::
Tensor
>
,
std
::
optional
<
at
::
Tensor
>>
dense_attn_decode_qkvfp8_interface
(
dense_attn_decode_qkvfp8_interface
(
...
@@ -123,7 +123,7 @@ dense_attn_decode_qkvfp8_interface(
...
@@ -123,7 +123,7 @@ dense_attn_decode_qkvfp8_interface(
num_sm_parts
,
num_sm_parts
,
at
::
cuda
::
getCurrentCUDAStream
().
stream
()
at
::
cuda
::
getCurrentCUDAStream
().
stream
()
};
};
smxx
::
decode
::
run_get_decoding_sched_meta_kernel
(
get_sched_meta_params
);
gfx9
::
decode
::
run_get_decoding_sched_meta_kernel
(
get_sched_meta_params
);
}
else
{
}
else
{
KU_CHECK_DTYPE
(
tile_scheduler_metadata
,
torch
::
kInt32
);
KU_CHECK_DTYPE
(
tile_scheduler_metadata
,
torch
::
kInt32
);
KU_CHECK_DTYPE
(
num_splits
,
torch
::
kInt32
);
KU_CHECK_DTYPE
(
num_splits
,
torch
::
kInt32
);
...
@@ -215,7 +215,7 @@ dense_attn_decode_qkvfp8_interface(
...
@@ -215,7 +215,7 @@ dense_attn_decode_qkvfp8_interface(
at
::
cuda
::
getCurrentCUDAStream
().
stream
()
at
::
cuda
::
getCurrentCUDAStream
().
stream
()
};
};
smxx
::
decode
::
run_flash_mla_combine_kernel
<
cutlass
::
bfloat16_t
>
(
combine_params
);
gfx9
::
decode
::
run_flash_mla_combine_kernel
<
cutlass
::
bfloat16_t
>
(
combine_params
);
out
=
out
.
view
({
batch_size
,
num_heads_k
,
seqlen_q_ori
,
num_q_heads_per_hk
,
head_size_v
}).
transpose
(
1
,
2
)
out
=
out
.
view
({
batch_size
,
num_heads_k
,
seqlen_q_ori
,
num_q_heads_per_hk
,
head_size_v
}).
transpose
(
1
,
2
)
...
...
csrc/api/sparse_decode.h
View file @
59487e20
...
@@ -5,8 +5,8 @@
...
@@ -5,8 +5,8 @@
#include "params.h"
#include "params.h"
#include "gfx93/decode/sparse_fp8/splitkv_mla.h"
#include "gfx93/decode/sparse_fp8/splitkv_mla.h"
#include "
smxx
/decode/get_decoding_sched_meta/get_decoding_sched_meta.h"
#include "
gfx9
/decode/get_decoding_sched_meta/get_decoding_sched_meta.h"
#include "
smxx
/decode/combine/combine.h"
#include "
gfx9
/decode/combine/combine.h"
// Feature set of sparse decoding kernels
// Feature set of sparse decoding kernels
enum
class
DecodeFeatures
:
int
{
enum
class
DecodeFeatures
:
int
{
...
@@ -328,7 +328,7 @@ sparse_attn_decode_interface(
...
@@ -328,7 +328,7 @@ sparse_attn_decode_interface(
impl_meta
.
num_sm_parts
,
impl_meta
.
num_sm_parts
,
at
::
cuda
::
getCurrentCUDAStream
().
stream
()
at
::
cuda
::
getCurrentCUDAStream
().
stream
()
};
};
smxx
::
decode
::
run_get_decoding_sched_meta_kernel
(
get_sched_meta_params
);
gfx9
::
decode
::
run_get_decoding_sched_meta_kernel
(
get_sched_meta_params
);
}
}
// Stick the metadata pointers to `params`
// Stick the metadata pointers to `params`
KU_CHECK_DEVICE
(
tile_scheduler_metadata
);
KU_CHECK_DEVICE
(
tile_scheduler_metadata
);
...
@@ -379,7 +379,7 @@ sparse_attn_decode_interface(
...
@@ -379,7 +379,7 @@ sparse_attn_decode_interface(
ku
::
get_optional_tensor_ptr
<
float
>
(
attn_sink
),
ku
::
get_optional_tensor_ptr
<
float
>
(
attn_sink
),
at
::
cuda
::
getCurrentCUDAStream
().
stream
()
at
::
cuda
::
getCurrentCUDAStream
().
stream
()
};
};
smxx
::
decode
::
run_flash_mla_combine_kernel
<
bf16
>
(
combine_params
);
gfx9
::
decode
::
run_flash_mla_combine_kernel
<
bf16
>
(
combine_params
);
delete
impl
;
delete
impl
;
...
...
csrc/
smxx
/decode/combine/combine.cu
→
csrc/
gfx9
/decode/combine/combine.cu
View file @
59487e20
...
@@ -14,7 +14,7 @@
...
@@ -14,7 +14,7 @@
using
namespace
cute
;
using
namespace
cute
;
namespace
smxx
::
decode
{
namespace
gfx9
::
decode
{
template
<
typename
ElementT
,
int
HEAD_DIM_V
,
int
BLOCK_SIZE_M
,
int
MAX_SPLITS
,
int
NUM_THREADS
>
template
<
typename
ElementT
,
int
HEAD_DIM_V
,
int
BLOCK_SIZE_M
,
int
MAX_SPLITS
,
int
NUM_THREADS
>
__global__
void
__launch_bounds__
(
NUM_THREADS
,
1
)
__global__
void
__launch_bounds__
(
NUM_THREADS
,
1
)
...
...
csrc/
smxx
/decode/combine/combine.h
→
csrc/
gfx9
/decode/combine/combine.h
View file @
59487e20
...
@@ -2,7 +2,7 @@
...
@@ -2,7 +2,7 @@
#include "params.h"
#include "params.h"
namespace
smxx
::
decode
{
namespace
gfx9
::
decode
{
template
<
typename
ElementT
>
template
<
typename
ElementT
>
void
run_flash_mla_combine_kernel
(
CombineParams
&
params
);
void
run_flash_mla_combine_kernel
(
CombineParams
&
params
);
...
...
csrc/
smxx
/decode/get_decoding_sched_meta/get_decoding_sched_meta.cu
→
csrc/
gfx9
/decode/get_decoding_sched_meta/get_decoding_sched_meta.cu
View file @
59487e20
...
@@ -6,7 +6,7 @@
...
@@ -6,7 +6,7 @@
#include "utils.h"
#include "utils.h"
namespace
smxx
::
decode
{
namespace
gfx9
::
decode
{
__global__
void
__launch_bounds__
(
64
,
1
)
__global__
void
__launch_bounds__
(
64
,
1
)
get_mla_metadata_kernel
(
const
GetDecodeSchedMetaParams
params
)
{
get_mla_metadata_kernel
(
const
GetDecodeSchedMetaParams
params
)
{
...
...
csrc/
smxx
/decode/get_decoding_sched_meta/get_decoding_sched_meta.h
→
csrc/
gfx9
/decode/get_decoding_sched_meta/get_decoding_sched_meta.h
View file @
59487e20
...
@@ -2,7 +2,7 @@
...
@@ -2,7 +2,7 @@
#include "params.h"
#include "params.h"
namespace
smxx
::
decode
{
namespace
gfx9
::
decode
{
void
run_get_decoding_sched_meta_kernel
(
GetDecodeSchedMetaParams
&
params
);
void
run_get_decoding_sched_meta_kernel
(
GetDecodeSchedMetaParams
&
params
);
...
...
setup.py
View file @
59487e20
...
@@ -52,8 +52,8 @@ ext_modules.append(
...
@@ -52,8 +52,8 @@ ext_modules.append(
"csrc/api/api.cpp"
,
"csrc/api/api.cpp"
,
# # Misc kernels for decoding
# # Misc kernels for decoding
"csrc/
smxx
/decode/get_decoding_sched_meta/get_decoding_sched_meta.cu"
,
"csrc/
gfx9
/decode/get_decoding_sched_meta/get_decoding_sched_meta.cu"
,
"csrc/
smxx
/decode/combine/combine.cu"
,
"csrc/
gfx9
/decode/combine/combine.cu"
,
# # gfx93 dense decode
# # gfx93 dense decode
"csrc/gfx93/decode/dense/instantiations/fp16.cu"
,
"csrc/gfx93/decode/dense/instantiations/fp16.cu"
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment