Commit fc7980db authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.15.1' into v0.15.1-ori

parents 3eab7fef 1892993b
......@@ -107,7 +107,10 @@ def make_kv_cache_config(block_size: int, num_blocks: int) -> KVCacheConfig:
def make_kv_cache_config_hybrid_model(
block_size: int, num_blocks: int, second_spec_type: str = "sliding_window"
block_size: int,
num_blocks: int,
sliding_window_blocks: int,
second_spec_type: str = "sliding_window",
) -> KVCacheConfig:
if second_spec_type == "sliding_window":
second_spec = SlidingWindowSpec(
......@@ -115,7 +118,7 @@ def make_kv_cache_config_hybrid_model(
num_kv_heads=1,
head_size=1,
dtype=torch.float32,
sliding_window=2 * block_size,
sliding_window=sliding_window_blocks * block_size,
)
elif second_spec_type == "mamba":
second_spec = MambaSpec(
......@@ -325,7 +328,7 @@ def test_prefill(hash_fn):
def test_prefill_hybrid_model():
block_size = 16
manager = KVCacheManager(
make_kv_cache_config_hybrid_model(block_size, 21),
make_kv_cache_config_hybrid_model(block_size, 21, 2),
max_model_len=8192,
enable_caching=True,
hash_block_size=block_size,
......@@ -334,7 +337,8 @@ def test_prefill_hybrid_model():
hash_fn = sha256
# Complete 3 blocks (48 tokens)
common_token_ids = [i for i in range(3) for _ in range(block_size)]
num_full_blocks = 3
common_token_ids = [i for i in range(num_full_blocks) for _ in range(block_size)]
# Fully cache miss
# Incomplete 1 block (7 tokens)
......@@ -375,6 +379,7 @@ def test_prefill_hybrid_model():
# Cache hit in the common prefix
# Incomplete 1 block (5 tokens)
unique_token_ids = [3] * 5
all_token_ids = common_token_ids + unique_token_ids
req1 = make_request("1", common_token_ids + unique_token_ids, block_size, hash_fn)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
assert len(req1.block_hashes) == 3
......@@ -394,34 +399,13 @@ def test_prefill_hybrid_model():
manager.free(req0)
manager.free(req1)
cached_block_hash_to_block_bak = copy.copy(
manager.block_pool.cached_block_hash_to_block._cache
)
def test_partial_request_hit(
request_id: str,
hash_to_evict: list[BlockHashWithGroupId],
expect_hit_length: int,
):
req = make_request(
request_id, common_token_ids + unique_token_ids, block_size, sha256
)
for hash_with_group_id in hash_to_evict:
manager.block_pool.cached_block_hash_to_block._cache.pop(hash_with_group_id)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req)
assert len(req.block_hashes) == 3
assert num_computed_tokens == expect_hit_length * block_size
for block_per_group in computed_blocks.blocks:
assert len(block_per_group) == num_computed_tokens // block_size
for hash_with_group_id in hash_to_evict:
manager.block_pool.cached_block_hash_to_block._cache[hash_with_group_id] = (
cached_block_hash_to_block_bak[hash_with_group_id]
)
manager.free(req)
# Evict the blocks outside sliding window, does not affect the hit length.
test_partial_request_hit(
_test_partial_request_hit(
manager,
block_size,
num_full_blocks,
"2",
all_token_ids,
[
make_block_hash_with_group_id(block_hashes[0], 1),
make_block_hash_with_group_id(block_hashes[0], 2),
......@@ -430,13 +414,23 @@ def test_prefill_hybrid_model():
)
# Evict the first block of full attention, makes total cache miss.
test_partial_request_hit(
"3", [make_block_hash_with_group_id(block_hashes[0], 0)], 0
_test_partial_request_hit(
manager,
block_size,
num_full_blocks,
"3",
all_token_ids,
[make_block_hash_with_group_id(block_hashes[0], 0)],
0,
)
# Evict the last block of all layers, reduces the hit length to 2.
test_partial_request_hit(
_test_partial_request_hit(
manager,
block_size,
num_full_blocks,
"4",
all_token_ids,
[
make_block_hash_with_group_id(block_hashes[2], 0),
make_block_hash_with_group_id(block_hashes[2], 1),
......@@ -446,18 +440,36 @@ def test_prefill_hybrid_model():
)
# Evict the last block of full attention, reduces the hit length to 2.
test_partial_request_hit(
"5", [make_block_hash_with_group_id(block_hashes[2], 0)], 2
_test_partial_request_hit(
manager,
block_size,
num_full_blocks,
"5",
all_token_ids,
[make_block_hash_with_group_id(block_hashes[2], 0)],
2,
)
# Evict the last block of sliding window, reduces the hit length to 2.
test_partial_request_hit(
"6", [make_block_hash_with_group_id(block_hashes[2], 1)], 2
_test_partial_request_hit(
manager,
block_size,
num_full_blocks,
"6",
all_token_ids,
[make_block_hash_with_group_id(block_hashes[2], 1)],
2,
)
# Evict the last block of sliding window, reduces the hit length to 2.
test_partial_request_hit(
"7", [make_block_hash_with_group_id(block_hashes[2], 2)], 2
_test_partial_request_hit(
manager,
block_size,
num_full_blocks,
"7",
all_token_ids,
[make_block_hash_with_group_id(block_hashes[2], 2)],
2,
)
# Evict different set of blocks for full attention and sliding window makes
......@@ -466,8 +478,12 @@ def test_prefill_hybrid_model():
# The cache hit length of sliding window is 2 * block_size.
# Then it is cache miss as the two type of layers
# have different hit length.
test_partial_request_hit(
_test_partial_request_hit(
manager,
block_size,
num_full_blocks,
"8",
all_token_ids,
[
make_block_hash_with_group_id(block_hashes[2], 0),
make_block_hash_with_group_id(block_hashes[0], 1),
......@@ -477,6 +493,214 @@ def test_prefill_hybrid_model():
)
def test_prefill_hybrid_model_eagle():
block_size = 16
kv_cache_config = make_kv_cache_config_hybrid_model(block_size, 31, 3)
manager = KVCacheManager(
kv_cache_config,
max_model_len=8192,
enable_caching=True,
hash_block_size=block_size,
use_eagle=True,
)
hash_fn = sha256
# Complete 6 blocks (96 tokens)
num_full_blocks = 6
common_token_ids = [i for i in range(num_full_blocks) for _ in range(block_size)]
# Fully cache miss
# Incomplete 1 block (7 tokens)
unique_token_ids = [6] * 7
all_token_ids = common_token_ids + unique_token_ids
req0 = make_request("0", all_token_ids, block_size, hash_fn)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0)
assert len(req0.block_hashes) == len(all_token_ids) // block_size
assert not computed_blocks.blocks[0]
assert num_computed_tokens == 0
blocks = manager.allocate_slots(
req0, len(all_token_ids), num_computed_tokens, computed_blocks
)
block_ids = (
[1, 2, 3, 4, 5, 6, 7],
[8, 9, 10, 11, 12, 13, 14],
[15, 16, 17, 18, 19, 20, 21],
)
assert blocks is not None and blocks.get_block_ids() == block_ids
# Check full block metadata
parent_block_hash = None
for i, full_block_ids in enumerate(zip(*(row[:-1] for row in block_ids))):
block_tokens = tuple(all_token_ids[i * block_size : (i + 1) * block_size])
block_hash = hash_block_tokens(hash_fn, parent_block_hash, block_tokens)
for group_id, block_id in enumerate(full_block_ids):
blk_hash = manager.block_pool.blocks[block_id].block_hash
assert blk_hash is not None
assert get_block_hash(blk_hash) == block_hash
assert get_group_id(blk_hash) == group_id
assert manager.block_pool.blocks[block_id].ref_cnt == 1
parent_block_hash = block_hash
# Check partial block metadata
for partial_block_id in (row[-1] for row in block_ids):
assert manager.block_pool.blocks[partial_block_id].block_hash is None
assert manager.block_pool.blocks[partial_block_id].ref_cnt == 1
# Cache hit in the common prefix
# Incomplete 1 block (5 tokens)
unique_token_ids = [6] * 5
all_token_ids = common_token_ids + unique_token_ids
req1 = make_request("1", all_token_ids, block_size, hash_fn)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
assert len(req1.block_hashes) == num_full_blocks
assert computed_blocks.get_block_ids() == (
[1, 2, 3, 4],
[0, 9, 10, 11],
[0, 16, 17, 18],
)
assert num_computed_tokens == 4 * block_size
num_new_tokens = len(all_token_ids) - num_computed_tokens
blocks = manager.allocate_slots(
req1, num_new_tokens, num_computed_tokens, computed_blocks
)
assert blocks is not None and blocks.get_block_ids() == (
[22, 23, 24],
[25, 26, 27],
[28, 29, 30],
)
for block_per_group in computed_blocks.blocks:
for block in block_per_group:
if block != manager.block_pool.null_block:
assert block.ref_cnt == 2
block_hashes = req1.block_hashes
manager.free(req0)
manager.free(req1)
# Evict the blocks outside sliding window, does not affect the hit length.
_test_partial_request_hit(
manager,
block_size,
num_full_blocks,
"2",
all_token_ids,
[
make_block_hash_with_group_id(block_hashes[0], 1),
make_block_hash_with_group_id(block_hashes[0], 2),
],
4,
)
# Evict the first block of full attention, makes total cache miss.
_test_partial_request_hit(
manager,
block_size,
num_full_blocks,
"3",
all_token_ids,
[make_block_hash_with_group_id(block_hashes[0], 0)],
0,
)
# Evict the last block of all layers, reduces the hit length to 3.
_test_partial_request_hit(
manager,
block_size,
num_full_blocks,
"4",
all_token_ids,
[
make_block_hash_with_group_id(block_hashes[-1], 0),
make_block_hash_with_group_id(block_hashes[-1], 1),
make_block_hash_with_group_id(block_hashes[-1], 2),
],
3,
)
# Evict the last block of full attention, reduces the hit length to 3.
_test_partial_request_hit(
manager,
block_size,
num_full_blocks,
"5",
all_token_ids,
[make_block_hash_with_group_id(block_hashes[-1], 0)],
3,
)
# Since the last block of full attention is dropped for eagle, evict
# the second last block of sliding window, reduces the hit length to 3.
_test_partial_request_hit(
manager,
block_size,
num_full_blocks,
"6",
all_token_ids,
[make_block_hash_with_group_id(block_hashes[-2], 1)],
3,
)
# Since the last block of full attention is dropped for eagle, evict
# the second last block of sliding window, reduces the hit length to 3.
_test_partial_request_hit(
manager,
block_size,
num_full_blocks,
"7",
all_token_ids,
[make_block_hash_with_group_id(block_hashes[-2], 2)],
3,
)
# Evict different set of blocks for full attention and sliding window makes
# total cache miss.
# The cache hit length of full attention is 4 * block_size.
# The cache hit length of sliding window is 3 * block_size.
# Then it is cache miss as the two type of layers
# have different hit length.
_test_partial_request_hit(
manager,
block_size,
num_full_blocks,
"8",
all_token_ids,
[
make_block_hash_with_group_id(block_hashes[-1], 0),
make_block_hash_with_group_id(block_hashes[0], 1),
make_block_hash_with_group_id(block_hashes[0], 2),
],
0,
)
def _test_partial_request_hit(
manager: KVCacheManager,
block_size: int,
num_full_blocks,
request_id: str,
prompt_token_ids: list[int],
hash_to_evict: list[BlockHashWithGroupId],
expect_hit_length: int,
):
cached_block_hash_to_block_bak = copy.copy(
manager.block_pool.cached_block_hash_to_block._cache
)
req = make_request(request_id, prompt_token_ids, block_size, sha256)
for hash_with_group_id in hash_to_evict:
manager.block_pool.cached_block_hash_to_block._cache.pop(hash_with_group_id)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req)
assert len(req.block_hashes) == num_full_blocks
assert num_computed_tokens == expect_hit_length * block_size
for block_per_group in computed_blocks.blocks:
assert len(block_per_group) == num_computed_tokens // block_size
for hash_with_group_id in hash_to_evict:
manager.block_pool.cached_block_hash_to_block._cache[hash_with_group_id] = (
cached_block_hash_to_block_bak[hash_with_group_id]
)
manager.free(req)
def _make_hybrid_kv_cache_config(
block_size: int, num_blocks: int, spec_types: list[str]
) -> KVCacheConfig:
......@@ -655,6 +879,85 @@ def test_prefill_hybrid_model_combinations(spec_types: list[str]):
manager.free(req1)
# Test cases with eagle enabled: Only test a single simple case for now.
# - 2 groups: 1 full + 1 other
_EAGLE_HYBRID_MODEL_TEST_CASES = [
# 2 groups: 1 full + 1 other
pytest.param(["full", "sliding_window"], 2, id="2g-full+sw"),
]
@pytest.mark.parametrize("spec_types,expect_hit_length", _EAGLE_HYBRID_MODEL_TEST_CASES)
def test_prefill_hybrid_model_combinations_eagle(
spec_types: list[str], expect_hit_length: int
):
"""
Test prefix caching with hybrid models (1 full attn + 1 other) with EAGLE.
More complex hybrid models with EAGLE are not yet supported (see issue #32802).
"""
block_size = 16
num_groups = len(spec_types)
# Allocate enough blocks for all groups
num_blocks = 10 * num_groups
kv_cache_config = _make_hybrid_kv_cache_config(block_size, num_blocks, spec_types)
manager = KVCacheManager(
kv_cache_config,
max_model_len=8192,
enable_caching=True,
hash_block_size=block_size,
use_eagle=True,
)
hash_fn = sha256
# Complete 3 blocks (48 tokens)
num_full_blocks = 4
common_token_ids = [i for i in range(num_full_blocks) for _ in range(block_size)]
unique_token_ids = [4] * 7
all_token_ids = common_token_ids + unique_token_ids
# First request: no cache hit initially
req0 = make_request("0", all_token_ids, block_size, hash_fn)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0)
assert len(req0.block_hashes) == num_full_blocks
assert not computed_blocks.blocks[0] # No cache hit initially
assert num_computed_tokens == 0
blocks = manager.allocate_slots(
req0, len(all_token_ids), num_computed_tokens, computed_blocks
)
assert blocks is not None
# Should have blocks for all groups
assert len(blocks.get_block_ids()) == num_groups
# Second request: should hit cached blocks for common prefix
all_token_ids = common_token_ids + [6] * 5
req1 = make_request("1", all_token_ids, block_size, hash_fn)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
# Should hit cached blocks for all groups
assert num_computed_tokens == expect_hit_length * block_size
assert len(computed_blocks.blocks) == num_groups
# Verify each group has the correct number of computed blocks
for block_per_group in computed_blocks.blocks:
assert len(block_per_group) == expect_hit_length
# Allocate and verify blocks for second request
blocks = manager.allocate_slots(
req1,
len(all_token_ids) - num_computed_tokens,
num_computed_tokens,
computed_blocks,
)
assert blocks is not None
assert len(blocks.get_block_ids()) == num_groups
manager.free(req0)
manager.free(req1)
def test_prefill_plp():
"""Test prefill with APC and some prompt logprobs (plp) requests.
......
......@@ -870,6 +870,66 @@ def test_schedule_spec_decoding_stats(spec_tokens, output_tokens, expected):
assert stats.num_accepted_tokens_per_pos == expected[3]
def test_spec_decoding_stats_empty_output():
"""Test that spec decoding stats handle empty output tokens gracefully.
This is a regression test for a bug where empty sampled_token_ids
would cause num_accepted = len([]) - 1 = -1, leading to a
ValueError when incrementing a Prometheus counter with a negative value.
"""
num_spec_tokens = 3
scheduler = create_scheduler(num_speculative_tokens=num_spec_tokens)
requests = create_requests(num_requests=1, num_tokens=1)
request = requests[0]
req_id = request.request_id
scheduler.add_request(request)
# Initial schedule (prefill)
output = scheduler.schedule()
assert len(output.scheduled_new_reqs) == 1
# Complete the prefill with a sampled token
model_runner_output = ModelRunnerOutput(
req_ids=[req_id],
req_id_to_index={req_id: 0},
sampled_token_ids=[[0]],
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[],
)
scheduler.update_from_output(output, model_runner_output)
# Add draft tokens for speculation
draft_token_ids = DraftTokenIds([req_id], [[1, 2, 3]])
scheduler.update_draft_token_ids(draft_token_ids)
# Schedule the speculated tokens for validation
output = scheduler.schedule()
assert req_id in output.scheduled_spec_decode_tokens
assert len(output.scheduled_spec_decode_tokens[req_id]) == 3
# Simulate empty output tokens (e.g., due to request abortion or error)
# This would previously cause num_accepted = -1 and crash
model_runner_output = ModelRunnerOutput(
req_ids=[req_id],
req_id_to_index={req_id: 0},
sampled_token_ids=[[]], # Empty output tokens
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[],
)
# This should not raise an error
engine_core_outputs = scheduler.update_from_output(output, model_runner_output)
# Spec decoding stats should be None since no tokens were generated
scheduler_stats = (
engine_core_outputs[0].scheduler_stats if engine_core_outputs else None
)
assert scheduler_stats is None or scheduler_stats.spec_decoding_stats is None
def _assert_right_scheduler_output(
output: SchedulerOutput,
num_requests: int,
......
......@@ -19,7 +19,6 @@ compressed-tensors, nm-testing/tinyllama-oneshot-w8a16-per-channel, main
compressed-tensors, nm-testing/Meta-Llama-3-8B-FP8-compressed-tensors-test, main
compressed-tensors, nm-testing/Phi-3-mini-128k-instruct-FP8, main
compressed-tensors, neuralmagic/Phi-3-medium-128k-instruct-quantized.w4a16, main
compressed-tensors, nm-testing/TinyLlama-1.1B-Chat-v1.0-actorder-group, main
#compressed-tensors, mgoin/DeepSeek-Coder-V2-Lite-Instruct-FP8, main
compressed-tensors, nm-testing/SparseLlama-3.1-8B-gsm8k-pruned.2of4-FP8-Dynamic-testing, main, 90
compressed-tensors, nm-testing/SparseLlama-3.1-8B-gsm8k-pruned.2of4-W8A8-testing, main, 90
......
......@@ -900,6 +900,8 @@ def cutlass_sparse_scaled_mm_supported(cuda_device_capability: int) -> bool:
def cutlass_group_gemm_supported(cuda_device_capability: int) -> bool:
if cuda_device_capability < 90 or cuda_device_capability >= 110:
return False
try:
return torch.ops._C.cutlass_group_gemm_supported(cuda_device_capability)
except AttributeError:
......@@ -2032,35 +2034,20 @@ def selective_scan_fwd(
)
# NOTE: The wvSplitK kernel (and all of the kernels in skinny_gemms.cu)
# are unable to properly handle non-contiguous
# tensors. It might be a good TODO(rasmith) to augment these kernels
# to be able to handle non-contiguous kernels for better performance.
def rocm_enforce_contiguous_skinny_gemm_inputs(
a: torch.Tensor, b: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
a = a.contiguous() # no-op if already contiguous, else clone
b = b.contiguous() # no-op if already contiguous, else clone
return a, b
# ROCm skinny gemms
def LLMM1(a: torch.Tensor, b: torch.Tensor, rows_per_block: int) -> torch.Tensor:
a, b = rocm_enforce_contiguous_skinny_gemm_inputs(a, b)
return torch.ops._rocm_C.LLMM1(a, b, rows_per_block)
def wvSplitK(
a: torch.Tensor, b: torch.Tensor, cu_count: int, bias: torch.Tensor = None
) -> torch.Tensor:
a, b = rocm_enforce_contiguous_skinny_gemm_inputs(a, b)
return torch.ops._rocm_C.wvSplitK(a, b, bias, cu_count)
def wvSplitKrc(
a: torch.Tensor, b: torch.Tensor, cu_count: int, bias: torch.Tensor = None
) -> torch.Tensor:
a, b = rocm_enforce_contiguous_skinny_gemm_inputs(a, b)
return torch.ops._rocm_C.wvSplitKrc(a, b, bias, cu_count)
......@@ -2073,7 +2060,6 @@ def wvSplitKQ(
cu_count: int,
bias: torch.Tensor = None,
) -> torch.Tensor:
a, b = rocm_enforce_contiguous_skinny_gemm_inputs(a, b)
out = torch.empty((b.shape[0], a.shape[0]), dtype=out_dtype, device=b.device)
torch.ops._rocm_C.wvSplitKQ(a, b, bias, out, scale_a, scale_b, cu_count)
return out
......
......@@ -361,6 +361,13 @@ def split_graph(
subgraph_id += 1
node_to_subgraph_id[node] = subgraph_id
split_op_graphs.append(subgraph_id)
# keep consecutive splitting ops together
# (we know node.next exists because node isn't the last (output) node)
if should_split(node.next, splitting_ops):
# this will get incremented by the next node
subgraph_id -= 1
else:
subgraph_id += 1
else:
node_to_subgraph_id[node] = subgraph_id
......
......@@ -582,6 +582,24 @@ class CompilationConfig:
local_cache_dir: str = field(default=None, init=False) # type: ignore
"""local cache dir for each rank"""
fast_moe_cold_start = True
"""Optimization for fast MOE cold start.
This is a bit of a hack that assumes that:
1. the only decoder forward pass being run is the current model
2. the decoder forward pass runs all of the MOEs in the order in which they
are initialized
When the above two conditions hold, this option greatly decreases cold start
time for MOE models.
If the above two conditions don't hold, then this option will lead to silent
incorrectness. The only condition in which this doesn't hold is speculative
decoding, where there is a draft model that may have MOEs in them.
NB: We're working on a longer-term solution that doesn't need these assumptions.
"""
# keep track of enabled and disabled custom ops
enabled_custom_ops: Counter[str] = field(default_factory=Counter, init=False)
"""custom ops that are enabled"""
......@@ -597,6 +615,10 @@ class CompilationConfig:
Map from layer name to layer objects that need to be accessed outside
model code, e.g., Attention, FusedMOE when dp_size>1."""
static_all_moe_layers: list[str] = field(default_factory=list, init=False)
"""The names of all the MOE layers in the model
"""
# Attention ops; used for piecewise cudagraphs
# Use PyTorch operator format: "namespace::name"
_attention_ops: ClassVar[list[str]] = [
......@@ -926,6 +948,15 @@ class CompilationConfig:
# for details. Make a copy to avoid mutating the class-level
# list via reference.
self.splitting_ops = list(self._attention_ops)
# unified_kv_cache_update has a string param that prevents Inductor
# from reusing piecewise graphs. Remove it from the compiled graph.
# This has the side-effect of excluding cache from cudagraphs but
# that doesn't seem to affect performance.
# https://github.com/vllm-project/vllm/issues/33267
if not self.use_inductor_graph_partition:
self.splitting_ops.append("vllm::unified_kv_cache_update")
elif len(self.splitting_ops) == 0:
if (
self.cudagraph_mode == CUDAGraphMode.PIECEWISE
......
......@@ -41,6 +41,7 @@ MTPModelTypes = Literal[
"longcat_flash_mtp",
"mtp",
"pangu_ultra_moe_mtp",
"step3p5_mtp",
]
EagleModelTypes = Literal["eagle", "eagle3", MTPModelTypes]
SpeculativeMethod = Literal[
......@@ -264,6 +265,11 @@ class SpeculativeConfig:
{"n_predict": n_predict, "architectures": ["LongCatFlashMTPModel"]}
)
if hf_config.model_type == "step3p5":
hf_config.model_type = "step3p5_mtp"
n_predict = getattr(hf_config, "num_nextn_predict_layers", 1)
hf_config.update({"n_predict": n_predict, "architectures": ["Step3p5MTP"]})
if initial_architecture == "MistralLarge3ForCausalLM":
hf_config.update({"architectures": ["EagleMistralLarge3ForCausalLM"]})
......
......@@ -217,9 +217,11 @@ class ForwardContext:
# the graph.
#
# The workaround is to store a list of the strings that each of those
# custom ops needs, in reverse order, in the ForwardContext.
# custom ops needs in the ForwardContext (all_moe_layers)
# as well as a counter (moe_layer_index).
# The ForwardContext object is alive for the duration of the forward pass.
# When the custom op needs the string, pop the string from this list.
# When the custom op needs a layer string, get the next string
# from all_moe_layers and increment the counter.
#
# This assumes that the custom operators will always be executed in
# order and that torch.compile will not try to reorder these
......@@ -233,7 +235,8 @@ class ForwardContext:
#
# If this value is None (like in some tests), then we end up baking the string
# into the graph. Otherwise, the moe custom ops will pop a string from this list.
remaining_moe_layers: list[str] | None = None
all_moe_layers: list[str] | None = None
moe_layer_index: int = 0
additional_kwargs: dict[str, Any] = field(default_factory=dict)
......@@ -271,17 +274,22 @@ def create_forward_context(
additional_kwargs: dict[str, Any] | None = None,
skip_compiled: bool = False,
):
no_compile_layers = vllm_config.compilation_config.static_forward_context
from vllm.model_executor.layers.fused_moe.layer import FusedMoE
remaining_moe_layers = [
name for name, layer in no_compile_layers.items() if isinstance(layer, FusedMoE)
]
remaining_moe_layers.reverse()
if vllm_config.compilation_config.fast_moe_cold_start:
if vllm_config.speculative_config is None:
all_moe_layers = vllm_config.compilation_config.static_all_moe_layers
else:
logger.warning_once(
"vllm_config.compilation_config.fast_moe_cold_start is not "
"compatible with speculative decoding so we are ignoring "
"fast_moe_cold_start."
)
all_moe_layers = None
else:
all_moe_layers = None
return ForwardContext(
no_compile_layers=no_compile_layers,
remaining_moe_layers=remaining_moe_layers,
no_compile_layers=vllm_config.compilation_config.static_forward_context,
all_moe_layers=all_moe_layers,
virtual_engine=virtual_engine,
attn_metadata=attn_metadata,
slot_mapping=slot_mapping or {},
......
......@@ -17,11 +17,63 @@ from vllm.logger import init_logger
from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton
from vllm.utils.collection_utils import LazyDict
logger = init_logger(__name__)
@triton.jit
def _swiglustep_and_mul_kernel(
o_ptr,
o_stride,
x_ptr,
x_stride,
limit: tl.constexpr,
d: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
) -> None:
i = tl.program_id(axis=0).to(tl.int64)
j = tl.program_id(axis=1)
o_row_ptr = o_ptr + o_stride * i
x_row_ptr = x_ptr + x_stride * i
offsets = j * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = offsets < d
gate = tl.load(x_row_ptr + offsets, mask=mask).to(tl.float32)
up = tl.load(x_row_ptr + offsets + d, mask=mask).to(tl.float32)
gate_silu = tl.sigmoid(gate) * gate
gate_clamped = tl.minimum(gate_silu, limit)
up_clamped = tl.minimum(tl.maximum(up, -limit), limit)
result = gate_clamped * up_clamped
result = result.to(x_ptr.dtype.element_ty)
tl.store(o_row_ptr + offsets, result, mask=mask)
def swiglustep_and_mul_triton(
output: torch.Tensor, input: torch.Tensor, limit: float = 7.0
):
b, n = input.shape
assert input.ndim == 2
assert n % 2 == 0
d = n // 2
def grid(meta):
return (b, triton.cdiv(d, meta["BLOCK_SIZE"]))
_swiglustep_and_mul_kernel[grid](
output,
output.stride(0),
input,
input.stride(0),
limit=limit,
d=d,
BLOCK_SIZE=1024,
)
# --8<-- [start:fatrelu_and_mul]
@CustomOp.register("fatrelu_and_mul")
class FatreluAndMul(CustomOp):
......@@ -304,6 +356,44 @@ class SwigluOAIAndMul(CustomOp):
return f"alpha={repr(self.alpha)}, limit={repr(self.limit)}"
# --8<-- [start:swiglustep_and_mul]
@CustomOp.register("swiglustep_and_mul")
class SwigluStepAndMul(CustomOp):
"""An activation function for SwiGLU with clamping.
Computes x -> silu(x[:d]).clamp(max=limit) * x[d:].clamp(-limit, limit)
where d = x.shape[-1] // 2.
Shapes:
x: (num_tokens, 2 * d) or (batch_size, seq_len, 2 * d)
return: (num_tokens, d) or (batch_size, seq_len, d)
"""
def __init__(self, limit: float = 7.0):
super().__init__()
if limit is None:
raise ValueError("SwigluStepAndMul requires limit to be set.")
self.limit = limit
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
"""PyTorch-native implementation equivalent to forward()."""
gate, up = x.chunk(2, dim=-1)
gate = F.silu(gate)
gate = gate.clamp(max=self.limit)
up = up.clamp(min=-self.limit, max=self.limit)
return gate * up
def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
d = x.shape[-1] // 2
output_shape = x.shape[:-1] + (d,)
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
swiglustep_and_mul_triton(out, x, self.limit)
return out
def extra_repr(self) -> str:
return f"limit={repr(self.limit)}"
# --8<-- [start:gelu_new]
@CustomOp.register("gelu_new")
class NewGELU(CustomOp):
......
......@@ -657,7 +657,12 @@ class CutlassExpertsFp4(mk.FusedMoEPermuteExpertsUnpermute):
@staticmethod
def _supports_current_device() -> bool:
return current_platform.has_device_capability((10, 0))
p = current_platform
return p.is_cuda() and (
p.is_device_capability_family(100)
or p.is_device_capability_family(110)
or p.is_device_capability_family(120)
)
@staticmethod
def _supports_no_act_and_mul() -> bool:
......
......@@ -144,7 +144,7 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
@staticmethod
def _supports_activation(activation: str) -> bool:
return activation in ["silu"]
return activation in ["silu", "swiglustep"]
@staticmethod
def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool:
......
......@@ -54,7 +54,8 @@ class FlashInferCuteDSLExperts(mk.FusedMoEPermuteExpertsUnpermute):
@staticmethod
def _supports_current_device() -> bool:
return current_platform.is_device_capability_family(100)
p = current_platform
return p.is_cuda() and p.is_device_capability_family(100)
@staticmethod
def _supports_no_act_and_mul() -> bool:
......
......@@ -84,11 +84,14 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute):
@staticmethod
def _supports_current_device() -> bool:
p = current_platform
return (
current_platform.is_cuda()
p.is_cuda()
and (
current_platform.is_device_capability((9, 0))
or current_platform.is_device_capability_family(100)
p.is_device_capability(90)
or p.is_device_capability_family(100)
or p.is_device_capability_family(110)
or p.is_device_capability_family(120)
)
and has_flashinfer_cutlass_fused_moe()
)
......@@ -102,29 +105,27 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute):
weight_key: QuantKey | None,
activation_key: QuantKey | None,
) -> bool:
# The following are supported by FlashInferExperts:
# * unquantized
# * fp8 static per-tensor on 9.0+
# * fp8 block on 9.0
# * nvfp4 on 10.0+
p = current_platform
scheme = (weight_key, activation_key)
# The following are supported by FlashInferExperts:
return (
# unquantized and fp8 static per-tensor on 9.0+
(
scheme
in [
(None, None),
(kFp8StaticTensorSym, kFp8StaticTensorSym),
]
and p.has_device_capability(90)
)
# fp8 block-scale on 9.0
or (
(scheme == (kFp8Static128BlockSym, kFp8Dynamic128Sym))
and (p.is_device_capability((9, 0)))
scheme == (kFp8Static128BlockSym, kFp8Dynamic128Sym)
and p.is_device_capability(90)
)
# nvfp4 on 10.0+
or (
(scheme == (kNvfp4Static, kNvfp4Dynamic))
and (p.is_device_capability_family(100))
scheme == (kNvfp4Static, kNvfp4Dynamic) and p.has_device_capability(100)
)
)
......
......@@ -30,7 +30,6 @@ from vllm.utils.torch_utils import direct_register_custom_op
def _supports_current_device() -> bool:
"""Supports only Blackwell-family GPUs."""
p = current_platform
# Add check flashinfer trtllm is available
return p.is_cuda() and p.is_device_capability_family(100)
......@@ -70,9 +69,14 @@ def _supports_routing_method(
RoutingMethodType.RenormalizeNaive,
]
elif (weight_key, activation_key) == (kFp8StaticTensorSym, kFp8StaticTensorSym):
# NOTE(rob): kernel requires Llama4.
return routing_method == RoutingMethodType.Llama4
# NOTE(dbari): as above, potentially allow others here.
return routing_method in [
RoutingMethodType.Llama4,
# NOTE(mgoin): Disabled to investigate accuracy issues.
# See https://github.com/vllm-project/vllm/issues/33532
# RoutingMethodType.Renormalize,
# RoutingMethodType.RenormalizeNaive,
]
else:
raise ValueError("Unsupported quantization scheme.")
......@@ -82,7 +86,23 @@ def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bo
return not moe_parallel_config.enable_eplb
def is_supported_config_trtllm(
def _supports_router_logits_dtype(
router_logits_dtype: torch.dtype | None,
routing_method: RoutingMethodType,
) -> bool:
"""
The FlashInfer TRTLLM FP8 kernel expects bfloat16 router_logits by default.
Only DeepSeekV3 routing supports float32 router_logits (which is converted
internally in the kernel).
"""
if router_logits_dtype == torch.float32:
# Only DeepSeekV3 routing handles float32 logits
# https://github.com/flashinfer-ai/flashinfer/issues/2469
return routing_method == RoutingMethodType.DeepSeekV3
return True
def is_supported_config_trtllm_fp8(
moe_config: FusedMoEConfig,
weight_key: QuantKey | None,
activation_key: QuantKey | None,
......@@ -111,13 +131,17 @@ def is_supported_config_trtllm(
return False, _make_reason("routing method")
elif activation_format != mk.FusedMoEActivationFormat.Standard:
return False, _make_reason("activation format")
elif not _supports_router_logits_dtype(
moe_config.router_logits_dtype, moe_config.routing_method
):
return False, _make_reason("float32 router_logits with non-DeepSeekV3 routing")
return True, None
def flashinfer_fused_moe_blockscale_fp8(
routing_logits: torch.Tensor,
routing_bias: torch.Tensor,
routing_bias: torch.Tensor | None,
x: torch.Tensor,
w13_weight: torch.Tensor,
w13_weight_scale_inv: torch.Tensor,
......@@ -131,7 +155,7 @@ def flashinfer_fused_moe_blockscale_fp8(
expert_offset: int,
local_num_experts: int,
block_shape: list[int],
routing_method_type: int = int(RoutingMethodType.DeepSeekV3),
routing_method_type: int,
routed_scaling: float | None = 1.0,
) -> torch.Tensor:
from vllm.utils.flashinfer import flashinfer_trtllm_fp8_block_scale_moe
......@@ -144,6 +168,13 @@ def flashinfer_fused_moe_blockscale_fp8(
# Routing kernel expects #experts <= #threads 512
assert global_num_experts <= 512
# The DeepSeekV3 routing method requires float32 router logits.
if routing_method_type == RoutingMethodType.DeepSeekV3:
routing_logits = routing_logits.to(torch.float32)
if routing_bias is not None:
routing_bias = routing_bias.to(x.dtype)
a_q, a_sf = per_token_group_quant_fp8(x, block_shape[1])
# NOTE: scales of hidden states have to be transposed!
a_sf_t = a_sf.t().contiguous()
......@@ -171,7 +202,7 @@ def flashinfer_fused_moe_blockscale_fp8(
def flashinfer_fused_moe_blockscale_fp8_fake(
routing_logits: torch.Tensor,
routing_bias: torch.Tensor,
routing_bias: torch.Tensor | None,
x: torch.Tensor,
w13_weight: torch.Tensor,
w13_weight_scale_inv: torch.Tensor,
......
......@@ -933,6 +933,7 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
SUPPORTED_W_A_FP8 = [
(kFp8Static128BlockSym, kFp8Dynamic128Sym),
(kFp8StaticChannelSym, kFp8DynamicTokenSym),
(kFp8StaticTensorSym, kFp8DynamicTokenSym),
(kFp8StaticTensorSym, kFp8StaticTensorSym),
(kFp8StaticTensorSym, kFp8DynamicTensorSym),
]
......
......@@ -45,6 +45,7 @@ from vllm.model_executor.layers.quantization.utils.ocp_mx_utils import OCP_MX_Sc
from vllm.model_executor.layers.quantization.utils.quant_utils import (
QuantKey,
kFp8Dynamic128Sym,
kFp8DynamicTensorSym,
kFp8DynamicTokenSym,
kFp8Static128BlockSym,
kFp8StaticChannelSym,
......@@ -1956,12 +1957,13 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
(kFp8StaticChannelSym, kFp8DynamicTokenSym),
(kFp8StaticTensorSym, kFp8DynamicTokenSym),
(kFp8StaticTensorSym, kFp8StaticTensorSym),
(kFp8StaticTensorSym, kFp8DynamicTensorSym),
]
return (weight_key, activation_key) in SUPPORTED_W_A
@staticmethod
def _supports_activation(activation: str) -> bool:
return activation in ["silu", "gelu", "swigluoai"]
return activation in ["silu", "gelu", "swigluoai", "swiglustep"]
@staticmethod
def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool:
......
......@@ -408,6 +408,7 @@ class FusedMoE(CustomOp):
if prefix in compilation_config.static_forward_context:
raise ValueError("Duplicate layer name: {}".format(prefix))
compilation_config.static_forward_context[prefix] = self
compilation_config.static_all_moe_layers.append(prefix)
self.layer_name = prefix
self.enable_eplb = enable_eplb
......@@ -1567,7 +1568,7 @@ class FusedMoE(CustomOp):
# Can be unavailable or None in unittests
if (
is_forward_context_available()
and get_forward_context().remaining_moe_layers is not None
and get_forward_context().all_moe_layers is not None
):
return "from_forward_context"
return self.layer_name
......@@ -1988,13 +1989,17 @@ class FusedMoE(CustomOp):
def get_layer_from_name(layer_name: str) -> FusedMoE:
forward_context: ForwardContext = get_forward_context()
if layer_name == "from_forward_context":
if not forward_context.remaining_moe_layers:
all_moe_layers = forward_context.all_moe_layers
assert all_moe_layers is not None
moe_layer_index = forward_context.moe_layer_index
if moe_layer_index >= len(all_moe_layers):
raise AssertionError(
"We expected the number of MOE layers in `remaining_moe_layers` "
"We expected the number of MOE layers in `all_moe_layers` "
"to be equal to the number of "
"{vllm.moe_forward, vllm.moe_forward_shared} calls."
)
layer_name = forward_context.remaining_moe_layers.pop()
layer_name = all_moe_layers[moe_layer_index]
forward_context.moe_layer_index += 1
self = cast(FusedMoE, forward_context.no_compile_layers[layer_name])
return self
......
......@@ -18,7 +18,7 @@ from vllm.model_executor.layers.fused_moe.config import (
fp8_w8a16_moe_quant_config,
)
from vllm.model_executor.layers.fused_moe.flashinfer_trtllm_moe import (
is_supported_config_trtllm,
is_supported_config_trtllm_fp8,
)
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
FlashinferMoeBackend,
......@@ -213,7 +213,7 @@ def select_fp8_moe_backend(
if fi_backend == FlashinferMoeBackend.TENSORRT_LLM:
backend = Fp8MoeBackend.FLASHINFER_TRTLLM
supported, reason = is_supported_config_trtllm(
supported, reason = is_supported_config_trtllm_fp8(
config, weight_key, activation_key, activation_format
)
if supported:
......@@ -240,7 +240,7 @@ def select_fp8_moe_backend(
]:
if backend == Fp8MoeBackend.FLASHINFER_TRTLLM:
k_cls = None
supported, reason = is_supported_config_trtllm(
supported, reason = is_supported_config_trtllm_fp8(
config,
weight_key,
activation_key,
......@@ -309,7 +309,7 @@ def select_fp8_moe_backend(
for backend in AVAILABLE_BACKENDS:
if backend == Fp8MoeBackend.FLASHINFER_TRTLLM:
k_cls = None
supported, reason = is_supported_config_trtllm(
supported, reason = is_supported_config_trtllm_fp8(
config,
weight_key,
activation_key,
......
......@@ -358,6 +358,11 @@ def apply_moe_activation(
torch.ops._C.gelu_and_mul(output, input)
elif activation == "swigluoai":
torch.ops._C.swigluoai_and_mul(output, input)
elif activation == "swiglustep":
from vllm.model_executor.layers.activation import swiglustep_and_mul_triton
swiglustep_and_mul_triton(output, input)
# Activations without gated multiplication
elif activation == SILU_NO_MUL:
output.copy_(F.silu(input))
......
......@@ -27,7 +27,6 @@ from vllm.model_executor.layers.fused_moe import (
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig,
FusedMoEQuantConfig,
RoutingMethodType,
int4_w4a16_moe_quant_config,
int4_w4afp8_moe_quant_config,
int8_w8a8_moe_quant_config,
......@@ -1022,17 +1021,9 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
if self.block_quant:
import vllm.model_executor.layers.fused_moe.flashinfer_trtllm_moe # noqa: E501, F401
e_score_correction_bias = (
layer.e_score_correction_bias.to(x.dtype)
if layer.e_score_correction_bias is not None
else None
)
routing_method_type = layer.routing_method_type
return torch.ops.vllm.flashinfer_fused_moe_blockscale_fp8(
routing_logits=router_logits.to(torch.float32)
if routing_method_type == RoutingMethodType.DeepSeekV3
else router_logits,
routing_bias=e_score_correction_bias,
routing_logits=router_logits,
routing_bias=layer.e_score_correction_bias,
x=x,
w13_weight=layer.w13_weight,
w13_weight_scale_inv=layer.w13_weight_scale,
......@@ -1046,7 +1037,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
expert_offset=layer.ep_rank * layer.local_num_experts,
local_num_experts=layer.local_num_experts,
block_shape=self.weight_block_size,
routing_method_type=routing_method_type,
routing_method_type=layer.routing_method_type,
routed_scaling=layer.routed_scaling_factor,
)
else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment