"cmake/vscode:/vscode.git/clone" did not exist on "c7914d30f90bc47f1c959d3330666885a0034f7d"
Commit e7c1b7f3 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge branch 'v0.5.4-dtk24.04.1'

parents 7462218e 04c62b93
import asyncio
from contextlib import suppress
from dataclasses import dataclass
from unittest.mock import MagicMock
import pytest
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.entrypoints.openai.protocol import ChatCompletionRequest
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
from vllm.transformers_utils.tokenizer import get_tokenizer
MODEL_NAME = "openai-community/gpt2"
CHAT_TEMPLATE = "Dummy chat template for testing {}"
pytestmark = pytest.mark.openai
@dataclass
class MockModelConfig:
......@@ -36,11 +37,47 @@ async def _async_serving_chat_init():
model_config,
served_model_names=[MODEL_NAME],
response_role="assistant",
chat_template=CHAT_TEMPLATE)
chat_template=CHAT_TEMPLATE,
lora_modules=None,
prompt_adapters=None,
request_logger=None)
return serving_completion
def test_async_serving_chat_init():
serving_completion = asyncio.run(_async_serving_chat_init())
assert serving_completion.tokenizer is not None
assert serving_completion.tokenizer.chat_template == CHAT_TEMPLATE
assert serving_completion.chat_template == CHAT_TEMPLATE
def test_serving_chat_should_set_correct_max_tokens():
mock_engine = MagicMock(spec=AsyncLLMEngine)
mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
serving_chat = OpenAIServingChat(mock_engine,
MockModelConfig(),
served_model_names=[MODEL_NAME],
response_role="assistant",
chat_template=CHAT_TEMPLATE,
lora_modules=None,
prompt_adapters=None,
request_logger=None)
req = ChatCompletionRequest(
model=MODEL_NAME,
messages=[{
"role": "user",
"content": "what is 1+1?"
}],
guided_decoding_backend="outlines",
)
with suppress(Exception):
asyncio.run(serving_chat.create_chat_completion(req))
# AsyncLLMEngine.generate(inputs, sampling_params, ...)
assert mock_engine.generate.call_args.args[1].max_tokens == 93
req.max_tokens = 10
with suppress(Exception):
asyncio.run(serving_chat.create_chat_completion(req))
assert mock_engine.generate.call_args.args[1].max_tokens == 10
import openai # use the official client for correctness check
import pytest
import requests
from vllm.transformers_utils.tokenizer import get_tokenizer
from ...utils import RemoteOpenAIServer
from .test_completion import zephyr_lora_added_tokens_files # noqa: F401
from .test_completion import zephyr_lora_files # noqa: F401
# any model with a chat template should work here
MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta"
@pytest.fixture(scope="module")
def server(zephyr_lora_added_tokens_files: str): # noqa: F811
args = [
# use half precision for speed and memory savings in CI environment
"--dtype",
"bfloat16",
"--max-model-len",
"8192",
"--enforce-eager",
"--max-num-seqs",
"128",
# lora config
"--enable-lora",
"--lora-modules",
f"zephyr-lora2={zephyr_lora_added_tokens_files}",
"--max-lora-rank",
"64",
]
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
yield remote_server
@pytest.fixture(scope="module")
def tokenizer_name(model_name: str,
zephyr_lora_added_tokens_files: str): # noqa: F811
return zephyr_lora_added_tokens_files if (
model_name == "zephyr-lora2") else model_name
@pytest.fixture(scope="module")
def client(server):
return server.get_async_client()
@pytest.mark.asyncio
@pytest.mark.parametrize(
"model_name,tokenizer_name",
[(MODEL_NAME, MODEL_NAME), ("zephyr-lora2", "zephyr-lora2")],
indirect=["tokenizer_name"],
)
async def test_tokenize_completions(client: openai.AsyncOpenAI,
model_name: str, tokenizer_name: str):
base_url = str(client.base_url)[:-3].strip("/")
tokenizer = get_tokenizer(tokenizer_name=tokenizer_name,
tokenizer_mode="fast")
for add_special in [False, True]:
prompt = "vllm1 This is a test prompt."
tokens = tokenizer.encode(prompt, add_special_tokens=add_special)
response = requests.post(base_url + "/tokenize",
json={
"add_special_tokens": add_special,
"model": model_name,
"prompt": prompt
})
response.raise_for_status()
assert response.json() == {
"tokens": tokens,
"count": len(tokens),
"max_model_len": 8192
}
@pytest.mark.asyncio
@pytest.mark.parametrize(
"model_name,tokenizer_name",
[(MODEL_NAME, MODEL_NAME), ("zephyr-lora2", "zephyr-lora2")],
indirect=["tokenizer_name"],
)
async def test_tokenize_chat(client: openai.AsyncOpenAI, model_name: str,
tokenizer_name: str):
base_url = str(client.base_url)[:-3].strip("/")
tokenizer = get_tokenizer(tokenizer_name=tokenizer_name,
tokenizer_mode="fast")
for add_generation in [False, True]:
for add_special in [False, True]:
conversation = [{
"role": "user",
"content": "Hi there!"
}, {
"role": "assistant",
"content": "Nice to meet you!"
}, {
"role": "user",
"content": "Can I ask a question? vllm1"
}]
prompt = tokenizer.apply_chat_template(
add_generation_prompt=add_generation,
conversation=conversation,
tokenize=False)
tokens = tokenizer.encode(prompt, add_special_tokens=add_special)
response = requests.post(base_url + "/tokenize",
json={
"add_generation_prompt":
add_generation,
"add_special_tokens": add_special,
"messages": conversation,
"model": model_name
})
response.raise_for_status()
assert response.json() == {
"tokens": tokens,
"count": len(tokens),
"max_model_len": 8192
}
@pytest.mark.asyncio
@pytest.mark.parametrize(
"model_name,tokenizer_name",
[(MODEL_NAME, MODEL_NAME), ("zephyr-lora2", "zephyr-lora2")],
indirect=["tokenizer_name"],
)
async def test_detokenize(client: openai.AsyncOpenAI, model_name: str,
tokenizer_name: str):
base_url = str(client.base_url)[:-3].strip("/")
tokenizer = get_tokenizer(tokenizer_name=tokenizer_name,
tokenizer_mode="fast")
prompt = "This is a test prompt. vllm1"
tokens = tokenizer.encode(prompt, add_special_tokens=False)
print(f"CALLING {base_url} FOR {model_name}")
response = requests.post(base_url + "/detokenize",
json={
"model": model_name,
"tokens": tokens
})
response.raise_for_status()
assert response.json() == {"prompt": prompt}
from pathlib import Path
from typing import Dict
from typing import Dict, List
import openai
import pytest
import pytest_asyncio
import ray
from vllm.multimodal.utils import ImageFetchAiohttp, encode_image_base64
from vllm.multimodal.utils import encode_image_base64, fetch_image
from ..utils import VLLM_PATH, RemoteOpenAIServer
from ...utils import VLLM_PATH, RemoteOpenAIServer
MODEL_NAME = "llava-hf/llava-1.5-7b-hf"
LLAVA_CHAT_TEMPLATE = (Path(__file__).parent.parent.parent /
"examples/template_llava.jinja")
LLAVA_CHAT_TEMPLATE = VLLM_PATH / "examples/template_llava.jinja"
assert LLAVA_CHAT_TEMPLATE.exists()
# Test different image extensions (JPG/PNG) and formats (gray/RGB/RGBA)
TEST_IMAGE_URLS = [
"https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg",
......@@ -22,37 +19,21 @@ TEST_IMAGE_URLS = [
"https://upload.wikimedia.org/wikipedia/commons/0/0b/RGBA_comp.png",
]
pytestmark = pytest.mark.openai
@pytest.fixture(scope="module")
def ray_ctx():
ray.init(runtime_env={"working_dir": VLLM_PATH})
yield
ray.shutdown()
@pytest.fixture(scope="module")
def server():
return RemoteOpenAIServer([
"--model",
MODEL_NAME,
args = [
"--dtype",
"bfloat16",
"--max-model-len",
"4096",
"--enforce-eager",
"--image-input-type",
"pixel_values",
"--image-token-id",
"32000",
"--image-input-shape",
"1,3,336,336",
"--image-feature-size",
"576",
"--chat-template",
str(LLAVA_CHAT_TEMPLATE),
])
]
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
yield remote_server
@pytest.fixture(scope="module")
......@@ -60,11 +41,10 @@ def client(server):
return server.get_async_client()
@pytest_asyncio.fixture(scope="session")
async def base64_encoded_image() -> Dict[str, str]:
@pytest.fixture(scope="session")
def base64_encoded_image() -> Dict[str, str]:
return {
image_url:
encode_image_base64(await ImageFetchAiohttp.fetch_image(image_url))
image_url: encode_image_base64(fetch_image(image_url))
for image_url in TEST_IMAGE_URLS
}
......@@ -216,7 +196,7 @@ async def test_chat_streaming_image(client: openai.AsyncOpenAI,
temperature=0.0,
stream=True,
)
chunks = []
chunks: List[str] = []
finish_reason_count = 0
async for chunk in stream:
delta = chunk.choices[0].delta
......@@ -279,7 +259,3 @@ async def test_multi_image_input(client: openai.AsyncOpenAI, model_name: str,
)
completion = completion.choices[0].text
assert completion is not None and len(completion) >= 0
if __name__ == "__main__":
pytest.main([__file__])
from typing import Optional, Tuple, Union
import torch
def as_float32_tensor(x: Union[float, torch.tensor]) -> torch.tensor:
return torch.as_tensor(x, dtype=torch.float32, device='cuda')
def ref_dynamic_per_token_quant(x: torch.tensor,
quant_dtype: torch.dtype,
scale_ub: Optional[torch.tensor] = None) \
-> Tuple[torch.tensor, torch.tensor]:
assert quant_dtype in [torch.int8, torch.float8_e4m3fn]
if scale_ub is not None:
assert quant_dtype == torch.float8_e4m3fn
qtype_traits = torch.iinfo(quant_dtype) if quant_dtype == torch.int8 \
else torch.finfo(quant_dtype)
qtype_max = as_float32_tensor(qtype_traits.max)
s_1 = as_float32_tensor(1.0)
s_512 = as_float32_tensor(512.0)
# For fp8, in order to match the cuda kernel output, we have to do exactly
# the same operations as in the corresponding fp8 kernel to prevent
# rounding errors.
# Compute scales
x_token_max, _ = x.abs().max(dim=-1)
x_token_max = as_float32_tensor(x_token_max)
if scale_ub is not None:
x_token_max = x_token_max.clamp(max=scale_ub)
scales = (x_token_max / qtype_max)[:, None]
# Quant
if quant_dtype == torch.int8:
iscales = as_float32_tensor(s_1 / scales)
torch_out = as_float32_tensor(x) * iscales
torch_out = torch_out.round()
torch_out = torch_out.clamp(qtype_traits.min,
qtype_traits.max).to(quant_dtype)
else:
assert quant_dtype == torch.float8_e4m3fn
min_scaling_factor = s_1 / (qtype_max * s_512)
scales = scales.clamp(min=min_scaling_factor)
torch_out = as_float32_tensor(x) / scales
torch_out = torch_out.clamp(qtype_traits.min,
qtype_traits.max).to(quant_dtype)
return torch_out, scales
# The int8 version is very similar. Incorporate the int8 version, like in
# ref_dynamic_per_token_quant, when we have a dynamic_per_tensor int8 quant
# kernel
def ref_dynamic_per_tensor_fp8_quant(x: torch.tensor) \
-> Tuple[torch.tensor, torch.tensor]:
fp8_traits = torch.finfo(torch.float8_e4m3fn)
fp8_max = as_float32_tensor(fp8_traits.max)
one = as_float32_tensor(1.0)
# For fp8, in order to match the cuda kernel output, we have to do exactly
# the same operations as in the corresponding fp8 kernel to prevent
# rounding errors.
x_max = as_float32_tensor(x.abs().max())
ref_scale = x_max / fp8_max
ref_iscale = one / ref_scale
ref_out = (as_float32_tensor(x) * ref_iscale).clamp(
fp8_traits.min, fp8_traits.max).to(dtype=torch.float8_e4m3fn)
return ref_out, ref_scale
......@@ -29,7 +29,7 @@ NUM_HEADS = [(40, 40), (64, 8)] # Arbitrary values for testing
# FlashAttention forward only supports head dimension at most 128
# https://github.com/ROCmSoftwarePlatform/flash-attention/blob/3d2b6f5d037782cc2c906909a46fb7e2e1b48b25/csrc/flash_attn_rocm/flash_api.cpp#L62
HEAD_SIZES = [64, 80, 96, 112, 128, 192, 256
HEAD_SIZES = [64, 80, 96, 112, 120, 128, 192, 256
] if not is_hip() else [64, 80, 96, 112, 128]
BLOCK_SIZES = [16, 32]
......@@ -73,27 +73,27 @@ def ref_single_query_cached_kv_attention(
block_size = value_cache.shape[3]
num_seqs = query.shape[0]
block_tables = block_tables.cpu().tolist()
seq_lens = seq_lens.cpu().tolist()
block_tables_lst = block_tables.cpu().tolist()
seq_lens_lst = seq_lens.cpu().tolist()
for i in range(num_seqs):
q = query[i].unsqueeze(0)
block_table = block_tables[i]
seq_len = int(seq_lens[i])
block_table = block_tables_lst[i]
seq_len = int(seq_lens_lst[i])
keys = []
values = []
keys_lst: List[torch.Tensor] = []
values_lst: List[torch.Tensor] = []
for j in range(seq_len):
block_number = int(block_table[j // block_size])
block_offset = j % block_size
k = key_cache[block_number, :, :, block_offset, :]
k = k.reshape(num_kv_heads, head_size)
keys.append(k)
keys_lst.append(k)
v = value_cache[block_number, :, :, block_offset]
values.append(v)
keys = torch.stack(keys, dim=0)
values = torch.stack(values, dim=0)
values_lst.append(v)
keys = torch.stack(keys_lst, dim=0)
values = torch.stack(values_lst, dim=0)
if num_queries_per_kv > 1:
# Handle MQA and GQA
keys = torch.repeat_interleave(keys, num_queries_per_kv, dim=1)
......@@ -135,6 +135,8 @@ def test_paged_attention(
seed: int,
device: str,
) -> None:
if kv_cache_dtype == "fp8" and head_size % 16:
pytest.skip()
random.seed(seed)
torch.random.manual_seed(seed)
if torch.cuda.is_available():
......@@ -158,14 +160,15 @@ def test_paged_attention(
# Create the block tables.
max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size
block_tables = []
block_tables_lst: List[List[int]] = []
for _ in range(num_seqs):
block_table = [
random.randint(0, NUM_BLOCKS - 1)
for _ in range(max_num_blocks_per_seq)
]
block_tables.append(block_table)
block_tables = torch.tensor(block_tables, dtype=torch.int)
block_tables_lst.append(block_table)
block_tables = torch.tensor(block_tables_lst, dtype=torch.int)
# Create the KV caches.
key_caches, value_caches = kv_cache_factory(NUM_BLOCKS, block_size, 1,
......@@ -175,7 +178,7 @@ def test_paged_attention(
key_cache, value_cache = key_caches[0], value_caches[0]
# Using default kv_scale
kv_scale = 1.0
k_scale = v_scale = 1.0
# Call the paged attention kernel.
output = torch.empty_like(query)
......@@ -193,7 +196,8 @@ def test_paged_attention(
max_seq_len,
alibi_slopes,
kv_cache_dtype,
kv_scale,
k_scale,
v_scale,
)
elif version == "v2":
num_partitions = ((max_seq_len + PARTITION_SIZE - 1) // PARTITION_SIZE)
......@@ -224,7 +228,8 @@ def test_paged_attention(
max_seq_len,
alibi_slopes,
kv_cache_dtype,
kv_scale,
k_scale,
v_scale,
)
else:
raise AssertionError(f"Unknown version: {version}")
......@@ -284,7 +289,7 @@ def ref_multi_query_kv_attention(
dtype: torch.dtype,
) -> torch.Tensor:
num_seqs = len(cu_seq_lens) - 1
ref_outputs = []
ref_outputs: List[torch.Tensor] = []
for i in range(num_seqs):
start_idx = cu_seq_lens[i]
end_idx = cu_seq_lens[i + 1]
......@@ -304,8 +309,8 @@ def ref_multi_query_kv_attention(
attn_mask=attn_mask,
)
ref_outputs.append(ref_output)
ref_output = torch.cat(ref_outputs, dim=0)
return ref_output
return torch.cat(ref_outputs, dim=0)
# TODO(woosuk): Add tests for USE_ALIBI=True.
......
......@@ -9,8 +9,8 @@ from vllm.attention.selector import which_attn_to_use
@pytest.mark.parametrize(
"name", ["TORCH_SDPA", "ROCM_FLASH", "XFORMERS", "FLASHINFER"])
@pytest.mark.parametrize("device", ["cpu", "hip", "cuda"])
"name", ["TORCH_SDPA", "ROCM_FLASH", "XFORMERS", "FLASHINFER", "OPENVINO"])
@pytest.mark.parametrize("device", ["cpu", "openvino", "hip", "cuda"])
def test_env(name: str, device: str, monkeypatch):
"""Test that the attention selector can be set via environment variable.
Note that we do not test FlashAttn because it is the default backend.
......@@ -28,6 +28,11 @@ def test_env(name: str, device: str, monkeypatch):
backend = which_attn_to_use(8, 16, 8, None, torch.float16,
torch.float16, 16)
assert backend.name == "ROCM_FLASH"
elif device == "openvino":
with patch("vllm.attention.selector.is_openvino", return_value=True):
backend = which_attn_to_use(8, 16, 8, None, torch.float16,
torch.float16, 16)
assert backend.name == "OPENVINO"
else:
backend = which_attn_to_use(8, 16, 8, None, torch.float16,
torch.float16, 16)
......@@ -42,36 +47,36 @@ def test_flash_attn(monkeypatch):
# Unsupported CUDA arch
with patch("torch.cuda.get_device_capability", return_value=[7, 5]):
backend = which_attn_to_use(8, 16, 8, None, torch.float16, None, 16)
assert backend.name != "FLASH_ATTN"
assert backend.name != STR_FLASH_ATTN_VAL
# Unsupported data type
backend = which_attn_to_use(8, 16, 8, None, torch.float8_e4m3fn, None, 16)
assert backend.name != "FLASH_ATTN"
assert backend.name != STR_FLASH_ATTN_VAL
# Unsupported kv cache data type
backend = which_attn_to_use(8, 16, 8, None, torch.float16, "fp8", 16)
assert backend.name != "FLASH_ATTN"
assert backend.name != STR_FLASH_ATTN_VAL
# Unsupported block size
backend = which_attn_to_use(8, 16, 8, None, torch.float16, None, 8)
assert backend.name != "FLASH_ATTN"
assert backend.name != STR_FLASH_ATTN_VAL
# Unsupported sliding window
backend = which_attn_to_use(8, 16, 8, 1, torch.float16, None, 16)
assert backend.name != "FLASH_ATTN"
assert backend.name != STR_FLASH_ATTN_VAL
# flash-attn is not installed
with patch.dict('sys.modules', {'vllm_flash_attn': None}):
backend = which_attn_to_use(8, 16, 8, None, torch.float16, None, 16)
assert backend.name != "FLASH_ATTN"
assert backend.name != STR_FLASH_ATTN_VAL
# Unsupported head size
backend = which_attn_to_use(8, 17, 8, None, torch.float16, None, 16)
assert backend.name != "FLASH_ATTN"
assert backend.name != STR_FLASH_ATTN_VAL
def test_invalid_env(monkeypatch):
"""Throw an exception if the backend name is invalid."""
override_backend_env_variable(monkeypatch, STR_INVALID_VAL)
with pytest.raises(ValueError):
which_attn_to_use(8, 16, 8, None, torch.float16, None, 16)
which_attn_to_use(8, 16, 8, None, torch.float16, None, 16)
\ No newline at end of file
......@@ -77,27 +77,27 @@ def ref_single_query_cached_kv_attention(
block_size = value_cache.shape[3]
num_seqs = query.shape[0]
block_tables = block_tables.cpu().tolist()
seq_lens = seq_lens.cpu().tolist()
block_tables_lst = block_tables.cpu().tolist()
seq_lens_lst = seq_lens.cpu().tolist()
for i in range(num_seqs):
q = query[i].unsqueeze(0)
block_table = block_tables[i]
seq_len = int(seq_lens[i])
block_table = block_tables_lst[i]
seq_len = int(seq_lens_lst[i])
keys = []
values = []
keys_lst: List[torch.Tensor] = []
values_lst: List[torch.Tensor] = []
for j in range(seq_len):
block_number = int(block_table[j // block_size])
block_offset = j % block_size
k = key_cache[block_number, :, :, block_offset, :]
k = k.reshape(num_kv_heads, head_size)
keys.append(k)
keys_lst.append(k)
v = value_cache[block_number, :, :, block_offset]
values.append(v)
keys = torch.stack(keys, dim=0)
values = torch.stack(values, dim=0)
values_lst.append(v)
keys = torch.stack(keys_lst, dim=0)
values = torch.stack(values_lst, dim=0)
if num_queries_per_kv > 1:
# Handle MQA and GQA
keys = torch.repeat_interleave(keys, num_queries_per_kv, dim=1)
......@@ -212,7 +212,7 @@ def test_paged_attention(
key_cache, value_cache = key_caches[0], value_caches[0]
# Using default kv_scale
kv_scale = 1.0
k_scale = v_scale = 1.0
tp_rank = 0
# Call the paged attention kernel.
......@@ -231,7 +231,8 @@ def test_paged_attention(
max_seq_len,
alibi_slopes,
kv_cache_dtype,
kv_scale,
k_scale,
v_scale,
tp_rank=tp_rank,
blocksparse_local_blocks=blocksparse_local_blocks,
blocksparse_vert_stride=blocksparse_vert_stride,
......@@ -267,7 +268,8 @@ def test_paged_attention(
max_seq_len,
alibi_slopes,
kv_cache_dtype,
kv_scale,
k_scale,
v_scale,
tp_rank=tp_rank,
blocksparse_local_blocks=blocksparse_local_blocks,
blocksparse_vert_stride=blocksparse_vert_stride,
......@@ -432,7 +434,7 @@ def test_varlen_blocksparse_attention_prefill(
value = torch.repeat_interleave(value, num_queries_per_kv, dim=1)
ref_output = ref_multi_query_kv_attention(
cu_seq_lens,
cu_seq_lens.tolist(),
query,
key,
value,
......
import random
from typing import Tuple
from typing import List, Tuple
import pytest
import torch
......@@ -11,7 +11,7 @@ DTYPES = [torch.half, torch.bfloat16, torch.float]
NUM_TOKENS = [42] # Arbitrary values for testing
NUM_LAYERS = [1] # Arbitrary values for testing
NUM_HEADS = [8] # Arbitrary values for testing
HEAD_SIZES = [64, 80, 96, 112, 128, 192, 256]
HEAD_SIZES = [64, 80, 96, 112, 120, 128, 192, 256]
BLOCK_SIZES = [8, 16, 32]
# Arbitrary values for testing
......@@ -53,6 +53,8 @@ def test_copy_blocks(
kv_cache_dtype: str,
device: str,
) -> None:
if kv_cache_dtype == "fp8" and head_size % 16:
pytest.skip()
random.seed(seed)
torch.random.manual_seed(seed)
if torch.cuda.is_available():
......@@ -64,7 +66,7 @@ def test_copy_blocks(
src_blocks = random.sample(range(num_blocks), num_mappings)
remainig_blocks = list(set(range(num_blocks)) - set(src_blocks))
dst_blocks = random.sample(remainig_blocks, 2 * num_mappings)
block_mapping = []
block_mapping: List[Tuple[int, int]] = []
for i in range(num_mappings):
src = src_blocks[i]
dst1 = dst_blocks[2 * i]
......@@ -125,6 +127,8 @@ def test_reshape_and_cache(
device: str,
kv_cache_dtype: str,
) -> None:
if kv_cache_dtype == "fp8" and head_size % 16:
pytest.skip()
random.seed(seed)
torch.random.manual_seed(seed)
if torch.cuda.is_available():
......@@ -132,8 +136,8 @@ def test_reshape_and_cache(
torch.set_default_device(device)
# Create a random slot mapping.
num_slots = block_size * num_blocks
slot_mapping = random.sample(range(num_slots), num_tokens)
slot_mapping = torch.tensor(slot_mapping, dtype=torch.long)
slot_mapping_lst = random.sample(range(num_slots), num_tokens)
slot_mapping = torch.tensor(slot_mapping_lst, dtype=torch.long)
qkv = torch.randn(num_tokens, 3, num_heads, head_size, dtype=dtype)
_, key, value = qkv.unbind(dim=1)
......@@ -156,11 +160,11 @@ def test_reshape_and_cache(
cloned_value_cache = value_cache.clone()
# Using default kv_scale
kv_scale = 1.0
k_scale = v_scale = 1.0
# Call the reshape_and_cache kernel.
ops.reshape_and_cache(key, value, key_cache, value_cache, slot_mapping,
kv_cache_dtype, kv_scale)
kv_cache_dtype, k_scale, v_scale)
if kv_cache_dtype == "fp8":
result_key_cache = torch.empty_like(key_cache, dtype=torch.float16)
......@@ -171,12 +175,12 @@ def test_reshape_and_cache(
# Run the reference implementation.
reshaped_key = key.reshape(num_tokens, *key_cache[0, :, :, 0, :].shape)
block_indicies = torch.div(slot_mapping, block_size, rounding_mode="floor")
block_indicies = block_indicies.cpu().tolist()
block_indicies_lst = block_indicies.cpu().tolist()
block_offsets = slot_mapping % block_size
block_offsets = block_offsets.cpu().tolist()
block_offsets_lst = block_offsets.cpu().tolist()
for i in range(num_tokens):
block_idx = block_indicies[i]
block_offset = block_offsets[i]
block_idx = block_indicies_lst[i]
block_offset = block_offsets_lst[i]
cloned_key_cache[block_idx, :, :, block_offset, :] = reshaped_key[i]
cloned_value_cache[block_idx, :, :, block_offset] = value[i]
......@@ -216,8 +220,6 @@ def test_reshape_and_cache_flash(
device: str,
kv_cache_dtype: str,
) -> None:
if kv_cache_dtype == "fp8":
pytest.skip()
random.seed(seed)
torch.random.manual_seed(seed)
torch.cuda.manual_seed(seed)
......@@ -225,8 +227,10 @@ def test_reshape_and_cache_flash(
# Create a random slot mapping.
num_slots = block_size * num_blocks
slot_mapping = random.sample(range(num_slots), num_tokens)
slot_mapping = torch.tensor(slot_mapping, dtype=torch.long, device=device)
slot_mapping_lst = random.sample(range(num_slots), num_tokens)
slot_mapping = torch.tensor(slot_mapping_lst,
dtype=torch.long,
device=device)
qkv = torch.randn(num_tokens,
3,
......@@ -247,29 +251,57 @@ def test_reshape_and_cache_flash(
dtype,
device=device,
)
key_cache, value_cache = key_caches[0], value_caches[0]
key_cache, value_cache = key_caches[0].contiguous(
), value_caches[0].contiguous()
del key_caches
del value_caches
# Clone the KV caches.
cloned_key_cache = key_cache.clone()
cloned_value_cache = value_cache.clone()
if kv_cache_dtype == "fp8":
cloned_key_cache = torch.empty_like(key_cache, dtype=torch.float16)
ops.convert_fp8(cloned_key_cache, key_cache)
cloned_value_cache = torch.empty_like(value_cache, dtype=torch.float16)
ops.convert_fp8(cloned_value_cache, value_cache)
else:
cloned_key_cache = key_cache.clone()
cloned_value_cache = value_cache.clone()
# Using default kv_scale
k_scale = v_scale = 1.0
# Call the reshape_and_cache kernel.
ops.reshape_and_cache_flash(key, value, key_cache, value_cache,
slot_mapping, kv_cache_dtype)
slot_mapping, kv_cache_dtype, k_scale, v_scale)
if kv_cache_dtype == "fp8":
result_key_cache = torch.empty_like(key_cache, dtype=torch.float16)
ops.convert_fp8(result_key_cache, key_cache)
result_value_cache = torch.empty_like(value_cache, dtype=torch.float16)
ops.convert_fp8(result_value_cache, value_cache)
# Run the reference implementation.
block_indicies = torch.div(slot_mapping, block_size, rounding_mode='floor')
block_indicies = block_indicies.cpu().tolist()
block_indicies = torch.div(slot_mapping, block_size, rounding_mode="floor")
block_indicies_lst = block_indicies.cpu().tolist()
block_offsets = slot_mapping % block_size
block_offsets = block_offsets.cpu().tolist()
block_offsets_lst = block_offsets.cpu().tolist()
for i in range(num_tokens):
block_idx = block_indicies[i]
block_offset = block_offsets[i]
block_idx = block_indicies_lst[i]
block_offset = block_offsets_lst[i]
cloned_key_cache[block_idx, block_offset, :, :] = key[i]
cloned_value_cache[block_idx, block_offset, :, :] = value[i]
assert torch.allclose(key_cache, cloned_key_cache)
assert torch.allclose(value_cache, cloned_value_cache)
if kv_cache_dtype == "fp8":
assert torch.allclose(result_key_cache,
cloned_key_cache,
atol=0.001,
rtol=0.1)
assert torch.allclose(result_value_cache,
cloned_value_cache,
atol=0.001,
rtol=0.1)
else:
assert torch.allclose(key_cache, cloned_key_cache)
assert torch.allclose(value_cache, cloned_value_cache)
@pytest.mark.parametrize("direction", COPYING_DIRECTION)
......@@ -298,6 +330,8 @@ def test_swap_blocks(
) -> None:
if kv_cache_dtype == "fp8" and "cpu" in direction:
pytest.skip()
if kv_cache_dtype == "fp8" and head_size % 16:
pytest.skip()
random.seed(seed)
torch.random.manual_seed(seed)
if torch.cuda.is_available():
......
......@@ -2,36 +2,53 @@
Run `pytest tests/kernels/test_cutlass.py`.
"""
from typing import Type
from typing import Optional, Type
import pytest
import torch
from vllm import _custom_ops as ops
from vllm.platforms import current_platform
CUDA_DEVICES = [
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
]
capability = torch.cuda.get_device_capability()
capability = current_platform.get_device_capability()
capability = capability[0] * 10 + capability[1]
def to_fp8(tensor: torch.tensor):
def to_fp8(tensor: torch.Tensor):
finfo = torch.finfo(torch.float8_e4m3fn)
return torch.round(tensor.clamp(
min=finfo.min, max=finfo.max)).to(dtype=torch.float8_e4m3fn)
def to_int8(tensor: torch.tensor):
def to_int8(tensor: torch.Tensor):
return torch.round(tensor.clamp(min=-128, max=127)).to(dtype=torch.int8)
def baseline_scaled_mm(a: torch.Tensor,
b: torch.Tensor,
scale_a: torch.Tensor,
scale_b: torch.Tensor,
out_dtype: Type[torch.dtype],
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
output = (scale_a * (scale_b * (torch.mm(
a.to(dtype=torch.float32), b.to(dtype=torch.float32))))).to(out_dtype)
if bias is not None:
output = output + bias
return output
def cutlass_fp8_gemm_helper(m: int,
n: int,
k: int,
per_token_act_quant: bool,
per_out_channel_weight_quant: bool,
use_bias: bool,
out_dtype: Type[torch.dtype] = torch.bfloat16,
device: str = "cuda"):
# Test for a cutlass kernel with per-token activation quantization
......@@ -42,16 +59,19 @@ def cutlass_fp8_gemm_helper(m: int,
m_a_scales = m if per_token_act_quant else 1
n_b_scales = n if per_out_channel_weight_quant else 1
scale_a = (torch.randn(
(m_a_scales, 1), device=device, dtype=torch.float32) / 10)
scale_b = (torch.randn(
(1, n_b_scales), device=device, dtype=torch.float32) / 10)
scale_a = (torch.randn((m_a_scales, 1), device=device,
dtype=torch.float32))
scale_b = (torch.randn((1, n_b_scales), device=device,
dtype=torch.float32))
if use_bias:
bias = torch.rand((n, ), device=device, dtype=out_dtype) * 10
else:
bias = None
out = ops.cutlass_scaled_mm(a, b, scale_a, scale_b, out_dtype)
baseline = torch.mm(scale_a * a.to(dtype=torch.float32),
scale_b * b.to(dtype=torch.float32)).to(out_dtype)
out = ops.cutlass_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
baseline = baseline_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
assert torch.allclose(out, baseline, rtol=1e-2, atol=1e-1)
assert torch.allclose(out, baseline, rtol=1e-2, atol=5e-2)
def cutlass_int8_gemm_helper(m: int,
......@@ -59,6 +79,7 @@ def cutlass_int8_gemm_helper(m: int,
k: int,
per_token_act_quant: bool,
per_out_channel_weight_quant: bool,
use_bias: bool,
out_dtype: Type[torch.dtype] = torch.bfloat16,
device: str = "cuda"):
# Test for a cutlass kernel with per-token activation quantization
......@@ -69,79 +90,106 @@ def cutlass_int8_gemm_helper(m: int,
m_a_scales = m if per_token_act_quant else 1
n_b_scales = n if per_out_channel_weight_quant else 1
scale_a = (torch.randn(
(m_a_scales, 1), device=device, dtype=torch.float32) / 10)
scale_b = (torch.randn(
(1, n_b_scales), device=device, dtype=torch.float32) / 10)
scale_a = (torch.randn((m_a_scales, 1), device=device,
dtype=torch.float32))
scale_b = (torch.randn((1, n_b_scales), device=device,
dtype=torch.float32))
out = ops.cutlass_scaled_mm(a, b, scale_a, scale_b, out_dtype)
baseline = torch.mm(scale_a * a.to(dtype=torch.float32),
scale_b *
b.to(dtype=torch.float32)).to(dtype=out_dtype)
if use_bias:
bias = torch.rand((n, ), device=device, dtype=out_dtype) * 10
else:
bias = None
out = ops.cutlass_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
baseline = baseline_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
assert torch.allclose(out, baseline, rtol=1e-1, atol=1e0)
@pytest.mark.parametrize("m", [512, 222, 100, 33, 1])
@pytest.mark.parametrize("n", [2048, 256, 1024])
@pytest.mark.parametrize("m", [1, 16, 32, 64, 128, 256, 512, 222, 100, 33])
@pytest.mark.parametrize("n", [2048, 4096, 8192, 16384, 24576, 256, 1024])
@pytest.mark.parametrize("k", [128, 496, 1024])
@pytest.mark.parametrize("per_act_token", [True, False])
@pytest.mark.parametrize("per_out_ch", [True, False])
@pytest.mark.parametrize("use_bias", [True, False])
@pytest.mark.skipif(capability < 89,
reason="FP8 is not supported on this GPU type.")
def test_cutlass_fp8_gemm(m: int, n: int, k: int, per_act_token: bool,
per_out_ch: bool):
cutlass_fp8_gemm_helper(m, n, k, per_act_token, per_out_ch)
per_out_ch: bool, use_bias: bool):
cutlass_fp8_gemm_helper(m, n, k, per_act_token, per_out_ch, use_bias)
@pytest.mark.parametrize("m", [512, 222, 33, 1])
@pytest.mark.parametrize("n", [2048, 256, 1024])
@pytest.mark.parametrize("m", [1, 16, 32, 64, 128, 256, 512, 222, 33, 1])
@pytest.mark.parametrize("n", [2048, 8192, 16384, 256, 1024])
@pytest.mark.parametrize("k", [128, 496, 1024])
@pytest.mark.parametrize("per_act_token", [True, False])
@pytest.mark.parametrize("per_out_ch", [True, False])
@pytest.mark.parametrize("use_bias", [True, False])
def test_cutlass_int8_gemm(m: int, n: int, k: int, per_act_token: bool,
per_out_ch: bool):
cutlass_int8_gemm_helper(m, n, k, per_act_token, per_out_ch)
per_out_ch: bool, use_bias: bool):
cutlass_int8_gemm_helper(m, n, k, per_act_token, per_out_ch, use_bias)
@pytest.mark.parametrize("per_act_token", [True, False])
@pytest.mark.parametrize("per_out_ch", [True, False])
@pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float16])
@pytest.mark.parametrize("use_bias", [True, False])
def test_cutlass_int8_gemm_output_dtype(per_act_token: bool, per_out_ch: bool,
out_dtype: Type[torch.dtype]):
cutlass_int8_gemm_helper(512, 512, 512, per_act_token, per_out_ch,
out_dtype)
out_dtype: Type[torch.dtype],
use_bias: bool):
cutlass_int8_gemm_helper(512,
512,
512,
per_act_token,
per_out_ch,
use_bias,
out_dtype=out_dtype)
@pytest.mark.parametrize("per_act_token", [True, False])
@pytest.mark.parametrize("per_out_ch", [True, False])
@pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float16])
@pytest.mark.parametrize("use_bias", [True, False])
@pytest.mark.skipif(capability < 89,
reason="FP8 is not supported on this GPU type.")
def test_cutlass_fp8_gemm_output_dtype(per_act_token: bool, per_out_ch: bool,
out_dtype: Type[torch.dtype]):
cutlass_fp8_gemm_helper(512, 512, 512, per_act_token, per_out_ch,
out_dtype)
out_dtype: Type[torch.dtype],
use_bias: bool):
cutlass_fp8_gemm_helper(512,
512,
512,
per_act_token,
per_out_ch,
use_bias,
out_dtype=out_dtype)
@pytest.mark.parametrize("per_act_token", [True, False])
@pytest.mark.parametrize("per_out_ch", [True, False])
@pytest.mark.parametrize("use_bias", [True, False])
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.skipif(capability < 89,
reason="FP8 is not supported on this GPU type.")
def test_cutlass_fp8_gemm_devices(per_act_token: bool, per_out_ch: bool,
device: str):
cutlass_fp8_gemm_helper(512, 512, 512, per_act_token, per_out_ch,
use_bias: bool, device: str):
cutlass_fp8_gemm_helper(512, 512, 512, per_act_token, per_out_ch, use_bias,
torch.bfloat16, device)
@pytest.mark.parametrize("per_act_token", [True, False])
@pytest.mark.parametrize("per_out_ch", [True, False])
@pytest.mark.parametrize("use_bias", [True, False])
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_cutlass_int8_gemm_devices(per_act_token: bool, per_out_ch: bool,
device: str):
cutlass_int8_gemm_helper(512, 512, 512, per_act_token, per_out_ch,
torch.bfloat16, device)
use_bias: bool, device: str):
cutlass_int8_gemm_helper(512,
512,
512,
per_act_token,
per_out_ch,
use_bias,
out_dtype=torch.bfloat16,
device=device)
# For the following two tests:
......@@ -151,20 +199,26 @@ def test_cutlass_int8_gemm_devices(per_act_token: bool, per_out_ch: bool,
# kernel must handle any M thrown at it.
@pytest.mark.parametrize("per_act_token", [True, False])
@pytest.mark.parametrize("per_out_ch", [True, False])
@pytest.mark.parametrize("use_bias", [True, False])
@pytest.mark.skipif(capability < 89,
reason="FP8 is not supported on this GPU type.")
def test_cutlass_fp8_gemm_m_sweep(per_act_token: bool, per_out_ch: bool):
def test_cutlass_fp8_gemm_m_sweep(per_act_token: bool, per_out_ch: bool,
use_bias: bool):
for nk in range(32, 128, 32):
for m in range(1, 128):
cutlass_fp8_gemm_helper(m, nk, nk, per_act_token, per_out_ch)
cutlass_fp8_gemm_helper(m, nk, nk, per_act_token, per_out_ch,
use_bias)
@pytest.mark.parametrize("per_act_token", [True, False])
@pytest.mark.parametrize("per_out_ch", [True, False])
def test_cutlass_int8_gemm_m_sweep(per_act_token: bool, per_out_ch: bool):
@pytest.mark.parametrize("use_bias", [True, False])
def test_cutlass_int8_gemm_m_sweep(per_act_token: bool, per_out_ch: bool,
use_bias: bool):
for nk in range(32, 128, 32):
for m in range(1, 128):
cutlass_int8_gemm_helper(m, nk, nk, per_act_token, per_out_ch)
cutlass_int8_gemm_helper(m, nk, nk, per_act_token, per_out_ch,
use_bias)
# Test working with a subset of A and B
......@@ -185,9 +239,11 @@ def test_cutlass_subset():
scale_a,
scale_b,
out_dtype=torch.bfloat16)
baseline = torch.mm(scale_a * a.to(dtype=torch.float32),
scale_b *
b.to(dtype=torch.float32)).to(dtype=torch.bfloat16)
baseline = baseline_scaled_mm(a,
b,
scale_a,
scale_b,
out_dtype=torch.bfloat16)
assert torch.allclose(out, baseline, rtol=1e-1, atol=1e0)
......
"""
Tests:
* E2E test of Encoder attention + Decoder self-attention +
Encoder/decoder cross-attention (collectively
"encoder/decoder attention")
* Confirm enc/dec models will fail for chunked prefill
* Confirm enc/dec models will fail for prefix caching
"""
from typing import NamedTuple, Optional
import pytest
import torch
from tests.kernels.utils import *
from tests.kernels.utils import make_causal_mask, maybe_make_long_tensor
from vllm.attention import Attention, AttentionMetadata
from vllm.attention.backends.abstract import AttentionBackend, AttentionType
from vllm.attention.backends.utils import STR_NOT_IMPL_ENC_DEC_ROCM_HIP
from vllm.utils import is_hip
HEAD_SIZES = [64, 256]
NUM_HEADS = [1, 16]
BATCH_SIZES = [1, 16]
BLOCK_SIZES = [16]
BACKEND_NAMES = [STR_XFORMERS_ATTN_VAL]
CUDA_DEVICE = "cuda:0"
MAX_DEC_SEQ_LENS = [128]
MAX_ENC_SEQ_LENS = [128]
# Narrow teest-cases for unsupported-scenario
# tests
HEAD_SIZES_FOR_UNSUPP = [HEAD_SIZES[0]]
class TestPoint(NamedTuple):
"""
Encapsulates the attributes which define a single invocation
of the test_e2e_enc_dec_attn() test
Attributes:
num_heads: The number of heads in the model.
head_size: Head dimension
backend_name: Name of the backend framework used.
batch_size: Number of samples per batch.
block_size: Size of each block of data processed.
max_dec_seq_len: Maximum sequence length for the decoder.
max_enc_seq_len: Maximum sequence length for the encoder.
num_blocks: Number of blocks in the model.
"""
num_heads: int
head_size: int
backend_name: str
batch_size: int
block_size: int
max_dec_seq_len: int
max_enc_seq_len: int
num_blocks: int
class TestResources(NamedTuple):
'''
Encapsulates key components for performing an
encoder/decoder attention test
Note that
(1) attn automatically selects an attention backend
based on platform info & a set of canned
heuristics
(2) attn_backend is thus *not the same backend
instance* used by attn, but rather it is
intended to be a
*different instance* of the *same backend class*;
it is assumed that the user of TestResources
will leverage attn_backend for the purpose of
constructing backend-compatible attention
metadata instances
Attributes:
* scale: 1/sqrt(d) scale factor for attn
* attn_backend: implementatino of abstraction
attention interface using
a particular kernel library
i.e. XFormers
* attn: Attention layer instance
* kv_cache: shared key/value cache for all attention
'''
scale: float
attn_backend: AttentionBackend
attn: Attention
kv_cache: torch.Tensor
def _make_test_resources(test_pt: TestPoint, ) -> TestResources:
'''
Build key components for performing encoder/decoder attention test.
Note that
(1) The Attention instance constructed here, automatically selects
an attention backend class based on platform info & a set of canned
heuristics, so
(2) The attention backend instance constructed here is thus *not
the same backend instance* used by attn, but rather it is
intended to be a *different instance* of the *same backend class*;
therefore,
(3) This function requires that test_pt.backend_name matches the backend
class that Attention will automatically select when it is constructed.
Arguments:
* test_pt: TestPoint data structure; this function relies on the
following fields: num_heads, head_size, num_blocks,
block_size, backend_name
Returns:
* TestResources data structure.
'''
scale = float(1.0 / (test_pt.head_size**0.5))
attn_backend = make_backend(test_pt.backend_name)
attn = Attention(
test_pt.num_heads,
test_pt.head_size,
scale=scale,
)
if test_pt.num_blocks is None or test_pt.num_heads is None:
# Caller does not require a KV cache
return TestResources(scale, attn_backend, attn, None)
# Construct KV cache
kv_cache = make_kv_cache(test_pt.num_blocks,
test_pt.num_heads,
test_pt.head_size,
test_pt.block_size,
device=CUDA_DEVICE)
return TestResources(scale, attn_backend, attn, kv_cache)
def _encoder_attn_setup(
test_pt: TestPoint,
test_rsrcs: TestResources,
) -> PhaseTestParameters:
'''
Set up test vectors & data structures for encoder attention test.
A triplet of synthetic query/key/value tensors are constructed.
Given this is an encoder attention test, the key & value
sequences will have the same length as the corresponding queries.
The query/key/value tensors are passed to an ideal reference
self-attention implementation to generate an ideal output tensor.
Encoder inference does not populate the KV cache, therefore
no KV cache memory mapping is constructed
Arguments:
* test_pt: TestPoint data structure; this function relies on the
following fields: batch_size, num_heads, head_size,
block_size, max_q_seq_len
* test_rsrcs: TestResources data structure; this function relies on the
scale field
Returns:
* PhaseTestParameters data structure comprising (1) packed query/key/value
tensors, (2) the ideal output of attention computed using a naive
implementation, and (3) KVCache field set to None
'''
(
num_heads,
head_size,
_,
batch_size,
_,
_,
max_q_seq_len,
_,
) = test_pt
scale = test_rsrcs.scale
max_kv_seq_len = max_q_seq_len
# Make test tensors
qkv_in, _, _ = make_qkv(batch_size,
max_q_seq_len,
max_kv_seq_len,
num_heads,
head_size,
attn_type=AttentionType.ENCODER,
device=CUDA_DEVICE)
# Compute correct answer using naive non-causal attention
# implementation
ideal_output = ref_masked_attention(qkv_in.query,
qkv_in.key,
qkv_in.value,
scale=scale,
q_seq_lens=qkv_in.q_seq_lens,
kv_seq_lens=qkv_in.kv_seq_lens)
packed_ideal_output, _ = pack_tensor(ideal_output,
qkv_in.q_seq_lens,
device=CUDA_DEVICE)
packed_qkv = pack_qkv(qkv_in, device=CUDA_DEVICE)
return PhaseTestParameters(
PackedQKVO(packed_qkv, packed_ideal_output),
None # No KV cache
)
def _decoder_attn_setup(
test_pt: TestPoint,
test_rsrcs: TestResources,
block_base_addr: int = 0,
) -> Tuple[QKVInputs, PhaseTestParameters, PhaseTestParameters, int]:
'''
Set up test vectors & data structures for self-attention test.
A triplet of synthetic query/key/value tensors are constructed ("baseline"
query/key/value). Given this is a self-attention test, the key & value
sequences will have the same length as the corresponding queries.
"Prefill" query/key/value tensors are derived by masking out the last value
in each baseline query/key/value. These tensors are used to test prefill &
populate KV cache for a subsequent decode test.
"Decode" query/key/value tensors are derived by extracting *only* the last
value from each baseline query/key/value (i.e. complement of the prefill
tensors.) These tensors are used to test decode, conditional on the kv cache
being populated during the prefill test.
The baseline query/key/value tensors are passed to an ideal reference
self-attention implementation to generate a "Baseline" ideal output tensor.
This tensor is split into the "Prefill" ideal output tensor (all but the
last element of each output sequence) and the "Decode" ideal output tensor
(*only* the last element of each output sequence); the "Prefill" and
"Decode" ideal output tensors can be used to validate the prefill and decode
test results, respectively.
This function also constructs the self-attention KV cache memory mapping
(slot mapping and block table), ensuring that the block table starts at
block_base_addr
Arguments:
* test_pt: TestPoint data structure; this function relies on the
following fields: batch_size, num_heads, head_size,
block_size, max_q_seq_len
* test_rsrcs: TestResources data structure; this function relies on the
scale field
* block_base_addr: decoder self-attention block-table base address
Returns:
* qkv: Unpacked (batch_size x padded_seq_len x num_heads x
head_size) query/key/value tensors
* Prefill-phase decoder self-attention PhaseTestParameters data structure,
including (1) packed (number_of_tokens x num_heads x head_size)
query/key/value tensors along with (2) ideal attention output
computed using a naive implementation, and (3) memory-mapping data
structures appropriate for prefill phase.
* Decode-phase decoder self-attention PhaseTestParameters data structure,
including (1) packed (number_of_tokens x num_heads x head_size)
query/key/value tensors along with (2) ideal attention output
computed using a naive implementation, and (3) memory-mapping data
structures appropriate for decode phase.
* max_block_idx: max physical address in decoder self-attention block-table
(intended to be used as the base address for the encoder/
decoder cross-attention block-table, which is not
constructed in this function)
'''
(
num_heads,
head_size,
_,
batch_size,
block_size,
max_q_seq_len,
_,
_,
) = test_pt
scale = test_rsrcs.scale
max_kv_seq_len = max_q_seq_len
# Build test tensors
(
qkv,
prefill_qkv,
decode_qkv,
) = make_qkv(batch_size,
max_q_seq_len,
max_kv_seq_len,
num_heads,
head_size,
attn_type=AttentionType.DECODER,
device=CUDA_DEVICE)
# Compute correct answer using naive attention implementation
# with causal attention mask
causal_mask = make_causal_mask(max_q_seq_len,
max_kv_seq_len).to(CUDA_DEVICE)
ideal_output = ref_masked_attention(qkv.query,
qkv.key,
qkv.value,
scale=scale,
custom_mask=causal_mask,
q_seq_lens=qkv.q_seq_lens,
kv_seq_lens=qkv.kv_seq_lens)
# Split out the prefill- & decode-phase ideal answers & pack them
prefill_ideal_output = torch.zeros_like(ideal_output)
decode_ideal_output = torch.zeros_like(ideal_output[:, 0:1])
for bdx, prefill_q_seq_len in enumerate(prefill_qkv.q_seq_lens):
prefill_ideal_output[bdx, :prefill_q_seq_len] = ideal_output[
bdx, :prefill_q_seq_len]
decode_ideal_output[bdx, :] = ideal_output[bdx, prefill_q_seq_len:(
prefill_q_seq_len + 1)]
prefill_packed_ideal_output, _ = pack_tensor(prefill_ideal_output,
prefill_qkv.q_seq_lens,
device=CUDA_DEVICE)
decode_packed_ideal_output, _ = pack_tensor(decode_ideal_output,
[1 for _ in range(batch_size)],
device=CUDA_DEVICE)
# Build prefill- & decode-phase data structures
# for decoder self-attention. Block tables and
# slot mapping must be in a format compatible
# with KV caching & attention kernels
#
# Prefill-phase:
#
# * Empty block-tables tensor
# * Slot-mapping with entries for prompt tokens
#
# Decode-phase:
# * Block-tables tensor with minimum number of blocks
# required by total num. tokens in the entirety of all sequences
# (including both prefill & decode)
# * Slot-mapping with entries for tokens that will be decoded in the
# current decode iteration
#
# Note: the format described above is simply mirroring what ModelRunner
# produces
prefill_block_tables = make_empty_block_tables_tensor(device=CUDA_DEVICE)
(
decode_block_tables,
slot_mapping_list,
max_block_idx,
) = make_block_tables_slot_mapping(block_size,
qkv.q_seq_lens,
device=CUDA_DEVICE,
block_base_addr=block_base_addr)
(
prefill_slot_mapping,
decode_slot_mapping,
) = split_slot_mapping(slot_mapping_list,
qkv.q_seq_lens,
device=CUDA_DEVICE)
prefill_pckd_qkv = pack_qkv(prefill_qkv, device=CUDA_DEVICE)
decode_pckd_qkv = pack_qkv(decode_qkv, device=CUDA_DEVICE)
return (
qkv,
PhaseTestParameters( # Prefill test params
PackedQKVO(prefill_pckd_qkv, prefill_packed_ideal_output),
KVMemoryMap(prefill_block_tables, prefill_slot_mapping)),
PhaseTestParameters( # Decode test params
PackedQKVO(decode_pckd_qkv, decode_packed_ideal_output),
KVMemoryMap(decode_block_tables, decode_slot_mapping)),
max_block_idx)
def _enc_dec_cross_attn_setup_reuses_query(
decoder_qkv: QKVInputs,
encoder_test_params: PhaseTestParameters,
prefill_decoder_phase_test_params: PhaseTestParameters,
test_pt: TestPoint,
test_rsrcs: TestResources,
block_base_addr: int = 0,
) -> Tuple[PhaseTestParameters, PhaseTestParameters]:
'''
Set up test vectors & data structures for cross-attention test.
A triplet of synthetic cross-attention key/value tensors are constructed
("baseline" key/value). Given this is a cross-attention test, we assume
query tensors were already synthesized for a prior self-attention test and
will be reused for cross-attention. The key & value sequences generated here
may have a different length than the corresponding queries (as is often
the case for cross-attention between decoder and encoder sequences.)
Cross attention key & value tensors do not grow during autoregressive
inference; thus this function obtains a single key/value pair suitable for
both prefill and decode.
The "baseline" query tensor is received as an argument. The "baseline"
query/key/value tensors are passed to an ideal reference cross-attention
implementation to generate a "baseline" ideal output tensor. This tensor is
split into the "Prefill" ideal output tensor (all but the last element of
each output sequence) and the "Decode" ideal output tensor (*only* the last
element of each output sequence); the "Prefill" and "Decode" ideal output
tensors can be used to validate the prefill and decode test results,
respectively.
This function also constructs the cross-attention KV cache memory mapping
(slot mapping and block table), ensuring that the block table starts at
block_base_addr.
Arguments:
* decoder_qkv: pre-existing unpacked (batch_size x padded_seq_len x
num_heads x head_size) decoder self-attention inputs;
this function relies on the query and q_seq_lens
fields
* encoder_test_params: PhaseTestParameters data structure which was
used for encoder inference; KV cache field
is not used by this function
* prefill_decoder_phase_test_params: PhaseTestParameters data structure
used for prefill-phase decoder
self-attention; all fields
including KV cache required
* test_pt: TestPoint data structure; this function relies on the
following fields: batch_size, num_heads, head_size,
block_size, max_q_seq_len
* test_rsrcs: TestResources data structure; this function relies on the
scale field
* block_base_addr: decoder self-attention block-table base address
Returns:
* Prefill-phase encoder/decoder cross-attention PhaseTestParameters data
structure, including (1) packed
(number_of_tokens x num_heads x head_size) query/key/value tensors
along with (2) ideal attention output computed using a
naive implementation, and (3) memory-mapping data structures appropriate
for prefill phase.
* Decode-phase encoder/decoder cross-attention PhaseTestParameters data
structure, including (1) packed
(number_of_tokens x num_heads x head_size) query/key/value tensors
along with (2) ideal attention output computed using a
naive implementation, and (3) memory-mapping data structures appropriate
for decode phase.
'''
assert encoder_test_params.packed_qkvo.packed_qkv is not None
assert prefill_decoder_phase_test_params.packed_qkvo.packed_qkv is not None
(
num_heads,
head_size,
_,
batch_size,
block_size,
max_decoder_seq_len,
max_encoder_seq_len,
_,
) = test_pt
scale = test_rsrcs.scale
decoder_query = decoder_qkv.query
decoder_seq_lens = decoder_qkv.q_seq_lens
encoder_seq_lens = encoder_test_params.packed_qkvo.packed_qkv.q_seq_lens
prefill_q_seq_lens = (
prefill_decoder_phase_test_params.packed_qkvo.packed_qkv.q_seq_lens)
assert prefill_q_seq_lens is not None
(
cross_kv,
_,
_,
) = make_qkv(batch_size,
max_decoder_seq_len,
max_encoder_seq_len,
num_heads,
head_size,
force_kv_seq_lens=encoder_seq_lens,
attn_type=AttentionType.ENCODER_DECODER,
device=CUDA_DEVICE)
ideal_output = ref_masked_attention(decoder_query,
cross_kv.key,
cross_kv.value,
scale=scale,
q_seq_lens=decoder_seq_lens,
kv_seq_lens=cross_kv.kv_seq_lens)
prefill_ideal_output = torch.zeros_like(ideal_output)
decode_ideal_output = torch.zeros_like(ideal_output[:, 0:1])
for bdx, prefill_q_seq_len in enumerate(prefill_q_seq_lens):
prefill_ideal_output[bdx, :prefill_q_seq_len] = ideal_output[
bdx, :prefill_q_seq_len]
decode_ideal_output[bdx, :] = ideal_output[bdx, prefill_q_seq_len:(
prefill_q_seq_len + 1)]
prefill_packed_ideal_output, _ = pack_tensor(prefill_ideal_output,
prefill_q_seq_lens,
device=CUDA_DEVICE)
decode_packed_ideal_output, _ = pack_tensor(decode_ideal_output,
[1 for _ in range(batch_size)],
device=CUDA_DEVICE)
# Build prefill- & decode-phase data structures
# for encoder/decoder cross-attention. Block tables and
# slot mapping must be in a format compatible
# with KV caching & attention kernels
#
# Whereas decoder self-attention extracts relationships between
# equal-length Q/K/V sequences, which mutually grow in length
# with each decoded token, cross-attention relates the Q sequence
# - which grows with each new decoded token - to fixed-length
# K and V sequences derived from the encoder hidden states.
#
# Prefill-phase:
#
# * Empty block-tables tensor
# * Slot-mapping with as many entries as there are tokens in the encoder
# prompt.
#
# Decode-phase:
# * Block-tables tensor with minimum number of blocks to
# accommodate K & V tensors which are equal in lnegth
# to the encoder prompt length
# * Empty slot-mapping tensor (since K & V are fixed in size,
# new decoded tokens are not KV-cached and require no slot-
# mapping)
#
# Note: the format above is simply an extension of what ModelRunner
# produces for decoder-only models
prefill_block_tables = make_empty_block_tables_tensor(device=CUDA_DEVICE)
decode_slot_mapping = make_empty_slot_mapping_tensor(device=CUDA_DEVICE)
(
decode_block_tables,
prefill_slot_mapping_list,
_,
) = make_block_tables_slot_mapping(block_size,
cross_kv.kv_seq_lens,
block_base_addr=block_base_addr,
device=CUDA_DEVICE)
prefill_slot_mapping = maybe_make_long_tensor(prefill_slot_mapping_list,
device=CUDA_DEVICE)
# Packed key/value (query is already provided)
packed_cross_kv = pack_qkv(cross_kv, device=CUDA_DEVICE)
return (
PhaseTestParameters( # Prefill-phase test params
PackedQKVO(packed_cross_kv, prefill_packed_ideal_output),
KVMemoryMap(prefill_block_tables, prefill_slot_mapping)),
PhaseTestParameters( # Decode-phase test params
PackedQKVO(None, decode_packed_ideal_output),
KVMemoryMap(decode_block_tables, decode_slot_mapping)))
def _run_encoder_attention_test(
attn: Attention,
encoder_test_params: PhaseTestParameters,
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
'''
Run encoder attention.
attn.forward() is passed attn_type=AttentionType.ENCODER in order
to configure the kernel invocation for encoder attention
Requires attn_metadata.num_decode_tokens == 0
(There is no encoder execution in the decode-phase)
Arguments:
* attn: Attention wrapper instance
* encoder_test_params: encoder PhaseTestParameters data structure;
this function relies on the packed
(number_of_tokens x num_heads x head_size)
query/key/value fields
* attn_metadata: attention metadata for encoder/decoder-self attention
Returns:
* Attention.forward() applied to packed {query,key,value} and
& attn_metadata
'''
assert attn_metadata.num_decode_tokens == 0
attn_type = AttentionType.ENCODER
packed_qkv = encoder_test_params.packed_qkvo.packed_qkv
assert packed_qkv is not None
return attn.forward(packed_qkv.query,
packed_qkv.key,
packed_qkv.value,
None,
attn_metadata,
attn_type=attn_type)
def _run_decoder_self_attention_test(
test_rsrcs: TestResources,
decoder_test_params: PhaseTestParameters,
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
'''
Run decoder self-attention test.
attn.forward() is passed attn_type=AttentionType.DECODER
in order to configure the kernel invocation for decoder self-attention.
Arguments:
* test_rsrcs: TestResources instance; this function relies on the kv_cache
and attn (Attention wrapper instance) fields
* decoder_test_params: decoder PhaseTestParameters data structure;
this function relies on the packed
(number_of_tokens x num_heads x head_size)
query/key/value fields
* attn_metadata: attention metadata for decoder-self attention
(contains KV cache memory-mapping)
Returns:
* Attention.forward() applied to packed_{query,key,value}, kv_cache
& attn_metadata
'''
attn_type = AttentionType.DECODER
attn = test_rsrcs.attn
kv_cache = test_rsrcs.kv_cache
packed_qkv = decoder_test_params.packed_qkvo.packed_qkv
assert packed_qkv is not None
return attn.forward(packed_qkv.query,
packed_qkv.key,
packed_qkv.value,
kv_cache,
attn_metadata,
attn_type=attn_type)
def _run_encoder_decoder_cross_attention_test(
test_rsrcs: TestResources,
decoder_test_params: PhaseTestParameters,
cross_test_params: Optional[PhaseTestParameters],
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
'''
Run encoder/decoder cross-attention test.
Via PhaseTestParameters data structures, consumes the same query utilized
for decoder self-attention, plus a key/value specific to cross-attention.
if cross_test_params is None or cross_test_params.packed_qkvo.packed_qkv
is None, this reflects that in decode-phase cross attention there
is no growth in the key and value tensors.
attn.forward() is passed attn_type=AttentionType.ENCODER_DECODER
in order to configure the kernel invocation for encoder/decoder cross-
attention.
Arguments:
* test_rsrcs: TestResources instance; this function relies on the kv_cache
and attn (Attention wrapper instance) fields
* decoder_test_params: decoder PhaseTestParameters data structure;
this function relies on the packed
(number_of_tokens x num_heads x head_size)
query field
* cross_test_params: encoder/decoder PhaseTestParameters data structure;
this function relies on the packed
(number_of_tokens x num_heads x head_size)
key/value fields
* attn_metadata: attention metadata for encoder/decoder-self attention
Returns:
* Attention.forward() applied to packed_{query,key,value}, kv_cache
& attn_metadata
'''
assert decoder_test_params.packed_qkvo.packed_qkv is not None
attn_type = AttentionType.ENCODER_DECODER
attn = test_rsrcs.attn
kv_cache = test_rsrcs.kv_cache
if cross_test_params is None:
key = None
value = None
else:
cross_pckd_qkv = cross_test_params.packed_qkvo.packed_qkv
key = (None if cross_pckd_qkv is None else cross_pckd_qkv.key)
value = (None if cross_pckd_qkv is None else cross_pckd_qkv.value)
return attn.forward(decoder_test_params.packed_qkvo.packed_qkv.query,
key,
value,
kv_cache,
attn_metadata,
attn_type=attn_type)
@pytest.mark.skipif(is_hip(), reason=STR_NOT_IMPL_ENC_DEC_ROCM_HIP)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("backend_name", BACKEND_NAMES)
@pytest.mark.parametrize("batch_size", BATCH_SIZES)
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
@pytest.mark.parametrize("max_dec_seq_len", MAX_DEC_SEQ_LENS)
@pytest.mark.parametrize("max_enc_seq_len", MAX_ENC_SEQ_LENS)
def test_encoder_only(num_heads: int, head_size: int, backend_name: str,
batch_size: int, block_size: int, max_dec_seq_len: int,
max_enc_seq_len: int, monkeypatch):
# Force Attention wrapper backend
override_backend_env_variable(monkeypatch, backend_name)
# Note: KV cache size of 4096 is arbitrary & chosen intentionally
# to be more than necessary, since exceeding the kv cache size
# is not part of this test
test_pt = TestPoint(num_heads, head_size, backend_name, batch_size,
block_size, max_dec_seq_len, max_enc_seq_len, 4096)
# Attention scale factor, attention backend instance, attention wrapper
# instance, KV cache init
test_rsrcs = _make_test_resources(test_pt)
# Construct encoder attention test params (only used
# during prefill)
enc_test_params = _encoder_attn_setup(test_pt, test_rsrcs)
# Shared prefill metadata structure
prephase_attn_metadata: AttentionMetadata = make_test_metadata(
test_rsrcs.attn_backend,
True,
None,
decoder_test_params=None,
encoder_test_params=enc_test_params,
cross_test_params=None,
device=CUDA_DEVICE)
# PREFILL: encoder attention
enc_pckd_act_out: torch.Tensor = (_run_encoder_attention_test(
test_rsrcs.attn, enc_test_params, prephase_attn_metadata))
# - Is encoder attention result correct?
assert_actual_matches_ideal(enc_test_params, enc_pckd_act_out)
@pytest.mark.skipif(is_hip(), reason=STR_NOT_IMPL_ENC_DEC_ROCM_HIP)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("backend_name", BACKEND_NAMES)
@pytest.mark.parametrize("batch_size", BATCH_SIZES)
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
@pytest.mark.parametrize("max_dec_seq_len", MAX_DEC_SEQ_LENS)
@pytest.mark.parametrize("max_enc_seq_len", MAX_ENC_SEQ_LENS)
def test_e2e_enc_dec_attn(
num_heads: int,
head_size: int,
backend_name: str,
batch_size: int,
block_size: int,
max_dec_seq_len: int,
max_enc_seq_len: int,
monkeypatch,
) -> None:
'''
End-to-end encoder/decoder test:
* Construct fake test vectors for (1) encoder attention,
(2) decoder self-attention, and (3) encoder/decoder cross-attention
* Construct (1) attention metadata structure with self- and cross-attention
attributes for prefill-phase, and (2) an analogous attention metadata
structure but for decode-phase
* Test attention steps in the following order
* Encoder attention
* Prefill self-attention
* Prefill cross-attention
* Decode self-attention
* Decode cross-attention
* Besides being reflective of realistic use-cases, this order would
exacerbate any accidental overlap in the self-/cross-attention
block tables, which one hopes to avoid
* Validate output correctness against ideal reference attention
implementation
Block tables are constructed such that cross-attention KV cache is in a
higher, non-intersecting address-space than self-attention KV cache.
Self- and cross-attention share the same query tensor but not the K/V
tensors. Self-attention K/Vs must have the same seq len as Q while
cross-attention K/Vs are allowed to differ in seq len, as is often the case
for cross-attention.
This test utilizes PyTest monkey patching to force the attention backend
via an environment variable.
Note on ROCm/HIP: currently encoder/decoder models are not supported on
AMD GPUs, therefore this test simply is skipped if is_hip().
Note on metadata: there is a single attention metadata structure shared by
all prefill-phase attention operations (encoder, decoder, enc/dec cross),
and a single one shared by all decode-phase attention operations
(decoder & enc/dec cross.) This is intended to reflect the behavior
of ModelRunner, which constructs a single attention metadata structure for
each prefill or decode run. A realistic scenario would rely on the
attention backend to utilize the appropriate attention metadata fields
according to the value of attn_metadata.attention_type. Thus, this test is
organized so as to confirm that the backend-under-test can handle a
shared prefill attention metadata structure & a shared decode attention
metadata structure.
'''
# Force Attention wrapper backend
override_backend_env_variable(monkeypatch, backend_name)
# Note: KV cache size of 4096 is arbitrary & chosen intentionally
# to be more than necessary, since exceeding the kv cache size
# is not part of this test
test_pt = TestPoint(num_heads, head_size, backend_name, batch_size,
block_size, max_dec_seq_len, max_enc_seq_len, 4096)
# Attention scale factor, attention backend instance, attention wrapper
# instance, KV cache init
test_rsrcs = _make_test_resources(test_pt)
# Construct encoder attention test params (only used
# during prefill)
enc_test_params = _encoder_attn_setup(test_pt, test_rsrcs)
# Construct Decoder self-attention prefill-phase & decode-phase
# test params, including query/key/value tensors, decoder self-attention
# memory-mapping. cross_block_base_addr is the uppermost address in the
# decoder self-attention block-table, i.e. a base address which the
# encoder/decoder cross-attention block-table may build downward toward.
(
dec_qkv,
prephase_dec_test_params,
decphase_dec_test_params,
cross_block_base_addr,
) = _decoder_attn_setup(test_pt, test_rsrcs)
# Construct encoder/decoder cross-attention prefill-phase & decode-phase
# test params, including key/value tensors, cross-attention memory-mapping
(
prephase_cross_test_params,
decphase_cross_test_params,
) = _enc_dec_cross_attn_setup_reuses_query(
dec_qkv,
enc_test_params,
prephase_dec_test_params,
test_pt,
test_rsrcs,
block_base_addr=cross_block_base_addr)
# Shared prefill metadata structure
assert prephase_dec_test_params.packed_qkvo.packed_qkv is not None
prephase_attn_metadata: AttentionMetadata = make_test_metadata(
test_rsrcs.attn_backend,
True,
prephase_dec_test_params.packed_qkvo.packed_qkv.q_seq_lens,
decoder_test_params=prephase_dec_test_params,
encoder_test_params=enc_test_params,
cross_test_params=prephase_cross_test_params,
device=CUDA_DEVICE)
# PREFILL: encoder attention
enc_pckd_act_out = _run_encoder_attention_test(test_rsrcs.attn,
enc_test_params,
prephase_attn_metadata)
# - Is encoder attention result correct?
assert_actual_matches_ideal(enc_test_params, enc_pckd_act_out)
# PREFILL: decoder self-attention test
prephase_dec_pckd_act_out = _run_decoder_self_attention_test(
test_rsrcs, prephase_dec_test_params, prephase_attn_metadata)
# - Is prefill decoder self-attention correct?
assert_actual_matches_ideal(prephase_dec_test_params,
prephase_dec_pckd_act_out)
# PREFILL: encoder/decoder cross-attention test
prephase_cross_pckd_act_out = _run_encoder_decoder_cross_attention_test(
test_rsrcs, prephase_dec_test_params, prephase_cross_test_params,
prephase_attn_metadata)
# - Is prefill encoder/decoder cross-attention correct?
assert_actual_matches_ideal(prephase_cross_test_params,
prephase_cross_pckd_act_out)
# DECODE: build decode-phase attention metadata
decphase_attn_metadata: AttentionMetadata = make_test_metadata(
test_rsrcs.attn_backend,
False,
dec_qkv.q_seq_lens,
decoder_test_params=decphase_dec_test_params,
encoder_test_params=enc_test_params,
cross_test_params=decphase_cross_test_params,
device=CUDA_DEVICE)
# DECODE: decoder self-attention test
decphase_dec_pckd_act_out = _run_decoder_self_attention_test(
test_rsrcs, decphase_dec_test_params, decphase_attn_metadata)
# - Is decode-phase decoder self-attention correct?
assert_actual_matches_ideal(decphase_dec_test_params,
decphase_dec_pckd_act_out)
# DECODE: encoder/decoder cross-attention test
decphase_cross_pckd_act_out = _run_encoder_decoder_cross_attention_test(
test_rsrcs, decphase_dec_test_params, None, decphase_attn_metadata)
# - Is decode-phase encoder/decoder cross-attention correct?
assert_actual_matches_ideal(decphase_cross_test_params,
decphase_cross_pckd_act_out)
......@@ -20,12 +20,13 @@ def ref_paged_attn(
block_tables: torch.Tensor,
scale: float,
sliding_window: Optional[int] = None,
soft_cap: Optional[float] = None,
) -> torch.Tensor:
num_seqs = len(query_lens)
block_tables = block_tables.cpu().numpy()
_, block_size, num_kv_heads, head_size = key_cache.shape
outputs = []
outputs: List[torch.Tensor] = []
start_idx = 0
for i in range(num_seqs):
query_len = query_lens[i]
......@@ -53,6 +54,8 @@ def ref_paged_attn(
(query_len + sliding_window) +
1).bool().logical_not()
mask |= sliding_window_mask
if soft_cap is not None:
attn = soft_cap * torch.tanh(attn / soft_cap)
attn.masked_fill_(mask, float("-inf"))
attn = torch.softmax(attn, dim=-1).to(v.dtype)
out = torch.einsum("hqk,khd->qhd", attn, v)
......@@ -68,13 +71,15 @@ def ref_paged_attn(
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
@pytest.mark.parametrize("dtype", DTYPES)
@torch.inference_mode
@pytest.mark.parametrize("soft_cap", [None, 10.0, 50.0])
@torch.inference_mode()
def test_flash_attn_with_paged_kv(
kv_lens: List[Tuple[int, int]],
kv_lens: List[int],
num_heads: Tuple[int, int],
head_size: int,
dtype: torch.dtype,
block_size: int,
soft_cap: Optional[float],
) -> None:
torch.set_default_device("cuda")
torch.cuda.manual_seed_all(0)
......@@ -108,6 +113,7 @@ def test_flash_attn_with_paged_kv(
causal=True,
block_table=block_tables,
cache_seqlens=kv_lens_tensor,
softcap=soft_cap if soft_cap is not None else 0,
).squeeze(1)
ref_output = ref_paged_attn(
......@@ -118,6 +124,7 @@ def test_flash_attn_with_paged_kv(
kv_lens=kv_lens,
block_tables=block_tables,
scale=scale,
soft_cap=soft_cap,
)
assert torch.allclose(output, ref_output, atol=1e-2, rtol=1e-2), \
f"{torch.max(torch.abs(output - ref_output))}"
......@@ -129,7 +136,8 @@ def test_flash_attn_with_paged_kv(
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
@pytest.mark.parametrize("sliding_window", [None])
@pytest.mark.parametrize("dtype", DTYPES)
@torch.inference_mode
@pytest.mark.parametrize("soft_cap", [None, 10.0, 50.0])
@torch.inference_mode()
def test_varlen_with_paged_kv(
seq_lens: List[Tuple[int, int]],
num_heads: Tuple[int, int],
......@@ -137,6 +145,7 @@ def test_varlen_with_paged_kv(
sliding_window: Optional[int],
dtype: torch.dtype,
block_size: int,
soft_cap: Optional[float],
) -> None:
torch.set_default_device("cuda")
torch.cuda.manual_seed_all(0)
......@@ -163,10 +172,6 @@ def test_varlen_with_paged_kv(
head_size,
dtype=dtype)
value_cache = torch.randn_like(key_cache)
# Normalize the scale of the key and value caches to mitigate
# numerical instability.
key_cache /= head_size**0.5
value_cache /= head_size**0.5
cu_query_lens = torch.tensor([0] + query_lens,
dtype=torch.int32).cumsum(dim=0,
dtype=torch.int32)
......@@ -192,6 +197,7 @@ def test_varlen_with_paged_kv(
causal=True,
window_size=window_size,
block_table=block_tables,
softcap=soft_cap if soft_cap is not None else 0,
)
ref_output = ref_paged_attn(
......@@ -203,6 +209,7 @@ def test_varlen_with_paged_kv(
block_tables=block_tables,
scale=scale,
sliding_window=sliding_window,
soft_cap=soft_cap,
)
assert torch.allclose(output, ref_output, atol=1e-2, rtol=1e-2), \
f"{torch.max(torch.abs(output - ref_output))}"
from typing import List, Optional, Tuple
import flashinfer
import pytest
import torch
NUM_HEADS = [(16, 16), (32, 8), (64, 8)]
HEAD_SIZES = [128, 256]
BLOCK_SIZES = [16, 32]
DTYPES = [torch.float16, torch.bfloat16]
NUM_BLOCKS = 32768 # Large enough to test overflow in index calculation.
def ref_paged_attn(
query: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
query_lens: List[int],
kv_lens: List[int],
block_tables: torch.Tensor,
scale: float,
sliding_window: Optional[int] = None,
soft_cap: Optional[float] = None,
) -> torch.Tensor:
num_seqs = len(query_lens)
block_tables = block_tables.cpu().numpy()
_, block_size, num_kv_heads, head_size = key_cache.shape
outputs: List[torch.Tensor] = []
start_idx = 0
for i in range(num_seqs):
query_len = query_lens[i]
kv_len = kv_lens[i]
q = query[start_idx:start_idx + query_len]
q *= scale
num_kv_blocks = (kv_len + block_size - 1) // block_size
block_indices = block_tables[i, :num_kv_blocks]
k = key_cache[block_indices].view(-1, num_kv_heads, head_size)
k = k[:kv_len]
v = value_cache[block_indices].view(-1, num_kv_heads, head_size)
v = v[:kv_len]
if q.shape[1] != k.shape[1]:
k = torch.repeat_interleave(k, q.shape[1] // k.shape[1], dim=1)
v = torch.repeat_interleave(v, q.shape[1] // v.shape[1], dim=1)
attn = torch.einsum("qhd,khd->hqk", q, k).float()
empty_mask = torch.ones(query_len, kv_len)
mask = torch.triu(empty_mask, diagonal=kv_len - query_len + 1).bool()
if sliding_window is not None:
sliding_window_mask = torch.triu(empty_mask,
diagonal=kv_len -
(query_len + sliding_window) +
1).bool().logical_not()
mask |= sliding_window_mask
if soft_cap is not None:
attn = soft_cap * torch.tanh(attn / soft_cap)
attn.masked_fill_(mask, float("-inf"))
attn = torch.softmax(attn, dim=-1).to(v.dtype)
out = torch.einsum("hqk,khd->qhd", attn, v)
outputs.append(out)
start_idx += query_len
return torch.cat(outputs, dim=0)
@pytest.mark.parametrize("kv_lens", [[1328, 18, 463], [1, 54, 293, 70]])
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("soft_cap", [None, 30.0, 50.0])
@torch.inference_mode
def test_flashinfer_decode_with_paged_kv(kv_lens: List[int],
num_heads: Tuple[int,
int], head_size: int,
dtype: torch.dtype, block_size: int,
soft_cap: Optional[float]) -> None:
torch.set_default_device("cuda")
torch.cuda.manual_seed_all(0)
num_seqs = len(kv_lens)
num_query_heads = num_heads[0]
num_kv_heads = num_heads[1]
assert num_query_heads % num_kv_heads == 0
max_kv_len = max(kv_lens)
scale = head_size**-0.5
query = torch.randn(num_seqs, num_query_heads, head_size, dtype=dtype)
key_value_cache = torch.randn(NUM_BLOCKS,
2,
block_size,
num_kv_heads,
head_size,
dtype=dtype)
key_cache = key_value_cache[:, 0, :, :, :].squeeze(1)
value_cache = key_value_cache[:, 1, :, :, :].squeeze(1)
max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size
block_tables = torch.randint(0,
NUM_BLOCKS,
(num_seqs, max_num_blocks_per_seq),
dtype=torch.int32)
kv_indptr = [0]
kv_indices = []
kv_last_page_lens = []
for i in range(num_seqs):
seq_len = kv_lens[i]
assert seq_len > 0
num_blocks = (seq_len + block_size - 1) // block_size
kv_indices.extend(block_tables[i, :num_blocks])
kv_indptr.append(kv_indptr[-1] + num_blocks)
kv_last_page_len = seq_len % block_size
if kv_last_page_len == 0:
kv_last_page_len = block_size
kv_last_page_lens.append(kv_last_page_len)
kv_indptr = torch.tensor(kv_indptr, dtype=torch.int32)
kv_indices = torch.tensor(kv_indices, dtype=torch.int32)
kv_last_page_lens = torch.tensor(kv_last_page_lens, dtype=torch.int32)
workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8)
wrapper = flashinfer.\
BatchDecodeWithPagedKVCacheWrapper(workspace_buffer, "NHD")
wrapper.begin_forward(kv_indptr,
kv_indices,
kv_last_page_lens,
num_query_heads,
num_kv_heads,
head_size,
block_size,
"NONE",
data_type=dtype)
output = wrapper.forward(query, key_value_cache, logits_soft_cap=soft_cap)
ref_output = ref_paged_attn(query=query,
key_cache=key_cache,
value_cache=value_cache,
query_lens=[1] * num_seqs,
kv_lens=kv_lens,
block_tables=block_tables,
scale=scale,
soft_cap=soft_cap)
assert torch.allclose(output, ref_output, atol=1e-2, rtol=1e-2), \
f"{torch.max(torch.abs(output - ref_output))}"
@pytest.mark.parametrize("seq_lens", [[(1, 1328), (5, 18), (129, 463)]])
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("soft_cap", [None, 30.0, 50.0])
@torch.inference_mode
def test_flashinfer_prefill_with_paged_kv(seq_lens: List[Tuple[int, int]],
num_heads: Tuple[int, int],
head_size: int, dtype: torch.dtype,
block_size: int,
soft_cap: Optional[float]) -> None:
torch.set_default_device("cuda")
torch.cuda.manual_seed_all(0)
num_seqs = len(seq_lens)
query_lens = [x[0] for x in seq_lens]
kv_lens = [x[1] for x in seq_lens]
num_query_heads = num_heads[0]
num_kv_heads = num_heads[1]
assert num_query_heads % num_kv_heads == 0
max_kv_len = max(kv_lens)
scale = head_size**-0.5
query = torch.randn(sum(query_lens),
num_query_heads,
head_size,
dtype=dtype)
key_value_cache = torch.randn(NUM_BLOCKS,
2,
block_size,
num_kv_heads,
head_size,
dtype=dtype)
key_cache = key_value_cache[:, 0, :, :, :].squeeze(1)
value_cache = key_value_cache[:, 1, :, :, :].squeeze(1)
# Normalize the scale of the key and value caches to mitigate
# numerical instability.
key_cache /= head_size**0.5
value_cache /= head_size**0.5
max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size
block_tables = torch.randint(0,
NUM_BLOCKS,
(num_seqs, max_num_blocks_per_seq),
dtype=torch.int32)
qo_indptr = [0]
kv_indptr = [0]
kv_indices = []
kv_last_page_lens = []
for i in range(num_seqs):
seq_len = kv_lens[i]
assert seq_len > 0
num_blocks = (seq_len + block_size - 1) // block_size
kv_indices.extend(block_tables[i, :num_blocks])
kv_indptr.append(kv_indptr[-1] + num_blocks)
kv_last_page_len = seq_len % block_size
if kv_last_page_len == 0:
kv_last_page_len = block_size
kv_last_page_lens.append(kv_last_page_len)
qo_indptr.append(qo_indptr[-1] + query_lens[i])
qo_indptr = torch.tensor(qo_indptr, dtype=torch.int32)
kv_indptr = torch.tensor(kv_indptr, dtype=torch.int32)
kv_indices = torch.tensor(kv_indices, dtype=torch.int32)
kv_last_page_lens = torch.tensor(kv_last_page_lens, dtype=torch.int32)
workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8)
wrapper = flashinfer.BatchPrefillWithPagedKVCacheWrapper(
workspace_buffer, "NHD")
wrapper.begin_forward(
qo_indptr,
kv_indptr,
kv_indices,
kv_last_page_lens,
num_query_heads,
num_kv_heads,
head_size,
block_size,
)
output = wrapper.forward(
query,
key_value_cache,
logits_soft_cap=soft_cap,
)
ref_output = ref_paged_attn(query=query,
key_cache=key_cache,
value_cache=value_cache,
query_lens=query_lens,
kv_lens=kv_lens,
block_tables=block_tables,
scale=scale,
soft_cap=soft_cap)
assert torch.allclose(output, ref_output, atol=1e-2, rtol=1e-2), \
f"{torch.max(torch.abs(output - ref_output))}"
import pytest
import torch
import vllm._custom_ops as ops
from tests.kernels.quant_utils import (ref_dynamic_per_tensor_fp8_quant,
ref_dynamic_per_token_quant)
DTYPES = [torch.half, torch.bfloat16, torch.float]
HIDDEN_SIZES = [1, 2, 3, 4, 16, 67, 768, 2048, 5120, 5137, 8192,
8193] # Arbitrary values for testing
HIDDEN_SIZES += list(range(1024, 1033)) # vectorized conversion edge cases
NUM_TOKENS = [1, 7, 83, 4096] # Arbitrary values for testing
SCALE_UBS = [True, False]
SEEDS = [0]
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("scale_ub", SCALE_UBS)
@pytest.mark.parametrize("seed", SEEDS)
@torch.inference_mode()
def test_dynamic_per_token_fp8_quant(num_tokens: int, hidden_size: int,
dtype: torch.dtype, scale_ub: bool,
seed: int) -> None:
torch.random.manual_seed(seed)
torch.cuda.manual_seed(seed)
x = torch.rand(num_tokens, hidden_size, dtype=dtype,
device="cuda") + 1e-6 # avoid nans
scale_ub = torch.mean(x).to(dtype=torch.float32, device='cuda') \
if scale_ub else None
ref_out, ref_scales = ref_dynamic_per_token_quant(x, torch.float8_e4m3fn,
scale_ub)
ops_out, ops_scales = ops.scaled_fp8_quant(x,
scale_ub=scale_ub,
use_per_token_if_dynamic=True)
assert torch.allclose(ref_scales, ops_scales)
assert torch.allclose(ref_out.to(dtype=torch.float32),
ops_out.to(dtype=torch.float32))
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@torch.inference_mode()
def test_dynamic_per_tensor_fp8_quant(num_tokens: int, hidden_size: int,
dtype: torch.dtype, seed: int) -> None:
torch.random.manual_seed(seed)
torch.cuda.manual_seed(seed)
x = torch.rand(num_tokens, hidden_size, dtype=dtype, device="cuda")
ref_out, ref_scale = ref_dynamic_per_tensor_fp8_quant(x)
ops_out, ops_scale = ops.scaled_fp8_quant(x)
assert torch.allclose(ref_scale, ops_scale)
assert torch.allclose(ref_out.to(dtype=torch.float32),
ops_out.to(dtype=torch.float32))
# Regression test for a case with large activations where an int32 index cannot
# represent the number of elements.
@torch.inference_mode()
@pytest.mark.parametrize("seed", SEEDS)
def test_fp8_quant_large(seed: int) -> None:
torch.random.manual_seed(seed)
torch.cuda.manual_seed(seed)
num_tokens = 1024000 # Mistral-Nemo's max_position_embeddings
hidden_size = 1152 # Smallest hidden_size to reproduce the error
dtype = torch.bfloat16
x = torch.rand(num_tokens, hidden_size, dtype=dtype, device="cuda")
ref_out, scale = ref_dynamic_per_tensor_fp8_quant(x)
ops_out, _ = ops.scaled_fp8_quant(x, scale)
# Minimize memory footprint in this test by freeing x and upconverting
# the outputs in place. (torch.allclose does not support fp8)
del x
ref_out = ref_out.to(dtype=dtype)
ops_out = ops_out.to(dtype=dtype)
assert torch.allclose(ref_out, ops_out)
import pytest
import torch
# ruff: noqa: F401
import vllm._C
from tests.kernels.quant_utils import ref_dynamic_per_token_quant
from vllm._custom_ops import scaled_int8_quant
DTYPES = [torch.half, torch.bfloat16, torch.float]
HIDDEN_SIZES = [16, 67, 768, 2048, 5120, 5137, 8192,
......@@ -21,23 +21,16 @@ def test_dynamic_scaled_int8_quant(num_tokens: int, hidden_size: int,
dtype: torch.dtype, seed: int) -> None:
torch.random.manual_seed(seed)
torch.cuda.manual_seed(seed)
int8_traits = torch.iinfo(torch.int8)
x = torch.rand(num_tokens, hidden_size, dtype=dtype, device="cuda") * 1000
x_token_max, _ = x.max(dim=1)
x_token_max = x_token_max.to(dtype=torch.float32)
scales = (x_token_max / float(127.0))[:, None].to(device="cuda",
dtype=torch.float32)
torch_out = (x / scales).round().clamp(int8_traits.min,
int8_traits.max).to(torch.int8)
ops_out = torch.empty_like(x, dtype=torch.int8, device="cuda")
scales_out = torch.empty_like(scales, dtype=torch.float32, device="cuda")
torch.ops._C.dynamic_scaled_int8_quant(ops_out, x, scales_out)
# reference
ref_out, ref_scales = ref_dynamic_per_token_quant(x, torch.int8)
# kernel
ops_out, ops_scales = scaled_int8_quant(x)
assert torch.allclose(scales_out, scales)
assert torch.allclose(torch_out, ops_out,
assert torch.allclose(ops_scales, ref_scales)
assert torch.allclose(ops_out, ref_out,
atol=1) # big atol to account for rounding errors
......@@ -55,12 +48,11 @@ def test_static_scaled_int8_quant(num_tokens: int, hidden_size: int,
int8_traits = torch.iinfo(torch.int8)
x = torch.rand(num_tokens, hidden_size, dtype=dtype, device="cuda") * 1000
scale = torch.tensor([scale], dtype=torch.float32, device="cuda")
out1 = (x / scale).round().clamp(int8_traits.min,
int8_traits.max).to(torch.int8)
out2 = torch.empty_like(x, dtype=torch.int8)
scale_argument = torch.tensor([scale], dtype=torch.float32, device="cuda")
out2, _ = scaled_int8_quant(x, scale)
torch.ops._C.static_scaled_int8_quant(out2, x, scale_argument)
assert torch.allclose(out1, out2,
atol=1) # big atol to account for rounding errors
......@@ -5,23 +5,33 @@ Run `pytest tests/kernels/marlin/test_marlin_gemm.py`.
import pytest
import torch
from tests.quantization.utils import is_quant_method_supported
from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.gptq_marlin import (
GPTQ_MARLIN_MAX_PARALLEL, GPTQ_MARLIN_MIN_THREAD_N,
GPTQ_MARLIN_SUPPORTED_GROUP_SIZES, GPTQ_MARLIN_SUPPORTED_NUM_BITS)
from vllm.model_executor.layers.quantization.gptq_marlin_24 import (
GPTQ_MARLIN_24_MAX_PARALLEL, GPTQ_MARLIN_24_MIN_THREAD_N,
GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES, GPTQ_MARLIN_24_SUPPORTED_NUM_BITS)
from vllm.model_executor.layers.quantization.utils.marlin_perms import (
marlin_perm)
GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES, GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES)
from vllm.model_executor.layers.quantization.qqq import (
MARLIN_QQQ_MAX_PARALLEL, MARLIN_QQQ_MIN_THREAD_N,
MARLIN_QQQ_SUPPORTED_GROUP_SIZES, MARLIN_QQQ_SUPPORTED_NUM_BITS)
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
MarlinWorkspace, compute_max_diff, is_marlin_supported, marlin_24_quantize,
marlin_quantize, marlin_weights)
GPTQ_MARLIN_MAX_PARALLEL, GPTQ_MARLIN_MIN_THREAD_N,
MARLIN_SUPPORTED_GROUP_SIZES, marlin_make_empty_g_idx,
marlin_permute_scales, query_marlin_supported_quant_types)
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
pack_fp8_to_int32)
from vllm.model_executor.layers.quantization.utils.marlin_utils_test import (
MarlinWorkspace, awq_marlin_quantize, get_weight_perm, marlin_quantize,
marlin_weights)
from vllm.model_executor.layers.quantization.utils.marlin_utils_test_24 import (
marlin_24_quantize)
from vllm.model_executor.layers.quantization.utils.marlin_utils_test_qqq import ( # noqa: E501
marlin_qqq_quantize)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
gptq_pack, quantize_weights, sort_weights)
awq_pack, gptq_pack, gptq_quantize_weights, quantize_weights, sort_weights)
ACT_ORDER_OPTS = [False, True]
K_FULL_OPTS = [False, True]
USE_FP32_REDUCE_OPTS = [False, True]
MARLIN_K_CHUNKS = [128]
MARLIN_N_CHUNKS = [64, 128, 256]
......@@ -38,21 +48,29 @@ MNK_FACTORS = [
(67, 13, 11),
]
DTYPES = [torch.float16, torch.bfloat16]
def compute_max_diff(output, output_ref):
return torch.mean(torch.abs(output - output_ref)) / torch.mean(
torch.abs(output_ref))
def rand_data(shape):
return torch.randn(shape, dtype=torch.half, device="cuda")
def rand_data(shape, dtype=torch.float16):
return torch.randn(shape, dtype=dtype, device="cuda")
@pytest.mark.skipif(not is_marlin_supported(),
@pytest.mark.skipif(not is_quant_method_supported("gptq_marlin"),
reason="Marlin is not supported on this GPU type.")
@pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS)
@pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS)
@pytest.mark.parametrize("num_bits", GPTQ_MARLIN_SUPPORTED_NUM_BITS)
@pytest.mark.parametrize("group_size", GPTQ_MARLIN_SUPPORTED_GROUP_SIZES)
@pytest.mark.parametrize("quant_type",
query_marlin_supported_quant_types(False))
@pytest.mark.parametrize("group_size", MARLIN_SUPPORTED_GROUP_SIZES)
@pytest.mark.parametrize("act_order", ACT_ORDER_OPTS)
@pytest.mark.parametrize("mnk_factors", MNK_FACTORS)
def test_marlin_repack(k_chunk, n_chunk, num_bits, group_size, act_order,
mnk_factors):
def test_gptq_marlin_repack(k_chunk, n_chunk, quant_type, group_size,
act_order, mnk_factors):
m_factor, n_factor, k_factor = mnk_factors
size_m = m_factor
......@@ -77,11 +95,11 @@ def test_marlin_repack(k_chunk, n_chunk, num_bits, group_size, act_order,
b_weight = rand_data((size_k, size_n))
# Quantize (and apply act_order if provided)
w_ref, q_w, s, g_idx, rand_perm = quantize_weights(b_weight, num_bits,
group_size, act_order)
w_ref, q_w, s, g_idx, rand_perm = gptq_quantize_weights(
b_weight, quant_type, group_size, act_order)
# Pack to GPTQ format
q_w_gptq = gptq_pack(q_w, num_bits, size_k, size_n)
q_w_gptq = gptq_pack(q_w, quant_type.size_bits, size_k, size_n)
# For act_order, sort the "weights" and "g_idx" so that group ids are
# increasing
......@@ -90,8 +108,9 @@ def test_marlin_repack(k_chunk, n_chunk, num_bits, group_size, act_order,
q_w, g_idx, sort_indices = sort_weights(q_w, g_idx)
# Pack to Marlin format
marlin_q_w_1 = marlin_weights(q_w, size_k, size_n, num_bits,
marlin_perm[num_bits])
weight_perm = get_weight_perm(quant_type.size_bits)
marlin_q_w_1 = marlin_weights(q_w, size_k, size_n, quant_type.size_bits,
weight_perm)
# Run Marlin repack GPU kernel
marlin_q_w_2 = ops.gptq_marlin_repack(
......@@ -99,30 +118,85 @@ def test_marlin_repack(k_chunk, n_chunk, num_bits, group_size, act_order,
sort_indices,
size_k,
size_n,
num_bits,
quant_type.size_bits,
)
torch.cuda.synchronize()
assert torch.allclose(marlin_q_w_1, marlin_q_w_2)
@pytest.mark.skipif(not is_quant_method_supported("gptq_marlin"),
reason="Marlin is not supported on this GPU type.")
@pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS)
@pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS)
@pytest.mark.parametrize("quant_type",
query_marlin_supported_quant_types(False))
@pytest.mark.parametrize("group_size", MARLIN_SUPPORTED_GROUP_SIZES)
@pytest.mark.parametrize("mnk_factors", MNK_FACTORS)
def test_awq_marlin_repack(k_chunk, n_chunk, quant_type, group_size,
mnk_factors):
m_factor, n_factor, k_factor = mnk_factors
size_m = m_factor
size_k = k_chunk * k_factor
size_n = n_chunk * n_factor
print(f"MNK = {size_m} {size_n} {size_k}")
# Normalize group_size
if group_size == -1:
group_size = size_k
assert group_size <= size_k
# Create input
b_weight = rand_data((size_k, size_n))
# Quantize
w_ref, q_w, s, zp = quantize_weights(b_weight,
quant_type,
group_size,
zero_points=True)
# Pack to AWQ format
q_w_awq = awq_pack(q_w, quant_type.size_bits, size_k, size_n)
# Pack to Marlin format
weight_perm = get_weight_perm(quant_type.size_bits)
marlin_q_w_1 = marlin_weights(q_w, size_k, size_n, quant_type.size_bits,
weight_perm)
# Run Marlin repack GPU kernel
marlin_q_w_2 = ops.awq_marlin_repack(
q_w_awq,
size_k,
size_n,
quant_type.size_bits,
)
torch.cuda.synchronize()
assert torch.allclose(marlin_q_w_1, marlin_q_w_2)
@pytest.mark.skipif(not is_marlin_supported(),
@pytest.mark.skipif(not is_quant_method_supported("gptq_marlin"),
reason="Marlin is not supported on this GPU type.")
@pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS)
@pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS)
@pytest.mark.parametrize("num_bits", GPTQ_MARLIN_SUPPORTED_NUM_BITS)
@pytest.mark.parametrize("group_size", GPTQ_MARLIN_SUPPORTED_GROUP_SIZES)
@pytest.mark.parametrize("quant_type",
query_marlin_supported_quant_types(False))
@pytest.mark.parametrize("group_size", MARLIN_SUPPORTED_GROUP_SIZES)
@pytest.mark.parametrize("mnk_factors", MNK_FACTORS)
@pytest.mark.parametrize("act_order", ACT_ORDER_OPTS)
@pytest.mark.parametrize("is_k_full", K_FULL_OPTS)
def test_marlin_gemm(
@pytest.mark.parametrize("use_fp32_reduce", USE_FP32_REDUCE_OPTS)
def test_gptq_marlin_gemm(
k_chunk,
n_chunk,
num_bits,
quant_type,
group_size,
mnk_factors,
act_order,
is_k_full,
use_fp32_reduce,
):
m_factor, n_factor, k_factor = mnk_factors
......@@ -143,7 +217,9 @@ def test_marlin_gemm(
b_weight = rand_data((size_k, size_n))
w_ref, marlin_q_w, marlin_s, g_idx, sort_indices, _ = marlin_quantize(
b_weight, num_bits, group_size, act_order)
b_weight, quant_type, group_size, act_order)
marlin_zp = marlin_make_empty_g_idx(marlin_s.device)
workspace = MarlinWorkspace(size_n, GPTQ_MARLIN_MIN_THREAD_N,
GPTQ_MARLIN_MAX_PARALLEL)
......@@ -152,14 +228,17 @@ def test_marlin_gemm(
a_input,
marlin_q_w,
marlin_s,
marlin_zp,
g_idx,
sort_indices,
workspace.scratch,
num_bits,
quant_type,
a_input.shape[0],
b_weight.shape[1],
a_input.shape[1],
is_k_full,
is_k_full=is_k_full,
has_zp=False,
use_fp32_reduce=use_fp32_reduce,
)
output_ref = torch.matmul(a_input, w_ref)
......@@ -171,14 +250,15 @@ def test_marlin_gemm(
assert max_diff < 0.04
@pytest.mark.skipif(not is_marlin_supported(),
@pytest.mark.skipif(not is_quant_method_supported("gptq_marlin"),
reason="Marlin is not supported on this GPU type.")
@pytest.mark.parametrize("k_chunk", MARLIN_24_K_CHUNKS)
@pytest.mark.parametrize("n_chunk", MARLIN_24_N_CHUNKS)
@pytest.mark.parametrize("num_bits", GPTQ_MARLIN_24_SUPPORTED_NUM_BITS)
@pytest.mark.parametrize("quant_type", GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES)
@pytest.mark.parametrize("group_size", GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES)
@pytest.mark.parametrize("mnk_factors", MNK_FACTORS)
def test_marlin_24_gemm(k_chunk, n_chunk, num_bits, group_size, mnk_factors):
def test_gptq_marlin_24_gemm(k_chunk, n_chunk, quant_type, group_size,
mnk_factors):
m_factor, n_factor, k_factor = mnk_factors
size_m = m_factor
......@@ -192,7 +272,7 @@ def test_marlin_24_gemm(k_chunk, n_chunk, num_bits, group_size, mnk_factors):
b_weight = rand_data((size_k, size_n))
(w_24_ref, marlin_24_q_w_comp, marlin_24_meta,
marlin_24_s) = marlin_24_quantize(b_weight, num_bits, group_size)
marlin_24_s) = marlin_24_quantize(b_weight, quant_type, group_size)
workspace_24 = MarlinWorkspace(size_n, GPTQ_MARLIN_24_MIN_THREAD_N,
GPTQ_MARLIN_24_MAX_PARALLEL)
......@@ -205,7 +285,7 @@ def test_marlin_24_gemm(k_chunk, n_chunk, num_bits, group_size, mnk_factors):
marlin_24_meta,
marlin_24_s,
workspace_24.scratch,
num_bits,
quant_type,
a_input.shape[0],
b_weight.shape[1],
a_input.shape[1],
......@@ -217,3 +297,204 @@ def test_marlin_24_gemm(k_chunk, n_chunk, num_bits, group_size, mnk_factors):
print("max_diff = {}".format(max_diff))
assert max_diff < 0.04
@pytest.mark.skipif(not is_quant_method_supported("fp8"),
reason="Marlin is not supported on this GPU type.")
@pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS)
@pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS)
@pytest.mark.parametrize("num_bits", [8])
@pytest.mark.parametrize("group_size", [-1])
@pytest.mark.parametrize("mnk_factors", MNK_FACTORS)
@pytest.mark.parametrize("dtype", DTYPES)
def test_fp8_marlin_gemm(
k_chunk,
n_chunk,
num_bits,
group_size,
mnk_factors,
dtype,
):
m_factor, n_factor, k_factor = mnk_factors
size_m = m_factor
size_k = k_chunk * k_factor
size_n = n_chunk * n_factor
print(f"MNK = {size_m} {size_n} {size_k}")
print(f"groupsize = {group_size}")
a_input = rand_data((size_m, size_k), dtype=dtype)
b_weight = rand_data((size_k, size_n), dtype=dtype)
# WEIGHTS
fp8_weight, weight_scale = ops.scaled_fp8_quant(b_weight, scale=None)
# Repack weights to gptq format (packed int32 elements)
packed_gptq_qweight = pack_fp8_to_int32(fp8_weight)
# Repack weights to marlin format
marlin_qweight = ops.gptq_marlin_repack(
b_q_weight=packed_gptq_qweight,
perm=torch.empty(0, dtype=torch.int, device="cuda"),
size_k=size_k,
size_n=size_n,
num_bits=8,
)
# WEIGHT SCALES
# Currently Marlin doesn't support per-tensor scales, so we
# expand it to channelwise
scales = weight_scale.repeat(1, size_n).to(a_input.dtype).to("cuda")
# Permute scales
marlin_scales = marlin_permute_scales(s=scales,
size_k=size_k,
size_n=size_n,
group_size=-1)
workspace = MarlinWorkspace(size_n, GPTQ_MARLIN_MIN_THREAD_N,
GPTQ_MARLIN_MAX_PARALLEL)
output = ops.fp8_marlin_gemm(
a=a_input,
b_q_weight=marlin_qweight,
b_scales=marlin_scales,
workspace=workspace.scratch,
num_bits=num_bits,
size_m=a_input.shape[0],
size_n=b_weight.shape[1],
size_k=a_input.shape[1],
)
output_ref = torch.matmul(a_input, b_weight)
torch.cuda.synchronize()
max_diff = compute_max_diff(output, output_ref)
print("max_diff = {}".format(max_diff))
assert max_diff < 0.04
@pytest.mark.skipif(not is_quant_method_supported("gptq_marlin"),
reason="Marlin is not supported on this GPU type.")
@pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS)
@pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS)
@pytest.mark.parametrize("quant_type",
query_marlin_supported_quant_types(True))
@pytest.mark.parametrize("group_size", MARLIN_SUPPORTED_GROUP_SIZES)
@pytest.mark.parametrize("mnk_factors", MNK_FACTORS)
@pytest.mark.parametrize("use_fp32_reduce", USE_FP32_REDUCE_OPTS)
def test_awq_marlin_gemm(
k_chunk,
n_chunk,
quant_type,
group_size,
mnk_factors,
use_fp32_reduce,
):
m_factor, n_factor, k_factor = mnk_factors
size_m = m_factor
size_k = k_chunk * k_factor
size_n = n_chunk * n_factor
print(f"MNK = {size_m} {size_n} {size_k}")
print(f"groupsize = {group_size}")
a_input = rand_data((size_m, size_k))
b_weight = rand_data((size_k, size_n))
w_ref, marlin_q_w, marlin_s, marlin_zp = awq_marlin_quantize(
b_weight, quant_type, group_size)
g_idx = torch.empty(0, dtype=torch.int, device=marlin_q_w.device)
sort_indices = torch.empty(0, dtype=torch.int, device=marlin_q_w.device)
is_k_full = True
has_zp = True
workspace = MarlinWorkspace(size_n, GPTQ_MARLIN_MIN_THREAD_N,
GPTQ_MARLIN_MAX_PARALLEL)
output = ops.gptq_marlin_gemm(
a_input,
marlin_q_w,
marlin_s,
marlin_zp,
g_idx,
sort_indices,
workspace.scratch,
quant_type,
a_input.shape[0],
b_weight.shape[1],
a_input.shape[1],
is_k_full=is_k_full,
has_zp=has_zp,
use_fp32_reduce=use_fp32_reduce,
)
output_ref = torch.matmul(a_input, w_ref)
torch.cuda.synchronize()
max_diff = compute_max_diff(output, output_ref)
print("max_diff = {}".format(max_diff))
assert max_diff < 0.04
@pytest.mark.skipif(not is_quant_method_supported("qqq"),
reason="Marlin is not supported on this GPU type.")
@pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS)
@pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS)
@pytest.mark.parametrize("num_bits", MARLIN_QQQ_SUPPORTED_NUM_BITS)
@pytest.mark.parametrize("group_size", MARLIN_QQQ_SUPPORTED_GROUP_SIZES)
@pytest.mark.parametrize("mnk_factors", MNK_FACTORS)
def test_marlin_qqq_gemm(
k_chunk,
n_chunk,
num_bits,
group_size,
mnk_factors,
):
int8_traits = torch.iinfo(torch.int8)
m_factor, n_factor, k_factor = mnk_factors
size_m = m_factor
size_k = k_chunk * k_factor
size_n = n_chunk * n_factor
print(f"MNK = {size_m} {size_n} {size_k}")
print(f"groupsize = {group_size}")
a_input = rand_data((size_m, size_k))
b_weight = rand_data((size_k, size_n))
# Quantize activations
s_a = a_input.abs().max(dim=-1, keepdim=True)[0].div(int8_traits.max).to(
torch.float)
q_a = (a_input / s_a).round().clamp(int8_traits.min,
int8_traits.max).to(torch.int8)
# Quantize weights
w_ref, marlin_qqq_q_w, marlin_qqq_s_group, marlin_qqq_s_channel = \
marlin_qqq_quantize(b_weight, num_bits, group_size)
workspace = MarlinWorkspace(size_n, MARLIN_QQQ_MIN_THREAD_N,
MARLIN_QQQ_MAX_PARALLEL)
output = ops.marlin_qqq_gemm(
q_a,
marlin_qqq_q_w,
s_a,
marlin_qqq_s_channel,
marlin_qqq_s_group,
workspace.scratch,
a_input.shape[0],
b_weight.shape[1],
a_input.shape[1],
)
output_ref = torch.matmul(q_a.half() * s_a.half(), w_ref)
torch.cuda.synchronize()
max_diff = compute_max_diff(output, output_ref)
print("max_diff = {}".format(max_diff))
assert max_diff < 0.04
\ No newline at end of file
......@@ -29,7 +29,7 @@ def torch_moe(a, w1, w2, score, topk):
topk_weight.view(B, -1, 1).to(out.dtype)).sum(dim=1)
@pytest.mark.parametrize("m", [512, 222, 33, 1])
@pytest.mark.parametrize("m", [1024 * 128, 512, 222, 33, 1])
@pytest.mark.parametrize("n", [2048, 256, 1024])
@pytest.mark.parametrize("k", [128, 511, 1024])
@pytest.mark.parametrize("e", [8, 64])
......@@ -77,8 +77,8 @@ def test_mixtral_moe(dtype: torch.dtype):
for i in range(config.num_local_experts):
weights = (hf_moe.experts[i].w1.weight.data,
hf_moe.experts[i].w3.weight.data)
vllm_moe.w13_weight[i][:] = torch.cat(weights, dim=0)
vllm_moe.w2_weight[i][:] = hf_moe.experts[i].w2.weight.data
vllm_moe.experts.w13_weight[i][:] = torch.cat(weights, dim=0)
vllm_moe.experts.w2_weight[i][:] = hf_moe.experts[i].w2.weight.data
# Generate input batch of dimensions [batch_size, seq_len, hidden_dim]
hf_inputs = torch.randn((1, 64, config.hidden_size)).to(dtype).to("cuda")
......
from itertools import accumulate, product
from typing import List, Optional
from typing import Dict, List, Optional
import pytest
import torch
......@@ -10,7 +10,7 @@ from .allclose_default import get_default_atol, get_default_rtol
IS_NEOX_STYLE = [True, False]
DTYPES = [torch.half, torch.bfloat16, torch.float]
HEAD_SIZES = [64, 80, 96, 112, 128, 192, 256]
HEAD_SIZES = [64, 80, 96, 112, 120, 128, 192, 256]
ROTARY_DIMS = [None, 32] # None means rotary dim == head size
NUM_HEADS = [7, 17] # Arbitrary values for testing
BATCH_SIZES = [1, 5] # Arbitrary values for testing
......@@ -126,7 +126,7 @@ def test_batched_rotary_embedding(
query,
key,
offsets=torch.zeros(batch_size * seq_len,
dtype=int,
dtype=torch.long,
device=device))
# Compare the results.
assert torch.allclose(out_query,
......@@ -214,20 +214,16 @@ def test_batched_rotary_embedding_multi_lora(
def test_rope_module_cache():
MAX_POSITIONS = [123, 1234]
BASES = [10000, 1000000]
ROPE_SCALINGS = [
None, {
"type": "linear",
"factor": (1, )
}, {
"type": "dynamic",
"factor": 1
}
]
settings = [
HEAD_SIZES, ROTARY_DIMS, MAX_POSITIONS, BASES, IS_NEOX_STYLE,
ROPE_SCALINGS, DTYPES
]
rope_setting_id_map = {}
ROPE_SCALINGS = (None, {
"type": "linear",
"factor": (1, )
}, {
"type": "dynamic",
"factor": 1
})
settings = (HEAD_SIZES, ROTARY_DIMS, MAX_POSITIONS, BASES, IS_NEOX_STYLE,
ROPE_SCALINGS, DTYPES)
rope_setting_id_map: Dict[str, int] = {}
for setting in product(*settings):
head_size, rotary_dim, max_position, base, \
is_neox_stype, rope_scaling, dtype = setting
......
import gc
from unittest.mock import patch
import pytest
import torch
import triton
import triton.language as tl
from vllm.model_executor.layers.ops.sample import (
MAX_TRITON_N_COLS, _uniform_to_exponential, get_num_triton_sampler_splits,
sample)
from vllm.model_executor.layers.ops.sample import (_sample_triton,
_uniform_to_exponential,
sample)
from vllm.model_executor.sampling_metadata import SamplingTensors
from vllm.model_executor.utils import set_random_seed
from vllm.triton_utils.libentry import LibEntry
from vllm.triton_utils.sample import (MAX_TRITON_N_COLS,
get_num_triton_sampler_splits)
SINGLE_SPLIT_VOCAB_SIZE = 32000 # llama/mistral/mixtral vocab size
MULTI_SPLIT_VOCAB_SIZE = MAX_TRITON_N_COLS + 100
......@@ -75,15 +79,20 @@ def test_sample_decoding_only(random_sampling, max_best_of,
seeds = torch.randint(1,
torch.iinfo(torch.long).max, (n_splits, bs),
device="cuda").mul_(random_sampling_mask)
sampled_tokens, sampled_logprobs, sampled_modified_probs = sample(
probs=probs,
logprobs=logprobs,
sample_indices=sample_indices,
seeds=seeds,
max_best_of=max_best_of,
modify_greedy_probs=modify_greedy_probs,
save_logprobs=save_logprobs,
_save_modified_probs=True)
#The current _sample_triton does not utilize the
# libentry decoration. The purpose of adding this patch is to test
# the correctness of libentry.
with patch("vllm.model_executor.layers.ops.sample._sample_triton",
LibEntry(_sample_triton)):
sampled_tokens, sampled_logprobs, sampled_modified_probs = sample(
probs=probs,
logprobs=logprobs,
sample_indices=sample_indices,
seeds=seeds,
max_best_of=max_best_of,
modify_greedy_probs=modify_greedy_probs,
save_logprobs=save_logprobs,
_save_modified_probs=True)
assert sampled_tokens.shape == (bs, max_best_of)
for i in range(bs):
assert torch.all(sampled_tokens[i] == i * (vocab_size // bs))
......@@ -129,6 +138,7 @@ def test_sample_decoding_only(random_sampling, max_best_of,
[SINGLE_SPLIT_VOCAB_SIZE, MULTI_SPLIT_VOCAB_SIZE])
def test_sample_prompt_logprobs(random_sampling, max_best_of,
modify_greedy_probs, seed, vocab_size):
set_random_seed(seed)
prompt_sizes = [16, 32, 64, 128] * 2
samples = 8
......@@ -156,14 +166,17 @@ def test_sample_prompt_logprobs(random_sampling, max_best_of,
seeds = torch.randint(1,
torch.iinfo(torch.long).max, (n_splits, samples),
device="cuda").mul_(random_sampling_mask)
sampled_tokens, sampled_logprobs, _ = sample(
probs=probs,
logprobs=logprobs,
sample_indices=sample_indices,
seeds=seeds,
max_best_of=max_best_of,
modify_greedy_probs=modify_greedy_probs,
save_logprobs=True)
#ditto
with patch("vllm.model_executor.layers.ops.sample._sample_triton",
LibEntry(_sample_triton)):
sampled_tokens, sampled_logprobs, _ = sample(
probs=probs,
logprobs=logprobs,
sample_indices=sample_indices,
seeds=seeds,
max_best_of=max_best_of,
modify_greedy_probs=modify_greedy_probs,
save_logprobs=True)
assert sampled_tokens.shape == (samples, max_best_of)
assert sampled_logprobs.shape == (samples, max_best_of)
for i, t in enumerate(sample_indices):
......
"""Kernel test utils"""
import itertools
import random
from numbers import Number
from typing import Any, List, NamedTuple, Optional, Tuple, Union
import pytest
import torch
from vllm.attention.backends.abstract import (AttentionBackend,
AttentionMetadata, AttentionType)
from vllm.attention.backends.xformers import XFormersBackend
from vllm.utils import make_tensor_with_pad
# String name of register which may be set in order to
# force auto-selection of attention backend by Attention
# wrapper
STR_BACKEND_ENV_VAR: str = "VLLM_ATTENTION_BACKEND"
# Possible string values of STR_BACKEND_ENV_VAR
# register, corresponding to possible backends
STR_FLASHINFER_ATTN_VAL: str = "FLASHINFER"
STR_TORCH_SDPA_ATTN_VAL: str = "TORCH_SDPA"
STR_ROCM_FLASH_ATTN_VAL: str = "ROCM_FLASH"
STR_XFORMERS_ATTN_VAL: str = "XFORMERS"
STR_FLASH_ATTN_VAL: str = "FLASH_ATTN"
STR_INVALID_VAL: str = "INVALID"
class QKVInputs(NamedTuple):
'''
Data structure for representing unpacked attention inputs,
query/key/values and their sequence lengths.
Attributes:
* {query,key,value}: unpacked (batch_size x padded_seq_len x
num_heads x head_size) attention inputs
* q_seq_lens: query sequence lengths list
* kv_seq_lens: shared key/value sequence lengths list
'''
query: torch.Tensor
key: torch.Tensor
value: torch.Tensor
q_seq_lens: List[int]
kv_seq_lens: List[int]
class QKVO(NamedTuple):
'''
Data structure for representing unpacked attention inputs,
alongside unpacked known-correct attention output
Attributes:
* qkv: unpacked (batch_size x padded_seq_len x
num_heads x head_size) attention inputs
* ideal_output: unpacked (batch_size x padded_seq_len x
num_heads x head_size) known-correct attention output
'''
qkv: QKVInputs
ideal_output: torch.Tensor
class PackedQKVInputs(NamedTuple):
'''
Data structure for representing packed attention inputs
Attributes:
* {query,key,value}: packed (number_of_tokens x num_heads
x head_size) attention inputs
* q_start_loc_list: list of query start locations within packed tensor
* kv_start_loc_list: shared list of key/value start locations within
packed tensor
* q_seq_lens: query sequence lengths list
* kv_seq_lens: shared key/value sequence lengths list
'''
query: torch.Tensor
key: torch.Tensor
value: torch.Tensor
q_start_loc_list: Optional[List[int]]
kv_start_loc_list: Optional[List[int]]
q_seq_lens: Optional[List[int]]
kv_seq_lens: Optional[List[int]]
class PackedQKVO(NamedTuple):
'''
Data structure for representing packed attention inputs,
alongside packed known-correct attention output
Attributes:
* packed_qkv: packed (number_of_tokens x num_heads
x head_size) attention inputs
* ideal_output: packed (number_of_tokens x num_heads
x head_size) known-correct attention output
'''
packed_qkv: Optional[PackedQKVInputs]
ideal_output: torch.Tensor
class KVMemoryMap(NamedTuple):
'''
Data structure for encapsulating KV cache memory mapping.
Attributes:
* block_tables: KV cache block tables
* slot_mapping: mapping of sequence offset to physical address
'''
block_tables: torch.Tensor
slot_mapping: torch.Tensor
class PhaseTestParameters(NamedTuple):
'''
Data structure for encapsulating the test parameters
for a given test "phase" (prefill or decode phase) and attention
scenario (encoder, decoder-self, encoder/decoder-cross)
Attributes:
* packed_qkvo: packed (number_of_tokens x num_heads
x head_size) attention inputs & known-correct
output
* kv_mmap: KV cache memory mapping, specific to this test phase &
attention scenario
'''
packed_qkvo: PackedQKVO
kv_mmap: Optional[KVMemoryMap]
def maybe_make_int_tensor(
_list: Optional[List[int]],
device: Union[torch.device, str],
) -> torch.Tensor:
'''
Convert Python int list to a 1D int torch.Tensor on `device`
Returns:
* If _list is not None: 1D int torch.Tensor on `device`
* None otherwise
'''
return None if _list is None else torch.tensor(
_list, dtype=torch.int, device=device)
def maybe_make_long_tensor(
_list: Optional[List[int]],
device: Union[torch.device, str],
) -> torch.Tensor:
'''
Convert Python int list to a 1D long torch.Tensor on `device`
Returns:
* If _list is not None: 1D long torch.Tensor on `device`
* None otherwise
'''
return None if _list is None else torch.tensor(
_list, dtype=torch.long, device=device)
def maybe_max(_list: Optional[List]) -> Optional[Number]:
'''
Returns:
* If _list is not None: max(_list)
* None otherwise
'''
return None if _list is None else max(_list)
def make_causal_mask(
q_max_seq_len: int,
kv_max_seq_len: int,
) -> torch.Tensor:
'''
Create a q_max_seq_len x kv_max_seq_len causal mask
Arguments:
* q_max_seq_len: query max seq len
* kv_max_seq_len: key/value max seq len
Returns:
* 2D tensor, q_max_seq_len x kv_max_seq_len
'''
# Create a matrix where entry (i, j) is True if i >= j
mask = torch.triu(torch.ones(q_max_seq_len, kv_max_seq_len), diagonal=1)
# Replace True with float('-inf') and False with 0
mask = mask.masked_fill(mask == 1,
float('-inf')).masked_fill(mask == 0, 0.0)
return mask
def override_backend_env_variable(mpatch: pytest.MonkeyPatch,
backend_name: str) -> None:
'''
......@@ -20,3 +219,724 @@ def override_backend_env_variable(mpatch: pytest.MonkeyPatch,
* backend_name: attention backend name to force
'''
mpatch.setenv(STR_BACKEND_ENV_VAR, backend_name)
def ref_masked_attention(query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
scale: float,
custom_mask: Optional[torch.Tensor] = None,
q_seq_lens: Optional[List] = None,
kv_seq_lens: Optional[List] = None) -> torch.Tensor:
'''
"Golden" masked attention reference. Supports two types of masking:
* Basic attention mask, utilizing {q,kv}_seq_lens args to mask out
padding elements
* Custom attention mask, which can force an arbitrary mask tensor, i.e.
causal
Arguments:
* query: batch_size x q_padded_seq_len x num_heads x head_size
* key: batch_size x kv_padded_seq_len x num_heads x head_size
* value: batch_size x kv_padded_seq_len x num_heads x head_size
* scale: Attention scale factor
* custom_mask: custom attention mask; good place to inject a causal
attention mask
* q_seq_lens: list of unpadded query seq_lens for each batch index
* kv_seq_lens: list of unpadded key/value seq_lens for each batch index
Returns:
* Attention result, batch_size x q_padded_seq_len x num_heads x head_size
'''
assert q_seq_lens is not None
assert kv_seq_lens is not None
batch_size = query.shape[0]
assert (len(q_seq_lens) == batch_size)
assert (len(kv_seq_lens) == batch_size)
attn_weights = scale * torch.einsum("bqhd,bkhd->bhqk", query, key).float()
# Basic attention mask, derived from seq lens
if (q_seq_lens is not None) or (kv_seq_lens is not None):
attn_mask = torch.zeros_like(attn_weights)
if q_seq_lens is not None:
for bdx, plen in enumerate(q_seq_lens):
attn_mask[bdx, :, plen:, :] = -torch.inf
if kv_seq_lens is not None:
for bdx, plen in enumerate(kv_seq_lens):
attn_mask[bdx, :, :, plen:] = -torch.inf
attn_weights = attn_weights + attn_mask.float()
# Custom attention mask
if custom_mask is not None:
attn_weights = attn_weights + custom_mask.float()
attn_weights = torch.softmax(attn_weights, dim=-1).to(value.dtype)
out = torch.einsum("bhqk,bkhd->bqhd", attn_weights, value)
return out
def make_qkv(
batch_size: int,
max_q_seq_len: int,
max_kv_seq_len: Optional[int],
num_heads: int,
head_size: int,
device: Union[torch.device, str],
force_kv_seq_lens: Optional[List[int]] = None,
attn_type: AttentionType = AttentionType.ENCODER_DECODER,
force_max_len: bool = False,
) -> Tuple[QKVInputs, QKVInputs, QKVInputs]:
'''
Construct QKV test tensors for self- and cross-attention.
Generates three query/key/value triplets:
* "Baseline" query/key/value (for input to reference attention function)
* "Prefill" query/key/value (last sequence offset zero'd out, for use as
input to prefill kernel)
* "Decode" query/key/value (only the last sequence offset from baseline,
for use as input to decode kernel)
Each Q/K/V triplet is associated with a list of q seqlens and a list of k/v
seqlens
Arguments:
* batch_size
* max_q_seq_len: max query seq len
* max_kv_seq_len: max key/value seq len
* num_heads
* head_size
* is_encoder_decoder_attn: if True, query seqlen may differ from
key/value seqlen (as is often the case for cross-attention);
o/w, query/key/value seqlens match at each batch index
(max_kv_seq_len is unused)
* force_kv_seq_lens: if not None, overrides kv sequence lengths
* attn_type: encoder, decoder self, or enc/dec cross attention
* force_max_len: if True, all query seqlens are max_q_seq_len; o/w query
seqlens are random in [2,max_q_seq_lens]. Same for key/value seqlens
and max_kv_seq_len, unless forced by is_encoder_decoder_attn=False
* device: CPU or CUDA device
Returns:
* Overall QKVInputs structure (containing full unpacked Q/K/V tensors)
* Prefill QKVInputs structure (containing all but the last sequence offset)
* Decode QKVInputs structure (containing all only the last sequence offset)
'''
if force_max_len:
q_seq_lens = [max_q_seq_len for _ in range(batch_size)]
else:
q_seq_lens = [
random.randint(2, max_q_seq_len) for _ in range(batch_size)
]
kv_seq_lens = None
if force_kv_seq_lens is not None:
kv_seq_lens = force_kv_seq_lens
elif attn_type != AttentionType.ENCODER_DECODER:
# K,V seq lens match Q for self-attention
kv_seq_lens = q_seq_lens
else:
# K,V seq lens are distinct from Q seq lens & random
assert max_kv_seq_len is not None
if force_max_len:
kv_seq_lens = [max_kv_seq_len] * batch_size
else:
kv_seq_lens = [
random.randint(2, max_kv_seq_len) for _ in range(batch_size)
]
query = torch.rand(
(batch_size, max_q_seq_len, num_heads, head_size)).to(device)
key = torch.rand(
(batch_size, max_kv_seq_len, num_heads, head_size)).to(device)
value = torch.rand(
(batch_size, max_kv_seq_len, num_heads, head_size)).to(device)
prefill_query = torch.zeros(
(batch_size, max_q_seq_len, num_heads, head_size)).to(device)
prefill_key = torch.zeros(
(batch_size, max_kv_seq_len, num_heads, head_size)).to(device)
prefill_value = torch.zeros(
(batch_size, max_kv_seq_len, num_heads, head_size)).to(device)
decode_query = torch.zeros(
(batch_size, 1, num_heads, head_size)).to(device)
decode_key = torch.zeros((batch_size, 1, num_heads, head_size)).to(device)
decode_value = torch.zeros(
(batch_size, 1, num_heads, head_size)).to(device)
for bdx, (q_seq_len, kv_seq_len) in enumerate(zip(q_seq_lens,
kv_seq_lens)):
query[bdx, q_seq_len:, :, :] = 0
key[bdx, kv_seq_len:, :, :] = 0
value[bdx, kv_seq_len:, :, :] = 0
prefill_query[bdx,
0:(q_seq_len - 1), :, :] = query[bdx,
0:(q_seq_len - 1), :, :]
prefill_key[bdx,
0:(kv_seq_len - 1), :, :] = key[bdx,
0:(kv_seq_len - 1), :, :]
prefill_value[bdx, 0:(kv_seq_len -
1), :, :] = value[bdx, 0:(kv_seq_len - 1), :, :]
decode_query[bdx, :, :, :] = query[bdx,
(q_seq_len - 1):q_seq_len, :, :]
decode_key[bdx, :, :, :] = key[bdx, (kv_seq_len - 1):kv_seq_len, :, :]
decode_value[bdx, :, :, :] = value[bdx,
(kv_seq_len - 1):kv_seq_len, :, :]
prefill_q_seq_lens = [plen - 1 for plen in q_seq_lens]
prefill_kv_seq_lens = [plen - 1 for plen in kv_seq_lens]
decode_q_seq_lens = [1 for _ in q_seq_lens]
decode_kv_seq_lens = [1 for _ in kv_seq_lens]
return (
QKVInputs(
query, # Overall QKV inputs
key,
value,
q_seq_lens,
kv_seq_lens),
QKVInputs(
prefill_query, # Prefill subset of QKV sequences
prefill_key,
prefill_value,
prefill_q_seq_lens,
prefill_kv_seq_lens),
QKVInputs(
decode_query, # Decode subset of KV sequences
decode_key,
decode_value,
decode_q_seq_lens,
decode_kv_seq_lens))
def pack_tensor(
unpacked_tensor: torch.Tensor, seq_lens: List[int],
device: Union[torch.device, str]) -> Tuple[torch.Tensor, List[int]]:
'''
Pack a batch_size x padded_seq_len x num_heads x head_size tensor into an
unpadded number_of_tokens x num_heads x head_size tensor, where
number_of_tokens = sum(seq_lens)
Arguments:
* unpacked_tensor: batch_size x padded_seq_len x num_heads x head_size
* seq_lens: list of token counts for each seq
* device: CPU or CUDA device
Returns
* packed_tensor: number_of_tokens x num_heads x head_size
* start_loc_list: start idx of each batch elt in packed_tensor; [0] +
list(itertools.accumulate(seq_lens))
'''
num_tok = sum(seq_lens)
num_heads = unpacked_tensor.shape[-2]
head_size = unpacked_tensor.shape[-1]
start_loc_list = [0] + list(itertools.accumulate(seq_lens))
packed_tensor = torch.zeros((num_tok, num_heads, head_size), device=device)
for bdx, (seq_len, start_loc) in enumerate(zip(seq_lens, start_loc_list)):
packed_tensor[start_loc:(
start_loc + seq_len), :, :] = unpacked_tensor[bdx, :seq_len, :, :]
return packed_tensor, start_loc_list
def pack_qkv(qkv: QKVInputs, device: Union[torch.device,
str]) -> PackedQKVInputs:
'''
Individually pack each of Q, K and V, each with dimensions batch_size x
padded_seq_len x num_heads x head_size, into respective number_of_tokens x
num_heads x head_size tensors.
For Q, number_of_tokens = sum(q_seq_lens).
For K and V, number_of_tokens = sum(kv_seq_lens)
Arguments:
* qkv: Unpacked (batch_size x padded_seq_len x num_heads x head_size)
attention inputs
* device: CPU or CUDA device
Returns
* Packed (number_of_tokens x num_heads x head_size) QKV inputs
derived from unpacked inputs
'''
if qkv.query is None:
packed_query = None
q_start_loc_list = None
else:
packed_query, q_start_loc_list = pack_tensor(qkv.query,
qkv.q_seq_lens,
device=device)
packed_key, kv_start_loc_list = pack_tensor(qkv.key,
qkv.kv_seq_lens,
device=device)
packed_value, _ = pack_tensor(qkv.value, qkv.kv_seq_lens, device=device)
return PackedQKVInputs(
packed_query, packed_key, packed_value, q_start_loc_list,
kv_start_loc_list,
(None if q_start_loc_list is None else qkv.q_seq_lens),
qkv.kv_seq_lens)
def make_backend(backend_name: str) -> AttentionBackend:
'''
Construct the backend instance determined by the backend_name string
argument.
"XFORMERS" -> construct xformers backend
TODO: other backends
Note: at time of writing the Attention wrapper automatically selects
its own backend for Attention.forward(); so the backend instance which
you generate with this function is not meant to be used for *running*
inference, but rather for generating compatible metadata structures
using backend.make_metadata()
Returns:
* Backend instance
'''
if backend_name == STR_XFORMERS_ATTN_VAL:
return XFormersBackend()
raise AssertionError(
f"Unrecognized backend_name {backend_name} for unit test")
def _make_metadata_tensors(
seq_lens: Optional[List[int]], context_lens: Optional[List[int]],
encoder_seq_lens: Optional[List[int]], device: Union[torch.device, str]
) -> Tuple[torch.Tensor, torch.Tensor, Any, Any, Optional[List[int]],
torch.Tensor, Optional[int]]:
'''
Build scalar & tensor values required to build attention metadata structure.
Arguments:
* seq_lens: list of token-counts for each decoder input seq
* context_lens: list of context length values for each seq
* encoder_seq_lens: list of token-counts for each encoder input seq
* device: CPU or CUDA device
Returns:
* seq_lens_tensor: decoder seq_lens list, as tensor
* context_lens_tensor: context_lens list, as tensor
* max_context_len: max(context_lens)
* max_seq_len: max(seq_lens)
* seq_start_loc: start idx of each sequence
* max_encoder_seq_len: encoder seq_lens list, as tensor
'''
seq_lens_tensor = maybe_make_int_tensor(seq_lens, device)
context_lens_tensor = maybe_make_int_tensor(context_lens, device)
max_context_len = maybe_max(context_lens)
max_seq_len = maybe_max(seq_lens)
encoder_seq_lens_tensor = maybe_make_int_tensor(encoder_seq_lens, device)
max_encoder_seq_len = (None if encoder_seq_lens is None else
max(encoder_seq_lens))
seq_start_loc = None
return (seq_lens_tensor, context_lens_tensor, max_context_len, max_seq_len,
seq_start_loc, encoder_seq_lens_tensor, max_encoder_seq_len)
def make_kv_cache(num_blocks: int,
num_heads: int,
head_size: int,
block_size: int,
device: Union[torch.device, str],
default_val: float = 0.0) -> torch.Tensor:
'''
Create a fake KV cache.
Arguments:
* num_blocks: number of blocks in the KV cache
* num_heads: number of attention heads
* head_size: head dimension
* block_size: number of offsets within a block
* device: CPU or CUDA device
* default_val: initialization value for KV cache elements
Returns:
* kv_cache: 2 x num_blocks x (block_size * num_heads * head_size)
'''
kv_cache = torch.rand(
(2, num_blocks, block_size * num_heads * head_size)).to(device)
if default_val is not None:
kv_cache[:, :, :] = default_val
return kv_cache
def _num_tokens_to_min_blocks(num_tokens: int, block_size: int) -> int:
'''
Compute the minimum number of blocks required to hold num_tokens tokens,
given block_size
'''
return (num_tokens + block_size) // block_size
def make_empty_slot_mapping_tensor(device: Union[torch.device, str]):
return maybe_make_long_tensor([], device)
def make_empty_block_tables_tensor(device: Union[torch.device, str]):
return torch.tensor([], device=device)
def split_slot_mapping(slot_mapping_list: torch.Tensor, seq_lens: List[int],
device: Union[torch.device, str]):
'''
Split a slot mapping into valid prefill- and decode-phase slot mappings.
Context:
* Your goal is to test (1) prefill of N prompts, with prompt-lengths
{K_i \\forall i \\in [0,N)}, followed by (2) decoding of a single token
for all N prompts (N tokens total); the resultant sequence lengths
after decode would be {K_i + 1 for i \\in [0,N)}
* The test you want to do requires (1) having the prefill slot mapping
for all tokens present during prefill, the number of which is
M = \\sum_i{K_i}, and (2) having the decode slot mapping for all N
decoded tokens
This function consumes a single 1D slot mapping, which is the
concatenation of N slot mappings each of length K_i + 1 (corresponding
to the sequence lengths after decode), with a total length of
P = \\sum_i{K_i + 1} = M + N
The prefill-phase slot mapping results from excising the (K_i + 1)-th entry
from each of the N subsequences in the slot mapping (i.e. omitting the
decoded token's mapping.)
The N excised entries are appended to obtain the decode-phase slot mapping
Arguments:
* slot_mapping_list: Length-P 1D slot mapping (as List) reflecting all N
post-decode sequences
* seq_lens: List of N post-decode sequence lengths (K_i + 1 in the
description above)
* device: cuda, cpu, etc.
Returns:
* prefill_slot_mapping: Length-M 1D slot mapping (as Tensor)
reflecting all N prefill prompts
* decode_slot_mapping: Length-N 1D slot mapping (as Tensor) reflecting
all N decoded tokens
'''
prefill_slot_mapping = []
decode_slot_mapping = []
base_idx = 0
for seq_len in seq_lens:
prefill_slot_mapping.extend(slot_mapping_list[base_idx:(base_idx +
seq_len - 1)])
decode_slot_mapping.append(slot_mapping_list[base_idx + seq_len - 1])
base_idx += seq_len
return (maybe_make_long_tensor(prefill_slot_mapping, device),
maybe_make_long_tensor(decode_slot_mapping, device))
def make_block_tables_slot_mapping(
block_size: int,
seq_lens: List[int],
device: Union[torch.device, str],
block_base_addr: int = 0) -> Tuple[torch.Tensor, List[int], int]:
'''
Construct fake block tables & slot mappings.
For a sequence with num_tokens tokens the minimum number
of required KV cache blocks is
num_blocks = (num_tokens + block_size) // block_size
Then the minimum KV cache size in blocks is
total_cache_blocks = sum(num_blocks for all seqs)
Then, the blocktable mapping counts downward from
block_base_addr + total_cache_blocks
to
block_base_addr
The constructed block-tables and slot-mapping are sized to the
lengths of the sequences in their entirety (as reflected by seq_lens),
i.e. the total of prefill prompt tokens + decoded tokens.
Arguments:
* block_size: number of offsets per block
* seq_lens: list of token-counts for each sequence
* block_base_addr: the block table base address
* device: CPU or CUDA device
Return:
* block_tables_tensor: block table for sequence
* slot_mapping_list: slot mapping for sequence
* max_block_idx: the highest block address within this block table
'''
# Provision minimum number of KV cache blocks
num_blocks_list = [
_num_tokens_to_min_blocks(num_tokens, block_size)
for num_tokens in seq_lens
]
max_block_table_len = max(num_blocks_list)
block_table_pad_tokens = 10
block_tables = []
slot_mapping_list = []
# Compute uppermost address of block table
total_cache_blocks = sum(num_blocks_list)
block_base_idx = block_base_addr + total_cache_blocks
max_block_idx = block_base_idx
for sdx, num_tokens in enumerate(seq_lens):
num_blocks = num_blocks_list[sdx]
block_table = list(
range(block_base_idx, block_base_idx - num_blocks, -1))
for idx in range(num_tokens):
mapping_value = (
idx % block_size) + block_table[idx // block_size] * block_size
slot_mapping_list.append(mapping_value)
block_base_idx -= num_blocks
block_tables.append(block_table)
block_tables_tensor = make_tensor_with_pad(
block_tables,
max_len=max_block_table_len + block_table_pad_tokens,
pad=0,
dtype=torch.int,
device=device,
)
return (block_tables_tensor, slot_mapping_list, max_block_idx)
def make_test_metadata(
attn_backend: AttentionBackend,
is_prompt: bool,
seq_lens: Optional[List[int]],
decoder_test_params: Optional[PhaseTestParameters],
device: Union[torch.device, str],
encoder_test_params: Optional[PhaseTestParameters] = None,
cross_test_params: Optional[PhaseTestParameters] = None
) -> AttentionMetadata:
'''
Construct fake attention metadata for a given test phase
(prefill-phase or decode-phase).
encoder_test_params and cross_test_params arguments allow encoder
attention and enc/dec cross-attention (respectively) to use distinct
metadata values from decoder self-attention (decoder_test_params.)
if encoder_test_params and cross_test_params are None, the attention
metadata will support decoder-only scenario.
Assumptions:
* No chunked prefill -> a batch is 100% prefill or 100% decode, never both
Arguments:
* attn_backend: Backend for sourcing attention kernels
* is_prompt: prefill if True, o/w decode
* seq_lens: list of token counts for each sequence
* decoder_test_params: decoder self-attention test params;
this function requires
kv_mmap (memory mapping) field
* device: CPU or CUDA device
* encoder_test_params: encoder attention test params;
this function requires encoder query
sequence lengths field. If None,
encoder query sequence lengths are
treated as None
* cross_test_params: enc/dec cross-attention test params;
this function requires kv_mmap field.
If None, KV cache memory map data
structures are treated as None
Return:
* AttentionMetadata structure
'''
# Decoder self-attention memory mapping
# decoder_test_params is None signals encoder-only
# scenario, so kv_mmap is None
kv_mmap = (None
if decoder_test_params is None else decoder_test_params.kv_mmap)
# This function constructs metadata assuming no chunked prefill,
# i.e. 100% prefill tokens or 100% decode tokens
#
# - If is_prompt, num_prefills_or_decodes is the number of prefills
# and num_prefill_or_decode_tokens is the number of prefill tokens
# - If not is_prompt, num_prefills_or_decodes is the number of decodes
# and num_prefill_or_decode_tokens is the number of decode tokens
#
# seq_lens is None signals encoder-only
# scenario, in which case num_prefills_or_decodes and
# num_prefill_or_decode_tokens are unused
num_prefills_or_decodes = (None if seq_lens is None else len(seq_lens))
num_prefill_or_decode_tokens = (None if seq_lens is None else (
sum(seq_lens) if is_prompt else len(seq_lens)))
# Seems for non-prefix-caching scenarios context_lens
# is never needed
context_lens = None
if encoder_test_params is None:
encoder_seq_lens = None
num_encoder_tokens = None
else:
# Encoder/decoder or encoder-only models only:
# * Extract encoder input sequence lengths
assert encoder_test_params.packed_qkvo.packed_qkv is not None
encoder_seq_lens = encoder_test_params.packed_qkvo.packed_qkv.q_seq_lens
num_encoder_tokens = (None if encoder_seq_lens is None else
(sum(encoder_seq_lens)))
if cross_test_params is None:
cross_kv_mmap = None
else:
# Encoder/decoder or encoder-only models only:
# * Extract *cross-attention* slot_mapping and block table
# (kv_mmap)
cross_kv_mmap = cross_test_params.kv_mmap
if is_prompt:
# Prefill-phase scenario
num_prefills = num_prefills_or_decodes
num_prefill_tokens = num_prefill_or_decode_tokens
num_decode_tokens = 0
(
seq_lens_tensor,
context_lens_tensor,
_,
_,
_,
encoder_seq_lens_tensor,
max_encoder_seq_len,
) = _make_metadata_tensors(seq_lens,
context_lens,
encoder_seq_lens,
device=device)
return attn_backend.make_metadata(
num_prefills=num_prefills,
slot_mapping=(None if kv_mmap is None else kv_mmap.slot_mapping),
num_prefill_tokens=num_prefill_tokens,
num_decode_tokens=num_decode_tokens,
seq_lens=seq_lens,
seq_lens_tensor=seq_lens_tensor,
max_prefill_seq_len=None if seq_lens is None else max(seq_lens),
max_decode_seq_len=0,
context_lens_tensor=context_lens_tensor,
block_tables=(None if kv_mmap is None else kv_mmap.block_tables),
use_cuda_graph=False,
num_encoder_tokens=num_encoder_tokens,
encoder_seq_lens=encoder_seq_lens,
encoder_seq_lens_tensor=encoder_seq_lens_tensor,
max_encoder_seq_len=max_encoder_seq_len,
cross_slot_mapping=(None if cross_kv_mmap is None else
cross_kv_mmap.slot_mapping),
cross_block_tables=(None if cross_kv_mmap is None else
cross_kv_mmap.block_tables))
else: # not is_prompt
# Decode-phase scenario
assert kv_mmap is not None
assert num_prefill_or_decode_tokens is not None
assert seq_lens is not None
num_prefills = 0
num_prefill_tokens = 0
num_decode_tokens = num_prefill_or_decode_tokens
(
seq_lens_tensor,
context_lens_tensor,
_,
_,
_,
encoder_seq_lens_tensor,
max_encoder_seq_len,
) = _make_metadata_tensors(seq_lens,
context_lens,
encoder_seq_lens,
device=device)
return attn_backend.make_metadata(
num_prefills=num_prefills,
slot_mapping=kv_mmap.slot_mapping,
num_prefill_tokens=num_prefill_tokens,
num_decode_tokens=num_decode_tokens,
seq_lens=seq_lens,
seq_lens_tensor=seq_lens_tensor,
max_prefill_seq_len=0,
max_decode_seq_len=max(seq_lens),
context_lens_tensor=context_lens_tensor,
block_tables=kv_mmap.block_tables,
use_cuda_graph=False,
num_encoder_tokens=num_encoder_tokens,
encoder_seq_lens=encoder_seq_lens,
encoder_seq_lens_tensor=encoder_seq_lens_tensor,
max_encoder_seq_len=max_encoder_seq_len,
cross_slot_mapping=(None if cross_kv_mmap is None else
cross_kv_mmap.slot_mapping),
cross_block_tables=(None if cross_kv_mmap is None else
cross_kv_mmap.block_tables))
def assert_actual_matches_ideal(test_params: PhaseTestParameters,
output_under_test: torch.Tensor) -> None:
'''
Assert that observed output matches the ideal output
contained in the test parameters data structure.
Arguments:
* test_params: Test parameters including packed ideal output
* output_under_test: actually observed output value
'''
ideal_output = test_params.packed_qkvo.ideal_output
assert torch.allclose(ideal_output,
output_under_test.view_as(ideal_output))
......@@ -2,6 +2,7 @@ import contextlib
import gc
import tempfile
from collections import OrderedDict
from typing import Dict, List, TypedDict
from unittest.mock import MagicMock, patch
import pytest
......@@ -24,7 +25,18 @@ from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.model_loader import get_model
LONG_LORA_INFOS = [{
class ContextIDInfo(TypedDict):
lora_id: int
context_length: str
class ContextInfo(TypedDict):
lora: str
context_length: str
LONG_LORA_INFOS: List[ContextIDInfo] = [{
"lora_id": 1,
"context_length": "16k",
}, {
......@@ -147,13 +159,21 @@ def dummy_model_gate_up() -> nn.Module:
@pytest.fixture(scope="session")
def sql_lora_files():
return snapshot_download(repo_id="yard1/llama-2-7b-sql-lora-test")
def sql_lora_huggingface_id():
# huggingface repo id is used to test lora runtime downloading.
return "yard1/llama-2-7b-sql-lora-test"
@pytest.fixture(scope="session")
def sql_lora_files(sql_lora_huggingface_id):
return snapshot_download(repo_id=sql_lora_huggingface_id)
@pytest.fixture(scope="session")
def mixtral_lora_files():
return snapshot_download(repo_id="terrysun/mixtral-lora-adapter")
# Note: this module has incorrect adapter_config.json to test
# https://github.com/vllm-project/vllm/pull/5909/files.
return snapshot_download(repo_id="SangBinCho/mixtral-lora")
@pytest.fixture(scope="session")
......@@ -207,7 +227,7 @@ def long_context_infos(long_context_lora_files_16k_1,
long_context_lora_files_16k_2,
long_context_lora_files_32k):
cleanup()
infos = {}
infos: Dict[int, ContextInfo] = {}
for lora_checkpoint_info in LONG_LORA_INFOS:
lora_id = lora_checkpoint_info["lora_id"]
if lora_id == 1:
......@@ -226,7 +246,7 @@ def long_context_infos(long_context_lora_files_16k_1,
@pytest.fixture
def llama_2_7b_engine_extra_embeddings() -> nn.Module:
def llama_2_7b_engine_extra_embeddings():
cleanup()
get_model_old = get_model
......@@ -244,7 +264,6 @@ def llama_2_7b_engine_extra_embeddings() -> nn.Module:
@pytest.fixture
def llama_2_7b_model_extra_embeddings(
llama_2_7b_engine_extra_embeddings) -> nn.Module:
def llama_2_7b_model_extra_embeddings(llama_2_7b_engine_extra_embeddings):
yield (llama_2_7b_engine_extra_embeddings.model_executor.driver_worker.
model_runner.model)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment