Unverified Commit 2ada4eca authored by Cunxiao Ni's avatar Cunxiao Ni Committed by GitHub
Browse files

[CI] Removes debug print statements from the example. (#1030)



* [CI] Removes debug print statements from the example.

* add parse args

* [Lint]: [pre-commit.ci] auto fixes [...]

* format

---------
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent e59e7f9a
...@@ -6,6 +6,7 @@ from tvm import DataType ...@@ -6,6 +6,7 @@ from tvm import DataType
import torch import torch
from dequantize_utils import torch_convert_bit_twiddling, assert_similar from dequantize_utils import torch_convert_bit_twiddling, assert_similar
from tilelang.autotuner import set_autotune_inputs from tilelang.autotuner import set_autotune_inputs
import argparse
def get_configs(): def get_configs():
...@@ -433,13 +434,18 @@ def get_data(m, n, k, qk, scale_size, topk, E, block_M): ...@@ -433,13 +434,18 @@ def get_data(m, n, k, qk, scale_size, topk, E, block_M):
expert_ids = torch.tensor(expert_ids, dtype=torch.int32, device='cuda') # (padding_M,) expert_ids = torch.tensor(expert_ids, dtype=torch.int32, device='cuda') # (padding_M,)
padding_M = sorted_token_ids.shape[0] # padding_M: token number after padding padding_M = sorted_token_ids.shape[0] # padding_M: token number after padding
print(f'{sorted_token_ids=}')
print(f'{expert_ids=}')
return A, qB, Scale, Bias, topk_weights, sorted_token_ids, expert_ids, padding_M return A, qB, Scale, Bias, topk_weights, sorted_token_ids, expert_ids, padding_M
def main(m=256, n=256, k=256, scale_size=32, fast_dequant=True, with_bias=False, topk=4, E=32): def main(m=256,
n=256,
k=256,
scale_size=32,
topk=4,
E=32,
fast_dequant=True,
with_bias=False,
tune=False):
# Tunable parameters # Tunable parameters
block_M, block_N, block_K = 128, 256, 128 # noqa: F841 block_M, block_N, block_K = 128, 256, 128 # noqa: F841
num_stages = 1 # noqa: F841 num_stages = 1 # noqa: F841
...@@ -453,8 +459,25 @@ def main(m=256, n=256, k=256, scale_size=32, fast_dequant=True, with_bias=False, ...@@ -453,8 +459,25 @@ def main(m=256, n=256, k=256, scale_size=32, fast_dequant=True, with_bias=False,
A, qB, Scale, Bias, topk_weights, sorted_token_ids, expert_ids, padding_M = get_data( A, qB, Scale, Bias, topk_weights, sorted_token_ids, expert_ids, padding_M = get_data(
m, n, k, qk, scale_size, topk, E, block_M) m, n, k, qk, scale_size, topk, E, block_M)
with set_autotune_inputs([A, qB, Scale, Bias, topk_weights, sorted_token_ids, expert_ids]): if tune:
# Autotune with inputs manually composed with set_autotune_inputs([A, qB, Scale, Bias, topk_weights, sorted_token_ids, expert_ids]):
# Autotune with inputs manually composed
kernel = matmul(
m,
n,
k,
topk,
E,
padding_M,
"bfloat16",
"bfloat16",
"float32",
num_bits=num_bits,
scale_size=scale_size,
fast_dequant=fast_dequant,
with_bias=with_bias,
)
else:
kernel = matmul( kernel = matmul(
m, m,
n, n,
...@@ -469,8 +492,13 @@ def main(m=256, n=256, k=256, scale_size=32, fast_dequant=True, with_bias=False, ...@@ -469,8 +492,13 @@ def main(m=256, n=256, k=256, scale_size=32, fast_dequant=True, with_bias=False,
scale_size=scale_size, scale_size=scale_size,
fast_dequant=fast_dequant, fast_dequant=fast_dequant,
with_bias=with_bias, with_bias=with_bias,
block_M=block_M,
block_N=block_N,
block_K=block_K,
num_stages=num_stages,
threads=threads,
split=split,
) )
print(f'Best config: {kernel.config}')
output = kernel( output = kernel(
A, A,
...@@ -504,8 +532,25 @@ def main(m=256, n=256, k=256, scale_size=32, fast_dequant=True, with_bias=False, ...@@ -504,8 +532,25 @@ def main(m=256, n=256, k=256, scale_size=32, fast_dequant=True, with_bias=False,
if __name__ == "__main__": if __name__ == "__main__":
M, N, K = 16384, 5760, 2944 # From gpt-oss-20b MoE's first gemm parser = argparse.ArgumentParser()
scale_size = 32 parser.add_argument(
topk = 4 # experts activated for each token "--M", type=int, default=16384, help="M") # From gpt-oss-20b MoE's first gemm
E = 32 # number of experts parser.add_argument("--N", type=int, default=5760, help="N")
main(M, N, K, scale_size, fast_dequant=True, with_bias=True, topk=topk, E=E) parser.add_argument("--K", type=int, default=2944, help="K")
parser.add_argument("--scale_size", type=int, default=32, help="scale size")
parser.add_argument(
"--topk", type=int, default=4, help="topk") # experts activated for each token
parser.add_argument("--E", type=int, default=32, help="E") # number of experts
parser.add_argument("--tune", action="store_true", help="tune configs")
args = parser.parse_args()
main(
args.M,
args.N,
args.K,
args.scale_size,
topk=args.topk,
E=args.E,
fast_dequant=True,
with_bias=True,
tune=args.tune)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment