Unverified Commit 177320a5 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Clean up imports (#5467)

parent d7bc19a4
...@@ -24,6 +24,7 @@ from sglang.api import ( ...@@ -24,6 +24,7 @@ from sglang.api import (
user_end, user_end,
video, video,
) )
from sglang.global_config import global_config
from sglang.lang.backend.runtime_endpoint import RuntimeEndpoint from sglang.lang.backend.runtime_endpoint import RuntimeEndpoint
from sglang.lang.choices import ( from sglang.lang.choices import (
greedy_token_selection, greedy_token_selection,
...@@ -31,6 +32,7 @@ from sglang.lang.choices import ( ...@@ -31,6 +32,7 @@ from sglang.lang.choices import (
unconditional_likelihood_normalized, unconditional_likelihood_normalized,
) )
from sglang.utils import LazyImport from sglang.utils import LazyImport
from sglang.version import __version__
ServerArgs = LazyImport("sglang.srt.server_args", "ServerArgs") ServerArgs = LazyImport("sglang.srt.server_args", "ServerArgs")
Anthropic = LazyImport("sglang.lang.backend.anthropic", "Anthropic") Anthropic = LazyImport("sglang.lang.backend.anthropic", "Anthropic")
...@@ -38,10 +40,6 @@ LiteLLM = LazyImport("sglang.lang.backend.litellm", "LiteLLM") ...@@ -38,10 +40,6 @@ LiteLLM = LazyImport("sglang.lang.backend.litellm", "LiteLLM")
OpenAI = LazyImport("sglang.lang.backend.openai", "OpenAI") OpenAI = LazyImport("sglang.lang.backend.openai", "OpenAI")
VertexAI = LazyImport("sglang.lang.backend.vertexai", "VertexAI") VertexAI = LazyImport("sglang.lang.backend.vertexai", "VertexAI")
# Other configs
from sglang.global_config import global_config
from sglang.version import __version__
__all__ = [ __all__ = [
"Engine", "Engine",
"Runtime", "Runtime",
......
...@@ -707,10 +707,6 @@ def sample_random_requests( ...@@ -707,10 +707,6 @@ def sample_random_requests(
# Download sharegpt if necessary # Download sharegpt if necessary
if not os.path.isfile(dataset_path): if not os.path.isfile(dataset_path):
print(
"If you do not want to randomly sample from a dataset,"
" please use --dataset-name random-ids."
)
dataset_path = download_and_cache_file(SHAREGPT_URL) dataset_path = download_and_cache_file(SHAREGPT_URL)
# Load the dataset. # Load the dataset.
......
from typing import List, Optional, Union
import numpy as np
from sglang.lang.backend.base_backend import BaseBackend from sglang.lang.backend.base_backend import BaseBackend
from sglang.lang.chat_template import get_chat_template from sglang.lang.chat_template import get_chat_template
from sglang.lang.interpreter import StreamExecutor from sglang.lang.interpreter import StreamExecutor
......
from typing import Callable, List, Optional, Union from typing import List, Optional, Union
from sglang.lang.chat_template import get_chat_template from sglang.lang.chat_template import get_chat_template
from sglang.lang.choices import ChoicesDecision, ChoicesSamplingMethod from sglang.lang.choices import ChoicesDecision, ChoicesSamplingMethod
......
...@@ -2,7 +2,7 @@ import dataclasses ...@@ -2,7 +2,7 @@ import dataclasses
import logging import logging
import time import time
import warnings import warnings
from typing import Callable, List, Optional, Union from typing import List, Optional, Union
import numpy as np import numpy as np
......
import os import os
import warnings import warnings
from typing import Optional
from sglang.lang.backend.base_backend import BaseBackend from sglang.lang.backend.base_backend import BaseBackend
from sglang.lang.chat_template import get_chat_template from sglang.lang.chat_template import get_chat_template
......
...@@ -5,13 +5,7 @@ from typing import List, Union ...@@ -5,13 +5,7 @@ from typing import List, Union
from sglang.global_config import global_config from sglang.global_config import global_config
from sglang.lang.interpreter import ProgramState, StreamExecutor, cache_program from sglang.lang.interpreter import ProgramState, StreamExecutor, cache_program
from sglang.lang.ir import ( from sglang.lang.ir import SglArgument, SglExpr, SglSamplingParams, SglVariable
SglArgument,
SglConstantText,
SglExpr,
SglSamplingParams,
SglVariable,
)
def compile_func(function, backend): def compile_func(function, backend):
......
"""Tracing a program.""" """Tracing a program."""
import uuid import uuid
from typing import Any, Callable, Dict, List, Optional, Union from typing import Any, Dict, List, Optional
from sglang.global_config import global_config
from sglang.lang.backend.base_backend import BaseBackend from sglang.lang.backend.base_backend import BaseBackend
from sglang.lang.interpreter import ProgramState, ProgramStateGroup from sglang.lang.interpreter import ProgramState, ProgramStateGroup
from sglang.lang.ir import ( from sglang.lang.ir import (
SglArgument, SglArgument,
SglCommitLazy,
SglConcateAndAppend,
SglConstantText, SglConstantText,
SglExpr, SglExpr,
SglExprList, SglExprList,
SglFork, SglFork,
SglFunction,
SglGen, SglGen,
SglGetForkItem, SglGetForkItem,
SglRoleBegin, SglRoleBegin,
...@@ -230,8 +226,8 @@ class TracerProgramState(ProgramState): ...@@ -230,8 +226,8 @@ class TracerProgramState(ProgramState):
self.cur_role = None self.cur_role = None
def _execute_var_scope_end(self, expr: SglVarScopeEnd): def _execute_var_scope_end(self, expr: SglVarScopeEnd):
new_node = SglVariable(name, source=self.last_node) new_node = SglVariable(expr.name, source=self.last_node)
self.variables[name] = new_node self.variables[expr.name] = new_node
def get_var(self, name): def get_var(self, name):
ret = self.arguments.get(name, None) ret = self.arguments.get(name, None)
......
# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/_custom_ops.py # Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/_custom_ops.py
import logging import logging
import os
from typing import List, Tuple from typing import List, Tuple
import torch import torch
import torch.library
from sglang.srt.utils import get_bool_env_var, is_hip, is_hpu from sglang.srt.utils import get_bool_env_var, is_hip, is_hpu
......
...@@ -42,65 +42,3 @@ class CustomOp(nn.Module): ...@@ -42,65 +42,3 @@ class CustomOp(nn.Module):
return self.forward_hip return self.forward_hip
else: else:
return self.forward_native return self.forward_native
if _is_cuda:
from sgl_kernel import sgl_per_tensor_quant_fp8, sgl_per_token_quant_fp8
def scaled_fp8_quant(
input: torch.Tensor,
scale: Optional[torch.Tensor] = None,
num_token_padding: Optional[int] = None,
use_per_token_if_dynamic: bool = False,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Quantize input tensor to FP8 (8-bit floating point) format.
Args:
input (torch.Tensor): Input tensor to be quantized
scale (Optional[torch.Tensor]): Pre-computed scaling factor for static quantization.
If None, scales will be computed dynamically.
num_token_padding (Optional[int]): If specified, pad the first dimension
of the output to at least this value.
use_per_token_if_dynamic (bool): When using dynamic scaling (scale=None),
determines the quantization granularity:
- True: compute scale per token
- False: compute single scale per tensor
Returns:
Tuple[torch.Tensor, torch.Tensor]: A tuple containing:
- quantized_tensor: The FP8 quantized version of input
- scale_tensor: The scaling factors used for quantization
Raises:
AssertionError: If input is not 2D or if static scale's numel != 1
"""
assert input.ndim == 2, f"Expected 2D input tensor, got {input.ndim}D"
shape = input.shape
out_dtype = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
if num_token_padding:
shape = (max(num_token_padding, input.shape[0]), shape[1])
output = torch.empty(shape, device=input.device, dtype=out_dtype)
if scale is None:
# Dynamic scaling
if use_per_token_if_dynamic:
scale = torch.empty(
(shape[0], 1), device=input.device, dtype=torch.float32
)
sgl_per_token_quant_fp8(input, output, scale)
else:
scale = torch.zeros(1, device=input.device, dtype=torch.float32)
sgl_per_tensor_quant_fp8(
input, output, scale, is_static=False
) # False for dynamic
else:
# Static scaling
assert (
scale.numel() == 1
), f"Expected scalar scale, got numel={scale.numel()}"
sgl_per_tensor_quant_fp8(
input, output, scale, is_static=True
) # True for static
return output, scale
...@@ -19,11 +19,10 @@ import torch.distributed as dist ...@@ -19,11 +19,10 @@ import torch.distributed as dist
from PIL.Image import Image from PIL.Image import Image
from torch.distributed.tensor import DeviceMesh, DTensor from torch.distributed.tensor import DeviceMesh, DTensor
from sglang.srt.entrypoints.engine import Engine
from sglang.srt.entrypoints.http_server_engine import HttpServerEngineAdapter from sglang.srt.entrypoints.http_server_engine import HttpServerEngineAdapter
from sglang.srt.model_executor.model_runner import LocalSerializedTensor from sglang.srt.model_executor.model_runner import LocalSerializedTensor
from sglang.srt.patch_torch import monkey_patch_torch_reductions from sglang.srt.patch_torch import monkey_patch_torch_reductions
from sglang.srt.server import Engine
from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.utils import MultiprocessingSerializer, broadcast_pyobj from sglang.srt.utils import MultiprocessingSerializer, broadcast_pyobj
......
...@@ -21,13 +21,6 @@ import torch ...@@ -21,13 +21,6 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from sglang.srt.utils import is_cuda_available
_is_cuda = is_cuda_available()
if _is_cuda:
from sgl_kernel import gelu_and_mul, gelu_tanh_and_mul, silu_and_mul
from sglang.srt.custom_op import CustomOp from sglang.srt.custom_op import CustomOp
from sglang.srt.distributed import ( from sglang.srt.distributed import (
divide, divide,
...@@ -35,7 +28,12 @@ from sglang.srt.distributed import ( ...@@ -35,7 +28,12 @@ from sglang.srt.distributed import (
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
) )
from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.utils import set_weight_attrs from sglang.srt.utils import is_cuda_available, set_weight_attrs
_is_cuda = is_cuda_available()
if _is_cuda:
from sgl_kernel import gelu_and_mul, gelu_tanh_and_mul, silu_and_mul
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
......
...@@ -19,6 +19,7 @@ from typing import Optional, Tuple, Union ...@@ -19,6 +19,7 @@ from typing import Optional, Tuple, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
from sglang.srt.custom_op import CustomOp
from sglang.srt.utils import is_cuda_available from sglang.srt.utils import is_cuda_available
_is_cuda = is_cuda_available() _is_cuda = is_cuda_available()
...@@ -31,7 +32,6 @@ if _is_cuda: ...@@ -31,7 +32,6 @@ if _is_cuda:
rmsnorm, rmsnorm,
) )
from sglang.srt.custom_op import CustomOp
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
......
...@@ -2,6 +2,7 @@ import logging ...@@ -2,6 +2,7 @@ import logging
from typing import Callable, List, Optional, Tuple from typing import Callable, List, Optional, Tuple
import torch import torch
from torch.nn import Module
try: try:
from deep_gemm import ( from deep_gemm import (
...@@ -13,8 +14,6 @@ try: ...@@ -13,8 +14,6 @@ try:
except ImportError: except ImportError:
use_deep_gemm = False use_deep_gemm = False
from torch.nn import Module
from sglang.srt.custom_op import CustomOp from sglang.srt.custom_op import CustomOp
from sglang.srt.distributed import ( from sglang.srt.distributed import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_rank,
...@@ -37,22 +36,17 @@ from sglang.srt.layers.quantization.base_config import ( ...@@ -37,22 +36,17 @@ from sglang.srt.layers.quantization.base_config import (
QuantizeMethodBase, QuantizeMethodBase,
) )
from sglang.srt.layers.quantization.fp8 import Fp8Config, Fp8MoEMethod from sglang.srt.layers.quantization.fp8 import Fp8Config, Fp8MoEMethod
from sglang.srt.layers.quantization.fp8_kernel import scaled_fp8_quant
from sglang.srt.model_executor.forward_batch_info import ForwardMode from sglang.srt.model_executor.forward_batch_info import ForwardMode
from sglang.srt.utils import DeepEPMode, is_cuda, is_hip, set_weight_attrs from sglang.srt.utils import DeepEPMode, is_hip, set_weight_attrs
_is_cuda = is_cuda() _is_hip = is_hip()
if _is_cuda: if _is_hip:
from sglang.srt.custom_op import scaled_fp8_quant as sgl_scaled_fp8_quant from vllm._custom_ops import scaled_fp8_quant
else:
from vllm import _custom_ops as vllm_ops
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
_is_hip = is_hip()
_buffer = None
class GroupedGemmRunner(torch.nn.Module): class GroupedGemmRunner(torch.nn.Module):
flashinfer_gemm_warpper = None flashinfer_gemm_warpper = None
...@@ -740,20 +734,12 @@ class Fp8EPMoEMethod(Fp8MoEMethod): ...@@ -740,20 +734,12 @@ class Fp8EPMoEMethod(Fp8MoEMethod):
) )
for expert in range(layer.num_experts_per_partition): for expert in range(layer.num_experts_per_partition):
if _is_cuda: w13_weight[expert, :, :], layer.w13_weight_scale[expert] = (
w13_weight[expert, :, :], layer.w13_weight_scale[expert] = ( scaled_fp8_quant(layer.w13_weight.data[expert, :, :])
sgl_scaled_fp8_quant(layer.w13_weight.data[expert, :, :]) )
) w2_weight[expert, :, :], layer.w2_weight_scale[expert] = (
w2_weight[expert, :, :], layer.w2_weight_scale[expert] = ( scaled_fp8_quant(layer.w2_weight.data[expert, :, :])
sgl_scaled_fp8_quant(layer.w2_weight.data[expert, :, :]) )
)
else:
w13_weight[expert, :, :], layer.w13_weight_scale[expert] = (
vllm_ops.scaled_fp8_quant(layer.w13_weight.data[expert, :, :])
)
w2_weight[expert, :, :], layer.w2_weight_scale[expert] = (
vllm_ops.scaled_fp8_quant(layer.w2_weight.data[expert, :, :])
)
layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False) layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False)
layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False) layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False)
return return
......
...@@ -13,6 +13,7 @@ import triton ...@@ -13,6 +13,7 @@ import triton
import triton.language as tl import triton.language as tl
from sglang.srt.layers.moe.topk import select_experts from sglang.srt.layers.moe.topk import select_experts
from sglang.srt.layers.quantization.fp8_kernel import scaled_fp8_quant
from sglang.srt.utils import ( from sglang.srt.utils import (
direct_register_custom_op, direct_register_custom_op,
get_bool_env_var, get_bool_env_var,
...@@ -22,28 +23,25 @@ from sglang.srt.utils import ( ...@@ -22,28 +23,25 @@ from sglang.srt.utils import (
) )
_is_hip = is_hip() _is_hip = is_hip()
logger = logging.getLogger(__name__)
padding_size = 128 if bool(int(os.getenv("MOE_PADDING", "0"))) else 0
enable_moe_align_block_size_triton = bool(
int(os.getenv("ENABLE_MOE_ALIGN_BLOCK_SIZE_TRITON", "0"))
)
_is_cuda = is_cuda() _is_cuda = is_cuda()
if _is_cuda: if _is_cuda:
from sgl_kernel import gelu_and_mul, silu_and_mul from sgl_kernel import gelu_and_mul, silu_and_mul
from sglang.srt.custom_op import scaled_fp8_quant as sgl_scaled_fp8_quant
else: else:
from vllm import _custom_ops as vllm_ops from vllm import _custom_ops as vllm_ops
from vllm._custom_ops import scaled_fp8_quant
if _is_cuda or _is_hip: if _is_cuda or _is_hip:
from sgl_kernel import moe_align_block_size as sgl_moe_align_block_size from sgl_kernel import moe_align_block_size as sgl_moe_align_block_size
logger = logging.getLogger(__name__)
padding_size = 128 if bool(int(os.getenv("MOE_PADDING", "0"))) else 0
enable_moe_align_block_size_triton = bool(
int(os.getenv("ENABLE_MOE_ALIGN_BLOCK_SIZE_TRITON", "0"))
)
@triton.jit @triton.jit
def write_zeros_to_output( def write_zeros_to_output(
c_ptr, c_ptr,
...@@ -770,14 +768,9 @@ def invoke_fused_moe_kernel( ...@@ -770,14 +768,9 @@ def invoke_fused_moe_kernel(
# activation tensor-wise fp8 quantization, dynamic or static # activation tensor-wise fp8 quantization, dynamic or static
padded_size = padding_size padded_size = padding_size
# activations apply per-token quantization when weights apply per-channel quantization by default # activations apply per-token quantization when weights apply per-channel quantization by default
if _is_cuda: A, A_scale = scaled_fp8_quant(
A, A_scale = sgl_scaled_fp8_quant( A, A_scale, use_per_token_if_dynamic=per_channel_quant
A, A_scale, use_per_token_if_dynamic=per_channel_quant )
)
else:
A, A_scale = vllm_ops.scaled_fp8_quant(
A, A_scale, use_per_token_if_dynamic=per_channel_quant
)
else: else:
# activation block-wise fp8 quantization # activation block-wise fp8 quantization
assert len(block_shape) == 2 assert len(block_shape) == 2
......
...@@ -13,7 +13,6 @@ ...@@ -13,7 +13,6 @@
# ============================================================================== # ==============================================================================
import math import math
import os
from typing import Callable, Optional from typing import Callable, Optional
import torch import torch
...@@ -29,6 +28,10 @@ _is_hip = is_hip() ...@@ -29,6 +28,10 @@ _is_hip = is_hip()
if _is_cuda: if _is_cuda:
from sgl_kernel import moe_fused_gate from sgl_kernel import moe_fused_gate
if _is_cuda or _is_hip:
from sgl_kernel import topk_softmax
expert_distribution_recorder = ExpertDistributionRecorder() expert_distribution_recorder = ExpertDistributionRecorder()
...@@ -59,11 +62,6 @@ def fused_topk( ...@@ -59,11 +62,6 @@ def fused_topk(
topk: int, topk: int,
renormalize: bool, renormalize: bool,
): ):
if _is_cuda or _is_hip:
from sgl_kernel import topk_softmax
else:
from vllm import _custom_ops as vllm_ops
assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch" assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
M, _ = hidden_states.shape M, _ = hidden_states.shape
...@@ -76,20 +74,12 @@ def fused_topk( ...@@ -76,20 +74,12 @@ def fused_topk(
M, topk, dtype=torch.int32, device=hidden_states.device M, topk, dtype=torch.int32, device=hidden_states.device
) )
if _is_cuda or _is_hip: topk_softmax(
topk_softmax( topk_weights,
topk_weights, topk_ids,
topk_ids, token_expert_indicies,
token_expert_indicies, gating_output.float(),
gating_output.float(), )
)
else:
vllm_ops.topk_softmax(
topk_weights,
topk_ids,
token_expert_indicies,
gating_output.float(),
)
del token_expert_indicies del token_expert_indicies
if renormalize: if renormalize:
......
...@@ -7,8 +7,6 @@ from typing import Callable, Optional, Union ...@@ -7,8 +7,6 @@ from typing import Callable, Optional, Union
import torch import torch
from torch.nn import Parameter from torch.nn import Parameter
from sglang.srt.distributed import get_tensor_model_parallel_rank
__all__ = [ __all__ = [
"BasevLLMParameter", "BasevLLMParameter",
"PackedvLLMParameter", "PackedvLLMParameter",
......
# Adapted from https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/layers/quantization/compressed_tensors # Adapted from https://github.com/vllm-project/vllm/tree/v0.8.2/vllm/model_executor/layers/quantization/compressed_tensors
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import logging import logging
...@@ -39,7 +39,6 @@ from sglang.srt.layers.quantization.compressed_tensors.utils import ( ...@@ -39,7 +39,6 @@ from sglang.srt.layers.quantization.compressed_tensors.utils import (
is_activation_quantization_format, is_activation_quantization_format,
should_ignore_layer, should_ignore_layer,
) )
from sglang.srt.layers.quantization.kv_cache import BaseKVCacheMethod
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
......
# Adapted from https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/layers/quantization/compressed_tensors # Adapted from https://github.com/vllm-project/vllm/tree/v0.8.2/vllm/model_executor/layers/quantization/compressed_tensors
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import enum import enum
import logging import logging
from enum import Enum from enum import Enum
from typing import TYPE_CHECKING, Callable, List, Optional from typing import Callable, List, Optional
import torch import torch
from compressed_tensors import CompressionFormat from compressed_tensors import CompressionFormat
from compressed_tensors.quantization import QuantizationStrategy from compressed_tensors.quantization import QuantizationStrategy
if TYPE_CHECKING: from sglang.srt.layers.quantization.fp8_kernel import scaled_fp8_quant
from sglang.srt.layers.moe.fused_moe_triton import (
FusedMoE,
FusedMoEMethodBase,
FusedMoeWeightScaleSupported,
)
from sglang.srt.layers.quantization.fp8_utils import normalize_e4m3fn_to_e4m3fnuz from sglang.srt.layers.quantization.fp8_utils import normalize_e4m3fn_to_e4m3fnuz
from sglang.srt.layers.quantization.utils import ( from sglang.srt.layers.quantization.utils import (
all_close_1d, all_close_1d,
...@@ -29,10 +23,9 @@ from sglang.srt.utils import set_weight_attrs ...@@ -29,10 +23,9 @@ from sglang.srt.utils import set_weight_attrs
_is_cuda = is_cuda() _is_cuda = is_cuda()
if _is_cuda: if not _is_cuda:
from sglang.srt.custom_op import scaled_fp8_quant as sgl_scaled_fp8_quant
else:
from vllm import _custom_ops as vllm_ops from vllm import _custom_ops as vllm_ops
from vllm._custom_ops import scaled_fp8_quant
try: try:
import vllm import vllm
...@@ -58,8 +51,6 @@ __all__ = [ ...@@ -58,8 +51,6 @@ __all__ = [
class CompressedTensorsMoEMethod: class CompressedTensorsMoEMethod:
def __new__(cls, *args, **kwargs): def __new__(cls, *args, **kwargs):
from sglang.srt.layers.moe.fused_moe_triton import FusedMoEMethodBase
if cls is CompressedTensorsMoEMethod: if cls is CompressedTensorsMoEMethod:
return super().__new__(cls) return super().__new__(cls)
return super().__new__(cls) return super().__new__(cls)
...@@ -76,7 +67,7 @@ class CompressedTensorsMoEMethod: ...@@ -76,7 +67,7 @@ class CompressedTensorsMoEMethod:
if quant_config._is_wNa16_group_channel(weight_quant, input_quant): if quant_config._is_wNa16_group_channel(weight_quant, input_quant):
if not VLLM_AVAILABLE: if not VLLM_AVAILABLE:
raise ImportError( raise ImportError(
"vllm is not installed, to use CompressedTensorsWNA16MoEMethod, please install vllm" "vllm is not installed, to use CompressedTensorsWNA16MoEMethod, please install vllm."
) )
return CompressedTensorsWNA16MoEMethod(quant_config) return CompressedTensorsWNA16MoEMethod(quant_config)
elif quant_config._is_fp8_w8a8(weight_quant, input_quant): elif quant_config._is_fp8_w8a8(weight_quant, input_quant):
...@@ -92,11 +83,6 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): ...@@ -92,11 +83,6 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
def __init__( def __init__(
self, quant_config: "CompressedTensorsConfig" # type: ignore # noqa E501 self, quant_config: "CompressedTensorsConfig" # type: ignore # noqa E501
): ):
from sglang.srt.layers.moe.fused_moe_triton import (
FusedMoEMethodBase,
FusedMoeWeightScaleSupported,
)
self.quant_config = quant_config self.quant_config = quant_config
self.weight_quant = self.quant_config.target_scheme_map["Linear"].get("weights") self.weight_quant = self.quant_config.target_scheme_map["Linear"].get("weights")
self.input_quant = self.quant_config.target_scheme_map["Linear"].get( self.input_quant = self.quant_config.target_scheme_map["Linear"].get(
...@@ -267,19 +253,11 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): ...@@ -267,19 +253,11 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
layer.w13_weight[expert_id][start : start + shard_size, :], layer.w13_weight[expert_id][start : start + shard_size, :],
layer.w13_weight_scale[expert_id][shard_id], layer.w13_weight_scale[expert_id][shard_id],
) )
(
layer.w13_weight[expert_id][start : start + shard_size, :],
_,
) = scaled_fp8_quant(dq_weight, max_w13_scales[expert_id])
if _is_cuda:
(
layer.w13_weight[expert_id][start : start + shard_size, :],
_,
) = sgl_scaled_fp8_quant(dq_weight, max_w13_scales[expert_id])
else:
(
layer.w13_weight[expert_id][start : start + shard_size, :],
_,
) = vllm_ops.scaled_fp8_quant(
dq_weight, max_w13_scales[expert_id]
)
start += shard_size start += shard_size
layer.w13_weight_scale = torch.nn.Parameter( layer.w13_weight_scale = torch.nn.Parameter(
...@@ -345,11 +323,6 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod): ...@@ -345,11 +323,6 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
def __init__( def __init__(
self, quant_config: "CompressedTensorsConfig" # type: ignore # noqa E501 self, quant_config: "CompressedTensorsConfig" # type: ignore # noqa E501
): ):
from sglang.srt.layers.moe.fused_moe_triton import (
FusedMoEMethodBase,
FusedMoeWeightScaleSupported,
)
self.quant_config = quant_config self.quant_config = quant_config
# TODO: @dsikka: refactor this to use schemes as other kernels # TODO: @dsikka: refactor this to use schemes as other kernels
# are supported + check if the layer is being ignored. # are supported + check if the layer is being ignored.
...@@ -609,7 +582,7 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod): ...@@ -609,7 +582,7 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
requires_grad=False, requires_grad=False,
) )
marlin_w13_qweight = ops.gptq_marlin_moe_repack( marlin_w13_qweight = vllm_ops.gptq_marlin_moe_repack(
layer.w13_weight_packed, layer.w13_weight_packed,
layer.w13_g_idx_sort_indices, layer.w13_g_idx_sort_indices,
layer.w13_weight_packed.shape[1] * self.packed_factor, layer.w13_weight_packed.shape[1] * self.packed_factor,
...@@ -617,7 +590,7 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod): ...@@ -617,7 +590,7 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
self.num_bits, self.num_bits,
) )
replace_tensor("w13_weight_packed", marlin_w13_qweight) replace_tensor("w13_weight_packed", marlin_w13_qweight)
marlin_w2_qweight = ops.gptq_marlin_moe_repack( marlin_w2_qweight = vllm_ops.gptq_marlin_moe_repack(
layer.w2_weight_packed, layer.w2_weight_packed,
layer.w2_g_idx_sort_indices, layer.w2_g_idx_sort_indices,
layer.w2_weight_packed.shape[1] * self.packed_factor, layer.w2_weight_packed.shape[1] * self.packed_factor,
...@@ -661,14 +634,9 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod): ...@@ -661,14 +634,9 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
correction_bias: Optional[torch.Tensor] = None, correction_bias: Optional[torch.Tensor] = None,
activation: str = "silu", activation: str = "silu",
) -> torch.Tensor: ) -> torch.Tensor:
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
from sglang.srt.layers.moe.topk import select_experts from sglang.srt.layers.moe.topk import select_experts
assert activation == "silu", "Only SiLU activation is supported." assert activation == "silu", "Only SiLU activation is supported."
if not VLLM_AVAILABLE:
raise ImportError(
"vllm is not installed, to use fused_marlin_moe, please install vllm"
)
if expert_map is not None: if expert_map is not None:
raise NotImplementedError( raise NotImplementedError(
"Expert Parallelism is not supported for " "fused Marlin MoE method." "Expert Parallelism is not supported for " "fused Marlin MoE method."
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment