Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
70f894b8
Unverified
Commit
70f894b8
authored
Feb 14, 2025
by
Yineng Zhang
Committed by
GitHub
Feb 14, 2025
Browse files
feat: support flashinfer mla attention for deepseek v3 (#3550)
parent
368de366
Changes
12
Hide whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
296 additions
and
132 deletions
+296
-132
.github/workflows/pr-test.yml
.github/workflows/pr-test.yml
+8
-8
python/pyproject.toml
python/pyproject.toml
+3
-2
python/sglang/global_config.py
python/sglang/global_config.py
+2
-0
python/sglang/srt/entrypoints/engine.py
python/sglang/srt/entrypoints/engine.py
+1
-1
python/sglang/srt/layers/attention/flashinfer_backend.py
python/sglang/srt/layers/attention/flashinfer_backend.py
+234
-109
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+1
-0
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+12
-2
python/sglang/srt/models/deepseek_v2.py
python/sglang/srt/models/deepseek_v2.py
+13
-7
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+7
-0
python/sglang/srt/utils.py
python/sglang/srt/utils.py
+7
-0
scripts/ci_install_dependency.sh
scripts/ci_install_dependency.sh
+5
-3
test/srt/test_eagle_infer.py
test/srt/test_eagle_infer.py
+3
-0
No files found.
.github/workflows/pr-test.yml
View file @
70f894b8
...
@@ -72,7 +72,7 @@ jobs:
...
@@ -72,7 +72,7 @@ jobs:
-
name
:
Install dependencies
-
name
:
Install dependencies
env
:
env
:
FLASHINFER_REPO
:
${{ inputs.version == 'nightly' && 'https://flashinfer.ai/whl/nightly/cu124/torch2.5/flashinfer' || 'https://flashinfer.ai/whl/cu124/torch2.5/flashinfer' }}
FLASHINFER_REPO
:
${{ inputs.version == 'nightly' && 'https://flashinfer.ai/whl/nightly/cu124/torch2.5/flashinfer
-python
' || 'https://flashinfer.ai/whl/cu124/torch2.5/flashinfer
-python
' }}
run
:
|
run
:
|
bash scripts/ci_install_dependency.sh
bash scripts/ci_install_dependency.sh
...
@@ -98,7 +98,7 @@ jobs:
...
@@ -98,7 +98,7 @@ jobs:
-
name
:
Install dependencies
-
name
:
Install dependencies
env
:
env
:
FLASHINFER_REPO
:
${{ inputs.version == 'nightly' && 'https://flashinfer.ai/whl/nightly/cu124/torch2.5/flashinfer' || 'https://flashinfer.ai/whl/cu124/torch2.5/flashinfer' }}
FLASHINFER_REPO
:
${{ inputs.version == 'nightly' && 'https://flashinfer.ai/whl/nightly/cu124/torch2.5/flashinfer
-python
' || 'https://flashinfer.ai/whl/cu124/torch2.5/flashinfer
-python
' }}
run
:
|
run
:
|
bash scripts/ci_install_dependency.sh
bash scripts/ci_install_dependency.sh
...
@@ -123,7 +123,7 @@ jobs:
...
@@ -123,7 +123,7 @@ jobs:
-
name
:
Install dependencies
-
name
:
Install dependencies
env
:
env
:
FLASHINFER_REPO
:
${{ inputs.version == 'nightly' && 'https://flashinfer.ai/whl/nightly/cu124/torch2.5/flashinfer' || 'https://flashinfer.ai/whl/cu124/torch2.5/flashinfer' }}
FLASHINFER_REPO
:
${{ inputs.version == 'nightly' && 'https://flashinfer.ai/whl/nightly/cu124/torch2.5/flashinfer
-python
' || 'https://flashinfer.ai/whl/cu124/torch2.5/flashinfer
-python
' }}
run
:
|
run
:
|
bash scripts/ci_install_dependency.sh
bash scripts/ci_install_dependency.sh
...
@@ -163,7 +163,7 @@ jobs:
...
@@ -163,7 +163,7 @@ jobs:
-
name
:
Install dependencies
-
name
:
Install dependencies
env
:
env
:
FLASHINFER_REPO
:
${{ inputs.version == 'nightly' && 'https://flashinfer.ai/whl/nightly/cu124/torch2.5/flashinfer' || 'https://flashinfer.ai/whl/cu124/torch2.5/flashinfer' }}
FLASHINFER_REPO
:
${{ inputs.version == 'nightly' && 'https://flashinfer.ai/whl/nightly/cu124/torch2.5/flashinfer
-python
' || 'https://flashinfer.ai/whl/cu124/torch2.5/flashinfer
-python
' }}
run
:
|
run
:
|
bash scripts/ci_install_dependency.sh
bash scripts/ci_install_dependency.sh
...
@@ -209,7 +209,7 @@ jobs:
...
@@ -209,7 +209,7 @@ jobs:
-
name
:
Install dependencies
-
name
:
Install dependencies
env
:
env
:
FLASHINFER_REPO
:
${{ inputs.version == 'nightly' && 'https://flashinfer.ai/whl/nightly/cu124/torch2.5/flashinfer' || 'https://flashinfer.ai/whl/cu124/torch2.5/flashinfer' }}
FLASHINFER_REPO
:
${{ inputs.version == 'nightly' && 'https://flashinfer.ai/whl/nightly/cu124/torch2.5/flashinfer
-python
' || 'https://flashinfer.ai/whl/cu124/torch2.5/flashinfer
-python
' }}
run
:
|
run
:
|
bash scripts/ci_install_dependency.sh
bash scripts/ci_install_dependency.sh
...
@@ -243,7 +243,7 @@ jobs:
...
@@ -243,7 +243,7 @@ jobs:
-
name
:
Install dependencies
-
name
:
Install dependencies
env
:
env
:
FLASHINFER_REPO
:
${{ inputs.version == 'nightly' && 'https://flashinfer.ai/whl/nightly/cu124/torch2.5/flashinfer' || 'https://flashinfer.ai/whl/cu124/torch2.5/flashinfer' }}
FLASHINFER_REPO
:
${{ inputs.version == 'nightly' && 'https://flashinfer.ai/whl/nightly/cu124/torch2.5/flashinfer
-python
' || 'https://flashinfer.ai/whl/cu124/torch2.5/flashinfer
-python
' }}
run
:
|
run
:
|
bash scripts/ci_install_dependency.sh
bash scripts/ci_install_dependency.sh
...
@@ -283,7 +283,7 @@ jobs:
...
@@ -283,7 +283,7 @@ jobs:
-
name
:
Install dependencies
-
name
:
Install dependencies
env
:
env
:
FLASHINFER_REPO
:
${{ inputs.version == 'nightly' && 'https://flashinfer.ai/whl/nightly/cu124/torch2.5/flashinfer' || 'https://flashinfer.ai/whl/cu124/torch2.5/flashinfer' }}
FLASHINFER_REPO
:
${{ inputs.version == 'nightly' && 'https://flashinfer.ai/whl/nightly/cu124/torch2.5/flashinfer
-python
' || 'https://flashinfer.ai/whl/cu124/torch2.5/flashinfer
-python
' }}
run
:
|
run
:
|
bash scripts/ci_install_dependency.sh
bash scripts/ci_install_dependency.sh
git clone https://github.com/merrymercy/human-eval.git
git clone https://github.com/merrymercy/human-eval.git
...
@@ -308,7 +308,7 @@ jobs:
...
@@ -308,7 +308,7 @@ jobs:
-
name
:
Install dependencies
-
name
:
Install dependencies
env
:
env
:
FLASHINFER_REPO
:
${{ inputs.version == 'nightly' && 'https://flashinfer.ai/whl/nightly/cu124/torch2.5/flashinfer' || 'https://flashinfer.ai/whl/cu124/torch2.5/flashinfer' }}
FLASHINFER_REPO
:
${{ inputs.version == 'nightly' && 'https://flashinfer.ai/whl/nightly/cu124/torch2.5/flashinfer
-python
' || 'https://flashinfer.ai/whl/cu124/torch2.5/flashinfer
-python
' }}
run
:
|
run
:
|
bash scripts/ci_install_dependency.sh
bash scripts/ci_install_dependency.sh
git clone https://github.com/merrymercy/human-eval.git
git clone https://github.com/merrymercy/human-eval.git
...
...
python/pyproject.toml
View file @
70f894b8
...
@@ -21,12 +21,13 @@ runtime_common = [
...
@@ -21,12 +21,13 @@ runtime_common = [
"hf_transfer"
,
"huggingface_hub"
,
"interegular"
,
"modelscope"
,
"hf_transfer"
,
"huggingface_hub"
,
"interegular"
,
"modelscope"
,
"orjson"
,
"packaging"
,
"pillow"
,
"prometheus-client>=0.20.0"
,
"orjson"
,
"packaging"
,
"pillow"
,
"prometheus-client>=0.20.0"
,
"psutil"
,
"pydantic"
,
"python-multipart"
,
"pyzmq>=25.1.2"
,
"psutil"
,
"pydantic"
,
"python-multipart"
,
"pyzmq>=25.1.2"
,
"torchao>=0.7.0"
,
"uvicorn"
,
"uvloop"
,
"xgrammar>=0.1.10"
"torchao>=0.7.0"
,
"uvicorn"
,
"uvloop"
,
"xgrammar>=0.1.10"
,
"ninja"
]
]
srt
=
[
srt
=
[
"sglang[runtime_common]"
,
"cuda-python"
,
"sglang[runtime_common]"
,
"cuda-python"
,
"sgl-kernel>=0.0.3.post5"
,
"torch"
,
"vllm>=0.6.4.post1,<=0.7.2"
,
"sgl-kernel>=0.0.3.post5"
,
"torch"
,
"vllm>=0.6.4.post1,<=0.7.2"
,
"flashinfer_python>=0.2.0.post2"
,
"outlines>=0.0.44,<=0.1.11"
"flashinfer_python>=0.2.1.post1"
,
"outlines>=0.0.44,<=0.1.11"
,
]
]
# HIP (Heterogeneous-computing Interface for Portability) for AMD
# HIP (Heterogeneous-computing Interface for Portability) for AMD
...
...
python/sglang/global_config.py
View file @
70f894b8
...
@@ -38,5 +38,7 @@ class GlobalConfig:
...
@@ -38,5 +38,7 @@ class GlobalConfig:
self
.
enable_precache_with_tracing
=
True
self
.
enable_precache_with_tracing
=
True
self
.
enable_parallel_encoding
=
True
self
.
enable_parallel_encoding
=
True
self
.
enable_flashinfer_mla
=
False
global_config
=
GlobalConfig
()
global_config
=
GlobalConfig
()
python/sglang/srt/entrypoints/engine.py
View file @
70f894b8
...
@@ -317,7 +317,7 @@ def _set_envs_and_config(server_args: ServerArgs):
...
@@ -317,7 +317,7 @@ def _set_envs_and_config(server_args: ServerArgs):
if
server_args
.
attention_backend
==
"flashinfer"
:
if
server_args
.
attention_backend
==
"flashinfer"
:
assert_pkg_version
(
assert_pkg_version
(
"flashinfer_python"
,
"flashinfer_python"
,
"0.2.
0
.post
2
"
,
"0.2.
1
.post
1
"
,
"Please uninstall the old version and "
"Please uninstall the old version and "
"reinstall the latest version by following the instructions "
"reinstall the latest version by following the instructions "
"at https://docs.flashinfer.ai/installation.html."
,
"at https://docs.flashinfer.ai/installation.html."
,
...
...
python/sglang/srt/layers/attention/flashinfer_backend.py
View file @
70f894b8
...
@@ -7,6 +7,7 @@ FlashInfer is faster and Triton is easier to customize.
...
@@ -7,6 +7,7 @@ FlashInfer is faster and Triton is easier to customize.
Each backend supports two operators: extend (i.e. prefill with cached prefix) and decode.
Each backend supports two operators: extend (i.e. prefill with cached prefix) and decode.
"""
"""
import
math
import
os
import
os
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
enum
import
Enum
,
auto
from
enum
import
Enum
,
auto
...
@@ -20,6 +21,7 @@ import triton.language as tl
...
@@ -20,6 +21,7 @@ import triton.language as tl
from
sglang.global_config
import
global_config
from
sglang.global_config
import
global_config
from
sglang.srt.layers.attention
import
AttentionBackend
from
sglang.srt.layers.attention
import
AttentionBackend
from
sglang.srt.layers.dp_attention
import
get_attention_tp_size
from
sglang.srt.layers.dp_attention
import
get_attention_tp_size
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
ForwardMode
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
ForwardMode
from
sglang.srt.utils
import
is_flashinfer_available
from
sglang.srt.utils
import
is_flashinfer_available
...
@@ -35,7 +37,7 @@ if is_flashinfer_available():
...
@@ -35,7 +37,7 @@ if is_flashinfer_available():
BatchPrefillWithRaggedKVCacheWrapper
,
BatchPrefillWithRaggedKVCacheWrapper
,
)
)
from
flashinfer.cascade
import
merge_state
from
flashinfer.cascade
import
merge_state
from
flashinfer.
decode
import
PosEncodingMode
from
flashinfer.
mla
import
BatchMLAPagedAttentionWrapper
class
WrapperDispatch
(
Enum
):
class
WrapperDispatch
(
Enum
):
...
@@ -45,7 +47,9 @@ class WrapperDispatch(Enum):
...
@@ -45,7 +47,9 @@ class WrapperDispatch(Enum):
@
dataclass
@
dataclass
class
DecodeMetadata
:
class
DecodeMetadata
:
decode_wrappers
:
List
[
BatchDecodeWithPagedKVCacheWrapper
]
decode_wrappers
:
List
[
Union
[
BatchDecodeWithPagedKVCacheWrapper
,
BatchMLAPagedAttentionWrapper
]
]
@
dataclass
@
dataclass
...
@@ -103,6 +107,12 @@ class FlashInferAttnBackend(AttentionBackend):
...
@@ -103,6 +107,12 @@ class FlashInferAttnBackend(AttentionBackend):
if
"Qwen2ForCausalLM"
in
model_runner
.
model_config
.
hf_config
.
architectures
:
if
"Qwen2ForCausalLM"
in
model_runner
.
model_config
.
hf_config
.
architectures
:
global_config
.
flashinfer_workspace_size
=
512
*
1024
*
1024
global_config
.
flashinfer_workspace_size
=
512
*
1024
*
1024
self
.
enable_flashinfer_mla
=
False
if
"DeepseekV3ForCausalLM"
in
model_runner
.
model_config
.
hf_config
.
architectures
:
if
global_server_args_dict
[
"enable_flashinfer_mla"
]:
self
.
enable_flashinfer_mla
=
True
global_config
.
enable_flashinfer_mla
=
True
# Allocate buffers
# Allocate buffers
global
global_workspace_buffer
global
global_workspace_buffer
if
global_workspace_buffer
is
None
:
if
global_workspace_buffer
is
None
:
...
@@ -120,6 +130,13 @@ class FlashInferAttnBackend(AttentionBackend):
...
@@ -120,6 +130,13 @@ class FlashInferAttnBackend(AttentionBackend):
)
)
for
_
in
range
(
self
.
num_wrappers
)
for
_
in
range
(
self
.
num_wrappers
)
]
]
if
self
.
enable_flashinfer_mla
:
self
.
qo_indptr
=
[
torch
.
zeros
(
(
max_bs
+
1
,),
dtype
=
torch
.
int32
,
device
=
model_runner
.
device
)
for
_
in
range
(
self
.
num_wrappers
)
]
else
:
else
:
assert
self
.
num_wrappers
==
1
assert
self
.
num_wrappers
==
1
self
.
kv_indptr
=
[
kv_indptr_buf
]
self
.
kv_indptr
=
[
kv_indptr_buf
]
...
@@ -153,13 +170,18 @@ class FlashInferAttnBackend(AttentionBackend):
...
@@ -153,13 +170,18 @@ class FlashInferAttnBackend(AttentionBackend):
self
.
prefill_wrappers_verify
.
append
(
self
.
prefill_wrappers_verify
.
append
(
BatchPrefillWithPagedKVCacheWrapper
(
self
.
workspace_buffer
,
"NHD"
)
BatchPrefillWithPagedKVCacheWrapper
(
self
.
workspace_buffer
,
"NHD"
)
)
)
self
.
decode_wrappers
.
append
(
if
self
.
enable_flashinfer_mla
:
BatchDecodeWithPagedKVCacheWrapper
(
self
.
decode_wrappers
.
append
(
self
.
workspace_buffer
,
BatchMLAPagedAttentionWrapper
(
self
.
workspace_buffer
,
backend
=
"fa2"
)
"NHD"
,
)
use_tensor_cores
=
self
.
decode_use_tensor_cores
,
else
:
self
.
decode_wrappers
.
append
(
BatchDecodeWithPagedKVCacheWrapper
(
self
.
workspace_buffer
,
"NHD"
,
use_tensor_cores
=
self
.
decode_use_tensor_cores
,
)
)
)
)
# Create indices updater
# Create indices updater
if
not
skip_prefill
:
if
not
skip_prefill
:
...
@@ -274,19 +296,32 @@ class FlashInferAttnBackend(AttentionBackend):
...
@@ -274,19 +296,32 @@ class FlashInferAttnBackend(AttentionBackend):
if
forward_mode
.
is_decode_or_idle
():
if
forward_mode
.
is_decode_or_idle
():
decode_wrappers
=
[]
decode_wrappers
=
[]
for
i
in
range
(
self
.
num_wrappers
):
for
i
in
range
(
self
.
num_wrappers
):
decode_wrappers
.
append
(
if
self
.
enable_flashinfer_mla
:
BatchDecodeWithPagedKVCacheWrapper
(
decode_wrappers
.
append
(
self
.
workspace_buffer
,
BatchMLAPagedAttentionWrapper
(
"NHD"
,
self
.
workspace_buffer
,
use_cuda_graph
=
True
,
use_cuda_graph
=
True
,
use_tensor_cores
=
self
.
decode_use_tensor_cores
,
qo_indptr
=
self
.
qo_indptr
[
i
][:
num_tokens
+
1
],
paged_kv_indptr_buffer
=
self
.
kv_indptr
[
i
][:
num_tokens
+
1
],
kv_indptr
=
self
.
kv_indptr
[
i
][:
num_tokens
+
1
],
paged_kv_indices_buffer
=
self
.
cuda_graph_kv_indices
[
i
],
kv_indices
=
self
.
cuda_graph_kv_indices
[
i
],
paged_kv_last_page_len_buffer
=
self
.
kv_last_page_len
[
kv_len_arr
=
self
.
kv_last_page_len
[:
num_tokens
],
:
num_tokens
backend
=
"fa2"
,
],
)
)
else
:
decode_wrappers
.
append
(
BatchDecodeWithPagedKVCacheWrapper
(
self
.
workspace_buffer
,
"NHD"
,
use_cuda_graph
=
True
,
use_tensor_cores
=
self
.
decode_use_tensor_cores
,
paged_kv_indptr_buffer
=
self
.
kv_indptr
[
i
][:
num_tokens
+
1
],
paged_kv_indices_buffer
=
self
.
cuda_graph_kv_indices
[
i
],
paged_kv_last_page_len_buffer
=
self
.
kv_last_page_len
[
:
num_tokens
],
)
)
)
)
seq_lens_sum
=
seq_lens
.
sum
().
item
()
seq_lens_sum
=
seq_lens
.
sum
().
item
()
self
.
indices_updater_decode
.
update
(
self
.
indices_updater_decode
.
update
(
req_pool_indices
,
req_pool_indices
,
...
@@ -375,64 +410,94 @@ class FlashInferAttnBackend(AttentionBackend):
...
@@ -375,64 +410,94 @@ class FlashInferAttnBackend(AttentionBackend):
forward_batch
:
ForwardBatch
,
forward_batch
:
ForwardBatch
,
save_kv_cache
=
True
,
save_kv_cache
=
True
,
):
):
prefill_wrapper_paged
=
self
.
forward_metadata
.
prefill_wrappers
[
if
global_config
.
enable_flashinfer_mla
:
self
.
_get_wrapper_idx
(
layer
)
cache_loc
=
(
]
forward_batch
.
out_cache_loc
cache_loc
=
(
if
not
layer
.
is_cross_attention
forward_batch
.
out_cache_loc
else
forward_batch
.
encoder_out_cache_loc
if
not
layer
.
is_cross_attention
)
else
forward_batch
.
encoder_out_cache_loc
)
logits_soft_cap
=
layer
.
logit_cap
logits_soft_cap
=
layer
.
logit_cap
if
not
self
.
forward_metadata
.
use_ragged
:
o1
,
_
=
self
.
prefill_wrapper_ragged
.
forward_return_lse
(
if
k
is
not
None
:
assert
v
is
not
None
if
save_kv_cache
:
forward_batch
.
token_to_kv_pool
.
set_kv_buffer
(
layer
,
cache_loc
,
k
,
v
,
layer
.
k_scale
,
layer
.
v_scale
)
o
=
prefill_wrapper_paged
.
forward
(
q
.
contiguous
().
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
head_dim
),
forward_batch
.
token_to_kv_pool
.
get_kv_buffer
(
layer
.
layer_id
),
causal
=
not
layer
.
is_cross_attention
,
sm_scale
=
layer
.
scaling
,
window_left
=
layer
.
sliding_window_size
,
logits_soft_cap
=
logits_soft_cap
,
k_scale
=
layer
.
k_scale
,
v_scale
=
layer
.
v_scale
,
)
else
:
o1
,
s1
=
self
.
prefill_wrapper_ragged
.
forward_return_lse
(
q
.
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
head_dim
),
q
.
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
head_dim
),
k
.
view
(
-
1
,
layer
.
tp_k_head_num
,
layer
.
head_dim
),
k
.
view
(
-
1
,
layer
.
tp_k_head_num
,
layer
.
head_dim
),
v
.
view
(
-
1
,
layer
.
tp_v_head_num
,
layer
.
head_dim
),
v
.
view
(
-
1
,
layer
.
tp_v_head_num
,
layer
.
v_
head_dim
),
causal
=
True
,
causal
=
True
,
sm_scale
=
layer
.
scaling
,
sm_scale
=
layer
.
scaling
,
logits_soft_cap
=
logits_soft_cap
,
logits_soft_cap
=
logits_soft_cap
,
)
)
if
self
.
forward_metadata
.
extend_no_prefix
:
o
=
o1
o
=
o1
else
:
if
save_kv_cache
:
o2
,
s2
=
prefill_wrapper_paged
.
forward_return_lse
(
forward_batch
.
token_to_kv_pool
.
set_kv_buffer
(
layer
,
cache_loc
,
k
,
v
,
)
return
o
.
view
(
-
1
,
layer
.
tp_q_head_num
*
layer
.
v_head_dim
)
else
:
prefill_wrapper_paged
=
self
.
forward_metadata
.
prefill_wrappers
[
self
.
_get_wrapper_idx
(
layer
)
]
cache_loc
=
(
forward_batch
.
out_cache_loc
if
not
layer
.
is_cross_attention
else
forward_batch
.
encoder_out_cache_loc
)
logits_soft_cap
=
layer
.
logit_cap
if
not
self
.
forward_metadata
.
use_ragged
:
if
k
is
not
None
:
assert
v
is
not
None
if
save_kv_cache
:
forward_batch
.
token_to_kv_pool
.
set_kv_buffer
(
layer
,
cache_loc
,
k
,
v
,
layer
.
k_scale
,
layer
.
v_scale
)
o
=
prefill_wrapper_paged
.
forward
(
q
.
contiguous
().
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
head_dim
),
q
.
contiguous
().
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
head_dim
),
forward_batch
.
token_to_kv_pool
.
get_kv_buffer
(
layer
.
layer_id
),
forward_batch
.
token_to_kv_pool
.
get_kv_buffer
(
layer
.
layer_id
),
causal
=
False
,
causal
=
not
layer
.
is_cross_attention
,
sm_scale
=
layer
.
scaling
,
sm_scale
=
layer
.
scaling
,
logits_soft_cap
=
layer
.
logit_cap
,
window_left
=
layer
.
sliding_window_size
,
logits_soft_cap
=
logits_soft_cap
,
k_scale
=
layer
.
k_scale
,
v_scale
=
layer
.
v_scale
,
)
else
:
o1
,
s1
=
self
.
prefill_wrapper_ragged
.
forward_return_lse
(
q
.
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
head_dim
),
k
.
view
(
-
1
,
layer
.
tp_k_head_num
,
layer
.
head_dim
),
v
.
view
(
-
1
,
layer
.
tp_v_head_num
,
layer
.
head_dim
),
causal
=
True
,
sm_scale
=
layer
.
scaling
,
logits_soft_cap
=
logits_soft_cap
,
)
)
o
,
_
=
merge_state
(
o1
,
s1
,
o2
,
s2
)
if
self
.
forward_metadata
.
extend_no_prefix
:
o
=
o1
else
:
o2
,
s2
=
prefill_wrapper_paged
.
forward_return_lse
(
q
.
contiguous
().
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
head_dim
),
forward_batch
.
token_to_kv_pool
.
get_kv_buffer
(
layer
.
layer_id
),
causal
=
False
,
sm_scale
=
layer
.
scaling
,
logits_soft_cap
=
layer
.
logit_cap
,
)
if
save_kv_cache
:
o
,
_
=
merge_state
(
o1
,
s1
,
o2
,
s2
)
forward_batch
.
token_to_kv_pool
.
set_kv_buffer
(
layer
,
cache_loc
,
k
,
v
,
layer
.
k_scale
,
layer
.
v_scale
if
save_kv_cache
:
)
forward_batch
.
token_to_kv_pool
.
set_kv_buffer
(
layer
,
cache_loc
,
k
,
v
,
layer
.
k_scale
,
layer
.
v_scale
)
return
o
.
view
(
-
1
,
layer
.
tp_q_head_num
*
layer
.
head_dim
)
return
o
.
view
(
-
1
,
layer
.
tp_q_head_num
*
layer
.
head_dim
)
def
forward_decode
(
def
forward_decode
(
self
,
self
,
...
@@ -452,23 +517,45 @@ class FlashInferAttnBackend(AttentionBackend):
...
@@ -452,23 +517,45 @@ class FlashInferAttnBackend(AttentionBackend):
else
forward_batch
.
encoder_out_cache_loc
else
forward_batch
.
encoder_out_cache_loc
)
)
if
k
is
not
None
:
if
self
.
enable_flashinfer_mla
:
assert
v
is
not
None
if
k
is
not
None
:
if
save_kv_cache
:
assert
v
is
not
None
forward_batch
.
token_to_kv_pool
.
set_kv_buffer
(
if
save_kv_cache
:
layer
,
cache_loc
,
k
,
v
,
layer
.
k_scale
,
layer
.
v_scale
forward_batch
.
token_to_kv_pool
.
set_kv_buffer
(
)
layer
,
cache_loc
,
k
,
v
,
)
reshaped_q
=
q
.
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
head_dim
)
k_buffer
=
forward_batch
.
token_to_kv_pool
.
get_key_buffer
(
layer
.
layer_id
)
reshaped_k
=
k_buffer
.
view
(
-
1
,
1
,
layer
.
head_dim
)
o
=
decode_wrapper
.
run
(
reshaped_q
[:,
:,
:
layer
.
v_head_dim
],
reshaped_q
[:,
:,
layer
.
v_head_dim
:],
reshaped_k
[:,
:,
:
layer
.
v_head_dim
],
reshaped_k
[:,
:,
layer
.
v_head_dim
:],
)
o
=
decode_wrapper
.
forward
(
return
o
.
view
(
-
1
,
layer
.
tp_q_head_num
*
layer
.
v_head_dim
)
q
.
contiguous
().
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
head_dim
),
else
:
forward_batch
.
token_to_kv_pool
.
get_kv_buffer
(
layer
.
layer_id
),
if
k
is
not
None
:
sm_scale
=
layer
.
scaling
,
assert
v
is
not
None
logits_soft_cap
=
layer
.
logit_cap
,
if
save_kv_cache
:
k_scale
=
layer
.
k_scale
,
forward_batch
.
token_to_kv_pool
.
set_kv_buffer
(
v_scale
=
layer
.
v_scale
,
layer
,
cache_loc
,
k
,
v
,
layer
.
k_scale
,
layer
.
v_scale
)
)
o
=
decode_wrapper
.
forward
(
q
.
contiguous
().
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
head_dim
),
forward_batch
.
token_to_kv_pool
.
get_kv_buffer
(
layer
.
layer_id
),
sm_scale
=
layer
.
scaling
,
logits_soft_cap
=
layer
.
logit_cap
,
k_scale
=
layer
.
k_scale
,
v_scale
=
layer
.
v_scale
,
)
return
o
.
view
(
-
1
,
layer
.
tp_q_head_num
*
layer
.
head_dim
)
return
o
.
view
(
-
1
,
layer
.
tp_q_head_num
*
layer
.
head_dim
)
def
_get_wrapper_idx
(
self
,
layer
:
RadixAttention
):
def
_get_wrapper_idx
(
self
,
layer
:
RadixAttention
):
if
self
.
num_wrappers
==
1
:
if
self
.
num_wrappers
==
1
:
...
@@ -516,7 +603,9 @@ class FlashInferIndicesUpdaterDecode:
...
@@ -516,7 +603,9 @@ class FlashInferIndicesUpdaterDecode:
req_pool_indices
:
torch
.
Tensor
,
req_pool_indices
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
seq_lens_sum
:
int
,
seq_lens_sum
:
int
,
decode_wrappers
:
List
[
BatchDecodeWithPagedKVCacheWrapper
],
decode_wrappers
:
List
[
Union
[
BatchDecodeWithPagedKVCacheWrapper
,
BatchMLAPagedAttentionWrapper
]
],
encoder_lens
:
Optional
[
torch
.
Tensor
],
encoder_lens
:
Optional
[
torch
.
Tensor
],
spec_info
:
Optional
[
SpecInfo
],
spec_info
:
Optional
[
SpecInfo
],
):
):
...
@@ -528,7 +617,9 @@ class FlashInferIndicesUpdaterDecode:
...
@@ -528,7 +617,9 @@ class FlashInferIndicesUpdaterDecode:
req_pool_indices
:
torch
.
Tensor
,
req_pool_indices
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
seq_lens_sum
:
int
,
seq_lens_sum
:
int
,
decode_wrappers
:
List
[
BatchDecodeWithPagedKVCacheWrapper
],
decode_wrappers
:
List
[
Union
[
BatchDecodeWithPagedKVCacheWrapper
,
BatchMLAPagedAttentionWrapper
]
],
encoder_lens
:
Optional
[
torch
.
Tensor
],
encoder_lens
:
Optional
[
torch
.
Tensor
],
spec_info
:
Optional
[
SpecInfo
],
spec_info
:
Optional
[
SpecInfo
],
):
):
...
@@ -609,7 +700,9 @@ class FlashInferIndicesUpdaterDecode:
...
@@ -609,7 +700,9 @@ class FlashInferIndicesUpdaterDecode:
def
call_begin_forward
(
def
call_begin_forward
(
self
,
self
,
wrapper
:
BatchDecodeWithPagedKVCacheWrapper
,
wrapper
:
Union
[
BatchDecodeWithPagedKVCacheWrapper
,
BatchMLAPagedAttentionWrapper
],
req_pool_indices
:
torch
.
Tensor
,
req_pool_indices
:
torch
.
Tensor
,
paged_kernel_lens
:
torch
.
Tensor
,
paged_kernel_lens
:
torch
.
Tensor
,
paged_kernel_lens_sum
:
int
,
paged_kernel_lens_sum
:
int
,
...
@@ -637,18 +730,37 @@ class FlashInferIndicesUpdaterDecode:
...
@@ -637,18 +730,37 @@ class FlashInferIndicesUpdaterDecode:
kv_indptr
,
kv_indices
=
spec_info
.
kv_indptr
,
spec_info
.
kv_indices
kv_indptr
,
kv_indices
=
spec_info
.
kv_indptr
,
spec_info
.
kv_indices
bs
=
kv_indptr
.
shape
[
0
]
-
1
bs
=
kv_indptr
.
shape
[
0
]
-
1
wrapper
.
begin_forward
(
if
global_config
.
enable_flashinfer_mla
:
kv_indptr
,
sm_scale
=
1.0
/
math
.
sqrt
(
192
)
kv_indices
,
q_indptr
=
torch
.
arange
(
0
,
bs
+
1
).
to
(
0
).
int
()
self
.
kv_last_page_len
[:
bs
],
kv_lens
=
paged_kernel_lens
.
to
(
torch
.
int32
)
self
.
num_qo_heads
,
wrapper
.
plan
(
self
.
num_kv_heads
,
q_indptr
,
self
.
head_dim
,
kv_indptr
,
1
,
kv_indices
,
data_type
=
self
.
data_type
,
kv_lens
,
q_data_type
=
self
.
q_data_type
,
self
.
num_qo_heads
,
non_blocking
=
True
,
512
,
)
64
,
1
,
False
,
sm_scale
,
self
.
data_type
,
self
.
data_type
,
)
else
:
wrapper
.
begin_forward
(
kv_indptr
,
kv_indices
,
self
.
kv_last_page_len
[:
bs
],
self
.
num_qo_heads
,
self
.
num_kv_heads
,
self
.
head_dim
,
1
,
data_type
=
self
.
data_type
,
q_data_type
=
self
.
q_data_type
,
non_blocking
=
True
,
)
class
FlashInferIndicesUpdaterPrefill
:
class
FlashInferIndicesUpdaterPrefill
:
...
@@ -857,30 +969,42 @@ class FlashInferIndicesUpdaterPrefill:
...
@@ -857,30 +969,42 @@ class FlashInferIndicesUpdaterPrefill:
# extend part
# extend part
if
use_ragged
:
if
use_ragged
:
wrapper_ragged
.
begin_forward
(
if
global_config
.
enable_flashinfer_mla
:
qo_indptr
,
wrapper_ragged
.
begin_forward
(
qo_indptr
=
qo_indptr
,
kv_indptr
=
qo_indptr
,
num_qo_heads
=
self
.
num_qo_heads
,
num_kv_heads
=
self
.
num_kv_heads
,
head_dim_qk
=
192
,
head_dim_vo
=
128
,
q_data_type
=
self
.
q_data_type
,
)
else
:
wrapper_ragged
.
begin_forward
(
qo_indptr
,
qo_indptr
,
self
.
num_qo_heads
,
self
.
num_kv_heads
,
self
.
head_dim
,
q_data_type
=
self
.
q_data_type
,
)
if
not
global_config
.
enable_flashinfer_mla
:
# cached part
wrapper_paged
.
begin_forward
(
qo_indptr
,
qo_indptr
,
kv_indptr
,
kv_indices
,
self
.
kv_last_page_len
[:
bs
],
self
.
num_qo_heads
,
self
.
num_qo_heads
,
self
.
num_kv_heads
,
self
.
num_kv_heads
,
self
.
head_dim
,
self
.
head_dim
,
1
,
q_data_type
=
self
.
q_data_type
,
q_data_type
=
self
.
q_data_type
,
custom_mask
=
custom_mask
,
non_blocking
=
True
,
)
)
# cached part
wrapper_paged
.
begin_forward
(
qo_indptr
,
kv_indptr
,
kv_indices
,
self
.
kv_last_page_len
[:
bs
],
self
.
num_qo_heads
,
self
.
num_kv_heads
,
self
.
head_dim
,
1
,
q_data_type
=
self
.
q_data_type
,
custom_mask
=
custom_mask
,
non_blocking
=
True
,
)
class
FlashInferMultiStepDraftBackend
:
class
FlashInferMultiStepDraftBackend
:
"""
"""
...
@@ -1163,6 +1287,7 @@ def fast_decode_plan(
...
@@ -1163,6 +1287,7 @@ def fast_decode_plan(
window_left
,
window_left
,
logits_soft_cap
,
logits_soft_cap
,
head_dim
,
head_dim
,
head_dim
,
empty_q_data
,
empty_q_data
,
empty_kv_cache
,
empty_kv_cache
,
stream
.
cuda_stream
,
stream
.
cuda_stream
,
...
...
python/sglang/srt/managers/schedule_batch.py
View file @
70f894b8
...
@@ -65,6 +65,7 @@ global_server_args_dict = {
...
@@ -65,6 +65,7 @@ global_server_args_dict = {
"enable_dp_attention"
:
ServerArgs
.
enable_dp_attention
,
"enable_dp_attention"
:
ServerArgs
.
enable_dp_attention
,
"enable_ep_moe"
:
ServerArgs
.
enable_ep_moe
,
"enable_ep_moe"
:
ServerArgs
.
enable_ep_moe
,
"device"
:
ServerArgs
.
device
,
"device"
:
ServerArgs
.
device
,
"enable_flashinfer_mla"
:
ServerArgs
.
enable_flashinfer_mla
,
}
}
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
...
...
python/sglang/srt/model_executor/model_runner.py
View file @
70f894b8
...
@@ -67,6 +67,7 @@ from sglang.srt.utils import (
...
@@ -67,6 +67,7 @@ from sglang.srt.utils import (
monkey_patch_p2p_access_check
,
monkey_patch_p2p_access_check
,
monkey_patch_vllm_gguf_config
,
monkey_patch_vllm_gguf_config
,
set_cpu_offload_max_bytes
,
set_cpu_offload_max_bytes
,
set_cuda_arch
,
)
)
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
...
@@ -110,8 +111,14 @@ class ModelRunner:
...
@@ -110,8 +111,14 @@ class ModelRunner:
):
):
# TODO: add MLA optimization on CPU
# TODO: add MLA optimization on CPU
if
self
.
server_args
.
device
!=
"cpu"
:
if
self
.
server_args
.
device
!=
"cpu"
:
logger
.
info
(
"MLA optimization is turned on. Use triton backend."
)
if
server_args
.
enable_flashinfer_mla
:
self
.
server_args
.
attention_backend
=
"triton"
logger
.
info
(
"FlashInfer MLA optimization is turned on. Use flashinfer backend for DeepseekV3ForCausalLM."
)
self
.
server_args
.
attention_backend
=
"flashinfer"
else
:
logger
.
info
(
"MLA optimization is turned on. Use triton backend."
)
self
.
server_args
.
attention_backend
=
"triton"
if
self
.
server_args
.
enable_double_sparsity
:
if
self
.
server_args
.
enable_double_sparsity
:
logger
.
info
(
logger
.
info
(
...
@@ -169,6 +176,7 @@ class ModelRunner:
...
@@ -169,6 +176,7 @@ class ModelRunner:
"enable_dp_attention"
:
server_args
.
enable_dp_attention
,
"enable_dp_attention"
:
server_args
.
enable_dp_attention
,
"enable_ep_moe"
:
server_args
.
enable_ep_moe
,
"enable_ep_moe"
:
server_args
.
enable_ep_moe
,
"device"
:
server_args
.
device
,
"device"
:
server_args
.
device
,
"enable_flashinfer_mla"
:
server_args
.
enable_flashinfer_mla
,
}
}
)
)
...
@@ -292,6 +300,8 @@ class ModelRunner:
...
@@ -292,6 +300,8 @@ class ModelRunner:
if
torch
.
cuda
.
get_device_capability
()[
1
]
<
5
:
if
torch
.
cuda
.
get_device_capability
()[
1
]
<
5
:
raise
RuntimeError
(
"SGLang only supports sm75 and above."
)
raise
RuntimeError
(
"SGLang only supports sm75 and above."
)
set_cuda_arch
()
# Prepare the model config
# Prepare the model config
self
.
load_config
=
LoadConfig
(
self
.
load_config
=
LoadConfig
(
load_format
=
self
.
server_args
.
load_format
,
load_format
=
self
.
server_args
.
load_format
,
...
...
python/sglang/srt/models/deepseek_v2.py
View file @
70f894b8
...
@@ -510,14 +510,20 @@ class DeepseekV2AttentionMLA(nn.Module):
...
@@ -510,14 +510,20 @@ class DeepseekV2AttentionMLA(nn.Module):
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
,
forward_batch
:
ForwardBatch
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
# Use normal computation for prefill and use weight absorption for extend/decode
if
global_server_args_dict
[
"enable_flashinfer_mla"
]:
if
(
if
forward_batch
.
forward_mode
.
is_extend
():
forward_batch
.
forward_mode
.
is_extend
()
return
self
.
forward_normal
(
positions
,
hidden_states
,
forward_batch
)
and
forward_batch
.
extend_prefix_lens
.
sum
()
==
0
else
:
):
return
self
.
forward_absorb
(
positions
,
hidden_states
,
forward_batch
)
return
self
.
forward_normal
(
positions
,
hidden_states
,
forward_batch
)
else
:
else
:
return
self
.
forward_absorb
(
positions
,
hidden_states
,
forward_batch
)
# Triton: Use normal computation for prefill and use weight absorption for extend/decode
if
(
forward_batch
.
forward_mode
.
is_extend
()
and
forward_batch
.
extend_prefix_lens
.
sum
()
==
0
):
return
self
.
forward_normal
(
positions
,
hidden_states
,
forward_batch
)
else
:
return
self
.
forward_absorb
(
positions
,
hidden_states
,
forward_batch
)
def
forward_normal
(
def
forward_normal
(
self
,
self
,
...
...
python/sglang/srt/server_args.py
View file @
70f894b8
...
@@ -168,6 +168,8 @@ class ServerArgs:
...
@@ -168,6 +168,8 @@ class ServerArgs:
tool_call_parser
:
str
=
None
tool_call_parser
:
str
=
None
enable_hierarchical_cache
:
bool
=
False
enable_hierarchical_cache
:
bool
=
False
enable_flashinfer_mla
:
bool
=
False
def
__post_init__
(
self
):
def
__post_init__
(
self
):
# Set missing default values
# Set missing default values
if
self
.
tokenizer_path
is
None
:
if
self
.
tokenizer_path
is
None
:
...
@@ -693,6 +695,11 @@ class ServerArgs:
...
@@ -693,6 +695,11 @@ class ServerArgs:
default
=
ServerArgs
.
grammar_backend
,
default
=
ServerArgs
.
grammar_backend
,
help
=
"Choose the backend for grammar-guided decoding."
,
help
=
"Choose the backend for grammar-guided decoding."
,
)
)
parser
.
add_argument
(
"--enable-flashinfer-mla"
,
action
=
"store_true"
,
help
=
"Enable FlashInfer MLA optimization"
,
)
# Speculative decoding
# Speculative decoding
parser
.
add_argument
(
parser
.
add_argument
(
...
...
python/sglang/srt/utils.py
View file @
70f894b8
...
@@ -1444,3 +1444,10 @@ def launch_dummy_health_check_server(host, port):
...
@@ -1444,3 +1444,10 @@ def launch_dummy_health_check_server(host, port):
timeout_keep_alive
=
5
,
timeout_keep_alive
=
5
,
loop
=
"uvloop"
,
loop
=
"uvloop"
,
)
)
def
set_cuda_arch
():
if
is_flashinfer_available
():
capability
=
torch
.
cuda
.
get_device_capability
()
arch
=
f
"
{
capability
[
0
]
}
.
{
capability
[
1
]
}
"
os
.
environ
[
"TORCH_CUDA_ARCH_LIST"
]
=
f
"
{
arch
}{
'+PTX'
if
arch
==
'9.0'
else
''
}
"
scripts/ci_install_dependency.sh
View file @
70f894b8
...
@@ -4,17 +4,19 @@ set -euxo pipefail
...
@@ -4,17 +4,19 @@ set -euxo pipefail
# Install the dependency in CI.
# Install the dependency in CI.
# Use repo from environment variable, passed from GitHub Actions
# Use repo from environment variable, passed from GitHub Actions
FLASHINFER_REPO
=
"
${
FLASHINFER_REPO
:-
https
://flashinfer.ai/whl/cu124/torch2.5/flashinfer
}
"
FLASHINFER_REPO
=
"
${
FLASHINFER_REPO
:-
https
://flashinfer.ai/whl/cu124/torch2.5/flashinfer
-python
}
"
SCRIPT_DIR
=
"
$(
cd
"
$(
dirname
"
${
BASH_SOURCE
[0]
}
"
)
"
&&
pwd
)
"
SCRIPT_DIR
=
"
$(
cd
"
$(
dirname
"
${
BASH_SOURCE
[0]
}
"
)
"
&&
pwd
)
"
bash
"
${
SCRIPT_DIR
}
/killall_sglang.sh"
bash
"
${
SCRIPT_DIR
}
/killall_sglang.sh"
pip
install
--upgrade
pip
pip
install
--upgrade
pip
pip uninstall flashinfer
-y
pip uninstall flashinfer
-y
pip
install
-e
"python[all]"
--find-links
https://flashinfer.ai/whl/cu124/torch2.5/flashinfer
/
pip
install
-e
"python[all]"
--find-links
https://flashinfer.ai/whl/cu124/torch2.5/flashinfer
-python
rm
-rf
/root/.cache/flashinfer
# Force reinstall flashinfer and torch_memory_saver
# Force reinstall flashinfer and torch_memory_saver
pip
install
flashinfer_python
==
0.2.0.post2
--find-links
${
FLASHINFER_REPO
}
--force-reinstall
--no-deps
pip
install
flashinfer_python
==
0.2.1.post1
--find-links
${
FLASHINFER_REPO
}
--force-reinstall
--no-deps
pip
install
torch_memory_saver
--force-reinstall
pip
install
torch_memory_saver
--force-reinstall
pip
install
transformers
==
4.45.2 sentence_transformers accelerate peft
pip
install
transformers
==
4.45.2 sentence_transformers accelerate peft
...
...
test/srt/test_eagle_infer.py
View file @
70f894b8
...
@@ -28,6 +28,7 @@ class TestEAGLEEngine(unittest.TestCase):
...
@@ -28,6 +28,7 @@ class TestEAGLEEngine(unittest.TestCase):
"speculative_eagle_topk"
:
8
,
"speculative_eagle_topk"
:
8
,
"speculative_num_draft_tokens"
:
64
,
"speculative_num_draft_tokens"
:
64
,
"mem_fraction_static"
:
0.7
,
"mem_fraction_static"
:
0.7
,
"cuda_graph_max_bs"
:
32
,
}
}
def
setUp
(
self
):
def
setUp
(
self
):
...
@@ -124,6 +125,8 @@ class TestEAGLEServer(unittest.TestCase):
...
@@ -124,6 +125,8 @@ class TestEAGLEServer(unittest.TestCase):
"64"
,
"64"
,
"--mem-fraction-static"
,
"--mem-fraction-static"
,
"0.7"
,
"0.7"
,
"--cuda-graph-max-bs"
,
"32"
,
],
],
)
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment