Commit ca9ce18d authored by flyingdown's avatar flyingdown Committed by zhangzbb
Browse files

[FEATURE]:support w4a16 aiter moe

parent ba6f2101
...@@ -72,6 +72,23 @@ from vllm.utils.torch_utils import direct_register_custom_op, is_torch_equal_or_ ...@@ -72,6 +72,23 @@ from vllm.utils.torch_utils import direct_register_custom_op, is_torch_equal_or_
from lightop import fuse_silu_and_mul from lightop import fuse_silu_and_mul
from lightop import op as op from lightop import op as op
try:
if envs.VLLM_ROCM_USE_AITER_MOE:
from aiter.moe import (
get_aiter_moe_config,
aiter_moe,
MoeSolutionType,
MoeQuantType,
)
else:
raise Exception("VLLM_ROCM_USE_AITER_MOE not set.")
except Exception:
get_aiter_moe_config = None
aiter_moe = None
MoeQuantType = None
print("INFO: Please install aiter if you want to infer with aiter_moe.\n")
logger = init_logger(__name__) logger = init_logger(__name__)
if envs.VLLM_USE_GLOBAL_CACHE13: if envs.VLLM_USE_GLOBAL_CACHE13:
...@@ -1744,6 +1761,32 @@ def fused_experts_impl( ...@@ -1744,6 +1761,32 @@ def fused_experts_impl(
CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE
M = min(num_tokens, CHUNK_SIZE) M = min(num_tokens, CHUNK_SIZE)
if envs.VLLM_ROCM_USE_AITER_MOE and use_int4_w4a16 and hidden_states.dtype == torch.bfloat16 and get_aiter_moe_config is not None and aiter_moe is not None:
# 根据 aiter 的config 判断是否启用 aiter
M, K = hidden_states.shape
E, N1, _ = w1.shape
_, N2, _ = w2.shape
top_k_num = topk_ids.size(1)
status, moe_config = get_aiter_moe_config(
M=M, E=E, N1=N1, N2=N2, K=K,
top_k=top_k_num, block_size=block_shape[1], dtype=hidden_states.dtype,
quant_type=MoeQuantType.W4A16,
)
if not status:
logger.info_once(
f"[aiter_moe_w4a16] SKIP {M=}, {E=}, {N1=}, {N2=}, {K=}, {top_k_num=}, {block_shape=}: "
f"no backend available"
)
else:
is_inplace = inplace and not disable_inplace()
return aiter_moe(hidden_states, w1, w2, topk_weights, topk_ids, moe_config, is_inplace, activation, w1_scale, w2_scale, w1_zp, w2_zp, a1_scale, a2_scale,
block_shape, global_num_experts, expert_map)
# return aiter_moe(hidden_states, w1, w2, topk_weights, topk_ids, moe_config, w1_scale, w2_scale, w1_zp, w2_zp, a1_scale, a2_scale,
# block_shape, global_num_experts, expert_map, activation)
# Optional fast path: use Marlin W16A16 fused MoE implementation when the # Optional fast path: use Marlin W16A16 fused MoE implementation when the
# expert weights are already packed in Marlin layout. # expert weights are already packed in Marlin layout.
if not use_nn_moe: if not use_nn_moe:
......
...@@ -12,6 +12,7 @@ from compressed_tensors.quantization import ( ...@@ -12,6 +12,7 @@ from compressed_tensors.quantization import (
QuantizationStrategy, QuantizationStrategy,
) )
import vllm.envs as envs
import vllm.model_executor.layers.fused_moe.modular_kernel as mk import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import get_tensor_model_parallel_world_size
...@@ -1806,6 +1807,36 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod): ...@@ -1806,6 +1807,36 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
layer.a13_scale = None layer.a13_scale = None
layer.a2_scale = None layer.a2_scale = None
if envs.VLLM_ROCM_USE_AITER_MOE:
logger.info_once("Using aiter moe")
w1_zp = torch.empty(
num_experts,
2 * intermediate_size_per_partition,
hidden_size // self.group_size // 2 if self.num_bits == 4 else hidden_size // self.group_size,
dtype=torch.uint8,
)
if self.num_bits == 4: w1_zp[:] = 136
w13_qzeros = torch.nn.Parameter(
w1_zp,
requires_grad=False,
)
layer.register_parameter("w13_qzeros", w13_qzeros)
set_weight_attrs(w13_qzeros, extra_weight_attrs)
w2_zp = torch.empty(
num_experts,
intermediate_size_per_partition,
hidden_size // self.group_size // 2 if self.num_bits == 4 else hidden_size // self.group_size,
dtype=torch.uint8,
)
if self.num_bits == 4: w2_zp[:] = 136
w2_qzeros = torch.nn.Parameter(
w2_zp,
requires_grad=False,
)
layer.register_parameter("w2_qzeros", w2_qzeros)
set_weight_attrs(w2_qzeros, extra_weight_attrs)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None: def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
# Reconfigure packed weights and scales to match moe_wna16 format # Reconfigure packed weights and scales to match moe_wna16 format
layer.w13_weight_packed = torch.nn.Parameter( layer.w13_weight_packed = torch.nn.Parameter(
...@@ -1836,8 +1867,8 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod): ...@@ -1836,8 +1867,8 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
return config_builder( return config_builder(
w1_scale=layer.w13_weight_scale, w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale, w2_scale=layer.w2_weight_scale,
w1_zp=None, w1_zp=layer.w13_qzeros if envs.VLLM_ROCM_USE_AITER_MOE else None,
w2_zp=None, w2_zp=layer.w2_qzeros if envs.VLLM_ROCM_USE_AITER_MOE else None,
block_shape=[0, self.group_size], block_shape=[0, self.group_size],
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment