Unverified Commit 2bd2d69e authored by NaOHCC's avatar NaOHCC Committed by GitHub
Browse files

[Carver][Bugfix] Correct score function for warp tile selection in tensorcore policy (#724)

* [Carver][Bugfix] Correct score function for warp tile selection in tensorcore policy

* [Typo] Correct architecture selection for CUDA and CDNA
parent 8e1b88f3
......@@ -53,7 +53,7 @@ def get_configs(args, kwargs):
from tilelang.carver.roller.rasterization import NoRasterization
import torch
arch = CDNA("cuda") if torch.version.hip is None else CUDA("hip")
arch = CUDA("cuda") if torch.version.hip is None else CDNA("hip")
topk = 10
carve_template = MatmulTemplate(
......
......@@ -187,7 +187,7 @@ def get_configs(args, kwargs):
from tilelang.carver.roller.rasterization import NoRasterization
import torch
arch = CDNA("cuda") if torch.version.hip is None else CUDA("hip")
arch = CUDA("cuda") if torch.version.hip is None else CDNA("hip")
topk = 10
carve_template = MatmulTemplate(
......
......@@ -96,7 +96,7 @@ def kernel(N,
def main():
my_func = kernel(N, C, H, W, F, K, S, D, P, 64, 128, 32, 3, 256)
cuda_device = CDNA("cuda") if torch.version.hip is None else CUDA("hip")
cuda_device = CUDA("cuda") if torch.version.hip is None else CDNA("hip")
result = Analyzer.analysis(my_func, cuda_device)
print(result)
print(f"Analyzed FLOPs: {result.total_flops}")
......
......@@ -49,7 +49,7 @@ def kernel(
def main():
my_func = kernel(128, 128, 32, 3, 128, True)
cuda_device = CDNA("cuda") if torch.version.hip is None else CUDA("hip")
cuda_device = CUDA("cuda") if torch.version.hip is None else CDNA("hip")
result = Analyzer.analysis(my_func, cuda_device)
print(f"Analyzed FLOPs: {result.total_flops}")
......
......@@ -16,7 +16,7 @@ def ref_program(A, B):
def get_configs(M, N, K, with_roller=False, topk=20):
if with_roller:
arch = CDNA("cuda") if torch.version.hip is None else CUDA("hip")
arch = CUDA("cuda") if torch.version.hip is None else CDNA("hip")
carve_template = MatmulTemplate(
M=M,
N=N,
......
......@@ -281,10 +281,9 @@ class TensorCorePolicy(DefaultPolicy):
factors = factorize(np.prod(space) // warps)
def _score(node, thread): # small is better
def _score(node, warp_tile): # small is better
score = 0
block_tile = [int(np.ceil(tile[i] / thread[i])) for i in range(ndim)]
shape = node.propagate_inputs_on_reduction(block_tile)
shape = node.propagate_inputs_on_reduction(warp_tile)
input_buffers = node.block_analyzer.get_input_buffers(node.reduction_block)
for i, _ in enumerate(input_buffers):
score += np.prod(shape[i]) / self.arch.bandwidth[1]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment