Commit 0b5e4e11 authored by zhuwenwen's avatar zhuwenwen
Browse files

add gemm pad and fa pad for 7b model

parent 2d0a73a3
...@@ -16,6 +16,7 @@ from vllm.model_executor.layers.quantization.base_config import ( ...@@ -16,6 +16,7 @@ from vllm.model_executor.layers.quantization.base_config import (
from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.utils import set_weight_attrs
import os import os
from vllm.model_executor.utils import gemm_bank_conf
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -88,6 +89,7 @@ class UnquantizedLinearMethod(LinearMethodBase): ...@@ -88,6 +89,7 @@ class UnquantizedLinearMethod(LinearMethodBase):
def __init__(self, separate_bias_add: bool = False): def __init__(self, separate_bias_add: bool = False):
self.separate_bias_add = separate_bias_add self.separate_bias_add = separate_bias_add
self.use_llama_nn = os.environ.get('LLAMA_NN') == '1' self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
self.use_gemm_pad = os.environ.get('GEMM_PAD') == '1'
def create_weights(self, layer: torch.nn.Module, def create_weights(self, layer: torch.nn.Module,
input_size_per_partition: int, input_size_per_partition: int,
...@@ -114,6 +116,9 @@ class UnquantizedLinearMethod(LinearMethodBase): ...@@ -114,6 +116,9 @@ class UnquantizedLinearMethod(LinearMethodBase):
return F.linear(x, weight) return F.linear(x, weight)
if self.use_llama_nn: if self.use_llama_nn:
if gemm_bank_conf(weight.shape[1] - 32) and os.environ['GEMM_PAD'] == '1':
weight = weight[:,:-32]
if bias is not None: if bias is not None:
if len(x.shape) == 2: if len(x.shape) == 2:
return torch.addmm(bias, x, weight) return torch.addmm(bias, x, weight)
......
...@@ -26,8 +26,15 @@ def get_model_architecture( ...@@ -26,8 +26,15 @@ def get_model_architecture(
if any(arch in architectures for arch in support_nn_architectures): if any(arch in architectures for arch in support_nn_architectures):
if os.getenv('LLAMA_NN') != '0': if os.getenv('LLAMA_NN') != '0':
os.environ['LLAMA_NN'] = '1' os.environ['LLAMA_NN'] = '1'
if os.getenv('GEMM_PAD') != '1':
os.environ['GEMM_PAD'] = '0'
if os.getenv('FA_PAD') != '1':
os.environ['FA_PAD'] = '0'
else: else:
os.environ['LLAMA_NN'] = '0' os.environ['LLAMA_NN'] = '0'
os.environ['GEMM_PAD'] = '0'
os.environ['FA_PAD'] = '0'
# Special handling for quantized Mixtral. # Special handling for quantized Mixtral.
# FIXME(woosuk): This is a temporary hack. # FIXME(woosuk): This is a temporary hack.
if (model_config.quantization is not None if (model_config.quantization is not None
......
...@@ -46,7 +46,9 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( ...@@ -46,7 +46,9 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import SamplerOutput from vllm.sequence import SamplerOutput
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.model_executor.utils import pad_weight, gemm_bank_conf
def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor: def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor:
...@@ -181,6 +183,8 @@ class BaiChuanAttention(nn.Module): ...@@ -181,6 +183,8 @@ class BaiChuanAttention(nn.Module):
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
qkv, _ = self.W_pack(hidden_states) qkv, _ = self.W_pack(hidden_states)
if os.environ.get('FA_PAD') == '1' and qkv.shape[-1] == 12320:
qkv = qkv[...,:-32]
q, k, v = qkv.chunk(chunks=3, dim=-1) q, k, v = qkv.chunk(chunks=3, dim=-1)
if self.postion_embedding != "ALIBI": if self.postion_embedding != "ALIBI":
q, k = self.rotary_emb(positions, q, k) q, k = self.rotary_emb(positions, q, k)
...@@ -330,6 +334,8 @@ class BaiChuanBaseForCausalLM(nn.Module): ...@@ -330,6 +334,8 @@ class BaiChuanBaseForCausalLM(nn.Module):
self.logits_processor = LogitsProcessor(config.vocab_size) self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler() self.sampler = Sampler()
self.use_llama_nn = os.environ.get('LLAMA_NN') == '1' self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
self.use_gemm_pad = os.environ.get('GEMM_PAD') == '1'
self.use_fa_pad = os.environ.get('FA_PAD') == '1'
def forward( def forward(
self, self,
...@@ -409,7 +415,13 @@ class BaiChuanBaseForCausalLM(nn.Module): ...@@ -409,7 +415,13 @@ class BaiChuanBaseForCausalLM(nn.Module):
for layername, weight in params_dict.items(): for layername, weight in params_dict.items():
matches = re.findall(combined_words, layername) matches = re.findall(combined_words, layername)
if matches: if matches:
if self.use_gemm_pad and gemm_bank_conf(weight.data.shape[0]):
weight.data = pad_weight(weight.data, 32)
if self.use_fa_pad and weight.data.shape[0] == 12288:
weight.data = pad_weight(weight.data, 32)
_weight = torch.zeros_like(weight.data) _weight = torch.zeros_like(weight.data)
ori_shape =_weight.shape ori_shape =_weight.shape
......
...@@ -29,7 +29,9 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader ...@@ -29,7 +29,9 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import SamplerOutput from vllm.sequence import SamplerOutput
from vllm.transformers_utils.configs import ChatGLMConfig from vllm.transformers_utils.configs import ChatGLMConfig
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.model_executor.utils import pad_weight, gemm_bank_conf
class GLMAttention(nn.Module): class GLMAttention(nn.Module):
...@@ -104,6 +106,8 @@ class GLMAttention(nn.Module): ...@@ -104,6 +106,8 @@ class GLMAttention(nn.Module):
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
qkv, _ = self.query_key_value(hidden_states) qkv, _ = self.query_key_value(hidden_states)
if os.environ.get('FA_PAD') == '1' and qkv.shape[-1] == 12320:
qkv = qkv[...,:-32]
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(position_ids, q, k) q, k = self.rotary_emb(position_ids, q, k)
context_layer = self.attn( context_layer = self.attn(
...@@ -357,6 +361,8 @@ class ChatGLMForCausalLM(nn.Module): ...@@ -357,6 +361,8 @@ class ChatGLMForCausalLM(nn.Module):
self.logits_processor = LogitsProcessor(config.padded_vocab_size) self.logits_processor = LogitsProcessor(config.padded_vocab_size)
self.sampler = Sampler() self.sampler = Sampler()
self.use_llama_nn = os.environ.get('LLAMA_NN') == '1' self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
self.use_gemm_pad = os.environ.get('GEMM_PAD') == '1'
self.use_fa_pad = os.environ.get('FA_PAD') == '1'
def forward( def forward(
self, self,
...@@ -409,7 +415,13 @@ class ChatGLMForCausalLM(nn.Module): ...@@ -409,7 +415,13 @@ class ChatGLMForCausalLM(nn.Module):
for layername, weight in params_dict.items(): for layername, weight in params_dict.items():
matches = re.findall(combined_words, layername) matches = re.findall(combined_words, layername)
if matches: if matches:
if self.use_gemm_pad and gemm_bank_conf(weight.data.shape[0]):
weight.data = pad_weight(weight.data, 32)
if self.use_fa_pad and weight.data.shape[0] == 12288:
weight.data = pad_weight(weight.data, 32)
_weight = torch.zeros_like(weight.data) _weight = torch.zeros_like(weight.data)
ori_shape =_weight.shape ori_shape =_weight.shape
......
...@@ -52,6 +52,8 @@ from vllm.sequence import SamplerOutput ...@@ -52,6 +52,8 @@ from vllm.sequence import SamplerOutput
from vllm.utils import is_hip, print_warning_once from vllm.utils import is_hip, print_warning_once
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.model_executor.utils import pad_weight, gemm_bank_conf
class LlamaMLP(nn.Module): class LlamaMLP(nn.Module):
...@@ -159,6 +161,8 @@ class LlamaAttention(nn.Module): ...@@ -159,6 +161,8 @@ class LlamaAttention(nn.Module):
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states) qkv, _ = self.qkv_proj(hidden_states)
if os.environ.get('FA_PAD') == '1' and qkv.shape[-1] == 12320:
qkv = qkv[...,:-32]
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k) q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata) attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
...@@ -364,6 +368,8 @@ class LlamaForCausalLM(nn.Module): ...@@ -364,6 +368,8 @@ class LlamaForCausalLM(nn.Module):
config.vocab_size, logit_scale) config.vocab_size, logit_scale)
self.sampler = Sampler() self.sampler = Sampler()
self.use_llama_nn = os.environ.get('LLAMA_NN') == '1' self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
self.use_gemm_pad = os.environ.get('GEMM_PAD') == '1'
self.use_fa_pad = os.environ.get('FA_PAD') == '1'
def forward( def forward(
self, self,
...@@ -452,7 +458,13 @@ class LlamaForCausalLM(nn.Module): ...@@ -452,7 +458,13 @@ class LlamaForCausalLM(nn.Module):
for layername, weight in params_dict.items(): for layername, weight in params_dict.items():
matches = re.findall(combined_words, layername) matches = re.findall(combined_words, layername)
if matches: if matches:
if self.use_gemm_pad and gemm_bank_conf(weight.data.shape[0]):
weight.data = pad_weight(weight.data, 32)
if self.use_fa_pad and weight.data.shape[0] == 12288:
weight.data = pad_weight(weight.data, 32)
_weight = torch.zeros_like(weight.data) _weight = torch.zeros_like(weight.data)
ori_shape =_weight.shape ori_shape =_weight.shape
......
...@@ -33,6 +33,9 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata ...@@ -33,6 +33,9 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import SamplerOutput from vllm.sequence import SamplerOutput
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.model_executor.utils import pad_weight, gemm_bank_conf
class QWenMLP(nn.Module): class QWenMLP(nn.Module):
def __init__( def __init__(
...@@ -120,6 +123,8 @@ class QWenAttention(nn.Module): ...@@ -120,6 +123,8 @@ class QWenAttention(nn.Module):
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
qkv, _ = self.c_attn(hidden_states) qkv, _ = self.c_attn(hidden_states)
if os.environ.get('FA_PAD') == '1' and qkv.shape[-1] == 12320:
qkv = qkv[...,:-32]
q, k, v = qkv.chunk(chunks=3, dim=-1) q, k, v = qkv.chunk(chunks=3, dim=-1)
q, k = self.rotary_emb(positions, q, k) q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata) attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
...@@ -202,7 +207,6 @@ class QWenModel(nn.Module): ...@@ -202,7 +207,6 @@ class QWenModel(nn.Module):
for _ in range(config.num_hidden_layers) for _ in range(config.num_hidden_layers)
]) ])
self.ln_f = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) self.ln_f = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
def forward( def forward(
self, self,
...@@ -242,6 +246,8 @@ class QWenLMHeadModel(nn.Module): ...@@ -242,6 +246,8 @@ class QWenLMHeadModel(nn.Module):
self.logits_processor = LogitsProcessor(config.vocab_size) self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler() self.sampler = Sampler()
self.use_llama_nn = os.environ.get('LLAMA_NN') == '1' self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
self.use_gemm_pad = os.environ.get('GEMM_PAD') == '1'
self.use_fa_pad = os.environ.get('FA_PAD') == '1'
def forward( def forward(
self, self,
...@@ -309,6 +315,12 @@ class QWenLMHeadModel(nn.Module): ...@@ -309,6 +315,12 @@ class QWenLMHeadModel(nn.Module):
for layername, weight in params_dict.items(): for layername, weight in params_dict.items():
matches = re.findall(combined_words, layername) matches = re.findall(combined_words, layername)
if matches: if matches:
if self.use_gemm_pad and gemm_bank_conf(weight.data.shape[0]):
weight.data = pad_weight(weight.data, 32)
if self.use_fa_pad and weight.data.shape[0] == 12288:
weight.data = pad_weight(weight.data, 32)
_weight = torch.zeros_like(weight.data) _weight = torch.zeros_like(weight.data)
ori_shape =_weight.shape ori_shape =_weight.shape
......
...@@ -50,6 +50,9 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata ...@@ -50,6 +50,9 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import SamplerOutput from vllm.sequence import SamplerOutput
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.model_executor.utils import pad_weight, gemm_bank_conf
class Qwen2MLP(nn.Module): class Qwen2MLP(nn.Module):
def __init__( def __init__(
...@@ -150,6 +153,8 @@ class Qwen2Attention(nn.Module): ...@@ -150,6 +153,8 @@ class Qwen2Attention(nn.Module):
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states) qkv, _ = self.qkv_proj(hidden_states)
if os.environ.get('FA_PAD') == '1' and qkv.shape[-1] == 12320:
qkv = qkv[...,:-32]
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k) q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata) attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
...@@ -322,6 +327,8 @@ class Qwen2ForCausalLM(nn.Module): ...@@ -322,6 +327,8 @@ class Qwen2ForCausalLM(nn.Module):
self.logits_processor = LogitsProcessor(config.vocab_size) self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler() self.sampler = Sampler()
self.use_llama_nn = os.environ.get('LLAMA_NN') == '1' self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
self.use_gemm_pad = os.environ.get('GEMM_PAD') == '1'
self.use_fa_pad = os.environ.get('FA_PAD') == '1'
def forward( def forward(
self, self,
...@@ -395,6 +402,12 @@ class Qwen2ForCausalLM(nn.Module): ...@@ -395,6 +402,12 @@ class Qwen2ForCausalLM(nn.Module):
for layername, weight in params_dict.items(): for layername, weight in params_dict.items():
matches = re.findall(combined_words, layername) matches = re.findall(combined_words, layername)
if matches: if matches:
if self.use_gemm_pad and gemm_bank_conf(weight.data.shape[0]):
weight.data = pad_weight(weight.data, 32)
if self.use_fa_pad and weight.data.shape[0] == 12288:
weight.data = pad_weight(weight.data, 32)
_weight = torch.zeros_like(weight.data) _weight = torch.zeros_like(weight.data)
ori_shape =_weight.shape ori_shape =_weight.shape
......
...@@ -33,3 +33,32 @@ def set_weight_attrs( ...@@ -33,3 +33,32 @@ def set_weight_attrs(
assert not hasattr( assert not hasattr(
weight, key), (f"Overwriting existing tensor attribute: {key}") weight, key), (f"Overwriting existing tensor attribute: {key}")
setattr(weight, key, value) setattr(weight, key, value)
def pad_weight(weight: torch.Tensor, num_pad: int, pad_dim: int = 0):
if weight.dim() == 1:
padding = torch.zeros(num_pad, dtype=weight.dtype, device=weight.device)
padded_weight = torch.cat([weight, padding], dim=0)
elif weight.dim() == 2:
if pad_dim == 0:
padding = torch.zeros(num_pad, weight.shape[1], dtype=weight.dtype, device=weight.device)
padded_weight = torch.cat([weight, padding], dim=0)
elif pad_dim == 1:
padding = torch.zeros(weight.shape[0], num_pad, dtype=weight.dtype, device=weight.device)
padded_weight = torch.cat([weight, padding], dim=1)
else:
raise ValueError("pad_dim must be 0 or 1")
else:
raise ValueError("Weight tensor must be 1D or 2D")
padded_weight = padded_weight.contiguous()
return padded_weight
def gemm_bank_conf(weight):
is_mul_of_2048 = weight % 2048 == 0
is_power_of_two = (weight & (weight - 1)) == 0 and weight != 0
if is_mul_of_2048 and is_power_of_two:
return True
else:
return False
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment