Commit cd3ed273 authored by zhuwenwen's avatar zhuwenwen
Browse files

update benchmark_moe.py

parent be0549c4
...@@ -7,19 +7,19 @@ import time ...@@ -7,19 +7,19 @@ import time
from contextlib import nullcontext from contextlib import nullcontext
from datetime import datetime from datetime import datetime
from itertools import product from itertools import product
from typing import Any, TypedDict from typing import Any, TypedDict, Optional
import ray import ray
import torch import torch
from ray.experimental.tqdm_ray import tqdm from ray.experimental.tqdm_ray import tqdm
from vllm.model_executor.layers.fused_moe.fused_moe import * from vllm.model_executor.layers.fused_moe.fused_moe import *
from vllm.platforms import current_platform
from vllm.transformers_utils.config import get_config from vllm.transformers_utils.config import get_config
from vllm.triton_utils import triton from vllm.triton_utils import triton
from vllm.utils import FlexibleArgumentParser from vllm.utils import FlexibleArgumentParser
FP8_DTYPE = current_platform.fp8_dtype() # 移除全局的 current_platform 导入,改为在需要时局部导入
# FP8_DTYPE = current_platform.fp8_dtype()
class BenchmarkConfig(TypedDict): class BenchmarkConfig(TypedDict):
...@@ -47,8 +47,12 @@ def benchmark_config( ...@@ -47,8 +47,12 @@ def benchmark_config(
use_deep_gemm: bool = False, use_deep_gemm: bool = False,
nn_moe: Optional[bool] = False nn_moe: Optional[bool] = False
) -> float: ) -> float:
from vllm.platforms import current_platform
device = torch.cuda.current_device()
init_dtype = torch.float16 if use_fp8_w8a8 else dtype init_dtype = torch.float16 if use_fp8_w8a8 else dtype
x = torch.randn(num_tokens, hidden_size, dtype=dtype) x = torch.randn(num_tokens, hidden_size, dtype=dtype, device=device)
if use_int8_w8a16: if use_int8_w8a16:
if not nn_moe: if not nn_moe:
w1 = torch.randint( w1 = torch.randint(
...@@ -60,6 +64,7 @@ def benchmark_config( ...@@ -60,6 +64,7 @@ def benchmark_config(
hidden_size, hidden_size,
), ),
dtype=torch.int8, dtype=torch.int8,
device=device,
) )
w2 = torch.randint( w2 = torch.randint(
-127, -127,
...@@ -70,6 +75,7 @@ def benchmark_config( ...@@ -70,6 +75,7 @@ def benchmark_config(
shard_intermediate_size // 2, shard_intermediate_size // 2,
), ),
dtype=torch.int8, dtype=torch.int8,
device=device,
) )
else: else:
w1 = torch.randint( w1 = torch.randint(
...@@ -81,6 +87,7 @@ def benchmark_config( ...@@ -81,6 +87,7 @@ def benchmark_config(
shard_intermediate_size, shard_intermediate_size,
), ),
dtype=torch.int8, dtype=torch.int8,
device=device,
) )
w2 = torch.randint( w2 = torch.randint(
-127, -127,
...@@ -91,23 +98,24 @@ def benchmark_config( ...@@ -91,23 +98,24 @@ def benchmark_config(
hidden_size, hidden_size,
), ),
dtype=torch.int8, dtype=torch.int8,
device=device,
) )
else: else:
if not nn_moe: if not nn_moe:
w1 = torch.randn( w1 = torch.randn(
num_experts, shard_intermediate_size, hidden_size, dtype=init_dtype num_experts, shard_intermediate_size, hidden_size, dtype=init_dtype, device=device
) )
w2 = torch.randn( w2 = torch.randn(
num_experts, hidden_size, shard_intermediate_size // 2, dtype=init_dtype num_experts, hidden_size, shard_intermediate_size // 2, dtype=init_dtype, device=device
) )
else: else:
w1 = torch.randn( w1 = torch.randn(
num_experts, hidden_size, shard_intermediate_size, dtype=init_dtype num_experts, hidden_size, shard_intermediate_size, dtype=init_dtype, device=device
) )
w2 = torch.randn( w2 = torch.randn(
num_experts, shard_intermediate_size // 2, hidden_size, dtype=init_dtype num_experts, shard_intermediate_size // 2, hidden_size, dtype=init_dtype, device=device
) )
gating_output = torch.randn(num_iters, num_tokens, num_experts, dtype=torch.float32) gating_output = torch.randn(num_iters, num_tokens, num_experts, dtype=torch.float32, device=device)
w1_scale = None w1_scale = None
w2_scale = None w2_scale = None
...@@ -115,9 +123,9 @@ def benchmark_config( ...@@ -115,9 +123,9 @@ def benchmark_config(
a2_scale = None a2_scale = None
if use_int8_w8a16: if use_int8_w8a16:
w1_scale = torch.randn( w1_scale = torch.randn(
(num_experts, 2 * shard_intermediate_size), dtype=torch.float32 (num_experts, 2 * shard_intermediate_size), dtype=torch.float32, device=device
) )
w2_scale = torch.randn((hidden_size, num_experts), dtype=torch.float32) w2_scale = torch.randn((hidden_size, num_experts), dtype=torch.float32, device=device)
if use_fp8_w8a8: if use_fp8_w8a8:
if block_quant_shape: if block_quant_shape:
block_n, block_k = block_quant_shape[0], block_quant_shape[1] block_n, block_k = block_quant_shape[0], block_quant_shape[1]
...@@ -130,24 +138,26 @@ def benchmark_config( ...@@ -130,24 +138,26 @@ def benchmark_config(
k_tiles_w1 = (K + block_k - 1) // block_k k_tiles_w1 = (K + block_k - 1) // block_k
k_tiles_w2 = (N + block_k - 1) // block_k k_tiles_w2 = (N + block_k - 1) // block_k
w1_scale = ( w1_scale = (
torch.rand((E, n_tiles_w1, k_tiles_w1), dtype=torch.float32) torch.rand((E, n_tiles_w1, k_tiles_w1), dtype=torch.float32, device=device)
* factor_for_scale * factor_for_scale
) )
w2_scale = ( w2_scale = (
torch.rand((E, n_tiles_w2, k_tiles_w2), dtype=torch.float32) torch.rand((E, n_tiles_w2, k_tiles_w2), dtype=torch.float32, device=device)
* factor_for_scale * factor_for_scale
) )
else: else:
w1_scale = torch.randn(num_experts, dtype=torch.float32) w1_scale = torch.randn(num_experts, dtype=torch.float32, device=device)
w2_scale = torch.randn(num_experts, dtype=torch.float32) w2_scale = torch.randn(num_experts, dtype=torch.float32, device=device)
a1_scale = torch.randn(1, dtype=torch.float32) a1_scale = torch.randn(1, dtype=torch.float32, device=device)
a2_scale = torch.randn(1, dtype=torch.float32) a2_scale = torch.randn(1, dtype=torch.float32, device=device)
# 获取 FP8_DTYPE
FP8_DTYPE = current_platform.fp8_dtype()
w1 = w1.to(FP8_DTYPE) w1 = w1.to(FP8_DTYPE)
w2 = w2.to(FP8_DTYPE) w2 = w2.to(FP8_DTYPE)
input_gating = torch.empty(num_tokens, num_experts, dtype=torch.float32) input_gating = torch.empty(num_tokens, num_experts, dtype=torch.float32, device=device)
def prepare(i: int): def prepare(i: int):
input_gating.copy_(gating_output[i]) input_gating.copy_(gating_output[i])
...@@ -266,6 +276,9 @@ def get_rocm_tuning_space(use_fp16, nn_moe: Optional[bool] = False): ...@@ -266,6 +276,9 @@ def get_rocm_tuning_space(use_fp16, nn_moe: Optional[bool] = False):
def get_configs_compute_bound(use_fp16, block_quant_shape, nn_moe: Optional[bool] = False) -> list[dict[str, int]]: def get_configs_compute_bound(use_fp16, block_quant_shape, nn_moe: Optional[bool] = False) -> list[dict[str, int]]:
configs: list[BenchmarkConfig] = [] configs: list[BenchmarkConfig] = []
# 局部导入 current_platform
from vllm.platforms import current_platform
if current_platform.is_rocm(): if current_platform.is_rocm():
param_ranges = get_rocm_tuning_space(use_fp16, nn_moe) param_ranges = get_rocm_tuning_space(use_fp16, nn_moe)
...@@ -426,12 +439,18 @@ def merge_unique_dicts(list1, list2): ...@@ -426,12 +439,18 @@ def merge_unique_dicts(list1, list2):
@ray.remote(num_gpus=1) @ray.remote(num_gpus=1)
class BenchmarkWorker: class BenchmarkWorker:
def __init__(self, seed: int, device_id: int) -> None: def __init__(self, seed: int, device_id: int) -> None:
torch.set_default_device("cuda:"+ str(device_id)) from vllm.platforms import current_platform
import os
if current_platform.is_rocm():
# In ROCm environment with Ray, let Ray handle device assignment
# Don't manually set default device as it may conflict with Ray's device mapping
pass
else:
torch.set_default_device("cuda:"+ str(device_id))
current_platform.seed_everything(seed) current_platform.seed_everything(seed)
self.seed = seed self.seed = seed
# Get the device ID to allocate tensors and kernels # Store the logical device ID for Ray
# on the respective GPU. This is required for Ray to work
# correctly with multi-GPU tuning on the ROCm platform.
self.device_id = device_id self.device_id = device_id
def benchmark( def benchmark(
...@@ -448,7 +467,13 @@ class BenchmarkWorker: ...@@ -448,7 +467,13 @@ class BenchmarkWorker:
use_deep_gemm: bool = False, use_deep_gemm: bool = False,
nn_moe: Optional[bool] = False, nn_moe: Optional[bool] = False,
) -> tuple[dict[str, int], float]: ) -> tuple[dict[str, int], float]:
# 局部导入 current_platform
from vllm.platforms import current_platform
current_platform.seed_everything(self.seed) current_platform.seed_everything(self.seed)
from vllm.model_executor.layers.fused_moe.fused_moe import (
get_config_dtype_str, get_moe_configs, get_default_config
)
dtype_str = get_config_dtype_str( dtype_str = get_config_dtype_str(
dtype, use_int8_w8a16=use_int8_w8a16, use_fp8_w8a8=use_fp8_w8a8 dtype, use_int8_w8a16=use_int8_w8a16, use_fp8_w8a8=use_fp8_w8a8
) )
...@@ -502,6 +527,9 @@ class BenchmarkWorker: ...@@ -502,6 +527,9 @@ class BenchmarkWorker:
use_deep_gemm: bool, use_deep_gemm: bool,
nn_moe: Optional[bool] = False, nn_moe: Optional[bool] = False,
) -> dict[str, int]: ) -> dict[str, int]:
from vllm.platforms import current_platform
import os
best_config = None best_config = None
best_time = float("inf") best_time = float("inf")
if current_platform.is_rocm(): if current_platform.is_rocm():
...@@ -515,10 +543,16 @@ class BenchmarkWorker: ...@@ -515,10 +543,16 @@ class BenchmarkWorker:
topk, topk,
) )
# In ROCm environments with Ray, device context is already handled by Ray
# Using torch.cuda.device() may cause device ordinal conflicts
need_device_guard = False need_device_guard = False
if current_platform.is_rocm(): if current_platform.is_rocm():
visible_device = os.environ.get("ROCR_VISIBLE_DEVICES", None) # For ROCm with Ray, skip additional device context management
if visible_device != f"{self.device_id}": need_device_guard = False
else:
# For other platforms, use device guard if needed
visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES", None)
if visible_devices is not None and len(visible_devices.split(',')) > 1:
need_device_guard = True need_device_guard = True
with torch.cuda.device(self.device_id) if need_device_guard else nullcontext(): with torch.cuda.device(self.device_id) if need_device_guard else nullcontext():
...@@ -587,6 +621,10 @@ def save_configs( ...@@ -587,6 +621,10 @@ def save_configs(
block_quant_shape: list[int], block_quant_shape: list[int],
use_nn_moe: Optional[bool] = False, use_nn_moe: Optional[bool] = False,
) -> None: ) -> None:
from vllm.model_executor.layers.fused_moe.fused_moe import (
get_config_dtype_str, get_config_file_name
)
dtype_str = get_config_dtype_str( dtype_str = get_config_dtype_str(
dtype, use_int8_w8a16=use_int8_w8a16, use_fp8_w8a8=use_fp8_w8a8 dtype, use_int8_w8a16=use_int8_w8a16, use_fp8_w8a8=use_fp8_w8a8
) )
...@@ -611,6 +649,12 @@ def get_weight_block_size_safety(config, default_value=None): ...@@ -611,6 +649,12 @@ def get_weight_block_size_safety(config, default_value=None):
def main(args: argparse.Namespace): def main(args: argparse.Namespace):
import os
import logging
from vllm.platforms import current_platform
logger = logging.getLogger(__name__)
print(args) print(args)
tp_size = args.tp_size tp_size = args.tp_size
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment