Commit 65f43084 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge branch 'v0.15.1-dev-wm' into 'v0.15.1-dev'

适配w8a8 deepep,接入lightop版deepgemm

See merge request dcutoolkit/deeplearing/vllm!418
parents e807ec39 d08e3d52
......@@ -259,7 +259,7 @@ class DeepEPAll2AllManagerBase(All2AllManagerBase):
# This is the DeepEP default. Stick to it till we can establish
# reasonable defaults based on profiling.
self.num_sms = 20
self.num_sms = 30
def get_handle(self, kwargs):
raise NotImplementedError
......@@ -292,16 +292,21 @@ class DeepEPHTAll2AllManager(DeepEPAll2AllManagerBase):
def _make_all2all_kwargs(self) -> dict[Any, Any]:
# Defaults for internode and intranode are taken from DeepEP tests.
num_nvl_bytes = envs.VLLM_DEEPEP_BUFFER_SIZE_MB * 1024 * 1024
#num_nvl_bytes = envs.VLLM_DEEPEP_BUFFER_SIZE_MB * 1024 * 1024
num_nvl_bytes = int(2e9/2)#1024 * 1024 * 1024
num_rdma_bytes = None
num_qps_per_rank = None
if self.internode and not envs.VLLM_DEEPEP_HIGH_THROUGHPUT_FORCE_INTRA_NODE:
num_rdma_bytes = envs.VLLM_DEEPEP_BUFFER_SIZE_MB * 1024 * 1024
num_qps_per_rank = self.num_sms // 2
# num_rdma_bytes = envs.VLLM_DEEPEP_BUFFER_SIZE_MB * 1024 * 1024
# num_qps_per_rank = self.num_sms // 2
num_rdma_bytes = int(1e9/2) #1024 * 1024 * 1024
num_qps_per_rank = 30 #self.num_sms // 2
self.num_sms = 30
else:
num_rdma_bytes = 0
num_qps_per_rank = 1
self.num_sms = 60
assert num_rdma_bytes is not None
assert num_qps_per_rank is not None
......
......@@ -162,6 +162,8 @@ def maybe_make_prepare_finalize(
and quant_config.block_shape == DEEPEP_QUANT_BLOCK_SHAPE
)
use_int8_dispatch = quant_config.quant_dtype == torch.int8
prepare_finalize = DeepEPLLPrepareAndFinalize(
handle,
max_tokens_per_rank=moe.max_num_tokens,
......@@ -170,6 +172,7 @@ def maybe_make_prepare_finalize(
global_to_physical=global_to_physical,
physical_to_global=physical_to_global,
local_expert_global_ids=local_expert_global_ids,
use_int8_dispatch=use_int8_dispatch,
)
elif moe.use_mori_kernels:
assert quant_config is not None
......
......@@ -32,6 +32,16 @@ from vllm.utils.deep_gemm import (
)
from vllm.utils.math_utils import cdiv, round_up
from vllm.utils.import_utils import has_deep_gemm
from lightop import fuse_silu_mul_quant_ep
if has_deep_gemm():
from deep_gemm import m_grouped_w8a8_gemm_nt_masked
else:
from lightop import m_grouped_w8a8_gemm_nt_masked
logger = init_logger(__name__)
......@@ -267,6 +277,8 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
quant_config: FusedMoEQuantConfig,
max_num_tokens: int,
num_dispatchers: int,
N: int = -1,
K: int = -1,
):
"""
max_num_tokens: Maximum number of tokens from a DP Rank
......@@ -279,8 +291,12 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
max_num_tokens=max_num_tokens,
num_dispatchers=num_dispatchers,
)
if quant_config.use_fp8_w8a8:
assert self.block_shape == get_mk_alignment_for_contiguous_layout()
assert self.quant_config.use_fp8_w8a8
self.N = N
self.K = K
@staticmethod
def activation_format() -> mk.FusedMoEActivationFormat:
......@@ -398,6 +414,7 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
workspace2: torch.Tensor,
expert_tokens_meta: mk.ExpertTokensMetadata | None,
apply_router_weight_on_input: bool,
use_nn_moe: bool | None = False,
):
assert expert_tokens_meta is not None
expert_num_tokens = expert_tokens_meta.expert_num_tokens
......@@ -408,12 +425,15 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
a1q = hidden_states
_, N, K = w1.size()
assert w2.size(1) == K
#assert w2.size(1) == K
E, max_num_tokens, N, K, _ = self.moe_problem_size(
hidden_states, w1, w2, topk_ids
)
if self.N > 0:
N = self.N
workspace1 = _resize_cache(workspace13, (E, max_num_tokens, N))
expected_m = self.estimate_expected_m(
......@@ -422,6 +442,7 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
topk=topk_ids.size(-1),
)
if self.quant_config.use_fp8_w8a16 or self.quant_config.use_fp8_w8a8:
fp8_m_grouped_gemm_nt_masked(
(a1q, a1q_scale),
(w1, self.w1_scale),
......@@ -444,3 +465,21 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
expert_num_tokens,
expected_m,
)
elif self.quant_config.use_int8_w8a8:
m_grouped_w8a8_gemm_nt_masked((a1q, a1q_scale),
(w1, self.w1_scale),
workspace1,
expert_num_tokens,
expected_m,
)
assert expert_num_tokens is not None
a2q, a2q_scale = fuse_silu_mul_quant_ep(workspace1, expert_num_tokens)
m_grouped_w8a8_gemm_nt_masked((a2q, a2q_scale),
(w2, self.w2_scale),
output,
expert_num_tokens,
expected_m)
else:
raise ValueError(f"Unsupported dtype {self.quant_config.quant_dtype}")
......@@ -87,7 +87,7 @@ def _quant_flags_to_group_shape(
"""
a_shape: GroupShape | None
w_shape: GroupShape | None
if block_shape is not None:
if block_shape is not None and quant_dtype!=torch.int8:
assert not per_act_token_quant
assert not per_out_ch_quant
# TODO(bnell): this is not quite right for activations since first
......@@ -211,10 +211,10 @@ class FusedMoEQuantConfig:
_w1: FusedMoEQuantDesc
_w2: FusedMoEQuantDesc
def __post_init__(self):
assert not self.per_act_token_quant or self.block_shape is None, (
"illegal quantization"
)
# def __post_init__(self):
# assert not self.per_act_token_quant or self.block_shape is None, (
# "illegal quantization"
# )
#
# Convenience accessors for various properties.
......@@ -246,6 +246,9 @@ class FusedMoEQuantConfig:
@property
def block_shape(self) -> list[int] | None:
if self.use_int8_w8a8:
return [256, 256]
if (
self._a1.shape is not None
and self._a1.shape != GroupShape.PER_TENSOR
......@@ -569,7 +572,7 @@ def int8_w8a8_moe_quant_config(
a2_scale=a2_scale,
per_act_token_quant=per_act_token_quant,
per_out_ch_quant=False,
block_shape=None,
block_shape=[256, 256],
)
......
......@@ -37,6 +37,12 @@ from vllm.utils.deep_gemm import (
)
from vllm.utils.import_utils import has_deep_gemm
from lightop import fuse_silu_mul_quant
if has_deep_gemm():
from deep_gemm import m_grouped_i8_gemm_nt_contiguous
else:
from lightop import m_grouped_w8a8_gemm_nt_contig_asm as m_grouped_i8_gemm_nt_contiguous
logger = init_logger(__name__)
......@@ -113,13 +119,23 @@ def _valid_deep_gemm(
class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
def __init__(self, moe_config: FusedMoEConfig, quant_config: FusedMoEQuantConfig):
def __init__(self,
moe_config: FusedMoEConfig,
quant_config: FusedMoEQuantConfig,
N: int = -1,
K: int = -1,):
super().__init__(moe_config=moe_config, quant_config=quant_config)
if quant_config.use_fp8_w8a8 or quant_config.use_fp8_w8a16:
assert quant_config.block_shape == get_mk_alignment_for_contiguous_layout()
assert quant_config.quant_dtype == torch.float8_e4m3fn
assert not quant_config.per_act_token_quant
assert not quant_config.per_out_ch_quant
self.N = N
self.K = K
@staticmethod
def activation_format() -> mk.FusedMoEActivationFormat:
return mk.FusedMoEActivationFormat.Standard
......@@ -241,6 +257,7 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
workspace2: torch.Tensor,
expert_tokens_meta: mk.ExpertTokensMetadata | None,
apply_router_weight_on_input: bool,
use_nn_moe: bool | None = False,
):
assert a1q_scale is not None
assert a2_scale is None
......@@ -255,19 +272,24 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
if global_num_experts == -1:
global_num_experts = local_num_experts
assert w2.size(1) == K
#assert w2.size(1) == K
if self.N > 0:
N = self.N
K = self.K
use_fp8 = self.quant_config.use_fp8_w8a16 or self.quant_config.use_fp8_w8a8
M_sum = compute_aligned_M(
M=topk_ids.size(0),
num_topk=topk_ids.size(1),
local_num_experts=local_num_experts,
alignment=get_mk_alignment_for_contiguous_layout()[0],
alignment=get_mk_alignment_for_contiguous_layout()[0] if use_fp8 else self.block_shape[0],
expert_tokens_meta=expert_tokens_meta,
)
a1q_perm = _resize_cache(
workspace13.view(dtype=torch.float8_e4m3fn), (M_sum, K)
workspace13.view(dtype=torch.float8_e4m3fn if use_fp8 else a1q.dtype), (M_sum, K)
)
a1q, a1q_scale, expert_ids, inv_perm = deepgemm_moe_permute(
aq=a1q,
aq_scale=a1q_scale,
......@@ -280,6 +302,8 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
assert a1q.size(0) == M_sum
mm1_out = _resize_cache(workspace2, (M_sum, N))
if use_fp8:
m_grouped_fp8_gemm_nt_contiguous(
(a1q, a1q_scale), (w1, self.w1_scale), mm1_out, expert_ids
)
......@@ -296,6 +320,19 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
m_grouped_fp8_gemm_nt_contiguous(
(a2q, a2q_scale), (w2, self.w2_scale), mm2_out, expert_ids
)
elif self.quant_config.use_int8_w8a8:
m_grouped_i8_gemm_nt_contiguous(
(a1q, a1q_scale), (w1, self.w1_scale), mm1_out, expert_ids)
a2q, a2q_scale = fuse_silu_mul_quant(mm1_out)
#a2q, a2q_scale = fuse_silu_mul_quant(input=mm1_out, expert_ids=expert_ids)
mm2_out = _resize_cache(workspace2, (M_sum, K))
m_grouped_i8_gemm_nt_contiguous(
(a2q, a2q_scale), (w2, self.w2_scale), mm2_out, expert_ids)
else:
raise ValueError(f"Unsupported dtype {self.quant_config.quant_dtype}")
if apply_router_weight_on_input:
topk_weights = torch.ones_like(topk_weights)
......
......@@ -13,6 +13,8 @@ from vllm.triton_utils import tl, triton
from vllm.utils.deep_gemm import get_mk_alignment_for_contiguous_layout
from vllm.utils.math_utils import round_up
from lightop import op
def expert_num_tokens_round_up_and_sum(
expert_num_tokens: torch.Tensor, alignment: int
......@@ -57,6 +59,12 @@ def round_up_128(x: int) -> int:
return ((x + y - 1) // y) * y
@triton.jit
def round_up_256(x: int) -> int:
y = 256
return ((x + y - 1) // y) * y
@triton.jit
def _fwd_kernel_ep_scatter_1(
num_recv_tokens_per_expert,
......@@ -74,26 +82,27 @@ def _fwd_kernel_ep_scatter_1(
mask=offset_cumsum < num_experts,
other=0,
)
tokens_per_expert = round_up_128(tokens_per_expert)
#tokens_per_expert = round_up_128(tokens_per_expert)
tokens_per_expert = round_up_256(tokens_per_expert)
cumsum = tl.cumsum(tokens_per_expert) - tokens_per_expert
#if cur_expert == 0:
tl.store(expert_start_loc + offset_cumsum, cumsum, mask=offset_cumsum < num_experts)
tl.debug_barrier()
#cur_expert_start = cumsum[cur_expert]
cur_expert_start = tl.load(expert_start_loc + cur_expert)
cur_expert_token_num = tl.load(num_recv_tokens_per_expert + cur_expert)
m_indices_start_ptr = m_indices + cur_expert_start
off_expert = tl.arange(0, BLOCK_E)
# any rows in the per-expert aligned region that do not correspond to
# real tokens are left untouched here and should remain initialized to
# -1 so DeepGEMM can skip them
for start_m in tl.range(0, cur_expert_token_num, BLOCK_E, num_stages=4):
offs = start_m + off_expert
mask = offs < cur_expert_token_num
tl.store(
m_indices_start_ptr + offs,
m_indices_start_ptr + start_m + off_expert,
cur_expert,
mask=mask,
mask=start_m + off_expert < cur_expert_token_num
)
......@@ -133,26 +142,32 @@ def _fwd_kernel_ep_scatter_2(
offset_in = tl.arange(0, HIDDEN_SIZE_PAD)
mask = offset_in < HIDDEN_SIZE
offset_in_s = tl.arange(0, SCALE_HIDDEN_SIZE_PAD)
mask_s = offset_in_s < SCALE_HIDDEN_SIZE
index_in_s = tl.arange(0, SCALE_HIDDEN_SIZE_PAD)
mask_s = index_in_s < SCALE_HIDDEN_SIZE
for token_id in range(start_token_id, total_token_num, grid_num):
for token_id_int32 in range(start_token_id, total_token_num, grid_num):
token_id = token_id_int32.to(tl.int64)
to_copy = tl.load(recv_x + token_id * recv_x_stride0 + offset_in, mask=mask)
to_copy_s = tl.load(
recv_x_scale + token_id * recv_x_scale_stride0 + offset_in_s, mask=mask_s
recv_x_scale
+ token_id * recv_x_scale_stride0
+ index_in_s * recv_x_scale_stride1,
mask=mask_s,
)
for topk_index in tl.range(0, topk_num, 1, num_stages=4):
for topk_idx_int32 in tl.range(0, topk_num, 1, num_stages=4):
topk_index = topk_idx_int32.to(tl.int64)
expert_id = tl.load(recv_topk + token_id * recv_topk_stride0 + topk_index)
if HAS_EXPERT_MAP:
expert_id = apply_expert_map(expert_id, expert_map)
if expert_id >= 0:
dest_token_index = tl.atomic_add(expert_start_loc + expert_id, 1)
dest_token_index_int32 = tl.atomic_add(expert_start_loc + expert_id, 1)
dest_token_index = dest_token_index_int32.to(tl.int64)
tl.store(
output_index + token_id * output_index_stride0 + topk_index,
dest_token_index,
dest_token_index_int32,
)
output_tensor_ptr = (
output_tensor + dest_token_index * output_tensor_stride0
......@@ -161,7 +176,11 @@ def _fwd_kernel_ep_scatter_2(
output_tensor_scale + dest_token_index * output_tensor_scale_stride0
)
tl.store(output_tensor_ptr + offset_in, to_copy, mask=mask)
tl.store(output_tensor_scale_ptr + offset_in_s, to_copy_s, mask=mask_s)
tl.store(
output_tensor_scale_ptr + index_in_s * output_tensor_scale_stride1,
to_copy_s,
mask=mask_s,
)
@torch.no_grad()
......@@ -177,16 +196,27 @@ def ep_scatter(
m_indices: torch.Tensor,
output_index: torch.Tensor,
):
BLOCK_E = 128 # token num of per expert is aligned to 128
BLOCK_D = 128 # block size of quantization
# BLOCK_E = 128 # token num of per expert is aligned to 128
# BLOCK_D = 128 # block size of quantization
BLOCK_E = 256 # token num of per expert is aligned to 256
num_warps = 8
num_experts = num_recv_tokens_per_expert.shape[0]
hidden_size = recv_x.shape[1]
scale_hidden_size = recv_x_scale.shape[-1]
# grid = (triton.cdiv(hidden_size, BLOCK_D), num_experts)
grid = num_experts
assert m_indices.shape[0] % BLOCK_E == 0
if hasattr(op, "ep_scatter"):
op.ep_scatter(
recv_x, recv_x_scale,
recv_topk, expert_map,
num_recv_tokens_per_expert,
output_tensor, output_tensor_scale, m_indices, output_index,
num_experts, BLOCK_E
)
else:
_fwd_kernel_ep_scatter_1[(grid,)](
num_recv_tokens_per_expert,
expert_start_loc,
......@@ -226,8 +256,10 @@ def ep_scatter(
num_warps=num_warps,
HIDDEN_SIZE=hidden_size,
HIDDEN_SIZE_PAD=triton.next_power_of_2(hidden_size),
SCALE_HIDDEN_SIZE=hidden_size // BLOCK_D,
SCALE_HIDDEN_SIZE_PAD=triton.next_power_of_2(hidden_size // BLOCK_D),
# SCALE_HIDDEN_SIZE=hidden_size // BLOCK_D,
# SCALE_HIDDEN_SIZE_PAD=triton.next_power_of_2(hidden_size // BLOCK_D),
SCALE_HIDDEN_SIZE=scale_hidden_size,
SCALE_HIDDEN_SIZE_PAD=triton.next_power_of_2(scale_hidden_size),
)
return
......@@ -255,25 +287,34 @@ def _fwd_kernel_ep_gather(
HAS_EXPERT_MAP: tl.constexpr,
BLOCK_D: tl.constexpr,
):
cur_block = tl.program_id(0)
start_cur_token = tl.program_id(1)
cur_block_int32 = tl.program_id(0)
cur_block = cur_block_int32.to(tl.int64)
start_cur_token_int32 = tl.program_id(1)
grid_num = tl.num_programs(1)
for cur_token in range(start_cur_token, total_token_num, grid_num):
for cur_token_int32 in range(start_cur_token_int32, total_token_num, grid_num):
cur_token = cur_token_int32.to(tl.int64)
off_d = tl.arange(0, BLOCK_D)
accumulator = tl.zeros([BLOCK_D], dtype=tl.float32)
for topk_index in range(0, topk_num):
for topk_index_int32 in range(0, topk_num):
topk_index = topk_index_int32.to(tl.int64)
expert_id = tl.load(
recv_topk_ids + cur_token * recv_topk_ids_stride0 + topk_index
)
if HAS_EXPERT_MAP:
expert_id = apply_expert_map(expert_id, expert_map)
if expert_id >= 0:
source_token_index = tl.load(
source_token_index_int32 = tl.load(
input_index + cur_token * input_index_stride0 + topk_index
)
source_token_index = source_token_index_int32.to(tl.int64)
acc_weight = tl.load(
recv_topk_weight + cur_token * recv_topk_weight_stride0 + topk_index
)
......@@ -350,7 +391,8 @@ def deepgemm_moe_permute(
H = aq.size(1)
device = aq.device
block_m, block_k = get_mk_alignment_for_contiguous_layout()
#block_m, block_k = get_mk_alignment_for_contiguous_layout()
block_m = 256
M_sum = compute_aligned_M(
M=topk_ids.size(0),
......@@ -368,8 +410,11 @@ def deepgemm_moe_permute(
if aq_out is None:
aq_out = torch.empty((M_sum, H), device=device, dtype=aq.dtype)
# aq_scale_out = torch.empty(
# (M_sum, H // block_k), device=device, dtype=torch.float32
# )
aq_scale_out = torch.empty(
(M_sum, H // block_k), device=device, dtype=torch.float32
(M_sum, aq_scale.shape[-1]), device=device, dtype=torch.float32
)
# DeepGEMM uses negative values in m_indices (here expert_ids) to mark
......
......@@ -225,7 +225,7 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
# DeepEP kernels only support dispatching block-quantized
# activation scales.
# Dispatch in bfloat16 and quantize afterwards
if not quant_config.is_block_quantized:
if not quant_config.is_block_quantized and not quant_config.is_per_act_token:
# Quantize after dispatch.
expert_x_scale = None
if expert_x.numel() != 0:
......@@ -266,7 +266,7 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
)
a1 = a1 * topk_weights.to(a1.dtype)
if quant_config.is_block_quantized:
if quant_config.is_block_quantized or quant_config.is_per_act_token:
# Quant and Dispatch
a1q, a1q_scale = moe_kernel_quantize_input(
a1,
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Callable
from typing import Callable, Optional
import deep_ep
import torch
......@@ -91,12 +92,14 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
global_to_physical: torch.Tensor | None = None,
physical_to_global: torch.Tensor | None = None,
local_expert_global_ids: torch.Tensor | None = None,
use_int8_dispatch: bool = False
):
super().__init__()
self.buffer = buffer
self.max_tokens_per_rank = max_tokens_per_rank
self.use_fp8_dispatch = use_fp8_dispatch
self.use_int8_dispatch = use_int8_dispatch
# The dispatch function returns a handle that the combine function
# requires. We store the handle here so it is available to the
# combine function.
......@@ -168,6 +171,7 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
x: torch.Tensor | tuple[torch.Tensor, torch.Tensor],
a1_dtype: torch.dtype,
quant_config: FusedMoEQuantConfig,
expert_num_tokens: Optional[torch.Tensor]= None,
) -> tuple[torch.Tensor, torch.Tensor | None]:
if self.use_fp8_dispatch:
block_k = (
......@@ -183,6 +187,9 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
# Dequant to get back the tokens in the datatype we dispatched in.
x_fp8, x_scales = x
x = dequant_fp8(x_fp8, x_scales).to(dtype=a1_dtype)
elif self.use_int8_dispatch:
x, x_scales = x
return x, x_scales
assert isinstance(x, (torch.Tensor, tuple))
q_dtype = quant_config.quant_dtype
......@@ -214,6 +221,7 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
num_experts, max_tokens, hidden_dim = x.size()
# TODO (varun): Optimization - Use a batched version of quant
if expert_num_tokens is None:
x = x.view((-1, hidden_dim))
x, x_scales = moe_kernel_quantize_input(
x,
......@@ -294,7 +302,8 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
dispatch_topk_ids,
self.max_tokens_per_rank,
num_experts,
use_fp8=self.use_fp8_dispatch,
use_fp8=self.use_fp8_dispatch or self.use_int8_dispatch,
use_int8=self.use_int8_dispatch,
round_scale=self.use_ue8m0_dispatch,
use_ue8m0=self.use_ue8m0_dispatch,
**(dict(use_nvfp4=True) if use_nvfp4 else dict()),
......@@ -327,7 +336,7 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
a1_dtype: torch.dtype,
quant_config: FusedMoEQuantConfig,
) -> mk.PrepareResultType:
expert_x, expert_x_scale = self._do_quant(expert_x, a1_dtype, quant_config)
expert_x, expert_x_scale = self._do_quant(expert_x, a1_dtype, quant_config, expert_num_tokens)
expert_tokens_meta = mk.ExpertTokensMetadata(
expert_num_tokens=expert_num_tokens, expert_num_tokens_cpu=None
......
......@@ -54,6 +54,8 @@ class FusedMoEModularMethod(FusedMoEMethodBase, CustomOp):
old_quant_method.select_gemm_impl(prepare_finalize, moe_layer),
shared_experts,
moe_parallel_config=moe_layer.moe_parallel_config,
N=old_quant_method.N if hasattr(old_quant_method, "N") else -1,
K=old_quant_method.K if hasattr(old_quant_method, "K") else -1,
),
)
......@@ -95,6 +97,7 @@ class FusedMoEModularMethod(FusedMoEMethodBase, CustomOp):
x: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
use_nn_moe: bool | None = False,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
return self.fused_experts(
hidden_states=x,
......
......@@ -785,6 +785,21 @@ def _slice_scales(
return None
_alt_stream: torch.cuda.Stream | None = None
def alt_stream() -> torch.cuda.Stream | None:
"""
Ensures aux_stream is initialized only once
"""
global _alt_stream
# TODO: validate this works properly on ROCm platform.
if _alt_stream is None:
_alt_stream = torch.cuda.Stream()
return _alt_stream
@final
class FusedMoEModularKernel(torch.nn.Module):
"""
......@@ -805,6 +820,8 @@ class FusedMoEModularKernel(torch.nn.Module):
fused_experts: FusedMoEPermuteExpertsUnpermute,
shared_experts: torch.nn.Module | None = None,
moe_parallel_config: FusedMoEParallelConfig | None = None,
N: int = -1,
K: int = -1,
):
super().__init__()
self.prepare_finalize = prepare_finalize
......@@ -831,6 +848,12 @@ class FusedMoEModularKernel(torch.nn.Module):
f"{fused_experts.__class__.__name__}."
f"{fused_experts.activation_format()}"
)
self.N = N
self.K = K
if self.shared_experts is not None:
self.alt_stream = alt_stream()
self.alt_event = torch.cuda.Event()
def _post_init_setup(self):
"""
......@@ -1136,9 +1159,9 @@ class FusedMoEModularKernel(torch.nn.Module):
_, M_full, N, K, top_k = self.fused_experts.moe_problem_size(
a1q, w1, w2, topk_ids
)
if use_nn_moe:
N = w1.size(2)
if self.N > 0:
N = self.N
K = self.K
num_chunks, CHUNK_SIZE = self._chunk_info(M_full)
......@@ -1244,6 +1267,14 @@ class FusedMoEModularKernel(torch.nn.Module):
if self.shared_experts is not None:
shared_output = self.shared_experts(hidden_states)
else:
self.alt_event.record()
if self.shared_experts is not None:
shared_output = self.shared_experts(hidden_states)
current_stream = torch.cuda.current_stream()
with torch.cuda.stream(self.alt_stream):
self.alt_stream.wait_event(self.alt_event)
finalize_ret = self.prepare_finalize.finalize_async(
output,
fused_out,
......@@ -1252,8 +1283,6 @@ class FusedMoEModularKernel(torch.nn.Module):
apply_router_weight_on_input,
self.fused_experts.finalize_weight_and_reduce_impl(),
)
if self.shared_experts is not None:
shared_output = self.shared_experts(hidden_states)
# TODO(lucas): refactor this in the alternative schedules followup
# currently unpack if we have hook + receiver pair or just
......@@ -1276,6 +1305,9 @@ class FusedMoEModularKernel(torch.nn.Module):
receiver()
self.alt_event.record()
current_stream.wait_event(self.alt_event)
if self.shared_experts is None:
return output
else:
......
......@@ -2,9 +2,11 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import functools
from math import prod
from typing import Optional
import torch
import torch.nn.functional as F
from triton.language.extra import libdevice
from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
......@@ -154,11 +156,147 @@ def _fp8_quantize(
return A, A_scale
@triton.jit
def _per_token_quant_int8_one_kernel_opt(
x_ptr,
xq_ptr,
scale_ptr,
stride_x,
stride_xq,
N,
T_dim,
has_tokens_per_expert: tl.constexpr,
tokens_per_expert_ptr,
BLOCK: tl.constexpr
):
row_id = tl.program_id(0)
if has_tokens_per_expert:
e = row_id // T_dim
t = row_id % T_dim
num_valid_tokens_for_e = tl.load(tokens_per_expert_ptr + e)
if t >= num_valid_tokens_for_e:
return
cols = tl.arange(0, BLOCK)
mask = cols < N
x = tl.load(x_ptr + row_id * stride_x + cols, mask=mask,
other=0.0).to(tl.float32)
absmax = tl.maximum(tl.max(tl.abs(x)), 1e-10)
scale_x = absmax / 127
x_q = x * (127 / absmax)
x_q = libdevice.nearbyint(x_q).to(tl.int8)
tl.store(xq_ptr + row_id * stride_xq + cols, x_q, mask=mask)
tl.store(scale_ptr + row_id, scale_x)
@triton.jit
def _per_token_quant_int8_kernel_opt(
x_ptr,
xq_ptr,
scale_ptr,
stride_x,
stride_xq,
N,
E_dim,
T_dim,
has_tokens_per_expert: tl.constexpr,
tokens_per_expert_ptr,
BLOCK: tl.constexpr
):
token_idx_start = tl.program_id(0)
grid_size = tl.num_programs(0)
num_total_tokens = E_dim * T_dim
for token_idx in range(token_idx_start, num_total_tokens, grid_size):
is_valid_token = True
if has_tokens_per_expert:
e = token_idx // T_dim
t = token_idx % T_dim
num_valid_tokens_for_e = tl.load(tokens_per_expert_ptr + e)
if t >= num_valid_tokens_for_e:
is_valid_token = False
if is_valid_token:
cols = tl.arange(0, BLOCK)
mask = cols < N
x = tl.load(x_ptr + token_idx * stride_x + cols, mask=mask,
other=0.0).to(tl.float32)
absmax = tl.maximum(tl.max(tl.abs(x)), 1e-10)
scale_x = absmax / 127
x_q = x * (127 / absmax)
x_q = libdevice.nearbyint(x_q).to(tl.int8)
tl.store(xq_ptr + token_idx * stride_xq + cols, x_q, mask=mask)
tl.store(scale_ptr + token_idx, scale_x)
def per_token_quant_int8_triton_opt(x: torch.Tensor,
tokens_per_expert: Optional[torch.Tensor] = None):
if x.dim() != 3:
raise ValueError(f"Input must be 3D [E, T, H], but got {x.shape}")
E, T, H = x.shape
N = H
x_q = torch.empty_like(x, device=x.device, dtype=torch.int8)
scales = torch.empty(x.shape[:-1] + (1, ),
device=x.device,
dtype=torch.float32)
BLOCK = triton.next_power_of_2(N)
num_warps = min(max(BLOCK // 256, 1), 8)
if T >= 4096:
num_warps = 1
num_tokens = E * T
grid_opt = num_tokens
if E == 16 and T >= 1024 :
grid_opt = max(1, num_tokens // (T // 256))
_per_token_quant_int8_kernel_opt[(grid_opt, )](
x,
x_q,
scales,
stride_x=x.stride(-2),
stride_xq=x_q.stride(-2),
N=N,
E_dim=E,
T_dim=T,
has_tokens_per_expert=tokens_per_expert is not None,
tokens_per_expert_ptr=tokens_per_expert,
BLOCK=BLOCK,
num_warps=num_warps,
num_stages=1,
)
else:
_per_token_quant_int8_one_kernel_opt[(grid_opt, )](
x,
x_q,
scales,
stride_x=x.stride(-2),
stride_xq=x_q.stride(-2),
N=N,
T_dim=T,
has_tokens_per_expert=tokens_per_expert is not None,
tokens_per_expert_ptr=tokens_per_expert,
BLOCK=BLOCK,
num_warps=num_warps,
num_stages=1,
)
return x_q, scales
def _int8_quantize(
A: torch.Tensor,
A_scale: torch.Tensor | None,
per_act_token: bool,
block_shape: list[int] | None = None,
expert_num_tokens: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Perform int8 quantization on the inputs. If a block_shape
......@@ -168,9 +306,12 @@ def _int8_quantize(
# If weights are per-channel (per_channel_quant=True), then
# activations apply per-token quantization. Otherwise, assume
# activation tensor-wise fp8/int8 quantization, dynamic or static
if block_shape is None:
if block_shape is None or per_act_token:
assert per_act_token, "int8 quantization only supports block or channel-wise"
if expert_num_tokens is None:
A, A_scale = per_token_quant_int8(A)
else:
A, A_scale = per_token_quant_int8_triton_opt(A, expert_num_tokens)
else:
assert not per_act_token
assert len(block_shape) == 2
......
......@@ -6,15 +6,26 @@ from enum import Enum
from typing import Callable, Optional
import torch
from compressed_tensors.quantization import (QuantizationStrategy)
import vllm.envs as envs
from vllm.config import get_current_vllm_config
from vllm.logger import init_logger
from torch.nn.parameter import Parameter
from vllm.distributed import get_ep_group, get_dp_group
from vllm.model_executor.layers.fused_moe import (
FusedMoE, FusedMoEActivationFormat, FusedMoEMethodBase,
FusedMoeWeightScaleSupported, FusedMoEConfig)
from vllm.model_executor.utils import set_weight_attrs
from vllm.model_executor.layers.quantization.utils.w8a8_utils import(
get_w8a8_int8_marlin_weights)
from vllm.model_executor.layers.fused_moe.config import (FusedMoEQuantConfig)
get_w8a8_int8_marlin_weights, w8a8_nt_kpack2_marlin_weight)
from vllm.model_executor.layers.fused_moe.config import (FusedMoEQuantConfig, int8_w8a8_moe_quant_config)
from vllm.model_executor.layers.fused_moe import (
FusedMoE,
FusedMoEMethodBase,
FusedMoEPermuteExpertsUnpermute,
FusedMoEPrepareAndFinalize,
FusedMoeWeightScaleSupported,
)
try:
from lmslim.layers.fused_moe.fuse_moe_int8_marlin import fused_experts_impl_int8_marlin
except Exception:
......@@ -75,14 +86,38 @@ class CompressedTensorsW8A8Int8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod)
"For INT8 Fused MoE layers, we require channelwise, "
"dynamic per token quantization. Found static input scales.")
vllm_config = get_current_vllm_config()
parallel_config = vllm_config.parallel_config
self.dp_size = get_dp_group().world_size
self.use_deepep = self.dp_size > 1 and parallel_config.enable_expert_parallel and \
(envs.VLLM_ALL2ALL_BACKEND == "deepep_high_throughput" or \
envs.VLLM_ALL2ALL_BACKEND == "deepep_low_latency" or \
envs.VLLM_ALL2ALL_BACKEND == "deepep_auto")
if self.use_deepep:
all2all_manager = get_ep_group().device_communicator.all2all_manager
assert all2all_manager is not None
self.num_dispatchers = all2all_manager.world_size
def get_fused_moe_quant_config(
self, layer: torch.nn.Module) -> Optional[FusedMoEQuantConfig]:
return None
self, layer: torch.nn.Module
) -> FusedMoEQuantConfig | None:
return int8_w8a8_moe_quant_config(
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale,
per_act_token_quant=True,
)
def create_weights(self, layer: torch.nn.Module, num_experts: int,
hidden_size: int, intermediate_size_per_partition: int,
params_dtype: torch.dtype, **extra_weight_attrs):
if self.use_deepep:
self.N = 2 * intermediate_size_per_partition
self.K = hidden_size
params_dtype = torch.int8
# WEIGHTS
......@@ -133,14 +168,20 @@ class CompressedTensorsW8A8Int8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
w1_marlin_list = []
for ii in range(layer.w13_weight.shape[0]):
if not self.use_deepep:
w1_marlin_in = get_w8a8_int8_marlin_weights(layer.w13_weight[ii])
else:
w1_marlin_in = w8a8_nt_kpack2_marlin_weight(layer.w13_weight[ii])
w1_marlin_list.append(w1_marlin_in)
w1_marlin = torch.stack(w1_marlin_list, dim=0)
del w1_marlin_list
w2_marlin_list = []
for ii in range(layer.w2_weight.shape[0]):
if not self.use_deepep:
w2_marlin_in = get_w8a8_int8_marlin_weights(layer.w2_weight[ii])
else:
w2_marlin_in = w8a8_nt_kpack2_marlin_weight(layer.w2_weight[ii])
w2_marlin_list.append(w2_marlin_in)
w2_marlin = torch.stack(w2_marlin_list, dim=0)
......@@ -177,3 +218,39 @@ class CompressedTensorsW8A8Int8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod)
a2_scale=layer.w2_input_scale,
use_nn_moe=False,
)
def select_gemm_impl(
self,
prepare_finalize: FusedMoEPrepareAndFinalize,
layer: torch.nn.Module,
) -> FusedMoEPermuteExpertsUnpermute:
from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import (
BatchedDeepGemmExperts,
)
from vllm.model_executor.layers.fused_moe.deep_gemm_moe import (
DeepGemmExperts,
)
if (
prepare_finalize.activation_format
== FusedMoEActivationFormat.BatchedExperts
):
max_num_tokens_per_rank = prepare_finalize.max_num_tokens_per_rank()
assert max_num_tokens_per_rank is not None
logger.debug("BatchedDeepGemmExperts(%s)", self.__class__.__name__)
return BatchedDeepGemmExperts(
moe_config=self.moe,
max_num_tokens=max_num_tokens_per_rank,
num_dispatchers=prepare_finalize.num_dispatchers(),
quant_config=self.moe_quant_config,
N=self.N,
K=self.K
)
else:
logger.debug("DeepGemmExperts(%s)", self.__class__.__name__)
return DeepGemmExperts(moe_config=self.moe,
quant_config=self.moe_quant_config,
N=self.N,
K=self.K)
\ No newline at end of file
......@@ -30,6 +30,19 @@ def get_w8a8_int8_marlin_weights(
return weight
def w8a8_nt_kpack2_marlin_weight(w8a8_w, # [size_n, size_k// 2 ]
k_tile=16,
n_tile=16, ):
assert w8a8_w.dtype == torch.int8, "w8a8_w 必须是 int8 类型"
size_n, size_k = w8a8_w.shape
assert size_n % k_tile == 0 and size_k % n_tile == 0, "k_tile / n_tile 必须能整除对应维度"
w8a8_w = w8a8_w.reshape((size_n // n_tile, n_tile, size_k // k_tile, k_tile))
w8a8_w = w8a8_w.permute((0, 2, 1, 3)).contiguous()
w8a8_w = w8a8_w.reshape((size_n // k_tile, size_k * k_tile))
return w8a8_w
def sparse_cutlass_supported() -> bool:
if not current_platform.is_cuda():
return False
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment