Unverified Commit 07610353 authored by Xiaoyu Zhang's avatar Xiaoyu Zhang Committed by GitHub
Browse files

fix log_info_on_rank0 error when run benchmark (#6260)

parent c087ddd6
......@@ -58,15 +58,22 @@ python benchmark/kernels/fused_moe_triton/benchmark_vllm_vs_sglang_fused_moe_tri
# Compare with FP8 mode for Qwen2-57B
python benchmark/kernels/fused_moe_triton/benchmark_vllm_vs_sglang_fused_moe_triton.py \
--model Qwen/Qwen2-57B-A14B-Instruct \
--use-fp8
--use-fp8-w8a8
# Compare with custom TP size
python benchmark/kernels/fused_moe_triton/benchmark_vllm_vs_sglang_fused_moe_triton.py \
--tp-size 4
--model deepseek-ai/DeepSeek-V3-0324 \
--tp-size 8
# Compare with custom TP size and n_share_experts_fusion
python benchmark/kernels/fused_moe_triton/benchmark_vllm_vs_sglang_fused_moe_triton.py \
--model deepseek-ai/DeepSeek-V3-0324 \
--tp-size 8 \
--n-share-experts-fusion 8
```
The benchmark results will be saved as plots and data files in the specified output directory (default: `./configs/benchmark_ops/vllm_sglang_fused_moe/`).
- `benchmark_torch_compile_fused_moe.py`: A tool for benchmarking the performance of the fused MoE kernel with `torch.compile` and original fused MoE kernel.
Usage is the same as `benchmark_vllm_vs_sglang_fused_moe_triton.py`.
Usage is the same as `benchmark_vllm_vs_sglang_fused_moe_triton.py`, note that `torch.compile` does not support `fp8_w8a8` and `int8_w8a8` fused_moe_kernel.
# python3 benchmark/kernels/fused_moe_triton/benchmark_torch_compile_fused_moe.py --model /DeepSeek-V3/ --tp-size 8 --use-fp8-w8a8
import argparse
import torch
......@@ -31,11 +32,12 @@ def get_model_config(model_name: str, tp_size: int):
intermediate_size = config.moe_intermediate_size
shard_intermediate_size = 2 * intermediate_size // tp_size
elif config.architectures[0] == "Qwen3MoeForCausalLM":
E = config.num_experts
E = config.n_routed_experts
topk = config.num_experts_per_tok
intermediate_size = config.moe_intermediate_size
shard_intermediate_size = 2 * intermediate_size // tp_size
elif config.architectures[0] in ["DeepseekV2ForCausalLM", "DeepseekV3ForCausalLM"]:
E = config.n_routed_experts
topk = config.num_experts_per_tok
intermediate_size = config.moe_intermediate_size
......@@ -99,7 +101,7 @@ def fused_moe_torch(
a1_scale=None,
a2_scale=None,
) -> torch.Tensor:
assert not use_fp8_w8a8, "Not supported"
assert not use_fp8_w8a8, "Fp8_w8a8 fused_moe is not supported for torch compile"
topk_weights, topk_ids = fused_topk_native(
hidden_states=x,
......@@ -193,7 +195,7 @@ def fused_moe_sglang_api(
args={},
)
)
def benchmark(batch_size, provider, model_config, use_fp8=False):
def benchmark(batch_size, provider, model_config, use_fp8_w8a8=False):
print(f"benchmark {provider} with batch_size={batch_size}")
torch.set_default_device("cuda")
torch.cuda.manual_seed_all(0)
......@@ -208,7 +210,7 @@ def benchmark(batch_size, provider, model_config, use_fp8=False):
x = torch.randn(num_tokens, hidden_size, dtype=dtype)
if use_fp8:
if use_fp8_w8a8:
init_dtype = dtype
w1 = torch.randn(
num_experts, shard_intermediate_size, hidden_size, dtype=init_dtype
......@@ -244,7 +246,7 @@ def benchmark(batch_size, provider, model_config, use_fp8=False):
w2,
input_gating,
topk,
use_fp8_w8a8=use_fp8,
use_fp8_w8a8=use_fp8_w8a8,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a1_scale,
......@@ -260,7 +262,7 @@ def benchmark(batch_size, provider, model_config, use_fp8=False):
w2,
input_gating,
topk,
use_fp8_w8a8=use_fp8,
use_fp8_w8a8=use_fp8_w8a8,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a1_scale,
......@@ -277,7 +279,7 @@ def main():
"--model", type=str, default="mistralai/Mixtral-8x7B-Instruct-v0.1"
)
parser.add_argument("--tp-size", type=int, default=2)
parser.add_argument("--use-fp8", action="store_true")
parser.add_argument("--use-fp8-w8a8", action="store_true")
parser.add_argument(
"--save-path",
type=str,
......@@ -291,7 +293,7 @@ def main():
print_data=True,
save_path=args.save_path,
model_config=model_config,
use_fp8=args.use_fp8,
use_fp8_w8a8=args.use_fp8_w8a8,
)
......
# python3 benchmark/kernels/fused_moe_triton/benchmark_vllm_vs_sglang_fused_moe_triton.py --model /DeepSeek-V3/ --tp-size 8 --use-fp8-w8a8
import argparse
import torch
......@@ -6,12 +7,18 @@ import vllm
from transformers import AutoConfig
from vllm.model_executor.layers.fused_moe.fused_moe import fused_moe as fused_moe_vllm
from sglang.srt.distributed.parallel_state import (
destroy_distributed_environment,
destroy_model_parallel,
init_distributed_environment,
initialize_model_parallel,
)
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import (
fused_moe as fused_moe_sglang,
)
def get_model_config(model_name: str, tp_size: int):
def get_model_config(model_name: str, tp_size: int, n_share_experts_fusion: int = 0):
"""Get model configuration parameters"""
config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
......@@ -36,7 +43,12 @@ def get_model_config(model_name: str, tp_size: int):
intermediate_size = config.moe_intermediate_size
shard_intermediate_size = 2 * intermediate_size // tp_size
elif config.architectures[0] in ["DeepseekV2ForCausalLM", "DeepseekV3ForCausalLM"]:
E = config.n_routed_experts
n_share_fusion_experts = n_share_experts_fusion
E = (
config.n_routed_experts + n_share_fusion_experts
if config.architectures[0] in ["DeepseekV3ForCausalLM"]
else config.n_routed_experts
)
topk = config.num_experts_per_tok
intermediate_size = config.moe_intermediate_size
shard_intermediate_size = 2 * intermediate_size // tp_size
......@@ -182,7 +194,7 @@ def fused_moe_sglang_api(
args={},
)
)
def benchmark(batch_size, provider, model_config, use_fp8=False):
def benchmark(batch_size, provider, model_config, use_fp8_w8a8=False):
print(f"benchmark {provider} with batch_size={batch_size}")
torch.set_default_device("cuda")
torch.cuda.manual_seed_all(0)
......@@ -193,12 +205,12 @@ def benchmark(batch_size, provider, model_config, use_fp8=False):
shard_intermediate_size = model_config["shard_intermediate_size"]
topk = model_config["topk"]
dtype = model_config["dtype"]
block_shape = getattr(model_config, "block_shape", None)
block_shape = model_config["block_shape"]
x = torch.randn(num_tokens, hidden_size, dtype=dtype)
w1_scale = w2_scale = a1_scale = a2_scale = None
if use_fp8:
if use_fp8_w8a8:
init_dtype = dtype
w1 = torch.randn(
num_experts, shard_intermediate_size, hidden_size, dtype=init_dtype
......@@ -247,7 +259,7 @@ def benchmark(batch_size, provider, model_config, use_fp8=False):
w2,
input_gating,
topk,
use_fp8_w8a8=use_fp8,
use_fp8_w8a8=use_fp8_w8a8,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a1_scale,
......@@ -264,7 +276,7 @@ def benchmark(batch_size, provider, model_config, use_fp8=False):
w2,
input_gating,
topk,
use_fp8_w8a8=use_fp8,
use_fp8_w8a8=use_fp8_w8a8,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a1_scale,
......@@ -282,7 +294,8 @@ def main():
"--model", type=str, default="mistralai/Mixtral-8x7B-Instruct-v0.1"
)
parser.add_argument("--tp-size", type=int, default=2)
parser.add_argument("--use-fp8", action="store_true")
parser.add_argument("--n-share-experts-fusion", type=int, default=0)
parser.add_argument("--use-fp8-w8a8", action="store_true")
parser.add_argument(
"--save-path",
type=str,
......@@ -290,14 +303,41 @@ def main():
)
args = parser.parse_args()
model_config = get_model_config(args.model, args.tp_size)
benchmark.run(
show_plots=True,
print_data=True,
save_path=args.save_path,
model_config=model_config,
use_fp8=args.use_fp8,
)
try:
if not torch.distributed.is_initialized():
torch.distributed.init_process_group(
backend="nccl" if torch.cuda.is_available() else "gloo",
init_method="tcp://127.0.0.1:23456",
world_size=1,
rank=0,
)
init_distributed_environment(
world_size=1,
rank=0,
distributed_init_method="tcp://127.0.0.1:23456",
local_rank=0,
backend="nccl" if torch.cuda.is_available() else "gloo",
)
initialize_model_parallel(
tensor_model_parallel_size=1,
pipeline_model_parallel_size=1,
)
model_config = get_model_config(
args.model, args.tp_size, args.n_share_experts_fusion
)
benchmark.run(
show_plots=True,
print_data=True,
save_path=args.save_path,
model_config=model_config,
use_fp8_w8a8=args.use_fp8_w8a8,
)
finally:
destroy_model_parallel()
destroy_distributed_environment()
if __name__ == "__main__":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment