Commit 3ab9494d authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge branch 'v0.9.2-dev_mtp_sampler' into 'v0.9.2-dev'

feat: add Marlin W16A16 MoE fast path

See merge request dcutoolkit/deeplearing/vllm!294
parents dac87ca7 4f575f17
......@@ -192,6 +192,8 @@ if TYPE_CHECKING:
VLLM_PP_DEBUG: bool = False
VLLM_USE_V32_ENCODE: bool = False
VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT: bool = False
VLLM_USE_FUSED_RMS_ROPE: bool = False
VLLM_USE_MARLIN_W16A16_MOE:bool = False
def get_default_cache_root():
return os.getenv(
......@@ -1255,6 +1257,15 @@ environment_variables: dict[str, Callable[[], Any]] = {
lambda: (os.getenv('VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT', 'False').lower() in
("true", "1")),
# vLLM will use fused RMS + RoPE kernel
"VLLM_USE_FUSED_RMS_ROPE":
lambda: (os.environ.get("VLLM_USE_FUSED_RMS_ROPE", "False").lower() in
("true", "1")),
# vLLM will use Marlin W16A16 kernel for MoE experts
"VLLM_USE_MARLIN_W16A16_MOE":
lambda: (os.environ.get("VLLM_USE_MARLIN_W16A16_MOE", "False").lower() in
("true", "1")),
}
# --8<-- [end:env-vars-definition]
......
import functools
import json
import os
from typing import Any, Callable, Dict, List, Optional, Tuple
import vllm.envs as envs
from vllm import _custom_ops as ops
from vllm.model_executor.layers.fused_moe.moe_align_block_size import (
moe_align_block_size)
import torch
import triton
import triton.language as tl
import lmslim.envs as lsenvs
from vllm.utils import W8a8GetCacheJSON
use_lightop = lsenvs.LMSLIM_USE_LIGHTOP
device_name = lsenvs.LMSLIM_GPU_NAME
num_cus= torch.cuda.get_device_properties(torch.cuda.current_device()).multi_processor_count
from lmslim.layers.gemm.int8_utils import (
per_token_group_quant_int8, per_token_quant_int8)
from vllm.model_executor.layers.fused_moe.fused_moe import get_moe_cache
if use_lightop:
from lightop import moe_gemm_marlin_w16a16, get_moe_cuda_marlin_config_w16a16
from lightop import op as op
@torch.compile
def moe_sum_reduce_torch_compile(x, out, routed_scaling_factor):
torch.sum(x, dim=1, out=out)
out.mul_(routed_scaling_factor)
@triton.jit
def _moe_sum_reduce_kernel(
input_ptr,
input_stride_0,
input_stride_1,
input_stride_2,
output_ptr,
output_stride_0,
output_stride_1,
token_num: int,
topk_num: int,
hidden_dim: int,
routed_scaling_factor: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_DIM: tl.constexpr,
NUM_STAGE: tl.constexpr,
):
input_stride_0 = tl.cast(input_stride_0, dtype=tl.int64)
input_stride_1 = tl.cast(input_stride_1, dtype=tl.int64)
output_stride_0 = tl.cast(output_stride_0, dtype=tl.int64)
token_block_id = tl.program_id(0)
dim_block_id = tl.program_id(1)
token_start = token_block_id * BLOCK_M
token_end = min((token_block_id + 1) * BLOCK_M, token_num)
dim_start = dim_block_id * BLOCK_DIM
dim_end = min((dim_block_id + 1) * BLOCK_DIM, hidden_dim)
offs_dim = dim_start + tl.arange(0, BLOCK_DIM)
for token_index in range(token_start, token_end):
accumulator = tl.zeros((BLOCK_DIM,), dtype=tl.float32)
input_t_ptr = input_ptr + token_index * input_stride_0 + offs_dim
for i in tl.range(0, topk_num, num_stages=NUM_STAGE):
tmp = tl.load(
input_t_ptr + i * input_stride_1, mask=offs_dim < dim_end, other=0.0
)
accumulator += tmp
accumulator = accumulator * routed_scaling_factor
store_t_ptr = output_ptr + token_index * output_stride_0 + offs_dim
tl.store(
store_t_ptr,
accumulator.to(input_ptr.dtype.element_ty),
mask=offs_dim < dim_end,
)
def moe_sum_reduce_triton(
input: torch.Tensor, output: torch.Tensor, routed_scaling_factor: float
):
assert input.is_contiguous()
assert output.is_contiguous()
token_num, topk_num, hidden_dim = input.shape
assert output.shape[0] == token_num and output.shape[1] == hidden_dim
if token_num <= 32:
BLOCK_M = 1
BLOCK_DIM = 512
NUM_STAGE = 2
num_warps = 4
elif token_num <= 128:
BLOCK_M = 1
BLOCK_DIM = 1024
NUM_STAGE = 0
num_warps = 2
elif token_num <= 4096:
BLOCK_M = 1
BLOCK_DIM = 2048
NUM_STAGE = 0
num_warps = 2
else:
BLOCK_M = 1
BLOCK_DIM = 2048
NUM_STAGE = 2
num_warps = 8
grid = (
triton.cdiv(token_num, BLOCK_M),
triton.cdiv(hidden_dim, BLOCK_DIM),
)
_moe_sum_reduce_kernel[grid](
input,
*input.stride(),
output,
*output.stride(),
token_num=token_num,
topk_num=topk_num,
hidden_dim=hidden_dim,
routed_scaling_factor=routed_scaling_factor,
BLOCK_M=BLOCK_M,
BLOCK_DIM=BLOCK_DIM,
NUM_STAGE=NUM_STAGE,
num_warps=num_warps,
)
return
def moe_reduce_dispatch(
intermediate_cache3: torch.Tensor,
out_hidden_states: torch.Tensor,
begin_chunk_idx: int,
end_chunk_idx: int,
routed_scaling_factor: float,
shared_output: Optional[torch.Tensor] = None,
):
inter_cache_view = intermediate_cache3.view(*intermediate_cache3.shape)
n = intermediate_cache3.shape[0]
# 根据 n 大小选择不同的 reduce 实现
if 1 <= n <= 4:
moe_sum_reduce_torch_compile(inter_cache_view, out_hidden_states[begin_chunk_idx:end_chunk_idx], 1.0)
elif 4 < n <= 1024:
moe_sum_reduce_triton(inter_cache_view, out_hidden_states[begin_chunk_idx:end_chunk_idx], 1.0)
elif 1024 < n <= 32768:
ops.moe_sum_opt1(inter_cache_view, out_hidden_states[begin_chunk_idx:end_chunk_idx])
else:
ops.moe_sum(inter_cache_view, out_hidden_states[begin_chunk_idx:end_chunk_idx])
# 根据 shared_output 是否存在决定怎么更新
if shared_output is not None:
out_hidden_states[begin_chunk_idx:end_chunk_idx].mul_(routed_scaling_factor).add_(shared_output[begin_chunk_idx:end_chunk_idx])
else:
out_hidden_states[begin_chunk_idx:end_chunk_idx].mul_(routed_scaling_factor)
def moe_kernel_prepare_input(
A: torch.Tensor,
B: torch.Tensor,
A_scale: Optional[torch.Tensor],
B_scale: Optional[torch.Tensor],
use_fp8_w8a8: bool,
use_int8_w8a8: bool,
use_int8_w8a16: bool,
use_int4_w4a16: bool,
use_int4_w4a8: bool,
per_channel_quant: bool,
block_shape: Optional[List[int]] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
if use_int8_w8a8 or use_int4_w4a8:
assert B_scale is not None
if block_shape is None:
# activation channel-wise int8 quantization
assert (per_channel_quant
), "int8 or int4 quantization only supports block or channel-wise"
A, A_scale = per_token_quant_int8(A)
else:
# activation block-wise int8 quantization
assert len(block_shape) == 2
_, block_k = block_shape[0], block_shape[1]
A, A_scale = per_token_group_quant_int8(A, block_k)
assert triton.cdiv(A.shape[-1], block_k) == A_scale.shape[-1]
elif use_int8_w8a16 or use_int4_w4a16:
assert B_scale is not None
assert block_shape is None or block_shape[0] == 0
else:
assert A_scale is None
assert B_scale is None
return A, A_scale
def round_up(x: int, y: int) -> int:
return ((x + y - 1) // y) * y
def moe_align_block_size_lightop(
topk_ids: torch.Tensor,
block_size: int,
num_experts: int,
expert_map: Optional[torch.Tensor] = None,
expert_mask: Optional[torch.Tensor] = None,
num_local_tokens: Optional[torch.Tensor] = None,
pad_sorted_ids: bool = False,
ep_size: int = 8,
num_token: Optional[int] = None
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
local_num_experts = num_experts // ep_size
if num_token:
if num_token < block_size:
max_num_tokens_padded = min(topk_ids.numel() * block_size, topk_ids.numel() + local_num_experts * (block_size - 1))
else:
max_num_tokens_padded = topk_ids.numel() + local_num_experts * (block_size - 1)
sorted_ids = torch.full((max_num_tokens_padded,), fill_value=topk_ids.numel(), dtype=torch.int32, device=topk_ids.device)
else:
max_num_tokens_padded = topk_ids.numel() + local_num_experts * (block_size - 1)
if pad_sorted_ids:
max_num_tokens_padded = round_up(max_num_tokens_padded, block_size)
sorted_ids = torch.empty((max_num_tokens_padded, ),
dtype=torch.int32,
device=topk_ids.device)
sorted_ids.fill_(topk_ids.numel())
max_num_m_blocks = triton.cdiv(max_num_tokens_padded, block_size)
# Expert ids must be zeroed out to prevent index out of bounds error while
# mapping global expert ids to local expert ids in expert parallelism.
if expert_map is not None:
expert_ids = torch.zeros((max_num_m_blocks, ),
dtype=torch.int32,
device=topk_ids.device)
else:
expert_ids = torch.empty((max_num_m_blocks, ),
dtype=torch.int32,
device=topk_ids.device)
num_tokens_post_pad = torch.zeros((1),
dtype=torch.int32,
device=topk_ids.device)
op.moe_align_block_size(topk_ids,
num_experts,
block_size,
sorted_ids,
expert_ids,
num_tokens_post_pad,
expert_map,
expert_mask,
num_local_tokens,
expert_map is not None)
return sorted_ids, expert_ids, num_tokens_post_pad
def fused_experts_impl_w16a16_marlin(hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
w1_marlin: torch.Tensor,
w2_marlin: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
cache13: torch.Tensor,
inplace: bool = False,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
use_fp8_w8a8: bool = False,
use_int8_w8a8: bool = False,
use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False,
use_int4_w4a8: bool = False,
per_channel_quant: bool = False,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
w1_zp: Optional[torch.Tensor] = None,
w2_zp: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None,
use_nn_moe: Optional[bool] = False,
routed_scaling_factor: Optional[float] = 1.0,
shared_output: Optional[torch.Tensor] = None,
num_local_tokens: Optional[torch.Tensor] = None,
expect_m: Optional[int] = -1,
):
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
assert w1.is_contiguous(), "Expert weights1 must be contiguous"
assert w2.is_contiguous(), "Expert weights2 must be contiguous"
# 当前只支持 bf16 fp16
assert hidden_states.dtype in [torch.bfloat16,torch.float16]
compute_type = hidden_states.dtype
assert use_lightop, (
"only BW and set LMSLIM_USE_LIGHTOP=1 support Marlin W16A16 MoE")
num_tokens, K = hidden_states.shape
E, twoN, K_w1 = w1.shape
N = twoN // 2
E2, K_w2, N2_w2 = w2.shape
if global_num_experts == -1:
global_num_experts = E
top_k_num = topk_ids.shape[1]
# We execute the fused_moe kernel in chunks to circumvent this issue:
# https://github.com/vllm-project/vllm/issues/5938
CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE
M = min(num_tokens, CHUNK_SIZE)
# We can reuse the memory between these because by the time we need
# cache3, we're done with cache1
N2 = 2 * N
intermediate_cache1 = cache13[:M * top_k_num * N2].view(-1, N2)
intermediate_cache3 = cache13[:M * top_k_num * K]
# This needs separate memory since it's used concurrently with cache1
intermediate_cache2 = torch.empty((M * top_k_num, N),
device=hidden_states.device,
dtype=compute_type)
is_ep = expert_map is not None
expert_mask=None
ep_size=None
if is_ep:
expert_mask = torch.zeros((CHUNK_SIZE, top_k_num), dtype=torch.bool, device=hidden_states.device, requires_grad=False)
ep_size = global_num_experts // E
if inplace:
out_hidden_states = hidden_states
else:
out_hidden_states = torch.empty_like(hidden_states, dtype=compute_type)
for chunk in range((num_tokens // CHUNK_SIZE) + 1):
begin_chunk_idx, end_chunk_idx = (chunk * CHUNK_SIZE,
min((chunk + 1) * CHUNK_SIZE,
num_tokens))
curr_hidden_states = hidden_states[begin_chunk_idx:end_chunk_idx]
tokens_in_chunk, _ = curr_hidden_states.shape
if tokens_in_chunk == 0:
break
bs = tokens_in_chunk
if num_local_tokens is not None and expect_m != -1:
bs = expect_m
intermediate_cache3 = intermediate_cache3.view(-1, K)
if tokens_in_chunk < CHUNK_SIZE and chunk > 0:
# Adjust the intermediate cache size and config for the last
# chunk. Note that in most cases we only have one chunk
# so the cache size and config are already set correctly and
# do not need to be adjusted.
intermediate_cache1 = intermediate_cache1[:tokens_in_chunk * top_k_num]
intermediate_cache2 = intermediate_cache2[:tokens_in_chunk * top_k_num]
intermediate_cache3 = intermediate_cache3[:tokens_in_chunk * top_k_num]
# import logging
# logger = logging.getLogger(__name__)
# if not status:
# logger.info("lightop unsupport this size E:%s, N:%s, K:%s", E, N, K)
config_marlin_0, config_marlin_1, status = get_moe_cuda_marlin_config_w16a16(
E,
bs,
N2,
K,
K,
N,
top_k_num,
device_name,
num_cus,
hidden_states.dtype)
assert status, f'lightop unsupport this size E:{E}, N:{N}, K:{K}'
# Align with vLLM's default config handling for W16A16.
# if "BLOCK_SIZE_M" not in config_marlin_0:
# config_marlin_0["BLOCK_SIZE_M"] = 16
# if "BLOCK_SIZE_M" not in config_marlin_1:
# config_marlin_1["BLOCK_SIZE_M"] = config_marlin_0["BLOCK_SIZE_M"]
# if "MODE" not in config_marlin_0:
# config_marlin_0["MODE"] = 412
# if "MODE" not in config_marlin_1:
# config_marlin_1["MODE"] = 411
# if "DELTA" not in config_marlin_0:
# config_marlin_0["DELTA"] = 1
# if "DELTA" not in config_marlin_1:
# config_marlin_1["DELTA"] = 1
block_size_m = config_marlin_0["BLOCK_SIZE_M"]
curr_topk_ids = topk_ids[begin_chunk_idx:end_chunk_idx]
curr_topk_weights = topk_weights[begin_chunk_idx:end_chunk_idx]
if num_local_tokens is None:
sorted_token_ids, expert_ids, num_tokens_post_padded = (
moe_align_block_size(curr_topk_ids, block_size_m,
global_num_experts, expert_map=expert_map,
expert_mask = expert_mask[:tokens_in_chunk] if is_ep else None))
else:
sorted_token_ids, expert_ids, num_tokens_post_padded = (
moe_align_block_size_lightop(curr_topk_ids, block_size_m, global_num_experts,
expert_map = expert_map,
expert_mask = expert_mask[begin_chunk_idx:end_chunk_idx] if is_ep else None,
num_local_tokens = num_local_tokens,
ep_size=ep_size))
# GEMM1: hidden_states * w1 -> intermediate_cache1
moe_gemm_marlin_w16a16(
curr_hidden_states,
w1_marlin,
intermediate_cache1,
None,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
top_k_num,
config_marlin_0,
)
torch.ops._C.silu_and_mul(intermediate_cache2, intermediate_cache1)
# GEMM2: intermediate_cache2 * w2, apply routing weights here.
moe_gemm_marlin_w16a16(
intermediate_cache2,
w2_marlin,
intermediate_cache3,
curr_topk_weights,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
1,
config_marlin_1,
)
intermediate_cache3 = intermediate_cache3.view(-1, top_k_num, K)
if is_ep:
op.moe_sum(input=intermediate_cache3,
output=out_hidden_states[begin_chunk_idx:end_chunk_idx],
bias = None if shared_output is None else shared_output[begin_chunk_idx:end_chunk_idx],
expert_mask = expert_mask[:tokens_in_chunk],
num_local_tokens=num_local_tokens,
factor=routed_scaling_factor,
)
elif use_lightop and shared_output is not None:
op.moe_sum(input=intermediate_cache3.view(*intermediate_cache3.shape),
output=out_hidden_states[begin_chunk_idx:end_chunk_idx],
bias=shared_output[begin_chunk_idx:end_chunk_idx],
expert_mask=None,
num_local_tokens=None,
factor=routed_scaling_factor)
elif shared_output is not None:
moe_reduce_dispatch(
intermediate_cache3,
out_hidden_states,
begin_chunk_idx,
end_chunk_idx,
routed_scaling_factor,
shared_output,
)
else:
moe_reduce_dispatch(
intermediate_cache3,
out_hidden_states,
begin_chunk_idx,
end_chunk_idx,
1.0,
None,
)
return out_hidden_states
......@@ -48,7 +48,27 @@ logger = init_logger(__name__)
if envs.VLLM_USE_GLOBAL_CACHE13:
moe_cache_singleton = None
# Cache Marlin-packed weights so we only reorder once per weight tensor.
_marlin_weight_cache: Dict[Tuple[int, torch.device, torch.dtype, torch.Size], torch.Tensor] = {}
# Cache packed W16A16 Marlin weights by parameter identity so we can offload
# original layouts from GPU without losing the packed copies.
_w16a16_marlin_weight_cache: Dict[int, Tuple[torch.Tensor, torch.Tensor]] = {}
def _get_marlin_packed_weight(weight: torch.Tensor,
pack_fn: Callable[[torch.Tensor], torch.Tensor]
) -> torch.Tensor:
key = (weight.data_ptr(), weight.device, weight.dtype, weight.shape)
cached = _marlin_weight_cache.get(key)
if cached is not None:
return cached
# Marlin packing is done per expert and reshaped back to original dims.
packed = torch.stack([pack_fn(weight[i]).contiguous()
for i in range(weight.shape[0])],
dim=0)
_marlin_weight_cache[key] = packed
return packed
arch_name = torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0]
arch_cu = torch.cuda.get_device_properties(torch.cuda.current_device()).multi_processor_count
......@@ -1694,6 +1714,71 @@ def fused_experts_impl(
cache13 = get_moe_cache(top_k_num, N,K if not use_nn_moe else w2.shape[2], device=hidden_states.device, dtype=hidden_states.dtype)
else:
cache13 = torch.empty(M * top_k_num * max(N, K if not use_nn_moe else w2.shape[2]), device=hidden_states.device, dtype=hidden_states.dtype)
# Optional fast path: use lmslim's Marlin W16A16 fused MoE implementation
# when explicitly requested. This reuses the same cache13 buffer as other
# fused paths for consistency.
from vllm.model_executor.layers.fused_moe.fuse_moe_w16a16_marlin import fused_experts_impl_w16a16_marlin
if (envs.VLLM_USE_MARLIN_W16A16_MOE
and fused_experts_impl_w16a16_marlin is not None):
# Only pack when shapes match the expected [E, 2N, K] / [E, K, N/2] contract.
# If shapes are unexpected, skip packing and fall back to non-Marlin paths below.
from vllm.model_executor.layers.fused_moe.marlin_quant import w16a16_marlin_weight
cache_key = id(w1)
cached_marlin = _w16a16_marlin_weight_cache.get(cache_key)
if cached_marlin is None:
w1_marlin = _get_marlin_packed_weight(w1, w16a16_marlin_weight)
w2_marlin = _get_marlin_packed_weight(w2, w16a16_marlin_weight)
# Offload original layout weights from GPU to avoid double residency.
with torch.no_grad():
w1_cpu = w1.detach().to("cpu")
w2_cpu = w2.detach().to("cpu")
if hasattr(w1, "data"):
w1.data = w1_cpu # type: ignore[attr-defined]
else:
w1 = w1_cpu
if hasattr(w2, "data"):
w2.data = w2_cpu # type: ignore[attr-defined]
else:
w2 = w2_cpu
_w16a16_marlin_weight_cache[cache_key] = (w1_marlin, w2_marlin)
else:
w1_marlin, w2_marlin = cached_marlin
return fused_experts_impl_w16a16_marlin(
hidden_states=hidden_states,
w1=w1,
w2=w2,
w1_marlin=w1_marlin,
w2_marlin=w2_marlin,
topk_weights=topk_weights,
topk_ids=topk_ids,
cache13=cache13,
inplace=inplace,
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input,
use_fp8_w8a8=False,
use_int8_w8a8=False,
use_int8_w8a16=False,
use_int4_w4a16=False,
use_int4_w4a8=False,
per_channel_quant=False,
global_num_experts=global_num_experts,
expert_map=expert_map,
w1_scale=None,
w2_scale=None,
w1_zp=None,
w2_zp=None,
a1_scale=None,
a2_scale=None,
block_shape=None,
use_nn_moe=False,
routed_scaling_factor=routed_scaling_factor,
shared_output=shared_output
)
if use_int8_w8a8 is True:
return fused_experts_impl_int8(hidden_states=hidden_states,
w1=w1,
......
import torch
import numpy as np
def unpack_int8_to_int4(tensor_int8: torch.Tensor) -> torch.Tensor:
"""
将[N, K//2]大小的torch.int8 Tensor,转换为[N, K]大小的torch.int32 Tensor。
每个int8包含两个int4,分别提取到int32的低4位,其余位为0。
Args:
tensor_int8 (torch.Tensor): 输入张量,形状为[N, K//2],类型为torch.int8。
Returns:
torch.Tensor: 输出张量,形状为[N, K],类型为torch.int32。
"""
if tensor_int8.dtype != torch.int8:
raise ValueError("Input tensor must be of type torch.int8")
N, K_half = tensor_int8.shape
tensor_uint8 = tensor_int8.to(torch.uint8)
# 拆分为低4位和高4位
low4 = tensor_uint8 & 0x0F
high4 = (tensor_uint8 >> 4) & 0x0F
# 创建目标 tensor(int32),每个元素只使用低4位
unpacked = torch.empty((N, K_half * 2), dtype=torch.int32, device=tensor_int8.device)
# 放置数据:每个值放在 int32 的低4位
unpacked[:, 0::2] = low4.to(torch.int32)
unpacked[:, 1::2] = high4.to(torch.int32)
return unpacked
# 从 [32, 64] int32的size中,重排后 每行相邻的8个uint4数据 混排后 pack成uint32数据
#原本是32 * 16算一次mmac,因为npack组成32 * 64大小
#现在是16 * 16算一次mmac,因为npack组成16 * 32大小
#这里是在对32 * 64 进行数据的重排
def get_weight_perms(interleave: bool=False):
# ================== 4条mmac 指令进行拼接的结果 ============
perm = []
for i in range(64): # 遍历64个线程,因为是针对一个warp内的
for col in range(2): # 遍历列方向2次, 代表2次mmac指令 具体是行还是列还不知道
cur_col = (i % 16) * 2 + col #计算当前线程在哪个列 这里是占据4列
for row in range(4): # 每个线程在 每个mmac中需要取8个uint4数据 占据8行
cur_row = (i // 16) * 4 + row
# 计算在整个 [32, 64]范围内的实际偏移
cur_idx = cur_row * 32 + cur_col
perm.append(cur_idx)
perm = np.array(perm)
if interleave:
# ================= 加入混排策略 =================
# # interleave = np.array([4, 0, 5, 1, 6, 2, 7, 3])
# # interleave = np.array([0, 4, 1, 5, 2, 6, 3, 7])
# QQQ 类似的 pack混排策略
interleave = np.array([4, 0, 5, 1, 6, 2, 7, 3])
# 按照 interleave 重排后展成 一维数组
perm = perm.reshape((-1, 8))[:, interleave].ravel()
perm = torch.from_numpy(perm)
return perm
#npack重排 //512大小
def marlin_weights_npack2(
q_w,
weight_perm,
k_tile=16,
n_tile=32):
# 2048, 768
size_k, size_n = q_w.shape
# [7168, 512] ==> [128, 16, 24,32]
q_w = q_w.reshape((size_k // k_tile, k_tile, size_n // n_tile, n_tile))
# [128, 16, 24,32] ==> [128, 24, 16,32]
q_w = q_w.permute((0, 2, 1, 3))
# [128, 24, 16,32] ==> [128, 12288]
q_w = q_w.reshape((size_k // k_tile, size_n * k_tile))
# 按照指定的 perm进行重排
q_w = q_w.reshape((-1, weight_perm.numel()))[:, weight_perm].reshape(q_w.shape)
# orig_device = q_w.device
# q_w = q_w.cpu().numpy()
# q_packed = np.zeros((q_w.shape[0], q_w.shape[1] // pack_factor), dtype=np.uint32)
# for i in range(pack_factor):
# q_packed |= q_w[:, i::pack_factor] << 4 * i
# q_packed = torch.from_numpy(q_packed.astype(np.int32)).to(orig_device)
return q_w
#npack重排
def marlin_weights_kpack2(
q_w,
weight_perm,
k_tile=32,
n_tile=16):
# 7168, 512
size_k, size_n = q_w.shape
# [7168, 512] ==> [224, 32, 8,64]
q_w = q_w.reshape((size_k // k_tile, k_tile, size_n // n_tile, n_tile))
# [224, 32, 8,64] ==> [224, 8, 32, 64]
q_w = q_w.permute((0, 2, 1, 3))
# [224, 8, 32, 64] ==> [224, 16384]
q_w = q_w.reshape((size_k // k_tile, size_n * k_tile))
# 按照指定的 perm进行重排
q_w = q_w.reshape((-1, weight_perm.numel()))[:, weight_perm].reshape(q_w.shape)
# orig_device = q_w.device
q_w = q_w.cpu().numpy().astype(np.uint32)
# q_packed = np.zeros((q_w.shape[0], q_w.shape[1] // pack_factor), dtype=np.uint32)
# for i in range(pack_factor):
# q_packed |= q_w[:, i::pack_factor] << 4 * i
# q_packed = torch.from_numpy(q_packed.astype(np.int32)).to(orig_device)
return q_w
def w16a16_marlin_weight(full_w16a16_w # [size_n, size_k]
):
# import pdb
# pdb.set_trace()
# [size_n, size_k] == > [size_k, size_n] 此时已经是默认NN的 k * n 基于这个进行重排
full_w16a16_w = full_w16a16_w.T
# 获取 [16, 32]的权重数据块中,需要重排的顺序
weight_perm = get_weight_perms()
# 按照索引进行重排
marlin_q_w = marlin_weights_npack2(full_w16a16_w, weight_perm, k_tile=16, n_tile=32)
return marlin_q_w
if __name__ == "__main__":
print("线程 0 需要的索引: ")
print(get_weight_perms(interleave=False)[:32])
print("线程 1 需要的索引: ")
print(get_weight_perms(interleave=False)[32:64])
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment