# coding=utf-8 # Adapted from # https://huggingface.co/Qwen/Qwen-7B/blob/main/modeling_qwen.py # Copyright (c) Alibaba Cloud. # LICENSE: https://huggingface.co/Qwen/Qwen-7B/blob/main/LICENSE """Inference-only QWen model compatible with HuggingFace weights.""" from typing import Any, Dict, Iterable, List, Optional, Tuple import torch from torch import nn from transformers import PretrainedConfig import os import re from vllm.attention import Attention, AttentionMetadata from vllm.config import CacheConfig from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors, SamplerOutput from vllm.utils import print_warning_once from vllm import _custom_ops as ops from vllm.model_executor.utils import pad_weight, gemm_bank_conf from .utils import is_pp_missing_parameter, make_layers class QWenMLP(nn.Module): def __init__( self, hidden_size: int, intermediate_size: int, hidden_act: str = "silu", quant_config: Optional[QuantizationConfig] = None, ): super().__init__() self.gate_up_proj = MergedColumnParallelLinear( hidden_size, [intermediate_size] * 2, bias=False, quant_config=quant_config) self.c_proj = RowParallelLinear(intermediate_size, hidden_size, bias=False, quant_config=quant_config) if hidden_act != "silu": raise ValueError(f"Unsupported activation: {hidden_act}. " "Only silu is supported for now.") self.act_fn = SiluAndMul() def forward(self, x): gate_up, _ = self.gate_up_proj(x) x = self.act_fn(gate_up) x, _ = self.c_proj(x) return x class QWenAttention(nn.Module): def __init__( self, hidden_size: int, num_heads: int, max_position_embeddings: int, rope_theta: float = 10000, rope_scaling: Optional[Dict[str, Any]] = None, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() self.hidden_size = hidden_size tensor_model_parallel_world_size = get_tensor_model_parallel_world_size( ) self.total_num_heads = num_heads assert self.total_num_heads % tensor_model_parallel_world_size == 0 self.num_heads = (self.total_num_heads // tensor_model_parallel_world_size) self.head_dim = hidden_size // self.total_num_heads self.c_attn = QKVParallelLinear( hidden_size, self.head_dim, self.total_num_heads, bias=True, quant_config=quant_config, ) self.c_proj = RowParallelLinear( self.total_num_heads * self.head_dim, hidden_size, bias=False, quant_config=quant_config, ) self.scaling = self.head_dim**-0.5 self.rotary_emb = get_rope( self.head_dim, rotary_dim=self.head_dim, max_position=max_position_embeddings, base=rope_theta, rope_scaling=rope_scaling, ) self.attn = Attention(self.num_heads, self.head_dim, self.scaling, cache_config=cache_config, quant_config=quant_config) self.quant_method = None if quant_config is not None: self.quant_method=quant_config.get_name() self.quant_config=quant_config def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, kv_cache: torch.Tensor, attn_metadata: AttentionMetadata, ) -> torch.Tensor: qkv, _ = self.c_attn(hidden_states) if os.environ.get('FA_PAD') == '1' and self.quant_method is None: qkv = qkv[...,:-32] q, k, v = qkv.chunk(chunks=3, dim=-1) q, k = self.rotary_emb(positions, q, k) attn_output = self.attn(q, k, v, kv_cache, attn_metadata) output, _ = self.c_proj(attn_output) return output class QWenBlock(nn.Module): def __init__( self, config: PretrainedConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() self.ln_1 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) rope_theta = getattr(config, "rope_theta", 10000) rope_scaling = getattr(config, "rope_scaling", None) self.attn = QWenAttention(config.hidden_size, config.num_attention_heads, config.max_position_embeddings, rope_theta=rope_theta, rope_scaling=rope_scaling, cache_config=cache_config, quant_config=quant_config) self.ln_2 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) self.mlp = QWenMLP(config.hidden_size, config.intermediate_size // 2, quant_config=quant_config) def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, kv_cache: torch.Tensor, attn_metadata: AttentionMetadata, residual: Optional[torch.Tensor], ) -> Tuple[torch.Tensor, torch.Tensor]: # Self Attention if residual is None: residual = hidden_states hidden_states = self.ln_1(hidden_states) else: hidden_states, residual = self.ln_1(hidden_states, residual) hidden_states = self.attn( positions=positions, hidden_states=hidden_states, kv_cache=kv_cache, attn_metadata=attn_metadata, ) # Fully Connected hidden_states, residual = self.ln_2(hidden_states, residual) hidden_states = self.mlp(hidden_states) return hidden_states, residual class QWenModel(nn.Module): def __init__( self, config: PretrainedConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", ): super().__init__() self.config = config self.vocab_size = config.vocab_size self.wte = VocabParallelEmbedding( config.vocab_size, config.hidden_size, ) self.start_layer, self.end_layer, self.h = make_layers( config.num_hidden_layers, lambda prefix: QWenBlock(config, cache_config, quant_config), prefix=f"{prefix}.h") self.ln_f = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors], ) -> torch.Tensor: if get_pp_group().is_first_rank: hidden_states = self.wte(input_ids) residual = None else: assert intermediate_tensors is not None hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] for i in range(self.start_layer, self.end_layer): layer = self.h[i] hidden_states, residual = layer( positions, hidden_states, kv_caches[i - self.start_layer], attn_metadata, residual, ) if not get_pp_group().is_last_rank: return IntermediateTensors({ "hidden_states": hidden_states, "residual": residual }) hidden_states, _ = self.ln_f(hidden_states, residual) return hidden_states class QWenLMHeadModel(nn.Module): def __init__( self, config: PretrainedConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() self.config = config self.quant_config = quant_config self.transformer = QWenModel(config, cache_config, quant_config) self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size, quant_config=quant_config) self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() self.quant_method = None if quant_config is not None: self.quant_method=quant_config.get_name() self.quant_config=quant_config self.use_llama_nn = os.environ.get('LLAMA_NN') == '1' self.use_gemm_pad = os.environ.get('GEMM_PAD') == '1' self.use_fa_pad = os.environ.get('FA_PAD') == '1' def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, ) -> torch.Tensor: hidden_states = self.transformer(input_ids, positions, kv_caches, attn_metadata, intermediate_tensors) return hidden_states def make_empty_intermediate_tensors( self, batch_size: int, dtype: torch.dtype, device: torch.device) -> IntermediateTensors: return IntermediateTensors({ "hidden_states": torch.zeros((batch_size, self.config.hidden_size), dtype=dtype, device=device), "residual": torch.zeros((batch_size, self.config.hidden_size), dtype=dtype, device=device), }) def compute_logits(self, hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata) -> torch.Tensor: logits = self.logits_processor(self.lm_head, hidden_states, sampling_metadata) return logits def sample( self, logits: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> Optional[SamplerOutput]: next_tokens = self.sampler(logits, sampling_metadata) return next_tokens def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("gate_up_proj", "w2", 0), ("gate_up_proj", "w1", 1), ] params_dict = dict(self.named_parameters()) for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name: continue for (param_name, weight_name, shard_id) in stacked_params_mapping: if weight_name not in name: continue name = name.replace(weight_name, param_name) # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue # Skip layers on other devices. if is_pp_missing_parameter(name, self): continue param = params_dict[name] weight_loader = param.weight_loader weight_loader(param, loaded_weight, shard_id) break else: # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue # Skip loading visual weights to support Qwen-VL models # in cases with text-only inputs # TODO: add support for Qwen-VL if (name not in params_dict and name.startswith("transformer.visual.")): print_warning_once( "Only text inputs are allowed. Images won't be handled " "until Qwen-VL models are fully supported.") continue # Skip layers on other devices. if is_pp_missing_parameter(name, self): continue param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) if self.use_llama_nn and self.quant_method is None : lay_key_words = [ "attn.c_attn.weight", "attn.c_proj.weight", "mlp.gate_up_proj.weight", "mlp.c_proj.weight", "lm_head.weight" ] combined_words = "|".join(lay_key_words) lay_qkv_words = ["attn.c_attn.weight"] qkv_words = "|".join(lay_qkv_words) lay_qkv_bias_words = ["attn.c_attn.bias"] qkv_bias_words = "|".join(lay_qkv_bias_words) for layername, weight in params_dict.items(): if self.use_fa_pad and (re.findall(qkv_bias_words, layername)): weight.data = pad_weight(weight.data, 32) matches = re.findall(combined_words, layername) if matches: if self.use_gemm_pad and gemm_bank_conf(weight.data.shape[0]): weight.data = pad_weight(weight.data, 32) if self.use_fa_pad and (re.findall(qkv_words, layername)): if not gemm_bank_conf(weight.data.shape[0]): weight.data = pad_weight(weight.data, 32) _weight = torch.zeros_like(weight.data) ori_shape =_weight.shape ops.trans_w16_gemm(_weight, weight.data, _weight.shape[0], _weight.shape[1]) weight.data.copy_(_weight) weight.data=weight.data.reshape(ori_shape[1],-1) if self.quant_method == "awq": lay_key_words = [ "attn.c_attn.qweight", "attn.c_proj.qweight", "mlp.gate_up_proj.qweight", "mlp.c_proj.qweight" ] combined_words = "|".join(lay_key_words) for layername, weight in params_dict.items(): matches = re.findall(combined_words, layername) if matches: qweight =params_dict[layername] qzeros=params_dict[layername.replace("qweight", "qzeros")] scales=params_dict[layername.replace("qweight", "scales")] zeros_and_scalse =params_dict[layername.replace("qweight", "zeros_and_scales")] group_size= self.quant_config.group_size dim_n = scales.data.shape[1] dim_k = qweight.data.shape[0] pad_group=2 _qw, _sz=ops.convert_s4(qweight,qzeros,scales,int(group_size)) sz = ops.sz_permute(_sz).reshape(-1,dim_n) zeros_and_scalse.data.copy_(sz) qweight.data.copy_(_qw) #reshape zeros_and_scalse.data=zeros_and_scalse.reshape(dim_n,-1) #[k/greop_size,n]------>[n,k/group_size] qweight.data=qweight.data.reshape(dim_n,-1) #[k,n/8]---->[n,k/8] if dim_k % 4096==0: zeros_and_scalse_pad= torch.zeros(dim_n,pad_group,dtype=torch.int32).cuda() zeros_and_scalse.data=torch.cat((zeros_and_scalse.data,zeros_and_scalse_pad),dim=1).contiguous() qweight_pad= torch.zeros(dim_n,int(group_size//4),dtype=torch.int32).cuda() qweight.data=torch.cat((qweight.data,qweight_pad),dim=1).contiguous()