Unverified Commit 8a25d3a7 authored by YingchaoX's avatar YingchaoX Committed by GitHub
Browse files

fix stablelm.py tensor-parallel-size bug (#2482)

parent d10f8e1d
......@@ -99,7 +99,7 @@ class StablelmAttention(nn.Module):
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_key_value_heads * self.head_dim
if (self.head_dim * self.num_heads) != self.hidden_size:
if (self.head_dim * self.num_heads * tp_size) != self.hidden_size:
raise ValueError(
f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
f" and `num_heads`: {self.num_heads}).")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment