Unverified Commit 177320a5 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Clean up imports (#5467)

parent d7bc19a4
......@@ -24,6 +24,7 @@ from sglang.api import (
user_end,
video,
)
from sglang.global_config import global_config
from sglang.lang.backend.runtime_endpoint import RuntimeEndpoint
from sglang.lang.choices import (
greedy_token_selection,
......@@ -31,6 +32,7 @@ from sglang.lang.choices import (
unconditional_likelihood_normalized,
)
from sglang.utils import LazyImport
from sglang.version import __version__
ServerArgs = LazyImport("sglang.srt.server_args", "ServerArgs")
Anthropic = LazyImport("sglang.lang.backend.anthropic", "Anthropic")
......@@ -38,10 +40,6 @@ LiteLLM = LazyImport("sglang.lang.backend.litellm", "LiteLLM")
OpenAI = LazyImport("sglang.lang.backend.openai", "OpenAI")
VertexAI = LazyImport("sglang.lang.backend.vertexai", "VertexAI")
# Other configs
from sglang.global_config import global_config
from sglang.version import __version__
__all__ = [
"Engine",
"Runtime",
......
......@@ -707,10 +707,6 @@ def sample_random_requests(
# Download sharegpt if necessary
if not os.path.isfile(dataset_path):
print(
"If you do not want to randomly sample from a dataset,"
" please use --dataset-name random-ids."
)
dataset_path = download_and_cache_file(SHAREGPT_URL)
# Load the dataset.
......
from typing import List, Optional, Union
import numpy as np
from sglang.lang.backend.base_backend import BaseBackend
from sglang.lang.chat_template import get_chat_template
from sglang.lang.interpreter import StreamExecutor
......
from typing import Callable, List, Optional, Union
from typing import List, Optional, Union
from sglang.lang.chat_template import get_chat_template
from sglang.lang.choices import ChoicesDecision, ChoicesSamplingMethod
......
......@@ -2,7 +2,7 @@ import dataclasses
import logging
import time
import warnings
from typing import Callable, List, Optional, Union
from typing import List, Optional, Union
import numpy as np
......
import os
import warnings
from typing import Optional
from sglang.lang.backend.base_backend import BaseBackend
from sglang.lang.chat_template import get_chat_template
......
......@@ -5,13 +5,7 @@ from typing import List, Union
from sglang.global_config import global_config
from sglang.lang.interpreter import ProgramState, StreamExecutor, cache_program
from sglang.lang.ir import (
SglArgument,
SglConstantText,
SglExpr,
SglSamplingParams,
SglVariable,
)
from sglang.lang.ir import SglArgument, SglExpr, SglSamplingParams, SglVariable
def compile_func(function, backend):
......
"""Tracing a program."""
import uuid
from typing import Any, Callable, Dict, List, Optional, Union
from typing import Any, Dict, List, Optional
from sglang.global_config import global_config
from sglang.lang.backend.base_backend import BaseBackend
from sglang.lang.interpreter import ProgramState, ProgramStateGroup
from sglang.lang.ir import (
SglArgument,
SglCommitLazy,
SglConcateAndAppend,
SglConstantText,
SglExpr,
SglExprList,
SglFork,
SglFunction,
SglGen,
SglGetForkItem,
SglRoleBegin,
......@@ -230,8 +226,8 @@ class TracerProgramState(ProgramState):
self.cur_role = None
def _execute_var_scope_end(self, expr: SglVarScopeEnd):
new_node = SglVariable(name, source=self.last_node)
self.variables[name] = new_node
new_node = SglVariable(expr.name, source=self.last_node)
self.variables[expr.name] = new_node
def get_var(self, name):
ret = self.arguments.get(name, None)
......
# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/_custom_ops.py
import logging
import os
from typing import List, Tuple
import torch
import torch.library
from sglang.srt.utils import get_bool_env_var, is_hip, is_hpu
......
......@@ -42,65 +42,3 @@ class CustomOp(nn.Module):
return self.forward_hip
else:
return self.forward_native
if _is_cuda:
from sgl_kernel import sgl_per_tensor_quant_fp8, sgl_per_token_quant_fp8
def scaled_fp8_quant(
input: torch.Tensor,
scale: Optional[torch.Tensor] = None,
num_token_padding: Optional[int] = None,
use_per_token_if_dynamic: bool = False,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Quantize input tensor to FP8 (8-bit floating point) format.
Args:
input (torch.Tensor): Input tensor to be quantized
scale (Optional[torch.Tensor]): Pre-computed scaling factor for static quantization.
If None, scales will be computed dynamically.
num_token_padding (Optional[int]): If specified, pad the first dimension
of the output to at least this value.
use_per_token_if_dynamic (bool): When using dynamic scaling (scale=None),
determines the quantization granularity:
- True: compute scale per token
- False: compute single scale per tensor
Returns:
Tuple[torch.Tensor, torch.Tensor]: A tuple containing:
- quantized_tensor: The FP8 quantized version of input
- scale_tensor: The scaling factors used for quantization
Raises:
AssertionError: If input is not 2D or if static scale's numel != 1
"""
assert input.ndim == 2, f"Expected 2D input tensor, got {input.ndim}D"
shape = input.shape
out_dtype = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
if num_token_padding:
shape = (max(num_token_padding, input.shape[0]), shape[1])
output = torch.empty(shape, device=input.device, dtype=out_dtype)
if scale is None:
# Dynamic scaling
if use_per_token_if_dynamic:
scale = torch.empty(
(shape[0], 1), device=input.device, dtype=torch.float32
)
sgl_per_token_quant_fp8(input, output, scale)
else:
scale = torch.zeros(1, device=input.device, dtype=torch.float32)
sgl_per_tensor_quant_fp8(
input, output, scale, is_static=False
) # False for dynamic
else:
# Static scaling
assert (
scale.numel() == 1
), f"Expected scalar scale, got numel={scale.numel()}"
sgl_per_tensor_quant_fp8(
input, output, scale, is_static=True
) # True for static
return output, scale
......@@ -19,11 +19,10 @@ import torch.distributed as dist
from PIL.Image import Image
from torch.distributed.tensor import DeviceMesh, DTensor
from sglang.srt.entrypoints.engine import Engine
from sglang.srt.entrypoints.http_server_engine import HttpServerEngineAdapter
from sglang.srt.model_executor.model_runner import LocalSerializedTensor
from sglang.srt.patch_torch import monkey_patch_torch_reductions
from sglang.srt.server import Engine
from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.utils import MultiprocessingSerializer, broadcast_pyobj
......
......@@ -21,13 +21,6 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
from sglang.srt.utils import is_cuda_available
_is_cuda = is_cuda_available()
if _is_cuda:
from sgl_kernel import gelu_and_mul, gelu_tanh_and_mul, silu_and_mul
from sglang.srt.custom_op import CustomOp
from sglang.srt.distributed import (
divide,
......@@ -35,7 +28,12 @@ from sglang.srt.distributed import (
get_tensor_model_parallel_world_size,
)
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.utils import set_weight_attrs
from sglang.srt.utils import is_cuda_available, set_weight_attrs
_is_cuda = is_cuda_available()
if _is_cuda:
from sgl_kernel import gelu_and_mul, gelu_tanh_and_mul, silu_and_mul
logger = logging.getLogger(__name__)
......
......@@ -19,6 +19,7 @@ from typing import Optional, Tuple, Union
import torch
import torch.nn as nn
from sglang.srt.custom_op import CustomOp
from sglang.srt.utils import is_cuda_available
_is_cuda = is_cuda_available()
......@@ -31,7 +32,6 @@ if _is_cuda:
rmsnorm,
)
from sglang.srt.custom_op import CustomOp
logger = logging.getLogger(__name__)
......
......@@ -2,6 +2,7 @@ import logging
from typing import Callable, List, Optional, Tuple
import torch
from torch.nn import Module
try:
from deep_gemm import (
......@@ -13,8 +14,6 @@ try:
except ImportError:
use_deep_gemm = False
from torch.nn import Module
from sglang.srt.custom_op import CustomOp
from sglang.srt.distributed import (
get_tensor_model_parallel_rank,
......@@ -37,22 +36,17 @@ from sglang.srt.layers.quantization.base_config import (
QuantizeMethodBase,
)
from sglang.srt.layers.quantization.fp8 import Fp8Config, Fp8MoEMethod
from sglang.srt.layers.quantization.fp8_kernel import scaled_fp8_quant
from sglang.srt.model_executor.forward_batch_info import ForwardMode
from sglang.srt.utils import DeepEPMode, is_cuda, is_hip, set_weight_attrs
from sglang.srt.utils import DeepEPMode, is_hip, set_weight_attrs
_is_cuda = is_cuda()
_is_hip = is_hip()
if _is_cuda:
from sglang.srt.custom_op import scaled_fp8_quant as sgl_scaled_fp8_quant
else:
from vllm import _custom_ops as vllm_ops
if _is_hip:
from vllm._custom_ops import scaled_fp8_quant
logger = logging.getLogger(__name__)
_is_hip = is_hip()
_buffer = None
class GroupedGemmRunner(torch.nn.Module):
flashinfer_gemm_warpper = None
......@@ -740,20 +734,12 @@ class Fp8EPMoEMethod(Fp8MoEMethod):
)
for expert in range(layer.num_experts_per_partition):
if _is_cuda:
w13_weight[expert, :, :], layer.w13_weight_scale[expert] = (
sgl_scaled_fp8_quant(layer.w13_weight.data[expert, :, :])
)
w2_weight[expert, :, :], layer.w2_weight_scale[expert] = (
sgl_scaled_fp8_quant(layer.w2_weight.data[expert, :, :])
)
else:
w13_weight[expert, :, :], layer.w13_weight_scale[expert] = (
vllm_ops.scaled_fp8_quant(layer.w13_weight.data[expert, :, :])
)
w2_weight[expert, :, :], layer.w2_weight_scale[expert] = (
vllm_ops.scaled_fp8_quant(layer.w2_weight.data[expert, :, :])
)
w13_weight[expert, :, :], layer.w13_weight_scale[expert] = (
scaled_fp8_quant(layer.w13_weight.data[expert, :, :])
)
w2_weight[expert, :, :], layer.w2_weight_scale[expert] = (
scaled_fp8_quant(layer.w2_weight.data[expert, :, :])
)
layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False)
layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False)
return
......
......@@ -13,6 +13,7 @@ import triton
import triton.language as tl
from sglang.srt.layers.moe.topk import select_experts
from sglang.srt.layers.quantization.fp8_kernel import scaled_fp8_quant
from sglang.srt.utils import (
direct_register_custom_op,
get_bool_env_var,
......@@ -22,28 +23,25 @@ from sglang.srt.utils import (
)
_is_hip = is_hip()
logger = logging.getLogger(__name__)
padding_size = 128 if bool(int(os.getenv("MOE_PADDING", "0"))) else 0
enable_moe_align_block_size_triton = bool(
int(os.getenv("ENABLE_MOE_ALIGN_BLOCK_SIZE_TRITON", "0"))
)
_is_cuda = is_cuda()
if _is_cuda:
from sgl_kernel import gelu_and_mul, silu_and_mul
from sglang.srt.custom_op import scaled_fp8_quant as sgl_scaled_fp8_quant
else:
from vllm import _custom_ops as vllm_ops
from vllm._custom_ops import scaled_fp8_quant
if _is_cuda or _is_hip:
from sgl_kernel import moe_align_block_size as sgl_moe_align_block_size
logger = logging.getLogger(__name__)
padding_size = 128 if bool(int(os.getenv("MOE_PADDING", "0"))) else 0
enable_moe_align_block_size_triton = bool(
int(os.getenv("ENABLE_MOE_ALIGN_BLOCK_SIZE_TRITON", "0"))
)
@triton.jit
def write_zeros_to_output(
c_ptr,
......@@ -770,14 +768,9 @@ def invoke_fused_moe_kernel(
# activation tensor-wise fp8 quantization, dynamic or static
padded_size = padding_size
# activations apply per-token quantization when weights apply per-channel quantization by default
if _is_cuda:
A, A_scale = sgl_scaled_fp8_quant(
A, A_scale, use_per_token_if_dynamic=per_channel_quant
)
else:
A, A_scale = vllm_ops.scaled_fp8_quant(
A, A_scale, use_per_token_if_dynamic=per_channel_quant
)
A, A_scale = scaled_fp8_quant(
A, A_scale, use_per_token_if_dynamic=per_channel_quant
)
else:
# activation block-wise fp8 quantization
assert len(block_shape) == 2
......
......@@ -13,7 +13,6 @@
# ==============================================================================
import math
import os
from typing import Callable, Optional
import torch
......@@ -29,6 +28,10 @@ _is_hip = is_hip()
if _is_cuda:
from sgl_kernel import moe_fused_gate
if _is_cuda or _is_hip:
from sgl_kernel import topk_softmax
expert_distribution_recorder = ExpertDistributionRecorder()
......@@ -59,11 +62,6 @@ def fused_topk(
topk: int,
renormalize: bool,
):
if _is_cuda or _is_hip:
from sgl_kernel import topk_softmax
else:
from vllm import _custom_ops as vllm_ops
assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
M, _ = hidden_states.shape
......@@ -76,20 +74,12 @@ def fused_topk(
M, topk, dtype=torch.int32, device=hidden_states.device
)
if _is_cuda or _is_hip:
topk_softmax(
topk_weights,
topk_ids,
token_expert_indicies,
gating_output.float(),
)
else:
vllm_ops.topk_softmax(
topk_weights,
topk_ids,
token_expert_indicies,
gating_output.float(),
)
topk_softmax(
topk_weights,
topk_ids,
token_expert_indicies,
gating_output.float(),
)
del token_expert_indicies
if renormalize:
......
......@@ -7,8 +7,6 @@ from typing import Callable, Optional, Union
import torch
from torch.nn import Parameter
from sglang.srt.distributed import get_tensor_model_parallel_rank
__all__ = [
"BasevLLMParameter",
"PackedvLLMParameter",
......
# Adapted from https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/layers/quantization/compressed_tensors
# Adapted from https://github.com/vllm-project/vllm/tree/v0.8.2/vllm/model_executor/layers/quantization/compressed_tensors
# SPDX-License-Identifier: Apache-2.0
import logging
......@@ -39,7 +39,6 @@ from sglang.srt.layers.quantization.compressed_tensors.utils import (
is_activation_quantization_format,
should_ignore_layer,
)
from sglang.srt.layers.quantization.kv_cache import BaseKVCacheMethod
logger = logging.getLogger(__name__)
......
# Adapted from https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/layers/quantization/compressed_tensors
# Adapted from https://github.com/vllm-project/vllm/tree/v0.8.2/vllm/model_executor/layers/quantization/compressed_tensors
# SPDX-License-Identifier: Apache-2.0
import enum
import logging
from enum import Enum
from typing import TYPE_CHECKING, Callable, List, Optional
from typing import Callable, List, Optional
import torch
from compressed_tensors import CompressionFormat
from compressed_tensors.quantization import QuantizationStrategy
if TYPE_CHECKING:
from sglang.srt.layers.moe.fused_moe_triton import (
FusedMoE,
FusedMoEMethodBase,
FusedMoeWeightScaleSupported,
)
from sglang.srt.layers.quantization.fp8_kernel import scaled_fp8_quant
from sglang.srt.layers.quantization.fp8_utils import normalize_e4m3fn_to_e4m3fnuz
from sglang.srt.layers.quantization.utils import (
all_close_1d,
......@@ -29,10 +23,9 @@ from sglang.srt.utils import set_weight_attrs
_is_cuda = is_cuda()
if _is_cuda:
from sglang.srt.custom_op import scaled_fp8_quant as sgl_scaled_fp8_quant
else:
if not _is_cuda:
from vllm import _custom_ops as vllm_ops
from vllm._custom_ops import scaled_fp8_quant
try:
import vllm
......@@ -58,8 +51,6 @@ __all__ = [
class CompressedTensorsMoEMethod:
def __new__(cls, *args, **kwargs):
from sglang.srt.layers.moe.fused_moe_triton import FusedMoEMethodBase
if cls is CompressedTensorsMoEMethod:
return super().__new__(cls)
return super().__new__(cls)
......@@ -76,7 +67,7 @@ class CompressedTensorsMoEMethod:
if quant_config._is_wNa16_group_channel(weight_quant, input_quant):
if not VLLM_AVAILABLE:
raise ImportError(
"vllm is not installed, to use CompressedTensorsWNA16MoEMethod, please install vllm"
"vllm is not installed, to use CompressedTensorsWNA16MoEMethod, please install vllm."
)
return CompressedTensorsWNA16MoEMethod(quant_config)
elif quant_config._is_fp8_w8a8(weight_quant, input_quant):
......@@ -92,11 +83,6 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
def __init__(
self, quant_config: "CompressedTensorsConfig" # type: ignore # noqa E501
):
from sglang.srt.layers.moe.fused_moe_triton import (
FusedMoEMethodBase,
FusedMoeWeightScaleSupported,
)
self.quant_config = quant_config
self.weight_quant = self.quant_config.target_scheme_map["Linear"].get("weights")
self.input_quant = self.quant_config.target_scheme_map["Linear"].get(
......@@ -267,19 +253,11 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
layer.w13_weight[expert_id][start : start + shard_size, :],
layer.w13_weight_scale[expert_id][shard_id],
)
(
layer.w13_weight[expert_id][start : start + shard_size, :],
_,
) = scaled_fp8_quant(dq_weight, max_w13_scales[expert_id])
if _is_cuda:
(
layer.w13_weight[expert_id][start : start + shard_size, :],
_,
) = sgl_scaled_fp8_quant(dq_weight, max_w13_scales[expert_id])
else:
(
layer.w13_weight[expert_id][start : start + shard_size, :],
_,
) = vllm_ops.scaled_fp8_quant(
dq_weight, max_w13_scales[expert_id]
)
start += shard_size
layer.w13_weight_scale = torch.nn.Parameter(
......@@ -345,11 +323,6 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
def __init__(
self, quant_config: "CompressedTensorsConfig" # type: ignore # noqa E501
):
from sglang.srt.layers.moe.fused_moe_triton import (
FusedMoEMethodBase,
FusedMoeWeightScaleSupported,
)
self.quant_config = quant_config
# TODO: @dsikka: refactor this to use schemes as other kernels
# are supported + check if the layer is being ignored.
......@@ -609,7 +582,7 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
requires_grad=False,
)
marlin_w13_qweight = ops.gptq_marlin_moe_repack(
marlin_w13_qweight = vllm_ops.gptq_marlin_moe_repack(
layer.w13_weight_packed,
layer.w13_g_idx_sort_indices,
layer.w13_weight_packed.shape[1] * self.packed_factor,
......@@ -617,7 +590,7 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
self.num_bits,
)
replace_tensor("w13_weight_packed", marlin_w13_qweight)
marlin_w2_qweight = ops.gptq_marlin_moe_repack(
marlin_w2_qweight = vllm_ops.gptq_marlin_moe_repack(
layer.w2_weight_packed,
layer.w2_g_idx_sort_indices,
layer.w2_weight_packed.shape[1] * self.packed_factor,
......@@ -661,14 +634,9 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
correction_bias: Optional[torch.Tensor] = None,
activation: str = "silu",
) -> torch.Tensor:
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
from sglang.srt.layers.moe.topk import select_experts
assert activation == "silu", "Only SiLU activation is supported."
if not VLLM_AVAILABLE:
raise ImportError(
"vllm is not installed, to use fused_marlin_moe, please install vllm"
)
if expert_map is not None:
raise NotImplementedError(
"Expert Parallelism is not supported for " "fused Marlin MoE method."
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment