Unverified Commit 300da091 authored by bnellnm's avatar bnellnm Committed by GitHub
Browse files

[Kernel] Fullgraph and opcheck tests (#8479)

parent 1c046447
"""
Tests for miscellaneous utilities
"""
from typing import Optional
import pytest
import torch
from tests.kernels.utils import opcheck
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
def rotary_embedding_opcheck(rot,
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
offsets: Optional[torch.Tensor] = None):
cos_sin_cache = rot.cos_sin_cache.to(query.device, dtype=query.dtype)
# ops.rotary_embedding()/batched_rotary_embedding()
# are in-place operations that update the query and key tensors.
if offsets is not None:
opcheck(torch.ops._C.batched_rotary_embedding,
(positions, query, key, rot.head_size, cos_sin_cache,
rot.is_neox_style, rot.rotary_dim, offsets))
else:
opcheck(torch.ops._C.rotary_embedding,
(positions, query, key, rot.head_size, cos_sin_cache,
rot.is_neox_style))
@pytest.mark.parametrize("device", ["cuda"])
@pytest.mark.parametrize("max_position", [11, 4096, 32768])
@pytest.mark.parametrize("is_neox_style", [True, False])
@pytest.mark.parametrize("rotary_dim", [32])
@pytest.mark.parametrize("head_size", [32, 108])
@pytest.mark.parametrize("seq_len", [11, 1024])
def test_rotary_embedding_opcheck(dist_init, device, max_position,
is_neox_style, rotary_dim, head_size,
seq_len):
batch_size = 1
base = 0
num_heads = 7
rot = RotaryEmbedding(head_size, rotary_dim, max_position, base,
is_neox_style, torch.float32)
positions = torch.randint(0,
max_position, (batch_size, seq_len),
device=device)
query = torch.randn(batch_size,
seq_len,
num_heads * head_size,
dtype=torch.float32,
device=device)
key = torch.randn_like(query)
rotary_embedding_opcheck(rot, positions, query, key)
offsets = torch.zeros(batch_size * seq_len,
device=device,
dtype=torch.long)
rotary_embedding_opcheck(rot, positions, query, key, offsets)
"""
Tests for miscellaneous utilities
"""
import pytest
import torch
from tests.kernels.utils import opcheck
from vllm.platforms import current_platform
def test_convert_fp8_opcheck():
data = torch.randn((256, 256), dtype=torch.float32, device="cuda")
result = torch.empty_like(data, dtype=torch.float8_e4m3fn)
opcheck(torch.ops._C_cache_ops.convert_fp8, (result, data, 1.0, "fp8"))
@pytest.mark.skipif(not current_platform.is_cuda(),
reason="Only supported for CUDA")
def test_cuda_utils_opcheck():
opcheck(torch.ops._C_cuda_utils.get_device_attribute, (0, 0))
opcheck(
torch.ops._C_cuda_utils.
get_max_shared_memory_per_block_device_attribute, (0, ))
...@@ -2,12 +2,14 @@ ...@@ -2,12 +2,14 @@
import itertools import itertools
import random import random
import unittest
from numbers import Number from numbers import Number
from typing import (Any, Dict, List, NamedTuple, Optional, Sequence, Tuple, from typing import (Any, Dict, List, NamedTuple, Optional, Sequence, Tuple,
Union) Union)
import pytest import pytest
import torch import torch
from torch._prims_common import TensorLikeType
from vllm.attention import AttentionBackend, AttentionMetadata, AttentionType from vllm.attention import AttentionBackend, AttentionMetadata, AttentionType
from vllm.utils import (STR_BACKEND_ENV_VAR, STR_XFORMERS_ATTN_VAL, from vllm.utils import (STR_BACKEND_ENV_VAR, STR_XFORMERS_ATTN_VAL,
...@@ -946,6 +948,34 @@ def assert_actual_matches_ideal(test_params: PhaseTestParameters, ...@@ -946,6 +948,34 @@ def assert_actual_matches_ideal(test_params: PhaseTestParameters,
output_under_test.view_as(ideal_output)) output_under_test.view_as(ideal_output))
# Copied/modified from torch._refs.__init__.py
def fp8_allclose(
a: TensorLikeType,
b: TensorLikeType,
rtol: float = 1e-05,
atol: float = 1e-08,
equal_nan: bool = False,
) -> bool:
"""
Reference implementation of torch.allclose
"""
torch._refs._check_close_args(name="torch.allclose",
a=a,
b=b,
rtol=rtol,
atol=atol)
return bool(
torch.all(
torch.isclose(a.double(),
b.double(),
rtol=rtol,
atol=atol,
equal_nan=equal_nan)).item())
# A special version of op check that has a restricted default set of test_utils
# and a patched version of allclose that supports fp8 types.
def opcheck(op: Union[torch._ops.OpOverload, torch._ops.OpOverloadPacket, def opcheck(op: Union[torch._ops.OpOverload, torch._ops.OpOverloadPacket,
torch._library.custom_ops.CustomOpDef], torch._library.custom_ops.CustomOpDef],
args: Tuple[Any, ...], args: Tuple[Any, ...],
...@@ -954,6 +984,7 @@ def opcheck(op: Union[torch._ops.OpOverload, torch._ops.OpOverloadPacket, ...@@ -954,6 +984,7 @@ def opcheck(op: Union[torch._ops.OpOverload, torch._ops.OpOverloadPacket,
test_utils: Union[str, Sequence[str]] = ALL_OPCHECK_TEST_UTILS, test_utils: Union[str, Sequence[str]] = ALL_OPCHECK_TEST_UTILS,
raise_exception: bool = True, raise_exception: bool = True,
cond: bool = True) -> Dict[str, str]: cond: bool = True) -> Dict[str, str]:
with unittest.mock.patch('torch.allclose', new=fp8_allclose):
return torch.library.opcheck( return torch.library.opcheck(
op, op,
args, args,
......
...@@ -20,8 +20,10 @@ if not current_platform.is_tpu(): ...@@ -20,8 +20,10 @@ if not current_platform.is_tpu():
if current_platform.is_rocm(): if current_platform.is_rocm():
import vllm._rocm_C # noqa: F401 import vllm._rocm_C # noqa: F401
supports_moe_ops = False
with contextlib.suppress(ImportError): with contextlib.suppress(ImportError):
import vllm._moe_C # noqa: F401 import vllm._moe_C # noqa: F401
supports_moe_ops = True
def hint_on_error(fn): def hint_on_error(fn):
...@@ -253,9 +255,7 @@ def gptq_gemm(a: torch.Tensor, b_q_weight: torch.Tensor, ...@@ -253,9 +255,7 @@ def gptq_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
b_g_idx, use_exllama, bit) b_g_idx, use_exllama, bit)
# TODO: has to be a better way to do this if hasattr(torch.ops._C, "gptq_gemm"):
try:
torch.ops._C.gptq_gemm # noqa B018
@torch.library.register_fake("_C::gptq_gemm") @torch.library.register_fake("_C::gptq_gemm")
def _gptq_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor, def _gptq_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor,
...@@ -265,8 +265,6 @@ try: ...@@ -265,8 +265,6 @@ try:
return torch.empty((a.size(0), b_q_weight.size(1)), return torch.empty((a.size(0), b_q_weight.size(1)),
dtype=a.dtype, dtype=a.dtype,
device=a.device) device=a.device)
except Exception:
pass
def gptq_shuffle(q_weight: torch.Tensor, q_perm: torch.Tensor, def gptq_shuffle(q_weight: torch.Tensor, q_perm: torch.Tensor,
...@@ -292,9 +290,7 @@ def gptq_marlin_24_gemm(a: torch.Tensor, b_q_weight: torch.Tensor, ...@@ -292,9 +290,7 @@ def gptq_marlin_24_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
size_n, size_k) size_n, size_k)
# TODO: has to be a better way to do this if hasattr(torch.ops._C, "gptq_marlin_24_gemm"):
try:
torch.ops._C.gptq_marlin_24_gemm # noqa B018
@torch.library.register_fake("_C::gptq_marlin_24_gemm") @torch.library.register_fake("_C::gptq_marlin_24_gemm")
def _gptq_marlin_24_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor, def _gptq_marlin_24_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor,
...@@ -420,8 +416,8 @@ try: ...@@ -420,8 +416,8 @@ try:
@torch.library.register_fake("_C::machete_gemm") @torch.library.register_fake("_C::machete_gemm")
def machete_gemm_fake( def machete_gemm_fake(
a: torch.Tensor, a: torch.Tensor,
b_q: torch. # Should be the tensor returned by machete_prepack_B
Tensor, # Should be the tensor returned by machete_prepack_B b_q: torch.Tensor,
b_type: ScalarType, b_type: ScalarType,
b_scales: Optional[torch.Tensor] = None, b_scales: Optional[torch.Tensor] = None,
b_zeros: Optional[torch.Tensor] = None, b_zeros: Optional[torch.Tensor] = None,
...@@ -451,10 +447,10 @@ try: ...@@ -451,10 +447,10 @@ try:
return torch.empty_like(x) return torch.empty_like(x)
@torch.library.register_fake("_C::causal_conv1d_update") @torch.library.register_fake("_C::causal_conv1d_update")
def causal_conv1d_update_fake(x: torch.Tensor, conv_state: torch.Tensor, def causal_conv1d_update_fake(
weight: torch.Tensor, x: torch.Tensor, conv_state: torch.Tensor, weight: torch.Tensor,
bias_: Optional[torch.Tensor], bias_: Optional[torch.Tensor], silu_activation: bool,
silu_activation: bool) -> torch.Tensor: conv_state_indices: Optional[torch.Tensor]) -> torch.Tensor:
return torch.empty_like(x) return torch.empty_like(x)
@torch.library.register_fake("_C::selective_scan_fwd") @torch.library.register_fake("_C::selective_scan_fwd")
...@@ -465,20 +461,11 @@ try: ...@@ -465,20 +461,11 @@ try:
delta_softplus: bool, index_: Optional[torch.Tensor], delta_softplus: bool, index_: Optional[torch.Tensor],
x: Optional[torch.Tensor]) -> List[torch.Tensor]: x: Optional[torch.Tensor]) -> List[torch.Tensor]:
a = torch.empty_like(u) a = torch.empty_like(u)
if x is not None:
b = x
else:
b = torch.empty((u.size(0), u.size(1), A.size(1)),
dtype=u.dtype,
device=u.device)
if z_ is not None: if z_ is not None:
c = torch.empty_like(z_) c = torch.empty_like(z_)
return [a, b, c] return [a, c]
else: else:
return [a, b] return [a]
except Exception:
pass
# cutlass # cutlass
...@@ -626,16 +613,12 @@ def machete_prepack_B(b_q_weight: torch.Tensor, ...@@ -626,16 +613,12 @@ def machete_prepack_B(b_q_weight: torch.Tensor,
return torch.ops._C.machete_prepack_B(b_q_weight, b_type) return torch.ops._C.machete_prepack_B(b_q_weight, b_type)
# TODO: has to be a better way to do this if hasattr(torch.ops._C, "permute_cols"):
try:
torch.ops._C.permute_cols # noqa B018
@torch.library.register_fake("_C::permute_cols") @torch.library.register_fake("_C::permute_cols")
def _permute_cols_fake(a: torch.Tensor, def _permute_cols_fake(a: torch.Tensor,
perm: torch.Tensor) -> torch.Tensor: perm: torch.Tensor) -> torch.Tensor:
return torch.empty_like(a) return torch.empty_like(a)
except Exception:
pass
def permute_cols(a: torch.Tensor, perm: torch.Tensor) -> torch.Tensor: def permute_cols(a: torch.Tensor, perm: torch.Tensor) -> torch.Tensor:
...@@ -828,6 +811,24 @@ def topk_softmax(topk_weights: torch.Tensor, topk_ids: torch.Tensor, ...@@ -828,6 +811,24 @@ def topk_softmax(topk_weights: torch.Tensor, topk_ids: torch.Tensor,
token_expert_indicies, gating_output) token_expert_indicies, gating_output)
if supports_moe_ops and hasattr(torch.ops._moe_C, "marlin_gemm_moe"):
@torch.library.register_fake("_moe_C::marlin_gemm_moe")
def marlin_gemm_moe_fake(a: torch.Tensor, b_q_weights: torch.Tensor,
sorted_ids: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor, b_scales: torch.Tensor,
g_idx: torch.Tensor, perm: torch.Tensor,
workspace: torch.Tensor, b_q_type: ScalarType,
size_m: int, size_n: int, size_k: int,
is_k_full: bool, num_experts: int, topk: int,
moe_block_size: int, replicate_input: bool,
apply_weights: bool) -> torch.Tensor:
return torch.empty((size_m, topk, size_n),
dtype=a.dtype,
device=a.device)
def reshape_and_cache( def reshape_and_cache(
key: torch.Tensor, key: torch.Tensor,
value: torch.Tensor, value: torch.Tensor,
......
...@@ -361,7 +361,7 @@ def selective_scan_fn(u, ...@@ -361,7 +361,7 @@ def selective_scan_fn(u,
x[:, :, 0, 0::2] = 1 x[:, :, 0, 0::2] = 1
if prev_state is not None: if prev_state is not None:
x[:, :, 0, 1::2].copy_(prev_state) x[:, :, 0, 1::2].copy_(prev_state)
out, x, *rest = ops.selective_scan_fwd(u, delta, A, B, C, D, z, delta_bias, out, *rest = ops.selective_scan_fwd(u, delta, A, B, C, D, z, delta_bias,
delta_softplus, position_indices, x) delta_softplus, position_indices, x)
last_state = x[:, :, -1, 1::2] # (batch, dim, dstate) last_state = x[:, :, -1, 1::2] # (batch, dim, dstate)
if z is None: if z is None:
......
...@@ -217,6 +217,7 @@ class GPTQLinearMethod(LinearMethodBase): ...@@ -217,6 +217,7 @@ class GPTQLinearMethod(LinearMethodBase):
layer.qzeros = Parameter(layer.qzeros.data, requires_grad=False) layer.qzeros = Parameter(layer.qzeros.data, requires_grad=False)
layer.qweight = Parameter(layer.qweight.data, requires_grad=False) layer.qweight = Parameter(layer.qweight.data, requires_grad=False)
layer.g_idx = Parameter(layer.g_idx.data, requires_grad=False) layer.g_idx = Parameter(layer.g_idx.data, requires_grad=False)
layer.scales = Parameter(layer.scales.data, requires_grad=False)
# exllama needs to shuffle the weight after the weight is loaded # exllama needs to shuffle the weight after the weight is loaded
# here we do the shuffle on first forward pass # here we do the shuffle on first forward pass
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment