Commit 3b52738d authored by Yuxi Chi's avatar Yuxi Chi Committed by LeiWang1999
Browse files

[Enhancement] Add tma bulk copy. (#600)

parent 9232e7b8
...@@ -8,6 +8,28 @@ ...@@ -8,6 +8,28 @@
namespace tl { namespace tl {
TL_DEVICE void tma_load(void *smem_ptr, void *gmem_ptr, uint64_t &smem_mbar,
uint32_t size) {
uint32_t smem_int_mbar = smem_ptr_to_uint(&smem_mbar);
uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr);
asm volatile("cp.async.bulk.shared::cluster.global.mbarrier::complete_tx::"
"bytes [%0], [%1], %2, [%3]; \n" ::"r"(smem_int_ptr),
"l"(gmem_ptr), "r"(size), "r"(smem_int_mbar)
:);
}
TL_DEVICE void tma_load_multicast(void *smem_ptr, void *gmem_ptr,
uint64_t &smem_mbar, uint32_t size,
uint16_t mask) {
uint32_t smem_int_mbar = smem_ptr_to_uint(&smem_mbar);
uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr);
asm volatile(
"cp.async.bulk.shared::cluster.global.mbarrier::complete_tx::bytes."
"multicast::cluster [%0], [%1], %2, [%3], %4; \n" ::"r"(smem_int_ptr),
"l"(gmem_ptr), "r"(size), "r"(smem_int_mbar), "h"(mask)
:);
}
TL_DEVICE void tma_load(const CUtensorMap &descriptor, uint64_t &smem_mbar, TL_DEVICE void tma_load(const CUtensorMap &descriptor, uint64_t &smem_mbar,
void const *const smem_ptr, int32_t const &crd0) { void const *const smem_ptr, int32_t const &crd0) {
uint64_t gmem_int_desc = reinterpret_cast<uint64_t>(&descriptor); uint64_t gmem_int_desc = reinterpret_cast<uint64_t>(&descriptor);
...@@ -103,6 +125,15 @@ TL_DEVICE void tma_load_im2col(const CUtensorMap &descriptor, ...@@ -103,6 +125,15 @@ TL_DEVICE void tma_load_im2col(const CUtensorMap &descriptor,
: "memory"); : "memory");
} }
TL_DEVICE void tma_store(void *dst_gmem_ptr, void *smem_ptr, uint32_t size) {
uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr);
asm volatile(
"cp.async.bulk.global.shared::cta.bulk_group [%1], [%0], %2; \n" ::"r"(
smem_int_ptr),
"l"(dst_gmem_ptr), "r"(size)
:);
}
TL_DEVICE void tma_store(const CUtensorMap &descriptor, TL_DEVICE void tma_store(const CUtensorMap &descriptor,
void const *const smem_ptr, int32_t const &crd0) { void const *const smem_ptr, int32_t const &crd0) {
uint64_t gmem_int_desc = reinterpret_cast<uint64_t>(&descriptor); uint64_t gmem_int_desc = reinterpret_cast<uint64_t>(&descriptor);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment