Unverified Commit 8361eb5c authored by Zhang Jason's avatar Zhang Jason Committed by GitHub
Browse files

[Examples] Add the support of rocm arch detecting (#661)


Co-authored-by: default avatarzhangnju <ningzhan@SMC-SC-DI08-33.dh144.dcgpu>
parent d764dca8
......@@ -49,7 +49,13 @@ def get_configs(args, kwargs):
if with_roller:
from tilelang.carver.template import MatmulTemplate
from tilelang.carver.arch import CUDA
from tilelang.carver.arch import CDNA
from tilelang.carver.roller.rasterization import NoRasterization
import torch
if torch.version.hip is not None:
arch=CDNA("hip")
else:
arch = CUDA("cuda")
topk = 10
......
......@@ -183,7 +183,13 @@ def get_configs(args, kwargs):
if with_roller:
from tilelang.carver.template import MatmulTemplate
from tilelang.carver.arch import CUDA
from tilelang.carver.arch import CDNA
from tilelang.carver.roller.rasterization import NoRasterization
import torch
if torch.version.hip is not None:
arch=CDNA("hip")
else:
arch = CUDA("cuda")
topk = 10
......
......@@ -50,7 +50,13 @@ def get_configs(args, kwargs):
if with_roller:
from tilelang.carver.template import MatmulTemplate
from tilelang.carver.arch import CUDA
from tilelang.carver.arch import CDNA
from tilelang.carver.roller.rasterization import NoRasterization
import torch
if torch.version.hip is not None:
arch=CDNA("hip")
else:
arch = CUDA("cuda")
topk = 10
......
import tilelang.language as T
from tilelang.tools import Analyzer
from tilelang.carver.arch import CUDA
from tilelang.carver.arch import CDNA
from tilelang.layout import make_swizzled_layout
import torch
N = 64
C = 256
H = 512
......@@ -94,6 +95,9 @@ def kernel(N,
def main():
my_func = kernel(N, C, H, W, F, K, S, D, P, 64, 128, 32, 3, 256)
if torch.version.hip is not None:
cuda_device=CDNA("hip")
else:
cuda_device = CUDA("cuda")
result = Analyzer.analysis(my_func, cuda_device)
print(result)
......
import tilelang.language as T
from tilelang.tools import Analyzer
from tilelang.carver.arch import CUDA
from tilelang.carver.arch import CDNA
import torch
M = N = K = 1024
......@@ -47,6 +49,9 @@ def kernel(
def main():
my_func = kernel(128, 128, 32, 3, 128, True)
if torch.version.hip is not None:
cuda_device=CDNA("hip")
else:
cuda_device = CUDA("cuda")
result = Analyzer.analysis(my_func, cuda_device)
......
......@@ -6,6 +6,7 @@ import tilelang.language as T
from tilelang.autotuner import AutoTuner
from tilelang.carver.template import ConvTemplate
from tilelang.carver.arch import CUDA
from tilelang.carver.arch import CDNA
from tilelang.carver.roller.rasterization import NoRasterization
......@@ -31,6 +32,9 @@ def ref_program(stride, padding, dilation):
def get_configs(N, C, H, W, F, K, S, D, P, with_roller=False, topk=15):
if with_roller:
if torch.version.hip is not None:
arch=CDNA("hip")
else:
arch = CUDA("cuda")
carve_template = ConvTemplate(
N=N,
......
......@@ -6,6 +6,7 @@ import tilelang.language as T
from tilelang.autotuner import AutoTuner
from tilelang.carver.template import MatmulTemplate
from tilelang.carver.arch import CUDA
from tilelang.carver.arch import CDNA
from tilelang.carver.roller.rasterization import NoRasterization
......@@ -15,6 +16,9 @@ def ref_program(A, B):
def get_configs(M, N, K, with_roller=False, topk=20):
if with_roller:
if torch.version.hip is not None:
arch=CDNA("hip")
else:
arch = CUDA("cuda")
carve_template = MatmulTemplate(
M=M,
......
......@@ -4,7 +4,7 @@ from .cpu import CPU
from .cdna import CDNA
from typing import Union
from tvm.target import Target
import torch
def get_arch(target: Union[str, Target] = "cuda") -> TileDevice:
if isinstance(target, str):
......@@ -23,7 +23,12 @@ def get_arch(target: Union[str, Target] = "cuda") -> TileDevice:
def auto_infer_current_arch() -> TileDevice:
# TODO(lei): This is a temporary solution to infer the current architecture
# Can be replaced by a more sophisticated method in the future
if torch.version.hip is not None:
return get_arch("hip")
if torch.cuda.is_available():
return get_arch("cuda")
else:
return get_arch("llvm")
from .cpu import is_cpu_arch # noqa: F401
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment