Unverified Commit 8332fc49 authored by OlivierDehaene's avatar OlivierDehaene Committed by GitHub
Browse files

fix: use get_speculate to the number of layers (#1737)

parent 743ecbca
......@@ -8,7 +8,8 @@ from typing import List, Tuple, Optional
from loguru import logger
from functools import lru_cache
# Dummy comment.
from text_generation_server.utils.speculate import get_speculate
HAS_BITS_AND_BYTES = True
try:
import bitsandbytes as bnb
......@@ -445,7 +446,7 @@ class MedusaModel(torch.nn.Module):
self.heads = torch.nn.ModuleList(
[
MedusaHead(config, medusa_config, prefix=f"{i}", weights=weights)
for i in range(medusa_config["medusa_num_heads"])
for i in range(get_speculate())
]
)
......@@ -542,7 +543,7 @@ class MedusaHeadV2(nn.Module):
)
routing[k] = filename
self.n_medusa_heads = medusa_config["medusa_num_heads"]
self.n_medusa_heads = get_speculate()
assert medusa_config["medusa_num_layers"] == 1
self.linear = TensorParallelColumnLinear.load_multi(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment