Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
yangql
composable_kernel-1
Commits
88b77181
"git@developer.sourcefind.cn:orangecat/ollama.git" did not exist on "52ea4d4bb2812f7ed7611e868fa429bf29970a1c"
Commit
88b77181
authored
Jun 11, 2019
by
Chao Liu
Browse files
rename files, added header guard, added namespace
parent
05e04665
Changes
62
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
222 additions
and
85 deletions
+222
-85
src/include/gridwise_convolution_implicit_gemm_v1r3_nchw_cyxk_nkhw.hpp
...ridwise_convolution_implicit_gemm_v1r3_nchw_cyxk_nkhw.hpp
+23
-15
src/include/gridwise_convolution_implicit_gemm_v1r3_nchw_cyxk_nkhw_lds_double_buffer.hpp
...n_implicit_gemm_v1r3_nchw_cyxk_nkhw_lds_double_buffer.hpp
+24
-16
src/include/gridwise_convolution_implicit_gemm_v2_chwn_cyxk_khwn.hpp
.../gridwise_convolution_implicit_gemm_v2_chwn_cyxk_khwn.hpp
+9
-2
src/include/gridwise_convolution_implicit_gemm_v2_chwn_cyxk_khwn_lds_double_buffer.hpp
...ion_implicit_gemm_v2_chwn_cyxk_khwn_lds_double_buffer.hpp
+10
-3
src/include/gridwise_convolution_implicit_gemm_v3_nchw_cyxk_nkhw.hpp
.../gridwise_convolution_implicit_gemm_v3_nchw_cyxk_nkhw.hpp
+14
-7
src/include/gridwise_convolution_implicit_gemm_v3_nchw_cyxk_nkhw_lds_double_buffer.hpp
...ion_implicit_gemm_v3_nchw_cyxk_nkhw_lds_double_buffer.hpp
+15
-8
src/include/gridwise_convolution_implicit_gemm_v4_nchw_kcyx_nkhw.hpp
.../gridwise_convolution_implicit_gemm_v4_nchw_kcyx_nkhw.hpp
+15
-8
src/include/gridwise_convolution_implicit_gemm_v4_nchw_kcyx_nkhw_lds_double_buffer.hpp
...ion_implicit_gemm_v4_nchw_kcyx_nkhw_lds_double_buffer.hpp
+16
-9
src/include/gridwise_convolution_kernel_wrapper.hpp
src/include/gridwise_convolution_kernel_wrapper.hpp
+16
-0
src/include/gridwise_convolution_wrapper.hpp
src/include/gridwise_convolution_wrapper.hpp
+0
-9
src/include/gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw.hpp
...idwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw.hpp
+4
-0
src/include/gridwise_implicit_gemm_convolution_1_chwn_cyxk_khwn_padded.hpp
...ise_implicit_gemm_convolution_1_chwn_cyxk_khwn_padded.hpp
+4
-0
src/include/integral_constant.hpp
src/include/integral_constant.hpp
+7
-1
src/include/tensor.hpp
src/include/tensor.hpp
+5
-1
src/include/threadwise_4d_tensor_op.hpp
src/include/threadwise_4d_tensor_op.hpp
+8
-1
src/include/threadwise_direct_convolution.hpp
src/include/threadwise_direct_convolution.hpp
+9
-2
src/include/threadwise_gemm.hpp
src/include/threadwise_gemm.hpp
+8
-1
src/include/threadwise_generic_tensor_op.hpp
src/include/threadwise_generic_tensor_op.hpp
+19
-0
src/include/threadwise_generic_tensor_slice_copy.hpp
src/include/threadwise_generic_tensor_slice_copy.hpp
+8
-1
src/include/threadwise_tensor_slice_copy.hpp
src/include/threadwise_tensor_slice_copy.hpp
+8
-1
No files found.
src/include/gridwise_convolution_implicit_gemm_v1r3_nchw_cyxk_nkhw.hpp
View file @
88b77181
#pragma once
#ifndef CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V1R3_NCHW_CYXK_NKHW
#define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V1R3_NCHW_CYXK_NKHW
#include "common.hpp"
#include "common.hpp"
#include "ConstantTensorDescriptor.hpp"
#include "ConstantTensorDescriptor.hpp"
#include "ConstantMatrixDescriptor.hpp"
#include "ConstantMatrixDescriptor.hpp"
#include "blockwise_2d_tensor_op.hpp"
#include "blockwise_2d_tensor_op.hpp"
#include "blockwise_tensor_slice_op.hpp"
#include "blockwise_tensor_slice_
c
op
y
.hpp"
#include "threadwise_tensor_slice_op.hpp"
#include "threadwise_tensor_slice_
c
op
y
.hpp"
#include "threadwise_
4d
_tensor_op.hpp"
#include "threadwise_
generic
_tensor_op.hpp"
#include "blockwise_batched_gemm.hpp"
#include "blockwise_batched_gemm.hpp"
namespace
ck
{
template
<
index_t
GridSize
,
template
<
index_t
GridSize
,
index_t
BlockSize
,
index_t
BlockSize
,
class
Float
,
class
Float
,
...
@@ -78,10 +82,10 @@ struct GridwiseConvolutionImplicitGemm_v1r3_nchw_cyxk_nkhw
...
@@ -78,10 +82,10 @@ struct GridwiseConvolutionImplicitGemm_v1r3_nchw_cyxk_nkhw
Ho
%
HoPerBlock
==
0
&&
Wo
%
WoPerBlock
==
0
,
Ho
%
HoPerBlock
==
0
&&
Wo
%
WoPerBlock
==
0
,
"wrong! cannot evenly divide work for workgroup "
);
"wrong! cannot evenly divide work for workgroup "
);
constexpr
index_t
NBlockWork
=
m
od_conv
::
integer_divide_ceil
(
N
,
NPerBlock
);
constexpr
index_t
NBlockWork
=
m
ath
::
integer_divide_ceil
(
N
,
NPerBlock
);
constexpr
index_t
KBlockWork
=
m
od_conv
::
integer_divide_ceil
(
K
,
KPerBlock
);
constexpr
index_t
KBlockWork
=
m
ath
::
integer_divide_ceil
(
K
,
KPerBlock
);
constexpr
index_t
HBlockWork
=
m
od_conv
::
integer_divide_ceil
(
Ho
,
HoPerBlock
);
constexpr
index_t
HBlockWork
=
m
ath
::
integer_divide_ceil
(
Ho
,
HoPerBlock
);
constexpr
index_t
WBlockWork
=
m
od_conv
::
integer_divide_ceil
(
Wo
,
WoPerBlock
);
constexpr
index_t
WBlockWork
=
m
ath
::
integer_divide_ceil
(
Wo
,
WoPerBlock
);
constexpr
auto
block_work_desc
=
make_ConstantTensorDescriptor_packed
(
constexpr
auto
block_work_desc
=
make_ConstantTensorDescriptor_packed
(
Sequence
<
NBlockWork
,
KBlockWork
,
HBlockWork
,
WBlockWork
>
{});
Sequence
<
NBlockWork
,
KBlockWork
,
HBlockWork
,
WBlockWork
>
{});
...
@@ -103,7 +107,7 @@ struct GridwiseConvolutionImplicitGemm_v1r3_nchw_cyxk_nkhw
...
@@ -103,7 +107,7 @@ struct GridwiseConvolutionImplicitGemm_v1r3_nchw_cyxk_nkhw
// LDS tensor view
// LDS tensor view
// be careful of alignment
// be careful of alignment
constexpr
index_t
max_align
=
m
od_conv
::
lcm
(
InBlockReorderDataPerWrite_N
,
constexpr
index_t
max_align
=
m
ath
::
lcm
(
InBlockReorderDataPerWrite_N
,
WeiBlockCopyDataPerRead_K
,
WeiBlockCopyDataPerRead_K
,
GemmDataPerReadA
,
GemmDataPerReadA
,
GemmDataPerReadB
);
GemmDataPerReadB
);
...
@@ -119,7 +123,7 @@ struct GridwiseConvolutionImplicitGemm_v1r3_nchw_cyxk_nkhw
...
@@ -119,7 +123,7 @@ struct GridwiseConvolutionImplicitGemm_v1r3_nchw_cyxk_nkhw
constexpr
auto
wei_c_k_block_desc
=
make_ConstantTensorDescriptor_aligned
(
constexpr
auto
wei_c_k_block_desc
=
make_ConstantTensorDescriptor_aligned
(
Sequence
<
CPerBlock
,
KPerBlock
>
{},
Sequence
<
CPerBlock
,
KPerBlock
>
{},
Number
<
m
od_conv
::
lcm
(
WeiBlockCopyDataPerRead_K
,
GemmDataPerReadA
)
>
{});
Number
<
m
ath
::
lcm
(
WeiBlockCopyDataPerRead_K
,
GemmDataPerReadA
)
>
{});
// tensor view of threadwise output in register
// tensor view of threadwise output in register
constexpr
auto
out_k_h_w_n_thread_desc
=
make_ConstantTensorDescriptor_packed
(
constexpr
auto
out_k_h_w_n_thread_desc
=
make_ConstantTensorDescriptor_packed
(
...
@@ -230,7 +234,7 @@ struct GridwiseConvolutionImplicitGemm_v1r3_nchw_cyxk_nkhw
...
@@ -230,7 +234,7 @@ struct GridwiseConvolutionImplicitGemm_v1r3_nchw_cyxk_nkhw
#endif
#endif
// set threadwise output tensor to 0
// set threadwise output tensor to 0
threadwise_
4d
_tensor_set_zero
(
out_k_h_w_n_thread_desc
,
p_out_thread
);
threadwise_
generic
_tensor_set_zero
(
out_k_h_w_n_thread_desc
,
p_out_thread
);
#if 0
#if 0
const Float* p_in_global_block_offset =
const Float* p_in_global_block_offset =
...
@@ -436,8 +440,12 @@ struct GridwiseConvolutionImplicitGemm_v1r3_nchw_cyxk_nkhw
...
@@ -436,8 +440,12 @@ struct GridwiseConvolutionImplicitGemm_v1r3_nchw_cyxk_nkhw
wo_block_data_begin
+
wo_thread_data_begin
),
wo_block_data_begin
+
wo_thread_data_begin
),
make_zero_array
<
index_t
,
10
>
(),
make_zero_array
<
index_t
,
10
>
(),
out_10d_thread_desc
.
GetLengths
().
ReorderGivenNew2Old
(
map_out_global2thread
),
out_10d_thread_desc
.
GetLengths
().
ReorderGivenNew2Old
(
map_out_global2thread
),
arithmetic_sequence_gen
<
0
,
10
,
1
>::
SeqType
{});
arithmetic_sequence_gen
<
0
,
10
,
1
>::
SeqType
{},
Number
<
1
>
{});
#endif
#endif
});
});
}
}
};
};
}
// namespace ck
#endif
src/include/gridwise_convolution_implicit_gemm_v1r3_lds_double_buffer
_nchw_cyxk_nkhw
.hpp
→
src/include/gridwise_convolution_implicit_gemm_v1r3_
nchw_cyxk_nkhw_
lds_double_buffer.hpp
View file @
88b77181
#pragma once
#ifndef CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V1R3_NCHW_CYXK_NKHW_LDS_DOUBLE_BUFFER
#define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V1R3_NCHW_CYXK_NKHW_LDS_DOUBLE_BUFFER
#include "common.hpp"
#include "common.hpp"
#include "ConstantTensorDescriptor.hpp"
#include "ConstantTensorDescriptor.hpp"
#include "ConstantMatrixDescriptor.hpp"
#include "ConstantMatrixDescriptor.hpp"
#include "blockwise_2d_tensor_op.hpp"
#include "blockwise_2d_tensor_op.hpp"
#include "blockwise_tensor_slice_op.hpp"
#include "blockwise_tensor_slice_
c
op
y
.hpp"
#include "threadwise_tensor_slice_op.hpp"
#include "threadwise_tensor_slice_
c
op
y
.hpp"
#include "threadwise_
4d
_tensor_op.hpp"
#include "threadwise_
generic
_tensor_op.hpp"
#include "blockwise_batched_gemm.hpp"
#include "blockwise_batched_gemm.hpp"
namespace
ck
{
template
<
index_t
GridSize
,
template
<
index_t
GridSize
,
index_t
BlockSize
,
index_t
BlockSize
,
class
Float
,
class
Float
,
...
@@ -40,7 +44,7 @@ template <index_t GridSize,
...
@@ -40,7 +44,7 @@ template <index_t GridSize,
class
WeiBlockCopyClusterLengths_CK
,
// not used
class
WeiBlockCopyClusterLengths_CK
,
// not used
index_t
WeiBlockCopyDataPerRead_K
,
index_t
WeiBlockCopyDataPerRead_K
,
index_t
OutThreadCopyDataPerWrite_W
>
index_t
OutThreadCopyDataPerWrite_W
>
struct
GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer
_nchw_cyxk_nkhw
struct
GridwiseConvolutionImplicitGemm_v1r3_
nchw_cyxk_nkhw_
lds_double_buffer
{
{
__device__
void
Run
(
const
Float
*
const
__restrict__
p_in_global
,
__device__
void
Run
(
const
Float
*
const
__restrict__
p_in_global
,
const
Float
*
const
__restrict__
p_wei_global
,
const
Float
*
const
__restrict__
p_wei_global
,
...
@@ -81,10 +85,10 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_nchw_cyxk_nkhw
...
@@ -81,10 +85,10 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_nchw_cyxk_nkhw
Ho
%
HoPerBlock
==
0
&&
Wo
%
WoPerBlock
==
0
,
Ho
%
HoPerBlock
==
0
&&
Wo
%
WoPerBlock
==
0
,
"wrong! cannot evenly divide work for workgroup "
);
"wrong! cannot evenly divide work for workgroup "
);
constexpr
index_t
NBlockWork
=
m
od_conv
::
integer_divide_ceil
(
N
,
NPerBlock
);
constexpr
index_t
NBlockWork
=
m
ath
::
integer_divide_ceil
(
N
,
NPerBlock
);
constexpr
index_t
KBlockWork
=
m
od_conv
::
integer_divide_ceil
(
K
,
KPerBlock
);
constexpr
index_t
KBlockWork
=
m
ath
::
integer_divide_ceil
(
K
,
KPerBlock
);
constexpr
index_t
HBlockWork
=
m
od_conv
::
integer_divide_ceil
(
Ho
,
HoPerBlock
);
constexpr
index_t
HBlockWork
=
m
ath
::
integer_divide_ceil
(
Ho
,
HoPerBlock
);
constexpr
index_t
WBlockWork
=
m
od_conv
::
integer_divide_ceil
(
Wo
,
WoPerBlock
);
constexpr
index_t
WBlockWork
=
m
ath
::
integer_divide_ceil
(
Wo
,
WoPerBlock
);
constexpr
auto
block_work_desc
=
make_ConstantTensorDescriptor_packed
(
constexpr
auto
block_work_desc
=
make_ConstantTensorDescriptor_packed
(
Sequence
<
NBlockWork
,
KBlockWork
,
HBlockWork
,
WBlockWork
>
{});
Sequence
<
NBlockWork
,
KBlockWork
,
HBlockWork
,
WBlockWork
>
{});
...
@@ -105,7 +109,7 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_nchw_cyxk_nkhw
...
@@ -105,7 +109,7 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_nchw_cyxk_nkhw
// LDS tensor view
// LDS tensor view
// be careful of alignment
// be careful of alignment
constexpr
index_t
max_align
=
m
od_conv
::
lcm
(
InBlockReorderDataPerWrite_N
,
constexpr
index_t
max_align
=
m
ath
::
lcm
(
InBlockReorderDataPerWrite_N
,
WeiBlockCopyDataPerRead_K
,
WeiBlockCopyDataPerRead_K
,
GemmDataPerReadA
,
GemmDataPerReadA
,
GemmDataPerReadB
);
GemmDataPerReadB
);
...
@@ -121,7 +125,7 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_nchw_cyxk_nkhw
...
@@ -121,7 +125,7 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_nchw_cyxk_nkhw
constexpr
auto
wei_c_k_block_desc
=
make_ConstantTensorDescriptor_aligned
(
constexpr
auto
wei_c_k_block_desc
=
make_ConstantTensorDescriptor_aligned
(
Sequence
<
CPerBlock
,
KPerBlock
>
{},
Sequence
<
CPerBlock
,
KPerBlock
>
{},
Number
<
m
od_conv
::
lcm
(
WeiBlockCopyDataPerRead_K
,
GemmDataPerReadA
)
>
{});
Number
<
m
ath
::
lcm
(
WeiBlockCopyDataPerRead_K
,
GemmDataPerReadA
)
>
{});
// tensor view of threadwise output in register
// tensor view of threadwise output in register
constexpr
auto
out_k_h_w_n_thread_desc
=
make_ConstantTensorDescriptor_packed
(
constexpr
auto
out_k_h_w_n_thread_desc
=
make_ConstantTensorDescriptor_packed
(
...
@@ -233,7 +237,7 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_nchw_cyxk_nkhw
...
@@ -233,7 +237,7 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_nchw_cyxk_nkhw
#endif
#endif
// set threadwise output tensor to 0
// set threadwise output tensor to 0
threadwise_
4d
_tensor_set_zero
(
out_k_h_w_n_thread_desc
,
p_out_thread
);
threadwise_
generic
_tensor_set_zero
(
out_k_h_w_n_thread_desc
,
p_out_thread
);
for
(
index_t
y
=
0
;
y
<
Y
;
++
y
)
for
(
index_t
y
=
0
;
y
<
Y
;
++
y
)
{
{
...
@@ -487,8 +491,12 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_nchw_cyxk_nkhw
...
@@ -487,8 +491,12 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_nchw_cyxk_nkhw
wo_block_data_begin
+
wo_thread_data_begin
),
wo_block_data_begin
+
wo_thread_data_begin
),
make_zero_array
<
index_t
,
10
>
(),
make_zero_array
<
index_t
,
10
>
(),
out_10d_thread_desc
.
GetLengths
().
ReorderGivenNew2Old
(
map_out_global2thread
),
out_10d_thread_desc
.
GetLengths
().
ReorderGivenNew2Old
(
map_out_global2thread
),
arithmetic_sequence_gen
<
0
,
10
,
1
>::
SeqType
{});
arithmetic_sequence_gen
<
0
,
10
,
1
>::
SeqType
{},
Number
<
1
>
{});
#endif
#endif
});
});
}
}
};
};
}
// namespace
#endif
src/include/gridwise_convolution_implicit_gemm_v2_chwn_cyxk_khwn.hpp
View file @
88b77181
#pragma once
#ifndef CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V2_CHWN_CYXK_KHWN
#define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V2_CHWN_CYXK_KHWN
#include "common.hpp"
#include "common.hpp"
#include "ConstantTensorDescriptor.hpp"
#include "ConstantTensorDescriptor.hpp"
#include "ConstantMatrixDescriptor.hpp"
#include "ConstantMatrixDescriptor.hpp"
...
@@ -6,6 +8,8 @@
...
@@ -6,6 +8,8 @@
#include "blockwise_2d_tensor_op.hpp"
#include "blockwise_2d_tensor_op.hpp"
#include "blockwise_gemm.hpp"
#include "blockwise_gemm.hpp"
namespace
ck
{
// define B = flatten(N, Hi, Wi)
// define B = flatten(N, Hi, Wi)
template
<
index_t
GridSize
,
template
<
index_t
GridSize
,
index_t
BlockSize
,
index_t
BlockSize
,
...
@@ -181,7 +185,7 @@ struct GridwiseConvolutionImplicitGemm_v2_chwn_cyxk_khwn
...
@@ -181,7 +185,7 @@ struct GridwiseConvolutionImplicitGemm_v2_chwn_cyxk_khwn
// LDS: be careful of alignment
// LDS: be careful of alignment
constexpr
index_t
max_align
=
constexpr
index_t
max_align
=
m
od_conv
::
lcm
(
index_t
(
4
),
InBlockCopyDataPerRead
,
WeiBlockCopyDataPerRead
);
m
ath
::
lcm
(
index_t
(
4
),
InBlockCopyDataPerRead
,
WeiBlockCopyDataPerRead
);
constexpr
index_t
in_block_space
=
in_cb_block_desc
.
GetElementSpace
(
Number
<
max_align
>
{});
constexpr
index_t
in_block_space
=
in_cb_block_desc
.
GetElementSpace
(
Number
<
max_align
>
{});
...
@@ -275,3 +279,6 @@ struct GridwiseConvolutionImplicitGemm_v2_chwn_cyxk_khwn
...
@@ -275,3 +279,6 @@ struct GridwiseConvolutionImplicitGemm_v2_chwn_cyxk_khwn
}
}
}
}
};
};
}
// namespace ck
#endif
src/include/gridwise_convolution_implicit_gemm_v2_chwn_cyxk_khwn_lds_double_buffer.hpp
View file @
88b77181
#pragma once
#ifndef CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V2_CHWN_CYXK_KHWN_LDS_DOUBLE_BUFFER
#define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V2_CHWN_CYXK_KHWN_LDS_DOUBLE_BUFFER
#include "common.hpp"
#include "common.hpp"
#include "ConstantTensorDescriptor.hpp"
#include "ConstantTensorDescriptor.hpp"
#include "ConstantMatrixDescriptor.hpp"
#include "ConstantMatrixDescriptor.hpp"
#include "blockwise_4d_tensor_op.hpp"
#include "blockwise_4d_tensor_op.hpp"
#include "blockwise_2d_tensor_op.hpp"
#include "blockwise_2d_tensor_op.hpp"
#include "threadwise_tensor_slice_op.hpp"
#include "threadwise_tensor_slice_
c
op
y
.hpp"
#include "blockwise_gemm.hpp"
#include "blockwise_gemm.hpp"
namespace
ck
{
// define B = flatten(N, Hi, Wi)
// define B = flatten(N, Hi, Wi)
template
<
index_t
GridSize
,
template
<
index_t
GridSize
,
index_t
BlockSize
,
index_t
BlockSize
,
...
@@ -185,7 +189,7 @@ struct GridwiseConvolutionImplicitGemm_v2_chwn_cyxk_khwn_lds_double_buffer
...
@@ -185,7 +189,7 @@ struct GridwiseConvolutionImplicitGemm_v2_chwn_cyxk_khwn_lds_double_buffer
// LDS: be careful of alignment
// LDS: be careful of alignment
constexpr
index_t
max_align
=
constexpr
index_t
max_align
=
m
od_conv
::
lcm
(
index_t
(
4
),
InBlockCopyDataPerRead
,
WeiBlockCopyDataPerRead
);
m
ath
::
lcm
(
index_t
(
4
),
InBlockCopyDataPerRead
,
WeiBlockCopyDataPerRead
);
constexpr
index_t
in_block_space
=
in_cb_block_desc
.
GetElementSpace
(
Number
<
max_align
>
{});
constexpr
index_t
in_block_space
=
in_cb_block_desc
.
GetElementSpace
(
Number
<
max_align
>
{});
...
@@ -404,3 +408,6 @@ struct GridwiseConvolutionImplicitGemm_v2_chwn_cyxk_khwn_lds_double_buffer
...
@@ -404,3 +408,6 @@ struct GridwiseConvolutionImplicitGemm_v2_chwn_cyxk_khwn_lds_double_buffer
}
}
}
}
};
};
}
// namespace ck
#endif
src/include/gridwise_convolution_implicit_gemm_v3_nchw_cyxk_nkhw.hpp
View file @
88b77181
#pragma once
#ifndef CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V3_NCHW_CYXK_NKHW
#define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V3_NCHW_CYXK_NKHW
#include "common.hpp"
#include "common.hpp"
#include "ConstantTensorDescriptor.hpp"
#include "ConstantTensorDescriptor.hpp"
#include "ConstantMergedTensorDescriptor.hpp"
#include "ConstantMergedTensorDescriptor.hpp"
#include "ConstantMatrixDescriptor.hpp"
#include "ConstantMatrixDescriptor.hpp"
#include "blockwise_generic_tensor_slice_op.hpp"
#include "blockwise_generic_tensor_slice_
c
op
y
.hpp"
#include "blockwise_gemm.hpp"
#include "blockwise_gemm.hpp"
namespace
ck
{
// define B = merge(N0, Ho, Wo)
// define B = merge(N0, Ho, Wo)
template
<
index_t
GridSize
,
template
<
index_t
GridSize
,
index_t
BlockSize
,
index_t
BlockSize
,
...
@@ -146,7 +150,7 @@ struct GridwiseConvolutionImplicitGemm_v3_nchw_cyxk_nkhw
...
@@ -146,7 +150,7 @@ struct GridwiseConvolutionImplicitGemm_v3_nchw_cyxk_nkhw
// be careful of LDS alignment
// be careful of LDS alignment
constexpr
auto
wei_c_k_block_desc
=
make_ConstantTensorDescriptor_aligned
(
constexpr
auto
wei_c_k_block_desc
=
make_ConstantTensorDescriptor_aligned
(
Sequence
<
CPerBlock
,
KPerBlock
>
{},
Sequence
<
CPerBlock
,
KPerBlock
>
{},
Number
<
m
od_conv
::
lcm
(
WeiBlockCopyDataPerAccess_K
,
GemmDataPerReadA
)
>
{});
Number
<
m
ath
::
lcm
(
WeiBlockCopyDataPerAccess_K
,
GemmDataPerReadA
)
>
{});
// operator for blockwise copy of weight into LDS
// operator for blockwise copy of weight into LDS
// slice a tensor, and copy it into another tensor
// slice a tensor, and copy it into another tensor
...
@@ -218,7 +222,7 @@ struct GridwiseConvolutionImplicitGemm_v3_nchw_cyxk_nkhw
...
@@ -218,7 +222,7 @@ struct GridwiseConvolutionImplicitGemm_v3_nchw_cyxk_nkhw
};
};
// LDS allocation for input and weight: be careful of alignment
// LDS allocation for input and weight: be careful of alignment
constexpr
index_t
max_align
=
m
od_conv
::
lcm
(
InBlockCopyDstDataPerWrite_N2
,
constexpr
index_t
max_align
=
m
ath
::
lcm
(
InBlockCopyDstDataPerWrite_N2
,
WeiBlockCopyDataPerAccess_K
,
WeiBlockCopyDataPerAccess_K
,
GemmDataPerReadA
,
GemmDataPerReadA
,
GemmDataPerReadB
);
GemmDataPerReadB
);
...
@@ -368,3 +372,6 @@ struct GridwiseConvolutionImplicitGemm_v3_nchw_cyxk_nkhw
...
@@ -368,3 +372,6 @@ struct GridwiseConvolutionImplicitGemm_v3_nchw_cyxk_nkhw
}
}
}
}
};
};
}
// namespace ck
#endif
src/include/gridwise_convolution_implicit_gemm_v3_lds_double_buffer
_nchw_cyxk_nkhw
.hpp
→
src/include/gridwise_convolution_implicit_gemm_v3_
nchw_cyxk_nkhw_
lds_double_buffer.hpp
View file @
88b77181
#pragma once
#ifndef CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V3_NCHW_CYXK_NKHW_LDS_DOUBLE_BUFFER
#define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V3_NCHW_CYXK_NKHW_LDS_DOUBLE_BUFFER
#include "common.hpp"
#include "common.hpp"
#include "ConstantTensorDescriptor.hpp"
#include "ConstantTensorDescriptor.hpp"
#include "ConstantMergedTensorDescriptor.hpp"
#include "ConstantMergedTensorDescriptor.hpp"
#include "ConstantMatrixDescriptor.hpp"
#include "ConstantMatrixDescriptor.hpp"
#include "blockwise_generic_tensor_slice_op.hpp"
#include "blockwise_generic_tensor_slice_
c
op
y
.hpp"
#include "blockwise_gemm.hpp"
#include "blockwise_gemm.hpp"
namespace
ck
{
// define B = merge(N0, Ho, Wo)
// define B = merge(N0, Ho, Wo)
template
<
index_t
GridSize
,
template
<
index_t
GridSize
,
index_t
BlockSize
,
index_t
BlockSize
,
...
@@ -34,7 +38,7 @@ template <index_t GridSize,
...
@@ -34,7 +38,7 @@ template <index_t GridSize,
class
WeiBlockCopySubLengths_C_K
,
class
WeiBlockCopySubLengths_C_K
,
class
WeiBlockCopyClusterLengths_C_K
,
class
WeiBlockCopyClusterLengths_C_K
,
index_t
WeiBlockCopyDataPerAccess_K
>
index_t
WeiBlockCopyDataPerAccess_K
>
struct
GridwiseConvolutionImplicitGemm_v3_lds_double_buffer
_nchw_cyxk_nkhw
struct
GridwiseConvolutionImplicitGemm_v3_
nchw_cyxk_nkhw_
lds_double_buffer
{
{
__device__
void
Run
(
const
Float
*
const
__restrict__
p_in_global
,
__device__
void
Run
(
const
Float
*
const
__restrict__
p_in_global
,
const
Float
*
const
__restrict__
p_wei_global
,
const
Float
*
const
__restrict__
p_wei_global
,
...
@@ -143,7 +147,7 @@ struct GridwiseConvolutionImplicitGemm_v3_lds_double_buffer_nchw_cyxk_nkhw
...
@@ -143,7 +147,7 @@ struct GridwiseConvolutionImplicitGemm_v3_lds_double_buffer_nchw_cyxk_nkhw
// be careful of LDS alignment
// be careful of LDS alignment
constexpr
auto
wei_c_k_block_desc
=
make_ConstantTensorDescriptor_aligned
(
constexpr
auto
wei_c_k_block_desc
=
make_ConstantTensorDescriptor_aligned
(
Sequence
<
CPerBlock
,
KPerBlock
>
{},
Sequence
<
CPerBlock
,
KPerBlock
>
{},
Number
<
m
od_conv
::
lcm
(
WeiBlockCopyDataPerAccess_K
,
GemmDataPerReadA
)
>
{});
Number
<
m
ath
::
lcm
(
WeiBlockCopyDataPerAccess_K
,
GemmDataPerReadA
)
>
{});
// operator for blockwise copy of weight into LDS
// operator for blockwise copy of weight into LDS
// slice a tensor, and copy it into another tensor
// slice a tensor, and copy it into another tensor
...
@@ -215,7 +219,7 @@ struct GridwiseConvolutionImplicitGemm_v3_lds_double_buffer_nchw_cyxk_nkhw
...
@@ -215,7 +219,7 @@ struct GridwiseConvolutionImplicitGemm_v3_lds_double_buffer_nchw_cyxk_nkhw
};
};
// LDS allocation for input and weight: be careful of alignment
// LDS allocation for input and weight: be careful of alignment
constexpr
index_t
max_align
=
m
od_conv
::
lcm
(
InBlockCopyDstDataPerWrite_N2
,
constexpr
index_t
max_align
=
m
ath
::
lcm
(
InBlockCopyDstDataPerWrite_N2
,
WeiBlockCopyDataPerAccess_K
,
WeiBlockCopyDataPerAccess_K
,
GemmDataPerReadA
,
GemmDataPerReadA
,
GemmDataPerReadB
);
GemmDataPerReadB
);
...
@@ -395,3 +399,6 @@ struct GridwiseConvolutionImplicitGemm_v3_lds_double_buffer_nchw_cyxk_nkhw
...
@@ -395,3 +399,6 @@ struct GridwiseConvolutionImplicitGemm_v3_lds_double_buffer_nchw_cyxk_nkhw
}
}
}
}
};
};
}
// namesspace ck
#endif
src/include/gridwise_convolution_implicit_gemm_v4_nchw_kcyx_nkhw.hpp
View file @
88b77181
#pragma once
#ifndef CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4_NCHW_KCYX_NKHW
#define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4_NCHW_KCYX_NKHW
#include "common.hpp"
#include "common.hpp"
#include "ConstantTensorDescriptor.hpp"
#include "ConstantTensorDescriptor.hpp"
#include "ConstantMergedTensorDescriptor.hpp"
#include "ConstantMergedTensorDescriptor.hpp"
#include "ConstantMatrixDescriptor.hpp"
#include "ConstantMatrixDescriptor.hpp"
#include "blockwise_generic_tensor_slice_op.hpp"
#include "blockwise_generic_tensor_slice_
c
op
y
.hpp"
#include "blockwise_gemm.hpp"
#include "blockwise_gemm.hpp"
#include "threadwise_generic_tensor_slice_op.hpp"
#include "threadwise_generic_tensor_slice_copy.hpp"
namespace
ck
{
// define B = merge(N0, Ho, Wo)
// define B = merge(N0, Ho, Wo)
template
<
index_t
GridSize
,
template
<
index_t
GridSize
,
...
@@ -176,7 +180,7 @@ struct GridwiseConvolutionImplicitGemm_v4_nchw_kcyx_nkhw
...
@@ -176,7 +180,7 @@ struct GridwiseConvolutionImplicitGemm_v4_nchw_kcyx_nkhw
// be careful of LDS alignment
// be careful of LDS alignment
constexpr
auto
wei_e_k_block_desc
=
make_ConstantTensorDescriptor_aligned
(
constexpr
auto
wei_e_k_block_desc
=
make_ConstantTensorDescriptor_aligned
(
Sequence
<
EPerBlock
,
KPerBlock
>
{},
Sequence
<
EPerBlock
,
KPerBlock
>
{},
Number
<
m
od_conv
::
lcm
(
WeiBlockCopyDstDataPerWrite_K
,
GemmDataPerReadA
)
>
{});
Number
<
m
ath
::
lcm
(
WeiBlockCopyDstDataPerWrite_K
,
GemmDataPerReadA
)
>
{});
// operator for blockwise copy of weight into LDS
// operator for blockwise copy of weight into LDS
// slice a tensor, and copy it into another tensor
// slice a tensor, and copy it into another tensor
...
@@ -248,7 +252,7 @@ struct GridwiseConvolutionImplicitGemm_v4_nchw_kcyx_nkhw
...
@@ -248,7 +252,7 @@ struct GridwiseConvolutionImplicitGemm_v4_nchw_kcyx_nkhw
};
};
// LDS allocation for input and weight: be careful of alignment
// LDS allocation for input and weight: be careful of alignment
constexpr
index_t
max_align
=
m
od_conv
::
lcm
(
InBlockCopyDstDataPerWrite_N2
,
constexpr
index_t
max_align
=
m
ath
::
lcm
(
InBlockCopyDstDataPerWrite_N2
,
WeiBlockCopyDstDataPerWrite_K
,
WeiBlockCopyDstDataPerWrite_K
,
GemmDataPerReadA
,
GemmDataPerReadA
,
GemmDataPerReadB
);
GemmDataPerReadB
);
...
@@ -345,3 +349,6 @@ struct GridwiseConvolutionImplicitGemm_v4_nchw_kcyx_nkhw
...
@@ -345,3 +349,6 @@ struct GridwiseConvolutionImplicitGemm_v4_nchw_kcyx_nkhw
}
}
}
}
};
};
}
// namespace ck
#endif
src/include/gridwise_convolution_implicit_gemm_v4_lds_double_buffer
_nchw_kcyx_nkhw
.hpp
→
src/include/gridwise_convolution_implicit_gemm_v4_
nchw_kcyx_nkhw_
lds_double_buffer.hpp
View file @
88b77181
#pragma once
#ifndef CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4_NCHW_KCYX_NKHW_LDS_DOUBLE_BUFFER
#define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4_NCHW_KCYX_NKHW_LDS_DOUBLE_BUFFER
#include "common.hpp"
#include "common.hpp"
#include "ConstantTensorDescriptor.hpp"
#include "ConstantTensorDescriptor.hpp"
#include "ConstantMergedTensorDescriptor.hpp"
#include "ConstantMergedTensorDescriptor.hpp"
#include "ConstantMatrixDescriptor.hpp"
#include "ConstantMatrixDescriptor.hpp"
#include "blockwise_generic_tensor_slice_op.hpp"
#include "blockwise_generic_tensor_slice_
c
op
y
.hpp"
#include "blockwise_gemm.hpp"
#include "blockwise_gemm.hpp"
#include "threadwise_generic_tensor_slice_op.hpp"
#include "threadwise_generic_tensor_slice_copy.hpp"
namespace
ck
{
// define B = merge(N0, Ho, Wo)
// define B = merge(N0, Ho, Wo)
template
<
index_t
GridSize
,
template
<
index_t
GridSize
,
...
@@ -42,7 +46,7 @@ template <index_t GridSize,
...
@@ -42,7 +46,7 @@ template <index_t GridSize,
class
WeiBlockCopyDstAccessOrder
,
class
WeiBlockCopyDstAccessOrder
,
index_t
WeiBlockCopySrcDataPerRead_E
,
index_t
WeiBlockCopySrcDataPerRead_E
,
index_t
WeiBlockCopyDstDataPerWrite_K
>
index_t
WeiBlockCopyDstDataPerWrite_K
>
struct
GridwiseConvolutionImplicitGemm_v4_lds_double_buffer
_nchw_kcyx_nkhw
struct
GridwiseConvolutionImplicitGemm_v4_
nchw_kcyx_nkhw_
lds_double_buffer
{
{
__device__
void
Run
(
const
Float
*
const
__restrict__
p_in_global
,
__device__
void
Run
(
const
Float
*
const
__restrict__
p_in_global
,
const
Float
*
const
__restrict__
p_wei_global
,
const
Float
*
const
__restrict__
p_wei_global
,
...
@@ -165,7 +169,7 @@ struct GridwiseConvolutionImplicitGemm_v4_lds_double_buffer_nchw_kcyx_nkhw
...
@@ -165,7 +169,7 @@ struct GridwiseConvolutionImplicitGemm_v4_lds_double_buffer_nchw_kcyx_nkhw
// be careful of LDS alignment
// be careful of LDS alignment
constexpr
auto
wei_e_k_block_desc
=
make_ConstantTensorDescriptor_aligned
(
constexpr
auto
wei_e_k_block_desc
=
make_ConstantTensorDescriptor_aligned
(
Sequence
<
EPerBlock
,
KPerBlock
>
{},
Sequence
<
EPerBlock
,
KPerBlock
>
{},
Number
<
m
od_conv
::
lcm
(
WeiBlockCopyDstDataPerWrite_K
,
GemmDataPerReadA
)
>
{});
Number
<
m
ath
::
lcm
(
WeiBlockCopyDstDataPerWrite_K
,
GemmDataPerReadA
)
>
{});
// operator for blockwise copy of weight into LDS
// operator for blockwise copy of weight into LDS
// slice a tensor, and copy it into another tensor
// slice a tensor, and copy it into another tensor
...
@@ -237,7 +241,7 @@ struct GridwiseConvolutionImplicitGemm_v4_lds_double_buffer_nchw_kcyx_nkhw
...
@@ -237,7 +241,7 @@ struct GridwiseConvolutionImplicitGemm_v4_lds_double_buffer_nchw_kcyx_nkhw
};
};
// LDS allocation for input and weight: be careful of alignment
// LDS allocation for input and weight: be careful of alignment
constexpr
index_t
max_align
=
m
od_conv
::
lcm
(
InBlockCopyDstDataPerWrite_N2
,
constexpr
index_t
max_align
=
m
ath
::
lcm
(
InBlockCopyDstDataPerWrite_N2
,
WeiBlockCopyDstDataPerWrite_K
,
WeiBlockCopyDstDataPerWrite_K
,
GemmDataPerReadA
,
GemmDataPerReadA
,
GemmDataPerReadB
);
GemmDataPerReadB
);
...
@@ -410,3 +414,6 @@ struct GridwiseConvolutionImplicitGemm_v4_lds_double_buffer_nchw_kcyx_nkhw
...
@@ -410,3 +414,6 @@ struct GridwiseConvolutionImplicitGemm_v4_lds_double_buffer_nchw_kcyx_nkhw
}
}
}
}
};
};
}
// namespace ck
#endif
src/include/gridwise_convolution_kernel_wrapper.hpp
0 → 100644
View file @
88b77181
#ifndef CK_GRIDWISE_CONVOLUTION_KERNEL_WRAPPER
#define CK_GRIDWISE_CONVOLUTION_KERNEL_WRAPPER
namespace
ck
{
template
<
class
GridwiseConvolution
,
class
T
>
__global__
void
run_gridwise_convolution_kernel
(
const
T
*
const
__restrict__
p_in_global
,
const
T
*
const
__restrict__
p_wei_global
,
T
*
const
__restrict__
p_out_global
)
{
GridwiseConvolution
{}.
Run
(
p_in_global
,
p_wei_global
,
p_out_global
);
}
}
// namespace ck
#endif
src/include/gridwise_convolution_wrapper.hpp
deleted
100644 → 0
View file @
05e04665
#pragma once
template
<
class
GridwiseConvolution
,
class
T
>
__global__
void
run_gridwise_convolution
(
const
T
*
const
__restrict__
p_in_global
,
const
T
*
const
__restrict__
p_wei_global
,
T
*
const
__restrict__
p_out_global
)
{
GridwiseConvolution
{}.
Run
(
p_in_global
,
p_wei_global
,
p_out_global
);
}
src/include/gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw.hpp
View file @
88b77181
...
@@ -7,6 +7,8 @@
...
@@ -7,6 +7,8 @@
#include "threadwise_4d_tensor_op.hpp"
#include "threadwise_4d_tensor_op.hpp"
#include "threadwise_direct_convolution.hpp"
#include "threadwise_direct_convolution.hpp"
namespace
ck
{
template
<
class
TInWei
,
template
<
class
TInWei
,
class
TOut
,
class
TOut
,
class
TAccum
,
class
TAccum
,
...
@@ -253,3 +255,5 @@ __global__ void gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw(
...
@@ -253,3 +255,5 @@ __global__ void gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw(
wo_block_data_begin
+
wo_thread_data_begin
),
wo_block_data_begin
+
wo_thread_data_begin
),
out_nkhw_thread_desc
.
GetLengths
());
out_nkhw_thread_desc
.
GetLengths
());
}
}
}
// namespace ck
src/include/gridwise_implicit_gemm_convolution_1_chwn_cyxk_khwn_padded.hpp
View file @
88b77181
...
@@ -7,6 +7,8 @@
...
@@ -7,6 +7,8 @@
#include "threadwise_4d_tensor_op.hpp"
#include "threadwise_4d_tensor_op.hpp"
#include "blockwise_gemm.hpp"
#include "blockwise_gemm.hpp"
namespace
ck
{
template
<
index_t
GridSize
,
template
<
index_t
GridSize
,
index_t
BlockSize
,
index_t
BlockSize
,
class
Float
,
class
Float
,
...
@@ -292,3 +294,5 @@ __global__ void gridwise_implicit_gemm_convolution_1_chwn_cyxk_khwn_padded(
...
@@ -292,3 +294,5 @@ __global__ void gridwise_implicit_gemm_convolution_1_chwn_cyxk_khwn_padded(
out_hkwn_thread_desc
.
GetLengths
(),
out_hkwn_thread_desc
.
GetLengths
(),
reorder_khwn_from_hkwn
);
reorder_khwn_from_hkwn
);
}
}
}
// namespace ck
src/include/integral_constant.hpp
View file @
88b77181
#pragma once
#ifndef CK_INTEGRAL_CONSTANT_HPP
#define CK_INTEGRAL_CONSTANT_HPP
namespace
ck
{
template
<
class
T
,
T
N
>
template
<
class
T
,
T
N
>
struct
integral_constant
struct
integral_constant
...
@@ -16,3 +19,6 @@ __host__ __device__ constexpr auto operator+(integral_constant<T, X>, integral_c
...
@@ -16,3 +19,6 @@ __host__ __device__ constexpr auto operator+(integral_constant<T, X>, integral_c
template
<
index_t
N
>
template
<
index_t
N
>
using
Number
=
integral_constant
<
index_t
,
N
>
;
using
Number
=
integral_constant
<
index_t
,
N
>
;
}
// namespace ck
#endif
src/include/tensor.hpp
View file @
88b77181
#pragma once
#ifndef CK_TENSOR_HPP
#define CK_TENSOR_HPP
#include <thread>
#include <thread>
#include <vector>
#include <vector>
#include <numeric>
#include <numeric>
...
@@ -266,3 +268,5 @@ struct Tensor
...
@@ -266,3 +268,5 @@ struct Tensor
TensorDescriptor
mDesc
;
TensorDescriptor
mDesc
;
std
::
vector
<
T
>
mData
;
std
::
vector
<
T
>
mData
;
};
};
#endif
src/include/threadwise_4d_tensor_op.hpp
View file @
88b77181
#pragma once
#ifndef CK_THREADWISE_4D_TENSOR_OP_HPP
#define CK_THREADWISE_4D_TENSOR_OP_HPP
#include "ConstantTensorDescriptor.hpp"
#include "ConstantTensorDescriptor.hpp"
namespace
ck
{
template
<
class
Float
,
class
Desc
,
class
IDim
,
class
NShift
>
template
<
class
Float
,
class
Desc
,
class
IDim
,
class
NShift
>
__device__
void
threadwise_4d_tensor_shift_down
(
Desc
,
Float
*
__restrict__
p
,
IDim
,
NShift
)
__device__
void
threadwise_4d_tensor_shift_down
(
Desc
,
Float
*
__restrict__
p
,
IDim
,
NShift
)
{
{
...
@@ -50,3 +54,6 @@ __device__ void threadwise_4d_tensor_shift_down(Desc, Float* __restrict__ p, IDi
...
@@ -50,3 +54,6 @@ __device__ void threadwise_4d_tensor_shift_down(Desc, Float* __restrict__ p, IDi
}
}
}
}
}
}
}
// namespace ck
#endif
src/include/threadwise_direct_convolution.hpp
View file @
88b77181
#pragma once
#ifndef CK_THREADWISE_DIRECT_CONVOLUTION_HPP
#define CK_THREADWISE_DIRECT_CONVOLUTION_HPP
#include "ConstantTensorDescriptor.hpp"
#include "ConstantTensorDescriptor.hpp"
#include "threadwise_tensor_slice_op.hpp"
#include "threadwise_tensor_slice_copy.hpp"
namespace
ck
{
// optimized for scenario if p_in, p_wei, p_out are in register
// optimized for scenario if p_in, p_wei, p_out are in register
template
<
class
TInWei
,
class
TOut
,
class
InDesc
,
class
WeiDesc
,
class
OutDesc
>
template
<
class
TInWei
,
class
TOut
,
class
InDesc
,
class
WeiDesc
,
class
OutDesc
>
...
@@ -218,3 +222,6 @@ __device__ void threadwise_direct_convolution_3(InDesc,
...
@@ -218,3 +222,6 @@ __device__ void threadwise_direct_convolution_3(InDesc,
}
}
#endif
#endif
}
}
}
// namespace ck
#endif
src/include/threadwise_gemm.hpp
View file @
88b77181
#pragma once
#ifndef CK_THREADWISE_GEMM_HPP
#define CK_THREADWISE_GEMM_HPP
#include "common.hpp"
#include "common.hpp"
#include "ConstantMatrixDescriptor.hpp"
#include "ConstantMatrixDescriptor.hpp"
namespace
ck
{
template
<
class
Float
,
class
Matrix
>
template
<
class
Float
,
class
Matrix
>
__device__
void
threadwise_matrix_set_zero
(
Matrix
,
Float
*
__restrict__
p_thread
)
__device__
void
threadwise_matrix_set_zero
(
Matrix
,
Float
*
__restrict__
p_thread
)
{
{
...
@@ -114,3 +118,6 @@ __device__ void threadwise_gemm(MatrixA,
...
@@ -114,3 +118,6 @@ __device__ void threadwise_gemm(MatrixA,
assert
(
false
);
assert
(
false
);
}
}
}
}
}
// namespace ck
#endif
src/include/threadwise_generic_tensor_op.hpp
0 → 100644
View file @
88b77181
#ifndef CK_THREADWISE_GENERIC_TENSOR_OP_HPP
#define CK_THREADWISE_GENERIC_TENSOR_OP_HPP
#include "ConstantTensorDescriptor.hpp"
#include "ConstantMergedTensorDescriptor.hpp"
namespace
ck
{
template
<
class
Float
,
class
TDesc
>
__device__
void
threadwise_generic_tensor_set_zero
(
TDesc
,
Float
*
__restrict__
p
)
{
static_ford
<
decltype
(
TDesc
::
GetLengths
())
>
{}([
&
](
auto
multi_id
)
{
constexpr
index_t
offset
=
TDesc
::
GetOffsetFromMultiIndex
(
multi_id
);
p
[
offset
]
=
static_cast
<
Float
>
(
0
);
});
}
}
// namespace ck
#endif
src/include/threadwise_generic_tensor_slice_op.hpp
→
src/include/threadwise_generic_tensor_slice_
c
op
y
.hpp
View file @
88b77181
#pragma once
#ifndef CK_THREADWISE_GENERIC_TENSOR_SLICE_COPY_HPP
#define CK_THREADWISE_GENERIC_TENSOR_SLICE_COPY_HPP
#include "ConstantTensorDescriptor.hpp"
#include "ConstantTensorDescriptor.hpp"
#include "ConstantMergedTensorDescriptor.hpp"
#include "ConstantMergedTensorDescriptor.hpp"
namespace
ck
{
template
<
class
Float
,
template
<
class
Float
,
class
SrcDesc
,
class
SrcDesc
,
class
DstDesc
,
class
DstDesc
,
...
@@ -97,3 +101,6 @@ __device__ void threadwise_generic_tensor_slice_copy_v1(
...
@@ -97,3 +101,6 @@ __device__ void threadwise_generic_tensor_slice_copy_v1(
});
});
#endif
#endif
}
}
}
// namespace ck
#endif
src/include/threadwise_tensor_slice_op.hpp
→
src/include/threadwise_tensor_slice_
c
op
y
.hpp
View file @
88b77181
#pragma once
#ifndef CK_THREADWISE_TENSOR_SLICE_COPY_HPP
#define CK_THREADWISE_TENSOR_SLICE_COPY_HPP
#include "ConstantTensorDescriptor.hpp"
#include "ConstantTensorDescriptor.hpp"
namespace
ck
{
// need to assume src and dst is aligned
// need to assume src and dst is aligned
template
<
class
Float
,
class
SrcDesc
,
class
DstDesc
,
class
SrcOpLengths
,
index_t
DataPerRead
>
template
<
class
Float
,
class
SrcDesc
,
class
DstDesc
,
class
SrcOpLengths
,
index_t
DataPerRead
>
__device__
void
threadwise_tensor_slice_copy
(
SrcDesc
,
__device__
void
threadwise_tensor_slice_copy
(
SrcDesc
,
...
@@ -192,3 +196,6 @@ threadwise_tensor_slice_copy_reorder_given_dst2src_v3(SrcDesc,
...
@@ -192,3 +196,6 @@ threadwise_tensor_slice_copy_reorder_given_dst2src_v3(SrcDesc,
});
});
});
});
}
}
}
// namespace ck
#endif
Prev
1
2
3
4
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment