Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
b4c41f72
Unverified
Commit
b4c41f72
authored
Jun 14, 2025
by
fzyzcjy
Committed by
GitHub
Jun 13, 2025
Browse files
Refactor DeepGEMM integration (#7150)
parent
8b8f2e74
Changes
12
Hide whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
207 additions
and
147 deletions
+207
-147
python/sglang/srt/layers/moe/ep_moe/kernels.py
python/sglang/srt/layers/moe/ep_moe/kernels.py
+1
-5
python/sglang/srt/layers/moe/ep_moe/layer.py
python/sglang/srt/layers/moe/ep_moe/layer.py
+35
-32
python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py
python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py
+5
-5
python/sglang/srt/layers/quantization/deep_gemm_wrapper/__init__.py
...ang/srt/layers/quantization/deep_gemm_wrapper/__init__.py
+1
-0
python/sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py
...rt/layers/quantization/deep_gemm_wrapper/compile_utils.py
+22
-76
python/sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py
...g/srt/layers/quantization/deep_gemm_wrapper/configurer.py
+26
-0
python/sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py
...g/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py
+95
-0
python/sglang/srt/layers/quantization/fp8_kernel.py
python/sglang/srt/layers/quantization/fp8_kernel.py
+6
-10
python/sglang/srt/layers/quantization/fp8_utils.py
python/sglang/srt/layers/quantization/fp8_utils.py
+3
-4
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+6
-6
python/sglang/srt/models/deepseek_v2.py
python/sglang/srt/models/deepseek_v2.py
+3
-7
python/sglang/srt/two_batch_overlap.py
python/sglang/srt/two_batch_overlap.py
+4
-2
No files found.
python/sglang/srt/layers/moe/ep_moe/kernels.py
View file @
b4c41f72
...
@@ -4,6 +4,7 @@ from typing import List, Optional
...
@@ -4,6 +4,7 @@ from typing import List, Optional
import
torch
import
torch
import
triton
import
triton
from
sglang.math_utils
import
ceil_div
from
sglang.srt.layers.quantization.fp8_kernel
import
per_token_group_quant_fp8
from
sglang.srt.layers.quantization.fp8_kernel
import
per_token_group_quant_fp8
from
sglang.srt.utils
import
dispose_tensor
,
is_cuda
from
sglang.srt.utils
import
dispose_tensor
,
is_cuda
...
@@ -15,11 +16,6 @@ if _is_cuda:
...
@@ -15,11 +16,6 @@ if _is_cuda:
sglang_per_token_group_quant_fp8
as
per_token_group_quant_fp8
,
sglang_per_token_group_quant_fp8
as
per_token_group_quant_fp8
,
)
)
try
:
from
deep_gemm
import
ceil_div
except
ImportError
:
logger
.
error
(
f
"Failed to import ceil_div from deep_gemm."
)
import
triton.language
as
tl
import
triton.language
as
tl
...
...
python/sglang/srt/layers/moe/ep_moe/layer.py
View file @
b4c41f72
import
logging
import
logging
from
typing
import
Callable
,
List
,
Optional
,
Tuple
from
typing
import
Callable
,
List
,
Optional
,
Tuple
import
einops
import
torch
import
torch
from
sgl_kernel
import
silu_and_mul
from
torch.nn
import
Module
from
torch.nn
import
Module
from
sglang.srt.layers.quantization.deep_gemm
import
_ENABLE_JIT_DEEPGEMM
from
sglang.srt.managers.expert_location
import
get_global_expert_location_metadata
from
sglang.srt.managers.expert_location_dispatch
import
ExpertLocationDispatchInfo
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
try
:
from
deep_gemm
import
(
get_col_major_tma_aligned_tensor
,
m_grouped_gemm_fp8_fp8_bf16_nt_contiguous
,
m_grouped_gemm_fp8_fp8_bf16_nt_masked
,
)
from
sgl_kernel
import
silu_and_mul
from
sglang.srt.layers.quantization.fp8_kernel
import
(
sglang_per_token_group_quant_fp8
,
)
use_deep_gemm
=
True
except
ImportError
:
use_deep_gemm
=
False
from
sglang.srt.custom_op
import
CustomOp
from
sglang.srt.custom_op
import
CustomOp
from
sglang.srt.distributed
import
(
from
sglang.srt.distributed
import
(
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_rank
,
...
@@ -45,6 +26,7 @@ from sglang.srt.layers.moe.ep_moe.kernels import (
...
@@ -45,6 +26,7 @@ from sglang.srt.layers.moe.ep_moe.kernels import (
from
sglang.srt.layers.moe.fused_moe_triton
import
FusedMoeWeightScaleSupported
from
sglang.srt.layers.moe.fused_moe_triton
import
FusedMoeWeightScaleSupported
from
sglang.srt.layers.moe.fused_moe_triton.layer
import
FusedMoE
,
FusedMoEMethodBase
from
sglang.srt.layers.moe.fused_moe_triton.layer
import
FusedMoE
,
FusedMoEMethodBase
from
sglang.srt.layers.moe.topk
import
select_experts
from
sglang.srt.layers.moe.topk
import
select_experts
from
sglang.srt.layers.quantization
import
deep_gemm_wrapper
from
sglang.srt.layers.quantization.base_config
import
(
from
sglang.srt.layers.quantization.base_config
import
(
QuantizationConfig
,
QuantizationConfig
,
QuantizeMethodBase
,
QuantizeMethodBase
,
...
@@ -52,10 +34,20 @@ from sglang.srt.layers.quantization.base_config import (
...
@@ -52,10 +34,20 @@ from sglang.srt.layers.quantization.base_config import (
from
sglang.srt.layers.quantization.fp8
import
Fp8Config
,
Fp8MoEMethod
from
sglang.srt.layers.quantization.fp8
import
Fp8Config
,
Fp8MoEMethod
from
sglang.srt.layers.quantization.fp8_kernel
import
(
from
sglang.srt.layers.quantization.fp8_kernel
import
(
scaled_fp8_quant
,
scaled_fp8_quant
,
sglang_per_token_group_quant_fp8
,
sglang_per_token_quant_fp8
,
sglang_per_token_quant_fp8
,
)
)
from
sglang.srt.managers.expert_location
import
get_global_expert_location_metadata
from
sglang.srt.managers.expert_location_dispatch
import
ExpertLocationDispatchInfo
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.model_executor.forward_batch_info
import
ForwardMode
from
sglang.srt.model_executor.forward_batch_info
import
ForwardMode
from
sglang.srt.utils
import
DeepEPMode
,
dispose_tensor
,
is_hip
,
set_weight_attrs
from
sglang.srt.utils
import
(
DeepEPMode
,
dispose_tensor
,
get_bool_env_var
,
is_hip
,
set_weight_attrs
,
)
_is_hip
=
is_hip
()
_is_hip
=
is_hip
()
...
@@ -680,7 +672,6 @@ class Fp8EPMoEMethod(Fp8MoEMethod):
...
@@ -680,7 +672,6 @@ class Fp8EPMoEMethod(Fp8MoEMethod):
params_dtype
:
torch
.
dtype
,
params_dtype
:
torch
.
dtype
,
**
extra_weight_attrs
,
**
extra_weight_attrs
,
):
):
if
self
.
quant_config
.
is_checkpoint_fp8_serialized
:
if
self
.
quant_config
.
is_checkpoint_fp8_serialized
:
params_dtype
=
torch
.
float8_e4m3fn
params_dtype
=
torch
.
float8_e4m3fn
...
@@ -920,7 +911,9 @@ class DeepEPMoE(EPMoE):
...
@@ -920,7 +911,9 @@ class DeepEPMoE(EPMoE):
)
)
self
.
deepep_mode
=
deepep_mode
self
.
deepep_mode
=
deepep_mode
if
self
.
deepep_mode
.
enable_low_latency
():
if
self
.
deepep_mode
.
enable_low_latency
():
assert
use_deep_gemm
,
f
"DeepEP
{
self
.
deepep_mode
}
mode requires deep_gemm"
assert
(
deep_gemm_wrapper
.
ENABLE_JIT_DEEPGEMM
),
f
"DeepEP
{
self
.
deepep_mode
}
mode requires deep_gemm"
self
.
w13_weight_fp8
=
(
self
.
w13_weight_fp8
=
(
self
.
w13_weight
,
self
.
w13_weight
,
(
(
...
@@ -948,7 +941,7 @@ class DeepEPMoE(EPMoE):
...
@@ -948,7 +941,7 @@ class DeepEPMoE(EPMoE):
):
):
resolved_deepep_mode
=
self
.
deepep_mode
.
resolve
(
forward_mode
)
resolved_deepep_mode
=
self
.
deepep_mode
.
resolve
(
forward_mode
)
if
resolved_deepep_mode
==
DeepEPMode
.
normal
:
if
resolved_deepep_mode
==
DeepEPMode
.
normal
:
if
_
ENABLE_JIT_DEEPGEMM
:
if
deep_gemm_wrapper
.
ENABLE_JIT_DEEPGEMM
:
return
self
.
forward_deepgemm_contiguous
(
return
self
.
forward_deepgemm_contiguous
(
hidden_states
,
topk_idx
,
topk_weights
,
num_recv_tokens_per_expert
hidden_states
,
topk_idx
,
topk_weights
,
num_recv_tokens_per_expert
)
)
...
@@ -1145,7 +1138,7 @@ class DeepEPMoE(EPMoE):
...
@@ -1145,7 +1138,7 @@ class DeepEPMoE(EPMoE):
dtype
=
torch
.
bfloat16
,
dtype
=
torch
.
bfloat16
,
)
)
input_tensor
[
1
]
=
tma_align_input_scale
(
input_tensor
[
1
])
input_tensor
[
1
]
=
tma_align_input_scale
(
input_tensor
[
1
])
m_
grouped_gemm_
fp8_fp8_
bf16_
nt_
contig
uous
(
deep_gemm_wrapper
.
grouped_gemm_
nt_f8f8
bf16_contig
(
input_tensor
,
self
.
w13_weight_fp8
,
gateup_output
,
m_indices
input_tensor
,
self
.
w13_weight_fp8
,
gateup_output
,
m_indices
)
)
del
input_tensor
del
input_tensor
...
@@ -1169,7 +1162,7 @@ class DeepEPMoE(EPMoE):
...
@@ -1169,7 +1162,7 @@ class DeepEPMoE(EPMoE):
)
)
del
down_input
del
down_input
down_input_scale
=
tma_align_input_scale
(
down_input_scale
)
down_input_scale
=
tma_align_input_scale
(
down_input_scale
)
m_
grouped_gemm_
fp8_fp8_
bf16_
nt_
contig
uous
(
deep_gemm_wrapper
.
grouped_gemm_
nt_f8f8
bf16_contig
(
(
down_input_fp8
,
down_input_scale
),
(
down_input_fp8
,
down_input_scale
),
self
.
w2_weight_fp8
,
self
.
w2_weight_fp8
,
down_output
,
down_output
,
...
@@ -1202,8 +1195,13 @@ class DeepEPMoE(EPMoE):
...
@@ -1202,8 +1195,13 @@ class DeepEPMoE(EPMoE):
gateup_output
=
torch
.
empty
(
gateup_output
=
torch
.
empty
(
(
num_groups
,
m
,
n
),
device
=
hidden_states_fp8
[
0
].
device
,
dtype
=
torch
.
bfloat16
(
num_groups
,
m
,
n
),
device
=
hidden_states_fp8
[
0
].
device
,
dtype
=
torch
.
bfloat16
)
)
m_grouped_gemm_fp8_fp8_bf16_nt_masked
(
deep_gemm_wrapper
.
grouped_gemm_nt_f8f8bf16_masked
(
hidden_states_fp8
,
self
.
w13_weight_fp8
,
gateup_output
,
masked_m
,
expected_m
hidden_states_fp8
,
self
.
w13_weight_fp8
,
gateup_output
,
masked_m
,
expected_m
,
recipe
=
(
1
,
128
,
128
)
if
deep_gemm_wrapper
.
DEEPGEMM_V202506
else
None
,
)
)
dispose_tensor
(
hidden_states_fp8
[
0
])
dispose_tensor
(
hidden_states_fp8
[
0
])
...
@@ -1240,13 +1238,18 @@ class DeepEPMoE(EPMoE):
...
@@ -1240,13 +1238,18 @@ class DeepEPMoE(EPMoE):
n
=
self
.
w2_weight
.
size
(
1
)
n
=
self
.
w2_weight
.
size
(
1
)
down_input_fp8
=
(
down_input_fp8
=
(
down_input
,
down_input
,
get_col_major_tma_aligned_tensor
(
down_input_scale
),
deep_gemm_wrapper
.
get_col_major_tma_aligned_tensor
(
down_input_scale
),
)
)
down_output
=
torch
.
empty
(
down_output
=
torch
.
empty
(
(
num_groups
,
m
,
n
),
device
=
down_input
.
device
,
dtype
=
torch
.
bfloat16
(
num_groups
,
m
,
n
),
device
=
down_input
.
device
,
dtype
=
torch
.
bfloat16
)
)
m_grouped_gemm_fp8_fp8_bf16_nt_masked
(
deep_gemm_wrapper
.
grouped_gemm_nt_f8f8bf16_masked
(
down_input_fp8
,
self
.
w2_weight_fp8
,
down_output
,
masked_m
,
expected_m
down_input_fp8
,
self
.
w2_weight_fp8
,
down_output
,
masked_m
,
expected_m
,
recipe
=
(
1
,
128
,
128
)
if
deep_gemm_wrapper
.
DEEPGEMM_V202506
else
None
,
)
)
return
down_output
return
down_output
...
...
python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py
View file @
b4c41f72
import
logging
import
logging
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
sglang.srt.layers.quantization
.
deep_gemm
import
_ENABLE_JIT_DEEPGEMM
from
sglang.srt.layers.quantization
import
deep_gemm
_wrapper
from
sglang.srt.managers.expert_distribution
import
(
from
sglang.srt.managers.expert_distribution
import
(
get_global_expert_distribution_recorder
,
get_global_expert_distribution_recorder
,
)
)
...
@@ -236,14 +236,14 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
...
@@ -236,14 +236,14 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
topk_weights
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
):
):
topk_idx
=
topk_idx
.
to
(
torch
.
int64
)
topk_idx
=
topk_idx
.
to
(
torch
.
int64
)
if
_
ENABLE_JIT_DEEPGEMM
:
if
deep_gemm_wrapper
.
ENABLE_JIT_DEEPGEMM
:
# TODO hard code 128 block quant,use fp8 communication
# TODO hard code 128 block quant,use fp8 communication
hidden_states
=
sglang_per_token_group_quant_fp8
(
hidden_states
,
128
)
hidden_states
=
sglang_per_token_group_quant_fp8
(
hidden_states
,
128
)
previous_event
=
Buffer
.
capture
()
if
self
.
async_finish
else
None
previous_event
=
Buffer
.
capture
()
if
self
.
async_finish
else
None
return
hidden_states
,
topk_idx
,
topk_weights
,
previous_event
return
hidden_states
,
topk_idx
,
topk_weights
,
previous_event
def
dispatch_b
(
self
,
hidden_states
,
topk_idx
,
topk_weights
,
previous_event
):
def
dispatch_b
(
self
,
hidden_states
,
topk_idx
,
topk_weights
,
previous_event
):
if
_
ENABLE_JIT_DEEPGEMM
:
if
deep_gemm_wrapper
.
ENABLE_JIT_DEEPGEMM
:
(
(
hidden_states
,
hidden_states
,
topk_idx
,
topk_idx
,
...
@@ -345,7 +345,7 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
...
@@ -345,7 +345,7 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
previous_event
=
previous_event
,
previous_event
=
previous_event
,
async_finish
=
self
.
async_finish
,
async_finish
=
self
.
async_finish
,
allocate_on_comm_stream
=
(
previous_event
is
not
None
)
and
self
.
async_finish
,
allocate_on_comm_stream
=
(
previous_event
is
not
None
)
and
self
.
async_finish
,
expert_alignment
=
128
if
_
ENABLE_JIT_DEEPGEMM
else
1
,
expert_alignment
=
128
if
deep_gemm_wrapper
.
ENABLE_JIT_DEEPGEMM
else
1
,
config
=
DeepEPConfig
.
get_instance
().
normal_dispatch_config
,
config
=
DeepEPConfig
.
get_instance
().
normal_dispatch_config
,
)
)
...
@@ -409,7 +409,7 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
...
@@ -409,7 +409,7 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
topk_idx
:
torch
.
Tensor
,
topk_idx
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
):
):
if
_
ENABLE_JIT_DEEPGEMM
:
if
deep_gemm_wrapper
.
ENABLE_JIT_DEEPGEMM
:
output
=
hidden_states
output
=
hidden_states
else
:
else
:
if
hidden_states
.
shape
[
0
]
>
0
:
if
hidden_states
.
shape
[
0
]
>
0
:
...
...
python/sglang/srt/layers/quantization/deep_gemm_wrapper/__init__.py
0 → 100644
View file @
b4c41f72
from
.entrypoint
import
*
python/sglang/srt/layers/quantization/deep_gemm.py
→
python/sglang/srt/layers/quantization/deep_gemm
_wrapper/compile_utils
.py
View file @
b4c41f72
...
@@ -5,33 +5,24 @@ from dataclasses import dataclass
...
@@ -5,33 +5,24 @@ from dataclasses import dataclass
from
enum
import
IntEnum
,
auto
from
enum
import
IntEnum
,
auto
from
typing
import
Callable
,
Dict
,
List
,
Optional
,
Tuple
from
typing
import
Callable
,
Dict
,
List
,
Optional
,
Tuple
import
torch
from
tqdm.contrib.concurrent
import
thread_map
from
tqdm.contrib.concurrent
import
thread_map
from
sglang.srt.layers.quantization.deep_gemm_wrapper.configurer
import
(
DEEPGEMM_V202506
,
ENABLE_JIT_DEEPGEMM
,
)
from
sglang.srt.server_args
import
ServerArgs
from
sglang.srt.server_args
import
ServerArgs
from
sglang.srt.utils
import
get_bool_env_var
,
get_device_sm
,
get_int_env_var
,
is_cuda
from
sglang.srt.utils
import
get_bool_env_var
,
get_int_env_var
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
_ENABLE_JIT_DEEPGEMM
=
False
try
:
try
:
import
deep_gemm
from
deep_gemm
import
get_num_sms
from
deep_gemm
import
get_num_sms
from
deep_gemm.jit
import
build
from
deep_gemm.jit
import
build
from
deep_gemm.jit.compiler
import
get_nvcc_compiler
from
deep_gemm.jit_kernels.gemm
import
get_best_configs
from
deep_gemm.jit_kernels.gemm
import
get_best_configs
from
deep_gemm.jit_kernels.runtime
import
FP8GemmRuntime
,
GemmType
from
deep_gemm.jit_kernels.runtime
import
FP8GemmRuntime
,
GemmType
sm_version
=
get_device_sm
()
if
sm_version
==
90
:
if
get_bool_env_var
(
"SGL_ENABLE_JIT_DEEPGEMM"
,
default
=
"true"
):
_ENABLE_JIT_DEEPGEMM
=
True
except
ImportError
:
except
ImportError
:
logger
.
warning
(
"Failed to import deepgemm, disable _ENABLE_JIT_DEEPGEMM."
)
pass
def
get_enable_jit_deepgemm
():
return
_ENABLE_JIT_DEEPGEMM
_BUILTIN_M_LIST
=
list
(
range
(
1
,
1024
*
16
+
1
))
_BUILTIN_M_LIST
=
list
(
range
(
1
,
1024
*
16
+
1
))
...
@@ -52,8 +43,10 @@ os.environ["DG_JIT_CACHE_DIR"] = os.getenv(
...
@@ -52,8 +43,10 @@ os.environ["DG_JIT_CACHE_DIR"] = os.getenv(
# NVRTC may have performance loss with some cases.
# NVRTC may have performance loss with some cases.
# And NVCC JIT speed is also 9x faster in the ref commit
# And NVCC JIT speed is also 9x faster in the ref commit
_USE_NVRTC_DEFAULT
=
"0"
_USE_NVRTC_DEFAULT
=
"0"
if
_
ENABLE_JIT_DEEPGEMM
:
if
ENABLE_JIT_DEEPGEMM
:
try
:
try
:
from
deep_gemm.jit.compiler
import
get_nvcc_compiler
get_nvcc_compiler
()
get_nvcc_compiler
()
except
:
except
:
logger
.
warning
(
logger
.
warning
(
...
@@ -114,6 +107,7 @@ class DeepGemmKernelHelper:
...
@@ -114,6 +107,7 @@ class DeepGemmKernelHelper:
_INITIALIZATION_DICT
:
Dict
[
Tuple
[
DeepGemmKernelType
,
int
,
int
,
int
],
bool
]
=
dict
()
_INITIALIZATION_DICT
:
Dict
[
Tuple
[
DeepGemmKernelType
,
int
,
int
,
int
],
bool
]
=
dict
()
# TODO improve naming
def
_compile_warning_1
():
def
_compile_warning_1
():
if
not
_IN_PRECOMPILE_STAGE
and
_IS_FIRST_RANK_ON_NODE
:
if
not
_IN_PRECOMPILE_STAGE
and
_IS_FIRST_RANK_ON_NODE
:
logger
.
warning
(
logger
.
warning
(
...
@@ -127,6 +121,7 @@ def _compile_warning_1():
...
@@ -127,6 +121,7 @@ def _compile_warning_1():
)
)
# TODO improve naming
def
_compile_warning_2
():
def
_compile_warning_2
():
logger
.
warning
(
logger
.
warning
(
"Entering DeepGEMM JIT Single Kernel Compile session. "
"Entering DeepGEMM JIT Single Kernel Compile session. "
...
@@ -238,6 +233,7 @@ def _compile_gemm_nt_f8f8bf16_one(
...
@@ -238,6 +233,7 @@ def _compile_gemm_nt_f8f8bf16_one(
_
=
build
(
"gemm_fp8_fp8_bf16_nt"
,
code
,
FP8GemmRuntime
,
kwargs
)
_
=
build
(
"gemm_fp8_fp8_bf16_nt"
,
code
,
FP8GemmRuntime
,
kwargs
)
# TODO further refactor warmup-related
_KERNEL_HELPER_DICT
:
Dict
[
DeepGemmKernelType
,
DeepGemmKernelHelper
]
=
{
_KERNEL_HELPER_DICT
:
Dict
[
DeepGemmKernelType
,
DeepGemmKernelHelper
]
=
{
DeepGemmKernelType
.
GROUPED_GEMM_NT_F8F8BF16_MASKED
:
DeepGemmKernelHelper
(
DeepGemmKernelType
.
GROUPED_GEMM_NT_F8F8BF16_MASKED
:
DeepGemmKernelHelper
(
name
=
"m_grouped_gemm_fp8_fp8_bf16_nt_masked"
,
name
=
"m_grouped_gemm_fp8_fp8_bf16_nt_masked"
,
...
@@ -270,7 +266,6 @@ def _maybe_compile_deep_gemm_one_type_all(
...
@@ -270,7 +266,6 @@ def _maybe_compile_deep_gemm_one_type_all(
num_groups
:
int
,
num_groups
:
int
,
m_list
:
Optional
[
List
[
int
]]
=
None
,
m_list
:
Optional
[
List
[
int
]]
=
None
,
)
->
None
:
)
->
None
:
global
_INITIALIZATION_DICT
global
_INITIALIZATION_DICT
global
_BUILTIN_M_LIST
global
_BUILTIN_M_LIST
...
@@ -304,56 +299,6 @@ def _maybe_compile_deep_gemm_one_type_all(
...
@@ -304,56 +299,6 @@ def _maybe_compile_deep_gemm_one_type_all(
thread_map
(
compile_func
,
collected_configs
,
max_workers
=
_COMPILE_WORKERS
)
thread_map
(
compile_func
,
collected_configs
,
max_workers
=
_COMPILE_WORKERS
)
def
grouped_gemm_nt_f8f8bf16_masked
(
lhs
:
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
],
rhs
:
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
],
out
:
torch
.
Tensor
,
masked_m
:
torch
.
Tensor
,
expected_m
:
int
,
):
num_groups
,
_
,
k
=
lhs
[
0
].
shape
_
,
n
,
_
=
rhs
[
0
].
shape
kernel_type
=
DeepGemmKernelType
.
GROUPED_GEMM_NT_F8F8BF16_MASKED
_maybe_compile_deep_gemm_one_type_all
(
kernel_type
,
n
,
k
,
num_groups
)
with
_log_jit_build
(
expected_m
,
n
,
k
,
kernel_type
):
deep_gemm
.
m_grouped_gemm_fp8_fp8_bf16_nt_masked
(
lhs
,
rhs
,
out
,
masked_m
,
expected_m
)
def
grouped_gemm_nt_f8f8bf16_contig
(
lhs
:
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
],
rhs
:
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
],
out
:
torch
.
Tensor
,
m_indices
:
torch
.
Tensor
,
):
m
,
k
=
lhs
[
0
].
shape
num_groups
,
n
,
_
=
rhs
[
0
].
shape
kernel_type
=
DeepGemmKernelType
.
GROUPED_GEMM_NT_F8F8BF16_CONTIG
_maybe_compile_deep_gemm_one_type_all
(
kernel_type
,
n
,
k
,
num_groups
)
with
_log_jit_build
(
m
,
n
,
k
,
kernel_type
):
deep_gemm
.
m_grouped_gemm_fp8_fp8_bf16_nt_contiguous
(
lhs
,
rhs
,
out
,
m_indices
)
def
gemm_nt_f8f8bf16
(
lhs
:
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
],
rhs
:
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
],
out
:
torch
.
Tensor
,
):
m
,
k
=
lhs
[
0
].
shape
n
,
_
=
rhs
[
0
].
shape
kernel_type
=
DeepGemmKernelType
.
GEMM_NT_F8F8BF16
_maybe_compile_deep_gemm_one_type_all
(
kernel_type
,
n
,
k
,
1
)
with
_log_jit_build
(
m
,
n
,
k
,
kernel_type
):
deep_gemm
.
gemm_fp8_fp8_bf16_nt
(
lhs
,
rhs
,
out
)
@
contextmanager
@
contextmanager
def
_log_jit_build
(
M
:
int
,
N
:
int
,
K
:
int
,
kernel_type
:
DeepGemmKernelType
):
def
_log_jit_build
(
M
:
int
,
N
:
int
,
K
:
int
,
kernel_type
:
DeepGemmKernelType
):
if
_IN_PRECOMPILE_STAGE
:
if
_IN_PRECOMPILE_STAGE
:
...
@@ -380,13 +325,14 @@ def _log_jit_build(M: int, N: int, K: int, kernel_type: DeepGemmKernelType):
...
@@ -380,13 +325,14 @@ def _log_jit_build(M: int, N: int, K: int, kernel_type: DeepGemmKernelType):
@
contextmanager
@
contextmanager
def
configure_deep_gemm_num_sms
(
num_sms
):
def
deep_gemm_execution_hook
(
if
num_sms
is
None
:
m
:
int
,
n
:
int
,
k
:
int
,
num_groups
:
int
,
kernel_type
:
DeepGemmKernelType
):
# not supported yet
if
DEEPGEMM_V202506
:
yield
return
_maybe_compile_deep_gemm_one_type_all
(
kernel_type
,
n
,
k
,
num_groups
)
with
_log_jit_build
(
m
,
n
,
k
,
kernel_type
):
yield
yield
else
:
original_num_sms
=
deep_gemm
.
get_num_sms
()
deep_gemm
.
set_num_sms
(
num_sms
)
try
:
yield
finally
:
deep_gemm
.
set_num_sms
(
original_num_sms
)
python/sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py
0 → 100644
View file @
b4c41f72
import
logging
from
sglang.srt.utils
import
get_bool_env_var
,
get_device_sm
logger
=
logging
.
getLogger
(
__name__
)
def
_compute_enable_deep_gemm
():
try
:
import
deep_gemm
except
ImportError
:
logger
.
warning
(
"Failed to import deep_gemm, disable ENABLE_JIT_DEEPGEMM."
)
return
False
sm_version
=
get_device_sm
()
if
sm_version
<
90
:
return
False
return
get_bool_env_var
(
"SGL_ENABLE_JIT_DEEPGEMM"
,
default
=
"true"
)
ENABLE_JIT_DEEPGEMM
=
_compute_enable_deep_gemm
()
DEEPGEMM_V202506
=
False
DEEPGEMM_SCALE_UE8M0
=
DEEPGEMM_V202506
python/sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py
0 → 100644
View file @
b4c41f72
import
logging
from
contextlib
import
contextmanager
from
typing
import
Tuple
import
torch
from
sglang.srt.layers.quantization.deep_gemm_wrapper
import
compile_utils
from
sglang.srt.layers.quantization.deep_gemm_wrapper.configurer
import
(
DEEPGEMM_SCALE_UE8M0
,
DEEPGEMM_V202506
,
ENABLE_JIT_DEEPGEMM
,
)
from
sglang.srt.server_args
import
ServerArgs
logger
=
logging
.
getLogger
(
__name__
)
if
ENABLE_JIT_DEEPGEMM
:
import
deep_gemm
from
deep_gemm
import
gemm_fp8_fp8_bf16_nt
as
_gemm_nt_f8f8bf16_raw
from
deep_gemm
import
get_col_major_tma_aligned_tensor
from
deep_gemm
import
(
m_grouped_gemm_fp8_fp8_bf16_nt_contiguous
as
_grouped_gemm_nt_f8f8bf16_contig_raw
,
)
from
deep_gemm
import
(
m_grouped_gemm_fp8_fp8_bf16_nt_masked
as
_grouped_gemm_nt_f8f8bf16_masked_raw
,
)
def
grouped_gemm_nt_f8f8bf16_masked
(
lhs
:
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
],
rhs
:
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
],
out
:
torch
.
Tensor
,
masked_m
:
torch
.
Tensor
,
expected_m
:
int
,
recipe
=
None
,
):
num_groups
,
_
,
k
=
lhs
[
0
].
shape
_
,
n
,
_
=
rhs
[
0
].
shape
kernel_type
=
compile_utils
.
DeepGemmKernelType
.
GROUPED_GEMM_NT_F8F8BF16_MASKED
with
compile_utils
.
deep_gemm_execution_hook
(
expected_m
,
n
,
k
,
num_groups
,
kernel_type
):
_grouped_gemm_nt_f8f8bf16_masked_raw
(
lhs
,
rhs
,
out
,
masked_m
,
expected_m
,
recipe
=
recipe
)
def
grouped_gemm_nt_f8f8bf16_contig
(
lhs
:
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
],
rhs
:
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
],
out
:
torch
.
Tensor
,
m_indices
:
torch
.
Tensor
,
):
m
,
k
=
lhs
[
0
].
shape
num_groups
,
n
,
_
=
rhs
[
0
].
shape
kernel_type
=
compile_utils
.
DeepGemmKernelType
.
GROUPED_GEMM_NT_F8F8BF16_CONTIG
with
compile_utils
.
deep_gemm_execution_hook
(
m
,
n
,
k
,
num_groups
,
kernel_type
):
_grouped_gemm_nt_f8f8bf16_contig_raw
(
lhs
,
rhs
,
out
,
m_indices
)
def
gemm_nt_f8f8bf16
(
lhs
:
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
],
rhs
:
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
],
out
:
torch
.
Tensor
,
):
m
,
k
=
lhs
[
0
].
shape
n
,
_
=
rhs
[
0
].
shape
num_groups
=
1
kernel_type
=
compile_utils
.
DeepGemmKernelType
.
GEMM_NT_F8F8BF16
with
compile_utils
.
deep_gemm_execution_hook
(
m
,
n
,
k
,
num_groups
,
kernel_type
):
_gemm_nt_f8f8bf16_raw
(
lhs
,
rhs
,
out
,
)
def
update_deep_gemm_config
(
gpu_id
:
int
,
server_args
:
ServerArgs
):
compile_utils
.
update_deep_gemm_config
(
gpu_id
,
server_args
)
@
contextmanager
def
configure_deep_gemm_num_sms
(
num_sms
):
if
num_sms
is
None
:
yield
else
:
original_num_sms
=
deep_gemm
.
get_num_sms
()
deep_gemm
.
set_num_sms
(
num_sms
)
try
:
yield
finally
:
deep_gemm
.
set_num_sms
(
original_num_sms
)
python/sglang/srt/layers/quantization/fp8_kernel.py
View file @
b4c41f72
...
@@ -23,7 +23,8 @@ import torch
...
@@ -23,7 +23,8 @@ import torch
import
triton
import
triton
import
triton.language
as
tl
import
triton.language
as
tl
from
sglang.srt.layers.quantization.deep_gemm
import
_ENABLE_JIT_DEEPGEMM
from
sglang.math_utils
import
align
from
sglang.srt.layers.quantization
import
deep_gemm_wrapper
from
sglang.srt.utils
import
(
from
sglang.srt.utils
import
(
direct_register_custom_op
,
direct_register_custom_op
,
get_device_core_count
,
get_device_core_count
,
...
@@ -44,10 +45,6 @@ if _is_cuda:
...
@@ -44,10 +45,6 @@ if _is_cuda:
sgl_per_token_quant_fp8
,
sgl_per_token_quant_fp8
,
)
)
from
sglang.srt.layers.quantization.deep_gemm
import
(
gemm_nt_f8f8bf16
as
deep_gemm_gemm_nt_f8f8bf16
,
)
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
...
@@ -67,7 +64,6 @@ else:
...
@@ -67,7 +64,6 @@ else:
fp8_max
=
torch
.
finfo
(
fp8_dtype
).
max
fp8_max
=
torch
.
finfo
(
fp8_dtype
).
max
fp8_min
=
-
fp8_max
fp8_min
=
-
fp8_max
if
supports_custom_op
():
if
supports_custom_op
():
def
deep_gemm_fp8_fp8_bf16_nt
(
def
deep_gemm_fp8_fp8_bf16_nt
(
...
@@ -77,7 +73,7 @@ if supports_custom_op():
...
@@ -77,7 +73,7 @@ if supports_custom_op():
Bs
:
torch
.
Tensor
,
Bs
:
torch
.
Tensor
,
C
:
torch
.
Tensor
,
C
:
torch
.
Tensor
,
)
->
None
:
)
->
None
:
deep_gemm_gemm_nt_f8f8bf16
((
A
,
As
),
(
B
,
Bs
),
C
)
deep_gemm_
wrapper
.
gemm_nt_f8f8bf16
((
A
,
As
),
(
B
,
Bs
),
C
)
def
deep_gemm_fp8_fp8_bf16_nt_fake
(
def
deep_gemm_fp8_fp8_bf16_nt_fake
(
A
:
torch
.
Tensor
,
A
:
torch
.
Tensor
,
...
@@ -797,12 +793,12 @@ def w8a8_block_fp8_matmul_deepgemm(
...
@@ -797,12 +793,12 @@ def w8a8_block_fp8_matmul_deepgemm(
M
,
N
,
K
,
C
=
prepare_block_fp8_matmul_inputs
(
A
,
B
,
As
,
Bs
,
block_size
,
output_dtype
)
M
,
N
,
K
,
C
=
prepare_block_fp8_matmul_inputs
(
A
,
B
,
As
,
Bs
,
block_size
,
output_dtype
)
# Deepgemm only supports output tensor type as bfloat16
# Deepgemm only supports output tensor type as bfloat16
assert
C
.
dtype
==
torch
.
bfloat16
and
_
ENABLE_JIT_DEEPGEMM
assert
C
.
dtype
==
torch
.
bfloat16
and
deep_gemm_wrapper
.
ENABLE_JIT_DEEPGEMM
if
supports_custom_op
():
if
supports_custom_op
():
torch
.
ops
.
sglang
.
deep_gemm_fp8_fp8_bf16_nt
(
A
,
As
,
B
,
Bs
,
C
)
torch
.
ops
.
sglang
.
deep_gemm_fp8_fp8_bf16_nt
(
A
,
As
,
B
,
Bs
,
C
)
else
:
else
:
deep_gemm_gemm_nt_f8f8bf16
((
A
,
As
),
(
B
,
Bs
),
C
)
deep_gemm_
wrapper
.
gemm_nt_f8f8bf16
((
A
,
As
),
(
B
,
Bs
),
C
)
return
C
return
C
...
@@ -896,7 +892,7 @@ def w8a8_block_fp8_matmul(
...
@@ -896,7 +892,7 @@ def w8a8_block_fp8_matmul(
block_size
:
List
[
int
],
block_size
:
List
[
int
],
output_dtype
:
torch
.
dtype
=
torch
.
float16
,
output_dtype
:
torch
.
dtype
=
torch
.
float16
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
if
output_dtype
==
torch
.
bfloat16
and
_
ENABLE_JIT_DEEPGEMM
:
if
output_dtype
==
torch
.
bfloat16
and
deep_gemm_wrapper
.
ENABLE_JIT_DEEPGEMM
:
return
w8a8_block_fp8_matmul_deepgemm
(
return
w8a8_block_fp8_matmul_deepgemm
(
A
,
B
,
As
,
Bs
,
block_size
,
output_dtype
=
output_dtype
A
,
B
,
As
,
Bs
,
block_size
,
output_dtype
=
output_dtype
)
)
...
...
python/sglang/srt/layers/quantization/fp8_utils.py
View file @
b4c41f72
import
os
from
curses
import
flash
from
typing
import
Callable
,
List
,
Optional
,
Tuple
from
typing
import
Callable
,
List
,
Optional
,
Tuple
import
einops
import
torch
import
torch
from
sglang.math_utils
import
align
from
sglang.math_utils
import
align
from
sglang.srt.layers.quantization
import
deep_gemm_wrapper
from
sglang.srt.layers.quantization.fp8_kernel
import
sglang_per_token_group_quant_fp8
from
sglang.srt.layers.quantization.fp8_kernel
import
sglang_per_token_group_quant_fp8
from
sglang.srt.layers.utils
import
is_sm100_supported
from
sglang.srt.layers.utils
import
is_sm100_supported
...
@@ -15,7 +15,6 @@ try:
...
@@ -15,7 +15,6 @@ try:
except
ImportError
:
except
ImportError
:
VLLM_AVAILABLE
=
False
VLLM_AVAILABLE
=
False
from
sglang.srt.layers.quantization.deep_gemm
import
_ENABLE_JIT_DEEPGEMM
from
sglang.srt.layers.quantization.fp8_kernel
import
(
from
sglang.srt.layers.quantization.fp8_kernel
import
(
fp8_dtype
,
fp8_dtype
,
fp8_max
,
fp8_max
,
...
@@ -138,7 +137,7 @@ def dispatch_w8a8_block_fp8_linear() -> Callable:
...
@@ -138,7 +137,7 @@ def dispatch_w8a8_block_fp8_linear() -> Callable:
return
cutlass_w8a8_block_fp8_linear_with_fallback
return
cutlass_w8a8_block_fp8_linear_with_fallback
elif
_use_aiter
:
elif
_use_aiter
:
return
aiter_w8a8_block_fp8_linear
return
aiter_w8a8_block_fp8_linear
elif
_
ENABLE_JIT_DEEPGEMM
:
elif
deep_gemm_wrapper
.
ENABLE_JIT_DEEPGEMM
:
return
deepgemm_w8a8_block_fp8_linear_with_fallback
return
deepgemm_w8a8_block_fp8_linear_with_fallback
else
:
else
:
return
triton_w8a8_block_fp8_linear
return
triton_w8a8_block_fp8_linear
...
...
python/sglang/srt/model_executor/model_runner.py
View file @
b4c41f72
...
@@ -26,6 +26,7 @@ from typing import List, Optional, Tuple, Union
...
@@ -26,6 +26,7 @@ from typing import List, Optional, Tuple, Union
import
torch
import
torch
import
torch.distributed
as
dist
import
torch.distributed
as
dist
from
sglang.srt
import
debug_utils
from
sglang.srt.configs.device_config
import
DeviceConfig
from
sglang.srt.configs.device_config
import
DeviceConfig
from
sglang.srt.configs.load_config
import
LoadConfig
from
sglang.srt.configs.load_config
import
LoadConfig
from
sglang.srt.configs.model_config
import
AttentionArch
,
ModelConfig
from
sglang.srt.configs.model_config
import
AttentionArch
,
ModelConfig
...
@@ -45,10 +46,9 @@ from sglang.srt.layers.dp_attention import (
...
@@ -45,10 +46,9 @@ from sglang.srt.layers.dp_attention import (
initialize_dp_attention
,
initialize_dp_attention
,
)
)
from
sglang.srt.layers.logits_processor
import
LogitsProcessorOutput
from
sglang.srt.layers.logits_processor
import
LogitsProcessorOutput
from
sglang.srt.layers.quantization
import
monkey_patch_isinstance_for_vllm_base_layer
from
sglang.srt.layers.quantization
import
(
from
sglang.srt.layers.quantization.deep_gemm
import
(
deep_gemm_wrapper
,
_ENABLE_JIT_DEEPGEMM
,
monkey_patch_isinstance_for_vllm_base_layer
,
update_deep_gemm_config
,
)
)
from
sglang.srt.layers.sampler
import
Sampler
from
sglang.srt.layers.sampler
import
Sampler
from
sglang.srt.layers.torchao_utils
import
apply_torchao_config_to_model
from
sglang.srt.layers.torchao_utils
import
apply_torchao_config_to_model
...
@@ -205,8 +205,8 @@ class ModelRunner:
...
@@ -205,8 +205,8 @@ class ModelRunner:
min_per_gpu_memory
=
self
.
init_torch_distributed
()
min_per_gpu_memory
=
self
.
init_torch_distributed
()
# Update deep gemm configure
# Update deep gemm configure
if
_
ENABLE_JIT_DEEPGEMM
:
if
deep_gemm_wrapper
.
ENABLE_JIT_DEEPGEMM
:
update_deep_gemm_config
(
gpu_id
,
server_args
)
deep_gemm_wrapper
.
update_deep_gemm_config
(
gpu_id
,
server_args
)
# If it is a draft model, tp_group can be different
# If it is a draft model, tp_group can be different
self
.
initialize
(
min_per_gpu_memory
)
self
.
initialize
(
min_per_gpu_memory
)
...
...
python/sglang/srt/models/deepseek_v2.py
View file @
b4c41f72
...
@@ -54,8 +54,8 @@ from sglang.srt.layers.logits_processor import LogitsProcessor
...
@@ -54,8 +54,8 @@ from sglang.srt.layers.logits_processor import LogitsProcessor
from
sglang.srt.layers.moe.ep_moe.layer
import
get_moe_impl_class
from
sglang.srt.layers.moe.ep_moe.layer
import
get_moe_impl_class
from
sglang.srt.layers.moe.ep_moe.token_dispatcher
import
DeepEPDispatcher
from
sglang.srt.layers.moe.ep_moe.token_dispatcher
import
DeepEPDispatcher
from
sglang.srt.layers.moe.topk
import
select_experts
from
sglang.srt.layers.moe.topk
import
select_experts
from
sglang.srt.layers.quantization
import
deep_gemm_wrapper
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.quantization.deep_gemm
import
_ENABLE_JIT_DEEPGEMM
from
sglang.srt.layers.quantization.fp8_kernel
import
(
from
sglang.srt.layers.quantization.fp8_kernel
import
(
is_fp8_fnuz
,
is_fp8_fnuz
,
per_tensor_quant_mla_fp8
,
per_tensor_quant_mla_fp8
,
...
@@ -110,10 +110,6 @@ _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
...
@@ -110,10 +110,6 @@ _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
if
_is_cuda
:
if
_is_cuda
:
from
sgl_kernel
import
awq_dequantize
,
bmm_fp8
,
merge_state_v2
from
sgl_kernel
import
awq_dequantize
,
bmm_fp8
,
merge_state_v2
from
sglang.srt.layers.quantization.deep_gemm
import
(
grouped_gemm_nt_f8f8bf16_masked
as
deep_gemm_grouped_gemm_nt_f8f8bf16_masked
,
)
else
:
else
:
from
vllm._custom_ops
import
awq_dequantize
from
vllm._custom_ops
import
awq_dequantize
...
@@ -981,7 +977,7 @@ class DeepseekV2AttentionMLA(nn.Module):
...
@@ -981,7 +977,7 @@ class DeepseekV2AttentionMLA(nn.Module):
q_nope_out
=
q_nope
.
new_empty
(
q_nope_out
=
q_nope
.
new_empty
(
(
self
.
num_local_heads
,
aligned_m
,
self
.
kv_lora_rank
)
(
self
.
num_local_heads
,
aligned_m
,
self
.
kv_lora_rank
)
)
)
deep_gemm_grouped_gemm_nt_f8f8bf16_masked
(
deep_gemm_
wrapper
.
grouped_gemm_nt_f8f8bf16_masked
(
(
q_nope_val
,
q_nope_scale
),
(
q_nope_val
,
q_nope_scale
),
(
self
.
w_kc
,
self
.
w_scale_k
),
(
self
.
w_kc
,
self
.
w_scale_k
),
q_nope_out
,
q_nope_out
,
...
@@ -1851,7 +1847,7 @@ class DeepseekV2ForCausalLM(nn.Module):
...
@@ -1851,7 +1847,7 @@ class DeepseekV2ForCausalLM(nn.Module):
and
weight_block_size
[
1
]
==
128
and
weight_block_size
[
1
]
==
128
and
model_dtype
==
torch
.
bfloat16
and
model_dtype
==
torch
.
bfloat16
):
):
if
_
ENABLE_JIT_DEEPGEMM
and
get_bool_env_var
(
if
deep_gemm_wrapper
.
ENABLE_JIT_DEEPGEMM
and
get_bool_env_var
(
"SGL_USE_DEEPGEMM_BMM"
,
"false"
"SGL_USE_DEEPGEMM_BMM"
,
"false"
):
):
block_scale
=
weight_scale
block_scale
=
weight_scale
...
...
python/sglang/srt/two_batch_overlap.py
View file @
b4c41f72
...
@@ -11,7 +11,7 @@ from sglang.srt.layers.communicator import (
...
@@ -11,7 +11,7 @@ from sglang.srt.layers.communicator import (
ScatterMode
,
ScatterMode
,
)
)
from
sglang.srt.layers.moe.ep_moe.token_dispatcher
import
DeepEPDispatcher
from
sglang.srt.layers.moe.ep_moe.token_dispatcher
import
DeepEPDispatcher
from
sglang.srt.layers.quantization
.deep_gemm
import
configure_
deep_gemm_
num_sms
from
sglang.srt.layers.quantization
import
deep_gemm_
wrapper
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
ForwardMode
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
ForwardMode
from
sglang.srt.operations
import
execute_operations
,
execute_overlapped_operations
from
sglang.srt.operations
import
execute_operations
,
execute_overlapped_operations
...
@@ -479,7 +479,9 @@ def _model_forward_tbo(
...
@@ -479,7 +479,9 @@ def _model_forward_tbo(
)
)
del
inputs
del
inputs
with
configure_deep_gemm_num_sms
(
operations_strategy
.
deep_gemm_num_sms
):
with
deep_gemm_wrapper
.
configure_deep_gemm_num_sms
(
operations_strategy
.
deep_gemm_num_sms
):
outputs_arr
=
execute_overlapped_operations
(
outputs_arr
=
execute_overlapped_operations
(
inputs_arr
=
inputs_arr
,
inputs_arr
=
inputs_arr
,
operations_arr
=
[
operations_strategy
.
operations
]
*
2
,
operations_arr
=
[
operations_strategy
.
operations
]
*
2
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment