Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
2695ab05
Unverified
Commit
2695ab05
authored
Apr 08, 2025
by
Yun Dai
Committed by
GitHub
Apr 08, 2025
Browse files
Fix loading KV quantization scale; Enable modelopt kv cache (#4686)
Co-authored-by:
qingquansong
<
ustcsqq@gmail.com
>
parent
88d6fd9a
Changes
38
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
135 additions
and
73 deletions
+135
-73
python/sglang/srt/configs/model_config.py
python/sglang/srt/configs/model_config.py
+1
-1
python/sglang/srt/layers/attention/flashattention_backend.py
python/sglang/srt/layers/attention/flashattention_backend.py
+23
-8
python/sglang/srt/layers/attention/flashinfer_backend.py
python/sglang/srt/layers/attention/flashinfer_backend.py
+13
-7
python/sglang/srt/layers/quantization/kv_cache.py
python/sglang/srt/layers/quantization/kv_cache.py
+43
-52
python/sglang/srt/layers/quantization/modelopt_quant.py
python/sglang/srt/layers/quantization/modelopt_quant.py
+25
-4
python/sglang/srt/layers/radix_attention.py
python/sglang/srt/layers/radix_attention.py
+13
-1
python/sglang/srt/models/baichuan.py
python/sglang/srt/models/baichuan.py
+2
-0
python/sglang/srt/models/chatglm.py
python/sglang/srt/models/chatglm.py
+1
-0
python/sglang/srt/models/commandr.py
python/sglang/srt/models/commandr.py
+1
-0
python/sglang/srt/models/dbrx.py
python/sglang/srt/models/dbrx.py
+1
-0
python/sglang/srt/models/deepseek.py
python/sglang/srt/models/deepseek.py
+1
-0
python/sglang/srt/models/deepseek_v2.py
python/sglang/srt/models/deepseek_v2.py
+3
-0
python/sglang/srt/models/exaone.py
python/sglang/srt/models/exaone.py
+1
-0
python/sglang/srt/models/gemma.py
python/sglang/srt/models/gemma.py
+1
-0
python/sglang/srt/models/gemma2.py
python/sglang/srt/models/gemma2.py
+1
-0
python/sglang/srt/models/gemma3_causal.py
python/sglang/srt/models/gemma3_causal.py
+1
-0
python/sglang/srt/models/gpt2.py
python/sglang/srt/models/gpt2.py
+1
-0
python/sglang/srt/models/gpt_bigcode.py
python/sglang/srt/models/gpt_bigcode.py
+1
-0
python/sglang/srt/models/granite.py
python/sglang/srt/models/granite.py
+1
-0
python/sglang/srt/models/grok.py
python/sglang/srt/models/grok.py
+1
-0
No files found.
python/sglang/srt/configs/model_config.py
View file @
2695ab05
...
...
@@ -239,7 +239,7 @@ class ModelConfig:
# check if is modelopt model -- modelopt doesn't have corresponding field
# in hf `config.json` but has a standalone `hf_quant_config.json` in the root directory
# example: https://huggingface.co/nvidia/Llama-3.1-8B-Instruct-FP8/tree/main
is_local
=
os
.
path
.
isdir
(
self
.
model_path
)
is_local
=
os
.
path
.
exists
(
self
.
model_path
)
modelopt_quant_config
=
{
"quant_method"
:
"modelopt"
}
if
not
is_local
:
from
huggingface_hub
import
HfApi
...
...
python/sglang/srt/layers/attention/flashattention_backend.py
View file @
2695ab05
...
...
@@ -292,6 +292,8 @@ class FlashAttentionBackend(AttentionBackend):
self
.
decode_cuda_graph_metadata
=
{}
self
.
target_verify_metadata
=
{}
self
.
req_to_token
=
model_runner
.
req_to_token_pool
.
req_to_token
self
.
kv_cache_dtype
=
model_runner
.
kv_cache_dtype
self
.
kv_cache_dtype_str
=
model_runner
.
server_args
.
kv_cache_dtype
self
.
page_size
=
model_runner
.
page_size
self
.
use_mla
=
(
model_runner
.
model_config
.
attention_arch
==
AttentionArch
.
MLA
...
...
@@ -520,6 +522,12 @@ class FlashAttentionBackend(AttentionBackend):
if
layer
.
sliding_window_size
is
not
None
else
(
-
1
,
-
1
)
)
k_descale
,
v_descale
=
None
,
None
if
self
.
kv_cache_dtype_str
!=
"auto"
:
descale_shape
=
(
forward_batch
.
batch_size
,
layer
.
tp_k_head_num
)
k_descale
=
layer
.
k_scale
.
expand
(
descale_shape
)
v_descale
=
layer
.
v_scale
.
expand
(
descale_shape
)
q
=
q
.
to
(
self
.
kv_cache_dtype
)
causal
=
not
layer
.
is_cross_attention
# Check if we should use local attention
...
...
@@ -576,8 +584,8 @@ class FlashAttentionBackend(AttentionBackend):
causal
=
causal
,
window_size
=
window_size
,
softcap
=
layer
.
logit_cap
,
k_descale
=
layer
.
k_scale
,
v_descale
=
layer
.
v_scale
,
k_descale
=
k_
de
scale
,
v_descale
=
v_
de
scale
,
)
else
:
# Do absorbed multi-latent attention
...
...
@@ -609,8 +617,8 @@ class FlashAttentionBackend(AttentionBackend):
softmax_scale
=
layer
.
scaling
,
causal
=
True
,
softcap
=
layer
.
logit_cap
,
k_descale
=
layer
.
k_scale
,
v_descale
=
layer
.
v_scale
,
k_descale
=
k_
de
scale
,
v_descale
=
v_
de
scale
,
)
return
o
.
view
(
-
1
,
layer
.
tp_q_head_num
*
layer
.
v_head_dim
)
...
...
@@ -657,6 +665,13 @@ class FlashAttentionBackend(AttentionBackend):
)
causal
=
not
layer
.
is_cross_attention
k_descale
,
v_descale
=
None
,
None
if
self
.
kv_cache_dtype_str
!=
"auto"
:
descale_shape
=
(
forward_batch
.
batch_size
,
layer
.
tp_k_head_num
)
k_descale
=
layer
.
k_scale
.
expand
(
descale_shape
)
v_descale
=
layer
.
v_scale
.
expand
(
descale_shape
)
q
=
q
.
to
(
self
.
kv_cache_dtype
)
if
not
self
.
use_mla
:
# Do multi-head attention
...
...
@@ -694,8 +709,8 @@ class FlashAttentionBackend(AttentionBackend):
causal
=
causal
,
window_size
=
window_size
,
softcap
=
layer
.
logit_cap
,
k_descale
=
layer
.
k_scale
,
v_descale
=
layer
.
v_scale
,
k_descale
=
k_
de
scale
,
v_descale
=
v_
de
scale
,
)
else
:
# Do absorbed multi-latent attention
...
...
@@ -729,8 +744,8 @@ class FlashAttentionBackend(AttentionBackend):
softmax_scale
=
layer
.
scaling
,
causal
=
True
,
softcap
=
layer
.
logit_cap
,
k_descale
=
layer
.
k_scale
,
v_descale
=
layer
.
v_scale
,
k_descale
=
k_
de
scale
,
v_descale
=
v_
de
scale
,
)
return
o
.
view
(
-
1
,
layer
.
tp_q_head_num
*
layer
.
v_head_dim
)
...
...
python/sglang/srt/layers/attention/flashinfer_backend.py
View file @
2695ab05
...
...
@@ -82,6 +82,8 @@ class FlashInferAttnBackend(AttentionBackend):
self
.
max_context_len
=
model_runner
.
model_config
.
context_len
self
.
skip_prefill
=
skip_prefill
self
.
is_multimodal
=
model_runner
.
model_config
.
is_multimodal
self
.
kv_cache_dtype
=
model_runner
.
kv_cache_dtype
self
.
kv_cache_dtype_str
=
model_runner
.
server_args
.
kv_cache_dtype
assert
not
(
model_runner
.
sliding_window_size
is
not
None
...
...
@@ -391,6 +393,8 @@ class FlashInferAttnBackend(AttentionBackend):
forward_batch
:
ForwardBatch
,
save_kv_cache
=
True
,
):
k_scale
=
layer
.
k_scale_float
if
self
.
kv_cache_dtype_str
!=
"auto"
else
None
v_scale
=
layer
.
v_scale_float
if
self
.
kv_cache_dtype_str
!=
"auto"
else
None
prefill_wrapper_paged
=
self
.
forward_metadata
.
prefill_wrappers
[
self
.
_get_wrapper_idx
(
layer
)
]
...
...
@@ -407,7 +411,7 @@ class FlashInferAttnBackend(AttentionBackend):
assert
v
is
not
None
if
save_kv_cache
:
forward_batch
.
token_to_kv_pool
.
set_kv_buffer
(
layer
,
cache_loc
,
k
,
v
,
layer
.
k_scale
,
layer
.
v_scale
layer
,
cache_loc
,
k
,
v
,
k_scale
,
v_scale
)
o
=
prefill_wrapper_paged
.
forward
(
...
...
@@ -417,8 +421,8 @@ class FlashInferAttnBackend(AttentionBackend):
sm_scale
=
layer
.
scaling
,
window_left
=
layer
.
sliding_window_size
,
logits_soft_cap
=
logits_soft_cap
,
k_scale
=
layer
.
k_scale
,
v_scale
=
layer
.
v_scale
,
k_scale
=
k_scale
,
v_scale
=
v_scale
,
)
else
:
o1
,
s1
=
self
.
prefill_wrapper_ragged
.
forward_return_lse
(
...
...
@@ -445,7 +449,7 @@ class FlashInferAttnBackend(AttentionBackend):
if
save_kv_cache
:
forward_batch
.
token_to_kv_pool
.
set_kv_buffer
(
layer
,
cache_loc
,
k
,
v
,
layer
.
k_scale
,
layer
.
v_scale
layer
,
cache_loc
,
k
,
v
,
k_scale
,
v_scale
)
return
o
.
view
(
-
1
,
layer
.
tp_q_head_num
*
layer
.
head_dim
)
...
...
@@ -459,6 +463,8 @@ class FlashInferAttnBackend(AttentionBackend):
forward_batch
:
ForwardBatch
,
save_kv_cache
=
True
,
):
k_scale
=
layer
.
k_scale_float
if
self
.
kv_cache_dtype_str
!=
"auto"
else
None
v_scale
=
layer
.
v_scale_float
if
self
.
kv_cache_dtype_str
!=
"auto"
else
None
decode_wrapper
=
self
.
forward_metadata
.
decode_wrappers
[
self
.
_get_wrapper_idx
(
layer
)
]
...
...
@@ -472,7 +478,7 @@ class FlashInferAttnBackend(AttentionBackend):
assert
v
is
not
None
if
save_kv_cache
:
forward_batch
.
token_to_kv_pool
.
set_kv_buffer
(
layer
,
cache_loc
,
k
,
v
,
layer
.
k_scale
,
layer
.
v_scale
layer
,
cache_loc
,
k
,
v
,
k_scale
,
v_scale
)
o
=
decode_wrapper
.
forward
(
...
...
@@ -480,8 +486,8 @@ class FlashInferAttnBackend(AttentionBackend):
forward_batch
.
token_to_kv_pool
.
get_kv_buffer
(
layer
.
layer_id
),
sm_scale
=
layer
.
scaling
,
logits_soft_cap
=
layer
.
logit_cap
,
k_scale
=
layer
.
k_scale
,
v_scale
=
layer
.
v_scale
,
k_scale
=
k_scale
,
v_scale
=
v_scale
,
)
return
o
.
view
(
-
1
,
layer
.
tp_q_head_num
*
layer
.
head_dim
)
...
...
python/sglang/srt/layers/quantization/kv_cache.py
View file @
2695ab05
...
...
@@ -8,6 +8,7 @@ from sglang.srt.layers.quantization.base_config import (
QuantizationConfig
,
QuantizeMethodBase
,
)
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.utils
import
is_hip
_is_hip
=
is_hip
()
...
...
@@ -17,7 +18,7 @@ logger = logging.getLogger(__name__)
class
BaseKVCacheMethod
(
QuantizeMethodBase
):
"""
Quant method that adds `
_
k_scale` and `
_
v_scale` attributes to the
Quant method that adds `k_scale` and `v_scale` attributes to the
Attention layer to support loading those scaling factors from checkpoints.
The k/v_scale will be used to:
- quantize k/v_cache entries before saving them to the cache
...
...
@@ -36,8 +37,12 @@ class BaseKVCacheMethod(QuantizeMethodBase):
# Initialize the KV cache scales to -1.0, which is an invalid value.
# If the k/v_scale appears in the checkpoint, it will be
# overwritten when loading weights.
layer
.
k_scale
=
torch
.
nn
.
Parameter
(
torch
.
tensor
(
-
1.0
),
requires_grad
=
False
)
layer
.
v_scale
=
torch
.
nn
.
Parameter
(
torch
.
tensor
(
-
1.0
),
requires_grad
=
False
)
layer
.
k_scale
=
torch
.
nn
.
Parameter
(
torch
.
tensor
(
-
1.0
,
dtype
=
torch
.
float32
),
requires_grad
=
False
)
layer
.
v_scale
=
torch
.
nn
.
Parameter
(
torch
.
tensor
(
-
1.0
,
dtype
=
torch
.
float32
),
requires_grad
=
False
)
@
classmethod
def
is_fp8_fnuz
(
cls
)
->
bool
:
...
...
@@ -47,52 +52,38 @@ class BaseKVCacheMethod(QuantizeMethodBase):
def
apply
(
self
,
layer
:
torch
.
nn
.
Module
)
->
torch
.
Tensor
:
raise
RuntimeError
(
f
"
{
self
.
__class__
.
__name__
}
.apply should not be called."
)
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
# If the kv-cache dtype is auto, we enforce the k/v_scale to be 1.0
# regardless whether the kv-scale is available in the checkpoint.
# No need to process kv scales after loading if we are going to
# calculate them on the fly.
if
layer
.
kv_cache_dtype
!=
"auto"
and
not
layer
.
calculate_kv_scales
:
if
layer
.
k_scale
>
0.0
and
layer
.
v_scale
>
0.0
:
# We prefer to use separate k_scale and v_scale if present
k_scale
=
layer
.
k_scale
.
to
(
"cpu"
).
tolist
()
v_scale
=
layer
.
v_scale
.
to
(
"cpu"
).
tolist
()
if
_is_hip
and
self
.
is_fp8_fnuz
():
k_scale
*=
2
v_scale
*=
2
elif
layer
.
k_scale
<
0.0
and
layer
.
v_scale
<
0.0
:
# If no scales were loaded (both scales are invalid negative
# values), use the default value of 1.0
k_scale
=
1.0
v_scale
=
1.0
else
:
# If we find a single kv_scale in the checkpoint, we remap
# kv_scale to k_scale during weight loading, and duplicate
# k_scale to v_scale here
assert
layer
.
k_scale
>
0.0
scale_to_duplicate
=
max
(
layer
.
k_scale
,
layer
.
v_scale
)
k_scale
=
scale_to_duplicate
.
to
(
"cpu"
).
tolist
()
v_scale
=
scale_to_duplicate
.
to
(
"cpu"
).
tolist
()
if
_is_hip
and
self
.
is_fp8_fnuz
():
k_scale
*=
2
v_scale
*=
2
if
not
isinstance
(
k_scale
,
float
)
or
not
isinstance
(
v_scale
,
float
):
raise
ValueError
(
"Only support per-tensor scaling factor "
"for fp8 KV cache"
)
# These are used in the final Attention.forward()
layer
.
_k_scale
.
copy_
(
k_scale
)
layer
.
_v_scale
.
copy_
(
v_scale
)
layer
.
_k_scale_float
=
k_scale
layer
.
_v_scale_float
=
v_scale
if
k_scale
==
1.0
and
v_scale
==
1.0
and
"e5m2"
not
in
layer
.
kv_cache_dtype
:
logger
.
warning
(
"Using KV cache scaling factor 1.0 for fp8_e4m3. This "
"may cause accuracy issues. Please make sure k/v_scale "
"scaling factors are available in the fp8 checkpoint."
)
del
layer
.
k_scale
del
layer
.
v_scale
def
process_weights_after_loading
(
self
,
layer
:
RadixAttention
)
->
None
:
if
layer
.
k_scale
>
0.0
and
layer
.
v_scale
>
0.0
:
# We prefer to use separate k_scale and v_scale if present
k_scale
=
layer
.
k_scale
.
to
(
"cpu"
).
tolist
()
v_scale
=
layer
.
v_scale
.
to
(
"cpu"
).
tolist
()
if
_is_hip
and
self
.
is_fp8_fnuz
():
k_scale
*=
2
v_scale
*=
2
elif
layer
.
k_scale
<
0.0
and
layer
.
v_scale
<
0.0
:
# If no scales were loaded (both scales are invalid negative
# values), use the default value of 1.0
k_scale
=
1.0
v_scale
=
1.0
else
:
# If we find a single kv_scale in the checkpoint, we remap
# kv_scale to k_scale during weight loading, and duplicate
# k_scale to v_scale here
assert
layer
.
k_scale
>
0.0
scale_to_duplicate
=
max
(
layer
.
k_scale
,
layer
.
v_scale
)
k_scale
=
scale_to_duplicate
.
to
(
"cpu"
).
tolist
()
v_scale
=
scale_to_duplicate
.
to
(
"cpu"
).
tolist
()
if
_is_hip
and
self
.
is_fp8_fnuz
():
k_scale
*=
2
v_scale
*=
2
if
not
isinstance
(
k_scale
,
float
)
or
not
isinstance
(
v_scale
,
float
):
raise
ValueError
(
"Only support per-tensor scaling factor "
"for fp8 KV cache"
)
# These are used in the final Attention.forward()
layer
.
k_scale
.
copy_
(
k_scale
)
layer
.
v_scale
.
copy_
(
v_scale
)
layer
.
k_scale_float
=
k_scale
layer
.
v_scale_float
=
v_scale
python/sglang/srt/layers/quantization/modelopt_quant.py
View file @
2695ab05
...
...
@@ -6,7 +6,6 @@ from typing import Any, Dict, List, Optional
import
torch
from
torch.nn.parameter
import
Parameter
from
sglang.srt.layers.attention.base_attn_backend
import
AttentionBackend
from
sglang.srt.layers.linear
import
LinearBase
,
LinearMethodBase
from
sglang.srt.layers.parameter
import
ModelWeightParameter
,
PerTensorScaleParameter
from
sglang.srt.layers.quantization.base_config
import
(
...
...
@@ -22,6 +21,7 @@ from sglang.srt.layers.quantization.utils import (
convert_to_channelwise
,
requantize_with_max_scale
,
)
from
sglang.srt.layers.radix_attention
import
RadixAttention
# Initialize logger for the module
logger
=
logging
.
getLogger
(
__name__
)
...
...
@@ -33,12 +33,19 @@ ACTIVATION_SCHEMES = ["static"]
class
ModelOptFp8Config
(
QuantizationConfig
):
"""Configuration for ModelOpt FP8 quantization, including serialization and compatibility checks."""
def
__init__
(
self
,
is_checkpoint_fp8_serialized
:
bool
=
False
)
->
None
:
def
__init__
(
self
,
is_checkpoint_fp8_serialized
:
bool
=
False
,
kv_cache_quant_method
:
Optional
[
str
]
=
None
,
exclude_modules
:
Optional
[
List
[
str
]]
=
None
,
)
->
None
:
"""
Args:
is_checkpoint_fp8_serialized (bool): Indicates if the checkpoint uses serialized FP8 format.
"""
self
.
is_checkpoint_fp8_serialized
=
is_checkpoint_fp8_serialized
self
.
kv_cache_quant_method
=
kv_cache_quant_method
self
.
exclude_modules
=
exclude_modules
if
is_checkpoint_fp8_serialized
:
logger
.
warning
(
"Detected ModelOpt FP8 checkpoint. The format is experimental and subject to change."
...
...
@@ -63,6 +70,12 @@ class ModelOptFp8Config(QuantizationConfig):
@
classmethod
def
from_config
(
cls
,
config
:
Dict
[
str
,
Any
])
->
"ModelOptFp8Config"
:
quant_method
=
cls
.
get_from_keys
(
config
,
[
"quantization"
]).
get
(
"quant_algo"
)
kv_cache_quant_method
=
cls
.
get_from_keys
(
config
,
[
"quantization"
]).
get
(
"kv_cache_quant_algo"
)
exclude_modules
=
cls
.
get_from_keys
(
config
,
[
"quantization"
]).
get
(
"exclude_modules"
)
if
"FP8"
not
in
quant_method
:
raise
ValueError
(
...
...
@@ -70,15 +83,23 @@ class ModelOptFp8Config(QuantizationConfig):
"Check the `hf_quant_config.json` file for your model's configuration."
)
return
cls
(
is_checkpoint_fp8_serialized
=
True
)
return
cls
(
is_checkpoint_fp8_serialized
=
True
,
kv_cache_quant_method
=
kv_cache_quant_method
,
exclude_modules
=
exclude_modules
,
)
def
get_quant_method
(
self
,
layer
:
torch
.
nn
.
Module
,
prefix
:
str
)
->
Optional
[
"QuantizeMethodBase"
]:
if
self
.
exclude_modules
and
any
(
module
in
prefix
for
module
in
self
.
exclude_modules
):
return
None
if
isinstance
(
layer
,
LinearBase
):
return
ModelOptFp8LinearMethod
(
self
)
if
isinstance
(
layer
,
Attention
Backend
):
if
self
.
kv_cache_quant_method
and
isinstance
(
layer
,
Radix
Attention
):
return
ModelOptFp8KVCacheMethod
(
self
)
return
None
...
...
python/sglang/srt/layers/radix_attention.py
View file @
2695ab05
...
...
@@ -13,8 +13,12 @@
# ==============================================================================
"""Radix attention."""
from
typing
import
Optional
from
torch
import
nn
from
sglang.srt.layers.linear
import
UnquantizedLinearMethod
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
...
...
@@ -34,6 +38,7 @@ class RadixAttention(nn.Module):
v_head_dim
:
int
=
-
1
,
sliding_window_size
:
int
=
-
1
,
is_cross_attention
:
bool
=
False
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
use_irope
:
bool
=
False
,
):
...
...
@@ -49,9 +54,16 @@ class RadixAttention(nn.Module):
self
.
logit_cap
=
logit_cap
self
.
sliding_window_size
=
sliding_window_size
or
-
1
self
.
is_cross_attention
=
is_cross_attention
self
.
use_irope
=
use_irope
self
.
k_scale
=
None
self
.
v_scale
=
None
self
.
use_irope
=
use_irope
self
.
k_scale_float
=
None
self
.
v_scale_float
=
None
self
.
quant_method
=
None
if
quant_config
is
not
None
:
self
.
quant_method
=
quant_config
.
get_quant_method
(
self
,
prefix
=
prefix
)
if
self
.
quant_method
is
not
None
:
self
.
quant_method
.
create_weights
(
self
)
def
forward
(
self
,
...
...
python/sglang/srt/models/baichuan.py
View file @
2695ab05
...
...
@@ -178,6 +178,7 @@ class BaiChuanAttention(nn.Module):
scaling
,
num_kv_heads
=
self
.
num_kv_heads
,
layer_id
=
layer_id
,
quant_config
=
quant_config
,
prefix
=
add_prefix
(
"attn"
,
prefix
),
)
else
:
...
...
@@ -194,6 +195,7 @@ class BaiChuanAttention(nn.Module):
self
.
scaling
,
num_kv_heads
=
self
.
num_kv_heads
,
layer_id
=
layer_id
,
quant_config
=
quant_config
,
prefix
=
add_prefix
(
"attn"
,
prefix
),
)
...
...
python/sglang/srt/models/chatglm.py
View file @
2695ab05
...
...
@@ -113,6 +113,7 @@ class GLMAttention(nn.Module):
self
.
scaling
,
num_kv_heads
=
self
.
num_kv_heads
,
layer_id
=
layer_id
,
quant_config
=
quant_config
,
prefix
=
add_prefix
(
"attn"
,
prefix
),
)
...
...
python/sglang/srt/models/commandr.py
View file @
2695ab05
...
...
@@ -204,6 +204,7 @@ class CohereAttention(nn.Module):
self
.
scaling
,
num_kv_heads
=
self
.
num_kv_heads
,
layer_id
=
layer_id
,
quant_config
=
quant_config
,
prefix
=
add_prefix
(
"attn"
,
prefix
),
)
if
self
.
use_qk_norm
:
...
...
python/sglang/srt/models/dbrx.py
View file @
2695ab05
...
...
@@ -249,6 +249,7 @@ class DbrxAttention(nn.Module):
self
.
scaling
,
num_kv_heads
=
self
.
num_kv_heads
,
layer_id
=
layer_id
,
quant_config
=
quant_config
,
prefix
=
add_prefix
(
"attn"
,
prefix
),
)
...
...
python/sglang/srt/models/deepseek.py
View file @
2695ab05
...
...
@@ -255,6 +255,7 @@ class DeepseekAttention(nn.Module):
self
.
scaling
,
num_kv_heads
=
self
.
num_kv_heads
,
layer_id
=
layer_id
,
quant_config
=
quant_config
,
prefix
=
add_prefix
(
"attn"
,
prefix
),
)
...
...
python/sglang/srt/models/deepseek_v2.py
View file @
2695ab05
...
...
@@ -489,6 +489,7 @@ class DeepseekV2Attention(nn.Module):
self
.
scaling
,
num_kv_heads
=
self
.
num_local_heads
,
layer_id
=
layer_id
,
quant_config
=
quant_config
,
prefix
=
add_prefix
(
"attn"
,
prefix
),
)
...
...
@@ -669,6 +670,7 @@ class DeepseekV2AttentionMLA(nn.Module):
num_kv_heads
=
1
,
layer_id
=
layer_id
,
v_head_dim
=
self
.
kv_lora_rank
,
quant_config
=
quant_config
,
prefix
=
add_prefix
(
"attn_mqa"
,
prefix
),
)
...
...
@@ -679,6 +681,7 @@ class DeepseekV2AttentionMLA(nn.Module):
num_kv_heads
=
self
.
num_local_heads
,
layer_id
=
layer_id
,
v_head_dim
=
self
.
v_head_dim
,
quant_config
=
quant_config
,
prefix
=
add_prefix
(
"attn_mha"
,
prefix
),
)
...
...
python/sglang/srt/models/exaone.py
View file @
2695ab05
...
...
@@ -155,6 +155,7 @@ class ExaoneAttention(nn.Module):
self
.
scaling
,
num_kv_heads
=
self
.
num_kv_heads
,
layer_id
=
layer_id
,
quant_config
=
quant_config
,
)
def
forward
(
...
...
python/sglang/srt/models/gemma.py
View file @
2695ab05
...
...
@@ -137,6 +137,7 @@ class GemmaAttention(nn.Module):
self
.
scaling
,
num_kv_heads
=
self
.
num_kv_heads
,
layer_id
=
layer_id
,
quant_config
=
quant_config
,
prefix
=
add_prefix
(
"attn"
,
prefix
),
)
...
...
python/sglang/srt/models/gemma2.py
View file @
2695ab05
...
...
@@ -163,6 +163,7 @@ class Gemma2Attention(nn.Module):
if
use_sliding_window
else
None
),
quant_config
=
quant_config
,
prefix
=
add_prefix
(
"attn"
,
prefix
),
)
...
...
python/sglang/srt/models/gemma3_causal.py
View file @
2695ab05
...
...
@@ -193,6 +193,7 @@ class Gemma3Attention(nn.Module):
# Module must also define `get_attention_sliding_window_size` to correctly initialize
# attention backend in `ForwardBatch`.
sliding_window_size
=
self
.
sliding_window
,
quant_config
=
quant_config
,
prefix
=
add_prefix
(
"attn"
,
prefix
),
)
...
...
python/sglang/srt/models/gpt2.py
View file @
2695ab05
...
...
@@ -78,6 +78,7 @@ class GPT2Attention(nn.Module):
scaling
=
self
.
scale
,
num_kv_heads
=
total_num_heads
,
layer_id
=
layer_id
,
quant_config
=
quant_config
,
)
def
forward
(
...
...
python/sglang/srt/models/gpt_bigcode.py
View file @
2695ab05
...
...
@@ -87,6 +87,7 @@ class GPTBigCodeAttention(nn.Module):
scaling
=
self
.
scale
,
num_kv_heads
=
self
.
num_kv_heads
,
layer_id
=
layer_id
,
quant_config
=
quant_config
,
prefix
=
add_prefix
(
"attn"
,
prefix
),
)
...
...
python/sglang/srt/models/granite.py
View file @
2695ab05
...
...
@@ -158,6 +158,7 @@ class GraniteAttention(nn.Module):
self
.
scaling
,
num_kv_heads
=
self
.
num_kv_heads
,
layer_id
=
layer_id
,
quant_config
=
quant_config
,
prefix
=
add_prefix
(
"attn"
,
prefix
),
)
...
...
python/sglang/srt/models/grok.py
View file @
2695ab05
...
...
@@ -215,6 +215,7 @@ class Grok1Attention(nn.Module):
num_kv_heads
=
self
.
num_kv_heads
,
layer_id
=
layer_id
,
logit_cap
=
logit_cap
,
quant_config
=
quant_config
,
)
def
forward
(
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment