Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
506a823a
Commit
506a823a
authored
May 30, 2020
by
Chao Liu
Browse files
clean up
parent
80901f59
Changes
55
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
208 additions
and
5028 deletions
+208
-5028
composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r3_nchw_kcyx_nkhw_lds_double_buffer.hpp
...n_implicit_gemm_v4r3_nchw_kcyx_nkhw_lds_double_buffer.hpp
+0
-457
composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r4_xdlops_fp16_nchw_kcyx_nkhw.hpp
...olution_implicit_gemm_v4r4_xdlops_fp16_nchw_kcyx_nkhw.hpp
+0
-218
composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp
..._convolution_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp
+0
-162
composable_kernel/include/tensor_operation/blockwise_gemm_xdlops.hpp
...kernel/include/tensor_operation/blockwise_gemm_xdlops.hpp
+0
-119
composable_kernel/include/tensor_operation/gridwise_gemm_xdlops.hpp
..._kernel/include/tensor_operation/gridwise_gemm_xdlops.hpp
+0
-656
composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_fp16_bfp16.hpp
...lude/tensor_operation/gridwise_gemm_xdlops_fp16_bfp16.hpp
+0
-337
composable_kernel/include/tensor_operation/xdlops_gemm.hpp
composable_kernel/include/tensor_operation/xdlops_gemm.hpp
+0
-920
composable_kernel/include/utility/amd_xdlops_emulate.hpp
composable_kernel/include/utility/amd_xdlops_emulate.hpp
+0
-217
driver/include/device_col2im_eb_nchw.hpp
driver/include/device_col2im_eb_nchw.hpp
+0
-109
driver/include/device_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw.hpp
...ution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw.hpp
+31
-21
driver/include/device_convolution_backward_data_implicit_gemm_v1r2_nchw_kcyx_nkhw.hpp
...ution_backward_data_implicit_gemm_v1r2_nchw_kcyx_nkhw.hpp
+46
-34
driver/include/device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw.hpp
...ution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw.hpp
+31
-21
driver/include/device_convolution_backward_data_implicit_gemm_v3r1_nchw_kcyx_nkhw.hpp
...ution_backward_data_implicit_gemm_v3r1_nchw_kcyx_nkhw.hpp
+31
-21
driver/include/device_convolution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp
...ution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp
+69
-100
driver/include/device_convolution_direct_v2_nchw_kcyx_nkhw.hpp
...r/include/device_convolution_direct_v2_nchw_kcyx_nkhw.hpp
+0
-98
driver/include/device_convolution_implicit_gemm_v1_chwn_cyxk_khwn.hpp
...de/device_convolution_implicit_gemm_v1_chwn_cyxk_khwn.hpp
+0
-486
driver/include/device_convolution_implicit_gemm_v1_chwn_cyxk_khwn_padded.hpp
...ce_convolution_implicit_gemm_v1_chwn_cyxk_khwn_padded.hpp
+0
-189
driver/include/device_convolution_implicit_gemm_v1_nchw_cyxk_nkhw.hpp
...de/device_convolution_implicit_gemm_v1_nchw_cyxk_nkhw.hpp
+0
-374
driver/include/device_convolution_implicit_gemm_v2_chwn_cyxk_khwn.hpp
...de/device_convolution_implicit_gemm_v2_chwn_cyxk_khwn.hpp
+0
-334
driver/include/device_convolution_implicit_gemm_v3_nchw_cyxk_nkhw.hpp
...de/device_convolution_implicit_gemm_v3_nchw_cyxk_nkhw.hpp
+0
-155
No files found.
composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r3_nchw_kcyx_nkhw_lds_double_buffer.hpp
deleted
100644 → 0
View file @
80901f59
#ifndef CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4R3_NCHW_KCYX_NKHW_LDS_DOUBLE_BUFFER
#define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4R3_NCHW_KCYX_NKHW_LDS_DOUBLE_BUFFER
#include "common_header.hpp"
#include "ConstantTensorDescriptor_deprecated.hpp"
#include "ConstantMergedTensorDescriptor_deprecated.hpp"
#include "ConstantMatrixDescriptor.hpp"
#include "blockwise_generic_tensor_slice_copy.hpp"
#include "blockwise_gemm.hpp"
#include "threadwise_generic_tensor_slice_copy.hpp"
namespace
ck
{
template
<
index_t
GridSize
,
index_t
BlockSize
,
class
Float
,
class
InGlobalDesc
,
class
WeiGlobalDesc
,
class
OutGlobalDesc
,
class
ConvStrides
,
class
ConvDilations
,
index_t
N0
,
index_t
N1
,
index_t
N2
,
index_t
Ho0
,
index_t
Ho1
,
index_t
Ho2
,
index_t
Wo0
,
index_t
Wo1
,
index_t
Wo2
,
index_t
BPerBlock
,
index_t
KPerBlock
,
index_t
EPerBlock
,
index_t
GemmMPerThreadSubC
,
index_t
GemmNPerThreadSubC
,
index_t
GemmMLevel0Cluster
,
index_t
GemmNLevel0Cluster
,
index_t
GemmMLevel1Cluster
,
index_t
GemmNLevel1Cluster
,
index_t
GemmKPerThreadLoop
,
index_t
GemmDataPerReadA
,
index_t
GemmDataPerReadB
,
class
InBlockCopySubLengths_E_N1_Ho1_Wo1_B_N2_Ho2_Wo2
,
class
InBlockCopyClusterLengths_E_N1_Ho1_Wo1_B_N2_Ho2_Wo2
,
class
InBlockCopyThreadClusterArrangeOrder
,
class
InBlockCopySrcAccessOrder
,
class
InBlockCopyDstAccessOrder
,
index_t
InBlockCopyDataPerAccess_W2
,
class
WeiBlockCopySubLengths_E_K
,
class
WeiBlockCopyClusterLengths_E_K
,
class
WeiBlockCopyThreadClusterArrangeOrder
,
class
WeiBlockCopySrcAccessOrder
,
class
WeiBlockCopyDstAccessOrder
,
index_t
WeiBlockCopySrcDataPerRead_E
,
index_t
WeiBlockCopyDstDataPerWrite_K
>
struct
GridwiseConvolutionImplicitGemm_v4r3_nchw_kcyx_nkhw_lds_double_buffer
{
__device__
void
Run
(
const
Float
*
const
__restrict__
p_in_global
,
const
Float
*
const
__restrict__
p_wei_global
,
Float
*
const
__restrict__
p_out_global
)
const
{
// this is a mess
// TODO: find more elegent way of specifying (or calculating) performance parameters
static_assert
(
N2
*
Ho2
*
Wo2
==
GemmNPerThreadSubC
,
"wrong!"
);
static_assert
((
N1
*
Ho1
*
Wo1
*
BPerBlock
*
N2
*
Ho2
*
Wo2
)
%
(
GemmNPerThreadSubC
*
GemmNLevel0Cluster
*
GemmNLevel1Cluster
)
==
0
,
"wrong!"
);
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
I5
=
Number
<
5
>
{};
constexpr
auto
I7
=
Number
<
7
>
{};
constexpr
auto
True
=
integral_constant
<
bool
,
true
>
{};
constexpr
auto
in_n_c_h_w_global_desc
=
InGlobalDesc
{};
constexpr
auto
wei_k_c_y_x_global_desc
=
WeiGlobalDesc
{};
constexpr
auto
out_n_k_h_w_global_desc
=
OutGlobalDesc
{};
constexpr
index_t
N
=
in_n_c_h_w_global_desc
.
GetLengths
()[
0
];
constexpr
index_t
C
=
in_n_c_h_w_global_desc
.
GetLengths
()[
1
];
constexpr
index_t
K
=
out_n_k_h_w_global_desc
.
GetLengths
()[
1
];
constexpr
index_t
Ho
=
out_n_k_h_w_global_desc
.
GetLengths
()[
2
];
constexpr
index_t
Wo
=
out_n_k_h_w_global_desc
.
GetLengths
()[
3
];
constexpr
index_t
Y
=
wei_k_c_y_x_global_desc
.
GetLengths
()[
2
];
constexpr
index_t
X
=
wei_k_c_y_x_global_desc
.
GetLengths
()[
3
];
constexpr
index_t
ConvStrideH
=
ConvStrides
{}[
0
];
constexpr
index_t
ConvStrideW
=
ConvStrides
{}[
1
];
constexpr
index_t
ConvDilationH
=
ConvDilations
{}[
0
];
constexpr
index_t
ConvDilationW
=
ConvDilations
{}[
1
];
constexpr
index_t
E
=
C
*
Y
*
X
;
constexpr
index_t
B
=
N0
*
Ho0
*
Wo0
;
static_assert
(
N
==
N0
*
N1
*
N2
&&
Ho
==
Ho0
*
Ho1
*
Ho2
&&
Wo
==
Wo0
*
Wo1
*
Wo2
,
"wrong!"
);
static_assert
((
X
==
1
||
ConvDilationW
%
InBlockCopyDataPerAccess_W2
==
0
),
"wrong! aligment requirement for vectorized global load of input tensor will "
"be violated"
);
// divide block work by [K, B]
static_assert
(
K
%
KPerBlock
==
0
&&
B
%
BPerBlock
==
0
&&
E
%
(
2
*
EPerBlock
)
==
0
,
"wrong! cannot divide work evenly among block"
);
constexpr
index_t
KBlockWork
=
K
/
KPerBlock
;
constexpr
index_t
BBlockWork
=
B
/
BPerBlock
;
constexpr
auto
block_work_desc
=
make_ConstantTensorDescriptor_packed
(
Sequence
<
KBlockWork
,
BBlockWork
>
{});
const
auto
block_work_multi_id
=
block_work_desc
.
GetMultiIndexFrom1dIndex
(
get_block_1d_id
());
const
index_t
k_block_data_on_global
=
block_work_multi_id
[
0
]
*
KPerBlock
;
const
index_t
b_block_data_on_global
=
block_work_multi_id
[
1
]
*
BPerBlock
;
// input tensor
// tensor descriptor in device memory [N0, N1, N2, Ho0, Ho1, Ho2, Wo0, Wo1, Wo2]
constexpr
auto
in_n0_n1_n2_ho0_ho1_ho2_wo0_wo1_wo2_global_desc
=
in_n_c_h_w_global_desc
.
Extract
(
I0
,
I2
,
I3
)
.
StridedSlice
(
I1
,
Number
<
Ho
>
{},
Number
<
ConvStrideH
>
{})
.
StridedSlice
(
I2
,
Number
<
Wo
>
{},
Number
<
ConvStrideW
>
{})
.
Fold
(
I2
,
Number
<
Wo1
>
{},
Number
<
Wo2
>
{})
.
Fold
(
I1
,
Number
<
Ho1
>
{},
Number
<
Ho2
>
{})
.
Fold
(
I0
,
Number
<
N1
>
{},
Number
<
N2
>
{});
constexpr
auto
in_n1_ho1_wo1_n0_ho0_wo0_n2_ho2_wo2_global_desc
=
in_n0_n1_n2_ho0_ho1_ho2_wo0_wo1_wo2_global_desc
.
ReorderGivenNew2Old
(
Sequence
<
1
,
4
,
7
,
0
,
3
,
6
,
2
,
5
,
8
>
{});
// batch descritpor for device memory
constexpr
auto
in_c_y_x_global_desc
=
in_n_c_h_w_global_desc
.
StridedSlice
(
I2
,
Number
<
Y
>
{},
Number
<
ConvDilationH
>
{})
.
StridedSlice
(
I3
,
Number
<
X
>
{},
Number
<
ConvDilationW
>
{})
.
Extract
(
Sequence
<
1
,
2
,
3
>
{});
// merged tensor descriptor in device memory [E, N1, B, N2], src of blockwise copy
constexpr
auto
in_e_n1_ho1_wo1_b_n2_ho2_wo2_global_merged_desc
=
make_ConstantMergedTensorDescriptor
(
in_c_y_x_global_desc
.
Embed
(
in_n1_ho1_wo1_n0_ho0_wo0_n2_ho2_wo2_global_desc
),
Sequence
<
0
,
1
,
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{},
Sequence
<
5
>
{},
Sequence
<
6
,
7
,
8
>
{},
Sequence
<
9
>
{},
Sequence
<
10
>
{},
Sequence
<
11
>
{});
// memory layout descriptor in LDS [E, N1, B, N2], dst of blockwise copy
// be careful of LDS alignment
constexpr
auto
in_e_n1_ho1_wo1_b_n2_ho2_wo2_block_desc
=
make_ConstantTensorDescriptor_packed
(
Sequence
<
EPerBlock
,
N1
,
Ho1
,
Wo1
,
BPerBlock
,
N2
,
Ho2
,
Wo2
>
{});
// input blockwise copy
// slice a merged tensor, reorder and copy to a normal tensor
// this copy operator already has blockwise offset built-in
auto
blockwise_in_copy
=
BlockwiseGenericTensorSliceCopy_v1_deprecated
<
BlockSize
,
Float
,
decltype
(
in_e_n1_ho1_wo1_b_n2_ho2_wo2_global_merged_desc
),
decltype
(
in_e_n1_ho1_wo1_b_n2_ho2_wo2_block_desc
),
decltype
(
in_e_n1_ho1_wo1_b_n2_ho2_wo2_block_desc
.
GetLengths
()),
InBlockCopySubLengths_E_N1_Ho1_Wo1_B_N2_Ho2_Wo2
,
InBlockCopyClusterLengths_E_N1_Ho1_Wo1_B_N2_Ho2_Wo2
,
InBlockCopyThreadClusterArrangeOrder
,
InBlockCopySrcAccessOrder
,
InBlockCopyDstAccessOrder
,
InBlockCopyDataPerAccess_W2
,
InBlockCopyDataPerAccess_W2
>
({
0
,
0
,
0
,
0
,
b_block_data_on_global
,
0
,
0
,
0
},
{
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
});
// weight tensor
// tensor descriptor in device memory, src of blockwise copy
constexpr
auto
wei_e_k_global_desc
=
wei_k_c_y_x_global_desc
.
Unfold
(
I1
,
I3
).
ReorderGivenNew2Old
(
Sequence
<
1
,
0
>
{});
// tensor descriptor in LDS, dst of blockwise copy
// be careful of LDS alignment
constexpr
auto
wei_e_k_block_desc
=
make_ConstantTensorDescriptor_aligned
(
Sequence
<
EPerBlock
,
KPerBlock
>
{},
Number
<
math
::
lcm
(
WeiBlockCopyDstDataPerWrite_K
,
GemmDataPerReadA
)
>
{});
// operator for blockwise copy of weight into LDS
// slice a tensor, and copy it into another tensor
// this copy operator already have blockwise offset built-in
auto
blockwise_wei_copy
=
BlockwiseGenericTensorSliceCopy_v1_deprecated
<
BlockSize
,
Float
,
decltype
(
wei_e_k_global_desc
),
decltype
(
wei_e_k_block_desc
),
decltype
(
wei_e_k_block_desc
.
GetLengths
()),
WeiBlockCopySubLengths_E_K
,
WeiBlockCopyClusterLengths_E_K
,
WeiBlockCopyThreadClusterArrangeOrder
,
WeiBlockCopySrcAccessOrder
,
WeiBlockCopyDstAccessOrder
,
WeiBlockCopySrcDataPerRead_E
,
WeiBlockCopyDstDataPerWrite_K
>
(
{
0
,
k_block_data_on_global
},
{
0
,
0
});
#if 0
if(get_block_1d_id() == 0)
{
printf("id (%d %d), in offset: %d %d, wei offset %d %d\n",
get_block_1d_id(),
get_thread_local_1d_id(),
blockwise_in_copy.mThreadSrcOffset,
blockwise_in_copy.mThreadDstOffset,
blockwise_wei_copy.mThreadSrcOffset,
blockwise_wei_copy.mThreadDstOffset);
}
#endif
// GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx
// a_mtx[EPerBlock, KPerBlock] is in LDS
// b_mtx[EPerBlocl, N1 * BPerBlock * N2] is in LDS
// c_mtx[KPerBlock, N1 * BPerBlock * N2] is distributed among threads, and saved in
// register
constexpr
auto
a_e_k_block_mtx_desc
=
make_ConstantMatrixDescriptor
(
wei_e_k_block_desc
);
// this check is ad-hoc
// TODO: need to properly implement tensor descriptor with multiple alignment
// requirements
static_assert
(
in_e_n1_ho1_wo1_b_n2_ho2_wo2_block_desc
.
GetStrides
()[
3
]
%
GemmDataPerReadB
==
0
,
"GemmDataPerReadB alignment requirement is not satisfied"
);
constexpr
auto
b_e_n1ho1wo1bn2ho2wo2_block_mtx_desc
=
make_ConstantMatrixDescriptor
(
in_e_n1_ho1_wo1_b_n2_ho2_wo2_block_desc
.
Unfold
(
I1
,
I7
));
// sanity check
static_assert
(
KPerBlock
%
(
GemmMPerThreadSubC
*
GemmMLevel0Cluster
*
GemmMLevel1Cluster
)
==
0
,
"wrong!"
);
constexpr
index_t
GemmMRepeat
=
KPerBlock
/
(
GemmMPerThreadSubC
*
GemmMLevel0Cluster
*
GemmMLevel1Cluster
);
// c_thread_mtx definition: this is a mess
// TODO:: more elegent way of defining c_thread_mtx
constexpr
auto
c_k0k2_n1ho1wo1n2ho2wo2_thread_mtx_desc
=
make_ConstantMatrixDescriptor_packed
(
Number
<
GemmMRepeat
*
GemmMPerThreadSubC
>
{},
Number
<
N1
*
Ho1
*
Wo1
*
N2
*
Ho2
*
Wo2
>
{});
const
auto
blockwise_gemm
=
BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
<
BlockSize
,
decltype
(
a_e_k_block_mtx_desc
),
decltype
(
b_e_n1ho1wo1bn2ho2wo2_block_mtx_desc
),
decltype
(
c_k0k2_n1ho1wo1n2ho2wo2_thread_mtx_desc
),
GemmMPerThreadSubC
,
GemmNPerThreadSubC
,
GemmMLevel0Cluster
,
GemmNLevel0Cluster
,
GemmMLevel1Cluster
,
GemmNLevel1Cluster
,
GemmKPerThreadLoop
,
GemmDataPerReadA
,
GemmDataPerReadB
>
{};
// LDS allocation for input and weight: be careful of alignment
constexpr
index_t
max_align
=
math
::
lcm
(
InBlockCopyDataPerAccess_W2
,
WeiBlockCopyDstDataPerWrite_K
,
GemmDataPerReadA
,
GemmDataPerReadB
);
constexpr
index_t
in_block_space
=
math
::
integer_least_multiple
(
in_e_n1_ho1_wo1_b_n2_ho2_wo2_block_desc
.
GetElementSpace
(),
max_align
);
constexpr
index_t
wei_block_space
=
math
::
integer_least_multiple
(
wei_e_k_block_desc
.
GetElementSpace
(),
max_align
);
__shared__
Float
p_in_block_double
[
2
*
in_block_space
];
__shared__
Float
p_wei_block_double
[
2
*
wei_block_space
];
// register allocation for output
Float
p_out_thread
[
c_k0k2_n1ho1wo1n2ho2wo2_thread_mtx_desc
.
GetElementSpace
()];
// zero out threadwise output
threadwise_matrix_set_zero
(
c_k0k2_n1ho1wo1n2ho2wo2_thread_mtx_desc
,
p_out_thread
);
const
Float
*
p_wei_block_on_global
=
p_wei_global
;
// LDS double buffer: preload data into LDS
{
blockwise_in_copy
.
Run
(
p_in_global
,
p_in_block_double
);
blockwise_wei_copy
.
Run
(
p_wei_global
,
p_wei_block_double
);
}
// LDS double buffer: main body
for
(
index_t
e_block_data_begin
=
0
;
e_block_data_begin
+
2
*
EPerBlock
<
E
;
e_block_data_begin
+=
2
*
EPerBlock
)
{
#pragma unroll
for
(
index_t
iloop
=
0
;
iloop
<
2
;
++
iloop
)
{
const
bool
even_loop
=
(
iloop
%
2
==
0
);
Float
*
p_in_block_now
=
even_loop
?
p_in_block_double
:
p_in_block_double
+
in_block_space
;
Float
*
p_wei_block_now
=
even_loop
?
p_wei_block_double
:
p_wei_block_double
+
wei_block_space
;
Float
*
p_in_block_next
=
even_loop
?
p_in_block_double
+
in_block_space
:
p_in_block_double
;
Float
*
p_wei_block_next
=
even_loop
?
p_wei_block_double
+
wei_block_space
:
p_wei_block_double
;
Float
p_in_register_buffer
[
blockwise_in_copy
.
GetRegisterBufferSize
()];
Float
p_wei_register_buffer
[
blockwise_wei_copy
.
GetRegisterBufferSize
()];
blockwise_in_copy
.
MoveSlicingWindowOnSourceTensor
(
I0
,
Number
<
EPerBlock
>
{},
True
);
p_wei_block_on_global
+=
EPerBlock
*
wei_e_k_global_desc
.
GetStride
(
I0
);
__syncthreads
();
// LDS doubel buffer: load next data from device mem
blockwise_in_copy
.
RunLoadRegisterBuffer
(
p_in_global
,
p_in_register_buffer
);
blockwise_wei_copy
.
RunLoadRegisterBuffer
(
p_wei_block_on_global
,
p_wei_register_buffer
);
#if 0
if(get_block_1d_id() == 0)
{
printf("tid (%d %d), %f %f %f %f\n",
get_block_1d_id(),
get_thread_local_1d_id(),
p_wei_register_buffer[0],
p_wei_register_buffer[1],
p_wei_register_buffer[2],
p_wei_register_buffer[3]);
}
#endif
// LDS double buffer: GEMM on current data
blockwise_gemm
.
Run
(
p_wei_block_now
,
p_in_block_now
,
p_out_thread
);
// LDS double buffer: store next data to LDS
blockwise_in_copy
.
RunStoreRegisterBuffer
(
p_in_register_buffer
,
p_in_block_next
);
blockwise_wei_copy
.
RunStoreRegisterBuffer
(
p_wei_register_buffer
,
p_wei_block_next
);
}
}
// LDS double buffer: tail
{
Float
p_in_register_buffer
[
blockwise_in_copy
.
GetRegisterBufferSize
()];
Float
p_wei_register_buffer
[
blockwise_wei_copy
.
GetRegisterBufferSize
()];
// even iteration
blockwise_in_copy
.
MoveSlicingWindowOnSourceTensor
(
I0
,
Number
<
EPerBlock
>
{},
True
);
p_wei_block_on_global
+=
EPerBlock
*
wei_e_k_global_desc
.
GetStride
(
I0
);
__syncthreads
();
// LDS doubel buffer: load next data from device mem
blockwise_in_copy
.
RunLoadRegisterBuffer
(
p_in_global
,
p_in_register_buffer
);
blockwise_wei_copy
.
RunLoadRegisterBuffer
(
p_wei_block_on_global
,
p_wei_register_buffer
);
// LDS double buffer: GEMM on current data
blockwise_gemm
.
Run
(
p_wei_block_double
,
p_in_block_double
,
p_out_thread
);
// LDS double buffer: store next data to LDS
blockwise_in_copy
.
RunStoreRegisterBuffer
(
p_in_register_buffer
,
p_in_block_double
+
in_block_space
);
blockwise_wei_copy
.
RunStoreRegisterBuffer
(
p_wei_register_buffer
,
p_wei_block_double
+
wei_block_space
);
// odd iteration
__syncthreads
();
// LDS double buffer: GEMM on current data
blockwise_gemm
.
Run
(
p_wei_block_double
+
wei_block_space
,
p_in_block_double
+
in_block_space
,
p_out_thread
);
}
// copy output: register to global memory
{
constexpr
index_t
K2
=
GemmMPerThreadSubC
;
constexpr
index_t
K1
=
GemmMLevel0Cluster
*
GemmMLevel1Cluster
;
// define tensor descriptor for threadwise copy
// output memory layout descriptor in register
constexpr
auto
out_k0_k1_k2_n1_ho1_wo1_n0_ho0_wo0_n2_ho2_wo2_thread_mem_desc
=
make_ConstantTensorDescriptor_packed
(
Sequence
<
KPerBlock
/
(
K1
*
K2
),
1
,
K2
,
N1
,
Ho1
,
Wo1
,
1
,
1
,
1
,
N2
,
Ho2
,
Wo2
>
{});
// output tensor descriptor in register, src of threadwise copy
constexpr
auto
out_n0_n1_n2_k0_k1_k2_ho0_ho1_ho2_wo0_wo1_wo2_thread_desc
=
out_k0_k1_k2_n1_ho1_wo1_n0_ho0_wo0_n2_ho2_wo2_thread_mem_desc
.
ReorderGivenNew2Old
(
Sequence
<
6
,
3
,
9
,
0
,
1
,
2
,
7
,
4
,
10
,
8
,
5
,
11
>
{});
// output memory layout descriptor in device memory, dst of threadwise copy
constexpr
auto
out_n0_n1_n2_k0_k1_k2_ho0_ho1_ho2_wo0_wo1_wo2_global_mem_desc
=
out_n_k_h_w_global_desc
.
Fold
(
I3
,
Sequence
<
Wo1
,
Wo2
>
{})
.
Fold
(
I2
,
Sequence
<
Ho1
,
Ho2
>
{})
.
Fold
(
I1
,
Sequence
<
K1
,
K2
>
{})
.
Fold
(
I0
,
Sequence
<
N1
,
N2
>
{});
// calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index
const
auto
c_thread_mtx_on_block
=
blockwise_gemm
.
GetBeginOfThreadMatrixC
(
get_thread_local_1d_id
());
const
index_t
k_thread_data_on_global
=
k_block_data_on_global
+
c_thread_mtx_on_block
.
row
;
const
index_t
b_thread_data_on_global
=
b_block_data_on_global
+
c_thread_mtx_on_block
.
col
/
(
N2
*
Ho2
*
Wo2
);
// output merged global tensor descriptor, for calculating origin of thread tensor
// in global memory
constexpr
auto
out_k_n1_ho1_wo1_b_n2_ho2_wo2_global_merged_desc
=
make_ConstantMergedTensorDescriptor
(
out_n0_n1_n2_k0_k1_k2_ho0_ho1_ho2_wo0_wo1_wo2_global_mem_desc
.
Unfold
(
I3
,
I5
),
Sequence
<
3
>
{},
Sequence
<
1
>
{},
Sequence
<
5
>
{},
Sequence
<
8
>
{},
Sequence
<
0
,
4
,
7
>
{},
Sequence
<
2
>
{},
Sequence
<
6
>
{},
Sequence
<
9
>
{});
// origin of dst in device memory
Float
*
p_out_thread_on_global
=
p_out_global
+
out_k_n1_ho1_wo1_b_n2_ho2_wo2_global_merged_desc
.
GetOffsetFromMultiIndex
(
k_thread_data_on_global
,
0
,
0
,
0
,
b_thread_data_on_global
,
0
,
0
,
0
);
threadwise_generic_tensor_slice_copy_v1
(
out_n0_n1_n2_k0_k1_k2_ho0_ho1_ho2_wo0_wo1_wo2_thread_desc
,
p_out_thread
,
{
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
},
out_n0_n1_n2_k0_k1_k2_ho0_ho1_ho2_wo0_wo1_wo2_global_mem_desc
,
p_out_thread_on_global
,
{
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
},
out_n0_n1_n2_k0_k1_k2_ho0_ho1_ho2_wo0_wo1_wo2_thread_desc
.
GetLengths
(),
arithmetic_sequence_gen
<
0
,
12
,
1
>::
type
{},
Number
<
1
>
{});
}
}
};
}
// namespace ck
#endif
composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r4_xdlops_fp16_nchw_kcyx_nkhw.hpp
deleted
100644 → 0
View file @
80901f59
#ifndef CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4R4_XDLOPS_FP16_NCHW_KCYX_NKHW_HPP
#define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4R4_XDLOPS_FP16_NCHW_KCYX_NKHW_HPP
#include "common_header.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "gridwise_gemm_xdlops_fp16_bfp16.hpp"
namespace
ck
{
template
<
index_t
GemmKPACK
>
struct
make_vectorized_WeiDesc_Xdlops
{
template
<
typename
WeiDesc
>
__device__
constexpr
auto
get
(
WeiDesc
&
)
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
wei_k_c_y_x_global_desc
=
WeiDesc
{};
constexpr
index_t
K
=
wei_k_c_y_x_global_desc
.
GetLength
(
I0
);
constexpr
index_t
C
=
wei_k_c_y_x_global_desc
.
GetLength
(
I1
);
constexpr
index_t
Y
=
wei_k_c_y_x_global_desc
.
GetLength
(
I2
);
constexpr
index_t
X
=
wei_k_c_y_x_global_desc
.
GetLength
(
I3
);
/* kpack comes from c*y*x */
static_assert
((
C
*
Y
*
X
)
%
GemmKPACK
==
0
,
"C needs to be multiple of vectorized GemmKPACK"
);
constexpr
index_t
GemmK
=
(
C
*
Y
*
X
)
/
GemmKPACK
;
constexpr
auto
wei_gemmm_gemmk_global_desc
=
transform_tensor_descriptor
(
unfold_tensor_descriptor
(
wei_k_c_y_x_global_desc
,
I1
,
I3
),
make_tuple
(
PassThrough
<
K
>
{},
PassThrough
<
C
*
Y
*
X
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
constexpr
auto
wei_gemmm_gemmk_gemmkpack_global_desc
=
transform_tensor_descriptor
(
wei_gemmm_gemmk_global_desc
,
make_tuple
(
PassThrough
<
K
>
{},
UnMerge
<
Sequence
<
GemmK
,
GemmKPACK
>>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{}));
constexpr
auto
wei_gemmk_gemmm_gemmkpack_global_desc
=
transform_tensor_descriptor
(
wei_gemmm_gemmk_gemmkpack_global_desc
,
make_tuple
(
PassThrough
<
GemmK
>
{},
PassThrough
<
K
>
{},
PassThrough
<
GemmKPACK
>
{}),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}));
return
wei_gemmk_gemmm_gemmkpack_global_desc
;
}
};
// GemmM = K
// GemmN = N * Ho * Wo
// GemmK = C * Y * X
template
<
index_t
GridSize
,
index_t
BlockSize
,
class
Float
,
class
AccDataType
,
class
InGlobalDesc
,
class
WeiGlobalDesc
,
class
OutGlobalDesc
,
class
ConvStrides
,
class
ConvDilations
,
class
LeftPads
,
class
RightPads
,
index_t
GemmMPerBlock
,
index_t
GemmNPerBlock
,
index_t
GemmKPerBlock
,
index_t
GemmKPACK
,
index_t
GemmMPerWave
,
index_t
GemmNPerWave
,
index_t
GemmThreadGemmDataPerReadM
,
index_t
GemmThreadGemmDataPerReadN
,
class
GemmABlockCopyThreadSliceLengths_GemmK_GemmM_GemmKPACK
,
class
GemmABlockCopyThreadClusterLengths_GemmK_GemmM_GemmKPACK
,
index_t
GemmABlockCopySrcDataPerRead_GemmKPACK
,
index_t
GemmABlockCopyDstDataPerWrite_GemmKPACK
,
class
GemmBBlockCopyThreadSliceLengths_GemmK_GemmN_GemmKPACK
,
class
GemmBBlockCopyThreadClusterLengths_GemmK_GemmN_GemmKPACK
,
index_t
GemmBBlockCopySrcDataPerRead_GemmN
,
index_t
GemmBBlockCopyDstDataPerWrite_GemmKPACK
>
struct
GridwiseConvolutionImplicitGemm_v4r4_xdlops_fwd_fp16_nchw_kcyx_nkhw
{
__device__
void
Run
(
const
Float
*
const
__restrict__
p_in_global
,
const
Float
*
const
__restrict__
p_wei_global
,
Float
*
const
__restrict__
p_out_global
)
const
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
in_n_c_hi_wi_global_desc
=
InGlobalDesc
{};
constexpr
auto
wei_k_c_y_x_global_desc
=
WeiGlobalDesc
{};
constexpr
auto
out_n_k_ho_wo_global_desc
=
OutGlobalDesc
{};
constexpr
index_t
N
=
in_n_c_hi_wi_global_desc
.
GetLength
(
I0
);
constexpr
index_t
C
=
in_n_c_hi_wi_global_desc
.
GetLength
(
I1
);
constexpr
index_t
Hi
=
in_n_c_hi_wi_global_desc
.
GetLength
(
I2
);
constexpr
index_t
Wi
=
in_n_c_hi_wi_global_desc
.
GetLength
(
I3
);
constexpr
index_t
K
=
out_n_k_ho_wo_global_desc
.
GetLength
(
I1
);
constexpr
index_t
Ho
=
out_n_k_ho_wo_global_desc
.
GetLength
(
I2
);
constexpr
index_t
Wo
=
out_n_k_ho_wo_global_desc
.
GetLength
(
I3
);
constexpr
index_t
Y
=
wei_k_c_y_x_global_desc
.
GetLength
(
I2
);
constexpr
index_t
X
=
wei_k_c_y_x_global_desc
.
GetLength
(
I3
);
constexpr
index_t
GemmM
=
K
;
constexpr
index_t
GemmK
=
(
C
*
Y
*
X
)
/
GemmKPACK
;
constexpr
index_t
GemmN
=
N
*
Ho
*
Wo
;
static_assert
(
GemmM
%
GemmMPerBlock
==
0
&&
GemmN
%
GemmNPerBlock
==
0
&&
GemmK
%
GemmKPerBlock
==
0
,
"wrong! cannot divide work evenly among block"
);
// sanity-check for vectorized memory load
constexpr
index_t
ConvStrideH
=
ConvStrides
{}[
0
];
constexpr
index_t
ConvStrideW
=
ConvStrides
{}[
1
];
constexpr
index_t
ConvDilationH
=
ConvDilations
{}[
0
];
constexpr
index_t
ConvDilationW
=
ConvDilations
{}[
1
];
static_assert
((
Wo
==
1
||
(
ConvStrideW
==
1
||
GemmBBlockCopySrcDataPerRead_GemmN
==
1
))
&&
(
X
==
1
||
ConvDilationW
%
GemmBBlockCopySrcDataPerRead_GemmN
==
0
),
"wrong! aligment requirement for vectorized global load of input tensor will "
"be violated"
);
constexpr
auto
in_n_c_hip_wip_global_desc
=
transform_tensor_descriptor
(
in_n_c_hi_wi_global_desc
,
make_tuple
(
PassThrough
<
N
>
{},
PassThrough
<
C
>
{},
Pad
<
Sequence
<
Hi
,
Wi
>
,
LeftPads
,
RightPads
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
3
>
{}));
constexpr
index_t
Hip
=
in_n_c_hip_wip_global_desc
.
GetLengths
()[
2
];
constexpr
index_t
Wip
=
in_n_c_hip_wip_global_desc
.
GetLengths
()[
3
];
constexpr
auto
in_n_c_y_ho_x_wo_global_desc
=
transform_tensor_descriptor
(
in_n_c_hip_wip_global_desc
,
make_tuple
(
PassThrough
<
N
>
{},
PassThrough
<
C
>
{},
Embed
<
Hip
,
Sequence
<
Y
,
Ho
>
,
Sequence
<
ConvDilationH
,
ConvStrideH
,
0
>>
{},
Embed
<
Wip
,
Sequence
<
X
,
Wo
>
,
Sequence
<
ConvDilationW
,
ConvStrideW
,
0
>>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
3
>
{},
Sequence
<
4
,
5
>
{}));
constexpr
auto
in_gemmk_gemmn_global_desc
=
transform_tensor_descriptor
(
in_n_c_y_ho_x_wo_global_desc
,
make_tuple
(
Merge
<
Sequence
<
C
,
Y
,
X
>>
{},
Merge
<
Sequence
<
N
,
Ho
,
Wo
>>
{}),
make_tuple
(
Sequence
<
1
,
2
,
4
>
{},
Sequence
<
0
,
3
,
5
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
constexpr
auto
in_gemmk_gemmkpack_gemmn_global_desc
=
transform_tensor_descriptor
(
in_gemmk_gemmn_global_desc
,
make_tuple
(
UnMerge
<
Sequence
<
GemmK
,
GemmKPACK
>>
{},
PassThrough
<
GemmN
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
1
>
{},
Sequence
<
2
>
{}));
constexpr
auto
in_gemmk_gemmn_gemmkpack_global_desc
=
transform_tensor_descriptor
(
in_gemmk_gemmkpack_gemmn_global_desc
,
make_tuple
(
PassThrough
<
GemmK
>
{},
PassThrough
<
GemmN
>
{},
PassThrough
<
GemmKPACK
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
2
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}));
constexpr
auto
wei_gemmk_gemmm_gemmkpack_global_desc
=
make_vectorized_WeiDesc_Xdlops
<
GemmKPACK
>
{}.
get
(
wei_k_c_y_x_global_desc
);
constexpr
auto
out_gemmm_gemmn_global_desc
=
transform_tensor_descriptor
(
out_n_k_ho_wo_global_desc
,
make_tuple
(
PassThrough
<
K
>
{},
Merge
<
Sequence
<
N
,
Ho
,
Wo
>>
{}),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
,
2
,
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
// GEMM
constexpr
auto
gridwise_gemm
=
GridwiseGemmTransposedANormalBNormalCXdlopsFp16Bfp16_v1
<
GridSize
,
BlockSize
,
Float
,
AccDataType
,
Float
,
decltype
(
wei_gemmk_gemmm_gemmkpack_global_desc
),
decltype
(
in_gemmk_gemmn_gemmkpack_global_desc
),
decltype
(
out_gemmm_gemmn_global_desc
),
GemmMPerBlock
,
GemmNPerBlock
,
GemmKPerBlock
,
GemmMPerWave
,
GemmNPerWave
,
GemmThreadGemmDataPerReadM
,
GemmThreadGemmDataPerReadN
,
GemmABlockCopyThreadSliceLengths_GemmK_GemmM_GemmKPACK
,
GemmABlockCopyThreadClusterLengths_GemmK_GemmM_GemmKPACK
,
Sequence
<
1
,
0
,
2
>
,
Sequence
<
1
,
0
,
2
>
,
Sequence
<
0
,
1
,
2
>
,
2
,
GemmABlockCopySrcDataPerRead_GemmKPACK
,
GemmABlockCopyDstDataPerWrite_GemmKPACK
,
GemmBBlockCopyThreadSliceLengths_GemmK_GemmN_GemmKPACK
,
GemmBBlockCopyThreadClusterLengths_GemmK_GemmN_GemmKPACK
,
Sequence
<
0
,
2
,
1
>
,
Sequence
<
0
,
2
,
1
>
,
Sequence
<
0
,
1
,
2
>
,
1
,
GemmBBlockCopySrcDataPerRead_GemmN
,
GemmBBlockCopyDstDataPerWrite_GemmKPACK
,
InMemoryDataOperation
::
Set
>
{};
gridwise_gemm
.
Run
(
p_wei_global
,
p_in_global
,
p_out_global
);
}
};
}
// namespace ck
#endif
composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp
deleted
100644 → 0
View file @
80901f59
#ifndef CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4R4_XDLOPS_NCHW_KCYX_NKHW_HPP
#define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4R4_XDLOPS_NCHW_KCYX_NKHW_HPP
#include "common_header.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "gridwise_gemm_xdlops.hpp"
namespace
ck
{
// GemmM = K
// GemmN = N * Ho * Wo
// GemmK = C * Y * X
template
<
index_t
GridSize
,
index_t
BlockSize
,
class
Float
,
class
AccDataType
,
class
InGlobalDesc
,
class
WeiGlobalDesc
,
class
OutGlobalDesc
,
class
ConvStrides
,
class
ConvDilations
,
class
LeftPads
,
class
RightPads
,
index_t
GemmMPerBlock
,
index_t
GemmNPerBlock
,
index_t
GemmKPerBlock
,
index_t
GemmMPerWave
,
index_t
GemmNPerWave
,
index_t
GemmThreadGemmDataPerReadM
,
index_t
GemmThreadGemmDataPerReadN
,
class
GemmABlockCopyThreadSliceLengths_GemmK_GemmM
,
class
GemmABlockCopyThreadClusterLengths_GemmK_GemmM
,
index_t
GemmABlockCopySrcDataPerRead_GemmK
,
index_t
GemmABlockCopyDstDataPerWrite_GemmM
,
class
GemmBBlockCopyThreadSliceLengths_GemmK_GemmN
,
class
GemmBBlockCopyThreadClusterLengths_GemmK_GemmN
,
index_t
GemmBBlockCopySrcDataPerRead_GemmN
,
index_t
GemmBBlockCopyDstDataPerWrite_GemmN
>
struct
GridwiseConvolutionImplicitGemm_v4r4_xdlops_fwd_fp32_nchw_kcyx_nkhw
{
__device__
void
Run
(
const
Float
*
const
__restrict__
p_in_global
,
const
Float
*
const
__restrict__
p_wei_global
,
Float
*
const
__restrict__
p_out_global
)
const
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
in_n_c_hi_wi_global_desc
=
InGlobalDesc
{};
constexpr
auto
wei_k_c_y_x_global_desc
=
WeiGlobalDesc
{};
constexpr
auto
out_n_k_ho_wo_global_desc
=
OutGlobalDesc
{};
constexpr
index_t
N
=
in_n_c_hi_wi_global_desc
.
GetLength
(
I0
);
constexpr
index_t
C
=
in_n_c_hi_wi_global_desc
.
GetLength
(
I1
);
constexpr
index_t
Hi
=
in_n_c_hi_wi_global_desc
.
GetLength
(
I2
);
constexpr
index_t
Wi
=
in_n_c_hi_wi_global_desc
.
GetLength
(
I3
);
constexpr
index_t
K
=
out_n_k_ho_wo_global_desc
.
GetLength
(
I1
);
constexpr
index_t
Ho
=
out_n_k_ho_wo_global_desc
.
GetLength
(
I2
);
constexpr
index_t
Wo
=
out_n_k_ho_wo_global_desc
.
GetLength
(
I3
);
constexpr
index_t
Y
=
wei_k_c_y_x_global_desc
.
GetLength
(
I2
);
constexpr
index_t
X
=
wei_k_c_y_x_global_desc
.
GetLength
(
I3
);
constexpr
index_t
ConvStrideH
=
ConvStrides
{}[
0
];
constexpr
index_t
ConvStrideW
=
ConvStrides
{}[
1
];
constexpr
index_t
ConvDilationH
=
ConvDilations
{}[
0
];
constexpr
index_t
ConvDilationW
=
ConvDilations
{}[
1
];
constexpr
index_t
GemmM
=
K
;
constexpr
index_t
GemmK
=
C
*
Y
*
X
;
constexpr
index_t
GemmN
=
N
*
Ho
*
Wo
;
static_assert
(
GemmM
%
GemmMPerBlock
==
0
&&
GemmN
%
GemmNPerBlock
==
0
&&
GemmK
%
GemmKPerBlock
==
0
,
"wrong! cannot divide work evenly among block"
);
// sanity-check for vectorized memory load
static_assert
((
Wo
==
1
||
(
ConvStrideW
==
1
||
GemmBBlockCopySrcDataPerRead_GemmN
==
1
))
&&
(
X
==
1
||
ConvDilationW
%
GemmBBlockCopySrcDataPerRead_GemmN
==
0
),
"wrong! aligment requirement for vectorized global load of input tensor will "
"be violated"
);
// input tensor
// global mem
constexpr
auto
in_n_c_hip_wip_global_desc
=
transform_tensor_descriptor
(
in_n_c_hi_wi_global_desc
,
make_tuple
(
PassThrough
<
N
>
{},
PassThrough
<
C
>
{},
Pad
<
Sequence
<
Hi
,
Wi
>
,
LeftPads
,
RightPads
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
3
>
{}));
constexpr
index_t
Hip
=
in_n_c_hip_wip_global_desc
.
GetLengths
()[
2
];
constexpr
index_t
Wip
=
in_n_c_hip_wip_global_desc
.
GetLengths
()[
3
];
constexpr
auto
in_n_c_y_ho_x_wo_global_desc
=
transform_tensor_descriptor
(
in_n_c_hip_wip_global_desc
,
make_tuple
(
PassThrough
<
N
>
{},
PassThrough
<
C
>
{},
Embed
<
Hip
,
Sequence
<
Y
,
Ho
>
,
Sequence
<
ConvDilationH
,
ConvStrideH
,
0
>>
{},
Embed
<
Wip
,
Sequence
<
X
,
Wo
>
,
Sequence
<
ConvDilationW
,
ConvStrideW
,
0
>>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
3
>
{},
Sequence
<
4
,
5
>
{}));
constexpr
auto
in_gemmk_gemmn_global_desc
=
transform_tensor_descriptor
(
in_n_c_y_ho_x_wo_global_desc
,
make_tuple
(
Merge
<
Sequence
<
C
,
Y
,
X
>>
{},
Merge
<
Sequence
<
N
,
Ho
,
Wo
>>
{}),
make_tuple
(
Sequence
<
1
,
2
,
4
>
{},
Sequence
<
0
,
3
,
5
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
constexpr
auto
wei_gemmk_gemmm_global_desc
=
reorder_tensor_descriptor_given_upper2lower
(
unfold_tensor_descriptor
(
wei_k_c_y_x_global_desc
,
I1
,
I3
),
Sequence
<
1
,
0
>
{});
constexpr
auto
out_gemmm_gemmn_global_desc
=
transform_tensor_descriptor
(
out_n_k_ho_wo_global_desc
,
make_tuple
(
PassThrough
<
K
>
{},
Merge
<
Sequence
<
N
,
Ho
,
Wo
>>
{}),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
,
2
,
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
// GEMM
constexpr
auto
gridwise_gemm
=
GridwiseGemmTransposedANormalBNormalCXdlops_v1
<
GridSize
,
BlockSize
,
Float
,
AccDataType
,
decltype
(
wei_gemmk_gemmm_global_desc
),
decltype
(
in_gemmk_gemmn_global_desc
),
decltype
(
out_gemmm_gemmn_global_desc
),
GemmMPerBlock
,
GemmNPerBlock
,
GemmKPerBlock
,
GemmMPerWave
,
GemmNPerWave
,
GemmThreadGemmDataPerReadM
,
GemmThreadGemmDataPerReadN
,
GemmABlockCopyThreadSliceLengths_GemmK_GemmM
,
GemmABlockCopyThreadClusterLengths_GemmK_GemmM
,
Sequence
<
1
,
0
>
,
Sequence
<
1
,
0
>
,
Sequence
<
0
,
1
>
,
0
,
GemmABlockCopySrcDataPerRead_GemmK
,
GemmABlockCopyDstDataPerWrite_GemmM
,
GemmBBlockCopyThreadSliceLengths_GemmK_GemmN
,
GemmBBlockCopyThreadClusterLengths_GemmK_GemmN
,
Sequence
<
0
,
1
>
,
Sequence
<
0
,
1
>
,
Sequence
<
0
,
1
>
,
1
,
GemmBBlockCopySrcDataPerRead_GemmN
,
GemmBBlockCopyDstDataPerWrite_GemmN
,
InMemoryDataOperation
::
Set
>
{};
gridwise_gemm
.
Run
(
p_wei_global
,
p_in_global
,
p_out_global
);
}
};
}
// namespace ck
#endif
composable_kernel/include/tensor_operation/blockwise_gemm_xdlops.hpp
deleted
100644 → 0
View file @
80901f59
#ifndef CK_BLOCKWISE_GEMM_XDLOPS_HPP
#define CK_BLOCKWISE_GEMM_XDLOPS_HPP
#include "common_header.hpp"
#include "ConstantMatrixDescriptor.hpp"
#include "xdlops_gemm.hpp"
#include "threadwise_gemm.hpp"
namespace
ck
{
template
<
index_t
BlockSize
,
class
BlockMatrixA
,
class
BlockMatrixB
,
class
Float
,
index_t
GemmMPerWave
,
index_t
GemmNPerWave
,
index_t
GemmMWaves
,
index_t
GemmNWaves
,
index_t
GemmDataPerReadA
,
index_t
GemmDataPerReadB
>
struct
BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_xdlops
{
struct
MatrixIndex
{
index_t
row
;
index_t
col
;
};
// static constexpr XdlopsGemm_t XdlopsGemm = XdlopsGemm_t<Float, GemmMPerWave, GemmNPerWave,
// GemmDataPerReadA, GemmDataPerReadB>{};
index_t
mMyWaveOffsetA
;
index_t
mMyWaveOffsetB
;
static
constexpr
index_t
WaveSize
=
64
;
__device__
constexpr
auto
GetOutputLayout
()
const
{
return
XdlopsGemm_t
<
Float
,
GemmMPerWave
,
GemmNPerWave
,
GemmDataPerReadA
,
GemmDataPerReadB
>
{}
.
GetOutputLayout
();
}
__device__
BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_xdlops
()
{
static_assert
(
BlockMatrixA
::
NRow
()
==
BlockMatrixB
::
NRow
(),
"wrong! K dimension not consistent
\n
"
);
constexpr
index_t
M
=
BlockMatrixA
::
NCol
();
// A is transposed
constexpr
index_t
N
=
BlockMatrixB
::
NCol
();
static_assert
(
GemmMPerWave
*
GemmMWaves
==
M
,
"GemmMWaves * GemmMPerWave != M"
);
static_assert
(
GemmNPerWave
*
GemmNWaves
==
N
,
"GemmNWaves * GemmNPerWave != N"
);
static_assert
(
BlockSize
==
GemmMWaves
*
GemmNWaves
*
WaveSize
,
"BlockSize != GemmMWaves * GemmNWaves * WaveSize
\n
"
);
const
index_t
waveId
=
get_thread_local_1d_id
()
/
WaveSize
;
const
index_t
waveId_m
=
waveId
/
GemmNWaves
;
const
index_t
waveId_n
=
waveId
%
GemmNWaves
;
mMyWaveOffsetA
=
waveId_m
*
GemmMPerWave
;
mMyWaveOffsetB
=
waveId_n
*
GemmNPerWave
;
}
template
<
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
void
Run
(
const
FloatA
*
__restrict__
p_a_block
,
const
FloatB
*
__restrict__
p_b_block
,
FloatC
*
__restrict__
p_c_thread
)
const
{
constexpr
index_t
M
=
BlockMatrixA
::
NCol
();
// A is transposed
constexpr
index_t
N
=
BlockMatrixB
::
NCol
();
constexpr
index_t
K
=
BlockMatrixA
::
NRow
();
XdlopsGemm_t
<
Float
,
GemmMPerWave
,
GemmNPerWave
,
GemmDataPerReadA
,
GemmDataPerReadB
>
{}
.
template
Run
<
M
,
N
,
K
>(
&
p_a_block
[
mMyWaveOffsetA
],
&
p_b_block
[
mMyWaveOffsetB
],
p_c_thread
);
}
__device__
static
MatrixIndex
GetBeginOfThreadMatrixC
(
index_t
i
)
{
const
index_t
waveId
=
get_thread_local_1d_id
()
/
WaveSize
;
const
auto
thread_mtx_on_blk
=
XdlopsGemm_t
<
Float
,
GemmMPerWave
,
GemmNPerWave
,
GemmDataPerReadA
,
GemmDataPerReadB
>
{}
.
GetBeginOfThreadBlk
(
i
);
const
index_t
col
=
waveId
%
GemmNWaves
*
GemmNPerWave
+
thread_mtx_on_blk
.
col
;
const
index_t
row
=
waveId
/
GemmNWaves
*
GemmMPerWave
+
thread_mtx_on_blk
.
row
;
return
MatrixIndex
{
row
,
col
};
}
__device__
constexpr
auto
GetThreadMatrixCDescriptor
()
const
{
const
index_t
reg_size
=
GemmMPerWave
*
GemmNPerWave
/
WaveSize
;
return
make_ConstantMatrixDescriptor_packed
(
Number
<
reg_size
>
{},
Number
<
1
>
{});
}
__device__
void
XdlopsMatrixCSetZero
()
const
{
constexpr
auto
thread_mtx_size
=
GemmMPerWave
*
GemmNPerWave
/
WaveSize
;
XdlopsGemm_t
<
Float
,
GemmMPerWave
,
GemmNPerWave
,
GemmDataPerReadA
,
GemmDataPerReadB
>
{}
.
SetZeroXdlopsRegs
(
Number
<
thread_mtx_size
>
{});
}
template
<
class
FloatC
>
__device__
void
XdlopsMatrixCRead
(
FloatC
*
__restrict__
p_c_thread
)
const
{
constexpr
auto
thread_mtx_size
=
GemmMPerWave
*
GemmNPerWave
/
WaveSize
;
XdlopsGemm_t
<
Float
,
GemmMPerWave
,
GemmNPerWave
,
GemmDataPerReadA
,
GemmDataPerReadB
>
{}
.
ReadXdlopsRegs
(
Number
<
thread_mtx_size
>
{},
p_c_thread
);
}
};
}
// namespace ck
#endif
composable_kernel/include/tensor_operation/gridwise_gemm_xdlops.hpp
deleted
100644 → 0
View file @
80901f59
#ifndef CK_GRIDWISE_GEMM_XDLOPS_HPP
#define CK_GRIDWISE_GEMM_XDLOPS_HPP
#include "common_header.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "ConstantMatrixDescriptor.hpp"
#include "blockwise_generic_tensor_slice_copy.hpp"
#include "threadwise_generic_tensor_slice_copy.hpp"
#include "blockwise_gemm_xdlops.hpp"
namespace
ck
{
template
<
index_t
GridSize
,
index_t
BlockSize
,
class
Float
,
class
AccFloat
,
class
AGlobalDesc
,
class
BGlobalDesc
,
class
CGlobalDesc
,
index_t
MPerBlock
,
index_t
NPerBlock
,
index_t
KPerBlock
,
index_t
MPerWave
,
index_t
NPerWave
,
index_t
GemmDataPerReadM
,
index_t
GemmDataPerReadN
,
class
ABlockCopyThreadSliceLengths_K_M
,
class
ABlockCopyThreadClusterLengths_K_M
,
class
ABlockCopyThreadClusterArrangeOrder
,
class
ABlockCopySrcAccessOrder
,
class
ABlockCopyDstAccessOrder
,
index_t
ABlockCopySrcVectorReadDim
,
index_t
ABlockCopySrcDataPerRead
,
index_t
ABlockCopyDstDataPerWrite_M
,
class
BBlockCopyThreadSliceLengths_K_N
,
class
BBlockCopyThreadClusterLengths_K_N
,
class
BBlockCopyThreadClusterArrangeOrder
,
class
BBlockCopySrcAccessOrder
,
class
BBlockCopyDstAccessOrder
,
index_t
BBlockCopySrcVectorReadDim
,
index_t
BBlockCopySrcDataPerRead
,
index_t
BBlockCopyDstDataPerWrite_N
,
InMemoryDataOperation
CGlobalMemoryDataOperation
>
struct
GridwiseGemmTransposedANormalBNormalCXdlops_v1
{
__device__
void
Run
(
const
Float
*
const
__restrict__
p_a_global
,
const
Float
*
const
__restrict__
p_b_global
,
Float
*
const
__restrict__
p_c_global
)
const
{
constexpr
auto
True
=
integral_constant
<
bool
,
true
>
{};
constexpr
auto
a_k_m_global_desc
=
AGlobalDesc
{};
constexpr
auto
b_k_n_global_desc
=
BGlobalDesc
{};
constexpr
auto
c_m_n_global_desc
=
CGlobalDesc
{};
constexpr
auto
K
=
b_k_n_global_desc
.
GetLengths
()[
0
];
constexpr
auto
N
=
b_k_n_global_desc
.
GetLengths
()[
1
];
constexpr
auto
M
=
a_k_m_global_desc
.
GetLengths
()[
1
];
// divide block work by [M, N]
static_assert
(
M
%
MPerBlock
==
0
&&
N
%
NPerBlock
==
0
&&
K
%
KPerBlock
==
0
,
"wrong! cannot divide work evenly among block"
);
constexpr
index_t
MBlockWork
=
M
/
MPerBlock
;
constexpr
index_t
NBlockWork
=
N
/
NPerBlock
;
static_assert
(
MPerBlock
%
MPerWave
==
0
&&
NPerBlock
%
NPerWave
==
0
,
"wrong! M/NPerBlock % M/NPerWave != 0"
);
constexpr
index_t
MWaves
=
MPerBlock
/
MPerWave
;
constexpr
index_t
NWaves
=
NPerBlock
/
NPerWave
;
constexpr
auto
block_work_desc
=
make_cluster_descriptor
(
Sequence
<
MBlockWork
,
NBlockWork
>
{});
const
auto
block_work_id
=
block_work_desc
.
CalculateClusterIndex
(
get_block_1d_id
());
const
index_t
m_block_data_on_global
=
block_work_id
[
0
]
*
MPerBlock
;
const
index_t
n_block_data_on_global
=
block_work_id
[
1
]
*
NPerBlock
;
// LDS mem
constexpr
index_t
max_align
=
math
::
lcm
(
BBlockCopyDstDataPerWrite_N
,
ABlockCopyDstDataPerWrite_M
,
GemmDataPerReadM
,
GemmDataPerReadN
);
// LDS
// be careful of LDS alignment
constexpr
auto
a_k_m_block_desc
=
make_native_tensor_descriptor_aligned
(
Sequence
<
KPerBlock
,
MPerBlock
>
{},
Number
<
max_align
>
{});
auto
a_blockwise_copy
=
BlockwiseGenericTensorSliceCopy_v4
<
BlockSize
,
decltype
(
a_k_m_global_desc
),
decltype
(
a_k_m_block_desc
),
decltype
(
a_k_m_block_desc
.
GetLengths
()),
ABlockCopyThreadSliceLengths_K_M
,
ABlockCopyThreadClusterLengths_K_M
,
ABlockCopyThreadClusterArrangeOrder
,
ABlockCopySrcAccessOrder
,
ABlockCopyDstAccessOrder
,
ABlockCopySrcVectorReadDim
,
1
,
ABlockCopySrcDataPerRead
,
ABlockCopyDstDataPerWrite_M
,
AddressSpace
::
Global
,
AddressSpace
::
Vgpr
,
AddressSpace
::
Lds
,
InMemoryDataOperation
::
Set
>
(
{
0
,
m_block_data_on_global
},
{
0
,
0
});
constexpr
auto
b_k_n_block_desc
=
make_native_tensor_descriptor_aligned
(
Sequence
<
KPerBlock
,
NPerBlock
>
{},
Number
<
max_align
>
{});
auto
b_blockwise_copy
=
BlockwiseGenericTensorSliceCopy_v4
<
BlockSize
,
decltype
(
b_k_n_global_desc
),
decltype
(
b_k_n_block_desc
),
decltype
(
b_k_n_block_desc
.
GetLengths
()),
BBlockCopyThreadSliceLengths_K_N
,
BBlockCopyThreadClusterLengths_K_N
,
BBlockCopyThreadClusterArrangeOrder
,
BBlockCopySrcAccessOrder
,
BBlockCopyDstAccessOrder
,
BBlockCopySrcVectorReadDim
,
1
,
BBlockCopySrcDataPerRead
,
BBlockCopyDstDataPerWrite_N
,
AddressSpace
::
Global
,
AddressSpace
::
Vgpr
,
AddressSpace
::
Lds
,
InMemoryDataOperation
::
Set
>
(
{
0
,
n_block_data_on_global
},
{
0
,
0
});
// GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx
// a_mtx[KPerBlock, MPerBlock] is in LDS
// b_mtx[EPerBlocl, NPerBlock] is in LDS
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// register
constexpr
auto
a_k_m_block_mtx_desc
=
make_ConstantMatrixDescriptor
(
a_k_m_block_desc
);
constexpr
auto
b_k_n_block_mtx_desc
=
make_ConstantMatrixDescriptor
(
b_k_n_block_desc
);
const
auto
blockwise_gemm
=
BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_xdlops
<
BlockSize
,
decltype
(
a_k_m_block_mtx_desc
),
decltype
(
b_k_n_block_mtx_desc
),
Float
,
MPerWave
,
NPerWave
,
MWaves
,
NWaves
,
GemmDataPerReadM
,
GemmDataPerReadN
>
{};
constexpr
auto
c_k_thread_mtx_desc
=
blockwise_gemm
.
GetThreadMatrixCDescriptor
();
constexpr
index_t
a_block_space
=
math
::
integer_least_multiple
(
a_k_m_block_desc
.
GetElementSpace
(),
max_align
);
constexpr
index_t
b_block_space
=
math
::
integer_least_multiple
(
b_k_n_block_desc
.
GetElementSpace
(),
max_align
);
__shared__
Float
p_a_block_double
[
2
*
a_block_space
];
__shared__
Float
p_b_block_double
[
2
*
b_block_space
];
// register allocation for output
AccFloat
p_c_thread
[
c_k_thread_mtx_desc
.
GetElementSpace
()];
// zero out threadwise output
threadwise_matrix_set_zero
(
c_k_thread_mtx_desc
,
p_c_thread
);
blockwise_gemm
.
XdlopsMatrixCSetZero
();
// LDS double buffer: preload data into LDS
{
a_blockwise_copy
.
Run
(
p_a_global
,
p_a_block_double
);
b_blockwise_copy
.
Run
(
p_b_global
,
p_b_block_double
);
}
using
b_blockwise_copy_src_step
=
Sequence
<
KPerBlock
,
0
>
;
using
a_blockwise_copy_src_step
=
Sequence
<
KPerBlock
,
0
>
;
// LDS double buffer: main body
for
(
index_t
k_block_data_begin
=
0
;
k_block_data_begin
+
2
*
KPerBlock
<
K
;
k_block_data_begin
+=
2
*
KPerBlock
)
{
#pragma unroll
for
(
index_t
iloop
=
0
;
iloop
<
2
;
++
iloop
)
{
const
bool
even_loop
=
(
iloop
%
2
==
0
);
Float
*
p_a_block_now
=
even_loop
?
p_a_block_double
:
p_a_block_double
+
a_block_space
;
Float
*
p_b_block_now
=
even_loop
?
p_b_block_double
:
p_b_block_double
+
b_block_space
;
Float
*
p_a_block_next
=
even_loop
?
p_a_block_double
+
a_block_space
:
p_a_block_double
;
Float
*
p_b_block_next
=
even_loop
?
p_b_block_double
+
b_block_space
:
p_b_block_double
;
Float
p_a_thread_buffer
[
a_blockwise_copy
.
GetThreadBufferSize
()];
Float
p_b_thread_buffer
[
b_blockwise_copy
.
GetThreadBufferSize
()];
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_blockwise_copy_src_step
{},
True
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_blockwise_copy_src_step
{},
True
);
__syncthreads
();
// LDS doubel buffer: load next data from device mem
a_blockwise_copy
.
RunLoadThreadBuffer
(
p_a_global
,
p_a_thread_buffer
);
b_blockwise_copy
.
RunLoadThreadBuffer
(
p_b_global
,
p_b_thread_buffer
);
// LDS double buffer: GEMM on current data
blockwise_gemm
.
Run
(
p_a_block_now
,
p_b_block_now
,
p_c_thread
);
// LDS double buffer: store next data to LDS
a_blockwise_copy
.
RunStoreThreadBuffer
(
p_a_thread_buffer
,
p_a_block_next
);
b_blockwise_copy
.
RunStoreThreadBuffer
(
p_b_thread_buffer
,
p_b_block_next
);
}
}
// LDS double buffer: tail
{
constexpr
bool
has_two_iteration_left
=
(
K
%
(
2
*
KPerBlock
)
==
0
);
if
(
has_two_iteration_left
)
// if has 2 iteration left
{
Float
p_a_thread_buffer
[
a_blockwise_copy
.
GetThreadBufferSize
()];
Float
p_b_thread_buffer
[
b_blockwise_copy
.
GetThreadBufferSize
()];
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_blockwise_copy_src_step
{},
True
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_blockwise_copy_src_step
{},
True
);
__syncthreads
();
// LDS double buffer: load last data from device mem
a_blockwise_copy
.
RunLoadThreadBuffer
(
p_a_global
,
p_a_thread_buffer
);
b_blockwise_copy
.
RunLoadThreadBuffer
(
p_b_global
,
p_b_thread_buffer
);
// LDS double buffer: GEMM on 2nd-last data
blockwise_gemm
.
Run
(
p_a_block_double
,
p_b_block_double
,
p_c_thread
);
// LDS double buffer: store last data to LDS
a_blockwise_copy
.
RunStoreThreadBuffer
(
p_a_thread_buffer
,
p_a_block_double
+
a_block_space
);
b_blockwise_copy
.
RunStoreThreadBuffer
(
p_b_thread_buffer
,
p_b_block_double
+
b_block_space
);
__syncthreads
();
// LDS double buffer: GEMM on current data
blockwise_gemm
.
Run
(
p_a_block_double
+
a_block_space
,
p_b_block_double
+
b_block_space
,
p_c_thread
);
}
else
// if has 1 iteration left
{
__syncthreads
();
// LDS double buffer: GEMM on last data
blockwise_gemm
.
Run
(
p_a_block_double
,
p_b_block_double
,
p_c_thread
);
}
}
// load data from xldop_acc_regs
blockwise_gemm
.
XdlopsMatrixCRead
(
p_c_thread
);
// copy output: register to global memory
{
///\todo inconsistent layout of xdlops and tensor
// xdlops layout
// M1 = num_groups;
// M0 = group_size;
// N1 = num_blks_per_wave;
// N0 = num_threads_per_blks;
constexpr
auto
CLayout
=
blockwise_gemm
.
GetOutputLayout
();
constexpr
index_t
M0
=
CLayout
.
M1
();
constexpr
index_t
M1
=
CLayout
.
N1
();
constexpr
index_t
M2
=
CLayout
.
M0
();
constexpr
auto
c_m0_m1_m2_n_global_desc
=
transform_tensor_descriptor
(
c_m_n_global_desc
,
make_tuple
(
UnMerge
<
Sequence
<
M0
,
M1
,
M2
>>
{},
PassThrough
<
N
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{},
Sequence
<
3
>
{}));
// src descriptor
constexpr
auto
c_m0_m1_m2_n_thread_desc
=
make_native_tensor_descriptor_packed
(
Sequence
<
M0
,
1
,
M2
,
1
>
{});
using
CThreadCopySliceLengths
=
Sequence
<
M0
,
1
,
M2
,
1
>
;
constexpr
index_t
BlkSize
=
CLayout
.
GetBlkSize
();
constexpr
index_t
NumBlks
=
CLayout
.
GetNumBlks
();
for
(
index_t
i
=
0
;
i
<
NumBlks
;
++
i
)
{
// calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index
const
auto
c_thread_mtx_on_block
=
blockwise_gemm
.
GetBeginOfThreadMatrixC
(
i
);
const
index_t
m_thread_data_on_global
=
m_block_data_on_global
+
c_thread_mtx_on_block
.
row
;
const
index_t
n_thread_data_on_global
=
n_block_data_on_global
+
c_thread_mtx_on_block
.
col
;
ThreadwiseGenericTensorSliceCopy_v4r2
<
decltype
(
c_m0_m1_m2_n_thread_desc
),
decltype
(
c_m0_m1_m2_n_global_desc
),
CThreadCopySliceLengths
,
arithmetic_sequence_gen
<
0
,
4
,
1
>::
type
,
3
,
1
,
1
,
AddressSpace
::
Vgpr
,
AddressSpace
::
Global
,
CGlobalMemoryDataOperation
>
(
{
0
,
0
,
0
,
0
},
{
m_thread_data_on_global
/
(
M2
*
M1
),
m_thread_data_on_global
%
(
M2
*
M1
)
/
M2
,
m_thread_data_on_global
%
M2
,
n_thread_data_on_global
})
.
Run
(
p_c_thread
+
i
*
BlkSize
,
p_c_global
);
}
}
}
};
template
<
index_t
GridSize
,
index_t
BlockSize
,
class
Float
,
class
AccFloat
,
class
AGlobalDesc
,
class
BGlobalDesc
,
class
CGlobalDesc
,
index_t
MPerBlock
,
index_t
NPerBlock
,
index_t
KPerBlock
,
index_t
MPerWave
,
index_t
NPerWave
,
index_t
GemmDataPerReadM
,
index_t
GemmDataPerReadN
,
class
ABlockCopyThreadSliceLengths_G_K_M
,
class
ABlockCopyThreadClusterLengths_G_K_M
,
class
ABlockCopyThreadClusterArrangeOrder
,
class
ABlockCopySrcAccessOrder
,
class
ABlockCopyDstAccessOrder
,
index_t
ABlockCopySrcVectorReadDim
,
index_t
ABlockCopySrcDataPerRead
,
index_t
ABlockCopyDstDataPerWrite_M
,
class
BBlockCopyThreadSliceLengths_G_K_N
,
class
BBlockCopyThreadClusterLengths_G_K_N
,
class
BBlockCopyThreadClusterArrangeOrder
,
class
BBlockCopySrcAccessOrder
,
class
BBlockCopyDstAccessOrder
,
index_t
BBlockCopySrcVectorReadDim
,
index_t
BBlockCopySrcDataPerRead
,
index_t
BBlockCopyDstDataPerWrite_N
,
InMemoryDataOperation
CGlobalMemoryDataOperation
>
struct
GridwiseBatchedGemmTransposedANormalBNormalCXdlops_v1
{
__device__
void
Run
(
const
Float
*
const
__restrict__
p_a_global
,
const
Float
*
const
__restrict__
p_b_global
,
Float
*
const
__restrict__
p_c_global
)
const
{
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
True
=
integral_constant
<
bool
,
true
>
{};
constexpr
auto
a_g_k_m_global_desc
=
AGlobalDesc
{};
constexpr
auto
b_g_k_n_global_desc
=
BGlobalDesc
{};
constexpr
auto
c_g_m_n_global_desc
=
CGlobalDesc
{};
constexpr
auto
G
=
b_g_k_n_global_desc
.
GetLengths
()[
0
];
constexpr
auto
K
=
b_g_k_n_global_desc
.
GetLengths
()[
1
];
constexpr
auto
N
=
b_g_k_n_global_desc
.
GetLengths
()[
2
];
constexpr
auto
M
=
a_g_k_m_global_desc
.
GetLengths
()[
2
];
// divide block work by [M, N]
static_assert
(
M
%
MPerBlock
==
0
&&
N
%
NPerBlock
==
0
&&
K
%
KPerBlock
==
0
,
"wrong! cannot divide work evenly among block"
);
constexpr
index_t
MBlockWork
=
M
/
MPerBlock
;
constexpr
index_t
NBlockWork
=
N
/
NPerBlock
;
static_assert
(
MPerBlock
%
MPerWave
==
0
&&
NPerBlock
%
NPerWave
==
0
,
"wrong! M/NPerBlock % M/NPerWave != 0"
);
constexpr
index_t
MWaves
=
MPerBlock
/
MPerWave
;
constexpr
index_t
NWaves
=
NPerBlock
/
NPerWave
;
constexpr
auto
block_work_desc
=
make_cluster_descriptor
(
Sequence
<
G
,
MBlockWork
,
NBlockWork
>
{});
const
auto
block_work_id
=
block_work_desc
.
CalculateClusterIndex
(
get_block_1d_id
());
const
index_t
group_id
=
block_work_id
[
0
];
const
index_t
m_block_data_on_global
=
block_work_id
[
1
]
*
MPerBlock
;
const
index_t
n_block_data_on_global
=
block_work_id
[
2
]
*
NPerBlock
;
// LDS mem
constexpr
index_t
max_align
=
math
::
lcm
(
BBlockCopyDstDataPerWrite_N
,
ABlockCopyDstDataPerWrite_M
,
GemmDataPerReadM
,
GemmDataPerReadN
);
// LDS
// be careful of LDS alignment
constexpr
auto
a_g_k_m_block_desc
=
make_native_tensor_descriptor_aligned
(
Sequence
<
1
,
KPerBlock
,
MPerBlock
>
{},
Number
<
max_align
>
{});
auto
a_blockwise_copy
=
BlockwiseGenericTensorSliceCopy_v4
<
BlockSize
,
decltype
(
a_g_k_m_global_desc
),
decltype
(
a_g_k_m_block_desc
),
decltype
(
a_g_k_m_block_desc
.
GetLengths
()),
ABlockCopyThreadSliceLengths_G_K_M
,
ABlockCopyThreadClusterLengths_G_K_M
,
ABlockCopyThreadClusterArrangeOrder
,
ABlockCopySrcAccessOrder
,
ABlockCopyDstAccessOrder
,
ABlockCopySrcVectorReadDim
,
2
,
ABlockCopySrcDataPerRead
,
ABlockCopyDstDataPerWrite_M
,
AddressSpace
::
Global
,
AddressSpace
::
Vgpr
,
AddressSpace
::
Lds
,
InMemoryDataOperation
::
Set
>
(
{
group_id
,
0
,
m_block_data_on_global
},
{
0
,
0
,
0
});
constexpr
auto
b_g_k_n_block_desc
=
make_native_tensor_descriptor_aligned
(
Sequence
<
1
,
KPerBlock
,
NPerBlock
>
{},
Number
<
max_align
>
{});
auto
b_blockwise_copy
=
BlockwiseGenericTensorSliceCopy_v4
<
BlockSize
,
decltype
(
b_g_k_n_global_desc
),
decltype
(
b_g_k_n_block_desc
),
decltype
(
b_g_k_n_block_desc
.
GetLengths
()),
BBlockCopyThreadSliceLengths_G_K_N
,
BBlockCopyThreadClusterLengths_G_K_N
,
BBlockCopyThreadClusterArrangeOrder
,
BBlockCopySrcAccessOrder
,
BBlockCopyDstAccessOrder
,
BBlockCopySrcVectorReadDim
,
2
,
BBlockCopySrcDataPerRead
,
BBlockCopyDstDataPerWrite_N
,
AddressSpace
::
Global
,
AddressSpace
::
Vgpr
,
AddressSpace
::
Lds
,
InMemoryDataOperation
::
Set
>
(
{
group_id
,
0
,
n_block_data_on_global
},
{
0
,
0
,
0
});
// GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx
// a_mtx[KPerBlock, MPerBlock] is in LDS
// b_mtx[EPerBlocl, NPerBlock] is in LDS
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// register
constexpr
auto
a_k_m_block_mtx_desc
=
make_ConstantMatrixDescriptor_packed
(
a_g_k_m_block_desc
.
GetLength
(
I1
),
a_g_k_m_block_desc
.
GetLength
(
I2
));
constexpr
auto
b_k_n_block_mtx_desc
=
make_ConstantMatrixDescriptor_packed
(
b_g_k_n_block_desc
.
GetLength
(
I1
),
b_g_k_n_block_desc
.
GetLength
(
I2
));
const
auto
blockwise_gemm
=
BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_xdlops
<
BlockSize
,
decltype
(
a_k_m_block_mtx_desc
),
decltype
(
b_k_n_block_mtx_desc
),
Float
,
MPerWave
,
NPerWave
,
MWaves
,
NWaves
,
GemmDataPerReadM
,
GemmDataPerReadN
>
{};
constexpr
auto
c_k_thread_mtx_desc
=
blockwise_gemm
.
GetThreadMatrixCDescriptor
();
constexpr
index_t
a_block_space
=
math
::
integer_least_multiple
(
a_g_k_m_block_desc
.
GetElementSpace
(),
max_align
);
constexpr
index_t
b_block_space
=
math
::
integer_least_multiple
(
b_g_k_n_block_desc
.
GetElementSpace
(),
max_align
);
__shared__
Float
p_a_block_double
[
2
*
a_block_space
];
__shared__
Float
p_b_block_double
[
2
*
b_block_space
];
// register allocation for output
AccFloat
p_c_thread
[
c_k_thread_mtx_desc
.
GetElementSpace
()];
// zero out threadwise output
threadwise_matrix_set_zero
(
c_k_thread_mtx_desc
,
p_c_thread
);
blockwise_gemm
.
XdlopsMatrixCSetZero
();
// LDS double buffer: preload data into LDS
{
a_blockwise_copy
.
Run
(
p_a_global
,
p_a_block_double
);
b_blockwise_copy
.
Run
(
p_b_global
,
p_b_block_double
);
}
using
b_blockwise_copy_src_step
=
Sequence
<
0
,
KPerBlock
,
0
>
;
using
a_blockwise_copy_src_step
=
Sequence
<
0
,
KPerBlock
,
0
>
;
// LDS double buffer: main body
for
(
index_t
k_block_data_begin
=
0
;
k_block_data_begin
+
2
*
KPerBlock
<
K
;
k_block_data_begin
+=
2
*
KPerBlock
)
{
#pragma unroll
for
(
index_t
iloop
=
0
;
iloop
<
2
;
++
iloop
)
{
const
bool
even_loop
=
(
iloop
%
2
==
0
);
Float
*
p_a_block_now
=
even_loop
?
p_a_block_double
:
p_a_block_double
+
a_block_space
;
Float
*
p_b_block_now
=
even_loop
?
p_b_block_double
:
p_b_block_double
+
b_block_space
;
Float
*
p_a_block_next
=
even_loop
?
p_a_block_double
+
a_block_space
:
p_a_block_double
;
Float
*
p_b_block_next
=
even_loop
?
p_b_block_double
+
b_block_space
:
p_b_block_double
;
Float
p_a_thread_buffer
[
a_blockwise_copy
.
GetThreadBufferSize
()];
Float
p_b_thread_buffer
[
b_blockwise_copy
.
GetThreadBufferSize
()];
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_blockwise_copy_src_step
{},
True
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_blockwise_copy_src_step
{},
True
);
__syncthreads
();
// LDS doubel buffer: load next data from device mem
a_blockwise_copy
.
RunLoadThreadBuffer
(
p_a_global
,
p_a_thread_buffer
);
b_blockwise_copy
.
RunLoadThreadBuffer
(
p_b_global
,
p_b_thread_buffer
);
// LDS double buffer: GEMM on current data
blockwise_gemm
.
Run
(
p_a_block_now
,
p_b_block_now
,
p_c_thread
);
// LDS double buffer: store next data to LDS
a_blockwise_copy
.
RunStoreThreadBuffer
(
p_a_thread_buffer
,
p_a_block_next
);
b_blockwise_copy
.
RunStoreThreadBuffer
(
p_b_thread_buffer
,
p_b_block_next
);
}
}
// LDS double buffer: tail
{
constexpr
bool
has_two_iteration_left
=
(
K
%
(
2
*
KPerBlock
)
==
0
);
if
(
has_two_iteration_left
)
// if has 2 iteration left
{
Float
p_a_thread_buffer
[
a_blockwise_copy
.
GetThreadBufferSize
()];
Float
p_b_thread_buffer
[
b_blockwise_copy
.
GetThreadBufferSize
()];
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_blockwise_copy_src_step
{},
True
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_blockwise_copy_src_step
{},
True
);
__syncthreads
();
// LDS double buffer: load last data from device mem
a_blockwise_copy
.
RunLoadThreadBuffer
(
p_a_global
,
p_a_thread_buffer
);
b_blockwise_copy
.
RunLoadThreadBuffer
(
p_b_global
,
p_b_thread_buffer
);
// LDS double buffer: GEMM on 2nd-last data
blockwise_gemm
.
Run
(
p_a_block_double
,
p_b_block_double
,
p_c_thread
);
// LDS double buffer: store last data to LDS
a_blockwise_copy
.
RunStoreThreadBuffer
(
p_a_thread_buffer
,
p_a_block_double
+
a_block_space
);
b_blockwise_copy
.
RunStoreThreadBuffer
(
p_b_thread_buffer
,
p_b_block_double
+
b_block_space
);
__syncthreads
();
// LDS double buffer: GEMM on current data
blockwise_gemm
.
Run
(
p_a_block_double
+
a_block_space
,
p_b_block_double
+
b_block_space
,
p_c_thread
);
}
else
// if has 1 iteration left
{
__syncthreads
();
// LDS double buffer: GEMM on last data
blockwise_gemm
.
Run
(
p_a_block_double
,
p_b_block_double
,
p_c_thread
);
}
}
// load data from xldop_acc_regs
blockwise_gemm
.
XdlopsMatrixCRead
(
p_c_thread
);
// copy output: register to global memory
{
///\todo inconsistent layout of xdlops and tensor
// xdlops layout
// M1 = num_groups;
// M0 = group_size;
// N1 = num_blks_per_wave;
// N0 = num_threads_per_blks;
constexpr
auto
CLayout
=
blockwise_gemm
.
GetOutputLayout
();
constexpr
index_t
M0
=
CLayout
.
M1
();
constexpr
index_t
M1
=
CLayout
.
N1
();
constexpr
index_t
M2
=
CLayout
.
M0
();
constexpr
auto
c_g_m0_m1_m2_n_global_desc
=
transform_tensor_descriptor
(
c_g_m_n_global_desc
,
make_tuple
(
PassThrough
<
G
>
{},
UnMerge
<
Sequence
<
M0
,
M1
,
M2
>>
{},
PassThrough
<
N
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
,
3
>
{},
Sequence
<
4
>
{}));
// src descriptor
constexpr
auto
c_g_m0_m1_m2_n_thread_desc
=
make_native_tensor_descriptor_packed
(
Sequence
<
1
,
M0
,
1
,
M2
,
1
>
{});
using
CThreadCopySliceLengths
=
Sequence
<
1
,
M0
,
1
,
M2
,
1
>
;
constexpr
index_t
BlkSize
=
CLayout
.
GetBlkSize
();
constexpr
index_t
NumBlks
=
CLayout
.
GetNumBlks
();
for
(
index_t
i
=
0
;
i
<
NumBlks
;
++
i
)
{
// calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index
const
auto
c_thread_mtx_on_block
=
blockwise_gemm
.
GetBeginOfThreadMatrixC
(
i
);
const
index_t
m_thread_data_on_global
=
m_block_data_on_global
+
c_thread_mtx_on_block
.
row
;
const
index_t
n_thread_data_on_global
=
n_block_data_on_global
+
c_thread_mtx_on_block
.
col
;
ThreadwiseGenericTensorSliceCopy_v4r2
<
decltype
(
c_g_m0_m1_m2_n_thread_desc
),
decltype
(
c_g_m0_m1_m2_n_global_desc
),
CThreadCopySliceLengths
,
arithmetic_sequence_gen
<
0
,
5
,
1
>::
type
,
4
,
1
,
1
,
AddressSpace
::
Vgpr
,
AddressSpace
::
Global
,
CGlobalMemoryDataOperation
>
(
{
0
,
0
,
0
,
0
,
0
},
{
group_id
,
m_thread_data_on_global
/
(
M2
*
M1
),
m_thread_data_on_global
%
(
M2
*
M1
)
/
M2
,
m_thread_data_on_global
%
M2
,
n_thread_data_on_global
})
.
Run
(
p_c_thread
+
i
*
BlkSize
,
p_c_global
);
}
}
}
};
}
// namespace ck
#endif
composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_fp16_bfp16.hpp
deleted
100644 → 0
View file @
80901f59
#ifndef CK_GRIDWISE_GEMM_XDLOPS_FP16_BFP16_HPP
#define CK_GRIDWISE_GEMM_XDLOPS_FP16_BFP16_HPP
#include "common_header.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "ConstantMatrixDescriptor.hpp"
#include "blockwise_generic_tensor_slice_copy.hpp"
#include "threadwise_generic_tensor_slice_copy.hpp"
#include "blockwise_gemm_xdlops.hpp"
namespace
ck
{
template
<
index_t
GridSize
,
index_t
BlockSize
,
class
ABFloat
,
class
AccFloat
,
class
CFloat
,
class
AGlobalDesc
,
class
BGlobalDesc
,
class
CGlobalDesc
,
index_t
MPerBlock
,
index_t
NPerBlock
,
index_t
KPerBlock
,
index_t
MPerWave
,
index_t
NPerWave
,
index_t
GemmDataPerReadM
,
index_t
GemmDataPerReadN
,
class
ABlockCopyThreadSliceLengths_K_M_KPACK
,
class
ABlockCopyThreadClusterLengths_K_M_KPACK
,
class
ABlockCopyThreadClusterArrangeOrder
,
class
ABlockCopySrcAccessOrder
,
class
ABlockCopyDstAccessOrder
,
index_t
ABlockCopySrcVectorReadDim
,
index_t
ABlockCopySrcDataPerRead
,
index_t
ABlockCopyDstDataPerWrite_KPACK
,
class
BBlockCopyThreadSliceLengths_K_N_KPACK
,
class
BBlockCopyThreadClusterLengths_K_N_KPACK
,
class
BBlockCopyThreadClusterArrangeOrder
,
class
BBlockCopySrcAccessOrder
,
class
BBlockCopyDstAccessOrder
,
index_t
BBlockCopySrcVectorReadDim
,
index_t
BBlockCopySrcDataPerRead
,
index_t
BBlockCopyDstDataPerWrite_KPACK
,
InMemoryDataOperation
OutputMemOp
>
struct
GridwiseGemmTransposedANormalBNormalCXdlopsFp16Bfp16_v1
{
__device__
void
Run
(
const
ABFloat
*
const
__restrict__
p_a_global
,
const
ABFloat
*
const
__restrict__
p_b_global
,
CFloat
*
const
__restrict__
p_c_global
)
const
{
constexpr
auto
b_k_n_kpack_global_desc
=
BGlobalDesc
{};
constexpr
auto
a_k_m_kpack_global_desc
=
AGlobalDesc
{};
constexpr
auto
c_m_n_global_desc
=
CGlobalDesc
{};
constexpr
auto
True
=
integral_constant
<
bool
,
true
>
{};
constexpr
auto
K
=
b_k_n_kpack_global_desc
.
GetLengths
()[
0
];
constexpr
auto
N
=
b_k_n_kpack_global_desc
.
GetLengths
()[
1
];
constexpr
auto
M
=
a_k_m_kpack_global_desc
.
GetLengths
()[
1
];
constexpr
auto
KPACK
=
b_k_n_kpack_global_desc
.
GetLengths
()[
2
];
// divide block work by [M, N]
static_assert
(
M
%
MPerBlock
==
0
&&
N
%
NPerBlock
==
0
&&
K
%
KPerBlock
==
0
,
"wrong! cannot divide work evenly among block"
);
constexpr
index_t
MBlockWork
=
M
/
MPerBlock
;
constexpr
index_t
NBlockWork
=
N
/
NPerBlock
;
constexpr
index_t
MWaves
=
MPerBlock
/
MPerWave
;
constexpr
index_t
NWaves
=
NPerBlock
/
NPerWave
;
constexpr
auto
block_work_desc
=
make_cluster_descriptor
(
Sequence
<
MBlockWork
,
NBlockWork
>
{});
const
auto
block_work_id
=
block_work_desc
.
CalculateClusterIndex
(
get_block_1d_id
());
const
index_t
k_block_data_on_global
=
block_work_id
[
0
]
*
MPerBlock
;
const
index_t
b_block_data_on_global
=
block_work_id
[
1
]
*
NPerBlock
;
// LDS mem
constexpr
index_t
max_align
=
math
::
lcm
(
BBlockCopyDstDataPerWrite_KPACK
,
ABlockCopyDstDataPerWrite_KPACK
,
KPACK
*
GemmDataPerReadM
,
KPACK
*
GemmDataPerReadN
);
// LDS
// be careful of LDS alignment
constexpr
auto
a_k_m_kpack_block_desc
=
make_native_tensor_descriptor_aligned
(
Sequence
<
KPerBlock
,
MPerBlock
,
KPACK
>
{},
Number
<
max_align
>
{});
auto
a_blockwise_copy
=
BlockwiseGenericTensorSliceCopy_v4
<
BlockSize
,
decltype
(
a_k_m_kpack_global_desc
),
decltype
(
a_k_m_kpack_block_desc
),
decltype
(
a_k_m_kpack_block_desc
.
GetLengths
()),
ABlockCopyThreadSliceLengths_K_M_KPACK
,
ABlockCopyThreadClusterLengths_K_M_KPACK
,
ABlockCopyThreadClusterArrangeOrder
,
ABlockCopySrcAccessOrder
,
ABlockCopyDstAccessOrder
,
ABlockCopySrcVectorReadDim
,
// Src dim to be read in vector form (M dimension)
2
,
// Dst dim to be written in vector form (KPACK dimension)
ABlockCopySrcDataPerRead
,
ABlockCopyDstDataPerWrite_KPACK
,
AddressSpace
::
Generic
,
AddressSpace
::
Vgpr
,
AddressSpace
::
Lds
,
InMemoryDataOperation
::
Set
>
({
0
,
k_block_data_on_global
,
0
},
{
0
,
0
,
0
});
constexpr
auto
b_k_n_kpack_block_desc
=
make_native_tensor_descriptor_aligned
(
Sequence
<
KPerBlock
,
NPerBlock
,
KPACK
>
{},
Number
<
max_align
>
{});
// input blockwise copy
auto
b_blockwise_copy
=
BlockwiseGenericTensorSliceCopy_v4
<
BlockSize
,
decltype
(
b_k_n_kpack_global_desc
),
decltype
(
b_k_n_kpack_block_desc
),
decltype
(
b_k_n_kpack_block_desc
.
GetLengths
()),
BBlockCopyThreadSliceLengths_K_N_KPACK
,
BBlockCopyThreadClusterLengths_K_N_KPACK
,
BBlockCopyThreadClusterArrangeOrder
,
BBlockCopySrcAccessOrder
,
BBlockCopyDstAccessOrder
,
BBlockCopySrcVectorReadDim
,
// Src dim to be read in vector form (N dimension)
2
,
// Dst dim to be written in vector form (KPACK dimension)
BBlockCopySrcDataPerRead
,
BBlockCopyDstDataPerWrite_KPACK
,
AddressSpace
::
Generic
,
AddressSpace
::
Vgpr
,
AddressSpace
::
Lds
,
InMemoryDataOperation
::
Set
>
({
0
,
b_block_data_on_global
,
0
},
{
0
,
0
,
0
});
// GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx
// a_mtx[KPerBlock, MPerBlock] is in LDS
// b_mtx[KPerBlocl, NPerBlock] is in LDS
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// register
constexpr
auto
a_k_m_block_mtx_desc
=
make_ConstantMatrixDescriptor_packed
(
Number
<
KPerBlock
>
{},
Number
<
MPerBlock
>
{});
constexpr
auto
b_k_n_block_mtx_desc
=
make_ConstantMatrixDescriptor_packed
(
Number
<
KPerBlock
>
{},
Number
<
NPerBlock
>
{});
const
auto
blockwise_gemm
=
BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_xdlops
<
BlockSize
,
decltype
(
a_k_m_block_mtx_desc
),
decltype
(
b_k_n_block_mtx_desc
),
ABFloat
,
MPerWave
,
NPerWave
,
MWaves
,
NWaves
,
GemmDataPerReadM
,
GemmDataPerReadN
>
{};
constexpr
auto
c_k_thread_mtx_desc
=
blockwise_gemm
.
GetThreadMatrixCDescriptor
();
constexpr
index_t
a_block_space
=
math
::
integer_least_multiple
(
a_k_m_kpack_block_desc
.
GetElementSpace
(),
max_align
);
constexpr
index_t
b_block_space
=
math
::
integer_least_multiple
(
b_k_n_kpack_block_desc
.
GetElementSpace
(),
max_align
);
__shared__
ABFloat
p_a_block_double
[
2
*
a_block_space
];
__shared__
ABFloat
p_b_block_double
[
2
*
b_block_space
];
// register allocation for output
AccFloat
p_c_thread
[
c_k_thread_mtx_desc
.
GetElementSpace
()];
// zero out threadwise output
threadwise_matrix_set_zero
(
c_k_thread_mtx_desc
,
p_c_thread
);
blockwise_gemm
.
XdlopsMatrixCSetZero
();
// LDS double buffer: preload data into LDS
{
a_blockwise_copy
.
Run
(
p_a_global
,
p_a_block_double
);
b_blockwise_copy
.
Run
(
p_b_global
,
p_b_block_double
);
}
using
blockwise_a_copy_src_step
=
Sequence
<
KPerBlock
,
0
,
0
>
;
using
blockwise_b_copy_src_step
=
Sequence
<
KPerBlock
,
0
,
0
>
;
// LDS double buffer: main body
for
(
index_t
k_block_data_begin
=
0
;
k_block_data_begin
+
2
*
KPerBlock
<
K
;
k_block_data_begin
+=
2
*
KPerBlock
)
{
#pragma unroll
for
(
index_t
iloop
=
0
;
iloop
<
2
;
++
iloop
)
{
const
bool
even_loop
=
(
iloop
%
2
==
0
);
ABFloat
*
p_a_block_now
=
even_loop
?
p_a_block_double
:
p_a_block_double
+
a_block_space
;
ABFloat
*
p_b_block_now
=
even_loop
?
p_b_block_double
:
p_b_block_double
+
b_block_space
;
ABFloat
*
p_a_block_next
=
even_loop
?
p_a_block_double
+
a_block_space
:
p_a_block_double
;
ABFloat
*
p_b_block_next
=
even_loop
?
p_b_block_double
+
b_block_space
:
p_b_block_double
;
ABFloat
p_a_thread_buffer
[
a_blockwise_copy
.
GetThreadBufferSize
()];
ABFloat
p_b_thread_buffer
[
b_blockwise_copy
.
GetThreadBufferSize
()];
a_blockwise_copy
.
MoveSrcSliceWindow
(
blockwise_a_copy_src_step
{},
True
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
blockwise_b_copy_src_step
{},
True
);
__syncthreads
();
// LDS doubel buffer: load next data from device mem
a_blockwise_copy
.
RunLoadThreadBuffer
(
p_a_global
,
p_a_thread_buffer
);
b_blockwise_copy
.
RunLoadThreadBuffer
(
p_b_global
,
p_b_thread_buffer
);
// LDS double buffer: GEMM on current data
// Vectorize the pointer to match with how half/bfloat16 datatypes are
// processed in gemm operation. Half type packs 4 half values while
// bfloat16 packs 2 bfloat16 values. Since gemm's matrix A and B
// 2D indexes are computed with vectorized value in mind (e.g. float, half2, half4),
// we recast datatype from a single half to 4 packed half/2 packed bfloat16
// respectively.
auto
p_a_block_vec
=
reinterpret_cast
<
const
half4_t
*>
(
p_a_block_now
);
auto
p_b_block_vec
=
reinterpret_cast
<
const
half4_t
*>
(
p_b_block_now
);
blockwise_gemm
.
Run
(
p_a_block_vec
,
p_b_block_vec
,
p_c_thread
);
// LDS double buffer: store next data to LDS
a_blockwise_copy
.
RunStoreThreadBuffer
(
p_a_thread_buffer
,
p_a_block_next
);
b_blockwise_copy
.
RunStoreThreadBuffer
(
p_b_thread_buffer
,
p_b_block_next
);
}
}
// LDS double buffer: tail
{
constexpr
bool
has_two_iteration_left
=
(
K
%
(
2
*
KPerBlock
)
==
0
);
if
(
has_two_iteration_left
)
// if has 2 iteration left
{
ABFloat
p_a_thread_buffer
[
a_blockwise_copy
.
GetThreadBufferSize
()];
ABFloat
p_b_thread_buffer
[
b_blockwise_copy
.
GetThreadBufferSize
()];
a_blockwise_copy
.
MoveSrcSliceWindow
(
blockwise_a_copy_src_step
{},
True
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
blockwise_b_copy_src_step
{},
True
);
__syncthreads
();
// LDS double buffer: load last data from device mem
a_blockwise_copy
.
RunLoadThreadBuffer
(
p_a_global
,
p_a_thread_buffer
);
b_blockwise_copy
.
RunLoadThreadBuffer
(
p_b_global
,
p_b_thread_buffer
);
// LDS double buffer: GEMM on 2nd-last data
auto
p_a_block_vec
=
reinterpret_cast
<
const
half4_t
*>
(
p_a_block_double
);
auto
p_b_block_vec
=
reinterpret_cast
<
const
half4_t
*>
(
p_b_block_double
);
blockwise_gemm
.
Run
(
p_a_block_vec
,
p_b_block_vec
,
p_c_thread
);
// LDS double buffer: store last data to LDS
a_blockwise_copy
.
RunStoreThreadBuffer
(
p_a_thread_buffer
,
p_a_block_double
+
a_block_space
);
b_blockwise_copy
.
RunStoreThreadBuffer
(
p_b_thread_buffer
,
p_b_block_double
+
b_block_space
);
__syncthreads
();
// LDS double buffer: GEMM on current data
p_a_block_vec
=
reinterpret_cast
<
const
half4_t
*>
(
p_a_block_double
+
a_block_space
);
p_b_block_vec
=
reinterpret_cast
<
const
half4_t
*>
(
p_b_block_double
+
b_block_space
);
blockwise_gemm
.
Run
(
p_a_block_vec
,
p_b_block_vec
,
p_c_thread
);
}
else
// if has 1 iteration left
{
__syncthreads
();
// LDS double buffer: GEMM on last data
auto
p_a_block_vec
=
reinterpret_cast
<
const
half4_t
*>
(
p_a_block_double
);
auto
p_b_block_vec
=
reinterpret_cast
<
const
half4_t
*>
(
p_b_block_double
);
blockwise_gemm
.
Run
(
p_a_block_vec
,
p_b_block_vec
,
p_c_thread
);
}
}
// load data from xldop_acc_regs
blockwise_gemm
.
XdlopsMatrixCRead
(
p_c_thread
);
// copy output: register to global memory
{
constexpr
auto
OutputLayout
=
blockwise_gemm
.
GetOutputLayout
();
constexpr
index_t
K0
=
OutputLayout
.
M1
();
constexpr
index_t
K1
=
OutputLayout
.
N1
();
constexpr
index_t
K2
=
OutputLayout
.
M0
();
constexpr
auto
out_k0_k1_k2_b_global_desc
=
transform_tensor_descriptor
(
c_m_n_global_desc
,
make_tuple
(
UnMerge
<
Sequence
<
K0
,
K1
,
K2
>>
{},
PassThrough
<
N
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{},
Sequence
<
3
>
{}));
// src descriptor
constexpr
auto
out_k0_k1_k2_b_thread_desc
=
make_native_tensor_descriptor_packed
(
Sequence
<
K0
,
1
,
K2
,
1
>
{});
using
OutThreadCopySliceLengths
=
Sequence
<
K0
,
1
,
K2
,
1
>
;
constexpr
index_t
BlkSize
=
OutputLayout
.
GetBlkSize
();
constexpr
index_t
NumBlks
=
OutputLayout
.
GetNumBlks
();
for
(
index_t
i
=
0
;
i
<
NumBlks
;
++
i
)
{
// calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index
const
auto
c_thread_mtx_on_block
=
blockwise_gemm
.
GetBeginOfThreadMatrixC
(
i
);
const
index_t
k_thread_data_on_global
=
k_block_data_on_global
+
c_thread_mtx_on_block
.
row
;
const
index_t
b_thread_data_on_global
=
b_block_data_on_global
+
c_thread_mtx_on_block
.
col
;
ThreadwiseGenericTensorSliceCopy_v4r2
<
decltype
(
out_k0_k1_k2_b_thread_desc
),
decltype
(
out_k0_k1_k2_b_global_desc
),
OutThreadCopySliceLengths
,
arithmetic_sequence_gen
<
0
,
4
,
1
>::
type
,
3
,
1
,
1
,
AddressSpace
::
Vgpr
,
is_same
<
AccFloat
,
CFloat
>::
value
?
AddressSpace
::
Global
:
AddressSpace
::
Generic
,
OutputMemOp
>
({
0
,
0
,
0
,
0
},
{
k_thread_data_on_global
/
(
K2
*
K1
),
k_thread_data_on_global
%
(
K2
*
K1
)
/
K2
,
k_thread_data_on_global
%
K2
,
b_thread_data_on_global
})
.
Run
(
p_c_thread
+
i
*
BlkSize
,
p_c_global
);
}
}
}
};
}
#endif
composable_kernel/include/tensor_operation/xdlops_gemm.hpp
deleted
100644 → 0
View file @
80901f59
#ifndef CK_XDLOPS_GEMM_HPP
#define CK_XDLOPS_GEMM_HPP
#include "common_header.hpp"
#include "ConstantMatrixDescriptor.hpp"
#include "math.hpp"
#define CK_USE_AMD_XDLOPS_EMULATE 1
namespace
ck
{
enum
struct
mfma_instr
{
// fp32
mfma_f32_32x32x1xf32
=
0
,
mfma_f32_16x16x1xf32
,
mfma_f32_4x4x1xf32
,
mfma_f32_32x32x2xf32
,
// k reduction
mfma_f32_16x16x4xf32
,
// k reduction
// fp16
mfma_f32_32x32x4f16
,
mfma_f32_16x16x4f16
,
mfma_f32_4x4x4f16
,
mfma_f32_32x32x8f16
,
// k reduction
mfma_f32_16x16x16f16
,
// k reduction
// bfp16
mfma_f32_32x32x2bf16
,
mfma_f32_16x16x2bf16
,
mfma_f32_4x4x2bf16
,
mfma_f32_32x32x4bf16
,
// k reduction
mfma_f32_16x16x8bf16
,
// k reduction
};
template
<
mfma_instr
instr
>
struct
mfma_info
;
template
<
>
struct
mfma_info
<
mfma_instr
::
mfma_f32_32x32x1xf32
>
{
static
constexpr
index_t
group_size
=
4
;
static
constexpr
index_t
num_groups_blk
=
4
;
static
constexpr
index_t
num_regs_blk
=
group_size
*
num_groups_blk
;
static
constexpr
index_t
num_threads_blk
=
32
;
static
constexpr
index_t
wave_size
=
64
;
static
constexpr
index_t
num_input_blks
=
wave_size
/
num_threads_blk
;
static
constexpr
index_t
num_output_blks
=
2
;
static
constexpr
index_t
num_regs_xdlops
=
num_regs_blk
*
num_output_blks
;
static
constexpr
index_t
m
=
32
;
static
constexpr
index_t
n
=
32
;
static
constexpr
index_t
k
=
1
;
static
constexpr
index_t
cycles
=
64
;
template
<
index_t
MPerWave
,
index_t
NPerWave
>
__device__
void
run
(
Number
<
MPerWave
>
,
Number
<
NPerWave
>
,
const
float
*
a
,
const
float
*
b
,
float
*
reg_c
)
const
{
static_assert
((
MPerWave
==
64
&&
NPerWave
==
64
)
||
(
MPerWave
==
32
&&
NPerWave
==
64
)
||
(
MPerWave
==
64
&&
NPerWave
==
32
),
"unsupported xdlops gemm"
);
const
auto
reg_a
=
*
a
;
const
auto
reg_b
=
*
b
;
auto
reg_c_
=
reinterpret_cast
<
float32_t
*>
(
reg_c
);
gcnasm_mfma_f32_32x32x1f32
<
MPerWave
,
NPerWave
>
(
reg_a
,
reg_b
,
reg_c_
);
}
};
template
<
>
struct
mfma_info
<
mfma_instr
::
mfma_f32_32x32x2xf32
>
{
static
constexpr
index_t
group_size
=
4
;
static
constexpr
index_t
num_groups_blk
=
4
;
static
constexpr
index_t
num_regs_blk
=
group_size
*
num_groups_blk
;
static
constexpr
index_t
num_threads_blk
=
32
;
static
constexpr
index_t
wave_size
=
64
;
static
constexpr
index_t
num_input_blks
=
wave_size
/
num_threads_blk
;
static
constexpr
index_t
num_output_blks
=
1
;
static
constexpr
index_t
num_regs_xdlops
=
num_regs_blk
*
num_output_blks
;
static
constexpr
index_t
m
=
32
;
static
constexpr
index_t
n
=
32
;
static
constexpr
index_t
k
=
2
;
static
constexpr
index_t
cycles
=
64
;
template
<
index_t
MPerWave
,
index_t
NPerWave
>
__device__
void
run
(
Number
<
MPerWave
>
,
Number
<
NPerWave
>
,
const
float
*
a
,
const
float
*
b
,
float
*
reg_c
)
const
{
static_assert
((
MPerWave
==
32
&&
NPerWave
==
32
),
"unsupported xdlops gemm"
);
const
auto
reg_a
=
*
a
;
const
auto
reg_b
=
*
b
;
auto
reg_c_
=
reinterpret_cast
<
float16_t
*>
(
reg_c
);
gcnasm_mfma_f32_32x32x2f32
(
reg_a
,
reg_b
,
reg_c_
);
}
};
template
<
>
struct
mfma_info
<
mfma_instr
::
mfma_f32_16x16x4xf32
>
{
static
constexpr
index_t
group_size
=
4
;
static
constexpr
index_t
num_groups_blk
=
1
;
static
constexpr
index_t
num_regs_blk
=
group_size
*
num_groups_blk
;
static
constexpr
index_t
num_threads_blk
=
16
;
static
constexpr
index_t
wave_size
=
64
;
static
constexpr
index_t
num_input_blks
=
wave_size
/
num_threads_blk
;
static
constexpr
index_t
num_output_blks
=
1
;
static
constexpr
index_t
num_regs_xdlops
=
num_regs_blk
*
num_output_blks
;
static
constexpr
index_t
m
=
16
;
static
constexpr
index_t
n
=
16
;
static
constexpr
index_t
k
=
4
;
static
constexpr
index_t
cycles
=
32
;
template
<
index_t
MPerWave
,
index_t
NPerWave
>
__device__
void
run
(
Number
<
MPerWave
>
,
Number
<
NPerWave
>
,
const
float
*
a
,
const
float
*
b
,
float
*
reg_c
)
const
{
static_assert
((
MPerWave
==
16
&&
NPerWave
==
16
),
"unsupported xdlops gemm"
);
const
auto
reg_a
=
*
a
;
const
auto
reg_b
=
*
b
;
auto
reg_c_
=
reinterpret_cast
<
float4_t
*>
(
reg_c
);
gcnasm_mfma_f32_16x16x4f32
(
reg_a
,
reg_b
,
reg_c_
);
}
};
template
<
>
struct
mfma_info
<
mfma_instr
::
mfma_f32_16x16x1xf32
>
{
static
constexpr
index_t
group_size
=
4
;
static
constexpr
index_t
num_groups_blk
=
1
;
static
constexpr
index_t
num_regs_blk
=
group_size
*
num_groups_blk
;
static
constexpr
index_t
num_threads_blk
=
16
;
static
constexpr
index_t
wave_size
=
64
;
static
constexpr
index_t
num_input_blks
=
wave_size
/
num_threads_blk
;
static
constexpr
index_t
num_output_blks
=
4
;
static
constexpr
index_t
num_regs_xdlops
=
num_regs_blk
*
num_output_blks
;
static
constexpr
index_t
m
=
16
;
static
constexpr
index_t
n
=
16
;
static
constexpr
index_t
k
=
1
;
static
constexpr
index_t
cycles
=
32
;
template
<
index_t
MPerWave
,
index_t
NPerWave
>
__device__
void
run
(
Number
<
MPerWave
>
,
Number
<
NPerWave
>
,
const
float
*
a
,
const
float
*
b
,
float
*
reg_c
)
const
{
static_assert
((
MPerWave
==
16
&&
NPerWave
==
64
)
||
(
MPerWave
==
64
&&
NPerWave
==
16
),
"unsupported xdlops gemm"
);
const
auto
reg_a
=
*
a
;
const
auto
reg_b
=
*
b
;
auto
reg_c_
=
reinterpret_cast
<
float16_t
*>
(
reg_c
);
gcnasm_mfma_f32_16x16x1f32
<
MPerWave
,
NPerWave
>
(
reg_a
,
reg_b
,
reg_c_
);
}
};
// treat 4x4x1 as a single-blk 4x64 mfma
template
<
>
struct
mfma_info
<
mfma_instr
::
mfma_f32_4x4x1xf32
>
{
static
constexpr
index_t
group_size
=
4
;
static
constexpr
index_t
num_groups_blk
=
1
;
static
constexpr
index_t
num_regs_blk
=
group_size
*
num_groups_blk
;
static
constexpr
index_t
num_threads_blk
=
64
;
static
constexpr
index_t
wave_size
=
64
;
static
constexpr
index_t
num_input_blks
=
1
;
static
constexpr
index_t
num_output_blks
=
1
;
static
constexpr
index_t
num_regs_xdlops
=
4
;
static
constexpr
index_t
m
=
4
;
static
constexpr
index_t
n
=
64
;
static
constexpr
index_t
k
=
1
;
static
constexpr
index_t
cycles
=
8
;
template
<
index_t
MPerWave
,
index_t
NPerWave
>
__device__
void
run
(
Number
<
MPerWave
>
,
Number
<
NPerWave
>
,
const
float
*
a
,
const
float
*
b
,
float
*
reg_c
)
const
{
static_assert
((
MPerWave
==
4
||
MPerWave
==
8
)
&&
NPerWave
==
64
,
"unsupported xdlops gemm"
);
const
auto
reg_a
=
*
a
;
const
auto
reg_b
=
*
b
;
auto
reg_c_
=
reinterpret_cast
<
float4_t
*>
(
reg_c
);
gcnasm_mfma_f32_4x4x1f32
<
MPerWave
,
NPerWave
>
(
reg_a
,
reg_b
,
reg_c_
);
}
};
template
<
>
struct
mfma_info
<
mfma_instr
::
mfma_f32_32x32x4f16
>
{
static
constexpr
index_t
group_size
=
4
;
static
constexpr
index_t
num_groups_blk
=
4
;
static
constexpr
index_t
num_regs_blk
=
group_size
*
num_groups_blk
;
static
constexpr
index_t
num_threads_blk
=
32
;
static
constexpr
index_t
wave_size
=
64
;
static
constexpr
index_t
num_input_blks
=
wave_size
/
num_threads_blk
;
static
constexpr
index_t
num_output_blks
=
2
;
static
constexpr
index_t
num_regs_xdlops
=
num_regs_blk
*
num_output_blks
;
static
constexpr
index_t
m
=
32
;
static
constexpr
index_t
n
=
32
;
static
constexpr
index_t
k
=
4
;
static
constexpr
index_t
cycles
=
64
;
template
<
index_t
MPerWave
,
index_t
NPerWave
>
__device__
void
run
(
Number
<
MPerWave
>
,
Number
<
NPerWave
>
,
const
half_t
*
a
,
const
half_t
*
b
,
float
*
reg_c
)
const
{
static_assert
((
MPerWave
==
64
&&
NPerWave
==
64
)
||
(
MPerWave
==
32
&&
NPerWave
==
64
)
||
(
MPerWave
==
64
&&
NPerWave
==
32
),
"unsupported xdlops gemm"
);
const
auto
reg_a
=
*
(
reinterpret_cast
<
const
half4_t
*>
(
a
));
const
auto
reg_b
=
*
(
reinterpret_cast
<
const
half4_t
*>
(
b
));
auto
reg_c_
=
reinterpret_cast
<
float32_t
*>
(
reg_c
);
gcnasm_mfma_f32_32x32x4f16
<
MPerWave
,
NPerWave
>
(
reg_a
,
reg_b
,
reg_c_
);
}
};
template
<
>
struct
mfma_info
<
mfma_instr
::
mfma_f32_32x32x8f16
>
{
static
constexpr
index_t
group_size
=
4
;
static
constexpr
index_t
num_groups_blk
=
4
;
static
constexpr
index_t
num_regs_blk
=
group_size
*
num_groups_blk
;
static
constexpr
index_t
num_threads_blk
=
32
;
static
constexpr
index_t
wave_size
=
64
;
static
constexpr
index_t
num_input_blks
=
wave_size
/
num_threads_blk
;
static
constexpr
index_t
num_output_blks
=
1
;
static
constexpr
index_t
num_regs_xdlops
=
num_regs_blk
*
num_output_blks
;
static
constexpr
index_t
m
=
32
;
static
constexpr
index_t
n
=
32
;
static
constexpr
index_t
k
=
8
;
static
constexpr
index_t
cycles
=
64
;
template
<
index_t
MPerWave
,
index_t
NPerWave
>
__device__
void
run
(
Number
<
MPerWave
>
,
Number
<
NPerWave
>
,
const
half_t
*
a
,
const
half_t
*
b
,
float
*
reg_c
)
const
{
static_assert
((
MPerWave
==
32
&&
NPerWave
==
32
),
"unsupported xdlops gemm"
);
const
auto
reg_a
=
*
(
reinterpret_cast
<
const
half4_t
*>
(
a
));
const
auto
reg_b
=
*
(
reinterpret_cast
<
const
half4_t
*>
(
b
));
auto
reg_c_
=
reinterpret_cast
<
float16_t
*>
(
reg_c
);
gcnasm_mfma_f32_32x32x8f16
(
reg_a
,
reg_b
,
reg_c_
);
}
};
template
<
>
struct
mfma_info
<
mfma_instr
::
mfma_f32_16x16x16f16
>
{
static
constexpr
index_t
group_size
=
4
;
static
constexpr
index_t
num_groups_blk
=
1
;
static
constexpr
index_t
num_regs_blk
=
group_size
*
num_groups_blk
;
static
constexpr
index_t
num_threads_blk
=
16
;
static
constexpr
index_t
wave_size
=
64
;
static
constexpr
index_t
num_input_blks
=
wave_size
/
num_threads_blk
;
static
constexpr
index_t
num_output_blks
=
1
;
static
constexpr
index_t
num_regs_xdlops
=
num_regs_blk
*
num_output_blks
;
static
constexpr
index_t
m
=
16
;
static
constexpr
index_t
n
=
16
;
static
constexpr
index_t
k
=
16
;
static
constexpr
index_t
cycles
=
32
;
template
<
index_t
MPerWave
,
index_t
NPerWave
>
__device__
void
run
(
Number
<
MPerWave
>
,
Number
<
NPerWave
>
,
const
half_t
*
a
,
const
half_t
*
b
,
float
*
reg_c
)
const
{
static_assert
((
MPerWave
==
16
&&
NPerWave
==
16
),
"unsupported xdlops gemm"
);
const
auto
reg_a
=
*
(
reinterpret_cast
<
const
half4_t
*>
(
a
));
const
auto
reg_b
=
*
(
reinterpret_cast
<
const
half4_t
*>
(
b
));
auto
reg_c_
=
reinterpret_cast
<
float4_t
*>
(
reg_c
);
gcnasm_mfma_f32_16x16x16f16
(
reg_a
,
reg_b
,
reg_c_
);
}
};
template
<
>
struct
mfma_info
<
mfma_instr
::
mfma_f32_16x16x4f16
>
{
static
constexpr
index_t
group_size
=
4
;
static
constexpr
index_t
num_groups_blk
=
1
;
static
constexpr
index_t
num_regs_blk
=
group_size
*
num_groups_blk
;
static
constexpr
index_t
num_threads_blk
=
16
;
static
constexpr
index_t
wave_size
=
64
;
static
constexpr
index_t
num_input_blks
=
wave_size
/
num_threads_blk
;
static
constexpr
index_t
num_output_blks
=
4
;
static
constexpr
index_t
num_regs_xdlops
=
num_regs_blk
*
num_output_blks
;
static
constexpr
index_t
m
=
16
;
static
constexpr
index_t
n
=
16
;
static
constexpr
index_t
k
=
4
;
static
constexpr
index_t
cycles
=
32
;
template
<
index_t
MPerWave
,
index_t
NPerWave
>
__device__
void
run
(
Number
<
MPerWave
>
,
Number
<
NPerWave
>
,
const
half_t
*
a
,
const
half_t
*
b
,
float
*
reg_c
)
const
{
static_assert
((
MPerWave
==
16
&&
NPerWave
==
64
)
||
(
MPerWave
==
64
&&
NPerWave
==
16
),
"unsupported xdlops gemm"
);
const
auto
reg_a
=
*
(
reinterpret_cast
<
const
half4_t
*>
(
a
));
const
auto
reg_b
=
*
(
reinterpret_cast
<
const
half4_t
*>
(
b
));
auto
reg_c_
=
reinterpret_cast
<
float16_t
*>
(
reg_c
);
gcnasm_mfma_f32_16x16x4f16
<
MPerWave
,
NPerWave
>
(
reg_a
,
reg_b
,
reg_c_
);
}
};
template
<
>
struct
mfma_info
<
mfma_instr
::
mfma_f32_4x4x4f16
>
{
static
constexpr
index_t
group_size
=
4
;
static
constexpr
index_t
num_groups_blk
=
1
;
static
constexpr
index_t
num_regs_blk
=
group_size
*
num_groups_blk
;
static
constexpr
index_t
num_threads_blk
=
64
;
static
constexpr
index_t
wave_size
=
64
;
static
constexpr
index_t
num_input_blks
=
1
;
static
constexpr
index_t
num_output_blks
=
1
;
static
constexpr
index_t
num_regs_xdlops
=
4
;
static
constexpr
index_t
m
=
4
;
static
constexpr
index_t
n
=
64
;
static
constexpr
index_t
k
=
4
;
static
constexpr
index_t
cycles
=
8
;
template
<
index_t
MPerWave
,
index_t
NPerWave
>
__device__
void
run
(
Number
<
MPerWave
>
,
Number
<
NPerWave
>
,
const
half_t
*
a
,
const
half_t
*
b
,
float
*
reg_c
)
const
{
static_assert
((
MPerWave
==
4
||
MPerWave
==
8
)
&&
NPerWave
==
64
,
"unsupported xdlops gemm"
);
const
auto
reg_a
=
*
(
reinterpret_cast
<
const
half4_t
*>
(
a
));
const
auto
reg_b
=
*
(
reinterpret_cast
<
const
half4_t
*>
(
b
));
auto
reg_c_
=
reinterpret_cast
<
float4_t
*>
(
reg_c
);
gcnasm_mfma_f32_4x4x4f16
<
MPerWave
,
NPerWave
>
(
reg_a
,
reg_b
,
reg_c_
);
}
};
template
<
>
struct
mfma_info
<
mfma_instr
::
mfma_f32_32x32x2bf16
>
{
static
constexpr
index_t
group_size
=
4
;
static
constexpr
index_t
num_groups_blk
=
4
;
static
constexpr
index_t
num_regs_blk
=
group_size
*
num_groups_blk
;
static
constexpr
index_t
num_threads_blk
=
32
;
static
constexpr
index_t
wave_size
=
64
;
static
constexpr
index_t
num_input_blks
=
wave_size
/
num_threads_blk
;
static
constexpr
index_t
num_output_blks
=
2
;
static
constexpr
index_t
num_regs_xdlops
=
num_regs_blk
*
num_output_blks
;
static
constexpr
index_t
m
=
32
;
static
constexpr
index_t
n
=
32
;
static
constexpr
index_t
k
=
2
;
static
constexpr
index_t
cycles
=
64
;
template
<
index_t
MPerWave
,
index_t
NPerWave
>
__device__
void
run
(
Number
<
MPerWave
>
,
Number
<
NPerWave
>
,
const
ushort
*
a
,
const
ushort
*
b
,
float
*
reg_c
)
const
{
static_assert
((
MPerWave
==
64
&&
NPerWave
==
64
)
||
(
MPerWave
==
32
&&
NPerWave
==
64
)
||
(
MPerWave
==
64
&&
NPerWave
==
32
),
"unsupported xdlops gemm"
);
const
auto
reg_a
=
*
(
reinterpret_cast
<
const
ushort2_t
*>
(
a
));
const
auto
reg_b
=
*
(
reinterpret_cast
<
const
ushort2_t
*>
(
b
));
auto
reg_c_
=
reinterpret_cast
<
float32_t
*>
(
reg_c
);
gcnasm_mfma_f32_32x32x2bf16
<
MPerWave
,
NPerWave
>
(
reg_a
,
reg_b
,
reg_c_
);
}
};
template
<
>
struct
mfma_info
<
mfma_instr
::
mfma_f32_32x32x4bf16
>
{
static
constexpr
index_t
group_size
=
4
;
static
constexpr
index_t
num_groups_blk
=
4
;
static
constexpr
index_t
num_regs_blk
=
group_size
*
num_groups_blk
;
static
constexpr
index_t
num_threads_blk
=
32
;
static
constexpr
index_t
wave_size
=
64
;
static
constexpr
index_t
num_input_blks
=
wave_size
/
num_threads_blk
;
static
constexpr
index_t
num_output_blks
=
1
;
static
constexpr
index_t
num_regs_xdlops
=
num_regs_blk
*
num_output_blks
;
static
constexpr
index_t
m
=
32
;
static
constexpr
index_t
n
=
32
;
static
constexpr
index_t
k
=
4
;
static
constexpr
index_t
cycles
=
64
;
template
<
index_t
MPerWave
,
index_t
NPerWave
>
__device__
void
run
(
Number
<
MPerWave
>
,
Number
<
NPerWave
>
,
const
ushort
*
a
,
const
ushort
*
b
,
float
*
reg_c
)
const
{
static_assert
((
MPerWave
==
32
&&
NPerWave
==
32
),
"unsupported xdlops gemm"
);
const
auto
reg_a
=
*
(
reinterpret_cast
<
const
ushort2_t
*>
(
a
));
const
auto
reg_b
=
*
(
reinterpret_cast
<
const
ushort2_t
*>
(
b
));
auto
reg_c_
=
reinterpret_cast
<
float16_t
*>
(
reg_c
);
gcnasm_mfma_f32_32x32x4bf16
(
reg_a
,
reg_b
,
reg_c_
);
}
};
template
<
>
struct
mfma_info
<
mfma_instr
::
mfma_f32_16x16x8bf16
>
{
static
constexpr
index_t
group_size
=
4
;
static
constexpr
index_t
num_groups_blk
=
1
;
static
constexpr
index_t
num_regs_blk
=
group_size
*
num_groups_blk
;
static
constexpr
index_t
num_threads_blk
=
16
;
static
constexpr
index_t
wave_size
=
64
;
static
constexpr
index_t
num_input_blks
=
wave_size
/
num_threads_blk
;
static
constexpr
index_t
num_output_blks
=
1
;
static
constexpr
index_t
num_regs_xdlops
=
num_regs_blk
*
num_output_blks
;
static
constexpr
index_t
m
=
16
;
static
constexpr
index_t
n
=
16
;
static
constexpr
index_t
k
=
8
;
static
constexpr
index_t
cycles
=
32
;
template
<
index_t
MPerWave
,
index_t
NPerWave
>
__device__
void
run
(
Number
<
MPerWave
>
,
Number
<
NPerWave
>
,
const
ushort
*
a
,
const
ushort
*
b
,
float
*
reg_c
)
const
{
static_assert
((
MPerWave
==
16
&&
NPerWave
==
16
),
"unsupported xdlops gemm"
);
const
auto
reg_a
=
*
(
reinterpret_cast
<
const
ushort2_t
*>
(
a
));
const
auto
reg_b
=
*
(
reinterpret_cast
<
const
ushort2_t
*>
(
b
));
auto
reg_c_
=
reinterpret_cast
<
float4_t
*>
(
reg_c
);
gcnasm_mfma_f32_16x16x8bf16
(
reg_a
,
reg_b
,
reg_c_
);
}
};
template
<
>
struct
mfma_info
<
mfma_instr
::
mfma_f32_16x16x2bf16
>
{
static
constexpr
index_t
group_size
=
4
;
static
constexpr
index_t
num_groups_blk
=
1
;
static
constexpr
index_t
num_regs_blk
=
group_size
*
num_groups_blk
;
static
constexpr
index_t
num_threads_blk
=
16
;
static
constexpr
index_t
wave_size
=
64
;
static
constexpr
index_t
num_input_blks
=
wave_size
/
num_threads_blk
;
static
constexpr
index_t
num_output_blks
=
4
;
static
constexpr
index_t
num_regs_xdlops
=
num_regs_blk
*
num_output_blks
;
static
constexpr
index_t
m
=
16
;
static
constexpr
index_t
n
=
16
;
static
constexpr
index_t
k
=
2
;
static
constexpr
index_t
cycles
=
32
;
template
<
index_t
MPerWave
,
index_t
NPerWave
>
__device__
void
run
(
Number
<
MPerWave
>
,
Number
<
NPerWave
>
,
const
ushort
*
a
,
const
ushort
*
b
,
float
*
reg_c
)
const
{
static_assert
((
MPerWave
==
16
&&
NPerWave
==
64
)
||
(
MPerWave
==
64
&&
NPerWave
==
16
),
"unsupported xdlops gemm"
);
const
auto
reg_a
=
*
(
reinterpret_cast
<
const
ushort2_t
*>
(
a
));
const
auto
reg_b
=
*
(
reinterpret_cast
<
const
ushort2_t
*>
(
b
));
auto
reg_c_
=
reinterpret_cast
<
float16_t
*>
(
reg_c
);
gcnasm_mfma_f32_16x16x2bf16
<
MPerWave
,
NPerWave
>
(
reg_a
,
reg_b
,
reg_c_
);
}
};
template
<
>
struct
mfma_info
<
mfma_instr
::
mfma_f32_4x4x2bf16
>
{
static
constexpr
index_t
group_size
=
4
;
static
constexpr
index_t
num_groups_blk
=
1
;
static
constexpr
index_t
num_regs_blk
=
group_size
*
num_groups_blk
;
static
constexpr
index_t
num_threads_blk
=
64
;
static
constexpr
index_t
wave_size
=
64
;
static
constexpr
index_t
num_input_blks
=
1
;
static
constexpr
index_t
num_output_blks
=
1
;
static
constexpr
index_t
num_regs_xdlops
=
4
;
static
constexpr
index_t
m
=
4
;
static
constexpr
index_t
n
=
64
;
static
constexpr
index_t
k
=
2
;
static
constexpr
index_t
cycles
=
8
;
template
<
index_t
MPerWave
,
index_t
NPerWave
>
__device__
void
run
(
Number
<
MPerWave
>
,
Number
<
NPerWave
>
,
const
ushort
*
a
,
const
ushort
*
b
,
float
*
reg_c
)
const
{
static_assert
((
MPerWave
==
4
||
MPerWave
==
8
)
&&
NPerWave
==
64
,
"unsupported xdlops gemm"
);
const
auto
reg_a
=
*
(
reinterpret_cast
<
const
ushort2_t
*>
(
a
));
const
auto
reg_b
=
*
(
reinterpret_cast
<
const
ushort2_t
*>
(
b
));
auto
reg_c_
=
reinterpret_cast
<
float4_t
*>
(
reg_c
);
gcnasm_mfma_f32_4x4x2bf16
<
MPerWave
,
NPerWave
>
(
reg_a
,
reg_b
,
reg_c_
);
}
};
template
<
class
data_type
,
index_t
MPerWave
,
index_t
NPerWave
>
__device__
constexpr
auto
GetMFMAInfo
();
template
<
>
__device__
constexpr
auto
GetMFMAInfo
<
float
,
32
,
64
>
()
{
return
mfma_info
<
mfma_instr
::
mfma_f32_32x32x1xf32
>
{};
}
template
<
>
__device__
constexpr
auto
GetMFMAInfo
<
float
,
64
,
64
>
()
{
return
mfma_info
<
mfma_instr
::
mfma_f32_32x32x1xf32
>
{};
}
template
<
>
__device__
constexpr
auto
GetMFMAInfo
<
float
,
64
,
32
>
()
{
return
mfma_info
<
mfma_instr
::
mfma_f32_32x32x1xf32
>
{};
}
template
<
>
__device__
constexpr
auto
GetMFMAInfo
<
float
,
32
,
32
>
()
{
return
mfma_info
<
mfma_instr
::
mfma_f32_32x32x2xf32
>
{};
}
template
<
>
__device__
constexpr
auto
GetMFMAInfo
<
float
,
16
,
16
>
()
{
return
mfma_info
<
mfma_instr
::
mfma_f32_16x16x4xf32
>
{};
}
template
<
>
__device__
constexpr
auto
GetMFMAInfo
<
float
,
16
,
64
>
()
{
return
mfma_info
<
mfma_instr
::
mfma_f32_16x16x1xf32
>
{};
}
template
<
>
__device__
constexpr
auto
GetMFMAInfo
<
float
,
64
,
16
>
()
{
return
mfma_info
<
mfma_instr
::
mfma_f32_16x16x1xf32
>
{};
}
template
<
>
__device__
constexpr
auto
GetMFMAInfo
<
float
,
8
,
64
>
()
{
return
mfma_info
<
mfma_instr
::
mfma_f32_4x4x1xf32
>
{};
}
template
<
>
__device__
constexpr
auto
GetMFMAInfo
<
float
,
4
,
64
>
()
{
return
mfma_info
<
mfma_instr
::
mfma_f32_4x4x1xf32
>
{};
}
template
<
>
__device__
constexpr
auto
GetMFMAInfo
<
half_t
,
64
,
64
>
()
{
return
mfma_info
<
mfma_instr
::
mfma_f32_32x32x4f16
>
{};
}
template
<
>
__device__
constexpr
auto
GetMFMAInfo
<
half_t
,
64
,
32
>
()
{
return
mfma_info
<
mfma_instr
::
mfma_f32_32x32x4f16
>
{};
}
template
<
>
__device__
constexpr
auto
GetMFMAInfo
<
half_t
,
32
,
64
>
()
{
return
mfma_info
<
mfma_instr
::
mfma_f32_32x32x4f16
>
{};
}
template
<
>
__device__
constexpr
auto
GetMFMAInfo
<
half_t
,
32
,
32
>
()
{
return
mfma_info
<
mfma_instr
::
mfma_f32_32x32x8f16
>
{};
}
template
<
>
__device__
constexpr
auto
GetMFMAInfo
<
half_t
,
16
,
16
>
()
{
return
mfma_info
<
mfma_instr
::
mfma_f32_16x16x16f16
>
{};
}
template
<
>
__device__
constexpr
auto
GetMFMAInfo
<
half_t
,
16
,
64
>
()
{
return
mfma_info
<
mfma_instr
::
mfma_f32_16x16x4f16
>
{};
}
template
<
>
__device__
constexpr
auto
GetMFMAInfo
<
half_t
,
64
,
16
>
()
{
return
mfma_info
<
mfma_instr
::
mfma_f32_16x16x4f16
>
{};
}
template
<
>
__device__
constexpr
auto
GetMFMAInfo
<
half_t
,
4
,
64
>
()
{
return
mfma_info
<
mfma_instr
::
mfma_f32_4x4x4f16
>
{};
}
template
<
>
__device__
constexpr
auto
GetMFMAInfo
<
half_t
,
8
,
64
>
()
{
return
mfma_info
<
mfma_instr
::
mfma_f32_4x4x4f16
>
{};
}
template
<
>
__device__
constexpr
auto
GetMFMAInfo
<
ushort
,
64
,
64
>
()
{
return
mfma_info
<
mfma_instr
::
mfma_f32_32x32x2bf16
>
{};
}
template
<
>
__device__
constexpr
auto
GetMFMAInfo
<
ushort
,
64
,
32
>
()
{
return
mfma_info
<
mfma_instr
::
mfma_f32_32x32x2bf16
>
{};
}
template
<
>
__device__
constexpr
auto
GetMFMAInfo
<
ushort
,
32
,
64
>
()
{
return
mfma_info
<
mfma_instr
::
mfma_f32_32x32x2bf16
>
{};
}
template
<
>
__device__
constexpr
auto
GetMFMAInfo
<
ushort
,
32
,
32
>
()
{
return
mfma_info
<
mfma_instr
::
mfma_f32_32x32x4bf16
>
{};
}
template
<
>
__device__
constexpr
auto
GetMFMAInfo
<
ushort
,
16
,
16
>
()
{
return
mfma_info
<
mfma_instr
::
mfma_f32_16x16x8bf16
>
{};
}
template
<
>
__device__
constexpr
auto
GetMFMAInfo
<
ushort
,
16
,
64
>
()
{
return
mfma_info
<
mfma_instr
::
mfma_f32_16x16x2bf16
>
{};
}
template
<
>
__device__
constexpr
auto
GetMFMAInfo
<
ushort
,
64
,
16
>
()
{
return
mfma_info
<
mfma_instr
::
mfma_f32_16x16x2bf16
>
{};
}
template
<
>
__device__
constexpr
auto
GetMFMAInfo
<
ushort
,
4
,
64
>
()
{
return
mfma_info
<
mfma_instr
::
mfma_f32_4x4x2bf16
>
{};
}
template
<
>
__device__
constexpr
auto
GetMFMAInfo
<
ushort
,
8
,
64
>
()
{
return
mfma_info
<
mfma_instr
::
mfma_f32_4x4x2bf16
>
{};
}
template
<
class
data_type
,
index_t
MPerWave
,
index_t
NPerWave
,
index_t
GemmDataPerReadA
,
index_t
GemmDataPerReadB
>
struct
XdlopsGemm_t
{
struct
MatrixIndex
{
index_t
row
;
index_t
col
;
};
template
<
index_t
M1_
,
index_t
M0_
,
index_t
N1_
,
index_t
N0_
>
struct
OutputLayout
{
__device__
static
constexpr
index_t
M1
()
{
return
M1_
;
}
__device__
static
constexpr
index_t
M0
()
{
return
M0_
;
}
__device__
static
constexpr
index_t
N1
()
{
return
N1_
;
}
__device__
static
constexpr
index_t
N0
()
{
return
N0_
;
}
__device__
static
constexpr
index_t
GetBlkSize
()
{
return
GetMFMAInfo
<
data_type
,
MPerWave
,
NPerWave
>
().
num_regs_blk
;
}
__device__
static
constexpr
index_t
GetNumBlks
()
{
constexpr
auto
mfma_type
=
GetMFMAInfo
<
data_type
,
MPerWave
,
NPerWave
>
();
return
MPerWave
*
NPerWave
/
(
mfma_type
.
m
*
mfma_type
.
n
);
}
};
__device__
constexpr
XdlopsGemm_t
()
{
static_assert
(
NPerWave
==
4
||
NPerWave
==
8
||
NPerWave
==
16
||
NPerWave
==
32
||
NPerWave
==
64
,
"Only support GemmNPerWave == 4, 8, 16, 32 or 64 for xdlops"
);
static_assert
(
MPerWave
==
4
||
MPerWave
==
8
||
MPerWave
==
16
||
MPerWave
==
32
||
MPerWave
==
64
,
"Only support GemmMPerWave == 4, 8, 16, 32 or 64 for xdlops"
);
static_assert
(
GemmDataPerReadA
==
1
&&
GemmDataPerReadB
==
1
,
"GemmDataPerReadA/B != 1"
);
constexpr
auto
mfma_type
=
GetMFMAInfo
<
data_type
,
MPerWave
,
NPerWave
>
();
static_assert
(
mfma_type
.
num_threads_blk
==
mfma_type
.
n
,
"n != num_threads_blk"
);
static_assert
(
mfma_type
.
num_regs_blk
*
mfma_type
.
num_input_blks
==
mfma_type
.
m
,
"m != num_input_blks * num_regs_blk"
);
static_assert
(
mfma_type
.
num_output_blks
==
mfma_type
.
num_input_blks
||
mfma_type
.
num_output_blks
==
1
,
"incorrect num_output_blks"
);
static_assert
(
mfma_type
.
num_regs_blk
*
mfma_type
.
wave_size
==
mfma_type
.
m
*
mfma_type
.
n
,
"num_regs_blk incorrect"
);
}
__device__
static
constexpr
bool
IsABroadcast
()
{
return
NPerWave
>=
MPerWave
;
}
__device__
static
constexpr
bool
IsKReduction
()
{
constexpr
auto
mfma_type
=
GetMFMAInfo
<
data_type
,
MPerWave
,
NPerWave
>
();
return
mfma_type
.
num_output_blks
==
1
&&
mfma_type
.
num_input_blks
!=
1
;
}
#if CK_USE_AMD_XDLOPS_EMULATE
// emulate xdlops
template
<
index_t
M
,
index_t
N
,
index_t
K
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
void
XdlopsEmulate
(
const
FloatA
*
const
__restrict__
p_a_wave
,
const
FloatB
*
const
__restrict__
p_b_wave
,
FloatC
*
const
__restrict__
p_c_thread
)
const
{
constexpr
auto
mfma_type
=
GetMFMAInfo
<
data_type
,
MPerWave
,
NPerWave
>
();
const
index_t
laneId
=
get_thread_local_1d_id
()
%
mfma_type
.
wave_size
;
const
index_t
blk_id
=
laneId
/
mfma_type
.
num_threads_blk
;
const
index_t
blk_td
=
laneId
%
mfma_type
.
num_threads_blk
;
// K reduction
static_if
<
IsKReduction
()
>
{}([
&
](
auto
)
{
for
(
index_t
k
=
0
;
k
<
K
;
k
+=
mfma_type
.
num_input_blks
)
{
for
(
index_t
n
=
0
;
n
<
mfma_type
.
num_input_blks
;
++
n
)
{
index_t
a_off
=
(
k
+
n
)
*
M
;
index_t
b_off
=
(
k
+
n
)
*
N
;
index_t
c_off
=
0
;
for
(
index_t
m
=
0
;
m
<
mfma_type
.
num_regs_blk
;
++
m
)
{
index_t
aindex
=
m
%
mfma_type
.
group_size
+
blk_id
*
mfma_type
.
group_size
+
m
/
mfma_type
.
group_size
*
(
mfma_type
.
group_size
*
mfma_type
.
num_input_blks
);
index_t
bindex
=
blk_td
;
p_c_thread
[
m
+
c_off
]
+=
inner_product_with_conversion
<
FloatC
>
{}(
p_a_wave
[
aindex
+
a_off
],
p_b_wave
[
bindex
+
b_off
]);
}
}
}
}).
Else
([
&
](
auto
)
{
static_if
<
IsABroadcast
()
>
{}([
&
](
auto
)
{
// ABroadcast
for
(
index_t
k
=
0
;
k
<
K
;
++
k
)
{
for
(
index_t
b
=
0
;
b
<
MPerWave
/
mfma_type
.
m
;
++
b
)
{
for
(
index_t
n
=
0
;
n
<
mfma_type
.
num_input_blks
;
++
n
)
{
index_t
a_off
=
k
*
M
+
b
*
mfma_type
.
m
;
index_t
b_off
=
k
*
N
+
n
*
mfma_type
.
num_threads_blk
;
index_t
c_off
=
n
*
mfma_type
.
num_regs_blk
+
b
*
mfma_type
.
num_regs_xdlops
;
for
(
index_t
m
=
0
;
m
<
mfma_type
.
num_regs_blk
;
++
m
)
{
index_t
aindex
=
m
%
mfma_type
.
group_size
+
blk_id
*
mfma_type
.
group_size
+
m
/
mfma_type
.
group_size
*
(
mfma_type
.
group_size
*
mfma_type
.
num_input_blks
);
index_t
bindex
=
blk_td
;
p_c_thread
[
m
+
c_off
]
+=
inner_product_with_conversion
<
FloatC
>
{}(
p_a_wave
[
aindex
+
a_off
],
p_b_wave
[
bindex
+
b_off
]);
}
}
}
}
}).
Else
([
&
](
auto
)
{
// BBroadcast
for
(
index_t
k
=
0
;
k
<
K
;
++
k
)
{
for
(
index_t
b
=
0
;
b
<
NPerWave
/
mfma_type
.
n
;
++
b
)
{
for
(
index_t
n
=
0
;
n
<
mfma_type
.
num_input_blks
;
++
n
)
{
index_t
a_off
=
k
*
M
+
n
*
mfma_type
.
m
;
index_t
b_off
=
k
*
N
+
b
*
mfma_type
.
n
;
index_t
c_off
=
n
*
mfma_type
.
num_regs_blk
+
b
*
mfma_type
.
num_regs_xdlops
;
for
(
index_t
m
=
0
;
m
<
mfma_type
.
num_regs_blk
;
++
m
)
{
index_t
aindex
=
m
%
mfma_type
.
group_size
+
blk_id
*
mfma_type
.
group_size
+
m
/
mfma_type
.
group_size
*
(
mfma_type
.
group_size
*
mfma_type
.
num_input_blks
);
index_t
bindex
=
blk_td
;
p_c_thread
[
m
+
c_off
]
+=
inner_product_with_conversion
<
FloatC
>
{}(
p_a_wave
[
aindex
+
a_off
],
p_b_wave
[
bindex
+
b_off
]);
}
}
}
}
});
});
}
#endif
template
<
index_t
M
,
index_t
N
,
index_t
K
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
void
Run
(
const
FloatA
*
const
__restrict__
p_a_wave
,
const
FloatB
*
const
__restrict__
p_b_wave
,
FloatC
*
const
__restrict__
p_c_thread
)
const
{
static_assert
(
GemmDataPerReadA
==
1
&&
GemmDataPerReadB
==
1
,
"GemmDataPerReadA/B != 1"
);
static_assert
(
is_same
<
FloatA
,
FloatB
>::
value
,
"FloatA != FloatB"
);
static_assert
(
is_same
<
FloatC
,
float
>::
value
,
"FloatC != float"
);
#if CK_USE_AMD_XDLOPS_EMULATE
XdlopsEmulate
<
M
,
N
,
K
>
(
p_a_wave
,
p_b_wave
,
p_c_thread
);
#else
constexpr
auto
mfma_type
=
GetMFMAInfo
<
data_type
,
MPerWave
,
NPerWave
>
();
static_if
<!
IsKReduction
()
>
{}([
&
](
auto
)
{
const
index_t
laneId
=
get_thread_local_1d_id
()
%
mfma_type
.
wave_size
;
for
(
index_t
k
=
0
;
k
<
K
;
++
k
)
{
mfma_type
.
run
(
Number
<
MPerWave
>
{},
Number
<
NPerWave
>
{},
p_a_wave
,
p_b_wave
,
p_c_thread
);
}
}).
Else
([
&
](
auto
)
{
const
index_t
laneId
=
get_thread_local_1d_id
()
%
mfma_type
.
wave_size
;
for
(
index_t
k
=
0
;
k
<
K
;
k
+=
mfma_type
.
num_input_blks
)
{
mfma_type
.
run
(
Number
<
MPerWave
>
{},
Number
<
NPerWave
>
{},
p_a_wave
,
p_b_wave
,
p_c_thread
);
}
});
#endif
}
__device__
static
MatrixIndex
GetBeginOfThreadBlk
(
index_t
i
)
{
constexpr
auto
mfma_type
=
GetMFMAInfo
<
data_type
,
MPerWave
,
NPerWave
>
();
const
index_t
laneId
=
get_thread_local_1d_id
()
%
mfma_type
.
wave_size
;
const
index_t
blk_id
=
laneId
/
mfma_type
.
num_threads_blk
;
const
index_t
blk_td
=
laneId
%
mfma_type
.
num_threads_blk
;
index_t
col_blk
=
i
%
mfma_type
.
num_output_blks
;
index_t
row_blk
=
i
/
mfma_type
.
num_output_blks
;
index_t
col
=
col_blk
*
mfma_type
.
n
+
blk_td
;
index_t
row
=
row_blk
*
mfma_type
.
m
+
blk_id
*
mfma_type
.
group_size
;
static_if
<!
IsABroadcast
()
>
{}([
&
](
auto
)
{
col_blk
=
i
/
mfma_type
.
num_output_blks
;
row_blk
=
i
%
mfma_type
.
num_output_blks
;
col
=
col_blk
*
mfma_type
.
n
+
blk_td
;
row
=
row_blk
*
mfma_type
.
m
+
blk_id
*
mfma_type
.
group_size
;
});
return
MatrixIndex
{
row
,
col
};
}
__device__
static
constexpr
auto
GetOutputLayout
()
{
constexpr
auto
mfma_type
=
GetMFMAInfo
<
data_type
,
MPerWave
,
NPerWave
>
();
constexpr
auto
M1
=
mfma_type
.
num_groups_blk
;
constexpr
auto
M0
=
mfma_type
.
group_size
;
constexpr
auto
N1
=
mfma_type
.
num_input_blks
;
constexpr
auto
N0
=
mfma_type
.
num_threads_blk
;
return
OutputLayout
<
M1
,
M0
,
N1
,
N0
>
{};
}
template
<
index_t
Size
>
__device__
void
SetZeroXdlopsRegs
(
Number
<
Size
>
)
const
{
#if !CK_USE_AMD_XDLOPS_EMULATE
// gcnasm_accvgpr_zero<Size>();
#endif
}
template
<
index_t
Size
,
class
FloatC
>
__device__
void
ReadXdlopsRegs
(
Number
<
Size
>
,
FloatC
*
const
__restrict__
p_c_thread
)
const
{
#if !CK_USE_AMD_XDLOPS_EMULATE
constexpr
auto
mfma_type
=
GetMFMAInfo
<
data_type
,
MPerWave
,
NPerWave
>
();
// gcnasm_nop<mfma_type.cycles>();
// gcnasm_accvgpr_read<Size>(p_c_thread);
#else
(
void
)
p_c_thread
;
#endif
}
};
}
// namespace ck
#endif
composable_kernel/include/utility/amd_xdlops_emulate.hpp
deleted
100644 → 0
View file @
80901f59
#ifndef CK_AMD_XDLOPS_HPP
#define CK_AMD_XDLOPS_HPP
namespace
ck
{
template
<
index_t
MPerWave
,
index_t
NPerWave
>
__device__
void
gcnasm_mfma_f32_32x32x1f32
(
const
float
&
,
const
float
&
,
float32_t
*
);
template
<
>
__device__
void
gcnasm_mfma_f32_32x32x1f32
<
64
,
64
>
(
const
float
&
reg_a
,
const
float
&
reg_b
,
float32_t
*
reg_c
)
{
auto
reg_c_
=
reinterpret_cast
<
float_t
*>
(
reg_c
);
for
(
index_t
i
=
0
;
i
<
32
;
i
++
)
{
reg_c_
[
i
+
32
]
=
reg_c_
[
i
]
=
reg_c_
[
i
]
+
reg_a
*
reg_b
;
}
}
template
<
>
__device__
void
gcnasm_mfma_f32_32x32x1f32
<
32
,
64
>
(
const
float
&
reg_a
,
const
float
&
reg_b
,
float32_t
*
reg_c
)
{
auto
reg_c_
=
reinterpret_cast
<
float_t
*>
(
reg_c
);
for
(
index_t
i
=
0
;
i
<
16
;
i
++
)
{
reg_c_
[
i
]
+=
reg_a
*
reg_b
;
}
}
template
<
>
__device__
void
gcnasm_mfma_f32_32x32x1f32
<
64
,
32
>
(
const
float
&
reg_a
,
const
float
&
reg_b
,
float32_t
*
reg_c
)
{
auto
reg_c_
=
reinterpret_cast
<
float_t
*>
(
reg_c
);
for
(
index_t
i
=
0
;
i
<
16
;
i
++
)
{
reg_c_
[
i
]
+=
reg_a
*
reg_b
;
}
}
__device__
void
gcnasm_mfma_f32_32x32x2f32
(
const
float
&
reg_a
,
const
float
&
reg_b
,
float16_t
*
reg_c
)
{
auto
reg_c_
=
reinterpret_cast
<
float_t
*>
(
reg_c
);
for
(
index_t
i
=
0
;
i
<
16
;
i
++
)
{
reg_c_
[
i
]
+=
reg_a
*
reg_b
;
}
}
__device__
void
gcnasm_mfma_f32_16x16x4f32
(
const
float
&
reg_a
,
const
float
&
reg_b
,
float4_t
*
reg_c
)
{
}
template
<
index_t
MPerWave
,
index_t
NPerWave
>
__device__
void
gcnasm_mfma_f32_16x16x1f32
(
const
float
&
,
const
float
&
,
float16_t
*
);
template
<
>
__device__
void
gcnasm_mfma_f32_16x16x1f32
<
16
,
64
>
(
const
float
&
reg_a
,
const
float
&
reg_b
,
float16_t
*
reg_c
)
{
}
template
<
>
__device__
void
gcnasm_mfma_f32_16x16x1f32
<
64
,
16
>
(
const
float
&
reg_a
,
const
float
&
reg_b
,
float16_t
*
reg_c
)
{
}
template
<
index_t
MPerWave
,
index_t
NPerWave
>
__device__
void
gcnasm_mfma_f32_4x4x1f32
(
const
float
&
reg_a
,
const
float
&
reg_b
,
float4_t
*
reg_c
);
template
<
>
__device__
void
gcnasm_mfma_f32_4x4x1f32
<
4
,
64
>
(
const
float
&
reg_a
,
const
float
&
reg_b
,
float4_t
*
reg_c
)
{
}
template
<
>
__device__
void
gcnasm_mfma_f32_4x4x1f32
<
8
,
64
>
(
const
float
&
reg_a
,
const
float
&
reg_b
,
float4_t
*
reg_c
)
{
}
template
<
index_t
MPerWave
,
index_t
NPerWave
>
__device__
void
gcnasm_mfma_f32_32x32x4f16
(
const
half4_t
&
,
const
half4_t
&
,
float32_t
*
);
template
<
>
__device__
void
gcnasm_mfma_f32_32x32x4f16
<
64
,
64
>
(
const
half4_t
&
reg_a
,
const
half4_t
&
reg_b
,
float32_t
*
reg_c
)
{
}
template
<
>
__device__
void
gcnasm_mfma_f32_32x32x4f16
<
32
,
64
>
(
const
half4_t
&
reg_a
,
const
half4_t
&
reg_b
,
float32_t
*
reg_c
)
{
}
template
<
>
__device__
void
gcnasm_mfma_f32_32x32x4f16
<
64
,
32
>
(
const
half4_t
&
reg_a
,
const
half4_t
&
reg_b
,
float32_t
*
reg_c
)
{
}
__device__
void
gcnasm_mfma_f32_32x32x8f16
(
const
half4_t
&
reg_a
,
const
half4_t
&
reg_b
,
float16_t
*
reg_c
)
{
}
__device__
void
gcnasm_mfma_f32_16x16x16f16
(
const
half4_t
&
reg_a
,
const
half4_t
&
reg_b
,
float4_t
*
reg_c
)
{
}
template
<
index_t
MPerWave
,
index_t
NPerWave
>
__device__
void
gcnasm_mfma_f32_16x16x4f16
(
const
half4_t
&
reg_a
,
const
half4_t
&
reg_b
,
float16_t
*
reg_c
);
template
<
>
__device__
void
gcnasm_mfma_f32_16x16x4f16
<
16
,
64
>
(
const
half4_t
&
reg_a
,
const
half4_t
&
reg_b
,
float16_t
*
reg_c
)
{
}
template
<
>
__device__
void
gcnasm_mfma_f32_16x16x4f16
<
64
,
16
>
(
const
half4_t
&
reg_a
,
const
half4_t
&
reg_b
,
float16_t
*
reg_c
)
{
}
template
<
index_t
MPerWave
,
index_t
NPerWave
>
__device__
void
gcnasm_mfma_f32_4x4x4f16
(
const
half4_t
&
reg_a
,
const
half4_t
&
reg_b
,
float4_t
*
reg_c
);
template
<
>
__device__
void
gcnasm_mfma_f32_4x4x4f16
<
4
,
64
>
(
const
half4_t
&
reg_a
,
const
half4_t
&
reg_b
,
float4_t
*
reg_c
)
{
}
template
<
>
__device__
void
gcnasm_mfma_f32_4x4x4f16
<
8
,
64
>
(
const
half4_t
&
reg_a
,
const
half4_t
&
reg_b
,
float4_t
*
reg_c
)
{
}
template
<
index_t
MPerWave
,
index_t
NPerWave
>
__device__
void
gcnasm_mfma_f32_32x32x2bf16
(
const
ushort2_t
&
,
const
ushort2_t
&
,
float32_t
*
);
template
<
>
__device__
void
gcnasm_mfma_f32_32x32x2bf16
<
64
,
64
>
(
const
ushort2_t
&
reg_a
,
const
ushort2_t
&
reg_b
,
float32_t
*
reg_c
)
{
}
template
<
>
__device__
void
gcnasm_mfma_f32_32x32x2bf16
<
32
,
64
>
(
const
ushort2_t
&
reg_a
,
const
ushort2_t
&
reg_b
,
float32_t
*
reg_c
)
{
}
template
<
>
__device__
void
gcnasm_mfma_f32_32x32x2bf16
<
64
,
32
>
(
const
ushort2_t
&
reg_a
,
const
ushort2_t
&
reg_b
,
float32_t
*
reg_c
)
{
}
__device__
void
gcnasm_mfma_f32_32x32x4bf16
(
const
ushort2_t
&
reg_a
,
const
ushort2_t
&
reg_b
,
float16_t
*
reg_c
)
{
}
__device__
void
gcnasm_mfma_f32_16x16x8bf16
(
const
ushort2_t
&
reg_a
,
const
ushort2_t
&
reg_b
,
float4_t
*
reg_c
)
{
}
template
<
index_t
MPerWave
,
index_t
NPerWave
>
__device__
void
gcnasm_mfma_f32_16x16x2bf16
(
const
ushort2_t
&
reg_a
,
const
ushort2_t
&
reg_b
,
float16_t
*
reg_c
);
template
<
>
__device__
void
gcnasm_mfma_f32_16x16x2bf16
<
16
,
64
>
(
const
ushort2_t
&
reg_a
,
const
ushort2_t
&
reg_b
,
float16_t
*
reg_c
)
{
}
template
<
>
__device__
void
gcnasm_mfma_f32_16x16x2bf16
<
64
,
16
>
(
const
ushort2_t
&
reg_a
,
const
ushort2_t
&
reg_b
,
float16_t
*
reg_c
)
{
}
template
<
index_t
MPerWave
,
index_t
NPerWave
>
__device__
void
gcnasm_mfma_f32_4x4x2bf16
(
const
ushort2_t
&
reg_a
,
const
ushort2_t
&
reg_b
,
float4_t
*
reg_c
);
template
<
>
__device__
void
gcnasm_mfma_f32_4x4x2bf16
<
4
,
64
>
(
const
ushort2_t
&
reg_a
,
const
ushort2_t
&
reg_b
,
float4_t
*
reg_c
)
{
}
template
<
>
__device__
void
gcnasm_mfma_f32_4x4x2bf16
<
8
,
64
>
(
const
ushort2_t
&
reg_a
,
const
ushort2_t
&
reg_b
,
float4_t
*
reg_c
)
{
}
// clang-format on
}
#endif
driver/include/device_col2im_eb_nchw.hpp
deleted
100644 → 0
View file @
80901f59
#pragma once
#include <unistd.h>
#include "device.hpp"
#include "tensor.hpp"
#include "gridwise_operation_wrapper.hpp"
#include "gridwise_col2im_eb_nchw.hpp"
template
<
typename
T
,
typename
ColDesc
,
typename
ImgDesc
,
typename
FilterSizes
,
typename
OutputSizes
,
typename
ConvStrides
,
typename
ConvDilations
,
typename
LeftPads
,
typename
RightPads
>
void
device_col2im_eb_nchw
(
ColDesc
,
const
Tensor
<
T
>&
col_eb
,
ImgDesc
,
Tensor
<
T
>&
img_nchw
,
FilterSizes
,
OutputSizes
,
ConvStrides
,
ConvDilations
,
LeftPads
,
RightPads
,
std
::
size_t
nrepeat
)
{
using
namespace
ck
;
constexpr
auto
col_eb_desc
=
ColDesc
{};
constexpr
auto
img_nchw_desc
=
ImgDesc
{};
constexpr
index_t
N
=
img_nchw_desc
.
GetLengths
()[
0
];
constexpr
index_t
C
=
img_nchw_desc
.
GetLengths
()[
1
];
constexpr
index_t
Hi
=
img_nchw_desc
.
GetLengths
()[
2
];
constexpr
index_t
Wi
=
img_nchw_desc
.
GetLengths
()[
3
];
constexpr
index_t
E
=
col_eb_desc
.
GetLengths
()[
0
];
constexpr
index_t
B
=
col_eb_desc
.
GetLengths
()[
1
];
std
::
size_t
data_sz
=
sizeof
(
T
);
DeviceMem
col_eb_device_buf
(
data_sz
*
col_eb
.
mDesc
.
GetElementSpace
());
DeviceMem
img_nchw_device_buf
(
data_sz
*
img_nchw
.
mDesc
.
GetElementSpace
());
col_eb_device_buf
.
ToDevice
(
col_eb
.
mData
.
data
());
img_nchw_device_buf
.
ToDevice
(
img_nchw
.
mData
.
data
());
#if 1
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
EPerBlock
=
128
;
constexpr
index_t
BPerBlock
=
128
;
using
BlockCopySubLengths_E_B
=
Sequence
<
8
,
8
>
;
using
BlockCopyClusterLengths_E_B
=
Sequence
<
16
,
16
>
;
using
BlockCopyThreadClusterArrangeOrder
=
Sequence
<
0
,
1
>
;
// [E, B]
using
BlockCopySrcAccessOrder
=
Sequence
<
0
,
1
>
;
// [E, B]
using
BlockCopyDstAccessOrder
=
Sequence
<
0
,
1
>
;
// [E, B]
constexpr
index_t
BlockCopyDataPerAccess_B
=
1
;
#endif
constexpr
index_t
GridSize
=
((
E
+
EPerBlock
-
1
)
/
EPerBlock
)
*
((
B
+
BPerBlock
-
1
)
/
BPerBlock
);
printf
(
"%s: BlockSize %u, GridSize %u
\n
"
,
__func__
,
BlockSize
,
GridSize
);
constexpr
auto
gridwise_col2im
=
GridwiseCol2Im_eb_nchw
<
GridSize
,
BlockSize
,
T
,
ColDesc
,
ImgDesc
,
FilterSizes
,
OutputSizes
,
ConvStrides
,
ConvDilations
,
LeftPads
,
RightPads
,
EPerBlock
,
BPerBlock
,
BlockCopySubLengths_E_B
,
BlockCopyClusterLengths_E_B
,
BlockCopyThreadClusterArrangeOrder
,
BlockCopySrcAccessOrder
,
BlockCopyDstAccessOrder
,
BlockCopyDataPerAccess_B
>
{};
for
(
index_t
i
=
0
;
i
<
nrepeat
;
++
i
)
{
float
time
=
launch_and_time_kernel
(
run_gridwise_operation
<
decltype
(
gridwise_col2im
),
const
T
*
const
__restrict__
,
T
*
const
__restrict__
>
,
dim3
(
GridSize
),
dim3
(
BlockSize
),
0
,
gridwise_col2im
,
const_cast
<
const
T
*
const
__restrict__
>
(
static_cast
<
T
*>
(
col_eb_device_buf
.
GetDeviceBuffer
())),
const_cast
<
T
*
const
__restrict__
>
(
static_cast
<
T
*>
(
img_nchw_device_buf
.
GetDeviceBuffer
())));
printf
(
"Elapsed time : %f ms
\n
"
,
time
);
usleep
(
std
::
min
(
time
*
1000
,
float
(
10000
)));
}
img_nchw_device_buf
.
FromDevice
(
img_nchw
.
mData
.
data
());
}
driver/include/device_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw.hpp
View file @
506a823a
...
...
@@ -149,7 +149,7 @@ void device_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw(InDesc i
printf
(
"%s: BlockSize %u, GridSize %u
\n
"
,
__func__
,
BlockSize
,
GridSize
);
constexpr
auto
gridwise_conv
=
GridwiseConvolutionBackwardDataImplicitGemm_v1r1_nchw_kcyx_nkhw
<
using
gridwise_conv
_bwd_data
=
GridwiseConvolutionBackwardDataImplicitGemm_v1r1_nchw_kcyx_nkhw
<
GridSize
,
BlockSize
,
T
,
...
...
@@ -181,28 +181,38 @@ void device_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw(InDesc i
GemmBBlockCopyThreadClusterLengths_GemmK_GemmN
,
GemmBBlockCopySrcDataPerRead_GemmN
,
GemmBBlockCopyDstDataPerWrite_GemmN
,
GemmCThreadCopyDstDataPerWrite_GemmN1
>
{}
;
GemmCThreadCopyDstDataPerWrite_GemmN1
>
;
for
(
index_t
i
=
0
;
i
<
nrepeat
;
++
i
)
for
(
index_t
i
=
0
;
i
<
5
;
++
i
)
{
float
time
=
launch_and_time_kernel
(
run_gridwise_operation
<
decltype
(
gridwise_conv
),
T
*
const
__restrict__
,
const
T
*
const
__restrict__
,
const
T
*
const
__restrict__
>
,
dim3
(
GridSize
),
dim3
(
BlockSize
),
0
,
0
,
gridwise_conv
,
static_cast
<
T
*>
(
in_nchw_device_buf
.
GetDeviceBuffer
()),
static_cast
<
T
*>
(
wei_kcyx_device_buf
.
GetDeviceBuffer
()),
static_cast
<
T
*>
(
out_nkhw_device_buf
.
GetDeviceBuffer
()));
printf
(
"Elapsed time : %f ms, %f TFlop/s
\n
"
,
time
,
(
float
)
calculate_convolution_flops
(
InDesc
{},
WeiDesc
{},
OutDesc
{})
/
(
std
::
size_t
(
1000
)
*
1000
*
1000
)
/
time
);
usleep
(
std
::
min
(
time
*
1000
,
float
(
10000
)));
std
::
cout
<<
"Start running "
<<
nrepeat
<<
" times..."
<<
std
::
endl
;
KernelTimer
timer
;
timer
.
Start
();
for
(
index_t
j
=
0
;
j
<
nrepeat
;
++
j
)
{
launch_kernel
(
run_gridwise_operation
<
gridwise_conv_bwd_data
,
T
*
const
__restrict__
,
const
T
*
const
__restrict__
,
const
T
*
const
__restrict__
>
,
dim3
(
GridSize
),
dim3
(
BlockSize
),
0
,
0
,
static_cast
<
T
*>
(
in_nchw_device_buf
.
GetDeviceBuffer
()),
static_cast
<
T
*>
(
wei_kcyx_device_buf
.
GetDeviceBuffer
()),
static_cast
<
T
*>
(
out_nkhw_device_buf
.
GetDeviceBuffer
()));
}
timer
.
End
();
float
ave_time
=
timer
.
GetElapsedTime
()
/
nrepeat
;
float
perf
=
(
float
)
calculate_convolution_flops
(
InDesc
{},
WeiDesc
{},
OutDesc
{})
/
(
std
::
size_t
(
1000
)
*
1000
*
1000
)
/
ave_time
;
std
::
cout
<<
"Average time : "
<<
ave_time
<<
" ms, "
<<
perf
<<
" TFlop/s"
<<
std
::
endl
;
}
in_nchw_device_buf
.
FromDevice
(
in_nchw
.
mData
.
data
());
...
...
driver/include/device_convolution_backward_data_implicit_gemm_v1r2_nchw_kcyx_nkhw.hpp
View file @
506a823a
...
...
@@ -55,25 +55,27 @@ void device_convolution_backward_data_implicit_gemm_v1r2_nchw_kcyx_nkhw(InDesc i
constexpr
index_t
BPerBlock
=
32
;
constexpr
index_t
EPerBlock
=
32
;
constexpr
index_t
KPerBlock
=
8
;
constexpr
index_t
KPerBlock
=
16
;
constexpr
index_t
GemmMPerThread
=
4
;
constexpr
index_t
GemmNPerThread
=
4
;
constexpr
index_t
GemmKPerThread
=
1
;
constexpr
index_t
GemmMPerThreadSubC
=
4
;
constexpr
index_t
GemmNPerThreadSubC
=
4
;
constexpr
index_t
GemmMLevel0Cluster
=
4
;
constexpr
index_t
GemmNLevel0Cluster
=
4
;
constexpr
index_t
GemmMLevel1Cluster
=
4
;
constexpr
index_t
GemmNLevel1Cluster
=
4
;
constexpr
index_t
GemmKPerThreadLoop
=
1
;
constexpr
index_t
GemmDataPerReadA
=
4
;
constexpr
index_t
GemmDataPerReadB
=
4
;
using
OutBlockCopySubLengths_K_B_N0
=
Sequence
<
1
,
1
,
4
>
;
constexpr
index_t
GemmDataPerReadA
=
4
;
constexpr
index_t
GemmDataPerReadB
=
4
;
using
OutBlockCopySubLengths_K_B_N0
=
Sequence
<
2
,
1
,
4
>
;
using
OutBlockCopyClusterLengths_K_B_N0
=
Sequence
<
8
,
32
,
1
>
;
constexpr
index_t
OutBlockCopySrcDataPerRead_B
=
1
;
constexpr
index_t
OutBlockCopyDstDataPerWrite_N0
=
4
;
using
WeiBlockCopySubLengths_K_E_C0
=
Sequence
<
1
,
4
,
1
>
;
using
WeiBlockCopySubLengths_K_E_C0
=
Sequence
<
2
,
4
,
1
>
;
using
WeiBlockCopyClusterLengths_K_E_C0
=
Sequence
<
8
,
8
,
4
>
;
constexpr
index_t
WeiBlockCopySrcDataPerRead_E
=
4
;
...
...
@@ -82,8 +84,8 @@ void device_convolution_backward_data_implicit_gemm_v1r2_nchw_kcyx_nkhw(InDesc i
constexpr
index_t
InThreadCopyDstDataPerWrite_B
=
1
;
#endif
constexpr
index_t
C0
=
GemmMPerThread
SubC
;
constexpr
index_t
N0
=
GemmNPerThread
SubC
;
constexpr
index_t
C0
=
GemmMPerThread
;
constexpr
index_t
N0
=
GemmNPerThread
;
constexpr
index_t
C1
=
C
/
C0
;
constexpr
index_t
N1
=
N
/
N0
;
...
...
@@ -96,7 +98,7 @@ void device_convolution_backward_data_implicit_gemm_v1r2_nchw_kcyx_nkhw(InDesc i
printf
(
"%s: BlockSize %u, GridSize %u
\n
"
,
__func__
,
BlockSize
,
GridSize
);
constexpr
auto
gridwise_conv
=
using
gridwise_conv
_bwd_data
=
GridwiseConvolutionBackwardDataImplicitGemm_v1r2_nchw_kcyx_nkhw_lds_double_buffer
<
GridSize
,
BlockSize
,
...
...
@@ -112,13 +114,13 @@ void device_convolution_backward_data_implicit_gemm_v1r2_nchw_kcyx_nkhw(InDesc i
EPerBlock
,
BPerBlock
,
KPerBlock
,
GemmMPerThreadSubC
,
GemmNPerThreadSubC
,
GemmMPerThread
,
GemmNPerThread
,
GemmKPerThread
,
GemmMLevel0Cluster
,
GemmNLevel0Cluster
,
GemmMLevel1Cluster
,
GemmNLevel1Cluster
,
GemmKPerThreadLoop
,
GemmDataPerReadA
,
GemmDataPerReadB
,
OutBlockCopySubLengths_K_B_N0
,
...
...
@@ -129,28 +131,38 @@ void device_convolution_backward_data_implicit_gemm_v1r2_nchw_kcyx_nkhw(InDesc i
WeiBlockCopyClusterLengths_K_E_C0
,
WeiBlockCopySrcDataPerRead_E
,
WeiBlockCopyDstDataPerWrite_C0
,
InThreadCopyDstDataPerWrite_B
>
{}
;
InThreadCopyDstDataPerWrite_B
>
;
for
(
index_t
i
=
0
;
i
<
nrepeat
;
++
i
)
for
(
index_t
i
=
0
;
i
<
5
;
++
i
)
{
float
time
=
launch_and_time_kernel
(
run_gridwise_operation
<
decltype
(
gridwise_conv
),
T
*
const
__restrict__
,
const
T
*
const
__restrict__
,
const
T
*
const
__restrict__
>
,
dim3
(
GridSize
),
dim3
(
BlockSize
),
0
,
0
,
gridwise_conv
,
static_cast
<
T
*>
(
in_nchw_device_buf
.
GetDeviceBuffer
()),
static_cast
<
T
*>
(
wei_kcyx_device_buf
.
GetDeviceBuffer
()),
static_cast
<
T
*>
(
out_nkhw_device_buf
.
GetDeviceBuffer
()));
printf
(
"Elapsed time : %f ms, %f TFlop/s
\n
"
,
time
,
(
float
)
calculate_convolution_flops
(
InDesc
{},
WeiDesc
{},
OutDesc
{})
/
(
std
::
size_t
(
1000
)
*
1000
*
1000
)
/
time
);
usleep
(
std
::
min
(
time
*
1000
,
float
(
10000
)));
std
::
cout
<<
"Start running "
<<
nrepeat
<<
" times..."
<<
std
::
endl
;
KernelTimer
timer
;
timer
.
Start
();
for
(
index_t
j
=
0
;
j
<
nrepeat
;
++
j
)
{
launch_kernel
(
run_gridwise_operation
<
gridwise_conv_bwd_data
,
T
*
const
__restrict__
,
const
T
*
const
__restrict__
,
const
T
*
const
__restrict__
>
,
dim3
(
GridSize
),
dim3
(
BlockSize
),
0
,
0
,
static_cast
<
T
*>
(
in_nchw_device_buf
.
GetDeviceBuffer
()),
static_cast
<
T
*>
(
wei_kcyx_device_buf
.
GetDeviceBuffer
()),
static_cast
<
T
*>
(
out_nkhw_device_buf
.
GetDeviceBuffer
()));
}
timer
.
End
();
float
ave_time
=
timer
.
GetElapsedTime
()
/
nrepeat
;
float
perf
=
(
float
)
calculate_convolution_flops
(
InDesc
{},
WeiDesc
{},
OutDesc
{})
/
(
std
::
size_t
(
1000
)
*
1000
*
1000
)
/
ave_time
;
std
::
cout
<<
"Average time : "
<<
ave_time
<<
" ms, "
<<
perf
<<
" TFlop/s"
<<
std
::
endl
;
}
in_nchw_device_buf
.
FromDevice
(
in_nchw
.
mData
.
data
());
...
...
driver/include/device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw.hpp
View file @
506a823a
...
...
@@ -185,7 +185,7 @@ void device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw(InDesc i
printf
(
"%s: BlockSize %u, GridSize %u
\n
"
,
__func__
,
BlockSize
,
GridSize
);
constexpr
auto
gridwise_conv
=
GridwiseConvolutionBackwardDataImplicitGemm_v2r1_nchw_kcyx_nkhw
<
using
gridwise_conv
_bwd_data
=
GridwiseConvolutionBackwardDataImplicitGemm_v2r1_nchw_kcyx_nkhw
<
GridSize
,
BlockSize
,
T
,
...
...
@@ -217,28 +217,38 @@ void device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw(InDesc i
GemmBBlockCopyThreadClusterLengths_GemmK_GemmN
,
GemmBBlockCopySrcDataPerRead_GemmN
,
GemmBBlockCopyDstDataPerWrite_GemmN
,
GemmCThreadCopyDstDataPerWrite_GemmN1
>
{}
;
GemmCThreadCopyDstDataPerWrite_GemmN1
>
;
for
(
index_t
i
=
0
;
i
<
nrepeat
;
++
i
)
for
(
index_t
i
=
0
;
i
<
5
;
++
i
)
{
float
time
=
launch_and_time_kernel
(
run_gridwise_operation
<
decltype
(
gridwise_conv
),
T
*
const
__restrict__
,
const
T
*
const
__restrict__
,
const
T
*
const
__restrict__
>
,
dim3
(
GridSize
),
dim3
(
BlockSize
),
0
,
0
,
gridwise_conv
,
static_cast
<
T
*>
(
in_nchw_device_buf
.
GetDeviceBuffer
()),
static_cast
<
T
*>
(
wei_kcyx_device_buf
.
GetDeviceBuffer
()),
static_cast
<
T
*>
(
out_nkhw_device_buf
.
GetDeviceBuffer
()));
printf
(
"Elapsed time : %f ms, %f TFlop/s
\n
"
,
time
,
(
float
)
calculate_convolution_flops
(
InDesc
{},
WeiDesc
{},
OutDesc
{})
/
(
std
::
size_t
(
1000
)
*
1000
*
1000
)
/
time
);
usleep
(
std
::
min
(
time
*
1000
,
float
(
10000
)));
std
::
cout
<<
"Start running "
<<
nrepeat
<<
" times..."
<<
std
::
endl
;
KernelTimer
timer
;
timer
.
Start
();
for
(
index_t
j
=
0
;
j
<
nrepeat
;
++
j
)
{
launch_kernel
(
run_gridwise_operation
<
gridwise_conv_bwd_data
,
T
*
const
__restrict__
,
const
T
*
const
__restrict__
,
const
T
*
const
__restrict__
>
,
dim3
(
GridSize
),
dim3
(
BlockSize
),
0
,
0
,
static_cast
<
T
*>
(
in_nchw_device_buf
.
GetDeviceBuffer
()),
static_cast
<
T
*>
(
wei_kcyx_device_buf
.
GetDeviceBuffer
()),
static_cast
<
T
*>
(
out_nkhw_device_buf
.
GetDeviceBuffer
()));
}
timer
.
End
();
float
ave_time
=
timer
.
GetElapsedTime
()
/
nrepeat
;
float
perf
=
(
float
)
calculate_convolution_flops
(
InDesc
{},
WeiDesc
{},
OutDesc
{})
/
(
std
::
size_t
(
1000
)
*
1000
*
1000
)
/
ave_time
;
std
::
cout
<<
"Average time : "
<<
ave_time
<<
" ms, "
<<
perf
<<
" TFlop/s"
<<
std
::
endl
;
}
in_nchw_device_buf
.
FromDevice
(
in_nchw
.
mData
.
data
());
...
...
driver/include/device_convolution_backward_data_implicit_gemm_v3r1_nchw_kcyx_nkhw.hpp
View file @
506a823a
...
...
@@ -124,7 +124,7 @@ void device_convolution_backward_data_implicit_gemm_v3r1_nchw_kcyx_nkhw(InDesc i
printf
(
"%s: BlockSize %u, GridSize %u
\n
"
,
__func__
,
BlockSize
,
GridSize
);
constexpr
auto
gridwise_conv
=
GridwiseConvolutionBackwardDataImplicitGemm_v3r1_nchw_kcyx_nkhw
<
using
gridwise_conv
_bwd_data
=
GridwiseConvolutionBackwardDataImplicitGemm_v3r1_nchw_kcyx_nkhw
<
GridSize
,
BlockSize
,
T
,
...
...
@@ -156,28 +156,38 @@ void device_convolution_backward_data_implicit_gemm_v3r1_nchw_kcyx_nkhw(InDesc i
GemmBBlockCopyThreadClusterLengths_GemmK_GemmN
,
GemmBBlockCopySrcDataPerRead_GemmN
,
GemmBBlockCopyDstDataPerWrite_GemmN
,
GemmCThreadCopyDstDataPerWrite_GemmN1
>
{}
;
GemmCThreadCopyDstDataPerWrite_GemmN1
>
;
for
(
index_t
i
=
0
;
i
<
nrepeat
;
++
i
)
for
(
index_t
i
=
0
;
i
<
5
;
++
i
)
{
float
time
=
launch_and_time_kernel
(
run_gridwise_operation
<
decltype
(
gridwise_conv
),
T
*
const
__restrict__
,
const
T
*
const
__restrict__
,
const
T
*
const
__restrict__
>
,
dim3
(
GridSize
),
dim3
(
BlockSize
),
0
,
0
,
gridwise_conv
,
static_cast
<
T
*>
(
in_nchw_device_buf
.
GetDeviceBuffer
()),
static_cast
<
T
*>
(
wei_kcyx_device_buf
.
GetDeviceBuffer
()),
static_cast
<
T
*>
(
out_nkhw_device_buf
.
GetDeviceBuffer
()));
printf
(
"Elapsed time : %f ms, %f TFlop/s
\n
"
,
time
,
(
float
)
calculate_convolution_flops
(
InDesc
{},
WeiDesc
{},
OutDesc
{})
/
(
std
::
size_t
(
1000
)
*
1000
*
1000
)
/
time
);
usleep
(
std
::
min
(
time
*
1000
,
float
(
10000
)));
std
::
cout
<<
"Start running "
<<
nrepeat
<<
" times..."
<<
std
::
endl
;
KernelTimer
timer
;
timer
.
Start
();
for
(
index_t
j
=
0
;
j
<
nrepeat
;
++
j
)
{
launch_kernel
(
run_gridwise_operation
<
gridwise_conv_bwd_data
,
T
*
const
__restrict__
,
const
T
*
const
__restrict__
,
const
T
*
const
__restrict__
>
,
dim3
(
GridSize
),
dim3
(
BlockSize
),
0
,
0
,
static_cast
<
T
*>
(
in_nchw_device_buf
.
GetDeviceBuffer
()),
static_cast
<
T
*>
(
wei_kcyx_device_buf
.
GetDeviceBuffer
()),
static_cast
<
T
*>
(
out_nkhw_device_buf
.
GetDeviceBuffer
()));
}
timer
.
End
();
float
ave_time
=
timer
.
GetElapsedTime
()
/
nrepeat
;
float
perf
=
(
float
)
calculate_convolution_flops
(
InDesc
{},
WeiDesc
{},
OutDesc
{})
/
(
std
::
size_t
(
1000
)
*
1000
*
1000
)
/
ave_time
;
std
::
cout
<<
"Average time : "
<<
ave_time
<<
" ms, "
<<
perf
<<
" TFlop/s"
<<
std
::
endl
;
}
in_nchw_device_buf
.
FromDevice
(
in_nchw
.
mData
.
data
());
...
...
driver/include/device_convolution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp
View file @
506a823a
...
...
@@ -2,18 +2,13 @@
#include <unistd.h>
#include "device.hpp"
#include "tensor.hpp"
#include "gridwise_operation_wrapper.hpp"
#include "gridwise_convolution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp"
namespace
launcher
{
using
namespace
ck
;
template
<
typename
GridwiseOp
,
index_t
GemmId
,
typename
...
Xs
>
__global__
void
run_gridwise_convolution_backward_data_v4r1
(
Xs
...
xs
)
{
GridwiseOp
::
template
Run
<
GemmId
>(
xs
...);
}
template
<
typename
T
,
typename
InDesc
,
typename
WeiDesc
,
...
...
@@ -62,7 +57,7 @@ void device_convolution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc i
wei_kcyx_device_buf
.
ToDevice
(
wei_kcyx
.
mData
.
data
());
out_nkhw_device_buf
.
ToDevice
(
out_nkhw
.
mData
.
data
());
#if
0
#if
1
// BlockSize = 256, each thread hold 64 data
constexpr
index_t
BlockSize
=
256
;
...
...
@@ -92,36 +87,6 @@ void device_convolution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc i
constexpr
index_t
GemmBBlockCopyDstDataPerWrite_GemmN
=
1
;
constexpr
index_t
GemmCThreadCopyDstDataPerWrite_GemmN1
=
1
;
#elif
1
// BlockSize = 256, each thread hold 64 data
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
GemmMPerBlock
=
128
;
constexpr
index_t
GemmNPerBlock
=
128
;
constexpr
index_t
GemmKPerBlock
=
16
;
constexpr
index_t
GemmMPerThread
=
4
;
constexpr
index_t
GemmNPerThread
=
4
;
constexpr
index_t
GemmKPerThread
=
1
;
constexpr
index_t
GemmMLevel0Cluster
=
4
;
constexpr
index_t
GemmNLevel0Cluster
=
4
;
constexpr
index_t
GemmMLevel1Cluster
=
4
;
constexpr
index_t
GemmNLevel1Cluster
=
4
;
constexpr
index_t
GemmThreadGemmDataPerReadM
=
4
;
constexpr
index_t
GemmThreadGemmDataPerReadN
=
4
;
using
GemmABlockCopyThreadSliceLengths_GemmK_GemmM
=
Sequence
<
2
,
4
>
;
using
GemmABlockCopyThreadClusterLengths_GemmK_GemmM
=
Sequence
<
8
,
32
>
;
constexpr
index_t
GemmABlockCopySrcDataPerRead_GemmM
=
4
;
constexpr
index_t
GemmABlockCopyDstDataPerWrite_GemmM
=
4
;
using
GemmBBlockCopyThreadSliceLengths_GemmK_GemmN
=
Sequence
<
2
,
4
>
;
using
GemmBBlockCopyThreadClusterLengths_GemmK_GemmN
=
Sequence
<
8
,
32
>
;
constexpr
index_t
GemmBBlockCopySrcDataPerRead_GemmN
=
4
;
constexpr
index_t
GemmBBlockCopyDstDataPerWrite_GemmN
=
4
;
constexpr
index_t
GemmCThreadCopyDstDataPerWrite_GemmN1
=
4
;
#endif
constexpr
index_t
GcdStrideDilationH
=
math
::
gcd
(
ConvStrideH
,
ConvDilationH
);
...
...
@@ -157,78 +122,82 @@ void device_convolution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc i
printf
(
"%s: BlockSize %u, GridSize %u
\n
"
,
__func__
,
BlockSize
,
GridSize
);
for
(
index_t
i
=
0
;
i
<
nrepeat
;
++
i
)
for
(
index_t
i
=
0
;
i
<
5
;
++
i
)
{
using
GridwiseConvBwdData
=
GridwiseConvolutionBackwardDataImplicitGemm_v4r1_nchw_kcyx_nkhw
<
GridSize
,
BlockSize
,
T
,
T
,
decltype
(
in_nchw_desc
),
decltype
(
wei_kcyx_desc
),
decltype
(
out_nkhw_desc
),
ConvStrides
,
ConvDilations
,
InLeftPads
,
InRightPads
,
GemmMPerBlock
,
GemmNPerBlock
,
GemmKPerBlock
,
GemmMPerThread
,
GemmNPerThread
,
GemmKPerThread
,
GemmMLevel0Cluster
,
GemmNLevel0Cluster
,
GemmMLevel1Cluster
,
GemmNLevel1Cluster
,
GemmThreadGemmDataPerReadM
,
GemmThreadGemmDataPerReadN
,
GemmABlockCopyThreadSliceLengths_GemmK_GemmM
,
GemmABlockCopyThreadClusterLengths_GemmK_GemmM
,
GemmABlockCopySrcDataPerRead_GemmM
,
GemmABlockCopyDstDataPerWrite_GemmM
,
GemmBBlockCopyThreadSliceLengths_GemmK_GemmN
,
GemmBBlockCopyThreadClusterLengths_GemmK_GemmN
,
GemmBBlockCopySrcDataPerRead_GemmN
,
GemmBBlockCopyDstDataPerWrite_GemmN
,
GemmCThreadCopyDstDataPerWrite_GemmN1
>
;
std
::
cout
<<
"Start running "
<<
nrepeat
<<
" times..."
<<
std
::
endl
;
KernelTimer
timer
;
timer
.
Start
();
static_for
<
0
,
GridwiseConvBwdData
::
GetNumberOfGemm
(),
1
>
{}([
&
](
auto
gemm_id_
)
{
constexpr
index_t
gemm_id
=
decltype
(
gemm_id_
){};
constexpr
auto
gemm_sizes
=
GridwiseConvBwdData
::
GetGemmSize
(
gemm_id
);
constexpr
index_t
gemm_k
=
gemm_sizes
.
At
(
2
);
constexpr
bool
is_gemm_not_empty
=
gemm_k
>
0
;
// only compile and run if GEMM is no empty
static_if
<
is_gemm_not_empty
>
{}([
&
](
auto
fwd
)
{
launch_kernel
(
run_gridwise_convolution_backward_data_v4r1
<
GridwiseConvBwdData
,
fwd
(
gemm_id
),
T
*
const
__restrict__
,
const
T
*
const
__restrict__
,
const
T
*
const
__restrict__
>
,
dim3
(
GridSize
),
dim3
(
BlockSize
),
0
,
0
,
static_cast
<
T
*>
(
in_nchw_device_buf
.
GetDeviceBuffer
()),
static_cast
<
T
*>
(
wei_kcyx_device_buf
.
GetDeviceBuffer
()),
static_cast
<
T
*>
(
out_nkhw_device_buf
.
GetDeviceBuffer
()));
for
(
index_t
i
=
0
;
i
<
nrepeat
;
++
i
)
{
using
GridwiseConvBwdData
=
GridwiseConvolutionBackwardDataImplicitGemm_v4r1_nchw_kcyx_nkhw
<
GridSize
,
BlockSize
,
T
,
T
,
decltype
(
in_nchw_desc
),
decltype
(
wei_kcyx_desc
),
decltype
(
out_nkhw_desc
),
ConvStrides
,
ConvDilations
,
InLeftPads
,
InRightPads
,
GemmMPerBlock
,
GemmNPerBlock
,
GemmKPerBlock
,
GemmMPerThread
,
GemmNPerThread
,
GemmKPerThread
,
GemmMLevel0Cluster
,
GemmNLevel0Cluster
,
GemmMLevel1Cluster
,
GemmNLevel1Cluster
,
GemmThreadGemmDataPerReadM
,
GemmThreadGemmDataPerReadN
,
GemmABlockCopyThreadSliceLengths_GemmK_GemmM
,
GemmABlockCopyThreadClusterLengths_GemmK_GemmM
,
GemmABlockCopySrcDataPerRead_GemmM
,
GemmABlockCopyDstDataPerWrite_GemmM
,
GemmBBlockCopyThreadSliceLengths_GemmK_GemmN
,
GemmBBlockCopyThreadClusterLengths_GemmK_GemmN
,
GemmBBlockCopySrcDataPerRead_GemmN
,
GemmBBlockCopyDstDataPerWrite_GemmN
,
GemmCThreadCopyDstDataPerWrite_GemmN1
>
;
static_for
<
0
,
GridwiseConvBwdData
::
GetNumberOfGemm
(),
1
>
{}([
&
](
auto
gemm_id
)
{
constexpr
auto
gemm_sizes
=
GridwiseConvBwdData
::
GetGemmSize
(
gemm_id
);
constexpr
index_t
gemm_k
=
gemm_sizes
.
At
(
2
);
constexpr
bool
is_gemm_not_empty
=
gemm_k
>
0
;
// only compile and run if GEMM is no empty
static_if
<
is_gemm_not_empty
>
{}([
&
](
auto
fwd
)
{
launch_kernel
(
run_gridwise_operation
<
GridwiseConvBwdData
,
T
*
const
__restrict__
,
const
T
*
const
__restrict__
,
const
T
*
const
__restrict__
,
decltype
(
gemm_id
)
>
,
dim3
(
GridSize
),
dim3
(
BlockSize
),
0
,
0
,
static_cast
<
T
*>
(
in_nchw_device_buf
.
GetDeviceBuffer
()),
static_cast
<
T
*>
(
wei_kcyx_device_buf
.
GetDeviceBuffer
()),
static_cast
<
T
*>
(
out_nkhw_device_buf
.
GetDeviceBuffer
()),
fwd
(
gemm_id
));
});
});
}
);
}
timer
.
End
();
float
time
=
timer
.
GetElapsedTime
();
printf
(
"Elapsed time : %f ms, %f TFlop/s
\n
"
,
time
,
(
float
)
calculate_convolution_flops
(
InDesc
{},
WeiDesc
{},
OutDesc
{})
/
(
std
::
size_t
(
1000
)
*
1000
*
1000
)
/
time
);
usleep
(
std
::
min
(
time
*
1000
,
float
(
10000
)));
float
ave_time
=
timer
.
GetElapsedTime
()
/
nrepeat
;
float
perf
=
(
float
)
calculate_convolution_flops
(
InDesc
{},
WeiDesc
{},
OutDesc
{})
/
(
std
::
size_t
(
1000
)
*
1000
*
1000
)
/
ave_time
;
std
::
cout
<<
"Average time : "
<<
ave_time
<<
" ms, "
<<
perf
<<
" TFlop/s"
<<
std
::
endl
;
}
in_nchw_device_buf
.
FromDevice
(
in_nchw
.
mData
.
data
());
...
...
driver/include/device_convolution_direct_v2_nchw_kcyx_nkhw.hpp
deleted
100644 → 0
View file @
80901f59
#pragma once
#include <unistd.h>
#include "device.hpp"
#include "tensor.hpp"
#include "gridwise_convolution_kernel_wrapper.hpp"
#include "gridwise_convolution_direct_v2_nchw_kcyx_nkhw.hpp"
using
namespace
ck
;
template
<
class
T
,
class
InDesc
,
class
WeiDesc
,
class
OutDesc
>
void
device_convolution_direct_v2_nchw_kcyx_nkhw
(
InDesc
,
const
Tensor
<
T
>&
in
,
WeiDesc
,
const
Tensor
<
T
>&
wei
,
OutDesc
,
Tensor
<
T
>&
out
,
index_t
nrepeat
)
{
std
::
size_t
data_sz
=
sizeof
(
T
);
DeviceMem
in_device_buf
(
data_sz
*
in
.
mDesc
.
GetElementSpace
());
DeviceMem
wei_device_buf
(
data_sz
*
wei
.
mDesc
.
GetElementSpace
());
DeviceMem
out_device_buf
(
data_sz
*
out
.
mDesc
.
GetElementSpace
());
int
num_thread
=
std
::
thread
::
hardware_concurrency
();
in_device_buf
.
ToDevice
(
in
.
mData
.
data
());
wei_device_buf
.
ToDevice
(
wei
.
mData
.
data
());
out_device_buf
.
ToDevice
(
out
.
mData
.
data
());
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
in_desc
=
InDesc
{};
constexpr
auto
wei_desc
=
WeiDesc
{};
constexpr
auto
out_desc
=
OutDesc
{};
#if 1
// 3x3, 34x34, 128 thread
constexpr
index_t
NPerBlock
=
2
;
constexpr
index_t
KPerBlock
=
32
;
constexpr
index_t
CPerBlock
=
4
;
constexpr
index_t
HoPerBlock
=
2
;
constexpr
index_t
WoPerBlock
=
32
;
constexpr
index_t
NPerThread
=
2
;
constexpr
index_t
KPerThread
=
4
;
constexpr
index_t
CPerThread
=
2
;
constexpr
index_t
HoPerThread
=
2
;
constexpr
index_t
WoPerThread
=
2
;
constexpr
index_t
InBlockCopyDataPerRead
=
1
;
constexpr
index_t
WeiBlockCopyDataPerRead
=
1
;
constexpr
index_t
BlockSize
=
128
;
#endif
constexpr
index_t
GridSize
=
(
out_desc
.
GetLength
(
I0
)
/
NPerBlock
)
*
(
out_desc
.
GetLength
(
I1
)
/
KPerBlock
)
*
(
out_desc
.
GetLength
(
I2
)
/
HoPerBlock
)
*
(
out_desc
.
GetLength
(
I3
)
/
WoPerBlock
);
printf
(
"%s: BlockSize %u, GridSize %u
\n
"
,
__func__
,
BlockSize
,
GridSize
);
for
(
index_t
i
=
0
;
i
<
nrepeat
;
++
i
)
{
using
gridwise_conv
=
GridwiseConvolutionDirect_v2_nchw_kcyx_nkhw
<
GridSize
,
BlockSize
,
T
,
InDesc
,
WeiDesc
,
OutDesc
,
NPerBlock
,
KPerBlock
,
CPerBlock
,
HoPerBlock
,
WoPerBlock
,
NPerThread
,
KPerThread
,
CPerThread
,
HoPerThread
,
WoPerThread
,
InBlockCopyDataPerRead
,
WeiBlockCopyDataPerRead
>
;
float
time
=
launch_and_time_kernel
(
run_gridwise_convolution_kernel
<
gridwise_conv
,
T
>
,
dim3
(
GridSize
),
dim3
(
BlockSize
),
0
,
static_cast
<
T
*>
(
in_device_buf
.
GetDeviceBuffer
()),
static_cast
<
T
*>
(
wei_device_buf
.
GetDeviceBuffer
()),
static_cast
<
T
*>
(
out_device_buf
.
GetDeviceBuffer
()));
printf
(
"Elapsed time : %f ms
\n
"
,
time
);
usleep
(
std
::
min
(
time
*
1000
,
float
(
10000
)));
}
out_device_buf
.
FromDevice
(
out
.
mData
.
data
());
}
driver/include/device_convolution_implicit_gemm_v1_chwn_cyxk_khwn.hpp
deleted
100644 → 0
View file @
80901f59
#pragma once
#include <unistd.h>
#include "device.hpp"
#include "tensor.hpp"
#include "gridwise_convolution_kernel_wrapper.hpp"
#include "gridwise_convolution_implicit_gemm_v1r1_chwn_cyxk_khwn.hpp"
#include "gridwise_convolution_implicit_gemm_v1r2_chwn_cyxk_khwn.hpp"
#include "gridwise_convolution_implicit_gemm_v1r3_chwn_cyxk_khwn.hpp"
#include "gridwise_convolution_implicit_gemm_v1r3_chwn_cyxk_khwn_lds_double_buffer.hpp"
using
namespace
ck
;
template
<
class
T
,
class
InDesc
,
class
WeiDesc
,
class
OutDesc
>
void
device_convolution_implicit_gemm_v1_chwn_cyxk_khwn
(
InDesc
,
const
Tensor
<
T
>&
in_nchw
,
WeiDesc
,
const
Tensor
<
T
>&
wei_kcyx
,
OutDesc
,
Tensor
<
T
>&
out_nkhw
,
index_t
nrepeat
)
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
in_nchw_desc
=
InDesc
{};
constexpr
auto
wei_kcyx_desc
=
WeiDesc
{};
constexpr
auto
out_nkhw_desc
=
OutDesc
{};
constexpr
index_t
Hi
=
in_nchw_desc
.
GetLength
(
I2
);
constexpr
index_t
Wi
=
in_nchw_desc
.
GetLength
(
I3
);
constexpr
index_t
N
=
out_nkhw_desc
.
GetLength
(
I0
);
constexpr
index_t
Ho
=
out_nkhw_desc
.
GetLength
(
I2
);
constexpr
index_t
Wo
=
out_nkhw_desc
.
GetLength
(
I3
);
constexpr
index_t
K
=
wei_kcyx_desc
.
GetLength
(
I0
);
constexpr
index_t
C
=
wei_kcyx_desc
.
GetLength
(
I1
);
constexpr
index_t
Y
=
wei_kcyx_desc
.
GetLength
(
I2
);
constexpr
index_t
X
=
wei_kcyx_desc
.
GetLength
(
I3
);
// reorder weight
auto
wei_cyxk_desc
=
make_ConstantTensorDescriptor_packed
(
Sequence
<
C
,
Y
,
X
,
K
>
{});
ostream_ConstantTensorDescriptor
(
wei_cyxk_desc
,
std
::
cout
<<
"wei_cyxk_desc: "
);
Tensor
<
T
>
wei_cyxk
(
make_TensorDescriptor
(
wei_cyxk_desc
));
auto
f_reorder_kcyx2cyxk
=
[
&
](
auto
k
,
auto
c
,
auto
y
,
auto
x
)
{
wei_cyxk
(
c
,
y
,
x
,
k
)
=
wei_kcyx
(
k
,
c
,
y
,
x
);
};
make_ParallelTensorFunctor
(
f_reorder_kcyx2cyxk
,
K
,
C
,
Y
,
X
)(
std
::
thread
::
hardware_concurrency
());
// reorder input
auto
in_chwn_desc
=
make_ConstantTensorDescriptor_packed
(
Sequence
<
C
,
Hi
,
Wi
,
N
>
{});
ostream_ConstantTensorDescriptor
(
in_chwn_desc
,
std
::
cout
<<
"in_chwn_desc: "
);
Tensor
<
T
>
in_chwn
(
make_TensorDescriptor
(
in_chwn_desc
));
auto
f_reorder_nchw2chwn
=
[
&
](
auto
n
,
auto
c
,
auto
hi
,
auto
wi
)
{
in_chwn
(
c
,
hi
,
wi
,
n
)
=
in_nchw
(
n
,
c
,
hi
,
wi
);
};
make_ParallelTensorFunctor
(
f_reorder_nchw2chwn
,
N
,
C
,
Hi
,
Wi
)(
std
::
thread
::
hardware_concurrency
());
// output
auto
out_khwn_desc
=
make_ConstantTensorDescriptor_packed
(
Sequence
<
K
,
Ho
,
Wo
,
N
>
{});
ostream_ConstantTensorDescriptor
(
out_khwn_desc
,
std
::
cout
<<
"out_khwn_desc: "
);
Tensor
<
T
>
out_khwn
(
make_TensorDescriptor
(
out_khwn_desc
));
std
::
size_t
data_sz
=
sizeof
(
T
);
DeviceMem
in_chwn_device_buf
(
data_sz
*
in_chwn
.
mDesc
.
GetElementSpace
());
DeviceMem
wei_cyxk_device_buf
(
data_sz
*
wei_cyxk
.
mDesc
.
GetElementSpace
());
DeviceMem
out_khwn_device_buf
(
data_sz
*
out_khwn
.
mDesc
.
GetElementSpace
());
in_chwn_device_buf
.
ToDevice
(
in_chwn
.
mData
.
data
());
wei_cyxk_device_buf
.
ToDevice
(
wei_cyxk
.
mData
.
data
());
out_khwn_device_buf
.
ToDevice
(
out_khwn
.
mData
.
data
());
#if 0
// for 3x3, 34x34, v1r1, Pascal
constexpr index_t BlockSize = 128;
constexpr index_t NPerBlock = 16;
constexpr index_t KPerBlock = 64;
constexpr index_t CPerBlock = 4;
constexpr index_t HoPerBlock = 2;
constexpr index_t WoPerBlock = 4;
constexpr index_t NPerThread = 4;
constexpr index_t KPerThread = 8;
constexpr index_t HoPerThread = 1;
constexpr index_t WoPerThread = 2;
constexpr index_t GemmMPerThreadSubC = 4;
constexpr index_t GemmNPerThreadSubC = 4;
constexpr index_t GemmMLevel0Cluster = 4;
constexpr index_t GemmNLevel0Cluster = 2;
constexpr index_t GemmMLevel1Cluster = 2;
constexpr index_t GemmNLevel1Cluster = 4;
constexpr index_t GemmKPerThreadLoop = 1;
constexpr index_t GemmDataPerReadA = 4;
constexpr index_t GemmDataPerReadB = 4;
using InBlockCopyClusterLengths_CHWN = Sequence<4, 4, 2, 4>;
constexpr index_t InBlockCopyDataPerAccess_N = 4;
constexpr index_t WeiBlockCopyDataPerAccess_K = 4;
constexpr index_t OutThreadCopyDataPerAccess_N = 2;
#elif
1
// for 3x3, 34x34, v1r3, Pascal
// for 3x3, 28x28, v1r3, Pascal
// for 3x3, 14x14, v1r3, Pascal
constexpr
index_t
BlockSize
=
128
;
constexpr
index_t
NPerBlock
=
16
;
constexpr
index_t
KPerBlock
=
128
;
constexpr
index_t
CPerBlock
=
8
;
constexpr
index_t
HoPerBlock
=
2
;
constexpr
index_t
WoPerBlock
=
2
;
constexpr
index_t
NPerThread
=
4
;
constexpr
index_t
KPerThread
=
8
;
constexpr
index_t
HoPerThread
=
1
;
constexpr
index_t
WoPerThread
=
2
;
constexpr
index_t
GemmMPerThreadSubC
=
4
;
constexpr
index_t
GemmNPerThreadSubC
=
4
;
constexpr
index_t
GemmMLevel0Cluster
=
4
;
constexpr
index_t
GemmNLevel0Cluster
=
2
;
constexpr
index_t
GemmMLevel1Cluster
=
4
;
constexpr
index_t
GemmNLevel1Cluster
=
2
;
constexpr
index_t
GemmKPerThreadLoop
=
1
;
constexpr
index_t
GemmDataPerReadA
=
4
;
constexpr
index_t
GemmDataPerReadB
=
4
;
using
InBlockCopySubLengths_CHWN
=
Sequence
<
1
,
1
,
1
,
4
>
;
using
InBlockCopyClusterLengths_CHWN
=
Sequence
<
8
,
2
,
2
,
4
>
;
constexpr
index_t
InBlockCopyDataPerAccess_N
=
4
;
using
WeiBlockCopySubLengths_CK
=
Sequence
<
2
,
4
>
;
using
WeiBlockCopyClusterLengths_CK
=
Sequence
<
4
,
32
>
;
constexpr
index_t
WeiBlockCopyDataPerAccess_K
=
4
;
constexpr
index_t
OutThreadCopyDataPerAccess_N
=
2
;
#elif 0
// for 3x3, 34x34, v1r1, Vega 20
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
NPerBlock
=
16
;
constexpr
index_t
KPerBlock
=
128
;
constexpr
index_t
CPerBlock
=
4
;
constexpr
index_t
HoPerBlock
=
2
;
constexpr
index_t
WoPerBlock
=
4
;
constexpr
index_t
NPerThread
=
4
;
constexpr
index_t
KPerThread
=
8
;
constexpr
index_t
HoPerThread
=
1
;
constexpr
index_t
WoPerThread
=
2
;
constexpr
index_t
GemmMPerThreadSubC
=
4
;
constexpr
index_t
GemmNPerThreadSubC
=
4
;
constexpr
index_t
GemmMLevel0Cluster
=
4
;
constexpr
index_t
GemmNLevel0Cluster
=
4
;
constexpr
index_t
GemmMLevel1Cluster
=
4
;
constexpr
index_t
GemmNLevel1Cluster
=
2
;
constexpr
index_t
GemmKPerThreadLoop
=
1
;
constexpr
index_t
GemmDataPerReadA
=
4
;
constexpr
index_t
GemmDataPerReadB
=
4
;
using
InBlockCopyClusterLengths_CHWN
=
Sequence
<
4
,
4
,
2
,
8
>
;
constexpr
index_t
InBlockCopyDataPerAccess_N
=
2
;
constexpr
index_t
WeiBlockCopyDataPerAccess_K
=
2
;
constexpr
index_t
OutThreadCopyDataPerAccess_N
=
4
;
#elif 1
// for 3x3, 34x34, v1r3, Vega 20
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
NPerBlock
=
16
;
constexpr
index_t
KPerBlock
=
128
;
constexpr
index_t
CPerBlock
=
8
;
constexpr
index_t
HoPerBlock
=
2
;
constexpr
index_t
WoPerBlock
=
4
;
constexpr
index_t
NPerThread
=
4
;
constexpr
index_t
KPerThread
=
8
;
constexpr
index_t
HoPerThread
=
1
;
constexpr
index_t
WoPerThread
=
2
;
constexpr
index_t
GemmMPerThreadSubC
=
4
;
constexpr
index_t
GemmNPerThreadSubC
=
4
;
constexpr
index_t
GemmMLevel0Cluster
=
4
;
constexpr
index_t
GemmNLevel0Cluster
=
4
;
constexpr
index_t
GemmMLevel1Cluster
=
4
;
constexpr
index_t
GemmNLevel1Cluster
=
2
;
constexpr
index_t
GemmKPerThreadLoop
=
1
;
constexpr
index_t
GemmDataPerReadA
=
4
;
constexpr
index_t
GemmDataPerReadB
=
4
;
using
InBlockCopySubLengths_CHWN
=
Sequence
<
1
,
1
,
1
,
4
>
;
using
InBlockCopyClusterLengths_CHWN
=
Sequence
<
8
,
2
,
4
,
4
>
;
constexpr
index_t
InBlockCopyDataPerAccess_N
=
4
;
using
WeiBlockCopySubLengths_CK
=
Sequence
<
1
,
4
>
;
using
WeiBlockCopyClusterLengths_CK
=
Sequence
<
8
,
32
>
;
constexpr
index_t
WeiBlockCopyDataPerAccess_K
=
4
;
constexpr
index_t
OutThreadCopyDataPerAccess_N
=
4
;
#elif 0
// for 3x3, 56x56, v1r1, Pascal
constexpr
index_t
NPerBlock
=
32
;
constexpr
index_t
KPerBlock
=
64
;
constexpr
index_t
CPerBlock
=
4
;
constexpr
index_t
HoPerBlock
=
2
;
constexpr
index_t
WoPerBlock
=
2
;
constexpr
index_t
NPerThread
=
4
;
constexpr
index_t
KPerThread
=
8
;
constexpr
index_t
HoPerThread
=
1
;
constexpr
index_t
WoPerThread
=
2
;
constexpr
index_t
InBlockCopy_ThreadPerDimC
=
1
;
constexpr
index_t
InBlockCopy_ThreadPerDimH
=
4
;
constexpr
index_t
InBlockCopy_ThreadPerDimW
=
4
;
constexpr
index_t
InBlockCopy_ThreadPerDimN
=
8
;
constexpr
index_t
InBlockCopyDataPerAccess_N
=
4
;
constexpr
index_t
WeiBlockCopyDataPerAccess_K
=
4
;
constexpr
index_t
GemmMPerThreadSubC
=
4
;
constexpr
index_t
GemmNPerThreadSubC
=
4
;
constexpr
index_t
GemmMLevel0Cluster
=
4
;
constexpr
index_t
GemmNLevel0Cluster
=
2
;
constexpr
index_t
GemmMLevel1Cluster
=
2
;
constexpr
index_t
GemmNLevel1Cluster
=
4
;
constexpr
index_t
GemmKPerThreadLoop
=
1
;
constexpr
index_t
OutThreadCopyDataPerAccess_N
=
2
;
constexpr
index_t
BlockSize
=
128
;
#elif 0
// for 3x3, 56x56, v1r2, Pascal
constexpr
index_t
NPerBlock
=
16
;
constexpr
index_t
KPerBlock
=
128
;
constexpr
index_t
CPerBlock
=
8
;
constexpr
index_t
HoPerBlock
=
2
;
constexpr
index_t
WoPerBlock
=
2
;
constexpr
index_t
NPerThread
=
4
;
constexpr
index_t
KPerThread
=
8
;
constexpr
index_t
HoPerThread
=
1
;
constexpr
index_t
WoPerThread
=
2
;
constexpr
index_t
GemmMPerThreadSubC
=
4
;
constexpr
index_t
GemmNPerThreadSubC
=
4
;
constexpr
index_t
GemmMLevel0Cluster
=
4
;
constexpr
index_t
GemmNLevel0Cluster
=
2
;
constexpr
index_t
GemmMLevel1Cluster
=
4
;
constexpr
index_t
GemmNLevel1Cluster
=
2
;
constexpr
index_t
GemmKPerThreadLoop
=
1
;
constexpr
index_t
GemmDataPerReadA
=
1
;
constexpr
index_t
GemmDataPerReadB
=
1
;
constexpr
index_t
InBlockCopy_ThreadPerDimC
=
1
;
constexpr
index_t
InBlockCopy_ThreadPerDimH
=
2
;
constexpr
index_t
InBlockCopy_ThreadPerDimW
=
4
;
constexpr
index_t
InBlockCopy_ThreadPerDimN
=
4
;
constexpr
index_t
InBlockCopyDataPerAccess_N
=
4
;
constexpr
index_t
WeiBlockCopyDataPerAccess_K
=
4
;
constexpr
index_t
OutThreadCopyDataPerAccess_N
=
4
;
constexpr
index_t
BlockSize
=
128
;
#elif 0
// for 3x3, 28x28, v1r1, Pacal
constexpr
index_t
NPerBlock
=
32
;
constexpr
index_t
KPerBlock
=
64
;
constexpr
index_t
CPerBlock
=
4
;
constexpr
index_t
HoPerBlock
=
2
;
constexpr
index_t
WoPerBlock
=
2
;
constexpr
index_t
NPerThread
=
4
;
constexpr
index_t
KPerThread
=
8
;
constexpr
index_t
HoPerThread
=
1
;
constexpr
index_t
WoPerThread
=
2
;
constexpr
index_t
InBlockCopy_ThreadPerDimC
=
1
;
constexpr
index_t
InBlockCopy_ThreadPerDimH
=
4
;
constexpr
index_t
InBlockCopy_ThreadPerDimW
=
4
;
constexpr
index_t
InBlockCopy_ThreadPerDimN
=
8
;
constexpr
index_t
InBlockCopyDataPerAccess_N
=
4
;
constexpr
index_t
WeiBlockCopyDataPerAccess_K
=
4
;
constexpr
index_t
GemmMPerThreadSubC
=
4
;
constexpr
index_t
GemmNPerThreadSubC
=
4
;
constexpr
index_t
GemmMLevel0Cluster
=
4
;
constexpr
index_t
GemmNLevel0Cluster
=
2
;
constexpr
index_t
GemmMLevel1Cluster
=
2
;
constexpr
index_t
GemmNLevel1Cluster
=
4
;
constexpr
index_t
GemmKPerThreadLoop
=
1
;
constexpr
index_t
GemmDataPerReadA
=
4
;
constexpr
index_t
GemmDataPerReadB
=
4
;
constexpr
index_t
OutThreadCopyDataPerAccess_N
=
2
;
constexpr
index_t
BlockSize
=
128
;
#elif 0
// for 3x3, 28x28, v1r2, Pascal
constexpr
index_t
BlockSize
=
128
;
constexpr
index_t
NPerBlock
=
16
;
constexpr
index_t
KPerBlock
=
128
;
constexpr
index_t
CPerBlock
=
8
;
constexpr
index_t
HoPerBlock
=
2
;
constexpr
index_t
WoPerBlock
=
2
;
constexpr
index_t
NPerThread
=
4
;
constexpr
index_t
KPerThread
=
8
;
constexpr
index_t
HoPerThread
=
1
;
constexpr
index_t
WoPerThread
=
2
;
constexpr
index_t
GemmMPerThreadSubC
=
4
;
constexpr
index_t
GemmNPerThreadSubC
=
4
;
constexpr
index_t
GemmMLevel0Cluster
=
4
;
constexpr
index_t
GemmNLevel0Cluster
=
2
;
constexpr
index_t
GemmMLevel1Cluster
=
4
;
constexpr
index_t
GemmNLevel1Cluster
=
2
;
constexpr
index_t
GemmKPerThreadLoop
=
1
;
constexpr
index_t
GemmDataPerReadA
=
4
;
constexpr
index_t
GemmDataPerReadB
=
4
;
using
InBlockCopyClusterLengths_CHWN
=
Sequence
<
4
,
2
,
4
,
4
>
;
constexpr
index_t
InBlockCopyDataPerAccess_N
=
4
;
constexpr
index_t
WeiBlockCopyDataPerAccess_K
=
4
;
constexpr
index_t
OutThreadCopyDataPerAccess_N
=
2
;
#elif 0
// for 1x1, 28x28, v1r1, Pascal
constexpr
index_t
NPerBlock
=
16
;
constexpr
index_t
KPerBlock
=
128
;
constexpr
index_t
CPerBlock
=
8
;
constexpr
index_t
HoPerBlock
=
2
;
constexpr
index_t
WoPerBlock
=
2
;
constexpr
index_t
NPerThread
=
4
;
constexpr
index_t
KPerThread
=
16
;
constexpr
index_t
CPerThread
=
1
;
constexpr
index_t
HoPerThread
=
1
;
constexpr
index_t
WoPerThread
=
1
;
constexpr
index_t
InBlockCopy_ThreadPerDimC
=
8
;
constexpr
index_t
InBlockCopy_ThreadPerDimH
=
2
;
constexpr
index_t
InBlockCopy_ThreadPerDimW
=
2
;
constexpr
index_t
InBlockCopy_ThreadPerDimN
=
4
;
constexpr
index_t
InBlockCopyDataPerAccess_N
=
4
;
constexpr
index_t
WeiBlockCopyDataPerAccess_K
=
4
;
constexpr
index_t
GemmMPerThreadSubC
=
4
;
constexpr
index_t
GemmNPerThreadSubC
=
4
;
constexpr
index_t
GemmMLevel0Cluster
=
4
;
constexpr
index_t
GemmNLevel0Cluster
=
2
;
constexpr
index_t
GemmMLevel1Cluster
=
2
;
constexpr
index_t
GemmNLevel1Cluster
=
4
;
constexpr
index_t
GemmKPerThreadLoop
=
1
;
constexpr
index_t
OutThreadCopyDataPerAccess_N
=
2
;
constexpr
index_t
BlockSize
=
128
;
#elif 0
// for 1x1, 14x14, v1r1, Pascal
constexpr
index_t
NPerBlock
=
16
;
constexpr
index_t
KPerBlock
=
128
;
constexpr
index_t
CPerBlock
=
8
;
constexpr
index_t
HoPerBlock
=
2
;
constexpr
index_t
WoPerBlock
=
2
;
constexpr
index_t
NPerThread
=
8
;
constexpr
index_t
KPerThread
=
8
;
constexpr
index_t
HoPerThread
=
1
;
constexpr
index_t
WoPerThread
=
1
;
constexpr
index_t
GemmMPerThreadSubC
=
4
;
constexpr
index_t
GemmNPerThreadSubC
=
4
;
constexpr
index_t
GemmMLevel0Cluster
=
4
;
constexpr
index_t
GemmNLevel0Cluster
=
2
;
constexpr
index_t
GemmMLevel1Cluster
=
4
;
constexpr
index_t
GemmNLevel1Cluster
=
2
;
constexpr
index_t
GemmKPerThreadLoop
=
1
;
constexpr
index_t
InBlockCopy_ThreadPerDimC
=
8
;
constexpr
index_t
InBlockCopy_ThreadPerDimH
=
2
;
constexpr
index_t
InBlockCopy_ThreadPerDimW
=
2
;
constexpr
index_t
InBlockCopy_ThreadPerDimN
=
4
;
constexpr
index_t
InBlockCopyDataPerAccess_N
=
4
;
constexpr
index_t
WeiBlockCopyDataPerAccess_K
=
4
;
constexpr
index_t
OutThreadCopyDataPerAccess_N
=
2
;
constexpr
index_t
BlockSize
=
128
;
#endif
constexpr
index_t
GridSize
=
(
N
/
NPerBlock
)
*
(
K
/
KPerBlock
)
*
(
Ho
/
HoPerBlock
)
*
(
Wo
/
WoPerBlock
);
printf
(
"%s: BlockSize %u, GridSize %u
\n
"
,
__func__
,
BlockSize
,
GridSize
);
constexpr
auto
gridwise_conv
=
#if 0
GridwiseConvolutionImplicitGemm_v1r1_chwn_cyxk_khwn
#elif
0
GridwiseConvolutionImplicitGemm_v1r2_chwn_cyxk_khwn
#elif 0
GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn
#elif 1
GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn_lds_double_buffer
#endif
<
GridSize
,
BlockSize
,
T
,
decltype
(
in_chwn_desc
),
decltype
(
wei_cyxk_desc
),
decltype
(
out_khwn_desc
),
NPerBlock
,
KPerBlock
,
CPerBlock
,
HoPerBlock
,
WoPerBlock
,
NPerThread
,
KPerThread
,
HoPerThread
,
WoPerThread
,
GemmMPerThreadSubC
,
GemmNPerThreadSubC
,
GemmMLevel0Cluster
,
GemmNLevel0Cluster
,
GemmMLevel1Cluster
,
GemmNLevel1Cluster
,
GemmKPerThreadLoop
,
GemmDataPerReadA
,
GemmDataPerReadB
,
InBlockCopySubLengths_CHWN
,
InBlockCopyClusterLengths_CHWN
,
InBlockCopyDataPerAccess_N
,
WeiBlockCopySubLengths_CK
,
WeiBlockCopyClusterLengths_CK
,
WeiBlockCopyDataPerAccess_K
,
OutThreadCopyDataPerAccess_N
>
{};
for
(
index_t
i
=
0
;
i
<
nrepeat
;
++
i
)
{
float
time
=
launch_and_time_kernel
(
run_gridwise_convolution_kernel
<
decltype
(
gridwise_conv
),
T
>
,
dim3
(
GridSize
),
dim3
(
BlockSize
),
0
,
static_cast
<
T
*>
(
in_chwn_device_buf
.
GetDeviceBuffer
()),
static_cast
<
T
*>
(
wei_cyxk_device_buf
.
GetDeviceBuffer
()),
static_cast
<
T
*>
(
out_khwn_device_buf
.
GetDeviceBuffer
()));
printf
(
"Elapsed time : %f ms, %f TFlop/s
\n
"
,
time
,
(
float
)
calculate_convolution_flops
(
InDesc
{},
WeiDesc
{},
OutDesc
{})
/
(
std
::
size_t
(
1000
)
*
1000
*
1000
)
/
time
);
usleep
(
std
::
min
(
time
*
1000
,
float
(
10000
)));
}
out_khwn_device_buf
.
FromDevice
(
out_khwn
.
mData
.
data
());
// reorder output
auto
f_reorder_khwn2nkhw
=
[
&
](
auto
k
,
auto
ho
,
auto
wo
,
auto
n
)
{
out_nkhw
(
n
,
k
,
ho
,
wo
)
=
out_khwn
(
k
,
ho
,
wo
,
n
);
};
make_ParallelTensorFunctor
(
f_reorder_khwn2nkhw
,
K
,
Ho
,
Wo
,
N
)(
std
::
thread
::
hardware_concurrency
());
}
driver/include/device_convolution_implicit_gemm_v1_chwn_cyxk_khwn_padded.hpp
deleted
100644 → 0
View file @
80901f59
#pragma once
#include <unistd.h>
#include "device.hpp"
#include "tensor.hpp"
#include "gridwise_convolution_implicit_gemm_v1r3_chwn_cyxk_khwn_padded.hpp"
using
namespace
ck
;
template
<
typename
T
,
class
InDesc
,
class
WeiDesc
,
class
OutDesc
,
class
LeftPads
,
class
RightPads
>
void
device_convolution_implicit_gemm_v1_chwn_cyxk_khwn_padded
(
InDesc
,
const
Tensor
<
T
>&
in_nchw
,
WeiDesc
,
const
Tensor
<
T
>&
wei_kcyx
,
OutDesc
,
Tensor
<
T
>&
out_nkhw
,
LeftPads
,
RightPads
,
index_t
nrepeat
)
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
in_nchw_desc
=
InDesc
{};
constexpr
auto
wei_kcyx_desc
=
WeiDesc
{};
constexpr
auto
out_nkhw_desc
=
OutDesc
{};
constexpr
index_t
Hi
=
in_nchw_desc
.
GetLength
(
I2
);
constexpr
index_t
Wi
=
in_nchw_desc
.
GetLength
(
I3
);
constexpr
index_t
N
=
out_nkhw_desc
.
GetLength
(
I0
);
constexpr
index_t
Ho
=
out_nkhw_desc
.
GetLength
(
I2
);
constexpr
index_t
Wo
=
out_nkhw_desc
.
GetLength
(
I3
);
constexpr
index_t
K
=
wei_kcyx_desc
.
GetLength
(
I0
);
constexpr
index_t
C
=
wei_kcyx_desc
.
GetLength
(
I1
);
constexpr
index_t
Y
=
wei_kcyx_desc
.
GetLength
(
I2
);
constexpr
index_t
X
=
wei_kcyx_desc
.
GetLength
(
I3
);
// reorder weight
auto
wei_cyxk_desc
=
make_ConstantTensorDescriptor_packed
(
Sequence
<
C
,
Y
,
X
,
K
>
{});
ostream_ConstantTensorDescriptor
(
wei_cyxk_desc
,
std
::
cout
<<
"wei_cyxk_desc: "
);
Tensor
<
T
>
wei_cyxk
(
make_TensorDescriptor
(
wei_cyxk_desc
));
auto
f_reorder_kcyx2cyxk
=
[
&
](
auto
k
,
auto
c
,
auto
y
,
auto
x
)
{
wei_cyxk
(
c
,
y
,
x
,
k
)
=
wei_kcyx
(
k
,
c
,
y
,
x
);
};
make_ParallelTensorFunctor
(
f_reorder_kcyx2cyxk
,
K
,
C
,
Y
,
X
)(
std
::
thread
::
hardware_concurrency
());
// reorder input
auto
in_chwn_desc
=
make_ConstantTensorDescriptor_packed
(
Sequence
<
C
,
Hi
,
Wi
,
N
>
{});
ostream_ConstantTensorDescriptor
(
in_chwn_desc
,
std
::
cout
<<
"in_chwn_desc: "
);
Tensor
<
T
>
in_chwn
(
make_TensorDescriptor
(
in_chwn_desc
));
auto
f_reorder_nchw2chwn
=
[
&
](
auto
n
,
auto
c
,
auto
hi
,
auto
wi
)
{
in_chwn
(
c
,
hi
,
wi
,
n
)
=
in_nchw
(
n
,
c
,
hi
,
wi
);
};
make_ParallelTensorFunctor
(
f_reorder_nchw2chwn
,
N
,
C
,
Hi
,
Wi
)(
std
::
thread
::
hardware_concurrency
());
// output
auto
out_khwn_desc
=
make_ConstantTensorDescriptor_packed
(
Sequence
<
K
,
Ho
,
Wo
,
N
>
{});
ostream_ConstantTensorDescriptor
(
out_khwn_desc
,
std
::
cout
<<
"out_khwn_desc: "
);
Tensor
<
T
>
out_khwn
(
make_TensorDescriptor
(
out_khwn_desc
));
std
::
size_t
data_sz
=
sizeof
(
T
);
DeviceMem
in_chwn_device_buf
(
data_sz
*
in_chwn
.
mDesc
.
GetElementSpace
());
DeviceMem
wei_cyxk_device_buf
(
data_sz
*
wei_cyxk
.
mDesc
.
GetElementSpace
());
DeviceMem
out_khwn_device_buf
(
data_sz
*
out_khwn
.
mDesc
.
GetElementSpace
());
in_chwn_device_buf
.
ToDevice
(
in_chwn
.
mData
.
data
());
wei_cyxk_device_buf
.
ToDevice
(
wei_cyxk
.
mData
.
data
());
out_khwn_device_buf
.
ToDevice
(
out_khwn
.
mData
.
data
());
#if 1
// v1r3, 3x3, 32x32, 1x1 pad
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
NPerBlock
=
32
;
constexpr
index_t
KPerBlock
=
128
;
constexpr
index_t
CPerBlock
=
8
;
constexpr
index_t
HoPerBlock
=
2
;
constexpr
index_t
WoPerBlock
=
2
;
constexpr
index_t
NPerThread
=
4
;
constexpr
index_t
KPerThread
=
8
;
constexpr
index_t
HoPerThread
=
1
;
constexpr
index_t
WoPerThread
=
2
;
constexpr
index_t
GemmMPerThreadSubC
=
4
;
constexpr
index_t
GemmNPerThreadSubC
=
4
;
constexpr
index_t
GemmMLevel0Cluster
=
4
;
constexpr
index_t
GemmNLevel0Cluster
=
4
;
constexpr
index_t
GemmMLevel1Cluster
=
4
;
constexpr
index_t
GemmNLevel1Cluster
=
2
;
constexpr
index_t
GemmKPerThreadLoop
=
1
;
constexpr
index_t
GemmDataPerReadA
=
4
;
constexpr
index_t
GemmDataPerReadB
=
4
;
using
InBlockCopySubLengths_CHWN
=
Sequence
<
1
,
1
,
1
,
4
>
;
using
InBlockCopyClusterLengths_CHWN
=
Sequence
<
8
,
2
,
2
,
8
>
;
constexpr
index_t
InBlockCopyDataPerAccess_N
=
4
;
using
WeiBlockCopySubLengths_CK
=
Sequence
<
1
,
4
>
;
using
WeiBlockCopyClusterLengths_CK
=
Sequence
<
8
,
32
>
;
constexpr
index_t
WeiBlockCopyDataPerAccess_K
=
4
;
constexpr
index_t
OutThreadCopyDataPerAccess_N
=
4
;
#endif
#if 1 // debug
constexpr
index_t
GridSize
=
(
N
/
NPerBlock
)
*
(
K
/
KPerBlock
)
*
(
Ho
/
HoPerBlock
)
*
(
Wo
/
WoPerBlock
);
#else
constexpr
index_t
GridSize
=
1
;
#endif
printf
(
"%s: BlockSize %u, GridSize %u
\n
"
,
__func__
,
BlockSize
,
GridSize
);
constexpr
auto
gridwise_conv
=
GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn_padded
<
GridSize
,
BlockSize
,
T
,
decltype
(
in_chwn_desc
),
decltype
(
wei_cyxk_desc
),
decltype
(
out_khwn_desc
),
LeftPads
,
RightPads
,
NPerBlock
,
KPerBlock
,
CPerBlock
,
HoPerBlock
,
WoPerBlock
,
NPerThread
,
KPerThread
,
HoPerThread
,
WoPerThread
,
GemmMPerThreadSubC
,
GemmNPerThreadSubC
,
GemmMLevel0Cluster
,
GemmNLevel0Cluster
,
GemmMLevel1Cluster
,
GemmNLevel1Cluster
,
GemmKPerThreadLoop
,
GemmDataPerReadA
,
GemmDataPerReadB
,
InBlockCopySubLengths_CHWN
,
InBlockCopyClusterLengths_CHWN
,
InBlockCopyDataPerAccess_N
,
WeiBlockCopySubLengths_CK
,
WeiBlockCopyClusterLengths_CK
,
WeiBlockCopyDataPerAccess_K
,
OutThreadCopyDataPerAccess_N
>
{};
for
(
index_t
i
=
0
;
i
<
nrepeat
;
++
i
)
{
float
time
=
launch_and_time_kernel
(
run_gridwise_convolution_kernel
<
decltype
(
gridwise_conv
),
T
>
,
dim3
(
GridSize
),
dim3
(
BlockSize
),
0
,
static_cast
<
T
*>
(
in_chwn_device_buf
.
GetDeviceBuffer
()),
static_cast
<
T
*>
(
wei_cyxk_device_buf
.
GetDeviceBuffer
()),
static_cast
<
T
*>
(
out_khwn_device_buf
.
GetDeviceBuffer
()));
printf
(
"Elapsed time : %f ms, %f TFlop/s
\n
"
,
time
,
(
float
)
calculate_convolution_flops
(
InDesc
{},
WeiDesc
{},
OutDesc
{})
/
(
std
::
size_t
(
1000
)
*
1000
*
1000
)
/
time
);
usleep
(
std
::
min
(
time
*
1000
,
float
(
10000
)));
}
out_khwn_device_buf
.
FromDevice
(
out_khwn
.
mData
.
data
());
// reorder output
auto
f_reorder_khwn2nkhw
=
[
&
](
auto
k
,
auto
ho
,
auto
wo
,
auto
n
)
{
out_nkhw
(
n
,
k
,
ho
,
wo
)
=
out_khwn
(
k
,
ho
,
wo
,
n
);
};
make_ParallelTensorFunctor
(
f_reorder_khwn2nkhw
,
K
,
Ho
,
Wo
,
N
)(
std
::
thread
::
hardware_concurrency
());
}
driver/include/device_convolution_implicit_gemm_v1_nchw_cyxk_nkhw.hpp
deleted
100644 → 0
View file @
80901f59
#pragma once
#include <unistd.h>
#include "device.hpp"
#include "tensor.hpp"
#include "gridwise_convolution_kernel_wrapper.hpp"
#include "gridwise_convolution_implicit_gemm_v1r3_nchw_cyxk_nkhw.hpp"
#include "gridwise_convolution_implicit_gemm_v1r3_nchw_cyxk_nkhw_lds_double_buffer.hpp"
using
namespace
ck
;
template
<
class
T
,
class
InDesc
,
class
WeiDesc
,
class
OutDesc
>
void
device_convolution_implicit_gemm_v1_nchw_cyxk_nkhw
(
InDesc
,
const
Tensor
<
T
>&
in_nchw
,
WeiDesc
,
const
Tensor
<
T
>&
wei_kcyx
,
OutDesc
,
Tensor
<
T
>&
out_nkhw
,
index_t
nrepeat
)
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
in_nchw_desc
=
InDesc
{};
constexpr
auto
wei_kcyx_desc
=
WeiDesc
{};
constexpr
auto
out_nkhw_desc
=
OutDesc
{};
constexpr
index_t
Hi
=
in_nchw_desc
.
GetLength
(
I2
);
constexpr
index_t
Wi
=
in_nchw_desc
.
GetLength
(
I3
);
constexpr
index_t
N
=
out_nkhw_desc
.
GetLength
(
I0
);
constexpr
index_t
Ho
=
out_nkhw_desc
.
GetLength
(
I2
);
constexpr
index_t
Wo
=
out_nkhw_desc
.
GetLength
(
I3
);
constexpr
index_t
K
=
wei_kcyx_desc
.
GetLength
(
I0
);
constexpr
index_t
C
=
wei_kcyx_desc
.
GetLength
(
I1
);
constexpr
index_t
Y
=
wei_kcyx_desc
.
GetLength
(
I2
);
constexpr
index_t
X
=
wei_kcyx_desc
.
GetLength
(
I3
);
// reorder weight
auto
wei_cyxk_desc
=
make_ConstantTensorDescriptor_packed
(
Sequence
<
C
,
Y
,
X
,
K
>
{});
ostream_ConstantTensorDescriptor
(
wei_cyxk_desc
,
std
::
cout
<<
"wei_cyxk_desc: "
);
Tensor
<
T
>
wei_cyxk
(
make_TensorDescriptor
(
wei_cyxk_desc
));
auto
f_reorder_kcyx2cyxk
=
[
&
](
auto
k
,
auto
c
,
auto
y
,
auto
x
)
{
wei_cyxk
(
c
,
y
,
x
,
k
)
=
wei_kcyx
(
k
,
c
,
y
,
x
);
};
make_ParallelTensorFunctor
(
f_reorder_kcyx2cyxk
,
K
,
C
,
Y
,
X
)(
std
::
thread
::
hardware_concurrency
());
std
::
size_t
data_sz
=
sizeof
(
T
);
DeviceMem
in_nchw_device_buf
(
data_sz
*
in_nchw
.
mDesc
.
GetElementSpace
());
DeviceMem
wei_cyxk_device_buf
(
data_sz
*
wei_cyxk
.
mDesc
.
GetElementSpace
());
DeviceMem
out_nkhw_device_buf
(
data_sz
*
out_nkhw
.
mDesc
.
GetElementSpace
());
in_nchw_device_buf
.
ToDevice
(
in_nchw
.
mData
.
data
());
wei_cyxk_device_buf
.
ToDevice
(
wei_cyxk
.
mData
.
data
());
out_nkhw_device_buf
.
ToDevice
(
out_nkhw
.
mData
.
data
());
#if 0
// for 3x3, 34x34, v1r3, Pascal
constexpr index_t BlockSize = 128;
constexpr index_t NPerBlock = 2;
constexpr index_t KPerBlock = 128;
constexpr index_t CPerBlock = 8;
constexpr index_t HoPerBlock = 2;
constexpr index_t WoPerBlock = 16;
constexpr index_t NPerThread = 2;
constexpr index_t KPerThread = 8;
constexpr index_t HoPerThread = 1;
constexpr index_t WoPerThread = 4;
constexpr index_t GemmMPerThreadSubC = 4;
constexpr index_t GemmNPerThreadSubC = 4;
constexpr index_t GemmMLevel0Cluster = 4;
constexpr index_t GemmNLevel0Cluster = 2;
constexpr index_t GemmMLevel1Cluster = 4;
constexpr index_t GemmNLevel1Cluster = 2;
constexpr index_t GemmKPerThreadLoop = 1;
constexpr index_t GemmDataPerReadA = 4;
constexpr index_t GemmDataPerReadB = 4;
using InBlockReorderSrcSubLengths_NCHW = Sequence<2, 1, 2, 1>;
using InBlockReorderSrcClusterLengths_NCHW = Sequence<1, 8, 1, 16>;
using InBlockReorderMapThreadCluster2SrcCluster_CHNW2NCHW = Sequence<1, 2, 0, 3>;
constexpr index_t InBlockReorderDataPerRead_W = 1; // v1r3 cannot do vector load input for NCHW
constexpr index_t InBlockReorderDataPerWrite_N = 1;
using WeiBlockCopyClusterLengths = void;
constexpr index_t WeiBlockCopyDataPerRead_K = 4;
constexpr index_t OutThreadCopyDataPerWrite_W = 2;
#elif
0
// for 3x3, 34x34, v1r3, Vega 20, WoPerBlock = 32
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
NPerBlock
=
1
;
constexpr
index_t
KPerBlock
=
128
;
constexpr
index_t
CPerBlock
=
8
;
constexpr
index_t
HoPerBlock
=
4
;
constexpr
index_t
WoPerBlock
=
32
;
constexpr
index_t
NPerThread
=
1
;
constexpr
index_t
KPerThread
=
8
;
constexpr
index_t
HoPerThread
=
1
;
constexpr
index_t
WoPerThread
=
8
;
constexpr
index_t
GemmMPerThreadSubC
=
4
;
constexpr
index_t
GemmNPerThreadSubC
=
4
;
constexpr
index_t
GemmMLevel0Cluster
=
4
;
constexpr
index_t
GemmNLevel0Cluster
=
2
;
constexpr
index_t
GemmMLevel1Cluster
=
4
;
constexpr
index_t
GemmNLevel1Cluster
=
2
;
constexpr
index_t
GemmKPerThreadLoop
=
1
;
constexpr
index_t
GemmDataPerReadA
=
4
;
constexpr
index_t
GemmDataPerReadB
=
4
;
using
InBlockReorderSrcSubLengths_NCHW
=
Sequence
<
1
,
2
,
2
,
1
>
;
using
InBlockReorderSrcClusterLengths_NCHW
=
Sequence
<
1
,
4
,
2
,
32
>
;
using
InBlockReorderMapThreadCluster2SrcCluster_CHNW2NCHW
=
Sequence
<
1
,
2
,
0
,
3
>
;
constexpr
index_t
InBlockReorderDataPerRead_W
=
1
;
// v1r3 cannot do vector load NCHW
constexpr
index_t
InBlockReorderDataPerWrite_N
=
1
;
using
WeiBlockCopyClusterLengths
=
void
;
constexpr
index_t
WeiBlockCopyDataPerRead_K
=
4
;
constexpr
index_t
OutThreadCopyDataPerWrite_W
=
4
;
#elif 1
// for 3x3, 34x34, v1r3, Vega 20, WoPerBlock = 16
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
NPerBlock
=
2
;
constexpr
index_t
KPerBlock
=
128
;
constexpr
index_t
CPerBlock
=
8
;
constexpr
index_t
HoPerBlock
=
4
;
constexpr
index_t
WoPerBlock
=
16
;
constexpr
index_t
NPerThread
=
2
;
constexpr
index_t
KPerThread
=
8
;
constexpr
index_t
HoPerThread
=
1
;
constexpr
index_t
WoPerThread
=
4
;
constexpr
index_t
GemmMPerThreadSubC
=
4
;
constexpr
index_t
GemmNPerThreadSubC
=
4
;
constexpr
index_t
GemmMLevel0Cluster
=
4
;
constexpr
index_t
GemmNLevel0Cluster
=
2
;
constexpr
index_t
GemmMLevel1Cluster
=
4
;
constexpr
index_t
GemmNLevel1Cluster
=
2
;
constexpr
index_t
GemmKPerThreadLoop
=
1
;
constexpr
index_t
GemmDataPerReadA
=
4
;
constexpr
index_t
GemmDataPerReadB
=
4
;
using
InBlockReorderSrcSubLengths_NCHW
=
Sequence
<
2
,
1
,
2
,
1
>
;
using
InBlockReorderSrcClusterLengths_NCHW
=
Sequence
<
1
,
8
,
2
,
16
>
;
using
InBlockReorderMapThreadCluster2SrcCluster_CHNW2NCHW
=
Sequence
<
1
,
2
,
0
,
3
>
;
constexpr
index_t
InBlockReorderDataPerRead_W
=
1
;
// v1r3 cannot do vector load NCHW
constexpr
index_t
InBlockReorderDataPerWrite_N
=
2
;
using
WeiBlockCopyClusterLengths
=
void
;
constexpr
index_t
WeiBlockCopyDataPerRead_K
=
4
;
constexpr
index_t
OutThreadCopyDataPerWrite_W
=
2
;
#elif 0
// for 3x3, 34x34, v1r3, Vega 20, WoPerBlock = 8
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
NPerBlock
=
4
;
constexpr
index_t
KPerBlock
=
128
;
constexpr
index_t
CPerBlock
=
8
;
constexpr
index_t
HoPerBlock
=
4
;
constexpr
index_t
WoPerBlock
=
8
;
constexpr
index_t
NPerThread
=
4
;
constexpr
index_t
KPerThread
=
8
;
constexpr
index_t
HoPerThread
=
1
;
constexpr
index_t
WoPerThread
=
2
;
constexpr
index_t
GemmMPerThreadSubC
=
4
;
constexpr
index_t
GemmNPerThreadSubC
=
4
;
constexpr
index_t
GemmMLevel0Cluster
=
4
;
constexpr
index_t
GemmNLevel0Cluster
=
2
;
constexpr
index_t
GemmMLevel1Cluster
=
4
;
constexpr
index_t
GemmNLevel1Cluster
=
2
;
constexpr
index_t
GemmKPerThreadLoop
=
1
;
constexpr
index_t
GemmDataPerReadA
=
4
;
constexpr
index_t
GemmDataPerReadB
=
4
;
using
InBlockReorderSrcSubLengths_NCHW
=
Sequence
<
4
,
1
,
1
,
1
>
;
using
InBlockReorderSrcClusterLengths_NCHW
=
Sequence
<
1
,
8
,
4
,
8
>
;
using
InBlockReorderMapThreadCluster2SrcCluster_CHNW2NCHW
=
Sequence
<
1
,
2
,
0
,
3
>
;
constexpr
index_t
InBlockReorderDataPerRead_W
=
1
;
// v1r3 cannot do vector load NCHW
constexpr
index_t
InBlockReorderDataPerWrite_N
=
4
;
using
WeiBlockCopyClusterLengths
=
void
;
constexpr
index_t
WeiBlockCopyDataPerRead_K
=
4
;
constexpr
index_t
OutThreadCopyDataPerWrite_W
=
1
;
#elif 0
// for 3x3, 34x34, v1r3, Vega 20, WoPerBlock = 4
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
NPerBlock
=
8
;
constexpr
index_t
KPerBlock
=
128
;
constexpr
index_t
CPerBlock
=
8
;
constexpr
index_t
HoPerBlock
=
4
;
constexpr
index_t
WoPerBlock
=
4
;
constexpr
index_t
NPerThread
=
4
;
constexpr
index_t
KPerThread
=
8
;
constexpr
index_t
HoPerThread
=
1
;
constexpr
index_t
WoPerThread
=
2
;
constexpr
index_t
GemmMPerThreadSubC
=
4
;
constexpr
index_t
GemmNPerThreadSubC
=
4
;
constexpr
index_t
GemmMLevel0Cluster
=
4
;
constexpr
index_t
GemmNLevel0Cluster
=
2
;
constexpr
index_t
GemmMLevel1Cluster
=
4
;
constexpr
index_t
GemmNLevel1Cluster
=
2
;
constexpr
index_t
GemmKPerThreadLoop
=
1
;
constexpr
index_t
GemmDataPerReadA
=
4
;
constexpr
index_t
GemmDataPerReadB
=
4
;
using
InBlockReorderSrcSubLengths_NCHW
=
Sequence
<
4
,
1
,
1
,
1
>
;
using
InBlockReorderSrcClusterLengths_NCHW
=
Sequence
<
2
,
8
,
4
,
4
>
;
using
InBlockReorderMapThreadCluster2SrcCluster_CHNW2NCHW
=
Sequence
<
1
,
2
,
0
,
3
>
;
constexpr
index_t
InBlockReorderDataPerRead_W
=
1
;
// v1r3 cannot do vector load NCHW
constexpr
index_t
InBlockReorderDataPerWrite_N
=
4
;
using
WeiBlockCopyClusterLengths
=
void
;
constexpr
index_t
WeiBlockCopyDataPerRead_K
=
4
;
constexpr
index_t
OutThreadCopyDataPerWrite_W
=
1
;
#elif 0
// for 3x3, 34x34, v1r3, Vega 20, WoPerBlock = 2
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
NPerBlock
=
32
;
constexpr
index_t
KPerBlock
=
128
;
constexpr
index_t
CPerBlock
=
8
;
constexpr
index_t
HoPerBlock
=
2
;
constexpr
index_t
WoPerBlock
=
2
;
constexpr
index_t
NPerThread
=
4
;
constexpr
index_t
KPerThread
=
8
;
constexpr
index_t
HoPerThread
=
1
;
constexpr
index_t
WoPerThread
=
2
;
constexpr
index_t
GemmMPerThreadSubC
=
4
;
constexpr
index_t
GemmNPerThreadSubC
=
4
;
constexpr
index_t
GemmMLevel0Cluster
=
4
;
constexpr
index_t
GemmNLevel0Cluster
=
4
;
constexpr
index_t
GemmMLevel1Cluster
=
4
;
constexpr
index_t
GemmNLevel1Cluster
=
2
;
constexpr
index_t
GemmKPerThreadLoop
=
1
;
constexpr
index_t
GemmDataPerReadA
=
4
;
constexpr
index_t
GemmDataPerReadB
=
4
;
using
InBlockReorderSrcSubLengths_NCHW
=
Sequence
<
4
,
1
,
1
,
1
>
;
using
InBlockReorderSrcClusterLengths_NCHW
=
Sequence
<
8
,
8
,
2
,
2
>
;
using
InBlockReorderMapThreadCluster2SrcCluster_CHNW2NCHW
=
Sequence
<
1
,
2
,
0
,
3
>
;
constexpr
index_t
InBlockReorderDataPerRead_W
=
1
;
// v1r3 cannot do vector load NCHW
constexpr
index_t
InBlockReorderDataPerWrite_N
=
4
;
using
WeiBlockCopyClusterLengths
=
void
;
constexpr
index_t
WeiBlockCopyDataPerRead_K
=
4
;
constexpr
index_t
OutThreadCopyDataPerWrite_W
=
1
;
#elif 1
// for 3x3, 28x28, v1r3, Pascal
constexpr
index_t
BlockSize
=
128
;
constexpr
index_t
NPerBlock
=
16
;
constexpr
index_t
KPerBlock
=
128
;
constexpr
index_t
CPerBlock
=
8
;
constexpr
index_t
HoPerBlock
=
2
;
constexpr
index_t
WoPerBlock
=
2
;
constexpr
index_t
NPerThread
=
4
;
constexpr
index_t
KPerThread
=
8
;
constexpr
index_t
HoPerThread
=
1
;
constexpr
index_t
WoPerThread
=
2
;
constexpr
index_t
GemmMPerThreadSubC
=
4
;
constexpr
index_t
GemmNPerThreadSubC
=
4
;
constexpr
index_t
GemmMLevel0Cluster
=
4
;
constexpr
index_t
GemmNLevel0Cluster
=
2
;
constexpr
index_t
GemmMLevel1Cluster
=
4
;
constexpr
index_t
GemmNLevel1Cluster
=
2
;
constexpr
index_t
GemmKPerThreadLoop
=
1
;
constexpr
index_t
GemmDataPerReadA
=
4
;
constexpr
index_t
GemmDataPerReadB
=
4
;
using
InBlockReorderSrcSubLengths_NCHW
=
Sequence
<
4
,
1
,
1
,
1
>
;
using
InBlockReorderSrcClusterLengths_NCHW
=
Sequence
<
4
,
8
,
2
,
2
>
;
using
InBlockReorderMapThreadCluster2SrcCluster_CHNW2NCHW
=
Sequence
<
1
,
2
,
0
,
3
>
;
constexpr
index_t
InBlockReorderDataPerRead_W
=
1
;
// v1r3 cannot do vector load NCHW
constexpr
index_t
InBlockReorderDataPerWrite_N
=
4
;
using
WeiBlockCopyClusterLengths
=
void
;
constexpr
index_t
WeiBlockCopyDataPerRead_K
=
4
;
constexpr
index_t
OutThreadCopyDataPerWrite_W
=
2
;
#endif
constexpr
index_t
GridSize
=
((
N
+
NPerBlock
-
1
)
/
NPerBlock
)
*
((
K
+
KPerBlock
-
1
)
/
KPerBlock
)
*
((
Ho
+
HoPerBlock
-
1
)
/
HoPerBlock
)
*
((
Wo
+
WoPerBlock
-
1
)
/
WoPerBlock
);
printf
(
"%s: BlockSize %u, GridSize %u
\n
"
,
__func__
,
BlockSize
,
GridSize
);
for
(
index_t
i
=
0
;
i
<
nrepeat
;
++
i
)
{
constexpr
auto
gridwise_conv
=
#if 0
GridwiseConvolutionImplicitGemm_v1r3_nchw_cyxk_nkhw
#else
GridwiseConvolutionImplicitGemm_v1r3_nchw_cyxk_nkhw_lds_double_buffer
#endif
<
GridSize
,
BlockSize
,
T
,
decltype
(
in_nchw_desc
),
decltype
(
wei_cyxk_desc
),
decltype
(
out_nkhw_desc
),
NPerBlock
,
KPerBlock
,
CPerBlock
,
HoPerBlock
,
WoPerBlock
,
NPerThread
,
KPerThread
,
HoPerThread
,
WoPerThread
,
GemmMPerThreadSubC
,
GemmNPerThreadSubC
,
GemmMLevel0Cluster
,
GemmNLevel0Cluster
,
GemmMLevel1Cluster
,
GemmNLevel1Cluster
,
GemmKPerThreadLoop
,
GemmDataPerReadA
,
GemmDataPerReadB
,
InBlockReorderSrcSubLengths_NCHW
,
InBlockReorderSrcClusterLengths_NCHW
,
InBlockReorderMapThreadCluster2SrcCluster_CHNW2NCHW
,
InBlockReorderDataPerRead_W
,
InBlockReorderDataPerWrite_N
,
WeiBlockCopyClusterLengths
,
WeiBlockCopyDataPerRead_K
,
OutThreadCopyDataPerWrite_W
>
{};
float
time
=
launch_and_time_kernel
(
run_gridwise_convolution_kernel
<
decltype
(
gridwise_conv
),
T
>
,
dim3
(
GridSize
),
dim3
(
BlockSize
),
0
,
static_cast
<
T
*>
(
in_nchw_device_buf
.
GetDeviceBuffer
()),
static_cast
<
T
*>
(
wei_cyxk_device_buf
.
GetDeviceBuffer
()),
static_cast
<
T
*>
(
out_nkhw_device_buf
.
GetDeviceBuffer
()));
printf
(
"Elapsed time : %f ms, %f TFlop/s
\n
"
,
time
,
(
float
)
calculate_convolution_flops
(
InDesc
{},
WeiDesc
{},
OutDesc
{})
/
(
std
::
size_t
(
1000
)
*
1000
*
1000
)
/
time
);
usleep
(
std
::
min
(
time
*
1000
,
float
(
10000
)));
}
out_nkhw_device_buf
.
FromDevice
(
out_nkhw
.
mData
.
data
());
}
driver/include/device_convolution_implicit_gemm_v2_chwn_cyxk_khwn.hpp
deleted
100644 → 0
View file @
80901f59
#pragma once
#include <unistd.h>
#include "device.hpp"
#include "tensor.hpp"
#include "gridwise_convolution_kernel_wrapper.hpp"
#include "gridwise_convolution_implicit_gemm_v2_chwn_cyxk_khwn.hpp"
#include "gridwise_convolution_implicit_gemm_v2_chwn_cyxk_khwn_lds_double_buffer.hpp"
using
namespace
ck
;
template
<
class
T
,
class
InDesc
,
class
WeiDesc
,
class
OutDesc
>
void
device_convolution_implicit_gemm_v2_chwn_cyxk_khwn
(
InDesc
,
const
Tensor
<
T
>&
in_nchw
,
WeiDesc
,
const
Tensor
<
T
>&
wei_kcyx
,
OutDesc
,
Tensor
<
T
>&
out_nkhw
,
index_t
nrepeat
)
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
in_nchw_desc
=
InDesc
{};
constexpr
auto
wei_kcyx_desc
=
WeiDesc
{};
constexpr
auto
out_nkhw_desc
=
OutDesc
{};
constexpr
index_t
N
=
in_nchw_desc
.
GetLength
(
I0
);
constexpr
index_t
Hi
=
in_nchw_desc
.
GetLength
(
I2
);
constexpr
index_t
Wi
=
in_nchw_desc
.
GetLength
(
I3
);
constexpr
index_t
Ho
=
out_nkhw_desc
.
GetLength
(
I2
);
constexpr
index_t
Wo
=
out_nkhw_desc
.
GetLength
(
I3
);
constexpr
index_t
K
=
wei_kcyx_desc
.
GetLength
(
I0
);
constexpr
index_t
C
=
wei_kcyx_desc
.
GetLength
(
I1
);
constexpr
index_t
Y
=
wei_kcyx_desc
.
GetLength
(
I2
);
constexpr
index_t
X
=
wei_kcyx_desc
.
GetLength
(
I3
);
constexpr
index_t
BGhostRead
=
(
Y
-
1
)
*
Wi
+
(
X
-
1
);
// convert in_nchw to in_cnhw
auto
in_chwn_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
C
,
Hi
,
Wi
,
N
>
{});
ostream_ConstantTensorDescriptor
(
in_chwn_desc
,
std
::
cout
<<
"in_chwn_desc: "
);
Tensor
<
T
>
in_chwn
(
make_TensorDescriptor
(
in_chwn_desc
));
make_ParallelTensorFunctor
(
[
&
](
auto
n
,
auto
c
,
auto
hi
,
auto
wi
)
{
in_chwn
(
c
,
hi
,
wi
,
n
)
=
in_nchw
(
n
,
c
,
hi
,
wi
);
},
N
,
C
,
Hi
,
Wi
)(
std
::
thread
::
hardware_concurrency
());
// convert wei_kcyx to wei_cyxk
auto
wei_cyxk_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
C
,
Y
,
X
,
K
>
{});
ostream_ConstantTensorDescriptor
(
wei_cyxk_desc
,
std
::
cout
<<
"wei_cyxk_desc: "
);
Tensor
<
T
>
wei_cyxk
(
make_TensorDescriptor
(
wei_cyxk_desc
));
make_ParallelTensorFunctor
(
[
&
](
auto
k
,
auto
c
,
auto
y
,
auto
x
)
{
wei_cyxk
(
c
,
y
,
x
,
k
)
=
wei_kcyx
(
k
,
c
,
y
,
x
);
},
K
,
C
,
Y
,
X
)(
std
::
thread
::
hardware_concurrency
());
// conver out_nkhw to out_knhw
auto
out_khwn_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
K
,
Ho
,
Wo
,
N
>
{});
ostream_ConstantTensorDescriptor
(
out_khwn_desc
,
std
::
cout
<<
"out_khwn_desc: "
);
Tensor
<
T
>
out_khwn
(
make_TensorDescriptor
(
out_khwn_desc
));
#if 0
// 3x3, 34x34
// need to use register double buffer for GEMM
constexpr index_t BPerBlock = 128;
constexpr index_t KPerBlock = 64;
constexpr index_t CPerBlock = 4;
constexpr index_t BPerThread = 8;
constexpr index_t KPerThread = 8;
constexpr index_t GemmMPerThreadSubC = 4;
constexpr index_t GemmNPerThreadSubC = 4;
constexpr index_t GemmMLevel0Cluster = 4;
constexpr index_t GemmNLevel0Cluster = 2;
constexpr index_t GemmMLevel1Cluster = 2;
constexpr index_t GemmNLevel1Cluster = 8;
constexpr index_t GemmKPerThreadLoop = 1;
constexpr index_t InBlockCopyThreadPerDim0 = 4;
constexpr index_t InBlockCopyThreadPerDim1 = 16;
constexpr index_t WeiBlockCopyThreadPerDim0 = 4;
constexpr index_t WeiBlockCopyThreadPerDim1 = 16;
constexpr index_t InBlockCopyDataPerRead = 4;
constexpr index_t WeiBlockCopyDataPerRead = 4;
constexpr index_t OutThreadCopyDataPerWrite = 4;
constexpr index_t BlockSize = 128;
#elif
0
// 1x1, 28x28, 64 threads
constexpr
index_t
BPerBlock
=
64
;
constexpr
index_t
KPerBlock
=
64
;
constexpr
index_t
CPerBlock
=
8
;
constexpr
index_t
BPerThread
=
8
;
constexpr
index_t
KPerThread
=
8
;
constexpr
index_t
GemmMPerThreadSubC
=
4
;
constexpr
index_t
GemmNPerThreadSubC
=
4
;
constexpr
index_t
GemmMLevel0Cluster
=
4
;
constexpr
index_t
GemmNLevel0Cluster
=
2
;
constexpr
index_t
GemmMLevel1Cluster
=
2
;
constexpr
index_t
GemmNLevel1Cluster
=
4
;
constexpr
index_t
GemmKPerThreadLoop
=
1
;
constexpr
index_t
GemmThreadPerColumnPerCluster
=
8
;
constexpr
index_t
GemmThreadPerRowPerCluster
=
8
;
constexpr
index_t
InBlockCopyThreadPerDim0
=
4
;
constexpr
index_t
InBlockCopyThreadPerDim1
=
16
;
constexpr
index_t
WeiBlockCopyThreadPerDim0
=
4
;
constexpr
index_t
WeiBlockCopyThreadPerDim1
=
16
;
constexpr
index_t
InBlockCopyDataPerRead
=
4
;
constexpr
index_t
WeiBlockCopyDataPerRead
=
4
;
constexpr
index_t
BlockSize
=
64
;
#elif 0
// 1x1, 28x28, 128 threads, no lds-double-buffer
// 1x1, 28x28, 128 threads, with lds-double-buffer, max_register = 128
constexpr
index_t
BPerBlock
=
64
;
constexpr
index_t
KPerBlock
=
128
;
constexpr
index_t
CPerBlock
=
8
;
constexpr
index_t
BPerThread
=
8
;
constexpr
index_t
KPerThread
=
8
;
constexpr
index_t
GemmMPerThreadSubC
=
4
;
constexpr
index_t
GemmNPerThreadSubC
=
4
;
constexpr
index_t
GemmMLevel0Cluster
=
4
;
constexpr
index_t
GemmNLevel0Cluster
=
2
;
constexpr
index_t
GemmMLevel1Cluster
=
4
;
constexpr
index_t
GemmNLevel1Cluster
=
4
;
constexpr
index_t
GemmKPerThreadLoop
=
1
;
constexpr
index_t
GemmThreadPerColumnPerCluster
=
8
;
constexpr
index_t
GemmThreadPerRowPerCluster
=
8
;
constexpr
index_t
InBlockCopyThreadPerDim0
=
4
;
constexpr
index_t
InBlockCopyThreadPerDim1
=
16
;
constexpr
index_t
WeiBlockCopyThreadPerDim0
=
4
;
constexpr
index_t
WeiBlockCopyThreadPerDim1
=
16
;
constexpr
index_t
InBlockCopyDataPerRead
=
4
;
constexpr
index_t
WeiBlockCopyDataPerRead
=
4
;
constexpr
index_t
BlockSize
=
128
;
#elif 0
// 1x1, 28x28, 256 thread
constexpr
index_t
BPerBlock
=
128
;
constexpr
index_t
KPerBlock
=
128
;
constexpr
index_t
CPerBlock
=
8
;
constexpr
index_t
BPerThread
=
8
;
constexpr
index_t
KPerThread
=
8
;
constexpr
index_t
GemmMPerThreadSubC
=
4
;
constexpr
index_t
GemmNPerThreadSubC
=
4
;
constexpr
index_t
GemmMLevel0Cluster
=
4
;
constexpr
index_t
GemmNLevel0Cluster
=
4
;
constexpr
index_t
GemmMLevel1Cluster
=
4
;
constexpr
index_t
GemmNLevel1Cluster
=
4
;
constexpr
index_t
GemmKPerThreadLoop
=
1
;
constexpr
index_t
GemmThreadPerColumnPerCluster
=
8
;
constexpr
index_t
GemmThreadPerRowPerCluster
=
8
;
constexpr
index_t
InBlockCopyThreadPerDim0
=
4
;
constexpr
index_t
InBlockCopyThreadPerDim1
=
16
;
constexpr
index_t
WeiBlockCopyThreadPerDim0
=
4
;
constexpr
index_t
WeiBlockCopyThreadPerDim1
=
16
;
constexpr
index_t
InBlockCopyDataPerRead
=
4
;
constexpr
index_t
WeiBlockCopyDataPerRead
=
4
;
constexpr
index_t
BlockSize
=
256
;
#elif 0
// 1x1, 14x14, Pascal, enable lds_double_buffer, disable register double buffer
constexpr
index_t
BPerBlock
=
64
;
constexpr
index_t
KPerBlock
=
128
;
constexpr
index_t
CPerBlock
=
8
;
constexpr
index_t
BPerThread
=
8
;
constexpr
index_t
KPerThread
=
8
;
constexpr
index_t
GemmMPerThreadSubC
=
4
;
constexpr
index_t
GemmNPerThreadSubC
=
4
;
constexpr
index_t
GemmMLevel0Cluster
=
4
;
constexpr
index_t
GemmNLevel0Cluster
=
2
;
constexpr
index_t
GemmMLevel1Cluster
=
4
;
constexpr
index_t
GemmNLevel1Cluster
=
4
;
constexpr
index_t
GemmKPerThreadLoop
=
1
;
constexpr
index_t
GemmDataPerReadA
=
4
;
constexpr
index_t
GemmDataPerReadB
=
4
;
constexpr
index_t
InBlockCopyThreadPerDim0
=
4
;
constexpr
index_t
InBlockCopyThreadPerDim1
=
16
;
constexpr
index_t
WeiBlockCopyThreadPerDim0
=
4
;
constexpr
index_t
WeiBlockCopyThreadPerDim1
=
16
;
constexpr
index_t
InBlockCopyDataPerRead
=
4
;
constexpr
index_t
WeiBlockCopyDataPerRead
=
4
;
constexpr
index_t
OutThreadCopyDataPerWrite
=
4
;
constexpr
index_t
BlockSize
=
128
;
#elif 1
// 1x1, 14x14, Vega 20, enable lds_double_buffer, disable register_double_buffer
constexpr
index_t
BPerBlock
=
128
;
constexpr
index_t
KPerBlock
=
128
;
constexpr
index_t
CPerBlock
=
8
;
constexpr
index_t
BPerThread
=
8
;
constexpr
index_t
KPerThread
=
8
;
constexpr
index_t
GemmMPerThreadSubC
=
4
;
constexpr
index_t
GemmNPerThreadSubC
=
4
;
constexpr
index_t
GemmMLevel0Cluster
=
4
;
constexpr
index_t
GemmNLevel0Cluster
=
4
;
constexpr
index_t
GemmMLevel1Cluster
=
4
;
constexpr
index_t
GemmNLevel1Cluster
=
4
;
constexpr
index_t
GemmKPerThreadLoop
=
1
;
constexpr
index_t
GemmDataPerReadA
=
4
;
constexpr
index_t
GemmDataPerReadB
=
4
;
constexpr
index_t
InBlockCopyThreadPerDim0
=
4
;
constexpr
index_t
InBlockCopyThreadPerDim1
=
16
;
constexpr
index_t
WeiBlockCopyThreadPerDim0
=
4
;
constexpr
index_t
WeiBlockCopyThreadPerDim1
=
16
;
constexpr
index_t
InBlockCopyDataPerRead
=
4
;
constexpr
index_t
WeiBlockCopyDataPerRead
=
4
;
constexpr
index_t
OutThreadCopyDataPerWrite
=
4
;
constexpr
index_t
BlockSize
=
256
;
#endif
constexpr
index_t
GridSize
=
((
N
*
Hi
*
Wi
+
BPerBlock
-
1
)
/
BPerBlock
)
*
((
K
+
KPerBlock
-
1
)
/
KPerBlock
);
printf
(
"%s: BlockSize %u, GridSize %u
\n
"
,
__func__
,
BlockSize
,
GridSize
);
// mem
std
::
size_t
data_sz
=
sizeof
(
T
);
DeviceMem
in_chwn_device_buf
(
data_sz
*
(
in_chwn
.
mDesc
.
GetElementSpace
()
+
BGhostRead
+
BPerBlock
));
// reserve extra space for BGhostRead
DeviceMem
wei_cyxk_device_buf
(
data_sz
*
wei_cyxk
.
mDesc
.
GetElementSpace
());
DeviceMem
out_khwn_device_buf
(
data_sz
*
out_khwn
.
mDesc
.
GetElementSpace
());
in_chwn_device_buf
.
ToDevice
(
in_chwn
.
mData
.
data
());
wei_cyxk_device_buf
.
ToDevice
(
wei_cyxk
.
mData
.
data
());
out_khwn_device_buf
.
ToDevice
(
out_khwn
.
mData
.
data
());
for
(
index_t
i
=
0
;
i
<
nrepeat
;
++
i
)
{
constexpr
auto
gridwise_conv
=
#if 0
GridwiseConvolutionImplicitGemm_v2_chwn_cyxk_khwn
#else
GridwiseConvolutionImplicitGemm_v2_chwn_cyxk_khwn_lds_double_buffer
#endif
<
GridSize
,
BlockSize
,
T
,
decltype
(
in_chwn_desc
),
decltype
(
wei_cyxk_desc
),
decltype
(
out_khwn_desc
),
BPerBlock
,
KPerBlock
,
CPerBlock
,
BPerThread
,
KPerThread
,
GemmMPerThreadSubC
,
GemmNPerThreadSubC
,
GemmMLevel0Cluster
,
GemmNLevel0Cluster
,
GemmMLevel1Cluster
,
GemmNLevel1Cluster
,
GemmKPerThreadLoop
,
GemmDataPerReadA
,
GemmDataPerReadB
,
InBlockCopyThreadPerDim0
,
InBlockCopyThreadPerDim1
,
WeiBlockCopyThreadPerDim0
,
WeiBlockCopyThreadPerDim1
,
InBlockCopyDataPerRead
,
WeiBlockCopyDataPerRead
,
OutThreadCopyDataPerWrite
>
{};
float
time
=
launch_and_time_kernel
(
run_gridwise_convolution_kernel
<
decltype
(
gridwise_conv
),
T
>
,
dim3
(
GridSize
),
dim3
(
BlockSize
),
0
,
static_cast
<
T
*>
(
in_chwn_device_buf
.
GetDeviceBuffer
()),
static_cast
<
T
*>
(
wei_cyxk_device_buf
.
GetDeviceBuffer
()),
static_cast
<
T
*>
(
out_khwn_device_buf
.
GetDeviceBuffer
()));
printf
(
"Elapsed time : %f ms, %f TFlop/s
\n
"
,
time
,
(
float
)
calculate_convolution_flops
(
InDesc
{},
WeiDesc
{},
OutDesc
{})
/
(
std
::
size_t
(
1000
)
*
1000
*
1000
)
/
time
);
usleep
(
std
::
min
(
time
*
1000
,
float
(
10000
)));
}
out_khwn_device_buf
.
FromDevice
(
out_khwn
.
mData
.
data
());
// convert out_khwn to out_nkhw
make_ParallelTensorFunctor
(
[
&
](
auto
n
,
auto
k
,
auto
ho
,
auto
wo
)
{
out_nkhw
(
n
,
k
,
ho
,
wo
)
=
out_khwn
(
k
,
ho
,
wo
,
n
);
},
N
,
K
,
Ho
,
Wo
)(
std
::
thread
::
hardware_concurrency
());
}
driver/include/device_convolution_implicit_gemm_v3_nchw_cyxk_nkhw.hpp
deleted
100644 → 0
View file @
80901f59
#pragma once
#include <unistd.h>
#include "device.hpp"
#include "tensor.hpp"
#include "gridwise_convolution_kernel_wrapper.hpp"
#include "gridwise_convolution_implicit_gemm_v3_nchw_cyxk_nkhw.hpp"
#include "gridwise_convolution_implicit_gemm_v3_nchw_cyxk_nkhw_lds_double_buffer.hpp"
template
<
class
T
,
class
InDesc
,
class
WeiDesc
,
class
OutDesc
>
void
device_convolution_implicit_gemm_v3_nchw_cyxk_nkhw
(
InDesc
,
const
Tensor
<
T
>&
in_nchw
,
WeiDesc
,
const
Tensor
<
T
>&
wei_kcyx
,
OutDesc
,
Tensor
<
T
>&
out_nkhw
,
index_t
nrepeat
)
{
using
namespace
ck
;
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
in_nchw_desc
=
InDesc
{};
constexpr
auto
wei_kcyx_desc
=
WeiDesc
{};
constexpr
auto
out_nkhw_desc
=
OutDesc
{};
constexpr
index_t
Hi
=
in_nchw_desc
.
GetLength
(
I2
);
constexpr
index_t
Wi
=
in_nchw_desc
.
GetLength
(
I3
);
constexpr
index_t
N
=
out_nkhw_desc
.
GetLength
(
I0
);
constexpr
index_t
Ho
=
out_nkhw_desc
.
GetLength
(
I2
);
constexpr
index_t
Wo
=
out_nkhw_desc
.
GetLength
(
I3
);
constexpr
index_t
K
=
wei_kcyx_desc
.
GetLength
(
I0
);
constexpr
index_t
C
=
wei_kcyx_desc
.
GetLength
(
I1
);
constexpr
index_t
Y
=
wei_kcyx_desc
.
GetLength
(
I2
);
constexpr
index_t
X
=
wei_kcyx_desc
.
GetLength
(
I3
);
// reorder weight
auto
wei_cyxk_desc
=
make_ConstantTensorDescriptor_packed
(
Sequence
<
C
,
Y
,
X
,
K
>
{});
ostream_ConstantTensorDescriptor
(
wei_cyxk_desc
,
std
::
cout
<<
"wei_cyxk_desc: "
);
Tensor
<
T
>
wei_cyxk
(
make_TensorDescriptor
(
wei_cyxk_desc
));
auto
f_reorder_kcyx2cyxk
=
[
&
](
auto
k
,
auto
c
,
auto
y
,
auto
x
)
{
wei_cyxk
(
c
,
y
,
x
,
k
)
=
wei_kcyx
(
k
,
c
,
y
,
x
);
};
make_ParallelTensorFunctor
(
f_reorder_kcyx2cyxk
,
K
,
C
,
Y
,
X
)(
std
::
thread
::
hardware_concurrency
());
std
::
size_t
data_sz
=
sizeof
(
T
);
DeviceMem
in_nchw_device_buf
(
data_sz
*
in_nchw
.
mDesc
.
GetElementSpace
());
DeviceMem
wei_cyxk_device_buf
(
data_sz
*
wei_cyxk
.
mDesc
.
GetElementSpace
());
DeviceMem
out_nkhw_device_buf
(
data_sz
*
out_nkhw
.
mDesc
.
GetElementSpace
());
in_nchw_device_buf
.
ToDevice
(
in_nchw
.
mData
.
data
());
wei_cyxk_device_buf
.
ToDevice
(
wei_cyxk
.
mData
.
data
());
out_nkhw_device_buf
.
ToDevice
(
out_nkhw
.
mData
.
data
());
constexpr
index_t
N1
=
2
;
constexpr
index_t
N2
=
4
;
constexpr
index_t
B
=
(
N
*
Ho
*
Wo
)
/
(
N1
*
N2
);
#if 1
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
BPerBlock
=
16
;
constexpr
index_t
KPerBlock
=
128
;
constexpr
index_t
CPerBlock
=
8
;
constexpr
index_t
GemmMPerThreadSubC
=
4
;
constexpr
index_t
GemmNPerThreadSubC
=
4
;
constexpr
index_t
GemmMLevel0Cluster
=
4
;
constexpr
index_t
GemmNLevel0Cluster
=
4
;
constexpr
index_t
GemmMLevel1Cluster
=
4
;
constexpr
index_t
GemmNLevel1Cluster
=
4
;
constexpr
index_t
GemmKPerThreadLoop
=
1
;
constexpr
index_t
GemmDataPerReadA
=
4
;
constexpr
index_t
GemmDataPerReadB
=
4
;
using
InBlockCopySubLengths_C_N1_B_N2
=
Sequence
<
1
,
1
,
1
,
4
>
;
using
InBlockCopyClusterLengths_C_N1_B_N2
=
Sequence
<
8
,
2
,
16
,
1
>
;
constexpr
index_t
InBlockCopySrcDataPerRead_B
=
1
;
constexpr
index_t
InBlockCopyDstDataPerWrite_N2
=
4
;
using
WeiBlockCopySubLengths_C_K
=
Sequence
<
1
,
4
>
;
using
WeiBlockCopyClusterLengths_C_K
=
Sequence
<
8
,
32
>
;
constexpr
index_t
WeiBlockCopyDataPerAccess_K
=
4
;
#endif
constexpr
index_t
GridSize
=
((
B
+
BPerBlock
-
1
)
/
BPerBlock
)
*
((
K
+
KPerBlock
-
1
)
/
KPerBlock
);
printf
(
"%s: BlockSize %u, GridSize %u
\n
"
,
__func__
,
BlockSize
,
GridSize
);
for
(
index_t
i
=
0
;
i
<
nrepeat
;
++
i
)
{
constexpr
auto
gridwise_conv
=
#if 0
GridwiseConvolutionImplicitGemm_v3_nchw_cyxk_nkhw
#else
GridwiseConvolutionImplicitGemm_v3_nchw_cyxk_nkhw_lds_double_buffer
#endif
<
GridSize
,
BlockSize
,
T
,
decltype
(
in_nchw_desc
),
decltype
(
wei_cyxk_desc
),
decltype
(
out_nkhw_desc
),
BPerBlock
,
KPerBlock
,
CPerBlock
,
N1
,
N2
,
GemmMPerThreadSubC
,
GemmNPerThreadSubC
,
GemmMLevel0Cluster
,
GemmNLevel0Cluster
,
GemmMLevel1Cluster
,
GemmNLevel1Cluster
,
GemmKPerThreadLoop
,
GemmDataPerReadA
,
GemmDataPerReadB
,
InBlockCopySubLengths_C_N1_B_N2
,
InBlockCopyClusterLengths_C_N1_B_N2
,
InBlockCopySrcDataPerRead_B
,
InBlockCopyDstDataPerWrite_N2
,
WeiBlockCopySubLengths_C_K
,
WeiBlockCopyClusterLengths_C_K
,
WeiBlockCopyDataPerAccess_K
>
{};
float
time
=
launch_and_time_kernel
(
run_gridwise_convolution_kernel
<
decltype
(
gridwise_conv
),
T
>
,
dim3
(
GridSize
),
dim3
(
BlockSize
),
0
,
static_cast
<
T
*>
(
in_nchw_device_buf
.
GetDeviceBuffer
()),
static_cast
<
T
*>
(
wei_cyxk_device_buf
.
GetDeviceBuffer
()),
static_cast
<
T
*>
(
out_nkhw_device_buf
.
GetDeviceBuffer
()));
printf
(
"Elapsed time : %f ms, %f TFlop/s
\n
"
,
time
,
(
float
)
calculate_convolution_flops
(
InDesc
{},
WeiDesc
{},
OutDesc
{})
/
(
std
::
size_t
(
1000
)
*
1000
*
1000
)
/
time
);
usleep
(
std
::
min
(
time
*
1000
,
float
(
10000
)));
}
out_nkhw_device_buf
.
FromDevice
(
out_nkhw
.
mData
.
data
());
}
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment