Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
96c73d70
Commit
96c73d70
authored
Apr 21, 2022
by
Chao Liu
Browse files
add missing type convert
parent
2d35fac0
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
109 additions
and
104 deletions
+109
-104
example/01_gemm/gemm_xdl_bf16.cpp
example/01_gemm/gemm_xdl_bf16.cpp
+45
-42
example/01_gemm/gemm_xdl_fp16.cpp
example/01_gemm/gemm_xdl_fp16.cpp
+1
-1
example/01_gemm/gemm_xdl_int8.cpp
example/01_gemm/gemm_xdl_int8.cpp
+47
-49
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp
...operation/gpu/thread/threadwise_tensor_slice_transfer.hpp
+9
-10
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v6r1.hpp
...tion/gpu/thread/threadwise_tensor_slice_transfer_v6r1.hpp
+7
-2
No files found.
example/01_gemm/gemm_xdl_bf16.cpp
View file @
96c73d70
...
...
@@ -11,8 +11,7 @@
#include "host_tensor.hpp"
#include "host_tensor_generator.hpp"
#include "device_tensor.hpp"
#include "device_gemm_xdl.hpp"
#include "device_gemm_xdl_c_shuffle.hpp"
#include "device_gemm_xdl_cshuffle.hpp"
#include "element_wise_operation.hpp"
#include "reference_gemm.hpp"
#include "gemm_specialization.hpp"
...
...
@@ -37,47 +36,51 @@ using ALayout = ck::tensor_layout::gemm::RowMajor;
using
BLayout
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
using
CLayout
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
static
constexpr
auto
GemmDefault
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
;
// clang-format off
using
DeviceGemmInstance
=
ck
::
tensor_operation
::
device
::
DeviceGemmXdl_C_Shuffle
<
ADataType
,
// ADataType
BDataType
,
// BDataType
CDataType
,
// CDataType
AccDataType
,
// AccDataType
CDataType
,
// CShuffleDataType
ALayout
,
// ALayout
BLayout
,
// BLayout
CLayout
,
// CLayout
PassThrough
,
// AElementwiseOperation
PassThrough
,
// BElementwiseOperation
PassThrough
,
// CElementwiseOperation
256
,
// BlockSize
256
,
// MPerBlock
128
,
// NPerBlock
32
,
// KPerBlock
8
,
// AK1
8
,
// BK1
32
,
// MPerXDL
32
,
// NPerXDL
4
,
// MXdlPerWave
2
,
// NXdlPerWave
S
<
4
,
64
,
1
>
,
// ABlockTransferThreadClusterLengths_K0_M_K1
S
<
1
,
0
,
2
>
,
// ABlockTransferThreadClusterArrangeOrder
S
<
1
,
0
,
2
>
,
// ABlockTransferSrcAccessOrder
2
,
// ABlockTransferSrcVectorDim
8
,
// ABlockTransferSrcScalarPerVector
8
,
// ABlockTransferDstScalarPerVector_K1
true
,
// ABlockLdsAddExtraM
S
<
4
,
64
,
1
>
,
// BBlockTransferThreadClusterLengths_K0_N_K1
S
<
1
,
0
,
2
>
,
// BBlockTransferThreadClusterArrangeOrder
S
<
1
,
0
,
2
>
,
// BBlockTransferSrcAccessOrder
2
,
// BBlockTransferSrcVectorDim
8
,
// BBlockTransferSrcScalarPerVector
8
,
// BBlockTransferDstScalarPerVector_K1
true
,
// BBlockLdsAddExtraN
1
,
// CShuffleMXdlPerWavePerShuffle
1
,
// CShuffleNXdlPerWavePerShuffle
S
<
1
,
1
,
32
,
1
,
1
,
8
>
,
// CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl
8
>
;
// CBlockTransferScalarPerVector_NWaveNPerXdl
using
DeviceGemmInstance
=
ck
::
tensor_operation
::
device
::
DeviceGemm_Xdl_CShuffle
<
ALayout
,
// typename ALayout,
BLayout
,
// typename BLayout,
CLayout
,
// typename CLayout,
ADataType
,
// typename ADataType,
BDataType
,
// typename BDataType,
CDataType
,
// typename CDataType,
AccDataType
,
// typename GemmAccDataType,
CDataType
,
// typename CShuffleDataType,
PassThrough
,
// typename AElementwiseOperation,
PassThrough
,
// typename BElementwiseOperation,
PassThrough
,
// typename CElementwiseOperation,
GemmDefault
,
// GemmSpecialization GemmSpec,
1
,
// index_t NumGemmKPrefetchStage,
256
,
// index_t BlockSize,
256
,
// index_t MPerBlock,
128
,
// index_t NPerBlock,
32
,
// index_t KPerBlock,
8
,
// index_t AK1,
8
,
// index_t BK1,
32
,
// index_t MPerXDL,
32
,
// index_t NPerXDL,
4
,
// index_t MXdlPerWave,
2
,
// index_t NXdlPerWave,
S
<
4
,
64
,
1
>
,
// typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
S
<
1
,
0
,
2
>
,
// typename ABlockTransferThreadClusterArrangeOrder,
S
<
1
,
0
,
2
>
,
// typename ABlockTransferSrcAccessOrder,
2
,
// index_t ABlockTransferSrcVectorDim,
8
,
// index_t ABlockTransferSrcScalarPerVector,
8
,
// index_t ABlockTransferDstScalarPerVector_AK1,
1
,
// bool ABlockLdsExtraM,
S
<
4
,
64
,
1
>
,
// typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
S
<
1
,
0
,
2
>
,
// typename BBlockTransferThreadClusterArrangeOrder,
S
<
1
,
0
,
2
>
,
// typename BBlockTransferSrcAccessOrder,
2
,
// index_t BBlockTransferSrcVectorDim,
8
,
// index_t BBlockTransferSrcScalarPerVector,
8
,
// index_t BBlockTransferDstScalarPerVector_BK1,
1
,
// bool BBlockLdsExtraN,
1
,
// index_t CShuffleMXdlPerWavePerShuffle,
1
,
// index_t CShuffleNXdlPerWavePerShuffle,
S
<
1
,
32
,
1
,
8
>
,
// typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
8
>
;
// index_t CShuffleBlockTransferScalarPerVector_NPerBlock>
// clang-format on
using
ReferenceGemmInstance
=
ck
::
tensor_operation
::
host
::
...
...
example/01_gemm/gemm_xdl_fp16.cpp
View file @
96c73d70
...
...
@@ -46,7 +46,7 @@ static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpeciali
static
constexpr
auto
GemmMNPadding
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNPadding
;
// clang-format off
#if
0
#if
1
using
DeviceGemmInstance
=
ck
::
tensor_operation
::
device
::
DeviceGemm_Xdl_CShuffle
//######| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
...
...
example/01_gemm/gemm_xdl_int8.cpp
View file @
96c73d70
...
...
@@ -11,8 +11,7 @@
#include "host_tensor.hpp"
#include "host_tensor_generator.hpp"
#include "device_tensor.hpp"
#include "device_gemm_xdl.hpp"
#include "device_gemm_xdl_c_shuffle.hpp"
#include "device_gemm_xdl_cshuffle.hpp"
#include "element_wise_operation.hpp"
#include "reference_gemm.hpp"
#include "gemm_specialization.hpp"
...
...
@@ -20,64 +19,63 @@
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
F32
=
float
;
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
ADataType
=
int8_t
;
using
BDataType
=
int8_t
;
using
CDataType
=
int
32
_t
;
using
CDataType
=
int
8
_t
;
using
AccDataType
=
int32_t
;
using
CShuffleDataType
=
int
32
_t
;
using
CShuffleDataType
=
int
8
_t
;
using
ALayout
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
BLayout
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
using
CLayout
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
static
constexpr
auto
GemmDefault
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
;
// clang-format off
using
DeviceGemmInstance
=
ck
::
tensor_operation
::
device
::
DeviceGemmXdl_C_Shuffle
<
ADataType
,
// ADataType
BDataType
,
// BDataType
CDataType
,
// CDataType
AccDataType
,
// AccDataType
CShuffleDataType
,
// CShuffleDataType
ALayout
,
// ALayout
BLayout
,
// BLayout
CLayout
,
// CLayout
PassThrough
,
// AElementwiseOperation
PassThrough
,
// BElementwiseOperation
PassThrough
,
// CElementwiseOperation
256
,
// BlockSize
256
,
// MPerBlock
128
,
// NPerBlock
64
,
// KPerBlock
16
,
// AK1
16
,
// BK1
32
,
// MPerXDL
32
,
// NPerXDL
4
,
// MXdlPerWave
2
,
// NXdlPerWave
S
<
4
,
64
,
1
>
,
// ABlockTransferThreadClusterLengths_K0_M_K1
S
<
1
,
0
,
2
>
,
// ABlockTransferThreadClusterArrangeOrder
S
<
1
,
0
,
2
>
,
// ABlockTransferSrcAccessOrder
2
,
// ABlockTransferSrcVectorDim
16
,
// ABlockTransferSrcScalarPerVector
16
,
// ABlockTransferDstScalarPerVector_K1
true
,
// ABlockLdsAddExtraM
S
<
4
,
64
,
1
>
,
// BBlockTransferThreadClusterLengths_K0_N_K1
S
<
1
,
0
,
2
>
,
// BBlockTransferThreadClusterArrangeOrder
S
<
1
,
0
,
2
>
,
// BBlockTransferSrcAccessOrder
2
,
// BBlockTransferSrcVectorDim
16
,
// BBlockTransferSrcScalarPerVector
16
,
// BBlockTransferDstScalarPerVector_K1
true
,
// BBlockLdsAddExtraN
1
,
// CShuffleMXdlPerWavePerShuffle
1
,
// CShuffleNXdlPerWavePerShuffle
S
<
1
,
1
,
32
,
1
,
1
,
8
>
,
// CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl
4
>
;
// CBlockTransferScalarPerVector_NWaveNPerXdl
using
DeviceGemmInstance
=
ck
::
tensor_operation
::
device
::
DeviceGemm_Xdl_CShuffle
<
ALayout
,
// typename ALayout,
BLayout
,
// typename BLayout,
CLayout
,
// typename CLayout,
ADataType
,
// typename ADataType,
BDataType
,
// typename BDataType,
CDataType
,
// typename CDataType,
AccDataType
,
// typename GemmAccDataType,
CShuffleDataType
,
// typename CShuffleDataType,
PassThrough
,
// typename AElementwiseOperation,
PassThrough
,
// typename BElementwiseOperation,
PassThrough
,
// typename CElementwiseOperation,
GemmDefault
,
// GemmSpecialization GemmSpec,
1
,
// index_t NumGemmKPrefetchStage,
256
,
// index_t BlockSize,
256
,
// index_t MPerBlock,
128
,
// index_t NPerBlock,
64
,
// index_t KPerBlock,
16
,
// index_t AK1,
16
,
// index_t BK1,
32
,
// index_t MPerXDL,
32
,
// index_t NPerXDL,
4
,
// index_t MXdlPerWave,
2
,
// index_t NXdlPerWave,
S
<
4
,
64
,
1
>
,
// typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
S
<
1
,
0
,
2
>
,
// typename ABlockTransferThreadClusterArrangeOrder,
S
<
1
,
0
,
2
>
,
// typename ABlockTransferSrcAccessOrder,
2
,
// index_t ABlockTransferSrcVectorDim,
16
,
// index_t ABlockTransferSrcScalarPerVector,
16
,
// index_t ABlockTransferDstScalarPerVector_AK1,
1
,
// bool ABlockLdsExtraM,
S
<
4
,
64
,
1
>
,
// typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
S
<
1
,
0
,
2
>
,
// typename BBlockTransferThreadClusterArrangeOrder,
S
<
1
,
0
,
2
>
,
// typename BBlockTransferSrcAccessOrder,
2
,
// index_t BBlockTransferSrcVectorDim,
8
,
// index_t BBlockTransferSrcScalarPerVector,
8
,
// index_t BBlockTransferDstScalarPerVector_BK1,
1
,
// bool BBlockLdsExtraN,
1
,
// index_t CShuffleMXdlPerWavePerShuffle,
1
,
// index_t CShuffleNXdlPerWavePerShuffle,
S
<
1
,
64
,
1
,
4
>
,
// typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
16
>
;
// index_t CShuffleBlockTransferScalarPerVector_NPerBlock>
// clang-format on
using
ReferenceGemmInstance
=
ck
::
tensor_operation
::
host
::
...
...
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp
View file @
96c73d70
...
...
@@ -51,7 +51,7 @@ template <typename SrcData,
typename
DstData
,
typename
SrcDesc
,
typename
DstDesc
,
typename
Dst
ElementwiseOperation
,
typename
ElementwiseOperation
,
typename
SliceLengths
,
typename
DimAccessOrder
,
index_t
DstVectorDim
,
...
...
@@ -70,12 +70,11 @@ struct ThreadwiseTensorSliceTransfer_v1r3
using
DstCoordStep
=
decltype
(
make_tensor_coordinate_step
(
DstDesc
{},
Index
{}));
__device__
constexpr
ThreadwiseTensorSliceTransfer_v1r3
(
const
DstDesc
&
dst_desc
,
const
Index
&
dst_slice_origin_idx
,
const
DstElementwiseOperation
&
dst_element_op
)
__device__
constexpr
ThreadwiseTensorSliceTransfer_v1r3
(
const
DstDesc
&
dst_desc
,
const
Index
&
dst_slice_origin_idx
,
const
ElementwiseOperation
&
element_op
)
:
dst_coord_
(
make_tensor_coordinate
(
dst_desc
,
dst_slice_origin_idx
)),
dst_
element_op_
{
dst_
element_op
}
element_op_
{
element_op
}
{
static_assert
(
SrcDesc
::
IsKnownAtCompileTime
(),
"wrong! SrcDesc need to known at compile-time"
);
...
...
@@ -136,13 +135,13 @@ struct ThreadwiseTensorSliceTransfer_v1r3
constexpr
index_t
src_offset
=
src_desc
.
CalculateOffset
(
src_slice_origin_idx
+
idx_md
+
i
*
dst_scalar_step_in_vector
);
SrcData
dst_
v
;
SrcData
v
;
// apply element-wise operation
dst_
element_op_
(
dst_
v
,
src_buf
[
Number
<
src_offset
>
{}]);
element_op_
(
v
,
src_buf
[
Number
<
src_offset
>
{}]);
// apply type convert
dst_vector
.
template
AsType
<
DstData
>()(
i
)
=
type_convert
<
DstData
>
(
dst_
v
);
dst_vector
.
template
AsType
<
DstData
>()(
i
)
=
type_convert
<
DstData
>
(
v
);
});
const
bool
is_dst_valid
=
...
...
@@ -213,7 +212,7 @@ struct ThreadwiseTensorSliceTransfer_v1r3
private:
DstCoord
dst_coord_
;
const
Dst
ElementwiseOperation
dst_
element_op_
;
const
ElementwiseOperation
element_op_
;
};
// namespace ThreadwiseTensorSliceTransfer_v1r3
// Assume:
...
...
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v6r1.hpp
View file @
96c73d70
...
...
@@ -102,8 +102,13 @@ struct ThreadwiseTensorSliceTransfer_v6r1
// apply pointwise operation
static_for
<
0
,
ScalarPerVector
,
1
>
{}([
&
](
auto
i
)
{
element_op_
(
dst_vector_container
.
template
AsType
<
DstData
>()(
i
),
src_vector_container
.
template
AsType
<
SrcData
>()[
i
]);
SrcData
v
;
// apply element-wise operation
element_op_
(
v
,
src_vector_container
.
template
AsType
<
SrcData
>()[
i
]);
// apply type convert
dst_vector_container
.
template
AsType
<
DstData
>()(
i
)
=
type_convert
<
DstData
>
(
v
);
});
const
bool
is_dst_valid
=
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment