Commit c770a58f authored by Lei Wang's avatar Lei Wang Committed by LeiWang1999
Browse files

[Language] Introduce `T.alloc_var` to define a variable like `int var;` (#255)

* [Enhancement] Add matrix multiplication functions for integer and float variables in Cython JIT

- Introduced `matmul_int_variable` and `matmul_float_variable` functions to support matrix multiplication with dynamic shapes and additional parameters.
- Implemented corresponding `run_matmul_int_variable` and `run_matmul_float_variable` functions for testing.
- Updated test cases to validate the new matrix multiplication implementations.
- Enhanced error handling in library initialization and compilation processes across various modules.
- Improved dynamic memory handling in CUDA kernel initialization to provide better error reporting.

* lint fix

* optimize

* Support var defiine

* lint fix

* Update TVM submodule and add alloc_variable function to allocate local variables in TileLang

- Updated the TVM submodule to the latest commit.
- Introduced `alloc_variable` function in `allocate.py` to support local variable allocation with specified data types and scopes.

* lint fix

* Refactor variable allocation functions for consistency

- Renamed `alloc_variable` to `alloc_var` across multiple files for improved consistency.
- Updated corresponding test functions to reflect the new naming convention.
- Adjusted imports in `__init__.py` to align with the changes.
parent 316d3b97
Subproject commit c1c2a08a53f24886d2f82839fe304f2f1b6d0973
Subproject commit ed1cb8dd61d81193ab33da03b9fcc9c4a04c3b60
......@@ -15,6 +15,7 @@
#include "../op/builtin.h"
#include "../op/bulk_copy.h"
#include "arith/pattern_match.h"
#include "target/source/ptx.h"
namespace tvm {
......@@ -735,7 +736,13 @@ std::string CodeGenTileLangCUDA::GetBufferRef(DataType t,
temp << "(" << ptr_cast(buffer_element_dtype) << vid << ")";
buffer_str = temp.str();
}
if (scope.empty()) {
scope = GetPtrStorageScope(buffer->data);
}
if (scope == "local.var") {
os << vid;
return os.str();
}
std::string index_str = PrintExpr(index);
if (t.bits() == 4 || (t.bits() == 1 && t.is_int())) {
// This is a special case, because CodegenCUDA::PrintType()
......@@ -1274,7 +1281,6 @@ void CodeGenTileLangCUDA::VisitStmt_(const AttrStmtNode *op) {
void CodeGenTileLangCUDA::VisitStmt_(const AllocateNode *op) {
ICHECK(!is_zero(op->condition));
std::string vid = AllocVarID(op->buffer_var.get());
this->PrintIndent();
std::string scope = GetPtrStorageScope(op->buffer_var);
const VarNode *buffer = op->buffer_var.as<VarNode>();
......@@ -1312,7 +1318,16 @@ void CodeGenTileLangCUDA::VisitStmt_(const AllocateNode *op) {
scope == "shared") {
constant_size = constant_size / (32 / op->dtype.bits());
}
if (scope == "shared") {
stream << ' ' << vid << '[' << constant_size << "];\n";
} else if (scope == "local") {
stream << ' ' << vid << '[' << constant_size << "];\n";
} else if (scope == "local.var") {
stream << ' ' << vid << " = " << PrintExpr(tir::make_const(op->dtype, 0))
<< ";\n";
} else {
ICHECK(false) << "Unsupported scope: " << scope;
}
}
RegisterHandleType(op->buffer_var.get(), op->dtype);
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file lower_device_storage_access.cc
* \brief Lower the special device storage access.
*/
#include <tvm/arith/analyzer.h>
#include <tvm/runtime/registry.h>
#include <tvm/target/target_info.h>
#include <tvm/tir/buffer.h>
#include <tvm/tir/builtin.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>
#include "runtime/thread_storage_scope.h"
#include "tir/transforms/ir_utils.h"
namespace tvm {
namespace tl {
using namespace tir;
using runtime::StorageRank;
using runtime::StorageScope;
class StorageAccessInfoLower : public StmtExprMutator {
public:
Stmt VisitStmt_(const AllocateNode *op) final {
auto scope = StorageScope::Create(GetPtrStorageScope(op->buffer_var));
if (scope.tag.length() != 0 && scope.tag != ".dyn" && scope.tag != ".var") {
auto info = GetMemoryInfo(GetPtrStorageScope(op->buffer_var));
ICHECK(info.defined())
<< "Cannot find memory info of " << scope.to_string();
ICHECK(storage_info_.find(op->buffer_var.get()) == storage_info_.end())
<< "Double allocation of " << scope.to_string();
storage_info_[op->buffer_var.get()] = info;
// Lower allocate to device allocate when needed.
Stmt stmt = StmtExprMutator::VisitStmt_(op);
op = stmt.as<AllocateNode>();
if (info->head_address.defined()) {
return LetStmt(op->buffer_var, info->head_address, op->body);
} else {
return op->body;
}
} else {
return StmtExprMutator::VisitStmt_(op);
}
}
Stmt VisitStmt_(const DeclBufferNode *op) final {
auto node = Downcast<DeclBuffer>(StmtExprMutator::VisitStmt_(op));
if (auto it = storage_info_.find(node->buffer->data.get());
it != storage_info_.end() && !it->second->head_address.defined()) {
return node->body;
} else {
return std::move(node);
}
}
PrimExpr VisitExpr_(const CallNode *op) final {
if (op->op.same_as(builtin::tvm_access_ptr())) {
return MakeAccessPtr(op);
} else {
return StmtExprMutator::VisitExpr_(op);
}
}
private:
// tvm_access_ptr
PrimExpr MakeAccessPtr(const CallNode *op) {
// Specially handle the buffer packed intrinsic
PrimExpr expr = StmtExprMutator::VisitExpr_(op);
op = expr.as<CallNode>();
ICHECK_EQ(op->args.size(), 5U);
DataType dtype = op->args[0].dtype();
const VarNode *buffer = op->args[1].as<VarNode>();
Var buffer_var = Downcast<Var>(op->args[1]);
PrimExpr offset = op->args[2];
auto it = storage_info_.find(buffer);
if (it != storage_info_.end() && it->second.defined()) {
return MakeTaggedAccessPtr(op->dtype, buffer_var, dtype, offset,
it->second);
}
ICHECK(op->dtype.is_handle());
// Change to address_of
return AddressOffset(buffer_var, dtype, offset);
}
PrimExpr MakeTaggedAccessPtr(DataType ptr_type, Var buffer_var,
DataType dtype, PrimExpr offset,
const MemoryInfo &info) {
if (ptr_type.is_handle()) {
ICHECK(info->head_address.defined())
<< buffer_var << " is not adddressable.";
return AddressOffset(buffer_var, dtype, offset);
}
int dtype_bits = dtype.bits() * dtype.lanes();
ICHECK_EQ(info->unit_bits % dtype_bits, 0);
return cast(
ptr_type,
analyzer_.Simplify(
offset / make_const(offset.dtype(), info->unit_bits / dtype_bits)));
}
// The storage scope of each buffer
std::unordered_map<const VarNode *, MemoryInfo> storage_info_;
// analyzer
arith::Analyzer analyzer_;
};
Stmt LowerStorageAccessInfo(Stmt stmt) {
return StorageAccessInfoLower()(std::move(stmt));
}
namespace transform {
using namespace tir::transform;
Pass LowerDeviceStorageAccessInfo() {
auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
auto *n = f.CopyOnWrite();
n->body = StorageAccessInfoLower()(std::move(n->body));
return f;
};
return CreatePrimFuncPass(pass_func, 0, "tl.LowerDeviceStorageAccessInfo",
{});
}
TVM_REGISTER_GLOBAL("tl.transform.LowerDeviceStorageAccessInfo")
.set_body_typed(LowerDeviceStorageAccessInfo);
} // namespace transform
} // namespace tl
} // namespace tvm
import tilelang.testing
def alloc_var(
N,
block_N,
dtype,
):
import tilelang.language as T
@T.prim_func
def main(
A: T.Buffer((N,), dtype),
B: T.Buffer((N,), dtype),
):
with T.Kernel(T.ceildiv(N, block_N), threads=block_N) as bx:
A_shared = T.alloc_shared([block_N], dtype)
tmp = T.alloc_var(dtype)
tmp = 1 # noqa: F841
T.copy(A[bx * block_N], A_shared)
T.copy(A_shared, B[bx * block_N])
return main
def run_alloc_var(
N,
block_N,
dtype,
min=None,
max=None,
):
program = alloc_var(N, block_N, dtype)
kernel = tilelang.compile(program, out_idx=[1])
code = kernel.get_kernel_source()
assert "tmp =" in code
def test_alloc_var():
run_alloc_var(1024, 128, "float16")
def alloc_var_add(
N,
block_N,
dtype,
):
import tilelang.language as T
@T.prim_func
def main(
A: T.Buffer((N,), dtype),
B: T.Buffer((N,), dtype),
):
with T.Kernel(T.ceildiv(N, block_N), threads=block_N) as bx:
A_shared = T.alloc_shared([block_N], dtype)
tmp = T.alloc_var(dtype)
tmp = 1 # noqa: F841
T.copy(A[bx * block_N], A_shared)
for i in T.Parallel(block_N):
A_shared[i] = A_shared[i] + tmp
T.copy(A_shared, B[bx * block_N])
return main
def run_alloc_var_add(
N,
block_N,
dtype,
):
program = alloc_var_add(N, block_N, dtype)
kernel = tilelang.compile(program, out_idx=[1])
code = kernel.get_kernel_source()
assert "tmp =" in code
def test_alloc_var_add():
run_alloc_var_add(1024, 128, "float16")
if __name__ == "__main__":
tilelang.testing.main()
......@@ -3,6 +3,7 @@
import os
import os.path as osp
from typing import Union, Optional, Callable, List
import tilelang.transform
from tilelang import tvm as tvm
from tvm import tir
from tvm.ir import CallingConv
......@@ -141,7 +142,7 @@ def host_codegen(host_mod: tvm.IRModule, target_host: Target) -> tvm.IRModule:
host_mod = tir.transform.LowerTVMBuiltin()(host_mod)
host_mod = tir.transform.LowerCustomDatatypes()(host_mod)
host_mod = tir.transform.LowerIntrin()(host_mod)
host_mod = tir.transform.LowerDeviceStorageAccessInfo()(host_mod)
host_mod = tilelang.transform.LowerDeviceStorageAccessInfo()(host_mod)
host_mod = tir.transform.CombineContextCall()(host_mod)
if target_host.kind.name == "llvm":
host_mod = tvm._ffi.get_global_func("target.build.llvm")(host_mod, target_host)
......@@ -153,7 +154,7 @@ def host_codegen(host_mod: tvm.IRModule, target_host: Target) -> tvm.IRModule:
def device_codegen(device_mod: tvm.IRModule, target: Target) -> tvm.IRModule:
device_mod = tir.transform.LowerDeviceStorageAccessInfo()(device_mod)
device_mod = tilelang.transform.LowerDeviceStorageAccessInfo()(device_mod)
device_mod = tir.transform.LowerIntrin()(device_mod)
device_mod = tir.transform.Simplify()(device_mod)
......@@ -168,10 +169,9 @@ def device_codegen(device_mod: tvm.IRModule, target: Target) -> tvm.IRModule:
def device_codegen_without_compile(device_mod: tvm.IRModule, target: Target) -> tvm.IRModule:
device_mod = tir.transform.LowerDeviceStorageAccessInfo()(device_mod)
device_mod = tilelang.transform.LowerDeviceStorageAccessInfo()(device_mod)
device_mod = tir.transform.LowerIntrin()(device_mod)
device_mod = tir.transform.Simplify()(device_mod)
if target.kind.name == "cuda":
device_mod = tvm._ffi.get_global_func("target.build.tilelang_cuda_without_compile")(
device_mod, target)
......
......@@ -13,6 +13,7 @@ from .allocate import (
alloc_local, # noqa: F401
alloc_shared, # noqa: F401
alloc_fragment, # noqa: F401
alloc_var, # noqa: F401
)
from .copy import copy, c2d_im2col # noqa: F401
from .gemm import GemmWarpPolicy, gemm # noqa: F401
......@@ -35,6 +36,8 @@ from .customize import (
)
from .builtin import * # noqa: F401
from .memscope import * # noqa: F401
def symbolic(name: str, dtype: str = "int32"):
return tir.Var(name, dtype)
......
......@@ -13,3 +13,7 @@ def alloc_local(shape, dtype, scope="local"):
def alloc_fragment(shape, dtype, scope="local.fragment"):
return T.alloc_buffer(shape, dtype, scope=scope)
def alloc_var(dtype, scope="local.var"):
return T.alloc_buffer([1], dtype, scope=scope)
from tvm._ffi.registry import register_func
from tvm.ir import make_node
@register_func("tvm.info.mem.local.var")
def mem_info_local_var():
return make_node(
"MemoryInfo",
unit_bits=8,
max_num_bits=64,
max_simd_bits=128,
head_address=None,
)
......@@ -256,3 +256,18 @@ def InjectPTXAsyncCopy():
The result pass
"""
return _ffi_api.InjectPTXAsyncCopy() # type: ignore
def LowerDeviceStorageAccessInfo():
"""Lower attached storage access information on device.
Returns
-------
fpass : tvm.transform.Pass
The result pass
Note
----
Run this pass after all storage access analysis finish.
"""
return _ffi_api.LowerDeviceStorageAccessInfo() # type: ignore
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment