Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
bddbbcb1
Unverified
Commit
bddbbcb1
authored
Dec 16, 2024
by
Jani Monoses
Committed by
GitHub
Dec 16, 2024
Browse files
[Model] Support Cohere2ForCausalLM (Cohere R7B) (#11203)
parent
b3b1526f
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
26 additions
and
4 deletions
+26
-4
docs/source/models/supported_models.rst
docs/source/models/supported_models.rst
+2
-2
tests/models/registry.py
tests/models/registry.py
+2
-0
tests/models/test_initialization.py
tests/models/test_initialization.py
+4
-0
vllm/model_executor/models/commandr.py
vllm/model_executor/models/commandr.py
+17
-2
vllm/model_executor/models/registry.py
vllm/model_executor/models/registry.py
+1
-0
No files found.
docs/source/models/supported_models.rst
View file @
bddbbcb1
...
...
@@ -118,9 +118,9 @@ Text Generation (``--task generate``)
- :code:`THUDM/chatglm2-6b`, :code:`THUDM/chatglm3-6b`, etc.
- ✅︎
- ✅︎
* - :code:`CohereForCausalLM`
* - :code:`CohereForCausalLM`
,:code:`Cohere2ForCausalLM`
- Command-R
- :code:`CohereForAI/c4ai-command-r-v01`, etc.
- :code:`CohereForAI/c4ai-command-r-v01`,
:code:`CohereForAI/c4ai-command-r7b-12-2024`,
etc.
- ✅︎
- ✅︎
* - :code:`DbrxForCausalLM`
...
...
tests/models/registry.py
View file @
bddbbcb1
...
...
@@ -53,6 +53,8 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
# ChatGLMModel supports multimodal
"CohereForCausalLM"
:
_HfExamplesInfo
(
"CohereForAI/c4ai-command-r-v01"
,
trust_remote_code
=
True
),
"Cohere2ForCausalLM"
:
_HfExamplesInfo
(
"CohereForAI/c4ai-command-r7b-12-2024"
,
# noqa: E501
trust_remote_code
=
True
),
"DbrxForCausalLM"
:
_HfExamplesInfo
(
"databricks/dbrx-instruct"
),
"DeciLMForCausalLM"
:
_HfExamplesInfo
(
"Deci/DeciLM-7B-instruct"
,
trust_remote_code
=
True
),
...
...
tests/models/test_initialization.py
View file @
bddbbcb1
from
unittest.mock
import
patch
import
pytest
import
transformers
from
transformers
import
PretrainedConfig
from
vllm
import
LLM
...
...
@@ -11,6 +12,9 @@ from .registry import HF_EXAMPLE_MODELS
@
pytest
.
mark
.
parametrize
(
"model_arch"
,
HF_EXAMPLE_MODELS
.
get_supported_archs
())
def
test_can_initialize
(
model_arch
):
model_info
=
HF_EXAMPLE_MODELS
.
get_hf_info
(
model_arch
)
if
(
model_arch
==
"Cohere2ForCausalLM"
and
transformers
.
__version__
<
"4.48.0"
):
pytest
.
skip
(
reason
=
"Model introduced in HF >= 4.48.0"
)
if
not
model_info
.
is_available_online
:
pytest
.
skip
(
"Model is not available online"
)
...
...
vllm/model_executor/models/commandr.py
View file @
bddbbcb1
...
...
@@ -48,7 +48,7 @@ from vllm.model_executor.utils import set_weight_attrs
from
vllm.sequence
import
IntermediateTensors
from
.interfaces
import
SupportsLoRA
,
SupportsPP
from
.utils
import
(
is_pp_missing_parameter
,
from
.utils
import
(
extract_layer_index
,
is_pp_missing_parameter
,
make_empty_intermediate_tensors_factory
,
make_layers
,
maybe_prefix
)
...
...
@@ -171,12 +171,26 @@ class CohereAttention(nn.Module):
rope_scaling
=
self
.
rope_scaling
,
is_neox_style
=
False
,
)
sliding_window
=
getattr
(
config
,
"sliding_window"
,
None
)
# Model v2 has sliding windows, v1 does not
self
.
v1
=
sliding_window
is
None
layer_idx
=
extract_layer_index
(
prefix
)
layer_has_sliding_window
=
(
getattr
(
config
,
"sliding_window_pattern"
,
False
)
and
(
layer_idx
+
1
)
%
self
.
config
.
sliding_window_pattern
!=
0
)
self
.
sliding_window
=
(
sliding_window
if
layer_has_sliding_window
else
None
)
self
.
attn
=
Attention
(
self
.
num_heads
,
self
.
head_dim
,
self
.
scaling
,
num_kv_heads
=
self
.
num_kv_heads
,
cache_config
=
cache_config
,
quant_config
=
quant_config
,
per_layer_sliding_window
=
self
.
sliding_window
,
prefix
=
f
"
{
prefix
}
.attn"
)
if
self
.
use_qk_norm
:
self
.
q_norm
=
LayerNorm
(
param_shape
=
(
self
.
num_heads
,
...
...
@@ -206,7 +220,8 @@ class CohereAttention(nn.Module):
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
if
self
.
use_qk_norm
:
q
,
k
=
self
.
_apply_qk_norm
(
q
,
k
)
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
if
self
.
v1
or
self
.
sliding_window
:
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
attn_output
=
self
.
attn
(
q
,
k
,
v
,
kv_cache
,
attn_metadata
)
output
,
_
=
self
.
o_proj
(
attn_output
)
return
output
...
...
vllm/model_executor/models/registry.py
View file @
bddbbcb1
...
...
@@ -41,6 +41,7 @@ _TEXT_GENERATION_MODELS = {
"BloomForCausalLM"
:
(
"bloom"
,
"BloomForCausalLM"
),
# ChatGLMModel supports multimodal
"CohereForCausalLM"
:
(
"commandr"
,
"CohereForCausalLM"
),
"Cohere2ForCausalLM"
:
(
"commandr"
,
"CohereForCausalLM"
),
"DbrxForCausalLM"
:
(
"dbrx"
,
"DbrxForCausalLM"
),
"DeciLMForCausalLM"
:
(
"decilm"
,
"DeciLMForCausalLM"
),
"DeepseekForCausalLM"
:
(
"deepseek"
,
"DeepseekForCausalLM"
),
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment