Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
ed068043
Commit
ed068043
authored
Nov 15, 2021
by
Jing Zhang
Browse files
merged develop
parents
41852668
e823d518
Changes
74
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1774 additions
and
375 deletions
+1774
-375
CMakeLists.txt
CMakeLists.txt
+4
-0
README.md
README.md
+0
-176
composable_kernel/include/problem_transform/transform_backward_weight_convolution_into_gemm_v4r4r2_atomic_nchw_kcyx_nkhw.hpp
...ht_convolution_into_gemm_v4r4r2_atomic_nchw_kcyx_nkhw.hpp
+147
-0
composable_kernel/include/problem_transform/transform_backward_weight_convolution_into_gemm_v4r4r4_atomic_nhwc_kyxc_nhwk.hpp
...ht_convolution_into_gemm_v4r4r4_atomic_nhwc_kyxc_nhwk.hpp
+147
-0
composable_kernel/include/problem_transform/transform_backward_weight_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk.hpp
...rd_weight_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk.hpp
+132
-0
composable_kernel/include/problem_transform/transform_backward_weight_convolution_into_gemm_v4r4r5_nhwc_kyxc_nhwk.hpp
...rd_weight_convolution_into_gemm_v4r4r5_nhwc_kyxc_nhwk.hpp
+144
-0
composable_kernel/include/problem_transform/transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk.hpp
...m_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk.hpp
+1
-2
composable_kernel/include/tensor_operation/blockwise_gemm_xdlops.hpp
...kernel/include/tensor_operation/blockwise_gemm_xdlops.hpp
+17
-17
composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v2r3.hpp
...el/include/tensor_operation/gridwise_gemm_xdlops_v2r3.hpp
+179
-162
composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v2r4.hpp
...el/include/tensor_operation/gridwise_gemm_xdlops_v2r4.hpp
+680
-0
composable_kernel/include/tensor_operation/threadwise_tensor_slice_transfer.hpp
...ude/tensor_operation/threadwise_tensor_slice_transfer.hpp
+1
-1
composable_kernel/include/tensor_operation/xdlops_gemm.hpp
composable_kernel/include/tensor_operation/xdlops_gemm.hpp
+7
-7
composable_kernel/include/utility/config.hpp
composable_kernel/include/utility/config.hpp
+1
-1
composable_kernel/include/utility/type.hpp
composable_kernel/include/utility/type.hpp
+3
-0
composable_kernel/src/kernel_wrapper/convolution_forward_implicit_gemm_v4r4_xdlops_nhwc_kyxc_nhwk.cpp
...tion_forward_implicit_gemm_v4r4_xdlops_nhwc_kyxc_nhwk.cpp
+9
-9
device_operation/device_conv_xdl_instance_f16_f16_f16_nhwc_kyxc_nhwk.cpp
...n/device_conv_xdl_instance_f16_f16_f16_nhwc_kyxc_nhwk.cpp
+64
-0
device_operation/device_conv_xdl_instance_f32_f32_f32_nhwc_kyxc_nhwk.cpp
...n/device_conv_xdl_instance_f32_f32_f32_nhwc_kyxc_nhwk.cpp
+64
-0
device_operation/device_gemm_xdl_instance_f16_f16_f16_km_kn_mn.cpp
...eration/device_gemm_xdl_instance_f16_f16_f16_km_kn_mn.cpp
+58
-0
device_operation/device_gemm_xdl_instance_f16_f16_f16_km_nk_mn.cpp
...eration/device_gemm_xdl_instance_f16_f16_f16_km_nk_mn.cpp
+58
-0
device_operation/device_gemm_xdl_instance_f16_f16_f16_mk_kn_mn.cpp
...eration/device_gemm_xdl_instance_f16_f16_f16_mk_kn_mn.cpp
+58
-0
No files found.
CMakeLists.txt
View file @
ed068043
...
...
@@ -40,6 +40,7 @@ message(STATUS "Build with HIP ${hip_VERSION}")
## half
#find_path(HALF_INCLUDE_DIR half.hpp)
set
(
HALF_INCLUDE_DIR
"
${
PROJECT_SOURCE_DIR
}
/external/half/include"
)
message
(
"HALF_INCLUDE_DIR:
${
HALF_INCLUDE_DIR
}
"
)
# CMAKE_CXX_FLAGS
...
...
@@ -185,6 +186,7 @@ enable_cppcheck(
composable_kernel/src/kernel_wrapper
INCLUDE
host/host_tensor/include
host/device/include
host/solver/include
host/driver_offline/include
composable_kernel/include/*
...
...
@@ -196,3 +198,5 @@ enable_cppcheck(
)
add_subdirectory
(
host
)
add_subdirectory
(
example
)
add_subdirectory
(
profiler
)
README.md
View file @
ed068043
# How to build and run
# Docker
```
docker run \
-it \
--rm \
--privileged \
--group-add sudo \
-w /root/workspace \
-v ${PATH_TO_LOCAL_WORKSPACE}:/root/workspace \
rocm/tensorflow:rocm4.2-tf2.4-dev \
/bin/bash
```
# Install Boost for online compilation
https://www.boost.org/doc/libs/1_66_0/more/getting_started/unix-variants.html#easy-build-and-install
# Build
Add path of Boost
```
export LD_LIBRARY_PATH=/usr/local/lib:$LD_LIBRARY_PATH
```
```
mkdir build && cd build
```
cmake cmd. Need to Specify target ID, example below is gfx908
```
cmake \
-D CMAKE_BUILD_TYPE=Release \
-D CMAKE_CXX_FLAGS="-DCK_AMD_GPU_GFX908 -O3 --amdgpu-target=gfx908 -mllvm --amdgpu-spill-vgpr-to-agpr=0 -gline-tables-only -save-temps=$PWD" \
-D HIP_ONLINE_COMPILER_FLAGS="-DCK_AMD_GPU_GFX908" \
-D CMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \
-D CMAKE_PREFIX_PATH=/opt/rocm \
-D CMAKE_VERBOSE_MAKEFILE:BOOL=ON \
..
```
Build drivers:
\
``conv_fwd_driver_offline``
is (offline compilation) driver for forward convolution,
\
``conv_bwd_driver_offline``
is (offline compilation) driver for backward-data convolution
\
``conv_fwd_driver_online``
is (online compilation) driver for forward convolution
```
make -j conv_fwd_driver_offline
make -j conv_bwd_driver_offline
make -j conv_fwd_driver_online
```
# Run
*
layout: 0 = NCHW; 1 = NHWC
*
algo: algorithm
*
verify: 0 = no verification; 1 = do verification
*
init: 0 ~ 5. initialization method
*
log: 0 = no log; 1 = do log
*
repeat: number of time kernel being launched
```
######################################################## layout algo verify init log repeat N__ K___ C___ Y X Hi_ Wi__ Strides Dilations LeftPads RightPads
./host/driver_offline/conv_fwd_driver_offline 0 4 0 0 0 1 128 256 192 3 3 71 71 2 2 1 1 1 1 1 1
./host/driver_offline/conv_fwd_driver_offline 0 4 0 0 0 1 256 1024 256 3 3 14 14 1 1 1 1 1 1 1 1
./host/driver_offline/conv_fwd_driver_offline 1 5 0 0 0 1 128 256 192 3 3 71 71 2 2 1 1 1 1 1 1
./host/driver_offline/conv_fwd_driver_offline 1 5 0 0 0 1 256 1024 256 3 3 14 14 1 1 1 1 1 1 1 1
./host/driver_offline/conv_bwd_driver_offline 1 5 0 0 0 1 256 256 1024 3 3 14 14 1 1 1 1 1 1 1 1
```
# Result
Forward convoltuion, FP16, NCHW
```
./host/driver_offline/conv_fwd_driver_offline 0 4 0 0 0 1 128 256 192 3 3 71 71 2 2 1 1 1 1 1 1
layout: 0
in: dim 4, lengths {128, 192, 71, 71}, strides {967872, 5041, 71, 1}
wei: dim 4, lengths {256, 192, 3, 3}, strides {1728, 9, 3, 1}
out: dim 4, lengths {128, 256, 36, 36}, strides {331776, 1296, 36, 1}
InLeftPads size 2, {1, 1, }
InRightPads size 2, {1, 1, }
ConvStrides size 2, {2, 2, }
ConvDilations size 2, {1, 1, }
device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw
a_k0_m_k1_grid_desc{216, 256, 8}
b_k0_n_k1_grid_desc{216, 165888, 8}
c_m_n_grid_desc{ 256, 165888}
launch_and_time_kernel: grid_dim {1296, 1, 1}, block_dim {256, 1, 1}
Warm up
Start running 1 times...
Average time : 1.4155 ms, 103.686 TFlop/s
```
Forward convoltuion, FP16, NCHW
```
./host/driver_offline/conv_fwd_driver_offline 0 4 0 0 0 1 256 1024 256 3 3 14 14 1 1 1 1 1 1 1 1
layout: 0
in: dim 4, lengths {256, 256, 14, 14}, strides {50176, 196, 14, 1}
wei: dim 4, lengths {1024, 256, 3, 3}, strides {2304, 9, 3, 1}
out: dim 4, lengths {256, 1024, 14, 14}, strides {200704, 196, 14, 1}
InLeftPads size 2, {1, 1, }
InRightPads size 2, {1, 1, }
ConvStrides size 2, {1, 1, }
ConvDilations size 2, {1, 1, }
device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw
a_k0_m_k1_grid_desc{288, 1024, 8}
b_k0_n_k1_grid_desc{288, 50176, 8}
c_m_n_grid_desc{ 1024, 50176}
launch_and_time_kernel: grid_dim {1568, 1, 1}, block_dim {256, 1, 1}
Warm up
Start running 1 times...
Average time : 2.21357 ms, 106.959 TFlop/s
```
Forward convolution, FP16, NHWC
```
./host/driver_offline/conv_fwd_driver_offline 1 5 0 0 0 1 128 256 192 3 3 71 71 2 2 1 1 1 1 1 1
layout: 1
in: dim 4, lengths {128, 71, 71, 192}, strides {967872, 13632, 192, 1}
wei: dim 4, lengths {256, 3, 3, 192}, strides {1728, 576, 192, 1}
out: dim 4, lengths {128, 36, 36, 256}, strides {331776, 9216, 256, 1}
InLeftPads size 2, {1, 1, }
InRightPads size 2, {1, 1, }
ConvStrides size 2, {2, 2, }
ConvDilations size 2, {1, 1, }
device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk
a_k0_m_k1_grid_desc{216, 165888, 8}
b_k0_n_k1_grid_desc{216, 256, 8}
c_m_n_grid_desc{ 165888, 256}
launch_and_time_kernel: grid_dim {1296, 1, 1}, block_dim {256, 1, 1}
Warm up
Start running 1 times...
Average time : 1.12014 ms, 131.025 TFlop/s
```
Forward convolution, FP16, NHWC
```
./host/driver_offline/conv_fwd_driver_offline 1 5 0 0 0 1 256 1024 256 3 3 14 14 1 1 1 1 1 1 1 1
layout: 1
in: dim 4, lengths {256, 14, 14, 256}, strides {50176, 3584, 256, 1}
wei: dim 4, lengths {1024, 3, 3, 256}, strides {2304, 768, 256, 1}
out: dim 4, lengths {256, 14, 14, 1024}, strides {200704, 14336, 1024, 1}
InLeftPads size 2, {1, 1, }
InRightPads size 2, {1, 1, }
ConvStrides size 2, {1, 1, }
ConvDilations size 2, {1, 1, }
device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk
a_k0_m_k1_grid_desc{288, 50176, 8}
b_k0_n_k1_grid_desc{288, 1024, 8}
c_m_n_grid_desc{ 50176, 1024}
launch_and_time_kernel: grid_dim {1568, 1, 1}, block_dim {256, 1, 1}
Warm up
Start running 1 times...
Average time : 1.86877 ms, 126.693 TFlop/s
```
Backward data convolution, FP16, NHWC
```
./host/driver_offline/conv_bwd_driver_offline 1 1 0 3 0 1 256 256 1024 3 3 14 14 1 1 1 1 1 1 1 1
layout: 1
in: dim 4, lengths {256, 14, 14, 1024}, strides {200704, 14336, 1024, 1}
wei: dim 4, lengths {256, 3, 3, 1024}, strides {9216, 3072, 1024, 1}
out: dim 4, lengths {256, 14, 14, 256}, strides {50176, 3584, 256, 1}
InLeftPads size 2, {1, 1, }
InRightPads size 2, {1, 1, }
ConvStrides size 2, {1, 1, }
ConvDilations size 2, {1, 1, }
device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk
a_k0_m_k1_grid_desc{288, 50176, 8}
b_k0_n_k1_grid_desc{288, 1024, 8}
c_m_n_grid_desc{ 50176, 1024}
launch_and_time_kernel: grid_dim {1568, 1, 1}, block_dim {256, 1, 1}
Warm up
Start running 1 times...
Average time : 2.22461 ms, 106.428 TFlop/s
```
composable_kernel/include/problem_transform/transform_backward_weight_convolution_into_gemm_v4r4r2_atomic_nchw_kcyx_nkhw.hpp
0 → 100644
View file @
ed068043
#ifndef CK_TRANSFORM_BACKWARD_WEIGHT_CONVOLUTION_INTO_GEMM_V4R4R2_ATOMIC_NCHW_KCYX_NKHW_HPP
#define CK_TRANSFORM_BACKWARD_WEIGHT_CONVOLUTION_INTO_GEMM_V4R4R2_ATOMIC_NCHW_KCYX_NKHW_HPP
#include "common_header.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
namespace
ck
{
// GemmM = K
// GemmK = N * Ho * Wo
// GemmN = C * Y * X
template
<
typename
...
Wei
,
typename
...
In
,
typename
...
Out
,
typename
ConvStrides
,
typename
ConvDilations
,
typename
InLeftPads
,
typename
InRightPads
,
index_t
GemmK1Value
,
typename
GemmKBatchType
,
typename
GemmKPadType
>
__host__
__device__
constexpr
auto
transform_backward_weight_convolution_into_gemm_v4r4r2_atomic_nchw_kcyx_nkhw_pad
(
const
TensorDescriptor
<
Wei
...
>&
wei_k_c_y_x_grid_desc
,
const
TensorDescriptor
<
In
...
>&
in_n_c_hi_wi_grid_desc
,
const
TensorDescriptor
<
Out
...
>&
out_n_k_ho_wo_grid_desc
,
const
ConvStrides
&
conv_strides
,
const
ConvDilations
&
conv_dilations
,
const
InLeftPads
&
in_left_pads
,
const
InRightPads
&
in_right_pads
,
Number
<
GemmK1Value
>
,
GemmKBatchType
GemmKBatch
,
GemmKPadType
GemmKPad
)
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
GemmK1
=
Number
<
GemmK1Value
>
{};
const
auto
N
=
in_n_c_hi_wi_grid_desc
.
GetLength
(
I0
);
const
auto
C
=
in_n_c_hi_wi_grid_desc
.
GetLength
(
I1
);
const
auto
K
=
out_n_k_ho_wo_grid_desc
.
GetLength
(
I1
);
const
auto
Hi
=
in_n_c_hi_wi_grid_desc
.
GetLength
(
I2
);
const
auto
Wi
=
in_n_c_hi_wi_grid_desc
.
GetLength
(
I3
);
const
auto
Ho
=
out_n_k_ho_wo_grid_desc
.
GetLength
(
I2
);
const
auto
Wo
=
out_n_k_ho_wo_grid_desc
.
GetLength
(
I3
);
const
auto
Y
=
wei_k_c_y_x_grid_desc
.
GetLength
(
I2
);
const
auto
X
=
wei_k_c_y_x_grid_desc
.
GetLength
(
I3
);
const
auto
ConvStrideH
=
conv_strides
[
I0
];
const
auto
ConvStrideW
=
conv_strides
[
I1
];
const
auto
ConvDilationH
=
conv_dilations
[
I0
];
const
auto
ConvDilationW
=
conv_dilations
[
I1
];
const
auto
InLeftPadH
=
in_left_pads
[
I0
];
const
auto
InLeftPadW
=
in_left_pads
[
I1
];
const
auto
InRightPadH
=
in_right_pads
[
I0
];
const
auto
InRightPadW
=
in_right_pads
[
I1
];
const
auto
GemmM
=
K
;
const
auto
GemmN
=
C
*
Y
*
X
;
const
auto
GemmKTotal
=
N
*
Ho
*
Wo
;
const
index_t
GemmK0
=
GemmKPad
/
(
GemmKBatch
*
GemmK1
);
// A: output tensor
const
auto
out_gemmktotal_gemmm_grid_desc
=
transform_tensor_descriptor
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
N
,
K
,
Ho
*
Wo
)),
make_tuple
(
make_pass_through_transform
(
K
),
make_merge_transform
(
make_tuple
(
N
,
Ho
*
Wo
))),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
,
2
>
{}),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}));
const
auto
out_gemmkpad_gemmm_grid_desc
=
transform_tensor_descriptor
(
out_gemmktotal_gemmm_grid_desc
,
make_tuple
(
make_right_pad_transform
(
GemmKTotal
,
GemmKPad
-
GemmKTotal
),
make_pass_through_transform
(
GemmM
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc
=
transform_tensor_descriptor
(
out_gemmkpad_gemmm_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
GemmKBatch
,
GemmK0
,
GemmK1
)),
make_pass_through_transform
(
GemmM
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
1
,
3
>
{},
Sequence
<
2
>
{}));
// B: input tensor
const
auto
in_n_c_hip_wip_grid_desc
=
transform_tensor_descriptor
(
in_n_c_hi_wi_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_pass_through_transform
(
C
),
make_pad_transform
(
Hi
,
InLeftPadH
,
InRightPadH
),
make_pad_transform
(
Wi
,
InLeftPadW
,
InRightPadW
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}));
const
auto
in_n_c_y_ho_x_wo_grid_desc
=
transform_tensor_descriptor
(
in_n_c_hip_wip_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_pass_through_transform
(
C
),
make_embed_transform
(
make_tuple
(
Y
,
Ho
),
make_tuple
(
ConvDilationH
,
ConvStrideH
)),
make_embed_transform
(
make_tuple
(
X
,
Wo
),
make_tuple
(
ConvDilationW
,
ConvStrideW
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
3
>
{},
Sequence
<
4
,
5
>
{}));
const
auto
in_gemmktotal_gemmn_grid_desc
=
transform_tensor_descriptor
(
in_n_c_y_ho_x_wo_grid_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
C
,
Y
,
X
)),
make_merge_transform
(
make_tuple
(
N
,
Ho
,
Wo
))),
make_tuple
(
Sequence
<
1
,
2
,
4
>
{},
Sequence
<
0
,
3
,
5
>
{}),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}));
const
auto
in_gemmkpad_gemmn_grid_desc
=
transform_tensor_descriptor
(
in_gemmktotal_gemmn_grid_desc
,
make_tuple
(
make_right_pad_transform
(
GemmKTotal
,
GemmKPad
-
GemmKTotal
),
make_pass_through_transform
(
GemmN
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc
=
transform_tensor_descriptor
(
in_gemmkpad_gemmn_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
GemmKBatch
,
GemmK0
,
GemmK1
)),
make_pass_through_transform
(
GemmN
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
1
,
3
>
{},
Sequence
<
2
>
{}));
// C: weight tensor
const
auto
wei_gemmm_gemmn_grid_desc
=
transform_tensor_descriptor
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
K
,
C
*
Y
*
X
)),
make_tuple
(
make_pass_through_transform
(
K
),
make_pass_through_transform
(
C
*
Y
*
X
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
return
make_tuple
(
out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc
,
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc
,
wei_gemmm_gemmn_grid_desc
);
}
}
// namespace ck
#endif
composable_kernel/include/problem_transform/transform_backward_weight_convolution_into_gemm_v4r4r4_atomic_nhwc_kyxc_nhwk.hpp
0 → 100644
View file @
ed068043
#ifndef CK_TRANSFORM_BACKWARD_WEIGHT_CONVOLUTION_INTO_GEMM_V4R4R4_ATOMIC_NHWC_KYXC_NHWK_HPP
#define CK_TRANSFORM_BACKWARD_WEIGHT_CONVOLUTION_INTO_GEMM_V4R4R4_ATOMIC_NHWC_KYXC_NHWK_HPP
#include "common_header.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
namespace
ck
{
// A: in
// B: wei
// C: out
// GemmM = N * Ho * Wo
// GemmN = K
// GemmK = Y * X * C
template
<
typename
...
In
,
typename
...
Wei
,
typename
...
Out
,
typename
ConvStrides
,
typename
ConvDilations
,
typename
InLeftPads
,
typename
InRightPads
,
index_t
GemmK1Value
,
typename
GemmKBatchType
,
typename
GemmKPadType
>
__host__
__device__
constexpr
auto
transform_backward_weight_convolution_into_gemm_v4r4r4_atomic_nhwc_kyxc_nhwk_pad
(
const
TensorDescriptor
<
In
...
>&
in_n_hi_wi_c_grid_desc
,
const
TensorDescriptor
<
Wei
...
>&
wei_k_y_x_c_grid_desc
,
const
TensorDescriptor
<
Out
...
>&
out_n_ho_wo_k_grid_desc
,
const
ConvStrides
&
conv_strides
,
const
ConvDilations
&
conv_dilations
,
const
InLeftPads
&
in_left_pads
,
const
InRightPads
&
in_right_pads
,
Number
<
GemmK1Value
>
,
GemmKBatchType
GemmKBatch
,
GemmKPadType
GemmKPad
)
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
GemmK1
=
Number
<
GemmK1Value
>
{};
const
auto
N
=
in_n_hi_wi_c_grid_desc
.
GetLength
(
I0
);
const
auto
C
=
in_n_hi_wi_c_grid_desc
.
GetLength
(
I3
);
const
auto
K
=
out_n_ho_wo_k_grid_desc
.
GetLength
(
I3
);
const
auto
Hi
=
in_n_hi_wi_c_grid_desc
.
GetLength
(
I1
);
const
auto
Wi
=
in_n_hi_wi_c_grid_desc
.
GetLength
(
I2
);
const
auto
Ho
=
out_n_ho_wo_k_grid_desc
.
GetLength
(
I1
);
const
auto
Wo
=
out_n_ho_wo_k_grid_desc
.
GetLength
(
I2
);
const
auto
Y
=
wei_k_y_x_c_grid_desc
.
GetLength
(
I1
);
const
auto
X
=
wei_k_y_x_c_grid_desc
.
GetLength
(
I2
);
const
auto
ConvStrideH
=
conv_strides
[
I0
];
const
auto
ConvStrideW
=
conv_strides
[
I1
];
const
auto
ConvDilationH
=
conv_dilations
[
I0
];
const
auto
ConvDilationW
=
conv_dilations
[
I1
];
const
auto
InLeftPadH
=
in_left_pads
[
I0
];
const
auto
InLeftPadW
=
in_left_pads
[
I1
];
const
auto
InRightPadH
=
in_right_pads
[
I0
];
const
auto
InRightPadW
=
in_right_pads
[
I1
];
const
auto
GemmM
=
Y
*
X
*
C
;
const
auto
GemmN
=
K
;
const
auto
GemmKTotal
=
N
*
Ho
*
Wo
;
const
index_t
GemmK0
=
GemmKPad
/
(
GemmKBatch
*
GemmK1
);
// A: input tensor
const
auto
in_n_hip_wip_c_grid_desc
=
transform_tensor_descriptor
(
in_n_hi_wi_c_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_pad_transform
(
Hi
,
InLeftPadH
,
InRightPadH
),
make_pad_transform
(
Wi
,
InLeftPadW
,
InRightPadW
),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}));
const
auto
in_n_y_ho_x_wo_c_grid_desc
=
transform_tensor_descriptor
(
in_n_hip_wip_c_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_embed_transform
(
make_tuple
(
Y
,
Ho
),
make_tuple
(
ConvDilationH
,
ConvStrideH
)),
make_embed_transform
(
make_tuple
(
X
,
Wo
),
make_tuple
(
ConvDilationW
,
ConvStrideW
)),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
,
4
>
{},
Sequence
<
5
>
{}));
const
auto
in_gemmktotal_gemmm_grid_desc
=
transform_tensor_descriptor
(
in_n_y_ho_x_wo_c_grid_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
Y
,
X
,
C
)),
make_merge_transform
(
make_tuple
(
N
,
Ho
,
Wo
))),
make_tuple
(
Sequence
<
1
,
3
,
5
>
{},
Sequence
<
0
,
2
,
4
>
{}),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}));
const
auto
in_gemmkpad_gemmm_grid_desc
=
transform_tensor_descriptor
(
in_gemmktotal_gemmm_grid_desc
,
make_tuple
(
make_right_pad_transform
(
GemmKTotal
,
GemmKPad
-
GemmKTotal
),
make_pass_through_transform
(
GemmM
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
in_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc
=
transform_tensor_descriptor
(
in_gemmkpad_gemmm_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
GemmKBatch
,
GemmK0
,
GemmK1
)),
make_pass_through_transform
(
GemmM
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
1
,
3
>
{},
Sequence
<
2
>
{}));
// B: output tensor
const
auto
out_gemmktotal_gemmn_grid_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
N
*
Ho
*
Wo
,
K
));
const
auto
out_gemmkpad_gemmn_grid_desc
=
transform_tensor_descriptor
(
out_gemmktotal_gemmn_grid_desc
,
make_tuple
(
make_right_pad_transform
(
GemmKTotal
,
GemmKPad
-
GemmKTotal
),
make_pass_through_transform
(
GemmN
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
out_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc
=
transform_tensor_descriptor
(
out_gemmkpad_gemmn_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
GemmKBatch
,
GemmK0
,
GemmK1
)),
make_pass_through_transform
(
GemmN
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
1
,
3
>
{},
Sequence
<
2
>
{}));
// C: weight tensor
const
auto
wei_gemmm_gemmn_grid_desc
=
transform_tensor_descriptor
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
K
,
Y
*
X
*
C
)),
make_tuple
(
make_pass_through_transform
(
K
),
make_pass_through_transform
(
Y
*
X
*
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}));
return
make_tuple
(
in_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc
,
out_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc
,
wei_gemmm_gemmn_grid_desc
);
}
}
// namespace ck
#endif
composable_kernel/include/problem_transform/transform_backward_weight_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk.hpp
0 → 100644
View file @
ed068043
#ifndef CK_TRANSFORM_BACKWARD_WEIGHT_CONVOLUTION_INTO_GEMM_V4R4R4_NHWC_KYXC_NHWK_HPP
#define CK_TRANSFORM_BACKWARD_WEIGHT_CONVOLUTION_INTO_GEMM_V4R4R4_NHWC_KYXC_NHWK_HPP
#include "common_header.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
namespace
ck
{
// A: in
// B: wei
// C: out
// GemmM = N * Ho * Wo
// GemmN = K
// GemmK = Y * X * C
template
<
typename
...
In
,
typename
...
Wei
,
typename
...
Out
,
typename
ConvStrides
,
typename
ConvDilations
,
typename
InLeftPads
,
typename
InRightPads
,
index_t
GemmK1Value
>
__host__
__device__
constexpr
auto
transform_backward_weight_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk_pad
(
const
TensorDescriptor
<
In
...
>&
in_n_hi_wi_c_grid_desc
,
const
TensorDescriptor
<
Wei
...
>&
wei_k_y_x_c_grid_desc
,
const
TensorDescriptor
<
Out
...
>&
out_n_ho_wo_k_grid_desc
,
const
ConvStrides
&
conv_strides
,
const
ConvDilations
&
conv_dilations
,
const
InLeftPads
&
in_left_pads
,
const
InRightPads
&
in_right_pads
,
Number
<
GemmK1Value
>
)
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
GemmK1
=
Number
<
GemmK1Value
>
{};
const
auto
N
=
in_n_hi_wi_c_grid_desc
.
GetLength
(
I0
);
const
auto
C
=
in_n_hi_wi_c_grid_desc
.
GetLength
(
I3
);
const
auto
K
=
out_n_ho_wo_k_grid_desc
.
GetLength
(
I3
);
const
auto
Hi
=
in_n_hi_wi_c_grid_desc
.
GetLength
(
I1
);
const
auto
Wi
=
in_n_hi_wi_c_grid_desc
.
GetLength
(
I2
);
const
auto
Ho
=
out_n_ho_wo_k_grid_desc
.
GetLength
(
I1
);
const
auto
Wo
=
out_n_ho_wo_k_grid_desc
.
GetLength
(
I2
);
const
auto
Y
=
wei_k_y_x_c_grid_desc
.
GetLength
(
I1
);
const
auto
X
=
wei_k_y_x_c_grid_desc
.
GetLength
(
I2
);
const
auto
ConvStrideH
=
conv_strides
[
I0
];
const
auto
ConvStrideW
=
conv_strides
[
I1
];
const
auto
ConvDilationH
=
conv_dilations
[
I0
];
const
auto
ConvDilationW
=
conv_dilations
[
I1
];
const
auto
InLeftPadH
=
in_left_pads
[
I0
];
const
auto
InLeftPadW
=
in_left_pads
[
I1
];
const
auto
InRightPadH
=
in_right_pads
[
I0
];
const
auto
InRightPadW
=
in_right_pads
[
I1
];
const
auto
GemmM
=
Y
*
X
*
C
;
const
auto
GemmN
=
K
;
const
auto
GemmK
=
N
*
Ho
*
Wo
;
const
auto
GemmK0
=
GemmK
/
GemmK1
;
// A: input tensor
const
auto
in_n_hip_wip_c_grid_desc
=
transform_tensor_descriptor
(
in_n_hi_wi_c_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_pad_transform
(
Hi
,
InLeftPadH
,
InRightPadH
),
make_pad_transform
(
Wi
,
InLeftPadW
,
InRightPadW
),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}));
const
auto
in_n_y_ho_x_wo_c_grid_desc
=
transform_tensor_descriptor
(
in_n_hip_wip_c_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_embed_transform
(
make_tuple
(
Y
,
Ho
),
make_tuple
(
ConvDilationH
,
ConvStrideH
)),
make_embed_transform
(
make_tuple
(
X
,
Wo
),
make_tuple
(
ConvDilationW
,
ConvStrideW
)),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
,
4
>
{},
Sequence
<
5
>
{}));
const
auto
in_gemmk_gemmm_grid_desc
=
transform_tensor_descriptor
(
in_n_y_ho_x_wo_c_grid_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
Y
,
X
,
C
)),
make_merge_transform
(
make_tuple
(
N
,
Ho
,
Wo
))),
make_tuple
(
Sequence
<
1
,
3
,
5
>
{},
Sequence
<
0
,
2
,
4
>
{}),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}));
const
auto
in_gemmk0_gemmm_gemmk1_grid_desc
=
transform_tensor_descriptor
(
in_gemmk_gemmm_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
GemmK0
,
GemmK1
)),
make_pass_through_transform
(
GemmM
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
// B: output tensor
const
auto
out_gemmk_gemmn_grid_desc
=
transform_tensor_descriptor
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
N
*
Ho
*
Wo
,
K
)),
make_tuple
(
make_pass_through_transform
(
N
*
Ho
*
Wo
),
make_pass_through_transform
(
K
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
out_gemmk0_gemmn_gemmk1_grid_desc
=
transform_tensor_descriptor
(
out_gemmk_gemmn_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
GemmK0
,
GemmK1
)),
make_pass_through_transform
(
GemmN
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
// C: weight tensor
const
auto
wei_gemmm_gemmn_grid_desc
=
transform_tensor_descriptor
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
K
,
Y
*
X
*
C
)),
make_tuple
(
make_pass_through_transform
(
K
),
make_pass_through_transform
(
Y
*
X
*
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}));
return
make_tuple
(
in_gemmk0_gemmm_gemmk1_grid_desc
,
out_gemmk0_gemmn_gemmk1_grid_desc
,
wei_gemmm_gemmn_grid_desc
);
}
}
// namespace ck
#endif
composable_kernel/include/problem_transform/transform_backward_weight_convolution_into_gemm_v4r4r5_nhwc_kyxc_nhwk.hpp
0 → 100644
View file @
ed068043
#ifndef CK_TRANSFORM_BACKWARD_WEIGHT_CONVOLUTION_INTO_GEMM_V4R4R5_NHWC_KYXC_NHWK_HPP
#define CK_TRANSFORM_BACKWARD_WEIGHT_CONVOLUTION_INTO_GEMM_V4R4R5_NHWC_KYXC_NHWK_HPP
#include "common_header.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
namespace
ck
{
// A: out
// B: in
// C: wei
// GemmM = K
// GemmN = Y * X * C
// GemmKTotal = N * Ho * Wo
template
<
typename
...
In
,
typename
...
Wei
,
typename
...
Out
,
typename
ConvStrides
,
typename
ConvDilations
,
typename
InLeftPads
,
typename
InRightPads
,
index_t
GemmK1Value
,
typename
GemmKBatchType
,
typename
GemmKPadType
>
__host__
__device__
constexpr
auto
transform_backward_weight_convolution_into_gemm_v4r4r5_nhwc_kyxc_nhwk_pad
(
const
TensorDescriptor
<
In
...
>&
in_n_hi_wi_c_grid_desc
,
const
TensorDescriptor
<
Wei
...
>&
wei_k_y_x_c_grid_desc
,
const
TensorDescriptor
<
Out
...
>&
out_n_ho_wo_k_grid_desc
,
const
ConvStrides
&
conv_strides
,
const
ConvDilations
&
conv_dilations
,
const
InLeftPads
&
in_left_pads
,
const
InRightPads
&
in_right_pads
,
Number
<
GemmK1Value
>
,
GemmKBatchType
GemmKBatch
,
GemmKPadType
GemmKPad
)
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
GemmK1
=
Number
<
GemmK1Value
>
{};
const
auto
N
=
in_n_hi_wi_c_grid_desc
.
GetLength
(
I0
);
const
auto
C
=
in_n_hi_wi_c_grid_desc
.
GetLength
(
I3
);
const
auto
K
=
out_n_ho_wo_k_grid_desc
.
GetLength
(
I3
);
const
auto
Hi
=
in_n_hi_wi_c_grid_desc
.
GetLength
(
I1
);
const
auto
Wi
=
in_n_hi_wi_c_grid_desc
.
GetLength
(
I2
);
const
auto
Ho
=
out_n_ho_wo_k_grid_desc
.
GetLength
(
I1
);
const
auto
Wo
=
out_n_ho_wo_k_grid_desc
.
GetLength
(
I2
);
const
auto
Y
=
wei_k_y_x_c_grid_desc
.
GetLength
(
I1
);
const
auto
X
=
wei_k_y_x_c_grid_desc
.
GetLength
(
I2
);
const
auto
ConvStrideH
=
conv_strides
[
I0
];
const
auto
ConvStrideW
=
conv_strides
[
I1
];
const
auto
ConvDilationH
=
conv_dilations
[
I0
];
const
auto
ConvDilationW
=
conv_dilations
[
I1
];
const
auto
InLeftPadH
=
in_left_pads
[
I0
];
const
auto
InLeftPadW
=
in_left_pads
[
I1
];
const
auto
InRightPadH
=
in_right_pads
[
I0
];
const
auto
InRightPadW
=
in_right_pads
[
I1
];
const
auto
GemmM
=
K
;
const
auto
GemmN
=
Y
*
X
*
C
;
const
auto
GemmKTotal
=
N
*
Ho
*
Wo
;
const
index_t
GemmK0
=
GemmKPad
/
(
GemmKBatch
*
GemmK1
);
// A: output tensor
const
auto
out_gemmktotal_gemmm_grid_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
N
*
Ho
*
Wo
,
K
));
const
auto
out_gemmkpad_gemmm_grid_desc
=
transform_tensor_descriptor
(
out_gemmktotal_gemmm_grid_desc
,
make_tuple
(
make_right_pad_transform
(
GemmKTotal
,
GemmKPad
-
GemmKTotal
),
make_pass_through_transform
(
GemmM
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc
=
transform_tensor_descriptor
(
out_gemmkpad_gemmm_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
GemmKBatch
,
GemmK0
,
GemmK1
)),
make_pass_through_transform
(
GemmM
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
1
,
3
>
{},
Sequence
<
2
>
{}));
// B: input tensor
const
auto
in_n_hip_wip_c_grid_desc
=
transform_tensor_descriptor
(
in_n_hi_wi_c_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_pad_transform
(
Hi
,
InLeftPadH
,
InRightPadH
),
make_pad_transform
(
Wi
,
InLeftPadW
,
InRightPadW
),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}));
const
auto
in_n_y_ho_x_wo_c_grid_desc
=
transform_tensor_descriptor
(
in_n_hip_wip_c_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_embed_transform
(
make_tuple
(
Y
,
Ho
),
make_tuple
(
ConvDilationH
,
ConvStrideH
)),
make_embed_transform
(
make_tuple
(
X
,
Wo
),
make_tuple
(
ConvDilationW
,
ConvStrideW
)),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
,
4
>
{},
Sequence
<
5
>
{}));
const
auto
in_gemmktotal_gemmn_grid_desc
=
transform_tensor_descriptor
(
in_n_y_ho_x_wo_c_grid_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
Y
,
X
,
C
)),
make_merge_transform
(
make_tuple
(
N
,
Ho
,
Wo
))),
make_tuple
(
Sequence
<
1
,
3
,
5
>
{},
Sequence
<
0
,
2
,
4
>
{}),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}));
const
auto
in_gemmkpad_gemmn_grid_desc
=
transform_tensor_descriptor
(
in_gemmktotal_gemmn_grid_desc
,
make_tuple
(
make_right_pad_transform
(
GemmKTotal
,
GemmKPad
-
GemmKTotal
),
make_pass_through_transform
(
GemmN
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc
=
transform_tensor_descriptor
(
in_gemmkpad_gemmn_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
GemmKBatch
,
GemmK0
,
GemmK1
)),
make_pass_through_transform
(
GemmN
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
1
,
3
>
{},
Sequence
<
2
>
{}));
// C: weight tensor
const
auto
wei_gemmm_gemmn_grid_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
K
,
Y
*
X
*
C
));
return
make_tuple
(
out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc
,
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc
,
wei_gemmm_gemmn_grid_desc
);
}
}
// namespace ck
#endif
composable_kernel/include/problem_transform/transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk.hpp
View file @
ed068043
...
...
@@ -21,8 +21,7 @@ template <typename... In,
typename
InLeftPads
,
typename
InRightPads
,
index_t
GemmK1Value
>
__host__
__device__
constexpr
auto
transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk_pad
(
__host__
__device__
constexpr
auto
transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk
(
const
TensorDescriptor
<
In
...
>&
in_n_hi_wi_c_grid_desc
,
const
TensorDescriptor
<
Wei
...
>&
wei_k_y_x_c_grid_desc
,
const
TensorDescriptor
<
Out
...
>&
out_n_ho_wo_k_grid_desc
,
...
...
composable_kernel/include/tensor_operation/blockwise_gemm_xdlops.hpp
View file @
ed068043
...
...
@@ -124,7 +124,7 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
"wrong!"
);
}
__host__
__device__
static
constexpr
auto
GetC
M0N0M1N1M2M3M4N2
ThreadDescriptor
()
__host__
__device__
static
constexpr
auto
GetCThreadDescriptor
_M0_N0_M1_N1_M2_M3_M4_N2
()
{
constexpr
auto
c_m0_m1_m2_n_tblk_lens
=
xdlops_gemm
.
GetCM0M1M2NThreadBlkLengths
();
...
...
@@ -136,9 +136,9 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
I1
,
I1
,
I1
,
M0
,
M1
,
M2
,
N
));
}
__host__
__device__
static
constexpr
auto
GetCM0N0M1N1M2M3M4N2
BlockDescriptor
()
__host__
__device__
static
constexpr
auto
GetC
BlockDescriptor_
M0
_
N0
_
M1
_
N1
_
M2
_
M3
_
M4
_
N2
()
{
constexpr
auto
c_m0_n0_m1_n1_m2_n2
_block_desc
=
constexpr
auto
c_
block_desc_
m0_n0_m1_n1_m2_n2
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MRepeat
>
{},
Number
<
NRepeat
>
{},
Number
<
MWaves
>
{},
...
...
@@ -146,24 +146,24 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
Number
<
MPerXDL
>
{},
Number
<
NPerXDL
>
{}));
return
xdlops_gemm
.
MakeCM0N0M1N1M2M3M4N2
Descriptor
(
c_m0_n0_m1_n1_m2_n2
_block_desc
);
return
xdlops_gemm
.
MakeC
Descriptor_
M0
_
N0
_
M1
_
N1
_
M2
_
M3
_
M4
_
N2
(
c_block_des
c_m0_n0_m1_n1_m2_n2
);
}
template
<
typename
C
MN
GridDesc
>
template
<
typename
CGridDesc
_M_N
>
__host__
__device__
static
constexpr
auto
MakeCM0N0M1N1M2M3M4N2
GridDescriptor
(
const
C
MN
GridDesc
&
c_m_n
_grid_desc
)
MakeC
GridDescriptor_
M0
_
N0
_
M1
_
N1
_
M2
_
M3
_
M4
_
N2
(
const
CGridDesc
_M_N
&
c
_grid_desc
_m_n
)
{
const
auto
c_m0_n0_m1_n1_m2_n2
_grid_desc
=
transform_tensor_descriptor
(
c_
m_n_
grid_desc
,
const
auto
c_
grid_desc_
m0_n0_m1_n1_m2_n2
=
transform_tensor_descriptor
(
c_grid_desc
_m_n
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
MRepeat
,
MWaves
,
MPerXDL
)),
make_unmerge_transform
(
make_tuple
(
NRepeat
,
NWaves
,
NPerXDL
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
,
4
>
{},
Sequence
<
1
,
3
,
5
>
{}));
return
xdlops_gemm
.
MakeCM0N0M1N1M2M3M4N2
Descriptor
(
c_m0_n0_m1_n1_m2_n2
_grid_desc
);
return
xdlops_gemm
.
MakeC
Descriptor_
M0
_
N0
_
M1
_
N1
_
M2
_
M3
_
M4
_
N2
(
c_grid_des
c_m0_n0_m1_n1_m2_n2
);
}
__host__
__device__
static
constexpr
auto
MakeA
K0M0M1M2K1
BlockDescriptor
()
__host__
__device__
static
constexpr
auto
MakeABlockDescriptor
_K0_M0_M1_M2_K1
()
{
return
transform_tensor_descriptor
(
AK0MK1BlockDesc
{},
...
...
@@ -175,7 +175,7 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
,
3
>
{},
Sequence
<
4
>
{}));
}
__host__
__device__
static
constexpr
auto
MakeB
K0N0N1N2K1
BlockDescriptor
()
__host__
__device__
static
constexpr
auto
MakeBBlockDescriptor
_K0_N0_N1_N2_K1
()
{
return
transform_tensor_descriptor
(
BK0NK1BlockDesc
{},
...
...
@@ -187,8 +187,8 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
,
3
>
{},
Sequence
<
4
>
{}));
}
static
constexpr
auto
a_k0_m0_m1_m2_k1
_block_desc
=
MakeAK0M0M1M2K1BlockDescriptor
();
static
constexpr
auto
b_k0_n0_n1_n2_k1
_block_desc
=
MakeBK0N0N1N2K1BlockDescriptor
();
static
constexpr
auto
a_
block_desc_
k0_m0_m1_m2_k1
=
MakeABlockDescriptor_K0_M0_M1_M2_K1
();
static
constexpr
auto
b_
block_desc_
k0_n0_n1_n2_k1
=
MakeBBlockDescriptor_K0_N0_N1_N2_K1
();
template
<
typename
ABlockBuffer
,
typename
BBlockBuffer
,
typename
CThreadBuffer
>
__device__
void
Run
(
const
ABlockBuffer
&
a_block_buf
,
...
...
@@ -202,7 +202,7 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
// read A
a_thread_copy_
.
Run
(
a_k0_m0_m1_m2_k1
_block_desc
,
a_thread_copy_
.
Run
(
a_
block_desc_
k0_m0_m1_m2_k1
,
make_tuple
(
I0
,
m0
,
I0
,
I0
,
I0
),
a_block_buf
,
a_thread_desc_
,
...
...
@@ -211,7 +211,7 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
// read B
b_thread_copy_
.
Run
(
b_k0_n0_n1_n2_k1
_block_desc
,
b_thread_copy_
.
Run
(
b_
block_desc_
k0_n0_n1_n2_k1
,
make_tuple
(
I0
,
n0
,
I0
,
I0
,
I0
),
b_block_buf
,
b_thread_desc_
,
...
...
@@ -256,7 +256,7 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
using
AThreadCopy
=
ThreadwiseTensorSliceTransfer_v4
<
FloatAB
,
FloatAB
,
decltype
(
a_k0_m0_m1_m2_k1
_block_desc
),
decltype
(
a_
block_desc_
k0_m0_m1_m2_k1
),
decltype
(
a_thread_desc_
),
Sequence
<
K0
,
1
,
1
,
1
,
K1
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
>
,
...
...
@@ -266,7 +266,7 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
using
BThreadCopy
=
ThreadwiseTensorSliceTransfer_v4
<
FloatAB
,
FloatAB
,
decltype
(
b_k0_n0_n1_n2_k1
_block_desc
),
decltype
(
b_
block_desc_
k0_n0_n1_n2_k1
),
decltype
(
b_thread_desc_
),
Sequence
<
K0
,
1
,
1
,
1
,
K1
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
>
,
...
...
composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v2r3.hpp
View file @
ed068043
...
...
@@ -16,44 +16,46 @@ namespace ck {
template
<
typename
GridwiseGemm
,
typename
FloatAB
,
typename
FloatC
,
typename
AK0MK1GridDesc
,
typename
BK0NK1GridDesc
,
typename
CM0N0M1N1M2M3M4N2GridDesc
,
typename
CBlockClusterAdaptor
>
typename
AGridDesc_K0_M_K1
,
typename
BGridDesc_K0_N_K1
,
typename
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
,
typename
Block2CTileMap
,
bool
HasMainKBlockLoop
>
__global__
void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
#endif
kernel_gemm_xdlops_v2r3
(
const
FloatAB
*
__restrict__
p_a_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
FloatC
*
__restrict__
p_c_grid
,
const
AK0MK1GridDesc
a_k0_m_k1_grid_desc
,
const
BK0NK1GridDesc
b_k0_n_k1_grid_desc
,
const
CM0N0M1N1M2M3M4N2GridDesc
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc
,
const
CBlockClusterAdaptor
c_block_cluster_adaptor
)
kernel_gemm_xdlops_v2r3
(
const
FloatAB
*
__restrict__
p_a_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
FloatC
*
__restrict__
p_c_grid
,
const
AGridDesc_K0_M_K1
a_grid_desc_k0_m_k1
,
const
BGridDesc_K0_N_K1
b_grid_desc_k0_n_k1
,
const
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
const
Block2CTileMap
block_2_ctile_map
)
{
constexpr
index_t
shared_block_size
=
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()
/
sizeof
(
FloatAB
);
__shared__
FloatAB
p_shared_block
[
shared_block_size
];
GridwiseGemm
::
Run
(
p_a_grid
,
p_b_grid
,
p_c_grid
,
p_shared_block
,
a_
k0_m_k1_
grid_desc
,
b_
k0_n_k1_
grid_desc
,
c_m0_n0_m1_n1_m2_m3_m4_n2
_grid_desc
,
c_block_cluster_adaptor
);
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>
(
p_a_grid
,
p_b_grid
,
p_c_grid
,
p_shared_block
,
a_grid_desc
_k0_m_k1
,
b_grid_desc
_k0_n_k1
,
c_grid_des
c_m0_n0_m1_n1_m2_m3_m4_n2
,
block_2_ctile_map
);
}
#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER
template
<
typename
GridwiseGemm
,
typename
FloatAB
,
typename
FloatC
,
typename
A
K0MK1
GridDesc
,
typename
B
K0NK1
GridDesc
,
typename
CM0N0M1N1M2M3M4N2
GridDesc
,
typename
C
Block
ClusterAdaptor
>
typename
AGridDesc
_K0_M_K1
,
typename
BGridDesc
_K0_N_K1
,
typename
C
GridDesc_
M0
_
N0
_
M1
_
N1
_
M2
_
M3
_
M4
_
N2
,
typename
Block
2CTileMap
>
__global__
void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
...
...
@@ -61,34 +63,34 @@ __global__ void
kernel_gemm_xdlops_v2r3
(
const
FloatAB
*
__restrict__
p_a_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
FloatC
*
__restrict__
p_c_grid
,
const
void
CONSTANT
*
p_a_
k0_m_k1_
grid_desc
,
const
void
CONSTANT
*
p_b_
k0_n_k1_
grid_desc
,
const
void
CONSTANT
*
p_c_m0_n0_m1_n1_m2_m3_m4_n2
_grid_desc
,
const
void
CONSTANT
*
p_
c_
block_
cluster_adaptor
)
const
void
CONSTANT
*
p_a_grid_desc
_k0_m_k1
,
const
void
CONSTANT
*
p_b_grid_desc
_k0_n_k1
,
const
void
CONSTANT
*
p_c_
grid_desc_
m0_n0_m1_n1_m2_m3_m4_n2
,
const
void
CONSTANT
*
p_block_
2_ctile_map
)
{
constexpr
index_t
shared_block_size
=
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()
/
sizeof
(
FloatAB
);
const
auto
a_
k0_m_k1_
grid_desc
=
*
reinterpret_cast
<
const
A
K0MK1
GridDesc
*>
(
cast_pointer_to_generic_address_space
(
p_a_
k0_m_k1_
grid_desc
));
const
auto
b_
k0_n_k1_
grid_desc
=
*
reinterpret_cast
<
const
B
K0NK1
GridDesc
*>
(
cast_pointer_to_generic_address_space
(
p_b_
k0_n_k1_
grid_desc
));
const
auto
c_m0_n0_m1_n1_m2_m3_m4_n2
_grid_desc
=
*
reinterpret_cast
<
const
CM0N0M1N1M2M3M4N2
GridDesc
*>
(
cast_pointer_to_generic_address_space
(
p_c_m0_n0_m1_n1_m2_m3_m4_n2
_grid_desc
));
const
auto
c_
block_
cluster_adaptor
=
*
reinterpret_cast
<
const
C
Block
ClusterAdaptor
*>
(
cast_pointer_to_generic_address_space
(
p_
c_
block_
cluster_adaptor
));
const
auto
a_grid_desc
_k0_m_k1
=
*
reinterpret_cast
<
const
AGridDesc
_K0_M_K1
*>
(
cast_pointer_to_generic_address_space
(
p_a_grid_desc
_k0_m_k1
));
const
auto
b_grid_desc
_k0_n_k1
=
*
reinterpret_cast
<
const
BGridDesc
_K0_N_K1
*>
(
cast_pointer_to_generic_address_space
(
p_b_grid_desc
_k0_n_k1
));
const
auto
c_
grid_desc_
m0_n0_m1_n1_m2_m3_m4_n2
=
*
reinterpret_cast
<
const
C
GridDesc_
M0
_
N0
_
M1
_
N1
_
M2
_
M3
_
M4
_
N2
*>
(
cast_pointer_to_generic_address_space
(
p_c_
grid_desc_
m0_n0_m1_n1_m2_m3_m4_n2
));
const
auto
block_
2_ctile_map
=
*
reinterpret_cast
<
const
Block
2CTileMap
*>
(
cast_pointer_to_generic_address_space
(
p_block_
2_ctile_map
));
__shared__
FloatAB
p_shared_block
[
shared_block_size
];
GridwiseGemm
::
Run
(
p_a_grid
,
p_b_grid
,
p_c_grid
,
p_shared_block
,
a_
k0_m_k1_
grid_desc
,
b_
k0_n_k1_
grid_desc
,
c_m0_n0_m1_n1_m2_m3_m4_n2
_grid_desc
,
c_block_cluster_adaptor
);
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>
(
p_a_grid
,
p_b_grid
,
p_c_grid
,
p_shared_block
,
a_grid_desc
_k0_m_k1
,
b_grid_desc
_k0_n_k1
,
c_grid_des
c_m0_n0_m1_n1_m2_m3_m4_n2
,
block_2_ctile_map
);
}
#endif
...
...
@@ -97,12 +99,12 @@ template <index_t BlockSize,
typename
FloatAcc
,
typename
FloatC
,
InMemoryDataOperationEnum_t
CGlobalMemoryDataOperation
,
typename
A
K0MK1
GridDesc
,
typename
B
K0NK1
GridDesc
,
typename
C
MN
GridDesc
,
typename
AGridDesc
_K0_M_K1
,
typename
BGridDesc
_K0_N_K1
,
typename
CGridDesc
_M_N
,
index_t
MPerBlock
,
index_t
NPerBlock
,
index_t
KPerBlock
,
index_t
K
0
PerBlock
,
index_t
MPerXDL
,
index_t
NPerXDL
,
index_t
K1Value
,
...
...
@@ -154,50 +156,50 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
constexpr
auto
max_lds_align
=
K1
;
// A matrix in LDS memory, dst of blockwise copy
constexpr
auto
a_
k0_m_k1_
block_desc
=
[
&
]()
{
constexpr
auto
a_block_desc
_k0_m_k1
=
[
&
]()
{
if
constexpr
(
ABlockLdsExtraM
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
Number
<
KPerBlock
>
{},
Number
<
MPerBlock
>
{},
K1
),
make_tuple
(
Number
<
K
0
PerBlock
>
{},
Number
<
MPerBlock
>
{},
K1
),
make_tuple
(
Number
<
MPerBlock
+
1
>
{}
*
K1
,
K1
,
I1
));
}
else
{
return
make_naive_tensor_descriptor_aligned
(
make_tuple
(
Number
<
KPerBlock
>
{},
Number
<
MPerBlock
>
{},
K1
),
max_lds_align
);
make_tuple
(
Number
<
K
0
PerBlock
>
{},
Number
<
MPerBlock
>
{},
K1
),
max_lds_align
);
}
}();
// B matrix in LDS memory, dst of blockwise copy
constexpr
auto
b_
k0_n_k1_
block_desc
=
[
&
]()
{
constexpr
auto
b_block_desc
_k0_n_k1
=
[
&
]()
{
if
constexpr
(
BBlockLdsExtraN
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
Number
<
KPerBlock
>
{},
Number
<
NPerBlock
>
{},
K1
),
make_tuple
(
Number
<
K
0
PerBlock
>
{},
Number
<
NPerBlock
>
{},
K1
),
make_tuple
(
Number
<
NPerBlock
+
1
>
{}
*
K1
,
K1
,
I1
));
}
else
{
return
make_naive_tensor_descriptor_aligned
(
make_tuple
(
Number
<
KPerBlock
>
{},
Number
<
NPerBlock
>
{},
K1
),
max_lds_align
);
make_tuple
(
Number
<
K
0
PerBlock
>
{},
Number
<
NPerBlock
>
{},
K1
),
max_lds_align
);
}
}();
// LDS allocation for A and B: be careful of alignment
constexpr
auto
a_block_space_size
=
math
::
integer_least_multiple
(
a_
k0_m_k1_
block_desc
.
GetElementSpaceSize
(),
max_lds_align
);
math
::
integer_least_multiple
(
a_block_desc
_k0_m_k1
.
GetElementSpaceSize
(),
max_lds_align
);
constexpr
auto
b_block_space_size
=
math
::
integer_least_multiple
(
b_
k0_n_k1_
block_desc
.
GetElementSpaceSize
(),
max_lds_align
);
math
::
integer_least_multiple
(
b_block_desc
_k0_n_k1
.
GetElementSpaceSize
(),
max_lds_align
);
return
(
a_block_space_size
+
b_block_space_size
)
*
sizeof
(
FloatAB
);
}
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
__host__
__device__
static
constexpr
bool
CheckValidity
(
const
A
K0MK1
GridDesc
&
a_k0_m_k1
_grid_desc
,
const
B
K0NK1
GridDesc
&
b_k0_n_k1
_grid_desc
,
const
C
MN
GridDesc
&
c_m_n
_grid_desc
,
CheckValidity
(
const
AGridDesc
_K0_M_K1
&
a
_grid_desc
_k0_m_k1
,
const
BGridDesc
_K0_N_K1
&
b
_grid_desc
_k0_n_k1
,
const
CGridDesc
_M_N
&
c
_grid_desc
_m_n
,
index_t
M01
,
index_t
N01
)
{
...
...
@@ -208,16 +210,16 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
(
NPerBlock
%
(
NRepeat
*
NPerXDL
))
==
0
,
"Invalid tuning param!"
);
const
auto
M
=
a_
k0_m_k1_
grid_desc
.
GetLength
(
I1
);
const
auto
N
=
b_
k0_n_k1_
grid_desc
.
GetLength
(
I1
);
const
auto
K0
=
a_
k0_m_k1_
grid_desc
.
GetLength
(
I0
);
const
auto
M
=
a_grid_desc
_k0_m_k1
.
GetLength
(
I1
);
const
auto
N
=
b_grid_desc
_k0_n_k1
.
GetLength
(
I1
);
const
auto
K0
=
a_grid_desc
_k0_m_k1
.
GetLength
(
I0
);
if
(
!
(
M
==
c_
m_n_
grid_desc
.
GetLength
(
I0
)
&&
N
==
c_
m_n_
grid_desc
.
GetLength
(
I1
)
&&
K0
==
b_
k0_n_k1_
grid_desc
.
GetLength
(
I0
)
&&
K1
==
a_
k0_m_k1_
grid_desc
.
GetLength
(
I2
)
&&
K1
==
b_
k0_n_k1_
grid_desc
.
GetLength
(
I2
)))
if
(
!
(
M
==
c_grid_desc
_m_n
.
GetLength
(
I0
)
&&
N
==
c_grid_desc
_m_n
.
GetLength
(
I1
)
&&
K0
==
b_grid_desc
_k0_n_k1
.
GetLength
(
I0
)
&&
K1
==
a_grid_desc
_k0_m_k1
.
GetLength
(
I2
)
&&
K1
==
b_grid_desc
_k0_n_k1
.
GetLength
(
I2
)))
return
false
;
if
(
!
(
M
%
MPerBlock
==
0
&&
N
%
NPerBlock
==
0
&&
K0
%
KPerBlock
==
0
))
if
(
!
(
M
%
MPerBlock
==
0
&&
N
%
NPerBlock
==
0
&&
K0
%
K
0
PerBlock
==
0
))
return
false
;
// check M01, N01
...
...
@@ -235,48 +237,55 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
}
__host__
__device__
static
constexpr
index_t
CalculateGridSize
(
const
C
MN
GridDesc
&
c_m_n
_grid_desc
)
CalculateGridSize
(
const
CGridDesc
_M_N
&
c
_grid_desc
_m_n
)
{
const
auto
M
=
c_
m_n_
grid_desc
.
GetLength
(
I0
);
const
auto
N
=
c_
m_n_
grid_desc
.
GetLength
(
I1
);
const
auto
M
=
c_grid_desc
_m_n
.
GetLength
(
I0
);
const
auto
N
=
c_grid_desc
_m_n
.
GetLength
(
I1
);
const
index_t
grid_size
=
(
M
/
MPerBlock
)
*
(
N
/
NPerBlock
);
return
grid_size
;
}
__host__
__device__
static
constexpr
bool
CalculateHasMainK0BlockLoop
(
index_t
K0
)
{
const
bool
has_main_k0_block_loop
=
(
K0
/
K0PerBlock
)
>
1
;
return
has_main_k0_block_loop
;
}
__host__
__device__
static
constexpr
auto
MakeCM0N0M1N1M2M3M4N2
GridDescriptor
(
const
C
MN
GridDesc
&
c_m_n
_grid_desc
)
MakeC
GridDescriptor_
M0
_
N0
_
M1
_
N1
_
M2
_
M3
_
M4
_
N2
(
const
CGridDesc
_M_N
&
c
_grid_desc
_m_n
)
{
constexpr
auto
max_lds_align
=
K1
;
// A matrix in LDS memory, dst of blockwise copy
constexpr
auto
a_
k0_m_k1_
block_desc
=
[
&
]()
{
constexpr
auto
a_block_desc
_k0_m_k1
=
[
&
]()
{
if
constexpr
(
ABlockLdsExtraM
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
Number
<
KPerBlock
>
{},
Number
<
MPerBlock
>
{},
K1
),
make_tuple
(
Number
<
K
0
PerBlock
>
{},
Number
<
MPerBlock
>
{},
K1
),
make_tuple
(
Number
<
MPerBlock
+
1
>
{}
*
K1
,
K1
,
I1
));
}
else
{
return
make_naive_tensor_descriptor_aligned
(
make_tuple
(
Number
<
KPerBlock
>
{},
Number
<
MPerBlock
>
{},
K1
),
max_lds_align
);
make_tuple
(
Number
<
K
0
PerBlock
>
{},
Number
<
MPerBlock
>
{},
K1
),
max_lds_align
);
}
}();
// B matrix in LDS memory, dst of blockwise copy
constexpr
auto
b_
k0_n_k1_
block_desc
=
[
&
]()
{
constexpr
auto
b_block_desc
_k0_n_k1
=
[
&
]()
{
if
constexpr
(
BBlockLdsExtraN
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
Number
<
KPerBlock
>
{},
Number
<
NPerBlock
>
{},
K1
),
make_tuple
(
Number
<
K
0
PerBlock
>
{},
Number
<
NPerBlock
>
{},
K1
),
make_tuple
(
Number
<
NPerBlock
+
1
>
{}
*
K1
,
K1
,
I1
));
}
else
{
return
make_naive_tensor_descriptor_aligned
(
make_tuple
(
Number
<
KPerBlock
>
{},
Number
<
NPerBlock
>
{},
K1
),
max_lds_align
);
make_tuple
(
Number
<
K
0
PerBlock
>
{},
Number
<
NPerBlock
>
{},
K1
),
max_lds_align
);
}
}();
...
...
@@ -284,23 +293,23 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
<
BlockSize
,
FloatAB
,
FloatAcc
,
decltype
(
a_
k0_m_k1_
block_desc
),
decltype
(
b_
k0_n_k1_
block_desc
),
decltype
(
a_block_desc
_k0_m_k1
),
decltype
(
b_block_desc
_k0_n_k1
),
MPerXDL
,
NPerXDL
,
MRepeat
,
NRepeat
,
K1
>
;
return
BlockwiseGemm
::
MakeCM0N0M1N1M2M3M4N2
GridDescriptor
(
c_m_n
_grid_desc
);
return
BlockwiseGemm
::
MakeC
GridDescriptor_
M0
_
N0
_
M1
_
N1
_
M2
_
M3
_
M4
_
N2
(
c
_grid_desc
_m_n
);
}
// return block_id to C matrix tile idx (m0, n0) mapping
__host__
__device__
static
constexpr
auto
Make
C
Block
ClusterAdaptor
(
const
C
MN
GridDesc
&
c_m_n
_grid_desc
,
index_t
M01
,
index_t
N01
)
MakeBlock
2CTileMap
(
const
CGridDesc
_M_N
&
c
_grid_desc
_m_n
,
index_t
M01
,
index_t
N01
)
{
const
auto
M
=
c_
m_n_
grid_desc
.
GetLength
(
I0
);
const
auto
N
=
c_
m_n_
grid_desc
.
GetLength
(
I1
);
const
auto
M
=
c_grid_desc
_m_n
.
GetLength
(
I0
);
const
auto
N
=
c_grid_desc
_m_n
.
GetLength
(
I1
);
constexpr
auto
M1
=
Number
<
MPerBlock
>
{};
constexpr
auto
N1
=
Number
<
NPerBlock
>
{};
...
...
@@ -331,30 +340,33 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
return
c_blockid_to_m0_n0_block_cluster_adaptor
;
}
using
CM0N0M1N1M2M3M4N2GridDesc
=
decltype
(
MakeCM0N0M1N1M2M3M4N2GridDescriptor
(
CMNGridDesc
{}));
using
CBlockClusterAdaptor
=
decltype
(
MakeCBlockClusterAdaptor
(
CMNGridDesc
{},
1
,
1
));
__device__
static
void
Run
(
const
FloatAB
*
__restrict__
p_a_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
FloatC
*
__restrict__
p_c_grid
,
FloatAB
*
__restrict__
p_shared_block
,
const
AK0MK1GridDesc
&
a_k0_m_k1_grid_desc
,
const
BK0NK1GridDesc
&
b_k0_n_k1_grid_desc
,
const
CM0N0M1N1M2M3M4N2GridDesc
&
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc
,
const
CBlockClusterAdaptor
&
c_block_cluster_adaptor
)
using
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
=
decltype
(
MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
(
CGridDesc_M_N
{}));
using
Block2CTileMap
=
decltype
(
MakeBlock2CTileMap
(
CGridDesc_M_N
{},
1
,
1
));
template
<
bool
HasMainKBlockLoop
>
__device__
static
void
Run
(
const
FloatAB
*
__restrict__
p_a_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
FloatC
*
__restrict__
p_c_grid
,
FloatAB
*
__restrict__
p_shared_block
,
const
AGridDesc_K0_M_K1
&
a_grid_desc_k0_m_k1
,
const
BGridDesc_K0_N_K1
&
b_grid_desc_k0_n_k1
,
const
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
&
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
const
Block2CTileMap
&
block_2_ctile_map
)
{
const
auto
a_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Global
>
(
p_a_grid
,
a_
k0_m_k1_
grid_desc
.
GetElementSpaceSize
());
p_a_grid
,
a_grid_desc
_k0_m_k1
.
GetElementSpaceSize
());
const
auto
b_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Global
>
(
p_b_grid
,
b_
k0_n_k1_
grid_desc
.
GetElementSpaceSize
());
p_b_grid
,
b_grid_desc
_k0_n_k1
.
GetElementSpaceSize
());
auto
c_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Global
>
(
p_c_grid
,
c_m0_n0_m1_n1_m2_m3_m4_n2
_grid_desc
.
GetElementSpaceSize
());
p_c_grid
,
c_
grid_desc_
m0_n0_m1_n1_m2_m3_m4_n2
.
GetElementSpaceSize
());
const
auto
K0
=
a_
k0_m_k1_
grid_desc
.
GetLength
(
I0
);
const
auto
K0
=
a_grid_desc
_k0_m_k1
.
GetLength
(
I0
);
// divide block work by [M, N]
const
auto
block_work_idx
=
c_
block_
cluster_adaptor
.
CalculateBottomIndex
(
make_multi_index
(
get_block_1d_id
()));
block_
2_ctile_map
.
CalculateBottomIndex
(
make_multi_index
(
get_block_1d_id
()));
// HACK: this force m/n_block_data_idx_on_grid into SGPR
const
index_t
m_block_data_idx_on_grid
=
...
...
@@ -367,32 +379,32 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
constexpr
auto
max_lds_align
=
K1
;
// A matrix in LDS memory, dst of blockwise copy
constexpr
auto
a_
k0_m_k1_
block_desc
=
[
&
]()
{
constexpr
auto
a_block_desc
_k0_m_k1
=
[
&
]()
{
if
constexpr
(
ABlockLdsExtraM
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
Number
<
KPerBlock
>
{},
Number
<
MPerBlock
>
{},
K1
),
make_tuple
(
Number
<
K
0
PerBlock
>
{},
Number
<
MPerBlock
>
{},
K1
),
make_tuple
(
Number
<
MPerBlock
+
1
>
{}
*
K1
,
K1
,
I1
));
}
else
{
return
make_naive_tensor_descriptor_aligned
(
make_tuple
(
Number
<
KPerBlock
>
{},
Number
<
MPerBlock
>
{},
K1
),
max_lds_align
);
make_tuple
(
Number
<
K
0
PerBlock
>
{},
Number
<
MPerBlock
>
{},
K1
),
max_lds_align
);
}
}();
// B matrix in LDS memory, dst of blockwise copy
constexpr
auto
b_
k0_n_k1_
block_desc
=
[
&
]()
{
constexpr
auto
b_block_desc
_k0_n_k1
=
[
&
]()
{
if
constexpr
(
BBlockLdsExtraN
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
Number
<
KPerBlock
>
{},
Number
<
NPerBlock
>
{},
K1
),
make_tuple
(
Number
<
K
0
PerBlock
>
{},
Number
<
NPerBlock
>
{},
K1
),
make_tuple
(
Number
<
NPerBlock
+
1
>
{}
*
K1
,
K1
,
I1
));
}
else
{
return
make_naive_tensor_descriptor_aligned
(
make_tuple
(
Number
<
KPerBlock
>
{},
Number
<
NPerBlock
>
{},
K1
),
max_lds_align
);
make_tuple
(
Number
<
K
0
PerBlock
>
{},
Number
<
NPerBlock
>
{},
K1
),
max_lds_align
);
}
}();
...
...
@@ -400,14 +412,14 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
auto
a_blockwise_copy
=
BlockwiseTensorSliceTransfer_v4
<
BlockSize
,
InMemoryDataOperationEnum_t
::
Set
,
Sequence
<
KPerBlock
,
MPerBlock
,
K1
>
,
Sequence
<
K
0
PerBlock
,
MPerBlock
,
K1
>
,
ABlockTransferThreadSliceLengths_K0_M_K1
,
ABlockTransferThreadClusterLengths_K0_M_K1
,
ABlockTransferThreadClusterArrangeOrder
,
FloatAB
,
FloatAB
,
decltype
(
a_
k0_m_k1_
grid_desc
),
decltype
(
a_
k0_m_k1_
block_desc
),
decltype
(
a_grid_desc
_k0_m_k1
),
decltype
(
a_block_desc
_k0_m_k1
),
ABlockTransferSrcAccessOrder
,
Sequence
<
1
,
0
,
2
>
,
ABlockTransferSrcVectorDim
,
...
...
@@ -417,23 +429,23 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
1
,
1
,
AThreadTransferSrcResetCoordinateAfterRun
,
true
>
(
a_
k0_m_k1_
grid_desc
,
true
>
(
a_grid_desc
_k0_m_k1
,
make_multi_index
(
0
,
m_block_data_idx_on_grid
,
0
),
a_
k0_m_k1_
block_desc
,
a_block_desc
_k0_m_k1
,
make_multi_index
(
0
,
0
,
0
));
// B matrix blockwise copy
auto
b_blockwise_copy
=
BlockwiseTensorSliceTransfer_v4
<
BlockSize
,
InMemoryDataOperationEnum_t
::
Set
,
Sequence
<
KPerBlock
,
NPerBlock
,
K1
>
,
Sequence
<
K
0
PerBlock
,
NPerBlock
,
K1
>
,
BBlockTransferThreadSliceLengths_K0_N_K1
,
BBlockTransferThreadClusterLengths_K0_N_K1
,
BBlockTransferThreadClusterArrangeOrder
,
FloatAB
,
FloatAB
,
decltype
(
b_
k0_n_k1_
grid_desc
),
decltype
(
b_
k0_n_k1_
block_desc
),
decltype
(
b_grid_desc
_k0_n_k1
),
decltype
(
b_block_desc
_k0_n_k1
),
BBlockTransferSrcAccessOrder
,
Sequence
<
1
,
0
,
2
>
,
BBlockTransferSrcVectorDim
,
...
...
@@ -443,15 +455,15 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
1
,
1
,
BThreadTransferSrcResetCoordinateAfterRun
,
true
>
(
b_
k0_n_k1_
grid_desc
,
true
>
(
b_grid_desc
_k0_n_k1
,
make_multi_index
(
0
,
n_block_data_idx_on_grid
,
0
),
b_
k0_n_k1_
block_desc
,
b_block_desc
_k0_n_k1
,
make_multi_index
(
0
,
0
,
0
));
// GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx
// a_mtx[KPerBlock, MPerBlock] is in LDS
// b_mtx[KPerBlock, NPerBlock] is in LDS
// a_mtx[K
0
PerBlock, MPerBlock] is in LDS
// b_mtx[K
0
PerBlock, NPerBlock] is in LDS
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// register
// sanity check
...
...
@@ -460,8 +472,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
<
BlockSize
,
FloatAB
,
FloatAcc
,
decltype
(
a_
k0_m_k1_
block_desc
),
decltype
(
b_
k0_n_k1_
block_desc
),
decltype
(
a_block_desc
_k0_m_k1
),
decltype
(
b_block_desc
_k0_n_k1
),
MPerXDL
,
NPerXDL
,
MRepeat
,
...
...
@@ -472,13 +484,13 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
// LDS allocation for A and B: be careful of alignment
constexpr
auto
a_block_space_size
=
math
::
integer_least_multiple
(
a_
k0_m_k1_
block_desc
.
GetElementSpaceSize
(),
max_lds_align
);
math
::
integer_least_multiple
(
a_block_desc
_k0_m_k1
.
GetElementSpaceSize
(),
max_lds_align
);
FloatAB
*
p_a_block
=
p_shared_block
;
FloatAB
*
p_b_block
=
p_shared_block
+
a_block_space_size
;
constexpr
auto
a_block_slice_copy_step
=
make_multi_index
(
KPerBlock
,
0
,
0
);
constexpr
auto
b_block_slice_copy_step
=
make_multi_index
(
KPerBlock
,
0
,
0
);
constexpr
auto
a_block_slice_copy_step
=
make_multi_index
(
K
0
PerBlock
,
0
,
0
);
constexpr
auto
b_block_slice_copy_step
=
make_multi_index
(
K
0
PerBlock
,
0
,
0
);
// hack to control index calculation when iterating over A and B matrix for threadwise copy
constexpr
auto
a_k0_m_k1_grid_step_hacks
=
AGridStepHacks
{};
...
...
@@ -490,46 +502,51 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
constexpr
auto
b_k0_n_k1_grid_move_slice_window_step_hack
=
BGridMoveSliceWindowStepHacks
{};
auto
a_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Lds
>
(
p_a_block
,
a_
k0_m_k1_
block_desc
.
GetElementSpaceSize
());
p_a_block
,
a_block_desc
_k0_m_k1
.
GetElementSpaceSize
());
auto
b_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Lds
>
(
p_b_block
,
b_
k0_n_k1_
block_desc
.
GetElementSpaceSize
());
p_b_block
,
b_block_desc
_k0_n_k1
.
GetElementSpaceSize
());
// preload data into LDS
{
a_blockwise_copy
.
RunRead
(
a_
k0_m_k1_
grid_desc
,
a_grid_buf
,
a_k0_m_k1_grid_step_hacks
);
b_blockwise_copy
.
RunRead
(
b_
k0_n_k1_
grid_desc
,
b_grid_buf
,
b_k0_n_k1_grid_step_hacks
);
a_blockwise_copy
.
RunRead
(
a_grid_desc
_k0_m_k1
,
a_grid_buf
,
a_k0_m_k1_grid_step_hacks
);
b_blockwise_copy
.
RunRead
(
b_grid_desc
_k0_n_k1
,
b_grid_buf
,
b_k0_n_k1_grid_step_hacks
);
a_blockwise_copy
.
RunWrite
(
a_
k0_m_k1_
block_desc
,
a_block_buf
);
b_blockwise_copy
.
RunWrite
(
b_
k0_n_k1_
block_desc
,
b_block_buf
);
a_blockwise_copy
.
RunWrite
(
a_block_desc
_k0_m_k1
,
a_block_buf
);
b_blockwise_copy
.
RunWrite
(
b_block_desc
_k0_n_k1
,
b_block_buf
);
}
// main body
index_t
k_block_data_begin
=
0
;
index_t
k
0
_block_data_begin
=
0
;
do
if
constexpr
(
HasMainKBlockLoop
)
{
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_k0_m_k1_grid_desc
,
a_block_slice_copy_step
,
a_k0_m_k1_grid_move_slice_window_step_hack
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_k0_n_k1_grid_desc
,
b_block_slice_copy_step
,
b_k0_n_k1_grid_move_slice_window_step_hack
);
do
{
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc_k0_m_k1
,
a_block_slice_copy_step
,
a_k0_m_k1_grid_move_slice_window_step_hack
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc_k0_n_k1
,
b_block_slice_copy_step
,
b_k0_n_k1_grid_move_slice_window_step_hack
);
a_blockwise_copy
.
RunRead
(
a_k0_m_k1_grid_desc
,
a_grid_buf
,
a_k0_m_k1_grid_step_hacks
);
a_blockwise_copy
.
RunRead
(
a_grid_desc_k0_m_k1
,
a_grid_buf
,
a_k0_m_k1_grid_step_hacks
);
block_sync_lds
();
block_sync_lds
();
b_blockwise_copy
.
RunRead
(
b_k0_n_k1_grid_desc
,
b_grid_buf
,
b_k0_n_k1_grid_step_hacks
);
b_blockwise_copy
.
RunRead
(
b_grid_desc_k0_n_k1
,
b_grid_buf
,
b_k0_n_k1_grid_step_hacks
);
blockwise_gemm
.
Run
(
a_block_buf
,
b_block_buf
,
c_thread_buf
);
blockwise_gemm
.
Run
(
a_block_buf
,
b_block_buf
,
c_thread_buf
);
block_sync_lds
();
block_sync_lds
();
a_blockwise_copy
.
RunWrite
(
a_
k0_m_k1_
block_desc
,
a_block_buf
);
b_blockwise_copy
.
RunWrite
(
b_
k0_n_k1_
block_desc
,
b_block_buf
);
a_blockwise_copy
.
RunWrite
(
a_block_desc
_k0_m_k1
,
a_block_buf
);
b_blockwise_copy
.
RunWrite
(
b_block_desc
_k0_n_k1
,
b_block_buf
);
k_block_data_begin
+=
KPerBlock
;
}
while
(
k_block_data_begin
<
(
K0
-
KPerBlock
));
k0_block_data_begin
+=
K0PerBlock
;
}
while
(
k0_block_data_begin
<
(
K0
-
K0PerBlock
));
}
// tail
{
...
...
@@ -540,19 +557,19 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
// output: register to global memory
{
constexpr
auto
c_m0_n0_m1_n1_m2_m3_m4_n2
_block_desc
=
blockwise_gemm
.
GetCM0N0M1N1M2M3M4N2
BlockDescriptor
();
constexpr
auto
M0
=
c_m0_n0_m1_n1_m2_m3_m4_n2
_block_desc
.
GetLength
(
I0
);
constexpr
auto
N0
=
c_m0_n0_m1_n1_m2_m3_m4_n2
_block_desc
.
GetLength
(
I1
);
constexpr
auto
M1
=
c_m0_n0_m1_n1_m2_m3_m4_n2
_block_desc
.
GetLength
(
I2
);
constexpr
auto
N1
=
c_m0_n0_m1_n1_m2_m3_m4_n2
_block_desc
.
GetLength
(
I3
);
constexpr
auto
M2
=
c_m0_n0_m1_n1_m2_m3_m4_n2
_block_desc
.
GetLength
(
I4
);
constexpr
auto
M3
=
c_m0_n0_m1_n1_m2_m3_m4_n2
_block_desc
.
GetLength
(
I5
);
constexpr
auto
M4
=
c_m0_n0_m1_n1_m2_m3_m4_n2
_block_desc
.
GetLength
(
I6
);
constexpr
auto
N2
=
c_m0_n0_m1_n1_m2_m3_m4_n2
_block_desc
.
GetLength
(
I7
);
constexpr
auto
c_m0_n0_m1_n1_m2_m3_m4_n2
_thread_desc
=
constexpr
auto
c_
block_desc_
m0_n0_m1_n1_m2_m3_m4_n2
=
blockwise_gemm
.
GetC
BlockDescriptor_
M0
_
N0
_
M1
_
N1
_
M2
_
M3
_
M4
_
N2
();
constexpr
auto
M0
=
c_
block_desc_
m0_n0_m1_n1_m2_m3_m4_n2
.
GetLength
(
I0
);
constexpr
auto
N0
=
c_
block_desc_
m0_n0_m1_n1_m2_m3_m4_n2
.
GetLength
(
I1
);
constexpr
auto
M1
=
c_
block_desc_
m0_n0_m1_n1_m2_m3_m4_n2
.
GetLength
(
I2
);
constexpr
auto
N1
=
c_
block_desc_
m0_n0_m1_n1_m2_m3_m4_n2
.
GetLength
(
I3
);
constexpr
auto
M2
=
c_
block_desc_
m0_n0_m1_n1_m2_m3_m4_n2
.
GetLength
(
I4
);
constexpr
auto
M3
=
c_
block_desc_
m0_n0_m1_n1_m2_m3_m4_n2
.
GetLength
(
I5
);
constexpr
auto
M4
=
c_
block_desc_
m0_n0_m1_n1_m2_m3_m4_n2
.
GetLength
(
I6
);
constexpr
auto
N2
=
c_
block_desc_
m0_n0_m1_n1_m2_m3_m4_n2
.
GetLength
(
I7
);
constexpr
auto
c_
thread_desc_
m0_n0_m1_n1_m2_m3_m4_n2
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
M0
>
{},
Number
<
N0
>
{},
I1
,
I1
,
Number
<
M2
>
{},
I1
,
Number
<
M4
>
{},
I1
));
...
...
@@ -591,8 +608,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
auto
c_thread_copy
=
ThreadwiseTensorSliceTransfer_v1r3
<
FloatAcc
,
FloatC
,
decltype
(
c_m0_n0_m1_n1_m2_m3_m4_n2
_thread_desc
),
decltype
(
c_m0_n0_m1_n1_m2_m3_m4_n2
_grid_desc
),
decltype
(
c_
thread_desc_
m0_n0_m1_n1_m2_m3_m4_n2
),
decltype
(
c_
grid_desc_
m0_n0_m1_n1_m2_m3_m4_n2
),
Sequence
<
M0
,
N0
,
I1
,
I1
,
M2
,
I1
,
M4
,
I1
>
,
CThreadTransferSrcDstAccessOrder
,
CThreadTransferSrcDstVectorDim
,
...
...
@@ -601,7 +618,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
1
,
true
>
{
c_m0_n0_m1_n1_m2_m3_m4_n2
_grid_desc
,
c_
grid_desc_
m0_n0_m1_n1_m2_m3_m4_n2
,
make_multi_index
(
m_thread_data_on_grid_idx
[
I0
],
n_thread_data_on_grid_idx
[
I0
],
m_thread_data_on_grid_idx
[
I1
],
...
...
@@ -611,10 +628,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
m_thread_data_on_grid_idx
[
I4
],
n_thread_data_on_grid_idx
[
I2
])};
c_thread_copy
.
Run
(
c_m0_n0_m1_n1_m2_m3_m4_n2
_thread_desc
,
c_thread_copy
.
Run
(
c_
thread_desc_
m0_n0_m1_n1_m2_m3_m4_n2
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
c_thread_buf
,
c_m0_n0_m1_n1_m2_m3_m4_n2
_grid_desc
,
c_
grid_desc_
m0_n0_m1_n1_m2_m3_m4_n2
,
c_grid_buf
,
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_tensor_step_hacks
);
}
...
...
composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v2r4.hpp
0 → 100644
View file @
ed068043
#ifndef CK_GRIDWISE_GEMM_XDLOPS_V2R4_HPP
#define CK_GRIDWISE_GEMM_XDLOPS_V2R4_HPP
#include "common_header.hpp"
#include "multi_index_transform_helper.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "blockwise_gemm_xdlops.hpp"
#include "blockwise_tensor_slice_transfer.hpp"
#include "threadwise_tensor_slice_transfer.hpp"
#include "threadwise_tensor_slice_set.hpp"
namespace
ck
{
#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE
template
<
typename
GridwiseGemm
,
typename
FloatAB
,
typename
FloatC
,
typename
ABK0MK1GridDesc
,
typename
BBK0NK1GridDesc
,
typename
CM0N0M1N1M2M3M4N2GridDesc
,
typename
CBlockClusterAdaptor
,
bool
HasMainKBlockLoop
>
__global__
void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
#endif
kernel_gemm_xdlops_v2r4
(
const
FloatAB
*
__restrict__
p_a_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
FloatC
*
__restrict__
p_c_grid
,
const
ABK0MK1GridDesc
a_b_k0_m_k1_grid_desc
,
const
BBK0NK1GridDesc
b_b_k0_n_k1_grid_desc
,
const
CM0N0M1N1M2M3M4N2GridDesc
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc
,
const
CBlockClusterAdaptor
c_block_cluster_adaptor
)
{
constexpr
index_t
shared_block_size
=
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()
/
sizeof
(
FloatAB
);
__shared__
FloatAB
p_shared_block
[
shared_block_size
];
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
p_a_grid
,
p_b_grid
,
p_c_grid
,
p_shared_block
,
a_b_k0_m_k1_grid_desc
,
b_b_k0_n_k1_grid_desc
,
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc
,
c_block_cluster_adaptor
);
}
#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER
template
<
typename
GridwiseGemm
,
typename
FloatAB
,
typename
FloatC
,
typename
ABK0MK1GridDesc
,
typename
BBK0NK1GridDesc
,
typename
CM0N0M1N1M2M3M4N2GridDesc
,
typename
CBlockClusterAdaptor
,
bool
HasMainKBlockLoop
>
__global__
void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
#endif
kernel_gemm_xdlops_v2r4
(
const
FloatAB
*
__restrict__
p_a_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
FloatC
*
__restrict__
p_c_grid
,
const
void
CONSTANT
*
p_a_b_k0_m_k1_grid_desc
,
const
void
CONSTANT
*
p_b_b_k0_n_k1_grid_desc
,
const
void
CONSTANT
*
p_c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc
,
const
void
CONSTANT
*
p_c_block_cluster_adaptor
)
{
constexpr
index_t
shared_block_size
=
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()
/
sizeof
(
FloatAB
);
const
auto
a_b_k0_m_k1_grid_desc
=
*
reinterpret_cast
<
const
ABK0MK1GridDesc
*>
(
cast_pointer_to_generic_address_space
(
p_a_b_k0_m_k1_grid_desc
));
const
auto
b_b_k0_n_k1_grid_desc
=
*
reinterpret_cast
<
const
BBK0NK1GridDesc
*>
(
cast_pointer_to_generic_address_space
(
p_b_b_k0_n_k1_grid_desc
));
const
auto
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc
=
*
reinterpret_cast
<
const
CM0N0M1N1M2M3M4N2GridDesc
*>
(
cast_pointer_to_generic_address_space
(
p_c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc
));
const
auto
c_block_cluster_adaptor
=
*
reinterpret_cast
<
const
CBlockClusterAdaptor
*>
(
cast_pointer_to_generic_address_space
(
p_c_block_cluster_adaptor
));
__shared__
FloatAB
p_shared_block
[
shared_block_size
];
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
p_a_grid
,
p_b_grid
,
p_c_grid
,
p_shared_block
,
a_b_k0_m_k1_grid_desc
,
b_b_k0_n_k1_grid_desc
,
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc
,
c_block_cluster_adaptor
);
}
#endif
template
<
index_t
BlockSize
,
typename
FloatAB
,
typename
FloatAcc
,
typename
FloatC
,
InMemoryDataOperationEnum_t
CGlobalMemoryDataOperation
,
typename
ABK0MK1GridDesc
,
typename
BBK0NK1GridDesc
,
typename
CMNGridDesc
,
index_t
MPerBlock
,
index_t
NPerBlock
,
index_t
K0PerBlock
,
index_t
MPerXDL
,
index_t
NPerXDL
,
index_t
K1Value
,
index_t
MRepeat
,
index_t
NRepeat
,
typename
ABlockTransferThreadSliceLengths_K0_M_K1
,
typename
ABlockTransferThreadClusterLengths_K0_M_K1
,
typename
ABlockTransferThreadClusterArrangeOrder
,
typename
ABlockTransferSrcAccessOrder
,
index_t
ABlockTransferSrcVectorDim
,
index_t
ABlockTransferSrcScalarPerVector
,
index_t
ABlockTransferDstScalarPerVector_K1
,
bool
AThreadTransferSrcResetCoordinateAfterRun
,
typename
BBlockTransferThreadSliceLengths_K0_N_K1
,
typename
BBlockTransferThreadClusterLengths_K0_N_K1
,
typename
BBlockTransferThreadClusterArrangeOrder
,
typename
BBlockTransferSrcAccessOrder
,
index_t
BBlockTransferSrcVectorDim
,
index_t
BBlockTransferSrcScalarPerVector
,
index_t
BBlockTransferDstScalarPerVector_K1
,
bool
BThreadTransferSrcResetCoordinateAfterRun
,
typename
CThreadTransferSrcDstAccessOrder
,
index_t
CThreadTransferSrcDstVectorDim
,
index_t
CThreadTransferDstScalarPerVector
,
typename
AGridStepHacks
,
typename
BGridStepHacks
,
typename
CGridStepHacks
,
typename
AGridMoveSliceWindowStepHacks
,
typename
BGridMoveSliceWindowStepHacks
,
bool
CAccessOrderMRepeatNRepeat
,
bool
ABlockLdsExtraM
,
bool
BBlockLdsExtraN
>
struct
GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
static
constexpr
auto
I4
=
Number
<
4
>
{};
static
constexpr
auto
I5
=
Number
<
5
>
{};
static
constexpr
auto
I6
=
Number
<
6
>
{};
static
constexpr
auto
I7
=
Number
<
7
>
{};
// K1 should be Number<...>
static
constexpr
auto
K1
=
Number
<
K1Value
>
{};
__host__
__device__
static
constexpr
index_t
GetSharedMemoryNumberOfByte
()
{
constexpr
auto
max_lds_align
=
K1
;
// A matrix in LDS memory, dst of blockwise copy
constexpr
auto
a_k0_m_k1_block_desc
=
[
&
]()
{
if
constexpr
(
ABlockLdsExtraM
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
Number
<
K0PerBlock
>
{},
Number
<
MPerBlock
>
{},
K1
),
make_tuple
(
Number
<
MPerBlock
+
1
>
{}
*
K1
,
K1
,
I1
));
}
else
{
return
make_naive_tensor_descriptor_aligned
(
make_tuple
(
Number
<
K0PerBlock
>
{},
Number
<
MPerBlock
>
{},
K1
),
max_lds_align
);
}
}();
// B matrix in LDS memory, dst of blockwise copy
constexpr
auto
b_k0_n_k1_block_desc
=
[
&
]()
{
if
constexpr
(
BBlockLdsExtraN
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
Number
<
K0PerBlock
>
{},
Number
<
NPerBlock
>
{},
K1
),
make_tuple
(
Number
<
NPerBlock
+
1
>
{}
*
K1
,
K1
,
I1
));
}
else
{
return
make_naive_tensor_descriptor_aligned
(
make_tuple
(
Number
<
K0PerBlock
>
{},
Number
<
NPerBlock
>
{},
K1
),
max_lds_align
);
}
}();
// LDS allocation for A and B: be careful of alignment
constexpr
auto
a_block_space_size
=
math
::
integer_least_multiple
(
a_k0_m_k1_block_desc
.
GetElementSpaceSize
(),
max_lds_align
);
constexpr
auto
b_block_space_size
=
math
::
integer_least_multiple
(
b_k0_n_k1_block_desc
.
GetElementSpaceSize
(),
max_lds_align
);
return
(
a_block_space_size
+
b_block_space_size
)
*
sizeof
(
FloatAB
);
}
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
__host__
__device__
static
constexpr
bool
CheckValidity
(
const
ABK0MK1GridDesc
&
a_b_k0_m_k1_grid_desc
,
const
BBK0NK1GridDesc
&
b_b_k0_n_k1_grid_desc
,
const
CMNGridDesc
&
c_m_n_grid_desc
,
index_t
M01
,
index_t
N01
)
{
static_assert
(
is_known_at_compile_time
<
remove_cv_t
<
decltype
(
K1
)
>>::
value
,
"wrong! K1 need to be known at compile-time"
);
static_assert
((
MPerBlock
%
(
MPerXDL
*
MRepeat
)
==
0
)
&&
(
NPerBlock
%
(
NRepeat
*
NPerXDL
))
==
0
,
"Invalid tuning param!"
);
const
auto
M
=
a_b_k0_m_k1_grid_desc
.
GetLength
(
I2
);
const
auto
N
=
b_b_k0_n_k1_grid_desc
.
GetLength
(
I2
);
const
auto
K0
=
a_b_k0_m_k1_grid_desc
.
GetLength
(
I1
);
const
auto
KBatch
=
a_b_k0_m_k1_grid_desc
.
GetLength
(
I0
);
if
(
!
(
M
==
c_m_n_grid_desc
.
GetLength
(
I0
)
&&
N
==
c_m_n_grid_desc
.
GetLength
(
I1
)
&&
K0
==
b_b_k0_n_k1_grid_desc
.
GetLength
(
I1
)
&&
K1
==
a_b_k0_m_k1_grid_desc
.
GetLength
(
I3
)
&&
K1
==
b_b_k0_n_k1_grid_desc
.
GetLength
(
I3
)
&&
KBatch
==
b_b_k0_n_k1_grid_desc
.
GetLength
(
I0
)))
return
false
;
if
(
!
(
M
%
MPerBlock
==
0
&&
N
%
NPerBlock
==
0
&&
K0
%
K0PerBlock
==
0
))
return
false
;
// check M01, N01
constexpr
auto
M1
=
Number
<
MPerBlock
>
{};
constexpr
auto
N1
=
Number
<
NPerBlock
>
{};
const
auto
M0
=
M
/
M1
;
const
auto
N0
=
N
/
N1
;
if
(
!
(
M0
%
M01
==
0
&&
N0
%
N01
==
0
))
return
false
;
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
return
true
;
}
__host__
__device__
static
constexpr
index_t
CalculateGridSize
(
const
CMNGridDesc
&
c_m_n_grid_desc
,
index_t
KBatch
)
{
const
auto
M
=
c_m_n_grid_desc
.
GetLength
(
I0
);
const
auto
N
=
c_m_n_grid_desc
.
GetLength
(
I1
);
const
index_t
grid_size
=
(
M
/
MPerBlock
)
*
(
N
/
NPerBlock
)
*
KBatch
;
return
grid_size
;
}
__host__
__device__
static
constexpr
bool
CalculateHasMainK0BlockLoop
(
index_t
K0
)
{
const
bool
has_main_k0_block_loop
=
K0
>
K0PerBlock
;
return
has_main_k0_block_loop
;
}
__host__
__device__
static
constexpr
auto
MakeCM0N0M1N1M2M3M4N2GridDescriptor
(
const
CMNGridDesc
&
c_m_n_grid_desc
)
{
constexpr
auto
max_lds_align
=
K1
;
// A matrix in LDS memory, dst of blockwise copy
constexpr
auto
a_k0_m_k1_block_desc
=
[
&
]()
{
if
constexpr
(
ABlockLdsExtraM
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
Number
<
K0PerBlock
>
{},
Number
<
MPerBlock
>
{},
K1
),
make_tuple
(
Number
<
MPerBlock
+
1
>
{}
*
K1
,
K1
,
I1
));
}
else
{
return
make_naive_tensor_descriptor_aligned
(
make_tuple
(
Number
<
K0PerBlock
>
{},
Number
<
MPerBlock
>
{},
K1
),
max_lds_align
);
}
}();
// B matrix in LDS memory, dst of blockwise copy
constexpr
auto
b_k0_n_k1_block_desc
=
[
&
]()
{
if
constexpr
(
BBlockLdsExtraN
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
Number
<
K0PerBlock
>
{},
Number
<
NPerBlock
>
{},
K1
),
make_tuple
(
Number
<
NPerBlock
+
1
>
{}
*
K1
,
K1
,
I1
));
}
else
{
return
make_naive_tensor_descriptor_aligned
(
make_tuple
(
Number
<
K0PerBlock
>
{},
Number
<
NPerBlock
>
{},
K1
),
max_lds_align
);
}
}();
using
BlockwiseGemm
=
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
<
BlockSize
,
FloatAB
,
FloatAcc
,
decltype
(
a_k0_m_k1_block_desc
),
decltype
(
b_k0_n_k1_block_desc
),
MPerXDL
,
NPerXDL
,
MRepeat
,
NRepeat
,
K1
>
;
return
BlockwiseGemm
::
MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
(
c_m_n_grid_desc
);
}
// return block_id to C matrix tile idx (m0, n0) mapping
__host__
__device__
static
constexpr
auto
MakeCBlockClusterAdaptor
(
const
CMNGridDesc
&
c_m_n_grid_desc
,
index_t
M01
,
index_t
N01
,
index_t
KBatch
)
{
const
auto
M
=
c_m_n_grid_desc
.
GetLength
(
I0
);
const
auto
N
=
c_m_n_grid_desc
.
GetLength
(
I1
);
constexpr
auto
M1
=
Number
<
MPerBlock
>
{};
constexpr
auto
N1
=
Number
<
NPerBlock
>
{};
const
auto
M0
=
M
/
M1
;
const
auto
N0
=
N
/
N1
;
const
auto
M00
=
M0
/
M01
;
const
auto
N00
=
N0
/
N01
;
const
auto
kbatch_m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_pass_through_transform
(
KBatch
),
make_unmerge_transform
(
make_tuple
(
M00
,
M01
)),
make_unmerge_transform
(
make_tuple
(
N00
,
N01
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
3
>
{},
Sequence
<
2
,
4
>
{}));
const
auto
c_blockid_to_kbatch_m00_m01_n00_n01_block_cluster_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
KBatch
,
M00
,
N00
,
M01
,
N01
))),
make_tuple
(
Sequence
<
0
,
1
,
2
,
3
,
4
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
const
auto
c_blockid_to_kbatch_m0_n0_block_cluster_adaptor
=
chain_tensor_adaptors
(
kbatch_m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor
,
c_blockid_to_kbatch_m00_m01_n00_n01_block_cluster_adaptor
);
return
c_blockid_to_kbatch_m0_n0_block_cluster_adaptor
;
}
using
CM0N0M1N1M2M3M4N2GridDesc
=
decltype
(
MakeCM0N0M1N1M2M3M4N2GridDescriptor
(
CMNGridDesc
{}));
using
CBlockClusterAdaptor
=
decltype
(
MakeCBlockClusterAdaptor
(
CMNGridDesc
{},
1
,
1
,
1
));
template
<
bool
HasMainKBlockLoop
>
__device__
static
void
Run
(
const
FloatAB
*
__restrict__
p_a_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
FloatC
*
__restrict__
p_c_grid
,
FloatAB
*
__restrict__
p_shared_block
,
const
ABK0MK1GridDesc
&
a_b_k0_m_k1_grid_desc
,
const
BBK0NK1GridDesc
&
b_b_k0_n_k1_grid_desc
,
const
CM0N0M1N1M2M3M4N2GridDesc
&
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc
,
const
CBlockClusterAdaptor
&
c_block_cluster_adaptor
)
{
const
auto
a_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Global
>
(
p_a_grid
,
a_b_k0_m_k1_grid_desc
.
GetElementSpaceSize
());
const
auto
b_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Global
>
(
p_b_grid
,
b_b_k0_n_k1_grid_desc
.
GetElementSpaceSize
());
auto
c_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Global
>
(
p_c_grid
,
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc
.
GetElementSpaceSize
());
const
auto
K0
=
a_b_k0_m_k1_grid_desc
.
GetLength
(
I1
);
// divide block work by [M, N]
const
auto
block_work_idx
=
c_block_cluster_adaptor
.
CalculateBottomIndex
(
make_multi_index
(
get_block_1d_id
()));
const
index_t
k_batch_id
=
block_work_idx
[
I0
];
// HACK: this force m/n_block_data_idx_on_grid into SGPR
const
index_t
m_block_data_idx_on_grid
=
__builtin_amdgcn_readfirstlane
(
block_work_idx
[
I1
]
*
MPerBlock
);
const
index_t
n_block_data_idx_on_grid
=
__builtin_amdgcn_readfirstlane
(
block_work_idx
[
I2
]
*
NPerBlock
);
// lds max alignment
constexpr
auto
max_lds_align
=
K1
;
// A matrix in LDS memory, dst of blockwise copy
constexpr
auto
a_k0_m_k1_block_desc
=
[
&
]()
{
if
constexpr
(
ABlockLdsExtraM
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
Number
<
K0PerBlock
>
{},
Number
<
MPerBlock
>
{},
K1
),
make_tuple
(
Number
<
MPerBlock
+
1
>
{}
*
K1
,
K1
,
I1
));
}
else
{
return
make_naive_tensor_descriptor_aligned
(
make_tuple
(
Number
<
K0PerBlock
>
{},
Number
<
MPerBlock
>
{},
K1
),
max_lds_align
);
}
}();
constexpr
auto
a_b_k0_m_k1_block_desc
=
[
&
]()
{
if
constexpr
(
ABlockLdsExtraM
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
Number
<
1
>
{},
Number
<
K0PerBlock
>
{},
Number
<
MPerBlock
>
{},
K1
),
make_tuple
(
Number
<
K0PerBlock
>
{}
*
Number
<
MPerBlock
+
1
>
{}
*
K1
,
Number
<
MPerBlock
+
1
>
{}
*
K1
,
K1
,
I1
));
}
else
{
return
make_naive_tensor_descriptor_aligned
(
make_tuple
(
Number
<
1
>
{},
Number
<
K0PerBlock
>
{},
Number
<
MPerBlock
>
{},
K1
),
max_lds_align
);
}
}();
// B matrix in LDS memory, dst of blockwise copy
constexpr
auto
b_k0_n_k1_block_desc
=
[
&
]()
{
if
constexpr
(
BBlockLdsExtraN
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
Number
<
K0PerBlock
>
{},
Number
<
NPerBlock
>
{},
K1
),
make_tuple
(
Number
<
NPerBlock
+
1
>
{}
*
K1
,
K1
,
I1
));
}
else
{
return
make_naive_tensor_descriptor_aligned
(
make_tuple
(
Number
<
K0PerBlock
>
{},
Number
<
NPerBlock
>
{},
K1
),
max_lds_align
);
}
}();
constexpr
auto
b_b_k0_n_k1_block_desc
=
[
&
]()
{
if
constexpr
(
BBlockLdsExtraN
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
Number
<
1
>
{},
Number
<
K0PerBlock
>
{},
Number
<
NPerBlock
>
{},
K1
),
make_tuple
(
Number
<
K0PerBlock
>
{}
*
Number
<
NPerBlock
+
1
>
{}
*
K1
,
Number
<
NPerBlock
+
1
>
{}
*
K1
,
K1
,
I1
));
}
else
{
return
make_naive_tensor_descriptor_aligned
(
make_tuple
(
Number
<
1
>
{},
Number
<
K0PerBlock
>
{},
Number
<
NPerBlock
>
{},
K1
),
max_lds_align
);
}
}();
// A matrix blockwise copy
auto
a_blockwise_copy
=
BlockwiseTensorSliceTransfer_v4
<
BlockSize
,
InMemoryDataOperationEnum_t
::
Set
,
Sequence
<
1
,
K0PerBlock
,
MPerBlock
,
K1
>
,
ABlockTransferThreadSliceLengths_K0_M_K1
,
ABlockTransferThreadClusterLengths_K0_M_K1
,
ABlockTransferThreadClusterArrangeOrder
,
FloatAB
,
FloatAB
,
decltype
(
a_b_k0_m_k1_grid_desc
),
decltype
(
a_b_k0_m_k1_block_desc
),
ABlockTransferSrcAccessOrder
,
Sequence
<
0
,
2
,
1
,
3
>
,
ABlockTransferSrcVectorDim
,
3
,
ABlockTransferSrcScalarPerVector
,
ABlockTransferDstScalarPerVector_K1
,
1
,
1
,
AThreadTransferSrcResetCoordinateAfterRun
,
true
>
(
a_b_k0_m_k1_grid_desc
,
make_multi_index
(
k_batch_id
,
0
,
m_block_data_idx_on_grid
,
0
),
a_b_k0_m_k1_block_desc
,
make_multi_index
(
0
,
0
,
0
,
0
));
// B matrix blockwise copy
auto
b_blockwise_copy
=
BlockwiseTensorSliceTransfer_v4
<
BlockSize
,
InMemoryDataOperationEnum_t
::
Set
,
Sequence
<
1
,
K0PerBlock
,
NPerBlock
,
K1
>
,
BBlockTransferThreadSliceLengths_K0_N_K1
,
BBlockTransferThreadClusterLengths_K0_N_K1
,
BBlockTransferThreadClusterArrangeOrder
,
FloatAB
,
FloatAB
,
decltype
(
b_b_k0_n_k1_grid_desc
),
decltype
(
b_b_k0_n_k1_block_desc
),
BBlockTransferSrcAccessOrder
,
Sequence
<
0
,
2
,
1
,
3
>
,
BBlockTransferSrcVectorDim
,
3
,
BBlockTransferSrcScalarPerVector
,
BBlockTransferDstScalarPerVector_K1
,
1
,
1
,
BThreadTransferSrcResetCoordinateAfterRun
,
true
>
(
b_b_k0_n_k1_grid_desc
,
make_multi_index
(
k_batch_id
,
0
,
n_block_data_idx_on_grid
,
0
),
b_b_k0_n_k1_block_desc
,
make_multi_index
(
0
,
0
,
0
,
0
));
// GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx
// a_mtx[K0PerBlock, MPerBlock] is in LDS
// b_mtx[K0PerBlock, NPerBlock] is in LDS
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// register
// sanity check
auto
blockwise_gemm
=
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
<
BlockSize
,
FloatAB
,
FloatAcc
,
decltype
(
a_k0_m_k1_block_desc
),
decltype
(
b_k0_n_k1_block_desc
),
MPerXDL
,
NPerXDL
,
MRepeat
,
NRepeat
,
K1
>
{};
auto
c_thread_buf
=
blockwise_gemm
.
GetCThreadBuffer
();
// LDS allocation for A and B: be careful of alignment
constexpr
auto
a_block_space_size
=
math
::
integer_least_multiple
(
a_k0_m_k1_block_desc
.
GetElementSpaceSize
(),
max_lds_align
);
FloatAB
*
p_a_block
=
p_shared_block
;
FloatAB
*
p_b_block
=
p_shared_block
+
a_block_space_size
;
constexpr
auto
a_block_slice_copy_step
=
make_multi_index
(
0
,
K0PerBlock
,
0
,
0
);
constexpr
auto
b_block_slice_copy_step
=
make_multi_index
(
0
,
K0PerBlock
,
0
,
0
);
// hack to control index calculation when iterating over A and B matrix for threadwise copy
constexpr
auto
a_k0_m_k1_grid_step_hacks
=
AGridStepHacks
{};
constexpr
auto
b_k0_n_k1_grid_step_hacks
=
BGridStepHacks
{};
// hack to control index calculation when move slice window for A and B matrix for
// threadwise copy
constexpr
auto
a_k0_m_k1_grid_move_slice_window_step_hack
=
AGridMoveSliceWindowStepHacks
{};
constexpr
auto
b_k0_n_k1_grid_move_slice_window_step_hack
=
BGridMoveSliceWindowStepHacks
{};
auto
a_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Lds
>
(
p_a_block
,
a_k0_m_k1_block_desc
.
GetElementSpaceSize
());
auto
b_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Lds
>
(
p_b_block
,
b_k0_n_k1_block_desc
.
GetElementSpaceSize
());
// preload data into LDS
{
a_blockwise_copy
.
RunRead
(
a_b_k0_m_k1_grid_desc
,
a_grid_buf
,
a_k0_m_k1_grid_step_hacks
);
b_blockwise_copy
.
RunRead
(
b_b_k0_n_k1_grid_desc
,
b_grid_buf
,
b_k0_n_k1_grid_step_hacks
);
a_blockwise_copy
.
RunWrite
(
a_b_k0_m_k1_block_desc
,
a_block_buf
);
b_blockwise_copy
.
RunWrite
(
b_b_k0_n_k1_block_desc
,
b_block_buf
);
}
// main body
index_t
k_block_data_begin
=
0
;
if
constexpr
(
HasMainKBlockLoop
)
{
do
{
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_b_k0_m_k1_grid_desc
,
a_block_slice_copy_step
,
a_k0_m_k1_grid_move_slice_window_step_hack
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_b_k0_n_k1_grid_desc
,
b_block_slice_copy_step
,
b_k0_n_k1_grid_move_slice_window_step_hack
);
a_blockwise_copy
.
RunRead
(
a_b_k0_m_k1_grid_desc
,
a_grid_buf
,
a_k0_m_k1_grid_step_hacks
);
block_sync_lds
();
b_blockwise_copy
.
RunRead
(
b_b_k0_n_k1_grid_desc
,
b_grid_buf
,
b_k0_n_k1_grid_step_hacks
);
blockwise_gemm
.
Run
(
a_block_buf
,
b_block_buf
,
c_thread_buf
);
block_sync_lds
();
a_blockwise_copy
.
RunWrite
(
a_b_k0_m_k1_block_desc
,
a_block_buf
);
b_blockwise_copy
.
RunWrite
(
b_b_k0_n_k1_block_desc
,
b_block_buf
);
k_block_data_begin
+=
K0PerBlock
;
}
while
(
k_block_data_begin
<
(
K0
-
K0PerBlock
));
}
// tail
{
block_sync_lds
();
blockwise_gemm
.
Run
(
a_block_buf
,
b_block_buf
,
c_thread_buf
);
}
// output: register to global memory
{
constexpr
auto
c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc
=
blockwise_gemm
.
GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
();
constexpr
auto
M0
=
c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc
.
GetLength
(
I0
);
constexpr
auto
N0
=
c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc
.
GetLength
(
I1
);
constexpr
auto
M1
=
c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc
.
GetLength
(
I2
);
constexpr
auto
N1
=
c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc
.
GetLength
(
I3
);
constexpr
auto
M2
=
c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc
.
GetLength
(
I4
);
constexpr
auto
M3
=
c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc
.
GetLength
(
I5
);
constexpr
auto
M4
=
c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc
.
GetLength
(
I6
);
constexpr
auto
N2
=
c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc
.
GetLength
(
I7
);
constexpr
auto
c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
M0
>
{},
Number
<
N0
>
{},
I1
,
I1
,
Number
<
M2
>
{},
I1
,
Number
<
M4
>
{},
I1
));
// calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index
const
auto
c_thread_mtx_on_block
=
blockwise_gemm
.
CalculateCThreadOriginDataIndex
(
I0
,
I0
,
I0
,
I0
);
const
index_t
m_thread_data_on_grid
=
m_block_data_idx_on_grid
+
c_thread_mtx_on_block
[
I0
];
const
index_t
n_thread_data_on_grid
=
n_block_data_idx_on_grid
+
c_thread_mtx_on_block
[
I1
];
constexpr
auto
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_tensor_step_hacks
=
CGridStepHacks
{};
const
auto
m_thread_data_on_grid_to_m0_m1_m2_m3_m4_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
M0
,
M1
,
M2
,
M3
,
M4
))),
make_tuple
(
Sequence
<
0
,
1
,
2
,
3
,
4
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
const
auto
m_thread_data_on_grid_idx
=
m_thread_data_on_grid_to_m0_m1_m2_m3_m4_adaptor
.
CalculateBottomIndex
(
make_multi_index
(
m_thread_data_on_grid
));
const
auto
n_thread_data_on_grid_to_n0_n1_n2_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
N0
,
N1
,
N2
))),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
const
auto
n_thread_data_on_grid_idx
=
n_thread_data_on_grid_to_n0_n1_n2_adaptor
.
CalculateBottomIndex
(
make_multi_index
(
n_thread_data_on_grid
));
auto
c_thread_copy
=
ThreadwiseTensorSliceTransfer_v1r3
<
FloatAcc
,
FloatC
,
decltype
(
c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc
),
decltype
(
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc
),
Sequence
<
M0
,
N0
,
I1
,
I1
,
M2
,
I1
,
M4
,
I1
>
,
CThreadTransferSrcDstAccessOrder
,
CThreadTransferSrcDstVectorDim
,
CThreadTransferDstScalarPerVector
,
CGlobalMemoryDataOperation
,
1
,
true
>
{
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc
,
make_multi_index
(
m_thread_data_on_grid_idx
[
I0
],
n_thread_data_on_grid_idx
[
I0
],
m_thread_data_on_grid_idx
[
I1
],
n_thread_data_on_grid_idx
[
I1
],
m_thread_data_on_grid_idx
[
I2
],
m_thread_data_on_grid_idx
[
I3
],
m_thread_data_on_grid_idx
[
I4
],
n_thread_data_on_grid_idx
[
I2
])};
c_thread_copy
.
Run
(
c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
c_thread_buf
,
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc
,
c_grid_buf
,
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_tensor_step_hacks
);
}
}
};
// namespace ck
}
// namespace ck
#endif
composable_kernel/include/tensor_operation/threadwise_tensor_slice_transfer.hpp
View file @
ed068043
...
...
@@ -413,7 +413,7 @@ struct ThreadwiseTensorSliceTransfer_v2
"wrong! SrcDesc need to known at compile-time"
);
}
__device__
void
Set
Dst
SliceOrigin
(
const
SrcDesc
&
src_desc
,
const
Index
&
src_slice_origin_idx
)
__device__
void
Set
Src
SliceOrigin
(
const
SrcDesc
&
src_desc
,
const
Index
&
src_slice_origin_idx
)
{
src_coord_
=
make_tensor_coordinate
(
src_desc
,
src_slice_origin_idx
);
}
...
...
composable_kernel/include/tensor_operation/xdlops_gemm.hpp
View file @
ed068043
...
...
@@ -644,17 +644,17 @@ struct XdlopsGemm
static_assert
(
KPack
%
mfma_instr
.
k_per_blk
==
0
,
"KPack cannot be divided by k_per_blk"
);
}
template
<
typename
CM0N0M1N1M2N2
Desc
>
template
<
typename
C
Desc_
M0
_
N0
_
M1
_
N1
_
M2
_
N2
>
__host__
__device__
static
constexpr
auto
MakeCM0N0M1N1M2M3M4N2
Descriptor
(
const
CM0N0M1N1M2N2
Desc
&
c_m0_n0_m1_n1_m2_n2
_desc
)
MakeC
Descriptor_
M0
_
N0
_
M1
_
N1
_
M2
_
M3
_
M4
_
N2
(
const
C
Desc_
M0
_
N0
_
M1
_
N1
_
M2
_
N2
&
c_des
c_m0_n0_m1_n1_m2_n2
)
{
const
auto
M0
=
c_m0_n0_m1_n1_m2_n2
_desc
.
GetLength
(
I0
);
const
auto
N0
=
c_m0_n0_m1_n1_m2_n2
_desc
.
GetLength
(
I1
);
const
auto
M1
=
c_m0_n0_m1_n1_m2_n2
_desc
.
GetLength
(
I2
);
const
auto
N1
=
c_m0_n0_m1_n1_m2_n2
_desc
.
GetLength
(
I3
);
const
auto
M0
=
c_
desc_
m0_n0_m1_n1_m2_n2
.
GetLength
(
I0
);
const
auto
N0
=
c_
desc_
m0_n0_m1_n1_m2_n2
.
GetLength
(
I1
);
const
auto
M1
=
c_
desc_
m0_n0_m1_n1_m2_n2
.
GetLength
(
I2
);
const
auto
N1
=
c_
desc_
m0_n0_m1_n1_m2_n2
.
GetLength
(
I3
);
return
transform_tensor_descriptor
(
c_m0_n0_m1_n1_m2_n2
_desc
,
c_
desc_
m0_n0_m1_n1_m2_n2
,
make_tuple
(
make_pass_through_transform
(
M0
),
make_pass_through_transform
(
N0
),
make_pass_through_transform
(
M1
),
...
...
composable_kernel/include/utility/config.hpp
View file @
ed068043
...
...
@@ -95,7 +95,7 @@
#define CK_EXPERIMENTAL_STATIC_TENSOR_DESCRIPTOR 1
// merge transformation use magic number division
#define CK_EXPERIMENTAL_MERGE_USE_MAGIC_DIVISION
0
#define CK_EXPERIMENTAL_MERGE_USE_MAGIC_DIVISION
1
// hack: have underlying assumption that need to be satsified, otherwise it's a bug
// hack for forcing register to keep idx_diff_low_const in SGPR. idx_diff_low_const must be
...
...
composable_kernel/include/utility/type.hpp
View file @
ed068043
...
...
@@ -16,6 +16,9 @@ struct is_same<X, X> : public integral_constant<bool, true>
{
};
template
<
typename
X
,
typename
Y
>
inline
constexpr
bool
is_same_v
=
is_same
<
X
,
Y
>::
value
;
template
<
typename
T
>
using
remove_reference_t
=
typename
std
::
remove_reference
<
T
>::
type
;
...
...
composable_kernel/src/kernel_wrapper/convolution_forward_implicit_gemm_v4r4_xdlops_nhwc_kyxc_nhwk.cpp
View file @
ed068043
...
...
@@ -92,7 +92,7 @@ extern "C" __global__ void convolution_forward_implicit_gemm_v4r4_xdlops_nhwc_ky
const
auto
wei_k_y_x_c_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
k
,
y
,
x
,
c
));
const
auto
out_n_ho_wo_k_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
n
,
ho
,
wo
,
k
));
const
auto
descs
=
transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk
_pad
(
const
auto
descs
=
transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk
(
in_n_hi_wi_c_desc
,
wei_k_y_x_c_desc
,
out_n_ho_wo_k_desc
,
...
...
@@ -230,14 +230,14 @@ extern "C" __global__ void
make_naive_tensor_descriptor_packed
(
make_tuple
(
256
,
28
,
28
,
256
));
constexpr
auto
descs
=
transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk
_pad
(
in_n_hi_wi_c_desc
,
wei_k_y_x_c_desc
,
out_n_ho_wo_k_desc
,
make_tuple
(
1
,
1
),
make_tuple
(
1
,
1
),
make_tuple
(
1
,
1
),
make_tuple
(
1
,
1
),
Number
<
K1
>
{});
transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk
(
in_n_hi_wi_c_desc
,
wei_k_y_x_c_desc
,
out_n_ho_wo_k_desc
,
make_tuple
(
1
,
1
),
make_tuple
(
1
,
1
),
make_tuple
(
1
,
1
),
make_tuple
(
1
,
1
),
Number
<
K1
>
{});
constexpr
auto
a_k0_m_k1_grid_desc_tmp
=
descs
[
I0
];
constexpr
auto
b_k0_n_k1_grid_desc_tmp
=
descs
[
I1
];
...
...
device_operation/device_conv_xdl_instance_f16_f16_f16_nhwc_kyxc_nhwk.cpp
0 → 100644
View file @
ed068043
#include <stdlib.h>
#include "config.hpp"
#include "device_conv_fwd_xdl_nhwc_kyxc_nhwk.hpp"
#include "device_conv_instance.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
device_conv_instance
{
using
F16
=
ck
::
half_t
;
using
F32
=
float
;
using
NHWC
=
ck
::
tensor_layout
::
convolution
::
NHWC
;
using
KYXC
=
ck
::
tensor_layout
::
convolution
::
KYXC
;
using
NHWK
=
ck
::
tensor_layout
::
convolution
::
NHWK
;
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
// Compilation parameters for in[n, hi, wi, c] * wei[k, y, x, c] = out[n, ho, wo, k]
using
device_conv_fwd_xdl_instances_f16_f16_f16_nhwc_kyxc_nhwk
=
std
::
tuple
<
// clang-format off
//##############| NDim| InData| WeiData| OutData| AccData| In| Wei| Out| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| ABlockLds| BBlockLds|
//##############| Spatial| Type| Type| Type| Type| Layout| Layout| Layout| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| SrcDstVectorDim| DstScalar| AddExtraM| AddExtraN|
//##############| | | | | | | | | | | | | | | | Wave| Wave| Lengths_K0_N_K1| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| Lengths_K0_N_K1| Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector| | |
//##############| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceConvFwdXdl
<
2
,
F16
,
F16
,
F16
,
F32
,
NHWC
,
KYXC
,
NHWK
,
256
,
256
,
128
,
4
,
8
,
32
,
32
,
4
,
2
,
S
<
1
,
4
,
8
>
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
S
<
1
,
2
,
8
>
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
7
,
1
,
true
,
true
>
,
DeviceConvFwdXdl
<
2
,
F16
,
F16
,
F16
,
F32
,
NHWC
,
KYXC
,
NHWK
,
256
,
128
,
256
,
4
,
8
,
32
,
32
,
2
,
4
,
S
<
1
,
2
,
8
>
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
S
<
1
,
4
,
8
>
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
7
,
1
,
true
,
true
>
,
DeviceConvFwdXdl
<
2
,
F16
,
F16
,
F16
,
F32
,
NHWC
,
KYXC
,
NHWK
,
128
,
128
,
128
,
4
,
8
,
32
,
32
,
4
,
2
,
S
<
1
,
4
,
8
>
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
S
<
1
,
4
,
8
>
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
7
,
1
,
true
,
true
>
,
DeviceConvFwdXdl
<
2
,
F16
,
F16
,
F16
,
F32
,
NHWC
,
KYXC
,
NHWK
,
256
,
128
,
128
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
1
,
2
,
8
>
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
S
<
1
,
2
,
8
>
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
7
,
1
,
true
,
true
>
,
DeviceConvFwdXdl
<
2
,
F16
,
F16
,
F16
,
F32
,
NHWC
,
KYXC
,
NHWK
,
128
,
128
,
64
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
1
,
4
,
8
>
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
S
<
1
,
2
,
8
>
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
7
,
1
,
true
,
true
>
,
DeviceConvFwdXdl
<
2
,
F16
,
F16
,
F16
,
F32
,
NHWC
,
KYXC
,
NHWK
,
128
,
64
,
128
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
1
,
2
,
8
>
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
S
<
1
,
4
,
8
>
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
7
,
1
,
true
,
true
>
,
DeviceConvFwdXdl
<
2
,
F16
,
F16
,
F16
,
F32
,
NHWC
,
KYXC
,
NHWK
,
64
,
64
,
64
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
1
,
4
,
8
>
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
S
<
1
,
4
,
8
>
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
7
,
1
,
true
,
true
>
,
DeviceConvFwdXdl
<
2
,
F16
,
F16
,
F16
,
F32
,
NHWC
,
KYXC
,
NHWK
,
256
,
128
,
64
,
4
,
8
,
32
,
32
,
2
,
1
,
S
<
1
,
2
,
8
>
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
S
<
1
,
1
,
8
>
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
7
,
1
,
true
,
true
>
,
DeviceConvFwdXdl
<
2
,
F16
,
F16
,
F16
,
F32
,
NHWC
,
KYXC
,
NHWK
,
256
,
64
,
128
,
4
,
8
,
32
,
32
,
1
,
2
,
S
<
1
,
1
,
8
>
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
S
<
1
,
2
,
8
>
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
7
,
1
,
true
,
true
>
,
DeviceConvFwdXdl
<
2
,
F16
,
F16
,
F16
,
F32
,
NHWC
,
KYXC
,
NHWK
,
128
,
128
,
32
,
4
,
8
,
32
,
32
,
2
,
1
,
S
<
1
,
4
,
8
>
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
S
<
1
,
1
,
8
>
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
7
,
1
,
true
,
true
>
,
DeviceConvFwdXdl
<
2
,
F16
,
F16
,
F16
,
F32
,
NHWC
,
KYXC
,
NHWK
,
128
,
32
,
128
,
4
,
8
,
32
,
32
,
1
,
2
,
S
<
1
,
1
,
8
>
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
S
<
1
,
4
,
8
>
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
7
,
1
,
true
,
true
>
,
DeviceConvFwdXdl
<
2
,
F16
,
F16
,
F16
,
F32
,
NHWC
,
KYXC
,
NHWK
,
64
,
64
,
32
,
4
,
8
,
32
,
32
,
2
,
1
,
S
<
1
,
4
,
8
>
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
S
<
1
,
2
,
8
>
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
7
,
1
,
true
,
true
>
,
DeviceConvFwdXdl
<
2
,
F16
,
F16
,
F16
,
F32
,
NHWC
,
KYXC
,
NHWK
,
64
,
32
,
64
,
4
,
8
,
32
,
32
,
1
,
2
,
S
<
1
,
2
,
8
>
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
S
<
1
,
4
,
8
>
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
7
,
1
,
true
,
true
>
// clang-format on
>
;
template
<
>
void
add_device_conv_fwd_instance
<
2
,
F16
,
F16
,
F16
,
NHWC
,
KYXC
,
NHWK
>
(
std
::
vector
<
DeviceConvFwdPtr
>&
device_conv_instances
)
{
using
DeviceConvs
=
device_conv_fwd_xdl_instances_f16_f16_f16_nhwc_kyxc_nhwk
;
const
auto
device_convs
=
DeviceConvs
{};
ck
::
static_for
<
0
,
std
::
tuple_size_v
<
DeviceConvs
>
,
1
>
{}([
&
](
auto
i
)
{
using
Conv
=
remove_cvref_t
<
decltype
(
std
::
get
<
i
>
(
device_convs
))
>
;
auto
conv
=
Conv
{};
device_conv_instances
.
push_back
(
std
::
make_unique
<
Conv
>
(
conv
));
});
}
}
// namespace device_conv_instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
device_operation/device_conv_xdl_instance_f32_f32_f32_nhwc_kyxc_nhwk.cpp
0 → 100644
View file @
ed068043
#include <stdlib.h>
#include "config.hpp"
#include "device_conv_fwd_xdl_nhwc_kyxc_nhwk.hpp"
#include "device_conv_instance.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
device_conv_instance
{
using
F16
=
ck
::
half_t
;
using
F32
=
float
;
using
NHWC
=
ck
::
tensor_layout
::
convolution
::
NHWC
;
using
KYXC
=
ck
::
tensor_layout
::
convolution
::
KYXC
;
using
NHWK
=
ck
::
tensor_layout
::
convolution
::
NHWK
;
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
// Compilation parameters for in[n, hi, wi, c] * wei[k, y, x, c] = out[n, ho, wo, k]
using
device_conv_fwd_xdl_instances_f32_f32_f32_nhwc_kyxc_nhwk
=
std
::
tuple
<
// clang-format off
//##############| NDim| InData| WeiData| OutData| AccData| In| Wei| Out| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| ABlockLds| BBlockLds|
//##############| Spatial| Type| Type| Type| Type| Layout| Layout| Layout| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| SrcDstVectorDim| DstScalar| AddExtraM| AddExtraN|
//##############| | | | | | | | | | | | | | | | Wave| Wave| Lengths_K0_N_K1| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| Lengths_K0_N_K1| Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector| | |
//##############| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceConvFwdXdl
<
2
,
F32
,
F32
,
F32
,
F32
,
NHWC
,
KYXC
,
NHWK
,
256
,
256
,
128
,
4
,
4
,
32
,
32
,
4
,
2
,
S
<
1
,
4
,
4
>
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
S
<
1
,
2
,
4
>
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
7
,
1
,
true
,
true
>
,
DeviceConvFwdXdl
<
2
,
F32
,
F32
,
F32
,
F32
,
NHWC
,
KYXC
,
NHWK
,
256
,
128
,
256
,
4
,
4
,
32
,
32
,
2
,
4
,
S
<
1
,
2
,
4
>
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
S
<
1
,
4
,
4
>
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
7
,
1
,
true
,
true
>
,
DeviceConvFwdXdl
<
2
,
F32
,
F32
,
F32
,
F32
,
NHWC
,
KYXC
,
NHWK
,
128
,
128
,
128
,
4
,
4
,
32
,
32
,
4
,
2
,
S
<
1
,
4
,
4
>
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
S
<
1
,
4
,
4
>
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
7
,
1
,
true
,
true
>
,
DeviceConvFwdXdl
<
2
,
F32
,
F32
,
F32
,
F32
,
NHWC
,
KYXC
,
NHWK
,
256
,
128
,
128
,
4
,
4
,
32
,
32
,
2
,
2
,
S
<
1
,
2
,
4
>
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
S
<
1
,
2
,
4
>
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
7
,
1
,
true
,
true
>
,
DeviceConvFwdXdl
<
2
,
F32
,
F32
,
F32
,
F32
,
NHWC
,
KYXC
,
NHWK
,
128
,
128
,
64
,
4
,
4
,
32
,
32
,
2
,
2
,
S
<
1
,
4
,
4
>
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
S
<
1
,
2
,
4
>
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
7
,
1
,
true
,
true
>
,
DeviceConvFwdXdl
<
2
,
F32
,
F32
,
F32
,
F32
,
NHWC
,
KYXC
,
NHWK
,
128
,
64
,
128
,
4
,
4
,
32
,
32
,
2
,
2
,
S
<
1
,
2
,
4
>
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
S
<
1
,
4
,
4
>
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
7
,
1
,
true
,
true
>
,
DeviceConvFwdXdl
<
2
,
F32
,
F32
,
F32
,
F32
,
NHWC
,
KYXC
,
NHWK
,
64
,
64
,
64
,
4
,
4
,
32
,
32
,
2
,
2
,
S
<
1
,
4
,
4
>
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
S
<
1
,
4
,
4
>
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
7
,
1
,
true
,
true
>
,
DeviceConvFwdXdl
<
2
,
F32
,
F32
,
F32
,
F32
,
NHWC
,
KYXC
,
NHWK
,
256
,
128
,
64
,
4
,
4
,
32
,
32
,
2
,
1
,
S
<
1
,
2
,
4
>
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
S
<
1
,
1
,
4
>
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
7
,
1
,
true
,
true
>
,
DeviceConvFwdXdl
<
2
,
F32
,
F32
,
F32
,
F32
,
NHWC
,
KYXC
,
NHWK
,
256
,
64
,
128
,
4
,
4
,
32
,
32
,
1
,
2
,
S
<
1
,
1
,
4
>
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
S
<
1
,
2
,
4
>
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
7
,
1
,
true
,
true
>
,
DeviceConvFwdXdl
<
2
,
F32
,
F32
,
F32
,
F32
,
NHWC
,
KYXC
,
NHWK
,
128
,
128
,
32
,
4
,
4
,
32
,
32
,
2
,
1
,
S
<
1
,
4
,
4
>
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
S
<
1
,
1
,
4
>
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
7
,
1
,
true
,
true
>
,
DeviceConvFwdXdl
<
2
,
F32
,
F32
,
F32
,
F32
,
NHWC
,
KYXC
,
NHWK
,
128
,
32
,
128
,
4
,
4
,
32
,
32
,
1
,
2
,
S
<
1
,
1
,
4
>
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
S
<
1
,
4
,
4
>
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
7
,
1
,
true
,
true
>
,
DeviceConvFwdXdl
<
2
,
F32
,
F32
,
F32
,
F32
,
NHWC
,
KYXC
,
NHWK
,
64
,
64
,
32
,
4
,
4
,
32
,
32
,
2
,
1
,
S
<
1
,
4
,
4
>
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
S
<
1
,
2
,
4
>
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
7
,
1
,
true
,
true
>
,
DeviceConvFwdXdl
<
2
,
F32
,
F32
,
F32
,
F32
,
NHWC
,
KYXC
,
NHWK
,
64
,
32
,
64
,
4
,
4
,
32
,
32
,
1
,
2
,
S
<
1
,
2
,
4
>
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
S
<
1
,
4
,
4
>
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
7
,
1
,
true
,
true
>
// clang-format on
>
;
template
<
>
void
add_device_conv_fwd_instance
<
2
,
F32
,
F32
,
F32
,
NHWC
,
KYXC
,
NHWK
>
(
std
::
vector
<
DeviceConvFwdPtr
>&
device_conv_instances
)
{
using
DeviceConvs
=
device_conv_fwd_xdl_instances_f32_f32_f32_nhwc_kyxc_nhwk
;
const
auto
device_convs
=
DeviceConvs
{};
ck
::
static_for
<
0
,
std
::
tuple_size_v
<
DeviceConvs
>
,
1
>
{}([
&
](
auto
i
)
{
using
Conv
=
remove_cvref_t
<
decltype
(
std
::
get
<
i
>
(
device_convs
))
>
;
auto
conv
=
Conv
{};
device_conv_instances
.
push_back
(
std
::
make_unique
<
Conv
>
(
conv
));
});
}
}
// namespace device_conv_instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
device_operation/device_gemm_xdl_instance_f16_f16_f16_km_kn_mn.cpp
0 → 100644
View file @
ed068043
#include <stdlib.h>
#include "config.hpp"
#include "device_gemm_xdl.hpp"
#include "device_gemm_instance.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
device_gemm_instance
{
using
F16
=
ck
::
half_t
;
using
F32
=
float
;
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
// Compilation parameters for a[k, m] * b[k, n] = c[m, n]
using
device_gemm_xdl_instance_f16_f16_f16_km_kn_mn
=
std
::
tuple
<
// clang-format off
//##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| ABlockLds| BBlockLds|
//##########| Type| Type| Type| Type| | | | Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| SrcDstVectorDim| DstScalar| AddExtraM| AddExtraN|
//##########| | | | | | | | | | | | | | | Wave| Wave| Lengths_K0_N_K1| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| Lengths_K0_N_K1| Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector| | |
//##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Col
,
Row
,
Row
,
256
,
256
,
128
,
4
,
8
,
32
,
32
,
4
,
2
,
S
<
1
,
4
,
8
>
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
S
<
1
,
2
,
8
>
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
7
,
1
,
true
,
true
>
,
DeviceGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Col
,
Row
,
Row
,
256
,
128
,
256
,
4
,
8
,
32
,
32
,
2
,
4
,
S
<
1
,
2
,
8
>
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
S
<
1
,
4
,
8
>
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
7
,
1
,
true
,
true
>
,
DeviceGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Col
,
Row
,
Row
,
128
,
128
,
128
,
4
,
8
,
32
,
32
,
4
,
2
,
S
<
1
,
4
,
8
>
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
S
<
1
,
4
,
8
>
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
7
,
1
,
true
,
true
>
,
DeviceGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Col
,
Row
,
Row
,
256
,
128
,
128
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
1
,
2
,
8
>
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
S
<
1
,
2
,
8
>
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
7
,
1
,
true
,
true
>
,
DeviceGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Col
,
Row
,
Row
,
128
,
128
,
64
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
1
,
4
,
8
>
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
S
<
1
,
2
,
8
>
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
7
,
1
,
true
,
true
>
,
DeviceGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Col
,
Row
,
Row
,
128
,
64
,
128
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
1
,
2
,
8
>
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
S
<
1
,
4
,
8
>
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
7
,
1
,
true
,
true
>
,
DeviceGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Col
,
Row
,
Row
,
256
,
128
,
64
,
4
,
8
,
32
,
32
,
2
,
1
,
S
<
1
,
2
,
8
>
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
S
<
1
,
1
,
8
>
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
8
,
7
,
1
,
true
,
true
>
,
DeviceGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Col
,
Row
,
Row
,
256
,
64
,
128
,
4
,
8
,
32
,
32
,
1
,
2
,
S
<
1
,
1
,
8
>
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
8
,
S
<
1
,
2
,
8
>
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
7
,
1
,
true
,
true
>
// clang-format on
>
;
template
<
>
void
add_device_gemm_instance
<
F16
,
F16
,
F16
,
Col
,
Row
,
Row
>
(
std
::
vector
<
DeviceGemmPtr
>&
device_op_instances
)
{
using
DeviceGemms
=
device_gemm_instance
::
device_gemm_xdl_instance_f16_f16_f16_km_kn_mn
;
const
auto
device_gemms
=
DeviceGemms
{};
ck
::
static_for
<
0
,
std
::
tuple_size_v
<
DeviceGemms
>
,
1
>
{}([
&
](
auto
i
)
{
using
Gemm
=
remove_cvref_t
<
decltype
(
std
::
get
<
i
>
(
device_gemms
))
>
;
auto
gemm
=
Gemm
{};
device_op_instances
.
push_back
(
std
::
make_unique
<
Gemm
>
(
gemm
));
});
}
}
// namespace device_gemm_instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
device_operation/device_gemm_xdl_instance_f16_f16_f16_km_nk_mn.cpp
0 → 100644
View file @
ed068043
#include <stdlib.h>
#include "config.hpp"
#include "device_gemm_xdl.hpp"
#include "device_gemm_instance.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
device_gemm_instance
{
using
F16
=
ck
::
half_t
;
using
F32
=
float
;
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
// Compilation parameters for a[k, m] * b[n, k] = c[m, n]
using
device_gemm_xdl_instance_f16_f16_f16_km_nk_mn
=
std
::
tuple
<
// clang-format off
//##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| ABlockLds| BBlockLds|
//##########| Type| Type| Type| Type| | | | Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| SrcDstVectorDim| DstScalar| AddExtraM| AddExtraN|
//##########| | | | | | | | | | | | | | | Wave| Wave| Lengths_K0_N_K1| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| Lengths_K0_N_K1| Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector| | |
//##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Col
,
Col
,
Row
,
256
,
256
,
128
,
4
,
8
,
32
,
32
,
4
,
2
,
S
<
1
,
4
,
8
>
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
S
<
1
,
2
,
8
>
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
7
,
1
,
true
,
true
>
,
DeviceGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Col
,
Col
,
Row
,
256
,
128
,
256
,
4
,
8
,
32
,
32
,
2
,
4
,
S
<
1
,
2
,
8
>
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
S
<
1
,
4
,
8
>
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
7
,
1
,
true
,
true
>
,
DeviceGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Col
,
Col
,
Row
,
128
,
128
,
128
,
4
,
8
,
32
,
32
,
4
,
2
,
S
<
1
,
4
,
8
>
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
S
<
1
,
4
,
8
>
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
7
,
1
,
true
,
true
>
,
DeviceGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Col
,
Col
,
Row
,
256
,
128
,
128
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
1
,
2
,
8
>
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
S
<
1
,
2
,
8
>
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
7
,
1
,
true
,
true
>
,
DeviceGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Col
,
Col
,
Row
,
128
,
128
,
64
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
1
,
4
,
8
>
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
S
<
1
,
2
,
8
>
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
7
,
1
,
true
,
true
>
,
DeviceGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Col
,
Col
,
Row
,
128
,
64
,
128
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
1
,
2
,
8
>
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
S
<
1
,
4
,
8
>
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
7
,
1
,
true
,
true
>
,
DeviceGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Col
,
Col
,
Row
,
256
,
128
,
64
,
4
,
8
,
32
,
32
,
2
,
1
,
S
<
1
,
2
,
8
>
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
S
<
1
,
1
,
8
>
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
7
,
1
,
true
,
true
>
,
DeviceGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Col
,
Col
,
Row
,
256
,
64
,
128
,
4
,
8
,
32
,
32
,
1
,
2
,
S
<
1
,
1
,
8
>
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
8
,
S
<
1
,
2
,
8
>
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
7
,
1
,
true
,
true
>
// clang-format on
>
;
template
<
>
void
add_device_gemm_instance
<
F16
,
F16
,
F16
,
Col
,
Col
,
Row
>
(
std
::
vector
<
DeviceGemmPtr
>&
device_op_instances
)
{
using
DeviceGemms
=
device_gemm_instance
::
device_gemm_xdl_instance_f16_f16_f16_km_nk_mn
;
const
auto
device_gemms
=
DeviceGemms
{};
ck
::
static_for
<
0
,
std
::
tuple_size_v
<
DeviceGemms
>
,
1
>
{}([
&
](
auto
i
)
{
using
Gemm
=
remove_cvref_t
<
decltype
(
std
::
get
<
i
>
(
device_gemms
))
>
;
auto
gemm
=
Gemm
{};
device_op_instances
.
push_back
(
std
::
make_unique
<
Gemm
>
(
gemm
));
});
}
}
// namespace device_gemm_instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
device_operation/device_gemm_xdl_instance_f16_f16_f16_mk_kn_mn.cpp
0 → 100644
View file @
ed068043
#include <stdlib.h>
#include "config.hpp"
#include "device_gemm_xdl.hpp"
#include "device_gemm_instance.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
device_gemm_instance
{
using
F16
=
ck
::
half_t
;
using
F32
=
float
;
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
// Compilation parameters for a[m, k] * b[k, n] = c[m, n]
using
device_gemm_xdl_instance_f16_f16_f16_mk_kn_mn
=
std
::
tuple
<
// clang-format off
//##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| ABlockLds| BBlockLds|
//##########| Type| Type| Type| Type| | | | Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| SrcDstVectorDim| DstScalar| AddExtraM| AddExtraN|
//##########| | | | | | | | | | | | | | | Wave| Wave| Lengths_K0_N_K1| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| Lengths_K0_N_K1| Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector| | |
//##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
256
,
256
,
128
,
4
,
8
,
32
,
32
,
4
,
2
,
S
<
1
,
4
,
8
>
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
S
<
1
,
2
,
8
>
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
7
,
1
,
true
,
true
>
,
DeviceGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
256
,
128
,
256
,
4
,
8
,
32
,
32
,
2
,
4
,
S
<
1
,
2
,
8
>
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
S
<
1
,
4
,
8
>
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
7
,
1
,
true
,
true
>
,
DeviceGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
128
,
128
,
128
,
4
,
8
,
32
,
32
,
4
,
2
,
S
<
1
,
4
,
8
>
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
S
<
1
,
4
,
8
>
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
7
,
1
,
true
,
true
>
,
DeviceGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
256
,
128
,
128
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
1
,
2
,
8
>
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
S
<
1
,
2
,
8
>
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
7
,
1
,
true
,
true
>
,
DeviceGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
128
,
128
,
64
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
1
,
4
,
8
>
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
S
<
1
,
2
,
8
>
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
7
,
1
,
true
,
true
>
,
DeviceGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
128
,
64
,
128
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
1
,
2
,
8
>
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
S
<
1
,
4
,
8
>
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
7
,
1
,
true
,
true
>
,
DeviceGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
256
,
128
,
64
,
4
,
8
,
32
,
32
,
2
,
1
,
S
<
1
,
2
,
8
>
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
S
<
1
,
1
,
8
>
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
8
,
7
,
1
,
true
,
true
>
,
DeviceGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
256
,
64
,
128
,
4
,
8
,
32
,
32
,
1
,
2
,
S
<
1
,
1
,
8
>
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
S
<
1
,
2
,
8
>
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
7
,
1
,
true
,
true
>
// clang-format on
>
;
template
<
>
void
add_device_gemm_instance
<
F16
,
F16
,
F16
,
Row
,
Row
,
Row
>
(
std
::
vector
<
DeviceGemmPtr
>&
device_op_instances
)
{
using
DeviceGemms
=
device_gemm_instance
::
device_gemm_xdl_instance_f16_f16_f16_mk_kn_mn
;
const
auto
device_gemms
=
DeviceGemms
{};
ck
::
static_for
<
0
,
std
::
tuple_size_v
<
DeviceGemms
>
,
1
>
{}([
&
](
auto
i
)
{
using
Gemm
=
remove_cvref_t
<
decltype
(
std
::
get
<
i
>
(
device_gemms
))
>
;
auto
gemm
=
Gemm
{};
device_op_instances
.
push_back
(
std
::
make_unique
<
Gemm
>
(
gemm
));
});
}
}
// namespace device_gemm_instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
Prev
1
2
3
4
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment