Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
FlashMLA
Commits
68971b5c
Commit
68971b5c
authored
Feb 11, 2026
by
zhanghj2
Browse files
对接口进行架构检查
parent
68055db7
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
32 additions
and
13 deletions
+32
-13
csrc/api/common.h
csrc/api/common.h
+19
-1
csrc/api/dense_decode.h
csrc/api/dense_decode.h
+2
-2
csrc/api/dense_decode_kvfp8.h
csrc/api/dense_decode_kvfp8.h
+2
-2
csrc/api/dense_decode_qkvfp8.h
csrc/api/dense_decode_qkvfp8.h
+2
-2
csrc/api/sparse_decode.h
csrc/api/sparse_decode.h
+3
-1
csrc/api/sparse_fwd.h
csrc/api/sparse_fwd.h
+4
-5
No files found.
csrc/api/common.h
View file @
68971b5c
...
...
@@ -6,7 +6,7 @@
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <kerutils/supplemental/torch_tensors.h>
#include <string>
#include <cutlass/bfloat16.h>
static
constexpr
float
LOG_2_E
=
1.44269504
f
;
...
...
@@ -22,6 +22,7 @@ struct Arch {
int
major
;
int
minor
;
int
num_sms
;
std
::
string
archName
;
cudaDeviceProp
*
device_prop
;
Arch
()
{
...
...
@@ -29,6 +30,7 @@ struct Arch {
major
=
device_prop
->
major
;
minor
=
device_prop
->
minor
;
num_sms
=
device_prop
->
multiProcessorCount
;
archName
=
device_prop
->
gcnArchName
;
}
bool
is_sm90a
()
const
{
...
...
@@ -38,6 +40,22 @@ struct Arch {
bool
is_sm100f
()
const
{
return
major
==
10
;
}
bool
is_gfx938
()
const
{
return
archName
.
substr
(
0
,
archName
.
find
(
':'
))
==
"gfx938"
;
}
bool
is_gfx936
()
const
{
return
archName
.
substr
(
0
,
archName
.
find
(
':'
))
==
"gfx936"
;
}
bool
is_gfx928
()
const
{
return
archName
.
substr
(
0
,
archName
.
find
(
':'
))
==
"gfx928"
;
}
bool
is_gfx93x
()
const
{
return
is_gfx936
()
||
is_gfx938
();
}
};
...
...
csrc/api/dense_decode.h
View file @
68971b5c
...
...
@@ -24,8 +24,8 @@ dense_attn_decode_interface(
)
{
// Check arch
Arch
arch
=
Arch
();
if
(
!
arch
.
is_
sm90a
())
{
TORCH_CHECK
(
false
,
"Dense decode MLA is only supported on
SM90a
architecture"
);
if
(
!
arch
.
is_
gfx93x
())
{
TORCH_CHECK
(
false
,
"Dense decode MLA is only supported on
gfx936 or gfx938
architecture"
);
}
// Check data types
...
...
csrc/api/dense_decode_kvfp8.h
View file @
68971b5c
...
...
@@ -26,8 +26,8 @@ dense_attn_decode_kvfp8_interface(
)
{
// Check arch
Arch
arch
=
Arch
();
if
(
!
arch
.
is_
sm90a
())
{
TORCH_CHECK
(
false
,
"Dense decode MLA is only supported on
SM90a
architecture"
);
if
(
!
arch
.
is_
gfx93x
())
{
TORCH_CHECK
(
false
,
"Dense decode MLA is only supported on
gfx936 or gfx938
architecture"
);
}
// Check data types
...
...
csrc/api/dense_decode_qkvfp8.h
View file @
68971b5c
...
...
@@ -26,8 +26,8 @@ dense_attn_decode_qkvfp8_interface(
)
{
// Check arch
Arch
arch
=
Arch
();
if
(
!
arch
.
is_
sm90a
())
{
TORCH_CHECK
(
false
,
"Dense decode MLA is only supported on
SM90a
architecture"
);
if
(
!
arch
.
is_
gfx938
())
{
TORCH_CHECK
(
false
,
"Dense decode MLA is only supported on
gfx938
architecture"
);
}
// Check data types
...
...
csrc/api/sparse_decode.h
View file @
68971b5c
...
...
@@ -101,7 +101,9 @@ sparse_attn_decode_interface(
// Check the architecture
Arch
arch
=
Arch
();
if
(
!
arch
.
is_gfx93x
())
{
TORCH_CHECK
(
false
,
"Dense decode MLA is only supported on gfx936 or gfx938 architecture"
);
}
KU_CHECK_NDIM
(
q
,
4
);
KU_CHECK_NDIM
(
kv
,
4
);
KU_CHECK_NDIM
(
indices
,
3
);
...
...
csrc/api/sparse_fwd.h
View file @
68971b5c
...
...
@@ -59,10 +59,9 @@ static std::vector<at::Tensor> sparse_attn_prefill_interface(
using
bf16
=
cutlass
::
bfloat16_t
;
Arch
arch
=
Arch
();
bool
is_sm90a
=
arch
.
is_sm90a
();
bool
is_sm100f
=
arch
.
is_sm100f
();
TORCH_CHECK
(
is_sm90a
||
is_sm100f
,
"Sparse Attention Forward Kernel is only supported on SM90a and SM100f architectures."
);
if
(
!
arch
.
is_gfx93x
())
{
TORCH_CHECK
(
false
,
"Dense decode MLA is only supported on gfx936 or gfx938 architecture"
);
}
KU_CHECK_NDIM
(
q
,
3
);
KU_CHECK_NDIM
(
kv
,
3
);
KU_CHECK_NDIM
(
indices
,
3
);
...
...
@@ -161,7 +160,7 @@ static std::vector<at::Tensor> sparse_attn_prefill_interface(
required_features
.
push_back
(
FwdFeatures
::
TOPK_LENGTH
);
}
if
(
is_sm90a
)
{
if
(
arch
.
is_gfx93x
()
)
{
Fwd_Sm90_Impl
fwd_impl
;
fwd_impl
.
run
(
params
,
required_features
);
}
else
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment