Unverified Commit 77830a26 authored by lukec's avatar lukec Committed by GitHub
Browse files

Add fuse_moe per-channel tune (#10915)

parent fce17048
...@@ -47,6 +47,7 @@ def benchmark_config( ...@@ -47,6 +47,7 @@ def benchmark_config(
use_fp8_w8a8: bool, use_fp8_w8a8: bool,
use_int8_w8a8: bool, use_int8_w8a8: bool,
use_int8_w8a16: bool, use_int8_w8a16: bool,
per_channel_quant: bool,
block_shape: List[int] = None, block_shape: List[int] = None,
num_iters: int = 100, num_iters: int = 100,
) -> float: ) -> float:
...@@ -152,6 +153,7 @@ def benchmark_config( ...@@ -152,6 +153,7 @@ def benchmark_config(
w2_scale=w2_scale, w2_scale=w2_scale,
a1_scale=a1_scale, a1_scale=a1_scale,
a2_scale=a2_scale, a2_scale=a2_scale,
per_channel_quant=per_channel_quant,
block_shape=block_shape, block_shape=block_shape,
) )
...@@ -261,6 +263,7 @@ class BenchmarkWorker: ...@@ -261,6 +263,7 @@ class BenchmarkWorker:
use_fp8_w8a8: bool, use_fp8_w8a8: bool,
use_int8_w8a8: bool, use_int8_w8a8: bool,
use_int8_w8a16: bool, use_int8_w8a16: bool,
per_channel_quant: bool,
block_shape: List[int], block_shape: List[int],
) -> Tuple[Dict[str, int], float]: ) -> Tuple[Dict[str, int], float]:
torch.cuda.manual_seed_all(0) torch.cuda.manual_seed_all(0)
...@@ -272,7 +275,12 @@ class BenchmarkWorker: ...@@ -272,7 +275,12 @@ class BenchmarkWorker:
block_n = block_shape[0] if block_shape else 0 block_n = block_shape[0] if block_shape else 0
block_k = block_shape[1] if block_shape else 0 block_k = block_shape[1] if block_shape else 0
op_config = get_moe_configs( op_config = get_moe_configs(
num_experts, shard_intermediate_size // 2, dtype_str, block_n, block_k num_experts,
shard_intermediate_size // 2,
dtype_str,
block_n,
block_k,
per_channel_quant,
) )
if op_config is None: if op_config is None:
config = get_default_config( config = get_default_config(
...@@ -299,6 +307,7 @@ class BenchmarkWorker: ...@@ -299,6 +307,7 @@ class BenchmarkWorker:
use_fp8_w8a8, use_fp8_w8a8,
use_int8_w8a8, use_int8_w8a8,
use_int8_w8a16, use_int8_w8a16,
per_channel_quant,
block_shape, block_shape,
) )
return config, kernel_time return config, kernel_time
...@@ -314,6 +323,7 @@ class BenchmarkWorker: ...@@ -314,6 +323,7 @@ class BenchmarkWorker:
use_fp8_w8a8: bool, use_fp8_w8a8: bool,
use_int8_w8a8: bool, use_int8_w8a8: bool,
use_int8_w8a16: bool, use_int8_w8a16: bool,
per_channel_quant: bool,
block_shape: List[int], block_shape: List[int],
search_space: List[Dict[str, int]], search_space: List[Dict[str, int]],
) -> Dict[str, int]: ) -> Dict[str, int]:
...@@ -333,6 +343,7 @@ class BenchmarkWorker: ...@@ -333,6 +343,7 @@ class BenchmarkWorker:
use_fp8_w8a8, use_fp8_w8a8,
use_int8_w8a8, use_int8_w8a8,
use_int8_w8a16, use_int8_w8a16,
per_channel_quant,
block_shape, block_shape,
num_iters=10, num_iters=10,
) )
...@@ -373,6 +384,7 @@ def save_configs( ...@@ -373,6 +384,7 @@ def save_configs(
use_fp8_w8a8: bool, use_fp8_w8a8: bool,
use_int8_w8a8: bool, use_int8_w8a8: bool,
use_int8_w8a16: bool, use_int8_w8a16: bool,
per_channel_quant: bool,
block_shape: List[int], block_shape: List[int],
) -> None: ) -> None:
dtype_str = get_config_dtype_str( dtype_str = get_config_dtype_str(
...@@ -389,6 +401,7 @@ def save_configs( ...@@ -389,6 +401,7 @@ def save_configs(
shard_intermediate_size // 2, shard_intermediate_size // 2,
dtype_str, dtype_str,
block_shape, block_shape,
per_channel_quant,
) )
print(f"Writing best config to {filename}...") print(f"Writing best config to {filename}...")
...@@ -471,6 +484,7 @@ def main(args: argparse.Namespace): ...@@ -471,6 +484,7 @@ def main(args: argparse.Namespace):
use_fp8_w8a8 = args.dtype == "fp8_w8a8" use_fp8_w8a8 = args.dtype == "fp8_w8a8"
use_int8_w8a8 = args.dtype == "int8_w8a8" use_int8_w8a8 = args.dtype == "int8_w8a8"
use_int8_w8a16 = args.dtype == "int8_w8a16" use_int8_w8a16 = args.dtype == "int8_w8a16"
per_channel_quant = args.per_channel_quant
block_shape = None block_shape = None
if ( if (
hasattr(config, "quantization_config") hasattr(config, "quantization_config")
...@@ -543,6 +557,7 @@ def main(args: argparse.Namespace): ...@@ -543,6 +557,7 @@ def main(args: argparse.Namespace):
use_fp8_w8a8, use_fp8_w8a8,
use_int8_w8a8, use_int8_w8a8,
use_int8_w8a16, use_int8_w8a16,
per_channel_quant,
block_shape, block_shape,
search_space, search_space,
) )
...@@ -562,6 +577,7 @@ def main(args: argparse.Namespace): ...@@ -562,6 +577,7 @@ def main(args: argparse.Namespace):
use_fp8_w8a8, use_fp8_w8a8,
use_int8_w8a8, use_int8_w8a8,
use_int8_w8a16, use_int8_w8a16,
per_channel_quant,
block_shape, block_shape,
) )
end = time.perf_counter() end = time.perf_counter()
...@@ -580,6 +596,7 @@ def main(args: argparse.Namespace): ...@@ -580,6 +596,7 @@ def main(args: argparse.Namespace):
use_fp8_w8a8, use_fp8_w8a8,
use_int8_w8a8, use_int8_w8a8,
use_int8_w8a16, use_int8_w8a16,
per_channel_quant,
block_shape, block_shape,
) )
for batch_size in batch_sizes for batch_size in batch_sizes
...@@ -603,6 +620,10 @@ if __name__ == "__main__": ...@@ -603,6 +620,10 @@ if __name__ == "__main__":
choices=["auto", "fp8_w8a8", "int8_w8a16", "int8_w8a8"], choices=["auto", "fp8_w8a8", "int8_w8a16", "int8_w8a8"],
default="auto", default="auto",
) )
parser.add_argument(
"--per-channel-quant",
action="store_true",
)
parser.add_argument("--seed", type=int, default=0) parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--batch-size", type=int, required=False) parser.add_argument("--batch-size", type=int, required=False)
parser.add_argument("--tune", action="store_true") parser.add_argument("--tune", action="store_true")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment