Unverified Commit 6af03f23 authored by BadrBasowid's avatar BadrBasowid Committed by GitHub
Browse files

[Refactor] [1/N] Reorganize kernel abstraction directory (#34055)


Signed-off-by: default avatarBadrBasowid <badr.basowid@gmail.com>
Co-authored-by: default avatarvllmellm <vllm.ellm@embeddedllm.com>
Co-authored-by: default avatarTJian <tunjian.tan@embeddedllm.com>
parent 1a6cf39d
...@@ -26,24 +26,16 @@ from vllm.config import ( ...@@ -26,24 +26,16 @@ from vllm.config import (
PassConfig, PassConfig,
VllmConfig, VllmConfig,
) )
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.kernels.linear import (
from vllm.model_executor.layers.quantization.kernels.scaled_mm.cutlass import ( ChannelWiseTorchFP8ScaledMMLinearKernel,
CutlassFP8ScaledMMLinearKernel, CutlassFP8ScaledMMLinearKernel,
)
from vllm.model_executor.layers.quantization.kernels.scaled_mm.flashinfer import (
FlashInferFP8ScaledMMLinearKernel, FlashInferFP8ScaledMMLinearKernel,
) FP8ScaledMMLinearKernel,
from vllm.model_executor.layers.quantization.kernels.scaled_mm.pytorch import (
ChannelWiseTorchFP8ScaledMMLinearKernel,
PerTensorTorchFP8ScaledMMLinearKernel, PerTensorTorchFP8ScaledMMLinearKernel,
RowWiseTorchFP8ScaledMMLinearKernel,
)
from vllm.model_executor.layers.quantization.kernels.scaled_mm.rocm import (
ROCmFP8ScaledMMLinearKernel, ROCmFP8ScaledMMLinearKernel,
RowWiseTorchFP8ScaledMMLinearKernel,
) )
from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa: E501 from vllm.model_executor.layers.layernorm import RMSNorm
FP8ScaledMMLinearKernel,
)
from vllm.model_executor.layers.quantization.utils.quant_utils import ( from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape, GroupShape,
QuantKey, QuantKey,
......
...@@ -26,22 +26,14 @@ from vllm.config import ( ...@@ -26,22 +26,14 @@ from vllm.config import (
VllmConfig, VllmConfig,
set_current_vllm_config, set_current_vllm_config,
) )
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.kernels.linear import (
from vllm.model_executor.layers.quantization.kernels.scaled_mm.cutlass import (
CutlassFP8ScaledMMLinearKernel, CutlassFP8ScaledMMLinearKernel,
)
from vllm.model_executor.layers.quantization.kernels.scaled_mm.flashinfer import (
FlashInferFP8ScaledMMLinearKernel, FlashInferFP8ScaledMMLinearKernel,
) FP8ScaledMMLinearKernel,
from vllm.model_executor.layers.quantization.kernels.scaled_mm.pytorch import (
PerTensorTorchFP8ScaledMMLinearKernel, PerTensorTorchFP8ScaledMMLinearKernel,
)
from vllm.model_executor.layers.quantization.kernels.scaled_mm.rocm import (
ROCmFP8ScaledMMLinearKernel, ROCmFP8ScaledMMLinearKernel,
) )
from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa: E501 from vllm.model_executor.layers.activation import SiluAndMul
FP8ScaledMMLinearKernel,
)
from vllm.model_executor.layers.quantization.utils.fp8_utils import W8A8BlockFp8LinearOp from vllm.model_executor.layers.quantization.utils.fp8_utils import W8A8BlockFp8LinearOp
from vllm.model_executor.layers.quantization.utils.quant_utils import ( from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape, GroupShape,
......
...@@ -10,16 +10,10 @@ from abc import ABC ...@@ -10,16 +10,10 @@ from abc import ABC
import pytest import pytest
from vllm.model_executor.layers.quantization.kernels.scaled_mm import ( from vllm.model_executor.kernels.linear import (
Int8ScaledMMLinearLayerConfig,
)
from vllm.model_executor.layers.quantization.kernels.scaled_mm.aiter import (
AiterInt8ScaledMMLinearKernel, AiterInt8ScaledMMLinearKernel,
)
from vllm.model_executor.layers.quantization.kernels.scaled_mm.cpu import (
CPUInt8ScaledMMLinearKernel, CPUInt8ScaledMMLinearKernel,
) Int8ScaledMMLinearLayerConfig,
from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa: E501
ScaledMMLinearKernel, ScaledMMLinearKernel,
) )
......
...@@ -42,11 +42,9 @@ from vllm.distributed import ( ...@@ -42,11 +42,9 @@ from vllm.distributed import (
) )
from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.entrypoints.cli.serve import ServeSubcommand from vllm.entrypoints.cli.serve import ServeSubcommand
from vllm.model_executor.layers.quantization.kernels.scaled_mm import ( from vllm.model_executor.kernels.linear import (
init_fp8_linear_kernel,
)
from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa: E501
FP8ScaledMMLinearKernel, FP8ScaledMMLinearKernel,
init_fp8_linear_kernel,
) )
from vllm.model_executor.layers.quantization.utils.fp8_utils import W8A8BlockFp8LinearOp from vllm.model_executor.layers.quantization.utils.fp8_utils import W8A8BlockFp8LinearOp
from vllm.model_executor.layers.quantization.utils.quant_utils import ( from vllm.model_executor.layers.quantization.utils.quant_utils import (
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
This module re-exports linear kernel implementations to provide a
stable import interface during an ongoing reorganization. Upcoming
PRs will remove the scaled_mm and mixed_precision subdirectories
and reorganize kernels by provider (aiter, cutlass, flashinfer, etc.)
rather than by precision type. By centralizing exports here, we
minimize the need to update imports across other modules when the
internal structure changes. If you are adding a new kernel selector
or kernel implementation, add it to this __init__.py to maintain
import stability.
"""
import os import os
from typing import TypeVar from typing import TypeVar
import torch import torch
import vllm.envs as envs
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.kernels.scaled_mm.aiter import ( from vllm.model_executor.kernels.linear.mixed_precision import (
MPLinearKernel,
MPLinearLayerConfig,
)
from vllm.model_executor.kernels.linear.mixed_precision.allspark import (
AllSparkLinearKernel,
)
from vllm.model_executor.kernels.linear.mixed_precision.conch import (
ConchLinearKernel,
)
from vllm.model_executor.kernels.linear.mixed_precision.cpu import (
CPUWNA16LinearKernel,
)
from vllm.model_executor.kernels.linear.mixed_precision.cutlass import (
CutlassW4A8LinearKernel,
)
from vllm.model_executor.kernels.linear.mixed_precision.dynamic_4bit import (
Dynamic4bitLinearKernel,
)
from vllm.model_executor.kernels.linear.mixed_precision.exllama import (
ExllamaLinearKernel,
)
from vllm.model_executor.kernels.linear.mixed_precision.machete import (
MacheteLinearKernel,
)
from vllm.model_executor.kernels.linear.mixed_precision.marlin import (
MarlinLinearKernel,
)
from vllm.model_executor.kernels.linear.mixed_precision.xpu import (
XPUwNa16LinearKernel,
)
from vllm.model_executor.kernels.linear.scaled_mm import (
FP8ScaledMMLinearKernel,
FP8ScaledMMLinearLayerConfig,
Int8ScaledMMLinearKernel,
Int8ScaledMMLinearLayerConfig,
ScaledMMLinearKernel,
ScaledMMLinearLayerConfig,
)
from vllm.model_executor.kernels.linear.scaled_mm.aiter import (
AiterInt8ScaledMMLinearKernel, AiterInt8ScaledMMLinearKernel,
) )
from vllm.model_executor.layers.quantization.kernels.scaled_mm.cpu import ( from vllm.model_executor.kernels.linear.scaled_mm.cpu import (
CPUInt8ScaledMMLinearKernel, CPUInt8ScaledMMLinearKernel,
) )
from vllm.model_executor.layers.quantization.kernels.scaled_mm.cutlass import ( from vllm.model_executor.kernels.linear.scaled_mm.cutlass import (
CutlassFP8ScaledMMLinearKernel, CutlassFP8ScaledMMLinearKernel,
CutlassInt8ScaledMMLinearKernel, CutlassInt8ScaledMMLinearKernel,
) )
from vllm.model_executor.layers.quantization.kernels.scaled_mm.flashinfer import ( from vllm.model_executor.kernels.linear.scaled_mm.flashinfer import (
FlashInferFP8ScaledMMLinearKernel, FlashInferFP8ScaledMMLinearKernel,
) )
from vllm.model_executor.layers.quantization.kernels.scaled_mm.pytorch import ( from vllm.model_executor.kernels.linear.scaled_mm.pytorch import (
ChannelWiseTorchFP8ScaledMMLinearKernel, ChannelWiseTorchFP8ScaledMMLinearKernel,
PerTensorTorchFP8ScaledMMLinearKernel, PerTensorTorchFP8ScaledMMLinearKernel,
RowWiseTorchFP8ScaledMMLinearKernel, RowWiseTorchFP8ScaledMMLinearKernel,
) )
from vllm.model_executor.layers.quantization.kernels.scaled_mm.rocm import ( from vllm.model_executor.kernels.linear.scaled_mm.rocm import (
ROCmFP8ScaledMMLinearKernel, ROCmFP8ScaledMMLinearKernel,
) )
from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa: E501 from vllm.model_executor.kernels.linear.scaled_mm.triton import (
FP8ScaledMMLinearKernel,
FP8ScaledMMLinearLayerConfig,
Int8ScaledMMLinearKernel,
Int8ScaledMMLinearLayerConfig,
ScaledMMLinearKernel,
ScaledMMLinearLayerConfig,
)
from vllm.model_executor.layers.quantization.kernels.scaled_mm.triton import (
TritonInt8ScaledMMLinearKernel, TritonInt8ScaledMMLinearKernel,
) )
from vllm.model_executor.layers.quantization.kernels.scaled_mm.xpu import ( from vllm.model_executor.kernels.linear.scaled_mm.xpu import (
XPUFP8ScaledMMLinearKernel, XPUFP8ScaledMMLinearKernel,
) )
from vllm.model_executor.layers.quantization.utils.quant_utils import QuantKey from vllm.model_executor.layers.quantization.utils.quant_utils import QuantKey
...@@ -80,6 +124,29 @@ _POSSIBLE_FP8_KERNELS: dict[PlatformEnum, list[type[FP8ScaledMMLinearKernel]]] = ...@@ -80,6 +124,29 @@ _POSSIBLE_FP8_KERNELS: dict[PlatformEnum, list[type[FP8ScaledMMLinearKernel]]] =
], ],
} }
# in priority/performance order (when available)
_POSSIBLE_KERNELS: dict[PlatformEnum, list[type[MPLinearKernel]]] = {
PlatformEnum.CUDA: [
CutlassW4A8LinearKernel,
MacheteLinearKernel,
AllSparkLinearKernel,
MarlinLinearKernel,
ConchLinearKernel,
ExllamaLinearKernel,
],
PlatformEnum.ROCM: [
ConchLinearKernel,
ExllamaLinearKernel,
],
PlatformEnum.XPU: [
XPUwNa16LinearKernel,
],
PlatformEnum.CPU: [
Dynamic4bitLinearKernel,
CPUWNA16LinearKernel,
],
}
_KernelT = TypeVar("_KernelT", bound=ScaledMMLinearKernel) _KernelT = TypeVar("_KernelT", bound=ScaledMMLinearKernel)
_KernelConfigT = TypeVar("_KernelConfigT", bound=ScaledMMLinearLayerConfig) _KernelConfigT = TypeVar("_KernelConfigT", bound=ScaledMMLinearLayerConfig)
...@@ -234,3 +301,97 @@ def init_int8_linear_kernel( ...@@ -234,3 +301,97 @@ def init_int8_linear_kernel(
"azp_adj", "azp_adj",
], ],
) )
def choose_mp_linear_kernel(
config: MPLinearLayerConfig, compute_capability: int | None = None
) -> type[MPLinearKernel]:
"""
Choose an MPLinearKernel that can implement the given config for the given
compute capability. Attempts to choose the best kernel in terms of
performance.
Args:
config (MPLinearLayerConfig): Description of the linear layer to be
implemented.
compute_capability (Optional[int], optional): The compute capability of
the target device, if None uses `current_platform` to get
the compute capability. Defaults to None.
Raises:
ValueError: If no kernel can implement the given config.
Returns:
type[MPLinearKernel]: Chosen kernel.
"""
if compute_capability is None:
if current_platform is None:
raise ValueError("Cannot determine compute capability")
_cc = current_platform.get_device_capability()
if _cc is not None:
compute_capability = _cc[0] * 10 + _cc[1]
failure_reasons = []
for kernel in _POSSIBLE_KERNELS[current_platform._enum]:
if kernel.__name__ in envs.VLLM_DISABLED_KERNELS:
failure_reasons.append(
f" {kernel.__name__} disabled by environment variable"
)
continue
if (
compute_capability is not None
and kernel.get_min_capability() > compute_capability
):
failure_reasons.append(
f"{kernel.__name__} requires capability "
f"{kernel.get_min_capability()}, current compute "
f" capability is {compute_capability}"
)
continue
can_implement, failure_reason = kernel.can_implement(config)
if can_implement:
return kernel
else:
failure_reasons.append(
f" {kernel.__name__} cannot implement due to: {failure_reason}"
)
raise ValueError(
"Failed to find a kernel that can implement the "
"WNA16 linear layer. Reasons: \n" + "\n".join(failure_reasons)
)
__all__ = [
"init_fp8_linear_kernel",
"init_int8_linear_kernel",
"choose_mp_linear_kernel",
"FP8ScaledMMLinearKernel",
"Int8ScaledMMLinearKernel",
"ScaledMMLinearKernel",
"FP8ScaledMMLinearLayerConfig",
"Int8ScaledMMLinearLayerConfig",
"ScaledMMLinearLayerConfig",
"AiterInt8ScaledMMLinearKernel",
"CPUInt8ScaledMMLinearKernel",
"CutlassFP8ScaledMMLinearKernel",
"CutlassInt8ScaledMMLinearKernel",
"FlashInferFP8ScaledMMLinearKernel",
"ChannelWiseTorchFP8ScaledMMLinearKernel",
"PerTensorTorchFP8ScaledMMLinearKernel",
"RowWiseTorchFP8ScaledMMLinearKernel",
"ROCmFP8ScaledMMLinearKernel",
"TritonInt8ScaledMMLinearKernel",
"MPLinearKernel",
"MPLinearLayerConfig",
"AllSparkLinearKernel",
"ConchLinearKernel",
"CPUWNA16LinearKernel",
"CutlassW4A8LinearKernel",
"Dynamic4bitLinearKernel",
"ExllamaLinearKernel",
"MacheteLinearKernel",
"MarlinLinearKernel",
"XPUwNa16LinearKernel",
]
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from vllm.model_executor.kernels.linear.mixed_precision.allspark import (
AllSparkLinearKernel,
)
from vllm.model_executor.kernels.linear.mixed_precision.conch import (
ConchLinearKernel,
)
from vllm.model_executor.kernels.linear.mixed_precision.cpu import (
CPUWNA16LinearKernel,
)
from vllm.model_executor.kernels.linear.mixed_precision.cutlass import (
CutlassW4A8LinearKernel,
)
from vllm.model_executor.kernels.linear.mixed_precision.dynamic_4bit import (
Dynamic4bitLinearKernel,
)
from vllm.model_executor.kernels.linear.mixed_precision.exllama import (
ExllamaLinearKernel,
)
from vllm.model_executor.kernels.linear.mixed_precision.machete import (
MacheteLinearKernel,
)
from vllm.model_executor.kernels.linear.mixed_precision.marlin import (
MarlinLinearKernel,
)
from vllm.model_executor.kernels.linear.mixed_precision.MPLinearKernel import (
MPLinearKernel,
MPLinearLayerConfig,
)
from vllm.model_executor.kernels.linear.mixed_precision.xpu import (
XPUwNa16LinearKernel,
)
__all__ = [
"MPLinearKernel",
"MPLinearLayerConfig",
"AllSparkLinearKernel",
"ConchLinearKernel",
"CPUWNA16LinearKernel",
"CutlassW4A8LinearKernel",
"Dynamic4bitLinearKernel",
"ExllamaLinearKernel",
"MacheteLinearKernel",
"MarlinLinearKernel",
"XPUwNa16LinearKernel",
]
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from vllm.model_executor.kernels.linear.scaled_mm.aiter import (
AiterInt8ScaledMMLinearKernel,
)
from vllm.model_executor.kernels.linear.scaled_mm.cpu import (
CPUInt8ScaledMMLinearKernel,
)
from vllm.model_executor.kernels.linear.scaled_mm.cutlass import (
CutlassFP8ScaledMMLinearKernel,
CutlassInt8ScaledMMLinearKernel,
)
from vllm.model_executor.kernels.linear.scaled_mm.flashinfer import (
FlashInferFP8ScaledMMLinearKernel,
)
from vllm.model_executor.kernels.linear.scaled_mm.pytorch import (
ChannelWiseTorchFP8ScaledMMLinearKernel,
PerTensorTorchFP8ScaledMMLinearKernel,
RowWiseTorchFP8ScaledMMLinearKernel,
)
from vllm.model_executor.kernels.linear.scaled_mm.rocm import (
ROCmFP8ScaledMMLinearKernel,
)
from vllm.model_executor.kernels.linear.scaled_mm.ScaledMMLinearKernel import (
FP8ScaledMMLinearKernel,
FP8ScaledMMLinearLayerConfig,
Int8ScaledMMLinearKernel,
Int8ScaledMMLinearLayerConfig,
ScaledMMLinearKernel,
ScaledMMLinearLayerConfig,
)
from vllm.model_executor.kernels.linear.scaled_mm.triton import (
TritonInt8ScaledMMLinearKernel,
)
__all__ = [
"FP8ScaledMMLinearKernel",
"FP8ScaledMMLinearLayerConfig",
"Int8ScaledMMLinearKernel",
"Int8ScaledMMLinearLayerConfig",
"ScaledMMLinearKernel",
"ScaledMMLinearLayerConfig",
"AiterInt8ScaledMMLinearKernel",
"CPUInt8ScaledMMLinearKernel",
"CutlassFP8ScaledMMLinearKernel",
"CutlassInt8ScaledMMLinearKernel",
"FlashInferFP8ScaledMMLinearKernel",
"ChannelWiseTorchFP8ScaledMMLinearKernel",
"PerTensorTorchFP8ScaledMMLinearKernel",
"RowWiseTorchFP8ScaledMMLinearKernel",
"ROCmFP8ScaledMMLinearKernel",
"TritonInt8ScaledMMLinearKernel",
]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment