Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
zhaoyu6
sglang
Commits
177320a5
"vscode:/vscode.git/clone" did not exist on "cb6d8d6ac5f0c77aed3323a30c43075d9d969991"
Unverified
Commit
177320a5
authored
Apr 16, 2025
by
Lianmin Zheng
Committed by
GitHub
Apr 16, 2025
Browse files
Clean up imports (#5467)
parent
d7bc19a4
Changes
51
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
63 additions
and
217 deletions
+63
-217
python/sglang/__init__.py
python/sglang/__init__.py
+2
-4
python/sglang/bench_serving.py
python/sglang/bench_serving.py
+0
-4
python/sglang/lang/__init__.py
python/sglang/lang/__init__.py
+0
-0
python/sglang/lang/backend/anthropic.py
python/sglang/lang/backend/anthropic.py
+0
-4
python/sglang/lang/backend/base_backend.py
python/sglang/lang/backend/base_backend.py
+1
-1
python/sglang/lang/backend/openai.py
python/sglang/lang/backend/openai.py
+1
-1
python/sglang/lang/backend/vertexai.py
python/sglang/lang/backend/vertexai.py
+0
-1
python/sglang/lang/compiler.py
python/sglang/lang/compiler.py
+1
-7
python/sglang/lang/tracer.py
python/sglang/lang/tracer.py
+3
-7
python/sglang/srt/_custom_ops.py
python/sglang/srt/_custom_ops.py
+0
-2
python/sglang/srt/custom_op.py
python/sglang/srt/custom_op.py
+0
-62
python/sglang/srt/entrypoints/verl_engine.py
python/sglang/srt/entrypoints/verl_engine.py
+1
-2
python/sglang/srt/layers/activation.py
python/sglang/srt/layers/activation.py
+6
-8
python/sglang/srt/layers/layernorm.py
python/sglang/srt/layers/layernorm.py
+1
-1
python/sglang/srt/layers/moe/ep_moe/layer.py
python/sglang/srt/layers/moe/ep_moe/layer.py
+12
-26
python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py
python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py
+12
-19
python/sglang/srt/layers/moe/topk.py
python/sglang/srt/layers/moe/topk.py
+10
-20
python/sglang/srt/layers/parameter.py
python/sglang/srt/layers/parameter.py
+0
-2
python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py
...ers/quantization/compressed_tensors/compressed_tensors.py
+1
-2
python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py
...quantization/compressed_tensors/compressed_tensors_moe.py
+12
-44
No files found.
python/sglang/__init__.py
View file @
177320a5
...
@@ -24,6 +24,7 @@ from sglang.api import (
...
@@ -24,6 +24,7 @@ from sglang.api import (
user_end
,
user_end
,
video
,
video
,
)
)
from
sglang.global_config
import
global_config
from
sglang.lang.backend.runtime_endpoint
import
RuntimeEndpoint
from
sglang.lang.backend.runtime_endpoint
import
RuntimeEndpoint
from
sglang.lang.choices
import
(
from
sglang.lang.choices
import
(
greedy_token_selection
,
greedy_token_selection
,
...
@@ -31,6 +32,7 @@ from sglang.lang.choices import (
...
@@ -31,6 +32,7 @@ from sglang.lang.choices import (
unconditional_likelihood_normalized
,
unconditional_likelihood_normalized
,
)
)
from
sglang.utils
import
LazyImport
from
sglang.utils
import
LazyImport
from
sglang.version
import
__version__
ServerArgs
=
LazyImport
(
"sglang.srt.server_args"
,
"ServerArgs"
)
ServerArgs
=
LazyImport
(
"sglang.srt.server_args"
,
"ServerArgs"
)
Anthropic
=
LazyImport
(
"sglang.lang.backend.anthropic"
,
"Anthropic"
)
Anthropic
=
LazyImport
(
"sglang.lang.backend.anthropic"
,
"Anthropic"
)
...
@@ -38,10 +40,6 @@ LiteLLM = LazyImport("sglang.lang.backend.litellm", "LiteLLM")
...
@@ -38,10 +40,6 @@ LiteLLM = LazyImport("sglang.lang.backend.litellm", "LiteLLM")
OpenAI
=
LazyImport
(
"sglang.lang.backend.openai"
,
"OpenAI"
)
OpenAI
=
LazyImport
(
"sglang.lang.backend.openai"
,
"OpenAI"
)
VertexAI
=
LazyImport
(
"sglang.lang.backend.vertexai"
,
"VertexAI"
)
VertexAI
=
LazyImport
(
"sglang.lang.backend.vertexai"
,
"VertexAI"
)
# Other configs
from
sglang.global_config
import
global_config
from
sglang.version
import
__version__
__all__
=
[
__all__
=
[
"Engine"
,
"Engine"
,
"Runtime"
,
"Runtime"
,
...
...
python/sglang/bench_serving.py
View file @
177320a5
...
@@ -707,10 +707,6 @@ def sample_random_requests(
...
@@ -707,10 +707,6 @@ def sample_random_requests(
# Download sharegpt if necessary
# Download sharegpt if necessary
if
not
os
.
path
.
isfile
(
dataset_path
):
if
not
os
.
path
.
isfile
(
dataset_path
):
print
(
"If you do not want to randomly sample from a dataset,"
" please use --dataset-name random-ids."
)
dataset_path
=
download_and_cache_file
(
SHAREGPT_URL
)
dataset_path
=
download_and_cache_file
(
SHAREGPT_URL
)
# Load the dataset.
# Load the dataset.
...
...
python/sglang/lang/__init__.py
deleted
100644 → 0
View file @
d7bc19a4
python/sglang/lang/backend/anthropic.py
View file @
177320a5
from
typing
import
List
,
Optional
,
Union
import
numpy
as
np
from
sglang.lang.backend.base_backend
import
BaseBackend
from
sglang.lang.backend.base_backend
import
BaseBackend
from
sglang.lang.chat_template
import
get_chat_template
from
sglang.lang.chat_template
import
get_chat_template
from
sglang.lang.interpreter
import
StreamExecutor
from
sglang.lang.interpreter
import
StreamExecutor
...
...
python/sglang/lang/backend/base_backend.py
View file @
177320a5
from
typing
import
Callable
,
List
,
Optional
,
Union
from
typing
import
List
,
Optional
,
Union
from
sglang.lang.chat_template
import
get_chat_template
from
sglang.lang.chat_template
import
get_chat_template
from
sglang.lang.choices
import
ChoicesDecision
,
ChoicesSamplingMethod
from
sglang.lang.choices
import
ChoicesDecision
,
ChoicesSamplingMethod
...
...
python/sglang/lang/backend/openai.py
View file @
177320a5
...
@@ -2,7 +2,7 @@ import dataclasses
...
@@ -2,7 +2,7 @@ import dataclasses
import
logging
import
logging
import
time
import
time
import
warnings
import
warnings
from
typing
import
Callable
,
List
,
Optional
,
Union
from
typing
import
List
,
Optional
,
Union
import
numpy
as
np
import
numpy
as
np
...
...
python/sglang/lang/backend/vertexai.py
View file @
177320a5
import
os
import
os
import
warnings
import
warnings
from
typing
import
Optional
from
sglang.lang.backend.base_backend
import
BaseBackend
from
sglang.lang.backend.base_backend
import
BaseBackend
from
sglang.lang.chat_template
import
get_chat_template
from
sglang.lang.chat_template
import
get_chat_template
...
...
python/sglang/lang/compiler.py
View file @
177320a5
...
@@ -5,13 +5,7 @@ from typing import List, Union
...
@@ -5,13 +5,7 @@ from typing import List, Union
from
sglang.global_config
import
global_config
from
sglang.global_config
import
global_config
from
sglang.lang.interpreter
import
ProgramState
,
StreamExecutor
,
cache_program
from
sglang.lang.interpreter
import
ProgramState
,
StreamExecutor
,
cache_program
from
sglang.lang.ir
import
(
from
sglang.lang.ir
import
SglArgument
,
SglExpr
,
SglSamplingParams
,
SglVariable
SglArgument
,
SglConstantText
,
SglExpr
,
SglSamplingParams
,
SglVariable
,
)
def
compile_func
(
function
,
backend
):
def
compile_func
(
function
,
backend
):
...
...
python/sglang/lang/tracer.py
View file @
177320a5
"""Tracing a program."""
"""Tracing a program."""
import
uuid
import
uuid
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Optional
,
Union
from
typing
import
Any
,
Dict
,
List
,
Optional
from
sglang.global_config
import
global_config
from
sglang.lang.backend.base_backend
import
BaseBackend
from
sglang.lang.backend.base_backend
import
BaseBackend
from
sglang.lang.interpreter
import
ProgramState
,
ProgramStateGroup
from
sglang.lang.interpreter
import
ProgramState
,
ProgramStateGroup
from
sglang.lang.ir
import
(
from
sglang.lang.ir
import
(
SglArgument
,
SglArgument
,
SglCommitLazy
,
SglConcateAndAppend
,
SglConstantText
,
SglConstantText
,
SglExpr
,
SglExpr
,
SglExprList
,
SglExprList
,
SglFork
,
SglFork
,
SglFunction
,
SglGen
,
SglGen
,
SglGetForkItem
,
SglGetForkItem
,
SglRoleBegin
,
SglRoleBegin
,
...
@@ -230,8 +226,8 @@ class TracerProgramState(ProgramState):
...
@@ -230,8 +226,8 @@ class TracerProgramState(ProgramState):
self
.
cur_role
=
None
self
.
cur_role
=
None
def
_execute_var_scope_end
(
self
,
expr
:
SglVarScopeEnd
):
def
_execute_var_scope_end
(
self
,
expr
:
SglVarScopeEnd
):
new_node
=
SglVariable
(
name
,
source
=
self
.
last_node
)
new_node
=
SglVariable
(
expr
.
name
,
source
=
self
.
last_node
)
self
.
variables
[
name
]
=
new_node
self
.
variables
[
expr
.
name
]
=
new_node
def
get_var
(
self
,
name
):
def
get_var
(
self
,
name
):
ret
=
self
.
arguments
.
get
(
name
,
None
)
ret
=
self
.
arguments
.
get
(
name
,
None
)
...
...
python/sglang/srt/_custom_ops.py
View file @
177320a5
# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/_custom_ops.py
# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/_custom_ops.py
import
logging
import
logging
import
os
from
typing
import
List
,
Tuple
from
typing
import
List
,
Tuple
import
torch
import
torch
import
torch.library
from
sglang.srt.utils
import
get_bool_env_var
,
is_hip
,
is_hpu
from
sglang.srt.utils
import
get_bool_env_var
,
is_hip
,
is_hpu
...
...
python/sglang/srt/custom_op.py
View file @
177320a5
...
@@ -42,65 +42,3 @@ class CustomOp(nn.Module):
...
@@ -42,65 +42,3 @@ class CustomOp(nn.Module):
return
self
.
forward_hip
return
self
.
forward_hip
else
:
else
:
return
self
.
forward_native
return
self
.
forward_native
if
_is_cuda
:
from
sgl_kernel
import
sgl_per_tensor_quant_fp8
,
sgl_per_token_quant_fp8
def
scaled_fp8_quant
(
input
:
torch
.
Tensor
,
scale
:
Optional
[
torch
.
Tensor
]
=
None
,
num_token_padding
:
Optional
[
int
]
=
None
,
use_per_token_if_dynamic
:
bool
=
False
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""
Quantize input tensor to FP8 (8-bit floating point) format.
Args:
input (torch.Tensor): Input tensor to be quantized
scale (Optional[torch.Tensor]): Pre-computed scaling factor for static quantization.
If None, scales will be computed dynamically.
num_token_padding (Optional[int]): If specified, pad the first dimension
of the output to at least this value.
use_per_token_if_dynamic (bool): When using dynamic scaling (scale=None),
determines the quantization granularity:
- True: compute scale per token
- False: compute single scale per tensor
Returns:
Tuple[torch.Tensor, torch.Tensor]: A tuple containing:
- quantized_tensor: The FP8 quantized version of input
- scale_tensor: The scaling factors used for quantization
Raises:
AssertionError: If input is not 2D or if static scale's numel != 1
"""
assert
input
.
ndim
==
2
,
f
"Expected 2D input tensor, got
{
input
.
ndim
}
D"
shape
=
input
.
shape
out_dtype
=
torch
.
float8_e4m3fnuz
if
_is_hip
else
torch
.
float8_e4m3fn
if
num_token_padding
:
shape
=
(
max
(
num_token_padding
,
input
.
shape
[
0
]),
shape
[
1
])
output
=
torch
.
empty
(
shape
,
device
=
input
.
device
,
dtype
=
out_dtype
)
if
scale
is
None
:
# Dynamic scaling
if
use_per_token_if_dynamic
:
scale
=
torch
.
empty
(
(
shape
[
0
],
1
),
device
=
input
.
device
,
dtype
=
torch
.
float32
)
sgl_per_token_quant_fp8
(
input
,
output
,
scale
)
else
:
scale
=
torch
.
zeros
(
1
,
device
=
input
.
device
,
dtype
=
torch
.
float32
)
sgl_per_tensor_quant_fp8
(
input
,
output
,
scale
,
is_static
=
False
)
# False for dynamic
else
:
# Static scaling
assert
(
scale
.
numel
()
==
1
),
f
"Expected scalar scale, got numel=
{
scale
.
numel
()
}
"
sgl_per_tensor_quant_fp8
(
input
,
output
,
scale
,
is_static
=
True
)
# True for static
return
output
,
scale
python/sglang/srt/entrypoints/verl_engine.py
View file @
177320a5
...
@@ -19,11 +19,10 @@ import torch.distributed as dist
...
@@ -19,11 +19,10 @@ import torch.distributed as dist
from
PIL.Image
import
Image
from
PIL.Image
import
Image
from
torch.distributed.tensor
import
DeviceMesh
,
DTensor
from
torch.distributed.tensor
import
DeviceMesh
,
DTensor
from
sglang.srt.entrypoints.engine
import
Engine
from
sglang.srt.entrypoints.http_server_engine
import
HttpServerEngineAdapter
from
sglang.srt.entrypoints.http_server_engine
import
HttpServerEngineAdapter
from
sglang.srt.model_executor.model_runner
import
LocalSerializedTensor
from
sglang.srt.model_executor.model_runner
import
LocalSerializedTensor
from
sglang.srt.patch_torch
import
monkey_patch_torch_reductions
from
sglang.srt.patch_torch
import
monkey_patch_torch_reductions
from
sglang.srt.server
import
Engine
from
sglang.srt.server_args
import
PortArgs
,
ServerArgs
from
sglang.srt.utils
import
MultiprocessingSerializer
,
broadcast_pyobj
from
sglang.srt.utils
import
MultiprocessingSerializer
,
broadcast_pyobj
...
...
python/sglang/srt/layers/activation.py
View file @
177320a5
...
@@ -21,13 +21,6 @@ import torch
...
@@ -21,13 +21,6 @@ import torch
import
torch.nn
as
nn
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
sglang.srt.utils
import
is_cuda_available
_is_cuda
=
is_cuda_available
()
if
_is_cuda
:
from
sgl_kernel
import
gelu_and_mul
,
gelu_tanh_and_mul
,
silu_and_mul
from
sglang.srt.custom_op
import
CustomOp
from
sglang.srt.custom_op
import
CustomOp
from
sglang.srt.distributed
import
(
from
sglang.srt.distributed
import
(
divide
,
divide
,
...
@@ -35,7 +28,12 @@ from sglang.srt.distributed import (
...
@@ -35,7 +28,12 @@ from sglang.srt.distributed import (
get_tensor_model_parallel_world_size
,
get_tensor_model_parallel_world_size
,
)
)
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.utils
import
set_weight_attrs
from
sglang.srt.utils
import
is_cuda_available
,
set_weight_attrs
_is_cuda
=
is_cuda_available
()
if
_is_cuda
:
from
sgl_kernel
import
gelu_and_mul
,
gelu_tanh_and_mul
,
silu_and_mul
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
...
...
python/sglang/srt/layers/layernorm.py
View file @
177320a5
...
@@ -19,6 +19,7 @@ from typing import Optional, Tuple, Union
...
@@ -19,6 +19,7 @@ from typing import Optional, Tuple, Union
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
sglang.srt.custom_op
import
CustomOp
from
sglang.srt.utils
import
is_cuda_available
from
sglang.srt.utils
import
is_cuda_available
_is_cuda
=
is_cuda_available
()
_is_cuda
=
is_cuda_available
()
...
@@ -31,7 +32,6 @@ if _is_cuda:
...
@@ -31,7 +32,6 @@ if _is_cuda:
rmsnorm
,
rmsnorm
,
)
)
from
sglang.srt.custom_op
import
CustomOp
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
...
...
python/sglang/srt/layers/moe/ep_moe/layer.py
View file @
177320a5
...
@@ -2,6 +2,7 @@ import logging
...
@@ -2,6 +2,7 @@ import logging
from
typing
import
Callable
,
List
,
Optional
,
Tuple
from
typing
import
Callable
,
List
,
Optional
,
Tuple
import
torch
import
torch
from
torch.nn
import
Module
try
:
try
:
from
deep_gemm
import
(
from
deep_gemm
import
(
...
@@ -13,8 +14,6 @@ try:
...
@@ -13,8 +14,6 @@ try:
except
ImportError
:
except
ImportError
:
use_deep_gemm
=
False
use_deep_gemm
=
False
from
torch.nn
import
Module
from
sglang.srt.custom_op
import
CustomOp
from
sglang.srt.custom_op
import
CustomOp
from
sglang.srt.distributed
import
(
from
sglang.srt.distributed
import
(
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_rank
,
...
@@ -37,22 +36,17 @@ from sglang.srt.layers.quantization.base_config import (
...
@@ -37,22 +36,17 @@ from sglang.srt.layers.quantization.base_config import (
QuantizeMethodBase
,
QuantizeMethodBase
,
)
)
from
sglang.srt.layers.quantization.fp8
import
Fp8Config
,
Fp8MoEMethod
from
sglang.srt.layers.quantization.fp8
import
Fp8Config
,
Fp8MoEMethod
from
sglang.srt.layers.quantization.fp8_kernel
import
scaled_fp8_quant
from
sglang.srt.model_executor.forward_batch_info
import
ForwardMode
from
sglang.srt.model_executor.forward_batch_info
import
ForwardMode
from
sglang.srt.utils
import
DeepEPMode
,
is_cuda
,
is_hip
,
set_weight_attrs
from
sglang.srt.utils
import
DeepEPMode
,
is_hip
,
set_weight_attrs
_is_
cuda
=
is_
cuda
()
_is_
hip
=
is_
hip
()
if
_is_cuda
:
if
_is_hip
:
from
sglang.srt.custom_op
import
scaled_fp8_quant
as
sgl_scaled_fp8_quant
from
vllm._custom_ops
import
scaled_fp8_quant
else
:
from
vllm
import
_custom_ops
as
vllm_ops
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
_is_hip
=
is_hip
()
_buffer
=
None
class
GroupedGemmRunner
(
torch
.
nn
.
Module
):
class
GroupedGemmRunner
(
torch
.
nn
.
Module
):
flashinfer_gemm_warpper
=
None
flashinfer_gemm_warpper
=
None
...
@@ -740,19 +734,11 @@ class Fp8EPMoEMethod(Fp8MoEMethod):
...
@@ -740,19 +734,11 @@ class Fp8EPMoEMethod(Fp8MoEMethod):
)
)
for
expert
in
range
(
layer
.
num_experts_per_partition
):
for
expert
in
range
(
layer
.
num_experts_per_partition
):
if
_is_cuda
:
w13_weight
[
expert
,
:,
:],
layer
.
w13_weight_scale
[
expert
]
=
(
sgl_scaled_fp8_quant
(
layer
.
w13_weight
.
data
[
expert
,
:,
:])
)
w2_weight
[
expert
,
:,
:],
layer
.
w2_weight_scale
[
expert
]
=
(
sgl_scaled_fp8_quant
(
layer
.
w2_weight
.
data
[
expert
,
:,
:])
)
else
:
w13_weight
[
expert
,
:,
:],
layer
.
w13_weight_scale
[
expert
]
=
(
w13_weight
[
expert
,
:,
:],
layer
.
w13_weight_scale
[
expert
]
=
(
vllm_ops
.
scaled_fp8_quant
(
layer
.
w13_weight
.
data
[
expert
,
:,
:])
scaled_fp8_quant
(
layer
.
w13_weight
.
data
[
expert
,
:,
:])
)
)
w2_weight
[
expert
,
:,
:],
layer
.
w2_weight_scale
[
expert
]
=
(
w2_weight
[
expert
,
:,
:],
layer
.
w2_weight_scale
[
expert
]
=
(
vllm_ops
.
scaled_fp8_quant
(
layer
.
w2_weight
.
data
[
expert
,
:,
:])
scaled_fp8_quant
(
layer
.
w2_weight
.
data
[
expert
,
:,
:])
)
)
layer
.
w13_weight
=
torch
.
nn
.
Parameter
(
w13_weight
,
requires_grad
=
False
)
layer
.
w13_weight
=
torch
.
nn
.
Parameter
(
w13_weight
,
requires_grad
=
False
)
layer
.
w2_weight
=
torch
.
nn
.
Parameter
(
w2_weight
,
requires_grad
=
False
)
layer
.
w2_weight
=
torch
.
nn
.
Parameter
(
w2_weight
,
requires_grad
=
False
)
...
...
python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py
View file @
177320a5
...
@@ -13,6 +13,7 @@ import triton
...
@@ -13,6 +13,7 @@ import triton
import
triton.language
as
tl
import
triton.language
as
tl
from
sglang.srt.layers.moe.topk
import
select_experts
from
sglang.srt.layers.moe.topk
import
select_experts
from
sglang.srt.layers.quantization.fp8_kernel
import
scaled_fp8_quant
from
sglang.srt.utils
import
(
from
sglang.srt.utils
import
(
direct_register_custom_op
,
direct_register_custom_op
,
get_bool_env_var
,
get_bool_env_var
,
...
@@ -22,28 +23,25 @@ from sglang.srt.utils import (
...
@@ -22,28 +23,25 @@ from sglang.srt.utils import (
)
)
_is_hip
=
is_hip
()
_is_hip
=
is_hip
()
logger
=
logging
.
getLogger
(
__name__
)
padding_size
=
128
if
bool
(
int
(
os
.
getenv
(
"MOE_PADDING"
,
"0"
)))
else
0
enable_moe_align_block_size_triton
=
bool
(
int
(
os
.
getenv
(
"ENABLE_MOE_ALIGN_BLOCK_SIZE_TRITON"
,
"0"
))
)
_is_cuda
=
is_cuda
()
_is_cuda
=
is_cuda
()
if
_is_cuda
:
if
_is_cuda
:
from
sgl_kernel
import
gelu_and_mul
,
silu_and_mul
from
sgl_kernel
import
gelu_and_mul
,
silu_and_mul
from
sglang.srt.custom_op
import
scaled_fp8_quant
as
sgl_scaled_fp8_quant
else
:
else
:
from
vllm
import
_custom_ops
as
vllm_ops
from
vllm
import
_custom_ops
as
vllm_ops
from
vllm._custom_ops
import
scaled_fp8_quant
if
_is_cuda
or
_is_hip
:
if
_is_cuda
or
_is_hip
:
from
sgl_kernel
import
moe_align_block_size
as
sgl_moe_align_block_size
from
sgl_kernel
import
moe_align_block_size
as
sgl_moe_align_block_size
logger
=
logging
.
getLogger
(
__name__
)
padding_size
=
128
if
bool
(
int
(
os
.
getenv
(
"MOE_PADDING"
,
"0"
)))
else
0
enable_moe_align_block_size_triton
=
bool
(
int
(
os
.
getenv
(
"ENABLE_MOE_ALIGN_BLOCK_SIZE_TRITON"
,
"0"
))
)
@
triton
.
jit
@
triton
.
jit
def
write_zeros_to_output
(
def
write_zeros_to_output
(
c_ptr
,
c_ptr
,
...
@@ -770,12 +768,7 @@ def invoke_fused_moe_kernel(
...
@@ -770,12 +768,7 @@ def invoke_fused_moe_kernel(
# activation tensor-wise fp8 quantization, dynamic or static
# activation tensor-wise fp8 quantization, dynamic or static
padded_size
=
padding_size
padded_size
=
padding_size
# activations apply per-token quantization when weights apply per-channel quantization by default
# activations apply per-token quantization when weights apply per-channel quantization by default
if
_is_cuda
:
A
,
A_scale
=
scaled_fp8_quant
(
A
,
A_scale
=
sgl_scaled_fp8_quant
(
A
,
A_scale
,
use_per_token_if_dynamic
=
per_channel_quant
)
else
:
A
,
A_scale
=
vllm_ops
.
scaled_fp8_quant
(
A
,
A_scale
,
use_per_token_if_dynamic
=
per_channel_quant
A
,
A_scale
,
use_per_token_if_dynamic
=
per_channel_quant
)
)
else
:
else
:
...
...
python/sglang/srt/layers/moe/topk.py
View file @
177320a5
...
@@ -13,7 +13,6 @@
...
@@ -13,7 +13,6 @@
# ==============================================================================
# ==============================================================================
import
math
import
math
import
os
from
typing
import
Callable
,
Optional
from
typing
import
Callable
,
Optional
import
torch
import
torch
...
@@ -29,6 +28,10 @@ _is_hip = is_hip()
...
@@ -29,6 +28,10 @@ _is_hip = is_hip()
if
_is_cuda
:
if
_is_cuda
:
from
sgl_kernel
import
moe_fused_gate
from
sgl_kernel
import
moe_fused_gate
if
_is_cuda
or
_is_hip
:
from
sgl_kernel
import
topk_softmax
expert_distribution_recorder
=
ExpertDistributionRecorder
()
expert_distribution_recorder
=
ExpertDistributionRecorder
()
...
@@ -59,11 +62,6 @@ def fused_topk(
...
@@ -59,11 +62,6 @@ def fused_topk(
topk
:
int
,
topk
:
int
,
renormalize
:
bool
,
renormalize
:
bool
,
):
):
if
_is_cuda
or
_is_hip
:
from
sgl_kernel
import
topk_softmax
else
:
from
vllm
import
_custom_ops
as
vllm_ops
assert
hidden_states
.
shape
[
0
]
==
gating_output
.
shape
[
0
],
"Number of tokens mismatch"
assert
hidden_states
.
shape
[
0
]
==
gating_output
.
shape
[
0
],
"Number of tokens mismatch"
M
,
_
=
hidden_states
.
shape
M
,
_
=
hidden_states
.
shape
...
@@ -76,20 +74,12 @@ def fused_topk(
...
@@ -76,20 +74,12 @@ def fused_topk(
M
,
topk
,
dtype
=
torch
.
int32
,
device
=
hidden_states
.
device
M
,
topk
,
dtype
=
torch
.
int32
,
device
=
hidden_states
.
device
)
)
if
_is_cuda
or
_is_hip
:
topk_softmax
(
topk_softmax
(
topk_weights
,
topk_weights
,
topk_ids
,
topk_ids
,
token_expert_indicies
,
token_expert_indicies
,
gating_output
.
float
(),
gating_output
.
float
(),
)
)
else
:
vllm_ops
.
topk_softmax
(
topk_weights
,
topk_ids
,
token_expert_indicies
,
gating_output
.
float
(),
)
del
token_expert_indicies
del
token_expert_indicies
if
renormalize
:
if
renormalize
:
...
...
python/sglang/srt/layers/parameter.py
View file @
177320a5
...
@@ -7,8 +7,6 @@ from typing import Callable, Optional, Union
...
@@ -7,8 +7,6 @@ from typing import Callable, Optional, Union
import
torch
import
torch
from
torch.nn
import
Parameter
from
torch.nn
import
Parameter
from
sglang.srt.distributed
import
get_tensor_model_parallel_rank
__all__
=
[
__all__
=
[
"BasevLLMParameter"
,
"BasevLLMParameter"
,
"PackedvLLMParameter"
,
"PackedvLLMParameter"
,
...
...
python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py
View file @
177320a5
# Adapted from https://github.com/vllm-project/vllm/tree/
main
/vllm/model_executor/layers/quantization/compressed_tensors
# Adapted from https://github.com/vllm-project/vllm/tree/
v0.8.2
/vllm/model_executor/layers/quantization/compressed_tensors
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
import
logging
import
logging
...
@@ -39,7 +39,6 @@ from sglang.srt.layers.quantization.compressed_tensors.utils import (
...
@@ -39,7 +39,6 @@ from sglang.srt.layers.quantization.compressed_tensors.utils import (
is_activation_quantization_format
,
is_activation_quantization_format
,
should_ignore_layer
,
should_ignore_layer
,
)
)
from
sglang.srt.layers.quantization.kv_cache
import
BaseKVCacheMethod
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
...
...
python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py
View file @
177320a5
# Adapted from https://github.com/vllm-project/vllm/tree/
main
/vllm/model_executor/layers/quantization/compressed_tensors
# Adapted from https://github.com/vllm-project/vllm/tree/
v0.8.2
/vllm/model_executor/layers/quantization/compressed_tensors
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
import
enum
import
enum
import
logging
import
logging
from
enum
import
Enum
from
enum
import
Enum
from
typing
import
TYPE_CHECKING
,
Callable
,
List
,
Optional
from
typing
import
Callable
,
List
,
Optional
import
torch
import
torch
from
compressed_tensors
import
CompressionFormat
from
compressed_tensors
import
CompressionFormat
from
compressed_tensors.quantization
import
QuantizationStrategy
from
compressed_tensors.quantization
import
QuantizationStrategy
if
TYPE_CHECKING
:
from
sglang.srt.layers.quantization.fp8_kernel
import
scaled_fp8_quant
from
sglang.srt.layers.moe.fused_moe_triton
import
(
FusedMoE
,
FusedMoEMethodBase
,
FusedMoeWeightScaleSupported
,
)
from
sglang.srt.layers.quantization.fp8_utils
import
normalize_e4m3fn_to_e4m3fnuz
from
sglang.srt.layers.quantization.fp8_utils
import
normalize_e4m3fn_to_e4m3fnuz
from
sglang.srt.layers.quantization.utils
import
(
from
sglang.srt.layers.quantization.utils
import
(
all_close_1d
,
all_close_1d
,
...
@@ -29,10 +23,9 @@ from sglang.srt.utils import set_weight_attrs
...
@@ -29,10 +23,9 @@ from sglang.srt.utils import set_weight_attrs
_is_cuda
=
is_cuda
()
_is_cuda
=
is_cuda
()
if
_is_cuda
:
if
not
_is_cuda
:
from
sglang.srt.custom_op
import
scaled_fp8_quant
as
sgl_scaled_fp8_quant
else
:
from
vllm
import
_custom_ops
as
vllm_ops
from
vllm
import
_custom_ops
as
vllm_ops
from
vllm._custom_ops
import
scaled_fp8_quant
try
:
try
:
import
vllm
import
vllm
...
@@ -58,8 +51,6 @@ __all__ = [
...
@@ -58,8 +51,6 @@ __all__ = [
class
CompressedTensorsMoEMethod
:
class
CompressedTensorsMoEMethod
:
def
__new__
(
cls
,
*
args
,
**
kwargs
):
def
__new__
(
cls
,
*
args
,
**
kwargs
):
from
sglang.srt.layers.moe.fused_moe_triton
import
FusedMoEMethodBase
if
cls
is
CompressedTensorsMoEMethod
:
if
cls
is
CompressedTensorsMoEMethod
:
return
super
().
__new__
(
cls
)
return
super
().
__new__
(
cls
)
return
super
().
__new__
(
cls
)
return
super
().
__new__
(
cls
)
...
@@ -76,7 +67,7 @@ class CompressedTensorsMoEMethod:
...
@@ -76,7 +67,7 @@ class CompressedTensorsMoEMethod:
if
quant_config
.
_is_wNa16_group_channel
(
weight_quant
,
input_quant
):
if
quant_config
.
_is_wNa16_group_channel
(
weight_quant
,
input_quant
):
if
not
VLLM_AVAILABLE
:
if
not
VLLM_AVAILABLE
:
raise
ImportError
(
raise
ImportError
(
"vllm is not installed, to use CompressedTensorsWNA16MoEMethod, please install vllm"
"vllm is not installed, to use CompressedTensorsWNA16MoEMethod, please install vllm
.
"
)
)
return
CompressedTensorsWNA16MoEMethod
(
quant_config
)
return
CompressedTensorsWNA16MoEMethod
(
quant_config
)
elif
quant_config
.
_is_fp8_w8a8
(
weight_quant
,
input_quant
):
elif
quant_config
.
_is_fp8_w8a8
(
weight_quant
,
input_quant
):
...
@@ -92,11 +83,6 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
...
@@ -92,11 +83,6 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
def
__init__
(
def
__init__
(
self
,
quant_config
:
"CompressedTensorsConfig"
# type: ignore # noqa E501
self
,
quant_config
:
"CompressedTensorsConfig"
# type: ignore # noqa E501
):
):
from
sglang.srt.layers.moe.fused_moe_triton
import
(
FusedMoEMethodBase
,
FusedMoeWeightScaleSupported
,
)
self
.
quant_config
=
quant_config
self
.
quant_config
=
quant_config
self
.
weight_quant
=
self
.
quant_config
.
target_scheme_map
[
"Linear"
].
get
(
"weights"
)
self
.
weight_quant
=
self
.
quant_config
.
target_scheme_map
[
"Linear"
].
get
(
"weights"
)
self
.
input_quant
=
self
.
quant_config
.
target_scheme_map
[
"Linear"
].
get
(
self
.
input_quant
=
self
.
quant_config
.
target_scheme_map
[
"Linear"
].
get
(
...
@@ -267,19 +253,11 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
...
@@ -267,19 +253,11 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
layer
.
w13_weight
[
expert_id
][
start
:
start
+
shard_size
,
:],
layer
.
w13_weight
[
expert_id
][
start
:
start
+
shard_size
,
:],
layer
.
w13_weight_scale
[
expert_id
][
shard_id
],
layer
.
w13_weight_scale
[
expert_id
][
shard_id
],
)
)
if
_is_cuda
:
(
(
layer
.
w13_weight
[
expert_id
][
start
:
start
+
shard_size
,
:],
layer
.
w13_weight
[
expert_id
][
start
:
start
+
shard_size
,
:],
_
,
_
,
)
=
sgl_scaled_fp8_quant
(
dq_weight
,
max_w13_scales
[
expert_id
])
)
=
scaled_fp8_quant
(
dq_weight
,
max_w13_scales
[
expert_id
])
else
:
(
layer
.
w13_weight
[
expert_id
][
start
:
start
+
shard_size
,
:],
_
,
)
=
vllm_ops
.
scaled_fp8_quant
(
dq_weight
,
max_w13_scales
[
expert_id
]
)
start
+=
shard_size
start
+=
shard_size
layer
.
w13_weight_scale
=
torch
.
nn
.
Parameter
(
layer
.
w13_weight_scale
=
torch
.
nn
.
Parameter
(
...
@@ -345,11 +323,6 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
...
@@ -345,11 +323,6 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
def
__init__
(
def
__init__
(
self
,
quant_config
:
"CompressedTensorsConfig"
# type: ignore # noqa E501
self
,
quant_config
:
"CompressedTensorsConfig"
# type: ignore # noqa E501
):
):
from
sglang.srt.layers.moe.fused_moe_triton
import
(
FusedMoEMethodBase
,
FusedMoeWeightScaleSupported
,
)
self
.
quant_config
=
quant_config
self
.
quant_config
=
quant_config
# TODO: @dsikka: refactor this to use schemes as other kernels
# TODO: @dsikka: refactor this to use schemes as other kernels
# are supported + check if the layer is being ignored.
# are supported + check if the layer is being ignored.
...
@@ -609,7 +582,7 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
...
@@ -609,7 +582,7 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
requires_grad
=
False
,
requires_grad
=
False
,
)
)
marlin_w13_qweight
=
ops
.
gptq_marlin_moe_repack
(
marlin_w13_qweight
=
vllm_
ops
.
gptq_marlin_moe_repack
(
layer
.
w13_weight_packed
,
layer
.
w13_weight_packed
,
layer
.
w13_g_idx_sort_indices
,
layer
.
w13_g_idx_sort_indices
,
layer
.
w13_weight_packed
.
shape
[
1
]
*
self
.
packed_factor
,
layer
.
w13_weight_packed
.
shape
[
1
]
*
self
.
packed_factor
,
...
@@ -617,7 +590,7 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
...
@@ -617,7 +590,7 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
self
.
num_bits
,
self
.
num_bits
,
)
)
replace_tensor
(
"w13_weight_packed"
,
marlin_w13_qweight
)
replace_tensor
(
"w13_weight_packed"
,
marlin_w13_qweight
)
marlin_w2_qweight
=
ops
.
gptq_marlin_moe_repack
(
marlin_w2_qweight
=
vllm_
ops
.
gptq_marlin_moe_repack
(
layer
.
w2_weight_packed
,
layer
.
w2_weight_packed
,
layer
.
w2_g_idx_sort_indices
,
layer
.
w2_g_idx_sort_indices
,
layer
.
w2_weight_packed
.
shape
[
1
]
*
self
.
packed_factor
,
layer
.
w2_weight_packed
.
shape
[
1
]
*
self
.
packed_factor
,
...
@@ -661,14 +634,9 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
...
@@ -661,14 +634,9 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
activation
:
str
=
"silu"
,
activation
:
str
=
"silu"
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
from
sglang.srt.layers.moe.fused_moe_triton
import
FusedMoE
from
sglang.srt.layers.moe.topk
import
select_experts
from
sglang.srt.layers.moe.topk
import
select_experts
assert
activation
==
"silu"
,
"Only SiLU activation is supported."
assert
activation
==
"silu"
,
"Only SiLU activation is supported."
if
not
VLLM_AVAILABLE
:
raise
ImportError
(
"vllm is not installed, to use fused_marlin_moe, please install vllm"
)
if
expert_map
is
not
None
:
if
expert_map
is
not
None
:
raise
NotImplementedError
(
raise
NotImplementedError
(
"Expert Parallelism is not supported for "
"fused Marlin MoE method."
"Expert Parallelism is not supported for "
"fused Marlin MoE method."
...
...
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment