Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
yangql
composable_kernel-1
Commits
238d58c2
Commit
238d58c2
authored
Aug 20, 2019
by
Chao Liu
Browse files
adding tensor_view
parent
08bf57b0
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
609 additions
and
19 deletions
+609
-19
composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v1r3_chwn_cyxk_khwn_padded.hpp
..._convolution_implicit_gemm_v1r3_chwn_cyxk_khwn_padded.hpp
+113
-7
composable_kernel/include/tensor_description/ConstantTensorDescriptor.hpp
...l/include/tensor_description/ConstantTensorDescriptor.hpp
+61
-2
composable_kernel/include/tensor_description/tensor_coordinate.hpp
...e_kernel/include/tensor_description/tensor_coordinate.hpp
+13
-1
composable_kernel/include/tensor_description/tensor_view.hpp
composable_kernel/include/tensor_description/tensor_view.hpp
+100
-0
composable_kernel/include/tensor_operation/blockwise_generic_tensor_slice_copy.hpp
.../tensor_operation/blockwise_generic_tensor_slice_copy.hpp
+131
-6
composable_kernel/include/tensor_operation/threadwise_generic_tensor_slice_copy.hpp
...tensor_operation/threadwise_generic_tensor_slice_copy.hpp
+167
-0
composable_kernel/include/utility/integral_constant.hpp
composable_kernel/include/utility/integral_constant.hpp
+3
-0
composable_kernel/include/utility/vector_type.hpp
composable_kernel/include/utility/vector_type.hpp
+19
-1
driver/src/driver.cpp
driver/src/driver.cpp
+2
-2
No files found.
composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v1r3_chwn_cyxk_khwn_padded.hpp
View file @
238d58c2
...
...
@@ -62,6 +62,9 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn_padded
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
True
=
integral_constant
<
bool
,
true
>
{};
constexpr
auto
False
=
integral_constant
<
bool
,
false
>
{};
constexpr
auto
in_c_h_w_n_global_desc
=
InGlobalDesc
{};
constexpr
auto
wei_c_y_x_k_global_desc
=
WeiGlobalDesc
{};
constexpr
auto
out_k_h_w_n_global_desc
=
OutGlobalDesc
{};
...
...
@@ -121,10 +124,21 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn_padded
constexpr
auto
wei_c_k_block_desc
=
make_ConstantTensorDescriptor_aligned
(
Sequence
<
CPerBlock
,
KPerBlock
>
{},
Number
<
max_align
>
{});
constexpr
auto
wei_c_1_1_k_block_desc
=
make_ConstantTensorDescriptor_aligned
(
Sequence
<
CPerBlock
,
1
,
1
,
KPerBlock
>
{},
Number
<
max_align
>
{});
// LDS: be careful of alignment
constexpr
index_t
in_block_space
=
in_c_h_w_n_block_desc
.
GetElementSpace
();
constexpr
index_t
wei_block_space
=
wei_c_k_block_desc
.
GetElementSpace
();
__shared__
Float
p_in_block
[
in_block_space
];
__shared__
Float
p_wei_block
[
wei_block_space
];
// tensor view of threadwise output in register
constexpr
auto
out_k_h_w_n_thread_desc
=
make_ConstantTensorDescriptor_packed
(
Sequence
<
KPerThread
,
HoPerThread
,
WoPerThread
,
NPerThread
>
{});
#if 0
// blockwise input copy
// format is [C, Hi, Wi, N]
auto blockwise_in_copy =
...
...
@@ -142,7 +156,31 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn_padded
InBlockCopyDataPerAccess_N,
InBlockCopyDataPerAccess_N>({0, 0, 0, 0},
{0, 0, 0, 0});
#else
auto
in_c_h_w_n_global
=
make_TensorView
(
in_c_h_w_n_global_desc
,
p_in_global
);
auto
in_c_h_w_n_block
=
make_TensorView
(
in_c_h_w_n_block_desc
,
p_in_block
);
auto
blockwise_in_copy
=
BlockwiseGenericTensorSliceCopy_v3
<
BlockSize
,
decltype
(
in_c_h_w_n_global
),
decltype
(
in_c_h_w_n_block
),
decltype
(
in_c_h_w_n_block
.
GetLengths
()),
InBlockCopySubLengths_CHWN
,
InBlockCopyClusterLengths_CHWN
,
Sequence
<
0
,
1
,
2
,
3
>
,
Sequence
<
0
,
1
,
2
,
3
>
,
Sequence
<
0
,
1
,
2
,
3
>
,
3
,
3
,
InBlockCopyDataPerAccess_N
,
InBlockCopyDataPerAccess_N
>
(
in_c_h_w_n_global
,
{
0
,
hi_block_data_begin
,
wi_block_data_begin
,
n_block_data_begin
},
in_c_h_w_n_block
,
{
0
,
0
,
0
,
0
});
#endif
#if 0
// blockwise wei copy
// format is [CPerBlock, KPerBlock]
const auto blockwise_wei_copy =
...
...
@@ -159,6 +197,38 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn_padded
1,
WeiBlockCopyDataPerAccess_K,
WeiBlockCopyDataPerAccess_K>({0, 0}, {0, 0});
#else
auto
wei_c_y_x_k_global
=
make_TensorView
(
wei_c_y_x_k_global_desc
,
p_wei_global
);
auto
wei_c_1_1_k_block
=
make_TensorView
(
wei_c_1_1_k_block_desc
,
p_wei_block
);
constexpr
index_t
WeiBlockCopySubLengths_C
=
WeiBlockCopySubLengths_CK
{}[
0
];
constexpr
index_t
WeiBlockCopySubLengths_K
=
WeiBlockCopySubLengths_CK
{}[
1
];
using
WeiBlockCopySubLengths_CYXK
=
Sequence
<
WeiBlockCopySubLengths_C
,
1
,
1
,
WeiBlockCopySubLengths_K
>
;
constexpr
index_t
WeiBlockCopyClusterLengths_C
=
WeiBlockCopyClusterLengths_CK
{}[
0
];
constexpr
index_t
WeiBlockCopyClusterLengths_K
=
WeiBlockCopyClusterLengths_CK
{}[
1
];
using
WeiBlockCopyClusterLengths_CYXK
=
Sequence
<
WeiBlockCopyClusterLengths_C
,
1
,
1
,
WeiBlockCopyClusterLengths_K
>
;
auto
blockwise_wei_copy
=
BlockwiseGenericTensorSliceCopy_v3
<
BlockSize
,
decltype
(
wei_c_y_x_k_global
),
decltype
(
wei_c_1_1_k_block
),
decltype
(
wei_c_1_1_k_block
.
GetLengths
()),
WeiBlockCopySubLengths_CYXK
,
WeiBlockCopyClusterLengths_CYXK
,
Sequence
<
0
,
1
,
2
,
3
>
,
Sequence
<
0
,
1
,
2
,
3
>
,
Sequence
<
0
,
1
,
2
,
3
>
,
3
,
3
,
WeiBlockCopyDataPerAccess_K
,
WeiBlockCopyDataPerAccess_K
>
(
wei_c_y_x_k_global
,
{
0
,
0
,
0
,
k_block_data_begin
},
wei_c_1_1_k_block
,
{
0
,
0
,
0
,
0
});
#endif
// a series of blockwise batched GEMM
// C_matrix += transpose(A_matrix) * B_matrix
...
...
@@ -200,13 +270,6 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn_padded
GemmDataPerReadA
,
GemmDataPerReadB
>
{};
// LDS: be careful of alignment
constexpr
index_t
in_block_space
=
in_c_h_w_n_block_desc
.
GetElementSpace
();
constexpr
index_t
wei_block_space
=
wei_c_k_block_desc
.
GetElementSpace
();
__shared__
Float
p_in_block
[
in_block_space
];
__shared__
Float
p_wei_block
[
wei_block_space
];
// register
// C++ lambda doesn't capture array, use pointer instead
Float
p_out_thread_data
[
out_k_h_w_n_thread_desc
.
GetElementSpace
()];
...
...
@@ -215,6 +278,7 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn_padded
// set threadwise output tensor to 0
threadwise_matrix_set_zero
(
c_k_wn_thread_mtx_desc
,
p_out_thread
);
#if 0
for(index_t y = 0; y < Y; ++y)
{
for(index_t x = 0; x < X; ++x)
...
...
@@ -246,6 +310,48 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn_padded
}
}
}
#else
for
(
index_t
y
=
0
;
y
<
Y
;
++
y
)
{
for
(
index_t
x
=
0
;
x
<
X
;
++
x
)
{
for
(
index_t
c_block_data_begin
=
0
;
c_block_data_begin
<
C
;
c_block_data_begin
+=
CPerBlock
)
{
#if 1 // debug
blockwise_in_copy
.
Run
();
blockwise_wei_copy
.
Run
();
#endif
__syncthreads
();
blockwise_batch_gemm
.
Run
(
p_wei_block
,
p_in_block
,
p_out_thread
);
__syncthreads
();
// move along C
blockwise_in_copy
.
MoveSrcSliceWindow
(
Sequence
<
CPerBlock
,
0
,
0
,
0
>
{},
True
);
blockwise_wei_copy
.
MoveSrcSliceWindow
(
Sequence
<
CPerBlock
,
0
,
0
,
0
>
{},
True
);
}
// reset C
blockwise_in_copy
.
MoveSrcSliceWindow
(
Sequence
<
C
,
0
,
0
,
0
>
{},
False
);
blockwise_wei_copy
.
MoveSrcSliceWindow
(
Sequence
<
C
,
0
,
0
,
0
>
{},
False
);
// move aling X
blockwise_in_copy
.
MoveSrcSliceWindow
(
Sequence
<
0
,
0
,
1
,
0
>
{},
True
);
blockwise_wei_copy
.
MoveSrcSliceWindow
(
Sequence
<
0
,
0
,
1
,
0
>
{},
True
);
}
// reset X
blockwise_in_copy
.
MoveSrcSliceWindow
(
Sequence
<
0
,
0
,
X
,
0
>
{},
False
);
blockwise_wei_copy
.
MoveSrcSliceWindow
(
Sequence
<
0
,
0
,
X
,
0
>
{},
False
);
// move along Y
blockwise_in_copy
.
MoveSrcSliceWindow
(
Sequence
<
0
,
1
,
0
,
0
>
{},
False
);
blockwise_wei_copy
.
MoveSrcSliceWindow
(
Sequence
<
0
,
1
,
0
,
0
>
{},
False
);
}
#endif
// output: register to global mem
const
auto
c_thread_mtx_begin
=
...
...
composable_kernel/include/tensor_description/ConstantTensorDescriptor.hpp
View file @
238d58c2
...
...
@@ -204,7 +204,7 @@ struct ConstantTensorDescriptor
}
// This function doesn't do carry check on the highest dimension for positive stepping (or
// borrow check on the
low
est dimension for negative stepping) , for performance reason. It is
// borrow check on the
high
est dimension for negative stepping) , for performance reason. It is
// the user's responsibility to make sure the result "new_mutli_id" is not out-of-bound on the
// highest dimension for positive stepping (or on the lowest dimension for negative stepping)
template
<
bool
PositiveDirection
>
...
...
@@ -304,14 +304,73 @@ struct ConstantTensorDescriptor
GetStrides
().
PushBack
(
leaf_tensor
::
GetStrides
()))
>
{};
}
template
<
index_t
IDimVector
,
index_t
DataPerVector
>
struct
lambda_IsVectorizationAllowed
{
bool
&
is_allowed
;
__host__
__device__
constexpr
lambda_IsVectorizationAllowed
(
bool
&
is_allowed_
)
:
is_allowed
(
is_allowed_
)
{
}
template
<
index_t
IDim_
>
__host__
__device__
constexpr
void
operator
()(
Number
<
IDim_
>
)
const
{
constexpr
auto
IDim
=
Number
<
IDim_
>
{};
if
(
IDimVector
!=
IDim
&&
Strides
::
Get
(
IDim
)
%
DataPerVector
!=
0
)
{
is_allowed
=
false
;
}
}
};
template
<
index_t
IDimVector
,
index_t
DataPerVector
>
__host__
__device__
static
constexpr
bool
IsVectorizationAllowed
(
Number
<
IDimVector
>
,
Number
<
DataPerVector
>
)
{
bool
is_allowed
=
(
Strides
{}[
IDimVector
]
==
1
||
DataPerVector
==
1
)
&&
Lengths
{}[
IDimVector
]
%
DataPerVector
==
0
;
static_for
<
0
,
nDim
,
1
>
{}(
lambda_IsVectorizationAllowed
<
IDimVector
,
DataPerVector
>
{
is_allowed
});
return
is_allowed
;
}
template
<
index_t
IDim
,
index_t
DataPerVector
>
__host__
__device__
static
constexpr
auto
Vectorize
(
Number
<
IDim
>
,
Number
<
DataPerVector
>
)
{
constexpr
auto
idim
=
Number
<
IDim
>
{};
constexpr
auto
data_per_vector
=
Number
<
DataPerVector
>
{};
static_assert
(
IsVectorizationAllowed
(
idim
,
data_per_vector
),
"wrong!"
);
using
vectorized_lengths
=
decltype
(
Lengths
::
Modify
(
Number
<
IDim
>
{},
Number
<
Lengths
{}[
IDim
]
/
DataPerVector
>
{}));
using
vectorized_strides
=
decltype
((
Strides
{}
/
Number
<
DataPerVector
>
{}).
Modify
(
Number
<
IDim
>
{},
Number
<
1
>
{}));
return
ConstantTensorDescriptor
<
vectorized_lengths
,
vectorized_strides
>
{};
}
template
<
index_t
IDim
,
index_t
SliceLen
>
__host__
__device__
static
constexpr
auto
Slice
(
Number
<
IDim
>
,
Number
<
SliceLen
>
)
{
using
slice_lengths
=
decltype
(
Lengths
{}.
Modify
(
Number
<
IDim
>
{},
Number
<
SliceLen
>
{}));
using
slice_lengths
=
decltype
(
Lengths
::
Modify
(
Number
<
IDim
>
{},
Number
<
SliceLen
>
{}));
return
ConstantTensorDescriptor
<
slice_lengths
,
Strides
>
{};
}
template
<
index_t
...
Is
>
__host__
__device__
static
constexpr
auto
Slice
(
Sequence
<
Is
...
>
slice_lengths
)
{
static_assert
(
slice_lengths
.
GetSize
()
==
nDim
,
"wrong!"
);
return
ConstantTensorDescriptor
<
decltype
(
slice_lengths
),
Strides
>
{};
}
template
<
index_t
IDim
,
index_t
SliceLength
,
index_t
SliceStride
>
__host__
__device__
static
constexpr
auto
StridedSlice
(
Number
<
IDim
>
,
Number
<
SliceLength
>
,
Number
<
SliceStride
>
)
...
...
composable_kernel/include/tensor_description/tensor_coordinate.hpp
View file @
238d58c2
...
...
@@ -7,6 +7,7 @@
namespace
ck
{
// TensorDesc is ConstantTensorDescriptor
template
<
class
TensorDesc
>
struct
NormalTensorCoordinate
{
...
...
@@ -26,6 +27,12 @@ struct NormalTensorCoordinate
{
}
template
<
index_t
...
Xs
>
__host__
__device__
constexpr
NormalTensorCoordinate
(
Sequence
<
Xs
...
>
)
:
NormalTensorCoordinate
(
Array
<
index_t
,
nDim
>
{
Xs
...})
{
}
__host__
__device__
constexpr
index_t
GetOffset
()
const
{
return
mOffset
;
}
// T is Array or Sequence
...
...
@@ -87,6 +94,7 @@ struct NormalTensorCoordinate
index_t
mOffset
;
};
// TensorDesc is ConstantMergedTensorDescriptor
template
<
class
TensorDesc
>
struct
MergedTensorCoordinate
{
...
...
@@ -235,6 +243,8 @@ struct MergedTensorCoordinate
static_assert
(
is_same
<
typename
T
::
data_type
,
index_t
>
{}
&&
T
::
GetSize
()
==
nDim
,
"wrong!"
);
static_for
<
0
,
nDim
,
1
>
{}([
&
](
auto
idim
)
{
// compiler should remove dead code path, because step_sizes is known at
// compile time
if
(
step_sizes
[
idim
]
!=
0
)
{
this
->
MoveOnDimension
(
idim
,
step_sizes
[
idim
],
integral_constant
<
bool
,
true
>
{});
...
...
@@ -250,6 +260,8 @@ struct MergedTensorCoordinate
static_assert
(
is_same
<
typename
T
::
data_type
,
index_t
>
{}
&&
T
::
GetSize
()
==
nDim
,
"wrong!"
);
static_for
<
0
,
nDim
,
1
>
{}([
&
](
auto
idim
)
{
// compiler should remove dead code path, because step_sizes is known at
// compile time
if
(
step_sizes
[
idim
]
!=
0
)
{
this
->
MoveOnDimension
(
idim
,
step_sizes
[
idim
],
integral_constant
<
bool
,
false
>
{});
...
...
@@ -287,7 +299,7 @@ struct MergedTensorCoordinate
// arithmetic after construction of TensorCoordinate.
// TODO: refactor TensorCoordinate, after introducing the concept of "dimensions"
// and simplify implementation of ConstantMergedTensorDescriptor, so we don't need to
// count on compiler to optimize way those register memory for us
// count on compiler to optimize
a
way those register memory for us
Array
<
index_t
,
nOriginalDim
>
mOriginalIndex
;
Array
<
index_t
,
nDim
>
mPartialOffsets
;
...
...
composable_kernel/include/tensor_description/tensor_view.hpp
0 → 100644
View file @
238d58c2
#ifndef CK_TENSOR_VIEW_HPP
#define CK_TENSOR_VIEW_HPP
#include "common_header.hpp"
#include "ConstantTensorDescriptor.hpp"
#include "ConstantMergedTensorDescriptor.hpp"
#include "tensor_coordinate.hpp"
namespace
ck
{
// TensorDesc is ConstantTensorDescriptor or ConstantMergedTensorDescriptor
template
<
class
TensorDesc
,
class
TData
>
struct
NormalTensorView
{
using
type
=
NormalTensorView
;
using
tensor_desc_type
=
TensorDesc
;
using
coordinate_type
=
typename
NormalTensorCoordinate
<
TensorDesc
>::
type
;
using
data_type
=
TData
;
static
constexpr
auto
nDim
=
TensorDesc
::
GetNumOfDimension
();
__host__
__device__
constexpr
NormalTensorView
(
TData
*
p_data
)
:
mpData
{
p_data
}
{}
__host__
__device__
constexpr
NormalTensorView
()
:
NormalTensorView
{
nullptr
}
{}
__host__
__device__
static
constexpr
auto
GetNumOfDimension
()
{
return
nDim
;
}
__host__
__device__
static
constexpr
auto
GetLengths
()
{
return
TensorDesc
::
GetLengths
();
}
__host__
__device__
const
TData
&
operator
[](
coordinate_type
coord
)
const
{
return
mpData
[
coord
.
GetOffset
()];
}
__host__
__device__
TData
&
operator
()(
coordinate_type
coord
)
const
{
return
mpData
[
coord
.
GetOffset
()];
}
template
<
class
IDim
,
class
DataPerVector
>
__host__
__device__
static
constexpr
auto
IsVectorizationAllowed
(
IDim
,
DataPerVector
)
{
return
TensorDesc
::
IsVectorizationAllowed
(
IDim
{},
DataPerVector
{});
}
template
<
class
IDim
,
class
DataPerVector
>
__host__
__device__
auto
Vectorize
(
IDim
idim
,
DataPerVector
data_per_vector
)
const
{
static_assert
(
IsVectorizationAllowed
(
idim
,
data_per_vector
),
"wrong!"
);
using
vector_t
=
typename
vector_type
<
TData
,
data_per_vector
>::
MemoryType
;
return
NormalTensorView
<
decltype
(
TensorDesc
::
Vectorize
(
idim
,
data_per_vector
)),
vector_t
>
(
reinterpret_cast
<
vector_t
*>
(
mpData
));
}
template
<
index_t
...
Is
>
__host__
__device__
auto
Slice
(
coordinate_type
slice_origin
,
Sequence
<
Is
...
>
slice_lengths
)
{
static_assert
(
slice_lengths
.
GetSize
()
==
nDim
,
"wrong!"
);
return
NormalTensorView
<
decltype
(
TensorDesc
::
Slice
(
slice_lengths
)),
TData
>
(
mpData
+
slice_origin
.
GetOffset
());
}
template
<
class
IDim
,
class
SliceLen
>
__host__
__device__
auto
Slice
(
coordinate_type
slice_origin
,
IDim
idim
,
SliceLen
slice_len
)
const
{
return
NormalTensorView
<
decltype
(
TensorDesc
::
Slice
(
idim
,
slice_len
)),
TData
>
(
mpData
+
slice_origin
.
GetOffset
());
}
// slice_window is a slicing window on "*this"
template
<
class
SliceWindow
,
class
T
,
bool
PositiveDirection
>
__device__
void
MoveSliceWindow
(
SliceWindow
&
slice_window
,
T
step_sizes
,
integral_constant
<
bool
,
PositiveDirection
>
)
{
if
(
PositiveDirection
)
{
slice_window
.
mpData
+=
coordinate_type
{
step_sizes
}.
GetOffset
();
}
else
{
slice_window
.
mpData
-=
coordinate_type
{
step_sizes
}.
GetOffset
();
}
}
// private:
data_type
*
mpData
;
};
template
<
class
...
Xs
,
class
TData
>
__host__
__device__
constexpr
auto
make_TensorView
(
ConstantTensorDescriptor
<
Xs
...
>
,
TData
*
p_data
)
{
return
NormalTensorView
<
ConstantTensorDescriptor
<
Xs
...
>
,
TData
>
{
p_data
};
}
}
// namespace ck
#endif
composable_kernel/include/tensor_operation/blockwise_generic_tensor_slice_copy.hpp
View file @
238d58c2
...
...
@@ -5,6 +5,7 @@
#include "ConstantTensorDescriptor.hpp"
#include "ConstantMergedTensorDescriptor.hpp"
#include "tensor_coordinate.hpp"
#include "tensor_view.hpp"
#include "threadwise_generic_tensor_slice_copy.hpp"
#ifndef CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_BLOCKWISE_GENERIC_SLICE_COPY_V1
...
...
@@ -442,12 +443,13 @@ struct BlockwiseGenericTensorSliceCopy_v2
__device__
constexpr
BlockwiseGenericTensorSliceCopy_v2
(
SrcCoordinate
src_block_slice_origin
,
DstCoordinate
dst_block_slice_origin
)
{
static_assert
(
nDim
==
SrcDesc
::
GetNumOfDimension
()
&&
nDim
==
DstDesc
::
GetNumOfDimension
()
&&
nDim
==
SliceLengths
::
GetSize
()
&&
nDim
==
SubLengths
::
GetSize
()
&&
nDim
==
ThreadClusterLengths
::
GetSize
()
&&
nDim
==
ThreadClusterArrangeOrder
::
GetSize
(),
"wrong! nDim not consistent"
);
static_assert
(
nDim
==
SrcDesc
::
GetNumOfDimension
()
&&
nDim
==
DstDesc
::
GetNumOfDimension
()
&&
nDim
==
SliceLengths
::
GetSize
()
&&
nDim
==
SubLengths
::
GetSize
()
&&
nDim
==
ThreadClusterLengths
::
GetSize
()
&&
nDim
==
ThreadClusterArrangeOrder
::
GetSize
()
&&
nDim
==
SrcDimAccessOrder
::
GetSize
()
&&
nDim
==
DstDimAccessOrder
::
GetSize
(),
"wrong! nDim not consistent"
);
static_assert
(
is_same
<
SliceLengths
,
decltype
(
SubLengths
{}
*
ThreadClusterLengths
{})
>
{},
"wrong! threads should be mapped to cover entire slicing window"
);
...
...
@@ -542,6 +544,129 @@ struct BlockwiseGenericTensorSliceCopy_v2
ThreadwiseStore
mThreadwiseStore
;
};
template
<
index_t
BlockSize
,
class
SrcTensor
,
class
DstTensor
,
class
SliceLengths
,
class
SubLengths
,
class
ThreadClusterLengths
,
class
ThreadClusterArrangeOrder
,
class
SrcDimAccessOrder
,
class
DstDimAccessOrder
,
index_t
SrcVectorAccessDim
,
index_t
DstVectorAccessDim
,
index_t
SrcDataPerAccess
,
index_t
DstDataPerAccess
>
struct
BlockwiseGenericTensorSliceCopy_v3
{
static
constexpr
index_t
nDim
=
SrcTensor
::
GetNumOfDimension
();
using
data_type
=
remove_cv_t
<
typename
SrcTensor
::
data_type
>
;
using
SrcCoordinate
=
typename
SrcTensor
::
coordinate_type
;
using
DstCoordinate
=
typename
DstTensor
::
coordinate_type
;
__device__
constexpr
BlockwiseGenericTensorSliceCopy_v3
(
SrcTensor
src_block
,
SrcCoordinate
src_block_slice_origin
,
DstTensor
dst_block
,
DstCoordinate
dst_block_slice_origin
)
:
mThreadBuffer
{
make_TensorView
(
ThreadBufferDesc
{},
mpBuffer
)}
{
static_assert
(
nDim
==
SrcTensor
::
GetNumOfDimension
()
&&
nDim
==
DstTensor
::
GetNumOfDimension
()
&&
nDim
==
SliceLengths
::
GetSize
()
&&
nDim
==
SubLengths
::
GetSize
()
&&
nDim
==
ThreadClusterLengths
::
GetSize
()
&&
nDim
==
ThreadClusterArrangeOrder
::
GetSize
()
&&
nDim
==
SrcDimAccessOrder
::
GetSize
()
&&
nDim
==
DstDimAccessOrder
::
GetSize
(),
"wrong! nDim not consistent"
);
static_assert
(
is_same
<
SliceLengths
,
decltype
(
SubLengths
{}
*
ThreadClusterLengths
{})
>
{},
"wrong! threads should be mapped to cover entire slicing window"
);
static_assert
(
is_same
<
remove_cv_t
<
typename
SrcTensor
::
data_type
>
,
remove_cv_t
<
typename
DstTensor
::
data_type
>>
{},
"wrong! type conversion not supported yet"
);
constexpr
auto
thread_cluster_desc
=
make_ConstantTensorDescriptor_packed
(
ThreadClusterLengths
::
ReorderGivenNew2Old
(
ThreadClusterArrangeOrder
{}));
static_assert
(
BlockSize
==
thread_cluster_desc
.
GetElementSize
(),
"wrong! BlockSize not consistent with ThreadClusterLengths"
);
const
auto
thread_cluster_id
=
thread_cluster_desc
.
GetMultiIndexFrom1dIndex
(
get_thread_local_1d_id
());
const
auto
data_cluster_id
=
reorder_array_given_old2new
(
thread_cluster_id
,
ThreadClusterArrangeOrder
{});
const
auto
thread_data_id_begin
=
data_cluster_id
*
SubLengths
{};
mThreadwiseLoad
=
ThreadwiseLoad
(
src_block
,
src_block_slice_origin
+
thread_data_id_begin
,
mThreadBuffer
,
make_zero_array
<
index_t
,
nDim
>
());
mThreadwiseStore
=
ThreadwiseStore
(
mThreadBuffer
,
make_zero_array
<
index_t
,
nDim
>
(),
dst_block
,
dst_block_slice_origin
+
thread_data_id_begin
);
}
__device__
void
RunLoadRegisterBuffer
()
{
mThreadwiseLoad
.
Run
();
}
__device__
void
RunStoreRegisterBuffer
()
const
{
mThreadwiseStore
.
Run
();
}
__device__
void
Run
()
{
mThreadwiseLoad
.
Run
();
mThreadwiseStore
.
Run
();
}
template
<
class
T
,
bool
PositiveDirection
>
__device__
void
MoveSrcSliceWindow
(
T
step_sizes
,
integral_constant
<
bool
,
PositiveDirection
>
positive_direction
)
{
mThreadwiseLoad
.
MoveSrcSliceWindow
(
step_sizes
,
positive_direction
);
}
template
<
class
T
,
bool
PositiveDirection
>
__device__
void
MoveDstSliceWindow
(
T
step_sizes
,
integral_constant
<
bool
,
PositiveDirection
>
positive_direction
)
{
mThreadwiseStore
.
MoveDstSliceWindow
(
step_sizes
,
positive_direction
);
}
private:
using
ThreadBufferDesc
=
decltype
(
make_ConstantTensorDescriptor_packed
(
SubLengths
{}));
using
ThreadBufferTensor
=
NormalTensorView
<
ThreadBufferDesc
,
data_type
>
;
using
ThreadwiseLoad
=
ThreadwiseGenericTensorSliceCopy_v3
<
SrcTensor
,
ThreadBufferTensor
,
SubLengths
,
SrcDimAccessOrder
,
SrcDimAccessOrder
,
SrcVectorAccessDim
,
SrcVectorAccessDim
,
SrcDataPerAccess
,
1
>
;
using
ThreadwiseStore
=
ThreadwiseGenericTensorSliceCopy_v3
<
ThreadBufferTensor
,
DstTensor
,
SubLengths
,
DstDimAccessOrder
,
DstDimAccessOrder
,
DstVectorAccessDim
,
DstVectorAccessDim
,
1
,
DstDataPerAccess
>
;
data_type
mpBuffer
[
ThreadBufferDesc
::
GetElementSpace
()];
ThreadBufferTensor
mThreadBuffer
;
ThreadwiseLoad
mThreadwiseLoad
;
ThreadwiseStore
mThreadwiseStore
;
};
}
// namespace ck
#endif
composable_kernel/include/tensor_operation/threadwise_generic_tensor_slice_copy.hpp
View file @
238d58c2
...
...
@@ -5,6 +5,7 @@
#include "ConstantTensorDescriptor.hpp"
#include "ConstantMergedTensorDescriptor.hpp"
#include "tensor_coordinate.hpp"
#include "tensor_view.hpp"
#ifndef CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_THREADWISE_GENERIC_TENSOR_SLICE_COPY_V1R1
#define CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_THREADWISE_GENERIC_TENSOR_SLICE_COPY_V1R1 0
...
...
@@ -773,5 +774,171 @@ struct ThreadwiseGenericTensorSliceCopy_v2r1
DstCoordinate
mDstSliceOrigin
;
};
template
<
class
SrcTensor
,
class
DstTensor
,
class
SliceLengths
,
class
SrcDimAccessOrder
,
class
DstDimAccessOrder
,
index_t
SrcVectorAccessDim
,
index_t
DstVectorAccessDim
,
index_t
SrcDataPerAccess
,
index_t
DstDataPerAccess
>
struct
ThreadwiseGenericTensorSliceCopy_v3
{
static
constexpr
index_t
nDim
=
SrcTensor
::
GetNumOfDimension
();
using
data_type
=
remove_cv_t
<
typename
SrcTensor
::
data_type
>
;
using
SrcCoordinate
=
typename
SrcTensor
::
coordinate_type
;
using
DstCoordinate
=
typename
DstTensor
::
coordinate_type
;
__device__
constexpr
ThreadwiseGenericTensorSliceCopy_v3
(
SrcTensor
src
,
SrcCoordinate
src_slice_origin
,
DstTensor
dst
,
DstCoordinate
dst_slice_origin
)
:
mSrc
{
src
},
mDst
{
dst
},
mSrcSlice
{
src
.
Slice
(
src_slice_origin
,
SliceLengths
{})},
mDstSlice
{
dst
.
Slice
(
dst_slice_origin
,
SliceLengths
{})}
{
static_assert
(
nDim
==
SrcTensor
::
GetNumOfDimension
()
&&
nDim
==
DstTensor
::
GetNumOfDimension
()
&&
nDim
==
SliceLengths
::
GetSize
()
&&
nDim
==
SrcDimAccessOrder
::
GetSize
()
&&
nDim
==
DstDimAccessOrder
::
GetSize
(),
"wrong! # of dimensions not the same"
);
static_assert
(
is_valid_sequence_map
<
SrcDimAccessOrder
>::
value
&&
is_valid_sequence_map
<
DstDimAccessOrder
>::
value
,
"wrong! map is not valid"
);
static_assert
(
is_same
<
remove_cv_t
<
typename
SrcTensor
::
data_type
>
,
remove_cv_t
<
typename
DstTensor
::
data_type
>>
{},
"wrong! type conversion is not supported yet"
);
static_assert
(
decltype
(
mSrcSlice
)
::
IsVectorizationAllowed
(
Number
<
SrcVectorAccessDim
>
{},
Number
<
SrcDataPerAccess
>
{})
&&
decltype
(
mDstSlice
)
::
IsVectorizationAllowed
(
Number
<
DstVectorAccessDim
>
{},
Number
<
DstDataPerAccess
>
{}),
"wrong! vectorized access is not allowed"
);
}
__device__
constexpr
ThreadwiseGenericTensorSliceCopy_v3
()
:
ThreadwiseGenericTensorSliceCopy_v3
(
SrcTensor
{},
SrcCoordinate
{},
DstTensor
{},
DstCoordinate
{})
{
}
__device__
void
Run
()
const
{
// buffer
constexpr
auto
buffer_desc
=
make_ConstantTensorDescriptor_packed
(
SrcTensor
::
GetLengths
());
data_type
p_buffer
[
buffer_desc
.
GetElementSpace
()];
auto
buffer
=
make_TensorView
(
buffer_desc
,
p_buffer
);
// copy data from src into buffer
{
using
src_vector_t
=
typename
vector_type
<
data_type
,
SrcDataPerAccess
>::
MemoryType
;
constexpr
auto
src_vector_access_dim
=
Number
<
SrcVectorAccessDim
>
{};
constexpr
auto
src_data_per_access
=
Number
<
SrcDataPerAccess
>
{};
auto
src_slice_vectorized
=
mSrcSlice
.
Vectorize
(
src_vector_access_dim
,
src_data_per_access
);
#if 0
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
{
print_ConstantTensorDescriptor("mSrcSlice: ", typename decltype(mSrcSlice)::tensor_desc_type{});
print_ConstantTensorDescriptor("src_slice_vector: ", typename decltype(src_slice_vectorized)::tensor_desc_type{});
}
#endif
#if 1 // debug
ford
<
decltype
(
src_slice_vectorized
.
GetLengths
()),
SrcDimAccessOrder
>
{}(
[
&
](
auto
src_vector_id
)
{
// load vector from src
const
src_vector_t
vector_data
=
src_slice_vectorized
[
src_vector_id
];
// unpack vector into buffer
auto
src_scalar_id
=
src_vector_id
;
src_scalar_id
(
src_vector_access_dim
)
*=
src_data_per_access
;
for
(
index_t
i
=
0
;
i
<
SrcDataPerAccess
;
++
i
)
{
auto
id
=
make_zero_array
<
index_t
,
nDim
>
();
id
(
src_vector_access_dim
)
=
i
;
buffer
(
src_scalar_id
+
id
)
=
reinterpret_cast
<
const
data_type
*>
(
&
vector_data
)[
i
];
}
});
#endif
}
// copy data from buffer into dst
{
using
dst_vector_t
=
typename
vector_type
<
data_type
,
DstDataPerAccess
>::
MemoryType
;
constexpr
auto
dst_vector_access_dim
=
Number
<
DstVectorAccessDim
>
{};
constexpr
auto
dst_data_per_access
=
Number
<
DstDataPerAccess
>
{};
auto
dst_slice_vectorized
=
mDstSlice
.
Vectorize
(
dst_vector_access_dim
,
dst_data_per_access
);
#if 0
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
{
print_ConstantTensorDescriptor("mDstSlice: ", typename decltype(mDstSlice)::tensor_desc_type{});
print_ConstantTensorDescriptor("dst_slice_vector: ", typename decltype(dst_slice_vectorized)::tensor_desc_type{});
}
#endif
#if 1 // debug
ford
<
decltype
(
dst_slice_vectorized
.
GetLengths
()),
DstDimAccessOrder
>
{}(
[
&
](
auto
dst_vector_id
)
{
dst_vector_t
vector_data
{};
// pack vector from buffer
auto
dst_scalar_id
=
dst_vector_id
;
dst_scalar_id
(
dst_vector_access_dim
)
*=
dst_data_per_access
;
for
(
index_t
i
=
0
;
i
<
DstDataPerAccess
;
++
i
)
{
auto
id
=
make_zero_array
<
index_t
,
nDim
>
();
id
(
dst_vector_access_dim
)
=
i
;
reinterpret_cast
<
data_type
*>
(
&
vector_data
)[
i
]
=
buffer
[
dst_scalar_id
+
id
];
}
// write vector into dst
dst_slice_vectorized
(
dst_vector_id
)
=
vector_data
;
});
#endif
}
}
// T can be Sequence or Array
template
<
class
T
,
bool
PositiveDirection
>
__device__
void
MoveSrcSliceWindow
(
T
step_sizes
,
integral_constant
<
bool
,
PositiveDirection
>
)
{
mSrc
.
MoveSliceWindow
(
mSrcSlice
,
step_sizes
,
integral_constant
<
bool
,
PositiveDirection
>
{});
}
template
<
class
T
,
bool
PositiveDirection
>
__device__
void
MoveDstSliceWindow
(
T
step_sizes
,
integral_constant
<
bool
,
PositiveDirection
>
)
{
mDst
.
MoveSliceWindow
(
mDstSlice
,
step_sizes
,
integral_constant
<
bool
,
PositiveDirection
>
{});
}
private:
using
SrcSlice
=
decltype
(
SrcTensor
{}.
Slice
(
make_zero_array
<
index_t
,
nDim
>
(),
SliceLengths
{}));
using
DstSlice
=
decltype
(
DstTensor
{}.
Slice
(
make_zero_array
<
index_t
,
nDim
>
(),
SliceLengths
{}));
SrcTensor
mSrc
;
DstTensor
mDst
;
SrcSlice
mSrcSlice
;
DstSlice
mDstSlice
;
};
}
// namespace ck
#endif
composable_kernel/include/utility/integral_constant.hpp
View file @
238d58c2
...
...
@@ -23,6 +23,9 @@ struct is_same<X, X> : public integral_constant<bool, true>
{
};
template
<
class
T
>
using
remove_cv_t
=
typename
std
::
remove_cv
<
T
>::
type
;
template
<
index_t
N
>
using
Number
=
integral_constant
<
index_t
,
N
>
;
...
...
composable_kernel/include/utility/vector_type.hpp
View file @
238d58c2
...
...
@@ -14,7 +14,7 @@ struct vector_type
template
<
>
struct
vector_type
<
float
,
1
>
{
typedef
float
MemoryType
;
using
MemoryType
=
float
;
template
<
index_t
I
>
__host__
__device__
static
void
SetScalar
(
MemoryType
&
v
,
float
s
,
Number
<
I
>
)
...
...
@@ -64,6 +64,24 @@ struct vector_type<float, 4>
}
};
template
<
>
struct
vector_type
<
const
float
,
1
>
{
using
MemoryType
=
const
float
;
};
template
<
>
struct
vector_type
<
const
float
,
2
>
{
using
MemoryType
=
const
float2_t
;
};
template
<
>
struct
vector_type
<
const
float
,
4
>
{
using
MemoryType
=
const
float4_t
;
};
}
// namespace ck
#endif
driver/src/driver.cpp
View file @
238d58c2
...
...
@@ -72,9 +72,9 @@ int main(int argc, char* argv[])
{
using
namespace
ck
;
#if
0
#if
1
constexpr
index_t
N
=
64
;
constexpr index_t C =
1536
;
constexpr
index_t
C
=
8
;
constexpr
index_t
HI
=
8
;
constexpr
index_t
WI
=
8
;
constexpr
index_t
K
=
256
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment