Unverified Commit 6af03f23 authored by BadrBasowid's avatar BadrBasowid Committed by GitHub
Browse files

[Refactor] [1/N] Reorganize kernel abstraction directory (#34055)


Signed-off-by: default avatarBadrBasowid <badr.basowid@gmail.com>
Co-authored-by: default avatarvllmellm <vllm.ellm@embeddedllm.com>
Co-authored-by: default avatarTJian <tunjian.tan@embeddedllm.com>
parent 1a6cf39d
...@@ -5,7 +5,7 @@ from collections.abc import Sequence ...@@ -5,7 +5,7 @@ from collections.abc import Sequence
import torch import torch
from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa: E501 from vllm.model_executor.kernels.linear import ( # noqa: E501
FP8ScaledMMLinearKernel, FP8ScaledMMLinearKernel,
FP8ScaledMMLinearLayerConfig, FP8ScaledMMLinearLayerConfig,
) )
......
...@@ -7,13 +7,13 @@ import torch ...@@ -7,13 +7,13 @@ import torch
from compressed_tensors.quantization import ActivationOrdering from compressed_tensors.quantization import ActivationOrdering
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( from vllm.model_executor.kernels.linear import (
CompressedTensorsScheme,
)
from vllm.model_executor.layers.quantization.kernels.mixed_precision import (
MPLinearLayerConfig, MPLinearLayerConfig,
choose_mp_linear_kernel, choose_mp_linear_kernel,
) )
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsScheme,
)
from vllm.model_executor.layers.quantization.utils.marlin_utils import ( from vllm.model_executor.layers.quantization.utils.marlin_utils import (
marlin_repeat_scales_on_all_ranks, marlin_repeat_scales_on_all_ranks,
) )
......
...@@ -6,13 +6,13 @@ from collections.abc import Callable ...@@ -6,13 +6,13 @@ from collections.abc import Callable
import torch import torch
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( from vllm.model_executor.kernels.linear import (
CompressedTensorsScheme,
)
from vllm.model_executor.layers.quantization.kernels.mixed_precision import (
MPLinearLayerConfig, MPLinearLayerConfig,
choose_mp_linear_kernel, choose_mp_linear_kernel,
) )
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsScheme,
)
from vllm.model_executor.parameter import ( from vllm.model_executor.parameter import (
ChannelQuantScaleParameter, ChannelQuantScaleParameter,
GroupQuantScaleParameter, GroupQuantScaleParameter,
......
...@@ -9,12 +9,12 @@ from torch.nn import Parameter ...@@ -9,12 +9,12 @@ from torch.nn import Parameter
from vllm._aiter_ops import rocm_aiter_ops from vllm._aiter_ops import rocm_aiter_ops
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.kernels.linear import (
init_fp8_linear_kernel,
)
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsScheme, CompressedTensorsScheme,
) )
from vllm.model_executor.layers.quantization.kernels.scaled_mm import (
init_fp8_linear_kernel,
)
from vllm.model_executor.layers.quantization.utils.fp8_utils import ( from vllm.model_executor.layers.quantization.utils.fp8_utils import (
W8A8BlockFp8LinearOp, W8A8BlockFp8LinearOp,
create_fp8_input_scale, create_fp8_input_scale,
......
...@@ -7,12 +7,12 @@ import torch ...@@ -7,12 +7,12 @@ import torch
from compressed_tensors.quantization import QuantizationStrategy from compressed_tensors.quantization import QuantizationStrategy
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.kernels.linear import (
init_int8_linear_kernel,
)
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsScheme, CompressedTensorsScheme,
) )
from vllm.model_executor.layers.quantization.kernels.scaled_mm import (
init_int8_linear_kernel,
)
from vllm.model_executor.parameter import ( from vllm.model_executor.parameter import (
BasevLLMParameter, BasevLLMParameter,
ChannelQuantScaleParameter, ChannelQuantScaleParameter,
......
...@@ -7,15 +7,13 @@ import torch ...@@ -7,15 +7,13 @@ import torch
from compressed_tensors.quantization import ActivationOrdering from compressed_tensors.quantization import ActivationOrdering
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( from vllm.model_executor.kernels.linear import (
CompressedTensorsScheme, MarlinLinearKernel,
)
from vllm.model_executor.layers.quantization.kernels.mixed_precision import (
MPLinearLayerConfig, MPLinearLayerConfig,
choose_mp_linear_kernel, choose_mp_linear_kernel,
) )
from vllm.model_executor.layers.quantization.kernels.mixed_precision.marlin import ( from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
MarlinLinearKernel, CompressedTensorsScheme,
) )
from vllm.model_executor.layers.quantization.utils.marlin_utils import ( from vllm.model_executor.layers.quantization.utils.marlin_utils import (
get_marlin_input_dtype, get_marlin_input_dtype,
......
...@@ -8,6 +8,9 @@ from torch.nn import Module ...@@ -8,6 +8,9 @@ from torch.nn import Module
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.kernels.linear import (
init_fp8_linear_kernel,
)
from vllm.model_executor.layers.linear import ( from vllm.model_executor.layers.linear import (
LinearBase, LinearBase,
LinearMethodBase, LinearMethodBase,
...@@ -18,9 +21,6 @@ from vllm.model_executor.layers.quantization.base_config import ( ...@@ -18,9 +21,6 @@ from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizationConfig,
QuantizeMethodBase, QuantizeMethodBase,
) )
from vllm.model_executor.layers.quantization.kernels.scaled_mm import (
init_fp8_linear_kernel,
)
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import ( from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
apply_fp8_marlin_linear, apply_fp8_marlin_linear,
prepare_fp8_layer_for_marlin, prepare_fp8_layer_for_marlin,
......
...@@ -13,6 +13,9 @@ from vllm import _custom_ops as ops ...@@ -13,6 +13,9 @@ from vllm import _custom_ops as ops
from vllm._aiter_ops import rocm_aiter_ops from vllm._aiter_ops import rocm_aiter_ops
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.kernels.linear import (
init_fp8_linear_kernel,
)
from vllm.model_executor.layers.attention import Attention from vllm.model_executor.layers.attention import Attention
from vllm.model_executor.layers.batch_invariant import ( from vllm.model_executor.layers.batch_invariant import (
vllm_is_batch_invariant, vllm_is_batch_invariant,
...@@ -46,9 +49,6 @@ from vllm.model_executor.layers.quantization.base_config import ( ...@@ -46,9 +49,6 @@ from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizationConfig,
QuantizeMethodBase, QuantizeMethodBase,
) )
from vllm.model_executor.layers.quantization.kernels.scaled_mm import (
init_fp8_linear_kernel,
)
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import ( from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
apply_fi_trtllm_fp8_per_tensor_moe, apply_fi_trtllm_fp8_per_tensor_moe,
......
...@@ -10,6 +10,10 @@ from safetensors.torch import _TYPES as _SAFETENSORS_TO_TORCH_DTYPE ...@@ -10,6 +10,10 @@ from safetensors.torch import _TYPES as _SAFETENSORS_TO_TORCH_DTYPE
import vllm.model_executor.layers.fused_moe # noqa import vllm.model_executor.layers.fused_moe # noqa
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.kernels.linear import (
MPLinearLayerConfig,
choose_mp_linear_kernel,
)
from vllm.model_executor.layers.fused_moe.config import ( from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig, FusedMoEConfig,
FusedMoEQuantConfig, FusedMoEQuantConfig,
...@@ -27,10 +31,6 @@ from vllm.model_executor.layers.quantization.base_config import ( ...@@ -27,10 +31,6 @@ from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizationConfig,
QuantizeMethodBase, QuantizeMethodBase,
) )
from vllm.model_executor.layers.quantization.kernels.mixed_precision import (
MPLinearLayerConfig,
choose_mp_linear_kernel,
)
from vllm.model_executor.layers.quantization.utils import replace_parameter from vllm.model_executor.layers.quantization.utils import replace_parameter
from vllm.model_executor.layers.quantization.utils.gptq_utils import ( from vllm.model_executor.layers.quantization.utils.gptq_utils import (
get_dynamic_override, get_dynamic_override,
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import vllm.envs as envs
from vllm.model_executor.layers.quantization.kernels.mixed_precision.allspark import ( # noqa: E501
AllSparkLinearKernel,
)
from vllm.model_executor.layers.quantization.kernels.mixed_precision.conch import ( # noqa: E501
ConchLinearKernel,
)
from vllm.model_executor.layers.quantization.kernels.mixed_precision.cpu import ( # noqa: E501
CPUWNA16LinearKernel,
)
from vllm.model_executor.layers.quantization.kernels.mixed_precision.cutlass import ( # noqa: E501
CutlassW4A8LinearKernel,
)
from vllm.model_executor.layers.quantization.kernels.mixed_precision.dynamic_4bit import ( # noqa: E501
Dynamic4bitLinearKernel,
)
from vllm.model_executor.layers.quantization.kernels.mixed_precision.exllama import ( # noqa: E501
ExllamaLinearKernel,
)
from vllm.model_executor.layers.quantization.kernels.mixed_precision.machete import ( # noqa: E501
MacheteLinearKernel,
)
from vllm.model_executor.layers.quantization.kernels.mixed_precision.marlin import ( # noqa: E501
MarlinLinearKernel,
)
from vllm.model_executor.layers.quantization.kernels.mixed_precision.MPLinearKernel import ( # noqa: E501
MPLinearKernel,
MPLinearLayerConfig,
)
from vllm.model_executor.layers.quantization.kernels.mixed_precision.xpu import ( # noqa: E501
XPUwNa16LinearKernel,
)
from vllm.platforms import PlatformEnum, current_platform
# in priority/performance order (when available)
_POSSIBLE_KERNELS: dict[PlatformEnum, list[type[MPLinearKernel]]] = {
PlatformEnum.CUDA: [
CutlassW4A8LinearKernel,
MacheteLinearKernel,
AllSparkLinearKernel,
MarlinLinearKernel,
ConchLinearKernel,
ExllamaLinearKernel,
],
PlatformEnum.ROCM: [
ConchLinearKernel,
ExllamaLinearKernel,
],
PlatformEnum.XPU: [
XPUwNa16LinearKernel,
],
PlatformEnum.CPU: [
Dynamic4bitLinearKernel,
CPUWNA16LinearKernel,
],
}
def choose_mp_linear_kernel(
config: MPLinearLayerConfig, compute_capability: int | None = None
) -> type[MPLinearKernel]:
"""
Choose an MPLinearKernel that can implement the given config for the given
compute capability. Attempts to choose the best kernel in terms of
performance.
Args:
config (MPLinearLayerConfig): Description of the linear layer to be
implemented.
compute_capability (Optional[int], optional): The compute capability of
the target device, if None uses `current_platform` to get
the compute capability. Defaults to None.
Raises:
ValueError: If no kernel can implement the given config.
Returns:
type[MPLinearKernel]: Chosen kernel.
"""
if compute_capability is None:
if current_platform is None:
raise ValueError("Cannot determine compute capability")
_cc = current_platform.get_device_capability()
if _cc is not None:
compute_capability = _cc[0] * 10 + _cc[1]
failure_reasons = []
for kernel in _POSSIBLE_KERNELS[current_platform._enum]:
if kernel.__name__ in envs.VLLM_DISABLED_KERNELS:
failure_reasons.append(
f" {kernel.__name__} disabled by environment variable"
)
continue
if (
compute_capability is not None
and kernel.get_min_capability() > compute_capability
):
failure_reasons.append(
f"{kernel.__name__} requires capability "
f"{kernel.get_min_capability()}, current compute "
f" capability is {compute_capability}"
)
continue
can_implement, failure_reason = kernel.can_implement(config)
if can_implement:
return kernel
else:
failure_reasons.append(
f" {kernel.__name__} cannot implement due to: {failure_reason}"
)
raise ValueError(
"Failed to find a kernel that can implement the "
"WNA16 linear layer. Reasons: \n" + "\n".join(failure_reasons)
)
...@@ -9,6 +9,9 @@ from torch.nn.parameter import Parameter ...@@ -9,6 +9,9 @@ from torch.nn.parameter import Parameter
import vllm.model_executor.layers.fused_moe.modular_kernel as mk import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.kernels.linear import (
init_fp8_linear_kernel,
)
from vllm.model_executor.layers.attention import Attention from vllm.model_executor.layers.attention import Attention
from vllm.model_executor.layers.fused_moe.activation import MoEActivation from vllm.model_executor.layers.fused_moe.activation import MoEActivation
from vllm.model_executor.layers.fused_moe.config import ( from vllm.model_executor.layers.fused_moe.config import (
...@@ -45,9 +48,6 @@ from vllm.model_executor.layers.quantization.base_config import ( ...@@ -45,9 +48,6 @@ from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizationConfig,
QuantizeMethodBase, QuantizeMethodBase,
) )
from vllm.model_executor.layers.quantization.kernels.scaled_mm import (
init_fp8_linear_kernel,
)
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import ( from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import (
flashinfer_trtllm_fp4_moe, flashinfer_trtllm_fp4_moe,
......
...@@ -8,6 +8,9 @@ from torch.nn.parameter import Parameter ...@@ -8,6 +8,9 @@ from torch.nn.parameter import Parameter
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.kernels.linear import (
init_fp8_linear_kernel,
)
from vllm.model_executor.layers.attention import Attention from vllm.model_executor.layers.attention import Attention
from vllm.model_executor.layers.linear import LinearBase, UnquantizedLinearMethod from vllm.model_executor.layers.linear import LinearBase, UnquantizedLinearMethod
from vllm.model_executor.layers.quantization import QuantizationMethods from vllm.model_executor.layers.quantization import QuantizationMethods
...@@ -17,9 +20,6 @@ from vllm.model_executor.layers.quantization.fp8 import ( ...@@ -17,9 +20,6 @@ from vllm.model_executor.layers.quantization.fp8 import (
Fp8KVCacheMethod, Fp8KVCacheMethod,
Fp8LinearMethod, Fp8LinearMethod,
) )
from vllm.model_executor.layers.quantization.kernels.scaled_mm import (
init_fp8_linear_kernel,
)
from vllm.model_executor.layers.quantization.utils.quant_utils import ( from vllm.model_executor.layers.quantization.utils.quant_utils import (
is_layer_skipped, is_layer_skipped,
kFp8DynamicTokenSym, kFp8DynamicTokenSym,
......
...@@ -8,7 +8,7 @@ import torch ...@@ -8,7 +8,7 @@ import torch
from torch.nn import Parameter from torch.nn import Parameter
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.kernels.scaled_mm import ( from vllm.model_executor.kernels.linear import (
init_fp8_linear_kernel, init_fp8_linear_kernel,
) )
from vllm.model_executor.layers.quantization.quark.schemes import QuarkScheme from vllm.model_executor.layers.quantization.quark.schemes import QuarkScheme
......
...@@ -6,7 +6,7 @@ from collections.abc import Callable ...@@ -6,7 +6,7 @@ from collections.abc import Callable
import torch import torch
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.kernels.scaled_mm import ( from vllm.model_executor.kernels.linear import (
init_int8_linear_kernel, init_int8_linear_kernel,
) )
from vllm.model_executor.layers.quantization.quark.schemes import QuarkScheme from vllm.model_executor.layers.quantization.quark.schemes import QuarkScheme
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment