Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
d99e020d
Commit
d99e020d
authored
May 31, 2021
by
Chao Liu
Browse files
overhauling fwd-v4r4
parent
4b21c0fd
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
195 additions
and
156 deletions
+195
-156
composable_kernel/include/kernel_algorithm/transform_forward_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw.hpp
...m_forward_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw.hpp
+1
-35
composable_kernel/include/tensor_operation/blockwise_gemm_v2r2.hpp
...e_kernel/include/tensor_operation/blockwise_gemm_v2r2.hpp
+1
-1
composable_kernel/include/tensor_operation/gridwise_dynamic_gemm_v1r2.hpp
...l/include/tensor_operation/gridwise_dynamic_gemm_v1r2.hpp
+144
-101
driver/include/device_dynamic_convolution_forward_implicit_gemm_v4r4r2_nchw_kcyx_nkhw.hpp
...nvolution_forward_implicit_gemm_v4r4r2_nchw_kcyx_nkhw.hpp
+49
-19
No files found.
composable_kernel/include/kernel_algorithm/transform_forward_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw.hpp
View file @
d99e020d
...
@@ -121,44 +121,10 @@ transform_forward_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw_pad(
...
@@ -121,44 +121,10 @@ transform_forward_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw_pad(
const
auto
out_gemm_block_cluster_desc
=
make_cluster_descriptor_v2
(
const
auto
out_gemm_block_cluster_desc
=
make_cluster_descriptor_v2
(
make_tuple
(
GemmM
/
Number
<
GemmMPerBlock
>
{},
GemmN
/
Number
<
GemmNPerBlock
>
{}));
make_tuple
(
GemmM
/
Number
<
GemmMPerBlock
>
{},
GemmN
/
Number
<
GemmNPerBlock
>
{}));
// hack to control index calculation when iterating over wei_gemmk_gemmm_global tensor
constexpr
auto
wei_gemmk_gemmm_global_iterator_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
>
{}),
make_tuple
(
Sequence
<
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
>
{}));
constexpr
auto
wei_gemmk_gemmm_global_move_slice_window_iterator_hacks
=
Sequence
<
0
,
0
,
0
>
{};
// hack to control index calculation when iterating over in_gemmk_gemmn_global tensor
constexpr
auto
in_gemmk_gemmn_global_iterator_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
>
{}),
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
>
{}));
constexpr
auto
in_gemmk_gemmn_global_move_slice_window_iterator_hacks
=
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
2
>
{};
// hack to control index calculation when iterating over out_gemmm0_gemmm1_gemmn0_gemmn1_global
// tensor hack for NKHW format
constexpr
auto
out_gemmm0_gemmm1_gemmn0_gemmn1_global_iterator_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
1
,
0
,
0
>
{},
Sequence
<
0
,
0
,
1
,
0
,
0
>
{}),
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
2
,
0
,
0
>
{},
Sequence
<
0
,
0
,
2
,
0
,
0
>
{}));
return
make_tuple
(
wei_gemmk_gemmm_global_desc
,
return
make_tuple
(
wei_gemmk_gemmm_global_desc
,
in_gemmk_gemmn_global_desc
,
in_gemmk_gemmn_global_desc
,
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc
,
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc
,
out_gemm_block_cluster_desc
,
out_gemm_block_cluster_desc
);
wei_gemmk_gemmm_global_iterator_hacks
,
in_gemmk_gemmn_global_iterator_hacks
,
out_gemmm0_gemmm1_gemmn0_gemmn1_global_iterator_hacks
,
wei_gemmk_gemmm_global_move_slice_window_iterator_hacks
,
in_gemmk_gemmn_global_move_slice_window_iterator_hacks
);
}
}
}
// namespace ck
}
// namespace ck
...
...
composable_kernel/include/tensor_operation/blockwise_gemm_v2r2.hpp
View file @
d99e020d
...
@@ -74,7 +74,7 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v2r2_pipeline_2x2
...
@@ -74,7 +74,7 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v2r2_pipeline_2x2
}
}
__host__
__device__
static
constexpr
auto
__host__
__device__
static
constexpr
auto
MakeBKN0N1BlockDescriptor
(
const
BKNBlockDesc
&
n
_k_n_block_desc
)
MakeBKN0N1BlockDescriptor
(
const
BKNBlockDesc
&
b
_k_n_block_desc
)
{
{
const
auto
b_k_n0_n1_block_desc
=
transform_dynamic_tensor_descriptor
(
const
auto
b_k_n0_n1_block_desc
=
transform_dynamic_tensor_descriptor
(
BKNBlockDesc
{},
BKNBlockDesc
{},
...
...
composable_kernel/include/tensor_operation/gridwise_dynamic_gemm_v1r2.hpp
View file @
d99e020d
...
@@ -17,9 +17,9 @@ template <typename GridwiseGemm,
...
@@ -17,9 +17,9 @@ template <typename GridwiseGemm,
typename
FloatA
,
typename
FloatA
,
typename
FloatB
,
typename
FloatB
,
typename
FloatC
,
typename
FloatC
,
typename
A
Global
Desc
,
typename
A
KMGrid
Desc
,
typename
B
Global
Desc
,
typename
B
KNGrid
Desc
,
typename
C
Global
Desc
,
typename
C
M0M1N0N1Grid
Desc
,
typename
CBlockClusterDesc
,
typename
CBlockClusterDesc
,
bool
HasMainKBlockLoop
,
bool
HasMainKBlockLoop
,
bool
HasDoubleTailKBlockLoop
>
bool
HasDoubleTailKBlockLoop
>
...
@@ -27,20 +27,20 @@ __global__ void
...
@@ -27,20 +27,20 @@ __global__ void
#if CK_USE_LAUNCH_BOUNDS
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
#endif
#endif
kernel_dynamic_gemm_v1r2
(
const
FloatA
*
__restrict__
p_a_g
lobal
,
kernel_dynamic_gemm_v1r2
(
const
FloatA
*
__restrict__
p_a_g
rid
,
const
FloatB
*
__restrict__
p_b_g
lobal
,
const
FloatB
*
__restrict__
p_b_g
rid
,
FloatC
*
__restrict__
p_c_g
lobal
,
FloatC
*
__restrict__
p_c_g
rid
,
const
A
Global
Desc
a_k_m_g
lobal
_desc
,
const
A
KMGrid
Desc
a_k_m_g
rid
_desc
,
const
B
Global
Desc
b_k_n_g
lobal
_desc
,
const
B
KNGrid
Desc
b_k_n_g
rid
_desc
,
const
C
Global
Desc
c_m0_m1_n0_n1_g
lobal
_desc
,
const
C
M0M1N0N1Grid
Desc
c_m0_m1_n0_n1_g
rid
_desc
,
const
CBlockClusterDesc
c_block_cluster_desc
)
const
CBlockClusterDesc
c_block_cluster_desc
)
{
{
GridwiseGemm
::
Run
(
p_a_g
lobal
,
GridwiseGemm
::
Run
(
p_a_g
rid
,
p_b_g
lobal
,
p_b_g
rid
,
p_c_g
lobal
,
p_c_g
rid
,
a_k_m_g
lobal
_desc
,
a_k_m_g
rid
_desc
,
b_k_n_g
lobal
_desc
,
b_k_n_g
rid
_desc
,
c_m0_m1_n0_n1_g
lobal
_desc
,
c_m0_m1_n0_n1_g
rid
_desc
,
c_block_cluster_desc
,
c_block_cluster_desc
,
integral_constant
<
bool
,
HasMainKBlockLoop
>
{},
integral_constant
<
bool
,
HasMainKBlockLoop
>
{},
integral_constant
<
bool
,
HasDoubleTailKBlockLoop
>
{});
integral_constant
<
bool
,
HasDoubleTailKBlockLoop
>
{});
...
@@ -53,9 +53,9 @@ template <typename GridwiseGemm,
...
@@ -53,9 +53,9 @@ template <typename GridwiseGemm,
typename
FloatA
,
typename
FloatA
,
typename
FloatB
,
typename
FloatB
,
typename
FloatC
,
typename
FloatC
,
typename
A
Global
Desc
,
typename
A
KMGrid
Desc
,
typename
B
Global
Desc
,
typename
B
KNGrid
Desc
,
typename
C
Global
Desc
,
typename
C
M0M1N0N1Grid
Desc
,
typename
CBlockClusterDesc
,
typename
CBlockClusterDesc
,
bool
HasMainKBlockLoop
,
bool
HasMainKBlockLoop
,
bool
HasDoubleTailKBlockLoop
>
bool
HasDoubleTailKBlockLoop
>
...
@@ -63,33 +63,33 @@ __global__ void
...
@@ -63,33 +63,33 @@ __global__ void
#if CK_USE_LAUNCH_BOUNDS
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
#endif
#endif
kernel_dynamic_gemm_v1r2
(
const
FloatA
*
__restrict__
p_a_g
lobal
,
kernel_dynamic_gemm_v1r2
(
const
FloatA
*
__restrict__
p_a_g
rid
,
const
FloatB
*
__restrict__
p_b_g
lobal
,
const
FloatB
*
__restrict__
p_b_g
rid
,
FloatC
*
__restrict__
p_c_g
lobal
,
FloatC
*
__restrict__
p_c_g
rid
,
const
void
__CONSTANT__
*
p_a_k_m_g
lobal
_desc
,
const
void
__CONSTANT__
*
p_a_k_m_g
rid
_desc
,
const
void
__CONSTANT__
*
p_b_k_n_g
lobal
_desc
,
const
void
__CONSTANT__
*
p_b_k_n_g
rid
_desc
,
const
void
__CONSTANT__
*
p_c_m0_m1_n0_n1_g
lobal
_desc
,
const
void
__CONSTANT__
*
p_c_m0_m1_n0_n1_g
rid
_desc
,
const
void
__CONSTANT__
*
p_c_block_cluster_desc
)
const
void
__CONSTANT__
*
p_c_block_cluster_desc
)
{
{
// first cast void __CONSTANT__ void* to void*
// first cast void __CONSTANT__ void* to void*
// second cast void* to Desc*
// second cast void* to Desc*
// the copy constructor of tensor descriptor doesn't take address_space(4)
// the copy constructor of tensor descriptor doesn't take address_space(4)
const
auto
a_k_m_g
lobal
_desc
=
const
auto
a_k_m_g
rid
_desc
=
*
reinterpret_cast
<
const
A
Global
Desc
*>
((
const
void
*
)
p_a_k_m_g
lobal
_desc
);
*
reinterpret_cast
<
const
A
KMGrid
Desc
*>
((
const
void
*
)
p_a_k_m_g
rid
_desc
);
const
auto
b_k_n_g
lobal
_desc
=
const
auto
b_k_n_g
rid
_desc
=
*
reinterpret_cast
<
const
B
Global
Desc
*>
((
const
void
*
)
p_b_k_n_g
lobal
_desc
);
*
reinterpret_cast
<
const
B
KNGrid
Desc
*>
((
const
void
*
)
p_b_k_n_g
rid
_desc
);
const
auto
c_m0_m1_n0_n1_g
lobal
_desc
=
const
auto
c_m0_m1_n0_n1_g
rid
_desc
=
*
reinterpret_cast
<
const
C
Global
Desc
*>
((
const
void
*
)
p_c_m0_m1_n0_n1_g
lobal
_desc
);
*
reinterpret_cast
<
const
C
M0M1N0N1Grid
Desc
*>
((
const
void
*
)
p_c_m0_m1_n0_n1_g
rid
_desc
);
const
auto
c_block_cluster_desc
=
const
auto
c_block_cluster_desc
=
*
reinterpret_cast
<
const
CBlockClusterDesc
*>
((
const
void
*
)
p_c_block_cluster_desc
);
*
reinterpret_cast
<
const
CBlockClusterDesc
*>
((
const
void
*
)
p_c_block_cluster_desc
);
GridwiseGemm
::
Run
(
p_a_g
lobal
,
GridwiseGemm
::
Run
(
p_a_g
rid
,
p_b_g
lobal
,
p_b_g
rid
,
p_c_g
lobal
,
p_c_g
rid
,
a_k_m_g
lobal
_desc
,
a_k_m_g
rid
_desc
,
b_k_n_g
lobal
_desc
,
b_k_n_g
rid
_desc
,
c_m0_m1_n0_n1_g
lobal
_desc
,
c_m0_m1_n0_n1_g
rid
_desc
,
c_block_cluster_desc
,
c_block_cluster_desc
,
integral_constant
<
bool
,
HasMainKBlockLoop
>
{},
integral_constant
<
bool
,
HasMainKBlockLoop
>
{},
integral_constant
<
bool
,
HasDoubleTailKBlockLoop
>
{});
integral_constant
<
bool
,
HasDoubleTailKBlockLoop
>
{});
...
@@ -101,9 +101,9 @@ template <index_t BlockSize,
...
@@ -101,9 +101,9 @@ template <index_t BlockSize,
typename
FloatAcc
,
typename
FloatAcc
,
typename
FloatC
,
typename
FloatC
,
InMemoryDataOperation
CGlobalMemoryDataOperation
,
InMemoryDataOperation
CGlobalMemoryDataOperation
,
typename
A
Global
Desc
,
typename
A
KMGrid
Desc
,
typename
B
Global
Desc
,
typename
B
KNGrid
Desc
,
typename
C
Global
Desc
,
typename
C
M0M1N0N1Grid
Desc
,
typename
CBlockClusterDesc
,
typename
CBlockClusterDesc
,
index_t
MPerBlock
,
index_t
MPerBlock
,
index_t
NPerBlock
,
index_t
NPerBlock
,
...
@@ -134,13 +134,18 @@ template <index_t BlockSize,
...
@@ -134,13 +134,18 @@ template <index_t BlockSize,
typename
CThreadTransferSrcDstAccessOrder
,
typename
CThreadTransferSrcDstAccessOrder
,
index_t
CThreadTransferSrcDstVectorDim
,
index_t
CThreadTransferSrcDstVectorDim
,
index_t
CThreadTransferDstScalarPerVector
,
index_t
CThreadTransferDstScalarPerVector
,
typename
AG
lobal
IteratorHacks
,
typename
AG
rid
IteratorHacks
,
typename
BG
lobal
IteratorHacks
,
typename
BG
rid
IteratorHacks
,
typename
CG
lobal
IteratorHacks
,
typename
CG
rid
IteratorHacks
,
typename
AG
lobal
MoveSliceWindowIteratorHacks
,
typename
AG
rid
MoveSliceWindowIteratorHacks
,
typename
BG
lobal
MoveSliceWindowIteratorHacks
>
typename
BG
rid
MoveSliceWindowIteratorHacks
>
struct
GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2
struct
GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2
{
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
__host__
__device__
static
constexpr
index_t
GetSharedMemoryNumberOfByte
()
__host__
__device__
static
constexpr
index_t
GetSharedMemoryNumberOfByte
()
{
{
constexpr
auto
max_lds_align
=
math
::
lcm
(
Number
<
ABlockTransferDstScalarPerVector_M
>
{},
constexpr
auto
max_lds_align
=
math
::
lcm
(
Number
<
ABlockTransferDstScalarPerVector_M
>
{},
...
@@ -168,33 +173,71 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2
...
@@ -168,33 +173,71 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2
return
2
*
(
a_block_space_size
+
b_block_space_size
)
*
sizeof
(
FloatAB
);
return
2
*
(
a_block_space_size
+
b_block_space_size
)
*
sizeof
(
FloatAB
);
}
}
__host__
__device__
static
constexpr
auto
MakeAKM0M1BlockClusterizedGridDescriptor
(
const
AKMGridDesc
&
a_k_m_grid_desc
)
{
const
auto
K
=
a_k_m_grid_desc
.
GetLength
(
I0
);
const
auto
M
=
a_k_m_grid_desc
.
GetLength
(
I1
);
const
auto
M1
=
Number
<
MPerBlock
>
{};
const
auto
M0
=
M
/
M1
;
const
auto
a_k_m0_m1_block_clusterized_grid_desc
=
transform_dynamic_tensor_descriptor
(
a_k_m_grid_desc
,
make_tuple
(
make_pass_through_transform
(
K
),
make_unmerge_transform
(
make_tuple
(
M0
,
M1
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{}));
return
a_k_m0_m1_block_clusterized_grid_desc
;
}
__host__
__device__
static
constexpr
auto
MakeBKN0N1BlockClusterizedGridDescriptor
(
const
BKNGridDesc
&
b_k_n_grid_desc
)
{
const
auto
K
=
b_k_n_grid_desc
.
GetLength
(
I0
);
const
auto
N
=
b_k_n_grid_desc
.
GetLength
(
I1
);
const
auto
N1
=
Number
<
NPerBlock
>
{};
const
auto
N0
=
N
/
N1
;
const
auto
b_k_n0_n1_block_clusterized_grid_desc
=
transform_dynamic_tensor_descriptor
(
b_k_n_grid_desc
,
make_tuple
(
make_pass_through_transform
(
K
),
make_unmerge_transform
(
make_tuple
(
N0
,
N1
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{}));
return
b_k_n0_n1_block_clusterized_grid_desc
;
}
using
AKM0M1GridDesc
=
decltype
(
MakeAKM0M1BlockClusterizedGridDescriptor
(
AKMGridDesc
{}));
using
BKN0N1GridDesc
=
decltype
(
MakeBKN0N1BlockClusterizedGridDescriptor
(
BKNGridDesc
{}));
template
<
bool
HasMainKBlockLoop
,
bool
HasDoubleTailKBlockLoop
>
template
<
bool
HasMainKBlockLoop
,
bool
HasDoubleTailKBlockLoop
>
__device__
static
void
Run
(
const
FloatAB
*
__restrict__
p_a_g
lobal
,
__device__
static
void
Run
(
const
FloatAB
*
__restrict__
p_a_g
rid
,
const
FloatAB
*
__restrict__
p_b_g
lobal
,
const
FloatAB
*
__restrict__
p_b_g
rid
,
FloatC
*
__restrict__
p_c_g
lobal
,
FloatC
*
__restrict__
p_c_g
rid
,
const
A
Global
Desc
&
a_k_m_g
lobal
_desc
,
const
A
KMGrid
Desc
&
a_k_m_g
rid
_desc
,
const
B
Global
Desc
&
b_k_n_g
lobal
_desc
,
const
B
KNGrid
Desc
&
b_k_n_g
rid
_desc
,
const
C
Global
Desc
&
c_m0_m1_n0_n1_g
lobal
_desc
,
const
C
M0M1N0N1Grid
Desc
&
c_m0_m1_n0_n1_g
rid
_desc
,
const
CBlockClusterDesc
&
c_block_cluster_desc
,
const
CBlockClusterDesc
&
c_block_cluster_desc
,
#if 0
const AKM0M1GridDesc& a_k_m0_m1_grid_desc,
const BKN0N1GridDesc& b_k_n0_n1_grid_desc,
#endif
FloatAB
*
__restrict__
p_shared_block
,
FloatAB
*
__restrict__
p_shared_block
,
integral_constant
<
bool
,
HasMainKBlockLoop
>
,
integral_constant
<
bool
,
HasMainKBlockLoop
>
,
integral_constant
<
bool
,
HasDoubleTailKBlockLoop
>
)
integral_constant
<
bool
,
HasDoubleTailKBlockLoop
>
)
{
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
const
auto
a_global_buf
=
make_dynamic_buffer
<
AddressSpace
::
Global
>
(
const
auto
a_global_buf
=
make_dynamic_buffer
<
AddressSpace
::
Global
>
(
p_a_g
lobal
,
a_k_m_g
lobal
_desc
.
GetElementSpaceSize
());
p_a_g
rid
,
a_k_m_g
rid
_desc
.
GetElementSpaceSize
());
const
auto
b_global_buf
=
make_dynamic_buffer
<
AddressSpace
::
Global
>
(
const
auto
b_global_buf
=
make_dynamic_buffer
<
AddressSpace
::
Global
>
(
p_b_g
lobal
,
b_k_n_g
lobal
_desc
.
GetElementSpaceSize
());
p_b_g
rid
,
b_k_n_g
rid
_desc
.
GetElementSpaceSize
());
auto
c_g
lobal
_buf
=
make_dynamic_buffer
<
AddressSpace
::
Global
>
(
auto
c_g
rid
_buf
=
make_dynamic_buffer
<
AddressSpace
::
Global
>
(
p_c_g
lobal
,
c_m0_m1_n0_n1_g
lobal
_desc
.
GetElementSpaceSize
());
p_c_g
rid
,
c_m0_m1_n0_n1_g
rid
_desc
.
GetElementSpaceSize
());
const
auto
K
=
a_k_m_g
lobal
_desc
.
GetLength
(
I0
);
const
auto
K
=
a_k_m_g
rid
_desc
.
GetLength
(
I0
);
const
auto
M
=
a_k_m_g
lobal
_desc
.
GetLength
(
I1
);
const
auto
M
=
a_k_m_g
rid
_desc
.
GetLength
(
I1
);
const
auto
N
=
b_k_n_g
lobal
_desc
.
GetLength
(
I1
);
const
auto
N
=
b_k_n_g
rid
_desc
.
GetLength
(
I1
);
// divide block work by [M, N]
// divide block work by [M, N]
const
auto
block_work_idx
=
const
auto
block_work_idx
=
...
@@ -233,7 +276,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2
...
@@ -233,7 +276,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2
ABlockTransferThreadClusterArrangeOrder
,
ABlockTransferThreadClusterArrangeOrder
,
FloatAB
,
FloatAB
,
FloatAB
,
FloatAB
,
decltype
(
a_k_m_g
lobal
_desc
),
decltype
(
a_k_m_g
rid
_desc
),
decltype
(
a_k_m_block_desc
),
decltype
(
a_k_m_block_desc
),
ABlockTransferSrcAccessOrder
,
ABlockTransferSrcAccessOrder
,
Sequence
<
0
,
1
>
,
Sequence
<
0
,
1
>
,
...
@@ -245,7 +288,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2
...
@@ -245,7 +288,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2
1
,
1
,
AThreadTransferSrcResetCoordinateAfterRun
,
AThreadTransferSrcResetCoordinateAfterRun
,
true
>
(
true
>
(
a_k_m_g
lobal
_desc
,
a_k_m_g
rid
_desc
,
make_multi_index
(
0
,
m_block_data_idx_on_global
),
make_multi_index
(
0
,
m_block_data_idx_on_global
),
a_k_m_block_desc
,
a_k_m_block_desc
,
make_multi_index
(
0
,
0
));
make_multi_index
(
0
,
0
));
...
@@ -260,7 +303,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2
...
@@ -260,7 +303,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2
BBlockTransferThreadClusterArrangeOrder
,
BBlockTransferThreadClusterArrangeOrder
,
FloatAB
,
FloatAB
,
FloatAB
,
FloatAB
,
decltype
(
b_k_n_g
lobal
_desc
),
decltype
(
b_k_n_g
rid
_desc
),
decltype
(
b_k_n_block_desc
),
decltype
(
b_k_n_block_desc
),
BBlockTransferSrcAccessOrder
,
BBlockTransferSrcAccessOrder
,
Sequence
<
0
,
1
>
,
Sequence
<
0
,
1
>
,
...
@@ -272,7 +315,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2
...
@@ -272,7 +315,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2
1
,
1
,
BThreadTransferSrcResetCoordinateAfterRun
,
BThreadTransferSrcResetCoordinateAfterRun
,
true
>
(
true
>
(
b_k_n_g
lobal
_desc
,
b_k_n_g
rid
_desc
,
make_multi_index
(
0
,
n_block_data_idx_on_global
),
make_multi_index
(
0
,
n_block_data_idx_on_global
),
b_k_n_block_desc
,
b_k_n_block_desc
,
make_multi_index
(
0
,
0
));
make_multi_index
(
0
,
0
));
...
@@ -328,15 +371,15 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2
...
@@ -328,15 +371,15 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2
constexpr
auto
b_block_slice_copy_step
=
make_multi_index
(
KPerBlock
,
0
);
constexpr
auto
b_block_slice_copy_step
=
make_multi_index
(
KPerBlock
,
0
);
// hack to control index calculation when iterating over A and B matrix for threadwise copy
// hack to control index calculation when iterating over A and B matrix for threadwise copy
constexpr
auto
a_k_m_global_iterator_hacks
=
AG
lobal
IteratorHacks
{};
constexpr
auto
a_k_m_global_iterator_hacks
=
AG
rid
IteratorHacks
{};
constexpr
auto
b_k_n_global_iterator_hacks
=
BG
lobal
IteratorHacks
{};
constexpr
auto
b_k_n_global_iterator_hacks
=
BG
rid
IteratorHacks
{};
// hack to control index calculation when move slice window for A and B matrix for
// hack to control index calculation when move slice window for A and B matrix for
// threadwise copy
// threadwise copy
constexpr
auto
a_k_m_global_move_slice_window_iterator_hack
=
constexpr
auto
a_k_m_global_move_slice_window_iterator_hack
=
AG
lobal
MoveSliceWindowIteratorHacks
{};
AG
rid
MoveSliceWindowIteratorHacks
{};
constexpr
auto
b_k_n_global_move_slice_window_iterator_hack
=
constexpr
auto
b_k_n_global_move_slice_window_iterator_hack
=
BG
lobal
MoveSliceWindowIteratorHacks
{};
BG
rid
MoveSliceWindowIteratorHacks
{};
auto
a_block_even_buf
=
make_dynamic_buffer
<
AddressSpace
::
Lds
>
(
auto
a_block_even_buf
=
make_dynamic_buffer
<
AddressSpace
::
Lds
>
(
p_a_block_double
,
a_k_m_block_desc
.
GetElementSpaceSize
());
p_a_block_double
,
a_k_m_block_desc
.
GetElementSpaceSize
());
...
@@ -350,8 +393,8 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2
...
@@ -350,8 +393,8 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2
// LDS double buffer: preload data into LDS
// LDS double buffer: preload data into LDS
{
{
a_blockwise_copy
.
RunRead
(
a_k_m_g
lobal
_desc
,
a_global_buf
,
a_k_m_global_iterator_hacks
);
a_blockwise_copy
.
RunRead
(
a_k_m_g
rid
_desc
,
a_global_buf
,
a_k_m_global_iterator_hacks
);
b_blockwise_copy
.
RunRead
(
b_k_n_g
lobal
_desc
,
b_global_buf
,
b_k_n_global_iterator_hacks
);
b_blockwise_copy
.
RunRead
(
b_k_n_g
rid
_desc
,
b_global_buf
,
b_k_n_global_iterator_hacks
);
a_blockwise_copy
.
RunWrite
(
a_k_m_block_desc
,
a_block_even_buf
);
a_blockwise_copy
.
RunWrite
(
a_k_m_block_desc
,
a_block_even_buf
);
b_blockwise_copy
.
RunWrite
(
b_k_n_block_desc
,
b_block_even_buf
);
b_blockwise_copy
.
RunWrite
(
b_k_n_block_desc
,
b_block_even_buf
);
...
@@ -366,10 +409,10 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2
...
@@ -366,10 +409,10 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2
do
do
{
{
// even iteration
// even iteration
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_k_m_g
lobal
_desc
,
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_k_m_g
rid
_desc
,
a_block_slice_copy_step
,
a_block_slice_copy_step
,
a_k_m_global_move_slice_window_iterator_hack
);
a_k_m_global_move_slice_window_iterator_hack
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_k_n_g
lobal
_desc
,
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_k_n_g
rid
_desc
,
b_block_slice_copy_step
,
b_block_slice_copy_step
,
b_k_n_global_move_slice_window_iterator_hack
);
b_k_n_global_move_slice_window_iterator_hack
);
...
@@ -377,9 +420,9 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2
...
@@ -377,9 +420,9 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2
// LDS doubel buffer: load next data from device mem
// LDS doubel buffer: load next data from device mem
a_blockwise_copy
.
RunRead
(
a_blockwise_copy
.
RunRead
(
a_k_m_g
lobal
_desc
,
a_global_buf
,
a_k_m_global_iterator_hacks
);
a_k_m_g
rid
_desc
,
a_global_buf
,
a_k_m_global_iterator_hacks
);
b_blockwise_copy
.
RunRead
(
b_blockwise_copy
.
RunRead
(
b_k_n_g
lobal
_desc
,
b_global_buf
,
b_k_n_global_iterator_hacks
);
b_k_n_g
rid
_desc
,
b_global_buf
,
b_k_n_global_iterator_hacks
);
// LDS double buffer: GEMM on current data
// LDS double buffer: GEMM on current data
blockwise_gemm
.
Run
(
blockwise_gemm
.
Run
(
...
@@ -390,10 +433,10 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2
...
@@ -390,10 +433,10 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2
b_blockwise_copy
.
RunWrite
(
b_k_n_block_desc
,
b_block_odd_buf
);
b_blockwise_copy
.
RunWrite
(
b_k_n_block_desc
,
b_block_odd_buf
);
// odd iteration
// odd iteration
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_k_m_g
lobal
_desc
,
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_k_m_g
rid
_desc
,
a_block_slice_copy_step
,
a_block_slice_copy_step
,
a_k_m_global_move_slice_window_iterator_hack
);
a_k_m_global_move_slice_window_iterator_hack
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_k_n_g
lobal
_desc
,
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_k_n_g
rid
_desc
,
b_block_slice_copy_step
,
b_block_slice_copy_step
,
b_k_n_global_move_slice_window_iterator_hack
);
b_k_n_global_move_slice_window_iterator_hack
);
...
@@ -401,9 +444,9 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2
...
@@ -401,9 +444,9 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2
// LDS doubel buffer: load next data from device mem
// LDS doubel buffer: load next data from device mem
a_blockwise_copy
.
RunRead
(
a_blockwise_copy
.
RunRead
(
a_k_m_g
lobal
_desc
,
a_global_buf
,
a_k_m_global_iterator_hacks
);
a_k_m_g
rid
_desc
,
a_global_buf
,
a_k_m_global_iterator_hacks
);
b_blockwise_copy
.
RunRead
(
b_blockwise_copy
.
RunRead
(
b_k_n_g
lobal
_desc
,
b_global_buf
,
b_k_n_global_iterator_hacks
);
b_k_n_g
rid
_desc
,
b_global_buf
,
b_k_n_global_iterator_hacks
);
// LDS double buffer: GEMM on current data
// LDS double buffer: GEMM on current data
blockwise_gemm
.
Run
(
blockwise_gemm
.
Run
(
...
@@ -420,18 +463,18 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2
...
@@ -420,18 +463,18 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2
// LDS double buffer: tail
// LDS double buffer: tail
if
constexpr
(
HasDoubleTailKBlockLoop
)
// if has 2 iteration left
if
constexpr
(
HasDoubleTailKBlockLoop
)
// if has 2 iteration left
{
{
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_k_m_g
lobal
_desc
,
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_k_m_g
rid
_desc
,
a_block_slice_copy_step
,
a_block_slice_copy_step
,
a_k_m_global_move_slice_window_iterator_hack
);
a_k_m_global_move_slice_window_iterator_hack
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_k_n_g
lobal
_desc
,
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_k_n_g
rid
_desc
,
b_block_slice_copy_step
,
b_block_slice_copy_step
,
b_k_n_global_move_slice_window_iterator_hack
);
b_k_n_global_move_slice_window_iterator_hack
);
__syncthreads
();
__syncthreads
();
// LDS double buffer: load last data from device mem
// LDS double buffer: load last data from device mem
a_blockwise_copy
.
RunRead
(
a_k_m_g
lobal
_desc
,
a_global_buf
,
a_k_m_global_iterator_hacks
);
a_blockwise_copy
.
RunRead
(
a_k_m_g
rid
_desc
,
a_global_buf
,
a_k_m_global_iterator_hacks
);
b_blockwise_copy
.
RunRead
(
b_k_n_g
lobal
_desc
,
b_global_buf
,
b_k_n_global_iterator_hacks
);
b_blockwise_copy
.
RunRead
(
b_k_n_g
rid
_desc
,
b_global_buf
,
b_k_n_global_iterator_hacks
);
// LDS double buffer: GEMM on 2nd-last data
// LDS double buffer: GEMM on 2nd-last data
blockwise_gemm
.
Run
(
blockwise_gemm
.
Run
(
...
@@ -462,7 +505,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2
...
@@ -462,7 +505,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2
constexpr
auto
N1
=
Number
<
N1PerThread
*
M1N1ThreadClusterN10
*
M1N1ThreadClusterN11
>
{};
constexpr
auto
N1
=
Number
<
N1PerThread
*
M1N1ThreadClusterN10
*
M1N1ThreadClusterN11
>
{};
// hack to control index calculation when iterating over c_m0_m1_n0_n1_global tensor
// hack to control index calculation when iterating over c_m0_m1_n0_n1_global tensor
constexpr
auto
c_m0_m1_n0_n1_global_tensor_iterator_hacks
=
CG
lobal
IteratorHacks
{};
constexpr
auto
c_m0_m1_n0_n1_global_tensor_iterator_hacks
=
CG
rid
IteratorHacks
{};
const
auto
c_thread_data_idx_on_block
=
const
auto
c_thread_data_idx_on_block
=
blockwise_gemm
.
CalculateCThreadOriginDataIndex
(
get_thread_local_1d_id
());
blockwise_gemm
.
CalculateCThreadOriginDataIndex
(
get_thread_local_1d_id
());
...
@@ -470,7 +513,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2
...
@@ -470,7 +513,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2
ThreadwiseDynamicTensorSliceTransfer_v1r3
<
FloatAcc
,
ThreadwiseDynamicTensorSliceTransfer_v1r3
<
FloatAcc
,
FloatC
,
FloatC
,
decltype
(
c_m0_m1_n0_n1_thread_desc
),
decltype
(
c_m0_m1_n0_n1_thread_desc
),
decltype
(
c_m0_m1_n0_n1_g
lobal
_desc
),
decltype
(
c_m0_m1_n0_n1_g
rid
_desc
),
decltype
(
c_m0_m1_n0_n1_thread_tensor_lengths
),
decltype
(
c_m0_m1_n0_n1_thread_tensor_lengths
),
CThreadTransferSrcDstAccessOrder
,
CThreadTransferSrcDstAccessOrder
,
CThreadTransferSrcDstVectorDim
,
CThreadTransferSrcDstVectorDim
,
...
@@ -478,7 +521,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2
...
@@ -478,7 +521,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2
CGlobalMemoryDataOperation
,
CGlobalMemoryDataOperation
,
1
,
1
,
true
>
{
true
>
{
c_m0_m1_n0_n1_g
lobal
_desc
,
c_m0_m1_n0_n1_g
rid
_desc
,
make_multi_index
(
m_block_data_idx_on_global
/
M1
+
c_thread_data_idx_on_block
[
I0
],
make_multi_index
(
m_block_data_idx_on_global
/
M1
+
c_thread_data_idx_on_block
[
I0
],
c_thread_data_idx_on_block
[
I1
],
c_thread_data_idx_on_block
[
I1
],
n_block_data_idx_on_global
/
N1
+
c_thread_data_idx_on_block
[
I2
],
n_block_data_idx_on_global
/
N1
+
c_thread_data_idx_on_block
[
I2
],
...
@@ -486,19 +529,19 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2
...
@@ -486,19 +529,19 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2
.
Run
(
c_m0_m1_n0_n1_thread_desc
,
.
Run
(
c_m0_m1_n0_n1_thread_desc
,
make_tuple
(
I0
,
I0
,
I0
,
I0
),
make_tuple
(
I0
,
I0
,
I0
,
I0
),
c_thread_buf
,
c_thread_buf
,
c_m0_m1_n0_n1_g
lobal
_desc
,
c_m0_m1_n0_n1_g
rid
_desc
,
c_g
lobal
_buf
,
c_g
rid
_buf
,
c_m0_m1_n0_n1_global_tensor_iterator_hacks
);
c_m0_m1_n0_n1_global_tensor_iterator_hacks
);
}
}
}
}
template
<
bool
HasMainKBlockLoop
,
bool
HasDoubleTailKBlockLoop
>
template
<
bool
HasMainKBlockLoop
,
bool
HasDoubleTailKBlockLoop
>
__device__
static
void
Run
(
const
FloatAB
*
__restrict__
p_a_g
lobal
,
__device__
static
void
Run
(
const
FloatAB
*
__restrict__
p_a_g
rid
,
const
FloatAB
*
__restrict__
p_b_g
lobal
,
const
FloatAB
*
__restrict__
p_b_g
rid
,
FloatC
*
__restrict__
p_c_g
lobal
,
FloatC
*
__restrict__
p_c_g
rid
,
const
A
Global
Desc
&
a_k_m_g
lobal
_desc
,
const
A
KMGrid
Desc
&
a_k_m_g
rid
_desc
,
const
B
Global
Desc
&
b_k_n_g
lobal
_desc
,
const
B
KNGrid
Desc
&
b_k_n_g
rid
_desc
,
const
C
Global
Desc
&
c_m0_m1_n0_n1_g
lobal
_desc
,
const
C
M0M1N0N1Grid
Desc
&
c_m0_m1_n0_n1_g
rid
_desc
,
const
CBlockClusterDesc
&
c_block_cluster_desc
,
const
CBlockClusterDesc
&
c_block_cluster_desc
,
integral_constant
<
bool
,
HasMainKBlockLoop
>
,
integral_constant
<
bool
,
HasMainKBlockLoop
>
,
integral_constant
<
bool
,
HasDoubleTailKBlockLoop
>
)
integral_constant
<
bool
,
HasDoubleTailKBlockLoop
>
)
...
@@ -507,12 +550,12 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2
...
@@ -507,12 +550,12 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2
__shared__
FloatAB
p_shared_block
[
shared_block_size
];
__shared__
FloatAB
p_shared_block
[
shared_block_size
];
Run
(
p_a_g
lobal
,
Run
(
p_a_g
rid
,
p_b_g
lobal
,
p_b_g
rid
,
p_c_g
lobal
,
p_c_g
rid
,
a_k_m_g
lobal
_desc
,
a_k_m_g
rid
_desc
,
b_k_n_g
lobal
_desc
,
b_k_n_g
rid
_desc
,
c_m0_m1_n0_n1_g
lobal
_desc
,
c_m0_m1_n0_n1_g
rid
_desc
,
c_block_cluster_desc
,
c_block_cluster_desc
,
p_shared_block
,
p_shared_block
,
integral_constant
<
bool
,
HasMainKBlockLoop
>
{},
integral_constant
<
bool
,
HasMainKBlockLoop
>
{},
...
...
driver/include/device_dynamic_convolution_forward_implicit_gemm_v4r4r2_nchw_kcyx_nkhw.hpp
View file @
d99e020d
...
@@ -485,6 +485,35 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_nchw_kcyx_nkhw(
...
@@ -485,6 +485,35 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_nchw_kcyx_nkhw(
in_left_pads
,
in_left_pads
,
in_right_pads
);
in_right_pads
);
// hack to control index calculation when iterating over wei_gemmk_gemmm_global tensor
constexpr
auto
wei_gemmk_gemmm_global_iterator_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
>
{}),
make_tuple
(
Sequence
<
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
>
{}));
constexpr
auto
wei_gemmk_gemmm_global_move_slice_window_iterator_hacks
=
Sequence
<
0
,
0
,
0
>
{};
// hack to control index calculation when iterating over in_gemmk_gemmn_global tensor
constexpr
auto
in_gemmk_gemmn_global_iterator_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
>
{}),
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
>
{}));
constexpr
auto
in_gemmk_gemmn_global_move_slice_window_iterator_hacks
=
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
2
>
{};
// hack to control index calculation when iterating over out_gemmm0_gemmm1_gemmn0_gemmn1_global
// tensor hack for NKHW format
constexpr
auto
out_gemmm0_gemmm1_gemmn0_gemmn1_global_iterator_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
1
,
0
,
0
>
{},
Sequence
<
0
,
0
,
1
,
0
,
0
>
{}),
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
2
,
0
,
0
>
{},
Sequence
<
0
,
0
,
2
,
0
,
0
>
{}));
for
(
index_t
i
=
0
;
i
<
5
;
++
i
)
for
(
index_t
i
=
0
;
i
<
5
;
++
i
)
{
{
float
ave_time
=
launch_kernel_dynamic_gemm_v1r2
<
float
ave_time
=
launch_kernel_dynamic_gemm_v1r2
<
...
@@ -527,25 +556,26 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_nchw_kcyx_nkhw(
...
@@ -527,25 +556,26 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_nchw_kcyx_nkhw(
Sequence
<
2
,
3
,
0
,
1
>
,
Sequence
<
2
,
3
,
0
,
1
>
,
3
,
3
,
GemmCThreadTransferDstScalarPerVector_GemmN1
,
GemmCThreadTransferDstScalarPerVector_GemmN1
,
decltype
(
descs
[
I4
]),
decltype
(
wei_gemmk_gemmm_global_iterator_hacks
),
decltype
(
descs
[
I5
]),
decltype
(
in_gemmk_gemmn_global_iterator_hacks
),
decltype
(
descs
[
I6
]),
decltype
(
out_gemmm0_gemmm1_gemmn0_gemmn1_global_iterator_hacks
),
decltype
(
descs
[
I7
]),
decltype
(
wei_gemmk_gemmm_global_move_slice_window_iterator_hacks
),
decltype
(
descs
[
I8
])
>
(
static_cast
<
typename
vector_type
<
TInWei
,
InWeiVectorSize
>::
type
*>
(
decltype
(
in_gemmk_gemmn_global_move_slice_window_iterator_hacks
)
>
(
wei_k_c_y_x_device_buf
.
GetDeviceBuffer
()),
static_cast
<
typename
vector_type
<
TInWei
,
InWeiVectorSize
>::
type
*>
(
static_cast
<
typename
vector_type
<
TInWei
,
InWeiVectorSize
>::
type
*>
(
wei_k_c_y_x_device_buf
.
GetDeviceBuffer
()),
in_n_c_hi_wi_device_buf
.
GetDeviceBuffer
()),
static_cast
<
typename
vector_type
<
TInWei
,
InWeiVectorSize
>::
type
*>
(
static_cast
<
TOut
*>
(
out_n_k_ho_wo_device_buf
.
GetDeviceBuffer
()),
in_n_c_hi_wi_device_buf
.
GetDeviceBuffer
()),
descs
[
I0
],
static_cast
<
TOut
*>
(
out_n_k_ho_wo_device_buf
.
GetDeviceBuffer
()),
descs
[
I1
],
descs
[
I0
],
descs
[
I2
],
descs
[
I1
],
descs
[
I3
],
descs
[
I2
],
descs
[
I4
],
descs
[
I3
],
descs
[
I5
],
wei_gemmk_gemmm_global_iterator_hacks
,
descs
[
I6
],
in_gemmk_gemmn_global_iterator_hacks
,
descs
[
I7
],
out_gemmm0_gemmm1_gemmn0_gemmn1_global_iterator_hacks
,
descs
[
I8
],
wei_gemmk_gemmm_global_move_slice_window_iterator_hacks
,
nrepeat
);
in_gemmk_gemmn_global_move_slice_window_iterator_hacks
,
nrepeat
);
float
perf
=
(
float
)
calculate_convolution_flops
(
float
perf
=
(
float
)
calculate_convolution_flops
(
in_n_c_hi_wi_desc
,
wei_k_c_y_x_desc
,
out_n_k_ho_wo_desc
)
/
in_n_c_hi_wi_desc
,
wei_k_c_y_x_desc
,
out_n_k_ho_wo_desc
)
/
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment