Unverified Commit 8361eb5c authored by Zhang Jason's avatar Zhang Jason Committed by GitHub
Browse files

[Examples] Add the support of rocm arch detecting (#661)


Co-authored-by: default avatarzhangnju <ningzhan@SMC-SC-DI08-33.dh144.dcgpu>
parent d764dca8
...@@ -49,8 +49,14 @@ def get_configs(args, kwargs): ...@@ -49,8 +49,14 @@ def get_configs(args, kwargs):
if with_roller: if with_roller:
from tilelang.carver.template import MatmulTemplate from tilelang.carver.template import MatmulTemplate
from tilelang.carver.arch import CUDA from tilelang.carver.arch import CUDA
from tilelang.carver.arch import CDNA
from tilelang.carver.roller.rasterization import NoRasterization from tilelang.carver.roller.rasterization import NoRasterization
arch = CUDA("cuda") import torch
if torch.version.hip is not None:
arch=CDNA("hip")
else:
arch = CUDA("cuda")
topk = 10 topk = 10
carve_template = MatmulTemplate( carve_template = MatmulTemplate(
......
...@@ -183,8 +183,14 @@ def get_configs(args, kwargs): ...@@ -183,8 +183,14 @@ def get_configs(args, kwargs):
if with_roller: if with_roller:
from tilelang.carver.template import MatmulTemplate from tilelang.carver.template import MatmulTemplate
from tilelang.carver.arch import CUDA from tilelang.carver.arch import CUDA
from tilelang.carver.arch import CDNA
from tilelang.carver.roller.rasterization import NoRasterization from tilelang.carver.roller.rasterization import NoRasterization
arch = CUDA("cuda") import torch
if torch.version.hip is not None:
arch=CDNA("hip")
else:
arch = CUDA("cuda")
topk = 10 topk = 10
carve_template = MatmulTemplate( carve_template = MatmulTemplate(
......
...@@ -50,8 +50,14 @@ def get_configs(args, kwargs): ...@@ -50,8 +50,14 @@ def get_configs(args, kwargs):
if with_roller: if with_roller:
from tilelang.carver.template import MatmulTemplate from tilelang.carver.template import MatmulTemplate
from tilelang.carver.arch import CUDA from tilelang.carver.arch import CUDA
from tilelang.carver.arch import CDNA
from tilelang.carver.roller.rasterization import NoRasterization from tilelang.carver.roller.rasterization import NoRasterization
arch = CUDA("cuda") import torch
if torch.version.hip is not None:
arch=CDNA("hip")
else:
arch = CUDA("cuda")
topk = 10 topk = 10
carve_template = MatmulTemplate( carve_template = MatmulTemplate(
......
import tilelang.language as T import tilelang.language as T
from tilelang.tools import Analyzer from tilelang.tools import Analyzer
from tilelang.carver.arch import CUDA from tilelang.carver.arch import CUDA
from tilelang.carver.arch import CDNA
from tilelang.layout import make_swizzled_layout from tilelang.layout import make_swizzled_layout
import torch
N = 64 N = 64
C = 256 C = 256
H = 512 H = 512
...@@ -94,7 +95,10 @@ def kernel(N, ...@@ -94,7 +95,10 @@ def kernel(N,
def main(): def main():
my_func = kernel(N, C, H, W, F, K, S, D, P, 64, 128, 32, 3, 256) my_func = kernel(N, C, H, W, F, K, S, D, P, 64, 128, 32, 3, 256)
cuda_device = CUDA("cuda") if torch.version.hip is not None:
cuda_device=CDNA("hip")
else:
cuda_device = CUDA("cuda")
result = Analyzer.analysis(my_func, cuda_device) result = Analyzer.analysis(my_func, cuda_device)
print(result) print(result)
print(f"Analyzed FLOPs: {result.total_flops}") print(f"Analyzed FLOPs: {result.total_flops}")
......
import tilelang.language as T import tilelang.language as T
from tilelang.tools import Analyzer from tilelang.tools import Analyzer
from tilelang.carver.arch import CUDA from tilelang.carver.arch import CUDA
from tilelang.carver.arch import CDNA
import torch
M = N = K = 1024 M = N = K = 1024
...@@ -47,7 +49,10 @@ def kernel( ...@@ -47,7 +49,10 @@ def kernel(
def main(): def main():
my_func = kernel(128, 128, 32, 3, 128, True) my_func = kernel(128, 128, 32, 3, 128, True)
cuda_device = CUDA("cuda") if torch.version.hip is not None:
cuda_device=CDNA("hip")
else:
cuda_device = CUDA("cuda")
result = Analyzer.analysis(my_func, cuda_device) result = Analyzer.analysis(my_func, cuda_device)
print(f"Analyzed FLOPs: {result.total_flops}") print(f"Analyzed FLOPs: {result.total_flops}")
......
...@@ -6,6 +6,7 @@ import tilelang.language as T ...@@ -6,6 +6,7 @@ import tilelang.language as T
from tilelang.autotuner import AutoTuner from tilelang.autotuner import AutoTuner
from tilelang.carver.template import ConvTemplate from tilelang.carver.template import ConvTemplate
from tilelang.carver.arch import CUDA from tilelang.carver.arch import CUDA
from tilelang.carver.arch import CDNA
from tilelang.carver.roller.rasterization import NoRasterization from tilelang.carver.roller.rasterization import NoRasterization
...@@ -31,7 +32,10 @@ def ref_program(stride, padding, dilation): ...@@ -31,7 +32,10 @@ def ref_program(stride, padding, dilation):
def get_configs(N, C, H, W, F, K, S, D, P, with_roller=False, topk=15): def get_configs(N, C, H, W, F, K, S, D, P, with_roller=False, topk=15):
if with_roller: if with_roller:
arch = CUDA("cuda") if torch.version.hip is not None:
arch=CDNA("hip")
else:
arch = CUDA("cuda")
carve_template = ConvTemplate( carve_template = ConvTemplate(
N=N, N=N,
C=C, C=C,
......
...@@ -6,6 +6,7 @@ import tilelang.language as T ...@@ -6,6 +6,7 @@ import tilelang.language as T
from tilelang.autotuner import AutoTuner from tilelang.autotuner import AutoTuner
from tilelang.carver.template import MatmulTemplate from tilelang.carver.template import MatmulTemplate
from tilelang.carver.arch import CUDA from tilelang.carver.arch import CUDA
from tilelang.carver.arch import CDNA
from tilelang.carver.roller.rasterization import NoRasterization from tilelang.carver.roller.rasterization import NoRasterization
...@@ -15,7 +16,10 @@ def ref_program(A, B): ...@@ -15,7 +16,10 @@ def ref_program(A, B):
def get_configs(M, N, K, with_roller=False, topk=20): def get_configs(M, N, K, with_roller=False, topk=20):
if with_roller: if with_roller:
arch = CUDA("cuda") if torch.version.hip is not None:
arch=CDNA("hip")
else:
arch = CUDA("cuda")
carve_template = MatmulTemplate( carve_template = MatmulTemplate(
M=M, M=M,
N=N, N=N,
......
...@@ -4,7 +4,7 @@ from .cpu import CPU ...@@ -4,7 +4,7 @@ from .cpu import CPU
from .cdna import CDNA from .cdna import CDNA
from typing import Union from typing import Union
from tvm.target import Target from tvm.target import Target
import torch
def get_arch(target: Union[str, Target] = "cuda") -> TileDevice: def get_arch(target: Union[str, Target] = "cuda") -> TileDevice:
if isinstance(target, str): if isinstance(target, str):
...@@ -23,7 +23,12 @@ def get_arch(target: Union[str, Target] = "cuda") -> TileDevice: ...@@ -23,7 +23,12 @@ def get_arch(target: Union[str, Target] = "cuda") -> TileDevice:
def auto_infer_current_arch() -> TileDevice: def auto_infer_current_arch() -> TileDevice:
# TODO(lei): This is a temporary solution to infer the current architecture # TODO(lei): This is a temporary solution to infer the current architecture
# Can be replaced by a more sophisticated method in the future # Can be replaced by a more sophisticated method in the future
return get_arch("cuda") if torch.version.hip is not None:
return get_arch("hip")
if torch.cuda.is_available():
return get_arch("cuda")
else:
return get_arch("llvm")
from .cpu import is_cpu_arch # noqa: F401 from .cpu import is_cpu_arch # noqa: F401
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment