Unverified Commit 77830a26 authored by lukec's avatar lukec Committed by GitHub
Browse files

Add fuse_moe per-channel tune (#10915)

parent fce17048
......@@ -47,6 +47,7 @@ def benchmark_config(
use_fp8_w8a8: bool,
use_int8_w8a8: bool,
use_int8_w8a16: bool,
per_channel_quant: bool,
block_shape: List[int] = None,
num_iters: int = 100,
) -> float:
......@@ -152,6 +153,7 @@ def benchmark_config(
w2_scale=w2_scale,
a1_scale=a1_scale,
a2_scale=a2_scale,
per_channel_quant=per_channel_quant,
block_shape=block_shape,
)
......@@ -261,6 +263,7 @@ class BenchmarkWorker:
use_fp8_w8a8: bool,
use_int8_w8a8: bool,
use_int8_w8a16: bool,
per_channel_quant: bool,
block_shape: List[int],
) -> Tuple[Dict[str, int], float]:
torch.cuda.manual_seed_all(0)
......@@ -272,7 +275,12 @@ class BenchmarkWorker:
block_n = block_shape[0] if block_shape else 0
block_k = block_shape[1] if block_shape else 0
op_config = get_moe_configs(
num_experts, shard_intermediate_size // 2, dtype_str, block_n, block_k
num_experts,
shard_intermediate_size // 2,
dtype_str,
block_n,
block_k,
per_channel_quant,
)
if op_config is None:
config = get_default_config(
......@@ -299,6 +307,7 @@ class BenchmarkWorker:
use_fp8_w8a8,
use_int8_w8a8,
use_int8_w8a16,
per_channel_quant,
block_shape,
)
return config, kernel_time
......@@ -314,6 +323,7 @@ class BenchmarkWorker:
use_fp8_w8a8: bool,
use_int8_w8a8: bool,
use_int8_w8a16: bool,
per_channel_quant: bool,
block_shape: List[int],
search_space: List[Dict[str, int]],
) -> Dict[str, int]:
......@@ -333,6 +343,7 @@ class BenchmarkWorker:
use_fp8_w8a8,
use_int8_w8a8,
use_int8_w8a16,
per_channel_quant,
block_shape,
num_iters=10,
)
......@@ -373,6 +384,7 @@ def save_configs(
use_fp8_w8a8: bool,
use_int8_w8a8: bool,
use_int8_w8a16: bool,
per_channel_quant: bool,
block_shape: List[int],
) -> None:
dtype_str = get_config_dtype_str(
......@@ -389,6 +401,7 @@ def save_configs(
shard_intermediate_size // 2,
dtype_str,
block_shape,
per_channel_quant,
)
print(f"Writing best config to {filename}...")
......@@ -471,6 +484,7 @@ def main(args: argparse.Namespace):
use_fp8_w8a8 = args.dtype == "fp8_w8a8"
use_int8_w8a8 = args.dtype == "int8_w8a8"
use_int8_w8a16 = args.dtype == "int8_w8a16"
per_channel_quant = args.per_channel_quant
block_shape = None
if (
hasattr(config, "quantization_config")
......@@ -543,6 +557,7 @@ def main(args: argparse.Namespace):
use_fp8_w8a8,
use_int8_w8a8,
use_int8_w8a16,
per_channel_quant,
block_shape,
search_space,
)
......@@ -562,6 +577,7 @@ def main(args: argparse.Namespace):
use_fp8_w8a8,
use_int8_w8a8,
use_int8_w8a16,
per_channel_quant,
block_shape,
)
end = time.perf_counter()
......@@ -580,6 +596,7 @@ def main(args: argparse.Namespace):
use_fp8_w8a8,
use_int8_w8a8,
use_int8_w8a16,
per_channel_quant,
block_shape,
)
for batch_size in batch_sizes
......@@ -603,6 +620,10 @@ if __name__ == "__main__":
choices=["auto", "fp8_w8a8", "int8_w8a16", "int8_w8a8"],
default="auto",
)
parser.add_argument(
"--per-channel-quant",
action="store_true",
)
parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--batch-size", type=int, required=False)
parser.add_argument("--tune", action="store_true")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment