Commit 782ca9f6 authored by xs-keju's avatar xs-keju Committed by GitHub
Browse files

Add cpu jit with backend ctypes (#154)



* Add cpu jit with backend ctypes

* Resolve some lint issues

* Apply PR feedback on head file and kernel example

* Add test cases

* Resolve formatting issues

* Resolve formatting issues

---------
Co-authored-by: default avatarxxw <1990389406@qq.con>
parent 3486e27e
#pragma once
#include "half.hpp"
#include <math.h>
#include <stdbool.h>
using half_float::half;
// Not Implemented
\ No newline at end of file
#pragma once
// Not Implemented
This diff is collapsed.
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.
#pragma once #pragma once
#include <math.h> #include <math.h>
#include <stdbool.h> #include <stdbool.h>
// Not Implemented // Not Implemented
F
\ No newline at end of file
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.
#pragma once #pragma once
// Not Implemented // Not Implemented
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import tilelang import tilelang
import tilelang.testing import tilelang.testing
from tilelang import tvm as tvm from tilelang import tvm as tvm
import tilelang.language as T import tilelang.language as T
import torch
def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"): def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"):
...@@ -62,5 +60,58 @@ def test_matmul_codegen(): ...@@ -62,5 +60,58 @@ def test_matmul_codegen():
assert_matmul_codegen(M=1024, N=1024, K=1024, block_M=128, block_N=128, block_K=32) assert_matmul_codegen(M=1024, N=1024, K=1024, block_M=128, block_N=128, block_K=32)
def test_matmul_compile():
def matmul_jit_test(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"):
# a simple kernel just for jit test
@T.prim_func
def matmul(
A: T.Buffer((M, K), dtype),
B: T.Buffer((K, N), dtype),
C: T.Buffer((M, N), dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), is_cpu=True) as (bx, by):
A_local = T.alloc_local((block_M, block_K), dtype)
B_local = T.alloc_local((block_K, block_N), dtype)
C_local = T.alloc_local((block_M, block_N), accum_dtype)
for p in T.serial(block_M):
for w in T.serial(block_N):
C_local[p, w] = 0
for ko in T.serial(K // block_K):
for i in T.serial(block_M):
for k in T.serial(block_K):
A_local[i, k] = A[by * block_M + i, ko * block_K + k]
for k in T.serial(block_K):
for j in T.serial(block_N):
B_local[k, j] = B[ko * block_K + k, bx * block_N + j]
for i in T.serial(block_M):
for j in T.serial(block_N):
for k in T.serial(block_K):
C_local[i, j] += A_local[i, k] * B_local[k, j]
for i in T.serial(block_M):
for j in T.serial(block_N):
C[by * block_M + i, bx * block_N + j] = C_local[i, j]
return matmul
M, N, K = 1024, 512, 512
block_M, block_N, block_K = M // 4, N // 4, K // 4
cpu_func = matmul_jit_test(M, N, K, block_M, block_N, block_K)
complied_fun = tilelang.compile(cpu_func, -1, execution_backend="ctypes", target="c")
in_dtype = "float16"
A = torch.randn(M, K, dtype=torch.__getattribute__(in_dtype))
B = torch.randn(K, N, dtype=torch.__getattribute__(in_dtype))
C = complied_fun(A, B)
C_torch = torch.matmul(A, B)
tilelang.testing.torch_assert_close(C, C_torch, atol=1e-2, rtol=1e-2, max_mismatched_ratio=0.05)
if __name__ == "__main__": if __name__ == "__main__":
tilelang.testing.main() tilelang.testing.main()
# Copyright (c) Microsoft Corporation. # Copyright (c) Microsoft Corporation.
# Licensed under the MIT License. # Licensed under the MIT License.
from typing import Optional from typing import Optional
from .utils import is_cuda_target, is_hip_target from .utils import is_cuda_target, is_hip_target, is_cpu_target
from tilelang import tvm as tvm from tilelang import tvm as tvm
from tilelang.contrib.nvcc import get_target_compute_version from tilelang.contrib.nvcc import get_target_compute_version
from tvm.target import Target from tvm.target import Target
...@@ -41,8 +41,8 @@ class LibraryGenerator(object): ...@@ -41,8 +41,8 @@ class LibraryGenerator(object):
command = [ command = [
"nvcc", "nvcc",
"-std=c++17", "-std=c++17",
"-w", # Disable all warning messages "-w", # Disable all warning messages
"-Xcudafe", "-Xcudafe",
"--diag_suppress=177", "--diag_suppress=177",
"--compiler-options", "--compiler-options",
...@@ -66,7 +66,15 @@ class LibraryGenerator(object): ...@@ -66,7 +66,15 @@ class LibraryGenerator(object):
"--shared", "--shared",
src.name, src.name,
] ]
elif is_cpu_target(target):
src = tempfile.NamedTemporaryFile(mode="w", suffix=".cpp", delete=False)
libpath = src.name.replace(".cpp", ".so")
command = ["g++", "-std=c++17", "-fPIC", "-shared", src.name]
with_tl = False
command += [
"-I" + TILELANG_TEMPLATE_PATH,
]
else: else:
raise ValueError(f"Unsupported target: {target}") raise ValueError(f"Unsupported target: {target}")
......
...@@ -38,6 +38,16 @@ def match_declare_kernel(source: str, annotation: str = "__global__") -> int: ...@@ -38,6 +38,16 @@ def match_declare_kernel(source: str, annotation: str = "__global__") -> int:
raise ValueError("No global kernel found in the source code") raise ValueError("No global kernel found in the source code")
def match_declare_kernel_cpu(source: str, annotation: str = "int32_t") -> int:
pattern = r"int32_t\s+\w+"
for line in source.split("\n"):
if annotation in line:
matched = re.findall(pattern, line)
if len(matched) >= 1:
return source.index(matched[0] + "(")
raise ValueError("No global kernel found in the source code")
def is_cuda_target(target: Target) -> bool: def is_cuda_target(target: Target) -> bool:
return target.kind.name == "cuda" return target.kind.name == "cuda"
...@@ -46,6 +56,10 @@ def is_hip_target(target: Target) -> bool: ...@@ -46,6 +56,10 @@ def is_hip_target(target: Target) -> bool:
return target.kind.name == "hip" return target.kind.name == "hip"
def is_cpu_target(target: Target) -> bool:
return target.kind.name in ["c"]
def get_annotated_mod( def get_annotated_mod(
func_or_mod: Union[tir.PrimFunc, tvm.IRModule], func_or_mod: Union[tir.PrimFunc, tvm.IRModule],
target: Union[str, Target] = "auto", target: Union[str, Target] = "auto",
......
...@@ -6,9 +6,10 @@ from tilelang import tvm as tvm ...@@ -6,9 +6,10 @@ from tilelang import tvm as tvm
from typing import Optional, List, Dict, Union from typing import Optional, List, Dict, Union
from tvm import IRModule from tvm import IRModule
from tvm.target import Target from tvm.target import Target
from .utils import match_declare_kernel, is_cuda_target, is_hip_target, get_annotated_mod from .utils import match_declare_kernel, match_declare_kernel_cpu, is_cuda_target, is_hip_target, is_cpu_target, get_annotated_mod
import re import re
import logging import logging
import textwrap
PREDEF_ARRTIBUTE_SET_DYNAMIC_MEMORY = """ PREDEF_ARRTIBUTE_SET_DYNAMIC_MEMORY = """
cudaFuncSetAttribute({}, cudaFuncAttributeMaxDynamicSharedMemorySize, {}); cudaFuncSetAttribute({}, cudaFuncAttributeMaxDynamicSharedMemorySize, {});
...@@ -374,6 +375,182 @@ class TLHIPSourceWrapper(TLCUDASourceWrapper): ...@@ -374,6 +375,182 @@ class TLHIPSourceWrapper(TLCUDASourceWrapper):
function_args.append({"name": "stream=hipStreamDefault", "type": "hipStream_t"},) function_args.append({"name": "stream=hipStreamDefault", "type": "hipStream_t"},)
class TLCPUSourceWrapper(object):
_TYPE_MAP = {
"float32": "float",
"float16": "half",
"int32": "int32_t",
}
INIT_FUNC = textwrap.dedent('''
#ifdef __cplusplus
extern "C"
#endif
int32_t init() {
return 0;
}
''')
CALL_PREFIX = textwrap.dedent("""
#ifdef __cplusplus
extern "C"
#endif
int32_t call({}) {{
return {};
}}
""")
backend = "tl"
backend = "tl"
def __init__(self, scheduled_ir_module: IRModule, source: str, target: Target):
self.mod = scheduled_ir_module
self.target = target
self.source = source
self.function_names: Optional[str] = None
self.dynamic_smem_buf: Optional[int] = None
self.parse_source_information()
self.srcpath: Optional[str] = None
self.libpath: Optional[str] = None
self.lib_code: Optional[str] = self.update_lib_code(source)
def create_call_func(self, code, function_informations):
# Extract the set of dynamic symbolic names used in the primary function
dynamic_symbolic_set = self.get_dynamic_symbolic_set(self.prim_func)
function_args = []
# Collect function arguments based on primary function's parameters and buffer mappings
for param in self.prim_func.params:
if param in self.prim_func.buffer_map:
buffer = self.prim_func.buffer_map[param]
function_args.append({
"name": buffer.name,
"type": self._TYPE_MAP[buffer.dtype] + "*",
})
elif isinstance(param, tvm.tir.Var):
function_args.append({"name": param.name, "type": self._TYPE_MAP[param.dtype]})
else:
raise ValueError(
f"Parameter {param} is not in the buffer map of the primary function.")
# Add dynamic symbols as integer arguments
for dyn_sym in dynamic_symbolic_set:
function_args.append({"name": dyn_sym, "type": "int"})
# Format the function arguments for declaration
def_args = ", ".join([f"{arg['type']} {arg['name']}" for arg in function_args])
def func_call_args(s, function_args):
# Extract the function call arguments matching the function definition
def maybe_desc(name: str, matches: List[str], i: int):
match = matches[i]
if match != name + "_desc":
return False
desc_decls = []
if i > 0:
desc_decls.append(matches[i - 1])
if i < len(matches) - 1:
desc_decls.append(matches[i + 1])
return any([decl == "CUtensorMap" for decl in desc_decls])
pattern = r"[,\s]*(?:\w+\s*\*+\s*\s+)?(\w+)"
matches = re.findall(pattern, s)
call_args = []
for i, match in enumerate(matches):
for arg in function_args:
if arg["name"] == match or maybe_desc(arg["name"], matches, i):
call_args.append(match)
return call_args
def legalize_c(p):
# Convert TIR expressions to legal C expressions
# Directly convert to string since the special case handling
# does not alter the string representation for `tvm.tir.Var` and `IntImm`.
# Replace Python's floor division operator with C's division operator
if isinstance(p, tvm.tir.IntImm):
p = int(p)
return str(p).replace("//", "/")
_call_str = """"""
for function_name, _ in function_informations.items():
# Find the location of the global kernel function in the code
index = match_declare_kernel_cpu(code, function_name + "(")
# Analyze the function declaration to prepare for argument extraction
declaration = code[index:].split(";")[0]
# Identify the start of the function body to insert arguments
index = code.index("{", index)
call_args = ", ".join(func_call_args(declaration, function_args))
_call_str += "{}({})".format(function_name, call_args)
# Wrap the kernel dispatch logic in an external C function
host_func = self.CALL_PREFIX.format(def_args, _call_str)
return host_func
def parse_source_information(self):
device_mod, host_mod = get_annotated_mod(self.mod, self.target)
assert (len(device_mod.functions) >= 1), "Device module should have at least one function."
assert (len(host_mod.functions) == 1), "Only support one function in host module."
function_names = []
for g_var, _ in device_mod.functions.items():
function_name = g_var.name_hint
function_names.append(function_name)
self.function_names = function_names
def get_dynamic_symbolic_set(self, prim_func):
# Determine the set of dynamic symbols used in the function
dynamic_symbolic_set: List[str] = []
for param in prim_func.params:
if param in prim_func.buffer_map:
buffer = prim_func.buffer_map[param]
for dim in buffer.shape:
if isinstance(dim, tvm.tir.Var) and (dim.name not in dynamic_symbolic_set):
dynamic_symbolic_set.append(dim.name)
return dynamic_symbolic_set
def get_cpu_init_func(self):
init_funcs = self.INIT_FUNC
return init_funcs
def update_lib_code(self, code: str):
# Update the library code with the given code string
self.lib_code = code
# Get the function names
function_names = self.function_names
# Get the CPU initialization function
init_func = self.get_cpu_init_func()
# Organize function information for code generation
function_informations = {}
for function_name in function_names:
function_informations[function_name] = {
"function_name": function_name,
}
# Create the call function wrapper for the CPU kernel
call_func = self.create_call_func(code, function_informations)
# Combine the source, initialization function, and call function to form the complete library code
lib_code = self.source + init_func + call_func
return lib_code
@property
def prim_func(self):
if len(self.mod.get_global_vars()) == 1:
return self.mod[self.mod.get_global_vars()[0]]
elif "main" in self.mod:
return self.mod["main"]
else:
for _, function in self.mod.functions_items():
attr = function.attrs
if "tir.is_global_func" in attr and attr["tir.is_global_func"]:
return function
raise ValueError("Cannot find primary function in the module.")
class TLWrapper(BaseWrapper): class TLWrapper(BaseWrapper):
def __init__(self, target: Target): def __init__(self, target: Target):
...@@ -392,6 +569,8 @@ class TLWrapper(BaseWrapper): ...@@ -392,6 +569,8 @@ class TLWrapper(BaseWrapper):
wrapper_class = TLCUDASourceWrapper wrapper_class = TLCUDASourceWrapper
elif is_hip_target(self.target): elif is_hip_target(self.target):
wrapper_class = TLHIPSourceWrapper wrapper_class = TLHIPSourceWrapper
elif is_cpu_target(self.target):
wrapper_class = TLCPUSourceWrapper
else: else:
raise ValueError(f"Unsupported platform: {self.arch.platform}") raise ValueError(f"Unsupported platform: {self.arch.platform}")
wrapper = wrapper_class(self.scheduled_ir_module, c_source, self.target) wrapper = wrapper_class(self.scheduled_ir_module, c_source, self.target)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment