Unverified Commit c369d690 authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Refactor] Refactor CUDA code generation to simplify eviction policy handling (#721)

* Update submodule 'tvm' to commit e11521e6936a827efa334588d29571fbb4620107

* Refactor CUDA code generation to simplify eviction policy handling

- Updated `VisitExpr_` methods in `codegen_cuda.cc` to use default eviction policy for `tma_load`, `tma_load_im2col`, and `tma_store` functions, reducing complexity.
- Removed conditional assembly code for `EVICT_NORMAL` in `copy_sm90.h`, streamlining the assembly calls for tensor memory operations.

* lint fix
parent 2bd2d69e
......@@ -1000,7 +1000,12 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) {
auto eviction_policy =
this->eviction_policy_names_
[op->args[op->args.size() - 1].as<IntImmNode>()->value];
// Simplify the code by using the default eviction policy
if (eviction_policy != "EVICT_NORMAL") {
ss << "tl::tma_load<tl::CacheHintSm90::" << eviction_policy << ">(";
} else {
ss << "tl::tma_load(";
}
auto desc = op->args[0];
ss << this->PrintExpr(desc) << ", ";
if (const IntImmNode *imm = op->args[1].as<IntImmNode>()) {
......@@ -1018,17 +1023,25 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) {
this->stream << ss.str();
} else if (op->op.same_as(tl::tma_load_im2col())) {
std::stringstream ss;
ss << "tl::tma_load_im2col<tl::CacheHintSm90::"
<< this->eviction_policy_names_
[op->args[op->args.size() - 1].as<IntImmNode>()->value]
<< ">";
auto eviction_policy =
this->eviction_policy_names_
[op->args[op->args.size() - 1].as<IntImmNode>()->value];
if (eviction_policy != "EVICT_NORMAL") {
ss << "tl::tma_load_im2col<tl::CacheHintSm90::" << eviction_policy << ">";
} else {
ss << "tl::tma_load_im2col";
}
print_extern_call_stmt(ss.str(), 0, 1);
} else if (op->op.same_as(tl::tma_store())) {
std::stringstream ss;
ss << "tl::tma_store<tl::CacheHintSm90::"
<< this->eviction_policy_names_
[op->args[op->args.size() - 1].as<IntImmNode>()->value]
<< ">";
auto eviction_policy =
this->eviction_policy_names_
[op->args[op->args.size() - 1].as<IntImmNode>()->value];
if (eviction_policy != "EVICT_NORMAL") {
ss << "tl::tma_store<tl::CacheHintSm90::" << eviction_policy << ">";
} else {
ss << "tl::tma_store";
}
print_extern_call_stmt(ss.str(), 0, 1);
} else if (op->op.same_as(tl::ptx_ldmatirx())) {
int trans = Downcast<IntImm>(op->args[0])->value;
......
......@@ -41,15 +41,6 @@ TL_DEVICE void tma_load(const CUtensorMap &descriptor, uint64_t &smem_mbar,
uint64_t gmem_int_desc = reinterpret_cast<uint64_t>(&descriptor);
uint32_t smem_int_mbar = smem_ptr_to_uint(&smem_mbar);
uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr);
if constexpr (cache_hint == CacheHintSm90::EVICT_NORMAL) {
asm volatile("cp.async.bulk.tensor.1d.shared::cluster.global.mbarrier::"
"complete_tx::bytes"
" [%0], [%1, {%3}], [%2];"
:
: "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar),
"r"(crd0)
: "memory");
} else {
asm volatile("cp.async.bulk.tensor.1d.shared::cluster.global.mbarrier::"
"complete_tx::bytes.L2::cache_hint"
" [%0], [%1, {%3}], [%2], %4;"
......@@ -57,7 +48,6 @@ TL_DEVICE void tma_load(const CUtensorMap &descriptor, uint64_t &smem_mbar,
: "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar),
"r"(crd0), "l"(cache_hint)
: "memory");
}
}
template <CacheHintSm90 cache_hint = CacheHintSm90::EVICT_NORMAL>
......@@ -67,15 +57,6 @@ TL_DEVICE void tma_load(const CUtensorMap &descriptor, uint64_t &smem_mbar,
uint64_t gmem_int_desc = reinterpret_cast<uint64_t>(&descriptor);
uint32_t smem_int_mbar = smem_ptr_to_uint(&smem_mbar);
uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr);
if constexpr (cache_hint == CacheHintSm90::EVICT_NORMAL) {
asm volatile("cp.async.bulk.tensor.2d.shared::cluster.global.mbarrier::"
"complete_tx::bytes"
" [%0], [%1, {%3, %4}], [%2];"
:
: "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar),
"r"(crd0), "r"(crd1)
: "memory");
} else {
asm volatile("cp.async.bulk.tensor.2d.shared::cluster.global.mbarrier::"
"complete_tx::bytes.L2::cache_hint"
" [%0], [%1, {%3, %4}], [%2], %5;"
......@@ -83,7 +64,6 @@ TL_DEVICE void tma_load(const CUtensorMap &descriptor, uint64_t &smem_mbar,
: "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar),
"r"(crd0), "r"(crd1), "l"(cache_hint)
: "memory");
}
}
template <CacheHintSm90 cache_hint = CacheHintSm90::EVICT_NORMAL>
......@@ -93,15 +73,6 @@ TL_DEVICE void tma_load(const CUtensorMap &descriptor, uint64_t &smem_mbar,
uint64_t gmem_int_desc = reinterpret_cast<uint64_t>(&descriptor);
uint32_t smem_int_mbar = smem_ptr_to_uint(&smem_mbar);
uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr);
if constexpr (cache_hint == CacheHintSm90::EVICT_NORMAL) {
asm volatile("cp.async.bulk.tensor.3d.shared::cluster.global.mbarrier::"
"complete_tx::bytes"
" [%0], [%1, {%3, %4, %5}], [%2];"
:
: "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar),
"r"(crd0), "r"(crd1), "r"(crd2)
: "memory");
} else {
asm volatile("cp.async.bulk.tensor.3d.shared::cluster.global.mbarrier::"
"complete_tx::bytes.L2::cache_hint"
" [%0], [%1, {%3, %4, %5}], [%2], %6;"
......@@ -109,7 +80,6 @@ TL_DEVICE void tma_load(const CUtensorMap &descriptor, uint64_t &smem_mbar,
: "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar),
"r"(crd0), "r"(crd1), "r"(crd2), "l"(cache_hint)
: "memory");
}
}
template <CacheHintSm90 cache_hint = CacheHintSm90::EVICT_NORMAL>
TL_DEVICE void tma_load(const CUtensorMap &descriptor, uint64_t &smem_mbar,
......@@ -119,15 +89,6 @@ TL_DEVICE void tma_load(const CUtensorMap &descriptor, uint64_t &smem_mbar,
uint64_t gmem_int_desc = reinterpret_cast<uint64_t>(&descriptor);
uint32_t smem_int_mbar = smem_ptr_to_uint(&smem_mbar);
uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr);
if constexpr (cache_hint == CacheHintSm90::EVICT_NORMAL) {
asm volatile("cp.async.bulk.tensor.4d.shared::cluster.global.mbarrier::"
"complete_tx::bytes"
" [%0], [%1, {%3, %4, %5, %6}], [%2];"
:
: "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar),
"r"(crd0), "r"(crd1), "r"(crd2), "r"(crd3)
: "memory");
} else {
asm volatile("cp.async.bulk.tensor.4d.shared::cluster.global.mbarrier::"
"complete_tx::bytes.L2::cache_hint"
" [%0], [%1, {%3, %4, %5, %6}], [%2], %7;"
......@@ -135,7 +96,6 @@ TL_DEVICE void tma_load(const CUtensorMap &descriptor, uint64_t &smem_mbar,
: "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar),
"r"(crd0), "r"(crd1), "r"(crd2), "r"(crd3), "l"(cache_hint)
: "memory");
}
}
template <CacheHintSm90 cache_hint = CacheHintSm90::EVICT_NORMAL>
......@@ -146,15 +106,6 @@ TL_DEVICE void tma_load(const CUtensorMap &descriptor, uint64_t &smem_mbar,
uint64_t gmem_int_desc = reinterpret_cast<uint64_t>(&descriptor);
uint32_t smem_int_mbar = smem_ptr_to_uint(&smem_mbar);
uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr);
if constexpr (cache_hint == CacheHintSm90::EVICT_NORMAL) {
asm volatile("cp.async.bulk.tensor.5d.shared::cluster.global.mbarrier::"
"complete_tx::bytes"
" [%0], [%1, {%3, %4, %5, %6, %7}], [%2];"
:
: "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar),
"r"(crd0), "r"(crd1), "r"(crd2), "r"(crd3), "r"(crd4)
: "memory");
} else {
asm volatile("cp.async.bulk.tensor.5d.shared::cluster.global.mbarrier::"
"complete_tx::bytes.L2::cache_hint"
" [%0], [%1, {%3, %4, %5, %6, %7}], [%2], %8;"
......@@ -163,7 +114,6 @@ TL_DEVICE void tma_load(const CUtensorMap &descriptor, uint64_t &smem_mbar,
"r"(crd0), "r"(crd1), "r"(crd2), "r"(crd3), "r"(crd4),
"l"(cache_hint)
: "memory");
}
}
template <CacheHintSm90 cache_hint = CacheHintSm90::EVICT_NORMAL>
......@@ -176,27 +126,14 @@ TL_DEVICE void tma_load_im2col(const CUtensorMap &descriptor,
uint64_t gmem_int_desc = reinterpret_cast<uint64_t>(&descriptor);
uint32_t smem_int_mbar = smem_ptr_to_uint(&smem_mbar);
uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr);
if constexpr (cache_hint == CacheHintSm90::EVICT_NORMAL) {
asm volatile(
"cp.async.bulk.tensor.4d.shared::cluster.global.im2col.mbarrier:"
":complete_tx::bytes"
" [%0], [%1, {%3, %4, %5, %6}], [%2], {%7, %8};"
:
: "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar),
"r"(coord_c), "r"(coord_w), "r"(coord_h), "r"(coord_n), "h"(offset_w),
"h"(offset_h)
: "memory");
} else {
asm volatile(
"cp.async.bulk.tensor.4d.shared::cluster.global.im2col.mbarrier:"
asm volatile("cp.async.bulk.tensor.4d.shared::cluster.global.im2col.mbarrier:"
":complete_tx::bytes.L2::cache_hint"
" [%0], [%1, {%3, %4, %5, %6}], [%2], {%7, %8}, %9;"
:
: "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar),
"r"(coord_c), "r"(coord_w), "r"(coord_h), "r"(coord_n), "h"(offset_w),
"h"(offset_h), "l"(cache_hint)
"r"(coord_c), "r"(coord_w), "r"(coord_h), "r"(coord_n),
"h"(offset_w), "h"(offset_h), "l"(cache_hint)
: "memory");
}
}
template <CacheHintSm90 cache_hint = CacheHintSm90::EVICT_NORMAL>
......@@ -204,21 +141,12 @@ TL_DEVICE void tma_store(const CUtensorMap &descriptor,
void const *const smem_ptr, int32_t const &crd0) {
uint64_t gmem_int_desc = reinterpret_cast<uint64_t>(&descriptor);
uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr);
if constexpr (cache_hint == CacheHintSm90::EVICT_NORMAL) {
asm volatile("cp.async.bulk.tensor.1d.global.shared::cta.bulk_group [%0, "
"{%2}], [%1];"
:
: "l"(gmem_int_desc), "r"(smem_int_ptr), "r"(crd0)
: "memory");
} else {
asm volatile("cp.async.bulk.tensor.1d.global.shared::cta.bulk_group "
"::cache_hint [%0, {%2}], [%1], %3;"
".L2::cache_hint [%0, {%2}], [%1], %3;"
:
: "l"(gmem_int_desc), "r"(smem_int_ptr), "r"(crd0),
"l"(cache_hint)
: "memory");
}
}
template <CacheHintSm90 cache_hint = CacheHintSm90::EVICT_NORMAL>
......@@ -227,21 +155,12 @@ TL_DEVICE void tma_store(const CUtensorMap &descriptor,
int32_t const &crd1) {
uint64_t gmem_int_desc = reinterpret_cast<uint64_t>(&descriptor);
uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr);
if constexpr (cache_hint == CacheHintSm90::EVICT_NORMAL) {
asm volatile("cp.async.bulk.tensor.2d.global.shared::cta.bulk_group [%0, "
"{%2, %3}], [%1];"
:
: "l"(gmem_int_desc), "r"(smem_int_ptr), "r"(crd0), "r"(crd1)
: "memory");
} else {
asm volatile("cp.async.bulk.tensor.2d.global.shared::cta.bulk_group "
"::cache_hint [%0, {%2, %3}], [%1], %4;"
".L2::cache_hint [%0, {%2, %3}], [%1], %4;"
:
: "l"(gmem_int_desc), "r"(smem_int_ptr), "r"(crd0), "r"(crd1),
"l"(cache_hint)
: "memory");
}
}
template <CacheHintSm90 cache_hint = CacheHintSm90::EVICT_NORMAL>
......@@ -250,22 +169,12 @@ TL_DEVICE void tma_store(const CUtensorMap &descriptor,
int32_t const &crd1, int32_t const &crd2) {
uint64_t gmem_int_desc = reinterpret_cast<uint64_t>(&descriptor);
uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr);
if constexpr (cache_hint == CacheHintSm90::EVICT_NORMAL) {
asm volatile("cp.async.bulk.tensor.3d.global.shared::cta.bulk_group [%0, "
"{%2, %3, %4}], [%1];"
:
: "l"(gmem_int_desc), "r"(smem_int_ptr), "r"(crd0), "r"(crd1),
"r"(crd2)
: "memory");
} else {
asm volatile("cp.async.bulk.tensor.3d.global.shared::cta.bulk_group "
"::cache_hint [%0, {%2, %3, %4}], [%1], %5;"
".L2::cache_hint [%0, {%2, %3, %4}], [%1], %5;"
:
: "l"(gmem_int_desc), "r"(smem_int_ptr), "r"(crd0), "r"(crd1),
"r"(crd2), "l"(cache_hint)
: "memory");
}
}
template <CacheHintSm90 cache_hint = CacheHintSm90::EVICT_NORMAL>
......@@ -275,22 +184,12 @@ TL_DEVICE void tma_store(const CUtensorMap &descriptor,
int32_t const &crd3) {
uint64_t gmem_int_desc = reinterpret_cast<uint64_t>(&descriptor);
uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr);
if constexpr (cache_hint == CacheHintSm90::EVICT_NORMAL) {
asm volatile("cp.async.bulk.tensor.4d.global.shared::cta.bulk_group [%0, "
"{%2, %3, %4, %5}], [%1];"
:
: "l"(gmem_int_desc), "r"(smem_int_ptr), "r"(crd0), "r"(crd1),
"r"(crd2), "r"(crd3)
: "memory");
} else {
asm volatile("cp.async.bulk.tensor.4d.global.shared::cta.bulk_group "
"::cache_hint [%0, {%2, %3, %4, %5}], [%1], %6;"
".L2::cache_hint [%0, {%2, %3, %4, %5}], [%1], %6;"
:
: "l"(gmem_int_desc), "r"(smem_int_ptr), "r"(crd0), "r"(crd1),
"r"(crd2), "r"(crd3), "l"(cache_hint)
: "memory");
}
}
template <CacheHintSm90 cache_hint = CacheHintSm90::EVICT_NORMAL>
......@@ -300,22 +199,12 @@ TL_DEVICE void tma_store(const CUtensorMap &descriptor,
int32_t const &crd3, int32_t const &crd4) {
uint64_t gmem_int_desc = reinterpret_cast<uint64_t>(&descriptor);
uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr);
if constexpr (cache_hint == CacheHintSm90::EVICT_NORMAL) {
asm volatile("cp.async.bulk.tensor.5d.global.shared::cta.bulk_group [%0, "
"{%2, %3, %4, %5, %6}], [%1];"
:
: "l"(gmem_int_desc), "r"(smem_int_ptr), "r"(crd0), "r"(crd1),
"r"(crd2), "r"(crd3), "r"(crd4)
: "memory");
} else {
asm volatile("cp.async.bulk.tensor.5d.global.shared::cta.bulk_group "
"::cache_hint [%0, {%2, %3, %4, %5, %6}], [%1], %7;"
".L2::cache_hint [%0, {%2, %3, %4, %5, %6}], [%1], %7;"
:
: "l"(gmem_int_desc), "r"(smem_int_ptr), "r"(crd0), "r"(crd1),
"r"(crd2), "r"(crd3), "r"(crd4), "l"(cache_hint)
: "memory");
}
}
TL_DEVICE void prefetch_tma_descriptor(const CUtensorMap &descriptor) {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment