"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "bcd607542cddc1240f380db07e681fac9ff27918"
Unverified Commit cfb22e03 authored by Mansu Kim's avatar Mansu Kim Committed by GitHub
Browse files

Support Clip QKV for MPT (#31307)

parent b7672826
......@@ -82,6 +82,7 @@ class MptAttention(nn.Module):
self.softmax_scale = 1 / math.sqrt(self.hidden_size / self.n_heads)
self.attn_dropout_p = config.attn_config.attn_pdrop
self.clip_qkv = config.attn_config.clip_qkv
self.Wqkv = nn.Linear(self.hidden_size, 3 * self.hidden_size, bias=False)
self.out_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
......@@ -95,6 +96,9 @@ class MptAttention(nn.Module):
batch_size, seq_length = hidden_states.shape[:2]
mixed_qkv = self.Wqkv(hidden_states)
if self.clip_qkv:
mixed_qkv = mixed_qkv.clamp(min=-self.clip_qkv, max=self.clip_qkv)
query_states, key_states, value_states = mixed_qkv.chunk(3, dim=2)
query_states = query_states.reshape(batch_size, seq_length, self.n_heads, self.head_dim).transpose(1, 2)
key_states = key_states.reshape(batch_size, seq_length, self.n_heads, self.head_dim).transpose(1, 2)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment