Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
b1100846
Unverified
Commit
b1100846
authored
Feb 24, 2025
by
Baizhou Zhang
Committed by
GitHub
Feb 24, 2025
Browse files
Refactor flashinfer logic for deepseek v3 and fix accuracy bug (#3785)
parent
27a46317
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
565 additions
and
19 deletions
+565
-19
python/sglang/srt/configs/model_config.py
python/sglang/srt/configs/model_config.py
+20
-0
python/sglang/srt/layers/attention/flashinfer_mla_backend.py
python/sglang/srt/layers/attention/flashinfer_mla_backend.py
+521
-0
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+5
-2
python/sglang/srt/models/deepseek_v2.py
python/sglang/srt/models/deepseek_v2.py
+19
-17
No files found.
python/sglang/srt/configs/model_config.py
View file @
b1100846
...
...
@@ -14,6 +14,7 @@
import
json
import
logging
import
math
from
enum
import
IntEnum
,
auto
from
typing
import
List
,
Optional
,
Set
,
Union
...
...
@@ -103,7 +104,20 @@ class ModelConfig:
self
.
head_dim
=
256
self
.
attention_arch
=
AttentionArch
.
MLA
self
.
kv_lora_rank
=
self
.
hf_config
.
kv_lora_rank
self
.
qk_nope_head_dim
=
self
.
hf_config
.
qk_nope_head_dim
self
.
qk_rope_head_dim
=
self
.
hf_config
.
qk_rope_head_dim
self
.
v_head_dim
=
self
.
hf_config
.
v_head_dim
# Handle rope scaling with yarn
self
.
scaling
=
1
/
math
.
sqrt
(
self
.
qk_nope_head_dim
+
self
.
qk_rope_head_dim
)
if
self
.
hf_config
.
rope_scaling
:
mscale_all_dim
=
self
.
hf_config
.
rope_scaling
.
get
(
"mscale_all_dim"
,
False
)
scaling_factor
=
self
.
hf_config
.
rope_scaling
[
"factor"
]
mscale
=
yarn_get_mscale
(
scaling_factor
,
float
(
mscale_all_dim
))
self
.
scaling
=
self
.
scaling
*
mscale
*
mscale
elif
"MiniCPM3ForCausalLM"
in
self
.
hf_config
.
architectures
:
self
.
head_dim
=
128
self
.
attention_arch
=
AttentionArch
.
MLA
...
...
@@ -414,3 +428,9 @@ def is_multimodal_model(model_architectures: List[str]):
def
is_encoder_decoder_model
(
model_architectures
:
List
[
str
]):
return
"MllamaForConditionalGeneration"
in
model_architectures
def
yarn_get_mscale
(
scale
:
float
=
1
,
mscale
:
float
=
1
)
->
float
:
if
scale
<=
1
:
return
1.0
return
0.1
*
mscale
*
math
.
log
(
scale
)
+
1.0
python/sglang/srt/layers/attention/flashinfer_mla_backend.py
0 → 100644
View file @
b1100846
from
__future__
import
annotations
"""
Support attention backend for flashinfer MLA.
When radix cache is enabled, the backend only uses BatchMLAPaged wrapper when forwarding.
When radix cache is disabled, the backend uses BatchPrefill wrappers for prefilling (with or without prefix cache),
and uses BatchMLAPaged wrapper for decoding.
More details can be found in https://docs.flashinfer.ai/api/mla.html
"""
import
math
from
dataclasses
import
dataclass
from
typing
import
TYPE_CHECKING
,
Optional
,
Union
import
torch
from
sglang.global_config
import
global_config
from
sglang.srt.layers.attention
import
AttentionBackend
from
sglang.srt.layers.attention.flashinfer_backend
import
(
create_flashinfer_kv_indices_triton
,
should_use_tensor_core
,
)
from
sglang.srt.layers.dp_attention
import
get_attention_tp_size
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
ForwardMode
from
sglang.srt.utils
import
is_flashinfer_available
if
TYPE_CHECKING
:
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.model_executor.model_runner
import
ModelRunner
from
sglang.srt.speculative.spec_info
import
SpecInfo
if
is_flashinfer_available
():
from
flashinfer
import
(
BatchPrefillWithPagedKVCacheWrapper
,
BatchPrefillWithRaggedKVCacheWrapper
,
)
from
flashinfer.cascade
import
merge_state
from
flashinfer.mla
import
BatchMLAPagedAttentionWrapper
@
dataclass
class
DecodeMetadata
:
decode_wrapper
:
BatchMLAPagedAttentionWrapper
@
dataclass
class
PrefillMetadata
:
prefill_wrapper
:
Union
[
BatchPrefillWithPagedKVCacheWrapper
,
BatchMLAPagedAttentionWrapper
]
use_ragged
:
bool
# Reuse this workspace buffer across all flashinfer wrappers
global_workspace_buffer
=
None
class
FlashInferMLAAttnBackend
(
AttentionBackend
):
"""Flashinfer attention kernels."""
def
__init__
(
self
,
model_runner
:
ModelRunner
,
kv_indptr_buf
:
Optional
[
torch
.
Tensor
]
=
None
,
):
super
().
__init__
()
# Parse constants
self
.
max_context_len
=
model_runner
.
model_config
.
context_len
global_config
.
enable_flashinfer_mla
=
True
# Allocate buffers
global
global_workspace_buffer
if
global_workspace_buffer
is
None
:
global_workspace_buffer
=
torch
.
empty
(
global_config
.
flashinfer_workspace_size
,
dtype
=
torch
.
uint8
,
device
=
model_runner
.
device
,
)
self
.
workspace_buffer
=
global_workspace_buffer
max_bs
=
model_runner
.
req_to_token_pool
.
size
if
kv_indptr_buf
is
None
:
self
.
kv_indptr
=
torch
.
zeros
(
(
max_bs
+
1
,),
dtype
=
torch
.
int32
,
device
=
model_runner
.
device
)
else
:
self
.
kv_indptr
=
kv_indptr_buf
self
.
qo_indptr
=
torch
.
zeros
(
(
max_bs
+
1
,),
dtype
=
torch
.
int32
,
device
=
model_runner
.
device
)
self
.
kv_last_page_len
=
torch
.
ones
(
(
max_bs
,),
dtype
=
torch
.
int32
,
device
=
model_runner
.
device
)
self
.
prefill_wrapper_ragged
=
BatchPrefillWithRaggedKVCacheWrapper
(
self
.
workspace_buffer
,
"NHD"
)
if
not
global_server_args_dict
[
"disable_radix_cache"
]:
# use mla paged prefill
self
.
prefill_wrapper_paged
=
BatchMLAPagedAttentionWrapper
(
self
.
workspace_buffer
,
backend
=
"auto"
,
)
else
:
self
.
prefill_wrapper_paged
=
BatchPrefillWithPagedKVCacheWrapper
(
self
.
workspace_buffer
,
"NHD"
,
backend
=
"auto"
,
)
self
.
decode_wrapper
=
BatchMLAPagedAttentionWrapper
(
self
.
workspace_buffer
,
backend
=
"auto"
)
# Create indices updater
self
.
indices_updater_prefill
=
FlashInferMLAIndicesUpdaterPrefill
(
model_runner
,
self
)
self
.
indices_updater_decode
=
FlashInferMLAIndicesUpdaterDecode
(
model_runner
,
self
)
# Other metadata
self
.
forward_metadata
:
Union
[
PrefillMetadata
,
DecodeMetadata
]
=
None
self
.
decode_cuda_graph_metadata
=
{}
self
.
prefill_cuda_graph_metadata
=
{}
def
init_forward_metadata
(
self
,
forward_batch
:
ForwardBatch
):
if
forward_batch
.
forward_mode
.
is_decode_or_idle
():
self
.
indices_updater_decode
.
update
(
forward_batch
.
req_pool_indices
,
forward_batch
.
seq_lens
,
forward_batch
.
seq_lens_sum
,
decode_wrapper
=
self
.
decode_wrapper
,
)
self
.
forward_metadata
=
DecodeMetadata
(
self
.
decode_wrapper
)
else
:
prefix_lens
=
forward_batch
.
extend_prefix_lens
use_ragged
=
global_server_args_dict
[
"disable_radix_cache"
]
self
.
indices_updater_prefill
.
update
(
forward_batch
.
req_pool_indices
,
forward_batch
.
seq_lens
,
forward_batch
.
seq_lens_sum
,
prefix_lens
,
prefill_wrapper_paged
=
self
.
prefill_wrapper_paged
,
use_ragged
=
use_ragged
,
)
self
.
forward_metadata
=
PrefillMetadata
(
self
.
prefill_wrapper_paged
,
use_ragged
)
def
init_cuda_graph_state
(
self
,
max_bs
:
int
,
kv_indices_buf
:
Optional
[
torch
.
Tensor
]
=
None
):
if
kv_indices_buf
is
None
:
cuda_graph_kv_indices
=
torch
.
zeros
(
(
max_bs
*
self
.
max_context_len
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
,
)
else
:
cuda_graph_kv_indices
=
kv_indices_buf
self
.
cuda_graph_kv_indices
=
cuda_graph_kv_indices
self
.
cuda_graph_custom_mask
=
torch
.
zeros
(
(
max_bs
*
self
.
max_context_len
),
dtype
=
torch
.
uint8
,
device
=
"cuda"
,
)
self
.
cuda_graph_qk_indptr
=
self
.
kv_indptr
.
clone
()
self
.
cuda_graph_qo_indptr
=
self
.
kv_indptr
.
clone
()
def
init_forward_metadata_capture_cuda_graph
(
self
,
bs
:
int
,
num_tokens
:
int
,
req_pool_indices
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
encoder_lens
:
Optional
[
torch
.
Tensor
],
forward_mode
:
ForwardMode
,
spec_info
:
Optional
[
SpecInfo
],
):
if
forward_mode
.
is_decode_or_idle
():
decode_wrapper
=
BatchMLAPagedAttentionWrapper
(
self
.
workspace_buffer
,
use_cuda_graph
=
True
,
qo_indptr
=
self
.
qo_indptr
[:
num_tokens
+
1
],
kv_indptr
=
self
.
kv_indptr
[:
num_tokens
+
1
],
kv_indices
=
self
.
cuda_graph_kv_indices
,
kv_len_arr
=
self
.
kv_last_page_len
[:
num_tokens
],
backend
=
"auto"
,
)
seq_lens_sum
=
seq_lens
.
sum
().
item
()
self
.
indices_updater_decode
.
update
(
req_pool_indices
,
seq_lens
,
seq_lens_sum
,
decode_wrapper
=
decode_wrapper
,
)
self
.
decode_cuda_graph_metadata
[
bs
]
=
decode_wrapper
self
.
forward_metadata
=
DecodeMetadata
(
decode_wrapper
)
else
:
raise
ValueError
(
f
"Invalid mode:
{
forward_mode
=
}
"
)
def
init_forward_metadata_replay_cuda_graph
(
self
,
bs
:
int
,
req_pool_indices
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
seq_lens_sum
:
int
,
encoder_lens
:
Optional
[
torch
.
Tensor
],
forward_mode
:
ForwardMode
,
spec_info
:
Optional
[
SpecInfo
],
):
if
forward_mode
.
is_decode_or_idle
():
self
.
indices_updater_decode
.
update
(
req_pool_indices
[:
bs
],
seq_lens
[:
bs
],
seq_lens_sum
,
decode_wrapper
=
self
.
decode_cuda_graph_metadata
[
bs
],
)
else
:
raise
ValueError
(
f
"Invalid forward mode:
{
forward_mode
=
}
"
)
def
get_cuda_graph_seq_len_fill_value
(
self
):
return
0
def
forward_extend
(
self
,
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
layer
:
RadixAttention
,
forward_batch
:
ForwardBatch
,
save_kv_cache
=
True
,
):
cache_loc
=
forward_batch
.
out_cache_loc
logits_soft_cap
=
layer
.
logit_cap
if
not
global_server_args_dict
[
"disable_radix_cache"
]:
# use mla paged prefill
prefill_wrapper_paged
=
self
.
forward_metadata
.
prefill_wrapper
if
k
is
not
None
:
assert
v
is
not
None
if
save_kv_cache
:
forward_batch
.
token_to_kv_pool
.
set_kv_buffer
(
layer
,
cache_loc
,
k
,
v
)
qall
=
q
.
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
head_dim
)
k_buf
=
forward_batch
.
token_to_kv_pool
.
get_key_buffer
(
layer
.
layer_id
)
o
=
prefill_wrapper_paged
.
run
(
qall
[:,
:,
:
layer
.
v_head_dim
],
qall
[:,
:,
layer
.
v_head_dim
:],
k_buf
[:,
:,
:
layer
.
v_head_dim
],
k_buf
[:,
:,
layer
.
v_head_dim
:],
)
else
:
# use mla ragged prefill
o
,
_
=
self
.
prefill_wrapper_ragged
.
forward_return_lse
(
q
.
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
head_dim
),
k
.
view
(
-
1
,
layer
.
tp_k_head_num
,
layer
.
head_dim
),
v
.
view
(
-
1
,
layer
.
tp_v_head_num
,
layer
.
v_head_dim
),
causal
=
True
,
sm_scale
=
layer
.
scaling
,
logits_soft_cap
=
logits_soft_cap
,
)
# FIXME: Here should be another prefill_paged to call
if
save_kv_cache
:
forward_batch
.
token_to_kv_pool
.
set_kv_buffer
(
layer
,
cache_loc
,
k
,
v
,
)
return
o
.
view
(
-
1
,
layer
.
tp_q_head_num
*
layer
.
v_head_dim
)
def
forward_decode
(
self
,
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
layer
:
RadixAttention
,
forward_batch
:
ForwardBatch
,
save_kv_cache
=
True
,
):
decode_wrapper
=
self
.
forward_metadata
.
decode_wrapper
cache_loc
=
forward_batch
.
out_cache_loc
if
k
is
not
None
:
assert
v
is
not
None
if
save_kv_cache
:
forward_batch
.
token_to_kv_pool
.
set_kv_buffer
(
layer
,
cache_loc
,
k
,
v
,
)
reshaped_q
=
q
.
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
head_dim
)
k_buffer
=
forward_batch
.
token_to_kv_pool
.
get_key_buffer
(
layer
.
layer_id
)
reshaped_k
=
k_buffer
.
view
(
-
1
,
1
,
layer
.
head_dim
)
o
=
decode_wrapper
.
run
(
reshaped_q
[:,
:,
:
layer
.
v_head_dim
],
reshaped_q
[:,
:,
layer
.
v_head_dim
:],
reshaped_k
[:,
:,
:
layer
.
v_head_dim
],
reshaped_k
[:,
:,
layer
.
v_head_dim
:],
)
return
o
.
view
(
-
1
,
layer
.
tp_q_head_num
*
layer
.
v_head_dim
)
class
FlashInferMLAIndicesUpdaterDecode
:
def
__init__
(
self
,
model_runner
:
ModelRunner
,
attn_backend
:
AttentionBackend
):
# Parse Constants
self
.
num_local_heads
=
(
model_runner
.
model_config
.
num_attention_heads
//
get_attention_tp_size
()
)
self
.
kv_lora_rank
=
model_runner
.
model_config
.
kv_lora_rank
self
.
qk_nope_head_dim
=
model_runner
.
model_config
.
qk_nope_head_dim
self
.
qk_rope_head_dim
=
model_runner
.
model_config
.
qk_rope_head_dim
self
.
scaling
=
model_runner
.
model_config
.
scaling
self
.
data_type
=
model_runner
.
kv_cache_dtype
self
.
attn_backend
=
attn_backend
# Buffers and wrappers
self
.
kv_indptr
=
attn_backend
.
kv_indptr
self
.
kv_last_page_len
=
attn_backend
.
kv_last_page_len
self
.
req_to_token
=
model_runner
.
req_to_token_pool
.
req_to_token
def
update
(
self
,
req_pool_indices
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
seq_lens_sum
:
int
,
decode_wrapper
:
BatchMLAPagedAttentionWrapper
,
):
decode_wrappers
=
decode_wrapper
or
self
.
decode_wrapper
self
.
call_begin_forward
(
decode_wrapper
,
req_pool_indices
,
seq_lens
,
seq_lens_sum
,
self
.
kv_indptr
,
)
def
call_begin_forward
(
self
,
wrapper
:
BatchMLAPagedAttentionWrapper
,
req_pool_indices
:
torch
.
Tensor
,
paged_kernel_lens
:
torch
.
Tensor
,
paged_kernel_lens_sum
:
int
,
kv_indptr
:
torch
.
Tensor
,
):
bs
=
len
(
req_pool_indices
)
kv_indptr
[
1
:
bs
+
1
]
=
torch
.
cumsum
(
paged_kernel_lens
,
dim
=
0
)
kv_indptr
=
kv_indptr
[:
bs
+
1
]
kv_indices
=
torch
.
empty
(
paged_kernel_lens_sum
,
dtype
=
torch
.
int32
,
device
=
"cuda"
)
create_flashinfer_kv_indices_triton
[(
bs
,)](
self
.
req_to_token
,
req_pool_indices
,
paged_kernel_lens
,
kv_indptr
,
None
,
kv_indices
,
self
.
req_to_token
.
shape
[
1
],
)
sm_scale
=
self
.
scaling
q_indptr
=
torch
.
arange
(
0
,
bs
+
1
).
to
(
0
).
int
()
kv_lens
=
paged_kernel_lens
.
to
(
torch
.
int32
)
wrapper
.
plan
(
q_indptr
,
kv_indptr
,
kv_indices
,
kv_lens
,
self
.
num_local_heads
,
self
.
kv_lora_rank
,
self
.
qk_rope_head_dim
,
1
,
False
,
sm_scale
,
self
.
data_type
,
self
.
data_type
,
)
class
FlashInferMLAIndicesUpdaterPrefill
:
def
__init__
(
self
,
model_runner
:
ModelRunner
,
attn_backend
:
AttentionBackend
):
# Parse Constants
self
.
num_qo_heads
=
(
model_runner
.
model_config
.
num_attention_heads
//
get_attention_tp_size
()
)
self
.
num_kv_heads
=
model_runner
.
model_config
.
get_num_kv_heads
(
get_attention_tp_size
()
)
self
.
kv_lora_rank
=
model_runner
.
model_config
.
kv_lora_rank
self
.
qk_nope_head_dim
=
model_runner
.
model_config
.
qk_nope_head_dim
self
.
qk_rope_head_dim
=
model_runner
.
model_config
.
qk_rope_head_dim
self
.
v_head_dim
=
model_runner
.
model_config
.
v_head_dim
self
.
scaling
=
model_runner
.
model_config
.
scaling
self
.
data_type
=
model_runner
.
kv_cache_dtype
self
.
q_data_type
=
model_runner
.
dtype
self
.
attn_backend
=
attn_backend
# Buffers and wrappers
self
.
kv_indptr
=
attn_backend
.
kv_indptr
self
.
kv_last_page_len
=
attn_backend
.
kv_last_page_len
self
.
qo_indptr
=
attn_backend
.
qo_indptr
self
.
req_to_token
=
model_runner
.
req_to_token_pool
.
req_to_token
self
.
prefill_wrapper_ragged
=
attn_backend
.
prefill_wrapper_ragged
def
update
(
self
,
req_pool_indices
:
torch
.
Tnesor
,
seq_lens
:
torch
.
Tensor
,
seq_lens_sum
:
int
,
prefix_lens
:
torch
.
Tensor
,
prefill_wrapper_paged
:
Union
[
BatchPrefillWithPagedKVCacheWrapper
,
BatchMLAPagedAttentionWrapper
],
use_ragged
:
bool
,
):
if
use_ragged
:
paged_kernel_lens
=
prefix_lens
paged_kernel_lens_sum
=
paged_kernel_lens
.
sum
().
item
()
else
:
paged_kernel_lens
=
seq_lens
paged_kernel_lens_sum
=
seq_lens_sum
self
.
call_begin_forward
(
self
.
prefill_wrapper_ragged
,
prefill_wrapper_paged
,
req_pool_indices
,
paged_kernel_lens
,
paged_kernel_lens_sum
,
seq_lens
,
prefix_lens
,
self
.
kv_indptr
,
self
.
qo_indptr
,
use_ragged
,
)
def
call_begin_forward
(
self
,
wrapper_ragged
:
BatchPrefillWithRaggedKVCacheWrapper
,
wrapper_paged
:
Union
[
BatchPrefillWithPagedKVCacheWrapper
,
BatchMLAPagedAttentionWrapper
],
req_pool_indices
:
torch
.
Tensor
,
paged_kernel_lens
:
torch
.
Tensor
,
paged_kernel_lens_sum
:
int
,
seq_lens
:
torch
.
Tensor
,
prefix_lens
:
torch
.
Tensor
,
kv_indptr
:
torch
.
Tensor
,
qo_indptr
:
torch
.
Tensor
,
use_ragged
:
bool
,
):
bs
=
len
(
req_pool_indices
)
# Normal extend
kv_indptr
[
1
:
bs
+
1
]
=
torch
.
cumsum
(
paged_kernel_lens
,
dim
=
0
)
kv_indptr
=
kv_indptr
[:
bs
+
1
]
kv_indices
=
torch
.
empty
(
paged_kernel_lens_sum
,
dtype
=
torch
.
int32
,
device
=
req_pool_indices
.
device
,
)
create_flashinfer_kv_indices_triton
[(
bs
,)](
self
.
req_to_token
,
req_pool_indices
,
paged_kernel_lens
,
kv_indptr
,
None
,
kv_indices
,
self
.
req_to_token
.
shape
[
1
],
)
qo_indptr
[
1
:
bs
+
1
]
=
torch
.
cumsum
(
seq_lens
-
prefix_lens
,
dim
=
0
)
qo_indptr
=
qo_indptr
[:
bs
+
1
]
sm_scale
=
self
.
scaling
# extend part
if
use_ragged
:
wrapper_ragged
.
begin_forward
(
qo_indptr
=
qo_indptr
,
kv_indptr
=
qo_indptr
,
num_qo_heads
=
self
.
num_qo_heads
,
num_kv_heads
=
self
.
num_kv_heads
,
head_dim_qk
=
self
.
qk_nope_head_dim
+
self
.
qk_rope_head_dim
,
head_dim_vo
=
self
.
v_head_dim
,
q_data_type
=
self
.
q_data_type
,
)
if
not
global_server_args_dict
[
"disable_radix_cache"
]:
# mla paged prefill
kv_len_arr
=
kv_indptr
[
1
:]
-
kv_indptr
[:
-
1
]
wrapper_paged
.
plan
(
qo_indptr
,
kv_indptr
,
kv_indices
,
kv_len_arr
,
self
.
num_qo_heads
,
self
.
kv_lora_rank
,
self
.
qk_rope_head_dim
,
1
,
True
,
sm_scale
,
self
.
q_data_type
,
self
.
data_type
,
)
# FIXME: Here should be some logic for prefill paged when not using radix cache?
python/sglang/srt/model_executor/model_runner.py
View file @
b1100846
...
...
@@ -34,6 +34,7 @@ from sglang.srt.distributed import (
from
sglang.srt.distributed.parallel_state
import
monkey_patch_vllm_parallel_state
from
sglang.srt.layers.attention.double_sparsity_backend
import
DoubleSparseAttnBackend
from
sglang.srt.layers.attention.flashinfer_backend
import
FlashInferAttnBackend
from
sglang.srt.layers.attention.flashinfer_mla_backend
import
FlashInferMLAAttnBackend
from
sglang.srt.layers.attention.torch_native_backend
import
TorchNativeAttnBackend
from
sglang.srt.layers.attention.triton_backend
import
TritonAttnBackend
from
sglang.srt.layers.dp_attention
import
(
...
...
@@ -113,9 +114,9 @@ class ModelRunner:
if
self
.
server_args
.
device
!=
"cpu"
:
if
server_args
.
enable_flashinfer_mla
:
logger
.
info
(
"
FlashInfer
MLA optimization is turned on. Use flashinfer backend
for DeepseekV3ForCausalLM
."
"MLA optimization is turned on. Use flashinfer
mla
backend."
)
self
.
server_args
.
attention_backend
=
"flashinfer"
self
.
server_args
.
attention_backend
=
"flashinfer
_mla
"
else
:
logger
.
info
(
"MLA optimization is turned on. Use triton backend."
)
self
.
server_args
.
attention_backend
=
"triton"
...
...
@@ -703,6 +704,8 @@ class ModelRunner:
self
.
attn_backend
=
TritonAttnBackend
(
self
)
elif
self
.
server_args
.
attention_backend
==
"torch_native"
:
self
.
attn_backend
=
TorchNativeAttnBackend
(
self
)
elif
self
.
server_args
.
attention_backend
==
"flashinfer_mla"
:
self
.
attn_backend
=
FlashInferMLAAttnBackend
(
self
)
else
:
raise
ValueError
(
f
"Invalid attention backend:
{
self
.
server_args
.
attention_backend
}
"
...
...
python/sglang/srt/models/deepseek_v2.py
View file @
b1100846
...
...
@@ -510,25 +510,27 @@ class DeepseekV2AttentionMLA(nn.Module):
hidden_states
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
,
)
->
torch
.
Tensor
:
if
global_server_args_dict
[
"enable_flashinfer_mla"
]:
if
global_server_args_dict
[
"disable_radix_cache"
]:
if
forward_batch
.
forward_mode
.
is_extend
():
return
self
.
forward_normal
(
positions
,
hidden_states
,
forward_batch
)
else
:
return
self
.
forward_absorb
(
positions
,
hidden_states
,
forward_batch
)
def
no_absorb
()
->
bool
:
if
global_server_args_dict
[
"enable_flashinfer_mla"
]:
# Flashinfer MLA: Only do not use absorb when prefilling/extending without radix cache
return
(
global_server_args_dict
[
"disable_radix_cache"
]
and
forward_batch
.
forward_mode
.
is_extend
()
)
else
:
return
self
.
forward_absorb
(
positions
,
hidden_states
,
forward_batch
)
# Triton: Use normal computation for prefill and use weight absorption for extend/decode
return
(
forward_batch
.
forward_mode
.
is_extend
()
and
not
forward_batch
.
forward_mode
.
is_target_verify
()
and
not
forward_batch
.
forward_mode
.
is_draft_extend
()
and
forward_batch
.
extend_prefix_lens
.
sum
()
==
0
)
if
no_absorb
():
return
self
.
forward_normal
(
positions
,
hidden_states
,
forward_batch
)
else
:
# Triton: Use normal computation for prefill and use weight absorption for extend/decode
if
(
forward_batch
.
forward_mode
.
is_extend
()
and
not
forward_batch
.
forward_mode
.
is_target_verify
()
and
not
forward_batch
.
forward_mode
.
is_draft_extend
()
and
forward_batch
.
extend_prefix_lens
.
sum
()
==
0
):
return
self
.
forward_normal
(
positions
,
hidden_states
,
forward_batch
)
else
:
return
self
.
forward_absorb
(
positions
,
hidden_states
,
forward_batch
)
return
self
.
forward_absorb
(
positions
,
hidden_states
,
forward_batch
)
def
forward_normal
(
self
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment