Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
9971dc22
"tests/vscode:/vscode.git/clone" did not exist on "2ecd2b23ec7e8f48c0e7286dad306d7265e17a29"
Unverified
Commit
9971dc22
authored
Mar 16, 2025
by
Yineng Zhang
Committed by
GitHub
Mar 16, 2025
Browse files
Revert "feat: Add FlashMLA submodule (#4449)" (#4470)
parent
3db35c1a
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
0 additions
and
205 deletions
+0
-205
.gitmodules
.gitmodules
+0
-3
sgl-kernel/setup.py
sgl-kernel/setup.py
+0
-49
sgl-kernel/tests/test_flash_mla.py
sgl-kernel/tests/test_flash_mla.py
+0
-153
No files found.
.gitmodules
View file @
9971dc22
...
@@ -10,6 +10,3 @@
...
@@ -10,6 +10,3 @@
[submodule "sgl-kernel/3rdparty/deepgemm"]
[submodule "sgl-kernel/3rdparty/deepgemm"]
path = sgl-kernel/3rdparty/deepgemm
path = sgl-kernel/3rdparty/deepgemm
url = https://github.com/deepseek-ai/DeepGEMM
url = https://github.com/deepseek-ai/DeepGEMM
[submodule "sgl-kernel/3rdparty/flashmla"]
path = sgl-kernel/3rdparty/flashmla
url = https://github.com/deepseek-ai/FlashMLA
sgl-kernel/setup.py
View file @
9971dc22
...
@@ -18,21 +18,6 @@ import shutil
...
@@ -18,21 +18,6 @@ import shutil
import
sys
import
sys
from
pathlib
import
Path
from
pathlib
import
Path
# Setup flash_mla at the top level for tests to find
# This makes the module importable without installation
root_dir
=
Path
(
__file__
).
parent
.
resolve
()
module_src
=
root_dir
/
"3rdparty"
/
"flashmla"
/
"flash_mla"
module_dest
=
root_dir
/
"flash_mla"
if
module_src
.
exists
()
and
not
module_dest
.
exists
():
try
:
os
.
symlink
(
module_src
,
module_dest
,
target_is_directory
=
True
)
print
(
f
"Created symbolic link from
{
module_src
}
to
{
module_dest
}
"
)
except
(
OSError
,
NotImplementedError
):
if
module_src
.
exists
():
shutil
.
copytree
(
module_src
,
module_dest
)
print
(
f
"Copied directory from
{
module_src
}
to
{
module_dest
}
"
)
import
torch
import
torch
from
setuptools
import
find_packages
,
setup
from
setuptools
import
find_packages
,
setup
from
setuptools.command.build_py
import
build_py
from
setuptools.command.build_py
import
build_py
...
@@ -70,7 +55,6 @@ cutlass_default = root / "3rdparty" / "cutlass"
...
@@ -70,7 +55,6 @@ cutlass_default = root / "3rdparty" / "cutlass"
cutlass
=
Path
(
os
.
environ
.
get
(
"CUSTOM_CUTLASS_SRC_DIR"
,
default
=
cutlass_default
))
cutlass
=
Path
(
os
.
environ
.
get
(
"CUSTOM_CUTLASS_SRC_DIR"
,
default
=
cutlass_default
))
flashinfer
=
root
/
"3rdparty"
/
"flashinfer"
flashinfer
=
root
/
"3rdparty"
/
"flashinfer"
deepgemm
=
root
/
"3rdparty"
/
"deepgemm"
deepgemm
=
root
/
"3rdparty"
/
"deepgemm"
flashmla
=
root
/
"3rdparty"
/
"flashmla"
include_dirs
=
[
include_dirs
=
[
root
/
"include"
,
root
/
"include"
,
root
/
"csrc"
,
root
/
"csrc"
,
...
@@ -79,7 +63,6 @@ include_dirs = [
...
@@ -79,7 +63,6 @@ include_dirs = [
flashinfer
.
resolve
()
/
"include"
,
flashinfer
.
resolve
()
/
"include"
,
flashinfer
.
resolve
()
/
"include"
/
"gemm"
,
flashinfer
.
resolve
()
/
"include"
/
"gemm"
,
flashinfer
.
resolve
()
/
"csrc"
,
flashinfer
.
resolve
()
/
"csrc"
,
flashmla
.
resolve
()
/
"csrc"
,
"cublas"
,
"cublas"
,
]
]
...
@@ -87,7 +70,6 @@ include_dirs = [
...
@@ -87,7 +70,6 @@ include_dirs = [
class
CustomBuildPy
(
build_py
):
class
CustomBuildPy
(
build_py
):
def
run
(
self
):
def
run
(
self
):
self
.
copy_deepgemm_to_build_lib
()
self
.
copy_deepgemm_to_build_lib
()
self
.
copy_flashmla_to_build_lib
()
self
.
make_jit_include_symlinks
()
self
.
make_jit_include_symlinks
()
build_py
.
run
(
self
)
build_py
.
run
(
self
)
...
@@ -111,17 +93,6 @@ class CustomBuildPy(build_py):
...
@@ -111,17 +93,6 @@ class CustomBuildPy(build_py):
os
.
unlink
(
dst_dir
)
os
.
unlink
(
dst_dir
)
os
.
symlink
(
src_dir
,
dst_dir
,
target_is_directory
=
True
)
os
.
symlink
(
src_dir
,
dst_dir
,
target_is_directory
=
True
)
# Create symbolic links for FlashMLA
flash_mla_include_dir
=
os
.
path
.
join
(
self
.
build_lib
,
"flash_mla/include"
)
os
.
makedirs
(
flash_mla_include_dir
,
exist_ok
=
True
)
# Create empty directories for FlashMLA's include paths
# This is safer than creating symlinks as the targets might not exist in CI
for
dirname
in
[
"cute"
,
"cutlass"
]:
dst_dir
=
f
"
{
flash_mla_include_dir
}
/
{
dirname
}
"
if
not
os
.
path
.
exists
(
dst_dir
):
os
.
makedirs
(
dst_dir
,
exist_ok
=
True
)
def
copy_deepgemm_to_build_lib
(
self
):
def
copy_deepgemm_to_build_lib
(
self
):
"""
"""
This function copies DeepGemm to python's site-packages
This function copies DeepGemm to python's site-packages
...
@@ -139,26 +110,6 @@ class CustomBuildPy(build_py):
...
@@ -139,26 +110,6 @@ class CustomBuildPy(build_py):
# Copy the directory
# Copy the directory
shutil
.
copytree
(
src_dir
,
dst_dir
)
shutil
.
copytree
(
src_dir
,
dst_dir
)
def
copy_flashmla_to_build_lib
(
self
):
"""
This function copies FlashMLA to python's site-packages
"""
dst_dir
=
os
.
path
.
join
(
self
.
build_lib
,
"flash_mla"
)
os
.
makedirs
(
dst_dir
,
exist_ok
=
True
)
src_dir
=
os
.
path
.
join
(
str
(
flashmla
.
resolve
()),
"flash_mla"
)
if
not
os
.
path
.
exists
(
src_dir
):
print
(
f
"Warning: Source directory
{
src_dir
}
does not exist, possibly the submodule is not properly initialized"
)
return
if
os
.
path
.
exists
(
dst_dir
):
shutil
.
rmtree
(
dst_dir
)
shutil
.
copytree
(
src_dir
,
dst_dir
)
nvcc_flags
=
[
nvcc_flags
=
[
"-DNDEBUG"
,
"-DNDEBUG"
,
...
...
sgl-kernel/tests/test_flash_mla.py
deleted
100644 → 0
View file @
3db35c1a
import
argparse
import
math
import
random
import
torch
import
triton
from
flash_mla
import
flash_mla_with_kvcache
,
get_mla_metadata
"""
fork FlashMLA/tests/test_flash_mla.py
"""
def
scaled_dot_product_attention
(
query
,
key
,
value
,
h_q
,
h_kv
,
is_causal
=
False
):
query
=
query
.
float
()
key
=
key
.
float
()
value
=
value
.
float
()
key
=
key
.
repeat_interleave
(
h_q
//
h_kv
,
dim
=
0
)
value
=
value
.
repeat_interleave
(
h_q
//
h_kv
,
dim
=
0
)
attn_weight
=
query
@
key
.
transpose
(
-
2
,
-
1
)
/
math
.
sqrt
(
query
.
size
(
-
1
))
if
is_causal
:
s_q
=
query
.
shape
[
-
2
]
s_k
=
key
.
shape
[
-
2
]
attn_bias
=
torch
.
zeros
(
s_q
,
s_k
,
dtype
=
query
.
dtype
)
temp_mask
=
torch
.
ones
(
s_q
,
s_k
,
dtype
=
torch
.
bool
).
tril
(
diagonal
=
s_k
-
s_q
)
attn_bias
.
masked_fill_
(
temp_mask
.
logical_not
(),
float
(
"-inf"
))
attn_bias
.
to
(
query
.
dtype
)
attn_weight
+=
attn_bias
lse
=
attn_weight
.
logsumexp
(
dim
=-
1
)
attn_weight
=
torch
.
softmax
(
attn_weight
,
dim
=-
1
,
dtype
=
torch
.
float32
)
return
attn_weight
@
value
,
lse
def
cal_diff
(
x
:
torch
.
Tensor
,
y
:
torch
.
Tensor
,
name
:
str
)
->
None
:
x
,
y
=
x
.
double
(),
y
.
double
()
RMSE
=
((
x
-
y
)
*
(
x
-
y
)).
mean
().
sqrt
().
item
()
cos_diff
=
1
-
2
*
(
x
*
y
).
sum
().
item
()
/
max
((
x
*
x
+
y
*
y
).
sum
().
item
(),
1e-12
)
amax_diff
=
(
x
-
y
).
abs
().
max
().
item
()
# print(f"{name}: {cos_diff=}, {RMSE=}, {amax_diff=}")
assert
cos_diff
<
1e-5
@
torch
.
inference_mode
()
def
test_flash_mla
(
b
,
s_q
,
mean_sk
,
h_q
,
h_kv
,
d
,
dv
,
causal
,
varlen
):
print
(
f
"
{
b
=
}
,
{
s_q
=
}
,
{
mean_sk
=
}
,
{
h_q
=
}
,
{
h_kv
=
}
,
{
d
=
}
,
{
dv
=
}
,
{
causal
=
}
,
{
varlen
=
}
"
)
cache_seqlens
=
torch
.
full
((
b
,),
mean_sk
,
dtype
=
torch
.
int32
)
if
varlen
:
for
i
in
range
(
b
):
cache_seqlens
[
i
]
=
max
(
random
.
normalvariate
(
mean_sk
,
mean_sk
/
2
),
s_q
)
total_seqlens
=
cache_seqlens
.
sum
().
item
()
mean_seqlens
=
cache_seqlens
.
float
().
mean
().
int
().
item
()
max_seqlen
=
cache_seqlens
.
max
().
item
()
max_seqlen_pad
=
triton
.
cdiv
(
max_seqlen
,
256
)
*
256
# print(f"{total_seqlens=}, {mean_seqlens=}, {max_seqlen=}")
q
=
torch
.
randn
(
b
,
s_q
,
h_q
,
d
)
block_size
=
64
block_table
=
torch
.
arange
(
b
*
max_seqlen_pad
//
block_size
,
dtype
=
torch
.
int32
).
view
(
b
,
max_seqlen_pad
//
block_size
)
blocked_k
=
torch
.
randn
(
block_table
.
numel
(),
block_size
,
h_kv
,
d
)
for
i
in
range
(
b
):
blocked_k
.
view
(
b
,
max_seqlen_pad
,
h_kv
,
d
)[
i
,
cache_seqlens
[
i
].
item
()
:]
=
(
float
(
"nan"
)
)
blocked_v
=
blocked_k
[...,
:
dv
]
tile_scheduler_metadata
,
num_splits
=
get_mla_metadata
(
cache_seqlens
,
s_q
*
h_q
//
h_kv
,
h_kv
)
def
flash_mla
():
return
flash_mla_with_kvcache
(
q
,
blocked_k
,
block_table
,
cache_seqlens
,
dv
,
tile_scheduler_metadata
,
num_splits
,
causal
=
causal
,
)
def
ref_mla
():
out
=
torch
.
empty
(
b
,
s_q
,
h_q
,
dv
,
dtype
=
torch
.
float32
)
lse
=
torch
.
empty
(
b
,
h_q
,
s_q
,
dtype
=
torch
.
float32
)
for
i
in
range
(
b
):
begin
=
i
*
max_seqlen_pad
end
=
begin
+
cache_seqlens
[
i
]
O
,
LSE
=
scaled_dot_product_attention
(
q
[
i
].
transpose
(
0
,
1
),
blocked_k
.
view
(
-
1
,
h_kv
,
d
)[
begin
:
end
].
transpose
(
0
,
1
),
blocked_v
.
view
(
-
1
,
h_kv
,
dv
)[
begin
:
end
].
transpose
(
0
,
1
),
h_q
=
h_q
,
h_kv
=
h_kv
,
is_causal
=
causal
,
)
out
[
i
]
=
O
.
transpose
(
0
,
1
)
lse
[
i
]
=
LSE
return
out
,
lse
out_flash
,
lse_flash
=
flash_mla
()
out_torch
,
lse_torch
=
ref_mla
()
cal_diff
(
out_flash
,
out_torch
,
"out"
)
cal_diff
(
lse_flash
,
lse_torch
,
"lse"
)
t
=
triton
.
testing
.
do_bench
(
flash_mla
)
FLOPS
=
s_q
*
total_seqlens
*
h_q
*
(
d
+
dv
)
*
2
bytes
=
(
total_seqlens
*
h_kv
*
d
+
b
*
s_q
*
h_q
*
d
+
b
*
s_q
*
h_q
*
dv
)
*
(
torch
.
finfo
(
q
.
dtype
).
bits
//
8
)
print
(
f
"
{
t
:.
3
f
}
ms,
{
FLOPS
/
10
**
9
/
t
:.
0
f
}
TFLOPS,
{
bytes
/
10
**
6
/
t
:.
0
f
}
GB/s"
)
def
main
(
torch_dtype
):
device
=
torch
.
device
(
"cuda:0"
)
torch
.
set_default_dtype
(
torch_dtype
)
torch
.
set_default_device
(
device
)
torch
.
cuda
.
set_device
(
device
)
torch
.
manual_seed
(
0
)
random
.
seed
(
0
)
h_kv
=
1
d
,
dv
=
576
,
512
causal
=
True
for
b
in
[
128
]:
for
s
in
[
4096
,
8192
]:
for
h_q
in
[
16
,
32
,
64
,
128
]:
# TP = 8, 4, 2, 1
for
s_q
in
[
1
,
2
]:
# MTP = 1, 2
for
varlen
in
[
False
,
True
]:
test_flash_mla
(
b
,
s_q
,
s
,
h_q
,
h_kv
,
d
,
dv
,
causal
,
varlen
)
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--dtype"
,
type
=
str
,
choices
=
[
"bf16"
,
"fp16"
],
default
=
"bf16"
,
help
=
"Data type to use for testing (bf16 or fp16)"
,
)
args
=
parser
.
parse_args
()
torch_dtype
=
torch
.
bfloat16
if
args
.
dtype
==
"fp16"
:
torch_dtype
=
torch
.
float16
main
(
torch_dtype
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment