Unverified Commit f5972a87 authored by Benjamin Chislett's avatar Benjamin Chislett Committed by GitHub
Browse files

[Model][Spec Decode] Nemotron-H MTP and Mamba Speculative Decoding Support (#33726)


Signed-off-by: default avatarShahar Mor <smor@nvidia.com>
Signed-off-by: default avatarBenjamin Chislett <bchislett@nvidia.com>
Signed-off-by: default avatarLucas Wilkinson <lwilkins@redhat.com>
Co-authored-by: default avatarShahar Mor <smor@nvidia.com>
Co-authored-by: default avatarRoi Koren <roik@nvidia.com>
Co-authored-by: default avatarLucas Wilkinson <lwilkins@redhat.com>
parent a9e15e04
...@@ -1200,6 +1200,11 @@ _SPECULATIVE_DECODING_EXAMPLE_MODELS = { ...@@ -1200,6 +1200,11 @@ _SPECULATIVE_DECODING_EXAMPLE_MODELS = {
}, },
is_available_online=False, is_available_online=False,
), ),
"NemotronHMTPModel": _HfExamplesInfo(
"nvidia/Nemotron-Super-Placeholder",
speculative_model="nvidia/Nemotron-Super-Placeholder",
is_available_online=False,
),
} }
_TRANSFORMERS_BACKEND_MODELS = { _TRANSFORMERS_BACKEND_MODELS = {
......
...@@ -41,6 +41,9 @@ def _make_vllm_config(block_size, max_model_len, max_num_seqs): ...@@ -41,6 +41,9 @@ def _make_vllm_config(block_size, max_model_len, max_num_seqs):
cudagraph_mode=CUDAGraphMode.FULL, cudagraph_mode=CUDAGraphMode.FULL,
max_cudagraph_capture_size=None, max_cudagraph_capture_size=None,
), ),
speculative_config=None,
num_speculative_tokens=0,
parallel_config=SimpleNamespace(decode_context_parallel_size=1),
scheduler_config=SimpleNamespace(max_num_seqs=max_num_seqs), scheduler_config=SimpleNamespace(max_num_seqs=max_num_seqs),
model_config=SimpleNamespace(max_model_len=max_model_len), model_config=SimpleNamespace(max_model_len=max_model_len),
) )
...@@ -92,7 +95,10 @@ def test_update_block_table_copies_block_idx_to_persistent_buffers(): ...@@ -92,7 +95,10 @@ def test_update_block_table_copies_block_idx_to_persistent_buffers():
has_initial_states_p=None, has_initial_states_p=None,
query_start_loc_p=None, query_start_loc_p=None,
num_computed_tokens_p=None, num_computed_tokens_p=None,
state_indices_tensor=builder_a.state_indices_tensor[:num_reqs], state_indices_tensor_p=None,
query_start_loc_d=None,
num_accepted_tokens=None,
state_indices_tensor_d=builder_a.state_indices_tensor_d[:num_reqs],
block_idx_last_scheduled_token=( block_idx_last_scheduled_token=(
builder_a.block_idx_last_scheduled_token[:num_reqs] builder_a.block_idx_last_scheduled_token[:num_reqs]
), ),
......
...@@ -36,6 +36,7 @@ MTPModelTypes = Literal[ ...@@ -36,6 +36,7 @@ MTPModelTypes = Literal[
"glm4_moe_lite_mtp", "glm4_moe_lite_mtp",
"glm_ocr_mtp", "glm_ocr_mtp",
"ernie_mtp", "ernie_mtp",
"nemotron_h_mtp",
"exaone_moe_mtp", "exaone_moe_mtp",
"qwen3_next_mtp", "qwen3_next_mtp",
"qwen3_5_mtp", "qwen3_5_mtp",
...@@ -255,6 +256,19 @@ class SpeculativeConfig: ...@@ -255,6 +256,19 @@ class SpeculativeConfig:
{"n_predict": n_predict, "architectures": ["ErnieMTPModel"]} {"n_predict": n_predict, "architectures": ["ErnieMTPModel"]}
) )
if (
hf_config.model_type == "nemotron_h"
and hasattr(hf_config, "num_nextn_predict_layers")
and hf_config.num_nextn_predict_layers > 0
):
# Check if this is an MTP variant
hf_config.model_type = "nemotron_h_mtp"
if hf_config.model_type == "nemotron_h_mtp":
n_predict = getattr(hf_config, "num_nextn_predict_layers", 1)
hf_config.update(
{"n_predict": n_predict, "architectures": ["NemotronHMTPModel"]}
)
if hf_config.model_type == "qwen3_next": if hf_config.model_type == "qwen3_next":
hf_config.model_type = "qwen3_next_mtp" hf_config.model_type = "qwen3_next_mtp"
if hf_config.model_type == "qwen3_next_mtp": if hf_config.model_type == "qwen3_next_mtp":
...@@ -325,7 +339,7 @@ class SpeculativeConfig: ...@@ -325,7 +339,7 @@ class SpeculativeConfig:
if self.target_model_config is None: if self.target_model_config is None:
raise ValueError("target_model_config must be present for mtp") raise ValueError("target_model_config must be present for mtp")
if self.target_model_config.hf_text_config.model_type == "deepseek_v32": if self.target_model_config.hf_text_config.model_type == "deepseek_v32":
# FIXME(luccafong): cudgraph with v32 MTP is not supported, # FIXME(luccafong): cudagraph with v32 MTP is not supported,
# remove this when the issue is fixed. # remove this when the issue is fixed.
self.enforce_eager = True self.enforce_eager = True
# use the draft model from the same model: # use the draft model from the same model:
...@@ -427,7 +441,7 @@ class SpeculativeConfig: ...@@ -427,7 +441,7 @@ class SpeculativeConfig:
self.method = "mtp" self.method = "mtp"
if self.num_speculative_tokens > 1: if self.num_speculative_tokens > 1:
logger.warning( logger.warning(
"Enabling num_speculative_tokens > 1 will run" "Enabling num_speculative_tokens > 1 will run "
"multiple times of forward on same MTP layer" "multiple times of forward on same MTP layer"
",which may result in lower acceptance rate" ",which may result in lower acceptance rate"
) )
...@@ -712,6 +726,7 @@ class SpeculativeConfig: ...@@ -712,6 +726,7 @@ class SpeculativeConfig:
"hunyuan_vl", "hunyuan_vl",
"hunyuan_v1_dense", "hunyuan_v1_dense",
"afmoe", "afmoe",
"nemotron_h",
] ]
if ( if (
self.method == "eagle3" self.method == "eagle3"
......
...@@ -395,6 +395,15 @@ class VllmConfig: ...@@ -395,6 +395,15 @@ class VllmConfig:
] ]
return hash_str return hash_str
@property
def num_speculative_tokens(self) -> int:
if (
self.speculative_config is not None
and self.speculative_config.num_speculative_tokens is not None
):
return self.speculative_config.num_speculative_tokens
return 0
@property @property
def needs_dp_coordinator(self) -> bool: def needs_dp_coordinator(self) -> bool:
""" """
......
...@@ -41,14 +41,6 @@ class MambaBase(AttentionLayerBase): ...@@ -41,14 +41,6 @@ class MambaBase(AttentionLayerBase):
pass pass
def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec | None: def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec | None:
if (
vllm_config.speculative_config is not None
and vllm_config.model_config.hf_config.model_type
not in ["qwen3_next", "qwen3_5", "qwen3_5_moe"]
):
raise NotImplementedError(
"Mamba with speculative decoding is not supported yet."
)
mamba_block_size = vllm_config.cache_config.mamba_block_size mamba_block_size = vllm_config.cache_config.mamba_block_size
page_size_padded = vllm_config.cache_config.mamba_page_size_padded page_size_padded = vllm_config.cache_config.mamba_page_size_padded
return MambaSpec( return MambaSpec(
......
...@@ -265,7 +265,8 @@ class MambaMixer(MambaBase, PluggableLayer): ...@@ -265,7 +265,8 @@ class MambaMixer(MambaBase, PluggableLayer):
attn_metadata = attn_metadata[self.prefix] attn_metadata = attn_metadata[self.prefix]
assert isinstance(attn_metadata, Mamba1AttentionMetadata) assert isinstance(attn_metadata, Mamba1AttentionMetadata)
query_start_loc_p = attn_metadata.query_start_loc_p query_start_loc_p = attn_metadata.query_start_loc_p
state_indices_tensor = attn_metadata.state_indices_tensor state_indices_tensor_p = attn_metadata.state_indices_tensor_p
state_indices_tensor_d = attn_metadata.state_indices_tensor_d
self_kv_cache = self.kv_cache[forward_context.virtual_engine] self_kv_cache = self.kv_cache[forward_context.virtual_engine]
conv_state = self_kv_cache[0].transpose(-1, -2) conv_state = self_kv_cache[0].transpose(-1, -2)
ssm_state = self_kv_cache[1] ssm_state = self_kv_cache[1]
...@@ -295,17 +296,13 @@ class MambaMixer(MambaBase, PluggableLayer): ...@@ -295,17 +296,13 @@ class MambaMixer(MambaBase, PluggableLayer):
prefill_decode_split = split_batch_to_prefill_and_decode( prefill_decode_split = split_batch_to_prefill_and_decode(
hidden_states_BC, hidden_states_BC,
gate, gate,
state_indices_tensor,
num_prefill_tokens, num_prefill_tokens,
num_prefills,
num_decode_tokens, num_decode_tokens,
) )
hidden_states_BC_p = prefill_decode_split.hidden_states_BC_p hidden_states_BC_p = prefill_decode_split.hidden_states_BC_p
hidden_states_BC_d = prefill_decode_split.hidden_states_BC_d hidden_states_BC_d = prefill_decode_split.hidden_states_BC_d
gate_p = prefill_decode_split.gate_p gate_p = prefill_decode_split.gate_p
gate_d = prefill_decode_split.gate_d gate_d = prefill_decode_split.gate_d
state_indices_tensor_p = prefill_decode_split.state_indices_tensor_p
state_indices_tensor_d = prefill_decode_split.state_indices_tensor_d
if is_mamba_cache_all: if is_mamba_cache_all:
block_idx_last_computed_token_d, block_idx_last_computed_token_p = ( block_idx_last_computed_token_d, block_idx_last_computed_token_p = (
...@@ -477,16 +474,12 @@ class PrefillDecodeSplit(NamedTuple): ...@@ -477,16 +474,12 @@ class PrefillDecodeSplit(NamedTuple):
hidden_states_BC_d: torch.Tensor hidden_states_BC_d: torch.Tensor
gate_p: torch.Tensor gate_p: torch.Tensor
gate_d: torch.Tensor gate_d: torch.Tensor
state_indices_tensor_p: torch.Tensor
state_indices_tensor_d: torch.Tensor
def split_batch_to_prefill_and_decode( def split_batch_to_prefill_and_decode(
hidden_states_BC: torch.Tensor, hidden_states_BC: torch.Tensor,
gate: torch.Tensor, gate: torch.Tensor,
state_indices_tensor: torch.Tensor,
num_prefill_tokens: int, num_prefill_tokens: int,
num_prefills: int,
num_decode_tokens: int, num_decode_tokens: int,
) -> PrefillDecodeSplit: ) -> PrefillDecodeSplit:
num_actual_tokens = num_prefill_tokens + num_decode_tokens num_actual_tokens = num_prefill_tokens + num_decode_tokens
...@@ -501,20 +494,11 @@ def split_batch_to_prefill_and_decode( ...@@ -501,20 +494,11 @@ def split_batch_to_prefill_and_decode(
gate[..., :num_actual_tokens], [num_decode_tokens, num_prefill_tokens], dim=-1 gate[..., :num_actual_tokens], [num_decode_tokens, num_prefill_tokens], dim=-1
) )
# num_decode_tokens accounts for CUDA graph padding when applicable
state_indices_tensor_d, state_indices_tensor_p = torch.split(
state_indices_tensor[: num_decode_tokens + num_prefills],
[num_decode_tokens, num_prefills],
dim=0,
)
return PrefillDecodeSplit( return PrefillDecodeSplit(
hidden_states_BC_p=hidden_states_BC_p, hidden_states_BC_p=hidden_states_BC_p,
hidden_states_BC_d=hidden_states_BC_d, hidden_states_BC_d=hidden_states_BC_d,
gate_p=gate_p, gate_p=gate_p,
gate_d=gate_d, gate_d=gate_d,
state_indices_tensor_p=state_indices_tensor_p,
state_indices_tensor_d=state_indices_tensor_d,
) )
......
...@@ -477,7 +477,8 @@ class MambaMixer2(MambaBase, PluggableLayer): ...@@ -477,7 +477,8 @@ class MambaMixer2(MambaBase, PluggableLayer):
dim=-1, dim=-1,
) )
compilation_config = get_current_vllm_config().compilation_config vllm_config = get_current_vllm_config()
compilation_config = vllm_config.compilation_config
if prefix in compilation_config.static_forward_context: if prefix in compilation_config.static_forward_context:
raise ValueError(f"Duplicate layer name: {prefix}") raise ValueError(f"Duplicate layer name: {prefix}")
compilation_config.static_forward_context[prefix] = self compilation_config.static_forward_context[prefix] = self
...@@ -488,6 +489,8 @@ class MambaMixer2(MambaBase, PluggableLayer): ...@@ -488,6 +489,8 @@ class MambaMixer2(MambaBase, PluggableLayer):
self.cache_config = cache_config self.cache_config = cache_config
self.prefix = prefix self.prefix = prefix
self.num_spec = vllm_config.num_speculative_tokens
# Pre-compute sizes for forward pass # Pre-compute sizes for forward pass
self.tped_intermediate_size = self.intermediate_size // self.tp_size self.tped_intermediate_size = self.intermediate_size // self.tp_size
self.tped_conv_size = self.conv_dim // self.tp_size self.tped_conv_size = self.conv_dim // self.tp_size
...@@ -576,7 +579,6 @@ class MambaMixer2(MambaBase, PluggableLayer): ...@@ -576,7 +579,6 @@ class MambaMixer2(MambaBase, PluggableLayer):
# conv_state = (..., dim, width-1) yet contiguous along 'dim' # conv_state = (..., dim, width-1) yet contiguous along 'dim'
conv_state = self_kv_cache[0].transpose(-1, -2) conv_state = self_kv_cache[0].transpose(-1, -2)
ssm_state = self_kv_cache[1] ssm_state = self_kv_cache[1]
state_indices_tensor = attn_metadata.state_indices_tensor
has_initial_states_p = attn_metadata.has_initial_states_p has_initial_states_p = attn_metadata.has_initial_states_p
prep_initial_states = attn_metadata.prep_initial_states prep_initial_states = attn_metadata.prep_initial_states
chunk_size = attn_metadata.chunk_size chunk_size = attn_metadata.chunk_size
...@@ -584,6 +586,12 @@ class MambaMixer2(MambaBase, PluggableLayer): ...@@ -584,6 +586,12 @@ class MambaMixer2(MambaBase, PluggableLayer):
query_start_loc_p = attn_metadata.query_start_loc_p query_start_loc_p = attn_metadata.query_start_loc_p
cu_chunk_seqlen_p = attn_metadata.cu_chunk_seqlen_p cu_chunk_seqlen_p = attn_metadata.cu_chunk_seqlen_p
last_chunk_indices_p = attn_metadata.last_chunk_indices_p last_chunk_indices_p = attn_metadata.last_chunk_indices_p
state_indices_tensor_p = attn_metadata.state_indices_tensor_p
state_indices_tensor_d = attn_metadata.state_indices_tensor_d
num_accepted_tokens = attn_metadata.num_accepted_tokens
query_start_loc_d = attn_metadata.query_start_loc_d
num_decodes = attn_metadata.num_decodes
num_decode_tokens = attn_metadata.num_decode_tokens
if attn_metadata is None: if attn_metadata is None:
# profile run # profile run
...@@ -593,29 +601,21 @@ class MambaMixer2(MambaBase, PluggableLayer): ...@@ -593,29 +601,21 @@ class MambaMixer2(MambaBase, PluggableLayer):
hidden_states, _B, _C = self.split_hidden_states_B_C_fn(hidden_states_B_C) hidden_states, _B, _C = self.split_hidden_states_B_C_fn(hidden_states_B_C)
return hidden_states return hidden_states
num_prefills = attn_metadata.num_prefills # request count num_prefills = attn_metadata.num_prefills
num_decodes = attn_metadata.num_decode_tokens # token count (=request) num_prefill_tokens = attn_metadata.num_prefill_tokens
num_prefill_tokens = attn_metadata.num_prefill_tokens # token count
has_prefill = num_prefills > 0 has_prefill = num_prefills > 0
has_decode = num_decodes > 0 has_decode = num_decodes > 0
num_actual_tokens = num_prefill_tokens + num_decodes num_actual_tokens = num_prefill_tokens + num_decode_tokens
# Separate prefill and decode by splitting varlen input
# Split along token dimension # Split along token dimension
hidden_states_B_C_d, hidden_states_B_C_p = torch.split( hidden_states_B_C_d, hidden_states_B_C_p = torch.split(
hidden_states_B_C[:num_actual_tokens], hidden_states_B_C[:num_actual_tokens],
[num_decodes, num_prefill_tokens], [num_decode_tokens, num_prefill_tokens],
dim=0, dim=0,
) )
dt_d, dt_p = torch.split( dt_d, dt_p = torch.split(
dt[:num_actual_tokens], dt[:num_actual_tokens],
[num_decodes, num_prefill_tokens], [num_decode_tokens, num_prefill_tokens],
dim=0,
)
# Split along batch dimension
state_indices_tensor_d, state_indices_tensor_p = torch.split(
state_indices_tensor[:num_actual_tokens],
[num_decodes, num_prefills],
dim=0, dim=0,
) )
...@@ -642,16 +642,16 @@ class MambaMixer2(MambaBase, PluggableLayer): ...@@ -642,16 +642,16 @@ class MambaMixer2(MambaBase, PluggableLayer):
) )
num_computed_tokens_p = attn_metadata.num_computed_tokens_p num_computed_tokens_p = attn_metadata.num_computed_tokens_p
else: else:
block_idx_last_computed_token_d = None
block_idx_last_computed_token_p = None block_idx_last_computed_token_p = None
block_idx_last_scheduled_token_d = None
block_idx_last_scheduled_token_p = None block_idx_last_scheduled_token_p = None
block_idx_first_scheduled_token_p = None block_idx_first_scheduled_token_p = None
block_idx_last_scheduled_token_d = None
block_idx_last_computed_token_d = None
num_computed_tokens_p = None num_computed_tokens_p = None
preallocated_ssm_out_d, preallocated_ssm_out_p = torch.split( preallocated_ssm_out_d, preallocated_ssm_out_p = torch.split(
output[:num_actual_tokens], output[:num_actual_tokens],
[num_decodes, num_prefill_tokens], [num_decode_tokens, num_prefill_tokens],
dim=0, dim=0,
) )
...@@ -709,6 +709,7 @@ class MambaMixer2(MambaBase, PluggableLayer): ...@@ -709,6 +709,7 @@ class MambaMixer2(MambaBase, PluggableLayer):
) )
# NOTE: final output is an in-place update of out tensor # NOTE: final output is an in-place update of out tensor
assert preallocated_ssm_out_p is not None
varlen_states = mamba_chunk_scan_combined_varlen( varlen_states = mamba_chunk_scan_combined_varlen(
hidden_states_p.view( hidden_states_p.view(
num_prefill_tokens, self.num_heads // self.tp_size, self.head_dim num_prefill_tokens, self.num_heads // self.tp_size, self.head_dim
...@@ -840,6 +841,9 @@ class MambaMixer2(MambaBase, PluggableLayer): ...@@ -840,6 +841,9 @@ class MambaMixer2(MambaBase, PluggableLayer):
conv_state_indices=state_indices_tensor_d, conv_state_indices=state_indices_tensor_d,
block_idx_last_scheduled_token=block_idx_last_scheduled_token_d, block_idx_last_scheduled_token=block_idx_last_scheduled_token_d,
initial_state_idx=block_idx_last_computed_token_d, initial_state_idx=block_idx_last_computed_token_d,
num_accepted_tokens=num_accepted_tokens,
query_start_loc=query_start_loc_d,
max_query_len=state_indices_tensor_d.size(-1),
) )
hidden_states_d, B_d, C_d = self.split_hidden_states_B_C_fn( hidden_states_d, B_d, C_d = self.split_hidden_states_B_C_fn(
...@@ -862,6 +866,7 @@ class MambaMixer2(MambaBase, PluggableLayer): ...@@ -862,6 +866,7 @@ class MambaMixer2(MambaBase, PluggableLayer):
-1, self.num_heads // self.tp_size, self.head_dim -1, self.num_heads // self.tp_size, self.head_dim
) )
assert preallocated_ssm_out_d is not None
# - the hidden is reshaped into (bs, num_heads, head_dim) # - the hidden is reshaped into (bs, num_heads, head_dim)
# - mamba_cache_params.ssm_state's slots will be selected # - mamba_cache_params.ssm_state's slots will be selected
# using state_indices_tensor_d # using state_indices_tensor_d
...@@ -879,7 +884,9 @@ class MambaMixer2(MambaBase, PluggableLayer): ...@@ -879,7 +884,9 @@ class MambaMixer2(MambaBase, PluggableLayer):
dt_softplus=True, dt_softplus=True,
state_batch_indices=state_indices_tensor_d_input, state_batch_indices=state_indices_tensor_d_input,
dst_state_batch_indices=state_indices_tensor_d_output, dst_state_batch_indices=state_indices_tensor_d_output,
out=preallocated_ssm_out_d.view(num_decodes, -1, self.head_dim), out=preallocated_ssm_out_d.view(num_decode_tokens, -1, self.head_dim),
num_accepted_tokens=num_accepted_tokens,
cu_seqlens=query_start_loc_d,
is_blackwell=self.is_blackwell, is_blackwell=self.is_blackwell,
) )
...@@ -901,6 +908,7 @@ class MambaMixer2(MambaBase, PluggableLayer): ...@@ -901,6 +908,7 @@ class MambaMixer2(MambaBase, PluggableLayer):
head_dim=self.head_dim, head_dim=self.head_dim,
state_size=self.ssm_state_size, state_size=self.ssm_state_size,
conv_kernel=self.conv_kernel_size, conv_kernel=self.conv_kernel_size,
num_spec=self.num_spec,
) )
@property @property
......
...@@ -133,6 +133,7 @@ class MambaStateShapeCalculator: ...@@ -133,6 +133,7 @@ class MambaStateShapeCalculator:
head_dim: int, head_dim: int,
state_size: int, state_size: int,
conv_kernel: int, conv_kernel: int,
num_spec: int = 0,
) -> tuple[tuple[int, int], tuple[int, int, int]]: ) -> tuple[tuple[int, int], tuple[int, int, int]]:
# if n_groups is not divisible by world_size, need to extend the shards # if n_groups is not divisible by world_size, need to extend the shards
# to ensure all groups needed by a head is sharded along with it # to ensure all groups needed by a head is sharded along with it
...@@ -141,7 +142,7 @@ class MambaStateShapeCalculator: ...@@ -141,7 +142,7 @@ class MambaStateShapeCalculator:
conv_dim = intermediate_size + 2 * n_groups * state_size conv_dim = intermediate_size + 2 * n_groups * state_size
# contiguous along 'dim' axis # contiguous along 'dim' axis
conv_state_shape = (conv_kernel - 1, divide(conv_dim, tp_world_size)) conv_state_shape = (conv_kernel - 1 + num_spec, divide(conv_dim, tp_world_size))
# These are not TP-ed as they depend on A, dt_bias, D # These are not TP-ed as they depend on A, dt_bias, D
# - they are typically small # - they are typically small
......
...@@ -1155,7 +1155,9 @@ def causal_conv1d_update( ...@@ -1155,7 +1155,9 @@ def causal_conv1d_update(
if conv_state_indices is None: if conv_state_indices is None:
assert conv_state.size(0) >= batch assert conv_state.size(0) >= batch
else: else:
assert (batch,) == conv_state_indices.shape assert batch == conv_state_indices.shape[0], (
f"ERROR: conv_state_indices should have shape ({batch},*) but got {conv_state_indices.shape}"
)
assert num_cache_lines >= batch assert num_cache_lines >= batch
assert weight.stride(1) == 1 # Need this assert weight.stride(1) == 1 # Need this
......
...@@ -119,7 +119,8 @@ class ShortConv(MambaBase, CustomOp): ...@@ -119,7 +119,8 @@ class ShortConv(MambaBase, CustomOp):
assert isinstance(attn_metadata, ShortConvAttentionMetadata) assert isinstance(attn_metadata, ShortConvAttentionMetadata)
self_kv_cache = self.kv_cache[forward_context.virtual_engine] self_kv_cache = self.kv_cache[forward_context.virtual_engine]
conv_state = self_kv_cache[0].transpose(-1, -2) conv_state = self_kv_cache[0].transpose(-1, -2)
state_indices_tensor = attn_metadata.state_indices_tensor state_indices_tensor_p = attn_metadata.state_indices_tensor_p
state_indices_tensor_d = attn_metadata.state_indices_tensor_d
has_initial_states_p = attn_metadata.has_initial_states_p has_initial_states_p = attn_metadata.has_initial_states_p
query_start_loc_p = attn_metadata.query_start_loc_p query_start_loc_p = attn_metadata.query_start_loc_p
...@@ -163,13 +164,6 @@ class ShortConv(MambaBase, CustomOp): ...@@ -163,13 +164,6 @@ class ShortConv(MambaBase, CustomOp):
[num_decodes, num_prefill_tokens], [num_decodes, num_prefill_tokens],
dim=0, dim=0,
) )
# Split along batch dimension
state_indices_tensor_d, state_indices_tensor_p = torch.split(
state_indices_tensor,
[num_decodes, num_prefills],
dim=0,
)
conv_output_list = [] conv_output_list = []
if has_prefill: if has_prefill:
......
...@@ -228,6 +228,7 @@ class Mamba2ForCausalLM( ...@@ -228,6 +228,7 @@ class Mamba2ForCausalLM(
head_dim=hf_config.head_dim, head_dim=hf_config.head_dim,
state_size=hf_config.state_size, state_size=hf_config.state_size,
conv_kernel=hf_config.conv_kernel, conv_kernel=hf_config.conv_kernel,
num_spec=vllm_config.num_speculative_tokens,
) )
@classmethod @classmethod
......
...@@ -636,6 +636,9 @@ class NemotronHModel(nn.Module): ...@@ -636,6 +636,9 @@ class NemotronHModel(nn.Module):
hidden_states, _ = self.norm_f(hidden_states, residual) hidden_states, _ = self.norm_f(hidden_states, residual)
return hidden_states return hidden_states
def is_spec_layer(self, config: NemotronHConfig, weight_name: str) -> bool:
return weight_name.startswith("mtp.")
def _get_max_n_routed_experts(self) -> int: def _get_max_n_routed_experts(self) -> int:
"""Get max n_routed_experts from config or block_configs for puzzle models. """Get max n_routed_experts from config or block_configs for puzzle models.
...@@ -702,6 +705,10 @@ class NemotronHModel(nn.Module): ...@@ -702,6 +705,10 @@ class NemotronHModel(nn.Module):
if name is None: if name is None:
continue continue
# Skip MTP/spec decode layers early (before stacked params mapping)
if name.startswith("mtp."):
continue
# load stacked params # load stacked params
for param_name, weight_name, shard_id in stacked_params_mapping: for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name: if weight_name not in name:
...@@ -845,6 +852,7 @@ class NemotronHForCausalLM( ...@@ -845,6 +852,7 @@ class NemotronHForCausalLM(
head_dim=hf_config.mamba_head_dim, head_dim=hf_config.mamba_head_dim,
state_size=hf_config.ssm_state_size, state_size=hf_config.ssm_state_size,
conv_kernel=hf_config.conv_kernel, conv_kernel=hf_config.conv_kernel,
num_spec=vllm_config.num_speculative_tokens,
) )
@classmethod @classmethod
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""NemotronH-MTP model with attention layers."""
import typing
from collections.abc import Callable, Iterable
import torch
import torch.nn as nn
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, ModelConfig, VllmConfig
from vllm.config.parallel import ParallelConfig
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import ColumnParallelLinear
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead,
VocabParallelEmbedding,
)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.utils import (
make_empty_intermediate_tensors_factory,
maybe_prefix,
)
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.configs import NemotronHConfig
from .interfaces import SupportsPP
from .nemotron_h import (
NemotronHAttentionDecoderLayer,
NemotronHMoEDecoderLayer,
)
class NemotronHMTPAttentionDecoderLayer(NemotronHAttentionDecoderLayer):
def __init__(
self,
config: NemotronHConfig,
layer_idx: int,
model_config: ModelConfig | None = None,
cache_config: CacheConfig | None = None,
quant_config: QuantizationConfig | None = None,
parallel_config: ParallelConfig | None = None,
prefix: str = "",
has_start_projections: bool = False,
has_end_norm: bool = False,
) -> None:
super().__init__(
config=config,
layer_idx=layer_idx,
model_config=model_config,
cache_config=cache_config,
quant_config=quant_config,
parallel_config=parallel_config,
prefix=prefix,
)
self.has_start_projections = has_start_projections
self.has_end_norm = has_end_norm
if has_start_projections:
self.enorm = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
self.hnorm = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
# Fusion layer to combine embeddings with target hidden states
self.eh_proj = ColumnParallelLinear(
input_size=config.hidden_size * 2,
output_size=config.hidden_size,
bias=False,
gather_output=True,
params_dtype=config.dtype
if hasattr(config, "dtype")
else torch.bfloat16,
quant_config=quant_config,
prefix=f"{prefix}.eh_proj",
)
if has_end_norm:
self.final_layernorm = RMSNorm(
config.hidden_size,
eps=getattr(config, "layer_norm_epsilon", 1e-5),
)
def forward(
self,
inputs_embeds: torch.Tensor,
positions: torch.Tensor,
hidden_states: torch.Tensor,
residual: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor | None]:
# Start projections (Fusion)
if self.has_start_projections:
# Normalize both inputs before fusion
assert inputs_embeds is not None
inputs_embeds_normed = self.enorm(inputs_embeds)
previous_hidden_states_normed = self.hnorm(hidden_states)
# Fuse via concatenation and linear projection
fused = torch.cat(
[inputs_embeds_normed, previous_hidden_states_normed], dim=-1
)
hidden_states, _ = self.eh_proj(fused)
# Call parent forward (Attention)
# Parent forward expects: hidden_states, residual
hidden_states, residual = super().forward(
positions=positions,
hidden_states=hidden_states,
residual=residual,
)
# End norm
if self.has_end_norm:
if residual is not None:
hidden_states = hidden_states + residual
residual = None # Consumed residual
hidden_states = self.final_layernorm(hidden_states)
return hidden_states, residual
class NemotronHMTPMoEDecoderLayer(NemotronHMoEDecoderLayer):
def __init__(
self,
config: NemotronHConfig,
layer_idx: int,
model_config: ModelConfig | None = None,
cache_config: CacheConfig | None = None,
quant_config: QuantizationConfig | None = None,
parallel_config: ParallelConfig | None = None,
prefix: str = "",
has_start_projections: bool = False,
has_end_norm: bool = False,
) -> None:
super().__init__(
config=config,
layer_idx=layer_idx,
model_config=model_config,
cache_config=cache_config,
quant_config=quant_config,
parallel_config=parallel_config,
prefix=prefix,
)
self.has_start_projections = has_start_projections
self.has_end_norm = has_end_norm
if has_start_projections:
self.enorm = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
self.hnorm = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
# Fusion layer to combine embeddings with target hidden states
self.eh_proj = ColumnParallelLinear(
input_size=config.hidden_size * 2,
output_size=config.hidden_size,
bias=False,
gather_output=True,
params_dtype=config.dtype
if hasattr(config, "dtype")
else torch.bfloat16,
quant_config=quant_config,
prefix=f"{prefix}.eh_proj",
)
if has_end_norm:
self.final_layernorm = RMSNorm(
config.hidden_size,
eps=getattr(config, "layer_norm_epsilon", 1e-5),
)
def forward(
self,
inputs_embeds: torch.Tensor,
positions: torch.Tensor,
hidden_states: torch.Tensor,
residual: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor | None]:
# Start projections (Fusion)
if self.has_start_projections:
# Normalize both inputs before fusion
assert inputs_embeds is not None
inputs_embeds_normed = self.enorm(inputs_embeds)
previous_hidden_states_normed = self.hnorm(hidden_states)
# Fuse via concatenation and linear projection
fused = torch.cat(
[inputs_embeds_normed, previous_hidden_states_normed], dim=-1
)
hidden_states, _ = self.eh_proj(fused)
# Call parent forward (MoE)
hidden_states, residual = super().forward(
hidden_states=hidden_states,
residual=residual,
)
# End norm
if self.has_end_norm:
if residual is not None:
hidden_states = hidden_states + residual
residual = None # Consumed residual
hidden_states = self.final_layernorm(hidden_states)
return hidden_states, residual
@support_torch_compile
class NemotronHMultiTokenPredictor(nn.Module):
"""MTP predictor with NemotronH layers."""
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
self.config = config
self.vocab_size = config.vocab_size
self.org_vocab_size = config.vocab_size
self.mtp_start_layer_idx = config.num_hidden_layers
self.num_mtp_layers = getattr(config, "num_nextn_predict_layers", 1)
assert self.num_mtp_layers == 1, (
"Only one MTP layer is supported for NemotronH-MTP"
)
self.pattern_str = config.mtp_hybrid_override_pattern
self.pattern_len = len(self.pattern_str)
assert self.pattern_len > 0
self.embed_tokens = VocabParallelEmbedding(
self.vocab_size,
config.hidden_size,
)
# Build flat list of layers
self.layers = torch.nn.ModuleDict()
# Total number of physical layers = num_steps * pattern_len
total_layers = self.num_mtp_layers * self.pattern_len
for i in range(total_layers):
step_rel_idx = i % self.pattern_len
char = self.pattern_str[step_rel_idx]
is_start_of_step = step_rel_idx == 0
is_end_of_step = step_rel_idx == self.pattern_len - 1
layer_prefix = f"{prefix}.layers.{i}"
# TODO smor- remove double layers formation
common_kwargs = dict(
config=config,
layer_idx=self.mtp_start_layer_idx + i,
model_config=vllm_config.model_config,
cache_config=vllm_config.cache_config,
quant_config=vllm_config.quant_config,
parallel_config=vllm_config.parallel_config,
prefix=layer_prefix,
has_start_projections=is_start_of_step,
has_end_norm=is_end_of_step,
)
if char == "*":
self.layers[str(i)] = NemotronHMTPAttentionDecoderLayer(**common_kwargs)
elif char == "E":
self.layers[str(i)] = NemotronHMTPMoEDecoderLayer(**common_kwargs)
else:
raise NotImplementedError(
f"Pattern char '{char}' in {self.pattern_str} not implemented"
)
self.make_empty_intermediate_tensors: Callable[..., IntermediateTensors] = (
make_empty_intermediate_tensors_factory(
["hidden_states", "residual"], config.hidden_size
)
)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
assert self.embed_tokens is not None, (
"embed_tokens not initialized - must be shared from target model"
)
return self.embed_tokens(input_ids)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
hidden_states: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None,
) -> torch.Tensor | IntermediateTensors:
if inputs_embeds is None:
inputs_embeds = self.get_input_embeddings(input_ids)
residual = None
for i in range(self.pattern_len):
hidden_states, residual = self.layers[str(i)](
inputs_embeds=inputs_embeds,
positions=positions,
hidden_states=hidden_states,
residual=residual,
)
return hidden_states
class NemotronHMTP(nn.Module, SupportsPP):
"""NemotronH MTP model."""
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
"k_proj",
"v_proj",
],
}
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
self.vllm_config = vllm_config
self.config = config
self.quant_config = vllm_config.quant_config
# Needed for load_weights mapping
self.mtp_start_layer_idx = config.num_hidden_layers
# EPLB config for experts
self.num_redundant_experts = 0
if vllm_config.parallel_config and vllm_config.parallel_config.eplb_config:
self.num_redundant_experts = (
vllm_config.parallel_config.eplb_config.num_redundant_experts
)
# MTP predictor
self.model = NemotronHMultiTokenPredictor(
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "mtp")
)
# LM head for generating logits
self.lm_head = ParallelLMHead(
self.config.vocab_size,
self.config.hidden_size,
prefix=maybe_prefix(prefix, "lm_head"),
)
self.logits_processor = LogitsProcessor(self.config.vocab_size)
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors
)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
hidden_states: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None,
**kwargs: object,
) -> torch.Tensor:
"""Forward - applies attention-based MTP."""
hidden_states = self.model(
input_ids,
positions,
hidden_states,
intermediate_tensors,
inputs_embeds,
)
return hidden_states
def compute_logits(
self,
hidden_states: torch.Tensor,
) -> torch.Tensor | None:
"""Compute logits for DRAFT token generation."""
assert self.lm_head is not None, (
"lm_head not initialized - must be shared from target model"
)
return self.logits_processor(self.lm_head, hidden_states)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
"""Load MTP weights with proper name remapping."""
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
]
expert_params_mapping = []
if hasattr(self.config, "n_routed_experts") and self.config.n_routed_experts:
expert_params_mapping = FusedMoE.make_expert_params_mapping(
self,
ckpt_gate_proj_name="up_proj",
ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="", # Empty - non-gated MoE
num_experts=self.config.n_routed_experts,
num_redundant_experts=self.num_redundant_experts,
)
params_dict = dict(self.named_parameters())
loaded_params: set[str] = set()
for name, loaded_weight in weights:
# Only process MTP weights - skip all non-MTP weights
if (
not name.startswith("mtp.")
and "embeddings" not in name
and "lm_head" not in name
):
continue
# Skip rotary embeddings (computed, not loaded)
if "rotary_emb.inv_freq" in name:
continue
name = name.replace("mtp.layers.", "model.layers.")
if "embeddings" in name:
name = name.replace("embeddings", "embed_tokens")
if name.startswith("backbone."):
name = name.replace("backbone.", "model.")
# Handle stacked parameters (qkv_proj) for attention layers
is_stacked = False
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
# Must be in a mixer (attention layer)
if ".mixer." not in name:
continue
is_stacked = True
stacked_name = name.replace(weight_name, param_name)
if stacked_name.endswith(".bias") and stacked_name not in params_dict:
continue
if stacked_name not in params_dict:
# Might be that mapping failed or param doesn't exist
continue
param = params_dict[stacked_name]
weight_loader = getattr(param, "weight_loader", None)
if weight_loader is not None:
weight_loader(param, loaded_weight, shard_id)
loaded_params.add(stacked_name)
break
if is_stacked:
continue
is_expert_weight = False
for mapping in expert_params_mapping:
param_name, weight_name, expert_id, shard_id = mapping
# weight_name is like "experts.0.up_proj."
if weight_name not in name:
continue
is_expert_weight = True
# Replace the expert-specific weight name with fused parameter name
# e.g., "experts.0.up_proj." -> "experts.w13_"
name_mapped = name.replace(weight_name, param_name)
if name_mapped not in params_dict:
continue
param = params_dict[name_mapped]
weight_loader = typing.cast(Callable[..., bool], param.weight_loader)
success = weight_loader(
param,
loaded_weight,
name_mapped,
shard_id=shard_id,
expert_id=expert_id,
return_success=True,
)
if success:
loaded_params.add(name_mapped)
break
if is_expert_weight:
continue
if name.endswith(".bias") and name not in params_dict:
continue
if name not in params_dict:
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
...@@ -266,7 +266,8 @@ class Plamo2MambaMixer(MambaBase, PluggableLayer): ...@@ -266,7 +266,8 @@ class Plamo2MambaMixer(MambaBase, PluggableLayer):
# conv_state = (..., dim, width-1) yet contiguous along 'dim' # conv_state = (..., dim, width-1) yet contiguous along 'dim'
conv_state = self_kv_cache[0].transpose(-1, -2) conv_state = self_kv_cache[0].transpose(-1, -2)
ssm_state = self_kv_cache[1] ssm_state = self_kv_cache[1]
state_indices_tensor = attn_metadata.state_indices_tensor state_indices_tensor_p = attn_metadata.state_indices_tensor_p
state_indices_tensor_d = attn_metadata.state_indices_tensor_d
has_initial_states_p = attn_metadata.has_initial_states_p has_initial_states_p = attn_metadata.has_initial_states_p
prep_initial_states = attn_metadata.prep_initial_states prep_initial_states = attn_metadata.prep_initial_states
chunk_size = attn_metadata.chunk_size chunk_size = attn_metadata.chunk_size
...@@ -309,13 +310,6 @@ class Plamo2MambaMixer(MambaBase, PluggableLayer): ...@@ -309,13 +310,6 @@ class Plamo2MambaMixer(MambaBase, PluggableLayer):
gate_d, gate_p = torch.split( gate_d, gate_p = torch.split(
gate[:num_actual_tokens], [num_decodes, num_prefill_tokens], dim=0 gate[:num_actual_tokens], [num_decodes, num_prefill_tokens], dim=0
) )
# Split along batch dimension
state_indices_tensor_d, state_indices_tensor_p = torch.split(
state_indices_tensor,
[num_decodes, num_prefills],
dim=0,
)
# Preallocate output tensor to avoid memcpy cost for merging prefill # Preallocate output tensor to avoid memcpy cost for merging prefill
# and decode outputs # and decode outputs
preallocated_ssm_out = torch.empty( preallocated_ssm_out = torch.empty(
...@@ -336,7 +330,7 @@ class Plamo2MambaMixer(MambaBase, PluggableLayer): ...@@ -336,7 +330,7 @@ class Plamo2MambaMixer(MambaBase, PluggableLayer):
if has_prefill: if has_prefill:
# 2. Convolution sequence transformation # 2. Convolution sequence transformation
# - "cache_indices" updates the conv_state cache in positions # - "cache_indices" updates the conv_state cache in positions
# pointed to by "state_indices_tensor" # pointed to by "state_indices_tensor_p"
x = hidden_states_p.transpose(0, 1) # this is the form that causal-conv see x = hidden_states_p.transpose(0, 1) # this is the form that causal-conv see
hidden_states_p = causal_conv1d_fn( hidden_states_p = causal_conv1d_fn(
x, x,
......
...@@ -522,6 +522,7 @@ _SPECULATIVE_DECODING_MODELS = { ...@@ -522,6 +522,7 @@ _SPECULATIVE_DECODING_MODELS = {
"DeepSeekMTPModel": ("deepseek_mtp", "DeepSeekMTP"), "DeepSeekMTPModel": ("deepseek_mtp", "DeepSeekMTP"),
"ErnieMTPModel": ("ernie_mtp", "ErnieMTP"), "ErnieMTPModel": ("ernie_mtp", "ErnieMTP"),
"ExaoneMoeMTP": ("exaone_moe_mtp", "ExaoneMoeMTP"), "ExaoneMoeMTP": ("exaone_moe_mtp", "ExaoneMoeMTP"),
"NemotronHMTPModel": ("nemotron_h_mtp", "NemotronHMTP"),
"LongCatFlashMTPModel": ("longcat_flash_mtp", "LongCatFlashMTP"), "LongCatFlashMTPModel": ("longcat_flash_mtp", "LongCatFlashMTP"),
"Glm4MoeMTPModel": ("glm4_moe_mtp", "Glm4MoeMTP"), "Glm4MoeMTPModel": ("glm4_moe_mtp", "Glm4MoeMTP"),
"Glm4MoeLiteMTPModel": ("glm4_moe_lite_mtp", "Glm4MoeLiteMTP"), "Glm4MoeLiteMTPModel": ("glm4_moe_lite_mtp", "Glm4MoeLiteMTP"),
......
...@@ -51,6 +51,8 @@ class NemotronHConfig(PretrainedConfig): ...@@ -51,6 +51,8 @@ class NemotronHConfig(PretrainedConfig):
The pattern of the hybrid model. The pattern is a string of The pattern of the hybrid model. The pattern is a string of
characters where each character represents characters where each character represents
M: Mamba2, *: Attention, -: MLP M: Mamba2, *: Attention, -: MLP
mtp_hybrid_override_pattern (`str`, *optional*, defaults to `"*E"`):
The pattern of the MTP layers.
num_attention_heads (`int`, *optional*, defaults to 32): num_attention_heads (`int`, *optional*, defaults to 32):
Number of attention heads for each attention layer in the Number of attention heads for each attention layer in the
Transformer encoder. Transformer encoder.
...@@ -150,6 +152,7 @@ class NemotronHConfig(PretrainedConfig): ...@@ -150,6 +152,7 @@ class NemotronHConfig(PretrainedConfig):
intermediate_size=21504, intermediate_size=21504,
num_hidden_layers=52, num_hidden_layers=52,
hybrid_override_pattern="M-M-M-M*-M-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M-", hybrid_override_pattern="M-M-M-M*-M-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M-",
mtp_hybrid_override_pattern="*E",
num_attention_heads=32, num_attention_heads=32,
head_dim=128, head_dim=128,
num_key_value_heads=8, # nemo: num_query_groups num_key_value_heads=8, # nemo: num_query_groups
...@@ -203,6 +206,7 @@ class NemotronHConfig(PretrainedConfig): ...@@ -203,6 +206,7 @@ class NemotronHConfig(PretrainedConfig):
self.intermediate_size = intermediate_size self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers self.num_hidden_layers = num_hidden_layers
self.hybrid_override_pattern = hybrid_override_pattern self.hybrid_override_pattern = hybrid_override_pattern
self.mtp_hybrid_override_pattern = mtp_hybrid_override_pattern
self.num_attention_heads = num_attention_heads self.num_attention_heads = num_attention_heads
self.head_dim = head_dim self.head_dim = head_dim
self.sliding_window = sliding_window self.sliding_window = sliding_window
...@@ -215,10 +219,9 @@ class NemotronHConfig(PretrainedConfig): ...@@ -215,10 +219,9 @@ class NemotronHConfig(PretrainedConfig):
assert len(self.hybrid_override_pattern) == self.num_hidden_layers, ( assert len(self.hybrid_override_pattern) == self.num_hidden_layers, (
"hybrid_override_pattern must have same length as num_hidden_layers" "hybrid_override_pattern must have same length as num_hidden_layers"
) )
assert re.match(r"^[*-M]+$", self.hybrid_override_pattern), ( assert re.match(r"^[*-ME]+$", self.hybrid_override_pattern), (
"hybrid_override_pattern must only contain characters 'M', '*', or '-'" "hybrid_override_pattern must only contain characters 'M', '*', '-', or 'E'"
) )
# for backward compatibility # for backward compatibility
if num_key_value_heads is None: if num_key_value_heads is None:
num_key_value_heads = num_attention_heads num_key_value_heads = num_attention_heads
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import itertools import itertools
from dataclasses import dataclass, replace from dataclasses import dataclass, replace
from typing import Any
import torch import torch
...@@ -200,8 +201,11 @@ class Mamba2AttentionMetadataBuilder( ...@@ -200,8 +201,11 @@ class Mamba2AttentionMetadataBuilder(
common_prefix_len: int, common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata, common_attn_metadata: CommonAttentionMetadata,
fast_build: bool = False, fast_build: bool = False,
**kwargs: Any,
) -> Mamba2AttentionMetadata: ) -> Mamba2AttentionMetadata:
common = self._compute_common_metadata(common_attn_metadata) common = self._compute_common_metadata(
common_attn_metadata, num_accepted_tokens=kwargs.get("num_accepted_tokens")
)
seq_idx_p = None seq_idx_p = None
cu_chunk_seqlen_p = None cu_chunk_seqlen_p = None
......
...@@ -2,9 +2,8 @@ ...@@ -2,9 +2,8 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import abc import abc
import copy from dataclasses import dataclass, replace
from dataclasses import dataclass from typing import Any, ClassVar, TypeVar
from typing import ClassVar, TypeVar
import torch import torch
...@@ -35,12 +34,21 @@ class BaseMambaAttentionMetadata: ...@@ -35,12 +34,21 @@ class BaseMambaAttentionMetadata:
num_reqs: int num_reqs: int
# The following tensors only contain prefill requests and will be None if # The following tensors only contain prefill requests and will be None if
# the batch has no prefill request. # the batch has no prefill requests.
has_initial_states_p: torch.Tensor | None has_initial_states_p: torch.Tensor | None
query_start_loc_p: torch.Tensor | None query_start_loc_p: torch.Tensor | None
num_computed_tokens_p: torch.Tensor | None num_computed_tokens_p: torch.Tensor | None
state_indices_tensor_p: torch.Tensor | None
state_indices_tensor: torch.Tensor # The following tensors are used for decode requests and
# speculative decoding compatibility, and will be None if the batch
# has no decode requests.
state_indices_tensor_d: torch.Tensor | None
query_start_loc_d: torch.Tensor | None # shape: [num_decodes + 1,]
# Number of accepted tokens for each spec sequence (for loading correct checkpoint)
# Includes the bonus token (so minimum is 1)
num_accepted_tokens: torch.Tensor | None # shape: [batch,]
# The following tensors are only used for prefix caching in all mode and # The following tensors are only used for prefix caching in all mode and
# are None if disabled # are None if disabled
...@@ -60,9 +68,9 @@ class BaseMambaAttentionMetadata: ...@@ -60,9 +68,9 @@ class BaseMambaAttentionMetadata:
class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC): class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC):
metadata_cls: type[M] metadata_cls: type[M]
reorder_batch_threshold: int = 1 reorder_batch_threshold: int = 1
_cudagraph_support: ClassVar[AttentionCGSupport] = ( _cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.UNIFORM_BATCH
AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE
) # Will be disabled if speculative decoding is used
supports_update_block_table: bool = True supports_update_block_table: bool = True
def __init__( def __init__(
...@@ -74,6 +82,12 @@ class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC): ...@@ -74,6 +82,12 @@ class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC):
): ):
super().__init__(kv_cache_spec, layer_names, vllm_config, device) super().__init__(kv_cache_spec, layer_names, vllm_config, device)
# Enable speculative decoding support
self.speculative_config = vllm_config.speculative_config
self.compilation_config = vllm_config.compilation_config
self.num_spec_tokens: int = vllm_config.num_speculative_tokens
self.use_spec_decode = self.num_spec_tokens > 0
assert isinstance(kv_cache_spec, MambaSpec) assert isinstance(kv_cache_spec, MambaSpec)
self.compilation_config = vllm_config.compilation_config self.compilation_config = vllm_config.compilation_config
self.decode_cudagraph_max_bs = self.vllm_config.scheduler_config.max_num_seqs self.decode_cudagraph_max_bs = self.vllm_config.scheduler_config.max_num_seqs
...@@ -84,13 +98,17 @@ class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC): ...@@ -84,13 +98,17 @@ class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC):
) )
if self.vllm_config.cache_config.mamba_cache_mode == "all": if self.vllm_config.cache_config.mamba_cache_mode == "all":
self.state_indices_tensor = torch.empty( max_num_blocks = cdiv(
self.vllm_config.model_config.max_model_len,
self.kv_cache_spec.block_size,
)
# Speculative decoding not supported with prefix caching,
# so keep shape consistent with prefill buffer
# TODO: reduce this size as needed for decode-only cudagraph capture
self.state_indices_tensor_d = torch.empty(
( (
self.decode_cudagraph_max_bs, self.decode_cudagraph_max_bs,
cdiv( max_num_blocks,
self.vllm_config.model_config.max_model_len,
self.kv_cache_spec.block_size,
),
), ),
dtype=torch.int32, dtype=torch.int32,
device=device, device=device,
...@@ -106,12 +124,25 @@ class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC): ...@@ -106,12 +124,25 @@ class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC):
device=device, device=device,
) )
else: else:
self.state_indices_tensor = torch.empty( self.state_indices_tensor_d = torch.empty(
(self.decode_cudagraph_max_bs, 1 + self.num_spec_tokens),
dtype=torch.int32,
device=device,
)
# For speculative decoding, we need to store the following buffers
# for CUDA graph capture during decode
if self.num_spec_tokens > 0:
self.decode_num_accepted_tokens = torch.empty(
(self.decode_cudagraph_max_bs,), (self.decode_cudagraph_max_bs,),
dtype=torch.int32, dtype=torch.int32,
device=device, device=device,
) )
self._init_reorder_batch_threshold(1, self.use_spec_decode)
if self.use_spec_decode:
self.supports_update_block_table = False
def build_for_cudagraph_capture( def build_for_cudagraph_capture(
self, common_attn_metadata: CommonAttentionMetadata self, common_attn_metadata: CommonAttentionMetadata
) -> M: ) -> M:
...@@ -121,26 +152,38 @@ class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC): ...@@ -121,26 +152,38 @@ class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC):
""" """
m = common_attn_metadata m = common_attn_metadata
assert m.num_reqs == m.num_actual_tokens, ( assert (
m.max_query_len <= 1 + self.num_spec_tokens
and m.num_reqs <= self.decode_cudagraph_max_bs
), (
"Mamba only supports decode-only full CUDAGraph capture. " "Mamba only supports decode-only full CUDAGraph capture. "
"Make sure all cudagraph capture sizes <= max_num_seq." "Make sure all cudagraph capture sizes <= max_num_seq."
) )
m.max_query_len = 1 # decode-only assert m.max_query_len == 1 + self.num_spec_tokens # decode-only
return self.build(0, m) num_accepted_tokens = None
if self.num_spec_tokens > 0:
num_accepted_tokens = torch.diff(m.query_start_loc)
return self.build(0, m, num_accepted_tokens=num_accepted_tokens)
def build( def build(
self, self,
common_prefix_len: int, common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata, common_attn_metadata: CommonAttentionMetadata,
fast_build: bool = False, fast_build: bool = False,
*,
num_accepted_tokens: torch.Tensor | None = None,
**kwargs: Any,
) -> M: ) -> M:
""" """
Default build implementation for Mamba-like attention backends. Default build implementation for Mamba-like attention backends.
Subclasses (e.g., Mamba2) can override to add additional metadata. Subclasses (e.g., Mamba2) can override to add additional metadata.
""" """
return self._compute_common_metadata(common_attn_metadata) return self._compute_common_metadata(
common_attn_metadata, num_accepted_tokens=num_accepted_tokens
)
def _compute_prefix_caching_block_indices( def _compute_prefix_caching_block_indices(
self, self,
...@@ -176,21 +219,32 @@ class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC): ...@@ -176,21 +219,32 @@ class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC):
def _compute_common_metadata( def _compute_common_metadata(
self, self,
common_attn_metadata: CommonAttentionMetadata, common_attn_metadata: CommonAttentionMetadata,
*,
num_accepted_tokens: torch.Tensor | None = None,
) -> M: ) -> M:
""" """
Compute metadata common to both Mamba1 and Mamba2. Compute metadata common to both Mamba1 and Mamba2.
""" """
num_reqs = common_attn_metadata.num_reqs num_reqs = common_attn_metadata.num_reqs
# Treat multi-token queries as decode requests when
# speculative decoding is enabled. Otherwise, use the
# default decode threshold to prevent misclassification
# of prefill queries as decode requests.
decode_threshold = (
self.reorder_batch_threshold if num_accepted_tokens is not None else 1
)
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = ( num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
split_decodes_and_prefills( split_decodes_and_prefills(
common_attn_metadata, decode_threshold=self.reorder_batch_threshold common_attn_metadata, decode_threshold=decode_threshold
) )
) )
# Need flags to indicate if there are initial states # Need flags to indicate if there are initial states
has_initial_states_p = None has_initial_states_p = None
query_start_loc_p = None query_start_loc_p = None
query_start_loc_d = None
num_computed_tokens = None num_computed_tokens = None
num_computed_tokens_p = None num_computed_tokens_p = None
...@@ -218,13 +272,31 @@ class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC): ...@@ -218,13 +272,31 @@ class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC):
common_attn_metadata, mamba_block_size common_attn_metadata, mamba_block_size
) )
else: else:
# Always return just a single block per each request:
state_indices_tensor = mamba_get_block_table_tensor( state_indices_tensor = mamba_get_block_table_tensor(
common_attn_metadata.block_table_tensor, common_attn_metadata.block_table_tensor,
common_attn_metadata.seq_lens, common_attn_metadata.seq_lens,
self.kv_cache_spec, self.kv_cache_spec,
self.vllm_config.cache_config.mamba_cache_mode, self.vllm_config.cache_config.mamba_cache_mode,
)[:, 0] )
if state_indices_tensor.dim() == 1:
state_indices_tensor = state_indices_tensor.unsqueeze(-1)
state_indices_tensor_d, state_indices_tensor_p = torch.split(
state_indices_tensor,
[num_decodes, num_prefills],
dim=0,
)
if self.vllm_config.cache_config.mamba_cache_mode != "all":
state_indices_tensor_d = state_indices_tensor_d[
:, : 1 + self.num_spec_tokens
]
state_indices_tensor_p = state_indices_tensor_p[:, 0]
if num_decodes > 0 and self.use_spec_decode:
assert num_accepted_tokens is not None
query_start_loc_d = common_attn_metadata.query_start_loc[: num_decodes + 1]
num_accepted_tokens = num_accepted_tokens[:num_decodes]
if num_prefills > 0: if num_prefills > 0:
if num_computed_tokens is None: if num_computed_tokens is None:
...@@ -258,39 +330,18 @@ class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC): ...@@ -258,39 +330,18 @@ class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC):
block_idx_first_scheduled_token_p = block_idx_first_scheduled_token[ block_idx_first_scheduled_token_p = block_idx_first_scheduled_token[
num_reqs - num_prefills : num_reqs num_reqs - num_prefills : num_reqs
] ]
elif (
num_decodes <= self.decode_cudagraph_max_bs
and self.compilation_config.cudagraph_mode.has_full_cudagraphs()
):
self.state_indices_tensor[:num_decodes].copy_(
state_indices_tensor, non_blocking=True
)
state_indices_tensor = self.state_indices_tensor[:num_decode_tokens]
state_indices_tensor[num_decodes:] = PAD_SLOT_ID
if self.vllm_config.cache_config.mamba_cache_mode == "all":
self.block_idx_last_scheduled_token[:num_decodes].copy_(
block_idx_last_scheduled_token, non_blocking=True
)
block_idx_last_scheduled_token = self.block_idx_last_scheduled_token[
:num_decode_tokens
]
self.block_idx_last_computed_token[:num_decodes].copy_( metadata = self.metadata_cls(
block_idx_last_computed_token, non_blocking=True
)
block_idx_last_computed_token = self.block_idx_last_computed_token[
:num_decode_tokens
]
return self.metadata_cls(
num_prefills=num_prefills, num_prefills=num_prefills,
num_prefill_tokens=num_prefill_tokens, num_prefill_tokens=num_prefill_tokens,
num_decodes=num_decodes, num_decodes=num_decodes,
num_decode_tokens=num_decode_tokens, num_decode_tokens=num_decode_tokens,
query_start_loc_p=query_start_loc_p, query_start_loc_p=query_start_loc_p,
has_initial_states_p=has_initial_states_p, has_initial_states_p=has_initial_states_p,
state_indices_tensor=state_indices_tensor, state_indices_tensor_p=state_indices_tensor_p,
state_indices_tensor_d=state_indices_tensor_d,
num_accepted_tokens=num_accepted_tokens,
query_start_loc_d=query_start_loc_d,
block_idx_last_scheduled_token=block_idx_last_scheduled_token, block_idx_last_scheduled_token=block_idx_last_scheduled_token,
block_idx_first_scheduled_token_p=block_idx_first_scheduled_token_p, block_idx_first_scheduled_token_p=block_idx_first_scheduled_token_p,
block_idx_last_computed_token=block_idx_last_computed_token, block_idx_last_computed_token=block_idx_last_computed_token,
...@@ -302,55 +353,112 @@ class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC): ...@@ -302,55 +353,112 @@ class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC):
token_chunk_offset_ptr=token_chunk_offset_ptr, token_chunk_offset_ptr=token_chunk_offset_ptr,
) )
def update_block_table( return self._update_metadata_for_cudagraph_capture(metadata)
def _update_metadata_for_cudagraph_capture(
self, self,
metadata: M, metadata: M,
blk_table: torch.Tensor,
slot_mapping: torch.Tensor,
) -> M: ) -> M:
new_metadata = copy.copy(metadata) """
state_indices_t = mamba_get_block_table_tensor( Update the metadata for cudagraph capture.
blk_table, Currently, only decode is supported for full cudagraphs with Mamba.
metadata.seq_lens, """
self.kv_cache_spec, state_indices_tensor_d = metadata.state_indices_tensor_d
self.vllm_config.cache_config.mamba_cache_mode, query_start_loc_d = metadata.query_start_loc_d
) num_accepted_tokens = metadata.num_accepted_tokens
if self.vllm_config.cache_config.mamba_cache_mode in ("none", "align"): block_idx_last_scheduled_token = metadata.block_idx_last_scheduled_token
# Only needs the block that saves the running state block_idx_last_computed_token = metadata.block_idx_last_computed_token
state_indices_t = state_indices_t[:, 0]
num_reqs = blk_table.shape[0]
# For CUDA graphs, copy to persistent buffer
if ( if (
metadata.num_prefills == 0 metadata.num_prefills == 0
and num_reqs <= self.decode_cudagraph_max_bs and metadata.num_decodes <= self.decode_cudagraph_max_bs
and self.compilation_config.cudagraph_mode.has_full_cudagraphs() and self.compilation_config.cudagraph_mode.has_full_cudagraphs()
): ):
persistent_state_indices_t = self.state_indices_tensor[:num_reqs] padded_bs = metadata.num_reqs
persistent_state_indices_t.copy_(state_indices_t, non_blocking=True) self.state_indices_tensor_d[: metadata.num_decodes].copy_(
state_indices_t = persistent_state_indices_t state_indices_tensor_d, non_blocking=True
)
state_indices_tensor_d = self.state_indices_tensor_d[:padded_bs]
state_indices_tensor_d[metadata.num_decodes :] = PAD_SLOT_ID
if self.use_spec_decode:
assert query_start_loc_d is not None
assert num_accepted_tokens is not None
query_start_loc_d = query_start_loc_d[: padded_bs + 1]
self.decode_num_accepted_tokens[: metadata.num_decodes].copy_(
num_accepted_tokens, non_blocking=True
)
num_accepted_tokens = self.decode_num_accepted_tokens[:padded_bs]
num_accepted_tokens[metadata.num_decodes :] = (
1 # pad with 1st slot index
)
# For 'all' mode, also update prefix caching block indices
# to use this builder's persistent buffers (required for CUDA
# graph replay to read from the correct memory addresses).
if self.vllm_config.cache_config.mamba_cache_mode == "all": if self.vllm_config.cache_config.mamba_cache_mode == "all":
assert metadata.block_idx_last_scheduled_token is not None assert block_idx_last_scheduled_token is not None
assert metadata.block_idx_last_computed_token is not None assert block_idx_last_computed_token is not None
self.block_idx_last_scheduled_token[:num_reqs].copy_( self.block_idx_last_scheduled_token[: metadata.num_decodes].copy_(
metadata.block_idx_last_scheduled_token[:num_reqs], block_idx_last_scheduled_token[: metadata.num_decodes],
non_blocking=True, non_blocking=True,
) )
new_metadata.block_idx_last_scheduled_token = ( block_idx_last_scheduled_token = self.block_idx_last_scheduled_token[
self.block_idx_last_scheduled_token[: metadata.num_decode_tokens] : metadata.num_decode_tokens
) ]
self.block_idx_last_computed_token[:num_reqs].copy_(
metadata.block_idx_last_computed_token[:num_reqs], self.block_idx_last_computed_token[: metadata.num_decodes].copy_(
block_idx_last_computed_token[: metadata.num_decodes],
non_blocking=True, non_blocking=True,
) )
new_metadata.block_idx_last_computed_token = ( block_idx_last_computed_token = self.block_idx_last_computed_token[
self.block_idx_last_computed_token[: metadata.num_decode_tokens] : metadata.num_decode_tokens
) ]
return replace(
metadata,
state_indices_tensor_d=state_indices_tensor_d,
query_start_loc_d=query_start_loc_d,
num_accepted_tokens=num_accepted_tokens,
block_idx_last_scheduled_token=block_idx_last_scheduled_token,
block_idx_last_computed_token=block_idx_last_computed_token,
)
def update_block_table(
self,
metadata: M,
blk_table: torch.Tensor,
slot_mapping: torch.Tensor,
) -> M:
state_indices_tensor = mamba_get_block_table_tensor(
blk_table,
metadata.seq_lens,
self.kv_cache_spec,
self.vllm_config.cache_config.mamba_cache_mode,
)
if state_indices_tensor.dim() == 1:
state_indices_tensor = state_indices_tensor.unsqueeze(-1)
assert (
metadata.num_prefills + metadata.num_decodes
== state_indices_tensor.shape[0]
), (
"Mismatch in number of requests when updating block table."
f" Expected {metadata.num_prefills + metadata.num_decodes}, "
f"got {state_indices_tensor.shape[0]}."
)
state_indices_tensor_d, state_indices_tensor_p = torch.split(
state_indices_tensor,
[metadata.num_decodes, metadata.num_prefills],
dim=0,
)
if self.vllm_config.cache_config.mamba_cache_mode != "all":
state_indices_tensor_d = state_indices_tensor_d[
:, : 1 + self.num_spec_tokens
]
state_indices_tensor_p = state_indices_tensor_p[:, 0]
new_metadata = replace(
metadata,
state_indices_tensor_d=state_indices_tensor_d,
state_indices_tensor_p=state_indices_tensor_p,
)
new_metadata.state_indices_tensor = state_indices_t return self._update_metadata_for_cudagraph_capture(new_metadata)
return new_metadata
...@@ -113,6 +113,7 @@ from vllm.v1.attention.backend import ( ...@@ -113,6 +113,7 @@ from vllm.v1.attention.backend import (
MultipleOf, MultipleOf,
) )
from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadataBuilder from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadataBuilder
from vllm.v1.attention.backends.mamba2_attn import Mamba2AttentionMetadataBuilder
from vllm.v1.attention.backends.utils import ( from vllm.v1.attention.backends.utils import (
create_fast_prefill_custom_backend, create_fast_prefill_custom_backend,
get_dcp_local_seq_lens, get_dcp_local_seq_lens,
...@@ -1852,7 +1853,9 @@ class GPUModelRunner( ...@@ -1852,7 +1853,9 @@ class GPUModelRunner(
) )
extra_attn_metadata_args = {} extra_attn_metadata_args = {}
if use_spec_decode and isinstance(builder, GDNAttentionMetadataBuilder): if use_spec_decode and isinstance(
builder, (Mamba2AttentionMetadataBuilder, GDNAttentionMetadataBuilder)
):
assert ubid is None, "UBatching not supported with GDN yet" assert ubid is None, "UBatching not supported with GDN yet"
extra_attn_metadata_args = dict( extra_attn_metadata_args = dict(
num_accepted_tokens=self.num_accepted_tokens.gpu[:num_reqs_padded], num_accepted_tokens=self.num_accepted_tokens.gpu[:num_reqs_padded],
...@@ -4725,7 +4728,7 @@ class GPUModelRunner( ...@@ -4725,7 +4728,7 @@ class GPUModelRunner(
# Set num_scheduled_tokens based on num_tokens and max_num_seqs # Set num_scheduled_tokens based on num_tokens and max_num_seqs
# for dummy run with LoRA so that the num_reqs collectively # for dummy run with LoRA so that the num_reqs collectively
# has num_tokens in total. # has num_tokens in total.
assert num_tokens <= self.scheduler_config.max_num_batched_tokens assert num_tokens <= self.max_num_tokens
max_num_reqs = self.scheduler_config.max_num_seqs max_num_reqs = self.scheduler_config.max_num_seqs
if create_mixed_batch: if create_mixed_batch:
assert not uniform_decode assert not uniform_decode
...@@ -4849,6 +4852,7 @@ class GPUModelRunner( ...@@ -4849,6 +4852,7 @@ class GPUModelRunner(
ubatch_slices=(ubatch_slices_padded if pad_attn else ubatch_slices), ubatch_slices=(ubatch_slices_padded if pad_attn else ubatch_slices),
for_cudagraph_capture=is_graph_capturing, for_cudagraph_capture=is_graph_capturing,
slot_mappings=slot_mappings_by_group, slot_mappings=slot_mappings_by_group,
use_spec_decode=self.speculative_config is not None,
) )
with self.maybe_dummy_run_with_lora( with self.maybe_dummy_run_with_lora(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment