Unverified Commit 3a0e1fc0 authored by dakotamahan-stability's avatar dakotamahan-stability Committed by GitHub
Browse files

Support for Stable LM 2 (#2598)


Co-authored-by: default avatarZhuohan Li <zhuohan123@gmail.com>
parent 6b7de1a0
...@@ -98,7 +98,7 @@ class StablelmAttention(nn.Module): ...@@ -98,7 +98,7 @@ class StablelmAttention(nn.Module):
self.scaling = self.head_dim**-0.5 self.scaling = self.head_dim**-0.5
self.q_size = self.num_heads * self.head_dim self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_key_value_heads * self.head_dim self.kv_size = self.num_key_value_heads * self.head_dim
self.qkv_bias = getattr(config, "use_qkv_bias", False)
if (self.head_dim * self.num_heads * tp_size) != self.hidden_size: if (self.head_dim * self.num_heads * tp_size) != self.hidden_size:
raise ValueError( raise ValueError(
f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
...@@ -108,7 +108,7 @@ class StablelmAttention(nn.Module): ...@@ -108,7 +108,7 @@ class StablelmAttention(nn.Module):
self.head_dim, self.head_dim,
self.total_num_heads, self.total_num_heads,
self.total_num_key_value_heads, self.total_num_key_value_heads,
bias=False, self.qkv_bias,
linear_method=linear_method) linear_method=linear_method)
self.o_proj = RowParallelLinear(self.total_num_heads * self.head_dim, self.o_proj = RowParallelLinear(self.total_num_heads * self.head_dim,
self.hidden_size, self.hidden_size,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment