"tests/entrypoints/pooling/basic/test_encode.py" did not exist on "2554b27baa58b15843367f92d7f73d71bb89033d"
Unverified Commit 4589b940 authored by Tianyu Guo's avatar Tianyu Guo Committed by GitHub
Browse files

[Bugfix] Fix benchmark_moe.py (#19016)


Signed-off-by: default avatarTianyu Guo <guoty9@mail2.sysu.edu.cn>
parent cc867be1
...@@ -7,7 +7,6 @@ import time ...@@ -7,7 +7,6 @@ import time
from contextlib import nullcontext from contextlib import nullcontext
from datetime import datetime from datetime import datetime
from itertools import product from itertools import product
from types import SimpleNamespace
from typing import Any, TypedDict from typing import Any, TypedDict
import ray import ray
...@@ -43,7 +42,7 @@ def benchmark_config( ...@@ -43,7 +42,7 @@ def benchmark_config(
use_fp8_w8a8: bool, use_fp8_w8a8: bool,
use_int8_w8a16: bool, use_int8_w8a16: bool,
num_iters: int = 100, num_iters: int = 100,
block_quant_shape: List[int] = None, block_quant_shape: list[int] = None,
use_deep_gemm: bool = False, use_deep_gemm: bool = False,
) -> float: ) -> float:
init_dtype = torch.float16 if use_fp8_w8a8 else dtype init_dtype = torch.float16 if use_fp8_w8a8 else dtype
...@@ -400,7 +399,7 @@ class BenchmarkWorker: ...@@ -400,7 +399,7 @@ class BenchmarkWorker:
dtype: torch.dtype, dtype: torch.dtype,
use_fp8_w8a8: bool, use_fp8_w8a8: bool,
use_int8_w8a16: bool, use_int8_w8a16: bool,
block_quant_shape: List[int] = None, block_quant_shape: list[int] = None,
use_deep_gemm: bool = False, use_deep_gemm: bool = False,
) -> tuple[dict[str, int], float]: ) -> tuple[dict[str, int], float]:
current_platform.seed_everything(self.seed) current_platform.seed_everything(self.seed)
...@@ -532,7 +531,7 @@ def save_configs( ...@@ -532,7 +531,7 @@ def save_configs(
dtype: torch.dtype, dtype: torch.dtype,
use_fp8_w8a8: bool, use_fp8_w8a8: bool,
use_int8_w8a16: bool, use_int8_w8a16: bool,
block_quant_shape: List[int], block_quant_shape: list[int],
) -> None: ) -> None:
dtype_str = get_config_dtype_str( dtype_str = get_config_dtype_str(
dtype, use_int8_w8a16=use_int8_w8a16, use_fp8_w8a8=use_fp8_w8a8 dtype, use_int8_w8a16=use_int8_w8a16, use_fp8_w8a8=use_fp8_w8a8
...@@ -563,7 +562,6 @@ def main(args: argparse.Namespace): ...@@ -563,7 +562,6 @@ def main(args: argparse.Namespace):
config = get_config(model=args.model, trust_remote_code=args.trust_remote_code) config = get_config(model=args.model, trust_remote_code=args.trust_remote_code)
if args.model_prefix: if args.model_prefix:
config = getattr(config, args.model_prefix) config = getattr(config, args.model_prefix)
config = SimpleNamespace(**config)
if config.architectures[0] == "DbrxForCausalLM": if config.architectures[0] == "DbrxForCausalLM":
E = config.ffn_config.moe_num_experts E = config.ffn_config.moe_num_experts
...@@ -595,11 +593,7 @@ def main(args: argparse.Namespace): ...@@ -595,11 +593,7 @@ def main(args: argparse.Namespace):
shard_intermediate_size = 2 * intermediate_size // args.tp_size shard_intermediate_size = 2 * intermediate_size // args.tp_size
hidden_size = config.hidden_size hidden_size = config.hidden_size
dtype = ( dtype = torch.float16 if current_platform.is_rocm() else config.torch_dtype
torch.float16
if current_platform.is_rocm()
else getattr(torch, config.torch_dtype)
)
use_fp8_w8a8 = args.dtype == "fp8_w8a8" use_fp8_w8a8 = args.dtype == "fp8_w8a8"
use_int8_w8a16 = args.dtype == "int8_w8a16" use_int8_w8a16 = args.dtype == "int8_w8a16"
block_quant_shape = get_weight_block_size_safety(config) block_quant_shape = get_weight_block_size_safety(config)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment