Unverified Commit 177320a5 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Clean up imports (#5467)

parent d7bc19a4
......@@ -17,7 +17,6 @@ from sglang.srt.layers.quantization.compressed_tensors.schemes import (
)
from sglang.srt.layers.quantization.fp8_utils import (
Fp8LinearOp,
maybe_create_device_identity,
normalize_e4m3fn_to_e4m3fnuz,
)
from sglang.srt.layers.quantization.utils import is_fp8_fnuz, requantize_with_max_scale
......@@ -99,8 +98,6 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
weight_loader: Callable,
**kwargs,
):
maybe_create_device_identity()
output_size_per_partition = sum(output_partition_sizes)
layer.logical_widths = output_partition_sizes
......
......@@ -8,15 +8,6 @@ import torch.nn.functional as F
from torch.nn import Module
from torch.nn.parameter import Parameter
from sglang.srt.layers.quantization.kv_cache import BaseKVCacheMethod
from sglang.srt.layers.quantization.utils import (
all_close_1d,
convert_to_channelwise,
is_layer_skipped,
per_tensor_dequantize,
requantize_with_max_scale,
)
try:
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
apply_fp8_marlin_linear,
......@@ -27,11 +18,12 @@ try:
except ImportError:
MARLIN_FP8_AVAILABLE = False
def apply_fp8_marlin_linear(*args, **kwargs):
raise ImportError("vllm is not installed")
def dummy_func(*args, **kwargs):
raise ImportError(
"marlin FP8 requires some operators from vllm. Please install vllm."
)
def prepare_fp8_layer_for_marlin(*args, **kwargs):
raise ImportError("vllm is not installed")
apply_fp8_marlin_linear = prepare_fp8_layer_for_marlin = dummy_func
from sglang.srt.distributed import get_tensor_model_parallel_world_size
......@@ -49,7 +41,10 @@ from sglang.srt.layers.quantization.base_config import (
QuantizationConfig,
QuantizeMethodBase,
)
from sglang.srt.layers.quantization.fp8_kernel import per_token_group_quant_fp8
from sglang.srt.layers.quantization.fp8_kernel import (
per_token_group_quant_fp8,
scaled_fp8_quant,
)
from sglang.srt.layers.quantization.fp8_utils import (
apply_fp8_linear,
apply_w8a8_block_fp8_linear,
......@@ -57,30 +52,35 @@ from sglang.srt.layers.quantization.fp8_utils import (
input_to_float8,
normalize_e4m3fn_to_e4m3fnuz,
)
from sglang.srt.layers.quantization.kv_cache import BaseKVCacheMethod
from sglang.srt.layers.quantization.utils import (
all_close_1d,
convert_to_channelwise,
is_layer_skipped,
per_tensor_dequantize,
requantize_with_max_scale,
)
from sglang.srt.utils import (
get_bool_env_var,
is_cuda,
is_hip,
permute_weight,
print_warning_once,
set_weight_attrs,
)
ACTIVATION_SCHEMES = ["static", "dynamic"]
_is_hip = is_hip()
_is_cuda = is_cuda()
if _is_hip:
from aiter import ActivationType
from aiter.fused_moe_bf16_asm import asm_moe, ck_moe_2stages, ck_moe_2stages_win4
from aiter.ops.shuffle import shuffle_weight
_is_cuda = is_cuda()
if not _is_cuda:
from vllm._custom_ops import scaled_fp8_quant
if _is_cuda:
from sglang.srt.custom_op import scaled_fp8_quant as sgl_scaled_fp8_quant
else:
from vllm import _custom_ops as vllm_ops
ACTIVATION_SCHEMES = ["static", "dynamic"]
logger = logging.getLogger(__name__)
......@@ -243,7 +243,6 @@ class Fp8LinearMethod(LinearMethodBase):
)
layer.logical_widths = output_partition_sizes
layer.input_size_per_partition = input_size_per_partition
layer.output_size_per_partition = output_size_per_partition
layer.orig_dtype = params_dtype
......@@ -327,7 +326,9 @@ class Fp8LinearMethod(LinearMethodBase):
layer.weight_scale_inv.data, requires_grad=False
)
return
layer.weight = torch.nn.Parameter(layer.weight.data, requires_grad=False)
# If checkpoint not serialized fp8, quantize the weights.
if not self.quant_config.is_checkpoint_fp8_serialized:
if self.cutlass_fp8_supported or self.use_marlin:
......@@ -391,12 +392,9 @@ class Fp8LinearMethod(LinearMethodBase):
)
if self.use_marlin:
try:
prepare_fp8_layer_for_marlin(layer)
# Activations not quantized for marlin.
del layer.input_scale
except ImportError:
self.use_marlin = False
prepare_fp8_layer_for_marlin(layer)
# Activations not quantized for marlin.
del layer.input_scale
def apply(
self,
......@@ -406,18 +404,15 @@ class Fp8LinearMethod(LinearMethodBase):
) -> torch.Tensor:
if self.use_marlin:
try:
return apply_fp8_marlin_linear(
input=x,
weight=layer.weight,
weight_scale=layer.weight_scale,
workspace=layer.workspace,
size_n=layer.output_size_per_partition,
size_k=layer.input_size_per_partition,
bias=bias,
)
except ImportError:
self.use_marlin = False
return apply_fp8_marlin_linear(
input=x,
weight=layer.weight,
weight_scale=layer.weight_scale,
workspace=layer.workspace,
size_n=layer.output_size_per_partition,
size_k=layer.input_size_per_partition,
bias=bias,
)
if self.block_quant:
return apply_w8a8_block_fp8_linear(
......@@ -516,7 +511,7 @@ class Fp8MoEMethod:
)
# WEIGHTS
if get_bool_env_var("USE_INT4_WEIGHT"):
if _is_hip and get_bool_env_var("USE_INT4_WEIGHT"):
# INT4 MoE weight - INT32 packed
w13_weight = torch.nn.Parameter(
torch.empty(
......@@ -617,7 +612,7 @@ class Fp8MoEMethod:
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
if get_bool_env_var("USE_INT4_WEIGHT"):
if _is_hip and get_bool_env_var("USE_INT4_WEIGHT"):
extra_weight_attrs.update(
{"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value}
)
......@@ -649,7 +644,7 @@ class Fp8MoEMethod:
layer.w2_input_scale = None
def process_weights_after_loading(self, layer: Module) -> None:
if get_bool_env_var("USE_INT4_WEIGHT"):
if _is_hip and get_bool_env_var("USE_INT4_WEIGHT"):
self.process_weights_hip_int4(layer)
return
......@@ -706,20 +701,12 @@ class Fp8MoEMethod:
requires_grad=False,
)
for expert in range(layer.num_experts):
if _is_cuda:
w13_weight[expert, :, :], layer.w13_weight_scale[expert] = (
sgl_scaled_fp8_quant(layer.w13_weight.data[expert, :, :])
)
w2_weight[expert, :, :], layer.w2_weight_scale[expert] = (
sgl_scaled_fp8_quant(layer.w2_weight.data[expert, :, :])
)
else:
w13_weight[expert, :, :], layer.w13_weight_scale[expert] = (
vllm_ops.scaled_fp8_quant(layer.w13_weight.data[expert, :, :])
)
w2_weight[expert, :, :], layer.w2_weight_scale[expert] = (
vllm_ops.scaled_fp8_quant(layer.w2_weight.data[expert, :, :])
)
w13_weight[expert, :, :], layer.w13_weight_scale[expert] = (
scaled_fp8_quant(layer.w13_weight.data[expert, :, :])
)
w2_weight[expert, :, :], layer.w2_weight_scale[expert] = (
scaled_fp8_quant(layer.w2_weight.data[expert, :, :])
)
layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False)
layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False)
......@@ -796,18 +783,10 @@ class Fp8MoEMethod:
layer.w13_weight[expert_id][start : start + shard_size, :],
layer.w13_weight_scale[expert_id][shard_id],
)
if _is_cuda:
(
layer.w13_weight[expert_id][start : start + shard_size, :],
_,
) = sgl_scaled_fp8_quant(dq_weight, max_w13_scales[expert_id])
else:
(
layer.w13_weight[expert_id][start : start + shard_size, :],
_,
) = vllm_ops.scaled_fp8_quant(
dq_weight, max_w13_scales[expert_id]
)
(
layer.w13_weight[expert_id][start : start + shard_size, :],
_,
) = scaled_fp8_quant(dq_weight, max_w13_scales[expert_id])
start += shard_size
layer.w13_weight_scale = torch.nn.Parameter(
......@@ -930,41 +909,11 @@ class Fp8MoEMethod:
correction_bias=correction_bias,
)
if _is_hip and get_bool_env_var("USE_INT4_WEIGHT"):
# TODO: add triton kernel and add check get_bool_env_var("CK_MOE")
assert not no_combine, f"{no_combine=} is not supported."
return ck_moe_2stages_win4(
x,
layer.w13_weight,
layer.w2_weight,
topk_weights,
topk_ids,
layer.w13_weight_scale1,
layer.w2_weight_scale1,
activation=(
ActivationType.Silu if activation == "silu" else ActivationType.Gelu
),
)
if _is_hip and get_bool_env_var("CK_MOE"):
assert not no_combine, f"{no_combine=} is not supported."
if self.block_quant:
# TODO(CK_MOE): FP8 block_quant only supports 'silu' for the time-being.
assert (
activation == "silu"
), f"CK_MOE: FP8 bloack_quant {activation=} will be supported later, unset CK_MOE"
return asm_moe(
x,
layer.w13_weight,
layer.w2_weight,
topk_weights,
topk_ids,
layer.w13_weight_scale_inv,
layer.w2_weight_scale_inv,
block_shape=tuple(self.quant_config.weight_block_size),
expert_mask=None,
)
else:
return ck_moe_2stages(
if _is_hip:
if get_bool_env_var("USE_INT4_WEIGHT"):
# TODO: add triton kernel and add check get_bool_env_var("CK_MOE")
assert not no_combine, f"{no_combine=} is not supported."
return ck_moe_2stages_win4(
x,
layer.w13_weight,
layer.w2_weight,
......@@ -978,33 +927,65 @@ class Fp8MoEMethod:
else ActivationType.Gelu
),
)
else:
# Expert fusion with FP8 quantization
return fused_experts(
x,
layer.w13_weight,
layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=inplace and not no_combine,
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input,
use_fp8_w8a8=True,
w1_scale=(
layer.w13_weight_scale_inv
if self.block_quant
else layer.w13_weight_scale
),
w2_scale=(
layer.w2_weight_scale_inv
if self.block_quant
else layer.w2_weight_scale
),
a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale,
block_shape=self.quant_config.weight_block_size,
no_combine=no_combine,
)
if get_bool_env_var("CK_MOE"):
assert not no_combine, f"{no_combine=} is not supported."
if self.block_quant:
# TODO(CK_MOE): FP8 block_quant only supports 'silu' for the time-being.
assert (
activation == "silu"
), f"CK_MOE: FP8 bloack_quant {activation=} will be supported later, unset CK_MOE"
return asm_moe(
x,
layer.w13_weight,
layer.w2_weight,
topk_weights,
topk_ids,
layer.w13_weight_scale_inv,
layer.w2_weight_scale_inv,
block_shape=tuple(self.quant_config.weight_block_size),
expert_mask=None,
)
else:
return ck_moe_2stages(
x,
layer.w13_weight,
layer.w2_weight,
topk_weights,
topk_ids,
layer.w13_weight_scale1,
layer.w2_weight_scale1,
activation=(
ActivationType.Silu
if activation == "silu"
else ActivationType.Gelu
),
)
# Expert fusion with FP8 quantization
return fused_experts(
x,
layer.w13_weight,
layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=inplace and not no_combine,
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input,
use_fp8_w8a8=True,
w1_scale=(
layer.w13_weight_scale_inv
if self.block_quant
else layer.w13_weight_scale
),
w2_scale=(
layer.w2_weight_scale_inv if self.block_quant else layer.w2_weight_scale
),
a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale,
block_shape=self.quant_config.weight_block_size,
no_combine=no_combine,
)
class Fp8KVCacheMethod(BaseKVCacheMethod):
......
......@@ -34,15 +34,23 @@ from sglang.srt.utils import (
supports_custom_op,
)
_enable_jit_deepgemm = False
_is_hip = is_hip()
fp8_type_ = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
_is_cuda = is_cuda()
_fp8_type = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
if _is_hip:
fp8_max = 224.0
else:
fp8_max = torch.finfo(_fp8_type).max
fp8_min = -fp8_max
_enable_jit_deepgemm = False
if _is_cuda:
import deep_gemm
from sgl_kernel import sgl_per_token_group_quant_fp8, sgl_per_token_quant_fp8
from sgl_kernel import (
sgl_per_tensor_quant_fp8,
sgl_per_token_group_quant_fp8,
sgl_per_token_quant_fp8,
)
sm_version = get_device_sm()
if sm_version == 90 and get_bool_env_var(
......@@ -53,6 +61,7 @@ if _is_cuda:
logger = logging.getLogger(__name__)
if supports_custom_op():
def deep_gemm_fp8_fp8_bf16_nt(
......@@ -179,7 +188,6 @@ def per_token_group_quant_fp8(
x: torch.Tensor,
group_size: int,
eps: float = 1e-10,
dtype: torch.dtype = fp8_type_,
column_major_scales: bool = False,
scale_tma_aligned: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
......@@ -192,7 +200,6 @@ def per_token_group_quant_fp8(
x: The input tenosr with ndim >= 2.
group_size: The group size used for quantization.
eps: The minimum to avoid dividing zero.
dtype: The dype of output tensor.
Returns:
Tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the scaling factor for quantization.
......@@ -202,15 +209,7 @@ def per_token_group_quant_fp8(
), "the last dimension of `x` cannot be divisible by `group_size`"
assert x.is_contiguous(), "`x` is not contiguous"
finfo = torch.finfo(dtype)
fp8_max = finfo.max
if _is_hip:
fp8_max = 224.0
fp8_min = -fp8_max
x_q = torch.empty_like(x, device=x.device, dtype=dtype)
x_q = torch.empty_like(x, device=x.device, dtype=_fp8_type)
M = x.numel() // group_size
N = group_size
if column_major_scales:
......@@ -276,27 +275,18 @@ def sglang_per_token_group_quant_fp8(
x: torch.Tensor,
group_size: int,
eps: float = 1e-10,
dtype: torch.dtype = fp8_type_,
):
assert (
x.shape[-1] % group_size == 0
), "the last dimension of `x` cannot be divisible by `group_size`"
assert x.is_contiguous(), "`x` is not contiguous"
finfo = torch.finfo(dtype)
fp8_max = finfo.max
fp8_min = -fp8_max
x_q = torch.empty_like(x, device=x.device, dtype=dtype)
M = x.numel() // group_size
N = group_size
x_q = torch.empty_like(x, device=x.device, dtype=_fp8_type)
x_s = torch.empty(
x.shape[:-1] + (x.shape[-1] // group_size,),
device=x.device,
dtype=torch.float32,
)
sgl_per_token_group_quant_fp8(x, x_q, x_s, group_size, eps, fp8_min, fp8_max)
return x_q, x_s
......@@ -304,7 +294,7 @@ def sglang_per_token_group_quant_fp8(
def sglang_per_token_quant_fp8(
x: torch.Tensor,
dtype: torch.dtype = fp8_type_,
dtype: torch.dtype = _fp8_type,
):
assert x.is_contiguous(), "`x` is not contiguous"
......@@ -368,7 +358,6 @@ def static_quant_fp8(
x: torch.Tensor,
x_s: torch.Tensor,
repeat_scale: bool = False,
dtype: torch.dtype = fp8_type_,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Function to perform static quantization using the given scale on an input tensor `x`.
......@@ -386,15 +375,8 @@ def static_quant_fp8(
"""
assert x.is_contiguous(), "`x` is not contiguous"
assert x_s.numel() == 1, "only supports per-tensor scale"
finfo = torch.finfo(dtype)
fp8_max = finfo.max
if _is_hip:
fp8_max = 224.0
fp8_min = -fp8_max
x_q = torch.empty_like(x, device=x.device, dtype=dtype)
x_q = torch.empty_like(x, device=x.device, dtype=_fp8_type)
M = x.numel() // x.shape[-1]
N = x.shape[-1]
if repeat_scale:
......@@ -896,7 +878,7 @@ def _per_tensor_quant_mla_fp8_stage2(
def per_tensor_quant_mla_fp8(
x: torch.Tensor, eps: float = 1e-12, dtype: torch.dtype = torch.float8_e4m3fn
x: torch.Tensor, eps: float = 1e-12
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
This function quantizes input values to float8 values with tensor-wise quantization
......@@ -904,13 +886,7 @@ def per_tensor_quant_mla_fp8(
"""
assert x.dim() == 3, "`x` is not a 3d-tensor"
finfo = torch.finfo(dtype)
fp8_max = finfo.max
if _is_hip:
dtype = torch.float8_e4m3fnuz
fp8_max = 224.0
x_q = x.new_empty(x.size(), dtype=dtype)
x_q = x.new_empty(x.size(), dtype=_fp8_type)
x_s = torch.zeros((1,), dtype=torch.float32, device=x.device)
num_head, num_seq, head_size = x.shape
......@@ -935,9 +911,64 @@ def per_tensor_quant_mla_fp8(
head_size,
x.stride(0),
x.stride(1),
-fp8_max,
fp8_min,
fp8_max,
BLOCK_SIZE,
)
return x_q, x_s
def scaled_fp8_quant(
input: torch.Tensor,
scale: Optional[torch.Tensor] = None,
num_token_padding: Optional[int] = None,
use_per_token_if_dynamic: bool = False,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Quantize input tensor to FP8 (8-bit floating point) format.
Args:
input (torch.Tensor): Input tensor to be quantized
scale (Optional[torch.Tensor]): Pre-computed scaling factor for static quantization.
If None, scales will be computed dynamically.
num_token_padding (Optional[int]): If specified, pad the first dimension
of the output to at least this value.
use_per_token_if_dynamic (bool): When using dynamic scaling (scale=None),
determines the quantization granularity:
- True: compute scale per token
- False: compute single scale per tensor
Returns:
Tuple[torch.Tensor, torch.Tensor]: A tuple containing:
- quantized_tensor: The FP8 quantized version of input
- scale_tensor: The scaling factors used for quantization
Raises:
AssertionError: If input is not 2D or if static scale's numel != 1
"""
assert input.ndim == 2, f"Expected 2D input tensor, got {input.ndim}D"
shape = input.shape
out_dtype = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
if num_token_padding:
shape = (max(num_token_padding, input.shape[0]), shape[1])
output = torch.empty(shape, device=input.device, dtype=out_dtype)
if scale is None:
# Dynamic scaling
if use_per_token_if_dynamic:
scale = torch.empty((shape[0], 1), device=input.device, dtype=torch.float32)
sgl_per_token_quant_fp8(input, output, scale)
else:
scale = torch.zeros(1, device=input.device, dtype=torch.float32)
sgl_per_tensor_quant_fp8(
input, output, scale, is_static=False
) # False for dynamic
else:
# Static scaling
assert scale.numel() == 1, f"Expected scalar scale, got numel={scale.numel()}"
sgl_per_tensor_quant_fp8(
input, output, scale, is_static=True
) # True for static
return output, scale
import os
from typing import List, Optional, Tuple
import torch
try:
from vllm import _custom_ops as vllm_ops
VLLM_AVAILABLE = True
except ImportError:
VLLM_AVAILABLE = False
from sglang.srt.layers.quantization.fp8_kernel import (
_enable_jit_deepgemm,
per_token_group_quant_fp8,
scaled_fp8_quant,
sglang_per_token_quant_fp8,
static_quant_fp8,
w8a8_block_fp8_matmul,
)
......@@ -17,30 +25,20 @@ from sglang.srt.utils import (
is_hip,
)
try:
import vllm
from vllm import _custom_ops as ops
VLLM_AVAILABLE = True
except ImportError:
VLLM_AVAILABLE = False
use_vllm_cutlass_w8a8_fp8_kernel = get_bool_env_var("USE_VLLM_CUTLASS_W8A8_FP8_KERNEL")
_is_hip = is_hip()
_is_cuda = is_cuda()
if _is_hip and get_bool_env_var("CK_MOE"):
from aiter import gemm_a8w8_blockscale
_is_cuda = is_cuda()
if _is_cuda:
from sgl_kernel import fp8_blockwise_scaled_mm, fp8_scaled_mm
from sglang.srt.custom_op import scaled_fp8_quant as sgl_scaled_fp8_quant
from sglang.srt.layers.quantization.fp8_kernel import sglang_per_token_quant_fp8
use_vllm_cutlass_w8a8_fp8_kernel = get_bool_env_var("USE_VLLM_CUTLASS_W8A8_FP8_KERNEL")
# Input scaling factors are no longer optional in _scaled_mm starting
# from pytorch 2.5. Allocating a dummy tensor to pass as input_scale
TORCH_DEVICE_IDENTITY = torch.ones(1, dtype=torch.float32)
TORCH_DEVICE_IDENTITY = None
_TORCH_VERSION = torch.__version__.split("+")[0]
try:
......@@ -214,7 +212,7 @@ def block_quant_to_tensor_quant(
x_dq_block_tiles[j][i][:, :] = x_dq_block_tiles[j][i] * x_s[j][i]
x_q_tensor, scale = (
sgl_scaled_fp8_quant(x_dq_block)
scaled_fp8_quant(x_dq_block)
if _is_cuda
else input_to_float8(x_dq_block, dtype=x_q_block.dtype)
)
......@@ -227,7 +225,7 @@ def channel_quant_to_tensor_quant(
) -> Tuple[torch.Tensor, torch.Tensor]:
x_dq_channel = x_q_channel.to(torch.float32) * x_s
x_q_tensor, scale = (
sgl_scaled_fp8_quant(x_dq_channel)
scaled_fp8_quant(x_dq_channel)
if _is_cuda
else input_to_float8(x_dq_channel, dtype=x_q_channel.dtype)
)
......@@ -264,7 +262,7 @@ def apply_fp8_linear(
# final solution should be: 1. add support to per-tensor activation scaling.
# 2. solve the torch.compile error from weight_scale.numel() == 1 and x_scale.numel() > 1 (below line#308)
if _is_hip and weight_scale.numel() == 1:
qinput, x_scale = ops.scaled_fp8_quant(
qinput, x_scale = vllm_ops.scaled_fp8_quant(
input_2d,
input_scale,
use_per_token_if_dynamic=use_per_token_if_dynamic,
......@@ -275,32 +273,29 @@ def apply_fp8_linear(
)
if cutlass_fp8_supported:
try:
if VLLM_AVAILABLE and use_vllm_cutlass_w8a8_fp8_kernel:
# Fall back to vllm cutlass w8a8 fp8 kernel
output = ops.cutlass_scaled_mm(
qinput,
weight,
out_dtype=input.dtype,
scale_a=x_scale,
scale_b=weight_scale,
bias=bias,
)
else:
assert (
weight_scale.numel() == weight.shape[1]
), "cutlass w8a8 fp8 sgl-kernel only supports per-channel scale"
output = fp8_scaled_mm(
qinput,
weight,
x_scale,
weight_scale,
out_dtype=input.dtype,
bias=bias,
)
return output.view(*output_shape)
except (ImportError, NameError, AttributeError):
pass
if VLLM_AVAILABLE and use_vllm_cutlass_w8a8_fp8_kernel:
# Fall back to vllm cutlass w8a8 fp8 kernel
output = vllm_ops.cutlass_scaled_mm(
qinput,
weight,
out_dtype=input.dtype,
scale_a=x_scale,
scale_b=weight_scale,
bias=bias,
)
else:
assert (
weight_scale.numel() == weight.shape[1]
), "cutlass w8a8 fp8 sgl-kernel only supports per-channel scale"
output = fp8_scaled_mm(
qinput,
weight,
x_scale,
weight_scale,
out_dtype=input.dtype,
bias=bias,
)
return output.view(*output_shape)
# torch.scaled_mm supports per tensor weights + activations only
# so fallback to naive if per channel or per token
......@@ -343,8 +338,10 @@ def apply_fp8_linear(
# Making sure the dummy tensor is on the same device as the weight
global TORCH_DEVICE_IDENTITY
if TORCH_DEVICE_IDENTITY.device != weight.device:
TORCH_DEVICE_IDENTITY = TORCH_DEVICE_IDENTITY.to(weight.device)
if TORCH_DEVICE_IDENTITY is None:
TORCH_DEVICE_IDENTITY = torch.ones(
1, dtype=torch.float32, device=weight.device
)
# GEMM
# This computes C = (X * W).
......@@ -372,13 +369,6 @@ def apply_fp8_linear(
return output.to(dtype=input.dtype).view(*output_shape)
def maybe_create_device_identity():
# Allocate dummy ones tensor for torch._scaled_mm
global TORCH_DEVICE_IDENTITY
if TORCH_DEVICE_IDENTITY is None:
TORCH_DEVICE_IDENTITY = torch.ones(1, dtype=torch.float32)
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/utils/w8a8_utils.py
# TODO(luka): follow similar pattern for marlin and block-fp8-linear
# https://github.com/vllm-project/vllm/issues/14397
......@@ -405,9 +395,7 @@ class Fp8LinearOp:
# We also don't pad when using torch.compile,
# as it breaks with dynamic shapes.
if pad_output is None:
enable_torch_compile = os.environ.get(
"SGLANG_ENABLE_TORCH_COMPILE", "0"
).lower() in ("1", "true", "yes")
enable_torch_compile = get_bool_env_var("SGLANG_ENABLE_TORCH_COMPILE")
pad_output = not enable_torch_compile
self.output_padding = 17 if pad_output else None
......@@ -439,13 +427,13 @@ class Fp8LinearOp:
# for sgl-kernel fp8_scaled_mm, it support per channel W now
if self.cutlass_fp8_supported and weight_scale.numel() == weight.shape[1]:
if _is_cuda:
qinput, x_scale = sgl_scaled_fp8_quant(
qinput, x_scale = scaled_fp8_quant(
input_2d,
input_scale,
use_per_token_if_dynamic=use_per_token_if_dynamic,
)
else:
qinput, x_scale = ops.scaled_fp8_quant(
qinput, x_scale = vllm_ops.scaled_fp8_quant(
input_2d,
input_scale,
scale_ub=input_scale_ub,
......@@ -455,7 +443,7 @@ class Fp8LinearOp:
# Fused GEMM_DQ
if VLLM_AVAILABLE and use_vllm_cutlass_w8a8_fp8_kernel:
# Fall back to vllm cutlass w8a8 fp8 kernel
output = ops.cutlass_scaled_mm(
output = vllm_ops.cutlass_scaled_mm(
qinput,
weight,
out_dtype=input.dtype,
......@@ -482,14 +470,14 @@ class Fp8LinearOp:
else:
# Maybe apply padding to output, see comment in __init__
if _is_cuda:
qinput, x_scale = sgl_scaled_fp8_quant(
qinput, x_scale = scaled_fp8_quant(
input_2d,
input_scale,
num_token_padding=self.output_padding,
use_per_token_if_dynamic=use_per_token_if_dynamic,
)
else:
qinput, x_scale = ops.scaled_fp8_quant(
qinput, x_scale = vllm_ops.scaled_fp8_quant(
input_2d,
input_scale,
num_token_padding=self.output_padding,
......@@ -562,9 +550,12 @@ class Fp8LinearOp:
# This computes C = (X * W).
# Output in fp32 to allow subsequent ops to happen in-place
# Making sure the dummy tensor is on the same device as the weight
global TORCH_DEVICE_IDENTITY
if TORCH_DEVICE_IDENTITY.device != weight.device:
TORCH_DEVICE_IDENTITY = TORCH_DEVICE_IDENTITY.to(weight.device)
if TORCH_DEVICE_IDENTITY is None:
TORCH_DEVICE_IDENTITY = torch.ones(
1, dtype=torch.float32, device=weight.device
)
output = torch._scaled_mm(
qinput,
......
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/utils/quant_utils.py
from types import MappingProxyType
from typing import List, Mapping, Optional, Tuple, Union
from typing import List, Mapping, Tuple, Union
import torch
from sglang.srt.layers.quantization.fp8_kernel import scaled_fp8_quant
from sglang.srt.utils import is_cuda
_is_cuda = is_cuda()
if _is_cuda:
from sglang.srt.custom_op import scaled_fp8_quant as sgl_scaled_fp8_quant
else:
from vllm import _custom_ops as vllm_ops
if not _is_cuda:
from vllm._custom_ops import scaled_fp8_quant
def is_fp8_fnuz() -> bool:
......@@ -116,12 +115,7 @@ def requantize_with_max_scale(
for idx, logical_width in enumerate(logical_widths):
end = start + logical_width
weight_dq = per_tensor_dequantize(weight[start:end, :], weight_scale[idx])
if _is_cuda:
weight[start:end, :], _ = sgl_scaled_fp8_quant(weight_dq, max_w_scale)
else:
weight[start:end, :], _ = vllm_ops.scaled_fp8_quant(
weight_dq, max_w_scale
)
weight[start:end, :], _ = scaled_fp8_quant(weight_dq, max_w_scale)
start = end
return max_w_scale, weight
......
from typing import Any, Callable, Dict, List, Optional
import torch
from sglang.srt.utils import is_cuda_available, set_weight_attrs
is_cuda = is_cuda_available()
if is_cuda:
from sgl_kernel import int8_scaled_mm
from torch.nn.parameter import Parameter
from sglang.srt.distributed import get_tensor_model_parallel_world_size
......@@ -18,6 +11,11 @@ from sglang.srt.layers.quantization.base_config import (
QuantizeMethodBase,
)
from sglang.srt.layers.quantization.int8_kernel import per_token_quant_int8
from sglang.srt.utils import is_cuda_available, set_weight_attrs
is_cuda = is_cuda_available()
if is_cuda:
from sgl_kernel import int8_scaled_mm
class W8A8Int8Config(QuantizationConfig):
......
......@@ -11,10 +11,11 @@ from sglang.srt.custom_op import CustomOp
from sglang.srt.utils import is_cuda_available
_is_cuda_available = is_cuda_available()
if _is_cuda_available:
from sgl_kernel import apply_rope_with_cos_sin_cache_inplace
else:
from vllm import _custom_ops as ops
from vllm._custom_ops import rotary_embedding as vllm_rotary_embedding
def _rotate_neox(x: torch.Tensor) -> torch.Tensor:
......@@ -159,7 +160,7 @@ class RotaryEmbedding(CustomOp):
)
else:
self.cos_sin_cache = self.cos_sin_cache.to(query.device, dtype=query.dtype)
ops.rotary_embedding(
vllm_rotary_embedding(
positions,
query,
key,
......
from sglang.srt.lora.backend.base_backend import BaseLoRABackend
def get_backend_from_name(name: str) -> BaseLoRABackend:
"""
Get corresponding backend class from backend's name
"""
if name == "triton":
from sglang.srt.lora.backend.triton_backend import TritonLoRABackend
return TritonLoRABackend
elif name == "flashinfer":
from sglang.srt.lora.backend.flashinfer_backend import FlashInferLoRABackend
return FlashInferLoRABackend
else:
raise ValueError(f"Invalid backend: {name}")
__all__ = [
"BaseLoRABackend",
"FlashInferLoRABackend",
"TritonLoRABackend",
"get_backend_from_name",
]
......@@ -75,7 +75,7 @@ class BaseLoRABackend:
qkv_lora_a: torch.Tensor,
qkv_lora_b: Union[torch.Tensor, Tuple[torch.Tensor]],
*args,
**kwargs
**kwargs,
) -> torch.Tensor:
"""Run the lora pass for QKV Layer.
......@@ -98,7 +98,7 @@ class BaseLoRABackend:
gate_up_lora_a: torch.Tensor,
gate_up_lora_b: Union[torch.Tensor, Tuple[torch.Tensor]],
*args,
**kwargs
**kwargs,
) -> torch.Tensor:
"""Run the lora pass for gate_up_proj, usually attached to MergedColumnParallelLayer.
......@@ -115,3 +115,19 @@ class BaseLoRABackend:
def set_batch_info(self, batch_info: LoRABatchInfo):
self.batch_info = batch_info
def get_backend_from_name(name: str) -> BaseLoRABackend:
"""
Get corresponding backend class from backend's name
"""
if name == "triton":
from sglang.srt.lora.backend.triton_backend import TritonLoRABackend
return TritonLoRABackend
elif name == "flashinfer":
from sglang.srt.lora.backend.flashinfer_backend import FlashInferLoRABackend
return FlashInferLoRABackend
else:
raise ValueError(f"Invalid backend: {name}")
......@@ -2,7 +2,7 @@ from typing import Tuple
import torch
from sglang.srt.lora.backend import BaseLoRABackend
from sglang.srt.lora.backend.base_backend import BaseLoRABackend
from sglang.srt.lora.utils import LoRABatchInfo
from sglang.srt.utils import is_flashinfer_available
......
import torch
from sglang.srt.lora.backend import BaseLoRABackend
from sglang.srt.lora.backend.base_backend import BaseLoRABackend
from sglang.srt.lora.triton_ops import (
gate_up_lora_b_fwd,
qkv_lora_b_fwd,
......
......@@ -16,7 +16,7 @@ from sglang.srt.layers.linear import (
RowParallelLinear,
)
from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
from sglang.srt.lora.backend import BaseLoRABackend
from sglang.srt.lora.backend.base_backend import BaseLoRABackend
class BaseLayerWithLoRA(nn.Module):
......
......@@ -27,7 +27,7 @@ from torch import nn
from sglang.srt.configs.load_config import LoadConfig
from sglang.srt.hf_transformers_utils import AutoConfig
from sglang.srt.lora.backend import BaseLoRABackend
from sglang.srt.lora.backend.base_backend import BaseLoRABackend
from sglang.srt.lora.lora_config import LoRAConfig
from sglang.srt.model_loader.loader import DefaultModelLoader
......
......@@ -22,7 +22,7 @@ import torch
from sglang.srt.configs.load_config import LoadConfig
from sglang.srt.hf_transformers_utils import AutoConfig
from sglang.srt.lora.backend import BaseLoRABackend, get_backend_from_name
from sglang.srt.lora.backend.base_backend import BaseLoRABackend, get_backend_from_name
from sglang.srt.lora.layers import BaseLayerWithLoRA, get_lora_layer
from sglang.srt.lora.lora import LoRAAdapter
from sglang.srt.lora.lora_config import LoRAConfig
......
......@@ -14,7 +14,6 @@
"""DetokenizerManager is a process that detokenizes the token ids."""
import dataclasses
import json
import logging
import os
import signal
......
"""
Multi-modality utils
Multi-modality utils
"""
import logging
from abc import abstractmethod
from typing import Callable, List, Optional, Tuple
......@@ -12,11 +13,11 @@ from sglang.srt.managers.schedule_batch import (
MultimodalDataItem,
MultimodalInputs,
global_server_args_dict,
logger,
)
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.utils import print_warning_once
from sglang.utils import logger
logger = logging.getLogger(__name__)
class MultiModalityDataPaddingPattern:
......
......@@ -5,8 +5,6 @@ import logging
import pkgutil
from functools import lru_cache
from transformers import PROCESSOR_MAPPING
from sglang.srt.managers.multimodal_processors.base_processor import (
BaseMultimodalProcessor,
)
......
......@@ -8,8 +8,6 @@ from typing import List, Optional
import numpy as np
import PIL
from decord import VideoReader, cpu
from PIL import Image
from transformers import BaseImageProcessorFast
from sglang.srt.managers.schedule_batch import Modality
......@@ -102,6 +100,9 @@ class BaseMultimodalProcessor(ABC):
"""
estimate the total frame count from all visual input
"""
# Lazy import because decord is not available on some arm platforms.
from decord import VideoReader, cpu
# Before processing inputs
estimated_frames_list = []
for image in image_data:
......
......@@ -37,11 +37,11 @@ from sglang.srt.model_executor.forward_batch_info import (
from sglang.srt.patch_torch import monkey_patch_torch_compile
from sglang.srt.utils import get_available_gpu_memory, is_hip
_is_hip = is_hip()
if TYPE_CHECKING:
from sglang.srt.model_executor.model_runner import ModelRunner
_is_hip = is_hip()
def _to_torch(model: torch.nn.Module, reverse: bool, num_tokens: int):
for sub in model._modules.values():
......
......@@ -320,7 +320,6 @@ class ModelRunner:
logger.info(f"DeepEP is turned on. DeepEP mode: {server_args.deepep_mode}")
if not self.use_mla_backend:
logger.info("Disable chunked prefix cache for non-MLA backend.")
server_args.disable_chunked_prefix_cache = True
elif self.page_size > 1:
logger.info("Disable chunked prefix cache when page size > 1.")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment