Unverified Commit 328841d0 authored by youkaichao's avatar youkaichao Committed by GitHub
Browse files

[bugfix] interleaving sliding window for cohere2 model (#11583)


Signed-off-by: default avataryoukaichao <youkaichao@gmail.com>
parent d427e5cf
...@@ -112,7 +112,7 @@ See [this page](#generative-models) for more information on how to use generativ ...@@ -112,7 +112,7 @@ See [this page](#generative-models) for more information on how to use generativ
- :code:`THUDM/chatglm2-6b`, :code:`THUDM/chatglm3-6b`, etc. - :code:`THUDM/chatglm2-6b`, :code:`THUDM/chatglm3-6b`, etc.
- ✅︎ - ✅︎
- ✅︎ - ✅︎
* - :code:`CohereForCausalLM`,:code:`Cohere2ForCausalLM` * - :code:`CohereForCausalLM`, :code:`Cohere2ForCausalLM`
- Command-R - Command-R
- :code:`CohereForAI/c4ai-command-r-v01`, :code:`CohereForAI/c4ai-command-r7b-12-2024`, etc. - :code:`CohereForAI/c4ai-command-r-v01`, :code:`CohereForAI/c4ai-command-r7b-12-2024`, etc.
- ✅︎ - ✅︎
......
from unittest.mock import patch from unittest.mock import patch
import pytest import pytest
import transformers
from transformers import PretrainedConfig from transformers import PretrainedConfig
from vllm import LLM from vllm import LLM
...@@ -12,9 +11,6 @@ from .registry import HF_EXAMPLE_MODELS ...@@ -12,9 +11,6 @@ from .registry import HF_EXAMPLE_MODELS
@pytest.mark.parametrize("model_arch", HF_EXAMPLE_MODELS.get_supported_archs()) @pytest.mark.parametrize("model_arch", HF_EXAMPLE_MODELS.get_supported_archs())
def test_can_initialize(model_arch): def test_can_initialize(model_arch):
model_info = HF_EXAMPLE_MODELS.get_hf_info(model_arch) model_info = HF_EXAMPLE_MODELS.get_hf_info(model_arch)
if (model_arch == "Cohere2ForCausalLM"
and transformers.__version__ < "4.48.0"):
pytest.skip(reason="Model introduced in HF >= 4.48.0")
if not model_info.is_available_online: if not model_info.is_available_online:
pytest.skip("Model is not available online") pytest.skip("Model is not available online")
......
...@@ -301,7 +301,7 @@ class ModelConfig: ...@@ -301,7 +301,7 @@ class ModelConfig:
sliding_window = getattr(self.hf_text_config, "sliding_window", None) sliding_window = getattr(self.hf_text_config, "sliding_window", None)
has_interleaved_attention = (sliding_window is not None) and ( has_interleaved_attention = (sliding_window is not None) and (
isinstance(sliding_window, list) or isinstance(sliding_window, list) or
(self.hf_text_config.model_type in ["gemma2"])) (self.hf_text_config.model_type in ["gemma2", "cohere2"]))
if (not self.disable_sliding_window and has_interleaved_attention): if (not self.disable_sliding_window and has_interleaved_attention):
if envs.VLLM_ATTENTION_BACKEND == "XFORMERS": if envs.VLLM_ATTENTION_BACKEND == "XFORMERS":
......
...@@ -172,16 +172,18 @@ class CohereAttention(nn.Module): ...@@ -172,16 +172,18 @@ class CohereAttention(nn.Module):
is_neox_style=False, is_neox_style=False,
) )
sliding_window = getattr(config, "sliding_window", None) # Model v2 has interleaved sliding windows, v1 does not
# Model v2 has sliding windows, v1 does not interleaved_sliding_window = getattr(config,
self.v1 = sliding_window is None "interleaved_sliding_window",
None)
self.v1 = interleaved_sliding_window is None
layer_idx = extract_layer_index(prefix) layer_idx = extract_layer_index(prefix)
layer_has_sliding_window = ( layer_has_sliding_window = (
getattr(config, "sliding_window_pattern", False) getattr(config, "sliding_window_pattern", False)
and (layer_idx + 1) % self.config.sliding_window_pattern != 0) and (layer_idx + 1) % self.config.sliding_window_pattern != 0)
self.sliding_window = (sliding_window self.sliding_window = (interleaved_sliding_window
if layer_has_sliding_window else None) if layer_has_sliding_window else None)
self.attn = Attention(self.num_heads, self.attn = Attention(self.num_heads,
......
...@@ -22,9 +22,9 @@ from vllm.envs import VLLM_USE_MODELSCOPE ...@@ -22,9 +22,9 @@ from vllm.envs import VLLM_USE_MODELSCOPE
from vllm.logger import init_logger from vllm.logger import init_logger
# yapf conflicts with isort for this block # yapf conflicts with isort for this block
# yapf: disable # yapf: disable
from vllm.transformers_utils.configs import (ChatGLMConfig, DbrxConfig, from vllm.transformers_utils.configs import (ChatGLMConfig, Cohere2Config,
EAGLEConfig, ExaoneConfig, DbrxConfig, EAGLEConfig,
H2OVLChatConfig, ExaoneConfig, H2OVLChatConfig,
InternVLChatConfig, JAISConfig, InternVLChatConfig, JAISConfig,
MedusaConfig, MllamaConfig, MedusaConfig, MllamaConfig,
MLPSpeculatorConfig, MPTConfig, MLPSpeculatorConfig, MPTConfig,
...@@ -52,6 +52,7 @@ _CONFIG_REGISTRY_OVERRIDE_HF: Dict[str, Type[PretrainedConfig]] = { ...@@ -52,6 +52,7 @@ _CONFIG_REGISTRY_OVERRIDE_HF: Dict[str, Type[PretrainedConfig]] = {
_CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = { _CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = {
"chatglm": ChatGLMConfig, "chatglm": ChatGLMConfig,
"cohere2": Cohere2Config,
"dbrx": DbrxConfig, "dbrx": DbrxConfig,
"mpt": MPTConfig, "mpt": MPTConfig,
"RefinedWeb": RWConfig, # For tiiuae/falcon-40b(-instruct) "RefinedWeb": RWConfig, # For tiiuae/falcon-40b(-instruct)
......
from vllm.transformers_utils.configs.chatglm import ChatGLMConfig from vllm.transformers_utils.configs.chatglm import ChatGLMConfig
from vllm.transformers_utils.configs.cohere2 import Cohere2Config
from vllm.transformers_utils.configs.dbrx import DbrxConfig from vllm.transformers_utils.configs.dbrx import DbrxConfig
from vllm.transformers_utils.configs.eagle import EAGLEConfig from vllm.transformers_utils.configs.eagle import EAGLEConfig
from vllm.transformers_utils.configs.exaone import ExaoneConfig from vllm.transformers_utils.configs.exaone import ExaoneConfig
...@@ -22,6 +23,7 @@ from vllm.transformers_utils.configs.ultravox import UltravoxConfig ...@@ -22,6 +23,7 @@ from vllm.transformers_utils.configs.ultravox import UltravoxConfig
__all__ = [ __all__ = [
"ChatGLMConfig", "ChatGLMConfig",
"Cohere2Config",
"DbrxConfig", "DbrxConfig",
"MPTConfig", "MPTConfig",
"RWConfig", "RWConfig",
......
# ruff: noqa
# Adapted from
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/cohere2/configuration_cohere2.py
from transformers import PretrainedConfig
from transformers.modeling_rope_utils import rope_config_validation
class Cohere2Config(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`CohereModel`]. It is used to instantiate an Cohere
model according to the specified arguments, defining the model architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information. Instantiating a configuration
with the defaults will yield a similar configuration to that of the [CohereForAI/c4ai-command-r-v01](https://huggingface.co/CohereForAI/c4ai-command-r-v01) model.
Args:
vocab_size (`int`, *optional*, defaults to 256000):
Vocabulary size of the Cohere model. Defines the number of different tokens that can be represented by the
`inputs_ids` passed when calling [`CohereModel`]
hidden_size (`int`, *optional*, defaults to 8192):
Dimension of the hidden representations.
intermediate_size (`int`, *optional*, defaults to 22528):
Dimension of the MLP representations.
logit_scale (`float`, *optional*, defaults to 0.0625):
The scaling factor for the output logits.
num_hidden_layers (`int`, *optional*, defaults to 40):
Number of hidden layers in the Transformer decoder.
num_attention_heads (`int`, *optional*, defaults to 64):
Number of attention heads for each attention layer in the Transformer decoder.
num_key_value_heads (`int`, *optional*):
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
`num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
by meanpooling all the original heads within that group. For more details checkout [this
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
`num_attention_heads`.
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
The non-linear activation function (function or string) in the decoder.
max_position_embeddings (`int`, *optional*, defaults to 8192):
The maximum sequence length that this model might ever be used with.
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
layer_norm_eps (`float`, *optional*, defaults to 1e-05):
The epsilon used by the layer normalization.
use_cache (`bool`, *optional*, defaults to `True`):
Whether or not the model should return the last key/values attentions (not used by all models). Only
relevant if `config.is_decoder=True`.
pad_token_id (`int`, *optional*, defaults to 0):
Padding token id.
bos_token_id (`int`, *optional*, defaults to 5):
Beginning of stream token id.
eos_token_id (`int`, *optional*, defaults to 255001):
End of stream token id.
tie_word_embeddings (`bool`, *optional*, defaults to `True`):
Whether to tie weight embeddings
rope_theta (`float`, *optional*, defaults to 10000.0):
The base period of the RoPE embeddings.
rope_scaling (`Dict`, *optional*):
Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
accordingly.
Expected contents:
`rope_type` (`str`):
The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
'llama3'], with 'default' being the original RoPE implementation.
`factor` (`float`, *optional*):
Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
most scaling types, a `factor` of x will enable the model to handle sequences of length x *
original maximum pre-trained length.
`original_max_position_embeddings` (`int`, *optional*):
Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
pretraining.
`attention_factor` (`float`, *optional*):
Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
computation. If unspecified, it defaults to value recommended by the implementation, using the
`factor` field to infer the suggested value.
`beta_fast` (`float`, *optional*):
Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
ramp function. If unspecified, it defaults to 32.
`beta_slow` (`float`, *optional*):
Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
ramp function. If unspecified, it defaults to 1.
`short_factor` (`List[float]`, *optional*):
Only used with 'longrope'. The scaling factor to be applied to short contexts (<
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
size divided by the number of attention heads divided by 2
`long_factor` (`List[float]`, *optional*):
Only used with 'longrope'. The scaling factor to be applied to long contexts (<
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
size divided by the number of attention heads divided by 2
`low_freq_factor` (`float`, *optional*):
Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
`high_freq_factor` (`float`, *optional*):
Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
Whether to use a bias in the query, key, value and output projection layers during self-attention.
attention_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for the attention probabilities.
sliding_window (`int`, *optional*, defaults to 4096):
Size of the sliding window attention context.
sliding_window_pattern (`int`, *optional*, defaults to 4):
Pattern for the sliding window attention.
cache_implementation (`str`, *optional*, defaults to `"hybrid"`): the cache type to be used with `generate`.
```python
>>> from transformers import Cohere2Model, Cohere2Config
>>> # Initializing a Cohere Nextmodel configuration
>>> configuration = Cohere2Config()
>>> # Initializing a model from the Cohere2 configuration
>>> model = Cohere2Model(configuration) # doctest: +SKIP
>>> # Accessing the model configuration
>>> configuration = model.config # doctest: +SKIP
```
"""
model_type = "cohere2"
keys_to_ignore_at_inference = ["past_key_values"]
def __init__(
self,
vocab_size=256000,
hidden_size=8192,
intermediate_size=22528,
logit_scale=0.0625,
num_hidden_layers=40,
num_attention_heads=64,
num_key_value_heads=None,
hidden_act="silu",
max_position_embeddings=8192,
initializer_range=0.02,
layer_norm_eps=1e-5,
use_cache=True,
pad_token_id=0,
bos_token_id=5,
eos_token_id=255001,
tie_word_embeddings=True,
rope_theta=10000.0,
rope_scaling=None,
attention_bias=False,
attention_dropout=0.0,
sliding_window=4096,
sliding_window_pattern=4,
cache_implementation="hybrid",
**kwargs,
):
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.logit_scale = logit_scale
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
# for backward compatibility
if num_key_value_heads is None:
num_key_value_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.hidden_act = hidden_act
self.initializer_range = initializer_range
self.layer_norm_eps = layer_norm_eps
self.use_cache = use_cache
self.rope_theta = rope_theta
self.rope_scaling = rope_scaling
self.attention_bias = attention_bias
self.attention_dropout = attention_dropout
self.sliding_window = sliding_window
self.sliding_window_pattern = sliding_window_pattern
# Need to specify head_dim in the config so it can be used in the attention forward functions
self.head_dim = hidden_size // num_attention_heads
self.cache_implementation = cache_implementation
# Validate the correctness of rotary position embeddings parameters
rope_config_validation(self)
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)
__all__ = ["Cohere2Config"]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment