"docs/git@developer.sourcefind.cn:change/sglang.git" did not exist on "65859754f1463ce280bbaaf68d04797705849240"
Unverified Commit 1357397a authored by Liana Koleva's avatar Liana Koleva Committed by GitHub
Browse files

feat: preview filename from tuning_fused_moe_triton.py (#12276)


Co-authored-by: default avatarXiaoyu Zhang <35585791+BBuf@users.noreply.github.com>
parent 42e1a72e
...@@ -376,6 +376,15 @@ def sort_config(config: BenchmarkConfig) -> BenchmarkConfig: ...@@ -376,6 +376,15 @@ def sort_config(config: BenchmarkConfig) -> BenchmarkConfig:
def save_configs( def save_configs(
configs: Dict[int, BenchmarkConfig], configs: Dict[int, BenchmarkConfig],
filename: str,
) -> None:
print(f"Writing best config to {filename}...")
with open(filename, "w") as f:
json.dump(configs, f, indent=4)
f.write("\n")
def get_filename(
num_experts: int, num_experts: int,
shard_intermediate_size: int, shard_intermediate_size: int,
hidden_size: int, hidden_size: int,
...@@ -404,10 +413,7 @@ def save_configs( ...@@ -404,10 +413,7 @@ def save_configs(
per_channel_quant, per_channel_quant,
) )
print(f"Writing best config to {filename}...") return filename
with open(filename, "w") as f:
json.dump(configs, f, indent=4)
f.write("\n")
def main(args: argparse.Namespace): def main(args: argparse.Namespace):
...@@ -541,7 +547,22 @@ def main(args: argparse.Namespace): ...@@ -541,7 +547,22 @@ def main(args: argparse.Namespace):
for config in search_space for config in search_space
if block_k % config["BLOCK_SIZE_K"] == 0 if block_k % config["BLOCK_SIZE_K"] == 0
] ]
print(f"Start tuning over {len(search_space)} configurations...")
filename = get_filename(
E,
shard_intermediate_size,
hidden_size,
topk,
dtype,
use_fp8_w8a8,
use_int8_w8a8,
use_int8_w8a16,
per_channel_quant,
block_shape,
)
print(
f"Start tuning over {len(search_space)} configurations to create {filename}..."
)
start = time.perf_counter() start = time.perf_counter()
configs = _distribute( configs = _distribute(
...@@ -569,16 +590,7 @@ def main(args: argparse.Namespace): ...@@ -569,16 +590,7 @@ def main(args: argparse.Namespace):
} }
save_configs( save_configs(
best_configs, best_configs,
E, filename,
shard_intermediate_size,
hidden_size,
topk,
dtype,
use_fp8_w8a8,
use_int8_w8a8,
use_int8_w8a16,
per_channel_quant,
block_shape,
) )
end = time.perf_counter() end = time.perf_counter()
print(f"Tuning took {end - start:.2f} seconds") print(f"Tuning took {end - start:.2f} seconds")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment