Unverified Commit 300da091 authored by bnellnm's avatar bnellnm Committed by GitHub
Browse files

[Kernel] Fullgraph and opcheck tests (#8479)

parent 1c046447
"""
Tests for miscellaneous utilities
"""
from typing import Optional
import pytest
import torch
from tests.kernels.utils import opcheck
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
def rotary_embedding_opcheck(rot,
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
offsets: Optional[torch.Tensor] = None):
cos_sin_cache = rot.cos_sin_cache.to(query.device, dtype=query.dtype)
# ops.rotary_embedding()/batched_rotary_embedding()
# are in-place operations that update the query and key tensors.
if offsets is not None:
opcheck(torch.ops._C.batched_rotary_embedding,
(positions, query, key, rot.head_size, cos_sin_cache,
rot.is_neox_style, rot.rotary_dim, offsets))
else:
opcheck(torch.ops._C.rotary_embedding,
(positions, query, key, rot.head_size, cos_sin_cache,
rot.is_neox_style))
@pytest.mark.parametrize("device", ["cuda"])
@pytest.mark.parametrize("max_position", [11, 4096, 32768])
@pytest.mark.parametrize("is_neox_style", [True, False])
@pytest.mark.parametrize("rotary_dim", [32])
@pytest.mark.parametrize("head_size", [32, 108])
@pytest.mark.parametrize("seq_len", [11, 1024])
def test_rotary_embedding_opcheck(dist_init, device, max_position,
is_neox_style, rotary_dim, head_size,
seq_len):
batch_size = 1
base = 0
num_heads = 7
rot = RotaryEmbedding(head_size, rotary_dim, max_position, base,
is_neox_style, torch.float32)
positions = torch.randint(0,
max_position, (batch_size, seq_len),
device=device)
query = torch.randn(batch_size,
seq_len,
num_heads * head_size,
dtype=torch.float32,
device=device)
key = torch.randn_like(query)
rotary_embedding_opcheck(rot, positions, query, key)
offsets = torch.zeros(batch_size * seq_len,
device=device,
dtype=torch.long)
rotary_embedding_opcheck(rot, positions, query, key, offsets)
"""
Tests for miscellaneous utilities
"""
import pytest
import torch
from tests.kernels.utils import opcheck
from vllm.platforms import current_platform
def test_convert_fp8_opcheck():
data = torch.randn((256, 256), dtype=torch.float32, device="cuda")
result = torch.empty_like(data, dtype=torch.float8_e4m3fn)
opcheck(torch.ops._C_cache_ops.convert_fp8, (result, data, 1.0, "fp8"))
@pytest.mark.skipif(not current_platform.is_cuda(),
reason="Only supported for CUDA")
def test_cuda_utils_opcheck():
opcheck(torch.ops._C_cuda_utils.get_device_attribute, (0, 0))
opcheck(
torch.ops._C_cuda_utils.
get_max_shared_memory_per_block_device_attribute, (0, ))
......@@ -2,12 +2,14 @@
import itertools
import random
import unittest
from numbers import Number
from typing import (Any, Dict, List, NamedTuple, Optional, Sequence, Tuple,
Union)
import pytest
import torch
from torch._prims_common import TensorLikeType
from vllm.attention import AttentionBackend, AttentionMetadata, AttentionType
from vllm.utils import (STR_BACKEND_ENV_VAR, STR_XFORMERS_ATTN_VAL,
......@@ -946,6 +948,34 @@ def assert_actual_matches_ideal(test_params: PhaseTestParameters,
output_under_test.view_as(ideal_output))
# Copied/modified from torch._refs.__init__.py
def fp8_allclose(
a: TensorLikeType,
b: TensorLikeType,
rtol: float = 1e-05,
atol: float = 1e-08,
equal_nan: bool = False,
) -> bool:
"""
Reference implementation of torch.allclose
"""
torch._refs._check_close_args(name="torch.allclose",
a=a,
b=b,
rtol=rtol,
atol=atol)
return bool(
torch.all(
torch.isclose(a.double(),
b.double(),
rtol=rtol,
atol=atol,
equal_nan=equal_nan)).item())
# A special version of op check that has a restricted default set of test_utils
# and a patched version of allclose that supports fp8 types.
def opcheck(op: Union[torch._ops.OpOverload, torch._ops.OpOverloadPacket,
torch._library.custom_ops.CustomOpDef],
args: Tuple[Any, ...],
......@@ -954,9 +984,10 @@ def opcheck(op: Union[torch._ops.OpOverload, torch._ops.OpOverloadPacket,
test_utils: Union[str, Sequence[str]] = ALL_OPCHECK_TEST_UTILS,
raise_exception: bool = True,
cond: bool = True) -> Dict[str, str]:
return torch.library.opcheck(
op,
args,
kwargs,
test_utils=test_utils,
raise_exception=raise_exception) if cond else {}
with unittest.mock.patch('torch.allclose', new=fp8_allclose):
return torch.library.opcheck(
op,
args,
kwargs,
test_utils=test_utils,
raise_exception=raise_exception) if cond else {}
......@@ -20,8 +20,10 @@ if not current_platform.is_tpu():
if current_platform.is_rocm():
import vllm._rocm_C # noqa: F401
supports_moe_ops = False
with contextlib.suppress(ImportError):
import vllm._moe_C # noqa: F401
supports_moe_ops = True
def hint_on_error(fn):
......@@ -253,9 +255,7 @@ def gptq_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
b_g_idx, use_exllama, bit)
# TODO: has to be a better way to do this
try:
torch.ops._C.gptq_gemm # noqa B018
if hasattr(torch.ops._C, "gptq_gemm"):
@torch.library.register_fake("_C::gptq_gemm")
def _gptq_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor,
......@@ -265,8 +265,6 @@ try:
return torch.empty((a.size(0), b_q_weight.size(1)),
dtype=a.dtype,
device=a.device)
except Exception:
pass
def gptq_shuffle(q_weight: torch.Tensor, q_perm: torch.Tensor,
......@@ -292,9 +290,7 @@ def gptq_marlin_24_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
size_n, size_k)
# TODO: has to be a better way to do this
try:
torch.ops._C.gptq_marlin_24_gemm # noqa B018
if hasattr(torch.ops._C, "gptq_marlin_24_gemm"):
@torch.library.register_fake("_C::gptq_marlin_24_gemm")
def _gptq_marlin_24_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor,
......@@ -420,8 +416,8 @@ try:
@torch.library.register_fake("_C::machete_gemm")
def machete_gemm_fake(
a: torch.Tensor,
b_q: torch.
Tensor, # Should be the tensor returned by machete_prepack_B
# Should be the tensor returned by machete_prepack_B
b_q: torch.Tensor,
b_type: ScalarType,
b_scales: Optional[torch.Tensor] = None,
b_zeros: Optional[torch.Tensor] = None,
......@@ -451,10 +447,10 @@ try:
return torch.empty_like(x)
@torch.library.register_fake("_C::causal_conv1d_update")
def causal_conv1d_update_fake(x: torch.Tensor, conv_state: torch.Tensor,
weight: torch.Tensor,
bias_: Optional[torch.Tensor],
silu_activation: bool) -> torch.Tensor:
def causal_conv1d_update_fake(
x: torch.Tensor, conv_state: torch.Tensor, weight: torch.Tensor,
bias_: Optional[torch.Tensor], silu_activation: bool,
conv_state_indices: Optional[torch.Tensor]) -> torch.Tensor:
return torch.empty_like(x)
@torch.library.register_fake("_C::selective_scan_fwd")
......@@ -465,20 +461,11 @@ try:
delta_softplus: bool, index_: Optional[torch.Tensor],
x: Optional[torch.Tensor]) -> List[torch.Tensor]:
a = torch.empty_like(u)
if x is not None:
b = x
else:
b = torch.empty((u.size(0), u.size(1), A.size(1)),
dtype=u.dtype,
device=u.device)
if z_ is not None:
c = torch.empty_like(z_)
return [a, b, c]
return [a, c]
else:
return [a, b]
except Exception:
pass
return [a]
# cutlass
......@@ -626,16 +613,12 @@ def machete_prepack_B(b_q_weight: torch.Tensor,
return torch.ops._C.machete_prepack_B(b_q_weight, b_type)
# TODO: has to be a better way to do this
try:
torch.ops._C.permute_cols # noqa B018
if hasattr(torch.ops._C, "permute_cols"):
@torch.library.register_fake("_C::permute_cols")
def _permute_cols_fake(a: torch.Tensor,
perm: torch.Tensor) -> torch.Tensor:
return torch.empty_like(a)
except Exception:
pass
def permute_cols(a: torch.Tensor, perm: torch.Tensor) -> torch.Tensor:
......@@ -828,6 +811,24 @@ def topk_softmax(topk_weights: torch.Tensor, topk_ids: torch.Tensor,
token_expert_indicies, gating_output)
if supports_moe_ops and hasattr(torch.ops._moe_C, "marlin_gemm_moe"):
@torch.library.register_fake("_moe_C::marlin_gemm_moe")
def marlin_gemm_moe_fake(a: torch.Tensor, b_q_weights: torch.Tensor,
sorted_ids: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor, b_scales: torch.Tensor,
g_idx: torch.Tensor, perm: torch.Tensor,
workspace: torch.Tensor, b_q_type: ScalarType,
size_m: int, size_n: int, size_k: int,
is_k_full: bool, num_experts: int, topk: int,
moe_block_size: int, replicate_input: bool,
apply_weights: bool) -> torch.Tensor:
return torch.empty((size_m, topk, size_n),
dtype=a.dtype,
device=a.device)
def reshape_and_cache(
key: torch.Tensor,
value: torch.Tensor,
......
......@@ -361,8 +361,8 @@ def selective_scan_fn(u,
x[:, :, 0, 0::2] = 1
if prev_state is not None:
x[:, :, 0, 1::2].copy_(prev_state)
out, x, *rest = ops.selective_scan_fwd(u, delta, A, B, C, D, z, delta_bias,
delta_softplus, position_indices, x)
out, *rest = ops.selective_scan_fwd(u, delta, A, B, C, D, z, delta_bias,
delta_softplus, position_indices, x)
last_state = x[:, :, -1, 1::2] # (batch, dim, dstate)
if z is None:
return out if not return_last_state else (out, last_state)
......
......@@ -217,6 +217,7 @@ class GPTQLinearMethod(LinearMethodBase):
layer.qzeros = Parameter(layer.qzeros.data, requires_grad=False)
layer.qweight = Parameter(layer.qweight.data, requires_grad=False)
layer.g_idx = Parameter(layer.g_idx.data, requires_grad=False)
layer.scales = Parameter(layer.scales.data, requires_grad=False)
# exllama needs to shuffle the weight after the weight is loaded
# here we do the shuffle on first forward pass
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment