Commit b66c8e4b authored by zhuwenwen's avatar zhuwenwen
Browse files

Synchronize the modifications from the 12th to the 17th:

修复CompressedTensorsLinearMethod中的w4a16的冲突问题
feat(moe): add Marlin W16A16 fused MoE behind VLLM_USE_MARLIN_W16A16_MOE
replace the fp8_mqa_logits and fp8_paged_mqa_logits interfaces in deepgemm with mqa_logits and paged_mqa_logits from lightop
parent b8ef3436
......@@ -271,6 +271,8 @@ if TYPE_CHECKING:
VLLM_USE_FUSE_SILU_AND_MUL: bool = False
VLLM_USE_OPT_RESHAPE_AND_CACHE: bool = False
VLLM_USE_TOPK_RENORM: bool = False
VLLM_USE_FUSED_RMS_ROPE: bool = False
VLLM_USE_MARLIN_W16A16_MOE:bool = False
def get_default_cache_root():
......@@ -1747,6 +1749,14 @@ environment_variables: dict[str, Callable[[], Any]] = {
lambda:
(os.environ.get("VLLM_USE_TOPK_RENORM", "True").lower() in
("true", "1")),
# vLLM will use fused RMS + RoPE kernel
"VLLM_USE_FUSED_RMS_ROPE":
lambda: (os.environ.get("VLLM_USE_FUSED_RMS_ROPE", "False").lower() in
("true", "1")),
# vLLM will use Marlin W16A16 kernel for MoE experts
"VLLM_USE_MARLIN_W16A16_MOE":
lambda: (os.environ.get("VLLM_USE_MARLIN_W16A16_MOE", "False").lower() in
("true", "1")),
}
# --8<-- [end:env-vars-definition]
......
import functools
import json
import os
from typing import Any, Callable, Dict, List, Optional, Tuple
import vllm.envs as envs
from vllm import _custom_ops as ops
from vllm.model_executor.layers.fused_moe.moe_align_block_size import (
moe_align_block_size)
import torch
import triton
import triton.language as tl
import lmslim.envs as lsenvs
use_lightop = lsenvs.LMSLIM_USE_LIGHTOP
device_name = lsenvs.LMSLIM_GPU_NAME
num_cus= torch.cuda.get_device_properties(torch.cuda.current_device()).multi_processor_count
from vllm.model_executor.layers.fused_moe.fused_moe import get_moe_cache
if use_lightop:
from lightop import moe_gemm_marlin_w16a16, get_moe_cuda_marlin_config_w16a16
from lightop import op as op
@torch.compile
def moe_sum_reduce_torch_compile(x, out, routed_scaling_factor):
torch.sum(x, dim=1, out=out)
out.mul_(routed_scaling_factor)
@triton.jit
def _moe_sum_reduce_kernel(
input_ptr,
input_stride_0,
input_stride_1,
input_stride_2,
output_ptr,
output_stride_0,
output_stride_1,
token_num: int,
topk_num: int,
hidden_dim: int,
routed_scaling_factor: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_DIM: tl.constexpr,
NUM_STAGE: tl.constexpr,
):
input_stride_0 = tl.cast(input_stride_0, dtype=tl.int64)
input_stride_1 = tl.cast(input_stride_1, dtype=tl.int64)
output_stride_0 = tl.cast(output_stride_0, dtype=tl.int64)
token_block_id = tl.program_id(0)
dim_block_id = tl.program_id(1)
token_start = token_block_id * BLOCK_M
token_end = min((token_block_id + 1) * BLOCK_M, token_num)
dim_start = dim_block_id * BLOCK_DIM
dim_end = min((dim_block_id + 1) * BLOCK_DIM, hidden_dim)
offs_dim = dim_start + tl.arange(0, BLOCK_DIM)
for token_index in range(token_start, token_end):
accumulator = tl.zeros((BLOCK_DIM,), dtype=tl.float32)
input_t_ptr = input_ptr + token_index * input_stride_0 + offs_dim
for i in tl.range(0, topk_num, num_stages=NUM_STAGE):
tmp = tl.load(
input_t_ptr + i * input_stride_1, mask=offs_dim < dim_end, other=0.0
)
accumulator += tmp
accumulator = accumulator * routed_scaling_factor
store_t_ptr = output_ptr + token_index * output_stride_0 + offs_dim
tl.store(
store_t_ptr,
accumulator.to(input_ptr.dtype.element_ty),
mask=offs_dim < dim_end,
)
def moe_sum_reduce_triton(
input: torch.Tensor, output: torch.Tensor, routed_scaling_factor: float
):
assert input.is_contiguous()
assert output.is_contiguous()
token_num, topk_num, hidden_dim = input.shape
assert output.shape[0] == token_num and output.shape[1] == hidden_dim
if token_num <= 32:
BLOCK_M = 1
BLOCK_DIM = 512
NUM_STAGE = 2
num_warps = 4
elif token_num <= 128:
BLOCK_M = 1
BLOCK_DIM = 1024
NUM_STAGE = 0
num_warps = 2
elif token_num <= 4096:
BLOCK_M = 1
BLOCK_DIM = 2048
NUM_STAGE = 0
num_warps = 2
else:
BLOCK_M = 1
BLOCK_DIM = 2048
NUM_STAGE = 2
num_warps = 8
grid = (
triton.cdiv(token_num, BLOCK_M),
triton.cdiv(hidden_dim, BLOCK_DIM),
)
_moe_sum_reduce_kernel[grid](
input,
*input.stride(),
output,
*output.stride(),
token_num=token_num,
topk_num=topk_num,
hidden_dim=hidden_dim,
routed_scaling_factor=routed_scaling_factor,
BLOCK_M=BLOCK_M,
BLOCK_DIM=BLOCK_DIM,
NUM_STAGE=NUM_STAGE,
num_warps=num_warps,
)
return
def moe_reduce_dispatch(
intermediate_cache3: torch.Tensor,
out_hidden_states: torch.Tensor,
begin_chunk_idx: int,
end_chunk_idx: int,
routed_scaling_factor: float,
shared_output: Optional[torch.Tensor] = None,
):
inter_cache_view = intermediate_cache3.view(*intermediate_cache3.shape)
n = intermediate_cache3.shape[0]
# 根据 n 大小选择不同的 reduce 实现
if 1 <= n <= 4:
moe_sum_reduce_torch_compile(inter_cache_view, out_hidden_states[begin_chunk_idx:end_chunk_idx], 1.0)
elif 4 < n <= 1024:
moe_sum_reduce_triton(inter_cache_view, out_hidden_states[begin_chunk_idx:end_chunk_idx], 1.0)
elif 1024 < n <= 32768:
ops.moe_sum_opt1(inter_cache_view, out_hidden_states[begin_chunk_idx:end_chunk_idx])
else:
ops.moe_sum(inter_cache_view, out_hidden_states[begin_chunk_idx:end_chunk_idx])
# 根据 shared_output 是否存在决定怎么更新
if shared_output is not None:
out_hidden_states[begin_chunk_idx:end_chunk_idx].mul_(routed_scaling_factor).add_(shared_output[begin_chunk_idx:end_chunk_idx])
else:
out_hidden_states[begin_chunk_idx:end_chunk_idx].mul_(routed_scaling_factor)
def round_up(x: int, y: int) -> int:
return ((x + y - 1) // y) * y
def moe_align_block_size_lightop(
topk_ids: torch.Tensor,
block_size: int,
num_experts: int,
expert_map: Optional[torch.Tensor] = None,
expert_mask: Optional[torch.Tensor] = None,
num_local_tokens: Optional[torch.Tensor] = None,
pad_sorted_ids: bool = False,
ep_size: int = 8,
num_token: Optional[int] = None
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
local_num_experts = num_experts // ep_size
if num_token:
if num_token < block_size:
max_num_tokens_padded = min(topk_ids.numel() * block_size, topk_ids.numel() + local_num_experts * (block_size - 1))
else:
max_num_tokens_padded = topk_ids.numel() + local_num_experts * (block_size - 1)
sorted_ids = torch.full((max_num_tokens_padded,), fill_value=topk_ids.numel(), dtype=torch.int32, device=topk_ids.device)
else:
max_num_tokens_padded = topk_ids.numel() + local_num_experts * (block_size - 1)
if pad_sorted_ids:
max_num_tokens_padded = round_up(max_num_tokens_padded, block_size)
sorted_ids = torch.empty((max_num_tokens_padded, ),
dtype=torch.int32,
device=topk_ids.device)
sorted_ids.fill_(topk_ids.numel())
max_num_m_blocks = triton.cdiv(max_num_tokens_padded, block_size)
# Expert ids must be zeroed out to prevent index out of bounds error while
# mapping global expert ids to local expert ids in expert parallelism.
if expert_map is not None:
expert_ids = torch.zeros((max_num_m_blocks, ),
dtype=torch.int32,
device=topk_ids.device)
else:
expert_ids = torch.empty((max_num_m_blocks, ),
dtype=torch.int32,
device=topk_ids.device)
num_tokens_post_pad = torch.zeros((1),
dtype=torch.int32,
device=topk_ids.device)
op.moe_align_block_size(topk_ids,
num_experts,
block_size,
sorted_ids,
expert_ids,
num_tokens_post_pad,
expert_map,
expert_mask,
num_local_tokens,
expert_map is not None)
return sorted_ids, expert_ids, num_tokens_post_pad
def fused_experts_impl_w16a16_marlin(hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
w1_marlin: torch.Tensor,
w2_marlin: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
cache13: torch.Tensor,
inplace: bool = False,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
use_nn_moe: Optional[bool] = False,
routed_scaling_factor: Optional[float] = 1.0,
shared_output: Optional[torch.Tensor] = None,
num_local_tokens: Optional[torch.Tensor] = None,
expect_m: Optional[int] = -1,
):
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
assert w1.is_contiguous(), "Expert weights1 must be contiguous"
assert w2.is_contiguous(), "Expert weights2 must be contiguous"
# 当前只支持 bf16 fp16
assert hidden_states.dtype in [torch.bfloat16,torch.float16]
compute_type = hidden_states.dtype
assert use_lightop, (
"only BW and set LMSLIM_USE_LIGHTOP=1 support Marlin W16A16 MoE")
num_tokens, K = hidden_states.shape
E, twoN, K_w1 = w1.shape
N = twoN // 2
E2, K_w2, N2_w2 = w2.shape
if global_num_experts == -1:
global_num_experts = E
top_k_num = topk_ids.shape[1]
# We execute the fused_moe kernel in chunks to circumvent this issue:
# https://github.com/vllm-project/vllm/issues/5938
CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE
M = min(num_tokens, CHUNK_SIZE)
# We can reuse the memory between these because by the time we need
# cache3, we're done with cache1
N2 = 2 * N
intermediate_cache1 = cache13[:M * top_k_num * N2].view(-1, N2)
intermediate_cache3 = cache13[:M * top_k_num * K]
# This needs separate memory since it's used concurrently with cache1
intermediate_cache2 = torch.empty((M * top_k_num, N),
device=hidden_states.device,
dtype=compute_type)
is_ep = expert_map is not None
expert_mask=None
ep_size=None
if is_ep:
expert_mask = torch.zeros((CHUNK_SIZE, top_k_num), dtype=torch.bool, device=hidden_states.device, requires_grad=False)
ep_size = global_num_experts // E
if inplace:
out_hidden_states = hidden_states
else:
out_hidden_states = torch.empty_like(hidden_states, dtype=compute_type)
for chunk in range((num_tokens // CHUNK_SIZE) + 1):
begin_chunk_idx, end_chunk_idx = (chunk * CHUNK_SIZE,
min((chunk + 1) * CHUNK_SIZE,
num_tokens))
curr_hidden_states = hidden_states[begin_chunk_idx:end_chunk_idx]
tokens_in_chunk, _ = curr_hidden_states.shape
if tokens_in_chunk == 0:
break
bs = tokens_in_chunk
if num_local_tokens is not None and expect_m != -1:
bs = expect_m
intermediate_cache3 = intermediate_cache3.view(-1, K)
if tokens_in_chunk < CHUNK_SIZE and chunk > 0:
# Adjust the intermediate cache size and config for the last
# chunk. Note that in most cases we only have one chunk
# so the cache size and config are already set correctly and
# do not need to be adjusted.
intermediate_cache1 = intermediate_cache1[:tokens_in_chunk * top_k_num]
intermediate_cache2 = intermediate_cache2[:tokens_in_chunk * top_k_num]
intermediate_cache3 = intermediate_cache3[:tokens_in_chunk * top_k_num]
# import logging
# logger = logging.getLogger(__name__)
# if not status:
# logger.info("lightop unsupport this size E:%s, N:%s, K:%s", E, N, K)
config_marlin_0, config_marlin_1, status = get_moe_cuda_marlin_config_w16a16(
E,
bs,
N2,
K,
K,
N,
top_k_num,
device_name,
num_cus,
hidden_states.dtype)
assert status, f'lightop unsupport this size E:{E}, N:{N}, K:{K}'
# Align with vLLM's default config handling for W16A16.
# if "BLOCK_SIZE_M" not in config_marlin_0:
# config_marlin_0["BLOCK_SIZE_M"] = 16
# if "BLOCK_SIZE_M" not in config_marlin_1:
# config_marlin_1["BLOCK_SIZE_M"] = config_marlin_0["BLOCK_SIZE_M"]
# if "MODE" not in config_marlin_0:
# config_marlin_0["MODE"] = 412
# if "MODE" not in config_marlin_1:
# config_marlin_1["MODE"] = 411
# if "DELTA" not in config_marlin_0:
# config_marlin_0["DELTA"] = 1
# if "DELTA" not in config_marlin_1:
# config_marlin_1["DELTA"] = 1
block_size_m = config_marlin_0["BLOCK_SIZE_M"]
curr_topk_ids = topk_ids[begin_chunk_idx:end_chunk_idx]
curr_topk_weights = topk_weights[begin_chunk_idx:end_chunk_idx]
if num_local_tokens is None:
sorted_token_ids, expert_ids, num_tokens_post_padded = (
moe_align_block_size(curr_topk_ids, block_size_m,
global_num_experts, expert_map=expert_map,
expert_mask = expert_mask[:tokens_in_chunk] if is_ep else None))
else:
sorted_token_ids, expert_ids, num_tokens_post_padded = (
moe_align_block_size_lightop(curr_topk_ids, block_size_m, global_num_experts,
expert_map = expert_map,
expert_mask = expert_mask[begin_chunk_idx:end_chunk_idx] if is_ep else None,
num_local_tokens = num_local_tokens,
ep_size=ep_size))
# GEMM1: hidden_states * w1 -> intermediate_cache1
moe_gemm_marlin_w16a16(
curr_hidden_states,
w1_marlin,
intermediate_cache1,
None,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
top_k_num,
config_marlin_0,
)
if (envs.VLLM_USE_FUSE_SILU_AND_MUL
and intermediate_cache1.dtype == intermediate_cache2.dtype
== torch.float16):
from lightop import fuse_silu_and_mul
fuse_silu_and_mul(intermediate_cache1, intermediate_cache2)
else:
torch.ops._C.silu_and_mul(intermediate_cache2, intermediate_cache1)
# GEMM2: intermediate_cache2 * w2, apply routing weights here.
moe_gemm_marlin_w16a16(
intermediate_cache2,
w2_marlin,
intermediate_cache3,
curr_topk_weights,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
1,
config_marlin_1,
)
intermediate_cache3 = intermediate_cache3.view(-1, top_k_num, K)
if envs.VLLM_USE_LIGHTOP_MOE_SUM_MUL_ADD:
from lightop import op as op
op.moe_sum(input=intermediate_cache3.view(*intermediate_cache3.size()),
output=out_hidden_states[begin_chunk_idx:end_chunk_idx], bias=shared_output[begin_chunk_idx:end_chunk_idx],
expert_mask=None, num_local_tokens=None, factor=routed_scaling_factor)
else:
if envs.VLLM_USE_LIGHTOP_MOE_SUM:
from lightop import op as op
op.moe_sum(input=intermediate_cache3.view(*intermediate_cache3.size()),
output=out_hidden_states[begin_chunk_idx:end_chunk_idx], bias=None,
expert_mask=None, num_local_tokens=None, factor=1.0)
elif envs.VLLM_USE_OPT_MOE_SUM:
moe_reduce_dispatch(intermediate_cache3.view(*intermediate_cache3.size()), out_hidden_states[begin_chunk_idx:end_chunk_idx], begin_chunk_idx, end_chunk_idx)
else:
ops.moe_sum(intermediate_cache3.view(*intermediate_cache3.size()),
out_hidden_states[begin_chunk_idx:end_chunk_idx])
return out_hidden_states
\ No newline at end of file
import torch
import numpy as np
# 从 [32, 64] int32的size中,重排后 每行相邻的8个uint4数据 混排后 pack成uint32数据
#原本是32 * 16算一次mmac,因为npack组成32 * 64大小
#现在是16 * 16算一次mmac,因为npack组成16 * 32大小
#这里是在对32 * 64 进行数据的重排
def get_weight_perms(interleave: bool=False):
# ================== 4条mmac 指令进行拼接的结果 ============
perm = []
for i in range(64): # 遍历64个线程,因为是针对一个warp内的
for col in range(2): # 遍历列方向2次, 代表2次mmac指令 具体是行还是列还不知道
cur_col = (i % 16) * 2 + col #计算当前线程在哪个列 这里是占据4列
for row in range(4): # 每个线程在 每个mmac中需要取8个uint4数据 占据8行
cur_row = (i // 16) * 4 + row
# 计算在整个 [32, 64]范围内的实际偏移
cur_idx = cur_row * 32 + cur_col
perm.append(cur_idx)
perm = np.array(perm)
if interleave:
# ================= 加入混排策略 =================
# # interleave = np.array([4, 0, 5, 1, 6, 2, 7, 3])
# # interleave = np.array([0, 4, 1, 5, 2, 6, 3, 7])
# QQQ 类似的 pack混排策略
interleave = np.array([4, 0, 5, 1, 6, 2, 7, 3])
# 按照 interleave 重排后展成 一维数组
perm = perm.reshape((-1, 8))[:, interleave].ravel()
perm = torch.from_numpy(perm)
return perm
#npack重排 //512大小
def marlin_weights_npack2(
q_w,
weight_perm,
k_tile=16,
n_tile=32):
# 2048, 768
size_k, size_n = q_w.shape
# [7168, 512] ==> [128, 16, 24,32]
q_w = q_w.reshape((size_k // k_tile, k_tile, size_n // n_tile, n_tile))
# [128, 16, 24,32] ==> [128, 24, 16,32]
q_w = q_w.permute((0, 2, 1, 3))
# [128, 24, 16,32] ==> [128, 12288]
q_w = q_w.reshape((size_k // k_tile, size_n * k_tile))
# 按照指定的 perm进行重排
q_w = q_w.reshape((-1, weight_perm.numel()))[:, weight_perm].reshape(q_w.shape)
# orig_device = q_w.device
# q_w = q_w.cpu().numpy()
# q_packed = np.zeros((q_w.shape[0], q_w.shape[1] // pack_factor), dtype=np.uint32)
# for i in range(pack_factor):
# q_packed |= q_w[:, i::pack_factor] << 4 * i
# q_packed = torch.from_numpy(q_packed.astype(np.int32)).to(orig_device)
return q_w
def w16a16_marlin_weight(full_w16a16_w # [size_n, size_k]
):
# import pdb
# pdb.set_trace()
# [size_n, size_k] == > [size_k, size_n] 此时已经是默认NN的 k * n 基于这个进行重排
full_w16a16_w = full_w16a16_w.T
# 获取 [16, 32]的权重数据块中,需要重排的顺序
weight_perm = get_weight_perms()
# 按照索引进行重排
marlin_q_w = marlin_weights_npack2(full_w16a16_w, weight_perm, k_tile=16, n_tile=32)
return marlin_q_w
\ No newline at end of file
......@@ -875,6 +875,16 @@ class CompressedTensorsLinearMethod(LinearMethodBase):
self.w8a8_strategy=int(os.getenv('W8A8_SUPPORT_METHODS', '1'))
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
weights_scheme = (
self.quantization_config
.target_scheme_map.get('Linear', {})
.get('weights')
)
if weights_scheme is not None:
num_bits = weights_scheme.num_bits
if num_bits == 4:
return layer.scheme.process_weights_after_loading(layer)
n=layer.weight.shape[0]
k=layer.weight.shape[1]
......
......@@ -225,6 +225,6 @@ class CompressedTensorsWNA16(CompressedTensorsScheme):
self.kernel.process_weights_after_loading(layer)
def apply_weights(
self, layer: torch.nn.Module, x: torch.Tensor, bias: torch.Tensor | None
self, layer: torch.nn.Module, x: torch.Tensor, bias: torch.Tensor | None, **kw
) -> torch.Tensor:
return self.kernel.apply_weights(layer, x, bias)
......@@ -79,7 +79,6 @@ from vllm.model_executor.model_loader.weight_utils import (
from vllm.model_executor.models.utils import sequence_parallel_chunk
from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors
from vllm.utils.deep_gemm import fp8_mqa_logits, fp8_paged_mqa_logits
from vllm.utils.torch_utils import direct_register_custom_op
from vllm.v1.attention.backends.mla.indexer import (
DeepseekV32IndexerBackend,
......@@ -96,6 +95,11 @@ from .utils import (
maybe_prefix,
)
if current_platform.is_rocm():
from lightop import op, gemmopt
else:
from vllm.utils.deep_gemm import fp8_mqa_logits, fp8_paged_mqa_logits
if current_platform.is_cuda_alike():
from vllm import _custom_ops as ops
elif current_platform.is_xpu():
......@@ -694,9 +698,11 @@ def sparse_attn_indexer(
)
fp8_mqa_logits_func = fp8_mqa_logits
if current_platform.is_rocm():
from vllm.attention.ops.rocm_aiter_mla_sparse import rocm_fp8_mqa_logits
# from vllm.attention.ops.rocm_aiter_mla_sparse import rocm_fp8_mqa_logits
fp8_mqa_logits_func = rocm_fp8_mqa_logits
# fp8_mqa_logits_func = rocm_fp8_mqa_logits
fp8_mqa_logits_func = op.mqa_logits
logits = fp8_mqa_logits_func(
q_fp8[chunk.token_start : chunk.token_end],
(k_fp8, k_scale.view(torch.float32)),
......@@ -744,11 +750,12 @@ def sparse_attn_indexer(
num_padded_tokens = batch_size * next_n
fp8_paged_mqa_logits_func = fp8_paged_mqa_logits
if current_platform.is_rocm():
from vllm.attention.ops.rocm_aiter_mla_sparse import (
rocm_fp8_paged_mqa_logits,
)
# from vllm.attention.ops.rocm_aiter_mla_sparse import (
# rocm_fp8_paged_mqa_logits,
# )
fp8_paged_mqa_logits_func = rocm_fp8_paged_mqa_logits
# fp8_paged_mqa_logits_func = rocm_fp8_paged_mqa_logits
fp8_paged_mqa_logits_func = gemmopt.paged_mqa_logits
logits = fp8_paged_mqa_logits_func(
padded_q_fp8_decode_tokens,
kv_cache,
......
......@@ -82,6 +82,7 @@ from vllm import envs
from vllm import _custom_ops as ops
from vllm.model_executor.utils import pad_weight, gemm_bank_conf
from vllm.utils import W8a8GetCacheJSON
from vllm.utils.torch_utils import direct_register_custom_op
logger = init_logger(__name__)
......@@ -305,6 +306,60 @@ class Qwen3MoeAttention(nn.Module):
self.q_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)
self.k_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)
def rms_rotary_embedding_fuse(
positions: torch.Tensor,
query: torch.Tensor,
head_size: int,
cos_sin_cache: torch.Tensor,
is_neox_style: bool,
q_weight: torch.Tensor,
k_weight: torch.Tensor,
epsilon: float,
key: torch.Tensor | None = None,
q_bias: torch.Tensor | None = None,
k_bias: torch.Tensor | None = None,
) -> None:
from lightop import rms_rotary_embedding_fuse as fused_kernel
fused_kernel(
positions,
query,
key,
head_size,
cos_sin_cache,
is_neox_style,
q_weight,
k_weight,
q_bias,
k_bias,
epsilon,
)
def rms_rotary_embedding_fuse_fake(
# q_out:torch.Tensor,
# k_out:torch.Tensor,
positions: torch.Tensor,
query: torch.Tensor,
head_size: int,
cos_sin_cache: torch.Tensor,
is_neox_style: bool,
q_weight: torch.Tensor,
k_weight: torch.Tensor,
epsilon: float,
key: torch.Tensor | None = None,
q_bias: torch.Tensor | None = None,
k_bias: torch.Tensor | None = None,
) -> None:
# Fake impl intentionally left as no-op for graph tracing modes.
pass
direct_register_custom_op(
op_name="rms_rotary_embedding_fuse",
op_func=rms_rotary_embedding_fuse,
mutates_args=["query", "key"],
fake_impl=rms_rotary_embedding_fuse_fake,
)
def forward(
self,
positions: torch.Tensor,
......@@ -313,20 +368,48 @@ class Qwen3MoeAttention(nn.Module):
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
# Add qk-norm
q_by_head = q.view(*q.shape[:-1], q.shape[-1] // self.head_dim, self.head_dim)
if envs.VLLM_USE_APEX_RN:
q_by_head = self.q_norm.forward_apex(q_by_head)
if envs.VLLM_USE_FUSED_RMS_ROPE :
# Fused RMSNorm + RoPE path through custom op.
cos_sin_cache = self.rotary_emb.cos_sin_cache
if (cos_sin_cache.device != q.device
or cos_sin_cache.dtype != q.dtype):
cos_sin_cache = cos_sin_cache.to(q.device,
dtype=q.dtype,
non_blocking=True)
# Persist the converted cache so we don't re-copy/re-allocate
# on every forward when the original buffer starts on CPU.
self.rotary_emb.cos_sin_cache = cos_sin_cache
# # q, k 使用 continuous
q = q.contiguous()
k = k.contiguous()
torch.ops.vllm.rms_rotary_embedding_fuse(
positions,
q,
k,
self.head_dim,
cos_sin_cache,
self.rotary_emb.is_neox_style,
self.q_norm.weight,
self.k_norm.weight,
None,
None,
self.q_norm.variance_epsilon,
)
else:
q_by_head = self.q_norm.forward_cuda(q_by_head)
q = q_by_head.view(q.shape)
q_by_head = q.view(*q.shape[:-1], q.shape[-1] // self.head_dim, self.head_dim)
if envs.VLLM_USE_APEX_RN:
q_by_head = self.q_norm.forward_apex(q_by_head)
else:
q_by_head = self.q_norm.forward_cuda(q_by_head)
q = q_by_head.view(q.shape)
k_by_head = k.view(*k.shape[:-1], k.shape[-1] // self.head_dim, self.head_dim)
if envs.VLLM_USE_APEX_RN:
k_by_head = self.k_norm.forward_apex(k_by_head)
else:
k_by_head = self.k_norm.forward_cuda(k_by_head)
k = k_by_head.view(k.shape)
q, k = self.rotary_emb(positions, q, k)
k_by_head = k.view(*k.shape[:-1], k.shape[-1] // self.head_dim, self.head_dim)
if envs.VLLM_USE_APEX_RN:
k_by_head = self.k_norm.forward_apex(k_by_head)
else:
k_by_head = self.k_norm.forward_cuda(k_by_head)
k = k_by_head.view(k.shape)
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v)
output, _ = self.o_proj(attn_output)
return output
......
......@@ -19,6 +19,8 @@ from vllm.v1.attention.backends.utils import (
CommonAttentionMetadata,
split_decodes_and_prefills,
)
from vllm.platforms import current_platform
from lightop import gemmopt
logger = init_logger(__name__)
......@@ -336,9 +338,14 @@ class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder):
seq_lens = common_attn_metadata.seq_lens[:num_decodes]
if is_deep_gemm_supported():
self.scheduler_metadata_buffer[:] = get_paged_mqa_logits_metadata(
seq_lens, self.kv_cache_spec.block_size, self.num_sms
)
if current_platform.is_rocm():
self.scheduler_metadata_buffer[:] = gemmopt.get_paged_mqa_logits_metadata(
seq_lens, self.kv_cache_spec.block_size, self.num_sms
)
else:
self.scheduler_metadata_buffer[:] = get_paged_mqa_logits_metadata(
seq_lens, self.kv_cache_spec.block_size, self.num_sms
)
decode_metadata = DeepSeekV32IndexerDecodeMetadata(
block_table=common_attn_metadata.block_table_tensor[:num_decodes, ...],
seq_lens=common_attn_metadata.seq_lens[:num_decodes],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment