Commit 4f6c0cd4 authored by zhuwenwen's avatar zhuwenwen
Browse files

fix: w4a8 marlin 中 weight重排接入lightop算子

parent 2b8d795b
...@@ -10,7 +10,7 @@ from vllm.model_executor.layers.quantization import QuantizationMethods ...@@ -10,7 +10,7 @@ from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.layers.linear import (LinearBase,LinearMethodBase) from vllm.model_executor.layers.linear import (LinearBase,LinearMethodBase)
from vllm.model_executor.layers.quantization.base_config import (QuantizationConfig, from vllm.model_executor.layers.quantization.base_config import (QuantizationConfig,
QuantizeMethodBase) QuantizeMethodBase)
from vllm.model_executor.layers.quantization.utils.w4a8_utils import w4a8_2_marlin_weight from vllm.model_executor.layers.quantization.utils.w4a8_utils import w4a8_weight_repack_impl
from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase, from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase,
FusedMoeWeightScaleSupported) FusedMoeWeightScaleSupported)
from vllm.model_executor.parameter import (ChannelQuantScaleParameter, from vllm.model_executor.parameter import (ChannelQuantScaleParameter,
...@@ -205,16 +205,8 @@ class SlimQuantW4A8Int8MarlinMoEMethod: ...@@ -205,16 +205,8 @@ class SlimQuantW4A8Int8MarlinMoEMethod:
layer.w2_weight_scale.data, requires_grad=False layer.w2_weight_scale.data, requires_grad=False
) )
w1_marlin_list = [] layer.w13_weight = Parameter(w4a8_weight_repack_impl(layer.w13_weight), requires_grad=False)
for e in range(layer.w13_weight.shape[0]): layer.w2_weight = Parameter(w4a8_weight_repack_impl(layer.w2_weight), requires_grad=False)
w1_marlin_in = w4a8_2_marlin_weight(layer.w13_weight[e])
w1_marlin_list.append(w1_marlin_in)
layer.w13_weight = Parameter(torch.stack(w1_marlin_list, dim=0), requires_grad=False)
w2_marlin_list = []
for e in range(layer.w2_weight.shape[0]):
w2_marlin_in = w4a8_2_marlin_weight(layer.w2_weight[e])
w2_marlin_list.append(w2_marlin_in)
layer.w2_weight = Parameter(torch.stack(w2_marlin_list, dim=0), requires_grad=False)
def apply( def apply(
self, self,
......
import torch import torch
import numpy as np import numpy as np
try:
from lightop import awq_marlin_repack_w4a8
use_lightop = True
except Exception:
use_lightop = False
def unpack_int8_to_int4(tensor_int8: torch.Tensor) -> torch.Tensor: def unpack_int8_to_int4(tensor_int8: torch.Tensor) -> torch.Tensor:
""" """
将[N, K//2]大小的torch.int8 Tensor,转换为[N, K]大小的torch.int32 Tensor。 将[N, K//2]大小的torch.int8 Tensor,转换为[N, K]大小的torch.int32 Tensor。
...@@ -69,3 +75,19 @@ def w4a8_2_marlin_weight(w4a8_w): ...@@ -69,3 +75,19 @@ def w4a8_2_marlin_weight(w4a8_w):
weight_perm = get_weight_perms() weight_perm = get_weight_perms()
marlin_q_w = marlin_weights(full_w4a8_w, weight_perm, k_tile=32, n_tile=64, pack_factor=8) marlin_q_w = marlin_weights(full_w4a8_w, weight_perm, k_tile=32, n_tile=64, pack_factor=8)
return marlin_q_w return marlin_q_w
def w4a8_weight_repack_impl(input):
if use_lightop:
size_batch = input.shape[0]
size_n = input.shape[1]
size_k = input.shape[2] * 2
output = torch.zeros((size_batch, size_k // 32, size_n * 4), device=input.device, dtype=torch.int32)
awq_marlin_repack_w4a8(input, output, size_batch, size_k, size_n)
else:
w_marlin_list = []
for e in range(input.shape[0]):
w_marlin_in = w4a8_2_marlin_weight(input[e])
w_marlin_list.append(w_marlin_in)
output = torch.stack(w_marlin_list, dim=0)
return output
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment