Commit c2bcb0ab authored by yangql's avatar yangql
Browse files

增加moe awq-marlin的支持

parent cb37537e
......@@ -16,6 +16,10 @@ try:
from lmslim import quant_tools
except Exception:
print("INFO: Please install lmslim if you want to infer gptq or awq or w8a8 model.\n")
try:
import marlin
except Exception:
print("INFO: Please install marlin if you want to infer awq of marlin.\n")
logger = init_logger(__name__)
......@@ -1473,7 +1477,7 @@ def awq_marlin_moe_repack(b_q_weight: torch.Tensor, perm: torch.Tensor,
device=b_q_weight.device,
dtype=b_q_weight.dtype)
for e in range(num_experts):
output[e] = torch.ops._C.awq_marlin_repack(b_q_weight[e], size_k,
output[e] = torch.ops.marlin.awq_marlin_repack(b_q_weight[e], size_k,
size_n, num_bits)
return output
......
......@@ -5,7 +5,11 @@ import functools
from typing import Optional
import torch
try:
import marlin
except Exception:
print("INFO: Please install marlin if you want to infer awq moe of marlin.\n")
import vllm.envs as envs
import vllm._custom_ops as ops
from vllm.model_executor.layers.fused_moe.fused_moe import (
moe_align_block_size, try_get_optimal_moe_config)
......@@ -14,28 +18,31 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils import (
from vllm.scalar_type import ScalarType, scalar_types
from vllm.utils import direct_register_custom_op
def get_scalar_type(num_bits: int, has_zp: bool):
if has_zp:
return scalar_types.uint4 if num_bits == 4 else scalar_types.uint8
else:
return scalar_types.uint4b8 if num_bits == 4 else scalar_types.uint8b128
def fused_marlin_moe(hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
def fused_marlin_moe(
hidden_states: torch.Tensor, # 32, 7168
w1: torch.Tensor, # 256, 512, 7168 --> 32*8, 512 --> 32*8, 256
w2: torch.Tensor, # 256, 256, 7168
w1_scale: torch.Tensor,
w2_scale: torch.Tensor,
gating_output: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
quant_type_id: int,
apply_router_weight_on_input: bool = False,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
global_scale1: Optional[torch.Tensor] = None,
global_scale2: Optional[torch.Tensor] = None,
g_idx1: Optional[torch.Tensor] = None,
g_idx2: Optional[torch.Tensor] = None,
sort_indices1: Optional[torch.Tensor] = None,
sort_indices2: Optional[torch.Tensor] = None,
w1_zeros: Optional[torch.Tensor] = None,
w2_zeros: Optional[torch.Tensor] = None,
workspace: Optional[torch.Tensor] = None,
# workspace: Optional[torch.Tensor] = None,
num_bits: int = 4,
is_k_full: bool = True,
inplace: bool = False) -> torch.Tensor:
"""
......@@ -65,16 +72,16 @@ def fused_marlin_moe(hidden_states: torch.Tensor,
Returns:
- torch.Tensor: The output tensor after applying the MoE layer.
"""
quant_type = ScalarType.from_id(quant_type_id)
assert quant_type in [
scalar_types.uint4, scalar_types.uint8b128, scalar_types.uint4b8,
scalar_types.float8_e4m3fn, scalar_types.float4_e2m1f
]
# quant_type = ScalarType.from_id(quant_type_id)
# assert quant_type in [
# scalar_types.uint4, scalar_types.uint8b128, scalar_types.uint4b8,
# scalar_types.float8_e4m3fn, scalar_types.float4_e2m1f
# ]
bit4_scalar_types = [
scalar_types.uint4, scalar_types.uint4b8, scalar_types.float4_e2m1f
]
num_bits = 4 if quant_type in bit4_scalar_types else 8
# bit4_scalar_types = [
# scalar_types.uint4, scalar_types.uint4b8, scalar_types.float4_e2m1f
# ]
# num_bits = 4 if quant_type in bit4_scalar_types else 8
# Check constraints.
assert hidden_states.shape[0] == gating_output.shape[
......@@ -87,35 +94,48 @@ def fused_marlin_moe(hidden_states: torch.Tensor,
assert w1.is_contiguous(), "Expert weights1 must be contiguous"
assert w2.is_contiguous(), "Expert weights2 must be contiguous"
assert hidden_states.dtype in [torch.float16, torch.bfloat16]
assert num_bits in [4, 8]
M, K = hidden_states.shape
E = w1.shape[0]
N = w2.shape[1] * 16
topk = topk_ids.shape[1]
get_config_func = functools.partial(
try_get_optimal_moe_config,
w1.shape,
w2.shape,
topk_ids.shape[1],
None,
is_marlin=True,
)
config = get_config_func(M)
block_size_m = config["BLOCK_SIZE_M"]
# assert num_bits in [4, 8]
# 目前只支持 uint4的量化结果
assert num_bits in [4]
M, K = hidden_states.shape # 32, 7168
E = w1.shape[0] # 256
N = w2.shape[1] * 16 # 256
topk = topk_ids.shape[1] # 8
# # 计算 topk_weights 和 topk_ids
# topk_weights, topk_ids = fused_topk(hidden_states, score, topk, False)
# 选择 block_size_m 的逻辑按照 Marlin来设置
for block_size_m in [16, 32, 48, 64, 80]:
if M * topk / E / block_size_m < 0.9:
break
# print("m: ", M, "; block_m: ", block_size_m)
if global_num_experts == -1:
global_num_experts = E
sorted_token_ids, expert_ids, num_tokens_post_padded = \
moe_align_block_size(topk_ids, block_size_m, global_num_experts,
expert_map)
if workspace is None:
workspace = marlin_make_workspace_new(hidden_states.device, 4)
intermediate_cache2 = torch.empty(
# max_num = num_tokens_post_padded.item()
# print("max_num: ", max_num)
# 输出
# for i in range(0, max_num, block_size_m):
# print(i / block_size_m, sorted_token_ids[i:(i + block_size_m)])
# if workspace is None:
# max_workspace_size = (max(2 * N, K) // 64) * \
# (sorted_token_ids.size(0) // block_size_m)
# device = hidden_states.device
# sms = torch.cuda.get_device_properties(device).multi_processor_count
# max_workspace_size = min(max_workspace_size, sms * 4)
# workspace = torch.zeros(max_workspace_size,
# dtype=torch.int,
# device=device,
# requires_grad=False)
scalar_type1 = get_scalar_type(num_bits, w1_zeros is not None)
scalar_type2 = get_scalar_type(num_bits, w2_zeros is not None)
intermediate_cache2 = torch.empty( # [32*8, 256]
(M * topk_ids.shape[1], N),
device=hidden_states.device,
dtype=hidden_states.dtype,
......@@ -125,94 +145,90 @@ def fused_marlin_moe(hidden_states: torch.Tensor,
device=hidden_states.device,
dtype=hidden_states.dtype,
)
intermediate_cache1 = intermediate_cache13[:M * topk_ids.shape[1] * 2 * N]
intermediate_cache1 = intermediate_cache13[:M * topk_ids.shape[1] * 2 * N] # [32*8, 512]
intermediate_cache1 = intermediate_cache1.view(-1, 2 * N)
intermediate_cache3 = intermediate_cache13[:M * topk_ids.shape[1] * K]
intermediate_cache3 = intermediate_cache13[:M * topk_ids.shape[1] * K] # # [32*8, 7168]
intermediate_cache3 = intermediate_cache3.view(-1, K)
maybe_warn_marlin_atomic_add(hidden_states.device, hidden_states.dtype)
use_atomic_add = hidden_states.dtype == torch.half or \
torch.cuda.get_device_capability(hidden_states.device)[0] >= 9
intermediate_cache1 = ops.moe_wna16_marlin_gemm(
hidden_states,
intermediate_cache1,
w1,
w1_scale,
global_scale1,
w1_zeros,
g_idx1,
sort_indices1,
workspace,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
topk_weights,
moe_block_size=block_size_m,
top_k=topk,
mul_topk_weights=apply_router_weight_on_input,
is_ep=expert_map is not None,
b_q_type=quant_type,
size_m=M,
size_n=2 * N,
size_k=K,
is_k_full=is_k_full,
use_atomic_add=use_atomic_add,
use_fp32_reduce=True,
is_zp_float=False)
intermediate_cache1.zero_()
intermediate_cache1 = torch.ops.marlin.moe_wna16_marlin_gemm(
hidden_states, # [32, 7168] # arg0: torch.Tensor,
intermediate_cache1, # [32*8, 512] # arg1: Optional[torch.Tensor]
w1, # arg2: torch.Tensor
w1_scale, # arg3: torch.Tensor
# w1_zeros, # arg4: Optional[torch.Tensor]
g_idx1, # arg5: Optional[torch.Tensor]
sort_indices1, # arg6: Optional[torch.Tensor]
# workspace, # arg7: torch.Tensor
sorted_token_ids, # arg8: torch.Tensor
expert_ids, # arg9: torch.Tensor
num_tokens_post_padded, # arg10: torch.Tensor
topk_weights, #arg11: torch.Tensor,
block_size_m,# arg12: int,
topk, # arg13: int,
False, # arg14: bool,
expert_map is not None, # arg15: bool,
scalar_type1.id, # arg16: int
M, # arg17: int,
2 * N, # arg18: int
K, # arg19: int,
is_k_full, # arg20: bool,
use_atomic_add, # arg21: bool,
True, # arg22: bool
False
) # arg23: bool
# [32*8, 512] --> [32*8, 256]
torch.ops._C.silu_and_mul(intermediate_cache2,
intermediate_cache1.view(-1, 2 * N))
if expert_map is not None:
intermediate_cache3.zero_()
intermediate_cache3 = ops.moe_wna16_marlin_gemm(
intermediate_cache2,
intermediate_cache3,
intermediate_cache3 = torch.ops.marlin.moe_wna16_marlin_gemm(
intermediate_cache2, # [32*8, 256]
intermediate_cache3, # [32*8, 7168]
w2,
w2_scale,
global_scale2,
w2_zeros,
# w2_zeros,
g_idx2,
sort_indices2,
workspace,
# workspace,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
topk_weights,
moe_block_size=block_size_m,
top_k=1,
mul_topk_weights=not apply_router_weight_on_input,
is_ep=expert_map is not None,
b_q_type=quant_type,
size_m=M * topk,
size_n=K,
size_k=N,
is_k_full=is_k_full,
use_atomic_add=use_atomic_add,
use_fp32_reduce=True,
is_zp_float=False).view(-1, topk, K)
block_size_m,
1,
True,
expert_map is not None,
scalar_type2.id,
M * topk,
K,
N,
is_k_full,
use_atomic_add,
True,
False
).view(-1, topk, K)
output = hidden_states if inplace else torch.empty_like(hidden_states)
return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape),
dim=1,
out=output)
def fused_marlin_moe_fake(hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
# return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape),
# dim=1,
# out=output)
ops.moe_sum(intermediate_cache3.view(*intermediate_cache3.shape), output)
return output
def fused_marlin_moe_fake(
hidden_states: torch.Tensor, # 32, 7168
w1: torch.Tensor, # 256, 512, 7168 --> 32*8, 512 --> 32*8, 256
w2: torch.Tensor, # 256, 256, 7168
w1_scale: torch.Tensor,
w2_scale: torch.Tensor,
gating_output: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
quant_type_id: int,
apply_router_weight_on_input: bool = False,
global_num_experts: int = -1,
global_scale1: Optional[torch.Tensor] = None,
global_scale2: Optional[torch.Tensor] = None,
expert_map: Optional[torch.Tensor] = None,
g_idx1: Optional[torch.Tensor] = None,
g_idx2: Optional[torch.Tensor] = None,
......@@ -220,7 +236,8 @@ def fused_marlin_moe_fake(hidden_states: torch.Tensor,
sort_indices2: Optional[torch.Tensor] = None,
w1_zeros: Optional[torch.Tensor] = None,
w2_zeros: Optional[torch.Tensor] = None,
workspace: Optional[torch.Tensor] = None,
# workspace: Optional[torch.Tensor] = None,
num_bits: int = 4,
is_k_full: bool = True,
inplace: bool = False) -> torch.Tensor:
return torch.empty_like(hidden_states)
......
......@@ -338,7 +338,7 @@ class AWQLinearMethod(LinearMethodBase):
if envs.VLLM_USE_TRITON_AWQ:
if m>16:
m = 2 ** math.ceil(math.log2(m))
m = 1 << (m - 1).bit_length()
best_config=getspec_config(m,n,k)
out = awq_gemm_triton(reshaped_x, qweight, scales, qzeros, pack_factor, best_config)
out_shape = (x.shape[:-1] + (qweight.shape[1] * 8, ))
......
......@@ -26,7 +26,8 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils import (
marlin_make_empty_g_idx, marlin_make_workspace_new,
marlin_moe_permute_scales, marlin_permute_scales,
moe_awq_to_marlin_zero_points, verify_marlin_supported,
verify_marlin_supports_shape)
verify_marlin_supports_shape,
awq_marlin_moe_permute_sz)
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.parameter import (GroupQuantScaleParameter,
PackedvLLMParameter)
......@@ -131,10 +132,10 @@ class AWQMarlinConfig(QuantizationConfig):
return UnquantizedLinearMethod()
# Check if the layer is supported by AWQMarlin.
if not check_marlin_supports_layer(layer, self.group_size):
logger.warning_once(
"Layer '%s' is not supported by AWQMarlin. Falling back to unoptimized AWQ kernels.", # noqa: E501
prefix,
)
# logger.warning_once(
# "Layer '%s' is not supported by AWQMarlin. Falling back to unoptimized AWQ kernels.", # noqa: E501
# prefix,
# )
return AWQConfig.from_config(
self.full_config).get_quant_method(layer, prefix)
return AWQMarlinLinearMethod(self)
......@@ -158,8 +159,8 @@ class AWQMarlinConfig(QuantizationConfig):
group_size = quant_config.get("group_size")
zero_point = quant_config.get("zero_point")
if not current_platform.is_cuda():
return False
# if not current_platform.is_cuda():
# return False
if quant_method != "awq":
return False
......@@ -441,7 +442,7 @@ class AWQMoEMethod(FusedMoEMethodBase):
group_size=self.quant_config.group_size,
)
replace_parameter(layer, "w13_scales", marlin_w13_scales)
#replace_parameter(layer, "w13_scales", marlin_w13_scales)
marlin_w2_scales = marlin_moe_permute_scales(
s=layer.w2_scales,
......@@ -449,21 +450,41 @@ class AWQMoEMethod(FusedMoEMethodBase):
size_n=layer.w2_scales.shape[2],
group_size=self.quant_config.group_size,
)
replace_parameter(layer, "w2_scales", marlin_w2_scales)
#replace_parameter(layer, "w2_scales", marlin_w2_scales)
marlin_w13_zp = moe_awq_to_marlin_zero_points(
layer.w13_qzeros,
size_k=layer.w13_qzeros.shape[1],
size_n=layer.w13_qzeros.shape[2] * self.quant_config.pack_factor,
num_bits=self.quant_config.weight_bits)
replace_parameter(layer, "w13_qzeros", marlin_w13_zp)
# replace_parameter(layer, "w13_qzeros", marlin_w13_zp)
marlin_w2_zp = moe_awq_to_marlin_zero_points(
layer.w2_qzeros,
size_k=layer.w2_qzeros.shape[1],
size_n=layer.w2_qzeros.shape[2] * self.quant_config.pack_factor,
num_bits=self.quant_config.weight_bits)
replace_parameter(layer, "w2_qzeros", marlin_w2_zp)
# replace_parameter(layer, "w2_qzeros", marlin_w2_zp)
marlin_w13_sz = awq_marlin_moe_permute_sz(
marlin_w13_scales,
marlin_w13_zp,
size_k=layer.w13_scales.shape[1] * self.quant_config.group_size,
size_n=layer.w13_scales.shape[2]
)
marlin_w2_sz = awq_marlin_moe_permute_sz(
marlin_w2_scales,
marlin_w2_zp,
size_k=layer.w2_scales.shape[1] * self.quant_config.group_size,
size_n=layer.w2_scales.shape[2]
)
replace_parameter(layer, "w13_scales", marlin_w13_sz)
replace_parameter(layer, "w2_scales", marlin_w2_sz)
layer.w13_qzeros = None
layer.w2_qzeros = None
torch.cuda.empty_cache()
def apply(
self,
......@@ -482,6 +503,9 @@ class AWQMoEMethod(FusedMoEMethodBase):
e_score_correction_bias: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
use_nn_moe: Optional[bool] = False,
routed_scaling_factor: Optional[float] = None,
use_fused_gate: Optional[bool] = False,
enable_eplb: bool = False,
expert_load_view: Optional[torch.Tensor] = None,
logical_to_physical_map: Optional[torch.Tensor] = None,
......@@ -503,7 +527,9 @@ class AWQMoEMethod(FusedMoEMethodBase):
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias)
e_score_correction_bias=e_score_correction_bias,
routed_scaling_factor=routed_scaling_factor,
use_fused_gate=use_fused_gate)
return torch.ops.vllm.fused_marlin_moe(
x,
......@@ -514,10 +540,12 @@ class AWQMoEMethod(FusedMoEMethodBase):
router_logits,
topk_weights,
topk_ids,
quant_type_id=self.quant_type.id,
apply_router_weight_on_input=apply_router_weight_on_input,
# quant_type_id=self.quant_type.id,
# apply_router_weight_on_input=apply_router_weight_on_input,
global_num_experts=global_num_experts,
expert_map=expert_map,
w1_zeros=layer.w13_qzeros,
w2_zeros=layer.w2_qzeros,
workspace=layer.workspace)
# workspace=layer.workspace
num_bits=4
)
......@@ -14,6 +14,10 @@ from vllm.platforms import current_platform
from vllm.scalar_type import ScalarType, scalar_types
from .quant_utils import pack_cols, unpack_cols
try:
import marlin
except Exception:
print("INFO: Please install marlin if you want to infer awq moe of marlin.\n")
logger = init_logger(__name__)
......@@ -153,7 +157,7 @@ def check_marlin_supports_shape(output_size_per_partition: int,
return False, e.__str__()
return True, None
#暂不支持marlinlinear
def check_marlin_supports_layer(layer: LinearBase, group_size: int) \
-> bool:
output_size_per_partition = getattr(layer, "output_size_per_partition",
......@@ -161,12 +165,12 @@ def check_marlin_supports_layer(layer: LinearBase, group_size: int) \
input_size_per_partition = getattr(layer, "input_size_per_partition",
None) or layer.input_size
return check_marlin_supports_shape(
output_size_per_partition=output_size_per_partition,
input_size_per_partition=input_size_per_partition,
input_size=layer.input_size,
group_size=group_size)[0]
# return check_marlin_supports_shape(
# output_size_per_partition=output_size_per_partition,
# input_size_per_partition=input_size_per_partition,
# input_size=layer.input_size,
# group_size=group_size)[0]
return False
def check_moe_marlin_supports_layer(layer: LinearBase, group_size: int) \
-> bool:
......@@ -237,30 +241,46 @@ def marlin_sort_g_idx(
return g_idx[g_idx_sort_indices], g_idx_sort_indices
def get_scale_perms():
scale_perm: list[int] = []
for i in range(8):
scale_perm.extend([i + 8 * j for j in range(8)])
scale_perm_single: list[int] = []
for i in range(4):
scale_perm_single.extend(
[2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]])
return scale_perm, scale_perm_single
# def get_scale_perms():
# scale_perm: list[int] = []
# for i in range(8):
# scale_perm.extend([i + 8 * j for j in range(8)])
# scale_perm_single: list[int] = []
# for i in range(4):
# scale_perm_single.extend(
# [2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]])
# return scale_perm, scale_perm_single
# def marlin_permute_scales(s: torch.Tensor, size_k: int, size_n: int,
# group_size: int) -> torch.Tensor:
def marlin_permute_scales(s: torch.Tensor, size_k: int, size_n: int,
group_size: int) -> torch.Tensor:
# scale_perm, scale_perm_single = get_scale_perms()
# if group_size < size_k and group_size != -1:
# s = s.reshape((-1, len(scale_perm)))[:, scale_perm]
# else:
# s = s.reshape((-1, len(scale_perm_single)))[:, scale_perm_single]
# s = s.reshape((-1, size_n)).contiguous()
scale_perm, scale_perm_single = get_scale_perms()
if group_size < size_k and group_size != -1:
# return s
def get_scale_perms():
scale_perm: List[int] = []
for i in range(16): # 遍历列方向不同scale的 8个线程
scale_perm.extend([i + 16 * j for j in range(8)]) # 插入 8 个数据块中 对应位置的索引
return scale_perm
def marlin_permute_scales(s: torch.Tensor, # [56, 512] # torch.float16
size_k: int, # 7168
size_n: int, # 512
group_size: int # 128
) -> torch.Tensor:
# 将[128, 128](fp16) B矩阵中 每个[16, 16]计算块中的对应位置的 zero值 放到一起
scale_perm = get_scale_perms()
s = s.reshape((-1, len(scale_perm)))[:, scale_perm]
else:
s = s.reshape((-1, len(scale_perm_single)))[:, scale_perm_single]
s = s.reshape((-1, size_n)).contiguous()
return s
def marlin_moe_permute_scales(
s: torch.Tensor,
size_k: int,
......@@ -281,19 +301,18 @@ def marlin_moe_permute_scales(
def marlin_zero_points(zp: torch.Tensor, size_k: int, size_n: int,
num_bits: int) -> torch.Tensor:
# Permute zero-points in a similar way to scales, but do not use the
# "single" permutation, since zero-points are applied on every MMA
scale_perm, _ = get_scale_perms()
# 和 scale 使用一致的重排逻辑,将[128, 128](fp16) B矩阵中 每个[16, 16]计算块中的对应位置的 zero值 放到一起
scale_perm = get_scale_perms()
zp = zp.reshape((-1, len(scale_perm)))[:, scale_perm]
# Interleave column dim (for the dequantize code) and pack it to int32
# uint4 混排
if num_bits == 4:
interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7])
elif num_bits == 8:
interleave = numpy.array([0, 2, 1, 3])
else:
raise Exception("num_bits must be 4 or 8, got {}".format(num_bits))
# uint4打包成 int32
zp = zp.reshape((-1, len(interleave)))[:, interleave].ravel()
zp = zp.reshape((-1, size_n)).contiguous()
zp = pack_cols(zp, num_bits, size_k, size_n)
......@@ -474,3 +493,86 @@ def apply_awq_marlin_linear(
output.add_(bias) # In-place add
return output.reshape(out_shape)
def merge_scales_zeros(marlin_s: torch.Tensor, marlin_zp: torch.Tensor,
data_num_0: int, data_num_1: int) -> torch.Tensor:
"""
合并两个 Tensor, 每行交替取 data_num_0 个 float16 和 data_num_1 个 int32。
要求:
- marlin_s 每行长度能被 data_num_0 整除
- marlin_zp 每行长度能被 data_num_1 整除
- 合并后的总字节数必为 4 的倍数
返回:
[N, M] 的 int32 Tensor(行数一致,列数已对齐)
"""
assert marlin_s.shape[0] == marlin_zp.shape[0], "Batch size mismatch"
assert marlin_s.dtype == torch.float16
assert marlin_zp.dtype == torch.int32
N, D0 = marlin_s.shape
_, D1 = marlin_zp.shape
assert D0 % data_num_0 == 0, "marlin_s 每行必须能被 data_num_0 整除"
assert D1 % data_num_1 == 0, "marlin_zp 每行必须能被 data_num_1 整除"
s_block_count = D0 // data_num_0
zp_block_count = D1 // data_num_1
assert s_block_count == zp_block_count
total_blocks = s_block_count
# 转为字节视图
s_bytes = marlin_s.view(torch.uint8).reshape(N, -1)
zp_bytes = marlin_zp.view(torch.uint8).reshape(N, -1)
# 每行的合并结果
merged_rows = []
for i in range(N):
s_row = s_bytes[i]
zp_row = zp_bytes[i]
s_ptr = 0
zp_ptr = 0
merged = []
for _ in range(total_blocks):
# 如果 s 还有剩余 block,就取
if s_ptr < s_row.numel():
chunk_s = s_row[s_ptr: s_ptr + data_num_0 * 2] # float16 = 2 字节
merged.append(chunk_s)
s_ptr += data_num_0 * 2
# 如果 zp 还有剩余 block,就取
if zp_ptr < zp_row.numel():
chunk_zp = zp_row[zp_ptr: zp_ptr + data_num_1 * 4] # int32 = 4 字节
merged.append(chunk_zp)
zp_ptr += data_num_1 * 4
# 合并所有字节,并直接转换为 int32
merged_bytes = torch.cat(merged)
# assert merged_bytes.numel() % 4 == 0, "最终字节长度必须是4的倍数"
merged_int32 = merged_bytes.view(torch.int32)
merged_rows.append(merged_int32)
# 所有合并行长度一致,可以直接堆叠
result = torch.stack(merged_rows)
return result
def awq_marlin_moe_permute_sz(
s : torch.Tensor,
z : torch.Tensor,
size_k: int,
size_n: int,
) -> torch.Tensor:
num_experts = s.shape[0]
# output = torch.empty((num_experts, size_k // 16, size_n//2 + size_n//8),
# device=z.device,
# dtype=z.dtype)
outputs = []
for e in range(num_experts):
out_sz = merge_scales_zeros(s[e], z[e], 128, 16)
outputs.append(out_sz)
return torch.stack(outputs, dim=0)
......@@ -164,7 +164,7 @@ class DeepSeekMTP(nn.Module, SupportsPP):
os.environ['LLAMA_NN'] = '0'
os.environ['LM_NN'] = '0'
# The AWQ layer of MTP uses BlockInt8W8A8.
if self.quant_method == "moe_wna16":
if self.quant_method == "moe_wna16" or self.quant_method == "awq_marlin":
vllm_config.quant_config = BlockInt8Config(is_checkpoint_int8_serialized=True, weight_block_size=[128,128])
self.model = DeepSeekMultiTokenPredictor(vllm_config=vllm_config,
......
......@@ -180,7 +180,7 @@ class RocmPlatform(Platform):
supported_quantization: list[str] = [
"awq", "gptq", "fp8", "compressed-tensors", "fbgemm_fp8", "gguf",
"quark", "ptpc_fp8", "moe_wna16", "blockwise_int8","w8a8_int8"
"quark", "ptpc_fp8", "moe_wna16", "blockwise_int8","w8a8_int8","awq_marlin"
]
@classmethod
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment