Unverified Commit 9117f892 authored by Saurabh Dash's avatar Saurabh Dash Committed by GitHub
Browse files

[Model] Cohere CommandR+ (#3829)

parent db2a6a41
...@@ -25,6 +25,7 @@ from typing import List, Optional, Tuple ...@@ -25,6 +25,7 @@ from typing import List, Optional, Tuple
import torch import torch
import torch.utils.checkpoint import torch.utils.checkpoint
from torch import nn from torch import nn
from torch.nn.parameter import Parameter
from transformers import CohereConfig from transformers import CohereConfig
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
...@@ -39,8 +40,9 @@ from vllm.model_executor.layers.sampler import Sampler ...@@ -39,8 +40,9 @@ from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding) VocabParallelEmbedding)
from vllm.model_executor.parallel_utils.parallel_state import ( from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_world_size) get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.utils import set_weight_attrs
from vllm.model_executor.weight_utils import (default_weight_loader, from vllm.model_executor.weight_utils import (default_weight_loader,
hf_model_weights_iterator) hf_model_weights_iterator)
from vllm.sequence import SamplerOutput from vllm.sequence import SamplerOutput
...@@ -48,11 +50,11 @@ from vllm.sequence import SamplerOutput ...@@ -48,11 +50,11 @@ from vllm.sequence import SamplerOutput
class LayerNorm(nn.Module): class LayerNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-5, bias=False): def __init__(self, param_shape=None, eps=1e-5):
super().__init__() super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size)) self.weight = nn.Parameter(torch.ones(param_shape))
self.bias = nn.Parameter(torch.zeros(hidden_size)) if bias else None
self.variance_epsilon = eps self.variance_epsilon = eps
set_weight_attrs(self.weight, {"weight_loader": self.weight_loader})
def forward(self, hidden_states, residuals=None): def forward(self, hidden_states, residuals=None):
input_dtype = hidden_states.dtype input_dtype = hidden_states.dtype
...@@ -62,10 +64,20 @@ class LayerNorm(nn.Module): ...@@ -62,10 +64,20 @@ class LayerNorm(nn.Module):
hidden_states = (hidden_states - hidden_states = (hidden_states -
mean) * torch.rsqrt(variance + self.variance_epsilon) mean) * torch.rsqrt(variance + self.variance_epsilon)
hidden_states = self.weight.to(torch.float32) * hidden_states hidden_states = self.weight.to(torch.float32) * hidden_states
if self.bias is not None:
hidden_states = hidden_states + self.bias.to(torch.float32)
return hidden_states.to(input_dtype), residuals return hidden_states.to(input_dtype), residuals
def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
tp_rank = get_tensor_model_parallel_rank()
shard_dim = 0 if param.dim() != 1 else None
param_data = param.data
if shard_dim is not None:
shard_size = param_data.shape[shard_dim]
start_idx = tp_rank * shard_size
loaded_weight = loaded_weight.narrow(shard_dim, start_idx,
shard_size)
assert param_data.shape == loaded_weight.shape
param_data.copy_(loaded_weight)
# Copied from transformers.models.llama.modeling_llama.LlamaMLP Llama->Cohere # Copied from transformers.models.llama.modeling_llama.LlamaMLP Llama->Cohere
class CohereMLP(nn.Module): class CohereMLP(nn.Module):
...@@ -131,6 +143,7 @@ class CohereAttention(nn.Module): ...@@ -131,6 +143,7 @@ class CohereAttention(nn.Module):
self.max_position_embeddings = config.max_position_embeddings self.max_position_embeddings = config.max_position_embeddings
self.rope_theta = config.rope_theta self.rope_theta = config.rope_theta
self.rope_scaling = getattr(config, "rope_scaling", None) self.rope_scaling = getattr(config, "rope_scaling", None)
self.use_qk_norm = getattr(config, "use_qk_norm", False)
self.qkv_proj = QKVParallelLinear( self.qkv_proj = QKVParallelLinear(
self.hidden_size, self.hidden_size,
self.head_dim, self.head_dim,
...@@ -159,6 +172,22 @@ class CohereAttention(nn.Module): ...@@ -159,6 +172,22 @@ class CohereAttention(nn.Module):
self.scaling, self.scaling,
num_kv_heads=self.num_kv_heads, num_kv_heads=self.num_kv_heads,
) )
if self.use_qk_norm:
self.q_norm = LayerNorm(param_shape=(self.num_heads,
self.head_dim),
eps=config.layer_norm_eps)
self.k_norm = LayerNorm(param_shape=(self.num_kv_heads,
self.head_dim),
eps=config.layer_norm_eps)
def _apply_qk_norm(self, q, k):
q = q.view(*q.shape[:-1], -1, self.head_dim)
k = k.view(*k.shape[:-1], -1, self.head_dim)
q, _ = self.q_norm(q)
k, _ = self.k_norm(k)
q = q.view(*q.shape[:-2], -1)
k = k.view(*k.shape[:-2], -1)
return q, k
def forward( def forward(
self, self,
...@@ -169,6 +198,8 @@ class CohereAttention(nn.Module): ...@@ -169,6 +198,8 @@ class CohereAttention(nn.Module):
) -> torch.Tensor: ) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states) qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
if self.use_qk_norm:
q, k = self._apply_qk_norm(q, k)
q, k = self.rotary_emb(positions, q, k) q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata) attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
output, _ = self.o_proj(attn_output) output, _ = self.o_proj(attn_output)
...@@ -186,7 +217,7 @@ class CohereDecoderLayer(nn.Module): ...@@ -186,7 +217,7 @@ class CohereDecoderLayer(nn.Module):
self.self_attn = CohereAttention(config, linear_method=linear_method) self.self_attn = CohereAttention(config, linear_method=linear_method)
self.mlp = CohereMLP(config, linear_method=linear_method) self.mlp = CohereMLP(config, linear_method=linear_method)
self.input_layernorm = LayerNorm(config.hidden_size, self.input_layernorm = LayerNorm(param_shape=(config.hidden_size),
eps=config.layer_norm_eps) eps=config.layer_norm_eps)
def forward( def forward(
...@@ -229,7 +260,8 @@ class CohereModel(nn.Module): ...@@ -229,7 +260,8 @@ class CohereModel(nn.Module):
CohereDecoderLayer(config, linear_method=linear_method) CohereDecoderLayer(config, linear_method=linear_method)
for _ in range(config.num_hidden_layers) for _ in range(config.num_hidden_layers)
]) ])
self.norm = LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.norm = LayerNorm(param_shape=(config.hidden_size),
eps=config.layer_norm_eps)
def forward( def forward(
self, self,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment