Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
9117f892
Unverified
Commit
9117f892
authored
Apr 05, 2024
by
Saurabh Dash
Committed by
GitHub
Apr 04, 2024
Browse files
[Model] Cohere CommandR+ (#3829)
parent
db2a6a41
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
40 additions
and
8 deletions
+40
-8
vllm/model_executor/models/commandr.py
vllm/model_executor/models/commandr.py
+40
-8
No files found.
vllm/model_executor/models/commandr.py
View file @
9117f892
...
...
@@ -25,6 +25,7 @@ from typing import List, Optional, Tuple
import
torch
import
torch.utils.checkpoint
from
torch
import
nn
from
torch.nn.parameter
import
Parameter
from
transformers
import
CohereConfig
from
vllm.attention
import
Attention
,
AttentionMetadata
...
...
@@ -39,8 +40,9 @@ from vllm.model_executor.layers.sampler import Sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
VocabParallelEmbedding
)
from
vllm.model_executor.parallel_utils.parallel_state
import
(
get_tensor_model_parallel_world_size
)
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
)
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.model_executor.weight_utils
import
(
default_weight_loader
,
hf_model_weights_iterator
)
from
vllm.sequence
import
SamplerOutput
...
...
@@ -48,11 +50,11 @@ from vllm.sequence import SamplerOutput
class
LayerNorm
(
nn
.
Module
):
def
__init__
(
self
,
hidden_size
,
eps
=
1e-5
,
bias
=
False
):
def
__init__
(
self
,
param_shape
=
None
,
eps
=
1e-5
):
super
().
__init__
()
self
.
weight
=
nn
.
Parameter
(
torch
.
ones
(
hidden_size
))
self
.
bias
=
nn
.
Parameter
(
torch
.
zeros
(
hidden_size
))
if
bias
else
None
self
.
weight
=
nn
.
Parameter
(
torch
.
ones
(
param_shape
))
self
.
variance_epsilon
=
eps
set_weight_attrs
(
self
.
weight
,
{
"weight_loader"
:
self
.
weight_loader
})
def
forward
(
self
,
hidden_states
,
residuals
=
None
):
input_dtype
=
hidden_states
.
dtype
...
...
@@ -62,10 +64,20 @@ class LayerNorm(nn.Module):
hidden_states
=
(
hidden_states
-
mean
)
*
torch
.
rsqrt
(
variance
+
self
.
variance_epsilon
)
hidden_states
=
self
.
weight
.
to
(
torch
.
float32
)
*
hidden_states
if
self
.
bias
is
not
None
:
hidden_states
=
hidden_states
+
self
.
bias
.
to
(
torch
.
float32
)
return
hidden_states
.
to
(
input_dtype
),
residuals
def
weight_loader
(
self
,
param
:
Parameter
,
loaded_weight
:
torch
.
Tensor
):
tp_rank
=
get_tensor_model_parallel_rank
()
shard_dim
=
0
if
param
.
dim
()
!=
1
else
None
param_data
=
param
.
data
if
shard_dim
is
not
None
:
shard_size
=
param_data
.
shape
[
shard_dim
]
start_idx
=
tp_rank
*
shard_size
loaded_weight
=
loaded_weight
.
narrow
(
shard_dim
,
start_idx
,
shard_size
)
assert
param_data
.
shape
==
loaded_weight
.
shape
param_data
.
copy_
(
loaded_weight
)
# Copied from transformers.models.llama.modeling_llama.LlamaMLP Llama->Cohere
class
CohereMLP
(
nn
.
Module
):
...
...
@@ -131,6 +143,7 @@ class CohereAttention(nn.Module):
self
.
max_position_embeddings
=
config
.
max_position_embeddings
self
.
rope_theta
=
config
.
rope_theta
self
.
rope_scaling
=
getattr
(
config
,
"rope_scaling"
,
None
)
self
.
use_qk_norm
=
getattr
(
config
,
"use_qk_norm"
,
False
)
self
.
qkv_proj
=
QKVParallelLinear
(
self
.
hidden_size
,
self
.
head_dim
,
...
...
@@ -159,6 +172,22 @@ class CohereAttention(nn.Module):
self
.
scaling
,
num_kv_heads
=
self
.
num_kv_heads
,
)
if
self
.
use_qk_norm
:
self
.
q_norm
=
LayerNorm
(
param_shape
=
(
self
.
num_heads
,
self
.
head_dim
),
eps
=
config
.
layer_norm_eps
)
self
.
k_norm
=
LayerNorm
(
param_shape
=
(
self
.
num_kv_heads
,
self
.
head_dim
),
eps
=
config
.
layer_norm_eps
)
def
_apply_qk_norm
(
self
,
q
,
k
):
q
=
q
.
view
(
*
q
.
shape
[:
-
1
],
-
1
,
self
.
head_dim
)
k
=
k
.
view
(
*
k
.
shape
[:
-
1
],
-
1
,
self
.
head_dim
)
q
,
_
=
self
.
q_norm
(
q
)
k
,
_
=
self
.
k_norm
(
k
)
q
=
q
.
view
(
*
q
.
shape
[:
-
2
],
-
1
)
k
=
k
.
view
(
*
k
.
shape
[:
-
2
],
-
1
)
return
q
,
k
def
forward
(
self
,
...
...
@@ -169,6 +198,8 @@ class CohereAttention(nn.Module):
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
if
self
.
use_qk_norm
:
q
,
k
=
self
.
_apply_qk_norm
(
q
,
k
)
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
attn_output
=
self
.
attn
(
q
,
k
,
v
,
kv_cache
,
attn_metadata
)
output
,
_
=
self
.
o_proj
(
attn_output
)
...
...
@@ -186,7 +217,7 @@ class CohereDecoderLayer(nn.Module):
self
.
self_attn
=
CohereAttention
(
config
,
linear_method
=
linear_method
)
self
.
mlp
=
CohereMLP
(
config
,
linear_method
=
linear_method
)
self
.
input_layernorm
=
LayerNorm
(
config
.
hidden_size
,
self
.
input_layernorm
=
LayerNorm
(
param_shape
=
(
config
.
hidden_size
)
,
eps
=
config
.
layer_norm_eps
)
def
forward
(
...
...
@@ -229,7 +260,8 @@ class CohereModel(nn.Module):
CohereDecoderLayer
(
config
,
linear_method
=
linear_method
)
for
_
in
range
(
config
.
num_hidden_layers
)
])
self
.
norm
=
LayerNorm
(
config
.
hidden_size
,
eps
=
config
.
layer_norm_eps
)
self
.
norm
=
LayerNorm
(
param_shape
=
(
config
.
hidden_size
),
eps
=
config
.
layer_norm_eps
)
def
forward
(
self
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment