Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
1e339898
Commit
1e339898
authored
Oct 18, 2024
by
aska-0096
Browse files
temp save
parent
11444e4c
Changes
9
Show whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
1251 additions
and
269 deletions
+1251
-269
example/01_gemm/CMakeLists.txt
example/01_gemm/CMakeLists.txt
+1
-0
example/01_gemm/gemm_xdl_fp16_v3.cpp
example/01_gemm/gemm_xdl_fp16_v3.cpp
+3
-3
include/ck/tensor/static_tensor.hpp
include/ck/tensor/static_tensor.hpp
+26
-0
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp
...eration/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp
+138
-22
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_selector.hpp
...ion/gpu/block/blockwise_gemm_pipeline_xdlops_selector.hpp
+6
-2
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3.hpp
...operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3.hpp
+41
-37
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp
...nsor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp
+168
-205
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp
...operation/gpu/thread/threadwise_tensor_slice_transfer.hpp
+827
-0
include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp
include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp
+41
-0
No files found.
example/01_gemm/CMakeLists.txt
View file @
1e339898
...
...
@@ -25,6 +25,7 @@ add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp16_v2)
add_example_executable
(
example_gemm_xdl_fp16_streamk_v3 gemm_xdl_fp16_streamk_v3.cpp
)
add_example_dependencies
(
example_gemm_xdl example_gemm_xdl_fp16_streamk_v3
)
add_example_executable
(
example_gemm_xdl_fp16_v3 gemm_xdl_fp16_v3.cpp
)
target_compile_options
(
example_gemm_xdl_fp16_v3 PRIVATE -mllvm -greedy-reverse-local-assignment=1 -save-temps=$PWD -Wno-gnu-line-marker
)
add_example_dependencies
(
example_gemm_xdl example_gemm_xdl_fp16_v3
)
add_example_executable
(
example_gemm_xdl_fp8_v3 gemm_xdl_fp8_v3.cpp
)
add_example_dependencies
(
example_gemm_xdl example_gemm_xdl_fp8_v3
)
...
...
example/01_gemm/gemm_xdl_fp16_v3.cpp
View file @
1e339898
...
...
@@ -19,7 +19,7 @@ using AElementOp = PassThrough;
using
BElementOp
=
PassThrough
;
using
CElementOp
=
PassThrough
;
static
constexpr
auto
GemmDefault
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNPadding
;
static
constexpr
auto
GemmDefault
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
;
// clang-format off
using
DeviceGemmV2Instance
=
...
...
@@ -29,13 +29,13 @@ using DeviceGemmV2Instance =
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
256
,
224
,
256
,
64
,
8
,
2
,
64
,
8
,
8
,
16
,
16
,
7
,
8
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
8
,
2
,
0
,
1
,
8
,
8
,
0
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
8
,
ck
::
BlockGemmPipelineScheduler
::
Intrawave
,
ck
::
BlockGemmPipelineVersion
::
v3
>
;
// clang-format on
...
...
include/ck/tensor/static_tensor.hpp
View file @
1e339898
...
...
@@ -218,6 +218,32 @@ struct StaticTensorTupleOfVectorBuffer
}
}
template
<
typename
X
,
typename
Idx
,
typename
enable_if
<
has_same_scalar_type
<
S
,
X
>
::
value
&&
is_known_at_compile_time
<
Idx
>::
value
&&
Idx
::
Size
()
==
ndim_
,
bool
>::
type
=
false
>
__host__
__device__
constexpr
void
SetAsType_Print
(
Idx
,
X
x
)
{
constexpr
auto
coord
=
make_tensor_coordinate
(
desc_
,
to_multi_index
(
Idx
{}));
constexpr
index_t
offset
=
coord
.
GetOffset
();
if
(
get_thread_local_1d_id
()
==
0
){
printf
(
"Tid: %d, Index: (%d, %d, %d, %d), Offset: %d
\n
"
,
get_thread_local_1d_id
(),
Idx
{}.
At
(
Number
<
0
>
{}).
value
,
Idx
{}.
At
(
Number
<
1
>
{}).
value
,
Idx
{}.
At
(
Number
<
2
>
{}).
value
,
Idx
{}.
At
(
Number
<
3
>
{}).
value
,
offset
);
}
constexpr
bool
is_valid
=
coordinate_has_valid_offset
(
desc_
,
coord
);
if
constexpr
(
is_valid
)
{
data_
.
template
SetAsType
<
X
>(
Number
<
offset
>
{},
x
);
}
}
// Get read access to V. No is_valid check
// Idx is for S, not V. Idx should be aligned with V
template
<
typename
Idx
>
...
...
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp
View file @
1e339898
...
...
@@ -30,6 +30,8 @@ template <index_t BlockSize,
index_t
MRepeat
,
index_t
NRepeat
,
index_t
KPack
,
bool
TransposeA
=
false
,
bool
TransposeB
=
false
,
bool
TransposeC
=
false
>
struct
BlockwiseGemmXdlops_pipeline_base
{
...
...
@@ -152,6 +154,38 @@ struct BlockwiseGemmXdlops_pipeline_base
return
make_tuple
(
c_thread_m
,
c_thread_n
);
}
// Contiguous output tile
template
<
index_t
m0
,
index_t
n0
,
index_t
xdlops_i
,
index_t
blk_i
>
__device__
static
auto
CalculateCThreadOriginDataIndexContiguous
(
Number
<
m0
>
,
Number
<
n0
>
,
Number
<
xdlops_i
>
,
Number
<
blk_i
>
)
{
const
auto
wave_idx
=
GetWaveIdx
();
const
auto
waveId_m
=
wave_idx
[
I0
];
const
auto
waveId_n
=
wave_idx
[
I1
];
const
auto
blk_idx
=
xdlops_gemm
.
GetBeginOfThreadBlk
(
xdlops_i
,
blk_i
);
constexpr
auto
mrepeat_mwave_mperxdl_to_m_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_unmerge_transform
(
make_tuple
(
MRepeat
,
MWaves
,
MPerXDL
))),
make_tuple
(
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{}));
constexpr
auto
nrepeat_nwave_nperxdl_to_n_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_unmerge_transform
(
make_tuple
(
NWaves
,
NPerXDL
,
NRepeat
))),
make_tuple
(
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{}));
const
index_t
c_thread_m
=
mrepeat_mwave_mperxdl_to_m_adaptor
.
CalculateBottomIndex
(
make_tuple
(
m0
,
waveId_m
,
blk_idx
[
I0
]))[
I0
];
const
index_t
c_thread_n
=
nrepeat_nwave_nperxdl_to_n_adaptor
.
CalculateBottomIndex
(
make_tuple
(
waveId_n
,
blk_idx
[
I1
],
n0
))[
I0
];
return
make_tuple
(
c_thread_m
,
c_thread_n
);
}
template
<
index_t
m0
,
index_t
n0
,
index_t
xdlops_i
,
index_t
blk_i
>
__device__
static
auto
CalculateCThreadOriginDataIndex8D
(
Number
<
m0
>
,
Number
<
n0
>
,
Number
<
xdlops_i
>
,
Number
<
blk_i
>
)
...
...
@@ -212,6 +246,21 @@ struct BlockwiseGemmXdlops_pipeline_base
make_tuple
(
Number
<
MRepeat
>
{},
Number
<
NRepeat
>
{},
I1
,
I1
,
M0
,
M1
,
M2
,
N
));
}
// Contiguous output tile
__host__
__device__
static
constexpr
auto
GetCThreadDescriptor_MBlock_NBlock_M0_M1_N0_M2_M3_N1_N2_M4
()
{
constexpr
auto
c_m0_m1_m2_n_tblk_lens
=
xdlops_gemm
.
GetCM0M1M2NThreadBlkLengths
();
constexpr
auto
M0
=
c_m0_m1_m2_n_tblk_lens
[
I0
];
constexpr
auto
M1
=
c_m0_m1_m2_n_tblk_lens
[
I1
];
constexpr
auto
M2
=
c_m0_m1_m2_n_tblk_lens
[
I2
];
constexpr
auto
N
=
c_m0_m1_m2_n_tblk_lens
[
I3
];
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
I1
,
Number
<
MRepeat
>
{},
I1
,
I1
,
M0
,
M1
,
N
,
Number
<
NRepeat
>
{},
M2
));
}
__host__
__device__
static
constexpr
auto
GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2
()
{
constexpr
auto
c_m0_m1_m2_n_tblk_lens
=
xdlops_gemm
.
GetCM0M1M2NThreadBlkLengths
();
...
...
@@ -253,6 +302,23 @@ struct BlockwiseGemmXdlops_pipeline_base
return
xdlops_gemm
.
MakeCDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
(
c_block_desc_m0_n0_m1_n1_m2_n2
);
}
__host__
__device__
static
constexpr
auto
GetCBlockDescriptor_MBlock_NBlock_M0_M1_N0_M2_M3_N1_N2_M4
()
{
constexpr
auto
c_block_desc_mblock_nblock_m0_n0_m1_n1_m2_n2
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
I1
,
Number
<
MRepeat
>
{},
Number
<
NRepeat
>
{},
Number
<
MWaves
>
{},
Number
<
NWaves
>
{},
Number
<
MPerXDL
>
{},
Number
<
NPerXDL
>
{}));
return
xdlops_gemm
.
MakeCDescriptor_MBlock_NBlock_M0_M1_N0_M2_M3_N1_N2_M4
(
c_block_desc_mblock_nblock_m0_n0_m1_n1_m2_n2
);
}
__host__
__device__
static
constexpr
auto
GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2
()
{
constexpr
auto
c_block_desc_g_m0_n0_m1_n1_m2_n2
=
...
...
@@ -327,28 +393,78 @@ struct BlockwiseGemmXdlops_pipeline_base
static
constexpr
auto
c_thread_desc_
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MRepeat
>
{},
Number
<
NRepeat
>
{},
xdlops_gemm
.
GetRegSizePerXdlops
()));
using
AThreadCopy
=
ThreadwiseTensorSliceTransfer_v4
<
ADataType
,
template
<
bool
Transpose
>
struct
AThreadCopySelector
;
template
<
>
struct
AThreadCopySelector
<
false
>
{
using
type
=
ThreadwiseTensorSliceTransfer_v5
<
ADataType
,
ComputeDataType
,
decltype
(
a_block_desc_m0_m1_m2_k
),
decltype
(
a_thread_desc_
),
Sequence
<
1
,
1
,
1
,
KPack
>
,
Sequence
<
MRepeat
,
1
,
1
,
KPack
>
,
Sequence
<
0
,
1
,
2
,
3
>
,
Sequence
<
0
,
1
,
2
,
3
>
,
3
,
3
,
A_K1
,
A_K1
>
;
};
using
BThreadCopy
=
ThreadwiseTensorSliceTransfer_v4
<
BDataType
,
template
<
>
struct
AThreadCopySelector
<
true
>
{
using
type
=
ThreadwiseTensorSliceTransfer_v5
<
ADataType
,
ComputeDataType
,
decltype
(
a_block_desc_m0_m1_m2_k
),
decltype
(
a_thread_desc_
),
Sequence
<
MRepeat
,
1
,
1
,
KPack
>
,
Sequence
<
3
,
1
,
2
,
0
>
,
Sequence
<
0
,
1
,
2
,
3
>
,
0
,
3
,
MRepeat
,
A_K1
>
;
};
template
<
bool
Transpose
>
struct
BThreadCopySelector
;
template
<
>
struct
BThreadCopySelector
<
false
>
{
using
type
=
ThreadwiseTensorSliceTransfer_v5
<
BDataType
,
ComputeDataType
,
decltype
(
b_block_desc_n0_n1_n2_k
),
decltype
(
b_thread_desc_
),
Sequence
<
1
,
1
,
1
,
KPack
>
,
Sequence
<
NRepeat
,
1
,
1
,
KPack
>
,
Sequence
<
0
,
1
,
2
,
3
>
,
Sequence
<
0
,
1
,
2
,
3
>
,
3
,
3
,
B_K1
,
B_K1
>
;
};
template
<
>
struct
BThreadCopySelector
<
true
>
{
using
type
=
ThreadwiseTensorSliceTransfer_v5
<
BDataType
,
ComputeDataType
,
decltype
(
b_block_desc_n0_n1_n2_k
),
decltype
(
b_thread_desc_
),
Sequence
<
NRepeat
,
1
,
1
,
KPack
>
,
Sequence
<
3
,
1
,
2
,
0
>
,
Sequence
<
0
,
1
,
2
,
3
>
,
0
,
3
,
NRepeat
,
B_K1
>
;
};
AThreadCopy
a_thread_copy_
;
BThreadCopy
b_thread_copy_
;
typename
AThreadCopy
Selector
<
TransposeA
>::
type
a_thread_copy_
;
typename
BThreadCopy
Selector
<
TransposeB
>::
type
b_thread_copy_
;
};
}
// namespace ck
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_selector.hpp
View file @
1e339898
...
...
@@ -40,7 +40,9 @@ template <BlockGemmPipelineVersion BlkGemmPipelineVer,
index_t
NPerXDL
,
index_t
MRepeat
,
index_t
NRepeat
,
index_t
KPack
>
index_t
KPack
,
bool
TransposeA
,
bool
TransposeB
>
constexpr
auto
BlockGemmPipeline_Selector
()
{
if
constexpr
(
BlkGemmPipelineVer
==
BlockGemmPipelineVersion
::
v1
)
...
...
@@ -110,7 +112,9 @@ constexpr auto BlockGemmPipeline_Selector()
NPerXDL
,
MRepeat
,
NRepeat
,
KPack
>
{};
KPack
,
TransposeA
,
TransposeB
>
{};
}
else
if
constexpr
(
BlkGemmPipelineVer
==
BlockGemmPipelineVersion
::
v4
)
{
...
...
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3.hpp
View file @
1e339898
...
...
@@ -32,7 +32,9 @@ template <BlockGemmPipelineScheduler BlkGemmPipelineVer,
index_t
NPerXDL
,
index_t
MRepeat
,
index_t
NRepeat
,
index_t
KPacks
>
index_t
KPack
,
bool
TransposeA
,
bool
TransposeB
>
struct
BlockwiseGemmXdlops_pipeline_v3
{
};
...
...
@@ -55,7 +57,9 @@ template <index_t BlockSize,
index_t
NPerXDL
,
index_t
MRepeat
,
index_t
NRepeat
,
index_t
KPack
index_t
KPack
,
bool
TransposeA
,
bool
TransposeB
// ,bool TransposeC //disable transposec right now...
>
struct
BlockwiseGemmXdlops_pipeline_v3
<
BlockGemmPipelineScheduler
::
Intrawave
,
...
...
@@ -77,7 +81,9 @@ struct BlockwiseGemmXdlops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
NPerXDL
,
MRepeat
,
NRepeat
,
KPack
>
KPack
,
TransposeA
,
TransposeB
>
:
BlockwiseGemmXdlops_pipeline_base
<
BlockSize
,
ADataType
,
BDataType
,
...
...
@@ -96,7 +102,9 @@ struct BlockwiseGemmXdlops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
NPerXDL
,
MRepeat
,
NRepeat
,
KPack
>
KPack
,
TransposeA
,
TransposeB
>
{
using
Base
=
BlockwiseGemmXdlops_pipeline_base
<
BlockSize
,
...
...
@@ -117,7 +125,9 @@ struct BlockwiseGemmXdlops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
NPerXDL
,
MRepeat
,
NRepeat
,
KPack
>
;
KPack
,
TransposeA
,
TransposeB
>
;
using
Base
::
I0
;
using
Base
::
I1
;
using
Base
::
KRepeat
;
...
...
@@ -322,23 +332,20 @@ struct BlockwiseGemmXdlops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
// Local prefetch 1
block_sync_lds
();
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k0
)
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
a_thread_copy_
.
Run
(
a_block_desc_m0_m1_m2_k
,
make_tuple
(
m0
,
I0
,
I0
,
Number
<
k0
*
AMmaKStride
>
{}),
a_block_buf
,
a_thread_desc_
,
make_tuple
(
m0
,
I0
,
k0
,
I0
),
a_thread_buf
);
});
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
// a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
// make_tuple(I0, I0, I0, Number<k0 * AMmaKStride>{}),
// a_block_buf,
// a_thread_desc_,
// make_tuple(I0, I0, k0, I0),
// a_thread_buf);
b_thread_copy_
.
Run
(
b_block_desc_n0_n1_n2_k
,
make_tuple
(
n
0
,
I0
,
I0
,
Number
<
k0
*
BMmaKStride
>
{}),
make_tuple
(
I
0
,
I0
,
I0
,
Number
<
k0
*
BMmaKStride
>
{}),
b_block_buf
,
b_thread_desc_
,
make_tuple
(
n
0
,
I0
,
k0
,
I0
),
make_tuple
(
I
0
,
I0
,
k0
,
I0
),
b_thread_buf
);
});
});
__builtin_amdgcn_sched_barrier
(
0
);
...
...
@@ -392,23 +399,20 @@ struct BlockwiseGemmXdlops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
block_sync_lds
();
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k0
)
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
a_thread_copy_
.
Run
(
a_block_desc_m0_m1_m2_k
,
make_tuple
(
m0
,
I0
,
I0
,
Number
<
k0
*
AMmaKStride
>
{}),
a_block_buf
,
a_thread_desc_
,
make_tuple
(
m0
,
I0
,
k0
,
I0
),
a_thread_buf
);
});
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
// a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
// make_tuple(I0, I0, I0, Number<k0 * AMmaKStride>{}),
// a_block_buf,
// a_thread_desc_,
// make_tuple(I0, I0, k0, I0),
// a_thread_buf);
b_thread_copy_
.
Run
(
b_block_desc_n0_n1_n2_k
,
make_tuple
(
n
0
,
I0
,
I0
,
Number
<
k0
*
BMmaKStride
>
{}),
make_tuple
(
I
0
,
I0
,
I0
,
Number
<
k0
*
BMmaKStride
>
{}),
b_block_buf
,
b_thread_desc_
,
make_tuple
(
n
0
,
I0
,
k0
,
I0
),
make_tuple
(
I
0
,
I0
,
k0
,
I0
),
b_thread_buf
);
});
});
HotLoopScheduler
();
__builtin_amdgcn_sched_barrier
(
0
);
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp
View file @
1e339898
...
...
@@ -221,6 +221,41 @@ struct GridwiseGemm_xdl_cshuffle_v3
make_tuple
(
Sequence
<
3
>
{},
Sequence
<
0
,
1
,
2
>
{}));
}
template
<
index_t
MNXdlPerWave
,
index_t
MNWaves
,
index_t
MNPerXdl
,
typename
TileDesc_K0_MN_K1
>
__host__
__device__
static
constexpr
auto
MakeGemmMmaTileDescriptorCongruous
(
const
TileDesc_K0_MN_K1
&
)
{
constexpr
index_t
K0
=
TileDesc_K0_MN_K1
{}.
GetLength
(
Number
<
0
>
{});
constexpr
index_t
K1
=
TileDesc_K0_MN_K1
{}.
GetLength
(
Number
<
2
>
{});
return
transform_tensor_descriptor
(
TileDesc_K0_MN_K1
{},
make_tuple
(
make_merge_transform_v3_division_mod
(
make_tuple
(
Number
<
K0
>
{},
Number
<
K1
>
{})),
make_unmerge_transform
(
make_tuple
(
Number
<
MNWaves
>
{},
Number
<
MNPerXdl
>
{},
Number
<
MNXdlPerWave
>
{}))),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
3
>
{},
Sequence
<
1
,
2
,
0
>
{}));
#if 0
constexpr auto mma_transformed =
transform_tensor_descriptor(
TileDesc_K0_MN_K1{},
make_tuple(make_merge_transform_v3_division_mod(make_tuple(Number<K0>{}, Number<K1>{})),
make_unmerge_transform(make_tuple(
Number<MNWaves>{}, Number<MNPerXdl>{}, Number<MNXdlPerWave>{}))),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2, 3>{}));
return transform_tensor_descriptor(
mma_transformed,
make_tuple(make_pass_through_transform(Number<KPerBlock>{}),
make_pass_through_transform(Number<MNWaves>{}),
make_pass_through_transform(Number<MNPerXdl>{}),
make_pass_through_transform(Number<MNXdlPerWave>{})),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<3>{}, Sequence<1>{}, Sequence<2>{}, Sequence<0>{}));
#endif
}
__host__
__device__
static
auto
MakeAGridDescriptor_AK0_M_AK1
(
index_t
M
,
index_t
MPad
,
index_t
K
,
index_t
KPad
,
index_t
StrideA
,
index_t
AK0
)
{
...
...
@@ -391,7 +426,20 @@ struct GridwiseGemm_xdl_cshuffle_v3
{
constexpr
index_t
MWaves
=
MPerBlock
/
(
MXdlPerWave
*
MPerXdl
);
return
MakeGemmMmaTileDescriptor
<
MXdlPerWave
,
MWaves
,
MPerXdl
>
(
ABlockDesc_AK0_M_AK1
{});
constexpr
auto
a_mma_desc
=
[
&
]()
{
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
ALayout
>::
value
)
{
return
MakeGemmMmaTileDescriptor
<
MXdlPerWave
,
MWaves
,
MPerXdl
>
(
ABlockDesc_AK0_M_AK1
{});
}
else
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
ColumnMajor
,
ALayout
>::
value
)
{
return
MakeGemmMmaTileDescriptorCongruous
<
MXdlPerWave
,
MWaves
,
MPerXdl
>
(
ABlockDesc_AK0_M_AK1
{});
}
}();
return
a_mma_desc
;
}
template
<
typename
BBlockDesc_BK0_N_BK1
>
...
...
@@ -400,7 +448,20 @@ struct GridwiseGemm_xdl_cshuffle_v3
{
constexpr
index_t
NWaves
=
NPerBlock
/
(
NXdlPerWave
*
NPerXdl
);
return
MakeGemmMmaTileDescriptor
<
NXdlPerWave
,
NWaves
,
NPerXdl
>
(
BBlockDesc_BK0_N_BK1
{});
constexpr
auto
b_mma_desc
=
[
&
]()
{
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
ColumnMajor
,
BLayout
>::
value
)
{
return
MakeGemmMmaTileDescriptor
<
NXdlPerWave
,
NWaves
,
NPerXdl
>
(
BBlockDesc_BK0_N_BK1
{});
}
else
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
BLayout
>::
value
)
{
return
MakeGemmMmaTileDescriptorCongruous
<
NXdlPerWave
,
NWaves
,
NPerXdl
>
(
BBlockDesc_BK0_N_BK1
{});
}
}();
return
b_mma_desc
;
}
__host__
__device__
static
auto
...
...
@@ -662,6 +723,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
}
else
// ColumnMajor A
{
#if 0
// kfold and mpair dimension is not always required.
// more dimension in merge_transform increase the difficulty of generating immarg offset
// for compiler.
...
...
@@ -746,6 +808,12 @@ struct GridwiseGemm_xdl_cshuffle_v3
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
return a_lds_block_desc_ak0_m_ak1;
#endif
static_assert
(
ABlockTransferSrcScalarPerVector
%
MXdlPerWave
==
0
);
return
make_naive_tensor_descriptor
(
make_tuple
(
AK0Number
,
Number
<
MPerBlock
>
{},
AK1Number
),
make_tuple
(
Number
<
AK1Number
*
MPerBlock
>
{},
I1
,
Number
<
MPerBlock
>
{}));
}
}
...
...
@@ -799,6 +867,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
}
else
// RowMajor B
{
#if 0
constexpr auto N0 = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(I1);
constexpr auto N1 = NPerBlock / N0;
...
...
@@ -880,6 +949,12 @@ struct GridwiseGemm_xdl_cshuffle_v3
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
return b_lds_block_desc_bk0_n_bk1;
#endif
static_assert
(
BBlockTransferSrcScalarPerVector
%
NXdlPerWave
==
0
);
return
make_naive_tensor_descriptor
(
make_tuple
(
BK0Number
,
Number
<
NPerBlock
>
{},
BK1Number
),
make_tuple
(
Number
<
BK1Number
*
NPerBlock
>
{},
I1
,
Number
<
NPerBlock
>
{}));
}
}
...
...
@@ -922,7 +997,9 @@ struct GridwiseGemm_xdl_cshuffle_v3
NPerXdl
,
MXdlPerWave
,
NXdlPerWave
,
KPack
>
())
>
;
KPack
,
is_same
<
tensor_layout
::
gemm
::
ColumnMajor
,
ALayout
>::
value
,
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
BLayout
>::
value
>
())
>
;
__device__
static
constexpr
index_t
GetSharedMemoryNumberOfByte
()
{
...
...
@@ -1255,11 +1332,11 @@ struct GridwiseGemm_xdl_cshuffle_v3
decltype
(
a_grid_desc_ak0_m_ak1
),
decltype
(
a_block_desc_ak0_m_ak1
),
ABlockTransferSrcAccessOrder
,
Sequence
<
0
,
1
,
2
>
,
ABlockTransferSrcAccessOrder
,
ABlockTransferSrcVectorDim
,
ABlockTransferSrcVectorDim
,
2
,
ABlockTransferSrcScalarPerVector
,
ABlockTransfer
Dst
ScalarPerVector
_AK1
,
ABlockTransfer
Src
ScalarPerVector
,
1
,
1
,
AThreadTransferSrcResetCoordinateAfterRun
,
...
...
@@ -1286,11 +1363,11 @@ struct GridwiseGemm_xdl_cshuffle_v3
decltype
(
b_grid_desc_bk0_n_bk1
),
decltype
(
b_block_desc_bk0_n_bk1
),
BBlockTransferSrcAccessOrder
,
Sequence
<
0
,
1
,
2
>
,
BBlockTransferSrcAccessOrder
,
BBlockTransferSrcVectorDim
,
BBlockTransferSrcVectorDim
,
2
,
BBlockTransferSrcScalarPerVector
,
BBlockTransfer
Dst
ScalarPerVector
_BK1
,
BBlockTransfer
Src
ScalarPerVector
,
1
,
1
,
BThreadTransferSrcResetCoordinateAfterRun
,
...
...
@@ -1343,63 +1420,32 @@ struct GridwiseGemm_xdl_cshuffle_v3
c_thread_buf
,
num_k_block_main_loop
);
// shuffle C and write out
{
static_assert
(
MXdlPerWave
%
CShuffleMXdlPerWavePerShuffle
==
0
&&
NXdlPerWave
%
CShuffleNXdlPerWavePerShuffle
==
0
,
"wrong!"
);
constexpr
index_t
MWave
=
MPerBlock
/
(
MXdlPerWave
*
MPerXdl
);
constexpr
index_t
NWave
=
NPerBlock
/
(
NXdlPerWave
*
NPerXdl
);
// TODO: hacky, fix it!
constexpr
auto
c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2
=
blockwise_gemm_pipeline
.
GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
();
// TODO: hacky, fix it!
// c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths
constexpr
auto
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
=
blockwise_gemm_pipeline
.
GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
();
constexpr
auto
M0
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
.
GetLength
(
I0
);
constexpr
auto
N0
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
.
GetLength
(
I1
);
constexpr
auto
M1
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
.
GetLength
(
I2
);
constexpr
auto
N1
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
.
GetLength
(
I3
);
constexpr
auto
M2
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
.
GetLength
(
I4
);
constexpr
auto
M3
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
.
GetLength
(
I5
);
constexpr
auto
M4
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
.
GetLength
(
I6
);
constexpr
auto
N2
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
.
GetLength
(
I7
);
// Epilogue
constexpr
auto
c_thread_desc_mblock_nblock_m0_m1_n0_m2_m3_n1_n2_m4
=
blockwise_gemm_pipeline
.
GetCThreadDescriptor_MBlock_NBlock_M0_M1_N0_M2_M3_N1_N2_M4
();
constexpr
auto
c_block_desc_mblock_nblock_m0_m1_n0_m2_m3_n1_n2_m4
=
blockwise_gemm_pipeline
.
GetCBlockDescriptor_MBlock_NBlock_M0_M1_N0_M2_M3_N1_N2_M4
();
constexpr
auto
M0
=
c_block_desc_mblock_nblock_m0_m1_n0_m2_m3_n1_n2_m4
.
GetLength
(
Number
<
2
>
{});
constexpr
auto
M1
=
c_block_desc_mblock_nblock_m0_m1_n0_m2_m3_n1_n2_m4
.
GetLength
(
Number
<
3
>
{});
constexpr
auto
N0
=
c_block_desc_mblock_nblock_m0_m1_n0_m2_m3_n1_n2_m4
.
GetLength
(
Number
<
4
>
{});
constexpr
auto
M2
=
c_block_desc_mblock_nblock_m0_m1_n0_m2_m3_n1_n2_m4
.
GetLength
(
Number
<
5
>
{});
constexpr
auto
M3
=
c_block_desc_mblock_nblock_m0_m1_n0_m2_m3_n1_n2_m4
.
GetLength
(
Number
<
6
>
{});
constexpr
auto
N1
=
c_block_desc_mblock_nblock_m0_m1_n0_m2_m3_n1_n2_m4
.
GetLength
(
Number
<
7
>
{});
constexpr
auto
N2
=
c_block_desc_mblock_nblock_m0_m1_n0_m2_m3_n1_n2_m4
.
GetLength
(
Number
<
8
>
{});
constexpr
auto
M4
=
c_block_desc_mblock_nblock_m0_m1_n0_m2_m3_n1_n2_m4
.
GetLength
(
Number
<
9
>
{});
constexpr
auto
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
=
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
();
auto
c_shuffle_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
CShuffleDataType
*>
(
p_shared
),
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
.
GetElementSpaceSize
());
constexpr
auto
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
=
transform_tensor_descriptor
(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
,
make_tuple
(
make_freeze_transform
(
I0
),
make_unmerge_transform
(
make_tuple
(
Number
<
CShuffleMXdlPerWavePerShuffle
>
{},
// M0 (MXdlPerWave) per shuffle
M1
,
// M1 = MWave
M2
,
// M2 * M3 * M4 = MPerXdl
M3
,
M4
)),
make_freeze_transform
(
I0
),
make_unmerge_transform
(
make_tuple
(
Number
<
CShuffleNXdlPerWavePerShuffle
>
{},
// N0 (NXdlPerWave) per shuffle
N1
,
// N1 = NWave
N2
))),
// N2 = NPerXdl
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<>
{},
Sequence
<
0
,
2
,
4
,
5
,
6
>
{},
Sequence
<>
{},
Sequence
<
1
,
3
,
7
>
{}));
// calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index
const
auto
c_thread_mtx_on_block
=
blockwise_gemm_pipeline
.
CalculateCThreadOriginDataIndex
(
I0
,
I0
,
I0
,
I0
);
blockwise_gemm_pipeline
.
CalculateCThreadOriginDataIndex
Contiguous
(
I0
,
I0
,
I0
,
I0
);
const
index_t
m_thread_data_on_block
=
c_thread_mtx_on_block
[
I0
];
const
index_t
n_thread_data_on_block
=
c_thread_mtx_on_block
[
I1
];
...
...
@@ -1414,8 +1460,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor
.
CalculateBottomIndex
(
make_multi_index
(
m_thread_data_on_block
));
const
auto
n_thread_data_on_block_to_n0_n1_n2_adaptor
=
make_single_stage_tensor_adaptor
(
const
auto
n_thread_data_on_block_to_n0_n1_n2_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
N0
,
N1
,
N2
))),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
...
...
@@ -1424,121 +1469,39 @@ struct GridwiseGemm_xdl_cshuffle_v3
n_thread_data_on_block_to_n0_n1_n2_adaptor
.
CalculateBottomIndex
(
make_multi_index
(
n_thread_data_on_block
));
// shuffle: threadwise copy C from VGPR to LDS
auto
c_thread_copy_vgpr_to_lds
=
ThreadwiseTensorSliceTransfer_v1r3
<
AccDataType
,
CShuffleDataType
,
decltype
(
c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2
),
decltype
(
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
),
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
Sequence
<
CShuffleMXdlPerWavePerShuffle
,
CShuffleNXdlPerWavePerShuffle
,
I1
,
I1
,
M2
,
I1
,
// Typecast -> Permute -> Coalesced vector store
auto
c_thread_copy_vgpr_to_global
=
ThreadwiseTensorSliceTransfer_v1r4
<
AccDataType
,
CDataType
,
decltype
(
c_thread_desc_mblock_nblock_m0_m1_n0_m2_m3_n1_n2_m4
),
decltype
(
c_block_desc_mblock_nblock_m0_m1_n0_m2_m3_n1_n2_m4
),
CElementwiseOperation
,
Sequence
<
I1
,
I1
,
M0
,
I1
,
I1
,
M2
,
I1
,
I1
,
N2
,
M4
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
9
,
8
>
,
9
,
8
,
M4
,
I1
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
>
,
7
,
1
,
N2
,
InMemoryDataOperationEnum
::
Set
,
1
,
true
>
{
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
make_multi_index
(
0
,
0
,
false
>
{
c_block_desc_mblock_nblock_m0_m1_n0_m2_m3_n1_n2_m4
,
make_multi_index
(
block_m_id
,
block_n_id
,
m_thread_data_on_block_idx
[
I0
],
m_thread_data_on_block_idx
[
I1
],
n_thread_data_on_block_idx
[
I
1
],
n_thread_data_on_block_idx
[
I
0
],
m_thread_data_on_block_idx
[
I2
],
m_thread_data_on_block_idx
[
I3
],
m_thread_data_on_block_idx
[
I4
],
n_thread_data_on_block_idx
[
I2
]),
ck
::
tensor_operation
::
element_wise
::
PassThrough
{}};
// shuffle: blockwise copy C from LDS to global
auto
c_shuffle_block_copy_lds_to_global
=
ThreadGroupTensorSliceTransfer_v6r1
<
ThisThreadBlock
,
// ThreadGroup
CElementwiseOperation
,
// ElementwiseOperation,
CGlobalMemoryDataOperation
,
// DstInMemOp,
Sequence
<
1
,
CShuffleMXdlPerWavePerShuffle
*
MWave
*
MPerXdl
,
1
,
CShuffleNXdlPerWavePerShuffle
*
NWave
*
NPerXdl
>
,
// BlockSliceLengths,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
Sequence
<
0
,
1
,
2
,
3
>
,
// typename ThreadClusterArrangeOrder,
CShuffleDataType
,
// typename SrcData,
CDataType
,
// typename DstData,
decltype
(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
),
decltype
(
c_grid_desc_mblock_mperblock_nblock_nperblock
),
Sequence
<
0
,
1
,
2
,
3
>
,
// typename DimAccessOrder,
3
,
// index_t VectorDim,
CShuffleBlockTransferScalarPerVector_NPerBlock
,
// index_t ScalarPerVector,
true
,
// bool ThreadTransferSrcResetCoordinateAfterRun,
false
>
// bool ThreadTransferDstResetCoordinateAfterRun>
{
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
,
make_multi_index
(
0
,
0
,
0
,
0
),
c_grid_desc_mblock_mperblock_nblock_nperblock
,
make_multi_index
(
block_m_id
,
0
,
block_n_id
,
0
),
n_thread_data_on_block_idx
[
I1
],
n_thread_data_on_block_idx
[
I2
],
m_thread_data_on_block_idx
[
I4
]),
c_element_op
};
// space filling curve for threadwise C in VGPR
constexpr
auto
sfc_c_vgpr
=
SpaceFillingCurve
<
Sequence
<
MXdlPerWave
,
NXdlPerWave
,
1
,
1
,
M2
,
1
,
M4
,
1
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
>
,
Sequence
<
CShuffleMXdlPerWavePerShuffle
,
CShuffleNXdlPerWavePerShuffle
,
1
,
1
,
M2
,
1
,
M4
,
1
>>
{};
// space filling curve for shuffled blockwise C in global mem
constexpr
auto
sfc_c_global
=
SpaceFillingCurve
<
Sequence
<
1
,
MPerBlock
,
1
,
NPerBlock
>
,
Sequence
<
0
,
2
,
1
,
3
>
,
Sequence
<
1
,
CShuffleMXdlPerWavePerShuffle
*
MWave
*
MPerXdl
,
1
,
CShuffleNXdlPerWavePerShuffle
*
NWave
*
NPerXdl
>>
{};
constexpr
index_t
num_access
=
sfc_c_vgpr
.
GetNumOfAccess
();
static_assert
(
num_access
==
sfc_c_global
.
GetNumOfAccess
(),
"wrong!"
);
static_for
<
0
,
num_access
,
1
>
{}([
&
](
auto
access_id
)
{
// make sure it's safe to write to LDS
block_sync_lds
();
// each thread write its data from VGPR to LDS
c_thread_copy_vgpr_to_lds
.
Run
(
c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
sfc_c_vgpr
.
GetIndexTupleOfNumber
(
access_id
),
c_thread_copy_vgpr_to_global
.
Run
(
c_thread_desc_mblock_nblock_m0_m1_n0_m2_m3_n1_n2_m4
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
c_thread_buf
,
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
c_shuffle_block_buf
);
// make sure it's safe to read from LDS
block_sync_lds
();
// each block copy its data from LDS to global
c_shuffle_block_copy_lds_to_global
.
Run
(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
,
c_shuffle_block_buf
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
c_block_desc_mblock_nblock_m0_m1_n0_m2_m3_n1_n2_m4
,
c_grid_buf
);
if
constexpr
(
access_id
<
num_access
-
1
)
{
constexpr
auto
c_global_step
=
sfc_c_global
.
GetForwardStep
(
access_id
);
// move on C
c_shuffle_block_copy_lds_to_global
.
MoveDstSliceWindow
(
c_grid_desc_mblock_mperblock_nblock_nperblock
,
c_global_step
);
}
});
}
}
template
<
bool
HasMainKBlockLoop
,
...
...
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp
View file @
1e339898
...
...
@@ -11,6 +11,7 @@
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_util.hpp"
#include "ck/tensor/static_tensor.hpp"
namespace
ck
{
// Assume:
...
...
@@ -189,6 +190,396 @@ struct ThreadwiseTensorSliceTransfer_v1r3
const
ElementwiseOperation
element_op_
;
};
// namespace ThreadwiseTensorSliceTransfer_v1r3
// Assume:
// 1. src:
// 1. SrcDesc is known at compile-time
// 2. SrcBuffer is StaticBuffer
// 3. SrcSliceOrginIdx is known at compile-time
// 2. dst:
// 1. DstDesc is not known at compile-time
// 2. DstBuffer is DynamicBuffer
// 3. DstSliceOrginIdx is not known at compile time
template
<
typename
SrcData
,
typename
DstData
,
typename
SrcDesc
,
typename
DstDesc
,
typename
ElementwiseOperation
,
typename
SliceLengths
,
typename
SrcDimAccessOrder
,
typename
DstDimAccessOrder
,
index_t
SrcVectorDim
,
index_t
DstVectorDim
,
index_t
SrcScalarPerVector
,
index_t
DstScalarPerVector
,
InMemoryDataOperationEnum
DstInMemOp
,
bool
DstResetCoordinateAfterRun
,
typename
enable_if
<
SrcDesc
::
IsKnownAtCompileTime
(),
bool
>
::
type
=
false
>
struct
ThreadwiseTensorSliceTransfer_v1r4
{
static
constexpr
index_t
nDim
=
SliceLengths
::
Size
();
using
Index
=
MultiIndex
<
nDim
>
;
using
DstCoord
=
decltype
(
make_tensor_coordinate
(
DstDesc
{},
Index
{}));
using
DstCoordStep
=
decltype
(
make_tensor_coordinate_step
(
DstDesc
{},
Index
{}));
__device__
constexpr
ThreadwiseTensorSliceTransfer_v1r4
(
const
DstDesc
&
dst_desc
,
const
Index
&
dst_slice_origin_idx
,
const
ElementwiseOperation
&
element_op
)
:
dst_coord_
(
make_tensor_coordinate
(
dst_desc
,
dst_slice_origin_idx
)),
element_op_
{
element_op
}
{
static_assert
(
SrcDesc
::
IsKnownAtCompileTime
(),
"wrong! SrcDesc need to known at compile-time"
);
static_assert
(
SliceLengths
::
At
(
Number
<
DstVectorDim
>
{})
%
DstScalarPerVector
==
0
,
"wrong! Not divisible"
);
}
__device__
void
SetDstSliceOrigin
(
const
DstDesc
&
dst_desc
,
const
Index
&
dst_slice_origin_idx
)
{
dst_coord_
=
make_tensor_coordinate
(
dst_desc
,
dst_slice_origin_idx
);
}
template
<
typename
SrcSliceOriginIdx
,
typename
SrcBuffer
,
typename
DstBuffer
>
__device__
void
Run
(
const
SrcDesc
&
,
const
SrcSliceOriginIdx
&
,
const
SrcBuffer
&
src_buf
,
const
DstDesc
&
dst_desc
,
DstBuffer
&
dst_buf
)
{
static_assert
(
SrcDesc
::
IsKnownAtCompileTime
(),
"wrong! SrcDesc need to known at compile-time"
);
static_assert
(
is_known_at_compile_time
<
remove_cvref_t
<
SrcSliceOriginIdx
>>::
value
,
"wrong! SrcSliceOrigin need to known at compile-time"
);
static_assert
(
SrcBuffer
::
IsStaticBuffer
(),
"wrong! SrcBuffer need to be StaticBuffer"
);
constexpr
auto
src_thread_scratch_desc_
=
decltype
(
GetSrcThreadScratchDescriptor
()){};
constexpr
auto
dst_thread_scratch_desc_
=
decltype
(
GetDstThreadScratchDescriptor
()){};
using
SrcThreadScratch
=
StaticTensorTupleOfVectorBuffer
<
AddressSpaceEnum
::
Vgpr
,
DstData
,
SrcScalarPerVector
,
decltype
(
src_thread_scratch_desc_
),
true
>
;
using
DstThreadScratch
=
StaticTensorTupleOfVectorBuffer
<
AddressSpaceEnum
::
Vgpr
,
DstData
,
DstScalarPerVector
,
decltype
(
dst_thread_scratch_desc_
),
true
>
;
SrcThreadScratch
src_thread_scratch_
;
DstThreadScratch
dst_thread_scratch_
;
// SrcDesc and src_slice_origin_idx are known at compile-time
constexpr
auto
src_desc
=
remove_cvref_t
<
SrcDesc
>
{};
constexpr
auto
src_slice_origin_idx
=
to_multi_index
(
SrcSliceOriginIdx
{});
// scalar per access on each dim
constexpr
auto
src_scalar_per_access
=
generate_sequence
(
detail
::
lambda_scalar_per_access
<
SrcVectorDim
,
SrcScalarPerVector
>
{},
Number
<
nDim
>
{});
constexpr
auto
access_lengths
=
SliceLengths
{}
/
src_scalar_per_access
;
constexpr
auto
src_scalar_step_in_vector
=
generate_sequence
(
detail
::
lambda_scalar_step_in_vector
<
SrcVectorDim
>
{},
Number
<
nDim
>
{});
constexpr
auto
src_dim_access_order
=
SrcDimAccessOrder
{};
constexpr
auto
ordered_access_lengths
=
container_reorder_given_new2old
(
access_lengths
,
src_dim_access_order
);
static_ford
<
decltype
(
ordered_access_lengths
)
>
{}([
&
](
auto
ordered_access_idx
)
{
vector_type_maker_t
<
DstData
,
SrcScalarPerVector
>
src_vector
;
using
src_vector_t
=
typename
decltype
(
src_vector
)
::
type
;
constexpr
auto
data_to_origin_disp_idx
=
ordered_access_idx
.
ReorderGivenOld2New
(
src_dim_access_order
)
*
src_scalar_per_access
;
// copy data from src_buf into dst_vector
// TODO: It's a hack here to use \p dst_scalar_step_in_vector. Use SpaceFillingCurve?
static_for
<
0
,
SrcScalarPerVector
,
1
>
{}([
&
](
auto
i
)
{
constexpr
index_t
src_offset
=
src_desc
.
CalculateOffset
(
src_slice_origin_idx
+
data_to_origin_disp_idx
+
i
*
src_scalar_step_in_vector
);
DstData
v
;
// apply element-wise operation
element_op_
(
v
,
src_buf
[
Number
<
src_offset
>
{}]);
src_vector
.
template
AsType
<
DstData
>()(
i
)
=
v
;
});
// Set data to scratch
src_thread_scratch_
.
template
SetAsType
<
src_vector_t
>(
data_to_origin_disp_idx
,
src_vector
.
template
AsType
<
src_vector_t
>()[
Number
<
0
>
{}]);
});
if
constexpr
(
SrcVectorDim
!=
DstVectorDim
&&
((
is_same
<
half_t
,
remove_cvref_t
<
DstData
>>::
value
&&
SrcScalarPerVector
%
2
==
0
&&
DstScalarPerVector
%
2
==
0
)
||
(
is_same
<
bhalf_t
,
remove_cvref_t
<
DstData
>>::
value
&&
SrcScalarPerVector
%
2
==
0
&&
DstScalarPerVector
%
2
==
0
)
||
(
is_same
<
int8_t
,
remove_cvref_t
<
DstData
>>::
value
&&
SrcScalarPerVector
%
4
==
0
&&
DstScalarPerVector
%
4
==
0
)
||
(
is_same
<
f8_t
,
remove_cvref_t
<
DstData
>>::
value
&&
SrcScalarPerVector
%
4
==
0
&&
DstScalarPerVector
%
4
==
0
)))
{
// each transpose does
// DstScalarPerVector # of src vectors in src_thread_scratch_
// SrcScalarPerVector # of dst vectors in dst_thread_scratch_
constexpr
index_t
num_src_vector
=
Number
<
DstScalarPerVector
>
{};
constexpr
index_t
num_dst_vector
=
Number
<
SrcScalarPerVector
>
{};
// Assume SrcVectorDim is not the same as DstVectorDim, so we do transpose
// TODO: make this logic generic for all scenario
static_assert
(
SrcVectorDim
!=
DstVectorDim
,
"wrong"
);
constexpr
auto
src_scalar_step_in_vector_
=
generate_sequence
(
detail
::
lambda_scalar_step_in_vector
<
SrcVectorDim
>
{},
Number
<
nDim
>
{});
constexpr
auto
dst_scalar_step_in_vector_
=
generate_sequence
(
detail
::
lambda_scalar_step_in_vector
<
DstVectorDim
>
{},
Number
<
nDim
>
{});
constexpr
auto
scalar_per_access_
=
generate_sequence
(
detail
::
lambda_scalar_per_access_for_src_and_dst
<
SrcVectorDim
,
SrcScalarPerVector
,
DstVectorDim
,
DstScalarPerVector
>
{},
Number
<
nDim
>
{});
constexpr
auto
access_lengths_
=
SliceLengths
{}
/
scalar_per_access_
;
static_ford
<
decltype
(
access_lengths_
)
>
{}([
&
](
auto
access_idx_
)
{
constexpr
auto
data_idx_
=
access_idx_
*
scalar_per_access_
;
constexpr
auto
data_idx_seq_
=
generate_sequence_v2
(
[
&
](
auto
i
)
{
return
Number
<
data_idx_
[
i
]
>
{};
},
Number
<
nDim
>
{});
using
src_vector_t
=
vector_type_maker_t
<
DstData
,
SrcScalarPerVector
>
;
using
dst_vector_t
=
vector_type_maker_t
<
DstData
,
DstScalarPerVector
>
;
// get DstScalarPerVector # of read-only references to src vectors from
// src_thread_scratch_
const
auto
src_vector_refs
=
generate_tie
(
[
&
](
auto
i
)
->
const
src_vector_t
&
{
// i increment corresponds to movement in DstVectorDim
return
src_thread_scratch_
.
GetVectorTypeReference
(
data_idx_seq_
+
i
*
dst_scalar_step_in_vector_
);
},
Number
<
num_src_vector
>
{});
// get SrcScalarPerVector # of references to dst vectors from
// dst_thread_scratch_
auto
dst_vector_refs
=
generate_tie
(
[
&
](
auto
i
)
->
dst_vector_t
&
{
// i increment corresponds to movement in SrcVectorDim
return
dst_thread_scratch_
.
GetVectorTypeReference
(
data_idx_seq_
+
i
*
src_scalar_step_in_vector_
);
},
Number
<
num_dst_vector
>
{});
// do data transpose
transpose_vectors
<
DstData
,
DstScalarPerVector
,
SrcScalarPerVector
>
{}(
src_vector_refs
,
dst_vector_refs
);
});
}
else
{
static_ford
<
SliceLengths
>
{}(
[
&
](
auto
idx
)
{
dst_thread_scratch_
(
idx
)
=
src_thread_scratch_
[
idx
];
});
}
constexpr
auto
dst_scalar_per_access
=
generate_sequence
(
detail
::
lambda_scalar_per_access
<
DstVectorDim
,
DstScalarPerVector
>
{},
Number
<
nDim
>
{});
constexpr
auto
dst_dim_access_order
=
DstDimAccessOrder
{};
constexpr
auto
ordered_dst_access_lengths
=
container_reorder_given_new2old
(
access_lengths
,
dst_dim_access_order
);
static_ford
<
decltype
(
ordered_dst_access_lengths
)
>
{}([
&
](
auto
ordered_dst_access_idx
)
{
using
dst_vector_type
=
vector_type_maker_t
<
DstData
,
DstScalarPerVector
>
;
using
dst_vector_t
=
typename
dst_vector_type
::
type
;
constexpr
auto
data_to_origin_disp_idx
=
ordered_dst_access_idx
.
ReorderGivenOld2New
(
dst_dim_access_order
)
*
dst_scalar_per_access
;
// copy data from dst_thread_scratch_ into dst_vector_container
auto
dst_vector
=
dst_vector_type
{
dst_thread_scratch_
.
template
GetAsType
<
dst_vector_t
>(
data_to_origin_disp_idx
)};
const
bool
is_dst_valid
=
coordinate_has_valid_offset_assuming_visible_index_is_valid
(
dst_desc
,
dst_coord_
);
// copy data from dst_vector into dst_buf
dst_buf
.
template
Update
<
DstInMemOp
,
dst_vector_t
>(
dst_coord_
.
GetOffset
(),
is_dst_valid
,
dst_vector
.
template
AsType
<
dst_vector_t
>()[
Number
<
0
>
{}]);
move_tensor_coordinate
(
dst_desc
,
dst_coord_
,
make_tensor_coordinate_step
(
dst_desc
,
to_multi_index
(
data_to_origin_disp_idx
)));
});
// move dst coordinate back to slice origin (or not)
if
constexpr
(
DstResetCoordinateAfterRun
)
{
const
auto
dst_reset_step
=
make_tensor_coordinate_step
(
dst_desc
,
GetDstCoordinateResetStep
());
move_tensor_coordinate
(
dst_desc
,
dst_coord_
,
dst_reset_step
);
}
}
__device__
static
constexpr
auto
GetSrcThreadScratchDescriptor
()
{
constexpr
auto
src_scalar_per_access
=
generate_sequence
(
detail
::
lambda_scalar_per_access
<
SrcVectorDim
,
SrcScalarPerVector
>
{},
Number
<
nDim
>
{});
constexpr
auto
src_access_lengths
=
SliceLengths
{}
/
src_scalar_per_access
;
constexpr
auto
src_access_lengths_and_vector_length
=
container_push_back
(
sequence_to_tuple_of_number
(
src_access_lengths
),
Number
<
SrcScalarPerVector
>
{});
// 1st stage of transforms
constexpr
auto
desc0
=
make_naive_tensor_descriptor_packed
(
src_access_lengths_and_vector_length
);
// 2nd stage of transforms
constexpr
auto
transforms
=
generate_tuple
(
[
&
](
auto
i
)
{
if
constexpr
(
i
==
SrcVectorDim
)
{
return
make_merge_transform_v3_division_mod
(
make_tuple
(
src_access_lengths_and_vector_length
[
i
],
src_access_lengths_and_vector_length
[
Number
<
nDim
>
{}]));
}
else
{
return
make_pass_through_transform
(
src_access_lengths_and_vector_length
[
i
]);
}
},
Number
<
nDim
>
{});
constexpr
auto
low_dim_idss
=
generate_tuple
(
[
&
](
auto
i
)
{
if
constexpr
(
i
==
SrcVectorDim
)
{
return
Sequence
<
i
.
value
,
nDim
>
{};
}
else
{
return
Sequence
<
i
.
value
>
{};
}
},
Number
<
nDim
>
{});
constexpr
auto
up_dim_idss
=
generate_tuple
([
&
](
auto
i
)
{
return
Sequence
<
i
.
value
>
{};
},
Number
<
nDim
>
{});
return
transform_tensor_descriptor
(
desc0
,
transforms
,
low_dim_idss
,
up_dim_idss
);
}
__device__
static
constexpr
auto
GetDstThreadScratchDescriptor
()
{
// 1st stage of transforms
constexpr
auto
dst_scalar_per_access
=
generate_sequence
(
detail
::
lambda_scalar_per_access
<
DstVectorDim
,
DstScalarPerVector
>
{},
Number
<
nDim
>
{});
constexpr
auto
dst_access_lengths
=
SliceLengths
{}
/
dst_scalar_per_access
;
constexpr
auto
dst_access_lengths_and_vector_length
=
container_push_back
(
sequence_to_tuple_of_number
(
dst_access_lengths
),
Number
<
DstScalarPerVector
>
{});
constexpr
auto
desc0
=
make_naive_tensor_descriptor_packed
(
dst_access_lengths_and_vector_length
);
// 2nd stage of transforms
constexpr
auto
transforms
=
generate_tuple
(
[
&
](
auto
i
)
{
if
constexpr
(
i
==
DstVectorDim
)
{
return
make_merge_transform_v3_division_mod
(
make_tuple
(
dst_access_lengths_and_vector_length
[
i
],
dst_access_lengths_and_vector_length
[
Number
<
nDim
>
{}]));
}
else
{
return
make_pass_through_transform
(
dst_access_lengths_and_vector_length
[
i
]);
}
},
Number
<
nDim
>
{});
constexpr
auto
low_dim_idss
=
generate_tuple
(
[
&
](
auto
i
)
{
if
constexpr
(
i
==
DstVectorDim
)
{
return
Sequence
<
i
.
value
,
nDim
>
{};
}
else
{
return
Sequence
<
i
.
value
>
{};
}
},
Number
<
nDim
>
{});
constexpr
auto
up_dim_idss
=
generate_tuple
([
&
](
auto
i
)
{
return
Sequence
<
i
.
value
>
{};
},
Number
<
nDim
>
{});
return
transform_tensor_descriptor
(
desc0
,
transforms
,
low_dim_idss
,
up_dim_idss
);
}
__device__
static
constexpr
auto
GetDstCoordinateResetStep
()
{
constexpr
auto
dst_scalar_per_access
=
generate_sequence
(
detail
::
lambda_scalar_per_access
<
DstVectorDim
,
DstScalarPerVector
>
{},
Number
<
nDim
>
{});
using
SpaceFillingCurve
=
SpaceFillingCurve
<
SliceLengths
,
DstDimAccessOrder
,
remove_cv_t
<
decltype
(
dst_scalar_per_access
)
>>
;
constexpr
auto
num_access
=
SpaceFillingCurve
::
GetNumOfAccess
();
if
constexpr
(
num_access
==
0
)
{
return
typename
SpaceFillingCurve
::
Index
{};
}
else
{
constexpr
auto
reset_step
=
SpaceFillingCurve
::
GetStepBetween
(
Number
<
num_access
-
1
>
{},
Number
<
0
>
{});
return
reset_step
;
}
}
// dst_slice_origin_step_idx need to be known at compile-time, for performance reason
__device__
void
MoveDstSliceWindow
(
const
DstDesc
&
dst_desc
,
const
Index
&
dst_slice_origin_step_idx
)
{
// if dst coord was not reset by Run(), then need to adjust the step here
const
auto
adjusted_step_idx
=
DstResetCoordinateAfterRun
?
dst_slice_origin_step_idx
:
dst_slice_origin_step_idx
+
GetDstCoordinateResetStep
();
// is it OK to construct a new step every time?
const
auto
adjusted_step
=
make_tensor_coordinate_step
(
dst_desc
,
adjusted_step_idx
);
move_tensor_coordinate
(
dst_desc
,
dst_coord_
,
adjusted_step
);
}
private:
DstCoord
dst_coord_
;
const
ElementwiseOperation
element_op_
;
};
// namespace ThreadwiseTensorSliceTransfer_v1r4
// Assume:
// 1. src:
// 1. SrcDesc is not known at compile-time
...
...
@@ -1202,6 +1593,442 @@ struct ThreadwiseTensorSliceTransfer_v4
SrcCoord
src_ref_coord_
;
};
// Assume:
// 1. src:
// 1. SrcDesc is known at compile-time
// 2. SrcBuffer is DynamicBuffer
// 3. src_ref_idx is known at run-time
// 4. SrcRefToOriginDisplacement is known at compile-time
// 5. use #-step
// 2. dst:
// 1. DstDesc is known at compile-time
// 2. DstBuffer is StaticBuffer
// 3. DstOriginIdx is known at compile-time
// 4. use direct address calculation
// 3. vector access on src
template
<
typename
SrcData
,
typename
DstData
,
typename
SrcDesc
,
typename
DstDesc
,
typename
SliceLengths
,
typename
SrcDimAccessOrder
,
typename
DstDimAccessOrder
,
index_t
SrcVectorDim
,
index_t
DstVectorDim
,
index_t
SrcScalarPerVector
,
index_t
DstScalarPerVector
,
typename
enable_if
<
SrcDesc
::
IsKnownAtCompileTime
()
&&
DstDesc
::
IsKnownAtCompileTime
(),
bool
>
::
type
=
false
>
struct
ThreadwiseTensorSliceTransfer_v5
{
static
constexpr
index_t
nDim
=
SliceLengths
::
Size
();
using
Index
=
MultiIndex
<
nDim
>
;
using
SrcCoord
=
decltype
(
make_tensor_coordinate
(
SrcDesc
{},
Index
{}));
using
SrcCoordStep
=
decltype
(
make_tensor_coordinate_step
(
SrcDesc
{},
Index
{}));
static
constexpr
auto
I0
=
Number
<
0
>
{};
__device__
constexpr
ThreadwiseTensorSliceTransfer_v5
(
const
Index
&
src_ref_idx
)
:
src_ref_coord_
(
make_tensor_coordinate
(
SrcDesc
{},
src_ref_idx
))
{
static_assert
(
SrcDesc
::
IsKnownAtCompileTime
()
&&
DstDesc
::
IsKnownAtCompileTime
(),
"wrong! SrcDesc and DstDesc need to known at compile-time"
);
static_assert
(
SliceLengths
::
At
(
Number
<
SrcVectorDim
>
{})
%
SrcScalarPerVector
==
0
,
"wrong! Not divisible"
);
}
template
<
typename
SrcRefToOriginDisplacement
,
typename
DstOriginIdx
,
typename
SrcBuffer
,
typename
DstBuffer
>
__device__
void
Run
(
const
SrcDesc
&
,
const
SrcRefToOriginDisplacement
&
,
const
SrcBuffer
&
src_buf
,
const
DstDesc
&
,
const
DstOriginIdx
&
,
DstBuffer
&
dst_buf
)
const
{
static_assert
(
SrcDesc
::
IsKnownAtCompileTime
()
&&
DstDesc
::
IsKnownAtCompileTime
(),
"wrong! SrcDesc and DstDesc need to known at compile-time"
);
static_assert
(
is_same
<
remove_cvref_t
<
typename
SrcBuffer
::
type
>
,
remove_cvref_t
<
SrcData
>>::
value
&&
is_same
<
remove_cvref_t
<
typename
DstBuffer
::
type
>
,
remove_cvref_t
<
DstData
>>::
value
,
"wrong! SrcBuffer or DstBuffer data type is wrong"
);
static_assert
(
DstBuffer
::
IsStaticBuffer
(),
"wrong! DstBuffer need to be StaticBuffer"
);
static_assert
(
is_known_at_compile_time
<
remove_cvref_t
<
SrcRefToOriginDisplacement
>>::
value
&&
is_known_at_compile_time
<
remove_cvref_t
<
DstOriginIdx
>>::
value
,
"wrong! SrcOriginToRefDistance and DstOriginToRefDistance need to be known "
"at compile-time"
);
constexpr
auto
src_thread_scratch_desc_
=
decltype
(
GetSrcThreadScratchDescriptor
()){};
constexpr
auto
dst_thread_scratch_desc_
=
decltype
(
GetDstThreadScratchDescriptor
()){};
using
SrcThreadScratch
=
StaticTensorTupleOfVectorBuffer
<
AddressSpaceEnum
::
Vgpr
,
SrcData
,
SrcScalarPerVector
,
decltype
(
src_thread_scratch_desc_
),
true
>
;
using
DstThreadScratch
=
StaticTensorTupleOfVectorBuffer
<
AddressSpaceEnum
::
Vgpr
,
SrcData
,
DstScalarPerVector
,
decltype
(
dst_thread_scratch_desc_
),
true
>
;
SrcThreadScratch
src_thread_scratch_
;
DstThreadScratch
dst_thread_scratch_
;
// src_thread_scratch_desc_++;
// SrcDesc and DstDesc are known at compile-time
constexpr
auto
src_desc
=
remove_cvref_t
<
SrcDesc
>
{};
constexpr
auto
dst_desc
=
remove_cvref_t
<
DstDesc
>
{};
// SrcOriginToRefDisttance and DstOriginToRefDistance are known at compile-time
constexpr
auto
src_ref_to_origin_disp_idx
=
to_multi_index
(
SrcRefToOriginDisplacement
{});
constexpr
auto
dst_origin_idx
=
to_multi_index
(
DstOriginIdx
{});
constexpr
auto
src_scalar_per_access
=
generate_sequence
(
detail
::
lambda_scalar_per_access
<
SrcVectorDim
,
SrcScalarPerVector
>
{},
Number
<
nDim
>
{});
constexpr
auto
access_lengths
=
SliceLengths
{}
/
src_scalar_per_access
;
constexpr
auto
src_dim_access_order
=
SrcDimAccessOrder
{};
constexpr
auto
ordered_access_lengths
=
container_reorder_given_new2old
(
access_lengths
,
src_dim_access_order
);
static_ford
<
decltype
(
ordered_access_lengths
)
>
{}([
&
](
auto
ordered_access_idx
)
{
// position in slice window
constexpr
auto
data_to_origin_disp_idx
=
ordered_access_idx
.
ReorderGivenOld2New
(
src_dim_access_order
)
*
src_scalar_per_access
;
#if 0
if (get_thread_local_1d_id()==0){
printf("%d, %d, %d, %d\n",
data_to_origin_disp_idx.At(Number<0>{}).value,
data_to_origin_disp_idx.At(Number<1>{}).value,
data_to_origin_disp_idx.At(Number<2>{}).value,
data_to_origin_disp_idx.At(Number<3>{}).value);
}
#endif
// src coordinate
constexpr
auto
src_ref_to_data_disp_idx
=
src_ref_to_origin_disp_idx
+
data_to_origin_disp_idx
;
constexpr
auto
src_ref_to_data_disp_coord_step
=
make_tensor_coordinate_step
(
src_desc
,
src_ref_to_data_disp_idx
);
auto
src_data_coord
=
src_ref_coord_
;
move_tensor_coordinate
(
src_desc
,
src_data_coord
,
src_ref_to_data_disp_coord_step
);
vector_type_maker_t
<
SrcData
,
SrcScalarPerVector
>
src_tmp_vector
;
using
src_vector_t
=
typename
decltype
(
src_tmp_vector
)
::
type
;
const
bool
is_src_valid
=
coordinate_has_valid_offset_assuming_visible_index_is_valid
(
src_desc
,
src_data_coord
);
// copy data from src_buf into src_tmp_vector
src_tmp_vector
.
template
AsType
<
src_vector_t
>()(
Number
<
0
>
{})
=
src_buf
.
template
Get
<
src_vector_t
>(
src_data_coord
.
GetOffset
(),
is_src_valid
);
#if 0
if (get_thread_local_1d_id()<32){
printf("Tid: %02d, Index(%d, %d, %d, %d), offset: %d\n", get_thread_local_1d_id(), src_data_coord.GetIndex().At(Number<0>{}),
src_data_coord.GetIndex().At(Number<1>{}),
src_data_coord.GetIndex().At(Number<2>{}),
src_data_coord.GetIndex().At(Number<3>{}), src_data_coord.GetOffset());
}
#endif
// Set data to scratch
src_thread_scratch_
.
template
SetAsType_Print
<
src_vector_t
>(
data_to_origin_disp_idx
,
src_tmp_vector
.
template
AsType
<
src_vector_t
>()[
I0
]);
});
if
constexpr
(
SrcVectorDim
!=
DstVectorDim
&&
((
is_same
<
half_t
,
remove_cvref_t
<
DstData
>>::
value
&&
SrcScalarPerVector
%
2
==
0
&&
DstScalarPerVector
%
2
==
0
)
||
(
is_same
<
bhalf_t
,
remove_cvref_t
<
DstData
>>::
value
&&
SrcScalarPerVector
%
2
==
0
&&
DstScalarPerVector
%
2
==
0
)
||
(
is_same
<
int8_t
,
remove_cvref_t
<
DstData
>>::
value
&&
SrcScalarPerVector
%
4
==
0
&&
DstScalarPerVector
%
4
==
0
)
||
(
is_same
<
f8_t
,
remove_cvref_t
<
DstData
>>::
value
&&
SrcScalarPerVector
%
4
==
0
&&
DstScalarPerVector
%
4
==
0
)))
{
// each transpose does
// DstScalarPerVector # of src vectors in src_thread_scratch_
// SrcScalarPerVector # of dst vectors in dst_thread_scratch_
constexpr
index_t
num_src_vector
=
Number
<
DstScalarPerVector
>
{};
constexpr
index_t
num_dst_vector
=
Number
<
SrcScalarPerVector
>
{};
// Assume SrcVectorDim is not the same as DstVectorDim, so we do transpose
// TODO: make this logic generic for all scenario
static_assert
(
SrcVectorDim
!=
DstVectorDim
,
"wrong"
);
constexpr
auto
src_scalar_step_in_vector_
=
generate_sequence
(
detail
::
lambda_scalar_step_in_vector
<
SrcVectorDim
>
{},
Number
<
nDim
>
{});
constexpr
auto
dst_scalar_step_in_vector_
=
generate_sequence
(
detail
::
lambda_scalar_step_in_vector
<
DstVectorDim
>
{},
Number
<
nDim
>
{});
constexpr
auto
scalar_per_access_
=
generate_sequence
(
detail
::
lambda_scalar_per_access_for_src_and_dst
<
SrcVectorDim
,
SrcScalarPerVector
,
DstVectorDim
,
DstScalarPerVector
>
{},
Number
<
nDim
>
{});
constexpr
auto
access_lengths_
=
SliceLengths
{}
/
scalar_per_access_
;
static_ford
<
decltype
(
access_lengths_
)
>
{}([
&
](
auto
access_idx_
)
{
constexpr
auto
data_idx_
=
access_idx_
*
scalar_per_access_
;
constexpr
auto
data_idx_seq_
=
generate_sequence_v2
(
[
&
](
auto
i
)
{
return
Number
<
data_idx_
[
i
]
>
{};
},
Number
<
nDim
>
{});
using
src_vector_t
=
vector_type_maker_t
<
SrcData
,
SrcScalarPerVector
>
;
using
dst_vector_t
=
vector_type_maker_t
<
SrcData
,
DstScalarPerVector
>
;
// get DstScalarPerVector # of read-only references to src vectors from
// src_thread_scratch_
const
auto
src_vector_refs
=
generate_tie
(
[
&
](
auto
i
)
->
const
src_vector_t
&
{
// i increment corresponds to movement in DstVectorDim
return
src_thread_scratch_
.
GetVectorTypeReference
(
data_idx_seq_
+
i
*
dst_scalar_step_in_vector_
);
},
Number
<
num_src_vector
>
{});
// get SrcScalarPerVector # of references to dst vectors from
// dst_thread_scratch_
auto
dst_vector_refs
=
generate_tie
(
[
&
](
auto
i
)
->
dst_vector_t
&
{
// i increment corresponds to movement in SrcVectorDim
return
dst_thread_scratch_
.
GetVectorTypeReference
(
data_idx_seq_
+
i
*
src_scalar_step_in_vector_
);
},
Number
<
num_dst_vector
>
{});
// do data transpose
transpose_vectors
<
SrcData
,
DstScalarPerVector
,
SrcScalarPerVector
>
{}(
src_vector_refs
,
dst_vector_refs
);
});
}
else
{
static_ford
<
SliceLengths
>
{}(
[
&
](
auto
idx
)
{
dst_thread_scratch_
(
idx
)
=
src_thread_scratch_
[
idx
];
});
}
constexpr
auto
dst_scalar_per_access
=
generate_sequence
(
detail
::
lambda_scalar_per_access
<
DstVectorDim
,
DstScalarPerVector
>
{},
Number
<
nDim
>
{});
// scalar step (if steping on SrcVectorDim) of each dim
constexpr
auto
dst_scalar_step_in_vector
=
generate_sequence_v2
(
[
&
](
auto
i
)
constexpr
{
if
constexpr
(
i
==
DstVectorDim
)
{
return
Number
<
1
>
{};
}
else
{
return
Number
<
0
>
{};
}
},
Number
<
nDim
>
{});
constexpr
auto
dst_dim_access_order
=
DstDimAccessOrder
{};
constexpr
auto
ordered_dst_access_lengths
=
container_reorder_given_new2old
(
access_lengths
,
dst_dim_access_order
);
static_ford
<
decltype
(
ordered_dst_access_lengths
)
>
{}([
&
](
auto
ordered_dst_access_idx
)
{
// position in slice window
constexpr
auto
data_to_origin_disp_idx
=
ordered_dst_access_idx
.
ReorderGivenOld2New
(
dst_dim_access_order
)
*
dst_scalar_per_access
;
using
src_vector_type
=
vector_type_maker_t
<
SrcData
,
DstScalarPerVector
>
;
using
src_vector_t
=
typename
src_vector_type
::
type
;
// copy data from dst_thread_scratch_ into dst_vector_container
auto
src_tmp_vector
=
src_vector_type
{
dst_thread_scratch_
.
template
GetAsType
<
src_vector_t
>(
data_to_origin_disp_idx
)};
if
constexpr
(
is_same
<
remove_cvref_t
<
SrcData
>
,
f8_t
>::
value
&&
is_same
<
remove_cvref_t
<
DstData
>
,
half_t
>::
value
&&
SrcScalarPerVector
%
2
==
0
)
{
// copy data from src_tmp_vector to dst_tmp_vector (data cast data from SrcData to
// DstData)
vector_type_maker_t
<
DstData
,
DstScalarPerVector
>
dst_tmp_vector
;
constexpr
index_t
pack_size
=
2
;
using
dst_v_t
=
typename
vector_type_maker_t
<
DstData
,
pack_size
>::
type
;
using
src_v_t
=
typename
vector_type_maker_t
<
SrcData
,
pack_size
>::
type
;
static_for
<
0
,
DstScalarPerVector
/
pack_size
,
1
>
{}([
&
](
auto
i
)
{
ck
::
tensor_operation
::
element_wise
::
PassThroughPack2
{}(
dst_tmp_vector
.
template
AsType
<
dst_v_t
>()(
i
),
src_tmp_vector
.
template
AsType
<
src_v_t
>()[
i
]);
});
// copy data from dst_tmp_vector into dst_buf
static_for
<
0
,
DstScalarPerVector
,
1
>
{}([
&
](
auto
i
)
{
constexpr
index_t
dst_offset
=
dst_desc
.
CalculateOffset
(
dst_origin_idx
+
data_to_origin_disp_idx
+
i
*
dst_scalar_step_in_vector
);
dst_buf
(
Number
<
dst_offset
>
{})
=
dst_tmp_vector
.
template
AsType
<
DstData
>()[
i
];
});
}
else
{
// copy data from src_tmp_vector to dst_tmp_vector (data cast data from SrcData to
// DstData)
vector_type_maker_t
<
DstData
,
DstScalarPerVector
>
dst_tmp_vector
;
// TODO: if SrcData and DstData are vetor type, then static_cast may not compile
static_for
<
0
,
DstScalarPerVector
,
1
>
{}([
&
](
auto
i
)
{
dst_tmp_vector
.
template
AsType
<
DstData
>()(
i
)
=
type_convert
<
DstData
>
(
src_tmp_vector
.
template
AsType
<
SrcData
>()[
i
]);
});
// copy data from dst_tmp_vector into dst_buf
static_for
<
0
,
DstScalarPerVector
,
1
>
{}([
&
](
auto
i
)
{
constexpr
index_t
dst_offset
=
dst_desc
.
CalculateOffset
(
dst_origin_idx
+
data_to_origin_disp_idx
+
i
*
dst_scalar_step_in_vector
);
dst_buf
(
Number
<
dst_offset
>
{})
=
dst_tmp_vector
.
template
AsType
<
DstData
>()[
i
];
});
}
});
}
__device__
static
constexpr
auto
GetSrcThreadScratchDescriptor
()
{
constexpr
auto
src_scalar_per_access
=
generate_sequence
(
detail
::
lambda_scalar_per_access
<
SrcVectorDim
,
SrcScalarPerVector
>
{},
Number
<
nDim
>
{});
constexpr
auto
src_access_lengths
=
SliceLengths
{}
/
src_scalar_per_access
;
constexpr
auto
src_access_lengths_and_vector_length
=
container_push_back
(
sequence_to_tuple_of_number
(
src_access_lengths
),
Number
<
SrcScalarPerVector
>
{});
// 1st stage of transforms
constexpr
auto
desc0
=
make_naive_tensor_descriptor_packed
(
src_access_lengths_and_vector_length
);
// 2nd stage of transforms
constexpr
auto
transforms
=
generate_tuple
(
[
&
](
auto
i
)
{
if
constexpr
(
i
==
SrcVectorDim
)
{
return
make_merge_transform_v3_division_mod
(
make_tuple
(
src_access_lengths_and_vector_length
[
i
],
src_access_lengths_and_vector_length
[
Number
<
nDim
>
{}]));
}
else
{
return
make_pass_through_transform
(
src_access_lengths_and_vector_length
[
i
]);
}
},
Number
<
nDim
>
{});
constexpr
auto
low_dim_idss
=
generate_tuple
(
[
&
](
auto
i
)
{
if
constexpr
(
i
==
SrcVectorDim
)
{
return
Sequence
<
i
.
value
,
nDim
>
{};
}
else
{
return
Sequence
<
i
.
value
>
{};
}
},
Number
<
nDim
>
{});
constexpr
auto
up_dim_idss
=
generate_tuple
([
&
](
auto
i
)
{
return
Sequence
<
i
.
value
>
{};
},
Number
<
nDim
>
{});
return
transform_tensor_descriptor
(
desc0
,
transforms
,
low_dim_idss
,
up_dim_idss
);
}
__device__
static
constexpr
auto
GetDstThreadScratchDescriptor
()
{
// 1st stage of transforms
constexpr
auto
dst_scalar_per_access
=
generate_sequence
(
detail
::
lambda_scalar_per_access
<
DstVectorDim
,
DstScalarPerVector
>
{},
Number
<
nDim
>
{});
constexpr
auto
dst_access_lengths
=
SliceLengths
{}
/
dst_scalar_per_access
;
constexpr
auto
dst_access_lengths_and_vector_length
=
container_push_back
(
sequence_to_tuple_of_number
(
dst_access_lengths
),
Number
<
DstScalarPerVector
>
{});
constexpr
auto
desc0
=
make_naive_tensor_descriptor_packed
(
dst_access_lengths_and_vector_length
);
// 2nd stage of transforms
constexpr
auto
transforms
=
generate_tuple
(
[
&
](
auto
i
)
{
if
constexpr
(
i
==
DstVectorDim
)
{
return
make_merge_transform_v3_division_mod
(
make_tuple
(
dst_access_lengths_and_vector_length
[
i
],
dst_access_lengths_and_vector_length
[
Number
<
nDim
>
{}]));
}
else
{
return
make_pass_through_transform
(
dst_access_lengths_and_vector_length
[
i
]);
}
},
Number
<
nDim
>
{});
constexpr
auto
low_dim_idss
=
generate_tuple
(
[
&
](
auto
i
)
{
if
constexpr
(
i
==
DstVectorDim
)
{
return
Sequence
<
i
.
value
,
nDim
>
{};
}
else
{
return
Sequence
<
i
.
value
>
{};
}
},
Number
<
nDim
>
{});
constexpr
auto
up_dim_idss
=
generate_tuple
([
&
](
auto
i
)
{
return
Sequence
<
i
.
value
>
{};
},
Number
<
nDim
>
{});
return
transform_tensor_descriptor
(
desc0
,
transforms
,
low_dim_idss
,
up_dim_idss
);
}
template
<
typename
SrcSliceMoveStepIdx
>
__device__
void
MoveSrcSliceWindow
(
const
SrcDesc
&
,
const
SrcSliceMoveStepIdx
&
src_slice_move_step_idx
)
{
constexpr
auto
src_desc
=
SrcDesc
{};
const
auto
src_slice_move_step_iter
=
make_tensor_coordinate_step
(
src_desc
,
to_multi_index
(
src_slice_move_step_idx
));
move_tensor_coordinate
(
SrcDesc
{},
src_ref_coord_
,
src_slice_move_step_iter
);
}
__device__
void
SetSrcCoord
(
const
Index
&
src_ref_idx
)
{
src_ref_coord_
=
make_tensor_coordinate
(
SrcDesc
{},
src_ref_idx
);
}
private:
SrcCoord
src_ref_coord_
;
};
/**
* @brief Threadwise data transfer
*
...
...
include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp
View file @
1e339898
...
...
@@ -950,6 +950,47 @@ struct XdlopsGemm
Sequence
<
7
>
{}));
}
template
<
typename
CDesc_MBlock_NBlock_M0_N0_M1_N1_M2_N2
>
__host__
__device__
static
constexpr
auto
MakeCDescriptor_MBlock_NBlock_M0_M1_N0_M2_M3_N1_N2_M4
(
const
CDesc_MBlock_NBlock_M0_N0_M1_N1_M2_N2
&
c_desc_mblock_nblock_m0_n0_m1_n1_m2_n2
)
{
const
auto
MBlock
=
c_desc_mblock_nblock_m0_n0_m1_n1_m2_n2
.
GetLength
(
I0
);
const
auto
NBlock
=
c_desc_mblock_nblock_m0_n0_m1_n1_m2_n2
.
GetLength
(
I1
);
const
auto
M0
=
c_desc_mblock_nblock_m0_n0_m1_n1_m2_n2
.
GetLength
(
I2
);
const
auto
N0
=
c_desc_mblock_nblock_m0_n0_m1_n1_m2_n2
.
GetLength
(
I3
);
const
auto
M1
=
c_desc_mblock_nblock_m0_n0_m1_n1_m2_n2
.
GetLength
(
I4
);
const
auto
N1
=
c_desc_mblock_nblock_m0_n0_m1_n1_m2_n2
.
GetLength
(
I5
);
return
transform_tensor_descriptor
(
c_desc_mblock_nblock_m0_n0_m1_n1_m2_n2
,
make_tuple
(
make_pass_through_transform
(
MBlock
),
make_pass_through_transform
(
NBlock
),
make_pass_through_transform
(
M0
),
make_pass_through_transform
(
N0
),
make_pass_through_transform
(
M1
),
make_pass_through_transform
(
N1
),
make_unmerge_transform
(
make_tuple
(
Number
<
mfma_instr
.
num_groups_per_blk
>
{},
Number
<
mfma_instr
.
num_input_blks
>
{},
Number
<
mfma_instr
.
group_size
>
{})),
make_pass_through_transform
(
Number
<
mfma_instr
.
num_threads_per_blk
>
{})),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{},
Sequence
<
5
>
{},
Sequence
<
6
>
{},
Sequence
<
7
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
8
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{},
Sequence
<
5
,
6
,
9
>
{},
Sequence
<
7
>
{}));
}
// transposed XDL output supporting C' = B' * A'
// M2_N2 -> M2_N2_N3_N4
template
<
typename
CDesc_M0_N0_M1_N1_M2_N2
>
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment