Unverified Commit 184076c3 authored by Daniel Cámpora's avatar Daniel Cámpora Committed by GitHub
Browse files

[DeepSeek v3.2] Make top-k work for any logit values. (#27568)


Signed-off-by: default avatarDaniel Campora <961215+dcampora@users.noreply.github.com>
Co-authored-by: default avatarMichael Goin <mgoin64@gmail.com>
Co-authored-by: default avatarCyrus Leung <tlleungac@connect.ust.hk>
parent eb1051fb
...@@ -102,13 +102,16 @@ void apply_repetition_penalties_(torch::Tensor& logits, ...@@ -102,13 +102,16 @@ void apply_repetition_penalties_(torch::Tensor& logits,
const torch::Tensor& output_mask, const torch::Tensor& output_mask,
const torch::Tensor& repetition_penalties); const torch::Tensor& repetition_penalties);
void top_k_per_row(const torch::Tensor& logits, const torch::Tensor& rowStarts, void top_k_per_row_prefill(const torch::Tensor& logits,
const torch::Tensor& rowStarts,
const torch::Tensor& rowEnds, torch::Tensor& indices, const torch::Tensor& rowEnds, torch::Tensor& indices,
int64_t numRows, int64_t stride0, int64_t stride1); int64_t numRows, int64_t stride0, int64_t stride1,
int64_t topK);
void top_k_per_row_decode(const torch::Tensor& logits, int64_t next_n, void top_k_per_row_decode(const torch::Tensor& logits, int64_t next_n,
const torch::Tensor& seq_lens, torch::Tensor& indices, const torch::Tensor& seqLens, torch::Tensor& indices,
int64_t numRows, int64_t stride0, int64_t stride1); int64_t numRows, int64_t stride0, int64_t stride1,
int64_t topK);
void rms_norm_static_fp8_quant(torch::Tensor& out, torch::Tensor& input, void rms_norm_static_fp8_quant(torch::Tensor& out, torch::Tensor& input,
torch::Tensor& weight, torch::Tensor& scale, torch::Tensor& weight, torch::Tensor& scale,
......
This diff is collapsed.
...@@ -179,15 +179,15 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ...@@ -179,15 +179,15 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
// Optimized top-k per row operation // Optimized top-k per row operation
ops.def( ops.def(
"top_k_per_row(Tensor logits, Tensor rowStarts, Tensor rowEnds, " "top_k_per_row_prefill(Tensor logits, Tensor rowStarts, Tensor rowEnds, "
"Tensor! indices, int numRows, int stride0, " "Tensor! indices, int numRows, int stride0, "
"int stride1) -> ()"); "int stride1, int topK) -> ()");
ops.impl("top_k_per_row", torch::kCUDA, &top_k_per_row); ops.impl("top_k_per_row_prefill", torch::kCUDA, &top_k_per_row_prefill);
ops.def( ops.def(
"top_k_per_row_decode(Tensor logits, int next_n, " "top_k_per_row_decode(Tensor logits, int next_n, "
"Tensor seq_lens, Tensor! indices, int numRows, " "Tensor seq_lens, Tensor! indices, "
"int stride0, int stride1) -> ()"); "int numRows, int stride0, int stride1, int topK) -> ()");
ops.impl("top_k_per_row_decode", torch::kCUDA, &top_k_per_row_decode); ops.impl("top_k_per_row_decode", torch::kCUDA, &top_k_per_row_decode);
// Layernorm-quant // Layernorm-quant
......
...@@ -9,23 +9,45 @@ from vllm.platforms import current_platform ...@@ -9,23 +9,45 @@ from vllm.platforms import current_platform
# Test parameters # Test parameters
NUM_ROWS = [1, 32, 2050] NUM_ROWS = [1, 32, 2050]
TOP_K_VALUES = [2048] TOP_K_VALUES = [2048, 3000]
BATCH_SIZE = [1, 2, 4, 2048, 4096] BATCH_SIZE = [1, 2, 2048]
NEXT_N = [1, 2, 4, 8] NEXT_N = [1, 8]
DATA_GENERATION = ["random", "10LSBits"]
def create_random_logits( def create_random_logits(
row_starts: torch.Tensor, row_starts: torch.Tensor,
row_ends: torch.Tensor, row_ends: torch.Tensor,
vocab_size: int,
dtype: torch.dtype, dtype: torch.dtype,
seed: int, seed: int,
data_generation: str,
) -> torch.Tensor: ) -> torch.Tensor:
"""Create random logits tensor for testing.""" """Create random logits tensor for testing."""
torch.manual_seed(seed) torch.manual_seed(seed)
np.random.seed(seed) np.random.seed(seed)
# Generate logits with some structure to make testing more meaningful # Generate logits with some structure to make testing more meaningful
logits = torch.randn(row_starts.shape[0], max(row_ends), dtype=dtype, device="cuda") if data_generation == "random":
logits = torch.randn(
row_starts.shape[0], max(row_ends), dtype=dtype, device="cuda"
)
elif data_generation == "10LSBits":
top_22_bits_mask = 0xFFFFFC00
last_10_bits_mask = 0x000003FF
fixed_top_22_bits = 0x3F900000
# Generate random bits for the last 10 bits
random_bottom_bits = torch.randint(
0,
2**10,
(row_starts.shape[0], max(row_ends)),
dtype=torch.int32,
device="cuda",
)
# Combine: fixed top 22 bits with random last 10 bits
logits_bits = (fixed_top_22_bits & top_22_bits_mask) | (
random_bottom_bits & last_10_bits_mask
)
logits = logits_bits.view(dtype)
for i, end in enumerate(row_ends): for i, end in enumerate(row_ends):
logits[i, end:] = float("-inf") logits[i, end:] = float("-inf")
return logits return logits
...@@ -113,13 +135,13 @@ def test_top_k_per_row( ...@@ -113,13 +135,13 @@ def test_top_k_per_row(
# Create test data # Create test data
vocab_size = 20000 vocab_size = 20000
row_starts, row_ends = create_row_boundaries(num_rows, vocab_size) row_starts, row_ends = create_row_boundaries(num_rows, vocab_size)
logits = create_random_logits(row_starts, row_ends, vocab_size, torch.float32, 42) logits = create_random_logits(row_starts, row_ends, torch.float32, 42, "random")
# Create output tensors # Create output tensors
indices = torch.empty((num_rows, top_k), dtype=torch.int32, device="cuda") indices = torch.empty((num_rows, top_k), dtype=torch.int32, device="cuda")
# Run CUDA implementation # Run CUDA implementation
torch.ops._C.top_k_per_row( torch.ops._C.top_k_per_row_prefill(
logits, logits,
row_starts, row_starts,
row_ends, row_ends,
...@@ -127,6 +149,7 @@ def test_top_k_per_row( ...@@ -127,6 +149,7 @@ def test_top_k_per_row(
num_rows, num_rows,
logits.stride(0), logits.stride(0),
logits.stride(1), logits.stride(1),
top_k,
) )
# Run reference implementation # Run reference implementation
...@@ -139,27 +162,23 @@ def test_top_k_per_row( ...@@ -139,27 +162,23 @@ def test_top_k_per_row(
# Compare results # Compare results
assert compare_top_k_results( assert compare_top_k_results(
logits, indices, torch_indices, row_starts, row_ends, top_k logits, indices, torch_indices, row_starts, row_ends, top_k
), "CUDA top_k_per_row results don't match torch.topk" ), "CUDA top_k_per_row_prefill results don't match torch.topk"
@pytest.mark.parametrize("top_k", TOP_K_VALUES) def _run_top_k_per_row_decode_test(
@pytest.mark.parametrize("batch_size", BATCH_SIZE)
@pytest.mark.parametrize("next_n", NEXT_N)
@pytest.mark.skipif(not current_platform.is_cuda(), reason="This test requires CUDA")
@torch.inference_mode()
def test_top_k_per_row_decode(
top_k: int, top_k: int,
batch_size: int, batch_size: int,
next_n: int, next_n: int,
vocab_size: int,
data_generation: str,
) -> None: ) -> None:
""" """
Test top_k_per_row with seq_lens tensor. Helper function to run top_k_per_row_decode test with given parameters.
""" """
torch.set_default_device("cuda:0") torch.set_default_device("cuda:0")
# Create test data # Create test data
num_rows = batch_size * next_n num_rows = batch_size * next_n
vocab_size = 20000
seq_lens = torch.randint( seq_lens = torch.randint(
vocab_size, (batch_size,), dtype=torch.int32, device="cuda" vocab_size, (batch_size,), dtype=torch.int32, device="cuda"
) )
...@@ -167,7 +186,9 @@ def test_top_k_per_row_decode( ...@@ -167,7 +186,9 @@ def test_top_k_per_row_decode(
row_indices = torch.arange(num_rows, device="cuda") // next_n row_indices = torch.arange(num_rows, device="cuda") // next_n
next_n_offset = torch.arange(num_rows, device="cuda") % next_n next_n_offset = torch.arange(num_rows, device="cuda") % next_n
row_ends = seq_lens[row_indices] - next_n + next_n_offset + 1 row_ends = seq_lens[row_indices] - next_n + next_n_offset + 1
logits = create_random_logits(row_starts, row_ends, vocab_size, torch.float32, 42) logits = create_random_logits(
row_starts, row_ends, torch.float32, 42, data_generation
)
# Create output tensors # Create output tensors
indices = torch.empty((num_rows, top_k), dtype=torch.int32, device="cuda") indices = torch.empty((num_rows, top_k), dtype=torch.int32, device="cuda")
...@@ -181,6 +202,7 @@ def test_top_k_per_row_decode( ...@@ -181,6 +202,7 @@ def test_top_k_per_row_decode(
num_rows, num_rows,
logits.stride(0), logits.stride(0),
logits.stride(1), logits.stride(1),
top_k,
) )
torch.cuda.synchronize() torch.cuda.synchronize()
...@@ -195,4 +217,41 @@ def test_top_k_per_row_decode( ...@@ -195,4 +217,41 @@ def test_top_k_per_row_decode(
# Compare results # Compare results
assert compare_top_k_results( assert compare_top_k_results(
logits, indices, torch_indices, row_starts, row_ends, top_k logits, indices, torch_indices, row_starts, row_ends, top_k
), "CUDA top_k_per_row results don't match torch.topk" ), "CUDA top_k_per_row_decode results don't match torch.topk"
@pytest.mark.parametrize("top_k", TOP_K_VALUES)
@pytest.mark.parametrize("batch_size", BATCH_SIZE)
@pytest.mark.parametrize("next_n", NEXT_N)
@pytest.mark.parametrize("data_generation", DATA_GENERATION)
@pytest.mark.skipif(not current_platform.is_cuda(), reason="This test requires CUDA")
@torch.inference_mode()
def test_top_k_per_row_decode(
top_k: int,
batch_size: int,
next_n: int,
data_generation: str,
) -> None:
"""
Test top_k_per_row with seq_lens tensor.
"""
vocab_size = 20000
_run_top_k_per_row_decode_test(
top_k, batch_size, next_n, vocab_size, data_generation
)
@pytest.mark.skipif(not current_platform.is_cuda(), reason="This test requires CUDA")
@torch.inference_mode()
def test_top_k_per_row_decode_large_vocab_size() -> None:
"""
Test top_k_per_row_decode with large vocabulary size.
"""
top_k = 2048
batch_size = 2
next_n = 2
vocab_size = 300000
data_generation = "random"
_run_top_k_per_row_decode_test(
top_k, batch_size, next_n, vocab_size, data_generation
)
...@@ -684,11 +684,10 @@ def sparse_attn_indexer( ...@@ -684,11 +684,10 @@ def sparse_attn_indexer(
chunk.cu_seqlen_ke, chunk.cu_seqlen_ke,
) )
num_rows = logits.shape[0] num_rows = logits.shape[0]
assert topk_tokens == 2048, "top_k_per_row assumes size 2048"
topk_indices = topk_indices_buffer[ topk_indices = topk_indices_buffer[
chunk.token_start : chunk.token_end, :topk_tokens chunk.token_start : chunk.token_end, :topk_tokens
] ]
torch.ops._C.top_k_per_row( torch.ops._C.top_k_per_row_prefill(
logits, logits,
chunk.cu_seqlen_ks, chunk.cu_seqlen_ks,
chunk.cu_seqlen_ke, chunk.cu_seqlen_ke,
...@@ -696,6 +695,7 @@ def sparse_attn_indexer( ...@@ -696,6 +695,7 @@ def sparse_attn_indexer(
num_rows, num_rows,
logits.stride(0), logits.stride(0),
logits.stride(1), logits.stride(1),
topk_tokens,
) )
if has_decode: if has_decode:
...@@ -738,7 +738,6 @@ def sparse_attn_indexer( ...@@ -738,7 +738,6 @@ def sparse_attn_indexer(
max_model_len=max_model_len, max_model_len=max_model_len,
) )
num_rows = logits.shape[0] num_rows = logits.shape[0]
assert topk_tokens == 2048, "top_k_per_row assumes size 2048"
topk_indices = topk_indices_buffer[:num_decode_tokens, :topk_tokens] topk_indices = topk_indices_buffer[:num_decode_tokens, :topk_tokens]
torch.ops._C.top_k_per_row_decode( torch.ops._C.top_k_per_row_decode(
...@@ -749,6 +748,7 @@ def sparse_attn_indexer( ...@@ -749,6 +748,7 @@ def sparse_attn_indexer(
num_rows, num_rows,
logits.stride(0), logits.stride(0),
logits.stride(1), logits.stride(1),
topk_tokens,
) )
if decode_metadata.requires_padding: if decode_metadata.requires_padding:
# if padded, we need to unpack # if padded, we need to unpack
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment