Commit ffcc47b7 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge branch 'v0.9.2-dev_mtp_sampler' into 'v0.9.2-dev'

Marlin W16A16 MoE: 清理未用量化接口与辅助代码,合入算子优化

See merge request dcutoolkit/deeplearing/vllm!298
parents 8548cf87 80e8f551
...@@ -10,13 +10,10 @@ import torch ...@@ -10,13 +10,10 @@ import torch
import triton import triton
import triton.language as tl import triton.language as tl
import lmslim.envs as lsenvs import lmslim.envs as lsenvs
from vllm.utils import W8a8GetCacheJSON
use_lightop = lsenvs.LMSLIM_USE_LIGHTOP use_lightop = lsenvs.LMSLIM_USE_LIGHTOP
device_name = lsenvs.LMSLIM_GPU_NAME device_name = lsenvs.LMSLIM_GPU_NAME
num_cus= torch.cuda.get_device_properties(torch.cuda.current_device()).multi_processor_count num_cus= torch.cuda.get_device_properties(torch.cuda.current_device()).multi_processor_count
from lmslim.layers.gemm.int8_utils import (
per_token_group_quant_int8, per_token_quant_int8)
from vllm.model_executor.layers.fused_moe.fused_moe import get_moe_cache from vllm.model_executor.layers.fused_moe.fused_moe import get_moe_cache
if use_lightop: if use_lightop:
from lightop import moe_gemm_marlin_w16a16, get_moe_cuda_marlin_config_w16a16 from lightop import moe_gemm_marlin_w16a16, get_moe_cuda_marlin_config_w16a16
...@@ -156,41 +153,6 @@ def moe_reduce_dispatch( ...@@ -156,41 +153,6 @@ def moe_reduce_dispatch(
else: else:
out_hidden_states[begin_chunk_idx:end_chunk_idx].mul_(routed_scaling_factor) out_hidden_states[begin_chunk_idx:end_chunk_idx].mul_(routed_scaling_factor)
def moe_kernel_prepare_input(
A: torch.Tensor,
B: torch.Tensor,
A_scale: Optional[torch.Tensor],
B_scale: Optional[torch.Tensor],
use_fp8_w8a8: bool,
use_int8_w8a8: bool,
use_int8_w8a16: bool,
use_int4_w4a16: bool,
use_int4_w4a8: bool,
per_channel_quant: bool,
block_shape: Optional[List[int]] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
if use_int8_w8a8 or use_int4_w4a8:
assert B_scale is not None
if block_shape is None:
# activation channel-wise int8 quantization
assert (per_channel_quant
), "int8 or int4 quantization only supports block or channel-wise"
A, A_scale = per_token_quant_int8(A)
else:
# activation block-wise int8 quantization
assert len(block_shape) == 2
_, block_k = block_shape[0], block_shape[1]
A, A_scale = per_token_group_quant_int8(A, block_k)
assert triton.cdiv(A.shape[-1], block_k) == A_scale.shape[-1]
elif use_int8_w8a16 or use_int4_w4a16:
assert B_scale is not None
assert block_shape is None or block_shape[0] == 0
else:
assert A_scale is None
assert B_scale is None
return A, A_scale
def round_up(x: int, y: int) -> int: def round_up(x: int, y: int) -> int:
return ((x + y - 1) // y) * y return ((x + y - 1) // y) * y
...@@ -262,21 +224,8 @@ def fused_experts_impl_w16a16_marlin(hidden_states: torch.Tensor, ...@@ -262,21 +224,8 @@ def fused_experts_impl_w16a16_marlin(hidden_states: torch.Tensor,
inplace: bool = False, inplace: bool = False,
activation: str = "silu", activation: str = "silu",
apply_router_weight_on_input: bool = False, apply_router_weight_on_input: bool = False,
use_fp8_w8a8: bool = False,
use_int8_w8a8: bool = False,
use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False,
use_int4_w4a8: bool = False,
per_channel_quant: bool = False,
global_num_experts: int = -1, global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None, expert_map: Optional[torch.Tensor] = None,
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
w1_zp: Optional[torch.Tensor] = None,
w2_zp: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None,
use_nn_moe: Optional[bool] = False, use_nn_moe: Optional[bool] = False,
routed_scaling_factor: Optional[float] = 1.0, routed_scaling_factor: Optional[float] = 1.0,
shared_output: Optional[torch.Tensor] = None, shared_output: Optional[torch.Tensor] = None,
...@@ -415,8 +364,13 @@ def fused_experts_impl_w16a16_marlin(hidden_states: torch.Tensor, ...@@ -415,8 +364,13 @@ def fused_experts_impl_w16a16_marlin(hidden_states: torch.Tensor,
top_k_num, top_k_num,
config_marlin_0, config_marlin_0,
) )
if (envs.VLLM_USE_FUSE_SILU_AND_MUL
torch.ops._C.silu_and_mul(intermediate_cache2, intermediate_cache1) and intermediate_cache1.dtype == intermediate_cache2.dtype
== torch.float16):
from lightop import fuse_silu_and_mul
fuse_silu_and_mul(intermediate_cache1, intermediate_cache2)
else:
torch.ops._C.silu_and_mul(intermediate_cache2, intermediate_cache1)
# GEMM2: intermediate_cache2 * w2, apply routing weights here. # GEMM2: intermediate_cache2 * w2, apply routing weights here.
moe_gemm_marlin_w16a16( moe_gemm_marlin_w16a16(
...@@ -432,38 +386,21 @@ def fused_experts_impl_w16a16_marlin(hidden_states: torch.Tensor, ...@@ -432,38 +386,21 @@ def fused_experts_impl_w16a16_marlin(hidden_states: torch.Tensor,
) )
intermediate_cache3 = intermediate_cache3.view(-1, top_k_num, K) intermediate_cache3 = intermediate_cache3.view(-1, top_k_num, K)
if is_ep: if envs.VLLM_USE_LIGHTOP_MOE_SUM_MUL_ADD:
op.moe_sum(input=intermediate_cache3, from lightop import op as op
output=out_hidden_states[begin_chunk_idx:end_chunk_idx], op.moe_sum(input=intermediate_cache3.view(*intermediate_cache3.size()),
bias = None if shared_output is None else shared_output[begin_chunk_idx:end_chunk_idx], output=out_hidden_states[begin_chunk_idx:end_chunk_idx], bias=shared_output[begin_chunk_idx:end_chunk_idx],
expert_mask = expert_mask[:tokens_in_chunk], expert_mask=None, num_local_tokens=None, factor=routed_scaling_factor)
num_local_tokens=num_local_tokens,
factor=routed_scaling_factor,
)
elif use_lightop and shared_output is not None:
op.moe_sum(input=intermediate_cache3.view(*intermediate_cache3.shape),
output=out_hidden_states[begin_chunk_idx:end_chunk_idx],
bias=shared_output[begin_chunk_idx:end_chunk_idx],
expert_mask=None,
num_local_tokens=None,
factor=routed_scaling_factor)
elif shared_output is not None:
moe_reduce_dispatch(
intermediate_cache3,
out_hidden_states,
begin_chunk_idx,
end_chunk_idx,
routed_scaling_factor,
shared_output,
)
else: else:
moe_reduce_dispatch( if envs.VLLM_USE_LIGHTOP_MOE_SUM:
intermediate_cache3, from lightop import op as op
out_hidden_states, op.moe_sum(input=intermediate_cache3.view(*intermediate_cache3.size()),
begin_chunk_idx, output=out_hidden_states[begin_chunk_idx:end_chunk_idx], bias=None,
end_chunk_idx, expert_mask=None, num_local_tokens=None, factor=1.0)
1.0, elif envs.VLLM_USE_OPT_MOE_SUM:
None, moe_reduce_dispatch(intermediate_cache3.view(*intermediate_cache3.size()), out_hidden_states[begin_chunk_idx:end_chunk_idx], begin_chunk_idx, end_chunk_idx)
) else:
ops.moe_sum(intermediate_cache3.view(*intermediate_cache3.size()),
out_hidden_states[begin_chunk_idx:end_chunk_idx])
return out_hidden_states return out_hidden_states
...@@ -1759,21 +1759,8 @@ def fused_experts_impl( ...@@ -1759,21 +1759,8 @@ def fused_experts_impl(
inplace=inplace, inplace=inplace,
activation=activation, activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input, apply_router_weight_on_input=apply_router_weight_on_input,
use_fp8_w8a8=False,
use_int8_w8a8=False,
use_int8_w8a16=False,
use_int4_w4a16=False,
use_int4_w4a8=False,
per_channel_quant=False,
global_num_experts=global_num_experts, global_num_experts=global_num_experts,
expert_map=expert_map, expert_map=expert_map,
w1_scale=None,
w2_scale=None,
w1_zp=None,
w2_zp=None,
a1_scale=None,
a2_scale=None,
block_shape=None,
use_nn_moe=False, use_nn_moe=False,
routed_scaling_factor=routed_scaling_factor, routed_scaling_factor=routed_scaling_factor,
shared_output=shared_output shared_output=shared_output
......
...@@ -2,36 +2,6 @@ ...@@ -2,36 +2,6 @@
import torch import torch
import numpy as np import numpy as np
def unpack_int8_to_int4(tensor_int8: torch.Tensor) -> torch.Tensor:
"""
将[N, K//2]大小的torch.int8 Tensor,转换为[N, K]大小的torch.int32 Tensor。
每个int8包含两个int4,分别提取到int32的低4位,其余位为0。
Args:
tensor_int8 (torch.Tensor): 输入张量,形状为[N, K//2],类型为torch.int8。
Returns:
torch.Tensor: 输出张量,形状为[N, K],类型为torch.int32。
"""
if tensor_int8.dtype != torch.int8:
raise ValueError("Input tensor must be of type torch.int8")
N, K_half = tensor_int8.shape
tensor_uint8 = tensor_int8.to(torch.uint8)
# 拆分为低4位和高4位
low4 = tensor_uint8 & 0x0F
high4 = (tensor_uint8 >> 4) & 0x0F
# 创建目标 tensor(int32),每个元素只使用低4位
unpacked = torch.empty((N, K_half * 2), dtype=torch.int32, device=tensor_int8.device)
# 放置数据:每个值放在 int32 的低4位
unpacked[:, 0::2] = low4.to(torch.int32)
unpacked[:, 1::2] = high4.to(torch.int32)
return unpacked
# 从 [32, 64] int32的size中,重排后 每行相邻的8个uint4数据 混排后 pack成uint32数据 # 从 [32, 64] int32的size中,重排后 每行相邻的8个uint4数据 混排后 pack成uint32数据
#原本是32 * 16算一次mmac,因为npack组成32 * 64大小 #原本是32 * 16算一次mmac,因为npack组成32 * 64大小
...@@ -95,36 +65,6 @@ def marlin_weights_npack2( ...@@ -95,36 +65,6 @@ def marlin_weights_npack2(
return q_w return q_w
#npack重排
def marlin_weights_kpack2(
q_w,
weight_perm,
k_tile=32,
n_tile=16):
# 7168, 512
size_k, size_n = q_w.shape
# [7168, 512] ==> [224, 32, 8,64]
q_w = q_w.reshape((size_k // k_tile, k_tile, size_n // n_tile, n_tile))
# [224, 32, 8,64] ==> [224, 8, 32, 64]
q_w = q_w.permute((0, 2, 1, 3))
# [224, 8, 32, 64] ==> [224, 16384]
q_w = q_w.reshape((size_k // k_tile, size_n * k_tile))
# 按照指定的 perm进行重排
q_w = q_w.reshape((-1, weight_perm.numel()))[:, weight_perm].reshape(q_w.shape)
# orig_device = q_w.device
q_w = q_w.cpu().numpy().astype(np.uint32)
# q_packed = np.zeros((q_w.shape[0], q_w.shape[1] // pack_factor), dtype=np.uint32)
# for i in range(pack_factor):
# q_packed |= q_w[:, i::pack_factor] << 4 * i
# q_packed = torch.from_numpy(q_packed.astype(np.int32)).to(orig_device)
return q_w
def w16a16_marlin_weight(full_w16a16_w # [size_n, size_k] def w16a16_marlin_weight(full_w16a16_w # [size_n, size_k]
): ):
# import pdb # import pdb
...@@ -136,10 +76,3 @@ def w16a16_marlin_weight(full_w16a16_w # [size_n, size_k] ...@@ -136,10 +76,3 @@ def w16a16_marlin_weight(full_w16a16_w # [size_n, size_k]
# 按照索引进行重排 # 按照索引进行重排
marlin_q_w = marlin_weights_npack2(full_w16a16_w, weight_perm, k_tile=16, n_tile=32) marlin_q_w = marlin_weights_npack2(full_w16a16_w, weight_perm, k_tile=16, n_tile=32)
return marlin_q_w return marlin_q_w
if __name__ == "__main__":
print("线程 0 需要的索引: ")
print(get_weight_perms(interleave=False)[:32])
print("线程 1 需要的索引: ")
print(get_weight_perms(interleave=False)[32:64])
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment