Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
a037693f
Commit
a037693f
authored
Dec 01, 2021
by
ltqin
Browse files
Merge branch 'develop' into conv_splitk_f32
parents
0694d6ed
4041850f
Changes
38
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
2745 additions
and
434 deletions
+2745
-434
CMakeLists.txt
CMakeLists.txt
+1
-0
composable_kernel/include/tensor_operation/blockwise_gemm_dlops_v3.hpp
...rnel/include/tensor_operation/blockwise_gemm_dlops_v3.hpp
+92
-100
composable_kernel/include/tensor_operation/gridwise_gemm_dlops_v3.hpp
...ernel/include/tensor_operation/gridwise_gemm_dlops_v3.hpp
+1920
-0
composable_kernel/include/tensor_operation/threadwise_gemm_dlops_v3.hpp
...nel/include/tensor_operation/threadwise_gemm_dlops_v3.hpp
+104
-96
composable_kernel/include/tensor_operation/threadwise_tensor_slice_transfer.hpp
...ude/tensor_operation/threadwise_tensor_slice_transfer.hpp
+35
-0
composable_kernel/include/utility/amd_buffer_addressing.hpp
composable_kernel/include/utility/amd_buffer_addressing.hpp
+24
-16
composable_kernel/include/utility/amd_inline_asm.hpp
composable_kernel/include/utility/amd_inline_asm.hpp
+14
-14
composable_kernel/include/utility/amd_xdlops.hpp
composable_kernel/include/utility/amd_xdlops.hpp
+4
-4
composable_kernel/include/utility/config.hpp
composable_kernel/include/utility/config.hpp
+23
-2
composable_kernel/include/utility/data_type.hpp
composable_kernel/include/utility/data_type.hpp
+3
-3
composable_kernel/include/utility/dynamic_buffer.hpp
composable_kernel/include/utility/dynamic_buffer.hpp
+40
-0
composable_kernel/include/utility/inner_product.hpp
composable_kernel/include/utility/inner_product.hpp
+2
-2
composable_kernel/include/utility/magic_division.hpp
composable_kernel/include/utility/magic_division.hpp
+1
-1
composable_kernel/include/utility/statically_indexed_array.hpp
...sable_kernel/include/utility/statically_indexed_array.hpp
+44
-0
composable_kernel/include/utility/type.hpp
composable_kernel/include/utility/type.hpp
+9
-1
example/1_gemm_xdl/gemm_xdl.cpp
example/1_gemm_xdl/gemm_xdl.cpp
+4
-5
host/driver_offline/CMakeLists.txt
host/driver_offline/CMakeLists.txt
+9
-0
host/driver_offline/include/device_convolution_add_forward_implicit_gemm_v5r1_dlops_nc0hwc1_kc0yxc1_nk0hwk1.hpp
...ward_implicit_gemm_v5r1_dlops_nc0hwc1_kc0yxc1_nk0hwk1.hpp
+220
-0
host/driver_offline/include/device_convolution_forward_implicit_gemm_v5r1_dlops_nc0hwc1_kc0yxc1_nk0hwk1.hpp
...ward_implicit_gemm_v5r1_dlops_nc0hwc1_kc0yxc1_nk0hwk1.hpp
+196
-0
host/driver_offline/include/device_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw.hpp
...ution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw.hpp
+0
-190
No files found.
CMakeLists.txt
View file @
a037693f
...
...
@@ -200,3 +200,4 @@ enable_cppcheck(
add_subdirectory
(
host
)
add_subdirectory
(
example
)
add_subdirectory
(
profiler
)
add_subdirectory
(
test
)
composable_kernel/include/tensor_operation/blockwise_gemm_dlops_v3.hpp
View file @
a037693f
...
...
@@ -10,99 +10,99 @@ template <index_t BlockSize,
typename
FloatA
,
typename
FloatB
,
typename
FloatC
,
typename
BlockMatrixA
,
typename
BlockMatrixB
,
typename
ThreadMatrixC
,
index_t
KPerThread
,
index_t
HPerThread
,
index_t
WPerThread
,
typename
ABlockDesc_E1_K1_E2
,
typename
BBlockDesc_E1_N_Ho_Wo_E2
,
typename
CThreadDesc_K_N_Ho_Wo
,
index_t
EPerThreadLoop
,
index_t
ThreadGemmADataPerRead_K
,
index_t
ThreadGemmBDataPerRead_W
>
index_t
KPerThreadLoop
>
struct
BlockwiseGemmDlops_km_kn_m0m1n0n1_v3
{
struct
MatrixIndex
{
index_t
k
;
index_t
h
;
index_t
w
;
};
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
static
constexpr
auto
I4
=
Number
<
4
>
{};
using
AIndex
=
MultiIndex
<
3
>
;
using
BIndex
=
MultiIndex
<
3
>
;
using
CIndex
=
MultiIndex
<
4
>
;
static
constexpr
auto
E1
=
ABlockDesc_E1_K1_E2
{}.
GetLength
(
I0
);
static
constexpr
auto
KPerBlock
=
ABlockDesc_E1_K1_E2
{}.
GetLength
(
I1
);
static
constexpr
auto
E2
=
ABlockDesc_E1_K1_E2
{}.
GetLength
(
I2
);
static
constexpr
auto
HoPerBlock
=
BBlockDesc_E1_N_Ho_Wo_E2
{}.
GetLength
(
I2
);
static
constexpr
auto
WoPerBlock
=
BBlockDesc_E1_N_Ho_Wo_E2
{}.
GetLength
(
I3
);
// HACK: fix this @Jing Zhang
static
constexpr
index_t
KPerThreadSubC
=
4
;
static
constexpr
auto
KPerThread
=
CThreadDesc_K_N_Ho_Wo
{}.
GetLength
(
I0
);
static
constexpr
auto
HoPerThread
=
CThreadDesc_K_N_Ho_Wo
{}.
GetLength
(
I2
);
static
constexpr
auto
WoPerThread
=
CThreadDesc_K_N_Ho_Wo
{}.
GetLength
(
I3
);
static
constexpr
auto
a_thread_mtx_
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
EPerThreadLoop
>
{},
Number
<
KPerThread
SubC
>
{}));
make_tuple
(
Number
<
EPerThreadLoop
>
{},
Number
<
KPerThread
Loop
>
{},
Number
<
E2
>
{}));
static
constexpr
auto
b_thread_mtx_
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
EPerThreadLoop
>
{},
Number
<
1
>
{},
Number
<
HPerThread
>
{},
Number
<
WPerThread
>
{}));
static
constexpr
auto
b_thread_mtx_
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
EPerThreadLoop
>
{},
Number
<
1
>
{},
Number
<
HoPerThread
>
{},
Number
<
WoPerThread
>
{},
Number
<
E2
>
{}));
static
constexpr
auto
c_thread_mtx_
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
KPerThreadSubC
>
{},
Number
<
1
>
{},
Number
<
HPerThread
>
{},
Number
<
WPerThread
>
{}));
using
AThreadCopy
=
ThreadwiseTensorSliceTransfer_v4
<
FloatA
,
FloatA
,
BlockMatrixA
,
decltype
(
a_thread_mtx_
),
Sequence
<
EPerThreadLoop
,
KPerThreadSubC
>
,
Sequence
<
0
,
1
>
,
1
,
ThreadGemmADataPerRead_K
,
1
>
;
Number
<
KPerThreadLoop
>
{},
Number
<
1
>
{},
Number
<
HoPerThread
>
{},
Number
<
WoPerThread
>
{}));
__device__
BlockwiseGemmDlops_km_kn_m0m1n0n1_v3
()
:
c_thread_
be
gin_
mtx
_idx_
{
GetBeginOfThread
MatrixC
(
get_thread_local_1d_id
())},
a_thread_copy_
{
make_tuple
(
0
,
c_thread_
be
gin_
mtx
_idx_
.
k
*
KPerThread
)}
:
c_thread_
ori
gin_
data
_idx_
{
GetBeginOf
C
Thread
Desc_K_N_Ho_Wo
(
get_thread_local_1d_id
())},
a_thread_copy_
{
make_tuple
(
0
,
c_thread_
ori
gin_
data
_idx_
[
I0
]
*
KPerThread
,
0
)}
{
static_assert
(
Block
MatrixA
::
IsKnownAtCompileTime
()
&&
Block
MatrixB
::
IsKnownAtCompileTime
()
&&
Thread
MatrixC
::
IsKnownAtCompileTime
(),
static_assert
(
A
Block
Desc_E1_K1_E2
::
IsKnownAtCompileTime
()
&&
B
Block
Desc_E1_N_Ho_Wo_E2
::
IsKnownAtCompileTime
()
&&
C
Thread
Desc_K_N_Ho_Wo
::
IsKnownAtCompileTime
(),
"wrong! Desc should be known at compile-time"
);
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
static_assert
(
BlockMatrixA
{}.
GetLength
(
I0
)
==
BlockMatrixB
{}.
GetLength
(
I0
),
"wrong! K dimension not consistent
\n
"
);
static_assert
(
ABlockDesc_E1_K1_E2
{}.
GetLength
(
I0
)
==
BBlockDesc_E1_N_Ho_Wo_E2
{}.
GetLength
(
I0
)
&&
ABlockDesc_E1_K1_E2
{}.
GetLength
(
I2
)
==
BBlockDesc_E1_N_Ho_Wo_E2
{}.
GetLength
(
I4
),
"wrong! E dimension not consistent
\n
"
);
constexpr
index_t
K
=
BlockMatrixA
{}.
GetLength
(
I1
);
// A is transposed
constexpr
index_t
H
=
BlockMatrixB
{}.
GetLength
(
I2
);
constexpr
index_t
W
=
BlockMatrixB
{}.
GetLength
(
I3
);
static_assert
(
E1
%
EPerThreadLoop
==
0
,
""
);
static_assert
(
KPerThread
%
KPerThreadLoop
==
0
,
""
);
static_assert
(
K
%
KPerThread
==
0
&&
H
%
HPerThread
==
0
&&
W
%
WPerThread
==
0
,
static_assert
(
KPerBlock
%
KPerThread
==
0
&&
HoPerBlock
%
HoPerThread
==
0
&&
WoPerBlock
%
WoPerThread
==
0
,
"wrong! Cannot evenly divide work among
\n
"
);
constexpr
auto
KThreadCluster
=
K
/
KPerThread
;
constexpr
auto
HThreadCluster
=
H
/
HPerThread
;
constexpr
auto
WThreadCluster
=
W
/
WPerThread
;
constexpr
auto
KThreadCluster
=
K
PerBlock
/
KPerThread
;
constexpr
auto
HThreadCluster
=
H
oPerBlock
/
H
o
PerThread
;
constexpr
auto
WThreadCluster
=
W
oPerBlock
/
W
o
PerThread
;
static_assert
(
BlockSize
==
KThreadCluster
*
HThreadCluster
*
WThreadCluster
,
"wrong! wrong blocksize
\n
"
);
}
__device__
static
constexpr
auto
GetThread
MatrixC
Lengths
()
__device__
static
constexpr
auto
Get
C
Thread
Desc_K_N_Ho_Wo
Lengths
()
{
return
Sequence
<
KPerThread
,
1
,
HPerThread
,
WPerThread
>
{};
return
Sequence
<
KPerThread
,
I
1
,
H
o
PerThread
,
W
o
PerThread
>
{};
}
__device__
static
Matrix
Index
GetBeginOfThread
MatrixC
(
index_t
thread_id
)
__device__
static
C
Index
GetBeginOf
C
Thread
Desc_K_N_Ho_Wo
(
index_t
thread_id
)
{
constexpr
index_t
H
=
BlockMatrixB
{}.
GetLength
(
Number
<
2
>
{});
constexpr
index_t
W
=
BlockMatrixB
{}.
GetLength
(
Number
<
3
>
{});
constexpr
auto
num_w_threads
=
W
/
WPerThread
;
constexpr
auto
num_h_threads
=
H
/
HPerThread
;
constexpr
auto
num_hw_threads
=
num_w_threads
*
num_h_threads
;
index_t
k_thread_id
=
thread_id
/
num_hw_threads
;
index_t
hw_thread_id
=
thread_id
%
num_hw_threads
;
index_t
h_thread_id
=
hw_thread_id
/
num_w_threads
;
index_t
w_thread_id
=
hw_thread_id
%
num_w_threads
;
return
MatrixIndex
{
k_thread_id
,
h_thread_id
,
w_thread_id
};
constexpr
auto
K0
=
KPerBlock
/
KPerThread
;
constexpr
auto
N0
=
I1
;
constexpr
auto
H0
=
HoPerBlock
/
HoPerThread
;
constexpr
auto
W0
=
WoPerBlock
/
WoPerThread
;
constexpr
auto
c_threadid_to_k_n_h_w_thread_cluster_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
K0
,
N0
,
H0
,
W0
))),
make_tuple
(
Sequence
<
0
,
1
,
2
,
3
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
const
auto
c_k_n_h_w_thread_cluster_idx
=
c_threadid_to_k_n_h_w_thread_cluster_adaptor
.
CalculateBottomIndex
(
make_multi_index
(
thread_id
));
return
c_k_n_h_w_thread_cluster_idx
;
}
template
<
typename
ABlockBuffer
,
typename
BThreadBuffer
,
typename
CThreadBuffer
>
...
...
@@ -116,19 +116,7 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v3
is_same
<
remove_cvref_t
<
typename
CThreadBuffer
::
type
>
,
remove_cvref_t
<
FloatC
>>::
value
&&
"wrong! inconsistent type"
);
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
a_block_mtx
=
BlockMatrixA
{};
constexpr
auto
EPerBlock
=
a_block_mtx
.
GetLength
(
I0
);
// HACK: fix this @Jing Zhang
constexpr
auto
HoPerThreadSubC
=
2
;
constexpr
auto
WoPerThreadSubC
=
2
;
static_assert
(
KPerThread
%
KPerThreadSubC
==
0
,
""
);
static_assert
(
HPerThread
%
HoPerThreadSubC
==
0
,
""
);
static_assert
(
WPerThread
%
WoPerThreadSubC
==
0
,
""
);
constexpr
auto
a_block_mtx
=
ABlockDesc_E1_K1_E2
{};
// thread A buffer for GEMM
StaticBuffer
<
AddressSpaceEnum_t
::
Vgpr
,
FloatA
,
a_thread_mtx_
.
GetElementSpaceSize
(),
true
>
...
...
@@ -139,42 +127,46 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v3
FloatC
,
decltype
(
a_thread_mtx_
),
decltype
(
b_thread_mtx_
),
decltype
(
c_thread_mtx_
),
HoPerThreadSubC
,
WoPerThreadSubC
>
{};
decltype
(
c_thread_mtx_
)
>
{};
static_for
<
0
,
E
PerBlock
,
EPerThreadLoop
>
{}([
&
](
auto
e_begin
)
{
static_for
<
0
,
KPerThread
,
KPerThread
SubC
>
{}([
&
](
auto
k_begin
)
{
static_for
<
0
,
E
1
,
EPerThreadLoop
>
{}([
&
](
auto
e_begin
)
{
static_for
<
0
,
KPerThread
,
KPerThread
Loop
>
{}([
&
](
auto
k_begin
)
{
a_thread_copy_
.
Run
(
a_block_mtx
,
make_tuple
(
e_begin
,
k_begin
),
make_tuple
(
e_begin
,
k_begin
,
I0
),
a_block_buf
,
a_thread_mtx_
,
make_tuple
(
I0
,
I0
),
make_tuple
(
I0
,
I0
,
I0
),
a_thread_buf
);
static_for
<
0
,
HPerThread
,
HoPerThreadSubC
>
{}([
&
](
auto
h_begin
)
{
static_for
<
0
,
WPerThread
,
WoPerThreadSubC
>
{}([
&
](
auto
w_begin
)
{
threadwise_gemm
.
Run
(
a_thread_buf
,
make_tuple
(
I0
,
I0
),
b_thread_buf
,
make_tuple
(
e_begin
,
I0
,
h_begin
,
w_begin
),
c_thread_buf
,
make_tuple
(
k_begin
,
I0
,
h_begin
,
w_begin
));
});
});
threadwise_gemm
.
Run
(
a_thread_buf
,
make_tuple
(
I0
,
I0
,
I0
),
b_thread_buf
,
make_tuple
(
e_begin
,
I0
,
I0
,
I0
,
I0
),
c_thread_buf
,
make_tuple
(
k_begin
,
I0
,
I0
,
I0
));
});
});
}
template
<
typename
ABlockSliceMoveStepIdx
>
__device__
void
MoveASliceWindow
(
const
BlockMatrixA
&
,
const
ABlockSliceMoveStepIdx
&
a_block_slice_move_step_idx
)
__device__
void
MoveABlockSliceWindow
(
const
ABlockSliceMoveStepIdx
&
a_block_slice_move_step_idx
)
{
a_thread_copy_
.
MoveSrcSliceWindow
(
Block
MatrixA
{},
a_block_slice_move_step_idx
);
a_thread_copy_
.
MoveSrcSliceWindow
(
A
Block
Desc_E1_K1_E2
{},
a_block_slice_move_step_idx
);
}
private:
MatrixIndex
c_thread_begin_mtx_idx_
;
using
AThreadCopy
=
ThreadwiseTensorSliceTransfer_v4
<
FloatA
,
FloatA
,
ABlockDesc_E1_K1_E2
,
decltype
(
a_thread_mtx_
),
Sequence
<
EPerThreadLoop
,
KPerThreadLoop
,
E2
>
,
Sequence
<
0
,
1
,
2
>
,
2
,
E2
,
E2
>
;
CIndex
c_thread_origin_data_idx_
;
AThreadCopy
a_thread_copy_
;
};
...
...
composable_kernel/include/tensor_operation/gridwise_gemm_dlops_v3.hpp
0 → 100644
View file @
a037693f
#ifndef CK_GRIDWISE_GEMM_V3_HPP
#define CK_GRIDWISE_GEMM_V3_HPP
#include "common_header.hpp"
#include "multi_index_transform_helper.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "blockwise_tensor_slice_transfer.hpp"
#include "threadwise_tensor_slice_transfer.hpp"
#include "threadwise_tensor_slice_set.hpp"
#include "blockwise_gemm_dlops_v3.hpp"
namespace
ck
{
#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE
template
<
typename
GridwiseGemm
,
typename
FloatAB
,
typename
FloatC
,
typename
AGridDesc_E0_E1_K0_K1_E2
,
typename
BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2
,
typename
CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2
,
typename
CBlockIdToBlockClusterAdaptor_K_N_H_W
,
bool
HasMainE0BlockLoop
,
ActivTypeEnum_t
ActivType
>
__global__
void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
#endif
kernel_gemm_dlops_v3
(
const
FloatAB
*
__restrict__
p_a_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
const
FloatC
*
__restrict__
p_bias_grid
,
FloatC
*
__restrict__
p_c_grid
,
const
AGridDesc_E0_E1_K0_K1_E2
a_e0_e1_k0_k1_e2_grid_desc
,
const
BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc
,
const
CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2
c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc
,
const
CBlockIdToBlockClusterAdaptor_K_N_H_W
c_blockid_to_k_n_h_w_block_cluster_adaptor
)
{
constexpr
index_t
shared_block_size
=
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()
/
sizeof
(
FloatAB
);
__shared__
FloatAB
p_shared_block
[
shared_block_size
];
GridwiseGemm
::
ConvBiasActiv
(
p_a_grid
,
p_b_grid
,
p_bias_grid
,
p_c_grid
,
p_shared_block
,
a_e0_e1_k0_k1_e2_grid_desc
,
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc
,
c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc
,
c_blockid_to_k_n_h_w_block_cluster_adaptor
,
integral_constant
<
bool
,
HasMainE0BlockLoop
>
{},
integral_constant
<
ActivTypeEnum_t
,
ActivType
>
{});
}
template
<
typename
GridwiseGemm
,
typename
FloatAB
,
typename
FloatC
,
typename
AGridDesc_E0_E1_K0_K1_E2
,
typename
BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2
,
typename
CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2
,
typename
DGridDesc_K0_K1_N_H0_H1_Hx_W0_W1_Wx
,
typename
CBlockIdToBlockClusterAdaptor_K_N_H_W
,
bool
HasMainE0BlockLoop
,
ActivTypeEnum_t
ActivType
>
__global__
void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
#endif
kernel_gemm_dlops_v3_resize_add
(
const
FloatAB
*
__restrict__
p_a_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
const
FloatC
*
__restrict__
p_bias_grid
,
FloatC
*
__restrict__
p_d_grid
,
const
AGridDesc_E0_E1_K0_K1_E2
a_e0_e1_k0_k1_e2_grid_desc
,
const
BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc
,
const
CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2
c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc
,
const
DGridDesc_K0_K1_N_H0_H1_Hx_W0_W1_Wx
d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc
,
const
CBlockIdToBlockClusterAdaptor_K_N_H_W
c_blockid_to_k_n_h_w_block_cluster_adaptor
)
{
constexpr
index_t
shared_block_size
=
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()
/
sizeof
(
FloatAB
);
__shared__
FloatAB
p_shared_block
[
shared_block_size
];
GridwiseGemm
::
ConvBiasActivResizeAdd
(
p_a_grid
,
p_b_grid
,
p_bias_grid
,
p_d_grid
,
p_shared_block
,
a_e0_e1_k0_k1_e2_grid_desc
,
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc
,
c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc
,
d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc
,
c_blockid_to_k_n_h_w_block_cluster_adaptor
,
integral_constant
<
bool
,
HasMainE0BlockLoop
>
{},
integral_constant
<
ActivTypeEnum_t
,
ActivType
>
{});
}
template
<
typename
GridwiseGemm
,
typename
FloatAB
,
typename
FloatC
,
typename
AGridDesc_E0_E1_K0_K1_E2
,
typename
BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2
,
typename
CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2
,
typename
DGridDesc_K0_K1_N_H0_H1_Hx_W0_W1_Wx
,
typename
CBlockIdToBlockClusterAdaptor_K_N_H_W
,
bool
HasMainE0BlockLoop
,
ActivTypeEnum_t
ActivType
>
__global__
void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
#endif
kernel_gemm_dlops_v3_maxpool
(
const
FloatAB
*
__restrict__
p_a_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
const
FloatC
*
__restrict__
p_bias_grid
,
FloatC
*
__restrict__
p_c_grid
,
FloatC
*
__restrict__
p_d_grid
,
const
AGridDesc_E0_E1_K0_K1_E2
a_e0_e1_k0_k1_e2_grid_desc
,
const
BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc
,
const
CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2
c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc
,
const
DGridDesc_K0_K1_N_H0_H1_Hx_W0_W1_Wx
d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc
,
const
CBlockIdToBlockClusterAdaptor_K_N_H_W
c_blockid_to_k_n_h_w_block_cluster_adaptor
)
{
constexpr
index_t
shared_block_size
=
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()
/
sizeof
(
FloatAB
);
__shared__
FloatAB
p_shared_block
[
shared_block_size
];
GridwiseGemm
::
ConvBiasActivMaxpool
(
p_a_grid
,
p_b_grid
,
p_bias_grid
,
p_c_grid
,
p_d_grid
,
p_shared_block
,
a_e0_e1_k0_k1_e2_grid_desc
,
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc
,
c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc
,
d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc
,
c_blockid_to_k_n_h_w_block_cluster_adaptor
,
integral_constant
<
bool
,
HasMainE0BlockLoop
>
{},
integral_constant
<
ActivTypeEnum_t
,
ActivType
>
{});
}
#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER
// pass tensor descriptor by CONSTANT void pointer
// CONSTANT is needed to inform compiler void pointers in the kernel signature are pointing to
// non-modifiable parameter address space, so compiler can enable corresponding optimization
template
<
typename
GridwiseGemm
,
typename
FloatAB
,
typename
FloatC
,
typename
AGridDesc_E0_E1_K0_K1_E2
,
typename
BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2
,
typename
CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2
,
typename
CBlockIdToBlockClusterAdaptor_K_N_H_W
,
bool
HasMainE0BlockLoop
,
ActivTypeEnum_t
ActivType
>
__global__
void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
#endif
kernel_gemm_dlops_v3
(
const
FloatAB
*
__restrict__
p_a_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
const
FloatC
*
__restrict__
p_bias_grid
,
FloatC
*
__restrict__
p_c_grid
,
const
void
CONSTANT
*
p_a_e0_e1_k0_k1_e2_grid_desc
,
const
void
CONSTANT
*
p_b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc
,
const
void
CONSTANT
*
p_c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc
,
const
void
CONSTANT
*
p_c_blockid_to_k_n_h_w_block_cluster_adaptor
)
{
// first cast void CONSTANT void* to void*
// second cast void* to Desc*
// the copy constructor of tensor descriptor doesn't take address_space(4)
const
auto
a_e0_e1_k0_k1_e2_grid_desc
=
*
reinterpret_cast
<
const
AGridDesc_E0_E1_K0_K1_E2
*>
(
cast_pointer_to_generic_address_space
(
p_a_e0_e1_k0_k1_e2_grid_desc
));
const
auto
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc
=
*
reinterpret_cast
<
const
BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2
*>
(
cast_pointer_to_generic_address_space
(
p_b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc
));
const
auto
c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc
=
*
reinterpret_cast
<
const
CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2
*>
(
cast_pointer_to_generic_address_space
(
p_c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc
));
const
auto
c_blockid_to_k_n_h_w_block_cluster_adaptor
=
*
reinterpret_cast
<
const
CBlockIdToBlockClusterAdaptor_K_N_H_W
*>
(
cast_pointer_to_generic_address_space
(
p_c_blockid_to_k_n_h_w_block_cluster_adaptor
));
constexpr
index_t
shared_block_size
=
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()
/
sizeof
(
FloatAB
);
__shared__
FloatAB
p_shared_block
[
shared_block_size
];
GridwiseGemm
::
ConvBiasActiv
(
p_a_grid
,
p_b_grid
,
p_bias_grid
,
p_c_grid
,
p_shared_block
,
a_e0_e1_k0_k1_e2_grid_desc
,
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc
,
c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc
,
c_blockid_to_k_n_h_w_block_cluster_adaptor
,
integral_constant
<
bool
,
HasMainE0BlockLoop
>
{},
integral_constant
<
ActivTypeEnum_t
,
ActivType
>
{});
}
// pass tensor descriptor by CONSTANT void pointer
// CONSTANT is needed to inform compiler void pointers in the kernel signature are pointing to
// non-modifiable parameter address space, so compiler can enable corresponding optimization
template
<
typename
GridwiseGemm
,
typename
FloatAB
,
typename
FloatC
,
typename
AGridDesc_E0_E1_K0_K1_E2
,
typename
BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2
,
typename
CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2
,
typename
DGridDesc_K0_K1_N_H0_H1_Hx_W0_W1_Wx
,
typename
CBlockIdToBlockClusterAdaptor_K_N_H_W
,
bool
HasMainE0BlockLoop
,
ActivTypeEnum_t
ActivType
>
__global__
void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
#endif
kernel_gemm_dlops_v3_resize_add
(
const
FloatAB
*
__restrict__
p_a_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
const
FloatC
*
__restrict__
p_bias_grid
,
FloatC
*
__restrict__
p_d_grid
,
const
void
CONSTANT
*
p_a_e0_e1_k0_k1_e2_grid_desc
,
const
void
CONSTANT
*
p_b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc
,
const
void
CONSTANT
*
p_c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc
,
const
void
CONSTANT
*
p_d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc
,
const
void
CONSTANT
*
p_c_blockid_to_k_n_h_w_block_cluster_adaptor
)
{
// first cast void CONSTANT void* to void*
// second cast void* to Desc*
// the copy constructor of tensor descriptor doesn't take address_space(4)
const
auto
a_e0_e1_k0_k1_e2_grid_desc
=
*
reinterpret_cast
<
const
AGridDesc_E0_E1_K0_K1_E2
*>
(
cast_pointer_to_generic_address_space
(
p_a_e0_e1_k0_k1_e2_grid_desc
));
const
auto
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc
=
*
reinterpret_cast
<
const
BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2
*>
(
cast_pointer_to_generic_address_space
(
p_b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc
));
const
auto
c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc
=
*
reinterpret_cast
<
const
CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2
*>
(
cast_pointer_to_generic_address_space
(
p_c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc
));
const
auto
d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc
=
*
reinterpret_cast
<
const
DGridDesc_K0_K1_N_H0_H1_Hx_W0_W1_Wx
*>
(
cast_pointer_to_generic_address_space
(
p_d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc
));
const
auto
c_blockid_to_k_n_h_w_block_cluster_adaptor
=
*
reinterpret_cast
<
const
CBlockIdToBlockClusterAdaptor_K_N_H_W
*>
(
cast_pointer_to_generic_address_space
(
p_c_blockid_to_k_n_h_w_block_cluster_adaptor
));
constexpr
index_t
shared_block_size
=
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()
/
sizeof
(
FloatAB
);
__shared__
FloatAB
p_shared_block
[
shared_block_size
];
GridwiseGemm
::
ConvBiasActivResizeAdd
(
p_a_grid
,
p_b_grid
,
p_bias_grid
,
p_d_grid
,
p_shared_block
,
a_e0_e1_k0_k1_e2_grid_desc
,
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc
,
c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc
,
d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc
,
c_blockid_to_k_n_h_w_block_cluster_adaptor
,
integral_constant
<
bool
,
HasMainE0BlockLoop
>
{},
integral_constant
<
ActivTypeEnum_t
,
ActivType
>
{});
}
template
<
typename
GridwiseGemm
,
typename
FloatAB
,
typename
FloatC
,
typename
AGridDesc_E0_E1_K0_K1_E2
,
typename
BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2
,
typename
CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2
,
typename
DGridDesc_K0_K1_N_H0_H1_Hx_W0_W1_Wx
,
typename
CBlockIdToBlockClusterAdaptor_K_N_H_W
,
bool
HasMainE0BlockLoop
,
ActivTypeEnum_t
ActivType
>
__global__
void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
#endif
kernel_gemm_dlops_v3_maxpool
(
const
FloatAB
*
__restrict__
p_a_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
const
FloatC
*
__restrict__
p_bias_grid
,
FloatC
*
__restrict__
p_c_grid
,
FloatC
*
__restrict__
p_d_grid
,
const
void
CONSTANT
*
p_a_e0_e1_k0_k1_e2_grid_desc
,
const
void
CONSTANT
*
p_b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc
,
const
void
CONSTANT
*
p_c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc
,
const
void
CONSTANT
*
p_d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc
,
const
void
CONSTANT
*
p_c_blockid_to_k_n_h_w_block_cluster_adaptor
)
{
// first cast void CONSTANT void* to void*
// second cast void* to Desc*
// the copy constructor of tensor descriptor doesn't take address_space(4)
const
auto
a_e0_e1_k0_k1_e2_grid_desc
=
*
reinterpret_cast
<
const
AGridDesc_E0_E1_K0_K1_E2
*>
(
cast_pointer_to_generic_address_space
(
p_a_e0_e1_k0_k1_e2_grid_desc
));
const
auto
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc
=
*
reinterpret_cast
<
const
BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2
*>
(
cast_pointer_to_generic_address_space
(
p_b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc
));
const
auto
c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc
=
*
reinterpret_cast
<
const
CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2
*>
(
cast_pointer_to_generic_address_space
(
p_c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc
));
const
auto
d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc
=
*
reinterpret_cast
<
const
DGridDesc_K0_K1_N_H0_H1_Hx_W0_W1_Wx
*>
(
cast_pointer_to_generic_address_space
(
p_d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc
));
const
auto
c_blockid_to_k_n_h_w_block_cluster_adaptor
=
*
reinterpret_cast
<
const
CBlockIdToBlockClusterAdaptor_K_N_H_W
*>
(
cast_pointer_to_generic_address_space
(
p_c_blockid_to_k_n_h_w_block_cluster_adaptor
));
constexpr
index_t
shared_block_size
=
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()
/
sizeof
(
FloatAB
);
__shared__
FloatAB
p_shared_block
[
shared_block_size
];
GridwiseGemm
::
ConvBiasActivMaxpool
(
p_a_grid
,
p_b_grid
,
p_bias_grid
,
p_c_grid
,
p_d_grid
,
p_shared_block
,
a_e0_e1_k0_k1_e2_grid_desc
,
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc
,
c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc
,
d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc
,
c_blockid_to_k_n_h_w_block_cluster_adaptor
,
integral_constant
<
bool
,
HasMainE0BlockLoop
>
{},
integral_constant
<
ActivTypeEnum_t
,
ActivType
>
{});
}
#elif CK_EXPERIMENTAL_STATIC_TENSOR_DESCRIPTOR
template
<
typename
GridwiseGemm
,
typename
FloatAB
,
typename
FloatC
,
typename
AGridDesc_E0_E1_K0_K1_E2
,
typename
BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2
,
typename
CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2
,
typename
DGridDesc_K0_K1_N_H0_H1_Hx_W0_W1_Wx
,
typename
CBlockIdToBlockClusterAdaptor_K_N_H_W
,
bool
HasMainE0BlockLoop
,
ActivTypeEnum_t
ActivType
>
__global__
void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
#endif
kernel_gemm_dlops_v3_resize_add
(
const
FloatAB
*
__restrict__
p_a_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
const
FloatC
*
__restrict__
p_bias_grid
,
FloatC
*
__restrict__
p_d_grid
)
{
constexpr
index_t
shared_block_size
=
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()
/
sizeof
(
FloatAB
);
__shared__
FloatAB
p_shared_block
[
shared_block_size
];
constexpr
auto
a_e0_e1_k0_k1_e2_grid_desc
=
AGridDesc_E0_E1_K0_K1_E2
{};
constexpr
auto
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc
=
BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2
{};
constexpr
auto
c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc
=
CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2
{};
constexpr
auto
d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc
=
DGridDesc_K0_K1_N_H0_H1_Hx_W0_W1_Wx
{};
constexpr
auto
c_blockid_to_k_n_h_w_block_cluster_adaptor
=
CBlockIdToBlockClusterAdaptor_K_N_H_W
{};
GridwiseGemm
::
ConvBiasActivResizeAdd
(
p_a_grid
,
p_b_grid
,
p_bias_grid
,
p_d_grid
,
p_shared_block
,
a_e0_e1_k0_k1_e2_grid_desc
,
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc
,
c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc
,
d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc
,
c_blockid_to_k_n_h_w_block_cluster_adaptor
,
integral_constant
<
bool
,
HasMainE0BlockLoop
>
{},
integral_constant
<
ActivTypeEnum_t
,
ActivType
>
{});
}
template
<
typename
GridwiseGemm
,
typename
FloatAB
,
typename
FloatC
,
typename
AGridDesc_E0_E1_K0_K1_E2
,
typename
BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2
,
typename
CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2
,
typename
DGridDesc_K0_K1_N_H0_H1_Hx_W0_W1_Wx
,
typename
CBlockIdToBlockClusterAdaptor_K_N_H_W
,
bool
HasMainE0BlockLoop
,
ActivTypeEnum_t
ActivType
>
__global__
void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
#endif
kernel_gemm_dlops_v3_maxpool
(
const
FloatAB
*
__restrict__
p_a_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
const
FloatC
*
__restrict__
p_bias_grid
,
FloatC
*
__restrict__
p_c_grid
,
FloatC
*
__restrict__
p_d_grid
)
{
constexpr
index_t
shared_block_size
=
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()
/
sizeof
(
FloatAB
);
__shared__
FloatAB
p_shared_block
[
shared_block_size
];
constexpr
auto
a_e0_e1_k0_k1_e2_grid_desc
=
AGridDesc_E0_E1_K0_K1_E2
{};
constexpr
auto
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc
=
BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2
{};
constexpr
auto
c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc
=
CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2
{};
constexpr
auto
d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc
=
DGridDesc_K0_K1_N_H0_H1_Hx_W0_W1_Wx
{};
constexpr
auto
c_blockid_to_k_n_h_w_block_cluster_adaptor
=
CBlockIdToBlockClusterAdaptor_K_N_H_W
{};
GridwiseGemm
::
ConvBiasActivMaxpool
(
p_a_grid
,
p_b_grid
,
p_bias_grid
,
p_c_grid
,
p_d_grid
,
p_shared_block
,
a_e0_e1_k0_k1_e2_grid_desc
,
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc
,
c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc
,
d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc
,
c_blockid_to_k_n_h_w_block_cluster_adaptor
,
integral_constant
<
bool
,
HasMainE0BlockLoop
>
{},
integral_constant
<
ActivTypeEnum_t
,
ActivType
>
{});
}
template
<
typename
GridwiseGemm
,
typename
FloatAB
,
typename
FloatC
,
typename
AGridDesc_E0_E1_K0_K1_E2
,
typename
BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2
,
typename
CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2
,
typename
CBlockIdToBlockClusterAdaptor_K_N_H_W
,
bool
HasMainE0BlockLoop
,
ActivTypeEnum_t
ActivType
>
__global__
void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
#endif
kernel_gemm_dlops_v3
(
const
FloatAB
*
__restrict__
p_a_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
const
FloatC
*
__restrict__
p_bias_grid
,
FloatC
*
__restrict__
p_c_grid
)
{
constexpr
index_t
shared_block_size
=
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()
/
sizeof
(
FloatAB
);
__shared__
FloatAB
p_shared_block
[
shared_block_size
];
constexpr
auto
a_e0_e1_k0_k1_e2_grid_desc
=
AGridDesc_E0_E1_K0_K1_E2
{};
constexpr
auto
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc
=
BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2
{};
constexpr
auto
c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc
=
CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2
{};
constexpr
auto
c_blockid_to_k_n_h_w_block_cluster_adaptor
=
CBlockIdToBlockClusterAdaptor_K_N_H_W
{};
GridwiseGemm
::
ConvBiasActiv
(
p_a_grid
,
p_b_grid
,
p_bias_grid
,
p_c_grid
,
p_shared_block
,
a_e0_e1_k0_k1_e2_grid_desc
,
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc
,
c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc
,
c_blockid_to_k_n_h_w_block_cluster_adaptor
,
integral_constant
<
bool
,
HasMainE0BlockLoop
>
{},
integral_constant
<
ActivTypeEnum_t
,
ActivType
>
{});
}
#endif
template
<
index_t
BlockSize
,
typename
FloatAB
,
typename
FloatAcc
,
typename
FloatC
,
InMemoryDataOperationEnum_t
CGlobalMemoryDataOperation
,
typename
AGridDesc_E0_E1_K_E2
,
typename
BGridDesc_E0_E1_N_Ho_Wo_E2
,
typename
CGridDesc_K_N_Ho_Wo
,
typename
DGridDesc_K_N_Hx_Wx
,
index_t
E1_
,
index_t
E2_
,
index_t
K2_
,
index_t
KPerBlock
,
index_t
HoPerBlock
,
index_t
WoPerBlock
,
index_t
E1PerBlock
,
index_t
KPerThread
,
index_t
HoPerThread
,
index_t
WoPerThread
,
index_t
EPerThread
,
typename
ABlockTransferThreadSliceLengths_E0_E1_K0_K1_E2
,
typename
ABlockTransferThreadClusterLengths_E0_E1_K0_K1_E2
,
typename
ABlockTransferThreadClusterArrangeOrder
,
typename
ABlockTransferSrcAccessOrder
,
index_t
ABlockTransferSrcVectorDim
,
index_t
ABlockTransferSrcScalarPerVector
,
index_t
ABlockTransferDstScalarPerVector_E2
,
bool
AThreadTransferSrcResetCoordinateAfterRun
,
typename
BBlockTransferSrcAccessOrder
,
index_t
BBlockTransferSrcVectorDim
,
index_t
BBlockTransferSrcScalarPerVector
,
bool
BThreadTransferSrcResetCoordinateAfterRun
,
typename
CThreadTransferSrcDstAccessOrder
,
index_t
CThreadTransferSrcDstVectorDim
,
index_t
CThreadTransferDstScalarPerVector
,
typename
AGlobalStepHacks
,
typename
BGlobalStepHacks
,
typename
CGlobalStepHacks
,
typename
DGlobalStepHacks
,
typename
AGlobalMoveSliceWindowStepHacks
,
typename
BGlobalMoveSliceWindowStepHacks
>
struct
GridwiseGemmDlops_km_kn_mn_v3
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
static
constexpr
auto
I4
=
Number
<
4
>
{};
static
constexpr
auto
I5
=
Number
<
5
>
{};
static
constexpr
auto
E1
=
Number
<
E1_
>
{};
static
constexpr
auto
E2
=
Number
<
E2_
>
{};
static
constexpr
auto
K2
=
Number
<
K2_
>
{};
static
constexpr
auto
NPerBlock
=
I1
;
static
constexpr
FloatAcc
alpha
=
0.3
;
__host__
__device__
static
constexpr
index_t
GetSharedMemoryNumberOfByte
()
{
constexpr
auto
max_lds_align
=
Number
<
ABlockTransferDstScalarPerVector_E2
>
{};
// A matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr
auto
a_e0_e1_k1_e2_block_desc
=
make_naive_tensor_descriptor_aligned
(
make_tuple
(
I1
,
Number
<
E1
>
{},
Number
<
KPerBlock
>
{},
Number
<
E2
>
{}),
max_lds_align
);
// LDS allocation for A and B: be careful of alignment
constexpr
auto
a_block_space_size
=
math
::
integer_least_multiple
(
a_e0_e1_k1_e2_block_desc
.
GetElementSpaceSize
(),
max_lds_align
);
return
a_block_space_size
*
sizeof
(
FloatAB
);
}
__host__
__device__
static
constexpr
index_t
CalculateGridSize
(
const
CGridDesc_K_N_Ho_Wo
&
c_k_n_ho_wo_grid_desc
)
{
const
auto
K
=
c_k_n_ho_wo_grid_desc
.
GetLength
(
I0
);
const
auto
N
=
c_k_n_ho_wo_grid_desc
.
GetLength
(
I1
);
const
auto
Ho
=
c_k_n_ho_wo_grid_desc
.
GetLength
(
I2
);
const
auto
Wo
=
c_k_n_ho_wo_grid_desc
.
GetLength
(
I3
);
const
auto
K0
=
K
/
KPerBlock
;
const
auto
N0
=
N
/
NPerBlock
;
const
auto
H0
=
Ho
/
HoPerBlock
;
const
auto
W0
=
Wo
/
WoPerBlock
;
const
index_t
grid_size
=
K0
*
N0
*
H0
*
W0
;
return
grid_size
;
}
__host__
__device__
static
constexpr
bool
CalculateHasMainE0BlockLoop
(
const
index_t
E0
)
{
const
bool
has_main_e0_block_loop
=
E0
>
1
;
return
has_main_e0_block_loop
;
}
__host__
__device__
static
constexpr
bool
CalculateHasMainE1BlockLoop
()
{
const
bool
has_main_e1_block_loop
=
((
E1
+
E1PerBlock
)
/
(
2
*
E1PerBlock
))
>
1
;
return
has_main_e1_block_loop
;
}
__host__
__device__
static
constexpr
bool
CalculateHasDoubleTailE1BlockLoop
()
{
const
bool
has_double_tail_e1_block_loop
=
(
E1
/
E1PerBlock
)
%
2
==
0
;
return
has_double_tail_e1_block_loop
;
}
__host__
__device__
static
constexpr
auto
MakeAE0E1K0K1E2GridDescriptor
(
const
AGridDesc_E0_E1_K_E2
&
a_e0_e1_k_e2_grid_desc
)
{
const
auto
E0
=
a_e0_e1_k_e2_grid_desc
.
GetLength
(
I0
);
const
auto
K
=
a_e0_e1_k_e2_grid_desc
.
GetLength
(
I2
);
const
auto
K1
=
Number
<
KPerBlock
>
{};
const
auto
K0
=
K
/
K1
;
const
auto
a_e0_e1_k0_k1_e2_grid_desc
=
transform_tensor_descriptor
(
a_e0_e1_k_e2_grid_desc
,
make_tuple
(
make_pass_through_transform
(
E0
),
make_pass_through_transform
(
E1
),
make_unmerge_transform
(
make_tuple
(
K0
,
K1
)),
make_pass_through_transform
(
E2
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
3
>
{},
Sequence
<
4
>
{}));
return
a_e0_e1_k0_k1_e2_grid_desc
;
}
__host__
__device__
static
constexpr
auto
MakeBE0E1NH0H1H2W0W1W2E2GridDescriptor
(
const
BGridDesc_E0_E1_N_Ho_Wo_E2
&
b_e0_e1_n_ho_wo_e2_grid_desc
)
{
const
auto
E0
=
b_e0_e1_n_ho_wo_e2_grid_desc
.
GetLength
(
I0
);
// const auto E1 = b_e0_e1_n_ho_wo_e2_grid_desc.GetLength(I1);
const
auto
N
=
b_e0_e1_n_ho_wo_e2_grid_desc
.
GetLength
(
I2
);
const
auto
Ho
=
b_e0_e1_n_ho_wo_e2_grid_desc
.
GetLength
(
I3
);
const
auto
Wo
=
b_e0_e1_n_ho_wo_e2_grid_desc
.
GetLength
(
I4
);
// const auto E2 = b_e0_e1_n_ho_wo_e2_grid_desc.GetLength(I5);
const
auto
H2
=
Number
<
HoPerThread
>
{};
const
auto
H1
=
Number
<
HoPerBlock
/
HoPerThread
>
{};
const
auto
H0
=
Ho
/
(
H1
*
H2
);
const
auto
W2
=
Number
<
WoPerThread
>
{};
const
auto
W1
=
Number
<
WoPerBlock
/
WoPerThread
>
{};
const
auto
W0
=
Wo
/
(
W1
*
W2
);
const
auto
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc
=
transform_tensor_descriptor
(
b_e0_e1_n_ho_wo_e2_grid_desc
,
make_tuple
(
make_pass_through_transform
(
E0
),
make_pass_through_transform
(
E1
),
make_pass_through_transform
(
N
),
make_unmerge_transform
(
make_tuple
(
H0
,
H1
,
H2
)),
make_unmerge_transform
(
make_tuple
(
W0
,
W1
,
W2
)),
make_pass_through_transform
(
E2
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{},
Sequence
<
5
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
,
4
,
5
>
{},
Sequence
<
6
,
7
,
8
>
{},
Sequence
<
9
>
{}));
return
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc
;
}
__host__
__device__
static
constexpr
auto
MakeCK0K1NH0H1H2W0W1W2GridDescriptor
(
const
CGridDesc_K_N_Ho_Wo
&
c_k_n_ho_wo_grid_desc
)
{
const
auto
K
=
c_k_n_ho_wo_grid_desc
.
GetLength
(
I0
);
const
auto
N
=
c_k_n_ho_wo_grid_desc
.
GetLength
(
I1
);
const
auto
Ho
=
c_k_n_ho_wo_grid_desc
.
GetLength
(
I2
);
const
auto
Wo
=
c_k_n_ho_wo_grid_desc
.
GetLength
(
I3
);
const
auto
K1
=
Number
<
KPerBlock
>
{};
const
auto
K0
=
K
/
K1
;
const
auto
H2
=
Number
<
HoPerThread
>
{};
const
auto
H1
=
Number
<
HoPerBlock
/
HoPerThread
>
{};
const
auto
H0
=
Ho
/
(
H1
*
H2
);
const
auto
W2
=
Number
<
WoPerThread
>
{};
const
auto
W1
=
Number
<
WoPerBlock
/
WoPerThread
>
{};
const
auto
W0
=
Wo
/
(
W1
*
W2
);
const
auto
c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc
=
transform_tensor_descriptor
(
c_k_n_ho_wo_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
K0
,
K1
)),
make_pass_through_transform
(
N
),
make_unmerge_transform
(
make_tuple
(
H0
,
H1
,
H2
)),
make_unmerge_transform
(
make_tuple
(
W0
,
W1
,
W2
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
,
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
,
4
,
5
>
{},
Sequence
<
6
,
7
,
8
>
{}));
return
c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc
;
}
__host__
__device__
static
constexpr
auto
MakeDK0K1NH0H1HxW0W1WxGridDescriptorMaxPool
(
const
DGridDesc_K_N_Hx_Wx
&
d_k_n_hx_wx_grid_desc
)
{
const
auto
K
=
d_k_n_hx_wx_grid_desc
.
GetLength
(
I0
);
const
auto
N
=
d_k_n_hx_wx_grid_desc
.
GetLength
(
I1
);
const
auto
Hx
=
d_k_n_hx_wx_grid_desc
.
GetLength
(
I2
);
const
auto
Wx
=
d_k_n_hx_wx_grid_desc
.
GetLength
(
I3
);
const
auto
K1
=
Number
<
KPerBlock
>
{};
const
auto
K0
=
K
/
K1
;
#if CK_EXPERIMENTAL_STATIC_TENSOR_DESCRIPTOR
const
auto
H2
=
Number
<
HoPerThread
/
2
>
{};
const
auto
H1
=
Number
<
HoPerBlock
/
HoPerThread
>
{};
const
auto
H0
=
Number
<
Hx
/
(
H1
*
H2
)
>
{};
const
auto
W2
=
Number
<
WoPerThread
/
2
>
{};
const
auto
W1
=
Number
<
WoPerBlock
/
WoPerThread
>
{};
const
auto
W0
=
Number
<
Wx
/
(
W1
*
W2
)
>
{};
#else
const
auto
H2
=
HoPerThread
/
2
;
const
auto
H1
=
HoPerBlock
/
HoPerThread
;
const
auto
H0
=
Hx
/
(
H1
*
H2
);
const
auto
W2
=
WoPerThread
/
2
;
const
auto
W1
=
WoPerBlock
/
WoPerThread
;
const
auto
W0
=
Wx
/
(
W1
*
W2
);
#endif
const
auto
d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc
=
transform_tensor_descriptor
(
d_k_n_hx_wx_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
K0
,
K1
)),
make_pass_through_transform
(
N
),
make_unmerge_transform
(
make_tuple
(
H0
,
H1
,
H2
)),
make_unmerge_transform
(
make_tuple
(
W0
,
W1
,
W2
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
,
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
,
4
,
5
>
{},
Sequence
<
6
,
7
,
8
>
{}));
return
d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc
;
}
__host__
__device__
static
constexpr
auto
MakeDK0K1NH0H1HxW0W1WxGridDescriptorResizeAdd
(
const
DGridDesc_K_N_Hx_Wx
&
d_k_n_hx_wx_grid_desc
)
{
const
auto
K
=
d_k_n_hx_wx_grid_desc
.
GetLength
(
I0
);
const
auto
N
=
d_k_n_hx_wx_grid_desc
.
GetLength
(
I1
);
const
auto
Hx
=
d_k_n_hx_wx_grid_desc
.
GetLength
(
I2
);
const
auto
Wx
=
d_k_n_hx_wx_grid_desc
.
GetLength
(
I3
);
const
auto
K1
=
Number
<
KPerBlock
>
{};
const
auto
K0
=
K
/
K1
;
const
auto
H2
=
Number
<
HoPerThread
*
2
>
{};
const
auto
H1
=
Number
<
HoPerBlock
/
HoPerThread
>
{};
const
auto
W2
=
Number
<
WoPerThread
*
2
>
{};
const
auto
W1
=
Number
<
WoPerBlock
/
WoPerThread
>
{};
#if CK_EXPERIMENTAL_STATIC_TENSOR_DESCRIPTOR
const
auto
H0
=
Number
<
Hx
/
(
H1
*
H2
)
>
{};
const
auto
W0
=
Number
<
Wx
/
(
W1
*
W2
)
>
{};
#else
const
auto
H0
=
Hx
/
(
H1
*
H2
);
const
auto
W0
=
Wx
/
(
W1
*
W2
);
#endif
const
auto
d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc
=
transform_tensor_descriptor
(
d_k_n_hx_wx_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
K0
,
K1
)),
make_pass_through_transform
(
N
),
make_unmerge_transform
(
make_tuple
(
H0
,
H1
,
H2
)),
make_unmerge_transform
(
make_tuple
(
W0
,
W1
,
W2
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
,
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
,
4
,
5
>
{},
Sequence
<
6
,
7
,
8
>
{}));
return
d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc
;
}
__host__
__device__
static
constexpr
auto
MakeCBlockIdToKNHoWoBlockClusterAdaptor
(
const
CGridDesc_K_N_Ho_Wo
&
c_k_n_ho_wo_grid_desc
)
{
const
auto
K
=
c_k_n_ho_wo_grid_desc
.
GetLength
(
I0
);
const
auto
N
=
c_k_n_ho_wo_grid_desc
.
GetLength
(
I1
);
const
auto
Ho
=
c_k_n_ho_wo_grid_desc
.
GetLength
(
I2
);
const
auto
Wo
=
c_k_n_ho_wo_grid_desc
.
GetLength
(
I3
);
#if CK_EXPERIMENTAL_STATIC_TENSOR_DESCRIPTOR
const
auto
K0
=
Number
<
K
/
KPerBlock
>
{};
const
auto
N0
=
Number
<
N
/
NPerBlock
>
{};
const
auto
H0
=
Number
<
Ho
/
HoPerBlock
>
{};
const
auto
W0
=
Number
<
Wo
/
WoPerBlock
>
{};
#else
const
auto
K0
=
K
/
KPerBlock
;
const
auto
N0
=
N
/
NPerBlock
;
const
auto
H0
=
Ho
/
HoPerBlock
;
const
auto
W0
=
Wo
/
WoPerBlock
;
#endif
const
auto
c_blockid_to_k_n_ho_wo_block_cluster_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
K0
,
N0
,
H0
,
W0
))),
make_tuple
(
Sequence
<
0
,
1
,
2
,
3
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
return
c_blockid_to_k_n_ho_wo_block_cluster_adaptor
;
}
// using AGridDesc_E0_E1_K0_K1_E2 =
// decltype(MakeAE0E1K0K1E2GridDescriptor(AGridDesc_E0_E1_K_E2{}));
// using BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2 =
// decltype(MakeBE0E1NH0H1H2W0W1W2E2GridDescriptor(BGridDesc_E0_E1_N_Ho_Wo_E2{}));
// using CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2 =
// decltype(MakeCK0K1NH0H1H2W0W1W2GridDescriptor(CGridDesc_K_N_Ho_Wo{}));
// using DGridDesc_K0_K1_N_H0_H1_Hx_W0_W1_Wx =
// decltype(MakeDK0K1NH0H1HxW0W1WxGridDescriptor(DGridDesc_K_N_Hx_Wx{}));
using
CBlockIdToBlockClusterAdaptor_K_N_H_W
=
decltype
(
MakeCBlockIdToKNHoWoBlockClusterAdaptor
(
CGridDesc_K_N_Ho_Wo
{}));
template
<
typename
CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2
>
__host__
__device__
static
constexpr
auto
MakeBiasK0K1GridDescriptor
(
const
CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2
&
c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc
)
{
const
auto
K0
=
c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc
.
GetLength
(
I0
);
const
auto
K1
=
c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc
.
GetLength
(
I1
);
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
K0
,
K1
));
}
__host__
__device__
static
constexpr
auto
MakeCK1NH2W2ThreadDescriptor
()
{
constexpr
auto
c_k1_n_h2_w2_thread_gemm_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
KPerThread
>
{},
I1
,
Number
<
HoPerThread
>
{},
Number
<
WoPerThread
>
{}));
return
c_k1_n_h2_w2_thread_gemm_desc
;
}
// using CThreadDesc_K1_N_H2_W2 = decltype(MakeCK1NH2W2ThreadDescriptor());
__host__
__device__
static
constexpr
auto
GetBlockWiseGemm
()
{
constexpr
auto
max_lds_align
=
Number
<
ABlockTransferDstScalarPerVector_E2
>
{};
constexpr
auto
a_e1_k1_e2_block_gemm_desc
=
make_naive_tensor_descriptor_aligned
(
make_tuple
(
Number
<
E1PerBlock
>
{},
Number
<
KPerBlock
>
{},
Number
<
E2
>
{}),
max_lds_align
);
constexpr
auto
b_e1_n_h_w_e2_block_gemm_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
E1PerBlock
>
{},
I1
,
Number
<
HoPerBlock
>
{},
Number
<
WoPerBlock
>
{},
Number
<
E2
>
{}));
constexpr
auto
c_k1_n_h2_w2_thread_gemm_desc
=
MakeCK1NH2W2ThreadDescriptor
();
auto
blockwise_gemm
=
BlockwiseGemmDlops_km_kn_m0m1n0n1_v3
<
BlockSize
,
FloatAB
,
FloatAB
,
FloatAcc
,
decltype
(
a_e1_k1_e2_block_gemm_desc
),
decltype
(
b_e1_n_h_w_e2_block_gemm_desc
),
decltype
(
c_k1_n_h2_w2_thread_gemm_desc
),
EPerThread
,
K2
>
{};
return
blockwise_gemm
;
}
__device__
static
constexpr
auto
GetCThreadIndex
()
{
auto
blockwise_gemm
=
GetBlockWiseGemm
();
auto
c_thread_mtx_index
=
blockwise_gemm
.
GetBeginOfCThreadDesc_K_N_Ho_Wo
(
get_thread_local_1d_id
());
return
c_thread_mtx_index
;
};
__device__
static
constexpr
auto
GetCBlockIndex
(
const
CBlockIdToBlockClusterAdaptor_K_N_H_W
&
c_blockid_to_k_n_h_w_block_cluster_adaptor
)
{
const
auto
c_k_n_h_w_block_cluster_idx
=
c_blockid_to_k_n_h_w_block_cluster_adaptor
.
CalculateBottomIndex
(
make_multi_index
(
get_block_1d_id
()));
return
c_k_n_h_w_block_cluster_idx
;
}
template
<
typename
BiasGlobalBuff
,
typename
CThreadBuff
,
typename
CBlockIndex
,
typename
CThreadIndex
,
typename
BiasGridDesc_K0_K1
,
typename
CThreadDesc_K1_N_H2_W2
>
__device__
static
void
BiasOp
(
BiasGlobalBuff
&
bias_global_buf
,
CThreadBuff
&
c_thread_buf
,
const
CBlockIndex
&
c_block_idx
,
const
CThreadIndex
&
c_thread_idx
,
const
BiasGridDesc_K0_K1
&
bias_k0_k1_grid_desc
,
const
CThreadDesc_K1_N_H2_W2
&
)
{
const
index_t
k_block_work_id
=
__builtin_amdgcn_readfirstlane
(
c_block_idx
[
I0
]);
const
auto
k_thread_id
=
c_thread_idx
[
I0
];
constexpr
auto
c_k1_n_h2_w2_thread_gemm_desc
=
CThreadDesc_K1_N_H2_W2
{};
constexpr
auto
bias_k0_k1_thread_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
Number
<
KPerThread
>
{}));
StaticBuffer
<
AddressSpaceEnum_t
::
Vgpr
,
FloatC
,
bias_k0_k1_thread_desc
.
GetElementSpaceSize
(),
true
>
bias_thread_buf
;
const
index_t
k_thread_data_on_global
=
k_thread_id
*
KPerThread
;
auto
bias_threadwise_transfer
=
ThreadwiseTensorSliceTransfer_v2
<
FloatC
,
FloatC
,
decltype
(
bias_k0_k1_grid_desc
),
decltype
(
bias_k0_k1_thread_desc
),
Sequence
<
I1
,
Number
<
KPerThread
>
{}
>
,
Sequence
<
0
,
1
>
,
1
,
CThreadTransferDstScalarPerVector
,
false
,
true
>
(
bias_k0_k1_grid_desc
,
make_multi_index
(
k_block_work_id
,
k_thread_data_on_global
));
constexpr
auto
bias_k0_k1_global_tensor_step_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
0
>
{}));
bias_threadwise_transfer
.
Run
(
bias_k0_k1_grid_desc
,
bias_global_buf
,
bias_k0_k1_thread_desc
,
make_tuple
(
I0
,
I0
),
bias_thread_buf
,
bias_k0_k1_global_tensor_step_hacks
);
static_for
<
0
,
KPerThread
,
1
>
{}([
&
](
auto
ki
)
{
static_for
<
0
,
HoPerThread
,
1
>
{}([
&
](
auto
hi
)
{
static_for
<
0
,
WoPerThread
,
1
>
{}([
&
](
auto
wi
)
{
constexpr
index_t
c_offset
=
c_k1_n_h2_w2_thread_gemm_desc
.
CalculateOffset
(
make_tuple
(
ki
,
0
,
hi
,
wi
));
c_thread_buf
(
Number
<
c_offset
>
{})
=
c_thread_buf
[
Number
<
c_offset
>
{}]
+
bias_thread_buf
[
ki
];
});
});
});
}
template
<
typename
CThreadBuff
,
typename
CThreadDesc_K1_N_H2_W2
,
ActivTypeEnum_t
activ_type_
>
__device__
static
void
Activation
(
CThreadBuff
&
c_thread_buf
,
const
CThreadDesc_K1_N_H2_W2
&
,
integral_constant
<
ActivTypeEnum_t
,
activ_type_
>
)
{
constexpr
auto
c_k1_n_h2_w2_thread_gemm_desc
=
CThreadDesc_K1_N_H2_W2
{};
static_for
<
0
,
c_k1_n_h2_w2_thread_gemm_desc
.
GetElementSpaceSize
(),
1
>
{}([
&
](
auto
i
)
{
if
constexpr
(
activ_type_
==
1
)
{
c_thread_buf
(
i
)
=
c_thread_buf
[
i
]
>=
0
?
c_thread_buf
[
i
]
:
alpha
*
c_thread_buf
[
i
];
}
else
if
constexpr
(
activ_type_
==
2
)
{
FloatAcc
x
=
1.0
+
exp
(
-
c_thread_buf
[
i
]);
asm
volatile
(
"
\n
\
v_rcp_f32 %0, %1
\n
"
:
"=v"
(
x
)
:
"0"
(
x
));
c_thread_buf
(
i
)
=
x
;
}
});
}
template
<
typename
CThreadBuff
,
typename
CGlobalBuff
,
typename
CBlockIndex
,
typename
CThreadIndex
,
typename
CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2
>
__device__
static
void
WriteOut
(
const
CThreadBuff
&
c_thread_buf
,
CGlobalBuff
&
c_global_buf
,
const
CBlockIndex
&
c_block_idx
,
const
CThreadIndex
&
c_thread_idx
,
const
CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2
&
c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc
)
{
const
index_t
k_block_work_id
=
__builtin_amdgcn_readfirstlane
(
c_block_idx
[
I0
]);
const
index_t
n_block_work_id
=
__builtin_amdgcn_readfirstlane
(
c_block_idx
[
I1
]);
const
index_t
ho_block_work_id
=
__builtin_amdgcn_readfirstlane
(
c_block_idx
[
I2
]);
const
index_t
wo_block_work_id
=
__builtin_amdgcn_readfirstlane
(
c_block_idx
[
I3
]);
const
auto
k_thread_id
=
c_thread_idx
[
I0
];
const
auto
ho_thread_id
=
c_thread_idx
[
I2
];
const
auto
wo_thread_id
=
c_thread_idx
[
I3
];
// hack to control index calculation when iterating over c_k_n_h0_h1_h2_w0_w1_w2_global
// tensor
constexpr
auto
c_k_n_h0_h1_h2_w0_w1_w2_global_tensor_step_hacks
=
CGlobalStepHacks
{};
constexpr
auto
c_k0_k1_n_h0_h1_h2_w0_w1_w2_thread_copy_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
Number
<
KPerThread
>
{},
I1
,
I1
,
I1
,
Number
<
HoPerThread
>
{},
I1
,
I1
,
Number
<
WoPerThread
>
{}));
const
index_t
k_thread_data_on_global
=
k_thread_id
*
KPerThread
;
ThreadwiseTensorSliceTransfer_v1r3
<
FloatAcc
,
FloatC
,
decltype
(
c_k0_k1_n_h0_h1_h2_w0_w1_w2_thread_copy_desc
),
decltype
(
c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc
),
Sequence
<
I1
,
KPerThread
,
I1
,
I1
,
I1
,
HoPerThread
,
I1
,
I1
,
WoPerThread
>
,
CThreadTransferSrcDstAccessOrder
,
CThreadTransferSrcDstVectorDim
,
CThreadTransferDstScalarPerVector
,
CGlobalMemoryDataOperation
,
1
,
true
>
(
c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc
,
make_multi_index
(
k_block_work_id
,
k_thread_data_on_global
,
n_block_work_id
,
ho_block_work_id
,
ho_thread_id
,
0
,
wo_block_work_id
,
wo_thread_id
,
0
))
.
Run
(
c_k0_k1_n_h0_h1_h2_w0_w1_w2_thread_copy_desc
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
c_thread_buf
,
c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc
,
c_global_buf
,
c_k_n_h0_h1_h2_w0_w1_w2_global_tensor_step_hacks
);
}
template
<
typename
CThreadBuff
,
typename
DGlobalBuff
,
typename
CBlockIndex
,
typename
CThreadIndex
,
typename
CThreadDesc_K1_N_H2_W2
,
typename
DGridDesc_K0_K1_N_H0_H1_Hx_W0_W1_Wx
>
__device__
static
void
MaxPool
(
const
CThreadBuff
&
c_thread_buf
,
DGlobalBuff
&
d_global_buf
,
const
CBlockIndex
&
c_block_idx
,
const
CThreadIndex
&
c_thread_idx
,
const
CThreadDesc_K1_N_H2_W2
&
,
const
DGridDesc_K0_K1_N_H0_H1_Hx_W0_W1_Wx
&
d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc
)
{
const
index_t
k_block_work_id
=
__builtin_amdgcn_readfirstlane
(
c_block_idx
[
I0
]);
const
index_t
n_block_work_id
=
__builtin_amdgcn_readfirstlane
(
c_block_idx
[
I1
]);
const
index_t
ho_block_work_id
=
__builtin_amdgcn_readfirstlane
(
c_block_idx
[
I2
]);
const
index_t
wo_block_work_id
=
__builtin_amdgcn_readfirstlane
(
c_block_idx
[
I3
]);
const
auto
k_thread_id
=
c_thread_idx
[
I0
];
const
auto
ho_thread_id
=
c_thread_idx
[
I2
];
const
auto
wo_thread_id
=
c_thread_idx
[
I3
];
constexpr
auto
c_k1_n_h2_w2_thread_gemm_desc
=
CThreadDesc_K1_N_H2_W2
{};
static_assert
(
HoPerThread
%
2
==
0
&&
WoPerThread
%
2
==
0
,
""
);
constexpr
auto
HoPerThread_2
=
HoPerThread
/
2
;
constexpr
auto
WoPerThread_2
=
WoPerThread
/
2
;
constexpr
auto
d_k0_k1_n_h0_h1_hx_w0_w1_wx_thread_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
Number
<
KPerThread
>
{},
I1
,
I1
,
I1
,
Number
<
HoPerThread_2
>
{},
I1
,
I1
,
Number
<
WoPerThread_2
>
{}));
StaticBuffer
<
AddressSpaceEnum_t
::
Vgpr
,
FloatC
,
d_k0_k1_n_h0_h1_hx_w0_w1_wx_thread_desc
.
GetElementSpaceSize
(),
true
>
d_thread_buf
;
static_for
<
0
,
KPerThread
,
1
>
{}([
&
](
auto
ki
)
{
static_for
<
0
,
HoPerThread_2
,
1
>
{}([
&
](
auto
hi
)
{
static_for
<
0
,
WoPerThread_2
,
1
>
{}([
&
](
auto
wi
)
{
constexpr
index_t
d_offset
=
d_k0_k1_n_h0_h1_hx_w0_w1_wx_thread_desc
.
CalculateOffset
(
make_tuple
(
0
,
ki
,
0
,
0
,
0
,
hi
,
0
,
0
,
wi
));
constexpr
index_t
c_offset_0
=
c_k1_n_h2_w2_thread_gemm_desc
.
CalculateOffset
(
make_tuple
(
ki
,
0
,
hi
*
2
,
wi
*
2
));
constexpr
index_t
c_offset_1
=
c_k1_n_h2_w2_thread_gemm_desc
.
CalculateOffset
(
make_tuple
(
ki
,
0
,
hi
*
2
,
wi
*
2
+
1
));
constexpr
index_t
c_offset_2
=
c_k1_n_h2_w2_thread_gemm_desc
.
CalculateOffset
(
make_tuple
(
ki
,
0
,
hi
*
2
+
1
,
wi
*
2
));
constexpr
index_t
c_offset_3
=
c_k1_n_h2_w2_thread_gemm_desc
.
CalculateOffset
(
make_tuple
(
ki
,
0
,
hi
*
2
+
1
,
wi
*
2
+
1
));
d_thread_buf
(
Number
<
d_offset
>
{})
=
c_thread_buf
[
Number
<
c_offset_0
>
{}];
d_thread_buf
(
Number
<
d_offset
>
{})
=
fmaxf
(
c_thread_buf
[
Number
<
c_offset_1
>
{}],
d_thread_buf
(
Number
<
d_offset
>
{}));
d_thread_buf
(
Number
<
d_offset
>
{})
=
fmaxf
(
c_thread_buf
[
Number
<
c_offset_2
>
{}],
d_thread_buf
(
Number
<
d_offset
>
{}));
d_thread_buf
(
Number
<
d_offset
>
{})
=
fmax
(
c_thread_buf
[
Number
<
c_offset_3
>
{}],
d_thread_buf
(
Number
<
d_offset
>
{}));
});
});
});
const
index_t
k_thread_data_on_global
=
k_thread_id
*
KPerThread
;
constexpr
auto
d_k_n_h0_h1_hx_w0_w1_wx_global_tensor_step_hacks
=
DGlobalStepHacks
{};
ThreadwiseTensorSliceTransfer_v1r3
<
FloatC
,
FloatC
,
decltype
(
d_k0_k1_n_h0_h1_hx_w0_w1_wx_thread_desc
),
decltype
(
d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc
),
Sequence
<
I1
,
KPerThread
,
I1
,
I1
,
I1
,
HoPerThread_2
,
I1
,
I1
,
WoPerThread_2
>
,
CThreadTransferSrcDstAccessOrder
,
CThreadTransferSrcDstVectorDim
,
CThreadTransferDstScalarPerVector
,
InMemoryDataOperationEnum_t
::
Set
,
1
,
true
>
(
d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc
,
make_multi_index
(
k_block_work_id
,
k_thread_data_on_global
,
n_block_work_id
,
ho_block_work_id
,
ho_thread_id
,
0
,
wo_block_work_id
,
wo_thread_id
,
0
))
.
Run
(
d_k0_k1_n_h0_h1_hx_w0_w1_wx_thread_desc
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
d_thread_buf
,
d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc
,
d_global_buf
,
d_k_n_h0_h1_hx_w0_w1_wx_global_tensor_step_hacks
);
}
template
<
typename
CThreadBuff
,
typename
DGlobalBuff
,
typename
CBlockIndex
,
typename
CThreadIndex
,
typename
CThreadDesc_K1_N_H2_W2
,
typename
DGridDesc_K0_K1_N_H0_H1_Hx_W0_W1_Wx
>
__device__
static
void
ResizeAdd
(
const
CThreadBuff
&
c_thread_buf
,
DGlobalBuff
&
d_global_buf
,
const
CBlockIndex
&
c_block_idx
,
const
CThreadIndex
&
c_thread_idx
,
const
CThreadDesc_K1_N_H2_W2
&
,
const
DGridDesc_K0_K1_N_H0_H1_Hx_W0_W1_Wx
&
d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc
)
{
const
index_t
k_block_work_id
=
__builtin_amdgcn_readfirstlane
(
c_block_idx
[
I0
]);
const
index_t
n_block_work_id
=
__builtin_amdgcn_readfirstlane
(
c_block_idx
[
I1
]);
const
index_t
ho_block_work_id
=
__builtin_amdgcn_readfirstlane
(
c_block_idx
[
I2
]);
const
index_t
wo_block_work_id
=
__builtin_amdgcn_readfirstlane
(
c_block_idx
[
I3
]);
const
auto
k_thread_id
=
c_thread_idx
[
I0
];
const
auto
ho_thread_id
=
c_thread_idx
[
I2
];
const
auto
wo_thread_id
=
c_thread_idx
[
I3
];
constexpr
auto
c_k1_n_h2_w2_thread_gemm_desc
=
CThreadDesc_K1_N_H2_W2
{};
constexpr
auto
HoPerThreadx2
=
HoPerThread
*
2
;
constexpr
auto
WoPerThreadx2
=
WoPerThread
*
2
;
constexpr
auto
d_k0_k1_n_h0_h1_hx_w0_w1_wx_thread_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
Number
<
KPerThread
>
{},
I1
,
I1
,
I1
,
Number
<
HoPerThreadx2
>
{},
I1
,
I1
,
Number
<
WoPerThreadx2
>
{}));
StaticBuffer
<
AddressSpaceEnum_t
::
Vgpr
,
FloatC
,
d_k0_k1_n_h0_h1_hx_w0_w1_wx_thread_desc
.
GetElementSpaceSize
(),
true
>
d_thread_buf
;
static_for
<
0
,
KPerThread
,
1
>
{}([
&
](
auto
k_i
)
{
static_for
<
0
,
HoPerThreadx2
,
1
>
{}([
&
](
auto
h_i
)
{
static_for
<
0
,
WoPerThreadx2
,
1
>
{}([
&
](
auto
w_i
)
{
d_thread_buf
(
Number
<
d_k0_k1_n_h0_h1_hx_w0_w1_wx_thread_desc
.
CalculateOffset
(
make_tuple
(
0
,
k_i
,
0
,
0
,
0
,
h_i
,
0
,
0
,
w_i
))
>
{})
=
c_thread_buf
[
Number
<
c_k1_n_h2_w2_thread_gemm_desc
.
CalculateOffset
(
make_tuple
(
k_i
,
0
,
h_i
/
2
,
w_i
/
2
))
>
{}];
});
});
});
// hack to control index calculation when iterating over d_k_n_ho_wo_global tensor
constexpr
auto
d_k_n_h0_h1_hx_w0_w1_wx_global_tensor_step_hacks
=
DGlobalStepHacks
{};
const
index_t
k_thread_data_on_global
=
k_thread_id
*
KPerThread
;
ThreadwiseTensorSliceTransfer_v1r3
<
FloatC
,
FloatC
,
decltype
(
d_k0_k1_n_h0_h1_hx_w0_w1_wx_thread_desc
),
decltype
(
d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc
),
Sequence
<
I1
,
KPerThread
,
I1
,
I1
,
I1
,
HoPerThreadx2
,
I1
,
I1
,
WoPerThreadx2
>
,
CThreadTransferSrcDstAccessOrder
,
CThreadTransferSrcDstVectorDim
,
CThreadTransferDstScalarPerVector
,
InMemoryDataOperationEnum_t
::
Add
,
1
,
true
>
(
d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc
,
make_multi_index
(
k_block_work_id
,
k_thread_data_on_global
,
n_block_work_id
,
ho_block_work_id
,
ho_thread_id
,
0
,
wo_block_work_id
,
wo_thread_id
,
0
))
.
Run
(
d_k0_k1_n_h0_h1_hx_w0_w1_wx_thread_desc
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
d_thread_buf
,
d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc
,
d_global_buf
,
d_k_n_h0_h1_hx_w0_w1_wx_global_tensor_step_hacks
);
}
template
<
typename
AGlobalBuff
,
typename
BGlobalBuff
,
typename
CThreadBuff
,
typename
CBlockIndex
,
typename
CThreadIndex
,
typename
AGridDesc_E0_E1_K0_K1_E2
,
typename
BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2
,
typename
CThreadDesc_K1_N_H2_W2
,
bool
HasMainE0BlockLoop
>
__device__
static
void
GemmOp
(
const
AGlobalBuff
&
a_global_buf
,
const
BGlobalBuff
&
b_global_buf
,
CThreadBuff
&
c_thread_buf
,
FloatAB
*
__restrict__
p_shared_block
,
const
CBlockIndex
&
c_block_idx
,
const
CThreadIndex
&
c_thread_idx
,
const
AGridDesc_E0_E1_K0_K1_E2
&
a_e0_e1_k0_k1_e2_grid_desc
,
const
BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2
&
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc
,
const
CThreadDesc_K1_N_H2_W2
&
,
integral_constant
<
bool
,
HasMainE0BlockLoop
>
)
{
constexpr
auto
HasMainE1BlockLoop
=
CalculateHasMainE1BlockLoop
();
constexpr
auto
HasDoubleTailE1BlockLoop
=
CalculateHasDoubleTailE1BlockLoop
();
// const auto c_k_n_h_w_block_cluster_idx =
// GetCBlockIndex(c_blockid_to_k_n_h_w_block_cluster_adaptor);
// c_blockid_to_k_n_h_w_block_cluster_adaptor.CalculateBottomIndex(
// make_multi_index(get_block_1d_id()));
const
index_t
k_block_work_id
=
__builtin_amdgcn_readfirstlane
(
c_block_idx
[
I0
]);
const
index_t
n_block_work_id
=
__builtin_amdgcn_readfirstlane
(
c_block_idx
[
I1
]);
const
index_t
ho_block_work_id
=
__builtin_amdgcn_readfirstlane
(
c_block_idx
[
I2
]);
const
index_t
wo_block_work_id
=
__builtin_amdgcn_readfirstlane
(
c_block_idx
[
I3
]);
constexpr
auto
max_lds_align
=
Number
<
ABlockTransferDstScalarPerVector_E2
>
{};
constexpr
auto
a_e1_k1_e2_block_gemm_desc
=
make_naive_tensor_descriptor_aligned
(
make_tuple
(
Number
<
E1PerBlock
>
{},
Number
<
KPerBlock
>
{},
Number
<
E2
>
{}),
max_lds_align
);
constexpr
auto
b_e1_n_h_w_e2_block_gemm_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
E1PerBlock
>
{},
I1
,
Number
<
HoPerBlock
>
{},
Number
<
WoPerBlock
>
{},
Number
<
E2
>
{}));
constexpr
auto
c_k1_n_h2_w2_thread_gemm_desc
=
CThreadDesc_K1_N_H2_W2
{};
auto
blockwise_gemm
=
BlockwiseGemmDlops_km_kn_m0m1n0n1_v3
<
BlockSize
,
FloatAB
,
FloatAB
,
FloatAcc
,
decltype
(
a_e1_k1_e2_block_gemm_desc
),
decltype
(
b_e1_n_h_w_e2_block_gemm_desc
),
decltype
(
c_k1_n_h2_w2_thread_gemm_desc
),
EPerThread
,
K2
>
{};
// blockwise_gemm.GetBeginOfCThreadDesc_K_N_Ho_Wo(get_thread_local_1d_id());
const
auto
ho_thread_id
=
c_thread_idx
[
I2
];
const
auto
wo_thread_id
=
c_thread_idx
[
I3
];
constexpr
auto
a_e0_e1_k0_k1_e2_block_copy_desc
=
make_naive_tensor_descriptor_aligned
(
make_tuple
(
Number
<
I1
>
{},
Number
<
E1
>
{},
I1
,
Number
<
KPerBlock
>
{},
Number
<
E2
>
{}),
max_lds_align
);
// A matrix blockwise copy
auto
a_blockwise_copy
=
BlockwiseTensorSliceTransfer_v4
<
BlockSize
,
InMemoryDataOperationEnum_t
::
Set
,
Sequence
<
I1
,
E1
,
I1
,
KPerBlock
,
E2
>
,
ABlockTransferThreadSliceLengths_E0_E1_K0_K1_E2
,
ABlockTransferThreadClusterLengths_E0_E1_K0_K1_E2
,
ABlockTransferThreadClusterArrangeOrder
,
FloatAB
,
FloatAB
,
decltype
(
a_e0_e1_k0_k1_e2_grid_desc
),
decltype
(
a_e0_e1_k0_k1_e2_block_copy_desc
),
ABlockTransferSrcAccessOrder
,
Sequence
<
0
,
1
,
2
,
3
,
4
>
,
ABlockTransferSrcVectorDim
,
4
,
ABlockTransferSrcScalarPerVector
,
ABlockTransferDstScalarPerVector_E2
,
1
,
1
,
AThreadTransferSrcResetCoordinateAfterRun
,
false
>
(
a_e0_e1_k0_k1_e2_grid_desc
,
make_multi_index
(
0
,
0
,
k_block_work_id
,
0
,
0
),
a_e0_e1_k0_k1_e2_block_copy_desc
,
make_multi_index
(
0
,
0
,
0
,
0
,
0
));
constexpr
auto
a_block_slice_copy_step
=
make_multi_index
(
I1
,
0
,
0
,
0
,
0
);
constexpr
auto
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_thread_copy_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
Number
<
E1PerBlock
>
{},
I1
,
I1
,
I1
,
Number
<
HoPerThread
>
{},
I1
,
I1
,
Number
<
WoPerThread
>
{},
Number
<
E2
>
{}));
auto
b_threadwise_transfer
=
ThreadwiseTensorSliceTransfer_v2
<
FloatAB
,
FloatAB
,
decltype
(
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc
),
decltype
(
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_thread_copy_desc
),
Sequence
<
I1
,
E1PerBlock
,
I1
,
I1
,
I1
,
HoPerThread
,
I1
,
I1
,
WoPerThread
,
E2
>
,
BBlockTransferSrcAccessOrder
,
BBlockTransferSrcVectorDim
,
BBlockTransferSrcScalarPerVector
,
BThreadTransferSrcResetCoordinateAfterRun
,
true
>
(
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc
,
make_multi_index
(
0
,
0
,
n_block_work_id
,
ho_block_work_id
,
ho_thread_id
,
0
,
wo_block_work_id
,
wo_thread_id
,
0
,
0
));
auto
a_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Lds
>
(
p_shared_block
,
a_e0_e1_k0_k1_e2_block_copy_desc
.
GetElementSpaceSize
());
//// register allocation for output
// StaticBuffer<AddressSpaceEnum_t::Vgpr,
// FloatAcc,
// c_k1_n_h2_w2_thread_gemm_desc.GetElementSpaceSize(),
// true>
// c_thread_buf;
// initialize output thread tensor
ThreadwiseTensorSliceSet_v1
<
FloatAcc
,
decltype
(
c_k1_n_h2_w2_thread_gemm_desc
),
Sequence
<
KPerThread
,
I1
,
HoPerThread
,
WoPerThread
>>
{}
.
Run
(
c_k1_n_h2_w2_thread_gemm_desc
,
make_tuple
(
I0
,
I0
,
I0
,
I0
),
c_thread_buf
,
FloatAcc
{
0
});
constexpr
auto
b_thread_slice_copy_step
=
make_multi_index
(
0
,
E1PerBlock
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
);
// hack to control index calculation when iterating over A and B matrix for threadwise copy
constexpr
auto
a_e0_e1_k_e2_global_step_hacks
=
AGlobalStepHacks
{};
constexpr
auto
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_global_step_hacks
=
BGlobalStepHacks
{};
// double regsiter buffer for b
StaticBuffer
<
AddressSpaceEnum_t
::
Vgpr
,
FloatAB
,
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_thread_copy_desc
.
GetElementSpaceSize
(),
true
>
b_thread_even_buf
,
b_thread_odd_buf
;
if
constexpr
(
HasMainE0BlockLoop
)
{
const
auto
E0
=
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc
.
GetLength
(
I0
);
index_t
e0_block_data_begin
=
0
;
do
{
// LDS double buffer: preload data
{
a_blockwise_copy
.
RunRead
(
a_e0_e1_k0_k1_e2_grid_desc
,
a_global_buf
,
a_e0_e1_k_e2_global_step_hacks
);
b_threadwise_transfer
.
Run
(
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc
,
b_global_buf
,
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_thread_copy_desc
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
b_thread_even_buf
,
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_global_step_hacks
);
a_blockwise_copy
.
RunWrite
(
a_e0_e1_k0_k1_e2_block_copy_desc
,
a_block_buf
);
}
__syncthreads
();
if
constexpr
(
HasMainE1BlockLoop
)
{
index_t
e1_block_data_begin
=
0
;
// LDS double buffer: main body
// use Do-While loop instead of For loop to simplify control flow
do
{
// even iteration
b_threadwise_transfer
.
MoveSrcSliceWindow
(
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc
,
b_thread_slice_copy_step
,
BGlobalMoveSliceWindowStepHacks
{});
b_threadwise_transfer
.
Run
(
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc
,
b_global_buf
,
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_thread_copy_desc
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
b_thread_odd_buf
,
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_global_step_hacks
);
// LDS double buffer: GEMM on current data
blockwise_gemm
.
Run
(
a_block_buf
,
b_thread_even_buf
,
c_thread_buf
);
blockwise_gemm
.
MoveABlockSliceWindow
(
make_tuple
(
E1PerBlock
,
0
,
0
));
b_threadwise_transfer
.
MoveSrcSliceWindow
(
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc
,
b_thread_slice_copy_step
,
BGlobalMoveSliceWindowStepHacks
{});
b_threadwise_transfer
.
Run
(
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc
,
b_global_buf
,
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_thread_copy_desc
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
b_thread_even_buf
,
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_global_step_hacks
);
// LDS double buffer: GEMM on current data
blockwise_gemm
.
Run
(
a_block_buf
,
b_thread_odd_buf
,
c_thread_buf
);
blockwise_gemm
.
MoveABlockSliceWindow
(
make_tuple
(
E1PerBlock
,
0
,
0
));
e1_block_data_begin
+=
2
*
E1PerBlock
;
}
while
(
e1_block_data_begin
<
E1
-
2
*
E1PerBlock
);
}
// LDS double buffer: tail
if
constexpr
(
HasDoubleTailE1BlockLoop
)
// if has 2 iteration left
{
b_threadwise_transfer
.
MoveSrcSliceWindow
(
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc
,
b_thread_slice_copy_step
,
BGlobalMoveSliceWindowStepHacks
{});
b_threadwise_transfer
.
Run
(
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc
,
b_global_buf
,
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_thread_copy_desc
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
b_thread_odd_buf
,
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_global_step_hacks
);
// LDS double buffer: GEMM on 2nd-last data
blockwise_gemm
.
Run
(
a_block_buf
,
b_thread_even_buf
,
c_thread_buf
);
blockwise_gemm
.
MoveABlockSliceWindow
(
make_tuple
(
E1PerBlock
,
0
,
0
));
// LDS double buffer: GEMM on last data
blockwise_gemm
.
Run
(
a_block_buf
,
b_thread_odd_buf
,
c_thread_buf
);
}
else
// if has 1 iteration left
{
// LDS double buffer: GEMM on last data
blockwise_gemm
.
Run
(
a_block_buf
,
b_thread_even_buf
,
c_thread_buf
);
}
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_e0_e1_k0_k1_e2_grid_desc
,
a_block_slice_copy_step
,
AGlobalMoveSliceWindowStepHacks
{});
blockwise_gemm
.
MoveABlockSliceWindow
(
make_tuple
(
-
(
E1
-
E1PerBlock
),
0
,
0
));
b_threadwise_transfer
.
MoveSrcSliceWindow
(
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc
,
b_thread_slice_copy_step
,
BGlobalMoveSliceWindowStepHacks
{});
e0_block_data_begin
+=
1
;
}
while
(
e0_block_data_begin
<
E0
);
}
else
{
// LDS double buffer: preload data
{
a_blockwise_copy
.
RunRead
(
a_e0_e1_k0_k1_e2_grid_desc
,
a_global_buf
,
a_e0_e1_k_e2_global_step_hacks
);
b_threadwise_transfer
.
Run
(
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc
,
b_global_buf
,
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_thread_copy_desc
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
b_thread_even_buf
,
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_global_step_hacks
);
a_blockwise_copy
.
RunWrite
(
a_e0_e1_k0_k1_e2_block_copy_desc
,
a_block_buf
);
}
__syncthreads
();
if
constexpr
(
HasMainE1BlockLoop
)
{
index_t
e1_block_data_begin
=
0
;
// LDS double buffer: main body
// use Do-While loop instead of For loop to simplify control flow
do
{
// even iteration
b_threadwise_transfer
.
MoveSrcSliceWindow
(
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc
,
b_thread_slice_copy_step
,
BGlobalMoveSliceWindowStepHacks
{});
b_threadwise_transfer
.
Run
(
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc
,
b_global_buf
,
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_thread_copy_desc
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
b_thread_odd_buf
,
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_global_step_hacks
);
// LDS double buffer: GEMM on current data
blockwise_gemm
.
Run
(
a_block_buf
,
b_thread_even_buf
,
c_thread_buf
);
blockwise_gemm
.
MoveABlockSliceWindow
(
make_tuple
(
E1PerBlock
,
0
,
0
));
b_threadwise_transfer
.
MoveSrcSliceWindow
(
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc
,
b_thread_slice_copy_step
,
BGlobalMoveSliceWindowStepHacks
{});
b_threadwise_transfer
.
Run
(
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc
,
b_global_buf
,
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_thread_copy_desc
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
b_thread_even_buf
,
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_global_step_hacks
);
// LDS double buffer: GEMM on current data
blockwise_gemm
.
Run
(
a_block_buf
,
b_thread_odd_buf
,
c_thread_buf
);
blockwise_gemm
.
MoveABlockSliceWindow
(
make_tuple
(
E1PerBlock
,
0
,
0
));
e1_block_data_begin
+=
2
*
E1PerBlock
;
}
while
(
e1_block_data_begin
<
E1
-
2
*
E1PerBlock
);
}
// LDS double buffer: tail
if
constexpr
(
HasDoubleTailE1BlockLoop
)
// if has 2 iteration left
{
b_threadwise_transfer
.
MoveSrcSliceWindow
(
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc
,
b_thread_slice_copy_step
,
BGlobalMoveSliceWindowStepHacks
{});
b_threadwise_transfer
.
Run
(
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc
,
b_global_buf
,
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_thread_copy_desc
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
b_thread_odd_buf
,
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_global_step_hacks
);
// LDS double buffer: GEMM on 2nd-last data
blockwise_gemm
.
Run
(
a_block_buf
,
b_thread_even_buf
,
c_thread_buf
);
blockwise_gemm
.
MoveABlockSliceWindow
(
make_tuple
(
E1PerBlock
,
0
,
0
));
// LDS double buffer: GEMM on last data
blockwise_gemm
.
Run
(
a_block_buf
,
b_thread_odd_buf
,
c_thread_buf
);
}
else
// if has 1 iteration left
{
// LDS double buffer: GEMM on last data
blockwise_gemm
.
Run
(
a_block_buf
,
b_thread_even_buf
,
c_thread_buf
);
}
}
}
template
<
typename
AGridDesc_E0_E1_K0_K1_E2
,
typename
BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2
,
typename
CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2
,
typename
DGridDesc_K0_K1_N_H0_H1_Hx_W0_W1_Wx
,
typename
CBlockIdToBlockClusterAdaptor_K_N_H_W
,
bool
HasMainE0BlockLoop
>
__device__
static
void
Conv
(
const
FloatAB
*
__restrict__
p_a_global
,
const
FloatAB
*
__restrict__
p_b_global
,
const
FloatC
*
__restrict__
p_bias_global
,
FloatC
*
__restrict__
p_c_global
,
FloatC
*
__restrict__
p_d_global
,
FloatAB
*
__restrict__
p_shared_block
,
const
AGridDesc_E0_E1_K0_K1_E2
&
a_e0_e1_k0_k1_e2_grid_desc
,
const
BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2
&
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc
,
const
CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2
&
c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc
,
const
DGridDesc_K0_K1_N_H0_H1_Hx_W0_W1_Wx
&
d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc
,
const
CBlockIdToBlockClusterAdaptor_K_N_H_W
&
c_blockid_to_k_n_h_w_block_cluster_adaptor
,
integral_constant
<
bool
,
HasMainE0BlockLoop
>
)
{
const
auto
bias_k0_k1_grid_desc
=
MakeBiasK0K1GridDescriptor
(
c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc
);
const
auto
a_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Global
>
(
p_a_global
,
a_e0_e1_k0_k1_e2_grid_desc
.
GetElementSpaceSize
());
const
auto
b_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Global
>
(
p_b_global
,
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc
.
GetElementSpaceSize
());
auto
c_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Global
>
(
p_c_global
,
c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc
.
GetElementSpaceSize
());
auto
d_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Global
>
(
p_d_global
,
d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc
.
GetElementSpaceSize
());
auto
bias_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Global
>
(
p_bias_global
,
bias_k0_k1_grid_desc
.
GetElementSpaceSize
());
constexpr
auto
c_k1_n_h2_w2_thread_gemm_desc
=
MakeCK1NH2W2ThreadDescriptor
();
// register allocation for output
StaticBuffer
<
AddressSpaceEnum_t
::
Vgpr
,
FloatAcc
,
c_k1_n_h2_w2_thread_gemm_desc
.
GetElementSpaceSize
(),
true
>
c_thread_buf
;
const
auto
c_k_n_h_w_block_cluster_idx
=
GetCBlockIndex
(
c_blockid_to_k_n_h_w_block_cluster_adaptor
);
const
auto
c_thread_mtx_index
=
GetCThreadIndex
();
// GemmOp
GemmOp
(
a_global_buf
,
b_global_buf
,
c_thread_buf
,
p_shared_block
,
c_k_n_h_w_block_cluster_idx
,
c_thread_mtx_index
,
a_e0_e1_k0_k1_e2_grid_desc
,
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc
,
c_k1_n_h2_w2_thread_gemm_desc
,
integral_constant
<
bool
,
HasMainE0BlockLoop
>
{});
// Output
WriteOut
(
c_thread_buf
,
c_global_buf
,
c_k_n_h_w_block_cluster_idx
,
c_thread_mtx_index
,
c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc
);
}
template
<
typename
AGridDesc_E0_E1_K0_K1_E2
,
typename
BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2
,
typename
CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2
,
typename
CBlockIdToBlockClusterAdaptor_K_N_H_W
,
bool
HasMainE0BlockLoop
,
ActivTypeEnum_t
ActivType
>
__device__
static
void
ConvBiasActiv
(
const
FloatAB
*
__restrict__
p_a_global
,
const
FloatAB
*
__restrict__
p_b_global
,
const
FloatC
*
__restrict__
p_bias_global
,
FloatC
*
__restrict__
p_c_global
,
FloatAB
*
__restrict__
p_shared_block
,
const
AGridDesc_E0_E1_K0_K1_E2
&
a_e0_e1_k0_k1_e2_grid_desc
,
const
BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2
&
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc
,
const
CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2
&
c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc
,
const
CBlockIdToBlockClusterAdaptor_K_N_H_W
&
c_blockid_to_k_n_h_w_block_cluster_adaptor
,
integral_constant
<
bool
,
HasMainE0BlockLoop
>
,
integral_constant
<
ActivTypeEnum_t
,
ActivType
>
)
{
static
constexpr
auto
activ_type
=
integral_constant
<
ActivTypeEnum_t
,
ActivType
>
{};
const
auto
bias_k0_k1_grid_desc
=
MakeBiasK0K1GridDescriptor
(
c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc
);
const
auto
a_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Global
>
(
p_a_global
,
a_e0_e1_k0_k1_e2_grid_desc
.
GetElementSpaceSize
());
const
auto
b_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Global
>
(
p_b_global
,
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc
.
GetElementSpaceSize
());
auto
c_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Global
>
(
p_c_global
,
c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc
.
GetElementSpaceSize
());
auto
bias_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Global
>
(
p_bias_global
,
bias_k0_k1_grid_desc
.
GetElementSpaceSize
());
constexpr
auto
c_k1_n_h2_w2_thread_gemm_desc
=
MakeCK1NH2W2ThreadDescriptor
();
// register allocation for output
StaticBuffer
<
AddressSpaceEnum_t
::
Vgpr
,
FloatAcc
,
c_k1_n_h2_w2_thread_gemm_desc
.
GetElementSpaceSize
(),
true
>
c_thread_buf
;
const
auto
c_k_n_h_w_block_cluster_idx
=
GetCBlockIndex
(
c_blockid_to_k_n_h_w_block_cluster_adaptor
);
const
auto
c_thread_mtx_index
=
GetCThreadIndex
();
// GemmOp
GemmOp
(
a_global_buf
,
b_global_buf
,
c_thread_buf
,
p_shared_block
,
c_k_n_h_w_block_cluster_idx
,
c_thread_mtx_index
,
a_e0_e1_k0_k1_e2_grid_desc
,
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc
,
c_k1_n_h2_w2_thread_gemm_desc
,
integral_constant
<
bool
,
HasMainE0BlockLoop
>
{});
// Bias
BiasOp
(
bias_global_buf
,
c_thread_buf
,
c_k_n_h_w_block_cluster_idx
,
c_thread_mtx_index
,
bias_k0_k1_grid_desc
,
c_k1_n_h2_w2_thread_gemm_desc
);
// Activ
Activation
(
c_thread_buf
,
c_k1_n_h2_w2_thread_gemm_desc
,
activ_type
);
// Output
WriteOut
(
c_thread_buf
,
c_global_buf
,
c_k_n_h_w_block_cluster_idx
,
c_thread_mtx_index
,
c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc
);
}
template
<
typename
AGridDesc_E0_E1_K0_K1_E2
,
typename
BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2
,
typename
CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2
,
typename
DGridDesc_K0_K1_N_H0_H1_Hx_W0_W1_Wx
,
typename
CBlockIdToBlockClusterAdaptor_K_N_H_W
,
bool
HasMainE0BlockLoop
,
ActivTypeEnum_t
ActivType
>
__device__
static
void
ConvBiasActivMaxpool
(
const
FloatAB
*
__restrict__
p_a_global
,
const
FloatAB
*
__restrict__
p_b_global
,
const
FloatC
*
__restrict__
p_bias_global
,
FloatC
*
__restrict__
p_c_global
,
FloatC
*
__restrict__
p_d_global
,
FloatAB
*
__restrict__
p_shared_block
,
const
AGridDesc_E0_E1_K0_K1_E2
&
a_e0_e1_k0_k1_e2_grid_desc
,
const
BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2
&
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc
,
const
CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2
&
c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc
,
const
DGridDesc_K0_K1_N_H0_H1_Hx_W0_W1_Wx
&
d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc
,
const
CBlockIdToBlockClusterAdaptor_K_N_H_W
&
c_blockid_to_k_n_h_w_block_cluster_adaptor
,
integral_constant
<
bool
,
HasMainE0BlockLoop
>
,
integral_constant
<
ActivTypeEnum_t
,
ActivType
>
)
{
static
constexpr
auto
activ_type
=
integral_constant
<
ActivTypeEnum_t
,
ActivType
>
{};
const
auto
bias_k0_k1_grid_desc
=
MakeBiasK0K1GridDescriptor
(
c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc
);
const
auto
a_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Global
>
(
p_a_global
,
a_e0_e1_k0_k1_e2_grid_desc
.
GetElementSpaceSize
());
const
auto
b_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Global
>
(
p_b_global
,
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc
.
GetElementSpaceSize
());
auto
c_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Global
>
(
p_c_global
,
c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc
.
GetElementSpaceSize
());
auto
d_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Global
>
(
p_d_global
,
d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc
.
GetElementSpaceSize
());
auto
bias_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Global
>
(
p_bias_global
,
bias_k0_k1_grid_desc
.
GetElementSpaceSize
());
constexpr
auto
c_k1_n_h2_w2_thread_gemm_desc
=
MakeCK1NH2W2ThreadDescriptor
();
// register allocation for output
StaticBuffer
<
AddressSpaceEnum_t
::
Vgpr
,
FloatAcc
,
c_k1_n_h2_w2_thread_gemm_desc
.
GetElementSpaceSize
(),
true
>
c_thread_buf
;
const
auto
c_k_n_h_w_block_cluster_idx
=
GetCBlockIndex
(
c_blockid_to_k_n_h_w_block_cluster_adaptor
);
const
auto
c_thread_mtx_index
=
GetCThreadIndex
();
// GemmOp
GemmOp
(
a_global_buf
,
b_global_buf
,
c_thread_buf
,
p_shared_block
,
c_k_n_h_w_block_cluster_idx
,
c_thread_mtx_index
,
a_e0_e1_k0_k1_e2_grid_desc
,
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc
,
c_k1_n_h2_w2_thread_gemm_desc
,
integral_constant
<
bool
,
HasMainE0BlockLoop
>
{});
// Bias
BiasOp
(
bias_global_buf
,
c_thread_buf
,
c_k_n_h_w_block_cluster_idx
,
c_thread_mtx_index
,
bias_k0_k1_grid_desc
,
c_k1_n_h2_w2_thread_gemm_desc
);
// Activ
Activation
(
c_thread_buf
,
c_k1_n_h2_w2_thread_gemm_desc
,
activ_type
);
// Output
WriteOut
(
c_thread_buf
,
c_global_buf
,
c_k_n_h_w_block_cluster_idx
,
c_thread_mtx_index
,
c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc
);
// MaxPool
MaxPool
(
c_thread_buf
,
d_global_buf
,
c_k_n_h_w_block_cluster_idx
,
c_thread_mtx_index
,
c_k1_n_h2_w2_thread_gemm_desc
,
d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc
);
}
template
<
typename
AGridDesc_E0_E1_K0_K1_E2
,
typename
BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2
,
typename
CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2
,
typename
DGridDesc_K0_K1_N_H0_H1_Hx_W0_W1_Wx
,
typename
CBlockIdToBlockClusterAdaptor_K_N_H_W
,
bool
HasMainE0BlockLoop
,
ActivTypeEnum_t
ActivType
>
__device__
static
void
ConvBiasActivResizeAdd
(
const
FloatAB
*
__restrict__
p_a_global
,
const
FloatAB
*
__restrict__
p_b_global
,
const
FloatC
*
__restrict__
p_bias_global
,
FloatC
*
__restrict__
p_d_global
,
FloatAB
*
__restrict__
p_shared_block
,
const
AGridDesc_E0_E1_K0_K1_E2
&
a_e0_e1_k0_k1_e2_grid_desc
,
const
BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2
&
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc
,
const
CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2
&
c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc
,
const
DGridDesc_K0_K1_N_H0_H1_Hx_W0_W1_Wx
&
d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc
,
const
CBlockIdToBlockClusterAdaptor_K_N_H_W
&
c_blockid_to_k_n_h_w_block_cluster_adaptor
,
integral_constant
<
bool
,
HasMainE0BlockLoop
>
,
integral_constant
<
ActivTypeEnum_t
,
ActivType
>
)
{
static
constexpr
auto
activ_type
=
integral_constant
<
ActivTypeEnum_t
,
ActivType
>
{};
const
auto
bias_k0_k1_grid_desc
=
MakeBiasK0K1GridDescriptor
(
c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc
);
const
auto
a_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Global
>
(
p_a_global
,
a_e0_e1_k0_k1_e2_grid_desc
.
GetElementSpaceSize
());
const
auto
b_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Global
>
(
p_b_global
,
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc
.
GetElementSpaceSize
());
auto
d_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Global
>
(
p_d_global
,
d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc
.
GetElementSpaceSize
());
auto
bias_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Global
>
(
p_bias_global
,
bias_k0_k1_grid_desc
.
GetElementSpaceSize
());
constexpr
auto
c_k1_n_h2_w2_thread_gemm_desc
=
MakeCK1NH2W2ThreadDescriptor
();
// register allocation for output
StaticBuffer
<
AddressSpaceEnum_t
::
Vgpr
,
FloatAcc
,
c_k1_n_h2_w2_thread_gemm_desc
.
GetElementSpaceSize
(),
true
>
c_thread_buf
;
const
auto
c_k_n_h_w_block_cluster_idx
=
GetCBlockIndex
(
c_blockid_to_k_n_h_w_block_cluster_adaptor
);
const
auto
c_thread_mtx_index
=
GetCThreadIndex
();
// GemmOp
GemmOp
(
a_global_buf
,
b_global_buf
,
c_thread_buf
,
p_shared_block
,
c_k_n_h_w_block_cluster_idx
,
c_thread_mtx_index
,
a_e0_e1_k0_k1_e2_grid_desc
,
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc
,
c_k1_n_h2_w2_thread_gemm_desc
,
integral_constant
<
bool
,
HasMainE0BlockLoop
>
{});
// Bias
BiasOp
(
bias_global_buf
,
c_thread_buf
,
c_k_n_h_w_block_cluster_idx
,
c_thread_mtx_index
,
bias_k0_k1_grid_desc
,
c_k1_n_h2_w2_thread_gemm_desc
);
// Activ
Activation
(
c_thread_buf
,
c_k1_n_h2_w2_thread_gemm_desc
,
activ_type
);
// Resize_Add
ResizeAdd
(
c_thread_buf
,
d_global_buf
,
c_k_n_h_w_block_cluster_idx
,
c_thread_mtx_index
,
c_k1_n_h2_w2_thread_gemm_desc
,
d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc
);
}
};
}
// namespace ck
#endif
composable_kernel/include/tensor_operation/threadwise_gemm_dlops_v3.hpp
View file @
a037693f
...
...
@@ -9,21 +9,22 @@ namespace ck {
// C[M, N] += transpose(A[K, M]) * B[K, N]
// Element of matrix can be vectorized data
// Assume:
// 1. ADesc, BDesc, CDesc are known at compile-time
// 1. AThreadDesc_E1_K_E2, BThreadDesc_E1_N_Ho_Wo_E2, CThreadDesc_K_N_Ho_Wo are known at
// compile-time
// 2. AOriginIdx, BOriginIdx, COriginIdx are known at compile-time
template
<
typename
FloatA
,
typename
FloatB
,
typename
FloatC
,
typename
ADesc
,
typename
BDesc
,
typename
CDesc
,
index_t
H
,
index_t
W
,
typename
enable_if
<
ADesc
::
IsKnownAtCompileTime
()
&&
BDesc
::
IsKnownAtCompileTime
()
&&
CDesc
::
IsKnownAtCompileTime
(),
typename
AThreadDesc_E1_K_E2
,
typename
BThreadDesc_E1_N_Ho_Wo_E2
,
typename
CThreadDesc_K_N_Ho_Wo
,
typename
enable_if
<
AThreadDesc_E1_K_E2
::
IsKnownAtCompileTime
()
&&
BThreadDesc_E1_N_Ho_Wo_E2
::
IsKnownAtCompileTime
()
&&
CThreadDesc_K_N_Ho_Wo
::
IsKnownAtCompileTime
(),
bool
>
::
type
=
false
>
struct
ThreadwiseGemmDlops_km_kn_mn_v3
{
template
<
typename
ABuffer
,
typename
AOriginIdx
,
typename
BBuffer
,
...
...
@@ -37,8 +38,10 @@ struct ThreadwiseGemmDlops_km_kn_mn_v3
CBuffer
&
c_buf
,
COriginIdx
)
{
static_assert
(
ADesc
::
IsKnownAtCompileTime
()
&&
BDesc
::
IsKnownAtCompileTime
()
&&
CDesc
::
IsKnownAtCompileTime
(),
static_assert
(
AThreadDesc_E1_K_E2
::
IsKnownAtCompileTime
()
&&
BThreadDesc_E1_N_Ho_Wo_E2
::
IsKnownAtCompileTime
()
&&
CThreadDesc_K_N_Ho_Wo
::
IsKnownAtCompileTime
(),
"wrong! Desc should be known at compile-time"
);
static_assert
(
is_known_at_compile_time
<
remove_cvref_t
<
AOriginIdx
>>::
value
&&
...
...
@@ -54,102 +57,107 @@ struct ThreadwiseGemmDlops_km_kn_mn_v3
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
E1
=
AThreadDesc_E1_K_E2
{}.
GetLength
(
I0
);
constexpr
auto
K
=
AThreadDesc_E1_K_E2
{}.
GetLength
(
I1
);
constexpr
auto
E2
=
AThreadDesc_E1_K_E2
{}.
GetLength
(
I2
);
constexpr
auto
E
=
ADesc
{}.
GetLength
(
I
0
);
constexpr
auto
K
=
ADesc
{}.
GetLength
(
I
1
);
constexpr
auto
Ho
=
BThreadDesc_E1_N_Ho_Wo_E2
{}.
GetLength
(
I
2
);
constexpr
auto
Wo
=
BThreadDesc_E1_N_Ho_Wo_E2
{}.
GetLength
(
I
3
);
constexpr
auto
a_origin_idx
=
to_multi_index
(
AOriginIdx
{});
constexpr
auto
b_origin_idx
=
to_multi_index
(
BOriginIdx
{});
constexpr
auto
c_origin_idx
=
to_multi_index
(
COriginIdx
{});
static_for
<
0
,
E
,
1
>
{}([
&
](
auto
e
)
{
if
constexpr
((
Ho
%
2
==
0
)
&&
(
Wo
%
2
==
0
))
{
constexpr
auto
SubHW
=
2
;
static_for
<
0
,
K
,
1
>
{}([
&
](
auto
k
)
{
constexpr
index_t
a_offset
=
ADesc
{}.
CalculateOffset
(
a_origin_idx
+
make_tuple
(
e
,
k
));
if
constexpr
(
H
==
2
&&
W
==
2
)
{
constexpr
index_t
b_offset_0
=
BDesc
{}.
CalculateOffset
(
b_origin_idx
+
make_tuple
(
e
,
0
,
0
,
0
));
constexpr
index_t
b_offset_1
=
BDesc
{}.
CalculateOffset
(
b_origin_idx
+
make_tuple
(
e
,
0
,
0
,
1
));
constexpr
index_t
b_offset_2
=
BDesc
{}.
CalculateOffset
(
b_origin_idx
+
make_tuple
(
e
,
0
,
1
,
0
));
constexpr
index_t
b_offset_3
=
BDesc
{}.
CalculateOffset
(
b_origin_idx
+
make_tuple
(
e
,
0
,
1
,
1
));
constexpr
index_t
c_offset_0
=
CDesc
{}.
CalculateOffset
(
c_origin_idx
+
make_tuple
(
k
,
0
,
0
,
0
));
constexpr
index_t
c_offset_1
=
CDesc
{}.
CalculateOffset
(
c_origin_idx
+
make_tuple
(
k
,
0
,
0
,
1
));
constexpr
index_t
c_offset_2
=
CDesc
{}.
CalculateOffset
(
c_origin_idx
+
make_tuple
(
k
,
0
,
1
,
0
));
constexpr
index_t
c_offset_3
=
CDesc
{}.
CalculateOffset
(
c_origin_idx
+
make_tuple
(
k
,
0
,
1
,
1
));
amd_assembly_outer_product_1x4
(
a_buf
[
Number
<
a_offset
>
{}],
b_buf
[
Number
<
b_offset_0
>
{}],
b_buf
[
Number
<
b_offset_1
>
{}],
b_buf
[
Number
<
b_offset_2
>
{}],
b_buf
[
Number
<
b_offset_3
>
{}],
c_buf
(
Number
<
c_offset_0
>
{}),
c_buf
(
Number
<
c_offset_1
>
{}),
c_buf
(
Number
<
c_offset_2
>
{}),
c_buf
(
Number
<
c_offset_3
>
{}));
}
else
if
constexpr
(
H
==
4
&&
W
==
1
)
{
constexpr
index_t
b_offset_0
=
BDesc
{}.
CalculateOffset
(
b_origin_idx
+
make_tuple
(
e
,
0
,
0
,
0
));
constexpr
index_t
b_offset_1
=
BDesc
{}.
CalculateOffset
(
b_origin_idx
+
make_tuple
(
e
,
0
,
1
,
0
));
constexpr
index_t
b_offset_2
=
BDesc
{}.
CalculateOffset
(
b_origin_idx
+
make_tuple
(
e
,
0
,
2
,
0
));
constexpr
index_t
b_offset_3
=
BDesc
{}.
CalculateOffset
(
b_origin_idx
+
make_tuple
(
e
,
0
,
3
,
0
));
constexpr
index_t
c_offset_0
=
CDesc
{}.
CalculateOffset
(
c_origin_idx
+
make_tuple
(
k
,
0
,
0
,
0
));
constexpr
index_t
c_offset_1
=
CDesc
{}.
CalculateOffset
(
c_origin_idx
+
make_tuple
(
k
,
0
,
1
,
0
));
constexpr
index_t
c_offset_2
=
CDesc
{}.
CalculateOffset
(
c_origin_idx
+
make_tuple
(
k
,
0
,
2
,
0
));
constexpr
index_t
c_offset_3
=
CDesc
{}.
CalculateOffset
(
c_origin_idx
+
make_tuple
(
k
,
0
,
3
,
0
));
amd_assembly_outer_product_1x4
(
a_buf
[
Number
<
a_offset
>
{}],
b_buf
[
Number
<
b_offset_0
>
{}],
b_buf
[
Number
<
b_offset_1
>
{}],
b_buf
[
Number
<
b_offset_2
>
{}],
b_buf
[
Number
<
b_offset_3
>
{}],
c_buf
(
Number
<
c_offset_0
>
{}),
c_buf
(
Number
<
c_offset_1
>
{}),
c_buf
(
Number
<
c_offset_2
>
{}),
c_buf
(
Number
<
c_offset_3
>
{}));
}
else
{
static_for
<
0
,
H
,
1
>
{}([
&
](
auto
h
)
{
static_for
<
0
,
W
,
1
>
{}([
&
](
auto
w
)
{
constexpr
index_t
b_offset
=
BDesc
{}.
CalculateOffset
(
b_origin_idx
+
make_tuple
(
e
,
0
,
h
,
w
));
constexpr
index_t
c_offset
=
CDesc
{}.
CalculateOffset
(
c_origin_idx
+
make_tuple
(
k
,
0
,
h
,
w
));
#if 0
c_buf(Number<c_offset>{}) += inner_product_with_conversion<FloatC>{}(
a_buf[Number<a_offset>{}], b_buf[Number<b_offset>{}]);
#else
amd_assembly_inner_product
(
a_buf
[
Number
<
a_offset
>
{}],
b_buf
[
Number
<
b_offset
>
{}],
c_buf
(
Number
<
c_offset
>
{}));
#endif
static_for
<
0
,
Ho
,
SubHW
>
{}([
&
](
auto
h
)
{
static_for
<
0
,
Wo
,
SubHW
>
{}([
&
](
auto
w
)
{
static_for
<
0
,
E1
,
1
>
{}([
&
](
auto
e1
)
{
static_for
<
0
,
E2
,
1
>
{}([
&
](
auto
e2
)
{
constexpr
index_t
a_offset
=
AThreadDesc_E1_K_E2
{}.
CalculateOffset
(
a_origin_idx
+
make_tuple
(
e1
,
k
,
e2
));
constexpr
index_t
b0_offset
=
BThreadDesc_E1_N_Ho_Wo_E2
{}.
CalculateOffset
(
b_origin_idx
+
make_tuple
(
e1
,
0
,
h
,
w
,
e2
));
constexpr
index_t
b1_offset
=
BThreadDesc_E1_N_Ho_Wo_E2
{}.
CalculateOffset
(
b_origin_idx
+
make_tuple
(
e1
,
0
,
h
,
w
+
1
,
e2
));
constexpr
index_t
b2_offset
=
BThreadDesc_E1_N_Ho_Wo_E2
{}.
CalculateOffset
(
b_origin_idx
+
make_tuple
(
e1
,
0
,
h
+
1
,
w
,
e2
));
constexpr
index_t
b3_offset
=
BThreadDesc_E1_N_Ho_Wo_E2
{}.
CalculateOffset
(
b_origin_idx
+
make_tuple
(
e1
,
0
,
h
+
1
,
w
+
1
,
e2
));
constexpr
index_t
c0_offset
=
CThreadDesc_K_N_Ho_Wo
{}.
CalculateOffset
(
c_origin_idx
+
make_tuple
(
k
,
0
,
h
,
w
));
constexpr
index_t
c1_offset
=
CThreadDesc_K_N_Ho_Wo
{}.
CalculateOffset
(
c_origin_idx
+
make_tuple
(
k
,
0
,
h
,
w
+
1
));
constexpr
index_t
c2_offset
=
CThreadDesc_K_N_Ho_Wo
{}.
CalculateOffset
(
c_origin_idx
+
make_tuple
(
k
,
0
,
h
+
1
,
w
));
constexpr
index_t
c3_offset
=
CThreadDesc_K_N_Ho_Wo
{}.
CalculateOffset
(
c_origin_idx
+
make_tuple
(
k
,
0
,
h
+
1
,
w
+
1
));
amd_assembly_outer_product_1x4
(
a_buf
[
Number
<
a_offset
>
{}],
b_buf
[
Number
<
b0_offset
>
{}],
b_buf
[
Number
<
b1_offset
>
{}],
b_buf
[
Number
<
b2_offset
>
{}],
b_buf
[
Number
<
b3_offset
>
{}],
c_buf
(
Number
<
c0_offset
>
{}),
c_buf
(
Number
<
c1_offset
>
{}),
c_buf
(
Number
<
c2_offset
>
{}),
c_buf
(
Number
<
c3_offset
>
{}));
});
});
});
});
});
}
else
{
static_for
<
0
,
K
,
1
>
{}([
&
](
auto
k
)
{
static_for
<
0
,
Ho
,
1
>
{}([
&
](
auto
h
)
{
static_for
<
0
,
Wo
,
1
>
{}([
&
](
auto
w
)
{
static_for
<
0
,
E1
,
1
>
{}([
&
](
auto
e1
)
{
static_for
<
0
,
E2
,
1
>
{}([
&
](
auto
e2
)
{
constexpr
index_t
a_offset
=
AThreadDesc_E1_K_E2
{}.
CalculateOffset
(
a_origin_idx
+
make_tuple
(
e1
,
k
,
e2
));
constexpr
index_t
b_offset
=
BThreadDesc_E1_N_Ho_Wo_E2
{}.
CalculateOffset
(
b_origin_idx
+
make_tuple
(
e1
,
0
,
h
,
w
,
e2
));
constexpr
index_t
c_offset
=
CThreadDesc_K_N_Ho_Wo
{}.
CalculateOffset
(
c_origin_idx
+
make_tuple
(
k
,
0
,
h
,
w
));
inner_product
<
FloatA
,
FloatB
,
FloatC
>
(
a_buf
[
Number
<
a_offset
>
{}],
b_buf
[
Number
<
b_offset
>
{}],
c_buf
(
Number
<
c_offset
>
{}));
});
});
});
}
}
);
});
}
);
}
}
};
...
...
composable_kernel/include/tensor_operation/threadwise_tensor_slice_transfer.hpp
View file @
a037693f
...
...
@@ -217,6 +217,22 @@ struct ThreadwiseTensorSliceTransfer_v1r3
is_dst_valid
,
dst_vector
.
template
AsType
<
dst_vector_t
>()[
Number
<
0
>
{}]);
}
else
if
constexpr
(
DstInMemOp
==
InMemoryDataOperationEnum_t
::
Add
)
{
typename
vector_type_maker
<
DstData
,
DstScalarPerVector
>::
type
tmp
;
tmp
.
template
AsType
<
dst_vector_t
>()(
Number
<
0
>
{})
=
dst_buf
.
template
Get
<
dst_vector_t
>(
dst_coord_
.
GetOffset
(),
is_dst_valid
);
static_for
<
0
,
DstScalarPerVector
,
1
>
{}([
&
](
auto
t
)
{
dst_vector
.
template
AsType
<
DstData
>()(
t
)
+=
tmp
.
template
AsType
<
DstData
>()[
t
];
});
dst_buf
.
template
Set
<
dst_vector_t
>(
dst_coord_
.
GetOffset
(),
is_dst_valid
,
dst_vector
.
template
AsType
<
dst_vector_t
>()[
Number
<
0
>
{}]);
}
constexpr
auto
move_on_dim
=
[
&
]()
constexpr
{
...
...
@@ -666,6 +682,25 @@ struct ThreadwiseTensorSliceTransfer_v2
move_tensor_coordinate
(
src_desc
,
src_coord_
,
adjusted_step
);
}
// src_slice_origin_step_idx need to be known at compile-time, for performance reason
template
<
typename
SrcMoveSliceWindowStepHack
>
__device__
void
MoveSrcSliceWindow
(
const
SrcDesc
&
src_desc
,
const
Index
&
src_slice_origin_step_idx
,
const
SrcMoveSliceWindowStepHack
&
src_move_slice_window_step_hack
)
{
// if src coord was not reset by RunRead(), then need to adjust the step here
const
auto
adjusted_step_idx
=
SrcResetCoordinateAfterRun
?
src_slice_origin_step_idx
:
src_slice_origin_step_idx
+
GetSrcCoordinateResetStep
();
// is it OK to construct a new step every time?
const
auto
adjusted_step
=
make_tensor_coordinate_step
(
src_desc
,
adjusted_step_idx
,
src_move_slice_window_step_hack
);
move_tensor_coordinate
(
src_desc
,
src_coord_
,
adjusted_step
);
}
private:
SrcCoord
src_coord_
;
};
// namespace ck
...
...
composable_kernel/include/utility/amd_buffer_addressing.hpp
View file @
a037693f
...
...
@@ -268,14 +268,14 @@ __device__ typename vector_type<T, N>::type amd_buffer_load_impl(int32x4_t src_w
const
float2_t
tmp
=
llvm_amdgcn_raw_buffer_load_fp32x2
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
0
);
return
as_type
<
double
>
(
tmp
);
return
bit_cast
<
double
>
(
tmp
);
}
else
if
constexpr
(
N
==
2
)
{
const
float4_t
tmp
=
llvm_amdgcn_raw_buffer_load_fp32x4
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
0
);
return
as_type
<
double2_t
>
(
tmp
);
return
bit_cast
<
double2_t
>
(
tmp
);
}
else
if
constexpr
(
N
==
4
)
{
...
...
@@ -289,8 +289,8 @@ __device__ typename vector_type<T, N>::type amd_buffer_load_impl(int32x4_t src_w
0
);
vector_type
<
double
,
4
>
tmp
;
tmp
.
AsType
<
double2_t
>
()(
Number
<
0
>
{})
=
as_type
<
double2_t
>
(
f32_0
);
tmp
.
AsType
<
double2_t
>
()(
Number
<
1
>
{})
=
as_type
<
double2_t
>
(
f32_1
);
tmp
.
AsType
<
double2_t
>
()(
Number
<
0
>
{})
=
bit_cast
<
double2_t
>
(
f32_0
);
tmp
.
AsType
<
double2_t
>
()(
Number
<
1
>
{})
=
bit_cast
<
double2_t
>
(
f32_1
);
return
tmp
.
AsType
<
double4_t
>
()(
Number
<
0
>
{});
}
...
...
@@ -351,7 +351,7 @@ __device__ typename vector_type<T, N>::type amd_buffer_load_impl(int32x4_t src_w
float4_t
tmp
=
llvm_amdgcn_raw_buffer_load_fp32x4
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
0
);
return
as_type
<
half8_t
>
(
tmp
);
return
bit_cast
<
half8_t
>
(
tmp
);
}
}
else
if
constexpr
(
is_same
<
T
,
ushort
>::
value
)
...
...
@@ -376,7 +376,7 @@ __device__ typename vector_type<T, N>::type amd_buffer_load_impl(int32x4_t src_w
int32x4_t
tmp
=
llvm_amdgcn_raw_buffer_load_i32x4
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
0
);
return
as_type
<
ushort8_t
>
(
tmp
);
return
bit_cast
<
ushort8_t
>
(
tmp
);
}
}
else
if
constexpr
(
is_same
<
T
,
int32_t
>::
value
)
...
...
@@ -427,7 +427,7 @@ __device__ typename vector_type<T, N>::type amd_buffer_load_impl(int32x4_t src_w
int16_t
tmp
=
llvm_amdgcn_raw_buffer_load_i16
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
0
);
return
as_type
<
int8x2_t
>
(
tmp
);
return
bit_cast
<
int8x2_t
>
(
tmp
);
#endif
}
else
if
constexpr
(
N
==
4
)
...
...
@@ -439,7 +439,7 @@ __device__ typename vector_type<T, N>::type amd_buffer_load_impl(int32x4_t src_w
int32_t
tmp
=
llvm_amdgcn_raw_buffer_load_i32
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
0
);
return
as_type
<
int8x4_t
>
(
tmp
);
return
bit_cast
<
int8x4_t
>
(
tmp
);
#endif
}
else
if
constexpr
(
N
==
8
)
...
...
@@ -461,7 +461,7 @@ __device__ typename vector_type<T, N>::type amd_buffer_load_impl(int32x4_t src_w
int32x2_t
tmp
=
llvm_amdgcn_raw_buffer_load_i32x2
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
0
);
return
as_type
<
int8x8_t
>
(
tmp
);
return
bit_cast
<
int8x8_t
>
(
tmp
);
#endif
}
else
if
constexpr
(
N
==
16
)
...
...
@@ -495,7 +495,7 @@ __device__ typename vector_type<T, N>::type amd_buffer_load_impl(int32x4_t src_w
int32x4_t
tmp
=
llvm_amdgcn_raw_buffer_load_i32x4
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
0
);
return
as_type
<
int8x16_t
>
(
tmp
);
return
bit_cast
<
int8x16_t
>
(
tmp
);
#endif
}
}
...
...
@@ -521,7 +521,7 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src
// use fp32 store to mimic fp64 store
if
constexpr
(
N
==
1
)
{
llvm_amdgcn_raw_buffer_store_fp32x2
(
as_type
<
float2_t
>
(
src_thread_data
),
llvm_amdgcn_raw_buffer_store_fp32x2
(
bit_cast
<
float2_t
>
(
src_thread_data
),
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
,
...
...
@@ -529,7 +529,7 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src
}
else
if
constexpr
(
N
==
2
)
{
llvm_amdgcn_raw_buffer_store_fp32x4
(
as_type
<
float4_t
>
(
src_thread_data
),
llvm_amdgcn_raw_buffer_store_fp32x4
(
bit_cast
<
float4_t
>
(
src_thread_data
),
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
,
...
...
@@ -591,6 +591,7 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src
}
else
if
constexpr
(
N
==
8
)
{
#if 0
vector_type<half_t, 8> tmp{src_thread_data};
llvm_amdgcn_raw_buffer_store_fp16x4(tmp.AsType<half4_t>()[Number<0>{}],
...
...
@@ -604,6 +605,13 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src
dst_thread_addr_offset,
dst_wave_addr_offset + 4 * sizeof(half_t),
0);
#else
llvm_amdgcn_raw_buffer_store_fp32x4
(
bit_cast
<
float4_t
>
(
src_thread_data
),
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
,
0
);
#endif
}
}
else
if
constexpr
(
is_same
<
T
,
ushort
>::
value
)
...
...
@@ -695,7 +703,7 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src
dst_wave_addr_offset
,
0
);
#else
llvm_amdgcn_raw_buffer_store_i16
(
as_type
<
int16_t
>
(
src_thread_data
),
llvm_amdgcn_raw_buffer_store_i16
(
bit_cast
<
int16_t
>
(
src_thread_data
),
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
,
...
...
@@ -711,7 +719,7 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src
dst_wave_addr_offset
,
0
);
#else
llvm_amdgcn_raw_buffer_store_i32
(
as_type
<
int32_t
>
(
src_thread_data
),
llvm_amdgcn_raw_buffer_store_i32
(
bit_cast
<
int32_t
>
(
src_thread_data
),
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
,
...
...
@@ -720,7 +728,7 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src
}
else
if
constexpr
(
N
==
8
)
{
llvm_amdgcn_raw_buffer_store_i32x2
(
as_type
<
int32x2_t
>
(
src_thread_data
),
llvm_amdgcn_raw_buffer_store_i32x2
(
bit_cast
<
int32x2_t
>
(
src_thread_data
),
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
,
...
...
@@ -728,7 +736,7 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src
}
else
if
constexpr
(
N
==
16
)
{
llvm_amdgcn_raw_buffer_store_i32x4
(
as_type
<
int32x4_t
>
(
src_thread_data
),
llvm_amdgcn_raw_buffer_store_i32x4
(
bit_cast
<
int32x4_t
>
(
src_thread_data
),
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
,
...
...
composable_kernel/include/utility/amd_inline_asm.hpp
View file @
a037693f
...
...
@@ -211,14 +211,14 @@ amd_assembly_outer_product_1x2(int8x4_t a, int8x4_t b0, int8x4_t b1, int32_t& c0
v_dot4_i32_i8 %1, %2, %4, %1
\n
\
"
:
"=v"
(
c0
),
"=v"
(
c1
)
:
"v"
(
as_type
<
int32_t
>
(
a
)),
"v"
(
as_type
<
int32_t
>
(
b0
)),
"v"
(
as_type
<
int32_t
>
(
b1
)),
:
"v"
(
bit_cast
<
int32_t
>
(
a
)),
"v"
(
bit_cast
<
int32_t
>
(
b0
)),
"v"
(
bit_cast
<
int32_t
>
(
b1
)),
"0"
(
c0
),
"1"
(
c1
));
#else
c0
=
__builtin_amdgcn_sdot4
(
as_type
<
int32_t
>
(
a
),
as_type
<
int32_t
>
(
b0
),
c0
,
false
);
c1
=
__builtin_amdgcn_sdot4
(
as_type
<
int32_t
>
(
a
),
as_type
<
int32_t
>
(
b1
),
c1
,
false
);
c0
=
__builtin_amdgcn_sdot4
(
bit_cast
<
int32_t
>
(
a
),
bit_cast
<
int32_t
>
(
b0
),
c0
,
false
);
c1
=
__builtin_amdgcn_sdot4
(
bit_cast
<
int32_t
>
(
a
),
bit_cast
<
int32_t
>
(
b1
),
c1
,
false
);
#endif
}
...
...
@@ -244,20 +244,20 @@ __device__ void amd_assembly_outer_product_1x4(int8x4_t a,
v_dot4_i32_i8 %3, %4, %8, %3
\n
\
"
:
"=v"
(
c0
),
"=v"
(
c1
),
"=v"
(
c2
),
"=v"
(
c3
)
:
"v"
(
as_type
<
int32_t
>
(
a
)),
"v"
(
as_type
<
int32_t
>
(
b0
)),
"v"
(
as_type
<
int32_t
>
(
b1
)),
"v"
(
as_type
<
int32_t
>
(
b2
)),
"v"
(
as_type
<
int32_t
>
(
b3
)),
:
"v"
(
bit_cast
<
int32_t
>
(
a
)),
"v"
(
bit_cast
<
int32_t
>
(
b0
)),
"v"
(
bit_cast
<
int32_t
>
(
b1
)),
"v"
(
bit_cast
<
int32_t
>
(
b2
)),
"v"
(
bit_cast
<
int32_t
>
(
b3
)),
"0"
(
c0
),
"1"
(
c1
),
"2"
(
c2
),
"3"
(
c3
));
#else
c0
=
__builtin_amdgcn_sdot4
(
as_type
<
int32_t
>
(
a
),
as_type
<
int32_t
>
(
b0
),
c0
,
false
);
c1
=
__builtin_amdgcn_sdot4
(
as_type
<
int32_t
>
(
a
),
as_type
<
int32_t
>
(
b1
),
c1
,
false
);
c2
=
__builtin_amdgcn_sdot4
(
as_type
<
int32_t
>
(
a
),
as_type
<
int32_t
>
(
b2
),
c2
,
false
);
c3
=
__builtin_amdgcn_sdot4
(
as_type
<
int32_t
>
(
a
),
as_type
<
int32_t
>
(
b3
),
c3
,
false
);
c0
=
__builtin_amdgcn_sdot4
(
bit_cast
<
int32_t
>
(
a
),
bit_cast
<
int32_t
>
(
b0
),
c0
,
false
);
c1
=
__builtin_amdgcn_sdot4
(
bit_cast
<
int32_t
>
(
a
),
bit_cast
<
int32_t
>
(
b1
),
c1
,
false
);
c2
=
__builtin_amdgcn_sdot4
(
bit_cast
<
int32_t
>
(
a
),
bit_cast
<
int32_t
>
(
b2
),
c2
,
false
);
c3
=
__builtin_amdgcn_sdot4
(
bit_cast
<
int32_t
>
(
a
),
bit_cast
<
int32_t
>
(
b3
),
c3
,
false
);
#endif
}
...
...
composable_kernel/include/utility/amd_xdlops.hpp
View file @
a037693f
...
...
@@ -340,8 +340,8 @@ struct intrin_mfma_i32_32x32x8i8<32, 32>
__device__
static
void
Run
(
const
int8x4_t
&
reg_a
,
const
int8x4_t
&
reg_b
,
FloatC
&
reg_c
)
{
reg_c
.
template
AsType
<
int32x16_t
>()(
Number
<
0
>
{})
=
llvm_intrin_amdgcn_mfma_i32_32x32x8i8
(
as_type
<
int
>
(
reg_a
),
as_type
<
int
>
(
reg_b
),
llvm_intrin_amdgcn_mfma_i32_32x32x8i8
(
bit_cast
<
int
>
(
reg_a
),
bit_cast
<
int
>
(
reg_b
),
reg_c
.
template
AsType
<
int32x16_t
>()[
Number
<
0
>
{}],
0
,
0
,
...
...
@@ -359,8 +359,8 @@ struct intrin_mfma_i32_16x16x16i8<16, 16>
__device__
static
void
Run
(
const
int8x4_t
&
reg_a
,
const
int8x4_t
&
reg_b
,
FloatC
&
reg_c
)
{
reg_c
.
template
AsType
<
int32x4_t
>()(
Number
<
0
>
{})
=
llvm_intrin_amdgcn_mfma_i32_16x16x16i8
(
as_type
<
int
>
(
reg_a
),
as_type
<
int
>
(
reg_b
),
llvm_intrin_amdgcn_mfma_i32_16x16x16i8
(
bit_cast
<
int
>
(
reg_a
),
bit_cast
<
int
>
(
reg_b
),
reg_c
.
template
AsType
<
int32x4_t
>()[
Number
<
0
>
{}],
0
,
0
,
...
...
composable_kernel/include/utility/config.hpp
View file @
a037693f
...
...
@@ -96,9 +96,22 @@
// pass tensor descriptor by value or void*
#define CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE 1
#define CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER 0
#define CK_EXPERIMENTAL_STATIC_TENSOR_DESCRIPTOR 0
// merge transformation use magic number division
#ifndef CK_EXPERIMENTAL_MERGE_USE_MAGIC_DIVISION
#define CK_EXPERIMENTAL_MERGE_USE_MAGIC_DIVISION 1
#endif
// use __builtin_memcpy instead of pointer cast to access a vector from pointer of scalar
#ifndef CK_EXPERIMENTAL_USE_MEMCPY_FOR_VECTOR_ACCESS
#define CK_EXPERIMENTAL_USE_MEMCPY_FOR_VECTOR_ACCESS 0
#endif
// use __builtin_memcpy instead of union to do bit_cast
#ifndef CK_EXPERIMENTAL_USE_MEMCPY_FOR_BIT_CAST
#define CK_EXPERIMENTAL_USE_MEMCPY_FOR_BIT_CAST 1
#endif
// hack: have underlying assumption that need to be satsified, otherwise it's a bug
// hack for forcing register to keep idx_diff_low_const in SGPR. idx_diff_low_const must be
...
...
@@ -118,7 +131,7 @@
#define CK_WORKAROUND_SWDEV_XXXXXX_INT8_BUFFER_LOAD_STORE_ISSUE 1
#endif
// workaround for compiler
crash when using buffer load/store for i8
// workaround for compiler
gnerating inefficient ds_write instructions
#ifndef CK_WORKAROUND_SWDEV_XXXXXX_INT8_DS_WRITE_ISSUE
#define CK_WORKAROUND_SWDEV_XXXXXX_INT8_DS_WRITE_ISSUE 1
#endif
...
...
@@ -133,7 +146,15 @@ namespace ck {
enum
InMemoryDataOperationEnum_t
{
Set
,
AtomicAdd
AtomicAdd
,
Add
};
enum
ActivTypeEnum_t
{
None
=
0
,
LeakyRelu
,
Sigmoid
};
// index type
...
...
composable_kernel/include/utility/data_type.hpp
View file @
a037693f
...
...
@@ -1081,11 +1081,11 @@ struct NumericLimits<half_t>
static
constexpr
unsigned
short
binary_max
=
0x7BFF
;
static
constexpr
unsigned
short
binary_lowest
=
0xFBFF
;
__host__
__device__
static
constexpr
half_t
Min
()
{
return
as_type
<
half_t
>
(
binary_min
);
}
__host__
__device__
static
constexpr
half_t
Min
()
{
return
bit_cast
<
half_t
>
(
binary_min
);
}
__host__
__device__
static
constexpr
half_t
Max
()
{
return
as_type
<
half_t
>
(
binary_max
);
}
__host__
__device__
static
constexpr
half_t
Max
()
{
return
bit_cast
<
half_t
>
(
binary_max
);
}
__host__
__device__
static
constexpr
half_t
Lowest
()
{
return
as_type
<
half_t
>
(
binary_lowest
);
}
__host__
__device__
static
constexpr
half_t
Lowest
()
{
return
bit_cast
<
half_t
>
(
binary_lowest
);
}
};
}
// namespace ck
...
...
composable_kernel/include/utility/dynamic_buffer.hpp
View file @
a037693f
...
...
@@ -83,12 +83,28 @@ struct DynamicBuffer
{
if
constexpr
(
InvalidElementUseNumericalZeroValue
)
{
#if CK_EXPERIMENTAL_USE_MEMCPY_FOR_VECTOR_ACCESS
X
tmp
;
__builtin_memcpy
(
&
tmp
,
&
(
p_data_
[
i
]),
sizeof
(
X
));
return
is_valid_element
?
tmp
:
X
{
0
};
#else
return
is_valid_element
?
*
c_style_pointer_cast
<
const
X
*>
(
&
p_data_
[
i
])
:
X
{
0
};
#endif
}
else
{
#if CK_EXPERIMENTAL_USE_MEMCPY_FOR_VECTOR_ACCESS
X
tmp
;
__builtin_memcpy
(
&
tmp
,
&
(
p_data_
[
i
]),
sizeof
(
X
));
return
is_valid_element
?
tmp
:
X
{
invalid_element_value_
};
#else
return
is_valid_element
?
*
c_style_pointer_cast
<
const
X
*>
(
&
p_data_
[
i
])
:
X
{
invalid_element_value_
};
#endif
}
}
}
...
...
@@ -117,7 +133,13 @@ struct DynamicBuffer
#else
if
(
is_valid_element
)
{
#if CK_EXPERIMENTAL_USE_MEMCPY_FOR_VECTOR_ACCESS
X
tmp
=
x
;
__builtin_memcpy
(
&
(
p_data_
[
i
]),
&
tmp
,
sizeof
(
X
));
#else
*
c_style_pointer_cast
<
X
*>
(
&
p_data_
[
i
])
=
x
;
#endif
}
#endif
}
...
...
@@ -126,7 +148,13 @@ struct DynamicBuffer
if
(
is_valid_element
)
{
#if !CK_WORKAROUND_SWDEV_XXXXXX_INT8_DS_WRITE_ISSUE
#if CK_EXPERIMENTAL_USE_MEMCPY_FOR_VECTOR_ACCESS
X
tmp
=
x
;
__builtin_memcpy
(
&
(
p_data_
[
i
]),
&
tmp
,
sizeof
(
X
));
#else
*
c_style_pointer_cast
<
X
*>
(
&
p_data_
[
i
])
=
x
;
#endif
#else
// HACK: compiler would lower IR "store<i8, 16> address_space(3)" into
// inefficient
...
...
@@ -201,7 +229,13 @@ struct DynamicBuffer
}
else
{
#if CK_EXPERIMENTAL_USE_MEMCPY_FOR_VECTOR_ACCESS
X
tmp
=
x
;
__builtin_memcpy
(
&
(
p_data_
[
i
]),
&
tmp
,
sizeof
(
X
));
#else
*
c_style_pointer_cast
<
X
*>
(
&
p_data_
[
i
])
=
x
;
#endif
}
#endif
}
...
...
@@ -210,7 +244,13 @@ struct DynamicBuffer
{
if
(
is_valid_element
)
{
#if CK_EXPERIMENTAL_USE_MEMCPY_FOR_VECTOR_ACCESS
X
tmp
=
x
;
__builtin_memcpy
(
&
(
p_data_
[
i
]),
&
tmp
,
sizeof
(
X
));
#else
*
c_style_pointer_cast
<
X
*>
(
&
p_data_
[
i
])
=
x
;
#endif
}
}
}
...
...
composable_kernel/include/utility/inner_product.hpp
View file @
a037693f
...
...
@@ -144,9 +144,9 @@ inner_product<int8x4_t, int8x4_t, int32_t>(const int8x4_t& a, const int8x4_t& b,
v_dot4_i32_i8 %0, %1, %2, %0
\n
\
"
:
"=v"
(
c
)
:
"v"
(
as_type
<
int32_t
>
(
a
)),
"v"
(
as_type
<
int32_t
>
(
b
)),
"0"
(
c
));
:
"v"
(
bit_cast
<
int32_t
>
(
a
)),
"v"
(
bit_cast
<
int32_t
>
(
b
)),
"0"
(
c
));
#else
c
=
__builtin_amdgcn_sdot4
(
as_type
<
int32_t
>
(
a
),
as_type
<
int32_t
>
(
b
),
c
,
false
);
c
=
__builtin_amdgcn_sdot4
(
bit_cast
<
int32_t
>
(
a
),
bit_cast
<
int32_t
>
(
b
),
c
,
false
);
#endif
#else
const
vector_type
<
int8_t
,
4
>
a_vector
{
a
};
...
...
composable_kernel/include/utility/magic_division.hpp
View file @
a037693f
...
...
@@ -125,7 +125,7 @@ struct MagicDivision
__host__
__device__
static
constexpr
int32_t
DoMagicDivision
(
int32_t
dividend_i32
,
uint32_t
multiplier
,
uint32_t
shift
)
{
uint32_t
dividend_u32
=
as_type
<
uint32_t
>
(
dividend_i32
);
uint32_t
dividend_u32
=
bit_cast
<
uint32_t
>
(
dividend_i32
);
uint32_t
tmp
=
__umulhi
(
dividend_u32
,
multiplier
);
return
(
tmp
+
dividend_u32
)
>>
shift
;
}
...
...
composable_kernel/include/utility/statically_indexed_array.hpp
View file @
a037693f
...
...
@@ -54,5 +54,49 @@ __host__ __device__ constexpr auto make_statically_indexed_array()
return
StaticallyIndexedArray
<
X
,
0
>
();
}
template
<
typename
T
,
index_t
N
>
struct
StaticallyIndexedArray_v2
{
__host__
__device__
constexpr
StaticallyIndexedArray_v2
()
=
default
;
__host__
__device__
static
constexpr
index_t
Size
()
{
return
N
;
}
// read access
template
<
index_t
I
>
__host__
__device__
constexpr
const
auto
&
At
(
Number
<
I
>
)
const
{
static_assert
(
I
<
N
,
"wrong! out of range"
);
return
data_
[
I
];
}
// write access
template
<
index_t
I
>
__host__
__device__
constexpr
auto
&
At
(
Number
<
I
>
)
{
static_assert
(
I
<
N
,
"wrong! out of range"
);
return
data_
[
I
];
}
// read access
template
<
index_t
I
>
__host__
__device__
constexpr
const
auto
&
operator
[](
Number
<
I
>
i
)
const
{
return
At
(
i
);
}
// write access
template
<
index_t
I
>
__host__
__device__
constexpr
auto
&
operator
()(
Number
<
I
>
i
)
{
return
At
(
i
);
}
__host__
__device__
static
constexpr
bool
IsStaticBuffer
()
{
return
true
;
}
T
data_
[
N
];
};
}
// namespace ck
#endif
composable_kernel/include/utility/type.hpp
View file @
a037693f
...
...
@@ -32,8 +32,15 @@ template <typename T>
inline
constexpr
bool
is_pointer_v
=
std
::
is_pointer
<
T
>::
value
;
template
<
typename
Y
,
typename
X
,
typename
enable_if
<
sizeof
(
X
)
==
sizeof
(
Y
),
bool
>
::
type
=
false
>
__host__
__device__
constexpr
Y
as_type
(
X
x
)
__host__
__device__
constexpr
Y
bit_cast
(
const
X
&
x
)
{
#if CK_EXPERIMENTAL_USE_MEMCPY_FOR_BIT_CAST
Y
y
;
__builtin_memcpy
(
&
y
,
&
x
,
sizeof
(
X
));
return
y
;
#else
union
AsType
{
X
x
;
...
...
@@ -41,6 +48,7 @@ __host__ __device__ constexpr Y as_type(X x)
};
return
AsType
{
x
}.
y
;
#endif
}
}
// namespace ck
...
...
example/1_gemm_xdl/gemm_xdl.cpp
View file @
a037693f
...
...
@@ -9,7 +9,6 @@
#include "device.hpp"
#include "host_tensor.hpp"
#include "host_tensor_generator.hpp"
#include "gemm_common.hpp"
#include "host_gemm.hpp"
#include "device_tensor.hpp"
#include "device_base.hpp"
...
...
@@ -139,12 +138,12 @@ int main(int argc, char* argv[])
{
case
0
:
break
;
case
1
:
a_m_k
.
GenerateTensorValue
(
GeneratorTensor_2
{
-
5
,
5
});
b_k_n
.
GenerateTensorValue
(
GeneratorTensor_2
{
-
5
,
5
});
a_m_k
.
GenerateTensorValue
(
GeneratorTensor_2
<
ADataType
>
{
-
5
,
5
});
b_k_n
.
GenerateTensorValue
(
GeneratorTensor_2
<
BDataType
>
{
-
5
,
5
});
break
;
default:
a_m_k
.
GenerateTensorValue
(
GeneratorTensor_3
<
float
>
{
0.0
,
1.0
});
b_k_n
.
GenerateTensorValue
(
GeneratorTensor_3
<
float
>
{
-
0.5
,
0.5
});
a_m_k
.
GenerateTensorValue
(
GeneratorTensor_3
<
ADataType
>
{
0.0
,
1.0
});
b_k_n
.
GenerateTensorValue
(
GeneratorTensor_3
<
BDataType
>
{
-
0.5
,
0.5
});
}
DeviceMem
a_m_k_device_buf
(
sizeof
(
ADataType
)
*
a_m_k
.
mDesc
.
GetElementSpace
());
...
...
host/driver_offline/CMakeLists.txt
View file @
a037693f
...
...
@@ -13,16 +13,25 @@ include_directories(BEFORE
)
set
(
CONV_FWD_DRIVER_OFFLINE_SOURCE src/conv_fwd_driver_offline.cpp
)
set
(
CONV_FWD_DRIVER_OFFLINE_NCHWC_SOURCE src/conv_fwd_driver_offline_nchwc.cpp
)
set
(
CONV_ADD_FWD_DRIVER_OFFLINE_NCHWC_SOURCE src/conv_add_fwd_driver_offline_nchwc.cpp
)
set
(
CONV_MAXPOOL_FWD_DRIVER_OFFLINE_NCHWC_SOURCE src/conv_maxpool_fwd_driver_offline_nchwc.cpp
)
set
(
CONV_BWD_DRIVER_OFFLINE_SOURCE src/conv_bwd_driver_offline.cpp
)
set
(
CONV_WRW_DRIVER_OFFLINE_SOURCE src/conv_wrw_driver_offline.cpp
)
set
(
GEMM_DRIVER_OFFLINE_SOURCE src/gemm_driver_offline.cpp
)
add_executable
(
conv_fwd_driver_offline
${
CONV_FWD_DRIVER_OFFLINE_SOURCE
}
)
add_executable
(
conv_fwd_driver_offline_nchwc
${
CONV_FWD_DRIVER_OFFLINE_NCHWC_SOURCE
}
)
add_executable
(
conv_add_fwd_driver_offline_nchwc
${
CONV_ADD_FWD_DRIVER_OFFLINE_NCHWC_SOURCE
}
)
add_executable
(
conv_maxpool_fwd_driver_offline_nchwc
${
CONV_MAXPOOL_FWD_DRIVER_OFFLINE_NCHWC_SOURCE
}
)
add_executable
(
conv_bwd_driver_offline
${
CONV_BWD_DRIVER_OFFLINE_SOURCE
}
)
add_executable
(
conv_wrw_driver_offline
${
CONV_WRW_DRIVER_OFFLINE_SOURCE
}
)
add_executable
(
gemm_driver_offline
${
GEMM_DRIVER_OFFLINE_SOURCE
}
)
target_link_libraries
(
conv_fwd_driver_offline PRIVATE host_tensor
)
target_link_libraries
(
conv_fwd_driver_offline_nchwc PRIVATE host_tensor
)
target_link_libraries
(
conv_add_fwd_driver_offline_nchwc PRIVATE host_tensor
)
target_link_libraries
(
conv_maxpool_fwd_driver_offline_nchwc PRIVATE host_tensor
)
target_link_libraries
(
conv_bwd_driver_offline PRIVATE host_tensor
)
target_link_libraries
(
conv_wrw_driver_offline PRIVATE host_tensor
)
target_link_libraries
(
gemm_driver_offline PRIVATE host_tensor
)
host/driver_offline/include/device_convolution_add_forward_implicit_gemm_v5r1_dlops_nc0hwc1_kc0yxc1_nk0hwk1.hpp
0 → 100644
View file @
a037693f
#include <unistd.h>
#include "device.hpp"
#include "host_tensor.hpp"
#include "driver_convolution_add_forward_implicit_gemm_v5r1_dlops_nc0hwc1_kc0yxc1_nk0hwk1.hpp"
template
<
typename
TInWei
,
typename
TAcc
,
typename
TOut
,
ck
::
ActivTypeEnum_t
activ_type
,
typename
InLengths
,
typename
WeiLengths
,
typename
AddLengths
,
typename
OutLengths
,
typename
ConvStrides
,
typename
ConvDilations
,
typename
InLeftPads
,
typename
InRightPads
>
void
device_convolution_add_forward_implicit_gemm_v5r1_dlops_nc0hwc1_kc0yxc1_nk0hwk1
(
const
InLengths
&
in_n_c0_hi_wi_c1_lengths
,
const
WeiLengths
&
wei_k_c0_y_x_c1_lengths
,
const
AddLengths
&
add_n_k0_hox2_wox2_k1_lengths
,
const
OutLengths
&
out_n_k0_ho_wo_k1_lengths
,
const
ConvStrides
&
conv_strides
,
const
ConvDilations
&
conv_dilations
,
const
InLeftPads
&
in_left_pads
,
const
InRightPads
&
in_right_pads
,
const
Tensor
<
TInWei
>&
in_n_c0_hi_wi_c1
,
const
Tensor
<
TInWei
>&
wei_k_c0_y_x_c1
,
const
Tensor
<
TOut
>&
bias_k0_k1
,
const
Tensor
<
TOut
>&
add_n_k0_hox2_wox2_k1
,
Tensor
<
TOut
>&
add_n_k0_hox2_wox2_k1_out
,
ck
::
index_t
nrepeat
)
{
using
namespace
ck
;
std
::
cout
<<
__func__
<<
std
::
endl
;
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
I4
=
Number
<
4
>
{};
const
auto
N
=
out_n_k0_ho_wo_k1_lengths
[
I0
];
const
auto
K0
=
out_n_k0_ho_wo_k1_lengths
[
I1
];
const
auto
Ho
=
out_n_k0_ho_wo_k1_lengths
[
I2
];
const
auto
Wo
=
out_n_k0_ho_wo_k1_lengths
[
I3
];
const
auto
K1
=
out_n_k0_ho_wo_k1_lengths
[
I4
];
const
auto
C0
=
in_n_c0_hi_wi_c1_lengths
[
I1
];
const
auto
Hi
=
in_n_c0_hi_wi_c1_lengths
[
I2
];
const
auto
Wi
=
in_n_c0_hi_wi_c1_lengths
[
I3
];
const
auto
C1
=
in_n_c0_hi_wi_c1_lengths
[
I4
];
const
auto
K
=
wei_k_c0_y_x_c1_lengths
[
I0
];
const
auto
Y
=
wei_k_c0_y_x_c1_lengths
[
I2
];
const
auto
X
=
wei_k_c0_y_x_c1_lengths
[
I3
];
const
auto
Hox2
=
add_n_k0_hox2_wox2_k1_lengths
[
I2
];
const
auto
Wox2
=
add_n_k0_hox2_wox2_k1_lengths
[
I3
];
DeviceMem
in_n_c0_hi_wi_c1_device_buf
(
sizeof
(
TInWei
)
*
in_n_c0_hi_wi_c1
.
mDesc
.
GetElementSpace
());
DeviceMem
wei_k_c0_y_x_c1_device_buf
(
sizeof
(
TInWei
)
*
wei_k_c0_y_x_c1
.
mDesc
.
GetElementSpace
());
DeviceMem
bias_k0_k1_device_buf
(
sizeof
(
TOut
)
*
bias_k0_k1
.
mDesc
.
GetElementSpace
());
DeviceMem
add_n_k0_hox2_wox2_k1_device_buf
(
sizeof
(
TOut
)
*
add_n_k0_hox2_wox2_k1
.
mDesc
.
GetElementSpace
());
in_n_c0_hi_wi_c1_device_buf
.
ToDevice
(
in_n_c0_hi_wi_c1
.
mData
.
data
());
wei_k_c0_y_x_c1_device_buf
.
ToDevice
(
wei_k_c0_y_x_c1
.
mData
.
data
());
bias_k0_k1_device_buf
.
ToDevice
(
bias_k0_k1
.
mData
.
data
());
add_n_k0_hox2_wox2_k1_device_buf
.
ToDevice
(
add_n_k0_hox2_wox2_k1
.
mData
.
data
());
constexpr
index_t
InWeiVectorSize
=
8
;
if
(
C1
%
InWeiVectorSize
!=
0
)
{
throw
std
::
runtime_error
(
"wrong! C1 cannot be divided by InWeiVectorSize"
);
}
#if 0
constexpr index_t BlockSize = 256;
constexpr index_t KPerBlock = 32;
constexpr index_t HoPerBlock = 8;
constexpr index_t WoPerBlock = 64;
constexpr index_t E1 = C0 * 9;
constexpr index_t E2 = 1;
constexpr index_t E1PerBlock = C0;
constexpr index_t KPerThread = 16;
constexpr index_t HoPerThread = 2;
constexpr index_t WoPerThread = 2;
constexpr index_t EPerThread = 1;
using ABlockTransferThreadSliceLengths_E0_E1_K0_K1_E2 = Sequence<1, 9, 1, E2>;
using ABlockTransferThreadClusterLengths_E0_E1_K0_K1_E2 = Sequence<1, E1PerBlock, KPerBlock, 1>;
constexpr index_t ABlockTransferSrcScalarPerVector_E2 = E2;
constexpr index_t ABlockTransferDstScalarPerVector_E2 = E2;
constexpr index_t BThreadTransferSrcScalarPerVector_E2 = E2;
constexpr index_t CThreadTransferDstScalarPerVector_K = K1;
#elif
1
constexpr
auto
BlockSize
=
64
;
constexpr
auto
KPerBlock
=
8
;
constexpr
auto
HoPerBlock
=
8
;
constexpr
auto
WoPerBlock
=
32
;
constexpr
auto
E1
=
2
*
9
;
constexpr
auto
E2
=
1
;
constexpr
auto
K2
=
2
;
constexpr
auto
E1PerBlock
=
2
;
constexpr
auto
KPerThread
=
KPerBlock
;
constexpr
auto
HoPerThread
=
2
;
constexpr
auto
WoPerThread
=
2
;
constexpr
auto
EPerThread
=
1
;
using
ABlockTransferThreadSliceLengths_E0_E1_K0_K1_E2
=
Sequence
<
1
,
9
,
1
,
1
,
E2
>
;
using
ABlockTransferThreadClusterLengths_E0_E1_K0_K1_E2
=
Sequence
<
1
,
E1PerBlock
,
1
,
KPerBlock
,
1
>
;
constexpr
auto
ABlockTransferSrcScalarPerVector_E2
=
E2
;
constexpr
auto
ABlockTransferDstScalarPerVector_E2
=
E2
;
constexpr
auto
BThreadTransferSrcScalarPerVector_E2
=
E2
;
constexpr
auto
CThreadTransferDstScalarPerVector_K
=
InWeiVectorSize
;
#endif
const
auto
in_n_c0_hi_wi_c1_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
N
,
C0
,
Hi
,
Wi
,
E2
));
const
auto
wei_k_c0_y_x_c1_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
K
,
C0
,
Y
,
X
,
E2
));
const
auto
add_n_k0_hox2_wox2_k1_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
N
,
K0
,
Hox2
,
Wox2
,
K1
));
const
auto
out_n_k0_ho_wo_k1_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
N
,
K0
,
Ho
,
Wo
,
K1
));
constexpr
auto
conv_driver
=
DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nc0hwc1_kc0yxc1_nk0hwk1_add
<
BlockSize
,
typename
vector_type
<
TInWei
,
InWeiVectorSize
>::
type
,
TAcc
,
TOut
,
E1
,
E2
,
K2
,
KPerBlock
,
HoPerBlock
,
WoPerBlock
,
E1PerBlock
,
KPerThread
,
HoPerThread
,
WoPerThread
,
EPerThread
,
ABlockTransferThreadSliceLengths_E0_E1_K0_K1_E2
,
ABlockTransferThreadClusterLengths_E0_E1_K0_K1_E2
,
ABlockTransferSrcScalarPerVector_E2
,
ABlockTransferDstScalarPerVector_E2
,
BThreadTransferSrcScalarPerVector_E2
,
CThreadTransferDstScalarPerVector_K
,
activ_type
>
{};
std
::
cerr
<<
"conv_bias_activ_resize_add_input_"
<<
"n"
<<
N
<<
"c"
<<
C0
<<
"h"
<<
Hi
<<
"w"
<<
Wi
<<
"c"
<<
C1
<<
"_filter_k"
<<
K
<<
"c"
<<
C0
<<
"y"
<<
Y
<<
"x"
<<
X
<<
"c"
<<
C1
<<
"_addout_n"
<<
N
<<
"k"
<<
K0
<<
"h"
<<
Ho
*
2
<<
"w"
<<
Wo
*
2
<<
"k"
<<
K1
<<
std
::
endl
;
for
(
int
i
=
0
;
i
<
5
;
i
++
)
{
const
auto
ave_time
=
conv_driver
.
Run
(
wei_k_c0_y_x_c1_desc
,
in_n_c0_hi_wi_c1_desc
,
out_n_k0_ho_wo_k1_desc
,
add_n_k0_hox2_wox2_k1_desc
,
conv_strides
,
conv_dilations
,
in_left_pads
,
in_right_pads
,
static_cast
<
typename
vector_type
<
TInWei
,
InWeiVectorSize
>::
type
*>
(
wei_k_c0_y_x_c1_device_buf
.
GetDeviceBuffer
()),
static_cast
<
typename
vector_type
<
TInWei
,
InWeiVectorSize
>::
type
*>
(
in_n_c0_hi_wi_c1_device_buf
.
GetDeviceBuffer
()),
static_cast
<
TOut
*>
(
bias_k0_k1_device_buf
.
GetDeviceBuffer
()),
static_cast
<
TOut
*>
(
add_n_k0_hox2_wox2_k1_device_buf
.
GetDeviceBuffer
()),
nrepeat
);
{
float
perf
=
static_cast
<
float
>
(
std
::
size_t
(
2
)
*
N
*
K
*
Ho
*
Wo
*
C0
*
C1
*
Y
*
X
)
/
(
std
::
size_t
(
1000
)
*
1000
*
1000
)
/
ave_time
;
std
::
cout
<<
"Average time : "
<<
ave_time
<<
" ms, "
<<
perf
<<
" TFlop/s"
<<
std
::
endl
;
}
}
add_n_k0_hox2_wox2_k1_device_buf
.
ToDevice
(
add_n_k0_hox2_wox2_k1
.
mData
.
data
());
conv_driver
.
Run
(
wei_k_c0_y_x_c1_desc
,
in_n_c0_hi_wi_c1_desc
,
out_n_k0_ho_wo_k1_desc
,
add_n_k0_hox2_wox2_k1_desc
,
conv_strides
,
conv_dilations
,
in_left_pads
,
in_right_pads
,
static_cast
<
typename
vector_type
<
TInWei
,
InWeiVectorSize
>::
type
*>
(
wei_k_c0_y_x_c1_device_buf
.
GetDeviceBuffer
()),
static_cast
<
typename
vector_type
<
TInWei
,
InWeiVectorSize
>::
type
*>
(
in_n_c0_hi_wi_c1_device_buf
.
GetDeviceBuffer
()),
static_cast
<
TOut
*>
(
bias_k0_k1_device_buf
.
GetDeviceBuffer
()),
static_cast
<
TOut
*>
(
add_n_k0_hox2_wox2_k1_device_buf
.
GetDeviceBuffer
()),
0
);
add_n_k0_hox2_wox2_k1_device_buf
.
FromDevice
(
add_n_k0_hox2_wox2_k1_out
.
mData
.
data
());
}
host/driver_offline/include/device_convolution_forward_implicit_gemm_v5r1_dlops_nc0hwc1_kc0yxc1_nk0hwk1.hpp
0 → 100644
View file @
a037693f
#include <unistd.h>
#include "device.hpp"
#include "host_tensor.hpp"
#include "driver_convolution_forward_implicit_gemm_v5r1_dlops_nc0hwc1_kc0yxc1_nk0hwk1.hpp"
template
<
typename
TInWei
,
typename
TAcc
,
typename
TOut
,
ck
::
ActivTypeEnum_t
activ_type
,
typename
InLengths
,
typename
WeiLengths
,
typename
OutLengths
,
typename
ConvStrides
,
typename
ConvDilations
,
typename
InLeftPads
,
typename
InRightPads
>
void
device_convolution_forward_implicit_gemm_v5r1_dlops_nc0hwc1_kc0yxc1_nk0hwk1
(
const
InLengths
&
in_n_c0_hi_wi_c1_lengths
,
const
WeiLengths
&
wei_k_c0_y_x_c1_lengths
,
const
OutLengths
&
out_n_k0_ho_wo_k1_lengths
,
const
ConvStrides
&
conv_strides
,
const
ConvDilations
&
conv_dilations
,
const
InLeftPads
&
in_left_pads
,
const
InRightPads
&
in_right_pads
,
const
Tensor
<
TInWei
>&
in_n_c0_hi_wi_c1
,
const
Tensor
<
TInWei
>&
wei_k_c0_y_x_c1
,
const
Tensor
<
TOut
>&
bias_k0_k1
,
Tensor
<
TOut
>&
out_n_k0_ho_wo_k1
,
ck
::
index_t
nrepeat
)
{
using
namespace
ck
;
std
::
cout
<<
__func__
<<
std
::
endl
;
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
I4
=
Number
<
4
>
{};
const
auto
N
=
out_n_k0_ho_wo_k1_lengths
[
I0
];
const
auto
K0
=
out_n_k0_ho_wo_k1_lengths
[
I1
];
const
auto
Ho
=
out_n_k0_ho_wo_k1_lengths
[
I2
];
const
auto
Wo
=
out_n_k0_ho_wo_k1_lengths
[
I3
];
const
auto
K1
=
out_n_k0_ho_wo_k1_lengths
[
I4
];
const
auto
C0
=
in_n_c0_hi_wi_c1_lengths
[
I1
];
const
auto
Hi
=
in_n_c0_hi_wi_c1_lengths
[
I2
];
const
auto
Wi
=
in_n_c0_hi_wi_c1_lengths
[
I3
];
const
auto
C1
=
in_n_c0_hi_wi_c1_lengths
[
I4
];
const
auto
K
=
wei_k_c0_y_x_c1_lengths
[
I0
];
const
auto
Y
=
wei_k_c0_y_x_c1_lengths
[
I2
];
const
auto
X
=
wei_k_c0_y_x_c1_lengths
[
I3
];
DeviceMem
in_n_c0_hi_wi_c1_device_buf
(
sizeof
(
TInWei
)
*
in_n_c0_hi_wi_c1
.
mDesc
.
GetElementSpace
());
DeviceMem
wei_k_c0_y_x_c1_device_buf
(
sizeof
(
TInWei
)
*
wei_k_c0_y_x_c1
.
mDesc
.
GetElementSpace
());
DeviceMem
bias_k0_k1_device_buf
(
sizeof
(
TOut
)
*
bias_k0_k1
.
mDesc
.
GetElementSpace
());
DeviceMem
out_n_k0_ho_wo_k1_device_buf
(
sizeof
(
TOut
)
*
out_n_k0_ho_wo_k1
.
mDesc
.
GetElementSpace
());
in_n_c0_hi_wi_c1_device_buf
.
ToDevice
(
in_n_c0_hi_wi_c1
.
mData
.
data
());
wei_k_c0_y_x_c1_device_buf
.
ToDevice
(
wei_k_c0_y_x_c1
.
mData
.
data
());
bias_k0_k1_device_buf
.
ToDevice
(
bias_k0_k1
.
mData
.
data
());
constexpr
index_t
InWeiVectorSize
=
8
;
if
(
C1
%
InWeiVectorSize
!=
0
)
{
throw
std
::
runtime_error
(
"wrong! C1 cannot be divided by InWeiVectorSize"
);
}
#if 0
constexpr index_t BlockSize = 256;
constexpr index_t KPerBlock = 32;
constexpr index_t HoPerBlock = 8;
constexpr index_t WoPerBlock = 64;
constexpr index_t E1 = C0 * 9;
constexpr index_t E2 = 1;
constexpr index_t E1PerBlock = C0;
constexpr index_t KPerThread = 16;
constexpr index_t HoPerThread = 2;
constexpr index_t WoPerThread = 2;
constexpr index_t EPerThread = 1;
using ABlockTransferThreadSliceLengths_E0_E1_K0_K1_E2 = Sequence<1, 9, 1, E2>;
using ABlockTransferThreadClusterLengths_E0_E1_K0_K1_E2 = Sequence<1, E1PerBlock, KPerBlock, 1>;
constexpr index_t ABlockTransferSrcScalarPerVector_E2 = E2;
constexpr index_t ABlockTransferDstScalarPerVector_E2 = E2;
constexpr index_t BThreadTransferSrcScalarPerVector_E2 = E2;
constexpr index_t CThreadTransferDstScalarPerVector_K = K1;
#elif
1
constexpr
index_t
BlockSize
=
64
;
constexpr
index_t
KPerBlock
=
8
;
constexpr
index_t
HoPerBlock
=
8
;
constexpr
index_t
WoPerBlock
=
32
;
constexpr
index_t
E1
=
2
*
9
;
constexpr
index_t
E2
=
1
;
constexpr
index_t
K2
=
2
;
constexpr
index_t
E1PerBlock
=
2
;
constexpr
index_t
KPerThread
=
KPerBlock
;
constexpr
index_t
HoPerThread
=
2
;
constexpr
index_t
WoPerThread
=
2
;
constexpr
index_t
EPerThread
=
1
;
using
ABlockTransferThreadSliceLengths_E0_E1_K0_K1_E2
=
Sequence
<
1
,
9
,
1
,
1
,
E2
>
;
using
ABlockTransferThreadClusterLengths_E0_E1_K0_K1_E2
=
Sequence
<
1
,
E1PerBlock
,
1
,
KPerBlock
,
1
>
;
constexpr
index_t
ABlockTransferSrcScalarPerVector_E2
=
E2
;
constexpr
index_t
ABlockTransferDstScalarPerVector_E2
=
E2
;
constexpr
index_t
BThreadTransferSrcScalarPerVector_E2
=
E2
;
constexpr
index_t
CThreadTransferDstScalarPerVector_K
=
InWeiVectorSize
;
#endif
if
(
KPerThread
%
InWeiVectorSize
!=
0
)
{
throw
std
::
runtime_error
(
"wrong! C1 cannot be divided by InWeiVectorSize"
);
}
const
auto
in_n_c0_hi_wi_c1_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
N
,
C0
,
Hi
,
Wi
,
E2
));
const
auto
wei_k_c0_y_x_c1_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
K
,
C0
,
Y
,
X
,
E2
));
const
auto
out_n_k0_ho_wo_k1_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
N
,
K0
,
Ho
,
Wo
,
K1
));
constexpr
auto
conv_driver
=
DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nc0hwc1_kc0yxc1_nk0hwk1_outpad
<
BlockSize
,
typename
vector_type
<
TInWei
,
InWeiVectorSize
>::
type
,
TAcc
,
TOut
,
E1
,
E2
,
K2
,
KPerBlock
,
HoPerBlock
,
WoPerBlock
,
E1PerBlock
,
KPerThread
,
HoPerThread
,
WoPerThread
,
EPerThread
,
ABlockTransferThreadSliceLengths_E0_E1_K0_K1_E2
,
ABlockTransferThreadClusterLengths_E0_E1_K0_K1_E2
,
ABlockTransferSrcScalarPerVector_E2
,
ABlockTransferDstScalarPerVector_E2
,
BThreadTransferSrcScalarPerVector_E2
,
CThreadTransferDstScalarPerVector_K
,
activ_type
>
{};
std
::
cerr
<<
"conv_bias_activ_input_"
<<
"n"
<<
N
<<
"c"
<<
C0
<<
"h"
<<
Hi
<<
"w"
<<
Wi
<<
"c"
<<
C1
<<
"_filter_k"
<<
K
<<
"c"
<<
C0
<<
"y"
<<
Y
<<
"x"
<<
X
<<
"c"
<<
C1
<<
"_convout_n"
<<
N
<<
"k"
<<
K0
<<
"h"
<<
Ho
<<
"w"
<<
Wo
<<
"k"
<<
K1
<<
std
::
endl
;
for
(
int
i
=
0
;
i
<
5
;
i
++
)
{
const
auto
ave_time
=
conv_driver
.
Run
(
wei_k_c0_y_x_c1_desc
,
in_n_c0_hi_wi_c1_desc
,
out_n_k0_ho_wo_k1_desc
,
conv_strides
,
conv_dilations
,
in_left_pads
,
in_right_pads
,
static_cast
<
typename
vector_type
<
TInWei
,
InWeiVectorSize
>::
type
*>
(
wei_k_c0_y_x_c1_device_buf
.
GetDeviceBuffer
()),
static_cast
<
typename
vector_type
<
TInWei
,
InWeiVectorSize
>::
type
*>
(
in_n_c0_hi_wi_c1_device_buf
.
GetDeviceBuffer
()),
static_cast
<
TOut
*>
(
bias_k0_k1_device_buf
.
GetDeviceBuffer
()),
static_cast
<
TOut
*>
(
out_n_k0_ho_wo_k1_device_buf
.
GetDeviceBuffer
()),
nrepeat
);
{
float
perf
=
static_cast
<
float
>
(
std
::
size_t
(
2
)
*
N
*
K
*
Ho
*
Wo
*
C0
*
C1
*
Y
*
X
)
/
(
std
::
size_t
(
1000
)
*
1000
*
1000
)
/
ave_time
;
std
::
cout
<<
"Average time : "
<<
ave_time
<<
" ms, "
<<
perf
<<
" TFlop/s"
<<
std
::
endl
;
}
}
out_n_k0_ho_wo_k1_device_buf
.
FromDevice
(
out_n_k0_ho_wo_k1
.
mData
.
data
());
}
host/driver_offline/include/device_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw.hpp
deleted
100644 → 0
View file @
0694d6ed
#include <unistd.h>
#include "device.hpp"
#include "host_tensor.hpp"
#include "driver_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw.hpp"
#include "driver_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw_outpad.hpp"
template
<
typename
TInWei
,
ck
::
index_t
InWeiVectorSize
,
typename
TAcc
,
typename
TOut
,
typename
InLengths
,
typename
WeiLengths
,
typename
OutLengths
,
typename
ConvStrides
,
typename
ConvDilations
,
typename
InLeftPads
,
typename
InRightPads
>
void
device_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw
(
const
InLengths
&
in_n_c_hi_wi_lengths
,
const
WeiLengths
&
wei_k_c_y_x_lengths
,
const
OutLengths
&
out_n_k_ho_wo_lengths
,
const
ConvStrides
&
conv_strides
,
const
ConvDilations
&
conv_dilations
,
const
InLeftPads
&
in_left_pads
,
const
InRightPads
&
in_right_pads
,
const
Tensor
<
TInWei
>&
in_n_c_hi_wi
,
const
Tensor
<
TInWei
>&
wei_k_c_y_x
,
Tensor
<
TOut
>&
out_n_k_ho_wo
,
ck
::
index_t
/* nrepeat */
)
{
using
namespace
ck
;
std
::
cout
<<
__func__
<<
std
::
endl
;
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
const
auto
N
=
out_n_k_ho_wo_lengths
[
I0
];
const
auto
K
=
out_n_k_ho_wo_lengths
[
I1
];
const
auto
C
=
wei_k_c_y_x_lengths
[
I1
];
const
auto
Hi
=
in_n_c_hi_wi_lengths
[
I2
];
const
auto
Wi
=
in_n_c_hi_wi_lengths
[
I3
];
const
auto
Ho
=
out_n_k_ho_wo_lengths
[
I2
];
const
auto
Wo
=
out_n_k_ho_wo_lengths
[
I3
];
const
auto
Y
=
wei_k_c_y_x_lengths
[
I2
];
const
auto
X
=
wei_k_c_y_x_lengths
[
I3
];
const
auto
C0
=
C
/
Number
<
InWeiVectorSize
>
{};
const
auto
C1
=
Number
<
InWeiVectorSize
>
{};
const
auto
K0
=
K
/
Number
<
InWeiVectorSize
>
{};
const
auto
K1
=
Number
<
InWeiVectorSize
>
{};
Tensor
<
TInWei
>
in_n_c0_hi_wi_c1
(
HostTensorDescriptor
(
std
::
initializer_list
<
index_t
>
{
N
,
C0
,
Hi
,
Wi
,
C1
}));
Tensor
<
TInWei
>
wei_k_c0_y_x_c1
(
HostTensorDescriptor
(
std
::
initializer_list
<
index_t
>
{
K
,
C0
,
Y
,
X
,
C1
}));
Tensor
<
TOut
>
out_n_k0_ho_wo_k1
(
HostTensorDescriptor
(
std
::
initializer_list
<
index_t
>
{
N
,
K0
,
Ho
,
Wo
,
K1
}));
auto
f_nchw2nc0hwc1
=
[
&
](
auto
n
,
auto
hi
,
auto
wi
,
auto
c
)
{
in_n_c0_hi_wi_c1
(
n
,
c
/
InWeiVectorSize
,
hi
,
wi
,
c
%
InWeiVectorSize
)
=
in_n_c_hi_wi
(
n
,
c
,
hi
,
wi
);
};
auto
f_kcyx2kc0yxc1
=
[
&
](
auto
k
,
auto
y
,
auto
x
,
auto
c
)
{
wei_k_c0_y_x_c1
(
k
,
c
/
InWeiVectorSize
,
y
,
x
,
c
%
InWeiVectorSize
)
=
wei_k_c_y_x
(
k
,
c
,
y
,
x
);
};
make_ParallelTensorFunctor
(
f_nchw2nc0hwc1
,
N
,
Hi
,
Wi
,
C
)();
make_ParallelTensorFunctor
(
f_kcyx2kc0yxc1
,
K
,
Y
,
X
,
C
)();
DeviceMem
in_n_c0_hi_wi_c1_device_buf
(
sizeof
(
TInWei
)
*
in_n_c0_hi_wi_c1
.
mDesc
.
GetElementSpace
());
DeviceMem
wei_k_c0_y_x_c1_device_buf
(
sizeof
(
TInWei
)
*
wei_k_c0_y_x_c1
.
mDesc
.
GetElementSpace
());
DeviceMem
out_n_k0_ho_wo_k1_device_buf
(
sizeof
(
TOut
)
*
out_n_k0_ho_wo_k1
.
mDesc
.
GetElementSpace
());
in_n_c0_hi_wi_c1_device_buf
.
ToDevice
(
in_n_c0_hi_wi_c1
.
mData
.
data
());
wei_k_c0_y_x_c1_device_buf
.
ToDevice
(
wei_k_c0_y_x_c1
.
mData
.
data
());
const
auto
in_n_c0_hi_wi_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
N
,
C0
,
Hi
,
Wi
));
const
auto
wei_k_c0_y_x_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
K
,
C0
,
Y
,
X
));
const
auto
out_n_k0_ho_wo_k1_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
N
,
K0
,
Ho
,
Wo
,
K1
));
#if 1
// cdata = 64, BlockSize = 64, 16x8x32x4
constexpr
index_t
BlockSize
=
64
;
constexpr
index_t
KPerBlock
=
16
;
constexpr
index_t
HoPerBlock
=
8
;
constexpr
index_t
WoPerBlock
=
32
;
constexpr
index_t
EPerBlock
=
1
;
constexpr
index_t
KPerThread
=
KPerBlock
;
constexpr
index_t
HoPerThread
=
2
;
constexpr
index_t
WoPerThread
=
2
;
constexpr
index_t
EPerThread
=
EPerBlock
;
using
ABlockTransferThreadSliceLengths_E_K
=
Sequence
<
3
,
1
>
;
using
ABlockTransferThreadClusterLengths_E_K
=
Sequence
<
3
*
EPerBlock
,
KPerBlock
>
;
constexpr
index_t
ABlockTransferSrcScalarPerVector_E
=
1
;
constexpr
index_t
ABlockTransferDstScalarPerVector_K
=
1
;
constexpr
index_t
BThreadTransferSrcScalarPerVector_W
=
1
;
constexpr
index_t
CThreadTransferDstScalarPerVector_W
=
16
;
static_assert
(
KPerThread
%
CThreadTransferDstScalarPerVector_W
==
0
,
""
);
#else
constexpr
index_t
BlockSize
=
64
;
constexpr
index_t
KPerBlock
=
16
;
constexpr
index_t
HoPerBlock
=
8
;
constexpr
index_t
WoPerBlock
=
32
;
constexpr
index_t
EPerBlock
=
1
;
constexpr
index_t
KPerThread
=
16
;
constexpr
index_t
HoPerThread
=
2
;
constexpr
index_t
WoPerThread
=
2
;
constexpr
index_t
EPerThread
=
EPerBlock
;
using
ABlockTransferThreadSliceLengths_E_K
=
Sequence
<
9
,
1
>
;
using
ABlockTransferThreadClusterLengths_E_K
=
Sequence
<
EPerBlock
,
16
>
;
constexpr
index_t
ABlockTransferSrcScalarPerVector_E
=
1
;
constexpr
index_t
ABlockTransferDstScalarPerVector_K
=
1
;
constexpr
index_t
BThreadTransferSrcScalarPerVector_W
=
1
;
constexpr
index_t
CThreadTransferDstScalarPerVector_W
=
K1
;
static_assert
(
KPerThread
%
CThreadTransferDstScalarPerVector_W
==
0
,
""
);
#endif
constexpr
auto
conv_driver
=
#if 0
DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_pad
#else
DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outpad
#endif
<
BlockSize
,
typename
vector_type
<
TInWei
,
InWeiVectorSize
>::
type
,
TAcc
,
TOut
,
KPerBlock
,
HoPerBlock
,
WoPerBlock
,
EPerBlock
,
KPerThread
,
HoPerThread
,
WoPerThread
,
EPerThread
,
ABlockTransferThreadSliceLengths_E_K
,
ABlockTransferThreadClusterLengths_E_K
,
ABlockTransferSrcScalarPerVector_E
,
ABlockTransferDstScalarPerVector_K
,
BThreadTransferSrcScalarPerVector_W
,
CThreadTransferDstScalarPerVector_W
>
{};
conv_driver
.
Run
(
wei_k_c0_y_x_desc
,
in_n_c0_hi_wi_desc
,
out_n_k0_ho_wo_k1_desc
,
conv_strides
,
conv_dilations
,
in_left_pads
,
in_right_pads
,
static_cast
<
typename
vector_type
<
TInWei
,
InWeiVectorSize
>::
type
*>
(
wei_k_c0_y_x_c1_device_buf
.
GetDeviceBuffer
()),
static_cast
<
typename
vector_type
<
TInWei
,
InWeiVectorSize
>::
type
*>
(
in_n_c0_hi_wi_c1_device_buf
.
GetDeviceBuffer
()),
static_cast
<
TOut
*>
(
out_n_k0_ho_wo_k1_device_buf
.
GetDeviceBuffer
()));
out_n_k0_ho_wo_k1_device_buf
.
FromDevice
(
out_n_k0_ho_wo_k1
.
mData
.
data
());
auto
f_nk0hwk1_to_nkhw
=
[
&
](
auto
n
,
auto
k
,
auto
ho
,
auto
wo
)
{
out_n_k_ho_wo
(
n
,
k
,
ho
,
wo
)
=
out_n_k0_ho_wo_k1
(
n
,
k
/
InWeiVectorSize
,
ho
,
wo
,
k
%
InWeiVectorSize
);
};
make_ParallelTensorFunctor
(
f_nk0hwk1_to_nkhw
,
N
,
K
,
Ho
,
Wo
)();
}
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment