Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
517a3e67
Unverified
Commit
517a3e67
authored
Apr 04, 2024
by
Saurabh Dash
Committed by
GitHub
Apr 04, 2024
Browse files
Refactor Cohere Model (#30027)
* changes * addressing comments * smol fix
parent
75b76a5e
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
46 additions
and
20 deletions
+46
-20
src/transformers/models/cohere/configuration_cohere.py
src/transformers/models/cohere/configuration_cohere.py
+4
-0
src/transformers/models/cohere/modeling_cohere.py
src/transformers/models/cohere/modeling_cohere.py
+42
-20
No files found.
src/transformers/models/cohere/configuration_cohere.py
View file @
517a3e67
...
@@ -85,6 +85,8 @@ class CohereConfig(PretrainedConfig):
...
@@ -85,6 +85,8 @@ class CohereConfig(PretrainedConfig):
Whether to use a bias in the query, key, value and output projection layers during self-attention.
Whether to use a bias in the query, key, value and output projection layers during self-attention.
attention_dropout (`float`, *optional*, defaults to 0.0):
attention_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for the attention probabilities.
The dropout ratio for the attention probabilities.
use_qk_norm (`bool`, *optional*, defaults to `False`):
Whether to use query-key normalization in the attention
```python
```python
>>> from transformers import CohereModel, CohereConfig
>>> from transformers import CohereModel, CohereConfig
...
@@ -123,6 +125,7 @@ class CohereConfig(PretrainedConfig):
...
@@ -123,6 +125,7 @@ class CohereConfig(PretrainedConfig):
rope_theta
=
10000.0
,
rope_theta
=
10000.0
,
attention_bias
=
False
,
attention_bias
=
False
,
attention_dropout
=
0.0
,
attention_dropout
=
0.0
,
use_qk_norm
=
False
,
**
kwargs
,
**
kwargs
,
):
):
self
.
vocab_size
=
vocab_size
self
.
vocab_size
=
vocab_size
...
@@ -145,6 +148,7 @@ class CohereConfig(PretrainedConfig):
...
@@ -145,6 +148,7 @@ class CohereConfig(PretrainedConfig):
self
.
rope_theta
=
rope_theta
self
.
rope_theta
=
rope_theta
self
.
attention_bias
=
attention_bias
self
.
attention_bias
=
attention_bias
self
.
attention_dropout
=
attention_dropout
self
.
attention_dropout
=
attention_dropout
self
.
use_qk_norm
=
use_qk_norm
super
().
__init__
(
super
().
__init__
(
pad_token_id
=
pad_token_id
,
pad_token_id
=
pad_token_id
,
...
...
src/transformers/models/cohere/modeling_cohere.py
View file @
517a3e67
...
@@ -76,10 +76,10 @@ def _get_unpad_data(attention_mask):
...
@@ -76,10 +76,10 @@ def _get_unpad_data(attention_mask):
class
CohereLayerNorm
(
nn
.
Module
):
class
CohereLayerNorm
(
nn
.
Module
):
def
__init__
(
self
,
hidden_size
,
eps
=
1e-5
,
bias
=
False
):
def
__init__
(
self
,
hidden_size
=
None
,
eps
=
1e-5
,
bias
=
False
):
"""The hidden size can be a tuple or an int. The tuple is used for QKNorm to normalize across head_dim"""
super
().
__init__
()
super
().
__init__
()
self
.
weight
=
nn
.
Parameter
(
torch
.
ones
(
hidden_size
))
self
.
weight
=
nn
.
Parameter
(
torch
.
ones
(
hidden_size
))
self
.
bias
=
nn
.
Parameter
(
torch
.
zeros
(
hidden_size
))
if
bias
else
None
self
.
variance_epsilon
=
eps
self
.
variance_epsilon
=
eps
def
forward
(
self
,
hidden_states
):
def
forward
(
self
,
hidden_states
):
...
@@ -89,8 +89,6 @@ class CohereLayerNorm(nn.Module):
...
@@ -89,8 +89,6 @@ class CohereLayerNorm(nn.Module):
variance
=
(
hidden_states
-
mean
).
pow
(
2
).
mean
(
-
1
,
keepdim
=
True
)
variance
=
(
hidden_states
-
mean
).
pow
(
2
).
mean
(
-
1
,
keepdim
=
True
)
hidden_states
=
(
hidden_states
-
mean
)
*
torch
.
rsqrt
(
variance
+
self
.
variance_epsilon
)
hidden_states
=
(
hidden_states
-
mean
)
*
torch
.
rsqrt
(
variance
+
self
.
variance_epsilon
)
hidden_states
=
self
.
weight
.
to
(
torch
.
float32
)
*
hidden_states
hidden_states
=
self
.
weight
.
to
(
torch
.
float32
)
*
hidden_states
if
self
.
bias
is
not
None
:
hidden_states
=
hidden_states
+
self
.
bias
.
to
(
torch
.
float32
)
return
hidden_states
.
to
(
input_dtype
)
return
hidden_states
.
to
(
input_dtype
)
...
@@ -122,7 +120,7 @@ class CohereRotaryEmbedding(nn.Module):
...
@@ -122,7 +120,7 @@ class CohereRotaryEmbedding(nn.Module):
emb
=
torch
.
repeat_interleave
(
freqs
,
2
,
dim
=-
1
)
emb
=
torch
.
repeat_interleave
(
freqs
,
2
,
dim
=-
1
)
cos
=
emb
.
cos
()
cos
=
emb
.
cos
()
sin
=
emb
.
sin
()
sin
=
emb
.
sin
()
return
cos
.
to
(
dtype
=
x
.
dtype
),
sin
.
to
(
dtype
=
x
.
dtype
)
return
cos
,
sin
def
rotate_half
(
x
):
def
rotate_half
(
x
):
...
@@ -133,7 +131,6 @@ def rotate_half(x):
...
@@ -133,7 +131,6 @@ def rotate_half(x):
return
rot_x
return
rot_x
# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
def
apply_rotary_pos_emb
(
q
,
k
,
cos
,
sin
,
position_ids
=
None
,
unsqueeze_dim
=
1
):
def
apply_rotary_pos_emb
(
q
,
k
,
cos
,
sin
,
position_ids
=
None
,
unsqueeze_dim
=
1
):
"""Applies Rotary Position Embedding to the query and key tensors.
"""Applies Rotary Position Embedding to the query and key tensors.
...
@@ -154,11 +151,14 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
...
@@ -154,11 +151,14 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
Returns:
Returns:
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
"""
"""
dtype
=
q
.
dtype
q
=
q
.
float
()
k
=
k
.
float
()
cos
=
cos
.
unsqueeze
(
unsqueeze_dim
)
cos
=
cos
.
unsqueeze
(
unsqueeze_dim
)
sin
=
sin
.
unsqueeze
(
unsqueeze_dim
)
sin
=
sin
.
unsqueeze
(
unsqueeze_dim
)
q_embed
=
(
q
*
cos
)
+
(
rotate_half
(
q
)
*
sin
)
q_embed
=
(
q
*
cos
)
+
(
rotate_half
(
q
)
*
sin
)
k_embed
=
(
k
*
cos
)
+
(
rotate_half
(
k
)
*
sin
)
k_embed
=
(
k
*
cos
)
+
(
rotate_half
(
k
)
*
sin
)
return
q_embed
,
k_embed
return
q_embed
.
to
(
dtype
=
dtype
),
k_embed
.
to
(
dtype
=
dtype
)
# Copied from transformers.models.llama.modeling_llama.LlamaMLP Llama->Cohere
# Copied from transformers.models.llama.modeling_llama.LlamaMLP Llama->Cohere
...
@@ -192,7 +192,6 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
...
@@ -192,7 +192,6 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
return
hidden_states
.
reshape
(
batch
,
num_key_value_heads
*
n_rep
,
slen
,
head_dim
)
return
hidden_states
.
reshape
(
batch
,
num_key_value_heads
*
n_rep
,
slen
,
head_dim
)
# Copied from transformers.models.llama.modeling_llama.LlamaAttention Llama->Cohere
class
CohereAttention
(
nn
.
Module
):
class
CohereAttention
(
nn
.
Module
):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
"""Multi-headed attention from 'Attention Is All You Need' paper"""
...
@@ -216,6 +215,7 @@ class CohereAttention(nn.Module):
...
@@ -216,6 +215,7 @@ class CohereAttention(nn.Module):
self
.
max_position_embeddings
=
config
.
max_position_embeddings
self
.
max_position_embeddings
=
config
.
max_position_embeddings
self
.
rope_theta
=
config
.
rope_theta
self
.
rope_theta
=
config
.
rope_theta
self
.
is_causal
=
True
self
.
is_causal
=
True
self
.
use_qk_norm
=
config
.
use_qk_norm
if
(
self
.
head_dim
*
self
.
num_heads
)
!=
self
.
hidden_size
:
if
(
self
.
head_dim
*
self
.
num_heads
)
!=
self
.
hidden_size
:
raise
ValueError
(
raise
ValueError
(
...
@@ -223,6 +223,13 @@ class CohereAttention(nn.Module):
...
@@ -223,6 +223,13 @@ class CohereAttention(nn.Module):
f
" and `num_heads`:
{
self
.
num_heads
}
)."
f
" and `num_heads`:
{
self
.
num_heads
}
)."
)
)
if
self
.
use_qk_norm
:
# When sharding the model using Tensor Parallelism, need to be careful to use n_local_heads
self
.
q_norm
=
CohereLayerNorm
(
hidden_size
=
(
self
.
num_heads
,
self
.
head_dim
),
eps
=
config
.
layer_norm_eps
)
self
.
k_norm
=
CohereLayerNorm
(
hidden_size
=
(
self
.
num_key_value_heads
,
self
.
head_dim
),
eps
=
config
.
layer_norm_eps
)
self
.
q_proj
=
nn
.
Linear
(
self
.
hidden_size
,
self
.
num_heads
*
self
.
head_dim
,
bias
=
config
.
attention_bias
)
self
.
q_proj
=
nn
.
Linear
(
self
.
hidden_size
,
self
.
num_heads
*
self
.
head_dim
,
bias
=
config
.
attention_bias
)
self
.
k_proj
=
nn
.
Linear
(
self
.
hidden_size
,
self
.
num_key_value_heads
*
self
.
head_dim
,
bias
=
config
.
attention_bias
)
self
.
k_proj
=
nn
.
Linear
(
self
.
hidden_size
,
self
.
num_key_value_heads
*
self
.
head_dim
,
bias
=
config
.
attention_bias
)
self
.
v_proj
=
nn
.
Linear
(
self
.
hidden_size
,
self
.
num_key_value_heads
*
self
.
head_dim
,
bias
=
config
.
attention_bias
)
self
.
v_proj
=
nn
.
Linear
(
self
.
hidden_size
,
self
.
num_key_value_heads
*
self
.
head_dim
,
bias
=
config
.
attention_bias
)
...
@@ -255,8 +262,14 @@ class CohereAttention(nn.Module):
...
@@ -255,8 +262,14 @@ class CohereAttention(nn.Module):
key_states
=
self
.
k_proj
(
hidden_states
)
key_states
=
self
.
k_proj
(
hidden_states
)
value_states
=
self
.
v_proj
(
hidden_states
)
value_states
=
self
.
v_proj
(
hidden_states
)
query_states
=
query_states
.
view
(
bsz
,
q_len
,
self
.
num_heads
,
self
.
head_dim
).
transpose
(
1
,
2
)
query_states
=
query_states
.
view
(
bsz
,
q_len
,
self
.
num_heads
,
self
.
head_dim
)
key_states
=
key_states
.
view
(
bsz
,
q_len
,
self
.
num_key_value_heads
,
self
.
head_dim
).
transpose
(
1
,
2
)
key_states
=
key_states
.
view
(
bsz
,
q_len
,
self
.
num_key_value_heads
,
self
.
head_dim
)
if
self
.
use_qk_norm
:
query_states
=
self
.
q_norm
(
query_states
)
key_states
=
self
.
k_norm
(
key_states
)
query_states
=
query_states
.
transpose
(
1
,
2
)
key_states
=
key_states
.
transpose
(
1
,
2
)
value_states
=
value_states
.
view
(
bsz
,
q_len
,
self
.
num_key_value_heads
,
self
.
head_dim
).
transpose
(
1
,
2
)
value_states
=
value_states
.
view
(
bsz
,
q_len
,
self
.
num_key_value_heads
,
self
.
head_dim
).
transpose
(
1
,
2
)
past_key_value
=
getattr
(
self
,
"past_key_value"
,
past_key_value
)
past_key_value
=
getattr
(
self
,
"past_key_value"
,
past_key_value
)
...
@@ -335,11 +348,14 @@ class CohereFlashAttention2(CohereAttention):
...
@@ -335,11 +348,14 @@ class CohereFlashAttention2(CohereAttention):
key_states
=
self
.
k_proj
(
hidden_states
)
key_states
=
self
.
k_proj
(
hidden_states
)
value_states
=
self
.
v_proj
(
hidden_states
)
value_states
=
self
.
v_proj
(
hidden_states
)
# Flash attention requires the input to have the shape
query_states
=
query_states
.
view
(
bsz
,
q_len
,
self
.
num_heads
,
self
.
head_dim
)
# batch_size x seq_length x head_dim x hidden_dim
key_states
=
key_states
.
view
(
bsz
,
q_len
,
self
.
num_key_value_heads
,
self
.
head_dim
)
# therefore we just need to keep the original shape
if
self
.
use_qk_norm
:
query_states
=
query_states
.
view
(
bsz
,
q_len
,
self
.
num_heads
,
self
.
head_dim
).
transpose
(
1
,
2
)
query_states
=
self
.
q_norm
(
query_states
)
key_states
=
key_states
.
view
(
bsz
,
q_len
,
self
.
num_key_value_heads
,
self
.
head_dim
).
transpose
(
1
,
2
)
key_states
=
self
.
k_norm
(
key_states
)
query_states
=
query_states
.
transpose
(
1
,
2
)
key_states
=
key_states
.
transpose
(
1
,
2
)
value_states
=
value_states
.
view
(
bsz
,
q_len
,
self
.
num_key_value_heads
,
self
.
head_dim
).
transpose
(
1
,
2
)
value_states
=
value_states
.
view
(
bsz
,
q_len
,
self
.
num_key_value_heads
,
self
.
head_dim
).
transpose
(
1
,
2
)
cos
,
sin
=
self
.
rotary_emb
(
value_states
,
position_ids
)
cos
,
sin
=
self
.
rotary_emb
(
value_states
,
position_ids
)
...
@@ -505,7 +521,7 @@ class CohereSdpaAttention(CohereAttention):
...
@@ -505,7 +521,7 @@ class CohereSdpaAttention(CohereAttention):
SDPA API.
SDPA API.
"""
"""
#
Adapted from CohereAttention.forward
#
Ignore copy
def
forward
(
def
forward
(
self
,
self
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
...
@@ -538,8 +554,14 @@ class CohereSdpaAttention(CohereAttention):
...
@@ -538,8 +554,14 @@ class CohereSdpaAttention(CohereAttention):
key_states
=
self
.
k_proj
(
hidden_states
)
key_states
=
self
.
k_proj
(
hidden_states
)
value_states
=
self
.
v_proj
(
hidden_states
)
value_states
=
self
.
v_proj
(
hidden_states
)
query_states
=
query_states
.
view
(
bsz
,
q_len
,
self
.
num_heads
,
self
.
head_dim
).
transpose
(
1
,
2
)
query_states
=
query_states
.
view
(
bsz
,
q_len
,
self
.
num_heads
,
self
.
head_dim
)
key_states
=
key_states
.
view
(
bsz
,
q_len
,
self
.
num_key_value_heads
,
self
.
head_dim
).
transpose
(
1
,
2
)
key_states
=
key_states
.
view
(
bsz
,
q_len
,
self
.
num_key_value_heads
,
self
.
head_dim
)
if
self
.
use_qk_norm
:
query_states
=
self
.
q_norm
(
query_states
)
key_states
=
self
.
k_norm
(
key_states
)
query_states
=
query_states
.
transpose
(
1
,
2
)
key_states
=
key_states
.
transpose
(
1
,
2
)
value_states
=
value_states
.
view
(
bsz
,
q_len
,
self
.
num_key_value_heads
,
self
.
head_dim
).
transpose
(
1
,
2
)
value_states
=
value_states
.
view
(
bsz
,
q_len
,
self
.
num_key_value_heads
,
self
.
head_dim
).
transpose
(
1
,
2
)
cos
,
sin
=
self
.
rotary_emb
(
value_states
,
position_ids
)
cos
,
sin
=
self
.
rotary_emb
(
value_states
,
position_ids
)
...
@@ -599,7 +621,7 @@ class CohereDecoderLayer(nn.Module):
...
@@ -599,7 +621,7 @@ class CohereDecoderLayer(nn.Module):
self
.
self_attn
=
COHERE_ATTENTION_CLASSES
[
config
.
_attn_implementation
](
config
=
config
,
layer_idx
=
layer_idx
)
self
.
self_attn
=
COHERE_ATTENTION_CLASSES
[
config
.
_attn_implementation
](
config
=
config
,
layer_idx
=
layer_idx
)
self
.
mlp
=
CohereMLP
(
config
)
self
.
mlp
=
CohereMLP
(
config
)
self
.
input_layernorm
=
CohereLayerNorm
(
config
.
hidden_size
,
eps
=
config
.
layer_norm_eps
)
self
.
input_layernorm
=
CohereLayerNorm
(
hidden_size
=
(
config
.
hidden_size
)
,
eps
=
config
.
layer_norm_eps
)
def
forward
(
def
forward
(
self
,
self
,
...
@@ -822,7 +844,7 @@ class CohereModel(CoherePreTrainedModel):
...
@@ -822,7 +844,7 @@ class CohereModel(CoherePreTrainedModel):
self
.
layers
=
nn
.
ModuleList
(
self
.
layers
=
nn
.
ModuleList
(
[
CohereDecoderLayer
(
config
,
layer_idx
)
for
layer_idx
in
range
(
config
.
num_hidden_layers
)]
[
CohereDecoderLayer
(
config
,
layer_idx
)
for
layer_idx
in
range
(
config
.
num_hidden_layers
)]
)
)
self
.
norm
=
CohereLayerNorm
(
config
.
hidden_size
,
eps
=
config
.
layer_norm_eps
)
self
.
norm
=
CohereLayerNorm
(
hidden_size
=
(
config
.
hidden_size
)
,
eps
=
config
.
layer_norm_eps
)
self
.
gradient_checkpointing
=
False
self
.
gradient_checkpointing
=
False
# Initialize weights and apply final processing
# Initialize weights and apply final processing
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment