Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
5c8112d9
Commit
5c8112d9
authored
Jan 28, 2022
by
Jing Zhang
Browse files
merge develop
parents
7b1ce567
ca47a6cf
Changes
89
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
4446 additions
and
389 deletions
+4446
-389
composable_kernel/include/tensor_description/static_tensor.hpp
...sable_kernel/include/tensor_description/static_tensor.hpp
+20
-15
composable_kernel/include/tensor_operation/blockwise_tensor_slice_transfer_v4r1.hpp
...tensor_operation/blockwise_tensor_slice_transfer_v4r1.hpp
+26
-12
composable_kernel/include/tensor_operation/blockwise_tensor_slice_transfer_v5r1.hpp
...tensor_operation/blockwise_tensor_slice_transfer_v5r1.hpp
+6
-6
composable_kernel/include/tensor_operation/blockwise_tensor_slice_transfer_v6r1.hpp
...tensor_operation/blockwise_tensor_slice_transfer_v6r1.hpp
+133
-0
composable_kernel/include/tensor_operation/blockwise_tensor_slice_transfer_v6r2.hpp
...tensor_operation/blockwise_tensor_slice_transfer_v6r2.hpp
+157
-0
composable_kernel/include/tensor_operation/blockwise_tensor_slice_transfer_v6r3.hpp
...tensor_operation/blockwise_tensor_slice_transfer_v6r3.hpp
+182
-0
composable_kernel/include/tensor_operation/element_wise_operation.hpp
...ernel/include/tensor_operation/element_wise_operation.hpp
+185
-0
composable_kernel/include/tensor_operation/gridwise_contraction_dlops_v1r2.hpp
...lude/tensor_operation/gridwise_contraction_dlops_v1r2.hpp
+2
-2
composable_kernel/include/tensor_operation/gridwise_gemm_dlops_v1r3.hpp
...nel/include/tensor_operation/gridwise_gemm_dlops_v1r3.hpp
+3
-3
composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v2r3.hpp
...el/include/tensor_operation/gridwise_gemm_xdlops_v2r3.hpp
+114
-149
composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v2r4.hpp
...el/include/tensor_operation/gridwise_gemm_xdlops_v2r4.hpp
+76
-76
composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v2r5.hpp
...el/include/tensor_operation/gridwise_gemm_xdlops_v2r5.hpp
+70
-90
composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v2r6.hpp
...el/include/tensor_operation/gridwise_gemm_xdlops_v2r6.hpp
+617
-0
composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v3r1.hpp
...el/include/tensor_operation/gridwise_gemm_xdlops_v3r1.hpp
+744
-0
composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v3r2.hpp
...el/include/tensor_operation/gridwise_gemm_xdlops_v3r2.hpp
+784
-0
composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v3r3.hpp
...el/include/tensor_operation/gridwise_gemm_xdlops_v3r3.hpp
+823
-0
composable_kernel/include/tensor_operation/threadwise_tensor_slice_transfer.hpp
...ude/tensor_operation/threadwise_tensor_slice_transfer.hpp
+8
-8
composable_kernel/include/tensor_operation/threadwise_tensor_slice_transfer_v1r4.hpp
...ensor_operation/threadwise_tensor_slice_transfer_v1r4.hpp
+12
-13
composable_kernel/include/tensor_operation/threadwise_tensor_slice_transfer_v1r5.hpp
...ensor_operation/threadwise_tensor_slice_transfer_v1r5.hpp
+453
-0
composable_kernel/include/tensor_operation/threadwise_tensor_slice_transfer_v3r1.hpp
...ensor_operation/threadwise_tensor_slice_transfer_v3r1.hpp
+31
-15
No files found.
composable_kernel/include/tensor_description/static_tensor.hpp
View file @
5c8112d9
#ifndef CK_STATIC_TENSOR_HPP
#define CK_STATIC_TENSOR_HPP
#include "ignore.hpp"
namespace
ck
{
// StaticTensor for Scalar
...
...
@@ -17,10 +15,10 @@ struct StaticTensor
static
constexpr
index_t
ndim_
=
TensorDesc
::
GetNumOfDimension
();
static
constexpr
index_t
element_space_size_
=
desc_
.
GetElementSpaceSize
();
__host__
__device__
constexpr
StaticTensor
()
:
invalid_element_value_
{
0
}
{}
__host__
__device__
constexpr
StaticTensor
()
:
invalid_element_
scalar_
value_
{
0
}
{}
__host__
__device__
constexpr
StaticTensor
(
T
invalid_element_value
)
:
invalid_element_value_
{
invalid_element_value
}
:
invalid_element_
scalar_
value_
{
invalid_element_value
}
{
}
...
...
@@ -44,11 +42,11 @@ struct StaticTensor
{
if
constexpr
(
InvalidElementUseNumericalZeroValue
)
{
return
T
{
0
}
;
return
zero_scalar_value_
;
}
else
{
return
invalid_element_value_
;
return
invalid_element_
scalar_
value_
;
}
}
}
...
...
@@ -71,12 +69,14 @@ struct StaticTensor
}
else
{
return
ignore
;
return
ignore
d_element_scalar_
;
}
}
StaticBuffer
<
AddressSpace
,
T
,
element_space_size_
,
true
>
data_
;
T
invalid_element_value_
=
T
{
0
};
static
constexpr
T
zero_scalar_value_
=
T
{
0
};
const
T
invalid_element_scalar_value_
;
T
ignored_element_scalar_
;
};
// StaticTensor for vector
...
...
@@ -97,10 +97,13 @@ struct StaticTensorTupleOfVectorBuffer
using
V
=
vector_type
<
S
,
ScalarPerVector
>
;
__host__
__device__
constexpr
StaticTensorTupleOfVectorBuffer
()
:
invalid_element_value_
{
0
}
{}
__host__
__device__
constexpr
StaticTensorTupleOfVectorBuffer
()
:
invalid_element_scalar_value_
{
0
}
{
}
__host__
__device__
constexpr
StaticTensorTupleOfVectorBuffer
(
S
invalid_element_value
)
:
invalid_element_value_
{
invalid_element_value
}
:
invalid_element_
scalar_
value_
{
invalid_element_value
}
{
}
...
...
@@ -125,11 +128,11 @@ struct StaticTensorTupleOfVectorBuffer
{
if
constexpr
(
InvalidElementUseNumericalZeroValue
)
{
return
S
{
0
}
;
return
zero_scalar_value_
;
}
else
{
return
invalid_element_value_
;
return
invalid_element_
scalar_
value_
;
}
}
}
...
...
@@ -153,7 +156,7 @@ struct StaticTensorTupleOfVectorBuffer
}
else
{
return
ignore
;
return
ignore
d_element_scalar_
;
}
}
...
...
@@ -186,7 +189,7 @@ struct StaticTensorTupleOfVectorBuffer
else
{
// TODO: is this right way to initialize a vector?
return
X
{
invalid_element_value_
};
return
X
{
invalid_element_
scalar_
value_
};
}
}
}
...
...
@@ -237,7 +240,9 @@ struct StaticTensorTupleOfVectorBuffer
}
StaticBufferTupleOfVector
<
AddressSpace
,
S
,
num_of_vector_
,
ScalarPerVector
,
true
>
data_
;
S
invalid_element_value_
=
S
{
0
};
static
constexpr
S
zero_scalar_value_
=
S
{
0
};
const
S
invalid_element_scalar_value_
=
S
{
0
};
S
ignored_element_scalar_
;
};
template
<
AddressSpaceEnum_t
AddressSpace
,
...
...
composable_kernel/include/tensor_operation/blockwise_tensor_slice_transfer.hpp
→
composable_kernel/include/tensor_operation/blockwise_tensor_slice_transfer
_v4r1
.hpp
View file @
5c8112d9
#ifndef CK_BLOCKWISE_TENSOR_SLICE_TRANSFER_HPP
#define CK_BLOCKWISE_TENSOR_SLICE_TRANSFER_HPP
#ifndef CK_BLOCKWISE_TENSOR_SLICE_TRANSFER_
V4R1_
HPP
#define CK_BLOCKWISE_TENSOR_SLICE_TRANSFER_
V4R1_
HPP
#include "common_header.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "cluster_descriptor.hpp"
#include "threadwise_tensor_slice_transfer_v3r
2
.hpp"
#include "threadwise_tensor_slice_transfer_v3r
1
.hpp"
namespace
ck
{
...
...
@@ -15,9 +15,9 @@ namespace ck {
// 3. ThreadwiseTensorSliceTransfer_v3::Run() does not construct new tensor coordinate
template
<
index_t
BlockSize
,
typename
SrcElementwiseOperation
,
typename
DstElementwiseOperation
,
InMemoryDataOperationEnum_t
DstInMemOp
,
typename
BlockSliceLengths
,
typename
ThreadSliceLengths
,
typename
ThreadClusterLengths
,
typename
ThreadClusterArrangeOrder
,
typename
SrcData
,
...
...
@@ -34,35 +34,38 @@ template <index_t BlockSize,
index_t
DstScalarStrideInVector
,
bool
ThreadTransferSrcResetCoordinateAfterRun
,
bool
ThreadTransferDstResetCoordinateAfterRun
>
struct
BlockwiseTensorSliceTransfer_v4
struct
BlockwiseTensorSliceTransfer_v4
r1
{
static
constexpr
index_t
nDim
=
remove_reference_t
<
SrcDesc
>::
GetNumOfDimension
();
static
constexpr
auto
thread_slice_lengths
=
BlockSliceLengths
{}
/
ThreadClusterLengths
{};
using
Index
=
MultiIndex
<
nDim
>
;
__device__
constexpr
BlockwiseTensorSliceTransfer_v4
(
__device__
constexpr
BlockwiseTensorSliceTransfer_v4
r1
(
const
SrcDesc
&
src_desc
,
const
Index
&
src_block_slice_origin
,
const
SrcElementwiseOperation
&
src_element_op
,
const
DstDesc
&
dst_desc
,
const
Index
&
dst_block_slice_origin
,
const
Src
ElementwiseOperation
&
src
_element_op
)
const
Dst
ElementwiseOperation
&
dst
_element_op
)
:
threadwise_transfer_
(
src_desc
,
make_zero_multi_index
<
nDim
>
(),
src_element_op
,
dst_desc
,
make_zero_multi_index
<
nDim
>
(),
src
_element_op
)
dst
_element_op
)
{
static_assert
(
nDim
==
remove_reference_t
<
remove_cv_t
<
SrcDesc
>>::
GetNumOfDimension
()
&&
nDim
==
remove_reference_t
<
remove_cv_t
<
DstDesc
>>::
GetNumOfDimension
()
&&
nDim
==
BlockSliceLengths
::
Size
()
&&
nDim
==
ThreadSliceLengths
::
Size
()
&&
nDim
==
ThreadClusterLengths
::
Size
()
&&
nDim
==
ThreadClusterArrangeOrder
::
Size
()
&&
nDim
==
SrcDimAccessOrder
::
Size
()
&&
nDim
==
DstDimAccessOrder
::
Size
(),
"wrong! nDim not consistent"
);
static_assert
(
is_same
<
BlockSliceLengths
,
decltype
(
T
hread
S
lice
L
engths
{}
*
ThreadClusterLengths
{})
>
{},
is_same
<
BlockSliceLengths
,
decltype
(
t
hread
_s
lice
_l
engths
*
ThreadClusterLengths
{})
>
{},
"wrong! threads should be mapped to cover entire slicing window"
);
static_assert
(
BlockSize
>=
thread_cluster_desc_
.
GetElementSize
(),
...
...
@@ -74,7 +77,7 @@ struct BlockwiseTensorSliceTransfer_v4
const
auto
thread_cluster_idx
=
thread_cluster_desc_
.
CalculateBottomIndex
(
make_multi_index
(
get_thread_local_1d_id
()));
const
auto
thread_data_idx_begin
=
thread_cluster_idx
*
T
hread
S
lice
L
engths
{}
;
const
auto
thread_data_idx_begin
=
thread_cluster_idx
*
t
hread
_s
lice
_l
engths
;
threadwise_transfer_
.
SetSrcSliceOrigin
(
src_desc
,
src_block_slice_origin
+
thread_data_idx_begin
);
...
...
@@ -114,6 +117,16 @@ struct BlockwiseTensorSliceTransfer_v4
}
}
template
<
typename
SrcBuffer
,
typename
DstBuffer
>
__device__
void
Run
(
const
SrcDesc
&
src_desc
,
const
SrcBuffer
&
src_buf
,
const
DstDesc
&
dst_desc
,
DstBuffer
&
dst_buf
)
{
RunRead
(
src_desc
,
src_buf
);
RunWrite
(
dst_desc
,
dst_buf
);
}
__device__
void
MoveSrcSliceWindow
(
const
SrcDesc
&
src_desc
,
const
Index
&
step
)
{
if
(
BlockSize
==
thread_cluster_desc_
.
GetElementSize
()
or
...
...
@@ -152,8 +165,9 @@ struct BlockwiseTensorSliceTransfer_v4
make_cluster_descriptor
(
ThreadClusterLengths
{},
ThreadClusterArrangeOrder
{});
using
ThreadwiseTransfer
=
ThreadwiseTensorSliceTransfer_v3r
2
<
T
hread
S
lice
L
engths
,
ThreadwiseTensorSliceTransfer_v3r
1
<
decltype
(
t
hread
_s
lice
_l
engths
)
,
SrcElementwiseOperation
,
DstElementwiseOperation
,
DstInMemOp
,
SrcData
,
DstData
,
...
...
composable_kernel/include/tensor_operation/blockwise_tensor_slice_transfer_v
2
.hpp
→
composable_kernel/include/tensor_operation/blockwise_tensor_slice_transfer_v
5r1
.hpp
View file @
5c8112d9
#ifndef CK_BLOCKWISE_TENSOR_SLICE_TRANSFER_V
2
_HPP
#define CK_BLOCKWISE_TENSOR_SLICE_TRANSFER_V
2
_HPP
#ifndef CK_BLOCKWISE_TENSOR_SLICE_TRANSFER_V
5R1
_HPP
#define CK_BLOCKWISE_TENSOR_SLICE_TRANSFER_V
5R1
_HPP
#include "common_header.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "cluster_descriptor.hpp"
#include "threadwise_tensor_slice_transfer_v
2
.hpp"
#include "threadwise_tensor_slice_transfer_v
5r1
.hpp"
namespace
ck
{
...
...
@@ -31,13 +31,13 @@ template <index_t BlockSize,
typename
DstVectorTensorContiguousDimOrder
,
bool
ThreadTransferSrcResetCoordinateAfterRun
,
bool
ThreadTransferDstResetCoordinateAfterRun
>
struct
BlockwiseTensorSliceTransfer_v
4
r1
struct
BlockwiseTensorSliceTransfer_v
5
r1
{
static
constexpr
index_t
nDim
=
remove_reference_t
<
SrcDesc
>::
GetNumOfDimension
();
using
Index
=
MultiIndex
<
nDim
>
;
__device__
constexpr
BlockwiseTensorSliceTransfer_v
4
r1
(
const
SrcDesc
&
src_desc
,
__device__
constexpr
BlockwiseTensorSliceTransfer_v
5
r1
(
const
SrcDesc
&
src_desc
,
const
Index
&
src_block_slice_origin
,
const
DstDesc
&
dst_desc
,
const
Index
&
dst_block_slice_origin
)
...
...
@@ -134,7 +134,7 @@ struct BlockwiseTensorSliceTransfer_v4r1
make_cluster_descriptor
(
ThreadClusterLengths
{},
ThreadClusterArrangeOrder
{});
using
ThreadwiseTransfer
=
ThreadwiseTensorSliceTransfer_v
3
r1
<
ThreadSliceLengths
,
ThreadwiseTensorSliceTransfer_v
5
r1
<
ThreadSliceLengths
,
DstInMemOp
,
SrcData
,
DstData
,
...
...
composable_kernel/include/tensor_operation/blockwise_tensor_slice_transfer_v6r1.hpp
0 → 100644
View file @
5c8112d9
#ifndef CK_BLOCKWISE_TENSOR_SLICE_TRANSFER_V6R1_HPP
#define CK_BLOCKWISE_TENSOR_SLICE_TRANSFER_V6R1_HPP
#include "common_header.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "cluster_descriptor.hpp"
#include "threadwise_tensor_slice_transfer_v6r1.hpp"
namespace
ck
{
// this version does following things to avoid scratch memory issue
// 1. Use StaticallyIndexedArray instead of C array for thread buffer
// 2. ThreadwiseTensorSliceTransfer_v3 does not keep reference to tensor descriptor
// 3. ThreadwiseTensorSliceTransfer_v3::Run() does not construct new tensor coordinate
template
<
index_t
BlockSize
,
typename
ElementwiseOperation
,
InMemoryDataOperationEnum_t
DstInMemOp
,
typename
BlockSliceLengths
,
typename
ThreadClusterLengths
,
typename
ThreadClusterArrangeOrder
,
typename
SrcData
,
typename
DstData
,
typename
SrcDesc
,
typename
DstDesc
,
typename
DimAccessOrder
,
index_t
VectorDim
,
index_t
ScalarPerVector
,
bool
ThreadTransferSrcResetCoordinateAfterRun
,
bool
ThreadTransferDstResetCoordinateAfterRun
>
struct
BlockwiseTensorSliceTransfer_v6r1
{
static
constexpr
index_t
nDim
=
remove_reference_t
<
SrcDesc
>::
GetNumOfDimension
();
static
constexpr
auto
thread_slice_lengths
=
BlockSliceLengths
{}
/
ThreadClusterLengths
{};
using
Index
=
MultiIndex
<
nDim
>
;
__device__
constexpr
BlockwiseTensorSliceTransfer_v6r1
(
const
SrcDesc
&
src_desc
,
const
Index
&
src_block_slice_origin
,
const
DstDesc
&
dst_desc
,
const
Index
&
dst_block_slice_origin
,
const
ElementwiseOperation
&
element_op
)
:
threadwise_transfer_
(
src_desc
,
make_zero_multi_index
<
nDim
>
(),
dst_desc
,
make_zero_multi_index
<
nDim
>
(),
element_op
)
{
static_assert
(
nDim
==
remove_reference_t
<
remove_cv_t
<
SrcDesc
>>::
GetNumOfDimension
()
&&
nDim
==
remove_reference_t
<
remove_cv_t
<
DstDesc
>>::
GetNumOfDimension
()
&&
nDim
==
ThreadClusterLengths
::
Size
()
&&
nDim
==
ThreadClusterArrangeOrder
::
Size
()
&&
nDim
==
DimAccessOrder
::
Size
(),
"wrong! nDim not consistent"
);
static_assert
(
is_same
<
BlockSliceLengths
,
decltype
(
thread_slice_lengths
*
ThreadClusterLengths
{})
>
{},
"wrong! threads should be mapped to cover entire slicing window"
);
static_assert
(
BlockSize
>=
thread_cluster_desc_
.
GetElementSize
(),
"wrong! BlockSize too small"
);
if
(
BlockSize
==
thread_cluster_desc_
.
GetElementSize
()
or
get_thread_local_1d_id
()
<
thread_cluster_desc_
.
GetElementSize
())
{
const
auto
thread_cluster_idx
=
thread_cluster_desc_
.
CalculateBottomIndex
(
make_multi_index
(
get_thread_local_1d_id
()));
const
auto
thread_data_idx_begin
=
thread_cluster_idx
*
thread_slice_lengths
;
threadwise_transfer_
.
SetSrcSliceOrigin
(
src_desc
,
src_block_slice_origin
+
thread_data_idx_begin
);
threadwise_transfer_
.
SetDstSliceOrigin
(
dst_desc
,
dst_block_slice_origin
+
thread_data_idx_begin
);
}
}
template
<
typename
SrcBuffer
,
typename
DstBuffer
>
__device__
void
Run
(
const
SrcDesc
&
src_desc
,
const
SrcBuffer
&
src_buf
,
const
DstDesc
&
dst_desc
,
DstBuffer
&
dst_buf
)
{
if
(
BlockSize
==
thread_cluster_desc_
.
GetElementSize
()
or
get_thread_local_1d_id
()
<
thread_cluster_desc_
.
GetElementSize
())
{
threadwise_transfer_
.
Run
(
src_desc
,
src_buf
,
dst_desc
,
dst_buf
);
}
}
__device__
void
MoveSrcSliceWindow
(
const
SrcDesc
&
src_desc
,
const
Index
&
step
)
{
if
(
BlockSize
==
thread_cluster_desc_
.
GetElementSize
()
or
get_thread_local_1d_id
()
<
thread_cluster_desc_
.
GetElementSize
())
{
threadwise_transfer_
.
MoveSrcSliceWindow
(
src_desc
,
step
);
}
}
__device__
void
MoveDstSliceWindow
(
const
DstDesc
&
dst_desc
,
const
Index
&
step
)
{
if
(
BlockSize
==
thread_cluster_desc_
.
GetElementSize
()
or
get_thread_local_1d_id
()
<
thread_cluster_desc_
.
GetElementSize
())
{
threadwise_transfer_
.
MoveDstSliceWindow
(
dst_desc
,
step
);
}
}
private:
static
constexpr
auto
thread_cluster_desc_
=
make_cluster_descriptor
(
ThreadClusterLengths
{},
ThreadClusterArrangeOrder
{});
using
ThreadwiseTransfer
=
ThreadwiseTensorSliceTransfer_v6r1
<
SrcData
,
DstData
,
SrcDesc
,
DstDesc
,
ElementwiseOperation
,
decltype
(
thread_slice_lengths
),
DimAccessOrder
,
VectorDim
,
ScalarPerVector
,
DstInMemOp
,
ThreadTransferSrcResetCoordinateAfterRun
,
ThreadTransferDstResetCoordinateAfterRun
>
;
ThreadwiseTransfer
threadwise_transfer_
;
};
}
// namespace ck
#endif
composable_kernel/include/tensor_operation/blockwise_tensor_slice_transfer_v6r2.hpp
0 → 100644
View file @
5c8112d9
#ifndef CK_BLOCKWISE_TENSOR_SLICE_TRANSFER_V6R2_HPP
#define CK_BLOCKWISE_TENSOR_SLICE_TRANSFER_V6R2_HPP
#include "common_header.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "cluster_descriptor.hpp"
#include "threadwise_tensor_slice_transfer_v6r2.hpp"
namespace
ck
{
// this version does following things to avoid scratch memory issue
// 1. Use StaticallyIndexedArray instead of C array for thread buffer
// 2. It does not keep reference to tensor descriptor
// 3. Run() does not construct new tensor coordinate
template
<
index_t
BlockSize
,
typename
ElementwiseOperation
,
InMemoryDataOperationEnum_t
DstInMemOp
,
typename
BlockSliceLengths
,
typename
ThreadClusterLengths
,
typename
ThreadClusterArrangeOrder
,
typename
Src0Data
,
typename
Src1Data
,
typename
DstData
,
typename
Src0Desc
,
typename
Src1Desc
,
typename
DstDesc
,
typename
DimAccessOrder
,
index_t
VectorDim
,
index_t
ScalarPerVector
,
bool
ThreadTransferSrc0ResetCoordinateAfterRun
,
bool
ThreadTransferSrc1ResetCoordinateAfterRun
,
bool
ThreadTransferDstResetCoordinateAfterRun
>
struct
BlockwiseTensorSliceTransfer_v6r2
{
static
constexpr
index_t
nDim
=
remove_reference_t
<
Src0Desc
>::
GetNumOfDimension
();
static
constexpr
auto
thread_slice_lengths
=
BlockSliceLengths
{}
/
ThreadClusterLengths
{};
using
Index
=
MultiIndex
<
nDim
>
;
__device__
constexpr
BlockwiseTensorSliceTransfer_v6r2
(
const
Src0Desc
&
src0_desc
,
const
Index
&
src0_block_slice_origin
,
const
Src1Desc
&
src1_desc
,
const
Index
&
src1_block_slice_origin
,
const
DstDesc
&
dst_desc
,
const
Index
&
dst_block_slice_origin
,
const
ElementwiseOperation
&
element_op
)
:
threadwise_transfer_
(
src0_desc
,
make_zero_multi_index
<
nDim
>
(),
src1_desc
,
make_zero_multi_index
<
nDim
>
(),
dst_desc
,
make_zero_multi_index
<
nDim
>
(),
element_op
)
{
static_assert
(
nDim
==
remove_reference_t
<
remove_cv_t
<
Src0Desc
>>::
GetNumOfDimension
()
&&
nDim
==
remove_reference_t
<
remove_cv_t
<
Src1Desc
>>::
GetNumOfDimension
()
&&
nDim
==
remove_reference_t
<
remove_cv_t
<
DstDesc
>>::
GetNumOfDimension
()
&&
nDim
==
ThreadClusterLengths
::
Size
()
&&
nDim
==
ThreadClusterArrangeOrder
::
Size
()
&&
nDim
==
DimAccessOrder
::
Size
(),
"wrong! nDim not consistent"
);
static_assert
(
is_same
<
BlockSliceLengths
,
decltype
(
thread_slice_lengths
*
ThreadClusterLengths
{})
>
{},
"wrong! threads should be mapped to cover entire slicing window"
);
static_assert
(
BlockSize
>=
thread_cluster_desc_
.
GetElementSize
(),
"wrong! BlockSize too small"
);
if
(
BlockSize
==
thread_cluster_desc_
.
GetElementSize
()
or
get_thread_local_1d_id
()
<
thread_cluster_desc_
.
GetElementSize
())
{
const
auto
thread_cluster_idx
=
thread_cluster_desc_
.
CalculateBottomIndex
(
make_multi_index
(
get_thread_local_1d_id
()));
const
auto
thread_data_idx_begin
=
thread_cluster_idx
*
thread_slice_lengths
;
threadwise_transfer_
.
SetSrc0SliceOrigin
(
src0_desc
,
src0_block_slice_origin
+
thread_data_idx_begin
);
threadwise_transfer_
.
SetSrc1SliceOrigin
(
src1_desc
,
src1_block_slice_origin
+
thread_data_idx_begin
);
threadwise_transfer_
.
SetDstSliceOrigin
(
dst_desc
,
dst_block_slice_origin
+
thread_data_idx_begin
);
}
}
template
<
typename
Src0Buffer
,
typename
Src1Buffer
,
typename
DstBuffer
>
__device__
void
Run
(
const
Src0Desc
&
src0_desc
,
const
Src0Buffer
&
src0_buf
,
const
Src1Desc
&
src1_desc
,
const
Src1Buffer
&
src1_buf
,
const
DstDesc
&
dst_desc
,
DstBuffer
&
dst_buf
)
{
if
(
BlockSize
==
thread_cluster_desc_
.
GetElementSize
()
or
get_thread_local_1d_id
()
<
thread_cluster_desc_
.
GetElementSize
())
{
threadwise_transfer_
.
Run
(
src0_desc
,
src0_buf
,
src1_desc
,
src1_buf
,
dst_desc
,
dst_buf
);
}
}
__device__
void
MoveSrc0SliceWindow
(
const
Src0Desc
&
src0_desc
,
const
Index
&
step
)
{
if
(
BlockSize
==
thread_cluster_desc_
.
GetElementSize
()
or
get_thread_local_1d_id
()
<
thread_cluster_desc_
.
GetElementSize
())
{
threadwise_transfer_
.
MoveSrc0SliceWindow
(
src0_desc
,
step
);
}
}
__device__
void
MoveSrc1SliceWindow
(
const
Src1Desc
&
src1_desc
,
const
Index
&
step
)
{
if
(
BlockSize
==
thread_cluster_desc_
.
GetElementSize
()
or
get_thread_local_1d_id
()
<
thread_cluster_desc_
.
GetElementSize
())
{
threadwise_transfer_
.
MoveSrc1SliceWindow
(
src1_desc
,
step
);
}
}
__device__
void
MoveDstSliceWindow
(
const
DstDesc
&
dst_desc
,
const
Index
&
step
)
{
if
(
BlockSize
==
thread_cluster_desc_
.
GetElementSize
()
or
get_thread_local_1d_id
()
<
thread_cluster_desc_
.
GetElementSize
())
{
threadwise_transfer_
.
MoveDstSliceWindow
(
dst_desc
,
step
);
}
}
private:
static
constexpr
auto
thread_cluster_desc_
=
make_cluster_descriptor
(
ThreadClusterLengths
{},
ThreadClusterArrangeOrder
{});
using
ThreadwiseTransfer
=
ThreadwiseTensorSliceTransfer_v6r2
<
Src0Data
,
Src1Data
,
DstData
,
Src0Desc
,
Src1Desc
,
DstDesc
,
ElementwiseOperation
,
decltype
(
thread_slice_lengths
),
DimAccessOrder
,
VectorDim
,
ScalarPerVector
,
DstInMemOp
,
ThreadTransferSrc0ResetCoordinateAfterRun
,
ThreadTransferSrc1ResetCoordinateAfterRun
,
ThreadTransferDstResetCoordinateAfterRun
>
;
ThreadwiseTransfer
threadwise_transfer_
;
};
}
// namespace ck
#endif
composable_kernel/include/tensor_operation/blockwise_tensor_slice_transfer_v6r3.hpp
0 → 100644
View file @
5c8112d9
#ifndef CK_BLOCKWISE_TENSOR_SLICE_TRANSFER_V6R3_HPP
#define CK_BLOCKWISE_TENSOR_SLICE_TRANSFER_V6R3_HPP
#include "common_header.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "cluster_descriptor.hpp"
#include "threadwise_tensor_slice_transfer_v6r3.hpp"
namespace
ck
{
// this version does following things to avoid scratch memory issue
// 1. Use StaticallyIndexedArray instead of C array for thread buffer
// 2. ThreadwiseTensorSliceTransfer_v3 does not keep reference to tensor descriptor
// 3. ThreadwiseTensorSliceTransfer_v3::Run() does not construct new tensor coordinate
template
<
index_t
BlockSize
,
typename
ElementwiseOperation
,
InMemoryDataOperationEnum_t
DstInMemOp
,
typename
BlockSliceLengths
,
typename
ThreadClusterLengths
,
typename
ThreadClusterArrangeOrder
,
typename
Src0Data
,
typename
Src1Data
,
typename
Src2Data
,
typename
DstData
,
typename
Src0Desc
,
typename
Src1Desc
,
typename
Src2Desc
,
typename
DstDesc
,
typename
DimAccessOrder
,
index_t
VectorDim
,
index_t
ScalarPerVector
,
bool
ThreadTransferSrc0ResetCoordinateAfterRun
,
bool
ThreadTransferSrc1ResetCoordinateAfterRun
,
bool
ThreadTransferSrc2ResetCoordinateAfterRun
,
bool
ThreadTransferDstResetCoordinateAfterRun
>
struct
BlockwiseTensorSliceTransfer_v6r3
{
static
constexpr
index_t
nDim
=
remove_reference_t
<
Src0Desc
>::
GetNumOfDimension
();
static
constexpr
auto
thread_slice_lengths
=
BlockSliceLengths
{}
/
ThreadClusterLengths
{};
using
Index
=
MultiIndex
<
nDim
>
;
__device__
constexpr
BlockwiseTensorSliceTransfer_v6r3
(
const
Src0Desc
&
src0_desc
,
const
Index
&
src0_block_slice_origin
,
const
Src1Desc
&
src1_desc
,
const
Index
&
src1_block_slice_origin
,
const
Src2Desc
&
src2_desc
,
const
Index
&
src2_block_slice_origin
,
const
DstDesc
&
dst_desc
,
const
Index
&
dst_block_slice_origin
,
const
ElementwiseOperation
&
element_op
)
:
threadwise_transfer_
(
src0_desc
,
make_zero_multi_index
<
nDim
>
(),
src1_desc
,
make_zero_multi_index
<
nDim
>
(),
src2_desc
,
make_zero_multi_index
<
nDim
>
(),
dst_desc
,
make_zero_multi_index
<
nDim
>
(),
element_op
)
{
static_assert
(
nDim
==
remove_reference_t
<
remove_cv_t
<
Src0Desc
>>::
GetNumOfDimension
()
&&
nDim
==
remove_reference_t
<
remove_cv_t
<
Src1Desc
>>::
GetNumOfDimension
()
&&
nDim
==
remove_reference_t
<
remove_cv_t
<
Src2Desc
>>::
GetNumOfDimension
()
&&
nDim
==
remove_reference_t
<
remove_cv_t
<
DstDesc
>>::
GetNumOfDimension
()
&&
nDim
==
ThreadClusterLengths
::
Size
()
&&
nDim
==
ThreadClusterArrangeOrder
::
Size
()
&&
nDim
==
DimAccessOrder
::
Size
(),
"wrong! nDim not consistent"
);
static_assert
(
is_same
<
BlockSliceLengths
,
decltype
(
thread_slice_lengths
*
ThreadClusterLengths
{})
>
{},
"wrong! threads should be mapped to cover entire slicing window"
);
static_assert
(
BlockSize
>=
thread_cluster_desc_
.
GetElementSize
(),
"wrong! BlockSize too small"
);
if
(
BlockSize
==
thread_cluster_desc_
.
GetElementSize
()
or
get_thread_local_1d_id
()
<
thread_cluster_desc_
.
GetElementSize
())
{
const
auto
thread_cluster_idx
=
thread_cluster_desc_
.
CalculateBottomIndex
(
make_multi_index
(
get_thread_local_1d_id
()));
const
auto
thread_data_idx_begin
=
thread_cluster_idx
*
thread_slice_lengths
;
threadwise_transfer_
.
SetSrc0SliceOrigin
(
src0_desc
,
src0_block_slice_origin
+
thread_data_idx_begin
);
threadwise_transfer_
.
SetSrc1SliceOrigin
(
src1_desc
,
src1_block_slice_origin
+
thread_data_idx_begin
);
threadwise_transfer_
.
SetSrc2SliceOrigin
(
src2_desc
,
src2_block_slice_origin
+
thread_data_idx_begin
);
threadwise_transfer_
.
SetDstSliceOrigin
(
dst_desc
,
dst_block_slice_origin
+
thread_data_idx_begin
);
}
}
template
<
typename
Src0Buffer
,
typename
Src1Buffer
,
typename
Src2Buffer
,
typename
DstBuffer
>
__device__
void
Run
(
const
Src0Desc
&
src0_desc
,
const
Src0Buffer
&
src0_buf
,
const
Src1Desc
&
src1_desc
,
const
Src1Buffer
&
src1_buf
,
const
Src2Desc
&
src2_desc
,
const
Src2Buffer
&
src2_buf
,
const
DstDesc
&
dst_desc
,
DstBuffer
&
dst_buf
)
{
if
(
BlockSize
==
thread_cluster_desc_
.
GetElementSize
()
or
get_thread_local_1d_id
()
<
thread_cluster_desc_
.
GetElementSize
())
{
threadwise_transfer_
.
Run
(
src0_desc
,
src0_buf
,
src1_desc
,
src1_buf
,
src2_desc
,
src2_buf
,
dst_desc
,
dst_buf
);
}
}
__device__
void
MoveSrc0SliceWindow
(
const
Src0Desc
&
src0_desc
,
const
Index
&
step
)
{
if
(
BlockSize
==
thread_cluster_desc_
.
GetElementSize
()
or
get_thread_local_1d_id
()
<
thread_cluster_desc_
.
GetElementSize
())
{
threadwise_transfer_
.
MoveSrc0SliceWindow
(
src0_desc
,
step
);
}
}
__device__
void
MoveSrc1SliceWindow
(
const
Src1Desc
&
src1_desc
,
const
Index
&
step
)
{
if
(
BlockSize
==
thread_cluster_desc_
.
GetElementSize
()
or
get_thread_local_1d_id
()
<
thread_cluster_desc_
.
GetElementSize
())
{
threadwise_transfer_
.
MoveSrc1SliceWindow
(
src1_desc
,
step
);
}
}
__device__
void
MoveSrc2SliceWindow
(
const
Src2Desc
&
src2_desc
,
const
Index
&
step
)
{
if
(
BlockSize
==
thread_cluster_desc_
.
GetElementSize
()
or
get_thread_local_1d_id
()
<
thread_cluster_desc_
.
GetElementSize
())
{
threadwise_transfer_
.
MoveSrc2SliceWindow
(
src2_desc
,
step
);
}
}
__device__
void
MoveDstSliceWindow
(
const
DstDesc
&
dst_desc
,
const
Index
&
step
)
{
if
(
BlockSize
==
thread_cluster_desc_
.
GetElementSize
()
or
get_thread_local_1d_id
()
<
thread_cluster_desc_
.
GetElementSize
())
{
threadwise_transfer_
.
MoveDstSliceWindow
(
dst_desc
,
step
);
}
}
private:
static
constexpr
auto
thread_cluster_desc_
=
make_cluster_descriptor
(
ThreadClusterLengths
{},
ThreadClusterArrangeOrder
{});
using
ThreadwiseTransfer
=
ThreadwiseTensorSliceTransfer_v6r3
<
Src0Data
,
Src1Data
,
Src2Data
,
DstData
,
Src0Desc
,
Src1Desc
,
Src2Desc
,
DstDesc
,
ElementwiseOperation
,
decltype
(
thread_slice_lengths
),
DimAccessOrder
,
VectorDim
,
ScalarPerVector
,
DstInMemOp
,
ThreadTransferSrc0ResetCoordinateAfterRun
,
ThreadTransferSrc1ResetCoordinateAfterRun
,
ThreadTransferSrc2ResetCoordinateAfterRun
,
ThreadTransferDstResetCoordinateAfterRun
>
;
ThreadwiseTransfer
threadwise_transfer_
;
};
}
// namespace ck
#endif
composable_kernel/include/tensor_operation/element_wise_operation.hpp
0 → 100644
View file @
5c8112d9
#ifndef CK_ELEMENT_WISE_OPERATION_HPP
#define CK_ELEMENT_WISE_OPERATION_HPP
namespace
ck
{
namespace
tensor_operation
{
namespace
element_wise
{
struct
PassThrough
{
template
<
typename
T
>
__host__
__device__
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
y
=
x
;
}
// TODO remove this
template
<
typename
T
>
__host__
__device__
constexpr
T
operator
()(
T
v
)
const
{
return
v
;
}
};
struct
AddRelu
{
template
<
typename
T
>
__host__
__device__
constexpr
void
operator
()(
T
&
y
,
const
T
&
x0
,
const
T
&
x1
)
const
{
T
a
=
x0
+
x1
;
y
=
a
>
0
?
a
:
0
;
}
// TODO remove this
template
<
typename
T1
>
__host__
constexpr
float
operator
()(
float
v0
,
T1
v1
)
const
{
float
b
=
v0
+
v1
;
float
c
=
b
>
0
?
b
:
0
;
return
c
;
}
// TODO remove this
template
<
typename
T1
>
__device__
constexpr
float
operator
()(
float
v0
,
T1
v1
)
const
{
#if 0
float a = v1 + v0;
float b = max(a, float(0));
return b;
#else
float
b
=
v1
+
v0
;
float
c
=
b
>
0
?
b
:
0
;
return
c
;
#endif
}
};
struct
AddReluAdd
{
template
<
typename
T
>
__host__
__device__
constexpr
void
operator
()(
T
&
y
,
const
T
&
x0
,
const
T
&
x1
,
const
T
&
x2
)
const
{
T
a
=
x0
+
x1
;
T
b
=
a
>
0
?
a
:
0
;
y
=
b
+
x2
;
}
// TODO remove this
template
<
typename
T1
,
typename
T2
>
__host__
constexpr
float
operator
()(
float
v0
,
T1
v1
,
T2
v2
)
const
{
float
b
=
v0
+
v1
;
float
c
=
b
>
0
?
b
:
0
;
float
d
=
c
+
v2
;
return
d
;
}
// TODO remove this
template
<
typename
T1
,
typename
T2
>
__device__
constexpr
float
operator
()(
float
v0
,
T1
v1
,
T2
v2
)
const
{
#if 0
float a = v1 + v0;
float b = max(a, float(0));
float c = b + v2;
return c;
#else
float
b
=
v1
+
v2
;
float
c
=
(
v0
>
-
v1
)
?
b
+
v0
:
v2
;
return
c
;
#endif
}
};
}
// namespace element_wise
}
// namespace tensor_operation
}
// namespace ck
namespace
ck
{
namespace
tensor_operation
{
namespace
element_wise
{
struct
AddLeakyReluAdd
{
template
<
typename
T1
,
typename
T2
>
__host__
constexpr
float
operator
()(
float
v0
,
T1
v1
,
T2
v2
)
const
{
float
a
=
v0
+
v1
;
float
b
=
0.1
*
a
;
float
c
=
b
>
0
?
b
:
0
;
float
d
=
c
+
v2
;
return
d
;
}
template
<
typename
T1
,
typename
T2
>
__device__
constexpr
float
operator
()(
float
v0
,
T1
v1
,
T2
v2
)
const
{
#if 0
// this use not too many registers, but use fp64 mul
float a = v0 + v1;
float b = 0.1 * a;
float c = b > 0 ? b : 0;
float d = c + v2;
return d;
#elif
0
// this spill register
float
a
=
v0
+
v1
;
float
b
=
float
(
0.1
)
*
a
;
float
c
=
b
>
0
?
b
:
0
;
float
d
=
c
+
v2
;
return
d
;
#elif 0
// this use lots of registers (but no spill)
constexpr
float
alpha
=
0.1
;
constexpr
float
alpha_inv
=
1.0
/
alpha
;
float
a
=
v2
*
alpha_inv
;
float
b
=
v1
+
v0
;
float
c
=
b
>
0
?
b
:
0
;
float
d
=
alpha
*
(
a
+
c
);
return
d
;
#elif 1
// this use lots of registers (but no spill), 89 Tflops
constexpr
float
alpha
=
0.1
;
constexpr
float
alpha_inv
=
1.0
/
alpha
;
float
a
=
v2
*
alpha_inv
;
float
b
=
v1
+
v0
;
float
c
=
max
(
b
,
float
(
0
));
float
d
=
alpha
*
(
a
+
c
);
return
d
;
#elif 1
// this spill registers, 89 Tflops
float
a
=
v0
+
v1
;
float
alpha
=
0.1
;
float
b
;
asm
volatile
(
"
\n
\
v_mul_f32_e32 %0, %1, %2
\n
\
"
:
"=v"
(
b
)
:
"s"
(
alpha
),
"v"
(
a
));
float
c
=
b
>
0
?
b
:
0
;
float
d
=
c
+
v2
;
return
d
;
#endif
}
};
}
// namespace element_wise
}
// namespace tensor_operation
}
// namespace ck
#endif
composable_kernel/include/tensor_operation/gridwise_contraction_dlops_v1r2.hpp
View file @
5c8112d9
...
...
@@ -381,7 +381,7 @@ struct GridwiseContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0_GM1_GN
"wrong!"
);
// A matrix blockwise copy
auto
a_blockwise_copy
=
BlockwiseTensorSliceTransfer_v
4
r1
<
auto
a_blockwise_copy
=
BlockwiseTensorSliceTransfer_v
5
r1
<
BlockSize
,
InMemoryDataOperationEnum_t
::
Set
,
Sequence
<
GK0PerBlock
,
GM0
,
1
,
GM1PerBlockGM11
,
GK1
.
value
>
,
...
...
@@ -405,7 +405,7 @@ struct GridwiseContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0_GM1_GN
make_multi_index
(
0
,
0
,
0
,
0
,
0
));
// B matrix blockwise copy
auto
b_blockwise_copy
=
BlockwiseTensorSliceTransfer_v
4
r1
<
auto
b_blockwise_copy
=
BlockwiseTensorSliceTransfer_v
5
r1
<
BlockSize
,
InMemoryDataOperationEnum_t
::
Set
,
Sequence
<
GK0PerBlock
,
GN0
,
1
,
GN1PerBlockGN11
,
GK1
.
value
>
,
...
...
composable_kernel/include/tensor_operation/gridwise_gemm_dlops_v1r3.hpp
View file @
5c8112d9
...
...
@@ -6,7 +6,7 @@
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "blockwise_gemm_dlops_v2r3.hpp"
#include "blockwise_tensor_slice_transfer_v
2
.hpp"
#include "blockwise_tensor_slice_transfer_v
5r1
.hpp"
#include "threadwise_tensor_slice_transfer_v2.hpp"
#include "threadwise_tensor_slice_set.hpp"
...
...
@@ -380,7 +380,7 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3
"wrong!"
);
// A matrix blockwise copy
auto
a_blockwise_copy
=
BlockwiseTensorSliceTransfer_v
4
r1
<
auto
a_blockwise_copy
=
BlockwiseTensorSliceTransfer_v
5
r1
<
BlockSize
,
InMemoryDataOperationEnum_t
::
Set
,
Sequence
<
KPerBlock
,
1
,
MPerBlockM1
,
K1
.
value
>
,
...
...
@@ -404,7 +404,7 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3
make_multi_index
(
0
,
0
,
0
,
0
));
// B matrix blockwise copy
auto
b_blockwise_copy
=
BlockwiseTensorSliceTransfer_v
4
r1
<
auto
b_blockwise_copy
=
BlockwiseTensorSliceTransfer_v
5
r1
<
BlockSize
,
InMemoryDataOperationEnum_t
::
Set
,
Sequence
<
KPerBlock
,
1
,
NPerBlockN1
,
K1
.
value
>
,
...
...
composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v2r3.hpp
View file @
5c8112d9
...
...
@@ -6,9 +6,8 @@
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "blockwise_gemm_xdlops.hpp"
#include "blockwise_tensor_slice_transfer.hpp"
#include "blockwise_tensor_slice_transfer
_v4r1
.hpp"
#include "threadwise_tensor_slice_transfer.hpp"
#include "threadwise_tensor_slice_set.hpp"
namespace
ck
{
...
...
@@ -40,15 +39,12 @@ __global__ void
const
CElementwiseOperation
c_element_op
,
const
Block2CTileMap
block_2_ctile_map
)
{
constexpr
index_t
shared_block_size
=
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()
/
sizeof
(
FloatAB
);
__shared__
FloatAB
p_shared_block
[
shared_block_size
];
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
p_a_grid
,
p_b_grid
,
p_c_grid
,
p_shared
_block
,
p_shared
,
a_grid_desc_k0_m_k1
,
b_grid_desc_k0_n_k1
,
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
...
...
@@ -83,9 +79,6 @@ __global__ void
const
void
CONSTANT
*
p_c_element_op
,
const
void
CONSTANT
*
p_block_2_ctile_map
)
{
constexpr
index_t
shared_block_size
=
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()
/
sizeof
(
FloatAB
);
const
auto
a_grid_desc_k0_m_k1
=
*
reinterpret_cast
<
const
AGridDesc_K0_M_K1
*>
(
cast_pointer_to_generic_address_space
(
p_a_grid_desc_k0_m_k1
));
const
auto
b_grid_desc_k0_n_k1
=
*
reinterpret_cast
<
const
BGridDesc_K0_N_K1
*>
(
...
...
@@ -102,12 +95,12 @@ __global__ void
const
auto
c_element_op
=
*
reinterpret_cast
<
const
CElementwiseOperation
*>
(
cast_pointer_to_generic_address_space
(
p_c_element_op
));
__shared__
FloatAB
p_shared
_block
[
shared_block_size
];
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()
];
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
p_a_grid
,
p_b_grid
,
p_c_grid
,
p_shared
_block
,
p_shared
,
a_grid_desc_k0_m_k1
,
b_grid_desc_k0_n_k1
,
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
...
...
@@ -135,9 +128,8 @@ template <index_t BlockSize,
index_t
MPerXDL
,
index_t
NPerXDL
,
index_t
K1Value
,
index_t
MRepeat
,
index_t
NRepeat
,
typename
ABlockTransferThreadSliceLengths_K0_M_K1
,
index_t
MXdlPerWave
,
index_t
NXdlPerWave
,
typename
ABlockTransferThreadClusterLengths_K0_M_K1
,
typename
ABlockTransferThreadClusterArrangeOrder
,
typename
ABlockTransferSrcAccessOrder
,
...
...
@@ -145,7 +137,7 @@ template <index_t BlockSize,
index_t
ABlockTransferSrcScalarPerVector
,
index_t
ABlockTransferDstScalarPerVector_K1
,
bool
AThreadTransferSrcResetCoordinateAfterRun
,
typename
BBlockTransferThreadSliceLengths_K0_N_K1
,
bool
ABlockLdsExtraM
,
typename
BBlockTransferThreadClusterLengths_K0_N_K1
,
typename
BBlockTransferThreadClusterArrangeOrder
,
typename
BBlockTransferSrcAccessOrder
,
...
...
@@ -153,17 +145,10 @@ template <index_t BlockSize,
index_t
BBlockTransferSrcScalarPerVector
,
index_t
BBlockTransferDstScalarPerVector_K1
,
bool
BThreadTransferSrcResetCoordinateAfterRun
,
bool
BBlockLdsExtraN
,
typename
CThreadTransferSrcDstAccessOrder
,
index_t
CThreadTransferSrcDstVectorDim
,
index_t
CThreadTransferDstScalarPerVector
,
typename
AGridStepHacks
,
typename
BGridStepHacks
,
typename
CGridStepHacks
,
typename
AGridMoveSliceWindowStepHacks
,
typename
BGridMoveSliceWindowStepHacks
,
bool
CAccessOrderMRepeatNRepeat
,
bool
ABlockLdsExtraM
,
bool
BBlockLdsExtraN
>
index_t
CThreadTransferDstScalarPerVector
>
struct
GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
...
...
@@ -178,7 +163,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
// K1 should be Number<...>
static
constexpr
auto
K1
=
Number
<
K1Value
>
{};
__host__
__device__
static
constexpr
index_t
GetSharedMemoryNumberOfByte
()
__host__
__device__
static
constexpr
auto
GetABlockDescriptor_K0PerBlock_MPerBlock_K1
()
{
constexpr
auto
max_lds_align
=
K1
;
...
...
@@ -197,6 +182,13 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
}
}();
return
a_block_desc_k0_m_k1
;
}
__host__
__device__
static
constexpr
auto
GetBBlockDescriptor_K0PerBlock_NPerBlock_K1
()
{
constexpr
auto
max_lds_align
=
K1
;
// B matrix in LDS memory, dst of blockwise copy
constexpr
auto
b_block_desc_k0_n_k1
=
[
&
]()
{
if
constexpr
(
BBlockLdsExtraN
)
...
...
@@ -212,14 +204,25 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
}
}();
return
b_block_desc_k0_n_k1
;
}
__host__
__device__
static
constexpr
index_t
GetSharedMemoryNumberOfByte
()
{
// LDS allocation for A and B: be careful of alignment
constexpr
auto
a_block_space_size
=
constexpr
auto
a_block_desc_k0_m_k1
=
GetABlockDescriptor_K0PerBlock_MPerBlock_K1
();
constexpr
auto
b_block_desc_k0_n_k1
=
GetBBlockDescriptor_K0PerBlock_NPerBlock_K1
();
constexpr
auto
max_lds_align
=
K1
;
constexpr
auto
a_block_space_size_aligned
=
math
::
integer_least_multiple
(
a_block_desc_k0_m_k1
.
GetElementSpaceSize
(),
max_lds_align
);
constexpr
auto
b_block_space_size
=
constexpr
auto
b_block_space_size
_aligned
=
math
::
integer_least_multiple
(
b_block_desc_k0_n_k1
.
GetElementSpaceSize
(),
max_lds_align
);
return
(
a_block_space_size
+
b_block_space_size
)
*
sizeof
(
FloatAB
);
return
(
a_block_space_size
_aligned
+
b_block_space_size
_aligned
)
*
sizeof
(
FloatAB
);
}
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
...
...
@@ -233,8 +236,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
static_assert
(
is_known_at_compile_time
<
remove_cv_t
<
decltype
(
K1
)
>>::
value
,
"wrong! K1 need to be known at compile-time"
);
static_assert
((
MPerBlock
%
(
MPerXDL
*
M
Repeat
)
==
0
)
&&
(
NPerBlock
%
(
N
Repeat
*
NPerXDL
))
==
0
,
static_assert
((
MPerBlock
%
(
MPerXDL
*
M
XdlPerWave
)
==
0
)
&&
(
NPerBlock
%
(
N
XdlPerWave
*
NPerXDL
))
==
0
,
"Invalid tuning param!"
);
const
auto
M
=
a_grid_desc_k0_m_k1
.
GetLength
(
I1
);
...
...
@@ -324,8 +327,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
decltype
(
b_block_desc_k0_n_k1
),
MPerXDL
,
NPerXDL
,
M
Repeat
,
N
Repeat
,
M
XdlPerWave
,
N
XdlPerWave
,
K1
>
;
return
BlockwiseGemm
::
MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
(
c_grid_desc_m_n
);
...
...
@@ -376,7 +379,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
Run
(
const
FloatAB
*
__restrict__
p_a_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
FloatC
*
__restrict__
p_c_grid
,
FloatAB
*
__restrict__
p_shared
_block
,
void
*
__restrict__
p_shared
,
const
AGridDesc_K0_M_K1
&
a_grid_desc_k0_m_k1
,
const
BGridDesc_K0_N_K1
&
b_grid_desc_k0_n_k1
,
const
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
&
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
...
...
@@ -409,42 +412,18 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
constexpr
auto
max_lds_align
=
K1
;
// A matrix in LDS memory, dst of blockwise copy
constexpr
auto
a_block_desc_k0_m_k1
=
[
&
]()
{
if
constexpr
(
ABlockLdsExtraM
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
Number
<
K0PerBlock
>
{},
Number
<
MPerBlock
>
{},
K1
),
make_tuple
(
Number
<
MPerBlock
+
1
>
{}
*
K1
,
K1
,
I1
));
}
else
{
return
make_naive_tensor_descriptor_aligned
(
make_tuple
(
Number
<
K0PerBlock
>
{},
Number
<
MPerBlock
>
{},
K1
),
max_lds_align
);
}
}();
constexpr
auto
a_block_desc_k0_m_k1
=
GetABlockDescriptor_K0PerBlock_MPerBlock_K1
();
// B matrix in LDS memory, dst of blockwise copy
constexpr
auto
b_block_desc_k0_n_k1
=
[
&
]()
{
if
constexpr
(
BBlockLdsExtraN
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
Number
<
K0PerBlock
>
{},
Number
<
NPerBlock
>
{},
K1
),
make_tuple
(
Number
<
NPerBlock
+
1
>
{}
*
K1
,
K1
,
I1
));
}
else
{
return
make_naive_tensor_descriptor_aligned
(
make_tuple
(
Number
<
K0PerBlock
>
{},
Number
<
NPerBlock
>
{},
K1
),
max_lds_align
);
}
}();
constexpr
auto
b_block_desc_k0_n_k1
=
GetBBlockDescriptor_K0PerBlock_NPerBlock_K1
();
// A matrix blockwise copy
auto
a_blockwise_copy
=
BlockwiseTensorSliceTransfer_v4
<
BlockSize
,
BlockwiseTensorSliceTransfer_v4
r1
<
BlockSize
,
AElementwiseOperation
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
InMemoryDataOperationEnum_t
::
Set
,
Sequence
<
K0PerBlock
,
MPerBlock
,
K1
>
,
ABlockTransferThreadSliceLengths_K0_M_K1
,
ABlockTransferThreadClusterLengths_K0_M_K1
,
ABlockTransferThreadClusterArrangeOrder
,
FloatAB
,
...
...
@@ -460,19 +439,21 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
1
,
1
,
AThreadTransferSrcResetCoordinateAfterRun
,
true
>
(
a_grid_desc_k0_m_k1
,
true
>
(
a_grid_desc_k0_m_k1
,
make_multi_index
(
0
,
m_block_data_idx_on_grid
,
0
),
a_element_op
,
a_block_desc_k0_m_k1
,
make_multi_index
(
0
,
0
,
0
),
a_element_op
);
ck
::
tensor_operation
::
element_wise
::
PassThrough
{}
);
// B matrix blockwise copy
auto
b_blockwise_copy
=
BlockwiseTensorSliceTransfer_v4
<
BlockSize
,
BlockwiseTensorSliceTransfer_v4
r1
<
BlockSize
,
BElementwiseOperation
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
InMemoryDataOperationEnum_t
::
Set
,
Sequence
<
K0PerBlock
,
NPerBlock
,
K1
>
,
BBlockTransferThreadSliceLengths_K0_N_K1
,
BBlockTransferThreadClusterLengths_K0_N_K1
,
BBlockTransferThreadClusterArrangeOrder
,
FloatAB
,
...
...
@@ -488,11 +469,13 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
1
,
1
,
BThreadTransferSrcResetCoordinateAfterRun
,
true
>
(
b_grid_desc_k0_n_k1
,
true
>
(
b_grid_desc_k0_n_k1
,
make_multi_index
(
0
,
n_block_data_idx_on_grid
,
0
),
b_element_op
,
b_block_desc_k0_n_k1
,
make_multi_index
(
0
,
0
,
0
),
b_element_op
);
ck
::
tensor_operation
::
element_wise
::
PassThrough
{}
);
// GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx
...
...
@@ -510,68 +493,53 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
decltype
(
b_block_desc_k0_n_k1
),
MPerXDL
,
NPerXDL
,
M
Repeat
,
N
Repeat
,
M
XdlPerWave
,
N
XdlPerWave
,
K1
>
{};
auto
c_thread_buf
=
blockwise_gemm
.
GetCThreadBuffer
();
// LDS allocation for A and B: be careful of alignment
constexpr
auto
a_block_space_size
=
constexpr
auto
a_block_space_size
_aligned
=
math
::
integer_least_multiple
(
a_block_desc_k0_m_k1
.
GetElementSpaceSize
(),
max_lds_align
);
FloatAB
*
p_a_block
=
p_shared_block
;
FloatAB
*
p_b_block
=
p_shared_block
+
a_block_space_size
;
auto
a_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Lds
>
(
static_cast
<
FloatAB
*>
(
p_shared
),
a_block_desc_k0_m_k1
.
GetElementSpaceSize
());
auto
b_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Lds
>
(
static_cast
<
FloatAB
*>
(
p_shared
)
+
a_block_space_size_aligned
,
b_block_desc_k0_n_k1
.
GetElementSpaceSize
());
constexpr
auto
a_block_slice_copy_step
=
make_multi_index
(
K0PerBlock
,
0
,
0
);
constexpr
auto
b_block_slice_copy_step
=
make_multi_index
(
K0PerBlock
,
0
,
0
);
// hack to control index calculation when iterating over A and B matrix for threadwise copy
constexpr
auto
a_k0_m_k1_grid_step_hacks
=
AGridStepHacks
{};
constexpr
auto
b_k0_n_k1_grid_step_hacks
=
BGridStepHacks
{};
// hack to control index calculation when move slice window for A and B matrix for
// threadwise copy
constexpr
auto
a_k0_m_k1_grid_move_slice_window_step_hack
=
AGridMoveSliceWindowStepHacks
{};
constexpr
auto
b_k0_n_k1_grid_move_slice_window_step_hack
=
BGridMoveSliceWindowStepHacks
{};
auto
a_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Lds
>
(
p_a_block
,
a_block_desc_k0_m_k1
.
GetElementSpaceSize
());
auto
b_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Lds
>
(
p_b_block
,
b_block_desc_k0_n_k1
.
GetElementSpaceSize
());
// preload data into LDS
{
a_blockwise_copy
.
RunRead
(
a_grid_desc_k0_m_k1
,
a_grid_buf
,
a_k0_m_k1_grid_step_hacks
);
b_blockwise_copy
.
RunRead
(
b_grid_desc_k0_n_k1
,
b_grid_buf
,
b_k0_n_k1_grid_step_hacks
);
a_blockwise_copy
.
RunRead
(
a_grid_desc_k0_m_k1
,
a_grid_buf
);
b_blockwise_copy
.
RunRead
(
b_grid_desc_k0_n_k1
,
b_grid_buf
);
a_blockwise_copy
.
RunWrite
(
a_block_desc_k0_m_k1
,
a_block_buf
);
b_blockwise_copy
.
RunWrite
(
b_block_desc_k0_n_k1
,
b_block_buf
);
}
// main body
index_t
k0_block_data_begin
=
0
;
// Initialize C
c_thread_buf
.
Clear
();
// main body
if
constexpr
(
HasMainKBlockLoop
)
{
index_t
k0_block_data_begin
=
0
;
do
{
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc_k0_m_k1
,
a_block_slice_copy_step
,
a_k0_m_k1_grid_move_slice_window_step_hack
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc_k0_n_k1
,
b_block_slice_copy_step
,
b_k0_n_k1_grid_move_slice_window_step_hack
);
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc_k0_m_k1
,
a_block_slice_copy_step
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc_k0_n_k1
,
b_block_slice_copy_step
);
a_blockwise_copy
.
RunRead
(
a_grid_desc_k0_m_k1
,
a_grid_buf
,
a_k0_m_k1_grid_step_hacks
);
a_blockwise_copy
.
RunRead
(
a_grid_desc_k0_m_k1
,
a_grid_buf
);
block_sync_lds
();
b_blockwise_copy
.
RunRead
(
b_grid_desc_k0_n_k1
,
b_grid_buf
,
b_k0_n_k1_grid_step_hacks
);
b_blockwise_copy
.
RunRead
(
b_grid_desc_k0_n_k1
,
b_grid_buf
);
blockwise_gemm
.
Run
(
a_block_buf
,
b_block_buf
,
c_thread_buf
);
...
...
@@ -619,8 +587,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
const
index_t
n_thread_data_on_grid
=
n_block_data_idx_on_grid
+
c_thread_mtx_on_block
[
I1
];
constexpr
auto
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_tensor_step_hacks
=
CGridStepHacks
{};
const
auto
m_thread_data_on_grid_to_m0_m1_m2_m3_m4_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
M0
,
M1
,
M2
,
M3
,
M4
))),
...
...
@@ -668,11 +634,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
c_thread_buf
,
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
c_grid_buf
,
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_tensor_step_hacks
);
c_grid_buf
);
}
}
};
// namespace ck
};
}
// namespace ck
#endif
composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v2r4.hpp
View file @
5c8112d9
...
...
@@ -6,9 +6,8 @@
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "blockwise_gemm_xdlops.hpp"
#include "blockwise_tensor_slice_transfer.hpp"
#include "blockwise_tensor_slice_transfer
_v4r1
.hpp"
#include "threadwise_tensor_slice_transfer.hpp"
#include "threadwise_tensor_slice_set.hpp"
namespace
ck
{
...
...
@@ -19,6 +18,9 @@ template <typename GridwiseGemm,
typename
ABK0MK1GridDesc
,
typename
BBK0NK1GridDesc
,
typename
CM0N0M1N1M2M3M4N2GridDesc
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
,
typename
CBlockClusterAdaptor
,
bool
HasMainKBlockLoop
>
__global__
void
...
...
@@ -31,6 +33,9 @@ __global__ void
const
ABK0MK1GridDesc
a_b_k0_m_k1_grid_desc
,
const
BBK0NK1GridDesc
b_b_k0_n_k1_grid_desc
,
const
CM0N0M1N1M2M3M4N2GridDesc
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc
,
const
AElementwiseOperation
a_element_op
,
const
BElementwiseOperation
b_element_op
,
const
CElementwiseOperation
c_element_op
,
const
CBlockClusterAdaptor
c_block_cluster_adaptor
)
{
constexpr
index_t
shared_block_size
=
...
...
@@ -45,6 +50,9 @@ __global__ void
a_b_k0_m_k1_grid_desc
,
b_b_k0_n_k1_grid_desc
,
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc
,
a_element_op
,
b_element_op
,
c_element_op
,
c_block_cluster_adaptor
);
}
#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER
...
...
@@ -129,11 +137,6 @@ template <index_t BlockSize,
typename
CThreadTransferSrcDstAccessOrder
,
index_t
CThreadTransferSrcDstVectorDim
,
index_t
CThreadTransferDstScalarPerVector
,
typename
AGridStepHacks
,
typename
BGridStepHacks
,
typename
CGridStepHacks
,
typename
AGridMoveSliceWindowStepHacks
,
typename
BGridMoveSliceWindowStepHacks
,
bool
CAccessOrderMRepeatNRepeat
,
bool
ABlockLdsExtraM
,
bool
BBlockLdsExtraN
>
...
...
@@ -371,6 +374,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4
c_block_cluster_adaptor
.
CalculateBottomIndex
(
make_multi_index
(
get_block_1d_id
()));
const
index_t
k_batch_id
=
block_work_idx
[
I0
];
// HACK: this force m/n_block_data_idx_on_grid into SGPR
const
index_t
m_block_data_idx_on_grid
=
__builtin_amdgcn_readfirstlane
(
block_work_idx
[
I1
]
*
MPerBlock
);
...
...
@@ -447,7 +451,9 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4
}();
// A matrix blockwise copy
auto
a_blockwise_copy
=
BlockwiseTensorSliceTransfer_v4
<
BlockSize
,
BlockwiseTensorSliceTransfer_v4r1
<
BlockSize
,
AElementwiseOperation
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
InMemoryDataOperationEnum_t
::
Set
,
Sequence
<
1
,
K0PerBlock
,
MPerBlock
,
K1
>
,
ABlockTransferThreadSliceLengths_K0_M_K1
,
...
...
@@ -469,12 +475,16 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4
true
>
(
a_b_k0_m_k1_grid_desc
,
make_multi_index
(
k_batch_id
,
0
,
m_block_data_idx_on_grid
,
0
),
a_element_op
,
a_b_k0_m_k1_block_desc
,
make_multi_index
(
0
,
0
,
0
,
0
));
make_multi_index
(
0
,
0
,
0
,
0
),
ck
::
tensor_operation
::
element_wise
::
PassThrough
{});
// B matrix blockwise copy
auto
b_blockwise_copy
=
BlockwiseTensorSliceTransfer_v4
<
BlockSize
,
BlockwiseTensorSliceTransfer_v4r1
<
BlockSize
,
BElementwiseOperation
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
InMemoryDataOperationEnum_t
::
Set
,
Sequence
<
1
,
K0PerBlock
,
NPerBlock
,
K1
>
,
BBlockTransferThreadSliceLengths_K0_N_K1
,
...
...
@@ -496,8 +506,10 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4
true
>
(
b_b_k0_n_k1_grid_desc
,
make_multi_index
(
k_batch_id
,
0
,
n_block_data_idx_on_grid
,
0
),
b_element_op
,
b_b_k0_n_k1_block_desc
,
make_multi_index
(
0
,
0
,
0
,
0
));
make_multi_index
(
0
,
0
,
0
,
0
),
ck
::
tensor_operation
::
element_wise
::
PassThrough
{});
// GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx
...
...
@@ -531,15 +543,6 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4
constexpr
auto
a_block_slice_copy_step
=
make_multi_index
(
0
,
K0PerBlock
,
0
,
0
);
constexpr
auto
b_block_slice_copy_step
=
make_multi_index
(
0
,
K0PerBlock
,
0
,
0
);
// hack to control index calculation when iterating over A and B matrix for threadwise copy
constexpr
auto
a_k0_m_k1_grid_step_hacks
=
AGridStepHacks
{};
constexpr
auto
b_k0_n_k1_grid_step_hacks
=
BGridStepHacks
{};
// hack to control index calculation when move slice window for A and B matrix for
// threadwise copy
constexpr
auto
a_k0_m_k1_grid_move_slice_window_step_hack
=
AGridMoveSliceWindowStepHacks
{};
constexpr
auto
b_k0_n_k1_grid_move_slice_window_step_hack
=
BGridMoveSliceWindowStepHacks
{};
auto
a_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Lds
>
(
p_a_block
,
a_k0_m_k1_block_desc
.
GetElementSpaceSize
());
auto
b_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Lds
>
(
...
...
@@ -547,33 +550,31 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4
// preload data into LDS
{
a_blockwise_copy
.
RunRead
(
a_b_k0_m_k1_grid_desc
,
a_grid_buf
,
a_k0_m_k1_grid_step_hacks
);
b_blockwise_copy
.
RunRead
(
b_b_k0_n_k1_grid_desc
,
b_grid_buf
,
b_k0_n_k1_grid_step_hacks
);
a_blockwise_copy
.
RunRead
(
a_b_k0_m_k1_grid_desc
,
a_grid_buf
);
b_blockwise_copy
.
RunRead
(
b_b_k0_n_k1_grid_desc
,
b_grid_buf
);
a_blockwise_copy
.
RunWrite
(
a_b_k0_m_k1_block_desc
,
a_block_buf
);
b_blockwise_copy
.
RunWrite
(
b_b_k0_n_k1_block_desc
,
b_block_buf
);
}
// Initialize C
c_thread_buf
.
Clear
();
// main body
index_t
k_block_data_begin
=
0
;
if
constexpr
(
HasMainKBlockLoop
)
{
index_t
k0_block_data_begin
=
0
;
do
{
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_b_k0_m_k1_grid_desc
,
a_block_slice_copy_step
,
a_k0_m_k1_grid_move_slice_window_step_hack
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_b_k0_n_k1_grid_desc
,
b_block_slice_copy_step
,
b_k0_n_k1_grid_move_slice_window_step_hack
);
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_b_k0_m_k1_grid_desc
,
a_block_slice_copy_step
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_b_k0_n_k1_grid_desc
,
b_block_slice_copy_step
);
a_blockwise_copy
.
RunRead
(
a_b_k0_m_k1_grid_desc
,
a_grid_buf
,
a_k0_m_k1_grid_step_hacks
);
a_blockwise_copy
.
RunRead
(
a_b_k0_m_k1_grid_desc
,
a_grid_buf
);
block_sync_lds
();
b_blockwise_copy
.
RunRead
(
b_b_k0_n_k1_grid_desc
,
b_grid_buf
,
b_k0_n_k1_grid_step_hacks
);
b_blockwise_copy
.
RunRead
(
b_b_k0_n_k1_grid_desc
,
b_grid_buf
);
blockwise_gemm
.
Run
(
a_block_buf
,
b_block_buf
,
c_thread_buf
);
...
...
@@ -622,8 +623,6 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4
const
index_t
n_thread_data_on_grid
=
n_block_data_idx_on_grid
+
c_thread_mtx_on_block
[
I1
];
constexpr
auto
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_tensor_step_hacks
=
CGridStepHacks
{};
const
auto
m_thread_data_on_grid_to_m0_m1_m2_m3_m4_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
M0
,
M1
,
M2
,
M3
,
M4
))),
...
...
@@ -648,6 +647,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4
FloatC
,
decltype
(
c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc
),
decltype
(
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc
),
CElementwiseOperation
,
Sequence
<
M0
,
N0
,
I1
,
I1
,
M2
,
I1
,
M4
,
I1
>
,
CThreadTransferSrcDstAccessOrder
,
CThreadTransferSrcDstVectorDim
,
...
...
@@ -664,14 +664,14 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4
m_thread_data_on_grid_idx
[
I2
],
m_thread_data_on_grid_idx
[
I3
],
m_thread_data_on_grid_idx
[
I4
],
n_thread_data_on_grid_idx
[
I2
])};
n_thread_data_on_grid_idx
[
I2
]),
c_element_op
};
c_thread_copy
.
Run
(
c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
c_thread_buf
,
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc
,
c_grid_buf
,
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_tensor_step_hacks
);
c_grid_buf
);
}
}
};
// namespace ck
...
...
composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v2r5.hpp
View file @
5c8112d9
...
...
@@ -6,9 +6,8 @@
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "blockwise_gemm_xdlops.hpp"
#include "blockwise_tensor_slice_transfer.hpp"
#include "blockwise_tensor_slice_transfer
_v4r1
.hpp"
#include "threadwise_tensor_slice_transfer_v1r4.hpp"
#include "threadwise_tensor_slice_set.hpp"
namespace
ck
{
...
...
@@ -88,7 +87,6 @@ template <index_t BlockSize,
index_t
K1Value
,
index_t
MRepeat
,
index_t
NRepeat
,
typename
ABlockTransferThreadSliceLengths_K0_M_K1
,
typename
ABlockTransferThreadClusterLengths_K0_M_K1
,
typename
ABlockTransferThreadClusterArrangeOrder
,
typename
ABlockTransferSrcAccessOrder
,
...
...
@@ -96,7 +94,7 @@ template <index_t BlockSize,
index_t
ABlockTransferSrcScalarPerVector
,
index_t
ABlockTransferDstScalarPerVector_K1
,
bool
AThreadTransferSrcResetCoordinateAfterRun
,
typename
BBlockTransferThreadSliceLengths_K0_N_K1
,
bool
ABlockLdsExtraM
,
typename
BBlockTransferThreadClusterLengths_K0_N_K1
,
typename
BBlockTransferThreadClusterArrangeOrder
,
typename
BBlockTransferSrcAccessOrder
,
...
...
@@ -104,17 +102,10 @@ template <index_t BlockSize,
index_t
BBlockTransferSrcScalarPerVector
,
index_t
BBlockTransferDstScalarPerVector_K1
,
bool
BThreadTransferSrcResetCoordinateAfterRun
,
bool
BBlockLdsExtraN
,
typename
CThreadTransferSrcDstAccessOrder
,
index_t
CThreadTransferSrcDstVectorDim
,
index_t
CThreadTransferDstScalarPerVector
,
typename
AGridStepHacks
,
typename
BGridStepHacks
,
typename
CGridStepHacks
,
typename
AGridMoveSliceWindowStepHacks
,
typename
BGridMoveSliceWindowStepHacks
,
bool
CAccessOrderMRepeatNRepeat
,
bool
ABlockLdsExtraM
,
bool
BBlockLdsExtraN
>
index_t
CThreadTransferDstScalarPerVector
>
struct
GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r5
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
...
...
@@ -410,11 +401,11 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r5
// A matrix blockwise copy
auto
a_blockwise_copy
=
BlockwiseTensorSliceTransfer_v4
<
BlockSize
,
BlockwiseTensorSliceTransfer_v4
r1
<
BlockSize
,
AElementwiseOperation
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
InMemoryDataOperationEnum_t
::
Set
,
Sequence
<
K0PerBlock
,
MPerBlock
,
K1
>
,
ABlockTransferThreadSliceLengths_K0_M_K1
,
ABlockTransferThreadClusterLengths_K0_M_K1
,
ABlockTransferThreadClusterArrangeOrder
,
FloatAB
,
...
...
@@ -430,19 +421,21 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r5
1
,
1
,
AThreadTransferSrcResetCoordinateAfterRun
,
true
>
(
a_grid_desc_k0_m_k1
,
true
>
(
a_grid_desc_k0_m_k1
,
make_multi_index
(
0
,
m_block_data_idx_on_grid
,
0
),
a_element_op
,
a_block_desc_k0_m_k1
,
make_multi_index
(
0
,
0
,
0
),
a_element_op
);
ck
::
tensor_operation
::
element_wise
::
PassThrough
{}
);
// B matrix blockwise copy
auto
b_blockwise_copy
=
BlockwiseTensorSliceTransfer_v4
<
BlockSize
,
BlockwiseTensorSliceTransfer_v4
r1
<
BlockSize
,
BElementwiseOperation
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
InMemoryDataOperationEnum_t
::
Set
,
Sequence
<
K0PerBlock
,
NPerBlock
,
K1
>
,
BBlockTransferThreadSliceLengths_K0_N_K1
,
BBlockTransferThreadClusterLengths_K0_N_K1
,
BBlockTransferThreadClusterArrangeOrder
,
FloatAB
,
...
...
@@ -458,11 +451,13 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r5
1
,
1
,
BThreadTransferSrcResetCoordinateAfterRun
,
true
>
(
b_grid_desc_k0_n_k1
,
true
>
(
b_grid_desc_k0_n_k1
,
make_multi_index
(
0
,
n_block_data_idx_on_grid
,
0
),
b_element_op
,
b_block_desc_k0_n_k1
,
make_multi_index
(
0
,
0
,
0
),
b_element_op
);
ck
::
tensor_operation
::
element_wise
::
PassThrough
{}
);
// GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx
...
...
@@ -496,15 +491,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r5
constexpr
auto
a_block_slice_copy_step
=
make_multi_index
(
K0PerBlock
,
0
,
0
);
constexpr
auto
b_block_slice_copy_step
=
make_multi_index
(
K0PerBlock
,
0
,
0
);
// hack to control index calculation when iterating over A and B matrix for threadwise copy
constexpr
auto
a_k0_m_k1_grid_step_hacks
=
AGridStepHacks
{};
constexpr
auto
b_k0_n_k1_grid_step_hacks
=
BGridStepHacks
{};
// hack to control index calculation when move slice window for A and B matrix for
// threadwise copy
constexpr
auto
a_k0_m_k1_grid_move_slice_window_step_hack
=
AGridMoveSliceWindowStepHacks
{};
constexpr
auto
b_k0_n_k1_grid_move_slice_window_step_hack
=
BGridMoveSliceWindowStepHacks
{};
auto
a_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Lds
>
(
p_a_block
,
a_block_desc_k0_m_k1
.
GetElementSpaceSize
());
auto
b_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Lds
>
(
...
...
@@ -512,34 +498,31 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r5
// preload data into LDS
{
a_blockwise_copy
.
RunRead
(
a_grid_desc_k0_m_k1
,
a_grid_buf
,
a_k0_m_k1_grid_step_hacks
);
b_blockwise_copy
.
RunRead
(
b_grid_desc_k0_n_k1
,
b_grid_buf
,
b_k0_n_k1_grid_step_hacks
);
a_blockwise_copy
.
RunRead
(
a_grid_desc_k0_m_k1
,
a_grid_buf
);
b_blockwise_copy
.
RunRead
(
b_grid_desc_k0_n_k1
,
b_grid_buf
);
a_blockwise_copy
.
RunWrite
(
a_block_desc_k0_m_k1
,
a_block_buf
);
b_blockwise_copy
.
RunWrite
(
b_block_desc_k0_n_k1
,
b_block_buf
);
}
//
main body
index_t
k0_block_data_begin
=
0
;
//
Initialize C
c_thread_buf
.
Clear
()
;
// main body
if
constexpr
(
HasMainKBlockLoop
)
{
index_t
k0_block_data_begin
=
0
;
do
{
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc_k0_m_k1
,
a_block_slice_copy_step
,
a_k0_m_k1_grid_move_slice_window_step_hack
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc_k0_n_k1
,
b_block_slice_copy_step
,
b_k0_n_k1_grid_move_slice_window_step_hack
);
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc_k0_m_k1
,
a_block_slice_copy_step
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc_k0_n_k1
,
b_block_slice_copy_step
);
a_blockwise_copy
.
RunRead
(
a_grid_desc_k0_m_k1
,
a_grid_buf
,
a_k0_m_k1_grid_step_hacks
);
a_blockwise_copy
.
RunRead
(
a_grid_desc_k0_m_k1
,
a_grid_buf
);
block_sync_lds
();
b_blockwise_copy
.
RunRead
(
b_grid_desc_k0_n_k1
,
b_grid_buf
,
b_k0_n_k1_grid_step_hacks
);
b_blockwise_copy
.
RunRead
(
b_grid_desc_k0_n_k1
,
b_grid_buf
);
blockwise_gemm
.
Run
(
a_block_buf
,
b_block_buf
,
c_thread_buf
);
...
...
@@ -588,8 +571,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r5
const
index_t
n_thread_data_on_grid
=
n_block_data_idx_on_grid
+
c_thread_mtx_on_block
[
I1
];
constexpr
auto
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_tensor_step_hacks
=
CGridStepHacks
{};
const
auto
m_thread_data_on_grid_to_m0_m1_m2_m3_m4_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
M0
,
M1
,
M2
,
M3
,
M4
))),
...
...
@@ -642,14 +623,13 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r5
c_thread_buf
,
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
c_grid_buf
,
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_tensor_step_hacks
,
c0_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
c0_grid_buf
,
c1_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
c1_grid_buf
);
}
}
};
// namespace ck
};
}
// namespace ck
#endif
composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v2r6.hpp
0 → 100644
View file @
5c8112d9
This diff is collapsed.
Click to expand it.
composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v3r1.hpp
0 → 100644
View file @
5c8112d9
This diff is collapsed.
Click to expand it.
composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v3r2.hpp
0 → 100644
View file @
5c8112d9
This diff is collapsed.
Click to expand it.
composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v3r3.hpp
0 → 100644
View file @
5c8112d9
This diff is collapsed.
Click to expand it.
composable_kernel/include/tensor_operation/threadwise_tensor_slice_transfer.hpp
View file @
5c8112d9
...
...
@@ -290,7 +290,7 @@ struct ThreadwiseTensorSliceTransfer_v1r3
const
DstDesc
&
dst_desc
,
DstBuffer
&
dst_buf
)
{
constexpr
index_t
ntransform_dst
=
DstDesc
::
GetNumOfTransform
();
constexpr
index_t
ntransform_dst
=
remove_cvref_t
<
DstDesc
>
::
GetNumOfTransform
();
constexpr
auto
zeros
=
typename
uniform_sequence_gen
<
ntransform_dst
,
0
>::
type
{};
...
...
@@ -326,7 +326,7 @@ struct ThreadwiseTensorSliceTransfer_v1r3
static_for
<
1
,
nDim
,
1
>
{}([
&
](
auto
i
)
{
index_t
tmp
=
ordered_access_lengths
[
I0
]
-
1
;
static_for
<
0
,
i
,
1
>
{}([
&
](
auto
j
)
{
static_for
<
1
,
i
,
1
>
{}([
&
](
auto
j
)
{
tmp
=
tmp
*
ordered_access_lengths
[
j
]
+
ordered_access_lengths
[
j
]
-
1
;
});
...
...
@@ -506,7 +506,7 @@ struct ThreadwiseTensorSliceTransfer_v2
static_for
<
1
,
nDim
,
1
>
{}([
&
](
auto
i
)
{
index_t
tmp
=
ordered_access_idx
[
I0
];
static_for
<
0
,
i
,
1
>
{}([
&
](
auto
j
)
{
static_for
<
1
,
i
,
1
>
{}([
&
](
auto
j
)
{
tmp
=
tmp
*
ordered_access_lengths
[
j
]
+
ordered_access_idx
[
j
];
});
...
...
@@ -638,7 +638,7 @@ struct ThreadwiseTensorSliceTransfer_v2
static_for
<
1
,
nDim
,
1
>
{}([
&
](
auto
i
)
{
index_t
tmp
=
ordered_access_lengths
[
I0
]
-
1
;
static_for
<
0
,
i
,
1
>
{}([
&
](
auto
j
)
{
static_for
<
1
,
i
,
1
>
{}([
&
](
auto
j
)
{
tmp
=
tmp
*
ordered_access_lengths
[
j
]
+
ordered_access_lengths
[
j
]
-
1
;
});
...
...
@@ -835,7 +835,7 @@ struct ThreadwiseTensorSliceTransfer_v3
static_for
<
1
,
nDim
,
1
>
{}([
&
](
auto
i
)
{
index_t
tmp
=
ordered_src_access_idx
[
I0
];
static_for
<
0
,
i
,
1
>
{}([
&
](
auto
j
)
{
static_for
<
1
,
i
,
1
>
{}([
&
](
auto
j
)
{
tmp
=
tmp
*
ordered_src_access_lengths
[
j
]
+
ordered_src_access_idx
[
j
];
});
...
...
@@ -992,7 +992,7 @@ struct ThreadwiseTensorSliceTransfer_v3
static_for
<
1
,
nDim
,
1
>
{}([
&
](
auto
i
)
{
index_t
tmp
=
ordered_dst_access_idx
[
I0
];
static_for
<
0
,
i
,
1
>
{}([
&
](
auto
j
)
{
static_for
<
1
,
i
,
1
>
{}([
&
](
auto
j
)
{
tmp
=
tmp
*
ordered_dst_access_lengths
[
j
]
+
ordered_dst_access_idx
[
j
];
});
...
...
@@ -1136,7 +1136,7 @@ struct ThreadwiseTensorSliceTransfer_v3
static_for
<
1
,
nDim
,
1
>
{}([
&
](
auto
i
)
{
index_t
tmp
=
ordered_src_access_lengths
[
I0
]
-
1
;
static_for
<
0
,
i
,
1
>
{}([
&
](
auto
j
)
{
static_for
<
1
,
i
,
1
>
{}([
&
](
auto
j
)
{
tmp
=
tmp
*
ordered_src_access_lengths
[
j
]
+
ordered_src_access_lengths
[
j
]
-
1
;
});
...
...
@@ -1196,7 +1196,7 @@ struct ThreadwiseTensorSliceTransfer_v3
static_for
<
1
,
nDim
,
1
>
{}([
&
](
auto
i
)
{
index_t
tmp
=
ordered_dst_access_lengths
[
I0
]
-
1
;
static_for
<
0
,
i
,
1
>
{}([
&
](
auto
j
)
{
static_for
<
1
,
i
,
1
>
{}([
&
](
auto
j
)
{
tmp
=
tmp
*
ordered_dst_access_lengths
[
j
]
+
ordered_dst_access_lengths
[
j
]
-
1
;
});
...
...
composable_kernel/include/tensor_operation/threadwise_tensor_slice_transfer_v1r4.hpp
View file @
5c8112d9
...
...
@@ -116,9 +116,6 @@ struct ThreadwiseTensorSliceTransfer_v1r4
constexpr
auto
dst_scalar_per_access
=
generate_sequence
(
detail
::
lambda_scalar_per_access
<
DstVectorDim
,
DstScalarPerVector
>
{},
Number
<
nDim
>
{});
constexpr
auto
dst_scalar_step_in_vector
=
generate_sequence
(
detail
::
lambda_scalar_step_in_vector
<
DstVectorDim
>
{},
Number
<
nDim
>
{});
constexpr
auto
access_lengths
=
SliceLengths
{}
/
dst_scalar_per_access
;
constexpr
auto
dim_access_order
=
DimAccessOrder
{};
...
...
@@ -141,7 +138,8 @@ struct ThreadwiseTensorSliceTransfer_v1r4
Number
<
nDim
>
{});
// make forward steps: dst0
// WARNING!!!!!!: this logic is only correct if DstScalarPerVector=1
// WARNING!!!!!!: this logic is only correct if dst/dst0/dst1 can use the same
// DstScalarPerVector
// TODO: fix this
const
auto
dst0_forward_steps
=
generate_tuple
(
[
&
](
auto
i
)
{
...
...
@@ -157,7 +155,8 @@ struct ThreadwiseTensorSliceTransfer_v1r4
Number
<
nDim
>
{});
// make forward steps: dst1
// WARNING!!!!!!: this logic is only correct if DstScalarPerVector=1
// WARNING!!!!!!: this logic is only correct if dst/dst0/dst1 can use the same
// DstScalarPerVector
// TODO: fix this
const
auto
dst1_forward_steps
=
generate_tuple
(
[
&
](
auto
i
)
{
...
...
@@ -187,7 +186,8 @@ struct ThreadwiseTensorSliceTransfer_v1r4
Number
<
nDim
>
{});
// make backward steps: dst0
// WARNING!!!!!!: this logic is only correct if DstScalarPerVector=1
// WARNING!!!!!!: this logic is only correct if dst/dst0/dst1 can use the same
// DstScalarPerVector
// TODO: fix this
const
auto
dst0_backward_steps
=
generate_tuple
(
[
&
](
auto
i
)
{
...
...
@@ -203,7 +203,8 @@ struct ThreadwiseTensorSliceTransfer_v1r4
Number
<
nDim
>
{});
// make backward steps: dst1
// WARNING!!!!!!: this logic is only correct if DstScalarPerVector=1
// WARNING!!!!!!: this logic is only correct if dst/dst0/dst1 can use the same
// DstScalarPerVector
// TODO: fix this
const
auto
dst1_backward_steps
=
generate_tuple
(
[
&
](
auto
i
)
{
...
...
@@ -229,7 +230,7 @@ struct ThreadwiseTensorSliceTransfer_v1r4
static_for
<
1
,
nDim
,
1
>
{}([
&
](
auto
i
)
{
index_t
tmp
=
ordered_access_idx
[
I0
];
static_for
<
0
,
i
,
1
>
{}([
&
](
auto
j
)
{
static_for
<
1
,
i
,
1
>
{}([
&
](
auto
j
)
{
tmp
=
tmp
*
ordered_access_lengths
[
j
]
+
ordered_access_idx
[
j
];
});
...
...
@@ -397,14 +398,12 @@ struct ThreadwiseTensorSliceTransfer_v1r4
typename
SrcBuffer
,
typename
DstBuffer
,
typename
Dst0Buffer
,
typename
Dst1Buffer
,
typename
DstStepHacks
>
typename
Dst1Buffer
>
__device__
void
Run
(
const
SrcDesc
&
,
const
SrcSliceOriginIdx
&
,
const
SrcBuffer
&
src_buf
,
const
DstDesc
&
dst_desc
,
DstBuffer
&
dst_buf
,
const
DstStepHacks
&
dst_step_hacks
,
const
Dst0Desc
&
dst0_desc
,
const
Dst0Buffer
&
dst0_buf
,
const
Dst1Desc
&
dst1_desc
,
...
...
@@ -427,7 +426,7 @@ struct ThreadwiseTensorSliceTransfer_v1r4
src_buf
,
dst_desc
,
dst_buf
,
dst
_step_hacks
,
f
_step_hacks
(
dst_desc
)
,
dst0_desc
,
dst0_buf
,
f_step_hacks
(
dst0_desc
),
...
...
@@ -461,7 +460,7 @@ struct ThreadwiseTensorSliceTransfer_v1r4
static_for
<
1
,
nDim
,
1
>
{}([
&
](
auto
i
)
{
index_t
tmp
=
ordered_access_lengths
[
I0
]
-
1
;
static_for
<
0
,
i
,
1
>
{}([
&
](
auto
j
)
{
static_for
<
1
,
i
,
1
>
{}([
&
](
auto
j
)
{
tmp
=
tmp
*
ordered_access_lengths
[
j
]
+
ordered_access_lengths
[
j
]
-
1
;
});
...
...
composable_kernel/include/tensor_operation/threadwise_tensor_slice_transfer_v1r5.hpp
0 → 100644
View file @
5c8112d9
This diff is collapsed.
Click to expand it.
composable_kernel/include/tensor_operation/threadwise_tensor_slice_transfer_v3r
2
.hpp
→
composable_kernel/include/tensor_operation/threadwise_tensor_slice_transfer_v3r
1
.hpp
View file @
5c8112d9
#ifndef CK_THREADWISE_TENSOR_SLICE_TRANSFER_V3R
2
_HPP
#define CK_THREADWISE_TENSOR_SLICE_TRANSFER_V3R
2
_HPP
#ifndef CK_THREADWISE_TENSOR_SLICE_TRANSFER_V3R
1
_HPP
#define CK_THREADWISE_TENSOR_SLICE_TRANSFER_V3R
1
_HPP
#include "common_header.hpp"
#include "tensor_descriptor.hpp"
...
...
@@ -47,6 +47,7 @@ struct lambda_scalar_per_access_for_src_and_dst
// 4. Use thread buffer
template
<
typename
SliceLengths
,
typename
SrcElementwiseOperation
,
typename
DstElementwiseOperation
,
InMemoryDataOperationEnum_t
DstInMemOp
,
typename
SrcData
,
typename
DstData
,
...
...
@@ -66,7 +67,7 @@ template <typename SliceLengths,
bool
DstResetCoordinateAfterRun
>
// control whether to move back dst coordinate after each
// RunWrite(), will be fused with MoveDstSliceWindow to
// save addr computation
struct
ThreadwiseTensorSliceTransfer_v3r
2
struct
ThreadwiseTensorSliceTransfer_v3r
1
{
static
constexpr
index_t
nDim
=
SliceLengths
::
Size
();
using
Index
=
MultiIndex
<
nDim
>
;
...
...
@@ -77,15 +78,17 @@ struct ThreadwiseTensorSliceTransfer_v3r2
using
SrcCoordStep
=
decltype
(
make_tensor_coordinate_step
(
SrcDesc
{},
Index
{}));
using
DstCoordStep
=
decltype
(
make_tensor_coordinate_step
(
DstDesc
{},
Index
{}));
__device__
constexpr
ThreadwiseTensorSliceTransfer_v3r
2
(
__device__
constexpr
ThreadwiseTensorSliceTransfer_v3r
1
(
const
SrcDesc
&
src_desc
,
const
Index
&
src_slice_origin
,
const
SrcElementwiseOperation
&
src_element_op
,
const
DstDesc
&
dst_desc
,
const
Index
&
dst_slice_origin
,
const
Src
ElementwiseOperation
&
src
_element_op
)
const
Dst
ElementwiseOperation
&
dst
_element_op
)
:
src_coord_
(
make_tensor_coordinate
(
src_desc
,
src_slice_origin
)),
dst_coord_
(
make_tensor_coordinate
(
dst_desc
,
dst_slice_origin
)),
src_element_op_
(
src_element_op
)
src_element_op_
(
src_element_op
),
dst_element_op_
(
dst_element_op
)
{
}
...
...
@@ -165,7 +168,7 @@ struct ThreadwiseTensorSliceTransfer_v3r2
static_for
<
1
,
nDim
,
1
>
{}([
&
](
auto
i
)
{
index_t
tmp
=
ordered_src_access_idx
[
I0
];
static_for
<
0
,
i
,
1
>
{}([
&
](
auto
j
)
{
static_for
<
1
,
i
,
1
>
{}([
&
](
auto
j
)
{
tmp
=
tmp
*
ordered_src_access_lengths
[
j
]
+
ordered_src_access_idx
[
j
];
});
...
...
@@ -412,7 +415,7 @@ struct ThreadwiseTensorSliceTransfer_v3r2
static_for
<
1
,
nDim
,
1
>
{}([
&
](
auto
i
)
{
index_t
tmp
=
ordered_dst_access_idx
[
I0
];
static_for
<
0
,
i
,
1
>
{}([
&
](
auto
j
)
{
static_for
<
1
,
i
,
1
>
{}([
&
](
auto
j
)
{
tmp
=
tmp
*
ordered_dst_access_lengths
[
j
]
+
ordered_dst_access_idx
[
j
];
});
...
...
@@ -442,13 +445,24 @@ struct ThreadwiseTensorSliceTransfer_v3r2
const
bool
is_dst_valid
=
coordinate_has_valid_offset_assuming_visible_index_is_valid
(
dst_desc
,
dst_coord_
);
using
dst_vector_t
=
typename
vector_type_maker_t
<
DstData
,
DstScalarPerVector
>::
type
;
using
dst_vector_type
=
vector_type_maker_t
<
DstData
,
DstScalarPerVector
>
;
using
dst_vector_t
=
typename
dst_vector_type
::
type
;
// copy data from dst_thread_scratch_ to dst_buf
// copy data from dst_thread_scratch_ into dst_vector_container
auto
dst_vector_container
=
dst_vector_type
{
dst_thread_scratch_
.
template
GetAsType
<
dst_vector_t
>(
dst_data_idx_seq
)};
// apply DstElementwiseOperation on dst_vector_container
static_for
<
0
,
DstScalarPerVector
,
1
>
{}([
&
](
auto
i
)
{
dst_vector_container
.
template
AsType
<
DstData
>()(
i
)
=
dst_element_op_
(
dst_vector_container
.
template
AsType
<
DstData
>()[
i
]);
});
// copy data from dst_vector_container to dst_buf
dst_buf
.
template
Set
<
dst_vector_t
>(
dst_coord_
.
GetOffset
(),
is_dst_valid
,
dst_
thread_scratch_
.
template
Get
AsType
<
dst_vector_t
>(
dst_data_idx_seq
)
);
dst_
vector_container
.
template
AsType
<
dst_vector_t
>(
)[
I0
]
);
constexpr
auto
move_on_dim
=
[
&
]()
constexpr
{
...
...
@@ -498,7 +512,7 @@ struct ThreadwiseTensorSliceTransfer_v3r2
template
<
typename
SrcBuffer
>
__device__
void
RunRead
(
const
SrcDesc
&
src_desc
,
const
SrcBuffer
&
src_buf
)
{
constexpr
index_t
ntransform_src
=
SrcDesc
::
GetNumOfTransform
();
constexpr
index_t
ntransform_src
=
remove_cvref_t
<
SrcDesc
>
::
GetNumOfTransform
();
constexpr
auto
zeros
=
typename
uniform_sequence_gen
<
ntransform_src
,
0
>::
type
{};
...
...
@@ -512,7 +526,8 @@ struct ThreadwiseTensorSliceTransfer_v3r2
template
<
typename
DstBuffer
>
__device__
void
RunWrite
(
const
DstDesc
&
dst_desc
,
DstBuffer
&
dst_buf
)
{
constexpr
index_t
ntransform_dst
=
DstDesc
::
GetNumOfTransform
();
// TODO: why need remove_cvref_t ?
constexpr
index_t
ntransform_dst
=
remove_cvref_t
<
DstDesc
>::
GetNumOfTransform
();
constexpr
auto
zeros
=
typename
uniform_sequence_gen
<
ntransform_dst
,
0
>::
type
{};
...
...
@@ -548,7 +563,7 @@ struct ThreadwiseTensorSliceTransfer_v3r2
static_for
<
1
,
nDim
,
1
>
{}([
&
](
auto
i
)
{
index_t
tmp
=
ordered_src_access_lengths
[
I0
]
-
1
;
static_for
<
0
,
i
,
1
>
{}([
&
](
auto
j
)
{
static_for
<
1
,
i
,
1
>
{}([
&
](
auto
j
)
{
tmp
=
tmp
*
ordered_src_access_lengths
[
j
]
+
ordered_src_access_lengths
[
j
]
-
1
;
});
...
...
@@ -608,7 +623,7 @@ struct ThreadwiseTensorSliceTransfer_v3r2
static_for
<
1
,
nDim
,
1
>
{}([
&
](
auto
i
)
{
index_t
tmp
=
ordered_dst_access_lengths
[
I0
]
-
1
;
static_for
<
0
,
i
,
1
>
{}([
&
](
auto
j
)
{
static_for
<
1
,
i
,
1
>
{}([
&
](
auto
j
)
{
tmp
=
tmp
*
ordered_dst_access_lengths
[
j
]
+
ordered_dst_access_lengths
[
j
]
-
1
;
});
...
...
@@ -811,6 +826,7 @@ struct ThreadwiseTensorSliceTransfer_v3r2
SrcCoord
src_coord_
;
DstCoord
dst_coord_
;
const
SrcElementwiseOperation
src_element_op_
;
const
DstElementwiseOperation
dst_element_op_
;
};
}
// namespace ck
...
...
Prev
1
2
3
4
5
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment