Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
yangql
composable_kernel-1
Commits
88b77181
Commit
88b77181
authored
Jun 11, 2019
by
Chao Liu
Browse files
rename files, added header guard, added namespace
parent
05e04665
Changes
62
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
222 additions
and
85 deletions
+222
-85
src/include/gridwise_convolution_implicit_gemm_v1r3_nchw_cyxk_nkhw.hpp
...ridwise_convolution_implicit_gemm_v1r3_nchw_cyxk_nkhw.hpp
+23
-15
src/include/gridwise_convolution_implicit_gemm_v1r3_nchw_cyxk_nkhw_lds_double_buffer.hpp
...n_implicit_gemm_v1r3_nchw_cyxk_nkhw_lds_double_buffer.hpp
+24
-16
src/include/gridwise_convolution_implicit_gemm_v2_chwn_cyxk_khwn.hpp
.../gridwise_convolution_implicit_gemm_v2_chwn_cyxk_khwn.hpp
+9
-2
src/include/gridwise_convolution_implicit_gemm_v2_chwn_cyxk_khwn_lds_double_buffer.hpp
...ion_implicit_gemm_v2_chwn_cyxk_khwn_lds_double_buffer.hpp
+10
-3
src/include/gridwise_convolution_implicit_gemm_v3_nchw_cyxk_nkhw.hpp
.../gridwise_convolution_implicit_gemm_v3_nchw_cyxk_nkhw.hpp
+14
-7
src/include/gridwise_convolution_implicit_gemm_v3_nchw_cyxk_nkhw_lds_double_buffer.hpp
...ion_implicit_gemm_v3_nchw_cyxk_nkhw_lds_double_buffer.hpp
+15
-8
src/include/gridwise_convolution_implicit_gemm_v4_nchw_kcyx_nkhw.hpp
.../gridwise_convolution_implicit_gemm_v4_nchw_kcyx_nkhw.hpp
+15
-8
src/include/gridwise_convolution_implicit_gemm_v4_nchw_kcyx_nkhw_lds_double_buffer.hpp
...ion_implicit_gemm_v4_nchw_kcyx_nkhw_lds_double_buffer.hpp
+16
-9
src/include/gridwise_convolution_kernel_wrapper.hpp
src/include/gridwise_convolution_kernel_wrapper.hpp
+16
-0
src/include/gridwise_convolution_wrapper.hpp
src/include/gridwise_convolution_wrapper.hpp
+0
-9
src/include/gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw.hpp
...idwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw.hpp
+4
-0
src/include/gridwise_implicit_gemm_convolution_1_chwn_cyxk_khwn_padded.hpp
...ise_implicit_gemm_convolution_1_chwn_cyxk_khwn_padded.hpp
+4
-0
src/include/integral_constant.hpp
src/include/integral_constant.hpp
+7
-1
src/include/tensor.hpp
src/include/tensor.hpp
+5
-1
src/include/threadwise_4d_tensor_op.hpp
src/include/threadwise_4d_tensor_op.hpp
+8
-1
src/include/threadwise_direct_convolution.hpp
src/include/threadwise_direct_convolution.hpp
+9
-2
src/include/threadwise_gemm.hpp
src/include/threadwise_gemm.hpp
+8
-1
src/include/threadwise_generic_tensor_op.hpp
src/include/threadwise_generic_tensor_op.hpp
+19
-0
src/include/threadwise_generic_tensor_slice_copy.hpp
src/include/threadwise_generic_tensor_slice_copy.hpp
+8
-1
src/include/threadwise_tensor_slice_copy.hpp
src/include/threadwise_tensor_slice_copy.hpp
+8
-1
No files found.
src/include/gridwise_convolution_implicit_gemm_v1r3_nchw_cyxk_nkhw.hpp
View file @
88b77181
#pragma once
#ifndef CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V1R3_NCHW_CYXK_NKHW
#define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V1R3_NCHW_CYXK_NKHW
#include "common.hpp"
#include "common.hpp"
#include "ConstantTensorDescriptor.hpp"
#include "ConstantTensorDescriptor.hpp"
#include "ConstantMatrixDescriptor.hpp"
#include "ConstantMatrixDescriptor.hpp"
#include "blockwise_2d_tensor_op.hpp"
#include "blockwise_2d_tensor_op.hpp"
#include "blockwise_tensor_slice_op.hpp"
#include "blockwise_tensor_slice_
c
op
y
.hpp"
#include "threadwise_tensor_slice_op.hpp"
#include "threadwise_tensor_slice_
c
op
y
.hpp"
#include "threadwise_
4d
_tensor_op.hpp"
#include "threadwise_
generic
_tensor_op.hpp"
#include "blockwise_batched_gemm.hpp"
#include "blockwise_batched_gemm.hpp"
namespace
ck
{
template
<
index_t
GridSize
,
template
<
index_t
GridSize
,
index_t
BlockSize
,
index_t
BlockSize
,
class
Float
,
class
Float
,
...
@@ -78,10 +82,10 @@ struct GridwiseConvolutionImplicitGemm_v1r3_nchw_cyxk_nkhw
...
@@ -78,10 +82,10 @@ struct GridwiseConvolutionImplicitGemm_v1r3_nchw_cyxk_nkhw
Ho
%
HoPerBlock
==
0
&&
Wo
%
WoPerBlock
==
0
,
Ho
%
HoPerBlock
==
0
&&
Wo
%
WoPerBlock
==
0
,
"wrong! cannot evenly divide work for workgroup "
);
"wrong! cannot evenly divide work for workgroup "
);
constexpr
index_t
NBlockWork
=
m
od_conv
::
integer_divide_ceil
(
N
,
NPerBlock
);
constexpr
index_t
NBlockWork
=
m
ath
::
integer_divide_ceil
(
N
,
NPerBlock
);
constexpr
index_t
KBlockWork
=
m
od_conv
::
integer_divide_ceil
(
K
,
KPerBlock
);
constexpr
index_t
KBlockWork
=
m
ath
::
integer_divide_ceil
(
K
,
KPerBlock
);
constexpr
index_t
HBlockWork
=
m
od_conv
::
integer_divide_ceil
(
Ho
,
HoPerBlock
);
constexpr
index_t
HBlockWork
=
m
ath
::
integer_divide_ceil
(
Ho
,
HoPerBlock
);
constexpr
index_t
WBlockWork
=
m
od_conv
::
integer_divide_ceil
(
Wo
,
WoPerBlock
);
constexpr
index_t
WBlockWork
=
m
ath
::
integer_divide_ceil
(
Wo
,
WoPerBlock
);
constexpr
auto
block_work_desc
=
make_ConstantTensorDescriptor_packed
(
constexpr
auto
block_work_desc
=
make_ConstantTensorDescriptor_packed
(
Sequence
<
NBlockWork
,
KBlockWork
,
HBlockWork
,
WBlockWork
>
{});
Sequence
<
NBlockWork
,
KBlockWork
,
HBlockWork
,
WBlockWork
>
{});
...
@@ -103,10 +107,10 @@ struct GridwiseConvolutionImplicitGemm_v1r3_nchw_cyxk_nkhw
...
@@ -103,10 +107,10 @@ struct GridwiseConvolutionImplicitGemm_v1r3_nchw_cyxk_nkhw
// LDS tensor view
// LDS tensor view
// be careful of alignment
// be careful of alignment
constexpr
index_t
max_align
=
m
od_conv
::
lcm
(
InBlockReorderDataPerWrite_N
,
constexpr
index_t
max_align
=
m
ath
::
lcm
(
InBlockReorderDataPerWrite_N
,
WeiBlockCopyDataPerRead_K
,
WeiBlockCopyDataPerRead_K
,
GemmDataPerReadA
,
GemmDataPerReadA
,
GemmDataPerReadB
);
GemmDataPerReadB
);
constexpr
auto
in_c_h_w_n_block_desc
=
make_ConstantTensorDescriptor_aligned
(
constexpr
auto
in_c_h_w_n_block_desc
=
make_ConstantTensorDescriptor_aligned
(
Sequence
<
CPerBlock
,
HoPerBlock
,
WoPerBlock
,
NPerBlock
>
{},
Sequence
<
CPerBlock
,
HoPerBlock
,
WoPerBlock
,
NPerBlock
>
{},
...
@@ -119,7 +123,7 @@ struct GridwiseConvolutionImplicitGemm_v1r3_nchw_cyxk_nkhw
...
@@ -119,7 +123,7 @@ struct GridwiseConvolutionImplicitGemm_v1r3_nchw_cyxk_nkhw
constexpr
auto
wei_c_k_block_desc
=
make_ConstantTensorDescriptor_aligned
(
constexpr
auto
wei_c_k_block_desc
=
make_ConstantTensorDescriptor_aligned
(
Sequence
<
CPerBlock
,
KPerBlock
>
{},
Sequence
<
CPerBlock
,
KPerBlock
>
{},
Number
<
m
od_conv
::
lcm
(
WeiBlockCopyDataPerRead_K
,
GemmDataPerReadA
)
>
{});
Number
<
m
ath
::
lcm
(
WeiBlockCopyDataPerRead_K
,
GemmDataPerReadA
)
>
{});
// tensor view of threadwise output in register
// tensor view of threadwise output in register
constexpr
auto
out_k_h_w_n_thread_desc
=
make_ConstantTensorDescriptor_packed
(
constexpr
auto
out_k_h_w_n_thread_desc
=
make_ConstantTensorDescriptor_packed
(
...
@@ -230,7 +234,7 @@ struct GridwiseConvolutionImplicitGemm_v1r3_nchw_cyxk_nkhw
...
@@ -230,7 +234,7 @@ struct GridwiseConvolutionImplicitGemm_v1r3_nchw_cyxk_nkhw
#endif
#endif
// set threadwise output tensor to 0
// set threadwise output tensor to 0
threadwise_
4d
_tensor_set_zero
(
out_k_h_w_n_thread_desc
,
p_out_thread
);
threadwise_
generic
_tensor_set_zero
(
out_k_h_w_n_thread_desc
,
p_out_thread
);
#if 0
#if 0
const Float* p_in_global_block_offset =
const Float* p_in_global_block_offset =
...
@@ -436,8 +440,12 @@ struct GridwiseConvolutionImplicitGemm_v1r3_nchw_cyxk_nkhw
...
@@ -436,8 +440,12 @@ struct GridwiseConvolutionImplicitGemm_v1r3_nchw_cyxk_nkhw
wo_block_data_begin
+
wo_thread_data_begin
),
wo_block_data_begin
+
wo_thread_data_begin
),
make_zero_array
<
index_t
,
10
>
(),
make_zero_array
<
index_t
,
10
>
(),
out_10d_thread_desc
.
GetLengths
().
ReorderGivenNew2Old
(
map_out_global2thread
),
out_10d_thread_desc
.
GetLengths
().
ReorderGivenNew2Old
(
map_out_global2thread
),
arithmetic_sequence_gen
<
0
,
10
,
1
>::
SeqType
{});
arithmetic_sequence_gen
<
0
,
10
,
1
>::
SeqType
{},
Number
<
1
>
{});
#endif
#endif
});
});
}
}
};
};
}
// namespace ck
#endif
src/include/gridwise_convolution_implicit_gemm_v1r3_lds_double_buffer
_nchw_cyxk_nkhw
.hpp
→
src/include/gridwise_convolution_implicit_gemm_v1r3_
nchw_cyxk_nkhw_
lds_double_buffer.hpp
View file @
88b77181
#pragma once
#ifndef CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V1R3_NCHW_CYXK_NKHW_LDS_DOUBLE_BUFFER
#define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V1R3_NCHW_CYXK_NKHW_LDS_DOUBLE_BUFFER
#include "common.hpp"
#include "common.hpp"
#include "ConstantTensorDescriptor.hpp"
#include "ConstantTensorDescriptor.hpp"
#include "ConstantMatrixDescriptor.hpp"
#include "ConstantMatrixDescriptor.hpp"
#include "blockwise_2d_tensor_op.hpp"
#include "blockwise_2d_tensor_op.hpp"
#include "blockwise_tensor_slice_op.hpp"
#include "blockwise_tensor_slice_
c
op
y
.hpp"
#include "threadwise_tensor_slice_op.hpp"
#include "threadwise_tensor_slice_
c
op
y
.hpp"
#include "threadwise_
4d
_tensor_op.hpp"
#include "threadwise_
generic
_tensor_op.hpp"
#include "blockwise_batched_gemm.hpp"
#include "blockwise_batched_gemm.hpp"
namespace
ck
{
template
<
index_t
GridSize
,
template
<
index_t
GridSize
,
index_t
BlockSize
,
index_t
BlockSize
,
class
Float
,
class
Float
,
...
@@ -40,7 +44,7 @@ template <index_t GridSize,
...
@@ -40,7 +44,7 @@ template <index_t GridSize,
class
WeiBlockCopyClusterLengths_CK
,
// not used
class
WeiBlockCopyClusterLengths_CK
,
// not used
index_t
WeiBlockCopyDataPerRead_K
,
index_t
WeiBlockCopyDataPerRead_K
,
index_t
OutThreadCopyDataPerWrite_W
>
index_t
OutThreadCopyDataPerWrite_W
>
struct
GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer
_nchw_cyxk_nkhw
struct
GridwiseConvolutionImplicitGemm_v1r3_
nchw_cyxk_nkhw_
lds_double_buffer
{
{
__device__
void
Run
(
const
Float
*
const
__restrict__
p_in_global
,
__device__
void
Run
(
const
Float
*
const
__restrict__
p_in_global
,
const
Float
*
const
__restrict__
p_wei_global
,
const
Float
*
const
__restrict__
p_wei_global
,
...
@@ -81,10 +85,10 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_nchw_cyxk_nkhw
...
@@ -81,10 +85,10 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_nchw_cyxk_nkhw
Ho
%
HoPerBlock
==
0
&&
Wo
%
WoPerBlock
==
0
,
Ho
%
HoPerBlock
==
0
&&
Wo
%
WoPerBlock
==
0
,
"wrong! cannot evenly divide work for workgroup "
);
"wrong! cannot evenly divide work for workgroup "
);
constexpr
index_t
NBlockWork
=
m
od_conv
::
integer_divide_ceil
(
N
,
NPerBlock
);
constexpr
index_t
NBlockWork
=
m
ath
::
integer_divide_ceil
(
N
,
NPerBlock
);
constexpr
index_t
KBlockWork
=
m
od_conv
::
integer_divide_ceil
(
K
,
KPerBlock
);
constexpr
index_t
KBlockWork
=
m
ath
::
integer_divide_ceil
(
K
,
KPerBlock
);
constexpr
index_t
HBlockWork
=
m
od_conv
::
integer_divide_ceil
(
Ho
,
HoPerBlock
);
constexpr
index_t
HBlockWork
=
m
ath
::
integer_divide_ceil
(
Ho
,
HoPerBlock
);
constexpr
index_t
WBlockWork
=
m
od_conv
::
integer_divide_ceil
(
Wo
,
WoPerBlock
);
constexpr
index_t
WBlockWork
=
m
ath
::
integer_divide_ceil
(
Wo
,
WoPerBlock
);
constexpr
auto
block_work_desc
=
make_ConstantTensorDescriptor_packed
(
constexpr
auto
block_work_desc
=
make_ConstantTensorDescriptor_packed
(
Sequence
<
NBlockWork
,
KBlockWork
,
HBlockWork
,
WBlockWork
>
{});
Sequence
<
NBlockWork
,
KBlockWork
,
HBlockWork
,
WBlockWork
>
{});
...
@@ -105,10 +109,10 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_nchw_cyxk_nkhw
...
@@ -105,10 +109,10 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_nchw_cyxk_nkhw
// LDS tensor view
// LDS tensor view
// be careful of alignment
// be careful of alignment
constexpr
index_t
max_align
=
m
od_conv
::
lcm
(
InBlockReorderDataPerWrite_N
,
constexpr
index_t
max_align
=
m
ath
::
lcm
(
InBlockReorderDataPerWrite_N
,
WeiBlockCopyDataPerRead_K
,
WeiBlockCopyDataPerRead_K
,
GemmDataPerReadA
,
GemmDataPerReadA
,
GemmDataPerReadB
);
GemmDataPerReadB
);
constexpr
auto
in_c_h_w_n_block_desc
=
make_ConstantTensorDescriptor_aligned
(
constexpr
auto
in_c_h_w_n_block_desc
=
make_ConstantTensorDescriptor_aligned
(
Sequence
<
CPerBlock
,
HoPerBlock
,
WoPerBlock
,
NPerBlock
>
{},
Sequence
<
CPerBlock
,
HoPerBlock
,
WoPerBlock
,
NPerBlock
>
{},
...
@@ -121,7 +125,7 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_nchw_cyxk_nkhw
...
@@ -121,7 +125,7 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_nchw_cyxk_nkhw
constexpr
auto
wei_c_k_block_desc
=
make_ConstantTensorDescriptor_aligned
(
constexpr
auto
wei_c_k_block_desc
=
make_ConstantTensorDescriptor_aligned
(
Sequence
<
CPerBlock
,
KPerBlock
>
{},
Sequence
<
CPerBlock
,
KPerBlock
>
{},
Number
<
m
od_conv
::
lcm
(
WeiBlockCopyDataPerRead_K
,
GemmDataPerReadA
)
>
{});
Number
<
m
ath
::
lcm
(
WeiBlockCopyDataPerRead_K
,
GemmDataPerReadA
)
>
{});
// tensor view of threadwise output in register
// tensor view of threadwise output in register
constexpr
auto
out_k_h_w_n_thread_desc
=
make_ConstantTensorDescriptor_packed
(
constexpr
auto
out_k_h_w_n_thread_desc
=
make_ConstantTensorDescriptor_packed
(
...
@@ -233,7 +237,7 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_nchw_cyxk_nkhw
...
@@ -233,7 +237,7 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_nchw_cyxk_nkhw
#endif
#endif
// set threadwise output tensor to 0
// set threadwise output tensor to 0
threadwise_
4d
_tensor_set_zero
(
out_k_h_w_n_thread_desc
,
p_out_thread
);
threadwise_
generic
_tensor_set_zero
(
out_k_h_w_n_thread_desc
,
p_out_thread
);
for
(
index_t
y
=
0
;
y
<
Y
;
++
y
)
for
(
index_t
y
=
0
;
y
<
Y
;
++
y
)
{
{
...
@@ -487,8 +491,12 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_nchw_cyxk_nkhw
...
@@ -487,8 +491,12 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_nchw_cyxk_nkhw
wo_block_data_begin
+
wo_thread_data_begin
),
wo_block_data_begin
+
wo_thread_data_begin
),
make_zero_array
<
index_t
,
10
>
(),
make_zero_array
<
index_t
,
10
>
(),
out_10d_thread_desc
.
GetLengths
().
ReorderGivenNew2Old
(
map_out_global2thread
),
out_10d_thread_desc
.
GetLengths
().
ReorderGivenNew2Old
(
map_out_global2thread
),
arithmetic_sequence_gen
<
0
,
10
,
1
>::
SeqType
{});
arithmetic_sequence_gen
<
0
,
10
,
1
>::
SeqType
{},
Number
<
1
>
{});
#endif
#endif
});
});
}
}
};
};
}
// namespace
#endif
src/include/gridwise_convolution_implicit_gemm_v2_chwn_cyxk_khwn.hpp
View file @
88b77181
#pragma once
#ifndef CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V2_CHWN_CYXK_KHWN
#define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V2_CHWN_CYXK_KHWN
#include "common.hpp"
#include "common.hpp"
#include "ConstantTensorDescriptor.hpp"
#include "ConstantTensorDescriptor.hpp"
#include "ConstantMatrixDescriptor.hpp"
#include "ConstantMatrixDescriptor.hpp"
...
@@ -6,6 +8,8 @@
...
@@ -6,6 +8,8 @@
#include "blockwise_2d_tensor_op.hpp"
#include "blockwise_2d_tensor_op.hpp"
#include "blockwise_gemm.hpp"
#include "blockwise_gemm.hpp"
namespace
ck
{
// define B = flatten(N, Hi, Wi)
// define B = flatten(N, Hi, Wi)
template
<
index_t
GridSize
,
template
<
index_t
GridSize
,
index_t
BlockSize
,
index_t
BlockSize
,
...
@@ -181,7 +185,7 @@ struct GridwiseConvolutionImplicitGemm_v2_chwn_cyxk_khwn
...
@@ -181,7 +185,7 @@ struct GridwiseConvolutionImplicitGemm_v2_chwn_cyxk_khwn
// LDS: be careful of alignment
// LDS: be careful of alignment
constexpr
index_t
max_align
=
constexpr
index_t
max_align
=
m
od_conv
::
lcm
(
index_t
(
4
),
InBlockCopyDataPerRead
,
WeiBlockCopyDataPerRead
);
m
ath
::
lcm
(
index_t
(
4
),
InBlockCopyDataPerRead
,
WeiBlockCopyDataPerRead
);
constexpr
index_t
in_block_space
=
in_cb_block_desc
.
GetElementSpace
(
Number
<
max_align
>
{});
constexpr
index_t
in_block_space
=
in_cb_block_desc
.
GetElementSpace
(
Number
<
max_align
>
{});
...
@@ -275,3 +279,6 @@ struct GridwiseConvolutionImplicitGemm_v2_chwn_cyxk_khwn
...
@@ -275,3 +279,6 @@ struct GridwiseConvolutionImplicitGemm_v2_chwn_cyxk_khwn
}
}
}
}
};
};
}
// namespace ck
#endif
src/include/gridwise_convolution_implicit_gemm_v2_chwn_cyxk_khwn_lds_double_buffer.hpp
View file @
88b77181
#pragma once
#ifndef CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V2_CHWN_CYXK_KHWN_LDS_DOUBLE_BUFFER
#define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V2_CHWN_CYXK_KHWN_LDS_DOUBLE_BUFFER
#include "common.hpp"
#include "common.hpp"
#include "ConstantTensorDescriptor.hpp"
#include "ConstantTensorDescriptor.hpp"
#include "ConstantMatrixDescriptor.hpp"
#include "ConstantMatrixDescriptor.hpp"
#include "blockwise_4d_tensor_op.hpp"
#include "blockwise_4d_tensor_op.hpp"
#include "blockwise_2d_tensor_op.hpp"
#include "blockwise_2d_tensor_op.hpp"
#include "threadwise_tensor_slice_op.hpp"
#include "threadwise_tensor_slice_
c
op
y
.hpp"
#include "blockwise_gemm.hpp"
#include "blockwise_gemm.hpp"
namespace
ck
{
// define B = flatten(N, Hi, Wi)
// define B = flatten(N, Hi, Wi)
template
<
index_t
GridSize
,
template
<
index_t
GridSize
,
index_t
BlockSize
,
index_t
BlockSize
,
...
@@ -185,7 +189,7 @@ struct GridwiseConvolutionImplicitGemm_v2_chwn_cyxk_khwn_lds_double_buffer
...
@@ -185,7 +189,7 @@ struct GridwiseConvolutionImplicitGemm_v2_chwn_cyxk_khwn_lds_double_buffer
// LDS: be careful of alignment
// LDS: be careful of alignment
constexpr
index_t
max_align
=
constexpr
index_t
max_align
=
m
od_conv
::
lcm
(
index_t
(
4
),
InBlockCopyDataPerRead
,
WeiBlockCopyDataPerRead
);
m
ath
::
lcm
(
index_t
(
4
),
InBlockCopyDataPerRead
,
WeiBlockCopyDataPerRead
);
constexpr
index_t
in_block_space
=
in_cb_block_desc
.
GetElementSpace
(
Number
<
max_align
>
{});
constexpr
index_t
in_block_space
=
in_cb_block_desc
.
GetElementSpace
(
Number
<
max_align
>
{});
...
@@ -404,3 +408,6 @@ struct GridwiseConvolutionImplicitGemm_v2_chwn_cyxk_khwn_lds_double_buffer
...
@@ -404,3 +408,6 @@ struct GridwiseConvolutionImplicitGemm_v2_chwn_cyxk_khwn_lds_double_buffer
}
}
}
}
};
};
}
// namespace ck
#endif
src/include/gridwise_convolution_implicit_gemm_v3_nchw_cyxk_nkhw.hpp
View file @
88b77181
#pragma once
#ifndef CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V3_NCHW_CYXK_NKHW
#define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V3_NCHW_CYXK_NKHW
#include "common.hpp"
#include "common.hpp"
#include "ConstantTensorDescriptor.hpp"
#include "ConstantTensorDescriptor.hpp"
#include "ConstantMergedTensorDescriptor.hpp"
#include "ConstantMergedTensorDescriptor.hpp"
#include "ConstantMatrixDescriptor.hpp"
#include "ConstantMatrixDescriptor.hpp"
#include "blockwise_generic_tensor_slice_op.hpp"
#include "blockwise_generic_tensor_slice_
c
op
y
.hpp"
#include "blockwise_gemm.hpp"
#include "blockwise_gemm.hpp"
namespace
ck
{
// define B = merge(N0, Ho, Wo)
// define B = merge(N0, Ho, Wo)
template
<
index_t
GridSize
,
template
<
index_t
GridSize
,
index_t
BlockSize
,
index_t
BlockSize
,
...
@@ -146,7 +150,7 @@ struct GridwiseConvolutionImplicitGemm_v3_nchw_cyxk_nkhw
...
@@ -146,7 +150,7 @@ struct GridwiseConvolutionImplicitGemm_v3_nchw_cyxk_nkhw
// be careful of LDS alignment
// be careful of LDS alignment
constexpr
auto
wei_c_k_block_desc
=
make_ConstantTensorDescriptor_aligned
(
constexpr
auto
wei_c_k_block_desc
=
make_ConstantTensorDescriptor_aligned
(
Sequence
<
CPerBlock
,
KPerBlock
>
{},
Sequence
<
CPerBlock
,
KPerBlock
>
{},
Number
<
m
od_conv
::
lcm
(
WeiBlockCopyDataPerAccess_K
,
GemmDataPerReadA
)
>
{});
Number
<
m
ath
::
lcm
(
WeiBlockCopyDataPerAccess_K
,
GemmDataPerReadA
)
>
{});
// operator for blockwise copy of weight into LDS
// operator for blockwise copy of weight into LDS
// slice a tensor, and copy it into another tensor
// slice a tensor, and copy it into another tensor
...
@@ -218,10 +222,10 @@ struct GridwiseConvolutionImplicitGemm_v3_nchw_cyxk_nkhw
...
@@ -218,10 +222,10 @@ struct GridwiseConvolutionImplicitGemm_v3_nchw_cyxk_nkhw
};
};
// LDS allocation for input and weight: be careful of alignment
// LDS allocation for input and weight: be careful of alignment
constexpr
index_t
max_align
=
m
od_conv
::
lcm
(
InBlockCopyDstDataPerWrite_N2
,
constexpr
index_t
max_align
=
m
ath
::
lcm
(
InBlockCopyDstDataPerWrite_N2
,
WeiBlockCopyDataPerAccess_K
,
WeiBlockCopyDataPerAccess_K
,
GemmDataPerReadA
,
GemmDataPerReadA
,
GemmDataPerReadB
);
GemmDataPerReadB
);
constexpr
index_t
in_block_space
=
constexpr
index_t
in_block_space
=
in_c_n1_b_n2_block_mem_desc
.
GetElementSpace
(
Number
<
max_align
>
{});
in_c_n1_b_n2_block_mem_desc
.
GetElementSpace
(
Number
<
max_align
>
{});
...
@@ -368,3 +372,6 @@ struct GridwiseConvolutionImplicitGemm_v3_nchw_cyxk_nkhw
...
@@ -368,3 +372,6 @@ struct GridwiseConvolutionImplicitGemm_v3_nchw_cyxk_nkhw
}
}
}
}
};
};
}
// namespace ck
#endif
src/include/gridwise_convolution_implicit_gemm_v3_lds_double_buffer
_nchw_cyxk_nkhw
.hpp
→
src/include/gridwise_convolution_implicit_gemm_v3_
nchw_cyxk_nkhw_
lds_double_buffer.hpp
View file @
88b77181
#pragma once
#ifndef CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V3_NCHW_CYXK_NKHW_LDS_DOUBLE_BUFFER
#define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V3_NCHW_CYXK_NKHW_LDS_DOUBLE_BUFFER
#include "common.hpp"
#include "common.hpp"
#include "ConstantTensorDescriptor.hpp"
#include "ConstantTensorDescriptor.hpp"
#include "ConstantMergedTensorDescriptor.hpp"
#include "ConstantMergedTensorDescriptor.hpp"
#include "ConstantMatrixDescriptor.hpp"
#include "ConstantMatrixDescriptor.hpp"
#include "blockwise_generic_tensor_slice_op.hpp"
#include "blockwise_generic_tensor_slice_
c
op
y
.hpp"
#include "blockwise_gemm.hpp"
#include "blockwise_gemm.hpp"
namespace
ck
{
// define B = merge(N0, Ho, Wo)
// define B = merge(N0, Ho, Wo)
template
<
index_t
GridSize
,
template
<
index_t
GridSize
,
index_t
BlockSize
,
index_t
BlockSize
,
...
@@ -34,7 +38,7 @@ template <index_t GridSize,
...
@@ -34,7 +38,7 @@ template <index_t GridSize,
class
WeiBlockCopySubLengths_C_K
,
class
WeiBlockCopySubLengths_C_K
,
class
WeiBlockCopyClusterLengths_C_K
,
class
WeiBlockCopyClusterLengths_C_K
,
index_t
WeiBlockCopyDataPerAccess_K
>
index_t
WeiBlockCopyDataPerAccess_K
>
struct
GridwiseConvolutionImplicitGemm_v3_lds_double_buffer
_nchw_cyxk_nkhw
struct
GridwiseConvolutionImplicitGemm_v3_
nchw_cyxk_nkhw_
lds_double_buffer
{
{
__device__
void
Run
(
const
Float
*
const
__restrict__
p_in_global
,
__device__
void
Run
(
const
Float
*
const
__restrict__
p_in_global
,
const
Float
*
const
__restrict__
p_wei_global
,
const
Float
*
const
__restrict__
p_wei_global
,
...
@@ -143,7 +147,7 @@ struct GridwiseConvolutionImplicitGemm_v3_lds_double_buffer_nchw_cyxk_nkhw
...
@@ -143,7 +147,7 @@ struct GridwiseConvolutionImplicitGemm_v3_lds_double_buffer_nchw_cyxk_nkhw
// be careful of LDS alignment
// be careful of LDS alignment
constexpr
auto
wei_c_k_block_desc
=
make_ConstantTensorDescriptor_aligned
(
constexpr
auto
wei_c_k_block_desc
=
make_ConstantTensorDescriptor_aligned
(
Sequence
<
CPerBlock
,
KPerBlock
>
{},
Sequence
<
CPerBlock
,
KPerBlock
>
{},
Number
<
m
od_conv
::
lcm
(
WeiBlockCopyDataPerAccess_K
,
GemmDataPerReadA
)
>
{});
Number
<
m
ath
::
lcm
(
WeiBlockCopyDataPerAccess_K
,
GemmDataPerReadA
)
>
{});
// operator for blockwise copy of weight into LDS
// operator for blockwise copy of weight into LDS
// slice a tensor, and copy it into another tensor
// slice a tensor, and copy it into another tensor
...
@@ -215,10 +219,10 @@ struct GridwiseConvolutionImplicitGemm_v3_lds_double_buffer_nchw_cyxk_nkhw
...
@@ -215,10 +219,10 @@ struct GridwiseConvolutionImplicitGemm_v3_lds_double_buffer_nchw_cyxk_nkhw
};
};
// LDS allocation for input and weight: be careful of alignment
// LDS allocation for input and weight: be careful of alignment
constexpr
index_t
max_align
=
m
od_conv
::
lcm
(
InBlockCopyDstDataPerWrite_N2
,
constexpr
index_t
max_align
=
m
ath
::
lcm
(
InBlockCopyDstDataPerWrite_N2
,
WeiBlockCopyDataPerAccess_K
,
WeiBlockCopyDataPerAccess_K
,
GemmDataPerReadA
,
GemmDataPerReadA
,
GemmDataPerReadB
);
GemmDataPerReadB
);
constexpr
index_t
in_block_space
=
constexpr
index_t
in_block_space
=
in_c_n1_b_n2_block_mem_desc
.
GetElementSpace
(
Number
<
max_align
>
{});
in_c_n1_b_n2_block_mem_desc
.
GetElementSpace
(
Number
<
max_align
>
{});
...
@@ -395,3 +399,6 @@ struct GridwiseConvolutionImplicitGemm_v3_lds_double_buffer_nchw_cyxk_nkhw
...
@@ -395,3 +399,6 @@ struct GridwiseConvolutionImplicitGemm_v3_lds_double_buffer_nchw_cyxk_nkhw
}
}
}
}
};
};
}
// namesspace ck
#endif
src/include/gridwise_convolution_implicit_gemm_v4_nchw_kcyx_nkhw.hpp
View file @
88b77181
#pragma once
#ifndef CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4_NCHW_KCYX_NKHW
#define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4_NCHW_KCYX_NKHW
#include "common.hpp"
#include "common.hpp"
#include "ConstantTensorDescriptor.hpp"
#include "ConstantTensorDescriptor.hpp"
#include "ConstantMergedTensorDescriptor.hpp"
#include "ConstantMergedTensorDescriptor.hpp"
#include "ConstantMatrixDescriptor.hpp"
#include "ConstantMatrixDescriptor.hpp"
#include "blockwise_generic_tensor_slice_op.hpp"
#include "blockwise_generic_tensor_slice_
c
op
y
.hpp"
#include "blockwise_gemm.hpp"
#include "blockwise_gemm.hpp"
#include "threadwise_generic_tensor_slice_op.hpp"
#include "threadwise_generic_tensor_slice_copy.hpp"
namespace
ck
{
// define B = merge(N0, Ho, Wo)
// define B = merge(N0, Ho, Wo)
template
<
index_t
GridSize
,
template
<
index_t
GridSize
,
...
@@ -176,7 +180,7 @@ struct GridwiseConvolutionImplicitGemm_v4_nchw_kcyx_nkhw
...
@@ -176,7 +180,7 @@ struct GridwiseConvolutionImplicitGemm_v4_nchw_kcyx_nkhw
// be careful of LDS alignment
// be careful of LDS alignment
constexpr
auto
wei_e_k_block_desc
=
make_ConstantTensorDescriptor_aligned
(
constexpr
auto
wei_e_k_block_desc
=
make_ConstantTensorDescriptor_aligned
(
Sequence
<
EPerBlock
,
KPerBlock
>
{},
Sequence
<
EPerBlock
,
KPerBlock
>
{},
Number
<
m
od_conv
::
lcm
(
WeiBlockCopyDstDataPerWrite_K
,
GemmDataPerReadA
)
>
{});
Number
<
m
ath
::
lcm
(
WeiBlockCopyDstDataPerWrite_K
,
GemmDataPerReadA
)
>
{});
// operator for blockwise copy of weight into LDS
// operator for blockwise copy of weight into LDS
// slice a tensor, and copy it into another tensor
// slice a tensor, and copy it into another tensor
...
@@ -248,10 +252,10 @@ struct GridwiseConvolutionImplicitGemm_v4_nchw_kcyx_nkhw
...
@@ -248,10 +252,10 @@ struct GridwiseConvolutionImplicitGemm_v4_nchw_kcyx_nkhw
};
};
// LDS allocation for input and weight: be careful of alignment
// LDS allocation for input and weight: be careful of alignment
constexpr
index_t
max_align
=
m
od_conv
::
lcm
(
InBlockCopyDstDataPerWrite_N2
,
constexpr
index_t
max_align
=
m
ath
::
lcm
(
InBlockCopyDstDataPerWrite_N2
,
WeiBlockCopyDstDataPerWrite_K
,
WeiBlockCopyDstDataPerWrite_K
,
GemmDataPerReadA
,
GemmDataPerReadA
,
GemmDataPerReadB
);
GemmDataPerReadB
);
constexpr
index_t
in_block_space
=
constexpr
index_t
in_block_space
=
in_e_n1_b_n2_block_desc
.
GetElementSpace
(
Number
<
max_align
>
{});
in_e_n1_b_n2_block_desc
.
GetElementSpace
(
Number
<
max_align
>
{});
...
@@ -345,3 +349,6 @@ struct GridwiseConvolutionImplicitGemm_v4_nchw_kcyx_nkhw
...
@@ -345,3 +349,6 @@ struct GridwiseConvolutionImplicitGemm_v4_nchw_kcyx_nkhw
}
}
}
}
};
};
}
// namespace ck
#endif
src/include/gridwise_convolution_implicit_gemm_v4_lds_double_buffer
_nchw_kcyx_nkhw
.hpp
→
src/include/gridwise_convolution_implicit_gemm_v4_
nchw_kcyx_nkhw_
lds_double_buffer.hpp
View file @
88b77181
#pragma once
#ifndef CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4_NCHW_KCYX_NKHW_LDS_DOUBLE_BUFFER
#define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4_NCHW_KCYX_NKHW_LDS_DOUBLE_BUFFER
#include "common.hpp"
#include "common.hpp"
#include "ConstantTensorDescriptor.hpp"
#include "ConstantTensorDescriptor.hpp"
#include "ConstantMergedTensorDescriptor.hpp"
#include "ConstantMergedTensorDescriptor.hpp"
#include "ConstantMatrixDescriptor.hpp"
#include "ConstantMatrixDescriptor.hpp"
#include "blockwise_generic_tensor_slice_op.hpp"
#include "blockwise_generic_tensor_slice_
c
op
y
.hpp"
#include "blockwise_gemm.hpp"
#include "blockwise_gemm.hpp"
#include "threadwise_generic_tensor_slice_op.hpp"
#include "threadwise_generic_tensor_slice_copy.hpp"
namespace
ck
{
// define B = merge(N0, Ho, Wo)
// define B = merge(N0, Ho, Wo)
template
<
index_t
GridSize
,
template
<
index_t
GridSize
,
...
@@ -42,7 +46,7 @@ template <index_t GridSize,
...
@@ -42,7 +46,7 @@ template <index_t GridSize,
class
WeiBlockCopyDstAccessOrder
,
class
WeiBlockCopyDstAccessOrder
,
index_t
WeiBlockCopySrcDataPerRead_E
,
index_t
WeiBlockCopySrcDataPerRead_E
,
index_t
WeiBlockCopyDstDataPerWrite_K
>
index_t
WeiBlockCopyDstDataPerWrite_K
>
struct
GridwiseConvolutionImplicitGemm_v4_lds_double_buffer
_nchw_kcyx_nkhw
struct
GridwiseConvolutionImplicitGemm_v4_
nchw_kcyx_nkhw_
lds_double_buffer
{
{
__device__
void
Run
(
const
Float
*
const
__restrict__
p_in_global
,
__device__
void
Run
(
const
Float
*
const
__restrict__
p_in_global
,
const
Float
*
const
__restrict__
p_wei_global
,
const
Float
*
const
__restrict__
p_wei_global
,
...
@@ -165,7 +169,7 @@ struct GridwiseConvolutionImplicitGemm_v4_lds_double_buffer_nchw_kcyx_nkhw
...
@@ -165,7 +169,7 @@ struct GridwiseConvolutionImplicitGemm_v4_lds_double_buffer_nchw_kcyx_nkhw
// be careful of LDS alignment
// be careful of LDS alignment
constexpr
auto
wei_e_k_block_desc
=
make_ConstantTensorDescriptor_aligned
(
constexpr
auto
wei_e_k_block_desc
=
make_ConstantTensorDescriptor_aligned
(
Sequence
<
EPerBlock
,
KPerBlock
>
{},
Sequence
<
EPerBlock
,
KPerBlock
>
{},
Number
<
m
od_conv
::
lcm
(
WeiBlockCopyDstDataPerWrite_K
,
GemmDataPerReadA
)
>
{});
Number
<
m
ath
::
lcm
(
WeiBlockCopyDstDataPerWrite_K
,
GemmDataPerReadA
)
>
{});
// operator for blockwise copy of weight into LDS
// operator for blockwise copy of weight into LDS
// slice a tensor, and copy it into another tensor
// slice a tensor, and copy it into another tensor
...
@@ -237,10 +241,10 @@ struct GridwiseConvolutionImplicitGemm_v4_lds_double_buffer_nchw_kcyx_nkhw
...
@@ -237,10 +241,10 @@ struct GridwiseConvolutionImplicitGemm_v4_lds_double_buffer_nchw_kcyx_nkhw
};
};
// LDS allocation for input and weight: be careful of alignment
// LDS allocation for input and weight: be careful of alignment
constexpr
index_t
max_align
=
m
od_conv
::
lcm
(
InBlockCopyDstDataPerWrite_N2
,
constexpr
index_t
max_align
=
m
ath
::
lcm
(
InBlockCopyDstDataPerWrite_N2
,
WeiBlockCopyDstDataPerWrite_K
,
WeiBlockCopyDstDataPerWrite_K
,
GemmDataPerReadA
,
GemmDataPerReadA
,
GemmDataPerReadB
);
GemmDataPerReadB
);
constexpr
index_t
in_block_space
=
constexpr
index_t
in_block_space
=
in_e_n1_b_n2_block_desc
.
GetElementSpace
(
Number
<
max_align
>
{});
in_e_n1_b_n2_block_desc
.
GetElementSpace
(
Number
<
max_align
>
{});
...
@@ -410,3 +414,6 @@ struct GridwiseConvolutionImplicitGemm_v4_lds_double_buffer_nchw_kcyx_nkhw
...
@@ -410,3 +414,6 @@ struct GridwiseConvolutionImplicitGemm_v4_lds_double_buffer_nchw_kcyx_nkhw
}
}
}
}
};
};
}
// namespace ck
#endif
src/include/gridwise_convolution_kernel_wrapper.hpp
0 → 100644
View file @
88b77181
#ifndef CK_GRIDWISE_CONVOLUTION_KERNEL_WRAPPER
#define CK_GRIDWISE_CONVOLUTION_KERNEL_WRAPPER
namespace
ck
{
template
<
class
GridwiseConvolution
,
class
T
>
__global__
void
run_gridwise_convolution_kernel
(
const
T
*
const
__restrict__
p_in_global
,
const
T
*
const
__restrict__
p_wei_global
,
T
*
const
__restrict__
p_out_global
)
{
GridwiseConvolution
{}.
Run
(
p_in_global
,
p_wei_global
,
p_out_global
);
}
}
// namespace ck
#endif
src/include/gridwise_convolution_wrapper.hpp
deleted
100644 → 0
View file @
05e04665
#pragma once
template
<
class
GridwiseConvolution
,
class
T
>
__global__
void
run_gridwise_convolution
(
const
T
*
const
__restrict__
p_in_global
,
const
T
*
const
__restrict__
p_wei_global
,
T
*
const
__restrict__
p_out_global
)
{
GridwiseConvolution
{}.
Run
(
p_in_global
,
p_wei_global
,
p_out_global
);
}
src/include/gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw.hpp
View file @
88b77181
...
@@ -7,6 +7,8 @@
...
@@ -7,6 +7,8 @@
#include "threadwise_4d_tensor_op.hpp"
#include "threadwise_4d_tensor_op.hpp"
#include "threadwise_direct_convolution.hpp"
#include "threadwise_direct_convolution.hpp"
namespace
ck
{
template
<
class
TInWei
,
template
<
class
TInWei
,
class
TOut
,
class
TOut
,
class
TAccum
,
class
TAccum
,
...
@@ -253,3 +255,5 @@ __global__ void gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw(
...
@@ -253,3 +255,5 @@ __global__ void gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw(
wo_block_data_begin
+
wo_thread_data_begin
),
wo_block_data_begin
+
wo_thread_data_begin
),
out_nkhw_thread_desc
.
GetLengths
());
out_nkhw_thread_desc
.
GetLengths
());
}
}
}
// namespace ck
src/include/gridwise_implicit_gemm_convolution_1_chwn_cyxk_khwn_padded.hpp
View file @
88b77181
...
@@ -7,6 +7,8 @@
...
@@ -7,6 +7,8 @@
#include "threadwise_4d_tensor_op.hpp"
#include "threadwise_4d_tensor_op.hpp"
#include "blockwise_gemm.hpp"
#include "blockwise_gemm.hpp"
namespace
ck
{
template
<
index_t
GridSize
,
template
<
index_t
GridSize
,
index_t
BlockSize
,
index_t
BlockSize
,
class
Float
,
class
Float
,
...
@@ -292,3 +294,5 @@ __global__ void gridwise_implicit_gemm_convolution_1_chwn_cyxk_khwn_padded(
...
@@ -292,3 +294,5 @@ __global__ void gridwise_implicit_gemm_convolution_1_chwn_cyxk_khwn_padded(
out_hkwn_thread_desc
.
GetLengths
(),
out_hkwn_thread_desc
.
GetLengths
(),
reorder_khwn_from_hkwn
);
reorder_khwn_from_hkwn
);
}
}
}
// namespace ck
src/include/integral_constant.hpp
View file @
88b77181
#pragma once
#ifndef CK_INTEGRAL_CONSTANT_HPP
#define CK_INTEGRAL_CONSTANT_HPP
namespace
ck
{
template
<
class
T
,
T
N
>
template
<
class
T
,
T
N
>
struct
integral_constant
struct
integral_constant
...
@@ -16,3 +19,6 @@ __host__ __device__ constexpr auto operator+(integral_constant<T, X>, integral_c
...
@@ -16,3 +19,6 @@ __host__ __device__ constexpr auto operator+(integral_constant<T, X>, integral_c
template
<
index_t
N
>
template
<
index_t
N
>
using
Number
=
integral_constant
<
index_t
,
N
>
;
using
Number
=
integral_constant
<
index_t
,
N
>
;
}
// namespace ck
#endif
src/include/tensor.hpp
View file @
88b77181
#pragma once
#ifndef CK_TENSOR_HPP
#define CK_TENSOR_HPP
#include <thread>
#include <thread>
#include <vector>
#include <vector>
#include <numeric>
#include <numeric>
...
@@ -266,3 +268,5 @@ struct Tensor
...
@@ -266,3 +268,5 @@ struct Tensor
TensorDescriptor
mDesc
;
TensorDescriptor
mDesc
;
std
::
vector
<
T
>
mData
;
std
::
vector
<
T
>
mData
;
};
};
#endif
src/include/threadwise_4d_tensor_op.hpp
View file @
88b77181
#pragma once
#ifndef CK_THREADWISE_4D_TENSOR_OP_HPP
#define CK_THREADWISE_4D_TENSOR_OP_HPP
#include "ConstantTensorDescriptor.hpp"
#include "ConstantTensorDescriptor.hpp"
namespace
ck
{
template
<
class
Float
,
class
Desc
,
class
IDim
,
class
NShift
>
template
<
class
Float
,
class
Desc
,
class
IDim
,
class
NShift
>
__device__
void
threadwise_4d_tensor_shift_down
(
Desc
,
Float
*
__restrict__
p
,
IDim
,
NShift
)
__device__
void
threadwise_4d_tensor_shift_down
(
Desc
,
Float
*
__restrict__
p
,
IDim
,
NShift
)
{
{
...
@@ -50,3 +54,6 @@ __device__ void threadwise_4d_tensor_shift_down(Desc, Float* __restrict__ p, IDi
...
@@ -50,3 +54,6 @@ __device__ void threadwise_4d_tensor_shift_down(Desc, Float* __restrict__ p, IDi
}
}
}
}
}
}
}
// namespace ck
#endif
src/include/threadwise_direct_convolution.hpp
View file @
88b77181
#pragma once
#ifndef CK_THREADWISE_DIRECT_CONVOLUTION_HPP
#define CK_THREADWISE_DIRECT_CONVOLUTION_HPP
#include "ConstantTensorDescriptor.hpp"
#include "ConstantTensorDescriptor.hpp"
#include "threadwise_tensor_slice_op.hpp"
#include "threadwise_tensor_slice_copy.hpp"
namespace
ck
{
// optimized for scenario if p_in, p_wei, p_out are in register
// optimized for scenario if p_in, p_wei, p_out are in register
template
<
class
TInWei
,
class
TOut
,
class
InDesc
,
class
WeiDesc
,
class
OutDesc
>
template
<
class
TInWei
,
class
TOut
,
class
InDesc
,
class
WeiDesc
,
class
OutDesc
>
...
@@ -218,3 +222,6 @@ __device__ void threadwise_direct_convolution_3(InDesc,
...
@@ -218,3 +222,6 @@ __device__ void threadwise_direct_convolution_3(InDesc,
}
}
#endif
#endif
}
}
}
// namespace ck
#endif
src/include/threadwise_gemm.hpp
View file @
88b77181
#pragma once
#ifndef CK_THREADWISE_GEMM_HPP
#define CK_THREADWISE_GEMM_HPP
#include "common.hpp"
#include "common.hpp"
#include "ConstantMatrixDescriptor.hpp"
#include "ConstantMatrixDescriptor.hpp"
namespace
ck
{
template
<
class
Float
,
class
Matrix
>
template
<
class
Float
,
class
Matrix
>
__device__
void
threadwise_matrix_set_zero
(
Matrix
,
Float
*
__restrict__
p_thread
)
__device__
void
threadwise_matrix_set_zero
(
Matrix
,
Float
*
__restrict__
p_thread
)
{
{
...
@@ -114,3 +118,6 @@ __device__ void threadwise_gemm(MatrixA,
...
@@ -114,3 +118,6 @@ __device__ void threadwise_gemm(MatrixA,
assert
(
false
);
assert
(
false
);
}
}
}
}
}
// namespace ck
#endif
src/include/threadwise_generic_tensor_op.hpp
0 → 100644
View file @
88b77181
#ifndef CK_THREADWISE_GENERIC_TENSOR_OP_HPP
#define CK_THREADWISE_GENERIC_TENSOR_OP_HPP
#include "ConstantTensorDescriptor.hpp"
#include "ConstantMergedTensorDescriptor.hpp"
namespace
ck
{
template
<
class
Float
,
class
TDesc
>
__device__
void
threadwise_generic_tensor_set_zero
(
TDesc
,
Float
*
__restrict__
p
)
{
static_ford
<
decltype
(
TDesc
::
GetLengths
())
>
{}([
&
](
auto
multi_id
)
{
constexpr
index_t
offset
=
TDesc
::
GetOffsetFromMultiIndex
(
multi_id
);
p
[
offset
]
=
static_cast
<
Float
>
(
0
);
});
}
}
// namespace ck
#endif
src/include/threadwise_generic_tensor_slice_op.hpp
→
src/include/threadwise_generic_tensor_slice_
c
op
y
.hpp
View file @
88b77181
#pragma once
#ifndef CK_THREADWISE_GENERIC_TENSOR_SLICE_COPY_HPP
#define CK_THREADWISE_GENERIC_TENSOR_SLICE_COPY_HPP
#include "ConstantTensorDescriptor.hpp"
#include "ConstantTensorDescriptor.hpp"
#include "ConstantMergedTensorDescriptor.hpp"
#include "ConstantMergedTensorDescriptor.hpp"
namespace
ck
{
template
<
class
Float
,
template
<
class
Float
,
class
SrcDesc
,
class
SrcDesc
,
class
DstDesc
,
class
DstDesc
,
...
@@ -97,3 +101,6 @@ __device__ void threadwise_generic_tensor_slice_copy_v1(
...
@@ -97,3 +101,6 @@ __device__ void threadwise_generic_tensor_slice_copy_v1(
});
});
#endif
#endif
}
}
}
// namespace ck
#endif
src/include/threadwise_tensor_slice_op.hpp
→
src/include/threadwise_tensor_slice_
c
op
y
.hpp
View file @
88b77181
#pragma once
#ifndef CK_THREADWISE_TENSOR_SLICE_COPY_HPP
#define CK_THREADWISE_TENSOR_SLICE_COPY_HPP
#include "ConstantTensorDescriptor.hpp"
#include "ConstantTensorDescriptor.hpp"
namespace
ck
{
// need to assume src and dst is aligned
// need to assume src and dst is aligned
template
<
class
Float
,
class
SrcDesc
,
class
DstDesc
,
class
SrcOpLengths
,
index_t
DataPerRead
>
template
<
class
Float
,
class
SrcDesc
,
class
DstDesc
,
class
SrcOpLengths
,
index_t
DataPerRead
>
__device__
void
threadwise_tensor_slice_copy
(
SrcDesc
,
__device__
void
threadwise_tensor_slice_copy
(
SrcDesc
,
...
@@ -192,3 +196,6 @@ threadwise_tensor_slice_copy_reorder_given_dst2src_v3(SrcDesc,
...
@@ -192,3 +196,6 @@ threadwise_tensor_slice_copy_reorder_given_dst2src_v3(SrcDesc,
});
});
});
});
}
}
}
// namespace ck
#endif
Prev
1
2
3
4
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment