Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
yangql
composable_kernel-1
Commits
8bdaba51
"git@developer.sourcefind.cn:yangql/composable_kernel-1.git" did not exist on "10c72aced809dd4c826cab1834d656b5959d4dd5"
Commit
8bdaba51
authored
Aug 13, 2019
by
Chao Liu
Browse files
clean up
parent
fab2f10a
Changes
20
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
510 additions
and
656 deletions
+510
-656
composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v1r1_chwn_cyxk_khwn.hpp
...ridwise_convolution_implicit_gemm_v1r1_chwn_cyxk_khwn.hpp
+7
-8
composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v1r3_chwn_cyxk_khwn.hpp
...ridwise_convolution_implicit_gemm_v1r3_chwn_cyxk_khwn.hpp
+68
-73
composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v1r3_chwn_cyxk_khwn_lds_double_buffer.hpp
...n_implicit_gemm_v1r3_chwn_cyxk_khwn_lds_double_buffer.hpp
+107
-79
composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v1r3_nchw_cyxk_nkhw_lds_double_buffer.hpp
...n_implicit_gemm_v1r3_nchw_cyxk_nkhw_lds_double_buffer.hpp
+32
-35
composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v2_chwn_cyxk_khwn.hpp
.../gridwise_convolution_implicit_gemm_v2_chwn_cyxk_khwn.hpp
+7
-8
composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v2_chwn_cyxk_khwn_lds_double_buffer.hpp
...ion_implicit_gemm_v2_chwn_cyxk_khwn_lds_double_buffer.hpp
+24
-29
composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v3_nchw_cyxk_nkhw_lds_double_buffer.hpp
...ion_implicit_gemm_v3_nchw_cyxk_nkhw_lds_double_buffer.hpp
+20
-21
composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer.hpp
...n_implicit_gemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer.hpp
+30
-34
composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r2_nchw_kcyx_nkhw_lds_double_buffer.hpp
...n_implicit_gemm_v4r2_nchw_kcyx_nkhw_lds_double_buffer.hpp
+15
-18
composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r3_nchw_kcyx_nkhw_lds_double_buffer.hpp
...n_implicit_gemm_v4r3_nchw_kcyx_nkhw_lds_double_buffer.hpp
+19
-22
composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer.hpp
...n_implicit_gemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer.hpp
+29
-33
composable_kernel/include/tensor_description/tensor_coordinate.hpp
...e_kernel/include/tensor_description/tensor_coordinate.hpp
+22
-0
composable_kernel/include/tensor_operation/blockwise_2d_tensor_op.hpp
...ernel/include/tensor_operation/blockwise_2d_tensor_op.hpp
+9
-9
composable_kernel/include/tensor_operation/blockwise_3d_tensor_op.hpp
...ernel/include/tensor_operation/blockwise_3d_tensor_op.hpp
+5
-5
composable_kernel/include/tensor_operation/blockwise_4d_tensor_op.hpp
...ernel/include/tensor_operation/blockwise_4d_tensor_op.hpp
+5
-5
composable_kernel/include/tensor_operation/blockwise_generic_tensor_slice_copy.hpp
.../tensor_operation/blockwise_generic_tensor_slice_copy.hpp
+22
-27
composable_kernel/include/tensor_operation/blockwise_tensor_slice_copy.hpp
.../include/tensor_operation/blockwise_tensor_slice_copy.hpp
+8
-8
composable_kernel/include/tensor_operation/threadwise_generic_tensor_slice_copy.hpp
...tensor_operation/threadwise_generic_tensor_slice_copy.hpp
+3
-170
driver/include/device_convolution_implicit_gemm_v1_chwn_cyxk_khwn.hpp
...de/device_convolution_implicit_gemm_v1_chwn_cyxk_khwn.hpp
+73
-67
driver/src/driver.cpp
driver/src/driver.cpp
+5
-5
No files found.
composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v1r1_chwn_cyxk_khwn.hpp
View file @
8bdaba51
...
@@ -241,16 +241,15 @@ struct GridwiseConvolutionImplicitGemm_v1r1_chwn_cyxk_khwn
...
@@ -241,16 +241,15 @@ struct GridwiseConvolutionImplicitGemm_v1r1_chwn_cyxk_khwn
blockwise_in_copy
.
Run
(
p_in_global_block_offset
,
p_in_block
);
blockwise_in_copy
.
Run
(
p_in_global_block_offset
,
p_in_block
);
blockwise_wei_copy
.
Run
(
p_wei_global_block_offset
,
p_wei_block
);
blockwise_wei_copy
.
Run
(
p_wei_global_block_offset
,
p_wei_block
);
#else
#else
Float
p_in_register_
clipboard
[
blockwise_in_copy
.
GetRegister
Clipboard
Size
()];
Float
p_in_register_
buffer
[
blockwise_in_copy
.
GetRegister
Buffer
Size
()];
Float
p_wei_register_
clipboard
[
blockwise_wei_copy
.
GetRegister
Clipboard
Size
()];
Float
p_wei_register_
buffer
[
blockwise_wei_copy
.
GetRegister
Buffer
Size
()];
blockwise_in_copy
.
RunLoadRegisterClipboard
(
p_in_global_block_offset
,
blockwise_in_copy
.
RunLoadRegisterBuffer
(
p_in_global_block_offset
,
p_in_register_buffer
);
p_in_register_clipboard
);
blockwise_wei_copy
.
RunLoadRegisterBuffer
(
p_wei_global_block_offset
,
blockwise_wei_copy
.
RunLoadRegisterClipboard
(
p_wei_global_block_offset
,
p_wei_register_buffer
);
p_wei_register_clipboard
);
blockwise_in_copy
.
RunStoreRegister
Clipboard
(
p_in_register_
clipboard
,
p_in_block
);
blockwise_in_copy
.
RunStoreRegister
Buffer
(
p_in_register_
buffer
,
p_in_block
);
blockwise_wei_copy
.
RunStoreRegister
Clipboard
(
p_wei_register_
clipboard
,
p_wei_block
);
blockwise_wei_copy
.
RunStoreRegister
Buffer
(
p_wei_register_
buffer
,
p_wei_block
);
#endif
#endif
__syncthreads
();
__syncthreads
();
...
...
composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v1r3_chwn_cyxk_khwn.hpp
View file @
8bdaba51
...
@@ -4,11 +4,8 @@
...
@@ -4,11 +4,8 @@
#include "common_header.hpp"
#include "common_header.hpp"
#include "ConstantTensorDescriptor.hpp"
#include "ConstantTensorDescriptor.hpp"
#include "ConstantMatrixDescriptor.hpp"
#include "ConstantMatrixDescriptor.hpp"
#include "blockwise_2d_tensor_op.hpp"
#include "blockwise_4d_tensor_op.hpp"
#include "blockwise_generic_tensor_slice_copy.hpp"
#include "blockwise_generic_tensor_slice_copy.hpp"
#include "threadwise_tensor_slice_copy.hpp"
#include "threadwise_generic_tensor_slice_copy.hpp"
#include "threadwise_4d_tensor_op.hpp"
#include "blockwise_batched_gemm.hpp"
#include "blockwise_batched_gemm.hpp"
namespace
ck
{
namespace
ck
{
...
@@ -37,10 +34,13 @@ template <index_t GridSize,
...
@@ -37,10 +34,13 @@ template <index_t GridSize,
index_t
GemmKPerThreadLoop
,
index_t
GemmKPerThreadLoop
,
index_t
GemmDataPerReadA
,
index_t
GemmDataPerReadA
,
index_t
GemmDataPerReadB
,
index_t
GemmDataPerReadB
,
class
InBlockCopySubLengths_CHWN
,
class
InBlockCopyClusterLengths_CHWN
,
class
InBlockCopyClusterLengths_CHWN
,
index_t
InBlockCopyDataPerRead_N
,
index_t
InBlockCopyDataPerAccess_N
,
index_t
WeiBlockCopyDataPerRead_K
,
class
WeiBlockCopySubLengths_CK
,
index_t
OutThreadCopyDataPerWrite_N
>
class
WeiBlockCopyClusterLengths_CK
,
index_t
WeiBlockCopyDataPerAccess_K
,
index_t
OutThreadCopyDataPerAccess_N
>
struct
GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn
struct
GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn
{
{
__device__
void
Run
(
const
Float
*
const
__restrict__
p_in_global
,
__device__
void
Run
(
const
Float
*
const
__restrict__
p_in_global
,
...
@@ -103,8 +103,8 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn
...
@@ -103,8 +103,8 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn
// LDS tensor view
// LDS tensor view
// be careful of alignment
// be careful of alignment
constexpr
index_t
max_align
=
math
::
lcm
(
InBlockCopyDataPer
Read
_N
,
constexpr
index_t
max_align
=
math
::
lcm
(
InBlockCopyDataPer
Access
_N
,
WeiBlockCopyDataPer
Read
_K
,
WeiBlockCopyDataPer
Access
_K
,
GemmDataPerReadA
,
GemmDataPerReadA
,
GemmDataPerReadB
);
GemmDataPerReadB
);
...
@@ -123,24 +123,10 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn
...
@@ -123,24 +123,10 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn
constexpr
auto
out_k_h_w_n_thread_desc
=
make_ConstantTensorDescriptor_packed
(
constexpr
auto
out_k_h_w_n_thread_desc
=
make_ConstantTensorDescriptor_packed
(
Sequence
<
KPerThread
,
HoPerThread
,
WoPerThread
,
NPerThread
>
{});
Sequence
<
KPerThread
,
HoPerThread
,
WoPerThread
,
NPerThread
>
{});
// blockwise copy
// blockwise copy
// input: format is [C, Hi, Wi, N]
// input: format is [C, Hi, Wi, N]
#if 0
const auto blockwise_in_copy =
Blockwise4dTensorCopy3<BlockSize,
Float,
decltype(in_c_h_w_n_global_desc),
decltype(in_c_h_w_n_block_desc),
decltype(in_c_h_w_n_block_desc.GetLengths()),
InBlockCopyClusterLengths_CHWN,
InBlockCopyDataPerRead_N>{};
#elif
0
using
InBlockCopySubLengths_CHWN
=
decltype
(
in_c_h_w_n_block_desc
.
GetLengths
()
/
InBlockCopyClusterLengths_CHWN
{});
auto
blockwise_in_copy
=
auto
blockwise_in_copy
=
BlockwiseGenericTensorSliceCopy_v1
<
BlockSize
,
BlockwiseGenericTensorSliceCopy_v2
<
BlockSize
,
Float
,
decltype
(
in_c_h_w_n_global_desc
),
decltype
(
in_c_h_w_n_global_desc
),
decltype
(
in_c_h_w_n_block_desc
),
decltype
(
in_c_h_w_n_block_desc
),
decltype
(
in_c_h_w_n_block_desc
.
GetLengths
()),
decltype
(
in_c_h_w_n_block_desc
.
GetLengths
()),
...
@@ -149,33 +135,28 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn
...
@@ -149,33 +135,28 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn
Sequence
<
0
,
1
,
2
,
3
>
,
Sequence
<
0
,
1
,
2
,
3
>
,
Sequence
<
0
,
1
,
2
,
3
>
,
Sequence
<
0
,
1
,
2
,
3
>
,
Sequence
<
0
,
1
,
2
,
3
>
,
Sequence
<
0
,
1
,
2
,
3
>
,
1
,
3
,
1
>
({
0
,
0
,
0
,
0
},
{
0
,
0
,
0
,
0
});
3
,
#elif 1
InBlockCopyDataPerAccess_N
,
using
InBlockCopySubLengths_CHWN
=
InBlockCopyDataPerAccess_N
>
({
0
,
0
,
0
,
0
},
decltype
(
in_c_h_w_n_block_desc
.
GetLengths
()
/
InBlockCopyClusterLengths_CHWN
{});
{
0
,
0
,
0
,
0
});
auto
blockwise_in_copy
=
BlockwiseGenericTensorSliceCopy_v2
<
BlockSize
,
decltype
(
in_c_h_w_n_global_desc
),
decltype
(
in_c_h_w_n_block_desc
),
NormalTensorCoordinate
<
decltype
(
in_c_h_w_n_global_desc
)
>
,
NormalTensorCoordinate
<
decltype
(
in_c_h_w_n_block_desc
)
>
,
decltype
(
in_c_h_w_n_block_desc
.
GetLengths
()),
InBlockCopySubLengths_CHWN
,
InBlockCopyClusterLengths_CHWN
,
Sequence
<
0
,
1
,
2
,
3
>>
({
0
,
0
,
0
,
0
},
{
0
,
0
,
0
,
0
});
#endif
// blockwise wei copy
// blockwise wei copy
// format is [CPerBlock,
X *
KPerBlock]
// format is [CPerBlock, KPerBlock]
const
auto
blockwise_wei_copy
=
const
auto
blockwise_wei_copy
=
Blockwise2dTensorCopy3
<
BlockSize
,
BlockwiseGenericTensorSliceCopy_v2
<
BlockSize
,
Float
,
decltype
(
wei_c_k_global_desc
),
decltype
(
wei_c_k_global_desc
),
decltype
(
wei_c_k_block_desc
),
decltype
(
wei_c_k_block_desc
),
decltype
(
wei_c_k_block_desc
.
GetLengths
()),
decltype
(
wei_c_k_block_desc
.
GetLengths
()),
WeiBlockCopySubLengths_CK
,
WeiBlockCopyDataPerRead_K
>
({
0
,
0
},
{
0
,
0
});
WeiBlockCopyClusterLengths_CK
,
Sequence
<
0
,
1
>
,
Sequence
<
0
,
1
>
,
Sequence
<
0
,
1
>
,
1
,
1
,
WeiBlockCopyDataPerAccess_K
,
WeiBlockCopyDataPerAccess_K
>
({
0
,
0
},
{
0
,
0
});
// a series of blockwise batched GEMM
// a series of blockwise batched GEMM
// C_matrix += transpose(A_matrix) * B_matrix
// C_matrix += transpose(A_matrix) * B_matrix
...
@@ -278,7 +259,7 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn
...
@@ -278,7 +259,7 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn
}
}
}
}
// output: register to global mem
,
// output: register to global mem
const
auto
c_thread_mtx_begin
=
const
auto
c_thread_mtx_begin
=
blockwise_batch_gemm
.
GetBeginOfThreadMatrixC
(
get_thread_local_1d_id
());
blockwise_batch_gemm
.
GetBeginOfThreadMatrixC
(
get_thread_local_1d_id
());
...
@@ -329,17 +310,24 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn
...
@@ -329,17 +310,24 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn
}
}
#endif
#endif
threadwise_tensor_slice_copy
(
out_10d_thread_desc
,
Float
*
p_out_thread_on_global
=
p_out_global
+
p_out_thread
,
out_k_h_w_n_global_desc
.
GetOffsetFromMultiIndex
(
out_10d_global_desc
,
k_block_data_begin
+
k_thread_data_begin
,
p_out_global
+
ho_block_data_begin
+
ho_thread_data_begin
,
out_k_h_w_n_global_desc
.
GetOffsetFromMultiIndex
(
wo_block_data_begin
+
wo_thread_data_begin
,
k_block_data_begin
+
k_thread_data_begin
,
n_block_data_begin
+
n_thread_data_begin
);
ho_block_data_begin
+
ho_thread_data_begin
,
wo_block_data_begin
+
wo_thread_data_begin
,
ThreadwiseGenericTensorSliceCopy_v2r1
<
decltype
(
out_10d_thread_desc
),
n_block_data_begin
+
n_thread_data_begin
),
decltype
(
out_10d_global_desc
),
out_10d_thread_desc
.
GetLengths
(),
decltype
(
out_10d_thread_desc
.
GetLengths
()),
Number
<
OutThreadCopyDataPerWrite_N
>
{});
arithmetic_sequence_gen
<
0
,
10
,
1
>::
type
,
arithmetic_sequence_gen
<
0
,
10
,
1
>::
type
,
9
,
9
,
OutThreadCopyDataPerAccess_N
,
OutThreadCopyDataPerAccess_N
>
(
make_zero_array
<
index_t
,
10
>
(),
make_zero_array
<
index_t
,
10
>
())
.
Run
(
p_out_thread
,
p_out_thread_on_global
);
}).
Else
([
&
](
auto
fwd
)
{
}).
Else
([
&
](
auto
fwd
)
{
static_assert
(
fwd
(
GemmNPerThreadSubC
)
>=
NPerBlock
&&
NPerThread
==
NPerBlock
&&
static_assert
(
fwd
(
GemmNPerThreadSubC
)
>=
NPerBlock
&&
NPerThread
==
NPerBlock
&&
GemmNPerThreadSubC
%
NPerThread
==
0
,
GemmNPerThreadSubC
%
NPerThread
==
0
,
...
@@ -380,17 +368,24 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn
...
@@ -380,17 +368,24 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn
}
}
#endif
#endif
threadwise_tensor_slice_copy
(
out_10d_thread_desc
,
Float
*
p_out_thread_on_global
=
p_out_global
+
p_out_thread
,
out_k_h_w_n_global_desc
.
GetOffsetFromMultiIndex
(
out_10d_global_desc
,
k_block_data_begin
+
k_thread_data_begin
,
p_out_global
+
ho_block_data_begin
+
ho_thread_data_begin
,
out_k_h_w_n_global_desc
.
GetOffsetFromMultiIndex
(
wo_block_data_begin
+
wo_thread_data_begin
,
k_block_data_begin
+
k_thread_data_begin
,
n_block_data_begin
+
n_thread_data_begin
);
ho_block_data_begin
+
ho_thread_data_begin
,
wo_block_data_begin
+
wo_thread_data_begin
,
ThreadwiseGenericTensorSliceCopy_v2r1
<
decltype
(
out_10d_thread_desc
),
n_block_data_begin
+
n_thread_data_begin
),
decltype
(
out_10d_global_desc
),
out_10d_thread_desc
.
GetLengths
(),
decltype
(
out_10d_thread_desc
.
GetLengths
()),
Number
<
OutThreadCopyDataPerWrite_N
>
{});
arithmetic_sequence_gen
<
0
,
10
,
1
>::
type
,
arithmetic_sequence_gen
<
0
,
10
,
1
>::
type
,
9
,
9
,
OutThreadCopyDataPerAccess_N
,
OutThreadCopyDataPerAccess_N
>
(
make_zero_array
<
index_t
,
10
>
(),
make_zero_array
<
index_t
,
10
>
())
.
Run
(
p_out_thread
,
p_out_thread_on_global
);
});
});
}
}
};
};
...
...
composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v1r3_chwn_cyxk_khwn_lds_double_buffer.hpp
View file @
8bdaba51
...
@@ -4,10 +4,8 @@
...
@@ -4,10 +4,8 @@
#include "common_header.hpp"
#include "common_header.hpp"
#include "ConstantTensorDescriptor.hpp"
#include "ConstantTensorDescriptor.hpp"
#include "ConstantMatrixDescriptor.hpp"
#include "ConstantMatrixDescriptor.hpp"
#include "blockwise_2d_tensor_op.hpp"
#include "blockwise_generic_tensor_slice_copy.hpp"
#include "blockwise_4d_tensor_op.hpp"
#include "threadwise_generic_tensor_slice_copy.hpp"
#include "threadwise_tensor_slice_copy.hpp"
#include "threadwise_4d_tensor_op.hpp"
#include "blockwise_batched_gemm.hpp"
#include "blockwise_batched_gemm.hpp"
namespace
ck
{
namespace
ck
{
...
@@ -36,10 +34,13 @@ template <index_t GridSize,
...
@@ -36,10 +34,13 @@ template <index_t GridSize,
index_t
GemmKPerThreadLoop
,
index_t
GemmKPerThreadLoop
,
index_t
GemmDataPerReadA
,
index_t
GemmDataPerReadA
,
index_t
GemmDataPerReadB
,
index_t
GemmDataPerReadB
,
class
InBlockCopySubLengths_CHWN
,
class
InBlockCopyClusterLengths_CHWN
,
class
InBlockCopyClusterLengths_CHWN
,
index_t
InBlockCopyDataPerRead_N
,
index_t
InBlockCopyDataPerAccess_N
,
index_t
WeiBlockCopyDataPerRead_K
,
class
WeiBlockCopySubLengths_CK
,
index_t
OutThreadCopyDataPerWrite_N
>
class
WeiBlockCopyClusterLengths_CK
,
index_t
WeiBlockCopyDataPerAccess_K
,
index_t
OutThreadCopyDataPerAccess_N
>
struct
GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn_lds_double_buffer
struct
GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn_lds_double_buffer
{
{
__device__
void
Run
(
const
Float
*
const
__restrict__
p_in_global
,
__device__
void
Run
(
const
Float
*
const
__restrict__
p_in_global
,
...
@@ -108,8 +109,8 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn_lds_double_buffer
...
@@ -108,8 +109,8 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn_lds_double_buffer
// LDS tensor view
// LDS tensor view
// be careful of alignment
// be careful of alignment
constexpr
index_t
max_align
=
math
::
lcm
(
InBlockCopyDataPer
Read
_N
,
constexpr
index_t
max_align
=
math
::
lcm
(
InBlockCopyDataPer
Access
_N
,
WeiBlockCopyDataPer
Read
_K
,
WeiBlockCopyDataPer
Access
_K
,
GemmDataPerReadA
,
GemmDataPerReadA
,
GemmDataPerReadB
);
GemmDataPerReadB
);
...
@@ -130,24 +131,38 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn_lds_double_buffer
...
@@ -130,24 +131,38 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn_lds_double_buffer
// blockwise copy
// blockwise copy
// input: format is [C, Hi, Wi, N]
// input: format is [C, Hi, Wi, N]
const
auto
blockwise_in_copy
=
auto
blockwise_in_copy
=
Blockwise4dTensorCopy3
<
BlockSize
,
BlockwiseGenericTensorSliceCopy_v2
<
BlockSize
,
Float
,
decltype
(
in_c_h_w_n_global_desc
),
decltype
(
in_c_h_w_n_global_desc
),
decltype
(
in_c_h_w_n_block_desc
),
decltype
(
in_c_h_w_n_block_desc
),
decltype
(
in_c_h_w_n_block_desc
.
GetLengths
()),
decltype
(
in_c_h_w_n_block_desc
.
GetLengths
()),
InBlockCopySubLengths_CHWN
,
InBlockCopyClusterLengths_CHWN
,
InBlockCopyClusterLengths_CHWN
,
InBlockCopyDataPerRead_N
>
{};
Sequence
<
0
,
1
,
2
,
3
>
,
Sequence
<
0
,
1
,
2
,
3
>
,
Sequence
<
0
,
1
,
2
,
3
>
,
3
,
3
,
InBlockCopyDataPerAccess_N
,
InBlockCopyDataPerAccess_N
>
({
0
,
0
,
0
,
0
},
{
0
,
0
,
0
,
0
});
// blockwise wei copy
// blockwise wei copy
// format is [CPerBlock, X * KPerBlock]
// format is [CPerBlock, X * KPerBlock]
const
auto
blockwise_wei_copy
=
const
auto
blockwise_wei_copy
=
Blockwise2dTensorCopy3
<
BlockSize
,
BlockwiseGenericTensorSliceCopy_v2
<
BlockSize
,
Float
,
decltype
(
wei_c_k_global_desc
),
decltype
(
wei_c_k_global_desc
),
decltype
(
wei_c_k_block_desc
),
decltype
(
wei_c_k_block_desc
),
decltype
(
wei_c_k_block_desc
.
GetLengths
()),
decltype
(
wei_c_k_block_desc
.
GetLengths
()),
WeiBlockCopySubLengths_CK
,
WeiBlockCopyDataPerRead_K
>
({
0
,
0
},
{
0
,
0
});
WeiBlockCopyClusterLengths_CK
,
Sequence
<
0
,
1
>
,
Sequence
<
0
,
1
>
,
Sequence
<
0
,
1
>
,
1
,
1
,
WeiBlockCopyDataPerAccess_K
,
WeiBlockCopyDataPerAccess_K
>
({
0
,
0
},
{
0
,
0
});
// a series of blockwise batched GEMM
// a series of blockwise batched GEMM
// C_matrix += transpose(A_matrix) * B_matrix
// C_matrix += transpose(A_matrix) * B_matrix
...
@@ -233,18 +248,18 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn_lds_double_buffer
...
@@ -233,18 +248,18 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn_lds_double_buffer
// LDS double buffer: preload data into LDS
// LDS double buffer: preload data into LDS
{
{
Float
p_in_register_
clipboard
[
blockwise_in_copy
.
GetRegister
Clipboard
Size
()];
Float
p_in_register_
buffer
[
blockwise_in_copy
.
GetRegister
Buffer
Size
()];
Float
p_wei_register_
clipboard
[
blockwise_wei_copy
.
GetRegister
Clipboard
Size
()];
Float
p_wei_register_
buffer
[
blockwise_wei_copy
.
GetRegister
Buffer
Size
()];
blockwise_in_copy
.
RunLoadRegister
Clipboard
(
p_in_global_block_offset
,
blockwise_in_copy
.
RunLoadRegister
Buffer
(
p_in_global_block_offset
,
p_in_register_
clipboard
);
p_in_register_
buffer
);
blockwise_wei_copy
.
RunLoadRegister
Clipboard
(
p_wei_global_block_offset
,
blockwise_wei_copy
.
RunLoadRegister
Buffer
(
p_wei_global_block_offset
,
p_wei_register_
clipboard
);
p_wei_register_
buffer
);
blockwise_in_copy
.
RunStoreRegister
Clipboard
(
p_in_register_
clipboard
,
blockwise_in_copy
.
RunStoreRegister
Buffer
(
p_in_register_
buffer
,
p_in_block_double
);
p_in_block_double
);
blockwise_wei_copy
.
RunStoreRegister
Clipboard
(
p_wei_register_
clipboard
,
blockwise_wei_copy
.
RunStoreRegister
Buffer
(
p_wei_register_
buffer
,
p_wei_block_double
);
p_wei_block_double
);
}
}
// LDS double buffer: main body
// LDS double buffer: main body
...
@@ -266,9 +281,8 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn_lds_double_buffer
...
@@ -266,9 +281,8 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn_lds_double_buffer
Float
*
p_wei_block_next
=
Float
*
p_wei_block_next
=
even_loop
?
p_wei_block_double
+
wei_block_space
:
p_wei_block_double
;
even_loop
?
p_wei_block_double
+
wei_block_space
:
p_wei_block_double
;
Float
p_in_register_clipboard
[
blockwise_in_copy
.
GetRegisterClipboardSize
()];
Float
p_in_register_buffer
[
blockwise_in_copy
.
GetRegisterBufferSize
()];
Float
Float
p_wei_register_buffer
[
blockwise_wei_copy
.
GetRegisterBufferSize
()];
p_wei_register_clipboard
[
blockwise_wei_copy
.
GetRegisterClipboardSize
()];
p_in_global_block_offset
+=
p_in_global_block_offset
+=
CPerBlock
*
in_c_h_w_n_global_desc
.
GetStride
(
I0
);
CPerBlock
*
in_c_h_w_n_global_desc
.
GetStride
(
I0
);
...
@@ -278,25 +292,25 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn_lds_double_buffer
...
@@ -278,25 +292,25 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn_lds_double_buffer
__syncthreads
();
__syncthreads
();
// LDS doubel buffer: load next data from device mem
// LDS doubel buffer: load next data from device mem
blockwise_in_copy
.
RunLoadRegister
Clipboard
(
p_in_global_block_offset
,
blockwise_in_copy
.
RunLoadRegister
Buffer
(
p_in_global_block_offset
,
p_in_register_
clipboard
);
p_in_register_
buffer
);
blockwise_wei_copy
.
RunLoadRegister
Clipboard
(
p_wei_global_block_offset
,
blockwise_wei_copy
.
RunLoadRegister
Buffer
(
p_wei_global_block_offset
,
p_wei_register_
clipboard
);
p_wei_register_
buffer
);
blockwise_batch_gemm
.
Run
(
p_wei_block_now
,
p_in_block_now
,
p_out_thread
);
blockwise_batch_gemm
.
Run
(
p_wei_block_now
,
p_in_block_now
,
p_out_thread
);
// LDS double buffer: store next data to LDS
// LDS double buffer: store next data to LDS
blockwise_in_copy
.
RunStoreRegister
Clipboard
(
p_in_register_
clipboard
,
blockwise_in_copy
.
RunStoreRegister
Buffer
(
p_in_register_
buffer
,
p_in_block_next
);
p_in_block_next
);
blockwise_wei_copy
.
RunStoreRegister
Clipboard
(
p_wei_register_
clipboard
,
blockwise_wei_copy
.
RunStoreRegister
Buffer
(
p_wei_register_
buffer
,
p_wei_block_next
);
p_wei_block_next
);
}
}
}
}
// LDS double buffer: tail
// LDS double buffer: tail
{
{
Float
p_in_register_
clipboard
[
blockwise_in_copy
.
GetRegister
Clipboard
Size
()];
Float
p_in_register_
buffer
[
blockwise_in_copy
.
GetRegister
Buffer
Size
()];
Float
p_wei_register_
clipboard
[
blockwise_wei_copy
.
GetRegister
Clipboard
Size
()];
Float
p_wei_register_
buffer
[
blockwise_wei_copy
.
GetRegister
Buffer
Size
()];
// even iteration
// even iteration
p_in_global_block_offset
+=
CPerBlock
*
in_c_h_w_n_global_desc
.
GetStride
(
I0
);
p_in_global_block_offset
+=
CPerBlock
*
in_c_h_w_n_global_desc
.
GetStride
(
I0
);
...
@@ -305,19 +319,19 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn_lds_double_buffer
...
@@ -305,19 +319,19 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn_lds_double_buffer
__syncthreads
();
__syncthreads
();
// LDS doubel buffer: load next data from device mem
// LDS doubel buffer: load next data from device mem
blockwise_in_copy
.
RunLoadRegister
Clipboard
(
p_in_global_block_offset
,
blockwise_in_copy
.
RunLoadRegister
Buffer
(
p_in_global_block_offset
,
p_in_register_
clipboard
);
p_in_register_
buffer
);
blockwise_wei_copy
.
RunLoadRegister
Clipboard
(
p_wei_global_block_offset
,
blockwise_wei_copy
.
RunLoadRegister
Buffer
(
p_wei_global_block_offset
,
p_wei_register_
clipboard
);
p_wei_register_
buffer
);
// LDS double buffer: GEMM on current data
// LDS double buffer: GEMM on current data
blockwise_batch_gemm
.
Run
(
p_wei_block_double
,
p_in_block_double
,
p_out_thread
);
blockwise_batch_gemm
.
Run
(
p_wei_block_double
,
p_in_block_double
,
p_out_thread
);
// LDS double buffer: store next data to LDS
// LDS double buffer: store next data to LDS
blockwise_in_copy
.
RunStoreRegister
Clipboard
(
p_in_register_
clipboard
,
blockwise_in_copy
.
RunStoreRegister
Buffer
(
p_in_register_
buffer
,
p_in_block_double
+
in_block_space
);
p_in_block_double
+
in_block_space
);
blockwise_wei_copy
.
RunStoreRegister
Clipboard
(
blockwise_wei_copy
.
RunStoreRegister
Buffer
(
p_wei_register_buffer
,
p_wei_register_clipboard
,
p_wei_block_double
+
wei_block_space
);
p_wei_block_double
+
wei_block_space
);
// odd iteration
// odd iteration
__syncthreads
();
__syncthreads
();
...
@@ -330,7 +344,7 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn_lds_double_buffer
...
@@ -330,7 +344,7 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn_lds_double_buffer
}
}
}
}
// output: register to global mem
,
// output: register to global mem
const
auto
c_thread_mtx_begin
=
const
auto
c_thread_mtx_begin
=
blockwise_batch_gemm
.
GetBeginOfThreadMatrixC
(
get_thread_local_1d_id
());
blockwise_batch_gemm
.
GetBeginOfThreadMatrixC
(
get_thread_local_1d_id
());
...
@@ -381,17 +395,24 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn_lds_double_buffer
...
@@ -381,17 +395,24 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn_lds_double_buffer
}
}
#endif
#endif
threadwise_tensor_slice_copy
(
out_10d_thread_desc
,
Float
*
p_out_thread_on_global
=
p_out_global
+
p_out_thread
,
out_k_h_w_n_global_desc
.
GetOffsetFromMultiIndex
(
out_10d_global_desc
,
k_block_data_begin
+
k_thread_data_begin
,
p_out_global
+
ho_block_data_begin
+
ho_thread_data_begin
,
out_k_h_w_n_global_desc
.
GetOffsetFromMultiIndex
(
wo_block_data_begin
+
wo_thread_data_begin
,
k_block_data_begin
+
k_thread_data_begin
,
n_block_data_begin
+
n_thread_data_begin
);
ho_block_data_begin
+
ho_thread_data_begin
,
wo_block_data_begin
+
wo_thread_data_begin
,
ThreadwiseGenericTensorSliceCopy_v2r1
<
decltype
(
out_10d_thread_desc
),
n_block_data_begin
+
n_thread_data_begin
),
decltype
(
out_10d_global_desc
),
out_10d_thread_desc
.
GetLengths
(),
decltype
(
out_10d_thread_desc
.
GetLengths
()),
Number
<
OutThreadCopyDataPerWrite_N
>
{});
arithmetic_sequence_gen
<
0
,
10
,
1
>::
type
,
arithmetic_sequence_gen
<
0
,
10
,
1
>::
type
,
9
,
9
,
OutThreadCopyDataPerAccess_N
,
OutThreadCopyDataPerAccess_N
>
(
make_zero_array
<
index_t
,
10
>
(),
make_zero_array
<
index_t
,
10
>
())
.
Run
(
p_out_thread
,
p_out_thread_on_global
);
}).
Else
([
&
](
auto
fwd
)
{
}).
Else
([
&
](
auto
fwd
)
{
static_assert
(
fwd
(
GemmNPerThreadSubC
)
>=
NPerBlock
&&
NPerThread
==
NPerBlock
&&
static_assert
(
fwd
(
GemmNPerThreadSubC
)
>=
NPerBlock
&&
NPerThread
==
NPerBlock
&&
GemmNPerThreadSubC
%
NPerThread
==
0
,
GemmNPerThreadSubC
%
NPerThread
==
0
,
...
@@ -432,17 +453,24 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn_lds_double_buffer
...
@@ -432,17 +453,24 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn_lds_double_buffer
}
}
#endif
#endif
threadwise_tensor_slice_copy
(
out_10d_thread_desc
,
Float
*
p_out_thread_on_global
=
p_out_global
+
p_out_thread
,
out_k_h_w_n_global_desc
.
GetOffsetFromMultiIndex
(
out_10d_global_desc
,
k_block_data_begin
+
k_thread_data_begin
,
p_out_global
+
ho_block_data_begin
+
ho_thread_data_begin
,
out_k_h_w_n_global_desc
.
GetOffsetFromMultiIndex
(
wo_block_data_begin
+
wo_thread_data_begin
,
k_block_data_begin
+
k_thread_data_begin
,
n_block_data_begin
+
n_thread_data_begin
);
ho_block_data_begin
+
ho_thread_data_begin
,
wo_block_data_begin
+
wo_thread_data_begin
,
ThreadwiseGenericTensorSliceCopy_v2r1
<
decltype
(
out_10d_thread_desc
),
n_block_data_begin
+
n_thread_data_begin
),
decltype
(
out_10d_global_desc
),
out_10d_thread_desc
.
GetLengths
(),
decltype
(
out_10d_thread_desc
.
GetLengths
()),
Number
<
OutThreadCopyDataPerWrite_N
>
{});
arithmetic_sequence_gen
<
0
,
10
,
1
>::
type
,
arithmetic_sequence_gen
<
0
,
10
,
1
>::
type
,
9
,
9
,
OutThreadCopyDataPerAccess_N
,
OutThreadCopyDataPerAccess_N
>
(
make_zero_array
<
index_t
,
10
>
(),
make_zero_array
<
index_t
,
10
>
())
.
Run
(
p_out_thread
,
p_out_thread_on_global
);
});
});
}
}
};
};
...
...
composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v1r3_nchw_cyxk_nkhw_lds_double_buffer.hpp
View file @
8bdaba51
...
@@ -254,19 +254,18 @@ struct GridwiseConvolutionImplicitGemm_v1r3_nchw_cyxk_nkhw_lds_double_buffer
...
@@ -254,19 +254,18 @@ struct GridwiseConvolutionImplicitGemm_v1r3_nchw_cyxk_nkhw_lds_double_buffer
// LDS double buffer: preload data into LDS
// LDS double buffer: preload data into LDS
{
{
Float
p_in_register_clipboard
[
blockwise_in_copy_reorder
Float
p_in_register_buffer
[
blockwise_in_copy_reorder
.
GetRegisterBufferSize
()];
.
GetRegisterClipboardSize
()];
Float
p_wei_register_buffer
[
blockwise_wei_copy
.
GetRegisterBufferSize
()];
Float
p_wei_register_clipboard
[
blockwise_wei_copy
.
GetRegisterClipboardSize
()];
blockwise_in_copy_reorder
.
RunLoadRegisterBuffer
(
p_in_global_block_offset
,
blockwise_in_copy_reorder
.
RunLoadRegisterClipboard
(
p_in_global_block_offset
,
p_in_register_buffer
);
p_in_register_clipboard
);
blockwise_wei_copy
.
RunLoadRegisterBuffer
(
p_wei_global_block_offset
,
blockwise_wei_copy
.
RunLoadRegisterClipboard
(
p_wei_global_block_offset
,
p_wei_register_buffer
);
p_wei_register_clipboard
);
blockwise_in_copy_reorder
.
RunStoreRegisterBuffer
(
p_in_register_buffer
,
blockwise_in_copy_reorder
.
RunStoreRegisterClipboard
(
p_in_register_clipboard
,
p_in_block_double
);
p_in_block_double
);
blockwise_wei_copy
.
RunStoreRegisterBuffer
(
p_wei_register_buffer
,
blockwise_wei_copy
.
RunStoreRegisterClipboard
(
p_wei_register_clipboard
,
p_wei_block_double
);
p_wei_block_double
);
}
}
// LDS double buffer: main body
// LDS double buffer: main body
...
@@ -288,10 +287,9 @@ struct GridwiseConvolutionImplicitGemm_v1r3_nchw_cyxk_nkhw_lds_double_buffer
...
@@ -288,10 +287,9 @@ struct GridwiseConvolutionImplicitGemm_v1r3_nchw_cyxk_nkhw_lds_double_buffer
Float
*
p_wei_block_next
=
Float
*
p_wei_block_next
=
even_loop
?
p_wei_block_double
+
wei_block_space
:
p_wei_block_double
;
even_loop
?
p_wei_block_double
+
wei_block_space
:
p_wei_block_double
;
Float
p_in_register_clipboard
[
blockwise_in_copy_reorder
.
GetRegisterClipboardSize
()];
Float
Float
p_wei_register_clipboard
[
blockwise_wei_copy
.
GetRegisterClipboardSize
()];
p_in_register_buffer
[
blockwise_in_copy_reorder
.
GetRegisterBufferSize
()];
Float
p_wei_register_buffer
[
blockwise_wei_copy
.
GetRegisterBufferSize
()];
p_in_global_block_offset
+=
p_in_global_block_offset
+=
CPerBlock
*
in_n_c_h_w_global_desc
.
GetStride
(
I1
);
CPerBlock
*
in_n_c_h_w_global_desc
.
GetStride
(
I1
);
...
@@ -301,27 +299,26 @@ struct GridwiseConvolutionImplicitGemm_v1r3_nchw_cyxk_nkhw_lds_double_buffer
...
@@ -301,27 +299,26 @@ struct GridwiseConvolutionImplicitGemm_v1r3_nchw_cyxk_nkhw_lds_double_buffer
__syncthreads
();
__syncthreads
();
// LDS doubel buffer: load next data from device mem
// LDS doubel buffer: load next data from device mem
blockwise_in_copy_reorder
.
RunLoadRegister
Clipboard
(
p_in_global_block_offset
,
blockwise_in_copy_reorder
.
RunLoadRegister
Buffer
(
p_in_global_block_offset
,
p_in_register_
clipboard
);
p_in_register_
buffer
);
blockwise_wei_copy
.
RunLoadRegister
Clipboard
(
p_wei_global_block_offset
,
blockwise_wei_copy
.
RunLoadRegister
Buffer
(
p_wei_global_block_offset
,
p_wei_register_
clipboard
);
p_wei_register_
buffer
);
// LDS double buffer: GEMM on current data
// LDS double buffer: GEMM on current data
run_blockwise_batch_gemm
(
p_wei_block_now
,
p_in_block_now
,
p_out_thread
);
run_blockwise_batch_gemm
(
p_wei_block_now
,
p_in_block_now
,
p_out_thread
);
// LDS double buffer: store next data to LDS
// LDS double buffer: store next data to LDS
blockwise_in_copy_reorder
.
RunStoreRegister
Clipboard
(
p_in_register_
clipboard
,
blockwise_in_copy_reorder
.
RunStoreRegister
Buffer
(
p_in_register_
buffer
,
p_in_block_next
);
p_in_block_next
);
blockwise_wei_copy
.
RunStoreRegister
Clipboard
(
p_wei_register_
clipboard
,
blockwise_wei_copy
.
RunStoreRegister
Buffer
(
p_wei_register_
buffer
,
p_wei_block_next
);
p_wei_block_next
);
}
}
}
}
// LDS double buffer: tail
// LDS double buffer: tail
{
{
Float
p_in_register_clipboard
[
blockwise_in_copy_reorder
Float
p_in_register_buffer
[
blockwise_in_copy_reorder
.
GetRegisterBufferSize
()];
.
GetRegisterClipboardSize
()];
Float
p_wei_register_buffer
[
blockwise_wei_copy
.
GetRegisterBufferSize
()];
Float
p_wei_register_clipboard
[
blockwise_wei_copy
.
GetRegisterClipboardSize
()];
// even iteration
// even iteration
p_in_global_block_offset
+=
CPerBlock
*
in_n_c_h_w_global_desc
.
GetStride
(
I1
);
p_in_global_block_offset
+=
CPerBlock
*
in_n_c_h_w_global_desc
.
GetStride
(
I1
);
...
@@ -330,19 +327,19 @@ struct GridwiseConvolutionImplicitGemm_v1r3_nchw_cyxk_nkhw_lds_double_buffer
...
@@ -330,19 +327,19 @@ struct GridwiseConvolutionImplicitGemm_v1r3_nchw_cyxk_nkhw_lds_double_buffer
__syncthreads
();
__syncthreads
();
// LDS doubel buffer: load next data from device mem
// LDS doubel buffer: load next data from device mem
blockwise_in_copy_reorder
.
RunLoadRegister
Clipboard
(
p_in_global_block_offset
,
blockwise_in_copy_reorder
.
RunLoadRegister
Buffer
(
p_in_global_block_offset
,
p_in_register_
clipboard
);
p_in_register_
buffer
);
blockwise_wei_copy
.
RunLoadRegister
Clipboard
(
p_wei_global_block_offset
,
blockwise_wei_copy
.
RunLoadRegister
Buffer
(
p_wei_global_block_offset
,
p_wei_register_
clipboard
);
p_wei_register_
buffer
);
// LDS double buffer: GEMM on current data
// LDS double buffer: GEMM on current data
run_blockwise_batch_gemm
(
p_wei_block_double
,
p_in_block_double
,
p_out_thread
);
run_blockwise_batch_gemm
(
p_wei_block_double
,
p_in_block_double
,
p_out_thread
);
// LDS double buffer: store next data to LDS
// LDS double buffer: store next data to LDS
blockwise_in_copy_reorder
.
RunStoreRegister
Clipboard
(
blockwise_in_copy_reorder
.
RunStoreRegister
Buffer
(
p_in_register_
clipboard
,
p_in_block_double
+
in_block_space
);
p_in_register_
buffer
,
p_in_block_double
+
in_block_space
);
blockwise_wei_copy
.
RunStoreRegister
Clipboard
(
blockwise_wei_copy
.
RunStoreRegister
Buffer
(
p_wei_register_buffer
,
p_wei_register_clipboard
,
p_wei_block_double
+
wei_block_space
);
p_wei_block_double
+
wei_block_space
);
// odd iteration
// odd iteration
__syncthreads
();
__syncthreads
();
...
...
composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v2_chwn_cyxk_khwn.hpp
View file @
8bdaba51
...
@@ -214,16 +214,15 @@ struct GridwiseConvolutionImplicitGemm_v2_chwn_cyxk_khwn
...
@@ -214,16 +214,15 @@ struct GridwiseConvolutionImplicitGemm_v2_chwn_cyxk_khwn
__syncthreads
())
__syncthreads
())
{
{
// load data
// load data
Float
p_in_register_
clipboard
[
blockwise_in_copy
.
GetRegister
Clipboard
Size
()];
Float
p_in_register_
buffer
[
blockwise_in_copy
.
GetRegister
Buffer
Size
()];
Float
p_wei_register_
clipboard
[
blockwise_wei_copy
.
GetRegister
Clipboard
Size
()];
Float
p_wei_register_
buffer
[
blockwise_wei_copy
.
GetRegister
Buffer
Size
()];
blockwise_in_copy
.
RunLoadRegisterClipboard
(
p_in_global_block_offset
,
blockwise_in_copy
.
RunLoadRegisterBuffer
(
p_in_global_block_offset
,
p_in_register_buffer
);
p_in_register_clipboard
);
blockwise_wei_copy
.
RunLoadRegisterBuffer
(
p_wei_global_block_offset
,
blockwise_wei_copy
.
RunLoadRegisterClipboard
(
p_wei_global_block_offset
,
p_wei_register_buffer
);
p_wei_register_clipboard
);
blockwise_in_copy
.
RunStoreRegister
Clipboard
(
p_in_register_
clipboard
,
p_in_block
);
blockwise_in_copy
.
RunStoreRegister
Buffer
(
p_in_register_
buffer
,
p_in_block
);
blockwise_wei_copy
.
RunStoreRegister
Clipboard
(
p_wei_register_
clipboard
,
p_wei_block
);
blockwise_wei_copy
.
RunStoreRegister
Buffer
(
p_wei_register_
buffer
,
p_wei_block
);
__syncthreads
();
__syncthreads
();
...
...
composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v2_chwn_cyxk_khwn_lds_double_buffer.hpp
View file @
8bdaba51
...
@@ -209,17 +209,15 @@ struct GridwiseConvolutionImplicitGemm_v2_chwn_cyxk_khwn_lds_double_buffer
...
@@ -209,17 +209,15 @@ struct GridwiseConvolutionImplicitGemm_v2_chwn_cyxk_khwn_lds_double_buffer
// preload data into LDS
// preload data into LDS
{
{
Float
p_in_register_
clipboard
[
blockwise_in_copy
.
GetRegister
Clipboard
Size
()];
Float
p_in_register_
buffer
[
blockwise_in_copy
.
GetRegister
Buffer
Size
()];
Float
p_wei_register_
clipboard
[
blockwise_wei_copy
.
GetRegister
Clipboard
Size
()];
Float
p_wei_register_
buffer
[
blockwise_wei_copy
.
GetRegister
Buffer
Size
()];
blockwise_in_copy
.
RunLoadRegisterClipboard
(
p_in_global_block_offset
,
blockwise_in_copy
.
RunLoadRegisterBuffer
(
p_in_global_block_offset
,
p_in_register_buffer
);
p_in_register_clipboard
);
blockwise_wei_copy
.
RunLoadRegisterBuffer
(
p_wei_global_block_offset
,
blockwise_wei_copy
.
RunLoadRegisterClipboard
(
p_wei_global_block_offset
,
p_wei_register_buffer
);
p_wei_register_clipboard
);
blockwise_in_copy
.
RunStoreRegisterClipboard
(
p_in_register_clipboard
,
p_in_block_double
);
blockwise_in_copy
.
RunStoreRegisterBuffer
(
p_in_register_buffer
,
p_in_block_double
);
blockwise_wei_copy
.
RunStoreRegisterClipboard
(
p_wei_register_clipboard
,
blockwise_wei_copy
.
RunStoreRegisterBuffer
(
p_wei_register_buffer
,
p_wei_block_double
);
p_wei_block_double
);
}
}
// register
// register
...
@@ -247,18 +245,18 @@ struct GridwiseConvolutionImplicitGemm_v2_chwn_cyxk_khwn_lds_double_buffer
...
@@ -247,18 +245,18 @@ struct GridwiseConvolutionImplicitGemm_v2_chwn_cyxk_khwn_lds_double_buffer
even_loop
?
p_wei_block_double
+
wei_block_space
:
p_wei_block_double
;
even_loop
?
p_wei_block_double
+
wei_block_space
:
p_wei_block_double
;
// load next data
// load next data
Float
p_in_register_
clipboard
[
blockwise_in_copy
.
GetRegister
Clipboard
Size
()];
Float
p_in_register_
buffer
[
blockwise_in_copy
.
GetRegister
Buffer
Size
()];
Float
p_wei_register_
clipboard
[
blockwise_wei_copy
.
GetRegister
Clipboard
Size
()];
Float
p_wei_register_
buffer
[
blockwise_wei_copy
.
GetRegister
Buffer
Size
()];
p_in_global_block_offset
+=
CPerBlock
*
in_cb_global_desc
.
GetStride
(
I0
);
p_in_global_block_offset
+=
CPerBlock
*
in_cb_global_desc
.
GetStride
(
I0
);
p_wei_global_block_offset
+=
CPerBlock
*
wei_cyxk_global_desc
.
GetStride
(
I0
);
p_wei_global_block_offset
+=
CPerBlock
*
wei_cyxk_global_desc
.
GetStride
(
I0
);
__syncthreads
();
__syncthreads
();
blockwise_in_copy
.
RunLoadRegister
Clipboard
(
p_in_global_block_offset
,
blockwise_in_copy
.
RunLoadRegister
Buffer
(
p_in_global_block_offset
,
p_in_register_
clipboard
);
p_in_register_
buffer
);
blockwise_wei_copy
.
RunLoadRegister
Clipboard
(
p_wei_global_block_offset
,
blockwise_wei_copy
.
RunLoadRegister
Buffer
(
p_wei_global_block_offset
,
p_wei_register_
clipboard
);
p_wei_register_
buffer
);
// compute on current data
// compute on current data
// a series of GEMM
// a series of GEMM
...
@@ -280,10 +278,8 @@ struct GridwiseConvolutionImplicitGemm_v2_chwn_cyxk_khwn_lds_double_buffer
...
@@ -280,10 +278,8 @@ struct GridwiseConvolutionImplicitGemm_v2_chwn_cyxk_khwn_lds_double_buffer
}
}
}
}
blockwise_in_copy
.
RunStoreRegisterClipboard
(
p_in_register_clipboard
,
blockwise_in_copy
.
RunStoreRegisterBuffer
(
p_in_register_buffer
,
p_in_block_next
);
p_in_block_next
);
blockwise_wei_copy
.
RunStoreRegisterBuffer
(
p_wei_register_buffer
,
p_wei_block_next
);
blockwise_wei_copy
.
RunStoreRegisterClipboard
(
p_wei_register_clipboard
,
p_wei_block_next
);
}
}
}
}
...
@@ -295,14 +291,13 @@ struct GridwiseConvolutionImplicitGemm_v2_chwn_cyxk_khwn_lds_double_buffer
...
@@ -295,14 +291,13 @@ struct GridwiseConvolutionImplicitGemm_v2_chwn_cyxk_khwn_lds_double_buffer
__syncthreads
();
__syncthreads
();
Float
p_in_register_
clipboard
[
blockwise_in_copy
.
GetRegister
Clipboard
Size
()];
Float
p_in_register_
buffer
[
blockwise_in_copy
.
GetRegister
Buffer
Size
()];
Float
p_wei_register_
clipboard
[
blockwise_wei_copy
.
GetRegister
Clipboard
Size
()];
Float
p_wei_register_
buffer
[
blockwise_wei_copy
.
GetRegister
Buffer
Size
()];
blockwise_in_copy
.
RunLoadRegisterClipboard
(
p_in_global_block_offset
,
blockwise_in_copy
.
RunLoadRegisterBuffer
(
p_in_global_block_offset
,
p_in_register_buffer
);
p_in_register_clipboard
);
blockwise_wei_copy
.
RunLoadRegister
Clipboard
(
p_wei_global_block_offset
,
blockwise_wei_copy
.
RunLoadRegister
Buffer
(
p_wei_global_block_offset
,
p_wei_register_
clipboard
);
p_wei_register_
buffer
);
for
(
index_t
y
=
0
;
y
<
Y
;
++
y
)
for
(
index_t
y
=
0
;
y
<
Y
;
++
y
)
{
{
...
@@ -322,10 +317,10 @@ struct GridwiseConvolutionImplicitGemm_v2_chwn_cyxk_khwn_lds_double_buffer
...
@@ -322,10 +317,10 @@ struct GridwiseConvolutionImplicitGemm_v2_chwn_cyxk_khwn_lds_double_buffer
}
}
}
}
blockwise_in_copy
.
RunStoreRegister
Clipboard
(
p_in_register_
clipboard
,
blockwise_in_copy
.
RunStoreRegister
Buffer
(
p_in_register_
buffer
,
p_in_block_double
+
in_block_space
);
p_in_block_double
+
in_block_space
);
blockwise_wei_copy
.
RunStoreRegister
Clipboard
(
p_wei_register_
clipboard
,
blockwise_wei_copy
.
RunStoreRegister
Buffer
(
p_wei_register_
buffer
,
p_wei_block_double
+
wei_block_space
);
p_wei_block_double
+
wei_block_space
);
// odd
// odd
__syncthreads
();
__syncthreads
();
...
...
composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v3_nchw_cyxk_nkhw_lds_double_buffer.hpp
View file @
8bdaba51
...
@@ -267,9 +267,8 @@ struct GridwiseConvolutionImplicitGemm_v3_nchw_cyxk_nkhw_lds_double_buffer
...
@@ -267,9 +267,8 @@ struct GridwiseConvolutionImplicitGemm_v3_nchw_cyxk_nkhw_lds_double_buffer
Float
*
p_wei_block_next
=
Float
*
p_wei_block_next
=
even_loop
?
p_wei_block_double
+
wei_block_space
:
p_wei_block_double
;
even_loop
?
p_wei_block_double
+
wei_block_space
:
p_wei_block_double
;
Float
p_in_register_clipboard
[
blockwise_in_copy
.
GetRegisterClipboardSize
()];
Float
p_in_register_buffer
[
blockwise_in_copy
.
GetRegisterBufferSize
()];
Float
Float
p_wei_register_buffer
[
blockwise_wei_copy
.
GetRegisterBufferSize
()];
p_wei_register_clipboard
[
blockwise_wei_copy
.
GetRegisterClipboardSize
()];
p_in_block_on_global
+=
CPerBlock
*
in_n_c_h_w_global_desc
.
GetStride
(
I1
);
p_in_block_on_global
+=
CPerBlock
*
in_n_c_h_w_global_desc
.
GetStride
(
I1
);
p_wei_block_on_global
+=
CPerBlock
*
wei_c_y_x_k_global_desc
.
GetStride
(
I0
);
p_wei_block_on_global
+=
CPerBlock
*
wei_c_y_x_k_global_desc
.
GetStride
(
I0
);
...
@@ -277,26 +276,26 @@ struct GridwiseConvolutionImplicitGemm_v3_nchw_cyxk_nkhw_lds_double_buffer
...
@@ -277,26 +276,26 @@ struct GridwiseConvolutionImplicitGemm_v3_nchw_cyxk_nkhw_lds_double_buffer
__syncthreads
();
__syncthreads
();
// LDS doubel buffer: load next data from device mem
// LDS doubel buffer: load next data from device mem
blockwise_in_copy
.
RunLoadRegister
Clipboard
(
p_in_block_on_global
,
blockwise_in_copy
.
RunLoadRegister
Buffer
(
p_in_block_on_global
,
p_in_register_
clipboard
);
p_in_register_
buffer
);
blockwise_wei_copy
.
RunLoadRegister
Clipboard
(
p_wei_block_on_global
,
blockwise_wei_copy
.
RunLoadRegister
Buffer
(
p_wei_block_on_global
,
p_wei_register_
clipboard
);
p_wei_register_
buffer
);
// LDS double buffer: GEMM on current data
// LDS double buffer: GEMM on current data
blockwise_gemm
.
Run
(
p_wei_block_now
,
p_in_block_now
,
p_out_thread
);
blockwise_gemm
.
Run
(
p_wei_block_now
,
p_in_block_now
,
p_out_thread
);
// LDS double buffer: store next data to LDS
// LDS double buffer: store next data to LDS
blockwise_in_copy
.
RunStoreRegister
Clipboard
(
p_in_register_
clipboard
,
blockwise_in_copy
.
RunStoreRegister
Buffer
(
p_in_register_
buffer
,
p_in_block_next
);
p_in_block_next
);
blockwise_wei_copy
.
RunStoreRegister
Clipboard
(
p_wei_register_
clipboard
,
blockwise_wei_copy
.
RunStoreRegister
Buffer
(
p_wei_register_
buffer
,
p_wei_block_next
);
p_wei_block_next
);
}
}
}
}
// LDS double buffer: tail
// LDS double buffer: tail
{
{
Float
p_in_register_
clipboard
[
blockwise_in_copy
.
GetRegister
Clipboard
Size
()];
Float
p_in_register_
buffer
[
blockwise_in_copy
.
GetRegister
Buffer
Size
()];
Float
p_wei_register_
clipboard
[
blockwise_wei_copy
.
GetRegister
Clipboard
Size
()];
Float
p_wei_register_
buffer
[
blockwise_wei_copy
.
GetRegister
Buffer
Size
()];
// even iteration
// even iteration
p_in_block_on_global
+=
CPerBlock
*
in_n_c_h_w_global_desc
.
GetStride
(
I1
);
p_in_block_on_global
+=
CPerBlock
*
in_n_c_h_w_global_desc
.
GetStride
(
I1
);
...
@@ -305,19 +304,19 @@ struct GridwiseConvolutionImplicitGemm_v3_nchw_cyxk_nkhw_lds_double_buffer
...
@@ -305,19 +304,19 @@ struct GridwiseConvolutionImplicitGemm_v3_nchw_cyxk_nkhw_lds_double_buffer
__syncthreads
();
__syncthreads
();
// LDS doubel buffer: load next data from device mem
// LDS doubel buffer: load next data from device mem
blockwise_in_copy
.
RunLoadRegister
Clipboard
(
p_in_block_on_global
,
blockwise_in_copy
.
RunLoadRegister
Buffer
(
p_in_block_on_global
,
p_in_register_
clipboard
);
p_in_register_
buffer
);
blockwise_wei_copy
.
RunLoadRegister
Clipboard
(
p_wei_block_on_global
,
blockwise_wei_copy
.
RunLoadRegister
Buffer
(
p_wei_block_on_global
,
p_wei_register_
clipboard
);
p_wei_register_
buffer
);
// LDS double buffer: GEMM on current data
// LDS double buffer: GEMM on current data
blockwise_gemm
.
Run
(
p_wei_block_double
,
p_in_block_double
,
p_out_thread
);
blockwise_gemm
.
Run
(
p_wei_block_double
,
p_in_block_double
,
p_out_thread
);
// LDS double buffer: store next data to LDS
// LDS double buffer: store next data to LDS
blockwise_in_copy
.
RunStoreRegister
Clipboard
(
p_in_register_
clipboard
,
blockwise_in_copy
.
RunStoreRegister
Buffer
(
p_in_register_
buffer
,
p_in_block_double
+
in_block_space
);
p_in_block_double
+
in_block_space
);
blockwise_wei_copy
.
RunStoreRegister
Clipboard
(
blockwise_wei_copy
.
RunStoreRegister
Buffer
(
p_wei_register_buffer
,
p_wei_register_clipboard
,
p_wei_block_double
+
wei_block_space
);
p_wei_block_double
+
wei_block_space
);
// odd iteration
// odd iteration
__syncthreads
();
__syncthreads
();
...
...
composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer.hpp
View file @
8bdaba51
...
@@ -176,22 +176,21 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer
...
@@ -176,22 +176,21 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer
InBlockCopyDstDataPerWrite_N2>(
InBlockCopyDstDataPerWrite_N2>(
{0, 0, b_block_data_on_global, 0}, {0, 0, 0, 0});
{0, 0, b_block_data_on_global, 0}, {0, 0, 0, 0});
#else
#else
auto
blockwise_in_copy
=
BlockwiseGenericTensorSliceCopy_v2
<
auto
blockwise_in_copy
=
BlockSize
,
BlockwiseGenericTensorSliceCopy_v2
<
BlockSize
,
decltype
(
in_e_n1_b_n2_global_merged_desc
),
decltype
(
in_e_n1_b_n2_global_merged_desc
),
decltype
(
in_e_n1_b_n2_block_desc
),
decltype
(
in_e_n1_b_n2_block_desc
),
MergedTensorCoordinate
<
decltype
(
in_e_n1_b_n2_global_merged_desc
)
>
,
decltype
(
in_e_n1_b_n2_block_desc
.
GetLengths
()),
NormalTensorCoordinate
<
decltype
(
in_e_n1_b_n2_block_desc
)
>
,
InBlockCopySubLengths_E_N1_B_N2
,
decltype
(
in_e_n1_b_n2_block_desc
.
GetLengths
()),
InBlockCopyClusterLengths_E_N1_B_N2
,
InBlockCopySubLengths_E_N1_B_N2
,
InBlockCopyThreadClusterArrangeOrder
,
InBlockCopyClusterLengths_E_N1_B_N2
,
InBlockCopySrcAccessOrder
,
InBlockCopyThreadClusterArrangeOrder
,
InBlockCopyDstAccessOrder
,
InBlockCopySrcAccessOrder
,
2
,
InBlockCopyDstAccessOrder
,
3
,
2
,
InBlockCopySrcDataPerRead_B
,
3
,
InBlockCopyDstDataPerWrite_N2
>
(
InBlockCopySrcDataPerRead_B
,
{
0
,
0
,
b_block_data_on_global
,
0
},
{
0
,
0
,
0
,
0
});
InBlockCopyDstDataPerWrite_N2
>
({
0
,
0
,
b_block_data_on_global
,
0
},
{
0
,
0
,
0
,
0
});
#endif
#endif
// weight tensor
// weight tensor
...
@@ -225,22 +224,21 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer
...
@@ -225,22 +224,21 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer
WeiBlockCopyDstDataPerWrite_K>(
WeiBlockCopyDstDataPerWrite_K>(
{0, k_block_data_on_global}, {0, 0});
{0, k_block_data_on_global}, {0, 0});
#else
#else
auto
blockwise_wei_copy
=
BlockwiseGenericTensorSliceCopy_v2
<
auto
blockwise_wei_copy
=
BlockSize
,
BlockwiseGenericTensorSliceCopy_v2
<
BlockSize
,
decltype
(
wei_e_k_global_desc
),
decltype
(
wei_e_k_global_desc
),
decltype
(
wei_e_k_block_desc
),
decltype
(
wei_e_k_block_desc
),
NormalTensorCoordinate
<
decltype
(
wei_e_k_global_desc
)
>
,
decltype
(
wei_e_k_block_desc
.
GetLengths
()),
NormalTensorCoordinate
<
decltype
(
wei_e_k_block_desc
)
>
,
WeiBlockCopySubLengths_E_K
,
decltype
(
wei_e_k_block_desc
.
GetLengths
()),
WeiBlockCopyClusterLengths_E_K
,
WeiBlockCopySubLengths_E_K
,
WeiBlockCopyThreadClusterArrangeOrder
,
WeiBlockCopyClusterLengths_E_K
,
WeiBlockCopySrcAccessOrder
,
WeiBlockCopyThreadClusterArrangeOrder
,
WeiBlockCopyDstAccessOrder
,
WeiBlockCopySrcAccessOrder
,
0
,
WeiBlockCopyDstAccessOrder
,
1
,
0
,
WeiBlockCopySrcDataPerRead_E
,
1
,
WeiBlockCopyDstDataPerWrite_K
>
(
WeiBlockCopySrcDataPerRead_E
,
{
0
,
k_block_data_on_global
},
{
0
,
0
});
WeiBlockCopyDstDataPerWrite_K
>
({
0
,
k_block_data_on_global
},
{
0
,
0
});
#endif
#endif
// GEMM definition
// GEMM definition
...
@@ -448,8 +446,6 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer
...
@@ -448,8 +446,6 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer
ThreadwiseGenericTensorSliceCopy_v2r1
<
ThreadwiseGenericTensorSliceCopy_v2r1
<
decltype
(
out_n0_n1_n2_k0_k1_k2_h_w_thread_desc
),
decltype
(
out_n0_n1_n2_k0_k1_k2_h_w_thread_desc
),
decltype
(
out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc
),
decltype
(
out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc
),
NormalTensorCoordinate
<
decltype
(
out_n0_n1_n2_k0_k1_k2_h_w_thread_desc
)
>
,
MergedTensorCoordinate
<
decltype
(
out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc
)
>
,
decltype
(
out_n0_n1_n2_k0_k1_k2_h_w_thread_desc
.
GetLengths
()),
decltype
(
out_n0_n1_n2_k0_k1_k2_h_w_thread_desc
.
GetLengths
()),
arithmetic_sequence_gen
<
0
,
8
,
1
>::
type
,
arithmetic_sequence_gen
<
0
,
8
,
1
>::
type
,
arithmetic_sequence_gen
<
0
,
8
,
1
>::
type
,
arithmetic_sequence_gen
<
0
,
8
,
1
>::
type
,
...
...
composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r2_nchw_kcyx_nkhw_lds_double_buffer.hpp
View file @
8bdaba51
...
@@ -313,8 +313,8 @@ struct GridwiseConvolutionImplicitGemm_v4r2_nchw_kcyx_nkhw_lds_double_buffer
...
@@ -313,8 +313,8 @@ struct GridwiseConvolutionImplicitGemm_v4r2_nchw_kcyx_nkhw_lds_double_buffer
Float
*
p_wei_block_next
=
Float
*
p_wei_block_next
=
even_loop
?
p_wei_block_double
+
wei_block_space
:
p_wei_block_double
;
even_loop
?
p_wei_block_double
+
wei_block_space
:
p_wei_block_double
;
Float
p_in_register_
clipboard
[
blockwise_in_copy
.
GetRegister
Clipboard
Size
()];
Float
p_in_register_
buffer
[
blockwise_in_copy
.
GetRegister
Buffer
Size
()];
Float
p_wei_register_
clipboard
[
blockwise_wei_copy
.
GetRegister
Clipboard
Size
()];
Float
p_wei_register_
buffer
[
blockwise_wei_copy
.
GetRegister
Buffer
Size
()];
blockwise_in_copy
.
MoveSlicingWindowOnSourceTensor
(
I0
,
Number
<
EPerBlock
>
{},
True
);
blockwise_in_copy
.
MoveSlicingWindowOnSourceTensor
(
I0
,
Number
<
EPerBlock
>
{},
True
);
p_wei_block_on_global
+=
EPerBlock
*
wei_e_k_global_desc
.
GetStride
(
I0
);
p_wei_block_on_global
+=
EPerBlock
*
wei_e_k_global_desc
.
GetStride
(
I0
);
...
@@ -322,25 +322,23 @@ struct GridwiseConvolutionImplicitGemm_v4r2_nchw_kcyx_nkhw_lds_double_buffer
...
@@ -322,25 +322,23 @@ struct GridwiseConvolutionImplicitGemm_v4r2_nchw_kcyx_nkhw_lds_double_buffer
__syncthreads
();
__syncthreads
();
// LDS doubel buffer: load next data from device mem
// LDS doubel buffer: load next data from device mem
blockwise_in_copy
.
RunLoadRegister
Clipboard
(
p_in_global
,
p_in_register_
clipboard
);
blockwise_in_copy
.
RunLoadRegister
Buffer
(
p_in_global
,
p_in_register_
buffer
);
blockwise_wei_copy
.
RunLoadRegister
Clipboard
(
p_wei_block_on_global
,
blockwise_wei_copy
.
RunLoadRegister
Buffer
(
p_wei_block_on_global
,
p_wei_register_
clipboard
);
p_wei_register_
buffer
);
// LDS double buffer: GEMM on current data
// LDS double buffer: GEMM on current data
blockwise_gemm
.
Run
(
p_wei_block_now
,
p_in_block_now
,
p_out_thread
);
blockwise_gemm
.
Run
(
p_wei_block_now
,
p_in_block_now
,
p_out_thread
);
// LDS double buffer: store next data to LDS
// LDS double buffer: store next data to LDS
blockwise_in_copy
.
RunStoreRegisterClipboard
(
p_in_register_clipboard
,
blockwise_in_copy
.
RunStoreRegisterBuffer
(
p_in_register_buffer
,
p_in_block_next
);
p_in_block_next
);
blockwise_wei_copy
.
RunStoreRegisterBuffer
(
p_wei_register_buffer
,
p_wei_block_next
);
blockwise_wei_copy
.
RunStoreRegisterClipboard
(
p_wei_register_clipboard
,
p_wei_block_next
);
}
}
}
}
// LDS double buffer: tail
// LDS double buffer: tail
{
{
Float
p_in_register_
clipboard
[
blockwise_in_copy
.
GetRegister
Clipboard
Size
()];
Float
p_in_register_
buffer
[
blockwise_in_copy
.
GetRegister
Buffer
Size
()];
Float
p_wei_register_
clipboard
[
blockwise_wei_copy
.
GetRegister
Clipboard
Size
()];
Float
p_wei_register_
buffer
[
blockwise_wei_copy
.
GetRegister
Buffer
Size
()];
// even iteration
// even iteration
blockwise_in_copy
.
MoveSlicingWindowOnSourceTensor
(
I0
,
Number
<
EPerBlock
>
{},
True
);
blockwise_in_copy
.
MoveSlicingWindowOnSourceTensor
(
I0
,
Number
<
EPerBlock
>
{},
True
);
...
@@ -349,18 +347,17 @@ struct GridwiseConvolutionImplicitGemm_v4r2_nchw_kcyx_nkhw_lds_double_buffer
...
@@ -349,18 +347,17 @@ struct GridwiseConvolutionImplicitGemm_v4r2_nchw_kcyx_nkhw_lds_double_buffer
__syncthreads
();
__syncthreads
();
// LDS doubel buffer: load next data from device mem
// LDS doubel buffer: load next data from device mem
blockwise_in_copy
.
RunLoadRegisterClipboard
(
p_in_global
,
p_in_register_clipboard
);
blockwise_in_copy
.
RunLoadRegisterBuffer
(
p_in_global
,
p_in_register_buffer
);
blockwise_wei_copy
.
RunLoadRegisterClipboard
(
p_wei_block_on_global
,
blockwise_wei_copy
.
RunLoadRegisterBuffer
(
p_wei_block_on_global
,
p_wei_register_buffer
);
p_wei_register_clipboard
);
// LDS double buffer: GEMM on current data
// LDS double buffer: GEMM on current data
blockwise_gemm
.
Run
(
p_wei_block_double
,
p_in_block_double
,
p_out_thread
);
blockwise_gemm
.
Run
(
p_wei_block_double
,
p_in_block_double
,
p_out_thread
);
// LDS double buffer: store next data to LDS
// LDS double buffer: store next data to LDS
blockwise_in_copy
.
RunStoreRegister
Clipboard
(
p_in_register_
clipboard
,
blockwise_in_copy
.
RunStoreRegister
Buffer
(
p_in_register_
buffer
,
p_in_block_double
+
in_block_space
);
p_in_block_double
+
in_block_space
);
blockwise_wei_copy
.
RunStoreRegister
Clipboard
(
p_wei_register_
clipboard
,
blockwise_wei_copy
.
RunStoreRegister
Buffer
(
p_wei_register_
buffer
,
p_wei_block_double
+
wei_block_space
);
p_wei_block_double
+
wei_block_space
);
// odd iteration
// odd iteration
__syncthreads
();
__syncthreads
();
...
...
composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r3_nchw_kcyx_nkhw_lds_double_buffer.hpp
View file @
8bdaba51
...
@@ -319,8 +319,8 @@ struct GridwiseConvolutionImplicitGemm_v4r3_nchw_kcyx_nkhw_lds_double_buffer
...
@@ -319,8 +319,8 @@ struct GridwiseConvolutionImplicitGemm_v4r3_nchw_kcyx_nkhw_lds_double_buffer
Float
*
p_wei_block_next
=
Float
*
p_wei_block_next
=
even_loop
?
p_wei_block_double
+
wei_block_space
:
p_wei_block_double
;
even_loop
?
p_wei_block_double
+
wei_block_space
:
p_wei_block_double
;
Float
p_in_register_
clipboard
[
blockwise_in_copy
.
GetRegister
Clipboard
Size
()];
Float
p_in_register_
buffer
[
blockwise_in_copy
.
GetRegister
Buffer
Size
()];
Float
p_wei_register_
clipboard
[
blockwise_wei_copy
.
GetRegister
Clipboard
Size
()];
Float
p_wei_register_
buffer
[
blockwise_wei_copy
.
GetRegister
Buffer
Size
()];
blockwise_in_copy
.
MoveSlicingWindowOnSourceTensor
(
I0
,
Number
<
EPerBlock
>
{},
True
);
blockwise_in_copy
.
MoveSlicingWindowOnSourceTensor
(
I0
,
Number
<
EPerBlock
>
{},
True
);
p_wei_block_on_global
+=
EPerBlock
*
wei_e_k_global_desc
.
GetStride
(
I0
);
p_wei_block_on_global
+=
EPerBlock
*
wei_e_k_global_desc
.
GetStride
(
I0
);
...
@@ -328,9 +328,9 @@ struct GridwiseConvolutionImplicitGemm_v4r3_nchw_kcyx_nkhw_lds_double_buffer
...
@@ -328,9 +328,9 @@ struct GridwiseConvolutionImplicitGemm_v4r3_nchw_kcyx_nkhw_lds_double_buffer
__syncthreads
();
__syncthreads
();
// LDS doubel buffer: load next data from device mem
// LDS doubel buffer: load next data from device mem
blockwise_in_copy
.
RunLoadRegister
Clipboard
(
p_in_global
,
p_in_register_
clipboard
);
blockwise_in_copy
.
RunLoadRegister
Buffer
(
p_in_global
,
p_in_register_
buffer
);
blockwise_wei_copy
.
RunLoadRegister
Clipboard
(
p_wei_block_on_global
,
blockwise_wei_copy
.
RunLoadRegister
Buffer
(
p_wei_block_on_global
,
p_wei_register_
clipboard
);
p_wei_register_
buffer
);
#if 0
#if 0
if(get_block_1d_id() == 0)
if(get_block_1d_id() == 0)
...
@@ -338,10 +338,10 @@ struct GridwiseConvolutionImplicitGemm_v4r3_nchw_kcyx_nkhw_lds_double_buffer
...
@@ -338,10 +338,10 @@ struct GridwiseConvolutionImplicitGemm_v4r3_nchw_kcyx_nkhw_lds_double_buffer
printf("tid (%d %d), %f %f %f %f\n",
printf("tid (%d %d), %f %f %f %f\n",
get_block_1d_id(),
get_block_1d_id(),
get_thread_local_1d_id(),
get_thread_local_1d_id(),
p_wei_register_
clipboard
[0],
p_wei_register_
buffer
[0],
p_wei_register_
clipboard
[1],
p_wei_register_
buffer
[1],
p_wei_register_
clipboard
[2],
p_wei_register_
buffer
[2],
p_wei_register_
clipboard
[3]);
p_wei_register_
buffer
[3]);
}
}
#endif
#endif
...
@@ -349,17 +349,15 @@ struct GridwiseConvolutionImplicitGemm_v4r3_nchw_kcyx_nkhw_lds_double_buffer
...
@@ -349,17 +349,15 @@ struct GridwiseConvolutionImplicitGemm_v4r3_nchw_kcyx_nkhw_lds_double_buffer
blockwise_gemm
.
Run
(
p_wei_block_now
,
p_in_block_now
,
p_out_thread
);
blockwise_gemm
.
Run
(
p_wei_block_now
,
p_in_block_now
,
p_out_thread
);
// LDS double buffer: store next data to LDS
// LDS double buffer: store next data to LDS
blockwise_in_copy
.
RunStoreRegisterClipboard
(
p_in_register_clipboard
,
blockwise_in_copy
.
RunStoreRegisterBuffer
(
p_in_register_buffer
,
p_in_block_next
);
p_in_block_next
);
blockwise_wei_copy
.
RunStoreRegisterBuffer
(
p_wei_register_buffer
,
p_wei_block_next
);
blockwise_wei_copy
.
RunStoreRegisterClipboard
(
p_wei_register_clipboard
,
p_wei_block_next
);
}
}
}
}
// LDS double buffer: tail
// LDS double buffer: tail
{
{
Float
p_in_register_
clipboard
[
blockwise_in_copy
.
GetRegister
Clipboard
Size
()];
Float
p_in_register_
buffer
[
blockwise_in_copy
.
GetRegister
Buffer
Size
()];
Float
p_wei_register_
clipboard
[
blockwise_wei_copy
.
GetRegister
Clipboard
Size
()];
Float
p_wei_register_
buffer
[
blockwise_wei_copy
.
GetRegister
Buffer
Size
()];
// even iteration
// even iteration
blockwise_in_copy
.
MoveSlicingWindowOnSourceTensor
(
I0
,
Number
<
EPerBlock
>
{},
True
);
blockwise_in_copy
.
MoveSlicingWindowOnSourceTensor
(
I0
,
Number
<
EPerBlock
>
{},
True
);
...
@@ -368,18 +366,17 @@ struct GridwiseConvolutionImplicitGemm_v4r3_nchw_kcyx_nkhw_lds_double_buffer
...
@@ -368,18 +366,17 @@ struct GridwiseConvolutionImplicitGemm_v4r3_nchw_kcyx_nkhw_lds_double_buffer
__syncthreads
();
__syncthreads
();
// LDS doubel buffer: load next data from device mem
// LDS doubel buffer: load next data from device mem
blockwise_in_copy
.
RunLoadRegisterClipboard
(
p_in_global
,
p_in_register_clipboard
);
blockwise_in_copy
.
RunLoadRegisterBuffer
(
p_in_global
,
p_in_register_buffer
);
blockwise_wei_copy
.
RunLoadRegisterClipboard
(
p_wei_block_on_global
,
blockwise_wei_copy
.
RunLoadRegisterBuffer
(
p_wei_block_on_global
,
p_wei_register_buffer
);
p_wei_register_clipboard
);
// LDS double buffer: GEMM on current data
// LDS double buffer: GEMM on current data
blockwise_gemm
.
Run
(
p_wei_block_double
,
p_in_block_double
,
p_out_thread
);
blockwise_gemm
.
Run
(
p_wei_block_double
,
p_in_block_double
,
p_out_thread
);
// LDS double buffer: store next data to LDS
// LDS double buffer: store next data to LDS
blockwise_in_copy
.
RunStoreRegister
Clipboard
(
p_in_register_
clipboard
,
blockwise_in_copy
.
RunStoreRegister
Buffer
(
p_in_register_
buffer
,
p_in_block_double
+
in_block_space
);
p_in_block_double
+
in_block_space
);
blockwise_wei_copy
.
RunStoreRegister
Clipboard
(
p_wei_register_
clipboard
,
blockwise_wei_copy
.
RunStoreRegister
Buffer
(
p_wei_register_
buffer
,
p_wei_block_double
+
wei_block_space
);
p_wei_block_double
+
wei_block_space
);
// odd iteration
// odd iteration
__syncthreads
();
__syncthreads
();
...
...
composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer.hpp
View file @
8bdaba51
...
@@ -134,8 +134,6 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer
...
@@ -134,8 +134,6 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer
BlockwiseGenericTensorSliceCopy_v2
<
BlockSize
,
BlockwiseGenericTensorSliceCopy_v2
<
BlockSize
,
decltype
(
in_e_b_global_desc
),
decltype
(
in_e_b_global_desc
),
decltype
(
in_e_b_block_desc
),
decltype
(
in_e_b_block_desc
),
MergedTensorCoordinate
<
decltype
(
in_e_b_global_desc
)
>
,
NormalTensorCoordinate
<
decltype
(
in_e_b_block_desc
)
>
,
decltype
(
in_e_b_block_desc
.
GetLengths
()),
decltype
(
in_e_b_block_desc
.
GetLengths
()),
InBlockCopySubLengths_E_B
,
InBlockCopySubLengths_E_B
,
InBlockCopyClusterLengths_E_B
,
InBlockCopyClusterLengths_E_B
,
...
@@ -162,22 +160,21 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer
...
@@ -162,22 +160,21 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer
// operator for blockwise copy of weight into LDS
// operator for blockwise copy of weight into LDS
// slice a tensor, and copy it into another tensor
// slice a tensor, and copy it into another tensor
// this copy operator already have blockwise offset built-in
// this copy operator already have blockwise offset built-in
auto
blockwise_wei_copy
=
BlockwiseGenericTensorSliceCopy_v2
<
auto
blockwise_wei_copy
=
BlockSize
,
BlockwiseGenericTensorSliceCopy_v2
<
BlockSize
,
decltype
(
wei_e_k_global_desc
),
decltype
(
wei_e_k_global_desc
),
decltype
(
wei_e_k_block_desc
),
decltype
(
wei_e_k_block_desc
),
NormalTensorCoordinate
<
decltype
(
wei_e_k_global_desc
)
>
,
decltype
(
wei_e_k_block_desc
.
GetLengths
()),
NormalTensorCoordinate
<
decltype
(
wei_e_k_block_desc
)
>
,
WeiBlockCopySubLengths_E_K
,
decltype
(
wei_e_k_block_desc
.
GetLengths
()),
WeiBlockCopyClusterLengths_E_K
,
WeiBlockCopySubLengths_E_K
,
WeiBlockCopyThreadClusterArrangeOrder
,
WeiBlockCopyClusterLengths_E_K
,
WeiBlockCopySrcAccessOrder
,
WeiBlockCopyThreadClusterArrangeOrder
,
WeiBlockCopyDstAccessOrder
,
WeiBlockCopySrcAccessOrder
,
0
,
WeiBlockCopyDstAccessOrder
,
1
,
0
,
WeiBlockCopySrcDataPerRead_E
,
1
,
WeiBlockCopyDstDataPerWrite_K
>
(
WeiBlockCopySrcDataPerRead_E
,
{
0
,
k_block_data_on_global
},
{
0
,
0
});
WeiBlockCopyDstDataPerWrite_K
>
({
0
,
k_block_data_on_global
},
{
0
,
0
});
// GEMM definition
// GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx
// c_mtx += transpose(a_mtx) * b_mtx
...
@@ -365,21 +362,20 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer
...
@@ -365,21 +362,20 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer
using
OutThreadCopySliceLengths
=
using
OutThreadCopySliceLengths
=
Sequence
<
GemmMRepeat
,
GemmMPerThreadSubC
,
GemmNPerThreadSubC
>
;
Sequence
<
GemmMRepeat
,
GemmMPerThreadSubC
,
GemmNPerThreadSubC
>
;
auto
threadwise_out_copy
=
ThreadwiseGenericTensorSliceCopy_v2r1
<
auto
threadwise_out_copy
=
decltype
(
out_k0_k1_b_thread_desc
),
ThreadwiseGenericTensorSliceCopy_v2r1
<
decltype
(
out_k0_k1_b_thread_desc
),
decltype
(
out_k0_k1_b_global_desc
),
decltype
(
out_k0_k1_b_global_desc
),
NormalTensorCoordinate
<
decltype
(
out_k0_k1_b_thread_desc
)
>
,
OutThreadCopySliceLengths
,
MergedTensorCoordinate
<
decltype
(
out_k0_k1_b_global_desc
)
>
,
arithmetic_sequence_gen
<
0
,
3
,
1
>::
type
,
OutThreadCopySliceLengths
,
arithmetic_sequence_gen
<
0
,
3
,
1
>::
type
,
arithmetic_sequence_gen
<
0
,
3
,
1
>::
type
,
2
,
arithmetic_sequence_gen
<
0
,
3
,
1
>::
type
,
2
,
2
,
OutThreadCopyDataPerAccess_B
,
2
,
OutThreadCopyDataPerAccess_B
>
(
OutThreadCopyDataPerAccess_B
,
{
0
,
0
,
0
},
OutThreadCopyDataPerAccess_B
>
({
0
,
0
,
0
},
{
k_thread_data_on_global
/
K1
,
{
k_thread_data_on_global
/
K1
,
k_thread_data_on_global
%
K1
,
k_thread_data_on_global
%
K1
,
b_thread_data_on_global
});
b_thread_data_on_global
});
for
(
index_t
nrepeat
=
0
;
nrepeat
<
GemmNRepeat
;
++
nrepeat
)
for
(
index_t
nrepeat
=
0
;
nrepeat
<
GemmNRepeat
;
++
nrepeat
)
{
{
...
...
composable_kernel/include/tensor_description/tensor_coordinate.hpp
View file @
8bdaba51
...
@@ -295,5 +295,27 @@ struct MergedTensorCoordinate
...
@@ -295,5 +295,27 @@ struct MergedTensorCoordinate
index_t
mOffset
;
index_t
mOffset
;
};
};
template
<
class
TensorDesc
>
struct
TensorCoordinate
{
private:
template
<
class
...
Ts
>
__host__
__device__
static
constexpr
auto
MakeDummyTensorCoordinate
(
ConstantTensorDescriptor
<
Ts
...
>
)
{
return
NormalTensorCoordinate
<
ConstantTensorDescriptor
<
Ts
...
>>
();
}
template
<
class
...
Ts
>
__host__
__device__
static
constexpr
auto
MakeDummyTensorCoordinate
(
ConstantMergedTensorDescriptor
<
Ts
...
>
)
{
return
MergedTensorCoordinate
<
ConstantMergedTensorDescriptor
<
Ts
...
>>
();
}
public:
using
type
=
decltype
(
MakeDummyTensorCoordinate
(
TensorDesc
{}));
};
}
// namespace ck
}
// namespace ck
#endif
#endif
composable_kernel/include/tensor_operation/blockwise_2d_tensor_op.hpp
View file @
8bdaba51
...
@@ -563,7 +563,7 @@ struct Blockwise2dTensorCopy3
...
@@ -563,7 +563,7 @@ struct Blockwise2dTensorCopy3
}
}
}
}
__device__
constexpr
index_t
GetRegister
Clipboard
Size
()
const
__device__
constexpr
index_t
GetRegister
Buffer
Size
()
const
{
{
static_assert
(
is_same
<
Float
,
float
>
{},
"wrong! only support float!
\n
"
);
static_assert
(
is_same
<
Float
,
float
>
{},
"wrong! only support float!
\n
"
);
...
@@ -579,8 +579,8 @@ struct Blockwise2dTensorCopy3
...
@@ -579,8 +579,8 @@ struct Blockwise2dTensorCopy3
return
DataPerRead
*
(
L0
+
thread_per_d0
-
1
)
/
thread_per_d0
;
return
DataPerRead
*
(
L0
+
thread_per_d0
-
1
)
/
thread_per_d0
;
}
}
__device__
void
RunLoadRegister
Clipboard
(
const
Float
*
__restrict__
p_src
,
__device__
void
RunLoadRegister
Buffer
(
const
Float
*
__restrict__
p_src
,
Float
*
__restrict__
p_clipboard
)
const
Float
*
__restrict__
p_clipboard
)
const
{
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
...
@@ -630,8 +630,8 @@ struct Blockwise2dTensorCopy3
...
@@ -630,8 +630,8 @@ struct Blockwise2dTensorCopy3
}
}
}
}
__device__
void
RunStoreRegister
Clipboard
(
const
Float
*
__restrict__
p_clipboard
,
__device__
void
RunStoreRegister
Buffer
(
const
Float
*
__restrict__
p_clipboard
,
Float
*
__restrict__
p_dst
)
const
Float
*
__restrict__
p_dst
)
const
{
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
...
@@ -681,8 +681,8 @@ struct Blockwise2dTensorCopy3
...
@@ -681,8 +681,8 @@ struct Blockwise2dTensorCopy3
}
}
#if CK_USE_AMD_INLINE_ASM
#if CK_USE_AMD_INLINE_ASM
__device__
void
RunLoadRegister
Clipboard
_asm
(
const
Float
*
__restrict__
p_src
,
__device__
void
RunLoadRegister
Buffer
_asm
(
const
Float
*
__restrict__
p_src
,
Float
*
p_clipboard
)
const
Float
*
p_clipboard
)
const
{
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
...
@@ -741,8 +741,8 @@ struct Blockwise2dTensorCopy3
...
@@ -741,8 +741,8 @@ struct Blockwise2dTensorCopy3
}
}
}
}
__device__
void
RunStoreRegister
Clipboard
_asm
(
const
Float
*
__restrict__
p_clipboard
,
__device__
void
RunStoreRegister
Buffer
_asm
(
const
Float
*
__restrict__
p_clipboard
,
Float
*
__restrict__
p_dst
)
const
Float
*
__restrict__
p_dst
)
const
{
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
...
...
composable_kernel/include/tensor_operation/blockwise_3d_tensor_op.hpp
View file @
8bdaba51
...
@@ -237,7 +237,7 @@ struct Blockwise3dTensorCopy3
...
@@ -237,7 +237,7 @@ struct Blockwise3dTensorCopy3
}
}
}
}
__device__
static
constexpr
index_t
GetRegister
Clipboard
Size
()
__device__
static
constexpr
index_t
GetRegister
Buffer
Size
()
{
{
static_assert
(
is_same
<
Float
,
float
>
{},
"wrong! only support float!
\n
"
);
static_assert
(
is_same
<
Float
,
float
>
{},
"wrong! only support float!
\n
"
);
...
@@ -260,8 +260,8 @@ struct Blockwise3dTensorCopy3
...
@@ -260,8 +260,8 @@ struct Blockwise3dTensorCopy3
return
DataPerRead
*
nloop_d0
*
nloop_d1
*
nloop_d2
;
return
DataPerRead
*
nloop_d0
*
nloop_d1
*
nloop_d2
;
}
}
__device__
void
RunLoadRegister
Clipboard
(
const
Float
*
__restrict__
p_src
,
__device__
void
RunLoadRegister
Buffer
(
const
Float
*
__restrict__
p_src
,
Float
*
__restrict__
p_clipboard
)
const
Float
*
__restrict__
p_clipboard
)
const
{
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
...
@@ -316,8 +316,8 @@ struct Blockwise3dTensorCopy3
...
@@ -316,8 +316,8 @@ struct Blockwise3dTensorCopy3
}
}
}
}
__device__
void
RunStoreRegister
Clipboard
(
const
Float
*
__restrict__
p_clipboard
,
__device__
void
RunStoreRegister
Buffer
(
const
Float
*
__restrict__
p_clipboard
,
Float
*
__restrict__
p_dst
)
const
Float
*
__restrict__
p_dst
)
const
{
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
...
...
composable_kernel/include/tensor_operation/blockwise_4d_tensor_op.hpp
View file @
8bdaba51
...
@@ -596,7 +596,7 @@ struct Blockwise4dTensorCopy3
...
@@ -596,7 +596,7 @@ struct Blockwise4dTensorCopy3
}
}
}
}
__device__
constexpr
index_t
GetRegister
Clipboard
Size
()
const
__device__
constexpr
index_t
GetRegister
Buffer
Size
()
const
{
{
static_assert
(
is_same
<
Float
,
float
>
{},
"wrong! only support float!
\n
"
);
static_assert
(
is_same
<
Float
,
float
>
{},
"wrong! only support float!
\n
"
);
...
@@ -623,8 +623,8 @@ struct Blockwise4dTensorCopy3
...
@@ -623,8 +623,8 @@ struct Blockwise4dTensorCopy3
return
DataPerRead
*
nloop_d0
*
nloop_d1
*
nloop_d2
*
nloop_d3
;
return
DataPerRead
*
nloop_d0
*
nloop_d1
*
nloop_d2
*
nloop_d3
;
}
}
__device__
void
RunLoadRegister
Clipboard
(
const
Float
*
__restrict__
p_src
,
__device__
void
RunLoadRegister
Buffer
(
const
Float
*
__restrict__
p_src
,
Float
*
__restrict__
p_clipboard
)
const
Float
*
__restrict__
p_clipboard
)
const
{
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
...
@@ -690,8 +690,8 @@ struct Blockwise4dTensorCopy3
...
@@ -690,8 +690,8 @@ struct Blockwise4dTensorCopy3
}
}
}
}
__device__
void
RunStoreRegister
Clipboard
(
const
Float
*
__restrict__
p_clipboard
,
__device__
void
RunStoreRegister
Buffer
(
const
Float
*
__restrict__
p_clipboard
,
Float
*
__restrict__
p_dst
)
const
Float
*
__restrict__
p_dst
)
const
{
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
...
...
composable_kernel/include/tensor_operation/blockwise_generic_tensor_slice_copy.hpp
View file @
8bdaba51
...
@@ -420,8 +420,6 @@ struct BlockwiseGenericTensorSliceCopy_v1
...
@@ -420,8 +420,6 @@ struct BlockwiseGenericTensorSliceCopy_v1
template
<
index_t
BlockSize
,
template
<
index_t
BlockSize
,
class
SrcDesc
,
class
SrcDesc
,
class
DstDesc
,
class
DstDesc
,
class
SrcCoordinate
,
class
DstCoordinate
,
class
SliceLengths
,
class
SliceLengths
,
class
SubLengths
,
class
SubLengths
,
class
ThreadClusterLengths
,
class
ThreadClusterLengths
,
...
@@ -436,6 +434,9 @@ struct BlockwiseGenericTensorSliceCopy_v2
...
@@ -436,6 +434,9 @@ struct BlockwiseGenericTensorSliceCopy_v2
{
{
static
constexpr
index_t
nDim
=
SrcDesc
::
GetNumOfDimension
();
static
constexpr
index_t
nDim
=
SrcDesc
::
GetNumOfDimension
();
using
SrcCoordinate
=
typename
TensorCoordinate
<
SrcDesc
>::
type
;
using
DstCoordinate
=
typename
TensorCoordinate
<
DstDesc
>::
type
;
__device__
constexpr
BlockwiseGenericTensorSliceCopy_v2
(
SrcCoordinate
src_block_slice_origin
,
__device__
constexpr
BlockwiseGenericTensorSliceCopy_v2
(
SrcCoordinate
src_block_slice_origin
,
DstCoordinate
dst_block_slice_origin
)
DstCoordinate
dst_block_slice_origin
)
{
{
...
@@ -515,31 +516,25 @@ struct BlockwiseGenericTensorSliceCopy_v2
...
@@ -515,31 +516,25 @@ struct BlockwiseGenericTensorSliceCopy_v2
private:
private:
using
RegisterBufferDesc
=
decltype
(
make_ConstantTensorDescriptor_packed
(
SubLengths
{}));
using
RegisterBufferDesc
=
decltype
(
make_ConstantTensorDescriptor_packed
(
SubLengths
{}));
using
ThreadwiseLoad
=
using
ThreadwiseLoad
=
ThreadwiseGenericTensorSliceCopy_v2r1
<
SrcDesc
,
ThreadwiseGenericTensorSliceCopy_v2r1
<
SrcDesc
,
RegisterBufferDesc
,
RegisterBufferDesc
,
SubLengths
,
SrcCoordinate
,
SrcDimAccessOrder
,
NormalTensorCoordinate
<
RegisterBufferDesc
>
,
SrcDimAccessOrder
,
SubLengths
,
SrcVectorAccessDim
,
SrcDimAccessOrder
,
SrcVectorAccessDim
,
SrcDimAccessOrder
,
SrcDataPerAccess
,
SrcVectorAccessDim
,
1
>
;
SrcVectorAccessDim
,
SrcDataPerAccess
,
using
ThreadwiseStore
=
ThreadwiseGenericTensorSliceCopy_v2r1
<
RegisterBufferDesc
,
1
>
;
DstDesc
,
SubLengths
,
using
ThreadwiseStore
=
DstDimAccessOrder
,
ThreadwiseGenericTensorSliceCopy_v2r1
<
RegisterBufferDesc
,
DstDimAccessOrder
,
DstDesc
,
DstVectorAccessDim
,
NormalTensorCoordinate
<
RegisterBufferDesc
>
,
DstVectorAccessDim
,
DstCoordinate
,
1
,
SubLengths
,
DstDataPerAccess
>
;
DstDimAccessOrder
,
DstDimAccessOrder
,
DstVectorAccessDim
,
DstVectorAccessDim
,
1
,
DstDataPerAccess
>
;
ThreadwiseLoad
mThreadwiseLoad
;
ThreadwiseLoad
mThreadwiseLoad
;
ThreadwiseStore
mThreadwiseStore
;
ThreadwiseStore
mThreadwiseStore
;
...
...
composable_kernel/include/tensor_operation/blockwise_tensor_slice_copy.hpp
View file @
8bdaba51
...
@@ -165,7 +165,7 @@ struct BlockwiseTensorSliceReorderCopy_v3
...
@@ -165,7 +165,7 @@ struct BlockwiseTensorSliceReorderCopy_v3
#endif
#endif
}
}
__device__
static
constexpr
index_t
GetRegister
Clipboard
Size
()
__device__
static
constexpr
index_t
GetRegister
Buffer
Size
()
{
{
constexpr
auto
thread_sub_tensor_lengths
=
SrcSubLengths
{};
constexpr
auto
thread_sub_tensor_lengths
=
SrcSubLengths
{};
...
@@ -183,8 +183,8 @@ struct BlockwiseTensorSliceReorderCopy_v3
...
@@ -183,8 +183,8 @@ struct BlockwiseTensorSliceReorderCopy_v3
return
thread_tensor_desc
.
GetElementSpace
();
return
thread_tensor_desc
.
GetElementSpace
();
}
}
__device__
void
RunLoadRegister
Clipboard
(
const
Float
*
__restrict__
p_src
,
__device__
void
RunLoadRegister
Buffer
(
const
Float
*
__restrict__
p_src
,
Float
*
__restrict__
p_clipboard
)
const
Float
*
__restrict__
p_clipboard
)
const
{
{
constexpr
auto
thread_sub_tensor_lengths
=
SrcSubLengths
{};
constexpr
auto
thread_sub_tensor_lengths
=
SrcSubLengths
{};
...
@@ -219,8 +219,8 @@ struct BlockwiseTensorSliceReorderCopy_v3
...
@@ -219,8 +219,8 @@ struct BlockwiseTensorSliceReorderCopy_v3
});
});
}
}
__device__
void
RunStoreRegister
Clipboard
(
const
Float
*
__restrict__
p_clipboard
,
__device__
void
RunStoreRegister
Buffer
(
const
Float
*
__restrict__
p_clipboard
,
Float
*
__restrict__
p_dst
)
const
Float
*
__restrict__
p_dst
)
const
{
{
constexpr
auto
thread_sub_tensor_lengths
=
SrcSubLengths
{};
constexpr
auto
thread_sub_tensor_lengths
=
SrcSubLengths
{};
...
@@ -274,10 +274,10 @@ struct BlockwiseTensorSliceReorderCopy_v3
...
@@ -274,10 +274,10 @@ struct BlockwiseTensorSliceReorderCopy_v3
__device__
void
Run
(
const
Float
*
__restrict__
p_src
,
Float
*
__restrict__
p_dst
)
const
__device__
void
Run
(
const
Float
*
__restrict__
p_src
,
Float
*
__restrict__
p_dst
)
const
{
{
Float
p_clipboard
[
GetRegister
Clipboard
Size
()];
Float
p_clipboard
[
GetRegister
Buffer
Size
()];
RunLoadRegister
Clipboard
(
p_src
,
p_clipboard
);
RunLoadRegister
Buffer
(
p_src
,
p_clipboard
);
RunStoreRegister
Clipboard
(
p_clipboard
,
p_dst
);
RunStoreRegister
Buffer
(
p_clipboard
,
p_dst
);
}
}
// this function doesn't do santiy check on whether the slicing window is out of the boundary
// this function doesn't do santiy check on whether the slicing window is out of the boundary
...
...
composable_kernel/include/tensor_operation/threadwise_generic_tensor_slice_copy.hpp
View file @
8bdaba51
...
@@ -14,10 +14,6 @@
...
@@ -14,10 +14,6 @@
#define CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_THREADWISE_GENERIC_TENSOR_SLICE_COPY_V1R2 0
#define CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_THREADWISE_GENERIC_TENSOR_SLICE_COPY_V1R2 0
#endif
#endif
#ifndef CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_THREADWISE_GENERIC_TENSOR_SLICE_COPY_V2
#define CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_THREADWISE_GENERIC_TENSOR_SLICE_COPY_V2 0
#endif
#ifndef CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_THREADWISE_GENERIC_TENSOR_SLICE_COPY_V2R1
#ifndef CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_THREADWISE_GENERIC_TENSOR_SLICE_COPY_V2R1
#define CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_THREADWISE_GENERIC_TENSOR_SLICE_COPY_V2R1 0
#define CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_THREADWISE_GENERIC_TENSOR_SLICE_COPY_V2R1 0
#endif
#endif
...
@@ -430,170 +426,6 @@ struct ThreadwiseGenericTensorSliceCopy_v1r2
...
@@ -430,170 +426,6 @@ struct ThreadwiseGenericTensorSliceCopy_v1r2
Array
<
index_t
,
nDim
>
mDstSliceOrigin
;
Array
<
index_t
,
nDim
>
mDstSliceOrigin
;
};
};
template
<
class
SrcDesc
,
class
DstDesc
,
class
SrcCoordinate
,
class
DstCoordinate
,
class
SliceLengths
>
struct
ThreadwiseGenericTensorSliceCopy_v2
{
static
constexpr
index_t
nDim
=
SrcDesc
::
GetNumOfDimension
();
__device__
constexpr
ThreadwiseGenericTensorSliceCopy_v2
(
SrcCoordinate
src_slice_origin
,
DstCoordinate
dst_slice_origin
)
:
mSrcSliceOrigin
(
src_slice_origin
),
mDstSliceOrigin
(
dst_slice_origin
)
{
}
__device__
constexpr
ThreadwiseGenericTensorSliceCopy_v2
()
:
ThreadwiseGenericTensorSliceCopy_v2
(
make_zero_array
<
index_t
,
nDim
>
(),
make_zero_array
<
index_t
,
nDim
>
())
{
}
__device__
void
SetSrcSliceOrigin
(
SrcCoordinate
src_slice_origin
)
{
mSrcSliceOrigin
=
src_slice_origin
;
}
__device__
void
SetDstSliceOrigin
(
DstCoordinate
dst_slice_origin
)
{
mDstSliceOrigin
=
dst_slice_origin
;
}
template
<
class
TDesc
,
class
Seq
>
struct
IsolateMergedDimSliceLengthsHack
{
template
<
class
IDim
>
__device__
constexpr
index_t
operator
()(
IDim
idim
)
const
{
return
TDesc
::
ContainMultipleOriginalDimensions
(
idim
)
?
Seq
{}[
idim
]
:
1
;
}
};
template
<
class
TData
>
__device__
void
Run
(
const
TData
*
p_src
,
TData
*
p_dst
)
const
{
constexpr
auto
buffer_desc
=
make_ConstantTensorDescriptor_packed
(
SliceLengths
{});
TData
p_buffer_
[
buffer_desc
.
GetElementSpace
()];
TData
*
p_buffer
=
p_buffer_
;
// hacks to isolate merged dimension from normal dimensions, and calculate their offset
// seperately
// SrcMergedDimSliceLengthsHack has entry same as SliceLengths on src merged dimensions,
// but 1 on normal dimensions;
// SrcNormalDimSliceLengthsHack has entry same as SliceLengths on src normal dimensions,
// but 1 on merged dimensions;
using
SrcMergedDimSliceLengthsHack
=
typename
sequence_gen
<
SliceLengths
::
GetSize
(),
IsolateMergedDimSliceLengthsHack
<
SrcDesc
,
SliceLengths
>>::
type
;
using
SrcNormalDimSliceLengthsHack
=
decltype
((
SliceLengths
{}
+
Number
<
1
>
{})
-
SrcMergedDimSliceLengthsHack
{});
#if CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_THREADWISE_GENERIC_TENSOR_SLICE_COPY_V2
static_ford
<
SrcMergedDimSliceLengthsHack
>
{}([
&
](
auto
merged_dim_data_id_
)
{
constexpr
auto
merged_dim_data_id
=
decltype
(
merged_dim_data_id_
){};
const
TData
*
p_src_tmp
=
p_src
+
(
mSrcSliceOrigin
+
merged_dim_data_id
).
GetOffset
();
static_ford
<
SrcNormalDimSliceLengthsHack
>
{}([
&
](
auto
normal_dim_data_id_
)
{
constexpr
auto
normal_dim_data_id
=
decltype
(
normal_dim_data_id_
){};
constexpr
index_t
buffer_offset
=
buffer_desc
.
GetOffsetFromMultiIndex
(
merged_dim_data_id
+
normal_dim_data_id
);
constexpr
index_t
src_normal_offset
=
SrcDesc
::
GetOffsetFromMultiIndex
(
normal_dim_data_id
);
p_buffer
[
buffer_offset
]
=
p_src_tmp
[
src_normal_offset
];
});
});
#else
ford
<
SrcMergedDimSliceLengthsHack
>
{}([
&
](
auto
merged_dim_data_id
)
{
const
TData
*
p_src_tmp
=
p_src
+
(
mSrcSliceOrigin
+
merged_dim_data_id
).
GetOffset
();
ford
<
SrcNormalDimSliceLengthsHack
>
{}([
&
](
auto
normal_dim_data_id
)
{
const
index_t
buffer_offset
=
buffer_desc
.
GetOffsetFromMultiIndex
(
merged_dim_data_id
+
normal_dim_data_id
);
const
index_t
src_normal_offset
=
SrcDesc
::
GetOffsetFromMultiIndex
(
normal_dim_data_id
);
p_buffer
[
buffer_offset
]
=
p_src_tmp
[
src_normal_offset
];
});
});
#endif
// DstMergedDimSliceLengthsHack has entry same as SliceLengths on dst merged dimensions,
// but 1 on normal dimensions;
// DstNormalDimSliceLengthsHack has entry same as SliceLengths on dst normal dimensions,
// but 1 on merged dimensions;
using
DstMergedDimSliceLengthsHack
=
typename
sequence_gen
<
SliceLengths
::
GetSize
(),
IsolateMergedDimSliceLengthsHack
<
DstDesc
,
SliceLengths
>>::
type
;
using
DstNormalDimSliceLengthsHack
=
decltype
((
SliceLengths
{}
+
Number
<
1
>
{})
-
DstMergedDimSliceLengthsHack
{});
#if CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_THREADWISE_GENERIC_TENSOR_SLICE_COPY_V2
static_ford
<
DstMergedDimSliceLengthsHack
>
{}([
&
](
auto
merged_dim_data_id_
)
{
constexpr
auto
merged_dim_data_id
=
decltype
(
merged_dim_data_id_
){};
TData
*
p_dst_tmp
=
p_dst
+
(
mDstSliceOrigin
+
merged_dim_data_id
).
GetOffset
();
static_ford
<
DstNormalDimSliceLengthsHack
>
{}([
&
](
auto
normal_dim_data_id_
)
{
constexpr
auto
normal_dim_data_id
=
decltype
(
normal_dim_data_id_
){};
constexpr
index_t
buffer_offset
=
buffer_desc
.
GetOffsetFromMultiIndex
(
merged_dim_data_id
+
normal_dim_data_id
);
constexpr
index_t
dst_normal_offset
=
DstDesc
::
GetOffsetFromMultiIndex
(
normal_dim_data_id
);
p_dst_tmp
[
dst_normal_offset
]
=
p_buffer
[
buffer_offset
];
});
});
#else
ford
<
DstMergedDimSliceLengthsHack
>
{}([
&
](
auto
merged_dim_data_id
)
{
TData
*
p_dst_tmp
=
p_dst
+
(
mDstSliceOrigin
+
merged_dim_data_id
).
GetOffset
();
ford
<
DstNormalDimSliceLengthsHack
>
{}([
&
](
auto
normal_dim_data_id
)
{
const
index_t
buffer_offset
=
buffer_desc
.
GetOffsetFromMultiIndex
(
merged_dim_data_id
+
normal_dim_data_id
);
const
index_t
dst_normal_offset
=
DstDesc
::
GetOffsetFromMultiIndex
(
normal_dim_data_id
);
p_dst_tmp
[
dst_normal_offset
]
=
p_buffer
[
buffer_offset
];
});
});
#endif
}
// T can be Sequence or Array
template
<
class
T
,
bool
PositiveDirection
>
__device__
void
MoveSrcSlicingWindow
(
T
step_sizes
,
integral_constant
<
bool
,
PositiveDirection
>
)
{
static_if
<
PositiveDirection
>
{}([
&
](
auto
)
{
mSrcSliceOrigin
+=
step_sizes
;
}).
Else
([
&
](
auto
)
{
mSrcSliceOrigin
-=
step_sizes
;
});
}
template
<
class
T
,
bool
PositiveDirection
>
__device__
void
MoveDstSlicingWindow
(
T
step_sizes
,
integral_constant
<
bool
,
PositiveDirection
>
)
{
static_if
<
PositiveDirection
>
{}([
&
](
auto
)
{
mDstSliceOrigin
+=
step_sizes
;
}).
Else
([
&
](
auto
)
{
mDstSliceOrigin
-=
step_sizes
;
});
}
private:
SrcCoordinate
mSrcSliceOrigin
;
DstCoordinate
mDstSliceOrigin
;
};
// This threadwise copy allow vector access of src and dst.
// This threadwise copy allow vector access of src and dst.
// It allows the dimensions of vector access to be different on src and dst.
// It allows the dimensions of vector access to be different on src and dst.
// It also allows the vector size to be different on src and dst.
// It also allows the vector size to be different on src and dst.
...
@@ -605,8 +437,6 @@ struct ThreadwiseGenericTensorSliceCopy_v2
...
@@ -605,8 +437,6 @@ struct ThreadwiseGenericTensorSliceCopy_v2
// used for the buffer.
// used for the buffer.
template
<
class
SrcDesc
,
template
<
class
SrcDesc
,
class
DstDesc
,
class
DstDesc
,
class
SrcCoordinate
,
class
DstCoordinate
,
class
SliceLengths
,
class
SliceLengths
,
class
SrcDimAccessOrder
,
class
SrcDimAccessOrder
,
class
DstDimAccessOrder
,
class
DstDimAccessOrder
,
...
@@ -618,6 +448,9 @@ struct ThreadwiseGenericTensorSliceCopy_v2r1
...
@@ -618,6 +448,9 @@ struct ThreadwiseGenericTensorSliceCopy_v2r1
{
{
static
constexpr
index_t
nDim
=
SliceLengths
::
GetSize
();
static
constexpr
index_t
nDim
=
SliceLengths
::
GetSize
();
using
SrcCoordinate
=
typename
TensorCoordinate
<
SrcDesc
>::
type
;
using
DstCoordinate
=
typename
TensorCoordinate
<
DstDesc
>::
type
;
__device__
constexpr
ThreadwiseGenericTensorSliceCopy_v2r1
(
SrcCoordinate
src_slice_origin
,
__device__
constexpr
ThreadwiseGenericTensorSliceCopy_v2r1
(
SrcCoordinate
src_slice_origin
,
DstCoordinate
dst_slice_origin
)
DstCoordinate
dst_slice_origin
)
:
mSrcSliceOrigin
(
src_slice_origin
),
mDstSliceOrigin
(
dst_slice_origin
)
:
mSrcSliceOrigin
(
src_slice_origin
),
mDstSliceOrigin
(
dst_slice_origin
)
...
...
driver/include/device_convolution_implicit_gemm_v1_chwn_cyxk_khwn.hpp
View file @
8bdaba51
...
@@ -107,11 +107,11 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc,
...
@@ -107,11 +107,11 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc,
constexpr index_t GemmDataPerReadB = 4;
constexpr index_t GemmDataPerReadB = 4;
using InBlockCopyClusterLengths_CHWN = Sequence<4, 4, 2, 4>;
using InBlockCopyClusterLengths_CHWN = Sequence<4, 4, 2, 4>;
constexpr index_t InBlockCopyDataPer
Read
_N = 4;
constexpr index_t InBlockCopyDataPer
Access
_N = 4;
constexpr index_t WeiBlockCopyDataPer
Read
_K = 4;
constexpr index_t WeiBlockCopyDataPer
Access
_K = 4;
constexpr index_t OutThreadCopyDataPer
Write
_N = 2;
constexpr index_t OutThreadCopyDataPer
Access
_N = 2;
#elif
0
#elif
0
// for 3x3, 34x34, v1r2, Pascal, in-block-copy1
// for 3x3, 34x34, v1r2, Pascal, in-block-copy1
constexpr
index_t
BlockSize
=
128
;
constexpr
index_t
BlockSize
=
128
;
...
@@ -137,12 +137,12 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc,
...
@@ -137,12 +137,12 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc,
constexpr
index_t
GemmDataPerReadA
=
4
;
constexpr
index_t
GemmDataPerReadA
=
4
;
constexpr
index_t
GemmDataPerReadB
=
4
;
constexpr
index_t
GemmDataPerReadB
=
4
;
using
InBlockCopyClusterLengths_CHWN
=
Sequence
<
0
,
0
,
0
,
0
>
;
// not used
using
InBlockCopyClusterLengths_CHWN
=
Sequence
<
0
,
0
,
0
,
0
>
;
// not used
constexpr
index_t
InBlockCopyDataPer
Read
_N
=
4
;
constexpr
index_t
InBlockCopyDataPer
Access
_N
=
4
;
constexpr
index_t
WeiBlockCopyDataPer
Read
_K
=
4
;
constexpr
index_t
WeiBlockCopyDataPer
Access
_K
=
4
;
constexpr
index_t
OutThreadCopyDataPer
Write
_N
=
2
;
constexpr
index_t
OutThreadCopyDataPer
Access
_N
=
2
;
#elif 1
#elif 1
// for 3x3, 34x34, v1r3, Pascal
// for 3x3, 34x34, v1r3, Pascal
// for 3x3, 28x28, v1r3, Pascal
// for 3x3, 28x28, v1r3, Pascal
...
@@ -170,12 +170,15 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc,
...
@@ -170,12 +170,15 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc,
constexpr
index_t
GemmDataPerReadA
=
4
;
constexpr
index_t
GemmDataPerReadA
=
4
;
constexpr
index_t
GemmDataPerReadB
=
4
;
constexpr
index_t
GemmDataPerReadB
=
4
;
using
InBlockCopyClusterLengths_CHWN
=
Sequence
<
8
,
2
,
2
,
4
>
;
using
InBlockCopySubLengths_CHWN
=
Sequence
<
1
,
1
,
1
,
4
>
;
constexpr
index_t
InBlockCopyDataPerRead_N
=
4
;
using
InBlockCopyClusterLengths_CHWN
=
Sequence
<
8
,
2
,
2
,
4
>
;
constexpr
index_t
InBlockCopyDataPerAccess_N
=
4
;
constexpr
index_t
WeiBlockCopyDataPerRead_K
=
4
;
using
WeiBlockCopySubLengths_CK
=
Sequence
<
2
,
4
>
;
using
WeiBlockCopyClusterLengths_CK
=
Sequence
<
4
,
32
>
;
constexpr
index_t
WeiBlockCopyDataPerAccess_K
=
4
;
constexpr
index_t
OutThreadCopyDataPer
Write
_N
=
2
;
constexpr
index_t
OutThreadCopyDataPer
Access
_N
=
2
;
#elif 0
#elif 0
// for 3x3, 34x34, v1r3, Pascal, bad
// for 3x3, 34x34, v1r3, Pascal, bad
constexpr
index_t
BlockSize
=
128
;
constexpr
index_t
BlockSize
=
128
;
...
@@ -201,12 +204,12 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc,
...
@@ -201,12 +204,12 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc,
constexpr
index_t
GemmDataPerReadA
=
4
;
constexpr
index_t
GemmDataPerReadA
=
4
;
constexpr
index_t
GemmDataPerReadB
=
4
;
constexpr
index_t
GemmDataPerReadB
=
4
;
using
InBlockCopyClusterLengths_CHWN
=
Sequence
<
2
,
2
,
32
,
1
>
;
using
InBlockCopyClusterLengths_CHWN
=
Sequence
<
2
,
2
,
32
,
1
>
;
constexpr
index_t
InBlockCopyDataPer
Read
_N
=
1
;
constexpr
index_t
InBlockCopyDataPer
Access
_N
=
1
;
constexpr
index_t
WeiBlockCopyDataPer
Read
_K
=
2
;
constexpr
index_t
WeiBlockCopyDataPer
Access
_K
=
2
;
constexpr
index_t
OutThreadCopyDataPer
Write
_N
=
1
;
constexpr
index_t
OutThreadCopyDataPer
Access
_N
=
1
;
#elif 0
#elif 0
// for 3x3, 34x34, v1r1, Vega 20
// for 3x3, 34x34, v1r1, Vega 20
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
BlockSize
=
256
;
...
@@ -232,12 +235,12 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc,
...
@@ -232,12 +235,12 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc,
constexpr
index_t
GemmDataPerReadA
=
4
;
constexpr
index_t
GemmDataPerReadA
=
4
;
constexpr
index_t
GemmDataPerReadB
=
4
;
constexpr
index_t
GemmDataPerReadB
=
4
;
using
InBlockCopyClusterLengths_CHWN
=
Sequence
<
4
,
4
,
2
,
8
>
;
using
InBlockCopyClusterLengths_CHWN
=
Sequence
<
4
,
4
,
2
,
8
>
;
constexpr
index_t
InBlockCopyDataPer
Read
_N
=
2
;
constexpr
index_t
InBlockCopyDataPer
Access
_N
=
2
;
constexpr
index_t
WeiBlockCopyDataPer
Read
_K
=
2
;
constexpr
index_t
WeiBlockCopyDataPer
Access
_K
=
2
;
constexpr
index_t
OutThreadCopyDataPer
Write
_N
=
4
;
constexpr
index_t
OutThreadCopyDataPer
Access
_N
=
4
;
#elif 1
#elif 1
// for 3x3, 34x34, v1r3, Vega 20
// for 3x3, 34x34, v1r3, Vega 20
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
BlockSize
=
256
;
...
@@ -263,12 +266,12 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc,
...
@@ -263,12 +266,12 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc,
constexpr
index_t
GemmDataPerReadA
=
4
;
constexpr
index_t
GemmDataPerReadA
=
4
;
constexpr
index_t
GemmDataPerReadB
=
4
;
constexpr
index_t
GemmDataPerReadB
=
4
;
using
InBlockCopyClusterLengths_CHWN
=
Sequence
<
8
,
2
,
4
,
4
>
;
using
InBlockCopyClusterLengths_CHWN
=
Sequence
<
8
,
2
,
4
,
4
>
;
constexpr
index_t
InBlockCopyDataPer
Read
_N
=
4
;
constexpr
index_t
InBlockCopyDataPer
Access
_N
=
4
;
constexpr
index_t
WeiBlockCopyDataPer
Read
_K
=
4
;
constexpr
index_t
WeiBlockCopyDataPer
Access
_K
=
4
;
constexpr
index_t
OutThreadCopyDataPer
Write
_N
=
4
;
constexpr
index_t
OutThreadCopyDataPer
Access
_N
=
4
;
#elif 0
#elif 0
// for 3x3, 56x56, v1r1, Pascal
// for 3x3, 56x56, v1r1, Pascal
constexpr
index_t
NPerBlock
=
32
;
constexpr
index_t
NPerBlock
=
32
;
...
@@ -282,13 +285,13 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc,
...
@@ -282,13 +285,13 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc,
constexpr
index_t
HoPerThread
=
1
;
constexpr
index_t
HoPerThread
=
1
;
constexpr
index_t
WoPerThread
=
2
;
constexpr
index_t
WoPerThread
=
2
;
constexpr
index_t
InBlockCopy_ThreadPerDimC
=
1
;
constexpr
index_t
InBlockCopy_ThreadPerDimC
=
1
;
constexpr
index_t
InBlockCopy_ThreadPerDimH
=
4
;
constexpr
index_t
InBlockCopy_ThreadPerDimH
=
4
;
constexpr
index_t
InBlockCopy_ThreadPerDimW
=
4
;
constexpr
index_t
InBlockCopy_ThreadPerDimW
=
4
;
constexpr
index_t
InBlockCopy_ThreadPerDimN
=
8
;
constexpr
index_t
InBlockCopy_ThreadPerDimN
=
8
;
constexpr
index_t
InBlockCopyDataPer
Read
_N
=
4
;
constexpr
index_t
InBlockCopyDataPer
Access
_N
=
4
;
constexpr
index_t
WeiBlockCopyDataPer
Read
_K
=
4
;
constexpr
index_t
WeiBlockCopyDataPer
Access
_K
=
4
;
constexpr
index_t
GemmMPerThreadSubC
=
4
;
constexpr
index_t
GemmMPerThreadSubC
=
4
;
constexpr
index_t
GemmNPerThreadSubC
=
4
;
constexpr
index_t
GemmNPerThreadSubC
=
4
;
...
@@ -298,7 +301,7 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc,
...
@@ -298,7 +301,7 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc,
constexpr
index_t
GemmNLevel1Cluster
=
4
;
constexpr
index_t
GemmNLevel1Cluster
=
4
;
constexpr
index_t
GemmKPerThreadLoop
=
1
;
constexpr
index_t
GemmKPerThreadLoop
=
1
;
constexpr
index_t
OutThreadCopyDataPer
Write
_N
=
2
;
constexpr
index_t
OutThreadCopyDataPer
Access
_N
=
2
;
constexpr
index_t
BlockSize
=
128
;
constexpr
index_t
BlockSize
=
128
;
#elif 0
#elif 0
...
@@ -324,14 +327,14 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc,
...
@@ -324,14 +327,14 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc,
constexpr
index_t
GemmDataPerReadA
=
1
;
constexpr
index_t
GemmDataPerReadA
=
1
;
constexpr
index_t
GemmDataPerReadB
=
1
;
constexpr
index_t
GemmDataPerReadB
=
1
;
constexpr
index_t
InBlockCopy_ThreadPerDimC
=
1
;
constexpr
index_t
InBlockCopy_ThreadPerDimC
=
1
;
constexpr
index_t
InBlockCopy_ThreadPerDimH
=
2
;
constexpr
index_t
InBlockCopy_ThreadPerDimH
=
2
;
constexpr
index_t
InBlockCopy_ThreadPerDimW
=
4
;
constexpr
index_t
InBlockCopy_ThreadPerDimW
=
4
;
constexpr
index_t
InBlockCopy_ThreadPerDimN
=
4
;
constexpr
index_t
InBlockCopy_ThreadPerDimN
=
4
;
constexpr
index_t
InBlockCopyDataPer
Read
_N
=
4
;
constexpr
index_t
InBlockCopyDataPer
Access
_N
=
4
;
constexpr
index_t
WeiBlockCopyDataPer
Read
_K
=
4
;
constexpr
index_t
WeiBlockCopyDataPer
Access
_K
=
4
;
constexpr
index_t
OutThreadCopyDataPer
Write
_N
=
4
;
constexpr
index_t
OutThreadCopyDataPer
Access
_N
=
4
;
constexpr
index_t
BlockSize
=
128
;
constexpr
index_t
BlockSize
=
128
;
#elif 0
#elif 0
...
@@ -347,13 +350,13 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc,
...
@@ -347,13 +350,13 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc,
constexpr
index_t
HoPerThread
=
1
;
constexpr
index_t
HoPerThread
=
1
;
constexpr
index_t
WoPerThread
=
2
;
constexpr
index_t
WoPerThread
=
2
;
constexpr
index_t
InBlockCopy_ThreadPerDimC
=
1
;
constexpr
index_t
InBlockCopy_ThreadPerDimC
=
1
;
constexpr
index_t
InBlockCopy_ThreadPerDimH
=
4
;
constexpr
index_t
InBlockCopy_ThreadPerDimH
=
4
;
constexpr
index_t
InBlockCopy_ThreadPerDimW
=
4
;
constexpr
index_t
InBlockCopy_ThreadPerDimW
=
4
;
constexpr
index_t
InBlockCopy_ThreadPerDimN
=
8
;
constexpr
index_t
InBlockCopy_ThreadPerDimN
=
8
;
constexpr
index_t
InBlockCopyDataPer
Read
_N
=
4
;
constexpr
index_t
InBlockCopyDataPer
Access
_N
=
4
;
constexpr
index_t
WeiBlockCopyDataPer
Read
_K
=
4
;
constexpr
index_t
WeiBlockCopyDataPer
Access
_K
=
4
;
constexpr
index_t
GemmMPerThreadSubC
=
4
;
constexpr
index_t
GemmMPerThreadSubC
=
4
;
constexpr
index_t
GemmNPerThreadSubC
=
4
;
constexpr
index_t
GemmNPerThreadSubC
=
4
;
...
@@ -365,7 +368,7 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc,
...
@@ -365,7 +368,7 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc,
constexpr
index_t
GemmDataPerReadA
=
4
;
constexpr
index_t
GemmDataPerReadA
=
4
;
constexpr
index_t
GemmDataPerReadB
=
4
;
constexpr
index_t
GemmDataPerReadB
=
4
;
constexpr
index_t
OutThreadCopyDataPer
Write
_N
=
2
;
constexpr
index_t
OutThreadCopyDataPer
Access
_N
=
2
;
constexpr
index_t
BlockSize
=
128
;
constexpr
index_t
BlockSize
=
128
;
#elif 0
#elif 0
...
@@ -393,12 +396,12 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc,
...
@@ -393,12 +396,12 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc,
constexpr
index_t
GemmDataPerReadA
=
4
;
constexpr
index_t
GemmDataPerReadA
=
4
;
constexpr
index_t
GemmDataPerReadB
=
4
;
constexpr
index_t
GemmDataPerReadB
=
4
;
using
InBlockCopyClusterLengths_CHWN
=
Sequence
<
4
,
2
,
4
,
4
>
;
using
InBlockCopyClusterLengths_CHWN
=
Sequence
<
4
,
2
,
4
,
4
>
;
constexpr
index_t
InBlockCopyDataPer
Read
_N
=
4
;
constexpr
index_t
InBlockCopyDataPer
Access
_N
=
4
;
constexpr
index_t
WeiBlockCopyDataPer
Read
_K
=
4
;
constexpr
index_t
WeiBlockCopyDataPer
Access
_K
=
4
;
constexpr
index_t
OutThreadCopyDataPer
Write
_N
=
2
;
constexpr
index_t
OutThreadCopyDataPer
Access
_N
=
2
;
#elif 0
#elif 0
// for 1x1, 28x28, v1r1, Pascal
// for 1x1, 28x28, v1r1, Pascal
constexpr
index_t
NPerBlock
=
16
;
constexpr
index_t
NPerBlock
=
16
;
...
@@ -413,13 +416,13 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc,
...
@@ -413,13 +416,13 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc,
constexpr
index_t
HoPerThread
=
1
;
constexpr
index_t
HoPerThread
=
1
;
constexpr
index_t
WoPerThread
=
1
;
constexpr
index_t
WoPerThread
=
1
;
constexpr
index_t
InBlockCopy_ThreadPerDimC
=
8
;
constexpr
index_t
InBlockCopy_ThreadPerDimC
=
8
;
constexpr
index_t
InBlockCopy_ThreadPerDimH
=
2
;
constexpr
index_t
InBlockCopy_ThreadPerDimH
=
2
;
constexpr
index_t
InBlockCopy_ThreadPerDimW
=
2
;
constexpr
index_t
InBlockCopy_ThreadPerDimW
=
2
;
constexpr
index_t
InBlockCopy_ThreadPerDimN
=
4
;
constexpr
index_t
InBlockCopy_ThreadPerDimN
=
4
;
constexpr
index_t
InBlockCopyDataPer
Read
_N
=
4
;
constexpr
index_t
InBlockCopyDataPer
Access
_N
=
4
;
constexpr
index_t
WeiBlockCopyDataPer
Read
_K
=
4
;
constexpr
index_t
WeiBlockCopyDataPer
Access
_K
=
4
;
constexpr
index_t
GemmMPerThreadSubC
=
4
;
constexpr
index_t
GemmMPerThreadSubC
=
4
;
constexpr
index_t
GemmNPerThreadSubC
=
4
;
constexpr
index_t
GemmNPerThreadSubC
=
4
;
...
@@ -429,7 +432,7 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc,
...
@@ -429,7 +432,7 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc,
constexpr
index_t
GemmNLevel1Cluster
=
4
;
constexpr
index_t
GemmNLevel1Cluster
=
4
;
constexpr
index_t
GemmKPerThreadLoop
=
1
;
constexpr
index_t
GemmKPerThreadLoop
=
1
;
constexpr
index_t
OutThreadCopyDataPer
Write
_N
=
2
;
constexpr
index_t
OutThreadCopyDataPer
Access
_N
=
2
;
constexpr
index_t
BlockSize
=
128
;
constexpr
index_t
BlockSize
=
128
;
#elif 0
#elif 0
...
@@ -453,14 +456,14 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc,
...
@@ -453,14 +456,14 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc,
constexpr
index_t
GemmNLevel1Cluster
=
2
;
constexpr
index_t
GemmNLevel1Cluster
=
2
;
constexpr
index_t
GemmKPerThreadLoop
=
1
;
constexpr
index_t
GemmKPerThreadLoop
=
1
;
constexpr
index_t
InBlockCopy_ThreadPerDimC
=
8
;
constexpr
index_t
InBlockCopy_ThreadPerDimC
=
8
;
constexpr
index_t
InBlockCopy_ThreadPerDimH
=
2
;
constexpr
index_t
InBlockCopy_ThreadPerDimH
=
2
;
constexpr
index_t
InBlockCopy_ThreadPerDimW
=
2
;
constexpr
index_t
InBlockCopy_ThreadPerDimW
=
2
;
constexpr
index_t
InBlockCopy_ThreadPerDimN
=
4
;
constexpr
index_t
InBlockCopy_ThreadPerDimN
=
4
;
constexpr
index_t
InBlockCopyDataPer
Read
_N
=
4
;
constexpr
index_t
InBlockCopyDataPer
Access
_N
=
4
;
constexpr
index_t
WeiBlockCopyDataPer
Read
_K
=
4
;
constexpr
index_t
WeiBlockCopyDataPer
Access
_K
=
4
;
constexpr
index_t
OutThreadCopyDataPer
Write
_N
=
2
;
constexpr
index_t
OutThreadCopyDataPer
Access
_N
=
2
;
constexpr
index_t
BlockSize
=
128
;
constexpr
index_t
BlockSize
=
128
;
#endif
#endif
...
@@ -478,9 +481,9 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc,
...
@@ -478,9 +481,9 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc,
GridwiseConvolutionImplicitGemm_v1r1_chwn_cyxk_khwn
GridwiseConvolutionImplicitGemm_v1r1_chwn_cyxk_khwn
#elif
0
#elif
0
GridwiseConvolutionImplicitGemm_v1r2_chwn_cyxk_khwn
GridwiseConvolutionImplicitGemm_v1r2_chwn_cyxk_khwn
#elif 1
GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn
#elif 0
#elif 0
GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn
#elif 1
GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn_lds_double_buffer
GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn_lds_double_buffer
#endif
#endif
<
GridSize
,
<
GridSize
,
...
@@ -507,10 +510,13 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc,
...
@@ -507,10 +510,13 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc,
GemmKPerThreadLoop
,
GemmKPerThreadLoop
,
GemmDataPerReadA
,
GemmDataPerReadA
,
GemmDataPerReadB
,
GemmDataPerReadB
,
InBlockCopySubLengths_CHWN
,
InBlockCopyClusterLengths_CHWN
,
InBlockCopyClusterLengths_CHWN
,
InBlockCopyDataPerRead_N
,
InBlockCopyDataPerAccess_N
,
WeiBlockCopyDataPerRead_K
,
WeiBlockCopySubLengths_CK
,
OutThreadCopyDataPerWrite_N
>
{};
WeiBlockCopyClusterLengths_CK
,
WeiBlockCopyDataPerAccess_K
,
OutThreadCopyDataPerAccess_N
>
{};
float
time
=
launch_kernel
(
run_gridwise_convolution_kernel
<
decltype
(
gridwise_conv
),
T
>
,
float
time
=
launch_kernel
(
run_gridwise_convolution_kernel
<
decltype
(
gridwise_conv
),
T
>
,
dim3
(
GridSize
),
dim3
(
GridSize
),
...
...
driver/src/driver.cpp
View file @
8bdaba51
...
@@ -9,7 +9,7 @@
...
@@ -9,7 +9,7 @@
#include "conv_common.hpp"
#include "conv_common.hpp"
#include "host_conv.hpp"
#include "host_conv.hpp"
#include "device_convolution_direct_v2_nchw_kcyx_nkhw.hpp"
#include "device_convolution_direct_v2_nchw_kcyx_nkhw.hpp"
//
#include "device_convolution_implicit_gemm_v1_chwn_cyxk_khwn.hpp"
#include "device_convolution_implicit_gemm_v1_chwn_cyxk_khwn.hpp"
//#include "device_convolution_implicit_gemm_v1_nchw_cyxk_nkhw.hpp"
//#include "device_convolution_implicit_gemm_v1_nchw_cyxk_nkhw.hpp"
//#include "device_convolution_implicit_gemm_v2_chwn_cyxk_khwn.hpp"
//#include "device_convolution_implicit_gemm_v2_chwn_cyxk_khwn.hpp"
//#include "device_convolution_implicit_gemm_v3_nchw_cyxk_nkhw.hpp"
//#include "device_convolution_implicit_gemm_v3_nchw_cyxk_nkhw.hpp"
...
@@ -85,7 +85,7 @@ int main(int argc, char* argv[])
...
@@ -85,7 +85,7 @@ int main(int argc, char* argv[])
constexpr
index_t
HPad
=
0
;
constexpr
index_t
HPad
=
0
;
constexpr
index_t
WPad
=
0
;
constexpr
index_t
WPad
=
0
;
#elif
0
#elif
1
// 3x3, 34x34
// 3x3, 34x34
constexpr
index_t
N
=
64
;
constexpr
index_t
N
=
64
;
constexpr
index_t
C
=
256
;
constexpr
index_t
C
=
256
;
...
@@ -367,7 +367,7 @@ int main(int argc, char* argv[])
...
@@ -367,7 +367,7 @@ int main(int argc, char* argv[])
#if 0
#if 0
device_convolution_direct_v2_nchw_kcyx_nkhw
device_convolution_direct_v2_nchw_kcyx_nkhw
(in_nchw_desc, in_nchw, wei_kcyx_desc, wei_kcyx, out_nkhw_desc, out_nkhw_device, nrepeat);
(in_nchw_desc, in_nchw, wei_kcyx_desc, wei_kcyx, out_nkhw_desc, out_nkhw_device, nrepeat);
#elif
0
#elif
1
device_convolution_implicit_gemm_v1_chwn_cyxk_khwn
(
device_convolution_implicit_gemm_v1_chwn_cyxk_khwn
(
in_nchw_desc
,
in_nchw
,
wei_kcyx_desc
,
wei_kcyx
,
out_nkhw_desc
,
out_nkhw_device
,
nrepeat
);
in_nchw_desc
,
in_nchw
,
wei_kcyx_desc
,
wei_kcyx
,
out_nkhw_desc
,
out_nkhw_device
,
nrepeat
);
#elif 0
#elif 0
...
@@ -379,7 +379,7 @@ int main(int argc, char* argv[])
...
@@ -379,7 +379,7 @@ int main(int argc, char* argv[])
#elif 0
#elif 0
device_convolution_implicit_gemm_v3_nchw_cyxk_nkhw
(
device_convolution_implicit_gemm_v3_nchw_cyxk_nkhw
(
(
in_nchw_desc
,
in_nchw
,
wei_kcyx_desc
,
wei_kcyx
,
out_nkhw_desc
,
out_nkhw_device
,
nrepeat
);
(
in_nchw_desc
,
in_nchw
,
wei_kcyx_desc
,
wei_kcyx
,
out_nkhw_desc
,
out_nkhw_device
,
nrepeat
);
#elif
1
#elif
0
device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw
(
in_nchw_desc
,
device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw
(
in_nchw_desc
,
in_nchw
,
in_nchw
,
wei_kcyx_desc
,
wei_kcyx_desc
,
...
@@ -409,7 +409,7 @@ int main(int argc, char* argv[])
...
@@ -409,7 +409,7 @@ int main(int argc, char* argv[])
ConvStrides
{},
ConvStrides
{},
ConvDilations
{},
ConvDilations
{},
nrepeat
);
nrepeat
);
#elif
1
#elif
0
device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw
(
in_nchw_desc
,
device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw
(
in_nchw_desc
,
in_nchw
,
in_nchw
,
wei_kcyx_desc
,
wei_kcyx_desc
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment