Commit 782ca9f6 authored by xs-keju's avatar xs-keju Committed by GitHub
Browse files

Add cpu jit with backend ctypes (#154)



* Add cpu jit with backend ctypes

* Resolve some lint issues

* Apply PR feedback on head file and kernel example

* Add test cases

* Resolve formatting issues

* Resolve formatting issues

---------
Co-authored-by: default avatarxxw <1990389406@qq.con>
parent 3486e27e
#pragma once
#include "half.hpp"
#include <math.h>
#include <stdbool.h>
using half_float::half;
// Not Implemented
\ No newline at end of file
#pragma once
// Not Implemented
This diff is collapsed.
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.
#pragma once
#include <math.h>
#include <stdbool.h>
// Not Implemented
F
\ No newline at end of file
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.
#pragma once
// Not Implemented
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import tilelang
import tilelang.testing
from tilelang import tvm as tvm
import tilelang.language as T
import torch
def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"):
......@@ -62,5 +60,58 @@ def test_matmul_codegen():
assert_matmul_codegen(M=1024, N=1024, K=1024, block_M=128, block_N=128, block_K=32)
def test_matmul_compile():
def matmul_jit_test(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"):
# a simple kernel just for jit test
@T.prim_func
def matmul(
A: T.Buffer((M, K), dtype),
B: T.Buffer((K, N), dtype),
C: T.Buffer((M, N), dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), is_cpu=True) as (bx, by):
A_local = T.alloc_local((block_M, block_K), dtype)
B_local = T.alloc_local((block_K, block_N), dtype)
C_local = T.alloc_local((block_M, block_N), accum_dtype)
for p in T.serial(block_M):
for w in T.serial(block_N):
C_local[p, w] = 0
for ko in T.serial(K // block_K):
for i in T.serial(block_M):
for k in T.serial(block_K):
A_local[i, k] = A[by * block_M + i, ko * block_K + k]
for k in T.serial(block_K):
for j in T.serial(block_N):
B_local[k, j] = B[ko * block_K + k, bx * block_N + j]
for i in T.serial(block_M):
for j in T.serial(block_N):
for k in T.serial(block_K):
C_local[i, j] += A_local[i, k] * B_local[k, j]
for i in T.serial(block_M):
for j in T.serial(block_N):
C[by * block_M + i, bx * block_N + j] = C_local[i, j]
return matmul
M, N, K = 1024, 512, 512
block_M, block_N, block_K = M // 4, N // 4, K // 4
cpu_func = matmul_jit_test(M, N, K, block_M, block_N, block_K)
complied_fun = tilelang.compile(cpu_func, -1, execution_backend="ctypes", target="c")
in_dtype = "float16"
A = torch.randn(M, K, dtype=torch.__getattribute__(in_dtype))
B = torch.randn(K, N, dtype=torch.__getattribute__(in_dtype))
C = complied_fun(A, B)
C_torch = torch.matmul(A, B)
tilelang.testing.torch_assert_close(C, C_torch, atol=1e-2, rtol=1e-2, max_mismatched_ratio=0.05)
if __name__ == "__main__":
tilelang.testing.main()
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from typing import Optional
from .utils import is_cuda_target, is_hip_target
from .utils import is_cuda_target, is_hip_target, is_cpu_target
from tilelang import tvm as tvm
from tilelang.contrib.nvcc import get_target_compute_version
from tvm.target import Target
......@@ -41,8 +41,8 @@ class LibraryGenerator(object):
command = [
"nvcc",
"-std=c++17",
"-w", # Disable all warning messages
"-std=c++17",
"-w", # Disable all warning messages
"-Xcudafe",
"--diag_suppress=177",
"--compiler-options",
......@@ -66,7 +66,15 @@ class LibraryGenerator(object):
"--shared",
src.name,
]
elif is_cpu_target(target):
src = tempfile.NamedTemporaryFile(mode="w", suffix=".cpp", delete=False)
libpath = src.name.replace(".cpp", ".so")
command = ["g++", "-std=c++17", "-fPIC", "-shared", src.name]
with_tl = False
command += [
"-I" + TILELANG_TEMPLATE_PATH,
]
else:
raise ValueError(f"Unsupported target: {target}")
......
......@@ -38,6 +38,16 @@ def match_declare_kernel(source: str, annotation: str = "__global__") -> int:
raise ValueError("No global kernel found in the source code")
def match_declare_kernel_cpu(source: str, annotation: str = "int32_t") -> int:
pattern = r"int32_t\s+\w+"
for line in source.split("\n"):
if annotation in line:
matched = re.findall(pattern, line)
if len(matched) >= 1:
return source.index(matched[0] + "(")
raise ValueError("No global kernel found in the source code")
def is_cuda_target(target: Target) -> bool:
return target.kind.name == "cuda"
......@@ -46,6 +56,10 @@ def is_hip_target(target: Target) -> bool:
return target.kind.name == "hip"
def is_cpu_target(target: Target) -> bool:
return target.kind.name in ["c"]
def get_annotated_mod(
func_or_mod: Union[tir.PrimFunc, tvm.IRModule],
target: Union[str, Target] = "auto",
......
......@@ -6,9 +6,10 @@ from tilelang import tvm as tvm
from typing import Optional, List, Dict, Union
from tvm import IRModule
from tvm.target import Target
from .utils import match_declare_kernel, is_cuda_target, is_hip_target, get_annotated_mod
from .utils import match_declare_kernel, match_declare_kernel_cpu, is_cuda_target, is_hip_target, is_cpu_target, get_annotated_mod
import re
import logging
import textwrap
PREDEF_ARRTIBUTE_SET_DYNAMIC_MEMORY = """
cudaFuncSetAttribute({}, cudaFuncAttributeMaxDynamicSharedMemorySize, {});
......@@ -374,6 +375,182 @@ class TLHIPSourceWrapper(TLCUDASourceWrapper):
function_args.append({"name": "stream=hipStreamDefault", "type": "hipStream_t"},)
class TLCPUSourceWrapper(object):
_TYPE_MAP = {
"float32": "float",
"float16": "half",
"int32": "int32_t",
}
INIT_FUNC = textwrap.dedent('''
#ifdef __cplusplus
extern "C"
#endif
int32_t init() {
return 0;
}
''')
CALL_PREFIX = textwrap.dedent("""
#ifdef __cplusplus
extern "C"
#endif
int32_t call({}) {{
return {};
}}
""")
backend = "tl"
backend = "tl"
def __init__(self, scheduled_ir_module: IRModule, source: str, target: Target):
self.mod = scheduled_ir_module
self.target = target
self.source = source
self.function_names: Optional[str] = None
self.dynamic_smem_buf: Optional[int] = None
self.parse_source_information()
self.srcpath: Optional[str] = None
self.libpath: Optional[str] = None
self.lib_code: Optional[str] = self.update_lib_code(source)
def create_call_func(self, code, function_informations):
# Extract the set of dynamic symbolic names used in the primary function
dynamic_symbolic_set = self.get_dynamic_symbolic_set(self.prim_func)
function_args = []
# Collect function arguments based on primary function's parameters and buffer mappings
for param in self.prim_func.params:
if param in self.prim_func.buffer_map:
buffer = self.prim_func.buffer_map[param]
function_args.append({
"name": buffer.name,
"type": self._TYPE_MAP[buffer.dtype] + "*",
})
elif isinstance(param, tvm.tir.Var):
function_args.append({"name": param.name, "type": self._TYPE_MAP[param.dtype]})
else:
raise ValueError(
f"Parameter {param} is not in the buffer map of the primary function.")
# Add dynamic symbols as integer arguments
for dyn_sym in dynamic_symbolic_set:
function_args.append({"name": dyn_sym, "type": "int"})
# Format the function arguments for declaration
def_args = ", ".join([f"{arg['type']} {arg['name']}" for arg in function_args])
def func_call_args(s, function_args):
# Extract the function call arguments matching the function definition
def maybe_desc(name: str, matches: List[str], i: int):
match = matches[i]
if match != name + "_desc":
return False
desc_decls = []
if i > 0:
desc_decls.append(matches[i - 1])
if i < len(matches) - 1:
desc_decls.append(matches[i + 1])
return any([decl == "CUtensorMap" for decl in desc_decls])
pattern = r"[,\s]*(?:\w+\s*\*+\s*\s+)?(\w+)"
matches = re.findall(pattern, s)
call_args = []
for i, match in enumerate(matches):
for arg in function_args:
if arg["name"] == match or maybe_desc(arg["name"], matches, i):
call_args.append(match)
return call_args
def legalize_c(p):
# Convert TIR expressions to legal C expressions
# Directly convert to string since the special case handling
# does not alter the string representation for `tvm.tir.Var` and `IntImm`.
# Replace Python's floor division operator with C's division operator
if isinstance(p, tvm.tir.IntImm):
p = int(p)
return str(p).replace("//", "/")
_call_str = """"""
for function_name, _ in function_informations.items():
# Find the location of the global kernel function in the code
index = match_declare_kernel_cpu(code, function_name + "(")
# Analyze the function declaration to prepare for argument extraction
declaration = code[index:].split(";")[0]
# Identify the start of the function body to insert arguments
index = code.index("{", index)
call_args = ", ".join(func_call_args(declaration, function_args))
_call_str += "{}({})".format(function_name, call_args)
# Wrap the kernel dispatch logic in an external C function
host_func = self.CALL_PREFIX.format(def_args, _call_str)
return host_func
def parse_source_information(self):
device_mod, host_mod = get_annotated_mod(self.mod, self.target)
assert (len(device_mod.functions) >= 1), "Device module should have at least one function."
assert (len(host_mod.functions) == 1), "Only support one function in host module."
function_names = []
for g_var, _ in device_mod.functions.items():
function_name = g_var.name_hint
function_names.append(function_name)
self.function_names = function_names
def get_dynamic_symbolic_set(self, prim_func):
# Determine the set of dynamic symbols used in the function
dynamic_symbolic_set: List[str] = []
for param in prim_func.params:
if param in prim_func.buffer_map:
buffer = prim_func.buffer_map[param]
for dim in buffer.shape:
if isinstance(dim, tvm.tir.Var) and (dim.name not in dynamic_symbolic_set):
dynamic_symbolic_set.append(dim.name)
return dynamic_symbolic_set
def get_cpu_init_func(self):
init_funcs = self.INIT_FUNC
return init_funcs
def update_lib_code(self, code: str):
# Update the library code with the given code string
self.lib_code = code
# Get the function names
function_names = self.function_names
# Get the CPU initialization function
init_func = self.get_cpu_init_func()
# Organize function information for code generation
function_informations = {}
for function_name in function_names:
function_informations[function_name] = {
"function_name": function_name,
}
# Create the call function wrapper for the CPU kernel
call_func = self.create_call_func(code, function_informations)
# Combine the source, initialization function, and call function to form the complete library code
lib_code = self.source + init_func + call_func
return lib_code
@property
def prim_func(self):
if len(self.mod.get_global_vars()) == 1:
return self.mod[self.mod.get_global_vars()[0]]
elif "main" in self.mod:
return self.mod["main"]
else:
for _, function in self.mod.functions_items():
attr = function.attrs
if "tir.is_global_func" in attr and attr["tir.is_global_func"]:
return function
raise ValueError("Cannot find primary function in the module.")
class TLWrapper(BaseWrapper):
def __init__(self, target: Target):
......@@ -392,6 +569,8 @@ class TLWrapper(BaseWrapper):
wrapper_class = TLCUDASourceWrapper
elif is_hip_target(self.target):
wrapper_class = TLHIPSourceWrapper
elif is_cpu_target(self.target):
wrapper_class = TLCPUSourceWrapper
else:
raise ValueError(f"Unsupported platform: {self.arch.platform}")
wrapper = wrapper_class(self.scheduled_ir_module, c_source, self.target)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment