Unverified Commit 177320a5 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Clean up imports (#5467)

parent d7bc19a4
...@@ -17,7 +17,6 @@ from sglang.srt.layers.quantization.compressed_tensors.schemes import ( ...@@ -17,7 +17,6 @@ from sglang.srt.layers.quantization.compressed_tensors.schemes import (
) )
from sglang.srt.layers.quantization.fp8_utils import ( from sglang.srt.layers.quantization.fp8_utils import (
Fp8LinearOp, Fp8LinearOp,
maybe_create_device_identity,
normalize_e4m3fn_to_e4m3fnuz, normalize_e4m3fn_to_e4m3fnuz,
) )
from sglang.srt.layers.quantization.utils import is_fp8_fnuz, requantize_with_max_scale from sglang.srt.layers.quantization.utils import is_fp8_fnuz, requantize_with_max_scale
...@@ -99,8 +98,6 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme): ...@@ -99,8 +98,6 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
weight_loader: Callable, weight_loader: Callable,
**kwargs, **kwargs,
): ):
maybe_create_device_identity()
output_size_per_partition = sum(output_partition_sizes) output_size_per_partition = sum(output_partition_sizes)
layer.logical_widths = output_partition_sizes layer.logical_widths = output_partition_sizes
......
...@@ -8,15 +8,6 @@ import torch.nn.functional as F ...@@ -8,15 +8,6 @@ import torch.nn.functional as F
from torch.nn import Module from torch.nn import Module
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
from sglang.srt.layers.quantization.kv_cache import BaseKVCacheMethod
from sglang.srt.layers.quantization.utils import (
all_close_1d,
convert_to_channelwise,
is_layer_skipped,
per_tensor_dequantize,
requantize_with_max_scale,
)
try: try:
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import ( from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
apply_fp8_marlin_linear, apply_fp8_marlin_linear,
...@@ -27,11 +18,12 @@ try: ...@@ -27,11 +18,12 @@ try:
except ImportError: except ImportError:
MARLIN_FP8_AVAILABLE = False MARLIN_FP8_AVAILABLE = False
def apply_fp8_marlin_linear(*args, **kwargs): def dummy_func(*args, **kwargs):
raise ImportError("vllm is not installed") raise ImportError(
"marlin FP8 requires some operators from vllm. Please install vllm."
)
def prepare_fp8_layer_for_marlin(*args, **kwargs): apply_fp8_marlin_linear = prepare_fp8_layer_for_marlin = dummy_func
raise ImportError("vllm is not installed")
from sglang.srt.distributed import get_tensor_model_parallel_world_size from sglang.srt.distributed import get_tensor_model_parallel_world_size
...@@ -49,7 +41,10 @@ from sglang.srt.layers.quantization.base_config import ( ...@@ -49,7 +41,10 @@ from sglang.srt.layers.quantization.base_config import (
QuantizationConfig, QuantizationConfig,
QuantizeMethodBase, QuantizeMethodBase,
) )
from sglang.srt.layers.quantization.fp8_kernel import per_token_group_quant_fp8 from sglang.srt.layers.quantization.fp8_kernel import (
per_token_group_quant_fp8,
scaled_fp8_quant,
)
from sglang.srt.layers.quantization.fp8_utils import ( from sglang.srt.layers.quantization.fp8_utils import (
apply_fp8_linear, apply_fp8_linear,
apply_w8a8_block_fp8_linear, apply_w8a8_block_fp8_linear,
...@@ -57,30 +52,35 @@ from sglang.srt.layers.quantization.fp8_utils import ( ...@@ -57,30 +52,35 @@ from sglang.srt.layers.quantization.fp8_utils import (
input_to_float8, input_to_float8,
normalize_e4m3fn_to_e4m3fnuz, normalize_e4m3fn_to_e4m3fnuz,
) )
from sglang.srt.layers.quantization.kv_cache import BaseKVCacheMethod
from sglang.srt.layers.quantization.utils import (
all_close_1d,
convert_to_channelwise,
is_layer_skipped,
per_tensor_dequantize,
requantize_with_max_scale,
)
from sglang.srt.utils import ( from sglang.srt.utils import (
get_bool_env_var, get_bool_env_var,
is_cuda, is_cuda,
is_hip, is_hip,
permute_weight,
print_warning_once, print_warning_once,
set_weight_attrs, set_weight_attrs,
) )
ACTIVATION_SCHEMES = ["static", "dynamic"]
_is_hip = is_hip() _is_hip = is_hip()
_is_cuda = is_cuda()
if _is_hip: if _is_hip:
from aiter import ActivationType from aiter import ActivationType
from aiter.fused_moe_bf16_asm import asm_moe, ck_moe_2stages, ck_moe_2stages_win4 from aiter.fused_moe_bf16_asm import asm_moe, ck_moe_2stages, ck_moe_2stages_win4
from aiter.ops.shuffle import shuffle_weight from aiter.ops.shuffle import shuffle_weight
_is_cuda = is_cuda() if not _is_cuda:
from vllm._custom_ops import scaled_fp8_quant
if _is_cuda: ACTIVATION_SCHEMES = ["static", "dynamic"]
from sglang.srt.custom_op import scaled_fp8_quant as sgl_scaled_fp8_quant
else:
from vllm import _custom_ops as vllm_ops
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -243,7 +243,6 @@ class Fp8LinearMethod(LinearMethodBase): ...@@ -243,7 +243,6 @@ class Fp8LinearMethod(LinearMethodBase):
) )
layer.logical_widths = output_partition_sizes layer.logical_widths = output_partition_sizes
layer.input_size_per_partition = input_size_per_partition layer.input_size_per_partition = input_size_per_partition
layer.output_size_per_partition = output_size_per_partition layer.output_size_per_partition = output_size_per_partition
layer.orig_dtype = params_dtype layer.orig_dtype = params_dtype
...@@ -327,7 +326,9 @@ class Fp8LinearMethod(LinearMethodBase): ...@@ -327,7 +326,9 @@ class Fp8LinearMethod(LinearMethodBase):
layer.weight_scale_inv.data, requires_grad=False layer.weight_scale_inv.data, requires_grad=False
) )
return return
layer.weight = torch.nn.Parameter(layer.weight.data, requires_grad=False) layer.weight = torch.nn.Parameter(layer.weight.data, requires_grad=False)
# If checkpoint not serialized fp8, quantize the weights. # If checkpoint not serialized fp8, quantize the weights.
if not self.quant_config.is_checkpoint_fp8_serialized: if not self.quant_config.is_checkpoint_fp8_serialized:
if self.cutlass_fp8_supported or self.use_marlin: if self.cutlass_fp8_supported or self.use_marlin:
...@@ -391,12 +392,9 @@ class Fp8LinearMethod(LinearMethodBase): ...@@ -391,12 +392,9 @@ class Fp8LinearMethod(LinearMethodBase):
) )
if self.use_marlin: if self.use_marlin:
try:
prepare_fp8_layer_for_marlin(layer) prepare_fp8_layer_for_marlin(layer)
# Activations not quantized for marlin. # Activations not quantized for marlin.
del layer.input_scale del layer.input_scale
except ImportError:
self.use_marlin = False
def apply( def apply(
self, self,
...@@ -406,7 +404,6 @@ class Fp8LinearMethod(LinearMethodBase): ...@@ -406,7 +404,6 @@ class Fp8LinearMethod(LinearMethodBase):
) -> torch.Tensor: ) -> torch.Tensor:
if self.use_marlin: if self.use_marlin:
try:
return apply_fp8_marlin_linear( return apply_fp8_marlin_linear(
input=x, input=x,
weight=layer.weight, weight=layer.weight,
...@@ -416,8 +413,6 @@ class Fp8LinearMethod(LinearMethodBase): ...@@ -416,8 +413,6 @@ class Fp8LinearMethod(LinearMethodBase):
size_k=layer.input_size_per_partition, size_k=layer.input_size_per_partition,
bias=bias, bias=bias,
) )
except ImportError:
self.use_marlin = False
if self.block_quant: if self.block_quant:
return apply_w8a8_block_fp8_linear( return apply_w8a8_block_fp8_linear(
...@@ -516,7 +511,7 @@ class Fp8MoEMethod: ...@@ -516,7 +511,7 @@ class Fp8MoEMethod:
) )
# WEIGHTS # WEIGHTS
if get_bool_env_var("USE_INT4_WEIGHT"): if _is_hip and get_bool_env_var("USE_INT4_WEIGHT"):
# INT4 MoE weight - INT32 packed # INT4 MoE weight - INT32 packed
w13_weight = torch.nn.Parameter( w13_weight = torch.nn.Parameter(
torch.empty( torch.empty(
...@@ -617,7 +612,7 @@ class Fp8MoEMethod: ...@@ -617,7 +612,7 @@ class Fp8MoEMethod:
set_weight_attrs(w13_weight_scale, extra_weight_attrs) set_weight_attrs(w13_weight_scale, extra_weight_attrs)
set_weight_attrs(w2_weight_scale, extra_weight_attrs) set_weight_attrs(w2_weight_scale, extra_weight_attrs)
if get_bool_env_var("USE_INT4_WEIGHT"): if _is_hip and get_bool_env_var("USE_INT4_WEIGHT"):
extra_weight_attrs.update( extra_weight_attrs.update(
{"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value} {"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value}
) )
...@@ -649,7 +644,7 @@ class Fp8MoEMethod: ...@@ -649,7 +644,7 @@ class Fp8MoEMethod:
layer.w2_input_scale = None layer.w2_input_scale = None
def process_weights_after_loading(self, layer: Module) -> None: def process_weights_after_loading(self, layer: Module) -> None:
if get_bool_env_var("USE_INT4_WEIGHT"): if _is_hip and get_bool_env_var("USE_INT4_WEIGHT"):
self.process_weights_hip_int4(layer) self.process_weights_hip_int4(layer)
return return
...@@ -706,19 +701,11 @@ class Fp8MoEMethod: ...@@ -706,19 +701,11 @@ class Fp8MoEMethod:
requires_grad=False, requires_grad=False,
) )
for expert in range(layer.num_experts): for expert in range(layer.num_experts):
if _is_cuda:
w13_weight[expert, :, :], layer.w13_weight_scale[expert] = ( w13_weight[expert, :, :], layer.w13_weight_scale[expert] = (
sgl_scaled_fp8_quant(layer.w13_weight.data[expert, :, :]) scaled_fp8_quant(layer.w13_weight.data[expert, :, :])
) )
w2_weight[expert, :, :], layer.w2_weight_scale[expert] = ( w2_weight[expert, :, :], layer.w2_weight_scale[expert] = (
sgl_scaled_fp8_quant(layer.w2_weight.data[expert, :, :]) scaled_fp8_quant(layer.w2_weight.data[expert, :, :])
)
else:
w13_weight[expert, :, :], layer.w13_weight_scale[expert] = (
vllm_ops.scaled_fp8_quant(layer.w13_weight.data[expert, :, :])
)
w2_weight[expert, :, :], layer.w2_weight_scale[expert] = (
vllm_ops.scaled_fp8_quant(layer.w2_weight.data[expert, :, :])
) )
layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False) layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False)
layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False) layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False)
...@@ -796,18 +783,10 @@ class Fp8MoEMethod: ...@@ -796,18 +783,10 @@ class Fp8MoEMethod:
layer.w13_weight[expert_id][start : start + shard_size, :], layer.w13_weight[expert_id][start : start + shard_size, :],
layer.w13_weight_scale[expert_id][shard_id], layer.w13_weight_scale[expert_id][shard_id],
) )
if _is_cuda:
(
layer.w13_weight[expert_id][start : start + shard_size, :],
_,
) = sgl_scaled_fp8_quant(dq_weight, max_w13_scales[expert_id])
else:
( (
layer.w13_weight[expert_id][start : start + shard_size, :], layer.w13_weight[expert_id][start : start + shard_size, :],
_, _,
) = vllm_ops.scaled_fp8_quant( ) = scaled_fp8_quant(dq_weight, max_w13_scales[expert_id])
dq_weight, max_w13_scales[expert_id]
)
start += shard_size start += shard_size
layer.w13_weight_scale = torch.nn.Parameter( layer.w13_weight_scale = torch.nn.Parameter(
...@@ -930,7 +909,8 @@ class Fp8MoEMethod: ...@@ -930,7 +909,8 @@ class Fp8MoEMethod:
correction_bias=correction_bias, correction_bias=correction_bias,
) )
if _is_hip and get_bool_env_var("USE_INT4_WEIGHT"): if _is_hip:
if get_bool_env_var("USE_INT4_WEIGHT"):
# TODO: add triton kernel and add check get_bool_env_var("CK_MOE") # TODO: add triton kernel and add check get_bool_env_var("CK_MOE")
assert not no_combine, f"{no_combine=} is not supported." assert not no_combine, f"{no_combine=} is not supported."
return ck_moe_2stages_win4( return ck_moe_2stages_win4(
...@@ -942,10 +922,13 @@ class Fp8MoEMethod: ...@@ -942,10 +922,13 @@ class Fp8MoEMethod:
layer.w13_weight_scale1, layer.w13_weight_scale1,
layer.w2_weight_scale1, layer.w2_weight_scale1,
activation=( activation=(
ActivationType.Silu if activation == "silu" else ActivationType.Gelu ActivationType.Silu
if activation == "silu"
else ActivationType.Gelu
), ),
) )
if _is_hip and get_bool_env_var("CK_MOE"):
if get_bool_env_var("CK_MOE"):
assert not no_combine, f"{no_combine=} is not supported." assert not no_combine, f"{no_combine=} is not supported."
if self.block_quant: if self.block_quant:
# TODO(CK_MOE): FP8 block_quant only supports 'silu' for the time-being. # TODO(CK_MOE): FP8 block_quant only supports 'silu' for the time-being.
...@@ -978,7 +961,7 @@ class Fp8MoEMethod: ...@@ -978,7 +961,7 @@ class Fp8MoEMethod:
else ActivationType.Gelu else ActivationType.Gelu
), ),
) )
else:
# Expert fusion with FP8 quantization # Expert fusion with FP8 quantization
return fused_experts( return fused_experts(
x, x,
...@@ -996,9 +979,7 @@ class Fp8MoEMethod: ...@@ -996,9 +979,7 @@ class Fp8MoEMethod:
else layer.w13_weight_scale else layer.w13_weight_scale
), ),
w2_scale=( w2_scale=(
layer.w2_weight_scale_inv layer.w2_weight_scale_inv if self.block_quant else layer.w2_weight_scale
if self.block_quant
else layer.w2_weight_scale
), ),
a1_scale=layer.w13_input_scale, a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale, a2_scale=layer.w2_input_scale,
......
...@@ -34,15 +34,23 @@ from sglang.srt.utils import ( ...@@ -34,15 +34,23 @@ from sglang.srt.utils import (
supports_custom_op, supports_custom_op,
) )
_enable_jit_deepgemm = False
_is_hip = is_hip() _is_hip = is_hip()
fp8_type_ = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
_is_cuda = is_cuda() _is_cuda = is_cuda()
_fp8_type = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
if _is_hip:
fp8_max = 224.0
else:
fp8_max = torch.finfo(_fp8_type).max
fp8_min = -fp8_max
_enable_jit_deepgemm = False
if _is_cuda: if _is_cuda:
import deep_gemm import deep_gemm
from sgl_kernel import sgl_per_token_group_quant_fp8, sgl_per_token_quant_fp8 from sgl_kernel import (
sgl_per_tensor_quant_fp8,
sgl_per_token_group_quant_fp8,
sgl_per_token_quant_fp8,
)
sm_version = get_device_sm() sm_version = get_device_sm()
if sm_version == 90 and get_bool_env_var( if sm_version == 90 and get_bool_env_var(
...@@ -53,6 +61,7 @@ if _is_cuda: ...@@ -53,6 +61,7 @@ if _is_cuda:
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
if supports_custom_op(): if supports_custom_op():
def deep_gemm_fp8_fp8_bf16_nt( def deep_gemm_fp8_fp8_bf16_nt(
...@@ -179,7 +188,6 @@ def per_token_group_quant_fp8( ...@@ -179,7 +188,6 @@ def per_token_group_quant_fp8(
x: torch.Tensor, x: torch.Tensor,
group_size: int, group_size: int,
eps: float = 1e-10, eps: float = 1e-10,
dtype: torch.dtype = fp8_type_,
column_major_scales: bool = False, column_major_scales: bool = False,
scale_tma_aligned: bool = False, scale_tma_aligned: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
...@@ -192,7 +200,6 @@ def per_token_group_quant_fp8( ...@@ -192,7 +200,6 @@ def per_token_group_quant_fp8(
x: The input tenosr with ndim >= 2. x: The input tenosr with ndim >= 2.
group_size: The group size used for quantization. group_size: The group size used for quantization.
eps: The minimum to avoid dividing zero. eps: The minimum to avoid dividing zero.
dtype: The dype of output tensor.
Returns: Returns:
Tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the scaling factor for quantization. Tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the scaling factor for quantization.
...@@ -202,15 +209,7 @@ def per_token_group_quant_fp8( ...@@ -202,15 +209,7 @@ def per_token_group_quant_fp8(
), "the last dimension of `x` cannot be divisible by `group_size`" ), "the last dimension of `x` cannot be divisible by `group_size`"
assert x.is_contiguous(), "`x` is not contiguous" assert x.is_contiguous(), "`x` is not contiguous"
finfo = torch.finfo(dtype) x_q = torch.empty_like(x, device=x.device, dtype=_fp8_type)
fp8_max = finfo.max
if _is_hip:
fp8_max = 224.0
fp8_min = -fp8_max
x_q = torch.empty_like(x, device=x.device, dtype=dtype)
M = x.numel() // group_size M = x.numel() // group_size
N = group_size N = group_size
if column_major_scales: if column_major_scales:
...@@ -276,27 +275,18 @@ def sglang_per_token_group_quant_fp8( ...@@ -276,27 +275,18 @@ def sglang_per_token_group_quant_fp8(
x: torch.Tensor, x: torch.Tensor,
group_size: int, group_size: int,
eps: float = 1e-10, eps: float = 1e-10,
dtype: torch.dtype = fp8_type_,
): ):
assert ( assert (
x.shape[-1] % group_size == 0 x.shape[-1] % group_size == 0
), "the last dimension of `x` cannot be divisible by `group_size`" ), "the last dimension of `x` cannot be divisible by `group_size`"
assert x.is_contiguous(), "`x` is not contiguous" assert x.is_contiguous(), "`x` is not contiguous"
finfo = torch.finfo(dtype) x_q = torch.empty_like(x, device=x.device, dtype=_fp8_type)
fp8_max = finfo.max
fp8_min = -fp8_max
x_q = torch.empty_like(x, device=x.device, dtype=dtype)
M = x.numel() // group_size
N = group_size
x_s = torch.empty( x_s = torch.empty(
x.shape[:-1] + (x.shape[-1] // group_size,), x.shape[:-1] + (x.shape[-1] // group_size,),
device=x.device, device=x.device,
dtype=torch.float32, dtype=torch.float32,
) )
sgl_per_token_group_quant_fp8(x, x_q, x_s, group_size, eps, fp8_min, fp8_max) sgl_per_token_group_quant_fp8(x, x_q, x_s, group_size, eps, fp8_min, fp8_max)
return x_q, x_s return x_q, x_s
...@@ -304,7 +294,7 @@ def sglang_per_token_group_quant_fp8( ...@@ -304,7 +294,7 @@ def sglang_per_token_group_quant_fp8(
def sglang_per_token_quant_fp8( def sglang_per_token_quant_fp8(
x: torch.Tensor, x: torch.Tensor,
dtype: torch.dtype = fp8_type_, dtype: torch.dtype = _fp8_type,
): ):
assert x.is_contiguous(), "`x` is not contiguous" assert x.is_contiguous(), "`x` is not contiguous"
...@@ -368,7 +358,6 @@ def static_quant_fp8( ...@@ -368,7 +358,6 @@ def static_quant_fp8(
x: torch.Tensor, x: torch.Tensor,
x_s: torch.Tensor, x_s: torch.Tensor,
repeat_scale: bool = False, repeat_scale: bool = False,
dtype: torch.dtype = fp8_type_,
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
"""Function to perform static quantization using the given scale on an input tensor `x`. """Function to perform static quantization using the given scale on an input tensor `x`.
...@@ -386,15 +375,8 @@ def static_quant_fp8( ...@@ -386,15 +375,8 @@ def static_quant_fp8(
""" """
assert x.is_contiguous(), "`x` is not contiguous" assert x.is_contiguous(), "`x` is not contiguous"
assert x_s.numel() == 1, "only supports per-tensor scale" assert x_s.numel() == 1, "only supports per-tensor scale"
finfo = torch.finfo(dtype)
fp8_max = finfo.max
if _is_hip: x_q = torch.empty_like(x, device=x.device, dtype=_fp8_type)
fp8_max = 224.0
fp8_min = -fp8_max
x_q = torch.empty_like(x, device=x.device, dtype=dtype)
M = x.numel() // x.shape[-1] M = x.numel() // x.shape[-1]
N = x.shape[-1] N = x.shape[-1]
if repeat_scale: if repeat_scale:
...@@ -896,7 +878,7 @@ def _per_tensor_quant_mla_fp8_stage2( ...@@ -896,7 +878,7 @@ def _per_tensor_quant_mla_fp8_stage2(
def per_tensor_quant_mla_fp8( def per_tensor_quant_mla_fp8(
x: torch.Tensor, eps: float = 1e-12, dtype: torch.dtype = torch.float8_e4m3fn x: torch.Tensor, eps: float = 1e-12
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
""" """
This function quantizes input values to float8 values with tensor-wise quantization This function quantizes input values to float8 values with tensor-wise quantization
...@@ -904,13 +886,7 @@ def per_tensor_quant_mla_fp8( ...@@ -904,13 +886,7 @@ def per_tensor_quant_mla_fp8(
""" """
assert x.dim() == 3, "`x` is not a 3d-tensor" assert x.dim() == 3, "`x` is not a 3d-tensor"
finfo = torch.finfo(dtype) x_q = x.new_empty(x.size(), dtype=_fp8_type)
fp8_max = finfo.max
if _is_hip:
dtype = torch.float8_e4m3fnuz
fp8_max = 224.0
x_q = x.new_empty(x.size(), dtype=dtype)
x_s = torch.zeros((1,), dtype=torch.float32, device=x.device) x_s = torch.zeros((1,), dtype=torch.float32, device=x.device)
num_head, num_seq, head_size = x.shape num_head, num_seq, head_size = x.shape
...@@ -935,9 +911,64 @@ def per_tensor_quant_mla_fp8( ...@@ -935,9 +911,64 @@ def per_tensor_quant_mla_fp8(
head_size, head_size,
x.stride(0), x.stride(0),
x.stride(1), x.stride(1),
-fp8_max, fp8_min,
fp8_max, fp8_max,
BLOCK_SIZE, BLOCK_SIZE,
) )
return x_q, x_s return x_q, x_s
def scaled_fp8_quant(
input: torch.Tensor,
scale: Optional[torch.Tensor] = None,
num_token_padding: Optional[int] = None,
use_per_token_if_dynamic: bool = False,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Quantize input tensor to FP8 (8-bit floating point) format.
Args:
input (torch.Tensor): Input tensor to be quantized
scale (Optional[torch.Tensor]): Pre-computed scaling factor for static quantization.
If None, scales will be computed dynamically.
num_token_padding (Optional[int]): If specified, pad the first dimension
of the output to at least this value.
use_per_token_if_dynamic (bool): When using dynamic scaling (scale=None),
determines the quantization granularity:
- True: compute scale per token
- False: compute single scale per tensor
Returns:
Tuple[torch.Tensor, torch.Tensor]: A tuple containing:
- quantized_tensor: The FP8 quantized version of input
- scale_tensor: The scaling factors used for quantization
Raises:
AssertionError: If input is not 2D or if static scale's numel != 1
"""
assert input.ndim == 2, f"Expected 2D input tensor, got {input.ndim}D"
shape = input.shape
out_dtype = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
if num_token_padding:
shape = (max(num_token_padding, input.shape[0]), shape[1])
output = torch.empty(shape, device=input.device, dtype=out_dtype)
if scale is None:
# Dynamic scaling
if use_per_token_if_dynamic:
scale = torch.empty((shape[0], 1), device=input.device, dtype=torch.float32)
sgl_per_token_quant_fp8(input, output, scale)
else:
scale = torch.zeros(1, device=input.device, dtype=torch.float32)
sgl_per_tensor_quant_fp8(
input, output, scale, is_static=False
) # False for dynamic
else:
# Static scaling
assert scale.numel() == 1, f"Expected scalar scale, got numel={scale.numel()}"
sgl_per_tensor_quant_fp8(
input, output, scale, is_static=True
) # True for static
return output, scale
import os
from typing import List, Optional, Tuple from typing import List, Optional, Tuple
import torch import torch
try:
from vllm import _custom_ops as vllm_ops
VLLM_AVAILABLE = True
except ImportError:
VLLM_AVAILABLE = False
from sglang.srt.layers.quantization.fp8_kernel import ( from sglang.srt.layers.quantization.fp8_kernel import (
_enable_jit_deepgemm, _enable_jit_deepgemm,
per_token_group_quant_fp8, per_token_group_quant_fp8,
scaled_fp8_quant,
sglang_per_token_quant_fp8,
static_quant_fp8, static_quant_fp8,
w8a8_block_fp8_matmul, w8a8_block_fp8_matmul,
) )
...@@ -17,30 +25,20 @@ from sglang.srt.utils import ( ...@@ -17,30 +25,20 @@ from sglang.srt.utils import (
is_hip, is_hip,
) )
try:
import vllm
from vllm import _custom_ops as ops
VLLM_AVAILABLE = True
except ImportError:
VLLM_AVAILABLE = False
use_vllm_cutlass_w8a8_fp8_kernel = get_bool_env_var("USE_VLLM_CUTLASS_W8A8_FP8_KERNEL")
_is_hip = is_hip() _is_hip = is_hip()
_is_cuda = is_cuda()
if _is_hip and get_bool_env_var("CK_MOE"): if _is_hip and get_bool_env_var("CK_MOE"):
from aiter import gemm_a8w8_blockscale from aiter import gemm_a8w8_blockscale
_is_cuda = is_cuda()
if _is_cuda: if _is_cuda:
from sgl_kernel import fp8_blockwise_scaled_mm, fp8_scaled_mm from sgl_kernel import fp8_blockwise_scaled_mm, fp8_scaled_mm
from sglang.srt.custom_op import scaled_fp8_quant as sgl_scaled_fp8_quant use_vllm_cutlass_w8a8_fp8_kernel = get_bool_env_var("USE_VLLM_CUTLASS_W8A8_FP8_KERNEL")
from sglang.srt.layers.quantization.fp8_kernel import sglang_per_token_quant_fp8
# Input scaling factors are no longer optional in _scaled_mm starting # Input scaling factors are no longer optional in _scaled_mm starting
# from pytorch 2.5. Allocating a dummy tensor to pass as input_scale # from pytorch 2.5. Allocating a dummy tensor to pass as input_scale
TORCH_DEVICE_IDENTITY = torch.ones(1, dtype=torch.float32) TORCH_DEVICE_IDENTITY = None
_TORCH_VERSION = torch.__version__.split("+")[0] _TORCH_VERSION = torch.__version__.split("+")[0]
try: try:
...@@ -214,7 +212,7 @@ def block_quant_to_tensor_quant( ...@@ -214,7 +212,7 @@ def block_quant_to_tensor_quant(
x_dq_block_tiles[j][i][:, :] = x_dq_block_tiles[j][i] * x_s[j][i] x_dq_block_tiles[j][i][:, :] = x_dq_block_tiles[j][i] * x_s[j][i]
x_q_tensor, scale = ( x_q_tensor, scale = (
sgl_scaled_fp8_quant(x_dq_block) scaled_fp8_quant(x_dq_block)
if _is_cuda if _is_cuda
else input_to_float8(x_dq_block, dtype=x_q_block.dtype) else input_to_float8(x_dq_block, dtype=x_q_block.dtype)
) )
...@@ -227,7 +225,7 @@ def channel_quant_to_tensor_quant( ...@@ -227,7 +225,7 @@ def channel_quant_to_tensor_quant(
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
x_dq_channel = x_q_channel.to(torch.float32) * x_s x_dq_channel = x_q_channel.to(torch.float32) * x_s
x_q_tensor, scale = ( x_q_tensor, scale = (
sgl_scaled_fp8_quant(x_dq_channel) scaled_fp8_quant(x_dq_channel)
if _is_cuda if _is_cuda
else input_to_float8(x_dq_channel, dtype=x_q_channel.dtype) else input_to_float8(x_dq_channel, dtype=x_q_channel.dtype)
) )
...@@ -264,7 +262,7 @@ def apply_fp8_linear( ...@@ -264,7 +262,7 @@ def apply_fp8_linear(
# final solution should be: 1. add support to per-tensor activation scaling. # final solution should be: 1. add support to per-tensor activation scaling.
# 2. solve the torch.compile error from weight_scale.numel() == 1 and x_scale.numel() > 1 (below line#308) # 2. solve the torch.compile error from weight_scale.numel() == 1 and x_scale.numel() > 1 (below line#308)
if _is_hip and weight_scale.numel() == 1: if _is_hip and weight_scale.numel() == 1:
qinput, x_scale = ops.scaled_fp8_quant( qinput, x_scale = vllm_ops.scaled_fp8_quant(
input_2d, input_2d,
input_scale, input_scale,
use_per_token_if_dynamic=use_per_token_if_dynamic, use_per_token_if_dynamic=use_per_token_if_dynamic,
...@@ -275,10 +273,9 @@ def apply_fp8_linear( ...@@ -275,10 +273,9 @@ def apply_fp8_linear(
) )
if cutlass_fp8_supported: if cutlass_fp8_supported:
try:
if VLLM_AVAILABLE and use_vllm_cutlass_w8a8_fp8_kernel: if VLLM_AVAILABLE and use_vllm_cutlass_w8a8_fp8_kernel:
# Fall back to vllm cutlass w8a8 fp8 kernel # Fall back to vllm cutlass w8a8 fp8 kernel
output = ops.cutlass_scaled_mm( output = vllm_ops.cutlass_scaled_mm(
qinput, qinput,
weight, weight,
out_dtype=input.dtype, out_dtype=input.dtype,
...@@ -299,8 +296,6 @@ def apply_fp8_linear( ...@@ -299,8 +296,6 @@ def apply_fp8_linear(
bias=bias, bias=bias,
) )
return output.view(*output_shape) return output.view(*output_shape)
except (ImportError, NameError, AttributeError):
pass
# torch.scaled_mm supports per tensor weights + activations only # torch.scaled_mm supports per tensor weights + activations only
# so fallback to naive if per channel or per token # so fallback to naive if per channel or per token
...@@ -343,8 +338,10 @@ def apply_fp8_linear( ...@@ -343,8 +338,10 @@ def apply_fp8_linear(
# Making sure the dummy tensor is on the same device as the weight # Making sure the dummy tensor is on the same device as the weight
global TORCH_DEVICE_IDENTITY global TORCH_DEVICE_IDENTITY
if TORCH_DEVICE_IDENTITY.device != weight.device: if TORCH_DEVICE_IDENTITY is None:
TORCH_DEVICE_IDENTITY = TORCH_DEVICE_IDENTITY.to(weight.device) TORCH_DEVICE_IDENTITY = torch.ones(
1, dtype=torch.float32, device=weight.device
)
# GEMM # GEMM
# This computes C = (X * W). # This computes C = (X * W).
...@@ -372,13 +369,6 @@ def apply_fp8_linear( ...@@ -372,13 +369,6 @@ def apply_fp8_linear(
return output.to(dtype=input.dtype).view(*output_shape) return output.to(dtype=input.dtype).view(*output_shape)
def maybe_create_device_identity():
# Allocate dummy ones tensor for torch._scaled_mm
global TORCH_DEVICE_IDENTITY
if TORCH_DEVICE_IDENTITY is None:
TORCH_DEVICE_IDENTITY = torch.ones(1, dtype=torch.float32)
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/utils/w8a8_utils.py # Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/utils/w8a8_utils.py
# TODO(luka): follow similar pattern for marlin and block-fp8-linear # TODO(luka): follow similar pattern for marlin and block-fp8-linear
# https://github.com/vllm-project/vllm/issues/14397 # https://github.com/vllm-project/vllm/issues/14397
...@@ -405,9 +395,7 @@ class Fp8LinearOp: ...@@ -405,9 +395,7 @@ class Fp8LinearOp:
# We also don't pad when using torch.compile, # We also don't pad when using torch.compile,
# as it breaks with dynamic shapes. # as it breaks with dynamic shapes.
if pad_output is None: if pad_output is None:
enable_torch_compile = os.environ.get( enable_torch_compile = get_bool_env_var("SGLANG_ENABLE_TORCH_COMPILE")
"SGLANG_ENABLE_TORCH_COMPILE", "0"
).lower() in ("1", "true", "yes")
pad_output = not enable_torch_compile pad_output = not enable_torch_compile
self.output_padding = 17 if pad_output else None self.output_padding = 17 if pad_output else None
...@@ -439,13 +427,13 @@ class Fp8LinearOp: ...@@ -439,13 +427,13 @@ class Fp8LinearOp:
# for sgl-kernel fp8_scaled_mm, it support per channel W now # for sgl-kernel fp8_scaled_mm, it support per channel W now
if self.cutlass_fp8_supported and weight_scale.numel() == weight.shape[1]: if self.cutlass_fp8_supported and weight_scale.numel() == weight.shape[1]:
if _is_cuda: if _is_cuda:
qinput, x_scale = sgl_scaled_fp8_quant( qinput, x_scale = scaled_fp8_quant(
input_2d, input_2d,
input_scale, input_scale,
use_per_token_if_dynamic=use_per_token_if_dynamic, use_per_token_if_dynamic=use_per_token_if_dynamic,
) )
else: else:
qinput, x_scale = ops.scaled_fp8_quant( qinput, x_scale = vllm_ops.scaled_fp8_quant(
input_2d, input_2d,
input_scale, input_scale,
scale_ub=input_scale_ub, scale_ub=input_scale_ub,
...@@ -455,7 +443,7 @@ class Fp8LinearOp: ...@@ -455,7 +443,7 @@ class Fp8LinearOp:
# Fused GEMM_DQ # Fused GEMM_DQ
if VLLM_AVAILABLE and use_vllm_cutlass_w8a8_fp8_kernel: if VLLM_AVAILABLE and use_vllm_cutlass_w8a8_fp8_kernel:
# Fall back to vllm cutlass w8a8 fp8 kernel # Fall back to vllm cutlass w8a8 fp8 kernel
output = ops.cutlass_scaled_mm( output = vllm_ops.cutlass_scaled_mm(
qinput, qinput,
weight, weight,
out_dtype=input.dtype, out_dtype=input.dtype,
...@@ -482,14 +470,14 @@ class Fp8LinearOp: ...@@ -482,14 +470,14 @@ class Fp8LinearOp:
else: else:
# Maybe apply padding to output, see comment in __init__ # Maybe apply padding to output, see comment in __init__
if _is_cuda: if _is_cuda:
qinput, x_scale = sgl_scaled_fp8_quant( qinput, x_scale = scaled_fp8_quant(
input_2d, input_2d,
input_scale, input_scale,
num_token_padding=self.output_padding, num_token_padding=self.output_padding,
use_per_token_if_dynamic=use_per_token_if_dynamic, use_per_token_if_dynamic=use_per_token_if_dynamic,
) )
else: else:
qinput, x_scale = ops.scaled_fp8_quant( qinput, x_scale = vllm_ops.scaled_fp8_quant(
input_2d, input_2d,
input_scale, input_scale,
num_token_padding=self.output_padding, num_token_padding=self.output_padding,
...@@ -562,9 +550,12 @@ class Fp8LinearOp: ...@@ -562,9 +550,12 @@ class Fp8LinearOp:
# This computes C = (X * W). # This computes C = (X * W).
# Output in fp32 to allow subsequent ops to happen in-place # Output in fp32 to allow subsequent ops to happen in-place
# Making sure the dummy tensor is on the same device as the weight
global TORCH_DEVICE_IDENTITY global TORCH_DEVICE_IDENTITY
if TORCH_DEVICE_IDENTITY.device != weight.device: if TORCH_DEVICE_IDENTITY is None:
TORCH_DEVICE_IDENTITY = TORCH_DEVICE_IDENTITY.to(weight.device) TORCH_DEVICE_IDENTITY = torch.ones(
1, dtype=torch.float32, device=weight.device
)
output = torch._scaled_mm( output = torch._scaled_mm(
qinput, qinput,
......
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/utils/quant_utils.py # Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/utils/quant_utils.py
from types import MappingProxyType from types import MappingProxyType
from typing import List, Mapping, Optional, Tuple, Union from typing import List, Mapping, Tuple, Union
import torch import torch
from sglang.srt.layers.quantization.fp8_kernel import scaled_fp8_quant
from sglang.srt.utils import is_cuda from sglang.srt.utils import is_cuda
_is_cuda = is_cuda() _is_cuda = is_cuda()
if _is_cuda: if not _is_cuda:
from sglang.srt.custom_op import scaled_fp8_quant as sgl_scaled_fp8_quant from vllm._custom_ops import scaled_fp8_quant
else:
from vllm import _custom_ops as vllm_ops
def is_fp8_fnuz() -> bool: def is_fp8_fnuz() -> bool:
...@@ -116,12 +115,7 @@ def requantize_with_max_scale( ...@@ -116,12 +115,7 @@ def requantize_with_max_scale(
for idx, logical_width in enumerate(logical_widths): for idx, logical_width in enumerate(logical_widths):
end = start + logical_width end = start + logical_width
weight_dq = per_tensor_dequantize(weight[start:end, :], weight_scale[idx]) weight_dq = per_tensor_dequantize(weight[start:end, :], weight_scale[idx])
if _is_cuda: weight[start:end, :], _ = scaled_fp8_quant(weight_dq, max_w_scale)
weight[start:end, :], _ = sgl_scaled_fp8_quant(weight_dq, max_w_scale)
else:
weight[start:end, :], _ = vllm_ops.scaled_fp8_quant(
weight_dq, max_w_scale
)
start = end start = end
return max_w_scale, weight return max_w_scale, weight
......
from typing import Any, Callable, Dict, List, Optional from typing import Any, Callable, Dict, List, Optional
import torch import torch
from sglang.srt.utils import is_cuda_available, set_weight_attrs
is_cuda = is_cuda_available()
if is_cuda:
from sgl_kernel import int8_scaled_mm
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
from sglang.srt.distributed import get_tensor_model_parallel_world_size from sglang.srt.distributed import get_tensor_model_parallel_world_size
...@@ -18,6 +11,11 @@ from sglang.srt.layers.quantization.base_config import ( ...@@ -18,6 +11,11 @@ from sglang.srt.layers.quantization.base_config import (
QuantizeMethodBase, QuantizeMethodBase,
) )
from sglang.srt.layers.quantization.int8_kernel import per_token_quant_int8 from sglang.srt.layers.quantization.int8_kernel import per_token_quant_int8
from sglang.srt.utils import is_cuda_available, set_weight_attrs
is_cuda = is_cuda_available()
if is_cuda:
from sgl_kernel import int8_scaled_mm
class W8A8Int8Config(QuantizationConfig): class W8A8Int8Config(QuantizationConfig):
......
...@@ -11,10 +11,11 @@ from sglang.srt.custom_op import CustomOp ...@@ -11,10 +11,11 @@ from sglang.srt.custom_op import CustomOp
from sglang.srt.utils import is_cuda_available from sglang.srt.utils import is_cuda_available
_is_cuda_available = is_cuda_available() _is_cuda_available = is_cuda_available()
if _is_cuda_available: if _is_cuda_available:
from sgl_kernel import apply_rope_with_cos_sin_cache_inplace from sgl_kernel import apply_rope_with_cos_sin_cache_inplace
else: else:
from vllm import _custom_ops as ops from vllm._custom_ops import rotary_embedding as vllm_rotary_embedding
def _rotate_neox(x: torch.Tensor) -> torch.Tensor: def _rotate_neox(x: torch.Tensor) -> torch.Tensor:
...@@ -159,7 +160,7 @@ class RotaryEmbedding(CustomOp): ...@@ -159,7 +160,7 @@ class RotaryEmbedding(CustomOp):
) )
else: else:
self.cos_sin_cache = self.cos_sin_cache.to(query.device, dtype=query.dtype) self.cos_sin_cache = self.cos_sin_cache.to(query.device, dtype=query.dtype)
ops.rotary_embedding( vllm_rotary_embedding(
positions, positions,
query, query,
key, key,
......
from sglang.srt.lora.backend.base_backend import BaseLoRABackend
def get_backend_from_name(name: str) -> BaseLoRABackend:
"""
Get corresponding backend class from backend's name
"""
if name == "triton":
from sglang.srt.lora.backend.triton_backend import TritonLoRABackend
return TritonLoRABackend
elif name == "flashinfer":
from sglang.srt.lora.backend.flashinfer_backend import FlashInferLoRABackend
return FlashInferLoRABackend
else:
raise ValueError(f"Invalid backend: {name}")
__all__ = [
"BaseLoRABackend",
"FlashInferLoRABackend",
"TritonLoRABackend",
"get_backend_from_name",
]
...@@ -75,7 +75,7 @@ class BaseLoRABackend: ...@@ -75,7 +75,7 @@ class BaseLoRABackend:
qkv_lora_a: torch.Tensor, qkv_lora_a: torch.Tensor,
qkv_lora_b: Union[torch.Tensor, Tuple[torch.Tensor]], qkv_lora_b: Union[torch.Tensor, Tuple[torch.Tensor]],
*args, *args,
**kwargs **kwargs,
) -> torch.Tensor: ) -> torch.Tensor:
"""Run the lora pass for QKV Layer. """Run the lora pass for QKV Layer.
...@@ -98,7 +98,7 @@ class BaseLoRABackend: ...@@ -98,7 +98,7 @@ class BaseLoRABackend:
gate_up_lora_a: torch.Tensor, gate_up_lora_a: torch.Tensor,
gate_up_lora_b: Union[torch.Tensor, Tuple[torch.Tensor]], gate_up_lora_b: Union[torch.Tensor, Tuple[torch.Tensor]],
*args, *args,
**kwargs **kwargs,
) -> torch.Tensor: ) -> torch.Tensor:
"""Run the lora pass for gate_up_proj, usually attached to MergedColumnParallelLayer. """Run the lora pass for gate_up_proj, usually attached to MergedColumnParallelLayer.
...@@ -115,3 +115,19 @@ class BaseLoRABackend: ...@@ -115,3 +115,19 @@ class BaseLoRABackend:
def set_batch_info(self, batch_info: LoRABatchInfo): def set_batch_info(self, batch_info: LoRABatchInfo):
self.batch_info = batch_info self.batch_info = batch_info
def get_backend_from_name(name: str) -> BaseLoRABackend:
"""
Get corresponding backend class from backend's name
"""
if name == "triton":
from sglang.srt.lora.backend.triton_backend import TritonLoRABackend
return TritonLoRABackend
elif name == "flashinfer":
from sglang.srt.lora.backend.flashinfer_backend import FlashInferLoRABackend
return FlashInferLoRABackend
else:
raise ValueError(f"Invalid backend: {name}")
...@@ -2,7 +2,7 @@ from typing import Tuple ...@@ -2,7 +2,7 @@ from typing import Tuple
import torch import torch
from sglang.srt.lora.backend import BaseLoRABackend from sglang.srt.lora.backend.base_backend import BaseLoRABackend
from sglang.srt.lora.utils import LoRABatchInfo from sglang.srt.lora.utils import LoRABatchInfo
from sglang.srt.utils import is_flashinfer_available from sglang.srt.utils import is_flashinfer_available
......
import torch import torch
from sglang.srt.lora.backend import BaseLoRABackend from sglang.srt.lora.backend.base_backend import BaseLoRABackend
from sglang.srt.lora.triton_ops import ( from sglang.srt.lora.triton_ops import (
gate_up_lora_b_fwd, gate_up_lora_b_fwd,
qkv_lora_b_fwd, qkv_lora_b_fwd,
......
...@@ -16,7 +16,7 @@ from sglang.srt.layers.linear import ( ...@@ -16,7 +16,7 @@ from sglang.srt.layers.linear import (
RowParallelLinear, RowParallelLinear,
) )
from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
from sglang.srt.lora.backend import BaseLoRABackend from sglang.srt.lora.backend.base_backend import BaseLoRABackend
class BaseLayerWithLoRA(nn.Module): class BaseLayerWithLoRA(nn.Module):
......
...@@ -27,7 +27,7 @@ from torch import nn ...@@ -27,7 +27,7 @@ from torch import nn
from sglang.srt.configs.load_config import LoadConfig from sglang.srt.configs.load_config import LoadConfig
from sglang.srt.hf_transformers_utils import AutoConfig from sglang.srt.hf_transformers_utils import AutoConfig
from sglang.srt.lora.backend import BaseLoRABackend from sglang.srt.lora.backend.base_backend import BaseLoRABackend
from sglang.srt.lora.lora_config import LoRAConfig from sglang.srt.lora.lora_config import LoRAConfig
from sglang.srt.model_loader.loader import DefaultModelLoader from sglang.srt.model_loader.loader import DefaultModelLoader
......
...@@ -22,7 +22,7 @@ import torch ...@@ -22,7 +22,7 @@ import torch
from sglang.srt.configs.load_config import LoadConfig from sglang.srt.configs.load_config import LoadConfig
from sglang.srt.hf_transformers_utils import AutoConfig from sglang.srt.hf_transformers_utils import AutoConfig
from sglang.srt.lora.backend import BaseLoRABackend, get_backend_from_name from sglang.srt.lora.backend.base_backend import BaseLoRABackend, get_backend_from_name
from sglang.srt.lora.layers import BaseLayerWithLoRA, get_lora_layer from sglang.srt.lora.layers import BaseLayerWithLoRA, get_lora_layer
from sglang.srt.lora.lora import LoRAAdapter from sglang.srt.lora.lora import LoRAAdapter
from sglang.srt.lora.lora_config import LoRAConfig from sglang.srt.lora.lora_config import LoRAConfig
......
...@@ -14,7 +14,6 @@ ...@@ -14,7 +14,6 @@
"""DetokenizerManager is a process that detokenizes the token ids.""" """DetokenizerManager is a process that detokenizes the token ids."""
import dataclasses import dataclasses
import json
import logging import logging
import os import os
import signal import signal
......
""" """
Multi-modality utils Multi-modality utils
""" """
import logging
from abc import abstractmethod from abc import abstractmethod
from typing import Callable, List, Optional, Tuple from typing import Callable, List, Optional, Tuple
...@@ -12,11 +13,11 @@ from sglang.srt.managers.schedule_batch import ( ...@@ -12,11 +13,11 @@ from sglang.srt.managers.schedule_batch import (
MultimodalDataItem, MultimodalDataItem,
MultimodalInputs, MultimodalInputs,
global_server_args_dict, global_server_args_dict,
logger,
) )
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.utils import print_warning_once from sglang.srt.utils import print_warning_once
from sglang.utils import logger
logger = logging.getLogger(__name__)
class MultiModalityDataPaddingPattern: class MultiModalityDataPaddingPattern:
......
...@@ -5,8 +5,6 @@ import logging ...@@ -5,8 +5,6 @@ import logging
import pkgutil import pkgutil
from functools import lru_cache from functools import lru_cache
from transformers import PROCESSOR_MAPPING
from sglang.srt.managers.multimodal_processors.base_processor import ( from sglang.srt.managers.multimodal_processors.base_processor import (
BaseMultimodalProcessor, BaseMultimodalProcessor,
) )
......
...@@ -8,8 +8,6 @@ from typing import List, Optional ...@@ -8,8 +8,6 @@ from typing import List, Optional
import numpy as np import numpy as np
import PIL import PIL
from decord import VideoReader, cpu
from PIL import Image
from transformers import BaseImageProcessorFast from transformers import BaseImageProcessorFast
from sglang.srt.managers.schedule_batch import Modality from sglang.srt.managers.schedule_batch import Modality
...@@ -102,6 +100,9 @@ class BaseMultimodalProcessor(ABC): ...@@ -102,6 +100,9 @@ class BaseMultimodalProcessor(ABC):
""" """
estimate the total frame count from all visual input estimate the total frame count from all visual input
""" """
# Lazy import because decord is not available on some arm platforms.
from decord import VideoReader, cpu
# Before processing inputs # Before processing inputs
estimated_frames_list = [] estimated_frames_list = []
for image in image_data: for image in image_data:
......
...@@ -37,11 +37,11 @@ from sglang.srt.model_executor.forward_batch_info import ( ...@@ -37,11 +37,11 @@ from sglang.srt.model_executor.forward_batch_info import (
from sglang.srt.patch_torch import monkey_patch_torch_compile from sglang.srt.patch_torch import monkey_patch_torch_compile
from sglang.srt.utils import get_available_gpu_memory, is_hip from sglang.srt.utils import get_available_gpu_memory, is_hip
_is_hip = is_hip()
if TYPE_CHECKING: if TYPE_CHECKING:
from sglang.srt.model_executor.model_runner import ModelRunner from sglang.srt.model_executor.model_runner import ModelRunner
_is_hip = is_hip()
def _to_torch(model: torch.nn.Module, reverse: bool, num_tokens: int): def _to_torch(model: torch.nn.Module, reverse: bool, num_tokens: int):
for sub in model._modules.values(): for sub in model._modules.values():
......
...@@ -320,7 +320,6 @@ class ModelRunner: ...@@ -320,7 +320,6 @@ class ModelRunner:
logger.info(f"DeepEP is turned on. DeepEP mode: {server_args.deepep_mode}") logger.info(f"DeepEP is turned on. DeepEP mode: {server_args.deepep_mode}")
if not self.use_mla_backend: if not self.use_mla_backend:
logger.info("Disable chunked prefix cache for non-MLA backend.")
server_args.disable_chunked_prefix_cache = True server_args.disable_chunked_prefix_cache = True
elif self.page_size > 1: elif self.page_size > 1:
logger.info("Disable chunked prefix cache when page size > 1.") logger.info("Disable chunked prefix cache when page size > 1.")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment