Commit cd3ed273 authored by zhuwenwen's avatar zhuwenwen
Browse files

update benchmark_moe.py

parent be0549c4
......@@ -7,19 +7,19 @@ import time
from contextlib import nullcontext
from datetime import datetime
from itertools import product
from typing import Any, TypedDict
from typing import Any, TypedDict, Optional
import ray
import torch
from ray.experimental.tqdm_ray import tqdm
from vllm.model_executor.layers.fused_moe.fused_moe import *
from vllm.platforms import current_platform
from vllm.transformers_utils.config import get_config
from vllm.triton_utils import triton
from vllm.utils import FlexibleArgumentParser
FP8_DTYPE = current_platform.fp8_dtype()
# 移除全局的 current_platform 导入,改为在需要时局部导入
# FP8_DTYPE = current_platform.fp8_dtype()
class BenchmarkConfig(TypedDict):
......@@ -47,8 +47,12 @@ def benchmark_config(
use_deep_gemm: bool = False,
nn_moe: Optional[bool] = False
) -> float:
from vllm.platforms import current_platform
device = torch.cuda.current_device()
init_dtype = torch.float16 if use_fp8_w8a8 else dtype
x = torch.randn(num_tokens, hidden_size, dtype=dtype)
x = torch.randn(num_tokens, hidden_size, dtype=dtype, device=device)
if use_int8_w8a16:
if not nn_moe:
w1 = torch.randint(
......@@ -60,6 +64,7 @@ def benchmark_config(
hidden_size,
),
dtype=torch.int8,
device=device,
)
w2 = torch.randint(
-127,
......@@ -70,6 +75,7 @@ def benchmark_config(
shard_intermediate_size // 2,
),
dtype=torch.int8,
device=device,
)
else:
w1 = torch.randint(
......@@ -81,6 +87,7 @@ def benchmark_config(
shard_intermediate_size,
),
dtype=torch.int8,
device=device,
)
w2 = torch.randint(
-127,
......@@ -91,23 +98,24 @@ def benchmark_config(
hidden_size,
),
dtype=torch.int8,
device=device,
)
else:
if not nn_moe:
w1 = torch.randn(
num_experts, shard_intermediate_size, hidden_size, dtype=init_dtype
num_experts, shard_intermediate_size, hidden_size, dtype=init_dtype, device=device
)
w2 = torch.randn(
num_experts, hidden_size, shard_intermediate_size // 2, dtype=init_dtype
num_experts, hidden_size, shard_intermediate_size // 2, dtype=init_dtype, device=device
)
else:
w1 = torch.randn(
num_experts, hidden_size, shard_intermediate_size, dtype=init_dtype
num_experts, hidden_size, shard_intermediate_size, dtype=init_dtype, device=device
)
w2 = torch.randn(
num_experts, shard_intermediate_size // 2, hidden_size, dtype=init_dtype
num_experts, shard_intermediate_size // 2, hidden_size, dtype=init_dtype, device=device
)
gating_output = torch.randn(num_iters, num_tokens, num_experts, dtype=torch.float32)
gating_output = torch.randn(num_iters, num_tokens, num_experts, dtype=torch.float32, device=device)
w1_scale = None
w2_scale = None
......@@ -115,9 +123,9 @@ def benchmark_config(
a2_scale = None
if use_int8_w8a16:
w1_scale = torch.randn(
(num_experts, 2 * shard_intermediate_size), dtype=torch.float32
(num_experts, 2 * shard_intermediate_size), dtype=torch.float32, device=device
)
w2_scale = torch.randn((hidden_size, num_experts), dtype=torch.float32)
w2_scale = torch.randn((hidden_size, num_experts), dtype=torch.float32, device=device)
if use_fp8_w8a8:
if block_quant_shape:
block_n, block_k = block_quant_shape[0], block_quant_shape[1]
......@@ -130,24 +138,26 @@ def benchmark_config(
k_tiles_w1 = (K + block_k - 1) // block_k
k_tiles_w2 = (N + block_k - 1) // block_k
w1_scale = (
torch.rand((E, n_tiles_w1, k_tiles_w1), dtype=torch.float32)
torch.rand((E, n_tiles_w1, k_tiles_w1), dtype=torch.float32, device=device)
* factor_for_scale
)
w2_scale = (
torch.rand((E, n_tiles_w2, k_tiles_w2), dtype=torch.float32)
torch.rand((E, n_tiles_w2, k_tiles_w2), dtype=torch.float32, device=device)
* factor_for_scale
)
else:
w1_scale = torch.randn(num_experts, dtype=torch.float32)
w2_scale = torch.randn(num_experts, dtype=torch.float32)
w1_scale = torch.randn(num_experts, dtype=torch.float32, device=device)
w2_scale = torch.randn(num_experts, dtype=torch.float32, device=device)
a1_scale = torch.randn(1, dtype=torch.float32)
a2_scale = torch.randn(1, dtype=torch.float32)
a1_scale = torch.randn(1, dtype=torch.float32, device=device)
a2_scale = torch.randn(1, dtype=torch.float32, device=device)
# 获取 FP8_DTYPE
FP8_DTYPE = current_platform.fp8_dtype()
w1 = w1.to(FP8_DTYPE)
w2 = w2.to(FP8_DTYPE)
input_gating = torch.empty(num_tokens, num_experts, dtype=torch.float32)
input_gating = torch.empty(num_tokens, num_experts, dtype=torch.float32, device=device)
def prepare(i: int):
input_gating.copy_(gating_output[i])
......@@ -266,6 +276,9 @@ def get_rocm_tuning_space(use_fp16, nn_moe: Optional[bool] = False):
def get_configs_compute_bound(use_fp16, block_quant_shape, nn_moe: Optional[bool] = False) -> list[dict[str, int]]:
configs: list[BenchmarkConfig] = []
# 局部导入 current_platform
from vllm.platforms import current_platform
if current_platform.is_rocm():
param_ranges = get_rocm_tuning_space(use_fp16, nn_moe)
......@@ -426,12 +439,18 @@ def merge_unique_dicts(list1, list2):
@ray.remote(num_gpus=1)
class BenchmarkWorker:
def __init__(self, seed: int, device_id: int) -> None:
torch.set_default_device("cuda:"+ str(device_id))
from vllm.platforms import current_platform
import os
if current_platform.is_rocm():
# In ROCm environment with Ray, let Ray handle device assignment
# Don't manually set default device as it may conflict with Ray's device mapping
pass
else:
torch.set_default_device("cuda:"+ str(device_id))
current_platform.seed_everything(seed)
self.seed = seed
# Get the device ID to allocate tensors and kernels
# on the respective GPU. This is required for Ray to work
# correctly with multi-GPU tuning on the ROCm platform.
# Store the logical device ID for Ray
self.device_id = device_id
def benchmark(
......@@ -448,7 +467,13 @@ class BenchmarkWorker:
use_deep_gemm: bool = False,
nn_moe: Optional[bool] = False,
) -> tuple[dict[str, int], float]:
# 局部导入 current_platform
from vllm.platforms import current_platform
current_platform.seed_everything(self.seed)
from vllm.model_executor.layers.fused_moe.fused_moe import (
get_config_dtype_str, get_moe_configs, get_default_config
)
dtype_str = get_config_dtype_str(
dtype, use_int8_w8a16=use_int8_w8a16, use_fp8_w8a8=use_fp8_w8a8
)
......@@ -502,6 +527,9 @@ class BenchmarkWorker:
use_deep_gemm: bool,
nn_moe: Optional[bool] = False,
) -> dict[str, int]:
from vllm.platforms import current_platform
import os
best_config = None
best_time = float("inf")
if current_platform.is_rocm():
......@@ -515,10 +543,16 @@ class BenchmarkWorker:
topk,
)
# In ROCm environments with Ray, device context is already handled by Ray
# Using torch.cuda.device() may cause device ordinal conflicts
need_device_guard = False
if current_platform.is_rocm():
visible_device = os.environ.get("ROCR_VISIBLE_DEVICES", None)
if visible_device != f"{self.device_id}":
# For ROCm with Ray, skip additional device context management
need_device_guard = False
else:
# For other platforms, use device guard if needed
visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES", None)
if visible_devices is not None and len(visible_devices.split(',')) > 1:
need_device_guard = True
with torch.cuda.device(self.device_id) if need_device_guard else nullcontext():
......@@ -587,6 +621,10 @@ def save_configs(
block_quant_shape: list[int],
use_nn_moe: Optional[bool] = False,
) -> None:
from vllm.model_executor.layers.fused_moe.fused_moe import (
get_config_dtype_str, get_config_file_name
)
dtype_str = get_config_dtype_str(
dtype, use_int8_w8a16=use_int8_w8a16, use_fp8_w8a8=use_fp8_w8a8
)
......@@ -611,6 +649,12 @@ def get_weight_block_size_safety(config, default_value=None):
def main(args: argparse.Namespace):
import os
import logging
from vllm.platforms import current_platform
logger = logging.getLogger(__name__)
print(args)
tp_size = args.tp_size
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment