Unverified Commit a6d0299c authored by Yanan Cao's avatar Yanan Cao Committed by GitHub
Browse files

[Kernel] [Helion] [6/N] Add num_tokens dimension to silu_mul autotuning and dispatching (#34185)


Signed-off-by: default avatarYanan Cao <gmagogsfm@gmail.com>
parent 6ce80f70
...@@ -54,8 +54,8 @@ def reset_config_manager_singleton(): ...@@ -54,8 +54,8 @@ def reset_config_manager_singleton():
class TestSiluMulFp8ConfigPicker: class TestSiluMulFp8ConfigPicker:
def test_config_picker_exact_match(self): def test_config_picker_exact_match(self):
config_keys = [ config_keys = [
"intermediate_2048_batchsize_256", "intermediate_2048_numtokens_256",
"intermediate_4096_batchsize_256", "intermediate_4096_numtokens_256",
] ]
input_tensor = torch.randn(32, 4096, dtype=torch.bfloat16, device="cuda") input_tensor = torch.randn(32, 4096, dtype=torch.bfloat16, device="cuda")
...@@ -63,12 +63,12 @@ class TestSiluMulFp8ConfigPicker: ...@@ -63,12 +63,12 @@ class TestSiluMulFp8ConfigPicker:
args = (input_tensor, scale) args = (input_tensor, scale)
selected_key = pick_silu_mul_fp8_config(args, config_keys) selected_key = pick_silu_mul_fp8_config(args, config_keys)
assert selected_key == "intermediate_2048_batchsize_256" assert selected_key == "intermediate_2048_numtokens_256"
def test_config_picker_closest_match(self): def test_config_picker_closest_match(self):
config_keys = [ config_keys = [
"intermediate_2048_batchsize_256", "intermediate_2048_numtokens_256",
"intermediate_4096_batchsize_256", "intermediate_4096_numtokens_256",
] ]
# Use 7000 (intermediate_size=3500) which is closer to 4096 than 2048 # Use 7000 (intermediate_size=3500) which is closer to 4096 than 2048
input_tensor = torch.randn(32, 7000, dtype=torch.bfloat16, device="cuda") input_tensor = torch.randn(32, 7000, dtype=torch.bfloat16, device="cuda")
...@@ -76,10 +76,10 @@ class TestSiluMulFp8ConfigPicker: ...@@ -76,10 +76,10 @@ class TestSiluMulFp8ConfigPicker:
args = (input_tensor, scale) args = (input_tensor, scale)
selected_key = pick_silu_mul_fp8_config(args, config_keys) selected_key = pick_silu_mul_fp8_config(args, config_keys)
assert selected_key == "intermediate_4096_batchsize_256" assert selected_key == "intermediate_4096_numtokens_256"
def test_config_picker_fallback_to_default(self): def test_config_picker_fallback_to_default(self):
config_keys = ["default", "some_other_key"] config_keys = ["default"]
input_tensor = torch.randn(32, 4096, dtype=torch.bfloat16, device="cuda") input_tensor = torch.randn(32, 4096, dtype=torch.bfloat16, device="cuda")
scale = torch.tensor([0.5], dtype=torch.float32, device="cuda") scale = torch.tensor([0.5], dtype=torch.float32, device="cuda")
...@@ -101,9 +101,9 @@ class TestSiluMulFp8ConfigPicker: ...@@ -101,9 +101,9 @@ class TestSiluMulFp8ConfigPicker:
@pytest.mark.parametrize("intermediate_size", [2048, 4096, 5120]) @pytest.mark.parametrize("intermediate_size", [2048, 4096, 5120])
def test_config_picker_different_sizes(self, intermediate_size): def test_config_picker_different_sizes(self, intermediate_size):
config_keys = [ config_keys = [
"intermediate_2048_batchsize_256", "intermediate_2048_numtokens_256",
"intermediate_4096_batchsize_256", "intermediate_4096_numtokens_256",
"intermediate_5120_batchsize_256", "intermediate_5120_numtokens_256",
] ]
input_tensor = torch.randn( input_tensor = torch.randn(
...@@ -113,9 +113,73 @@ class TestSiluMulFp8ConfigPicker: ...@@ -113,9 +113,73 @@ class TestSiluMulFp8ConfigPicker:
args = (input_tensor, scale) args = (input_tensor, scale)
selected_key = pick_silu_mul_fp8_config(args, config_keys) selected_key = pick_silu_mul_fp8_config(args, config_keys)
expected_key = f"intermediate_{intermediate_size}_batchsize_256" expected_key = f"intermediate_{intermediate_size}_numtokens_256"
assert selected_key == expected_key assert selected_key == expected_key
def test_config_picker_numtokens_ceiling(self):
"""Pick the smallest numtokens >= input num_tokens."""
config_keys = [
"intermediate_4096_numtokens_8",
"intermediate_4096_numtokens_32",
"intermediate_4096_numtokens_128",
"intermediate_4096_numtokens_256",
]
# 20 tokens -> should pick numtokens_32 (smallest >= 20)
input_tensor = torch.randn(20, 8192, dtype=torch.bfloat16, device="cuda")
scale = torch.tensor([0.5], dtype=torch.float32, device="cuda")
selected_key = pick_silu_mul_fp8_config((input_tensor, scale), config_keys)
assert selected_key == "intermediate_4096_numtokens_32"
def test_config_picker_numtokens_exact(self):
"""Exact num_tokens match is preferred over ceiling."""
config_keys = [
"intermediate_4096_numtokens_8",
"intermediate_4096_numtokens_32",
"intermediate_4096_numtokens_128",
]
input_tensor = torch.randn(32, 8192, dtype=torch.bfloat16, device="cuda")
scale = torch.tensor([0.5], dtype=torch.float32, device="cuda")
selected_key = pick_silu_mul_fp8_config((input_tensor, scale), config_keys)
assert selected_key == "intermediate_4096_numtokens_32"
def test_config_picker_numtokens_fallback_to_largest(self):
"""Fall back to the largest numtokens when input exceeds all."""
config_keys = [
"intermediate_4096_numtokens_8",
"intermediate_4096_numtokens_32",
"intermediate_4096_numtokens_128",
]
# 512 tokens -> exceeds all available, should pick largest (128)
input_tensor = torch.randn(512, 8192, dtype=torch.bfloat16, device="cuda")
scale = torch.tensor([0.5], dtype=torch.float32, device="cuda")
selected_key = pick_silu_mul_fp8_config((input_tensor, scale), config_keys)
assert selected_key == "intermediate_4096_numtokens_128"
def test_config_picker_malformed_key_raises(self):
"""Malformed config keys should raise ValueError."""
config_keys = ["intermediate_4096_badformat_256"]
input_tensor = torch.randn(32, 8192, dtype=torch.bfloat16, device="cuda")
scale = torch.tensor([0.5], dtype=torch.float32, device="cuda")
with pytest.raises(ValueError, match="Malformed config key"):
pick_silu_mul_fp8_config((input_tensor, scale), config_keys)
def test_config_picker_default_ignored_when_valid_keys_exist(self):
"""'default' is skipped in favor of a real match."""
config_keys = [
"default",
"intermediate_4096_numtokens_32",
"intermediate_4096_numtokens_128",
]
input_tensor = torch.randn(64, 8192, dtype=torch.bfloat16, device="cuda")
scale = torch.tensor([0.5], dtype=torch.float32, device="cuda")
selected_key = pick_silu_mul_fp8_config((input_tensor, scale), config_keys)
assert selected_key == "intermediate_4096_numtokens_128"
class TestSiluMulFp8Correctness: class TestSiluMulFp8Correctness:
@pytest.mark.parametrize("batch_size", [1, 8, 32, 128]) @pytest.mark.parametrize("batch_size", [1, 8, 32, 128])
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
from typing import Any from typing import Any
import regex as re
import torch import torch
from vllm.logger import init_logger from vllm.logger import init_logger
...@@ -53,44 +54,78 @@ def silu_mul_fp8(input: torch.Tensor, scale: torch.Tensor) -> torch.Tensor: ...@@ -53,44 +54,78 @@ def silu_mul_fp8(input: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
return out.view(output_shape) return out.view(output_shape)
@silu_mul_fp8.register_input_generator # type: ignore[misc]
def generate_silu_mul_fp8_inputs() -> dict[str, tuple[Any, ...]]:
intermediate_sizes = [2048, 2880, 4096, 8192, 11008, 14336]
# Use the same num_tokens values as vLLM's default cudagraph capture sizes.
# See vllm/config/vllm.py _set_cudagraph_sizes() for the canonical formula.
num_tokens_list = [1, 2, 4] + list(range(8, 256, 8)) + list(range(256, 513, 16))
inputs = {}
for num_tokens in num_tokens_list:
for intermediate_size in intermediate_sizes:
# Input tensor has shape (num_tokens, 2 * intermediate_size)
# because silu_mul splits it into two halves
input_tensor = torch.randn(
num_tokens,
2 * intermediate_size,
device="cuda",
dtype=torch.bfloat16,
)
scale = torch.tensor([1.0], device="cuda", dtype=torch.float32)
config_key = f"intermediate_{intermediate_size}_numtokens_{num_tokens}"
inputs[config_key] = (input_tensor, scale)
return inputs
@silu_mul_fp8.register_config_picker # type: ignore[misc] @silu_mul_fp8.register_config_picker # type: ignore[misc]
def pick_silu_mul_fp8_config( def pick_silu_mul_fp8_config(
args: tuple[Any, ...], config_keys: list[str] args: tuple[Any, ...], config_keys: list[str]
) -> str | None: ) -> str | None:
"""Pick the best pre-tuned config for the given input shape.
Selection strategy:
1. Find the closest intermediate_size among available configs
(exact match preferred).
2. Among the num_tokens values tuned for that intermediate_size, pick
the smallest num_tokens >= the input's num_tokens. If the input is
larger than all available num_tokens, fall back to the largest.
Config keys must be "default" or follow the format
"intermediate_{int}_numtokens_{int}".
"""
if not config_keys: if not config_keys:
return None return None
input_tensor, scale = args input_tensor, _scale = args
intermediate_size = input_tensor.shape[-1] // 2 intermediate_size = input_tensor.shape[-1] // 2
num_tokens = input_tensor.view(-1, input_tensor.shape[-1]).shape[0]
# TODO(gmagosfm): Rerun autotuning to capture config for configs: dict[int, list[int]] = {}
# other batch sizes.
target_key = f"intermediate_{intermediate_size}_batchsize_256"
if target_key in config_keys:
return target_key
intermediate_sizes = []
for key in config_keys: for key in config_keys:
if key.startswith("intermediate_") and "_batchsize_256" in key: if key == "default":
try:
size_str = key.split("_")[1]
size = int(size_str)
intermediate_sizes.append((abs(size - intermediate_size), key))
except (ValueError, IndexError):
continue continue
match = re.fullmatch(r"intermediate_(\d+)_numtokens_(\d+)", key)
if not match:
raise ValueError(
f"Malformed config key '{key}', "
f"expected format 'intermediate_{{int}}_numtokens_{{int}}'"
)
isize_str, ntokens_str = match.groups()
configs.setdefault(int(isize_str), []).append(int(ntokens_str))
if intermediate_sizes: if not configs:
_, best_key = min(intermediate_sizes) return "default" if "default" in config_keys else None
logger.debug(
"No exact config for intermediate_size=%d, using closest match: %s", best_isize = min(configs, key=lambda s: abs(s - intermediate_size))
intermediate_size, available_ntokens = sorted(configs[best_isize])
best_key, best_ntokens = next(
(n for n in available_ntokens if n >= num_tokens), available_ntokens[-1]
) )
return best_key
if "default" in config_keys:
return "default"
return None return f"intermediate_{best_isize}_numtokens_{best_ntokens}"
def silu_mul_fp8_baseline(input: torch.Tensor, scale: torch.Tensor) -> torch.Tensor: def silu_mul_fp8_baseline(input: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment