Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
517a3e67
Unverified
Commit
517a3e67
authored
Apr 04, 2024
by
Saurabh Dash
Committed by
GitHub
Apr 04, 2024
Browse files
Refactor Cohere Model (#30027)
* changes * addressing comments * smol fix
parent
75b76a5e
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
46 additions
and
20 deletions
+46
-20
src/transformers/models/cohere/configuration_cohere.py
src/transformers/models/cohere/configuration_cohere.py
+4
-0
src/transformers/models/cohere/modeling_cohere.py
src/transformers/models/cohere/modeling_cohere.py
+42
-20
No files found.
src/transformers/models/cohere/configuration_cohere.py
View file @
517a3e67
...
...
@@ -85,6 +85,8 @@ class CohereConfig(PretrainedConfig):
Whether to use a bias in the query, key, value and output projection layers during self-attention.
attention_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for the attention probabilities.
use_qk_norm (`bool`, *optional*, defaults to `False`):
Whether to use query-key normalization in the attention
```python
>>> from transformers import CohereModel, CohereConfig
...
...
@@ -123,6 +125,7 @@ class CohereConfig(PretrainedConfig):
rope_theta
=
10000.0
,
attention_bias
=
False
,
attention_dropout
=
0.0
,
use_qk_norm
=
False
,
**
kwargs
,
):
self
.
vocab_size
=
vocab_size
...
...
@@ -145,6 +148,7 @@ class CohereConfig(PretrainedConfig):
self
.
rope_theta
=
rope_theta
self
.
attention_bias
=
attention_bias
self
.
attention_dropout
=
attention_dropout
self
.
use_qk_norm
=
use_qk_norm
super
().
__init__
(
pad_token_id
=
pad_token_id
,
...
...
src/transformers/models/cohere/modeling_cohere.py
View file @
517a3e67
...
...
@@ -76,10 +76,10 @@ def _get_unpad_data(attention_mask):
class
CohereLayerNorm
(
nn
.
Module
):
def
__init__
(
self
,
hidden_size
,
eps
=
1e-5
,
bias
=
False
):
def
__init__
(
self
,
hidden_size
=
None
,
eps
=
1e-5
,
bias
=
False
):
"""The hidden size can be a tuple or an int. The tuple is used for QKNorm to normalize across head_dim"""
super
().
__init__
()
self
.
weight
=
nn
.
Parameter
(
torch
.
ones
(
hidden_size
))
self
.
bias
=
nn
.
Parameter
(
torch
.
zeros
(
hidden_size
))
if
bias
else
None
self
.
variance_epsilon
=
eps
def
forward
(
self
,
hidden_states
):
...
...
@@ -89,8 +89,6 @@ class CohereLayerNorm(nn.Module):
variance
=
(
hidden_states
-
mean
).
pow
(
2
).
mean
(
-
1
,
keepdim
=
True
)
hidden_states
=
(
hidden_states
-
mean
)
*
torch
.
rsqrt
(
variance
+
self
.
variance_epsilon
)
hidden_states
=
self
.
weight
.
to
(
torch
.
float32
)
*
hidden_states
if
self
.
bias
is
not
None
:
hidden_states
=
hidden_states
+
self
.
bias
.
to
(
torch
.
float32
)
return
hidden_states
.
to
(
input_dtype
)
...
...
@@ -122,7 +120,7 @@ class CohereRotaryEmbedding(nn.Module):
emb
=
torch
.
repeat_interleave
(
freqs
,
2
,
dim
=-
1
)
cos
=
emb
.
cos
()
sin
=
emb
.
sin
()
return
cos
.
to
(
dtype
=
x
.
dtype
),
sin
.
to
(
dtype
=
x
.
dtype
)
return
cos
,
sin
def
rotate_half
(
x
):
...
...
@@ -133,7 +131,6 @@ def rotate_half(x):
return
rot_x
# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
def
apply_rotary_pos_emb
(
q
,
k
,
cos
,
sin
,
position_ids
=
None
,
unsqueeze_dim
=
1
):
"""Applies Rotary Position Embedding to the query and key tensors.
...
...
@@ -154,11 +151,14 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
Returns:
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
"""
dtype
=
q
.
dtype
q
=
q
.
float
()
k
=
k
.
float
()
cos
=
cos
.
unsqueeze
(
unsqueeze_dim
)
sin
=
sin
.
unsqueeze
(
unsqueeze_dim
)
q_embed
=
(
q
*
cos
)
+
(
rotate_half
(
q
)
*
sin
)
k_embed
=
(
k
*
cos
)
+
(
rotate_half
(
k
)
*
sin
)
return
q_embed
,
k_embed
return
q_embed
.
to
(
dtype
=
dtype
),
k_embed
.
to
(
dtype
=
dtype
)
# Copied from transformers.models.llama.modeling_llama.LlamaMLP Llama->Cohere
...
...
@@ -192,7 +192,6 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
return
hidden_states
.
reshape
(
batch
,
num_key_value_heads
*
n_rep
,
slen
,
head_dim
)
# Copied from transformers.models.llama.modeling_llama.LlamaAttention Llama->Cohere
class
CohereAttention
(
nn
.
Module
):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
...
...
@@ -216,6 +215,7 @@ class CohereAttention(nn.Module):
self
.
max_position_embeddings
=
config
.
max_position_embeddings
self
.
rope_theta
=
config
.
rope_theta
self
.
is_causal
=
True
self
.
use_qk_norm
=
config
.
use_qk_norm
if
(
self
.
head_dim
*
self
.
num_heads
)
!=
self
.
hidden_size
:
raise
ValueError
(
...
...
@@ -223,6 +223,13 @@ class CohereAttention(nn.Module):
f
" and `num_heads`:
{
self
.
num_heads
}
)."
)
if
self
.
use_qk_norm
:
# When sharding the model using Tensor Parallelism, need to be careful to use n_local_heads
self
.
q_norm
=
CohereLayerNorm
(
hidden_size
=
(
self
.
num_heads
,
self
.
head_dim
),
eps
=
config
.
layer_norm_eps
)
self
.
k_norm
=
CohereLayerNorm
(
hidden_size
=
(
self
.
num_key_value_heads
,
self
.
head_dim
),
eps
=
config
.
layer_norm_eps
)
self
.
q_proj
=
nn
.
Linear
(
self
.
hidden_size
,
self
.
num_heads
*
self
.
head_dim
,
bias
=
config
.
attention_bias
)
self
.
k_proj
=
nn
.
Linear
(
self
.
hidden_size
,
self
.
num_key_value_heads
*
self
.
head_dim
,
bias
=
config
.
attention_bias
)
self
.
v_proj
=
nn
.
Linear
(
self
.
hidden_size
,
self
.
num_key_value_heads
*
self
.
head_dim
,
bias
=
config
.
attention_bias
)
...
...
@@ -255,8 +262,14 @@ class CohereAttention(nn.Module):
key_states
=
self
.
k_proj
(
hidden_states
)
value_states
=
self
.
v_proj
(
hidden_states
)
query_states
=
query_states
.
view
(
bsz
,
q_len
,
self
.
num_heads
,
self
.
head_dim
).
transpose
(
1
,
2
)
key_states
=
key_states
.
view
(
bsz
,
q_len
,
self
.
num_key_value_heads
,
self
.
head_dim
).
transpose
(
1
,
2
)
query_states
=
query_states
.
view
(
bsz
,
q_len
,
self
.
num_heads
,
self
.
head_dim
)
key_states
=
key_states
.
view
(
bsz
,
q_len
,
self
.
num_key_value_heads
,
self
.
head_dim
)
if
self
.
use_qk_norm
:
query_states
=
self
.
q_norm
(
query_states
)
key_states
=
self
.
k_norm
(
key_states
)
query_states
=
query_states
.
transpose
(
1
,
2
)
key_states
=
key_states
.
transpose
(
1
,
2
)
value_states
=
value_states
.
view
(
bsz
,
q_len
,
self
.
num_key_value_heads
,
self
.
head_dim
).
transpose
(
1
,
2
)
past_key_value
=
getattr
(
self
,
"past_key_value"
,
past_key_value
)
...
...
@@ -335,11 +348,14 @@ class CohereFlashAttention2(CohereAttention):
key_states
=
self
.
k_proj
(
hidden_states
)
value_states
=
self
.
v_proj
(
hidden_states
)
# Flash attention requires the input to have the shape
# batch_size x seq_length x head_dim x hidden_dim
# therefore we just need to keep the original shape
query_states
=
query_states
.
view
(
bsz
,
q_len
,
self
.
num_heads
,
self
.
head_dim
).
transpose
(
1
,
2
)
key_states
=
key_states
.
view
(
bsz
,
q_len
,
self
.
num_key_value_heads
,
self
.
head_dim
).
transpose
(
1
,
2
)
query_states
=
query_states
.
view
(
bsz
,
q_len
,
self
.
num_heads
,
self
.
head_dim
)
key_states
=
key_states
.
view
(
bsz
,
q_len
,
self
.
num_key_value_heads
,
self
.
head_dim
)
if
self
.
use_qk_norm
:
query_states
=
self
.
q_norm
(
query_states
)
key_states
=
self
.
k_norm
(
key_states
)
query_states
=
query_states
.
transpose
(
1
,
2
)
key_states
=
key_states
.
transpose
(
1
,
2
)
value_states
=
value_states
.
view
(
bsz
,
q_len
,
self
.
num_key_value_heads
,
self
.
head_dim
).
transpose
(
1
,
2
)
cos
,
sin
=
self
.
rotary_emb
(
value_states
,
position_ids
)
...
...
@@ -505,7 +521,7 @@ class CohereSdpaAttention(CohereAttention):
SDPA API.
"""
#
Adapted from CohereAttention.forward
#
Ignore copy
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
...
...
@@ -538,8 +554,14 @@ class CohereSdpaAttention(CohereAttention):
key_states
=
self
.
k_proj
(
hidden_states
)
value_states
=
self
.
v_proj
(
hidden_states
)
query_states
=
query_states
.
view
(
bsz
,
q_len
,
self
.
num_heads
,
self
.
head_dim
).
transpose
(
1
,
2
)
key_states
=
key_states
.
view
(
bsz
,
q_len
,
self
.
num_key_value_heads
,
self
.
head_dim
).
transpose
(
1
,
2
)
query_states
=
query_states
.
view
(
bsz
,
q_len
,
self
.
num_heads
,
self
.
head_dim
)
key_states
=
key_states
.
view
(
bsz
,
q_len
,
self
.
num_key_value_heads
,
self
.
head_dim
)
if
self
.
use_qk_norm
:
query_states
=
self
.
q_norm
(
query_states
)
key_states
=
self
.
k_norm
(
key_states
)
query_states
=
query_states
.
transpose
(
1
,
2
)
key_states
=
key_states
.
transpose
(
1
,
2
)
value_states
=
value_states
.
view
(
bsz
,
q_len
,
self
.
num_key_value_heads
,
self
.
head_dim
).
transpose
(
1
,
2
)
cos
,
sin
=
self
.
rotary_emb
(
value_states
,
position_ids
)
...
...
@@ -599,7 +621,7 @@ class CohereDecoderLayer(nn.Module):
self
.
self_attn
=
COHERE_ATTENTION_CLASSES
[
config
.
_attn_implementation
](
config
=
config
,
layer_idx
=
layer_idx
)
self
.
mlp
=
CohereMLP
(
config
)
self
.
input_layernorm
=
CohereLayerNorm
(
config
.
hidden_size
,
eps
=
config
.
layer_norm_eps
)
self
.
input_layernorm
=
CohereLayerNorm
(
hidden_size
=
(
config
.
hidden_size
)
,
eps
=
config
.
layer_norm_eps
)
def
forward
(
self
,
...
...
@@ -822,7 +844,7 @@ class CohereModel(CoherePreTrainedModel):
self
.
layers
=
nn
.
ModuleList
(
[
CohereDecoderLayer
(
config
,
layer_idx
)
for
layer_idx
in
range
(
config
.
num_hidden_layers
)]
)
self
.
norm
=
CohereLayerNorm
(
config
.
hidden_size
,
eps
=
config
.
layer_norm_eps
)
self
.
norm
=
CohereLayerNorm
(
hidden_size
=
(
config
.
hidden_size
)
,
eps
=
config
.
layer_norm_eps
)
self
.
gradient_checkpointing
=
False
# Initialize weights and apply final processing
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment