Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
177320a5
Unverified
Commit
177320a5
authored
Apr 16, 2025
by
Lianmin Zheng
Committed by
GitHub
Apr 16, 2025
Browse files
Clean up imports (#5467)
parent
d7bc19a4
Changes
51
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
63 additions
and
217 deletions
+63
-217
python/sglang/__init__.py
python/sglang/__init__.py
+2
-4
python/sglang/bench_serving.py
python/sglang/bench_serving.py
+0
-4
python/sglang/lang/__init__.py
python/sglang/lang/__init__.py
+0
-0
python/sglang/lang/backend/anthropic.py
python/sglang/lang/backend/anthropic.py
+0
-4
python/sglang/lang/backend/base_backend.py
python/sglang/lang/backend/base_backend.py
+1
-1
python/sglang/lang/backend/openai.py
python/sglang/lang/backend/openai.py
+1
-1
python/sglang/lang/backend/vertexai.py
python/sglang/lang/backend/vertexai.py
+0
-1
python/sglang/lang/compiler.py
python/sglang/lang/compiler.py
+1
-7
python/sglang/lang/tracer.py
python/sglang/lang/tracer.py
+3
-7
python/sglang/srt/_custom_ops.py
python/sglang/srt/_custom_ops.py
+0
-2
python/sglang/srt/custom_op.py
python/sglang/srt/custom_op.py
+0
-62
python/sglang/srt/entrypoints/verl_engine.py
python/sglang/srt/entrypoints/verl_engine.py
+1
-2
python/sglang/srt/layers/activation.py
python/sglang/srt/layers/activation.py
+6
-8
python/sglang/srt/layers/layernorm.py
python/sglang/srt/layers/layernorm.py
+1
-1
python/sglang/srt/layers/moe/ep_moe/layer.py
python/sglang/srt/layers/moe/ep_moe/layer.py
+12
-26
python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py
python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py
+12
-19
python/sglang/srt/layers/moe/topk.py
python/sglang/srt/layers/moe/topk.py
+10
-20
python/sglang/srt/layers/parameter.py
python/sglang/srt/layers/parameter.py
+0
-2
python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py
...ers/quantization/compressed_tensors/compressed_tensors.py
+1
-2
python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py
...quantization/compressed_tensors/compressed_tensors_moe.py
+12
-44
No files found.
python/sglang/__init__.py
View file @
177320a5
...
...
@@ -24,6 +24,7 @@ from sglang.api import (
user_end
,
video
,
)
from
sglang.global_config
import
global_config
from
sglang.lang.backend.runtime_endpoint
import
RuntimeEndpoint
from
sglang.lang.choices
import
(
greedy_token_selection
,
...
...
@@ -31,6 +32,7 @@ from sglang.lang.choices import (
unconditional_likelihood_normalized
,
)
from
sglang.utils
import
LazyImport
from
sglang.version
import
__version__
ServerArgs
=
LazyImport
(
"sglang.srt.server_args"
,
"ServerArgs"
)
Anthropic
=
LazyImport
(
"sglang.lang.backend.anthropic"
,
"Anthropic"
)
...
...
@@ -38,10 +40,6 @@ LiteLLM = LazyImport("sglang.lang.backend.litellm", "LiteLLM")
OpenAI
=
LazyImport
(
"sglang.lang.backend.openai"
,
"OpenAI"
)
VertexAI
=
LazyImport
(
"sglang.lang.backend.vertexai"
,
"VertexAI"
)
# Other configs
from
sglang.global_config
import
global_config
from
sglang.version
import
__version__
__all__
=
[
"Engine"
,
"Runtime"
,
...
...
python/sglang/bench_serving.py
View file @
177320a5
...
...
@@ -707,10 +707,6 @@ def sample_random_requests(
# Download sharegpt if necessary
if
not
os
.
path
.
isfile
(
dataset_path
):
print
(
"If you do not want to randomly sample from a dataset,"
" please use --dataset-name random-ids."
)
dataset_path
=
download_and_cache_file
(
SHAREGPT_URL
)
# Load the dataset.
...
...
python/sglang/lang/__init__.py
deleted
100644 → 0
View file @
d7bc19a4
python/sglang/lang/backend/anthropic.py
View file @
177320a5
from
typing
import
List
,
Optional
,
Union
import
numpy
as
np
from
sglang.lang.backend.base_backend
import
BaseBackend
from
sglang.lang.chat_template
import
get_chat_template
from
sglang.lang.interpreter
import
StreamExecutor
...
...
python/sglang/lang/backend/base_backend.py
View file @
177320a5
from
typing
import
Callable
,
List
,
Optional
,
Union
from
typing
import
List
,
Optional
,
Union
from
sglang.lang.chat_template
import
get_chat_template
from
sglang.lang.choices
import
ChoicesDecision
,
ChoicesSamplingMethod
...
...
python/sglang/lang/backend/openai.py
View file @
177320a5
...
...
@@ -2,7 +2,7 @@ import dataclasses
import
logging
import
time
import
warnings
from
typing
import
Callable
,
List
,
Optional
,
Union
from
typing
import
List
,
Optional
,
Union
import
numpy
as
np
...
...
python/sglang/lang/backend/vertexai.py
View file @
177320a5
import
os
import
warnings
from
typing
import
Optional
from
sglang.lang.backend.base_backend
import
BaseBackend
from
sglang.lang.chat_template
import
get_chat_template
...
...
python/sglang/lang/compiler.py
View file @
177320a5
...
...
@@ -5,13 +5,7 @@ from typing import List, Union
from
sglang.global_config
import
global_config
from
sglang.lang.interpreter
import
ProgramState
,
StreamExecutor
,
cache_program
from
sglang.lang.ir
import
(
SglArgument
,
SglConstantText
,
SglExpr
,
SglSamplingParams
,
SglVariable
,
)
from
sglang.lang.ir
import
SglArgument
,
SglExpr
,
SglSamplingParams
,
SglVariable
def
compile_func
(
function
,
backend
):
...
...
python/sglang/lang/tracer.py
View file @
177320a5
"""Tracing a program."""
import
uuid
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Optional
,
Union
from
typing
import
Any
,
Dict
,
List
,
Optional
from
sglang.global_config
import
global_config
from
sglang.lang.backend.base_backend
import
BaseBackend
from
sglang.lang.interpreter
import
ProgramState
,
ProgramStateGroup
from
sglang.lang.ir
import
(
SglArgument
,
SglCommitLazy
,
SglConcateAndAppend
,
SglConstantText
,
SglExpr
,
SglExprList
,
SglFork
,
SglFunction
,
SglGen
,
SglGetForkItem
,
SglRoleBegin
,
...
...
@@ -230,8 +226,8 @@ class TracerProgramState(ProgramState):
self
.
cur_role
=
None
def
_execute_var_scope_end
(
self
,
expr
:
SglVarScopeEnd
):
new_node
=
SglVariable
(
name
,
source
=
self
.
last_node
)
self
.
variables
[
name
]
=
new_node
new_node
=
SglVariable
(
expr
.
name
,
source
=
self
.
last_node
)
self
.
variables
[
expr
.
name
]
=
new_node
def
get_var
(
self
,
name
):
ret
=
self
.
arguments
.
get
(
name
,
None
)
...
...
python/sglang/srt/_custom_ops.py
View file @
177320a5
# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/_custom_ops.py
import
logging
import
os
from
typing
import
List
,
Tuple
import
torch
import
torch.library
from
sglang.srt.utils
import
get_bool_env_var
,
is_hip
,
is_hpu
...
...
python/sglang/srt/custom_op.py
View file @
177320a5
...
...
@@ -42,65 +42,3 @@ class CustomOp(nn.Module):
return
self
.
forward_hip
else
:
return
self
.
forward_native
if
_is_cuda
:
from
sgl_kernel
import
sgl_per_tensor_quant_fp8
,
sgl_per_token_quant_fp8
def
scaled_fp8_quant
(
input
:
torch
.
Tensor
,
scale
:
Optional
[
torch
.
Tensor
]
=
None
,
num_token_padding
:
Optional
[
int
]
=
None
,
use_per_token_if_dynamic
:
bool
=
False
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""
Quantize input tensor to FP8 (8-bit floating point) format.
Args:
input (torch.Tensor): Input tensor to be quantized
scale (Optional[torch.Tensor]): Pre-computed scaling factor for static quantization.
If None, scales will be computed dynamically.
num_token_padding (Optional[int]): If specified, pad the first dimension
of the output to at least this value.
use_per_token_if_dynamic (bool): When using dynamic scaling (scale=None),
determines the quantization granularity:
- True: compute scale per token
- False: compute single scale per tensor
Returns:
Tuple[torch.Tensor, torch.Tensor]: A tuple containing:
- quantized_tensor: The FP8 quantized version of input
- scale_tensor: The scaling factors used for quantization
Raises:
AssertionError: If input is not 2D or if static scale's numel != 1
"""
assert
input
.
ndim
==
2
,
f
"Expected 2D input tensor, got
{
input
.
ndim
}
D"
shape
=
input
.
shape
out_dtype
=
torch
.
float8_e4m3fnuz
if
_is_hip
else
torch
.
float8_e4m3fn
if
num_token_padding
:
shape
=
(
max
(
num_token_padding
,
input
.
shape
[
0
]),
shape
[
1
])
output
=
torch
.
empty
(
shape
,
device
=
input
.
device
,
dtype
=
out_dtype
)
if
scale
is
None
:
# Dynamic scaling
if
use_per_token_if_dynamic
:
scale
=
torch
.
empty
(
(
shape
[
0
],
1
),
device
=
input
.
device
,
dtype
=
torch
.
float32
)
sgl_per_token_quant_fp8
(
input
,
output
,
scale
)
else
:
scale
=
torch
.
zeros
(
1
,
device
=
input
.
device
,
dtype
=
torch
.
float32
)
sgl_per_tensor_quant_fp8
(
input
,
output
,
scale
,
is_static
=
False
)
# False for dynamic
else
:
# Static scaling
assert
(
scale
.
numel
()
==
1
),
f
"Expected scalar scale, got numel=
{
scale
.
numel
()
}
"
sgl_per_tensor_quant_fp8
(
input
,
output
,
scale
,
is_static
=
True
)
# True for static
return
output
,
scale
python/sglang/srt/entrypoints/verl_engine.py
View file @
177320a5
...
...
@@ -19,11 +19,10 @@ import torch.distributed as dist
from
PIL.Image
import
Image
from
torch.distributed.tensor
import
DeviceMesh
,
DTensor
from
sglang.srt.entrypoints.engine
import
Engine
from
sglang.srt.entrypoints.http_server_engine
import
HttpServerEngineAdapter
from
sglang.srt.model_executor.model_runner
import
LocalSerializedTensor
from
sglang.srt.patch_torch
import
monkey_patch_torch_reductions
from
sglang.srt.server
import
Engine
from
sglang.srt.server_args
import
PortArgs
,
ServerArgs
from
sglang.srt.utils
import
MultiprocessingSerializer
,
broadcast_pyobj
...
...
python/sglang/srt/layers/activation.py
View file @
177320a5
...
...
@@ -21,13 +21,6 @@ import torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
sglang.srt.utils
import
is_cuda_available
_is_cuda
=
is_cuda_available
()
if
_is_cuda
:
from
sgl_kernel
import
gelu_and_mul
,
gelu_tanh_and_mul
,
silu_and_mul
from
sglang.srt.custom_op
import
CustomOp
from
sglang.srt.distributed
import
(
divide
,
...
...
@@ -35,7 +28,12 @@ from sglang.srt.distributed import (
get_tensor_model_parallel_world_size
,
)
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.utils
import
set_weight_attrs
from
sglang.srt.utils
import
is_cuda_available
,
set_weight_attrs
_is_cuda
=
is_cuda_available
()
if
_is_cuda
:
from
sgl_kernel
import
gelu_and_mul
,
gelu_tanh_and_mul
,
silu_and_mul
logger
=
logging
.
getLogger
(
__name__
)
...
...
python/sglang/srt/layers/layernorm.py
View file @
177320a5
...
...
@@ -19,6 +19,7 @@ from typing import Optional, Tuple, Union
import
torch
import
torch.nn
as
nn
from
sglang.srt.custom_op
import
CustomOp
from
sglang.srt.utils
import
is_cuda_available
_is_cuda
=
is_cuda_available
()
...
...
@@ -31,7 +32,6 @@ if _is_cuda:
rmsnorm
,
)
from
sglang.srt.custom_op
import
CustomOp
logger
=
logging
.
getLogger
(
__name__
)
...
...
python/sglang/srt/layers/moe/ep_moe/layer.py
View file @
177320a5
...
...
@@ -2,6 +2,7 @@ import logging
from
typing
import
Callable
,
List
,
Optional
,
Tuple
import
torch
from
torch.nn
import
Module
try
:
from
deep_gemm
import
(
...
...
@@ -13,8 +14,6 @@ try:
except
ImportError
:
use_deep_gemm
=
False
from
torch.nn
import
Module
from
sglang.srt.custom_op
import
CustomOp
from
sglang.srt.distributed
import
(
get_tensor_model_parallel_rank
,
...
...
@@ -37,22 +36,17 @@ from sglang.srt.layers.quantization.base_config import (
QuantizeMethodBase
,
)
from
sglang.srt.layers.quantization.fp8
import
Fp8Config
,
Fp8MoEMethod
from
sglang.srt.layers.quantization.fp8_kernel
import
scaled_fp8_quant
from
sglang.srt.model_executor.forward_batch_info
import
ForwardMode
from
sglang.srt.utils
import
DeepEPMode
,
is_cuda
,
is_hip
,
set_weight_attrs
from
sglang.srt.utils
import
DeepEPMode
,
is_hip
,
set_weight_attrs
_is_
cuda
=
is_
cuda
()
_is_
hip
=
is_
hip
()
if
_is_cuda
:
from
sglang.srt.custom_op
import
scaled_fp8_quant
as
sgl_scaled_fp8_quant
else
:
from
vllm
import
_custom_ops
as
vllm_ops
if
_is_hip
:
from
vllm._custom_ops
import
scaled_fp8_quant
logger
=
logging
.
getLogger
(
__name__
)
_is_hip
=
is_hip
()
_buffer
=
None
class
GroupedGemmRunner
(
torch
.
nn
.
Module
):
flashinfer_gemm_warpper
=
None
...
...
@@ -740,20 +734,12 @@ class Fp8EPMoEMethod(Fp8MoEMethod):
)
for
expert
in
range
(
layer
.
num_experts_per_partition
):
if
_is_cuda
:
w13_weight
[
expert
,
:,
:],
layer
.
w13_weight_scale
[
expert
]
=
(
sgl_scaled_fp8_quant
(
layer
.
w13_weight
.
data
[
expert
,
:,
:])
)
w2_weight
[
expert
,
:,
:],
layer
.
w2_weight_scale
[
expert
]
=
(
sgl_scaled_fp8_quant
(
layer
.
w2_weight
.
data
[
expert
,
:,
:])
)
else
:
w13_weight
[
expert
,
:,
:],
layer
.
w13_weight_scale
[
expert
]
=
(
vllm_ops
.
scaled_fp8_quant
(
layer
.
w13_weight
.
data
[
expert
,
:,
:])
)
w2_weight
[
expert
,
:,
:],
layer
.
w2_weight_scale
[
expert
]
=
(
vllm_ops
.
scaled_fp8_quant
(
layer
.
w2_weight
.
data
[
expert
,
:,
:])
)
w13_weight
[
expert
,
:,
:],
layer
.
w13_weight_scale
[
expert
]
=
(
scaled_fp8_quant
(
layer
.
w13_weight
.
data
[
expert
,
:,
:])
)
w2_weight
[
expert
,
:,
:],
layer
.
w2_weight_scale
[
expert
]
=
(
scaled_fp8_quant
(
layer
.
w2_weight
.
data
[
expert
,
:,
:])
)
layer
.
w13_weight
=
torch
.
nn
.
Parameter
(
w13_weight
,
requires_grad
=
False
)
layer
.
w2_weight
=
torch
.
nn
.
Parameter
(
w2_weight
,
requires_grad
=
False
)
return
...
...
python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py
View file @
177320a5
...
...
@@ -13,6 +13,7 @@ import triton
import
triton.language
as
tl
from
sglang.srt.layers.moe.topk
import
select_experts
from
sglang.srt.layers.quantization.fp8_kernel
import
scaled_fp8_quant
from
sglang.srt.utils
import
(
direct_register_custom_op
,
get_bool_env_var
,
...
...
@@ -22,28 +23,25 @@ from sglang.srt.utils import (
)
_is_hip
=
is_hip
()
logger
=
logging
.
getLogger
(
__name__
)
padding_size
=
128
if
bool
(
int
(
os
.
getenv
(
"MOE_PADDING"
,
"0"
)))
else
0
enable_moe_align_block_size_triton
=
bool
(
int
(
os
.
getenv
(
"ENABLE_MOE_ALIGN_BLOCK_SIZE_TRITON"
,
"0"
))
)
_is_cuda
=
is_cuda
()
if
_is_cuda
:
from
sgl_kernel
import
gelu_and_mul
,
silu_and_mul
from
sglang.srt.custom_op
import
scaled_fp8_quant
as
sgl_scaled_fp8_quant
else
:
from
vllm
import
_custom_ops
as
vllm_ops
from
vllm._custom_ops
import
scaled_fp8_quant
if
_is_cuda
or
_is_hip
:
from
sgl_kernel
import
moe_align_block_size
as
sgl_moe_align_block_size
logger
=
logging
.
getLogger
(
__name__
)
padding_size
=
128
if
bool
(
int
(
os
.
getenv
(
"MOE_PADDING"
,
"0"
)))
else
0
enable_moe_align_block_size_triton
=
bool
(
int
(
os
.
getenv
(
"ENABLE_MOE_ALIGN_BLOCK_SIZE_TRITON"
,
"0"
))
)
@
triton
.
jit
def
write_zeros_to_output
(
c_ptr
,
...
...
@@ -770,14 +768,9 @@ def invoke_fused_moe_kernel(
# activation tensor-wise fp8 quantization, dynamic or static
padded_size
=
padding_size
# activations apply per-token quantization when weights apply per-channel quantization by default
if
_is_cuda
:
A
,
A_scale
=
sgl_scaled_fp8_quant
(
A
,
A_scale
,
use_per_token_if_dynamic
=
per_channel_quant
)
else
:
A
,
A_scale
=
vllm_ops
.
scaled_fp8_quant
(
A
,
A_scale
,
use_per_token_if_dynamic
=
per_channel_quant
)
A
,
A_scale
=
scaled_fp8_quant
(
A
,
A_scale
,
use_per_token_if_dynamic
=
per_channel_quant
)
else
:
# activation block-wise fp8 quantization
assert
len
(
block_shape
)
==
2
...
...
python/sglang/srt/layers/moe/topk.py
View file @
177320a5
...
...
@@ -13,7 +13,6 @@
# ==============================================================================
import
math
import
os
from
typing
import
Callable
,
Optional
import
torch
...
...
@@ -29,6 +28,10 @@ _is_hip = is_hip()
if
_is_cuda
:
from
sgl_kernel
import
moe_fused_gate
if
_is_cuda
or
_is_hip
:
from
sgl_kernel
import
topk_softmax
expert_distribution_recorder
=
ExpertDistributionRecorder
()
...
...
@@ -59,11 +62,6 @@ def fused_topk(
topk
:
int
,
renormalize
:
bool
,
):
if
_is_cuda
or
_is_hip
:
from
sgl_kernel
import
topk_softmax
else
:
from
vllm
import
_custom_ops
as
vllm_ops
assert
hidden_states
.
shape
[
0
]
==
gating_output
.
shape
[
0
],
"Number of tokens mismatch"
M
,
_
=
hidden_states
.
shape
...
...
@@ -76,20 +74,12 @@ def fused_topk(
M
,
topk
,
dtype
=
torch
.
int32
,
device
=
hidden_states
.
device
)
if
_is_cuda
or
_is_hip
:
topk_softmax
(
topk_weights
,
topk_ids
,
token_expert_indicies
,
gating_output
.
float
(),
)
else
:
vllm_ops
.
topk_softmax
(
topk_weights
,
topk_ids
,
token_expert_indicies
,
gating_output
.
float
(),
)
topk_softmax
(
topk_weights
,
topk_ids
,
token_expert_indicies
,
gating_output
.
float
(),
)
del
token_expert_indicies
if
renormalize
:
...
...
python/sglang/srt/layers/parameter.py
View file @
177320a5
...
...
@@ -7,8 +7,6 @@ from typing import Callable, Optional, Union
import
torch
from
torch.nn
import
Parameter
from
sglang.srt.distributed
import
get_tensor_model_parallel_rank
__all__
=
[
"BasevLLMParameter"
,
"PackedvLLMParameter"
,
...
...
python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py
View file @
177320a5
# Adapted from https://github.com/vllm-project/vllm/tree/
main
/vllm/model_executor/layers/quantization/compressed_tensors
# Adapted from https://github.com/vllm-project/vllm/tree/
v0.8.2
/vllm/model_executor/layers/quantization/compressed_tensors
# SPDX-License-Identifier: Apache-2.0
import
logging
...
...
@@ -39,7 +39,6 @@ from sglang.srt.layers.quantization.compressed_tensors.utils import (
is_activation_quantization_format
,
should_ignore_layer
,
)
from
sglang.srt.layers.quantization.kv_cache
import
BaseKVCacheMethod
logger
=
logging
.
getLogger
(
__name__
)
...
...
python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py
View file @
177320a5
# Adapted from https://github.com/vllm-project/vllm/tree/
main
/vllm/model_executor/layers/quantization/compressed_tensors
# Adapted from https://github.com/vllm-project/vllm/tree/
v0.8.2
/vllm/model_executor/layers/quantization/compressed_tensors
# SPDX-License-Identifier: Apache-2.0
import
enum
import
logging
from
enum
import
Enum
from
typing
import
TYPE_CHECKING
,
Callable
,
List
,
Optional
from
typing
import
Callable
,
List
,
Optional
import
torch
from
compressed_tensors
import
CompressionFormat
from
compressed_tensors.quantization
import
QuantizationStrategy
if
TYPE_CHECKING
:
from
sglang.srt.layers.moe.fused_moe_triton
import
(
FusedMoE
,
FusedMoEMethodBase
,
FusedMoeWeightScaleSupported
,
)
from
sglang.srt.layers.quantization.fp8_kernel
import
scaled_fp8_quant
from
sglang.srt.layers.quantization.fp8_utils
import
normalize_e4m3fn_to_e4m3fnuz
from
sglang.srt.layers.quantization.utils
import
(
all_close_1d
,
...
...
@@ -29,10 +23,9 @@ from sglang.srt.utils import set_weight_attrs
_is_cuda
=
is_cuda
()
if
_is_cuda
:
from
sglang.srt.custom_op
import
scaled_fp8_quant
as
sgl_scaled_fp8_quant
else
:
if
not
_is_cuda
:
from
vllm
import
_custom_ops
as
vllm_ops
from
vllm._custom_ops
import
scaled_fp8_quant
try
:
import
vllm
...
...
@@ -58,8 +51,6 @@ __all__ = [
class
CompressedTensorsMoEMethod
:
def
__new__
(
cls
,
*
args
,
**
kwargs
):
from
sglang.srt.layers.moe.fused_moe_triton
import
FusedMoEMethodBase
if
cls
is
CompressedTensorsMoEMethod
:
return
super
().
__new__
(
cls
)
return
super
().
__new__
(
cls
)
...
...
@@ -76,7 +67,7 @@ class CompressedTensorsMoEMethod:
if
quant_config
.
_is_wNa16_group_channel
(
weight_quant
,
input_quant
):
if
not
VLLM_AVAILABLE
:
raise
ImportError
(
"vllm is not installed, to use CompressedTensorsWNA16MoEMethod, please install vllm"
"vllm is not installed, to use CompressedTensorsWNA16MoEMethod, please install vllm
.
"
)
return
CompressedTensorsWNA16MoEMethod
(
quant_config
)
elif
quant_config
.
_is_fp8_w8a8
(
weight_quant
,
input_quant
):
...
...
@@ -92,11 +83,6 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
def
__init__
(
self
,
quant_config
:
"CompressedTensorsConfig"
# type: ignore # noqa E501
):
from
sglang.srt.layers.moe.fused_moe_triton
import
(
FusedMoEMethodBase
,
FusedMoeWeightScaleSupported
,
)
self
.
quant_config
=
quant_config
self
.
weight_quant
=
self
.
quant_config
.
target_scheme_map
[
"Linear"
].
get
(
"weights"
)
self
.
input_quant
=
self
.
quant_config
.
target_scheme_map
[
"Linear"
].
get
(
...
...
@@ -267,19 +253,11 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
layer
.
w13_weight
[
expert_id
][
start
:
start
+
shard_size
,
:],
layer
.
w13_weight_scale
[
expert_id
][
shard_id
],
)
(
layer
.
w13_weight
[
expert_id
][
start
:
start
+
shard_size
,
:],
_
,
)
=
scaled_fp8_quant
(
dq_weight
,
max_w13_scales
[
expert_id
])
if
_is_cuda
:
(
layer
.
w13_weight
[
expert_id
][
start
:
start
+
shard_size
,
:],
_
,
)
=
sgl_scaled_fp8_quant
(
dq_weight
,
max_w13_scales
[
expert_id
])
else
:
(
layer
.
w13_weight
[
expert_id
][
start
:
start
+
shard_size
,
:],
_
,
)
=
vllm_ops
.
scaled_fp8_quant
(
dq_weight
,
max_w13_scales
[
expert_id
]
)
start
+=
shard_size
layer
.
w13_weight_scale
=
torch
.
nn
.
Parameter
(
...
...
@@ -345,11 +323,6 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
def
__init__
(
self
,
quant_config
:
"CompressedTensorsConfig"
# type: ignore # noqa E501
):
from
sglang.srt.layers.moe.fused_moe_triton
import
(
FusedMoEMethodBase
,
FusedMoeWeightScaleSupported
,
)
self
.
quant_config
=
quant_config
# TODO: @dsikka: refactor this to use schemes as other kernels
# are supported + check if the layer is being ignored.
...
...
@@ -609,7 +582,7 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
requires_grad
=
False
,
)
marlin_w13_qweight
=
ops
.
gptq_marlin_moe_repack
(
marlin_w13_qweight
=
vllm_
ops
.
gptq_marlin_moe_repack
(
layer
.
w13_weight_packed
,
layer
.
w13_g_idx_sort_indices
,
layer
.
w13_weight_packed
.
shape
[
1
]
*
self
.
packed_factor
,
...
...
@@ -617,7 +590,7 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
self
.
num_bits
,
)
replace_tensor
(
"w13_weight_packed"
,
marlin_w13_qweight
)
marlin_w2_qweight
=
ops
.
gptq_marlin_moe_repack
(
marlin_w2_qweight
=
vllm_
ops
.
gptq_marlin_moe_repack
(
layer
.
w2_weight_packed
,
layer
.
w2_g_idx_sort_indices
,
layer
.
w2_weight_packed
.
shape
[
1
]
*
self
.
packed_factor
,
...
...
@@ -661,14 +634,9 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
activation
:
str
=
"silu"
,
)
->
torch
.
Tensor
:
from
sglang.srt.layers.moe.fused_moe_triton
import
FusedMoE
from
sglang.srt.layers.moe.topk
import
select_experts
assert
activation
==
"silu"
,
"Only SiLU activation is supported."
if
not
VLLM_AVAILABLE
:
raise
ImportError
(
"vllm is not installed, to use fused_marlin_moe, please install vllm"
)
if
expert_map
is
not
None
:
raise
NotImplementedError
(
"Expert Parallelism is not supported for "
"fused Marlin MoE method."
...
...
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment