# coding=utf-8 # Adapted from # https://github.com/THUDM/ChatGLM2-6B """Inference-only ChatGLM model compatible with THUDM weights. The input of the model is flattened to a 1D tensor of tokens. The model uses InputMetadata to extract the original 2D shape of the input. """ from typing import List, Optional, Tuple import torch from torch import nn from torch.nn import LayerNorm from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.attention import PagedAttentionWithRoPE from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (LinearMethodBase, MergedColumnParallelLinear, QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding, ParallelLMHead) from vllm.model_executor.parallel_utils.parallel_state import ( get_tensor_model_parallel_world_size) from vllm.model_executor.weight_utils import (default_weight_loader, hf_model_weights_iterator) from vllm.sequence import SamplerOutput from vllm.transformers_utils.configs import ChatGLMConfig KVCache = Tuple[torch.Tensor, torch.Tensor] class GLMAttention(nn.Module): def __init__( self, config, linear_method: Optional[LinearMethodBase] = None, ): super().__init__() self.hidden_size = config.hidden_size tp_size = get_tensor_model_parallel_world_size() self.total_num_heads = config.num_attention_heads assert self.total_num_heads % tp_size == 0 self.num_heads = self.total_num_heads // tp_size self.multi_query_attention = config.multi_query_attention self.total_num_kv_heads = (config.multi_query_group_num if config.multi_query_attention else config.num_attention_heads) if self.total_num_kv_heads >= tp_size: # Number of KV heads is greater than TP size, so we partition # the KV heads across multiple tensor parallel GPUs. assert self.total_num_kv_heads % tp_size == 0 else: # Number of KV heads is less than TP size, so we replicate # the KV heads across multiple tensor parallel GPUs. assert tp_size % self.total_num_kv_heads == 0 self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) self.head_dim = config.hidden_size // self.total_num_heads self.q_size = self.num_heads * self.head_dim self.kv_size = self.num_kv_heads * self.head_dim self.scaling = self.head_dim**-0.5 self.query_key_value = QKVParallelLinear( self.hidden_size, self.head_dim, self.total_num_heads, self.total_num_kv_heads, bias=config.add_bias_linear or config.add_qkv_bias, linear_method=linear_method, ) self.dense = RowParallelLinear( self.total_num_heads * self.head_dim, config.hidden_size, bias=config.add_bias_linear, linear_method=linear_method, ) self.attn = PagedAttentionWithRoPE( self.num_heads, self.head_dim, self.scaling, rotary_dim=self.head_dim // 2, num_kv_heads=self.num_kv_heads, is_neox_style=False, ) def forward( self, hidden_states: torch.Tensor, position_ids: torch.Tensor, kv_cache: KVCache, input_metadata: InputMetadata, cache_event: Optional[torch.cuda.Event], ) -> torch.Tensor: qkv, _ = self.query_key_value(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) key_cache, value_cache = kv_cache context_layer = self.attn( position_ids, q, k, v, key_cache, value_cache, input_metadata, cache_event, ) attn_output, _ = self.dense(context_layer) return attn_output class GLMMLP(nn.Module): """MLP. MLP will take the input with h hidden state, project it to 4*h hidden dimension, perform nonlinear transformation, and project the state back into h hidden dimension. """ def __init__( self, config, linear_method: Optional[LinearMethodBase] = None, ): super().__init__() self.add_bias = config.add_bias_linear # Project to 4h. self.dense_h_to_4h = MergedColumnParallelLinear( config.hidden_size, [config.ffn_hidden_size] * 2, bias=config.add_bias_linear, linear_method=linear_method, ) self.activation_func = SiluAndMul() # Project back to h. self.dense_4h_to_h = RowParallelLinear( config.ffn_hidden_size, config.hidden_size, bias=config.add_bias_linear, linear_method=linear_method, ) def forward(self, hidden_states): # [s, b, 4hp] intermediate_parallel, _ = self.dense_h_to_4h(hidden_states) intermediate_parallel = self.activation_func(intermediate_parallel) # [s, b, h] output, _ = self.dense_4h_to_h(intermediate_parallel) return output class GLMBlock(nn.Module): """A single transformer layer. Transformer layer takes input with size [s, b, h] and returns an output of the same size. """ def __init__( self, config, linear_method: Optional[LinearMethodBase] = None, ): super().__init__() self.apply_residual_connection_post_layernorm = ( config.apply_residual_connection_post_layernorm) self.fp32_residual_connection = config.fp32_residual_connection layer_norm_func = RMSNorm if config.rmsnorm else LayerNorm # Layernorm on the input data. self.input_layernorm = layer_norm_func(config.hidden_size, eps=config.layernorm_epsilon) # Self attention. self.self_attention = GLMAttention(config, linear_method) self.hidden_dropout = config.hidden_dropout # Layernorm on the attention output self.post_attention_layernorm = layer_norm_func( config.hidden_size, eps=config.layernorm_epsilon) # MLP self.mlp = GLMMLP(config, linear_method) def forward( self, hidden_states: torch.Tensor, position_ids: torch.Tensor, kv_cache: KVCache, input_metadata: InputMetadata, cache_event: Optional[torch.cuda.Event], ) -> torch.Tensor: # hidden_states: [num_tokens, h] # Layer norm at the beginning of the transformer layer. layernorm_output = self.input_layernorm(hidden_states) # Self attention. attention_output = self.self_attention( hidden_states=layernorm_output, position_ids=position_ids, kv_cache=kv_cache, input_metadata=input_metadata, cache_event=cache_event, ) # Residual connection. if self.apply_residual_connection_post_layernorm: residual = layernorm_output else: residual = hidden_states layernorm_input = residual + attention_output # Layer norm post the self attention. layernorm_output = self.post_attention_layernorm(layernorm_input) # Second residual connection. if self.apply_residual_connection_post_layernorm: residual = layernorm_output else: residual = layernorm_input output = self.mlp(layernorm_output) + residual return output class GLMTransformer(nn.Module): """Transformer class.""" def __init__( self, config, linear_method: Optional[LinearMethodBase] = None, ): super().__init__() self.post_layer_norm = config.post_layer_norm # Number of layers. self.num_layers = config.num_layers # Transformer layers. self.layers = nn.ModuleList( [GLMBlock(config, linear_method) for i in range(self.num_layers)]) if self.post_layer_norm: layer_norm_func = RMSNorm if config.rmsnorm else LayerNorm # Final layer norm before output. self.final_layernorm = layer_norm_func( config.hidden_size, eps=config.layernorm_epsilon) def forward( self, hidden_states: torch.Tensor, position_ids: torch.Tensor, kv_caches: List[KVCache], input_metadata: InputMetadata, cache_events: Optional[List[torch.cuda.Event]], ) -> torch.Tensor: for i in range(self.num_layers): cache_event = None if cache_events is None else cache_events[i] layer = self.layers[i] hidden_states = layer( hidden_states=hidden_states, position_ids=position_ids, kv_cache=kv_caches[i], input_metadata=input_metadata, cache_event=cache_event, ) # Final layer norm. if self.post_layer_norm: hidden_states = self.final_layernorm(hidden_states) return hidden_states class ChatGLMModel(nn.Module): def __init__( self, config, linear_method: Optional[LinearMethodBase] = None, ): super().__init__() self.embedding = VocabParallelEmbedding(config.padded_vocab_size, config.hidden_size) self.num_layers = config.num_layers self.multi_query_group_num = config.multi_query_group_num self.kv_channels = config.kv_channels self.encoder = GLMTransformer(config, linear_method) self.output_layer = ParallelLMHead(config.padded_vocab_size, config.hidden_size) def forward( self, input_ids: torch.Tensor, position_ids: torch.Tensor, kv_caches: List[KVCache], input_metadata: InputMetadata, cache_events: Optional[List[torch.cuda.Event]], ): inputs_embeds = self.embedding(input_ids) # Run encoder. hidden_states = self.encoder( hidden_states=inputs_embeds, position_ids=position_ids, kv_caches=kv_caches, input_metadata=input_metadata, cache_events=cache_events, ) return hidden_states class ChatGLMForCausalLM(nn.Module): def __init__( self, config: ChatGLMConfig, linear_method: Optional[LinearMethodBase] = None, ): super().__init__() self.config: ChatGLMConfig = config self.linear_method = linear_method self.transformer = ChatGLMModel(config, linear_method) self.lm_head_weight = self.transformer.output_layer.weight self.sampler = Sampler(config.padded_vocab_size) def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, kv_caches: List[KVCache], input_metadata: InputMetadata, cache_events: Optional[List[torch.cuda.Event]], ) -> SamplerOutput: hidden_states = self.transformer(input_ids, positions, kv_caches, input_metadata, cache_events) next_tokens = self.sampler(self.lm_head_weight, hidden_states, input_metadata) return next_tokens def load_weights(self, model_name_or_path: str, cache_dir: Optional[str] = None, load_format: str = "auto", revision: Optional[str] = None): params_dict = dict(self.named_parameters(remove_duplicate=False)) for name, loaded_weight in hf_model_weights_iterator( model_name_or_path, cache_dir, load_format, revision): if "rotary_pos_emb.inv_freq" in name: continue if "word_embeddings" in name: name = name.replace(".word_embeddings", "") param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight)