"...composable_kernel_onnxruntime.git" did not exist on "0271338ed4d2d6d83f2cd032fffe6726eadfb99d"
Unverified Commit 2bd2d69e authored by NaOHCC's avatar NaOHCC Committed by GitHub
Browse files

[Carver][Bugfix] Correct score function for warp tile selection in tensorcore policy (#724)

* [Carver][Bugfix] Correct score function for warp tile selection in tensorcore policy

* [Typo] Correct architecture selection for CUDA and CDNA
parent 8e1b88f3
...@@ -53,7 +53,7 @@ def get_configs(args, kwargs): ...@@ -53,7 +53,7 @@ def get_configs(args, kwargs):
from tilelang.carver.roller.rasterization import NoRasterization from tilelang.carver.roller.rasterization import NoRasterization
import torch import torch
arch = CDNA("cuda") if torch.version.hip is None else CUDA("hip") arch = CUDA("cuda") if torch.version.hip is None else CDNA("hip")
topk = 10 topk = 10
carve_template = MatmulTemplate( carve_template = MatmulTemplate(
......
...@@ -187,7 +187,7 @@ def get_configs(args, kwargs): ...@@ -187,7 +187,7 @@ def get_configs(args, kwargs):
from tilelang.carver.roller.rasterization import NoRasterization from tilelang.carver.roller.rasterization import NoRasterization
import torch import torch
arch = CDNA("cuda") if torch.version.hip is None else CUDA("hip") arch = CUDA("cuda") if torch.version.hip is None else CDNA("hip")
topk = 10 topk = 10
carve_template = MatmulTemplate( carve_template = MatmulTemplate(
......
...@@ -96,7 +96,7 @@ def kernel(N, ...@@ -96,7 +96,7 @@ def kernel(N,
def main(): def main():
my_func = kernel(N, C, H, W, F, K, S, D, P, 64, 128, 32, 3, 256) my_func = kernel(N, C, H, W, F, K, S, D, P, 64, 128, 32, 3, 256)
cuda_device = CDNA("cuda") if torch.version.hip is None else CUDA("hip") cuda_device = CUDA("cuda") if torch.version.hip is None else CDNA("hip")
result = Analyzer.analysis(my_func, cuda_device) result = Analyzer.analysis(my_func, cuda_device)
print(result) print(result)
print(f"Analyzed FLOPs: {result.total_flops}") print(f"Analyzed FLOPs: {result.total_flops}")
......
...@@ -49,7 +49,7 @@ def kernel( ...@@ -49,7 +49,7 @@ def kernel(
def main(): def main():
my_func = kernel(128, 128, 32, 3, 128, True) my_func = kernel(128, 128, 32, 3, 128, True)
cuda_device = CDNA("cuda") if torch.version.hip is None else CUDA("hip") cuda_device = CUDA("cuda") if torch.version.hip is None else CDNA("hip")
result = Analyzer.analysis(my_func, cuda_device) result = Analyzer.analysis(my_func, cuda_device)
print(f"Analyzed FLOPs: {result.total_flops}") print(f"Analyzed FLOPs: {result.total_flops}")
......
...@@ -16,7 +16,7 @@ def ref_program(A, B): ...@@ -16,7 +16,7 @@ def ref_program(A, B):
def get_configs(M, N, K, with_roller=False, topk=20): def get_configs(M, N, K, with_roller=False, topk=20):
if with_roller: if with_roller:
arch = CDNA("cuda") if torch.version.hip is None else CUDA("hip") arch = CUDA("cuda") if torch.version.hip is None else CDNA("hip")
carve_template = MatmulTemplate( carve_template = MatmulTemplate(
M=M, M=M,
N=N, N=N,
......
...@@ -281,10 +281,9 @@ class TensorCorePolicy(DefaultPolicy): ...@@ -281,10 +281,9 @@ class TensorCorePolicy(DefaultPolicy):
factors = factorize(np.prod(space) // warps) factors = factorize(np.prod(space) // warps)
def _score(node, thread): # small is better def _score(node, warp_tile): # small is better
score = 0 score = 0
block_tile = [int(np.ceil(tile[i] / thread[i])) for i in range(ndim)] shape = node.propagate_inputs_on_reduction(warp_tile)
shape = node.propagate_inputs_on_reduction(block_tile)
input_buffers = node.block_analyzer.get_input_buffers(node.reduction_block) input_buffers = node.block_analyzer.get_input_buffers(node.reduction_block)
for i, _ in enumerate(input_buffers): for i, _ in enumerate(input_buffers):
score += np.prod(shape[i]) / self.arch.bandwidth[1] score += np.prod(shape[i]) / self.arch.bandwidth[1]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment