Unverified Commit 79028d43 authored by Xin Yang's avatar Xin Yang Committed by GitHub
Browse files

[Perf] Disable clean_logits in deepgemm fp8_mqa_logits kernel (#33568)

parent 325ab6b0
...@@ -95,7 +95,8 @@ def _ref_fp8_mqa_logits( ...@@ -95,7 +95,8 @@ def _ref_fp8_mqa_logits(
@pytest.mark.skipif( @pytest.mark.skipif(
not current_platform.has_device_capability(90), reason="SM90 and SM100 only" not current_platform.has_device_capability(90), reason="SM90 and SM100 only"
) )
def test_deepgemm_fp8_mqa_logits(): @pytest.mark.parametrize("clean_logits", [True, False])
def test_deepgemm_fp8_mqa_logits(clean_logits: bool):
torch.manual_seed(0) torch.manual_seed(0)
random.seed(0) random.seed(0)
num_heads, head_dim = 32, 128 num_heads, head_dim = 32, 128
...@@ -126,7 +127,9 @@ def test_deepgemm_fp8_mqa_logits(): ...@@ -126,7 +127,9 @@ def test_deepgemm_fp8_mqa_logits():
q_fp8 = q.to(torch.float8_e4m3fn) q_fp8 = q.to(torch.float8_e4m3fn)
kv_fp8 = per_custom_dims_cast_to_fp8(kv, (0,), False) kv_fp8 = per_custom_dims_cast_to_fp8(kv, (0,), False)
logits = fp8_mqa_logits(q_fp8, kv_fp8, weights, ks, ke) logits = fp8_mqa_logits(
q_fp8, kv_fp8, weights, ks, ke, clean_logits=clean_logits
)
ref_logits = _ref_fp8_mqa_logits( ref_logits = _ref_fp8_mqa_logits(
q=q, q=q,
...@@ -135,13 +138,14 @@ def test_deepgemm_fp8_mqa_logits(): ...@@ -135,13 +138,14 @@ def test_deepgemm_fp8_mqa_logits():
cu_seqlen_ks=ks, cu_seqlen_ks=ks,
cu_seqlen_ke=ke, cu_seqlen_ke=ke,
) )
ref_neginf_mask = ref_logits == float("-inf") ref_neginf_mask = ref_logits == float("-inf")
neginf_mask = logits == float("-inf")
assert torch.equal(neginf_mask, ref_neginf_mask) if clean_logits:
neginf_mask = logits == float("-inf")
assert torch.equal(neginf_mask, ref_neginf_mask)
ref_logits = ref_logits.masked_fill(ref_neginf_mask, 0) ref_logits = ref_logits.masked_fill(ref_neginf_mask, 0)
logits = logits.masked_fill(neginf_mask, 0) logits = logits.masked_fill(ref_neginf_mask, 0)
diff = calc_diff(logits, ref_logits) diff = calc_diff(logits, ref_logits)
assert diff < 1e-3, f"{diff=}" assert diff < 1e-3, f"{diff=}"
...@@ -201,7 +205,8 @@ def _ref_fp8_paged_mqa_logits( ...@@ -201,7 +205,8 @@ def _ref_fp8_paged_mqa_logits(
@pytest.mark.skipif( @pytest.mark.skipif(
not current_platform.has_device_capability(90), reason="SM90 and SM100 only" not current_platform.has_device_capability(90), reason="SM90 and SM100 only"
) )
def test_deepgemm_fp8_paged_mqa_logits(): @pytest.mark.parametrize("clean_logits", [True, False])
def test_deepgemm_fp8_paged_mqa_logits(clean_logits: bool):
torch.manual_seed(0) torch.manual_seed(0)
random.seed(0) random.seed(0)
...@@ -264,6 +269,7 @@ def test_deepgemm_fp8_paged_mqa_logits(): ...@@ -264,6 +269,7 @@ def test_deepgemm_fp8_paged_mqa_logits():
block_tables, block_tables,
schedule_metadata, schedule_metadata,
max_model_len, max_model_len,
clean_logits=clean_logits,
) )
ref_logits = _ref_fp8_paged_mqa_logits( ref_logits = _ref_fp8_paged_mqa_logits(
......
...@@ -6,6 +6,7 @@ import pytest ...@@ -6,6 +6,7 @@ import pytest
import torch import torch
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils.torch_utils import set_random_seed
# Test parameters # Test parameters
NUM_ROWS = [1, 32, 2050] NUM_ROWS = [1, 32, 2050]
...@@ -20,6 +21,7 @@ def create_random_logits( ...@@ -20,6 +21,7 @@ def create_random_logits(
row_ends: torch.Tensor, row_ends: torch.Tensor,
dtype: torch.dtype, dtype: torch.dtype,
seed: int, seed: int,
clean_logits: bool,
data_generation: str, data_generation: str,
) -> torch.Tensor: ) -> torch.Tensor:
"""Create random logits tensor for testing.""" """Create random logits tensor for testing."""
...@@ -48,8 +50,9 @@ def create_random_logits( ...@@ -48,8 +50,9 @@ def create_random_logits(
) )
logits = logits_bits.view(dtype) logits = logits_bits.view(dtype)
for i, end in enumerate(row_ends): if clean_logits:
logits[i, end:] = float("-inf") for i, end in enumerate(row_ends):
logits[i, end:] = float("-inf")
return logits return logits
...@@ -121,21 +124,26 @@ def compare_top_k_results( ...@@ -121,21 +124,26 @@ def compare_top_k_results(
@pytest.mark.parametrize("num_rows", NUM_ROWS) @pytest.mark.parametrize("num_rows", NUM_ROWS)
@pytest.mark.parametrize("top_k", TOP_K_VALUES) @pytest.mark.parametrize("top_k", TOP_K_VALUES)
@pytest.mark.parametrize("clean_logits", [True, False])
@pytest.mark.skipif(not current_platform.is_cuda(), reason="This test requires CUDA") @pytest.mark.skipif(not current_platform.is_cuda(), reason="This test requires CUDA")
@torch.inference_mode() @torch.inference_mode()
def test_top_k_per_row( def test_top_k_per_row(
num_rows: int, num_rows: int,
top_k: int, top_k: int,
clean_logits: bool,
) -> None: ) -> None:
""" """
Test top_k_per_row. Test top_k_per_row.
""" """
set_random_seed(0)
torch.set_default_device("cuda:0") torch.set_default_device("cuda:0")
# Create test data # Create test data
vocab_size = 20000 vocab_size = 20000
row_starts, row_ends = create_row_boundaries(num_rows, vocab_size) row_starts, row_ends = create_row_boundaries(num_rows, vocab_size)
logits = create_random_logits(row_starts, row_ends, torch.float32, 42, "random") logits = create_random_logits(
row_starts, row_ends, torch.float32, 42, clean_logits, "random"
)
# Create output tensors # Create output tensors
indices = torch.empty((num_rows, top_k), dtype=torch.int32, device="cuda") indices = torch.empty((num_rows, top_k), dtype=torch.int32, device="cuda")
...@@ -153,11 +161,12 @@ def test_top_k_per_row( ...@@ -153,11 +161,12 @@ def test_top_k_per_row(
) )
# Run reference implementation # Run reference implementation
torch_indices = logits.topk(min(top_k, max(row_ends)), dim=-1)[1] torch_indices = torch.empty((num_rows, top_k), dtype=torch.int32, device="cuda")
mask_lo = torch_indices >= 0 for i in range(num_rows):
mask_hi = (torch_indices - (row_ends - row_starts)[:, None]) < 0 row_end = int(row_ends[i])
mask = mask_lo & mask_hi k_i = min(top_k, row_end)
torch_indices = torch_indices.masked_fill(~mask, -1) idx = logits[i, :row_end].topk(k_i, dim=-1)[1]
torch_indices[i, :k_i] = idx
# Compare results # Compare results
assert compare_top_k_results( assert compare_top_k_results(
...@@ -170,6 +179,7 @@ def _run_top_k_per_row_decode_test( ...@@ -170,6 +179,7 @@ def _run_top_k_per_row_decode_test(
batch_size: int, batch_size: int,
next_n: int, next_n: int,
vocab_size: int, vocab_size: int,
clean_logits: bool,
data_generation: str, data_generation: str,
) -> None: ) -> None:
""" """
...@@ -180,14 +190,18 @@ def _run_top_k_per_row_decode_test( ...@@ -180,14 +190,18 @@ def _run_top_k_per_row_decode_test(
# Create test data # Create test data
num_rows = batch_size * next_n num_rows = batch_size * next_n
seq_lens = torch.randint( seq_lens = torch.randint(
vocab_size, (batch_size,), dtype=torch.int32, device="cuda" low=next_n,
high=vocab_size,
size=(batch_size,),
dtype=torch.int32,
device="cuda",
) )
row_starts = torch.zeros(num_rows, dtype=torch.int32, device="cuda") row_starts = torch.zeros(num_rows, dtype=torch.int32, device="cuda")
row_indices = torch.arange(num_rows, device="cuda") // next_n row_indices = torch.arange(num_rows, device="cuda") // next_n
next_n_offset = torch.arange(num_rows, device="cuda") % next_n next_n_offset = torch.arange(num_rows, device="cuda") % next_n
row_ends = seq_lens[row_indices] - next_n + next_n_offset + 1 row_ends = seq_lens[row_indices] - next_n + next_n_offset + 1
logits = create_random_logits( logits = create_random_logits(
row_starts, row_ends, torch.float32, 42, data_generation row_starts, row_ends, torch.float32, 42, clean_logits, data_generation
) )
# Create output tensors # Create output tensors
...@@ -208,11 +222,12 @@ def _run_top_k_per_row_decode_test( ...@@ -208,11 +222,12 @@ def _run_top_k_per_row_decode_test(
torch.cuda.synchronize() torch.cuda.synchronize()
# Run reference implementation # Run reference implementation
torch_indices = logits.topk(min(top_k, max(row_ends)), dim=-1)[1] torch_indices = torch.empty((num_rows, top_k), dtype=torch.int32, device="cuda")
mask_lo = torch_indices >= 0 for i in range(num_rows):
mask_hi = (torch_indices - (row_ends - row_starts)[:, None]) < 0 row_end = int(row_ends[i])
mask = mask_lo & mask_hi k_i = min(top_k, row_end)
torch_indices = torch_indices.masked_fill(~mask, -1) idx = logits[i, :row_end].topk(k_i, dim=-1)[1]
torch_indices[i, :k_i] = idx
# Compare results # Compare results
assert compare_top_k_results( assert compare_top_k_results(
...@@ -223,6 +238,7 @@ def _run_top_k_per_row_decode_test( ...@@ -223,6 +238,7 @@ def _run_top_k_per_row_decode_test(
@pytest.mark.parametrize("top_k", TOP_K_VALUES) @pytest.mark.parametrize("top_k", TOP_K_VALUES)
@pytest.mark.parametrize("batch_size", BATCH_SIZE) @pytest.mark.parametrize("batch_size", BATCH_SIZE)
@pytest.mark.parametrize("next_n", NEXT_N) @pytest.mark.parametrize("next_n", NEXT_N)
@pytest.mark.parametrize("clean_logits", [True, False])
@pytest.mark.parametrize("data_generation", DATA_GENERATION) @pytest.mark.parametrize("data_generation", DATA_GENERATION)
@pytest.mark.skipif(not current_platform.is_cuda(), reason="This test requires CUDA") @pytest.mark.skipif(not current_platform.is_cuda(), reason="This test requires CUDA")
@torch.inference_mode() @torch.inference_mode()
...@@ -230,28 +246,32 @@ def test_top_k_per_row_decode( ...@@ -230,28 +246,32 @@ def test_top_k_per_row_decode(
top_k: int, top_k: int,
batch_size: int, batch_size: int,
next_n: int, next_n: int,
clean_logits: bool,
data_generation: str, data_generation: str,
) -> None: ) -> None:
""" """
Test top_k_per_row with seq_lens tensor. Test top_k_per_row with seq_lens tensor.
""" """
set_random_seed(0)
vocab_size = 20000 vocab_size = 20000
_run_top_k_per_row_decode_test( _run_top_k_per_row_decode_test(
top_k, batch_size, next_n, vocab_size, data_generation top_k, batch_size, next_n, vocab_size, clean_logits, data_generation
) )
@pytest.mark.skipif(not current_platform.is_cuda(), reason="This test requires CUDA") @pytest.mark.skipif(not current_platform.is_cuda(), reason="This test requires CUDA")
@pytest.mark.parametrize("clean_logits", [True, False])
@torch.inference_mode() @torch.inference_mode()
def test_top_k_per_row_decode_large_vocab_size() -> None: def test_top_k_per_row_decode_large_vocab_size(clean_logits: bool) -> None:
""" """
Test top_k_per_row_decode with large vocabulary size. Test top_k_per_row_decode with large vocabulary size.
""" """
set_random_seed(0)
top_k = 2048 top_k = 2048
batch_size = 2 batch_size = 2
next_n = 2 next_n = 2
vocab_size = 300000 vocab_size = 300000
data_generation = "random" data_generation = "random"
_run_top_k_per_row_decode_test( _run_top_k_per_row_decode_test(
top_k, batch_size, next_n, vocab_size, data_generation top_k, batch_size, next_n, vocab_size, clean_logits, data_generation
) )
...@@ -108,6 +108,7 @@ def sparse_attn_indexer( ...@@ -108,6 +108,7 @@ def sparse_attn_indexer(
weights[chunk.token_start : chunk.token_end], weights[chunk.token_start : chunk.token_end],
chunk.cu_seqlen_ks, chunk.cu_seqlen_ks,
chunk.cu_seqlen_ke, chunk.cu_seqlen_ke,
clean_logits=False,
) )
num_rows = logits.shape[0] num_rows = logits.shape[0]
...@@ -157,6 +158,7 @@ def sparse_attn_indexer( ...@@ -157,6 +158,7 @@ def sparse_attn_indexer(
decode_metadata.block_table, decode_metadata.block_table,
decode_metadata.schedule_metadata, decode_metadata.schedule_metadata,
max_model_len=max_model_len, max_model_len=max_model_len,
clean_logits=False,
) )
num_rows = logits.shape[0] num_rows = logits.shape[0]
......
...@@ -242,6 +242,7 @@ def fp8_mqa_logits( ...@@ -242,6 +242,7 @@ def fp8_mqa_logits(
weights: torch.Tensor, weights: torch.Tensor,
cu_seqlen_ks: torch.Tensor, cu_seqlen_ks: torch.Tensor,
cu_seqlen_ke: torch.Tensor, cu_seqlen_ke: torch.Tensor,
clean_logits: bool,
) -> torch.Tensor: ) -> torch.Tensor:
"""Compute FP8 MQA logits for a single sequence without KV paging. """Compute FP8 MQA logits for a single sequence without KV paging.
...@@ -256,6 +257,7 @@ def fp8_mqa_logits( ...@@ -256,6 +257,7 @@ def fp8_mqa_logits(
shape [M], dtype int32. shape [M], dtype int32.
cu_seqlen_ke: End indices (exclusive) for valid K per query position, cu_seqlen_ke: End indices (exclusive) for valid K per query position,
shape [M], dtype int32. shape [M], dtype int32.
clean_logits: Whether to clean the unfilled logits into `-inf`.
Returns: Returns:
Logits tensor of shape [M, N], dtype `torch.float32`. Logits tensor of shape [M, N], dtype `torch.float32`.
...@@ -263,7 +265,9 @@ def fp8_mqa_logits( ...@@ -263,7 +265,9 @@ def fp8_mqa_logits(
_lazy_init() _lazy_init()
if _fp8_mqa_logits_impl is None: if _fp8_mqa_logits_impl is None:
return _missing() return _missing()
return _fp8_mqa_logits_impl(q, kv, weights, cu_seqlen_ks, cu_seqlen_ke) return _fp8_mqa_logits_impl(
q, kv, weights, cu_seqlen_ks, cu_seqlen_ke, clean_logits=clean_logits
)
def get_paged_mqa_logits_metadata( def get_paged_mqa_logits_metadata(
...@@ -295,6 +299,7 @@ def fp8_paged_mqa_logits( ...@@ -295,6 +299,7 @@ def fp8_paged_mqa_logits(
block_tables: torch.Tensor, block_tables: torch.Tensor,
schedule_metadata: torch.Tensor, schedule_metadata: torch.Tensor,
max_model_len: int, max_model_len: int,
clean_logits: bool,
) -> torch.Tensor: ) -> torch.Tensor:
"""Compute FP8 MQA logits using paged KV-cache. """Compute FP8 MQA logits using paged KV-cache.
...@@ -312,6 +317,7 @@ def fp8_paged_mqa_logits( ...@@ -312,6 +317,7 @@ def fp8_paged_mqa_logits(
schedule_metadata: Returned by `get_paged_mqa_logits_metadata`; schedule_metadata: Returned by `get_paged_mqa_logits_metadata`;
used to distribute work across SMs. used to distribute work across SMs.
max_model_len: Maximum sequence length used to size the logits output. max_model_len: Maximum sequence length used to size the logits output.
clean_logits: Whether to clean the unfilled logits into `-inf`.
Returns: Returns:
Logits tensor of shape [B * next_n, max_model_len], dtype Logits tensor of shape [B * next_n, max_model_len], dtype
...@@ -328,7 +334,7 @@ def fp8_paged_mqa_logits( ...@@ -328,7 +334,7 @@ def fp8_paged_mqa_logits(
block_tables, block_tables,
schedule_metadata, schedule_metadata,
max_model_len, max_model_len,
clean_logits=True, clean_logits=clean_logits,
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment