Commit c770a58f authored by Lei Wang's avatar Lei Wang Committed by LeiWang1999
Browse files

[Language] Introduce `T.alloc_var` to define a variable like `int var;` (#255)

* [Enhancement] Add matrix multiplication functions for integer and float variables in Cython JIT

- Introduced `matmul_int_variable` and `matmul_float_variable` functions to support matrix multiplication with dynamic shapes and additional parameters.
- Implemented corresponding `run_matmul_int_variable` and `run_matmul_float_variable` functions for testing.
- Updated test cases to validate the new matrix multiplication implementations.
- Enhanced error handling in library initialization and compilation processes across various modules.
- Improved dynamic memory handling in CUDA kernel initialization to provide better error reporting.

* lint fix

* optimize

* Support var defiine

* lint fix

* Update TVM submodule and add alloc_variable function to allocate local variables in TileLang

- Updated the TVM submodule to the latest commit.
- Introduced `alloc_variable` function in `allocate.py` to support local variable allocation with specified data types and scopes.

* lint fix

* Refactor variable allocation functions for consistency

- Renamed `alloc_variable` to `alloc_var` across multiple files for improved consistency.
- Updated corresponding test functions to reflect the new naming convention.
- Adjusted imports in `__init__.py` to align with the changes.
parent 316d3b97
Subproject commit c1c2a08a53f24886d2f82839fe304f2f1b6d0973 Subproject commit ed1cb8dd61d81193ab33da03b9fcc9c4a04c3b60
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
#include "../op/builtin.h" #include "../op/builtin.h"
#include "../op/bulk_copy.h" #include "../op/bulk_copy.h"
#include "arith/pattern_match.h"
#include "target/source/ptx.h" #include "target/source/ptx.h"
namespace tvm { namespace tvm {
...@@ -735,7 +736,13 @@ std::string CodeGenTileLangCUDA::GetBufferRef(DataType t, ...@@ -735,7 +736,13 @@ std::string CodeGenTileLangCUDA::GetBufferRef(DataType t,
temp << "(" << ptr_cast(buffer_element_dtype) << vid << ")"; temp << "(" << ptr_cast(buffer_element_dtype) << vid << ")";
buffer_str = temp.str(); buffer_str = temp.str();
} }
if (scope.empty()) {
scope = GetPtrStorageScope(buffer->data);
}
if (scope == "local.var") {
os << vid;
return os.str();
}
std::string index_str = PrintExpr(index); std::string index_str = PrintExpr(index);
if (t.bits() == 4 || (t.bits() == 1 && t.is_int())) { if (t.bits() == 4 || (t.bits() == 1 && t.is_int())) {
// This is a special case, because CodegenCUDA::PrintType() // This is a special case, because CodegenCUDA::PrintType()
...@@ -1274,7 +1281,6 @@ void CodeGenTileLangCUDA::VisitStmt_(const AttrStmtNode *op) { ...@@ -1274,7 +1281,6 @@ void CodeGenTileLangCUDA::VisitStmt_(const AttrStmtNode *op) {
void CodeGenTileLangCUDA::VisitStmt_(const AllocateNode *op) { void CodeGenTileLangCUDA::VisitStmt_(const AllocateNode *op) {
ICHECK(!is_zero(op->condition)); ICHECK(!is_zero(op->condition));
std::string vid = AllocVarID(op->buffer_var.get()); std::string vid = AllocVarID(op->buffer_var.get());
this->PrintIndent(); this->PrintIndent();
std::string scope = GetPtrStorageScope(op->buffer_var); std::string scope = GetPtrStorageScope(op->buffer_var);
const VarNode *buffer = op->buffer_var.as<VarNode>(); const VarNode *buffer = op->buffer_var.as<VarNode>();
...@@ -1312,7 +1318,16 @@ void CodeGenTileLangCUDA::VisitStmt_(const AllocateNode *op) { ...@@ -1312,7 +1318,16 @@ void CodeGenTileLangCUDA::VisitStmt_(const AllocateNode *op) {
scope == "shared") { scope == "shared") {
constant_size = constant_size / (32 / op->dtype.bits()); constant_size = constant_size / (32 / op->dtype.bits());
} }
stream << ' ' << vid << '[' << constant_size << "];\n"; if (scope == "shared") {
stream << ' ' << vid << '[' << constant_size << "];\n";
} else if (scope == "local") {
stream << ' ' << vid << '[' << constant_size << "];\n";
} else if (scope == "local.var") {
stream << ' ' << vid << " = " << PrintExpr(tir::make_const(op->dtype, 0))
<< ";\n";
} else {
ICHECK(false) << "Unsupported scope: " << scope;
}
} }
RegisterHandleType(op->buffer_var.get(), op->dtype); RegisterHandleType(op->buffer_var.get(), op->dtype);
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file lower_device_storage_access.cc
* \brief Lower the special device storage access.
*/
#include <tvm/arith/analyzer.h>
#include <tvm/runtime/registry.h>
#include <tvm/target/target_info.h>
#include <tvm/tir/buffer.h>
#include <tvm/tir/builtin.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>
#include "runtime/thread_storage_scope.h"
#include "tir/transforms/ir_utils.h"
namespace tvm {
namespace tl {
using namespace tir;
using runtime::StorageRank;
using runtime::StorageScope;
class StorageAccessInfoLower : public StmtExprMutator {
public:
Stmt VisitStmt_(const AllocateNode *op) final {
auto scope = StorageScope::Create(GetPtrStorageScope(op->buffer_var));
if (scope.tag.length() != 0 && scope.tag != ".dyn" && scope.tag != ".var") {
auto info = GetMemoryInfo(GetPtrStorageScope(op->buffer_var));
ICHECK(info.defined())
<< "Cannot find memory info of " << scope.to_string();
ICHECK(storage_info_.find(op->buffer_var.get()) == storage_info_.end())
<< "Double allocation of " << scope.to_string();
storage_info_[op->buffer_var.get()] = info;
// Lower allocate to device allocate when needed.
Stmt stmt = StmtExprMutator::VisitStmt_(op);
op = stmt.as<AllocateNode>();
if (info->head_address.defined()) {
return LetStmt(op->buffer_var, info->head_address, op->body);
} else {
return op->body;
}
} else {
return StmtExprMutator::VisitStmt_(op);
}
}
Stmt VisitStmt_(const DeclBufferNode *op) final {
auto node = Downcast<DeclBuffer>(StmtExprMutator::VisitStmt_(op));
if (auto it = storage_info_.find(node->buffer->data.get());
it != storage_info_.end() && !it->second->head_address.defined()) {
return node->body;
} else {
return std::move(node);
}
}
PrimExpr VisitExpr_(const CallNode *op) final {
if (op->op.same_as(builtin::tvm_access_ptr())) {
return MakeAccessPtr(op);
} else {
return StmtExprMutator::VisitExpr_(op);
}
}
private:
// tvm_access_ptr
PrimExpr MakeAccessPtr(const CallNode *op) {
// Specially handle the buffer packed intrinsic
PrimExpr expr = StmtExprMutator::VisitExpr_(op);
op = expr.as<CallNode>();
ICHECK_EQ(op->args.size(), 5U);
DataType dtype = op->args[0].dtype();
const VarNode *buffer = op->args[1].as<VarNode>();
Var buffer_var = Downcast<Var>(op->args[1]);
PrimExpr offset = op->args[2];
auto it = storage_info_.find(buffer);
if (it != storage_info_.end() && it->second.defined()) {
return MakeTaggedAccessPtr(op->dtype, buffer_var, dtype, offset,
it->second);
}
ICHECK(op->dtype.is_handle());
// Change to address_of
return AddressOffset(buffer_var, dtype, offset);
}
PrimExpr MakeTaggedAccessPtr(DataType ptr_type, Var buffer_var,
DataType dtype, PrimExpr offset,
const MemoryInfo &info) {
if (ptr_type.is_handle()) {
ICHECK(info->head_address.defined())
<< buffer_var << " is not adddressable.";
return AddressOffset(buffer_var, dtype, offset);
}
int dtype_bits = dtype.bits() * dtype.lanes();
ICHECK_EQ(info->unit_bits % dtype_bits, 0);
return cast(
ptr_type,
analyzer_.Simplify(
offset / make_const(offset.dtype(), info->unit_bits / dtype_bits)));
}
// The storage scope of each buffer
std::unordered_map<const VarNode *, MemoryInfo> storage_info_;
// analyzer
arith::Analyzer analyzer_;
};
Stmt LowerStorageAccessInfo(Stmt stmt) {
return StorageAccessInfoLower()(std::move(stmt));
}
namespace transform {
using namespace tir::transform;
Pass LowerDeviceStorageAccessInfo() {
auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
auto *n = f.CopyOnWrite();
n->body = StorageAccessInfoLower()(std::move(n->body));
return f;
};
return CreatePrimFuncPass(pass_func, 0, "tl.LowerDeviceStorageAccessInfo",
{});
}
TVM_REGISTER_GLOBAL("tl.transform.LowerDeviceStorageAccessInfo")
.set_body_typed(LowerDeviceStorageAccessInfo);
} // namespace transform
} // namespace tl
} // namespace tvm
import tilelang.testing
def alloc_var(
N,
block_N,
dtype,
):
import tilelang.language as T
@T.prim_func
def main(
A: T.Buffer((N,), dtype),
B: T.Buffer((N,), dtype),
):
with T.Kernel(T.ceildiv(N, block_N), threads=block_N) as bx:
A_shared = T.alloc_shared([block_N], dtype)
tmp = T.alloc_var(dtype)
tmp = 1 # noqa: F841
T.copy(A[bx * block_N], A_shared)
T.copy(A_shared, B[bx * block_N])
return main
def run_alloc_var(
N,
block_N,
dtype,
min=None,
max=None,
):
program = alloc_var(N, block_N, dtype)
kernel = tilelang.compile(program, out_idx=[1])
code = kernel.get_kernel_source()
assert "tmp =" in code
def test_alloc_var():
run_alloc_var(1024, 128, "float16")
def alloc_var_add(
N,
block_N,
dtype,
):
import tilelang.language as T
@T.prim_func
def main(
A: T.Buffer((N,), dtype),
B: T.Buffer((N,), dtype),
):
with T.Kernel(T.ceildiv(N, block_N), threads=block_N) as bx:
A_shared = T.alloc_shared([block_N], dtype)
tmp = T.alloc_var(dtype)
tmp = 1 # noqa: F841
T.copy(A[bx * block_N], A_shared)
for i in T.Parallel(block_N):
A_shared[i] = A_shared[i] + tmp
T.copy(A_shared, B[bx * block_N])
return main
def run_alloc_var_add(
N,
block_N,
dtype,
):
program = alloc_var_add(N, block_N, dtype)
kernel = tilelang.compile(program, out_idx=[1])
code = kernel.get_kernel_source()
assert "tmp =" in code
def test_alloc_var_add():
run_alloc_var_add(1024, 128, "float16")
if __name__ == "__main__":
tilelang.testing.main()
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
import os import os
import os.path as osp import os.path as osp
from typing import Union, Optional, Callable, List from typing import Union, Optional, Callable, List
import tilelang.transform
from tilelang import tvm as tvm from tilelang import tvm as tvm
from tvm import tir from tvm import tir
from tvm.ir import CallingConv from tvm.ir import CallingConv
...@@ -141,7 +142,7 @@ def host_codegen(host_mod: tvm.IRModule, target_host: Target) -> tvm.IRModule: ...@@ -141,7 +142,7 @@ def host_codegen(host_mod: tvm.IRModule, target_host: Target) -> tvm.IRModule:
host_mod = tir.transform.LowerTVMBuiltin()(host_mod) host_mod = tir.transform.LowerTVMBuiltin()(host_mod)
host_mod = tir.transform.LowerCustomDatatypes()(host_mod) host_mod = tir.transform.LowerCustomDatatypes()(host_mod)
host_mod = tir.transform.LowerIntrin()(host_mod) host_mod = tir.transform.LowerIntrin()(host_mod)
host_mod = tir.transform.LowerDeviceStorageAccessInfo()(host_mod) host_mod = tilelang.transform.LowerDeviceStorageAccessInfo()(host_mod)
host_mod = tir.transform.CombineContextCall()(host_mod) host_mod = tir.transform.CombineContextCall()(host_mod)
if target_host.kind.name == "llvm": if target_host.kind.name == "llvm":
host_mod = tvm._ffi.get_global_func("target.build.llvm")(host_mod, target_host) host_mod = tvm._ffi.get_global_func("target.build.llvm")(host_mod, target_host)
...@@ -153,7 +154,7 @@ def host_codegen(host_mod: tvm.IRModule, target_host: Target) -> tvm.IRModule: ...@@ -153,7 +154,7 @@ def host_codegen(host_mod: tvm.IRModule, target_host: Target) -> tvm.IRModule:
def device_codegen(device_mod: tvm.IRModule, target: Target) -> tvm.IRModule: def device_codegen(device_mod: tvm.IRModule, target: Target) -> tvm.IRModule:
device_mod = tir.transform.LowerDeviceStorageAccessInfo()(device_mod) device_mod = tilelang.transform.LowerDeviceStorageAccessInfo()(device_mod)
device_mod = tir.transform.LowerIntrin()(device_mod) device_mod = tir.transform.LowerIntrin()(device_mod)
device_mod = tir.transform.Simplify()(device_mod) device_mod = tir.transform.Simplify()(device_mod)
...@@ -168,10 +169,9 @@ def device_codegen(device_mod: tvm.IRModule, target: Target) -> tvm.IRModule: ...@@ -168,10 +169,9 @@ def device_codegen(device_mod: tvm.IRModule, target: Target) -> tvm.IRModule:
def device_codegen_without_compile(device_mod: tvm.IRModule, target: Target) -> tvm.IRModule: def device_codegen_without_compile(device_mod: tvm.IRModule, target: Target) -> tvm.IRModule:
device_mod = tir.transform.LowerDeviceStorageAccessInfo()(device_mod) device_mod = tilelang.transform.LowerDeviceStorageAccessInfo()(device_mod)
device_mod = tir.transform.LowerIntrin()(device_mod) device_mod = tir.transform.LowerIntrin()(device_mod)
device_mod = tir.transform.Simplify()(device_mod) device_mod = tir.transform.Simplify()(device_mod)
if target.kind.name == "cuda": if target.kind.name == "cuda":
device_mod = tvm._ffi.get_global_func("target.build.tilelang_cuda_without_compile")( device_mod = tvm._ffi.get_global_func("target.build.tilelang_cuda_without_compile")(
device_mod, target) device_mod, target)
......
...@@ -13,6 +13,7 @@ from .allocate import ( ...@@ -13,6 +13,7 @@ from .allocate import (
alloc_local, # noqa: F401 alloc_local, # noqa: F401
alloc_shared, # noqa: F401 alloc_shared, # noqa: F401
alloc_fragment, # noqa: F401 alloc_fragment, # noqa: F401
alloc_var, # noqa: F401
) )
from .copy import copy, c2d_im2col # noqa: F401 from .copy import copy, c2d_im2col # noqa: F401
from .gemm import GemmWarpPolicy, gemm # noqa: F401 from .gemm import GemmWarpPolicy, gemm # noqa: F401
...@@ -35,6 +36,8 @@ from .customize import ( ...@@ -35,6 +36,8 @@ from .customize import (
) )
from .builtin import * # noqa: F401 from .builtin import * # noqa: F401
from .memscope import * # noqa: F401
def symbolic(name: str, dtype: str = "int32"): def symbolic(name: str, dtype: str = "int32"):
return tir.Var(name, dtype) return tir.Var(name, dtype)
......
...@@ -13,3 +13,7 @@ def alloc_local(shape, dtype, scope="local"): ...@@ -13,3 +13,7 @@ def alloc_local(shape, dtype, scope="local"):
def alloc_fragment(shape, dtype, scope="local.fragment"): def alloc_fragment(shape, dtype, scope="local.fragment"):
return T.alloc_buffer(shape, dtype, scope=scope) return T.alloc_buffer(shape, dtype, scope=scope)
def alloc_var(dtype, scope="local.var"):
return T.alloc_buffer([1], dtype, scope=scope)
from tvm._ffi.registry import register_func
from tvm.ir import make_node
@register_func("tvm.info.mem.local.var")
def mem_info_local_var():
return make_node(
"MemoryInfo",
unit_bits=8,
max_num_bits=64,
max_simd_bits=128,
head_address=None,
)
...@@ -256,3 +256,18 @@ def InjectPTXAsyncCopy(): ...@@ -256,3 +256,18 @@ def InjectPTXAsyncCopy():
The result pass The result pass
""" """
return _ffi_api.InjectPTXAsyncCopy() # type: ignore return _ffi_api.InjectPTXAsyncCopy() # type: ignore
def LowerDeviceStorageAccessInfo():
"""Lower attached storage access information on device.
Returns
-------
fpass : tvm.transform.Pass
The result pass
Note
----
Run this pass after all storage access analysis finish.
"""
return _ffi_api.LowerDeviceStorageAccessInfo() # type: ignore
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment