Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
1a24ad25
Commit
1a24ad25
authored
Apr 21, 2022
by
Chao Liu
Browse files
refactor
parent
7cd48ef1
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
47 additions
and
51 deletions
+47
-51
example/14_gemm_xdl_requant_relu_requant/gemm_xdl_requant_relu_requant_int8.cpp
...quant_relu_requant/gemm_xdl_requant_relu_requant_int8.cpp
+45
-42
library/include/ck/library/reference_tensor_operation/cpu/reference_gemm.hpp
...library/reference_tensor_operation/cpu/reference_gemm.hpp
+1
-4
library/include/ck/library/utility/check_err.hpp
library/include/ck/library/utility/check_err.hpp
+1
-5
No files found.
example/14_gemm_xdl_requant_relu_requant/gemm_xdl_requant_relu_requant_int8.cpp
View file @
1a24ad25
...
@@ -13,8 +13,7 @@
...
@@ -13,8 +13,7 @@
#include "host_tensor_generator.hpp"
#include "host_tensor_generator.hpp"
#include "host_gemm.hpp"
#include "host_gemm.hpp"
#include "device_tensor.hpp"
#include "device_tensor.hpp"
#include "device_gemm_xdl.hpp"
#include "device_gemm_xdl_cshuffle.hpp"
#include "device_gemm_xdl_c_shuffle.hpp"
#include "element_wise_operation.hpp"
#include "element_wise_operation.hpp"
#include "reference_gemm.hpp"
#include "reference_gemm.hpp"
#include "gemm_specialization.hpp"
#include "gemm_specialization.hpp"
...
@@ -54,47 +53,51 @@ using ALayout = ck::tensor_layout::gemm::RowMajor;
...
@@ -54,47 +53,51 @@ using ALayout = ck::tensor_layout::gemm::RowMajor;
using
BLayout
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
using
BLayout
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
using
CLayout
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
CLayout
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
static
constexpr
auto
GemmDefault
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
;
// clang-format off
// clang-format off
using
DeviceGemmInstance
=
ck
::
tensor_operation
::
device
::
DeviceGemmXdl_C_Shuffle
<
using
DeviceGemmInstance
=
ck
::
tensor_operation
::
device
::
DeviceGemm_Xdl_CShuffle
<
ADataType
,
// ADataType
ALayout
,
// typename ALayout,
BDataType
,
// BDataType
BLayout
,
// typename BLayout,
CDataType
,
// CDataType
CLayout
,
// typename CLayout,
AccDataType
,
// AccDataType
ADataType
,
// typename ADataType,
CShuffleDataType
,
// CShuffleDataType
BDataType
,
// typename BDataType,
ALayout
,
// ALayout
CDataType
,
// typename CDataType,
BLayout
,
// BLayout
AccDataType
,
// typename GemmAccDataType,
CLayout
,
// CLayout
CShuffleDataType
,
// typename CShuffleDataType,
PassThrough
,
// AElementwiseOperation
PassThrough
,
// typename AElementwiseOperation,
PassThrough
,
// BElementwiseOperation
PassThrough
,
// typename BElementwiseOperation,
RequantReluRequant
,
// CElementwiseOperation
RequantReluRequant
,
// typename CElementwiseOperation,
256
,
// BlockSize
GemmDefault
,
// GemmSpecialization GemmSpec,
256
,
// MPerBlock
1
,
// index_t NumGemmKPrefetchStage,
128
,
// NPerBlock
256
,
// index_t BlockSize,
64
,
// KPerBlock
256
,
// index_t MPerBlock,
16
,
// AK1
128
,
// index_t NPerBlock,
16
,
// BK1
64
,
// index_t KPerBlock,
32
,
// MPerXDL
16
,
// index_t AK1,
32
,
// NPerXDL
16
,
// index_t BK1,
4
,
// MXdlPerWave
32
,
// index_t MPerXDL,
2
,
// NXdlPerWave
32
,
// index_t NPerXDL,
S
<
4
,
64
,
1
>
,
// ABlockTransferThreadClusterLengths_K0_M_K1
4
,
// index_t MXdlPerWave,
S
<
1
,
0
,
2
>
,
// ABlockTransferThreadClusterArrangeOrder
2
,
// index_t NXdlPerWave,
S
<
1
,
0
,
2
>
,
// ABlockTransferSrcAccessOrder
S
<
4
,
64
,
1
>
,
// typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
2
,
// ABlockTransferSrcVectorDim
S
<
1
,
0
,
2
>
,
// typename ABlockTransferThreadClusterArrangeOrder,
16
,
// ABlockTransferSrcScalarPerVector
S
<
1
,
0
,
2
>
,
// typename ABlockTransferSrcAccessOrder,
16
,
// ABlockTransferDstScalarPerVector_K1
2
,
// index_t ABlockTransferSrcVectorDim,
true
,
// ABlockLdsAddExtraM
16
,
// index_t ABlockTransferSrcScalarPerVector,
S
<
4
,
64
,
1
>
,
// BBlockTransferThreadClusterLengths_K0_N_K1
16
,
// index_t ABlockTransferDstScalarPerVector_AK1,
S
<
1
,
0
,
2
>
,
// BBlockTransferThreadClusterArrangeOrder
1
,
// bool ABlockLdsExtraM,
S
<
1
,
0
,
2
>
,
// BBlockTransferSrcAccessOrder
S
<
4
,
64
,
1
>
,
// typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
2
,
// BBlockTransferSrcVectorDim
S
<
1
,
0
,
2
>
,
// typename BBlockTransferThreadClusterArrangeOrder,
16
,
// BBlockTransferSrcScalarPerVector
S
<
1
,
0
,
2
>
,
// typename BBlockTransferSrcAccessOrder,
16
,
// BBlockTransferDstScalarPerVector_K1
2
,
// index_t BBlockTransferSrcVectorDim,
true
,
// BBlockLdsAddExtraN
8
,
// index_t BBlockTransferSrcScalarPerVector,
1
,
// CShuffleMXdlPerWavePerShuffle
8
,
// index_t BBlockTransferDstScalarPerVector_BK1,
1
,
// CShuffleNXdlPerWavePerShuffle
1
,
// bool BBlockLdsExtraN,
S
<
1
,
1
,
64
,
1
,
1
,
4
>
,
// CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl
1
,
// index_t CShuffleMXdlPerWavePerShuffle,
16
>
;
// CBlockTransferScalarPerVector_NWaveNPerXdl
1
,
// index_t CShuffleNXdlPerWavePerShuffle,
S
<
1
,
64
,
1
,
4
>
,
// typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
16
>
;
// index_t CShuffleBlockTransferScalarPerVector_NPerBlock>
// clang-format on
// clang-format on
using
ReferenceGemmInstance
=
ck
::
tensor_operation
::
host
::
using
ReferenceGemmInstance
=
ck
::
tensor_operation
::
host
::
...
...
library/include/ck/library/reference_tensor_operation/cpu/reference_gemm.hpp
View file @
1a24ad25
#ifndef REFERENCE_GEMM_HPP
#pragma once
#define REFERENCE_GEMM_HPP
#include <iostream>
#include <iostream>
#include <sstream>
#include <sstream>
#include "device_base.hpp"
#include "device_base.hpp"
...
@@ -129,4 +127,3 @@ struct ReferenceGemm : public device::BaseOperator
...
@@ -129,4 +127,3 @@ struct ReferenceGemm : public device::BaseOperator
}
// namespace host
}
// namespace host
}
// namespace tensor_operation
}
// namespace tensor_operation
}
// namespace ck
}
// namespace ck
#endif
library/include/ck/library/utility/check_err.hpp
View file @
1a24ad25
#ifndef CHECK_ERR_HPP
#pragma once
#define CHECK_ERR_HPP
#include <algorithm>
#include <algorithm>
#include <cmath>
#include <cmath>
#include <cstdlib>
#include <cstdlib>
...
@@ -194,5 +192,3 @@ std::ostream& operator<<(std::ostream& os, const std::vector<T>& v)
...
@@ -194,5 +192,3 @@ std::ostream& operator<<(std::ostream& os, const std::vector<T>& v)
std
::
copy
(
std
::
begin
(
v
),
std
::
end
(
v
),
std
::
ostream_iterator
<
T
>
(
os
,
" "
));
std
::
copy
(
std
::
begin
(
v
),
std
::
end
(
v
),
std
::
ostream_iterator
<
T
>
(
os
,
" "
));
return
os
;
return
os
;
}
}
#endif
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment