"tests/vscode:/vscode.git/clone" did not exist on "5efa206a8cc5501563a79f667a5ae2f87dba2108"
Unverified Commit b98d89ef authored by Sky Lee's avatar Sky Lee Committed by GitHub
Browse files

[Misc] Medusa supports custom bias (#10361)

parent 8b6725b0
...@@ -14,11 +14,14 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata ...@@ -14,11 +14,14 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
class ResidualBlock(nn.Module): class ResidualBlock(nn.Module):
def __init__(self, hidden_size: int, num_layers: int) -> None: def __init__(self, config: VllmConfig, hidden_size: int,
num_layers: int) -> None:
super().__init__() super().__init__()
self.layers = nn.ModuleList([ self.layers = nn.ModuleList([
nn.Linear(hidden_size, hidden_size, bias=False) nn.Linear(hidden_size,
hidden_size,
bias=getattr(config, "medusa_fc_bias", False))
for _ in range(num_layers) for _ in range(num_layers)
]) ])
self.act = nn.SiLU() self.act = nn.SiLU()
...@@ -49,7 +52,8 @@ class Medusa(nn.Module): ...@@ -49,7 +52,8 @@ class Medusa(nn.Module):
super().__init__() super().__init__()
self.config = config self.config = config
self.blocks = nn.ModuleList([ self.blocks = nn.ModuleList([
ResidualBlock(hidden_size=self.config.hidden_size, ResidualBlock(config=config,
hidden_size=self.config.hidden_size,
num_layers=self.config.num_hidden_layers) num_layers=self.config.num_hidden_layers)
for _ in range(self.config.num_heads) for _ in range(self.config.num_heads)
]) ])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment