Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
tilelang
Commits
549416f7
"vscode:/vscode.git/clone" did not exist on "2beb20f1dc4dc52de92368eee93fc8b7922ec511"
Commit
549416f7
authored
Jan 11, 2025
by
LeiWang1999
Browse files
Merge branch 'main' of
https://github.com/microsoft/TileLang
into main
parents
4d63633a
7fad4e88
Changes
90
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1420 additions
and
1100 deletions
+1420
-1100
src/tl_templates/cuda/copy_sm90.h
src/tl_templates/cuda/copy_sm90.h
+142
-119
src/tl_templates/cuda/gemm_sm70.h
src/tl_templates/cuda/gemm_sm70.h
+70
-50
src/tl_templates/cuda/gemm_sm80.h
src/tl_templates/cuda/gemm_sm80.h
+137
-108
src/tl_templates/cuda/gemm_sm90.h
src/tl_templates/cuda/gemm_sm90.h
+132
-82
src/tl_templates/cuda/ldsm.h
src/tl_templates/cuda/ldsm.h
+64
-43
src/tl_templates/cuda/reduce.h
src/tl_templates/cuda/reduce.h
+9
-14
src/tl_templates/cuda/threadblock_swizzle.h
src/tl_templates/cuda/threadblock_swizzle.h
+15
-11
src/tl_templates/hip/common.h
src/tl_templates/hip/common.h
+13
-9
src/tl_templates/hip/copy.h
src/tl_templates/hip/copy.h
+38
-29
src/tl_templates/hip/gemm.h
src/tl_templates/hip/gemm.h
+57
-37
src/tl_templates/hip/reduce.h
src/tl_templates/hip/reduce.h
+9
-14
src/tl_templates/hip/threadblock_swizzle.h
src/tl_templates/hip/threadblock_swizzle.h
+15
-11
src/transform/cluster_planning.cc
src/transform/cluster_planning.cc
+30
-23
src/transform/common/loop_fusion_utils.h
src/transform/common/loop_fusion_utils.h
+30
-23
src/transform/common/loop_vectorization_utils.h
src/transform/common/loop_vectorization_utils.h
+173
-120
src/transform/frontend_legalize.cc
src/transform/frontend_legalize.cc
+10
-10
src/transform/inject_fence_proxy.cc
src/transform/inject_fence_proxy.cc
+27
-29
src/transform/inject_pipeline.cc
src/transform/inject_pipeline.cc
+351
-288
src/transform/layout_inference.cc
src/transform/layout_inference.cc
+54
-42
src/transform/legalize_safe_memory_access.cc
src/transform/legalize_safe_memory_access.cc
+44
-38
No files found.
src/tl_templates/cuda/copy_sm90.h
View file @
549416f7
...
...
@@ -8,222 +8,245 @@
namespace
tl
{
TL_DEVICE
void
tma_load
(
const
CUtensorMap
&
descriptor
,
uint64_t
&
smem_mbar
,
void
const
*
const
smem_ptr
,
int32_t
const
&
crd0
)
{
TL_DEVICE
void
tma_load
(
const
CUtensorMap
&
descriptor
,
uint64_t
&
smem_mbar
,
void
const
*
const
smem_ptr
,
int32_t
const
&
crd0
)
{
uint64_t
gmem_int_desc
=
reinterpret_cast
<
uint64_t
>
(
&
descriptor
);
uint32_t
smem_int_mbar
=
smem_ptr_to_uint
(
&
smem_mbar
);
uint32_t
smem_int_ptr
=
smem_ptr_to_uint
(
smem_ptr
);
asm
volatile
(
"cp.async.bulk.tensor.1d.shared::cluster.global.mbarrier::complete_tx::bytes"
" [%0], [%1, {%3}], [%2];"
:
:
"r"
(
smem_int_ptr
),
"l"
(
gmem_int_desc
),
"r"
(
smem_int_mbar
),
"r"
(
crd0
)
:
"memory"
);
asm
volatile
(
"cp.async.bulk.tensor.1d.shared::cluster.global.mbarrier::"
"complete_tx::bytes"
" [%0], [%1, {%3}], [%2];"
:
:
"r"
(
smem_int_ptr
),
"l"
(
gmem_int_desc
),
"r"
(
smem_int_mbar
),
"r"
(
crd0
)
:
"memory"
);
}
TL_DEVICE
void
tma_load
(
const
CUtensorMap
&
descriptor
,
uint64_t
&
smem_mbar
,
void
const
*
const
smem_ptr
,
int32_t
const
&
crd0
,
int32_t
const
&
crd1
)
{
TL_DEVICE
void
tma_load
(
const
CUtensorMap
&
descriptor
,
uint64_t
&
smem_mbar
,
void
const
*
const
smem_ptr
,
int32_t
const
&
crd0
,
int32_t
const
&
crd1
)
{
uint64_t
gmem_int_desc
=
reinterpret_cast
<
uint64_t
>
(
&
descriptor
);
uint32_t
smem_int_mbar
=
smem_ptr_to_uint
(
&
smem_mbar
);
uint32_t
smem_int_ptr
=
smem_ptr_to_uint
(
smem_ptr
);
asm
volatile
(
"cp.async.bulk.tensor.2d.shared::cluster.global.mbarrier::complete_tx::bytes"
" [%0], [%1, {%3, %4}], [%2];"
:
:
"r"
(
smem_int_ptr
),
"l"
(
gmem_int_desc
),
"r"
(
smem_int_mbar
),
"r"
(
crd0
),
"r"
(
crd1
)
:
"memory"
);
asm
volatile
(
"cp.async.bulk.tensor.2d.shared::cluster.global.mbarrier::"
"complete_tx::bytes"
" [%0], [%1, {%3, %4}], [%2];"
:
:
"r"
(
smem_int_ptr
),
"l"
(
gmem_int_desc
),
"r"
(
smem_int_mbar
),
"r"
(
crd0
),
"r"
(
crd1
)
:
"memory"
);
}
TL_DEVICE
void
tma_load
(
const
CUtensorMap
&
descriptor
,
uint64_t
&
smem_mbar
,
void
const
*
const
smem_ptr
,
int32_t
const
&
crd0
,
int32_t
const
&
crd1
,
int32_t
const
&
crd2
)
{
TL_DEVICE
void
tma_load
(
const
CUtensorMap
&
descriptor
,
uint64_t
&
smem_mbar
,
void
const
*
const
smem_ptr
,
int32_t
const
&
crd0
,
int32_t
const
&
crd1
,
int32_t
const
&
crd2
)
{
uint64_t
gmem_int_desc
=
reinterpret_cast
<
uint64_t
>
(
&
descriptor
);
uint32_t
smem_int_mbar
=
smem_ptr_to_uint
(
&
smem_mbar
);
uint32_t
smem_int_ptr
=
smem_ptr_to_uint
(
smem_ptr
);
asm
volatile
(
"cp.async.bulk.tensor.3d.shared::cluster.global.mbarrier::complete_tx::bytes"
" [%0], [%1, {%3, %4, %5}], [%2];"
:
:
"r"
(
smem_int_ptr
),
"l"
(
gmem_int_desc
),
"r"
(
smem_int_mbar
),
"r"
(
crd0
),
"r"
(
crd1
),
"r"
(
crd2
)
:
"memory"
);
asm
volatile
(
"cp.async.bulk.tensor.3d.shared::cluster.global.mbarrier::"
"complete_tx::bytes"
" [%0], [%1, {%3, %4, %5}], [%2];"
:
:
"r"
(
smem_int_ptr
),
"l"
(
gmem_int_desc
),
"r"
(
smem_int_mbar
),
"r"
(
crd0
),
"r"
(
crd1
),
"r"
(
crd2
)
:
"memory"
);
}
TL_DEVICE
void
tma_load
(
const
CUtensorMap
&
descriptor
,
uint64_t
&
smem_mbar
,
void
const
*
const
smem_ptr
,
int32_t
const
&
crd0
,
int32_t
const
&
crd1
,
int32_t
const
&
crd2
,
int32_t
const
&
crd3
)
{
TL_DEVICE
void
tma_load
(
const
CUtensorMap
&
descriptor
,
uint64_t
&
smem_mbar
,
void
const
*
const
smem_ptr
,
int32_t
const
&
crd0
,
int32_t
const
&
crd1
,
int32_t
const
&
crd2
,
int32_t
const
&
crd3
)
{
uint64_t
gmem_int_desc
=
reinterpret_cast
<
uint64_t
>
(
&
descriptor
);
uint32_t
smem_int_mbar
=
smem_ptr_to_uint
(
&
smem_mbar
);
uint32_t
smem_int_ptr
=
smem_ptr_to_uint
(
smem_ptr
);
asm
volatile
(
"cp.async.bulk.tensor.4d.shared::cluster.global.mbarrier::
complete_tx::bytes"
" [%0], [%1, {%3, %4, %5, %6}], [%2];"
:
:
"r"
(
smem_int_ptr
),
"l"
(
gmem_int_desc
),
"r"
(
smem_int_mbar
),
"r"
(
crd0
),
"r"
(
crd1
),
"r"
(
crd2
),
"r"
(
crd3
)
:
"memory"
);
asm
volatile
(
"cp.async.bulk.tensor.4d.shared::cluster.global.mbarrier::"
"
complete_tx::bytes"
" [%0], [%1, {%3, %4, %5, %6}], [%2];"
:
:
"r"
(
smem_int_ptr
),
"l"
(
gmem_int_desc
),
"r"
(
smem_int_mbar
),
"r"
(
crd0
),
"r"
(
crd1
),
"r"
(
crd2
),
"r"
(
crd3
)
:
"memory"
);
}
TL_DEVICE
void
tma_load
(
const
CUtensorMap
&
descriptor
,
uint64_t
&
smem_mbar
,
void
const
*
const
smem_ptr
,
int32_t
const
&
crd0
,
int32_t
const
&
crd1
,
int32_t
const
&
crd2
,
int32_t
const
&
crd3
,
int32_t
const
&
crd4
)
{
TL_DEVICE
void
tma_load
(
const
CUtensorMap
&
descriptor
,
uint64_t
&
smem_mbar
,
void
const
*
const
smem_ptr
,
int32_t
const
&
crd0
,
int32_t
const
&
crd1
,
int32_t
const
&
crd2
,
int32_t
const
&
crd3
,
int32_t
const
&
crd4
)
{
uint64_t
gmem_int_desc
=
reinterpret_cast
<
uint64_t
>
(
&
descriptor
);
uint32_t
smem_int_mbar
=
smem_ptr_to_uint
(
&
smem_mbar
);
uint32_t
smem_int_ptr
=
smem_ptr_to_uint
(
smem_ptr
);
asm
volatile
(
"cp.async.bulk.tensor.5d.shared::cluster.global.mbarrier::
complete_tx::bytes"
" [%0], [%1, {%3, %4, %5, %6, %7}], [%2];"
:
:
"r"
(
smem_int_ptr
),
"l"
(
gmem_int_desc
),
"r"
(
smem_int_mbar
),
"r"
(
crd0
),
"r"
(
crd1
),
"r"
(
crd2
),
"r"
(
crd3
),
"r"
(
crd4
)
:
"memory"
);
asm
volatile
(
"cp.async.bulk.tensor.5d.shared::cluster.global.mbarrier::"
"
complete_tx::bytes"
" [%0], [%1, {%3, %4, %5, %6, %7}], [%2];"
:
:
"r"
(
smem_int_ptr
),
"l"
(
gmem_int_desc
),
"r"
(
smem_int_mbar
),
"r"
(
crd0
),
"r"
(
crd1
),
"r"
(
crd2
),
"r"
(
crd3
),
"r"
(
crd4
)
:
"memory"
);
}
TL_DEVICE
void
tma_load_im2col
(
const
CUtensorMap
&
descriptor
,
uint64_t
&
smem_mbar
,
void
const
*
const
smem_ptr
,
int32_t
const
&
coord_c
,
int32_t
const
&
coord_w
,
int32_t
const
&
coord_h
,
int32_t
const
&
coord_n
,
uint16_t
const
&
offset_w
,
uint16_t
const
&
offset_h
)
{
TL_DEVICE
void
tma_load_im2col
(
const
CUtensorMap
&
descriptor
,
uint64_t
&
smem_mbar
,
void
const
*
const
smem_ptr
,
int32_t
const
&
coord_c
,
int32_t
const
&
coord_w
,
int32_t
const
&
coord_h
,
int32_t
const
&
coord_n
,
uint16_t
const
&
offset_w
,
uint16_t
const
&
offset_h
)
{
uint64_t
gmem_int_desc
=
reinterpret_cast
<
uint64_t
>
(
&
descriptor
);
uint32_t
smem_int_mbar
=
smem_ptr_to_uint
(
&
smem_mbar
);
uint32_t
smem_int_ptr
=
smem_ptr_to_uint
(
smem_ptr
);
asm
volatile
(
"cp.async.bulk.tensor.4d.shared::cluster.global.im2col.mbarrier::complete_tx::bytes"
" [%0], [%1, {%3, %4, %5, %6}], [%2], {%7, %8};"
:
:
"r"
(
smem_int_ptr
),
"l"
(
gmem_int_desc
),
"r"
(
smem_int_mbar
),
"r"
(
coord_c
),
"r"
(
coord_w
),
"r"
(
coord_h
),
"r"
(
coord_n
),
"h"
(
offset_w
),
"h"
(
offset_h
)
:
"memory"
);
asm
volatile
(
"cp.async.bulk.tensor.4d.shared::cluster.global.im2col.mbarrier:"
":complete_tx::bytes"
" [%0], [%1, {%3, %4, %5, %6}], [%2], {%7, %8};"
:
:
"r"
(
smem_int_ptr
),
"l"
(
gmem_int_desc
),
"r"
(
smem_int_mbar
),
"r"
(
coord_c
),
"r"
(
coord_w
),
"r"
(
coord_h
),
"r"
(
coord_n
),
"h"
(
offset_w
),
"h"
(
offset_h
)
:
"memory"
);
}
TL_DEVICE
void
tma_store
(
const
CUtensorMap
&
descriptor
,
void
const
*
const
smem_ptr
,
int32_t
const
&
crd0
)
{
TL_DEVICE
void
tma_store
(
const
CUtensorMap
&
descriptor
,
void
const
*
const
smem_ptr
,
int32_t
const
&
crd0
)
{
uint64_t
gmem_int_desc
=
reinterpret_cast
<
uint64_t
>
(
&
descriptor
);
uint32_t
smem_int_ptr
=
smem_ptr_to_uint
(
smem_ptr
);
asm
volatile
(
"cp.async.bulk.tensor.1d.global.shared::cta.bulk_group [%0, {%2}], [%1];"
:
:
"l"
(
gmem_int_desc
),
"r"
(
smem_int_ptr
),
"r"
(
crd0
)
:
"memory"
);
asm
volatile
(
"cp.async.bulk.tensor.1d.global.shared::cta.bulk_group [%0, {%2}], [%1];"
:
:
"l"
(
gmem_int_desc
),
"r"
(
smem_int_ptr
),
"r"
(
crd0
)
:
"memory"
);
}
TL_DEVICE
void
tma_store
(
const
CUtensorMap
&
descriptor
,
void
const
*
const
smem_ptr
,
int32_t
const
&
crd0
,
int32_t
const
&
crd1
)
{
TL_DEVICE
void
tma_store
(
const
CUtensorMap
&
descriptor
,
void
const
*
const
smem_ptr
,
int32_t
const
&
crd0
,
int32_t
const
&
crd1
)
{
uint64_t
gmem_int_desc
=
reinterpret_cast
<
uint64_t
>
(
&
descriptor
);
uint32_t
smem_int_ptr
=
smem_ptr_to_uint
(
smem_ptr
);
asm
volatile
(
"cp.async.bulk.tensor.2d.global.shared::cta.bulk_group [%0, {%2, %3}], [%1];"
asm
volatile
(
"cp.async.bulk.tensor.2d.global.shared::cta.bulk_group [%0, "
"{%2, %3}], [%1];"
:
:
"l"
(
gmem_int_desc
),
"r"
(
smem_int_ptr
),
"r"
(
crd0
),
"r"
(
crd1
)
:
"memory"
);
}
TL_DEVICE
void
tma_store
(
const
CUtensorMap
&
descriptor
,
void
const
*
const
smem_ptr
,
int32_t
const
&
crd0
,
int32_t
const
&
crd1
,
int32_t
const
&
crd2
)
{
TL_DEVICE
void
tma_store
(
const
CUtensorMap
&
descriptor
,
void
const
*
const
smem_ptr
,
int32_t
const
&
crd0
,
int32_t
const
&
crd1
,
int32_t
const
&
crd2
)
{
uint64_t
gmem_int_desc
=
reinterpret_cast
<
uint64_t
>
(
&
descriptor
);
uint32_t
smem_int_ptr
=
smem_ptr_to_uint
(
smem_ptr
);
asm
volatile
(
"cp.async.bulk.tensor.3d.global.shared::cta.bulk_group [%0, {%2, %3, %4}], [%1];"
asm
volatile
(
"cp.async.bulk.tensor.3d.global.shared::cta.bulk_group [%0, "
"{%2, %3, %4}], [%1];"
:
:
"l"
(
gmem_int_desc
),
"r"
(
smem_int_ptr
),
"r"
(
crd0
),
"r"
(
crd1
),
"r"
(
crd2
)
:
"l"
(
gmem_int_desc
),
"r"
(
smem_int_ptr
),
"r"
(
crd0
),
"r"
(
crd1
),
"r"
(
crd2
)
:
"memory"
);
}
TL_DEVICE
void
tma_store
(
const
CUtensorMap
&
descriptor
,
void
const
*
const
smem_ptr
,
int32_t
const
&
crd0
,
int32_t
const
&
crd1
,
int32_t
const
&
crd2
,
int32_t
const
&
crd3
)
{
TL_DEVICE
void
tma_store
(
const
CUtensorMap
&
descriptor
,
void
const
*
const
smem_ptr
,
int32_t
const
&
crd0
,
int32_t
const
&
crd1
,
int32_t
const
&
crd2
,
int32_t
const
&
crd3
)
{
uint64_t
gmem_int_desc
=
reinterpret_cast
<
uint64_t
>
(
&
descriptor
);
uint32_t
smem_int_ptr
=
smem_ptr_to_uint
(
smem_ptr
);
asm
volatile
(
"cp.async.bulk.tensor.4d.global.shared::cta.bulk_group [%0, {%2, %3, %4, %5}], [%1];"
asm
volatile
(
"cp.async.bulk.tensor.4d.global.shared::cta.bulk_group [%0, "
"{%2, %3, %4, %5}], [%1];"
:
:
"l"
(
gmem_int_desc
),
"r"
(
smem_int_ptr
),
"r"
(
crd0
),
"r"
(
crd1
),
"r"
(
crd2
),
"r"
(
crd3
)
:
"l"
(
gmem_int_desc
),
"r"
(
smem_int_ptr
),
"r"
(
crd0
),
"r"
(
crd1
),
"r"
(
crd2
),
"r"
(
crd3
)
:
"memory"
);
}
TL_DEVICE
void
tma_store
(
const
CUtensorMap
&
descriptor
,
void
const
*
const
smem_ptr
,
int32_t
const
&
crd0
,
int32_t
const
&
crd1
,
int32_t
const
&
crd2
,
int32_t
const
&
crd3
,
int32_t
const
&
crd4
)
{
TL_DEVICE
void
tma_store
(
const
CUtensorMap
&
descriptor
,
void
const
*
const
smem_ptr
,
int32_t
const
&
crd0
,
int32_t
const
&
crd1
,
int32_t
const
&
crd2
,
int32_t
const
&
crd3
,
int32_t
const
&
crd4
)
{
uint64_t
gmem_int_desc
=
reinterpret_cast
<
uint64_t
>
(
&
descriptor
);
uint32_t
smem_int_ptr
=
smem_ptr_to_uint
(
smem_ptr
);
asm
volatile
(
"cp.async.bulk.tensor.5d.global.shared::cta.bulk_group [%0, {%2, %3, %4, %5, %6}], [%1];"
:
:
"l"
(
gmem_int_desc
),
"r"
(
smem_int_ptr
),
"r"
(
crd0
),
"r"
(
crd1
),
"r"
(
crd2
),
"r"
(
crd3
),
"r"
(
crd4
)
:
"memory"
);
asm
volatile
(
"cp.async.bulk.tensor.5d.global.shared::cta.bulk_group [%0, "
"{%2, %3, %4, %5, %6}], [%1];"
:
:
"l"
(
gmem_int_desc
),
"r"
(
smem_int_ptr
),
"r"
(
crd0
),
"r"
(
crd1
),
"r"
(
crd2
),
"r"
(
crd3
),
"r"
(
crd4
)
:
"memory"
);
}
TL_DEVICE
void
prefetch_tma_descriptor
(
const
CUtensorMap
&
descriptor
)
{
TL_DEVICE
void
prefetch_tma_descriptor
(
const
CUtensorMap
&
descriptor
)
{
uint64_t
gmem_int_desc
=
reinterpret_cast
<
uint64_t
>
(
&
descriptor
);
asm
volatile
(
"prefetch.tensormap [%0];"
:
:
"l"
(
gmem_int_desc
)
:
"memory"
);
}
TL_DEVICE
void
mbarrier_init
(
uint64_t
&
smem_barrier
,
uint32_t
arrive_count
)
{
TL_DEVICE
void
mbarrier_init
(
uint64_t
&
smem_barrier
,
uint32_t
arrive_count
)
{
uint32_t
smem_int_ptr
=
smem_ptr_to_uint
(
&
smem_barrier
);
asm
volatile
(
"mbarrier.init.shared.b64 [%1], %0;"
:
:
"r"
(
arrive_count
),
"r"
(
smem_int_ptr
));
asm
volatile
(
"mbarrier.init.shared.b64 [%1], %0;"
:
:
"r"
(
arrive_count
),
"r"
(
smem_int_ptr
));
}
TL_DEVICE
void
mbarrier_wait
(
uint64_t
&
smem_barrier
,
int
phase_bit
)
{
TL_DEVICE
void
mbarrier_wait
(
uint64_t
&
smem_barrier
,
int
phase_bit
)
{
uint32_t
smem_int_ptr
=
smem_ptr_to_uint
(
&
smem_barrier
);
asm
volatile
(
"{
\n
"
".reg .pred P1;
\n
"
"LAB_WAIT:
\n
"
"mbarrier.try_wait.parity.shared.b64 P1, [%0], %1;
\n
"
"@!P1 bra.uni LAB_WAIT;
\n
"
"}
\n
"
::
"r"
(
smem_int_ptr
),
"r"
(
phase_bit
));
asm
volatile
(
"{
\n
"
".reg .pred P1;
\n
"
"LAB_WAIT:
\n
"
"mbarrier.try_wait.parity.shared.b64 P1, [%0], %1;
\n
"
"@!P1 bra.uni LAB_WAIT;
\n
"
"}
\n
"
::
"r"
(
smem_int_ptr
),
"r"
(
phase_bit
));
}
TL_DEVICE
void
mbarrier_arrive
(
uint64_t
&
smem_barrier
)
{
TL_DEVICE
void
mbarrier_arrive
(
uint64_t
&
smem_barrier
)
{
uint32_t
smem_int_ptr
=
smem_ptr_to_uint
(
&
smem_barrier
);
asm
volatile
(
"mbarrier.arrive.shared.b64 _, [%0];"
:
:
"r"
(
smem_int_ptr
));
}
TL_DEVICE
void
mbarrier_expect_tx
(
uint64_t
&
smem_barrier
,
uint32_t
transaction_bytes
)
{
TL_DEVICE
void
mbarrier_expect_tx
(
uint64_t
&
smem_barrier
,
uint32_t
transaction_bytes
)
{
uint32_t
smem_int_ptr
=
smem_ptr_to_uint
(
&
smem_barrier
);
asm
volatile
(
"mbarrier.expect_tx.shared.b64 [%1], %0;"
:
:
"r"
(
transaction_bytes
),
"r"
(
smem_int_ptr
));
}
TL_DEVICE
void
mbarrier_arrive_expect_tx
(
uint64_t
&
smem_barrier
,
uint32_t
transaction_bytes
)
{
TL_DEVICE
void
mbarrier_arrive_expect_tx
(
uint64_t
&
smem_barrier
,
uint32_t
transaction_bytes
)
{
uint32_t
smem_int_ptr
=
smem_ptr_to_uint
(
&
smem_barrier
);
asm
volatile
(
"mbarrier.arrive.expect_tx.shared.b64 _, [%1], %0;"
:
:
"r"
(
transaction_bytes
),
"r"
(
smem_int_ptr
));
}
TL_DEVICE
void
mbarrier_cp_async_arrive
(
uint64_t
&
smem_barrier
)
{
TL_DEVICE
void
mbarrier_cp_async_arrive
(
uint64_t
&
smem_barrier
)
{
uint32_t
smem_int_ptr
=
smem_ptr_to_uint
(
&
smem_barrier
);
asm
volatile
(
"cp.async.mbarrier.arrive.shared.b64 [%0];"
:
:
"r"
(
smem_int_ptr
));
asm
volatile
(
"cp.async.mbarrier.arrive.shared.b64 [%0];"
:
:
"r"
(
smem_int_ptr
));
}
TL_DEVICE
void
fence_proxy_async
()
{
asm
volatile
(
"fence.proxy.async.shared::cta;"
:
:
);
}
TL_DEVICE
void
fence_proxy_async
()
{
asm
volatile
(
"fence.proxy.async.shared::cta;"
:
:
);
}
TL_DEVICE
void
syncthreads_partial
(
uint64_t
&
smem_barrier
)
{
TL_DEVICE
void
syncthreads_partial
(
uint64_t
&
smem_barrier
)
{
uint32_t
smem_int_ptr
=
smem_ptr_to_uint
(
&
smem_barrier
);
uint64_t
state
;
asm
volatile
(
"{
\n
"
".reg .pred P1;
\n
"
"mbarrier.arrive.shared.b64 %1, [%0];
\n
"
"LAB_WAIT:
\n
"
"mbarrier.try_wait.shared.b64 P1, [%0], %1;
\n
"
"@!P1 bra.uni LAB_WAIT;
\n
"
"}
\n
"
:
:
"r"
(
smem_int_ptr
),
"l"
(
state
));
asm
volatile
(
"{
\n
"
".reg .pred P1;
\n
"
"mbarrier.arrive.shared.b64 %1, [%0];
\n
"
"LAB_WAIT:
\n
"
"mbarrier.try_wait.shared.b64 P1, [%0], %1;
\n
"
"@!P1 bra.uni LAB_WAIT;
\n
"
"}
\n
"
:
:
"r"
(
smem_int_ptr
),
"l"
(
state
));
}
template
<
uint32_t
RegCount
>
TL_DEVICE
void
warpgroup_reg_alloc
(){
asm
volatile
(
"setmaxnreg.inc.sync.aligned.u32 %0;
\n
"
:
:
"n"
(
RegCount
)
);
template
<
uint32_t
RegCount
>
TL_DEVICE
void
warpgroup_reg_alloc
()
{
asm
volatile
(
"setmaxnreg.inc.sync.aligned.u32 %0;
\n
"
:
:
"n"
(
RegCount
));
}
template
<
uint32_t
RegCount
>
TL_DEVICE
void
warpgroup_reg_dealloc
(){
asm
volatile
(
"setmaxnreg.dec.sync.aligned.u32 %0;
\n
"
:
:
"n"
(
RegCount
)
);
template
<
uint32_t
RegCount
>
TL_DEVICE
void
warpgroup_reg_dealloc
()
{
asm
volatile
(
"setmaxnreg.dec.sync.aligned.u32 %0;
\n
"
:
:
"n"
(
RegCount
));
}
}
// namespace tl
\ No newline at end of file
}
// namespace tl
\ No newline at end of file
src/tl_templates/cuda/gemm_sm70.h
View file @
549416f7
...
...
@@ -13,78 +13,94 @@ using cutlass::gemm::GemmShape;
// Add 128 bits padding when the last dim is a multiple of 256 bits
template
<
typename
T
,
bool
transpose
,
int
M
,
int
K
,
typename
Enable
=
void
>
struct
DispatchSharedMemoryLayoutA
{
using
Layout
=
typename
std
::
conditional
<
transpose
,
cutlass
::
layout
::
ColumnMajor
,
cutlass
::
layout
::
RowMajor
>::
type
;
using
Layout
=
typename
std
::
conditional
<
transpose
,
cutlass
::
layout
::
ColumnMajor
,
cutlass
::
layout
::
RowMajor
>::
type
;
static
int
constexpr
Dim
=
transpose
?
M
:
K
;
static
int
constexpr
Stride
=
(
Dim
*
sizeof
(
T
)
%
32
==
0
)
?
Dim
+
16
/
sizeof
(
T
)
:
Dim
;
static
int
constexpr
Stride
=
(
Dim
*
sizeof
(
T
)
%
32
==
0
)
?
Dim
+
16
/
sizeof
(
T
)
:
Dim
;
};
template
<
typename
T
,
bool
transpose
,
int
N
,
int
K
,
typename
Enable
=
void
>
struct
DispatchSharedMemoryLayoutB
{
using
Layout
=
typename
std
::
conditional
<
transpose
,
cutlass
::
layout
::
ColumnMajor
,
cutlass
::
layout
::
RowMajor
>::
type
;
using
Layout
=
typename
std
::
conditional
<
transpose
,
cutlass
::
layout
::
ColumnMajor
,
cutlass
::
layout
::
RowMajor
>::
type
;
static
int
constexpr
Dim
=
transpose
?
K
:
N
;
static
int
constexpr
Stride
=
(
Dim
*
sizeof
(
T
)
%
32
==
0
)
?
Dim
+
16
/
sizeof
(
T
)
:
Dim
;
static
int
constexpr
Stride
=
(
Dim
*
sizeof
(
T
)
%
32
==
0
)
?
Dim
+
16
/
sizeof
(
T
)
:
Dim
;
};
// Partial specialization for half_t
template
<
int
M
,
int
K
>
struct
DispatchSharedMemoryLayoutA
<
half_t
,
true
,
M
,
K
,
typename
std
::
enable_if
<
M
%
64
==
0
>::
type
>
{
using
Layout
=
cutlass
::
layout
::
ColumnMajorVoltaTensorOpMultiplicandCongruous
<
16
>
;
struct
DispatchSharedMemoryLayoutA
<
half_t
,
true
,
M
,
K
,
typename
std
::
enable_if
<
M
%
64
==
0
>::
type
>
{
using
Layout
=
cutlass
::
layout
::
ColumnMajorVoltaTensorOpMultiplicandCongruous
<
16
>
;
static
int
constexpr
Stride
=
M
;
};
template
<
int
M
,
int
K
>
struct
DispatchSharedMemoryLayoutA
<
half_t
,
false
,
M
,
K
>
{
using
Layout
=
cutlass
::
layout
::
RowMajorVoltaTensorOpMultiplicandCrosswise
<
16
,
K
>
;
using
Layout
=
cutlass
::
layout
::
RowMajorVoltaTensorOpMultiplicandCrosswise
<
16
,
K
>
;
static
int
constexpr
Stride
=
M
;
};
template
<
int
N
,
int
K
>
struct
DispatchSharedMemoryLayoutB
<
half_t
,
true
,
N
,
K
>
{
using
Layout
=
cutlass
::
layout
::
ColumnMajorVoltaTensorOpMultiplicandCrosswise
<
16
,
K
>
;
template
<
int
N
,
int
K
>
struct
DispatchSharedMemoryLayoutB
<
half_t
,
true
,
N
,
K
>
{
using
Layout
=
cutlass
::
layout
::
ColumnMajorVoltaTensorOpMultiplicandCrosswise
<
16
,
K
>
;
static
int
constexpr
Stride
=
N
;
};
template
<
int
N
,
int
K
>
struct
DispatchSharedMemoryLayoutB
<
half_t
,
false
,
N
,
K
,
typename
std
::
enable_if
<
N
%
64
==
0
>::
type
>
{
using
Layout
=
cutlass
::
layout
::
RowMajorVoltaTensorOpMultiplicandBCongruous
<
16
>
;
using
Layout
=
cutlass
::
layout
::
RowMajorVoltaTensorOpMultiplicandBCongruous
<
16
>
;
static
int
constexpr
Stride
=
N
;
};
template
<
typename
Shape
,
int
num_warp_m
,
int
num_warp_n
,
bool
trans_A
,
bool
trans_B
,
typename
A_type_raw
,
typename
B_type_raw
,
typename
C_type_raw
>
template
<
typename
Shape
,
int
num_warp_m
,
int
num_warp_n
,
bool
trans_A
,
bool
trans_B
,
typename
A_type_raw
,
typename
B_type_raw
,
typename
C_type_raw
>
class
GemmTensorOp
{
public:
public:
using
A_type
=
A_type_raw
;
using
B_type
=
B_type_raw
;
using
C_type
=
C_type_raw
;
using
InstructionShape
=
GemmShape
<
16
,
16
,
4
>
;
using
SMemLayoutA
=
typename
DispatchSharedMemoryLayoutA
<
A_type
,
trans_A
,
Shape
::
kM
,
Shape
::
kK
>::
Layout
;
typename
DispatchSharedMemoryLayoutA
<
A_type
,
trans_A
,
Shape
::
kM
,
Shape
::
kK
>::
Layout
;
using
SMemLayoutB
=
typename
DispatchSharedMemoryLayoutB
<
B_type
,
trans_B
,
Shape
::
kN
,
Shape
::
kK
>::
Layout
;
typename
DispatchSharedMemoryLayoutB
<
B_type
,
trans_B
,
Shape
::
kN
,
Shape
::
kK
>::
Layout
;
static
constexpr
int
stride_A
=
DispatchSharedMemoryLayoutA
<
A_type
,
trans_A
,
Shape
::
kM
,
Shape
::
kK
>::
Stride
;
DispatchSharedMemoryLayoutA
<
A_type
,
trans_A
,
Shape
::
kM
,
Shape
::
kK
>::
Stride
;
static
constexpr
int
stride_B
=
DispatchSharedMemoryLayoutB
<
B_type
,
trans_B
,
Shape
::
kN
,
Shape
::
kK
>::
Stride
;
DispatchSharedMemoryLayoutB
<
B_type
,
trans_B
,
Shape
::
kN
,
Shape
::
kK
>::
Stride
;
using
Policy
=
cutlass
::
gemm
::
warp
::
MmaTensorOpPolicy
<
cutlass
::
arch
::
Mma
<
InstructionShape
,
32
,
A_type
,
typename
std
::
conditional
<
trans_A
,
cutlass
::
layout
::
ColumnMajor
,
cutlass
::
layout
::
RowMajor
>::
type
,
B_type
,
typename
std
::
conditional
<
trans_B
,
cutlass
::
layout
::
ColumnMajor
,
cutlass
::
layout
::
RowMajor
>::
type
,
C_type
,
cutlass
::
layout
::
RowMajor
,
cutlass
::
arch
::
OpMultiplyAdd
>
,
cutlass
::
MatrixShape
<
1
,
1
>
>
;
cutlass
::
arch
::
Mma
<
InstructionShape
,
32
,
A_type
,
typename
std
::
conditional
<
trans_A
,
cutlass
::
layout
::
ColumnMajor
,
cutlass
::
layout
::
RowMajor
>::
type
,
B_type
,
typename
std
::
conditional
<
trans_B
,
cutlass
::
layout
::
ColumnMajor
,
cutlass
::
layout
::
RowMajor
>::
type
,
C_type
,
cutlass
::
layout
::
RowMajor
,
cutlass
::
arch
::
OpMultiplyAdd
>
,
cutlass
::
MatrixShape
<
1
,
1
>>
;
static_assert
(
Shape
::
kM
%
num_warp_m
==
0
);
static_assert
(
Shape
::
kN
%
num_warp_n
==
0
);
using
MmaWarp
=
typename
cutlass
::
gemm
::
warp
::
MmaVoltaTensorOp
<
GemmShape
<
Shape
::
kM
/
num_warp_m
,
Shape
::
kN
/
num_warp_n
,
InstructionShape
::
kK
>
,
A_type
,
SMemLayoutA
,
B_type
,
SMemLayoutB
,
C_type
,
cutlass
::
layout
::
RowMajor
,
Policy
>
;
GemmShape
<
Shape
::
kM
/
num_warp_m
,
Shape
::
kN
/
num_warp_n
,
InstructionShape
::
kK
>
,
A_type
,
SMemLayoutA
,
B_type
,
SMemLayoutB
,
C_type
,
cutlass
::
layout
::
RowMajor
,
Policy
>
;
using
TensorRefA
=
typename
MmaWarp
::
IteratorA
::
TensorRef
;
using
TensorRefB
=
typename
MmaWarp
::
IteratorB
::
TensorRef
;
...
...
@@ -97,13 +113,14 @@ class GemmTensorOp {
static_assert
(
Shape
::
kK
%
InstructionShape
::
kK
==
0
);
static
int
constexpr
kKgroups
=
Shape
::
kK
/
InstructionShape
::
kK
;
static
CUTLASS_DEVICE
void
body
(
A_type_raw
*
pA
,
B_type_raw
*
pB
,
FragmentC
&
accum
,
const
int
warp_idx_m
,
const
int
warp_idx_n
,
const
int
lane_id
)
{
static
CUTLASS_DEVICE
void
body
(
A_type_raw
*
pA
,
B_type_raw
*
pB
,
FragmentC
&
accum
,
const
int
warp_idx_m
,
const
int
warp_idx_n
,
const
int
lane_id
)
{
MmaWarp
mma_op
;
FragmentA
frag_A
;
FragmentB
frag_B
;
const
TensorRefA
ref_A
((
A_type
*
)
pA
,
stride_A
);
const
TensorRefB
ref_B
((
B_type
*
)
pB
,
stride_B
);
const
TensorRefA
ref_A
((
A_type
*
)
pA
,
stride_A
);
const
TensorRefB
ref_B
((
B_type
*
)
pB
,
stride_B
);
IteratorA
iter_A
(
ref_A
,
lane_id
);
IteratorB
iter_B
(
ref_B
,
lane_id
);
iter_A
.
add_tile_offset
({
warp_idx_m
,
0
});
...
...
@@ -118,11 +135,12 @@ class GemmTensorOp {
}
}
static
CUTLASS_DEVICE
void
body_rs
(
const
FragmentA
*
frag_A
,
B_type_raw
*
pB
,
FragmentC
&
accum
,
const
int
warp_idx_n
,
const
int
lane_id
)
{
static
CUTLASS_DEVICE
void
body_rs
(
const
FragmentA
*
frag_A
,
B_type_raw
*
pB
,
FragmentC
&
accum
,
const
int
warp_idx_n
,
const
int
lane_id
)
{
MmaWarp
mma_op
;
FragmentB
frag_B
;
const
TensorRefB
ref_B
((
B_type
*
)
pB
,
stride_B
);
const
TensorRefB
ref_B
((
B_type
*
)
pB
,
stride_B
);
IteratorB
iter_B
(
ref_B
,
lane_id
);
iter_B
.
add_tile_offset
({
0
,
warp_idx_n
});
CUTLASS_PRAGMA_UNROLL
...
...
@@ -136,27 +154,29 @@ class GemmTensorOp {
namespace
tl
{
template
<
int
M
,
int
N
,
int
K
,
int
num_warp_m
,
int
num_warp_n
,
bool
trans_A
,
bool
trans_B
,
typename
A_type
,
typename
B_type
,
typename
C_type
>
CUTLASS_DEVICE
void
gemm_ss
(
A_type
*
pA
,
B_type
*
pB
,
C_type
*
accum
)
{
using
MMA
=
GemmTensorOp
<
GemmShape
<
M
,
N
,
K
>
,
num_warp_m
,
num_warp_n
,
trans_A
,
trans_B
,
A_type
,
B_type
,
C_type
>
;
template
<
int
M
,
int
N
,
int
K
,
int
num_warp_m
,
int
num_warp_n
,
bool
trans_A
,
bool
trans_B
,
typename
A_type
,
typename
B_type
,
typename
C_type
>
CUTLASS_DEVICE
void
gemm_ss
(
A_type
*
pA
,
B_type
*
pB
,
C_type
*
accum
)
{
using
MMA
=
GemmTensorOp
<
GemmShape
<
M
,
N
,
K
>
,
num_warp_m
,
num_warp_n
,
trans_A
,
trans_B
,
A_type
,
B_type
,
C_type
>
;
using
FragmentC
=
typename
MMA
::
FragmentC
;
int
warp_id
=
threadIdx
.
x
/
32
;
int
lane_id
=
threadIdx
.
x
%
32
;
MMA
::
body
(
pA
,
pB
,
*
(
FragmentC
*
)(
accum
),
warp_id
/
num_warp_n
,
warp_id
%
num_warp_n
,
lane_id
);
MMA
::
body
(
pA
,
pB
,
*
(
FragmentC
*
)(
accum
),
warp_id
/
num_warp_n
,
warp_id
%
num_warp_n
,
lane_id
);
}
template
<
int
M
,
int
N
,
int
K
,
int
num_warp_m
,
int
num_warp_n
,
bool
trans_A
,
bool
trans_B
,
typename
A_type
,
typename
B_type
,
typename
C_type
>
CUTLASS_DEVICE
void
gemm_rs
(
A_type
*
pA
,
B_type
*
pB
,
C_type
*
accum
)
{
using
MMA
=
GemmTensorOp
<
GemmShape
<
M
,
N
,
K
>
,
num_warp_m
,
num_warp_n
,
trans_A
,
trans_B
,
A_type
,
B_type
,
C_type
>
;
template
<
int
M
,
int
N
,
int
K
,
int
num_warp_m
,
int
num_warp_n
,
bool
trans_A
,
bool
trans_B
,
typename
A_type
,
typename
B_type
,
typename
C_type
>
CUTLASS_DEVICE
void
gemm_rs
(
A_type
*
pA
,
B_type
*
pB
,
C_type
*
accum
)
{
using
MMA
=
GemmTensorOp
<
GemmShape
<
M
,
N
,
K
>
,
num_warp_m
,
num_warp_n
,
trans_A
,
trans_B
,
A_type
,
B_type
,
C_type
>
;
using
FragmentA
=
typename
MMA
::
FragmentA
;
using
FragmentC
=
typename
MMA
::
FragmentC
;
int
warp_id
=
threadIdx
.
x
/
32
;
int
lane_id
=
threadIdx
.
x
%
32
;
MMA
::
body_rs
((
const
FragmentA
*
)(
pA
),
pB
,
*
(
FragmentC
*
)(
accum
),
warp_id
%
num_warp_n
,
lane_id
);
MMA
::
body_rs
((
const
FragmentA
*
)(
pA
),
pB
,
*
(
FragmentC
*
)(
accum
),
warp_id
%
num_warp_n
,
lane_id
);
}
};
// namespace tl
};
// namespace tl
src/tl_templates/cuda/gemm_sm80.h
View file @
549416f7
...
...
@@ -12,39 +12,32 @@ template <typename A_type, typename B_type, typename C_type>
struct
DispatchInstruction
;
#if (defined(__CUDA_ARCH_LIST__) && (__CUDA_ARCH_LIST__ >= 800))
template
<
>
struct
DispatchInstruction
<
half_t
,
half_t
,
half_t
>
{
template
<
>
struct
DispatchInstruction
<
half_t
,
half_t
,
half_t
>
{
using
MMA
=
MMA_Atom
<
SM80_16x8x16_F16F16F16F16_TN
>
;
using
MMA_Group
=
Layout
<
Shape
<
_1
,
_2
,
_1
>>
;
};
template
<
>
struct
DispatchInstruction
<
half_t
,
half_t
,
float
>
{
template
<
>
struct
DispatchInstruction
<
half_t
,
half_t
,
float
>
{
using
MMA
=
MMA_Atom
<
SM80_16x8x16_F32F16F16F32_TN
>
;
using
MMA_Group
=
Layout
<
Shape
<
_1
,
_2
,
_1
>>
;
};
template
<
>
struct
DispatchInstruction
<
bfloat16_t
,
bfloat16_t
,
float
>
{
template
<
>
struct
DispatchInstruction
<
bfloat16_t
,
bfloat16_t
,
float
>
{
using
MMA
=
MMA_Atom
<
SM80_16x8x16_F32BF16BF16F32_TN
>
;
using
MMA_Group
=
Layout
<
Shape
<
_1
,
_2
,
_1
>>
;
};
template
<
>
struct
DispatchInstruction
<
tfloat32_t
,
tfloat32_t
,
float
>
{
template
<
>
struct
DispatchInstruction
<
tfloat32_t
,
tfloat32_t
,
float
>
{
using
MMA
=
MMA_Atom
<
SM80_16x8x8_F32TF32TF32F32_TN
>
;
using
MMA_Group
=
Layout
<
Shape
<
_1
,
_2
,
_1
>>
;
};
template
<
>
struct
DispatchInstruction
<
int8_t
,
int8_t
,
int
>
{
template
<
>
struct
DispatchInstruction
<
int8_t
,
int8_t
,
int
>
{
using
MMA
=
MMA_Atom
<
SM80_16x8x32_S32S8S8S32_TN
>
;
using
MMA_Group
=
Layout
<
Shape
<
_1
,
_2
,
_1
>>
;
};
template
<
>
struct
DispatchInstruction
<
double
,
double
,
double
>
{
template
<
>
struct
DispatchInstruction
<
double
,
double
,
double
>
{
using
MMA
=
MMA_Atom
<
SM80_8x8x4_F64F64F64F64_TN
>
;
using
MMA_Group
=
Layout
<
Shape
<
_2
,
_2
,
_1
>>
;
};
#elif (defined(__CUDA_ARCH_LIST__) && (__CUDA_ARCH_LIST__ >= 750))
template
<
>
struct
DispatchInstruction
<
half_t
,
half_t
,
float
>
{
template
<
>
struct
DispatchInstruction
<
half_t
,
half_t
,
float
>
{
using
MMA
=
MMA_Atom
<
SM75_16x8x8_F32F16F16F32_TN
>
;
using
MMA_Group
=
Layout
<
Shape
<
_1
,
_2
,
_2
>>
;
};
...
...
@@ -54,149 +47,175 @@ template <int Bits, int N, int K, bool K_inner, typename Enable = void>
struct
OperandTraits
{
// Primary template, use padded layout and default copy
static
constexpr
int
stride
=
K_inner
?
K
:
N
;
static
constexpr
int
padded
=
stride
%
(
256
/
Bits
)
==
0
?
stride
+
128
/
Bits
:
stride
;
using
Layout
=
typename
std
::
conditional
<
K_inner
,
Layout
<
Shape
<
Int
<
N
>
,
Int
<
K
>>
,
Shape
<
Int
<
padded
>
,
_1
>>
,
Layout
<
Shape
<
Int
<
N
>
,
Int
<
K
>>
,
Shape
<
_1
,
Int
<
padded
>>>>::
type
;
static
constexpr
int
padded
=
stride
%
(
256
/
Bits
)
==
0
?
stride
+
128
/
Bits
:
stride
;
using
Layout
=
typename
std
::
conditional
<
K_inner
,
Layout
<
Shape
<
Int
<
N
>
,
Int
<
K
>>
,
Shape
<
Int
<
padded
>
,
_1
>>
,
Layout
<
Shape
<
Int
<
N
>
,
Int
<
K
>>
,
Shape
<
_1
,
Int
<
padded
>>>>::
type
;
using
Copy
=
DefaultCopy
;
};
template
<
int
N
,
int
K
>
struct
OperandTraits
<
16
,
N
,
K
,
true
,
typename
std
::
enable_if
<
K
%
64
==
32
>::
type
>
{
using
LayoutAtom
=
decltype
(
composition
(
Swizzle
<
2
,
3
,
3
>
{},
Layout
<
Shape
<
_8
,
_32
>
,
Stride
<
_32
,
_1
>>
{}));
struct
OperandTraits
<
16
,
N
,
K
,
true
,
typename
std
::
enable_if
<
K
%
64
==
32
>::
type
>
{
using
LayoutAtom
=
decltype
(
composition
(
Swizzle
<
2
,
3
,
3
>
{},
Layout
<
Shape
<
_8
,
_32
>
,
Stride
<
_32
,
_1
>>
{}));
using
Layout
=
decltype
(
tile_to_shape
(
LayoutAtom
{},
Shape
<
Int
<
N
>
,
Int
<
K
>>
{}));
using
Copy
=
SM75_U32x4_LDSM_N
;
};
template
<
int
N
,
int
K
>
struct
OperandTraits
<
16
,
N
,
K
,
true
,
typename
std
::
enable_if
<
K
%
64
==
0
>::
type
>
{
using
LayoutAtom
=
decltype
(
composition
(
Swizzle
<
3
,
3
,
3
>
{},
Layout
<
Shape
<
_8
,
_64
>
,
Stride
<
_64
,
_1
>>
{}));
struct
OperandTraits
<
16
,
N
,
K
,
true
,
typename
std
::
enable_if
<
K
%
64
==
0
>::
type
>
{
using
LayoutAtom
=
decltype
(
composition
(
Swizzle
<
3
,
3
,
3
>
{},
Layout
<
Shape
<
_8
,
_64
>
,
Stride
<
_64
,
_1
>>
{}));
using
Layout
=
decltype
(
tile_to_shape
(
LayoutAtom
{},
Shape
<
Int
<
N
>
,
Int
<
K
>>
{}));
using
Copy
=
SM75_U32x4_LDSM_N
;
};
template
<
int
N
,
int
K
>
struct
OperandTraits
<
16
,
N
,
K
,
false
,
typename
std
::
enable_if
<
N
%
64
==
32
>::
type
>
{
using
LayoutAtom
=
decltype
(
composition
(
Swizzle
<
2
,
3
,
3
>
{},
Layout
<
Shape
<
_32
,
_8
>
,
Stride
<
_1
,
_32
>>
{}));
using
Layout
=
decltype
(
tile_to_shape
(
LayoutAtom
{},
Shape
<
Int
<
N
>
,
Int
<
K
>>
{},
Step
<
_2
,
_1
>
{}));
struct
OperandTraits
<
16
,
N
,
K
,
false
,
typename
std
::
enable_if
<
N
%
64
==
32
>::
type
>
{
using
LayoutAtom
=
decltype
(
composition
(
Swizzle
<
2
,
3
,
3
>
{},
Layout
<
Shape
<
_32
,
_8
>
,
Stride
<
_1
,
_32
>>
{}));
using
Layout
=
decltype
(
tile_to_shape
(
LayoutAtom
{},
Shape
<
Int
<
N
>
,
Int
<
K
>>
{},
Step
<
_2
,
_1
>
{}));
using
Copy
=
SM75_U16x8_LDSM_T
;
};
template
<
int
N
,
int
K
>
struct
OperandTraits
<
16
,
N
,
K
,
false
,
typename
std
::
enable_if
<
N
%
64
==
0
>::
type
>
{
using
LayoutAtom
=
decltype
(
composition
(
Swizzle
<
3
,
3
,
3
>
{},
Layout
<
Shape
<
_64
,
_8
>
,
Stride
<
_1
,
_64
>>
{}));
using
Layout
=
decltype
(
tile_to_shape
(
LayoutAtom
{},
Shape
<
Int
<
N
>
,
Int
<
K
>>
{},
Step
<
_2
,
_1
>
{}));
struct
OperandTraits
<
16
,
N
,
K
,
false
,
typename
std
::
enable_if
<
N
%
64
==
0
>::
type
>
{
using
LayoutAtom
=
decltype
(
composition
(
Swizzle
<
3
,
3
,
3
>
{},
Layout
<
Shape
<
_64
,
_8
>
,
Stride
<
_1
,
_64
>>
{}));
using
Layout
=
decltype
(
tile_to_shape
(
LayoutAtom
{},
Shape
<
Int
<
N
>
,
Int
<
K
>>
{},
Step
<
_2
,
_1
>
{}));
using
Copy
=
SM75_U16x8_LDSM_T
;
};
template
<
int
N
,
int
K
>
struct
OperandTraits
<
32
,
N
,
K
,
true
,
typename
std
::
enable_if
<
K
%
32
==
0
>::
type
>
{
using
LayoutAtom
=
decltype
(
composition
(
Swizzle
<
3
,
2
,
3
>
{},
Layout
<
Shape
<
_8
,
_32
>
,
Stride
<
_32
,
_1
>>
{}));
struct
OperandTraits
<
32
,
N
,
K
,
true
,
typename
std
::
enable_if
<
K
%
32
==
0
>::
type
>
{
using
LayoutAtom
=
decltype
(
composition
(
Swizzle
<
3
,
2
,
3
>
{},
Layout
<
Shape
<
_8
,
_32
>
,
Stride
<
_32
,
_1
>>
{}));
using
Layout
=
decltype
(
tile_to_shape
(
LayoutAtom
{},
Shape
<
Int
<
N
>
,
Int
<
K
>>
{}));
using
Copy
=
SM75_U32x4_LDSM_N
;
};
template
<
int
N
,
int
K
>
struct
OperandTraits
<
32
,
N
,
K
,
true
,
typename
std
::
enable_if
<
K
%
32
==
16
>::
type
>
{
using
LayoutAtom
=
decltype
(
composition
(
Swizzle
<
2
,
2
,
3
>
{},
Layout
<
Shape
<
_8
,
_16
>
,
Stride
<
_16
,
_1
>>
{}));
struct
OperandTraits
<
32
,
N
,
K
,
true
,
typename
std
::
enable_if
<
K
%
32
==
16
>::
type
>
{
using
LayoutAtom
=
decltype
(
composition
(
Swizzle
<
2
,
2
,
3
>
{},
Layout
<
Shape
<
_8
,
_16
>
,
Stride
<
_16
,
_1
>>
{}));
using
Layout
=
decltype
(
tile_to_shape
(
LayoutAtom
{},
Shape
<
Int
<
N
>
,
Int
<
K
>>
{}));
using
Copy
=
SM75_U32x4_LDSM_N
;
};
template
<
int
N
,
int
K
>
struct
OperandTraits
<
32
,
N
,
K
,
false
,
typename
std
::
enable_if
<
N
%
32
==
0
>::
type
>
{
using
LayoutAtom
=
decltype
(
composition
(
Swizzle
<
3
,
2
,
3
>
{},
Layout
<
Shape
<
_32
,
_8
>
,
Stride
<
_1
,
_32
>>
{}));
using
Layout
=
decltype
(
tile_to_shape
(
LayoutAtom
{},
Shape
<
Int
<
N
>
,
Int
<
K
>>
{},
Step
<
_2
,
_1
>
{}));
struct
OperandTraits
<
32
,
N
,
K
,
false
,
typename
std
::
enable_if
<
N
%
32
==
0
>::
type
>
{
using
LayoutAtom
=
decltype
(
composition
(
Swizzle
<
3
,
2
,
3
>
{},
Layout
<
Shape
<
_32
,
_8
>
,
Stride
<
_1
,
_32
>>
{}));
using
Layout
=
decltype
(
tile_to_shape
(
LayoutAtom
{},
Shape
<
Int
<
N
>
,
Int
<
K
>>
{},
Step
<
_2
,
_1
>
{}));
using
Copy
=
UniversalCopy
<
tfloat32_t
>
;
};
template
<
int
N
,
int
K
>
struct
OperandTraits
<
32
,
N
,
K
,
false
,
typename
std
::
enable_if
<
N
%
32
==
16
>::
type
>
{
using
LayoutAtom
=
decltype
(
composition
(
Swizzle
<
2
,
2
,
3
>
{},
Layout
<
Shape
<
_16
,
_8
>
,
Stride
<
_1
,
_16
>>
{}));
using
Layout
=
decltype
(
tile_to_shape
(
LayoutAtom
{},
Shape
<
Int
<
N
>
,
Int
<
K
>>
{},
Step
<
_2
,
_1
>
{}));
struct
OperandTraits
<
32
,
N
,
K
,
false
,
typename
std
::
enable_if
<
N
%
32
==
16
>::
type
>
{
using
LayoutAtom
=
decltype
(
composition
(
Swizzle
<
2
,
2
,
3
>
{},
Layout
<
Shape
<
_16
,
_8
>
,
Stride
<
_1
,
_16
>>
{}));
using
Layout
=
decltype
(
tile_to_shape
(
LayoutAtom
{},
Shape
<
Int
<
N
>
,
Int
<
K
>>
{},
Step
<
_2
,
_1
>
{}));
using
Copy
=
UniversalCopy
<
tfloat32_t
>
;
};
template
<
int
N
,
int
K
>
struct
OperandTraits
<
8
,
N
,
K
,
true
,
typename
std
::
enable_if
<
K
%
128
==
64
>::
type
>
{
using
LayoutAtom
=
decltype
(
composition
(
Swizzle
<
2
,
4
,
3
>
{},
Layout
<
Shape
<
_8
,
_64
>
,
Stride
<
_64
,
_1
>>
{}));
struct
OperandTraits
<
8
,
N
,
K
,
true
,
typename
std
::
enable_if
<
K
%
128
==
64
>::
type
>
{
using
LayoutAtom
=
decltype
(
composition
(
Swizzle
<
2
,
4
,
3
>
{},
Layout
<
Shape
<
_8
,
_64
>
,
Stride
<
_64
,
_1
>>
{}));
using
Layout
=
decltype
(
tile_to_shape
(
LayoutAtom
{},
Shape
<
Int
<
N
>
,
Int
<
K
>>
{}));
using
Copy
=
SM75_U32x4_LDSM_N
;
};
template
<
int
N
,
int
K
>
struct
OperandTraits
<
8
,
N
,
K
,
true
,
typename
std
::
enable_if
<
K
%
128
==
0
>::
type
>
{
using
LayoutAtom
=
decltype
(
composition
(
Swizzle
<
3
,
4
,
3
>
{},
Layout
<
Shape
<
_8
,
_128
>
,
Stride
<
_128
,
_1
>>
{}));
struct
OperandTraits
<
8
,
N
,
K
,
true
,
typename
std
::
enable_if
<
K
%
128
==
0
>::
type
>
{
using
LayoutAtom
=
decltype
(
composition
(
Swizzle
<
3
,
4
,
3
>
{},
Layout
<
Shape
<
_8
,
_128
>
,
Stride
<
_128
,
_1
>>
{}));
using
Layout
=
decltype
(
tile_to_shape
(
LayoutAtom
{},
Shape
<
Int
<
N
>
,
Int
<
K
>>
{}));
using
Copy
=
SM75_U32x4_LDSM_N
;
};
template
<
int
N
,
int
K
>
struct
OperandTraits
<
64
,
N
,
K
,
true
,
typename
std
::
enable_if
<
K
%
16
==
0
>::
type
>
{
using
LayoutAtom
=
decltype
(
composition
(
Swizzle
<
2
,
0
,
4
>
{},
Layout
<
Shape
<
_4
,
_16
>
,
Stride
<
_16
,
_1
>>
{}));
struct
OperandTraits
<
64
,
N
,
K
,
true
,
typename
std
::
enable_if
<
K
%
16
==
0
>::
type
>
{
using
LayoutAtom
=
decltype
(
composition
(
Swizzle
<
2
,
0
,
4
>
{},
Layout
<
Shape
<
_4
,
_16
>
,
Stride
<
_16
,
_1
>>
{}));
using
Layout
=
decltype
(
tile_to_shape
(
LayoutAtom
{},
Shape
<
Int
<
N
>
,
Int
<
K
>>
{}));
using
Copy
=
DefaultCopy
;
};
template
<
int
N
,
int
K
>
struct
OperandTraits
<
64
,
N
,
K
,
false
,
typename
std
::
enable_if
<
N
%
16
==
0
>::
type
>
{
using
LayoutAtom
=
decltype
(
composition
(
Swizzle
<
2
,
2
,
2
>
{},
Layout
<
Shape
<
_16
,
_4
>
,
Stride
<
_1
,
_16
>>
{}));
using
Layout
=
decltype
(
tile_to_shape
(
LayoutAtom
{},
Shape
<
Int
<
N
>
,
Int
<
K
>>
{},
Step
<
_2
,
_1
>
{}));
struct
OperandTraits
<
64
,
N
,
K
,
false
,
typename
std
::
enable_if
<
N
%
16
==
0
>::
type
>
{
using
LayoutAtom
=
decltype
(
composition
(
Swizzle
<
2
,
2
,
2
>
{},
Layout
<
Shape
<
_16
,
_4
>
,
Stride
<
_1
,
_16
>>
{}));
using
Layout
=
decltype
(
tile_to_shape
(
LayoutAtom
{},
Shape
<
Int
<
N
>
,
Int
<
K
>>
{},
Step
<
_2
,
_1
>
{}));
using
Copy
=
DefaultCopy
;
};
template
<
int
M
,
int
N
,
int
K
,
int
num_warp_m
,
int
num_warp_n
,
bool
trans_A
,
bool
trans_B
,
typename
A_type_raw
,
typename
B_type_raw
,
typename
C_type_raw
>
template
<
int
M
,
int
N
,
int
K
,
int
num_warp_m
,
int
num_warp_n
,
bool
trans_A
,
bool
trans_B
,
typename
A_type_raw
,
typename
B_type_raw
,
typename
C_type_raw
>
class
GemmTensorOp
{
public:
using
A_type
=
typename
std
::
conditional
<
std
::
is_same
<
A_type_raw
,
float
>::
value
,
tfloat32_t
,
A_type_raw
>::
type
;
using
B_type
=
typename
std
::
conditional
<
std
::
is_same
<
B_type_raw
,
float
>::
value
,
tfloat32_t
,
A_type_raw
>::
type
;
public:
using
A_type
=
typename
std
::
conditional
<
std
::
is_same
<
A_type_raw
,
float
>::
value
,
tfloat32_t
,
A_type_raw
>::
type
;
using
B_type
=
typename
std
::
conditional
<
std
::
is_same
<
B_type_raw
,
float
>::
value
,
tfloat32_t
,
A_type_raw
>::
type
;
using
C_type
=
C_type_raw
;
using
Instruction
=
DispatchInstruction
<
A_type
,
B_type
,
C_type
>
;
using
OperandATraits
=
OperandTraits
<
sizeof_bits
<
A_type
>::
value
,
M
,
K
,
!
trans_A
>
;
using
OperandBTraits
=
OperandTraits
<
sizeof_bits
<
B_type
>::
value
,
N
,
K
,
trans_B
>
;
using
OperandATraits
=
OperandTraits
<
sizeof_bits
<
A_type
>::
value
,
M
,
K
,
!
trans_A
>
;
using
OperandBTraits
=
OperandTraits
<
sizeof_bits
<
B_type
>::
value
,
N
,
K
,
trans_B
>
;
using
SmemLayoutA
=
typename
OperandATraits
::
Layout
;
using
SmemLayoutB
=
typename
OperandBTraits
::
Layout
;
using
SmemCopyA
=
Copy_Atom
<
typename
OperandATraits
::
Copy
,
A_type
>
;
using
SmemCopyB
=
Copy_Atom
<
typename
OperandBTraits
::
Copy
,
B_type
>
;
using
TileMma
=
TiledMMA
<
typename
Instruction
::
MMA
,
Layout
<
Shape
<
Int
<
num_warp_m
>
,
Int
<
num_warp_n
>
,
_1
>>
,
typename
Instruction
::
MMA_Group
>
;
using
TileMma
=
TiledMMA
<
typename
Instruction
::
MMA
,
Layout
<
Shape
<
Int
<
num_warp_m
>
,
Int
<
num_warp_n
>
,
_1
>>
,
typename
Instruction
::
MMA_Group
>
;
template
<
class
...
Args
>
static
CUTE_DEVICE
auto
remove_swizzle
(
Layout
<
Args
...
>
const
&
layout
)
{
static
CUTE_DEVICE
auto
remove_swizzle
(
Layout
<
Args
...
>
const
&
layout
)
{
return
layout
;
}
// In fp16, when layout is KxN and n_warp is 1 and N % 64 == 0
// the original layout fail to compile, currently using this as a workaround
template
<
class
...
Args
>
static
CUTE_DEVICE
auto
remove_swizzle
(
ComposedLayout
<
Args
...
>
const
&
layout
)
{
static
CUTE_DEVICE
auto
remove_swizzle
(
ComposedLayout
<
Args
...
>
const
&
layout
)
{
if
constexpr
(
sizeof
(
A_type
)
==
2
)
return
layout
.
layout_b
();
else
return
layout
;
}
static
CUTE_DEVICE
void
body
(
A_type_raw
*
pA
,
B_type_raw
*
pB
,
C_type_raw
*
pC
)
{
static
CUTE_DEVICE
void
body
(
A_type_raw
*
pA
,
B_type_raw
*
pB
,
C_type_raw
*
pC
)
{
const
int
tid
=
threadIdx
.
x
;
Tensor
sA
=
make_tensor
(
make_smem_ptr
(
reinterpret_cast
<
A_type
*>
(
pA
)),
SmemLayoutA
{});
Tensor
sB
=
make_tensor
(
make_smem_ptr
(
reinterpret_cast
<
B_type
*>
(
pB
)),
SmemLayoutB
{});
Tensor
sA
=
make_tensor
(
make_smem_ptr
(
reinterpret_cast
<
A_type
*>
(
pA
)),
SmemLayoutA
{});
Tensor
sB
=
make_tensor
(
make_smem_ptr
(
reinterpret_cast
<
B_type
*>
(
pB
)),
SmemLayoutB
{});
TileMma
tiled_mma
;
auto
thr_mma
=
tiled_mma
.
get_thread_slice
(
tid
);
auto
tiled_copy_A
=
make_tiled_copy_A
(
SmemCopyA
{},
tiled_mma
);
...
...
@@ -212,10 +231,12 @@ class GemmTensorOp {
Tensor
tCrA_copy_view
=
thr_copy_A
.
retile_D
(
tCrA
);
Tensor
tCrB_copy_view
=
thr_copy_B
.
retile_D
(
tCrB
);
Tensor
acc
=
make_tensor
(
make_rmem_ptr
(
reinterpret_cast
<
C_type
*>
(
pC
)),
partition_shape_C
(
tiled_mma
,
Shape
<
Int
<
M
>
,
Int
<
N
>>
{}));
Tensor
acc
=
make_tensor
(
make_rmem_ptr
(
reinterpret_cast
<
C_type
*>
(
pC
)),
partition_shape_C
(
tiled_mma
,
Shape
<
Int
<
M
>
,
Int
<
N
>>
{}));
// when layout is KxN and n_warp is 1, there seem to be a bug, use this as a workaround
// when layout is KxN and n_warp is 1, there seem to be a bug, use this as a
// workaround
auto
tCrA_view
=
make_tensor
(
tCrA
.
data
(),
remove_swizzle
(
tCrA
.
layout
()));
auto
tCrB_view
=
make_tensor
(
tCrB
.
data
(),
remove_swizzle
(
tCrB
.
layout
()));
CUTE_UNROLL
...
...
@@ -226,9 +247,11 @@ class GemmTensorOp {
}
}
static
CUTE_DEVICE
void
body_rs
(
A_type_raw
*
pA
,
B_type_raw
*
pB
,
C_type_raw
*
pC
)
{
static
CUTE_DEVICE
void
body_rs
(
A_type_raw
*
pA
,
B_type_raw
*
pB
,
C_type_raw
*
pC
)
{
const
int
tid
=
threadIdx
.
x
;
Tensor
sB
=
make_tensor
(
make_smem_ptr
(
reinterpret_cast
<
B_type
*>
(
pB
)),
SmemLayoutB
{});
Tensor
sB
=
make_tensor
(
make_smem_ptr
(
reinterpret_cast
<
B_type
*>
(
pB
)),
SmemLayoutB
{});
TileMma
tiled_mma
;
auto
thr_mma
=
tiled_mma
.
get_thread_slice
(
tid
);
auto
tiled_copy_B
=
make_tiled_copy_B
(
SmemCopyB
{},
tiled_mma
);
...
...
@@ -239,10 +262,12 @@ class GemmTensorOp {
Tensor
tCrB_copy_view
=
thr_copy_B
.
retile_D
(
tCrB
);
Tensor
acc
=
make_tensor
(
make_rmem_ptr
(
reinterpret_cast
<
C_type
*>
(
pC
)),
partition_shape_C
(
tiled_mma
,
Shape
<
Int
<
M
>
,
Int
<
N
>>
{}));
Tensor
tCrA
=
make_tensor
(
make_rmem_ptr
(
reinterpret_cast
<
A_type
*>
(
pA
)),
partition_shape_A
(
tiled_mma
,
Shape
<
Int
<
M
>
,
Int
<
K
>>
{}));
Tensor
acc
=
make_tensor
(
make_rmem_ptr
(
reinterpret_cast
<
C_type
*>
(
pC
)),
partition_shape_C
(
tiled_mma
,
Shape
<
Int
<
M
>
,
Int
<
N
>>
{}));
Tensor
tCrA
=
make_tensor
(
make_rmem_ptr
(
reinterpret_cast
<
A_type
*>
(
pA
)),
partition_shape_A
(
tiled_mma
,
Shape
<
Int
<
M
>
,
Int
<
K
>>
{}));
auto
tCrB_view
=
make_tensor
(
tCrB
.
data
(),
remove_swizzle
(
tCrB
.
layout
()));
copy
(
tiled_copy_B
,
tCsB
(
_
,
_
,
0
),
tCrB_copy_view
(
_
,
_
,
0
));
...
...
@@ -255,9 +280,11 @@ class GemmTensorOp {
}
}
static
CUTE_DEVICE
void
body_sr
(
A_type_raw
*
pA
,
B_type_raw
*
pB
,
C_type_raw
*
pC
)
{
static
CUTE_DEVICE
void
body_sr
(
A_type_raw
*
pA
,
B_type_raw
*
pB
,
C_type_raw
*
pC
)
{
const
int
tid
=
threadIdx
.
x
;
Tensor
sA
=
make_tensor
(
make_smem_ptr
(
reinterpret_cast
<
A_type
*>
(
pA
)),
SmemLayoutA
{});
Tensor
sA
=
make_tensor
(
make_smem_ptr
(
reinterpret_cast
<
A_type
*>
(
pA
)),
SmemLayoutA
{});
TileMma
tiled_mma
;
auto
thr_mma
=
tiled_mma
.
get_thread_slice
(
tid
);
auto
tiled_copy_A
=
make_tiled_copy_A
(
SmemCopyA
{},
tiled_mma
);
...
...
@@ -268,10 +295,12 @@ class GemmTensorOp {
Tensor
tCrA_copy_view
=
thr_copy_A
.
retile_D
(
tCrA
);
Tensor
acc
=
make_tensor
(
make_rmem_ptr
(
reinterpret_cast
<
C_type
*>
(
pC
)),
partition_shape_C
(
tiled_mma
,
Shape
<
Int
<
M
>
,
Int
<
N
>>
{}));
Tensor
tCrB
=
make_tensor
(
make_rmem_ptr
(
reinterpret_cast
<
B_type
*>
(
pB
)),
partition_shape_B
(
tiled_mma
,
Shape
<
Int
<
N
>
,
Int
<
K
>>
{}));
Tensor
acc
=
make_tensor
(
make_rmem_ptr
(
reinterpret_cast
<
C_type
*>
(
pC
)),
partition_shape_C
(
tiled_mma
,
Shape
<
Int
<
M
>
,
Int
<
N
>>
{}));
Tensor
tCrB
=
make_tensor
(
make_rmem_ptr
(
reinterpret_cast
<
B_type
*>
(
pB
)),
partition_shape_B
(
tiled_mma
,
Shape
<
Int
<
N
>
,
Int
<
K
>>
{}));
auto
tCrA_view
=
make_tensor
(
tCrA
.
data
(),
remove_swizzle
(
tCrA
.
layout
()));
copy
(
tiled_copy_A
,
tCsA
(
_
,
_
,
0
),
tCrA_copy_view
(
_
,
_
,
0
));
...
...
@@ -285,32 +314,32 @@ class GemmTensorOp {
}
};
}
// namespace cute
}
// namespace cute
namespace
tl
{
template
<
int
M
,
int
N
,
int
K
,
int
num_warp_m
,
int
num_warp_n
,
bool
trans_A
,
bool
trans_B
,
typename
A_type
,
typename
B_type
,
typename
C_type
>
CUTLASS_DEVICE
void
gemm_ss
(
A_type
*
pA
,
B_type
*
pB
,
C_type
*
accum
)
{
using
MMA
=
cute
::
GemmTensorOp
<
M
,
N
,
K
,
num_warp_m
,
num_warp_n
,
trans_A
,
trans_B
,
A_type
,
B_type
,
C_type
>
;
template
<
int
M
,
int
N
,
int
K
,
int
num_warp_m
,
int
num_warp_n
,
bool
trans_A
,
bool
trans_B
,
typename
A_type
,
typename
B_type
,
typename
C_type
>
CUTLASS_DEVICE
void
gemm_ss
(
A_type
*
pA
,
B_type
*
pB
,
C_type
*
accum
)
{
using
MMA
=
cute
::
GemmTensorOp
<
M
,
N
,
K
,
num_warp_m
,
num_warp_n
,
trans_A
,
trans_B
,
A_type
,
B_type
,
C_type
>
;
MMA
::
body
(
pA
,
pB
,
accum
);
}
template
<
int
M
,
int
N
,
int
K
,
int
num_warp_m
,
int
num_warp_n
,
bool
trans_A
,
bool
trans_B
,
typename
A_type
,
typename
B_type
,
typename
C_type
>
CUTLASS_DEVICE
void
gemm_rs
(
A_type
*
pA
,
B_type
*
pB
,
C_type
*
accum
)
{
using
MMA
=
cute
::
GemmTensorOp
<
M
,
N
,
K
,
num_warp_m
,
num_warp_n
,
trans_A
,
trans_B
,
A_type
,
B_type
,
C_type
>
;
template
<
int
M
,
int
N
,
int
K
,
int
num_warp_m
,
int
num_warp_n
,
bool
trans_A
,
bool
trans_B
,
typename
A_type
,
typename
B_type
,
typename
C_type
>
CUTLASS_DEVICE
void
gemm_rs
(
A_type
*
pA
,
B_type
*
pB
,
C_type
*
accum
)
{
using
MMA
=
cute
::
GemmTensorOp
<
M
,
N
,
K
,
num_warp_m
,
num_warp_n
,
trans_A
,
trans_B
,
A_type
,
B_type
,
C_type
>
;
MMA
::
body_rs
(
pA
,
pB
,
accum
);
}
template
<
int
M
,
int
N
,
int
K
,
int
num_warp_m
,
int
num_warp_n
,
bool
trans_A
,
bool
trans_B
,
typename
A_type
,
typename
B_type
,
typename
C_type
>
CUTLASS_DEVICE
void
gemm_sr
(
A_type
*
pA
,
B_type
*
pB
,
C_type
*
accum
)
{
using
MMA
=
cute
::
GemmTensorOp
<
M
,
N
,
K
,
num_warp_m
,
num_warp_n
,
trans_A
,
trans_B
,
A_type
,
B_type
,
C_type
>
;
template
<
int
M
,
int
N
,
int
K
,
int
num_warp_m
,
int
num_warp_n
,
bool
trans_A
,
bool
trans_B
,
typename
A_type
,
typename
B_type
,
typename
C_type
>
CUTLASS_DEVICE
void
gemm_sr
(
A_type
*
pA
,
B_type
*
pB
,
C_type
*
accum
)
{
using
MMA
=
cute
::
GemmTensorOp
<
M
,
N
,
K
,
num_warp_m
,
num_warp_n
,
trans_A
,
trans_B
,
A_type
,
B_type
,
C_type
>
;
MMA
::
body_sr
(
pA
,
pB
,
accum
);
}
}
// namespace tl
}
// namespace tl
src/tl_templates/cuda/gemm_sm90.h
View file @
549416f7
...
...
@@ -2,9 +2,9 @@
// Licensed under the MIT License.
#pragma once
#include <cutlass/cutlass.h>
#include <cutlass/arch/barrier.h>
#include <cute/algorithm/copy.hpp>
#include <cutlass/arch/barrier.h>
#include <cutlass/cutlass.h>
#include "common.h"
...
...
@@ -19,78 +19,112 @@ CUTE_HOST_DEVICE constexpr auto ss_smem_selector() {
static_assert
(
BLK_K0
%
8
==
0
,
"BLK_K0 must be a multiple of 8."
);
if
constexpr
(
major
==
GMMA
::
Major
::
MN
)
{
if
constexpr
(
BLK_MN0
%
size
<
0
>
(
GMMA
::
Layout_MN_SW128_Atom
<
ElementType
>
{})
==
0
)
{
if
constexpr
(
BLK_MN0
%
size
<
0
>
(
GMMA
::
Layout_MN_SW128_Atom
<
ElementType
>
{})
==
0
)
{
return
GMMA
::
Layout_MN_SW128_Atom
<
ElementType
>
{};
}
else
if
constexpr
(
BLK_MN0
%
size
<
0
>
(
GMMA
::
Layout_MN_SW64_Atom
<
ElementType
>
{})
==
0
)
{
}
else
if
constexpr
(
BLK_MN0
%
size
<
0
>
(
GMMA
::
Layout_MN_SW64_Atom
<
ElementType
>
{})
==
0
)
{
return
GMMA
::
Layout_MN_SW64_Atom
<
ElementType
>
{};
}
else
if
constexpr
(
BLK_MN0
%
size
<
0
>
(
GMMA
::
Layout_MN_SW32_Atom
<
ElementType
>
{})
==
0
)
{
}
else
if
constexpr
(
BLK_MN0
%
size
<
0
>
(
GMMA
::
Layout_MN_SW32_Atom
<
ElementType
>
{})
==
0
)
{
return
GMMA
::
Layout_MN_SW32_Atom
<
ElementType
>
{};
}
else
if
constexpr
(
BLK_MN0
%
size
<
0
>
(
GMMA
::
Layout_MN_INTER_Atom
<
ElementType
>
{})
==
0
)
{
}
else
if
constexpr
(
BLK_MN0
%
size
<
0
>
(
GMMA
::
Layout_MN_INTER_Atom
<
ElementType
>
{})
==
0
)
{
return
GMMA
::
Layout_MN_INTER_Atom
<
ElementType
>
{};
}
else
{
static_assert
(
BLK_MN0
%
size
<
0
>
(
GMMA
::
Layout_MN_INTER_Atom
<
ElementType
>
{})
==
0
,
"BLK_MN0 must be a multiple of size<0>(GMMA::Layout_MN_INTER_Atom<ElementType>{})"
);
"BLK_MN0 must be a multiple of "
"size<0>(GMMA::Layout_MN_INTER_Atom<ElementType>{})"
);
}
}
else
if
constexpr
(
major
==
GMMA
::
Major
::
K
)
{
if
constexpr
(
BLK_K0
%
size
<
1
>
(
GMMA
::
Layout_K_SW128_Atom
<
ElementType
>
{})
==
0
)
{
if
constexpr
(
BLK_K0
%
size
<
1
>
(
GMMA
::
Layout_K_SW128_Atom
<
ElementType
>
{})
==
0
)
{
return
GMMA
::
Layout_K_SW128_Atom
<
ElementType
>
{};
}
else
if
constexpr
(
BLK_K0
%
size
<
1
>
(
GMMA
::
Layout_K_SW64_Atom
<
ElementType
>
{})
==
0
)
{
}
else
if
constexpr
(
BLK_K0
%
size
<
1
>
(
GMMA
::
Layout_K_SW64_Atom
<
ElementType
>
{})
==
0
)
{
return
GMMA
::
Layout_K_SW64_Atom
<
ElementType
>
{};
}
else
if
constexpr
(
BLK_K0
%
size
<
1
>
(
GMMA
::
Layout_K_SW32_Atom
<
ElementType
>
{})
==
0
)
{
}
else
if
constexpr
(
BLK_K0
%
size
<
1
>
(
GMMA
::
Layout_K_SW32_Atom
<
ElementType
>
{})
==
0
)
{
return
GMMA
::
Layout_K_SW32_Atom
<
ElementType
>
{};
}
else
if
constexpr
(
BLK_K0
%
size
<
1
>
(
GMMA
::
Layout_K_INTER_Atom
<
ElementType
>
{})
==
0
)
{
}
else
if
constexpr
(
BLK_K0
%
size
<
1
>
(
GMMA
::
Layout_K_INTER_Atom
<
ElementType
>
{})
==
0
)
{
return
GMMA
::
Layout_K_INTER_Atom
<
ElementType
>
{};
}
else
{
static_assert
(
BLK_K0
%
size
<
1
>
(
GMMA
::
Layout_K_INTER_Atom
<
ElementType
>
{})
==
0
,
"BLK_K0 must be a multiple of size<1>(GMMA::Layout_K_INTER_Atom<ElementType>{})"
);
"BLK_K0 must be a multiple of "
"size<1>(GMMA::Layout_K_INTER_Atom<ElementType>{})"
);
}
}
}
template
<
int
M
,
int
N
,
int
K
,
int
num_warp_m
,
int
num_warp_n
,
bool
trans_A
,
bool
trans_B
,
typename
A_type_raw
,
typename
B_type_raw
,
typename
C_type_raw
>
template
<
int
M
,
int
N
,
int
K
,
int
num_warp_m
,
int
num_warp_n
,
bool
trans_A
,
bool
trans_B
,
typename
A_type_raw
,
typename
B_type_raw
,
typename
C_type_raw
>
class
GemmTensorOp
{
public:
using
A_type
=
conditional_t
<
std
::
is_same
<
A_type_raw
,
float
>::
value
,
tfloat32_t
,
A_type_raw
>
;
using
B_type
=
conditional_t
<
std
::
is_same
<
B_type_raw
,
float
>::
value
,
tfloat32_t
,
B_type_raw
>
;
public:
using
A_type
=
conditional_t
<
std
::
is_same
<
A_type_raw
,
float
>::
value
,
tfloat32_t
,
A_type_raw
>
;
using
B_type
=
conditional_t
<
std
::
is_same
<
B_type_raw
,
float
>::
value
,
tfloat32_t
,
B_type_raw
>
;
using
C_type
=
C_type_raw
;
static
constexpr
GMMA
::
Major
GmmaMajorA
=
trans_A
?
GMMA
::
Major
::
MN
:
GMMA
::
Major
::
K
;
static
constexpr
GMMA
::
Major
GmmaMajorB
=
trans_B
?
GMMA
::
Major
::
K
:
GMMA
::
Major
::
MN
;
static
constexpr
GMMA
::
Major
GmmaMajorA
=
trans_A
?
GMMA
::
Major
::
MN
:
GMMA
::
Major
::
K
;
static
constexpr
GMMA
::
Major
GmmaMajorB
=
trans_B
?
GMMA
::
Major
::
K
:
GMMA
::
Major
::
MN
;
using
SmemLayoutAtomA
=
decltype
(
ss_smem_selector
<
GmmaMajorA
,
A_type
,
Int
<
M
>
,
Int
<
K
>>
());
using
SmemLayoutAtomB
=
decltype
(
ss_smem_selector
<
GmmaMajorB
,
B_type
,
Int
<
N
>
,
Int
<
K
>>
());
using
SmemLayoutAtomA
=
decltype
(
ss_smem_selector
<
GmmaMajorA
,
A_type
,
Int
<
M
>
,
Int
<
K
>>
());
using
SmemLayoutAtomB
=
decltype
(
ss_smem_selector
<
GmmaMajorB
,
B_type
,
Int
<
N
>
,
Int
<
K
>>
());
using
SmemLayoutA
=
decltype
(
tile_to_shape
(
SmemLayoutAtomA
{},
Shape
<
Int
<
M
>
,
Int
<
K
>>
{},
conditional_t
<
trans_A
,
Step
<
_2
,
_1
>
,
Step
<
_1
,
_2
>>
{}));
using
SmemLayoutB
=
decltype
(
tile_to_shape
(
SmemLayoutAtomB
{},
Shape
<
Int
<
N
>
,
Int
<
K
>>
{},
conditional_t
<
trans_B
,
Step
<
_1
,
_2
>
,
Step
<
_2
,
_1
>>
{}));
using
SmemLayoutA
=
decltype
(
tile_to_shape
(
SmemLayoutAtomA
{},
Shape
<
Int
<
M
>
,
Int
<
K
>>
{},
conditional_t
<
trans_A
,
Step
<
_2
,
_1
>
,
Step
<
_1
,
_2
>>
{}));
using
SmemLayoutB
=
decltype
(
tile_to_shape
(
SmemLayoutAtomB
{},
Shape
<
Int
<
N
>
,
Int
<
K
>>
{},
conditional_t
<
trans_B
,
Step
<
_1
,
_2
>
,
Step
<
_2
,
_1
>>
{}));
// static_assert(num_warp_n == 1);
static_assert
(
num_warp_m
%
4
==
0
);
template
<
int
wg_wait
=
0
>
static
CUTE_DEVICE
void
body
(
A_type_raw
*
pA
,
B_type_raw
*
pB
,
C_type_raw
*
pC
)
{
template
<
int
wg_wait
=
0
>
static
CUTE_DEVICE
void
body
(
A_type_raw
*
pA
,
B_type_raw
*
pB
,
C_type_raw
*
pC
)
{
const
int
tid
=
threadIdx
.
x
;
Tensor
sA
=
make_tensor
(
make_smem_ptr
(
reinterpret_cast
<
A_type
*>
(
pA
)),
SmemLayoutA
{});
Tensor
sB
=
make_tensor
(
make_smem_ptr
(
reinterpret_cast
<
B_type
*>
(
pB
)),
SmemLayoutB
{});
auto
tiled_mma
=
make_tiled_mma
(
GMMA
::
ss_op_selector
<
A_type
,
B_type
,
C_type
,
Shape
<
Int
<
M
>
,
Int
<
N
/
num_warp_n
>
,
Int
<
K
>>
,
GmmaMajorA
,
GmmaMajorB
>
(),
Layout
<
Shape
<
Int
<
num_warp_m
/
4
>
,
Int
<
num_warp_n
>
,
_1
>>
{});
Tensor
sA
=
make_tensor
(
make_smem_ptr
(
reinterpret_cast
<
A_type
*>
(
pA
)),
SmemLayoutA
{});
Tensor
sB
=
make_tensor
(
make_smem_ptr
(
reinterpret_cast
<
B_type
*>
(
pB
)),
SmemLayoutB
{});
auto
tiled_mma
=
make_tiled_mma
(
GMMA
::
ss_op_selector
<
A_type
,
B_type
,
C_type
,
Shape
<
Int
<
M
>
,
Int
<
N
/
num_warp_n
>
,
Int
<
K
>>
,
GmmaMajorA
,
GmmaMajorB
>
(),
Layout
<
Shape
<
Int
<
num_warp_m
/
4
>
,
Int
<
num_warp_n
>
,
_1
>>
{});
auto
thr_mma
=
tiled_mma
.
get_thread_slice
(
tid
);
// Allocate registers for pipelining
Tensor
tCsA
=
thr_mma
.
partition_A
(
sA
);
// (MMA,MMA_M,MMA_K,PIPE)
Tensor
tCsB
=
thr_mma
.
partition_B
(
sB
);
// (MMA,MMA_N,MMA_K,PIPE)
Tensor
tCsA
=
thr_mma
.
partition_A
(
sA
);
// (MMA,MMA_M,MMA_K,PIPE)
Tensor
tCsB
=
thr_mma
.
partition_B
(
sB
);
// (MMA,MMA_N,MMA_K,PIPE)
Tensor
tCrA
=
thr_mma
.
make_fragment_A
(
tCsA
);
// (MMA,MMA_N,MMA_K,PIPE)
Tensor
tCrB
=
thr_mma
.
make_fragment_B
(
tCsB
);
// (MMA,MMA_M,MMA_N,PIPE)
Tensor
tCrA
=
thr_mma
.
make_fragment_A
(
tCsA
);
// (MMA,MMA_N,MMA_K,PIPE)
Tensor
tCrB
=
thr_mma
.
make_fragment_B
(
tCsB
);
// (MMA,MMA_M,MMA_N,PIPE)
Tensor
acc
=
make_tensor
(
make_rmem_ptr
(
reinterpret_cast
<
C_type
*>
(
pC
)),
partition_shape_C
(
tiled_mma
,
Shape
<
Int
<
M
>
,
Int
<
N
>>
{}));
Tensor
acc
=
make_tensor
(
make_rmem_ptr
(
reinterpret_cast
<
C_type
*>
(
pC
)),
partition_shape_C
(
tiled_mma
,
Shape
<
Int
<
M
>
,
Int
<
N
>>
{}));
warpgroup_fence_operand
(
acc
);
warpgroup_arrive
();
...
...
@@ -103,7 +137,9 @@ class GemmTensorOp {
}
warpgroup_commit_batch
();
if
constexpr
(
wg_wait
>=
0
)
{
warpgroup_wait
<
wg_wait
>
();
}
if
constexpr
(
wg_wait
>=
0
)
{
warpgroup_wait
<
wg_wait
>
();
}
warpgroup_fence_operand
(
acc
);
// warpgroup_fence_operand(acc);
// warpgroup_arrive();
...
...
@@ -115,25 +151,31 @@ class GemmTensorOp {
// warpgroup_fence_operand(acc);
}
template
<
int
wg_wait
=
0
>
static
CUTE_DEVICE
void
body_rs
(
A_type_raw
*
pA
,
B_type_raw
*
pB
,
C_type_raw
*
pC
)
{
template
<
int
wg_wait
=
0
>
static
CUTE_DEVICE
void
body_rs
(
A_type_raw
*
pA
,
B_type_raw
*
pB
,
C_type_raw
*
pC
)
{
// TODO: Move bar.sync out of body_rs
// asm volatile("bar.sync %0, %1;" : : "r"(1), "r"(num_warp_m * num_warp_n * 32));
// asm volatile("bar.sync %0, %1;" : : "r"(1), "r"(num_warp_m * num_warp_n *
// 32));
const
int
tid
=
threadIdx
.
x
;
Tensor
sB
=
make_tensor
(
make_smem_ptr
(
reinterpret_cast
<
B_type
*>
(
pB
)),
SmemLayoutB
{});
auto
tiled_mma
=
make_tiled_mma
(
GMMA
::
rs_op_selector
<
A_type
,
B_type
,
C_type
,
Shape
<
Int
<
M
>
,
Int
<
N
/
num_warp_n
>
,
Int
<
K
>>
,
GmmaMajorA
,
GmmaMajorB
>
(),
Layout
<
Shape
<
Int
<
num_warp_m
/
4
>
,
Int
<
num_warp_n
>
,
_1
>>
{});
Tensor
sB
=
make_tensor
(
make_smem_ptr
(
reinterpret_cast
<
B_type
*>
(
pB
)),
SmemLayoutB
{});
auto
tiled_mma
=
make_tiled_mma
(
GMMA
::
rs_op_selector
<
A_type
,
B_type
,
C_type
,
Shape
<
Int
<
M
>
,
Int
<
N
/
num_warp_n
>
,
Int
<
K
>>
,
GmmaMajorA
,
GmmaMajorB
>
(),
Layout
<
Shape
<
Int
<
num_warp_m
/
4
>
,
Int
<
num_warp_n
>
,
_1
>>
{});
auto
thr_mma
=
tiled_mma
.
get_thread_slice
(
tid
);
// Allocate registers for pipelining
Tensor
tCsB
=
thr_mma
.
partition_B
(
sB
);
// (MMA,MMA_N,MMA_K,PIPE)
Tensor
tCrB
=
thr_mma
.
make_fragment_B
(
tCsB
);
// (MMA,MMA_M,MMA_N,PIPE)
Tensor
tCrA
=
make_tensor
(
make_rmem_ptr
(
reinterpret_cast
<
A_type
*>
(
pA
)),
partition_shape_A
(
tiled_mma
,
Shape
<
Int
<
M
>
,
Int
<
K
>>
{}));
Tensor
acc
=
make_tensor
(
make_rmem_ptr
(
reinterpret_cast
<
C_type
*>
(
pC
)),
partition_shape_C
(
tiled_mma
,
Shape
<
Int
<
M
>
,
Int
<
N
>>
{}));
Tensor
tCsB
=
thr_mma
.
partition_B
(
sB
);
// (MMA,MMA_N,MMA_K,PIPE)
Tensor
tCrB
=
thr_mma
.
make_fragment_B
(
tCsB
);
// (MMA,MMA_M,MMA_N,PIPE)
Tensor
tCrA
=
make_tensor
(
make_rmem_ptr
(
reinterpret_cast
<
A_type
*>
(
pA
)),
partition_shape_A
(
tiled_mma
,
Shape
<
Int
<
M
>
,
Int
<
K
>>
{}));
Tensor
acc
=
make_tensor
(
make_rmem_ptr
(
reinterpret_cast
<
C_type
*>
(
pC
)),
partition_shape_C
(
tiled_mma
,
Shape
<
Int
<
M
>
,
Int
<
N
>>
{}));
warpgroup_fence_operand
(
tCrA
);
warpgroup_fence_operand
(
acc
);
...
...
@@ -146,7 +188,9 @@ class GemmTensorOp {
tiled_mma
.
accumulate_
=
GMMA
::
ScaleOut
::
One
;
}
warpgroup_commit_batch
();
if
constexpr
(
wg_wait
>=
0
)
{
warpgroup_wait
<
wg_wait
>
();
}
if
constexpr
(
wg_wait
>=
0
)
{
warpgroup_wait
<
wg_wait
>
();
}
warpgroup_fence_operand
(
acc
);
warpgroup_fence_operand
(
tCrA
);
...
...
@@ -156,57 +200,63 @@ class GemmTensorOp {
// gemm(tiled_mma, tCrA(_, _, _), tCrB(_, _, _), acc);
// warpgroup_commit_batch();
// if constexpr (wg_wait >= 0) { warpgroup_wait<wg_wait>(); }
// warpgroup_fence_operand(acc);
}
};
}
// namespace cute
}
// namespace cute
namespace
tl
{
template
<
int
M
,
int
N
,
int
K
,
int
num_warp_m
,
int
num_warp_n
,
bool
trans_A
,
bool
trans_B
,
int
wg_wait
=
0
,
typename
A_type
,
typename
B_type
,
typename
C_type
>
TL_DEVICE
void
gemm_ss
(
A_type
*
pA
,
B_type
*
pB
,
C_type
*
accum
)
{
using
MMA
=
cute
::
GemmTensorOp
<
M
,
N
,
K
,
num_warp_m
,
num_warp_n
,
trans_A
,
trans_B
,
A_type
,
B_type
,
C_type
>
;
template
<
int
M
,
int
N
,
int
K
,
int
num_warp_m
,
int
num_warp_n
,
bool
trans_A
,
bool
trans_B
,
int
wg_wait
=
0
,
typename
A_type
,
typename
B_type
,
typename
C_type
>
TL_DEVICE
void
gemm_ss
(
A_type
*
pA
,
B_type
*
pB
,
C_type
*
accum
)
{
using
MMA
=
cute
::
GemmTensorOp
<
M
,
N
,
K
,
num_warp_m
,
num_warp_n
,
trans_A
,
trans_B
,
A_type
,
B_type
,
C_type
>
;
MMA
::
body
<
wg_wait
>
(
pA
,
pB
,
accum
);
}
template
<
int
M
,
int
N
,
int
K
,
int
num_warp_m
,
int
num_warp_n
,
bool
trans_A
,
bool
trans_B
,
int
wg_wait
=
0
,
typename
A_type
,
typename
B_type
,
typename
C_type
>
TL_DEVICE
void
gemm_rs
(
A_type
*
pA
,
B_type
*
pB
,
C_type
*
accum
)
{
using
MMA
=
cute
::
GemmTensorOp
<
M
,
N
,
K
,
num_warp_m
,
num_warp_n
,
trans_A
,
trans_B
,
A_type
,
B_type
,
C_type
>
;
template
<
int
M
,
int
N
,
int
K
,
int
num_warp_m
,
int
num_warp_n
,
bool
trans_A
,
bool
trans_B
,
int
wg_wait
=
0
,
typename
A_type
,
typename
B_type
,
typename
C_type
>
TL_DEVICE
void
gemm_rs
(
A_type
*
pA
,
B_type
*
pB
,
C_type
*
accum
)
{
using
MMA
=
cute
::
GemmTensorOp
<
M
,
N
,
K
,
num_warp_m
,
num_warp_n
,
trans_A
,
trans_B
,
A_type
,
B_type
,
C_type
>
;
MMA
::
body_rs
<
wg_wait
>
(
pA
,
pB
,
accum
);
}
template
<
int
num_mma
>
TL_DEVICE
void
wait_wgmma
()
{
template
<
int
num_mma
>
TL_DEVICE
void
wait_wgmma
()
{
warpgroup_wait
<
num_mma
>
();
}
template
<
int
NumMmaThreads
>
TL_DEVICE
void
warp_scheduler_barrier_sync
()
{
cutlass
::
arch
::
NamedBarrier
::
sync
(
NumMmaThreads
,
cutlass
::
canonical_warp_group_idx
()
/*id*/
);
template
<
int
NumMmaThreads
>
TL_DEVICE
void
warp_scheduler_barrier_sync
()
{
cutlass
::
arch
::
NamedBarrier
::
sync
(
NumMmaThreads
,
cutlass
::
canonical_warp_group_idx
()
/*id*/
);
}
template
<
int
NumMmaThreads
>
TL_DEVICE
void
warp_scheduler_barrier_arrive
()
{
template
<
int
NumMmaThreads
>
TL_DEVICE
void
warp_scheduler_barrier_arrive
()
{
static_assert
(
NumMmaThreads
==
256
||
NumMmaThreads
==
384
);
if
constexpr
(
NumMmaThreads
==
256
)
{
cutlass
::
arch
::
NamedBarrier
::
arrive
(
NumMmaThreads
,
(
1
-
cutlass
::
canonical_warp_group_idx
())
/*id*/
);
cutlass
::
arch
::
NamedBarrier
::
arrive
(
NumMmaThreads
,
(
1
-
cutlass
::
canonical_warp_group_idx
())
/*id*/
);
}
else
{
cutlass
::
arch
::
NamedBarrier
::
arrive
(
NumMmaThreads
,
(
cutlass
::
canonical_warp_group_idx
()
<=
1
?
cutlass
::
canonical_warp_group_idx
()
+
1
:
cutlass
::
canonical_warp_group_idx
()
+
1
-
3
)
/*id*/
);
cutlass
::
arch
::
NamedBarrier
::
arrive
(
NumMmaThreads
,
(
cutlass
::
canonical_warp_group_idx
()
<=
0
?
cutlass
::
canonical_warp_group_idx
()
+
2
:
cutlass
::
canonical_warp_group_idx
()
+
2
-
3
)
/*id*/
);
cutlass
::
arch
::
NamedBarrier
::
arrive
(
NumMmaThreads
,
(
cutlass
::
canonical_warp_group_idx
()
<=
1
?
cutlass
::
canonical_warp_group_idx
()
+
1
:
cutlass
::
canonical_warp_group_idx
()
+
1
-
3
)
/*id*/
);
cutlass
::
arch
::
NamedBarrier
::
arrive
(
NumMmaThreads
,
(
cutlass
::
canonical_warp_group_idx
()
<=
0
?
cutlass
::
canonical_warp_group_idx
()
+
2
:
cutlass
::
canonical_warp_group_idx
()
+
2
-
3
)
/*id*/
);
}
}
template
<
int
NumMmaThreads
>
TL_DEVICE
void
mma_init
()
{
template
<
int
NumMmaThreads
>
TL_DEVICE
void
mma_init
()
{
static_assert
(
NumMmaThreads
==
256
||
NumMmaThreads
==
384
);
if
(
cutlass
::
canonical_warp_group_idx
()
>
0
)
{
cutlass
::
arch
::
NamedBarrier
::
arrive
(
NumMmaThreads
,
0
);
...
...
@@ -217,4 +267,4 @@ TL_DEVICE void mma_init() {
}
}
}
}
// namespace tl
}
// namespace tl
src/tl_templates/cuda/ldsm.h
View file @
549416f7
...
...
@@ -6,97 +6,118 @@
namespace
tl
{
TL_DEVICE
void
ptx_ldmatrix_x1
(
void
const
*
const
smem_ptr
,
void
*
const
local_ptr
)
{
TL_DEVICE
void
ptx_ldmatrix_x1
(
void
const
*
const
smem_ptr
,
void
*
const
local_ptr
)
{
uint32_t
smem_int_ptr
=
smem_ptr_to_uint
(
smem_ptr
);
int32_t
*
value
=
reinterpret_cast
<
int32_t
*>
(
local_ptr
);
int32_t
*
value
=
reinterpret_cast
<
int32_t
*>
(
local_ptr
);
asm
volatile
(
"ldmatrix.sync.aligned.x1.m8n8.shared.b16 {%0}, [%1];
\n
"
:
"=r"
(
value
[
0
])
:
"r"
(
smem_int_ptr
));
}
TL_DEVICE
void
ptx_ldmatrix_x2
(
void
const
*
const
smem_ptr
,
void
*
const
local_ptr
)
{
TL_DEVICE
void
ptx_ldmatrix_x2
(
void
const
*
const
smem_ptr
,
void
*
const
local_ptr
)
{
uint32_t
smem_int_ptr
=
smem_ptr_to_uint
(
smem_ptr
);
int32_t
*
value
=
reinterpret_cast
<
int32_t
*>
(
local_ptr
);
int32_t
*
value
=
reinterpret_cast
<
int32_t
*>
(
local_ptr
);
asm
volatile
(
"ldmatrix.sync.aligned.x2.m8n8.shared.b16 {%0, %1}, [%2];
\n
"
:
"=r"
(
value
[
0
]),
"=r"
(
value
[
1
])
:
"r"
(
smem_int_ptr
));
}
TL_DEVICE
void
ptx_ldmatrix_x4
(
void
const
*
const
smem_ptr
,
void
*
const
local_ptr
)
{
TL_DEVICE
void
ptx_ldmatrix_x4
(
void
const
*
const
smem_ptr
,
void
*
const
local_ptr
)
{
uint32_t
smem_int_ptr
=
smem_ptr_to_uint
(
smem_ptr
);
int32_t
*
value
=
reinterpret_cast
<
int32_t
*>
(
local_ptr
);
asm
volatile
(
"ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, %3}, [%4];
\n
"
:
"=r"
(
value
[
0
]),
"=r"
(
value
[
1
]),
"=r"
(
value
[
2
]),
"=r"
(
value
[
3
])
:
"r"
(
smem_int_ptr
));
int32_t
*
value
=
reinterpret_cast
<
int32_t
*>
(
local_ptr
);
asm
volatile
(
"ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, %3}, [%4];
\n
"
:
"=r"
(
value
[
0
]),
"=r"
(
value
[
1
]),
"=r"
(
value
[
2
]),
"=r"
(
value
[
3
])
:
"r"
(
smem_int_ptr
));
}
TL_DEVICE
void
ptx_ldmatrix_x1_trans
(
void
const
*
const
smem_ptr
,
void
*
const
local_ptr
)
{
TL_DEVICE
void
ptx_ldmatrix_x1_trans
(
void
const
*
const
smem_ptr
,
void
*
const
local_ptr
)
{
uint32_t
smem_int_ptr
=
smem_ptr_to_uint
(
smem_ptr
);
int32_t
*
value
=
reinterpret_cast
<
int32_t
*>
(
local_ptr
);
int32_t
*
value
=
reinterpret_cast
<
int32_t
*>
(
local_ptr
);
asm
volatile
(
"ldmatrix.sync.aligned.x1.trans.m8n8.shared.b16 {%0}, [%1];
\n
"
:
"=r"
(
value
[
0
])
:
"r"
(
smem_int_ptr
));
}
TL_DEVICE
void
ptx_ldmatrix_x2_trans
(
void
const
*
const
smem_ptr
,
void
*
const
local_ptr
)
{
TL_DEVICE
void
ptx_ldmatrix_x2_trans
(
void
const
*
const
smem_ptr
,
void
*
const
local_ptr
)
{
uint32_t
smem_int_ptr
=
smem_ptr_to_uint
(
smem_ptr
);
int32_t
*
value
=
reinterpret_cast
<
int32_t
*>
(
local_ptr
);
asm
volatile
(
"ldmatrix.sync.aligned.x2.trans.m8n8.shared.b16 {%0, %1}, [%2];
\n
"
:
"=r"
(
value
[
0
]),
"=r"
(
value
[
1
])
:
"r"
(
smem_int_ptr
));
int32_t
*
value
=
reinterpret_cast
<
int32_t
*>
(
local_ptr
);
asm
volatile
(
"ldmatrix.sync.aligned.x2.trans.m8n8.shared.b16 {%0, %1}, [%2];
\n
"
:
"=r"
(
value
[
0
]),
"=r"
(
value
[
1
])
:
"r"
(
smem_int_ptr
));
}
TL_DEVICE
void
ptx_ldmatrix_x4_trans
(
void
const
*
const
smem_ptr
,
void
*
const
local_ptr
)
{
TL_DEVICE
void
ptx_ldmatrix_x4_trans
(
void
const
*
const
smem_ptr
,
void
*
const
local_ptr
)
{
uint32_t
smem_int_ptr
=
smem_ptr_to_uint
(
smem_ptr
);
int32_t
*
value
=
reinterpret_cast
<
int32_t
*>
(
local_ptr
);
asm
volatile
(
"ldmatrix.sync.aligned.x4.trans.m8n8.shared.b16 {%0, %1, %2, %3}, [%4];
\n
"
:
"=r"
(
value
[
0
]),
"=r"
(
value
[
1
]),
"=r"
(
value
[
2
]),
"=r"
(
value
[
3
])
:
"r"
(
smem_int_ptr
));
int32_t
*
value
=
reinterpret_cast
<
int32_t
*>
(
local_ptr
);
asm
volatile
(
"ldmatrix.sync.aligned.x4.trans.m8n8.shared.b16 {%0, %1, %2, %3}, [%4];
\n
"
:
"=r"
(
value
[
0
]),
"=r"
(
value
[
1
]),
"=r"
(
value
[
2
]),
"=r"
(
value
[
3
])
:
"r"
(
smem_int_ptr
));
}
TL_DEVICE
void
ptx_stmatrix_x1
(
void
const
*
const
smem_ptr
,
const
int32_t
&
value0
)
{
TL_DEVICE
void
ptx_stmatrix_x1
(
void
const
*
const
smem_ptr
,
const
int32_t
&
value0
)
{
uint32_t
smem_int_ptr
=
smem_ptr_to_uint
(
smem_ptr
);
asm
volatile
(
"stmatrix.sync.aligned.x1.m8n8.shared.b16 [%0], {%1};
\n
"
::
"r"
(
smem_int_ptr
),
asm
volatile
(
"stmatrix.sync.aligned.x1.m8n8.shared.b16 [%0], {%1};
\n
"
::
"r"
(
smem_int_ptr
),
"r"
(
value0
));
}
TL_DEVICE
void
ptx_stmatrix_x2
(
void
const
*
const
smem_ptr
,
const
int32_t
&
value0
,
const
int32_t
&
value1
)
{
TL_DEVICE
void
ptx_stmatrix_x2
(
void
const
*
const
smem_ptr
,
const
int32_t
&
value0
,
const
int32_t
&
value1
)
{
uint32_t
smem_int_ptr
=
smem_ptr_to_uint
(
smem_ptr
);
asm
volatile
(
"stmatrix.sync.aligned.x2.m8n8.shared.b16 [%0], {%1, %2};
\n
"
::
"r"
(
smem_int_ptr
),
"r"
(
value0
),
"r"
(
value1
));
asm
volatile
(
"stmatrix.sync.aligned.x2.m8n8.shared.b16 [%0], {%1, %2};
\n
"
::
"r"
(
smem_int_ptr
),
"r"
(
value0
),
"r"
(
value1
));
}
TL_DEVICE
void
ptx_stmatrix_x4
(
void
const
*
const
smem_ptr
,
const
int32_t
&
value0
,
const
int32_t
&
value
1
,
const
int32_t
&
value
2
,
const
int32_t
&
value3
)
{
TL_DEVICE
void
ptx_stmatrix_x4
(
void
const
*
const
smem_ptr
,
const
int32_t
&
value
0
,
const
int32_t
&
value
1
,
const
int32_t
&
value2
,
const
int32_t
&
value3
)
{
uint32_t
smem_int_ptr
=
smem_ptr_to_uint
(
smem_ptr
);
asm
volatile
(
"stmatrix.sync.aligned.x4.m8n8.shared.b16 [%0], {%1, %2, %3, %4};
\n
"
::
"r"
(
smem_int_ptr
),
"stmatrix.sync.aligned.x4.m8n8.shared.b16 [%0], {%1, %2, %3, %4};
\n
"
::
"r"
(
smem_int_ptr
),
"r"
(
value0
),
"r"
(
value1
),
"r"
(
value2
),
"r"
(
value3
));
}
TL_DEVICE
void
ptx_stmatrix_x1_trans
(
void
const
*
const
smem_ptr
,
const
int32_t
&
value0
)
{
TL_DEVICE
void
ptx_stmatrix_x1_trans
(
void
const
*
const
smem_ptr
,
const
int32_t
&
value0
)
{
uint32_t
smem_int_ptr
=
smem_ptr_to_uint
(
smem_ptr
);
asm
volatile
(
"stmatrix.sync.aligned.x1.trans.m8n8.shared.b16 [%0], {%1};
\n
"
::
"r"
(
smem_int_ptr
),
"r"
(
value0
));
asm
volatile
(
"stmatrix.sync.aligned.x1.trans.m8n8.shared.b16 [%0], {%1};
\n
"
::
"r"
(
smem_int_ptr
),
"r"
(
value0
));
}
TL_DEVICE
void
ptx_stmatrix_x2_trans
(
void
const
*
const
smem_ptr
,
const
int32_t
&
value0
,
const
int32_t
&
value1
)
{
TL_DEVICE
void
ptx_stmatrix_x2_trans
(
void
const
*
const
smem_ptr
,
const
int32_t
&
value0
,
const
int32_t
&
value1
)
{
uint32_t
smem_int_ptr
=
smem_ptr_to_uint
(
smem_ptr
);
asm
volatile
(
"stmatrix.sync.aligned.x2.trans.m8n8.shared.b16 [%0], {%1, %2};
\n
"
::
"r"
(
smem_int_ptr
),
"stmatrix.sync.aligned.x2.trans.m8n8.shared.b16 [%0], {%1, %2};
\n
"
::
"r"
(
smem_int_ptr
),
"r"
(
value0
),
"r"
(
value1
));
}
TL_DEVICE
void
ptx_stmatrix_x4_trans
(
void
const
*
const
smem_ptr
,
const
int32_t
&
value0
,
const
int32_t
&
value1
,
const
int32_t
&
value2
,
const
int32_t
&
value3
)
{
TL_DEVICE
void
ptx_stmatrix_x4_trans
(
void
const
*
const
smem_ptr
,
const
int32_t
&
value0
,
const
int32_t
&
value1
,
const
int32_t
&
value2
,
const
int32_t
&
value3
)
{
uint32_t
smem_int_ptr
=
smem_ptr_to_uint
(
smem_ptr
);
asm
volatile
(
"stmatrix.sync.aligned.x4.trans.m8n8.shared.b16 [%0], {%1, %2,
%3, %4};
\n
"
::
"r"
(
smem_int_ptr
),
asm
volatile
(
"stmatrix.sync.aligned.x4.trans.m8n8.shared.b16 [%0], {%1, %2,
"
"%3, %4};
\n
"
::
"r"
(
smem_int_ptr
),
"r"
(
value0
),
"r"
(
value1
),
"r"
(
value2
),
"r"
(
value3
));
}
}
// namespace tl
\ No newline at end of file
}
// namespace tl
\ No newline at end of file
src/tl_templates/cuda/reduce.h
View file @
549416f7
...
...
@@ -7,34 +7,29 @@
namespace
tl
{
struct
SumOp
{
template
<
typename
T
>
TL_DEVICE
T
operator
()(
T
const
&
x
,
T
const
&
y
)
{
template
<
typename
T
>
TL_DEVICE
T
operator
()(
T
const
&
x
,
T
const
&
y
)
{
return
x
+
y
;
}
};
struct
MaxOp
{
template
<
typename
T
>
TL_DEVICE
T
operator
()(
T
const
&
x
,
T
const
&
y
)
{
template
<
typename
T
>
TL_DEVICE
T
operator
()(
T
const
&
x
,
T
const
&
y
)
{
return
cutlass
::
fast_max
(
x
,
y
);
}
};
struct
MinOp
{
template
<
typename
T
>
TL_DEVICE
T
operator
()(
T
const
&
x
,
T
const
&
y
)
{
template
<
typename
T
>
TL_DEVICE
T
operator
()(
T
const
&
x
,
T
const
&
y
)
{
return
cutlass
::
fast_min
(
x
,
y
);
}
};
template
<
class
Reducer
,
int
threads
,
int
scale
>
struct
AllReduce
{
static_assert
(
threads
==
1024
or
threads
==
512
or
threads
==
256
or
threads
==
128
or
threads
==
64
or
threads
==
32
or
threads
==
16
or
threads
==
8
or
threads
==
4
or
threads
==
2
);
template
<
class
Reducer
,
int
threads
,
int
scale
>
struct
AllReduce
{
static_assert
(
threads
==
1024
or
threads
==
512
or
threads
==
256
or
threads
==
128
or
threads
==
64
or
threads
==
32
or
threads
==
16
or
threads
==
8
or
threads
==
4
or
threads
==
2
);
static_assert
(
threads
%
scale
==
0
);
template
<
typename
T
>
static
TL_DEVICE
T
run
(
T
x
,
T
*
red_buf
=
nullptr
)
{
template
<
typename
T
>
static
TL_DEVICE
T
run
(
T
x
,
T
*
red_buf
=
nullptr
)
{
constexpr
int
offset
=
threads
/
2
;
if
constexpr
(
offset
>=
32
)
{
__syncthreads
();
...
...
@@ -54,4 +49,4 @@ struct AllReduce {
}
};
}
// namespace tl
}
// namespace tl
src/tl_templates/cuda/threadblock_swizzle.h
View file @
549416f7
...
...
@@ -6,8 +6,7 @@
namespace
tl
{
template
<
int
panel_width
>
TL_DEVICE
dim3
rasterization2DRow
()
{
template
<
int
panel_width
>
TL_DEVICE
dim3
rasterization2DRow
()
{
const
unsigned
int
block_idx
=
blockIdx
.
x
+
blockIdx
.
y
*
gridDim
.
x
;
const
unsigned
int
grid_size
=
gridDim
.
x
*
gridDim
.
y
;
const
unsigned
int
panel_size
=
panel_width
*
gridDim
.
x
;
...
...
@@ -15,15 +14,17 @@ TL_DEVICE dim3 rasterization2DRow() {
const
unsigned
int
panel_idx
=
block_idx
/
panel_size
;
const
unsigned
int
total_panel
=
cutlass
::
ceil_div
(
grid_size
,
panel_size
);
const
unsigned
int
stride
=
panel_idx
+
1
<
total_panel
?
panel_width
:
(
grid_size
-
panel_idx
*
panel_size
)
/
gridDim
.
x
;
const
unsigned
int
col_idx
=
(
panel_idx
&
1
)
?
gridDim
.
x
-
1
-
panel_offset
/
stride
:
panel_offset
/
stride
;
panel_idx
+
1
<
total_panel
?
panel_width
:
(
grid_size
-
panel_idx
*
panel_size
)
/
gridDim
.
x
;
const
unsigned
int
col_idx
=
(
panel_idx
&
1
)
?
gridDim
.
x
-
1
-
panel_offset
/
stride
:
panel_offset
/
stride
;
const
unsigned
int
row_idx
=
panel_offset
%
stride
+
panel_idx
*
panel_width
;
return
{
col_idx
,
row_idx
,
blockIdx
.
z
};
}
template
<
int
panel_width
>
TL_DEVICE
dim3
rasterization2DColumn
()
{
template
<
int
panel_width
>
TL_DEVICE
dim3
rasterization2DColumn
()
{
const
unsigned
int
block_idx
=
blockIdx
.
x
+
blockIdx
.
y
*
gridDim
.
x
;
const
unsigned
int
grid_size
=
gridDim
.
x
*
gridDim
.
y
;
const
unsigned
int
panel_size
=
panel_width
*
gridDim
.
y
;
...
...
@@ -31,11 +32,14 @@ TL_DEVICE dim3 rasterization2DColumn() {
const
unsigned
int
panel_idx
=
block_idx
/
panel_size
;
const
unsigned
int
total_panel
=
cutlass
::
ceil_div
(
grid_size
,
panel_size
);
const
unsigned
int
stride
=
panel_idx
+
1
<
total_panel
?
panel_width
:
(
grid_size
-
panel_idx
*
panel_size
)
/
gridDim
.
y
;
const
unsigned
int
row_idx
=
(
panel_idx
&
1
)
?
gridDim
.
y
-
1
-
panel_offset
/
stride
:
panel_offset
/
stride
;
panel_idx
+
1
<
total_panel
?
panel_width
:
(
grid_size
-
panel_idx
*
panel_size
)
/
gridDim
.
y
;
const
unsigned
int
row_idx
=
(
panel_idx
&
1
)
?
gridDim
.
y
-
1
-
panel_offset
/
stride
:
panel_offset
/
stride
;
const
unsigned
int
col_idx
=
panel_offset
%
stride
+
panel_idx
*
panel_width
;
return
{
col_idx
,
row_idx
,
blockIdx
.
z
};
}
}
// namespace tl
}
// namespace tl
src/tl_templates/hip/common.h
View file @
549416f7
...
...
@@ -2,10 +2,10 @@
// Licensed under the MIT License.
#pragma once
#include <
hip/hip_runtim
e.h>
#include <
ck_tile/cor
e.h
pp
>
#include <hip/hip_fp16.h>
#include <hip/hip_runtime.h>
#include <rocwmma/rocwmma.hpp>
#include <ck_tile/core.hpp>
using
ck_tile
::
half_t
;
...
...
@@ -36,12 +36,16 @@ using ck_tile::half_t;
using
float16_t
=
_Float16
;
using
float16x2
=
__attribute__
((
__vector_size__
(
2
*
sizeof
(
float16_t
))))
float16_t
;
using
float16x4
=
__attribute__
((
__vector_size__
(
4
*
sizeof
(
float16_t
))))
float16_t
;
using
float16x8
=
__attribute__
((
__vector_size__
(
8
*
sizeof
(
float16_t
))))
float16_t
;
using
float16x16
=
__attribute__
((
__vector_size__
(
16
*
sizeof
(
float16_t
))))
float16_t
;
using
float16x2
=
__attribute__
((
__vector_size__
(
2
*
sizeof
(
float16_t
))))
float16_t
;
using
float16x4
=
__attribute__
((
__vector_size__
(
4
*
sizeof
(
float16_t
))))
float16_t
;
using
float16x8
=
__attribute__
((
__vector_size__
(
8
*
sizeof
(
float16_t
))))
float16_t
;
using
float16x16
=
__attribute__
((
__vector_size__
(
16
*
sizeof
(
float16_t
))))
float16_t
;
using
int32x4
=
__attribute__
((
__vector_size__
(
4
*
sizeof
(
int
))))
int
;
using
int32x4
=
__attribute__
((
__vector_size__
(
4
*
sizeof
(
int
))))
int
;
using
float32x4
=
__attribute__
((
__vector_size__
(
4
*
sizeof
(
float
))))
float
;
using
float32x16
=
__attribute__
((
__vector_size__
(
16
*
sizeof
(
float
))))
float
;
...
...
@@ -49,7 +53,7 @@ using int8x4 = __attribute__((__vector_size__(4 * sizeof(int8_t)))) int8_t;
// Pack two half_t values.
TL_DEVICE
unsigned
__pack_half2
(
const
half_t
x
,
const
half_t
y
)
{
unsigned
v0
=
*
((
unsigned
short
*
)
&
x
);
unsigned
v1
=
*
((
unsigned
short
*
)
&
y
);
unsigned
v0
=
*
((
unsigned
short
*
)
&
x
);
unsigned
v1
=
*
((
unsigned
short
*
)
&
y
);
return
(
v1
<<
16
)
|
v0
;
}
src/tl_templates/hip/copy.h
View file @
549416f7
...
...
@@ -16,12 +16,13 @@ using index_t = u32;
using
ck_tile
::
int32x4_t
;
struct
__attribute__
((
packed
))
buffer_resource
{
const
void
*
ptr
;
const
void
*
ptr
;
uint32_t
range
;
uint32_t
config
;
};
CK_TILE_DEVICE
int32x4_t
make_wave_buffer_resource
(
const
void
*
ptr
,
uint32_t
size
=
0xffffffff
)
{
CK_TILE_DEVICE
int32x4_t
make_wave_buffer_resource
(
const
void
*
ptr
,
uint32_t
size
=
0xffffffff
)
{
buffer_resource
res
{
ptr
,
size
,
CK_TILE_BUFFER_RESOURCE_3RD_DWORD
};
int32x4_t
r
=
__builtin_bit_cast
(
int32x4_t
,
res
);
r
.
x
=
__builtin_amdgcn_readfirstlane
(
r
.
x
);
...
...
@@ -56,48 +57,56 @@ __device__ void async_gld_sld_fence(index_t cnt) {
__device__
void
wave_barrier
()
{
asm
volatile
(
"s_barrier"
:
:
:
"memory"
);
}
template
<
int
N
=
0
>
TL_DEVICE
void
cp_async_wait
()
{
template
<
int
N
=
0
>
TL_DEVICE
void
cp_async_wait
()
{
async_gld_fence
(
N
);
// or
// async_gld_sld_fence(N);
}
template
<
bool
pre_nop
=
false
>
CK_TILE_DEVICE
void
async_buffer_load_dword_v
(
void
*
smem
,
int32x4_t
rsrc
,
index_t
voffset
)
{
auto
const
lds_ptr_sgpr
=
__builtin_amdgcn_readfirstlane
((
reinterpret_cast
<
uintptr_t
>
(
smem
)));
asm
volatile
(
"s_mov_b32 m0, %0;
\n\t
"
"buffer_load_dword %1, %2, 0 offen lds;
\n\t
"
::
"s"
(
lds_ptr_sgpr
),
"v"
(
voffset
),
"s"
(
rsrc
)
:
"memory"
);
CK_TILE_DEVICE
void
async_buffer_load_dword_v
(
void
*
smem
,
int32x4_t
rsrc
,
index_t
voffset
)
{
auto
const
lds_ptr_sgpr
=
__builtin_amdgcn_readfirstlane
((
reinterpret_cast
<
uintptr_t
>
(
smem
)));
asm
volatile
(
"s_mov_b32 m0, %0;
\n\t
"
"buffer_load_dword %1, %2, 0 offen lds;
\n\t
"
::
"s"
(
lds_ptr_sgpr
),
"v"
(
voffset
),
"s"
(
rsrc
)
:
"memory"
);
}
template
<
int
N
>
TL_DEVICE
void
cp_async_gs
(
void
*
lds_base_ptr
,
void
*
global_base_ptr
)
{
if
constexpr
(
N
==
16
)
{
*
(
uint4
*
)
lds_base_ptr
=
*
(
uint4
*
)
global_base_ptr
;
}
else
if
constexpr
(
N
==
8
)
{
*
(
uint2
*
)
lds_base_ptr
=
*
(
uint2
*
)
global_base_ptr
;
}
else
if
constexpr
(
N
==
4
)
{
async_buffer_load_dword_v
(
lds_base_ptr
,
make_wave_buffer_resource
(((
int32_t
*
)
global_base_ptr
)
-
threadIdx
.
x
),
threadIdx
.
x
*
N
/*assume 4 bytes*/
);
TL_DEVICE
void
cp_async_gs
(
void
*
lds_base_ptr
,
void
*
global_base_ptr
)
{
if
constexpr
(
N
==
16
)
{
*
(
uint4
*
)
lds_base_ptr
=
*
(
uint4
*
)
global_base_ptr
;
}
else
if
constexpr
(
N
==
8
)
{
*
(
uint2
*
)
lds_base_ptr
=
*
(
uint2
*
)
global_base_ptr
;
}
else
if
constexpr
(
N
==
4
)
{
async_buffer_load_dword_v
(
lds_base_ptr
,
make_wave_buffer_resource
(((
int32_t
*
)
global_base_ptr
)
-
threadIdx
.
x
),
threadIdx
.
x
*
N
/*assume 4 bytes*/
);
}
}
template
<
int
N
>
TL_DEVICE
void
cp_async_gs_conditional
(
void
*
lds_base_ptr
,
void
*
global_base_ptr
,
bool
cond
)
{
if
constexpr
(
N
==
16
){
*
(
uint4
*
)
lds_base_ptr
=
cond
?
*
(
uint4
*
)
global_base_ptr
:
make_uint4
(
0
,
0
,
0
,
0
);
}
else
if
constexpr
(
N
==
8
){
*
(
uint2
*
)
lds_base_ptr
=
cond
?
*
(
uint2
*
)
global_base_ptr
:
make_uint2
(
0
,
0
);
}
else
{
TL_DEVICE
void
cp_async_gs_conditional
(
void
*
lds_base_ptr
,
void
*
global_base_ptr
,
bool
cond
)
{
if
constexpr
(
N
==
16
)
{
*
(
uint4
*
)
lds_base_ptr
=
cond
?
*
(
uint4
*
)
global_base_ptr
:
make_uint4
(
0
,
0
,
0
,
0
);
}
else
if
constexpr
(
N
==
8
)
{
*
(
uint2
*
)
lds_base_ptr
=
cond
?
*
(
uint2
*
)
global_base_ptr
:
make_uint2
(
0
,
0
);
}
else
{
if
(
cond
)
{
async_buffer_load_dword_v
(
lds_base_ptr
,
make_wave_buffer_resource
(((
int32_t
*
)
global_base_ptr
)
-
threadIdx
.
x
),
threadIdx
.
x
*
N
/*assume 4 bytes*/
);
}
else
{
*
(
uint4
*
)
lds_base_ptr
=
make_uint4
(
0
,
0
,
0
,
0
);
async_buffer_load_dword_v
(
lds_base_ptr
,
make_wave_buffer_resource
(((
int32_t
*
)
global_base_ptr
)
-
threadIdx
.
x
),
threadIdx
.
x
*
N
/*assume 4 bytes*/
);
}
else
{
*
(
uint4
*
)
lds_base_ptr
=
make_uint4
(
0
,
0
,
0
,
0
);
}
}
}
}
// namespace tl
}
// namespace tl
src/tl_templates/hip/gemm.h
View file @
549416f7
...
...
@@ -6,12 +6,12 @@
namespace
tl
{
// ref to bitblas/tl/mfma_macro_generator.py::kPack
template
<
int
M
,
int
N
,
int
K
,
int
num_warp_m
,
int
num_warp_n
,
bool
TransposeA
,
bool
TransposeB
,
int
kPack
,
typename
A_type
,
typename
B_type
,
typename
C_type
,
typename
AccDataType
=
float
>
template
<
int
M
,
int
N
,
int
K
,
int
num_warp_m
,
int
num_warp_n
,
bool
TransposeA
,
bool
TransposeB
,
int
kPack
,
typename
A_type
,
typename
B_type
,
typename
C_type
,
typename
AccDataType
=
float
>
class
GemmTensorOp
{
public:
public:
static
constexpr
int
micro_size_x
=
16
;
static
constexpr
int
micro_size_y
=
16
;
static
constexpr
int
micro_size_k
=
16
;
...
...
@@ -28,7 +28,8 @@ class GemmTensorOp {
static
constexpr
int
warp_rows
=
M_Tile
/
(
block_row_warps
*
micro_size_x
);
static
constexpr
int
warp_cols
=
N_Tile
/
(
block_col_warps
*
micro_size_y
);
// The kPadA, kPadB, kPadC & kBlockPerCu should also come from the Codegen part.
// The kPadA, kPadB, kPadC & kBlockPerCu should also come from the Codegen
// part.
static
constexpr
bool
kPadA
=
true
;
static
constexpr
bool
kPadB
=
true
;
static
constexpr
bool
kPadC
=
true
;
...
...
@@ -37,12 +38,16 @@ class GemmTensorOp {
static
constexpr
int
warp_size
=
64
;
TL_DEVICE
static
constexpr
auto
reverse_index_map
(
int
thread_id
,
int
local_id
)
{
return
std
::
make_pair
(
thread_id
%
16
,
(
thread_id
/
16
)
*
(
4
*
kPack
)
+
local_id
);
TL_DEVICE
static
constexpr
auto
reverse_index_map
(
int
thread_id
,
int
local_id
)
{
return
std
::
make_pair
(
thread_id
%
16
,
(
thread_id
/
16
)
*
(
4
*
kPack
)
+
local_id
);
}
TL_DEVICE
static
constexpr
auto
reverse_index_map_transposed
(
int
thread_id
,
int
local_id
)
{
return
std
::
make_pair
((
thread_id
/
16
)
*
(
4
*
kPack
)
+
local_id
,
thread_id
%
16
);
TL_DEVICE
static
constexpr
auto
reverse_index_map_transposed
(
int
thread_id
,
int
local_id
)
{
return
std
::
make_pair
((
thread_id
/
16
)
*
(
4
*
kPack
)
+
local_id
,
thread_id
%
16
);
}
/*
...
...
@@ -62,7 +67,8 @@ class GemmTensorOp {
const
int
elemsPerOneBanksRow
=
(
numBanks
*
bankBitWidth
)
/
typeWidthInBit
;
const
int
perPhase
=
std
::
max
(
1
,
elemsPerOneBanksRow
/
innerDimLength
);
const
int
maxPhase
=
std
::
min
(
SIMDWidth
/
perPhase
,
innerDimLength
/
vecSize
);
const
int
maxPhase
=
std
::
min
(
SIMDWidth
/
perPhase
,
innerDimLength
/
vecSize
);
const
int
phase
=
(
row
/
perPhase
)
%
maxPhase
;
const
int
colOffSwizzled
=
(((
col
/
vecSize
)
^
phase
)
*
vecSize
);
...
...
@@ -73,16 +79,19 @@ class GemmTensorOp {
}
template
<
int
continuous
=
32
,
int
element_size
=
2
>
TL_DEVICE
static
constexpr
auto
make_layout_padded
(
const
int
row
,
const
int
col
)
{
TL_DEVICE
static
constexpr
auto
make_layout_padded
(
const
int
row
,
const
int
col
)
{
return
std
::
make_pair
(
row
,
col
);
}
template
<
int
continuous
=
32
,
int
element_size
=
2
>
TL_DEVICE
static
constexpr
auto
make_swizzle_layout
(
const
int
row
,
const
int
col
)
{
TL_DEVICE
static
constexpr
auto
make_swizzle_layout
(
const
int
row
,
const
int
col
)
{
constexpr
auto
vector_size
=
BANK_SIZE_BYTES
/
(
element_size
*
8
);
if
(
continuous
%
(
vector_size
*
4
)
==
0
)
{
auto
[
n_row
,
n_col
]
=
make_mfma_swizzle_layout
<
continuous
,
element_size
>
(
row
,
col
);
auto
[
n_row
,
n_col
]
=
make_mfma_swizzle_layout
<
continuous
,
element_size
>
(
row
,
col
);
return
n_row
*
continuous
+
n_col
;
}
else
{
auto
[
n_row
,
n_col
]
=
make_layout_padded
(
row
,
col
);
...
...
@@ -93,7 +102,8 @@ class GemmTensorOp {
}
}
static
TL_DEVICE
void
body
(
A_type
*
A_shared
,
B_type
*
B_shared
,
C_type
*
C_local
)
{
static
TL_DEVICE
void
body
(
A_type
*
A_shared
,
B_type
*
B_shared
,
C_type
*
C_local
)
{
auto
tid
=
threadIdx
.
x
;
auto
warp_id
=
tid
/
warp_size
;
auto
warp_n
=
warp_id
/
block_row_warps
;
...
...
@@ -122,7 +132,8 @@ class GemmTensorOp {
for
(
int
local_id
=
0
;
local_id
<
(
kPack
*
local_size_a
);
local_id
++
)
{
auto
[
row
,
col
]
=
reverse_index_map
(
lane_id
,
local_id
);
A_local
[
i
*
kPack
*
local_size_a
+
local_id
]
=
A_shared
[
make_swizzle_layout
<
last_dim_a
,
sizeof
(
A_type
)
>
(
l
+
row
,
r
+
col
)];
A_shared
[
make_swizzle_layout
<
last_dim_a
,
sizeof
(
A_type
)
>
(
l
+
row
,
r
+
col
)];
}
}
...
...
@@ -133,7 +144,8 @@ class GemmTensorOp {
for
(
int
local_id
=
0
;
local_id
<
(
kPack
*
local_size_b
);
local_id
++
)
{
auto
[
row
,
col
]
=
reverse_index_map
(
lane_id
,
local_id
);
B_local
[
j
*
kPack
*
local_size_b
+
local_id
]
=
B_shared
[
make_swizzle_layout
<
last_dim_b
,
sizeof
(
B_type
)
>
(
l
+
row
,
r
+
col
)];
B_shared
[
make_swizzle_layout
<
last_dim_b
,
sizeof
(
B_type
)
>
(
l
+
row
,
r
+
col
)];
}
}
...
...
@@ -141,17 +153,19 @@ class GemmTensorOp {
for
(
int
kp
=
0
;
kp
<
kPack
;
kp
++
)
{
for
(
int
i
=
0
;
i
<
warp_rows
;
++
i
)
{
for
(
int
j
=
0
;
j
<
warp_cols
;
++
j
)
{
*
(((
float32x4
*
)
C_local
)
+
((
i
*
warp_cols
)
+
j
))
=
__builtin_amdgcn_mfma_f32_16x16x16f16
(
*
(((
float16x4
*
)
B_local
)
+
j
*
kPack
+
kp
),
*
(((
float16x4
*
)
A_local
)
+
i
*
kPack
+
kp
),
*
(((
float32x4
*
)
C_local
)
+
((
i
*
warp_cols
)
+
j
)),
0
,
0
,
0
);
*
(((
float32x4
*
)
C_local
)
+
((
i
*
warp_cols
)
+
j
))
=
__builtin_amdgcn_mfma_f32_16x16x16f16
(
*
(((
float16x4
*
)
B_local
)
+
j
*
kPack
+
kp
),
*
(((
float16x4
*
)
A_local
)
+
i
*
kPack
+
kp
),
*
(((
float32x4
*
)
C_local
)
+
((
i
*
warp_cols
)
+
j
)),
0
,
0
,
0
);
}
}
}
}
}
static
TL_DEVICE
void
body_rs
(
A_type
*
A_local
,
B_type
*
B_shared
,
C_type
*
C_local
)
{
static
TL_DEVICE
void
body_rs
(
A_type
*
A_local
,
B_type
*
B_shared
,
C_type
*
C_local
)
{
auto
tid
=
threadIdx
.
x
;
auto
warp_id
=
tid
/
warp_size
;
auto
warp_n
=
warp_id
/
block_row_warps
;
...
...
@@ -179,7 +193,8 @@ class GemmTensorOp {
for
(
int
local_id
=
0
;
local_id
<
kPack
*
local_size_b
;
local_id
++
)
{
auto
[
row
,
col
]
=
reverse_index_map
(
lane_id
,
local_id
);
B_local
[
j
*
local_size_b
+
local_id
]
=
B_shared
[
make_swizzle_layout
<
last_dim_b
,
sizeof
(
B_type
)
>
(
l
+
row
,
r
+
col
)];
B_shared
[
make_swizzle_layout
<
last_dim_b
,
sizeof
(
B_type
)
>
(
l
+
row
,
r
+
col
)];
}
}
...
...
@@ -187,9 +202,12 @@ class GemmTensorOp {
for
(
int
kp
=
0
;
kp
<
kPack
;
kp
++
)
{
for
(
int
i
=
0
;
i
<
warp_rows
;
++
i
)
{
for
(
int
j
=
0
;
j
<
warp_cols
;
++
j
)
{
*
(((
float32x4
*
)
C_local
)
+
((
i
*
warp_cols
)
+
j
))
=
__builtin_amdgcn_mfma_f32_16x16x16f16
(
*
(((
float16x4
*
)
B_local
)
+
j
*
kPack
+
kp
),
*
(((
float16x4
*
)
A_local
)
+
ki
*
warp_rows
*
kPack
+
i
*
kPack
+
kp
),
*
(((
float32x4
*
)
C_local
)
+
((
i
*
warp_cols
)
+
j
)),
0
,
0
,
0
);
*
(((
float32x4
*
)
C_local
)
+
((
i
*
warp_cols
)
+
j
))
=
__builtin_amdgcn_mfma_f32_16x16x16f16
(
*
(((
float16x4
*
)
B_local
)
+
j
*
kPack
+
kp
),
*
(((
float16x4
*
)
A_local
)
+
ki
*
warp_rows
*
kPack
+
i
*
kPack
+
kp
),
*
(((
float32x4
*
)
C_local
)
+
((
i
*
warp_cols
)
+
j
)),
0
,
0
,
0
);
}
}
}
...
...
@@ -197,24 +215,26 @@ class GemmTensorOp {
}
};
}
// namespace tl
}
// namespace tl
namespace
tl
{
template
<
int
M
,
int
N
,
int
K
,
int
num_warp_m
,
int
num_warp_n
,
bool
trans_A
,
bool
trans_B
,
int
kPack
,
typename
A_type
,
typename
B_type
,
typename
C_type
>
TL_DEVICE
void
gemm_ss
(
A_type
*
pA
,
B_type
*
pB
,
C_type
*
accum
)
{
using
Compute
=
GemmTensorOp
<
M
,
N
,
K
,
num_warp_m
,
num_warp_n
,
trans_A
,
trans_B
,
kPack
,
A_type
,
B_type
,
C_type
>
;
template
<
int
M
,
int
N
,
int
K
,
int
num_warp_m
,
int
num_warp_n
,
bool
trans_A
,
bool
trans_B
,
int
kPack
,
typename
A_type
,
typename
B_type
,
typename
C_type
>
TL_DEVICE
void
gemm_ss
(
A_type
*
pA
,
B_type
*
pB
,
C_type
*
accum
)
{
using
Compute
=
GemmTensorOp
<
M
,
N
,
K
,
num_warp_m
,
num_warp_n
,
trans_A
,
trans_B
,
kPack
,
A_type
,
B_type
,
C_type
>
;
Compute
::
body
(
pA
,
pB
,
accum
);
}
template
<
int
M
,
int
N
,
int
K
,
int
num_warp_m
,
int
num_warp_n
,
bool
trans_A
,
bool
trans_B
,
int
kPack
,
typename
A_type
,
typename
B_type
,
typename
C_type
>
TL_DEVICE
void
gemm_rs
(
A_type
*
pA
,
B_type
*
pB
,
C_type
*
accum
)
{
using
Compute
=
GemmTensorOp
<
M
,
N
,
K
,
num_warp_m
,
num_warp_n
,
trans_A
,
trans_B
,
kPack
,
A_type
,
B_type
,
C_type
>
;
template
<
int
M
,
int
N
,
int
K
,
int
num_warp_m
,
int
num_warp_n
,
bool
trans_A
,
bool
trans_B
,
int
kPack
,
typename
A_type
,
typename
B_type
,
typename
C_type
>
TL_DEVICE
void
gemm_rs
(
A_type
*
pA
,
B_type
*
pB
,
C_type
*
accum
)
{
using
Compute
=
GemmTensorOp
<
M
,
N
,
K
,
num_warp_m
,
num_warp_n
,
trans_A
,
trans_B
,
kPack
,
A_type
,
B_type
,
C_type
>
;
Compute
::
body_rs
(
pA
,
pB
,
accum
);
}
}
// namespace tl
}
// namespace tl
src/tl_templates/hip/reduce.h
View file @
549416f7
...
...
@@ -7,35 +7,30 @@
namespace
tl
{
struct
SumOp
{
template
<
typename
T
>
TL_DEVICE
T
operator
()(
T
const
&
x
,
T
const
&
y
)
{
template
<
typename
T
>
TL_DEVICE
T
operator
()(
T
const
&
x
,
T
const
&
y
)
{
return
x
+
y
;
}
};
struct
MaxOp
{
template
<
typename
T
>
TL_DEVICE
T
operator
()(
T
const
&
x
,
T
const
&
y
)
{
template
<
typename
T
>
TL_DEVICE
T
operator
()(
T
const
&
x
,
T
const
&
y
)
{
return
ck_tile
::
max
(
x
,
y
);
}
};
struct
MinOp
{
template
<
typename
T
>
TL_DEVICE
T
operator
()(
T
const
&
x
,
T
const
&
y
)
{
template
<
typename
T
>
TL_DEVICE
T
operator
()(
T
const
&
x
,
T
const
&
y
)
{
return
ck_tile
::
min
(
x
,
y
);
}
};
template
<
class
Reducer
,
int
threads
,
int
scale
>
struct
AllReduce
{
static_assert
(
threads
==
1024
||
threads
==
512
||
threads
==
256
||
threads
==
128
||
threads
==
64
||
threads
==
32
||
threads
==
16
||
threads
==
8
||
threads
==
4
||
threads
==
2
);
template
<
class
Reducer
,
int
threads
,
int
scale
>
struct
AllReduce
{
static_assert
(
threads
==
1024
||
threads
==
512
||
threads
==
256
||
threads
==
128
||
threads
==
64
||
threads
==
32
||
threads
==
16
||
threads
==
8
||
threads
==
4
||
threads
==
2
);
static_assert
(
threads
%
scale
==
0
);
template
<
typename
T
>
static
__device__
T
run
(
T
x
,
T
*
red_buf
=
nullptr
)
{
template
<
typename
T
>
static
__device__
T
run
(
T
x
,
T
*
red_buf
=
nullptr
)
{
constexpr
int
offset
=
threads
/
2
;
constexpr
int
warpSize
=
64
;
...
...
@@ -55,4 +50,4 @@ struct AllReduce {
}
};
}
// namespace tl
}
// namespace tl
src/tl_templates/hip/threadblock_swizzle.h
View file @
549416f7
...
...
@@ -6,8 +6,7 @@
namespace
tl
{
template
<
int
panel_width
>
TL_DEVICE
dim3
rasterization2DRow
()
{
template
<
int
panel_width
>
TL_DEVICE
dim3
rasterization2DRow
()
{
auto
ceil_div
=
[](
int
a
,
int
b
)
{
return
(
a
+
b
-
1
)
/
b
;
};
const
unsigned
int
block_idx
=
blockIdx
.
x
+
blockIdx
.
y
*
gridDim
.
x
;
const
unsigned
int
grid_size
=
gridDim
.
x
*
gridDim
.
y
;
...
...
@@ -16,15 +15,17 @@ TL_DEVICE dim3 rasterization2DRow() {
const
unsigned
int
panel_idx
=
block_idx
/
panel_size
;
const
unsigned
int
total_panel
=
ceil_div
(
grid_size
,
panel_size
);
const
unsigned
int
stride
=
panel_idx
+
1
<
total_panel
?
panel_width
:
(
grid_size
-
panel_idx
*
panel_size
)
/
gridDim
.
x
;
const
unsigned
int
col_idx
=
(
panel_idx
&
1
)
?
gridDim
.
x
-
1
-
panel_offset
/
stride
:
panel_offset
/
stride
;
panel_idx
+
1
<
total_panel
?
panel_width
:
(
grid_size
-
panel_idx
*
panel_size
)
/
gridDim
.
x
;
const
unsigned
int
col_idx
=
(
panel_idx
&
1
)
?
gridDim
.
x
-
1
-
panel_offset
/
stride
:
panel_offset
/
stride
;
const
unsigned
int
row_idx
=
panel_offset
%
stride
+
panel_idx
*
panel_width
;
return
{
col_idx
,
row_idx
,
blockIdx
.
z
};
}
template
<
int
panel_width
>
TL_DEVICE
dim3
rasterization2DColumn
()
{
template
<
int
panel_width
>
TL_DEVICE
dim3
rasterization2DColumn
()
{
auto
ceil_div
=
[](
int
a
,
int
b
)
{
return
(
a
+
b
-
1
)
/
b
;
};
const
unsigned
int
block_idx
=
blockIdx
.
x
+
blockIdx
.
y
*
gridDim
.
x
;
const
unsigned
int
grid_size
=
gridDim
.
x
*
gridDim
.
y
;
...
...
@@ -33,11 +34,14 @@ TL_DEVICE dim3 rasterization2DColumn() {
const
unsigned
int
panel_idx
=
block_idx
/
panel_size
;
const
unsigned
int
total_panel
=
ceil_div
(
grid_size
,
panel_size
);
const
unsigned
int
stride
=
panel_idx
+
1
<
total_panel
?
panel_width
:
(
grid_size
-
panel_idx
*
panel_size
)
/
gridDim
.
y
;
const
unsigned
int
row_idx
=
(
panel_idx
&
1
)
?
gridDim
.
y
-
1
-
panel_offset
/
stride
:
panel_offset
/
stride
;
panel_idx
+
1
<
total_panel
?
panel_width
:
(
grid_size
-
panel_idx
*
panel_size
)
/
gridDim
.
y
;
const
unsigned
int
row_idx
=
(
panel_idx
&
1
)
?
gridDim
.
y
-
1
-
panel_offset
/
stride
:
panel_offset
/
stride
;
const
unsigned
int
col_idx
=
panel_offset
%
stride
+
panel_idx
*
panel_width
;
return
{
col_idx
,
row_idx
,
blockIdx
.
z
};
}
}
// namespace tl
}
// namespace tl
src/transform/cluster_planning.cc
View file @
549416f7
...
...
@@ -31,15 +31,17 @@ namespace tvm {
namespace
tir
{
class
ClusterPlanner
{
public:
static
PrimFunc
Substitute
(
PrimFunc
&
f
)
{
public:
static
PrimFunc
Substitute
(
PrimFunc
&
f
)
{
// Step 1: Collect the read region of the function
Map
<
Var
,
Buffer
>
buffer_data_to_buffer_
;
for
(
const
auto
&
[
_
,
buffer
]
:
f
->
buffer_map
)
{
for
(
const
auto
&
[
_
,
buffer
]
:
f
->
buffer_map
)
{
buffer_data_to_buffer_
.
Set
(
buffer
->
data
,
buffer
);
}
Block
block
(
/*iter_vars=*/
{},
/*reads=*/
{},
/*writes=*/
{},
/*name_hint=*/
""
,
/*body*/
f
->
body
);
Array
<
Array
<
BufferRegion
>>
access
=
GetBlockReadWriteRegion
(
block
,
buffer_data_to_buffer_
);
Block
block
(
/*iter_vars=*/
{},
/*reads=*/
{},
/*writes=*/
{},
/*name_hint=*/
""
,
/*body*/
f
->
body
);
Array
<
Array
<
BufferRegion
>>
access
=
GetBlockReadWriteRegion
(
block
,
buffer_data_to_buffer_
);
auto
reads
=
access
[
0
];
BlockIdxVisitor
blockIdx_visitor
;
...
...
@@ -47,20 +49,22 @@ class ClusterPlanner {
auto
dom_map
=
blockIdx_visitor
.
dom_map_
;
// Step 2: Collect mem reuse count for clustering on each dimension.
std
::
unordered_map
<
const
IterVarNode
*
,
size_t
>
mem_reuse_count
;
for
(
auto
iv
:
dom_map
)
mem_reuse_count
[
iv
]
=
0
;
std
::
unordered_map
<
const
IterVarNode
*
,
size_t
>
mem_reuse_count
;
for
(
auto
iv
:
dom_map
)
mem_reuse_count
[
iv
]
=
0
;
for
(
const
auto
&
buffer_region
:
reads
)
{
for
(
const
auto
&
buffer_region
:
reads
)
{
PrimExpr
size
=
buffer_region
->
buffer
->
dtype
.
bits
();
RegionVisitor
visitor
;
for
(
const
auto
&
range
:
buffer_region
->
region
)
{
for
(
const
auto
&
range
:
buffer_region
->
region
)
{
size
=
size
*
range
->
extent
;
visitor
(
range
->
min
);
}
size
=
arith
::
Analyzer
().
Simplify
(
size
);
if
(
auto
imm
=
size
.
as
<
IntImmNode
>
())
{
for
(
auto
iv
:
dom_map
)
{
if
(
visitor
.
seen_
.
count
(
iv
->
var
.
get
())
==
0
)
mem_reuse_count
[
iv
]
+=
imm
->
value
;
if
(
visitor
.
seen_
.
count
(
iv
->
var
.
get
())
==
0
)
mem_reuse_count
[
iv
]
+=
imm
->
value
;
}
}
}
...
...
@@ -70,7 +74,8 @@ class ClusterPlanner {
String
cluster_tag
;
for
(
auto
iv
:
dom_map
)
{
if
(
auto
extent
=
iv
->
dom
->
extent
.
as
<
IntImmNode
>
())
{
if
(
extent
->
value
%
cluster_size_
==
0
&&
mem_reuse_count
[
iv
]
>
mem_reuse_max
)
{
if
(
extent
->
value
%
cluster_size_
==
0
&&
mem_reuse_count
[
iv
]
>
mem_reuse_max
)
{
cluster_tag
=
iv
->
thread_tag
;
mem_reuse_max
=
mem_reuse_count
[
iv
];
}
...
...
@@ -78,27 +83,28 @@ class ClusterPlanner {
}
if
(
mem_reuse_max
>
0
)
{
cluster_tag
=
"clusterIdx"
+
String
(
cluster_tag
.
c_str
()
+
strlen
(
"blockIdx"
));
cluster_tag
=
"clusterIdx"
+
String
(
cluster_tag
.
c_str
()
+
strlen
(
"blockIdx"
));
return
WithAttr
(
f
,
cluster_tag
,
Integer
(
cluster_size_
));
}
else
{
return
f
;
}
}
private:
private:
ClusterPlanner
()
=
default
;
class
RegionVisitor
:
public
ExprVisitor
{
public:
public:
RegionVisitor
(){};
void
VisitExpr_
(
const
VarNode
*
var
)
{
seen_
.
insert
(
var
);
}
std
::
unordered_set
<
const
VarNode
*>
seen_
;
void
VisitExpr_
(
const
VarNode
*
var
)
{
seen_
.
insert
(
var
);
}
std
::
unordered_set
<
const
VarNode
*>
seen_
;
};
class
BlockIdxVisitor
:
public
StmtVisitor
{
public:
public:
BlockIdxVisitor
(){};
void
VisitStmt_
(
const
AttrStmtNode
*
attr
)
final
{
void
VisitStmt_
(
const
AttrStmtNode
*
attr
)
final
{
if
(
attr
->
attr_key
==
attr
::
thread_extent
)
{
IterVar
iv
=
Downcast
<
IterVar
>
(
attr
->
node
);
String
tag
=
iv
->
thread_tag
;
...
...
@@ -108,7 +114,7 @@ class ClusterPlanner {
StmtVisitor
::
VisitStmt_
(
attr
);
}
/*! \brief The map from vars to blockidx extents. */
std
::
unordered_set
<
const
IterVarNode
*>
dom_map_
;
std
::
unordered_set
<
const
IterVarNode
*>
dom_map_
;
};
/*! \brief Currently set the plossible cluster size as 2 */
...
...
@@ -126,8 +132,9 @@ tvm::transform::Pass ClusterPlanning() {
return
CreatePrimFuncPass
(
pass_func
,
0
,
"tl.ClusterPlanning"
,
{});
}
TVM_REGISTER_GLOBAL
(
"tl.transform.ClusterPlanning"
).
set_body_typed
(
ClusterPlanning
);
}
// namespace transform
TVM_REGISTER_GLOBAL
(
"tl.transform.ClusterPlanning"
)
.
set_body_typed
(
ClusterPlanning
);
}
// namespace transform
}
// namespace tir
}
// namespace tvm
}
// namespace tir
}
// namespace tvm
src/transform/common/loop_fusion_utils.h
View file @
549416f7
...
...
@@ -32,10 +32,10 @@
#include <queue>
#include "arith/ir_mutator_with_analyzer.h"
#include "../../op/parallel.h"
#include "../loop_partition.h"
#include "../loop_vectorize.h"
#include "arith/ir_mutator_with_analyzer.h"
namespace
tvm
{
namespace
tl
{
...
...
@@ -44,15 +44,15 @@ using namespace tir;
using
arith
::
IRMutatorWithAnalyzer
;
class
FragmentAccessDetector
:
public
StmtExprVisitor
{
public:
public:
FragmentAccessDetector
()
=
default
;
void
Collect
(
Stmt
stmt
)
{
VisitStmt
(
stmt
);
}
bool
HasFragmentAccess
()
{
return
has_fragment_access_
;
}
private:
void
VisitExpr_
(
const
BufferLoadNode
*
op
)
final
{
private:
void
VisitExpr_
(
const
BufferLoadNode
*
op
)
final
{
// Check if the buffer is in global scope
if
(
IsFragmentBuffer
(
op
->
buffer
))
{
has_fragment_access_
=
true
;
...
...
@@ -60,7 +60,7 @@ class FragmentAccessDetector : public StmtExprVisitor {
StmtExprVisitor
::
VisitExpr_
(
op
);
}
void
VisitStmt_
(
const
BufferStoreNode
*
op
)
final
{
void
VisitStmt_
(
const
BufferStoreNode
*
op
)
final
{
// Check if the buffer is in global scope
if
(
IsFragmentBuffer
(
op
->
buffer
))
{
has_fragment_access_
=
true
;
...
...
@@ -69,8 +69,9 @@ class FragmentAccessDetector : public StmtExprVisitor {
}
// Helper function to determine if a buffer is local.fragment
bool
IsFragmentBuffer
(
const
Buffer
&
buffer
)
{
// The storage scope is often encoded in the buffer->data var name or associated attributes.
bool
IsFragmentBuffer
(
const
Buffer
&
buffer
)
{
// The storage scope is often encoded in the buffer->data var name or
// associated attributes.
String
scope
=
buffer
.
scope
();
return
scope
==
"local.fragment"
;
}
...
...
@@ -87,23 +88,25 @@ class FragmentAccessDetector : public StmtExprVisitor {
* Once fused, a single loop variable will replace the chain, and the
* original loop variables will be derived by division and modulo operations.
*
* This can be helpful for inferring layout for the fragment in a subsequent pass.
* This can be helpful for inferring layout for the fragment in a subsequent
* pass.
*/
class
ParallelLoopFuser
:
public
IRMutatorWithAnalyzer
{
public:
public:
static
Stmt
Fuse
(
Stmt
stmt
)
{
arith
::
Analyzer
analyzer
;
ParallelLoopFuser
substituter
(
&
analyzer
);
return
substituter
.
VisitStmt
(
stmt
);
}
private:
ParallelLoopFuser
(
arith
::
Analyzer
*
analyzer
)
:
IRMutatorWithAnalyzer
(
analyzer
)
{};
private:
ParallelLoopFuser
(
arith
::
Analyzer
*
analyzer
)
:
IRMutatorWithAnalyzer
(
analyzer
){};
Stmt
VisitStmt_
(
const
ForNode
*
op
)
final
{
Stmt
VisitStmt_
(
const
ForNode
*
op
)
final
{
// Gather consecutive parallel loops
std
::
vector
<
const
ForNode
*>
loop_chain
;
const
ForNode
*
current
=
op
;
std
::
vector
<
const
ForNode
*>
loop_chain
;
const
ForNode
*
current
=
op
;
// check if has fragment access
FragmentAccessDetector
detector
;
detector
.
Collect
(
op
->
body
);
...
...
@@ -113,11 +116,13 @@ class ParallelLoopFuser : public IRMutatorWithAnalyzer {
}
while
(
true
)
{
if
(
current
->
kind
!=
ForKind
::
kParallel
)
break
;
if
(
!
is_zero
(
current
->
min
))
break
;
if
(
current
->
kind
!=
ForKind
::
kParallel
)
break
;
if
(
!
is_zero
(
current
->
min
))
break
;
loop_chain
.
push_back
(
current
);
const
ForNode
*
inner_for
=
current
->
body
.
as
<
ForNode
>
();
const
ForNode
*
inner_for
=
current
->
body
.
as
<
ForNode
>
();
if
(
!
inner_for
)
{
break
;
}
...
...
@@ -147,7 +152,7 @@ class ParallelLoopFuser : public IRMutatorWithAnalyzer {
Var
fused_var
(
fused_name
,
DataType
::
Int
(
32
));
// The body of the last loop in the chain:
const
ForNode
*
innermost_loop
=
loop_chain
.
back
();
const
ForNode
*
innermost_loop
=
loop_chain
.
back
();
Stmt
body
=
innermost_loop
->
body
;
// We need to substitute all loop variables in the chain.
...
...
@@ -175,7 +180,8 @@ class ParallelLoopFuser : public IRMutatorWithAnalyzer {
extents
.
push_back
(
l
->
extent
);
}
std
::
vector
<
PrimExpr
>
strides
(
loop_chain
.
size
(),
make_const
(
DataType
::
Int
(
32
),
1
));
std
::
vector
<
PrimExpr
>
strides
(
loop_chain
.
size
(),
make_const
(
DataType
::
Int
(
32
),
1
));
for
(
int
i
=
static_cast
<
int
>
(
loop_chain
.
size
())
-
2
;
i
>=
0
;
i
--
)
{
strides
[
i
]
=
strides
[
i
+
1
]
*
extents
[
i
+
1
];
}
...
...
@@ -189,8 +195,9 @@ class ParallelLoopFuser : public IRMutatorWithAnalyzer {
Map
<
Var
,
PrimExpr
>
var_map
;
for
(
size_t
i
=
0
;
i
<
loop_chain
.
size
();
i
++
)
{
const
ForNode
*
loop
=
loop_chain
[
i
];
var_map
.
Set
(
loop
->
loop_var
,
analyzer_
->
Simplify
(
create_index_expr
(
static_cast
<
int
>
(
i
))));
const
ForNode
*
loop
=
loop_chain
[
i
];
var_map
.
Set
(
loop
->
loop_var
,
analyzer_
->
Simplify
(
create_index_expr
(
static_cast
<
int
>
(
i
))));
}
// Perform the substitution
...
...
@@ -203,5 +210,5 @@ class ParallelLoopFuser : public IRMutatorWithAnalyzer {
}
};
}
// namespace tl
}
// namespace tvm
}
// namespace tl
}
// namespace tvm
src/transform/common/loop_vectorization_utils.h
View file @
549416f7
...
...
@@ -32,10 +32,10 @@
#include <queue>
#include "arith/ir_mutator_with_analyzer.h"
#include "../../op/parallel.h"
#include "../loop_partition.h"
#include "../loop_vectorize.h"
#include "arith/ir_mutator_with_analyzer.h"
namespace
tvm
{
namespace
tl
{
...
...
@@ -46,7 +46,8 @@ using namespace tir;
// Use the same code as tir.transform.vectorize_loop
inline
PrimExpr
CreateNewLanes
(
bool
is_scalable
,
int
lanes_or_vscale_factor
)
{
if
(
is_scalable
)
{
return
Mul
(
Call
(
DataType
::
Int
(
32
),
builtin
::
vscale
(),
{}),
lanes_or_vscale_factor
);
return
Mul
(
Call
(
DataType
::
Int
(
32
),
builtin
::
vscale
(),
{}),
lanes_or_vscale_factor
);
}
else
{
return
lanes_or_vscale_factor
;
}
...
...
@@ -58,7 +59,7 @@ inline PrimExpr BroadcastTo(PrimExpr e, int lanes, bool is_scalable) {
e
.
dtype
().
is_scalable_vector
()
==
is_scalable
)
return
e
;
if
(
const
BroadcastNode
*
op
=
e
.
as
<
BroadcastNode
>
())
{
if
(
const
BroadcastNode
*
op
=
e
.
as
<
BroadcastNode
>
())
{
ICHECK
(
op
->
dtype
.
is_scalable_vector
()
==
is_scalable
)
<<
"Can't broadcast between scalable and fixed length vectors."
;
int
e_lanes
=
op
->
dtype
.
get_lanes_or_vscale_factor
();
...
...
@@ -68,40 +69,39 @@ inline PrimExpr BroadcastTo(PrimExpr e, int lanes, bool is_scalable) {
}
}
ICHECK
(
e
.
dtype
().
is_scalar
())
<<
"Cannot broadcast lanes="
<<
e
.
dtype
().
get_lanes_or_vscale_factor
()
<<
" is_scalable="
<<
e
.
dtype
().
is_scalable_vector
()
<<
" to "
<<
lanes
;
ICHECK
(
e
.
dtype
().
is_scalar
())
<<
"Cannot broadcast lanes="
<<
e
.
dtype
().
get_lanes_or_vscale_factor
()
<<
" is_scalable="
<<
e
.
dtype
().
is_scalable_vector
()
<<
" to "
<<
lanes
;
return
Broadcast
(
e
,
CreateNewLanes
(
is_scalable
,
lanes
));
}
// Rewrite vectorized allocation access
// This is necessary for making each vector component containing its own
workspace.
// Originates from Halide's loop vectorizer
// This is necessary for making each vector component containing its own
//
workspace.
Originates from Halide's loop vectorizer
//
// s[i] = s[i * lanes + var]
//
// The same principle applies when using one thread to simulate multiple context.
// The same principle applies when using one thread to simulate multiple
// context.
//
class
VecAllocAccess
:
public
StmtExprMutator
{
public:
VecAllocAccess
(
const
VarNode
*
buf
,
Var
var
,
PrimExpr
var_lanes
)
public:
VecAllocAccess
(
const
VarNode
*
buf
,
Var
var
,
PrimExpr
var_lanes
)
:
buf_
(
buf
),
var_
(
var
),
var_lanes_
(
var_lanes
)
{}
PrimExpr
VisitExpr_
(
const
BufferLoadNode
*
op
)
final
{
PrimExpr
VisitExpr_
(
const
BufferLoadNode
*
op
)
final
{
auto
load
=
Downcast
<
BufferLoad
>
(
StmtExprMutator
::
VisitExpr_
(
op
));
return
UpdateBufferAccess
(
load
);
}
Stmt
VisitStmt_
(
const
BufferStoreNode
*
op
)
final
{
Stmt
VisitStmt_
(
const
BufferStoreNode
*
op
)
final
{
auto
store
=
Downcast
<
BufferStore
>
(
StmtExprMutator
::
VisitStmt_
(
op
));
return
UpdateBufferAccess
(
store
);
}
private:
template
<
typename
Node
>
Node
UpdateBufferAccess
(
Node
node
)
{
private:
template
<
typename
Node
>
Node
UpdateBufferAccess
(
Node
node
)
{
// Only update the buffer that's being replaced.
if
(
node
->
buffer
->
data
.
get
()
!=
buf_
)
{
return
node
;
...
...
@@ -117,7 +117,8 @@ class VecAllocAccess : public StmtExprMutator {
// var_lanes_. Typically, this will be a 1-d index into a flat
// memory space.
Array
<
PrimExpr
>
shape
=
node
->
buffer
->
shape
;
shape
.
Set
(
shape
.
size
()
-
1
,
analyzer_
.
Simplify
(
shape
[
shape
.
size
()
-
1
]
*
var_lanes_
));
shape
.
Set
(
shape
.
size
()
-
1
,
analyzer_
.
Simplify
(
shape
[
shape
.
size
()
-
1
]
*
var_lanes_
));
// TODO(Lunderberg): Move this pass to be prior to
// StorageFlatten/FlattenBuffer, implement by appending a
...
...
@@ -146,8 +147,9 @@ class VecAllocAccess : public StmtExprMutator {
// Extend the last index by the number of lanes in the vectorized
// variable.
Array
<
PrimExpr
>
indices
=
node
->
indices
;
indices
.
Set
(
indices
.
size
()
-
1
,
analyzer_
.
Simplify
(
indices
[
indices
.
size
()
-
1
]
*
var_lanes_
+
var_
));
indices
.
Set
(
indices
.
size
()
-
1
,
analyzer_
.
Simplify
(
indices
[
indices
.
size
()
-
1
]
*
var_lanes_
+
var_
));
auto
writer
=
node
.
CopyOnWrite
();
writer
->
buffer
=
buf
;
...
...
@@ -156,9 +158,9 @@ class VecAllocAccess : public StmtExprMutator {
}
// buffer var
const
VarNode
*
buf_
;
const
VarNode
*
buf_
;
// Updated buffer objects.
std
::
unordered_map
<
const
BufferNode
*
,
Buffer
>
buffer_map_
;
std
::
unordered_map
<
const
BufferNode
*
,
Buffer
>
buffer_map_
;
// variable to be replaced
Var
var_
;
// the lanes.
...
...
@@ -170,8 +172,9 @@ class VecAllocAccess : public StmtExprMutator {
// We use ExprFunctor directly instead of StmtExprMutator
// This is because the transformation can change the dtype of the Expr
// The existing ExprMutator transformation rules may not be well defined.
class
Vectorizer
:
public
StmtMutator
,
public
ExprFunctor
<
PrimExpr
(
const
PrimExpr
&
)
>
{
public:
class
Vectorizer
:
public
StmtMutator
,
public
ExprFunctor
<
PrimExpr
(
const
PrimExpr
&
)
>
{
public:
using
ExprFunctor
::
VisitExpr
;
using
StmtMutator
::
operator
();
...
...
@@ -179,7 +182,7 @@ class Vectorizer : public StmtMutator, public ExprFunctor<PrimExpr(const PrimExp
ramp_
=
Ramp
(
IntImm
(
var
->
dtype
,
0
),
IntImm
(
var
->
dtype
,
1
),
var_lanes
);
}
Stmt
VisitStmt
(
const
Stmt
&
stmt
)
final
{
Stmt
VisitStmt
(
const
Stmt
&
stmt
)
final
{
ICHECK
(
!
need_scalarize_
);
Stmt
ret
=
StmtMutator
::
VisitStmt
(
stmt
);
if
(
need_scalarize_
)
{
...
...
@@ -190,17 +193,19 @@ class Vectorizer : public StmtMutator, public ExprFunctor<PrimExpr(const PrimExp
}
}
PrimExpr
VisitExpr
(
const
PrimExpr
&
e
)
final
{
return
ExprFunctor
::
VisitExpr
(
e
);
}
PrimExpr
VisitExpr
(
const
PrimExpr
&
e
)
final
{
return
ExprFunctor
::
VisitExpr
(
e
);
}
PrimExpr
VisitExpr_
(
const
AddNode
*
op
)
final
{
PrimExpr
VisitExpr_
(
const
AddNode
*
op
)
final
{
return
AddSubVec
(
op
,
[](
PrimExpr
a
,
PrimExpr
b
)
{
return
a
+
b
;
});
}
PrimExpr
VisitExpr_
(
const
SubNode
*
op
)
final
{
PrimExpr
VisitExpr_
(
const
SubNode
*
op
)
final
{
return
AddSubVec
(
op
,
[](
PrimExpr
a
,
PrimExpr
b
)
{
return
a
-
b
;
});
}
PrimExpr
VisitExpr_
(
const
MulNode
*
op
)
final
{
PrimExpr
VisitExpr_
(
const
MulNode
*
op
)
final
{
PrimExpr
a
=
this
->
VisitExpr
(
op
->
a
);
PrimExpr
b
=
this
->
VisitExpr
(
op
->
b
);
if
(
a
.
same_as
(
op
->
a
)
&&
b
.
same_as
(
op
->
b
))
{
...
...
@@ -211,11 +216,12 @@ class Vectorizer : public StmtMutator, public ExprFunctor<PrimExpr(const PrimExp
if
(
is_vec_a
&&
is_vec_b
)
{
// Let's not multiply scalable and fixed length vectors
ICHECK
(
a
.
dtype
().
is_scalable_vector
()
==
b
.
dtype
().
is_scalable_vector
())
<<
"Fixed length and scalable vectors can't be mixed in multiplication."
;
<<
"Fixed length and scalable vectors can't be mixed in "
"multiplication."
;
}
if
(
is_vec_a
||
is_vec_b
)
{
const
RampNode
*
b_ramp
=
b
.
as
<
RampNode
>
();
const
RampNode
*
a_ramp
=
a
.
as
<
RampNode
>
();
const
RampNode
*
b_ramp
=
b
.
as
<
RampNode
>
();
const
RampNode
*
a_ramp
=
a
.
as
<
RampNode
>
();
if
(
a_ramp
&&
b
.
dtype
().
is_scalar
()
&&
analyzer_
.
CanProve
(
b
>
0
))
{
PrimExpr
lanes
=
a_ramp
->
lanes
;
return
Ramp
(
a_ramp
->
base
*
b
,
a_ramp
->
stride
*
b
,
lanes
);
...
...
@@ -227,28 +233,34 @@ class Vectorizer : public StmtMutator, public ExprFunctor<PrimExpr(const PrimExp
int
a_lanes
=
a
.
dtype
().
get_lanes_or_vscale_factor
();
int
b_lanes
=
b
.
dtype
().
get_lanes_or_vscale_factor
();
int
max_lanes
=
std
::
max
(
a_lanes
,
b_lanes
);
bool
is_scalable
=
a
.
dtype
().
is_scalable_vector
()
||
b
.
dtype
().
is_scalable_vector
();
return
Mul
(
BroadcastTo
(
a
,
max_lanes
,
is_scalable
),
BroadcastTo
(
b
,
max_lanes
,
is_scalable
));
bool
is_scalable
=
a
.
dtype
().
is_scalable_vector
()
||
b
.
dtype
().
is_scalable_vector
();
return
Mul
(
BroadcastTo
(
a
,
max_lanes
,
is_scalable
),
BroadcastTo
(
b
,
max_lanes
,
is_scalable
));
}
}
return
BinaryVec
<
Mul
>
(
op
);
}
PrimExpr
VisitExpr_
(
const
DivNode
*
op
)
final
{
return
BinaryVec
<
Div
>
(
op
);
}
PrimExpr
VisitExpr_
(
const
ModNode
*
op
)
final
{
return
BinaryVec
<
Mod
>
(
op
);
}
PrimExpr
VisitExpr_
(
const
FloorDivNode
*
op
)
final
{
return
BinaryVec
<
FloorDiv
>
(
op
);
}
PrimExpr
VisitExpr_
(
const
FloorModNode
*
op
)
final
{
return
BinaryVec
<
FloorMod
>
(
op
);
}
PrimExpr
VisitExpr_
(
const
MinNode
*
op
)
final
{
return
BinaryVec
<
Min
>
(
op
);
}
PrimExpr
VisitExpr_
(
const
MaxNode
*
op
)
final
{
return
BinaryVec
<
Max
>
(
op
);
}
PrimExpr
VisitExpr_
(
const
EQNode
*
op
)
final
{
return
BinaryVec
<
EQ
>
(
op
);
}
PrimExpr
VisitExpr_
(
const
NENode
*
op
)
final
{
return
BinaryVec
<
NE
>
(
op
);
}
PrimExpr
VisitExpr_
(
const
LTNode
*
op
)
final
{
return
BinaryVec
<
LT
>
(
op
);
}
PrimExpr
VisitExpr_
(
const
LENode
*
op
)
final
{
return
BinaryVec
<
LE
>
(
op
);
}
PrimExpr
VisitExpr_
(
const
GTNode
*
op
)
final
{
return
BinaryVec
<
GT
>
(
op
);
}
PrimExpr
VisitExpr_
(
const
GENode
*
op
)
final
{
return
BinaryVec
<
GE
>
(
op
);
}
PrimExpr
VisitExpr_
(
const
AndNode
*
op
)
final
{
return
BinaryVec
<
And
>
(
op
);
}
PrimExpr
VisitExpr_
(
const
OrNode
*
op
)
final
{
return
BinaryVec
<
Or
>
(
op
);
}
PrimExpr
VisitExpr_
(
const
NotNode
*
op
)
final
{
PrimExpr
VisitExpr_
(
const
DivNode
*
op
)
final
{
return
BinaryVec
<
Div
>
(
op
);
}
PrimExpr
VisitExpr_
(
const
ModNode
*
op
)
final
{
return
BinaryVec
<
Mod
>
(
op
);
}
PrimExpr
VisitExpr_
(
const
FloorDivNode
*
op
)
final
{
return
BinaryVec
<
FloorDiv
>
(
op
);
}
PrimExpr
VisitExpr_
(
const
FloorModNode
*
op
)
final
{
return
BinaryVec
<
FloorMod
>
(
op
);
}
PrimExpr
VisitExpr_
(
const
MinNode
*
op
)
final
{
return
BinaryVec
<
Min
>
(
op
);
}
PrimExpr
VisitExpr_
(
const
MaxNode
*
op
)
final
{
return
BinaryVec
<
Max
>
(
op
);
}
PrimExpr
VisitExpr_
(
const
EQNode
*
op
)
final
{
return
BinaryVec
<
EQ
>
(
op
);
}
PrimExpr
VisitExpr_
(
const
NENode
*
op
)
final
{
return
BinaryVec
<
NE
>
(
op
);
}
PrimExpr
VisitExpr_
(
const
LTNode
*
op
)
final
{
return
BinaryVec
<
LT
>
(
op
);
}
PrimExpr
VisitExpr_
(
const
LENode
*
op
)
final
{
return
BinaryVec
<
LE
>
(
op
);
}
PrimExpr
VisitExpr_
(
const
GTNode
*
op
)
final
{
return
BinaryVec
<
GT
>
(
op
);
}
PrimExpr
VisitExpr_
(
const
GENode
*
op
)
final
{
return
BinaryVec
<
GE
>
(
op
);
}
PrimExpr
VisitExpr_
(
const
AndNode
*
op
)
final
{
return
BinaryVec
<
And
>
(
op
);
}
PrimExpr
VisitExpr_
(
const
OrNode
*
op
)
final
{
return
BinaryVec
<
Or
>
(
op
);
}
PrimExpr
VisitExpr_
(
const
NotNode
*
op
)
final
{
PrimExpr
a
=
this
->
VisitExpr
(
op
->
a
);
if
(
a
.
same_as
(
op
->
a
))
{
return
GetRef
<
PrimExpr
>
(
op
);
...
...
@@ -257,7 +269,7 @@ class Vectorizer : public StmtMutator, public ExprFunctor<PrimExpr(const PrimExp
}
}
PrimExpr
VisitExpr_
(
const
RampNode
*
op
)
final
{
PrimExpr
VisitExpr_
(
const
RampNode
*
op
)
final
{
PrimExpr
base
=
this
->
VisitExpr
(
op
->
base
);
PrimExpr
stride
=
this
->
VisitExpr
(
op
->
stride
);
ICHECK
(
!
base
.
dtype
().
is_scalable_vector
())
...
...
@@ -267,11 +279,13 @@ class Vectorizer : public StmtMutator, public ExprFunctor<PrimExpr(const PrimExp
if
(
base
.
dtype
().
is_fixed_length_vector
()
&&
stride
.
dtype
().
is_scalar
())
{
ICHECK
(
op
->
lanes
->
IsInstance
<
IntImmNode
>
())
<<
"Vectorizing over existing scalable vectors is not supported."
;
const
RampNode
*
base_ramp
=
base
.
as
<
RampNode
>
();
const
RampNode
*
base_ramp
=
base
.
as
<
RampNode
>
();
int
op_lanes
=
static_cast
<
int
>
(
Downcast
<
IntImm
>
(
op
->
lanes
)
->
value
);
int
base_ramp_lanes
=
static_cast
<
int
>
(
Downcast
<
IntImm
>
(
base_ramp
->
lanes
)
->
value
);
int
base_ramp_lanes
=
static_cast
<
int
>
(
Downcast
<
IntImm
>
(
base_ramp
->
lanes
)
->
value
);
if
(
analyzer_
.
CanProve
(
base_ramp
->
stride
==
stride
*
make_const
(
stride
.
dtype
(),
base_ramp_lanes
)))
{
stride
*
make_const
(
stride
.
dtype
(),
base_ramp_lanes
)))
{
return
Ramp
(
base_ramp
->
base
,
stride
,
op_lanes
*
base_ramp_lanes
);
}
}
...
...
@@ -280,13 +294,13 @@ class Vectorizer : public StmtMutator, public ExprFunctor<PrimExpr(const PrimExp
stride
=
BroadcastTo
(
stride
,
lanes
,
false
);
Array
<
PrimExpr
>
elems
;
for
(
int
i
=
0
;
i
<
lanes
;
++
i
)
{
elems
.
push_back
(
Ramp
(
Shuffle
::
ExtractElement
(
base
,
i
),
Shuffle
::
ExtractElement
(
stride
,
i
),
op
->
lanes
));
elems
.
push_back
(
Ramp
(
Shuffle
::
ExtractElement
(
base
,
i
),
Shuffle
::
ExtractElement
(
stride
,
i
),
op
->
lanes
));
}
return
Shuffle
::
Concat
(
elems
);
}
PrimExpr
VisitExpr_
(
const
BroadcastNode
*
op
)
final
{
PrimExpr
VisitExpr_
(
const
BroadcastNode
*
op
)
final
{
PrimExpr
value
=
this
->
VisitExpr
(
op
->
value
);
if
(
value
.
dtype
().
is_scalable_or_fixed_length_vector
())
{
need_scalarize_
=
true
;
...
...
@@ -299,45 +313,56 @@ class Vectorizer : public StmtMutator, public ExprFunctor<PrimExpr(const PrimExp
}
}
PrimExpr
VisitExpr_
(
const
SelectNode
*
op
)
final
{
PrimExpr
VisitExpr_
(
const
SelectNode
*
op
)
final
{
PrimExpr
cond
=
this
->
VisitExpr
(
op
->
condition
);
PrimExpr
t
=
this
->
VisitExpr
(
op
->
true_value
);
PrimExpr
f
=
this
->
VisitExpr
(
op
->
false_value
);
if
(
cond
.
same_as
(
op
->
condition
)
&&
t
.
same_as
(
op
->
true_value
)
&&
f
.
same_as
(
op
->
false_value
))
{
if
(
cond
.
same_as
(
op
->
condition
)
&&
t
.
same_as
(
op
->
true_value
)
&&
f
.
same_as
(
op
->
false_value
))
{
return
GetRef
<
PrimExpr
>
(
op
);
}
else
{
int
cond_lanes
=
cond
.
dtype
().
get_lanes_or_vscale_factor
();
int
t_lanes
=
t
.
dtype
().
get_lanes_or_vscale_factor
();
int
f_lanes
=
f
.
dtype
().
get_lanes_or_vscale_factor
();
int
lanes
=
std
::
max
(
std
::
max
(
cond_lanes
,
t_lanes
),
f_lanes
);
bool
is_scalable
=
cond
.
dtype
().
is_scalable_vector
()
||
t
.
dtype
().
is_scalable_vector
()
||
bool
is_scalable
=
cond
.
dtype
().
is_scalable_vector
()
||
t
.
dtype
().
is_scalable_vector
()
||
f
.
dtype
().
is_scalable_vector
();
return
Select
(
BroadcastTo
(
cond
,
lanes
,
is_scalable
),
BroadcastTo
(
t
,
lanes
,
is_scalable
),
return
Select
(
BroadcastTo
(
cond
,
lanes
,
is_scalable
),
BroadcastTo
(
t
,
lanes
,
is_scalable
),
BroadcastTo
(
f
,
lanes
,
is_scalable
));
}
}
PrimExpr
VisitExpr_
(
const
CastNode
*
op
)
final
{
PrimExpr
VisitExpr_
(
const
CastNode
*
op
)
final
{
PrimExpr
value
=
this
->
VisitExpr
(
op
->
value
);
if
(
value
.
same_as
(
op
->
value
))
{
return
GetRef
<
PrimExpr
>
(
op
);
}
else
{
if
(
value
.
dtype
().
is_scalable_vector
())
{
return
Cast
(
op
->
dtype
.
with_scalable_vscale_factor
(
value
.
dtype
().
vscale_factor
()),
value
);
return
Cast
(
op
->
dtype
.
with_scalable_vscale_factor
(
value
.
dtype
().
vscale_factor
()),
value
);
}
else
{
return
Cast
(
op
->
dtype
.
with_lanes
(
value
.
dtype
().
lanes
()),
value
);
}
}
}
PrimExpr
VisitExpr_
(
const
FloatImmNode
*
op
)
final
{
return
GetRef
<
PrimExpr
>
(
op
);
}
PrimExpr
VisitExpr_
(
const
FloatImmNode
*
op
)
final
{
return
GetRef
<
PrimExpr
>
(
op
);
}
PrimExpr
VisitExpr_
(
const
IntImmNode
*
op
)
final
{
return
GetRef
<
PrimExpr
>
(
op
);
}
PrimExpr
VisitExpr_
(
const
IntImmNode
*
op
)
final
{
return
GetRef
<
PrimExpr
>
(
op
);
}
PrimExpr
VisitExpr_
(
const
StringImmNode
*
op
)
final
{
return
GetRef
<
PrimExpr
>
(
op
);
}
PrimExpr
VisitExpr_
(
const
StringImmNode
*
op
)
final
{
return
GetRef
<
PrimExpr
>
(
op
);
}
// Variable
PrimExpr
VisitExpr_
(
const
VarNode
*
op
)
final
{
PrimExpr
VisitExpr_
(
const
VarNode
*
op
)
final
{
Var
var
=
GetRef
<
Var
>
(
op
);
if
(
var
.
same_as
(
var_
))
{
...
...
@@ -351,7 +376,7 @@ class Vectorizer : public StmtMutator, public ExprFunctor<PrimExpr(const PrimExp
}
}
// IfThenElse expr
PrimExpr
MutateIfThenElseExpr_
(
const
CallNode
*
op
)
{
PrimExpr
MutateIfThenElseExpr_
(
const
CallNode
*
op
)
{
PrimExpr
cond
=
this
->
VisitExpr
(
op
->
args
[
0
]);
if
(
cond
.
dtype
().
is_scalable_or_fixed_length_vector
())
{
need_scalarize_
=
true
;
...
...
@@ -359,24 +384,27 @@ class Vectorizer : public StmtMutator, public ExprFunctor<PrimExpr(const PrimExp
}
PrimExpr
t
=
this
->
VisitExpr
(
op
->
args
[
1
]);
PrimExpr
f
=
this
->
VisitExpr
(
op
->
args
[
2
]);
if
(
cond
.
same_as
(
op
->
args
[
0
])
&&
t
.
same_as
(
op
->
args
[
1
])
&&
f
.
same_as
(
op
->
args
[
2
]))
{
if
(
cond
.
same_as
(
op
->
args
[
0
])
&&
t
.
same_as
(
op
->
args
[
1
])
&&
f
.
same_as
(
op
->
args
[
2
]))
{
return
GetRef
<
PrimExpr
>
(
op
);
}
else
{
int
t_lanes
=
t
.
dtype
().
get_lanes_or_vscale_factor
();
int
f_lanes
=
f
.
dtype
().
get_lanes_or_vscale_factor
();
int
lanes
=
std
::
max
(
t_lanes
,
f_lanes
);
bool
is_scalable
=
t
.
dtype
().
is_scalable_vector
()
||
f
.
dtype
().
is_scalable_vector
();
bool
is_scalable
=
t
.
dtype
().
is_scalable_vector
()
||
f
.
dtype
().
is_scalable_vector
();
t
=
BroadcastTo
(
t
,
lanes
,
is_scalable
);
f
=
BroadcastTo
(
f
,
lanes
,
is_scalable
);
if
(
is_scalable
)
{
return
Call
(
op
->
dtype
.
with_scalable_vscale_factor
(
lanes
),
op
->
op
,
{
cond
,
t
,
f
});
return
Call
(
op
->
dtype
.
with_scalable_vscale_factor
(
lanes
),
op
->
op
,
{
cond
,
t
,
f
});
}
else
{
return
Call
(
op
->
dtype
.
with_lanes
(
lanes
),
op
->
op
,
{
cond
,
t
,
f
});
}
}
}
// Reinterpret expr
PrimExpr
MutateReinterpretExpr_
(
const
CallNode
*
op
)
{
PrimExpr
MutateReinterpretExpr_
(
const
CallNode
*
op
)
{
ICHECK
(
op
->
op
.
same_as
(
builtin
::
reinterpret
()));
PrimExpr
value
=
this
->
VisitExpr
(
op
->
args
[
0
]);
if
(
value
.
same_as
(
op
->
args
[
0
]))
{
...
...
@@ -384,14 +412,15 @@ class Vectorizer : public StmtMutator, public ExprFunctor<PrimExpr(const PrimExp
}
else
{
int
lanes
=
value
.
dtype
().
get_lanes_or_vscale_factor
();
if
(
value
.
dtype
().
is_scalable_vector
())
{
return
Call
(
op
->
dtype
.
with_scalable_vscale_factor
(
lanes
),
op
->
op
,
{
value
});
return
Call
(
op
->
dtype
.
with_scalable_vscale_factor
(
lanes
),
op
->
op
,
{
value
});
}
else
{
return
Call
(
op
->
dtype
.
with_lanes
(
lanes
),
op
->
op
,
{
value
});
}
}
}
// Call
PrimExpr
VisitExpr_
(
const
CallNode
*
op
)
final
{
PrimExpr
VisitExpr_
(
const
CallNode
*
op
)
final
{
if
(
op
->
op
.
same_as
(
builtin
::
if_then_else
()))
{
return
MutateIfThenElseExpr_
(
op
);
}
else
if
(
op
->
op
.
same_as
(
builtin
::
texture2d_load
()))
{
...
...
@@ -406,13 +435,15 @@ class Vectorizer : public StmtMutator, public ExprFunctor<PrimExpr(const PrimExp
// Vectorize the value to store
Array
<
PrimExpr
>
value
{
op
->
args
.
back
()};
Array
<
PrimExpr
>
mutated_value
=
MutateArray
(
value
,
&
lane
);
Array
<
PrimExpr
>
new_args
{
op
->
args
[
0
],
op
->
args
[
1
],
op
->
args
[
2
],
mutated_value
[
0
]};
Array
<
PrimExpr
>
new_args
{
op
->
args
[
0
],
op
->
args
[
1
],
op
->
args
[
2
],
mutated_value
[
0
]};
return
Call
(
op
->
dtype
.
with_lanes
(
lane
),
op
->
op
,
new_args
);
}
else
if
(
op
->
op
.
same_as
(
builtin
::
reinterpret
()))
{
return
MutateReinterpretExpr_
(
op
);
}
auto
optional_op
=
op
->
op
.
as
<
Op
>
();
bool
vectorizable
=
optional_op
&&
op_vectorizable_
.
get
(
optional_op
.
value
(),
false
)
&&
bool
vectorizable
=
optional_op
&&
op_vectorizable_
.
get
(
optional_op
.
value
(),
false
)
&&
!
op
->
dtype
.
is_scalable_vector
();
if
(
!
vectorizable
)
{
...
...
@@ -443,10 +474,12 @@ class Vectorizer : public StmtMutator, public ExprFunctor<PrimExpr(const PrimExp
}
}
// BufferLoad
PrimExpr
VisitExpr_
(
const
BufferLoadNode
*
op
)
final
{
PrimExpr
VisitExpr_
(
const
BufferLoadNode
*
op
)
final
{
auto
load
=
GetRef
<
BufferLoad
>
(
op
);
auto
fmutate
=
[
this
](
const
PrimExpr
&
index
)
{
return
this
->
VisitExpr
(
index
);
};
auto
fmutate
=
[
this
](
const
PrimExpr
&
index
)
{
return
this
->
VisitExpr
(
index
);
};
Array
<
PrimExpr
>
indices
=
op
->
indices
.
Map
(
fmutate
);
if
(
!
indices
.
same_as
(
op
->
indices
))
{
...
...
@@ -457,7 +490,7 @@ class Vectorizer : public StmtMutator, public ExprFunctor<PrimExpr(const PrimExp
return
std
::
move
(
load
);
}
// Let
PrimExpr
VisitExpr_
(
const
LetNode
*
op
)
final
{
PrimExpr
VisitExpr_
(
const
LetNode
*
op
)
final
{
PrimExpr
value
=
this
->
VisitExpr
(
op
->
value
);
// Weaker SSA condition
// A single var can be binded in multiple lets
...
...
@@ -486,24 +519,28 @@ class Vectorizer : public StmtMutator, public ExprFunctor<PrimExpr(const PrimExp
}
}
// BufferStore
Stmt
VisitStmt_
(
const
BufferStoreNode
*
op
)
final
{
Stmt
VisitStmt_
(
const
BufferStoreNode
*
op
)
final
{
auto
store
=
GetRef
<
BufferStore
>
(
op
);
auto
fmutate
=
[
this
](
const
PrimExpr
&
index
)
{
return
this
->
VisitExpr
(
index
);
};
auto
fmutate
=
[
this
](
const
PrimExpr
&
index
)
{
return
this
->
VisitExpr
(
index
);
};
Array
<
PrimExpr
>
indices
=
op
->
indices
.
Map
(
fmutate
);
PrimExpr
value
=
this
->
VisitExpr
(
op
->
value
);
if
(
!
indices
.
same_as
(
op
->
indices
)
||
!
value
.
same_as
(
op
->
value
))
{
ICHECK
(
!
op
->
buffer
->
dtype
.
is_scalable_vector
())
<<
"Vectorizing over scalable buffer elements is not supported in vectorizer."
;
<<
"Vectorizing over scalable buffer elements is not supported in "
"vectorizer."
;
// How many lanes of indexing are present in the index and
// buffer element type, excluding the last index.
int
other_index_lanes
=
op
->
buffer
->
dtype
.
lanes
();
for
(
size_t
i
=
0
;
i
<
indices
.
size
()
-
1
;
i
++
)
{
other_index_lanes
*=
indices
[
i
].
dtype
().
lanes
();
// Only allow the last index to be scalable
ICHECK
(
!
indices
[
i
].
dtype
().
is_scalable_vector
())
<<
"Only the last index can be scalable."
;
ICHECK
(
!
indices
[
i
].
dtype
().
is_scalable_vector
())
<<
"Only the last index can be scalable."
;
}
// The total number of lanes of indexing, including the last index.
...
...
@@ -519,14 +556,16 @@ class Vectorizer : public StmtMutator, public ExprFunctor<PrimExpr(const PrimExp
int
total_lanes
=
std
::
max
(
index_lanes
,
value_dtype_lanes
);
ICHECK_EQ
(
total_lanes
%
other_index_lanes
,
0
)
<<
"When storing to buffer "
<<
op
->
buffer
->
name
<<
", cannot produce "
<<
total_lanes
<<
"When storing to buffer "
<<
op
->
buffer
->
name
<<
", cannot produce "
<<
total_lanes
<<
" lanes of storage location by changing the last index."
;
int
last_index_lanes
=
total_lanes
/
other_index_lanes
;
// Broadcast the last index such that the total number of index
// lanes matches the desired number.
indices
.
Set
(
indices
.
size
()
-
1
,
BroadcastTo
(
indices
[
indices
.
size
()
-
1
],
last_index_lanes
,
is_last_index_scalable
));
indices
.
Set
(
indices
.
size
()
-
1
,
BroadcastTo
(
indices
[
indices
.
size
()
-
1
],
last_index_lanes
,
is_last_index_scalable
));
auto
writer
=
store
.
CopyOnWrite
();
writer
->
indices
=
indices
;
...
...
@@ -536,7 +575,7 @@ class Vectorizer : public StmtMutator, public ExprFunctor<PrimExpr(const PrimExp
return
std
::
move
(
store
);
}
// For
Stmt
VisitStmt_
(
const
ForNode
*
op
)
final
{
Stmt
VisitStmt_
(
const
ForNode
*
op
)
final
{
if
(
op
->
kind
==
ForKind
::
kVectorized
)
{
LOG
(
WARNING
)
<<
"Detect vectorize inside vectorized loop, ignoring..."
;
}
...
...
@@ -550,12 +589,12 @@ class Vectorizer : public StmtMutator, public ExprFunctor<PrimExpr(const PrimExp
if
(
extent
.
same_as
(
op
->
extent
)
&&
body
.
same_as
(
op
->
body
))
{
return
GetRef
<
Stmt
>
(
op
);
}
else
{
return
For
(
op
->
loop_var
,
op
->
min
,
extent
,
op
->
kind
,
body
,
op
->
thread_binding
,
op
->
annotations
);
return
For
(
op
->
loop_var
,
op
->
min
,
extent
,
op
->
kind
,
body
,
op
->
thread_binding
,
op
->
annotations
);
}
}
// IfThenElse
Stmt
VisitStmt_
(
const
IfThenElseNode
*
op
)
final
{
Stmt
VisitStmt_
(
const
IfThenElseNode
*
op
)
final
{
ICHECK
(
!
op
->
condition
.
dtype
().
is_scalable_or_fixed_length_vector
());
PrimExpr
condition
=
this
->
VisitExpr
(
op
->
condition
);
if
(
condition
.
dtype
().
is_scalable_or_fixed_length_vector
())
{
...
...
@@ -574,13 +613,14 @@ class Vectorizer : public StmtMutator, public ExprFunctor<PrimExpr(const PrimExp
}
}
// While
Stmt
VisitStmt_
(
const
WhileNode
*
op
)
final
{
Stmt
VisitStmt_
(
const
WhileNode
*
op
)
final
{
LOG
(
FATAL
)
<<
"A while loop inside a vectorized loop not supported."
;
}
// LetStmt
Stmt
VisitStmt_
(
const
LetStmtNode
*
op
)
final
{
Stmt
VisitStmt_
(
const
LetStmtNode
*
op
)
final
{
PrimExpr
value
=
this
->
VisitExpr
(
op
->
value
);
ICHECK
(
!
let_binding_
.
count
(
op
->
var
))
<<
"SSA violation, a single var is binded twice"
;
ICHECK
(
!
let_binding_
.
count
(
op
->
var
))
<<
"SSA violation, a single var is binded twice"
;
let_binding_
[
op
->
var
]
=
value
;
if
(
value
.
dtype
().
get_lanes_or_vscale_factor
()
!=
...
...
@@ -599,20 +639,22 @@ class Vectorizer : public StmtMutator, public ExprFunctor<PrimExpr(const PrimExp
}
}
// Allocate
Stmt
VisitStmt_
(
const
AllocateNode
*
op
)
final
{
Stmt
VisitStmt_
(
const
AllocateNode
*
op
)
final
{
// Mutate the condition
PrimExpr
condition
=
this
->
VisitExpr
(
op
->
condition
);
if
(
condition
.
dtype
().
is_scalable_or_fixed_length_vector
())
{
LOG
(
WARNING
)
<<
"Cannot handle vector extent in alloc of "
<<
op
->
buffer_var
->
name_hint
;
LOG
(
WARNING
)
<<
"Cannot handle vector extent in alloc of "
<<
op
->
buffer_var
->
name_hint
;
return
Scalarize
(
GetRef
<
Stmt
>
(
op
));
}
// Mutate the extents
Array
<
PrimExpr
>
extents
;
for
(
const
auto
&
extent
:
op
->
extents
)
{
for
(
const
auto
&
extent
:
op
->
extents
)
{
PrimExpr
new_ext
=
this
->
VisitExpr
(
extent
);
if
(
new_ext
.
dtype
().
is_scalable_or_fixed_length_vector
())
{
LOG
(
WARNING
)
<<
"Cannot handle vector extent in alloc of "
<<
op
->
buffer_var
->
name_hint
;
LOG
(
WARNING
)
<<
"Cannot handle vector extent in alloc of "
<<
op
->
buffer_var
->
name_hint
;
return
Scalarize
(
GetRef
<
Stmt
>
(
op
));
}
extents
.
push_back
(
new_ext
);
...
...
@@ -629,7 +671,8 @@ class Vectorizer : public StmtMutator, public ExprFunctor<PrimExpr(const PrimExp
extents
.
Set
(
extents
.
size
()
-
1
,
extents
[
extents
.
size
()
-
1
]
*
var_lanes_
);
// Rewrite access to the buffer in the body.
Stmt
body
=
VecAllocAccess
(
op
->
buffer_var
.
get
(),
var_
,
var_lanes_
)(
op
->
body
);
Stmt
body
=
VecAllocAccess
(
op
->
buffer_var
.
get
(),
var_
,
var_lanes_
)(
op
->
body
);
body
=
this
->
VisitStmt
(
body
);
return
Allocate
(
op
->
buffer_var
,
op
->
dtype
,
extents
,
condition
,
body
);
}
...
...
@@ -641,11 +684,11 @@ class Vectorizer : public StmtMutator, public ExprFunctor<PrimExpr(const PrimExp
return
For
(
idx
,
IntImm
(
var_
->
dtype
,
0
),
var_lanes_
,
ForKind
::
kSerial
,
stmt
);
}
// ProducerStore
Stmt
VisitStmt_
(
const
ProducerStoreNode
*
op
)
final
{
Stmt
VisitStmt_
(
const
ProducerStoreNode
*
op
)
final
{
LOG
(
FATAL
)
<<
"ProducerProvide cannot appear in a TIR PrimFunc"
;
}
private:
private:
// analyzer
arith
::
Analyzer
analyzer_
;
// deep equal
...
...
@@ -661,19 +704,22 @@ class Vectorizer : public StmtMutator, public ExprFunctor<PrimExpr(const PrimExp
// Let binding
std
::
unordered_map
<
Var
,
PrimExpr
,
ObjectPtrHash
,
ObjectPtrEqual
>
let_binding_
;
// vectorizable property
OpAttrMap
<
TVectorizable
>
op_vectorizable_
=
Op
::
GetAttrMap
<
TVectorizable
>
(
"TVectorizable"
);
OpAttrMap
<
TVectorizable
>
op_vectorizable_
=
Op
::
GetAttrMap
<
TVectorizable
>
(
"TVectorizable"
);
// mutate array, with given lane requirement
// when finished, p_lane updates the lane requirement.
Array
<
PrimExpr
>
MutateArray
(
Array
<
PrimExpr
>
arr
,
int
*
p_lanes
)
{
if
(
arr
.
size
()
==
0
)
return
arr
;
int
&
lanes
=
*
p_lanes
;
Array
<
PrimExpr
>
MutateArray
(
Array
<
PrimExpr
>
arr
,
int
*
p_lanes
)
{
if
(
arr
.
size
()
==
0
)
return
arr
;
int
&
lanes
=
*
p_lanes
;
bool
changed
=
false
;
std
::
vector
<
PrimExpr
>
new_arr
(
arr
.
size
());
for
(
size_t
i
=
0
;
i
<
arr
.
size
();
i
++
)
{
PrimExpr
old_elem
=
arr
[
i
];
PrimExpr
new_elem
=
this
->
VisitExpr
(
old_elem
);
if
(
!
new_elem
.
same_as
(
old_elem
))
changed
=
true
;
if
(
!
new_elem
.
same_as
(
old_elem
))
changed
=
true
;
new_arr
[
i
]
=
new_elem
;
lanes
=
std
::
max
(
lanes
,
new_elem
.
dtype
().
lanes
());
}
...
...
@@ -684,12 +730,13 @@ class Vectorizer : public StmtMutator, public ExprFunctor<PrimExpr(const PrimExp
changed
=
true
;
}
}
if
(
!
changed
)
return
arr
;
if
(
!
changed
)
return
arr
;
return
Array
<
PrimExpr
>
(
new_arr
);
}
template
<
typename
TOp
,
typename
T
>
PrimExpr
BinaryVec
(
const
T
*
op
)
{
static_assert
(
std
::
is_same
<
typename
TOp
::
ContainerType
,
T
>::
value
,
"constraint"
);
template
<
typename
TOp
,
typename
T
>
PrimExpr
BinaryVec
(
const
T
*
op
)
{
static_assert
(
std
::
is_same
<
typename
TOp
::
ContainerType
,
T
>::
value
,
"constraint"
);
PrimExpr
a
=
this
->
VisitExpr
(
op
->
a
);
PrimExpr
b
=
this
->
VisitExpr
(
op
->
b
);
if
(
a
.
same_as
(
op
->
a
)
&&
b
.
same_as
(
op
->
b
))
{
...
...
@@ -698,12 +745,14 @@ class Vectorizer : public StmtMutator, public ExprFunctor<PrimExpr(const PrimExp
int
a_lanes
=
a
.
dtype
().
get_lanes_or_vscale_factor
();
int
b_lanes
=
b
.
dtype
().
get_lanes_or_vscale_factor
();
int
lanes
=
std
::
max
(
a_lanes
,
b_lanes
);
bool
is_scalable
=
a
.
dtype
().
is_scalable_vector
()
||
b
.
dtype
().
is_scalable_vector
();
return
TOp
(
BroadcastTo
(
a
,
lanes
,
is_scalable
),
BroadcastTo
(
b
,
lanes
,
is_scalable
));
bool
is_scalable
=
a
.
dtype
().
is_scalable_vector
()
||
b
.
dtype
().
is_scalable_vector
();
return
TOp
(
BroadcastTo
(
a
,
lanes
,
is_scalable
),
BroadcastTo
(
b
,
lanes
,
is_scalable
));
}
}
template
<
typename
T
,
typename
FCompute
>
PrimExpr
AddSubVec
(
const
T
*
op
,
FCompute
fcompute
)
{
PrimExpr
AddSubVec
(
const
T
*
op
,
FCompute
fcompute
)
{
PrimExpr
a
=
this
->
VisitExpr
(
op
->
a
);
PrimExpr
b
=
this
->
VisitExpr
(
op
->
b
);
if
(
a
.
same_as
(
op
->
a
)
&&
b
.
same_as
(
op
->
b
))
{
...
...
@@ -713,21 +762,25 @@ class Vectorizer : public StmtMutator, public ExprFunctor<PrimExpr(const PrimExp
int
b_lanes
=
b
.
dtype
().
get_lanes_or_vscale_factor
();
int
lanes
=
std
::
max
(
a_lanes
,
b_lanes
);
if
(
lanes
!=
1
)
{
const
RampNode
*
b_ramp
=
b
.
as
<
RampNode
>
();
const
RampNode
*
a_ramp
=
a
.
as
<
RampNode
>
();
const
RampNode
*
b_ramp
=
b
.
as
<
RampNode
>
();
const
RampNode
*
a_ramp
=
a
.
as
<
RampNode
>
();
if
(
a
.
dtype
().
is_scalar
()
&&
b_ramp
)
{
return
Ramp
(
fcompute
(
a
,
b_ramp
->
base
),
fcompute
(
make_zero
(
b_ramp
->
stride
.
dtype
()),
b_ramp
->
stride
),
b_ramp
->
lanes
);
return
Ramp
(
fcompute
(
a
,
b_ramp
->
base
),
fcompute
(
make_zero
(
b_ramp
->
stride
.
dtype
()),
b_ramp
->
stride
),
b_ramp
->
lanes
);
}
if
(
b
.
dtype
().
is_scalar
()
&&
a_ramp
)
{
return
Ramp
(
fcompute
(
a_ramp
->
base
,
b
),
a_ramp
->
stride
,
a_ramp
->
lanes
);
}
}
bool
is_scalable
=
a
.
dtype
().
is_scalable_vector
()
||
b
.
dtype
().
is_scalable_vector
();
return
fcompute
(
BroadcastTo
(
a
,
lanes
,
is_scalable
),
BroadcastTo
(
b
,
lanes
,
is_scalable
));
bool
is_scalable
=
a
.
dtype
().
is_scalable_vector
()
||
b
.
dtype
().
is_scalable_vector
();
return
fcompute
(
BroadcastTo
(
a
,
lanes
,
is_scalable
),
BroadcastTo
(
b
,
lanes
,
is_scalable
));
}
}
};
}
// namespace tl
}
// namespace tvm
\ No newline at end of file
}
// namespace tl
}
// namespace tvm
\ No newline at end of file
src/transform/frontend_legalize.cc
View file @
549416f7
...
...
@@ -34,19 +34,19 @@ namespace tl {
using
namespace
tir
;
class
FrontendLegalizer
:
public
arith
::
IRMutatorWithAnalyzer
{
public:
public:
static
PrimFunc
Substitute
(
PrimFunc
f
)
{
arith
::
Analyzer
analyzer
;
FrontendLegalizer
substituter
(
&
analyzer
);
PrimFuncNode
*
fptr
=
f
.
CopyOnWrite
();
PrimFuncNode
*
fptr
=
f
.
CopyOnWrite
();
fptr
->
body
=
substituter
.
VisitStmt
(
f
->
body
);
return
f
;
}
private:
private:
using
arith
::
IRMutatorWithAnalyzer
::
IRMutatorWithAnalyzer
;
Stmt
VisitStmt_
(
const
ForNode
*
node
)
final
{
Stmt
VisitStmt_
(
const
ForNode
*
node
)
final
{
if
(
node
->
kind
==
ForKind
::
kParallel
)
{
parallel_for_scope_
++
;
}
...
...
@@ -57,7 +57,7 @@ class FrontendLegalizer : public arith::IRMutatorWithAnalyzer {
return
n
;
}
PrimExpr
VisitExpr_
(
const
VarNode
*
node
)
final
{
PrimExpr
VisitExpr_
(
const
VarNode
*
node
)
final
{
if
(
let_bindings_
.
count
(
node
))
{
return
arith
::
IRMutatorWithAnalyzer
::
VisitExpr
(
let_bindings_
[
node
]);
}
else
{
...
...
@@ -65,18 +65,18 @@ class FrontendLegalizer : public arith::IRMutatorWithAnalyzer {
}
}
Stmt
VisitStmt_
(
const
LetStmtNode
*
node
)
final
{
Stmt
VisitStmt_
(
const
LetStmtNode
*
node
)
final
{
let_bindings_
[
node
->
var
.
get
()]
=
node
->
value
;
return
arith
::
IRMutatorWithAnalyzer
::
VisitStmt
(
node
->
body
);
}
PrimExpr
VisitExpr_
(
const
LetNode
*
node
)
final
{
PrimExpr
VisitExpr_
(
const
LetNode
*
node
)
final
{
let_bindings_
[
node
->
var
.
get
()]
=
node
->
value
;
return
arith
::
IRMutatorWithAnalyzer
::
VisitExpr
(
node
->
body
);
}
int
parallel_for_scope_
=
0
;
std
::
unordered_map
<
const
VarNode
*
,
PrimExpr
>
let_bindings_
;
std
::
unordered_map
<
const
VarNode
*
,
PrimExpr
>
let_bindings_
;
};
using
namespace
tir
::
transform
;
...
...
@@ -91,5 +91,5 @@ Pass FrontendLegalize() {
TVM_REGISTER_GLOBAL
(
"tl.transform.FrontendLegalize"
)
.
set_body_typed
(
FrontendLegalize
);
}
// namespace tl
}
// namespace tvm
}
// namespace tl
}
// namespace tvm
src/transform/inject_fence_proxy.cc
View file @
549416f7
...
...
@@ -38,10 +38,10 @@ using namespace tir;
enum
class
Proxy
{
kGeneric
,
kAsync
,
kBoth
};
class
ProxyMarker
:
public
StmtVisitor
{
public:
public:
ProxyMarker
()
=
default
;
Proxy
GetProxy
(
const
StmtNode
*
stmt
)
const
{
Proxy
GetProxy
(
const
StmtNode
*
stmt
)
const
{
auto
it
=
map_
.
find
(
stmt
);
// ICHECK(it != map_.end());
// TODO: This is a hack implementation to avoid the ICHECK failure.
...
...
@@ -51,9 +51,9 @@ class ProxyMarker : public StmtVisitor {
return
it
->
second
;
}
Proxy
GetProxy
(
const
Stmt
&
stmt
)
const
{
return
GetProxy
(
stmt
.
get
());
}
Proxy
GetProxy
(
const
Stmt
&
stmt
)
const
{
return
GetProxy
(
stmt
.
get
());
}
void
VisitStmt_
(
const
EvaluateNode
*
op
)
final
{
void
VisitStmt_
(
const
EvaluateNode
*
op
)
final
{
Proxy
proxy
=
Proxy
::
kAsync
;
if
(
auto
call
=
op
->
value
.
as
<
CallNode
>
())
{
if
(
call
->
op
.
same_as
(
LDMatrixOp
())
||
call
->
op
.
same_as
(
STMatrixOp
()))
{
...
...
@@ -63,12 +63,12 @@ class ProxyMarker : public StmtVisitor {
SetProxy
(
op
,
proxy
);
}
void
VisitStmt_
(
const
BufferStoreNode
*
op
)
final
{
void
VisitStmt_
(
const
BufferStoreNode
*
op
)
final
{
Proxy
proxy
=
Proxy
::
kGeneric
;
SetProxy
(
op
,
proxy
);
}
void
VisitStmt_
(
const
SeqStmtNode
*
op
)
final
{
void
VisitStmt_
(
const
SeqStmtNode
*
op
)
final
{
StmtVisitor
::
VisitStmt_
(
op
);
auto
role
=
GetProxy
(
op
->
seq
[
0
]);
for
(
auto
stmt
:
op
->
seq
)
{
...
...
@@ -80,61 +80,59 @@ class ProxyMarker : public StmtVisitor {
SetProxy
(
op
,
role
);
}
void
VisitStmt_
(
const
IfThenElseNode
*
op
)
final
{
void
VisitStmt_
(
const
IfThenElseNode
*
op
)
final
{
StmtVisitor
::
VisitStmt_
(
op
);
auto
role
=
GetProxy
(
op
->
then_case
);
if
(
op
->
else_case
.
defined
())
{
auto
role_else
=
GetProxy
(
op
->
else_case
.
value
());
if
(
role
!=
role_else
)
role
=
Proxy
::
kBoth
;
if
(
role
!=
role_else
)
role
=
Proxy
::
kBoth
;
}
SetProxy
(
op
,
role
);
}
void
VisitStmt_
(
const
BlockRealizeNode
*
op
)
final
{
void
VisitStmt_
(
const
BlockRealizeNode
*
op
)
final
{
StmtVisitor
::
VisitStmt_
(
op
);
SetProxy
(
op
,
GetProxy
(
op
->
block
));
}
template
<
class
NodeType
>
void
HandleBodyStmt
(
const
NodeType
*
op
)
{
template
<
class
NodeType
>
void
HandleBodyStmt
(
const
NodeType
*
op
)
{
StmtVisitor
::
VisitStmt_
(
op
);
SetProxy
(
op
,
GetProxy
(
op
->
body
));
}
void
VisitStmt_
(
const
ForNode
*
op
)
final
{
HandleBodyStmt
(
op
);
}
void
VisitStmt_
(
const
LetStmtNode
*
op
)
final
{
HandleBodyStmt
(
op
);
}
void
VisitStmt_
(
const
AttrStmtNode
*
op
)
final
{
HandleBodyStmt
(
op
);
}
void
VisitStmt_
(
const
AssertStmtNode
*
op
)
final
{
HandleBodyStmt
(
op
);
}
void
VisitStmt_
(
const
BlockNode
*
op
)
final
{
HandleBodyStmt
(
op
);
}
void
VisitStmt_
(
const
ForNode
*
op
)
final
{
HandleBodyStmt
(
op
);
}
void
VisitStmt_
(
const
LetStmtNode
*
op
)
final
{
HandleBodyStmt
(
op
);
}
void
VisitStmt_
(
const
AttrStmtNode
*
op
)
final
{
HandleBodyStmt
(
op
);
}
void
VisitStmt_
(
const
AssertStmtNode
*
op
)
final
{
HandleBodyStmt
(
op
);
}
void
VisitStmt_
(
const
BlockNode
*
op
)
final
{
HandleBodyStmt
(
op
);
}
private:
void
SetProxy
(
const
StmtNode
*
stmt
,
Proxy
proxy
)
{
map_
[
stmt
]
=
proxy
;
}
std
::
unordered_map
<
const
StmtNode
*
,
Proxy
>
map_
;
private:
void
SetProxy
(
const
StmtNode
*
stmt
,
Proxy
proxy
)
{
map_
[
stmt
]
=
proxy
;
}
std
::
unordered_map
<
const
StmtNode
*
,
Proxy
>
map_
;
};
class
InjectFenceProxy
:
public
StmtExprMutator
{
public:
public:
static
PrimFunc
Substitute
(
PrimFunc
f
)
{
auto
T
=
InjectFenceProxy
();
f
.
CopyOnWrite
()
->
body
=
T
(
f
->
body
);
return
f
;
}
private:
Proxy
get_generic_proxy
(
const
Stmt
&
stmt
)
{
private:
Proxy
get_generic_proxy
(
const
Stmt
&
stmt
)
{
auto
marker
=
ProxyMarker
();
marker
(
stmt
);
return
marker
.
GetProxy
(
stmt
);
}
Stmt
VisitStmt_
(
const
SeqStmtNode
*
op
)
final
{
Stmt
VisitStmt_
(
const
SeqStmtNode
*
op
)
final
{
ICHECK
(
op
->
seq
.
size
()
>
0
);
Array
<
Stmt
>
new_body
;
Proxy
cur_proxy
,
prev_proxy
;
auto
fence_stmt
=
Evaluate
(
Call
(
DataType
::
Handle
(),
FenceProxyAsyncOp
(),
{}));
auto
fence_stmt
=
Evaluate
(
Call
(
DataType
::
Handle
(),
FenceProxyAsyncOp
(),
{}));
prev_proxy
=
get_generic_proxy
(
op
->
seq
[
0
]);
new_body
.
push_back
(
VisitStmt
(
op
->
seq
[
0
]));
if
(
op
->
seq
.
size
()
>
1
)
{
...
...
@@ -171,5 +169,5 @@ tvm::transform::Pass InjectFenceProxy() {
TVM_REGISTER_GLOBAL
(
"tl.transform.InjectFenceProxy"
)
.
set_body_typed
(
InjectFenceProxy
);
}
// namespace tl
}
// namespace tvm
}
// namespace tl
}
// namespace tvm
src/transform/inject_pipeline.cc
View file @
549416f7
...
...
@@ -19,7 +19,8 @@
/*!
* \file inject_software_pipeline.cc
* \brief Transform annotated loops into pipelined one that parallelize producers and consumers
* \brief Transform annotated loops into pipelined one that parallelize
* producers and consumers
*/
#include <tvm/target/target.h>
#include <tvm/tir/builtin.h>
...
...
@@ -38,24 +39,27 @@ using namespace tir;
/*!
* \brief Create a block and infer the access region with the given body.
*
* The result is a opaque block that doesn't contain any block iter vars. In
case the body is a
* block realize without predicate, it is unnecessary to
create a new block, the block of the block
* realize will be returned.
* The result is a opaque block that doesn't contain any block iter vars. In
*
case the body is a
block realize without predicate, it is unnecessary to
*
create a new block, the block of the block
realize will be returned.
*
* \param body The body of the block.
* \param buffer_data_to_buffer The map from buffer data to buffer.
* \return The result block.
*/
Block
MakeBlock
(
const
Stmt
&
body
,
const
Map
<
Var
,
Buffer
>&
buffer_data_to_buffer
)
{
if
(
const
BlockRealizeNode
*
block_realize
=
body
.
as
<
BlockRealizeNode
>
())
{
Block
MakeBlock
(
const
Stmt
&
body
,
const
Map
<
Var
,
Buffer
>
&
buffer_data_to_buffer
)
{
if
(
const
BlockRealizeNode
*
block_realize
=
body
.
as
<
BlockRealizeNode
>
())
{
if
(
is_one
(
block_realize
->
predicate
))
{
// no need to create a new block
return
block_realize
->
block
;
}
}
Block
block
(
/*iter_vars=*/
{},
/*reads=*/
{},
/*writes=*/
{},
/*name_hint=*/
""
,
/*body*/
body
);
Array
<
Array
<
BufferRegion
>>
access
=
GetBlockReadWriteRegion
(
block
,
buffer_data_to_buffer
);
BlockNode
*
n
=
block
.
CopyOnWrite
();
Block
block
(
/*iter_vars=*/
{},
/*reads=*/
{},
/*writes=*/
{},
/*name_hint=*/
""
,
/*body*/
body
);
Array
<
Array
<
BufferRegion
>>
access
=
GetBlockReadWriteRegion
(
block
,
buffer_data_to_buffer
);
BlockNode
*
n
=
block
.
CopyOnWrite
();
n
->
reads
=
access
[
0
];
n
->
writes
=
access
[
1
];
return
block
;
...
...
@@ -68,69 +72,76 @@ struct PipelineAnnotation {
bool
async
;
};
using
PipelineInfo
=
std
::
unordered_map
<
Block
,
PipelineAnnotation
,
ObjectPtrHash
,
ObjectPtrEqual
>
;
using
PipelineInfo
=
std
::
unordered_map
<
Block
,
PipelineAnnotation
,
ObjectPtrHash
,
ObjectPtrEqual
>
;
struct
BufferAccessInfo
{
int
def
=
-
1
;
// the defining stage of the buffer
int
use
=
-
1
;
// the last using stage of the buffer
int
def
=
-
1
;
// the defining stage of the buffer
int
use
=
-
1
;
// the last using stage of the buffer
};
/*!
* \brief Rewriter for the body of the software pipeline. This pass inserts `floormod` to indices
* of the remapped buffer to select the version corresponding to the pipeline stage.
* \brief Rewriter for the body of the software pipeline. This pass inserts
* `floormod` to indices of the remapped buffer to select the version
* corresponding to the pipeline stage.
*/
class
PipelineBodyRewriter
:
public
StmtExprMutator
{
public:
public:
/*!
* \brief Constructor of PipelineBodyRewriter.
* \param buffer_data_to_buffer The map from buffer data to buffer.
* \param buffer_remap The map from original buffer to the buffer with updated shape for
* multi-versioning in the software pipeline.
* \param pipeline_loop The original loop to be software pipelined.
* \param access_all_versions Whether all versions the buffers in the software pipeline are
* accessed. This will be used to update block access region. In the prologue and epilogue
* of a two-stage software pipeline, only one version of these buffers are accessed.
* \param buffer_remap The map from original buffer to the buffer with updated
* shape for multi-versioning in the software pipeline. \param pipeline_loop
* The original loop to be software pipelined. \param access_all_versions
* Whether all versions the buffers in the software pipeline are accessed.
* This will be used to update block access region. In the prologue and
* epilogue of a two-stage software pipeline, only one version of these
* buffers are accessed.
*/
PipelineBodyRewriter
(
const
Map
<
Var
,
Buffer
>
&
buffer_data_to_buffer
,
const
Map
<
Buffer
,
Buffer
>
&
buffer_remap
,
For
pipeline_loop
,
bool
access_all_versions
)
PipelineBodyRewriter
(
const
Map
<
Var
,
Buffer
>
&
buffer_data_to_buffer
,
const
Map
<
Buffer
,
Buffer
>
&
buffer_remap
,
For
pipeline_loop
,
bool
access_all_versions
)
:
buffer_data_to_buffer_
(
buffer_data_to_buffer
),
buffer_remap_
(
buffer_remap
),
pipeline_loop_
(
pipeline_loop
),
buffer_remap_
(
buffer_remap
),
pipeline_loop_
(
pipeline_loop
),
access_all_versions_
(
access_all_versions
)
{}
private:
BufferRegion
RewritePipelineBufferRegion
(
const
BufferRegion
&
buffer_region
)
const
{
private:
BufferRegion
RewritePipelineBufferRegion
(
const
BufferRegion
&
buffer_region
)
const
{
auto
it
=
buffer_remap_
.
find
(
buffer_region
->
buffer
);
if
(
it
!=
buffer_remap_
.
end
())
{
Region
new_region
=
buffer_region
->
region
;
const
Buffer
&
new_buffer
=
(
*
it
).
second
;
// For pipeline buffers, relax the access region of the first dimension to
full extent
// if access_all_versions == true
const
Buffer
&
new_buffer
=
(
*
it
).
second
;
// For pipeline buffers, relax the access region of the first dimension to
//
full extent
if access_all_versions == true
Range
accessed_version
=
access_all_versions_
?
Range
::
FromMinExtent
(
0
,
new_buffer
->
shape
[
0
])
:
Range
::
FromMinExtent
(
floormod
((
pipeline_loop_
->
loop_var
-
pipeline_loop_
->
min
),
new_buffer
->
shape
[
0
]),
Integer
(
1
));
:
Range
::
FromMinExtent
(
floormod
((
pipeline_loop_
->
loop_var
-
pipeline_loop_
->
min
),
new_buffer
->
shape
[
0
]),
Integer
(
1
));
new_region
.
insert
(
new_region
.
begin
(),
accessed_version
);
return
BufferRegion
(
new_buffer
,
new_region
);
}
return
buffer_region
;
}
PrimExpr
RewriteBufferAccess
(
const
Call
&
call
,
const
std
::
vector
<
int
>
arg_indices
)
{
auto
product
=
[](
const
Array
<
PrimExpr
>&
input
)
{
return
foldl
([](
PrimExpr
a
,
PrimExpr
b
,
Span
span
)
{
return
mul
(
a
,
b
,
span
);
},
make_const
(
DataType
::
Int
(
32
),
1
),
input
);
PrimExpr
RewriteBufferAccess
(
const
Call
&
call
,
const
std
::
vector
<
int
>
arg_indices
)
{
auto
product
=
[](
const
Array
<
PrimExpr
>
&
input
)
{
return
foldl
(
[](
PrimExpr
a
,
PrimExpr
b
,
Span
span
)
{
return
mul
(
a
,
b
,
span
);
},
make_const
(
DataType
::
Int
(
32
),
1
),
input
);
};
Array
<
PrimExpr
>
new_args
=
call
->
args
;
for
(
int
i
:
arg_indices
)
{
const
Buffer
&
buffer
=
buffer_data_to_buffer_
.
at
(
Downcast
<
Var
>
(
call
->
args
[
i
]));
const
Buffer
&
buffer
=
buffer_data_to_buffer_
.
at
(
Downcast
<
Var
>
(
call
->
args
[
i
]));
auto
it
=
buffer_remap_
.
find
(
buffer
);
if
(
it
!=
buffer_remap_
.
end
())
{
const
Buffer
&
new_buffer
=
(
*
it
).
second
;
const
PrimExpr
&
old_index
=
call
->
args
[
i
+
1
];
const
Buffer
&
new_buffer
=
(
*
it
).
second
;
const
PrimExpr
&
old_index
=
call
->
args
[
i
+
1
];
PrimExpr
offset
;
if
(
new_buffer
->
strides
.
empty
())
{
offset
=
product
(
buffer
->
shape
);
...
...
@@ -138,62 +149,63 @@ class PipelineBodyRewriter : public StmtExprMutator {
offset
=
new_buffer
->
strides
[
0
];
}
PrimExpr
new_index
=
old_index
+
floormod
(
pipeline_loop_
->
loop_var
,
new_buffer
->
shape
[
0
])
*
offset
;
old_index
+
floormod
(
pipeline_loop_
->
loop_var
,
new_buffer
->
shape
[
0
])
*
offset
;
new_args
.
Set
(
i
+
1
,
new_index
);
}
}
return
Call
(
call
->
dtype
,
call
->
op
,
new_args
,
call
->
span
);
}
Stmt
VisitStmt_
(
const
BlockNode
*
op
)
final
{
for
(
const
Buffer
&
alloc_buffer
:
op
->
alloc_buffers
)
{
Stmt
VisitStmt_
(
const
BlockNode
*
op
)
final
{
for
(
const
Buffer
&
alloc_buffer
:
op
->
alloc_buffers
)
{
buffer_data_to_buffer_
.
Set
(
alloc_buffer
->
data
,
alloc_buffer
);
}
Block
block
=
Downcast
<
Block
>
(
StmtExprMutator
::
VisitStmt_
(
op
));
BlockNode
*
n
=
block
.
CopyOnWrite
();
n
->
reads
.
MutateByApply
([
this
](
const
BufferRegion
&
buffer_region
)
{
BlockNode
*
n
=
block
.
CopyOnWrite
();
n
->
reads
.
MutateByApply
([
this
](
const
BufferRegion
&
buffer_region
)
{
return
RewritePipelineBufferRegion
(
buffer_region
);
});
n
->
writes
.
MutateByApply
([
this
](
const
BufferRegion
&
buffer_region
)
{
n
->
writes
.
MutateByApply
([
this
](
const
BufferRegion
&
buffer_region
)
{
return
RewritePipelineBufferRegion
(
buffer_region
);
});
for
(
const
Buffer
&
alloc_buffer
:
op
->
alloc_buffers
)
{
for
(
const
Buffer
&
alloc_buffer
:
op
->
alloc_buffers
)
{
buffer_data_to_buffer_
.
erase
(
alloc_buffer
->
data
);
}
return
std
::
move
(
block
);
}
Stmt
VisitStmt_
(
const
BufferStoreNode
*
op
)
final
{
Stmt
VisitStmt_
(
const
BufferStoreNode
*
op
)
final
{
BufferStore
store
=
Downcast
<
BufferStore
>
(
StmtExprMutator
::
VisitStmt_
(
op
));
auto
it
=
buffer_remap_
.
find
(
store
->
buffer
);
if
(
it
==
buffer_remap_
.
end
())
{
return
std
::
move
(
store
);
}
const
Buffer
&
new_buffer
=
(
*
it
).
second
;
auto
*
n
=
store
.
CopyOnWrite
();
const
Buffer
&
new_buffer
=
(
*
it
).
second
;
auto
*
n
=
store
.
CopyOnWrite
();
n
->
buffer
=
new_buffer
;
PrimExpr
version
=
floormod
(
(
pipeline_loop_
->
loop_var
-
pipeline_loop_
->
min
),
new_buffer
->
shape
[
0
]);
PrimExpr
version
=
floormod
(
(
pipeline_loop_
->
loop_var
-
pipeline_loop_
->
min
),
new_buffer
->
shape
[
0
]);
n
->
indices
.
insert
(
n
->
indices
.
begin
(),
version
);
return
std
::
move
(
store
);
}
PrimExpr
VisitExpr_
(
const
BufferLoadNode
*
op
)
final
{
PrimExpr
VisitExpr_
(
const
BufferLoadNode
*
op
)
final
{
BufferLoad
load
=
Downcast
<
BufferLoad
>
(
StmtExprMutator
::
VisitExpr_
(
op
));
auto
it
=
buffer_remap_
.
find
(
load
->
buffer
);
if
(
it
==
buffer_remap_
.
end
())
{
return
std
::
move
(
load
);
}
const
Buffer
&
new_buffer
=
(
*
it
).
second
;
auto
*
n
=
load
.
CopyOnWrite
();
const
Buffer
&
new_buffer
=
(
*
it
).
second
;
auto
*
n
=
load
.
CopyOnWrite
();
n
->
buffer
=
new_buffer
;
PrimExpr
version
=
floormod
(
(
pipeline_loop_
->
loop_var
-
pipeline_loop_
->
min
),
new_buffer
->
shape
[
0
]);
PrimExpr
version
=
floormod
(
(
pipeline_loop_
->
loop_var
-
pipeline_loop_
->
min
),
new_buffer
->
shape
[
0
]);
n
->
indices
.
insert
(
n
->
indices
.
begin
(),
version
);
return
std
::
move
(
load
);
}
PrimExpr
VisitExpr_
(
const
CallNode
*
op
)
final
{
PrimExpr
VisitExpr_
(
const
CallNode
*
op
)
final
{
Call
call
=
Downcast
<
Call
>
(
StmtExprMutator
::
VisitExpr_
(
op
));
if
(
call
->
op
.
same_as
(
builtin
::
tvm_access_ptr
()))
{
return
RewriteBufferAccess
(
call
,
{
1
});
...
...
@@ -208,24 +220,25 @@ class PipelineBodyRewriter : public StmtExprMutator {
};
/*!
* \brief Rewriter for the software pipeline that rewrite a loop into a pipelined one.
* \brief Rewriter for the software pipeline that rewrite a loop into a
* pipelined one.
*/
class
PipelineRewriter
:
public
StmtExprMutator
{
public:
PipelineRewriter
(
Map
<
Var
,
Buffer
>
buffer_data_to_buffer
,
const
Array
<
Buffer
>&
pipeline_allocs
,
const
For
&
pipeline_loop
,
const
PipelineInfo
&
pipeline_info
)
public:
PipelineRewriter
(
Map
<
Var
,
Buffer
>
buffer_data_to_buffer
,
const
Array
<
Buffer
>
&
pipeline_allocs
,
const
For
&
pipeline_loop
,
const
PipelineInfo
&
pipeline_info
)
:
buffer_data_to_buffer_
(
std
::
move
(
buffer_data_to_buffer
)),
pipeline_allocs_
(
pipeline_allocs
),
pipeline_loop_
(
pipeline_loop
),
pipeline_allocs_
(
pipeline_allocs
),
pipeline_loop_
(
pipeline_loop
),
pipeline_info_
(
pipeline_info
)
{}
Stmt
BuildPipeline
()
{
// Step 1: Analyze accesses to the buffers in the pipeline and compute the
number of versions
// need to maintain for each buffer.
std
::
unordered_map
<
Buffer
,
BufferAccessInfo
,
ObjectPtrHash
,
ObjectPtrEqual
>
infos
=
GetBufferAccessInfo
();
for
(
const
Buffer
&
buffer
:
pipeline_allocs_
)
{
// Step 1: Analyze accesses to the buffers in the pipeline and compute the
//
number of versions
need to maintain for each buffer.
std
::
unordered_map
<
Buffer
,
BufferAccessInfo
,
ObjectPtrHash
,
ObjectPtrEqual
>
infos
=
GetBufferAccessInfo
();
for
(
const
Buffer
&
buffer
:
pipeline_allocs_
)
{
int
num_versions
=
ComputeBufferVersions
(
buffer
,
infos
.
at
(
buffer
));
if
(
num_versions
>
1
)
{
buffer_remap_
.
Set
(
buffer
,
RewriteAllocBuffer
(
buffer
,
num_versions
));
...
...
@@ -233,27 +246,28 @@ class PipelineRewriter : public StmtExprMutator {
}
ordered_stmts_
.
resize
(
pipeline_info_
.
size
());
for
(
const
auto
&
[
block
,
anno
]
:
pipeline_info_
)
{
for
(
const
auto
&
[
block
,
anno
]
:
pipeline_info_
)
{
ordered_stmts_
.
Set
(
anno
.
order
,
block
);
}
for
(
const
Block
&
block
:
ordered_stmts_
)
{
for
(
const
Block
&
block
:
ordered_stmts_
)
{
int
stage
=
pipeline_info_
[
block
].
stage
;
if
(
pipeline_info_
[
block
].
async
)
{
auto
&
state
=
async_states
[
stage
];
auto
&
state
=
async_states
[
stage
];
state
.
producer_head
=
pipeline_loop_
->
min
-
1
;
for
(
auto
write_region
:
block
->
writes
)
{
auto
buffer
=
write_region
->
buffer
;
state
.
dst_buffers
.
insert
(
buffer
.
get
());
if
(
buffer_remap_
.
count
(
buffer
))
state
.
dst_buffers
.
insert
(
buffer_remap_
[
buffer
].
get
());
if
(
buffer_remap_
.
count
(
buffer
))
state
.
dst_buffers
.
insert
(
buffer_remap_
[
buffer
].
get
());
}
}
}
std
::
unordered_set
<
int
>
consumed
;
for
(
const
Block
&
block
:
ordered_stmts_
)
{
for
(
const
Block
&
block
:
ordered_stmts_
)
{
int
stage
=
pipeline_info_
[
block
].
stage
;
if
(
pipeline_info_
[
block
].
async
)
{
auto
&
state
=
async_states
[
stage
];
auto
&
state
=
async_states
[
stage
];
if
(
state
.
commit_groups
.
empty
()
||
consumed
.
count
(
stage
))
{
state
.
commit_groups
.
push_back
({});
}
...
...
@@ -263,13 +277,15 @@ class PipelineRewriter : public StmtExprMutator {
auto
buffer
=
buffer_remap_
.
count
(
write_region
->
buffer
)
?
buffer_remap_
[
write_region
->
buffer
]
:
write_region
->
buffer
;
state
.
buffer_to_commit_group_
[
buffer
.
get
()]
=
state
.
commit_groups
.
size
()
-
1
;
state
.
buffer_to_commit_group_
[
buffer
.
get
()]
=
state
.
commit_groups
.
size
()
-
1
;
}
}
for
(
auto
read_region
:
block
->
reads
)
{
for
(
const
auto
&
[
producer_stage_id
,
producer_state
]
:
async_states
)
{
if
(
producer_stage_id
<=
stage
&&
producer_state
.
writes
(
read_region
->
buffer
))
{
for
(
const
auto
&
[
producer_stage_id
,
producer_state
]
:
async_states
)
{
if
(
producer_stage_id
<=
stage
&&
producer_state
.
writes
(
read_region
->
buffer
))
{
consumed
.
insert
(
producer_stage_id
);
}
}
...
...
@@ -277,17 +293,21 @@ class PipelineRewriter : public StmtExprMutator {
}
// Step 2: Emit the pipeline prologue, body and epilogue.
Stmt
prologue
=
EmitImpl
(
pipeline_loop_
->
min
,
pipeline_loop_
->
min
+
max_stage_
,
true
,
true
);
Stmt
body
=
EmitImpl
(
pipeline_loop_
->
min
+
max_stage_
,
pipeline_loop_
->
min
+
pipeline_loop_
->
extent
,
false
,
false
);
Stmt
epilogue
=
EmitImpl
(
pipeline_loop_
->
min
+
pipeline_loop_
->
extent
,
pipeline_loop_
->
min
+
pipeline_loop_
->
extent
+
max_stage_
,
true
,
true
);
Stmt
prologue
=
EmitImpl
(
pipeline_loop_
->
min
,
pipeline_loop_
->
min
+
max_stage_
,
true
,
true
);
Stmt
body
=
EmitImpl
(
pipeline_loop_
->
min
+
max_stage_
,
pipeline_loop_
->
min
+
pipeline_loop_
->
extent
,
false
,
false
);
Stmt
epilogue
=
EmitImpl
(
pipeline_loop_
->
min
+
pipeline_loop_
->
extent
,
pipeline_loop_
->
min
+
pipeline_loop_
->
extent
+
max_stage_
,
true
,
true
);
SeqStmt
stmt
=
SeqStmt
({
prologue
,
body
,
epilogue
});
// Step 3: Make a new block that contains new buffer allocations after pipeline rewriting.
// Step 3: Make a new block that contains new buffer allocations after
// pipeline rewriting.
Array
<
Buffer
>
alloc_buffers
;
for
(
const
auto
&
alloc
:
pipeline_allocs_
)
{
for
(
const
auto
&
alloc
:
pipeline_allocs_
)
{
alloc_buffers
.
push_back
(
buffer_remap_
.
Get
(
alloc
).
value_or
(
alloc
));
buffer_data_to_buffer_
.
erase
(
alloc
->
data
);
}
...
...
@@ -296,26 +316,28 @@ class PipelineRewriter : public StmtExprMutator {
return
BlockRealize
({},
Bool
(
true
),
block
);
}
private:
private:
/*!
* \brief Analyze accesses to the buffers in the software pipeline.
*
* This method check the 'define' and 'use' stage of the buffers in the software pipeline, which
* can be used to compute the number of versions needed to maintain after rewriting.
* This method check the 'define' and 'use' stage of the buffers in the
* software pipeline, which can be used to compute the number of versions
* needed to maintain after rewriting.
*/
std
::
unordered_map
<
Buffer
,
BufferAccessInfo
,
ObjectPtrHash
,
ObjectPtrEqual
>
GetBufferAccessInfo
()
{
std
::
unordered_map
<
Buffer
,
BufferAccessInfo
,
ObjectPtrHash
,
ObjectPtrEqual
>
infos
;
for
(
const
auto
&
pair
:
pipeline_info_
)
{
const
Block
&
block
=
pair
.
first
;
std
::
unordered_map
<
Buffer
,
BufferAccessInfo
,
ObjectPtrHash
,
ObjectPtrEqual
>
infos
;
for
(
const
auto
&
pair
:
pipeline_info_
)
{
const
Block
&
block
=
pair
.
first
;
int
stage
=
pair
.
second
.
stage
;
max_stage_
=
std
::
max
(
max_stage_
,
stage
);
for
(
const
BufferRegion
&
write
:
block
->
writes
)
{
for
(
const
BufferRegion
&
write
:
block
->
writes
)
{
if
(
!
infos
.
count
(
write
->
buffer
))
{
infos
.
emplace
(
write
->
buffer
,
BufferAccessInfo
{});
}
auto
&
info
=
infos
.
at
(
write
->
buffer
);
auto
&
info
=
infos
.
at
(
write
->
buffer
);
if
(
info
.
def
==
-
1
)
{
info
.
def
=
stage
;
}
else
{
...
...
@@ -323,11 +345,11 @@ class PipelineRewriter : public StmtExprMutator {
}
}
for
(
const
BufferRegion
&
read
:
block
->
reads
)
{
for
(
const
BufferRegion
&
read
:
block
->
reads
)
{
if
(
!
infos
.
count
(
read
->
buffer
))
{
infos
.
emplace
(
read
->
buffer
,
BufferAccessInfo
{});
}
auto
&
info
=
infos
.
at
(
read
->
buffer
);
auto
&
info
=
infos
.
at
(
read
->
buffer
);
info
.
use
=
std
::
max
(
info
.
use
,
stage
);
}
}
...
...
@@ -355,58 +377,64 @@ class PipelineRewriter : public StmtExprMutator {
}
/*!
* \brief Compute the number of versions need to maintain for buffer accessed
in the software
* pipeline.
* \brief Compute the number of versions need to maintain for buffer accessed
*
in the software
pipeline.
*
* This method applies liveness analysis to the target buffer to compute the number of versions
* need to maintain during the software pipeline.
* Annotation `attr::double_buffer_scope` is handled here which provides a way to override the
* result of the analysis. Additional double buffering in the software pipeline can be useful
* to eliminate synchronizations in GPU devices.
* This method applies liveness analysis to the target buffer to compute the
* number of versions need to maintain during the software pipeline.
* Annotation `attr::double_buffer_scope` is handled here which provides a way
* to override the result of the analysis. Additional double buffering in the
* software pipeline can be useful to eliminate synchronizations in GPU
* devices.
*
* \param buffer The target buffer
* \param buffer_info The access information of the target buffer.
* \return The number of versions required for the target buffer.
*/
int
ComputeBufferVersions
(
const
Buffer
&
buffer
,
const
BufferAccessInfo
&
buffer_info
)
{
int
ComputeBufferVersions
(
const
Buffer
&
buffer
,
const
BufferAccessInfo
&
buffer_info
)
{
if
(
buffer_info
.
def
==
-
1
)
{
// Keep the original number of versions as buffers defined outside the
software pipeline
// should not be mutated.
// Keep the original number of versions as buffers defined outside the
//
software pipeline
should not be mutated.
return
1
;
}
// `use - def + 1` is a upper bound of the needed versions
// We optimize a few case where the number of versions can be smaller than the upper bound
// We optimize a few case where the number of versions can be smaller than
// the upper bound
int
num_versions
=
buffer_info
.
use
-
buffer_info
.
def
+
1
;
if
(
num_versions
>=
2
)
{
// A special case when `use - def + 1 == 2`. Double buffering is only
needed in this case when
// these exists a reader block_i and a writer
block_j such that
// order(block_i) < order(block_j) and stage(block_i) <
stage(block_j) and the access regions
// of block_i and block_j overlap.
// A special case when `use - def + 1 == 2`. Double buffering is only
//
needed in this case when
these exists a reader block_i and a writer
//
block_j such that
order(block_i) < order(block_j) and stage(block_i) <
//
stage(block_j) and the access regions
of block_i and block_j overlap.
bool
need_multi_version
=
false
;
for
(
const
auto
&
pair1
:
pipeline_info_
)
{
const
Block
&
writer_block
=
pair1
.
first
;
const
auto
&
writer_info
=
pair1
.
second
;
for
(
const
auto
&
pair1
:
pipeline_info_
)
{
const
Block
&
writer_block
=
pair1
.
first
;
const
auto
&
writer_info
=
pair1
.
second
;
auto
it1
=
std
::
find_if
(
writer_block
->
writes
.
begin
(),
writer_block
->
writes
.
end
(),
[
&
](
const
BufferRegion
&
buffer_region
)
{
auto
it1
=
std
::
find_if
(
writer_block
->
writes
.
begin
(),
writer_block
->
writes
.
end
(),
[
&
](
const
BufferRegion
&
buffer_region
)
{
return
buffer_region
->
buffer
.
same_as
(
buffer
);
});
if
(
it1
==
writer_block
->
writes
.
end
())
{
continue
;
}
for
(
const
auto
&
pair2
:
pipeline_info_
)
{
const
Block
&
reader_block
=
pair2
.
first
;
const
auto
&
reader_info
=
pair2
.
second
;
auto
it2
=
std
::
find_if
(
reader_block
->
reads
.
begin
(),
reader_block
->
reads
.
end
(),
[
&
](
const
BufferRegion
&
buffer_region
)
{
return
buffer_region
->
buffer
.
same_as
(
buffer
);
});
for
(
const
auto
&
pair2
:
pipeline_info_
)
{
const
Block
&
reader_block
=
pair2
.
first
;
const
auto
&
reader_info
=
pair2
.
second
;
auto
it2
=
std
::
find_if
(
reader_block
->
reads
.
begin
(),
reader_block
->
reads
.
end
(),
[
&
](
const
BufferRegion
&
buffer_region
)
{
return
buffer_region
->
buffer
.
same_as
(
buffer
);
});
if
(
it2
==
reader_block
->
reads
.
end
())
{
continue
;
}
if
(
writer_info
.
order
<
reader_info
.
order
&&
writer_info
.
stage
<
reader_info
.
stage
&&
if
(
writer_info
.
order
<
reader_info
.
order
&&
writer_info
.
stage
<
reader_info
.
stage
&&
MayConflict
((
*
it1
)
->
region
,
(
*
it2
)
->
region
))
{
need_multi_version
=
true
;
break
;
...
...
@@ -421,13 +449,12 @@ class PipelineRewriter : public StmtExprMutator {
}
/*!
* \brief Rewrite buffer allocation to keep multiple versions of original buffer for pipelined
* accesses.
* \param buffer The buffer to be resized.
* \brief Rewrite buffer allocation to keep multiple versions of original
* buffer for pipelined accesses. \param buffer The buffer to be resized.
* \param num_versions The number of versions to keep.
* \return The resized buffer.
*/
Buffer
RewriteAllocBuffer
(
const
Buffer
&
buffer
,
int
num_versions
)
{
Buffer
RewriteAllocBuffer
(
const
Buffer
&
buffer
,
int
num_versions
)
{
ObjectPtr
<
BufferNode
>
new_buffer
=
make_object
<
BufferNode
>
(
*
(
buffer
.
get
()));
new_buffer
->
shape
.
insert
(
new_buffer
->
shape
.
begin
(),
PrimExpr
(
num_versions
));
if
(
new_buffer
->
strides
.
size
())
{
...
...
@@ -438,29 +465,32 @@ class PipelineRewriter : public StmtExprMutator {
return
Buffer
(
new_buffer
);
}
// Per-stage states that need to be tracked across pipeline prologue, body, and epilogue.
// Per-stage states that need to be tracked across pipeline prologue, body,
// and epilogue.
struct
AsyncStateGlobal
{
// Buffers that this stage asynchronously writes.
std
::
unordered_set
<
const
BufferNode
*>
dst_buffers
;
// An imaginary index that the latest async operation associated with this stage has written
// into. Only valid if all associated predicates are true, so that we can count the number of
// async invocations exactly. When it is valid, it is the "sum of extents of loops that have
// been executed" - 1, e.g. for epilogue it is prologue extent + body extent - 1. This
// is only needed to compute wait count for epilogue without async producers.
std
::
unordered_set
<
const
BufferNode
*>
dst_buffers
;
// An imaginary index that the latest async operation associated with this
// stage has written into. Only valid if all associated predicates are true,
// so that we can count the number of async invocations exactly. When it is
// valid, it is the "sum of extents of loops that have been executed" - 1,
// e.g. for epilogue it is prologue extent + body extent - 1. This is only
// needed to compute wait count for epilogue without async producers.
PrimExpr
producer_head
;
std
::
vector
<
std
::
vector
<
int
>>
commit_groups
;
std
::
unordered_map
<
const
BufferNode
*
,
int
>
buffer_to_commit_group_
;
std
::
unordered_map
<
const
BufferNode
*
,
int
>
buffer_to_commit_group_
;
bool
writes
(
Buffer
buf
)
const
{
return
dst_buffers
.
count
(
buf
.
get
())
>
0
;
}
};
// Per-stage states that are local to each of pipeline prologue, body, and epilogue.
// Per-stage states that are local to each of pipeline prologue, body, and
// epilogue.
struct
AsyncStateLocal
{
struct
PendingWait
{
// The index into a list of blocks, where async_wait_queue should be
attached at the
// beginning.
// The index into a list of blocks, where async_wait_queue should be
//
attached at the
beginning.
int
insert_before
;
// in_flight_count would be a more precise name, but the implementation
uses wait_count for
// brevity.
// in_flight_count would be a more precise name, but the implementation
//
uses wait_count for
brevity.
PrimExpr
wait_count
{
nullptr
};
bool
valid
()
const
{
return
wait_count
.
defined
();
}
...
...
@@ -468,8 +498,8 @@ class PipelineRewriter : public StmtExprMutator {
std
::
vector
<
PendingWait
>
pending_waits
;
// A symbolic expression representing the index the latest async operation
associated with this
// stage has written into, at the "current" iteration.
// A symbolic expression representing the index the latest async operation
//
associated with this
stage has written into, at the "current" iteration.
Optional
<
PrimExpr
>
producer_head
;
};
...
...
@@ -483,31 +513,35 @@ class PipelineRewriter : public StmtExprMutator {
bool
is_async
;
};
void
PopulateWaitCounts
(
const
std
::
vector
<
RewrittenBlockInfo
>
&
new_blocks
,
std
::
map
<
int
,
AsyncStateLocal
>
*
async_states_local
)
{
void
PopulateWaitCounts
(
const
std
::
vector
<
RewrittenBlockInfo
>
&
new_blocks
,
std
::
map
<
int
,
AsyncStateLocal
>
*
async_states_local
)
{
for
(
size_t
i
=
0
;
i
<
new_blocks
.
size
();
++
i
)
{
int
producer_stage_idx
=
-
1
;
for
(
auto
read_region
:
new_blocks
[
i
].
block
->
reads
)
{
for
(
const
auto
&
[
stage
,
state
]
:
async_states
)
{
if
(
stage
<=
new_blocks
[
i
].
stage
&&
state
.
writes
(
read_region
->
buffer
))
{
// Found an earlier stage where read_region->buffer was asynchronously written
for
(
const
auto
&
[
stage
,
state
]
:
async_states
)
{
if
(
stage
<=
new_blocks
[
i
].
stage
&&
state
.
writes
(
read_region
->
buffer
))
{
// Found an earlier stage where read_region->buffer was
// asynchronously written
ICHECK
(
producer_stage_idx
==
-
1
||
producer_stage_idx
==
stage
)
<<
"A dependency on multiple async stages is not supported"
;
producer_stage_idx
=
stage
;
}
}
}
if
(
producer_stage_idx
==
-
1
)
continue
;
const
auto
&
state
=
async_states
[
producer_stage_idx
];
auto
&
dep_local_state
=
(
*
async_states_local
)[
producer_stage_idx
];
if
(
producer_stage_idx
==
-
1
)
continue
;
const
auto
&
state
=
async_states
[
producer_stage_idx
];
auto
&
dep_local_state
=
(
*
async_states_local
)[
producer_stage_idx
];
PrimExpr
in_flight_cnt
=
0
;
for
(
const
auto
&
group
:
state
.
commit_groups
)
{
for
(
const
auto
&
group
:
state
.
commit_groups
)
{
PrimExpr
consumer_head
=
new_blocks
[
i
].
access_index
;
PrimExpr
producer_head
;
if
(
dep_local_state
.
producer_head
.
defined
())
{
producer_head
=
dep_local_state
.
producer_head
.
value
();
// if the group is after the wait point, minus by 1
if
(
group
.
front
()
>
new_blocks
[
i
].
order
)
producer_head
-=
1
;
if
(
group
.
front
()
>
new_blocks
[
i
].
order
)
producer_head
-=
1
;
}
else
{
producer_head
=
state
.
producer_head
;
}
...
...
@@ -516,41 +550,43 @@ class PipelineRewriter : public StmtExprMutator {
// We can relax the in-flight-count by the number of independent commit.
std
::
unordered_set
<
int
>
dependent_groups
;
for
(
const
auto
&
read_region
:
new_blocks
[
i
].
block
->
reads
)
{
for
(
const
auto
&
read_region
:
new_blocks
[
i
].
block
->
reads
)
{
if
(
state
.
buffer_to_commit_group_
.
count
(
read_region
->
buffer
.
get
()))
dependent_groups
.
insert
(
state
.
buffer_to_commit_group_
.
at
(
read_region
->
buffer
.
get
()));
dependent_groups
.
insert
(
state
.
buffer_to_commit_group_
.
at
(
read_region
->
buffer
.
get
()));
}
for
(
int
i
=
int
(
state
.
commit_groups
.
size
())
-
1
;
i
>=
0
;
i
--
)
{
if
(
dependent_groups
.
count
(
i
)
==
0
)
in_flight_cnt
+=
1
;
else
break
;
// stop relaxing
break
;
// stop relaxing
}
in_flight_cnt
=
analyzer_
.
Simplify
(
in_flight_cnt
);
dep_local_state
.
pending_waits
.
push_back
({
static_cast
<
int
>
(
i
),
in_flight_cnt
});
dep_local_state
.
pending_waits
.
push_back
(
{
static_cast
<
int
>
(
i
),
in_flight_cnt
});
}
}
// Given pipelined blocks and async-related information, generate final loop
statements with async
// scopes (if any).
// Given pipelined blocks and async-related information, generate final loop
//
statements with async
scopes (if any).
Array
<
Stmt
>
CompletePipelineLoopStatements
(
const
std
::
vector
<
RewrittenBlockInfo
>
&
blocks
,
const
std
::
map
<
int
,
AsyncStateLocal
>
&
async_states_local
)
const
{
const
std
::
vector
<
RewrittenBlockInfo
>
&
blocks
,
const
std
::
map
<
int
,
AsyncStateLocal
>
&
async_states_local
)
const
{
std
::
vector
<
RewrittenBlockInfo
>
new_blocks
=
blocks
;
for
(
const
auto
&
[
stage_id
,
state
]
:
async_states_local
)
{
for
(
const
auto
&
pw
:
state
.
pending_waits
)
{
auto
&
block
=
new_blocks
[
pw
.
insert_before
].
block
;
BlockNode
*
n
=
block
.
CopyOnWrite
();
for
(
const
auto
&
[
stage_id
,
state
]
:
async_states_local
)
{
for
(
const
auto
&
pw
:
state
.
pending_waits
)
{
auto
&
block
=
new_blocks
[
pw
.
insert_before
].
block
;
BlockNode
*
n
=
block
.
CopyOnWrite
();
auto
zero
=
make_zero
(
DataType
::
Int
(
32
));
n
->
body
=
AttrStmt
(
zero
,
tir
::
attr
::
async_wait_
queue_scope
,
stage_id
,
AttrStmt
(
zero
,
tir
::
attr
::
async_wait_inflight_count
,
pw
.
wait_count
,
n
->
body
));
n
->
body
=
AttrStmt
(
zero
,
tir
::
attr
::
async_wait_queue_scope
,
stage_id
,
AttrStmt
(
zero
,
tir
::
attr
::
async_wait_
inflight_count
,
pw
.
wait_count
,
n
->
body
));
}
}
// mark the last async stmt as commit
std
::
unordered_set
<
int
>
commit_group_indices
;
for
(
const
auto
&
[
stage_id
,
state
]
:
async_states
)
{
for
(
const
auto
&
[
stage_id
,
state
]
:
async_states
)
{
for
(
size_t
i
=
0
;
i
<
state
.
commit_groups
.
size
();
++
i
)
{
commit_group_indices
.
insert
(
state
.
commit_groups
[
i
].
back
());
}
...
...
@@ -561,9 +597,9 @@ class PipelineRewriter : public StmtExprMutator {
for
(
size_t
i
=
0
;
i
<
new_blocks
.
size
();
i
++
)
{
Block
block
=
new_blocks
[
i
].
block
;
if
(
commit_group_indices
.
count
(
new_blocks
[
i
].
order
))
{
auto
commit_queue_scope
=
AttrStmt
(
make_zero
(
DataType
::
Int
(
32
)),
tir
::
attr
::
async_commit_queue_scope
,
new_blocks
[
i
].
stage
,
block
->
body
);
auto
commit_queue_scope
=
AttrStmt
(
make_zero
(
DataType
::
Int
(
32
)),
tir
::
attr
::
async_commit_queue_scope
,
new_blocks
[
i
].
stage
,
block
->
body
);
block
=
MakeBlock
(
commit_queue_scope
,
buffer_data_to_buffer_
);
}
stmts
.
push_back
(
BlockRealize
({},
new_blocks
[
i
].
predicate
,
block
));
...
...
@@ -579,15 +615,18 @@ class PipelineRewriter : public StmtExprMutator {
* \param unroll_loop Whether the loop should be unrolled.
* \return The result loop.
*/
Stmt
EmitImpl
(
PrimExpr
start
,
PrimExpr
end
,
bool
unroll_loop
,
bool
need_bound_check
)
{
Stmt
EmitImpl
(
PrimExpr
start
,
PrimExpr
end
,
bool
unroll_loop
,
bool
need_bound_check
)
{
PrimExpr
new_loop_var
;
PrimExpr
extent
=
end
-
start
;
auto
make_nop
=
[]()
{
return
BlockRealize
({},
Bool
(
true
),
MakeBlock
(
Evaluate
(
0
),
{}));
};
auto
make_nop
=
[]()
{
return
BlockRealize
({},
Bool
(
true
),
MakeBlock
(
Evaluate
(
0
),
{}));
};
bool
is_unit_loop
=
analyzer_
.
CanProveEqual
(
extent
,
1
);
if
(
is_unit_loop
)
{
new_loop_var
=
start
;
// use constants as the loop var for unit loops
new_loop_var
=
start
;
// use constants as the loop var for unit loops
}
else
{
new_loop_var
=
pipeline_loop_
->
loop_var
.
copy_with_suffix
(
""
);
analyzer_
.
Bind
(
Downcast
<
Var
>
(
new_loop_var
),
Range
(
start
,
end
));
...
...
@@ -598,45 +637,52 @@ class PipelineRewriter : public StmtExprMutator {
// Async related
std
::
map
<
int
,
AsyncStateLocal
>
async_states_local
;
for
(
const
Block
&
block
:
ordered_stmts_
)
{
for
(
const
Block
&
block
:
ordered_stmts_
)
{
int
stage
=
pipeline_info_
.
at
(
block
).
stage
;
int
order
=
pipeline_info_
.
at
(
block
).
order
;
PrimExpr
inbound
=
Bool
(
true
);
PrimExpr
skewed_loop_var
=
new_loop_var
-
stage
;
if
(
need_bound_check
)
inbound
=
analyzer_
.
Simplify
(
pipeline_loop_
->
min
<=
skewed_loop_var
)
&&
(
skewed_loop_var
<
pipeline_loop_
->
min
+
pipeline_loop_
->
extent
);
inbound
=
analyzer_
.
Simplify
(
pipeline_loop_
->
min
<=
skewed_loop_var
)
&&
(
skewed_loop_var
<
pipeline_loop_
->
min
+
pipeline_loop_
->
extent
);
if
(
analyzer_
.
CanProve
(
!
inbound
))
{
continue
;
}
Block
new_block
=
Downcast
<
Block
>
(
PipelineBodyRewriter
(
buffer_data_to_buffer_
,
buffer_remap_
,
pipeline_loop_
,
max_stage_
!=
1
)(
block
));
Block
new_block
=
Downcast
<
Block
>
(
PipelineBodyRewriter
(
buffer_data_to_buffer_
,
buffer_remap_
,
pipeline_loop_
,
max_stage_
!=
1
)(
block
));
PrimExpr
delta
=
start
-
pipeline_loop_
->
min
;
// This variable corresponds to
// - "producer_head" if this stage is an async producer
// - "consumer_head" if this stage reads from asynchronously written buffers.
PrimExpr
normalized_access_index
=
is_unit_loop
?
skewed_loop_var
:
skewed_loop_var
+
delta
;
// - "consumer_head" if this stage reads from asynchronously written
// buffers.
PrimExpr
normalized_access_index
=
is_unit_loop
?
skewed_loop_var
:
skewed_loop_var
+
delta
;
// Adjust the block predicate and the body according to the final loop bound
// Adjust the block predicate and the body according to the final loop
// bound
// [pipeline_loop_->min, extent).
if
(
!
is_unit_loop
)
{
Var
loop_iter
=
Downcast
<
Var
>
(
new_loop_var
);
inbound
=
Substitute
(
inbound
,
{{
loop_iter
,
loop_iter
+
delta
}});
}
new_block
=
Downcast
<
Block
>
(
Substitute
(
new_block
,
{{
pipeline_loop_
->
loop_var
,
normalized_access_index
}}));
new_block
=
Downcast
<
Block
>
(
Substitute
(
new_block
,
{{
pipeline_loop_
->
loop_var
,
normalized_access_index
}}));
if
(
pipeline_info_
[
block
].
async
)
{
auto
&
local_state
=
async_states_local
[
stage
];
auto
&
local_state
=
async_states_local
[
stage
];
local_state
.
producer_head
=
normalized_access_index
;
BlockNode
*
n
=
new_block
.
CopyOnWrite
();
n
->
body
=
AttrStmt
(
make_zero
(
DataType
::
Int
(
32
)),
tir
::
attr
::
async_scope
,
1
,
n
->
body
);
BlockNode
*
n
=
new_block
.
CopyOnWrite
();
n
->
body
=
AttrStmt
(
make_zero
(
DataType
::
Int
(
32
)),
tir
::
attr
::
async_scope
,
1
,
n
->
body
);
}
new_blocks
.
push_back
(
{
stage
,
order
,
inbound
,
new_block
,
normalized_access_index
,
pipeline_info_
[
block
].
async
});
new_blocks
.
push_back
({
stage
,
order
,
inbound
,
new_block
,
normalized_access_index
,
pipeline_info_
[
block
].
async
});
}
PopulateWaitCounts
(
new_blocks
,
&
async_states_local
);
...
...
@@ -655,8 +701,8 @@ class PipelineRewriter : public StmtExprMutator {
if
(
!
is_unit_loop
)
{
Map
<
String
,
ObjectRef
>
preserved_annotations
;
for
(
const
auto
&
kv
:
pipeline_loop_
->
annotations
)
{
const
String
&
key
=
kv
.
first
;
for
(
const
auto
&
kv
:
pipeline_loop_
->
annotations
)
{
const
String
&
key
=
kv
.
first
;
if
(
kv
.
first
!=
tir
::
attr
::
software_pipeline_stage
&&
kv
.
first
!=
tir
::
attr
::
software_pipeline_order
&&
kv
.
first
!=
tir
::
attr
::
software_pipeline_async_stages
)
{
...
...
@@ -664,16 +710,17 @@ class PipelineRewriter : public StmtExprMutator {
}
}
new_loop
=
For
(
Downcast
<
Var
>
(
new_loop_var
),
pipeline_loop_
->
min
,
extent
,
unroll_loop
?
ForKind
::
kUnrolled
:
pipeline_loop_
->
kind
,
std
::
move
(
new_loop
),
NullOpt
,
preserved_annotations
);
unroll_loop
?
ForKind
::
kUnrolled
:
pipeline_loop_
->
kind
,
std
::
move
(
new_loop
),
NullOpt
,
preserved_annotations
);
}
// Update producer heads in the global async states.
for
(
const
auto
&
[
stage_id
,
state
]
:
async_states_local
)
{
for
(
const
auto
&
[
stage_id
,
state
]
:
async_states_local
)
{
async_states
[
stage_id
].
producer_head
+=
extent
;
}
return
BlockRealize
({},
Bool
(
true
),
MakeBlock
(
std
::
move
(
new_loop
),
buffer_data_to_buffer_
));
return
BlockRealize
({},
Bool
(
true
),
MakeBlock
(
std
::
move
(
new_loop
),
buffer_data_to_buffer_
));
}
arith
::
Analyzer
analyzer_
;
...
...
@@ -690,22 +737,23 @@ class PipelineRewriter : public StmtExprMutator {
/*!
* \brief Build the dependency graph among a array of blocks.
* \param[in] blocks The array of blocks.
* \param[out] dep_src2dst Optional, a map to store dependency edges from the source to the
* destination.
* \param[out] dep_dst2src Optional, a map to store dependency edges from the
* destination to the source.
* \param[out] dep_src2dst Optional, a map to store dependency edges from the
* source to the destination. \param[out] dep_dst2src Optional, a map to store
* dependency edges from the destination to the source.
*/
void
BuildDependencyGraph
(
const
Array
<
Block
>&
blocks
,
std
::
unordered_map
<
Block
,
Array
<
Block
>
,
ObjectPtrHash
,
ObjectPtrEqual
>*
dep_src2dst
,
std
::
unordered_map
<
Block
,
Array
<
Block
>
,
ObjectPtrHash
,
ObjectPtrEqual
>*
dep_dst2src
)
{
std
::
unordered_map
<
Var
,
Array
<
Block
>
,
ObjectPtrHash
,
ObjectPtrEqual
>
buffer_writers
;
for
(
const
Block
&
block
:
blocks
)
{
for
(
const
BufferRegion
&
read
:
block
->
reads
)
{
void
BuildDependencyGraph
(
const
Array
<
Block
>
&
blocks
,
std
::
unordered_map
<
Block
,
Array
<
Block
>
,
ObjectPtrHash
,
ObjectPtrEqual
>
*
dep_src2dst
,
std
::
unordered_map
<
Block
,
Array
<
Block
>
,
ObjectPtrHash
,
ObjectPtrEqual
>
*
dep_dst2src
)
{
std
::
unordered_map
<
Var
,
Array
<
Block
>
,
ObjectPtrHash
,
ObjectPtrEqual
>
buffer_writers
;
for
(
const
Block
&
block
:
blocks
)
{
for
(
const
BufferRegion
&
read
:
block
->
reads
)
{
auto
it
=
buffer_writers
.
find
(
read
->
buffer
->
data
);
if
(
it
!=
buffer_writers
.
end
())
{
for
(
const
Block
&
writer
:
it
->
second
)
{
for
(
const
Block
&
writer
:
it
->
second
)
{
if
(
dep_src2dst
!=
nullptr
)
{
(
*
dep_src2dst
)[
writer
].
push_back
(
block
);
}
...
...
@@ -715,83 +763,89 @@ void BuildDependencyGraph(
}
}
}
for
(
const
BufferRegion
&
write
:
block
->
writes
)
{
for
(
const
BufferRegion
&
write
:
block
->
writes
)
{
buffer_writers
[
write
->
buffer
->
data
].
push_back
(
block
);
}
}
}
class
PipelineInjector
:
private
StmtExprMutator
{
public:
static
Stmt
Inject
(
const
PrimFunc
&
func
)
{
public:
static
Stmt
Inject
(
const
PrimFunc
&
func
)
{
auto
global_symbol
=
func
->
GetAttr
<
String
>
(
tvm
::
attr
::
kGlobalSymbol
);
PipelineInjector
injector
(
global_symbol
);
for
(
const
auto
&
kv
:
func
->
buffer_map
)
{
const
Buffer
&
buffer
=
kv
.
second
;
for
(
const
auto
&
kv
:
func
->
buffer_map
)
{
const
Buffer
&
buffer
=
kv
.
second
;
injector
.
buffer_data_to_buffer_
.
Set
(
buffer
->
data
,
buffer
);
}
return
injector
(
func
->
body
);
}
private:
explicit
PipelineInjector
(
Optional
<
String
>
global_symbol
)
:
global_symbol_
(
global_symbol
)
{}
private:
explicit
PipelineInjector
(
Optional
<
String
>
global_symbol
)
:
global_symbol_
(
global_symbol
)
{}
/*!
* \brief Check the pipeline satisfies the following conditions:
* 1. No conflicting order: The order of each statement should be unique.
* 2. Reordering of statements doesn't break buffer access dependencies.
Specifically, for
* dependency (e.g. read-after-write) from statement A to
statement B, it requires:
*
case 1: stage(A) < stage(B)
*
case 2: stage(A) ==
stage(B) and order(A) < order(B)
* 2. Reordering of statements doesn't break buffer access dependencies.
*
Specifically, for
dependency (e.g. read-after-write) from statement A to
*
statement B, it requires:
case 1: stage(A) < stage(B)
case 2: stage(A) ==
* stage(B) and order(A) < order(B)
*/
void
ValidatePipelineBody
(
const
PipelineInfo
&
pipeline_info
,
const
Array
<
Block
>&
original_order
)
{
void
ValidatePipelineBody
(
const
PipelineInfo
&
pipeline_info
,
const
Array
<
Block
>
&
original_order
)
{
std
::
unordered_set
<
int
>
used_orders
;
std
::
unordered_map
<
int
,
int
>
stage_max_order
;
std
::
unordered_map
<
int
,
const
Block
*>
order_to_block
;
std
::
unordered_map
<
const
Block
*
,
int
>
block_to_stage
;
for
(
const
Block
&
block
:
original_order
)
{
const
auto
&
stmt_info
=
pipeline_info
.
at
(
block
);
std
::
unordered_map
<
int
,
const
Block
*>
order_to_block
;
std
::
unordered_map
<
const
Block
*
,
int
>
block_to_stage
;
for
(
const
Block
&
block
:
original_order
)
{
const
auto
&
stmt_info
=
pipeline_info
.
at
(
block
);
int
order
=
stmt_info
.
order
;
CHECK
(
!
used_orders
.
count
(
order
))
<<
"ValueError: Two statements in the software pipeline cannot have the same order"
;
<<
"ValueError: Two statements in the software pipeline cannot have "
"the same order"
;
used_orders
.
insert
(
order
);
}
std
::
unordered_map
<
Block
,
Array
<
Block
>
,
ObjectPtrHash
,
ObjectPtrEqual
>
dep_src2dst
;
std
::
unordered_map
<
Block
,
Array
<
Block
>
,
ObjectPtrHash
,
ObjectPtrEqual
>
dep_src2dst
;
BuildDependencyGraph
(
original_order
,
&
dep_src2dst
,
nullptr
);
for
(
const
auto
&
pair
:
dep_src2dst
)
{
const
Block
&
src
=
pair
.
first
;
const
auto
&
src_info
=
pipeline_info
.
at
(
src
);
const
Array
<
Block
>
&
dsts
=
pair
.
second
;
for
(
const
Block
&
dst
:
dsts
)
{
const
auto
&
dst_info
=
pipeline_info
.
at
(
dst
);
for
(
const
auto
&
pair
:
dep_src2dst
)
{
const
Block
&
src
=
pair
.
first
;
const
auto
&
src_info
=
pipeline_info
.
at
(
src
);
const
Array
<
Block
>
&
dsts
=
pair
.
second
;
for
(
const
Block
&
dst
:
dsts
)
{
const
auto
&
dst_info
=
pipeline_info
.
at
(
dst
);
CHECK_LE
(
src_info
.
stage
,
dst_info
.
stage
)
<<
"ValueError: statement "
<<
dst
<<
" in stage "
<<
dst_info
.
stage
<<
" cannot depends on statement "
<<
src
<<
" in a later stage "
<<
src_info
.
stage
;
<<
" cannot depends on statement "
<<
src
<<
" in a later stage "
<<
src_info
.
stage
;
if
(
src_info
.
stage
==
dst_info
.
stage
)
{
CHECK_LT
(
src_info
.
order
,
dst_info
.
order
)
<<
"ValueError: two statements with buffer "
"access dependency in the same stage of the "
"software pipeline cannot be reordered"
;
CHECK_LT
(
src_info
.
order
,
dst_info
.
order
)
<<
"ValueError: two statements with buffer "
"access dependency in the same stage of the "
"software pipeline cannot be reordered"
;
}
}
}
}
Stmt
VisitStmt_
(
const
ForNode
*
op
)
final
{
Stmt
VisitStmt_
(
const
ForNode
*
op
)
final
{
// Step 1: Recursively rewrite the children first.
For
for_node
=
Downcast
<
For
>
(
StmtExprMutator
::
VisitStmt_
(
op
));
if
(
!
HasPipelineAnnotation
(
op
))
{
return
std
::
move
(
for_node
);
}
// Step 2: Find the body and buffer allocations of the pipeline. The body
can be direct child of
// the for-loop. If the for-loop has BlockRealize as
its child, the pipeline body will be the
// child of the block.
// Step 2: Find the body and buffer allocations of the pipeline. The body
//
can be direct child of
the for-loop. If the for-loop has BlockRealize as
//
its child, the pipeline body will be the
child of the block.
Stmt
pipeline_body
{
nullptr
};
Array
<
Buffer
>
pipeline_allocs
;
if
(
const
auto
*
realize
=
for_node
->
body
.
as
<
BlockRealizeNode
>
())
{
const
auto
&
block
=
realize
->
block
;
for
(
const
auto
&
buffer
:
block
->
alloc_buffers
)
{
if
(
const
auto
*
realize
=
for_node
->
body
.
as
<
BlockRealizeNode
>
())
{
const
auto
&
block
=
realize
->
block
;
for
(
const
auto
&
buffer
:
block
->
alloc_buffers
)
{
ICHECK
(
buffer
->
IsInstance
<
BufferNode
>
());
buffer_data_to_buffer_
.
Set
(
buffer
->
data
,
buffer
);
}
...
...
@@ -801,31 +855,32 @@ class PipelineInjector : private StmtExprMutator {
pipeline_body
=
for_node
->
body
;
}
const
SeqStmtNode
*
pipeline_body_seq
=
pipeline_body
.
as
<
SeqStmtNode
>
();
CHECK
(
pipeline_body_seq
)
<<
"ValueError: The body of the software pipeline
should be SeqStmt, got "
<<
pipeline_body
->
GetTypeKey
();
const
SeqStmtNode
*
pipeline_body_seq
=
pipeline_body
.
as
<
SeqStmtNode
>
();
CHECK
(
pipeline_body_seq
)
<<
"ValueError: The body of the software pipeline "
"
should be SeqStmt, got "
<<
pipeline_body
->
GetTypeKey
();
// Step 3: Blockize the components of the pipeline. Each child of the
pipelined loop will be
// converted into a block.
// Step 3: Blockize the components of the pipeline. Each child of the
//
pipelined loop will be
converted into a block.
PipelineInfo
pipeline_info
;
Array
<
Block
>
original_order
;
// pipeline body blocks in the original order
Array
<
Block
>
original_order
;
// pipeline body blocks in the original order
auto
f_add_child
=
[
&
](
const
Stmt
&
child
)
{
auto
f_add_child
=
[
&
](
const
Stmt
&
child
)
{
original_order
.
push_back
(
MakeBlock
(
child
,
buffer_data_to_buffer_
));
};
for
(
size_t
i
=
0
;
i
<
pipeline_body_seq
->
seq
.
size
();
i
++
)
{
const
auto
*
nested_block_realize
=
pipeline_body_seq
->
seq
[
i
].
as
<
BlockRealizeNode
>
();
const
auto
*
nested_block_realize
=
pipeline_body_seq
->
seq
[
i
].
as
<
BlockRealizeNode
>
();
if
(
nested_block_realize
&&
is_one
(
nested_block_realize
->
predicate
)
&&
nested_block_realize
->
block
->
body
->
IsInstance
<
SeqStmtNode
>
())
{
const
Block
&
nested_pipeline_block
=
nested_block_realize
->
block
;
ICHECK
(
nested_pipeline_block
->
match_buffers
.
empty
());
// match_buffer should have been lowered
for
(
const
auto
&
buffer
:
nested_pipeline_block
->
alloc_buffers
)
{
const
Block
&
nested_pipeline_block
=
nested_block_realize
->
block
;
ICHECK
(
nested_pipeline_block
->
match_buffers
.
empty
());
// match_buffer should have been lowered
for
(
const
auto
&
buffer
:
nested_pipeline_block
->
alloc_buffers
)
{
pipeline_allocs
.
push_back
(
buffer
);
buffer_data_to_buffer_
.
Set
(
buffer
->
data
,
buffer
);
}
const
auto
*
nested_seq
=
nested_pipeline_block
->
body
.
as
<
SeqStmtNode
>
();
const
auto
*
nested_seq
=
nested_pipeline_block
->
body
.
as
<
SeqStmtNode
>
();
for
(
size_t
j
=
0
;
j
<
nested_seq
->
seq
.
size
();
j
++
)
{
f_add_child
(
nested_seq
->
seq
[
j
]);
}
...
...
@@ -834,21 +889,26 @@ class PipelineInjector : private StmtExprMutator {
}
}
auto
pipeline_stages
=
Downcast
<
Array
<
Integer
>>
(
op
->
annotations
.
at
(
tir
::
attr
::
software_pipeline_stage
));
auto
pipeline_orders
=
Downcast
<
Array
<
Integer
>>
(
op
->
annotations
.
at
(
tir
::
attr
::
software_pipeline_order
));
auto
pipeline_stages
=
Downcast
<
Array
<
Integer
>>
(
op
->
annotations
.
at
(
tir
::
attr
::
software_pipeline_stage
));
auto
pipeline_orders
=
Downcast
<
Array
<
Integer
>>
(
op
->
annotations
.
at
(
tir
::
attr
::
software_pipeline_order
));
CHECK_EQ
(
pipeline_stages
.
size
(),
original_order
.
size
())
<<
"PrimFunc "
<<
global_symbol_
<<
" has original order "
<<
original_order
.
Map
([](
const
auto
&
block
)
{
return
block
->
name_hint
;
})
<<
", but pipeline annotation is "
<<
pipeline_stages
<<
" with different size"
;
<<
original_order
.
Map
(
[](
const
auto
&
block
)
{
return
block
->
name_hint
;
})
<<
", but pipeline annotation is "
<<
pipeline_stages
<<
" with different size"
;
CHECK_EQ
(
pipeline_orders
.
size
(),
original_order
.
size
())
<<
"PrimFunc "
<<
global_symbol_
<<
" has original order "
<<
original_order
.
Map
([](
const
auto
&
block
)
{
return
block
->
name_hint
;
})
<<
", but pipeline annotation is "
<<
pipeline_orders
<<
" with different size"
;
<<
original_order
.
Map
(
[](
const
auto
&
block
)
{
return
block
->
name_hint
;
})
<<
", but pipeline annotation is "
<<
pipeline_orders
<<
" with different size"
;
std
::
unordered_set
<
int
>
pipeline_async_stages
;
if
(
auto
annot
=
op
->
annotations
.
Get
(
tir
::
attr
::
software_pipeline_async_stages
))
{
if
(
auto
annot
=
op
->
annotations
.
Get
(
tir
::
attr
::
software_pipeline_async_stages
))
{
for
(
auto
s
:
Downcast
<
Array
<
Integer
>>
(
annot
))
{
pipeline_async_stages
.
insert
(
s
->
value
);
}
...
...
@@ -856,43 +916,44 @@ class PipelineInjector : private StmtExprMutator {
for
(
size_t
i
=
0
;
i
<
pipeline_stages
.
size
();
i
++
)
{
int
stage
=
static_cast
<
int
>
(
pipeline_stages
[
i
]
->
value
);
bool
is_async
=
pipeline_async_stages
.
find
(
stage
)
!=
pipeline_async_stages
.
end
();
PipelineAnnotation
stage_order
{
stage
,
/*order=*/
static_cast
<
int
>
(
pipeline_orders
[
i
]
->
value
),
is_async
};
bool
is_async
=
pipeline_async_stages
.
find
(
stage
)
!=
pipeline_async_stages
.
end
();
PipelineAnnotation
stage_order
{
stage
,
/*order=*/
static_cast
<
int
>
(
pipeline_orders
[
i
]
->
value
),
is_async
};
pipeline_info
.
emplace
(
original_order
[
i
],
stage_order
);
}
ValidatePipelineBody
(
pipeline_info
,
original_order
);
// Step 4: Rewrite the pipeline body.
Stmt
pipeline
=
PipelineRewriter
(
buffer_data_to_buffer_
,
pipeline_allocs
,
GetRef
<
For
>
(
op
),
pipeline_info
)
.
BuildPipeline
();
Stmt
pipeline
=
PipelineRewriter
(
buffer_data_to_buffer_
,
pipeline_allocs
,
GetRef
<
For
>
(
op
),
pipeline_info
)
.
BuildPipeline
();
if
(
const
auto
*
realize
=
op
->
body
.
as
<
BlockRealizeNode
>
())
{
const
auto
&
block
=
realize
->
block
;
for
(
const
auto
&
buffer
:
block
->
alloc_buffers
)
{
if
(
const
auto
*
realize
=
op
->
body
.
as
<
BlockRealizeNode
>
())
{
const
auto
&
block
=
realize
->
block
;
for
(
const
auto
&
buffer
:
block
->
alloc_buffers
)
{
buffer_data_to_buffer_
.
erase
(
buffer
->
data
);
}
}
return
pipeline
;
}
Stmt
VisitStmt_
(
const
BlockNode
*
op
)
final
{
for
(
const
auto
&
buffer
:
op
->
alloc_buffers
)
{
Stmt
VisitStmt_
(
const
BlockNode
*
op
)
final
{
for
(
const
auto
&
buffer
:
op
->
alloc_buffers
)
{
buffer_data_to_buffer_
.
Set
(
buffer
->
data
,
buffer
);
}
Block
block
=
Downcast
<
Block
>
(
StmtExprMutator
::
VisitStmt_
(
op
));
for
(
const
auto
&
buffer
:
op
->
alloc_buffers
)
{
for
(
const
auto
&
buffer
:
op
->
alloc_buffers
)
{
buffer_data_to_buffer_
.
erase
(
buffer
->
data
);
}
return
std
::
move
(
block
);
}
bool
HasPipelineAnnotation
(
const
ForNode
*
op
)
const
{
bool
HasPipelineAnnotation
(
const
ForNode
*
op
)
const
{
auto
it1
=
op
->
annotations
.
find
(
tir
::
attr
::
software_pipeline_stage
);
auto
it2
=
op
->
annotations
.
find
(
tir
::
attr
::
software_pipeline_order
);
bool
has_stage
=
it1
!=
op
->
annotations
.
end
();
...
...
@@ -901,10 +962,12 @@ class PipelineInjector : private StmtExprMutator {
return
true
;
}
if
(
has_stage
)
{
LOG
(
FATAL
)
<<
"ValueError: Order of the software pipeline is not defined."
;
LOG
(
FATAL
)
<<
"ValueError: Order of the software pipeline is not defined."
;
}
if
(
has_order
)
{
LOG
(
FATAL
)
<<
"ValueError: Stage of the software pipeline is not defined."
;
LOG
(
FATAL
)
<<
"ValueError: Stage of the software pipeline is not defined."
;
}
return
false
;
}
...
...
@@ -914,13 +977,13 @@ class PipelineInjector : private StmtExprMutator {
};
/*!
* \brief Transform annotated loops into pipelined one that parallelize
producers and consumers.
* \return The IR transform pass.
* \brief Transform annotated loops into pipelined one that parallelize
*
producers and consumers.
\return The IR transform pass.
*/
tir
::
transform
::
Pass
InjectSoftwarePipeline
()
{
using
namespace
tir
::
transform
;
auto
pass_func
=
[
=
](
PrimFunc
f
,
IRModule
m
,
PassContext
ctx
)
{
auto
*
fptr
=
f
.
CopyOnWrite
();
auto
*
fptr
=
f
.
CopyOnWrite
();
fptr
->
body
=
PipelineInjector
::
Inject
(
f
);
fptr
->
body
=
ConvertSSA
(
std
::
move
(
fptr
->
body
));
return
f
;
...
...
@@ -931,5 +994,5 @@ tir::transform::Pass InjectSoftwarePipeline() {
TVM_REGISTER_GLOBAL
(
"tl.transform.InjectSoftwarePipeline"
)
.
set_body_typed
(
InjectSoftwarePipeline
);
}
// namespace tl
}
// namespace tvm
}
// namespace tl
}
// namespace tvm
src/transform/layout_inference.cc
View file @
549416f7
...
...
@@ -30,11 +30,11 @@
#include <queue>
#include "arith/ir_mutator_with_analyzer.h"
#include "../op/parallel.h"
#include "arith/ir_mutator_with_analyzer.h"
#include "common/loop_fusion_utils.h"
#include "loop_partition.h"
#include "loop_vectorize.h"
#include "common/loop_fusion_utils.h"
namespace
tvm
{
namespace
tl
{
...
...
@@ -49,7 +49,7 @@ struct LayoutInferenceResult {
};
class
BufferUseDefCollector
:
public
StmtExprVisitor
{
public:
public:
BufferUseDefCollector
()
=
default
;
LayoutInferenceResult
Run
()
{
...
...
@@ -59,22 +59,27 @@ class BufferUseDefCollector : public StmtExprVisitor {
// maintain a bfs queue and infer common layout
std
::
queue
<
int
>
q
;
std
::
vector
<
bool
>
in_queue
(
num_infer
,
true
);
for
(
int
i
=
0
;
i
<
num_infer
;
i
++
)
q
.
push
(
i
);
for
(
int
i
=
0
;
i
<
num_infer
;
i
++
)
q
.
push
(
i
);
auto
run_infer_step
=
[
&
](
int
cur_infer_id
,
InferLevel
level
,
bool
update_queue
)
{
auto
&
next
=
infer_list_
[
cur_infer_id
];
auto
run_infer_step
=
[
&
](
int
cur_infer_id
,
InferLevel
level
,
bool
update_queue
)
{
auto
&
next
=
infer_list_
[
cur_infer_id
];
auto
iter_var
=
thread_var_vec_
[
cur_infer_id
];
auto
updates
=
next
->
InferLayout
(
LayoutInferArgs
{
target_
,
static_cast
<
size_t
>
(
*
as_const_int
(
iter_var
->
dom
->
extent
)),
layout_map
},
LayoutInferArgs
{
target_
,
static_cast
<
size_t
>
(
*
as_const_int
(
iter_var
->
dom
->
extent
)),
layout_map
},
level
);
for
(
const
auto
&
[
buffer
,
layout
]
:
updates
)
{
for
(
const
auto
&
[
buffer
,
layout
]
:
updates
)
{
if
(
layout_map
.
count
(
buffer
))
{
ICHECK
(
StructuralEqual
()(
layout
,
layout_map
[
buffer
]))
<<
"Get different layout for "
<<
buffer
;
}
else
{
layout_map
.
Set
(
buffer
,
layout
);
if
(
!
update_queue
)
continue
;
if
(
!
update_queue
)
continue
;
for
(
int
idx
:
use_list_
[
buffer
])
{
if
(
!
in_queue
[
idx
]
&&
idx
!=
cur_infer_id
)
{
in_queue
[
idx
]
=
true
;
...
...
@@ -108,16 +113,17 @@ class BufferUseDefCollector : public StmtExprVisitor {
}
// Check that all fragments have been inferred
for
(
const
auto
&
[
buffer
,
_
]
:
use_list_
)
{
for
(
const
auto
&
[
buffer
,
_
]
:
use_list_
)
{
if
(
buffer
.
scope
()
==
"local.fragment"
&&
layout_map
.
count
(
buffer
)
==
0
)
LOG_ERROR
<<
"The layout for fragment "
<<
buffer
<<
" can not be inferred correctly."
;
LOG_ERROR
<<
"The layout for fragment "
<<
buffer
<<
" can not be inferred correctly."
;
}
// Collect the layout for for nodes
Map
<
For
,
Fragment
>
for_map
;
Map
<
For
,
PrimExpr
>
predicate_map
;
for
(
auto
&
base_infer
:
infer_list_
)
{
if
(
auto
for_infer
=
dynamic_cast
<
ParallelOp
*>
(
base_infer
.
get
()))
{
for
(
auto
&
base_infer
:
infer_list_
)
{
if
(
auto
for_infer
=
dynamic_cast
<
ParallelOp
*>
(
base_infer
.
get
()))
{
ICHECK
(
for_infer
->
GetLoopLayout
().
defined
())
<<
"The Layout for Parallel for can not be inferred correctly :
\n
"
<<
for_infer
->
GetRoot
();
...
...
@@ -130,25 +136,27 @@ class BufferUseDefCollector : public StmtExprVisitor {
return
{
layout_map
,
for_map
,
predicate_map
};
}
void
Collect
(
const
PrimFunc
&
f
)
{
for
(
const
auto
&
[
_
,
buffer
]
:
f
->
buffer_map
)
{
void
Collect
(
const
PrimFunc
&
f
)
{
for
(
const
auto
&
[
_
,
buffer
]
:
f
->
buffer_map
)
{
buffer_data_to_buffer_
.
Set
(
buffer
->
data
,
buffer
);
}
auto
target
=
f
->
GetAttr
<
Target
>
(
tvm
::
attr
::
kTarget
);
ICHECK
(
target
.
defined
())
<<
"Layout_Inference: Require the target attribute"
;
ICHECK
(
target
.
defined
())
<<
"Layout_Inference: Require the target attribute"
;
target_
=
target
.
value
();
this
->
operator
()(
f
->
body
);
}
private:
void
VisitExpr_
(
const
CallNode
*
op
)
final
{
private:
void
VisitExpr_
(
const
CallNode
*
op
)
final
{
StmtExprVisitor
::
VisitExpr_
(
op
);
// Do not analysis the call node to the global function.
if
(
op
->
op
.
as
<
GlobalVarNode
>
())
return
;
if
(
op
->
op
.
as
<
GlobalVarNode
>
())
return
;
auto
p
=
ParseOperator
(
GetRef
<
Call
>
(
op
),
buffer_data_to_buffer_
);
if
(
p
!=
nullptr
)
{
for
(
const
auto
&
arg
:
op
->
args
)
{
for
(
const
auto
&
arg
:
op
->
args
)
{
if
(
auto
buffer
=
getBufferFromAccessPtr
(
arg
))
{
addToUseList
(
buffer
.
value
());
}
...
...
@@ -158,7 +166,7 @@ class BufferUseDefCollector : public StmtExprVisitor {
}
}
Optional
<
Buffer
>
getBufferFromAccessPtr
(
const
PrimExpr
&
expr
)
{
Optional
<
Buffer
>
getBufferFromAccessPtr
(
const
PrimExpr
&
expr
)
{
auto
call
=
expr
.
as
<
CallNode
>
();
if
(
call
&&
call
->
op
.
same_as
(
builtin
::
tvm_access_ptr
()))
{
auto
var
=
call
->
args
[
1
].
as
<
Var
>
().
value
();
...
...
@@ -167,7 +175,7 @@ class BufferUseDefCollector : public StmtExprVisitor {
return
NullOpt
;
}
void
addToUseList
(
const
Buffer
&
buffer
)
{
void
addToUseList
(
const
Buffer
&
buffer
)
{
int
infer_idx
=
infer_list_
.
size
();
if
(
use_list_
.
find
(
buffer
)
==
use_list_
.
end
())
{
use_list_
[
buffer
]
=
{};
...
...
@@ -175,10 +183,10 @@ class BufferUseDefCollector : public StmtExprVisitor {
use_list_
[
buffer
].
push_back
(
infer_idx
);
}
void
VisitStmt_
(
const
ForNode
*
op
)
final
{
void
VisitStmt_
(
const
ForNode
*
op
)
final
{
if
(
op
->
kind
==
ForKind
::
kParallel
)
{
auto
infer
=
std
::
make_unique
<
ParallelOp
>
(
GetRef
<
For
>
(
op
));
for
(
const
auto
&
[
buffer
,
_
]
:
infer
->
GetIndiceMap
())
{
for
(
const
auto
&
[
buffer
,
_
]
:
infer
->
GetIndiceMap
())
{
addToUseList
(
buffer
);
}
infer_list_
.
push_back
(
std
::
move
(
infer
));
...
...
@@ -188,13 +196,14 @@ class BufferUseDefCollector : public StmtExprVisitor {
}
}
void
VisitStmt_
(
const
BlockNode
*
op
)
final
{
void
VisitStmt_
(
const
BlockNode
*
op
)
final
{
for
(
auto
buffer
:
op
->
alloc_buffers
)
{
buffer_data_to_buffer_
.
Set
(
buffer
->
data
,
buffer
);
}
if
(
op
->
annotations
.
count
(
attr
::
kLayoutMap
))
{
auto
map
=
op
->
annotations
.
Get
(
attr
::
kLayoutMap
).
as
<
Map
<
Var
,
Layout
>>
().
value
();
for
(
const
auto
&
[
var
,
layout
]
:
map
)
{
auto
map
=
op
->
annotations
.
Get
(
attr
::
kLayoutMap
).
as
<
Map
<
Var
,
Layout
>>
().
value
();
for
(
const
auto
&
[
var
,
layout
]
:
map
)
{
auto
buffer
=
buffer_data_to_buffer_
[
var
];
ICHECK
(
StructuralEqual
()(
layout
->
InputShape
(),
buffer
->
shape
));
annotated_layout_map_
.
Set
(
buffer
,
layout
);
...
...
@@ -203,7 +212,7 @@ class BufferUseDefCollector : public StmtExprVisitor {
StmtExprVisitor
::
VisitStmt_
(
op
);
}
void
VisitStmt_
(
const
AttrStmtNode
*
op
)
final
{
void
VisitStmt_
(
const
AttrStmtNode
*
op
)
final
{
if
(
op
->
attr_key
==
tir
::
attr
::
thread_extent
)
{
IterVar
iv
=
Downcast
<
IterVar
>
(
op
->
node
);
if
(
iv
->
thread_tag
==
"threadIdx.x"
)
{
...
...
@@ -216,7 +225,8 @@ class BufferUseDefCollector : public StmtExprVisitor {
Map
<
Var
,
Buffer
>
buffer_data_to_buffer_
;
std
::
vector
<
std
::
unique_ptr
<
Operator
>>
infer_list_
;
std
::
unordered_map
<
Buffer
,
std
::
vector
<
int
>
,
ObjectPtrHash
,
ObjectPtrEqual
>
use_list_
;
std
::
unordered_map
<
Buffer
,
std
::
vector
<
int
>
,
ObjectPtrHash
,
ObjectPtrEqual
>
use_list_
;
IterVar
thread_var_
;
std
::
vector
<
IterVar
>
thread_var_vec_
;
Target
target_
;
...
...
@@ -224,10 +234,10 @@ class BufferUseDefCollector : public StmtExprVisitor {
};
class
LayoutInferencer
:
public
IRMutatorWithAnalyzer
{
public:
public:
static
PrimFunc
Substitute
(
PrimFunc
f
)
{
arith
::
Analyzer
analyzer
;
PrimFuncNode
*
fptr
=
f
.
CopyOnWrite
();
PrimFuncNode
*
fptr
=
f
.
CopyOnWrite
();
fptr
->
body
=
ParallelLoopFuser
::
Fuse
(
f
->
body
);
BufferUseDefCollector
collector
;
collector
.
Collect
(
f
);
...
...
@@ -237,11 +247,12 @@ class LayoutInferencer : public IRMutatorWithAnalyzer {
return
f
;
}
private:
LayoutInferencer
(
const
LayoutInferenceResult
result
,
arith
::
Analyzer
*
analyzer
)
:
arith
::
IRMutatorWithAnalyzer
(
analyzer
),
result_
(
result
)
{};
private:
LayoutInferencer
(
const
LayoutInferenceResult
result
,
arith
::
Analyzer
*
analyzer
)
:
arith
::
IRMutatorWithAnalyzer
(
analyzer
),
result_
(
result
){};
Stmt
VisitStmt_
(
const
BlockNode
*
op
)
final
{
Stmt
VisitStmt_
(
const
BlockNode
*
op
)
final
{
Block
block
=
Downcast
<
Block
>
(
IRMutatorWithAnalyzer
::
VisitStmt_
(
op
));
for
(
auto
buffer
:
block
->
alloc_buffers
)
{
...
...
@@ -255,11 +266,12 @@ class LayoutInferencer : public IRMutatorWithAnalyzer {
return
block
;
}
Stmt
VisitStmt_
(
const
ForNode
*
op
)
final
{
Stmt
VisitStmt_
(
const
ForNode
*
op
)
final
{
For
for_node
=
Downcast
<
For
>
(
IRMutatorWithAnalyzer
::
VisitStmt_
(
op
));
if
(
result_
.
for_map
.
count
(
GetRef
<
For
>
(
op
)))
{
auto
loop_layout
=
result_
.
for_map
[
GetRef
<
For
>
(
op
)];
for_node
=
PartitionLoop
(
for_node
,
thread_var_
->
var
,
analyzer_
,
loop_layout
);
for_node
=
PartitionLoop
(
for_node
,
thread_var_
->
var
,
analyzer_
,
loop_layout
);
for_node
=
VectorizeLoop
(
for_node
);
if
(
result_
.
predicate_map
.
count
(
GetRef
<
For
>
(
op
)))
{
return
IfThenElse
(
result_
.
predicate_map
[
GetRef
<
For
>
(
op
)],
for_node
);
...
...
@@ -270,7 +282,7 @@ class LayoutInferencer : public IRMutatorWithAnalyzer {
return
for_node
;
}
Stmt
VisitStmt_
(
const
AttrStmtNode
*
op
)
final
{
Stmt
VisitStmt_
(
const
AttrStmtNode
*
op
)
final
{
if
(
op
->
attr_key
==
tir
::
attr
::
thread_extent
)
{
IterVar
iv
=
Downcast
<
IterVar
>
(
op
->
node
);
ICHECK_NE
(
iv
->
thread_tag
.
length
(),
0U
);
...
...
@@ -281,7 +293,7 @@ class LayoutInferencer : public IRMutatorWithAnalyzer {
return
IRMutatorWithAnalyzer
::
VisitStmt_
(
op
);
}
private:
private:
const
LayoutInferenceResult
result_
;
IterVar
thread_var_
;
};
...
...
@@ -297,5 +309,5 @@ tvm::transform::Pass LayoutInference() {
TVM_REGISTER_GLOBAL
(
"tl.transform.LayoutInference"
)
.
set_body_typed
(
LayoutInference
);
}
// namespace tl
}
// namespace tvm
}
// namespace tl
}
// namespace tvm
src/transform/legalize_safe_memory_access.cc
View file @
549416f7
...
...
@@ -30,8 +30,8 @@
#include <queue>
#include "arith/ir_mutator_with_analyzer.h"
#include "../op/parallel.h"
#include "arith/ir_mutator_with_analyzer.h"
#include "loop_partition.h"
#include "loop_vectorize.h"
...
...
@@ -43,11 +43,11 @@ using arith::IRMutatorWithAnalyzer;
// Helper class to find leaf For nodes in a given IR
class
LeafForFinder
:
public
StmtVisitor
{
public:
public:
std
::
vector
<
For
>
leaf_for_nodes
;
private:
void
VisitStmt_
(
const
ForNode
*
op
)
final
{
private:
void
VisitStmt_
(
const
ForNode
*
op
)
final
{
has_child_for_
=
false
;
bool
parent_has_child_for
=
parent_has_child_for_
;
parent_has_child_for_
=
false
;
...
...
@@ -62,7 +62,7 @@ class LeafForFinder : public StmtVisitor {
parent_has_child_for_
=
true
;
}
private:
private:
bool
has_child_for_
=
false
;
bool
parent_has_child_for_
=
false
;
};
...
...
@@ -75,11 +75,11 @@ class LeafForFinder : public StmtVisitor {
// If the index might exceed the shape (upper bound too large),
// log a warning or handle accordingly.
struct
GlobalMemChecker
:
public
StmtExprVisitor
{
arith
::
Analyzer
*
analyzer
;
arith
::
Analyzer
*
analyzer
;
explicit
GlobalMemChecker
(
arith
::
Analyzer
*
analyzer
)
:
analyzer
(
analyzer
)
{}
explicit
GlobalMemChecker
(
arith
::
Analyzer
*
analyzer
)
:
analyzer
(
analyzer
)
{}
void
VisitExpr_
(
const
BufferLoadNode
*
op
)
final
{
void
VisitExpr_
(
const
BufferLoadNode
*
op
)
final
{
// Check if the buffer is in global scope
if
(
IsGlobalBuffer
(
op
->
buffer
))
{
CheckBufferIndices
(
op
->
buffer
,
op
->
indices
,
/*is_load=*/
true
);
...
...
@@ -87,7 +87,7 @@ struct GlobalMemChecker : public StmtExprVisitor {
StmtExprVisitor
::
VisitExpr_
(
op
);
}
void
VisitStmt_
(
const
BufferStoreNode
*
op
)
final
{
void
VisitStmt_
(
const
BufferStoreNode
*
op
)
final
{
// Check if the buffer is in global scope
if
(
IsGlobalBuffer
(
op
->
buffer
))
{
CheckBufferIndices
(
op
->
buffer
,
op
->
indices
,
/*is_load=*/
false
);
...
...
@@ -96,21 +96,24 @@ struct GlobalMemChecker : public StmtExprVisitor {
}
// Helper function to determine if a buffer is global
bool
IsGlobalBuffer
(
const
Buffer
&
buffer
)
{
// The storage scope is often encoded in the buffer->data var name or associated attributes.
// In typical TVM IR, global buffers have scope "global".
// Here we assume a helper function GetPtrStorageScope is available.
// If not, you might need to parse buffer->data->name_hint or associated attributes.
bool
IsGlobalBuffer
(
const
Buffer
&
buffer
)
{
// The storage scope is often encoded in the buffer->data var name or
// associated attributes. In typical TVM IR, global buffers have scope
// "global". Here we assume a helper function GetPtrStorageScope is
// available. If not, you might need to parse buffer->data->name_hint or
// associated attributes.
String
scope
=
buffer
.
scope
();
return
scope
==
"global"
;
}
// Check each index against the buffer shape dimensions
void
CheckBufferIndices
(
const
Buffer
&
buffer
,
const
Array
<
PrimExpr
>&
indices
,
bool
is_load
)
{
void
CheckBufferIndices
(
const
Buffer
&
buffer
,
const
Array
<
PrimExpr
>
&
indices
,
bool
is_load
)
{
// Ensure indices count matches buffer dimension
if
(
indices
.
size
()
!=
buffer
->
shape
.
size
())
{
LOG
(
WARNING
)
<<
"Buffer access dimension mismatch: indices size ("
<<
indices
.
size
()
<<
") vs. shape size ("
<<
buffer
->
shape
.
size
()
<<
")"
;
LOG
(
WARNING
)
<<
"Buffer access dimension mismatch: indices size ("
<<
indices
.
size
()
<<
") vs. shape size ("
<<
buffer
->
shape
.
size
()
<<
")"
;
return
;
}
...
...
@@ -130,18 +133,19 @@ struct GlobalMemChecker : public StmtExprVisitor {
Array
<
PrimExpr
>
GetConditions
()
{
return
_conditions
;
}
private:
private:
Array
<
PrimExpr
>
_conditions
;
};
class
SafeMemorysRewriter
:
public
StmtExprMutator
{
arith
::
Analyzer
*
analyzer_
;
arith
::
Analyzer
*
analyzer_
;
public:
explicit
SafeMemorysRewriter
(
arith
::
Analyzer
*
analyzer
)
:
analyzer_
(
analyzer
)
{}
public:
explicit
SafeMemorysRewriter
(
arith
::
Analyzer
*
analyzer
)
:
analyzer_
(
analyzer
)
{}
private:
Stmt
VisitStmt_
(
const
BufferStoreNode
*
op
)
final
{
private:
Stmt
VisitStmt_
(
const
BufferStoreNode
*
op
)
final
{
// Check if the buffer is in global scope
auto
store
=
Downcast
<
BufferStore
>
(
StmtExprMutator
::
VisitStmt_
(
op
));
GlobalMemChecker
checker
(
analyzer_
);
...
...
@@ -173,12 +177,13 @@ class SafeMemorysRewriter : public StmtExprMutator {
// Handle Call Nodes
// For example
// T.call_extern("handle", "atomicAddx2", T.address_of(C), T.address_of(C_shared))
Stmt
VisitStmt_
(
const
EvaluateNode
*
op
)
final
{
// T.call_extern("handle", "atomicAddx2", T.address_of(C),
// T.address_of(C_shared))
Stmt
VisitStmt_
(
const
EvaluateNode
*
op
)
final
{
auto
evaluate
=
Downcast
<
Evaluate
>
(
StmtExprMutator
::
VisitStmt_
(
op
));
auto
call
=
Downcast
<
Call
>
(
evaluate
->
value
);
if
(
call
.
defined
()
&&
call
->
op
==
builtin
::
call_extern
())
{
GlobalMemChecker
checker
(
analyzer_
);
checker
(
call
);
Array
<
PrimExpr
>
conditions
=
checker
.
GetConditions
();
...
...
@@ -197,13 +202,12 @@ class SafeMemorysRewriter : public StmtExprMutator {
return
evaluate
;
}
bool
isSharedBuffer
(
const
Buffer
&
buffer
)
{
bool
isSharedBuffer
(
const
Buffer
&
buffer
)
{
String
scope
=
buffer
.
scope
();
return
scope
==
"shared"
||
scope
==
"shared.dyn"
;
}
bool
IsGlobalBuffer
(
const
Buffer
&
buffer
)
{
bool
IsGlobalBuffer
(
const
Buffer
&
buffer
)
{
String
scope
=
buffer
.
scope
();
return
scope
==
"global"
;
}
...
...
@@ -211,32 +215,34 @@ class SafeMemorysRewriter : public StmtExprMutator {
// Class to legalize safe memory access by transforming them appropriately
class
SafeMemoryLegalizer
:
IRMutatorWithAnalyzer
{
public:
public:
// Static method to substitute and transform the given PrimFunc
static
PrimFunc
Substitute
(
PrimFunc
f
)
{
arith
::
Analyzer
analyzer
;
// Create an instance of the legalizer with the analyzer
SafeMemoryLegalizer
substituter
(
&
analyzer
);
// Get a mutable copy of the function node
PrimFuncNode
*
fptr
=
f
.
CopyOnWrite
();
PrimFuncNode
*
fptr
=
f
.
CopyOnWrite
();
// Apply the legalizer to the function body
fptr
->
body
=
substituter
.
VisitStmt
(
f
->
body
);
return
f
;
}
private:
private:
// Constructor initializing the base class with the analyzer
SafeMemoryLegalizer
(
arith
::
Analyzer
*
analyzer
)
:
arith
::
IRMutatorWithAnalyzer
(
analyzer
)
{}
SafeMemoryLegalizer
(
arith
::
Analyzer
*
analyzer
)
:
arith
::
IRMutatorWithAnalyzer
(
analyzer
)
{}
// Override the VisitStmt_ method to handle ForNode (loop statements)
Stmt
VisitStmt_
(
const
ForNode
*
op
)
final
{
Stmt
VisitStmt_
(
const
ForNode
*
op
)
final
{
// Visit and potentially modify the loop node
For
for_node
=
Downcast
<
For
>
(
IRMutatorWithAnalyzer
::
VisitStmt_
(
op
));
auto
has_inner_loop
=
HasInnerLoop
(
for_node
->
body
);
if
(
!
has_inner_loop
)
{
SafeMemorysRewriter
rewriter
(
analyzer_
);
for_node
.
CopyOnWrite
()
->
body
=
rewriter
(
for_node
->
body
);
// // Detect Buffer Load Node in the loop body, collect the indices and buffer size
// // Detect Buffer Load Node in the loop body, collect the indices and
// buffer size
// // Run the checker on the loop body
// GlobalMemChecker checker(analyzer_);
...
...
@@ -257,7 +263,7 @@ class SafeMemoryLegalizer : IRMutatorWithAnalyzer {
return
IRMutatorWithAnalyzer
::
VisitStmt_
(
op
);
}
static
bool
HasInnerLoop
(
const
Stmt
&
stmt
)
{
static
bool
HasInnerLoop
(
const
Stmt
&
stmt
)
{
LeafForFinder
finder
;
finder
(
stmt
);
return
finder
.
leaf_for_nodes
.
size
()
>
0
;
...
...
@@ -279,5 +285,5 @@ tvm::transform::Pass LegalizeSafeMemoryAccess() {
TVM_REGISTER_GLOBAL
(
"tl.transform.LegalizeSafeMemoryAccess"
)
.
set_body_typed
(
LegalizeSafeMemoryAccess
);
}
// namespace tl
}
// namespace tvm
}
// namespace tl
}
// namespace tvm
Prev
1
2
3
4
5
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment