Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
b37cb71f
Commit
b37cb71f
authored
Oct 16, 2019
by
Wen-Heng (Jack) Chung
Browse files
Enable bwd wrw
parent
c5143bca
Changes
26
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
3954 additions
and
805 deletions
+3954
-805
composable_kernel/include/kernel_algorithm/convolution_common.hpp
...le_kernel/include/kernel_algorithm/convolution_common.hpp
+14
-0
composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4_fp16_bfp16_nchw_kcyx_nkhw_lds_double_buffer.hpp
...t_gemm_v4_fp16_bfp16_nchw_kcyx_nkhw_lds_double_buffer.hpp
+66
-68
composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4_nchw_kc1x1_nkhw_lds_double_buffer.hpp
...on_implicit_gemm_v4_nchw_kc1x1_nkhw_lds_double_buffer.hpp
+51
-54
composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer.hpp
...n_implicit_gemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer.hpp
+496
-0
composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer_deprecated.hpp
...gemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer_deprecated.hpp
+447
-0
composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r4_xdlops_fp16_bfp16_nchw_kcyx_nkhw_lds_double_buffer.hpp
...r4_xdlops_fp16_bfp16_nchw_kcyx_nkhw_lds_double_buffer.hpp
+29
-36
composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r4_xdlops_nchw_kc1x1_nkhw_lds_double_buffer.hpp
...it_gemm_v4r4_xdlops_nchw_kc1x1_nkhw_lds_double_buffer.hpp
+49
-55
composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw_lds_double_buffer.hpp
...cit_gemm_v4r4_xdlops_nchw_kcyx_nkhw_lds_double_buffer.hpp
+55
-69
composable_kernel/include/tensor_description/ConstantMatrixDescriptor.hpp
...l/include/tensor_description/ConstantMatrixDescriptor.hpp
+23
-5
composable_kernel/include/tensor_description/ConstantMergedTensorDescriptor_deprecated.hpp
...description/ConstantMergedTensorDescriptor_deprecated.hpp
+210
-0
composable_kernel/include/tensor_description/ConstantTensorDescriptor_deprecated.hpp
...ensor_description/ConstantTensorDescriptor_deprecated.hpp
+612
-0
composable_kernel/include/tensor_description/dimension.hpp
composable_kernel/include/tensor_description/dimension.hpp
+17
-0
composable_kernel/include/tensor_description/multi_index_transform.hpp
...rnel/include/tensor_description/multi_index_transform.hpp
+421
-0
composable_kernel/include/tensor_description/tensor_coordinate.hpp
...e_kernel/include/tensor_description/tensor_coordinate.hpp
+166
-217
composable_kernel/include/tensor_description/tensor_coordinate_deprecated.hpp
...clude/tensor_description/tensor_coordinate_deprecated.hpp
+348
-0
composable_kernel/include/tensor_description/tensor_coordinate_helper.hpp
...l/include/tensor_description/tensor_coordinate_helper.hpp
+16
-0
composable_kernel/include/tensor_description/tensor_descriptor.hpp
...e_kernel/include/tensor_description/tensor_descriptor.hpp
+500
-0
composable_kernel/include/tensor_description/tensor_descriptor_helper.hpp
...l/include/tensor_description/tensor_descriptor_helper.hpp
+212
-0
composable_kernel/include/tensor_operation/blockwise_gemm.hpp
...osable_kernel/include/tensor_operation/blockwise_gemm.hpp
+204
-298
composable_kernel/include/tensor_operation/blockwise_gemm_xdlops.hpp
...kernel/include/tensor_operation/blockwise_gemm_xdlops.hpp
+18
-3
No files found.
composable_kernel/include/kernel_algorithm/convolution_common.hpp
0 → 100644
View file @
b37cb71f
#ifndef CK_CONVOLUTION_COMMON_HPP
#define CK_CONVOLUTION_COMMON_HPP
namespace
ck
{
enum
ConvolutionDirection
{
Forward
,
BackwardData
,
BackwardWeight
};
}
// namespace ck
#endif
composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4_fp16_bfp16_nchw_kcyx_nkhw_lds_double_buffer.hpp
View file @
b37cb71f
...
...
@@ -2,22 +2,21 @@
#define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4_FP16_BFP16_NCHW_KCYX_NKHW_LDS_DOUBLE_BUFFER_HPP
#include "common_header.hpp"
#include "ConstantTensorDescriptor.hpp"
#include "ConstantMergedTensorDescriptor.hpp"
#include "ConstantTensorDescriptor
_deprecated
.hpp"
#include "ConstantMergedTensorDescriptor
_deprecated
.hpp"
#include "ConstantMatrixDescriptor.hpp"
#include "blockwise_generic_tensor_slice_copy.hpp"
#include "blockwise_generic_tensor_slice_copy
_deprecated
.hpp"
#include "blockwise_gemm.hpp"
#include "threadwise_generic_tensor_slice_copy.hpp"
#include "
implicitgemm_params
.hpp"
#include "threadwise_generic_tensor_slice_copy
_deprecated
.hpp"
#include "
convolution_common
.hpp"
namespace
ck
{
template
<
ImplicitGemmDirection
conv_dir
,
typename
WeiDesc
,
index_t
NonVectorizedC
>
struct
make_vectorized_WeiDesc
{
};
template
<
ConvolutionDirection
,
typename
WeiDesc
,
index_t
NonVectorizedC
>
struct
make_vectorized_WeiDesc
;
template
<
typename
WeiDesc
,
index_t
NonVectorizedC
>
struct
make_vectorized_WeiDesc
<
ImplicitGemm
Direction
::
Forward
Data
,
WeiDesc
,
NonVectorizedC
>
struct
make_vectorized_WeiDesc
<
Convolution
Direction
::
Forward
,
WeiDesc
,
NonVectorizedC
>
{
__device__
constexpr
auto
get
(
WeiDesc
&
)
{
...
...
@@ -30,8 +29,9 @@ struct make_vectorized_WeiDesc<ImplicitGemmDirection::ForwardData, WeiDesc, NonV
.
ReorderGivenNew2Old
(
Sequence
<
2
,
0
,
1
>
{});
}
};
template
<
typename
WeiDesc
,
index_t
NonVectorizedC
>
struct
make_vectorized_WeiDesc
<
ImplicitGemm
Direction
::
BackwardWeight
,
WeiDesc
,
NonVectorizedC
>
struct
make_vectorized_WeiDesc
<
Convolution
Direction
::
BackwardWeight
,
WeiDesc
,
NonVectorizedC
>
{
__device__
constexpr
auto
get
(
WeiDesc
&
desc
)
{
...
...
@@ -56,6 +56,7 @@ template <index_t GridSize,
class
OutGlobalDesc
,
class
ConvStrides
,
class
ConvDilations
,
ConvolutionDirection
ConvDirection
,
index_t
BPerBlock
,
index_t
KPerBlock
,
index_t
EPerBlock
,
...
...
@@ -83,8 +84,7 @@ template <index_t GridSize,
class
WeiBlockCopySrcAccessOrder
,
class
WeiBlockCopyDstAccessOrder
,
index_t
WeiBlockCopySrcDataPerRead_E
,
index_t
WeiBlockCopyDstDataPerWrite_K
,
ImplicitGemmDirection
conv_dir
>
index_t
WeiBlockCopyDstDataPerWrite_K
>
struct
GridwiseConvolutionImplicitGemm_v4_fp16_bfp16_nchw_kcyx_nkhw_lds_double_buffer
{
__device__
void
Run
(
const
Float
*
const
__restrict__
p_in_global
,
...
...
@@ -198,12 +198,11 @@ struct GridwiseConvolutionImplicitGemm_v4_fp16_bfp16_nchw_kcyx_nkhw_lds_double_b
// input blockwise copy
// slice a merged tensor, reorder and copy to a normal tensor
// this copy operator already has blockwise offset built-in
auto
blockwise_in_copy
=
BlockwiseGenericTensorSliceCopy_v1
<
BlockSize
,
auto
blockwise_in_copy
=
BlockwiseGenericTensorSliceCopy_v1_deprecated
<
BlockSize
,
decltype
(
in_e_n1_b_n2_2eor4e_global_merged_desc
),
decltype
(
in_e_n1_b_n2_2eor4e_block_desc
),
decltype
(
in_e_n1_b_n2_2eor4e_block_desc
.
GetLengths
()),
decltype
(
in_e_n1_b_n2_2eor4e_block_desc
.
GetLengths
()),
InBlockCopySubLengths_E_N1_B_N2_EPACK
,
InBlockCopyClusterLengths_E_N1_B_N2_EPACK
,
InBlockCopyThreadClusterArrangeOrder
,
...
...
@@ -212,13 +211,15 @@ struct GridwiseConvolutionImplicitGemm_v4_fp16_bfp16_nchw_kcyx_nkhw_lds_double_b
2
,
4
,
InBlockCopySrcDataPerRead_B
,
InBlockCopyDstDataPerWrite_EPACK
>
(
{
0
,
0
,
b_block_data_on_global
,
0
,
0
},
{
0
,
0
,
0
,
0
,
0
});
InBlockCopyDstDataPerWrite_EPACK
>
(
{
0
,
0
,
b_block_data_on_global
,
0
,
0
},
{
0
,
0
,
0
,
0
,
0
});
// weight tensor
// tensor descriptor in device memory, src of blockwise copy
constexpr
auto
wei_e_k_2eor4e_global_desc
=
make_vectorized_WeiDesc
<
conv_dir
,
decltype
(
wei_k_c_y_x_global_desc
),
nonVectorizedC
>
{}
make_vectorized_WeiDesc
<
ConvDirection
,
decltype
(
wei_k_c_y_x_global_desc
),
nonVectorizedC
>
{}
.
get
(
wei_k_c_y_x_global_desc
);
// tensor descriptor in LDS, dst of blockwise copy
...
...
@@ -235,8 +236,8 @@ struct GridwiseConvolutionImplicitGemm_v4_fp16_bfp16_nchw_kcyx_nkhw_lds_double_b
// operator for blockwise copy of weight into LDS
// slice a tensor, and copy it into another tensor
// this copy operator already have blockwise offset built-in
auto
blockwise_wei_copy
=
BlockwiseGenericTensorSliceCopy_v1
<
BlockSize
,
auto
blockwise_wei_copy
=
BlockwiseGenericTensorSliceCopy_v1_deprecated
<
BlockSize
,
decltype
(
wei_e_k_2eor4e_global_desc
),
decltype
(
wei_e_k_2eor4e_block_desc
),
decltype
(
wei_e_k_2eor4e_block_desc
.
GetLengths
()),
...
...
@@ -248,8 +249,7 @@ struct GridwiseConvolutionImplicitGemm_v4_fp16_bfp16_nchw_kcyx_nkhw_lds_double_b
0
,
2
,
WeiBlockCopySrcDataPerRead_E
,
WeiBlockCopyDstDataPerWrite_EPACK
>
(
{
0
,
k_block_data_on_global
,
0
},
{
0
,
0
,
0
});
WeiBlockCopyDstDataPerWrite_EPACK
>
({
0
,
k_block_data_on_global
,
0
},
{
0
,
0
,
0
});
// GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx
...
...
@@ -279,7 +279,6 @@ struct GridwiseConvolutionImplicitGemm_v4_fp16_bfp16_nchw_kcyx_nkhw_lds_double_b
const
auto
blockwise_gemm
=
BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
<
BlockSize
,
EPACK
,
decltype
(
a_e_k_block_mtx_desc
),
decltype
(
b_e_n1bn2_block_mtx_desc
),
decltype
(
c_k0k2_n1n2_thread_mtx_desc
),
...
...
@@ -347,12 +346,12 @@ struct GridwiseConvolutionImplicitGemm_v4_fp16_bfp16_nchw_kcyx_nkhw_lds_double_b
Float
*
p_wei_block_next
=
even_loop
?
p_wei_block_double
+
wei_block_space
:
p_wei_block_double
;
Float
p_in_
register
_buffer
[
blockwise_in_copy
.
Get
Register
BufferSize
()];
Float
p_wei_
register
_buffer
[
blockwise_wei_copy
.
Get
Register
BufferSize
()];
Float
p_in_
thread
_buffer
[
blockwise_in_copy
.
Get
Thread
BufferSize
()];
Float
p_wei_
thread
_buffer
[
blockwise_wei_copy
.
Get
Thread
BufferSize
()];
blockwise_in_copy
.
MoveS
licingWindowOnSourceTensor
(
I0
,
Number
<
EPerBlock
>
{},
True
);
static_if
<
c
onv
_dir
==
ImplicitGemm
Direction
::
BackwardWeight
>
{}([
&
](
auto
fwd
)
{
fwd
(
blockwise_wei_copy
).
MoveSrcSlic
ing
Window
(
Sequence
<
EPerBlock
,
0
,
0
>
{},
True
);
blockwise_in_copy
.
MoveS
rcSliceWindow
(
Sequence
<
EPerBlock
,
0
,
0
,
0
,
0
>
{},
True
);
static_if
<
C
onv
Direction
==
Convolution
Direction
::
BackwardWeight
>
{}([
&
](
auto
fwd
)
{
fwd
(
blockwise_wei_copy
).
MoveSrcSlic
e
Window
(
Sequence
<
EPerBlock
,
0
,
0
>
{},
True
);
}).
Else
([
&
](
auto
fwd
)
{
p_wei_block_on_global
+=
EPerBlock
*
fwd
(
wei_e_k_2eor4e_global_desc
).
GetStride
(
I0
);
...
...
@@ -361,9 +360,8 @@ struct GridwiseConvolutionImplicitGemm_v4_fp16_bfp16_nchw_kcyx_nkhw_lds_double_b
__syncthreads
();
// LDS doubel buffer: load next data from device mem
blockwise_in_copy
.
RunLoadRegisterBuffer
(
p_in_global
,
p_in_register_buffer
);
blockwise_wei_copy
.
RunLoadRegisterBuffer
(
p_wei_block_on_global
,
p_wei_register_buffer
);
blockwise_in_copy
.
RunLoadThreadBuffer
(
p_in_global
,
p_in_thread_buffer
);
blockwise_wei_copy
.
RunLoadThreadBuffer
(
p_wei_block_on_global
,
p_wei_thread_buffer
);
// LDS double buffer: GEMM on current data
const
typename
vector_type
<
Float
,
EPACK
>::
MemoryType
*
p_a_block_vec
=
...
...
@@ -375,20 +373,20 @@ struct GridwiseConvolutionImplicitGemm_v4_fp16_bfp16_nchw_kcyx_nkhw_lds_double_b
blockwise_gemm
.
Run
(
p_a_block_vec
,
p_b_block_vec
,
p_out_thread
);
// LDS double buffer: store next data to LDS
blockwise_in_copy
.
RunStore
Register
Buffer
(
p_in_
register
_buffer
,
p_in_block_next
);
blockwise_wei_copy
.
RunStore
Register
Buffer
(
p_wei_
register
_buffer
,
p_wei_block_next
);
blockwise_in_copy
.
RunStore
Thread
Buffer
(
p_in_
thread
_buffer
,
p_in_block_next
);
blockwise_wei_copy
.
RunStore
Thread
Buffer
(
p_wei_
thread
_buffer
,
p_wei_block_next
);
}
}
// LDS double buffer: tail
{
Float
p_in_
register
_buffer
[
blockwise_in_copy
.
Get
Register
BufferSize
()];
Float
p_wei_
register
_buffer
[
blockwise_wei_copy
.
Get
Register
BufferSize
()];
Float
p_in_
thread
_buffer
[
blockwise_in_copy
.
Get
Thread
BufferSize
()];
Float
p_wei_
thread
_buffer
[
blockwise_wei_copy
.
Get
Thread
BufferSize
()];
// even iteration
blockwise_in_copy
.
MoveS
licingWindowOnSourceTensor
(
I0
,
Number
<
EPerBlock
>
{},
True
);
static_if
<
c
onv
_dir
==
ImplicitGemm
Direction
::
BackwardWeight
>
{}([
&
](
auto
fwd
)
{
fwd
(
blockwise_wei_copy
).
MoveSrcSlic
ing
Window
(
Sequence
<
EPerBlock
,
0
,
0
>
{},
True
);
blockwise_in_copy
.
MoveS
rcSliceWindow
(
Sequence
<
EPerBlock
,
0
,
0
,
0
,
0
>
{},
True
);
static_if
<
C
onv
Direction
==
Convolution
Direction
::
BackwardWeight
>
{}([
&
](
auto
fwd
)
{
fwd
(
blockwise_wei_copy
).
MoveSrcSlic
e
Window
(
Sequence
<
EPerBlock
,
0
,
0
>
{},
True
);
}).
Else
([
&
](
auto
fwd
)
{
p_wei_block_on_global
+=
EPerBlock
*
fwd
(
wei_e_k_2eor4e_global_desc
).
GetStride
(
I0
);
});
...
...
@@ -396,8 +394,8 @@ struct GridwiseConvolutionImplicitGemm_v4_fp16_bfp16_nchw_kcyx_nkhw_lds_double_b
__syncthreads
();
// LDS doubel buffer: load next data from device mem
blockwise_in_copy
.
RunLoad
Register
Buffer
(
p_in_global
,
p_in_
register
_buffer
);
blockwise_wei_copy
.
RunLoad
Register
Buffer
(
p_wei_block_on_global
,
p_wei_
register
_buffer
);
blockwise_in_copy
.
RunLoad
Thread
Buffer
(
p_in_global
,
p_in_
thread
_buffer
);
blockwise_wei_copy
.
RunLoad
Thread
Buffer
(
p_wei_block_on_global
,
p_wei_
thread
_buffer
);
// LDS double buffer: GEMM on current data
// Vectorize the pointer to match with how half/bfloat16 datatypes are
...
...
@@ -415,9 +413,9 @@ struct GridwiseConvolutionImplicitGemm_v4_fp16_bfp16_nchw_kcyx_nkhw_lds_double_b
blockwise_gemm
.
Run
(
p_a_block_vec
,
p_b_block_vec
,
p_out_thread
);
// LDS double buffer: store next data to LDS
blockwise_in_copy
.
RunStore
Register
Buffer
(
p_in_
register
_buffer
,
blockwise_in_copy
.
RunStore
Thread
Buffer
(
p_in_
thread
_buffer
,
p_in_block_double
+
in_block_space
);
blockwise_wei_copy
.
RunStore
Register
Buffer
(
p_wei_
register
_buffer
,
blockwise_wei_copy
.
RunStore
Thread
Buffer
(
p_wei_
thread
_buffer
,
p_wei_block_double
+
wei_block_space
);
// odd iteration
...
...
@@ -479,7 +477,7 @@ struct GridwiseConvolutionImplicitGemm_v4_fp16_bfp16_nchw_kcyx_nkhw_lds_double_b
out_k_n1_b_n2_global_merged_desc
.
GetOffsetFromMultiIndex
(
k_thread_data_on_global
,
0
,
b_thread_data_on_global
,
0
);
ThreadwiseGenericTensorSliceCopy_v1r2
<
ThreadwiseGenericTensorSliceCopy_v1r2
_deprecated
<
decltype
(
out_n0_n1_n2_k0_k1_k2_h_w_thread_desc
),
decltype
(
out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc
),
decltype
(
out_n0_n1_n2_k0_k1_k2_h_w_thread_desc
.
GetLengths
()),
...
...
composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4_nchw_kc1x1_nkhw_lds_double_buffer.hpp
View file @
b37cb71f
...
...
@@ -2,13 +2,13 @@
#define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4_NCHW_KC1x1_NKHW_LDS_DOUBLE_BUFFER_HPP
#include "common_header.hpp"
#include "ConstantTensorDescriptor.hpp"
#include "ConstantMergedTensorDescriptor.hpp"
#include "ConstantTensorDescriptor
_deprecated
.hpp"
#include "ConstantMergedTensorDescriptor
_deprecated
.hpp"
#include "ConstantMatrixDescriptor.hpp"
#include "blockwise_generic_tensor_slice_copy.hpp"
#include "blockwise_generic_tensor_slice_copy
_deprecated
.hpp"
#include "blockwise_gemm.hpp"
#include "threadwise_generic_tensor_slice_copy.hpp"
#include "
implicitgemm_params
.hpp"
#include "threadwise_generic_tensor_slice_copy
_deprecated
.hpp"
#include "
convolution_common
.hpp"
namespace
ck
{
...
...
@@ -21,7 +21,7 @@ template <index_t GridSize,
class
WeiGlobalDesc
,
class
OutGlobalDesc
,
// exchanged outside for backward
class
ConvStrides
,
ImplicitGemm
Direction
Direction
,
Convolution
Direction
Conv
Direction
,
index_t
BPerBlock
,
index_t
KPerBlock
,
index_t
EPerBlock
,
...
...
@@ -56,7 +56,7 @@ struct GridwiseConvolutionImplicitGemm_v4_nchw_kc1x1_nkhw_lds_double_buffer
const
Float
*
const
__restrict__
p_wei_global
,
Float
*
const
__restrict__
p_out_global
)
const
{
constexpr
bool
isForward
=
Direction
==
ImplicitGemm
Direction
::
Forward
Data
;
constexpr
bool
isForward
=
(
Conv
Direction
==
Convolution
Direction
::
Forward
)
;
// this is a mess
// TODO: find more elegent way of specifying (or calculating) performance parameters
...
...
@@ -161,8 +161,8 @@ struct GridwiseConvolutionImplicitGemm_v4_nchw_kc1x1_nkhw_lds_double_buffer
// input blockwise copy
// slice a merged tensor, reorder and copy to a normal tensor
// this copy operator already has blockwise offset built-in
auto
blockwise_in_copy
=
BlockwiseGenericTensorSliceCopy_v1
<
BlockSize
,
auto
blockwise_in_copy
=
BlockwiseGenericTensorSliceCopy_v1_deprecated
<
BlockSize
,
decltype
(
in_e_n1_b_n2_global_merged_desc
),
decltype
(
in_e_n1_b_n2_block_desc
),
decltype
(
in_e_n1_b_n2_block_desc
.
GetLengths
()),
...
...
@@ -174,8 +174,7 @@ struct GridwiseConvolutionImplicitGemm_v4_nchw_kc1x1_nkhw_lds_double_buffer
2
,
3
,
InBlockCopySrcDataPerRead_B
,
InBlockCopyDstDataPerWrite_N2
>
(
{
0
,
0
,
b_block_data_on_global
,
0
},
{
0
,
0
,
0
,
0
});
InBlockCopyDstDataPerWrite_N2
>
({
0
,
0
,
b_block_data_on_global
,
0
},
{
0
,
0
,
0
,
0
});
// weight tensor
// tensor descriptor in device memory, src of blockwise copy
...
...
@@ -198,7 +197,7 @@ struct GridwiseConvolutionImplicitGemm_v4_nchw_kc1x1_nkhw_lds_double_buffer
// slice a tensor, and copy it into another tensor
// this copy operator already have blockwise offset built-in
auto
blockwise_wei_copy
=
BlockwiseGenericTensorSliceCopy_v1
<
BlockSize
,
BlockwiseGenericTensorSliceCopy_v1
_deprecated
<
BlockSize
,
decltype
(
wei_e_k_global_desc
),
decltype
(
wei_e_k_block_desc
),
decltype
(
wei_e_k_block_desc
.
GetLengths
()),
...
...
@@ -239,7 +238,6 @@ struct GridwiseConvolutionImplicitGemm_v4_nchw_kc1x1_nkhw_lds_double_buffer
const
auto
blockwise_gemm
=
BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
<
BlockSize
,
1
,
// EPACK = 1
decltype
(
a_e_k_block_mtx_desc
),
decltype
(
b_e_n1bn2_block_mtx_desc
),
decltype
(
c_k0k2_n1n2_thread_mtx_desc
),
...
...
@@ -301,50 +299,49 @@ struct GridwiseConvolutionImplicitGemm_v4_nchw_kc1x1_nkhw_lds_double_buffer
Float
*
p_wei_block_next
=
even_loop
?
p_wei_block_double
+
wei_block_space
:
p_wei_block_double
;
Float
p_in_
register
_buffer
[
blockwise_in_copy
.
Get
Register
BufferSize
()];
Float
p_wei_
register
_buffer
[
blockwise_wei_copy
.
Get
Register
BufferSize
()];
Float
p_in_
thread
_buffer
[
blockwise_in_copy
.
Get
Thread
BufferSize
()];
Float
p_wei_
thread
_buffer
[
blockwise_wei_copy
.
Get
Thread
BufferSize
()];
blockwise_in_copy
.
MoveSrcSlic
ing
Window
(
Sequence
<
EPerBlock
,
0
,
0
,
0
>
{},
True
);
blockwise_in_copy
.
MoveSrcSlic
e
Window
(
Sequence
<
EPerBlock
,
0
,
0
,
0
>
{},
True
);
p_wei_block_on_global
+=
EPerBlock
*
wei_e_k_global_desc
.
GetStride
(
I0
);
__syncthreads
();
// LDS doubel buffer: load next data from device mem
blockwise_in_copy
.
RunLoadRegisterBuffer
(
p_in_global
,
p_in_register_buffer
);
blockwise_wei_copy
.
RunLoadRegisterBuffer
(
p_wei_block_on_global
,
p_wei_register_buffer
);
blockwise_in_copy
.
RunLoadThreadBuffer
(
p_in_global
,
p_in_thread_buffer
);
blockwise_wei_copy
.
RunLoadThreadBuffer
(
p_wei_block_on_global
,
p_wei_thread_buffer
);
// LDS double buffer: GEMM on current data
blockwise_gemm
.
Run
(
p_wei_block_now
,
p_in_block_now
,
p_out_thread
);
// LDS double buffer: store next data to LDS
blockwise_in_copy
.
RunStore
Register
Buffer
(
p_in_
register
_buffer
,
p_in_block_next
);
blockwise_wei_copy
.
RunStore
Register
Buffer
(
p_wei_
register
_buffer
,
p_wei_block_next
);
blockwise_in_copy
.
RunStore
Thread
Buffer
(
p_in_
thread
_buffer
,
p_in_block_next
);
blockwise_wei_copy
.
RunStore
Thread
Buffer
(
p_wei_
thread
_buffer
,
p_wei_block_next
);
}
}
// LDS double buffer: tail
{
Float
p_in_
register
_buffer
[
blockwise_in_copy
.
Get
Register
BufferSize
()];
Float
p_wei_
register
_buffer
[
blockwise_wei_copy
.
Get
Register
BufferSize
()];
Float
p_in_
thread
_buffer
[
blockwise_in_copy
.
Get
Thread
BufferSize
()];
Float
p_wei_
thread
_buffer
[
blockwise_wei_copy
.
Get
Thread
BufferSize
()];
// even iteration
blockwise_in_copy
.
MoveSrcSlic
ing
Window
(
Sequence
<
EPerBlock
,
0
,
0
,
0
>
{},
True
);
blockwise_in_copy
.
MoveSrcSlic
e
Window
(
Sequence
<
EPerBlock
,
0
,
0
,
0
>
{},
True
);
p_wei_block_on_global
+=
EPerBlock
*
wei_e_k_global_desc
.
GetStride
(
I0
);
__syncthreads
();
// LDS doubel buffer: load next data from device mem
blockwise_in_copy
.
RunLoad
Register
Buffer
(
p_in_global
,
p_in_
register
_buffer
);
blockwise_wei_copy
.
RunLoad
Register
Buffer
(
p_wei_block_on_global
,
p_wei_
register
_buffer
);
blockwise_in_copy
.
RunLoad
Thread
Buffer
(
p_in_global
,
p_in_
thread
_buffer
);
blockwise_wei_copy
.
RunLoad
Thread
Buffer
(
p_wei_block_on_global
,
p_wei_
thread
_buffer
);
// LDS double buffer: GEMM on current data
blockwise_gemm
.
Run
(
p_wei_block_double
,
p_in_block_double
,
p_out_thread
);
// LDS double buffer: store next data to LDS
blockwise_in_copy
.
RunStore
Register
Buffer
(
p_in_
register
_buffer
,
blockwise_in_copy
.
RunStore
Thread
Buffer
(
p_in_
thread
_buffer
,
p_in_block_double
+
in_block_space
);
blockwise_wei_copy
.
RunStore
Register
Buffer
(
p_wei_
register
_buffer
,
blockwise_wei_copy
.
RunStore
Thread
Buffer
(
p_wei_
thread
_buffer
,
p_wei_block_double
+
wei_block_space
);
// odd iteration
...
...
@@ -437,7 +434,7 @@ struct GridwiseConvolutionImplicitGemm_v4_nchw_kc1x1_nkhw_lds_double_buffer
out_k_n1_b_n2_global_merged_desc
.
GetOffsetFromMultiIndex
(
k_thread_data_on_global
,
0
,
b_thread_data_on_global
,
0
);
ThreadwiseGenericTensorSliceCopy_v1r2
<
ThreadwiseGenericTensorSliceCopy_v1r2
_deprecated
<
decltype
(
out_n0_n1_n2_k0_k1_k2_h_w_thread_desc
),
decltype
(
out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc
),
decltype
(
out_n0_n1_n2_k0_k1_k2_h_w_thread_desc
.
GetLengths
()),
...
...
composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer.hpp
0 → 100644
View file @
b37cb71f
#ifndef CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4R1_NCHW_KCYX_NKHW_LDS_DOUBLE_BUFFER_HPP
#define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4R1_NCHW_KCYX_NKHW_LDS_DOUBLE_BUFFER_HPP
#include "common_header.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "ConstantMatrixDescriptor.hpp"
#include "blockwise_generic_tensor_slice_copy.hpp"
#include "threadwise_generic_tensor_slice_copy.hpp"
#include "blockwise_gemm.hpp"
#include "convolution_common.hpp"
namespace
ck
{
template
<
ConvolutionDirection
>
struct
make_wei_e_k_global_desc_v4r1
;
template
<
>
struct
make_wei_e_k_global_desc_v4r1
<
ConvolutionDirection
::
Forward
>
{
template
<
typename
WeiDesc
>
__device__
constexpr
auto
operator
()(
WeiDesc
)
const
{
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
return
reorder_tensor_descriptor_given_upper2lower
(
unfold_tensor_descriptor
(
WeiDesc
{},
I1
,
I3
),
Sequence
<
1
,
0
>
{});
}
};
template
<
>
struct
make_wei_e_k_global_desc_v4r1
<
ConvolutionDirection
::
BackwardWeight
>
{
template
<
typename
WeiDesc
>
__device__
constexpr
auto
operator
()(
WeiDesc
)
const
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
wei_k_c_y_x_global_desc
=
WeiDesc
{};
constexpr
index_t
K
=
wei_k_c_y_x_global_desc
.
GetLength
(
I0
);
constexpr
index_t
C
=
wei_k_c_y_x_global_desc
.
GetLength
(
I1
);
constexpr
index_t
Y
=
wei_k_c_y_x_global_desc
.
GetLength
(
I2
);
constexpr
index_t
X
=
wei_k_c_y_x_global_desc
.
GetLength
(
I3
);
return
transform_tensor_descriptor
(
unfold_tensor_descriptor
(
wei_k_c_y_x_global_desc
,
I2
,
I3
),
make_tuple
(
Merge
<
Sequence
<
C
,
Y
*
X
>>
{},
PassThrough
<
K
>
{}),
make_tuple
(
Sequence
<
1
,
2
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
}
};
template
<
index_t
GridSize
,
index_t
BlockSize
,
typename
Float
,
typename
AccDataType
,
typename
InGlobalDesc
,
typename
WeiGlobalDesc
,
typename
OutGlobalDesc
,
typename
ConvStrides
,
typename
ConvDilations
,
typename
LeftPads
,
typename
RightPads
,
ConvolutionDirection
ConvDirection
,
index_t
BPerBlock
,
index_t
KPerBlock
,
index_t
EPerBlock
,
index_t
GemmNRepeat
,
index_t
GemmMPerThreadSubC
,
index_t
GemmNPerThreadSubC
,
index_t
GemmMLevel0Cluster
,
index_t
GemmNLevel0Cluster
,
index_t
GemmMLevel1Cluster
,
index_t
GemmNLevel1Cluster
,
index_t
GemmKPerThreadLoop
,
index_t
GemmDataPerReadA
,
index_t
GemmDataPerReadB
,
typename
InBlockCopySubLengths_E_N1_B_N2
,
typename
InBlockCopyClusterLengths_E_N1_B_N2
,
typename
InBlockCopyThreadClusterArrangeOrder
,
typename
InBlockCopySrcAccessOrder
,
typename
InBlockCopyDstAccessOrder
,
index_t
InBlockCopySrcDataPerRead_B
,
index_t
InBlockCopyDstDataPerWrite_N2
,
typename
WeiBlockCopySubLengths_E_K
,
typename
WeiBlockCopyClusterLengths_E_K
,
typename
WeiBlockCopyThreadClusterArrangeOrder
,
typename
WeiBlockCopySrcAccessOrder
,
typename
WeiBlockCopyDstAccessOrder
,
index_t
WeiBlockCopySrcDataPerRead_E
,
index_t
WeiBlockCopyDstDataPerWrite_K
>
struct
GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer
{
__device__
void
Run
(
const
Float
*
const
__restrict__
p_in_global
,
const
Float
*
const
__restrict__
p_wei_global
,
Float
*
const
__restrict__
p_out_global
)
const
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
True
=
integral_constant
<
bool
,
true
>
{};
constexpr
auto
generic_address_space
=
integral_constant
<
AddressSpace
,
AddressSpace
::
generic
>
{};
constexpr
auto
global_address_space
=
integral_constant
<
AddressSpace
,
AddressSpace
::
global
>
{};
static_assert
(
ConvDirection
==
ConvolutionDirection
::
Forward
||
ConvDirection
==
ConvolutionDirection
::
BackwardWeight
,
"wrong! this kernel only support convolution forward and backward-weight"
);
// this is a mess
// TODO: find more elegent way of specifying (or calculating) performance parameters
constexpr
index_t
N1
=
GemmNRepeat
;
constexpr
index_t
N2
=
GemmNPerThreadSubC
;
static_assert
((
N1
*
N2
*
BPerBlock
)
%
(
GemmNPerThreadSubC
*
GemmNLevel0Cluster
*
GemmNLevel1Cluster
)
==
0
,
"wrong!"
);
constexpr
auto
in_n_c_hi_wi_global_desc
=
InGlobalDesc
{};
constexpr
auto
wei_k_c_y_x_global_desc
=
WeiGlobalDesc
{};
constexpr
auto
out_n_k_ho_wo_global_desc
=
OutGlobalDesc
{};
constexpr
index_t
N
=
in_n_c_hi_wi_global_desc
.
GetLength
(
I0
);
constexpr
index_t
C
=
in_n_c_hi_wi_global_desc
.
GetLength
(
I1
);
constexpr
index_t
Hi
=
in_n_c_hi_wi_global_desc
.
GetLength
(
I2
);
constexpr
index_t
Wi
=
in_n_c_hi_wi_global_desc
.
GetLength
(
I3
);
constexpr
index_t
K
=
out_n_k_ho_wo_global_desc
.
GetLength
(
I1
);
constexpr
index_t
Ho
=
out_n_k_ho_wo_global_desc
.
GetLength
(
I2
);
constexpr
index_t
Wo
=
out_n_k_ho_wo_global_desc
.
GetLength
(
I3
);
constexpr
index_t
Y
=
wei_k_c_y_x_global_desc
.
GetLength
(
I2
);
constexpr
index_t
X
=
wei_k_c_y_x_global_desc
.
GetLength
(
I3
);
constexpr
index_t
ConvStrideH
=
ConvStrides
{}[
0
];
constexpr
index_t
ConvStrideW
=
ConvStrides
{}[
1
];
constexpr
index_t
ConvDilationH
=
ConvDilations
{}[
0
];
constexpr
index_t
ConvDilationW
=
ConvDilations
{}[
1
];
static_assert
(
N
%
(
N1
*
N2
)
==
0
,
"wrong! cannot divice N evenly among thread"
);
constexpr
index_t
N0
=
N
/
(
N1
*
N2
);
constexpr
index_t
B
=
N0
*
Ho
*
Wo
;
constexpr
index_t
E
=
C
*
Y
*
X
;
// sanity-check for vectorized memory load
static_assert
((
Wo
==
1
||
(
ConvStrideW
==
1
||
InBlockCopySrcDataPerRead_B
==
1
))
&&
(
X
==
1
||
ConvDilationW
%
InBlockCopySrcDataPerRead_B
==
0
),
"wrong! aligment requirement for vectorized global load of input tensor will "
"be violated"
);
// divide block work by [K, B]
static_assert
(
K
%
KPerBlock
==
0
&&
B
%
BPerBlock
==
0
&&
E
%
EPerBlock
==
0
,
"wrong! cannot divide work evenly among block"
);
constexpr
index_t
KBlockWork
=
K
/
KPerBlock
;
constexpr
index_t
BBlockWork
=
B
/
BPerBlock
;
constexpr
auto
block_work_desc
=
make_cluster_descriptor
(
Sequence
<
KBlockWork
,
BBlockWork
>
{});
const
auto
block_work_id
=
block_work_desc
.
CalculateClusterIndex
(
get_block_1d_id
());
const
index_t
k_block_data_on_global
=
block_work_id
[
0
]
*
KPerBlock
;
const
index_t
b_block_data_on_global
=
block_work_id
[
1
]
*
BPerBlock
;
// input tensor
// global tensor in global memory
constexpr
auto
in_n_c_hip_wip_global_desc
=
transform_tensor_descriptor
(
in_n_c_hi_wi_global_desc
,
make_tuple
(
PassThrough
<
N
>
{},
PassThrough
<
C
>
{},
Pad
<
Sequence
<
Hi
,
Wi
>
,
LeftPads
,
RightPads
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
3
>
{}));
constexpr
auto
in_n0_n1_n2_c_y_ho_x_wo_global_desc
=
transform_tensor_descriptor
(
in_n_c_hip_wip_global_desc
,
make_tuple
(
UnMerge
<
Sequence
<
N0
,
N1
,
N2
>>
{},
PassThrough
<
C
>
{},
Embed
<
Sequence
<
Y
,
Ho
>
,
Sequence
<
ConvDilationH
,
ConvStrideH
,
0
>>
{},
Embed
<
Sequence
<
X
,
Wo
>
,
Sequence
<
ConvDilationW
,
ConvStrideW
,
0
>>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
,
5
>
{},
Sequence
<
6
,
7
>
{}));
// global tensor in global memory, src of blockwise copy
constexpr
auto
in_e_n1_b_n2_global_desc
=
transform_tensor_descriptor
(
in_n0_n1_n2_c_y_ho_x_wo_global_desc
,
make_tuple
(
Merge
<
Sequence
<
C
,
Y
,
X
>>
{},
PassThrough
<
N1
>
{},
Merge
<
Sequence
<
N0
,
Ho
,
Wo
>>
{},
PassThrough
<
N2
>
{}),
make_tuple
(
Sequence
<
3
,
4
,
6
>
{},
Sequence
<
1
>
{},
Sequence
<
0
,
5
,
7
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}));
// block tensor in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr
auto
in_e_n1_b_n2_block_desc
=
make_native_tensor_descriptor_aligned
(
Sequence
<
EPerBlock
,
N1
,
BPerBlock
,
N2
>
{},
Number
<
InBlockCopyDstDataPerWrite_N2
>
{});
// this check is ad-hoc
// TODO: need to properly implement tensor descriptor with multiple alignment
// requirements
static_assert
(
in_e_n1_b_n2_block_desc
.
GetStride
(
I1
)
%
GemmDataPerReadB
==
0
,
"GemmDataPerReadB alignment requirement is not satisfied"
);
// input tensor blockwise copy
auto
blockwise_in_copy
=
BlockwiseGenericTensorSliceCopy_v4
<
BlockSize
,
decltype
(
in_e_n1_b_n2_global_desc
),
decltype
(
in_e_n1_b_n2_block_desc
),
decltype
(
in_e_n1_b_n2_block_desc
.
GetLengths
()),
InBlockCopySubLengths_E_N1_B_N2
,
InBlockCopyClusterLengths_E_N1_B_N2
,
InBlockCopyThreadClusterArrangeOrder
,
InBlockCopySrcAccessOrder
,
InBlockCopyDstAccessOrder
,
2
,
3
,
InBlockCopySrcDataPerRead_B
,
InBlockCopyDstDataPerWrite_N2
>
(
{
0
,
0
,
b_block_data_on_global
,
0
},
{
0
,
0
,
0
,
0
});
// weight tensor
// global tensor in global memory, src of blockwise copy
// It is constructed differently, depending on whether forward or backward weight
// convolution
constexpr
auto
wei_e_k_global_desc
=
make_wei_e_k_global_desc_v4r1
<
ConvDirection
>
{}(
wei_k_c_y_x_global_desc
);
// block tensor in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr
auto
wei_e_k_block_desc
=
make_native_tensor_descriptor_aligned
(
Sequence
<
EPerBlock
,
KPerBlock
>
{},
Number
<
math
::
lcm
(
WeiBlockCopyDstDataPerWrite_K
,
GemmDataPerReadA
)
>
{});
// this check is ad-hoc
// TODO: need to properly implement tensor descriptor with multiple alignment
// requirements
static_assert
(
wei_e_k_block_desc
.
GetStride
(
I0
)
%
GemmDataPerReadA
==
0
,
"GemmDataPerReadA alignment requirement is not satisfied"
);
// weight tensor blockwise copy
auto
blockwise_wei_copy
=
BlockwiseGenericTensorSliceCopy_v4
<
BlockSize
,
decltype
(
wei_e_k_global_desc
),
decltype
(
wei_e_k_block_desc
),
decltype
(
wei_e_k_block_desc
.
GetLengths
()),
WeiBlockCopySubLengths_E_K
,
WeiBlockCopyClusterLengths_E_K
,
WeiBlockCopyThreadClusterArrangeOrder
,
WeiBlockCopySrcAccessOrder
,
WeiBlockCopyDstAccessOrder
,
0
,
1
,
WeiBlockCopySrcDataPerRead_E
,
WeiBlockCopyDstDataPerWrite_K
>
(
{
0
,
k_block_data_on_global
},
{
0
,
0
});
// GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx
// a_mtx[EPerBlock, KPerBlock] is in LDS
// b_mtx[EPerBlocl, N1 * BPerBlock * N2] is in LDS
// c_mtx[KPerBlock, N1 * BPerBlock * N2] is distributed among threads, and saved in
// register
constexpr
auto
a_e_k_block_mtx_desc
=
make_ConstantMatrixDescriptor
(
wei_e_k_block_desc
);
constexpr
auto
b_e_n1bn2_block_mtx_desc
=
make_ConstantMatrixDescriptor
(
in_e_n1_b_n2_block_desc
.
GetLength
(
I0
),
in_e_n1_b_n2_block_desc
.
GetLength
(
I1
)
*
in_e_n1_b_n2_block_desc
.
GetLength
(
I2
)
*
in_e_n1_b_n2_block_desc
.
GetLength
(
I3
),
in_e_n1_b_n2_block_desc
.
GetStride
(
I0
));
// sanity check
static_assert
(
KPerBlock
%
(
GemmMPerThreadSubC
*
GemmMLevel0Cluster
*
GemmMLevel1Cluster
)
==
0
,
"wrong!"
);
constexpr
index_t
GemmMRepeat
=
KPerBlock
/
(
GemmMPerThreadSubC
*
GemmMLevel0Cluster
*
GemmMLevel1Cluster
);
// c_thread_mtx definition: this is a mess
// TODO:: more elegent way of defining c_thread_mtx
constexpr
auto
c_k0k1_n1n2_thread_mtx_desc
=
make_ConstantMatrixDescriptor_packed
(
Number
<
GemmMRepeat
*
GemmMPerThreadSubC
>
{},
Number
<
GemmNRepeat
*
GemmNPerThreadSubC
>
{});
const
auto
blockwise_gemm
=
BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
<
BlockSize
,
decltype
(
a_e_k_block_mtx_desc
),
decltype
(
b_e_n1bn2_block_mtx_desc
),
decltype
(
c_k0k1_n1n2_thread_mtx_desc
),
GemmMPerThreadSubC
,
GemmNPerThreadSubC
,
GemmMLevel0Cluster
,
GemmNLevel0Cluster
,
GemmMLevel1Cluster
,
GemmNLevel1Cluster
,
GemmKPerThreadLoop
,
GemmDataPerReadA
,
GemmDataPerReadB
>
{};
// LDS allocation for input and weight: be careful of alignment
constexpr
index_t
max_align
=
math
::
lcm
(
InBlockCopyDstDataPerWrite_N2
,
WeiBlockCopyDstDataPerWrite_K
,
GemmDataPerReadA
,
GemmDataPerReadB
);
constexpr
index_t
in_block_space
=
math
::
integer_least_multiple
(
in_e_n1_b_n2_block_desc
.
GetElementSpace
(),
max_align
);
constexpr
index_t
wei_block_space
=
math
::
integer_least_multiple
(
wei_e_k_block_desc
.
GetElementSpace
(),
max_align
);
__shared__
Float
p_in_block_double
[
2
*
in_block_space
];
__shared__
Float
p_wei_block_double
[
2
*
wei_block_space
];
// register allocation for output
AccDataType
p_out_thread
[
c_k0k1_n1n2_thread_mtx_desc
.
GetElementSpace
()];
// zero out threadwise output
threadwise_matrix_set_zero
(
c_k0k1_n1n2_thread_mtx_desc
,
p_out_thread
);
// LDS double buffer: preload data into LDS
{
blockwise_in_copy
.
Run
(
p_in_global
,
p_in_block_double
,
global_address_space
,
generic_address_space
);
blockwise_wei_copy
.
Run
(
p_wei_global
,
p_wei_block_double
,
global_address_space
,
generic_address_space
);
}
// LDS double buffer: main body
for
(
index_t
e_block_data_begin
=
0
;
e_block_data_begin
+
2
*
EPerBlock
<
E
;
e_block_data_begin
+=
2
*
EPerBlock
)
{
#pragma unroll
for
(
index_t
iloop
=
0
;
iloop
<
2
;
++
iloop
)
{
const
bool
even_loop
=
(
iloop
%
2
==
0
);
Float
*
p_in_block_now
=
even_loop
?
p_in_block_double
:
p_in_block_double
+
in_block_space
;
Float
*
p_wei_block_now
=
even_loop
?
p_wei_block_double
:
p_wei_block_double
+
wei_block_space
;
Float
*
p_in_block_next
=
even_loop
?
p_in_block_double
+
in_block_space
:
p_in_block_double
;
Float
*
p_wei_block_next
=
even_loop
?
p_wei_block_double
+
wei_block_space
:
p_wei_block_double
;
Float
p_in_thread_buffer
[
blockwise_in_copy
.
GetThreadBufferSize
()];
Float
p_wei_thread_buffer
[
blockwise_wei_copy
.
GetThreadBufferSize
()];
blockwise_in_copy
.
MoveSrcSliceWindow
(
Sequence
<
EPerBlock
,
0
,
0
,
0
>
{},
True
);
blockwise_wei_copy
.
MoveSrcSliceWindow
(
Sequence
<
EPerBlock
,
0
>
{},
True
);
__syncthreads
();
// LDS doubel buffer: load next data from device mem
blockwise_in_copy
.
RunLoadThreadBuffer
(
p_in_global
,
p_in_thread_buffer
,
global_address_space
,
generic_address_space
);
blockwise_wei_copy
.
RunLoadThreadBuffer
(
p_wei_global
,
p_wei_thread_buffer
,
global_address_space
,
generic_address_space
);
// LDS double buffer: GEMM on current data
blockwise_gemm
.
Run
(
p_wei_block_now
,
p_in_block_now
,
p_out_thread
);
// LDS double buffer: store next data to LDS
blockwise_in_copy
.
RunStoreThreadBuffer
(
p_in_thread_buffer
,
p_in_block_next
);
blockwise_wei_copy
.
RunStoreThreadBuffer
(
p_wei_thread_buffer
,
p_wei_block_next
);
}
}
// LDS double buffer: tail
{
constexpr
bool
has_two_iteration_left
=
(
E
%
(
2
*
EPerBlock
)
==
0
);
if
(
has_two_iteration_left
)
// if has 2 iteration left
{
Float
p_in_thread_buffer
[
blockwise_in_copy
.
GetThreadBufferSize
()];
Float
p_wei_thread_buffer
[
blockwise_wei_copy
.
GetThreadBufferSize
()];
blockwise_in_copy
.
MoveSrcSliceWindow
(
Sequence
<
EPerBlock
,
0
,
0
,
0
>
{},
True
);
blockwise_wei_copy
.
MoveSrcSliceWindow
(
Sequence
<
EPerBlock
,
0
>
{},
True
);
__syncthreads
();
// LDS double buffer: load last data from device mem
blockwise_in_copy
.
RunLoadThreadBuffer
(
p_in_global
,
p_in_thread_buffer
,
global_address_space
,
generic_address_space
);
blockwise_wei_copy
.
RunLoadThreadBuffer
(
p_wei_global
,
p_wei_thread_buffer
,
global_address_space
,
generic_address_space
);
// LDS double buffer: GEMM on 2nd-last data
blockwise_gemm
.
Run
(
p_wei_block_double
,
p_in_block_double
,
p_out_thread
);
// LDS double buffer: store last data to LDS
blockwise_in_copy
.
RunStoreThreadBuffer
(
p_in_thread_buffer
,
p_in_block_double
+
in_block_space
);
blockwise_wei_copy
.
RunStoreThreadBuffer
(
p_wei_thread_buffer
,
p_wei_block_double
+
wei_block_space
);
__syncthreads
();
// LDS double buffer: GEMM on last data
blockwise_gemm
.
Run
(
p_wei_block_double
+
wei_block_space
,
p_in_block_double
+
in_block_space
,
p_out_thread
);
}
else
// if has 1 iteration left
{
__syncthreads
();
// LDS double buffer: GEMM on last data
blockwise_gemm
.
Run
(
p_wei_block_double
,
p_in_block_double
,
p_out_thread
);
}
}
// copy output: register to global memory
{
constexpr
index_t
K1
=
GemmMPerThreadSubC
*
GemmMLevel0Cluster
*
GemmMLevel1Cluster
;
constexpr
index_t
K0
=
K
/
K1
;
// define output tensor descriptor for threadwise copy
// thread output tensor, src of threadwise copy
constexpr
auto
out_k0_k1_n1_b_n2_thread_desc
=
make_native_tensor_descriptor_packed
(
Sequence
<
GemmMRepeat
,
GemmMPerThreadSubC
,
N1
,
1
,
N2
>
{});
// global output tensor
constexpr
auto
out_n0_n1_n2_k0_k1_ho_wo_global_desc
=
transform_tensor_descriptor
(
out_n_k_ho_wo_global_desc
,
make_tuple
(
UnMerge
<
Sequence
<
N0
,
N1
,
N2
>>
{},
UnMerge
<
Sequence
<
K0
,
K1
>>
{},
PassThrough
<
Ho
>
{},
PassThrough
<
Wo
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{},
Sequence
<
3
,
4
>
{},
Sequence
<
5
>
{},
Sequence
<
6
>
{}));
// global output tensor, dst of threadwise copy
constexpr
auto
out_k0_k1_n1_b_n2_global_desc
=
transform_tensor_descriptor
(
out_n0_n1_n2_k0_k1_ho_wo_global_desc
,
make_tuple
(
PassThrough
<
K0
>
{},
PassThrough
<
K1
>
{},
PassThrough
<
N1
>
{},
Merge
<
Sequence
<
N0
,
Ho
,
Wo
>>
{},
PassThrough
<
N2
>
{}),
make_tuple
(
Sequence
<
3
>
{},
Sequence
<
4
>
{},
Sequence
<
1
>
{},
Sequence
<
0
,
5
,
6
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{}));
// calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index
const
auto
c_thread_mtx_on_block
=
blockwise_gemm
.
GetBeginOfThreadMatrixC
(
get_thread_local_1d_id
());
const
index_t
k_thread_data_on_global
=
k_block_data_on_global
+
c_thread_mtx_on_block
.
row
;
const
index_t
b_thread_data_on_global
=
b_block_data_on_global
+
c_thread_mtx_on_block
.
col
/
N2
;
ThreadwiseGenericTensorSliceCopy_v4r2
<
decltype
(
out_k0_k1_n1_b_n2_thread_desc
),
decltype
(
out_k0_k1_n1_b_n2_global_desc
),
decltype
(
out_k0_k1_n1_b_n2_thread_desc
.
GetLengths
()),
arithmetic_sequence_gen
<
0
,
5
,
1
>::
type
,
3
,
1
,
1
>
({
0
,
0
,
0
,
0
,
0
},
{
k_thread_data_on_global
/
K1
,
k_thread_data_on_global
%
K1
,
0
,
b_thread_data_on_global
,
0
})
.
Run
(
p_out_thread
,
p_out_global
,
generic_address_space
,
global_address_space
);
}
}
};
}
// namespace ck
#endif
composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer_deprecated.hpp
0 → 100644
View file @
b37cb71f
#ifndef CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4R1_NCHW_KCYX_NKHW_LDS_DOUBLE_BUFFER_DEPRECATED_HPP
#define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4R1_NCHW_KCYX_NKHW_LDS_DOUBLE_BUFFER_DEPRECATED_HPP
#include "common_header.hpp"
#include "ConstantTensorDescriptor_deprecated.hpp"
#include "ConstantMergedTensorDescriptor_deprecated.hpp"
#include "ConstantMatrixDescriptor.hpp"
#include "blockwise_generic_tensor_slice_copy_deprecated.hpp"
#include "blockwise_gemm.hpp"
#include "threadwise_generic_tensor_slice_copy_deprecated.hpp"
#include "convolution_common.hpp"
namespace
ck
{
template
<
ConvolutionDirection
>
struct
make_wei_e_k_global_desc_v4r1_deprecated
;
template
<
>
struct
make_wei_e_k_global_desc_v4r1_deprecated
<
ConvolutionDirection
::
Forward
>
{
template
<
typename
WeiDesc
>
__device__
constexpr
auto
operator
()(
WeiDesc
)
const
{
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
return
WeiDesc
::
Unfold
(
I1
,
I3
).
ReorderGivenNew2Old
(
Sequence
<
1
,
0
>
{});
}
};
template
<
>
struct
make_wei_e_k_global_desc_v4r1_deprecated
<
ConvolutionDirection
::
BackwardWeight
>
{
template
<
typename
WeiDesc
>
__device__
constexpr
auto
operator
()(
WeiDesc
)
const
{
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
return
make_ConstantMergedTensorDescriptor
(
WeiDesc
::
Unfold
(
I2
,
I3
),
Sequence
<
1
,
2
>
{},
Sequence
<
0
>
{});
}
};
// define B = merge(N0, Ho, Wo)
template
<
index_t
GridSize
,
index_t
BlockSize
,
class
Float
,
class
AccDataType
,
class
InGlobalDesc
,
class
WeiGlobalDesc
,
class
OutGlobalDesc
,
class
ConvStrides
,
class
ConvDilations
,
ConvolutionDirection
ConvDirection
,
index_t
BPerBlock
,
index_t
KPerBlock
,
index_t
EPerBlock
,
index_t
GemmNRepeat
,
index_t
GemmMPerThreadSubC
,
index_t
GemmNPerThreadSubC
,
index_t
GemmMLevel0Cluster
,
index_t
GemmNLevel0Cluster
,
index_t
GemmMLevel1Cluster
,
index_t
GemmNLevel1Cluster
,
index_t
GemmKPerThreadLoop
,
index_t
GemmDataPerReadA
,
index_t
GemmDataPerReadB
,
class
InBlockCopySubLengths_E_N1_B_N2
,
class
InBlockCopyClusterLengths_E_N1_B_N2
,
class
InBlockCopyThreadClusterArrangeOrder
,
class
InBlockCopySrcAccessOrder
,
class
InBlockCopyDstAccessOrder
,
index_t
InBlockCopySrcDataPerRead_B
,
index_t
InBlockCopyDstDataPerWrite_N2
,
class
WeiBlockCopySubLengths_E_K
,
class
WeiBlockCopyClusterLengths_E_K
,
class
WeiBlockCopyThreadClusterArrangeOrder
,
class
WeiBlockCopySrcAccessOrder
,
class
WeiBlockCopyDstAccessOrder
,
index_t
WeiBlockCopySrcDataPerRead_E
,
index_t
WeiBlockCopyDstDataPerWrite_K
>
struct
GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer_deprecated
{
__device__
void
Run
(
const
Float
*
const
__restrict__
p_in_global
,
const
Float
*
const
__restrict__
p_wei_global
,
Float
*
const
__restrict__
p_out_global
)
const
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
True
=
integral_constant
<
bool
,
true
>
{};
constexpr
auto
generic_address_space
=
integral_constant
<
AddressSpace
,
AddressSpace
::
generic
>
{};
constexpr
auto
global_address_space
=
integral_constant
<
AddressSpace
,
AddressSpace
::
global
>
{};
static_assert
(
ConvDirection
==
ConvolutionDirection
::
Forward
||
ConvDirection
==
ConvolutionDirection
::
BackwardWeight
,
"wrong! this kernel only support convolution forward and backward-weight"
);
// this is a mess
// TODO: find more elegent way of specifying (or calculating) performance parameters
constexpr
index_t
N1
=
GemmNRepeat
;
constexpr
index_t
N2
=
GemmNPerThreadSubC
;
static_assert
((
N1
*
N2
*
BPerBlock
)
%
(
GemmNPerThreadSubC
*
GemmNLevel0Cluster
*
GemmNLevel1Cluster
)
==
0
,
"wrong!"
);
constexpr
auto
in_n_c_h_w_global_desc
=
InGlobalDesc
{};
constexpr
auto
wei_k_c_y_x_global_desc
=
WeiGlobalDesc
{};
constexpr
auto
out_n_k_h_w_global_desc
=
OutGlobalDesc
{};
constexpr
index_t
N
=
in_n_c_h_w_global_desc
.
GetLength
(
I0
);
constexpr
index_t
C
=
in_n_c_h_w_global_desc
.
GetLength
(
I1
);
constexpr
index_t
K
=
out_n_k_h_w_global_desc
.
GetLength
(
I1
);
constexpr
index_t
Ho
=
out_n_k_h_w_global_desc
.
GetLength
(
I2
);
constexpr
index_t
Wo
=
out_n_k_h_w_global_desc
.
GetLength
(
I3
);
constexpr
index_t
Y
=
wei_k_c_y_x_global_desc
.
GetLength
(
I2
);
constexpr
index_t
X
=
wei_k_c_y_x_global_desc
.
GetLength
(
I3
);
constexpr
index_t
ConvStrideH
=
ConvStrides
{}[
0
];
constexpr
index_t
ConvStrideW
=
ConvStrides
{}[
1
];
constexpr
index_t
ConvDilationH
=
ConvDilations
{}[
0
];
constexpr
index_t
ConvDilationW
=
ConvDilations
{}[
1
];
static_assert
(
N
%
(
N1
*
N2
)
==
0
,
"wrong! cannot divice N evenly among thread"
);
constexpr
index_t
N0
=
N
/
(
N1
*
N2
);
constexpr
index_t
B
=
N0
*
Ho
*
Wo
;
constexpr
index_t
E
=
C
*
Y
*
X
;
// sanity-check for vectorized memory load
static_assert
((
Wo
==
1
||
(
ConvStrideW
==
1
||
InBlockCopySrcDataPerRead_B
==
1
))
&&
(
X
==
1
||
ConvDilationW
%
InBlockCopySrcDataPerRead_B
==
0
),
"wrong! aligment requirement for vectorized global load of input tensor will "
"be violated"
);
// divide block work by [K, B]
static_assert
(
K
%
KPerBlock
==
0
&&
B
%
BPerBlock
==
0
&&
E
%
(
2
*
EPerBlock
)
==
0
,
"wrong! cannot divide work evenly among block"
);
constexpr
index_t
KBlockWork
=
K
/
KPerBlock
;
constexpr
index_t
BBlockWork
=
B
/
BPerBlock
;
constexpr
auto
block_work_desc
=
make_ConstantTensorDescriptor_packed
(
Sequence
<
KBlockWork
,
BBlockWork
>
{});
const
auto
block_work_multi_id
=
block_work_desc
.
GetMultiIndexFrom1dIndex
(
get_block_1d_id
());
const
index_t
k_block_data_on_global
=
block_work_multi_id
[
0
]
*
KPerBlock
;
const
index_t
b_block_data_on_global
=
block_work_multi_id
[
1
]
*
BPerBlock
;
// input tensor
// tensor descriptor in device memory [N0, N1, N2, Ho, Wo]
constexpr
auto
in_n0_n1_n2_h_w_global_desc
=
in_n_c_h_w_global_desc
.
StridedSlice
(
I2
,
Number
<
Ho
>
{},
Number
<
ConvStrideH
>
{})
.
StridedSlice
(
I3
,
Number
<
Wo
>
{},
Number
<
ConvStrideW
>
{})
.
Fold
(
I0
,
Number
<
N1
>
{},
Number
<
N2
>
{})
.
Extract
(
Sequence
<
0
,
1
,
2
,
4
,
5
>
{});
// batch descritpor for device memory
constexpr
auto
in_c_y_x_global_desc
=
in_n_c_h_w_global_desc
.
StridedSlice
(
I2
,
Number
<
Y
>
{},
Number
<
ConvDilationH
>
{})
.
StridedSlice
(
I3
,
Number
<
X
>
{},
Number
<
ConvDilationW
>
{})
.
Extract
(
Sequence
<
1
,
2
,
3
>
{});
// merged tensor descriptor in device memory [E, N1, B, N2], src of blockwise copy
constexpr
auto
in_e_n1_b_n2_global_merged_desc
=
make_ConstantMergedTensorDescriptor
(
in_c_y_x_global_desc
.
Embed
(
in_n0_n1_n2_h_w_global_desc
),
Sequence
<
0
,
1
,
2
>
{},
Sequence
<
4
>
{},
Sequence
<
3
,
6
,
7
>
{},
Sequence
<
5
>
{});
// memory layout descriptor in LDS [E, N1, B, N2], dst of blockwise copy
// be careful of LDS alignment
constexpr
auto
in_e_n1_b_n2_block_desc
=
make_ConstantTensorDescriptor_aligned
(
Sequence
<
EPerBlock
,
N1
,
BPerBlock
,
N2
>
{},
Number
<
InBlockCopyDstDataPerWrite_N2
>
{});
// this check is ad-hoc
// TODO: need to properly implement tensor descriptor with multiple alignment
// requirements
static_assert
(
in_e_n1_b_n2_block_desc
.
GetStride
(
I1
)
%
GemmDataPerReadB
==
0
,
"GemmDataPerReadB alignment requirement is not satisfied"
);
// input blockwise copy
// slice a merged tensor, reorder and copy to a normal tensor
// this copy operator already has blockwise offset built-in
auto
blockwise_in_copy
=
BlockwiseGenericTensorSliceCopy_v2_deprecated
<
BlockSize
,
decltype
(
in_e_n1_b_n2_global_merged_desc
),
decltype
(
in_e_n1_b_n2_block_desc
),
decltype
(
in_e_n1_b_n2_block_desc
.
GetLengths
()),
InBlockCopySubLengths_E_N1_B_N2
,
InBlockCopyClusterLengths_E_N1_B_N2
,
InBlockCopyThreadClusterArrangeOrder
,
InBlockCopySrcAccessOrder
,
InBlockCopyDstAccessOrder
,
2
,
3
,
InBlockCopySrcDataPerRead_B
,
InBlockCopyDstDataPerWrite_N2
>
({
0
,
0
,
b_block_data_on_global
,
0
},
{
0
,
0
,
0
,
0
});
// weight tensor
// Iensor descriptor in device memory, src of blockwise copy
// It is constructed differently, depending on whether forward or backward weight
// convolution
constexpr
auto
wei_e_k_global_desc
=
make_wei_e_k_global_desc_v4r1_deprecated
<
ConvDirection
>
{}(
wei_k_c_y_x_global_desc
);
// tensor descriptor in LDS, dst of blockwise copy
// be careful of LDS alignment
constexpr
auto
wei_e_k_block_desc
=
make_ConstantTensorDescriptor_aligned
(
Sequence
<
EPerBlock
,
KPerBlock
>
{},
Number
<
math
::
lcm
(
WeiBlockCopyDstDataPerWrite_K
,
GemmDataPerReadA
)
>
{});
// operator for blockwise copy of weight into LDS
// slice a tensor, and copy it into another tensor
// this copy operator already have blockwise offset built-in
auto
blockwise_wei_copy
=
BlockwiseGenericTensorSliceCopy_v2_deprecated
<
BlockSize
,
decltype
(
wei_e_k_global_desc
),
decltype
(
wei_e_k_block_desc
),
decltype
(
wei_e_k_block_desc
.
GetLengths
()),
WeiBlockCopySubLengths_E_K
,
WeiBlockCopyClusterLengths_E_K
,
WeiBlockCopyThreadClusterArrangeOrder
,
WeiBlockCopySrcAccessOrder
,
WeiBlockCopyDstAccessOrder
,
0
,
1
,
WeiBlockCopySrcDataPerRead_E
,
WeiBlockCopyDstDataPerWrite_K
>
(
{
0
,
k_block_data_on_global
},
{
0
,
0
});
// GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx
// a_mtx[EPerBlock, KPerBlock] is in LDS
// b_mtx[EPerBlocl, N1 * BPerBlock * N2] is in LDS
// c_mtx[KPerBlock, N1 * BPerBlock * N2] is distributed among threads, and saved in
// register
constexpr
auto
a_e_k_block_mtx_desc
=
make_ConstantMatrixDescriptor
(
wei_e_k_block_desc
);
constexpr
auto
b_e_n1bn2_block_mtx_desc
=
make_ConstantMatrixDescriptor
(
in_e_n1_b_n2_block_desc
.
Unfold
(
I1
,
I3
));
// sanity check
static_assert
(
KPerBlock
%
(
GemmMPerThreadSubC
*
GemmMLevel0Cluster
*
GemmMLevel1Cluster
)
==
0
,
"wrong!"
);
constexpr
index_t
GemmMRepeat
=
KPerBlock
/
(
GemmMPerThreadSubC
*
GemmMLevel0Cluster
*
GemmMLevel1Cluster
);
// c_thread_mtx definition: this is a mess
// TODO:: more elegent way of defining c_thread_mtx
constexpr
auto
c_k0k1_n1n2_thread_mtx_desc
=
make_ConstantMatrixDescriptor_packed
(
Number
<
GemmMRepeat
*
GemmMPerThreadSubC
>
{},
Number
<
GemmNRepeat
*
GemmNPerThreadSubC
>
{});
const
auto
blockwise_gemm
=
BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
<
BlockSize
,
decltype
(
a_e_k_block_mtx_desc
),
decltype
(
b_e_n1bn2_block_mtx_desc
),
decltype
(
c_k0k1_n1n2_thread_mtx_desc
),
GemmMPerThreadSubC
,
GemmNPerThreadSubC
,
GemmMLevel0Cluster
,
GemmNLevel0Cluster
,
GemmMLevel1Cluster
,
GemmNLevel1Cluster
,
GemmKPerThreadLoop
,
GemmDataPerReadA
,
GemmDataPerReadB
>
{};
// LDS allocation for input and weight: be careful of alignment
constexpr
index_t
max_align
=
math
::
lcm
(
InBlockCopyDstDataPerWrite_N2
,
WeiBlockCopyDstDataPerWrite_K
,
GemmDataPerReadA
,
GemmDataPerReadB
);
constexpr
index_t
in_block_space
=
math
::
integer_least_multiple
(
in_e_n1_b_n2_block_desc
.
GetElementSpace
(),
max_align
);
constexpr
index_t
wei_block_space
=
math
::
integer_least_multiple
(
wei_e_k_block_desc
.
GetElementSpace
(),
max_align
);
__shared__
Float
p_in_block_double
[
2
*
in_block_space
];
__shared__
Float
p_wei_block_double
[
2
*
wei_block_space
];
// register allocation for output
AccDataType
p_out_thread
[
c_k0k1_n1n2_thread_mtx_desc
.
GetElementSpace
()];
// zero out threadwise output
threadwise_matrix_set_zero
(
c_k0k1_n1n2_thread_mtx_desc
,
p_out_thread
);
// LDS double buffer: preload data into LDS
{
blockwise_in_copy
.
Run
(
p_in_global
,
p_in_block_double
,
global_address_space
,
generic_address_space
);
blockwise_wei_copy
.
Run
(
p_wei_global
,
p_wei_block_double
,
global_address_space
,
generic_address_space
);
}
// LDS double buffer: main body
for
(
index_t
e_block_data_begin
=
0
;
e_block_data_begin
+
2
*
EPerBlock
<
E
;
e_block_data_begin
+=
2
*
EPerBlock
)
{
#pragma unroll
for
(
index_t
iloop
=
0
;
iloop
<
2
;
++
iloop
)
{
const
bool
even_loop
=
(
iloop
%
2
==
0
);
Float
*
p_in_block_now
=
even_loop
?
p_in_block_double
:
p_in_block_double
+
in_block_space
;
Float
*
p_wei_block_now
=
even_loop
?
p_wei_block_double
:
p_wei_block_double
+
wei_block_space
;
Float
*
p_in_block_next
=
even_loop
?
p_in_block_double
+
in_block_space
:
p_in_block_double
;
Float
*
p_wei_block_next
=
even_loop
?
p_wei_block_double
+
wei_block_space
:
p_wei_block_double
;
Float
p_in_thread_buffer
[
blockwise_in_copy
.
GetThreadBufferSize
()];
Float
p_wei_thread_buffer
[
blockwise_wei_copy
.
GetThreadBufferSize
()];
blockwise_in_copy
.
MoveSrcSliceWindow
(
Sequence
<
EPerBlock
,
0
,
0
,
0
>
{},
True
);
blockwise_wei_copy
.
MoveSrcSliceWindow
(
Sequence
<
EPerBlock
,
0
>
{},
True
);
__syncthreads
();
// LDS doubel buffer: load next data from device mem
blockwise_in_copy
.
RunLoadThreadBuffer
(
p_in_global
,
p_in_thread_buffer
,
global_address_space
,
generic_address_space
);
blockwise_wei_copy
.
RunLoadThreadBuffer
(
p_wei_global
,
p_wei_thread_buffer
,
global_address_space
,
generic_address_space
);
// LDS double buffer: GEMM on current data
blockwise_gemm
.
Run
(
p_wei_block_now
,
p_in_block_now
,
p_out_thread
);
// LDS double buffer: store next data to LDS
blockwise_in_copy
.
RunStoreThreadBuffer
(
p_in_thread_buffer
,
p_in_block_next
);
blockwise_wei_copy
.
RunStoreThreadBuffer
(
p_wei_thread_buffer
,
p_wei_block_next
);
}
}
// LDS double buffer: tail
{
// even iteration
Float
p_in_thread_buffer
[
blockwise_in_copy
.
GetThreadBufferSize
()];
Float
p_wei_thread_buffer
[
blockwise_wei_copy
.
GetThreadBufferSize
()];
blockwise_in_copy
.
MoveSrcSliceWindow
(
Sequence
<
EPerBlock
,
0
,
0
,
0
>
{},
True
);
blockwise_wei_copy
.
MoveSrcSliceWindow
(
Sequence
<
EPerBlock
,
0
>
{},
True
);
__syncthreads
();
// LDS doubel buffer: load next data from device mem
blockwise_in_copy
.
RunLoadThreadBuffer
(
p_in_global
,
p_in_thread_buffer
,
global_address_space
,
generic_address_space
);
blockwise_wei_copy
.
RunLoadThreadBuffer
(
p_wei_global
,
p_wei_thread_buffer
,
global_address_space
,
generic_address_space
);
// LDS double buffer: GEMM on current data
blockwise_gemm
.
Run
(
p_wei_block_double
,
p_in_block_double
,
p_out_thread
);
// LDS double buffer: store next data to LDS
blockwise_in_copy
.
RunStoreThreadBuffer
(
p_in_thread_buffer
,
p_in_block_double
+
in_block_space
);
blockwise_wei_copy
.
RunStoreThreadBuffer
(
p_wei_thread_buffer
,
p_wei_block_double
+
wei_block_space
);
// odd iteration
__syncthreads
();
// LDS double buffer: GEMM on current data
blockwise_gemm
.
Run
(
p_wei_block_double
+
wei_block_space
,
p_in_block_double
+
in_block_space
,
p_out_thread
);
}
// copy output: register to global memory
{
constexpr
index_t
K1
=
GemmMPerThreadSubC
*
GemmMLevel0Cluster
*
GemmMLevel1Cluster
;
// define tensor descriptor for threadwise copy
// output memory layout descriptor in register, src of threadwise copy
constexpr
auto
out_k0_k1_n1_b_n2_thread_mem_desc
=
make_ConstantTensorDescriptor_packed
(
Sequence
<
GemmMRepeat
,
GemmMPerThreadSubC
,
N1
,
1
,
N2
>
{});
// output memory layout descriptor in device memory
constexpr
auto
out_n0_n1_n2_k0_k1_h_w_global_mem_desc
=
out_n_k_h_w_global_desc
.
Fold
(
I1
,
Number
<
K1
>
{}).
Fold
(
I0
,
Number
<
N1
>
{},
Number
<
N2
>
{});
// output merged global tensor descriptor, dst of threadwise copy
constexpr
auto
out_k0_k1_n1_b_n2_global_merged_desc
=
make_ConstantMergedTensorDescriptor
(
out_n0_n1_n2_k0_k1_h_w_global_mem_desc
,
Sequence
<
3
>
{},
Sequence
<
4
>
{},
Sequence
<
1
>
{},
Sequence
<
0
,
5
,
6
>
{},
Sequence
<
2
>
{});
// calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index
const
auto
c_thread_mtx_on_block
=
blockwise_gemm
.
GetBeginOfThreadMatrixC
(
get_thread_local_1d_id
());
const
index_t
k_thread_data_on_global
=
k_block_data_on_global
+
c_thread_mtx_on_block
.
row
;
const
index_t
b_thread_data_on_global
=
b_block_data_on_global
+
c_thread_mtx_on_block
.
col
/
N2
;
ThreadwiseGenericTensorSliceCopy_v2r1_deprecated
<
decltype
(
out_k0_k1_n1_b_n2_thread_mem_desc
),
decltype
(
out_k0_k1_n1_b_n2_global_merged_desc
),
decltype
(
out_k0_k1_n1_b_n2_thread_mem_desc
.
GetLengths
()),
arithmetic_sequence_gen
<
0
,
5
,
1
>::
type
,
arithmetic_sequence_gen
<
0
,
5
,
1
>::
type
,
3
,
3
,
1
,
1
>
({
0
,
0
,
0
,
0
,
0
},
{
k_thread_data_on_global
/
K1
,
k_thread_data_on_global
%
K1
,
0
,
b_thread_data_on_global
,
0
})
.
Run
(
p_out_thread
,
p_out_global
,
generic_address_space
,
global_address_space
);
}
}
};
}
// namespace ck
#endif // CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4_NCHW_KCYX_NKHW_LDS_DOUBLE_BUFFER_DEPRECATED_HPP
composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r4_xdlops_fp16_bfp16_nchw_kcyx_nkhw_lds_double_buffer.hpp
View file @
b37cb71f
...
...
@@ -2,12 +2,12 @@
#define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4R4_FP16_BFP16_NCHW_KCYX_NKHW_HPP_LDS_DOUBLE_BUFFER_HPP
#include "common_header.hpp"
#include "ConstantTensorDescriptor.hpp"
#include "ConstantMergedTensorDescriptor.hpp"
#include "ConstantTensorDescriptor
_deprecated
.hpp"
#include "ConstantMergedTensorDescriptor
_deprecated
.hpp"
#include "ConstantMatrixDescriptor.hpp"
#include "blockwise_generic_tensor_slice_copy.hpp"
#include "blockwise_generic_tensor_slice_copy
_deprecated
.hpp"
#include "blockwise_gemm_xdlops.hpp"
#include "threadwise_generic_tensor_slice_copy.hpp"
#include "threadwise_generic_tensor_slice_copy
_deprecated
.hpp"
#include "implicitgemm_params.hpp"
namespace
ck
{
...
...
@@ -173,12 +173,10 @@ struct GridwiseConvolutionImplicitGemm_v4r4_xdlops_fp16_bfp16_nchw_kcyx_nkhw_lds
// input blockwise copy
// slice a merged tensor, reorder and copy to a normal tensor
// this copy operator already has blockwise offset built-in
auto
blockwise_in_copy
=
BlockwiseGenericTensorSliceCopy_v2
<
auto
blockwise_in_copy
=
BlockwiseGenericTensorSliceCopy_v2
_deprecated
<
BlockSize
,
decltype
(
in_e_b_global_desc
),
decltype
(
in_e_b_block_desc
),
MergedTensorCoordinate
<
decltype
(
in_e_b_global_desc
)
>
,
NormalTensorCoordinate
<
decltype
(
in_e_b_block_desc
)
>
,
decltype
(
in_e_b_block_desc
.
GetLengths
()),
InBlockCopySubLengths_E_B
,
InBlockCopyClusterLengths_E_B
,
...
...
@@ -209,12 +207,10 @@ struct GridwiseConvolutionImplicitGemm_v4r4_xdlops_fp16_bfp16_nchw_kcyx_nkhw_lds
// operator for blockwise copy of weight into LDS
// slice a tensor, and copy it into another tensor
// this copy operator already have blockwise offset built-in
auto
blockwise_wei_copy
=
BlockwiseGenericTensorSliceCopy_v2
<
auto
blockwise_wei_copy
=
BlockwiseGenericTensorSliceCopy_v2
_deprecated
<
BlockSize
,
decltype
(
wei_e_k_global_desc
),
decltype
(
wei_e_k_block_desc
),
NormalTensorCoordinate
<
decltype
(
wei_e_k_global_desc
)
>
,
NormalTensorCoordinate
<
decltype
(
wei_e_k_block_desc
)
>
,
decltype
(
wei_e_k_block_desc
.
GetLengths
()),
WeiBlockCopySubLengths_E_K
,
WeiBlockCopyClusterLengths_E_K
,
...
...
@@ -300,22 +296,21 @@ struct GridwiseConvolutionImplicitGemm_v4r4_xdlops_fp16_bfp16_nchw_kcyx_nkhw_lds
Float
*
p_wei_block_next
=
even_loop
?
p_wei_block_double
+
wei_block_space
:
p_wei_block_double
;
Float
p_in_
register
_buffer
[
blockwise_in_copy
.
Get
Register
BufferSize
()];
Float
p_wei_
register
_buffer
[
blockwise_wei_copy
.
Get
Register
BufferSize
()];
Float
p_in_
thread
_buffer
[
blockwise_in_copy
.
Get
Thread
BufferSize
()];
Float
p_wei_
thread
_buffer
[
blockwise_wei_copy
.
Get
Thread
BufferSize
()];
blockwise_in_copy
.
MoveSrcSlic
ing
Window
(
Sequence
<
EPerBlock
,
0
,
0
>
{},
True
);
blockwise_in_copy
.
MoveSrcSlic
e
Window
(
Sequence
<
EPerBlock
,
0
,
0
>
{},
True
);
static_if
<
conv_dir
==
ImplicitGemmDirection
::
BackwardWeight
>
{}([
&
](
auto
)
{
blockwise_wei_copy
.
MoveSrcSlic
ing
Window
(
Sequence
<
EPerBlock
,
0
,
0
>
{},
True
);
}).
Else
([
&
](
auto
fwd
)
{
p_wei_block_on_global
+=
EPerBlock
*
fwd
(
wei_e_k_global_desc
)
.
GetStride
(
I0
);
blockwise_wei_copy
.
MoveSrcSlic
e
Window
(
Sequence
<
EPerBlock
,
0
,
0
>
{},
True
);
}).
Else
([
&
](
auto
)
{
p_wei_block_on_global
+=
EPerBlock
*
wei_e_k_global_desc
.
GetStride
(
I0
);
});
__syncthreads
();
// LDS doubel buffer: load next data from device mem
blockwise_in_copy
.
RunLoadRegisterBuffer
(
p_in_global
,
p_in_register_buffer
);
blockwise_wei_copy
.
RunLoadRegisterBuffer
(
p_wei_block_on_global
,
p_wei_register_buffer
);
blockwise_in_copy
.
RunLoadThreadBuffer
(
p_in_global
,
p_in_thread_buffer
);
blockwise_wei_copy
.
RunLoadThreadBuffer
(
p_wei_block_on_global
,
p_wei_thread_buffer
);
// LDS double buffer: GEMM on current data
const
typename
vector_type
<
Float
,
EPack
>::
MemoryType
*
p_a_block_vec
=
...
...
@@ -327,29 +322,29 @@ struct GridwiseConvolutionImplicitGemm_v4r4_xdlops_fp16_bfp16_nchw_kcyx_nkhw_lds
blockwise_gemm
.
Run
(
p_a_block_vec
,
p_b_block_vec
,
p_out_thread
);
// LDS double buffer: store next data to LDS
blockwise_in_copy
.
RunStore
Register
Buffer
(
p_in_
register
_buffer
,
p_in_block_next
);
blockwise_wei_copy
.
RunStore
Register
Buffer
(
p_wei_
register
_buffer
,
p_wei_block_next
);
blockwise_in_copy
.
RunStore
Thread
Buffer
(
p_in_
thread
_buffer
,
p_in_block_next
);
blockwise_wei_copy
.
RunStore
Thread
Buffer
(
p_wei_
thread
_buffer
,
p_wei_block_next
);
}
}
// LDS double buffer: tail
{
Float
p_in_
register
_buffer
[
blockwise_in_copy
.
Get
Register
BufferSize
()];
Float
p_wei_
register
_buffer
[
blockwise_wei_copy
.
Get
Register
BufferSize
()];
Float
p_in_
thread
_buffer
[
blockwise_in_copy
.
Get
Thread
BufferSize
()];
Float
p_wei_
thread
_buffer
[
blockwise_wei_copy
.
Get
Thread
BufferSize
()];
// even iteration
blockwise_in_copy
.
MoveSrcSlic
ing
Window
(
Sequence
<
EPerBlock
,
0
,
0
>
{},
True
);
blockwise_in_copy
.
MoveSrcSlic
e
Window
(
Sequence
<
EPerBlock
,
0
,
0
>
{},
True
);
static_if
<
conv_dir
==
ImplicitGemmDirection
::
BackwardWeight
>
{}([
&
](
auto
)
{
blockwise_wei_copy
.
MoveSrcSlic
ing
Window
(
Sequence
<
EPerBlock
,
0
,
0
>
{},
True
);
}).
Else
([
&
](
auto
fwd
)
{
p_wei_block_on_global
+=
EPerBlock
*
fwd
(
wei_e_k_global_desc
)
.
GetStride
(
I0
);
blockwise_wei_copy
.
MoveSrcSlic
e
Window
(
Sequence
<
EPerBlock
,
0
,
0
>
{},
True
);
}).
Else
([
&
](
auto
)
{
p_wei_block_on_global
+=
EPerBlock
*
wei_e_k_global_desc
.
GetStride
(
I0
);
});
__syncthreads
();
// LDS doubel buffer: load next data from device mem
blockwise_in_copy
.
RunLoad
Register
Buffer
(
p_in_global
,
p_in_
register
_buffer
);
blockwise_wei_copy
.
RunLoad
Register
Buffer
(
p_wei_block_on_global
,
p_wei_
register
_buffer
);
blockwise_in_copy
.
RunLoad
Thread
Buffer
(
p_in_global
,
p_in_
thread
_buffer
);
blockwise_wei_copy
.
RunLoad
Thread
Buffer
(
p_wei_block_on_global
,
p_wei_
thread
_buffer
);
// LDS double buffer: GEMM on current data
// Vectorize the pointer to match with how half/bfloat16 datatypes are
...
...
@@ -368,9 +363,9 @@ struct GridwiseConvolutionImplicitGemm_v4r4_xdlops_fp16_bfp16_nchw_kcyx_nkhw_lds
blockwise_gemm
.
Run
(
p_a_block_vec
,
p_b_block_vec
,
p_out_thread
);
// LDS double buffer: store next data to LDS
blockwise_in_copy
.
RunStore
Register
Buffer
(
p_in_
register
_buffer
,
blockwise_in_copy
.
RunStore
Thread
Buffer
(
p_in_
thread
_buffer
,
p_in_block_double
+
in_block_space
);
blockwise_wei_copy
.
RunStore
Register
Buffer
(
p_wei_
register
_buffer
,
blockwise_wei_copy
.
RunStore
Thread
Buffer
(
p_wei_
thread
_buffer
,
p_wei_block_double
+
wei_block_space
);
// odd iteration
...
...
@@ -426,11 +421,9 @@ struct GridwiseConvolutionImplicitGemm_v4r4_xdlops_fp16_bfp16_nchw_kcyx_nkhw_lds
const
index_t
b_thread_data_on_global
=
b_block_data_on_global
+
c_thread_mtx_on_block
.
col
;
auto
threadwise_out_copy
=
ThreadwiseGenericTensorSliceCopy_v2r1
<
auto
threadwise_out_copy
=
ThreadwiseGenericTensorSliceCopy_v2r1
_deprecated
<
decltype
(
out_k0_k1_k2_b_thread_desc
),
decltype
(
out_k0_k1_k2_b_global_desc
),
NormalTensorCoordinate
<
decltype
(
out_k0_k1_k2_b_thread_desc
)
>
,
MergedTensorCoordinate
<
decltype
(
out_k0_k1_k2_b_global_desc
)
>
,
OutThreadCopySliceLengths
,
arithmetic_sequence_gen
<
0
,
4
,
1
>::
type
,
arithmetic_sequence_gen
<
0
,
4
,
1
>::
type
,
...
...
composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r4_xdlops_nchw_kc1x1_nkhw_lds_double_buffer.hpp
View file @
b37cb71f
...
...
@@ -2,12 +2,12 @@
#define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4R4_NCHW_KC1X1_NKHW_HPP_LDS_DOUBLE_BUFFER_HPP
#include "common_header.hpp"
#include "ConstantTensorDescriptor.hpp"
#include "ConstantMergedTensorDescriptor.hpp"
#include "ConstantTensorDescriptor
_deprecated
.hpp"
#include "ConstantMergedTensorDescriptor
_deprecated
.hpp"
#include "ConstantMatrixDescriptor.hpp"
#include "blockwise_generic_tensor_slice_copy.hpp"
#include "blockwise_generic_tensor_slice_copy
_deprecated
.hpp"
#include "blockwise_gemm_xdlops.hpp"
#include "threadwise_generic_tensor_slice_copy.hpp"
#include "threadwise_generic_tensor_slice_copy
_deprecated
.hpp"
#include "implicitgemm_params.hpp"
namespace
ck
{
...
...
@@ -131,11 +131,9 @@ struct GridwiseConvolutionImplicitGemm_v4r4_xdlops_nchw_kc1x1_nkhw_lds_double_bu
// slice a merged tensor, reorder and copy to a normal tensor
// this copy operator already has blockwise offset built-in
auto
blockwise_in_copy
=
BlockwiseGenericTensorSliceCopy_v2
<
BlockSize
,
BlockwiseGenericTensorSliceCopy_v2
_deprecated
<
BlockSize
,
decltype
(
in_e_b_global_desc
),
decltype
(
in_e_b_block_desc
),
MergedTensorCoordinate
<
decltype
(
in_e_b_global_desc
)
>
,
NormalTensorCoordinate
<
decltype
(
in_e_b_block_desc
)
>
,
decltype
(
in_e_b_block_desc
.
GetLengths
()),
InBlockCopySubLengths_E_B
,
InBlockCopyClusterLengths_E_B
,
...
...
@@ -167,12 +165,10 @@ struct GridwiseConvolutionImplicitGemm_v4r4_xdlops_nchw_kc1x1_nkhw_lds_double_bu
// operator for blockwise copy of weight into LDS
// slice a tensor, and copy it into another tensor
// this copy operator already have blockwise offset built-in
auto
blockwise_wei_copy
=
BlockwiseGenericTensorSliceCopy_v2
<
BlockSize
,
auto
blockwise_wei_copy
=
BlockwiseGenericTensorSliceCopy_v2_deprecated
<
BlockSize
,
decltype
(
wei_e_k_global_desc
),
decltype
(
wei_e_k_block_desc
),
NormalTensorCoordinate
<
decltype
(
wei_e_k_global_desc
)
>
,
NormalTensorCoordinate
<
decltype
(
wei_e_k_block_desc
)
>
,
decltype
(
wei_e_k_block_desc
.
GetLengths
()),
WeiBlockCopySubLengths_E_K
,
WeiBlockCopyClusterLengths_E_K
,
...
...
@@ -182,7 +178,8 @@ struct GridwiseConvolutionImplicitGemm_v4r4_xdlops_nchw_kc1x1_nkhw_lds_double_bu
0
,
1
,
WeiBlockCopySrcDataPerRead_E
,
WeiBlockCopyDstDataPerWrite_K
>
({
0
,
k_block_data_on_global
},
{
0
,
0
});
WeiBlockCopyDstDataPerWrite_K
>
(
{
0
,
k_block_data_on_global
},
{
0
,
0
});
// GEMM definition
constexpr
auto
a_e_k_block_mtx_desc
=
make_ConstantMatrixDescriptor
(
wei_e_k_block_desc
);
...
...
@@ -253,50 +250,49 @@ struct GridwiseConvolutionImplicitGemm_v4r4_xdlops_nchw_kc1x1_nkhw_lds_double_bu
Float
*
p_wei_block_next
=
even_loop
?
p_wei_block_double
+
wei_block_space
:
p_wei_block_double
;
Float
p_in_
register
_buffer
[
blockwise_in_copy
.
Get
Register
BufferSize
()];
Float
p_wei_
register
_buffer
[
blockwise_wei_copy
.
Get
Register
BufferSize
()];
Float
p_in_
thread
_buffer
[
blockwise_in_copy
.
Get
Thread
BufferSize
()];
Float
p_wei_
thread
_buffer
[
blockwise_wei_copy
.
Get
Thread
BufferSize
()];
blockwise_in_copy
.
MoveSrcSlic
ing
Window
(
Sequence
<
EPerBlock
,
0
>
{},
True
);
blockwise_in_copy
.
MoveSrcSlic
e
Window
(
Sequence
<
EPerBlock
,
0
>
{},
True
);
p_wei_block_on_global
+=
EPerBlock
*
wei_e_k_global_desc
.
GetStrides
()[
0
];
__syncthreads
();
// LDS doubel buffer: load next data from device mem
blockwise_in_copy
.
RunLoadRegisterBuffer
(
p_in_global
,
p_in_register_buffer
);
blockwise_wei_copy
.
RunLoadRegisterBuffer
(
p_wei_block_on_global
,
p_wei_register_buffer
);
blockwise_in_copy
.
RunLoadThreadBuffer
(
p_in_global
,
p_in_thread_buffer
);
blockwise_wei_copy
.
RunLoadThreadBuffer
(
p_wei_block_on_global
,
p_wei_thread_buffer
);
// LDS double buffer: GEMM on current data
blockwise_gemm
.
Run
(
p_wei_block_now
,
p_in_block_now
,
p_out_thread
);
// LDS double buffer: store next data to LDS
blockwise_in_copy
.
RunStore
Register
Buffer
(
p_in_
register
_buffer
,
p_in_block_next
);
blockwise_wei_copy
.
RunStore
Register
Buffer
(
p_wei_
register
_buffer
,
p_wei_block_next
);
blockwise_in_copy
.
RunStore
Thread
Buffer
(
p_in_
thread
_buffer
,
p_in_block_next
);
blockwise_wei_copy
.
RunStore
Thread
Buffer
(
p_wei_
thread
_buffer
,
p_wei_block_next
);
}
}
// LDS double buffer: tail
{
Float
p_in_
register
_buffer
[
blockwise_in_copy
.
Get
Register
BufferSize
()];
Float
p_wei_
register
_buffer
[
blockwise_wei_copy
.
Get
Register
BufferSize
()];
Float
p_in_
thread
_buffer
[
blockwise_in_copy
.
Get
Thread
BufferSize
()];
Float
p_wei_
thread
_buffer
[
blockwise_wei_copy
.
Get
Thread
BufferSize
()];
// even iteration
blockwise_in_copy
.
MoveSrcSlic
ing
Window
(
Sequence
<
EPerBlock
,
0
>
{},
True
);
blockwise_in_copy
.
MoveSrcSlic
e
Window
(
Sequence
<
EPerBlock
,
0
>
{},
True
);
p_wei_block_on_global
+=
EPerBlock
*
wei_e_k_global_desc
.
GetStrides
()[
0
];
__syncthreads
();
// LDS doubel buffer: load next data from device mem
blockwise_in_copy
.
RunLoad
Register
Buffer
(
p_in_global
,
p_in_
register
_buffer
);
blockwise_wei_copy
.
RunLoad
Register
Buffer
(
p_wei_block_on_global
,
p_wei_
register
_buffer
);
blockwise_in_copy
.
RunLoad
Thread
Buffer
(
p_in_global
,
p_in_
thread
_buffer
);
blockwise_wei_copy
.
RunLoad
Thread
Buffer
(
p_wei_block_on_global
,
p_wei_
thread
_buffer
);
// LDS double buffer: GEMM on current data
blockwise_gemm
.
Run
(
p_wei_block_double
,
p_in_block_double
,
p_out_thread
);
// LDS double buffer: store next data to LDS
blockwise_in_copy
.
RunStore
Register
Buffer
(
p_in_
register
_buffer
,
blockwise_in_copy
.
RunStore
Thread
Buffer
(
p_in_
thread
_buffer
,
p_in_block_double
+
in_block_space
);
blockwise_wei_copy
.
RunStore
Register
Buffer
(
p_wei_
register
_buffer
,
blockwise_wei_copy
.
RunStore
Thread
Buffer
(
p_wei_
thread
_buffer
,
p_wei_block_double
+
wei_block_space
);
// odd iteration
...
...
@@ -373,11 +369,9 @@ struct GridwiseConvolutionImplicitGemm_v4r4_xdlops_nchw_kc1x1_nkhw_lds_double_bu
const
index_t
b_thread_data_on_global
=
b_block_data_on_global
+
c_thread_mtx_on_block
.
col
;
auto
threadwise_out_copy
=
ThreadwiseGenericTensorSliceCopy_v2r1
<
auto
threadwise_out_copy
=
ThreadwiseGenericTensorSliceCopy_v2r1
_deprecated
<
decltype
(
out_k0_k1_k2_b_thread_desc
),
decltype
(
out_k0_k1_k2_b_global_desc
),
NormalTensorCoordinate
<
decltype
(
out_k0_k1_k2_b_thread_desc
)
>
,
MergedTensorCoordinate
<
decltype
(
out_k0_k1_k2_b_global_desc
)
>
,
OutThreadCopySliceLengths
,
arithmetic_sequence_gen
<
0
,
4
,
1
>::
type
,
arithmetic_sequence_gen
<
0
,
4
,
1
>::
type
,
...
...
composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw_lds_double_buffer.hpp
View file @
b37cb71f
...
...
@@ -2,12 +2,12 @@
#define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4R4_XDLOPS_NCHW_KCYX_NKHW_HPP_LDS_DOUBLE_BUFFER_HPP
#include "common_header.hpp"
#include "ConstantTensorDescriptor.hpp"
#include "ConstantMergedTensorDescriptor.hpp"
#include "ConstantTensorDescriptor
_deprecated
.hpp"
#include "ConstantMergedTensorDescriptor
_deprecated
.hpp"
#include "ConstantMatrixDescriptor.hpp"
#include "blockwise_generic_tensor_slice_copy.hpp"
#include "blockwise_generic_tensor_slice_copy
_deprecated
.hpp"
#include "blockwise_gemm_xdlops.hpp"
#include "threadwise_generic_tensor_slice_copy.hpp"
#include "threadwise_generic_tensor_slice_copy
_deprecated
.hpp"
#include "implicitgemm_params.hpp"
namespace
ck
{
...
...
@@ -80,8 +80,6 @@ struct GridwiseConvolutionImplicitGemm_v4r4_xdlops_nchw_kcyx_nkhw_lds_double_buf
const
Float
*
const
__restrict__
p_wei_global
,
Float
*
const
__restrict__
p_out_global
)
const
{
if
(
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
==
0
)
printf
(
"conv dir %d"
,
conv_dir
);
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
...
...
@@ -162,11 +160,9 @@ struct GridwiseConvolutionImplicitGemm_v4r4_xdlops_nchw_kcyx_nkhw_lds_double_buf
// slice a merged tensor, reorder and copy to a normal tensor
// this copy operator already has blockwise offset built-in
auto
blockwise_in_copy
=
BlockwiseGenericTensorSliceCopy_v2
<
BlockSize
,
BlockwiseGenericTensorSliceCopy_v2
_deprecated
<
BlockSize
,
decltype
(
in_e_b_global_desc
),
decltype
(
in_e_b_block_desc
),
MergedTensorCoordinate
<
decltype
(
in_e_b_global_desc
)
>
,
NormalTensorCoordinate
<
decltype
(
in_e_b_block_desc
)
>
,
decltype
(
in_e_b_block_desc
.
GetLengths
()),
InBlockCopySubLengths_E_B
,
InBlockCopyClusterLengths_E_B
,
...
...
@@ -194,12 +190,10 @@ struct GridwiseConvolutionImplicitGemm_v4r4_xdlops_nchw_kcyx_nkhw_lds_double_buf
// operator for blockwise copy of weight into LDS
// slice a tensor, and copy it into another tensor
// this copy operator already have blockwise offset built-in
auto
blockwise_wei_copy
=
BlockwiseGenericTensorSliceCopy_v2
<
BlockSize
,
auto
blockwise_wei_copy
=
BlockwiseGenericTensorSliceCopy_v2_deprecated
<
BlockSize
,
decltype
(
wei_e_k_global_desc
),
decltype
(
wei_e_k_block_desc
),
NormalTensorCoordinate
<
decltype
(
wei_e_k_global_desc
)
>
,
NormalTensorCoordinate
<
decltype
(
wei_e_k_block_desc
)
>
,
decltype
(
wei_e_k_block_desc
.
GetLengths
()),
WeiBlockCopySubLengths_E_K
,
WeiBlockCopyClusterLengths_E_K
,
...
...
@@ -209,7 +203,8 @@ struct GridwiseConvolutionImplicitGemm_v4r4_xdlops_nchw_kcyx_nkhw_lds_double_buf
0
,
1
,
WeiBlockCopySrcDataPerRead_E
,
WeiBlockCopyDstDataPerWrite_K
>
({
0
,
k_block_data_on_global
},
{
0
,
0
});
WeiBlockCopyDstDataPerWrite_K
>
(
{
0
,
k_block_data_on_global
},
{
0
,
0
});
// GEMM definition
constexpr
auto
a_e_k_block_mtx_desc
=
make_ConstantMatrixDescriptor
(
wei_e_k_block_desc
);
...
...
@@ -261,6 +256,7 @@ struct GridwiseConvolutionImplicitGemm_v4r4_xdlops_nchw_kcyx_nkhw_lds_double_buf
blockwise_wei_copy
.
Run
(
p_wei_global
,
p_wei_block_double
);
}
#if 1
// LDS double buffer: main body
for
(
index_t
e_block_data_begin
=
0
;
e_block_data_begin
+
2
*
EPerBlock
<
E
;
e_block_data_begin
+=
2
*
EPerBlock
)
...
...
@@ -280,58 +276,50 @@ struct GridwiseConvolutionImplicitGemm_v4r4_xdlops_nchw_kcyx_nkhw_lds_double_buf
Float
*
p_wei_block_next
=
even_loop
?
p_wei_block_double
+
wei_block_space
:
p_wei_block_double
;
Float
p_in_
register
_buffer
[
blockwise_in_copy
.
Get
Register
BufferSize
()];
Float
p_wei_
register
_buffer
[
blockwise_wei_copy
.
Get
Register
BufferSize
()];
Float
p_in_
thread
_buffer
[
blockwise_in_copy
.
Get
Thread
BufferSize
()];
Float
p_wei_
thread
_buffer
[
blockwise_wei_copy
.
Get
Thread
BufferSize
()];
blockwise_in_copy
.
MoveSrcSlicingWindow
(
Sequence
<
EPerBlock
,
0
>
{},
True
);
static_if
<
conv_dir
==
ImplicitGemmDirection
::
BackwardWeight
>
{}([
&
](
auto
)
{
blockwise_wei_copy
.
MoveSrcSlicingWindow
(
Sequence
<
EPerBlock
,
0
>
{},
True
);
}).
Else
([
&
](
auto
fwd
)
{
p_wei_block_on_global
+=
EPerBlock
*
fwd
(
wei_e_k_global_desc
).
GetStride
(
I0
);
});
blockwise_in_copy
.
MoveSrcSliceWindow
(
Sequence
<
EPerBlock
,
0
>
{},
True
);
blockwise_wei_copy
.
MoveSrcSliceWindow
(
Sequence
<
EPerBlock
,
0
>
{},
True
);
__syncthreads
();
// LDS doubel buffer: load next data from device mem
blockwise_in_copy
.
RunLoadRegisterBuffer
(
p_in_global
,
p_in_register_buffer
);
blockwise_wei_copy
.
RunLoadRegisterBuffer
(
p_wei_block_on_global
,
p_wei_register_buffer
);
blockwise_in_copy
.
RunLoadThreadBuffer
(
p_in_global
,
p_in_thread_buffer
);
blockwise_wei_copy
.
RunLoadThreadBuffer
(
p_wei_block_on_global
,
p_wei_thread_buffer
);
// LDS double buffer: GEMM on current data
blockwise_gemm
.
Run
(
p_wei_block_now
,
p_in_block_now
,
p_out_thread
);
// LDS double buffer: store next data to LDS
blockwise_in_copy
.
RunStore
Register
Buffer
(
p_in_
register
_buffer
,
p_in_block_next
);
blockwise_wei_copy
.
RunStore
Register
Buffer
(
p_wei_
register
_buffer
,
p_wei_block_next
);
blockwise_in_copy
.
RunStore
Thread
Buffer
(
p_in_
thread
_buffer
,
p_in_block_next
);
blockwise_wei_copy
.
RunStore
Thread
Buffer
(
p_wei_
thread
_buffer
,
p_wei_block_next
);
}
}
#endif
// LDS double buffer: tail
{
Float
p_in_
register
_buffer
[
blockwise_in_copy
.
Get
Register
BufferSize
()];
Float
p_wei_
register
_buffer
[
blockwise_wei_copy
.
Get
Register
BufferSize
()];
Float
p_in_
thread
_buffer
[
blockwise_in_copy
.
Get
Thread
BufferSize
()];
Float
p_wei_
thread
_buffer
[
blockwise_wei_copy
.
Get
Thread
BufferSize
()];
// even iteration
blockwise_in_copy
.
MoveSrcSlicingWindow
(
Sequence
<
EPerBlock
,
0
>
{},
True
);
static_if
<
conv_dir
==
ImplicitGemmDirection
::
BackwardWeight
>
{}([
&
](
auto
)
{
blockwise_wei_copy
.
MoveSrcSlicingWindow
(
Sequence
<
EPerBlock
,
0
>
{},
True
);
}).
Else
([
&
](
auto
fwd
)
{
p_wei_block_on_global
+=
EPerBlock
*
fwd
(
wei_e_k_global_desc
).
GetStride
(
I0
);
});
blockwise_in_copy
.
MoveSrcSliceWindow
(
Sequence
<
EPerBlock
,
0
>
{},
True
);
blockwise_wei_copy
.
MoveSrcSliceWindow
(
Sequence
<
EPerBlock
,
0
>
{},
True
);
__syncthreads
();
// LDS doubel buffer: load next data from device mem
blockwise_in_copy
.
RunLoad
Register
Buffer
(
p_in_global
,
p_in_
register
_buffer
);
blockwise_wei_copy
.
RunLoad
Register
Buffer
(
p_wei_block_on_global
,
p_wei_
register
_buffer
);
blockwise_in_copy
.
RunLoad
Thread
Buffer
(
p_in_global
,
p_in_
thread
_buffer
);
blockwise_wei_copy
.
RunLoad
Thread
Buffer
(
p_wei_block_on_global
,
p_wei_
thread
_buffer
);
// LDS double buffer: GEMM on current data
blockwise_gemm
.
Run
(
p_wei_block_double
,
p_in_block_double
,
p_out_thread
);
// LDS double buffer: store next data to LDS
blockwise_in_copy
.
RunStore
Register
Buffer
(
p_in_
register
_buffer
,
blockwise_in_copy
.
RunStore
Thread
Buffer
(
p_in_
thread
_buffer
,
p_in_block_double
+
in_block_space
);
blockwise_wei_copy
.
RunStore
Register
Buffer
(
p_wei_
register
_buffer
,
blockwise_wei_copy
.
RunStore
Thread
Buffer
(
p_wei_
thread
_buffer
,
p_wei_block_double
+
wei_block_space
);
// odd iteration
...
...
@@ -384,11 +372,9 @@ struct GridwiseConvolutionImplicitGemm_v4r4_xdlops_nchw_kcyx_nkhw_lds_double_buf
const
index_t
b_thread_data_on_global
=
b_block_data_on_global
+
c_thread_mtx_on_block
.
col
;
auto
threadwise_out_copy
=
ThreadwiseGenericTensorSliceCopy_v2r1
<
auto
threadwise_out_copy
=
ThreadwiseGenericTensorSliceCopy_v2r1
_deprecated
<
decltype
(
out_k0_k1_k2_b_thread_desc
),
decltype
(
out_k0_k1_k2_b_global_desc
),
NormalTensorCoordinate
<
decltype
(
out_k0_k1_k2_b_thread_desc
)
>
,
MergedTensorCoordinate
<
decltype
(
out_k0_k1_k2_b_global_desc
)
>
,
OutThreadCopySliceLengths
,
arithmetic_sequence_gen
<
0
,
4
,
1
>::
type
,
arithmetic_sequence_gen
<
0
,
4
,
1
>::
type
,
...
...
composable_kernel/include/tensor_description/ConstantMatrixDescriptor.hpp
View file @
b37cb71f
...
...
@@ -2,7 +2,8 @@
#define CK_CONSTANT_MATRIX_DESCRIPTOR_HPP
#include "common_header.hpp"
#include "ConstantTensorDescriptor.hpp"
#include "ConstantTensorDescriptor_deprecated.hpp"
#include "tensor_descriptor.hpp"
namespace
ck
{
...
...
@@ -31,6 +32,11 @@ struct ConstantMatrixDescriptor
return
irow
*
RowStride_
+
icol
;
}
__host__
__device__
static
index_t
CalculateOffset
(
index_t
irow
,
index_t
icol
)
{
return
GetOffsetFromMultiIndex
(
irow
,
icol
);
}
template
<
index_t
SubNRow
,
index_t
SubNCol
>
__host__
__device__
static
constexpr
auto
MakeSubMatrixDescriptor
(
Number
<
SubNRow
>
,
Number
<
SubNCol
>
)
...
...
@@ -52,10 +58,22 @@ __host__ __device__ constexpr auto
return
ConstantMatrixDescriptor
<
NRow
,
NCol
,
RowStride
>
{};
}
template
<
class
...
Ts
>
__host__
__device__
constexpr
auto
make_ConstantMatrixDescriptor
(
ConstantTensorDescriptor
<
Ts
...
>
)
template
<
typename
...
Ts
>
__host__
__device__
constexpr
auto
make_ConstantMatrixDescriptor
(
ConstantTensorDescriptor_deprecated
<
Ts
...
>
)
{
using
TDesc
=
ConstantTensorDescriptor_deprecated
<
Ts
...
>
;
static_assert
(
TDesc
::
GetNumOfDimension
()
==
2
,
"wrong"
);
static_assert
(
TDesc
::
GetStrides
()[
1
]
==
1
,
"wrong"
);
return
ConstantMatrixDescriptor
<
TDesc
::
GetLengths
()[
0
],
TDesc
::
GetLengths
()[
1
],
TDesc
::
GetStrides
()[
0
]
>
{};
}
template
<
typename
...
Ts
>
__host__
__device__
constexpr
auto
make_ConstantMatrixDescriptor
(
NativeTensorDescriptor
<
Ts
...
>
)
{
using
TDesc
=
Constant
TensorDescriptor
<
Ts
...
>
;
using
TDesc
=
Native
TensorDescriptor
<
Ts
...
>
;
static_assert
(
TDesc
::
GetNumOfDimension
()
==
2
,
"wrong"
);
static_assert
(
TDesc
::
GetStrides
()[
1
]
==
1
,
"wrong"
);
return
ConstantMatrixDescriptor
<
TDesc
::
GetLengths
()[
0
],
...
...
@@ -63,7 +81,7 @@ __host__ __device__ constexpr auto make_ConstantMatrixDescriptor(ConstantTensorD
TDesc
::
GetStrides
()[
0
]
>
{};
}
template
<
class
TDesc
>
template
<
typename
TDesc
>
__host__
__device__
void
print_ConstantMatrixDescriptor
(
TDesc
,
const
char
*
s
)
{
printf
(
...
...
composable_kernel/include/tensor_description/ConstantMergedTensorDescriptor_deprecated.hpp
0 → 100644
View file @
b37cb71f
#ifndef CK_CONSTANT_MERGED_TENSOR_DESCRIPTOR_DEPRECATED_HPP
#define CK_CONSTANT_MERGED_TENSOR_DESCRIPTOR_DEPRECATED_HPP
#include "common_header.hpp"
#include "ConstantTensorDescriptor_deprecated.hpp"
namespace
ck
{
// OriginalTensorDesc : ConstantTensorDescriptor_deprecated<...>
// it's the tensor whose dimensions are to be merged
// OriginalDimMergeSeqs : Sequence<...>...
// each is a sequence of original dimensions (of OriginalTensorDesc) to be merged
template
<
class
OriginalTensorDesc
,
class
...
OriginalDimMergeSeqs
>
struct
ConstantMergedTensorDescriptor_deprecated
{
using
Type
=
ConstantMergedTensorDescriptor_deprecated
;
static
constexpr
auto
mOriginalDimMergeSeqs
=
std
::
tuple
<
OriginalDimMergeSeqs
...
>
{};
static
constexpr
index_t
nDim
=
sizeof
...(
OriginalDimMergeSeqs
);
static
constexpr
index_t
nOriginalDim
=
OriginalTensorDesc
::
GetNumOfDimension
();
__host__
__device__
constexpr
ConstantMergedTensorDescriptor_deprecated
()
{
static_assert
(
nDim
<=
nOriginalDim
,
"wrong!"
);
// TODO: check each of OriginalDimMergeSeqs contains at least 1, and at most
// OriginalTensorDesc::nDim number of dimensions
// TODO: check OriginalDimMergeSeqs contains all original dimensions
// TODO: check there is no duplication in OriginalDimMergeSeqs
}
__host__
__device__
static
constexpr
auto
GetOriginalTensorDescriptor
()
{
return
OriginalTensorDesc
{};
}
__host__
__device__
static
constexpr
auto
GetNumOfDimension
()
{
return
Number
<
nDim
>
{};
}
template
<
index_t
IDim
>
__host__
__device__
static
constexpr
auto
GetContainedOriginalDimensions
(
Number
<
IDim
>
)
{
return
std
::
get
<
IDim
>
(
mOriginalDimMergeSeqs
);
}
template
<
index_t
IDim
>
__host__
__device__
static
constexpr
bool
ContainMultipleOriginalDimensions
(
Number
<
IDim
>
)
{
return
(
std
::
get
<
IDim
>
(
mOriginalDimMergeSeqs
).
GetSize
()
>
1
);
}
template
<
index_t
IDim
>
__host__
__device__
static
constexpr
auto
GetLength
(
Number
<
IDim
>
)
{
constexpr
auto
original_dims_partial
=
std
::
get
<
IDim
>
(
mOriginalDimMergeSeqs
);
return
OriginalTensorDesc
::
Extract
(
original_dims_partial
).
GetElementSize
();
}
template
<
index_t
IDim
>
__host__
__device__
static
constexpr
auto
GetStride
(
Number
<
IDim
>
)
{
static_assert
(
!
ContainMultipleOriginalDimensions
(
Number
<
IDim
>
{}),
"wrong! stride of a merged dimension is undefined"
);
constexpr
auto
idim_original
=
std
::
get
<
IDim
>
(
mOriginalDimMergeSeqs
).
Back
();
return
OriginalTensorDesc
::
GetStride
(
Number
<
idim_original
>
{});
}
// this is a hack to return the stride of the last original dimension of a merged dimension
// TODO: refactor this once the concept of "dimension" is used
template
<
index_t
IDim
>
__host__
__device__
static
constexpr
auto
GetLastOriginalDimensionStride
(
Number
<
IDim
>
)
{
constexpr
auto
idim_last_original
=
std
::
get
<
IDim
>
(
mOriginalDimMergeSeqs
).
Back
();
return
OriginalTensorDesc
::
GetStride
(
Number
<
idim_last_original
>
{});
}
__host__
__device__
static
constexpr
auto
GetLengths
()
{
return
Sequence
<
OriginalTensorDesc
::
Extract
(
OriginalDimMergeSeqs
{}).
GetElementSize
()...
>
{};
}
__host__
__device__
static
constexpr
auto
GetElementSize
()
{
return
OriginalTensorDesc
::
GetElementSize
();
}
template
<
class
OriginalDimsPartial
>
struct
lambda_1_GetOriginalMultiIndexFromMultiIndex
{
const
Array
<
index_t
,
OriginalDimsPartial
::
GetSize
()
>&
original_multi_id_partial
;
Array
<
index_t
,
nOriginalDim
>&
original_multi_id
;
__host__
__device__
constexpr
lambda_1_GetOriginalMultiIndexFromMultiIndex
(
const
Array
<
index_t
,
OriginalDimsPartial
::
GetSize
()
>&
original_multi_id_partial_
,
Array
<
index_t
,
nOriginalDim
>&
original_multi_id_
)
:
original_multi_id_partial
(
original_multi_id_partial_
),
original_multi_id
(
original_multi_id_
)
{
}
template
<
index_t
I
>
__host__
__device__
constexpr
void
operator
()(
Number
<
I
>
)
const
{
constexpr
index_t
idim_original
=
OriginalDimsPartial
::
Get
(
Number
<
I
>
{});
index_t
itmp
=
original_multi_id_partial
[
I
];
original_multi_id
(
idim_original
)
=
itmp
;
}
};
struct
lambda_0_GetOriginalMultiIndexFromMultiIndex
{
const
Array
<
index_t
,
nDim
>&
multi_id
;
Array
<
index_t
,
nOriginalDim
>&
original_multi_id
;
__host__
__device__
constexpr
lambda_0_GetOriginalMultiIndexFromMultiIndex
(
const
Array
<
index_t
,
nDim
>&
multi_id_
,
Array
<
index_t
,
nOriginalDim
>&
original_multi_id_
)
:
multi_id
(
multi_id_
),
original_multi_id
(
original_multi_id_
)
{
}
template
<
index_t
IDim
>
__host__
__device__
constexpr
void
operator
()(
Number
<
IDim
>
)
const
{
constexpr
auto
original_dims_partial
=
std
::
get
<
IDim
>
(
Type
::
mOriginalDimMergeSeqs
);
// get partial original-multi-id corresponding to this merged dimension
const
auto
original_multi_id_partial
=
OriginalTensorDesc
::
Extract
(
original_dims_partial
)
.
GetMultiIndexFrom1dIndex
(
multi_id
[
IDim
]);
static_for
<
0
,
original_dims_partial
.
GetSize
(),
1
>
{}(
lambda_1_GetOriginalMultiIndexFromMultiIndex
<
decltype
(
original_dims_partial
)
>
(
original_multi_id_partial
,
original_multi_id
));
}
};
// return type is Array<...>
__host__
__device__
static
constexpr
auto
GetOriginalMultiIndexFromMultiIndex
(
Array
<
index_t
,
nDim
>
multi_id
)
{
Array
<
index_t
,
nOriginalDim
>
original_multi_id
;
static_for
<
0
,
nDim
,
1
>
{}(
lambda_0_GetOriginalMultiIndexFromMultiIndex
(
multi_id
,
original_multi_id
));
return
original_multi_id
;
}
template
<
index_t
...
Is
>
__host__
__device__
static
constexpr
index_t
GetOffsetFromMultiIndex
(
Sequence
<
Is
...
>
)
{
constexpr
auto
multi_id
=
sequence2array
(
Sequence
<
Is
...
>
{});
constexpr
auto
original_multi_id
=
GetOriginalMultiIndexFromMultiIndex
(
multi_id
);
return
OriginalTensorDesc
::
GetOffsetFromMultiIndex
(
original_multi_id
);
}
__host__
__device__
static
constexpr
index_t
GetOffsetFromMultiIndex
(
Array
<
index_t
,
nDim
>
multi_id
)
{
auto
original_multi_id
=
GetOriginalMultiIndexFromMultiIndex
(
multi_id
);
return
OriginalTensorDesc
::
GetOffsetFromMultiIndex
(
original_multi_id
);
}
template
<
class
...
Is
>
__host__
__device__
static
constexpr
index_t
GetOffsetFromMultiIndex
(
Is
...
is
)
{
return
GetOffsetFromMultiIndex
(
Array
<
index_t
,
nDim
>
{
is
...});
}
__host__
__device__
static
constexpr
Array
<
index_t
,
nDim
>
GetMultiIndexFrom1dIndex
(
index_t
id
)
{
constexpr
auto
packed_desc
=
make_ConstantTensorDescriptor_packed
(
GetLengths
());
return
packed_desc
.
GetMultiIndexFrom1dIndex
(
id
);
}
__host__
__device__
static
constexpr
auto
Pack
()
{
constexpr
auto
lengths
=
GetLengths
();
constexpr
auto
strides
=
calculate_tensor_strides_packed
(
lengths
);
return
ConstantTensorDescriptor_deprecated
<
decltype
(
lengths
),
decltype
(
strides
)
>
{};
}
};
template
<
class
OriginalTensorDesc
,
class
...
OriginalDimMergeSeqs
>
__host__
__device__
constexpr
auto
make_ConstantMergedTensorDescriptor
(
OriginalTensorDesc
,
OriginalDimMergeSeqs
...)
{
return
ConstantMergedTensorDescriptor_deprecated
<
OriginalTensorDesc
,
OriginalDimMergeSeqs
...
>
{};
}
template
<
class
TDesc
>
__host__
__device__
void
print_ConstantMergedTensorDescriptor
(
const
char
*
s
,
TDesc
)
{
print_ConstantTensorDescriptor
(
s
,
TDesc
::
GetOriginalTensorDescriptor
());
}
}
// namespace ck
#endif
composable_kernel/include/tensor_description/ConstantTensorDescriptor_deprecated.hpp
0 → 100644
View file @
b37cb71f
#ifndef CK_CONSTANT_TENSOR_DESCRIPTOR_DEPRECATED_HPP
#define CK_CONSTANT_TENSOR_DESCRIPTOR_DEPRECATED_HPP
#include "common_header.hpp"
namespace
ck
{
template
<
class
Lengths
>
__host__
__device__
constexpr
auto
calculate_tensor_strides_packed_deprecated
(
Lengths
)
{
return
reverse_inclusive_scan_sequence
(
Lengths
{}.
PopFront
(),
math
::
multiplies
<
index_t
>
{},
Number
<
1
>
{})
.
PushBack
(
Number
<
1
>
{});
}
template
<
class
Lengths
,
index_t
Align
>
__host__
__device__
constexpr
auto
calculate_tensor_strides_aligned_old
(
Lengths
,
Number
<
Align
>
)
{
constexpr
index_t
L_back_align
=
Align
*
math
::
integer_divide_ceiler
<
index_t
>
{}(
Lengths
{}.
Back
(),
Align
);
return
calculate_tensor_strides_packed_deprecated
(
Lengths
{}.
Modify
(
Number
<
Lengths
{}.
GetSize
()
-
1
>
{},
Number
<
L_back_align
>
{}));
}
template
<
class
Lengths
,
class
Strides
>
struct
ConstantTensorDescriptor_deprecated
{
using
Type
=
ConstantTensorDescriptor_deprecated
;
static
constexpr
index_t
nDim
=
Lengths
::
GetSize
();
__host__
__device__
constexpr
ConstantTensorDescriptor_deprecated
()
{
static_assert
(
Lengths
::
GetSize
()
==
Strides
::
GetSize
(),
"nDim not consistent"
);
}
__host__
__device__
static
constexpr
auto
GetOriginalTensorDescriptor
()
{
return
Type
{};
}
template
<
index_t
IDim
>
__host__
__device__
static
constexpr
auto
GetContainedOriginalDimensions
(
Number
<
IDim
>
)
{
return
Sequence
<
IDim
>
{};
}
__host__
__device__
static
constexpr
auto
GetNumOfDimension
()
{
return
Number
<
nDim
>
{};
}
__host__
__device__
static
constexpr
auto
GetLengths
()
{
return
Lengths
{};
}
__host__
__device__
static
constexpr
auto
GetStrides
()
{
return
Strides
{};
}
__host__
__device__
static
constexpr
auto
GetLength
(
index_t
IDim
)
{
return
Lengths
{}[
IDim
];
}
__host__
__device__
static
constexpr
auto
GetStride
(
index_t
IDim
)
{
return
Strides
{}[
IDim
];
}
struct
lambda_AreDimensionsContinuous
{
bool
&
is_continuous
;
__host__
__device__
constexpr
lambda_AreDimensionsContinuous
(
bool
&
is_continuous_
)
:
is_continuous
(
is_continuous_
)
{
}
template
<
index_t
IDim_
>
__host__
__device__
constexpr
void
operator
()(
Number
<
IDim_
>
)
const
{
constexpr
auto
IDim
=
Number
<
IDim_
>
{};
constexpr
auto
IDim_p1
=
Number
<
IDim_
+
1
>
{};
is_continuous
=
is_continuous
&&
(
GetStride
(
IDim
)
>=
GetStride
(
IDim_p1
)
&&
GetStride
(
IDim
)
==
GetStride
(
IDim_p1
)
*
GetLength
(
IDim_p1
));
}
};
__host__
__device__
static
constexpr
bool
AreDimensionsContinuous
()
{
bool
is_continuous
=
true
;
static_for
<
0
,
nDim
-
1
,
1
>
{}(
lambda_AreDimensionsContinuous
(
is_continuous
));
return
is_continuous
;
}
__host__
__device__
static
constexpr
bool
IsPackedTensor
()
{
return
AreDimensionsContinuous
()
&&
GetStride
(
Number
<
nDim
-
1
>
{})
==
1
;
}
template
<
class
T
>
__host__
__device__
static
constexpr
bool
ContainMultipleOriginalDimensions
(
T
)
{
return
false
;
}
__host__
__device__
static
constexpr
auto
GetElementSize
()
{
return
Number
<
reduce_on_sequence
(
Lengths
{},
math
::
multiplies
<
index_t
>
{},
Number
<
1
>
{})
>
{};
}
__host__
__device__
static
constexpr
auto
GetElementSpace
()
{
constexpr
index_t
element_space_unaligned
=
reduce_on_sequence
(
(
GetLengths
()
-
Number
<
1
>
{})
*
GetStrides
(),
math
::
plus
<
index_t
>
{},
Number
<
1
>
{});
return
Number
<
element_space_unaligned
>
{};
}
// emulate constexpr lambda
template
<
index_t
NSize
>
struct
lambda_GetOffsetFromMultiIndex
{
Array
<
index_t
,
NSize
>&
multi_id
;
index_t
&
offset
;
__host__
__device__
constexpr
lambda_GetOffsetFromMultiIndex
(
Array
<
index_t
,
NSize
>&
multi_id_
,
index_t
&
offset_
)
:
multi_id
(
multi_id_
),
offset
(
offset_
)
{
}
template
<
class
X
>
__host__
__device__
constexpr
void
operator
()(
X
IDim
)
const
{
offset
+=
multi_id
[
IDim
]
*
Type
::
GetStride
(
IDim
);
}
};
template
<
index_t
NSize
>
__host__
__device__
static
constexpr
index_t
GetOffsetFromMultiIndex
(
Array
<
index_t
,
NSize
>
multi_id
)
{
static_assert
(
NSize
==
nDim
,
"wrong! Dimension not consistent"
);
index_t
offset
=
0
;
static_for
<
0
,
nDim
,
1
>
{}(
lambda_GetOffsetFromMultiIndex
<
NSize
>
(
multi_id
,
offset
));
return
offset
;
}
template
<
class
...
Is
>
__host__
__device__
static
constexpr
index_t
GetOffsetFromMultiIndex
(
Is
...
is
)
{
return
GetOffsetFromMultiIndex
(
Array
<
index_t
,
sizeof
...(
Is
)
>
{
is
...});
}
template
<
index_t
...
Is
>
__host__
__device__
static
constexpr
auto
GetOffsetFromMultiIndex
(
Sequence
<
Is
...
>
)
{
static_assert
(
sizeof
...(
Is
)
==
nDim
,
"wrong! Dimension not consistent"
);
constexpr
auto
multi_id
=
Sequence
<
Is
...
>
{};
return
Number
<
reduce_on_sequence
(
multi_id
*
GetStrides
(),
math
::
plus
<
index_t
>
{},
Number
<
0
>
{})
>
{};
}
// emulate constexpr lambda
template
<
class
PackedStrides
>
struct
lambda_GetMultiIndexFrom1dIndex
{
index_t
&
id
;
Array
<
index_t
,
nDim
>&
multi_id
;
__host__
__device__
constexpr
lambda_GetMultiIndexFrom1dIndex
(
index_t
&
id_
,
Array
<
index_t
,
nDim
>&
multi_id_
)
:
id
(
id_
),
multi_id
(
multi_id_
)
{
}
template
<
class
IDim_
>
__host__
__device__
constexpr
void
operator
()(
IDim_
)
const
{
constexpr
auto
IDim
=
IDim_
{};
constexpr
index_t
stride
=
PackedStrides
::
Get
(
IDim
);
multi_id
(
IDim
)
=
id
/
stride
;
id
-=
multi_id
[
IDim
]
*
stride
;
}
};
__host__
__device__
static
constexpr
Array
<
index_t
,
nDim
>
GetMultiIndexFrom1dIndex
(
index_t
id
)
{
Array
<
index_t
,
nDim
>
multi_id
;
using
PackedStrides
=
decltype
(
calculate_tensor_strides_packed_deprecated
(
GetLengths
()));
// calculate index in each of the dimensions in the order of their dimension
static_for
<
0
,
nDim
-
1
,
1
>
{}(
lambda_GetMultiIndexFrom1dIndex
<
PackedStrides
>
(
id
,
multi_id
));
multi_id
(
Number
<
nDim
-
1
>
{})
=
id
/
PackedStrides
::
Get
(
Number
<
nDim
-
1
>
{});
return
multi_id
;
}
__host__
__device__
static
constexpr
auto
GetOriginalMultiIndexFromMultiIndex
(
Array
<
index_t
,
nDim
>
multi_id
)
{
return
multi_id
;
}
// This function doesn't do carry check on the highest dimension for positive stepping (or
// borrow check on the highest dimension for negative stepping) , for performance reason. It is
// the user's responsibility to make sure the result "new_mutli_id" is not out-of-bound on the
// highest dimension for positive stepping (or on the lowest dimension for negative stepping)
template
<
bool
PositiveDirection
>
__host__
__device__
static
Array
<
index_t
,
nDim
>
UpdateMultiIndexGivenStepSizeOf1dIndex
(
Array
<
index_t
,
nDim
>
old_multi_id
,
index_t
step_size_of_1d_index
,
integral_constant
<
bool
,
PositiveDirection
>
)
{
Array
<
index_t
,
nDim
>
new_multi_id
;
const
auto
step_sizes
=
GetMultiIndexFrom1dIndex
(
step_size_of_1d_index
);
static_if
<
PositiveDirection
>
{}([
&
](
auto
)
{
new_multi_id
=
old_multi_id
+
step_sizes
;
bool
carry
=
false
;
// do carry check in reversed order, starting from lowest dimension
// don't check the highest dimension
static_for
<
0
,
nDim
,
1
>
{}([
&
](
auto
IDimReverse
)
{
constexpr
index_t
idim
=
nDim
-
1
-
IDimReverse
;
constexpr
auto
IDim
=
Number
<
idim
>
{};
if
(
carry
)
{
++
new_multi_id
(
idim
);
}
carry
=
false
;
if
(
new_multi_id
[
idim
]
>=
GetLength
(
IDim
))
{
new_multi_id
(
idim
)
-=
GetLength
(
IDim
);
carry
=
true
;
}
});
}).
Else
([
&
](
auto
)
{
// shift up multi-id to avoid unsigned integer underflow during intermediate
// calculations. After the shift, should have new_multi_id[...] >= 1
new_multi_id
=
old_multi_id
+
(
GetLengths
()
-
step_sizes
);
bool
borrow
=
false
;
// do borrow check in reversed order, starting from lowest dimension
// don't check the highest dimension
static_for
<
0
,
nDim
,
1
>
{}([
&
](
auto
IDimReverse
)
{
constexpr
index_t
idim
=
nDim
-
1
-
IDimReverse
;
constexpr
auto
IDim
=
Number
<
idim
>
{};
if
(
borrow
)
{
--
new_multi_id
(
idim
);
}
borrow
=
false
;
if
(
new_multi_id
[
idim
]
<
GetLength
(
IDim
))
{
new_multi_id
(
idim
)
+=
GetLength
(
IDim
);
borrow
=
true
;
}
});
// shift back down multi-id
// here, should have new_multi_id[...] >= GetLengths()
new_multi_id
=
new_multi_id
-
GetLengths
();
});
return
new_multi_id
;
}
template
<
index_t
...
IDims
>
__host__
__device__
static
constexpr
auto
Extract
(
Number
<
IDims
>
...
extract_dims
)
{
static_assert
(
sizeof
...(
IDims
)
<=
GetNumOfDimension
(),
"wrong! too many number of dimensions to be extracted"
);
using
extract_lengths
=
decltype
(
Lengths
::
Extract
(
extract_dims
...));
using
extract_strides
=
decltype
(
Strides
::
Extract
(
extract_dims
...));
return
ConstantTensorDescriptor_deprecated
<
extract_lengths
,
extract_strides
>
{};
}
template
<
index_t
...
IDims
>
__host__
__device__
static
constexpr
auto
Extract
(
Sequence
<
IDims
...
>
)
{
return
Extract
(
Number
<
IDims
>
{}...);
}
template
<
class
...
Ts
>
__host__
__device__
static
constexpr
auto
Embed
(
ConstantTensorDescriptor_deprecated
<
Ts
...
>
)
{
using
leaf_tensor
=
ConstantTensorDescriptor_deprecated
<
Ts
...
>
;
return
ConstantTensorDescriptor_deprecated
<
decltype
(
GetLengths
().
PushBack
(
leaf_tensor
::
GetLengths
())),
decltype
(
GetStrides
().
PushBack
(
leaf_tensor
::
GetStrides
()))
>
{};
}
template
<
index_t
IDimVector
,
index_t
DataPerVector
>
struct
lambda_IsVectorizationAllowed
{
bool
&
is_allowed
;
__host__
__device__
constexpr
lambda_IsVectorizationAllowed
(
bool
&
is_allowed_
)
:
is_allowed
(
is_allowed_
)
{
}
template
<
index_t
IDim_
>
__host__
__device__
constexpr
void
operator
()(
Number
<
IDim_
>
)
const
{
constexpr
auto
IDim
=
Number
<
IDim_
>
{};
if
(
IDimVector
!=
IDim
&&
Strides
::
Get
(
IDim
)
%
DataPerVector
!=
0
)
{
is_allowed
=
false
;
}
}
};
template
<
index_t
IDimVector
,
index_t
DataPerVector
>
__host__
__device__
static
constexpr
bool
IsVectorizationAllowed
(
Number
<
IDimVector
>
,
Number
<
DataPerVector
>
)
{
bool
is_allowed
=
(
Strides
{}[
IDimVector
]
==
1
||
DataPerVector
==
1
)
&&
Lengths
{}[
IDimVector
]
%
DataPerVector
==
0
;
static_for
<
0
,
nDim
,
1
>
{}(
lambda_IsVectorizationAllowed
<
IDimVector
,
DataPerVector
>
{
is_allowed
});
return
is_allowed
;
}
template
<
index_t
IDim
,
index_t
DataPerVector
>
__host__
__device__
static
constexpr
auto
Vectorize
(
Number
<
IDim
>
,
Number
<
DataPerVector
>
)
{
constexpr
auto
idim
=
Number
<
IDim
>
{};
constexpr
auto
data_per_vector
=
Number
<
DataPerVector
>
{};
static_assert
(
IsVectorizationAllowed
(
idim
,
data_per_vector
),
"wrong!"
);
using
vectorized_lengths
=
decltype
(
Lengths
::
Modify
(
Number
<
IDim
>
{},
Number
<
Lengths
{}[
IDim
]
/
DataPerVector
>
{}));
using
vectorized_strides
=
decltype
((
Strides
{}
/
Number
<
DataPerVector
>
{}).
Modify
(
Number
<
IDim
>
{},
Number
<
1
>
{}));
return
ConstantTensorDescriptor_deprecated
<
vectorized_lengths
,
vectorized_strides
>
{};
}
template
<
index_t
IDim
,
index_t
SliceLen
>
__host__
__device__
static
constexpr
auto
Slice
(
Number
<
IDim
>
,
Number
<
SliceLen
>
)
{
using
slice_lengths
=
decltype
(
Lengths
::
Modify
(
Number
<
IDim
>
{},
Number
<
SliceLen
>
{}));
return
ConstantTensorDescriptor_deprecated
<
slice_lengths
,
Strides
>
{};
}
template
<
index_t
...
Is
>
__host__
__device__
static
constexpr
auto
Slice
(
Sequence
<
Is
...
>
slice_lengths
)
{
static_assert
(
slice_lengths
.
GetSize
()
==
nDim
,
"wrong!"
);
return
ConstantTensorDescriptor_deprecated
<
decltype
(
slice_lengths
),
Strides
>
{};
}
template
<
index_t
IDim
,
index_t
SliceLength
,
index_t
SliceStride
>
__host__
__device__
static
constexpr
auto
StridedSlice
(
Number
<
IDim
>
,
Number
<
SliceLength
>
,
Number
<
SliceStride
>
)
{
constexpr
index_t
new_stride
=
Strides
::
Get
(
Number
<
IDim
>
{})
*
SliceStride
;
using
new_lengths
=
decltype
(
Lengths
::
Modify
(
Number
<
IDim
>
{},
Number
<
SliceLength
>
{}));
using
new_strides
=
decltype
(
Strides
::
Modify
(
Number
<
IDim
>
{},
Number
<
new_stride
>
{}));
return
ConstantTensorDescriptor_deprecated
<
new_lengths
,
new_strides
>
{};
}
template
<
index_t
IDim
,
index_t
...
FoldIntervals
>
__host__
__device__
static
constexpr
auto
Fold
(
Number
<
IDim
>
,
Number
<
FoldIntervals
>
...)
{
constexpr
auto
fold_intervals
=
Sequence
<
FoldIntervals
...
>
{};
constexpr
index_t
fold_intervals_product
=
reduce_on_sequence
(
fold_intervals
,
math
::
multiplies
<
index_t
>
{},
Number
<
1
>
{});
constexpr
auto
unfold_length
=
GetLength
(
Number
<
IDim
>
{});
constexpr
auto
unfold_stride
=
GetStride
(
Number
<
IDim
>
{});
// length of the dimension to be folded needs to be dividable by fold_interval_product,
// otherwise, folding is invalid
static_assert
(
unfold_length
%
fold_intervals_product
==
0
,
"wrong! length on the dimension to be folded cannot be evenly divided!"
);
// folded lengths
constexpr
auto
fold_lengths
=
Sequence
<
unfold_length
/
fold_intervals_product
>
{}.
PushBack
(
fold_intervals
);
// folded strides
constexpr
auto
fold_strides
=
Number
<
unfold_stride
>
{}
*
reverse_inclusive_scan_sequence
(
fold_intervals
.
PushBack
(
Number
<
1
>
{}),
math
::
multiplies
<
index_t
>
{},
Number
<
1
>
{});
// left and right
constexpr
auto
left
=
typename
arithmetic_sequence_gen
<
0
,
IDim
,
1
>::
type
{};
constexpr
auto
right
=
typename
arithmetic_sequence_gen
<
IDim
+
1
,
GetNumOfDimension
(),
1
>::
type
{};
constexpr
auto
new_lengths
=
GetLengths
().
Extract
(
left
).
PushBack
(
fold_lengths
).
PushBack
(
GetLengths
().
Extract
(
right
));
constexpr
auto
new_strides
=
GetStrides
().
Extract
(
left
).
PushBack
(
fold_strides
).
PushBack
(
GetStrides
().
Extract
(
right
));
return
ConstantTensorDescriptor_deprecated
<
decltype
(
new_lengths
),
decltype
(
new_strides
)
>
{};
}
template
<
index_t
IDim
,
index_t
...
FoldIntervals
>
__host__
__device__
static
constexpr
auto
Fold
(
Number
<
IDim
>
,
Sequence
<
FoldIntervals
...
>
)
{
return
Fold
(
Number
<
IDim
>
{},
Number
<
FoldIntervals
>
{}...);
}
// this function unfold dimension [FirstUnfoldDim, ..., LastUnfoldDim] into 1 dimension
template
<
index_t
FirstUnfoldDim
,
index_t
LastUnfoldDim
>
__host__
__device__
static
constexpr
auto
Unfold
(
Number
<
FirstUnfoldDim
>
,
Number
<
LastUnfoldDim
>
)
{
static_assert
(
FirstUnfoldDim
>=
0
&&
LastUnfoldDim
<
nDim
&&
FirstUnfoldDim
<=
LastUnfoldDim
,
"wrong! should have FirstUnfoldDim <= LastUnfoldDim!"
);
// left and right
constexpr
auto
left
=
typename
arithmetic_sequence_gen
<
0
,
FirstUnfoldDim
,
1
>::
type
{};
constexpr
auto
middle
=
typename
arithmetic_sequence_gen
<
FirstUnfoldDim
,
LastUnfoldDim
+
1
,
1
>::
type
{};
constexpr
auto
right
=
typename
arithmetic_sequence_gen
<
LastUnfoldDim
+
1
,
GetNumOfDimension
(),
1
>::
type
{};
// dimensions to be unfolded need to be continuous
static_assert
(
Type
::
Extract
(
middle
).
AreDimensionsContinuous
(),
"wrong! not unfoldable"
);
// unfolded length, stride
constexpr
index_t
unfold_length
=
reduce_on_sequence
(
GetLengths
().
Extract
(
middle
),
math
::
multiplies
<
index_t
>
{},
Number
<
1
>
{});
constexpr
index_t
unfold_stride
=
GetStride
(
Number
<
LastUnfoldDim
>
{});
// new lengths, strides
constexpr
auto
new_lengths
=
GetLengths
()
.
Extract
(
left
)
.
PushBack
(
Number
<
unfold_length
>
{})
.
PushBack
(
GetLengths
().
Extract
(
right
));
constexpr
auto
new_strides
=
GetStrides
()
.
Extract
(
left
)
.
PushBack
(
Number
<
unfold_stride
>
{})
.
PushBack
(
GetStrides
().
Extract
(
right
));
return
ConstantTensorDescriptor_deprecated
<
decltype
(
new_lengths
),
decltype
(
new_strides
)
>
{};
}
__host__
__device__
static
constexpr
auto
Pack
()
{
using
packed_strides
=
decltype
(
calculate_tensor_strides_packed_deprecated
(
Lengths
{}));
return
ConstantTensorDescriptor_deprecated
<
Lengths
,
packed_strides
>
{};
}
template
<
class
MapNew2Old
>
__host__
__device__
static
constexpr
auto
ReorderGivenNew2Old
(
MapNew2Old
)
{
return
ConstantTensorDescriptor_deprecated
<
decltype
(
Lengths
::
ReorderGivenNew2Old
(
MapNew2Old
{})),
decltype
(
Strides
::
ReorderGivenNew2Old
(
MapNew2Old
{}))
>
{};
}
template
<
class
MapOld2New
>
__host__
__device__
static
constexpr
auto
ReorderGivenOld2New
(
MapOld2New
)
{
return
ConstantTensorDescriptor_deprecated
<
decltype
(
Lengths
::
ReorderGivenOld2New
(
MapOld2New
{})),
decltype
(
Strides
::
ReorderGivenOld2New
(
MapOld2New
{}))
>
{};
}
};
template
<
class
Lengths
>
__host__
__device__
constexpr
auto
make_ConstantTensorDescriptor_packed
(
Lengths
)
{
using
Strides
=
decltype
(
calculate_tensor_strides_packed_deprecated
(
Lengths
{}));
return
ConstantTensorDescriptor_deprecated
<
Lengths
,
Strides
>
{};
}
template
<
class
Lengths
,
class
Strides
>
__host__
__device__
constexpr
auto
make_ConstantTensorDescriptor
(
Lengths
,
Strides
)
{
return
ConstantTensorDescriptor_deprecated
<
Lengths
,
Strides
>
{};
}
template
<
class
Lengths
,
index_t
Align
>
__host__
__device__
constexpr
auto
make_ConstantTensorDescriptor_aligned
(
Lengths
,
Number
<
Align
>
)
{
using
Strides
=
decltype
(
calculate_tensor_strides_aligned_old
(
Lengths
{},
Number
<
Align
>
{}));
return
ConstantTensorDescriptor_deprecated
<
Lengths
,
Strides
>
{};
}
template
<
index_t
...
Lengths
,
index_t
...
Strides
>
__host__
__device__
void
print_ConstantTensorDescriptor
(
const
char
*
s
,
ConstantTensorDescriptor_deprecated
<
Sequence
<
Lengths
...
>
,
Sequence
<
Strides
...
>>
)
{
constexpr
index_t
ndim
=
sizeof
...(
Lengths
);
static_assert
(
ndim
>
0
&&
ndim
<=
12
,
"wrong!"
);
static_if
<
ndim
==
1
>
{}([
&
](
auto
)
{
printf
(
"%s dim %u, lengths {%u}, strides {%u}
\n
"
,
s
,
ndim
,
Lengths
...,
Strides
...);
});
static_if
<
ndim
==
2
>
{}([
&
](
auto
)
{
printf
(
"%s dim %u, lengths {%u %u}, strides {%u %u}
\n
"
,
s
,
ndim
,
Lengths
...,
Strides
...);
});
static_if
<
ndim
==
3
>
{}([
&
](
auto
)
{
printf
(
"%s dim %u, lengths {%u %u %u}, strides {%u %u %u}
\n
"
,
s
,
ndim
,
Lengths
...,
Strides
...);
});
static_if
<
ndim
==
4
>
{}([
&
](
auto
)
{
printf
(
"%s dim %u, lengths {%u %u %u %u}, strides {%u %u %u %u}
\n
"
,
s
,
ndim
,
Lengths
...,
Strides
...);
});
static_if
<
ndim
==
5
>
{}([
&
](
auto
)
{
printf
(
"%s dim %u, lengths {%u %u %u %u %u}, strides {%u %u %u %u %u}
\n
"
,
s
,
ndim
,
Lengths
...,
Strides
...);
});
static_if
<
ndim
==
6
>
{}([
&
](
auto
)
{
printf
(
"%s dim %u, lengths {%u %u %u %u %u %u}, strides {%u %u %u %u %u %u}
\n
"
,
s
,
ndim
,
Lengths
...,
Strides
...);
});
static_if
<
ndim
==
7
>
{}([
&
](
auto
)
{
printf
(
"%s dim %u, lengths {%u %u %u %u %u %u %u}, strides {%u %u %u %u %u %u %u}
\n
"
,
s
,
ndim
,
Lengths
...,
Strides
...);
});
static_if
<
ndim
==
8
>
{}([
&
](
auto
)
{
printf
(
"%s dim %u, lengths {%u %u %u %u %u %u %u %u}, strides {%u %u %u %u %u %u %u %u}
\n
"
,
s
,
ndim
,
Lengths
...,
Strides
...);
});
static_if
<
ndim
==
9
>
{}([
&
](
auto
)
{
printf
(
"%s dim %u, lengths {%u %u %u %u %u %u %u %u %u}, strides {%u %u %u %u %u %u %u %u "
"%u}
\n
"
,
s
,
ndim
,
Lengths
...,
Strides
...);
});
static_if
<
ndim
==
10
>
{}([
&
](
auto
)
{
printf
(
"%s dim %u, lengths {%u %u %u %u %u %u %u %u %u %u}, strides {%u %u %u %u %u %u %u "
"%u %u %u}
\n
"
,
s
,
ndim
,
Lengths
...,
Strides
...);
});
static_if
<
ndim
==
11
>
{}([
&
](
auto
)
{
printf
(
"%s dim %u, lengths {%u %u %u %u %u %u %u %u %u %u %u}, strides {%u %u %u %u %u %u "
"%u %u "
"%u %u %u}
\n
"
,
s
,
ndim
,
Lengths
...,
Strides
...);
});
static_if
<
ndim
==
12
>
{}([
&
](
auto
)
{
printf
(
"%s dim %u, lengths {%u %u %u %u %u %u %u %u %u %u %u %u}, strides {%u %u %u %u %u "
"%u %u %u %u "
"%u %u %u}
\n
"
,
s
,
ndim
,
Lengths
...,
Strides
...);
});
}
}
// namespace ck
#endif
composable_kernel/include/tensor_description/dimension.hpp
0 → 100644
View file @
b37cb71f
#ifndef CK_DIMENSION_HPP
#define CK_DIMENSION_HPP
#include "common_header.hpp"
namespace
ck
{
template
<
index_t
Length
,
index_t
Stride
>
struct
NativeDimension
{
__host__
__device__
static
constexpr
auto
GetLength
()
{
return
Number
<
Length
>
{};
}
__host__
__device__
static
constexpr
auto
GetStride
()
{
return
Number
<
Stride
>
{};
}
};
}
// namespace ck
#endif
composable_kernel/include/tensor_description/multi_index_transform.hpp
0 → 100644
View file @
b37cb71f
#ifndef CK_MULTI_INDEX_TRANSFORM_HPP
#define CK_MULTI_INDEX_TRANSFORM_HPP
#include "common_header.hpp"
namespace
ck
{
template
<
index_t
N
>
using
MultiIndex
=
Array
<
index_t
,
N
>
;
template
<
typename
...
Xs
>
__host__
__device__
constexpr
auto
make_multi_index
(
Xs
...
xs
)
{
return
MultiIndex
<
sizeof
...(
Xs
)
>
(
xs
...);
}
template
<
index_t
Length
>
struct
PassThrough
{
using
LowerIndex
=
MultiIndex
<
1
>
;
using
UpperIndex
=
MultiIndex
<
1
>
;
__host__
__device__
static
constexpr
auto
GetNumOfLowerDimension
()
{
return
Number
<
1
>
{};
}
__host__
__device__
static
constexpr
auto
GetNumOfUpperDimension
()
{
return
Number
<
1
>
{};
}
__host__
__device__
static
constexpr
auto
GetUpperLengths
()
{
return
Sequence
<
Length
>
{};
}
__host__
__device__
static
constexpr
auto
CalculateLowerIndex
(
const
UpperIndex
&
idx_up
)
{
return
idx_up
;
}
__host__
__device__
static
constexpr
auto
CalculateLowerIndexDiff
(
const
UpperIndex
&
idx_up_diff
,
const
UpperIndex
&
/* idx_up_old */
,
const
LowerIndex
&
/* idx_low_old */
)
{
return
idx_up_diff
;
}
__host__
__device__
static
constexpr
bool
IsLinearTransform
()
{
return
true
;
}
__host__
__device__
static
constexpr
bool
IsUpperIndexMappedToValidLowerIndex
(
const
UpperIndex
&
/* idx_up */
)
{
return
true
;
}
};
// LowerLengths: Sequence<...>
template
<
typename
LowerLengths
,
typename
LeftPads
,
typename
RightPads
>
struct
Pad
{
static
constexpr
index_t
nDim
=
LowerLengths
::
Size
();
using
LowerIndex
=
MultiIndex
<
nDim
>
;
using
UpperIndex
=
MultiIndex
<
nDim
>
;
__host__
__device__
static
constexpr
auto
GetNumOfLowerDimension
()
{
return
Number
<
nDim
>
{};
}
__host__
__device__
static
constexpr
auto
GetNumOfUpperDimension
()
{
return
Number
<
nDim
>
{};
}
__host__
__device__
static
constexpr
auto
GetUpperLengths
()
{
return
LowerLengths
{}
+
LeftPads
{}
+
RightPads
{};
}
__host__
__device__
static
constexpr
auto
CalculateLowerIndex
(
const
UpperIndex
&
idx_up
)
{
return
idx_up
-
LeftPads
{};
}
__host__
__device__
static
constexpr
auto
CalculateLowerIndexDiff
(
const
UpperIndex
&
idx_up_diff
,
const
UpperIndex
&
/* idx_up_old */
,
const
LowerIndex
&
/* idx_low_old */
)
{
return
idx_up_diff
;
}
__host__
__device__
static
constexpr
bool
IsLinearTransform
()
{
return
true
;
}
__host__
__device__
constexpr
bool
IsUpperIndexMappedToValidLowerIndex
(
const
UpperIndex
&
idx_up
)
const
{
#if 0
struct lambda_no_pad
{
__host__ __device__ constexpr bool operator()(index_t x) const { return x == 0; }
};
if(sequence_all_of(LeftPads{}, lambda_no_pad{}) &&
sequence_all_of(RightPads{}, lambda_no_pad{}))
{
return true;
}
else
#endif
{
bool
flag
=
true
;
static_for
<
0
,
nDim
,
1
>
{}([
&
](
auto
idim
)
{
// only check if there is left-padding
static_if
<
(
LeftPads
::
At
(
idim
)
!=
0
)
>
{}(
[
&
](
auto
)
{
flag
=
flag
&&
idx_up
[
idim
]
>=
LeftPads
::
At
(
idim
);
});
// only check if there is right-padding
static_if
<
(
RightPads
::
At
(
idim
)
!=
0
)
>
{}([
&
](
auto
)
{
flag
=
flag
&&
(
idx_up
[
idim
]
<
LeftPads
::
At
(
idim
)
+
LowerLengths
::
At
(
idim
));
});
});
return
flag
;
}
}
};
// LowerLengths: Sequence<...>
template
<
typename
LowerLengths
>
struct
Merge
{
static
constexpr
index_t
nDimLow
=
LowerLengths
::
Size
();
static
constexpr
index_t
nDimUp
=
1
;
using
LowerIndex
=
MultiIndex
<
nDimLow
>
;
using
UpperIndex
=
MultiIndex
<
nDimUp
>
;
__host__
__device__
static
constexpr
auto
GetNumOfLowerDimension
()
{
return
Number
<
nDimLow
>
{};
}
__host__
__device__
static
constexpr
auto
GetNumOfUpperDimension
()
{
return
Number
<
nDimUp
>
{};
}
__host__
__device__
static
constexpr
auto
GetUpperLengths
()
{
return
Sequence
<
reduce_on_sequence
(
LowerLengths
{},
math
::
multiplies
<
index_t
>
{},
Number
<
1
>
{})
>
{};
}
// emulate constexpr lambda
template
<
typename
PseudoLowStrides
>
struct
lambda_CalculateLowerIndex
{
index_t
&
itmp
;
LowerIndex
&
idx_low
;
__host__
__device__
explicit
constexpr
lambda_CalculateLowerIndex
(
index_t
&
itmp_
,
LowerIndex
&
idx_low_
)
:
itmp
(
itmp_
),
idx_low
(
idx_low_
)
{
}
template
<
typename
IDim
>
__host__
__device__
constexpr
void
operator
()(
IDim
idim
)
const
{
constexpr
index_t
stride
=
PseudoLowStrides
::
At
(
idim
);
idx_low
(
idim
)
=
itmp
/
stride
;
itmp
-=
idx_low
[
idim
]
*
stride
;
}
};
__host__
__device__
static
constexpr
auto
CalculateLowerIndex
(
const
UpperIndex
&
idx_up
)
{
LowerIndex
idx_low
;
index_t
itmp
=
idx_up
[
0
];
constexpr
auto
pseudo_low_strides
=
reverse_inclusive_scan_sequence
(
LowerLengths
::
PopFront
(),
math
::
multiplies
<
index_t
>
{},
Number
<
1
>
{})
.
PushBack
(
Number
<
1
>
{});
static_for
<
0
,
nDimLow
-
1
,
1
>
{}(
lambda_CalculateLowerIndex
<
decltype
(
pseudo_low_strides
)
>
(
itmp
,
idx_low
));
idx_low
(
nDimLow
-
1
)
=
itmp
/
pseudo_low_strides
[
nDimLow
-
1
];
return
idx_low
;
}
// idx_low_diff depends on idx_low_old, so idx_low need to be up-to-date
// If idx_up_diff is known at compile-time, many calculations can be optimized
// away by compiler
// This function assume idx_low_old is not out-of-bound
__host__
__device__
static
constexpr
auto
CalculateLowerIndexDiff
(
const
UpperIndex
&
idx_up_diff
,
const
UpperIndex
&
/* idx_up_old */
,
const
LowerIndex
&
idx_low_old
)
{
// do nothing if idx_up_diff == 0
if
(
idx_up_diff
[
0
]
==
0
)
{
return
make_zero_array
<
index_t
,
nDimLow
>
();
}
// CalculateLowerIndex(idx_up_diff) has multiple integer divisions.
// If idx_up_diff is known at compile-time, the calculation can
// be done at compile-time. However, if idx_up_diff is only known
// at run-time, then the calculation will also be computed at
// run-time, and can be very expensive.
LowerIndex
idx_low_new
=
idx_low_old
+
CalculateLowerIndex
(
idx_up_diff
);
if
(
idx_up_diff
[
0
]
>
0
)
{
bool
carry
=
false
;
// do carry check in reversed order, starting from lowest dimension
// don't check the highest dimension
static_for
<
0
,
nDimLow
-
1
,
1
>
{}([
&
](
auto
ireverse
)
{
constexpr
index_t
i
=
nDimLow
-
1
-
ireverse
;
if
(
carry
)
{
++
idx_low_new
(
i
);
}
carry
=
false
;
if
(
idx_low_new
[
i
]
>=
LowerLengths
::
At
(
i
))
{
idx_low_new
(
i
)
-=
LowerLengths
::
At
(
i
);
carry
=
true
;
}
});
// highest dimension, no out-of-bound check
if
(
carry
)
{
++
idx_low_new
(
0
);
}
}
else
if
(
idx_up_diff
[
0
]
<
0
)
{
bool
borrow
=
false
;
// do borrow check in reversed order, starting from lowest dimension
// don't check the highest dimension
static_for
<
0
,
nDimLow
-
1
,
1
>
{}([
&
](
auto
ireverse
)
{
constexpr
index_t
i
=
nDimLow
-
1
-
ireverse
;
if
(
borrow
)
{
--
idx_low_new
(
i
);
}
borrow
=
false
;
if
(
idx_low_new
[
i
]
<
0
)
{
idx_low_new
(
i
)
+=
LowerLengths
::
At
(
i
);
borrow
=
true
;
}
});
// highest dimension, no out-of-bound check
if
(
borrow
)
{
--
idx_low_new
(
0
);
}
}
return
idx_low_new
-
idx_low_old
;
}
__host__
__device__
static
constexpr
bool
IsLinearTransform
()
{
return
false
;
}
__host__
__device__
static
constexpr
bool
IsUpperIndexMappedToValidLowerIndex
(
const
UpperIndex
&
/* idx_up */
)
{
return
true
;
}
};
// UpperLengths: Sequence<...>
template
<
typename
UpperLengths
>
struct
UnMerge
{
static
constexpr
index_t
nDimLow
=
1
;
static
constexpr
index_t
nDimUp
=
UpperLengths
::
Size
();
using
LowerIndex
=
MultiIndex
<
nDimLow
>
;
using
UpperIndex
=
MultiIndex
<
nDimUp
>
;
__host__
__device__
static
constexpr
auto
GetNumOfLowerDimension
()
{
return
Number
<
nDimLow
>
{};
}
__host__
__device__
static
constexpr
auto
GetNumOfUpperDimension
()
{
return
Number
<
nDimUp
>
{};
}
__host__
__device__
static
constexpr
auto
GetUpperLengths
()
{
return
UpperLengths
{};
}
__host__
__device__
static
constexpr
auto
CalculateLowerIndex
(
const
UpperIndex
&
idx_up
)
{
LowerIndex
idx_low
{
0
};
constexpr
auto
pseudo_up_strides
=
reverse_inclusive_scan_sequence
(
UpperLengths
::
PopFront
(),
math
::
multiplies
<
index_t
>
{},
Number
<
1
>
{})
.
PushBack
(
Number
<
1
>
{});
static_for
<
0
,
nDimUp
,
1
>
{}(
[
&
](
auto
idim
)
{
idx_low
(
0
)
+=
idx_up
[
idim
]
*
pseudo_up_strides
[
idim
];
});
return
idx_low
;
}
__host__
__device__
static
constexpr
auto
CalculateLowerIndexDiff
(
const
UpperIndex
&
idx_up_diff
,
const
UpperIndex
&
/* idx_up_old */
,
const
LowerIndex
&
/* idx_low_old */
)
{
return
CalculateLowerIndex
(
idx_up_diff
);
}
__host__
__device__
static
constexpr
bool
IsLinearTransform
()
{
return
true
;
}
__host__
__device__
static
constexpr
bool
IsUpperIndexMappedToValidLowerIndex
(
const
UpperIndex
&
/* idx_up */
)
{
return
true
;
}
};
// UpperLengths: Sequence<...>
// Coefficients: Sequence<...>
// idx_low = coefficients[0, ...nDimUp-1] * idx_up[0, ...nDimUp-1] + coefficients[nDimUp]
template
<
typename
UpperLengths
,
typename
Coefficients
>
struct
Embed
{
static
constexpr
index_t
nDimLow
=
1
;
static
constexpr
index_t
nDimUp
=
UpperLengths
::
Size
();
using
LowerIndex
=
MultiIndex
<
nDimLow
>
;
using
UpperIndex
=
MultiIndex
<
nDimUp
>
;
__host__
__device__
explicit
constexpr
Embed
()
{
static_assert
(
UpperLengths
::
GetSize
()
==
nDimUp
&&
Coefficients
::
GetSize
()
==
nDimUp
+
1
,
"wrong! # of dimensions not consistent"
);
}
__host__
__device__
static
constexpr
auto
GetNumOfUpperDimension
()
{
return
Number
<
nDimUp
>
{};
}
__host__
__device__
static
constexpr
auto
GetNumOfLowerDimension
()
{
return
Number
<
nDimLow
>
{};
}
__host__
__device__
static
constexpr
auto
GetUpperLengths
()
{
return
UpperLengths
{};
}
__host__
__device__
static
constexpr
auto
CalculateLowerIndex
(
const
UpperIndex
&
idx_up
)
{
LowerIndex
idx_low
(
Coefficients
{}[
nDimUp
]);
static_for
<
0
,
nDimUp
,
1
>
{}(
[
&
](
auto
idim
)
{
idx_low
(
0
)
+=
idx_up
[
idim
]
*
Coefficients
{}[
idim
];
});
return
idx_low
;
}
__host__
__device__
static
constexpr
auto
CalculateLowerIndexDiff
(
const
UpperIndex
&
idx_up_diff
,
const
UpperIndex
&
/* idx_up_old */
,
const
LowerIndex
&
/* idx_low_old */
)
{
LowerIndex
idx_low_diff
{
0
};
static_for
<
0
,
nDimUp
,
1
>
{}(
[
&
](
auto
idim
)
{
idx_low_diff
(
0
)
+=
idx_up_diff
[
idim
]
*
Coefficients
{}[
idim
];
});
return
idx_low_diff
;
}
__host__
__device__
static
constexpr
bool
IsLinearTransform
()
{
return
true
;
}
__host__
__device__
static
constexpr
bool
IsUpperIndexMappedToValidLowerIndex
(
const
UpperIndex
&
/* idx_up */
)
{
return
true
;
}
};
template
<
index_t
LowerLength
,
index_t
VectorSize
>
struct
Vectorize
{
using
LowerIndex
=
MultiIndex
<
1
>
;
using
UpperIndex
=
MultiIndex
<
1
>
;
__host__
__device__
constexpr
Vectorize
()
{
static_assert
(
VectorSize
>
0
&&
LowerLength
%
VectorSize
==
0
,
"wrong! cannot evenly divide"
);
}
__host__
__device__
static
constexpr
auto
GetNumOfLowerDimension
()
{
return
Number
<
1
>
{};
}
__host__
__device__
static
constexpr
auto
GetNumOfUpperDimension
()
{
return
Number
<
1
>
{};
}
__host__
__device__
static
constexpr
auto
GetUpperLengths
()
{
return
Sequence
<
LowerLength
/
VectorSize
>
{};
}
__host__
__device__
static
constexpr
auto
CalculateLowerIndex
(
const
UpperIndex
&
idx_up
)
{
return
VectorSize
*
idx_up
;
}
__host__
__device__
static
constexpr
auto
CalculateLowerIndexDiff
(
const
UpperIndex
&
idx_up_diff
,
const
UpperIndex
&
/* idx_up_old */
,
const
LowerIndex
&
/* idx_low_old */
)
{
return
VectorSize
*
idx_up_diff
;
}
__host__
__device__
static
constexpr
bool
IsLinearTransform
()
{
return
true
;
}
__host__
__device__
static
constexpr
bool
IsUpperIndexMappedToValidLowerIndex
(
const
UpperIndex
&
/* idx_up */
)
{
return
true
;
}
};
}
// namespace ck
#endif
composable_kernel/include/tensor_description/tensor_coordinate.hpp
View file @
b37cb71f
...
...
@@ -2,299 +2,248 @@
#define CK_TENSOR_COORDINATE_HPP
#include "common_header.hpp"
#include "ConstantTensorDescriptor.hpp"
#include "ConstantMergedTensorDescriptor.hpp"
#include "dimension.hpp"
#include "multi_index_transform.hpp"
#include "tensor_descriptor.hpp"
namespace
ck
{
template
<
class
TensorDesc
>
struct
NormalTensorCoordinate
// A "tensor cooridnate" is an opaque object that represents a "point of location" inside a tensor
// At the bare minimun, user should be able to query the following information from a tensor
// coordinate:
// 1. Tensor descriptor
// 2. Location, represented in the form of multi-index
// 3. Location, represented in the form of the offset to the origin of the tensor
// 4. If the location is inside invalid area or not, i.e. the padding area of an implicitly padded
// tensor is considered invalid, because the padding area doesn't have any physical memory
// allocation
// A tensor cooridnate also provides following functionality:
// 1. Given step size in each dimension, update itself, or return a new tensor cooridnate, so user
// can freely move the "point of location" inside the tensor
// wrapper class for NativeTensorCoordinate and TransformedTensorCoordinate
template
<
typename
TensorDesc
>
struct
TensorCoordinate
;
// tensor coordinate for native tensor
template
<
typename
NativeTensorDesc
>
struct
NativeTensorCoordinate
{
using
type
=
NormalTensorCoordinate
;
using
tensor_desc_type
=
TensorDesc
;
using
type
=
NativeTensorCoordinate
;
using
tensor_desc_type
=
NativeTensorDesc
;
static
constexpr
index_t
nDim
=
tensor_desc_type
::
GetNumOfDimension
();
using
Index
=
MultiIndex
<
nDim
>
;
__host__
__device__
constexpr
NativeTensorCoordinate
(
Index
idx
)
:
mIndex
(
idx
),
mOffset
(
tensor_desc_type
::
CalculateOffset
(
idx
))
{
}
__host__
__device__
constexpr
NormalTensorCoordinate
(
Array
<
index_t
,
nDim
>
tensor_index
)
:
mOffset
{
tensor_desc_type
::
GetOffsetFromMultiIndex
(
tensor_index
)}
template
<
typename
...
Xs
>
__host__
__device__
constexpr
NativeTensorCoordinate
(
Xs
...
xs
)
:
NativeTensorCoordinate
(
Index
{
xs
...})
{
}
template
<
class
...
Xs
>
__host__
__device__
constexpr
N
ormal
TensorCoordinate
(
Xs
...
xs
)
:
N
ormal
TensorCoordinate
(
Array
<
index_t
,
nDim
>
{
x
s
...})
template
<
index_t
...
Xs
>
__host__
__device__
constexpr
N
ative
TensorCoordinate
(
Sequence
<
Xs
...
>
)
:
N
ative
TensorCoordinate
(
Index
{
X
s
...})
{
}
__host__
__device__
constexpr
index_t
GetOffset
()
const
{
return
mOffset
;
}
__host__
__device__
static
constexpr
auto
GetTensorDescriptor
()
{
return
tensor_desc_type
{}
;
}
// T is Array or Sequence
template
<
class
T
>
__host__
__device__
type
operator
+=
(
T
step_sizes
)
__host__
__device__
constexpr
const
Index
&
GetIndex
()
const
{
return
mIndex
;
}
__host__
__device__
constexpr
const
index_t
&
GetOffset
()
const
{
return
mOffset
;
}
__host__
__device__
constexpr
type
operator
+=
(
const
Index
&
idx_diff
)
{
static_assert
(
is_same
<
typename
T
::
data_type
,
index_t
>
{}
&&
T
::
GetSize
()
==
nDim
,
"wrong!"
);
// mIndex is updated here, but some (or all) of its entries may never be used
// compiler should remove those entries as dead code
mIndex
+=
idx_diff
;
mOffset
+=
tensor_desc_type
::
GetOffsetFromMultiIndex
(
step_sizes
);
mOffset
+=
tensor_desc_type
::
CalculateOffsetDiff
(
idx_diff
);
return
*
this
;
}
template
<
class
T
>
__host__
__device__
type
operator
-=
(
T
step_sizes
)
__host__
__device__
constexpr
type
operator
-=
(
const
Index
&
idx_diff
)
{
static_assert
(
is_same
<
typename
T
::
data_type
,
index_t
>
{}
&&
T
::
GetSize
()
==
nDim
,
"wrong!"
);
// mIndex is updated here, but some (or all) of its entries may never be used
// compiler should remove those entries as dead code
mIndex
-=
idx_diff
;
mOffset
-=
tensor_desc_type
::
GetOffsetFromMultiIndex
(
step_sizes
);
mOffset
-=
tensor_desc_type
::
CalculateOffsetDiff
(
idx_diff
);
return
*
this
;
}
template
<
class
T
>
__host__
__device__
constexpr
type
operator
+
(
T
step_sizes
)
const
__host__
__device__
constexpr
type
operator
+
(
const
Index
&
idx_diff
)
const
{
type
coord
=
*
this
;
coord
+=
step_sizes
;
coord
+=
idx_diff
;
return
coord
;
}
template
<
class
T
>
__host__
__device__
constexpr
type
operator
-
(
T
step_sizes
)
const
__host__
__device__
constexpr
type
operator
-
(
const
Index
&
idx_diff
)
const
{
type
coord
=
*
this
;
coord
-=
step_sizes
;
coord
-=
idx_diff
;
return
coord
;
}
// reposition point of origin, and return compensated offset.
// This is a hack to reduce index calculation during looping over
// a tensor whose origin is this TensorCoordinate. It does so, by spitting
// out the run-time offset to the pointer (to the tensor data) held by this
// TensorCoordiante, so the caller can add the offset into the run-time pointer of
// the data, so only 1 run-time variable (update pointer) is needed, instead
// of 2 run-time variables (old pointer and this offset)
// TODO: after introducing the concept of "run-time tensor view", which contains the
// run-time pointer to the data, always keep track of the pointer, instead of both
// offset and the pointer. This also bring additional benefit that we don't need to
// worry the offset might underflow (because offset is unsigned integer) when updating it.
__host__
__device__
constexpr
index_t
RepositOrigin
()
__host__
__device__
static
constexpr
index_t
CalculateOffsetDiff
(
const
Index
&
idx_diff
)
{
index_t
offset_diff
=
mOffset
;
mOffset
=
0
;
return
offset_diff
;
return
tensor_desc_type
::
CalculateOffsetDiff
(
idx_diff
);
}
__host__
__device__
static
constexpr
bool
IsUpperIndexMappedToValidOffset
()
{
return
true
;
}
private:
// mIndex may be saved and updated, however, the value of some (or all) of its entries may
// never be used. Compiler should be able to remove these entries as well as its calculation
// as dead code.
// TODO: make sure compiler indeed remove these dead code
Index
mIndex
;
index_t
mOffset
;
};
template
<
class
TensorDesc
>
struct
MergedTensorCoordinate
// tensor coordinate for transformed tensor
template
<
typename
TransformedTensorDesc
>
struct
TransformedTensorCoordinate
{
using
type
=
MergedTensorCoordinate
;
using
tensor_desc_type
=
TensorDesc
;
using
tensor_desc_type
=
TransformedTensorDesc
;
using
LowerCoord
=
typename
TensorCoordinate
<
decltype
(
tensor_desc_type
::
GetLowerTensorDescriptor
())
>::
type
;
using
UpperCoord
=
TransformedTensorCoordinate
;
static
constexpr
index_t
nDim
=
tensor_desc_type
::
GetNumOfDimension
();
static
constexpr
index_t
nOriginalDim
=
tensor_desc_type
::
GetOriginalTensorDescriptor
().
GetNumOfDimension
();
using
UpperIndex
=
MultiIndex
<
nDim
>
;
__host__
__device__
constexpr
Merg
edTensorCoordinate
(
Array
<
index_t
,
nDim
>
tensor_inde
x
)
:
m
OriginalIndex
{
tensor_desc_type
::
GetOrigin
al
M
ul
tiIndexFromMultiIndex
(
tensor_inde
x
)}
__host__
__device__
constexpr
Transform
edTensorCoordinate
(
UpperIndex
id
x
)
:
m
IndexUp
{
idx
},
mCoordLow
{
tensor_desc_type
::
C
al
c
ul
ateLowerIndex
(
id
x
)}
{
// partial offset on each dimension
static_for
<
0
,
nDim
,
1
>
{}([
&
](
auto
idim
)
{
constexpr
auto
partial_original_dims
=
tensor_desc_type
::
GetContainedOriginalDimensions
(
idim
);
constexpr
auto
partial_original_desc
=
tensor_desc_type
::
GetOriginalTensorDescriptor
().
Extract
(
partial_original_dims
);
mPartialOffsets
(
idim
)
=
partial_original_desc
.
GetOffsetFromMultiIndex
(
extract_array
(
mOriginalIndex
,
partial_original_dims
));
});
// complete offset
mOffset
=
accumulate_on_array
(
mPartialOffsets
,
math
::
plus
<
index_t
>
{},
static_cast
<
index_t
>
(
0
));
}
template
<
class
...
Xs
>
__host__
__device__
constexpr
Merg
edTensorCoordinate
(
Xs
...
xs
)
:
Merg
edTensorCoordinate
(
Array
<
index_t
,
nDim
>
{
xs
...})
template
<
typename
...
Xs
>
__host__
__device__
constexpr
Transform
edTensorCoordinate
(
Xs
...
xs
)
:
Transform
edTensorCoordinate
(
UpperIndex
{
xs
...})
{
}
__host__
__device__
constexpr
index_t
GetOffset
()
const
{
return
mOffset
;
}
template
<
class
IDim
,
class
T
,
bool
PositiveDirection
>
__host__
__device__
void
MoveOnDimension
(
IDim
idim_
,
T
step_size
,
integral_constant
<
bool
,
PositiveDirection
>
)
template
<
index_t
...
Xs
>
__host__
__device__
constexpr
TransformedTensorCoordinate
(
Sequence
<
Xs
...
>
)
:
TransformedTensorCoordinate
(
UpperIndex
{
Xs
...})
{
constexpr
auto
idim
=
idim_
;
// if step_size is known at compile time
static_if
<
is_static
<
T
>::
value
>
{}(
[
&
](
auto
)
{
static_if
<
T
{}
==
0
>
{}([
&
](
auto
)
{
return
;
});
});
// update original index
static_if
<
tensor_desc_type
::
ContainMultipleOriginalDimensions
(
idim
)
>
{}([
&
](
auto
)
{
constexpr
auto
partial_original_dims
=
tensor_desc_type
::
GetContainedOriginalDimensions
(
idim
);
constexpr
index_t
ndim_partial_original
=
partial_original_dims
.
GetSize
();
constexpr
auto
partial_original_desc
=
tensor_desc_type
::
GetOriginalTensorDescriptor
().
Extract
(
partial_original_dims
);
const
auto
partial_original_step_sizes
=
partial_original_desc
.
GetMultiIndexFrom1dIndex
(
step_size
);
}
// update partial original multi-id
auto
partial_original_id
=
extract_array
(
mOriginalIndex
,
partial_original_dims
);
__host__
__device__
static
constexpr
auto
GetTensorDescriptor
()
{
return
tensor_desc_type
{};
}
static_if
<
PositiveDirection
>
{}([
&
](
auto
)
{
partial_original_id
+=
partial_original_step_sizes
;
__host__
__device__
constexpr
const
LowerCoord
&
GetLowerCoordinate
()
const
{
return
mCoordLow
;
}
bool
carry
=
false
;
__host__
__device__
constexpr
const
UpperIndex
&
GetUpperIndex
()
const
{
return
mIndexUp
;
}
// do carry check in reversed order, starting from lowest dimension
// don't check the highest dimension
static_for
<
0
,
ndim_partial_original
,
1
>
{}([
&
](
auto
IReverse
)
{
constexpr
index_t
i
=
ndim_partial_original
-
1
-
IReverse
;
__host__
__device__
constexpr
const
UpperIndex
&
GetIndex
()
const
{
return
GetUpperIndex
();
}
if
(
carry
)
__host__
__device__
constexpr
const
index_t
&
GetOffset
()
const
{
++
partial_original_id
(
i
);
return
GetLowerCoordinate
().
GetOffset
(
);
}
carry
=
false
;
if
(
partial_original_id
[
i
]
>=
partial_original_desc
.
GetLength
(
i
))
__host__
__device__
constexpr
UpperCoord
operator
+=
(
const
UpperIndex
&
idx_up_diff
)
{
partial_original_id
(
i
)
-=
partial_original_desc
.
GetLength
(
i
);
carry
=
true
;
}
});
}).
Else
([
&
](
auto
)
{
// shift up multi-id to avoid unsigned integer underflow during intermediate
// calculations. After the shift, should have new_multi_id[...] >= 1
partial_original_id
+=
partial_original_desc
.
GetLengths
()
-
partial_original_step_sizes
;
bool
borrow
=
false
;
// For transformation of multi-index difference, not all transformation functions need to
// know the old lower-index or the old upper-index. We pass both of them to the
// transformation function. The transformation function itself decides to use them or not.
mCoordLow
+=
tensor_desc_type
::
CalculateLowerIndexDiff
(
idx_up_diff
,
GetIndex
(),
GetLowerCoordinate
().
GetIndex
());
// do borrow check in reversed order, starting from lowest dimension
// don't check the highest dimension
static_for
<
0
,
ndim_partial_original
,
1
>
{}([
&
](
auto
IReverse
)
{
constexpr
index_t
i
=
ndim_partial_original
-
1
-
IReverse
;
// mIndexUp is updated here, but some (or all) of its entries may never be used
// compiler should remove those entries as dead code
mIndexUp
+=
idx_up_diff
;
if
(
borrow
)
{
--
partial_original_id
(
i
);
return
*
this
;
}
borrow
=
false
;
if
(
partial_original_id
[
i
]
<
partial_original_desc
.
GetLength
(
i
))
__host__
__device__
constexpr
UpperCoord
operator
-=
(
const
UpperIndex
&
idx_up_diff
)
{
partial_original_id
(
i
)
+=
partial_original_desc
.
GetLength
(
i
);
borrow
=
true
;
}
});
mCoordLow
-=
tensor_desc_type
::
CalculateLowerIndexDiff
(
idx_up_diff
,
GetIndex
(),
GetLowerCoordinate
().
GetIndex
());
// shift back down multi-id
// here, should have new_multi_id[...] >= GetLengths()
partial_original_id
=
partial_original_id
-
partial_original_desc
.
GetLengths
();
});
// mIndex is updated here, but some (or all) of its entries may never be used
// compiler should remove those entries as dead code
mIndexUp
-=
idx_up_diff
;
// update "mOriginalIndex"
static_for
<
0
,
ndim_partial_original
,
1
>
{}([
&
](
auto
I
)
{
constexpr
auto
idim_original
=
partial_original_dims
[
I
];
mOriginalIndex
(
idim_original
)
=
partial_original_id
[
I
];
});
// calculate new partial offset on this merged dimension
const
index_t
old_partial_offset
=
mPartialOffsets
[
idim
];
mPartialOffsets
(
idim
)
=
partial_original_desc
.
GetOffsetFromMultiIndex
(
partial_original_id
);
// update "mThreadSrcOffset", do "+" before "-" to avoid underflow
mOffset
=
(
mOffset
+
mPartialOffsets
[
idim
])
-
old_partial_offset
;
}).
Else
([
&
](
auto
fwd
)
{
static_if
<
PositiveDirection
>
{}([
&
](
auto
)
{
mOffset
+=
step_size
*
fwd
(
tensor_desc_type
{}).
GetStride
(
idim
);
}).
Else
([
&
](
auto
)
{
mOffset
-=
step_size
*
fwd
(
tensor_desc_type
{}).
GetStride
(
idim
);
});
});
return
*
this
;
}
// T is Array or Sequence
template
<
class
T
>
__host__
__device__
type
operator
+=
(
T
step_sizes
)
{
static_assert
(
is_same
<
typename
T
::
data_type
,
index_t
>
{}
&&
T
::
GetSize
()
==
nDim
,
"wrong! the rank of step size doesn't match with that of tensor coordinate"
);
static_for
<
0
,
nDim
,
1
>
{}([
&
](
auto
idim
)
{
if
(
step_sizes
[
idim
]
!=
0
)
__host__
__device__
constexpr
UpperCoord
operator
+
(
const
UpperIndex
&
idx_up_diff
)
const
{
this
->
MoveOnDimension
(
idim
,
step_sizes
[
idim
],
integral_constant
<
bool
,
true
>
{});
UpperCoord
coord_up
=
*
this
;
coord_up
+=
idx_up_diff
;
return
coord_up
;
}
});
return
*
this
;
__host__
__device__
constexpr
UpperCoord
operator
-
(
const
UpperIndex
&
idx_up_diff
)
const
{
UpperCoord
coord_up
=
*
this
;
coord_up
-=
idx_up_diff
;
return
coord_up
;
}
template
<
class
T
>
__host__
__device__
type
operator
-=
(
T
step_sizes
)
// Calculate offset diff without updating tensor-coordinate
// If idx_up_diff is know at compile time, and has only non-zero entries on linear dimensions,
// then all calculation can be done at compile-time.
// TODO: this function is not compiled to expected ISA
__host__
__device__
constexpr
index_t
CalculateOffsetDiff
(
const
UpperIndex
&
idx_up_diff
)
const
{
static_assert
(
is_same
<
typename
T
::
data_type
,
index_t
>
{}
&&
T
::
GetSize
()
==
nDim
,
"wrong! the rank of step size doesn't match with that of tensor coordinate"
);
// For transformation of multi-index difference, not all transformation functions need to
// know the old lower-index or the old upper-index. We pass both of them to the
// transformation function. The transformation function itself decides to use them or not.
const
auto
idx_low_diff
=
tensor_desc_type
::
CalculateLowerIndexDiff
(
idx_up_diff
,
GetIndex
(),
GetLowerCoordinate
().
GetIndex
());
static_for
<
0
,
nDim
,
1
>
{}([
&
](
auto
idim
)
{
if
(
step_sizes
[
idim
]
!=
0
)
{
this
->
MoveOnDimension
(
idim
,
step_sizes
[
idim
],
integral_constant
<
bool
,
false
>
{});
return
GetLowerCoordinate
().
CalculateOffsetDiff
(
idx_low_diff
);
}
});
return
*
this
;
__host__
__device__
constexpr
bool
IsUpperIndexMappedToValidOffset
()
const
{
return
tensor_desc_type
::
IsUpperIndexMappedToValidLowerIndex
(
GetIndex
())
&&
mCoordLow
.
IsUpperIndexMappedToValidOffset
();
}
template
<
class
T
>
__host__
__device__
constexpr
type
operator
+
(
T
step_sizes
)
const
private:
// mIndexUp may be calculated and updated, however, the value of some (or all) of its entries
// may
// never be used. Compiler should be able to remove these entries as well as its calculation
// as dead code.
// TODO: make sure compiler indeed remove these dead code
UpperIndex
mIndexUp
;
LowerCoord
mCoordLow
;
};
template
<
typename
TensorDesc
>
struct
TensorCoordinate
{
private:
template
<
typename
...
Ts
>
__host__
__device__
static
constexpr
auto
MakeDummyTensorCoordinate
(
NativeTensorDescriptor
<
Ts
...
>
)
{
type
coord
=
*
this
;
coord
+=
step_sizes
;
return
coord
;
return
NativeTensorCoordinate
<
NativeTensorDescriptor
<
Ts
...
>>
(
make_zero_array
<
index_t
,
TensorDesc
::
GetNumOfDimension
()
>
());
}
template
<
class
T
>
__host__
__device__
constexpr
type
operator
-
(
T
step_sizes
)
const
template
<
typename
...
Ts
>
__host__
__device__
static
constexpr
auto
MakeDummyTensorCoordinate
(
TransformedTensorDescriptor
<
Ts
...
>
)
{
type
coord
=
*
this
;
coord
-=
step_sizes
;
return
coord
;
return
TransformedTensorCoordinate
<
TransformedTensorDescriptor
<
Ts
...
>>
(
make_zero_array
<
index_t
,
TensorDesc
::
GetNumOfDimension
()
>
());
}
__host__
__device__
static
constexpr
index_t
RepositOrigin
()
{
return
0
;
}
private:
// Allocate register memory for all merged dimensions and normal dimensions.
// However, only those merged dimensions, whose index will be involved in arithmetic
// after the construction of this TensorCoordinate (e.g. when user move a slicing
// window on the merged dimension), will use these register memory.
// Let's hope compiler will optimize away those register memory allocated for normal
// dimensions, and those merged dimensions, that would never be involved in index
// arithmetic after construction of TensorCoordinate.
// TODO: refactor TensorCoordinate, after introducing the concept of "dimensions"
// and simplify implementation of ConstantMergedTensorDescriptor, so we don't need to
// count on compiler to optimize way those register memory for us
Array
<
index_t
,
nOriginalDim
>
mOriginalIndex
;
Array
<
index_t
,
nDim
>
mPartialOffsets
;
// complete offset
index_t
mOffset
;
public:
using
type
=
decltype
(
MakeDummyTensorCoordinate
(
TensorDesc
{}));
};
}
// namespace ck
...
...
composable_kernel/include/tensor_description/tensor_coordinate_deprecated.hpp
0 → 100644
View file @
b37cb71f
#ifndef CK_TENSOR_COORDINATE_DEPRECATED_HPP
#define CK_TENSOR_COORDINATE_DEPRECATED_HPP
#include "common_header.hpp"
#include "ConstantTensorDescriptor_deprecated.hpp"
#include "ConstantMergedTensorDescriptor_deprecated.hpp"
namespace
ck
{
// TensorDesc is ConstantTensorDescriptor_deprecated
template
<
class
TensorDesc
>
struct
NormalTensorCoordinate_deprecated
{
using
type
=
NormalTensorCoordinate_deprecated
;
using
tensor_desc_type
=
TensorDesc
;
static
constexpr
index_t
nDim
=
tensor_desc_type
::
GetNumOfDimension
();
__host__
__device__
constexpr
NormalTensorCoordinate_deprecated
(
Array
<
index_t
,
nDim
>
tensor_index
)
:
mOffset
{
tensor_desc_type
::
GetOffsetFromMultiIndex
(
tensor_index
)}
{
}
template
<
class
...
Xs
>
__host__
__device__
constexpr
NormalTensorCoordinate_deprecated
(
Xs
...
xs
)
:
NormalTensorCoordinate_deprecated
(
Array
<
index_t
,
nDim
>
{
xs
...})
{
}
template
<
index_t
...
Xs
>
__host__
__device__
constexpr
NormalTensorCoordinate_deprecated
(
Sequence
<
Xs
...
>
)
:
NormalTensorCoordinate_deprecated
(
Array
<
index_t
,
nDim
>
{
Xs
...})
{
}
__host__
__device__
constexpr
index_t
GetOffset
()
const
{
return
mOffset
;
}
// T is Array or Sequence
template
<
class
T
>
__host__
__device__
type
operator
+=
(
T
step_sizes
)
{
static_assert
(
is_same
<
typename
T
::
data_type
,
index_t
>
{}
&&
T
::
GetSize
()
==
nDim
,
"wrong!"
);
mOffset
+=
tensor_desc_type
::
GetOffsetFromMultiIndex
(
step_sizes
);
return
*
this
;
}
template
<
class
T
>
__host__
__device__
type
operator
-=
(
T
step_sizes
)
{
static_assert
(
is_same
<
typename
T
::
data_type
,
index_t
>
{}
&&
T
::
GetSize
()
==
nDim
,
"wrong!"
);
mOffset
-=
tensor_desc_type
::
GetOffsetFromMultiIndex
(
step_sizes
);
return
*
this
;
}
template
<
class
T
>
__host__
__device__
constexpr
type
operator
+
(
T
step_sizes
)
const
{
type
coord
=
*
this
;
coord
+=
step_sizes
;
return
coord
;
}
template
<
class
T
>
__host__
__device__
constexpr
type
operator
-
(
T
step_sizes
)
const
{
type
coord
=
*
this
;
coord
-=
step_sizes
;
return
coord
;
}
// reposition point of origin, and return compensated offset.
// This is a hack to reduce index calculation during looping over
// a tensor whose origin is this TensorCoordinate. It does so, by spitting
// out the run-time offset to the pointer (to the tensor data) held by this
// TensorCoordiante, so the caller can add the offset into the run-time pointer of
// the data, so only 1 run-time variable (update pointer) is needed, instead
// of 2 run-time variables (old pointer and this offset)
// TODO: after introducing the concept of "run-time tensor view", which contains the
// run-time pointer to the data, always keep track of the pointer, instead of both
// offset and the pointer. This also bring additional benefit that we don't need to
// worry the offset might underflow (because offset is unsigned integer) when updating it.
__host__
__device__
constexpr
index_t
RepositionOrigin
()
{
index_t
offset_diff
=
mOffset
;
mOffset
=
0
;
return
offset_diff
;
}
private:
index_t
mOffset
;
};
// TensorDesc is ConstantMergedTensorDescriptor_deprecated
template
<
class
TensorDesc
>
struct
MergedTensorCoordinate_deprecated
{
using
type
=
MergedTensorCoordinate_deprecated
;
using
tensor_desc_type
=
TensorDesc
;
static
constexpr
index_t
nDim
=
tensor_desc_type
::
GetNumOfDimension
();
static
constexpr
index_t
nOriginalDim
=
tensor_desc_type
::
GetOriginalTensorDescriptor
().
GetNumOfDimension
();
__host__
__device__
constexpr
MergedTensorCoordinate_deprecated
(
Array
<
index_t
,
nDim
>
tensor_index
)
:
mOriginalIndex
{
tensor_desc_type
::
GetOriginalMultiIndexFromMultiIndex
(
tensor_index
)}
{
// partial offset on each dimension
static_for
<
0
,
nDim
,
1
>
{}([
&
](
auto
idim
)
{
constexpr
auto
partial_original_dims
=
tensor_desc_type
::
GetContainedOriginalDimensions
(
idim
);
constexpr
auto
partial_original_desc
=
tensor_desc_type
::
GetOriginalTensorDescriptor
().
Extract
(
partial_original_dims
);
mPartialOffsets
(
idim
)
=
partial_original_desc
.
GetOffsetFromMultiIndex
(
extract_array
(
mOriginalIndex
,
partial_original_dims
));
});
// complete offset
mOffset
=
accumulate_on_array
(
mPartialOffsets
,
math
::
plus
<
index_t
>
{},
static_cast
<
index_t
>
(
0
));
}
template
<
class
...
Xs
>
__host__
__device__
constexpr
MergedTensorCoordinate_deprecated
(
Xs
...
xs
)
:
MergedTensorCoordinate_deprecated
(
Array
<
index_t
,
nDim
>
{
xs
...})
{
}
__host__
__device__
constexpr
index_t
GetOffset
()
const
{
return
mOffset
;
}
template
<
class
IDim
,
class
T
,
bool
PositiveDirection
>
__host__
__device__
void
MoveOnDimension
(
IDim
idim_
,
T
step_size
,
integral_constant
<
bool
,
PositiveDirection
>
)
{
constexpr
auto
idim
=
idim_
;
// if step_size is known at compile time
static_if
<
is_static
<
T
>::
value
>
{}(
[
&
](
auto
)
{
static_if
<
T
{}
==
0
>
{}([
&
](
auto
)
{
return
;
});
});
// update original index
static_if
<
tensor_desc_type
::
ContainMultipleOriginalDimensions
(
idim
)
>
{}([
&
](
auto
)
{
constexpr
auto
partial_original_dims
=
tensor_desc_type
::
GetContainedOriginalDimensions
(
idim
);
constexpr
index_t
ndim_partial_original
=
partial_original_dims
.
GetSize
();
constexpr
auto
partial_original_desc
=
tensor_desc_type
::
GetOriginalTensorDescriptor
().
Extract
(
partial_original_dims
);
const
auto
partial_original_step_sizes
=
partial_original_desc
.
GetMultiIndexFrom1dIndex
(
step_size
);
// update partial original multi-id
auto
partial_original_id
=
extract_array
(
mOriginalIndex
,
partial_original_dims
);
static_if
<
PositiveDirection
>
{}([
&
](
auto
)
{
partial_original_id
+=
partial_original_step_sizes
;
bool
carry
=
false
;
// do carry check in reversed order, starting from lowest dimension
// don't check the highest dimension
static_for
<
0
,
ndim_partial_original
-
1
,
1
>
{}([
&
](
auto
IReverse
)
{
constexpr
index_t
i
=
ndim_partial_original
-
1
-
IReverse
;
if
(
carry
)
{
++
partial_original_id
(
i
);
}
carry
=
false
;
if
(
partial_original_id
[
i
]
>=
partial_original_desc
.
GetLength
(
i
))
{
partial_original_id
(
i
)
-=
partial_original_desc
.
GetLength
(
i
);
carry
=
true
;
}
});
// highest dimension
if
(
carry
)
{
++
partial_original_id
(
0
);
}
}).
Else
([
&
](
auto
)
{
// shift up multi-id to avoid unsigned integer underflow during intermediate
// calculations. After the shift, should have new_multi_id[...] >= 1
partial_original_id
+=
partial_original_desc
.
GetLengths
()
-
partial_original_step_sizes
;
bool
borrow
=
false
;
// do borrow check in reversed order, starting from lowest dimension
// don't check the highest dimension
static_for
<
0
,
ndim_partial_original
-
1
,
1
>
{}([
&
](
auto
IReverse
)
{
constexpr
index_t
i
=
ndim_partial_original
-
1
-
IReverse
;
if
(
borrow
)
{
--
partial_original_id
(
i
);
}
borrow
=
false
;
if
(
partial_original_id
[
i
]
<
partial_original_desc
.
GetLength
(
i
))
{
partial_original_id
(
i
)
+=
partial_original_desc
.
GetLength
(
i
);
borrow
=
true
;
}
});
// highest dimension
if
(
borrow
)
{
--
partial_original_id
(
0
);
}
// shift back down multi-id
// here, should have new_multi_id[...] >= GetLengths()
partial_original_id
=
partial_original_id
-
partial_original_desc
.
GetLengths
();
});
// update "mOriginalIndex"
static_for
<
0
,
ndim_partial_original
,
1
>
{}([
&
](
auto
I
)
{
constexpr
auto
idim_original
=
partial_original_dims
[
I
];
mOriginalIndex
(
idim_original
)
=
partial_original_id
[
I
];
});
// calculate new partial offset on this merged dimension
const
index_t
old_partial_offset
=
mPartialOffsets
[
idim
];
mPartialOffsets
(
idim
)
=
partial_original_desc
.
GetOffsetFromMultiIndex
(
partial_original_id
);
// update "mThreadSrcOffset", do "+" before "-" to avoid underflow
mOffset
=
(
mOffset
+
mPartialOffsets
[
idim
])
-
old_partial_offset
;
}).
Else
([
&
](
auto
fwd
)
{
static_if
<
PositiveDirection
>
{}([
&
](
auto
)
{
mOffset
+=
step_size
*
fwd
(
tensor_desc_type
{}).
GetStride
(
idim
);
}).
Else
([
&
](
auto
)
{
mOffset
-=
step_size
*
fwd
(
tensor_desc_type
{}).
GetStride
(
idim
);
});
});
}
// T is Array or Sequence
template
<
class
T
>
__host__
__device__
type
operator
+=
(
T
step_sizes
)
{
static_assert
(
is_same
<
typename
T
::
data_type
,
index_t
>
{}
&&
T
::
GetSize
()
==
nDim
,
"wrong!"
);
static_for
<
0
,
nDim
,
1
>
{}([
&
](
auto
idim
)
{
// compiler should remove dead code path, because step_sizes is known at
// compile time
if
(
step_sizes
[
idim
]
!=
0
)
{
this
->
MoveOnDimension
(
idim
,
step_sizes
[
idim
],
integral_constant
<
bool
,
true
>
{});
}
});
return
*
this
;
}
template
<
class
T
>
__host__
__device__
type
operator
-=
(
T
step_sizes
)
{
static_assert
(
is_same
<
typename
T
::
data_type
,
index_t
>
{}
&&
T
::
GetSize
()
==
nDim
,
"wrong!"
);
static_for
<
0
,
nDim
,
1
>
{}([
&
](
auto
idim
)
{
// compiler should remove dead code path, because step_sizes is known at
// compile time
if
(
step_sizes
[
idim
]
!=
0
)
{
this
->
MoveOnDimension
(
idim
,
step_sizes
[
idim
],
integral_constant
<
bool
,
false
>
{});
}
});
return
*
this
;
}
template
<
class
T
>
__host__
__device__
constexpr
type
operator
+
(
T
step_sizes
)
const
{
type
coord
=
*
this
;
coord
+=
step_sizes
;
return
coord
;
}
template
<
class
T
>
__host__
__device__
constexpr
type
operator
-
(
T
step_sizes
)
const
{
type
coord
=
*
this
;
coord
-=
step_sizes
;
return
coord
;
}
__host__
__device__
static
constexpr
index_t
RepositionOrigin
()
{
return
0
;
}
private:
// Allocate register memory for all merged dimensions and normal dimensions.
// However, only those merged dimensions, whose index will be involved in arithmetic
// after the construction of this TensorCoordinate (e.g. when user move a slicing
// window on the merged dimension), will use these register memory.
// Let's hope compiler will optimize away those register memory allocated for normal
// dimensions, and those merged dimensions, that would never be involved in index
// arithmetic after construction of TensorCoordinate.
// TODO: refactor TensorCoordinate, after introducing the concept of "dimensions"
// and simplify implementation of ConstantMergedTensorDescriptor_deprecated, so we don't need to
// count on compiler to optimize away those register memory for us
Array
<
index_t
,
nOriginalDim
>
mOriginalIndex
;
Array
<
index_t
,
nDim
>
mPartialOffsets
;
// complete offset
index_t
mOffset
;
};
template
<
class
TensorDesc
>
struct
TensorCoordinate_deprecated
{
private:
template
<
class
...
Ts
>
__host__
__device__
static
constexpr
auto
MakeDummyTensorCoordinate
(
ConstantTensorDescriptor_deprecated
<
Ts
...
>
)
{
return
NormalTensorCoordinate_deprecated
<
ConstantTensorDescriptor_deprecated
<
Ts
...
>>
();
}
template
<
class
...
Ts
>
__host__
__device__
static
constexpr
auto
MakeDummyTensorCoordinate
(
ConstantMergedTensorDescriptor_deprecated
<
Ts
...
>
)
{
return
MergedTensorCoordinate_deprecated
<
ConstantMergedTensorDescriptor_deprecated
<
Ts
...
>>
();
}
public:
using
type
=
decltype
(
MakeDummyTensorCoordinate
(
TensorDesc
{}));
};
}
// namespace ck
#endif
composable_kernel/include/tensor_description/tensor_coordinate_helper.hpp
0 → 100644
View file @
b37cb71f
#ifndef CK_TENSOR_COORDINATE_HELPER_HPP
#define CK_TENSOR_COORDINATE_HELPER_HPP
#include "tensor_coordiante_hpp"
namespace
ck
{
template
<
typename
TensorDesc
>
__host__
__device__
constexpr
auto
make_tensor_coordinate
(
TensorDesc
,
MultiIndex
<
TensorDesc
::
GetNumOfDimension
()
>
idx
)
{
return
typename
TensorCoordinate
<
TensorDesc
>::
type
(
idx
);
}
}
// namespace ck
#endif
composable_kernel/include/tensor_description/tensor_descriptor.hpp
0 → 100644
View file @
b37cb71f
#ifndef CK_TENSOR_DESCRIPTOR_HPP
#define CK_TENSOR_DESCRIPTOR_HPP
#include "common_header.hpp"
#include "dimension.hpp"
#include "multi_index_transform.hpp"
namespace
ck
{
// tensor descriptor for "native tensor"
// A "native tensor" is a "true" tensor that can be represented by Lengths and Strides
template
<
typename
...
NativeDimensions
>
struct
NativeTensorDescriptor
{
using
type
=
NativeTensorDescriptor
;
static
constexpr
index_t
nDim
=
sizeof
...(
NativeDimensions
);
static
constexpr
auto
mDimensions
=
make_tuple
(
NativeDimensions
{}...);
using
Index
=
MultiIndex
<
nDim
>
;
__host__
__device__
static
constexpr
auto
GetNumOfDimension
()
{
return
Number
<
nDim
>
{};
}
template
<
index_t
IDim
>
__host__
__device__
static
constexpr
auto
GetLength
(
Number
<
IDim
>
)
{
return
mDimensions
.
At
(
Number
<
IDim
>
{}).
GetLength
();
}
template
<
index_t
IDim
>
__host__
__device__
static
constexpr
auto
GetStride
(
Number
<
IDim
>
)
{
return
mDimensions
.
At
(
Number
<
IDim
>
{}).
GetStride
();
}
template
<
index_t
...
IDims
>
__host__
__device__
static
constexpr
auto
GetLengths
(
Sequence
<
IDims
...
>
)
{
return
Sequence
<
GetLength
(
Number
<
IDims
>
{})...
>
{};
}
template
<
index_t
...
IDims
>
__host__
__device__
static
constexpr
auto
GetStrides
(
Sequence
<
IDims
...
>
)
{
return
Sequence
<
GetStride
(
Number
<
IDims
>
{})...
>
{};
}
template
<
index_t
IDim
,
index_t
...
IDims
>
__host__
__device__
static
constexpr
auto
GetLengths
(
Number
<
IDim
>
,
Number
<
IDims
>
...)
{
return
GetLengths
(
Sequence
<
IDim
,
IDims
...
>
{});
}
template
<
index_t
IDim
,
index_t
...
IDims
>
__host__
__device__
static
constexpr
auto
GetStrides
(
Number
<
IDim
>
,
Number
<
IDims
>
...)
{
return
GetStrides
(
Sequence
<
IDim
,
IDims
...
>
{});
}
__host__
__device__
static
constexpr
auto
GetLengths
()
{
return
GetLengths
(
typename
arithmetic_sequence_gen
<
0
,
nDim
,
1
>::
type
{});
}
__host__
__device__
static
constexpr
auto
GetStrides
()
{
return
GetStrides
(
typename
arithmetic_sequence_gen
<
0
,
nDim
,
1
>::
type
{});
}
__host__
__device__
static
constexpr
index_t
GetElementSize
()
{
return
reduce_on_sequence
(
GetLengths
(),
math
::
multiplies
<
index_t
>
{},
Number
<
1
>
{});
}
__host__
__device__
static
constexpr
index_t
GetElementSpace
()
{
return
reduce_on_sequence
(
(
GetLengths
()
-
Number
<
1
>
{})
*
GetStrides
(),
math
::
plus
<
index_t
>
{},
Number
<
1
>
{});
}
// TODO: this cannot return constepxr because of use of lambda
__host__
__device__
static
constexpr
index_t
CalculateOffset
(
const
Index
&
idx
)
{
index_t
offset
=
0
;
static_for
<
0
,
nDim
,
1
>
{}([
&
](
auto
idim
)
{
offset
+=
idx
[
idim
]
*
GetStride
(
idim
);
});
return
offset
;
}
__host__
__device__
static
constexpr
index_t
CalculateOffsetDiff
(
const
Index
&
idx_diff
)
{
index_t
offset_diff
=
0
;
static_for
<
0
,
nDim
,
1
>
{}(
[
&
](
auto
idim
)
{
offset_diff
+=
idx_diff
[
idim
]
*
GetStride
(
idim
);
});
return
offset_diff
;
}
template
<
index_t
IDim
>
__host__
__device__
static
constexpr
bool
IsLinearDimension
(
Number
<
IDim
>
)
{
return
true
;
}
__host__
__device__
static
constexpr
auto
GetLinearDimensionMask
()
{
return
typename
uniform_sequence_gen
<
nDim
,
1
>::
type
{};
}
__host__
__device__
static
constexpr
auto
GetNonLinearDimensionMask
()
{
return
typename
uniform_sequence_gen
<
nDim
,
0
>::
type
{};
}
__host__
__device__
static
constexpr
auto
GetNonLinearDimensions
()
{
return
Sequence
<>
{};
}
__host__
__device__
static
constexpr
auto
GetNonLinearIndependentDimensionGroups
()
{
return
Tuple
<>
{};
}
__host__
__device__
static
constexpr
bool
IsUpperIndexMappedToValidOffset
(
const
Index
&
/* idx */
)
{
return
true
;
}
};
// Tensor descriptor for "transformed tensor"
template
<
typename
LowTensorDescriptor
,
// NativeTensorDescriptor or TransformedTensorDescriptor
typename
Transforms
,
// Tuple<MultIndexTransforms...>
typename
LowDimensionIds
,
// Tuple<Sequence<...>>
typename
UpDimensionIds
>
// Tuple<Sequence<...>>
struct
TransformedTensorDescriptor
{
using
type
=
TransformedTensorDescriptor
;
static
constexpr
index_t
nTransform
=
Transforms
::
Size
();
struct
lambda_merge_sequences
{
template
<
typename
...
Seqs
>
__host__
__device__
constexpr
auto
operator
()(
Seqs
...
seqs
)
const
{
return
merge_sequences
(
seqs
...);
}
};
__host__
__device__
static
constexpr
auto
GetNumOfLowerDimension
()
{
// Here, we assume all lower-dimensions are active
// TODO: sanity-check all lower-dimension are indeed active
using
duplicated_low_active_dims
=
decltype
(
unpack
(
lambda_merge_sequences
{},
LowDimensionIds
{}));
using
low_active_dims
=
typename
sequence_unique_sort
<
duplicated_low_active_dims
,
math
::
less
<
index_t
>
,
math
::
equal
<
index_t
>>::
type
;
return
low_active_dims
::
Size
();
}
__host__
__device__
static
constexpr
auto
GetNumOfUpperDimension
()
{
using
duplicated_up_active_dims
=
decltype
(
unpack
(
lambda_merge_sequences
{},
UpDimensionIds
{}));
using
up_active_dims
=
typename
sequence_unique_sort
<
duplicated_up_active_dims
,
math
::
less
<
index_t
>
,
math
::
equal
<
index_t
>>::
type
;
return
up_active_dims
::
Size
();
}
static
constexpr
index_t
nDimUp
=
GetNumOfUpperDimension
();
static
constexpr
index_t
nDimLow
=
GetNumOfLowerDimension
();
using
UpperIndex
=
MultiIndex
<
nDimUp
>
;
using
LowerIndex
=
MultiIndex
<
nDimLow
>
;
__host__
__device__
constexpr
TransformedTensorDescriptor
()
{
static_assert
(
nTransform
==
Transforms
::
Size
()
&&
nTransform
==
LowDimensionIds
::
Size
()
&&
nTransform
==
UpDimensionIds
::
Size
(),
"wrong! # of transformations not the same"
);
// sanity check:
// LowDimensionIds should include all low-dimensions,
// UpDimensionIds should include all up-dimensions
using
mingled_up_dimension_ids
=
decltype
(
unpack
(
lambda_merge_sequences
{},
UpDimensionIds
{}));
using
sorted_up_dimension_ids
=
typename
sequence_sort
<
mingled_up_dimension_ids
,
math
::
less
<
index_t
>>::
type
;
static_assert
(
sorted_up_dimension_ids
::
Size
()
==
nDimUp
&&
is_valid_sequence_map
<
sorted_up_dimension_ids
>
{},
"wrong! UpDimensionIds is not configured correctly"
);
using
mingled_low_dimension_ids
=
decltype
(
unpack
(
lambda_merge_sequences
{},
LowDimensionIds
{}));
using
sorted_low_dimension_ids
=
typename
sequence_sort
<
mingled_low_dimension_ids
,
math
::
less
<
index_t
>>::
type
;
static_assert
(
sorted_low_dimension_ids
::
Size
()
==
nDimLow
&&
is_valid_sequence_map
<
sorted_low_dimension_ids
>
{},
"wrong! LowDimensionIds is not configured correctly"
);
// TODO: sanity check: while a up-dimension could be associated with multille
// transformation, a low-dimension should be associated with only one transformation
// TODO: sanity-check: GetLowerLengths of each transform should be consistent with lengths
// of lower-tensor-descriptor
}
__host__
__device__
static
constexpr
auto
GetNumOfDimension
()
{
return
GetNumOfUpperDimension
();
}
__host__
__device__
static
constexpr
auto
GetLowerTensorDescriptor
()
{
return
LowTensorDescriptor
{};
}
struct
lambda_GetUpperLengths
{
template
<
typename
Transform
>
__host__
__device__
constexpr
auto
operator
()(
const
Transform
&
tran
)
const
{
return
tran
.
GetUpperLengths
();
}
};
__host__
__device__
static
constexpr
auto
GetUpperLengths
()
{
constexpr
auto
tuple_of_up_lengths
=
transform_tuples
(
lambda_GetUpperLengths
{},
Transforms
{});
constexpr
auto
mingled_up_lengths
=
unpack
(
lambda_merge_sequences
{},
tuple_of_up_lengths
);
constexpr
auto
mingled_up_dimension_ids
=
unpack
(
lambda_merge_sequences
{},
UpDimensionIds
{});
// TODO: sanity-check mingled_up_dimension_ids contain all upper-dimensions
// TODO: sanity-check mingled_up_lengths have no conflicting upper-length
// sort by upper-dimension-ids
using
sort_up_dimension_ids
=
sequence_unique_sort
<
decltype
(
mingled_up_dimension_ids
),
math
::
less
<
index_t
>
,
math
::
equal
<
index_t
>>
;
// sanity-check sorted-upper-dimension-ids should be Sequence<0, 1, ... nDimUp-1>
static_assert
(
is_same
<
typename
sort_up_dimension_ids
::
type
,
typename
arithmetic_sequence_gen
<
0
,
nDimUp
,
1
>::
type
>
{},
"wrong! UpDimensionIds is not configured correctly"
);
constexpr
auto
sorted2unsorted_map
=
typename
sort_up_dimension_ids
::
sorted2unsorted_map
{};
constexpr
auto
sorted_up_lengths
=
pick_sequence_elements_by_ids
(
mingled_up_lengths
,
sorted2unsorted_map
);
return
sorted_up_lengths
;
}
__host__
__device__
static
constexpr
auto
GetLengths
()
{
return
GetUpperLengths
();
}
template
<
index_t
IDim
>
__host__
__device__
static
constexpr
auto
GetLength
(
Number
<
IDim
>
)
{
return
GetLengths
()[
IDim
];
}
template
<
index_t
...
IDims
>
__host__
__device__
static
constexpr
auto
GetLengths
(
Sequence
<
IDims
...
>
)
{
return
Sequence
<
GetLength
(
Number
<
IDims
>
{})...
>
{};
}
template
<
index_t
IDim
,
index_t
...
IDims
>
__host__
__device__
static
constexpr
auto
GetLengths
(
Number
<
IDim
>
,
Number
<
IDims
>
...)
{
return
GetLengths
(
Sequence
<
IDim
,
IDims
...
>
{});
}
__host__
__device__
static
constexpr
index_t
GetElementSize
()
{
return
reduce_on_sequence
(
GetLengths
(),
math
::
multiplies
<
index_t
>
{},
Number
<
1
>
{});
}
__host__
__device__
static
constexpr
index_t
GetElementSpace
()
{
// TODO: Is this the correct definition for transformed tensor?
return
GetLowerTensorDescriptor
().
GetElementSpace
();
}
// TODO: right now return value is not constexpr because use of non-constexpr lambda
__host__
__device__
static
constexpr
LowerIndex
CalculateLowerIndex
(
const
UpperIndex
&
idx_up
)
{
LowerIndex
idx_low
;
static_for
<
0
,
nTransform
,
1
>
{}([
&
](
auto
itran
)
{
constexpr
auto
tran
=
Transforms
{}.
At
(
itran
);
const
auto
idx_up_part
=
pick_array_element
(
idx_up
,
UpDimensionIds
{}.
At
(
itran
));
auto
idx_low_part
=
pick_array_element
(
idx_low
,
LowDimensionIds
{}.
At
(
itran
));
// this assume each lower (single) index is only assocaited with one transformation,
// which is required for index transformation, and has been checked during constructor
// of TransformedTensorDescriptor
idx_low_part
=
tran
.
CalculateLowerIndex
(
to_array
(
idx_up_part
));
});
return
idx_low
;
}
// TODO: right now return value is not constexpr because use of non-constepxr lambda
__host__
__device__
static
constexpr
LowerIndex
CalculateLowerIndexDiff
(
const
UpperIndex
&
idx_up_diff
,
const
UpperIndex
&
idx_up_old
,
const
LowerIndex
&
idx_low_old
)
{
LowerIndex
idx_low_diff
;
static_for
<
0
,
nTransform
,
1
>
{}([
&
](
auto
itran
)
{
constexpr
auto
tran
=
Transforms
{}.
At
(
itran
);
const
auto
idx_up_diff_part
=
pick_array_element
(
idx_up_diff
,
UpDimensionIds
{}.
At
(
itran
));
const
auto
idx_up_old_part
=
pick_array_element
(
idx_up_old
,
UpDimensionIds
{}.
At
(
itran
));
const
auto
idx_low_old_part
=
pick_array_element
(
idx_low_old
,
LowDimensionIds
{}.
At
(
itran
));
auto
idx_low_diff_part
=
pick_array_element
(
idx_low_diff
,
LowDimensionIds
{}.
At
(
itran
));
// this assume each lower (single) index is associated with only one transformation,
// which is required for index transformation, and has been checked during constructor
// of TransformedTensorDescriptor
idx_low_diff_part
=
tran
.
CalculateLowerIndexDiff
(
to_array
(
idx_up_diff_part
),
to_array
(
idx_up_old_part
),
to_array
(
idx_low_old_part
));
});
return
idx_low_diff
;
}
__host__
__device__
static
constexpr
index_t
CalculateOffset
(
const
UpperIndex
&
idx_up
)
{
return
GetLowerTensorDescriptor
().
CalculateOffset
(
CalculateLowerIndex
(
idx_up
));
}
struct
lambda_sequence_logical_and
{
template
<
typename
...
Seqs
>
__host__
__device__
constexpr
auto
operator
()(
Seqs
...)
const
{
return
typename
sequence_reduce
<
logical_and
<
index_t
>
,
Seqs
...
>::
type
{};
}
};
template
<
typename
T
>
struct
lambda_is_true
{
__host__
__device__
constexpr
auto
operator
()(
const
T
&
x
)
const
{
// TODO: remove static_cast once Sequence can take bool as entries
return
static_cast
<
bool
>
(
x
)
==
true
;
}
};
struct
lambda_get_linear_dimension_mask_of_single_tranform
{
// check only one transform at a time
template
<
typename
Transform
,
typename
LowDimensionId
,
typename
UpDimensionId
>
__host__
__device__
constexpr
auto
operator
()(
Transform
,
LowDimensionId
,
UpDimensionId
)
const
{
// judge if transformation is linear
constexpr
bool
is_linear_transform
=
Transform
::
IsLinearTransform
();
// judge if all lower dimension are linear
constexpr
bool
are_all_low_dim_linear
=
sequence_all_of
(
pick_sequence_elements_by_ids
(
GetLowerTensorDescriptor
().
GetLinearDimensionMask
(),
LowDimensionId
{}),
lambda_is_true
<
index_t
>
{});
// create linear mask for upper dimensions
constexpr
bool
are_up_dim_linear
=
is_linear_transform
&&
are_all_low_dim_linear
;
constexpr
auto
mask_of_up_linear_dims
=
modify_sequence_elements_by_ids
(
typename
uniform_sequence_gen
<
nDimUp
,
1
>::
type
{},
typename
uniform_sequence_gen
<
UpDimensionId
::
Size
(),
are_up_dim_linear
>::
type
{},
UpDimensionId
{});
return
mask_of_up_linear_dims
;
}
};
// TODO: this is a hack, transform_tuples() doesn't compile, would complain about constexpr
template
<
typename
F
,
typename
X
,
typename
Y
,
typename
Z
,
index_t
...
Is
>
__host__
__device__
static
constexpr
auto
dummy_transform_tuples_impl
(
F
f
,
X
x
,
Y
y
,
Z
z
,
Sequence
<
Is
...
>
)
{
return
make_tuple
(
f
(
x
.
At
(
Number
<
Is
>
{}),
y
.
At
(
Number
<
Is
>
{}),
z
.
At
(
Number
<
Is
>
{}))...);
}
__host__
__device__
static
constexpr
auto
GetLinearDimensionMask
()
{
#if 0
// create tuple of linear dimension masks, for all transformations
// TODO: this doesn't compile, because transform_tuples() complain about constexpr
constexpr auto tuple_of_linear_dimension_mask =
transform_tuples(lambda_get_linear_dimension_mask_of_single_tranform{},
Transforms{},
LowDimensionIds{},
UpDimensionIds{});
#else
// create tuple of linear dimension masks, for all transformations
// TODO: this is a hack
constexpr
auto
tuple_of_linear_dimension_mask
=
dummy_transform_tuples_impl
(
lambda_get_linear_dimension_mask_of_single_tranform
{},
Transforms
{},
LowDimensionIds
{},
UpDimensionIds
{},
typename
arithmetic_sequence_gen
<
0
,
Transforms
::
Size
(),
1
>::
type
{});
#endif
// reduce tuple of masks into one mask
constexpr
auto
linear_dimension_mask
=
unpack
(
lambda_sequence_logical_and
{},
tuple_of_linear_dimension_mask
);
return
linear_dimension_mask
;
}
__host__
__device__
static
constexpr
auto
GetNonLinearDimensionMask
()
{
return
GetLinearDimensionMask
().
Transform
(
logical_not
<
index_t
>
{});
}
template
<
index_t
IDim
>
__host__
__device__
static
constexpr
bool
IsLinearDimension
(
Number
<
IDim
>
)
{
return
GetLinearDimensionMask
().
At
(
Number
<
IDim
>
{});
}
__host__
__device__
static
constexpr
auto
GetLinearDimensions
()
{
constexpr
auto
linear_dimension_mask
=
GetLinearDimensionMask
();
return
pick_sequence_elements_by_mask
(
typename
arithmetic_sequence_gen
<
0
,
nDimUp
,
1
>::
type
{},
linear_dimension_mask
);
}
__host__
__device__
static
constexpr
auto
GetNonLinearDimensions
()
{
constexpr
auto
nonlinear_dimension_mask
=
GetNonLinearDimensionMask
();
return
pick_sequence_elements_by_mask
(
typename
arithmetic_sequence_gen
<
0
,
nDimUp
,
1
>::
type
{},
nonlinear_dimension_mask
);
}
#if 0
__host__ __device__ static constexpr auto GetNonLinearIndependentDimensionGroups()
{
// TODO: not implemented
}
#endif
__host__
__device__
static
constexpr
bool
IsUpperIndexMappedToValidLowerIndex
(
const
UpperIndex
&
idx_up
)
{
bool
flag
=
true
;
static_for
<
0
,
nTransform
,
1
>
{}([
&
](
auto
itran
)
{
constexpr
auto
tran
=
Transforms
{}.
At
(
itran
);
const
auto
idx_up_part
=
pick_array_element
(
idx_up
,
UpDimensionIds
{}.
At
(
itran
));
flag
=
flag
&&
tran
.
IsUpperIndexMappedToValidLowerIndex
(
to_array
(
idx_up_part
));
});
return
flag
;
}
// Whenever this function is called, it will call CalculateLowerIndex() recursively.
// If you have created a tensor coordinate already, instead of calling this function,
// you should call TensorCoordinate::IsUpperIndexMappedToValidOffset() which would
// be less expensive.
__host__
__device__
static
constexpr
bool
IsUpperIndexMappedToValidOffset
(
const
UpperIndex
&
idx_up
)
{
return
IsUpperIndexMappedToValidLowerIndex
(
idx_up
)
&&
GetLowerTensorDescriptor
().
IsUpperIndexMappedToValidOffset
(
CalculateLowerIndex
(
idx_up
));
}
};
}
// namespace ck
#endif
composable_kernel/include/tensor_description/tensor_descriptor_helper.hpp
0 → 100644
View file @
b37cb71f
#ifndef CK_TENSOR_DESCRIPTOR_HELPER_HPP
#define CK_TENSOR_DESCRIPTOR_HELPER_HPP
#include "common_header.hpp"
#include "tensor_descriptor.hpp"
namespace
ck
{
template
<
typename
Lengths
>
__host__
__device__
constexpr
auto
calculate_tensor_strides_packed
(
Lengths
)
{
return
reverse_inclusive_scan_sequence
(
Lengths
{}.
PopFront
(),
math
::
multiplies
<
index_t
>
{},
Number
<
1
>
{})
.
PushBack
(
Number
<
1
>
{});
}
template
<
typename
Lengths
,
index_t
Align
>
__host__
__device__
constexpr
auto
calculate_tensor_strides_aligned
(
Lengths
,
Number
<
Align
>
)
{
constexpr
index_t
L_back_align
=
Align
*
math
::
integer_divide_ceiler
<
index_t
>
{}(
Lengths
{}.
Back
(),
Align
);
return
calculate_tensor_strides_packed
(
Lengths
{}.
Modify
(
Number
<
Lengths
{}.
GetSize
()
-
1
>
{},
Number
<
L_back_align
>
{}));
}
template
<
index_t
...
Lengths
,
index_t
...
Strides
>
__host__
__device__
constexpr
auto
make_native_tensor_descriptor
(
Sequence
<
Lengths
...
>
,
Sequence
<
Strides
...
>
)
{
return
NativeTensorDescriptor
<
NativeDimension
<
Lengths
,
Strides
>
...
>
{};
}
template
<
typename
Lengths
>
__host__
__device__
constexpr
auto
make_native_tensor_descriptor_packed
(
Lengths
)
{
constexpr
auto
strides
=
calculate_tensor_strides_packed
(
Lengths
{});
return
make_native_tensor_descriptor
(
Lengths
{},
strides
);
}
template
<
typename
Lengths
,
index_t
Align
>
__host__
__device__
constexpr
auto
make_native_tensor_descriptor_aligned
(
Lengths
,
Number
<
Align
>
)
{
constexpr
auto
strides
=
calculate_tensor_strides_aligned
(
Lengths
{},
Number
<
Align
>
{});
return
make_native_tensor_descriptor
(
Lengths
{},
strides
);
}
template
<
typename
LowTensorDescriptor
,
typename
Transforms
,
typename
LowDimensionIds
,
typename
UpDimensionIds
>
__host__
__device__
constexpr
auto
transform_tensor_descriptor
(
LowTensorDescriptor
,
Transforms
,
LowDimensionIds
,
UpDimensionIds
)
{
return
TransformedTensorDescriptor
<
LowTensorDescriptor
,
Transforms
,
LowDimensionIds
,
UpDimensionIds
>
{};
}
template
<
typename
LowerTensorDescriptor
,
index_t
...
LowerLengths
,
index_t
...
LowerDimensionIds
,
index_t
...
UpperDimensionIds
>
__host__
__device__
constexpr
auto
reorder_transformed_tensor_descriptor_impl
(
LowerTensorDescriptor
,
Sequence
<
LowerLengths
...
>
,
Sequence
<
LowerDimensionIds
...
>
,
Sequence
<
UpperDimensionIds
...
>
)
{
return
TransformedTensorDescriptor
<
LowerTensorDescriptor
,
Tuple
<
PassThrough
<
LowerLengths
>
...
>
,
Tuple
<
Sequence
<
LowerDimensionIds
>
...
>
,
Tuple
<
Sequence
<
UpperDimensionIds
>
...
>>
{};
}
// reorder a NativeTensorDescriptor
template
<
typename
...
Ts
,
typename
MapLower2Upper
>
__host__
__device__
constexpr
auto
reorder_tensor_descriptor_given_lower2upper
(
NativeTensorDescriptor
<
Ts
...
>
,
MapLower2Upper
)
{
static_assert
(
is_valid_sequence_map
<
MapLower2Upper
>
{},
"wrong! MapLower2Upper is not a valid map"
);
constexpr
auto
old_desc
=
NativeTensorDescriptor
<
Ts
...
>
{};
static_assert
(
old_desc
.
GetNumOfDimension
()
==
MapLower2Upper
::
Size
(),
"wrong!"
);
constexpr
auto
new_lengths
=
old_desc
.
GetLengths
().
ReorderGivenOld2New
(
MapLower2Upper
{});
constexpr
auto
new_strides
=
old_desc
.
GetStrides
().
ReorderGivenOld2New
(
MapLower2Upper
{});
return
make_native_tensor_descriptor
(
new_lengths
,
new_strides
);
}
// reorder a TransformedTensorDescriptor
template
<
typename
...
Ts
,
typename
MapLower2Upper
>
__host__
__device__
constexpr
auto
reorder_tensor_descriptor_given_lower2upper
(
TransformedTensorDescriptor
<
Ts
...
>
,
MapLower2Upper
)
{
static_assert
(
is_valid_sequence_map
<
MapLower2Upper
>
{},
"wrong! MapLower2Upper is not a valid map"
);
constexpr
auto
low_desc
=
TransformedTensorDescriptor
<
Ts
...
>
{};
static_assert
(
low_desc
.
GetNumOfDimension
()
==
MapLower2Upper
::
Size
(),
"wrong!"
);
return
reorder_transformed_tensor_descriptor_impl
(
low_desc
,
low_desc
.
GetLengths
(),
typename
arithmetic_sequence_gen
<
0
,
low_desc
.
GetNumOfDimension
(),
1
>::
type
{},
MapLower2Upper
{});
}
template
<
typename
LowerTensorDescriptor
,
typename
MapUpper2Lower
>
__host__
__device__
constexpr
auto
reorder_tensor_descriptor_given_upper2lower
(
LowerTensorDescriptor
,
MapUpper2Lower
)
{
return
reorder_tensor_descriptor_given_lower2upper
(
LowerTensorDescriptor
{},
typename
sequence_map_inverse
<
MapUpper2Lower
>::
type
{});
}
template
<
typename
Lengths
,
typename
Strides
>
__host__
__device__
constexpr
bool
are_dimensions_unfoldable
(
Lengths
,
Strides
)
{
static_assert
(
Lengths
::
Size
()
==
Strides
::
Size
(),
"wrong!"
);
bool
flag
=
true
;
for
(
index_t
i
=
0
;
i
<
Lengths
::
Size
()
-
1
;
++
i
)
{
flag
=
flag
&&
Strides
::
At
(
i
)
==
Strides
::
At
(
i
+
1
)
*
Lengths
::
At
(
i
+
1
);
}
return
flag
;
}
// unfold only support NativeTennsorDescriptor, for now
template
<
index_t
FirstUnfoldDim
,
index_t
LastUnfoldDim
,
typename
...
Ts
>
__host__
__device__
constexpr
auto
unfold_tensor_descriptor
(
NativeTensorDescriptor
<
Ts
...
>
desc
,
Number
<
FirstUnfoldDim
>
,
Number
<
LastUnfoldDim
>
)
{
constexpr
index_t
nDim
=
desc
.
GetNumOfDimension
();
static_assert
(
FirstUnfoldDim
>=
0
&&
LastUnfoldDim
<
nDim
&&
FirstUnfoldDim
<=
LastUnfoldDim
,
"wrong! should have FirstUnfoldDim <= LastUnfoldDim!"
);
// left and right
constexpr
auto
left
=
typename
arithmetic_sequence_gen
<
0
,
FirstUnfoldDim
,
1
>::
type
{};
constexpr
auto
middle
=
typename
arithmetic_sequence_gen
<
FirstUnfoldDim
,
LastUnfoldDim
+
1
,
1
>::
type
{};
constexpr
auto
right
=
typename
arithmetic_sequence_gen
<
LastUnfoldDim
+
1
,
nDim
,
1
>::
type
{};
// sanity-checknfoldable
static_assert
(
are_dimensions_unfoldable
(
desc
.
GetLengths
(
middle
),
desc
.
GetStrides
(
middle
)),
"wrong! not unfoldable"
);
// unfolded length, stride
constexpr
index_t
unfold_length
=
reduce_on_sequence
(
desc
.
GetLengths
(
middle
),
math
::
multiplies
<
index_t
>
{},
Number
<
1
>
{});
constexpr
index_t
unfold_stride
=
desc
.
GetStride
(
Number
<
LastUnfoldDim
>
{});
// new lengths, strides
constexpr
auto
new_lengths
=
desc
.
GetLengths
(
left
).
PushBack
(
Number
<
unfold_length
>
{}).
PushBack
(
desc
.
GetLengths
(
right
));
constexpr
auto
new_strides
=
desc
.
GetStrides
(
left
).
PushBack
(
Number
<
unfold_stride
>
{}).
PushBack
(
desc
.
GetStrides
(
right
));
return
make_native_tensor_descriptor
(
new_lengths
,
new_strides
);
}
// a cluster map 1d index to N-d index
template
<
typename
Lengths
,
typename
ArrangeOrder
>
struct
ClusterDescriptor
{
static
constexpr
index_t
nDim
=
Lengths
::
Size
();
static
constexpr
auto
mDesc
=
transform_tensor_descriptor
(
make_native_tensor_descriptor_packed
(
Lengths
{}),
make_tuple
(
Merge
<
decltype
(
Lengths
::
ReorderGivenNew2Old
(
ArrangeOrder
{}))
>
{}),
make_tuple
(
ArrangeOrder
{}),
make_tuple
(
Sequence
<
0
>
{}));
__host__
__device__
constexpr
ClusterDescriptor
()
{
static_assert
(
Lengths
::
Size
()
==
nDim
&&
ArrangeOrder
::
Size
()
==
nDim
,
"wrong! size not the same"
);
static_assert
(
is_valid_sequence_map
<
ArrangeOrder
>
{},
"wrong! ArrangeOrder is wrong"
);
}
__host__
__device__
static
constexpr
index_t
GetElementSize
()
{
return
mDesc
.
GetElementSize
();
}
__host__
__device__
static
constexpr
auto
CalculateClusterIndex
(
index_t
idx_1d
)
{
return
mDesc
.
CalculateLowerIndex
(
MultiIndex
<
1
>
{
idx_1d
});
}
};
template
<
typename
Lengths
,
typename
ArrangeOrder
=
typename
arithmetic_sequence_gen
<
0
,
Lengths
::
Size
(),
1
>
::
type
>
__host__
__device__
constexpr
auto
make_cluster_descriptor
(
Lengths
,
ArrangeOrder
order
=
typename
arithmetic_sequence_gen
<
0
,
Lengths
::
Size
(),
1
>::
type
{})
{
return
ClusterDescriptor
<
Lengths
,
decltype
(
order
)
>
{};
}
}
// namespace ck
#endif
composable_kernel/include/tensor_operation/blockwise_gemm.hpp
View file @
b37cb71f
...
...
@@ -5,25 +5,23 @@
#include "ConstantMatrixDescriptor.hpp"
#include "threadwise_gemm.hpp"
#ifndef CK_BLOCKWISE_GEMM_USE_AMD_INLINE_ASM
#define CK_BLOCKWISE_GEMM_USE_AMD_INLINE_ASM 0
#endif
namespace
ck
{
// if following number are power of 2, index calculation shall be greatly reduced:
// MPerThreadSubC, NPerThreadSubC, MLevel0Cluster, NLevel0Cluster, MLevel1Cluster, NLevel1Cluster
// blockwise GEMM: C += transpose(A) * B
// A and B are visable to the whole block, C is distributed among each thread
// If following number are power of 2, index calculation shall be greatly reduced:
// MPerThreadSubC, NPerThreadSubC, MLevel0ThreadCluster, NLevel0ThreadCluster,
// MLevel1ThreadCluster, NLevel1ThreadCluster
template
<
index_t
BlockSize
,
index_t
EPack
,
class
BlockMatrixA
,
class
BlockMatrixB
,
class
ThreadMatrixC
,
typename
BlockMatrixA
,
typename
BlockMatrixB
,
typename
ThreadMatrixC
,
index_t
MPerThreadSubC
,
index_t
NPerThreadSubC
,
index_t
MLevel0Cluster
,
index_t
NLevel0Cluster
,
index_t
MLevel1Cluster
,
index_t
NLevel1Cluster
,
index_t
MLevel0
Thread
Cluster
,
index_t
NLevel0
Thread
Cluster
,
index_t
MLevel1
Thread
Cluster
,
index_t
NLevel1
Thread
Cluster
,
index_t
KPerThreadLoop
,
index_t
DataPerReadA
,
index_t
DataPerReadB
>
...
...
@@ -40,8 +38,8 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
__device__
BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
()
{
constexpr
index_t
ThreadPerLevel1Cluster
=
MLevel0Cluster
*
NLevel0Cluster
*
MLevel1Cluster
*
NLevel1Cluster
;
constexpr
index_t
ThreadPerLevel1Cluster
=
MLevel0ThreadCluster
*
NLevel0ThreadCluster
*
MLevel1
Thread
Cluster
*
NLevel1
Thread
Cluster
;
static_assert
(
BlockSize
==
ThreadPerLevel1Cluster
,
"wrong! wrong blocksize
\n
"
);
...
...
@@ -51,8 +49,8 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
constexpr
index_t
M
=
BlockMatrixA
::
NCol
();
// A is transposed
constexpr
index_t
N
=
BlockMatrixB
::
NCol
();
static_assert
(
M
%
(
MPerThreadSubC
*
MLevel0Cluster
*
MLevel1Cluster
)
==
0
&&
N
%
(
NPerThreadSubC
*
NLevel0Cluster
*
NLevel1Cluster
)
==
0
,
static_assert
(
M
%
(
MPerThreadSubC
*
MLevel0
Thread
Cluster
*
MLevel1
Thread
Cluster
)
==
0
&&
N
%
(
NPerThreadSubC
*
NLevel0
Thread
Cluster
*
NLevel1
Thread
Cluster
)
==
0
,
"wrong! Cannot evenly divide work among
\n
"
);
static_assert
(
...
...
@@ -70,26 +68,28 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
constexpr
index_t
M
=
BlockMatrixA
::
NCol
();
// A is transposed
constexpr
index_t
N
=
BlockMatrixB
::
NCol
();
constexpr
index_t
MRepeat
=
M
/
(
MPerThreadSubC
*
MLevel0Cluster
*
MLevel1Cluster
);
constexpr
index_t
NRepeat
=
N
/
(
NPerThreadSubC
*
NLevel0Cluster
*
NLevel1Cluster
);
constexpr
index_t
MRepeat
=
M
/
(
MPerThreadSubC
*
MLevel0ThreadCluster
*
MLevel1ThreadCluster
);
constexpr
index_t
NRepeat
=
N
/
(
NPerThreadSubC
*
NLevel0ThreadCluster
*
NLevel1ThreadCluster
);
return
Sequence
<
MRepeat
*
MPerThreadSubC
,
NRepeat
*
NPerThreadSubC
>
{};
}
__device__
static
MatrixIndex
GetBeginOfThreadMatrixC
(
index_t
thread_id
)
{
constexpr
index_t
ThreadPerLevel0Cluster
=
MLevel0Cluster
*
NLevel0Cluster
;
constexpr
index_t
ThreadPerLevel0Cluster
=
MLevel0
Thread
Cluster
*
NLevel0
Thread
Cluster
;
index_t
level1_id
=
thread_id
/
ThreadPerLevel0Cluster
;
index_t
level1_m_id
=
level1_id
/
NLevel1Cluster
;
index_t
level1_n_id
=
level1_id
%
NLevel1Cluster
;
index_t
level1_m_id
=
level1_id
/
NLevel1
Thread
Cluster
;
index_t
level1_n_id
=
level1_id
%
NLevel1
Thread
Cluster
;
index_t
level0_id
=
thread_id
%
ThreadPerLevel0Cluster
;
index_t
level0_m_id
=
level0_id
/
NLevel0Cluster
;
index_t
level0_n_id
=
level0_id
%
NLevel0Cluster
;
index_t
level0_m_id
=
level0_id
/
NLevel0
Thread
Cluster
;
index_t
level0_n_id
=
level0_id
%
NLevel0
Thread
Cluster
;
constexpr
index_t
MPerLevel0Cluster
=
MPerThreadSubC
*
MLevel0Cluster
;
constexpr
index_t
NPerLevel0Cluster
=
NPerThreadSubC
*
NLevel0Cluster
;
constexpr
index_t
MPerLevel0Cluster
=
MPerThreadSubC
*
MLevel0
Thread
Cluster
;
constexpr
index_t
NPerLevel0Cluster
=
NPerThreadSubC
*
NLevel0
Thread
Cluster
;
return
MatrixIndex
{
level1_m_id
*
MPerLevel0Cluster
+
level0_m_id
*
MPerThreadSubC
,
level1_n_id
*
NPerLevel0Cluster
+
level0_n_id
*
NPerThreadSubC
};
...
...
@@ -100,8 +100,10 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
{
constexpr
auto
c_thread_mtx
=
ThreadMatrixC
{};
constexpr
index_t
MPerLevel1Cluster
=
MPerThreadSubC
*
MLevel0Cluster
*
MLevel1Cluster
;
constexpr
index_t
NPerLevel1Cluster
=
NPerThreadSubC
*
NLevel0Cluster
*
NLevel1Cluster
;
constexpr
index_t
MPerLevel1Cluster
=
MPerThreadSubC
*
MLevel0ThreadCluster
*
MLevel1ThreadCluster
;
constexpr
index_t
NPerLevel1Cluster
=
NPerThreadSubC
*
NLevel0ThreadCluster
*
NLevel1ThreadCluster
;
index_t
m_repeat
=
m_in_c
/
MPerThreadSubC
;
index_t
n_repeat
=
n_in_c
/
NPerThreadSubC
;
...
...
@@ -113,329 +115,233 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
n_repeat
*
NPerLevel1Cluster
+
n_in_sub_c
};
}
#if CK_USE_AMD_INLINE_ASM
// 1 in specialized template represent pack size. fp32 = 1
template
<
index_t
PACKSIZE
>
__device__
void
outerProduct
(
const
typename
vector_type
<
float
,
4
>::
MemoryType
&
a
,
const
typename
vector_type
<
float
,
4
>::
MemoryType
&
b
,
typename
vector_type
<
float
,
4
>::
MemoryType
*
c
)
const
template
<
typename
FloatA
,
typename
FloatB
,
typename
FloatC
>
__device__
void
Run_naive
(
const
FloatA
*
p_a_block
,
const
FloatB
*
p_b_block
,
FloatC
*
p_c_thread
)
const
{
static_assert
(
1
==
PACKSIZE
,
"only packsize of 1 is supported with float datatype!"
);
constexpr
index_t
NRepeat
=
2
;
constexpr
auto
a_block_mtx
=
BlockMatrixA
{};
constexpr
auto
b_block_mtx
=
BlockMatrixB
{};
constexpr
auto
c_thread_mtx
=
ThreadMatrixC
{};
outerProduct1x4
(
a
.
x
,
b
,
c
[
0
*
NRepeat
]);
outerProduct1x4
(
a
.
y
,
b
,
c
[
1
*
NRepeat
]);
outerProduct1x4
(
a
.
z
,
b
,
c
[
2
*
NRepeat
]);
outerProduct1x4
(
a
.
w
,
b
,
c
[
3
*
NRepeat
]);
}
constexpr
index_t
K
=
a_block_mtx
.
NRow
();
// 1 in specialized template represent pack size. fp32 = 1
template
<
index_t
PACKSIZE
>
__device__
void
outerProduct
(
const
typename
vector_type
<
float
,
2
>::
MemoryType
&
a
,
const
typename
vector_type
<
float
,
4
>::
MemoryType
&
b
,
typename
vector_type
<
float
,
4
>::
MemoryType
*
c
)
const
{
static_assert
(
1
==
PACKSIZE
,
"only packsize of 1 is supported with float datatype!"
);
constexpr
index_t
NRepeat
=
2
;
constexpr
index_t
MPerThread
=
c_thread_mtx
.
NRow
();
constexpr
index_t
NPerThread
=
c_thread_mtx
.
NCol
();
outerProduct1x4
(
a
.
x
,
b
,
c
[
0
*
NRepeat
]);
outerProduct1x4
(
a
.
y
,
b
,
c
[
1
*
NRepeat
]);
}
constexpr
index_t
MPerLevel1Cluster
=
MPerThreadSubC
*
MLevel0ThreadCluster
*
MLevel1ThreadCluster
;
constexpr
index_t
NPerLevel1Cluster
=
NPerThreadSubC
*
NLevel0ThreadCluster
*
NLevel1ThreadCluster
;
// 1 in specialized template represent pack size. fp32 = 1
template
<
index_t
PACKSIZE
>
__device__
void
outerProduct
(
const
typename
vector_type
<
float
,
4
>::
MemoryType
&
a
,
const
typename
vector_type
<
float
,
2
>::
MemoryType
&
b
,
typename
vector_type
<
float
,
2
>::
MemoryType
*
c
)
const
{
static_assert
(
1
==
PACKSIZE
,
"only packsize of 1 is supported with float datatype!"
);
constexpr
index_t
NRepeat
=
2
;
constexpr
index_t
MRepeat
=
MPerThread
/
MPerThreadSubC
;
constexpr
index_t
NRepeat
=
NPerThread
/
NPerThreadSubC
;
outerProduct1x2
(
a
.
x
,
b
,
c
[
0
*
NRepeat
]);
outerProduct1x2
(
a
.
y
,
b
,
c
[
1
*
NRepeat
]);
outerProduct1x2
(
a
.
z
,
b
,
c
[
2
*
NRepeat
]);
outerProduct1x2
(
a
.
w
,
b
,
c
[
3
*
NRepeat
]);
}
// thread A, B for GEMM
constexpr
auto
a_thread_mtx
=
make_ConstantMatrixDescriptor_packed
(
Number
<
KPerThreadLoop
>
{},
Number
<
MPerThread
>
{});
// 1 in specialized template represent pack size. fp32 = 1
template
<
index_t
PACKSIZE
>
__device__
void
outerProduct
(
const
typename
vector_type
<
float
,
2
>::
MemoryType
&
a
,
const
typename
vector_type
<
float
,
2
>::
MemoryType
&
b
,
typename
vector_type
<
float
,
2
>::
MemoryType
*
c
)
const
{
static_assert
(
1
==
PACKSIZE
,
"only packsize of 1 is supported with float datatype!"
);
constexpr
index_t
NRepeat
=
2
;
constexpr
auto
b_thread_mtx
=
make_ConstantMatrixDescriptor_packed
(
Number
<
KPerThreadLoop
>
{},
Number
<
NPerThread
>
{});
outerProduct1x2
(
a
.
x
,
b
,
c
[
0
*
NRepeat
]);
outerProduct1x2
(
a
.
y
,
b
,
c
[
1
*
NRepeat
]);
}
FloatA
p_a_thread
[
a_thread_mtx
.
GetElementSpace
()];
FloatB
p_b_thread
[
b_thread_mtx
.
GetElementSpace
()];
// PACKSIZE for fp16 could be 4 or 2
template
<
index_t
PACKSIZE
>
__device__
void
outerProduct
(
const
typename
vector_type
<
typename
vector_type
<
half
,
PACKSIZE
>::
MemoryType
,
4
>::
MemoryType
&
a
,
const
typename
vector_type
<
typename
vector_type
<
half
,
PACKSIZE
>::
MemoryType
,
4
>::
MemoryType
&
b
,
typename
vector_type
<
float
,
4
>::
MemoryType
*
c
)
const
constexpr
auto
a_thread_copy
=
ThreadwiseMatrixSliceCopy
<
BlockMatrixA
,
decltype
(
a_thread_mtx
),
KPerThreadLoop
,
MPerThreadSubC
,
DataPerReadA
>
{};
constexpr
auto
b_thread_copy
=
ThreadwiseMatrixSliceCopy
<
BlockMatrixB
,
decltype
(
b_thread_mtx
),
KPerThreadLoop
,
NPerThreadSubC
,
DataPerReadB
>
{};
constexpr
auto
threadwise_gemm
=
ThreadwiseGemmTransANormalBNormalC
<
decltype
(
a_thread_mtx
),
decltype
(
b_thread_mtx
),
decltype
(
c_thread_mtx
)
>
{};
#pragma unroll
// loop over k
for
(
index_t
k_begin
=
0
;
k_begin
<
K
;
k_begin
+=
KPerThreadLoop
)
{
static_assert
(
2
==
PACKSIZE
||
4
==
PACKSIZE
,
"only packsize of 2,4 is supported with float datatype!"
);
constexpr
index_t
NRepeat
=
2
;
const
typename
vector_type
<
half
,
PACKSIZE
>::
MemoryType
*
reg_a
=
reinterpret_cast
<
const
typename
vector_type
<
half
,
PACKSIZE
>::
MemoryType
*>
(
&
a
);
outerProduct1x4Half
<
PACKSIZE
>
(
reg_a
[
0
],
b
,
c
[
0
*
NRepeat
]);
outerProduct1x4Half
<
PACKSIZE
>
(
reg_a
[
1
],
b
,
c
[
1
*
NRepeat
]);
outerProduct1x4Half
<
PACKSIZE
>
(
reg_a
[
2
],
b
,
c
[
2
*
NRepeat
]);
outerProduct1x4Half
<
PACKSIZE
>
(
reg_a
[
3
],
b
,
c
[
3
*
NRepeat
]);
}
// PACKSIZE for fp16 could be 4 or 2
template
<
index_t
PACKSIZE
>
__device__
void
outerProduct
(
const
typename
vector_type
<
typename
vector_type
<
half
,
PACKSIZE
>::
MemoryType
,
2
>::
MemoryType
&
a
,
const
typename
vector_type
<
typename
vector_type
<
half
,
PACKSIZE
>::
MemoryType
,
2
>::
MemoryType
&
b
,
typename
vector_type
<
float
,
2
>::
MemoryType
*
c
)
const
#pragma unroll
// read A
for
(
index_t
m_repeat
=
0
;
m_repeat
<
MRepeat
;
++
m_repeat
)
{
static_assert
(
2
==
PACKSIZE
||
4
==
PACKSIZE
,
"only packsize of 2,4 is supported with float datatype!"
);
constexpr
index_t
NRepeat
=
2
;
const
typename
vector_type
<
half
,
PACKSIZE
>::
MemoryType
*
reg_a
=
reinterpret_cast
<
const
typename
vector_type
<
half
,
PACKSIZE
>::
MemoryType
*>
(
&
a
);
outerProduct1x2Half
<
PACKSIZE
>
(
reg_a
[
0
],
b
,
c
[
0
*
NRepeat
]);
outerProduct1x2Half
<
PACKSIZE
>
(
reg_a
[
1
],
b
,
c
[
1
*
NRepeat
]);
a_thread_copy
.
Run
(
p_a_block
+
a_block_mtx
.
CalculateOffset
(
k_begin
,
m_repeat
*
MPerLevel1Cluster
)
+
mMyThreadOffsetA
,
p_a_thread
+
a_thread_mtx
.
CalculateOffset
(
0
,
m_repeat
*
MPerThreadSubC
));
}
// PACKSIZE for fp16 could be 4 or 2
template
<
index_t
PACKSIZE
>
__device__
void
outerProduct1x4Half
(
const
typename
vector_type
<
half
,
PACKSIZE
>::
MemoryType
&
a
,
const
typename
vector_type
<
typename
vector_type
<
half
,
PACKSIZE
>::
MemoryType
,
4
>::
MemoryType
&
b
,
vector_type
<
float
,
4
>::
MemoryType
&
c
)
const
#pragma unroll
// read B
for
(
index_t
n_repeat
=
0
;
n_repeat
<
NRepeat
;
++
n_repeat
)
{
static_if
<
PACKSIZE
==
4
>
{}([
&
](
auto
)
{
outerProduct1x4dot2TwoTimes
(
reinterpret_cast
<
const
half2
*>
(
&
a
),
reinterpret_cast
<
const
half2
*>
(
&
b
),
reinterpret_cast
<
float
*>
(
&
c
));
}).
Else
([
&
](
auto
)
{
static_if
<
PACKSIZE
==
2
>
{}([
&
](
auto
)
{
outerProduct1x4dot2
(
reinterpret_cast
<
const
half2
*>
(
&
a
),
reinterpret_cast
<
const
half2
*>
(
&
b
),
reinterpret_cast
<
float
*>
(
&
c
));
}).
Else
([
&
](
auto
fwd
)
{
// not implemented
static_assert
(
fwd
(
false
),
"wrong! packsize = 1 for fp16 is insensible."
);
});
});
b_thread_copy
.
Run
(
p_b_block
+
b_block_mtx
.
CalculateOffset
(
k_begin
,
n_repeat
*
NPerLevel1Cluster
)
+
mMyThreadOffsetB
,
p_b_thread
+
b_thread_mtx
.
CalculateOffset
(
0
,
n_repeat
*
NPerThreadSubC
));
}
// PACKSIZE for fp16 could be 4 or 2
template
<
index_t
PACKSIZE
>
__device__
void
outerProduct1x2Half
(
const
typename
vector_type
<
half
,
PACKSIZE
>::
MemoryType
&
a
,
const
typename
vector_type
<
typename
vector_type
<
half
,
PACKSIZE
>::
MemoryType
,
2
>::
MemoryType
&
b
,
vector_type
<
float
,
2
>::
MemoryType
&
c
)
const
{
static_if
<
PACKSIZE
==
4
>
{}([
&
](
auto
)
{
outerProduct1x2dot2TwoTimes
(
reinterpret_cast
<
const
half2
*>
(
&
a
),
reinterpret_cast
<
const
half2
*>
(
&
b
),
reinterpret_cast
<
float
*>
(
&
c
));
}).
Else
([
&
](
auto
)
{
static_if
<
PACKSIZE
==
2
>
{}([
&
](
auto
)
{
outerProduct1x2dot2
(
reinterpret_cast
<
const
half2
*>
(
&
a
),
reinterpret_cast
<
const
half2
*>
(
&
b
),
reinterpret_cast
<
float
*>
(
&
c
));
}).
Else
([
&
](
auto
fwd
)
{
// not implemented
static_assert
(
fwd
(
false
),
"wrong! packsize = 1 for fp16 is insensible."
);
});
});
// C += A * B
threadwise_gemm
.
Run
(
p_a_thread
,
p_b_thread
,
p_c_thread
);
}
}
template
<
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
void
Run_amd_asm
(
const
FloatA
*
__restrict__
p_a_block
,
const
FloatB
*
__restrict__
p_b_block
,
FloatC
*
__restrict__
p_c_thread
)
const
template
<
typename
FloatA
,
typename
FloatB
,
typename
FloatC
>
__device__
void
Run_pipelined_2x2
(
const
FloatA
*
p_a_block
,
const
FloatB
*
p_b_block
,
FloatC
*
p_c_thread
)
const
{
constexpr
auto
a_block_mtx
=
BlockMatrixA
{};
constexpr
auto
b_block_mtx
=
BlockMatrixB
{};
constexpr
auto
c_thread_mtx
=
ThreadMatrixC
{};
constexpr
index_t
M
=
a_block_mtx
.
NCol
();
constexpr
index_t
N
=
b_block_mtx
.
NCol
();
constexpr
index_t
K
=
a_block_mtx
.
NRow
();
constexpr
index_t
MPerThread
=
c_thread_mtx
.
NRow
();
constexpr
index_t
NPerThread
=
c_thread_mtx
.
NCol
();
// thread A, B for GEMM
constexpr
index_t
MPerLevel1Cluster
=
MPerThreadSubC
*
MLevel0ThreadCluster
*
MLevel1ThreadCluster
;
constexpr
index_t
NPerLevel1Cluster
=
NPerThreadSubC
*
NLevel0ThreadCluster
*
NLevel1ThreadCluster
;
constexpr
index_t
MRepeat
=
MPerThread
/
MPerThreadSubC
;
constexpr
index_t
NRepeat
=
NPerThread
/
NPerThreadSubC
;
static_assert
(
MRepeat
==
2
&&
NRepeat
==
2
,
"wrong! inline asm cannot deal with this GEMM config yet"
);
// thread A, B
constexpr
auto
a_thread_mtx
=
make_ConstantMatrixDescriptor_packed
(
Number
<
KPerThreadLoop
>
{},
Number
<
MPerThread
>
{});
constexpr
auto
b_thread_mtx
=
make_ConstantMatrixDescriptor_packed
(
Number
<
KPerThreadLoop
>
{},
Number
<
NPerThread
>
{});
constexpr
index_t
MPerLevel1Cluster
=
MPerThreadSubC
*
MLevel0Cluster
*
MLevel1Cluster
;
constexpr
index_t
NPerLevel1Cluster
=
NPerThreadSubC
*
NLevel0Cluster
*
NLevel1Cluster
;
// thread A-sub, B-sub
constexpr
auto
a_thread_sub_mtx
=
a_thread_mtx
.
MakeSubMatrixDescriptor
(
Number
<
KPerThreadLoop
>
{},
Number
<
MPerThreadSubC
>
{});
constexpr
auto
b_thread_sub_mtx
=
b_thread_mtx
.
MakeSubMatrixDescriptor
(
Number
<
KPerThreadLoop
>
{},
Number
<
NPerThreadSubC
>
{});
static_assert
((
MPerThreadSubC
==
4
||
MPerThreadSubC
==
2
)
&&
(
NPerThreadSubC
==
4
||
NPerThreadSubC
==
2
)
&&
KPerThreadLoop
==
1
,
"M/NPerThreadSubC wrong!"
);
// thread C-sub
constexpr
auto
c_thread_sub_mtx
=
ThreadMatrixC
::
MakeSubMatrixDescriptor
(
Number
<
MPerThreadSubC
>
{},
Number
<
NPerThreadSubC
>
{}
);
static_assert
(
MPerThread
%
4
==
0
&&
NPerThread
%
4
==
0
,
"M/NPerThread % 4 != 0"
);
FloatA
p_a_thread
[
a_thread_mtx
.
GetElementSpace
()];
FloatB
p_b_thread
[
b_thread_mtx
.
GetElementSpace
()];
constexpr
index_t
MRepeat
=
M
/
(
MPerThreadSubC
*
MLevel0Cluster
*
MLevel1Cluster
);
constexpr
index_t
NRepeat
=
N
/
(
NPerThreadSubC
*
NLevel0Cluster
*
NLevel1Cluster
);
constexpr
auto
a_thread_copy
=
ThreadwiseMatrixSliceCopy
<
BlockMatrixA
,
decltype
(
a_thread_mtx
),
KPerThreadLoop
,
MPerThreadSubC
,
DataPerReadA
>
{};
static_assert
(
MRepeat
==
2
&&
NRepeat
==
2
,
"M/NRepeat != 2"
);
constexpr
auto
b_thread_copy
=
ThreadwiseMatrixSliceCopy
<
BlockMatrixB
,
decltype
(
b_thread_mtx
),
KPerThreadLoop
,
NPerThreadSubC
,
DataPerReadB
>
{};
using
typeA
=
typename
vector_type
<
FloatA
,
MPerThreadSubC
>::
MemoryType
;
using
typeB
=
typename
vector_type
<
FloatB
,
NPerThreadSubC
>::
MemoryType
;
using
typeC
=
typename
vector_type
<
FloatC
,
NPerThreadSubC
>::
MemoryType
;
constexpr
auto
threadwise_gemm
=
ThreadwiseGemmTransANormalBNormalC
<
decltype
(
a_thread_sub_mtx
),
decltype
(
b_thread_sub_mtx
),
decltype
(
c_thread_sub_mtx
)
>
{};
FloatA
p_a_
thread
[
a_thread_mtx
.
GetElementSpace
()]
;
FloatB
p_b_
thread
[
b_thread_mtx
.
GetElementSpace
()]
;
const
FloatA
*
p_a_
block_off
=
p_a_block
+
mMyThreadOffsetA
;
const
FloatB
*
p_b_
block_off
=
p_b_block
+
mMyThreadOffsetB
;
typeA
*
reg_a
=
reinterpret_cast
<
typeA
*>
(
p_a_thread
);
typeB
*
reg_b
=
reinterpret_cast
<
typeB
*>
(
p_b_thread
);
typeC
*
reg_c
=
reinterpret_cast
<
typeC
*>
(
p_c_thread
);
reg_a
[
0
]
=
*
reinterpret_cast
<
const
typeA
*>
(
&
p_a_block
[
mMyThreadOffsetA
]);
reg_b
[
0
]
=
*
reinterpret_cast
<
const
typeB
*>
(
&
p_b_block
[
mMyThreadOffsetB
]);
reg_b
[
1
]
=
*
reinterpret_cast
<
const
typeB
*>
(
&
p_b_block
[(
mMyThreadOffsetB
+
NPerLevel1Cluster
)]);
reg_a
[
1
]
=
*
reinterpret_cast
<
const
typeA
*>
(
&
p_a_block
[(
mMyThreadOffsetA
+
MPerLevel1Cluster
)]);
outerProduct
<
EPack
>
(
reg_a
[
0
],
reg_b
[
0
],
&
reg_c
[
0
]);
outerProduct
<
EPack
>
(
reg_a
[
0
],
reg_b
[
1
],
&
reg_c
[
1
]);
#pragma unroll
for
(
index_t
k
=
1
;
k
<
K
;
++
k
)
{
reg_a
[
0
]
=
*
reinterpret_cast
<
const
typeA
*>
(
&
p_a_block
[(
mMyThreadOffsetA
+
k
*
M
)]);
outerProduct
<
EPack
>
(
reg_a
[
1
],
reg_b
[
0
],
&
reg_c
[
NRepeat
*
MPerThreadSubC
]);
reg_b
[
0
]
=
*
reinterpret_cast
<
const
typeB
*>
(
&
p_b_block
[(
mMyThreadOffsetB
+
k
*
N
)]);
outerProduct
<
EPack
>
(
reg_a
[
1
],
reg_b
[
1
],
&
reg_c
[
NRepeat
*
MPerThreadSubC
+
1
]);
reg_b
[
1
]
=
*
reinterpret_cast
<
const
typeB
*>
(
&
p_b_block
[(
mMyThreadOffsetB
+
k
*
N
+
NPerLevel1Cluster
)]);
reg_a
[
1
]
=
*
reinterpret_cast
<
const
typeA
*>
(
&
p_a_block
[(
mMyThreadOffsetA
+
k
*
M
+
MPerLevel1Cluster
)]);
outerProduct
<
EPack
>
(
reg_a
[
0
],
reg_b
[
0
],
&
reg_c
[
0
]);
outerProduct
<
EPack
>
(
reg_a
[
0
],
reg_b
[
1
],
&
reg_c
[
1
]);
}
outerProduct
<
EPack
>
(
reg_a
[
1
],
reg_b
[
0
],
&
reg_c
[
NRepeat
*
MPerThreadSubC
]);
outerProduct
<
EPack
>
(
reg_a
[
1
],
reg_b
[
1
],
&
reg_c
[
NRepeat
*
MPerThreadSubC
+
1
]);
}
#endif
// read A_sub_0
a_thread_copy
.
Run
(
p_a_block_off
,
p_a_thread
);
template
<
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
void
Run_source
(
const
FloatA
*
const
__restrict__
p_a_block
,
const
FloatB
*
const
__restrict__
p_b_block
,
FloatC
*
const
__restrict__
p_c_thread
)
const
{
constexpr
auto
True
=
integral_constant
<
bool
,
true
>
{};
constexpr
auto
False
=
integral_constant
<
bool
,
false
>
{};
// read B_sub_0
b_thread_copy
.
Run
(
p_b_block_off
,
p_b_thread
);
constexpr
auto
a_block_mtx
=
BlockMatrixA
{};
constexpr
auto
b_block_mtx
=
BlockMatrixB
{};
constexpr
auto
c_thread_mtx
=
ThreadMatrixC
{}
;
// read B_sub_1
b_thread_copy
.
Run
(
p_b_block_off
+
b_block_mtx
.
CalculateOffset
(
0
,
NPerLevel1Cluster
),
p_b_thread
+
b_thread_mtx
.
CalculateOffset
(
0
,
NPerThreadSubC
))
;
constexpr
index_t
K
=
a_block_mtx
.
NRow
();
// read A_sub_1
a_thread_copy
.
Run
(
p_a_block_off
+
a_block_mtx
.
CalculateOffset
(
0
,
MPerLevel1Cluster
),
p_a_thread
+
a_thread_mtx
.
CalculateOffset
(
0
,
MPerThreadSubC
));
constexpr
index_t
MPerThread
=
c_thread_mtx
.
NRow
();
constexpr
index_t
NPerT
hread
=
c_thread
_mtx
.
NCol
(
);
// C_sub_00 += transpose(A_sub_0) * B_sub_0
threadwise_gemm
.
Run
(
p_a_thread
,
p_b_t
hread
,
p_
c_thread
);
// thread A, B for GEMM
constexpr
auto
a_thread_mtx
=
make_ConstantMatrixDescriptor
(
Number
<
KPerThreadLoop
>
{},
Number
<
MPerThread
>
{},
Number
<
MPerThread
>
{});
// C_sub_01 += transpose(A_sub_0) * B_sub_1
threadwise_gemm
.
Run
(
p_a_thread
,
p_b_thread
+
b_thread_mtx
.
CalculateOffset
(
0
,
NPerThreadSubC
),
p_c_thread
+
ThreadMatrixC
::
CalculateOffset
(
0
,
NPerThreadSubC
));
constexpr
auto
b_thread_mtx
=
make_ConstantMatrixDescriptor
(
Number
<
KPerThreadLoop
>
{},
Number
<
NPerThread
>
{},
Number
<
NPerThread
>
{});
#pragma unroll
// loop over rest of k
for
(
index_t
k
=
KPerThreadLoop
;
k
<
K
;
k
+=
KPerThreadLoop
)
{
// read A_sub_0
a_thread_copy
.
Run
(
p_a_block_off
+
a_block_mtx
.
CalculateOffset
(
k
,
0
),
p_a_thread
);
// thread A-sub, B-sub for copy
constexpr
auto
a_thread_sub_mtx
=
make_ConstantMatrixDescriptor
(
Number
<
KPerThreadLoop
>
{},
Number
<
MPerThreadSubC
>
{},
Number
<
MPerThread
>
{});
// C_sub_10 += transpose(A_sub_1) * B_sub_0
threadwise_gemm
.
Run
(
p_a_thread
+
a_thread_mtx
.
CalculateOffset
(
0
,
MPerThreadSubC
),
p_b_thread
,
p_c_thread
+
ThreadMatrixC
::
CalculateOffset
(
MPerThreadSubC
,
0
));
constexpr
auto
b_th
read_sub_
mtx
=
make_ConstantMatrixDescriptor
(
Number
<
KPerThreadLoop
>
{},
Number
<
NPerThreadSubC
>
{},
Number
<
NPerT
hread
>
{}
);
//
read
B
_sub_
0
b_thread_copy
.
Run
(
p_b_block_off
+
b_block_mtx
.
CalculateOffset
(
k
,
0
),
p_b_t
hread
);
FloatA
p_a_thread
[
a_thread_mtx
.
GetElementSpace
()];
FloatB
p_b_thread
[
b_thread_mtx
.
GetElementSpace
()];
// C_sub_11 += transpose(A_sub_1) * B_sub_1
threadwise_gemm
.
Run
(
p_a_thread
+
a_thread_mtx
.
CalculateOffset
(
0
,
MPerThreadSubC
),
p_b_thread
+
b_thread_mtx
.
CalculateOffset
(
0
,
NPerThreadSubC
),
p_c_thread
+
ThreadMatrixC
::
CalculateOffset
(
MPerThreadSubC
,
NPerThreadSubC
));
constexpr
index_t
MPerLevel1Cluster
=
MPerThreadSubC
*
MLevel0Cluster
*
MLevel1Cluster
;
constexpr
index_t
NPerLevel1Cluster
=
NPerThreadSubC
*
NLevel0Cluster
*
NLevel1Cluster
;
// read B_sub_1
b_thread_copy
.
Run
(
p_b_block_off
+
b_block_mtx
.
CalculateOffset
(
k
,
NPerLevel1Cluster
),
p_b_thread
+
b_thread_mtx
.
CalculateOffset
(
0
,
NPerThreadSubC
));
constexpr
index_t
MRepeat
=
MPerThread
/
MPerThreadSubC
;
constexpr
index_t
NRepeat
=
NPerThread
/
NPerThreadSubC
;
// read A_sub_1
a_thread_copy
.
Run
(
p_a_block_off
+
a_block_mtx
.
CalculateOffset
(
k
,
MPerLevel1Cluster
),
p_a_thread
+
a_thread_mtx
.
CalculateOffset
(
0
,
MPerThreadSubC
));
#pragma unroll
// loop over k
for
(
index_t
k_begin
=
0
;
k_begin
<
K
;
k_begin
+=
KPerThreadLoop
)
{
#pragma unroll
// copy A-sub to form A
for
(
index_t
m_repeat
=
0
;
m_repeat
<
MRepeat
;
++
m_repeat
)
{
threadwise_matrix_copy
(
a_block_mtx
,
p_a_block
+
a_block_mtx
.
GetOffsetFromMultiIndex
(
k_begin
,
m_repeat
*
MPerLevel1Cluster
)
+
mMyThreadOffsetA
,
a_thread_mtx
,
p_a_thread
+
a_thread_mtx
.
GetOffsetFromMultiIndex
(
0
,
m_repeat
*
MPerThreadSubC
),
a_thread_sub_mtx
.
GetLengths
(),
Number
<
DataPerReadA
>
{});
}
// C_sub_00 += transpose(A_sub_0) * B_sub_0
threadwise_gemm
.
Run
(
p_a_thread
,
p_b_thread
,
p_c_thread
);
#pragma unroll
// copy B-sub to form B
for
(
index_t
n_repeat
=
0
;
n_repeat
<
NRepeat
;
++
n_repeat
)
{
threadwise_matrix_copy
(
b_block_mtx
,
p_b_block
+
b_block_mtx
.
GetOffsetFromMultiIndex
(
k_begin
,
n_repeat
*
NPerLevel1Cluster
)
+
mMyThreadOffsetB
,
b_thread_mtx
,
p_b_thread
+
b_thread_mtx
.
GetOffsetFromMultiIndex
(
0
,
n_repeat
*
NPerThreadSubC
),
b_thread_sub_mtx
.
GetLengths
(),
Number
<
DataPerReadB
>
{});
// C_sub_01 += transpose(A_sub_0) * B_sub_1
threadwise_gemm
.
Run
(
p_a_thread
,
p_b_thread
+
b_thread_mtx
.
CalculateOffset
(
0
,
NPerThreadSubC
),
p_c_thread
+
ThreadMatrixC
::
CalculateOffset
(
0
,
NPerThreadSubC
));
}
// C = A * B
threadwise_gemm
(
a_thread_mtx
,
True
,
p_a_thread
,
b_thread_mtx
,
False
,
// C_sub_10 += transpose(A_sub_1) * B_sub_0
threadwise_gemm
.
Run
(
p_a_thread
+
a_thread_mtx
.
CalculateOffset
(
0
,
MPerThreadSubC
),
p_b_thread
,
c_thread_mtx
,
False
,
p_c_thread
);
}
p_c_thread
+
ThreadMatrixC
::
CalculateOffset
(
MPerThreadSubC
,
0
));
// C_sub_11 += transpose(A_sub_1) * B_sub_1
threadwise_gemm
.
Run
(
p_a_thread
+
a_thread_mtx
.
CalculateOffset
(
0
,
MPerThreadSubC
),
p_b_thread
+
b_thread_mtx
.
CalculateOffset
(
0
,
NPerThreadSubC
),
p_c_thread
+
ThreadMatrixC
::
CalculateOffset
(
MPerThreadSubC
,
NPerThreadSubC
));
}
template
<
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
void
Run
(
const
FloatA
*
__restrict__
p_a_block
,
const
FloatB
*
__restrict__
p_b_block
,
FloatC
*
__restrict__
p_c_thread
)
const
template
<
typename
FloatA
,
typename
FloatB
,
typename
FloatC
>
__device__
void
Run
(
const
FloatA
*
p_a_block
,
const
FloatB
*
p_b_block
,
FloatC
*
p_c_thread
)
const
{
#if CK_USE_AMD_INLINE_ASM && CK_BLOCKWISE_GEMM_USE_AMD_INLINE_ASM
// The assembly path doesn't support bfloat16 using asm instructions
#if MIOPEN_USE_BFP16 == 1
Run_source
(
p_a_block
,
p_b_block
,
p_c_thread
);
#if CK_EXPERIMENTAL_BLOCKWISE_GEMM_USE_PIPELINE
constexpr
index_t
MPerThread
=
ThreadMatrixC
::
NRow
();
constexpr
index_t
NPerThread
=
ThreadMatrixC
::
NCol
();
constexpr
index_t
MRepeat
=
MPerThread
/
MPerThreadSubC
;
constexpr
index_t
NRepeat
=
NPerThread
/
NPerThreadSubC
;
static_if
<
MRepeat
==
2
&&
NRepeat
==
2
>
{}([
&
](
auto
)
{
Run_pipelined_2x2
(
p_a_block
,
p_b_block
,
p_c_thread
);
}).
Else
([
&
](
auto
)
{
Run_naive
(
p_a_block
,
p_b_block
,
p_c_thread
);
});
#else
Run_
amd_asm
(
p_a_block
,
p_b_block
,
p_c_thread
);
Run_
naive
(
p_a_block
,
p_b_block
,
p_c_thread
);
#endif
#else
Run_source
(
p_a_block
,
p_b_block
,
p_c_thread
);
#endif // CK_USE_AMD_INLINE_ASM
}
};
...
...
composable_kernel/include/tensor_operation/blockwise_gemm_xdlops.hpp
View file @
b37cb71f
...
...
@@ -95,9 +95,21 @@ __device__ void WaveWiseGemmMx64(const FloatA* const __restrict__ p_a_wave,
(
mfma_info
::
group_size
*
mfma_info
::
num_blks_wave
)
+
a_off
;
// A is transposed
index_t
bindex
=
b_off
+
lane_b
+
n
*
mfma_info
::
num_threads_blk
;
p_c_thread
[
m
+
n
*
output_m
+
b
*
output_m
*
mfma_info
::
num_blks_wave
]
+=
math
::
inner_product_with_conversion
<
FloatC
>
{}(
p_a_wave
[
aindex
],
// p_c_thread[m + n * output_m + b * output_m * mfma_info::num_blks_wave] +=
// math::inner_product_with_conversion<FloatC>{}(p_a_wave[aindex],
// p_b_wave[bindex]);
index_t
cindex
=
m
+
n
*
output_m
+
b
*
output_m
*
mfma_info
::
num_blks_wave
;
if
(
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
==
0
&&
cindex
==
0
)
{
printf
(
"Run p_c[%d] = %f, p_a[%d] = %f, p_b[%d] = %f
\n
"
,
cindex
,
p_c_thread
[
cindex
],
aindex
,
p_a_wave
[
aindex
],
bindex
,
p_b_wave
[
bindex
]);
p_c_thread
[
cindex
+
k
]
=
p_a_wave
[
aindex
];
}
}
}
}
...
...
@@ -251,6 +263,9 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_xdlops
constexpr
index_t
N
=
BlockMatrixB
::
NCol
();
constexpr
index_t
K
=
BlockMatrixA
::
NRow
();
if
(
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
==
0
)
printf
(
"Run M %d, N %d, K %d
\n
"
,
M
,
N
,
K
);
// static_if<EnableXdlops>{}([&](auto) {
// WaveWiseGemmMx64_xdlops<M,
// N,
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment