"torchvision/git@developer.sourcefind.cn:OpenDAS/vision.git" did not exist on "1f4f7624560c1d54db3b51b79226d47728238db5"
Unverified Commit 7248a810 authored by Gabriel Wu's avatar Gabriel Wu Committed by GitHub
Browse files

feat(cutedsl): add CuTeDSL backend (#1421)



* feat: CuTeDSL backend

* fix: clang-tidy

* fix: clang-format

* fix: ci

* fix: revert example gemm fp8

* fix: remove duplicate code

* fix: switch-case

* fix: fp16 silence

* fix: TVM IR print

* fix: useless tir

* fix: clang-format

* fix: remove tilelang/contrib/cutedsl/.gitignore

* fix: use hexfloat

* fix: gsym guard

* fix: unknown storage sync type

* fix: string literal

* fix: add args guard

* fix: name hint dedup

* fix: better find_kernel_by_pattern

* fix: set libpath for from_database path

* fix: guard buffer.strides

* fix: from guard

* fix: eviction guard

* fix: use thread local tma descs

* fix: ruff

* fix: drop tma_init_cpp

* fix: exc_info

* fix: negative unmatch early return

* fix: rename postproc func and add test

* fix: handle fast math according to pass config

* fix: dyn_sym parse

* fix: wrap_forward

* fix: use tvm_ffi.libinfo instead of cli

* fix: keep signature

* fix: C++ string safety

* fix: mark tma_store_add as unsupported

* fix: tvm version

* resolve ldsm and cpasync issues.

* fix: minor fixes

* fix: parse signature using ast

* fix: guard global_addr

* fix: create tempfile only when necessary

* fix: use logger.execption for exceptions

* fix: guard lib_path and host_func

* fix: remove tma_cpp_init and add timeout for cpp compile

* add timeout for mbarrier_wait.

* fix: _load_kernel_from_disk signature

* resolve codegen issues.

* fix: logger.exception

* add comment for div_by=1

* merge

* fix: reserve cutlass,cute,tl

* fix: guard tma_store

* fix: allow int64 offset in make_tensor_at_offset

* fix: guard barrier

* fix: add comments for div_by=16

* fix: div_by=1 issue

* delete div_by when offset is 0

* use tl.make_tensor when offset is 0

* fix: explicitly check cutedsl target

* fix: use param.torch_dtype()

---------
Co-authored-by: default avataryuxic <yuxic@nvidia.com>
Co-authored-by: default avatarYong <yong@local>
Co-authored-by: default avatarLeiWang1999 <leiwang1999@outlook.com>
parent a6f59f31
......@@ -370,8 +370,27 @@ jobs:
pytest --verbose --color=yes --durations=0 --showlocals --cache-clear
)
"${PYTEST[@]}" --maxfail=3 --numprocesses=4 \
--ignore=./python/jit/test_tilelang_jit_cutedsl.py \
./python
# CuTeDSL JIT tests require GEMM v1 (must be set before importing tilelang).
# Run them in a dedicated step to avoid changing the default GEMM selection
# (and to keep the rest of the CUDA tests on GEMM v2).
- name: Run CuTeDSL JIT tests (GEMM v1) with Python ${{ matrix.python-version }} (${{ matrix.runner.toolkit }})
id: cutedsl-tests
if: contains(matrix.runner.toolkit, 'CUDA')
env:
TILELANG_USE_GEMM_V1: "1"
run: |
cd testing
PYTEST=(
uv run --no-project -m --
pytest --verbose --color=yes --durations=0 --showlocals --cache-clear
)
# Avoid xdist contention on a single GPU by running this file in one worker.
"${PYTEST[@]}" --maxfail=3 --numprocesses=1 \
./python/jit/test_tilelang_jit_cutedsl.py
# AMD ROCm tests
- name: Run ROCm tests with Python ${{ matrix.python-version }} (${{ matrix.runner.toolkit }})
id: rocm-tests
......
......@@ -215,7 +215,11 @@ elseif(USE_CUDA)
src/runtime/runtime.cc
src/target/ptx.cc
src/target/codegen_cuda.cc
src/target/codegen_py.cc
src/target/codegen_utils.cc
src/target/codegen_cutedsl.cc
src/target/rt_mod_cuda.cc
src/target/rt_mod_cutedsl.cc
)
list(APPEND TILE_LANG_SRCS ${TILE_LANG_CUDA_SRCS})
......
......@@ -14,7 +14,13 @@ cd examples
python -m pytest -n 4 . --verbose --color=yes --durations=0 --showlocals --cache-clear
cd ..
# Run pytest in parallel (4 workers) for all tests in the testing/python directory
# Run pytest in parallel (4 workers) for all tests in the testing/python directory.
# IMPORTANT: CuTeDSL backend currently requires GEMM v1 (TILELANG_USE_GEMM_V1=1).
# Do NOT export it globally here, or you'll silently change the default GEMM selection
# for unrelated tests. Run the CuTeDSL JIT tests in a separate pytest invocation.
cd testing/python
python -m pytest -n 4 . --verbose --color=yes --durations=0 --showlocals --cache-clear
python -m pytest -n 4 . --ignore=jit/test_tilelang_jit_cutedsl.py --verbose --color=yes --durations=0 --showlocals --cache-clear
# CuTeDSL JIT tests (isolate env + avoid xdist contention on a single GPU)
TILELANG_USE_GEMM_V1=1 python -m pytest -n 1 jit/test_tilelang_jit_cutedsl.py --verbose --color=yes --durations=0 --showlocals --cache-clear
cd ..
......@@ -7,3 +7,5 @@
# CUDA specific requirements
flash-attn==2.5.8
cuda-python==12.9.4
# CuTeDSL (CUTLASS Python DSL with CuTe support)
nvidia-cutlass-dsl>=4.3.1
This diff is collapsed.
/*!
* \file target/codegen_cutedsl.h
* \brief Utility to generate CuTeDSL code
*/
#ifndef TVM_TL_TARGET_CODEGEN_CUTEDSL_H_
#define TVM_TL_TARGET_CODEGEN_CUTEDSL_H_
#include <tvm/target/codegen.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/op.h>
#include <string>
#include <unordered_map>
#include <vector>
#include "codegen_py.h"
namespace tvm {
namespace codegen {
class CodeGenTileLangCuTeDSL final : public CodeGenTileLangPY {
public:
CodeGenTileLangCuTeDSL();
protected:
void PrintFuncDecorator_(std::ostream &os) override; // NOLINT(*)
void PreFunctionBody_(const PrimFunc &f) override;
protected:
void PrintType(DataType t, std::ostream &os) override; // NOLINT(*)
void VisitExpr_(const BroadcastNode *op,
std::ostream &os) override; // NOLINT(*)
void VisitExpr_(const FloatImmNode *op,
std::ostream &os) override; // NOLINT(*)
void VisitExpr_(const CastNode *op, std::ostream &os) override; // NOLINT(*)
void VisitExpr_(const DivNode *op, std::ostream &os) override; // NOLINT(*)
void VisitExpr_(const MinNode *op, std::ostream &os) override; // NOLINT(*)
void VisitExpr_(const MaxNode *op, std::ostream &os) override; // NOLINT(*)
void VisitExpr_(const CallNode *op, std::ostream &os) override; // NOLINT(*)
void VisitExpr_(const BufferLoadNode *op,
std::ostream &os) override; // NOLINT(*)
void VisitStmt_(const BufferStoreNode *op) override;
void VisitStmt_(const AllocateNode *op) override;
void VisitStmt_(const AttrStmtNode *op) override;
void VisitStmt_(const ForNode *op) override;
void VisitStmt_(const IfThenElseNode *op) override;
void VisitStmt_(const EvaluateNode *op) override;
protected:
virtual void PrintVecElemLoad_(const std::string &vec, DataType t, int i,
std::ostream &os); // NOLINT(*)
virtual void PrintVecElemStore_(const std::string &vec, DataType t, int i,
const std::string &value);
virtual void PrintVecStore_(const BufferNode *buffer, DataType t,
PrimExpr base, const std::string &value);
void PrintVecBinaryOp_(const std::string &opstr, DataType dtype, PrimExpr lhs,
PrimExpr rhs,
std::ostream &os); // NOLINT(*)
void PrintBinaryExpr_(const std::string &opstr, DataType dtype, PrimExpr lhs,
PrimExpr rhs,
std::ostream &os) override; // NOLINT(*)
void PrintBinaryIntrinsic_(const CallNode *op, const char *opstr,
std::ostream &os) override; // NOLINT(*)
void PrintCallExtern_(Type ret_type, ffi::String global_symbol,
const ffi::Array<PrimExpr> &args, bool skip_first_arg,
std::ostream &os) override; // NOLINT(*)
std::string GetBufferPtr_(const BufferNode *buffer, PrimExpr index);
std::string GetBufferRef_(DataType t, const BufferNode *buffer,
PrimExpr index) override;
/*!
* \brief Print expr representing the thread tag
* \param IterVar iv The thread index to be binded;
*/
virtual void BindThreadIndex_(const IterVar &iv); // NOLINT(*)
virtual void PrintStorageSync_(const CallNode *op);
std::string
CanonicalizeFastmathFunctionName_(const std::string &func_name) const;
private:
// The name of the mbarrier array in shared memory
const std::string mbarrier_name_ = "mbarrier";
std::unordered_map<const VarNode *, IntImm> unroll_factor_;
std::vector<std::string> eviction_policy_names_ = {
"EVICT_NORMAL", "EVICT_FIRST", "EVICT_LAST"};
// Fastmath configuration (read from PassContext)
bool enable_fastmath_ = false;
};
} // namespace codegen
} // namespace tvm
#endif // TVM_TL_TARGET_CODEGEN_CUTEDSL_H_
This diff is collapsed.
/*!
* \file codegen_py.h
* \brief Common utilities to generate simple Python code.
*/
#ifndef TVM_TL_TARGET_CODEGEN_PY_H_
#define TVM_TL_TARGET_CODEGEN_PY_H_
#include <tvm/ir/op.h>
#include <tvm/target/codegen.h>
#include <tvm/tir/analysis.h>
#include <tvm/tir/builtin.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/function.h>
#include <tvm/tir/op_attr_types.h>
#include <tvm/tir/stmt.h>
#include <tvm/tir/stmt_functor.h>
#include <string>
#include <unordered_map>
// from tvm/src/
#include "target/source/codegen_source_base.h"
#include "tir/transforms/ir_utils.h"
namespace tvm {
namespace codegen {
using namespace tir;
/*!
* \brief A base class to generate simple Python code.
*/
class CodeGenTileLangPY
: public ExprFunctor<void(const PrimExpr &, std::ostream &)>,
public StmtFunctor<void(const Stmt &)>,
public CodeGenSourceBase {
public:
/*!
* \brief Add the function definition to the generated module.
* \param gvar The GlobalVar representing the function.
* \param func The function to be compiled.
*/
virtual void AddFunction(const GlobalVar &gvar, const PrimFunc &func);
/*!
* \brief Finalize the compilation and return the code.
* \return The code.
*/
virtual std::string Finish();
protected:
/*!
* \brief Get the name of a declared function
* \param gvar The GlobalVar of the function
* \returns The string name of the function
*/
ffi::String GetFunctionName_(const GlobalVar &gvar);
/*!
* \brief Reserve the function name in the generated module.
*
* \param gvar The GlobalVar representing the function.
* \param func The function to be compiled.
* \param whether to append return 0 in the end.
*/
virtual void RegisterFunction_(const GlobalVar &gvar, const PrimFunc &func);
/*!
* \brief Initialize codegen state for generating f.
* \param f The function to be compiled.
*/
virtual void InitFuncState_(const PrimFunc &f);
/*! \brief Print the function signature before ":"
* \param function_name The name of the function
* \param func The function whose signature should be printed
* \param os The output stream
*/
virtual void PrintFunctionSignature_(const ffi::String &function_name,
const PrimFunc &func,
std::ostream &os); // NOLINT(*)
/*!
* \brief Print the function decorator
* \param os The output stream
*/
virtual void PrintFuncDecorator_(std::ostream &os) {} // NOLINT(*)
/*!
* \brief Insert statement before function body.
* \param f The function to be compiled.
*/
virtual void PreFunctionBody_(const PrimFunc &f) {}
protected:
/*! \brief reserves common Python keywords */
void ReserveKeywordsAsUnique_();
void PrintSSAAssign(const std::string &target, const std::string &src,
DataType t) override;
protected:
/*!
* \brief Print Type representation of type type.
* \param t The type representation.
* \param os The output stream
*/
void PrintType(DataType type, std::ostream &os) override; // NOLINT(*)
/*!
* \brief Print the Stmt n to CodeGenTileLangPY->stream
* \param n The statement to be printed.
*/
void PrintStmt_(const Stmt &n) { VisitStmt(n); }
/*!
* \brief Print the expression n into os
* \param n The expression to be printed.
* \param os The output stream
*/
void PrintExpr_(const PrimExpr &n, std::ostream &os) { // NOLINT(*)
VisitExpr(n, os);
}
/*!
* \brief Same as PrintExpr_, but simply returns result string
* \param n The expression to be printed.
*/
std::string PrintExpr_(const PrimExpr &n) {
std::ostringstream os;
PrintExpr_(n, os);
return os.str();
}
// expression
void VisitExpr_(const VarNode *op, std::ostream &os) override; // NOLINT(*)
void VisitExpr_(const IntImmNode *op, std::ostream &os) override; // NOLINT(*)
void VisitExpr_(const FloatImmNode *op,
std::ostream &os) override; // NOLINT(*)
void VisitExpr_(const StringImmNode *op,
std::ostream &os) override; // NOLINT(*)
void VisitExpr_(const CastNode *op, std::ostream &os) override; // NOLINT(*)
void VisitExpr_(const AddNode *op, std::ostream &os) override; // NOLINT(*)
void VisitExpr_(const SubNode *op, std::ostream &os) override; // NOLINT(*)
void VisitExpr_(const MulNode *op, std::ostream &os) override; // NOLINT(*)
void VisitExpr_(const DivNode *op, std::ostream &os) override; // NOLINT(*)
void VisitExpr_(const ModNode *op, std::ostream &os) override; // NOLINT(*)
void VisitExpr_(const MinNode *op, std::ostream &os) override; // NOLINT(*)
void VisitExpr_(const MaxNode *op, std::ostream &os) override; // NOLINT(*)
void VisitExpr_(const EQNode *op, std::ostream &os) override; // NOLINT(*)
void VisitExpr_(const NENode *op, std::ostream &os) override; // NOLINT(*)
void VisitExpr_(const LTNode *op, std::ostream &os) override; // NOLINT(*)
void VisitExpr_(const LENode *op, std::ostream &os) override; // NOLINT(*)
void VisitExpr_(const GTNode *op, std::ostream &os) override; // NOLINT(*)
void VisitExpr_(const GENode *op, std::ostream &os) override; // NOLINT(*)
void VisitExpr_(const AndNode *op, std::ostream &os) override; // NOLINT(*)
void VisitExpr_(const OrNode *op, std::ostream &os) override; // NOLINT(*)
void VisitExpr_(const NotNode *op, std::ostream &os) override; // NOLINT(*)
void VisitExpr_(const SelectNode *op, std::ostream &os) override; // NOLINT(*)
void VisitExpr_(const RampNode *op, std::ostream &os) override; // NOLINT(*)
void VisitExpr_(const CallNode *op, std::ostream &os) override; // NOLINT(*)
void VisitExpr_(const BufferLoadNode *op,
std::ostream &os) override; // NOLINT(*)
// statment
void VisitStmt_(const BufferStoreNode *op) override;
void VisitStmt_(const DeclBufferNode *op) override;
void VisitStmt_(const LetStmtNode *op) override;
void VisitStmt_(const AllocateNode *op) override;
void VisitStmt_(const AttrStmtNode *op) override;
void VisitStmt_(const ForNode *op) override;
void VisitStmt_(const WhileNode *op) override;
void VisitStmt_(const IfThenElseNode *op) override;
void VisitStmt_(const SeqStmtNode *op) override;
void VisitStmt_(const EvaluateNode *op) override;
void VisitStmt_(const AssertStmtNode *op) override;
protected:
// Get a string of type casting
virtual std::string CastFromTo_(const std::string &value, DataType from,
DataType target);
virtual void PrintBinaryExpr_(const std::string &opstr, DataType dtype,
PrimExpr lhs, PrimExpr rhs,
std::ostream &os); // NOLINT(*)
virtual void PrintBinaryIntrinsic_(const CallNode *op, const char *opstr,
std::ostream &os); // NOLINT(*)
/*!
* \brief Print external function call.
* \param ret_type The return type.
* \param global_symbol The symbolc of the target function.
* \param args The arguments to the function.
* \param skip_first_arg Whether to skip the first arguments.
* \param os The output stream.
*/
virtual void PrintCallExtern_(Type ret_type, ffi::String global_symbol,
const ffi::Array<PrimExpr> &args,
bool skip_first_arg,
std::ostream &os); // NOLINT(*)
// Print reference to a buffer as type t in index.
virtual std::string GetBufferRef_(DataType t, const BufferNode *buffer,
PrimExpr index);
/*!
* \brief Register the data type of buf_var
* \param buf_var The buffer variable.
* \param t The type to be checked.
*/
void RegisterHandleType_(const VarNode *buf_var, DataType t);
/*!
* \brief If buffer is allocated as type t.
* \param buf_var The buffer variable.
* \param t The type to be checked.
*/
bool HandleTypeMatch_(const VarNode *buf_var, DataType t) const;
protected:
/*! \brief the storage scope of allocation */
std::unordered_map<const VarNode *, std::string> alloc_storage_scope_;
/*! \brief Record of ops that have pre-defined global symbol. */
OpAttrMap<TGlobalSymbol> op_attr_global_symbol_ =
Op::GetAttrMap<TGlobalSymbol>("TGlobalSymbol");
// cache commonly used ops
const Op &builtin_call_extern_ = builtin::call_extern();
const Op &builtin_call_pure_extern_ = builtin::call_pure_extern();
private:
/*! \brief the data type of allocated buffers */
std::unordered_map<const VarNode *, DataType> handle_data_type_;
/* \brief Map of GlobalVar to their symbol.
*
* For externally-exposed functions, this is given by the
* tvm::attr::kTarget attribute of the PrimFunc. For internal
* functions, this is the name of the function's GlobalVar, possibly
* altered to prevent duplicate names.
*/
std::unordered_map<GlobalVar, ffi::String> internal_functions_;
/* \brief Name supply to generate unique function names */
NameSupply func_name_supply_;
/*!
* \brief Escape a string to be a valid Python double-quoted string literal.
* \param s The input string to escape.
* \param os The output stream to write the escaped string to.
*/
void EscapeStringLiteral_(const std::string &s, std::ostream &os);
};
} // namespace codegen
} // namespace tvm
#endif // TVM_TL_TARGET_CODEGEN_PY_H_
/*!
* \file target/codegen_utils.cc
* \brief Shared utility functions for code generation
*/
#include "codegen_utils.h"
namespace tvm {
namespace codegen {
bool CheckOutermostParenthesesMatch(const std::string &s) {
if (!s.empty() && s.front() == '(' && s.back() == ')') {
size_t len = s.size();
int n_unmatched = 0;
for (size_t i = 0; i < len; ++i) {
if (s[i] == '(') {
n_unmatched++;
} else if (s[i] == ')') {
n_unmatched--;
}
if (n_unmatched < 0) {
return false;
}
if (n_unmatched == 0) {
return i == len - 1;
}
}
}
return false;
}
std::string RemoveOutermostParentheses(const std::string &s) {
if (CheckOutermostParenthesesMatch(s)) {
return s.substr(1, s.size() - 2);
} else {
return s;
}
}
} // namespace codegen
} // namespace tvm
/*!
* \file target/codegen_utils.h
* \brief Shared utility functions for code generation
*/
#ifndef TVM_TARGET_CODEGEN_UTILS_H_
#define TVM_TARGET_CODEGEN_UTILS_H_
#include <string>
namespace tvm {
namespace codegen {
/*!
* \brief Check if the outermost parentheses match
* \param s The input string
* \return true if the first character is '(' and the last character is ')'
* and they form a matching pair
*/
bool CheckOutermostParenthesesMatch(const std::string &s);
/*!
* \brief Remove outermost parentheses if they match
* \param s The input string
* \return The string with outermost parentheses removed if they match,
* otherwise return the original string
*/
std::string RemoveOutermostParentheses(const std::string &s);
} // namespace codegen
} // namespace tvm
#endif // TVM_TARGET_CODEGEN_UTILS_H_
#include "codegen_cutedsl.h"
#include "runtime/cuda/cuda_module.h"
#include "runtime/pack_args.h"
#include <tvm/ffi/reflection/registry.h>
namespace tvm {
namespace codegen {
static std::unordered_map<std::string, runtime::FunctionInfo>
ExtractFuncInfo(const IRModule &mod) {
std::unordered_map<std::string, runtime::FunctionInfo> fmap;
for (auto kv : mod->functions) {
ICHECK(kv.second->IsInstance<tir::PrimFuncNode>())
<< "Can only lower IR Module with PrimFuncs";
auto f = Downcast<tir::PrimFunc>(kv.second);
runtime::FunctionInfo info;
for (size_t i = 0; i < f->params.size(); ++i) {
if (f->params[i]->dtype.is_handle()) {
auto ptr = f->params[i]->type_annotation.as<PointerTypeNode>();
if (ptr && ptr->storage_scope == "grid_constant") {
info.arg_types.push_back(DataType(runtime::kDLGridConstant, 64, 1));
continue;
}
}
info.arg_types.push_back(f->params[i].dtype());
}
if (auto opt = f->GetAttr<ffi::Array<ffi::String>>(
tir::attr::kKernelLaunchParams)) {
for (const auto &tag : opt.value()) {
info.launch_param_tags.push_back(tag);
}
}
auto global_symbol = f->GetAttr<ffi::String>(tvm::attr::kGlobalSymbol);
fmap[static_cast<std::string>(global_symbol.value())] = info;
}
return fmap;
}
ffi::Module BuildTileLangCuTeDSLWithoutCompile(IRModule mod, Target target) {
CodeGenTileLangCuTeDSL cg;
for (auto kv : mod->functions) {
ICHECK(kv.second->IsInstance<PrimFuncNode>())
<< "CodeGenTileLangCuTeDSL: Can only take PrimFunc";
auto gvar = Downcast<GlobalVar>(kv.first);
auto f = Downcast<PrimFunc>(kv.second);
auto calling_conv = f->GetAttr<Integer>(tvm::attr::kCallingConv);
ICHECK(calling_conv == CallingConv::kDeviceKernelLaunch);
cg.AddFunction(gvar, f);
}
std::string code = cg.Finish();
if (const auto f =
ffi::Function::GetGlobal("tilelang_callback_cutedsl_postproc")) {
code = (*f)(code, target).cast<std::string>();
}
return runtime::CUDAModuleCreate("ptx", "ptx", ExtractFuncInfo(mod), code);
}
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("target.build.tilelang_cutedsl_without_compile",
BuildTileLangCuTeDSLWithoutCompile);
}
} // namespace codegen
} // namespace tvm
......@@ -173,4 +173,4 @@ template <class T, unsigned I = 0>
inline constexpr size_t extent_v = extent<T, I>::value;
} // namespace std
#endif
\ No newline at end of file
#endif // __CUDACC_RTC__
from tilelang import tvm as tvm
import tilelang.language as T
import tilelang.testing
import tilelang
import torch
from tilelang.utils.tensor import map_torch_type
def matmul(
M,
N,
K,
block_M,
block_N,
block_K,
trans_A,
trans_B,
in_dtype,
out_dtype,
accum_dtype,
num_stages,
threads,
):
A_shape = (K, M) if trans_A else (M, K)
B_shape = (N, K) if trans_B else (K, N)
A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K)
B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N)
@T.prim_func
def main(
A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, in_dtype),
C: T.Tensor((M, N), out_dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype)
B_shared = T.alloc_shared(B_shared_shape, in_dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
T.clear(C_local)
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
if trans_A:
T.copy(A[k * block_K, by * block_M], A_shared)
else:
T.copy(A[by * block_M, k * block_K], A_shared)
if trans_B:
T.copy(B[bx * block_N, k * block_K], B_shared)
else:
T.copy(B[k * block_K, bx * block_N], B_shared)
T.gemm(A_shared, B_shared, C_local, trans_A, trans_B)
T.copy(C_local, C[by * block_M, bx * block_N])
return main
def run_gemm(
M,
N,
K,
trans_A,
trans_B,
in_dtype,
out_dtype,
dtypeAccum,
block_M,
block_N,
block_K,
num_stages=3,
num_threads=128,
):
program = matmul(
M,
N,
K,
block_M,
block_N,
block_K,
trans_A,
trans_B,
in_dtype,
out_dtype,
dtypeAccum,
num_stages,
num_threads,
)
stramp = "&*(XS)"
@tvm.register_global_func("tilelang_callback_cutedsl_postproc", override=True)
def tilelang_callback_cutedsl_postproc(code, _):
code = f"# {stramp}\n" + code
return code
matmul_kernel = tilelang.compile(program, out_idx=-1, target="cutedsl")
kernel_source = matmul_kernel.get_kernel_source()
assert stramp in kernel_source, f"Expected {stramp} in the kernel source"
def test_gemm_f16f16f16_nn():
run_gemm(
512,
1024,
768,
False,
False,
"float16",
"float16",
"float16",
128,
256,
32,
2,
)
def matmul_jit_kernel(
M,
N,
K,
block_M,
block_N,
block_K,
trans_A,
trans_B,
in_dtype,
out_dtype,
accum_dtype,
num_stages,
threads,
):
A_shape = (K, M) if trans_A else (M, K)
B_shape = (N, K) if trans_B else (K, N)
A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K)
B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N)
@T.prim_func
def main(
A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, in_dtype),
C: T.Tensor((M, N), out_dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype)
B_shared = T.alloc_shared(B_shared_shape, in_dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
T.clear(C_local)
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
if trans_A:
T.copy(A[k * block_K, by * block_M], A_shared)
else:
T.copy(A[by * block_M, k * block_K], A_shared)
if trans_B:
T.copy(B[bx * block_N, k * block_K], B_shared)
else:
T.copy(B[k * block_K, bx * block_N], B_shared)
T.gemm(A_shared, B_shared, C_local, trans_A, trans_B)
T.copy(C_local, C[by * block_M, bx * block_N])
return main
def run_gemm_jit_kernel(
M,
N,
K,
trans_A,
trans_B,
in_dtype,
out_dtype,
dtypeAccum,
block_M,
block_N,
block_K,
num_stages=3,
num_threads=128,
):
program = matmul_jit_kernel(
M,
N,
K,
block_M,
block_N,
block_K,
trans_A,
trans_B,
in_dtype,
out_dtype,
dtypeAccum,
num_stages,
num_threads,
)
matmul_kernel = tilelang.compile(program, out_idx=-1, target="cutedsl")
in_dtype = map_torch_type(in_dtype)
out_dtype = map_torch_type(out_dtype)
A = torch.randn(M, K, dtype=in_dtype).cuda()
B = torch.randn(K, N, dtype=in_dtype).cuda()
if trans_A:
A = A.T
if trans_B:
B = B.T
def ref_program(A, B):
import torch
C = torch.matmul(A.to(torch.float), B.to(torch.float))
C = C.to(out_dtype)
return C
ref_C = ref_program(A, B)
C = matmul_kernel(A, B)
tilelang.testing.torch_assert_close(C, ref_C, atol=1e-2, rtol=1e-2, max_mismatched_ratio=0.05)
def test_gemm_jit_kernel():
run_gemm_jit_kernel(
512,
1024,
768,
False,
False,
"float16",
"float16",
"float16",
128,
256,
32,
2,
)
def run_cutedsl_kernel_do_bench(
M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages=3, num_threads=128
):
program = matmul(
M,
N,
K,
block_M,
block_N,
block_K,
trans_A,
trans_B,
in_dtype,
out_dtype,
dtypeAccum,
num_stages,
num_threads,
)
matmul_kernel = tilelang.compile(program, target="cutedsl")
profiler = matmul_kernel.get_profiler()
cutedsl_latency = profiler.do_bench(func=matmul_kernel)
print(f"CuTeDSL Latency: {cutedsl_latency} ms")
assert cutedsl_latency is not None
tvm_latency = profiler.do_bench()
print(f"TVM Latency: {tvm_latency} ms")
assert tvm_latency is not None
def test_cutedsl_kernel_do_bench():
run_cutedsl_kernel_do_bench(512, 1024, 768, False, False, "float16", "float16", "float16", 128, 256, 32, 2)
def run_cutedsl_kernel_multi_stream(
M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages=3, num_threads=128
):
program = matmul(
M,
N,
K,
block_M,
block_N,
block_K,
trans_A,
trans_B,
in_dtype,
out_dtype,
dtypeAccum,
num_stages,
num_threads,
)
matmul_kernel = tilelang.compile(program, target="cutedsl")
in_dtype = map_torch_type(in_dtype)
out_dtype = map_torch_type(out_dtype)
tensor_a = torch.randn(M, K, dtype=in_dtype).cuda()
tensor_b = torch.randn(K, N, dtype=in_dtype).cuda()
if trans_A:
tensor_a = tensor_a.T
if trans_B:
tensor_b = tensor_b.T
tensor_c = torch.randn(M, N, dtype=out_dtype).cuda()
num_streams = 4
for _ in range(num_streams):
stream = torch.cuda.Stream()
with torch.cuda.stream(stream):
matmul_kernel(tensor_a, tensor_b, tensor_c)
def test_cutedsl_kernel_multi_stream():
run_cutedsl_kernel_multi_stream(512, 1024, 768, False, False, "float16", "float16", "float16", 128, 256, 32, 2)
def run_cutedsl_dynamic_shape(
M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages=3, num_threads=128
):
program = matmul(
M,
N,
K,
block_M,
block_N,
block_K,
trans_A,
trans_B,
in_dtype,
out_dtype,
dtypeAccum,
num_stages,
num_threads,
)
matmul_kernel = tilelang.compile(program, target="cutedsl")
if isinstance(M, T.Var):
M = 1024
if isinstance(N, T.Var):
N = 1024
if isinstance(K, T.Var):
K = 768
in_dtype = map_torch_type(in_dtype)
out_dtype = map_torch_type(out_dtype)
tensor_a = torch.randn(M, K, dtype=in_dtype).cuda()
tensor_b = torch.randn(K, N, dtype=in_dtype).cuda()
if trans_A:
tensor_a = tensor_a.T
if trans_B:
tensor_b = tensor_b.T
tensor_c = torch.randn(M, N, dtype=out_dtype).cuda()
matmul_kernel(tensor_a, tensor_b, tensor_c)
tensor_ref_c = torch.matmul(tensor_a.to(torch.float), tensor_b.to(torch.float)).to(out_dtype)
tilelang.testing.torch_assert_close(tensor_c, tensor_ref_c, atol=1e-2, rtol=1e-2, max_mismatched_ratio=0.05)
def test_cutedsl_dynamic_shape():
run_cutedsl_dynamic_shape(T.dynamic("m"), 1024, 768, False, False, "float16", "float16", "float16", 128, 256, 32, 2)
run_cutedsl_dynamic_shape(T.dynamic("m"), T.dynamic("n"), 768, False, False, "float16", "float16", "float16", 128, 256, 32, 2)
run_cutedsl_dynamic_shape(
T.dynamic("m"), T.dynamic("n"), T.dynamic("k"), False, False, "float16", "float16", "float16", 128, 256, 32, 2
)
def check_hopper():
if not torch.cuda.is_available():
return False
props = torch.cuda.get_device_properties(0)
compute_capability = props.major, props.minor
return compute_capability == (9, 0)
if __name__ == "__main__":
tilelang.testing.main()
......@@ -29,6 +29,11 @@ KERNEL_CUBIN_PATH = "kernel.cubin"
KERNEL_PY_PATH = "kernel.py"
PARAMS_PATH = "params.pkl"
# CuTeDSL C++ launcher specific
LAUNCHER_LIB_PATH = "launcher_lib.so"
LAUNCHER_CPP_PATH = "launcher.cpp"
CUTEDSL_CUBIN_PATH = "kernel.cubin"
class KernelCache:
"""
......@@ -43,7 +48,7 @@ class KernelCache:
_instance = None # For implementing singleton pattern
_lock = threading.Lock() # For thread safety
_memory_cache = {} # In-memory cache dictionary
execution_backend: Literal["tvm_ffi", "ctypes", "cython", "nvrtc", "torch"] = "tvm_ffi"
execution_backend: Literal["tvm_ffi", "ctypes", "cython", "nvrtc", "torch", "cutedsl"] = "tvm_ffi"
def __new__(cls):
"""
......@@ -72,7 +77,7 @@ class KernelCache:
self,
func: Callable,
out_idx: list[int],
execution_backend: Literal["tvm_ffi", "ctypes", "cython", "nvrtc", "torch"] = "tvm_ffi",
execution_backend: Literal["tvm_ffi", "ctypes", "cython", "nvrtc", "torch", "cutedsl"] = "tvm_ffi",
args=None,
target: str | Target = "auto",
target_host: str | Target = None,
......@@ -85,7 +90,7 @@ class KernelCache:
Args:
func (Callable): The function to be compiled.
out_idx (List[int]): Indices specifying which outputs to return.
execution_backend (Literal): Backend type for execution. Defaults to "cython".
execution_backend (Literal): Backend type for execution. Defaults to "tvm_ffi".
args: Arguments passed to the function.
target (Union[str, Target]): Compilation target platform. Defaults to "auto".
target_host (Union[str, Target], optional): Host target platform.
......@@ -118,7 +123,7 @@ class KernelCache:
*args,
target: str | Target = "auto",
target_host: str | Target = None,
execution_backend: Literal["auto", "tvm_ffi", "ctypes", "cython", "nvrtc", "torch"] = "auto",
execution_backend: Literal["auto", "tvm_ffi", "ctypes", "cython", "nvrtc", "torch", "cutedsl"] = "auto",
verbose: bool = False,
pass_configs: dict = None,
compile_flags: list[str] | str | None = None,
......@@ -217,7 +222,11 @@ class KernelCache:
)
with self._lock:
if env.is_cache_enabled():
cache_path = self._get_cache_path(key)
self._save_kernel_to_disk(key, kernel, func, verbose)
# Set cache path on adapter so it can save cubin after first execution
if hasattr(kernel, "adapter") and execution_backend == "cutedsl":
kernel.adapter._cache_path = cache_path
# Store in memory cache after compilation
self._memory_cache[key] = kernel
......@@ -287,59 +296,83 @@ class KernelCache:
# Save kernel source code
try:
device_kernel_path = os.path.join(cache_path, DEVICE_KERNEL_PATH)
if verbose:
self.logger.debug(f"Saving kernel source code to file: {device_kernel_path}")
if kernel.kernel_source is not None:
KernelCache._safe_write_file(device_kernel_path, "w", lambda file: file.write(kernel.kernel_source))
except Exception as e:
self.logger.error(f"Error saving kernel source code to disk: {e}")
if self.execution_backend != "cutedsl":
device_kernel_path = os.path.join(cache_path, DEVICE_KERNEL_PATH)
if verbose:
self.logger.debug(f"Saving kernel source code to file: {device_kernel_path}")
if kernel.kernel_source is not None:
KernelCache._safe_write_file(device_kernel_path, "w", lambda file: file.write(kernel.kernel_source))
except Exception:
self.logger.exception("Error saving kernel source code to disk")
# Save wrapped kernel source code
try:
host_kernel_path = os.path.join(cache_path, HOST_KERNEL_PATH)
host_kernel_path = os.path.join(cache_path, HOST_KERNEL_PATH if self.execution_backend != "cutedsl" else KERNEL_PY_PATH)
if verbose:
self.logger.debug(f"Saving wrapped kernel source code to file: {host_kernel_path}")
if self.execution_backend == "tvm_ffi":
KernelCache._safe_write_file(host_kernel_path, "w", lambda file: file.write(kernel.adapter.get_host_source()))
else:
KernelCache._safe_write_file(host_kernel_path, "w", lambda file: file.write(kernel.adapter.get_kernel_source()))
except Exception as e:
self.logger.error(f"Error saving host kernel source code to disk: {e}")
except Exception:
self.logger.exception("Error saving host kernel source code to disk")
# Save the kernel library
try:
# Save CUBIN or SO file
if self.execution_backend == "nvrtc":
kernel_lib_path = KERNEL_CUBIN_PATH
elif self.execution_backend == "tvm_ffi":
kernel_lib_path = EXECUTABLE_PATH
else:
kernel_lib_path = KERNEL_LIB_PATH
kernel_lib_path = os.path.join(cache_path, kernel_lib_path)
if self.execution_backend == "cutedsl":
# For CuTeDSL, kernel_lib_path is the Python module
kernel_lib_path = os.path.join(cache_path, KERNEL_PY_PATH)
# Save C++ launcher library if it exists
lib_gen = getattr(kernel.adapter, "lib_generator", None)
if lib_gen and hasattr(lib_gen, "launcher_libpath") and lib_gen.launcher_libpath:
launcher_lib_path = os.path.join(cache_path, LAUNCHER_LIB_PATH)
src_launcher_path = lib_gen.launcher_libpath
if verbose:
self.logger.debug(f"Saving C++ launcher library to cache: {src_launcher_path}")
KernelCache._safe_write_file(
launcher_lib_path, "wb", lambda file: file.write(KernelCache._load_binary(src_launcher_path))
)
# Optionally save launcher C++ source for debugging
if hasattr(kernel.adapter, "launcher_cpp_code") and kernel.adapter.launcher_cpp_code:
launcher_cpp_path = os.path.join(cache_path, LAUNCHER_CPP_PATH)
if verbose:
self.logger.debug(f"Saving C++ launcher source to: {launcher_cpp_path}")
KernelCache._safe_write_file(launcher_cpp_path, "w", lambda file: file.write(kernel.adapter.launcher_cpp_code))
# Save an extra Python file for NVRTC
if self.execution_backend == "nvrtc":
src_lib_path = kernel.adapter.libpath
kernel_py_path = os.path.join(cache_path, KERNEL_PY_PATH)
src_lib_path = src_lib_path.replace(".cubin", ".py")
if verbose:
self.logger.debug(f"Saving kernel nvrtc python code to file: {kernel_py_path}")
KernelCache._safe_write_file(kernel_py_path, "wb", lambda file: file.write(KernelCache._load_binary(src_lib_path)))
elif self.execution_backend == "tvm_ffi":
executable = kernel.adapter.executable
if verbose:
self.logger.debug(f"Saving kernel executable to file: {executable}")
KernelCache._safe_write_executable(executable, kernel_lib_path)
else:
src_lib_path = kernel.adapter.libpath
if verbose:
self.logger.debug(f"Saving kernel library to file: {kernel_lib_path}")
KernelCache._safe_write_file(kernel_lib_path, "wb", lambda file: file.write(KernelCache._load_binary(src_lib_path)))
except Exception as e:
self.logger.error(f"Error saving kernel library to disk: {e}")
if self.execution_backend == "nvrtc":
kernel_lib_path = KERNEL_CUBIN_PATH
elif self.execution_backend == "tvm_ffi":
kernel_lib_path = EXECUTABLE_PATH
else:
kernel_lib_path = KERNEL_LIB_PATH
kernel_lib_path = os.path.join(cache_path, kernel_lib_path)
# Save an extra Python file for NVRTC
if self.execution_backend == "nvrtc":
src_lib_path = kernel.adapter.libpath
kernel_py_path = os.path.join(cache_path, KERNEL_PY_PATH)
src_lib_path = src_lib_path.replace(".cubin", ".py")
if verbose:
self.logger.debug(f"Saving kernel nvrtc python code to file: {kernel_py_path}")
KernelCache._safe_write_file(kernel_py_path, "wb", lambda file: file.write(KernelCache._load_binary(src_lib_path)))
if self.execution_backend == "tvm_ffi":
executable = kernel.adapter.executable
if verbose:
self.logger.debug(f"Saving kernel executable to file: {executable}")
KernelCache._safe_write_executable(executable, kernel_lib_path)
else:
src_lib_path = kernel.adapter.libpath
if verbose:
self.logger.debug(f"Saving kernel library to file: {kernel_lib_path}")
KernelCache._safe_write_file(kernel_lib_path, "wb", lambda file: file.write(KernelCache._load_binary(src_lib_path)))
except Exception:
self.logger.exception("Error saving kernel library to disk")
# Save kernel parameters
try:
......@@ -347,19 +380,19 @@ class KernelCache:
if verbose:
self.logger.debug(f"Saving kernel parameters to disk: {params_path}")
KernelCache._safe_write_file(params_path, "wb", lambda file: cloudpickle.dump(kernel.params, file))
except Exception as e:
self.logger.error(f"Error saving kernel parameters to disk: {e}")
except Exception:
self.logger.exception("Error saving kernel parameters to disk")
def _load_kernel_from_disk(
self,
key: str,
target: str | Target = "auto",
target_host: str | Target = None,
out_idx: list[int] = None,
execution_backend: Literal["tvm_ffi", "ctypes", "cython", "nvrtc", "torch"] = "tvm_ffi",
pass_configs: dict = None,
target_host: str | Target | None = None,
out_idx: list[int] | None = None,
execution_backend: Literal["tvm_ffi", "ctypes", "cython", "nvrtc", "torch", "cutedsl"] = "tvm_ffi",
pass_configs: dict | None = None,
compile_flags: list[str] | str | None = None,
func: Callable = None,
func: Callable | None = None,
verbose: bool = False,
) -> JITKernel | None:
"""
......@@ -370,7 +403,7 @@ class KernelCache:
target (Union[str, Target]): Compilation target platform. Defaults to "auto".
target_host (Union[str, Target], optional): Host target platform.
out_idx (List[int], optional): Indices specifying which outputs to return.
execution_backend (Literal): Backend type for execution. Defaults to "cython".
execution_backend (Literal): Backend type for execution. Defaults to "tvm_ffi".
pass_configs (dict, optional): Configuration for compiler passes.
func (Callable, optional): The original function.
verbose (bool): Enable verbose log messages.
......@@ -385,11 +418,21 @@ class KernelCache:
kernel_lib_path = KERNEL_CUBIN_PATH
elif self.execution_backend == "tvm_ffi":
kernel_lib_path = EXECUTABLE_PATH
elif self.execution_backend == "cutedsl":
kernel_lib_path = KERNEL_PY_PATH
else:
kernel_lib_path = KERNEL_LIB_PATH
kernel_lib_path = os.path.join(cache_path, kernel_lib_path)
params_path = os.path.join(cache_path, PARAMS_PATH)
if not all([os.path.exists(file) for file in (kernel_lib_path, params_path)]):
# Check required files exist
required_files = [kernel_lib_path, params_path]
# For CuTeDSL, also check launcher library
if self.execution_backend == "cutedsl":
required_files.append(os.path.join(cache_path, LAUNCHER_LIB_PATH))
if not all([os.path.exists(file) for file in required_files]):
return None
device_kernel_source: str | None = None
......@@ -397,20 +440,25 @@ class KernelCache:
kernel_params: list[KernelParam] | None = None
# Load the kernel source file (optional)
try:
if verbose:
self.logger.debug(f"Loading kernel source code from file: {device_kernel_path}")
with open(device_kernel_path) as f:
device_kernel_source = f.read()
except Exception as e:
self.logger.error(f"Error loading kernel source code from disk: {e}")
try:
if verbose:
self.logger.debug(f"Loading wrapped kernel source code from file: {host_kernel_path}")
with open(host_kernel_path) as f:
host_kernel_source = f.read()
except Exception as e:
self.logger.error(f"Error loading host kernel source code from disk: {e}")
if self.execution_backend != "cutedsl":
try:
if verbose:
self.logger.debug(f"Loading kernel source code from file: {device_kernel_path}")
with open(device_kernel_path) as f:
device_kernel_source = f.read()
except Exception:
self.logger.exception("Error loading kernel source code from disk")
try:
if verbose:
self.logger.debug(f"Loading wrapped kernel source code from file: {host_kernel_path}")
with open(host_kernel_path) as f:
host_kernel_source = f.read()
except Exception:
self.logger.exception("Error loading host kernel source code from disk")
else:
# For CuTeDSL, set empty strings since sources aren't loaded from cache
device_kernel_source = ""
host_kernel_source = ""
# Load kernel parameters
try:
......@@ -418,10 +466,10 @@ class KernelCache:
self.logger.debug(f"Loading kernel parameters from file: {params_path}")
with open(params_path, "rb") as f:
kernel_params = cloudpickle.load(f)
except Exception as e:
self.logger.error(f"Error loading kernel parameters from disk: {e}")
except Exception:
self.logger.exception("Error loading kernel parameters from disk")
if host_kernel_source and device_kernel_source and kernel_params:
if ((host_kernel_source and device_kernel_source) or self.execution_backend == "cutedsl") and kernel_params:
return JITKernel.from_database(
func=func,
host_kernel_source=host_kernel_source,
......@@ -453,5 +501,5 @@ class KernelCache:
# Re-create the cache directory
KernelCache._create_dirs()
except Exception as e:
self.logger.error(f"Error clearing disk cache: {e}")
except Exception:
self.logger.exception("Error clearing disk cache")
import cutlass
import cutlass.cute as cute
from cutlass._mlir.dialects import nvvm
from cutlass.cutlass_dsl import T
# re-export cutlass.cute.arch functions first
from cutlass.cute.arch import sync_threads # noqa: F401
from cutlass.cute.arch import alloc_smem, get_dyn_smem # noqa: F401
from cutlass.cute.arch import warpgroup_reg_alloc, warpgroup_reg_dealloc # noqa: F401
from cutlass.cute import make_tensor, make_rmem_tensor, recast_ptr # noqa: F401
from cutlass.cute.typing import Numeric
from cutlass.base_dsl.typing import as_numeric, Int32, Uint16, Uint32 # noqa: F401
from cutlass._mlir.dialects import llvm, arith # noqa: F401
from cutlass._mlir import ir as mlir_ir
from cutlass.cutlass_dsl import dsl_user_op
# Import our custom implementations (will override if names conflict)
from .mbar import *
from .cpasync import *
from .gemm_V1 import *
from .reduce import *
from .ldsm import *
from .math import *
from .threadblock_swizzle import *
# Forward nvvm enums
from cutlass._mlir.dialects.nvvm import (
MemOrderKind,
MemScopeKind,
AtomicOpKind,
)
BYTES_PER_TENSORMAP = 128
BYTES_PER_POINTER = 8
def make_filled_tensor(shape, value):
t = cute.make_rmem_tensor(shape, type(value))
t.fill(value)
return t
def make_tensor_at_offset(ptr: cute.Pointer, offset, shape, div_by=1):
if div_by != 1:
offset = cute.assume(cutlass.as_numeric(offset), divby=div_by)
return cute.make_tensor(ptr + offset, shape)
def shuffle_elect(thread_extent):
# thread_extent is the number of threads of a warpgroup
warp_idx = cute.arch.warp_idx()
warp_idx = cute.arch.make_warp_uniform(warp_idx)
if thread_extent == 0:
return warp_idx == 0
else:
return (warp_idx % (thread_extent // 32)) == 0
def sync_thread_partial(barrier_id=None, thread_count=None):
bar_sync_ptx(barrier_id, thread_count)
# Packing functions
def pack_half2(x, y):
"""
Pack two half-precision (fp16) values into a single 32-bit value.
Corresponds to CUDA's __pack_half2 intrinsic.
This packs two fp16 values into a single int32 by treating the fp16 bits
as raw data and concatenating them.
"""
@dsl_user_op
def pack_half2_impl(x_val, y_val, *, loc=None, ip=None):
# Cast fp16 to uint16 (bitcast)
x_ir = x_val.ir_value(loc=loc, ip=ip) if hasattr(x_val, "ir_value") else x_val
y_ir = y_val.ir_value(loc=loc, ip=ip) if hasattr(y_val, "ir_value") else y_val
# Bitcast fp16 to i16
i16_type = mlir_ir.IntegerType.get_signless(16)
x_i16 = llvm.bitcast(i16_type, x_ir, loc=loc, ip=ip)
y_i16 = llvm.bitcast(i16_type, y_ir, loc=loc, ip=ip)
packed_xy = llvm.inline_asm(
Int32.mlir_type,
[x_i16, y_i16],
"mov.b32 $0, {$1, $2};",
"=r,h,h",
has_side_effects=True,
is_align_stack=False,
asm_dialect=llvm.AsmDialect.AD_ATT,
loc=loc,
ip=ip,
)
return Int32(packed_xy)
return pack_half2_impl(x, y)
def AtomicAdd(ptr: cute.Pointer, value: Numeric, *, loc=None, ip=None):
if ptr.dtype == cutlass.Float32:
ret = nvvm.atomicrmw(
T.f32(),
AtomicOpKind.FADD,
ptr.llvm_ptr,
ptr.dtype(value).ir_value(loc=loc, ip=ip),
mem_order=MemOrderKind.RELAXED,
syncscope=MemScopeKind.GPU,
loc=loc,
ip=ip,
)
elif ptr.dtype == cutlass.Int32:
ret = nvvm.atomicrmw(
T.i32(),
AtomicOpKind.ADD,
ptr.llvm_ptr,
ptr.dtype(value).ir_value(loc=loc, ip=ip),
mem_order=MemOrderKind.RELAXED,
syncscope=MemScopeKind.GPU,
loc=loc,
ip=ip,
)
else:
raise ValueError(f"Unsupported dtype: {ptr.dtype}")
return ptr.dtype(ret)
from __future__ import annotations
from cutlass.cutlass_dsl import CuTeDSL, T, if_generate, dsl_user_op # noqa: F401
from cutlass._mlir.dialects import nvvm, cute_nvgpu # noqa: F401
from cutlass._mlir import ir
import cutlass._mlir.dialects.cute as _cute_ir
import cutlass._mlir.dialects.cute_nvgpu as _cute_nvgpu_ir
import cutlass.cute as cute
from cutlass.cute.typing import Int, Boolean, Int32, Int16, Uint64, Union # noqa: F401
from cutlass.impl_utils import check_value_in
from cutlass.cute.arch import cp_async_commit_group as cp_async_commit # noqa: F401
from cutlass.cute.arch import cp_async_wait_group as cp_async_wait # noqa: F401
BYTES_PER_TENSORMAP = 128
BYTES_PER_POINTER = 8
def cp_async_gs(size, dst, dst_offset, src, src_offset):
assert size in [16, 8, 4]
# use CG (cache global) to by pass L1 when loading contiguous 128B.
mode = nvvm.LoadCacheModifierKind.CG if size == 16 else nvvm.LoadCacheModifierKind.CA
if isinstance(src, cute.Tensor):
src_ptr = src.iterator
elif isinstance(src, cute.Pointer):
src_ptr = src
else:
raise ValueError(f"Invalid source type: {type(src)}")
if isinstance(dst, cute.Tensor):
dst_ptr = dst.iterator
elif isinstance(dst, cute.Pointer):
dst_ptr = dst
else:
raise ValueError(f"Invalid destination type: {type(dst)}")
cp_async_shared_global(dst_ptr + dst_offset, src_ptr + src_offset, size, mode)
@cute.jit
def cp_async_gs_conditional(size, dst, dst_offset, src, src_offset, cond):
if cond:
cp_async_gs(size, dst, dst_offset, src, src_offset)
@dsl_user_op
def extract_tensormap_ptr(tma_atom: cute.CopyAtom, *, loc=None, ip=None) -> cute.Pointer:
"""
extract the tensormap pointer from a TMA Copy Atom.
:param tma_atom: The TMA Copy Atom
:type tma_atom: CopyAtom
"""
exec_value = _cute_nvgpu_ir.atom_make_exec_tma(tma_atom._trait.value, loc=loc, ip=ip)
ptr_type = _cute_ir.PtrType.get(Uint64.mlir_type, _cute_ir.AddressSpace.generic, 64)
tensormap_ptr = _cute_nvgpu_ir.get_tma_desc_addr(ptr_type, exec_value, loc=loc, ip=ip)
return tensormap_ptr
@dsl_user_op
def tma_load(tma_desc, mbar: cute.Pointer, smem_ptr: cute.Pointer, crd: Int | tuple[Int, ...], *, loc=None, ip=None) -> None:
"""
Load data from global memory to shared memory using TMA (Tensor Memory Access).
:param tma_desc: TMA descriptor for the tensor
:type tma_desc: CopyAtom or tensormap_ptr or Tensor of tensormap_ptr
:param mbar: Mbarrier pointer in shared memory
:type mbar: Pointer
:param smem_ptr: Destination pointer in shared memory
:type smem_ptr: Pointer
:param crd: Coordinates tuple for the tensor access
:type crd: tuple[Int, ...]
"""
arch = CuTeDSL._get_dsl().envar.arch
check_value_in(arch, ["sm_90", "sm_90a", "sm_100a"], "arch")
if not isinstance(crd, tuple) and isinstance(tma_desc, cute.Pointer):
# Legacy signature: tma_load(smem_ptr, gmem_ptr, mbar, size)
_smem_ptr = tma_desc
_gmem_ptr = mbar
_mbar = smem_ptr
nvvm.cp_async_bulk_shared_cluster_global(
dst_mem=_smem_ptr.llvm_ptr,
src_mem=_gmem_ptr.llvm_ptr,
mbar=_mbar.llvm_ptr,
size=Int32(crd).ir_value(loc=loc, ip=ip),
loc=loc,
ip=ip,
)
else:
if isinstance(tma_desc, cute.CopyAtom):
tma_desc_ptr = extract_tensormap_ptr(tma_desc)
elif isinstance(tma_desc, cute.Tensor):
tma_desc_ptr = tma_desc.iterator
else:
tma_desc_ptr = tma_desc
nvvm.cp_async_bulk_tensor_shared_cluster_global(
dst_mem=smem_ptr.llvm_ptr,
tma_descriptor=tma_desc_ptr.llvm_ptr,
coordinates=[Int32(i).ir_value(loc=loc, ip=ip) for i in crd],
mbar=mbar.llvm_ptr,
im2col_offsets=[],
load_mode=nvvm.CpAsyncBulkTensorLoadMode.TILE,
group=nvvm.Tcgen05GroupKind.CTA_1,
use_intrinsic=False, # set to True would lead to compile error
loc=loc,
ip=ip,
)
@dsl_user_op
def tma_store(tma_desc, smem_ptr: cute.Pointer, crd: Int | tuple[Int, ...], *, loc=None, ip=None) -> None:
"""
Store data from shared memory to global memory using TMA (Tensor Memory Access).
:param tma_desc: TMA descriptor for the tensor
:type tma_desc: TMA descriptor
:param smem_ptr: Source pointer in shared memory
:type smem_ptr: Pointer
:param crd: Coordinates tuple for the tensor access
:type crd: tuple[Int, ...]
"""
arch = CuTeDSL._get_dsl().envar.arch
check_value_in(arch, ["sm_90", "sm_90a", "sm_100a"], "arch")
if not isinstance(crd, tuple):
if arch not in ("sm_90", "sm_90a"):
raise NotImplementedError("tma_store(size) path is only implemented for sm_90/sm_90a")
gmem_ptr = tma_desc.align(smem_ptr.alignment)
_cute_nvgpu_ir.arch_copy_SM90_bulk_copy_s2g(
dsmem_data_addr=smem_ptr.value,
gmem_data_addr=gmem_ptr.value,
size=ir.IntegerAttr.get(ir.IntegerType.get_signless(32), crd),
loc=loc,
ip=ip,
)
else:
if isinstance(tma_desc, cute.CopyAtom):
tma_desc_ptr = extract_tensormap_ptr(tma_desc)
elif isinstance(tma_desc, cute.Tensor):
tma_desc_ptr = tma_desc.iterator
else:
tma_desc_ptr = tma_desc
nvvm.cp_async_bulk_tensor_global_shared_cta(
tma_descriptor=tma_desc_ptr.llvm_ptr,
src_mem=smem_ptr.llvm_ptr,
coordinates=[Int32(i).ir_value(loc=loc, ip=ip) for i in crd],
predicate=None,
loc=loc,
ip=ip,
)
@dsl_user_op
def tma_store_arrive(*, loc=None, ip=None) -> None:
"""
Indicate arrival of warp issuing TMA_STORE.
Corresponds to PTX instruction: cp.async.bulk.commit_group;
"""
nvvm.cp_async_bulk_commit_group(loc=loc, ip=ip)
@dsl_user_op
def tma_store_wait(count: int, *, read=None, loc=None, ip=None) -> None:
"""
Wait for TMA_STORE operations to complete.
Corresponds to PTX instruction: cp.async.bulk.wait_group.read <count>;
:param count: The number of outstanding bulk async groups to wait for
:type count: Int
"""
nvvm.cp_async_bulk_wait_group(group=count, read=read, loc=loc, ip=ip)
@dsl_user_op
def cp_async_shared_global(
dst: cute.Pointer, src: cute.Pointer, cp_size: Int, modifier: nvvm.LoadCacheModifierKind, *, src_size: Int = None, loc=None, ip=None
) -> None:
"""
Asynchronously copy data from global memory to shared memory.
:param dst: Destination pointer in shared memory
:type dst: Pointer
:param src: Source pointer in global memory
:type src: Pointer
:param size: Size of the copy in bytes
:type size: Int
:param modifier: Cache modifier
:type modifier: Int
:param cp_size: Optional copy size override
:type cp_size: Int
"""
size = src_size if src_size else cp_size
nvvm.cp_async_shared_global(
dst=dst.llvm_ptr,
src=src.llvm_ptr,
size=ir.IntegerAttr.get(ir.IntegerType.get_signless(32), size),
modifier=modifier,
cp_size=Int32(cp_size).ir_value(loc=loc, ip=ip),
loc=loc,
ip=ip,
)
@dsl_user_op
def prefetch_tma_descriptor(tma_desc, *, loc=None, ip=None) -> None:
"""
Prefetch a TMA descriptor.
Corresponds to PTX instruction: prefetch.tensormap;
"""
if isinstance(tma_desc, cute.CopyAtom):
tma_desc_ptr = extract_tensormap_ptr(tma_desc)
elif isinstance(tma_desc, cute.Tensor):
tma_desc_ptr = tma_desc.iterator
else:
tma_desc_ptr = tma_desc
nvvm.prefetch_tensormap(tma_desc_ptr.llvm_ptr, loc=loc, ip=ip)
import cutlass
import cutlass.cute as cute
import cutlass.utils as utils # noqa: F401
import math
import cutlass.utils.hopper_helpers as hopper_utils
from cutlass.utils import LayoutEnum
from cutlass.cute.nvgpu.warpgroup import OperandMajorMode, OperandSource, make_smem_layout_atom
def make_aligned_tensor(ptr: cute.Pointer, layout: cute.Layout, align_bytes: int, swizzle=False):
ptr = ptr.align(align_bytes)
if swizzle and isinstance(layout, cute.ComposedLayout):
ptr = cute.recast_ptr(ptr=ptr, swizzle_=layout.inner, dtype=ptr.dtype)
return cute.make_tensor(ptr, layout.outer)
return cute.make_tensor(ptr, layout)
def gemm_ss(
M,
N,
K,
warp_m,
warp_n,
trans_A,
trans_B,
clear_accum,
stride_A,
stride_B,
offset_A,
offset_B,
use_wgmma=None,
wg_wait=0,
A_ptr: cute.Pointer = None,
B_ptr: cute.Pointer = None,
C_ptr: cute.Pointer = None,
):
"""GEMM with both A and B from shared memory"""
if use_wgmma:
gemm = Gemm_SM90(
M,
N,
K,
warp_m,
warp_n,
trans_A,
trans_B,
clear_accum,
stride_A,
stride_B,
offset_A,
offset_B,
A_ptr.dtype,
B_ptr.dtype,
C_ptr.dtype,
)
gemm(A_ptr, B_ptr, C_ptr, wg_wait=wg_wait, clear_accum=clear_accum)
else:
gemm = Gemm_SM80(
M,
N,
K,
warp_m,
warp_n,
trans_A,
trans_B,
clear_accum,
stride_A,
stride_B,
offset_A,
offset_B,
A_ptr.dtype,
B_ptr.dtype,
C_ptr.dtype,
)
gemm(A_ptr, B_ptr, C_ptr)
def gemm_rs(
M,
N,
K,
warp_m,
warp_n,
trans_A,
trans_B,
clear_accum,
stride_A,
stride_B,
offset_A,
offset_B,
use_wgmma=None,
wg_wait=0,
A_ptr: cute.Pointer = None,
B_ptr: cute.Pointer = None,
C_ptr: cute.Pointer = None,
):
"""GEMM with A from register/fragment and B from shared memory"""
if use_wgmma:
gemm = Gemm_SM90(
M,
N,
K,
warp_m,
warp_n,
trans_A,
trans_B,
clear_accum,
stride_A,
stride_B,
offset_A,
offset_B,
A_ptr.dtype,
B_ptr.dtype,
C_ptr.dtype,
)
gemm.body_rs(A_ptr, B_ptr, C_ptr, wg_wait=wg_wait, clear_accum=clear_accum)
else:
gemm = Gemm_SM80(
M,
N,
K,
warp_m,
warp_n,
trans_A,
trans_B,
clear_accum,
stride_A,
stride_B,
offset_A,
offset_B,
A_ptr.dtype,
B_ptr.dtype,
C_ptr.dtype,
)
gemm.body_rs(A_ptr, B_ptr, C_ptr)
def gemm_sr(
M,
N,
K,
warp_m,
warp_n,
trans_A,
trans_B,
clear_accum,
stride_A,
stride_B,
offset_A,
offset_B,
use_wgmma=None,
wg_wait=0,
A_ptr: cute.Pointer = None,
B_ptr: cute.Pointer = None,
C_ptr: cute.Pointer = None,
):
"""GEMM with A from shared memory and B from register/fragment"""
# wgmma doesn't support gemm_sr, only use SM80
gemm = Gemm_SM80(
M,
N,
K,
warp_m,
warp_n,
trans_A,
trans_B,
clear_accum,
stride_A,
stride_B,
offset_A,
offset_B,
A_ptr.dtype,
B_ptr.dtype,
C_ptr.dtype,
)
gemm.body_sr(A_ptr, B_ptr, C_ptr)
def gemm_rr(
M,
N,
K,
warp_m,
warp_n,
trans_A,
trans_B,
clear_accum,
stride_A,
stride_B,
offset_A,
offset_B,
use_wgmma=None,
wg_wait=0,
A_ptr: cute.Pointer = None,
B_ptr: cute.Pointer = None,
C_ptr: cute.Pointer = None,
):
"""GEMM with both A and B from register/fragment"""
# Both operands in register, no copy needed
gemm = Gemm_SM80(
M,
N,
K,
warp_m,
warp_n,
trans_A,
trans_B,
clear_accum,
stride_A,
stride_B,
offset_A,
offset_B,
A_ptr.dtype,
B_ptr.dtype,
C_ptr.dtype,
)
# For gemm_rr, directly call _body_impl with copy_A=False, copy_B=False
gemm._body_impl(A_ptr, B_ptr, C_ptr, copy_A=False, copy_B=False)
class Gemm_SM80:
_instances = {} # cache instances for the same arguments
def __new__(cls, *args):
key = args
if key not in cls._instances:
cls._instances[key] = super().__new__(cls)
return cls._instances[key]
# in Tilelang, trans_A == 0 or trans_B == 1 means K major
# in Cute, trans == 0 means K major
def __init__(
self, M, N, K, warp_m, warp_n, trans_A, trans_B, clear_accum, stride_A, stride_B, offset_A, offset_B, A_type, B_type, C_type
):
if not hasattr(self, "initialized"):
self.cta_tiler = (M, N, K)
self.mma_inst_shape = (16, 8, 16)
self.trans_A = trans_A != 0 # same with Tilelang
self.trans_B = trans_B == 0 # inverse with Tilelang
A_major_mode = LayoutEnum.COL_MAJOR if self.trans_A else LayoutEnum.ROW_MAJOR
B_major_mode = LayoutEnum.COL_MAJOR if self.trans_B else LayoutEnum.ROW_MAJOR
self.A_layout = self._make_smem_layout_AB(A_type, A_major_mode, 128, (M, K))
self.B_layout = self._make_smem_layout_AB(B_type, B_major_mode, 128, (N, K))
self.ab_dtype = A_type
self.acc_dtype = C_type
self.tiled_mma = self._make_tiled_mma(warp_m, warp_n)
self.clear_accum = clear_accum
def _make_smem_layout_AB(self, dtype, major_mode, copy_bits, smem_tiler):
is_row_major = major_mode == LayoutEnum.ROW_MAJOR
major_mode_size = smem_tiler[1] if is_row_major else smem_tiler[0]
major_mode_size = 64 if major_mode_size >= 64 else major_mode_size
swizzle_bits = int(math.log2(major_mode_size * dtype.width // copy_bits))
swizzle_bits = min(swizzle_bits, 3)
layout_atom_outer = (
cute.make_layout((8, major_mode_size), stride=(major_mode_size, 1))
if is_row_major
else cute.make_layout((major_mode_size, 8), stride=(1, major_mode_size))
)
layout_atom = cute.make_composed_layout(
cute.make_swizzle(swizzle_bits, 3, 3),
0,
layout_atom_outer,
)
layout = cute.tile_to_shape(layout_atom, smem_tiler, (0, 1) if is_row_major else (1, 0))
return layout
def _make_tiled_mma(self, warp_m, warp_n):
atom_layout_mnk = (warp_m, warp_n, 1)
op = cute.nvgpu.warp.MmaF16BF16Op(self.ab_dtype, self.acc_dtype, self.mma_inst_shape)
permutation_mnk = (
atom_layout_mnk[0] * self.mma_inst_shape[0],
atom_layout_mnk[1] * self.mma_inst_shape[1] * 2,
atom_layout_mnk[2] * self.mma_inst_shape[2],
)
tiled_mma = cute.make_tiled_mma(op, atom_layout_mnk, permutation_mnk)
return tiled_mma
@cute.jit
def __call__(
self,
sA_ptr: cute.Pointer,
sB_ptr: cute.Pointer,
rC_ptr: cute.Pointer,
):
"""GEMM body: both A and B from shared memory"""
self._body_impl(sA_ptr, sB_ptr, rC_ptr, copy_A=True, copy_B=True)
@cute.jit
def body_rs(
self,
rA_ptr: cute.Pointer, # A already in register
sB_ptr: cute.Pointer, # B from shared memory
rC_ptr: cute.Pointer,
):
"""GEMM body_rs: A from register, B from shared memory"""
self._body_impl(rA_ptr, sB_ptr, rC_ptr, copy_A=False, copy_B=True)
@cute.jit
def body_sr(
self,
sA_ptr: cute.Pointer, # A from shared memory
rB_ptr: cute.Pointer, # B already in register
rC_ptr: cute.Pointer,
):
"""GEMM body_sr: A from shared memory, B from register"""
self._body_impl(sA_ptr, rB_ptr, rC_ptr, copy_A=True, copy_B=False)
@cute.jit
def _body_impl(
self,
A_ptr: cute.Pointer,
B_ptr: cute.Pointer,
rC_ptr: cute.Pointer,
copy_A: cutlass.Constexpr = True,
copy_B: cutlass.Constexpr = True,
):
"""Internal implementation with configurable copy operations"""
tidx, _, _ = cute.arch.thread_idx()
thr_mma = self.tiled_mma.get_slice(tidx)
tCrA = None
tCrB = None
tCrC = cute.make_tensor(rC_ptr, self.tiled_mma.partition_shape_C((self.cta_tiler[0], self.cta_tiler[1])))
# Create copy operations only for operands that need copying
if cutlass.const_expr(copy_A):
sA = make_aligned_tensor(A_ptr, self.A_layout, 16)
tCsA = thr_mma.partition_A(sA)
tCrA = self.tiled_mma.make_fragment_A(tCsA)
atom_copy_s2r_A = cute.make_copy_atom(
cute.nvgpu.warp.LdMatrix8x8x16bOp(self.trans_A, 4),
sA.element_type,
)
tiled_copy_s2r_A = cute.make_tiled_copy(
atom_copy_s2r_A,
layout_tv=self.tiled_mma.tv_layout_A_tiled,
tiler_mn=(self.tiled_mma.get_tile_size(0), self.tiled_mma.get_tile_size(2)),
)
thr_copy_ldmatrix_A = tiled_copy_s2r_A.get_slice(tidx)
tCsA_copy_view = thr_copy_ldmatrix_A.partition_S(sA)
tCrA_copy_view = thr_copy_ldmatrix_A.retile(tCrA)
else:
# A already in register
tCrA = cute.make_tensor(A_ptr, self.tiled_mma.partition_shape_A((self.cta_tiler[0], self.cta_tiler[2])))
if cutlass.const_expr(copy_B):
sB = make_aligned_tensor(B_ptr, self.B_layout, 16)
tCsB = thr_mma.partition_B(sB)
tCrB = self.tiled_mma.make_fragment_B(tCsB)
atom_copy_s2r_B = cute.make_copy_atom(
cute.nvgpu.warp.LdMatrix8x8x16bOp(self.trans_B, 4),
sB.element_type,
)
tiled_copy_s2r_B = cute.make_tiled_copy(
atom_copy_s2r_B,
layout_tv=self.tiled_mma.tv_layout_B_tiled,
tiler_mn=(self.tiled_mma.get_tile_size(1), self.tiled_mma.get_tile_size(2)),
)
thr_copy_ldmatrix_B = tiled_copy_s2r_B.get_slice(tidx)
tCsB_copy_view = thr_copy_ldmatrix_B.partition_S(sB)
tCrB_copy_view = thr_copy_ldmatrix_B.retile(tCrB)
else:
# B already in register
tCrB = cute.make_tensor(B_ptr, self.tiled_mma.partition_shape_B((self.cta_tiler[1], self.cta_tiler[2])))
if self.clear_accum:
tCrC.fill(0)
for k in cutlass.range(cute.size(tCrA, mode=[2])):
if cutlass.const_expr(copy_A):
cute.copy(tiled_copy_s2r_A, tCsA_copy_view[None, None, k], tCrA_copy_view[None, None, k])
if cutlass.const_expr(copy_B):
cute.copy(tiled_copy_s2r_B, tCsB_copy_view[None, None, k], tCrB_copy_view[None, None, k])
cute.gemm(self.tiled_mma, tCrC, tCrA[None, None, k], tCrB[None, None, k], tCrC)
class Gemm_SM90:
_instances = {} # cache instances for the same arguments
def __new__(cls, *args):
key = args
if key not in cls._instances:
cls._instances[key] = super().__new__(cls)
return cls._instances[key]
# in Tilelang, trans_A == 0 or trans_B == 1 means K major
# in Cute, trans == 0 means K major
def __init__(
self, M, N, K, warp_m, warp_n, trans_A, trans_B, clear_accum, stride_A, stride_B, offset_A, offset_B, A_type, B_type, C_type
):
if not hasattr(self, "initialized"):
self.cta_tiler = (M, N, K)
self.tiler_mn = (M, N)
self.atom_layout_mnk = (warp_m // 4, warp_n, 1)
self.trans_A = trans_A != 0 # same with Tilelang
self.trans_B = trans_B == 0 # inverse with Tilelang
self.a_leading_mode = OperandMajorMode.MN if self.trans_A else OperandMajorMode.K
self.b_leading_mode = OperandMajorMode.MN if self.trans_B else OperandMajorMode.K
A_major_mode = LayoutEnum.COL_MAJOR if self.trans_A else LayoutEnum.ROW_MAJOR
B_major_mode = LayoutEnum.COL_MAJOR if self.trans_B else LayoutEnum.ROW_MAJOR
self.A_layout = self.make_smem_layout_AB(A_type, A_major_mode, (M, K))
self.B_layout = self.make_smem_layout_AB(B_type, B_major_mode, (N, K))
self.a_dtype = A_type
self.b_dtype = B_type
self.acc_dtype = C_type
self.tiled_mma = None
self.A_source = None
self.clear_accum = clear_accum
@staticmethod
def make_tma_atom(
tensor,
smem_layout_staged,
smem_tile,
mcast_dim,
):
op = cute.nvgpu.cpasync.CopyBulkTensorTileG2SOp() if mcast_dim == 1 else cute.nvgpu.cpasync.CopyBulkTensorTileG2SMulticastOp()
smem_layout = cute.slice_(smem_layout_staged, (None, None, 0))
tma_atom, tma_tensor = cute.nvgpu.cpasync.make_tiled_tma_atom(
op,
tensor,
smem_layout,
smem_tile,
num_multicast=mcast_dim,
)
return tma_atom
@staticmethod
def get_tma_atom(tensor, tiler_mk, stages=1):
smem_layout_staged = Gemm_SM90.make_smem_layout_AB(tensor.element_type, LayoutEnum.from_tensor(tensor), tiler_mk, stages)
tma_atom = Gemm_SM90.make_tma_atom(tensor, smem_layout_staged, tiler_mk, 1)
return tma_atom
@staticmethod
def make_smem_layout_AB(dtype, major_mode: LayoutEnum, tiler_mk, stages=1):
smem_shape = tiler_mk
# Determine if K is the major mode and get the major mode size
is_k_major = major_mode.sm90_mma_major_mode() == cute.nvgpu.warpgroup.OperandMajorMode.K
major_mode_size = tiler_mk[1] if is_k_major else tiler_mk[0]
# Create SMEM layout atom for A tensor based on major mode and data type
smem_layout_atom = make_smem_layout_atom(
hopper_utils.get_smem_layout_atom(major_mode, dtype, major_mode_size),
dtype,
)
# Tile the SMEM layout atom to the A tensor shape and add staging dimension
smem_layout = cute.tile_to_shape(smem_layout_atom, cute.append(smem_shape, stages), order=(0, 1, 2) if is_k_major else (1, 0, 2))
return smem_layout
def _make_tiled_mma(self, is_rsMode=False):
tiled_mma = hopper_utils.make_trivial_tiled_mma(
self.a_dtype,
self.b_dtype,
self.a_leading_mode,
self.b_leading_mode,
self.acc_dtype,
self.atom_layout_mnk,
(64, self.tiler_mn[1] // self.atom_layout_mnk[1]),
OperandSource.SMEM if not is_rsMode else OperandSource.RMEM,
)
return tiled_mma
@cute.jit
def __call__(
self,
sA_ptr: cute.Pointer,
sB_ptr: cute.Pointer,
rC_ptr: cute.Pointer,
wg_wait: cutlass.Constexpr = 0,
clear_accum: cutlass.Constexpr = False,
):
tidx, _, _ = cute.arch.thread_idx()
self.tiled_mma = self._make_tiled_mma()
thr_mma = self.tiled_mma.get_slice(tidx)
sA_ptr = cute.recast_ptr(sA_ptr, self.A_layout.inner, dtype=sA_ptr.dtype)
sB_ptr = cute.recast_ptr(sB_ptr, self.B_layout.inner, dtype=sB_ptr.dtype)
sA = cute.make_tensor(sA_ptr, self.A_layout.outer)
sB = cute.make_tensor(sB_ptr, self.B_layout.outer)
tCsA = thr_mma.partition_A(sA)
tCsB = thr_mma.partition_B(sB)
tCrA = self.tiled_mma.make_fragment_A(tCsA)
tCrB = self.tiled_mma.make_fragment_B(tCsB)
tCrC = cute.make_tensor(rC_ptr, self.tiled_mma.partition_shape_C((self.cta_tiler[0], self.cta_tiler[1])))
cute.nvgpu.warpgroup.fence()
if cutlass.const_expr(clear_accum):
self.tiled_mma.set(cute.nvgpu.warpgroup.Field.ACCUMULATE, False)
else:
self.tiled_mma.set(cute.nvgpu.warpgroup.Field.ACCUMULATE, True)
num_k_blocks = cute.size(tCrA, mode=[2])
for k in cutlass.range(num_k_blocks):
tCrA_1phase = tCrA[None, None, k, 0]
tCrB_1phase = tCrB[None, None, k, 0]
cute.gemm(self.tiled_mma, tCrC, tCrA_1phase, tCrB_1phase, tCrC)
self.tiled_mma.set(cute.nvgpu.warpgroup.Field.ACCUMULATE, True)
cute.nvgpu.warpgroup.commit_group()
if cutlass.const_expr(wg_wait >= 0):
cute.nvgpu.warpgroup.wait_group(wg_wait)
@cute.jit
def body_rs(
self,
rA_ptr: cute.Pointer, # A already in register (Fragment)
sB_ptr: cute.Pointer, # B from shared memory
rC_ptr: cute.Pointer,
wg_wait: cutlass.Constexpr = 0,
clear_accum: cutlass.Constexpr = False,
):
"""
GEMM body_rs for SM90/Hopper: A from register, B from shared memory.
Based on cute::tl_wgmma::GemmTensorOp::body_rs from gemm_sm90.h
"""
tidx, _, _ = cute.arch.thread_idx()
self.tiled_mma = self._make_tiled_mma(is_rsMode=True)
# if self.A_source != OperandSource.RMEM or self.tiled_mma is None:
# self.tiled_mma = self._make_tiled_mma(is_rsMode = True)
# self.A_source = OperandSource.RMEM
# B from shared memory (with swizzle)
sB_ptr = cute.recast_ptr(sB_ptr, self.B_layout.inner, dtype=sB_ptr.dtype)
sB = cute.make_tensor(sB_ptr, self.B_layout.outer)
# Use the existing tiled_mma
thr_mma = self.tiled_mma.get_slice(tidx)
# Partition B from shared memory - standard path
tCsB = thr_mma.partition_B(sB)
tCrB = self.tiled_mma.make_fragment_B(tCsB)
# A already in register
# For body_rs, A is NOT partitioned through thr_mma (it's already partitioned)
# We create the tensor directly with the full shape
# This matches C++: make_tensor(make_rmem_ptr(pA), partition_shape_A(...))
tCrA = cute.make_tensor(rA_ptr, self.tiled_mma.partition_shape_A((self.cta_tiler[0], self.cta_tiler[2])))
# C accumulator
tCrC = cute.make_tensor(rC_ptr, self.tiled_mma.partition_shape_C((self.cta_tiler[0], self.cta_tiler[1])))
# Fence operands (prepare for wgmma)
cute.nvgpu.warpgroup.fence()
# Note: warpgroup_arrive() is called internally by wgmma
# Set accumulation mode
if cutlass.const_expr(clear_accum):
self.tiled_mma.set(cute.nvgpu.warpgroup.Field.ACCUMULATE, False)
else:
self.tiled_mma.set(cute.nvgpu.warpgroup.Field.ACCUMULATE, True)
# GEMM loop
num_k_blocks = cute.size(tCrB, mode=[2])
for k_block in cutlass.range(num_k_blocks):
# Match the indexing pattern from __call__
# If tCrB has 4 dimensions (with pipeline), use [None, None, k, 0]
# Otherwise use [None, None, k]
tCrB_k = tCrB[None, None, k_block, 0] if cute.rank(tCrB) >= 4 else tCrB[None, None, k_block]
tCrA_k = tCrA[None, None, k_block, 0] if cute.rank(tCrA) >= 4 else tCrA[None, None, k_block]
cute.gemm(self.tiled_mma, tCrC, tCrA_k, tCrB_k, tCrC)
self.tiled_mma.set(cute.nvgpu.warpgroup.Field.ACCUMULATE, True)
cute.nvgpu.warpgroup.commit_group()
if cutlass.const_expr(wg_wait >= 0):
cute.nvgpu.warpgroup.wait_group(wg_wait)
"""
LDMATRIX and STMATRIX operations for CuTeDSL backend.
Based on tl_templates/cuda/ldsm.h
These functions provide wrappers around PTX ldmatrix/stmatrix instructions
for loading/storing 8x8 matrix fragments between shared memory and registers.
"""
from cutlass.cutlass_dsl import T, dsl_user_op
from cutlass._mlir.dialects import nvvm, llvm
from cutlass._mlir import ir # noqa: F401
from cutlass.cute.typing import Pointer, Int32 # noqa: F401
import cutlass.cute as cute
def _to_ir_value(v, loc=None, ip=None):
"""Convert value to MLIR IR, handling both cutlass types and raw MLIR Values"""
if hasattr(v, "ir_value"):
return v.ir_value(loc=loc, ip=ip)
else:
# Already an MLIR Value
return v
def _ldmatrix(smem_ptr, local_ptr, num, transpose, loc=None, ip=None):
"""Internal helper for ldmatrix operations"""
layout = nvvm.MMALayout.col if transpose else nvvm.MMALayout.row
assert num in [2, 4]
ret_type = llvm.StructType.get_literal([T.i32()] * num)
out_i32 = nvvm.ldmatrix(ret_type, smem_ptr.llvm_ptr, num=num, layout=layout, loc=loc, ip=ip)
out = cute.make_tensor(cute.recast_ptr(local_ptr, dtype=cute.Int32), num)
for i in range(num):
out[i] = cute.Int32(llvm.extractvalue(T.i32(), out_i32, [i], loc=loc, ip=ip))
def _stmatrix(smem_ptr, values, transpose, loc=None, ip=None):
"""Internal helper for stmatrix operations"""
layout = nvvm.MMALayout.col if transpose else nvvm.MMALayout.row
ir_values = [_to_ir_value(v, loc, ip) for v in values]
nvvm.stmatrix(smem_ptr.llvm_ptr, ir_values, layout=layout, loc=loc, ip=ip)
# ============================================================================
# LDMATRIX operations (load from shared memory to registers)
# ============================================================================
@dsl_user_op
def ptx_ldmatrix_x1(smem_ptr: Pointer, local_ptr: Pointer, *, loc=None, ip=None) -> None:
"""Load 1 matrix (8x8) from shared memory"""
# _ldmatrix(smem_ptr, local_ptr, 1, False, loc, ip)
out_i32 = nvvm.ldmatrix(T.i32(), smem_ptr.llvm_ptr, num=1, layout=nvvm.MMALayout.row, loc=loc, ip=ip)
out = cute.make_tensor(cute.recast_ptr(local_ptr, dtype=cute.Int32), 1)
out[0] = cute.Int32(out_i32)
@dsl_user_op
def ptx_ldmatrix_x2(smem_ptr: Pointer, local_ptr: Pointer, *, loc=None, ip=None) -> None:
"""Load 2 matrices (8x8 each) from shared memory"""
_ldmatrix(smem_ptr, local_ptr, 2, False, loc, ip)
@dsl_user_op
def ptx_ldmatrix_x4(smem_ptr: Pointer, local_ptr: Pointer, *, loc=None, ip=None) -> None:
"""Load 4 matrices (8x8 each) from shared memory"""
_ldmatrix(smem_ptr, local_ptr, 4, False, loc, ip)
@dsl_user_op
def ptx_ldmatrix_x1_trans(smem_ptr: Pointer, local_ptr: Pointer, *, loc=None, ip=None) -> None:
"""Load 1 matrix (8x8) with transpose from shared memory"""
out_i32 = nvvm.ldmatrix(T.i32(), smem_ptr.llvm_ptr, num=1, layout=nvvm.MMALayout.col, loc=loc, ip=ip)
out = cute.make_tensor(cute.recast_ptr(local_ptr, dtype=cute.Int32), 1)
out[0] = cute.Int32(out_i32)
@dsl_user_op
def ptx_ldmatrix_x2_trans(smem_ptr: Pointer, local_ptr: Pointer, *, loc=None, ip=None) -> None:
"""Load 2 matrices (8x8 each) with transpose from shared memory"""
_ldmatrix(smem_ptr, local_ptr, 2, True, loc, ip)
@dsl_user_op
def ptx_ldmatrix_x4_trans(smem_ptr: Pointer, local_ptr: Pointer, *, loc=None, ip=None) -> None:
"""Load 4 matrices (8x8 each) with transpose from shared memory"""
_ldmatrix(smem_ptr, local_ptr, 4, True, loc, ip)
# ============================================================================
# STMATRIX operations (store from registers to shared memory)
# ============================================================================
@dsl_user_op
def ptx_stmatrix_x1(smem_ptr: Pointer, value0, *, loc=None, ip=None) -> None:
"""Store 1 matrix (8x8) to shared memory"""
_stmatrix(smem_ptr, [value0], False, loc, ip)
@dsl_user_op
def ptx_stmatrix_x2(smem_ptr: Pointer, value0, value1, *, loc=None, ip=None) -> None:
"""Store 2 matrices (8x8 each) to shared memory"""
_stmatrix(smem_ptr, [value0, value1], False, loc, ip)
@dsl_user_op
def ptx_stmatrix_x4(smem_ptr: Pointer, value0, value1, value2, value3, *, loc=None, ip=None) -> None:
"""Store 4 matrices (8x8 each) to shared memory"""
_stmatrix(smem_ptr, [value0, value1, value2, value3], False, loc, ip)
@dsl_user_op
def ptx_stmatrix_x1_trans(smem_ptr: Pointer, value0, *, loc=None, ip=None) -> None:
"""Store 1 matrix (8x8) with transpose to shared memory"""
_stmatrix(smem_ptr, [value0], True, loc, ip)
@dsl_user_op
def ptx_stmatrix_x2_trans(smem_ptr: Pointer, value0, value1, *, loc=None, ip=None) -> None:
"""Store 2 matrices (8x8 each) with transpose to shared memory"""
_stmatrix(smem_ptr, [value0, value1], True, loc, ip)
@dsl_user_op
def ptx_stmatrix_x4_trans(smem_ptr: Pointer, value0, value1, value2, value3, *, loc=None, ip=None) -> None:
"""Store 4 matrices (8x8 each) with transpose to shared memory"""
_stmatrix(smem_ptr, [value0, value1, value2, value3], True, loc, ip)
import cutlass.cute as cute
from cutlass.cute.typing import Union, Numeric
from cutlass.cute.tensor import TensorSSA
from cutlass._mlir.dialects import arith
from cutlass.cute.math import exp, exp2, log, log2, log10, tan, cos, sin, sqrt # noqa: F401
def divf(x: Union[TensorSSA, Numeric], y: Union[TensorSSA, Numeric], fastmath: bool = False) -> Union[TensorSSA, Numeric]:
return cute.math._math_op(arith.divf, fastmath, x, y)
"""
Simple wrappers that delegate to cutlass.cute.arch implementations.
We use the existing implementations from cutlass rather than reinventing the wheel.
"""
from cutlass.cute.typing import Pointer, Int, Int32, Boolean # noqa: F401
from cutlass.cutlass_dsl import CuTeDSL, dsl_user_op # noqa: F401
from cutlass._mlir.dialects import nvvm
from cutlass.cute.arch import mbarrier_init, mbarrier_expect_tx, mbarrier_arrive # noqa: F401
from cutlass.cute.arch import mbarrier_arrive_and_expect_tx as arrive_and_expect_tx # noqa: F401
from cutlass.cute.arch import cp_async_mbarrier_arrive_noinc as mbarrier_cp_async_arrive_noinc # noqa: F401
import cutlass.cute.arch as arch
@dsl_user_op
def mbarrier_wait(mbar_ptr: Pointer, phase: Int, timeout_ns: Int = 10000000, *, loc=None, ip=None) -> None:
"""Waits on a mbarrier with a specified phase."""
nvvm.mbarrier_try_wait_parity_shared(
mbar_ptr.llvm_ptr,
Int32(phase).ir_value(loc=loc, ip=ip),
Int32(timeout_ns).ir_value(loc=loc, ip=ip),
loc=loc,
ip=ip,
)
@dsl_user_op
def mbarrier_cp_async_arrive(mbar_ptr: Pointer, *, loc=None, ip=None) -> None:
mbar_llvm_ptr = mbar_ptr.llvm_ptr
nvvm.cp_async_mbarrier_arrive_shared(
mbar_llvm_ptr,
noinc=False,
loc=loc,
ip=ip,
)
def fence_proxy_async():
arch.fence_proxy(arch.ProxyKind.async_shared, space=arch.SharedSpace.shared_cta)
def fence_barrier_init():
arch.mbarrier_init_fence()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment