Commit 4851c202 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.6.1' into v0.6.1-dev

parents 9b902f9e 3fd2b0d2
import pytest
from vllm.entrypoints.chat_utils import apply_chat_template, load_chat_template
from vllm.entrypoints.chat_utils import (apply_hf_chat_template,
load_chat_template)
from vllm.entrypoints.openai.protocol import ChatCompletionRequest
from vllm.transformers_utils.tokenizer import get_tokenizer
......@@ -87,7 +88,7 @@ def test_get_gen_prompt(model, template, add_generation_prompt,
add_generation_prompt=add_generation_prompt)
# Call the function and get the result
result = apply_chat_template(
result = apply_hf_chat_template(
tokenizer,
conversation=mock_request.messages,
chat_template=mock_request.chat_template or template_content,
......
......@@ -2,7 +2,7 @@ from typing import Optional
import torch
from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispacther
from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher
class MyMod(torch.nn.Module):
......@@ -13,7 +13,7 @@ class MyMod(torch.nn.Module):
return x * 2
class MyWrapper(TorchCompileWrapperWithCustomDispacther):
class MyWrapper(TorchCompileWrapperWithCustomDispatcher):
def __init__(self, model):
self.model = model
......
......@@ -21,6 +21,7 @@ from transformers import (AutoModelForCausalLM, AutoTokenizer, BatchEncoding,
from vllm import LLM, SamplingParams
from vllm.assets.image import ImageAsset
from vllm.assets.video import VideoAsset
from vllm.config import TokenizerPoolConfig
from vllm.connections import global_http_connection
from vllm.distributed import (destroy_distributed_environment,
......@@ -44,6 +45,7 @@ _LONG_PROMPTS = [os.path.join(_TEST_DIR, "prompts", "summary.txt")]
PromptImageInput = Union[List[Image.Image], List[List[Image.Image]]]
PromptAudioInput = Union[List[Tuple[np.ndarray, int]],
List[List[Tuple[np.ndarray, int]]]]
PromptVideoInput = Union[List[np.ndarray], List[List[np.ndarray]]]
def _read_prompts(filename: str) -> List[str]:
......@@ -85,8 +87,35 @@ class _ImageAssets(_ImageAssetsBase):
return [prompts["stop_sign"], prompts["cherry_blossom"]]
class _VideoAssetPrompts(TypedDict):
sample_demo_1: str
if sys.version_info < (3, 9):
# UserList cannot be subscripted
class _VideoAssetsBase(UserList):
pass
else:
class _VideoAssetsBase(UserList[VideoAsset]):
pass
class _VideoAssets(_VideoAssetsBase):
def __init__(self) -> None:
super().__init__([
VideoAsset("sample_demo_1.mp4"),
])
def prompts(self, prompts: _VideoAssetPrompts) -> List[str]:
return [prompts["sample_demo_1"]]
IMAGE_ASSETS = _ImageAssets()
"""Singleton instance of :class:`_ImageAssets`."""
VIDEO_ASSETS = _VideoAssets()
"""Singleton instance of :class:`_VideoAssets`."""
@pytest.fixture(autouse=True)
......@@ -202,6 +231,11 @@ def image_assets() -> _ImageAssets:
return IMAGE_ASSETS
@pytest.fixture(scope="session")
def video_assets() -> _VideoAssets:
return VIDEO_ASSETS
_T = TypeVar("_T", nn.Module, torch.Tensor, BatchEncoding, BatchFeature)
......@@ -278,7 +312,8 @@ class HfRunner:
def generate(
self,
prompts: List[str],
images: Optional[List[Image.Image]] = None,
images: Optional[PromptImageInput] = None,
videos: Optional[List[np.ndarray]] = None,
**kwargs: Any,
) -> List[Tuple[List[List[int]], List[str]]]:
if images:
......@@ -292,6 +327,8 @@ class HfRunner:
}
if images is not None and images[i] is not None:
processor_kwargs["images"] = images[i]
if videos is not None and videos[i] is not None:
processor_kwargs["videos"] = videos[i]
inputs = self.processor(**processor_kwargs)
inputs = self.postprocess_inputs(inputs)
......@@ -314,7 +351,7 @@ class HfRunner:
self,
prompts: List[str],
max_tokens: int,
images: Optional[List[Image.Image]] = None,
images: Optional[PromptImageInput] = None,
**kwargs: Any,
) -> List[Tuple[List[int], str]]:
outputs = self.generate(prompts,
......@@ -351,7 +388,8 @@ class HfRunner:
self,
prompts: List[str],
max_tokens: int,
images: Optional[List[Image.Image]] = None,
images: Optional[PromptImageInput] = None,
videos: Optional[List[np.ndarray]] = None,
**kwargs: Any,
) -> List[List[torch.Tensor]]:
all_logprobs: List[List[torch.Tensor]] = []
......@@ -362,6 +400,8 @@ class HfRunner:
}
if images is not None and images[i] is not None:
processor_kwargs["images"] = images[i]
if videos is not None and videos[i] is not None:
processor_kwargs["videos"] = videos[i]
inputs = self.processor(**processor_kwargs)
inputs = self.postprocess_inputs(inputs)
......@@ -433,8 +473,9 @@ class HfRunner:
prompts: List[str],
max_tokens: int,
num_logprobs: int,
images: Optional[List[Image.Image]] = None,
audios: Optional[List[Tuple[np.ndarray, int]]] = None,
images: Optional[PromptImageInput] = None,
audios: Optional[PromptAudioInput] = None,
videos: Optional[List[np.ndarray]] = None,
**kwargs: Any,
) -> List[Tuple[List[int], str, List[Dict[int, float]]]]:
all_logprobs: List[List[Dict[int, float]]] = []
......@@ -454,6 +495,8 @@ class HfRunner:
processor_kwargs["audio"] = audio
processor_kwargs["sampling_rate"] = sr
if videos is not None:
processor_kwargs["videos"] = videos[i]
inputs = self.processor(**processor_kwargs)
inputs = self.postprocess_inputs(inputs)
......@@ -634,12 +677,16 @@ class VllmRunner:
sampling_params: SamplingParams,
images: Optional[PromptImageInput] = None,
audios: Optional[PromptAudioInput] = None,
videos: Optional[PromptVideoInput] = None,
) -> List[Tuple[List[int], str, Optional[SampleLogprobs]]]:
assert sampling_params.logprobs is not None
if images is not None:
assert len(prompts) == len(images)
if videos is not None:
assert len(prompts) == len(videos)
inputs = [TextPrompt(prompt=prompt) for prompt in prompts]
if images is not None:
for i, image in enumerate(images):
......@@ -649,6 +696,11 @@ class VllmRunner:
for i, audio in enumerate(audios):
inputs[i]["multi_modal_data"] = {"audio": audio}
if videos is not None:
for i, video in enumerate(videos):
inputs[i]["multi_modal_data"] = {"video": video}
print(f"[INPUTS!!!!]: {inputs}, {sampling_params}")
req_outputs = self.model.generate(inputs,
sampling_params=sampling_params)
return self._final_steps_generate_w_logprobs(req_outputs)
......@@ -671,7 +723,7 @@ class VllmRunner:
self,
prompts: List[str],
max_tokens: int,
images: Optional[List[Image.Image]] = None,
images: Optional[PromptImageInput] = None,
) -> List[Tuple[List[int], str]]:
greedy_params = SamplingParams(temperature=0.0, max_tokens=max_tokens)
outputs = self.generate(prompts, greedy_params, images=images)
......@@ -685,6 +737,7 @@ class VllmRunner:
num_logprobs: int,
images: Optional[PromptImageInput] = None,
audios: Optional[PromptAudioInput] = None,
videos: Optional[PromptVideoInput] = None,
stop_token_ids: Optional[List[int]] = None,
) -> List[Tuple[List[int], str, Optional[SampleLogprobs]]]:
greedy_logprobs_params = SamplingParams(temperature=0.0,
......@@ -694,7 +747,8 @@ class VllmRunner:
outputs = self.generate_w_logprobs(prompts,
greedy_logprobs_params,
images=images,
audios=audios)
audios=audios,
videos=videos)
return [(output_ids, output_str, output_logprobs)
for output_ids, output_str, output_logprobs in outputs]
......
......@@ -35,9 +35,11 @@ def test_models(hf_runner, vllm_runner, image_assets, model: str,
if model.startswith("llava-hf/llava-1.5"):
from ..models.test_llava import models, run_test
elif model.startswith("llava-hf/llava-v1.6"):
from ..models.test_llava_next import models, run_test
from ..models.test_llava_next import run_test # type: ignore[no-redef]
from ..models.test_llava_next import models
elif model.startswith("facebook/chameleon"):
from ..models.test_chameleon import models, run_test
from ..models.test_chameleon import run_test # type: ignore[no-redef]
from ..models.test_chameleon import models
else:
raise NotImplementedError(f"Unsupported model: {model}")
......
......@@ -18,23 +18,28 @@ logger = init_logger("test_pipeline_parallel")
VLLM_MULTI_NODE = os.getenv("VLLM_MULTI_NODE", "0") == "1"
@pytest.mark.parametrize(("TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, "
"MODEL_NAME, DIST_BACKEND"),
[
(2, 2, 0, 1, "meta-llama/Meta-Llama-3-8B", "mp"),
(2, 2, 1, 0, "meta-llama/Meta-Llama-3-8B", "mp"),
(1, 3, 0, 0, "meta-llama/Meta-Llama-3-8B", "mp"),
(1, 4, 0, 1, "meta-llama/Meta-Llama-3-8B", "mp"),
(1, 4, 1, 0, "meta-llama/Meta-Llama-3-8B", "mp"),
(1, 3, 0, 0, "meta-llama/Meta-Llama-3-8B", "ray"),
(1, 4, 0, 1, "meta-llama/Meta-Llama-3-8B", "ray"),
(1, 4, 1, 0, "meta-llama/Meta-Llama-3-8B", "ray"),
(2, 2, 1, 0, "meta-llama/Meta-Llama-3-8B", "ray"),
(2, 2, 0, 1, "meta-llama/Meta-Llama-3-8B", "ray"),
])
@pytest.mark.parametrize(
("TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, TRUST_REMOTE_CODE, "
"MODEL_NAME, DIST_BACKEND"),
[
(2, 2, 0, 1, 0, "meta-llama/Meta-Llama-3-8B", "mp"),
(2, 2, 1, 0, 0, "meta-llama/Meta-Llama-3-8B", "mp"),
(1, 3, 0, 0, 0, "meta-llama/Meta-Llama-3-8B", "mp"),
(1, 4, 0, 1, 0, "meta-llama/Meta-Llama-3-8B", "mp"),
(1, 4, 1, 0, 0, "meta-llama/Meta-Llama-3-8B", "mp"),
(1, 3, 0, 0, 0, "meta-llama/Meta-Llama-3-8B", "ray"),
(1, 4, 0, 1, 0, "meta-llama/Meta-Llama-3-8B", "ray"),
(1, 4, 1, 0, 0, "meta-llama/Meta-Llama-3-8B", "ray"),
(2, 2, 1, 0, 0, "meta-llama/Meta-Llama-3-8B", "ray"),
(2, 2, 0, 1, 0, "meta-llama/Meta-Llama-3-8B", "ray"),
(1, 2, 1, 1, 1, "OpenGVLab/InternVL2-1B", "ray"),
(1, 2, 1, 1, 1, "OpenGVLab/InternVL2-2B", "ray"),
(1, 2, 1, 0, 1, "OpenGVLab/InternVL2-4B", "ray"),
],
)
@fork_new_process_for_each_test
def test_compare_tp(TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, MODEL_NAME,
DIST_BACKEND):
def test_compare_tp(TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL,
TRUST_REMOTE_CODE, MODEL_NAME, DIST_BACKEND):
if VLLM_MULTI_NODE and DIST_BACKEND == "mp":
pytest.skip("Skipping multi-node pipeline parallel test for "
"multiprocessing distributed backend")
......@@ -43,6 +48,8 @@ def test_compare_tp(TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, MODEL_NAME,
# use half precision for speed and memory savings in CI environment
"--dtype",
"float16",
"--max-model-len",
"8192",
"--pipeline-parallel-size",
str(PP_SIZE),
"--tensor-parallel-size",
......@@ -59,7 +66,9 @@ def test_compare_tp(TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, MODEL_NAME,
tp_args = [
# use half precision for speed and memory savings in CI environment
"--dtype",
"bfloat16",
"float16",
"--max-model-len",
"8192",
"--tensor-parallel-size",
str(max(TP_SIZE, 2)), # We only use 2 GPUs in the CI.
"--distributed-executor-backend",
......@@ -71,6 +80,9 @@ def test_compare_tp(TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, MODEL_NAME,
if EAGER_MODE:
pp_args.append("--enforce-eager")
tp_args.append("--enforce-eager")
if TRUST_REMOTE_CODE:
pp_args.append("--trust-remote-code")
tp_args.append("--trust-remote-code")
pp_env = None
if (DIST_BACKEND == "ray" and TP_SIZE == 2 and PP_SIZE == 2
and CHUNKED_PREFILL):
......
......@@ -83,7 +83,7 @@ def test_local_workers() -> None:
workers[3].process.kill()
# Other workers should get shut down here
worker_monitor.join(2)
worker_monitor.join(20)
# Ensure everything is stopped
assert not worker_monitor.is_alive()
......@@ -108,7 +108,7 @@ def test_local_workers_clean_shutdown() -> None:
# Clean shutdown
worker_monitor.close()
worker_monitor.join(5)
worker_monitor.join(20)
# Ensure everything is stopped
assert not worker_monitor.is_alive()
......@@ -161,7 +161,7 @@ async def test_local_workers_async() -> None:
workers[3].process.kill()
# Other workers should get shut down here
worker_monitor.join(2)
worker_monitor.join(20)
# Ensure everything is stopped
assert not worker_monitor.is_alive()
......
......@@ -50,7 +50,7 @@ def zephyr_lora_files():
@pytest.mark.skip_global_cleanup
def test_multiple_lora_requests(llm: LLM, zephyr_lora_files):
lora_request = [
LoRARequest(LORA_NAME, idx + 1, zephyr_lora_files)
LoRARequest(LORA_NAME + str(idx), idx + 1, zephyr_lora_files)
for idx in range(len(PROMPTS))
]
# Multiple SamplingParams should be matched with each prompt
......
......@@ -8,7 +8,9 @@ from vllm.entrypoints.openai.protocol import BatchRequestOutput
INPUT_BATCH = """{"custom_id": "request-1", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "NousResearch/Meta-Llama-3-8B-Instruct", "messages": [{"role": "system", "content": "You are a helpful assistant."},{"role": "user", "content": "Hello world!"}],"max_tokens": 1000}}
{"custom_id": "request-2", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "NousResearch/Meta-Llama-3-8B-Instruct", "messages": [{"role": "system", "content": "You are an unhelpful assistant."},{"role": "user", "content": "Hello world!"}],"max_tokens": 1000}}
{"custom_id": "request-3", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "NonExistModel", "messages": [{"role": "system", "content": "You are an unhelpful assistant."},{"role": "user", "content": "Hello world!"}],"max_tokens": 1000}}"""
{"custom_id": "request-3", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "NonExistModel", "messages": [{"role": "system", "content": "You are an unhelpful assistant."},{"role": "user", "content": "Hello world!"}],"max_tokens": 1000}}
{"custom_id": "request-4", "method": "POST", "url": "/bad_url", "body": {"model": "NousResearch/Meta-Llama-3-8B-Instruct", "messages": [{"role": "system", "content": "You are an unhelpful assistant."},{"role": "user", "content": "Hello world!"}],"max_tokens": 1000}}
{"custom_id": "request-5", "method": "POST", "url": "/v1/chat/completions", "body": {"stream": "True", "model": "NousResearch/Meta-Llama-3-8B-Instruct", "messages": [{"role": "system", "content": "You are an unhelpful assistant."},{"role": "user", "content": "Hello world!"}],"max_tokens": 1000}}"""
INVALID_INPUT_BATCH = """{"invalid_field": "request-1", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "NousResearch/Meta-Llama-3-8B-Instruct", "messages": [{"role": "system", "content": "You are a helpful assistant."},{"role": "user", "content": "Hello world!"}],"max_tokens": 1000}}
{"custom_id": "request-2", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "NousResearch/Meta-Llama-3-8B-Instruct", "messages": [{"role": "system", "content": "You are an unhelpful assistant."},{"role": "user", "content": "Hello world!"}],"max_tokens": 1000}}"""
......
from http import HTTPStatus
from unittest.mock import MagicMock
import pytest
from vllm.config import ModelConfig
from vllm.engine.protocol import AsyncEngineClient
from vllm.entrypoints.openai.protocol import (ErrorResponse,
LoadLoraAdapterRequest,
UnloadLoraAdapterRequest)
from vllm.entrypoints.openai.serving_engine import OpenAIServing
MODEL_NAME = "meta-llama/Llama-2-7b"
LORA_LOADING_SUCCESS_MESSAGE = (
"Success: LoRA adapter '{lora_name}' added successfully.")
LORA_UNLOADING_SUCCESS_MESSAGE = (
"Success: LoRA adapter '{lora_name}' removed successfully.")
async def _async_serving_engine_init():
mock_engine_client = MagicMock(spec=AsyncEngineClient)
mock_model_config = MagicMock(spec=ModelConfig)
# Set the max_model_len attribute to avoid missing attribute
mock_model_config.max_model_len = 2048
serving_engine = OpenAIServing(mock_engine_client,
mock_model_config,
served_model_names=[MODEL_NAME],
lora_modules=None,
prompt_adapters=None,
request_logger=None)
return serving_engine
@pytest.mark.asyncio
async def test_load_lora_adapter_success():
serving_engine = await _async_serving_engine_init()
request = LoadLoraAdapterRequest(lora_name="adapter",
lora_path="/path/to/adapter2")
response = await serving_engine.load_lora_adapter(request)
assert response == LORA_LOADING_SUCCESS_MESSAGE.format(lora_name='adapter')
assert len(serving_engine.lora_requests) == 1
assert serving_engine.lora_requests[0].lora_name == "adapter"
@pytest.mark.asyncio
async def test_load_lora_adapter_missing_fields():
serving_engine = await _async_serving_engine_init()
request = LoadLoraAdapterRequest(lora_name="", lora_path="")
response = await serving_engine.load_lora_adapter(request)
assert isinstance(response, ErrorResponse)
assert response.type == "InvalidUserInput"
assert response.code == HTTPStatus.BAD_REQUEST
@pytest.mark.asyncio
async def test_load_lora_adapter_duplicate():
serving_engine = await _async_serving_engine_init()
request = LoadLoraAdapterRequest(lora_name="adapter1",
lora_path="/path/to/adapter1")
response = await serving_engine.load_lora_adapter(request)
assert response == LORA_LOADING_SUCCESS_MESSAGE.format(
lora_name='adapter1')
assert len(serving_engine.lora_requests) == 1
request = LoadLoraAdapterRequest(lora_name="adapter1",
lora_path="/path/to/adapter1")
response = await serving_engine.load_lora_adapter(request)
assert isinstance(response, ErrorResponse)
assert response.type == "InvalidUserInput"
assert response.code == HTTPStatus.BAD_REQUEST
assert len(serving_engine.lora_requests) == 1
@pytest.mark.asyncio
async def test_unload_lora_adapter_success():
serving_engine = await _async_serving_engine_init()
request = LoadLoraAdapterRequest(lora_name="adapter1",
lora_path="/path/to/adapter1")
response = await serving_engine.load_lora_adapter(request)
assert len(serving_engine.lora_requests) == 1
request = UnloadLoraAdapterRequest(lora_name="adapter1")
response = await serving_engine.unload_lora_adapter(request)
assert response == LORA_UNLOADING_SUCCESS_MESSAGE.format(
lora_name='adapter1')
assert len(serving_engine.lora_requests) == 0
@pytest.mark.asyncio
async def test_unload_lora_adapter_missing_fields():
serving_engine = await _async_serving_engine_init()
request = UnloadLoraAdapterRequest(lora_name="", lora_int_id=None)
response = await serving_engine.unload_lora_adapter(request)
assert isinstance(response, ErrorResponse)
assert response.type == "InvalidUserInput"
assert response.code == HTTPStatus.BAD_REQUEST
@pytest.mark.asyncio
async def test_unload_lora_adapter_not_found():
serving_engine = await _async_serving_engine_init()
request = UnloadLoraAdapterRequest(lora_name="nonexistent_adapter")
response = await serving_engine.unload_lora_adapter(request)
assert isinstance(response, ErrorResponse)
assert response.type == "InvalidUserInput"
assert response.code == HTTPStatus.BAD_REQUEST
......@@ -3,8 +3,10 @@ from typing import Type
import pytest
import torch
from tests.kernels.utils import opcheck
from vllm.model_executor.layers.activation import (FastGELU, GeluAndMul,
NewGELU, SiluAndMul)
NewGELU, QuickGELU,
SiluAndMul)
from .allclose_default import get_default_atol, get_default_rtol
......@@ -39,18 +41,28 @@ def test_act_and_mul(
x = torch.randn(num_tokens, 2 * d, dtype=dtype)
if activation == "silu":
layer = SiluAndMul()
fn = torch.ops._C.silu_and_mul
elif activation == "gelu":
layer = GeluAndMul(approximate="none")
fn = torch.ops._C.gelu_and_mul
elif activation == "gelu_tanh":
layer = GeluAndMul(approximate="tanh")
fn = torch.ops._C.gelu_tanh_and_mul
out = layer(x)
ref_out = layer.forward_native(x)
# The SiLU and GELU implementations are equivalent to the native PyTorch
# implementations, so we can do exact comparison.
torch.testing.assert_close(out, ref_out, atol=0.0, rtol=0.0)
d = x.shape[-1] // 2
output_shape = (x.shape[:-1] + (d, ))
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
opcheck(fn, (out, x))
@pytest.mark.parametrize("activation", [FastGELU, NewGELU])
@pytest.mark.parametrize("activation", [(FastGELU, torch.ops._C.gelu_fast),
(NewGELU, torch.ops._C.gelu_new),
(QuickGELU, torch.ops._C.gelu_quick)])
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
@pytest.mark.parametrize("d", D)
@pytest.mark.parametrize("dtype", DTYPES)
......@@ -70,10 +82,14 @@ def test_activation(
torch.cuda.manual_seed(seed)
torch.set_default_device(device)
x = torch.randn(num_tokens, d, dtype=dtype)
layer = activation()
layer = activation[0]()
fn = activation[1]
out = layer(x)
ref_out = layer.forward_native(x)
torch.testing.assert_close(out,
ref_out,
atol=get_default_atol(out),
rtol=get_default_rtol(out))
out = torch.empty_like(x)
opcheck(fn, (out, x))
......@@ -6,6 +6,7 @@ import torch
from xformers import ops as xops
from xformers.ops.fmha.attn_bias import BlockDiagonalCausalMask
from tests.kernels.utils import opcheck
from vllm import _custom_ops as ops
from vllm.utils import get_max_shared_memory_bytes, is_hip
......@@ -199,6 +200,13 @@ def test_paged_attention(
k_scale,
v_scale,
)
opcheck(torch.ops._C.paged_attention_v1,
(output, query, key_cache, value_cache, num_kv_heads, scale,
block_tables, seq_lens, block_size, max_seq_len, alibi_slopes,
kv_cache_dtype, k_scale, v_scale, 0, 0, 0, 64, 0),
cond=(head_size == HEAD_SIZES[0]))
elif version == "v2":
num_partitions = ((max_seq_len + PARTITION_SIZE - 1) // PARTITION_SIZE)
assert PARTITION_SIZE % block_size == 0
......@@ -231,6 +239,14 @@ def test_paged_attention(
k_scale,
v_scale,
)
opcheck(torch.ops._C.paged_attention_v2,
(output, exp_sums, max_logits, tmp_output, query, key_cache,
value_cache, num_kv_heads, scale, block_tables, seq_lens,
block_size, max_seq_len, alibi_slopes, kv_cache_dtype,
k_scale, v_scale, 0, 0, 0, 64, 0),
cond=(head_size == HEAD_SIZES[0]))
else:
raise AssertionError(f"Unknown version: {version}")
......
......@@ -4,6 +4,7 @@ from typing import List, Tuple
import pytest
import torch
from tests.kernels.utils import DEFAULT_OPCHECK_TEST_UTILS, opcheck
from vllm import _custom_ops as ops
COPYING_DIRECTION = [('cuda', 'cpu'), ('cuda', 'cuda'), ('cpu', 'cuda')]
......@@ -88,6 +89,11 @@ def test_copy_blocks(
block_mapping_tensor = torch.tensor(block_mapping,
dtype=torch.int64,
device=device).view(-1, 2)
opcheck(torch.ops._C_cache_ops.copy_blocks,
(key_caches, value_caches, block_mapping_tensor),
test_utils=DEFAULT_OPCHECK_TEST_UTILS,
cond=(head_size == HEAD_SIZES[0]))
ops.copy_blocks(key_caches, value_caches, block_mapping_tensor)
# Run the reference implementation.
......@@ -163,6 +169,10 @@ def test_reshape_and_cache(
k_scale = v_scale = 1.0
# Call the reshape_and_cache kernel.
opcheck(torch.ops._C_cache_ops.reshape_and_cache,
(key, value, key_cache, value_cache, slot_mapping, kv_cache_dtype,
k_scale, v_scale),
cond=(head_size == HEAD_SIZES[0]))
ops.reshape_and_cache(key, value, key_cache, value_cache, slot_mapping,
kv_cache_dtype, k_scale, v_scale)
......@@ -270,6 +280,10 @@ def test_reshape_and_cache_flash(
k_scale = v_scale = 1.0
# Call the reshape_and_cache kernel.
opcheck(torch.ops._C_cache_ops.reshape_and_cache_flash,
(key, value, key_cache, value_cache, slot_mapping, kv_cache_dtype,
k_scale, v_scale),
cond=(head_size == HEAD_SIZES[0]))
ops.reshape_and_cache_flash(key, value, key_cache, value_cache,
slot_mapping, kv_cache_dtype, k_scale, v_scale)
......@@ -367,6 +381,14 @@ def test_swap_blocks(
src_value_caches_clone = src_value_caches[0].clone()
# Call the swap_blocks kernel.
do_opcheck = (head_size == HEAD_SIZES[0])
opcheck(torch.ops._C_cache_ops.swap_blocks,
(src_key_caches[0], dist_key_caches[0], block_mapping_tensor),
cond=do_opcheck)
opcheck(torch.ops._C_cache_ops.swap_blocks,
(src_value_caches[0], dist_value_caches[0], block_mapping_tensor),
cond=do_opcheck)
ops.swap_blocks(src_key_caches[0], dist_key_caches[0],
block_mapping_tensor)
ops.swap_blocks(src_value_caches[0], dist_value_caches[0],
......
......@@ -7,6 +7,7 @@ from typing import Optional, Type
import pytest
import torch
from tests.kernels.utils import opcheck
from vllm import _custom_ops as ops
from vllm.platforms import current_platform
......@@ -108,6 +109,9 @@ def cutlass_int8_gemm_helper(m: int,
torch.testing.assert_close(out, baseline, rtol=1e-1, atol=1e0)
opcheck(torch.ops._C.cutlass_scaled_mm,
(out, a, b, scale_a, scale_b, bias))
@pytest.mark.parametrize("m", [1, 16, 32, 64, 128, 256, 512, 222, 100, 33])
@pytest.mark.parametrize("n", [2048, 4096, 8192, 16384, 24576, 256, 1024])
......@@ -341,6 +345,15 @@ def test_cutlass_int8_azp(m: int, n: int, k: int, out_dtype: torch.dtype,
torch.testing.assert_close(out, baseline_dq, rtol=rtol, atol=atol)
torch.testing.assert_close(out, baseline_q, rtol=rtol, atol=atol)
if azp_per_token:
opcheck(torch.ops._C.cutlass_scaled_mm_azp,
(out, aq_i8, bq_i8, scale_a, scale_b, azp_adj_i32, azp_i32,
func_bias))
else:
opcheck(torch.ops._C.cutlass_scaled_mm_azp,
(out, aq_i8, bq_i8, scale_a, scale_b, azp_with_adj_i32, None,
func_bias))
# Test working with a subset of A and B
def test_cutlass_subset():
......
......@@ -445,7 +445,8 @@ def test_flashinfer_decode_with_paged_fp8_kv(
head_size,
block_size,
"NONE",
data_type=dtype)
data_type=dtype,
q_data_type=dtype)
output = wrapper.forward(query,
kv_cache_fp8,
logits_soft_cap=soft_cap,
......
......@@ -2,6 +2,7 @@ import pytest
import torch
from tests.kernels.quant_utils import ref_dynamic_per_token_quant
from tests.kernels.utils import opcheck
from vllm._custom_ops import scaled_int8_quant
DTYPES = [torch.half, torch.bfloat16, torch.float]
......@@ -12,6 +13,16 @@ SEEDS = [0]
SCALE = [0.1, 0.5, 0.8, 1.2, 2.1]
def opcheck_int8_quant(output, input, scale=None):
if scale is not None:
opcheck(torch.ops._C.static_scaled_int8_quant, (output, input, scale))
else:
scale = torch.empty((input.numel() // input.shape[-1], 1),
device=input.device,
dtype=torch.float32)
opcheck(torch.ops._C.dynamic_scaled_int8_quant, (output, input, scale))
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
@pytest.mark.parametrize("dtype", DTYPES)
......@@ -34,6 +45,8 @@ def test_dynamic_scaled_int8_quant(num_tokens: int, hidden_size: int,
ops_out, ref_out, atol=1,
rtol=0.0) # big atol to account for rounding errors
opcheck_int8_quant(ops_out, x)
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
......@@ -58,3 +71,5 @@ def test_static_scaled_int8_quant(num_tokens: int, hidden_size: int,
torch.testing.assert_close(
out1, out2, atol=1,
rtol=0.0) # big atol to account for rounding errors
opcheck_int8_quant(out2, x, scale)
import pytest
import torch
from tests.kernels.utils import opcheck
from vllm.model_executor.layers.layernorm import RMSNorm
DTYPES = [torch.half, torch.bfloat16, torch.float]
......@@ -52,3 +53,10 @@ def test_rms_norm(
torch.testing.assert_close(out[1], ref_out[1], atol=1e-2, rtol=1e-2)
else:
torch.testing.assert_close(out, ref_out, atol=1e-2, rtol=1e-2)
if residual is not None:
opcheck(torch.ops._C.fused_add_rms_norm,
(x, residual, layer.weight.data, layer.variance_epsilon))
else:
opcheck(torch.ops._C.rms_norm,
(out, x, layer.weight.data, layer.variance_epsilon))
......@@ -9,6 +9,7 @@ from typing import Optional, Tuple
import pytest
import torch
from tests.kernels.utils import opcheck
from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.utils.quant_utils import (
pack_rows, quantize_weights)
......@@ -76,6 +77,8 @@ def machete_quantize_and_pack(w: torch.Tensor,
w_q = w_q.t().contiguous().t() # convert to col major
w_q_machete = ops.machete_prepack_B(w_q, wtype)
opcheck(torch.ops._C.machete_prepack_B, (w_q, wtype))
return w_ref, w_q_machete, w_s, w_zp
......@@ -146,6 +149,10 @@ def test_machete_all_schedules(shape, atype: torch.dtype,
schedule=schedule,
)
opcheck(torch.ops._C.machete_gemm,
(a, w_q_machete, wtype, w_s, maybe_convert_zeropoints(
w_zp, w_s), group_size, None, None, None, schedule))
# Relax atol as our reduction dim becomes larger (more rounding error)
# Relax atol when we have zeropoints since the way machete applies
# zeropoints (after scales) causes noise around 0
......
......@@ -5,6 +5,7 @@ Run `pytest tests/kernels/marlin/test_marlin_gemm.py`.
import pytest
import torch
from tests.kernels.utils import DEFAULT_OPCHECK_TEST_UTILS, opcheck
from tests.quantization.utils import is_quant_method_supported
from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.gptq_marlin_24 import (
......@@ -73,12 +74,9 @@ def test_gptq_marlin_repack(k_chunk, n_chunk, quant_type, group_size,
act_order, mnk_factors):
m_factor, n_factor, k_factor = mnk_factors
size_m = m_factor
size_k = k_chunk * k_factor
size_n = n_chunk * n_factor
print(f"MNK = {size_m} {size_n} {size_k}")
# Filter act_order
if act_order:
if group_size == -1:
......@@ -112,6 +110,9 @@ def test_gptq_marlin_repack(k_chunk, n_chunk, quant_type, group_size,
marlin_q_w_1 = marlin_weights(q_w, size_k, size_n, quant_type.size_bits,
weight_perm)
opcheck(torch.ops._C.gptq_marlin_repack,
(q_w_gptq, sort_indices, size_k, size_n, quant_type.size_bits))
# Run Marlin repack GPU kernel
marlin_q_w_2 = ops.gptq_marlin_repack(
q_w_gptq,
......@@ -137,12 +138,9 @@ def test_awq_marlin_repack(k_chunk, n_chunk, quant_type, group_size,
mnk_factors):
m_factor, n_factor, k_factor = mnk_factors
size_m = m_factor
size_k = k_chunk * k_factor
size_n = n_chunk * n_factor
print(f"MNK = {size_m} {size_n} {size_k}")
# Normalize group_size
if group_size == -1:
group_size = size_k
......@@ -165,6 +163,9 @@ def test_awq_marlin_repack(k_chunk, n_chunk, quant_type, group_size,
marlin_q_w_1 = marlin_weights(q_w, size_k, size_n, quant_type.size_bits,
weight_perm)
opcheck(torch.ops._C.awq_marlin_repack,
(q_w_awq, size_k, size_n, quant_type.size_bits))
# Run Marlin repack GPU kernel
marlin_q_w_2 = ops.awq_marlin_repack(
q_w_awq,
......@@ -204,9 +205,6 @@ def test_gptq_marlin_gemm(
size_k = k_chunk * k_factor
size_n = n_chunk * n_factor
print(f"MNK = {size_m} {size_n} {size_k}")
print(f"groupsize = {group_size}")
if act_order:
if group_size == -1:
return
......@@ -224,6 +222,13 @@ def test_gptq_marlin_gemm(
workspace = MarlinWorkspace(size_n, GPTQ_MARLIN_MIN_THREAD_N,
GPTQ_MARLIN_MAX_PARALLEL)
opcheck(
torch.ops._C.gptq_marlin_gemm,
(a_input, marlin_q_w, marlin_s, marlin_zp, g_idx, sort_indices,
workspace.scratch, quant_type, a_input.shape[0], b_weight.shape[1],
a_input.shape[1], is_k_full, False, use_fp32_reduce),
test_utils=DEFAULT_OPCHECK_TEST_UTILS)
output = ops.gptq_marlin_gemm(
a_input,
marlin_q_w,
......@@ -245,7 +250,6 @@ def test_gptq_marlin_gemm(
torch.cuda.synchronize()
max_diff = compute_max_diff(output, output_ref)
print("max_diff = {}".format(max_diff))
assert max_diff < 0.04
......@@ -265,9 +269,6 @@ def test_gptq_marlin_24_gemm(k_chunk, n_chunk, quant_type, group_size,
size_k = k_chunk * k_factor
size_n = n_chunk * n_factor
print(f"MNK = {size_m} {size_n} {size_k}")
print(f"groupsize = {group_size}")
a_input = rand_data((size_m, size_k))
b_weight = rand_data((size_k, size_n))
......@@ -279,6 +280,12 @@ def test_gptq_marlin_24_gemm(k_chunk, n_chunk, quant_type, group_size,
output_ref = torch.matmul(a_input, w_24_ref)
opcheck(torch.ops._C.gptq_marlin_24_gemm,
(a_input, marlin_24_q_w_comp, marlin_24_meta, marlin_24_s,
workspace_24.scratch, quant_type, a_input.shape[0],
b_weight.shape[1], a_input.shape[1]),
test_utils=DEFAULT_OPCHECK_TEST_UTILS)
output = ops.gptq_marlin_24_gemm(
a_input,
marlin_24_q_w_comp,
......@@ -294,7 +301,6 @@ def test_gptq_marlin_24_gemm(k_chunk, n_chunk, quant_type, group_size,
torch.cuda.synchronize()
max_diff = compute_max_diff(output, output_ref)
print("max_diff = {}".format(max_diff))
assert max_diff < 0.04
......@@ -321,9 +327,6 @@ def test_fp8_marlin_gemm(
size_k = k_chunk * k_factor
size_n = n_chunk * n_factor
print(f"MNK = {size_m} {size_n} {size_k}")
print(f"groupsize = {group_size}")
a_input = rand_data((size_m, size_k), dtype=dtype)
b_weight = rand_data((size_k, size_n), dtype=dtype)
......@@ -353,6 +356,10 @@ def test_fp8_marlin_gemm(
workspace = MarlinWorkspace(size_n, GPTQ_MARLIN_MIN_THREAD_N,
GPTQ_MARLIN_MAX_PARALLEL)
opcheck(torch.ops._C.fp8_marlin_gemm,
(a_input, marlin_qweight, marlin_scales, workspace.scratch,
num_bits, a_input.shape[0], b_weight.shape[1], a_input.shape[1]))
output = ops.fp8_marlin_gemm(
a=a_input,
b_q_weight=marlin_qweight,
......@@ -368,7 +375,6 @@ def test_fp8_marlin_gemm(
torch.cuda.synchronize()
max_diff = compute_max_diff(output, output_ref)
print("max_diff = {}".format(max_diff))
assert max_diff < 0.04
......@@ -396,9 +402,6 @@ def test_awq_marlin_gemm(
size_k = k_chunk * k_factor
size_n = n_chunk * n_factor
print(f"MNK = {size_m} {size_n} {size_k}")
print(f"groupsize = {group_size}")
a_input = rand_data((size_m, size_k))
b_weight = rand_data((size_k, size_n))
......@@ -434,7 +437,6 @@ def test_awq_marlin_gemm(
torch.cuda.synchronize()
max_diff = compute_max_diff(output, output_ref)
print("max_diff = {}".format(max_diff))
assert max_diff < 0.04
......@@ -460,9 +462,6 @@ def test_marlin_qqq_gemm(
size_k = k_chunk * k_factor
size_n = n_chunk * n_factor
print(f"MNK = {size_m} {size_n} {size_k}")
print(f"groupsize = {group_size}")
a_input = rand_data((size_m, size_k))
b_weight = rand_data((size_k, size_n))
......@@ -479,6 +478,11 @@ def test_marlin_qqq_gemm(
workspace = MarlinWorkspace(size_n, MARLIN_QQQ_MIN_THREAD_N,
MARLIN_QQQ_MAX_PARALLEL)
opcheck(torch.ops._C.marlin_qqq_gemm,
(q_a, marlin_qqq_q_w, s_a, marlin_qqq_s_channel,
marlin_qqq_s_group, workspace.scratch, a_input.shape[0],
b_weight.shape[1], a_input.shape[1]))
output = ops.marlin_qqq_gemm(
q_a,
marlin_qqq_q_w,
......@@ -495,6 +499,5 @@ def test_marlin_qqq_gemm(
torch.cuda.synchronize()
max_diff = compute_max_diff(output, output_ref)
print("max_diff = {}".format(max_diff))
assert max_diff < 0.04
......@@ -2,6 +2,8 @@
Run `pytest tests/kernels/test_moe.py`.
"""
from typing import List
import pytest
import torch
from transformers import MixtralConfig
......@@ -9,7 +11,13 @@ from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import fused_moe
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
fused_marlin_moe, single_marlin_moe)
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
from vllm.model_executor.layers.quantization.utils.marlin_utils_test import (
marlin_quantize)
from vllm.model_executor.models.mixtral import MixtralMoE
from vllm.scalar_type import scalar_types
def torch_moe(a, w1, w2, score, topk):
......@@ -29,6 +37,20 @@ def torch_moe(a, w1, w2, score, topk):
topk_weight.view(B, -1, 1).to(out.dtype)).sum(dim=1)
def torch_moe_single(a, w, score, topk):
B, D = a.shape
a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D)
out = torch.zeros(B * topk, w.shape[1], dtype=a.dtype, device=a.device)
score = torch.softmax(score, dim=-1, dtype=torch.float32)
_, topk_ids = torch.topk(score, topk)
topk_ids = topk_ids.view(-1)
for i in range(w.shape[0]):
mask = topk_ids == i
if mask.sum():
out[mask] = a[mask] @ w[i].transpose(0, 1)
return (out.view(B, -1, w.shape[1])).sum(dim=1)
@pytest.mark.parametrize("m", [1024 * 128, 512, 222, 33, 1])
@pytest.mark.parametrize("n", [2048, 256, 1024])
@pytest.mark.parametrize("k", [128, 511, 1024])
......@@ -43,11 +65,11 @@ def test_fused_moe(
topk: int,
dtype: torch.dtype,
):
a = torch.randn((m, k), device='cuda', dtype=dtype) / 10
w1 = torch.randn((e, 2 * n, k), device='cuda', dtype=dtype) / 10
w2 = torch.randn((e, k, n), device='cuda', dtype=dtype) / 10
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10
score = torch.randn((m, e), device='cuda', dtype=dtype)
score = torch.randn((m, e), device="cuda", dtype=dtype)
triton_output = fused_moe(a, w1, w2, score, topk, renormalize=False)
torch_output = torch_moe(a, w1, w2, score, topk)
torch.testing.assert_close(triton_output, torch_output, atol=1e-2, rtol=0)
......@@ -99,3 +121,194 @@ def test_mixtral_moe(dtype: torch.dtype):
vllm_states,
rtol=mixtral_moe_tol[dtype],
atol=mixtral_moe_tol[dtype])
def stack_and_dev(tensors: List[torch.Tensor]):
dev = tensors[0].device
return torch.stack(tensors, dim=0).to(dev)
def compute_max_diff(output, output_ref):
return torch.mean(torch.abs(output - output_ref)) / torch.mean(
torch.abs(output_ref))
@pytest.mark.parametrize("m", [64, 512, 222, 33, 1])
@pytest.mark.parametrize("n", [128, 2048, 256, 1024])
@pytest.mark.parametrize("k", [128, 1024, 512])
@pytest.mark.parametrize("e", [4, 8, 64])
@pytest.mark.parametrize("topk", [2, 6])
@pytest.mark.parametrize("group_size", [-1, 32, 64, 128])
@pytest.mark.parametrize("act_order", [True, False])
def test_fused_marlin_moe(
m: int,
n: int,
k: int,
e: int,
topk: int,
group_size: int,
act_order: bool,
):
torch.manual_seed(7)
if topk > e:
return
# Filter act_order
if act_order:
if group_size == -1:
return
if group_size in (k, n):
return
quant_type = scalar_types.uint4b8
dtype = torch.float16
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10
for i in range(w2.shape[0]):
w2[0] = torch.eye(k, n, device="cuda", dtype=dtype)
w_ref1_l = []
qweight1_l = []
scales1_l = []
g_idx1_l = []
sort_indices1_l = []
for i in range(w1.shape[0]):
test_perm = torch.randperm(k)
w_ref1, qweight1, scales1, g_idx1, sort_indices1, _ = marlin_quantize(
w1[i].transpose(1, 0), quant_type, group_size, act_order,
test_perm)
w_ref1_l.append(w_ref1)
qweight1_l.append(qweight1)
scales1_l.append(scales1)
g_idx1_l.append(g_idx1)
sort_indices1_l.append(sort_indices1)
w_ref1 = stack_and_dev(w_ref1_l)
qweight1 = stack_and_dev(qweight1_l).contiguous()
scales1 = stack_and_dev(scales1_l)
g_idx1 = stack_and_dev(g_idx1_l)
sort_indices1 = stack_and_dev(sort_indices1_l)
w_ref2_l = []
qweight2_l = []
scales2_l = []
g_idx2_l = []
sort_indices2_l = []
for i in range(w2.shape[0]):
test_perm = torch.randperm(n)
w_ref2, qweight2, scales2, g_idx2, sort_indices2, _ = marlin_quantize(
w2[i].transpose(1, 0), quant_type, group_size, act_order,
test_perm)
w_ref2_l.append(w_ref2)
qweight2_l.append(qweight2)
scales2_l.append(scales2)
g_idx2_l.append(g_idx2)
sort_indices2_l.append(sort_indices2)
w_ref2 = stack_and_dev(w_ref2_l)
qweight2 = stack_and_dev(qweight2_l).contiguous()
scales2 = stack_and_dev(scales2_l)
g_idx2 = stack_and_dev(g_idx2_l)
sort_indices2 = stack_and_dev(sort_indices2_l)
score = torch.randn((m, e), device="cuda", dtype=dtype)
topk_weights, topk_ids = fused_topk(a, score, topk, False)
triton_output = fused_moe(
a,
w_ref1.transpose(1, 2).contiguous(),
w_ref2.transpose(1, 2).contiguous(),
score,
topk,
renormalize=False,
)
marlin_output = fused_marlin_moe(
a,
qweight1,
qweight2,
score,
g_idx1,
g_idx2,
sort_indices1,
sort_indices2,
topk_weights,
topk_ids,
w1_scale=scales1,
w2_scale=scales2,
)
assert compute_max_diff(marlin_output, triton_output) < 4e-2
@pytest.mark.skip("This test is here for the sake of debugging, "
"don't run it in automated tests.")
@pytest.mark.parametrize("m", [64, 512, 222, 33, 1])
@pytest.mark.parametrize("n", [128, 2048, 256, 1024])
@pytest.mark.parametrize("k", [128, 1024, 512])
@pytest.mark.parametrize("e", [4, 8, 64])
@pytest.mark.parametrize("topk", [2, 6])
@pytest.mark.parametrize("group_size", [-1, 32, 64, 128])
@pytest.mark.parametrize("act_order", [True, False])
def test_marlin_moe_mmm(
m: int,
n: int,
k: int,
e: int,
topk: int,
group_size: int,
act_order: bool,
):
if topk > e:
return
# Filter act_order
if act_order:
if group_size == -1:
return
if group_size == k:
return
quant_type = scalar_types.uint4b8
dtype = torch.float16
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
w = torch.randn((e, n, k), device="cuda", dtype=dtype) / 10
w_ref_l = []
qweights_l = []
scales_l = []
g_idx_l = []
sort_indices_l = []
for i in range(w.shape[0]):
test_perm = torch.randperm(k)
w_ref, qweight, scales, g_idx, sort_indices, _ = marlin_quantize(
w[i].transpose(1, 0), quant_type, group_size, act_order, test_perm)
w_ref_l.append(w_ref)
qweights_l.append(qweight)
scales_l.append(scales)
g_idx_l.append(g_idx)
sort_indices_l.append(sort_indices)
w_ref = stack_and_dev(w_ref_l)
qweight = stack_and_dev(qweights_l).contiguous()
scales = stack_and_dev(scales_l)
g_idx = stack_and_dev(g_idx_l)
sort_indices = stack_and_dev(sort_indices_l)
score = torch.randn((m, e), device="cuda", dtype=dtype)
marlin_output = single_marlin_moe(a,
qweight,
scales,
score,
g_idx,
sort_indices,
topk,
renormalize=False)
torch_output = torch_moe_single(a, w_ref.transpose(1, 2), score, topk)
assert compute_max_diff(marlin_output, torch_output) < 1e-2
......@@ -3,7 +3,8 @@
import itertools
import random
from numbers import Number
from typing import Any, List, NamedTuple, Optional, Tuple, Union
from typing import (Any, Dict, List, NamedTuple, Optional, Sequence, Tuple,
Union)
import pytest
import torch
......@@ -13,6 +14,21 @@ from vllm.attention.backends.xformers import XFormersBackend
from vllm.utils import (STR_BACKEND_ENV_VAR, STR_XFORMERS_ATTN_VAL,
make_tensor_with_pad)
# For now, disable "test_aot_dispatch_dynamic" since there are some
# bugs related to this test in PyTorch 2.4.
DEFAULT_OPCHECK_TEST_UTILS: Tuple[str, ...] = (
"test_schema",
"test_autograd_registration",
"test_faketensor",
)
ALL_OPCHECK_TEST_UTILS: Tuple[str, ...] = (
"test_schema",
"test_autograd_registration",
"test_faketensor",
"test_aot_dispatch_dynamic",
)
class QKVInputs(NamedTuple):
'''
......@@ -926,3 +942,19 @@ def assert_actual_matches_ideal(test_params: PhaseTestParameters,
ideal_output = test_params.packed_qkvo.ideal_output
torch.testing.assert_close(ideal_output,
output_under_test.view_as(ideal_output))
def opcheck(op: Union[torch._ops.OpOverload, torch._ops.OpOverloadPacket,
torch._library.custom_ops.CustomOpDef],
args: Tuple[Any, ...],
kwargs: Optional[Dict[str, Any]] = None,
*,
test_utils: Union[str, Sequence[str]] = ALL_OPCHECK_TEST_UTILS,
raise_exception: bool = True,
cond: bool = True) -> Dict[str, str]:
return torch.library.opcheck(
op,
args,
kwargs,
test_utils=test_utils,
raise_exception=raise_exception) if cond else {}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment