Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
177320a5
"docs/advanced_features/hyperparameter_tuning.md" did not exist on "cdcbde5fc3155edaa6b98a13ab8764101e657b23"
Unverified
Commit
177320a5
authored
Apr 16, 2025
by
Lianmin Zheng
Committed by
GitHub
Apr 16, 2025
Browse files
Clean up imports (#5467)
parent
d7bc19a4
Changes
51
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
288 additions
and
306 deletions
+288
-306
python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py
...compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py
+0
-3
python/sglang/srt/layers/quantization/fp8.py
python/sglang/srt/layers/quantization/fp8.py
+113
-132
python/sglang/srt/layers/quantization/fp8_kernel.py
python/sglang/srt/layers/quantization/fp8_kernel.py
+76
-45
python/sglang/srt/layers/quantization/fp8_utils.py
python/sglang/srt/layers/quantization/fp8_utils.py
+54
-63
python/sglang/srt/layers/quantization/utils.py
python/sglang/srt/layers/quantization/utils.py
+5
-11
python/sglang/srt/layers/quantization/w8a8_int8.py
python/sglang/srt/layers/quantization/w8a8_int8.py
+5
-7
python/sglang/srt/layers/rotary_embedding.py
python/sglang/srt/layers/rotary_embedding.py
+3
-2
python/sglang/srt/lora/backend/__init__.py
python/sglang/srt/lora/backend/__init__.py
+0
-25
python/sglang/srt/lora/backend/base_backend.py
python/sglang/srt/lora/backend/base_backend.py
+18
-2
python/sglang/srt/lora/backend/flashinfer_backend.py
python/sglang/srt/lora/backend/flashinfer_backend.py
+1
-1
python/sglang/srt/lora/backend/triton_backend.py
python/sglang/srt/lora/backend/triton_backend.py
+1
-1
python/sglang/srt/lora/layers.py
python/sglang/srt/lora/layers.py
+1
-1
python/sglang/srt/lora/lora.py
python/sglang/srt/lora/lora.py
+1
-1
python/sglang/srt/lora/lora_manager.py
python/sglang/srt/lora/lora_manager.py
+1
-1
python/sglang/srt/managers/detokenizer_manager.py
python/sglang/srt/managers/detokenizer_manager.py
+0
-1
python/sglang/srt/managers/mm_utils.py
python/sglang/srt/managers/mm_utils.py
+4
-3
python/sglang/srt/managers/multimodal_processor.py
python/sglang/srt/managers/multimodal_processor.py
+0
-2
python/sglang/srt/managers/multimodal_processors/base_processor.py
...lang/srt/managers/multimodal_processors/base_processor.py
+3
-2
python/sglang/srt/model_executor/cuda_graph_runner.py
python/sglang/srt/model_executor/cuda_graph_runner.py
+2
-2
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+0
-1
No files found.
python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py
View file @
177320a5
...
...
@@ -17,7 +17,6 @@ from sglang.srt.layers.quantization.compressed_tensors.schemes import (
)
from
sglang.srt.layers.quantization.fp8_utils
import
(
Fp8LinearOp
,
maybe_create_device_identity
,
normalize_e4m3fn_to_e4m3fnuz
,
)
from
sglang.srt.layers.quantization.utils
import
is_fp8_fnuz
,
requantize_with_max_scale
...
...
@@ -99,8 +98,6 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
weight_loader
:
Callable
,
**
kwargs
,
):
maybe_create_device_identity
()
output_size_per_partition
=
sum
(
output_partition_sizes
)
layer
.
logical_widths
=
output_partition_sizes
...
...
python/sglang/srt/layers/quantization/fp8.py
View file @
177320a5
...
...
@@ -8,15 +8,6 @@ import torch.nn.functional as F
from
torch.nn
import
Module
from
torch.nn.parameter
import
Parameter
from
sglang.srt.layers.quantization.kv_cache
import
BaseKVCacheMethod
from
sglang.srt.layers.quantization.utils
import
(
all_close_1d
,
convert_to_channelwise
,
is_layer_skipped
,
per_tensor_dequantize
,
requantize_with_max_scale
,
)
try
:
from
vllm.model_executor.layers.quantization.utils.marlin_utils_fp8
import
(
apply_fp8_marlin_linear
,
...
...
@@ -27,11 +18,12 @@ try:
except
ImportError
:
MARLIN_FP8_AVAILABLE
=
False
def
apply_fp8_marlin_linear
(
*
args
,
**
kwargs
):
raise
ImportError
(
"vllm is not installed"
)
def
dummy_func
(
*
args
,
**
kwargs
):
raise
ImportError
(
"marlin FP8 requires some operators from vllm. Please install vllm."
)
def
prepare_fp8_layer_for_marlin
(
*
args
,
**
kwargs
):
raise
ImportError
(
"vllm is not installed"
)
apply_fp8_marlin_linear
=
prepare_fp8_layer_for_marlin
=
dummy_func
from
sglang.srt.distributed
import
get_tensor_model_parallel_world_size
...
...
@@ -49,7 +41,10 @@ from sglang.srt.layers.quantization.base_config import (
QuantizationConfig
,
QuantizeMethodBase
,
)
from
sglang.srt.layers.quantization.fp8_kernel
import
per_token_group_quant_fp8
from
sglang.srt.layers.quantization.fp8_kernel
import
(
per_token_group_quant_fp8
,
scaled_fp8_quant
,
)
from
sglang.srt.layers.quantization.fp8_utils
import
(
apply_fp8_linear
,
apply_w8a8_block_fp8_linear
,
...
...
@@ -57,30 +52,35 @@ from sglang.srt.layers.quantization.fp8_utils import (
input_to_float8
,
normalize_e4m3fn_to_e4m3fnuz
,
)
from
sglang.srt.layers.quantization.kv_cache
import
BaseKVCacheMethod
from
sglang.srt.layers.quantization.utils
import
(
all_close_1d
,
convert_to_channelwise
,
is_layer_skipped
,
per_tensor_dequantize
,
requantize_with_max_scale
,
)
from
sglang.srt.utils
import
(
get_bool_env_var
,
is_cuda
,
is_hip
,
permute_weight
,
print_warning_once
,
set_weight_attrs
,
)
ACTIVATION_SCHEMES
=
[
"static"
,
"dynamic"
]
_is_hip
=
is_hip
()
_is_cuda
=
is_cuda
()
if
_is_hip
:
from
aiter
import
ActivationType
from
aiter.fused_moe_bf16_asm
import
asm_moe
,
ck_moe_2stages
,
ck_moe_2stages_win4
from
aiter.ops.shuffle
import
shuffle_weight
_is_cuda
=
is_cuda
()
if
not
_is_cuda
:
from
vllm._custom_ops
import
scaled_fp8_quant
if
_is_cuda
:
from
sglang.srt.custom_op
import
scaled_fp8_quant
as
sgl_scaled_fp8_quant
else
:
from
vllm
import
_custom_ops
as
vllm_ops
ACTIVATION_SCHEMES
=
[
"static"
,
"dynamic"
]
logger
=
logging
.
getLogger
(
__name__
)
...
...
@@ -243,7 +243,6 @@ class Fp8LinearMethod(LinearMethodBase):
)
layer
.
logical_widths
=
output_partition_sizes
layer
.
input_size_per_partition
=
input_size_per_partition
layer
.
output_size_per_partition
=
output_size_per_partition
layer
.
orig_dtype
=
params_dtype
...
...
@@ -327,7 +326,9 @@ class Fp8LinearMethod(LinearMethodBase):
layer
.
weight_scale_inv
.
data
,
requires_grad
=
False
)
return
layer
.
weight
=
torch
.
nn
.
Parameter
(
layer
.
weight
.
data
,
requires_grad
=
False
)
# If checkpoint not serialized fp8, quantize the weights.
if
not
self
.
quant_config
.
is_checkpoint_fp8_serialized
:
if
self
.
cutlass_fp8_supported
or
self
.
use_marlin
:
...
...
@@ -391,12 +392,9 @@ class Fp8LinearMethod(LinearMethodBase):
)
if
self
.
use_marlin
:
try
:
prepare_fp8_layer_for_marlin
(
layer
)
# Activations not quantized for marlin.
del
layer
.
input_scale
except
ImportError
:
self
.
use_marlin
=
False
prepare_fp8_layer_for_marlin
(
layer
)
# Activations not quantized for marlin.
del
layer
.
input_scale
def
apply
(
self
,
...
...
@@ -406,18 +404,15 @@ class Fp8LinearMethod(LinearMethodBase):
)
->
torch
.
Tensor
:
if
self
.
use_marlin
:
try
:
return
apply_fp8_marlin_linear
(
input
=
x
,
weight
=
layer
.
weight
,
weight_scale
=
layer
.
weight_scale
,
workspace
=
layer
.
workspace
,
size_n
=
layer
.
output_size_per_partition
,
size_k
=
layer
.
input_size_per_partition
,
bias
=
bias
,
)
except
ImportError
:
self
.
use_marlin
=
False
return
apply_fp8_marlin_linear
(
input
=
x
,
weight
=
layer
.
weight
,
weight_scale
=
layer
.
weight_scale
,
workspace
=
layer
.
workspace
,
size_n
=
layer
.
output_size_per_partition
,
size_k
=
layer
.
input_size_per_partition
,
bias
=
bias
,
)
if
self
.
block_quant
:
return
apply_w8a8_block_fp8_linear
(
...
...
@@ -516,7 +511,7 @@ class Fp8MoEMethod:
)
# WEIGHTS
if
get_bool_env_var
(
"USE_INT4_WEIGHT"
):
if
_is_hip
and
get_bool_env_var
(
"USE_INT4_WEIGHT"
):
# INT4 MoE weight - INT32 packed
w13_weight
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
...
...
@@ -617,7 +612,7 @@ class Fp8MoEMethod:
set_weight_attrs
(
w13_weight_scale
,
extra_weight_attrs
)
set_weight_attrs
(
w2_weight_scale
,
extra_weight_attrs
)
if
get_bool_env_var
(
"USE_INT4_WEIGHT"
):
if
_is_hip
and
get_bool_env_var
(
"USE_INT4_WEIGHT"
):
extra_weight_attrs
.
update
(
{
"quant_method"
:
FusedMoeWeightScaleSupported
.
CHANNEL
.
value
}
)
...
...
@@ -649,7 +644,7 @@ class Fp8MoEMethod:
layer
.
w2_input_scale
=
None
def
process_weights_after_loading
(
self
,
layer
:
Module
)
->
None
:
if
get_bool_env_var
(
"USE_INT4_WEIGHT"
):
if
_is_hip
and
get_bool_env_var
(
"USE_INT4_WEIGHT"
):
self
.
process_weights_hip_int4
(
layer
)
return
...
...
@@ -706,20 +701,12 @@ class Fp8MoEMethod:
requires_grad
=
False
,
)
for
expert
in
range
(
layer
.
num_experts
):
if
_is_cuda
:
w13_weight
[
expert
,
:,
:],
layer
.
w13_weight_scale
[
expert
]
=
(
sgl_scaled_fp8_quant
(
layer
.
w13_weight
.
data
[
expert
,
:,
:])
)
w2_weight
[
expert
,
:,
:],
layer
.
w2_weight_scale
[
expert
]
=
(
sgl_scaled_fp8_quant
(
layer
.
w2_weight
.
data
[
expert
,
:,
:])
)
else
:
w13_weight
[
expert
,
:,
:],
layer
.
w13_weight_scale
[
expert
]
=
(
vllm_ops
.
scaled_fp8_quant
(
layer
.
w13_weight
.
data
[
expert
,
:,
:])
)
w2_weight
[
expert
,
:,
:],
layer
.
w2_weight_scale
[
expert
]
=
(
vllm_ops
.
scaled_fp8_quant
(
layer
.
w2_weight
.
data
[
expert
,
:,
:])
)
w13_weight
[
expert
,
:,
:],
layer
.
w13_weight_scale
[
expert
]
=
(
scaled_fp8_quant
(
layer
.
w13_weight
.
data
[
expert
,
:,
:])
)
w2_weight
[
expert
,
:,
:],
layer
.
w2_weight_scale
[
expert
]
=
(
scaled_fp8_quant
(
layer
.
w2_weight
.
data
[
expert
,
:,
:])
)
layer
.
w13_weight
=
torch
.
nn
.
Parameter
(
w13_weight
,
requires_grad
=
False
)
layer
.
w2_weight
=
torch
.
nn
.
Parameter
(
w2_weight
,
requires_grad
=
False
)
...
...
@@ -796,18 +783,10 @@ class Fp8MoEMethod:
layer
.
w13_weight
[
expert_id
][
start
:
start
+
shard_size
,
:],
layer
.
w13_weight_scale
[
expert_id
][
shard_id
],
)
if
_is_cuda
:
(
layer
.
w13_weight
[
expert_id
][
start
:
start
+
shard_size
,
:],
_
,
)
=
sgl_scaled_fp8_quant
(
dq_weight
,
max_w13_scales
[
expert_id
])
else
:
(
layer
.
w13_weight
[
expert_id
][
start
:
start
+
shard_size
,
:],
_
,
)
=
vllm_ops
.
scaled_fp8_quant
(
dq_weight
,
max_w13_scales
[
expert_id
]
)
(
layer
.
w13_weight
[
expert_id
][
start
:
start
+
shard_size
,
:],
_
,
)
=
scaled_fp8_quant
(
dq_weight
,
max_w13_scales
[
expert_id
])
start
+=
shard_size
layer
.
w13_weight_scale
=
torch
.
nn
.
Parameter
(
...
...
@@ -930,41 +909,11 @@ class Fp8MoEMethod:
correction_bias
=
correction_bias
,
)
if
_is_hip
and
get_bool_env_var
(
"USE_INT4_WEIGHT"
):
# TODO: add triton kernel and add check get_bool_env_var("CK_MOE")
assert
not
no_combine
,
f
"
{
no_combine
=
}
is not supported."
return
ck_moe_2stages_win4
(
x
,
layer
.
w13_weight
,
layer
.
w2_weight
,
topk_weights
,
topk_ids
,
layer
.
w13_weight_scale1
,
layer
.
w2_weight_scale1
,
activation
=
(
ActivationType
.
Silu
if
activation
==
"silu"
else
ActivationType
.
Gelu
),
)
if
_is_hip
and
get_bool_env_var
(
"CK_MOE"
):
assert
not
no_combine
,
f
"
{
no_combine
=
}
is not supported."
if
self
.
block_quant
:
# TODO(CK_MOE): FP8 block_quant only supports 'silu' for the time-being.
assert
(
activation
==
"silu"
),
f
"CK_MOE: FP8 bloack_quant
{
activation
=
}
will be supported later, unset CK_MOE"
return
asm_moe
(
x
,
layer
.
w13_weight
,
layer
.
w2_weight
,
topk_weights
,
topk_ids
,
layer
.
w13_weight_scale_inv
,
layer
.
w2_weight_scale_inv
,
block_shape
=
tuple
(
self
.
quant_config
.
weight_block_size
),
expert_mask
=
None
,
)
else
:
return
ck_moe_2stages
(
if
_is_hip
:
if
get_bool_env_var
(
"USE_INT4_WEIGHT"
):
# TODO: add triton kernel and add check get_bool_env_var("CK_MOE")
assert
not
no_combine
,
f
"
{
no_combine
=
}
is not supported."
return
ck_moe_2stages_win4
(
x
,
layer
.
w13_weight
,
layer
.
w2_weight
,
...
...
@@ -978,33 +927,65 @@ class Fp8MoEMethod:
else
ActivationType
.
Gelu
),
)
else
:
# Expert fusion with FP8 quantization
return
fused_experts
(
x
,
layer
.
w13_weight
,
layer
.
w2_weight
,
topk_weights
=
topk_weights
,
topk_ids
=
topk_ids
,
inplace
=
inplace
and
not
no_combine
,
activation
=
activation
,
apply_router_weight_on_input
=
apply_router_weight_on_input
,
use_fp8_w8a8
=
True
,
w1_scale
=
(
layer
.
w13_weight_scale_inv
if
self
.
block_quant
else
layer
.
w13_weight_scale
),
w2_scale
=
(
layer
.
w2_weight_scale_inv
if
self
.
block_quant
else
layer
.
w2_weight_scale
),
a1_scale
=
layer
.
w13_input_scale
,
a2_scale
=
layer
.
w2_input_scale
,
block_shape
=
self
.
quant_config
.
weight_block_size
,
no_combine
=
no_combine
,
)
if
get_bool_env_var
(
"CK_MOE"
):
assert
not
no_combine
,
f
"
{
no_combine
=
}
is not supported."
if
self
.
block_quant
:
# TODO(CK_MOE): FP8 block_quant only supports 'silu' for the time-being.
assert
(
activation
==
"silu"
),
f
"CK_MOE: FP8 bloack_quant
{
activation
=
}
will be supported later, unset CK_MOE"
return
asm_moe
(
x
,
layer
.
w13_weight
,
layer
.
w2_weight
,
topk_weights
,
topk_ids
,
layer
.
w13_weight_scale_inv
,
layer
.
w2_weight_scale_inv
,
block_shape
=
tuple
(
self
.
quant_config
.
weight_block_size
),
expert_mask
=
None
,
)
else
:
return
ck_moe_2stages
(
x
,
layer
.
w13_weight
,
layer
.
w2_weight
,
topk_weights
,
topk_ids
,
layer
.
w13_weight_scale1
,
layer
.
w2_weight_scale1
,
activation
=
(
ActivationType
.
Silu
if
activation
==
"silu"
else
ActivationType
.
Gelu
),
)
# Expert fusion with FP8 quantization
return
fused_experts
(
x
,
layer
.
w13_weight
,
layer
.
w2_weight
,
topk_weights
=
topk_weights
,
topk_ids
=
topk_ids
,
inplace
=
inplace
and
not
no_combine
,
activation
=
activation
,
apply_router_weight_on_input
=
apply_router_weight_on_input
,
use_fp8_w8a8
=
True
,
w1_scale
=
(
layer
.
w13_weight_scale_inv
if
self
.
block_quant
else
layer
.
w13_weight_scale
),
w2_scale
=
(
layer
.
w2_weight_scale_inv
if
self
.
block_quant
else
layer
.
w2_weight_scale
),
a1_scale
=
layer
.
w13_input_scale
,
a2_scale
=
layer
.
w2_input_scale
,
block_shape
=
self
.
quant_config
.
weight_block_size
,
no_combine
=
no_combine
,
)
class
Fp8KVCacheMethod
(
BaseKVCacheMethod
):
...
...
python/sglang/srt/layers/quantization/fp8_kernel.py
View file @
177320a5
...
...
@@ -34,15 +34,23 @@ from sglang.srt.utils import (
supports_custom_op
,
)
_enable_jit_deepgemm
=
False
_is_hip
=
is_hip
()
fp8_type_
=
torch
.
float8_e4m3fnuz
if
_is_hip
else
torch
.
float8_e4m3fn
_is_cuda
=
is_cuda
()
_fp8_type
=
torch
.
float8_e4m3fnuz
if
_is_hip
else
torch
.
float8_e4m3fn
if
_is_hip
:
fp8_max
=
224.0
else
:
fp8_max
=
torch
.
finfo
(
_fp8_type
).
max
fp8_min
=
-
fp8_max
_enable_jit_deepgemm
=
False
if
_is_cuda
:
import
deep_gemm
from
sgl_kernel
import
sgl_per_token_group_quant_fp8
,
sgl_per_token_quant_fp8
from
sgl_kernel
import
(
sgl_per_tensor_quant_fp8
,
sgl_per_token_group_quant_fp8
,
sgl_per_token_quant_fp8
,
)
sm_version
=
get_device_sm
()
if
sm_version
==
90
and
get_bool_env_var
(
...
...
@@ -53,6 +61,7 @@ if _is_cuda:
logger
=
logging
.
getLogger
(
__name__
)
if
supports_custom_op
():
def
deep_gemm_fp8_fp8_bf16_nt
(
...
...
@@ -179,7 +188,6 @@ def per_token_group_quant_fp8(
x
:
torch
.
Tensor
,
group_size
:
int
,
eps
:
float
=
1e-10
,
dtype
:
torch
.
dtype
=
fp8_type_
,
column_major_scales
:
bool
=
False
,
scale_tma_aligned
:
bool
=
False
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
...
...
@@ -192,7 +200,6 @@ def per_token_group_quant_fp8(
x: The input tenosr with ndim >= 2.
group_size: The group size used for quantization.
eps: The minimum to avoid dividing zero.
dtype: The dype of output tensor.
Returns:
Tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the scaling factor for quantization.
...
...
@@ -202,15 +209,7 @@ def per_token_group_quant_fp8(
),
"the last dimension of `x` cannot be divisible by `group_size`"
assert
x
.
is_contiguous
(),
"`x` is not contiguous"
finfo
=
torch
.
finfo
(
dtype
)
fp8_max
=
finfo
.
max
if
_is_hip
:
fp8_max
=
224.0
fp8_min
=
-
fp8_max
x_q
=
torch
.
empty_like
(
x
,
device
=
x
.
device
,
dtype
=
dtype
)
x_q
=
torch
.
empty_like
(
x
,
device
=
x
.
device
,
dtype
=
_fp8_type
)
M
=
x
.
numel
()
//
group_size
N
=
group_size
if
column_major_scales
:
...
...
@@ -276,27 +275,18 @@ def sglang_per_token_group_quant_fp8(
x
:
torch
.
Tensor
,
group_size
:
int
,
eps
:
float
=
1e-10
,
dtype
:
torch
.
dtype
=
fp8_type_
,
):
assert
(
x
.
shape
[
-
1
]
%
group_size
==
0
),
"the last dimension of `x` cannot be divisible by `group_size`"
assert
x
.
is_contiguous
(),
"`x` is not contiguous"
finfo
=
torch
.
finfo
(
dtype
)
fp8_max
=
finfo
.
max
fp8_min
=
-
fp8_max
x_q
=
torch
.
empty_like
(
x
,
device
=
x
.
device
,
dtype
=
dtype
)
M
=
x
.
numel
()
//
group_size
N
=
group_size
x_q
=
torch
.
empty_like
(
x
,
device
=
x
.
device
,
dtype
=
_fp8_type
)
x_s
=
torch
.
empty
(
x
.
shape
[:
-
1
]
+
(
x
.
shape
[
-
1
]
//
group_size
,),
device
=
x
.
device
,
dtype
=
torch
.
float32
,
)
sgl_per_token_group_quant_fp8
(
x
,
x_q
,
x_s
,
group_size
,
eps
,
fp8_min
,
fp8_max
)
return
x_q
,
x_s
...
...
@@ -304,7 +294,7 @@ def sglang_per_token_group_quant_fp8(
def
sglang_per_token_quant_fp8
(
x
:
torch
.
Tensor
,
dtype
:
torch
.
dtype
=
fp8_type
_
,
dtype
:
torch
.
dtype
=
_
fp8_type
,
):
assert
x
.
is_contiguous
(),
"`x` is not contiguous"
...
...
@@ -368,7 +358,6 @@ def static_quant_fp8(
x
:
torch
.
Tensor
,
x_s
:
torch
.
Tensor
,
repeat_scale
:
bool
=
False
,
dtype
:
torch
.
dtype
=
fp8_type_
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""Function to perform static quantization using the given scale on an input tensor `x`.
...
...
@@ -386,15 +375,8 @@ def static_quant_fp8(
"""
assert
x
.
is_contiguous
(),
"`x` is not contiguous"
assert
x_s
.
numel
()
==
1
,
"only supports per-tensor scale"
finfo
=
torch
.
finfo
(
dtype
)
fp8_max
=
finfo
.
max
if
_is_hip
:
fp8_max
=
224.0
fp8_min
=
-
fp8_max
x_q
=
torch
.
empty_like
(
x
,
device
=
x
.
device
,
dtype
=
d
type
)
x_q
=
torch
.
empty_like
(
x
,
device
=
x
.
device
,
dtype
=
_fp8_
type
)
M
=
x
.
numel
()
//
x
.
shape
[
-
1
]
N
=
x
.
shape
[
-
1
]
if
repeat_scale
:
...
...
@@ -896,7 +878,7 @@ def _per_tensor_quant_mla_fp8_stage2(
def
per_tensor_quant_mla_fp8
(
x
:
torch
.
Tensor
,
eps
:
float
=
1e-12
,
dtype
:
torch
.
dtype
=
torch
.
float8_e4m3fn
x
:
torch
.
Tensor
,
eps
:
float
=
1e-12
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""
This function quantizes input values to float8 values with tensor-wise quantization
...
...
@@ -904,13 +886,7 @@ def per_tensor_quant_mla_fp8(
"""
assert
x
.
dim
()
==
3
,
"`x` is not a 3d-tensor"
finfo
=
torch
.
finfo
(
dtype
)
fp8_max
=
finfo
.
max
if
_is_hip
:
dtype
=
torch
.
float8_e4m3fnuz
fp8_max
=
224.0
x_q
=
x
.
new_empty
(
x
.
size
(),
dtype
=
dtype
)
x_q
=
x
.
new_empty
(
x
.
size
(),
dtype
=
_fp8_type
)
x_s
=
torch
.
zeros
((
1
,),
dtype
=
torch
.
float32
,
device
=
x
.
device
)
num_head
,
num_seq
,
head_size
=
x
.
shape
...
...
@@ -935,9 +911,64 @@ def per_tensor_quant_mla_fp8(
head_size
,
x
.
stride
(
0
),
x
.
stride
(
1
),
-
fp8_m
ax
,
fp8_m
in
,
fp8_max
,
BLOCK_SIZE
,
)
return
x_q
,
x_s
def
scaled_fp8_quant
(
input
:
torch
.
Tensor
,
scale
:
Optional
[
torch
.
Tensor
]
=
None
,
num_token_padding
:
Optional
[
int
]
=
None
,
use_per_token_if_dynamic
:
bool
=
False
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""
Quantize input tensor to FP8 (8-bit floating point) format.
Args:
input (torch.Tensor): Input tensor to be quantized
scale (Optional[torch.Tensor]): Pre-computed scaling factor for static quantization.
If None, scales will be computed dynamically.
num_token_padding (Optional[int]): If specified, pad the first dimension
of the output to at least this value.
use_per_token_if_dynamic (bool): When using dynamic scaling (scale=None),
determines the quantization granularity:
- True: compute scale per token
- False: compute single scale per tensor
Returns:
Tuple[torch.Tensor, torch.Tensor]: A tuple containing:
- quantized_tensor: The FP8 quantized version of input
- scale_tensor: The scaling factors used for quantization
Raises:
AssertionError: If input is not 2D or if static scale's numel != 1
"""
assert
input
.
ndim
==
2
,
f
"Expected 2D input tensor, got
{
input
.
ndim
}
D"
shape
=
input
.
shape
out_dtype
=
torch
.
float8_e4m3fnuz
if
_is_hip
else
torch
.
float8_e4m3fn
if
num_token_padding
:
shape
=
(
max
(
num_token_padding
,
input
.
shape
[
0
]),
shape
[
1
])
output
=
torch
.
empty
(
shape
,
device
=
input
.
device
,
dtype
=
out_dtype
)
if
scale
is
None
:
# Dynamic scaling
if
use_per_token_if_dynamic
:
scale
=
torch
.
empty
((
shape
[
0
],
1
),
device
=
input
.
device
,
dtype
=
torch
.
float32
)
sgl_per_token_quant_fp8
(
input
,
output
,
scale
)
else
:
scale
=
torch
.
zeros
(
1
,
device
=
input
.
device
,
dtype
=
torch
.
float32
)
sgl_per_tensor_quant_fp8
(
input
,
output
,
scale
,
is_static
=
False
)
# False for dynamic
else
:
# Static scaling
assert
scale
.
numel
()
==
1
,
f
"Expected scalar scale, got numel=
{
scale
.
numel
()
}
"
sgl_per_tensor_quant_fp8
(
input
,
output
,
scale
,
is_static
=
True
)
# True for static
return
output
,
scale
python/sglang/srt/layers/quantization/fp8_utils.py
View file @
177320a5
import
os
from
typing
import
List
,
Optional
,
Tuple
import
torch
try
:
from
vllm
import
_custom_ops
as
vllm_ops
VLLM_AVAILABLE
=
True
except
ImportError
:
VLLM_AVAILABLE
=
False
from
sglang.srt.layers.quantization.fp8_kernel
import
(
_enable_jit_deepgemm
,
per_token_group_quant_fp8
,
scaled_fp8_quant
,
sglang_per_token_quant_fp8
,
static_quant_fp8
,
w8a8_block_fp8_matmul
,
)
...
...
@@ -17,30 +25,20 @@ from sglang.srt.utils import (
is_hip
,
)
try
:
import
vllm
from
vllm
import
_custom_ops
as
ops
VLLM_AVAILABLE
=
True
except
ImportError
:
VLLM_AVAILABLE
=
False
use_vllm_cutlass_w8a8_fp8_kernel
=
get_bool_env_var
(
"USE_VLLM_CUTLASS_W8A8_FP8_KERNEL"
)
_is_hip
=
is_hip
()
_is_cuda
=
is_cuda
()
if
_is_hip
and
get_bool_env_var
(
"CK_MOE"
):
from
aiter
import
gemm_a8w8_blockscale
_is_cuda
=
is_cuda
()
if
_is_cuda
:
from
sgl_kernel
import
fp8_blockwise_scaled_mm
,
fp8_scaled_mm
from
sglang.srt.custom_op
import
scaled_fp8_quant
as
sgl_scaled_fp8_quant
from
sglang.srt.layers.quantization.fp8_kernel
import
sglang_per_token_quant_fp8
use_vllm_cutlass_w8a8_fp8_kernel
=
get_bool_env_var
(
"USE_VLLM_CUTLASS_W8A8_FP8_KERNEL"
)
# Input scaling factors are no longer optional in _scaled_mm starting
# from pytorch 2.5. Allocating a dummy tensor to pass as input_scale
TORCH_DEVICE_IDENTITY
=
torch
.
ones
(
1
,
dtype
=
torch
.
float32
)
TORCH_DEVICE_IDENTITY
=
None
_TORCH_VERSION
=
torch
.
__version__
.
split
(
"+"
)[
0
]
try
:
...
...
@@ -214,7 +212,7 @@ def block_quant_to_tensor_quant(
x_dq_block_tiles
[
j
][
i
][:,
:]
=
x_dq_block_tiles
[
j
][
i
]
*
x_s
[
j
][
i
]
x_q_tensor
,
scale
=
(
sgl_
scaled_fp8_quant
(
x_dq_block
)
scaled_fp8_quant
(
x_dq_block
)
if
_is_cuda
else
input_to_float8
(
x_dq_block
,
dtype
=
x_q_block
.
dtype
)
)
...
...
@@ -227,7 +225,7 @@ def channel_quant_to_tensor_quant(
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
x_dq_channel
=
x_q_channel
.
to
(
torch
.
float32
)
*
x_s
x_q_tensor
,
scale
=
(
sgl_
scaled_fp8_quant
(
x_dq_channel
)
scaled_fp8_quant
(
x_dq_channel
)
if
_is_cuda
else
input_to_float8
(
x_dq_channel
,
dtype
=
x_q_channel
.
dtype
)
)
...
...
@@ -264,7 +262,7 @@ def apply_fp8_linear(
# final solution should be: 1. add support to per-tensor activation scaling.
# 2. solve the torch.compile error from weight_scale.numel() == 1 and x_scale.numel() > 1 (below line#308)
if
_is_hip
and
weight_scale
.
numel
()
==
1
:
qinput
,
x_scale
=
ops
.
scaled_fp8_quant
(
qinput
,
x_scale
=
vllm_
ops
.
scaled_fp8_quant
(
input_2d
,
input_scale
,
use_per_token_if_dynamic
=
use_per_token_if_dynamic
,
...
...
@@ -275,32 +273,29 @@ def apply_fp8_linear(
)
if
cutlass_fp8_supported
:
try
:
if
VLLM_AVAILABLE
and
use_vllm_cutlass_w8a8_fp8_kernel
:
# Fall back to vllm cutlass w8a8 fp8 kernel
output
=
ops
.
cutlass_scaled_mm
(
qinput
,
weight
,
out_dtype
=
input
.
dtype
,
scale_a
=
x_scale
,
scale_b
=
weight_scale
,
bias
=
bias
,
)
else
:
assert
(
weight_scale
.
numel
()
==
weight
.
shape
[
1
]
),
"cutlass w8a8 fp8 sgl-kernel only supports per-channel scale"
output
=
fp8_scaled_mm
(
qinput
,
weight
,
x_scale
,
weight_scale
,
out_dtype
=
input
.
dtype
,
bias
=
bias
,
)
return
output
.
view
(
*
output_shape
)
except
(
ImportError
,
NameError
,
AttributeError
):
pass
if
VLLM_AVAILABLE
and
use_vllm_cutlass_w8a8_fp8_kernel
:
# Fall back to vllm cutlass w8a8 fp8 kernel
output
=
vllm_ops
.
cutlass_scaled_mm
(
qinput
,
weight
,
out_dtype
=
input
.
dtype
,
scale_a
=
x_scale
,
scale_b
=
weight_scale
,
bias
=
bias
,
)
else
:
assert
(
weight_scale
.
numel
()
==
weight
.
shape
[
1
]
),
"cutlass w8a8 fp8 sgl-kernel only supports per-channel scale"
output
=
fp8_scaled_mm
(
qinput
,
weight
,
x_scale
,
weight_scale
,
out_dtype
=
input
.
dtype
,
bias
=
bias
,
)
return
output
.
view
(
*
output_shape
)
# torch.scaled_mm supports per tensor weights + activations only
# so fallback to naive if per channel or per token
...
...
@@ -343,8 +338,10 @@ def apply_fp8_linear(
# Making sure the dummy tensor is on the same device as the weight
global
TORCH_DEVICE_IDENTITY
if
TORCH_DEVICE_IDENTITY
.
device
!=
weight
.
device
:
TORCH_DEVICE_IDENTITY
=
TORCH_DEVICE_IDENTITY
.
to
(
weight
.
device
)
if
TORCH_DEVICE_IDENTITY
is
None
:
TORCH_DEVICE_IDENTITY
=
torch
.
ones
(
1
,
dtype
=
torch
.
float32
,
device
=
weight
.
device
)
# GEMM
# This computes C = (X * W).
...
...
@@ -372,13 +369,6 @@ def apply_fp8_linear(
return
output
.
to
(
dtype
=
input
.
dtype
).
view
(
*
output_shape
)
def
maybe_create_device_identity
():
# Allocate dummy ones tensor for torch._scaled_mm
global
TORCH_DEVICE_IDENTITY
if
TORCH_DEVICE_IDENTITY
is
None
:
TORCH_DEVICE_IDENTITY
=
torch
.
ones
(
1
,
dtype
=
torch
.
float32
)
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/utils/w8a8_utils.py
# TODO(luka): follow similar pattern for marlin and block-fp8-linear
# https://github.com/vllm-project/vllm/issues/14397
...
...
@@ -405,9 +395,7 @@ class Fp8LinearOp:
# We also don't pad when using torch.compile,
# as it breaks with dynamic shapes.
if
pad_output
is
None
:
enable_torch_compile
=
os
.
environ
.
get
(
"SGLANG_ENABLE_TORCH_COMPILE"
,
"0"
).
lower
()
in
(
"1"
,
"true"
,
"yes"
)
enable_torch_compile
=
get_bool_env_var
(
"SGLANG_ENABLE_TORCH_COMPILE"
)
pad_output
=
not
enable_torch_compile
self
.
output_padding
=
17
if
pad_output
else
None
...
...
@@ -439,13 +427,13 @@ class Fp8LinearOp:
# for sgl-kernel fp8_scaled_mm, it support per channel W now
if
self
.
cutlass_fp8_supported
and
weight_scale
.
numel
()
==
weight
.
shape
[
1
]:
if
_is_cuda
:
qinput
,
x_scale
=
sgl_
scaled_fp8_quant
(
qinput
,
x_scale
=
scaled_fp8_quant
(
input_2d
,
input_scale
,
use_per_token_if_dynamic
=
use_per_token_if_dynamic
,
)
else
:
qinput
,
x_scale
=
ops
.
scaled_fp8_quant
(
qinput
,
x_scale
=
vllm_
ops
.
scaled_fp8_quant
(
input_2d
,
input_scale
,
scale_ub
=
input_scale_ub
,
...
...
@@ -455,7 +443,7 @@ class Fp8LinearOp:
# Fused GEMM_DQ
if
VLLM_AVAILABLE
and
use_vllm_cutlass_w8a8_fp8_kernel
:
# Fall back to vllm cutlass w8a8 fp8 kernel
output
=
ops
.
cutlass_scaled_mm
(
output
=
vllm_
ops
.
cutlass_scaled_mm
(
qinput
,
weight
,
out_dtype
=
input
.
dtype
,
...
...
@@ -482,14 +470,14 @@ class Fp8LinearOp:
else
:
# Maybe apply padding to output, see comment in __init__
if
_is_cuda
:
qinput
,
x_scale
=
sgl_
scaled_fp8_quant
(
qinput
,
x_scale
=
scaled_fp8_quant
(
input_2d
,
input_scale
,
num_token_padding
=
self
.
output_padding
,
use_per_token_if_dynamic
=
use_per_token_if_dynamic
,
)
else
:
qinput
,
x_scale
=
ops
.
scaled_fp8_quant
(
qinput
,
x_scale
=
vllm_
ops
.
scaled_fp8_quant
(
input_2d
,
input_scale
,
num_token_padding
=
self
.
output_padding
,
...
...
@@ -562,9 +550,12 @@ class Fp8LinearOp:
# This computes C = (X * W).
# Output in fp32 to allow subsequent ops to happen in-place
# Making sure the dummy tensor is on the same device as the weight
global
TORCH_DEVICE_IDENTITY
if
TORCH_DEVICE_IDENTITY
.
device
!=
weight
.
device
:
TORCH_DEVICE_IDENTITY
=
TORCH_DEVICE_IDENTITY
.
to
(
weight
.
device
)
if
TORCH_DEVICE_IDENTITY
is
None
:
TORCH_DEVICE_IDENTITY
=
torch
.
ones
(
1
,
dtype
=
torch
.
float32
,
device
=
weight
.
device
)
output
=
torch
.
_scaled_mm
(
qinput
,
...
...
python/sglang/srt/layers/quantization/utils.py
View file @
177320a5
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/utils/quant_utils.py
from
types
import
MappingProxyType
from
typing
import
List
,
Mapping
,
Optional
,
Tuple
,
Union
from
typing
import
List
,
Mapping
,
Tuple
,
Union
import
torch
from
sglang.srt.layers.quantization.fp8_kernel
import
scaled_fp8_quant
from
sglang.srt.utils
import
is_cuda
_is_cuda
=
is_cuda
()
if
_is_cuda
:
from
sglang.srt.custom_op
import
scaled_fp8_quant
as
sgl_scaled_fp8_quant
else
:
from
vllm
import
_custom_ops
as
vllm_ops
if
not
_is_cuda
:
from
vllm._custom_ops
import
scaled_fp8_quant
def
is_fp8_fnuz
()
->
bool
:
...
...
@@ -116,12 +115,7 @@ def requantize_with_max_scale(
for
idx
,
logical_width
in
enumerate
(
logical_widths
):
end
=
start
+
logical_width
weight_dq
=
per_tensor_dequantize
(
weight
[
start
:
end
,
:],
weight_scale
[
idx
])
if
_is_cuda
:
weight
[
start
:
end
,
:],
_
=
sgl_scaled_fp8_quant
(
weight_dq
,
max_w_scale
)
else
:
weight
[
start
:
end
,
:],
_
=
vllm_ops
.
scaled_fp8_quant
(
weight_dq
,
max_w_scale
)
weight
[
start
:
end
,
:],
_
=
scaled_fp8_quant
(
weight_dq
,
max_w_scale
)
start
=
end
return
max_w_scale
,
weight
...
...
python/sglang/srt/layers/quantization/w8a8_int8.py
View file @
177320a5
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Optional
import
torch
from
sglang.srt.utils
import
is_cuda_available
,
set_weight_attrs
is_cuda
=
is_cuda_available
()
if
is_cuda
:
from
sgl_kernel
import
int8_scaled_mm
from
torch.nn.parameter
import
Parameter
from
sglang.srt.distributed
import
get_tensor_model_parallel_world_size
...
...
@@ -18,6 +11,11 @@ from sglang.srt.layers.quantization.base_config import (
QuantizeMethodBase
,
)
from
sglang.srt.layers.quantization.int8_kernel
import
per_token_quant_int8
from
sglang.srt.utils
import
is_cuda_available
,
set_weight_attrs
is_cuda
=
is_cuda_available
()
if
is_cuda
:
from
sgl_kernel
import
int8_scaled_mm
class
W8A8Int8Config
(
QuantizationConfig
):
...
...
python/sglang/srt/layers/rotary_embedding.py
View file @
177320a5
...
...
@@ -11,10 +11,11 @@ from sglang.srt.custom_op import CustomOp
from
sglang.srt.utils
import
is_cuda_available
_is_cuda_available
=
is_cuda_available
()
if
_is_cuda_available
:
from
sgl_kernel
import
apply_rope_with_cos_sin_cache_inplace
else
:
from
vllm
import
_custom_ops
as
ops
from
vllm
.
_custom_ops
import
rotary_embedding
as
vllm_rotary_embedding
def
_rotate_neox
(
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
...
...
@@ -159,7 +160,7 @@ class RotaryEmbedding(CustomOp):
)
else
:
self
.
cos_sin_cache
=
self
.
cos_sin_cache
.
to
(
query
.
device
,
dtype
=
query
.
dtype
)
ops
.
rotary_embedding
(
vllm_
rotary_embedding
(
positions
,
query
,
key
,
...
...
python/sglang/srt/lora/backend/__init__.py
deleted
100644 → 0
View file @
d7bc19a4
from
sglang.srt.lora.backend.base_backend
import
BaseLoRABackend
def
get_backend_from_name
(
name
:
str
)
->
BaseLoRABackend
:
"""
Get corresponding backend class from backend's name
"""
if
name
==
"triton"
:
from
sglang.srt.lora.backend.triton_backend
import
TritonLoRABackend
return
TritonLoRABackend
elif
name
==
"flashinfer"
:
from
sglang.srt.lora.backend.flashinfer_backend
import
FlashInferLoRABackend
return
FlashInferLoRABackend
else
:
raise
ValueError
(
f
"Invalid backend:
{
name
}
"
)
__all__
=
[
"BaseLoRABackend"
,
"FlashInferLoRABackend"
,
"TritonLoRABackend"
,
"get_backend_from_name"
,
]
python/sglang/srt/lora/backend/base_backend.py
View file @
177320a5
...
...
@@ -75,7 +75,7 @@ class BaseLoRABackend:
qkv_lora_a
:
torch
.
Tensor
,
qkv_lora_b
:
Union
[
torch
.
Tensor
,
Tuple
[
torch
.
Tensor
]],
*
args
,
**
kwargs
**
kwargs
,
)
->
torch
.
Tensor
:
"""Run the lora pass for QKV Layer.
...
...
@@ -98,7 +98,7 @@ class BaseLoRABackend:
gate_up_lora_a
:
torch
.
Tensor
,
gate_up_lora_b
:
Union
[
torch
.
Tensor
,
Tuple
[
torch
.
Tensor
]],
*
args
,
**
kwargs
**
kwargs
,
)
->
torch
.
Tensor
:
"""Run the lora pass for gate_up_proj, usually attached to MergedColumnParallelLayer.
...
...
@@ -115,3 +115,19 @@ class BaseLoRABackend:
def
set_batch_info
(
self
,
batch_info
:
LoRABatchInfo
):
self
.
batch_info
=
batch_info
def
get_backend_from_name
(
name
:
str
)
->
BaseLoRABackend
:
"""
Get corresponding backend class from backend's name
"""
if
name
==
"triton"
:
from
sglang.srt.lora.backend.triton_backend
import
TritonLoRABackend
return
TritonLoRABackend
elif
name
==
"flashinfer"
:
from
sglang.srt.lora.backend.flashinfer_backend
import
FlashInferLoRABackend
return
FlashInferLoRABackend
else
:
raise
ValueError
(
f
"Invalid backend:
{
name
}
"
)
python/sglang/srt/lora/backend/flashinfer_backend.py
View file @
177320a5
...
...
@@ -2,7 +2,7 @@ from typing import Tuple
import
torch
from
sglang.srt.lora.backend
import
BaseLoRABackend
from
sglang.srt.lora.backend
.base_backend
import
BaseLoRABackend
from
sglang.srt.lora.utils
import
LoRABatchInfo
from
sglang.srt.utils
import
is_flashinfer_available
...
...
python/sglang/srt/lora/backend/triton_backend.py
View file @
177320a5
import
torch
from
sglang.srt.lora.backend
import
BaseLoRABackend
from
sglang.srt.lora.backend
.base_backend
import
BaseLoRABackend
from
sglang.srt.lora.triton_ops
import
(
gate_up_lora_b_fwd
,
qkv_lora_b_fwd
,
...
...
python/sglang/srt/lora/layers.py
View file @
177320a5
...
...
@@ -16,7 +16,7 @@ from sglang.srt.layers.linear import (
RowParallelLinear
,
)
from
sglang.srt.layers.vocab_parallel_embedding
import
VocabParallelEmbedding
from
sglang.srt.lora.backend
import
BaseLoRABackend
from
sglang.srt.lora.backend
.base_backend
import
BaseLoRABackend
class
BaseLayerWithLoRA
(
nn
.
Module
):
...
...
python/sglang/srt/lora/lora.py
View file @
177320a5
...
...
@@ -27,7 +27,7 @@ from torch import nn
from
sglang.srt.configs.load_config
import
LoadConfig
from
sglang.srt.hf_transformers_utils
import
AutoConfig
from
sglang.srt.lora.backend
import
BaseLoRABackend
from
sglang.srt.lora.backend
.base_backend
import
BaseLoRABackend
from
sglang.srt.lora.lora_config
import
LoRAConfig
from
sglang.srt.model_loader.loader
import
DefaultModelLoader
...
...
python/sglang/srt/lora/lora_manager.py
View file @
177320a5
...
...
@@ -22,7 +22,7 @@ import torch
from
sglang.srt.configs.load_config
import
LoadConfig
from
sglang.srt.hf_transformers_utils
import
AutoConfig
from
sglang.srt.lora.backend
import
BaseLoRABackend
,
get_backend_from_name
from
sglang.srt.lora.backend
.base_backend
import
BaseLoRABackend
,
get_backend_from_name
from
sglang.srt.lora.layers
import
BaseLayerWithLoRA
,
get_lora_layer
from
sglang.srt.lora.lora
import
LoRAAdapter
from
sglang.srt.lora.lora_config
import
LoRAConfig
...
...
python/sglang/srt/managers/detokenizer_manager.py
View file @
177320a5
...
...
@@ -14,7 +14,6 @@
"""DetokenizerManager is a process that detokenizes the token ids."""
import
dataclasses
import
json
import
logging
import
os
import
signal
...
...
python/sglang/srt/managers/mm_utils.py
View file @
177320a5
"""
Multi-modality utils
Multi-modality utils
"""
import
logging
from
abc
import
abstractmethod
from
typing
import
Callable
,
List
,
Optional
,
Tuple
...
...
@@ -12,11 +13,11 @@ from sglang.srt.managers.schedule_batch import (
MultimodalDataItem
,
MultimodalInputs
,
global_server_args_dict
,
logger
,
)
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.utils
import
print_warning_once
from
sglang.utils
import
logger
logger
=
logging
.
getLogger
(
__name__
)
class
MultiModalityDataPaddingPattern
:
...
...
python/sglang/srt/managers/multimodal_processor.py
View file @
177320a5
...
...
@@ -5,8 +5,6 @@ import logging
import
pkgutil
from
functools
import
lru_cache
from
transformers
import
PROCESSOR_MAPPING
from
sglang.srt.managers.multimodal_processors.base_processor
import
(
BaseMultimodalProcessor
,
)
...
...
python/sglang/srt/managers/multimodal_processors/base_processor.py
View file @
177320a5
...
...
@@ -8,8 +8,6 @@ from typing import List, Optional
import
numpy
as
np
import
PIL
from
decord
import
VideoReader
,
cpu
from
PIL
import
Image
from
transformers
import
BaseImageProcessorFast
from
sglang.srt.managers.schedule_batch
import
Modality
...
...
@@ -102,6 +100,9 @@ class BaseMultimodalProcessor(ABC):
"""
estimate the total frame count from all visual input
"""
# Lazy import because decord is not available on some arm platforms.
from
decord
import
VideoReader
,
cpu
# Before processing inputs
estimated_frames_list
=
[]
for
image
in
image_data
:
...
...
python/sglang/srt/model_executor/cuda_graph_runner.py
View file @
177320a5
...
...
@@ -37,11 +37,11 @@ from sglang.srt.model_executor.forward_batch_info import (
from
sglang.srt.patch_torch
import
monkey_patch_torch_compile
from
sglang.srt.utils
import
get_available_gpu_memory
,
is_hip
_is_hip
=
is_hip
()
if
TYPE_CHECKING
:
from
sglang.srt.model_executor.model_runner
import
ModelRunner
_is_hip
=
is_hip
()
def
_to_torch
(
model
:
torch
.
nn
.
Module
,
reverse
:
bool
,
num_tokens
:
int
):
for
sub
in
model
.
_modules
.
values
():
...
...
python/sglang/srt/model_executor/model_runner.py
View file @
177320a5
...
...
@@ -320,7 +320,6 @@ class ModelRunner:
logger
.
info
(
f
"DeepEP is turned on. DeepEP mode:
{
server_args
.
deepep_mode
}
"
)
if
not
self
.
use_mla_backend
:
logger
.
info
(
"Disable chunked prefix cache for non-MLA backend."
)
server_args
.
disable_chunked_prefix_cache
=
True
elif
self
.
page_size
>
1
:
logger
.
info
(
"Disable chunked prefix cache when page size > 1."
)
...
...
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment