Commit 1e0cb1f4 authored by zhuwenwen's avatar zhuwenwen
Browse files

support nn layout

parent 281ca6c1
......@@ -119,7 +119,8 @@ torch::Tensor gptq_gemm(torch::Tensor a, torch::Tensor b_q_weight,
void gptq_shuffle(torch::Tensor q_weight, torch::Tensor q_perm, int64_t bit);
void trans_w16_gemm(torch::Tensor dst,torch::Tensor src,int64_t row,int64_t col);
void trans_w16_gemm(torch::Tensor dst, torch::Tensor src, int64_t row, int64_t col);
// void static_scaled_fp8_quant(torch::Tensor& out, torch::Tensor& input,
// torch::Tensor& scale);
......
......@@ -1548,7 +1548,7 @@ __global__ void trans_w16_gemm_cudakernel(int64_t num_kernels,T* dst,const T* sr
int64_t id = blockIdx.x * blockDim.x + threadIdx.x;
if(id >= num_kernels) return;
int64_t j=id%row; //dst的列id
int64_t j=id%row;
int64_t i=id/row;
dst[i*row+j]=src[j*col+i];
......
......@@ -160,7 +160,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
ops.impl("gptq_shuffle", torch::kCUDA, &gptq_shuffle);
// trans w16
ops.def("trans_w16_gemm(Tensor! dst, Tensor src, int row,int col) -> ()");
ops.def("trans_w16_gemm(Tensor! dst, Tensor src, int row, int col) -> ()");
ops.impl("trans_w16_gemm", torch::kCUDA, &trans_w16_gemm);
// Quantized GEMM for SqueezeLLM.
......
......@@ -89,7 +89,6 @@ class UnquantizedLinearMethod(LinearMethodBase):
self.separate_bias_add = separate_bias_add
self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
def create_weights(self, layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: List[int], input_size: int,
......@@ -110,15 +109,13 @@ class UnquantizedLinearMethod(LinearMethodBase):
weight = layer.weight
if self.separate_bias_add:
if bias is not None:
return F.linear(x, weight) + bias
return F.linear(x, weight)
if self.use_llama_nn:
if bias is not None:
return torch.matmul(x, weight) +bias
return torch.addmm(bias, x, weight)
else:
return torch.matmul(x, weight)
else:
......
......@@ -22,13 +22,9 @@ def set_default_torch_dtype(dtype: torch.dtype):
def get_model_architecture(
model_config: ModelConfig) -> Tuple[Type[nn.Module], str]:
architectures = getattr(model_config.hf_config, "architectures", [])
if architectures == ['LlamaForCausalLM'] or architectures == ['Qwen2ForCausalLM'] or architectures == ['ChatGLMModel'] or architectures == ['BaichuanForCausalLM']:
if architectures == ['LlamaForCausalLM'] or architectures == ['QWenLMHeadModel'] or architectures == ['Qwen2ForCausalLM'] or architectures == ['ChatGLMModel'] or architectures == ['BaichuanForCausalLM']:
if os.getenv('LLAMA_NN') != '0':
os.environ['LLAMA_NN'] = '1'
if os.getenv('GEMM_PAD') != '1':
os.environ['GEMM_PAD'] = '0'
if os.getenv('FA_PAD') != '1':
os.environ['FA_PAD'] = '0'
else:
os.environ['LLAMA_NN'] = '0'
# Special handling for quantized Mixtral.
......
......@@ -25,6 +25,7 @@ import torch
from torch import nn
from transformers import PretrainedConfig
import os
import re
from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig, LoRAConfig
......@@ -45,6 +46,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import SamplerOutput
from vllm import _custom_ops as ops
def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor:
......@@ -179,8 +181,6 @@ class BaiChuanAttention(nn.Module):
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
qkv, _ = self.W_pack(hidden_states)
if os.environ.get('FA_PAD') == '1' and qkv.shape[-1] == 12320:
qkv = qkv[...,:-32]
q, k, v = qkv.chunk(chunks=3, dim=-1)
if self.postion_embedding != "ALIBI":
q, k = self.rotary_emb(positions, q, k)
......@@ -329,6 +329,7 @@ class BaiChuanBaseForCausalLM(nn.Module):
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
def forward(
self,
......@@ -397,6 +398,26 @@ class BaiChuanBaseForCausalLM(nn.Module):
default_weight_loader)
weight_loader(param, loaded_weight)
if self.use_llama_nn:
lay_key_words = [
"self_attn.W_pack.weight",
"self_attn.o_proj.weight",
"mlp.gate_up_proj.weight",
"mlp.down_proj.weight"
]
combined_words = "|".join(lay_key_words)
for layername, weight in params_dict.items():
matches = re.findall(combined_words, layername)
if matches:
_weight = torch.zeros_like(weight.data)
ori_shape =_weight.shape
ops.trans_w16_gemm(_weight, weight.data, _weight.shape[0], _weight.shape[1])
weight.data.copy_(_weight)
weight.data=weight.data.reshape(ori_shape[1], -1)
class BaichuanForCausalLM(BaiChuanBaseForCausalLM):
"""Baichuan 13B and Baichuan2 7B/13B."""
......
......@@ -8,6 +8,7 @@ import torch
from torch import nn
from torch.nn import LayerNorm
import os
import re
from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig, LoRAConfig
......@@ -28,6 +29,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import SamplerOutput
from vllm.transformers_utils.configs import ChatGLMConfig
from vllm import _custom_ops as ops
class GLMAttention(nn.Module):
......@@ -102,8 +104,6 @@ class GLMAttention(nn.Module):
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
qkv, _ = self.query_key_value(hidden_states)
if os.environ.get('FA_PAD') == '1' and qkv.shape[-1] == 12320:
qkv = qkv[...,:-32]
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(position_ids, q, k)
context_layer = self.attn(
......@@ -356,6 +356,7 @@ class ChatGLMForCausalLM(nn.Module):
self.lm_head_weight = self.transformer.output_layer.weight
self.logits_processor = LogitsProcessor(config.padded_vocab_size)
self.sampler = Sampler()
self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
def forward(
self,
......@@ -396,3 +397,23 @@ class ChatGLMForCausalLM(nn.Module):
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
if self.use_llama_nn:
lay_key_words = [
"self_attention.query_key_value.weight",
"self_attention.dense.weight",
"mlp.dense_h_to_4h.weight",
"mlp.dense_4h_to_h.weight"
]
combined_words = "|".join(lay_key_words)
for layername, weight in params_dict.items():
matches = re.findall(combined_words, layername)
if matches:
_weight = torch.zeros_like(weight.data)
ori_shape =_weight.shape
ops.trans_w16_gemm(_weight, weight.data, _weight.shape[0], _weight.shape[1])
weight.data.copy_(_weight)
weight.data=weight.data.reshape(ori_shape[1], -1)
......@@ -159,8 +159,6 @@ class LlamaAttention(nn.Module):
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
if os.environ.get('FA_PAD') == '1' and qkv.shape[-1] == 12320:
qkv = qkv[...,:-32]
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
......@@ -444,29 +442,24 @@ class LlamaForCausalLM(nn.Module):
weight_loader(param, loaded_weight)
if self.use_llama_nn:
#以上代码模型权重已经加载完了
#以下代码使用正则匹配来找出要修改的weight
lay_key_words = [
"self_attn.qkv_proj.weight",
"self_attn.o_proj.weight",
"mlp.gate_up_proj.weight",
"mlp.down_proj.weight"
]
#合并所有关键词为一个正则表达式
combined_words = "|".join(lay_key_words)
for layername, weight in params_dict.items():
#print("key:\n",key)
matches = re.findall(combined_words, layername)
if matches:
#创建一个跟value一样大的tensor
_weight = torch.zeros_like(weight.data)
ori_shape =_weight.shape
ops.trans_w16_gemm(_weight,weight.data,_weight.shape[0],_weight.shape[1])
ops.trans_w16_gemm(_weight, weight.data, _weight.shape[0], _weight.shape[1])
weight.data.copy_(_weight)
weight.data=weight.data.reshape(ori_shape[1],-1)
weight.data=weight.data.reshape(ori_shape[1], -1)
# If this function is called, it should always initialize KV cache scale
# factors (or else raise an exception). Thus, handled exceptions should
......
......@@ -298,28 +298,21 @@ class QWenLMHeadModel(nn.Module):
default_weight_loader)
weight_loader(param, loaded_weight)
if self.use_llama_nn:
#以上代码模型权重已经加载完了
#以下代码使用正则匹配来找出要修改的weight
lay_key_words = [
"self_attn.qkv_proj.weight",
"self_attn.o_proj.weight",
"attn.c_attn.weight",
"attn.c_proj.weight",
"mlp.gate_up_proj.weight",
"mlp.down_proj.weight"
"mlp.c_proj.weight"
]
#合并所有关键词为一个正则表达式
combined_words = "|".join(lay_key_words)
for layername, weight in params_dict.items():
#print("key:\n",key)
matches = re.findall(combined_words, layername)
if matches:
#print(layername)
# print(weight.data)
#创建一个跟value一样大的tensor
_weight = torch.zeros_like(weight.data)
ori_shape =_weight.shape
ops.trans_w16_gemm(_weight,weight.data,_weight.shape[0],_weight.shape[1])
ops.trans_w16_gemm(_weight, weight.data, _weight.shape[0], _weight.shape[1])
weight.data.copy_(_weight)
weight.data=weight.data.reshape(ori_shape[1],-1)
......
......@@ -150,8 +150,6 @@ class Qwen2Attention(nn.Module):
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
if os.environ.get('FA_PAD') == '1' and qkv.shape[-1] == 12320:
qkv = qkv[...,:-32]
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
......@@ -386,28 +384,21 @@ class Qwen2ForCausalLM(nn.Module):
weight_loader(param, loaded_weight)
if self.use_llama_nn:
#以上代码模型权重已经加载完了
#以下代码使用正则匹配来找出要修改的weight
lay_key_words = [
"self_attn.qkv_proj.weight",
"self_attn.o_proj.weight",
"mlp.gate_up_proj.weight",
"mlp.down_proj.weight"
]
#合并所有关键词为一个正则表达式
combined_words = "|".join(lay_key_words)
for layername, weight in params_dict.items():
#print("key:\n",key)
matches = re.findall(combined_words, layername)
if matches:
#print(layername)
# print(weight.data)
#创建一个跟value一样大的tensor
_weight = torch.zeros_like(weight.data)
ori_shape =_weight.shape
ops.trans_w16_gemm(_weight,weight.data,_weight.shape[0],_weight.shape[1])
ops.trans_w16_gemm(_weight, weight.data, _weight.shape[0], _weight.shape[1])
weight.data.copy_(_weight)
weight.data=weight.data.reshape(ori_shape[1],-1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment