Commit 1e0cb1f4 authored by zhuwenwen's avatar zhuwenwen
Browse files

support nn layout

parent 281ca6c1
...@@ -119,7 +119,8 @@ torch::Tensor gptq_gemm(torch::Tensor a, torch::Tensor b_q_weight, ...@@ -119,7 +119,8 @@ torch::Tensor gptq_gemm(torch::Tensor a, torch::Tensor b_q_weight,
void gptq_shuffle(torch::Tensor q_weight, torch::Tensor q_perm, int64_t bit); void gptq_shuffle(torch::Tensor q_weight, torch::Tensor q_perm, int64_t bit);
void trans_w16_gemm(torch::Tensor dst,torch::Tensor src,int64_t row,int64_t col); void trans_w16_gemm(torch::Tensor dst, torch::Tensor src, int64_t row, int64_t col);
// void static_scaled_fp8_quant(torch::Tensor& out, torch::Tensor& input, // void static_scaled_fp8_quant(torch::Tensor& out, torch::Tensor& input,
// torch::Tensor& scale); // torch::Tensor& scale);
......
...@@ -1548,7 +1548,7 @@ __global__ void trans_w16_gemm_cudakernel(int64_t num_kernels,T* dst,const T* sr ...@@ -1548,7 +1548,7 @@ __global__ void trans_w16_gemm_cudakernel(int64_t num_kernels,T* dst,const T* sr
int64_t id = blockIdx.x * blockDim.x + threadIdx.x; int64_t id = blockIdx.x * blockDim.x + threadIdx.x;
if(id >= num_kernels) return; if(id >= num_kernels) return;
int64_t j=id%row; //dst的列id int64_t j=id%row;
int64_t i=id/row; int64_t i=id/row;
dst[i*row+j]=src[j*col+i]; dst[i*row+j]=src[j*col+i];
......
...@@ -160,7 +160,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ...@@ -160,7 +160,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
ops.impl("gptq_shuffle", torch::kCUDA, &gptq_shuffle); ops.impl("gptq_shuffle", torch::kCUDA, &gptq_shuffle);
// trans w16 // trans w16
ops.def("trans_w16_gemm(Tensor! dst, Tensor src, int row,int col) -> ()"); ops.def("trans_w16_gemm(Tensor! dst, Tensor src, int row, int col) -> ()");
ops.impl("trans_w16_gemm", torch::kCUDA, &trans_w16_gemm); ops.impl("trans_w16_gemm", torch::kCUDA, &trans_w16_gemm);
// Quantized GEMM for SqueezeLLM. // Quantized GEMM for SqueezeLLM.
......
...@@ -89,7 +89,6 @@ class UnquantizedLinearMethod(LinearMethodBase): ...@@ -89,7 +89,6 @@ class UnquantizedLinearMethod(LinearMethodBase):
self.separate_bias_add = separate_bias_add self.separate_bias_add = separate_bias_add
self.use_llama_nn = os.environ.get('LLAMA_NN') == '1' self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
def create_weights(self, layer: torch.nn.Module, def create_weights(self, layer: torch.nn.Module,
input_size_per_partition: int, input_size_per_partition: int,
output_partition_sizes: List[int], input_size: int, output_partition_sizes: List[int], input_size: int,
...@@ -110,15 +109,13 @@ class UnquantizedLinearMethod(LinearMethodBase): ...@@ -110,15 +109,13 @@ class UnquantizedLinearMethod(LinearMethodBase):
weight = layer.weight weight = layer.weight
if self.separate_bias_add: if self.separate_bias_add:
if bias is not None: if bias is not None:
return F.linear(x, weight) + bias return F.linear(x, weight) + bias
return F.linear(x, weight) return F.linear(x, weight)
if self.use_llama_nn: if self.use_llama_nn:
if bias is not None: if bias is not None:
return torch.matmul(x, weight) +bias return torch.addmm(bias, x, weight)
else: else:
return torch.matmul(x, weight) return torch.matmul(x, weight)
else: else:
......
...@@ -22,13 +22,9 @@ def set_default_torch_dtype(dtype: torch.dtype): ...@@ -22,13 +22,9 @@ def set_default_torch_dtype(dtype: torch.dtype):
def get_model_architecture( def get_model_architecture(
model_config: ModelConfig) -> Tuple[Type[nn.Module], str]: model_config: ModelConfig) -> Tuple[Type[nn.Module], str]:
architectures = getattr(model_config.hf_config, "architectures", []) architectures = getattr(model_config.hf_config, "architectures", [])
if architectures == ['LlamaForCausalLM'] or architectures == ['Qwen2ForCausalLM'] or architectures == ['ChatGLMModel'] or architectures == ['BaichuanForCausalLM']: if architectures == ['LlamaForCausalLM'] or architectures == ['QWenLMHeadModel'] or architectures == ['Qwen2ForCausalLM'] or architectures == ['ChatGLMModel'] or architectures == ['BaichuanForCausalLM']:
if os.getenv('LLAMA_NN') != '0': if os.getenv('LLAMA_NN') != '0':
os.environ['LLAMA_NN'] = '1' os.environ['LLAMA_NN'] = '1'
if os.getenv('GEMM_PAD') != '1':
os.environ['GEMM_PAD'] = '0'
if os.getenv('FA_PAD') != '1':
os.environ['FA_PAD'] = '0'
else: else:
os.environ['LLAMA_NN'] = '0' os.environ['LLAMA_NN'] = '0'
# Special handling for quantized Mixtral. # Special handling for quantized Mixtral.
......
...@@ -25,6 +25,7 @@ import torch ...@@ -25,6 +25,7 @@ import torch
from torch import nn from torch import nn
from transformers import PretrainedConfig from transformers import PretrainedConfig
import os import os
import re
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig, LoRAConfig from vllm.config import CacheConfig, LoRAConfig
...@@ -45,6 +46,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( ...@@ -45,6 +46,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import SamplerOutput from vllm.sequence import SamplerOutput
from vllm import _custom_ops as ops
def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor: def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor:
...@@ -179,8 +181,6 @@ class BaiChuanAttention(nn.Module): ...@@ -179,8 +181,6 @@ class BaiChuanAttention(nn.Module):
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
qkv, _ = self.W_pack(hidden_states) qkv, _ = self.W_pack(hidden_states)
if os.environ.get('FA_PAD') == '1' and qkv.shape[-1] == 12320:
qkv = qkv[...,:-32]
q, k, v = qkv.chunk(chunks=3, dim=-1) q, k, v = qkv.chunk(chunks=3, dim=-1)
if self.postion_embedding != "ALIBI": if self.postion_embedding != "ALIBI":
q, k = self.rotary_emb(positions, q, k) q, k = self.rotary_emb(positions, q, k)
...@@ -329,6 +329,7 @@ class BaiChuanBaseForCausalLM(nn.Module): ...@@ -329,6 +329,7 @@ class BaiChuanBaseForCausalLM(nn.Module):
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
self.logits_processor = LogitsProcessor(config.vocab_size) self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler() self.sampler = Sampler()
self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
def forward( def forward(
self, self,
...@@ -396,6 +397,26 @@ class BaiChuanBaseForCausalLM(nn.Module): ...@@ -396,6 +397,26 @@ class BaiChuanBaseForCausalLM(nn.Module):
weight_loader = getattr(param, "weight_loader", weight_loader = getattr(param, "weight_loader",
default_weight_loader) default_weight_loader)
weight_loader(param, loaded_weight) weight_loader(param, loaded_weight)
if self.use_llama_nn:
lay_key_words = [
"self_attn.W_pack.weight",
"self_attn.o_proj.weight",
"mlp.gate_up_proj.weight",
"mlp.down_proj.weight"
]
combined_words = "|".join(lay_key_words)
for layername, weight in params_dict.items():
matches = re.findall(combined_words, layername)
if matches:
_weight = torch.zeros_like(weight.data)
ori_shape =_weight.shape
ops.trans_w16_gemm(_weight, weight.data, _weight.shape[0], _weight.shape[1])
weight.data.copy_(_weight)
weight.data=weight.data.reshape(ori_shape[1], -1)
class BaichuanForCausalLM(BaiChuanBaseForCausalLM): class BaichuanForCausalLM(BaiChuanBaseForCausalLM):
......
...@@ -8,6 +8,7 @@ import torch ...@@ -8,6 +8,7 @@ import torch
from torch import nn from torch import nn
from torch.nn import LayerNorm from torch.nn import LayerNorm
import os import os
import re
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig, LoRAConfig from vllm.config import CacheConfig, LoRAConfig
...@@ -28,6 +29,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader ...@@ -28,6 +29,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import SamplerOutput from vllm.sequence import SamplerOutput
from vllm.transformers_utils.configs import ChatGLMConfig from vllm.transformers_utils.configs import ChatGLMConfig
from vllm import _custom_ops as ops
class GLMAttention(nn.Module): class GLMAttention(nn.Module):
...@@ -102,8 +104,6 @@ class GLMAttention(nn.Module): ...@@ -102,8 +104,6 @@ class GLMAttention(nn.Module):
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
qkv, _ = self.query_key_value(hidden_states) qkv, _ = self.query_key_value(hidden_states)
if os.environ.get('FA_PAD') == '1' and qkv.shape[-1] == 12320:
qkv = qkv[...,:-32]
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(position_ids, q, k) q, k = self.rotary_emb(position_ids, q, k)
context_layer = self.attn( context_layer = self.attn(
...@@ -356,6 +356,7 @@ class ChatGLMForCausalLM(nn.Module): ...@@ -356,6 +356,7 @@ class ChatGLMForCausalLM(nn.Module):
self.lm_head_weight = self.transformer.output_layer.weight self.lm_head_weight = self.transformer.output_layer.weight
self.logits_processor = LogitsProcessor(config.padded_vocab_size) self.logits_processor = LogitsProcessor(config.padded_vocab_size)
self.sampler = Sampler() self.sampler = Sampler()
self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
def forward( def forward(
self, self,
...@@ -396,3 +397,23 @@ class ChatGLMForCausalLM(nn.Module): ...@@ -396,3 +397,23 @@ class ChatGLMForCausalLM(nn.Module):
weight_loader = getattr(param, "weight_loader", weight_loader = getattr(param, "weight_loader",
default_weight_loader) default_weight_loader)
weight_loader(param, loaded_weight) weight_loader(param, loaded_weight)
if self.use_llama_nn:
lay_key_words = [
"self_attention.query_key_value.weight",
"self_attention.dense.weight",
"mlp.dense_h_to_4h.weight",
"mlp.dense_4h_to_h.weight"
]
combined_words = "|".join(lay_key_words)
for layername, weight in params_dict.items():
matches = re.findall(combined_words, layername)
if matches:
_weight = torch.zeros_like(weight.data)
ori_shape =_weight.shape
ops.trans_w16_gemm(_weight, weight.data, _weight.shape[0], _weight.shape[1])
weight.data.copy_(_weight)
weight.data=weight.data.reshape(ori_shape[1], -1)
...@@ -159,8 +159,6 @@ class LlamaAttention(nn.Module): ...@@ -159,8 +159,6 @@ class LlamaAttention(nn.Module):
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states) qkv, _ = self.qkv_proj(hidden_states)
if os.environ.get('FA_PAD') == '1' and qkv.shape[-1] == 12320:
qkv = qkv[...,:-32]
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k) q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata) attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
...@@ -444,29 +442,24 @@ class LlamaForCausalLM(nn.Module): ...@@ -444,29 +442,24 @@ class LlamaForCausalLM(nn.Module):
weight_loader(param, loaded_weight) weight_loader(param, loaded_weight)
if self.use_llama_nn: if self.use_llama_nn:
#以上代码模型权重已经加载完了
#以下代码使用正则匹配来找出要修改的weight
lay_key_words = [ lay_key_words = [
"self_attn.qkv_proj.weight", "self_attn.qkv_proj.weight",
"self_attn.o_proj.weight", "self_attn.o_proj.weight",
"mlp.gate_up_proj.weight", "mlp.gate_up_proj.weight",
"mlp.down_proj.weight" "mlp.down_proj.weight"
] ]
#合并所有关键词为一个正则表达式
combined_words = "|".join(lay_key_words) combined_words = "|".join(lay_key_words)
for layername, weight in params_dict.items(): for layername, weight in params_dict.items():
#print("key:\n",key)
matches = re.findall(combined_words, layername) matches = re.findall(combined_words, layername)
if matches: if matches:
#创建一个跟value一样大的tensor
_weight = torch.zeros_like(weight.data) _weight = torch.zeros_like(weight.data)
ori_shape =_weight.shape ori_shape =_weight.shape
ops.trans_w16_gemm(_weight,weight.data,_weight.shape[0],_weight.shape[1]) ops.trans_w16_gemm(_weight, weight.data, _weight.shape[0], _weight.shape[1])
weight.data.copy_(_weight) weight.data.copy_(_weight)
weight.data=weight.data.reshape(ori_shape[1],-1) weight.data=weight.data.reshape(ori_shape[1], -1)
# If this function is called, it should always initialize KV cache scale # If this function is called, it should always initialize KV cache scale
# factors (or else raise an exception). Thus, handled exceptions should # factors (or else raise an exception). Thus, handled exceptions should
......
...@@ -298,28 +298,21 @@ class QWenLMHeadModel(nn.Module): ...@@ -298,28 +298,21 @@ class QWenLMHeadModel(nn.Module):
default_weight_loader) default_weight_loader)
weight_loader(param, loaded_weight) weight_loader(param, loaded_weight)
if self.use_llama_nn: if self.use_llama_nn:
#以上代码模型权重已经加载完了
#以下代码使用正则匹配来找出要修改的weight
lay_key_words = [ lay_key_words = [
"self_attn.qkv_proj.weight", "attn.c_attn.weight",
"self_attn.o_proj.weight", "attn.c_proj.weight",
"mlp.gate_up_proj.weight", "mlp.gate_up_proj.weight",
"mlp.down_proj.weight" "mlp.c_proj.weight"
] ]
#合并所有关键词为一个正则表达式
combined_words = "|".join(lay_key_words) combined_words = "|".join(lay_key_words)
for layername, weight in params_dict.items(): for layername, weight in params_dict.items():
#print("key:\n",key)
matches = re.findall(combined_words, layername) matches = re.findall(combined_words, layername)
if matches: if matches:
#print(layername)
# print(weight.data)
#创建一个跟value一样大的tensor
_weight = torch.zeros_like(weight.data) _weight = torch.zeros_like(weight.data)
ori_shape =_weight.shape ori_shape =_weight.shape
ops.trans_w16_gemm(_weight,weight.data,_weight.shape[0],_weight.shape[1]) ops.trans_w16_gemm(_weight, weight.data, _weight.shape[0], _weight.shape[1])
weight.data.copy_(_weight) weight.data.copy_(_weight)
weight.data=weight.data.reshape(ori_shape[1],-1) weight.data=weight.data.reshape(ori_shape[1],-1)
......
...@@ -150,8 +150,6 @@ class Qwen2Attention(nn.Module): ...@@ -150,8 +150,6 @@ class Qwen2Attention(nn.Module):
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states) qkv, _ = self.qkv_proj(hidden_states)
if os.environ.get('FA_PAD') == '1' and qkv.shape[-1] == 12320:
qkv = qkv[...,:-32]
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k) q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata) attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
...@@ -386,28 +384,21 @@ class Qwen2ForCausalLM(nn.Module): ...@@ -386,28 +384,21 @@ class Qwen2ForCausalLM(nn.Module):
weight_loader(param, loaded_weight) weight_loader(param, loaded_weight)
if self.use_llama_nn: if self.use_llama_nn:
#以上代码模型权重已经加载完了
#以下代码使用正则匹配来找出要修改的weight
lay_key_words = [ lay_key_words = [
"self_attn.qkv_proj.weight", "self_attn.qkv_proj.weight",
"self_attn.o_proj.weight", "self_attn.o_proj.weight",
"mlp.gate_up_proj.weight", "mlp.gate_up_proj.weight",
"mlp.down_proj.weight" "mlp.down_proj.weight"
] ]
#合并所有关键词为一个正则表达式
combined_words = "|".join(lay_key_words) combined_words = "|".join(lay_key_words)
for layername, weight in params_dict.items(): for layername, weight in params_dict.items():
#print("key:\n",key)
matches = re.findall(combined_words, layername) matches = re.findall(combined_words, layername)
if matches: if matches:
#print(layername)
# print(weight.data)
#创建一个跟value一样大的tensor
_weight = torch.zeros_like(weight.data) _weight = torch.zeros_like(weight.data)
ori_shape =_weight.shape ori_shape =_weight.shape
ops.trans_w16_gemm(_weight,weight.data,_weight.shape[0],_weight.shape[1]) ops.trans_w16_gemm(_weight, weight.data, _weight.shape[0], _weight.shape[1])
weight.data.copy_(_weight) weight.data.copy_(_weight)
weight.data=weight.data.reshape(ori_shape[1],-1) weight.data=weight.data.reshape(ori_shape[1],-1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment