Commit 8d75f22e authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.13.0rc1' into v0.13.0rc1-ori

parents ce888aa4 7d80c73d
......@@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
from openai.types.responses.response_function_tool_call import ResponseFunctionToolCall
from openai.types.responses.response_function_tool_call_output_item import (
ResponseFunctionToolCallOutputItem,
)
......@@ -14,7 +15,8 @@ from openai.types.responses.response_reasoning_item import (
)
from vllm.entrypoints.responses_utils import (
construct_chat_message_with_tool_call,
_construct_single_message_from_response_item,
construct_chat_messages_with_tool_call,
convert_tool_responses_to_completions_format,
)
......@@ -42,7 +44,43 @@ class TestResponsesUtils:
assert result == {"type": "function", "function": input_tool}
def test_construct_chat_message_with_tool_call(self):
def test_construct_chat_messages_with_tool_call(self):
"""Test construction of chat messages with tool calls."""
reasoning_item = ResponseReasoningItem(
id="lol",
summary=[],
type="reasoning",
content=[
Content(
text="Leroy Jenkins",
type="reasoning_text",
)
],
encrypted_content=None,
status=None,
)
mcp_tool_item = ResponseFunctionToolCall(
id="mcp_123",
call_id="call_123",
type="function_call",
status="completed",
name="python",
arguments='{"code": "123+456"}',
)
input_items = [reasoning_item, mcp_tool_item]
messages = construct_chat_messages_with_tool_call(input_items)
assert len(messages) == 1
message = messages[0]
assert message["role"] == "assistant"
assert message["reasoning"] == "Leroy Jenkins"
assert message["tool_calls"][0]["id"] == "call_123"
assert message["tool_calls"][0]["function"]["name"] == "python"
assert (
message["tool_calls"][0]["function"]["arguments"] == '{"code": "123+456"}'
)
def test_construct_single_message_from_response_item(self):
item = ResponseReasoningItem(
id="lol",
summary=[],
......@@ -56,7 +94,7 @@ class TestResponsesUtils:
encrypted_content=None,
status=None,
)
formatted_item = construct_chat_message_with_tool_call(item)
formatted_item = _construct_single_message_from_response_item(item)
assert formatted_item["role"] == "assistant"
assert formatted_item["reasoning"] == "Leroy Jenkins"
......@@ -74,7 +112,7 @@ class TestResponsesUtils:
status=None,
)
formatted_item = construct_chat_message_with_tool_call(item)
formatted_item = _construct_single_message_from_response_item(item)
assert formatted_item["role"] == "assistant"
assert (
formatted_item["reasoning"]
......@@ -88,7 +126,7 @@ class TestResponsesUtils:
output="1234",
status="completed",
)
formatted_item = construct_chat_message_with_tool_call(tool_call_output)
formatted_item = _construct_single_message_from_response_item(tool_call_output)
assert formatted_item["role"] == "tool"
assert formatted_item["content"] == "1234"
assert formatted_item["tool_call_id"] == "temp"
......@@ -102,7 +140,7 @@ class TestResponsesUtils:
status=None,
)
with pytest.raises(ValueError):
construct_chat_message_with_tool_call(item)
_construct_single_message_from_response_item(item)
output_item = ResponseOutputMessage(
id="msg_bf585bbbe3d500e0",
......@@ -119,6 +157,6 @@ class TestResponsesUtils:
type="message",
)
formatted_item = construct_chat_message_with_tool_call(output_item)
formatted_item = _construct_single_message_from_response_item(output_item)
assert formatted_item["role"] == "assistant"
assert formatted_item["content"] == "dongyi"
......@@ -26,7 +26,14 @@ def clear_cache():
_cached_get_attn_backend.cache_clear()
@pytest.mark.parametrize("device", ["cpu", "hip", "cuda"])
devices = ["cpu"]
if current_platform.is_cuda():
devices.append("cuda")
if current_platform.is_rocm():
devices.append("hip")
@pytest.mark.parametrize("device", devices)
def test_mha_attn_platform(device: str):
"""
Test the attention selector between different platform and device.
......@@ -46,7 +53,7 @@ def test_mha_attn_platform(device: str):
patch("vllm.model_executor.models.vision.current_platform", RocmPlatform()),
):
attn = MultiHeadAttention(16, 64, scale=1)
assert attn.attn_backend == AttentionBackendEnum.TORCH_SDPA
assert attn.attn_backend == AttentionBackendEnum.FLASH_ATTN
else:
# Test CUDA with head_size=64 (divisible by 32)
# - should use vLLM's FlashAttention
......
......@@ -8,6 +8,12 @@ import torch
import vllm._custom_ops as ops
from tests.kernels.utils import opcheck
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
per_token_group_quant_fp8,
)
from vllm.model_executor.layers.quantization.utils.int8_utils import (
per_token_group_quant_int8,
)
DTYPES = [torch.bfloat16, torch.float]
QUANT_DTYPES = [torch.int8, torch.float8_e4m3fn]
......@@ -21,6 +27,7 @@ NUM_TOKENS_HIDDEN_SIZES = [
ADD_RESIDUAL = [False, True]
SCALE_UBS = [True, False]
GROUP_SIZES = [None, [1, 64], [1, 128]]
SEEDS = [0]
CUDA_DEVICES = [f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)]
......@@ -45,12 +52,13 @@ def ref_rms_norm(
return out, residual
def ref_dynamic_per_token_quant(
def ref_dynamic_per_token_or_block_quant(
rms_norm_layer: RMSNorm,
x: torch.Tensor,
quant_dtype: torch.dtype,
residual: torch.Tensor | None,
scale_ub: torch.Tensor | None,
group_size: list[int] | None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None]:
if scale_ub is not None:
assert quant_dtype == torch.float8_e4m3fn
......@@ -59,6 +67,17 @@ def ref_dynamic_per_token_quant(
torch_out, residual = ref_rms_norm(rms_norm_layer, x, residual)
# Quant
if group_size is not None:
if quant_dtype == torch.float8_e4m3fn:
torch_out, scales = per_token_group_quant_fp8(
torch_out, group_size=group_size[1], use_ue8m0=False
)
else:
assert quant_dtype == torch.int8
torch_out, scales = per_token_group_quant_int8(
torch_out, group_size=group_size[1]
)
else:
if quant_dtype == torch.float8_e4m3fn:
torch_out, scales = ops.scaled_fp8_quant(
torch_out, scale_ub=scale_ub, use_per_token_if_dynamic=True
......@@ -76,21 +95,29 @@ def ref_impl(
quant_dtype: torch.dtype,
residual: torch.Tensor | None,
scale_ub: torch.Tensor | None,
group_size: list[int] | None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None]:
return ref_dynamic_per_token_quant(
rms_norm_layer, x, quant_dtype, residual, scale_ub
return ref_dynamic_per_token_or_block_quant(
rms_norm_layer, x, quant_dtype, residual, scale_ub, group_size
)
def ops_dynamic_per_token_quant(
def ops_dynamic_per_token_or_block_quant(
weight: torch.Tensor,
x: torch.Tensor,
quant_dtype: torch.dtype,
residual: torch.Tensor | None,
scale_ub: torch.Tensor | None,
group_size: list[int] | None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None]:
if residual is not None:
residual = residual.clone()
if group_size is not None:
out, scales = ops.rms_norm_per_block_quant(
x, weight, EPS, quant_dtype, group_size, scale_ub, residual, True
)
scales = scales.contiguous()
else:
out, scales = ops.rms_norm_dynamic_per_token_quant(
x, weight, EPS, quant_dtype, scale_ub, residual
)
......@@ -103,8 +130,11 @@ def ops_impl(
quant_dtype: torch.dtype,
residual: torch.Tensor | None,
scale_ub: torch.Tensor | None,
group_size: list[int] | None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None]:
return ops_dynamic_per_token_quant(weight, x, quant_dtype, residual, scale_ub)
return ops_dynamic_per_token_or_block_quant(
weight, x, quant_dtype, residual, scale_ub, group_size
)
@pytest.mark.parametrize("num_tokens, hidden_size", NUM_TOKENS_HIDDEN_SIZES)
......@@ -112,6 +142,7 @@ def ops_impl(
@pytest.mark.parametrize("has_scale_ub", SCALE_UBS)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("quant_dtype", QUANT_DTYPES)
@pytest.mark.parametrize("group_size", GROUP_SIZES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@torch.inference_mode()
......@@ -122,6 +153,7 @@ def test_rms_norm(
has_scale_ub: bool,
dtype: torch.dtype,
quant_dtype: torch.dtype,
group_size: list[int] | None,
seed: int,
device: str,
) -> None:
......@@ -130,6 +162,14 @@ def test_rms_norm(
torch.cuda.manual_seed(seed)
torch.set_default_device(device)
if group_size is not None and hidden_size % group_size[1] != 0:
# skip
return
if group_size is not None and has_scale_ub:
# blockwise baseline doesn't support scale_ub
return
if has_scale_ub and quant_dtype != torch.float8_e4m3fn:
# skip
return
......@@ -150,10 +190,10 @@ def test_rms_norm(
scale_ub = None
ref_out, ref_scales, ref_residual = ref_impl(
layer, x, quant_dtype, residual, scale_ub
layer, x, quant_dtype, residual, scale_ub, group_size
)
ops_out, ops_scales, ops_residual = ops_impl(
layer.weight, x, quant_dtype, residual, scale_ub
layer.weight, x, quant_dtype, residual, scale_ub, group_size
)
assert ref_out.dtype == quant_dtype
......@@ -166,11 +206,15 @@ def test_rms_norm(
assert torch.allclose(ref_scales, ops_scales)
a = ref_out.to(dtype=torch.float32)
b = ops_out.to(dtype=torch.float32)
ok = torch.allclose(a, b)
ok = torch.allclose(a, b, atol=1e-6)
if not ok:
# fallback: compare dequantized values with relaxed tolerance
if group_size is None:
a_deq = a * ref_scales.view(-1, 1)
b_deq = b * ops_scales.view(-1, 1)
else:
a_deq = a * ref_scales.repeat_interleave(group_size[1], dim=1)
b_deq = b * ops_scales.repeat_interleave(group_size[1], dim=1)
# NOTE: It is possible that some future test cases trigger this
# max diff due to precision issues. If such an error is
# encountered, it's recommended to inspect the differences between
......
......@@ -113,12 +113,10 @@ def test_mrope(
is_neox_style = True
max_position = config.max_position_embeddings
partial_rotary_factor = getattr(config, "partial_rotary_factor", 1.0)
rotary_dim = int(head_dim * partial_rotary_factor)
mrope_helper_class = get_rope(
head_size=head_dim,
rotary_dim=rotary_dim,
rotary_dim=head_dim,
max_position=max_position,
is_neox_style=is_neox_style,
rope_parameters=config.rope_parameters,
......@@ -184,12 +182,10 @@ def test_mrope_torch_compile_tracing(
)
is_neox_style = True
max_position = config.max_position_embeddings
partial_rotary_factor = getattr(config, "partial_rotary_factor", 1.0)
rotary_dim = int(head_dim * partial_rotary_factor)
mrope_helper_class = get_rope(
head_size=head_dim,
rotary_dim=rotary_dim,
rotary_dim=head_dim,
max_position=max_position,
is_neox_style=is_neox_style,
rope_parameters=config.rope_parameters,
......
......@@ -425,6 +425,80 @@ def test_selective_state_update(dim, dstate, has_z, itype):
assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
@pytest.mark.parametrize("itype", [torch.float32, torch.bfloat16])
@pytest.mark.parametrize("has_z", [False, True])
@pytest.mark.parametrize("dstate", [16, 64])
@pytest.mark.parametrize("dim", [2048, 2048 + 16, 4096])
@pytest.mark.parametrize("max_seq_len", [1, 2, 4])
def test_selective_state_update_varlen(dim, dstate, has_z, itype, max_seq_len):
device = "cuda"
rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (5e-3, 1e-2)
if itype == torch.bfloat16:
rtol, atol = 5e-2, 1.5e-1
if torch.version.hip:
atol *= 2
# set seed
current_platform.seed_everything(0)
batch_size = 4
token_counts = torch.randint(1, max_seq_len + 1, (batch_size,), device=device)
total_tokens = int(token_counts.sum().item())
cu_seqlens = torch.tensor(
[0] + torch.cumsum(token_counts, dim=0).tolist(),
dtype=torch.int32,
device=device,
)
state = torch.randn(batch_size, dim, dstate, dtype=itype, device=device)
x = torch.randn(total_tokens, dim, device=device, dtype=itype)
out = torch.empty_like(x)
dt = torch.randn(total_tokens, dim, device=device, dtype=itype)
dt_bias = torch.rand(dim, device=device) - 4.0
A = -torch.rand(dim, dstate, device=device) - 1.0
B = torch.randn(total_tokens, dstate, device=device)
C = torch.randn(total_tokens, dstate, device=device)
D = torch.randn(dim, device=device)
z = torch.randn_like(x) if has_z else None
state_ref = state.detach().clone()
selective_state_update(
state,
x,
dt,
A,
B,
C,
D=D,
z=z,
dt_bias=dt_bias,
dt_softplus=True,
out=out,
cu_seqlens=cu_seqlens,
)
out_ref_list = []
for seq_idx in range(batch_size):
start_idx = cu_seqlens[seq_idx].item()
end_idx = cu_seqlens[seq_idx + 1].item()
num_tokens = end_idx - start_idx
for token_idx in range(num_tokens):
idx = start_idx + token_idx
out_ref_list.append(
selective_state_update_ref(
state_ref[seq_idx : seq_idx + 1],
x[idx : idx + 1],
dt[idx : idx + 1],
A,
B[idx : idx + 1],
C[idx : idx + 1],
D=D,
z=z[idx : idx + 1] if has_z else None,
dt_bias=dt_bias,
dt_softplus=True,
)
)
out_ref = torch.cat(out_ref_list, dim=0)
assert torch.allclose(state, state_ref, rtol=rtol, atol=atol)
assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
@pytest.mark.parametrize("wtype", [torch.float32])
@pytest.mark.parametrize("itype", [torch.float32])
@pytest.mark.parametrize("seqlen", [1, 256, 1024, 4096])
......@@ -766,3 +840,254 @@ def test_selective_state_update_with_heads_with_batch_indices(
print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
assert torch.allclose(state[state_indices, :], state_ref, rtol=rtol, atol=atol)
assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
@pytest.mark.parametrize("itype", [torch.float32, torch.bfloat16])
@pytest.mark.parametrize("has_z", [False, True])
@pytest.mark.parametrize("dstate", [16, 64])
@pytest.mark.parametrize("dim", [2048, 4096])
@pytest.mark.parametrize("max_seq_len", [2, 4])
def test_selective_state_update_with_num_accepted_tokens(
dim, dstate, has_z, itype, max_seq_len
):
device = "cuda"
rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (5e-3, 1e-2)
if itype == torch.bfloat16:
rtol, atol = 5e-2, 1.5e-1
if torch.version.hip:
atol *= 2
current_platform.seed_everything(0)
batch_size = 4
tokens_per_seq = torch.randint(1, max_seq_len + 1, (batch_size,), device=device)
total_tokens = int(tokens_per_seq.sum().item())
num_accepted_tokens = torch.randint(0, max_seq_len, (batch_size,), device=device)
num_accepted_tokens[0] = 0 # Add edge-case of no accepted tokens
num_accepted_tokens[1] = max_seq_len # Add edge-case of all tokens accepted
cu_seqlens = torch.tensor(
[0] + torch.cumsum(tokens_per_seq, dim=0).tolist(),
dtype=torch.int32,
device=device,
)
total_state_slots = 50
state = torch.randn(total_state_slots, dim, dstate, dtype=itype, device=device)
state_batch_indices = torch.full(
(batch_size, max_seq_len), PAD_SLOT_ID, dtype=torch.int32, device=device
)
initial_state_slots = torch.randint(
0, 15, (batch_size,), device=device, dtype=torch.int32
)
for seq_idx in range(batch_size):
token_pos = max(num_accepted_tokens[seq_idx].item() - 1, 0)
state_batch_indices[seq_idx, token_pos] = initial_state_slots[seq_idx]
dst_state_batch_indices = torch.full(
(batch_size, max_seq_len), PAD_SLOT_ID, dtype=torch.int32, device=device
)
slot_offset = 15
dst_slots_map = {}
for seq_idx in range(batch_size):
for token_idx in range(tokens_per_seq[seq_idx].item()):
dst_state_batch_indices[seq_idx, token_idx] = slot_offset
dst_slots_map[(seq_idx, token_idx)] = slot_offset
slot_offset += 1
x = torch.randn(total_tokens, dim, device=device, dtype=itype)
out = torch.empty_like(x)
dt = torch.randn(total_tokens, dim, device=device, dtype=itype)
dt_bias = torch.rand(dim, device=device) - 4.0
A = -torch.rand(dim, dstate, device=device) - 1.0
B = torch.randn(total_tokens, dstate, device=device)
C = torch.randn(total_tokens, dstate, device=device)
D = torch.randn(dim, device=device)
z = torch.randn_like(x) if has_z else None
state_ref_intermediate = {}
out_ref_list = []
for seq_idx in range(batch_size):
seq_start = cu_seqlens[seq_idx].item()
seq_end = cu_seqlens[seq_idx + 1].item()
num_tokens = seq_end - seq_start
token_pos = max(num_accepted_tokens[seq_idx].item() - 1, 0)
initial_slot = state_batch_indices[seq_idx, token_pos].item()
state_seq = state[initial_slot : initial_slot + 1].clone()
for token_idx in range(num_tokens):
global_idx = seq_start + token_idx
out_token = selective_state_update_ref(
state_seq,
x[global_idx : global_idx + 1],
dt[global_idx : global_idx + 1],
A,
B[global_idx : global_idx + 1],
C[global_idx : global_idx + 1],
D=D,
z=z[global_idx : global_idx + 1] if has_z else None,
dt_bias=dt_bias,
dt_softplus=True,
)
out_ref_list.append(out_token)
state_ref_intermediate[(seq_idx, token_idx)] = state_seq.clone()
out_ref = torch.cat(out_ref_list, dim=0)
selective_state_update(
state,
x,
dt,
A,
B,
C,
D=D,
z=z,
dt_bias=dt_bias,
dt_softplus=True,
out=out,
cu_seqlens=cu_seqlens,
state_batch_indices=state_batch_indices,
dst_state_batch_indices=dst_state_batch_indices,
num_accepted_tokens=num_accepted_tokens,
pad_slot_id=PAD_SLOT_ID,
)
assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
for seq_idx in range(batch_size):
num_tokens = tokens_per_seq[seq_idx].item()
for token_idx in range(num_tokens):
dst_slot = dst_slots_map[(seq_idx, token_idx)]
state_ref = state_ref_intermediate[(seq_idx, token_idx)].squeeze(0)
assert torch.allclose(state[dst_slot], state_ref, rtol=rtol, atol=atol)
@pytest.mark.parametrize("itype", [torch.float32, torch.bfloat16])
@pytest.mark.parametrize("has_z", [False, True])
@pytest.mark.parametrize("dstate", [16, 64])
@pytest.mark.parametrize("dim", [2048, 4096])
@pytest.mark.parametrize("max_seq_len", [2, 4])
def test_selective_state_update_varlen_with_num_accepted(
dim, dstate, has_z, itype, max_seq_len
):
device = "cuda"
rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (5e-3, 1e-2)
if itype == torch.bfloat16:
rtol, atol = 5e-2, 1.5e-1
if torch.version.hip:
atol *= 2
current_platform.seed_everything(0)
batch_size = 4
tokens_per_seq = torch.randint(1, max_seq_len + 1, (batch_size,), device=device)
total_tokens = int(tokens_per_seq.sum().item())
num_accepted_tokens = torch.randint(0, max_seq_len, (batch_size,), device=device)
num_accepted_tokens[0] = 0 # Add edge-case of no accepted tokens
num_accepted_tokens[1] = max_seq_len # Add edge-case of all tokens accepted
cu_seqlens = torch.tensor(
[0] + torch.cumsum(tokens_per_seq, dim=0).tolist(),
dtype=torch.int32,
device=device,
)
total_state_slots = 50
state = torch.randn(total_state_slots, dim, dstate, dtype=itype, device=device)
state_batch_indices = torch.full(
(batch_size, max_seq_len), PAD_SLOT_ID, dtype=torch.int32, device=device
)
initial_state_slots = torch.randint(
0, 15, (batch_size,), device=device, dtype=torch.int32
)
for seq_idx in range(batch_size):
token_pos = max(num_accepted_tokens[seq_idx].item() - 1, 0)
state_batch_indices[seq_idx, token_pos] = initial_state_slots[seq_idx]
dst_state_batch_indices = torch.full(
(batch_size, max_seq_len), PAD_SLOT_ID, dtype=torch.int32, device=device
)
slot_offset = 15
dst_slots_map = {}
for seq_idx in range(batch_size):
for token_idx in range(tokens_per_seq[seq_idx].item()):
dst_state_batch_indices[seq_idx, token_idx] = slot_offset
dst_slots_map[(seq_idx, token_idx)] = slot_offset
slot_offset += 1
x = torch.randn(total_tokens, dim, device=device, dtype=itype)
out = torch.empty_like(x)
dt = torch.randn(total_tokens, dim, device=device, dtype=itype)
dt_bias = torch.rand(dim, device=device) - 4.0
A = -torch.rand(dim, dstate, device=device) - 1.0
B = torch.randn(total_tokens, dstate, device=device)
C = torch.randn(total_tokens, dstate, device=device)
D = torch.randn(dim, device=device)
z = torch.randn_like(x) if has_z else None
state_ref_intermediate = {}
for seq_idx in range(batch_size):
seq_start = cu_seqlens[seq_idx].item()
seq_end = cu_seqlens[seq_idx + 1].item()
num_tokens = seq_end - seq_start
token_pos = max(num_accepted_tokens[seq_idx].item() - 1, 0)
initial_slot = state_batch_indices[seq_idx, token_pos].item()
state_seq = state[initial_slot : initial_slot + 1].clone()
for token_idx in range(num_tokens):
global_idx = seq_start + token_idx
selective_state_update_ref(
state_seq,
x[global_idx : global_idx + 1],
dt[global_idx : global_idx + 1],
A,
B[global_idx : global_idx + 1],
C[global_idx : global_idx + 1],
D=D,
z=z[global_idx : global_idx + 1] if has_z else None,
dt_bias=dt_bias,
dt_softplus=True,
)
state_ref_intermediate[(seq_idx, token_idx)] = state_seq.clone()
selective_state_update(
state,
x,
dt,
A,
B,
C,
D=D,
z=z,
dt_bias=dt_bias,
dt_softplus=True,
out=out,
cu_seqlens=cu_seqlens,
state_batch_indices=state_batch_indices,
dst_state_batch_indices=dst_state_batch_indices,
num_accepted_tokens=num_accepted_tokens,
pad_slot_id=PAD_SLOT_ID,
)
for seq_idx in range(batch_size):
num_tokens = tokens_per_seq[seq_idx].item()
for token_idx in range(num_tokens):
dst_slot = dst_slots_map[(seq_idx, token_idx)]
state_ref = state_ref_intermediate[(seq_idx, token_idx)].squeeze(0)
assert torch.allclose(state[dst_slot], state_ref, rtol=rtol, atol=atol)
......@@ -13,9 +13,6 @@ from vllm.model_executor.layers.fused_moe.all2all_utils import (
from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import (
BatchedDeepGemmExperts,
)
from vllm.model_executor.layers.fused_moe.batched_triton_or_deep_gemm_moe import (
BatchedTritonOrDeepGemmExperts,
)
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig,
FusedMoEQuantConfig,
......@@ -286,16 +283,6 @@ if has_deep_gemm() and is_deep_gemm_supported():
needs_matching_quant=False,
needs_deep_gemm=True,
)
register_experts(
BatchedTritonOrDeepGemmExperts,
batched_format,
common_float_and_int_types,
blocked_quantization_support=True,
supports_chunking=False,
supports_expert_map=False,
needs_matching_quant=True,
needs_deep_gemm=True,
)
register_experts(
TritonOrDeepGemmExperts,
standard_format,
......@@ -457,10 +444,6 @@ def make_fused_experts(
kwargs = batch_kwargs | quant_kwargs
print(f"Making BatchedTritonExperts {kwargs} ...")
experts = BatchedTritonExperts(**kwargs)
elif fused_experts_type == BatchedTritonOrDeepGemmExperts:
kwargs = batch_kwargs | quant_kwargs | deepgemm_kwargs
print(f"Making BatchedTritonOrDeepGemmExperts {kwargs} ...")
experts = BatchedTritonOrDeepGemmExperts(**kwargs)
elif fused_experts_type == DeepGemmExperts:
print(f"Making DeepGemmExperts {quant_config} ...")
experts = DeepGemmExperts(quant_config)
......
......@@ -955,9 +955,22 @@ def test_fused_marlin_moe_with_bias(m):
torch.testing.assert_close(marlin_output, torch_output, atol=5e-2, rtol=0)
def test_moe_align_block_size_opcheck():
@pytest.mark.parametrize("ep_size", [1, 2])
def test_moe_align_block_size_opcheck(ep_size):
num_experts = 4
block_size = 4
expert_map = None
if ep_size != 1:
local_num_experts = num_experts // ep_size
expert_ids = torch.randint(
0, num_experts, (local_num_experts,), device="cuda", dtype=torch.int32
)
expert_map = torch.full((num_experts,), -1, device="cuda", dtype=torch.int32)
expert_map[expert_ids] = torch.arange(
local_num_experts, device="cuda", dtype=torch.int32
)
topk_ids = torch.randint(0, num_experts, (3, 4), dtype=torch.int32, device="cuda")
max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1)
......@@ -980,6 +993,7 @@ def test_moe_align_block_size_opcheck():
sorted_ids,
expert_ids,
num_tokens_post_pad,
expert_map,
),
)
......
......@@ -106,6 +106,8 @@ def torch_moe_align_block_size(
max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1)
if pad_sorted_ids:
max_num_tokens_padded = round_up(max_num_tokens_padded, block_size)
if topk_ids.numel() < num_experts:
max_num_tokens_padded = topk_ids.numel() * block_size
flattened_token_indices = torch.arange(
topk_ids.numel(), device=topk_ids.device, dtype=torch.int32
......@@ -126,6 +128,8 @@ def torch_moe_align_block_size(
)
for expert_id in range(num_experts):
original_count = expert_token_counts[expert_id]
if expert_map is not None and expert_map[expert_id] == -1:
continue
if original_count > 0:
expert_padded_counts[expert_id] = (
(original_count + block_size - 1) // block_size
......@@ -143,6 +147,9 @@ def torch_moe_align_block_size(
current_pos = 0
current_block = 0
for expert_id in range(num_experts):
if expert_map is not None and expert_map[expert_id] == -1:
continue
expert_mask = sorted_expert_ids == expert_id
expert_tokens = sorted_token_indices[expert_mask]
num_expert_tokens = expert_tokens.shape[0]
......@@ -153,7 +160,13 @@ def torch_moe_align_block_size(
)
expert_blocks_needed = expert_padded_counts[expert_id] // block_size
expert_ids[current_block : current_block + expert_blocks_needed] = expert_id
expert_id_new = expert_id
if expert_map is not None:
expert_id_new = expert_map[expert_id]
expert_ids[current_block : current_block + expert_blocks_needed] = (
expert_id_new
)
current_pos += expert_padded_counts[expert_id]
current_block += expert_blocks_needed
......@@ -163,8 +176,6 @@ def torch_moe_align_block_size(
[total_padded_tokens], dtype=torch.int32, device=topk_ids.device
)
if expert_map is not None:
expert_ids = expert_map[expert_ids]
return sorted_token_ids, expert_ids, num_tokens_post_pad
......@@ -229,9 +240,9 @@ def test_moe_align_block_size(
)
@pytest.mark.parametrize("m", [16, 32])
@pytest.mark.parametrize("m", [16, 32, 2048])
@pytest.mark.parametrize("topk", [2, 4])
@pytest.mark.parametrize("num_experts", [8])
@pytest.mark.parametrize("num_experts", [8, 64])
@pytest.mark.parametrize("block_size", [64])
@pytest.mark.skipif(current_platform.is_rocm(), reason="Skip for rocm")
def test_moe_align_block_size_with_expert_map(
......@@ -253,6 +264,7 @@ def test_moe_align_block_size_with_expert_map(
block_size=block_size,
num_experts=num_experts,
expert_map=expert_map,
ignore_invalid_experts=True,
)
golden_sorted_ids, golden_expert_ids, golden_num_tokens = (
torch_moe_align_block_size(
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
import torch
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
_per_token_group_quant_fp8_colmajor,
silu_mul_per_token_group_quant_fp8_colmajor,
)
from vllm.platforms import current_platform
from vllm.triton_utils import triton
from vllm.utils.deep_gemm import is_deep_gemm_e8m0_used
FLOAT8_DTYPE = torch.float8_e4m3fn
GROUP_SIZE = 128
def reference_quant(x: torch.Tensor, use_ue8m0: bool):
"""
Reference triton quant kernel from,
vllm.model_executor.layers.quantization.utils.fp8_utils
"""
x_q = torch.empty_like(x, device=x.device, dtype=FLOAT8_DTYPE)
# Allocate the scale tensor in column-major format.
shape = (x.shape[-1] // GROUP_SIZE,) + x.shape[:-1]
x_s = torch.empty(shape, device=x.device, dtype=torch.float32).permute(-1, -2)
M = x.numel() // GROUP_SIZE
N = GROUP_SIZE
BLOCK = triton.next_power_of_2(N)
# heuristics for number of warps
num_warps = min(max(BLOCK // 256, 1), 8)
num_stages = 1
finfo = torch.finfo(FLOAT8_DTYPE)
fp8_min = finfo.min
fp8_max = finfo.max
_per_token_group_quant_fp8_colmajor[(M,)](
x,
x_q,
x_s,
GROUP_SIZE,
x.shape[1],
x.stride(0),
x_s.stride(1),
eps=1e-10,
fp8_min=fp8_min,
fp8_max=fp8_max,
use_ue8m0=use_ue8m0,
BLOCK=BLOCK,
num_warps=num_warps,
num_stages=num_stages,
)
return x_q, x_s
def reference(x: torch.Tensor, use_ue8m0: bool) -> tuple[torch.Tensor, torch.Tensor]:
T, N = x.size()
ref_act_out = torch.empty((T, N // 2), dtype=torch.bfloat16, device="cuda")
torch.ops._C.silu_and_mul(ref_act_out, x)
return reference_quant(ref_act_out, use_ue8m0)
@pytest.mark.parametrize("T", [128, 256, 512])
@pytest.mark.parametrize("N", [128 * 2, 256 * 2, 768 * 2, 2048 * 2, 7168 * 2])
def test_silu_mul_fp8_quant_deep_gemm(T: int, N: int):
current_platform.seed_everything(42)
input = torch.rand((T, N), dtype=torch.bfloat16, device="cuda")
use_ue8m0 = is_deep_gemm_e8m0_used()
# Test
output, output_scales = silu_mul_per_token_group_quant_fp8_colmajor(
input, use_ue8m0=use_ue8m0
)
# Reference
ref_output, ref_output_scales = reference(input, use_ue8m0)
torch.testing.assert_close(output.to(torch.float32), ref_output.to(torch.float32))
torch.testing.assert_close(output_scales, ref_output_scales)
......@@ -103,7 +103,7 @@ def ref_dynamic_per_tensor_fp8_quant(
.clamp(fp8_traits_min, fp8_traits_max)
.to(FP8_DTYPE)
)
return ref_out, ref_scale.view((1, 1))
return ref_out, ref_scale.view(1)
def native_w8a8_block_matmul(
......
......@@ -54,6 +54,10 @@ def setup_cuda():
torch.set_default_device("cuda")
@pytest.mark.skipif(
current_platform.is_fp8_fnuz(),
reason="This platform supports e4m3fnuz, not e4m3fn.",
)
@pytest.mark.parametrize(
"num_tokens,d,dtype,group_size,seed",
itertools.product(NUM_TOKENS, D, DTYPES, GROUP_SIZE, SEEDS),
......@@ -78,14 +82,14 @@ def test_per_token_group_quant_fp8(num_tokens, d, dtype, group_size, seed):
def test_w8a8_block_fp8_matmul(M, N, K, block_size, out_dtype, seed):
torch.manual_seed(seed)
factor_for_scale = 1e-2
fp8_info = torch.finfo(torch.float8_e4m3fn)
fp8_info = torch.finfo(current_platform.fp8_dtype())
fp8_max, fp8_min = fp8_info.max, fp8_info.min
A_fp32 = (torch.rand(M, K, dtype=torch.float32) - 0.5) * 2 * fp8_max
A_fp8 = A_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)
A_fp8 = A_fp32.clamp(min=fp8_min, max=fp8_max).to(current_platform.fp8_dtype())
B_fp32 = (torch.rand(N, K, dtype=torch.float32) - 0.5) * 2 * fp8_max
B_fp8 = B_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)
B_fp8 = B_fp32.clamp(min=fp8_min, max=fp8_max).to(current_platform.fp8_dtype())
block_n, block_k = block_size[0], block_size[1]
n_tiles = (N + block_n - 1) // block_n
......@@ -103,6 +107,9 @@ def test_w8a8_block_fp8_matmul(M, N, K, block_size, out_dtype, seed):
assert rel_diff < 0.001
@pytest.mark.skipif(
not current_platform.is_cuda(), reason="CUTLASS only supported on CUDA platform."
)
@torch.inference_mode()
def test_w8a8_block_fp8_cutlass_matmul():
# Test simple case where weight.shape % 128 != 0,
......@@ -151,6 +158,10 @@ def test_w8a8_block_fp8_cutlass_matmul():
assert rel_diff < 0.001
@pytest.mark.skipif(
current_platform.is_fp8_fnuz(),
reason="This platform supports e4m3fnuz, not e4m3fn.",
)
@pytest.mark.parametrize(
"M,N,K,block_size,out_dtype,seed",
itertools.product(M, N, K, BLOCK_SIZE, OUT_DTYPES, SEEDS),
......
......@@ -15,6 +15,9 @@ from vllm import _custom_ops as ops
from vllm.platforms import current_platform
from vllm.utils.math_utils import cdiv
if not current_platform.is_cuda():
pytest.skip("These tests use CUTLASS which requires CUDA", allow_module_level=True)
MNK_FACTORS = [
(1, 256, 128),
(1, 16384, 1024),
......
......@@ -12,12 +12,18 @@ import torch
from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.utils.quant_utils import (
convert_packed_uint4b8_to_signed_int4_inplace,
pack_cols,
pack_rows,
quantize_weights,
unpack_quantized_values_into_int32,
)
from vllm.platforms import current_platform
from vllm.scalar_type import ScalarType, scalar_types
if not current_platform.is_cuda():
pytest.skip("These tests use CUTLASS which requires CUDA", allow_module_level=True)
# TODO: in future PR refactor this and `is_quant_method_supported` in the kernel
# unit tests to a common utility function. Currently the use of
# `is_quant_method_supported` conflates kernels with quantization methods
......@@ -167,8 +173,7 @@ def create_test_tensors(
# for the practical use case we need per-tok scales for fp8 activations
w_tok_s = torch.randn((m,), device="cuda", dtype=types.token_scale_type)
# weights are already per-group quantized, use placeholder here
w_ch_s = torch.ones((n,), device="cuda", dtype=types.channel_scale_type)
w_ch_s = torch.randn((n,), device="cuda", dtype=types.channel_scale_type)
return Tensors(
w_ref=w_ref,
......@@ -211,7 +216,7 @@ def mm_test_helper(
print(output_ref)
torch.testing.assert_close(
output, output_ref.to(output.dtype), rtol=1e-3, atol=1e-3
output, output_ref.to(output.dtype), rtol=1e-2, atol=1e-2
)
......@@ -257,7 +262,7 @@ def test_w4a8_cuda_graph():
)
w_tok_s = torch.randn((m,), device="cuda", dtype=torch.float32)
w_ch_s = torch.ones((n,), device="cuda", dtype=torch.float32)
w_ch_s = torch.randn((n,), device="cuda", dtype=torch.float32)
# Construct a trivial model with a single layer that calls the kernel
model = W4A8Layer(
......@@ -287,4 +292,38 @@ def test_w4a8_cuda_graph():
output.zero_()
g.replay()
torch.testing.assert_close(output, output_ref, rtol=1e-3, atol=1e-3)
torch.testing.assert_close(output, output_ref, rtol=1e-2, atol=1e-2)
@pytest.mark.skipif(
not IS_SUPPORTED_BY_GPU, reason="CUTLASS W4A8 is not supported on this GPU type."
)
@pytest.mark.parametrize("shape", MNK_SHAPES)
def test_convert_packed_uint4b8_to_signed_int4_inplace(shape):
"""
The W4A16 checkpoints encode the weights as int4b8 packed to int32.
The CUTLASS kernels expect signed int4 packed to int32.
This tests checks that the runtime int4b8 -> signed int4 conversion
matches the offline conversion step exactly.
"""
_, N, K = shape
# random weights packed to int32
t = torch.randint(
low=torch.iinfo(torch.int32).min,
high=torch.iinfo(torch.int32).max + 1,
size=(N, K // 8),
dtype=torch.int32,
device="cuda",
)
# compute reference
unpacked = unpack_quantized_values_into_int32(
t.clone(), scalar_types.uint4b8, packed_dim=1
)
unpacked = unpacked - 8 # int4b8 -> signed int4
ref = pack_cols(unpacked & 0x0F, 4, *unpacked.shape)
out = convert_packed_uint4b8_to_signed_int4_inplace(t.clone())
assert torch.equal(ref, out)
assert not torch.equal(ref, t)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Tests for the CUTLASS-based W4A8 grouped GEMM kernel and the full MoE layer.
"""
import random
from dataclasses import dataclass
import pytest
import torch
from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.utils.quant_utils import (
pack_rows,
quantize_weights,
)
from vllm.platforms import current_platform
from vllm.scalar_type import ScalarType, scalar_types
IS_SUPPORTED_BY_GPU = current_platform.get_device_capability()[0] >= 9
def to_fp8(tensor: torch.Tensor) -> torch.Tensor:
finfo = torch.finfo(torch.float8_e4m3fn)
return tensor.clamp(min=finfo.min, max=finfo.max).to(dtype=torch.float8_e4m3fn)
def cutlass_quantize(
atype: torch.dtype,
w: torch.Tensor,
wtype: ScalarType,
stype: torch.dtype | None,
group_size: int | None,
zero_points: bool = False,
):
"""
Quantize weights into W4 and compute reference dequantized weights.
Encoding/reordering of weights and packing of scales is deferred
until after all experts are combined.
"""
assert wtype.is_integer(), "TODO: support floating point weights"
w_ref, w_q, w_s, w_zp = quantize_weights(
w, wtype, group_size=group_size, zero_points=zero_points
)
# Since scales are later cast to fp8, recompute w_ref in atype here.
w_ref = (
w_q.to(torch.float32)
* w_s.to(atype).to(torch.float32).repeat_interleave(group_size, dim=0)
).to(atype)
# Bit mask prevents sign extension of int4 when packing.
w_q = pack_rows(w_q & 0x0F, wtype.size_bits, *w_q.shape)
# Make weights row-major (N, K).
w_q = w_q.t().contiguous()
return w_ref, w_q, w_s.to(atype), w_zp
def cutlass_preprocess(
w_q_experts: list[torch.Tensor], w_s_experts: list[torch.Tensor]
):
"""
Reorder/encode expert weights and pack scales.
Returns:
w_q_packed: Packed/encoded int4 weights for all experts.
w_s_packed: Packed fp8 scales for all experts.
packed_layout: Layout/stride metadata for grouped GEMM.
"""
w_s_packed = ops.cutlass_pack_scale_fp8(torch.stack(w_s_experts))
w_q_packed, packed_layout = ops.cutlass_encode_and_reorder_int4b_grouped(
torch.stack(w_q_experts)
) # expects dim 3
return w_q_packed, w_s_packed, packed_layout
GROUP_SIZE = 128
# (num_experts, N, K)
TEST_SHAPES = [
(8, 512, 2048),
(8, 2048, 2048),
(64, 512, 1024),
(64, 2048, 2048),
(4, 2048, 768),
(8, 768, 2048),
(64, 1536, 2048),
(128, 8192, 4096), # test overflow int32
]
ALIGNMENT = 16 # torch._scaled_mm alignment for M, needed for reference check
@dataclass
class MoETestSetup:
num_experts: int
K: int
N: int
Ms: list[int]
M_full: int
a: torch.Tensor
a_ref: torch.Tensor
a_strides: torch.Tensor
out: torch.Tensor
c_strides: torch.Tensor
per_tok_scales: torch.Tensor
per_chan_scales: torch.Tensor
w_refs: list[torch.Tensor]
w_q_packed: torch.Tensor
w_s_packed: torch.Tensor
problem_sizes: torch.Tensor
expert_offsets: torch.Tensor
b_strides: torch.Tensor
group_scale_strides: torch.Tensor
def make_moe_test_setup(
num_experts: int,
K: int,
N: int,
*,
alignment: int = ALIGNMENT,
max_blocks: int = 64,
device: str = "cuda",
random_zero: bool = False,
) -> MoETestSetup:
"""Create a full set of tensors for testing cutlass_w4a8_moe_mm."""
assert K % GROUP_SIZE == 0
# Token counts per expert (multiples of `alignment`).
Ms = [alignment * random.randint(1, max_blocks) for _ in range(num_experts)]
# set random experts to 0 tokens
if random_zero and num_experts > 1:
num_zero = max(1, num_experts // 8)
zero_indices = random.sample(range(num_experts), k=num_zero)
for idx in zero_indices:
Ms[idx] = 0
M_full = sum(Ms)
assert M_full > 0
# Activations.
a = to_fp8(torch.randn((M_full, K), device=device))
a_ref = a.to(torch.float32)
a_strides = torch.full((num_experts,), K, dtype=torch.int64, device=device)
# Output buffer.
out = torch.empty((M_full, N), dtype=torch.bfloat16, device=device)
c_strides = torch.full((num_experts,), N, dtype=torch.int64, device=device)
# Channel/token scales.
per_tok_scales = torch.randn((M_full, 1), dtype=torch.float32, device=device)
per_chan_scales = torch.randn(
(num_experts, N, 1), dtype=torch.float32, device=device
)
# Expert weights and scales.
wtype = scalar_types.int4
atype = stype = torch.float8_e4m3fn
w_refs, w_qs, w_ss = [], [], []
for _ in range(num_experts):
b = to_fp8(torch.randn((K, N), device=device))
w_ref, w_q, w_s, _ = cutlass_quantize(
atype, b.to(torch.float16), wtype, stype, GROUP_SIZE, zero_points=False
)
w_refs.append(w_ref)
w_qs.append(w_q)
w_ss.append(w_s)
w_q_packed, w_s_packed, packed_layout = cutlass_preprocess(w_qs, w_ss)
problem_sizes = torch.tensor(
[[N, M, K] for M in Ms], dtype=torch.int32, device=device
)
expert_offsets = torch.cat(
[
torch.tensor([0], dtype=torch.int64),
torch.cumsum(torch.tensor(Ms, dtype=torch.int64), dim=0)[:-1],
]
).to(device=device)
# B strides and group scale strides.
b_strides = packed_layout
group_scale_strides = torch.zeros(
(num_experts, 2), dtype=torch.int64, device=device
)
group_scale_strides[:, 0] = N
return MoETestSetup(
num_experts=num_experts,
K=K,
N=N,
Ms=Ms,
M_full=M_full,
a=a,
a_ref=a_ref,
a_strides=a_strides,
out=out,
c_strides=c_strides,
per_tok_scales=per_tok_scales,
per_chan_scales=per_chan_scales,
w_refs=w_refs,
w_q_packed=w_q_packed,
w_s_packed=w_s_packed,
problem_sizes=problem_sizes,
expert_offsets=expert_offsets,
b_strides=b_strides,
group_scale_strides=group_scale_strides,
)
def compute_moe_reference_output(setup: MoETestSetup) -> torch.Tensor:
"""Compute reference output using torch._scaled_mm per expert."""
out_ref = torch.empty_like(setup.out)
ends = torch.cumsum(torch.tensor(setup.Ms), 0).tolist()
starts = setup.expert_offsets.cpu().tolist()
for i in range(setup.num_experts):
start, end = starts[i], ends[i]
if start == end:
continue
out_ref_i = torch._scaled_mm(
setup.a_ref[start:end].to(torch.float8_e4m3fn),
setup.w_refs[i].to(torch.float8_e4m3fn).t().contiguous().t(),
setup.per_tok_scales[start:end], # (M, 1)
setup.per_chan_scales[i].reshape(1, -1), # (1, N)
out_dtype=torch.bfloat16,
use_fast_accum=True,
)
out_ref[start:end] = out_ref_i
return out_ref
@pytest.mark.skipif(
not IS_SUPPORTED_BY_GPU,
reason="W4A8 Grouped GEMM is not supported on this GPU type.",
)
@pytest.mark.parametrize("shape", TEST_SHAPES)
@pytest.mark.parametrize("random_zero", [True, False])
def test_cutlass_w4a8_moe_mm_end_to_end(shape, random_zero):
num_experts, N, K = shape
current_platform.seed_everything(42)
setup = make_moe_test_setup(
num_experts=num_experts, K=K, N=N, max_blocks=64, random_zero=random_zero
)
ops.cutlass_w4a8_moe_mm(
setup.out,
setup.a,
setup.w_q_packed,
setup.per_tok_scales,
setup.per_chan_scales,
setup.w_s_packed,
GROUP_SIZE,
setup.expert_offsets,
setup.problem_sizes,
setup.a_strides,
setup.b_strides,
setup.c_strides,
setup.group_scale_strides,
)
torch.cuda.synchronize()
out_ref = compute_moe_reference_output(setup)
torch.testing.assert_close(setup.out, out_ref, rtol=1e-2, atol=1e-2)
class W4A8MoELayer(torch.nn.Module):
"""
Minimal wrapper module to test cuda graphs
"""
def __init__(self, setup: MoETestSetup):
super().__init__()
self.setup = setup
def forward(self, a: torch.Tensor) -> torch.Tensor:
s = self.setup
ops.cutlass_w4a8_moe_mm(
s.out,
a,
s.w_q_packed,
s.per_tok_scales,
s.per_chan_scales,
s.w_s_packed,
GROUP_SIZE,
s.expert_offsets,
s.problem_sizes,
s.a_strides,
s.b_strides,
s.c_strides,
s.group_scale_strides,
)
return s.out
@pytest.mark.skipif(
not IS_SUPPORTED_BY_GPU,
reason="W4A8 Grouped GEMM is not supported on this GPU type.",
)
def test_cutlass_w4a8_moe_mm_cuda_graph():
current_platform.seed_everything(42)
# Fixed config for CUDA graph test (single parameter point).
num_experts = 8
K = 512
N = 2048
setup = make_moe_test_setup(
num_experts=num_experts,
K=K,
N=N,
max_blocks=32,
)
# Construct model that calls the grouped GEMM kernel.
model = W4A8MoELayer(setup)
# Build reference output once.
out_ref = compute_moe_reference_output(setup)
# Capture and run the model in a CUDA graph.
a_static = setup.a.clone() # static input tensor for graph replay
stream = torch.cuda.Stream()
with torch.cuda.stream(stream):
g = torch.cuda.CUDAGraph()
with torch.cuda.graph(g):
out_static = model(a_static)
out_static.zero_()
g.replay()
torch.testing.assert_close(out_static, out_ref, rtol=1e-2, atol=1e-2)
......@@ -8,6 +8,13 @@ import torch
from compressed_tensors.transform import deterministic_hadamard_matrix
from vllm import _custom_ops as ops
from vllm.platforms import current_platform
if current_platform.is_rocm():
pytest.skip(
"These tests require hadacore_transform, not supported on ROCm.",
allow_module_level=True,
)
@pytest.mark.parametrize("batch_size", [1, 32])
......
......@@ -23,6 +23,12 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
from vllm.platforms import current_platform
from vllm.scalar_type import ScalarType, scalar_types
if current_platform.is_rocm():
pytest.skip(
"These tests require machete_prepack_B, not supported on ROCm.",
allow_module_level=True,
)
CUDA_DEVICES = [f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)]
# TODO: in future PR refactor this and `is_quant_method_supported` in the kernel
......
......@@ -56,6 +56,14 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
from vllm.platforms import current_platform
from vllm.scalar_type import scalar_types
if current_platform.is_rocm():
pytest.skip(
"These tests require gptq_marlin_repack,"
"marlin_int4_fp8_preprocess, gptq_marlin_24_gemm,"
"or gptq_marlin_gemm which are not supported on ROCm.",
allow_module_level=True,
)
ACT_ORDER_OPTS = [False, True]
K_FULL_OPTS = [False, True]
USE_ATOMIC_ADD_OPTS = [False, True]
......
......@@ -9,23 +9,45 @@ from vllm.platforms import current_platform
# Test parameters
NUM_ROWS = [1, 32, 2050]
TOP_K_VALUES = [2048]
BATCH_SIZE = [1, 2, 4, 2048, 4096]
NEXT_N = [1, 2, 4, 8]
TOP_K_VALUES = [2048, 3000]
BATCH_SIZE = [1, 2, 2048]
NEXT_N = [1, 8]
DATA_GENERATION = ["random", "10LSBits"]
def create_random_logits(
row_starts: torch.Tensor,
row_ends: torch.Tensor,
vocab_size: int,
dtype: torch.dtype,
seed: int,
data_generation: str,
) -> torch.Tensor:
"""Create random logits tensor for testing."""
torch.manual_seed(seed)
np.random.seed(seed)
# Generate logits with some structure to make testing more meaningful
logits = torch.randn(row_starts.shape[0], max(row_ends), dtype=dtype, device="cuda")
if data_generation == "random":
logits = torch.randn(
row_starts.shape[0], max(row_ends), dtype=dtype, device="cuda"
)
elif data_generation == "10LSBits":
top_22_bits_mask = 0xFFFFFC00
last_10_bits_mask = 0x000003FF
fixed_top_22_bits = 0x3F900000
# Generate random bits for the last 10 bits
random_bottom_bits = torch.randint(
0,
2**10,
(row_starts.shape[0], max(row_ends)),
dtype=torch.int32,
device="cuda",
)
# Combine: fixed top 22 bits with random last 10 bits
logits_bits = (fixed_top_22_bits & top_22_bits_mask) | (
random_bottom_bits & last_10_bits_mask
)
logits = logits_bits.view(dtype)
for i, end in enumerate(row_ends):
logits[i, end:] = float("-inf")
return logits
......@@ -113,13 +135,13 @@ def test_top_k_per_row(
# Create test data
vocab_size = 20000
row_starts, row_ends = create_row_boundaries(num_rows, vocab_size)
logits = create_random_logits(row_starts, row_ends, vocab_size, torch.float32, 42)
logits = create_random_logits(row_starts, row_ends, torch.float32, 42, "random")
# Create output tensors
indices = torch.empty((num_rows, top_k), dtype=torch.int32, device="cuda")
# Run CUDA implementation
torch.ops._C.top_k_per_row(
torch.ops._C.top_k_per_row_prefill(
logits,
row_starts,
row_ends,
......@@ -127,6 +149,7 @@ def test_top_k_per_row(
num_rows,
logits.stride(0),
logits.stride(1),
top_k,
)
# Run reference implementation
......@@ -139,27 +162,23 @@ def test_top_k_per_row(
# Compare results
assert compare_top_k_results(
logits, indices, torch_indices, row_starts, row_ends, top_k
), "CUDA top_k_per_row results don't match torch.topk"
), "CUDA top_k_per_row_prefill results don't match torch.topk"
@pytest.mark.parametrize("top_k", TOP_K_VALUES)
@pytest.mark.parametrize("batch_size", BATCH_SIZE)
@pytest.mark.parametrize("next_n", NEXT_N)
@pytest.mark.skipif(not current_platform.is_cuda(), reason="This test requires CUDA")
@torch.inference_mode()
def test_top_k_per_row_decode(
def _run_top_k_per_row_decode_test(
top_k: int,
batch_size: int,
next_n: int,
vocab_size: int,
data_generation: str,
) -> None:
"""
Test top_k_per_row with seq_lens tensor.
Helper function to run top_k_per_row_decode test with given parameters.
"""
torch.set_default_device("cuda:0")
# Create test data
num_rows = batch_size * next_n
vocab_size = 20000
seq_lens = torch.randint(
vocab_size, (batch_size,), dtype=torch.int32, device="cuda"
)
......@@ -167,7 +186,9 @@ def test_top_k_per_row_decode(
row_indices = torch.arange(num_rows, device="cuda") // next_n
next_n_offset = torch.arange(num_rows, device="cuda") % next_n
row_ends = seq_lens[row_indices] - next_n + next_n_offset + 1
logits = create_random_logits(row_starts, row_ends, vocab_size, torch.float32, 42)
logits = create_random_logits(
row_starts, row_ends, torch.float32, 42, data_generation
)
# Create output tensors
indices = torch.empty((num_rows, top_k), dtype=torch.int32, device="cuda")
......@@ -181,6 +202,7 @@ def test_top_k_per_row_decode(
num_rows,
logits.stride(0),
logits.stride(1),
top_k,
)
torch.cuda.synchronize()
......@@ -195,4 +217,41 @@ def test_top_k_per_row_decode(
# Compare results
assert compare_top_k_results(
logits, indices, torch_indices, row_starts, row_ends, top_k
), "CUDA top_k_per_row results don't match torch.topk"
), "CUDA top_k_per_row_decode results don't match torch.topk"
@pytest.mark.parametrize("top_k", TOP_K_VALUES)
@pytest.mark.parametrize("batch_size", BATCH_SIZE)
@pytest.mark.parametrize("next_n", NEXT_N)
@pytest.mark.parametrize("data_generation", DATA_GENERATION)
@pytest.mark.skipif(not current_platform.is_cuda(), reason="This test requires CUDA")
@torch.inference_mode()
def test_top_k_per_row_decode(
top_k: int,
batch_size: int,
next_n: int,
data_generation: str,
) -> None:
"""
Test top_k_per_row with seq_lens tensor.
"""
vocab_size = 20000
_run_top_k_per_row_decode_test(
top_k, batch_size, next_n, vocab_size, data_generation
)
@pytest.mark.skipif(not current_platform.is_cuda(), reason="This test requires CUDA")
@torch.inference_mode()
def test_top_k_per_row_decode_large_vocab_size() -> None:
"""
Test top_k_per_row_decode with large vocabulary size.
"""
top_k = 2048
batch_size = 2
next_n = 2
vocab_size = 300000
data_generation = "random"
_run_top_k_per_row_decode_test(
top_k, batch_size, next_n, vocab_size, data_generation
)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import os
import random
import torch
from tqdm import tqdm
from vllm.config import KVTransferConfig
from vllm.distributed.kv_transfer.kv_lookup_buffer.simple_buffer import SimpleBuffer
from vllm.distributed.kv_transfer.kv_pipe.pynccl_pipe import PyNcclPipe
# TODO: the test depends on a lot of fields in the current implementation.
# We should have standard interface instead direct field access
def test_run(my_rank, buffer, device):
# buffer should be empty in the beginning
if my_rank == 0:
assert buffer.buffer_size == 0
assert len(buffer.buffer) == 0
print(f"My rank: {my_rank}, device: {device}")
# insert
tokens = torch.tensor([1, 2, 3]).to(device)
roi = tokens > 0
if my_rank == 0:
key = 2.0 * torch.ones([5, 6]).to(device)
value = 3.0 * torch.ones([5, 6]).to(device)
placeholder = torch.tensor([1]).to(device)
buffer.insert(tokens, roi, key, value, placeholder)
torch.distributed.barrier()
# drop_select
if my_rank == 1:
tok, roi_, key, value, hidden = buffer.drop_select(tokens, roi)
assert torch.allclose(tokens, tok)
assert torch.allclose(roi, roi_)
assert torch.allclose(key, 2.0 * torch.ones([5, 6], device=device))
assert torch.allclose(value, 3.0 * torch.ones([5, 6], device=device))
torch.distributed.barrier()
if my_rank == 0:
assert buffer.buffer_size == 0
assert len(buffer.buffer) == 0
print(f"My rank: {my_rank}, Test run passed!")
def stress_test(my_rank, buf, device):
torch.distributed.barrier()
torch.manual_seed(100)
reqs = [
(
torch.rand(100).to(device), # tokens
torch.ones(100).bool().to(device), # roi
torch.rand(100).to(device), # key
torch.rand(100).to(device), # value
torch.rand(100).to(device), # hidden
)
for i in tqdm(range(200))
]
random.seed(my_rank)
random.shuffle(reqs)
torch.distributed.barrier()
n = 0
# the buffer size can only store 100 reqs
# so the sender will occasionally block to wait for the receiver.
for req in tqdm(reqs):
if my_rank == 0:
buf.insert(*req)
else:
tok, roi, k, v, h = req
tok_, roi_, k_, v_, h_ = buf.drop_select(tok, roi)
if tok_ is None:
assert roi_ is None
assert k_ is None
assert v_ is None
assert h_ is None
n += 1
else:
assert torch.allclose(tok, tok_)
assert torch.allclose(roi, roi_)
assert torch.allclose(k, k_)
assert torch.allclose(v, v_)
assert torch.allclose(h, h_)
print(f"Rank {my_rank} done")
torch.distributed.barrier()
if my_rank == 0:
x = torch.tensor([0])
torch.distributed.recv(x, 1)
# the # of None received is the kv that are not selected
assert x.item() == len(buf.buffer)
# and the size of the buffer should be 2000 * buffer len
print(buf.buffer_size)
assert buf.buffer_size == 1700 * len(buf.buffer)
else:
torch.distributed.send(torch.tensor([n]), 0)
print(f"My rank: {my_rank}, Passed stress test!")
if __name__ == "__main__":
my_rank = int(os.environ["RANK"])
torch.distributed.init_process_group(
backend="gloo",
init_method="tcp://localhost:12398",
world_size=2,
rank=my_rank,
)
print(f"initialized! My rank is {my_rank}")
config = KVTransferConfig(
kv_connector="P2pNcclConnector",
kv_buffer_device="cuda",
kv_buffer_size=1e9,
kv_rank=my_rank,
kv_role="kv_both", # this arg doesn't matter in this test
kv_parallel_size=2,
kv_ip="127.0.0.1",
kv_port=12345,
)
data_pipe = PyNcclPipe(
local_rank=my_rank,
config=config,
device="cuda",
port_offset=0,
)
cpu_pipe = PyNcclPipe(
local_rank=my_rank,
config=config,
device="cpu",
port_offset=1,
)
buffer = SimpleBuffer(cpu_pipe, data_pipe, 170000)
test_run(my_rank, buffer, data_pipe.device)
stress_test(my_rank, buffer, data_pipe.device)
buffer.close()
data_pipe.close()
cpu_pipe.close()
print("Done")
#!/bin/bash
RANK=0 python3 test_lookup_buffer.py &
PID0=$!
RANK=1 python3 test_lookup_buffer.py &
PID1=$!
wait $PID0
wait $PID1
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment