Unverified Commit 184076c3 authored by Daniel Cámpora's avatar Daniel Cámpora Committed by GitHub
Browse files

[DeepSeek v3.2] Make top-k work for any logit values. (#27568)


Signed-off-by: default avatarDaniel Campora <961215+dcampora@users.noreply.github.com>
Co-authored-by: default avatarMichael Goin <mgoin64@gmail.com>
Co-authored-by: default avatarCyrus Leung <tlleungac@connect.ust.hk>
parent eb1051fb
......@@ -102,13 +102,16 @@ void apply_repetition_penalties_(torch::Tensor& logits,
const torch::Tensor& output_mask,
const torch::Tensor& repetition_penalties);
void top_k_per_row(const torch::Tensor& logits, const torch::Tensor& rowStarts,
const torch::Tensor& rowEnds, torch::Tensor& indices,
int64_t numRows, int64_t stride0, int64_t stride1);
void top_k_per_row_prefill(const torch::Tensor& logits,
const torch::Tensor& rowStarts,
const torch::Tensor& rowEnds, torch::Tensor& indices,
int64_t numRows, int64_t stride0, int64_t stride1,
int64_t topK);
void top_k_per_row_decode(const torch::Tensor& logits, int64_t next_n,
const torch::Tensor& seq_lens, torch::Tensor& indices,
int64_t numRows, int64_t stride0, int64_t stride1);
const torch::Tensor& seqLens, torch::Tensor& indices,
int64_t numRows, int64_t stride0, int64_t stride1,
int64_t topK);
void rms_norm_static_fp8_quant(torch::Tensor& out, torch::Tensor& input,
torch::Tensor& weight, torch::Tensor& scale,
......
This diff is collapsed.
......@@ -179,15 +179,15 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
// Optimized top-k per row operation
ops.def(
"top_k_per_row(Tensor logits, Tensor rowStarts, Tensor rowEnds, "
"top_k_per_row_prefill(Tensor logits, Tensor rowStarts, Tensor rowEnds, "
"Tensor! indices, int numRows, int stride0, "
"int stride1) -> ()");
ops.impl("top_k_per_row", torch::kCUDA, &top_k_per_row);
"int stride1, int topK) -> ()");
ops.impl("top_k_per_row_prefill", torch::kCUDA, &top_k_per_row_prefill);
ops.def(
"top_k_per_row_decode(Tensor logits, int next_n, "
"Tensor seq_lens, Tensor! indices, int numRows, "
"int stride0, int stride1) -> ()");
"Tensor seq_lens, Tensor! indices, "
"int numRows, int stride0, int stride1, int topK) -> ()");
ops.impl("top_k_per_row_decode", torch::kCUDA, &top_k_per_row_decode);
// Layernorm-quant
......
......@@ -9,23 +9,45 @@ from vllm.platforms import current_platform
# Test parameters
NUM_ROWS = [1, 32, 2050]
TOP_K_VALUES = [2048]
BATCH_SIZE = [1, 2, 4, 2048, 4096]
NEXT_N = [1, 2, 4, 8]
TOP_K_VALUES = [2048, 3000]
BATCH_SIZE = [1, 2, 2048]
NEXT_N = [1, 8]
DATA_GENERATION = ["random", "10LSBits"]
def create_random_logits(
row_starts: torch.Tensor,
row_ends: torch.Tensor,
vocab_size: int,
dtype: torch.dtype,
seed: int,
data_generation: str,
) -> torch.Tensor:
"""Create random logits tensor for testing."""
torch.manual_seed(seed)
np.random.seed(seed)
# Generate logits with some structure to make testing more meaningful
logits = torch.randn(row_starts.shape[0], max(row_ends), dtype=dtype, device="cuda")
if data_generation == "random":
logits = torch.randn(
row_starts.shape[0], max(row_ends), dtype=dtype, device="cuda"
)
elif data_generation == "10LSBits":
top_22_bits_mask = 0xFFFFFC00
last_10_bits_mask = 0x000003FF
fixed_top_22_bits = 0x3F900000
# Generate random bits for the last 10 bits
random_bottom_bits = torch.randint(
0,
2**10,
(row_starts.shape[0], max(row_ends)),
dtype=torch.int32,
device="cuda",
)
# Combine: fixed top 22 bits with random last 10 bits
logits_bits = (fixed_top_22_bits & top_22_bits_mask) | (
random_bottom_bits & last_10_bits_mask
)
logits = logits_bits.view(dtype)
for i, end in enumerate(row_ends):
logits[i, end:] = float("-inf")
return logits
......@@ -113,13 +135,13 @@ def test_top_k_per_row(
# Create test data
vocab_size = 20000
row_starts, row_ends = create_row_boundaries(num_rows, vocab_size)
logits = create_random_logits(row_starts, row_ends, vocab_size, torch.float32, 42)
logits = create_random_logits(row_starts, row_ends, torch.float32, 42, "random")
# Create output tensors
indices = torch.empty((num_rows, top_k), dtype=torch.int32, device="cuda")
# Run CUDA implementation
torch.ops._C.top_k_per_row(
torch.ops._C.top_k_per_row_prefill(
logits,
row_starts,
row_ends,
......@@ -127,6 +149,7 @@ def test_top_k_per_row(
num_rows,
logits.stride(0),
logits.stride(1),
top_k,
)
# Run reference implementation
......@@ -139,27 +162,23 @@ def test_top_k_per_row(
# Compare results
assert compare_top_k_results(
logits, indices, torch_indices, row_starts, row_ends, top_k
), "CUDA top_k_per_row results don't match torch.topk"
), "CUDA top_k_per_row_prefill results don't match torch.topk"
@pytest.mark.parametrize("top_k", TOP_K_VALUES)
@pytest.mark.parametrize("batch_size", BATCH_SIZE)
@pytest.mark.parametrize("next_n", NEXT_N)
@pytest.mark.skipif(not current_platform.is_cuda(), reason="This test requires CUDA")
@torch.inference_mode()
def test_top_k_per_row_decode(
def _run_top_k_per_row_decode_test(
top_k: int,
batch_size: int,
next_n: int,
vocab_size: int,
data_generation: str,
) -> None:
"""
Test top_k_per_row with seq_lens tensor.
Helper function to run top_k_per_row_decode test with given parameters.
"""
torch.set_default_device("cuda:0")
# Create test data
num_rows = batch_size * next_n
vocab_size = 20000
seq_lens = torch.randint(
vocab_size, (batch_size,), dtype=torch.int32, device="cuda"
)
......@@ -167,7 +186,9 @@ def test_top_k_per_row_decode(
row_indices = torch.arange(num_rows, device="cuda") // next_n
next_n_offset = torch.arange(num_rows, device="cuda") % next_n
row_ends = seq_lens[row_indices] - next_n + next_n_offset + 1
logits = create_random_logits(row_starts, row_ends, vocab_size, torch.float32, 42)
logits = create_random_logits(
row_starts, row_ends, torch.float32, 42, data_generation
)
# Create output tensors
indices = torch.empty((num_rows, top_k), dtype=torch.int32, device="cuda")
......@@ -181,6 +202,7 @@ def test_top_k_per_row_decode(
num_rows,
logits.stride(0),
logits.stride(1),
top_k,
)
torch.cuda.synchronize()
......@@ -195,4 +217,41 @@ def test_top_k_per_row_decode(
# Compare results
assert compare_top_k_results(
logits, indices, torch_indices, row_starts, row_ends, top_k
), "CUDA top_k_per_row results don't match torch.topk"
), "CUDA top_k_per_row_decode results don't match torch.topk"
@pytest.mark.parametrize("top_k", TOP_K_VALUES)
@pytest.mark.parametrize("batch_size", BATCH_SIZE)
@pytest.mark.parametrize("next_n", NEXT_N)
@pytest.mark.parametrize("data_generation", DATA_GENERATION)
@pytest.mark.skipif(not current_platform.is_cuda(), reason="This test requires CUDA")
@torch.inference_mode()
def test_top_k_per_row_decode(
top_k: int,
batch_size: int,
next_n: int,
data_generation: str,
) -> None:
"""
Test top_k_per_row with seq_lens tensor.
"""
vocab_size = 20000
_run_top_k_per_row_decode_test(
top_k, batch_size, next_n, vocab_size, data_generation
)
@pytest.mark.skipif(not current_platform.is_cuda(), reason="This test requires CUDA")
@torch.inference_mode()
def test_top_k_per_row_decode_large_vocab_size() -> None:
"""
Test top_k_per_row_decode with large vocabulary size.
"""
top_k = 2048
batch_size = 2
next_n = 2
vocab_size = 300000
data_generation = "random"
_run_top_k_per_row_decode_test(
top_k, batch_size, next_n, vocab_size, data_generation
)
......@@ -684,11 +684,10 @@ def sparse_attn_indexer(
chunk.cu_seqlen_ke,
)
num_rows = logits.shape[0]
assert topk_tokens == 2048, "top_k_per_row assumes size 2048"
topk_indices = topk_indices_buffer[
chunk.token_start : chunk.token_end, :topk_tokens
]
torch.ops._C.top_k_per_row(
torch.ops._C.top_k_per_row_prefill(
logits,
chunk.cu_seqlen_ks,
chunk.cu_seqlen_ke,
......@@ -696,6 +695,7 @@ def sparse_attn_indexer(
num_rows,
logits.stride(0),
logits.stride(1),
topk_tokens,
)
if has_decode:
......@@ -738,7 +738,6 @@ def sparse_attn_indexer(
max_model_len=max_model_len,
)
num_rows = logits.shape[0]
assert topk_tokens == 2048, "top_k_per_row assumes size 2048"
topk_indices = topk_indices_buffer[:num_decode_tokens, :topk_tokens]
torch.ops._C.top_k_per_row_decode(
......@@ -749,6 +748,7 @@ def sparse_attn_indexer(
num_rows,
logits.stride(0),
logits.stride(1),
topk_tokens,
)
if decode_metadata.requires_padding:
# if padded, we need to unpack
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment