Commit c2bcb0ab authored by yangql's avatar yangql
Browse files

增加moe awq-marlin的支持

parent cb37537e
...@@ -16,6 +16,10 @@ try: ...@@ -16,6 +16,10 @@ try:
from lmslim import quant_tools from lmslim import quant_tools
except Exception: except Exception:
print("INFO: Please install lmslim if you want to infer gptq or awq or w8a8 model.\n") print("INFO: Please install lmslim if you want to infer gptq or awq or w8a8 model.\n")
try:
import marlin
except Exception:
print("INFO: Please install marlin if you want to infer awq of marlin.\n")
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -1473,7 +1477,7 @@ def awq_marlin_moe_repack(b_q_weight: torch.Tensor, perm: torch.Tensor, ...@@ -1473,7 +1477,7 @@ def awq_marlin_moe_repack(b_q_weight: torch.Tensor, perm: torch.Tensor,
device=b_q_weight.device, device=b_q_weight.device,
dtype=b_q_weight.dtype) dtype=b_q_weight.dtype)
for e in range(num_experts): for e in range(num_experts):
output[e] = torch.ops._C.awq_marlin_repack(b_q_weight[e], size_k, output[e] = torch.ops.marlin.awq_marlin_repack(b_q_weight[e], size_k,
size_n, num_bits) size_n, num_bits)
return output return output
......
...@@ -5,7 +5,11 @@ import functools ...@@ -5,7 +5,11 @@ import functools
from typing import Optional from typing import Optional
import torch import torch
try:
import marlin
except Exception:
print("INFO: Please install marlin if you want to infer awq moe of marlin.\n")
import vllm.envs as envs
import vllm._custom_ops as ops import vllm._custom_ops as ops
from vllm.model_executor.layers.fused_moe.fused_moe import ( from vllm.model_executor.layers.fused_moe.fused_moe import (
moe_align_block_size, try_get_optimal_moe_config) moe_align_block_size, try_get_optimal_moe_config)
...@@ -14,28 +18,31 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils import ( ...@@ -14,28 +18,31 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils import (
from vllm.scalar_type import ScalarType, scalar_types from vllm.scalar_type import ScalarType, scalar_types
from vllm.utils import direct_register_custom_op from vllm.utils import direct_register_custom_op
def get_scalar_type(num_bits: int, has_zp: bool):
if has_zp:
return scalar_types.uint4 if num_bits == 4 else scalar_types.uint8
else:
return scalar_types.uint4b8 if num_bits == 4 else scalar_types.uint8b128
def fused_marlin_moe(hidden_states: torch.Tensor, def fused_marlin_moe(
w1: torch.Tensor, hidden_states: torch.Tensor, # 32, 7168
w2: torch.Tensor, w1: torch.Tensor, # 256, 512, 7168 --> 32*8, 512 --> 32*8, 256
w2: torch.Tensor, # 256, 256, 7168
w1_scale: torch.Tensor, w1_scale: torch.Tensor,
w2_scale: torch.Tensor, w2_scale: torch.Tensor,
gating_output: torch.Tensor, gating_output: torch.Tensor,
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
quant_type_id: int,
apply_router_weight_on_input: bool = False,
global_num_experts: int = -1, global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None, expert_map: Optional[torch.Tensor] = None,
global_scale1: Optional[torch.Tensor] = None,
global_scale2: Optional[torch.Tensor] = None,
g_idx1: Optional[torch.Tensor] = None, g_idx1: Optional[torch.Tensor] = None,
g_idx2: Optional[torch.Tensor] = None, g_idx2: Optional[torch.Tensor] = None,
sort_indices1: Optional[torch.Tensor] = None, sort_indices1: Optional[torch.Tensor] = None,
sort_indices2: Optional[torch.Tensor] = None, sort_indices2: Optional[torch.Tensor] = None,
w1_zeros: Optional[torch.Tensor] = None, w1_zeros: Optional[torch.Tensor] = None,
w2_zeros: Optional[torch.Tensor] = None, w2_zeros: Optional[torch.Tensor] = None,
workspace: Optional[torch.Tensor] = None, # workspace: Optional[torch.Tensor] = None,
num_bits: int = 4,
is_k_full: bool = True, is_k_full: bool = True,
inplace: bool = False) -> torch.Tensor: inplace: bool = False) -> torch.Tensor:
""" """
...@@ -65,16 +72,16 @@ def fused_marlin_moe(hidden_states: torch.Tensor, ...@@ -65,16 +72,16 @@ def fused_marlin_moe(hidden_states: torch.Tensor,
Returns: Returns:
- torch.Tensor: The output tensor after applying the MoE layer. - torch.Tensor: The output tensor after applying the MoE layer.
""" """
quant_type = ScalarType.from_id(quant_type_id) # quant_type = ScalarType.from_id(quant_type_id)
assert quant_type in [ # assert quant_type in [
scalar_types.uint4, scalar_types.uint8b128, scalar_types.uint4b8, # scalar_types.uint4, scalar_types.uint8b128, scalar_types.uint4b8,
scalar_types.float8_e4m3fn, scalar_types.float4_e2m1f # scalar_types.float8_e4m3fn, scalar_types.float4_e2m1f
] # ]
bit4_scalar_types = [ # bit4_scalar_types = [
scalar_types.uint4, scalar_types.uint4b8, scalar_types.float4_e2m1f # scalar_types.uint4, scalar_types.uint4b8, scalar_types.float4_e2m1f
] # ]
num_bits = 4 if quant_type in bit4_scalar_types else 8 # num_bits = 4 if quant_type in bit4_scalar_types else 8
# Check constraints. # Check constraints.
assert hidden_states.shape[0] == gating_output.shape[ assert hidden_states.shape[0] == gating_output.shape[
...@@ -87,35 +94,48 @@ def fused_marlin_moe(hidden_states: torch.Tensor, ...@@ -87,35 +94,48 @@ def fused_marlin_moe(hidden_states: torch.Tensor,
assert w1.is_contiguous(), "Expert weights1 must be contiguous" assert w1.is_contiguous(), "Expert weights1 must be contiguous"
assert w2.is_contiguous(), "Expert weights2 must be contiguous" assert w2.is_contiguous(), "Expert weights2 must be contiguous"
assert hidden_states.dtype in [torch.float16, torch.bfloat16] assert hidden_states.dtype in [torch.float16, torch.bfloat16]
assert num_bits in [4, 8] # assert num_bits in [4, 8]
# 目前只支持 uint4的量化结果
M, K = hidden_states.shape assert num_bits in [4]
E = w1.shape[0]
N = w2.shape[1] * 16 M, K = hidden_states.shape # 32, 7168
topk = topk_ids.shape[1] E = w1.shape[0] # 256
N = w2.shape[1] * 16 # 256
get_config_func = functools.partial( topk = topk_ids.shape[1] # 8
try_get_optimal_moe_config, # # 计算 topk_weights 和 topk_ids
w1.shape, # topk_weights, topk_ids = fused_topk(hidden_states, score, topk, False)
w2.shape,
topk_ids.shape[1], # 选择 block_size_m 的逻辑按照 Marlin来设置
None, for block_size_m in [16, 32, 48, 64, 80]:
is_marlin=True, if M * topk / E / block_size_m < 0.9:
) break
config = get_config_func(M) # print("m: ", M, "; block_m: ", block_size_m)
block_size_m = config["BLOCK_SIZE_M"]
if global_num_experts == -1: if global_num_experts == -1:
global_num_experts = E global_num_experts = E
sorted_token_ids, expert_ids, num_tokens_post_padded = \ sorted_token_ids, expert_ids, num_tokens_post_padded = \
moe_align_block_size(topk_ids, block_size_m, global_num_experts, moe_align_block_size(topk_ids, block_size_m, global_num_experts,
expert_map) expert_map)
# max_num = num_tokens_post_padded.item()
if workspace is None: # print("max_num: ", max_num)
workspace = marlin_make_workspace_new(hidden_states.device, 4) # 输出
# for i in range(0, max_num, block_size_m):
intermediate_cache2 = torch.empty( # print(i / block_size_m, sorted_token_ids[i:(i + block_size_m)])
# if workspace is None:
# max_workspace_size = (max(2 * N, K) // 64) * \
# (sorted_token_ids.size(0) // block_size_m)
# device = hidden_states.device
# sms = torch.cuda.get_device_properties(device).multi_processor_count
# max_workspace_size = min(max_workspace_size, sms * 4)
# workspace = torch.zeros(max_workspace_size,
# dtype=torch.int,
# device=device,
# requires_grad=False)
scalar_type1 = get_scalar_type(num_bits, w1_zeros is not None)
scalar_type2 = get_scalar_type(num_bits, w2_zeros is not None)
intermediate_cache2 = torch.empty( # [32*8, 256]
(M * topk_ids.shape[1], N), (M * topk_ids.shape[1], N),
device=hidden_states.device, device=hidden_states.device,
dtype=hidden_states.dtype, dtype=hidden_states.dtype,
...@@ -125,104 +145,101 @@ def fused_marlin_moe(hidden_states: torch.Tensor, ...@@ -125,104 +145,101 @@ def fused_marlin_moe(hidden_states: torch.Tensor,
device=hidden_states.device, device=hidden_states.device,
dtype=hidden_states.dtype, dtype=hidden_states.dtype,
) )
intermediate_cache1 = intermediate_cache13[:M * topk_ids.shape[1] * 2 * N] intermediate_cache1 = intermediate_cache13[:M * topk_ids.shape[1] * 2 * N] # [32*8, 512]
intermediate_cache1 = intermediate_cache1.view(-1, 2 * N) intermediate_cache1 = intermediate_cache1.view(-1, 2 * N)
intermediate_cache3 = intermediate_cache13[:M * topk_ids.shape[1] * K] intermediate_cache3 = intermediate_cache13[:M * topk_ids.shape[1] * K] # # [32*8, 7168]
intermediate_cache3 = intermediate_cache3.view(-1, K) intermediate_cache3 = intermediate_cache3.view(-1, K)
maybe_warn_marlin_atomic_add(hidden_states.device, hidden_states.dtype)
use_atomic_add = hidden_states.dtype == torch.half or \ use_atomic_add = hidden_states.dtype == torch.half or \
torch.cuda.get_device_capability(hidden_states.device)[0] >= 9 torch.cuda.get_device_capability(hidden_states.device)[0] >= 9
intermediate_cache1 = ops.moe_wna16_marlin_gemm( intermediate_cache1.zero_()
hidden_states, intermediate_cache1 = torch.ops.marlin.moe_wna16_marlin_gemm(
intermediate_cache1, hidden_states, # [32, 7168] # arg0: torch.Tensor,
w1, intermediate_cache1, # [32*8, 512] # arg1: Optional[torch.Tensor]
w1_scale, w1, # arg2: torch.Tensor
global_scale1, w1_scale, # arg3: torch.Tensor
w1_zeros, # w1_zeros, # arg4: Optional[torch.Tensor]
g_idx1, g_idx1, # arg5: Optional[torch.Tensor]
sort_indices1, sort_indices1, # arg6: Optional[torch.Tensor]
workspace, # workspace, # arg7: torch.Tensor
sorted_token_ids, sorted_token_ids, # arg8: torch.Tensor
expert_ids, expert_ids, # arg9: torch.Tensor
num_tokens_post_padded, num_tokens_post_padded, # arg10: torch.Tensor
topk_weights, topk_weights, #arg11: torch.Tensor,
moe_block_size=block_size_m, block_size_m,# arg12: int,
top_k=topk, topk, # arg13: int,
mul_topk_weights=apply_router_weight_on_input, False, # arg14: bool,
is_ep=expert_map is not None, expert_map is not None, # arg15: bool,
b_q_type=quant_type, scalar_type1.id, # arg16: int
size_m=M, M, # arg17: int,
size_n=2 * N, 2 * N, # arg18: int
size_k=K, K, # arg19: int,
is_k_full=is_k_full, is_k_full, # arg20: bool,
use_atomic_add=use_atomic_add, use_atomic_add, # arg21: bool,
use_fp32_reduce=True, True, # arg22: bool
is_zp_float=False) False
) # arg23: bool
# [32*8, 512] --> [32*8, 256]
torch.ops._C.silu_and_mul(intermediate_cache2, torch.ops._C.silu_and_mul(intermediate_cache2,
intermediate_cache1.view(-1, 2 * N)) intermediate_cache1.view(-1, 2 * N))
intermediate_cache3.zero_()
if expert_map is not None: intermediate_cache3 = torch.ops.marlin.moe_wna16_marlin_gemm(
intermediate_cache3.zero_() intermediate_cache2, # [32*8, 256]
intermediate_cache3, # [32*8, 7168]
intermediate_cache3 = ops.moe_wna16_marlin_gemm(
intermediate_cache2,
intermediate_cache3,
w2, w2,
w2_scale, w2_scale,
global_scale2, # w2_zeros,
w2_zeros,
g_idx2, g_idx2,
sort_indices2, sort_indices2,
workspace, # workspace,
sorted_token_ids, sorted_token_ids,
expert_ids, expert_ids,
num_tokens_post_padded, num_tokens_post_padded,
topk_weights, topk_weights,
moe_block_size=block_size_m, block_size_m,
top_k=1, 1,
mul_topk_weights=not apply_router_weight_on_input, True,
is_ep=expert_map is not None, expert_map is not None,
b_q_type=quant_type, scalar_type2.id,
size_m=M * topk, M * topk,
size_n=K, K,
size_k=N, N,
is_k_full=is_k_full, is_k_full,
use_atomic_add=use_atomic_add, use_atomic_add,
use_fp32_reduce=True, True,
is_zp_float=False).view(-1, topk, K) False
).view(-1, topk, K)
output = hidden_states if inplace else torch.empty_like(hidden_states) output = hidden_states if inplace else torch.empty_like(hidden_states)
return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape), # return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape),
dim=1, # dim=1,
out=output) # out=output)
ops.moe_sum(intermediate_cache3.view(*intermediate_cache3.shape), output)
return output
def fused_marlin_moe_fake(hidden_states: torch.Tensor,
w1: torch.Tensor, def fused_marlin_moe_fake(
w2: torch.Tensor, hidden_states: torch.Tensor, # 32, 7168
w1_scale: torch.Tensor, w1: torch.Tensor, # 256, 512, 7168 --> 32*8, 512 --> 32*8, 256
w2_scale: torch.Tensor, w2: torch.Tensor, # 256, 256, 7168
gating_output: torch.Tensor, w1_scale: torch.Tensor,
topk_weights: torch.Tensor, w2_scale: torch.Tensor,
topk_ids: torch.Tensor, gating_output: torch.Tensor,
quant_type_id: int, topk_weights: torch.Tensor,
apply_router_weight_on_input: bool = False, topk_ids: torch.Tensor,
global_num_experts: int = -1, global_num_experts: int = -1,
global_scale1: Optional[torch.Tensor] = None, expert_map: Optional[torch.Tensor] = None,
global_scale2: Optional[torch.Tensor] = None, g_idx1: Optional[torch.Tensor] = None,
expert_map: Optional[torch.Tensor] = None, g_idx2: Optional[torch.Tensor] = None,
g_idx1: Optional[torch.Tensor] = None, sort_indices1: Optional[torch.Tensor] = None,
g_idx2: Optional[torch.Tensor] = None, sort_indices2: Optional[torch.Tensor] = None,
sort_indices1: Optional[torch.Tensor] = None, w1_zeros: Optional[torch.Tensor] = None,
sort_indices2: Optional[torch.Tensor] = None, w2_zeros: Optional[torch.Tensor] = None,
w1_zeros: Optional[torch.Tensor] = None, # workspace: Optional[torch.Tensor] = None,
w2_zeros: Optional[torch.Tensor] = None, num_bits: int = 4,
workspace: Optional[torch.Tensor] = None, is_k_full: bool = True,
is_k_full: bool = True, inplace: bool = False) -> torch.Tensor:
inplace: bool = False) -> torch.Tensor:
return torch.empty_like(hidden_states) return torch.empty_like(hidden_states)
......
...@@ -338,7 +338,7 @@ class AWQLinearMethod(LinearMethodBase): ...@@ -338,7 +338,7 @@ class AWQLinearMethod(LinearMethodBase):
if envs.VLLM_USE_TRITON_AWQ: if envs.VLLM_USE_TRITON_AWQ:
if m>16: if m>16:
m = 2 ** math.ceil(math.log2(m)) m = 1 << (m - 1).bit_length()
best_config=getspec_config(m,n,k) best_config=getspec_config(m,n,k)
out = awq_gemm_triton(reshaped_x, qweight, scales, qzeros, pack_factor, best_config) out = awq_gemm_triton(reshaped_x, qweight, scales, qzeros, pack_factor, best_config)
out_shape = (x.shape[:-1] + (qweight.shape[1] * 8, )) out_shape = (x.shape[:-1] + (qweight.shape[1] * 8, ))
......
...@@ -26,7 +26,8 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils import ( ...@@ -26,7 +26,8 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils import (
marlin_make_empty_g_idx, marlin_make_workspace_new, marlin_make_empty_g_idx, marlin_make_workspace_new,
marlin_moe_permute_scales, marlin_permute_scales, marlin_moe_permute_scales, marlin_permute_scales,
moe_awq_to_marlin_zero_points, verify_marlin_supported, moe_awq_to_marlin_zero_points, verify_marlin_supported,
verify_marlin_supports_shape) verify_marlin_supports_shape,
awq_marlin_moe_permute_sz)
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.parameter import (GroupQuantScaleParameter, from vllm.model_executor.parameter import (GroupQuantScaleParameter,
PackedvLLMParameter) PackedvLLMParameter)
...@@ -131,10 +132,10 @@ class AWQMarlinConfig(QuantizationConfig): ...@@ -131,10 +132,10 @@ class AWQMarlinConfig(QuantizationConfig):
return UnquantizedLinearMethod() return UnquantizedLinearMethod()
# Check if the layer is supported by AWQMarlin. # Check if the layer is supported by AWQMarlin.
if not check_marlin_supports_layer(layer, self.group_size): if not check_marlin_supports_layer(layer, self.group_size):
logger.warning_once( # logger.warning_once(
"Layer '%s' is not supported by AWQMarlin. Falling back to unoptimized AWQ kernels.", # noqa: E501 # "Layer '%s' is not supported by AWQMarlin. Falling back to unoptimized AWQ kernels.", # noqa: E501
prefix, # prefix,
) # )
return AWQConfig.from_config( return AWQConfig.from_config(
self.full_config).get_quant_method(layer, prefix) self.full_config).get_quant_method(layer, prefix)
return AWQMarlinLinearMethod(self) return AWQMarlinLinearMethod(self)
...@@ -158,8 +159,8 @@ class AWQMarlinConfig(QuantizationConfig): ...@@ -158,8 +159,8 @@ class AWQMarlinConfig(QuantizationConfig):
group_size = quant_config.get("group_size") group_size = quant_config.get("group_size")
zero_point = quant_config.get("zero_point") zero_point = quant_config.get("zero_point")
if not current_platform.is_cuda(): # if not current_platform.is_cuda():
return False # return False
if quant_method != "awq": if quant_method != "awq":
return False return False
...@@ -441,7 +442,7 @@ class AWQMoEMethod(FusedMoEMethodBase): ...@@ -441,7 +442,7 @@ class AWQMoEMethod(FusedMoEMethodBase):
group_size=self.quant_config.group_size, group_size=self.quant_config.group_size,
) )
replace_parameter(layer, "w13_scales", marlin_w13_scales) #replace_parameter(layer, "w13_scales", marlin_w13_scales)
marlin_w2_scales = marlin_moe_permute_scales( marlin_w2_scales = marlin_moe_permute_scales(
s=layer.w2_scales, s=layer.w2_scales,
...@@ -449,21 +450,41 @@ class AWQMoEMethod(FusedMoEMethodBase): ...@@ -449,21 +450,41 @@ class AWQMoEMethod(FusedMoEMethodBase):
size_n=layer.w2_scales.shape[2], size_n=layer.w2_scales.shape[2],
group_size=self.quant_config.group_size, group_size=self.quant_config.group_size,
) )
replace_parameter(layer, "w2_scales", marlin_w2_scales) #replace_parameter(layer, "w2_scales", marlin_w2_scales)
marlin_w13_zp = moe_awq_to_marlin_zero_points( marlin_w13_zp = moe_awq_to_marlin_zero_points(
layer.w13_qzeros, layer.w13_qzeros,
size_k=layer.w13_qzeros.shape[1], size_k=layer.w13_qzeros.shape[1],
size_n=layer.w13_qzeros.shape[2] * self.quant_config.pack_factor, size_n=layer.w13_qzeros.shape[2] * self.quant_config.pack_factor,
num_bits=self.quant_config.weight_bits) num_bits=self.quant_config.weight_bits)
replace_parameter(layer, "w13_qzeros", marlin_w13_zp) # replace_parameter(layer, "w13_qzeros", marlin_w13_zp)
marlin_w2_zp = moe_awq_to_marlin_zero_points( marlin_w2_zp = moe_awq_to_marlin_zero_points(
layer.w2_qzeros, layer.w2_qzeros,
size_k=layer.w2_qzeros.shape[1], size_k=layer.w2_qzeros.shape[1],
size_n=layer.w2_qzeros.shape[2] * self.quant_config.pack_factor, size_n=layer.w2_qzeros.shape[2] * self.quant_config.pack_factor,
num_bits=self.quant_config.weight_bits) num_bits=self.quant_config.weight_bits)
replace_parameter(layer, "w2_qzeros", marlin_w2_zp) # replace_parameter(layer, "w2_qzeros", marlin_w2_zp)
marlin_w13_sz = awq_marlin_moe_permute_sz(
marlin_w13_scales,
marlin_w13_zp,
size_k=layer.w13_scales.shape[1] * self.quant_config.group_size,
size_n=layer.w13_scales.shape[2]
)
marlin_w2_sz = awq_marlin_moe_permute_sz(
marlin_w2_scales,
marlin_w2_zp,
size_k=layer.w2_scales.shape[1] * self.quant_config.group_size,
size_n=layer.w2_scales.shape[2]
)
replace_parameter(layer, "w13_scales", marlin_w13_sz)
replace_parameter(layer, "w2_scales", marlin_w2_sz)
layer.w13_qzeros = None
layer.w2_qzeros = None
torch.cuda.empty_cache()
def apply( def apply(
self, self,
...@@ -482,6 +503,9 @@ class AWQMoEMethod(FusedMoEMethodBase): ...@@ -482,6 +503,9 @@ class AWQMoEMethod(FusedMoEMethodBase):
e_score_correction_bias: Optional[torch.Tensor] = None, e_score_correction_bias: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False, apply_router_weight_on_input: bool = False,
activation: str = "silu", activation: str = "silu",
use_nn_moe: Optional[bool] = False,
routed_scaling_factor: Optional[float] = None,
use_fused_gate: Optional[bool] = False,
enable_eplb: bool = False, enable_eplb: bool = False,
expert_load_view: Optional[torch.Tensor] = None, expert_load_view: Optional[torch.Tensor] = None,
logical_to_physical_map: Optional[torch.Tensor] = None, logical_to_physical_map: Optional[torch.Tensor] = None,
...@@ -503,7 +527,9 @@ class AWQMoEMethod(FusedMoEMethodBase): ...@@ -503,7 +527,9 @@ class AWQMoEMethod(FusedMoEMethodBase):
num_expert_group=num_expert_group, num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function, custom_routing_function=custom_routing_function,
scoring_func=scoring_func, scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias) e_score_correction_bias=e_score_correction_bias,
routed_scaling_factor=routed_scaling_factor,
use_fused_gate=use_fused_gate)
return torch.ops.vllm.fused_marlin_moe( return torch.ops.vllm.fused_marlin_moe(
x, x,
...@@ -514,10 +540,12 @@ class AWQMoEMethod(FusedMoEMethodBase): ...@@ -514,10 +540,12 @@ class AWQMoEMethod(FusedMoEMethodBase):
router_logits, router_logits,
topk_weights, topk_weights,
topk_ids, topk_ids,
quant_type_id=self.quant_type.id, # quant_type_id=self.quant_type.id,
apply_router_weight_on_input=apply_router_weight_on_input, # apply_router_weight_on_input=apply_router_weight_on_input,
global_num_experts=global_num_experts, global_num_experts=global_num_experts,
expert_map=expert_map, expert_map=expert_map,
w1_zeros=layer.w13_qzeros, w1_zeros=layer.w13_qzeros,
w2_zeros=layer.w2_qzeros, w2_zeros=layer.w2_qzeros,
workspace=layer.workspace) # workspace=layer.workspace
num_bits=4
)
...@@ -14,7 +14,11 @@ from vllm.platforms import current_platform ...@@ -14,7 +14,11 @@ from vllm.platforms import current_platform
from vllm.scalar_type import ScalarType, scalar_types from vllm.scalar_type import ScalarType, scalar_types
from .quant_utils import pack_cols, unpack_cols from .quant_utils import pack_cols, unpack_cols
try:
import marlin
except Exception:
print("INFO: Please install marlin if you want to infer awq moe of marlin.\n")
logger = init_logger(__name__) logger = init_logger(__name__)
GPTQ_MARLIN_TILE = 16 GPTQ_MARLIN_TILE = 16
...@@ -153,7 +157,7 @@ def check_marlin_supports_shape(output_size_per_partition: int, ...@@ -153,7 +157,7 @@ def check_marlin_supports_shape(output_size_per_partition: int,
return False, e.__str__() return False, e.__str__()
return True, None return True, None
#暂不支持marlinlinear
def check_marlin_supports_layer(layer: LinearBase, group_size: int) \ def check_marlin_supports_layer(layer: LinearBase, group_size: int) \
-> bool: -> bool:
output_size_per_partition = getattr(layer, "output_size_per_partition", output_size_per_partition = getattr(layer, "output_size_per_partition",
...@@ -161,12 +165,12 @@ def check_marlin_supports_layer(layer: LinearBase, group_size: int) \ ...@@ -161,12 +165,12 @@ def check_marlin_supports_layer(layer: LinearBase, group_size: int) \
input_size_per_partition = getattr(layer, "input_size_per_partition", input_size_per_partition = getattr(layer, "input_size_per_partition",
None) or layer.input_size None) or layer.input_size
return check_marlin_supports_shape( # return check_marlin_supports_shape(
output_size_per_partition=output_size_per_partition, # output_size_per_partition=output_size_per_partition,
input_size_per_partition=input_size_per_partition, # input_size_per_partition=input_size_per_partition,
input_size=layer.input_size, # input_size=layer.input_size,
group_size=group_size)[0] # group_size=group_size)[0]
return False
def check_moe_marlin_supports_layer(layer: LinearBase, group_size: int) \ def check_moe_marlin_supports_layer(layer: LinearBase, group_size: int) \
-> bool: -> bool:
...@@ -237,30 +241,46 @@ def marlin_sort_g_idx( ...@@ -237,30 +241,46 @@ def marlin_sort_g_idx(
return g_idx[g_idx_sort_indices], g_idx_sort_indices return g_idx[g_idx_sort_indices], g_idx_sort_indices
# def get_scale_perms():
# scale_perm: list[int] = []
# for i in range(8):
# scale_perm.extend([i + 8 * j for j in range(8)])
# scale_perm_single: list[int] = []
# for i in range(4):
# scale_perm_single.extend(
# [2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]])
# return scale_perm, scale_perm_single
# def marlin_permute_scales(s: torch.Tensor, size_k: int, size_n: int,
# group_size: int) -> torch.Tensor:
# scale_perm, scale_perm_single = get_scale_perms()
# if group_size < size_k and group_size != -1:
# s = s.reshape((-1, len(scale_perm)))[:, scale_perm]
# else:
# s = s.reshape((-1, len(scale_perm_single)))[:, scale_perm_single]
# s = s.reshape((-1, size_n)).contiguous()
# return s
def get_scale_perms(): def get_scale_perms():
scale_perm: list[int] = [] scale_perm: List[int] = []
for i in range(8): for i in range(16): # 遍历列方向不同scale的 8个线程
scale_perm.extend([i + 8 * j for j in range(8)]) scale_perm.extend([i + 16 * j for j in range(8)]) # 插入 8 个数据块中 对应位置的索引
scale_perm_single: list[int] = [] return scale_perm
for i in range(4):
scale_perm_single.extend(
[2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]]) def marlin_permute_scales(s: torch.Tensor, # [56, 512] # torch.float16
return scale_perm, scale_perm_single size_k: int, # 7168
size_n: int, # 512
group_size: int # 128
def marlin_permute_scales(s: torch.Tensor, size_k: int, size_n: int, ) -> torch.Tensor:
group_size: int) -> torch.Tensor: # 将[128, 128](fp16) B矩阵中 每个[16, 16]计算块中的对应位置的 zero值 放到一起
scale_perm = get_scale_perms()
scale_perm, scale_perm_single = get_scale_perms() s = s.reshape((-1, len(scale_perm)))[:, scale_perm]
if group_size < size_k and group_size != -1:
s = s.reshape((-1, len(scale_perm)))[:, scale_perm]
else:
s = s.reshape((-1, len(scale_perm_single)))[:, scale_perm_single]
s = s.reshape((-1, size_n)).contiguous() s = s.reshape((-1, size_n)).contiguous()
return s return s
def marlin_moe_permute_scales( def marlin_moe_permute_scales(
s: torch.Tensor, s: torch.Tensor,
size_k: int, size_k: int,
...@@ -281,19 +301,18 @@ def marlin_moe_permute_scales( ...@@ -281,19 +301,18 @@ def marlin_moe_permute_scales(
def marlin_zero_points(zp: torch.Tensor, size_k: int, size_n: int, def marlin_zero_points(zp: torch.Tensor, size_k: int, size_n: int,
num_bits: int) -> torch.Tensor: num_bits: int) -> torch.Tensor:
# Permute zero-points in a similar way to scales, but do not use the # 和 scale 使用一致的重排逻辑,将[128, 128](fp16) B矩阵中 每个[16, 16]计算块中的对应位置的 zero值 放到一起
# "single" permutation, since zero-points are applied on every MMA scale_perm = get_scale_perms()
scale_perm, _ = get_scale_perms()
zp = zp.reshape((-1, len(scale_perm)))[:, scale_perm] zp = zp.reshape((-1, len(scale_perm)))[:, scale_perm]
# Interleave column dim (for the dequantize code) and pack it to int32 # uint4 混排
if num_bits == 4: if num_bits == 4:
interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7]) interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7])
elif num_bits == 8: elif num_bits == 8:
interleave = numpy.array([0, 2, 1, 3]) interleave = numpy.array([0, 2, 1, 3])
else: else:
raise Exception("num_bits must be 4 or 8, got {}".format(num_bits)) raise Exception("num_bits must be 4 or 8, got {}".format(num_bits))
# uint4打包成 int32
zp = zp.reshape((-1, len(interleave)))[:, interleave].ravel() zp = zp.reshape((-1, len(interleave)))[:, interleave].ravel()
zp = zp.reshape((-1, size_n)).contiguous() zp = zp.reshape((-1, size_n)).contiguous()
zp = pack_cols(zp, num_bits, size_k, size_n) zp = pack_cols(zp, num_bits, size_k, size_n)
...@@ -474,3 +493,86 @@ def apply_awq_marlin_linear( ...@@ -474,3 +493,86 @@ def apply_awq_marlin_linear(
output.add_(bias) # In-place add output.add_(bias) # In-place add
return output.reshape(out_shape) return output.reshape(out_shape)
def merge_scales_zeros(marlin_s: torch.Tensor, marlin_zp: torch.Tensor,
data_num_0: int, data_num_1: int) -> torch.Tensor:
"""
合并两个 Tensor, 每行交替取 data_num_0 个 float16 和 data_num_1 个 int32。
要求:
- marlin_s 每行长度能被 data_num_0 整除
- marlin_zp 每行长度能被 data_num_1 整除
- 合并后的总字节数必为 4 的倍数
返回:
[N, M] 的 int32 Tensor(行数一致,列数已对齐)
"""
assert marlin_s.shape[0] == marlin_zp.shape[0], "Batch size mismatch"
assert marlin_s.dtype == torch.float16
assert marlin_zp.dtype == torch.int32
N, D0 = marlin_s.shape
_, D1 = marlin_zp.shape
assert D0 % data_num_0 == 0, "marlin_s 每行必须能被 data_num_0 整除"
assert D1 % data_num_1 == 0, "marlin_zp 每行必须能被 data_num_1 整除"
s_block_count = D0 // data_num_0
zp_block_count = D1 // data_num_1
assert s_block_count == zp_block_count
total_blocks = s_block_count
# 转为字节视图
s_bytes = marlin_s.view(torch.uint8).reshape(N, -1)
zp_bytes = marlin_zp.view(torch.uint8).reshape(N, -1)
# 每行的合并结果
merged_rows = []
for i in range(N):
s_row = s_bytes[i]
zp_row = zp_bytes[i]
s_ptr = 0
zp_ptr = 0
merged = []
for _ in range(total_blocks):
# 如果 s 还有剩余 block,就取
if s_ptr < s_row.numel():
chunk_s = s_row[s_ptr: s_ptr + data_num_0 * 2] # float16 = 2 字节
merged.append(chunk_s)
s_ptr += data_num_0 * 2
# 如果 zp 还有剩余 block,就取
if zp_ptr < zp_row.numel():
chunk_zp = zp_row[zp_ptr: zp_ptr + data_num_1 * 4] # int32 = 4 字节
merged.append(chunk_zp)
zp_ptr += data_num_1 * 4
# 合并所有字节,并直接转换为 int32
merged_bytes = torch.cat(merged)
# assert merged_bytes.numel() % 4 == 0, "最终字节长度必须是4的倍数"
merged_int32 = merged_bytes.view(torch.int32)
merged_rows.append(merged_int32)
# 所有合并行长度一致,可以直接堆叠
result = torch.stack(merged_rows)
return result
def awq_marlin_moe_permute_sz(
s : torch.Tensor,
z : torch.Tensor,
size_k: int,
size_n: int,
) -> torch.Tensor:
num_experts = s.shape[0]
# output = torch.empty((num_experts, size_k // 16, size_n//2 + size_n//8),
# device=z.device,
# dtype=z.dtype)
outputs = []
for e in range(num_experts):
out_sz = merge_scales_zeros(s[e], z[e], 128, 16)
outputs.append(out_sz)
return torch.stack(outputs, dim=0)
...@@ -164,7 +164,7 @@ class DeepSeekMTP(nn.Module, SupportsPP): ...@@ -164,7 +164,7 @@ class DeepSeekMTP(nn.Module, SupportsPP):
os.environ['LLAMA_NN'] = '0' os.environ['LLAMA_NN'] = '0'
os.environ['LM_NN'] = '0' os.environ['LM_NN'] = '0'
# The AWQ layer of MTP uses BlockInt8W8A8. # The AWQ layer of MTP uses BlockInt8W8A8.
if self.quant_method == "moe_wna16": if self.quant_method == "moe_wna16" or self.quant_method == "awq_marlin":
vllm_config.quant_config = BlockInt8Config(is_checkpoint_int8_serialized=True, weight_block_size=[128,128]) vllm_config.quant_config = BlockInt8Config(is_checkpoint_int8_serialized=True, weight_block_size=[128,128])
self.model = DeepSeekMultiTokenPredictor(vllm_config=vllm_config, self.model = DeepSeekMultiTokenPredictor(vllm_config=vllm_config,
......
...@@ -180,7 +180,7 @@ class RocmPlatform(Platform): ...@@ -180,7 +180,7 @@ class RocmPlatform(Platform):
supported_quantization: list[str] = [ supported_quantization: list[str] = [
"awq", "gptq", "fp8", "compressed-tensors", "fbgemm_fp8", "gguf", "awq", "gptq", "fp8", "compressed-tensors", "fbgemm_fp8", "gguf",
"quark", "ptpc_fp8", "moe_wna16", "blockwise_int8","w8a8_int8" "quark", "ptpc_fp8", "moe_wna16", "blockwise_int8","w8a8_int8","awq_marlin"
] ]
@classmethod @classmethod
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment