Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
yangql
composable_kernel-1
Commits
740149fc
Commit
740149fc
authored
Aug 13, 2019
by
Chao Liu
Browse files
clean up
parent
40836ab9
Changes
13
Hide whitespace changes
Inline
Side-by-side
Showing
13 changed files
with
71 additions
and
215 deletions
+71
-215
composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v1r3_chwn_cyxk_khwn.hpp
...ridwise_convolution_implicit_gemm_v1r3_chwn_cyxk_khwn.hpp
+29
-5
composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v1r3_chwn_cyxk_khwn_lds_double_buffer.hpp
...n_implicit_gemm_v1r3_chwn_cyxk_khwn_lds_double_buffer.hpp
+20
-39
composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp
...ridwise_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp
+2
-4
composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer.hpp
...n_implicit_gemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer.hpp
+2
-4
composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r2_nchw_kcyx_nkhw_lds_double_buffer.hpp
...n_implicit_gemm_v4r2_nchw_kcyx_nkhw_lds_double_buffer.hpp
+2
-4
composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r3_nchw_kcyx_nkhw_lds_double_buffer.hpp
...n_implicit_gemm_v4r3_nchw_kcyx_nkhw_lds_double_buffer.hpp
+2
-4
composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp
...ridwise_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp
+2
-4
composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer.hpp
...n_implicit_gemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer.hpp
+2
-4
composable_kernel/include/tensor_description/ConstantMatrixDescriptor.hpp
...l/include/tensor_description/ConstantMatrixDescriptor.hpp
+3
-3
composable_kernel/include/tensor_operation/blockwise_batched_gemm.hpp
...ernel/include/tensor_operation/blockwise_batched_gemm.hpp
+6
-139
composable_kernel/include/utility/config_amd.hpp.in
composable_kernel/include/utility/config_amd.hpp.in
+0
-2
composable_kernel/include/utility/config_nvidia.hpp.in
composable_kernel/include/utility/config_nvidia.hpp.in
+0
-2
driver/src/driver.cpp
driver/src/driver.cpp
+1
-1
No files found.
composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v1r3_chwn_cyxk_khwn.hpp
View file @
740149fc
...
@@ -126,7 +126,7 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn
...
@@ -126,7 +126,7 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn
// blockwise copy
// blockwise copy
// input: format is [C, Hi, Wi, N]
// input: format is [C, Hi, Wi, N]
auto
blockwise_in_copy
=
auto
blockwise_in_copy
=
BlockwiseGenericTensorSliceCopy_v
2
<
BlockSize
,
BlockwiseGenericTensorSliceCopy_v
1
<
BlockSize
,
decltype
(
in_c_h_w_n_global_desc
),
decltype
(
in_c_h_w_n_global_desc
),
decltype
(
in_c_h_w_n_block_desc
),
decltype
(
in_c_h_w_n_block_desc
),
decltype
(
in_c_h_w_n_block_desc
.
GetLengths
()),
decltype
(
in_c_h_w_n_block_desc
.
GetLengths
()),
...
@@ -142,9 +142,9 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn
...
@@ -142,9 +142,9 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn
{
0
,
0
,
0
,
0
});
{
0
,
0
,
0
,
0
});
// blockwise wei copy
// blockwise wei copy
// format is [CPerBlock, KPerBlock]
// format is [CPerBlock,
X *
KPerBlock]
const
auto
blockwise_wei_copy
=
const
auto
blockwise_wei_copy
=
BlockwiseGenericTensorSliceCopy_v
2
<
BlockSize
,
BlockwiseGenericTensorSliceCopy_v
1
<
BlockSize
,
decltype
(
wei_c_k_global_desc
),
decltype
(
wei_c_k_global_desc
),
decltype
(
wei_c_k_block_desc
),
decltype
(
wei_c_k_block_desc
),
decltype
(
wei_c_k_block_desc
.
GetLengths
()),
decltype
(
wei_c_k_block_desc
.
GetLengths
()),
...
@@ -317,7 +317,18 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn
...
@@ -317,7 +317,18 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn
wo_block_data_begin
+
wo_thread_data_begin
,
wo_block_data_begin
+
wo_thread_data_begin
,
n_block_data_begin
+
n_thread_data_begin
);
n_block_data_begin
+
n_thread_data_begin
);
ThreadwiseGenericTensorSliceCopy_v2r1
<
decltype
(
out_10d_thread_desc
),
#if 1
ThreadwiseGenericTensorSliceCopy_v1r2
<
decltype
(
out_10d_thread_desc
),
decltype
(
out_10d_global_desc
),
decltype
(
out_10d_thread_desc
.
GetLengths
()),
arithmetic_sequence_gen
<
0
,
10
,
1
>::
type
,
9
,
OutThreadCopyDataPerAccess_N
,
OutThreadCopyDataPerAccess_N
>
(
make_zero_array
<
index_t
,
10
>
(),
make_zero_array
<
index_t
,
10
>
())
.
Run
(
p_out_thread
,
p_out_thread_on_global
);
#elif 0
ThreadwiseGenericTensorSliceCopy_v1r1
<
decltype
(
out_10d_thread_desc
),
decltype
(
out_10d_global_desc
),
decltype
(
out_10d_global_desc
),
decltype
(
out_10d_thread_desc
.
GetLengths
()),
decltype
(
out_10d_thread_desc
.
GetLengths
()),
arithmetic_sequence_gen
<
0
,
10
,
1
>::
type
,
arithmetic_sequence_gen
<
0
,
10
,
1
>::
type
,
...
@@ -328,6 +339,7 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn
...
@@ -328,6 +339,7 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn
OutThreadCopyDataPerAccess_N
>
(
OutThreadCopyDataPerAccess_N
>
(
make_zero_array
<
index_t
,
10
>
(),
make_zero_array
<
index_t
,
10
>
())
make_zero_array
<
index_t
,
10
>
(),
make_zero_array
<
index_t
,
10
>
())
.
Run
(
p_out_thread
,
p_out_thread_on_global
);
.
Run
(
p_out_thread
,
p_out_thread_on_global
);
#endif
}).
Else
([
&
](
auto
fwd
)
{
}).
Else
([
&
](
auto
fwd
)
{
static_assert
(
fwd
(
GemmNPerThreadSubC
)
>=
NPerBlock
&&
NPerThread
==
NPerBlock
&&
static_assert
(
fwd
(
GemmNPerThreadSubC
)
>=
NPerBlock
&&
NPerThread
==
NPerBlock
&&
GemmNPerThreadSubC
%
NPerThread
==
0
,
GemmNPerThreadSubC
%
NPerThread
==
0
,
...
@@ -375,7 +387,18 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn
...
@@ -375,7 +387,18 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn
wo_block_data_begin
+
wo_thread_data_begin
,
wo_block_data_begin
+
wo_thread_data_begin
,
n_block_data_begin
+
n_thread_data_begin
);
n_block_data_begin
+
n_thread_data_begin
);
ThreadwiseGenericTensorSliceCopy_v2r1
<
decltype
(
out_10d_thread_desc
),
#if 1
ThreadwiseGenericTensorSliceCopy_v1r2
<
decltype
(
out_10d_thread_desc
),
decltype
(
out_10d_global_desc
),
decltype
(
out_10d_thread_desc
.
GetLengths
()),
arithmetic_sequence_gen
<
0
,
10
,
1
>::
type
,
9
,
OutThreadCopyDataPerAccess_N
,
OutThreadCopyDataPerAccess_N
>
(
make_zero_array
<
index_t
,
10
>
(),
make_zero_array
<
index_t
,
10
>
())
.
Run
(
p_out_thread
,
p_out_thread_on_global
);
#elif 0
ThreadwiseGenericTensorSliceCopy_v1r1
<
decltype
(
out_10d_thread_desc
),
decltype
(
out_10d_global_desc
),
decltype
(
out_10d_global_desc
),
decltype
(
out_10d_thread_desc
.
GetLengths
()),
decltype
(
out_10d_thread_desc
.
GetLengths
()),
arithmetic_sequence_gen
<
0
,
10
,
1
>::
type
,
arithmetic_sequence_gen
<
0
,
10
,
1
>::
type
,
...
@@ -386,6 +409,7 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn
...
@@ -386,6 +409,7 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn
OutThreadCopyDataPerAccess_N
>
(
OutThreadCopyDataPerAccess_N
>
(
make_zero_array
<
index_t
,
10
>
(),
make_zero_array
<
index_t
,
10
>
())
make_zero_array
<
index_t
,
10
>
(),
make_zero_array
<
index_t
,
10
>
())
.
Run
(
p_out_thread
,
p_out_thread_on_global
);
.
Run
(
p_out_thread
,
p_out_thread_on_global
);
#endif
});
});
}
}
};
};
...
...
composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v1r3_chwn_cyxk_khwn_lds_double_buffer.hpp
View file @
740149fc
...
@@ -7,10 +7,6 @@
...
@@ -7,10 +7,6 @@
#include "blockwise_generic_tensor_slice_copy.hpp"
#include "blockwise_generic_tensor_slice_copy.hpp"
#include "threadwise_generic_tensor_slice_copy.hpp"
#include "threadwise_generic_tensor_slice_copy.hpp"
#include "blockwise_batched_gemm.hpp"
#include "blockwise_batched_gemm.hpp"
#include "blockwise_2d_tensor_op.hpp"
#include "blockwise_4d_tensor_op.hpp"
#include "threadwise_tensor_slice_copy.hpp"
#include "threadwise_4d_tensor_op.hpp"
namespace
ck
{
namespace
ck
{
...
@@ -133,18 +129,8 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn_lds_double_buffer
...
@@ -133,18 +129,8 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn_lds_double_buffer
constexpr
auto
out_k_h_w_n_thread_desc
=
make_ConstantTensorDescriptor_packed
(
constexpr
auto
out_k_h_w_n_thread_desc
=
make_ConstantTensorDescriptor_packed
(
Sequence
<
KPerThread
,
HoPerThread
,
WoPerThread
,
NPerThread
>
{});
Sequence
<
KPerThread
,
HoPerThread
,
WoPerThread
,
NPerThread
>
{});
#if 1
// blockwise copy
// blockwise copy
// input: format is [C, Hi, Wi, N]
// input: format is [C, Hi, Wi, N]
const
auto
blockwise_in_copy
=
Blockwise4dTensorCopy3
<
BlockSize
,
Float
,
decltype
(
in_c_h_w_n_global_desc
),
decltype
(
in_c_h_w_n_block_desc
),
decltype
(
in_c_h_w_n_block_desc
.
GetLengths
()),
InBlockCopyClusterLengths_CHWN
,
InBlockCopyDataPerAccess_N
>
{};
#else
auto
blockwise_in_copy
=
auto
blockwise_in_copy
=
BlockwiseGenericTensorSliceCopy_v1
<
BlockSize
,
BlockwiseGenericTensorSliceCopy_v1
<
BlockSize
,
decltype
(
in_c_h_w_n_global_desc
),
decltype
(
in_c_h_w_n_global_desc
),
...
@@ -160,19 +146,9 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn_lds_double_buffer
...
@@ -160,19 +146,9 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn_lds_double_buffer
InBlockCopyDataPerAccess_N
,
InBlockCopyDataPerAccess_N
,
InBlockCopyDataPerAccess_N
>
({
0
,
0
,
0
,
0
},
InBlockCopyDataPerAccess_N
>
({
0
,
0
,
0
,
0
},
{
0
,
0
,
0
,
0
});
{
0
,
0
,
0
,
0
});
#endif
#if 1
// blockwise wei copy
// blockwise wei copy
// format is [CPerBlock, X * KPerBlock]
// format is [CPerBlock, X * KPerBlock]
const
auto
blockwise_wei_copy
=
Blockwise2dTensorCopy3
<
BlockSize
,
Float
,
decltype
(
wei_c_k_global_desc
),
decltype
(
wei_c_k_block_desc
),
decltype
(
wei_c_k_block_desc
.
GetLengths
()),
WeiBlockCopyDataPerAccess_K
>
({
0
,
0
},
{
0
,
0
});
#else
const
auto
blockwise_wei_copy
=
const
auto
blockwise_wei_copy
=
BlockwiseGenericTensorSliceCopy_v1
<
BlockSize
,
BlockwiseGenericTensorSliceCopy_v1
<
BlockSize
,
decltype
(
wei_c_k_global_desc
),
decltype
(
wei_c_k_global_desc
),
...
@@ -187,7 +163,6 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn_lds_double_buffer
...
@@ -187,7 +163,6 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn_lds_double_buffer
1
,
1
,
WeiBlockCopyDataPerAccess_K
,
WeiBlockCopyDataPerAccess_K
,
WeiBlockCopyDataPerAccess_K
>
({
0
,
0
},
{
0
,
0
});
WeiBlockCopyDataPerAccess_K
>
({
0
,
0
},
{
0
,
0
});
#endif
// a series of blockwise batched GEMM
// a series of blockwise batched GEMM
// C_matrix += transpose(A_matrix) * B_matrix
// C_matrix += transpose(A_matrix) * B_matrix
...
@@ -428,13 +403,16 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn_lds_double_buffer
...
@@ -428,13 +403,16 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn_lds_double_buffer
n_block_data_begin
+
n_thread_data_begin
);
n_block_data_begin
+
n_thread_data_begin
);
#if 1
#if 1
threadwise_tensor_slice_copy
(
out_10d_thread_desc
,
ThreadwiseGenericTensorSliceCopy_v1r2
<
decltype
(
out_10d_thread_desc
),
p_out_thread
,
decltype
(
out_10d_global_desc
),
out_10d_global_desc
,
decltype
(
out_10d_thread_desc
.
GetLengths
()),
p_out_thread_on_global
,
arithmetic_sequence_gen
<
0
,
10
,
1
>::
type
,
out_10d_thread_desc
.
GetLengths
(),
9
,
Number
<
OutThreadCopyDataPerAccess_N
>
{});
OutThreadCopyDataPerAccess_N
,
#else
OutThreadCopyDataPerAccess_N
>
(
make_zero_array
<
index_t
,
10
>
(),
make_zero_array
<
index_t
,
10
>
())
.
Run
(
p_out_thread
,
p_out_thread_on_global
);
#elif 0
ThreadwiseGenericTensorSliceCopy_v1r1
<
decltype
(
out_10d_thread_desc
),
ThreadwiseGenericTensorSliceCopy_v1r1
<
decltype
(
out_10d_thread_desc
),
decltype
(
out_10d_global_desc
),
decltype
(
out_10d_global_desc
),
decltype
(
out_10d_thread_desc
.
GetLengths
()),
decltype
(
out_10d_thread_desc
.
GetLengths
()),
...
@@ -495,13 +473,16 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn_lds_double_buffer
...
@@ -495,13 +473,16 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn_lds_double_buffer
n_block_data_begin
+
n_thread_data_begin
);
n_block_data_begin
+
n_thread_data_begin
);
#if 1
#if 1
threadwise_tensor_slice_copy
(
out_10d_thread_desc
,
ThreadwiseGenericTensorSliceCopy_v1r2
<
decltype
(
out_10d_thread_desc
),
p_out_thread
,
decltype
(
out_10d_global_desc
),
out_10d_global_desc
,
decltype
(
out_10d_thread_desc
.
GetLengths
()),
p_out_thread_on_global
,
arithmetic_sequence_gen
<
0
,
10
,
1
>::
type
,
out_10d_thread_desc
.
GetLengths
(),
9
,
Number
<
OutThreadCopyDataPerAccess_N
>
{});
OutThreadCopyDataPerAccess_N
,
#else
OutThreadCopyDataPerAccess_N
>
(
make_zero_array
<
index_t
,
10
>
(),
make_zero_array
<
index_t
,
10
>
())
.
Run
(
p_out_thread
,
p_out_thread_on_global
);
#elif 0
ThreadwiseGenericTensorSliceCopy_v1r1
<
decltype
(
out_10d_thread_desc
),
ThreadwiseGenericTensorSliceCopy_v1r1
<
decltype
(
out_10d_thread_desc
),
decltype
(
out_10d_global_desc
),
decltype
(
out_10d_global_desc
),
decltype
(
out_10d_thread_desc
.
GetLengths
()),
decltype
(
out_10d_thread_desc
.
GetLengths
()),
...
...
composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp
View file @
740149fc
...
@@ -234,12 +234,10 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw
...
@@ -234,12 +234,10 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw
// b_mtx[EPerBlocl, N1 * BPerBlock * N2] is in LDS
// b_mtx[EPerBlocl, N1 * BPerBlock * N2] is in LDS
// c_mtx[KPerBlock, N1 * BPerBlock * N2] is distributed among threads, and saved in
// c_mtx[KPerBlock, N1 * BPerBlock * N2] is distributed among threads, and saved in
// register
// register
constexpr
auto
a_e_k_block_mtx_desc
=
constexpr
auto
a_e_k_block_mtx_desc
=
make_ConstantMatrixDescriptor
(
wei_e_k_block_desc
);
make_ConstantMatrixDescriptor_from_ConstantTensorDescriptor
(
wei_e_k_block_desc
);
constexpr
auto
b_e_n1bn2_block_mtx_desc
=
constexpr
auto
b_e_n1bn2_block_mtx_desc
=
make_ConstantMatrixDescriptor_from_ConstantTensorDescriptor
(
make_ConstantMatrixDescriptor
(
in_e_n1_b_n2_block_desc
.
Unfold
(
I1
,
I3
));
in_e_n1_b_n2_block_desc
.
Unfold
(
I1
,
I3
));
// sanity check
// sanity check
static_assert
(
KPerBlock
%
(
GemmMPerThreadSubC
*
GemmMLevel0Cluster
*
GemmMLevel1Cluster
)
==
static_assert
(
KPerBlock
%
(
GemmMPerThreadSubC
*
GemmMLevel0Cluster
*
GemmMLevel1Cluster
)
==
...
...
composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer.hpp
View file @
740149fc
...
@@ -247,12 +247,10 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer
...
@@ -247,12 +247,10 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer
// b_mtx[EPerBlocl, N1 * BPerBlock * N2] is in LDS
// b_mtx[EPerBlocl, N1 * BPerBlock * N2] is in LDS
// c_mtx[KPerBlock, N1 * BPerBlock * N2] is distributed among threads, and saved in
// c_mtx[KPerBlock, N1 * BPerBlock * N2] is distributed among threads, and saved in
// register
// register
constexpr
auto
a_e_k_block_mtx_desc
=
constexpr
auto
a_e_k_block_mtx_desc
=
make_ConstantMatrixDescriptor
(
wei_e_k_block_desc
);
make_ConstantMatrixDescriptor_from_ConstantTensorDescriptor
(
wei_e_k_block_desc
);
constexpr
auto
b_e_n1bn2_block_mtx_desc
=
constexpr
auto
b_e_n1bn2_block_mtx_desc
=
make_ConstantMatrixDescriptor_from_ConstantTensorDescriptor
(
make_ConstantMatrixDescriptor
(
in_e_n1_b_n2_block_desc
.
Unfold
(
I1
,
I3
));
in_e_n1_b_n2_block_desc
.
Unfold
(
I1
,
I3
));
// sanity check
// sanity check
static_assert
(
KPerBlock
%
(
GemmMPerThreadSubC
*
GemmMLevel0Cluster
*
GemmMLevel1Cluster
)
==
static_assert
(
KPerBlock
%
(
GemmMPerThreadSubC
*
GemmMLevel0Cluster
*
GemmMLevel1Cluster
)
==
...
...
composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r2_nchw_kcyx_nkhw_lds_double_buffer.hpp
View file @
740149fc
...
@@ -222,8 +222,7 @@ struct GridwiseConvolutionImplicitGemm_v4r2_nchw_kcyx_nkhw_lds_double_buffer
...
@@ -222,8 +222,7 @@ struct GridwiseConvolutionImplicitGemm_v4r2_nchw_kcyx_nkhw_lds_double_buffer
// b_mtx[EPerBlocl, N1 * BPerBlock * N2] is in LDS
// b_mtx[EPerBlocl, N1 * BPerBlock * N2] is in LDS
// c_mtx[KPerBlock, N1 * BPerBlock * N2] is distributed among threads, and saved in
// c_mtx[KPerBlock, N1 * BPerBlock * N2] is distributed among threads, and saved in
// register
// register
constexpr
auto
a_e_k_block_mtx_desc
=
constexpr
auto
a_e_k_block_mtx_desc
=
make_ConstantMatrixDescriptor
(
wei_e_k_block_desc
);
make_ConstantMatrixDescriptor_from_ConstantTensorDescriptor
(
wei_e_k_block_desc
);
// this check is ad-hoc
// this check is ad-hoc
// TODO: need to properly implement tensor descriptor with multiple alignment
// TODO: need to properly implement tensor descriptor with multiple alignment
...
@@ -233,8 +232,7 @@ struct GridwiseConvolutionImplicitGemm_v4r2_nchw_kcyx_nkhw_lds_double_buffer
...
@@ -233,8 +232,7 @@ struct GridwiseConvolutionImplicitGemm_v4r2_nchw_kcyx_nkhw_lds_double_buffer
"GemmDataPerReadB alignment requirement is not satisfied"
);
"GemmDataPerReadB alignment requirement is not satisfied"
);
constexpr
auto
b_e_n0ho0wo0bn2ho2wo2_block_mtx_desc
=
constexpr
auto
b_e_n0ho0wo0bn2ho2wo2_block_mtx_desc
=
make_ConstantMatrixDescriptor_from_ConstantTensorDescriptor
(
make_ConstantMatrixDescriptor
(
in_e_n0_ho0_wo0_b_n2_ho2_wo2_block_desc
.
Unfold
(
I1
,
I7
));
in_e_n0_ho0_wo0_b_n2_ho2_wo2_block_desc
.
Unfold
(
I1
,
I7
));
// sanity check
// sanity check
static_assert
(
KPerBlock
%
(
GemmMPerThreadSubC
*
GemmMLevel0Cluster
*
GemmMLevel1Cluster
)
==
static_assert
(
KPerBlock
%
(
GemmMPerThreadSubC
*
GemmMLevel0Cluster
*
GemmMLevel1Cluster
)
==
...
...
composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r3_nchw_kcyx_nkhw_lds_double_buffer.hpp
View file @
740149fc
...
@@ -228,8 +228,7 @@ struct GridwiseConvolutionImplicitGemm_v4r3_nchw_kcyx_nkhw_lds_double_buffer
...
@@ -228,8 +228,7 @@ struct GridwiseConvolutionImplicitGemm_v4r3_nchw_kcyx_nkhw_lds_double_buffer
// b_mtx[EPerBlocl, N1 * BPerBlock * N2] is in LDS
// b_mtx[EPerBlocl, N1 * BPerBlock * N2] is in LDS
// c_mtx[KPerBlock, N1 * BPerBlock * N2] is distributed among threads, and saved in
// c_mtx[KPerBlock, N1 * BPerBlock * N2] is distributed among threads, and saved in
// register
// register
constexpr
auto
a_e_k_block_mtx_desc
=
constexpr
auto
a_e_k_block_mtx_desc
=
make_ConstantMatrixDescriptor
(
wei_e_k_block_desc
);
make_ConstantMatrixDescriptor_from_ConstantTensorDescriptor
(
wei_e_k_block_desc
);
// this check is ad-hoc
// this check is ad-hoc
// TODO: need to properly implement tensor descriptor with multiple alignment
// TODO: need to properly implement tensor descriptor with multiple alignment
...
@@ -239,8 +238,7 @@ struct GridwiseConvolutionImplicitGemm_v4r3_nchw_kcyx_nkhw_lds_double_buffer
...
@@ -239,8 +238,7 @@ struct GridwiseConvolutionImplicitGemm_v4r3_nchw_kcyx_nkhw_lds_double_buffer
"GemmDataPerReadB alignment requirement is not satisfied"
);
"GemmDataPerReadB alignment requirement is not satisfied"
);
constexpr
auto
b_e_n1ho1wo1bn2ho2wo2_block_mtx_desc
=
constexpr
auto
b_e_n1ho1wo1bn2ho2wo2_block_mtx_desc
=
make_ConstantMatrixDescriptor_from_ConstantTensorDescriptor
(
make_ConstantMatrixDescriptor
(
in_e_n1_ho1_wo1_b_n2_ho2_wo2_block_desc
.
Unfold
(
I1
,
I7
));
in_e_n1_ho1_wo1_b_n2_ho2_wo2_block_desc
.
Unfold
(
I1
,
I7
));
// sanity check
// sanity check
static_assert
(
KPerBlock
%
(
GemmMPerThreadSubC
*
GemmMLevel0Cluster
*
GemmMLevel1Cluster
)
==
static_assert
(
KPerBlock
%
(
GemmMPerThreadSubC
*
GemmMLevel0Cluster
*
GemmMLevel1Cluster
)
==
...
...
composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp
View file @
740149fc
...
@@ -172,11 +172,9 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw
...
@@ -172,11 +172,9 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw
// b_mtx[EPerBlocl, BPerBlock] is in LDS
// b_mtx[EPerBlocl, BPerBlock] is in LDS
// c_mtx[KPerBlock, BPerBlock] is distributed among threads, and saved in
// c_mtx[KPerBlock, BPerBlock] is distributed among threads, and saved in
// register
// register
constexpr
auto
a_e_k_block_mtx_desc
=
constexpr
auto
a_e_k_block_mtx_desc
=
make_ConstantMatrixDescriptor
(
wei_e_k_block_desc
);
make_ConstantMatrixDescriptor_from_ConstantTensorDescriptor
(
wei_e_k_block_desc
);
constexpr
auto
b_e_b_block_mtx_desc
=
constexpr
auto
b_e_b_block_mtx_desc
=
make_ConstantMatrixDescriptor
(
in_e_b_block_desc
);
make_ConstantMatrixDescriptor_from_ConstantTensorDescriptor
(
in_e_b_block_desc
);
// sanity check
// sanity check
static_assert
(
static_assert
(
...
...
composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer.hpp
View file @
740149fc
...
@@ -182,11 +182,9 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer
...
@@ -182,11 +182,9 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer
// b_mtx[EPerBlocl, BPerBlock] is in LDS
// b_mtx[EPerBlocl, BPerBlock] is in LDS
// c_mtx[KPerBlock, BPerBlock] is distributed among threads, and saved in
// c_mtx[KPerBlock, BPerBlock] is distributed among threads, and saved in
// register
// register
constexpr
auto
a_e_k_block_mtx_desc
=
constexpr
auto
a_e_k_block_mtx_desc
=
make_ConstantMatrixDescriptor
(
wei_e_k_block_desc
);
make_ConstantMatrixDescriptor_from_ConstantTensorDescriptor
(
wei_e_k_block_desc
);
constexpr
auto
b_e_b_block_mtx_desc
=
constexpr
auto
b_e_b_block_mtx_desc
=
make_ConstantMatrixDescriptor
(
in_e_b_block_desc
);
make_ConstantMatrixDescriptor_from_ConstantTensorDescriptor
(
in_e_b_block_desc
);
// sanity check
// sanity check
static_assert
(
static_assert
(
...
...
composable_kernel/include/tensor_description/ConstantMatrixDescriptor.hpp
View file @
740149fc
...
@@ -52,10 +52,10 @@ __host__ __device__ constexpr auto
...
@@ -52,10 +52,10 @@ __host__ __device__ constexpr auto
return
ConstantMatrixDescriptor
<
NRow
,
NCol
,
RowStride
>
{};
return
ConstantMatrixDescriptor
<
NRow
,
NCol
,
RowStride
>
{};
}
}
template
<
class
TDesc
>
template
<
class
...
Ts
>
__host__
__device__
constexpr
auto
__host__
__device__
constexpr
auto
make_ConstantMatrixDescriptor
(
ConstantTensorDescriptor
<
Ts
...
>
)
make_ConstantMatrixDescriptor_from_ConstantTensorDescriptor
(
TDesc
)
{
{
using
TDesc
=
ConstantTensorDescriptor
<
Ts
...
>
;
static_assert
(
TDesc
::
GetNumOfDimension
()
==
2
,
"wrong"
);
static_assert
(
TDesc
::
GetNumOfDimension
()
==
2
,
"wrong"
);
static_assert
(
TDesc
::
GetStrides
()[
1
]
==
1
,
"wrong"
);
static_assert
(
TDesc
::
GetStrides
()[
1
]
==
1
,
"wrong"
);
return
ConstantMatrixDescriptor
<
TDesc
::
GetLengths
()[
0
],
return
ConstantMatrixDescriptor
<
TDesc
::
GetLengths
()[
0
],
...
...
composable_kernel/include/tensor_operation/blockwise_batched_gemm.hpp
View file @
740149fc
...
@@ -5,6 +5,10 @@
...
@@ -5,6 +5,10 @@
#include "ConstantMatrixDescriptor.hpp"
#include "ConstantMatrixDescriptor.hpp"
#include "threadwise_gemm.hpp"
#include "threadwise_gemm.hpp"
#ifndef CK_BLOCKWISE_GEMM_USE_AMD_INLINE_ASM
#define CK_BLOCKWISE_GEMM_USE_AMD_INLINE_ASM 1
#endif
namespace
ck
{
namespace
ck
{
template
<
index_t
BlockSize
,
template
<
index_t
BlockSize
,
...
@@ -97,24 +101,6 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
...
@@ -97,24 +101,6 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
mMyThreadOffsetB
=
c_thread_mtx_index
.
batch
*
BlockMatrixStrideB
+
mMyThreadOffsetB
=
c_thread_mtx_index
.
batch
*
BlockMatrixStrideB
+
b_block_mtx
.
GetOffsetFromMultiIndex
(
0
,
c_thread_mtx_index
.
col
);
b_block_mtx
.
GetOffsetFromMultiIndex
(
0
,
c_thread_mtx_index
.
col
);
#if 0
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
{
print_ConstantMatrixDescriptor(BlockMatrixA{}, "a_block_mtx: ");
print_ConstantMatrixDescriptor(BlockMatrixB{}, "b_block_mtx: ");
print_ConstantMatrixDescriptor(ThreadMatrixC{}, "c_thread_mtx: ");
printf("%u %u, %u %u %u, %u %u\n",
get_block_1d_id(),
get_thread_local_1d_id(),
c_thread_mtx_index.batch,
c_thread_mtx_index.row,
c_thread_mtx_index.col,
mMyThreadOffsetA,
mMyThreadOffsetB);
}
#endif
}
}
__device__
MatrixIndex
GetBeginOfThreadMatrixC
(
index_t
thread_id
)
const
__device__
MatrixIndex
GetBeginOfThreadMatrixC
(
index_t
thread_id
)
const
...
@@ -257,29 +243,6 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
...
@@ -257,29 +243,6 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
}
}
}
}
#if 0
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
{
printf("a: %f %f %f %f %f %f %f %f, b: %f %f %f %f %f %f %f %f\n",
p_a_thread[0],
p_a_thread[1],
p_a_thread[2],
p_a_thread[3],
p_a_thread[4],
p_a_thread[5],
p_a_thread[6],
p_a_thread[7],
p_b_thread[0],
p_b_thread[1],
p_b_thread[2],
p_b_thread[3],
p_b_thread[4],
p_b_thread[5],
p_b_thread[6],
p_b_thread[7]);
}
#endif
threadwise_gemm
(
a_thread_mtx
,
threadwise_gemm
(
a_thread_mtx
,
True
,
True
,
p_a_thread
,
p_a_thread
,
...
@@ -311,10 +274,10 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
...
@@ -311,10 +274,10 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
// thread A, B for GEMM
// thread A, B for GEMM
// A is transposed, b is not
// A is transposed, b is not
constexpr
auto
a_thread_mtx
=
constexpr
auto
a_thread_mtx
=
make_ConstantMatrixDescriptor
(
Number
<
KPerThreadLoop
>
{},
Number
<
MPerThread
>
{});
make_ConstantMatrixDescriptor
_packed
(
Number
<
KPerThreadLoop
>
{},
Number
<
MPerThread
>
{});
constexpr
auto
b_thread_mtx
=
constexpr
auto
b_thread_mtx
=
make_ConstantMatrixDescriptor
(
Number
<
KPerThreadLoop
>
{},
Number
<
NPerThread
>
{});
make_ConstantMatrixDescriptor
_packed
(
Number
<
KPerThreadLoop
>
{},
Number
<
NPerThread
>
{});
// thread A-sub, B-sub for copy
// thread A-sub, B-sub for copy
constexpr
auto
a_thread_sub_mtx
=
make_ConstantMatrixDescriptor
(
constexpr
auto
a_thread_sub_mtx
=
make_ConstantMatrixDescriptor
(
...
@@ -382,102 +345,6 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
...
@@ -382,102 +345,6 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
outerProduct4x4
(
reg_a
[
1
],
reg_b
[
0
],
reg_c
[
8
],
reg_c
[
10
],
reg_c
[
12
],
reg_c
[
14
]);
outerProduct4x4
(
reg_a
[
1
],
reg_b
[
0
],
reg_c
[
8
],
reg_c
[
10
],
reg_c
[
12
],
reg_c
[
14
]);
outerProduct4x4
(
reg_a
[
1
],
reg_b
[
1
],
reg_c
[
9
],
reg_c
[
11
],
reg_c
[
13
],
reg_c
[
15
]);
outerProduct4x4
(
reg_a
[
1
],
reg_b
[
1
],
reg_c
[
9
],
reg_c
[
11
],
reg_c
[
13
],
reg_c
[
15
]);
}
}
template
<
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
void
Run_asm_v2
(
const
FloatA
*
__restrict__
p_a_block
,
const
FloatB
*
__restrict__
p_b_block
,
FloatC
*
__restrict__
p_c_thread
)
const
{
constexpr
auto
a_block_mtx
=
BlockMatrixA
{};
constexpr
auto
b_block_mtx
=
BlockMatrixB
{};
constexpr
auto
c_thread_mtx
=
ThreadMatrixC
{};
constexpr
index_t
M
=
a_block_mtx
.
NCol
();
constexpr
index_t
N
=
b_block_mtx
.
NCol
();
constexpr
index_t
K
=
a_block_mtx
.
NRow
();
// A is transposed
constexpr
index_t
MPerThread
=
c_thread_mtx
.
NRow
();
constexpr
index_t
NPerThread
=
c_thread_mtx
.
NCol
();
// thread A, B for GEMM
// A is transposed, b is not
constexpr
auto
a_thread_mtx
=
make_ConstantMatrixDescriptor
(
Number
<
KPerThreadLoop
>
{},
Number
<
MPerThread
>
{});
constexpr
auto
b_thread_mtx
=
make_ConstantMatrixDescriptor
(
Number
<
KPerThreadLoop
>
{},
Number
<
NPerThread
>
{});
// thread A-sub, B-sub for copy
constexpr
auto
a_thread_sub_mtx
=
make_ConstantMatrixDescriptor
(
Number
<
KPerThreadLoop
>
{},
Number
<
MPerThreadSubC
>
{},
Number
<
MPerThread
>
{});
constexpr
auto
b_thread_sub_mtx
=
make_ConstantMatrixDescriptor
(
Number
<
KPerThreadLoop
>
{},
Number
<
NPerThreadSubC
>
{},
Number
<
NPerThread
>
{});
FloatA
p_a_thread
[
a_thread_mtx
.
GetElementSpace
()];
FloatB
p_b_thread
[
b_thread_mtx
.
GetElementSpace
()];
constexpr
index_t
MPerLevel1Cluster
=
MPerThreadSubC
*
MLevel0Cluster
*
MLevel1Cluster
;
constexpr
index_t
NPerLevel1Cluster
=
NPerThreadSubC
*
NLevel0Cluster
*
NLevel1Cluster
;
// assertion for inline asm
static_assert
(
is_same
<
FloatA
,
float
>
{}
&&
is_same
<
FloatB
,
float
>
{}
&&
is_same
<
FloatC
,
float
>
{},
"Run_amd_asm only deal with float
\n
"
);
static_assert
(
MPerThreadSubC
==
4
&&
NPerThreadSubC
==
4
&&
KPerThreadLoop
==
1
&&
MPerThread
==
8
&&
NPerThread
==
8
,
"Run_amd_asm cannot deal with this GEMM shape yet
\n
"
);
static_assert
(
DataPerReadA
==
4
&&
DataPerReadB
==
4
,
"Run_amd_asm only do float4 read
\n
"
);
static_assert
(
BlockMatrixStrideA
==
0
&&
BatchPerThread
==
1
,
"Run_amd_asm can only deal with BlockMatrixStrideA == 0 && BatchPerThread == "
"1 for now
\n
"
);
using
Float4
=
vector_type
<
float
,
4
>::
MemoryType
;
Float4
*
reg_a
=
(
Float4
*
)(
p_a_thread
);
Float4
*
reg_b
=
(
Float4
*
)(
p_b_thread
);
Float4
*
reg_c
=
(
Float4
*
)(
p_c_thread
);
void
*
a_lds_loc
=
(
void
*
)(
p_a_block
+
mMyThreadOffsetA
);
void
*
b_lds_loc
=
(
void
*
)(
p_b_block
+
mMyThreadOffsetB
);
constexpr
index_t
a_lds_row_stride
=
sizeof
(
float
)
*
a_block_mtx
.
RowStride
();
constexpr
index_t
b_lds_row_stride
=
sizeof
(
float
)
*
b_block_mtx
.
RowStride
();
constexpr
index_t
a_lds_cluster_col_stride
=
sizeof
(
float
)
*
MPerLevel1Cluster
;
constexpr
index_t
b_lds_cluster_col_stride
=
sizeof
(
float
)
*
NPerLevel1Cluster
;
ds_read_b128
(
reg_a
[
0
],
a_lds_loc
,
0
);
ds_read_b128
(
reg_b
[
0
],
b_lds_loc
,
0
);
ds_read_b128
(
reg_b
[
1
],
b_lds_loc
,
b_lds_cluster_col_stride
);
ds_read_b128
(
reg_a
[
1
],
a_lds_loc
,
a_lds_cluster_col_stride
);
lgkmcnt
(
2
);
outerProduct4x4
(
reg_a
[
0
],
reg_b
[
0
],
reg_c
[
0
],
reg_c
[
2
],
reg_c
[
4
],
reg_c
[
6
]);
lgkmcnt
(
1
);
outerProduct4x4
(
reg_a
[
0
],
reg_b
[
1
],
reg_c
[
1
],
reg_c
[
3
],
reg_c
[
5
],
reg_c
[
7
]);
#pragma unroll
for
(
index_t
k
=
1
;
k
<
K
;
++
k
)
{
ds_read_b128
(
reg_a
[
0
],
a_lds_loc
,
k
*
a_lds_row_stride
);
lgkmcnt
(
1
);
outerProduct4x4
(
reg_a
[
1
],
reg_b
[
0
],
reg_c
[
8
],
reg_c
[
10
],
reg_c
[
12
],
reg_c
[
14
]);
ds_read_b128
(
reg_b
[
0
],
b_lds_loc
,
k
*
b_lds_row_stride
);
outerProduct4x4
(
reg_a
[
1
],
reg_b
[
1
],
reg_c
[
9
],
reg_c
[
11
],
reg_c
[
13
],
reg_c
[
15
]);
ds_read_b128
(
reg_b
[
1
],
b_lds_loc
,
b_lds_cluster_col_stride
+
k
*
b_lds_row_stride
);
ds_read_b128
(
reg_a
[
1
],
a_lds_loc
,
a_lds_cluster_col_stride
+
k
*
a_lds_row_stride
);
lgkmcnt
(
2
);
outerProduct4x4
(
reg_a
[
0
],
reg_b
[
0
],
reg_c
[
0
],
reg_c
[
2
],
reg_c
[
4
],
reg_c
[
6
]);
lgkmcnt
(
1
);
outerProduct4x4
(
reg_a
[
0
],
reg_b
[
1
],
reg_c
[
1
],
reg_c
[
3
],
reg_c
[
5
],
reg_c
[
7
]);
}
lgkmcnt
(
0
);
outerProduct4x4
(
reg_a
[
1
],
reg_b
[
0
],
reg_c
[
8
],
reg_c
[
10
],
reg_c
[
12
],
reg_c
[
14
]);
outerProduct4x4
(
reg_a
[
1
],
reg_b
[
1
],
reg_c
[
9
],
reg_c
[
11
],
reg_c
[
13
],
reg_c
[
15
]);
}
#endif
#endif
template
<
class
FloatA
,
class
FloatB
,
class
FloatC
>
template
<
class
FloatA
,
class
FloatB
,
class
FloatC
>
...
...
composable_kernel/include/utility/config_amd.hpp.in
View file @
740149fc
...
@@ -7,10 +7,8 @@
...
@@ -7,10 +7,8 @@
#define CK_DEVICE_BACKEND_AMD 1
#define CK_DEVICE_BACKEND_AMD 1
#define CK_USE_AMD_INLINE_ASM 1
#define CK_USE_AMD_INLINE_ASM 1
#define CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_BLOCKWISE_GENERIC_SLICE_COPY_V1 1
#define CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_BLOCKWISE_GENERIC_SLICE_COPY_V1 1
#define CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_THREADWISE_GENERIC_TENSOR_SLICE_COPY_V1 0
#define CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_THREADWISE_GENERIC_TENSOR_SLICE_COPY_V1R1 0
#define CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_THREADWISE_GENERIC_TENSOR_SLICE_COPY_V1R1 0
#define CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_THREADWISE_GENERIC_TENSOR_SLICE_COPY_V1R2 0
#define CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_THREADWISE_GENERIC_TENSOR_SLICE_COPY_V1R2 0
#define CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_THREADWISE_GENERIC_TENSOR_SLICE_COPY_V2 0
#define CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_THREADWISE_GENERIC_TENSOR_SLICE_COPY_V2R1 0
#define CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_THREADWISE_GENERIC_TENSOR_SLICE_COPY_V2R1 0
namespace ck {
namespace ck {
...
...
composable_kernel/include/utility/config_nvidia.hpp.in
View file @
740149fc
...
@@ -9,10 +9,8 @@
...
@@ -9,10 +9,8 @@
#define CK_DEVICE_BACKEND_NVIDIA 1
#define CK_DEVICE_BACKEND_NVIDIA 1
#define CK_USE_AMD_INLINE_ASM 0
#define CK_USE_AMD_INLINE_ASM 0
#define CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_BLOCKWISE_GENERIC_SLICE_COPY_V1 0
#define CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_BLOCKWISE_GENERIC_SLICE_COPY_V1 0
#define CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_THREADWISE_GENERIC_TENSOR_SLICE_COPY_V1 0
#define CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_THREADWISE_GENERIC_TENSOR_SLICE_COPY_V1R1 0
#define CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_THREADWISE_GENERIC_TENSOR_SLICE_COPY_V1R1 0
#define CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_THREADWISE_GENERIC_TENSOR_SLICE_COPY_V1R2 0
#define CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_THREADWISE_GENERIC_TENSOR_SLICE_COPY_V1R2 0
#define CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_THREADWISE_GENERIC_TENSOR_SLICE_COPY_V2 0
#define CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_THREADWISE_GENERIC_TENSOR_SLICE_COPY_V2R1 0
#define CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_THREADWISE_GENERIC_TENSOR_SLICE_COPY_V2R1 0
namespace ck {
namespace ck {
...
...
driver/src/driver.cpp
View file @
740149fc
...
@@ -71,7 +71,7 @@ int main(int argc, char* argv[])
...
@@ -71,7 +71,7 @@ int main(int argc, char* argv[])
{
{
using
namespace
ck
;
using
namespace
ck
;
#if
0
#if
1
constexpr
index_t
N
=
64
;
constexpr
index_t
N
=
64
;
constexpr
index_t
C
=
1536
;
constexpr
index_t
C
=
1536
;
constexpr
index_t
HI
=
8
;
constexpr
index_t
HI
=
8
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment