Commit e128ad19 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge branch 'v0.9.2-dev-w4a8' into 'v0.9.2-dev'

fix: w4a8 marlin 中 weight重排接入lightop算子

See merge request dcutoolkit/deeplearing/vllm!202
parents f6324f60 082d41a1
......@@ -10,7 +10,7 @@ from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.layers.linear import (LinearBase,LinearMethodBase)
from vllm.model_executor.layers.quantization.base_config import (QuantizationConfig,
QuantizeMethodBase)
from vllm.model_executor.layers.quantization.utils.w4a8_utils import w4a8_2_marlin_weight
from vllm.model_executor.layers.quantization.utils.w4a8_utils import w4a8_weight_repack_impl
from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase,
FusedMoeWeightScaleSupported)
from vllm.model_executor.parameter import (ChannelQuantScaleParameter,
......@@ -22,6 +22,7 @@ try:
except Exception:
print("INFO: Please install lmslim if you want to infer the quantitative model of moe.\n")
class MarlinMoeWorkspace:
"""
Singleton manager for device-specific workspace buffers used by w4a8 Marlin-MoE.
......@@ -205,16 +206,8 @@ class SlimQuantW4A8Int8MarlinMoEMethod:
layer.w2_weight_scale.data, requires_grad=False
)
w1_marlin_list = []
for e in range(layer.w13_weight.shape[0]):
w1_marlin_in = w4a8_2_marlin_weight(layer.w13_weight[e])
w1_marlin_list.append(w1_marlin_in)
layer.w13_weight = Parameter(torch.stack(w1_marlin_list, dim=0), requires_grad=False)
w2_marlin_list = []
for e in range(layer.w2_weight.shape[0]):
w2_marlin_in = w4a8_2_marlin_weight(layer.w2_weight[e])
w2_marlin_list.append(w2_marlin_in)
layer.w2_weight = Parameter(torch.stack(w2_marlin_list, dim=0), requires_grad=False)
layer.w13_weight = Parameter(w4a8_weight_repack_impl(layer.w13_weight), requires_grad=False)
layer.w2_weight = Parameter(w4a8_weight_repack_impl(layer.w2_weight), requires_grad=False)
def apply(
self,
......
......@@ -2,6 +2,12 @@
import torch
import numpy as np
try:
from lightop import awq_marlin_repack_w4a8
use_lightop = True
except Exception:
use_lightop = False
def unpack_int8_to_int4(tensor_int8: torch.Tensor) -> torch.Tensor:
"""
将[N, K//2]大小的torch.int8 Tensor,转换为[N, K]大小的torch.int32 Tensor。
......@@ -70,3 +76,18 @@ def w4a8_2_marlin_weight(w4a8_w):
marlin_q_w = marlin_weights(full_w4a8_w, weight_perm, k_tile=32, n_tile=64, pack_factor=8)
return marlin_q_w
def w4a8_weight_repack_impl(input):
if use_lightop:
size_batch = input.shape[0]
size_n = input.shape[1]
size_k = input.shape[2] * 2
output = torch.zeros((size_batch, size_k // 32, size_n * 4), device=input.device, dtype=torch.int32)
awq_marlin_repack_w4a8(input, output, size_batch, size_k, size_n)
else:
w_marlin_list = []
for e in range(input.shape[0]):
w_marlin_in = w4a8_2_marlin_weight(input[e])
w_marlin_list.append(w_marlin_in)
output = torch.stack(w_marlin_list, dim=0)
return output
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment