Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
177320a5
Unverified
Commit
177320a5
authored
Apr 16, 2025
by
Lianmin Zheng
Committed by
GitHub
Apr 16, 2025
Browse files
Clean up imports (#5467)
parent
d7bc19a4
Changes
51
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
288 additions
and
306 deletions
+288
-306
python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py
...compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py
+0
-3
python/sglang/srt/layers/quantization/fp8.py
python/sglang/srt/layers/quantization/fp8.py
+113
-132
python/sglang/srt/layers/quantization/fp8_kernel.py
python/sglang/srt/layers/quantization/fp8_kernel.py
+76
-45
python/sglang/srt/layers/quantization/fp8_utils.py
python/sglang/srt/layers/quantization/fp8_utils.py
+54
-63
python/sglang/srt/layers/quantization/utils.py
python/sglang/srt/layers/quantization/utils.py
+5
-11
python/sglang/srt/layers/quantization/w8a8_int8.py
python/sglang/srt/layers/quantization/w8a8_int8.py
+5
-7
python/sglang/srt/layers/rotary_embedding.py
python/sglang/srt/layers/rotary_embedding.py
+3
-2
python/sglang/srt/lora/backend/__init__.py
python/sglang/srt/lora/backend/__init__.py
+0
-25
python/sglang/srt/lora/backend/base_backend.py
python/sglang/srt/lora/backend/base_backend.py
+18
-2
python/sglang/srt/lora/backend/flashinfer_backend.py
python/sglang/srt/lora/backend/flashinfer_backend.py
+1
-1
python/sglang/srt/lora/backend/triton_backend.py
python/sglang/srt/lora/backend/triton_backend.py
+1
-1
python/sglang/srt/lora/layers.py
python/sglang/srt/lora/layers.py
+1
-1
python/sglang/srt/lora/lora.py
python/sglang/srt/lora/lora.py
+1
-1
python/sglang/srt/lora/lora_manager.py
python/sglang/srt/lora/lora_manager.py
+1
-1
python/sglang/srt/managers/detokenizer_manager.py
python/sglang/srt/managers/detokenizer_manager.py
+0
-1
python/sglang/srt/managers/mm_utils.py
python/sglang/srt/managers/mm_utils.py
+4
-3
python/sglang/srt/managers/multimodal_processor.py
python/sglang/srt/managers/multimodal_processor.py
+0
-2
python/sglang/srt/managers/multimodal_processors/base_processor.py
...lang/srt/managers/multimodal_processors/base_processor.py
+3
-2
python/sglang/srt/model_executor/cuda_graph_runner.py
python/sglang/srt/model_executor/cuda_graph_runner.py
+2
-2
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+0
-1
No files found.
python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py
View file @
177320a5
...
...
@@ -17,7 +17,6 @@ from sglang.srt.layers.quantization.compressed_tensors.schemes import (
)
from
sglang.srt.layers.quantization.fp8_utils
import
(
Fp8LinearOp
,
maybe_create_device_identity
,
normalize_e4m3fn_to_e4m3fnuz
,
)
from
sglang.srt.layers.quantization.utils
import
is_fp8_fnuz
,
requantize_with_max_scale
...
...
@@ -99,8 +98,6 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
weight_loader
:
Callable
,
**
kwargs
,
):
maybe_create_device_identity
()
output_size_per_partition
=
sum
(
output_partition_sizes
)
layer
.
logical_widths
=
output_partition_sizes
...
...
python/sglang/srt/layers/quantization/fp8.py
View file @
177320a5
...
...
@@ -8,15 +8,6 @@ import torch.nn.functional as F
from
torch.nn
import
Module
from
torch.nn.parameter
import
Parameter
from
sglang.srt.layers.quantization.kv_cache
import
BaseKVCacheMethod
from
sglang.srt.layers.quantization.utils
import
(
all_close_1d
,
convert_to_channelwise
,
is_layer_skipped
,
per_tensor_dequantize
,
requantize_with_max_scale
,
)
try
:
from
vllm.model_executor.layers.quantization.utils.marlin_utils_fp8
import
(
apply_fp8_marlin_linear
,
...
...
@@ -27,11 +18,12 @@ try:
except
ImportError
:
MARLIN_FP8_AVAILABLE
=
False
def
apply_fp8_marlin_linear
(
*
args
,
**
kwargs
):
raise
ImportError
(
"vllm is not installed"
)
def
dummy_func
(
*
args
,
**
kwargs
):
raise
ImportError
(
"marlin FP8 requires some operators from vllm. Please install vllm."
)
def
prepare_fp8_layer_for_marlin
(
*
args
,
**
kwargs
):
raise
ImportError
(
"vllm is not installed"
)
apply_fp8_marlin_linear
=
prepare_fp8_layer_for_marlin
=
dummy_func
from
sglang.srt.distributed
import
get_tensor_model_parallel_world_size
...
...
@@ -49,7 +41,10 @@ from sglang.srt.layers.quantization.base_config import (
QuantizationConfig
,
QuantizeMethodBase
,
)
from
sglang.srt.layers.quantization.fp8_kernel
import
per_token_group_quant_fp8
from
sglang.srt.layers.quantization.fp8_kernel
import
(
per_token_group_quant_fp8
,
scaled_fp8_quant
,
)
from
sglang.srt.layers.quantization.fp8_utils
import
(
apply_fp8_linear
,
apply_w8a8_block_fp8_linear
,
...
...
@@ -57,30 +52,35 @@ from sglang.srt.layers.quantization.fp8_utils import (
input_to_float8
,
normalize_e4m3fn_to_e4m3fnuz
,
)
from
sglang.srt.layers.quantization.kv_cache
import
BaseKVCacheMethod
from
sglang.srt.layers.quantization.utils
import
(
all_close_1d
,
convert_to_channelwise
,
is_layer_skipped
,
per_tensor_dequantize
,
requantize_with_max_scale
,
)
from
sglang.srt.utils
import
(
get_bool_env_var
,
is_cuda
,
is_hip
,
permute_weight
,
print_warning_once
,
set_weight_attrs
,
)
ACTIVATION_SCHEMES
=
[
"static"
,
"dynamic"
]
_is_hip
=
is_hip
()
_is_cuda
=
is_cuda
()
if
_is_hip
:
from
aiter
import
ActivationType
from
aiter.fused_moe_bf16_asm
import
asm_moe
,
ck_moe_2stages
,
ck_moe_2stages_win4
from
aiter.ops.shuffle
import
shuffle_weight
_is_cuda
=
is_cuda
()
if
not
_is_cuda
:
from
vllm._custom_ops
import
scaled_fp8_quant
if
_is_cuda
:
from
sglang.srt.custom_op
import
scaled_fp8_quant
as
sgl_scaled_fp8_quant
else
:
from
vllm
import
_custom_ops
as
vllm_ops
ACTIVATION_SCHEMES
=
[
"static"
,
"dynamic"
]
logger
=
logging
.
getLogger
(
__name__
)
...
...
@@ -243,7 +243,6 @@ class Fp8LinearMethod(LinearMethodBase):
)
layer
.
logical_widths
=
output_partition_sizes
layer
.
input_size_per_partition
=
input_size_per_partition
layer
.
output_size_per_partition
=
output_size_per_partition
layer
.
orig_dtype
=
params_dtype
...
...
@@ -327,7 +326,9 @@ class Fp8LinearMethod(LinearMethodBase):
layer
.
weight_scale_inv
.
data
,
requires_grad
=
False
)
return
layer
.
weight
=
torch
.
nn
.
Parameter
(
layer
.
weight
.
data
,
requires_grad
=
False
)
# If checkpoint not serialized fp8, quantize the weights.
if
not
self
.
quant_config
.
is_checkpoint_fp8_serialized
:
if
self
.
cutlass_fp8_supported
or
self
.
use_marlin
:
...
...
@@ -391,12 +392,9 @@ class Fp8LinearMethod(LinearMethodBase):
)
if
self
.
use_marlin
:
try
:
prepare_fp8_layer_for_marlin
(
layer
)
# Activations not quantized for marlin.
del
layer
.
input_scale
except
ImportError
:
self
.
use_marlin
=
False
prepare_fp8_layer_for_marlin
(
layer
)
# Activations not quantized for marlin.
del
layer
.
input_scale
def
apply
(
self
,
...
...
@@ -406,18 +404,15 @@ class Fp8LinearMethod(LinearMethodBase):
)
->
torch
.
Tensor
:
if
self
.
use_marlin
:
try
:
return
apply_fp8_marlin_linear
(
input
=
x
,
weight
=
layer
.
weight
,
weight_scale
=
layer
.
weight_scale
,
workspace
=
layer
.
workspace
,
size_n
=
layer
.
output_size_per_partition
,
size_k
=
layer
.
input_size_per_partition
,
bias
=
bias
,
)
except
ImportError
:
self
.
use_marlin
=
False
return
apply_fp8_marlin_linear
(
input
=
x
,
weight
=
layer
.
weight
,
weight_scale
=
layer
.
weight_scale
,
workspace
=
layer
.
workspace
,
size_n
=
layer
.
output_size_per_partition
,
size_k
=
layer
.
input_size_per_partition
,
bias
=
bias
,
)
if
self
.
block_quant
:
return
apply_w8a8_block_fp8_linear
(
...
...
@@ -516,7 +511,7 @@ class Fp8MoEMethod:
)
# WEIGHTS
if
get_bool_env_var
(
"USE_INT4_WEIGHT"
):
if
_is_hip
and
get_bool_env_var
(
"USE_INT4_WEIGHT"
):
# INT4 MoE weight - INT32 packed
w13_weight
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
...
...
@@ -617,7 +612,7 @@ class Fp8MoEMethod:
set_weight_attrs
(
w13_weight_scale
,
extra_weight_attrs
)
set_weight_attrs
(
w2_weight_scale
,
extra_weight_attrs
)
if
get_bool_env_var
(
"USE_INT4_WEIGHT"
):
if
_is_hip
and
get_bool_env_var
(
"USE_INT4_WEIGHT"
):
extra_weight_attrs
.
update
(
{
"quant_method"
:
FusedMoeWeightScaleSupported
.
CHANNEL
.
value
}
)
...
...
@@ -649,7 +644,7 @@ class Fp8MoEMethod:
layer
.
w2_input_scale
=
None
def
process_weights_after_loading
(
self
,
layer
:
Module
)
->
None
:
if
get_bool_env_var
(
"USE_INT4_WEIGHT"
):
if
_is_hip
and
get_bool_env_var
(
"USE_INT4_WEIGHT"
):
self
.
process_weights_hip_int4
(
layer
)
return
...
...
@@ -706,20 +701,12 @@ class Fp8MoEMethod:
requires_grad
=
False
,
)
for
expert
in
range
(
layer
.
num_experts
):
if
_is_cuda
:
w13_weight
[
expert
,
:,
:],
layer
.
w13_weight_scale
[
expert
]
=
(
sgl_scaled_fp8_quant
(
layer
.
w13_weight
.
data
[
expert
,
:,
:])
)
w2_weight
[
expert
,
:,
:],
layer
.
w2_weight_scale
[
expert
]
=
(
sgl_scaled_fp8_quant
(
layer
.
w2_weight
.
data
[
expert
,
:,
:])
)
else
:
w13_weight
[
expert
,
:,
:],
layer
.
w13_weight_scale
[
expert
]
=
(
vllm_ops
.
scaled_fp8_quant
(
layer
.
w13_weight
.
data
[
expert
,
:,
:])
)
w2_weight
[
expert
,
:,
:],
layer
.
w2_weight_scale
[
expert
]
=
(
vllm_ops
.
scaled_fp8_quant
(
layer
.
w2_weight
.
data
[
expert
,
:,
:])
)
w13_weight
[
expert
,
:,
:],
layer
.
w13_weight_scale
[
expert
]
=
(
scaled_fp8_quant
(
layer
.
w13_weight
.
data
[
expert
,
:,
:])
)
w2_weight
[
expert
,
:,
:],
layer
.
w2_weight_scale
[
expert
]
=
(
scaled_fp8_quant
(
layer
.
w2_weight
.
data
[
expert
,
:,
:])
)
layer
.
w13_weight
=
torch
.
nn
.
Parameter
(
w13_weight
,
requires_grad
=
False
)
layer
.
w2_weight
=
torch
.
nn
.
Parameter
(
w2_weight
,
requires_grad
=
False
)
...
...
@@ -796,18 +783,10 @@ class Fp8MoEMethod:
layer
.
w13_weight
[
expert_id
][
start
:
start
+
shard_size
,
:],
layer
.
w13_weight_scale
[
expert_id
][
shard_id
],
)
if
_is_cuda
:
(
layer
.
w13_weight
[
expert_id
][
start
:
start
+
shard_size
,
:],
_
,
)
=
sgl_scaled_fp8_quant
(
dq_weight
,
max_w13_scales
[
expert_id
])
else
:
(
layer
.
w13_weight
[
expert_id
][
start
:
start
+
shard_size
,
:],
_
,
)
=
vllm_ops
.
scaled_fp8_quant
(
dq_weight
,
max_w13_scales
[
expert_id
]
)
(
layer
.
w13_weight
[
expert_id
][
start
:
start
+
shard_size
,
:],
_
,
)
=
scaled_fp8_quant
(
dq_weight
,
max_w13_scales
[
expert_id
])
start
+=
shard_size
layer
.
w13_weight_scale
=
torch
.
nn
.
Parameter
(
...
...
@@ -930,41 +909,11 @@ class Fp8MoEMethod:
correction_bias
=
correction_bias
,
)
if
_is_hip
and
get_bool_env_var
(
"USE_INT4_WEIGHT"
):
# TODO: add triton kernel and add check get_bool_env_var("CK_MOE")
assert
not
no_combine
,
f
"
{
no_combine
=
}
is not supported."
return
ck_moe_2stages_win4
(
x
,
layer
.
w13_weight
,
layer
.
w2_weight
,
topk_weights
,
topk_ids
,
layer
.
w13_weight_scale1
,
layer
.
w2_weight_scale1
,
activation
=
(
ActivationType
.
Silu
if
activation
==
"silu"
else
ActivationType
.
Gelu
),
)
if
_is_hip
and
get_bool_env_var
(
"CK_MOE"
):
assert
not
no_combine
,
f
"
{
no_combine
=
}
is not supported."
if
self
.
block_quant
:
# TODO(CK_MOE): FP8 block_quant only supports 'silu' for the time-being.
assert
(
activation
==
"silu"
),
f
"CK_MOE: FP8 bloack_quant
{
activation
=
}
will be supported later, unset CK_MOE"
return
asm_moe
(
x
,
layer
.
w13_weight
,
layer
.
w2_weight
,
topk_weights
,
topk_ids
,
layer
.
w13_weight_scale_inv
,
layer
.
w2_weight_scale_inv
,
block_shape
=
tuple
(
self
.
quant_config
.
weight_block_size
),
expert_mask
=
None
,
)
else
:
return
ck_moe_2stages
(
if
_is_hip
:
if
get_bool_env_var
(
"USE_INT4_WEIGHT"
):
# TODO: add triton kernel and add check get_bool_env_var("CK_MOE")
assert
not
no_combine
,
f
"
{
no_combine
=
}
is not supported."
return
ck_moe_2stages_win4
(
x
,
layer
.
w13_weight
,
layer
.
w2_weight
,
...
...
@@ -978,33 +927,65 @@ class Fp8MoEMethod:
else
ActivationType
.
Gelu
),
)
else
:
# Expert fusion with FP8 quantization
return
fused_experts
(
x
,
layer
.
w13_weight
,
layer
.
w2_weight
,
topk_weights
=
topk_weights
,
topk_ids
=
topk_ids
,
inplace
=
inplace
and
not
no_combine
,
activation
=
activation
,
apply_router_weight_on_input
=
apply_router_weight_on_input
,
use_fp8_w8a8
=
True
,
w1_scale
=
(
layer
.
w13_weight_scale_inv
if
self
.
block_quant
else
layer
.
w13_weight_scale
),
w2_scale
=
(
layer
.
w2_weight_scale_inv
if
self
.
block_quant
else
layer
.
w2_weight_scale
),
a1_scale
=
layer
.
w13_input_scale
,
a2_scale
=
layer
.
w2_input_scale
,
block_shape
=
self
.
quant_config
.
weight_block_size
,
no_combine
=
no_combine
,
)
if
get_bool_env_var
(
"CK_MOE"
):
assert
not
no_combine
,
f
"
{
no_combine
=
}
is not supported."
if
self
.
block_quant
:
# TODO(CK_MOE): FP8 block_quant only supports 'silu' for the time-being.
assert
(
activation
==
"silu"
),
f
"CK_MOE: FP8 bloack_quant
{
activation
=
}
will be supported later, unset CK_MOE"
return
asm_moe
(
x
,
layer
.
w13_weight
,
layer
.
w2_weight
,
topk_weights
,
topk_ids
,
layer
.
w13_weight_scale_inv
,
layer
.
w2_weight_scale_inv
,
block_shape
=
tuple
(
self
.
quant_config
.
weight_block_size
),
expert_mask
=
None
,
)
else
:
return
ck_moe_2stages
(
x
,
layer
.
w13_weight
,
layer
.
w2_weight
,
topk_weights
,
topk_ids
,
layer
.
w13_weight_scale1
,
layer
.
w2_weight_scale1
,
activation
=
(
ActivationType
.
Silu
if
activation
==
"silu"
else
ActivationType
.
Gelu
),
)
# Expert fusion with FP8 quantization
return
fused_experts
(
x
,
layer
.
w13_weight
,
layer
.
w2_weight
,
topk_weights
=
topk_weights
,
topk_ids
=
topk_ids
,
inplace
=
inplace
and
not
no_combine
,
activation
=
activation
,
apply_router_weight_on_input
=
apply_router_weight_on_input
,
use_fp8_w8a8
=
True
,
w1_scale
=
(
layer
.
w13_weight_scale_inv
if
self
.
block_quant
else
layer
.
w13_weight_scale
),
w2_scale
=
(
layer
.
w2_weight_scale_inv
if
self
.
block_quant
else
layer
.
w2_weight_scale
),
a1_scale
=
layer
.
w13_input_scale
,
a2_scale
=
layer
.
w2_input_scale
,
block_shape
=
self
.
quant_config
.
weight_block_size
,
no_combine
=
no_combine
,
)
class
Fp8KVCacheMethod
(
BaseKVCacheMethod
):
...
...
python/sglang/srt/layers/quantization/fp8_kernel.py
View file @
177320a5
...
...
@@ -34,15 +34,23 @@ from sglang.srt.utils import (
supports_custom_op
,
)
_enable_jit_deepgemm
=
False
_is_hip
=
is_hip
()
fp8_type_
=
torch
.
float8_e4m3fnuz
if
_is_hip
else
torch
.
float8_e4m3fn
_is_cuda
=
is_cuda
()
_fp8_type
=
torch
.
float8_e4m3fnuz
if
_is_hip
else
torch
.
float8_e4m3fn
if
_is_hip
:
fp8_max
=
224.0
else
:
fp8_max
=
torch
.
finfo
(
_fp8_type
).
max
fp8_min
=
-
fp8_max
_enable_jit_deepgemm
=
False
if
_is_cuda
:
import
deep_gemm
from
sgl_kernel
import
sgl_per_token_group_quant_fp8
,
sgl_per_token_quant_fp8
from
sgl_kernel
import
(
sgl_per_tensor_quant_fp8
,
sgl_per_token_group_quant_fp8
,
sgl_per_token_quant_fp8
,
)
sm_version
=
get_device_sm
()
if
sm_version
==
90
and
get_bool_env_var
(
...
...
@@ -53,6 +61,7 @@ if _is_cuda:
logger
=
logging
.
getLogger
(
__name__
)
if
supports_custom_op
():
def
deep_gemm_fp8_fp8_bf16_nt
(
...
...
@@ -179,7 +188,6 @@ def per_token_group_quant_fp8(
x
:
torch
.
Tensor
,
group_size
:
int
,
eps
:
float
=
1e-10
,
dtype
:
torch
.
dtype
=
fp8_type_
,
column_major_scales
:
bool
=
False
,
scale_tma_aligned
:
bool
=
False
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
...
...
@@ -192,7 +200,6 @@ def per_token_group_quant_fp8(
x: The input tenosr with ndim >= 2.
group_size: The group size used for quantization.
eps: The minimum to avoid dividing zero.
dtype: The dype of output tensor.
Returns:
Tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the scaling factor for quantization.
...
...
@@ -202,15 +209,7 @@ def per_token_group_quant_fp8(
),
"the last dimension of `x` cannot be divisible by `group_size`"
assert
x
.
is_contiguous
(),
"`x` is not contiguous"
finfo
=
torch
.
finfo
(
dtype
)
fp8_max
=
finfo
.
max
if
_is_hip
:
fp8_max
=
224.0
fp8_min
=
-
fp8_max
x_q
=
torch
.
empty_like
(
x
,
device
=
x
.
device
,
dtype
=
dtype
)
x_q
=
torch
.
empty_like
(
x
,
device
=
x
.
device
,
dtype
=
_fp8_type
)
M
=
x
.
numel
()
//
group_size
N
=
group_size
if
column_major_scales
:
...
...
@@ -276,27 +275,18 @@ def sglang_per_token_group_quant_fp8(
x
:
torch
.
Tensor
,
group_size
:
int
,
eps
:
float
=
1e-10
,
dtype
:
torch
.
dtype
=
fp8_type_
,
):
assert
(
x
.
shape
[
-
1
]
%
group_size
==
0
),
"the last dimension of `x` cannot be divisible by `group_size`"
assert
x
.
is_contiguous
(),
"`x` is not contiguous"
finfo
=
torch
.
finfo
(
dtype
)
fp8_max
=
finfo
.
max
fp8_min
=
-
fp8_max
x_q
=
torch
.
empty_like
(
x
,
device
=
x
.
device
,
dtype
=
dtype
)
M
=
x
.
numel
()
//
group_size
N
=
group_size
x_q
=
torch
.
empty_like
(
x
,
device
=
x
.
device
,
dtype
=
_fp8_type
)
x_s
=
torch
.
empty
(
x
.
shape
[:
-
1
]
+
(
x
.
shape
[
-
1
]
//
group_size
,),
device
=
x
.
device
,
dtype
=
torch
.
float32
,
)
sgl_per_token_group_quant_fp8
(
x
,
x_q
,
x_s
,
group_size
,
eps
,
fp8_min
,
fp8_max
)
return
x_q
,
x_s
...
...
@@ -304,7 +294,7 @@ def sglang_per_token_group_quant_fp8(
def
sglang_per_token_quant_fp8
(
x
:
torch
.
Tensor
,
dtype
:
torch
.
dtype
=
fp8_type
_
,
dtype
:
torch
.
dtype
=
_
fp8_type
,
):
assert
x
.
is_contiguous
(),
"`x` is not contiguous"
...
...
@@ -368,7 +358,6 @@ def static_quant_fp8(
x
:
torch
.
Tensor
,
x_s
:
torch
.
Tensor
,
repeat_scale
:
bool
=
False
,
dtype
:
torch
.
dtype
=
fp8_type_
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""Function to perform static quantization using the given scale on an input tensor `x`.
...
...
@@ -386,15 +375,8 @@ def static_quant_fp8(
"""
assert
x
.
is_contiguous
(),
"`x` is not contiguous"
assert
x_s
.
numel
()
==
1
,
"only supports per-tensor scale"
finfo
=
torch
.
finfo
(
dtype
)
fp8_max
=
finfo
.
max
if
_is_hip
:
fp8_max
=
224.0
fp8_min
=
-
fp8_max
x_q
=
torch
.
empty_like
(
x
,
device
=
x
.
device
,
dtype
=
d
type
)
x_q
=
torch
.
empty_like
(
x
,
device
=
x
.
device
,
dtype
=
_fp8_
type
)
M
=
x
.
numel
()
//
x
.
shape
[
-
1
]
N
=
x
.
shape
[
-
1
]
if
repeat_scale
:
...
...
@@ -896,7 +878,7 @@ def _per_tensor_quant_mla_fp8_stage2(
def
per_tensor_quant_mla_fp8
(
x
:
torch
.
Tensor
,
eps
:
float
=
1e-12
,
dtype
:
torch
.
dtype
=
torch
.
float8_e4m3fn
x
:
torch
.
Tensor
,
eps
:
float
=
1e-12
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""
This function quantizes input values to float8 values with tensor-wise quantization
...
...
@@ -904,13 +886,7 @@ def per_tensor_quant_mla_fp8(
"""
assert
x
.
dim
()
==
3
,
"`x` is not a 3d-tensor"
finfo
=
torch
.
finfo
(
dtype
)
fp8_max
=
finfo
.
max
if
_is_hip
:
dtype
=
torch
.
float8_e4m3fnuz
fp8_max
=
224.0
x_q
=
x
.
new_empty
(
x
.
size
(),
dtype
=
dtype
)
x_q
=
x
.
new_empty
(
x
.
size
(),
dtype
=
_fp8_type
)
x_s
=
torch
.
zeros
((
1
,),
dtype
=
torch
.
float32
,
device
=
x
.
device
)
num_head
,
num_seq
,
head_size
=
x
.
shape
...
...
@@ -935,9 +911,64 @@ def per_tensor_quant_mla_fp8(
head_size
,
x
.
stride
(
0
),
x
.
stride
(
1
),
-
fp8_m
ax
,
fp8_m
in
,
fp8_max
,
BLOCK_SIZE
,
)
return
x_q
,
x_s
def
scaled_fp8_quant
(
input
:
torch
.
Tensor
,
scale
:
Optional
[
torch
.
Tensor
]
=
None
,
num_token_padding
:
Optional
[
int
]
=
None
,
use_per_token_if_dynamic
:
bool
=
False
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""
Quantize input tensor to FP8 (8-bit floating point) format.
Args:
input (torch.Tensor): Input tensor to be quantized
scale (Optional[torch.Tensor]): Pre-computed scaling factor for static quantization.
If None, scales will be computed dynamically.
num_token_padding (Optional[int]): If specified, pad the first dimension
of the output to at least this value.
use_per_token_if_dynamic (bool): When using dynamic scaling (scale=None),
determines the quantization granularity:
- True: compute scale per token
- False: compute single scale per tensor
Returns:
Tuple[torch.Tensor, torch.Tensor]: A tuple containing:
- quantized_tensor: The FP8 quantized version of input
- scale_tensor: The scaling factors used for quantization
Raises:
AssertionError: If input is not 2D or if static scale's numel != 1
"""
assert
input
.
ndim
==
2
,
f
"Expected 2D input tensor, got
{
input
.
ndim
}
D"
shape
=
input
.
shape
out_dtype
=
torch
.
float8_e4m3fnuz
if
_is_hip
else
torch
.
float8_e4m3fn
if
num_token_padding
:
shape
=
(
max
(
num_token_padding
,
input
.
shape
[
0
]),
shape
[
1
])
output
=
torch
.
empty
(
shape
,
device
=
input
.
device
,
dtype
=
out_dtype
)
if
scale
is
None
:
# Dynamic scaling
if
use_per_token_if_dynamic
:
scale
=
torch
.
empty
((
shape
[
0
],
1
),
device
=
input
.
device
,
dtype
=
torch
.
float32
)
sgl_per_token_quant_fp8
(
input
,
output
,
scale
)
else
:
scale
=
torch
.
zeros
(
1
,
device
=
input
.
device
,
dtype
=
torch
.
float32
)
sgl_per_tensor_quant_fp8
(
input
,
output
,
scale
,
is_static
=
False
)
# False for dynamic
else
:
# Static scaling
assert
scale
.
numel
()
==
1
,
f
"Expected scalar scale, got numel=
{
scale
.
numel
()
}
"
sgl_per_tensor_quant_fp8
(
input
,
output
,
scale
,
is_static
=
True
)
# True for static
return
output
,
scale
python/sglang/srt/layers/quantization/fp8_utils.py
View file @
177320a5
import
os
from
typing
import
List
,
Optional
,
Tuple
import
torch
try
:
from
vllm
import
_custom_ops
as
vllm_ops
VLLM_AVAILABLE
=
True
except
ImportError
:
VLLM_AVAILABLE
=
False
from
sglang.srt.layers.quantization.fp8_kernel
import
(
_enable_jit_deepgemm
,
per_token_group_quant_fp8
,
scaled_fp8_quant
,
sglang_per_token_quant_fp8
,
static_quant_fp8
,
w8a8_block_fp8_matmul
,
)
...
...
@@ -17,30 +25,20 @@ from sglang.srt.utils import (
is_hip
,
)
try
:
import
vllm
from
vllm
import
_custom_ops
as
ops
VLLM_AVAILABLE
=
True
except
ImportError
:
VLLM_AVAILABLE
=
False
use_vllm_cutlass_w8a8_fp8_kernel
=
get_bool_env_var
(
"USE_VLLM_CUTLASS_W8A8_FP8_KERNEL"
)
_is_hip
=
is_hip
()
_is_cuda
=
is_cuda
()
if
_is_hip
and
get_bool_env_var
(
"CK_MOE"
):
from
aiter
import
gemm_a8w8_blockscale
_is_cuda
=
is_cuda
()
if
_is_cuda
:
from
sgl_kernel
import
fp8_blockwise_scaled_mm
,
fp8_scaled_mm
from
sglang.srt.custom_op
import
scaled_fp8_quant
as
sgl_scaled_fp8_quant
from
sglang.srt.layers.quantization.fp8_kernel
import
sglang_per_token_quant_fp8
use_vllm_cutlass_w8a8_fp8_kernel
=
get_bool_env_var
(
"USE_VLLM_CUTLASS_W8A8_FP8_KERNEL"
)
# Input scaling factors are no longer optional in _scaled_mm starting
# from pytorch 2.5. Allocating a dummy tensor to pass as input_scale
TORCH_DEVICE_IDENTITY
=
torch
.
ones
(
1
,
dtype
=
torch
.
float32
)
TORCH_DEVICE_IDENTITY
=
None
_TORCH_VERSION
=
torch
.
__version__
.
split
(
"+"
)[
0
]
try
:
...
...
@@ -214,7 +212,7 @@ def block_quant_to_tensor_quant(
x_dq_block_tiles
[
j
][
i
][:,
:]
=
x_dq_block_tiles
[
j
][
i
]
*
x_s
[
j
][
i
]
x_q_tensor
,
scale
=
(
sgl_
scaled_fp8_quant
(
x_dq_block
)
scaled_fp8_quant
(
x_dq_block
)
if
_is_cuda
else
input_to_float8
(
x_dq_block
,
dtype
=
x_q_block
.
dtype
)
)
...
...
@@ -227,7 +225,7 @@ def channel_quant_to_tensor_quant(
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
x_dq_channel
=
x_q_channel
.
to
(
torch
.
float32
)
*
x_s
x_q_tensor
,
scale
=
(
sgl_
scaled_fp8_quant
(
x_dq_channel
)
scaled_fp8_quant
(
x_dq_channel
)
if
_is_cuda
else
input_to_float8
(
x_dq_channel
,
dtype
=
x_q_channel
.
dtype
)
)
...
...
@@ -264,7 +262,7 @@ def apply_fp8_linear(
# final solution should be: 1. add support to per-tensor activation scaling.
# 2. solve the torch.compile error from weight_scale.numel() == 1 and x_scale.numel() > 1 (below line#308)
if
_is_hip
and
weight_scale
.
numel
()
==
1
:
qinput
,
x_scale
=
ops
.
scaled_fp8_quant
(
qinput
,
x_scale
=
vllm_
ops
.
scaled_fp8_quant
(
input_2d
,
input_scale
,
use_per_token_if_dynamic
=
use_per_token_if_dynamic
,
...
...
@@ -275,32 +273,29 @@ def apply_fp8_linear(
)
if
cutlass_fp8_supported
:
try
:
if
VLLM_AVAILABLE
and
use_vllm_cutlass_w8a8_fp8_kernel
:
# Fall back to vllm cutlass w8a8 fp8 kernel
output
=
ops
.
cutlass_scaled_mm
(
qinput
,
weight
,
out_dtype
=
input
.
dtype
,
scale_a
=
x_scale
,
scale_b
=
weight_scale
,
bias
=
bias
,
)
else
:
assert
(
weight_scale
.
numel
()
==
weight
.
shape
[
1
]
),
"cutlass w8a8 fp8 sgl-kernel only supports per-channel scale"
output
=
fp8_scaled_mm
(
qinput
,
weight
,
x_scale
,
weight_scale
,
out_dtype
=
input
.
dtype
,
bias
=
bias
,
)
return
output
.
view
(
*
output_shape
)
except
(
ImportError
,
NameError
,
AttributeError
):
pass
if
VLLM_AVAILABLE
and
use_vllm_cutlass_w8a8_fp8_kernel
:
# Fall back to vllm cutlass w8a8 fp8 kernel
output
=
vllm_ops
.
cutlass_scaled_mm
(
qinput
,
weight
,
out_dtype
=
input
.
dtype
,
scale_a
=
x_scale
,
scale_b
=
weight_scale
,
bias
=
bias
,
)
else
:
assert
(
weight_scale
.
numel
()
==
weight
.
shape
[
1
]
),
"cutlass w8a8 fp8 sgl-kernel only supports per-channel scale"
output
=
fp8_scaled_mm
(
qinput
,
weight
,
x_scale
,
weight_scale
,
out_dtype
=
input
.
dtype
,
bias
=
bias
,
)
return
output
.
view
(
*
output_shape
)
# torch.scaled_mm supports per tensor weights + activations only
# so fallback to naive if per channel or per token
...
...
@@ -343,8 +338,10 @@ def apply_fp8_linear(
# Making sure the dummy tensor is on the same device as the weight
global
TORCH_DEVICE_IDENTITY
if
TORCH_DEVICE_IDENTITY
.
device
!=
weight
.
device
:
TORCH_DEVICE_IDENTITY
=
TORCH_DEVICE_IDENTITY
.
to
(
weight
.
device
)
if
TORCH_DEVICE_IDENTITY
is
None
:
TORCH_DEVICE_IDENTITY
=
torch
.
ones
(
1
,
dtype
=
torch
.
float32
,
device
=
weight
.
device
)
# GEMM
# This computes C = (X * W).
...
...
@@ -372,13 +369,6 @@ def apply_fp8_linear(
return
output
.
to
(
dtype
=
input
.
dtype
).
view
(
*
output_shape
)
def
maybe_create_device_identity
():
# Allocate dummy ones tensor for torch._scaled_mm
global
TORCH_DEVICE_IDENTITY
if
TORCH_DEVICE_IDENTITY
is
None
:
TORCH_DEVICE_IDENTITY
=
torch
.
ones
(
1
,
dtype
=
torch
.
float32
)
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/utils/w8a8_utils.py
# TODO(luka): follow similar pattern for marlin and block-fp8-linear
# https://github.com/vllm-project/vllm/issues/14397
...
...
@@ -405,9 +395,7 @@ class Fp8LinearOp:
# We also don't pad when using torch.compile,
# as it breaks with dynamic shapes.
if
pad_output
is
None
:
enable_torch_compile
=
os
.
environ
.
get
(
"SGLANG_ENABLE_TORCH_COMPILE"
,
"0"
).
lower
()
in
(
"1"
,
"true"
,
"yes"
)
enable_torch_compile
=
get_bool_env_var
(
"SGLANG_ENABLE_TORCH_COMPILE"
)
pad_output
=
not
enable_torch_compile
self
.
output_padding
=
17
if
pad_output
else
None
...
...
@@ -439,13 +427,13 @@ class Fp8LinearOp:
# for sgl-kernel fp8_scaled_mm, it support per channel W now
if
self
.
cutlass_fp8_supported
and
weight_scale
.
numel
()
==
weight
.
shape
[
1
]:
if
_is_cuda
:
qinput
,
x_scale
=
sgl_
scaled_fp8_quant
(
qinput
,
x_scale
=
scaled_fp8_quant
(
input_2d
,
input_scale
,
use_per_token_if_dynamic
=
use_per_token_if_dynamic
,
)
else
:
qinput
,
x_scale
=
ops
.
scaled_fp8_quant
(
qinput
,
x_scale
=
vllm_
ops
.
scaled_fp8_quant
(
input_2d
,
input_scale
,
scale_ub
=
input_scale_ub
,
...
...
@@ -455,7 +443,7 @@ class Fp8LinearOp:
# Fused GEMM_DQ
if
VLLM_AVAILABLE
and
use_vllm_cutlass_w8a8_fp8_kernel
:
# Fall back to vllm cutlass w8a8 fp8 kernel
output
=
ops
.
cutlass_scaled_mm
(
output
=
vllm_
ops
.
cutlass_scaled_mm
(
qinput
,
weight
,
out_dtype
=
input
.
dtype
,
...
...
@@ -482,14 +470,14 @@ class Fp8LinearOp:
else
:
# Maybe apply padding to output, see comment in __init__
if
_is_cuda
:
qinput
,
x_scale
=
sgl_
scaled_fp8_quant
(
qinput
,
x_scale
=
scaled_fp8_quant
(
input_2d
,
input_scale
,
num_token_padding
=
self
.
output_padding
,
use_per_token_if_dynamic
=
use_per_token_if_dynamic
,
)
else
:
qinput
,
x_scale
=
ops
.
scaled_fp8_quant
(
qinput
,
x_scale
=
vllm_
ops
.
scaled_fp8_quant
(
input_2d
,
input_scale
,
num_token_padding
=
self
.
output_padding
,
...
...
@@ -562,9 +550,12 @@ class Fp8LinearOp:
# This computes C = (X * W).
# Output in fp32 to allow subsequent ops to happen in-place
# Making sure the dummy tensor is on the same device as the weight
global
TORCH_DEVICE_IDENTITY
if
TORCH_DEVICE_IDENTITY
.
device
!=
weight
.
device
:
TORCH_DEVICE_IDENTITY
=
TORCH_DEVICE_IDENTITY
.
to
(
weight
.
device
)
if
TORCH_DEVICE_IDENTITY
is
None
:
TORCH_DEVICE_IDENTITY
=
torch
.
ones
(
1
,
dtype
=
torch
.
float32
,
device
=
weight
.
device
)
output
=
torch
.
_scaled_mm
(
qinput
,
...
...
python/sglang/srt/layers/quantization/utils.py
View file @
177320a5
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/utils/quant_utils.py
from
types
import
MappingProxyType
from
typing
import
List
,
Mapping
,
Optional
,
Tuple
,
Union
from
typing
import
List
,
Mapping
,
Tuple
,
Union
import
torch
from
sglang.srt.layers.quantization.fp8_kernel
import
scaled_fp8_quant
from
sglang.srt.utils
import
is_cuda
_is_cuda
=
is_cuda
()
if
_is_cuda
:
from
sglang.srt.custom_op
import
scaled_fp8_quant
as
sgl_scaled_fp8_quant
else
:
from
vllm
import
_custom_ops
as
vllm_ops
if
not
_is_cuda
:
from
vllm._custom_ops
import
scaled_fp8_quant
def
is_fp8_fnuz
()
->
bool
:
...
...
@@ -116,12 +115,7 @@ def requantize_with_max_scale(
for
idx
,
logical_width
in
enumerate
(
logical_widths
):
end
=
start
+
logical_width
weight_dq
=
per_tensor_dequantize
(
weight
[
start
:
end
,
:],
weight_scale
[
idx
])
if
_is_cuda
:
weight
[
start
:
end
,
:],
_
=
sgl_scaled_fp8_quant
(
weight_dq
,
max_w_scale
)
else
:
weight
[
start
:
end
,
:],
_
=
vllm_ops
.
scaled_fp8_quant
(
weight_dq
,
max_w_scale
)
weight
[
start
:
end
,
:],
_
=
scaled_fp8_quant
(
weight_dq
,
max_w_scale
)
start
=
end
return
max_w_scale
,
weight
...
...
python/sglang/srt/layers/quantization/w8a8_int8.py
View file @
177320a5
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Optional
import
torch
from
sglang.srt.utils
import
is_cuda_available
,
set_weight_attrs
is_cuda
=
is_cuda_available
()
if
is_cuda
:
from
sgl_kernel
import
int8_scaled_mm
from
torch.nn.parameter
import
Parameter
from
sglang.srt.distributed
import
get_tensor_model_parallel_world_size
...
...
@@ -18,6 +11,11 @@ from sglang.srt.layers.quantization.base_config import (
QuantizeMethodBase
,
)
from
sglang.srt.layers.quantization.int8_kernel
import
per_token_quant_int8
from
sglang.srt.utils
import
is_cuda_available
,
set_weight_attrs
is_cuda
=
is_cuda_available
()
if
is_cuda
:
from
sgl_kernel
import
int8_scaled_mm
class
W8A8Int8Config
(
QuantizationConfig
):
...
...
python/sglang/srt/layers/rotary_embedding.py
View file @
177320a5
...
...
@@ -11,10 +11,11 @@ from sglang.srt.custom_op import CustomOp
from
sglang.srt.utils
import
is_cuda_available
_is_cuda_available
=
is_cuda_available
()
if
_is_cuda_available
:
from
sgl_kernel
import
apply_rope_with_cos_sin_cache_inplace
else
:
from
vllm
import
_custom_ops
as
ops
from
vllm
.
_custom_ops
import
rotary_embedding
as
vllm_rotary_embedding
def
_rotate_neox
(
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
...
...
@@ -159,7 +160,7 @@ class RotaryEmbedding(CustomOp):
)
else
:
self
.
cos_sin_cache
=
self
.
cos_sin_cache
.
to
(
query
.
device
,
dtype
=
query
.
dtype
)
ops
.
rotary_embedding
(
vllm_
rotary_embedding
(
positions
,
query
,
key
,
...
...
python/sglang/srt/lora/backend/__init__.py
deleted
100644 → 0
View file @
d7bc19a4
from
sglang.srt.lora.backend.base_backend
import
BaseLoRABackend
def
get_backend_from_name
(
name
:
str
)
->
BaseLoRABackend
:
"""
Get corresponding backend class from backend's name
"""
if
name
==
"triton"
:
from
sglang.srt.lora.backend.triton_backend
import
TritonLoRABackend
return
TritonLoRABackend
elif
name
==
"flashinfer"
:
from
sglang.srt.lora.backend.flashinfer_backend
import
FlashInferLoRABackend
return
FlashInferLoRABackend
else
:
raise
ValueError
(
f
"Invalid backend:
{
name
}
"
)
__all__
=
[
"BaseLoRABackend"
,
"FlashInferLoRABackend"
,
"TritonLoRABackend"
,
"get_backend_from_name"
,
]
python/sglang/srt/lora/backend/base_backend.py
View file @
177320a5
...
...
@@ -75,7 +75,7 @@ class BaseLoRABackend:
qkv_lora_a
:
torch
.
Tensor
,
qkv_lora_b
:
Union
[
torch
.
Tensor
,
Tuple
[
torch
.
Tensor
]],
*
args
,
**
kwargs
**
kwargs
,
)
->
torch
.
Tensor
:
"""Run the lora pass for QKV Layer.
...
...
@@ -98,7 +98,7 @@ class BaseLoRABackend:
gate_up_lora_a
:
torch
.
Tensor
,
gate_up_lora_b
:
Union
[
torch
.
Tensor
,
Tuple
[
torch
.
Tensor
]],
*
args
,
**
kwargs
**
kwargs
,
)
->
torch
.
Tensor
:
"""Run the lora pass for gate_up_proj, usually attached to MergedColumnParallelLayer.
...
...
@@ -115,3 +115,19 @@ class BaseLoRABackend:
def
set_batch_info
(
self
,
batch_info
:
LoRABatchInfo
):
self
.
batch_info
=
batch_info
def
get_backend_from_name
(
name
:
str
)
->
BaseLoRABackend
:
"""
Get corresponding backend class from backend's name
"""
if
name
==
"triton"
:
from
sglang.srt.lora.backend.triton_backend
import
TritonLoRABackend
return
TritonLoRABackend
elif
name
==
"flashinfer"
:
from
sglang.srt.lora.backend.flashinfer_backend
import
FlashInferLoRABackend
return
FlashInferLoRABackend
else
:
raise
ValueError
(
f
"Invalid backend:
{
name
}
"
)
python/sglang/srt/lora/backend/flashinfer_backend.py
View file @
177320a5
...
...
@@ -2,7 +2,7 @@ from typing import Tuple
import
torch
from
sglang.srt.lora.backend
import
BaseLoRABackend
from
sglang.srt.lora.backend
.base_backend
import
BaseLoRABackend
from
sglang.srt.lora.utils
import
LoRABatchInfo
from
sglang.srt.utils
import
is_flashinfer_available
...
...
python/sglang/srt/lora/backend/triton_backend.py
View file @
177320a5
import
torch
from
sglang.srt.lora.backend
import
BaseLoRABackend
from
sglang.srt.lora.backend
.base_backend
import
BaseLoRABackend
from
sglang.srt.lora.triton_ops
import
(
gate_up_lora_b_fwd
,
qkv_lora_b_fwd
,
...
...
python/sglang/srt/lora/layers.py
View file @
177320a5
...
...
@@ -16,7 +16,7 @@ from sglang.srt.layers.linear import (
RowParallelLinear
,
)
from
sglang.srt.layers.vocab_parallel_embedding
import
VocabParallelEmbedding
from
sglang.srt.lora.backend
import
BaseLoRABackend
from
sglang.srt.lora.backend
.base_backend
import
BaseLoRABackend
class
BaseLayerWithLoRA
(
nn
.
Module
):
...
...
python/sglang/srt/lora/lora.py
View file @
177320a5
...
...
@@ -27,7 +27,7 @@ from torch import nn
from
sglang.srt.configs.load_config
import
LoadConfig
from
sglang.srt.hf_transformers_utils
import
AutoConfig
from
sglang.srt.lora.backend
import
BaseLoRABackend
from
sglang.srt.lora.backend
.base_backend
import
BaseLoRABackend
from
sglang.srt.lora.lora_config
import
LoRAConfig
from
sglang.srt.model_loader.loader
import
DefaultModelLoader
...
...
python/sglang/srt/lora/lora_manager.py
View file @
177320a5
...
...
@@ -22,7 +22,7 @@ import torch
from
sglang.srt.configs.load_config
import
LoadConfig
from
sglang.srt.hf_transformers_utils
import
AutoConfig
from
sglang.srt.lora.backend
import
BaseLoRABackend
,
get_backend_from_name
from
sglang.srt.lora.backend
.base_backend
import
BaseLoRABackend
,
get_backend_from_name
from
sglang.srt.lora.layers
import
BaseLayerWithLoRA
,
get_lora_layer
from
sglang.srt.lora.lora
import
LoRAAdapter
from
sglang.srt.lora.lora_config
import
LoRAConfig
...
...
python/sglang/srt/managers/detokenizer_manager.py
View file @
177320a5
...
...
@@ -14,7 +14,6 @@
"""DetokenizerManager is a process that detokenizes the token ids."""
import
dataclasses
import
json
import
logging
import
os
import
signal
...
...
python/sglang/srt/managers/mm_utils.py
View file @
177320a5
"""
Multi-modality utils
Multi-modality utils
"""
import
logging
from
abc
import
abstractmethod
from
typing
import
Callable
,
List
,
Optional
,
Tuple
...
...
@@ -12,11 +13,11 @@ from sglang.srt.managers.schedule_batch import (
MultimodalDataItem
,
MultimodalInputs
,
global_server_args_dict
,
logger
,
)
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.utils
import
print_warning_once
from
sglang.utils
import
logger
logger
=
logging
.
getLogger
(
__name__
)
class
MultiModalityDataPaddingPattern
:
...
...
python/sglang/srt/managers/multimodal_processor.py
View file @
177320a5
...
...
@@ -5,8 +5,6 @@ import logging
import
pkgutil
from
functools
import
lru_cache
from
transformers
import
PROCESSOR_MAPPING
from
sglang.srt.managers.multimodal_processors.base_processor
import
(
BaseMultimodalProcessor
,
)
...
...
python/sglang/srt/managers/multimodal_processors/base_processor.py
View file @
177320a5
...
...
@@ -8,8 +8,6 @@ from typing import List, Optional
import
numpy
as
np
import
PIL
from
decord
import
VideoReader
,
cpu
from
PIL
import
Image
from
transformers
import
BaseImageProcessorFast
from
sglang.srt.managers.schedule_batch
import
Modality
...
...
@@ -102,6 +100,9 @@ class BaseMultimodalProcessor(ABC):
"""
estimate the total frame count from all visual input
"""
# Lazy import because decord is not available on some arm platforms.
from
decord
import
VideoReader
,
cpu
# Before processing inputs
estimated_frames_list
=
[]
for
image
in
image_data
:
...
...
python/sglang/srt/model_executor/cuda_graph_runner.py
View file @
177320a5
...
...
@@ -37,11 +37,11 @@ from sglang.srt.model_executor.forward_batch_info import (
from
sglang.srt.patch_torch
import
monkey_patch_torch_compile
from
sglang.srt.utils
import
get_available_gpu_memory
,
is_hip
_is_hip
=
is_hip
()
if
TYPE_CHECKING
:
from
sglang.srt.model_executor.model_runner
import
ModelRunner
_is_hip
=
is_hip
()
def
_to_torch
(
model
:
torch
.
nn
.
Module
,
reverse
:
bool
,
num_tokens
:
int
):
for
sub
in
model
.
_modules
.
values
():
...
...
python/sglang/srt/model_executor/model_runner.py
View file @
177320a5
...
...
@@ -320,7 +320,6 @@ class ModelRunner:
logger
.
info
(
f
"DeepEP is turned on. DeepEP mode:
{
server_args
.
deepep_mode
}
"
)
if
not
self
.
use_mla_backend
:
logger
.
info
(
"Disable chunked prefix cache for non-MLA backend."
)
server_args
.
disable_chunked_prefix_cache
=
True
elif
self
.
page_size
>
1
:
logger
.
info
(
"Disable chunked prefix cache when page size > 1."
)
...
...
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment