Unverified Commit 07610353 authored by Xiaoyu Zhang's avatar Xiaoyu Zhang Committed by GitHub
Browse files

fix log_info_on_rank0 error when run benchmark (#6260)

parent c087ddd6
...@@ -58,15 +58,22 @@ python benchmark/kernels/fused_moe_triton/benchmark_vllm_vs_sglang_fused_moe_tri ...@@ -58,15 +58,22 @@ python benchmark/kernels/fused_moe_triton/benchmark_vllm_vs_sglang_fused_moe_tri
# Compare with FP8 mode for Qwen2-57B # Compare with FP8 mode for Qwen2-57B
python benchmark/kernels/fused_moe_triton/benchmark_vllm_vs_sglang_fused_moe_triton.py \ python benchmark/kernels/fused_moe_triton/benchmark_vllm_vs_sglang_fused_moe_triton.py \
--model Qwen/Qwen2-57B-A14B-Instruct \ --model Qwen/Qwen2-57B-A14B-Instruct \
--use-fp8 --use-fp8-w8a8
# Compare with custom TP size # Compare with custom TP size
python benchmark/kernels/fused_moe_triton/benchmark_vllm_vs_sglang_fused_moe_triton.py \ python benchmark/kernels/fused_moe_triton/benchmark_vllm_vs_sglang_fused_moe_triton.py \
--tp-size 4 --model deepseek-ai/DeepSeek-V3-0324 \
--tp-size 8
# Compare with custom TP size and n_share_experts_fusion
python benchmark/kernels/fused_moe_triton/benchmark_vllm_vs_sglang_fused_moe_triton.py \
--model deepseek-ai/DeepSeek-V3-0324 \
--tp-size 8 \
--n-share-experts-fusion 8
``` ```
The benchmark results will be saved as plots and data files in the specified output directory (default: `./configs/benchmark_ops/vllm_sglang_fused_moe/`). The benchmark results will be saved as plots and data files in the specified output directory (default: `./configs/benchmark_ops/vllm_sglang_fused_moe/`).
- `benchmark_torch_compile_fused_moe.py`: A tool for benchmarking the performance of the fused MoE kernel with `torch.compile` and original fused MoE kernel. - `benchmark_torch_compile_fused_moe.py`: A tool for benchmarking the performance of the fused MoE kernel with `torch.compile` and original fused MoE kernel.
Usage is the same as `benchmark_vllm_vs_sglang_fused_moe_triton.py`. Usage is the same as `benchmark_vllm_vs_sglang_fused_moe_triton.py`, note that `torch.compile` does not support `fp8_w8a8` and `int8_w8a8` fused_moe_kernel.
# python3 benchmark/kernels/fused_moe_triton/benchmark_torch_compile_fused_moe.py --model /DeepSeek-V3/ --tp-size 8 --use-fp8-w8a8
import argparse import argparse
import torch import torch
...@@ -31,11 +32,12 @@ def get_model_config(model_name: str, tp_size: int): ...@@ -31,11 +32,12 @@ def get_model_config(model_name: str, tp_size: int):
intermediate_size = config.moe_intermediate_size intermediate_size = config.moe_intermediate_size
shard_intermediate_size = 2 * intermediate_size // tp_size shard_intermediate_size = 2 * intermediate_size // tp_size
elif config.architectures[0] == "Qwen3MoeForCausalLM": elif config.architectures[0] == "Qwen3MoeForCausalLM":
E = config.num_experts E = config.n_routed_experts
topk = config.num_experts_per_tok topk = config.num_experts_per_tok
intermediate_size = config.moe_intermediate_size intermediate_size = config.moe_intermediate_size
shard_intermediate_size = 2 * intermediate_size // tp_size shard_intermediate_size = 2 * intermediate_size // tp_size
elif config.architectures[0] in ["DeepseekV2ForCausalLM", "DeepseekV3ForCausalLM"]: elif config.architectures[0] in ["DeepseekV2ForCausalLM", "DeepseekV3ForCausalLM"]:
E = config.n_routed_experts E = config.n_routed_experts
topk = config.num_experts_per_tok topk = config.num_experts_per_tok
intermediate_size = config.moe_intermediate_size intermediate_size = config.moe_intermediate_size
...@@ -99,7 +101,7 @@ def fused_moe_torch( ...@@ -99,7 +101,7 @@ def fused_moe_torch(
a1_scale=None, a1_scale=None,
a2_scale=None, a2_scale=None,
) -> torch.Tensor: ) -> torch.Tensor:
assert not use_fp8_w8a8, "Not supported" assert not use_fp8_w8a8, "Fp8_w8a8 fused_moe is not supported for torch compile"
topk_weights, topk_ids = fused_topk_native( topk_weights, topk_ids = fused_topk_native(
hidden_states=x, hidden_states=x,
...@@ -193,7 +195,7 @@ def fused_moe_sglang_api( ...@@ -193,7 +195,7 @@ def fused_moe_sglang_api(
args={}, args={},
) )
) )
def benchmark(batch_size, provider, model_config, use_fp8=False): def benchmark(batch_size, provider, model_config, use_fp8_w8a8=False):
print(f"benchmark {provider} with batch_size={batch_size}") print(f"benchmark {provider} with batch_size={batch_size}")
torch.set_default_device("cuda") torch.set_default_device("cuda")
torch.cuda.manual_seed_all(0) torch.cuda.manual_seed_all(0)
...@@ -208,7 +210,7 @@ def benchmark(batch_size, provider, model_config, use_fp8=False): ...@@ -208,7 +210,7 @@ def benchmark(batch_size, provider, model_config, use_fp8=False):
x = torch.randn(num_tokens, hidden_size, dtype=dtype) x = torch.randn(num_tokens, hidden_size, dtype=dtype)
if use_fp8: if use_fp8_w8a8:
init_dtype = dtype init_dtype = dtype
w1 = torch.randn( w1 = torch.randn(
num_experts, shard_intermediate_size, hidden_size, dtype=init_dtype num_experts, shard_intermediate_size, hidden_size, dtype=init_dtype
...@@ -244,7 +246,7 @@ def benchmark(batch_size, provider, model_config, use_fp8=False): ...@@ -244,7 +246,7 @@ def benchmark(batch_size, provider, model_config, use_fp8=False):
w2, w2,
input_gating, input_gating,
topk, topk,
use_fp8_w8a8=use_fp8, use_fp8_w8a8=use_fp8_w8a8,
w1_scale=w1_scale, w1_scale=w1_scale,
w2_scale=w2_scale, w2_scale=w2_scale,
a1_scale=a1_scale, a1_scale=a1_scale,
...@@ -260,7 +262,7 @@ def benchmark(batch_size, provider, model_config, use_fp8=False): ...@@ -260,7 +262,7 @@ def benchmark(batch_size, provider, model_config, use_fp8=False):
w2, w2,
input_gating, input_gating,
topk, topk,
use_fp8_w8a8=use_fp8, use_fp8_w8a8=use_fp8_w8a8,
w1_scale=w1_scale, w1_scale=w1_scale,
w2_scale=w2_scale, w2_scale=w2_scale,
a1_scale=a1_scale, a1_scale=a1_scale,
...@@ -277,7 +279,7 @@ def main(): ...@@ -277,7 +279,7 @@ def main():
"--model", type=str, default="mistralai/Mixtral-8x7B-Instruct-v0.1" "--model", type=str, default="mistralai/Mixtral-8x7B-Instruct-v0.1"
) )
parser.add_argument("--tp-size", type=int, default=2) parser.add_argument("--tp-size", type=int, default=2)
parser.add_argument("--use-fp8", action="store_true") parser.add_argument("--use-fp8-w8a8", action="store_true")
parser.add_argument( parser.add_argument(
"--save-path", "--save-path",
type=str, type=str,
...@@ -291,7 +293,7 @@ def main(): ...@@ -291,7 +293,7 @@ def main():
print_data=True, print_data=True,
save_path=args.save_path, save_path=args.save_path,
model_config=model_config, model_config=model_config,
use_fp8=args.use_fp8, use_fp8_w8a8=args.use_fp8_w8a8,
) )
......
# python3 benchmark/kernels/fused_moe_triton/benchmark_vllm_vs_sglang_fused_moe_triton.py --model /DeepSeek-V3/ --tp-size 8 --use-fp8-w8a8
import argparse import argparse
import torch import torch
...@@ -6,12 +7,18 @@ import vllm ...@@ -6,12 +7,18 @@ import vllm
from transformers import AutoConfig from transformers import AutoConfig
from vllm.model_executor.layers.fused_moe.fused_moe import fused_moe as fused_moe_vllm from vllm.model_executor.layers.fused_moe.fused_moe import fused_moe as fused_moe_vllm
from sglang.srt.distributed.parallel_state import (
destroy_distributed_environment,
destroy_model_parallel,
init_distributed_environment,
initialize_model_parallel,
)
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import ( from sglang.srt.layers.moe.fused_moe_triton.fused_moe import (
fused_moe as fused_moe_sglang, fused_moe as fused_moe_sglang,
) )
def get_model_config(model_name: str, tp_size: int): def get_model_config(model_name: str, tp_size: int, n_share_experts_fusion: int = 0):
"""Get model configuration parameters""" """Get model configuration parameters"""
config = AutoConfig.from_pretrained(model_name, trust_remote_code=True) config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
...@@ -36,7 +43,12 @@ def get_model_config(model_name: str, tp_size: int): ...@@ -36,7 +43,12 @@ def get_model_config(model_name: str, tp_size: int):
intermediate_size = config.moe_intermediate_size intermediate_size = config.moe_intermediate_size
shard_intermediate_size = 2 * intermediate_size // tp_size shard_intermediate_size = 2 * intermediate_size // tp_size
elif config.architectures[0] in ["DeepseekV2ForCausalLM", "DeepseekV3ForCausalLM"]: elif config.architectures[0] in ["DeepseekV2ForCausalLM", "DeepseekV3ForCausalLM"]:
E = config.n_routed_experts n_share_fusion_experts = n_share_experts_fusion
E = (
config.n_routed_experts + n_share_fusion_experts
if config.architectures[0] in ["DeepseekV3ForCausalLM"]
else config.n_routed_experts
)
topk = config.num_experts_per_tok topk = config.num_experts_per_tok
intermediate_size = config.moe_intermediate_size intermediate_size = config.moe_intermediate_size
shard_intermediate_size = 2 * intermediate_size // tp_size shard_intermediate_size = 2 * intermediate_size // tp_size
...@@ -182,7 +194,7 @@ def fused_moe_sglang_api( ...@@ -182,7 +194,7 @@ def fused_moe_sglang_api(
args={}, args={},
) )
) )
def benchmark(batch_size, provider, model_config, use_fp8=False): def benchmark(batch_size, provider, model_config, use_fp8_w8a8=False):
print(f"benchmark {provider} with batch_size={batch_size}") print(f"benchmark {provider} with batch_size={batch_size}")
torch.set_default_device("cuda") torch.set_default_device("cuda")
torch.cuda.manual_seed_all(0) torch.cuda.manual_seed_all(0)
...@@ -193,12 +205,12 @@ def benchmark(batch_size, provider, model_config, use_fp8=False): ...@@ -193,12 +205,12 @@ def benchmark(batch_size, provider, model_config, use_fp8=False):
shard_intermediate_size = model_config["shard_intermediate_size"] shard_intermediate_size = model_config["shard_intermediate_size"]
topk = model_config["topk"] topk = model_config["topk"]
dtype = model_config["dtype"] dtype = model_config["dtype"]
block_shape = getattr(model_config, "block_shape", None) block_shape = model_config["block_shape"]
x = torch.randn(num_tokens, hidden_size, dtype=dtype) x = torch.randn(num_tokens, hidden_size, dtype=dtype)
w1_scale = w2_scale = a1_scale = a2_scale = None w1_scale = w2_scale = a1_scale = a2_scale = None
if use_fp8: if use_fp8_w8a8:
init_dtype = dtype init_dtype = dtype
w1 = torch.randn( w1 = torch.randn(
num_experts, shard_intermediate_size, hidden_size, dtype=init_dtype num_experts, shard_intermediate_size, hidden_size, dtype=init_dtype
...@@ -247,7 +259,7 @@ def benchmark(batch_size, provider, model_config, use_fp8=False): ...@@ -247,7 +259,7 @@ def benchmark(batch_size, provider, model_config, use_fp8=False):
w2, w2,
input_gating, input_gating,
topk, topk,
use_fp8_w8a8=use_fp8, use_fp8_w8a8=use_fp8_w8a8,
w1_scale=w1_scale, w1_scale=w1_scale,
w2_scale=w2_scale, w2_scale=w2_scale,
a1_scale=a1_scale, a1_scale=a1_scale,
...@@ -264,7 +276,7 @@ def benchmark(batch_size, provider, model_config, use_fp8=False): ...@@ -264,7 +276,7 @@ def benchmark(batch_size, provider, model_config, use_fp8=False):
w2, w2,
input_gating, input_gating,
topk, topk,
use_fp8_w8a8=use_fp8, use_fp8_w8a8=use_fp8_w8a8,
w1_scale=w1_scale, w1_scale=w1_scale,
w2_scale=w2_scale, w2_scale=w2_scale,
a1_scale=a1_scale, a1_scale=a1_scale,
...@@ -282,7 +294,8 @@ def main(): ...@@ -282,7 +294,8 @@ def main():
"--model", type=str, default="mistralai/Mixtral-8x7B-Instruct-v0.1" "--model", type=str, default="mistralai/Mixtral-8x7B-Instruct-v0.1"
) )
parser.add_argument("--tp-size", type=int, default=2) parser.add_argument("--tp-size", type=int, default=2)
parser.add_argument("--use-fp8", action="store_true") parser.add_argument("--n-share-experts-fusion", type=int, default=0)
parser.add_argument("--use-fp8-w8a8", action="store_true")
parser.add_argument( parser.add_argument(
"--save-path", "--save-path",
type=str, type=str,
...@@ -290,14 +303,41 @@ def main(): ...@@ -290,14 +303,41 @@ def main():
) )
args = parser.parse_args() args = parser.parse_args()
model_config = get_model_config(args.model, args.tp_size) try:
if not torch.distributed.is_initialized():
torch.distributed.init_process_group(
backend="nccl" if torch.cuda.is_available() else "gloo",
init_method="tcp://127.0.0.1:23456",
world_size=1,
rank=0,
)
init_distributed_environment(
world_size=1,
rank=0,
distributed_init_method="tcp://127.0.0.1:23456",
local_rank=0,
backend="nccl" if torch.cuda.is_available() else "gloo",
)
initialize_model_parallel(
tensor_model_parallel_size=1,
pipeline_model_parallel_size=1,
)
model_config = get_model_config(
args.model, args.tp_size, args.n_share_experts_fusion
)
benchmark.run( benchmark.run(
show_plots=True, show_plots=True,
print_data=True, print_data=True,
save_path=args.save_path, save_path=args.save_path,
model_config=model_config, model_config=model_config,
use_fp8=args.use_fp8, use_fp8_w8a8=args.use_fp8_w8a8,
) )
finally:
destroy_model_parallel()
destroy_distributed_environment()
if __name__ == "__main__": if __name__ == "__main__":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment