Unverified Commit c369d690 authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Refactor] Refactor CUDA code generation to simplify eviction policy handling (#721)

* Update submodule 'tvm' to commit e11521e6936a827efa334588d29571fbb4620107

* Refactor CUDA code generation to simplify eviction policy handling

- Updated `VisitExpr_` methods in `codegen_cuda.cc` to use default eviction policy for `tma_load`, `tma_load_im2col`, and `tma_store` functions, reducing complexity.
- Removed conditional assembly code for `EVICT_NORMAL` in `copy_sm90.h`, streamlining the assembly calls for tensor memory operations.

* lint fix
parent 2bd2d69e
...@@ -1000,7 +1000,12 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) { ...@@ -1000,7 +1000,12 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) {
auto eviction_policy = auto eviction_policy =
this->eviction_policy_names_ this->eviction_policy_names_
[op->args[op->args.size() - 1].as<IntImmNode>()->value]; [op->args[op->args.size() - 1].as<IntImmNode>()->value];
// Simplify the code by using the default eviction policy
if (eviction_policy != "EVICT_NORMAL") {
ss << "tl::tma_load<tl::CacheHintSm90::" << eviction_policy << ">("; ss << "tl::tma_load<tl::CacheHintSm90::" << eviction_policy << ">(";
} else {
ss << "tl::tma_load(";
}
auto desc = op->args[0]; auto desc = op->args[0];
ss << this->PrintExpr(desc) << ", "; ss << this->PrintExpr(desc) << ", ";
if (const IntImmNode *imm = op->args[1].as<IntImmNode>()) { if (const IntImmNode *imm = op->args[1].as<IntImmNode>()) {
...@@ -1018,17 +1023,25 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) { ...@@ -1018,17 +1023,25 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) {
this->stream << ss.str(); this->stream << ss.str();
} else if (op->op.same_as(tl::tma_load_im2col())) { } else if (op->op.same_as(tl::tma_load_im2col())) {
std::stringstream ss; std::stringstream ss;
ss << "tl::tma_load_im2col<tl::CacheHintSm90::" auto eviction_policy =
<< this->eviction_policy_names_ this->eviction_policy_names_
[op->args[op->args.size() - 1].as<IntImmNode>()->value] [op->args[op->args.size() - 1].as<IntImmNode>()->value];
<< ">"; if (eviction_policy != "EVICT_NORMAL") {
ss << "tl::tma_load_im2col<tl::CacheHintSm90::" << eviction_policy << ">";
} else {
ss << "tl::tma_load_im2col";
}
print_extern_call_stmt(ss.str(), 0, 1); print_extern_call_stmt(ss.str(), 0, 1);
} else if (op->op.same_as(tl::tma_store())) { } else if (op->op.same_as(tl::tma_store())) {
std::stringstream ss; std::stringstream ss;
ss << "tl::tma_store<tl::CacheHintSm90::" auto eviction_policy =
<< this->eviction_policy_names_ this->eviction_policy_names_
[op->args[op->args.size() - 1].as<IntImmNode>()->value] [op->args[op->args.size() - 1].as<IntImmNode>()->value];
<< ">"; if (eviction_policy != "EVICT_NORMAL") {
ss << "tl::tma_store<tl::CacheHintSm90::" << eviction_policy << ">";
} else {
ss << "tl::tma_store";
}
print_extern_call_stmt(ss.str(), 0, 1); print_extern_call_stmt(ss.str(), 0, 1);
} else if (op->op.same_as(tl::ptx_ldmatirx())) { } else if (op->op.same_as(tl::ptx_ldmatirx())) {
int trans = Downcast<IntImm>(op->args[0])->value; int trans = Downcast<IntImm>(op->args[0])->value;
......
...@@ -41,15 +41,6 @@ TL_DEVICE void tma_load(const CUtensorMap &descriptor, uint64_t &smem_mbar, ...@@ -41,15 +41,6 @@ TL_DEVICE void tma_load(const CUtensorMap &descriptor, uint64_t &smem_mbar,
uint64_t gmem_int_desc = reinterpret_cast<uint64_t>(&descriptor); uint64_t gmem_int_desc = reinterpret_cast<uint64_t>(&descriptor);
uint32_t smem_int_mbar = smem_ptr_to_uint(&smem_mbar); uint32_t smem_int_mbar = smem_ptr_to_uint(&smem_mbar);
uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr); uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr);
if constexpr (cache_hint == CacheHintSm90::EVICT_NORMAL) {
asm volatile("cp.async.bulk.tensor.1d.shared::cluster.global.mbarrier::"
"complete_tx::bytes"
" [%0], [%1, {%3}], [%2];"
:
: "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar),
"r"(crd0)
: "memory");
} else {
asm volatile("cp.async.bulk.tensor.1d.shared::cluster.global.mbarrier::" asm volatile("cp.async.bulk.tensor.1d.shared::cluster.global.mbarrier::"
"complete_tx::bytes.L2::cache_hint" "complete_tx::bytes.L2::cache_hint"
" [%0], [%1, {%3}], [%2], %4;" " [%0], [%1, {%3}], [%2], %4;"
...@@ -57,7 +48,6 @@ TL_DEVICE void tma_load(const CUtensorMap &descriptor, uint64_t &smem_mbar, ...@@ -57,7 +48,6 @@ TL_DEVICE void tma_load(const CUtensorMap &descriptor, uint64_t &smem_mbar,
: "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar),
"r"(crd0), "l"(cache_hint) "r"(crd0), "l"(cache_hint)
: "memory"); : "memory");
}
} }
template <CacheHintSm90 cache_hint = CacheHintSm90::EVICT_NORMAL> template <CacheHintSm90 cache_hint = CacheHintSm90::EVICT_NORMAL>
...@@ -67,15 +57,6 @@ TL_DEVICE void tma_load(const CUtensorMap &descriptor, uint64_t &smem_mbar, ...@@ -67,15 +57,6 @@ TL_DEVICE void tma_load(const CUtensorMap &descriptor, uint64_t &smem_mbar,
uint64_t gmem_int_desc = reinterpret_cast<uint64_t>(&descriptor); uint64_t gmem_int_desc = reinterpret_cast<uint64_t>(&descriptor);
uint32_t smem_int_mbar = smem_ptr_to_uint(&smem_mbar); uint32_t smem_int_mbar = smem_ptr_to_uint(&smem_mbar);
uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr); uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr);
if constexpr (cache_hint == CacheHintSm90::EVICT_NORMAL) {
asm volatile("cp.async.bulk.tensor.2d.shared::cluster.global.mbarrier::"
"complete_tx::bytes"
" [%0], [%1, {%3, %4}], [%2];"
:
: "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar),
"r"(crd0), "r"(crd1)
: "memory");
} else {
asm volatile("cp.async.bulk.tensor.2d.shared::cluster.global.mbarrier::" asm volatile("cp.async.bulk.tensor.2d.shared::cluster.global.mbarrier::"
"complete_tx::bytes.L2::cache_hint" "complete_tx::bytes.L2::cache_hint"
" [%0], [%1, {%3, %4}], [%2], %5;" " [%0], [%1, {%3, %4}], [%2], %5;"
...@@ -83,7 +64,6 @@ TL_DEVICE void tma_load(const CUtensorMap &descriptor, uint64_t &smem_mbar, ...@@ -83,7 +64,6 @@ TL_DEVICE void tma_load(const CUtensorMap &descriptor, uint64_t &smem_mbar,
: "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar),
"r"(crd0), "r"(crd1), "l"(cache_hint) "r"(crd0), "r"(crd1), "l"(cache_hint)
: "memory"); : "memory");
}
} }
template <CacheHintSm90 cache_hint = CacheHintSm90::EVICT_NORMAL> template <CacheHintSm90 cache_hint = CacheHintSm90::EVICT_NORMAL>
...@@ -93,15 +73,6 @@ TL_DEVICE void tma_load(const CUtensorMap &descriptor, uint64_t &smem_mbar, ...@@ -93,15 +73,6 @@ TL_DEVICE void tma_load(const CUtensorMap &descriptor, uint64_t &smem_mbar,
uint64_t gmem_int_desc = reinterpret_cast<uint64_t>(&descriptor); uint64_t gmem_int_desc = reinterpret_cast<uint64_t>(&descriptor);
uint32_t smem_int_mbar = smem_ptr_to_uint(&smem_mbar); uint32_t smem_int_mbar = smem_ptr_to_uint(&smem_mbar);
uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr); uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr);
if constexpr (cache_hint == CacheHintSm90::EVICT_NORMAL) {
asm volatile("cp.async.bulk.tensor.3d.shared::cluster.global.mbarrier::"
"complete_tx::bytes"
" [%0], [%1, {%3, %4, %5}], [%2];"
:
: "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar),
"r"(crd0), "r"(crd1), "r"(crd2)
: "memory");
} else {
asm volatile("cp.async.bulk.tensor.3d.shared::cluster.global.mbarrier::" asm volatile("cp.async.bulk.tensor.3d.shared::cluster.global.mbarrier::"
"complete_tx::bytes.L2::cache_hint" "complete_tx::bytes.L2::cache_hint"
" [%0], [%1, {%3, %4, %5}], [%2], %6;" " [%0], [%1, {%3, %4, %5}], [%2], %6;"
...@@ -109,7 +80,6 @@ TL_DEVICE void tma_load(const CUtensorMap &descriptor, uint64_t &smem_mbar, ...@@ -109,7 +80,6 @@ TL_DEVICE void tma_load(const CUtensorMap &descriptor, uint64_t &smem_mbar,
: "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar),
"r"(crd0), "r"(crd1), "r"(crd2), "l"(cache_hint) "r"(crd0), "r"(crd1), "r"(crd2), "l"(cache_hint)
: "memory"); : "memory");
}
} }
template <CacheHintSm90 cache_hint = CacheHintSm90::EVICT_NORMAL> template <CacheHintSm90 cache_hint = CacheHintSm90::EVICT_NORMAL>
TL_DEVICE void tma_load(const CUtensorMap &descriptor, uint64_t &smem_mbar, TL_DEVICE void tma_load(const CUtensorMap &descriptor, uint64_t &smem_mbar,
...@@ -119,15 +89,6 @@ TL_DEVICE void tma_load(const CUtensorMap &descriptor, uint64_t &smem_mbar, ...@@ -119,15 +89,6 @@ TL_DEVICE void tma_load(const CUtensorMap &descriptor, uint64_t &smem_mbar,
uint64_t gmem_int_desc = reinterpret_cast<uint64_t>(&descriptor); uint64_t gmem_int_desc = reinterpret_cast<uint64_t>(&descriptor);
uint32_t smem_int_mbar = smem_ptr_to_uint(&smem_mbar); uint32_t smem_int_mbar = smem_ptr_to_uint(&smem_mbar);
uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr); uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr);
if constexpr (cache_hint == CacheHintSm90::EVICT_NORMAL) {
asm volatile("cp.async.bulk.tensor.4d.shared::cluster.global.mbarrier::"
"complete_tx::bytes"
" [%0], [%1, {%3, %4, %5, %6}], [%2];"
:
: "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar),
"r"(crd0), "r"(crd1), "r"(crd2), "r"(crd3)
: "memory");
} else {
asm volatile("cp.async.bulk.tensor.4d.shared::cluster.global.mbarrier::" asm volatile("cp.async.bulk.tensor.4d.shared::cluster.global.mbarrier::"
"complete_tx::bytes.L2::cache_hint" "complete_tx::bytes.L2::cache_hint"
" [%0], [%1, {%3, %4, %5, %6}], [%2], %7;" " [%0], [%1, {%3, %4, %5, %6}], [%2], %7;"
...@@ -135,7 +96,6 @@ TL_DEVICE void tma_load(const CUtensorMap &descriptor, uint64_t &smem_mbar, ...@@ -135,7 +96,6 @@ TL_DEVICE void tma_load(const CUtensorMap &descriptor, uint64_t &smem_mbar,
: "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar),
"r"(crd0), "r"(crd1), "r"(crd2), "r"(crd3), "l"(cache_hint) "r"(crd0), "r"(crd1), "r"(crd2), "r"(crd3), "l"(cache_hint)
: "memory"); : "memory");
}
} }
template <CacheHintSm90 cache_hint = CacheHintSm90::EVICT_NORMAL> template <CacheHintSm90 cache_hint = CacheHintSm90::EVICT_NORMAL>
...@@ -146,15 +106,6 @@ TL_DEVICE void tma_load(const CUtensorMap &descriptor, uint64_t &smem_mbar, ...@@ -146,15 +106,6 @@ TL_DEVICE void tma_load(const CUtensorMap &descriptor, uint64_t &smem_mbar,
uint64_t gmem_int_desc = reinterpret_cast<uint64_t>(&descriptor); uint64_t gmem_int_desc = reinterpret_cast<uint64_t>(&descriptor);
uint32_t smem_int_mbar = smem_ptr_to_uint(&smem_mbar); uint32_t smem_int_mbar = smem_ptr_to_uint(&smem_mbar);
uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr); uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr);
if constexpr (cache_hint == CacheHintSm90::EVICT_NORMAL) {
asm volatile("cp.async.bulk.tensor.5d.shared::cluster.global.mbarrier::"
"complete_tx::bytes"
" [%0], [%1, {%3, %4, %5, %6, %7}], [%2];"
:
: "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar),
"r"(crd0), "r"(crd1), "r"(crd2), "r"(crd3), "r"(crd4)
: "memory");
} else {
asm volatile("cp.async.bulk.tensor.5d.shared::cluster.global.mbarrier::" asm volatile("cp.async.bulk.tensor.5d.shared::cluster.global.mbarrier::"
"complete_tx::bytes.L2::cache_hint" "complete_tx::bytes.L2::cache_hint"
" [%0], [%1, {%3, %4, %5, %6, %7}], [%2], %8;" " [%0], [%1, {%3, %4, %5, %6, %7}], [%2], %8;"
...@@ -163,7 +114,6 @@ TL_DEVICE void tma_load(const CUtensorMap &descriptor, uint64_t &smem_mbar, ...@@ -163,7 +114,6 @@ TL_DEVICE void tma_load(const CUtensorMap &descriptor, uint64_t &smem_mbar,
"r"(crd0), "r"(crd1), "r"(crd2), "r"(crd3), "r"(crd4), "r"(crd0), "r"(crd1), "r"(crd2), "r"(crd3), "r"(crd4),
"l"(cache_hint) "l"(cache_hint)
: "memory"); : "memory");
}
} }
template <CacheHintSm90 cache_hint = CacheHintSm90::EVICT_NORMAL> template <CacheHintSm90 cache_hint = CacheHintSm90::EVICT_NORMAL>
...@@ -176,27 +126,14 @@ TL_DEVICE void tma_load_im2col(const CUtensorMap &descriptor, ...@@ -176,27 +126,14 @@ TL_DEVICE void tma_load_im2col(const CUtensorMap &descriptor,
uint64_t gmem_int_desc = reinterpret_cast<uint64_t>(&descriptor); uint64_t gmem_int_desc = reinterpret_cast<uint64_t>(&descriptor);
uint32_t smem_int_mbar = smem_ptr_to_uint(&smem_mbar); uint32_t smem_int_mbar = smem_ptr_to_uint(&smem_mbar);
uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr); uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr);
if constexpr (cache_hint == CacheHintSm90::EVICT_NORMAL) { asm volatile("cp.async.bulk.tensor.4d.shared::cluster.global.im2col.mbarrier:"
asm volatile(
"cp.async.bulk.tensor.4d.shared::cluster.global.im2col.mbarrier:"
":complete_tx::bytes"
" [%0], [%1, {%3, %4, %5, %6}], [%2], {%7, %8};"
:
: "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar),
"r"(coord_c), "r"(coord_w), "r"(coord_h), "r"(coord_n), "h"(offset_w),
"h"(offset_h)
: "memory");
} else {
asm volatile(
"cp.async.bulk.tensor.4d.shared::cluster.global.im2col.mbarrier:"
":complete_tx::bytes.L2::cache_hint" ":complete_tx::bytes.L2::cache_hint"
" [%0], [%1, {%3, %4, %5, %6}], [%2], {%7, %8}, %9;" " [%0], [%1, {%3, %4, %5, %6}], [%2], {%7, %8}, %9;"
: :
: "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar),
"r"(coord_c), "r"(coord_w), "r"(coord_h), "r"(coord_n), "h"(offset_w), "r"(coord_c), "r"(coord_w), "r"(coord_h), "r"(coord_n),
"h"(offset_h), "l"(cache_hint) "h"(offset_w), "h"(offset_h), "l"(cache_hint)
: "memory"); : "memory");
}
} }
template <CacheHintSm90 cache_hint = CacheHintSm90::EVICT_NORMAL> template <CacheHintSm90 cache_hint = CacheHintSm90::EVICT_NORMAL>
...@@ -204,21 +141,12 @@ TL_DEVICE void tma_store(const CUtensorMap &descriptor, ...@@ -204,21 +141,12 @@ TL_DEVICE void tma_store(const CUtensorMap &descriptor,
void const *const smem_ptr, int32_t const &crd0) { void const *const smem_ptr, int32_t const &crd0) {
uint64_t gmem_int_desc = reinterpret_cast<uint64_t>(&descriptor); uint64_t gmem_int_desc = reinterpret_cast<uint64_t>(&descriptor);
uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr); uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr);
if constexpr (cache_hint == CacheHintSm90::EVICT_NORMAL) {
asm volatile("cp.async.bulk.tensor.1d.global.shared::cta.bulk_group [%0, "
"{%2}], [%1];"
:
: "l"(gmem_int_desc), "r"(smem_int_ptr), "r"(crd0)
: "memory");
} else {
asm volatile("cp.async.bulk.tensor.1d.global.shared::cta.bulk_group " asm volatile("cp.async.bulk.tensor.1d.global.shared::cta.bulk_group "
"::cache_hint [%0, {%2}], [%1], %3;" ".L2::cache_hint [%0, {%2}], [%1], %3;"
: :
: "l"(gmem_int_desc), "r"(smem_int_ptr), "r"(crd0), : "l"(gmem_int_desc), "r"(smem_int_ptr), "r"(crd0),
"l"(cache_hint) "l"(cache_hint)
: "memory"); : "memory");
}
} }
template <CacheHintSm90 cache_hint = CacheHintSm90::EVICT_NORMAL> template <CacheHintSm90 cache_hint = CacheHintSm90::EVICT_NORMAL>
...@@ -227,21 +155,12 @@ TL_DEVICE void tma_store(const CUtensorMap &descriptor, ...@@ -227,21 +155,12 @@ TL_DEVICE void tma_store(const CUtensorMap &descriptor,
int32_t const &crd1) { int32_t const &crd1) {
uint64_t gmem_int_desc = reinterpret_cast<uint64_t>(&descriptor); uint64_t gmem_int_desc = reinterpret_cast<uint64_t>(&descriptor);
uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr); uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr);
if constexpr (cache_hint == CacheHintSm90::EVICT_NORMAL) {
asm volatile("cp.async.bulk.tensor.2d.global.shared::cta.bulk_group [%0, "
"{%2, %3}], [%1];"
:
: "l"(gmem_int_desc), "r"(smem_int_ptr), "r"(crd0), "r"(crd1)
: "memory");
} else {
asm volatile("cp.async.bulk.tensor.2d.global.shared::cta.bulk_group " asm volatile("cp.async.bulk.tensor.2d.global.shared::cta.bulk_group "
"::cache_hint [%0, {%2, %3}], [%1], %4;" ".L2::cache_hint [%0, {%2, %3}], [%1], %4;"
: :
: "l"(gmem_int_desc), "r"(smem_int_ptr), "r"(crd0), "r"(crd1), : "l"(gmem_int_desc), "r"(smem_int_ptr), "r"(crd0), "r"(crd1),
"l"(cache_hint) "l"(cache_hint)
: "memory"); : "memory");
}
} }
template <CacheHintSm90 cache_hint = CacheHintSm90::EVICT_NORMAL> template <CacheHintSm90 cache_hint = CacheHintSm90::EVICT_NORMAL>
...@@ -250,22 +169,12 @@ TL_DEVICE void tma_store(const CUtensorMap &descriptor, ...@@ -250,22 +169,12 @@ TL_DEVICE void tma_store(const CUtensorMap &descriptor,
int32_t const &crd1, int32_t const &crd2) { int32_t const &crd1, int32_t const &crd2) {
uint64_t gmem_int_desc = reinterpret_cast<uint64_t>(&descriptor); uint64_t gmem_int_desc = reinterpret_cast<uint64_t>(&descriptor);
uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr); uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr);
if constexpr (cache_hint == CacheHintSm90::EVICT_NORMAL) {
asm volatile("cp.async.bulk.tensor.3d.global.shared::cta.bulk_group [%0, "
"{%2, %3, %4}], [%1];"
:
: "l"(gmem_int_desc), "r"(smem_int_ptr), "r"(crd0), "r"(crd1),
"r"(crd2)
: "memory");
} else {
asm volatile("cp.async.bulk.tensor.3d.global.shared::cta.bulk_group " asm volatile("cp.async.bulk.tensor.3d.global.shared::cta.bulk_group "
"::cache_hint [%0, {%2, %3, %4}], [%1], %5;" ".L2::cache_hint [%0, {%2, %3, %4}], [%1], %5;"
: :
: "l"(gmem_int_desc), "r"(smem_int_ptr), "r"(crd0), "r"(crd1), : "l"(gmem_int_desc), "r"(smem_int_ptr), "r"(crd0), "r"(crd1),
"r"(crd2), "l"(cache_hint) "r"(crd2), "l"(cache_hint)
: "memory"); : "memory");
}
} }
template <CacheHintSm90 cache_hint = CacheHintSm90::EVICT_NORMAL> template <CacheHintSm90 cache_hint = CacheHintSm90::EVICT_NORMAL>
...@@ -275,22 +184,12 @@ TL_DEVICE void tma_store(const CUtensorMap &descriptor, ...@@ -275,22 +184,12 @@ TL_DEVICE void tma_store(const CUtensorMap &descriptor,
int32_t const &crd3) { int32_t const &crd3) {
uint64_t gmem_int_desc = reinterpret_cast<uint64_t>(&descriptor); uint64_t gmem_int_desc = reinterpret_cast<uint64_t>(&descriptor);
uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr); uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr);
if constexpr (cache_hint == CacheHintSm90::EVICT_NORMAL) {
asm volatile("cp.async.bulk.tensor.4d.global.shared::cta.bulk_group [%0, "
"{%2, %3, %4, %5}], [%1];"
:
: "l"(gmem_int_desc), "r"(smem_int_ptr), "r"(crd0), "r"(crd1),
"r"(crd2), "r"(crd3)
: "memory");
} else {
asm volatile("cp.async.bulk.tensor.4d.global.shared::cta.bulk_group " asm volatile("cp.async.bulk.tensor.4d.global.shared::cta.bulk_group "
"::cache_hint [%0, {%2, %3, %4, %5}], [%1], %6;" ".L2::cache_hint [%0, {%2, %3, %4, %5}], [%1], %6;"
: :
: "l"(gmem_int_desc), "r"(smem_int_ptr), "r"(crd0), "r"(crd1), : "l"(gmem_int_desc), "r"(smem_int_ptr), "r"(crd0), "r"(crd1),
"r"(crd2), "r"(crd3), "l"(cache_hint) "r"(crd2), "r"(crd3), "l"(cache_hint)
: "memory"); : "memory");
}
} }
template <CacheHintSm90 cache_hint = CacheHintSm90::EVICT_NORMAL> template <CacheHintSm90 cache_hint = CacheHintSm90::EVICT_NORMAL>
...@@ -300,22 +199,12 @@ TL_DEVICE void tma_store(const CUtensorMap &descriptor, ...@@ -300,22 +199,12 @@ TL_DEVICE void tma_store(const CUtensorMap &descriptor,
int32_t const &crd3, int32_t const &crd4) { int32_t const &crd3, int32_t const &crd4) {
uint64_t gmem_int_desc = reinterpret_cast<uint64_t>(&descriptor); uint64_t gmem_int_desc = reinterpret_cast<uint64_t>(&descriptor);
uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr); uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr);
if constexpr (cache_hint == CacheHintSm90::EVICT_NORMAL) {
asm volatile("cp.async.bulk.tensor.5d.global.shared::cta.bulk_group [%0, "
"{%2, %3, %4, %5, %6}], [%1];"
:
: "l"(gmem_int_desc), "r"(smem_int_ptr), "r"(crd0), "r"(crd1),
"r"(crd2), "r"(crd3), "r"(crd4)
: "memory");
} else {
asm volatile("cp.async.bulk.tensor.5d.global.shared::cta.bulk_group " asm volatile("cp.async.bulk.tensor.5d.global.shared::cta.bulk_group "
"::cache_hint [%0, {%2, %3, %4, %5, %6}], [%1], %7;" ".L2::cache_hint [%0, {%2, %3, %4, %5, %6}], [%1], %7;"
: :
: "l"(gmem_int_desc), "r"(smem_int_ptr), "r"(crd0), "r"(crd1), : "l"(gmem_int_desc), "r"(smem_int_ptr), "r"(crd0), "r"(crd1),
"r"(crd2), "r"(crd3), "r"(crd4), "l"(cache_hint) "r"(crd2), "r"(crd3), "r"(crd4), "l"(cache_hint)
: "memory"); : "memory");
}
} }
TL_DEVICE void prefetch_tma_descriptor(const CUtensorMap &descriptor) { TL_DEVICE void prefetch_tma_descriptor(const CUtensorMap &descriptor) {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment