Unverified Commit 1ac66942 authored by zhangyiming's avatar zhangyiming Committed by GitHub
Browse files

[OOT] Add OOT support for linear kernel. (#37989)


Signed-off-by: default avatarmenogrey <1299267905@qq.com>
parent 6cc7abdc
......@@ -7,15 +7,21 @@ Run `pytest tests/kernels/quantization/test_scaled_mm_kernel_selection.py`.
import inspect
from abc import ABC
from unittest.mock import patch
import pytest
import torch
from vllm.model_executor.kernels.linear import (
AiterInt8ScaledMMLinearKernel,
CPUInt8ScaledMMLinearKernel,
Int8ScaledMMLinearKernel,
Int8ScaledMMLinearLayerConfig,
ScaledMMLinearKernel,
init_int8_linear_kernel,
register_linear_kernel,
)
from vllm.platforms import PlatformEnum
pytestmark = pytest.mark.cpu_test
......@@ -85,3 +91,39 @@ def test_cpu_kernel_accepts_all_configs():
assert can_impl, (
f"CPUInt8ScaledMMLinearKernel should accept config {config}: {reason}"
)
class OOTInt8ScaledMMLinearKernel(Int8ScaledMMLinearKernel):
@classmethod
def is_supported(
cls, compute_capability: int | None = None
) -> tuple[bool, str | None]:
return True, None
@classmethod
def can_implement(cls, c: Int8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
return True, None
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
pass
def apply_weights(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
pass
@patch("vllm.model_executor.kernels.linear.current_platform")
def test_register_oot_linear_kernel(platform_mock):
"""Test that the linear kernel registration works correctly."""
platform_mock._enum = PlatformEnum.OOT
register_linear_kernel(OOTInt8ScaledMMLinearKernel, PlatformEnum.OOT, "int8")
kernel = init_int8_linear_kernel(True, True, True, "module")
assert isinstance(kernel, OOTInt8ScaledMMLinearKernel), (
"init_int8_linear_kernel should return an instance of the registered kernel"
)
......@@ -367,10 +367,44 @@ def choose_mp_linear_kernel(
)
def register_linear_kernel(
kernel_class: type,
platform: PlatformEnum,
kernel_type: str = "mp",
) -> None:
"""
Register a new linear kernel class to be considered in kernel selection.
Args:
kernel_class (type): The kernel class to register.
platform (PlatformEnum): The platform for which this kernel is applicable.
kernel_type (str): The type of the kernel, either "mp", "int8", or "fp8".
Defaults to "mp".
Raises:
ValueError: If the kernel_type is not recognized.
"""
if kernel_type == "mp":
if platform not in _POSSIBLE_KERNELS:
_POSSIBLE_KERNELS[platform] = []
_POSSIBLE_KERNELS[platform].append(kernel_class)
elif kernel_type == "int8":
if platform not in _POSSIBLE_INT8_KERNELS:
_POSSIBLE_INT8_KERNELS[platform] = []
_POSSIBLE_INT8_KERNELS[platform].append(kernel_class)
elif kernel_type == "fp8":
if platform not in _POSSIBLE_FP8_KERNELS:
_POSSIBLE_FP8_KERNELS[platform] = []
_POSSIBLE_FP8_KERNELS[platform].append(kernel_class)
else:
raise ValueError(f"Unrecognized kernel type: {kernel_type}")
__all__ = [
"init_fp8_linear_kernel",
"init_int8_linear_kernel",
"choose_mp_linear_kernel",
"register_linear_kernel",
"FP8ScaledMMLinearKernel",
"Int8ScaledMMLinearKernel",
"ScaledMMLinearKernel",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment