Unverified Commit 53fa4573 authored by Varun Sundar Rabindranath's avatar Varun Sundar Rabindranath Committed by GitHub
Browse files

[Misc] Add unit tests for MoE ModularKernel combinations + Profiling utility (#20449)


Signed-off-by: default avatarVarun Sundar Rabindranath <vsundarr@redhat.com>
Co-authored-by: default avatarVarun Sundar Rabindranath <vsundarr@redhat.com>
parent 6fb16244
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import argparse
import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from .common import Config
from .mk_objects import (MK_ALL_PREPARE_FINALIZE_TYPES, MK_FUSED_EXPERT_TYPES,
MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES)
def make_config_arg_parser(description: str):
def to_pf_class_type(s: str) -> mk.FusedMoEPrepareAndFinalize:
for pf in MK_ALL_PREPARE_FINALIZE_TYPES:
if pf.__name__ == s:
return pf
raise ValueError(
f"Cannot find a PrepareFinalize type that matches {s}")
def to_experts_class_type(s: str) -> mk.FusedMoEPermuteExpertsUnpermute:
for fe in MK_FUSED_EXPERT_TYPES:
if fe.__name__ == s:
return fe
raise ValueError(f"Cannot find a FusedExperts type that matches {s}")
def to_quant_torch_dtype(s: str) -> torch.dtype:
if s == "torch.float8_e4m3fn":
return torch.float8_e4m3fn
raise ValueError(f"Unsupported quant type {s}")
parser = argparse.ArgumentParser(description=description)
parser.add_argument(
"--world-size",
type=int,
default=2,
help="Number of ranks that participate in all2all",
)
parser.add_argument(
"--pf-type",
type=to_pf_class_type,
required=True,
help=("Choose a PrepareFinalize Type : "
f"{[x.__name__ for x in MK_ALL_PREPARE_FINALIZE_TYPES]}"),
)
parser.add_argument(
"--experts-type",
type=to_experts_class_type,
required=True,
help=(f"Choose a FusedExpert type : "
f"{[x.__name__ for x in MK_FUSED_EXPERT_TYPES]}"),
)
parser.add_argument(
"-m",
nargs="+",
type=int,
default=[64],
help="num tokens per rank",
)
parser.add_argument(
"-k",
type=int,
default=7168,
help="hidden-size",
)
parser.add_argument(
"-n",
type=int,
default=1024,
help="N dimension of the first fused-moe matmul",
)
parser.add_argument("--num-experts",
type=int,
default=32,
help="Global num experts")
parser.add_argument("--topk",
nargs="+",
type=int,
default=[4, 1],
help="num topk")
parser.add_argument(
"--fused-moe-chunk-size",
nargs="+",
type=int,
help="Fused moe chunk size used for the non-batched fused experts impl."
)
# Quant args
parser.add_argument("--quant-dtype",
type=to_quant_torch_dtype,
help="Quant datatype")
parser.add_argument("--per-token-quantized-activations",
action='store_true',
help=("The input activations must be per-token "
"quantized"))
parser.add_argument("--per-channel-quantized-weights",
action="store_true",
help="The weights must be per-channel quantized.")
parser.add_argument("--block-shape",
nargs="+",
type=int,
help="Quantization block shape")
# Torch trace profile generation args
parser.add_argument("--torch-trace-dir-path",
type=str,
default=None,
help="Get torch trace for single execution")
return parser
def _validate_args(args: argparse.Namespace):
if args.quant_dtype is not None:
assert args.quant_dtype == torch.float8_e4m3fn
if args.block_shape is not None:
assert len(args.block_shape) == 2, (
f"block shape must have 2 elements. got {args.block_shape}")
if args.experts_type in MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES:
assert args.world_size == 1, (
"Single GPU objects need world size set to 1")
if args.torch_trace_dir_path is not None:
from pathlib import Path
assert Path(args.torch_trace_dir_path).is_dir(), (
f"Please create {args.torch_trace_dir_path}")
def make_config(args: argparse.Namespace) -> Config:
_validate_args(args)
quant_config = None
if args.quant_dtype is not None:
quant_config = FusedMoEQuantConfig(
quant_dtype=args.quant_dtype,
per_act_token_quant=args.per_token_quantized_activations,
per_out_ch_quant=args.per_channel_quantized_weights,
block_shape=args.block_shape)
return Config(
Ms=args.m,
K=args.k,
N=args.n,
E=args.num_experts,
topks=args.topk,
dtype=torch.bfloat16, # hard-code
quant_config=quant_config,
prepare_finalize_type=args.pf_type,
fused_experts_type=args.experts_type,
fused_moe_chunk_size=args.fused_moe_chunk_size,
world_size=args.world_size,
torch_trace_dir_path=args.torch_trace_dir_path)
This diff is collapsed.
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import copy
from enum import Enum
from itertools import product
from typing import Optional
import torch
from tqdm import tqdm
from vllm.config import VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from vllm.platforms import current_platform
from .common import (Config, RankTensors, WeightTensors, reference_moe_impl,
run_modular_kernel)
from .mk_objects import (MK_FUSED_EXPERT_TYPES,
MK_MULTI_GPU_PREPARE_FINALIZE_TYPES, MK_QUANT_CONFIGS)
from .parallel_utils import ProcessGroupInfo, parallel_launch_with_config
class Result(Enum):
PASS = 1
FAIL = 2
SKIP = 3
def rank_worker(
pgi: ProcessGroupInfo,
vllm_config: VllmConfig,
cpu_group,
config: Config,
weights: WeightTensors,
):
current_platform.seed_everything(pgi.rank)
# sanity check
from vllm import envs
if config.fused_moe_chunk_size is not None:
assert (config.fused_moe_chunk_size == envs.VLLM_FUSED_MOE_CHUNK_SIZE)
# get weights to this device
weights.to_current_device()
Ms = config.Ms
assert isinstance(Ms, list)
TOPKs = config.topks
assert isinstance(TOPKs, list)
for m, topk in product(Ms, TOPKs):
print(f"Running m={m}, topk={topk} ...")
# override m and topk
cfgx = copy.deepcopy(config)
cfgx.Ms = m
cfgx.topks = topk
# inputs for rank
rank_tensors = RankTensors.make(cfgx, pgi)
# modular kernel out
mk_out = run_modular_kernel(pgi, vllm_config, cfgx, weights,
rank_tensors)
with set_current_vllm_config(vllm_config):
ref_out = reference_moe_impl(cfgx, weights, rank_tensors)
torch.testing.assert_close(ref_out, mk_out, atol=3e-2, rtol=3e-2)
def make_feature_matrix(csv_file_path: str):
from dataclasses import asdict
import pandas as pd
def add_to_results(config: Config,
success: Result,
results_df: Optional[pd.DataFrame] = None):
config_dict = asdict(config)
config_dict['prepare_finalize_type'] = config_dict[
'prepare_finalize_type'].__name__
config_dict['fused_experts_type'] = config_dict[
'fused_experts_type'].__name__
config_dict['per_tensor_act_quant'] = config.is_per_tensor_act_quant
quant_config_dict = config_dict['quant_config']
del config_dict['quant_config']
if quant_config_dict is None:
quant_config = FusedMoEQuantConfig(None)
quant_config_dict = asdict(quant_config)
config_dict |= quant_config_dict
result_dict = config_dict | {'success': success.name}
result_df = pd.DataFrame([result_dict])
if results_df is None:
results_df = result_df
else:
results_df = pd.concat([results_df, result_df], ignore_index=True)
return results_df
Ms = [64]
Ks = [7168] # hidden sizes
Ns = [2048]
TOPKs = [[4, 1]]
Es = [32]
DTYPEs = [torch.bfloat16]
PF_TYPES = MK_MULTI_GPU_PREPARE_FINALIZE_TYPES
FE_TYPES = MK_FUSED_EXPERT_TYPES
Q_TYPES = MK_QUANT_CONFIGS
combinations = list(
product(Ms, Ks, Ns, Es, TOPKs, DTYPEs, PF_TYPES, FE_TYPES, Q_TYPES))
results_df: Optional[pd.DataFrame] = None
for m, k, n, e, topks, dtype, pf_type, experts_type, quant_config in tqdm(
combinations): #noqa: E501
config = Config(Ms=[m],
K=k,
N=n,
E=e,
topks=topks,
dtype=dtype,
prepare_finalize_type=pf_type,
fused_experts_type=experts_type,
quant_config=quant_config,
world_size=2,
fused_moe_chunk_size=None)
success = None
if config.is_valid():
print(f"Running config : {config.describe()} ...")
try:
weights: WeightTensors = WeightTensors.make(config)
vllm_config, env_dict = config.make_env_data()
parallel_launch_with_config(config.world_size, rank_worker,
vllm_config, env_dict, config,
weights)
success = Result.PASS
except Exception as _:
success = Result.FAIL
else:
success = Result.SKIP
results_df = add_to_results(config, success, results_df)
if results_df is not None:
results_df.to_csv(f"{csv_file_path}")
if __name__ == '__main__':
import argparse
from pathlib import Path
parser = argparse.ArgumentParser(description=(
"Make ModularKernel feature matrix \n"
"Example : python3 -m tests.kernels.moe.modular_kernel_tools.make_feature_matrix " #noqa: E501
"-f ./feature_matrices/feature_matrix.csv"))
parser.add_argument("-f",
"--feature-matrix-csv-file-path",
type=str,
required=True,
help="File name to Generate a .csv file")
args = parser.parse_args()
csv_path = args.feature_matrix_csv_file_path
assert csv_path.endswith(
'csv'), f"Need a file path ending with .csv, got {csv_path}"
assert Path(csv_path).parent.is_dir(
), f"Cannot find parent directory for {Path(csv_path).parent}"
make_feature_matrix(args.feature_matrix_csv_file_path)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
# Fused experts and PrepareFinalize imports
from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import (
BatchedDeepGemmExperts)
from vllm.model_executor.layers.fused_moe.batched_triton_or_deep_gemm_moe import ( # noqa: E501
BatchedTritonOrDeepGemmExperts)
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from vllm.model_executor.layers.fused_moe.cutlass_moe import CutlassExpertsFp8
from vllm.model_executor.layers.fused_moe.deep_gemm_moe import DeepGemmExperts
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
BatchedTritonExperts, NaiveBatchedExperts)
from vllm.model_executor.layers.fused_moe.layer import TritonExperts
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
MoEPrepareAndFinalizeNoEP)
from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import (
TritonOrDeepGemmExperts)
from vllm.utils import has_deep_ep, has_pplx
if has_deep_ep():
from vllm.model_executor.layers.fused_moe.deepep_ht_prepare_finalize import ( # noqa: E501
DeepEPHTPrepareAndFinalize)
from vllm.model_executor.layers.fused_moe.deepep_ll_prepare_finalize import ( # noqa: E501
DeepEPLLPrepareAndFinalize)
if has_pplx():
from vllm.model_executor.layers.fused_moe.pplx_prepare_finalize import (
PplxPrepareAndFinalize)
MK_MULTI_GPU_PREPARE_FINALIZE_TYPES = []
if has_pplx():
MK_MULTI_GPU_PREPARE_FINALIZE_TYPES += [PplxPrepareAndFinalize]
if has_deep_ep():
MK_MULTI_GPU_PREPARE_FINALIZE_TYPES += [
DeepEPHTPrepareAndFinalize, DeepEPLLPrepareAndFinalize
]
MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES = [MoEPrepareAndFinalizeNoEP]
MK_ALL_PREPARE_FINALIZE_TYPES = (MK_MULTI_GPU_PREPARE_FINALIZE_TYPES +
MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES)
MK_FUSED_EXPERT_TYPES = [
BatchedDeepGemmExperts,
BatchedTritonExperts,
NaiveBatchedExperts,
BatchedTritonOrDeepGemmExperts,
CutlassExpertsFp8,
DeepGemmExperts,
TritonOrDeepGemmExperts,
TritonExperts,
]
MK_QUANT_CONFIGS = [
None,
# per-channel / per-column weights and per-tensor activations
FusedMoEQuantConfig(quant_dtype=torch.float8_e4m3fn,
per_out_ch_quant=True,
per_act_token_quant=False,
block_shape=None),
# per-channel / per-column weights and per-token activations
FusedMoEQuantConfig(quant_dtype=torch.float8_e4m3fn,
per_out_ch_quant=True,
per_act_token_quant=True,
block_shape=None),
# per-tensor weights and per-tensor activations
FusedMoEQuantConfig(quant_dtype=torch.float8_e4m3fn,
per_out_ch_quant=False,
per_act_token_quant=False,
block_shape=None),
# per-tensor weights and per-token activations
FusedMoEQuantConfig(quant_dtype=torch.float8_e4m3fn,
per_out_ch_quant=False,
per_act_token_quant=True,
block_shape=None),
# block-quantized weights and 128 block per-token activations
FusedMoEQuantConfig(quant_dtype=torch.float8_e4m3fn,
per_out_ch_quant=False,
per_act_token_quant=False,
block_shape=[128, 128]),
# TODO (varun) : Should we test the following combinations ?
# block-quantized weights and per-token activations
# block-quantized weights and per-tensor activations
]
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import dataclasses
import os
import traceback
from typing import Any, Callable, Optional
import torch
from torch.multiprocessing import (
spawn) # pyright: ignore[reportPrivateImportUsage]
from typing_extensions import Concatenate, ParamSpec
from vllm.config import VllmConfig, set_current_vllm_config
from vllm.distributed import (init_distributed_environment,
initialize_model_parallel)
from vllm.utils import get_open_port
## Parallel Processes Utils
P = ParamSpec("P")
@dataclasses.dataclass
class ProcessGroupInfo:
world_size: int
world_local_size: int
rank: int
node_rank: int
local_rank: int
device: torch.device
def _set_vllm_config(vllm_config: VllmConfig, world_size: int, rank: int,
local_rank: int):
import tempfile
temp_file = tempfile.mkstemp()[1]
set_current_vllm_config(vllm_config)
with set_current_vllm_config(vllm_config):
init_distributed_environment(
world_size=world_size,
rank=rank,
distributed_init_method=f"file://{temp_file}",
local_rank=local_rank,
backend="nccl",
)
initialize_model_parallel(
tensor_model_parallel_size=vllm_config.parallel_config.
tensor_parallel_size,
pipeline_model_parallel_size=vllm_config.parallel_config.
pipeline_parallel_size,
)
cpu_group = torch.distributed.new_group(list(range(world_size)),
backend="gloo")
return cpu_group
def _worker_parallel_launch(
local_rank: int,
world_size: int,
world_local_size: int,
node_rank: int,
init_method: str,
worker: Callable[Concatenate[ProcessGroupInfo, Optional[VllmConfig], Any,
P], None],
vllm_config: Optional[VllmConfig],
env_dict: Optional[dict],
*args: P.args,
**kwargs: P.kwargs,
) -> None:
rank = node_rank * world_local_size + local_rank
torch.cuda.set_device(local_rank)
device = torch.device("cuda", local_rank)
torch.distributed.init_process_group(
backend="cpu:gloo,cuda:nccl",
init_method=init_method,
rank=rank,
world_size=world_size,
device_id=device,
)
barrier = torch.tensor([rank], device=device)
torch.distributed.all_reduce(barrier)
if env_dict is not None:
os.environ.update(env_dict)
cpu_group = None
if vllm_config is not None:
cpu_group = _set_vllm_config(vllm_config, world_size, rank, local_rank)
try:
worker(
ProcessGroupInfo(
world_size=world_size,
world_local_size=world_local_size,
rank=rank,
node_rank=node_rank,
local_rank=local_rank,
device=device,
),
vllm_config,
cpu_group,
*args,
**kwargs,
)
except Exception as ex:
print(ex)
traceback.print_exc()
raise
finally:
torch.distributed.destroy_process_group()
def parallel_launch_with_config(
world_size: int,
worker: Callable[Concatenate[ProcessGroupInfo, VllmConfig, Any, P], None],
vllm_config: VllmConfig,
env_dict: dict[Any, Any],
*args: P.args,
**kwargs: P.kwargs,
) -> None:
assert not kwargs
spawn(
_worker_parallel_launch,
args=(
world_size,
world_size,
0,
f"tcp://{os.getenv('LOCALHOST', 'localhost')}:{get_open_port()}",
worker,
vllm_config,
env_dict,
) + args,
nprocs=world_size,
join=True,
)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import copy
from itertools import product
from typing import Any, Callable
import torch
from vllm.config import VllmConfig
from vllm.platforms import current_platform
from .common import Config, RankTensors, WeightTensors, make_modular_kernel
from .parallel_utils import ProcessGroupInfo, parallel_launch_with_config
def do_profile(fn: Callable,
fn_kwargs: dict[Any, Any],
pgi: ProcessGroupInfo,
config: Config,
num_warmups: int = 5):
for _ in range(num_warmups):
fn(**fn_kwargs)
with torch.profiler.profile(
activities=[
torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.CUDA,
],
with_stack=True,
record_shapes=True,
) as tprof:
fn(**fn_kwargs)
torch.cuda.synchronize(torch.cuda.current_device())
# TODO (varun): Add a descriptive trace file name
tprof.export_chrome_trace(
f"{config.torch_trace_dir_path}/m{config.M}_{pgi.rank}_trace.json")
def profile_modular_kernel(
pgi: ProcessGroupInfo,
vllm_config: VllmConfig,
config: Config,
weights: WeightTensors,
rank_tensors: RankTensors,
) -> None:
assert isinstance(config.Ms, int)
assert isinstance(config.topks, int)
# weights for rank
rank_weights = weights.slice_weights(pgi.rank, config.num_local_experts)
# make modular kernel
mk = make_modular_kernel(config, vllm_config)
mk_kwargs = {
"hidden_states": rank_tensors.hidden_states,
"w1": rank_weights.w1,
"w2": rank_weights.w2,
"topk_weights": rank_tensors.topk_weights,
"topk_ids": rank_tensors.topk_ids,
"expert_map": rank_tensors.expert_map,
"w1_scale": rank_weights.w1_scale,
"w2_scale": rank_weights.w2_scale,
"a1_scale": rank_tensors.hidden_states_scale,
"global_num_experts": config.E,
"apply_router_weight_on_input": config.topk == 1,
}
do_profile(mk.forward, mk_kwargs, pgi, config)
def rank_worker(
pgi: ProcessGroupInfo,
vllm_config: VllmConfig,
cpu_group,
config: Config,
weights: WeightTensors,
):
current_platform.seed_everything(pgi.rank)
# sanity check
from vllm import envs
if config.fused_moe_chunk_size is not None:
assert (config.fused_moe_chunk_size == envs.VLLM_FUSED_MOE_CHUNK_SIZE)
# get weights to this device
weights.to_current_device()
Ms = config.Ms
assert isinstance(Ms, list)
TOPKs = config.topks
assert isinstance(TOPKs, list)
for m, topk in product(Ms, TOPKs):
print(f"Running m={m}, topk={topk} ...")
# override m and topk
cfgx = copy.deepcopy(config)
cfgx.Ms = m
cfgx.topks = topk
# inputs for rank
rank_tensors = RankTensors.make(cfgx, pgi)
profile_modular_kernel(pgi, vllm_config, cfgx, weights, rank_tensors)
def run(config: Config):
weights: WeightTensors = WeightTensors.make(config)
vllm_config, env_dict = config.make_env_data()
parallel_launch_with_config(config.world_size, rank_worker, vllm_config,
env_dict, config, weights)
if __name__ == '__main__':
from .cli_args import make_config, make_config_arg_parser
parser = make_config_arg_parser(description=(
"Run single prepare-finalize & fused-experts combination test"
"Example : python3 -m tests.kernels.moe.modular_kernel_tools.profile_modular_kernel " #noqa: E501
"--pf-type PplxPrepareAndFinalize --experts-type BatchedTritonExperts"
))
args = parser.parse_args()
assert args.torch_trace_dir_path is not None, (
"Please pass in a directory to store torch traces")
config = make_config(args)
run(config)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import math
import torch
import vllm._custom_ops as ops
def per_token_cast_to_fp8(
x: torch.Tensor, block_size: int) -> tuple[torch.Tensor, torch.Tensor]:
assert x.dim() == 2
m, n = x.shape
pad_size = (block_size - (n % block_size)) % block_size
x = torch.nn.functional.pad(x,
(0, pad_size), value=0) if pad_size > 0 else x
x_view = x.view(m, -1, block_size)
x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4)
fp8_data = (x_view * (448.0 / x_amax.unsqueeze(2))).to(torch.float8_e4m3fn)
return fp8_data.view(m, n + pad_size)[:, :n], (x_amax / 448.0).view(m, -1)
def per_block_cast_to_fp8(
x: torch.Tensor, block_size_k: int,
block_size_n: int) -> tuple[torch.Tensor, torch.Tensor]:
assert x.dim() == 2
m, n = x.shape
x_padded = torch.zeros(
(
int(math.ceil(m / block_size_k)) * block_size_k,
int(math.ceil(n / block_size_n)) * block_size_n,
),
dtype=x.dtype,
device=x.device,
)
x_padded[:m, :n] = x
x_view = x_padded.view(-1, block_size_k,
x_padded.size(1) // block_size_k, block_size_n)
x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4)
x_scaled = (x_view * (448.0 / x_amax)).to(torch.float8_e4m3fn)
x_scaled_sub = x_scaled.view_as(x_padded)[:m, :n].contiguous()
scales = (x_amax / 448.0).view(x_view.size(0), x_view.size(2))
return x_scaled_sub, scales
def make_non_quant_weights(
e: int,
n: int,
k: int,
dtype: torch.dtype,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Return weights w1, w2
"""
device = torch.cuda.current_device()
w1 = torch.randn((e, 2 * n, k), device=device, dtype=dtype) / 15
w2 = torch.randn((e, k, n), device=device, dtype=dtype) / 15
return w1, w2
def make_block_quant_fp8_weights(
e: int,
n: int,
k: int,
block_size: list[int],
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Return weights w1, w2, w1_scale, w2_scale
"""
dtype = torch.bfloat16
device = torch.cuda.current_device()
fp8_info = torch.finfo(torch.float8_e4m3fn)
fp8_max, fp8_min = fp8_info.max, fp8_info.min
w1_bf16, w2_bf16 = make_non_quant_weights(e, n, k, dtype)
w1_bf16 = w1_bf16.clamp(min=fp8_min, max=fp8_max).to(dtype=dtype)
w2_bf16 = w2_bf16.clamp(min=fp8_min, max=fp8_max).to(dtype=dtype)
block_n, block_k = block_size[0], block_size[1]
n_tiles_w1 = ((2 * n) + block_n - 1) // block_n
k_tiles_w1 = (k + block_k - 1) // block_k
n_tiles_w2 = (k + block_n - 1) // block_n
k_tiles_w2 = (n + block_k - 1) // block_k
w1 = torch.empty_like(w1_bf16, dtype=torch.float8_e4m3fn, device=device)
w2 = torch.empty_like(w2_bf16, dtype=torch.float8_e4m3fn, device=device)
w1_s = torch.empty((e, n_tiles_w1, k_tiles_w1),
device=device,
dtype=torch.float32)
w2_s = torch.empty((e, n_tiles_w2, k_tiles_w2),
device=device,
dtype=torch.float32)
assert w1_s.shape == (e, (2 * n + (block_n - 1)) // block_n,
(k + (block_k - 1)) // block_k)
assert (w2.shape[-2] + block_n - 1) // block_n == w2_s.shape[-2]
for i in range(e):
w1[i], w1_s[i] = per_block_cast_to_fp8(w1_bf16[i],
block_size_k=block_k,
block_size_n=block_n)
w2[i], w2_s[i] = per_block_cast_to_fp8(w2_bf16[i],
block_size_k=block_k,
block_size_n=block_n)
return w1, w2, w1_s, w2_s
def make_quant_fp8_weights(
e: int,
n: int,
k: int,
per_out_channel_quant: bool,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Return w1, w2, w1_scale, w2_scale
"""
q_dtype = torch.float8_e4m3fn
w1, w2 = make_non_quant_weights(e, n, k, dtype=torch.bfloat16)
# w1 -> w1_q, w2 -> w2_q
w1_q = torch.empty((e, 2 * n, k), device="cuda", dtype=q_dtype)
w2_q = torch.empty((e, k, n), device="cuda", dtype=q_dtype)
n_b_scales = 2 * n if per_out_channel_quant else 1
k_b_scales = k if per_out_channel_quant else 1
w1_scale = torch.empty((e, n_b_scales, 1),
device="cuda",
dtype=torch.float32)
w2_scale = torch.empty((e, k_b_scales, 1),
device="cuda",
dtype=torch.float32)
for expert in range(e):
w1_q[expert], w1_scale[expert] = ops.scaled_fp8_quant(
w1[expert], use_per_token_if_dynamic=per_out_channel_quant)
w2_q[expert], w2_scale[expert] = ops.scaled_fp8_quant(
w2[expert], use_per_token_if_dynamic=per_out_channel_quant)
return w1_q, w2_q, w1_scale, w2_scale
...@@ -4,7 +4,6 @@ ...@@ -4,7 +4,6 @@
DeepEP test utilities DeepEP test utilities
""" """
import dataclasses import dataclasses
import importlib
import os import os
import traceback import traceback
from typing import Callable, Optional from typing import Callable, Optional
...@@ -15,10 +14,9 @@ from torch.multiprocessing import ( ...@@ -15,10 +14,9 @@ from torch.multiprocessing import (
spawn) # pyright: ignore[reportPrivateImportUsage] spawn) # pyright: ignore[reportPrivateImportUsage]
from typing_extensions import Concatenate, ParamSpec from typing_extensions import Concatenate, ParamSpec
from vllm.utils import get_open_port from vllm.utils import get_open_port, has_deep_ep
has_deep_ep = importlib.util.find_spec("deep_ep") is not None if has_deep_ep():
if has_deep_ep:
from vllm.model_executor.layers.fused_moe.deepep_ht_prepare_finalize import ( # noqa: E501 from vllm.model_executor.layers.fused_moe.deepep_ht_prepare_finalize import ( # noqa: E501
DeepEPHTPrepareAndFinalize) DeepEPHTPrepareAndFinalize)
from vllm.model_executor.layers.fused_moe.deepep_ll_prepare_finalize import ( # noqa: E501 from vllm.model_executor.layers.fused_moe.deepep_ll_prepare_finalize import ( # noqa: E501
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import copy
from itertools import product
from typing import Optional
import pytest
import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.config import VllmConfig, current_platform, set_current_vllm_config
from vllm.model_executor.layers.fused_moe.batched_triton_or_deep_gemm_moe import ( # noqa: E501
BatchedTritonOrDeepGemmExperts)
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from vllm.model_executor.layers.fused_moe.cutlass_moe import CutlassExpertsFp8
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
BatchedTritonExperts)
from vllm.model_executor.layers.fused_moe.layer import TritonExperts
from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import (
TritonOrDeepGemmExperts)
from vllm.utils import has_deep_ep, has_deep_gemm, has_pplx
from .modular_kernel_tools.common import (Config, RankTensors, WeightTensors,
reference_moe_impl,
run_modular_kernel)
from .modular_kernel_tools.mk_objects import (
MK_FUSED_EXPERT_TYPES, MK_MULTI_GPU_PREPARE_FINALIZE_TYPES,
MK_QUANT_CONFIGS, MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES)
from .modular_kernel_tools.parallel_utils import (ProcessGroupInfo,
parallel_launch_with_config)
# TODO (varun): These requirements are very strict and could be relaxed.
has_all_packages = (has_deep_ep() and has_deep_gemm() and has_pplx())
meets_package_requirements = pytest.mark.skipif(
not has_all_packages,
reason="Requires deep_ep & deep_gemm & pplx packages",
)
def rank_worker(
pgi: ProcessGroupInfo,
vllm_config: VllmConfig,
cpu_group,
config: Config,
weights: WeightTensors,
):
current_platform.seed_everything(pgi.rank)
# sanity check
from vllm import envs
if config.fused_moe_chunk_size is not None:
assert (config.fused_moe_chunk_size == envs.VLLM_FUSED_MOE_CHUNK_SIZE)
# get weights to this device
weights.to_current_device()
Ms = config.Ms
assert isinstance(Ms, list)
TOPKs = config.topks
assert isinstance(TOPKs, list)
for m, topk in product(Ms, TOPKs):
print(f"Running m={m}, topk={topk} ...")
# override m and topk
cfgx = copy.deepcopy(config)
cfgx.Ms = m
cfgx.topks = topk
# inputs for rank
rank_tensors = RankTensors.make(cfgx, pgi)
# modular kernel out
mk_out = run_modular_kernel(pgi, vllm_config, cfgx, weights,
rank_tensors)
with set_current_vllm_config(vllm_config):
ref_out = reference_moe_impl(cfgx, weights, rank_tensors)
torch.testing.assert_close(ref_out, mk_out, atol=3e-2, rtol=3e-2)
def run(config: Config):
assert config.is_valid()
print(f"Testing config \n{config.describe()} ...")
weights: WeightTensors = WeightTensors.make(config)
vllm_config, env_dict = config.make_env_data()
parallel_launch_with_config(config.world_size, rank_worker, vllm_config,
env_dict, config, weights)
Ms = [32, 64]
Ks = [7168] # hidden sizes
Ns = [2048]
TOPKs = [4, 1]
Es = [32]
DTYPEs = [torch.bfloat16]
FUSED_MOE_CHUNK_SIZEs = [None, 16]
def is_nyi_config(config: Config) -> bool:
# We know these configs to be legitimate. but still fail.
if (config.fused_experts_type in [
BatchedTritonExperts, BatchedTritonOrDeepGemmExperts,
TritonExperts, TritonOrDeepGemmExperts
]):
# The triton kernels expect both per-act-token-quant and
# per-out-ch-quant or neither.
unsupported_quant_config = ((config.is_per_act_token_quant +
config.is_per_out_ch_quant) == 1)
return unsupported_quant_config
# cutlass kernels dont support expert_maps yet.
return config.fused_experts_type == CutlassExpertsFp8
@pytest.mark.parametrize("k", Ks)
@pytest.mark.parametrize("n", Ns)
@pytest.mark.parametrize("e", Es)
@pytest.mark.parametrize("dtype", DTYPEs)
@pytest.mark.parametrize("quant_config", MK_QUANT_CONFIGS)
@pytest.mark.parametrize(
"combination",
product(MK_MULTI_GPU_PREPARE_FINALIZE_TYPES, MK_FUSED_EXPERT_TYPES))
@pytest.mark.parametrize("fused_moe_chunk_size", FUSED_MOE_CHUNK_SIZEs)
@pytest.mark.parametrize("world_size", [2])
@meets_package_requirements
def test_modular_kernel_combinations_multigpu(
k: int, n: int, e: int, dtype: torch.dtype,
quant_config: FusedMoEQuantConfig,
combination: tuple[mk.FusedMoEPrepareAndFinalize,
mk.FusedMoEPermuteExpertsUnpermute],
fused_moe_chunk_size: Optional[int], world_size: int):
config = Config(
Ms=Ms,
K=k,
N=n,
E=e,
topks=TOPKs,
dtype=dtype,
quant_config=quant_config,
prepare_finalize_type=combination[0],
fused_experts_type=combination[1],
fused_moe_chunk_size=fused_moe_chunk_size,
world_size=world_size,
)
if not config.is_valid():
pytest.skip(f"Tests config {config} is not valid. Skipping ...")
if is_nyi_config(config):
pytest.skip(f"Tests config {config} is nyi. Skipping ...")
print(f"{config.describe()}")
run(config)
@pytest.mark.parametrize("k", Ks)
@pytest.mark.parametrize("n", Ns)
@pytest.mark.parametrize("e", Es)
@pytest.mark.parametrize("dtype", DTYPEs)
@pytest.mark.parametrize("quant_config", MK_QUANT_CONFIGS)
@pytest.mark.parametrize(
"combination",
product(MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES, MK_FUSED_EXPERT_TYPES))
@pytest.mark.parametrize("fused_moe_chunk_size", FUSED_MOE_CHUNK_SIZEs)
@pytest.mark.parametrize("world_size", [1])
@meets_package_requirements
def test_modular_kernel_combinations_singlegpu(
k: int, n: int, e: int, dtype: torch.dtype,
quant_config: FusedMoEQuantConfig,
combination: tuple[mk.FusedMoEPrepareAndFinalize,
mk.FusedMoEPermuteExpertsUnpermute],
fused_moe_chunk_size: Optional[int], world_size: int):
config = Config(
Ms=Ms,
K=k,
N=n,
E=e,
topks=TOPKs,
dtype=dtype,
quant_config=quant_config,
prepare_finalize_type=combination[0],
fused_experts_type=combination[1],
fused_moe_chunk_size=fused_moe_chunk_size,
world_size=world_size,
)
if not config.is_valid():
pytest.skip(f"Tests config {config} is not valid. Skipping ...")
if is_nyi_config(config):
pytest.skip(f"Tests config {config} is nyi. Skipping ...")
run(config)
if __name__ == '__main__':
# Ability to test individual PrepareAndFinalize and FusedExperts combination
from .modular_kernel_tools.cli_args import (make_config,
make_config_arg_parser)
parser = make_config_arg_parser(description=(
"Run single prepare-finalize & fused-experts combination test"
"Example : python3 -m tests.kernels.moe.test_modular_kernel_combinations " #noqa: E501
"--pf-type PplxPrepareAndFinalize --experts-type BatchedTritonExperts"
))
args = parser.parse_args()
config = make_config(args)
run(config)
...@@ -1072,6 +1072,7 @@ def torch_experts( ...@@ -1072,6 +1072,7 @@ def torch_experts(
quant_dtype: Optional[torch.dtype] = None, quant_dtype: Optional[torch.dtype] = None,
per_act_token_quant=False, per_act_token_quant=False,
block_shape: Optional[list[int]] = None, block_shape: Optional[list[int]] = None,
apply_router_weights_on_input: bool = False,
) -> torch.Tensor: ) -> torch.Tensor:
assert (global_num_experts == -1 assert (global_num_experts == -1
or (global_num_experts == w1.shape[0] and expert_map is None) or (global_num_experts == w1.shape[0] and expert_map is None)
...@@ -1081,11 +1082,17 @@ def torch_experts( ...@@ -1081,11 +1082,17 @@ def torch_experts(
M, K = a.shape M, K = a.shape
topk = topk_ids.shape[1] topk = topk_ids.shape[1]
if apply_router_weights_on_input:
assert topk == 1
a = a * topk_weight.to(a.dtype)
a = a.view(M, -1, K).repeat(1, topk, 1).reshape(-1, K) a = a.view(M, -1, K).repeat(1, topk, 1).reshape(-1, K)
out = torch.zeros(M * topk, w2.shape[1], dtype=a.dtype, device=a.device) out = torch.zeros(M * topk, w2.shape[1], dtype=a.dtype, device=a.device)
a, a_scale = moe_kernel_quantize_input(a, None, quant_dtype, if a1_scale:
assert not per_act_token_quant and block_shape is None
a, a_scale = moe_kernel_quantize_input(a, a1_scale, quant_dtype,
per_act_token_quant, block_shape) per_act_token_quant, block_shape)
num_experts = w1.shape[0] num_experts = w1.shape[0]
...@@ -1104,6 +1111,7 @@ def torch_experts( ...@@ -1104,6 +1111,7 @@ def torch_experts(
tmp2 = SiluAndMul()(tmp1) tmp2 = SiluAndMul()(tmp1)
out[mask] = tmp2 @ w2[i].transpose(0, 1) out[mask] = tmp2 @ w2[i].transpose(0, 1)
elif block_shape is not None: elif block_shape is not None:
# block quantized
assert (a_scale is not None and w1_scale is not None assert (a_scale is not None and w1_scale is not None
and w2_scale is not None) and w2_scale is not None)
tmp1 = native_w8a8_block_matmul(a[mask], w1[i], a_scale[mask], tmp1 = native_w8a8_block_matmul(a[mask], w1[i], a_scale[mask],
...@@ -1121,13 +1129,25 @@ def torch_experts( ...@@ -1121,13 +1129,25 @@ def torch_experts(
assert (a_scale is not None and w1_scale is not None assert (a_scale is not None and w1_scale is not None
and w2_scale is not None) and w2_scale is not None)
scales = a_scale if a_scale.numel() == 1 else a_scale[mask] scales = a_scale if a_scale.numel() == 1 else a_scale[mask]
tmp1 = a[mask].to(f32) * scales tmp1 = a[mask].to(f32) * scales
w1_dq = (w1[i].to(f32) * w1_scale[i]).transpose(0, 1) w1_dq = (w1[i].to(f32) * w1_scale[i]).transpose(0, 1)
tmp1 = tmp1 @ w1_dq tmp1 = (tmp1 @ w1_dq).to(out.dtype)
tmp2 = SiluAndMul()(tmp1)
tmp2 = SiluAndMul()(tmp1).to(out.dtype)
tmp2, b_scale = moe_kernel_quantize_input(
tmp2, a2_scale, quant_dtype, per_act_token_quant,
block_shape)
assert b_scale is not None
tmp2 = tmp2.to(f32) * b_scale
w2_dq = (w2[i].to(f32) * w2_scale[i]).transpose(0, 1) w2_dq = (w2[i].to(f32) * w2_scale[i]).transpose(0, 1)
out[mask] = (tmp2 @ w2_dq).to(out.dtype) out[mask] = (tmp2 @ w2_dq).to(out.dtype)
if apply_router_weights_on_input:
return out
else:
return (out.view(M, -1, w2.shape[1]).to(f32) * return (out.view(M, -1, w2.shape[1]).to(f32) *
topk_weight.view(M, -1, 1)).sum(dim=1).to(out.dtype) topk_weight.view(M, -1, 1)).sum(dim=1).to(out.dtype)
......
...@@ -240,8 +240,7 @@ class DeviceCommunicatorBase: ...@@ -240,8 +240,7 @@ class DeviceCommunicatorBase:
if module.__class__.__name__ == "FusedMoE" if module.__class__.__name__ == "FusedMoE"
] ]
for module in moe_modules: for module in moe_modules:
module.quant_method.init_prepare_finalize(module.moe_config, module.quant_method.init_prepare_finalize(module.moe_config)
module.quant_config)
def dispatch( def dispatch(
self, hidden_states: torch.Tensor, self, hidden_states: torch.Tensor,
......
...@@ -37,7 +37,6 @@ class BatchedTritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -37,7 +37,6 @@ class BatchedTritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
block_shape=block_shape, block_shape=block_shape,
per_act_token_quant=per_act_token_quant, per_act_token_quant=per_act_token_quant,
)) ))
self.allow_deep_gemm = allow_deep_gemm
self.batched_triton_experts = BatchedTritonExperts( self.batched_triton_experts = BatchedTritonExperts(
max_num_tokens=max_num_tokens, max_num_tokens=max_num_tokens,
......
...@@ -81,13 +81,12 @@ class FusedMoEMethodBase(QuantizeMethodBase): ...@@ -81,13 +81,12 @@ class FusedMoEMethodBase(QuantizeMethodBase):
params_dtype: torch.dtype, **extra_weight_attrs): params_dtype: torch.dtype, **extra_weight_attrs):
raise NotImplementedError raise NotImplementedError
def init_prepare_finalize(self, moe: FusedMoEConfig, @staticmethod
quant_config: Optional[QuantizationConfig]): def maybe_make_prepare_finalize(
moe: FusedMoEConfig) -> Optional[FusedMoEPrepareAndFinalize]:
all2all_manager = get_ep_group().device_communicator.all2all_manager all2all_manager = get_ep_group().device_communicator.all2all_manager
assert all2all_manager is not None assert all2all_manager is not None
self.moe = moe
prepare_finalize: Optional[FusedMoEPrepareAndFinalize] = None prepare_finalize: Optional[FusedMoEPrepareAndFinalize] = None
if moe.use_pplx_kernels: if moe.use_pplx_kernels:
...@@ -160,8 +159,6 @@ class FusedMoEMethodBase(QuantizeMethodBase): ...@@ -160,8 +159,6 @@ class FusedMoEMethodBase(QuantizeMethodBase):
and moe.quant_config.block_shape and moe.quant_config.block_shape
== DEEPEP_QUANT_BLOCK_SHAPE) == DEEPEP_QUANT_BLOCK_SHAPE)
# Note (varun): Whether to use FP8 dispatch or not needs some
# profiling. Turning it off for now.
prepare_finalize = DeepEPLLPrepareAndFinalize( prepare_finalize = DeepEPLLPrepareAndFinalize(
handle, handle,
max_tokens_per_rank=moe.max_num_tokens, max_tokens_per_rank=moe.max_num_tokens,
...@@ -169,11 +166,18 @@ class FusedMoEMethodBase(QuantizeMethodBase): ...@@ -169,11 +166,18 @@ class FusedMoEMethodBase(QuantizeMethodBase):
use_fp8_dispatch=use_fp8_dispatch, use_fp8_dispatch=use_fp8_dispatch,
) )
return prepare_finalize
def init_prepare_finalize(self, moe: FusedMoEConfig):
self.moe = moe
prepare_finalize = FusedMoEMethodBase.maybe_make_prepare_finalize(
self.moe)
self.topk_indices_dtype = None self.topk_indices_dtype = None
if prepare_finalize is not None: if prepare_finalize is not None:
logger.debug("%s", prepare_finalize.__class__.__name__) logger.debug("%s", prepare_finalize.__class__.__name__)
self.topk_indices_dtype = prepare_finalize.topk_indices_dtype() self.topk_indices_dtype = prepare_finalize.topk_indices_dtype()
experts = self.select_gemm_impl(prepare_finalize, moe) experts = self.select_gemm_impl(prepare_finalize, self.moe)
self.fused_experts = FusedMoEModularKernel( self.fused_experts = FusedMoEModularKernel(
prepare_finalize, prepare_finalize,
experts, experts,
......
...@@ -7,7 +7,8 @@ import torch ...@@ -7,7 +7,8 @@ import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from vllm.model_executor.layers.fused_moe.deep_gemm_moe import ( from vllm.model_executor.layers.fused_moe.deep_gemm_moe import (
DeepGemmExperts, _valid_deep_gemm, _valid_deep_gemm_shape) DeepGemmExperts, _valid_deep_gemm, _valid_deep_gemm_shape,
deep_gemm_block_shape)
from vllm.model_executor.layers.fused_moe.fused_moe import TritonExperts from vllm.model_executor.layers.fused_moe.fused_moe import TritonExperts
from vllm.utils.deep_gemm import is_blackwell_deep_gemm_used from vllm.utils.deep_gemm import is_blackwell_deep_gemm_used
...@@ -44,8 +45,10 @@ class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -44,8 +45,10 @@ class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
per_act_token_quant=per_act_token_quant, per_act_token_quant=per_act_token_quant,
block_shape=block_shape, block_shape=block_shape,
) )
self.allow_deep_gemm = (allow_deep_gemm and not per_act_token_quant
and use_fp8_w8a8) self.allow_deep_gemm = (allow_deep_gemm and use_fp8_w8a8 and
self.block_shape == deep_gemm_block_shape())
self.deep_gemm_expert = DeepGemmExperts( self.deep_gemm_expert = DeepGemmExperts(
) if self.allow_deep_gemm else None ) if self.allow_deep_gemm else None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment