Unverified Commit 49c85715 authored by Elevator14B's avatar Elevator14B Committed by GitHub
Browse files

Fix various issues under `int64_t` static and dynamic shape. (#1218)



* Fix various issues under int64_t static and dynamic shape.

* Resolve reviewed issues.

* Add unit test.

* fix

---------
Co-authored-by: default avatarLeiWang1999 <leiwang1999@outlook.com>
parent e805f8e5
......@@ -6,6 +6,7 @@
#include "tvm/node/structural_hash.h"
#include "tvm/tir/builtin.h"
#include "tvm/tir/expr.h"
#include "tvm/tir/op.h"
#include "tvm/tir/stmt.h"
#include "tvm/tir/stmt_functor.h"
#include "tvm/tir/transform.h"
......@@ -62,7 +63,8 @@ private:
Stmt build(Stmt body) {
auto analyzer = arith::Analyzer{};
for (const auto &e : items) {
auto simplified = analyzer.Simplify(GT(e.expr, 0));
auto simplified =
analyzer.Simplify(GT(e.expr, make_zero(e.expr->dtype)));
std::stringstream ss;
ss << "Buffer shape should be greater than 0: shape `" << e.expr
<< "` from buffer ";
......
import tilelang
import tilelang.language as T
@tilelang.jit
def fill_symbolic(value: float, dtype="bfloat16"):
n = T.symbolic("n", "int64")
block_n = 512
@T.prim_func
def main(x: T.Tensor[n, dtype]):
# Initialize Kernel Context
with T.Kernel(T.ceildiv(n, block_n), threads=128) as bx:
# Doesn't yet work with int64-shaped global tensor
# T.fill(x[bx * block_n : (bx + 1) * block_n], value)
for i in T.Parallel(block_n):
x[bx * block_n + i] = value
return main
def run_fill_symbolic(n: int):
import torch
x = torch.zeros(n, dtype=torch.bfloat16, device="cuda")
fill_symbolic(1.0)(x)
assert x.min() == 1.0 and x.max() == 1.0
def test_fill_symbolic():
# Requires 8GB VRAM
run_fill_symbolic(2**32)
@tilelang.jit
def fill_static(n: int, value: float, dtype="bfloat16"):
block_n = 512
@T.prim_func
def main(x: T.Tensor[n, dtype]):
# Initialize Kernel Context
with T.Kernel(T.ceildiv(n, block_n), threads=128) as bx:
# Doesn't yet work with int64-shaped global tensor
# T.fill(x[bx * block_n : (bx + 1) * block_n], value)
for i in T.Parallel(block_n):
x[bx * block_n + i] = value
return main
def run_fill_static(n: int):
import torch
x = torch.zeros(n, dtype=torch.bfloat16, device="cuda")
fill_static(n, 1.0)(x)
assert x.min() == 1.0 and x.max() == 1.0
def test_fill_static():
# Requires 8GB VRAM
run_fill_static(2**32)
if __name__ == "__main__":
test_fill_symbolic()
test_fill_static()
......@@ -267,9 +267,9 @@ cdef class CythonKernelWrapper:
# Add dynamic dimension values to kernel arguments
for _, (ref_id, buffer_idx, shape_idx) in self.dynamic_symbolic_map.items():
if ref_id == 0:
call_args.append(tensor_list[buffer_idx].shape[shape_idx])
call_args.append(ctypes.c_int64(tensor_list[buffer_idx].shape[shape_idx]))
else:
call_args.append(tensor_list[buffer_idx].stride(shape_idx))
call_args.append(ctypes.c_int64(tensor_list[buffer_idx].stride(shape_idx)))
# Add CUDA stream to kernel arguments
call_args.append(ctypes.c_void_p(stream))
......
......@@ -313,9 +313,9 @@ class TLNVRTCSourceWrapper(TLCUDASourceWrapper):
raise ValueError(
f"Parameter {param} is not in the buffer map of the primary function.")
# Add dynamic symbols as integer arguments
for dyn_sym in dynamic_symbolic_set:
for dyn_sym, dyn_sym_dtype in dynamic_symbolic_set:
if dyn_sym not in [arg["name"] for arg in function_args]:
function_args.append({"name": dyn_sym, "type": "ctypes.c_int"})
function_args.append({"name": dyn_sym, "type": self._lookup_type(dyn_sym_dtype)})
function_args.append(self.get_stream_type())
......
......@@ -220,9 +220,9 @@ class TLCUDASourceWrapper:
raise ValueError(
f"Parameter {param} is not in the buffer map of the primary function.")
# Add dynamic symbols as integer arguments
for dyn_sym in dynamic_symbolic_set:
for dyn_sym, dyn_sym_dtype in dynamic_symbolic_set:
if dyn_sym not in [arg["name"] for arg in function_args]:
function_args.append({"name": dyn_sym, "type": "int"})
function_args.append({"name": dyn_sym, "type": self._lookup_type(dyn_sym_dtype)})
function_args.append(self.get_stream_type())
......@@ -405,18 +405,20 @@ class TLCUDASourceWrapper:
def get_dynamic_symbolic_set(self, prim_func):
# Determine the set of dynamic symbols used in the function
dynamic_symbolic_set: list[str] = []
dynamic_symbolic_set: dict[str, str] = {}
def unique_push_back(name: str):
def unique_push_back(name: str, dtype: str):
if name not in dynamic_symbolic_set:
dynamic_symbolic_set.append(name)
dynamic_symbolic_set[name] = dtype
else:
assert dtype == dynamic_symbolic_set[name]
for param in prim_func.params:
if param in prim_func.buffer_map:
buffer = prim_func.buffer_map[param]
for dim in buffer.shape:
if isinstance(dim, tvm.tir.Var):
unique_push_back(dim.name)
unique_push_back(dim.name, str(dim.dtype))
# Note: In buffer definitions, any dynamic symbols appearing in strides are listed after those in the shape.
for param in prim_func.params:
......@@ -424,9 +426,9 @@ class TLCUDASourceWrapper:
buffer = prim_func.buffer_map[param]
for stride in buffer.strides:
if isinstance(stride, tvm.tir.Var):
unique_push_back(stride.name)
unique_push_back(stride.name, str(stride.dtype))
return dynamic_symbolic_set
return list(dynamic_symbolic_set.items())
def get_init_func(self):
# Initialize an empty string for the CUDA function call
......@@ -665,8 +667,8 @@ class TLCPUSourceWrapper:
raise ValueError(
f"Parameter {param} is not in the buffer map of the primary function.")
# Add dynamic symbols as integer arguments
for dyn_sym in dynamic_symbolic_set:
function_args.append({"name": dyn_sym, "type": "int"})
for dyn_sym, dyn_sym_dtype in dynamic_symbolic_set:
function_args.append({"name": dyn_sym, "type": self._lookup_type(dyn_sym_dtype)})
# Format the function arguments for declaration
def_args = ", ".join([f"{arg['type']} {arg['name']}" for arg in function_args])
......@@ -715,14 +717,14 @@ class TLCPUSourceWrapper:
def get_dynamic_symbolic_set(self, prim_func):
# Determine the set of dynamic symbols used in the function
dynamic_symbolic_set: list[str] = []
dynamic_symbolic_set: dict[str, str] = {}
for param in prim_func.params:
if param in prim_func.buffer_map:
buffer = prim_func.buffer_map[param]
for dim in buffer.shape:
if isinstance(dim, tvm.tir.Var) and (dim.name not in dynamic_symbolic_set):
dynamic_symbolic_set.append(dim.name)
return dynamic_symbolic_set
dynamic_symbolic_set[dim.name] = str(dim.dtype)
return list(dynamic_symbolic_set.items())
def get_cpu_init_func(self):
# Provide init() and get_last_error() for CPU backend
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment