"vscode:/vscode.git/clone" did not exist on "1282bd812ea4e1511378bad5b918d609280d2b89"
Unverified Commit 691e29ec authored by Nick Hill's avatar Nick Hill Committed by GitHub
Browse files

[BugFix] Fix `MLPSpeculator` handling of `num_speculative_tokens` (#5876)

parent 3fd02bda
......@@ -920,15 +920,19 @@ class SpeculativeConfig:
max_logprobs=target_model_config.max_logprobs,
)
if (draft_model_config.hf_config.model_type == "mlp_speculator"
draft_hf_config = draft_model_config.hf_config
if (draft_hf_config.model_type == "mlp_speculator"
and target_parallel_config.world_size != 1):
# MLPSpeculator TP support will be added very soon
raise ValueError(
"Speculative decoding with mlp_speculator models does not "
"yet support distributed inferencing (TP > 1).")
n_predict = getattr(draft_model_config.hf_config, "n_predict",
None)
if (num_speculative_tokens is not None
and hasattr(draft_hf_config, "num_lookahead_tokens")):
draft_hf_config.num_lookahead_tokens = num_speculative_tokens
n_predict = getattr(draft_hf_config, "n_predict", None)
if n_predict is not None:
if num_speculative_tokens is None:
# Default to max value defined in draft model config.
......
......@@ -11,6 +11,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.sequence import SamplerOutput
from vllm.transformers_utils.configs import MLPSpeculatorConfig
class MLPSpeculatorLayerNorm(nn.Module):
......@@ -48,7 +49,7 @@ class MLPSpeculatorLayerNorm(nn.Module):
class MLPSpeculator(nn.Module):
def __init__(self, config, **kwargs) -> None:
def __init__(self, config: MLPSpeculatorConfig, **kwargs) -> None:
super().__init__()
self.n_predict = config.n_predict
self.vocab_size = config.vocab_size
......@@ -56,8 +57,7 @@ class MLPSpeculator(nn.Module):
self.inner_dim = config.inner_dim if config.inner_dim != 0 \
else config.emb_dim
self.max_speculative_tokens = getattr(config, "max_speculative_tokens",
self.n_predict)
self.max_speculative_tokens = config.num_lookahead_tokens
self.emb = nn.ModuleList([
VocabParallelEmbedding(config.vocab_size,
......@@ -137,7 +137,8 @@ class MLPSpeculator(nn.Module):
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
params_dict = dict(self.named_parameters())
for name, loaded_weight in weights:
param = params_dict[name.replace("speculator.", "")]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
param = params_dict.get(name.replace("speculator.", ""))
if param is not None:
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
......@@ -35,6 +35,7 @@ class MLPSpeculatorConfig(PretrainedConfig):
candidate tree.
For each candidate branch in the tree, head n produces topk[n]
additional sub-branches.
NOTE: This parameter is currently unused.
n_candidates: int
number of child candidates to create per sequence
"""
......@@ -47,4 +48,6 @@ class MLPSpeculatorConfig(PretrainedConfig):
self.n_predict = n_predict
self.top_k_tokens_per_head = top_k_tokens_per_head
self.n_candidates = n_candidates
self.num_lookahead_tokens = n_predict
super().__init__(**kwargs)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment