Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
tilelang
Commits
3b52738d
Commit
3b52738d
authored
Jun 27, 2025
by
Yuxi Chi
Committed by
LeiWang1999
Jun 27, 2025
Browse files
[Enhancement] Add tma bulk copy. (#600)
parent
9232e7b8
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
31 additions
and
0 deletions
+31
-0
src/tl_templates/cuda/copy_sm90.h
src/tl_templates/cuda/copy_sm90.h
+31
-0
No files found.
src/tl_templates/cuda/copy_sm90.h
View file @
3b52738d
...
@@ -8,6 +8,28 @@
...
@@ -8,6 +8,28 @@
namespace
tl
{
namespace
tl
{
TL_DEVICE
void
tma_load
(
void
*
smem_ptr
,
void
*
gmem_ptr
,
uint64_t
&
smem_mbar
,
uint32_t
size
)
{
uint32_t
smem_int_mbar
=
smem_ptr_to_uint
(
&
smem_mbar
);
uint32_t
smem_int_ptr
=
smem_ptr_to_uint
(
smem_ptr
);
asm
volatile
(
"cp.async.bulk.shared::cluster.global.mbarrier::complete_tx::"
"bytes [%0], [%1], %2, [%3];
\n
"
::
"r"
(
smem_int_ptr
),
"l"
(
gmem_ptr
),
"r"
(
size
),
"r"
(
smem_int_mbar
)
:
);
}
TL_DEVICE
void
tma_load_multicast
(
void
*
smem_ptr
,
void
*
gmem_ptr
,
uint64_t
&
smem_mbar
,
uint32_t
size
,
uint16_t
mask
)
{
uint32_t
smem_int_mbar
=
smem_ptr_to_uint
(
&
smem_mbar
);
uint32_t
smem_int_ptr
=
smem_ptr_to_uint
(
smem_ptr
);
asm
volatile
(
"cp.async.bulk.shared::cluster.global.mbarrier::complete_tx::bytes."
"multicast::cluster [%0], [%1], %2, [%3], %4;
\n
"
::
"r"
(
smem_int_ptr
),
"l"
(
gmem_ptr
),
"r"
(
size
),
"r"
(
smem_int_mbar
),
"h"
(
mask
)
:
);
}
TL_DEVICE
void
tma_load
(
const
CUtensorMap
&
descriptor
,
uint64_t
&
smem_mbar
,
TL_DEVICE
void
tma_load
(
const
CUtensorMap
&
descriptor
,
uint64_t
&
smem_mbar
,
void
const
*
const
smem_ptr
,
int32_t
const
&
crd0
)
{
void
const
*
const
smem_ptr
,
int32_t
const
&
crd0
)
{
uint64_t
gmem_int_desc
=
reinterpret_cast
<
uint64_t
>
(
&
descriptor
);
uint64_t
gmem_int_desc
=
reinterpret_cast
<
uint64_t
>
(
&
descriptor
);
...
@@ -103,6 +125,15 @@ TL_DEVICE void tma_load_im2col(const CUtensorMap &descriptor,
...
@@ -103,6 +125,15 @@ TL_DEVICE void tma_load_im2col(const CUtensorMap &descriptor,
:
"memory"
);
:
"memory"
);
}
}
TL_DEVICE
void
tma_store
(
void
*
dst_gmem_ptr
,
void
*
smem_ptr
,
uint32_t
size
)
{
uint32_t
smem_int_ptr
=
smem_ptr_to_uint
(
smem_ptr
);
asm
volatile
(
"cp.async.bulk.global.shared::cta.bulk_group [%1], [%0], %2;
\n
"
::
"r"
(
smem_int_ptr
),
"l"
(
dst_gmem_ptr
),
"r"
(
size
)
:
);
}
TL_DEVICE
void
tma_store
(
const
CUtensorMap
&
descriptor
,
TL_DEVICE
void
tma_store
(
const
CUtensorMap
&
descriptor
,
void
const
*
const
smem_ptr
,
int32_t
const
&
crd0
)
{
void
const
*
const
smem_ptr
,
int32_t
const
&
crd0
)
{
uint64_t
gmem_int_desc
=
reinterpret_cast
<
uint64_t
>
(
&
descriptor
);
uint64_t
gmem_int_desc
=
reinterpret_cast
<
uint64_t
>
(
&
descriptor
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment