Commit fcfc474d authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.8.3' into v0.8.3-dev

parents bb94d2e5 296c6572
......@@ -35,6 +35,8 @@ if HAS_TRITON:
# import to register the custom ops
import vllm.model_executor.layers.fused_moe.fused_marlin_moe # noqa
import vllm.model_executor.layers.fused_moe.fused_moe # noqa
from vllm.model_executor.layers.fused_moe.cutlass_moe import (
cutlass_moe_fp8)
from vllm.model_executor.layers.fused_moe.fused_moe import (
fused_experts, fused_moe, fused_topk, get_config_file_name,
grouped_topk)
......@@ -45,4 +47,5 @@ if HAS_TRITON:
"fused_experts",
"get_config_file_name",
"grouped_topk",
"cutlass_moe_fp8",
]
{
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 16,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 1
},
"2": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 16,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 1
},
"4": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 16,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 2,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 1
},
"8": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 2
},
"16": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 2,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 1
},
"24": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 2
},
"32": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 1
},
"48": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 1
},
"64": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 2,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 2
},
"96": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 2
},
"128": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 2
},
"256": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 8,
"num_warps": 4,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 2
},
"512": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 2
},
"1024": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 2
},
"1536": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 8,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 2
},
"2048": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 2
},
"3072": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 2
},
"4096": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 2
}
}
# SPDX-License-Identifier: Apache-2.0
"""Fused MoE kernel."""
from typing import Optional
import torch
from vllm import _custom_ops as ops
#TODO make the grouped gemm kernel consistent with scaled gemm kernel
def cutlass_moe_fp8(
a: torch.Tensor,
w1_q: torch.Tensor,
w2_q: torch.Tensor,
w1_scale: torch.Tensor,
w2_scale: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
ab_strides1: torch.Tensor,
c_strides1: torch.Tensor,
ab_strides2: torch.Tensor,
c_strides2: torch.Tensor,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
out_dtype: torch.dtype = torch.half,
apply_router_weight_on_input: bool = False,
) -> torch.Tensor:
"""
This function computes a a8w8-quantized Mixture of Experts (MoE) layer
using two sets of quantized weights, w1_q and w2_q, and top-k gating
mechanism. The matrix multiplications are implemented with CUTLASS
grouped gemm.
Parameters:
- a (torch.Tensor): The input tensor to the MoE layer.
Shape: [M, K]
- w1_q (torch.Tensor): The first set of fp8-quantized expert weights.
Shape: [num_experts, K, 2N] (the weights are passed transposed)
- w2_q (torch.Tensor): The second set of fp8-quantized expert weights.
Shape: [num_experts, N, K] (the weights are passed transposed)
- w1_scale (torch.Tensor): The fp32 scale to dequantize w1_q.
Shape: [num_experts] or [num_experts, 2N]
- w2_scale (torch.Tensor): The fp32 scale to dequantize w2_q.
Shape: [num_experts] or [num_experts, K]
- gating_output (torch.Tensor): The output of the gating operation
(before softmax).
- topk_weights (torch.Tensor): The weights of each token->expert mapping.
- ab_strides1 (torch.Tensor): The input and weights strides of the first
grouped gemm.
- c_strides1 (torch.Tensor): The output strides of the first grouped gemm.
- ab_strides2 (torch.Tensor): The input and weights strides of the second
grouped gemm.
- c_strides2 (torch.Tensor): The output strides of the second grouped gemm.
- a1_scale (Optional[torch.Tensor]): The optional fp32 scale to quantize a.
Shape: scalar or [M]
- a2_scale (Optional[torch.Tensor]): The optional fp32 scale to
quantize the intermediate result between the gemms.
Shape: scalar or [M]
- out_dtype (torch.Tensor): The output tensor type.
Returns:
- torch.Tensor: The fp16 output tensor after applying the MoE layer.
"""
assert topk_weights.shape == topk_ids.shape, "topk shape mismatch"
assert w1_q.dtype == torch.float8_e4m3fn
assert w2_q.dtype == torch.float8_e4m3fn
assert a.shape[1] == w1_q.shape[1], "Hidden size mismatch w1"
assert w1_q.shape[2] == w2_q.shape[1] * 2, "Hidden size mismatch w2"
assert w1_q.shape[0] == w2_q.shape[0], "Expert number mismatch"
assert a1_scale is None or a1_scale.dim(
) == 0 or a1_scale.shape[0] == 1 or a1_scale.shape[0] == a.shape[
0], "Input scale shape mismatch"
assert w1_scale.dim() == 1 or w1_scale.shape[1] == 1 or w1_scale.shape[
1] == w1_q.shape[2], "W1 scale shape mismatch"
assert w2_scale.dim() == 1 or w2_scale.shape[1] == 1 or w2_scale.shape[
1] == w2_q.shape[2], "W2 scale shape mismatch"
assert w1_q.shape[0] == w2_q.shape[0], "Weights expert number mismatch"
assert w1_q.shape[0] == w1_scale.shape[
0], "w1 scales expert number mismatch"
assert w1_q.shape[0] == w2_scale.shape[
0], "w2 scales expert number mismatch"
assert a2_scale is None or a1_scale is None or a2_scale.shape == a1_scale.shape, "Intermediate scale shape mismatch" # noqa: E501
assert ab_strides1.shape[0] == w1_q.shape[
0], "AB Strides 1 expert number mismatch"
assert c_strides1.shape[0] == w1_q.shape[
0], "C Strides 1 expert number mismatch"
assert ab_strides2.shape[0] == w2_q.shape[
0], "AB Strides 2 expert number mismatch"
assert c_strides2.shape[0] == w2_q.shape[
0], "C Strides 2 expert number mismatch"
assert out_dtype in [torch.half, torch.bfloat16], "Invalid output dtype"
num_experts = w1_q.size(0)
m = a.size(0)
k = w1_q.size(1)
n = w2_q.size(1)
topk = topk_ids.size(1)
per_act_token = a1_scale.numel() != 1 if a1_scale is not None else (
a2_scale.numel() != 1 if a2_scale is not None else False)
if apply_router_weight_on_input:
assert topk == 1, \
"apply_router_weight_on_input is only implemented for topk=1"
# TODO: this only works for topK=1, will need to update for topK>1
a = a * topk_weights.to(out_dtype)
a_q, a1_scale = ops.scaled_fp8_quant(
a, a1_scale, use_per_token_if_dynamic=per_act_token)
device = a_q.device
expert_offsets = torch.empty((num_experts + 1),
dtype=torch.int32,
device=device)
problem_sizes1 = torch.empty((num_experts, 3),
dtype=torch.int32,
device=device)
problem_sizes2 = torch.empty((num_experts, 3),
dtype=torch.int32,
device=device)
a_map = torch.empty((topk_ids.numel()), dtype=torch.int32, device=device)
c_map = torch.empty((topk_ids.numel()), dtype=torch.int32, device=device)
ops.get_cutlass_moe_mm_data(topk_ids, expert_offsets, problem_sizes1,
problem_sizes2, a_map, c_map, num_experts, n,
k)
rep_a_q = a_q.view(dtype=torch.uint8)[a_map].view(dtype=a_q.dtype)
rep_a1_scales = a1_scale[a_map] if per_act_token else a1_scale
c1 = torch.empty((m * topk, n * 2), device=device, dtype=out_dtype)
c2 = torch.empty((m * topk, k), device=device, dtype=out_dtype)
ops.cutlass_moe_mm(c1, rep_a_q, w1_q, rep_a1_scales, w1_scale,
expert_offsets[:-1], problem_sizes1, ab_strides1,
ab_strides1, c_strides1)
intermediate = torch.empty((m * topk, n), device=device, dtype=out_dtype)
torch.ops._C.silu_and_mul(intermediate, c1)
intemediate_q, a2_scale = ops.scaled_fp8_quant(
intermediate, a2_scale, use_per_token_if_dynamic=per_act_token)
ops.cutlass_moe_mm(c2, intemediate_q, w2_q, a2_scale, w2_scale,
expert_offsets[:-1], problem_sizes2, ab_strides2,
ab_strides2, c_strides2)
# Gather tokens
c2 = c2[c_map].view(m, topk, k)
if not apply_router_weight_on_input:
c2 = c2 * topk_weights.view(m, topk, 1).to(out_dtype)
return c2.sum(dim=1)
# SPDX-License-Identifier: Apache-2.0
import importlib.util
from typing import Optional, Tuple
import torch
import vllm.envs as envs
from vllm import _custom_ops as ops
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.moe_align_block_size import (
moe_align_block_size)
from vllm.model_executor.layers.fused_moe.utils import (_fp8_perm,
_fp8_quantize,
_resize_cache)
from vllm.utils import round_up
logger = init_logger(__name__)
has_deep_gemm = importlib.util.find_spec("deep_gemm") is not None
def _valid_deep_gemm(hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
expert_map: Optional[torch.Tensor] = None) -> bool:
"""
Check if the given problem size is supported by the DeepGemm grouped
gemm kernel. All of M, N, K and the quantization block_shape must be
aligned by `dg.get_m_alignment_for_contiguous_layout()`.
"""
if not has_deep_gemm:
return False
# Lazy import to avoid CUDA initialization problems.
import deep_gemm as dg
# Expert maps not supported yet.
if expert_map is not None:
return False
align = dg.get_m_alignment_for_contiguous_layout()
M = hidden_states.shape[0]
_, K, N = w2.shape
# For now, disable DeepGemm for small N until better permute/unpermute
# ops are available.
if N <= 512:
return False
if align > M or N % align != 0 or K % align != 0:
return False
return (hidden_states.is_contiguous() and w1.is_contiguous()
and w2.is_contiguous())
def _moe_permute(
curr_hidden_states: torch.Tensor,
a1q_scale: Optional[torch.Tensor],
curr_topk_ids: torch.Tensor,
global_num_experts: int,
expert_map: Optional[torch.Tensor],
block_m: int,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], torch.Tensor, torch.Tensor,
Optional[torch.Tensor]]:
"""
Determine the sorted_token_ids, expert_ids for the given problem size.
Permute the hidden states and scales according to `sorted_token_ids`.
"""
top_k_num = curr_topk_ids.shape[1]
tokens_in_chunk, _ = curr_hidden_states.shape
sorted_token_ids, expert_ids, num_tokens_post_padded = (
moe_align_block_size(curr_topk_ids,
block_m,
global_num_experts,
expert_map,
pad_sorted_ids=True))
inv_perm: Optional[torch.Tensor] = None
num_tokens = top_k_num * tokens_in_chunk
sorted_token_ids = sorted_token_ids.clamp(max=num_tokens - 1)
expert_ids = torch.repeat_interleave(expert_ids, block_m, dim=0)
inv_perm = torch.argsort(sorted_token_ids)[:num_tokens]
# Permute according to sorted token ids.
curr_hidden_states = _fp8_perm(curr_hidden_states,
sorted_token_ids // top_k_num)
if a1q_scale is not None:
a1q_scale = a1q_scale[sorted_token_ids // top_k_num]
return (curr_hidden_states, a1q_scale, sorted_token_ids, expert_ids,
inv_perm)
def _moe_unpermute_and_reduce(
out: torch.Tensor,
curr_hidden: torch.Tensor,
inv_perm: Optional[torch.Tensor],
topk_weight: torch.Tensor,
) -> None:
"""
Unpermute the final result and apply topk_weights, then perform the final
reduction on the hidden states.
"""
M, topk = topk_weight.shape
K = curr_hidden.shape[1]
curr_hidden = curr_hidden[inv_perm, ...]
curr_hidden = curr_hidden.view(-1, topk, K)
curr_hidden.mul_(topk_weight.view(M, -1, 1))
ops.moe_sum(curr_hidden, out)
def deep_gemm_moe_fp8(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
w1_scale: torch.Tensor,
w2_scale: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
inplace: bool = False,
activation: str = "silu",
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
This function computes a a8w8-quantized Mixture of Experts (MoE) layer
using two sets of quantized weights, w1_q and w2_q, and top-k gating
mechanism. The matrix multiplications are implemented with DeepGemm
grouped gemm.
Parameters:
- hidden_states (torch.Tensor): The input tensor to the MoE layer.
Shape: [M, K]
- w1 (torch.Tensor): The first set of fp8 quantized expert weights.
Shape: [num_experts, K, 2N] (the weights are passed transposed)
- w2 (torch.Tensor): The second set of fp8 quantized expert weights.
Shape: [num_experts, N, K] (the weights are passed transposed)
- w1_scale (torch.Tensor): The fp32 scale to dequantize w1_q.
Shape: [num_experts] or [num_experts, 2N]
- w2_scale (torch.Tensor): The fp32 scale to dequantize w2_q.
Shape: [num_experts] or [num_experts, K]
- topk_weights (torch.Tensor): The weights of each token->expert mapping.
- topk_ids (torch.Tensor): The token->expert mapping for topk_weights.
- inplace (bool): If True, perform the operation in-place.
Defaults to False.
- activation (str): The activation function to apply after the first
MoE layer.
- global_num_experts (int): The total number of experts in the global
expert space.
- expert_map (Optional[torch.Tensor]): A tensor mapping expert indices
from the global expert space to the local expert space of the expert
parallel shard.
- a1_scale (Optional[torch.Tensor]): The optional fp32 scale to quantize a.
Shape: scalar or [M]
- a2_scale (Optional[torch.Tensor]): The optional fp32 scale to
quantize the intermediate result between the gemms.
Shape: scalar or [M]
Returns:
- torch.Tensor: The bfloat16 output tensor after applying the MoE layer.
"""
# Lazy import to avoid CUDA initialization problems.
import deep_gemm as dg
assert expert_map is None, "Expert maps not supported yet"
assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch"
assert topk_weights.shape == topk_ids.shape, "topk shape mismatch"
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
assert w1.stride(-1) == 1, "Stride of last dimension must be 1"
assert w2.stride(-1) == 1, "Stride of last dimension must be 1"
assert hidden_states.dtype in [
torch.float32, torch.float16, torch.bfloat16
]
assert w1.dtype == torch.float8_e4m3fn
assert w2.dtype == torch.float8_e4m3fn
assert w1.shape[0] == w2.shape[0], "Expert number mismatch"
assert w1.shape[0] == w1_scale.shape[0], "w1 scales expert number mismatch"
assert w1.shape[0] == w2_scale.shape[0], "w2 scales expert number mismatch"
assert a1_scale is None or a1_scale.dim(
) == 0 or a1_scale.shape[0] == 1 or a1_scale.shape[
0] == hidden_states.shape[0], "Input scale shape mismatch"
assert a2_scale is None or a1_scale is None or a2_scale.shape == a1_scale.shape, "Intermediate scale shape mismatch" # noqa: E501
num_tokens, _ = hidden_states.shape
E, N, _ = w1.shape
K = w2.shape[1]
if global_num_experts == -1:
global_num_experts = E
# We execute the fused_moe kernel in chunks to circumvent this issue:
# https://github.com/vllm-project/vllm/issues/5938
CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE
assert _valid_deep_gemm(hidden_states, w1, w2, expert_map)
if inplace:
out_hidden_states = hidden_states
else:
out_hidden_states = torch.empty_like(hidden_states)
block_m = dg.get_m_alignment_for_contiguous_layout()
block_shape = [block_m, block_m]
assert w1_scale is not None
assert w2_scale is not None
# We attempt to transpose and align offline in Fp8MoEMethod, in which
# case these calls will be nops. Otherwise, they'll be performed every
# time the layer is executed.
w1_scale = dg.get_col_major_tma_aligned_tensor(w1_scale).contiguous()
w2_scale = dg.get_col_major_tma_aligned_tensor(w2_scale).contiguous()
M_sum = topk_ids.numel() + global_num_experts * (block_m - 1)
M_sum = round_up(M_sum, block_m)
num_chunks = (num_tokens // CHUNK_SIZE) + 1
# We can reuse the memory between cache1 and cache3 because by the time
# we need cache3, we're done with cache1
workspace13 = torch.empty(M_sum * max(N, K),
device=hidden_states.device,
dtype=hidden_states.dtype)
workspace1 = workspace13[:M_sum * N].view(M_sum, N)
workspace2 = torch.empty((M_sum, N // 2),
device=hidden_states.device,
dtype=hidden_states.dtype)
workspace3 = workspace13[:M_sum * K].view(M_sum, K)
for chunk in range(num_chunks):
begin_chunk_idx, end_chunk_idx = (chunk * CHUNK_SIZE,
min((chunk + 1) * CHUNK_SIZE,
num_tokens))
curr_hidden_states = hidden_states[begin_chunk_idx:end_chunk_idx]
tokens_in_chunk, _ = curr_hidden_states.shape
if tokens_in_chunk == 0:
break
curr_topk_ids = topk_ids[begin_chunk_idx:end_chunk_idx]
curr_topk_weights = topk_weights[begin_chunk_idx:end_chunk_idx]
a1q_scale: Optional[torch.Tensor] = None
qcurr_hidden_states, a1q_scale = _fp8_quantize(curr_hidden_states,
a1_scale, block_shape)
(qcurr_hidden_states, a1q_scale, sorted_token_ids, expert_ids,
inv_perm) = _moe_permute(qcurr_hidden_states, a1q_scale,
curr_topk_ids, global_num_experts,
expert_map, block_m)
# Adjust the intermediate cache size and config for the last chunk.
# Note that in most cases we only have one chunk so the cache size
# and config are already set correctly and do not need to be adjusted.
if tokens_in_chunk < CHUNK_SIZE and chunk > 0:
curr_M = sorted_token_ids.numel()
workspace1 = _resize_cache(workspace1, (curr_M, N))
workspace2 = _resize_cache(workspace2, (curr_M, N // 2))
workspace3 = _resize_cache(workspace3, (curr_M, K))
dg.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(
(qcurr_hidden_states, a1q_scale), (w1, w1_scale), workspace1,
expert_ids)
if activation == "silu":
torch.ops._C.silu_and_mul(workspace2, workspace1.view(-1, N))
elif activation == "gelu":
torch.ops._C.gelu_and_mul(workspace2, workspace1.view(-1, N))
else:
raise ValueError(f"Unsupported FusedMoe activation: {activation}")
a2q_scale: Optional[torch.Tensor] = None
qworkspace2, a2q_scale = _fp8_quantize(workspace2, a2_scale,
block_shape)
dg.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(
(qworkspace2, a2q_scale), (w2, w2_scale), workspace3, expert_ids)
_moe_unpermute_and_reduce(
out_hidden_states[begin_chunk_idx:end_chunk_idx],
workspace3.view(*workspace3.shape), inv_perm, curr_topk_weights)
return out_hidden_states
......@@ -12,14 +12,18 @@ import triton.language as tl
import vllm.envs as envs
from vllm import _custom_ops as ops
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
per_token_group_quant_fp8)
from vllm.model_executor.layers.quantization.utils.int8_utils import (
per_token_group_quant_int8)
from vllm.model_executor.layers.fused_moe.deep_gemm_moe import (
_valid_deep_gemm, deep_gemm_moe_fp8)
from vllm.model_executor.layers.fused_moe.moe_align_block_size import (
moe_align_block_size)
from vllm.model_executor.layers.fused_moe.utils import _fp8_quantize
from vllm.platforms import current_platform
from vllm.utils import direct_register_custom_op
from .rocm_aiter_fused_moe import (is_rocm_aiter_moe_enabled,
rocm_aiter_fused_experts,
rocm_aiter_topk_softmax)
logger = init_logger(__name__)
device_name = current_platform.get_device_name().replace(" ", "_")
......@@ -671,248 +675,13 @@ def fused_moe_kernel(
tl.store(c_ptrs, accumulator, mask=c_mask)
def ceil_div(a, b):
return (a + b - 1) // b
@triton.jit
def moe_align_block_size_stage1(
topk_ids_ptr,
tokens_cnts_ptr,
num_experts: tl.constexpr,
numel: tl.constexpr,
tokens_per_thread: tl.constexpr,
):
pid = tl.program_id(0)
start_idx = pid * tokens_per_thread
off_c = (pid + 1) * num_experts
for i in range(tokens_per_thread):
if start_idx + i < numel:
idx = tl.load(topk_ids_ptr + start_idx + i)
token_cnt = tl.load(tokens_cnts_ptr + off_c + idx)
tl.store(tokens_cnts_ptr + off_c + idx, token_cnt + 1)
@triton.jit
def moe_align_block_size_stage2(
tokens_cnts_ptr,
num_experts: tl.constexpr,
):
pid = tl.program_id(0)
last_cnt = 0
for i in range(1, num_experts + 1):
token_cnt = tl.load(tokens_cnts_ptr + i * num_experts + pid)
last_cnt = last_cnt + token_cnt
tl.store(tokens_cnts_ptr + i * num_experts + pid, last_cnt)
@triton.jit
def moe_align_block_size_stage3(
total_tokens_post_pad_ptr,
tokens_cnts_ptr,
cumsum_ptr,
num_experts: tl.constexpr,
block_size: tl.constexpr,
):
last_cumsum = 0
off_cnt = num_experts * num_experts
for i in range(1, num_experts + 1):
token_cnt = tl.load(tokens_cnts_ptr + off_cnt + i - 1)
last_cumsum = last_cumsum + tl.cdiv(token_cnt, block_size) * block_size
tl.store(cumsum_ptr + i, last_cumsum)
tl.store(total_tokens_post_pad_ptr, last_cumsum)
@triton.jit
def moe_align_block_size_stage4(
topk_ids_ptr,
sorted_token_ids_ptr,
expert_ids_ptr,
tokens_cnts_ptr,
cumsum_ptr,
num_experts: tl.constexpr,
block_size: tl.constexpr,
numel: tl.constexpr,
tokens_per_thread: tl.constexpr,
):
pid = tl.program_id(0)
start_idx = tl.load(cumsum_ptr + pid)
end_idx = tl.load(cumsum_ptr + pid + 1)
for i in range(start_idx, end_idx, block_size):
tl.store(expert_ids_ptr + i // block_size, pid)
start_idx = pid * tokens_per_thread
off_t = pid * num_experts
for i in range(start_idx, tl.minimum(start_idx + tokens_per_thread,
numel)):
expert_id = tl.load(topk_ids_ptr + i)
token_cnt = tl.load(tokens_cnts_ptr + off_t + expert_id)
rank_post_pad = token_cnt + tl.load(cumsum_ptr + expert_id)
tl.store(sorted_token_ids_ptr + rank_post_pad, i)
tl.store(tokens_cnts_ptr + off_t + expert_id, token_cnt + 1)
# Triton implementation based on:
# https://github.com/sgl-project/sglang/commit/ba5112ff691d791a9e38c6c71f59324a5fcb49d0
def moe_align_block_size_triton(
topk_ids: torch.Tensor,
num_experts: int,
block_size: int,
sorted_token_ids: torch.Tensor,
expert_ids: torch.Tensor,
num_tokens_post_pad: torch.Tensor,
) -> None:
numel = topk_ids.numel()
grid = (num_experts, )
tokens_cnts = torch.zeros((num_experts + 1, num_experts),
dtype=torch.int32,
device=topk_ids.device)
cumsum = torch.zeros((num_experts + 1, ),
dtype=torch.int32,
device=topk_ids.device)
tokens_per_thread = ceil_div(numel, num_experts)
moe_align_block_size_stage1[grid](
topk_ids,
tokens_cnts,
num_experts,
numel,
tokens_per_thread,
)
moe_align_block_size_stage2[grid](
tokens_cnts,
num_experts,
)
moe_align_block_size_stage3[(1, )](
num_tokens_post_pad,
tokens_cnts,
cumsum,
num_experts,
block_size,
)
moe_align_block_size_stage4[grid](
topk_ids,
sorted_token_ids,
expert_ids,
tokens_cnts,
cumsum,
num_experts,
block_size,
numel,
tokens_per_thread,
)
def moe_align_block_size(
topk_ids: torch.Tensor,
block_size: int,
num_experts: int,
expert_map: torch.Tensor = None,
num_token: Optional[int] = None
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Aligns the token distribution across experts to be compatible with block
size for matrix multiplication.
Parameters:
- topk_ids: A tensor of shape [total_tokens, top_k] representing the
top-k expert indices for each token.
- block_size: The block size used in block matrix multiplication.
- num_experts: The total number of experts.
- expert_map: A tensor of shape [num_experts] that maps the expert index
from the global space to the local index space of the current
expert parallel shard. If the expert is not in the current expert
parallel shard, the mapping is set to -1.
Returns:
- sorted_token_ids: A tensor containing the sorted token indices according
to their allocated expert.
- expert_ids: A tensor indicating the assigned expert index for each block.
- num_tokens_post_padded: The total number of tokens after padding,
ensuring divisibility by block_size.
This function pads the number of tokens that each expert needs to process
so that it is divisible by block_size.
Padding ensures that during block matrix multiplication, the dimensions
align correctly.
Example:
Given topk_ids = [[2, 3, 4], [1, 2, 4], [1, 3, 4], [1, 2, 3]],
block_size = 4, and num_experts = 4:
- We initially have 12 tokens (after repeating 'top_k' times) and 4 experts,
with each expert needing to process 3 tokens.
- As block_size is 4, we pad 1 token for each expert.
- First, flatten topk_ids to [2, 3, 4, 1, 2, 4, 1, 3, 4, 1, 2, 3].
- Then append padding tokens [12, 12, 12, 12] for each block.
- After sorting by expert index, we obtain token_ids
[3, 6, 9, 12, 0, 4, 10, 12, 1, 7, 11, 12, 2, 5, 8, 12].
Tokens 12 are non-existent (padding) and are ignored in
the subsequent matrix multiplication.
- The padding ensures that the total number of tokens is now divisible
by block_size for proper block matrix operations.
"""
if num_token:
if num_token < block_size:
max_num_tokens_padded = min(topk_ids.numel() * block_size, topk_ids.numel() + num_experts * (block_size - 1))
else:
max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1)
sorted_ids = torch.full((max_num_tokens_padded,), fill_value=topk_ids.numel(), dtype=torch.int32, device=topk_ids.device)
else:
max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1)
sorted_ids = torch.empty((max_num_tokens_padded, ),
dtype=torch.int32,
device=topk_ids.device)
sorted_ids.fill_(topk_ids.numel())
max_num_m_blocks = triton.cdiv(max_num_tokens_padded, block_size)
expert_ids = torch.empty((max_num_m_blocks, ),
dtype=torch.int32,
device=topk_ids.device)
num_tokens_post_pad = torch.empty((1),
dtype=torch.int32,
device=topk_ids.device)
if num_experts >= 224:
if envs.VLLM_ENABLE_MOE_ALIGN_BLOCK_SIZE_TRITON or num_experts != 256:
moe_align_block_size_triton(
topk_ids,
num_experts,
block_size,
sorted_ids,
expert_ids,
num_tokens_post_pad,
)
else:
# Currently requires num_experts=256
ops.sgl_moe_align_block_size(
topk_ids,
num_experts,
block_size,
sorted_ids,
expert_ids,
num_tokens_post_pad,
)
else:
ops.moe_align_block_size(topk_ids, num_experts, block_size, sorted_ids,
expert_ids, num_tokens_post_pad)
if expert_map is not None:
expert_ids = expert_map[expert_ids]
return sorted_ids, expert_ids, num_tokens_post_pad
def invoke_fused_moe_kernel(A: torch.Tensor,
B: torch.Tensor,
C: torch.Tensor,
A_scale: Optional[torch.Tensor],
B_scale: Optional[torch.Tensor],
B_zp: Optional[torch.Tensor],
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
topk_weights: Optional[torch.Tensor],
sorted_token_ids: torch.Tensor,
expert_ids: torch.Tensor,
num_tokens_post_padded: torch.Tensor,
......@@ -926,33 +695,24 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
use_int4_w4a16: bool,
block_shape: Optional[List[int]] = None,
use_nn_moe: Optional[bool]=False) -> None:
assert topk_weights.stride(1) == 1
assert topk_weights is not None or not mul_routed_weight
assert topk_weights is None or topk_weights.stride(1) == 1
assert sorted_token_ids.stride(0) == 1
if use_fp8_w8a8:
assert B_scale is not None
if block_shape is None:
A, A_scale = ops.scaled_fp8_quant(A, A_scale)
else:
assert len(block_shape) == 2
block_n, block_k = block_shape[0], block_shape[1]
A, A_scale = per_token_group_quant_fp8(A, block_k)
assert triton.cdiv(A.shape[-1], block_k) == A_scale.shape[-1]
assert triton.cdiv(B.shape[-2], block_n) == B_scale.shape[-2]
assert triton.cdiv(B.shape[-1], block_k) == B_scale.shape[-1]
assert (block_shape is None or triton.cdiv(B.shape[-2], block_shape[0])
== B_scale.shape[-2])
assert (block_shape is None or triton.cdiv(B.shape[-1], block_shape[1])
== B_scale.shape[-1])
elif use_int8_w8a8:
assert B_scale is not None
if block_shape is None:
A, A_scale = ops.scaled_int8_quant(A, A_scale)
else:
assert len(block_shape) == 2
block_n, block_k = block_shape[0], block_shape[1]
A, A_scale = per_token_group_quant_int8(A, block_k)
assert triton.cdiv(A.shape[-1], block_k) == A_scale.shape[-1]
assert triton.cdiv(B.shape[-2], block_n) == B_scale.shape[-2]
assert triton.cdiv(B.shape[-1], block_k) == B_scale.shape[-1]
assert (block_shape is None or triton.cdiv(B.shape[-2], block_shape[0])
== B_scale.shape[-2])
assert (block_shape is None or triton.cdiv(B.shape[-1], block_shape[1])
== B_scale.shape[-1])
elif use_int8_w8a16 or use_int4_w4a16:
assert B_scale is not None
assert block_shape is None or block_shape[0] == 0
......@@ -960,6 +720,9 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
assert A_scale is None
assert B_scale is None
M = A.shape[0]
num_tokens = M * top_k
EM = sorted_token_ids.shape[0]
if A.shape[0] < config["BLOCK_SIZE_M"]:
# optimize for small batch_size.
......@@ -977,20 +740,20 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
assert B_zp is None or B_zp.ndim == 3
if os.environ.get('moe_wna16_use_cuda') == '1':
use_moe_wna16_cuda = should_moe_wna16_use_cuda(
num_valid_tokens=topk_ids.numel(),
group_size=block_shape[1],
num_experts=B.shape[0],
bit=4 if use_int4_w4a16 else 8)
num_valid_tokens=num_tokens,
group_size=block_shape[1],
num_experts=B.shape[0],
bit=4 if use_int4_w4a16 else 8)
config = config.copy()
config.update(
get_moe_wna16_block_config(config=config,
use_moe_wna16_cuda=use_moe_wna16_cuda,
num_valid_tokens=topk_ids.numel(),
num_valid_tokens=num_tokens,
size_k=A.shape[1],
size_n=B.shape[1],
num_experts=B.shape[1],
group_size=block_shape[1],
real_top_k=topk_ids.shape[1],
real_top_k=top_k,
block_size_m=config["BLOCK_SIZE_M"]))
if use_moe_wna16_cuda:
......@@ -1055,7 +818,7 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
B.shape[1],
A.shape[1],
EM,
topk_ids.numel(),
num_tokens,
A.stride(0),
A.stride(1),
B.stride(0),
......@@ -1079,7 +842,6 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
use_int8_w8a16=use_int8_w8a16,
**config,
)
else:
# config = config.copy()
# BLOCK_SIZE_K = config.pop("BLOCK_SIZE_K")
......@@ -1099,7 +861,7 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
B.shape[1] if not use_nn_moe else B.shape[2],
A.shape[1],
EM,
topk_ids.numel(),
num_tokens,
A.stride(0),
A.stride(1),
B.stride(0),
......@@ -1352,12 +1114,34 @@ def try_get_optimal_moe_config(
return config
def vllm_topk_softmax(topk_weights: torch.Tensor, topk_indices: torch.Tensor,
token_expert_indices: torch.Tensor,
gating_output: torch.Tensor,
renormalize: bool) -> tuple[torch.Tensor, ...]:
ops.topk_softmax(
topk_weights,
topk_indices,
token_expert_indices,
gating_output,
)
if renormalize:
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
return topk_weights, topk_indices
def dispatch_topk_func() -> Callable[..., tuple[torch.Tensor, ...]]:
if is_rocm_aiter_moe_enabled():
return rocm_aiter_topk_softmax
return vllm_topk_softmax
def fused_topk(
hidden_states: torch.Tensor,
gating_output: torch.Tensor,
topk: int,
renormalize: bool,
):
) -> Tuple[torch.Tensor, torch.Tensor]:
assert hidden_states.shape[0] == gating_output.shape[0], (
"Number of tokens mismatch")
......@@ -1376,30 +1160,29 @@ def fused_topk(
dtype=torch.int32,
device=hidden_states.device)
ops.topk_softmax(
topk_weights,
topk_ids,
token_expert_indicies,
gating_output.float(), # TODO(woosuk): Optimize this.
)
del token_expert_indicies # Not used. Will be used in the future.
gating_output_float = gating_output.float() # TODO(woosuk): Optimize this.
if renormalize:
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
topk_func = dispatch_topk_func()
topk_weights, topk_ids = topk_func(topk_weights, topk_ids,
token_expert_indicies,
gating_output_float, renormalize)
del token_expert_indicies # Not used. Will be used in the future.
return topk_weights, topk_ids
# This is used by the Deepseek-V2 and Deepseek-V3 model
@torch.compile(dynamic=True, backend=current_platform.simple_compile_backend)
def grouped_topk(hidden_states: torch.Tensor,
gating_output: torch.Tensor,
topk: int,
renormalize: bool,
num_expert_group: int = 0,
topk_group: int = 0,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None):
def grouped_topk(
hidden_states: torch.Tensor,
gating_output: torch.Tensor,
topk: int,
renormalize: bool,
num_expert_group: int = 0,
topk_group: int = 0,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None
) -> Tuple[torch.Tensor, torch.Tensor]:
assert hidden_states.shape[0] == gating_output.shape[0], (
"Number of tokens mismatch")
......@@ -1448,11 +1231,12 @@ def grouped_topk(hidden_states: torch.Tensor,
return topk_weights.to(torch.float32), topk_ids.to(torch.int32)
def get_config_dtype_str(dtype: torch.dtype,
use_int4_w4a16: Optional[bool] = False,
use_int8_w8a16: Optional[bool] = False,
use_fp8_w8a8: Optional[bool] = False,
use_int8_w8a8: Optional[bool] = False):
def get_config_dtype_str(
dtype: torch.dtype,
use_int4_w4a16: Optional[bool] = False,
use_int8_w8a16: Optional[bool] = False,
use_fp8_w8a8: Optional[bool] = False,
use_int8_w8a8: Optional[bool] = False) -> Optional[str]:
if use_fp8_w8a8:
return "fp8_w8a8"
elif use_int8_w8a8:
......@@ -1474,6 +1258,7 @@ def inplace_fused_experts(hidden_states: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
activation: Optional[str] = None,
apply_router_weight_on_input: bool = False,
use_fp8_w8a8: bool = False,
use_int8_w8a8: bool = False,
use_int8_w8a16: bool = False,
......@@ -1489,10 +1274,10 @@ def inplace_fused_experts(hidden_states: torch.Tensor,
block_shape: Optional[List[int]] = None,
use_nn_moe: Optional[bool] = False) -> None:
fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids, True,
activation, use_fp8_w8a8, use_int8_w8a8, use_int8_w8a16,
use_int4_w4a16, global_num_experts, expert_map,
w1_scale, w2_scale, w1_zp, w2_zp, a1_scale, a2_scale,
block_shape, use_nn_moe)
activation, apply_router_weight_on_input, use_fp8_w8a8,
use_int8_w8a8, use_int8_w8a16, use_int4_w4a16, global_num_experts,
expert_map, w1_scale, w2_scale, w1_zp, w2_zp, a1_scale,
a2_scale, block_shape, use_nn_moe)
def inplace_fused_experts_fake(
......@@ -1502,6 +1287,7 @@ def inplace_fused_experts_fake(
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
activation: Optional[str] = None,
apply_router_weight_on_input: bool = False,
use_fp8_w8a8: bool = False,
use_int8_w8a8: bool = False,
use_int8_w8a16: bool = False,
......@@ -1534,6 +1320,7 @@ def outplace_fused_experts(
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
activation: Optional[str] = None,
apply_router_weight_on_input: bool = False,
use_fp8_w8a8: bool = False,
use_int8_w8a8: bool = False,
use_int8_w8a16: bool = False,
......@@ -1549,10 +1336,11 @@ def outplace_fused_experts(
block_shape: Optional[List[int]] = None,
use_nn_moe: Optional[bool] = False) -> torch.Tensor:
return fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids,
False, activation, use_fp8_w8a8,use_int8_w8a8, use_int8_w8a16,
use_int4_w4a16, global_num_experts, expert_map,
w1_scale, w2_scale, w1_zp, w2_zp, a1_scale,
a2_scale, block_shape, use_nn_moe)
False, activation, apply_router_weight_on_input,
use_fp8_w8a8, use_int8_w8a8, use_int8_w8a16, use_int4_w4a16,
global_num_experts, expert_map, w1_scale,
w2_scale, w1_zp, w2_zp, a1_scale, a2_scale,
block_shape, use_nn_moe)
def outplace_fused_experts_fake(
......@@ -1587,6 +1375,24 @@ direct_register_custom_op(
)
def torch_vllm_inplace_fused_experts(**kwargs) -> torch.Tensor:
torch.ops.vllm.inplace_fused_experts(**kwargs)
hidden_states = kwargs['hidden_states']
return hidden_states
def torch_vllm_outplace_fused_experts(**kwargs) -> torch.Tensor:
return torch.ops.vllm.outplace_fused_experts(**kwargs)
def dispatch_fused_experts_func(inplace: bool) -> Callable[..., torch.Tensor]:
if is_rocm_aiter_moe_enabled():
return rocm_aiter_fused_experts
if inplace:
return torch_vllm_inplace_fused_experts
return torch_vllm_outplace_fused_experts
def fused_experts(hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
......@@ -1594,6 +1400,7 @@ def fused_experts(hidden_states: torch.Tensor,
topk_ids: torch.Tensor,
inplace: bool = False,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
use_fp8_w8a8: bool = False,
use_int8_w8a8: bool = False,
use_int8_w8a16: bool = False,
......@@ -1607,21 +1414,50 @@ def fused_experts(hidden_states: torch.Tensor,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None,
allow_deep_gemm: bool = False,
use_nn_moe: Optional[bool] = False) -> torch.Tensor:
if inplace:
torch.ops.vllm.inplace_fused_experts(
hidden_states, w1, w2, topk_weights, topk_ids, activation,
use_fp8_w8a8, use_int8_w8a8, use_int8_w8a16, use_int4_w4a16, global_num_experts,
expert_map, w1_scale, w2_scale, w1_zp, w2_zp, a1_scale, a2_scale,
block_shape, use_nn_moe)
return hidden_states
if (allow_deep_gemm and use_fp8_w8a8
and _valid_deep_gemm(hidden_states, w1, w2, expert_map)):
assert apply_router_weight_on_input is False
return deep_gemm_moe_fp8(
hidden_states=hidden_states,
w1=w1,
w2=w2,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=inplace,
activation=activation,
global_num_experts=global_num_experts,
expert_map=expert_map,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a1_scale,
a2_scale=a2_scale,
)
else:
return torch.ops.vllm.outplace_fused_experts(
hidden_states, w1, w2, topk_weights, topk_ids, activation,
use_fp8_w8a8, use_int8_w8a8, use_int8_w8a16, use_int4_w4a16, global_num_experts,
expert_map, w1_scale, w2_scale, w1_zp, w2_zp, a1_scale, a2_scale,
block_shape, use_nn_moe)
return dispatch_fused_experts_func(inplace)(
hidden_states=hidden_states,
w1=w1,
w2=w2,
topk_weights=topk_weights,
topk_ids=topk_ids,
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input,
use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a8=use_int8_w8a8,
use_int8_w8a16=use_int8_w8a16,
use_int4_w4a16=use_int4_w4a16,
global_num_experts=global_num_experts,
expert_map=expert_map,
w1_scale=w1_scale,
w2_scale=w2_scale,
w1_zp=w1_zp,
w2_zp=w2_zp,
a1_scale=a1_scale,
a2_scale=a2_scale,
block_shape=block_shape,
use_nn_moe=use_nn_moe)
def fused_experts_impl(hidden_states: torch.Tensor,
......@@ -1631,6 +1467,7 @@ def fused_experts_impl(hidden_states: torch.Tensor,
topk_ids: torch.Tensor,
inplace: bool = False,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
use_fp8_w8a8: bool = False,
use_int8_w8a8: bool = False,
use_int8_w8a16: bool = False,
......@@ -1663,10 +1500,13 @@ def fused_experts_impl(hidden_states: torch.Tensor,
]
num_tokens, _ = hidden_states.shape
if use_nn_moe:
E, _, N = w1.shape
else:
E, N, _ = w1.shape
K = w2.shape[1]
if global_num_experts == -1:
global_num_experts = E
top_k_num = topk_ids.shape[1]
......@@ -1695,13 +1535,11 @@ def fused_experts_impl(hidden_states: torch.Tensor,
# We can reuse the memory between these because by the time we need
# cache3, we're done with cache1
cache13 = torch.empty(M * top_k_num * max(N, w2.shape[1] if not use_nn_moe else w2.shape[2]),
cache13 = torch.empty(M * top_k_num * max(N, K if not use_nn_moe else w2.shape[2]),
device=hidden_states.device,
dtype=hidden_states.dtype)
intermediate_cache1 = cache13[:M * top_k_num * N].view(
(M, topk_ids.shape[1], N))
intermediate_cache3 = cache13[:M * top_k_num * (w2.shape[1] if not use_nn_moe else w2.shape[2])].view(
(M, topk_ids.shape[1], w2.shape[1] if not use_nn_moe else w2.shape[2]))
intermediate_cache1 = cache13[:M * top_k_num * N].view(M, top_k_num, N)
intermediate_cache3 = cache13[:M * top_k_num * (K if not use_nn_moe else w2.shape[2])].view(M, top_k_num, K)
# This needs separate memory since it's used concurrently with cache1
intermediate_cache2 = torch.empty((M * top_k_num, N // 2),
......@@ -1745,6 +1583,16 @@ def fused_experts_impl(hidden_states: torch.Tensor,
curr_topk_ids = topk_ids[begin_chunk_idx:end_chunk_idx]
curr_topk_weights = topk_weights[begin_chunk_idx:end_chunk_idx]
a1q_scale: Optional[torch.Tensor] = None
if use_fp8_w8a8:
qcurr_hidden_states, a1q_scale = _fp8_quantize(
curr_hidden_states, a1_scale, block_shape)
else:
qcurr_hidden_states = curr_hidden_states
a1q_scale = a1_scale
if use_int8_w8a8:
m=curr_hidden_states.shape[0]
if m<=16:
......@@ -1771,30 +1619,27 @@ def fused_experts_impl(hidden_states: torch.Tensor,
"num_stages": 0,
"num_warps": 4
}
# sorted_token_ids, expert_ids, num_tokens_post_padded = (
# moe_align_block_size(curr_topk_ids, config['BLOCK_SIZE_M'],
# global_num_experts, expert_map))
if use_int4_w4a16:
sorted_token_ids, expert_ids, num_tokens_post_padded = (
moe_align_block_size(curr_topk_ids, config['BLOCK_SIZE_M'], global_num_experts, expert_map, curr_hidden_states.shape[0]))
sorted_token_ids, expert_ids, num_tokens_post_padded = (
moe_align_block_size(curr_topk_ids, config['BLOCK_SIZE_M'],
global_num_experts, expert_map, curr_hidden_states.shape[0]))
else:
sorted_token_ids, expert_ids, num_tokens_post_padded = (
moe_align_block_size(curr_topk_ids, config['BLOCK_SIZE_M'], global_num_experts, expert_map))
moe_align_block_size(curr_topk_ids, config['BLOCK_SIZE_M'],
global_num_experts, expert_map))
invoke_fused_moe_kernel(curr_hidden_states,
invoke_fused_moe_kernel(qcurr_hidden_states,
w1,
intermediate_cache1,
a1_scale,
a1q_scale,
w1_scale,
w1_zp,
curr_topk_weights,
curr_topk_ids,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
False,
apply_router_weight_on_input,
top_k_num,
config,
compute_type=compute_type,
......@@ -1813,6 +1658,16 @@ def fused_experts_impl(hidden_states: torch.Tensor,
intermediate_cache1.view(-1, N))
else:
raise ValueError(f"Unsupported FusedMoe activation: {activation}")
a2q_scale: Optional[torch.Tensor] = None
if use_fp8_w8a8:
qintermediate_cache2, a2q_scale = _fp8_quantize(
intermediate_cache2, a2_scale, block_shape)
else:
qintermediate_cache2 = intermediate_cache2
a2q_scale = a2_scale
if use_int8_w8a8:
m=curr_hidden_states.shape[0]
if m<=16:
......@@ -1840,18 +1695,17 @@ def fused_experts_impl(hidden_states: torch.Tensor,
"num_warps": 4
}
invoke_fused_moe_kernel(intermediate_cache2,
invoke_fused_moe_kernel(qintermediate_cache2,
w2,
intermediate_cache3,
a2_scale,
a2q_scale,
w2_scale,
w2_zp,
curr_topk_weights,
curr_topk_ids,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
True,
not apply_router_weight_on_input,
1,
config,
compute_type=compute_type,
......@@ -1864,6 +1718,7 @@ def fused_experts_impl(hidden_states: torch.Tensor,
ops.moe_sum(intermediate_cache3.view(*intermediate_cache3.shape),
out_hidden_states[begin_chunk_idx:end_chunk_idx])
return out_hidden_states
......
......@@ -9,7 +9,7 @@ import torch
import torch.nn.functional as F
from torch.nn.parameter import UninitializedParameter
from vllm import envs
import vllm.envs as envs
from vllm.config import get_current_vllm_config
from vllm.distributed import (get_dp_group, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
......@@ -67,6 +67,8 @@ class FusedMoEMethodBase(QuantizeMethodBase):
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
) -> torch.Tensor:
raise NotImplementedError
......@@ -135,6 +137,18 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
layer.w2_weight = torch.nn.Parameter(self._maybe_pad_weight(
layer.w2_weight.data),
requires_grad=False)
# Lazy import to avoid importing triton.
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
is_rocm_aiter_moe_enabled, shuffle_weights)
if is_rocm_aiter_moe_enabled():
# reshaping weights is required for aiter moe kernel.
shuffled_w13, shuffled_w2 = shuffle_weights(
layer.w13_weight.data, layer.w2_weight.data)
layer.w13_weight = torch.nn.Parameter(shuffled_w13,
requires_grad=False)
layer.w2_weight = torch.nn.Parameter(shuffled_w2,
requires_grad=False)
if current_platform.is_cpu():
if current_platform.get_cpu_architecture() == CpuArchEnum.X86:
......@@ -162,24 +176,27 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
use_nn_moe: Optional[bool] = False,
) -> torch.Tensor:
return self.forward(x=x,
layer=layer,
router_logits=router_logits,
top_k=top_k,
renormalize=renormalize,
use_grouped_topk=use_grouped_topk,
topk_group=topk_group,
num_expert_group=num_expert_group,
global_num_experts=global_num_experts,
expert_map=expert_map,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias,
activation=activation,
use_nn_moe=use_nn_moe,)
return self.forward(
x=x,
layer=layer,
router_logits=router_logits,
top_k=top_k,
renormalize=renormalize,
use_grouped_topk=use_grouped_topk,
topk_group=topk_group,
num_expert_group=num_expert_group,
global_num_experts=global_num_experts,
expert_map=expert_map,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias,
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input,
use_nn_moe=use_nn_moe)
def forward_cuda(
self,
......@@ -196,6 +213,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
use_nn_moe: Optional[bool] = False,
) -> torch.Tensor:
......@@ -211,16 +229,18 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias)
return fused_experts(hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=True,
activation=activation,
global_num_experts=global_num_experts,
expert_map=expert_map,
use_nn_moe=use_nn_moe,)
return fused_experts(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=True,
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input,
global_num_experts=global_num_experts,
expert_map=expert_map,
use_nn_moe=use_nn_moe)
def forward_cpu(
self,
......@@ -238,10 +258,12 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
use_nn_moe: Optional[bool] = False,
**kwargs,
):
assert activation == "silu", f"{activation} is not supported."
assert apply_router_weight_on_input is False
return layer.ipex_fusion(
x,
use_grouped_topk,
......@@ -265,16 +287,21 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
renormalize: bool,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
use_nn_moe: Optional[bool] = False,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
use_nn_moe: Optional[bool] = False,
) -> torch.Tensor:
assert not use_grouped_topk
assert num_expert_group is None
assert topk_group is None
assert custom_routing_function is None
assert layer is not None
assert apply_router_weight_on_input is False
if scoring_func != "softmax":
raise NotImplementedError(
"Only softmax scoring function is supported for HPU.")
......@@ -299,12 +326,14 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
) -> torch.Tensor:
assert not use_grouped_topk
assert num_expert_group is None
assert topk_group is None
assert custom_routing_function is None
assert apply_router_weight_on_input is False
if scoring_func != "softmax":
raise NotImplementedError(
"Only softmax scoring function is supported for TPU.")
......@@ -321,7 +350,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
expert_map=expert_map,
renormalize=renormalize)
forward_native = forward_cuda
forward_native = forward_tpu if current_platform.is_tpu() else forward_cuda
def determine_expert_map(
......@@ -410,6 +439,7 @@ class FusedMoE(torch.nn.Module):
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
):
super().__init__()
......@@ -484,7 +514,7 @@ class FusedMoE(torch.nn.Module):
"non-grouped topk.")
if current_platform.is_hpu():
from vllm_hpu_extension.ops import DynamicFusedMOE
self.hpu_fused_moe = DynamicFusedMOE(self.num_experts)
self.hpu_fused_moe = DynamicFusedMOE(self.global_num_experts)
# Note: get_quant_method will look at the layer's local_num_experts
# for heuristic purposes, so it must be initialized first.
......@@ -500,7 +530,9 @@ class FusedMoE(torch.nn.Module):
self.use_nn_moe = int(os.environ.get('MOE_NN', 1)) == 1
else:
self.use_nn_moe = False
self.apply_router_weight_on_input = apply_router_weight_on_input
moe_quant_params = {
"num_experts": self.local_num_experts,
"hidden_size": hidden_size,
......@@ -736,8 +768,9 @@ class FusedMoE(torch.nn.Module):
tp_rank=self.tp_rank)
return
# Case weight scales and zero_points
if ("scale" in weight_name or "zero" in weight_name):
# Case weight scales, zero_points and offset
if ("scale" in weight_name or "zero" in weight_name
or "offset" in weight_name):
# load the weight scales and zp based on the quantization scheme
# supported weight scales/zp can be found in
# FusedMoeWeightScaleSupported
......@@ -886,6 +919,7 @@ class FusedMoE(torch.nn.Module):
scoring_func=self.scoring_func,
e_score_correction_bias=self.e_score_correction_bias,
activation=self.activation,
apply_router_weight_on_input=self.apply_router_weight_on_input,
use_nn_moe=self.use_nn_moe,
)
......@@ -923,32 +957,6 @@ class FusedMoE(torch.nn.Module):
]
def _load_fp8_scale(self, param: torch.nn.Parameter,
loaded_weight: torch.Tensor, weight_name: str,
shard_id: str, expert_id: int) -> None:
param_data = param.data
# Input scales can be loaded directly and should be equal.
if "input_scale" in weight_name:
if param_data[expert_id] != 1 and (param_data[expert_id] -
loaded_weight).abs() > 1e-5:
raise ValueError(
"input_scales of w1 and w3 of a layer "
f"must be equal. But got {param_data[expert_id]} "
f"vs. {loaded_weight}")
param_data[expert_id] = loaded_weight
# Weight scales
elif "weight_scale" in weight_name:
# If we are in merged column case (gate_up_proj)
if shard_id in ("w1", "w3"):
# We have to keep the weight scales of w1 and w3 because
# we need to re-quantize w1/w3 weights after weight loading.
idx = 0 if shard_id == "w1" else 1
param_data[expert_id][idx] = loaded_weight
# If we are in the row parallel case (down_proj)
else:
param_data[expert_id] = loaded_weight
def extra_repr(self) -> str:
s = (
......
# SPDX-License-Identifier: Apache-2.0
from typing import Optional, Tuple
import torch
import triton
import triton.language as tl
import vllm.envs as envs
from vllm import _custom_ops as ops
from vllm.utils import round_up
def ceil_div(a, b):
return (a + b - 1) // b
@triton.jit
def moe_align_block_size_stage1(
topk_ids_ptr,
tokens_cnts_ptr,
num_experts: tl.constexpr,
numel: tl.constexpr,
tokens_per_thread: tl.constexpr,
):
pid = tl.program_id(0)
start_idx = pid * tokens_per_thread
off_c = (pid + 1) * num_experts
for i in range(tokens_per_thread):
if start_idx + i < numel:
idx = tl.load(topk_ids_ptr + start_idx + i)
token_cnt = tl.load(tokens_cnts_ptr + off_c + idx)
tl.store(tokens_cnts_ptr + off_c + idx, token_cnt + 1)
@triton.jit
def moe_align_block_size_stage2(
tokens_cnts_ptr,
num_experts: tl.constexpr,
):
pid = tl.program_id(0)
last_cnt = 0
for i in range(1, num_experts + 1):
token_cnt = tl.load(tokens_cnts_ptr + i * num_experts + pid)
last_cnt = last_cnt + token_cnt
tl.store(tokens_cnts_ptr + i * num_experts + pid, last_cnt)
@triton.jit
def moe_align_block_size_stage3(
total_tokens_post_pad_ptr,
tokens_cnts_ptr,
cumsum_ptr,
num_experts: tl.constexpr,
block_size: tl.constexpr,
):
last_cumsum = 0
off_cnt = num_experts * num_experts
for i in range(1, num_experts + 1):
token_cnt = tl.load(tokens_cnts_ptr + off_cnt + i - 1)
last_cumsum = last_cumsum + tl.cdiv(token_cnt, block_size) * block_size
tl.store(cumsum_ptr + i, last_cumsum)
tl.store(total_tokens_post_pad_ptr, last_cumsum)
@triton.jit
def moe_align_block_size_stage4(
topk_ids_ptr,
sorted_token_ids_ptr,
expert_ids_ptr,
tokens_cnts_ptr,
cumsum_ptr,
num_experts: tl.constexpr,
block_size: tl.constexpr,
numel: tl.constexpr,
tokens_per_thread: tl.constexpr,
):
pid = tl.program_id(0)
start_idx = tl.load(cumsum_ptr + pid)
end_idx = tl.load(cumsum_ptr + pid + 1)
for i in range(start_idx, end_idx, block_size):
tl.store(expert_ids_ptr + i // block_size, pid)
start_idx = pid * tokens_per_thread
off_t = pid * num_experts
for i in range(start_idx, tl.minimum(start_idx + tokens_per_thread,
numel)):
expert_id = tl.load(topk_ids_ptr + i)
token_cnt = tl.load(tokens_cnts_ptr + off_t + expert_id)
rank_post_pad = token_cnt + tl.load(cumsum_ptr + expert_id)
tl.store(sorted_token_ids_ptr + rank_post_pad, i)
tl.store(tokens_cnts_ptr + off_t + expert_id, token_cnt + 1)
# Triton implementation based on:
# https://github.com/sgl-project/sglang/commit/ba5112ff691d791a9e38c6c71f59324a5fcb49d0
def moe_align_block_size_triton(
topk_ids: torch.Tensor,
num_experts: int,
block_size: int,
sorted_token_ids: torch.Tensor,
expert_ids: torch.Tensor,
num_tokens_post_pad: torch.Tensor,
) -> None:
numel = topk_ids.numel()
grid = (num_experts, )
tokens_cnts = torch.zeros((num_experts + 1, num_experts),
dtype=torch.int32,
device=topk_ids.device)
cumsum = torch.zeros((num_experts + 1, ),
dtype=torch.int32,
device=topk_ids.device)
tokens_per_thread = ceil_div(numel, num_experts)
moe_align_block_size_stage1[grid](
topk_ids,
tokens_cnts,
num_experts,
numel,
tokens_per_thread,
)
moe_align_block_size_stage2[grid](
tokens_cnts,
num_experts,
)
moe_align_block_size_stage3[(1, )](
num_tokens_post_pad,
tokens_cnts,
cumsum,
num_experts,
block_size,
)
moe_align_block_size_stage4[grid](
topk_ids,
sorted_token_ids,
expert_ids,
tokens_cnts,
cumsum,
num_experts,
block_size,
numel,
tokens_per_thread,
)
def moe_align_block_size(
topk_ids: torch.Tensor,
block_size: int,
num_experts: int,
expert_map: Optional[torch.Tensor] = None,
pad_sorted_ids: bool = False,
num_token: Optional[int] = None
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Aligns the token distribution across experts to be compatible with block
size for matrix multiplication.
Parameters:
- topk_ids: A tensor of shape [total_tokens, top_k] representing the
top-k expert indices for each token.
- block_size: The block size used in block matrix multiplication.
- num_experts: The total number of experts.
- expert_map: A tensor of shape [num_experts] that maps the expert index
from the global space to the local index space of the current
expert parallel shard. If the expert is not in the current expert
parallel shard, the mapping is set to -1.
- pad_sorted_ids: A flag indicating whether the sorted_token_ids length
should be padded to a multiple of block_size,
Returns:
- sorted_token_ids: A tensor containing the sorted token indices according
to their allocated expert.
- expert_ids: A tensor indicating the assigned expert index for each block.
- num_tokens_post_padded: The total number of tokens after padding,
ensuring divisibility by block_size.
This function pads the number of tokens that each expert needs to process
so that it is divisible by block_size.
Padding ensures that during block matrix multiplication, the dimensions
align correctly.
Example:
Given topk_ids = [[2, 3, 4], [1, 2, 4], [1, 3, 4], [1, 2, 3]],
block_size = 4, and num_experts = 4:
- We initially have 12 tokens (after repeating 'top_k' times) and 4 experts,
with each expert needing to process 3 tokens.
- As block_size is 4, we pad 1 token for each expert.
- First, flatten topk_ids to [2, 3, 4, 1, 2, 4, 1, 3, 4, 1, 2, 3].
- Then append padding tokens [12, 12, 12, 12] for each block.
- After sorting by expert index, we obtain token_ids
[3, 6, 9, 12, 0, 4, 10, 12, 1, 7, 11, 12, 2, 5, 8, 12].
Tokens 12 are non-existent (padding) and are ignored in
the subsequent matrix multiplication.
- The padding ensures that the total number of tokens is now divisible
by block_size for proper block matrix operations.
"""
if num_token:
if num_token < block_size:
max_num_tokens_padded = min(topk_ids.numel() * block_size, topk_ids.numel() + num_experts * (block_size - 1))
else:
max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1)
sorted_ids = torch.full((max_num_tokens_padded,), fill_value=topk_ids.numel(), dtype=torch.int32, device=topk_ids.device)
else:
max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1)
if pad_sorted_ids:
max_num_tokens_padded = round_up(max_num_tokens_padded, block_size)
sorted_ids = torch.empty((max_num_tokens_padded, ),
dtype=torch.int32,
device=topk_ids.device)
sorted_ids.fill_(topk_ids.numel())
max_num_m_blocks = triton.cdiv(max_num_tokens_padded, block_size)
# Expert ids must be zeroed out to prevent index out of bounds error while
# mapping global expert ids to local expert ids in expert parallelism.
expert_ids = torch.empty((max_num_m_blocks, ),
dtype=torch.int32,
device=topk_ids.device)
num_tokens_post_pad = torch.empty((1),
dtype=torch.int32,
device=topk_ids.device)
if num_experts >= 224:
if envs.VLLM_ENABLE_MOE_ALIGN_BLOCK_SIZE_TRITON or num_experts != 256:
moe_align_block_size_triton(
topk_ids,
num_experts,
block_size,
sorted_ids,
expert_ids,
num_tokens_post_pad,
)
else:
# Currently requires num_experts=256
ops.sgl_moe_align_block_size(
topk_ids,
num_experts,
block_size,
sorted_ids,
expert_ids,
num_tokens_post_pad,
)
else:
ops.moe_align_block_size(topk_ids, num_experts, block_size, sorted_ids,
expert_ids, num_tokens_post_pad)
if expert_map is not None:
expert_ids = expert_map[expert_ids]
return sorted_ids, expert_ids, num_tokens_post_pad
# SPDX-License-Identifier: Apache-2.0
from typing import List, Optional
import torch
import vllm.envs as envs
from vllm.platforms import current_platform
def is_rocm_aiter_moe_enabled() -> bool:
return current_platform.is_rocm() \
and envs.VLLM_ROCM_USE_AITER_MOE \
and envs.VLLM_ROCM_USE_AITER \
def is_rocm_aiter_block_scaled_moe_enabled() -> bool:
return is_rocm_aiter_moe_enabled() and \
envs.VLLM_ROCM_USE_AITER_FP8_BLOCK_SCALED_MOE
def rocm_aiter_fused_experts(
*,
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
use_fp8_w8a8: bool = False,
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None,
expert_mask: Optional[torch.Tensor] = None,
**kwagrs # Ignore additional keyword arguments
) -> torch.Tensor:
import aiter as rocm_aiter
import aiter.fused_moe_bf16_asm as rocm_aiter_asm_fmoe
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
per_token_group_quant_fp8)
if envs.VLLM_ROCM_USE_AITER_FP8_BLOCK_SCALED_MOE and use_fp8_w8a8:
assert w1_scale is not None
assert w2_scale is not None
local_E = E = w1.shape[0]
if expert_mask is not None:
E = expert_mask.numel()
topk = topk_ids.shape[1]
model_dim = w1.shape[-1]
dtype = hidden_states.dtype
# The default block sizes are 128 in AITER.
if block_shape is None:
block_shape = [128, 128]
scale_blk_k = block_shape[1]
(
sorted_token_ids,
sorted_weight_buf,
sorted_expert_ids,
num_valid_ids,
out_asm,
) = rocm_aiter_asm_fmoe.moe_sorting_ck(topk_ids,
topk_weights,
E,
model_dim,
dtype,
expert_mask=expert_mask)
a1, a1_scale = per_token_group_quant_fp8(hidden_states, scale_blk_k)
rocm_aiter.fmoe_fp8_blockscale_g1u1(
out_asm,
a1,
w1,
w2,
sorted_token_ids,
sorted_weight_buf,
sorted_expert_ids,
num_valid_ids,
topk,
w1_scale.view(local_E, -1),
w2_scale.view(local_E, -1),
a1_scale.t().contiguous(),
block_shape[0],
block_shape[1],
None,
)
return out_asm
elif use_fp8_w8a8:
return rocm_aiter_asm_fmoe.asm_moe(hidden_states=hidden_states,
w1=w1,
w2=w2,
topk_weight=topk_weights,
topk_ids=topk_ids,
fc1_scale=w1_scale,
fc2_scale=w2_scale,
fc1_smooth_scale=None,
fc2_smooth_scale=None,
a16=False)
return rocm_aiter.ck_moe(hidden_states=hidden_states,
w1=w1,
w2=w2,
topk_weights=topk_weights,
topk_ids=topk_ids)
def rocm_aiter_topk_softmax(topk_weights: torch.Tensor,
topk_indices: torch.Tensor,
token_expert_indices: torch.Tensor,
gating_output: torch.Tensor,
renormalize: bool) -> tuple[torch.Tensor, ...]:
import aiter as rocm_aiter
rocm_aiter.topk_softmax(topk_weights, topk_indices, token_expert_indices,
gating_output, renormalize)
return topk_weights, topk_indices
def shuffle_weights(*tensors: torch.Tensor) -> tuple[torch.Tensor, ...]:
"""
Applies shuffle_weight function from AITER to each
input tensor and returns them.
Args:
*tensors: Variable number of torch.Tensor objects.
Returns:
A tuple of shuffled tensors.
"""
from aiter.ops.shuffle import shuffle_weight
return tuple(shuffle_weight(tensor) for tensor in tensors)
def expand_weights(*tensors: torch.Tensor,
expansion_dims: list[int]) -> tuple[torch.Tensor, ...]:
"""
Expands the dimensions of input tensors.
Args:
*tensors: A variable number of torch.Tensor objects.
expansion_dims: A list of expansion dimensions
corresponding to each tensor.
Returns:
A tuple of tensors with expanded dimensions.
"""
assert len(tensors) == len(expansion_dims), \
"Number of tensors must match the number of expansion dimensions."
return tuple(
tensor.unsqueeze(-1).unsqueeze(-1).expand((-1, dim, -1))
for tensor, dim in zip(tensors, expansion_dims))
# SPDX-License-Identifier: Apache-2.0
from math import prod
from typing import List, Optional, Tuple
import torch
from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
per_token_group_quant_fp8)
from vllm.utils import cdiv
def _resize_cache(x: torch.Tensor, v: Tuple[int, ...]) -> torch.Tensor:
"""
Shrink the given tensor and apply the given view to it. This is
used to resize the intermediate fused_moe caches.
"""
assert prod(v) <= x.numel()
return x.flatten()[:prod(v)].view(*v)
def _fp8_quantize(
A: torch.Tensor,
A_scale: Optional[torch.Tensor],
block_shape: Optional[List[int]],
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Perform fp8 quantization on the inputs. If a block_shape
is provided, the output will be blocked.
"""
if block_shape is None:
A, A_scale = ops.scaled_fp8_quant(A, A_scale)
else:
assert len(block_shape) == 2
_, block_k = block_shape[0], block_shape[1]
A, A_scale = per_token_group_quant_fp8(A, block_k)
assert cdiv(A.shape[-1], block_k) == A_scale.shape[-1]
return A, A_scale
def _fp8_perm(m: torch.Tensor, idx: torch.Tensor) -> torch.Tensor:
"""
A permutation routine that works on fp8 types.
"""
if torch.is_floating_point(m) and torch.finfo(m.dtype).bits == 8:
return m.view(dtype=torch.uint8)[idx, ...].view(dtype=m.dtype)
else:
return m[idx, ...]
......@@ -109,6 +109,7 @@ class RMSNorm(CustomOp):
eps: float = 1e-6,
var_hidden_size: Optional[int] = None,
has_weight: bool = True,
dtype: Optional[torch.dtype] = None,
) -> None:
super().__init__()
......@@ -117,8 +118,10 @@ class RMSNorm(CustomOp):
self.variance_size_override = (None if var_hidden_size == hidden_size
else var_hidden_size)
self.has_weight = has_weight
self.weight = torch.ones(hidden_size)
if dtype is not None:
self.weight = torch.ones(hidden_size, dtype=dtype)
else:
self.weight = torch.ones(hidden_size)
if self.has_weight:
self.weight = nn.Parameter(self.weight)
......
# SPDX-License-Identifier: Apache-2.0
import torch
import triton
import triton.language as tl
from einops import rearrange
@triton.jit
def _fwd_diag_kernel(Q, K, V, Out, S, b: tl.constexpr, h: tl.constexpr, n,
d: tl.constexpr, e: tl.constexpr, BLOCK: tl.constexpr,
NUM_BLOCK, CBLOCK: tl.constexpr):
# This kernel computes the diagonal blocks of the attention matrix
# Each diagonal block represents attention
# where queries attend to keys in the same block
off = tl.program_id(0)
off_bh = off // NUM_BLOCK # batch-head index
off_block = off % NUM_BLOCK # block index within the sequence
off_cblock = tl.program_id(1) # sub-block index within a block
off_h = off_bh % h # head index
# Calculate base offsets for the current batch and head
qk_offset = off_bh * n * d
v_offset = off_bh * n * e
o_offset = off_bh * n * e
# Calculate offsets for the current block
block_offset = off_block * BLOCK
qk_block_offset = block_offset * d
v_block_offset = block_offset * e
o_block_offset = block_offset * e
# Calculate offsets for the current sub-block
cblock_offset = off_cblock * CBLOCK
q_cblock_offset = cblock_offset * d
o_cblock_offset = cblock_offset * e
# Calculate pointers to the query, key, value, and output tensors
Q_block_ptr = (Q + qk_offset + qk_block_offset + q_cblock_offset +
tl.arange(0, CBLOCK)[:, None] * d +
tl.arange(0, d)[None, :])
K_trans_block_ptr = (K + qk_offset + qk_block_offset +
tl.arange(0, CBLOCK)[None, :] * d +
tl.arange(0, d)[:, None])
V_block_ptr = (V + v_offset + v_block_offset +
tl.arange(0, CBLOCK)[:, None] * e +
tl.arange(0, e)[None, :])
O_block_ptr = (Out + o_offset + o_block_offset + o_cblock_offset +
tl.arange(0, CBLOCK)[:, None] * e +
tl.arange(0, e)[None, :])
# Load the decay rate for the current head
S_block_ptr = S + off_h
s = tl.load(S_block_ptr)
i = off_cblock
q_index = tl.arange(0, CBLOCK) + i * CBLOCK
# Load query values
q = tl.load(Q_block_ptr,
mask=block_offset + q_index[:, None] < n,
other=0.0).to(tl.float32)
# Initialize output accumulator
qkv = tl.zeros([CBLOCK, e], dtype=tl.float32)
# Process all sub-blocks up to and
# including the current one (causal attention)
for j in range(i + 1):
kv_index = tl.arange(0, CBLOCK) + j * CBLOCK
diff = q_index[:, None] - kv_index[None, :]
s_index = s * diff
# Apply causal mask: only attend to positions before the current one
s_index = tl.where(diff >= 0, -s_index, float("-inf"))
decay = tl.exp(s_index)
# Load key and value
k_trans = tl.load(
K_trans_block_ptr,
mask=block_offset + kv_index[None, :] < n,
other=0.0,
).to(tl.float32)
v = tl.load(
V_block_ptr,
mask=block_offset + kv_index[:, None] < n,
other=0.0,
).to(tl.float32)
# Compute attention scores and apply decay
qk = tl.dot(q, k_trans) * decay
# Compute weighted values and accumulate
qkv += tl.dot(qk, v)
# Move to the next sub-block
K_trans_block_ptr += CBLOCK * d
V_block_ptr += CBLOCK * e
# Store the result
tl.store(
O_block_ptr,
qkv.to(O_block_ptr.dtype.element_ty),
mask=block_offset + q_index[:, None] < n,
)
@triton.jit
def _fwd_kv_parallel(
K,
V,
K_decay,
KV,
b: tl.constexpr,
h: tl.constexpr,
n,
d: tl.constexpr,
e: tl.constexpr,
BLOCK: tl.constexpr,
NUM_BLOCK,
D_FBLOCK: tl.constexpr,
E_FBLOCK: tl.constexpr,
NUM_FBLOCK: tl.constexpr,
CBLOCK: tl.constexpr,
NUM_CBLOCK: tl.constexpr,
):
# This kernel computes the key-value outer
# products for each block in parallel
off_bh = tl.program_id(0) # batch-head index
off_block = tl.program_id(1) # block index
off_h = off_bh % h # head index
block_offset = off_block * BLOCK
# Calculate offsets for the current block
k_block_offset = block_offset * d
v_block_offset = block_offset * e
kv_block_offset = off_block * d * e
# Calculate base offsets for the current batch and head
k_offset = off_bh * n * d
v_offset = off_bh * n * e
kv_offset = off_bh * NUM_BLOCK * d * e
# Calculate pointers to the key, value, and key-value tensors
K_trans_block_ptr = (K + k_offset + k_block_offset +
tl.arange(0, CBLOCK)[None, :] * d +
tl.arange(0, D_FBLOCK)[:, None])
V_block_ptr = (V + v_offset + v_block_offset +
tl.arange(0, CBLOCK)[:, None] * e +
tl.arange(0, E_FBLOCK)[None, :])
KV_block_ptr = (KV + kv_offset + kv_block_offset +
tl.arange(0, D_FBLOCK)[:, None] * e +
tl.arange(0, E_FBLOCK)[None, :])
# Load the decay factors for the current head and block
k_decay_ptr = (K_decay + off_h * BLOCK + tl.arange(0, CBLOCK)[None, :])
kv_index = tl.arange(0, CBLOCK)
# Initialize the key-value outer product accumulator
kv = tl.zeros([D_FBLOCK, E_FBLOCK], dtype=tl.float32)
# Handle the last block which might be smaller than BLOCK
if off_block == NUM_BLOCK - 1:
split_n = n - (NUM_BLOCK - 1) * BLOCK
else:
split_n = BLOCK
left_shift = tl.cdiv(split_n, CBLOCK) * CBLOCK - split_n
num_blocks = min(tl.cdiv(split_n, CBLOCK), NUM_CBLOCK)
k_decay_ptr += (NUM_CBLOCK - num_blocks) * CBLOCK
# Process all sub-blocks in the current block
for j in range(num_blocks):
left_bound = (1 - j) * left_shift
# Load key and value, handling boundary conditions
k_trans = tl.load(K_trans_block_ptr - left_shift * d,
mask=kv_index[None, :] >= left_bound,
other=0.0)
v = tl.load(V_block_ptr - left_shift * e,
mask=kv_index[:, None] >= left_bound,
other=0.0)
# Load decay factor and compute weighted key-value outer product
k_decay = tl.load(k_decay_ptr)
kv += tl.dot(k_trans * k_decay, v)
# Move to the next sub-block
K_trans_block_ptr += CBLOCK * d
V_block_ptr += CBLOCK * e
k_decay_ptr += CBLOCK
# Store the result
tl.store(KV_block_ptr, kv.to(KV_block_ptr.dtype.element_ty))
@triton.jit
def _fwd_kv_reduce(S, KV, KV_HISTORY, b: tl.constexpr, h: tl.constexpr, n,
d: tl.constexpr, e: tl.constexpr, BLOCK: tl.constexpr,
NUM_BLOCK, D_FBLOCK: tl.constexpr, E_FBLOCK: tl.constexpr):
# This kernel reduces the key-value outer products
# across blocks and updates the KV history
off_bh = tl.program_id(0) # batch-head index
off_h = off_bh % h # head index
kv_offset = off_bh * NUM_BLOCK * d * e
# Calculate pointer to the key-value tensor
KV_block_ptr = (KV + kv_offset + tl.arange(0, D_FBLOCK)[:, None] * e +
tl.arange(0, E_FBLOCK)[None, :])
# Load the decay rate for the current head
s_ptrs = S + off_h
s = tl.load(s_ptrs)
# Calculate pointer to the key-value history tensor
kv_history_offset = off_bh * d * e
KV_HISTORY_block_ptr = (KV_HISTORY + kv_history_offset +
tl.arange(0, D_FBLOCK)[:, None] * e +
tl.arange(0, E_FBLOCK)[None, :])
# Load the previous key-value history
kv_pre = tl.load(KV_HISTORY_block_ptr).to(tl.float32)
# Process all blocks in reverse order to compute the prefix sum
for i in range(NUM_BLOCK):
block_size = min(n - i * BLOCK, BLOCK)
# Compute decay factor for the current block
block_decay = tl.exp(-s.to(tl.float32) * block_size)
# Load the current key-value outer product
kv_cur = tl.load(KV_block_ptr).to(tl.float32)
# Store the previous key-value history to the current block
tl.store(KV_block_ptr, kv_pre.to(KV_block_ptr.dtype.element_ty))
# Update the key-value history with the current block
kv_pre = block_decay * kv_pre + kv_cur
KV_block_ptr += d * e
# Store the updated key-value history
tl.store(KV_HISTORY_block_ptr, kv_pre)
@triton.jit
def _fwd_none_diag_kernel(
Q,
Out,
S,
KV,
b: tl.constexpr,
h: tl.constexpr,
n,
d: tl.constexpr,
e: tl.constexpr,
BLOCK: tl.constexpr,
NUM_BLOCK,
E_FBLOCK: tl.constexpr,
CBLOCK: tl.constexpr,
NUM_CBLOCK: tl.constexpr,
):
# This kernel computes the non-diagonal blocks of the attention matrix
# Each non-diagonal block represents attention
# where queries attend to keys in different blocks
off_bh = tl.program_id(0) # batch-head index
off_h = off_bh % h # head index
off_nc = tl.program_id(1)
off_n = off_nc // NUM_CBLOCK # block index
off_c = off_nc % NUM_CBLOCK # sub-block index
off_e = tl.program_id(2) # output feature block index
n_offset = off_n * BLOCK
c_offset = off_c * CBLOCK
e_offset = off_e * E_FBLOCK
block_offset = n_offset + c_offset
# Calculate offsets for the current batch, head, and block
q_offset = off_bh * n * d + (n_offset + c_offset) * d
o_offset = off_bh * n * e + (n_offset + c_offset) * e + e_offset
kv_offset = off_bh * NUM_BLOCK * d * e + off_n * d * e + e_offset
# Calculate pointers to the query, output, and key-value tensors
Q_block_ptr = (Q + q_offset + tl.arange(0, CBLOCK)[:, None] * d +
tl.arange(0, d)[None, :])
O_block_ptr = (Out + o_offset + tl.arange(0, CBLOCK)[:, None] * e +
tl.arange(0, E_FBLOCK)[None, :])
KV_block_ptr = (KV + kv_offset + tl.arange(0, d)[:, None] * e +
tl.arange(0, E_FBLOCK)[None, :])
# Load the decay rate for the current head
S_block_ptr = S + off_h
s = tl.load(S_block_ptr)
c_array = tl.arange(0, CBLOCK)
# Load the key-value outer product for the current block
kv = tl.load(KV_block_ptr).to(tl.float32)
q_index = block_offset + tl.arange(0, CBLOCK)
# Load query values
q = tl.load(Q_block_ptr, mask=q_index[:, None] < n,
other=0.).to(tl.float32)
# Compute decay factors for the current sub-block
q_decay = tl.exp(-s.to(tl.float32) * (off_c * CBLOCK + c_array[:, None]))
# Compute non-diagonal attention output
qkv_none_diag = tl.dot(q, kv) * q_decay
# Load diagonal attention output (computed by _fwd_diag_kernel)
qkv_diag = tl.load(O_block_ptr, mask=q_index[:, None] < n,
other=0.).to(tl.float32)
# Combine diagonal and non-diagonal attention outputs
qkv = qkv_diag + qkv_none_diag
# Store the result
tl.store(O_block_ptr,
qkv.to(O_block_ptr.dtype.element_ty),
mask=q_index[:, None] < n)
class _attention(torch.autograd.Function):
@staticmethod
def forward(ctx, q, k, v, s, kv_history):
# Forward pass of the lightning attention algorithm
q = q.contiguous()
k = k.contiguous()
v = v.contiguous()
s = s.contiguous()
# Check CUDA compute capability
capability = torch.cuda.get_device_capability()
if capability[0] < 8:
raise RuntimeError("Flash attention currently only supported",
"for compute capability >= 80")
# Get input dimensions
b, h, n, d = q.shape
e = v.shape[-1]
# Initialize output tensor
o = torch.empty((b, h, n, e), dtype=q.dtype, device=q.device)
# Set block sizes
BLOCK = 256
NUM_BLOCK = triton.cdiv(n, BLOCK)
CBLOCK = 32
NUM_CBLOCK = BLOCK // CBLOCK
assert BLOCK % CBLOCK == 0, "BLOCK must be a multiple of CBLOCK"
# Compute decay factors for keys
array = torch.arange(0, BLOCK, device=q.device) + 1
k_decay = torch.exp(-s * (BLOCK - array.reshape(1, -1)))
# Step 1: Compute diagonal blocks of attention
grid = (b * h * NUM_BLOCK, NUM_CBLOCK)
_fwd_diag_kernel[grid](q,
k,
v,
o,
s,
b,
h,
n,
d,
e,
BLOCK=BLOCK,
NUM_BLOCK=NUM_BLOCK,
CBLOCK=CBLOCK)
# Set feature block sizes
NUM_FBLOCK = 1
D_FBLOCK = d // NUM_FBLOCK
assert d % NUM_FBLOCK == 0
E_FBLOCK = e // NUM_FBLOCK
assert e % NUM_FBLOCK == 0
CBLOCK = 64
NUM_CBLOCK = BLOCK // CBLOCK
assert BLOCK % CBLOCK == 0, "BLOCK must be a multiple of CBLOCK"
# Step 2: Compute key-value outer products for each block in parallel
kv = torch.empty((b, h, NUM_BLOCK, d, e),
dtype=torch.float32,
device=q.device)
grid = (b * h, NUM_BLOCK)
_fwd_kv_parallel[grid](
k,
v,
k_decay,
kv,
b,
h,
n,
d,
e,
BLOCK=BLOCK,
NUM_BLOCK=NUM_BLOCK,
D_FBLOCK=D_FBLOCK,
E_FBLOCK=E_FBLOCK,
NUM_FBLOCK=NUM_FBLOCK,
CBLOCK=CBLOCK,
NUM_CBLOCK=NUM_CBLOCK,
)
# Step 3: Reduce key-value outer products
# across blocks and update KV history
grid = (b * h, NUM_FBLOCK)
_fwd_kv_reduce[grid](s,
kv,
kv_history,
b,
h,
n,
d,
e,
BLOCK=BLOCK,
NUM_BLOCK=NUM_BLOCK,
D_FBLOCK=D_FBLOCK,
E_FBLOCK=E_FBLOCK)
# Step 4: Compute non-diagonal blocks of attention
grid = (b * h, NUM_BLOCK * NUM_CBLOCK)
_fwd_none_diag_kernel[grid](
q,
o,
s,
kv,
b,
h,
n,
d,
e,
BLOCK=BLOCK,
NUM_BLOCK=NUM_BLOCK,
E_FBLOCK=E_FBLOCK,
CBLOCK=CBLOCK,
NUM_CBLOCK=NUM_CBLOCK,
)
# Save tensors for backward pass
ctx.save_for_backward(q, k, v, s, kv)
ctx.BLOCK = BLOCK
return o, torch.cat([kv, kv_history.unsqueeze(2)], dim=2)
# Apply the lightning attention function
lightning_attention_ = _attention.apply
def lightning_attention(q, k, v, ed, block_size=256, kv_history=None):
"""
Apply lightning attention algorithm
to compute attention efficiently.
Args:
q: Query tensor of shape [batch, heads, seq_len, dim]
k: Key tensor of shape [batch, heads, seq_len, dim]
v: Value tensor of shape [batch, heads, seq_len, dim_v]
ed: Decay rate tensor of shape [heads]
block_size: Size of blocks for block-sparse attention
kv_history: Optional key-value history from previous computations
Returns:
output: Attention output
kv: Updated key-value history
"""
d = q.shape[-1]
e = v.shape[-1]
if ed.dim() == 1:
ed = ed.view(1, -1, 1, 1)
# Split the computation into chunks for better parallelism
m = 128 if d >= 128 else 64
assert d % m == 0, f"Dimension d ({d}) must be divisible by m ({m})"
arr = [m * i for i in range(d // m + 1)]
if arr[-1] != d:
arr.append(d)
n = len(arr)
output = 0
# Initialize or clone key-value history
if kv_history is None:
kv_history = torch.zeros((q.shape[0], q.shape[1], d, e),
dtype=torch.float32,
device=q.device)
else:
kv_history = kv_history.clone().contiguous()
# Process each chunk and accumulate results
for i in range(n - 1):
s = arr[i]
e = arr[i + 1]
q1 = q[..., s:e]
k1 = k[..., s:e]
o, kv = lightning_attention_(q1, k1, v, ed, kv_history)
output = output + o
return output, kv
@triton.jit
def _linear_attn_decode_kernel(
q_ptr,
k_ptr,
v_ptr,
kv_cache_ptr,
slope_rate,
slot_idx,
output_ptr,
D: tl.constexpr,
qkv_b_stride,
qkv_h_stride,
cache_b_stride,
cache_h_stride,
cache_d0_stride,
cache_d1_stride,
BLOCK_SIZE: tl.constexpr,
):
"""
Kernel for linear attention decoding with KV cache.
This kernel computes attention for a single token using the KV cache.
"""
pid_b = tl.program_id(0) # batch index
pid_h = tl.program_id(1) # head index
pid_d = tl.program_id(2) # dimension block index
# Load slot index for the current batch
slot_id = tl.load(slot_idx + pid_b)
# Skip if slot_id is -1 (padding)
if slot_id == -1:
return
batch_id = pid_b
head_id = pid_h
# Load decay rate for the current head
ratio = tl.load(slope_rate + pid_h)
# Calculate offsets for dimensions
qk_d_offsets = tl.arange(0, D)
v_d_offsets = tl.arange(0, BLOCK_SIZE) + pid_d * BLOCK_SIZE
cache_d_offsets = qk_d_offsets[:, None] * cache_d0_stride + v_d_offsets[
None, :] * cache_d1_stride
# Calculate offsets for the current batch and head
q_offset = batch_id * qkv_b_stride + head_id * qkv_h_stride
k_offset = batch_id * qkv_b_stride + head_id * qkv_h_stride
v_offset = batch_id * qkv_b_stride + head_id * qkv_h_stride
cache_offset = slot_id * cache_b_stride + head_id * cache_h_stride
# Create masks for loading tensors
qk_mask = qk_d_offsets < D
v_mask = v_d_offsets < D
# Load query, key, and value tensors
q = tl.load(q_ptr + q_offset + qk_d_offsets, mask=qk_mask, other=0.0)
k = tl.load(k_ptr + k_offset + qk_d_offsets, mask=qk_mask, other=0.0)
v = tl.load(v_ptr + v_offset + v_d_offsets, mask=v_mask, other=0.0)
# Compute key-value outer product
kv_outer = k[:, None] * v[None, :]
kv_mask = qk_mask[:, None] & v_mask[None, :]
# Apply decay to previous KV cache
ratio = tl.exp(-ratio)
kv_ptr = kv_cache_ptr + cache_offset + cache_d_offsets
kv_cache_old = tl.load(kv_ptr, mask=kv_mask, other=0.0)
kv_outer = kv_outer + ratio * kv_cache_old
# Compute attention output
output = q[:, None].to(tl.float32) * kv_outer
output = tl.sum(output, axis=0)
# Update KV cache and store output
tl.store(kv_ptr, kv_outer, mask=kv_mask)
tl.store(output_ptr + q_offset + v_d_offsets, output, mask=v_mask)
def linear_decode_forward_triton(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
kv_caches: torch.Tensor,
slope_rate: torch.Tensor,
slot_idx: torch.Tensor,
BLOCK_SIZE: int = 32,
) -> torch.Tensor:
"""
Perform linear attention decoding using Triton kernels.
Args:
q: Query tensor of shape [B, H, 1, D]
k: Key tensor of shape [B, H, 1, D]
v: Value tensor of shape [B, H, 1, D]
kv_caches: Key-value cache tensor
slope_rate: Decay rate tensor
slot_idx: Slot indices for batches
BLOCK_SIZE: Size of blocks for processing
Returns:
output: Attention output tensor
"""
B, H, _, D = q.shape
assert k.shape == (B, H, 1, D)
assert v.shape == (B, H, 1, D)
# Initialize output tensor
output = torch.empty_like(q)
# Set grid dimensions for the kernel
grid = (B, H, D // BLOCK_SIZE)
# Calculate strides for tensors
qkv_b_stride = q.stride(0)
qkv_h_stride = q.stride(1)
cache_b_stride = kv_caches.stride(0)
cache_h_stride = kv_caches.stride(1)
cache_d0_stride = kv_caches.stride(2)
cache_d1_stride = kv_caches.stride(3)
# Launch the kernel
_linear_attn_decode_kernel[grid](
q,
k,
v,
kv_caches,
slope_rate,
slot_idx,
output,
D,
qkv_b_stride,
qkv_h_stride,
cache_b_stride,
cache_h_stride,
cache_d0_stride,
cache_d1_stride,
BLOCK_SIZE=BLOCK_SIZE,
)
# Reshape output and return
output = rearrange(output, "b h n d -> b n (h d)")
return output.squeeze(1).contiguous()
......@@ -469,6 +469,7 @@ class AWQMoEMethod(FusedMoEMethodBase):
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
) -> torch.Tensor:
assert activation == "silu", "Only SiLU activation is supported."
......@@ -476,6 +477,10 @@ class AWQMoEMethod(FusedMoEMethodBase):
raise NotImplementedError(
"Expert Parallelism is not supported for "
"fused Marlin MoE method.")
if apply_router_weight_on_input:
raise NotImplementedError(
"Apply router weight on input is not supported for"
"fused Marlin MoE method.")
topk_weights, topk_ids = FusedMoE.select_experts(
hidden_states=x,
......
......@@ -9,6 +9,7 @@ from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
set_weight_attrs)
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.utils import direct_register_custom_op
class BitsAndBytesConfig(QuantizationConfig):
......@@ -321,9 +322,6 @@ class BitsAndBytesLinearMethod(LinearMethodBase):
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
# only load the bitsandbytes module when needed
from bitsandbytes import matmul_4bit
original_type = x.dtype
original_shape = x.shape
reshape_after_matmul = False
......@@ -343,19 +341,7 @@ class BitsAndBytesLinearMethod(LinearMethodBase):
out_dim_1,
dtype=torch.bfloat16,
device=x.device)
current_index = 0
for i in range(len(quant_states)):
output_size = quant_states[i].shape[0]
# It is more efficient to use out kwarg like
# matmul_4bit(..., out = ...). Infeasible now due to the bug
# https://github.com/TimDettmers/bitsandbytes/issues/1235.
# Need to change after the bug is fixed.
out[:, current_index:current_index + output_size] = matmul_4bit(
bf_x, qweight[offsets[i]:offsets[i + 1]].t(), quant_states[i])
current_index += output_size
apply_bnb_4bit(bf_x, qweight, offsets, out)
out = out.to(original_type)
if reshape_after_matmul:
......@@ -365,3 +351,46 @@ class BitsAndBytesLinearMethod(LinearMethodBase):
out += bias
return out
def _apply_bnb_4bit(
x: torch.Tensor,
weight: torch.Tensor,
offsets: torch.Tensor,
out: torch.Tensor,
) -> None:
# only load the bitsandbytes module when needed
from bitsandbytes import matmul_4bit
quant_states = weight.bnb_quant_state
current_index = 0
for i in range(len(quant_states)):
output_size = quant_states[i].shape[0]
# It is more efficient to use out kwarg like
# matmul_4bit(..., out = ...). Infeasible now due to the bug
# https://github.com/TimDettmers/bitsandbytes/issues/1235.
# Need to change after the bug is fixed.
out[:, current_index:current_index + output_size] = matmul_4bit(
x, weight[offsets[i]:offsets[i + 1]].t(), quant_states[i])
current_index += output_size
def _apply_bnb_4bit_fake(
x: torch.Tensor,
weight: torch.Tensor,
offsets: torch.Tensor,
out: torch.Tensor,
) -> None:
return
try:
direct_register_custom_op(
op_name="apply_bnb_4bit",
op_func=_apply_bnb_4bit,
mutates_args=["out"],
fake_impl=_apply_bnb_4bit_fake,
)
apply_bnb_4bit = torch.ops.vllm.apply_bnb_4bit
except AttributeError as error:
raise error
......@@ -97,7 +97,8 @@ class CompressedTensorsConfig(QuantizationConfig):
if isinstance(layer, Attention):
return CompressedTensorsKVCacheMethod(self)
if isinstance(layer, FusedMoE):
return CompressedTensorsMoEMethod.get_moe_method(self)
return CompressedTensorsMoEMethod.get_moe_method(
self, layer.activation, layer.expert_map)
return None
@classmethod
......@@ -192,17 +193,26 @@ class CompressedTensorsConfig(QuantizationConfig):
def _check_scheme_supported(self,
min_capability: int,
error: bool = True) -> bool:
error: bool = True,
match_exact: bool = False) -> bool:
capability_tuple = current_platform.get_device_capability()
if capability_tuple is not None:
capability = capability_tuple.to_int()
supported = capability >= min_capability
if error and not supported:
raise RuntimeError(
"Quantization scheme is not supported for ",
f"the current GPU. Min capability: {min_capability}. ",
f"Current capability: {capability}.")
if match_exact:
supported = capability == min_capability
if error and not supported:
raise RuntimeError(
"Quantization scheme is not supported for ",
"the current GPU. Required capability: ",
f"{min_capability}. Current capability: {capability}.")
else:
supported = capability >= min_capability
if error and not supported:
raise RuntimeError(
"Quantization scheme is not supported for ",
f"the current GPU. Min capability: {min_capability}. ",
f"Current capability: {capability}.")
return supported
else:
return False
......@@ -263,6 +273,11 @@ class CompressedTensorsConfig(QuantizationConfig):
input_quant.strategy == QuantizationStrategy.TENSOR)
return is_symmetric_activation and is_per_tensor_activation
def _is_fp8_w8a8_sm90(self, weight_quant: BaseModel,
input_quant: BaseModel) -> bool:
return (self._check_scheme_supported(90, error=False, match_exact=True)
and self._is_fp8_w8a8(weight_quant, input_quant))
def _is_fp8_w8a16(self, weight_quant: BaseModel,
input_quant: BaseModel) -> bool:
# Confirm weights quantized.
......
......@@ -31,6 +31,7 @@ class GPTQMarlinState(Enum):
__all__ = [
"CompressedTensorsMoEMethod", "CompressedTensorsW8A8Fp8MoEMethod",
"CompressedTensorsW8A8Fp8MoECutlassMethod",
"CompressedTensorsWNA16MoEMethod"
]
......@@ -39,7 +40,9 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase):
@staticmethod
def get_moe_method(
quant_config: "CompressedTensorsConfig" # type: ignore # noqa E501
quant_config: "CompressedTensorsConfig", # type: ignore # noqa E501
activation: str,
expert_map: Optional[torch.Tensor],
) -> "CompressedTensorsMoEMethod":
# TODO: @dsikka: refactor this to use schemes as other kernels
# are supported + check if the layer is being ignored.
......@@ -49,6 +52,9 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase):
if quant_config._is_wNa16_group_channel(weight_quant, input_quant):
return CompressedTensorsWNA16MoEMethod(quant_config)
elif (quant_config._is_fp8_w8a8_sm90(weight_quant, input_quant)
and activation == "silu" and expert_map is None):
return CompressedTensorsW8A8Fp8MoECutlassMethod(quant_config)
elif quant_config._is_fp8_w8a8(weight_quant, input_quant):
return CompressedTensorsW8A8Fp8MoEMethod(quant_config)
else:
......@@ -218,6 +224,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
) -> torch.Tensor:
from vllm.model_executor.layers.fused_moe import fused_experts
......@@ -234,20 +241,245 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias)
return fused_experts(x,
layer.w13_weight,
layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=True,
activation=activation,
use_fp8_w8a8=True,
global_num_experts=global_num_experts,
expert_map=expert_map,
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale)
return fused_experts(
x,
layer.w13_weight,
layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=True,
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input,
use_fp8_w8a8=True,
global_num_experts=global_num_experts,
expert_map=expert_map,
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale)
class CompressedTensorsW8A8Fp8MoECutlassMethod(CompressedTensorsMoEMethod):
def __init__(
self,
quant_config: "CompressedTensorsConfig" # type: ignore # noqa E501
):
self.quant_config = quant_config
self.weight_quant = self.quant_config.target_scheme_map["Linear"].get(
"weights")
self.input_quant = self.quant_config.target_scheme_map["Linear"].get(
"input_activations")
per_tensor = (self.weight_quant.strategy == QuantizationStrategy.TENSOR
and self.input_quant.strategy
== QuantizationStrategy.TENSOR)
per_channel = (
self.weight_quant.strategy == QuantizationStrategy.CHANNEL
and self.input_quant.strategy == QuantizationStrategy.TOKEN)
if not (per_tensor or per_channel):
raise ValueError(
"For FP8 Fused MoE layers, we require per tensor "
"or channelwise, dynamic per token quantization. Found "
f"{self.weight_quant}, {self.input_quant}")
self.static_input_scales = not self.input_quant.dynamic
if self.static_input_scales and per_channel:
raise ValueError(
"For FP8 Fused MoE layer, we require either per tensor or "
"channelwise, dynamic per token quantization.")
def create_weights(self, layer: torch.nn.Module, num_experts: int,
hidden_size: int, intermediate_size_per_partition: int,
params_dtype: torch.dtype, **extra_weight_attrs):
params_dtype = torch.float8_e4m3fn
# WEIGHTS
w13_weight = torch.nn.Parameter(torch.empty(
num_experts,
2 * intermediate_size_per_partition,
hidden_size,
dtype=params_dtype),
requires_grad=False)
layer.register_parameter("w13_weight", w13_weight)
set_weight_attrs(w13_weight, extra_weight_attrs)
w2_weight = torch.nn.Parameter(torch.empty(
num_experts,
hidden_size,
intermediate_size_per_partition,
dtype=params_dtype),
requires_grad=False)
layer.register_parameter("w2_weight", w2_weight)
set_weight_attrs(w2_weight, extra_weight_attrs)
# WEIGHT_SCALES
if self.weight_quant.strategy == QuantizationStrategy.TENSOR:
# Allocate 2 scales for w1 and w3 respectively.
# They are combined to a single scale after weight loading.
w13_weight_scale = torch.nn.Parameter(torch.ones(
num_experts, 2, dtype=torch.float32),
requires_grad=False)
layer.register_parameter("w13_weight_scale", w13_weight_scale)
w2_weight_scale = torch.nn.Parameter(torch.ones(
num_experts, dtype=torch.float32),
requires_grad=False)
layer.register_parameter("w2_weight_scale", w2_weight_scale)
# Add PER-TENSOR quantization for FusedMoE.weight_loader.
extra_weight_attrs.update(
{"quant_method": FusedMoeWeightScaleSupported.TENSOR.value})
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
elif self.weight_quant.strategy == QuantizationStrategy.CHANNEL:
w13_weight_scale = torch.nn.Parameter(torch.ones(
num_experts,
2 * intermediate_size_per_partition,
1,
dtype=torch.float32),
requires_grad=False)
layer.register_parameter("w13_weight_scale", w13_weight_scale)
w2_weight_scale = torch.nn.Parameter(torch.ones(
num_experts, hidden_size, 1, dtype=torch.float32),
requires_grad=False)
layer.register_parameter("w2_weight_scale", w2_weight_scale)
# Add PER-CHANNEL quantization for FusedMoE.weight_loader.
extra_weight_attrs.update(
{"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value})
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
# INPUT_SCALES
if self.static_input_scales:
w13_input_scale = torch.nn.Parameter(torch.ones(
num_experts, dtype=torch.float32),
requires_grad=False)
layer.register_parameter("w13_input_scale", w13_input_scale)
set_weight_attrs(w13_input_scale, extra_weight_attrs)
w2_input_scale = torch.nn.Parameter(torch.ones(
num_experts, dtype=torch.float32),
requires_grad=False)
layer.register_parameter("w2_input_scale", w2_input_scale)
set_weight_attrs(w2_input_scale, extra_weight_attrs)
else:
layer.w13_input_scale = None
layer.w2_input_scale = None
device = w13_weight.device
# TODO strides can be shared across multiple layers
self.ab_strides1 = torch.full((num_experts, ),
hidden_size,
device=device,
dtype=torch.int64)
self.c_strides1 = torch.full((num_experts, ),
2 * intermediate_size_per_partition,
device=device,
dtype=torch.int64)
self.ab_strides2 = torch.full((num_experts, ),
intermediate_size_per_partition,
device=device,
dtype=torch.int64)
self.c_strides2 = torch.full((num_experts, ),
hidden_size,
device=device,
dtype=torch.int64)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
# Fp8 moe kernels require a single activation scale.
# We take the max of all the scales in case they differ.
if self.static_input_scales:
assert self.input_quant.strategy == QuantizationStrategy.TENSOR
if (layer.w13_input_scale is None or layer.w2_input_scale is None):
raise ValueError(
"QuantConfig has static quantization, but found "
"activation scales are None.")
if (not all_close_1d(layer.w13_input_scale)
or not all_close_1d(layer.w2_input_scale)):
logger.warning_once(
"Found input_scales that are not equal for "
"fp8 MoE layer. Using the maximum across experts "
"for each layer.")
layer.w13_input_scale = torch.nn.Parameter(
layer.w13_input_scale.max(), requires_grad=False)
layer.w2_input_scale = torch.nn.Parameter(
layer.w2_input_scale.max(), requires_grad=False)
# For Per-TENSOR case, Fp8 moe kernel needs single weight scale
# for w13 per expert. Use max then dequant and requant each expert.
if self.weight_quant.strategy == QuantizationStrategy.TENSOR:
assert layer.w13_weight_scale is not None
shard_size = layer.intermediate_size_per_partition
max_w13_scales = layer.w13_weight_scale.max(dim=1).values
for expert_id in range(layer.local_num_experts):
start = 0
for shard_id in range(2):
dq_weight = per_tensor_dequantize(
layer.w13_weight[expert_id][start:start +
shard_size, :],
layer.w13_weight_scale[expert_id][shard_id])
layer.w13_weight[expert_id][
start:start + shard_size, :], _ = ops.scaled_fp8_quant(
dq_weight, max_w13_scales[expert_id])
start += shard_size
layer.w13_weight_scale = torch.nn.Parameter(max_w13_scales,
requires_grad=False)
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool = False,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
) -> torch.Tensor:
assert activation == "silu"
assert global_num_experts == layer.w13_weight.shape[0]
assert expert_map is None
topk_weights, topk_ids = FusedMoE.select_experts(
hidden_states=x,
router_logits=router_logits,
use_grouped_topk=use_grouped_topk,
top_k=top_k,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias)
from vllm.model_executor.layers.fused_moe import cutlass_moe_fp8
return cutlass_moe_fp8(
x,
layer.w13_weight.transpose(1, 2),
layer.w2_weight.transpose(1, 2),
layer.w13_weight_scale,
layer.w2_weight_scale,
topk_weights,
topk_ids,
self.ab_strides1,
self.c_strides1,
self.ab_strides2,
self.c_strides2,
a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale,
out_dtype=x.dtype,
apply_router_weight_on_input=apply_router_weight_on_input,
)
class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
......@@ -551,6 +783,7 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
) -> torch.Tensor:
assert activation == "silu", "Only SiLU activation is supported."
......@@ -558,6 +791,10 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
raise NotImplementedError(
"Expert Parallelism is not supported for "
"fused Marlin MoE method.")
if apply_router_weight_on_input:
raise NotImplementedError(
"Apply router weight on input is not supported for "
"fused Marlin MoE method.")
topk_weights, topk_ids = FusedMoE.select_experts(
hidden_states=x,
......
......@@ -23,6 +23,7 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
def __init__(self, strategy: str, is_static_input_scheme: bool):
self.strategy = strategy
self.out_dtype = torch.get_default_dtype()
self.is_static_input_scheme = is_static_input_scheme
self.fp8_linear = Fp8LinearOp(use_per_token_if_dynamic=True)
......@@ -143,5 +144,6 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
return self.fp8_linear.apply(input=x,
weight=layer.weight,
weight_scale=layer.weight_scale,
out_dtype=self.out_dtype,
input_scale=layer.input_scale,
bias=bias)
......@@ -113,6 +113,7 @@ class ExpertsInt8MoEMethod(FusedMoEMethodBase):
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
) -> torch.Tensor:
from vllm.model_executor.layers.fused_moe import fused_experts
......@@ -129,18 +130,20 @@ class ExpertsInt8MoEMethod(FusedMoEMethodBase):
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias)
return fused_experts(x,
layer.w13_weight,
layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=True,
activation=activation,
use_int8_w8a16=True,
global_num_experts=global_num_experts,
expert_map=expert_map,
w1_scale=layer.w13_scale,
w2_scale=layer.w2_scale)
return fused_experts(
x,
layer.w13_weight,
layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=True,
activation=activation,
use_int8_w8a16=True,
global_num_experts=global_num_experts,
apply_router_weight_on_input=apply_router_weight_on_input,
expert_map=expert_map,
w1_scale=layer.w13_scale,
w2_scale=layer.w2_scale)
@staticmethod
def quantizing_weight_loader(layer, weight_loader):
......
......@@ -73,6 +73,7 @@ class FBGEMMFp8LinearMethod(LinearMethodBase):
def __init__(self, quant_config: FBGEMMFp8Config):
self.quant_config = quant_config
self.fp8_linear = Fp8LinearOp(use_per_token_if_dynamic=True)
self.out_dtype = torch.get_default_dtype()
def create_weights(
self,
......@@ -161,6 +162,7 @@ class FBGEMMFp8LinearMethod(LinearMethodBase):
return self.fp8_linear.apply(input=x,
weight=layer.weight,
weight_scale=layer.weight_scale,
out_dtype=self.out_dtype,
input_scale=None,
input_scale_ub=layer.input_scale_ub,
bias=bias)
# SPDX-License-Identifier: Apache-2.0
import importlib.util
from typing import Any, Callable, Dict, List, Optional
import torch
......@@ -37,6 +38,14 @@ ACTIVATION_SCHEMES = ["static", "dynamic"]
logger = init_logger(__name__)
has_deep_gemm = importlib.util.find_spec("deep_gemm") is not None
def _is_col_major(x: torch.Tensor) -> bool:
assert x.dim() == 3
b, m, n = x.shape
return x.stride(0) == m * n and x.stride(1) == 1 and x.stride(2) == m
class Fp8Config(QuantizationConfig):
"""Config class for FP8."""
......@@ -116,6 +125,21 @@ class Fp8Config(QuantizationConfig):
return Fp8KVCacheMethod(self)
return None
def get_cache_scale(self, name: str) -> Optional[str]:
"""
Check whether the param name matches the format for k/v cache scales
in compressed-tensors. If this is the case, return its equivalent
param name expected by vLLM
:param name: param name
:return: matching param name for KV cache scale in vLLM
"""
if name.endswith(".output_scale") and ".k_proj" in name:
return name.replace(".k_proj.output_scale", ".attn.k_scale")
if name.endswith(".output_scale") and ".v_proj" in name:
return name.replace(".v_proj.output_scale", ".attn.v_scale")
return None
class Fp8LinearMethod(LinearMethodBase):
"""Linear method for FP8.
......@@ -138,6 +162,7 @@ class Fp8LinearMethod(LinearMethodBase):
def __init__(self, quant_config: Fp8Config):
self.quant_config = quant_config
self.cutlass_block_fp8_supported = cutlass_block_fp8_supported()
self.out_dtype = torch.get_default_dtype()
# For GPUs that lack FP8 hardware support, we can leverage the Marlin
# kernel for fast weight-only FP8 quantization
......@@ -386,6 +411,7 @@ class Fp8LinearMethod(LinearMethodBase):
return self.fp8_linear.apply(input=x,
weight=layer.weight,
weight_scale=layer.weight_scale,
out_dtype=self.out_dtype,
input_scale=layer.input_scale,
bias=bias)
......@@ -407,6 +433,19 @@ class Fp8MoEMethod(FusedMoEMethodBase):
self.quant_config = quant_config
self.block_quant = self.quant_config.weight_block_size is not None
# Check for DeepGemm support.
self.allow_deep_gemm = False
if envs.VLLM_USE_DEEP_GEMM:
if not has_deep_gemm:
logger.warning_once("Failed to import DeepGemm kernels.")
elif (current_platform.is_cuda()
and current_platform.has_device_capability(90)):
logger.info_once("Using DeepGemm kernels for Fp8MoEMethod.")
self.allow_deep_gemm = True
else:
logger.warning_once(
"DeepGemm not supported on the current platform.")
def create_weights(self, layer: Module, num_experts: int, hidden_size: int,
intermediate_size_per_partition: int,
params_dtype: torch.dtype, **extra_weight_attrs):
......@@ -529,6 +568,11 @@ class Fp8MoEMethod(FusedMoEMethodBase):
layer.w2_input_scale = None
def process_weights_after_loading(self, layer: Module) -> None:
# Lazy import to avoid importing triton too early.
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
expand_weights, is_rocm_aiter_block_scaled_moe_enabled,
is_rocm_aiter_moe_enabled, shuffle_weights)
# TODO (rob): refactor block quant into separate class.
if self.block_quant:
assert self.quant_config.activation_scheme == "dynamic"
......@@ -554,6 +598,28 @@ class Fp8MoEMethod(FusedMoEMethodBase):
layer.w2_weight = Parameter(w2_weight, requires_grad=False)
layer.w2_weight_scale_inv = Parameter(w2_weight_scale_inv,
requires_grad=False)
if is_rocm_aiter_block_scaled_moe_enabled():
# reshaping weights is required for aiter moe kernel.
shuffled_w13, shuffled_w2 = shuffle_weights(
layer.w13_weight.data, layer.w2_weight.data)
layer.w13_weight = torch.nn.Parameter(shuffled_w13,
requires_grad=False)
layer.w2_weight = torch.nn.Parameter(shuffled_w2,
requires_grad=False)
# DeepGemm scales need to be transposed and aligned. We try to do
# it ahead of time for performance reasons.
if self.allow_deep_gemm:
# Lazy import to avoid CUDA initialization problems.
import deep_gemm as dg
if _is_col_major(layer.w13_weight_scale_inv):
layer.w13_weight_scale_inv = \
dg.get_col_major_tma_aligned_tensor(layer.w13_weight_scale_inv).contiguous()
if _is_col_major(layer.w2_weight_scale_inv):
layer.w2_weight_scale_inv = \
dg.get_col_major_tma_aligned_tensor(layer.w2_weight_scale_inv).contiguous()
return
# If checkpoint is fp16, quantize in place.
......@@ -581,6 +647,26 @@ class Fp8MoEMethod(FusedMoEMethodBase):
requires_grad=False)
layer.w2_weight = torch.nn.Parameter(w2_weight,
requires_grad=False)
if is_rocm_aiter_moe_enabled():
# reshaping weights is required for aiter moe kernel.
w13_scales, w2_scales = expand_weights(
layer.w13_weight_scale.data,
layer.w2_weight_scale.data,
expansion_dims=[
layer.w13_weight.shape[1], layer.w2_weight.shape[1]
])
layer.w13_weight_scale = torch.nn.Parameter(
w13_scales.contiguous(), requires_grad=False)
layer.w2_weight_scale = torch.nn.Parameter(
w2_scales.contiguous(), requires_grad=False)
shuffled_w13, shuffled_w2 = shuffle_weights(
layer.w13_weight, layer.w2_weight)
layer.w13_weight = torch.nn.Parameter(shuffled_w13,
requires_grad=False)
layer.w2_weight = torch.nn.Parameter(shuffled_w2,
requires_grad=False)
return
# If checkpoint is fp8, we need to handle that the
......@@ -648,6 +734,26 @@ class Fp8MoEMethod(FusedMoEMethodBase):
dq_weight, max_w13_scales[expert_id])
start += shard_size
if is_rocm_aiter_moe_enabled():
# reshaping weights is required for aiter moe kernel.
expansion_dims = [
layer.w13_weight.shape[1], layer.w2_weight.shape[1]
]
max_w13_scales, w2_scales = expand_weights(
max_w13_scales,
layer.w2_weight_scale.data,
expansion_dims=expansion_dims)
layer.w2_weight_scale = torch.nn.Parameter(
w2_scales.contiguous(), requires_grad=False)
shuffled_w13, shuffled_w2 = shuffle_weights(
layer.w13_weight, layer.w2_weight)
layer.w13_weight = torch.nn.Parameter(shuffled_w13,
requires_grad=False)
layer.w2_weight = torch.nn.Parameter(shuffled_w2,
requires_grad=False)
layer.w13_weight_scale = torch.nn.Parameter(max_w13_scales,
requires_grad=False)
return
......@@ -667,6 +773,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
) -> torch.Tensor:
from vllm.model_executor.layers.fused_moe import fused_experts
......@@ -694,6 +801,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
activation=activation,
use_fp8_w8a8=True,
global_num_experts=global_num_experts,
apply_router_weight_on_input=apply_router_weight_on_input,
expert_map=expert_map,
w1_scale=(layer.w13_weight_scale_inv
if self.block_quant else layer.w13_weight_scale),
......@@ -702,6 +810,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale,
block_shape=self.quant_config.weight_block_size,
allow_deep_gemm=self.allow_deep_gemm,
)
......
......@@ -117,7 +117,7 @@ def _fuse_mul_mat(x: torch.Tensor, qweight: torch.Tensor,
elif qweight_type in DEQUANT_TYPES:
block_size, type_size = gguf.GGML_QUANT_SIZES[qweight_type]
shape = (qweight.shape[0], qweight.shape[1] // type_size * block_size)
weight = ops.ggml_dequantize(qweight, qweight_type, *shape)
weight = ops.ggml_dequantize(qweight, qweight_type, *shape, x.dtype)
y = x @ weight.T
else:
# Raise an error if the quantization type is not supported.
......@@ -338,9 +338,15 @@ class GGUFMoEMethod(FusedMoEMethodBase):
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
):
assert activation == "silu", "Only SiLU activation is supported."
if apply_router_weight_on_input:
raise NotImplementedError(
"Apply router weight on input is not supported for"
"fused GGUF MoE method.")
topk_weights, topk_ids = FusedMoE.select_experts(
hidden_states=x,
router_logits=router_logits,
......@@ -377,7 +383,7 @@ class GGUFEmbeddingMethod(GGUFLinearMethod):
x_flat = x.flatten()
quant = torch.index_select(qweight, dim=0, index=x_flat)
dequant = ops.ggml_dequantize(quant, qweight_type, hidden_size,
x_flat.shape[0]).to(self.params_dtype)
x_flat.shape[0], self.params_dtype)
return dequant.view(*x.shape, hidden_size)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment