Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
bddbbcb1
Unverified
Commit
bddbbcb1
authored
Dec 16, 2024
by
Jani Monoses
Committed by
GitHub
Dec 16, 2024
Browse files
[Model] Support Cohere2ForCausalLM (Cohere R7B) (#11203)
parent
b3b1526f
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
26 additions
and
4 deletions
+26
-4
docs/source/models/supported_models.rst
docs/source/models/supported_models.rst
+2
-2
tests/models/registry.py
tests/models/registry.py
+2
-0
tests/models/test_initialization.py
tests/models/test_initialization.py
+4
-0
vllm/model_executor/models/commandr.py
vllm/model_executor/models/commandr.py
+17
-2
vllm/model_executor/models/registry.py
vllm/model_executor/models/registry.py
+1
-0
No files found.
docs/source/models/supported_models.rst
View file @
bddbbcb1
...
@@ -118,9 +118,9 @@ Text Generation (``--task generate``)
...
@@ -118,9 +118,9 @@ Text Generation (``--task generate``)
- :code:`THUDM/chatglm2-6b`, :code:`THUDM/chatglm3-6b`, etc.
- :code:`THUDM/chatglm2-6b`, :code:`THUDM/chatglm3-6b`, etc.
- ✅︎
- ✅︎
- ✅︎
- ✅︎
* - :code:`CohereForCausalLM`
* - :code:`CohereForCausalLM`
,:code:`Cohere2ForCausalLM`
- Command-R
- Command-R
- :code:`CohereForAI/c4ai-command-r-v01`, etc.
- :code:`CohereForAI/c4ai-command-r-v01`,
:code:`CohereForAI/c4ai-command-r7b-12-2024`,
etc.
- ✅︎
- ✅︎
- ✅︎
- ✅︎
* - :code:`DbrxForCausalLM`
* - :code:`DbrxForCausalLM`
...
...
tests/models/registry.py
View file @
bddbbcb1
...
@@ -53,6 +53,8 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
...
@@ -53,6 +53,8 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
# ChatGLMModel supports multimodal
# ChatGLMModel supports multimodal
"CohereForCausalLM"
:
_HfExamplesInfo
(
"CohereForAI/c4ai-command-r-v01"
,
"CohereForCausalLM"
:
_HfExamplesInfo
(
"CohereForAI/c4ai-command-r-v01"
,
trust_remote_code
=
True
),
trust_remote_code
=
True
),
"Cohere2ForCausalLM"
:
_HfExamplesInfo
(
"CohereForAI/c4ai-command-r7b-12-2024"
,
# noqa: E501
trust_remote_code
=
True
),
"DbrxForCausalLM"
:
_HfExamplesInfo
(
"databricks/dbrx-instruct"
),
"DbrxForCausalLM"
:
_HfExamplesInfo
(
"databricks/dbrx-instruct"
),
"DeciLMForCausalLM"
:
_HfExamplesInfo
(
"Deci/DeciLM-7B-instruct"
,
"DeciLMForCausalLM"
:
_HfExamplesInfo
(
"Deci/DeciLM-7B-instruct"
,
trust_remote_code
=
True
),
trust_remote_code
=
True
),
...
...
tests/models/test_initialization.py
View file @
bddbbcb1
from
unittest.mock
import
patch
from
unittest.mock
import
patch
import
pytest
import
pytest
import
transformers
from
transformers
import
PretrainedConfig
from
transformers
import
PretrainedConfig
from
vllm
import
LLM
from
vllm
import
LLM
...
@@ -11,6 +12,9 @@ from .registry import HF_EXAMPLE_MODELS
...
@@ -11,6 +12,9 @@ from .registry import HF_EXAMPLE_MODELS
@
pytest
.
mark
.
parametrize
(
"model_arch"
,
HF_EXAMPLE_MODELS
.
get_supported_archs
())
@
pytest
.
mark
.
parametrize
(
"model_arch"
,
HF_EXAMPLE_MODELS
.
get_supported_archs
())
def
test_can_initialize
(
model_arch
):
def
test_can_initialize
(
model_arch
):
model_info
=
HF_EXAMPLE_MODELS
.
get_hf_info
(
model_arch
)
model_info
=
HF_EXAMPLE_MODELS
.
get_hf_info
(
model_arch
)
if
(
model_arch
==
"Cohere2ForCausalLM"
and
transformers
.
__version__
<
"4.48.0"
):
pytest
.
skip
(
reason
=
"Model introduced in HF >= 4.48.0"
)
if
not
model_info
.
is_available_online
:
if
not
model_info
.
is_available_online
:
pytest
.
skip
(
"Model is not available online"
)
pytest
.
skip
(
"Model is not available online"
)
...
...
vllm/model_executor/models/commandr.py
View file @
bddbbcb1
...
@@ -48,7 +48,7 @@ from vllm.model_executor.utils import set_weight_attrs
...
@@ -48,7 +48,7 @@ from vllm.model_executor.utils import set_weight_attrs
from
vllm.sequence
import
IntermediateTensors
from
vllm.sequence
import
IntermediateTensors
from
.interfaces
import
SupportsLoRA
,
SupportsPP
from
.interfaces
import
SupportsLoRA
,
SupportsPP
from
.utils
import
(
is_pp_missing_parameter
,
from
.utils
import
(
extract_layer_index
,
is_pp_missing_parameter
,
make_empty_intermediate_tensors_factory
,
make_layers
,
make_empty_intermediate_tensors_factory
,
make_layers
,
maybe_prefix
)
maybe_prefix
)
...
@@ -171,12 +171,26 @@ class CohereAttention(nn.Module):
...
@@ -171,12 +171,26 @@ class CohereAttention(nn.Module):
rope_scaling
=
self
.
rope_scaling
,
rope_scaling
=
self
.
rope_scaling
,
is_neox_style
=
False
,
is_neox_style
=
False
,
)
)
sliding_window
=
getattr
(
config
,
"sliding_window"
,
None
)
# Model v2 has sliding windows, v1 does not
self
.
v1
=
sliding_window
is
None
layer_idx
=
extract_layer_index
(
prefix
)
layer_has_sliding_window
=
(
getattr
(
config
,
"sliding_window_pattern"
,
False
)
and
(
layer_idx
+
1
)
%
self
.
config
.
sliding_window_pattern
!=
0
)
self
.
sliding_window
=
(
sliding_window
if
layer_has_sliding_window
else
None
)
self
.
attn
=
Attention
(
self
.
num_heads
,
self
.
attn
=
Attention
(
self
.
num_heads
,
self
.
head_dim
,
self
.
head_dim
,
self
.
scaling
,
self
.
scaling
,
num_kv_heads
=
self
.
num_kv_heads
,
num_kv_heads
=
self
.
num_kv_heads
,
cache_config
=
cache_config
,
cache_config
=
cache_config
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
per_layer_sliding_window
=
self
.
sliding_window
,
prefix
=
f
"
{
prefix
}
.attn"
)
prefix
=
f
"
{
prefix
}
.attn"
)
if
self
.
use_qk_norm
:
if
self
.
use_qk_norm
:
self
.
q_norm
=
LayerNorm
(
param_shape
=
(
self
.
num_heads
,
self
.
q_norm
=
LayerNorm
(
param_shape
=
(
self
.
num_heads
,
...
@@ -206,6 +220,7 @@ class CohereAttention(nn.Module):
...
@@ -206,6 +220,7 @@ class CohereAttention(nn.Module):
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
if
self
.
use_qk_norm
:
if
self
.
use_qk_norm
:
q
,
k
=
self
.
_apply_qk_norm
(
q
,
k
)
q
,
k
=
self
.
_apply_qk_norm
(
q
,
k
)
if
self
.
v1
or
self
.
sliding_window
:
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
attn_output
=
self
.
attn
(
q
,
k
,
v
,
kv_cache
,
attn_metadata
)
attn_output
=
self
.
attn
(
q
,
k
,
v
,
kv_cache
,
attn_metadata
)
output
,
_
=
self
.
o_proj
(
attn_output
)
output
,
_
=
self
.
o_proj
(
attn_output
)
...
...
vllm/model_executor/models/registry.py
View file @
bddbbcb1
...
@@ -41,6 +41,7 @@ _TEXT_GENERATION_MODELS = {
...
@@ -41,6 +41,7 @@ _TEXT_GENERATION_MODELS = {
"BloomForCausalLM"
:
(
"bloom"
,
"BloomForCausalLM"
),
"BloomForCausalLM"
:
(
"bloom"
,
"BloomForCausalLM"
),
# ChatGLMModel supports multimodal
# ChatGLMModel supports multimodal
"CohereForCausalLM"
:
(
"commandr"
,
"CohereForCausalLM"
),
"CohereForCausalLM"
:
(
"commandr"
,
"CohereForCausalLM"
),
"Cohere2ForCausalLM"
:
(
"commandr"
,
"CohereForCausalLM"
),
"DbrxForCausalLM"
:
(
"dbrx"
,
"DbrxForCausalLM"
),
"DbrxForCausalLM"
:
(
"dbrx"
,
"DbrxForCausalLM"
),
"DeciLMForCausalLM"
:
(
"decilm"
,
"DeciLMForCausalLM"
),
"DeciLMForCausalLM"
:
(
"decilm"
,
"DeciLMForCausalLM"
),
"DeepseekForCausalLM"
:
(
"deepseek"
,
"DeepseekForCausalLM"
),
"DeepseekForCausalLM"
:
(
"deepseek"
,
"DeepseekForCausalLM"
),
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment