"tests/vscode:/vscode.git/clone" did not exist on "4ae77dfd42041dc2defe21f6ccf76aecb4478812"
Unverified Commit 6d3da472 authored by Jee Jee Li's avatar Jee Jee Li Committed by GitHub
Browse files

[Misc] Add --save-dir option to benchmark_moe (#23020)


Signed-off-by: default avatarJee Jee Li <pandaleefree@gmail.com>
parent 78863f8c
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
import argparse import argparse
import json import json
import os
import time import time
from contextlib import nullcontext from contextlib import nullcontext
from datetime import datetime from datetime import datetime
...@@ -542,6 +543,7 @@ def save_configs( ...@@ -542,6 +543,7 @@ def save_configs(
use_fp8_w8a8: bool, use_fp8_w8a8: bool,
use_int8_w8a16: bool, use_int8_w8a16: bool,
block_quant_shape: list[int], block_quant_shape: list[int],
save_dir: str,
) -> None: ) -> None:
dtype_str = get_config_dtype_str( dtype_str = get_config_dtype_str(
dtype, use_int8_w8a16=use_int8_w8a16, use_fp8_w8a8=use_fp8_w8a8 dtype, use_int8_w8a16=use_int8_w8a16, use_fp8_w8a8=use_fp8_w8a8
...@@ -552,7 +554,8 @@ def save_configs( ...@@ -552,7 +554,8 @@ def save_configs(
filename = get_config_file_name( filename = get_config_file_name(
num_experts, shard_intermediate_size // 2, dtype_str, block_quant_shape num_experts, shard_intermediate_size // 2, dtype_str, block_quant_shape
) )
os.makedirs(save_dir, exist_ok=True)
filename = os.path.join(save_dir, filename)
print(f"Writing best config to {filename}...") print(f"Writing best config to {filename}...")
with open(filename, "w") as f: with open(filename, "w") as f:
json.dump(configs, f, indent=4) json.dump(configs, f, indent=4)
...@@ -707,6 +710,7 @@ def main(args: argparse.Namespace): ...@@ -707,6 +710,7 @@ def main(args: argparse.Namespace):
use_fp8_w8a8, use_fp8_w8a8,
use_int8_w8a16, use_int8_w8a16,
block_quant_shape, block_quant_shape,
args.save_dir,
) )
end = time.time() end = time.time()
print(f"Tuning took {end - start:.2f} seconds") print(f"Tuning took {end - start:.2f} seconds")
...@@ -748,6 +752,9 @@ if __name__ == "__main__": ...@@ -748,6 +752,9 @@ if __name__ == "__main__":
"--dtype", type=str, choices=["auto", "fp8_w8a8", "int8_w8a16"], default="auto" "--dtype", type=str, choices=["auto", "fp8_w8a8", "int8_w8a16"], default="auto"
) )
parser.add_argument("--use-deep-gemm", action="store_true") parser.add_argument("--use-deep-gemm", action="store_true")
parser.add_argument(
"--save-dir", type=str, default="./", help="Directory to save tuned results"
)
parser.add_argument("--seed", type=int, default=0) parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--batch-size", type=int, nargs="+", required=False) parser.add_argument("--batch-size", type=int, nargs="+", required=False)
parser.add_argument("--tune", action="store_true") parser.add_argument("--tune", action="store_true")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment