"...composable_kernel.git" did not exist on "99ebfebad1b2eb22e710a97ab9ef5bd2b6bb443e"
Unverified Commit 950ed16c authored by Cunxiao Ni's avatar Cunxiao Ni Committed by GitHub
Browse files

[Fix] fix some issues with JIT decorators existing in the examples (#681)



* [Fix] fix some issues with JIT decorators existing in the examples

* format

* Uses PassConfigKey instand of str

---------
Co-authored-by: default avatarCunxiao <nicunxiao@bytedance.com>
parent 689ee52b
...@@ -25,6 +25,7 @@ def ref_program(stride, padding, dilation): ...@@ -25,6 +25,7 @@ def ref_program(stride, padding, dilation):
return main return main
@tilelang.jit(out_idx=[2])
def convolution(N, def convolution(N,
C, C,
H, H,
...@@ -116,8 +117,7 @@ def main(argv=None): ...@@ -116,8 +117,7 @@ def main(argv=None):
block_k = 32 block_k = 32
num_stages = 3 num_stages = 3
threads = 256 threads = 256
program = convolution(N, C, H, W, F, K, S, D, P, block_m, block_n, block_k, num_stages, threads) kernel = convolution(N, C, H, W, F, K, S, D, P, block_m, block_n, block_k, num_stages, threads)
kernel = tilelang.compile(program, out_idx=[2])
out_c = kernel(a, b) out_c = kernel(a, b)
ref_c = ref_program(S, P, D)(a, b) ref_c = ref_program(S, P, D)(a, b)
......
...@@ -32,10 +32,7 @@ def ref_program(stride, padding, dilation): ...@@ -32,10 +32,7 @@ def ref_program(stride, padding, dilation):
def get_configs(N, C, H, W, F, K, S, D, P, with_roller=False, topk=15): def get_configs(N, C, H, W, F, K, S, D, P, with_roller=False, topk=15):
if with_roller: if with_roller:
if torch.version.hip is not None: arch = CDNA("hip") if torch.version.hip is not None else CUDA("cuda")
arch=CDNA("hip")
else:
arch = CUDA("cuda")
carve_template = ConvTemplate( carve_template = ConvTemplate(
N=N, N=N,
C=C, C=C,
...@@ -102,6 +99,7 @@ def get_configs(N, C, H, W, F, K, S, D, P, with_roller=False, topk=15): ...@@ -102,6 +99,7 @@ def get_configs(N, C, H, W, F, K, S, D, P, with_roller=False, topk=15):
def get_best_config(N, C, H, W, F, K, S, D, P, ref_prog, with_roller=False): def get_best_config(N, C, H, W, F, K, S, D, P, ref_prog, with_roller=False):
@tilelang.jit(out_idx=[2])
def kernel( def kernel(
block_M=None, block_M=None,
block_N=None, block_N=None,
...@@ -212,6 +210,7 @@ def get_heuristic_config() -> dict: ...@@ -212,6 +210,7 @@ def get_heuristic_config() -> dict:
} }
@tilelang.jit(out_idx=[2])
def convolution(N, def convolution(N,
C, C,
H, H,
...@@ -302,7 +301,7 @@ def main(n: int = 128, ...@@ -302,7 +301,7 @@ def main(n: int = 128,
kernel = result.kernel kernel = result.kernel
else: else:
config = get_heuristic_config() config = get_heuristic_config()
kernel = tilelang.compile(convolution(N, C, H, W, F, K, S, D, P, **config), out_idx=[2]) kernel = convolution(N, C, H, W, F, K, S, D, P, **config)
profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Auto) profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Auto)
tilelang_latency = profiler.do_bench() tilelang_latency = profiler.do_bench()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment