Commit d0181e5a authored by jujl1's avatar jujl1
Browse files

feat: w4a8 marlin 初版合入

parent 20316346
......@@ -7,6 +7,7 @@ from torch.nn.parameter import Parameter
from vllm.model_executor.layers.linear import (LinearBase,LinearMethodBase)
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.layers.quantization.utils.w4a8_utils import w4a8_2_marlin_weight
from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase,
FusedMoeWeightScaleSupported)
from vllm.model_executor.parameter import (BasevLLMParameter,
......@@ -18,10 +19,29 @@ from lmslim.layers.gemm.int8_utils import (
per_token_quant_int8)
from vllm import _custom_ops as ops
from vllm.utils import W8a8GetCacheJSON
import vllm.envs as envs
import os
from vllm import _custom_ops as ops
try:
from lmslim.layers.fused_moe.fuse_moe_w4a8_marlin import fused_experts_impl_w4a8_marlin
except Exception:
print("INFO: Please install lmslim if you want to infer the quantitative model of moe.\n")
workspace = None
global_reduce_buffer = None
def get_marlin_moe_workspace(device):
global workspace
global global_reduce_buffer
if workspace is None:
sms = torch.cuda.get_device_properties(device).multi_processor_count
workspace = torch.zeros(500, dtype=torch.int, device=device, requires_grad=False)
if global_reduce_buffer is None:
sms = torch.cuda.get_device_properties(device).multi_processor_count
global_reduce_buffer = torch.zeros(sms * 6 * 128 * 512, dtype=torch.int, device=device, requires_grad=False)
return workspace, global_reduce_buffer
W8A8_TRITONJSON=W8a8GetCacheJSON()
def baseline_scaled_mm(a: torch.Tensor,
......@@ -319,14 +339,27 @@ class SlimQuantW4A8Int8MoEMethod:
if configs_dict:
self.tritonsingleton.triton_moejson_dict.update(configs_dict)
layer.w13_weight = Parameter(layer.w13_weight, requires_grad=False)
layer.w2_weight = Parameter(layer.w2_weight, requires_grad=False)
# layer.w13_weight = Parameter(layer.w13_weight, requires_grad=False)
# layer.w2_weight = Parameter(layer.w2_weight, requires_grad=False)
layer.w13_weight_scale = Parameter(
layer.w13_weight_scale.data, requires_grad=False
)
layer.w2_weight_scale = Parameter(
layer.w2_weight_scale.data, requires_grad=False
)
w1_marlin_list = []
for e in range(layer.w13_weight.shape[0]):
w1_marlin_in = w4a8_2_marlin_weight(layer.w13_weight[e])
w1_marlin_list.append(w1_marlin_in)
layer.w13_weight = Parameter(torch.stack(w1_marlin_list, dim=0), requires_grad=False)
w2_marlin_list = []
for e in range(layer.w2_weight.shape[0]):
w2_marlin_in = w4a8_2_marlin_weight(layer.w2_weight[e])
w2_marlin_list.append(w2_marlin_in)
layer.w2_weight = Parameter(torch.stack(w2_marlin_list, dim=0), requires_grad=False)
def apply(
self,
......@@ -370,13 +403,15 @@ class SlimQuantW4A8Int8MoEMethod:
routed_scaling_factor=routed_scaling_factor,
use_fused_gate=use_fused_gate
)
return fused_experts(
workspace, global_reduce_buffer=get_marlin_moe_workspace(device=x.device)
return fused_experts_impl_w4a8_marlin(
x,
layer.w13_weight,
layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
workspace=workspace,
global_reduce_buffer=global_reduce_buffer,
inplace=True,
use_int4_w4a8=True,
per_channel_quant=True,
......
import torch
import numpy as np
def unpack_int8_to_int4(tensor_int8: torch.Tensor) -> torch.Tensor:
"""
将[N, K//2]大小的torch.int8 Tensor,转换为[N, K]大小的torch.int32 Tensor。
每个int8包含两个int4,分别提取到int32的低4位,其余位为0。
Args:
tensor_int8 (torch.Tensor): 输入张量,形状为[N, K//2],类型为torch.int8。
Returns:
torch.Tensor: 输出张量,形状为[N, K],类型为torch.int32。
"""
if tensor_int8.dtype != torch.int8:
raise ValueError("Input tensor must be of type torch.int8")
N, K_half = tensor_int8.shape
tensor_uint8 = tensor_int8.to(torch.uint8)
low4 = tensor_uint8 & 0x0F
high4 = (tensor_uint8 >> 4) & 0x0F
unpacked = torch.empty((N, K_half * 2), dtype=torch.int32, device=tensor_int8.device)
unpacked[:, 0::2] = low4.to(torch.int32)
unpacked[:, 1::2] = high4.to(torch.int32)
return unpacked
def get_weight_perms(interleave: bool=True):
perm = []
for i in range(64):
for col in range(4):
cur_col = (i % 16) * 4 + col
for row in range(8):
cur_row = (i // 16) * 8 + row
cur_idx = cur_row * 64 + cur_col
perm.append(cur_idx)
perm = np.array(perm)
if interleave:
interleave = np.array([4, 0, 5, 1, 6, 2, 7, 3])
perm = perm.reshape((-1, 8))[:, interleave].ravel()
perm = torch.from_numpy(perm)
return perm
def marlin_weights(q_w,weight_perm,k_tile=32,n_tile=64,pack_factor=8):
size_k, size_n = q_w.shape
q_w = q_w.reshape((size_k // k_tile, k_tile, size_n // n_tile, n_tile))
q_w = q_w.permute((0, 2, 1, 3))
q_w = q_w.reshape((size_k // k_tile, size_n * k_tile))
q_w = q_w.reshape((-1, weight_perm.numel()))[:, weight_perm].reshape(q_w.shape)
orig_device = q_w.device
q_w = q_w.cpu().numpy().astype(np.uint32)
q_packed = np.zeros((q_w.shape[0], q_w.shape[1] // pack_factor), dtype=np.uint32)
for i in range(pack_factor):
q_packed |= q_w[:, i::pack_factor] << 4 * i
q_packed = torch.from_numpy(q_packed.astype(np.int32)).to(orig_device)
return q_packed
def w4a8_2_marlin_weight(w4a8_w):
full_w4a8_w = unpack_int8_to_int4(w4a8_w)
full_w4a8_w = full_w4a8_w.T
weight_perm = get_weight_perms()
marlin_q_w = marlin_weights(full_w4a8_w, weight_perm, k_tile=32, n_tile=64, pack_factor=8)
return marlin_q_w
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment