Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
a1175a4e
Commit
a1175a4e
authored
Nov 22, 2025
by
maxiao1
Browse files
Merge remote-tracking branch 'origin/v0.5.4_dev' into sglang_v0.5.5
parents
0c006b88
31653dd9
Changes
62
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1825 additions
and
131 deletions
+1825
-131
python/sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py
...t/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py
+2
-1
python/sglang/srt/layers/moe/fused_moe_triton/layer.py
python/sglang/srt/layers/moe/fused_moe_triton/layer.py
+73
-3
python/sglang/srt/layers/moe/token_dispatcher/deepep.py
python/sglang/srt/layers/moe/token_dispatcher/deepep.py
+129
-74
python/sglang/srt/layers/moe/topk.py
python/sglang/srt/layers/moe/topk.py
+51
-0
python/sglang/srt/layers/quantization/__init__.py
python/sglang/srt/layers/quantization/__init__.py
+4
-1
python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py
...ers/quantization/compressed_tensors/compressed_tensors.py
+45
-0
python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_marlin.py
...ntization/compressed_tensors/compressed_tensors_marlin.py
+92
-0
python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe_marlin.py
...ation/compressed_tensors/compressed_tensors_moe_marlin.py
+279
-0
python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py
...ompressed_tensors/schemes/compressed_tensors_w8a8_int8.py
+6
-3
python/sglang/srt/layers/quantization/slimquant_w4a8.py
python/sglang/srt/layers/quantization/slimquant_w4a8.py
+419
-0
python/sglang/srt/layers/quantization/slimquant_w4a8_marlin.py
...n/sglang/srt/layers/quantization/slimquant_w4a8_marlin.py
+402
-0
python/sglang/srt/layers/quantization/w4a8_utils.py
python/sglang/srt/layers/quantization/w4a8_utils.py
+92
-0
python/sglang/srt/layers/quantization/w8a8_int8.py
python/sglang/srt/layers/quantization/w8a8_int8.py
+5
-2
python/sglang/srt/layers/rotary_embedding.py
python/sglang/srt/layers/rotary_embedding.py
+124
-1
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+1
-0
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+1
-1
python/sglang/srt/mem_cache/allocator.py
python/sglang/srt/mem_cache/allocator.py
+42
-18
python/sglang/srt/mem_cache/common.py
python/sglang/srt/mem_cache/common.py
+14
-8
python/sglang/srt/mem_cache/memory_pool.py
python/sglang/srt/mem_cache/memory_pool.py
+9
-0
python/sglang/srt/model_executor/forward_batch_info.py
python/sglang/srt/model_executor/forward_batch_info.py
+35
-19
No files found.
python/sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py
View file @
a1175a4e
...
@@ -14,9 +14,10 @@ from sglang.srt.layers.quantization.fp8_kernel import (
...
@@ -14,9 +14,10 @@ from sglang.srt.layers.quantization.fp8_kernel import (
)
)
from
sglang.srt.layers.quantization.int8_kernel
import
(
from
sglang.srt.layers.quantization.int8_kernel
import
(
per_token_group_quant_int8
,
per_token_group_quant_int8
,
per_token_quant_int8
,
#
per_token_quant_int8,
sglang_per_token_group_quant_int8
,
sglang_per_token_group_quant_int8
,
)
)
from
lmslim.layers.gemm.int8_utils
import
per_token_quant_int8
from
sglang.srt.utils
import
(
from
sglang.srt.utils
import
(
cpu_has_amx_support
,
cpu_has_amx_support
,
get_bool_env_var
,
get_bool_env_var
,
...
...
python/sglang/srt/layers/moe/fused_moe_triton/layer.py
View file @
a1175a4e
...
@@ -45,6 +45,7 @@ from sglang.srt.layers.quantization.fp8 import Fp8MoEMethod
...
@@ -45,6 +45,7 @@ from sglang.srt.layers.quantization.fp8 import Fp8MoEMethod
from
sglang.srt.layers.quantization.modelopt_quant
import
ModelOptNvFp4FusedMoEMethod
from
sglang.srt.layers.quantization.modelopt_quant
import
ModelOptNvFp4FusedMoEMethod
from
sglang.srt.layers.quantization.unquant
import
UnquantizedFusedMoEMethod
from
sglang.srt.layers.quantization.unquant
import
UnquantizedFusedMoEMethod
from
sglang.srt.model_loader.weight_utils
import
narrow_padded_param_and_loaded_weight
from
sglang.srt.model_loader.weight_utils
import
narrow_padded_param_and_loaded_weight
from
sglang.srt.environ
import
envs
from
sglang.srt.two_batch_overlap
import
MaybeTboDeepEPDispatcher
from
sglang.srt.two_batch_overlap
import
MaybeTboDeepEPDispatcher
from
sglang.srt.utils
import
(
from
sglang.srt.utils
import
(
cpu_has_amx_support
,
cpu_has_amx_support
,
...
@@ -58,6 +59,12 @@ from sglang.srt.utils import (
...
@@ -58,6 +59,12 @@ from sglang.srt.utils import (
if
is_flashinfer_available
():
if
is_flashinfer_available
():
from
flashinfer
import
RoutingMethodType
,
fp4_quantize
from
flashinfer
import
RoutingMethodType
,
fp4_quantize
_is_hip
=
is_hip
()
_is_cpu_amx_available
=
cpu_has_amx_support
()
_is_cpu
=
is_cpu
()
_user_lightop_moe_sum_mul_add
=
get_bool_env_var
(
"SGLANG_USE_LIGHTOP_MOE_SUM_MUL_ADD"
)
# Try to import FP4 TRTLLM function if flashinfer is available
# Try to import FP4 TRTLLM function if flashinfer is available
trtllm_fp4_block_scale_moe
=
None
trtllm_fp4_block_scale_moe
=
None
if
get_moe_runner_backend
().
is_flashinfer_trtllm
():
if
get_moe_runner_backend
().
is_flashinfer_trtllm
():
...
@@ -100,6 +107,49 @@ class FusedMoeWeightScaleSupported(Enum):
...
@@ -100,6 +107,49 @@ class FusedMoeWeightScaleSupported(Enum):
GROUP
=
"group"
GROUP
=
"group"
BLOCK
=
"block"
BLOCK
=
"block"
def
determine_expert_map
(
ep_size
:
int
,
ep_rank
:
int
,
global_num_experts
:
int
)
->
tuple
[
int
,
Optional
[
torch
.
Tensor
]]:
"""
Calculates how many experts should be assigned to each rank for EP and
creates a mapping from global to local expert index. Experts are
distributed evenly across ranks. Any remaining are assigned to the
last rank.
Args:
ep_size (int): The size of the expert parallel group
global_num_experts (int): The total number of experts in the model.
Returns:
tuple[int, Optional[torch.Tensor]]: A tuple containing:
- local_num_experts (int): The number of experts assigned
to the current rank.
- expert_map (Optional[torch.Tensor]): A tensor of shape
(global_num_experts,) mapping from global to local index.
Contains -1 for experts not assigned to the current rank.
Returns None if ep_size is 1.
"""
assert
ep_size
>
0
if
ep_size
==
1
:
return
(
global_num_experts
,
None
)
local_num_experts
=
global_num_experts
//
ep_size
# Create a tensor of size num_experts filled with -1
expert_map
=
torch
.
full
((
global_num_experts
,
),
-
1
,
dtype
=
torch
.
int32
)
# Create a expert map for the local experts
if
ep_rank
<
(
ep_size
-
1
):
# Each non-last rank gets local_num_experts experts.
expert_map
[
ep_rank
*
local_num_experts
:
(
ep_rank
+
1
)
*
local_num_experts
]
=
\
torch
.
arange
(
0
,
local_num_experts
,
dtype
=
torch
.
int32
)
else
:
# All remaining experts are assigned to the last rank.
local_num_experts
=
(
global_num_experts
-
ep_rank
*
local_num_experts
)
expert_map
[
-
local_num_experts
:]
=
\
torch
.
arange
(
0
,
local_num_experts
,
dtype
=
torch
.
int32
)
return
(
local_num_experts
,
expert_map
)
class
FusedMoE
(
torch
.
nn
.
Module
):
class
FusedMoE
(
torch
.
nn
.
Module
):
"""FusedMoE layer for MoE models.
"""FusedMoE layer for MoE models.
...
@@ -167,12 +217,20 @@ class FusedMoE(torch.nn.Module):
...
@@ -167,12 +217,20 @@ class FusedMoE(torch.nn.Module):
self
.
moe_tp_size
=
get_moe_tensor_parallel_world_size
()
self
.
moe_tp_size
=
get_moe_tensor_parallel_world_size
()
self
.
moe_tp_rank
=
get_moe_tensor_parallel_rank
()
self
.
moe_tp_rank
=
get_moe_tensor_parallel_rank
()
assert
num_experts
%
self
.
moe_ep_size
==
0
assert
num_experts
%
self
.
moe_ep_size
==
0
self
.
num_local_experts
=
num_experts
//
self
.
moe_ep_size
# self.num_local_experts = num_experts // self.moe_ep_size
if
self
.
moe_ep_size
!=
0
:
self
.
num_local_experts
,
self
.
expert_map
=
determine_expert_map
(
ep_size
=
self
.
moe_ep_size
,
ep_rank
=
self
.
moe_ep_rank
,
global_num_experts
=
num_experts
)
else
:
self
.
local_num_experts
,
self
.
expert_map
=
(
self
.
global_num_experts
,
None
)
assert
intermediate_size
%
self
.
moe_tp_size
==
0
assert
intermediate_size
%
self
.
moe_tp_size
==
0
self
.
intermediate_size_per_partition
=
intermediate_size
//
self
.
moe_tp_size
self
.
intermediate_size_per_partition
=
intermediate_size
//
self
.
moe_tp_size
self
.
reduce_results
=
reduce_results
self
.
reduce_results
=
reduce_results
self
.
use_presharded_weights
=
use_presharded_weights
self
.
use_presharded_weights
=
use_presharded_weights
# self.global_num_experts = self.num_experts
self
.
use_triton_kernels
=
get_moe_runner_backend
().
is_triton_kernels
()
self
.
use_triton_kernels
=
get_moe_runner_backend
().
is_triton_kernels
()
...
@@ -829,9 +887,21 @@ class FusedMoE(torch.nn.Module):
...
@@ -829,9 +887,21 @@ class FusedMoE(torch.nn.Module):
f
"Unsupported weight_name
{
weight_name
}
for FusedMoE weight_loader_fused. Nothing is loaded."
f
"Unsupported weight_name
{
weight_name
}
for FusedMoE weight_loader_fused. Nothing is loaded."
)
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
topk_output
:
TopKOutput
,
**
kwargs
):
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
topk_output
:
TopKOutput
=
None
,
shared_output
:
torch
.
Tensor
=
None
,
**
kwargs
):
origin_hidden_states_dim
=
hidden_states
.
shape
[
-
1
]
origin_hidden_states_dim
=
hidden_states
.
shape
[
-
1
]
assert
self
.
quant_method
is
not
None
assert
self
.
quant_method
is
not
None
if
_user_lightop_moe_sum_mul_add
:
final_hidden_states
=
self
.
quant_method
.
apply_with_shared_output
(
layer
=
self
,
x
=
hidden_states
,
activation
=
getattr
(
self
,
'moe_runner_config'
,
None
)
and
self
.
moe_runner_config
.
activation
or
"silu"
,
shared_output
=
shared_output
,
topk_output
=
topk_output
,
)
if
self
.
reduce_results
and
(
self
.
moe_tp_size
>
1
or
self
.
moe_ep_size
>
1
):
final_hidden_states
=
tensor_model_parallel_all_reduce
(
final_hidden_states
)
return
final_hidden_states
dispatch_output
=
self
.
dispatcher
.
dispatch
(
dispatch_output
=
self
.
dispatcher
.
dispatch
(
hidden_states
=
hidden_states
,
topk_output
=
topk_output
hidden_states
=
hidden_states
,
topk_output
=
topk_output
...
...
python/sglang/srt/layers/moe/token_dispatcher/deepep.py
View file @
a1175a4e
...
@@ -4,7 +4,7 @@ import logging
...
@@ -4,7 +4,7 @@ import logging
from
contextlib
import
nullcontext
from
contextlib
import
nullcontext
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
TYPE_CHECKING
,
List
,
NamedTuple
,
Optional
,
Tuple
,
Union
from
typing
import
TYPE_CHECKING
,
List
,
NamedTuple
,
Optional
,
Tuple
,
Union
from
sglang.srt.distributed
import
get_moe_expert_parallel_rank
,
get_moe_expert_parallel_world_size
from
sglang.srt.eplb.expert_distribution
import
get_global_expert_distribution_recorder
from
sglang.srt.eplb.expert_distribution
import
get_global_expert_distribution_recorder
from
sglang.srt.layers
import
deep_gemm_wrapper
from
sglang.srt.layers
import
deep_gemm_wrapper
from
sglang.srt.layers.dp_attention
import
get_is_extend_in_batch
from
sglang.srt.layers.dp_attention
import
get_is_extend_in_batch
...
@@ -55,6 +55,10 @@ import torch.distributed as dist
...
@@ -55,6 +55,10 @@ import torch.distributed as dist
_use_aiter
=
get_bool_env_var
(
"SGLANG_USE_AITER"
)
and
is_hip
()
_use_aiter
=
get_bool_env_var
(
"SGLANG_USE_AITER"
)
and
is_hip
()
use_groupgemm
=
get_bool_env_var
(
"SGLANG_GROUPGEMM"
,
default
=
"true"
)
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
...
@@ -308,7 +312,7 @@ class _DeepEPDispatcherImplBase:
...
@@ -308,7 +312,7 @@ class _DeepEPDispatcherImplBase:
self
.
params_bytes
=
2
self
.
params_bytes
=
2
self
.
num_max_dispatch_tokens_per_rank
=
get_int_env_var
(
self
.
num_max_dispatch_tokens_per_rank
=
get_int_env_var
(
"SGLANG_DEEPEP_NUM_MAX_DISPATCH_TOKENS_PER_RANK"
,
128
"SGLANG_DEEPEP_NUM_MAX_DISPATCH_TOKENS_PER_RANK"
,
64
)
)
# DeepEP internode_ll dispatch uses FINISHED_SUM_TAG=1024
# DeepEP internode_ll dispatch uses FINISHED_SUM_TAG=1024
# and the logic requires num-tokens-sent-from-one-rank-to-another-rank less than it
# and the logic requires num-tokens-sent-from-one-rank-to-another-rank less than it
...
@@ -357,18 +361,18 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
...
@@ -357,18 +361,18 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
):
):
topk_weights
,
topk_ids
=
topk_output
.
topk_weights
,
topk_output
.
topk_ids
topk_weights
,
topk_ids
=
topk_output
.
topk_weights
,
topk_output
.
topk_ids
topk_ids
=
topk_ids
.
to
(
torch
.
int64
)
topk_ids
=
topk_ids
.
to
(
torch
.
int64
)
if
(
#
if (
deep_gemm_wrapper
.
ENABLE_JIT_DEEPGEMM
#
deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM
and
not
get_moe_runner_backend
().
is_cutlass
()
#
and not get_moe_runner_backend().is_cutlass()
):
#
):
# TODO hard code 128 block quant,use fp8 communication
#
# TODO hard code 128 block quant,use fp8 communication
hidden_states
=
sglang_per_token_group_quant_fp8
(
#
hidden_states = sglang_per_token_group_quant_fp8(
hidden_states
,
#
hidden_states,
128
,
#
128,
column_major_scales
=
deep_gemm_wrapper
.
DEEPGEMM_SCALE_UE8M0
,
#
column_major_scales=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
scale_tma_aligned
=
deep_gemm_wrapper
.
DEEPGEMM_SCALE_UE8M0
,
#
scale_tma_aligned=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
scale_ue8m0
=
deep_gemm_wrapper
.
DEEPGEMM_SCALE_UE8M0
,
#
scale_ue8m0=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
)
#
)
previous_event
=
Buffer
.
capture
()
if
self
.
async_finish
else
None
previous_event
=
Buffer
.
capture
()
if
self
.
async_finish
else
None
return
hidden_states
,
topk_ids
,
topk_weights
,
previous_event
return
hidden_states
,
topk_ids
,
topk_weights
,
previous_event
...
@@ -380,7 +384,7 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
...
@@ -380,7 +384,7 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
num_recv_tokens_per_expert
,
num_recv_tokens_per_expert
,
event
,
event
,
)
=
self
.
_dispatch_core
(
hidden_states
,
topk_ids
,
topk_weights
,
previous_event
)
)
=
self
.
_dispatch_core
(
hidden_states
,
topk_ids
,
topk_weights
,
previous_event
)
event
.
current_stream_wait
()
if
self
.
async_finish
else
()
#
event.current_stream_wait() if self.async_finish else ()
if
isinstance
(
hidden_states
,
tuple
):
if
isinstance
(
hidden_states
,
tuple
):
hidden_states
,
hidden_states_scale
=
hidden_states
hidden_states
,
hidden_states_scale
=
hidden_states
...
@@ -435,19 +439,25 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
...
@@ -435,19 +439,25 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
num_tokens_per_rdma_rank
=
num_tokens_per_rdma_rank
,
num_tokens_per_rdma_rank
=
num_tokens_per_rdma_rank
,
is_token_in_rank
=
is_token_in_rank
,
is_token_in_rank
=
is_token_in_rank
,
num_tokens_per_expert
=
num_tokens_per_expert
,
num_tokens_per_expert
=
num_tokens_per_expert
,
previous_event
=
previous_event
,
previous_event
=
None
,
async_finish
=
self
.
async_finish
,
async_finish
=
False
,
allocate_on_comm_stream
=
(
previous_event
is
not
None
)
and
self
.
async_finish
,
allocate_on_comm_stream
=
False
,
expert_alignment
=
128
if
deep_gemm_wrapper
.
ENABLE_JIT_DEEPGEMM
else
1
,
expert_alignment
=
1
,
config
=
DeepEPConfig
.
get_instance
().
normal_dispatch_config
,
config
=
DeepEPConfig
.
get_instance
().
normal_dispatch_config
,
)
)
if
self
.
quant_config
.
get
(
"quant_method"
)
==
"slimquant_w4a8_marlin"
:
self
.
rank_expert_offset
=
get_moe_expert_parallel_rank
()
*
(
self
.
num_experts
//
get_moe_expert_parallel_world_size
())
recv_topk_ids
=
torch
.
where
(
recv_topk_ids
==
-
1
,
self
.
num_experts
-
1
if
self
.
rank_expert_offset
==
0
else
0
,
recv_topk_ids
+
self
.
rank_expert_offset
)
else
:
get_global_expert_distribution_recorder
().
on_deepep_dispatch_normal
(
get_global_expert_distribution_recorder
().
on_deepep_dispatch_normal
(
num_recv_tokens_per_expert
,
num_recv_tokens_per_expert
,
num_tokens_per_rank
=
num_tokens_per_rank
,
num_tokens_per_rank
=
num_tokens_per_rank
,
num_tokens_per_rdma_rank
=
num_tokens_per_rdma_rank
,
num_tokens_per_rdma_rank
=
num_tokens_per_rdma_rank
,
num_tokens_per_expert
=
num_tokens_per_expert
,
num_tokens_per_expert
=
num_tokens_per_expert
,
)
)
return
(
return
(
recv_x
,
recv_x
,
recv_topk_ids
,
recv_topk_ids
,
...
@@ -464,17 +474,38 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
...
@@ -464,17 +474,38 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
overlap_args
:
Optional
[
CombineOverlapArgs
]
=
None
,
overlap_args
:
Optional
[
CombineOverlapArgs
]
=
None
,
):
):
if
deep_gemm_wrapper
.
ENABLE_JIT_DEEPGEMM
or
_use_aiter
or
_is_npu
:
#
if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM or _use_aiter or _is_npu:
output
=
hidden_states
output
=
hidden_states
else
:
# else:
raise
NotImplementedError
()
# triton runner was supported but it's temporarily disabled
# if hidden_states.shape[0] > 0:
# num_tokens = self.src2dst.shape[0] // self.router_topk
# output = torch.empty(
# (num_tokens, hidden_states.shape[1]),
# device=hidden_states.device,
# dtype=hidden_states.dtype,
# )
# deepep_post_reorder_triton_kernel[(num_tokens,)](
# hidden_states,
# output,
# self.src2dst,
# topk_idx,
# topk_weights,
# self.router_topk,
# hidden_states.shape[1],
# BLOCK_SIZE=512,
# )
# else:
# output = torch.zeros(
# (0, hidden_states.shape[1]),
# device=hidden_states.device,
# dtype=hidden_states.dtype,
# )
previous_event
=
Buffer
.
capture
()
if
self
.
async_finish
else
None
previous_event
=
Buffer
.
capture
()
if
self
.
async_finish
else
None
return
output
,
previous_event
return
output
,
previous_event
def
combine_b
(
self
,
output
,
previous_event
):
def
combine_b
(
self
,
output
,
previous_event
):
hidden_states
,
event
=
self
.
_combine_core
(
output
,
previous_event
)
hidden_states
,
event
=
self
.
_combine_core
(
output
,
previous_event
)
event
.
current_stream_wait
()
if
self
.
async_finish
else
()
#
event.current_stream_wait() if self.async_finish else ()
self
.
handle
=
None
self
.
handle
=
None
self
.
src2dst
=
None
self
.
src2dst
=
None
return
hidden_states
return
hidden_states
...
@@ -484,9 +515,9 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
...
@@ -484,9 +515,9 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
combined_x
,
_
,
event
=
buffer
.
combine
(
combined_x
,
_
,
event
=
buffer
.
combine
(
x
,
x
,
self
.
handle
,
self
.
handle
,
async_finish
=
self
.
async_finish
,
async_finish
=
False
,
previous_event
=
previous_event
,
previous_event
=
None
,
allocate_on_comm_stream
=
previous_event
is
not
Non
e
,
allocate_on_comm_stream
=
Fals
e
,
config
=
DeepEPConfig
.
get_instance
().
normal_combine_config
,
config
=
DeepEPConfig
.
get_instance
().
normal_combine_config
,
)
)
return
combined_x
,
event
return
combined_x
,
event
...
@@ -515,7 +546,7 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
...
@@ -515,7 +546,7 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
num_max_dispatch_tokens_per_rank: the actual batch size in the decoding engine should be less than 256
num_max_dispatch_tokens_per_rank: the actual batch size in the decoding engine should be less than 256
https://github.com/deepseek-ai/DeepEP?tab=readme-ov-file#example-use-in-inference-decoding
https://github.com/deepseek-ai/DeepEP?tab=readme-ov-file#example-use-in-inference-decoding
"""
"""
self
.
return_recv_hook
=
return_recv_hook
self
.
return_recv_hook
=
False
self
.
device_module
=
torch
.
get_device_module
()
self
.
device_module
=
torch
.
get_device_module
()
self
.
quant_config
=
{}
self
.
quant_config
=
{}
...
@@ -589,13 +620,26 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
...
@@ -589,13 +620,26 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
use_fp8
=
True
use_fp8
=
True
buffer
=
self
.
_get_buffer
()
buffer
=
self
.
_get_buffer
()
if
use_groupgemm
:
packed_recv_hidden
,
self
.
packed_recv_count
,
self
.
handle
,
event
,
hook
=
(
packed_recv_hidden
,
self
.
packed_recv_count
,
self
.
handle
,
event
,
hook
=
(
buffer
.
low_latency_dispatch
(
buffer
.
low_latency_dispatch
(
hidden_states
,
hidden_states
,
topk_ids
,
topk_ids
,
self
.
num_max_dispatch_tokens_per_rank
,
self
.
num_max_dispatch_tokens_per_rank
,
self
.
num_experts
,
self
.
num_experts
,
use_fp8
=
use_fp8
,
use_fp8
=
False
,
async_finish
=
not
self
.
return_recv_hook
,
return_recv_hook
=
self
.
return_recv_hook
,
)
)
else
:
packed_recv_hidden
,
self
.
packed_recv_count
,
self
.
handle
,
event
,
hook
=
(
buffer
.
low_latency_dispatch
(
hidden_states
,
topk_ids
,
self
.
num_max_dispatch_tokens_per_rank
,
self
.
num_experts
,
use_fp8
=
False
,
**
(
dict
(
use_nvfp4
=
True
)
if
use_nvfp4
else
dict
()),
**
(
dict
(
use_nvfp4
=
True
)
if
use_nvfp4
else
dict
()),
**
(
**
(
dict
(
x_global_scale
=
input_global_scale
)
dict
(
x_global_scale
=
input_global_scale
)
...
@@ -653,11 +697,23 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
...
@@ -653,11 +697,23 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
ctx
=
torch
.
cuda
.
stream
(
overlap_args
.
stream
)
ctx
=
torch
.
cuda
.
stream
(
overlap_args
.
stream
)
with
ctx
:
with
ctx
:
if
use_groupgemm
:
combined_hidden_states
,
event
,
hook
=
buffer
.
low_latency_combine
(
x
=
hidden_states
,
topk_idx
=
topk_ids
,
topk_weights
=
topk_weights
,
handle
=
self
.
handle
,
zero_copy
=
False
,
async_finish
=
not
self
.
return_recv_hook
,
return_recv_hook
=
self
.
return_recv_hook
,
)
else
:
combined_hidden_states
,
event
,
hook
=
buffer
.
low_latency_combine
(
combined_hidden_states
,
event
,
hook
=
buffer
.
low_latency_combine
(
x
=
hidden_states
,
x
=
hidden_states
,
topk_idx
=
topk_ids
,
topk_idx
=
topk_ids
,
topk_weights
=
topk_weights
,
topk_weights
=
topk_weights
,
handle
=
self
.
handle
,
handle
=
self
.
handle
,
zero_copy
=
False
,
async_finish
=
not
self
.
return_recv_hook
,
async_finish
=
not
self
.
return_recv_hook
,
return_recv_hook
=
self
.
return_recv_hook
,
return_recv_hook
=
self
.
return_recv_hook
,
**
(
**
(
...
@@ -670,7 +726,6 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
...
@@ -670,7 +726,6 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
else
{}
else
{}
),
),
)
)
self
.
packed_recv_count
=
self
.
handle
=
None
self
.
packed_recv_count
=
self
.
handle
=
None
return
combined_hidden_states
,
event
,
hook
return
combined_hidden_states
,
event
,
hook
...
...
python/sglang/srt/layers/moe/topk.py
View file @
a1175a4e
...
@@ -28,6 +28,8 @@ from typing import (
...
@@ -28,6 +28,8 @@ from typing import (
runtime_checkable
,
runtime_checkable
,
)
)
from
numpy
import
dtype
import
torch
import
torch
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
...
@@ -45,6 +47,7 @@ from sglang.srt.eplb.expert_location_dispatch import (
...
@@ -45,6 +47,7 @@ from sglang.srt.eplb.expert_location_dispatch import (
from
sglang.srt.layers.dp_attention
import
is_allocation_symmetric
from
sglang.srt.layers.dp_attention
import
is_allocation_symmetric
from
sglang.srt.layers.moe
import
get_moe_runner_backend
from
sglang.srt.layers.moe
import
get_moe_runner_backend
from
sglang.srt.utils
import
(
from
sglang.srt.utils
import
(
direct_register_custom_op
,
cpu_has_amx_support
,
cpu_has_amx_support
,
get_bool_env_var
,
get_bool_env_var
,
get_compiler_backend
,
get_compiler_backend
,
...
@@ -70,6 +73,7 @@ _is_cpu = is_cpu()
...
@@ -70,6 +73,7 @@ _is_cpu = is_cpu()
_is_cpu_amx_available
=
cpu_has_amx_support
()
_is_cpu_amx_available
=
cpu_has_amx_support
()
_is_npu
=
is_npu
()
_is_npu
=
is_npu
()
_use_aiter
=
get_bool_env_var
(
"SGLANG_USE_AITER"
)
and
_is_hip
_use_aiter
=
get_bool_env_var
(
"SGLANG_USE_AITER"
)
and
_is_hip
_use_lightop
=
get_bool_env_var
(
"SGLANG_USE_LIGHTOP"
)
if
_is_cuda
:
if
_is_cuda
:
from
sgl_kernel
import
moe_fused_gate
from
sgl_kernel
import
moe_fused_gate
...
@@ -81,9 +85,44 @@ if _use_aiter:
...
@@ -81,9 +85,44 @@ if _use_aiter:
from
aiter
import
biased_grouped_topk
as
aiter_biased_grouped_topk
from
aiter
import
biased_grouped_topk
as
aiter_biased_grouped_topk
except
ImportError
:
except
ImportError
:
raise
ImportError
(
"aiter is required when SGLANG_USE_AITER is set to True"
)
raise
ImportError
(
"aiter is required when SGLANG_USE_AITER is set to True"
)
if
_use_lightop
:
from
lightop
import
op
as
op
if
_is_npu
:
if
_is_npu
:
import
torch_npu
import
torch_npu
# ------- custom op for moe_fused_gate
def
moe_fused_gate_dcu
(
gating_output
:
torch
.
Tensor
,
correction_bias
:
torch
.
Tensor
,
num_expert_group
:
int
,
topk_group
:
int
,
topk
:
int
,
num_fused_shared_experts
:
int
,
routed_scaling_factor
:
float
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
topk_weights
,
topk_ids
=
op
.
moe_fused_gate
(
gating_output
.
to
(
dtype
=
torch
.
float32
),
# or bfloat16
correction_bias
,
num_expert_group
,
topk_group
,
topk
,
num_fused_shared_experts
,
# 0 in vllm
routed_scaling_factor
,
)
return
topk_weights
,
topk_ids
def
moe_fused_gate_fake
(
gating_output
:
torch
.
Tensor
,
correction_bias
:
torch
.
Tensor
,
num_expert_group
:
int
,
topk_group
:
int
,
topk
:
int
,
num_fused_shared_experts
:
int
,
routed_scaling_factor
:
float
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
return
torch
.
empty
((
gating_output
.
size
(
0
),
topk
),
dtype
=
gating_output
.
dtype
,
device
=
gating_output
.
device
),
\
torch
.
empty
((
gating_output
.
size
(
0
),
topk
),
dtype
=
gating_output
.
dtype
,
device
=
gating_output
.
device
)
direct_register_custom_op
(
op_name
=
"moe_fused_gate_dcu"
,
op_func
=
moe_fused_gate_dcu
,
mutates_args
=
[],
fake_impl
=
moe_fused_gate_fake
,
)
# -------
# -------------------------------- TopKConfig ---------------------------------------
# -------------------------------- TopKConfig ---------------------------------------
...
@@ -759,6 +798,18 @@ def biased_grouped_topk_gpu(
...
@@ -759,6 +798,18 @@ def biased_grouped_topk_gpu(
routed_scaling_factor
,
routed_scaling_factor
,
)
)
return
topk_weights
,
topk_ids
return
topk_weights
,
topk_ids
elif
_use_lightop
:
assert
not
apply_routed_scaling_factor_on_output
,
"Not implemented"
topk_weights
,
topk_ids
=
torch
.
ops
.
sglang
.
moe_fused_gate_dcu
(
gating_output
.
to
(
dtype
=
torch
.
float32
),
# or bfloat16
correction_bias
,
num_expert_group
,
topk_group
,
topk
,
0
,
# 0 in vllm
routed_scaling_factor
,
)
return
topk_weights
,
topk_ids
else
:
else
:
return
biased_grouped_topk_impl
(
return
biased_grouped_topk_impl
(
hidden_states
,
hidden_states
,
...
...
python/sglang/srt/layers/quantization/__init__.py
View file @
a1175a4e
...
@@ -38,6 +38,8 @@ from sglang.srt.layers.quantization.qoq import QoQConfig
...
@@ -38,6 +38,8 @@ from sglang.srt.layers.quantization.qoq import QoQConfig
from
sglang.srt.layers.quantization.w4afp8
import
W4AFp8Config
from
sglang.srt.layers.quantization.w4afp8
import
W4AFp8Config
from
sglang.srt.layers.quantization.w8a8_fp8
import
W8A8Fp8Config
from
sglang.srt.layers.quantization.w8a8_fp8
import
W8A8Fp8Config
from
sglang.srt.layers.quantization.w8a8_int8
import
W8A8Int8Config
from
sglang.srt.layers.quantization.w8a8_int8
import
W8A8Int8Config
from
sglang.srt.layers.quantization.slimquant_w4a8_marlin
import
SlimQuantW4A8Int8MarlinConfig
from
sglang.srt.layers.quantization.compressed_tensors.compressed_tensors_marlin
import
SlimQuantCompressedTensorsMarlinConfig
from
sglang.srt.utils
import
is_cuda
,
is_hip
,
mxfp_supported
from
sglang.srt.utils
import
is_cuda
,
is_hip
,
mxfp_supported
_is_mxfp_supported
=
mxfp_supported
()
_is_mxfp_supported
=
mxfp_supported
()
...
@@ -65,7 +67,8 @@ BASE_QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
...
@@ -65,7 +67,8 @@ BASE_QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
"w4afp8"
:
W4AFp8Config
,
"w4afp8"
:
W4AFp8Config
,
"petit_nvfp4"
:
PetitNvFp4Config
,
"petit_nvfp4"
:
PetitNvFp4Config
,
"fbgemm_fp8"
:
FBGEMMFp8Config
,
"fbgemm_fp8"
:
FBGEMMFp8Config
,
"auto-round"
:
AutoRoundConfig
,
"slimquant_w4a8_marlin"
:
SlimQuantW4A8Int8MarlinConfig
,
"slimquant_marlin"
:
SlimQuantCompressedTensorsMarlinConfig
,
}
}
...
...
python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py
View file @
a1175a4e
...
@@ -44,6 +44,7 @@ from sglang.srt.layers.quantization.compressed_tensors.utils import (
...
@@ -44,6 +44,7 @@ from sglang.srt.layers.quantization.compressed_tensors.utils import (
)
)
from
sglang.srt.layers.quantization.fp8
import
Fp8LinearMethod
from
sglang.srt.layers.quantization.fp8
import
Fp8LinearMethod
from
sglang.srt.layers.quantization.unquant
import
UnquantizedLinearMethod
from
sglang.srt.layers.quantization.unquant
import
UnquantizedLinearMethod
from
sglang.srt.layers.quantization.kv_cache
import
BaseKVCacheMethod
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
...
@@ -639,3 +640,47 @@ class CompressedTensorsLinearMethod(LinearMethodBase):
...
@@ -639,3 +640,47 @@ class CompressedTensorsLinearMethod(LinearMethodBase):
if
scheme
is
None
:
if
scheme
is
None
:
raise
ValueError
(
"A scheme must be defined for each layer"
)
raise
ValueError
(
"A scheme must be defined for each layer"
)
return
scheme
.
apply_weights
(
layer
,
x
,
bias
=
bias
)
return
scheme
.
apply_weights
(
layer
,
x
,
bias
=
bias
)
class
CompressedTensorsKVCacheMethod
(
BaseKVCacheMethod
):
"""
Supports loading kv-cache scaling factors from compressed-tensors
checkpoints.
"""
def
__init__
(
self
,
quant_config
:
CompressedTensorsConfig
):
self
.
validate_kv_cache_scheme
(
quant_config
.
kv_cache_scheme
)
super
().
__init__
(
quant_config
)
@
staticmethod
def
validate_kv_cache_scheme
(
kv_cache_scheme
:
Optional
[
dict
[
str
,
Any
]]):
"""
Validator for the kv cache scheme. Useful for controlling the
kv cache quantization schemes, that are being supported in vLLM
:param kv_cache_scheme: the compressed-tensors kv cache scheme
"""
if
kv_cache_scheme
is
None
:
return
type_
=
kv_cache_scheme
.
get
(
"type"
)
num_bits
=
kv_cache_scheme
.
get
(
"num_bits"
)
if
type_
!=
"float"
and
num_bits
!=
8
:
raise
NotImplementedError
(
"Currently supported kv cache quantization is "
"num_bits=8, type=float, however "
f
"received num_bits=
{
num_bits
}
, type=
{
type_
}
"
)
strategy
=
kv_cache_scheme
.
get
(
"strategy"
)
if
strategy
!=
"tensor"
:
raise
NotImplementedError
(
"Only support per-tensor scaling factor "
"for compressed-tensors KV cache. "
f
"Expected strategy: tensor, found strategy:
{
strategy
}
"
)
is_symmetric
=
kv_cache_scheme
.
get
(
"symmetric"
)
if
not
is_symmetric
:
raise
NotImplementedError
(
"Only support symmetric scaling factor "
"for compressed-tensors KV cache. "
f
"However found symmetric:
{
is_symmetric
}
"
)
python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_marlin.py
0 → 100644
View file @
a1175a4e
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
__future__
import
annotations
from
typing
import
TYPE_CHECKING
,
Any
,
Literal
,
Optional
,
cast
import
torch
from
compressed_tensors.config
import
SparsityCompressionConfig
from
compressed_tensors.quantization
import
QuantizationArgs
import
logging
from
sglang.srt.layers.linear
import
LinearBase
from
sglang.srt.layers.quantization.unquant
import
UnquantizedEmbeddingMethod
from
sglang.srt.layers.quantization.base_config
import
(
LinearMethodBase
,
QuantizationConfig
,
QuantizeMethodBase
,
)
from
sglang.srt.layers.quantization.compressed_tensors.compressed_tensors
import
CompressedTensorsConfig
,
CompressedTensorsLinearMethod
,
CompressedTensorsKVCacheMethod
from
sglang.srt.layers.quantization.compressed_tensors.compressed_tensors_moe_marlin
import
CompressedTensorsMarlinMoEMethod
from
sglang.srt.layers.quantization.compressed_tensors.utils
import
(
should_ignore_layer
)
from
sglang.srt.layers.quantization.kv_cache
import
BaseKVCacheMethod
import
os
# if TYPE_CHECKING:
# from vllm.model_executor.models.utils import WeightsMapper
logger
=
logging
.
getLogger
(
__name__
)
__all__
=
[
"CompressedTensorsLinearMethod"
]
SPARSITY_CONFIG_NAME
:
Literal
[
"sparsity_config"
]
=
"sparsity_config"
QUANTIZATION_SCHEME_MAP_TYPE
=
dict
[
str
,
Optional
[
dict
[
str
,
QuantizationArgs
]]]
class
SlimQuantCompressedTensorsMarlinConfig
(
CompressedTensorsConfig
):
def
__init__
(
self
,
target_scheme_map
:
dict
[
str
,
Any
],
ignore
:
list
[
str
],
quant_format
:
str
,
sparsity_scheme_map
:
dict
[
str
,
SparsityCompressionConfig
],
sparsity_ignore_list
:
list
[
str
],
kv_cache_scheme
:
Optional
[
dict
[
str
,
Any
]]
=
None
,
config
:
Optional
[
dict
[
str
,
Any
]]
=
None
,
packed_modules_mapping
:
Optional
[
dict
[
str
,
list
[
str
]]]
=
None
,
):
super
().
__init__
(
target_scheme_map
,
ignore
,
quant_format
,
sparsity_scheme_map
,
sparsity_ignore_list
,
kv_cache_scheme
,
config
,
packed_modules_mapping
,
)
@
classmethod
def
override_quantization_method
(
cls
,
hf_quant_cfg
,
user_quant
)
->
Optional
[
str
]:
if
hf_quant_cfg
.
get
(
"quant_method"
)
==
"compressed-tensors"
\
and
user_quant
==
"slimquant_marlin"
:
return
cls
.
get_name
()
return
None
@
classmethod
def
get_name
(
cls
)
->
str
:
return
"slimquant_marlin"
def
get_quant_method
(
self
,
layer
:
torch
.
nn
.
Module
,
prefix
:
str
,
)
->
Optional
[
"QuantizeMethodBase"
]:
from
sglang.srt.layers.moe.fused_moe_triton.layer
import
FusedMoE
# Avoid circular import
from
sglang.srt.layers.radix_attention
import
RadixAttention
# Check if the layer is skipped for quantization.
if
should_ignore_layer
(
prefix
,
ignore
=
self
.
ignore
,
fused_mapping
=
self
.
packed_modules_mapping
):
return
UnquantizedEmbeddingMethod
()
#UnquantizedLinearMethod()
if
isinstance
(
layer
,
LinearBase
):
scheme
=
self
.
get_scheme
(
layer
=
layer
,
layer_name
=
prefix
)
if
scheme
is
None
:
return
UnquantizedEmbeddingMethod
()
#UnquantizedLinearMethod()
layer
.
scheme
=
scheme
return
CompressedTensorsLinearMethod
(
self
)
if
isinstance
(
layer
,
RadixAttention
):
return
CompressedTensorsKVCacheMethod
(
self
)
if
isinstance
(
layer
,
FusedMoE
):
return
CompressedTensorsMarlinMoEMethod
.
get_moe_method
(
self
,
layer
)
return
None
python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe_marlin.py
0 → 100644
View file @
a1175a4e
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
__future__
import
annotations
import
enum
from
enum
import
Enum
from
typing
import
Callable
,
Optional
import
torch
from
compressed_tensors.quantization
import
(
QuantizationStrategy
)
import
logging
from
torch.nn.parameter
import
Parameter
from
sglang.srt.layers.quantization.base_config
import
FusedMoEMethodBase
from
sglang.srt.utils
import
set_weight_attrs
from
sglang.srt.layers.moe
import
MoeRunner
,
MoeRunnerBackend
,
MoeRunnerConfig
from
sglang.srt.layers.moe.utils
import
get_moe_a2a_backend
try
:
from
lmslim.layers.fused_moe.fuse_moe_int8_marlin
import
fused_experts_impl_int8_marlin
except
Exception
:
print
(
"INFO: Please install lmslim if you want to infer the quantitative model of moe.
\n
"
)
logger
=
logging
.
getLogger
(
__name__
)
__all__
=
[
"CompressedTensorsW8A8Int8MarlinMoEMethod"
,
]
def
get_w8a8_int8_marlin_weights
(
weight
,
k_tile
=
64
):
# 7168, 512
weight
=
weight
.
T
size_k
,
size_n
=
weight
.
shape
assert
size_k
//
k_tile
weight
=
weight
.
reshape
(
size_k
//
k_tile
,
k_tile
,
size_n
)
weight
=
weight
.
transpose
(
1
,
2
)
weight
=
weight
.
reshape
(
size_k
//
k_tile
,
size_n
*
k_tile
)
return
weight
def
w8a8_nt_kpack2_marlin_weight
(
w8a8_w
,
# [size_n, size_k// 2 ]
k_tile
=
16
,
n_tile
=
16
,
):
assert
w8a8_w
.
dtype
==
torch
.
int8
,
"w8a8_w 必须是 int8 类型"
size_n
,
size_k
=
w8a8_w
.
shape
assert
size_n
%
k_tile
==
0
and
size_k
%
n_tile
==
0
,
"k_tile / n_tile 必须能整除对应维度"
q
=
w8a8_w
.
reshape
((
size_n
//
n_tile
,
n_tile
,
size_k
//
k_tile
,
k_tile
))
q
=
q
.
permute
((
0
,
2
,
1
,
3
)).
contiguous
()
q
=
q
.
reshape
((
size_n
//
k_tile
,
size_k
*
k_tile
))
return
q
class
CompressedTensorsMarlinMoEMethod
(
FusedMoEMethodBase
):
@
staticmethod
def
get_moe_method
(
quant_config
:
"SlimQuantCompressedTensorsMarlinConfig"
,
# type: ignore # noqa E501
layer
:
torch
.
nn
.
Module
,
)
->
"CompressedTensorsMarlinMoEMethod"
:
# are supported + check if the layer is being ignored.
weight_quant
=
quant_config
.
target_scheme_map
[
"Linear"
].
get
(
"weights"
)
input_quant
=
quant_config
.
target_scheme_map
[
"Linear"
].
get
(
"input_activations"
)
if
quant_config
.
_is_dynamic_token_w8a8
(
weight_quant
,
input_quant
):
return
CompressedTensorsW8A8Int8MarlinMoEMethod
(
quant_config
)
else
:
raise
RuntimeError
(
f
"Slimquant_marlin does not support the FusedMoe scheme:
{
weight_quant
}
,
{
input_quant
}
"
)
class
CompressedTensorsW8A8Int8MarlinMoEMethod
(
CompressedTensorsMarlinMoEMethod
):
def
__init__
(
self
,
quant_config
:
"CompressedTensorsMarlinConfig"
# type: ignore # noqa E501
):
self
.
quant_config
=
quant_config
self
.
weight_quant
=
self
.
quant_config
.
target_scheme_map
[
"Linear"
].
get
(
"weights"
)
self
.
input_quant
=
self
.
quant_config
.
target_scheme_map
[
"Linear"
].
get
(
"input_activations"
)
self
.
use_deepep
=
get_moe_a2a_backend
().
is_deepep
()
per_channel
=
(
self
.
weight_quant
.
strategy
==
QuantizationStrategy
.
CHANNEL
and
self
.
input_quant
.
strategy
==
QuantizationStrategy
.
TOKEN
)
if
not
per_channel
:
raise
ValueError
(
"For INT8 Fused MoE layers, we require channelwise, "
"dynamic per token quantization. Found "
f
"
{
self
.
weight_quant
}
,
{
self
.
input_quant
}
"
)
self
.
static_input_scales
=
not
self
.
input_quant
.
dynamic
if
self
.
static_input_scales
:
raise
ValueError
(
"For INT8 Fused MoE layers, we require channelwise, "
"dynamic per token quantization. Found static input scales."
)
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
num_experts
:
int
,
hidden_size
:
int
,
intermediate_size_per_partition
:
int
,
params_dtype
:
torch
.
dtype
,
**
extra_weight_attrs
):
from
sglang.srt.layers.moe.fused_moe_triton
import
FusedMoeWeightScaleSupported
params_dtype
=
torch
.
int8
# WEIGHTS
w13_weight
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
num_experts
,
2
*
intermediate_size_per_partition
,
hidden_size
,
dtype
=
params_dtype
),
requires_grad
=
False
)
layer
.
register_parameter
(
"w13_weight"
,
w13_weight
)
set_weight_attrs
(
w13_weight
,
extra_weight_attrs
)
w2_weight
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
num_experts
,
hidden_size
,
intermediate_size_per_partition
,
dtype
=
params_dtype
),
requires_grad
=
False
)
layer
.
register_parameter
(
"w2_weight"
,
w2_weight
)
set_weight_attrs
(
w2_weight
,
extra_weight_attrs
)
# WEIGHT_SCALES
assert
self
.
weight_quant
.
strategy
==
QuantizationStrategy
.
CHANNEL
w13_weight_scale
=
torch
.
nn
.
Parameter
(
torch
.
ones
(
num_experts
,
2
*
intermediate_size_per_partition
,
1
,
dtype
=
torch
.
float32
),
requires_grad
=
False
)
layer
.
register_parameter
(
"w13_weight_scale"
,
w13_weight_scale
)
w2_weight_scale
=
torch
.
nn
.
Parameter
(
torch
.
ones
(
num_experts
,
hidden_size
,
1
,
dtype
=
torch
.
float32
),
requires_grad
=
False
)
layer
.
register_parameter
(
"w2_weight_scale"
,
w2_weight_scale
)
# Add PER-CHANNEL quantization for FusedMoE.weight_loader.
extra_weight_attrs
.
update
(
{
"quant_method"
:
FusedMoeWeightScaleSupported
.
CHANNEL
.
value
})
set_weight_attrs
(
w13_weight_scale
,
extra_weight_attrs
)
set_weight_attrs
(
w2_weight_scale
,
extra_weight_attrs
)
# INPUT_SCALES
assert
not
self
.
static_input_scales
layer
.
w13_input_scale
=
None
layer
.
w2_input_scale
=
None
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
w1_marlin_list
=
[]
for
ii
in
range
(
layer
.
w13_weight
.
shape
[
0
]):
if
not
self
.
use_deepep
:
w1_marlin_in
=
get_w8a8_int8_marlin_weights
(
layer
.
w13_weight
[
ii
])
else
:
w1_marlin_in
=
w8a8_nt_kpack2_marlin_weight
(
layer
.
w13_weight
[
ii
])
w1_marlin_list
.
append
(
w1_marlin_in
)
w1_marlin
=
torch
.
stack
(
w1_marlin_list
,
dim
=
0
)
w2_marlin_list
=
[]
for
ii
in
range
(
layer
.
w2_weight
.
shape
[
0
]):
if
not
self
.
use_deepep
:
w2_marlin_in
=
get_w8a8_int8_marlin_weights
(
layer
.
w2_weight
[
ii
])
else
:
w2_marlin_in
=
w8a8_nt_kpack2_marlin_weight
(
layer
.
w2_weight
[
ii
])
w2_marlin_list
.
append
(
w2_marlin_in
)
w2_marlin
=
torch
.
stack
(
w2_marlin_list
,
dim
=
0
)
layer
.
w13_weight
=
Parameter
(
w1_marlin
,
requires_grad
=
False
)
layer
.
w2_weight
=
Parameter
(
w2_marlin
,
requires_grad
=
False
)
def
create_moe_runner
(
self
,
layer
:
torch
.
nn
.
Module
,
moe_runner_config
:
MoeRunnerConfig
):
self
.
moe_runner_config
=
moe_runner_config
# def apply(
# self,
# layer: torch.nn.Module,
# x: torch.Tensor,
# router_logits: torch.Tensor,
# top_k: int,
# renormalize: bool,
# use_grouped_topk: bool = False,
# topk_group: Optional[int] = None,
# num_expert_group: Optional[int] = None,
# global_num_experts: int = -1,
# expert_map: Optional[torch.Tensor] = None,
# custom_routing_function: Optional[Callable] = None,
# scoring_func: str = "softmax",
# e_score_correction_bias: Optional[torch.Tensor] = None,
# apply_router_weight_on_input: bool = False,
# activation: str = "silu",
# enable_eplb: bool = False,
# use_nn_moe: Optional[bool] = False,
# routed_scaling_factor: Optional[float] = None,
# use_fused_gate: Optional[bool] = False,
# expert_load_view: Optional[torch.Tensor] = None,
# logical_to_physical_map: Optional[torch.Tensor] = None,
# logical_replica_count: Optional[torch.Tensor] = None,
# shared_output: Optional[torch.Tensor] = None,
# ) -> torch.Tensor:
# from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
# if enable_eplb:
# raise NotImplementedError(
# "EPLB not supported for "
# "`CompressedTensorsW8A8Int8MoEMethod` yet.")
# topk_weights, topk_ids = FusedMoE.select_experts(
# hidden_states=x,
# router_logits=router_logits,
# use_grouped_topk=use_grouped_topk,
# top_k=top_k,
# renormalize=renormalize,
# topk_group=topk_group,
# num_expert_group=num_expert_group,
# custom_routing_function=custom_routing_function,
# scoring_func=scoring_func,
# routed_scaling_factor=routed_scaling_factor,
# use_fused_gate=use_fused_gate,
# e_score_correction_bias=e_score_correction_bias)
# return fused_experts_impl_int8_marlin(
# hidden_states=x,
# w1=layer.w13_weight,
# w2=layer.w2_weight,
# topk_weights=topk_weights,
# topk_ids=topk_ids,
# inplace=True,
# activation=activation,
# apply_router_weight_on_input=apply_router_weight_on_input,
# use_int8_w8a8=True,
# per_channel_quant=True,
# global_num_experts=global_num_experts,
# expert_map=expert_map,
# w1_scale=layer.w13_weight_scale,
# w2_scale=layer.w2_weight_scale,
# a1_scale=layer.w13_input_scale,
# a2_scale=layer.w2_input_scale,
# use_nn_moe=False,
# shared_output=shared_output,
# routed_scaling_factor=routed_scaling_factor)
def
apply
(
self
,
layer
:
torch
.
nn
.
Module
,
dispatch_output
,
)
:
from
sglang.srt.layers.moe.token_dispatcher.standard
import
StandardCombineInput
x
=
dispatch_output
.
hidden_states
topk_output
=
dispatch_output
.
topk_output
from
sglang.srt.layers.moe.topk
import
apply_topk_weights_cpu
topk_weights
,
topk_ids
,
_
=
topk_output
x
,
topk_weights
=
apply_topk_weights_cpu
(
self
.
moe_runner_config
.
apply_router_weight_on_input
,
topk_weights
,
x
)
output
=
fused_experts_impl_int8_marlin
(
x
,
layer
.
w13_weight
,
layer
.
w2_weight
,
topk_weights
=
topk_weights
,
topk_ids
=
topk_ids
,
inplace
=
True
,
activation
=
layer
.
moe_runner_config
.
activation
,
apply_router_weight_on_input
=
self
.
moe_runner_config
.
apply_router_weight_on_input
,
use_int8_w8a8
=
True
,
per_channel_quant
=
True
,
global_num_experts
=
layer
.
moe_runner_config
.
num_experts
,
w1_scale
=
(
layer
.
w13_weight_scale
),
w2_scale
=
(
layer
.
w2_weight_scale
),
a1_scale
=
layer
.
w13_input_scale
,
a2_scale
=
layer
.
w2_input_scale
,
use_nn_moe
=
False
,
)
return
StandardCombineInput
(
hidden_states
=
output
)
\ No newline at end of file
python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py
View file @
a1175a4e
...
@@ -15,10 +15,11 @@ from sglang.srt.layers.parameter import (
...
@@ -15,10 +15,11 @@ from sglang.srt.layers.parameter import (
from
sglang.srt.layers.quantization.compressed_tensors.schemes
import
(
from
sglang.srt.layers.quantization.compressed_tensors.schemes
import
(
CompressedTensorsScheme
,
CompressedTensorsScheme
,
)
)
from
sglang.srt.layers.quantization.int8_kernel
import
per_token_quant_int8
# from sglang.srt.layers.quantization.int8_kernel import per_token_quant_int8
from
lmslim.layers.gemm.int8_utils
import
per_token_quant_int8
from
sglang.srt.layers.quantization.utils
import
requantize_with_max_scale
from
sglang.srt.layers.quantization.utils
import
requantize_with_max_scale
from
sglang.srt.utils
import
is_cuda
from
sglang.srt.utils
import
is_cuda
from
lmslim
import
quant_ops
_is_cuda
=
is_cuda
()
_is_cuda
=
is_cuda
()
if
_is_cuda
:
if
_is_cuda
:
from
sgl_kernel
import
int8_scaled_mm
from
sgl_kernel
import
int8_scaled_mm
...
@@ -162,12 +163,14 @@ class CompressedTensorsW8A8Int8(CompressedTensorsScheme):
...
@@ -162,12 +163,14 @@ class CompressedTensorsW8A8Int8(CompressedTensorsScheme):
)
)
layer
.
register_parameter
(
"input_zero_point"
,
input_zero_point
)
layer
.
register_parameter
(
"input_zero_point"
,
input_zero_point
)
@
torch
.
_dynamo
.
disable
()
def
apply_weights
(
def
apply_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
# TODO: add cutlass_scaled_mm_azp support
# TODO: add cutlass_scaled_mm_azp support
x_q
,
x_scale
=
per_token_quant_int8
(
x
)
x_q
,
x_scale
=
per_token_quant_int8
(
x
)
return
int8_scaled_mm
(
# TODO: fix with lmslim/lightop
return
quant_ops
.
triton_scaled_mm
(
x_q
,
layer
.
weight
,
x_scale
,
layer
.
weight_scale
,
out_dtype
=
x
.
dtype
,
bias
=
bias
x_q
,
layer
.
weight
,
x_scale
,
layer
.
weight_scale
,
out_dtype
=
x
.
dtype
,
bias
=
bias
)
)
python/sglang/srt/layers/quantization/slimquant_w4a8.py
0 → 100644
View file @
a1175a4e
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Optional
import
torch
from
sglang.srt.layers.linear
import
set_weight_attrs
from
sglang.srt.distributed
import
get_tensor_model_parallel_world_size
from
torch.nn.parameter
import
Parameter
from
sglang.srt.layers.linear
import
LinearBase
from
sglang.srt.layers.quantization.base_config
import
LinearMethodBase
,
QuantizationConfig
,
QuantizeMethodBase
,
FusedMoEMethodBase
from
sglang.srt.layers.parameter
import
(
ChannelQuantScaleParameter
,
_ColumnvLLMParameter
,
RowvLLMParameter
,
)
from
lmslim.layers.gemm.int8_utils
import
(
per_token_group_quant_int8
,
per_token_quant_int8
)
from
sglang.srt
import
_custom_ops
as
ops
from
vllm.utils
import
W8a8GetCacheJSON
from
sglang.srt.layers.moe
import
MoeRunner
,
MoeRunnerBackend
,
MoeRunnerConfig
import
os
from
sglang.srt.utils
import
get_bool_env_var
_use_fused_rms_quant
=
get_bool_env_var
(
"SGLANG_USE_FUSED_RMS_QUANT"
)
_use_fused_silu_mul_quant
=
get_bool_env_var
(
"SGLANG_USE_FUSED_SILU_MUL_QUANT"
)
class
ModelWeightParameter
(
_ColumnvLLMParameter
,
RowvLLMParameter
):
"""
Parameter class for linear layer weights. Uses both column and
row parallelism.
"""
pass
W8A8_TRITONJSON
=
W8a8GetCacheJSON
()
def
baseline_scaled_mm
(
a
:
torch
.
Tensor
,
b
:
torch
.
Tensor
,
scale_a
:
torch
.
Tensor
,
scale_b
:
torch
.
Tensor
,
out_dtype
:
torch
.
dtype
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
scales
=
scale_a
*
scale_b
.
T
gemmout
=
torch
.
mm
(
a
.
to
(
dtype
=
torch
.
float32
),
b
.
to
(
dtype
=
torch
.
float32
))
output
=
(
scales
*
gemmout
).
to
(
out_dtype
)
if
bias
is
not
None
:
output
=
output
+
bias
return
output
.
to
(
out_dtype
)
class
SlimQuantW4A8Int8Config
(
QuantizationConfig
):
"""Config class for W8A8 Int8 Quantization.
- Weight: static, per-channel, symmetric
- Activation: dynamic, per-token, symmetric
"""
def
__init__
(
self
):
pass
@
classmethod
def
get_supported_act_dtypes
(
cls
)
->
List
[
torch
.
dtype
]:
return
[
torch
.
float16
,
torch
.
bfloat16
]
@
classmethod
def
get_min_capability
(
cls
)
->
int
:
return
75
@
classmethod
def
get_name
(
self
)
->
str
:
return
"slimquant_w4a8"
@
classmethod
def
get_config_filenames
(
cls
)
->
List
[
str
]:
return
[]
@
classmethod
def
from_config
(
cls
,
config
:
Dict
[
str
,
Any
])
->
"SlimQuantW4A8Int8Config"
:
return
cls
()
def
get_quant_method
(
self
,
layer
:
torch
.
nn
.
Module
,
prefix
:
str
,
)
->
Optional
[
"QuantizeMethodBase"
]:
from
sglang.srt.layers.moe.fused_moe_triton
import
(
FusedMoE
,
FusedMoeWeightScaleSupported
)
if
isinstance
(
layer
,
LinearBase
):
return
SlimQuantW4A8Int8LinearMethod
(
self
)
elif
isinstance
(
layer
,
FusedMoE
):
return
SlimQuantW4A8Int8MoEMethod
(
self
)
return
None
def
get_scaled_act_names
(
self
)
->
List
[
str
]:
return
[]
class
SlimQuantW4A8Int8LinearMethod
(
LinearMethodBase
):
def
__init__
(
self
,
quantization_config
:
SlimQuantW4A8Int8Config
):
self
.
quantization_config
=
quantization_config
self
.
tritonsingleton
=
W8a8GetCacheJSON
()
self
.
w8a8_strategy
=
int
(
os
.
getenv
(
'W8A8_SUPPORT_METHODS'
,
'1'
))
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
n
=
layer
.
weight
.
shape
[
0
]
k
=
layer
.
weight
.
shape
[
1
]
if
self
.
w8a8_strategy
==
1
:
if
{
n
,
k
}
not
in
self
.
tritonsingleton
.
weight_shapes
:
self
.
tritonsingleton
.
weight_shapes
.
append
({
n
,
k
})
json_file
=
self
.
tritonsingleton
.
get_w8a8json_name
(
n
,
k
)
configs_dict
=
self
.
tritonsingleton
.
get_triton_cache
(
json_file
,
n
,
k
)
if
configs_dict
:
self
.
tritonsingleton
.
triton_json_dict
.
update
(
configs_dict
)
for
key
,
value
in
configs_dict
.
items
():
m
=
int
(
key
.
split
(
'_'
)[
0
])
ops
.
triton_int8_gemm_helper
(
m
=
m
,
n
=
n
,
k
=
k
,
per_token_act_quant
=
True
,
per_out_channel_weight_quant
=
True
,
use_bias
=
False
,
device
=
layer
.
weight
.
device
,
best_config
=
value
)
else
:
weight_data
=
layer
.
weight
.
data
_weight
=
weight_data
.
T
.
contiguous
().
reshape
(
n
,
-
1
)
layer
.
weight
.
data
=
_weight
layer
.
weight
=
Parameter
(
layer
.
weight
.
t
(),
requires_grad
=
False
)
layer
.
weight_scale
=
Parameter
(
layer
.
weight_scale
.
data
,
requires_grad
=
False
)
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
input_size_per_partition
:
int
,
output_partition_sizes
:
List
[
int
],
input_size
:
int
,
output_size
:
int
,
params_dtype
:
torch
.
dtype
,
**
extra_weight_attrs
,
):
weight_loader
=
extra_weight_attrs
.
get
(
"weight_loader"
)
self
.
logical_widths
=
output_partition_sizes
weight
=
ModelWeightParameter
(
data
=
torch
.
empty
(
sum
(
output_partition_sizes
),
input_size_per_partition
,
dtype
=
torch
.
int8
),
input_dim
=
1
,
output_dim
=
0
,
weight_loader
=
weight_loader
,
)
layer
.
register_parameter
(
"weight"
,
weight
)
weight_scale
=
ChannelQuantScaleParameter
(
data
=
torch
.
empty
((
sum
(
output_partition_sizes
),
1
),
dtype
=
torch
.
float32
),
output_dim
=
0
,
weight_loader
=
weight_loader
,
)
layer
.
register_parameter
(
"weight_scale"
,
weight_scale
)
@
torch
.
_dynamo
.
disable
()
# TODO: 性能优化需要lmslim/lightop配合
def
apply
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
input_quant_args
:
Optional
[
list
[
torch
.
Tensor
]]
=
None
,
silu_quant_args
:
Optional
[
list
[
torch
.
Tensor
]]
=
None
):
if
_use_fused_rms_quant
and
input_quant_args
is
not
None
:
assert
len
(
input_quant_args
)
==
2
x_q
,
x_scale
=
input_quant_args
elif
_use_fused_silu_mul_quant
and
silu_quant_args
is
not
None
:
x_q
,
x_scale
=
silu_quant_args
else
:
x_q
,
x_scale
=
per_token_quant_int8
(
x
)
if
self
.
w8a8_strategy
==
1
:
m
=
x_q
.
shape
[
0
]
k
=
x_q
.
shape
[
1
]
n
=
layer
.
weight
.
shape
[
1
]
if
len
(
W8A8_TRITONJSON
.
triton_json_dict
)
==
0
:
best_config
=
None
elif
f
"1_
{
n
}
_
{
k
}
"
in
W8A8_TRITONJSON
.
triton_json_dict
:
if
m
<=
16
:
m_
=
m
elif
m
<=
64
:
m_
=
(
m
+
3
)
&
-
4
#取值到最近的4的倍数
elif
m
<=
160
:
m_
=
(
m
+
7
)
&
-
8
elif
m
<
200
:
#256
m_
=
160
elif
m
<
480
:
#512
m_
=
256
elif
m
<
960
:
#1024
m_
=
512
elif
m
<
2048
:
m_
=
1024
elif
m
<
4096
:
m_
=
2048
elif
m
<
6000
:
m_
=
4096
else
:
m_
=
8192
best_config
=
W8A8_TRITONJSON
.
triton_json_dict
[
f
"
{
m_
}
_
{
n
}
_
{
k
}
"
]
else
:
best_config
=
None
#if best_config==None:
# print("m:{},n:{},k:{}".format(m,n,k))
# print("config not found!")
return
ops
.
triton_scaled_mm
(
x_q
,
layer
.
weight
,
scale_a
=
x_scale
,
scale_b
=
layer
.
weight_scale
,
out_dtype
=
x
.
dtype
,
bias
=
bias
,
best_config
=
best_config
)
elif
self
.
w8a8_strategy
==
2
:
return
ops
.
cutlass_scaled_mm
(
x_q
,
layer
.
weight
,
scale_a
=
x_scale
,
scale_b
=
layer
.
weight_scale
,
out_dtype
=
x
.
dtype
,
bias
=
bias
)
else
:
return
ops
.
rocblas_scaled_mm
(
x_q
,
layer
.
weight
,
scale_a
=
x_scale
,
scale_b
=
layer
.
weight_scale
,
out_dtype
=
x
.
dtype
,
bias
=
bias
)
class
SlimQuantW4A8Int8MoEMethod
:
"""MoE method for W4A8INT8.
Supports loading INT8 checkpoints with static weight scale and
dynamic/static activation scale.
Also supports loading quantized FP16/BF16 model checkpoints with dynamic
activation scaling. The weight scaling factor will be initialized after
the model weights are loaded.
Args:
quant_config: The quantization config.
"""
def
__new__
(
cls
,
*
args
,
**
kwargs
):
from
sglang.srt.layers.moe.fused_moe_triton
import
(
FusedMoE
,
FusedMoeWeightScaleSupported
)
if
not
hasattr
(
cls
,
"_initialized"
):
original_init
=
cls
.
__init__
new_cls
=
type
(
cls
.
__name__
,
(
FusedMoEMethodBase
,),
{
"__init__"
:
original_init
,
**
{
k
:
v
for
k
,
v
in
cls
.
__dict__
.
items
()
if
k
!=
"__dict__"
},
},
)
obj
=
super
(
new_cls
,
new_cls
).
__new__
(
new_cls
)
obj
.
__init__
(
*
args
,
**
kwargs
)
return
obj
return
super
().
__new__
(
cls
)
def
__init__
(
self
,
quant_config
):
self
.
quant_config
=
quant_config
self
.
tritonsingleton
=
W8a8GetCacheJSON
()
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
num_experts
:
int
,
hidden_size
:
int
,
intermediate_size
:
int
,
params_dtype
:
torch
.
dtype
,
**
extra_weight_attrs
,
):
from
sglang.srt.layers.moe.fused_moe_triton
import
(
FusedMoE
,
FusedMoeWeightScaleSupported
)
tp_size
=
get_tensor_model_parallel_world_size
()
# WEIGHTS
w13_weight
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
num_experts
,
2
*
intermediate_size
,
hidden_size
//
2
,
dtype
=
torch
.
int8
),
requires_grad
=
False
,
)
layer
.
register_parameter
(
"w13_weight"
,
w13_weight
)
set_weight_attrs
(
w13_weight
,
extra_weight_attrs
)
w2_weight
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
num_experts
,
hidden_size
,
intermediate_size
//
2
,
dtype
=
torch
.
int8
),
requires_grad
=
False
,
)
layer
.
register_parameter
(
"w2_weight"
,
w2_weight
)
set_weight_attrs
(
w2_weight
,
extra_weight_attrs
)
w13_weight_scale
=
torch
.
nn
.
Parameter
(
torch
.
ones
(
num_experts
,
2
*
intermediate_size
,
1
,
dtype
=
torch
.
float32
),
requires_grad
=
False
,
)
w2_weight_scale
=
torch
.
nn
.
Parameter
(
torch
.
ones
(
num_experts
,
hidden_size
,
1
,
dtype
=
torch
.
float32
),
requires_grad
=
False
,
)
layer
.
register_parameter
(
"w13_weight_scale"
,
w13_weight_scale
)
layer
.
register_parameter
(
"w2_weight_scale"
,
w2_weight_scale
)
extra_weight_attrs
.
update
(
{
"quant_method"
:
FusedMoeWeightScaleSupported
.
CHANNEL
.
value
}
)
set_weight_attrs
(
w13_weight_scale
,
extra_weight_attrs
)
set_weight_attrs
(
w2_weight_scale
,
extra_weight_attrs
)
w13_input_scale
=
None
layer
.
register_parameter
(
"w13_input_scale"
,
w13_input_scale
)
w2_input_scale
=
None
layer
.
register_parameter
(
"w2_input_scale"
,
w2_input_scale
)
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
E
=
layer
.
w13_weight
.
shape
[
0
]
N1
=
layer
.
w13_weight
.
shape
[
1
]
N2
=
layer
.
w2_weight
.
shape
[
1
]
K
=
N1
//
2
if
[
E
,
N1
,
N2
,
K
]
not
in
self
.
tritonsingleton
.
moe_weight_shapes
:
self
.
tritonsingleton
.
moe_weight_shapes
.
append
([
E
,
N1
,
N2
,
K
])
TOPK
=
self
.
tritonsingleton
.
topk
json_file
=
self
.
tritonsingleton
.
get_moeint8json_name
(
E
,
N1
,
N2
,
K
,
TOPK
,
use_int4_w4a8
=
True
)
configs_dict
=
self
.
tritonsingleton
.
get_moeint8_triton_cache
(
json_file
,
E
,
N1
,
N2
,
K
,
TOPK
)
#warmup
if
configs_dict
:
self
.
tritonsingleton
.
triton_moejson_dict
.
update
(
configs_dict
)
layer
.
w13_weight
=
Parameter
(
layer
.
w13_weight
,
requires_grad
=
False
)
layer
.
w2_weight
=
Parameter
(
layer
.
w2_weight
,
requires_grad
=
False
)
layer
.
w13_weight_scale
=
Parameter
(
layer
.
w13_weight_scale
.
data
,
requires_grad
=
False
)
layer
.
w2_weight_scale
=
Parameter
(
layer
.
w2_weight_scale
.
data
,
requires_grad
=
False
)
def
create_moe_runner
(
self
,
layer
:
torch
.
nn
.
Module
,
moe_runner_config
:
MoeRunnerConfig
):
self
.
moe_runner_config
=
moe_runner_config
self
.
runner
=
MoeRunner
(
MoeRunnerBackend
.
TRITON
,
moe_runner_config
)
def
apply
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
top_k
:
int
,
renormalize
:
bool
,
use_grouped_topk
:
bool
=
False
,
topk_group
:
Optional
[
int
]
=
None
,
num_expert_group
:
Optional
[
int
]
=
None
,
global_num_experts
:
int
=
-
1
,
expert_map
:
Optional
[
torch
.
Tensor
]
=
None
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
scoring_func
:
str
=
"softmax"
,
e_score_correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
apply_router_weight_on_input
:
bool
=
False
,
activation
:
str
=
"silu"
,
enable_eplb
:
bool
=
False
,
use_nn_moe
:
Optional
[
bool
]
=
False
,
routed_scaling_factor
:
Optional
[
float
]
=
None
,
use_fused_gate
:
Optional
[
bool
]
=
False
,
**
_
)
->
torch
.
Tensor
:
from
sglang.srt.layers.moe.fused_moe_triton
import
(
FusedMoE
,
FusedMoeWeightScaleSupported
)
from
sglang.srt.layers.moe.fused_moe_triton.fused_moe
import
fused_experts
if
enable_eplb
:
raise
NotImplementedError
(
"EPLB not supported for `SlimQuantW4A8Int8MoEMethod` yet."
)
# Expert selection
topk_weights
,
topk_ids
=
FusedMoE
.
select_experts
(
hidden_states
=
x
,
router_logits
=
router_logits
,
use_grouped_topk
=
use_grouped_topk
,
top_k
=
top_k
,
renormalize
=
renormalize
,
topk_group
=
topk_group
,
num_expert_group
=
num_expert_group
,
custom_routing_function
=
custom_routing_function
,
scoring_func
=
scoring_func
,
e_score_correction_bias
=
e_score_correction_bias
,
routed_scaling_factor
=
routed_scaling_factor
,
use_fused_gate
=
use_fused_gate
)
return
fused_experts
(
x
,
layer
.
w13_weight
,
layer
.
w2_weight
,
topk_weights
=
topk_weights
,
topk_ids
=
topk_ids
,
inplace
=
True
,
use_int4_w4a8
=
True
,
per_channel_quant
=
True
,
activation
=
activation
,
expert_map
=
expert_map
,
apply_router_weight_on_input
=
apply_router_weight_on_input
,
global_num_experts
=
global_num_experts
,
w1_scale
=
(
layer
.
w13_weight_scale
),
w2_scale
=
(
layer
.
w2_weight_scale
),
a1_scale
=
layer
.
w13_input_scale
,
a2_scale
=
layer
.
w2_input_scale
,
use_nn_moe
=
use_nn_moe
,
)
python/sglang/srt/layers/quantization/slimquant_w4a8_marlin.py
0 → 100644
View file @
a1175a4e
This diff is collapsed.
Click to expand it.
python/sglang/srt/layers/quantization/w4a8_utils.py
0 → 100644
View file @
a1175a4e
import
torch
import
numpy
as
np
try
:
from
lightop
import
awq_marlin_repack_w4a8
use_lightop
=
False
except
Exception
:
use_lightop
=
False
def
unpack_int8_to_int4
(
tensor_int8
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""
将[N, K//2]大小的torch.int8 Tensor,转换为[N, K]大小的torch.int32 Tensor。
每个int8包含两个int4,分别提取到int32的低4位,其余位为0。
Args:
tensor_int8 (torch.Tensor): 输入张量,形状为[N, K//2],类型为torch.int8。
Returns:
torch.Tensor: 输出张量,形状为[N, K],类型为torch.int32。
"""
if
tensor_int8
.
dtype
!=
torch
.
int8
:
raise
ValueError
(
"Input tensor must be of type torch.int8"
)
N
,
K_half
=
tensor_int8
.
shape
tensor_uint8
=
tensor_int8
.
to
(
torch
.
uint8
)
high4
=
tensor_uint8
&
0x0F
low4
=
(
tensor_uint8
>>
4
)
&
0x0F
unpacked
=
torch
.
empty
((
N
,
K_half
*
2
),
dtype
=
torch
.
int32
,
device
=
tensor_int8
.
device
)
unpacked
[:,
0
::
2
]
=
low4
.
to
(
torch
.
int32
)
unpacked
[:,
1
::
2
]
=
high4
.
to
(
torch
.
int32
)
return
unpacked
def
get_weight_perms
(
interleave
:
bool
=
True
):
perm
=
[]
for
i
in
range
(
64
):
for
col
in
range
(
4
):
cur_col
=
(
i
%
16
)
*
4
+
col
for
row
in
range
(
8
):
cur_row
=
(
i
//
16
)
*
8
+
row
cur_idx
=
cur_row
*
64
+
cur_col
perm
.
append
(
cur_idx
)
perm
=
np
.
array
(
perm
)
if
interleave
:
interleave
=
np
.
array
([
4
,
0
,
5
,
1
,
6
,
2
,
7
,
3
])
perm
=
perm
.
reshape
((
-
1
,
8
))[:,
interleave
].
ravel
()
perm
=
torch
.
from_numpy
(
perm
)
return
perm
def
marlin_weights
(
q_w
,
weight_perm
,
k_tile
=
32
,
n_tile
=
64
,
pack_factor
=
8
):
size_k
,
size_n
=
q_w
.
shape
q_w
=
q_w
.
reshape
((
size_k
//
k_tile
,
k_tile
,
size_n
//
n_tile
,
n_tile
))
q_w
=
q_w
.
permute
((
0
,
2
,
1
,
3
))
q_w
=
q_w
.
reshape
((
size_k
//
k_tile
,
size_n
*
k_tile
))
q_w
=
q_w
.
reshape
((
-
1
,
weight_perm
.
numel
()))[:,
weight_perm
].
reshape
(
q_w
.
shape
)
orig_device
=
q_w
.
device
q_w
=
q_w
.
contiguous
().
to
(
torch
.
int32
)
M
,
N
=
q_w
.
shape
assert
N
%
pack_factor
==
0
,
f
"size_n (
{
N
}
) must be divisible by pack_factor (
{
pack_factor
}
)"
q_packed
=
torch
.
zeros
((
M
,
N
//
pack_factor
),
dtype
=
torch
.
int32
,
device
=
orig_device
)
for
i
in
range
(
pack_factor
):
q_packed
+=
q_w
[:,
i
::
pack_factor
]
<<
(
4
*
i
)
return
q_packed
def
w4a8_2_marlin_weight
(
w4a8_w
):
full_w4a8_w
=
unpack_int8_to_int4
(
w4a8_w
)
full_w4a8_w
=
full_w4a8_w
.
T
weight_perm
=
get_weight_perms
()
marlin_q_w
=
marlin_weights
(
full_w4a8_w
,
weight_perm
,
k_tile
=
32
,
n_tile
=
64
,
pack_factor
=
8
)
return
marlin_q_w
def
w4a8_weight_repack_impl
(
input
):
if
use_lightop
:
size_batch
=
input
.
shape
[
0
]
size_n
=
input
.
shape
[
1
]
size_k
=
input
.
shape
[
2
]
*
2
output
=
torch
.
zeros
((
size_batch
,
size_k
//
32
,
size_n
*
4
),
device
=
input
.
device
,
dtype
=
torch
.
int32
)
awq_marlin_repack_w4a8
(
input
,
output
,
size_batch
,
size_k
,
size_n
)
else
:
w_marlin_list
=
[]
for
e
in
range
(
input
.
shape
[
0
]):
w_marlin_in
=
w4a8_2_marlin_weight
(
input
[
e
])
w_marlin_list
.
append
(
w_marlin_in
)
output
=
torch
.
stack
(
w_marlin_list
,
dim
=
0
)
return
output
python/sglang/srt/layers/quantization/w8a8_int8.py
View file @
a1175a4e
...
@@ -22,7 +22,8 @@ from sglang.srt.layers.quantization.base_config import (
...
@@ -22,7 +22,8 @@ from sglang.srt.layers.quantization.base_config import (
QuantizeMethodBase
,
QuantizeMethodBase
,
)
)
from
sglang.srt.layers.quantization.compressed_tensors.utils
import
should_ignore_layer
from
sglang.srt.layers.quantization.compressed_tensors.utils
import
should_ignore_layer
from
sglang.srt.layers.quantization.int8_kernel
import
per_token_quant_int8
# from sglang.srt.layers.quantization.int8_kernel import per_token_quant_int8
from
lmslim.layers.gemm.int8_utils
import
per_token_quant_int8
from
sglang.srt.layers.quantization.unquant
import
UnquantizedLinearMethod
from
sglang.srt.layers.quantization.unquant
import
UnquantizedLinearMethod
from
sglang.srt.utils
import
(
from
sglang.srt.utils
import
(
apply_module_patch
,
apply_module_patch
,
...
@@ -39,6 +40,8 @@ if TYPE_CHECKING:
...
@@ -39,6 +40,8 @@ if TYPE_CHECKING:
CombineInput
,
CombineInput
,
StandardDispatchOutput
,
StandardDispatchOutput
,
)
)
from
lmslim
import
quant_ops
_is_cuda
=
is_cuda
()
_is_cuda
=
is_cuda
()
_is_cpu_amx_available
=
cpu_has_amx_support
()
_is_cpu_amx_available
=
cpu_has_amx_support
()
...
@@ -405,7 +408,7 @@ class W8A8Int8LinearMethod(LinearMethodBase):
...
@@ -405,7 +408,7 @@ class W8A8Int8LinearMethod(LinearMethodBase):
x_scale_2d
=
x_scale
.
view
(
-
1
,
x_scale
.
shape
[
-
1
])
x_scale_2d
=
x_scale
.
view
(
-
1
,
x_scale
.
shape
[
-
1
])
output_shape
=
[
*
x_q
.
shape
[:
-
1
],
layer
.
weight
.
shape
[
1
]]
output_shape
=
[
*
x_q
.
shape
[:
-
1
],
layer
.
weight
.
shape
[
1
]]
output
=
int8
_scaled_mm
(
output
=
quant_ops
.
triton
_scaled_mm
(
x_q_2d
,
x_q_2d
,
layer
.
weight
,
layer
.
weight
,
x_scale_2d
,
x_scale_2d
,
...
...
python/sglang/srt/layers/rotary_embedding.py
View file @
a1175a4e
...
@@ -23,6 +23,8 @@ from sglang.srt.utils import (
...
@@ -23,6 +23,8 @@ from sglang.srt.utils import (
is_xpu
,
is_xpu
,
)
)
from
sglang.srt.utils
import
direct_register_custom_op
_is_cuda
=
is_cuda
()
_is_cuda
=
is_cuda
()
_is_hip
=
is_hip
()
_is_hip
=
is_hip
()
_use_aiter
=
get_bool_env_var
(
"SGLANG_USE_AITER"
)
and
_is_hip
_use_aiter
=
get_bool_env_var
(
"SGLANG_USE_AITER"
)
and
_is_hip
...
@@ -30,6 +32,7 @@ _is_npu = is_npu()
...
@@ -30,6 +32,7 @@ _is_npu = is_npu()
_is_cpu_amx_available
=
cpu_has_amx_support
()
_is_cpu_amx_available
=
cpu_has_amx_support
()
_is_cpu
=
is_cpu
()
_is_cpu
=
is_cpu
()
_is_xpu
=
is_xpu
()
_is_xpu
=
is_xpu
()
_use_lightop
=
get_bool_env_var
(
"SGLANG_USE_LIGHTOP"
)
if
_is_cuda
:
if
_is_cuda
:
from
sgl_kernel
import
FusedSetKVBufferArg
,
apply_rope_with_cos_sin_cache_inplace
from
sgl_kernel
import
FusedSetKVBufferArg
,
apply_rope_with_cos_sin_cache_inplace
...
@@ -58,6 +61,34 @@ def _rotate_gptj(x: torch.Tensor) -> torch.Tensor:
...
@@ -58,6 +61,34 @@ def _rotate_gptj(x: torch.Tensor) -> torch.Tensor:
x
=
torch
.
stack
((
-
x2
,
x1
),
dim
=-
1
)
x
=
torch
.
stack
((
-
x2
,
x1
),
dim
=-
1
)
return
x
.
flatten
(
-
2
)
return
x
.
flatten
(
-
2
)
# for dcu
@
triton
.
jit
def
deepseek_scaling_rotary_emb_kernel_gptj
(
cos_sin
,
q
,
stride1
:
int
,
stride2
:
int
,
stride_cs
:
int
,
dim1
:
int
,
dim2
:
int
,
dim3
:
int
,
BLOCK_SIZE
:
tl
.
constexpr
):
pid0
=
tl
.
program_id
(
0
)
pid1
=
tl
.
program_id
(
1
)
pid2
=
tl
.
program_id
(
2
)
offsets_cs
=
tl
.
arange
(
0
,
BLOCK_SIZE
)
+
pid2
*
BLOCK_SIZE
offsets_q
=
tl
.
arange
(
0
,
BLOCK_SIZE
*
2
)
+
pid2
*
BLOCK_SIZE
*
2
offsets
=
pid0
*
stride1
+
pid1
*
stride2
+
offsets_q
mask
=
offsets_cs
<
dim3
mask2
=
offsets_q
<
dim3
*
2
v_cos
=
tl
.
load
(
cos_sin
+
pid0
*
stride_cs
+
offsets_cs
,
mask
=
mask
)
v_cos2
=
tl
.
interleave
(
v_cos
,
v_cos
)
v_sin
=
tl
.
load
(
cos_sin
+
pid0
*
stride_cs
+
dim3
+
offsets_cs
,
mask
=
mask
)
v_sin2
=
tl
.
interleave
(
v_sin
,
v_sin
)
x12
=
tl
.
load
(
q
+
offsets
,
mask
=
mask2
)
x1
,
x2
=
tl
.
split
(
x12
.
reshape
([
BLOCK_SIZE
,
2
]))
# we are both reading and writing 'q'; make sure all warps are in sync
tl
.
debug_barrier
()
x12_
=
tl
.
ravel
(
tl
.
join
(
-
x2
,
x1
))
x12
=
x12
*
v_cos2
+
x12_
*
v_sin2
tl
.
store
(
q
+
offsets
,
x12
,
mask
=
mask2
)
def
_apply_rotary_emb
(
def
_apply_rotary_emb
(
x
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
...
@@ -765,6 +796,9 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
...
@@ -765,6 +796,9 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
# Re-dispatch
# Re-dispatch
if
_is_hip
:
if
_is_hip
:
if
_use_lightop
:
self
.
_forward_method
=
self
.
forward_dcu
else
:
self
.
_forward_method
=
self
.
forward_native
self
.
_forward_method
=
self
.
forward_native
def
_compute_inv_freq
(
self
,
scaling_factor
:
float
)
->
torch
.
Tensor
:
def
_compute_inv_freq
(
self
,
scaling_factor
:
float
)
->
torch
.
Tensor
:
...
@@ -808,6 +842,24 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
...
@@ -808,6 +842,24 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
cache
=
torch
.
cat
((
cos
,
sin
),
dim
=-
1
)
cache
=
torch
.
cat
((
cos
,
sin
),
dim
=-
1
)
return
cache
return
cache
def
rotary_embedding_deepseek_fuse
(
positions
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
head_size
:
int
,
cos_sin_cache
:
torch
.
Tensor
,
is_neox_style
:
bool
)
->
None
:
from
lightop
import
op
op
.
rotary_embedding_deepseek_fuse
(
positions
,
query
,
key
,
head_size
,
cos_sin_cache
,
is_neox_style
)
def
rotary_embedding_deepseek_fuse_fake
(
positions
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
head_size
:
int
,
cos_sin_cache
:
torch
.
Tensor
,
is_neox_style
:
bool
)
->
None
:
pass
direct_register_custom_op
(
op_name
=
"rotary_embedding_deepseek_fuse"
,
op_func
=
rotary_embedding_deepseek_fuse
,
mutates_args
=
[
"query"
,
"key"
],
fake_impl
=
rotary_embedding_deepseek_fuse_fake
,
)
def
forward_native
(
def
forward_native
(
self
,
self
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
...
@@ -849,6 +901,77 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
...
@@ -849,6 +901,77 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
key
=
key_rot
key
=
key_rot
return
query
.
to
(
dtype
),
key
.
to
(
dtype
)
return
query
.
to
(
dtype
),
key
.
to
(
dtype
)
def
forward_dcu
(
self
,
positions
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
offsets
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
assert
key
is
not
None
if
self
.
cos_sin_cache
.
device
!=
positions
.
device
:
self
.
cos_sin_cache
:
torch
.
Tensor
=
self
.
cos_sin_cache
.
to
(
positions
.
device
)
cos_sin
=
self
.
cos_sin_cache
[
torch
.
add
(
positions
,
offsets
)
if
offsets
is
not
None
else
positions
]
if
query
.
device
.
type
==
'cuda'
and
not
self
.
is_neox_style
:
# not self.reference ?
assert
len
(
query
.
shape
)
==
3
def
call
(
q
):
BLOCK_SIZE
=
64
grid
=
(
q
.
shape
[
-
3
],
q
.
shape
[
-
2
],
triton
.
cdiv
(
self
.
rotary_dim
//
2
,
BLOCK_SIZE
),
)
deepseek_scaling_rotary_emb_kernel_gptj
[
grid
](
cos_sin
,
q
,
stride1
=
q
.
stride
()[
-
3
],
stride2
=
q
.
stride
()[
-
2
],
stride_cs
=
cos_sin
.
stride
()[
-
2
],
dim1
=
q
.
shape
[
0
],
dim2
=
q
.
shape
[
1
],
dim3
=
self
.
rotary_dim
//
2
,
BLOCK_SIZE
=
BLOCK_SIZE
,
num_warps
=
1
)
if
_use_lightop
:
torch
.
ops
.
sglang
.
rotary_embedding_deepseek_fuse
(
positions
,
query
,
key
,
self
.
head_size
,
self
.
cos_sin_cache
,
self
.
is_neox_style
)
else
:
call
(
query
)
call
(
key
)
return
query
,
key
else
:
query_rot
=
query
[...,
:
self
.
rotary_dim
]
key_rot
=
key
[...,
:
self
.
rotary_dim
]
if
self
.
rotary_dim
<
self
.
head_size
:
query_pass
=
query
[...,
self
.
rotary_dim
:]
key_pass
=
key
[...,
self
.
rotary_dim
:]
cos
,
sin
=
cos_sin
.
chunk
(
2
,
dim
=-
1
)
if
self
.
is_neox_style
:
# NOTE(woosuk): Here we assume that the positions tensor has the
# shape [batch_size, seq_len].
cos
=
cos
.
repeat
(
1
,
1
,
2
).
unsqueeze
(
-
2
)
sin
=
sin
.
repeat
(
1
,
1
,
2
).
unsqueeze
(
-
2
)
else
:
cos
=
cos
.
repeat_interleave
(
2
,
dim
=-
1
).
unsqueeze
(
-
2
)
sin
=
sin
.
repeat_interleave
(
2
,
dim
=-
1
).
unsqueeze
(
-
2
)
rotate_fn
=
_rotate_neox
if
self
.
is_neox_style
else
_rotate_gptj
query_rot
=
query_rot
*
cos
+
rotate_fn
(
query_rot
)
*
sin
key_rot
=
key_rot
*
cos
+
rotate_fn
(
key_rot
)
*
sin
if
self
.
rotary_dim
<
self
.
head_size
:
query
=
torch
.
cat
((
query_rot
,
query_pass
),
dim
=-
1
)
key
=
torch
.
cat
((
key_rot
,
key_pass
),
dim
=-
1
)
else
:
query
=
query_rot
key
=
key_rot
return
query
,
key
def
forward_npu
(
def
forward_npu
(
self
,
self
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
...
...
python/sglang/srt/managers/schedule_batch.py
View file @
a1175a4e
...
@@ -1246,6 +1246,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
...
@@ -1246,6 +1246,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
self
.
seq_lens
=
seq_lens_tensor
self
.
seq_lens
=
seq_lens_tensor
self
.
seq_lens_cpu
=
seq_lens_cpu
self
.
seq_lens_cpu
=
seq_lens_cpu
self
.
extend_num_tokens
=
extend_num_tokens
self
.
extend_num_tokens
=
extend_num_tokens
self
.
loc_tensor
=
torch
.
tensor
([
-
1
],
device
=
self
.
device
)
# Allocate memory
# Allocate memory
out_cache_loc
,
req_pool_indices_tensor
,
req_pool_indices
=
alloc_for_extend
(
out_cache_loc
,
req_pool_indices_tensor
,
req_pool_indices
=
alloc_for_extend
(
...
...
python/sglang/srt/managers/scheduler.py
View file @
a1175a4e
...
@@ -2006,7 +2006,7 @@ class Scheduler(
...
@@ -2006,7 +2006,7 @@ class Scheduler(
batch
.
spec_info
=
batch_result
.
next_draft_input
batch
.
spec_info
=
batch_result
.
next_draft_input
batch
.
spec_info
.
future_indices
=
future_indices
batch
.
spec_info
.
future_indices
=
future_indices
batch
.
sampling_info
.
is_all_greedy
=
True
#nhb
# batch.spec_info = EagleDraftInput(
# batch.spec_info = EagleDraftInput(
# future_indices=future_indices,
# future_indices=future_indices,
# verify_done=batch_result.next_draft_input.verify_done,
# verify_done=batch_result.next_draft_input.verify_done,
...
...
python/sglang/srt/mem_cache/allocator.py
View file @
a1175a4e
This diff is collapsed.
Click to expand it.
python/sglang/srt/mem_cache/common.py
View file @
a1175a4e
This diff is collapsed.
Click to expand it.
python/sglang/srt/mem_cache/memory_pool.py
View file @
a1175a4e
...
@@ -1404,6 +1404,15 @@ class MLATokenToKVPool(KVCache):
...
@@ -1404,6 +1404,15 @@ class MLATokenToKVPool(KVCache):
return
self
.
kv_buffer
[
layer_id
-
self
.
start_layer
]
return
self
.
kv_buffer
[
layer_id
-
self
.
start_layer
]
def
get_key_buffer_DeepSeekV2
(
self
,
layer_id
:
int
):
if
self
.
layer_transfer_counter
is
not
None
:
self
.
layer_transfer_counter
.
wait_until
(
layer_id
-
self
.
start_layer
)
if
self
.
store_dtype
!=
self
.
dtype
and
self
.
dtype
not
in
(
torch
.
float8_e5m2
,
torch
.
float8_e4m3fn
):
return
self
.
kv_buffer
[
layer_id
-
self
.
start_layer
].
view
(
self
.
dtype
)
return
self
.
kv_buffer
[
layer_id
-
self
.
start_layer
],
self
.
dtype
def
get_value_buffer
(
self
,
layer_id
:
int
):
def
get_value_buffer
(
self
,
layer_id
:
int
):
if
self
.
layer_transfer_counter
is
not
None
:
if
self
.
layer_transfer_counter
is
not
None
:
self
.
layer_transfer_counter
.
wait_until
(
layer_id
-
self
.
start_layer
)
self
.
layer_transfer_counter
.
wait_until
(
layer_id
-
self
.
start_layer
)
...
...
python/sglang/srt/model_executor/forward_batch_info.py
View file @
a1175a4e
This diff is collapsed.
Click to expand it.
Prev
1
2
3
4
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment