Unverified Commit c38eba30 authored by Thomas Parnell's avatar Thomas Parnell Committed by GitHub
Browse files

[Bugfix] MLPSpeculator: Use ParallelLMHead in tie_weights=False case. (#6303)


Signed-off-by: default avatarThomas Parnell <tpa@zurich.ibm.com>
parent e72ae80b
......@@ -110,7 +110,7 @@ class MLPSpeculator(nn.Module):
])
self.head = nn.ModuleList([
nn.Linear(self.inner_dim, self.vocab_size, bias=False)
ParallelLMHead(self.vocab_size, self.inner_dim, bias=False)
for _ in range(self.max_speculative_tokens)
])
self.ln = nn.ModuleList([
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment