Commit 1e77d04e authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge branch 'v0.5.0-dtk24.04.1' into v0.5.2-dtk24.04.1

# Conflicts:
#	csrc/attention/attention_kernels.cu
#	csrc/attention/attention_utils.cuh
#	csrc/layernorm_kernels.cu
#	vllm/model_executor/layers/linear.py
#	vllm/model_executor/models/baichuan.py
#	vllm/model_executor/models/llama.py
parents 6fa22430 c62f8e9a
......@@ -27,6 +27,7 @@ import torch
from torch import nn
from transformers import LlamaConfig
import os
import re
from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig, LoRAConfig
......@@ -53,6 +54,7 @@ from vllm.utils import is_hip, print_warning_once
from .interfaces import SupportsLoRA
from .utils import is_pp_missing_parameter, make_layers
from vllm import _custom_ops as ops
class LlamaMLP(nn.Module):
......@@ -160,8 +162,6 @@ class LlamaAttention(nn.Module):
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
if os.environ.get('FA_PAD') == '1' and qkv.shape[-1] == 12320:
qkv = qkv[...,:-32]
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
......@@ -383,6 +383,7 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA):
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.vocab_size, logit_scale)
self.sampler = Sampler()
self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
def forward(
self,
......@@ -482,8 +483,28 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA):
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
weight_loader(param, loaded_weight)
if self.use_llama_nn:
lay_key_words = [
"self_attn.qkv_proj.weight",
"self_attn.o_proj.weight",
"mlp.gate_up_proj.weight",
"mlp.down_proj.weight"
]
combined_words = "|".join(lay_key_words)
for layername, weight in params_dict.items():
matches = re.findall(combined_words, layername)
if matches:
_weight = torch.zeros_like(weight.data)
ori_shape =_weight.shape
ops.trans_w16_gemm(_weight, weight.data, _weight.shape[0], _weight.shape[1])
weight.data.copy_(_weight)
weight.data=weight.data.reshape(ori_shape[1], -1)
# If this function is called, it should always initialize KV cache scale
# factors (or else raise an exception). Thus, handled exceptions should
# make sure to leave KV cache scale factors in a known good (dummy) state
......
......@@ -10,6 +10,9 @@ import torch
from torch import nn
from transformers import PretrainedConfig
import os
import re
from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig
from vllm.distributed import get_tensor_model_parallel_world_size
......@@ -30,7 +33,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors, SamplerOutput
from vllm.utils import print_warning_once
from vllm import _custom_ops as ops
class QWenMLP(nn.Module):
def __init__(
......@@ -200,6 +203,7 @@ class QWenModel(nn.Module):
for _ in range(config.num_hidden_layers)
])
self.ln_f = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
def forward(
self,
......@@ -240,6 +244,7 @@ class QWenLMHeadModel(nn.Module):
quant_config=quant_config)
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
def forward(
self,
......@@ -305,3 +310,24 @@ class QWenLMHeadModel(nn.Module):
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
if self.use_llama_nn:
lay_key_words = [
"attn.c_attn.weight",
"attn.c_proj.weight",
"mlp.gate_up_proj.weight",
"mlp.c_proj.weight"
]
combined_words = "|".join(lay_key_words)
for layername, weight in params_dict.items():
matches = re.findall(combined_words, layername)
if matches:
_weight = torch.zeros_like(weight.data)
ori_shape =_weight.shape
ops.trans_w16_gemm(_weight, weight.data, _weight.shape[0], _weight.shape[1])
weight.data.copy_(_weight)
weight.data=weight.data.reshape(ori_shape[1],-1)
......@@ -28,6 +28,7 @@ import torch
from torch import nn
from transformers import Qwen2Config
import os
import re
from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig, LoRAConfig
......@@ -51,7 +52,7 @@ from vllm.utils import print_warning_once
from .interfaces import SupportsLoRA
from vllm import _custom_ops as ops
class Qwen2MLP(nn.Module):
def __init__(
......@@ -152,8 +153,6 @@ class Qwen2Attention(nn.Module):
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
if os.environ.get('FA_PAD') == '1' and qkv.shape[-1] == 12320:
qkv = qkv[...,:-32]
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
......@@ -327,6 +326,7 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA):
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
def forward(
self,
......@@ -401,3 +401,25 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA):
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
if self.use_llama_nn:
lay_key_words = [
"self_attn.qkv_proj.weight",
"self_attn.o_proj.weight",
"mlp.gate_up_proj.weight",
"mlp.down_proj.weight"
]
combined_words = "|".join(lay_key_words)
for layername, weight in params_dict.items():
matches = re.findall(combined_words, layername)
if matches:
_weight = torch.zeros_like(weight.data)
ori_shape =_weight.shape
ops.trans_w16_gemm(_weight, weight.data, _weight.shape[0], _weight.shape[1])
weight.data.copy_(_weight)
weight.data=weight.data.reshape(ori_shape[1],-1)
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment