Commit c637d1aa authored by 王敏's avatar 王敏
Browse files

Merge remote-tracking branch 'origin/v0.15.1-dev-pcp' into v0.15.1-dev-pcp

parents f3d1f95b 263d6216
...@@ -733,7 +733,8 @@ class FusedMoE(CustomOp): ...@@ -733,7 +733,8 @@ class FusedMoE(CustomOp):
if (self.quant_method.__class__.__name__ in ("BlockInt8MoEMethod", if (self.quant_method.__class__.__name__ in ("BlockInt8MoEMethod",
"SlimQuantW4A8Int8MoEMethod", "SlimQuantW4A8Int8MoEMethod",
"SlimQuantW4A8Int8MarlinMoEMethod")): "SlimQuantW4A8Int8MarlinMoEMethod",
"SlimQuantW4A8Int8AiterMoEMethod")):
moe_quant_params["intermediate_size"] = self.intermediate_size_per_partition moe_quant_params["intermediate_size"] = self.intermediate_size_per_partition
......
...@@ -25,6 +25,21 @@ import os ...@@ -25,6 +25,21 @@ import os
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm import envs from vllm import envs
from vllm.logger import init_logger
logger = init_logger(__name__)
try:
from aiter.ops.shuffle import w4a8_moe_layout_shuffle_gemm1,w4a8_moe_layout_shuffle_gemm2
from aiter.moe import (
get_aiter_moe_config,
aiter_moe,
MoeSolutionType,
MoeQuantType,
)
from aiter import dtypes, ActivationType
except ImportError as e:
print("Import error msg: import aiter")
W8A8_TRITONJSON=W8a8GetCacheJSON() W8A8_TRITONJSON=W8a8GetCacheJSON()
def baseline_scaled_mm(a: torch.Tensor, def baseline_scaled_mm(a: torch.Tensor,
...@@ -82,7 +97,10 @@ class SlimQuantW4A8Int8Config(QuantizationConfig): ...@@ -82,7 +97,10 @@ class SlimQuantW4A8Int8Config(QuantizationConfig):
if isinstance(layer, LinearBase): if isinstance(layer, LinearBase):
return SlimQuantW4A8Int8LinearMethod(self) return SlimQuantW4A8Int8LinearMethod(self)
elif isinstance(layer, FusedMoE): elif isinstance(layer, FusedMoE):
return SlimQuantW4A8Int8MoEMethod(self, layer.moe_config) if envs.VLLM_ROCM_USE_AITER_MOE:
return SlimQuantW4A8Int8AiterMoEMethod(self, layer.moe_config)
else:
return SlimQuantW4A8Int8MoEMethod(self, layer.moe_config)
return None return None
def get_scaled_act_names(self) -> List[str]: def get_scaled_act_names(self) -> List[str]:
...@@ -328,4 +346,214 @@ class SlimQuantW4A8Int8MoEMethod: ...@@ -328,4 +346,214 @@ class SlimQuantW4A8Int8MoEMethod:
use_nn_moe=use_nn_moe, use_nn_moe=use_nn_moe,
shared_output=shared_output, shared_output=shared_output,
routed_scaling_factor=routed_scaling_factor, routed_scaling_factor=routed_scaling_factor,
)
class SlimQuantW4A8Int8AiterMoEMethod:
"""MoE method for W4A8INT8.
Supports loading INT8 checkpoints with static weight scale and
dynamic/static activation scale.
Also supports loading quantized FP16/BF16 model checkpoints with dynamic
activation scaling. The weight scaling factor will be initialized after
the model weights are loaded.
Args:
quant_config: The quantization config.
"""
def __new__(cls, *args, **kwargs):
if not hasattr(cls, "_initialized"):
original_init = cls.__init__
new_cls = type(
cls.__name__,
(FusedMoEMethodBase,),
{
"__init__": original_init,
**{k: v for k, v in cls.__dict__.items() if k != "__dict__"},
},
)
obj = super(new_cls, new_cls).__new__(new_cls)
obj.__init__(*args, **kwargs)
return obj
return super().__new__(cls)
def __init__(self, quant_config, moe):
self.moe = moe
self.quant_config = quant_config
self.tritonsingleton= W8a8GetCacheJSON()
self.moe_quant_config: Optional[FusedMoEQuantConfig] = None
self.moe_mk: Optional[FusedMoEModularKernel] = None
def get_fused_moe_quant_config(
self, layer: torch.nn.Module)-> Optional[FusedMoEQuantConfig]:
self.moe_quant_config = FusedMoEQuantConfig.make(
torch.int8,
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale,
per_act_token_quant=True,
per_out_ch_quant=False,
block_shape=None,
weight_dtype='int4'
) )
return self.moe_quant_config
def create_weights(
self,
layer: torch.nn.Module,
num_experts: int,
hidden_size: int,
intermediate_size: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
tp_size = get_tensor_model_parallel_world_size()
# WEIGHTS
w13_weight = torch.nn.Parameter(
torch.empty(
num_experts, 2 * intermediate_size, hidden_size//2, dtype=torch.int8
),
requires_grad=False,
)
layer.register_parameter("w13_weight", w13_weight)
set_weight_attrs(w13_weight, extra_weight_attrs)
w2_weight = torch.nn.Parameter(
torch.empty(num_experts, hidden_size, intermediate_size//2, dtype=torch.int8),
requires_grad=False,
)
layer.register_parameter("w2_weight", w2_weight)
set_weight_attrs(w2_weight, extra_weight_attrs)
w13_weight_scale = torch.nn.Parameter(
torch.ones(num_experts, 2 * intermediate_size, 1, dtype=torch.float32),
requires_grad=False,
)
w2_weight_scale = torch.nn.Parameter(
torch.ones(num_experts, hidden_size, 1, dtype=torch.float32),
requires_grad=False,
)
layer.register_parameter("w13_weight_scale", w13_weight_scale)
layer.register_parameter("w2_weight_scale", w2_weight_scale)
extra_weight_attrs.update(
{"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value}
)
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
w13_input_scale = None
layer.register_parameter("w13_input_scale", w13_input_scale)
w2_input_scale = None
layer.register_parameter("w2_input_scale", w2_input_scale)
def repack_and_shuffle_w4a8(self, weight_data, E):
"""
逐 expert 处理 [n, k_half]
处理完直接写回 weight_data[i]
"""
# 原始 shape: [E, n, k_half]
for i in range(E):
# 1. 取当前 expert [n, k_half]
expert = weight_data[i]
n, k_half = expert.shape
# 2. repack 逻辑(连续 → blocked)
w_u8 = expert.to(torch.uint8)
# 解包 1byte → 2个4bit
w_unpacked = torch.stack([
(w_u8 >> 4) & 0x0F,
w_u8 & 0x0F
], dim=-1).view(n, -1)
# 8个4bit分块重排
blocks = w_unpacked.view(n, -1, 8)
w_low = blocks[..., :4]
w_high = blocks[..., 4:]
packed = (w_low << 4) | w_high
packed = packed.view(n, k_half)
# 3. shuffle
w_marlin_in = w4a8_moe_layout_shuffle_gemm2(packed)
w_marlin_in = w_marlin_in.reshape(n, k_half)
# 4. 直接写回
weight_data[i] = w_marlin_in
return weight_data
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
E=layer.w13_weight.shape[0]
layer.w13_weight_scale = Parameter(
layer.w13_weight_scale.data, requires_grad=False
)
layer.w2_weight_scale = Parameter(
layer.w2_weight_scale.data, requires_grad=False
)
layer.w13_weight = Parameter(self.repack_and_shuffle_w4a8(layer.w13_weight.data, E), requires_grad=False)
layer.w2_weight = Parameter(self.repack_and_shuffle_w4a8(layer.w2_weight.data, E), requires_grad=False)
def apply(
self,
layer: FusedMoE,
x: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
use_nn_moe: bool | None = False,
use_fused_gate: bool | None = False,
i_q: torch.Tensor | None = None,
i_s: torch.Tensor | None = None,
shared_output: torch.Tensor | None = None,
routed_scaling_factor: float = 1.0,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
from vllm.model_executor.layers.fused_moe import fused_experts
E = layer.w13_weight.size(0)
K = x.size(-1)
N1 = layer.w13_weight.size(1)
if x.dim() == 2:
# Make sure we are using the correct a1 (pre-permute).
M = x.size(0)
else:
assert x.dim() == 3
assert x.size(0) == E, f"{x.size(0)} == {E}"
M = x.size(1)
topk = topk_ids.size(1)
status, moe_cfg = get_aiter_moe_config(
M=M,
E=E,
N1=N1,
N2=N1//2,
K=K,
top_k=topk,
block_size=None,
dtype=dtypes.bf16,
quant_type=MoeQuantType.W4A8,
)
if not status:
assert moe_cfg.solution_type is None
assert moe_cfg.config is None
logger.info(f"[get_config_w4a8] {M=}, no solution found")
return aiter_moe(
x,
layer.w13_weight,
layer.w2_weight,
topk_weights,
topk_ids,
moe_cfg,
layer.w13_weight_scale,
layer.w2_weight_scale,
w1_zp=None,
w2_zp=None,
a1_scale=None,
a2_scale=None,
block_shape=None,
global_num_experts=E,
expert_map=None,
activation="silu"
)
\ No newline at end of file
...@@ -16,7 +16,7 @@ from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase, ...@@ -16,7 +16,7 @@ from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase,
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from vllm.model_executor.layers.fused_moe.modular_kernel import ( from vllm.model_executor.layers.fused_moe.modular_kernel import (
FusedMoEModularKernel) FusedMoEModularKernel)
from vllm.model_executor.layers.quantization.slimquant_w4a8 import SlimQuantW4A8Int8LinearMethod from vllm.model_executor.layers.quantization.slimquant_w4a8 import SlimQuantW4A8Int8LinearMethod, SlimQuantW4A8Int8AiterMoEMethod
from vllm.model_executor.layers.fused_moe.fused_moe import get_moe_cache from vllm.model_executor.layers.fused_moe.fused_moe import get_moe_cache
try: try:
from lmslim.layers.fused_moe.fuse_moe_w4a8_marlin import fused_experts_impl_w4a8_marlin from lmslim.layers.fused_moe.fuse_moe_w4a8_marlin import fused_experts_impl_w4a8_marlin
...@@ -111,7 +111,10 @@ class SlimQuantW4A8Int8MarlinConfig(QuantizationConfig): ...@@ -111,7 +111,10 @@ class SlimQuantW4A8Int8MarlinConfig(QuantizationConfig):
if isinstance(layer, LinearBase): if isinstance(layer, LinearBase):
return SlimQuantW4A8Int8LinearMethod(self) return SlimQuantW4A8Int8LinearMethod(self)
elif isinstance(layer, FusedMoE): elif isinstance(layer, FusedMoE):
return SlimQuantW4A8Int8MarlinMoEMethod(self, layer.moe_config) if envs.VLLM_ROCM_USE_AITER_MOE:
return SlimQuantW4A8Int8AiterMoEMethod(self, layer.moe_config)
else:
return SlimQuantW4A8Int8MarlinMoEMethod(self, layer.moe_config)
return None return None
def get_scaled_act_names(self) -> List[str]: def get_scaled_act_names(self) -> List[str]:
......
...@@ -30,6 +30,7 @@ elif current_platform.is_xpu(): ...@@ -30,6 +30,7 @@ elif current_platform.is_xpu():
from vllm._ipex_ops import ipex_ops as ops from vllm._ipex_ops import ipex_ops as ops
logger = init_logger(__name__) logger = init_logger(__name__)
_GLOBAL_LOGITS_BUFFERS = {}
@maybe_transfer_kv_layer @maybe_transfer_kv_layer
def sparse_attn_indexer( def sparse_attn_indexer(
...@@ -50,7 +51,21 @@ def sparse_attn_indexer( ...@@ -50,7 +51,21 @@ def sparse_attn_indexer(
# careful! this will be None in dummy run # careful! this will be None in dummy run
attn_metadata = get_forward_context().attn_metadata attn_metadata = get_forward_context().attn_metadata
fp8_dtype = current_platform.fp8_dtype() fp8_dtype = current_platform.fp8_dtype()
if q_fp8.dtype == fp8_dtype:
MAX_ELEMENTS = 65536 * 65536
elif q_fp8.dtype in (torch.bfloat16, torch.float16):
MAX_ELEMENTS = 16384 * 32768
else:
MAX_ELEMENTS = 16384 * 32768
device = q_fp8.device
if device not in _GLOBAL_LOGITS_BUFFERS or _GLOBAL_LOGITS_BUFFERS[device].numel() < MAX_ELEMENTS:
_GLOBAL_LOGITS_BUFFERS[device] = torch.empty(
MAX_ELEMENTS,
dtype=torch.float32,
device=device
)
logits_buffer = _GLOBAL_LOGITS_BUFFERS[device]
# assert isinstance(attn_metadata, dict) # assert isinstance(attn_metadata, dict)
if not isinstance(attn_metadata, dict): if not isinstance(attn_metadata, dict):
# Reserve workspace for indexer during profiling run # Reserve workspace for indexer during profiling run
...@@ -140,18 +155,21 @@ def sparse_attn_indexer( ...@@ -140,18 +155,21 @@ def sparse_attn_indexer(
weights_all = weights[chunk.token_start:chunk.token_end] weights_all = weights[chunk.token_start:chunk.token_end]
ks_all = chunk.cu_seqlen_ks ks_all = chunk.cu_seqlen_ks
ke_all = chunk.cu_seqlen_ke ke_all = chunk.cu_seqlen_ke
num_q = q_all.shape[0] num_q = q_all.shape[0]
num_k = k_fp8.shape[0] num_k = k_fp8.shape[0]
MAX_ELEMENTS = 1024 * 1024 * 1024 # 4GB is_q_fp16_bf16 = q_all.dtype in (torch.float16, torch.bfloat16)
if (num_q <= 65536 and num_k <= 65536): # if num_q <= 65536 and num_k <= 65536 and (num_q * num_k <= MAX_ELEMENTS): align_size = 128 if is_q_fp16_bf16 else 1
MAX_Q_CHUNK = max(1, num_q)
else: kv_seq_len_aligned = (num_k + align_size - 1) // align_size * align_size
MAX_Q_CHUNK = max(1024, MAX_ELEMENTS // max(1, num_k))
MAX_Q_CHUNK = min(MAX_Q_CHUNK, max(1, num_q)) current_capacity = logits_buffer.numel()
MAX_Q_CHUNK = current_capacity // max(1, kv_seq_len_aligned)
if align_size > 1:
MAX_Q_CHUNK = (MAX_Q_CHUNK // align_size) * align_size
MAX_Q_CHUNK = max(1, MAX_Q_CHUNK)
#存储q的起始和终止地址
slices = [] slices = []
for start_idx in range(0, num_q, MAX_Q_CHUNK): for start_idx in range(0, num_q, MAX_Q_CHUNK):
...@@ -161,13 +179,19 @@ def sparse_attn_indexer( ...@@ -161,13 +179,19 @@ def sparse_attn_indexer(
for q_start, q_end in slices: for q_start, q_end in slices:
if q_end <= q_start: if q_end <= q_start:
continue continue
q_slice = q_all[q_start:q_end] q_slice = q_all[q_start:q_end]
weights_slice = weights_all[q_start:q_end] weights_slice = weights_all[q_start:q_end]
ks_slice = ks_all[q_start:q_end] ks_slice = ks_all[q_start:q_end]
ke_slice = ke_all[q_start:q_end] ke_slice = ke_all[q_start:q_end]
q_len = q_end - q_start
q_seq_len_aligned = (q_len + align_size - 1) // align_size * align_size
required_size = q_seq_len_aligned * kv_seq_len_aligned
logits_slice_view = logits_buffer[:required_size].view(q_seq_len_aligned, kv_seq_len_aligned)
if not current_platform.is_rocm(): if not current_platform.is_rocm():
logits_slice = fp8_mqa_logits( logits_slice = fp8_mqa_logits(
q_slice, q_slice,
...@@ -177,40 +201,44 @@ def sparse_attn_indexer( ...@@ -177,40 +201,44 @@ def sparse_attn_indexer(
ke_slice, ke_slice,
) )
elif get_gcn_arch_name() == "gfx938": elif get_gcn_arch_name() == "gfx938":
logits_slice = op.mqa_logits( op.mqa_logits(
q_slice, q_slice,
k_fp8, k_fp8,
weights_slice, weights_slice,
ks_slice, ks_slice,
ke_slice, ke_slice,
q_slice.shape[0], q_slice.shape[0],
k_fp8.shape[0], k_fp8.shape[0],
q_slice.shape[1], q_slice.shape[1],
q_slice.shape[2], q_slice.shape[2],
k_scale.view(torch.float32).flatten(), k_scale.view(torch.float32).flatten(),
True True,
logits_slice_view
) )
logits_slice = logits_slice_view[:q_len, :num_k]
else: else:
logits_slice = op.mqa_logits( op.mqa_logits(
q_slice, q_slice,
k_fp8, k_fp8,
weights_slice.to(torch.float32), weights_slice.to(torch.float32),
ks_slice, ks_slice,
ke_slice, ke_slice,
q_slice.shape[0], q_slice.shape[0],
k_fp8.shape[0], k_fp8.shape[0],
q_slice.shape[1], q_slice.shape[1],
q_slice.shape[2], q_slice.shape[2],
None, None,
True True,
logits_slice_view
) )
logits_slice = logits_slice_view[:q_len, :num_k]
num_rows_slice = logits_slice.shape[0] num_rows_slice = logits_slice.shape[0]
topk_indices_slice = topk_indices_buffer[ topk_indices_slice = topk_indices_buffer[
chunk.token_start + q_start : chunk.token_start + q_end, :topk_tokens chunk.token_start + q_start : chunk.token_start + q_end, :topk_tokens
] ]
if not envs.USE_LIGHTOP_TOPK: if not envs.USE_LIGHTOP_TOPK:
torch.ops._C.top_k_per_row_prefill( torch.ops._C.top_k_per_row_prefill(
logits_slice, logits_slice,
...@@ -460,6 +488,4 @@ class SparseAttnIndexer(CustomOp): ...@@ -460,6 +488,4 @@ class SparseAttnIndexer(CustomOp):
self.max_model_len, self.max_model_len,
self.max_total_seq_len, self.max_total_seq_len,
self.topk_indices_buffer, self.topk_indices_buffer,
) )
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment