Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
b58b98ff
Commit
b58b98ff
authored
Jun 15, 2022
by
Chao Liu
Browse files
add ckProfiler
parent
3d005816
Changes
12
Hide whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
94 additions
and
75 deletions
+94
-75
include/ck/tensor_description/tensor_adaptor.hpp
include/ck/tensor_description/tensor_adaptor.hpp
+4
-0
include/ck/tensor_description/tensor_descriptor.hpp
include/ck/tensor_description/tensor_descriptor.hpp
+7
-0
include/ck/utility/tuple.hpp
include/ck/utility/tuple.hpp
+1
-1
library/include/ck/library/tensor_operation_instance/device_operation_instance.hpp
...y/tensor_operation_instance/device_operation_instance.hpp
+1
-1
library/src/tensor_operation_instance/gpu/gemm_add_add_fastgelu/device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instance.cpp
..._fastgelu_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instance.cpp
+4
-4
library/src/tensor_operation_instance/gpu/gemm_add_add_fastgelu/device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instance.cpp
..._fastgelu_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instance.cpp
+25
-26
library/src/tensor_operation_instance/gpu/gemm_add_add_fastgelu/device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instance.cpp
..._fastgelu_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instance.cpp
+4
-4
library/src/tensor_operation_instance/gpu/gemm_add_add_fastgelu/device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instance.cpp
..._fastgelu_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instance.cpp
+3
-3
profiler/CMakeLists.txt
profiler/CMakeLists.txt
+2
-2
profiler/include/profile_gemm_add_add_fastgelu_impl.hpp
profiler/include/profile_gemm_add_add_fastgelu_impl.hpp
+12
-7
profiler/src/profile_gemm_add_add_fastgelu.cpp
profiler/src/profile_gemm_add_add_fastgelu.cpp
+29
-25
profiler/src/profiler.cpp
profiler/src/profiler.cpp
+2
-2
No files found.
include/ck/tensor_description/tensor_adaptor.hpp
View file @
b58b98ff
...
@@ -136,7 +136,11 @@ struct TensorAdaptor
...
@@ -136,7 +136,11 @@ struct TensorAdaptor
using
ElementSize
=
remove_cv_t
<
decltype
(
InitializeElementSize
(
Transforms
{}))
>
;
using
ElementSize
=
remove_cv_t
<
decltype
(
InitializeElementSize
(
Transforms
{}))
>
;
public:
public:
#if 0 // workaround compiler complaint about constexpr
__host__ __device__ constexpr TensorAdaptor() = default;
__host__ __device__ constexpr TensorAdaptor() = default;
#else
__host__
__device__
constexpr
TensorAdaptor
()
:
transforms_
{},
element_size_
{}
{}
#endif
__host__
__device__
constexpr
TensorAdaptor
(
const
Transforms
&
transforms
)
__host__
__device__
constexpr
TensorAdaptor
(
const
Transforms
&
transforms
)
:
transforms_
{
transforms
},
element_size_
{
InitializeElementSize
(
transforms
)}
:
transforms_
{
transforms
},
element_size_
{
InitializeElementSize
(
transforms
)}
...
...
include/ck/tensor_description/tensor_descriptor.hpp
View file @
b58b98ff
...
@@ -111,7 +111,14 @@ struct TensorDescriptor
...
@@ -111,7 +111,14 @@ struct TensorDescriptor
using
ElementSize
=
remove_cv_t
<
decltype
(
InitializeElementSize
(
Transforms
{}))
>
;
using
ElementSize
=
remove_cv_t
<
decltype
(
InitializeElementSize
(
Transforms
{}))
>
;
public:
public:
#if 0 // workaround compiler complaint about constexpr
__host__ __device__ constexpr TensorDescriptor() = default;
__host__ __device__ constexpr TensorDescriptor() = default;
#else
__host__
__device__
constexpr
TensorDescriptor
()
:
transforms_
{},
element_size_
{},
element_space_size_
{}
{
}
#endif
__host__
__device__
constexpr
TensorDescriptor
(
const
Transforms
&
transforms
,
__host__
__device__
constexpr
TensorDescriptor
(
const
Transforms
&
transforms
,
ElementSpaceSize
element_space_size
)
ElementSpaceSize
element_space_size
)
...
...
include/ck/utility/tuple.hpp
View file @
b58b98ff
...
@@ -18,7 +18,7 @@ struct TupleElementKey
...
@@ -18,7 +18,7 @@ struct TupleElementKey
template
<
typename
Key
,
typename
Data
>
template
<
typename
Key
,
typename
Data
>
struct
TupleElementKeyData
struct
TupleElementKeyData
{
{
#if 0
#if 0
// workaround compiler complaint about implicitly-deleted default constructor
__host__ __device__ constexpr TupleElementKeyData() = default;
__host__ __device__ constexpr TupleElementKeyData() = default;
#else
#else
__host__
__device__
constexpr
TupleElementKeyData
()
:
mData
{}
{}
__host__
__device__
constexpr
TupleElementKeyData
()
:
mData
{}
{}
...
...
library/include/ck/library/tensor_operation_instance/device_operation_instance.hpp
View file @
b58b98ff
#pragma once
#pragma once
#include <
stdlib.h
>
#include <
vector
>
namespace
ck
{
namespace
ck
{
namespace
tensor_operation
{
namespace
tensor_operation
{
...
...
library/src/tensor_operation_instance/gpu/gemm_add_add_fastgelu/device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instance.cpp
View file @
b58b98ff
...
@@ -12,10 +12,10 @@ namespace device_gemm_instance {
...
@@ -12,10 +12,10 @@ namespace device_gemm_instance {
using
F16
=
ck
::
half_t
;
using
F16
=
ck
::
half_t
;
using
F32
=
float
;
using
F32
=
float
;
using
F16_F16
=
ck
::
Tuple
<
F16
,
F16
>
using
F16_F16
=
ck
::
Tuple
<
F16
,
F16
>
;
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
template
<
ck
::
index_t
...
Is
>
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
S
=
ck
::
Sequence
<
Is
...
>
;
...
@@ -28,7 +28,7 @@ static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecializa
...
@@ -28,7 +28,7 @@ static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecializa
// e = elementwise((a * b), d)
// e = elementwise((a * b), d)
// outout: e[m, n]
// outout: e[m, n]
// input: a[k, m], b[k, n], d[m, n]
// input: a[k, m], b[k, n], d[m, n]
using
device_gemm_add_add_gelu_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances
=
std
::
tuple
<
using
device_gemm_add_add_
fast
gelu_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances
=
std
::
tuple
<
// clang-format off
// clang-format off
//##############################| ALayout| BLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//##############################| ALayout| BLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//##############################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//##############################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
...
...
library/src/tensor_operation_instance/gpu/gemm_add_add_fastgelu/device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instance.cpp
View file @
b58b98ff
...
@@ -12,10 +12,10 @@ namespace device_gemm_instance {
...
@@ -12,10 +12,10 @@ namespace device_gemm_instance {
using
F16
=
ck
::
half_t
;
using
F16
=
ck
::
half_t
;
using
F32
=
float
;
using
F32
=
float
;
using
F16_F16
=
ck
::
Tuple
<
F16
,
F16
>
using
F16_F16
=
ck
::
Tuple
<
F16
,
F16
>
;
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
template
<
ck
::
index_t
...
Is
>
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
S
=
ck
::
Sequence
<
Is
...
>
;
...
@@ -28,28 +28,28 @@ static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecializa
...
@@ -28,28 +28,28 @@ static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecializa
// e = elementwise((a * b), d)
// e = elementwise((a * b), d)
// outout: e[m, n]
// outout: e[m, n]
// input: a[k, m], b[n, k], d[m, n]
// input: a[k, m], b[n, k], d[m, n]
using
device_gemm_gelu_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances
=
std
::
tuple
<
using
device_gemm_
add_add_fast
gelu_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances
=
std
::
tuple
<
// clang-format off
// clang-format off
//##############################| ALayout| BLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//##############################| ALayout| BLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B|
CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//##############################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//##############################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise|
Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//##############################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//##############################| | | | | | | | | | Operation| Operation|
Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
//##############################| | | | | | | | | | | |
| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemmMultipleD_Xdl_CShuffle
<
Col
,
Col
,
Row
,
F16
,
F16
,
F32
,
F32
,
F16_F16
,
F16
,
PassThrough
,
PassThrough
,
FastGelu
,
GemmDefault
,
1
,
256
,
256
,
128
,
32
,
2
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
0
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceGemmMultipleD_Xdl_CShuffle
<
Col
,
Col
,
Row
,
F16
,
F16
,
F32
,
F32
,
F16_F16
,
F16
,
PassThrough
,
PassThrough
,
AddAdd
FastGelu
,
GemmDefault
,
1
,
256
,
256
,
128
,
32
,
2
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
0
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceGemmMultipleD_Xdl_CShuffle
<
Col
,
Col
,
Row
,
F16
,
F16
,
F32
,
F32
,
F16_F16
,
F16
,
PassThrough
,
PassThrough
,
FastGelu
,
GemmDefault
,
1
,
256
,
256
,
128
,
32
,
8
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceGemmMultipleD_Xdl_CShuffle
<
Col
,
Col
,
Row
,
F16
,
F16
,
F32
,
F32
,
F16_F16
,
F16
,
PassThrough
,
PassThrough
,
AddAdd
FastGelu
,
GemmDefault
,
1
,
256
,
256
,
128
,
32
,
8
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceGemmMultipleD_Xdl_CShuffle
<
Col
,
Col
,
Row
,
F16
,
F16
,
F32
,
F32
,
F16_F16
,
F16
,
PassThrough
,
PassThrough
,
FastGelu
,
GemmDefault
,
1
,
256
,
128
,
256
,
32
,
2
,
8
,
32
,
32
,
2
,
4
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
0
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceGemmMultipleD_Xdl_CShuffle
<
Col
,
Col
,
Row
,
F16
,
F16
,
F32
,
F32
,
F16_F16
,
F16
,
PassThrough
,
PassThrough
,
AddAdd
FastGelu
,
GemmDefault
,
1
,
256
,
128
,
256
,
32
,
2
,
8
,
32
,
32
,
2
,
4
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
0
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceGemmMultipleD_Xdl_CShuffle
<
Col
,
Col
,
Row
,
F16
,
F16
,
F32
,
F32
,
F16_F16
,
F16
,
PassThrough
,
PassThrough
,
FastGelu
,
GemmDefault
,
1
,
256
,
128
,
256
,
32
,
8
,
8
,
32
,
32
,
2
,
4
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceGemmMultipleD_Xdl_CShuffle
<
Col
,
Col
,
Row
,
F16
,
F16
,
F32
,
F32
,
F16_F16
,
F16
,
PassThrough
,
PassThrough
,
AddAdd
FastGelu
,
GemmDefault
,
1
,
256
,
128
,
256
,
32
,
8
,
8
,
32
,
32
,
2
,
4
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceGemmMultipleD_Xdl_CShuffle
<
Col
,
Col
,
Row
,
F16
,
F16
,
F32
,
F32
,
F16_F16
,
F16
,
PassThrough
,
PassThrough
,
FastGelu
,
GemmDefault
,
1
,
128
,
128
,
128
,
32
,
2
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
0
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
>
,
DeviceGemmMultipleD_Xdl_CShuffle
<
Col
,
Col
,
Row
,
F16
,
F16
,
F32
,
F32
,
F16_F16
,
F16
,
PassThrough
,
PassThrough
,
AddAdd
FastGelu
,
GemmDefault
,
1
,
128
,
128
,
128
,
32
,
2
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
0
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
>
,
DeviceGemmMultipleD_Xdl_CShuffle
<
Col
,
Col
,
Row
,
F16
,
F16
,
F32
,
F32
,
F16_F16
,
F16
,
PassThrough
,
PassThrough
,
FastGelu
,
GemmDefault
,
1
,
128
,
128
,
128
,
32
,
8
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
1
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
>
,
DeviceGemmMultipleD_Xdl_CShuffle
<
Col
,
Col
,
Row
,
F16
,
F16
,
F32
,
F32
,
F16_F16
,
F16
,
PassThrough
,
PassThrough
,
AddAdd
FastGelu
,
GemmDefault
,
1
,
128
,
128
,
128
,
32
,
8
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
1
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
>
,
DeviceGemmMultipleD_Xdl_CShuffle
<
Col
,
Col
,
Row
,
F16
,
F16
,
F32
,
F32
,
F16_F16
,
F16
,
PassThrough
,
PassThrough
,
FastGelu
,
GemmDefault
,
1
,
256
,
128
,
128
,
32
,
2
,
8
,
32
,
32
,
2
,
2
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
0
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceGemmMultipleD_Xdl_CShuffle
<
Col
,
Col
,
Row
,
F16
,
F16
,
F32
,
F32
,
F16_F16
,
F16
,
PassThrough
,
PassThrough
,
AddAdd
FastGelu
,
GemmDefault
,
1
,
256
,
128
,
128
,
32
,
2
,
8
,
32
,
32
,
2
,
2
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
0
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceGemmMultipleD_Xdl_CShuffle
<
Col
,
Col
,
Row
,
F16
,
F16
,
F32
,
F32
,
F16_F16
,
F16
,
PassThrough
,
PassThrough
,
FastGelu
,
GemmDefault
,
1
,
256
,
128
,
128
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceGemmMultipleD_Xdl_CShuffle
<
Col
,
Col
,
Row
,
F16
,
F16
,
F32
,
F32
,
F16_F16
,
F16
,
PassThrough
,
PassThrough
,
AddAdd
FastGelu
,
GemmDefault
,
1
,
256
,
128
,
128
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceGemmMultipleD_Xdl_CShuffle
<
Col
,
Col
,
Row
,
F16
,
F16
,
F32
,
F32
,
F16_F16
,
F16
,
PassThrough
,
PassThrough
,
FastGelu
,
GemmDefault
,
1
,
128
,
128
,
64
,
32
,
2
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
0
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
4
>
,
8
>
,
DeviceGemmMultipleD_Xdl_CShuffle
<
Col
,
Col
,
Row
,
F16
,
F16
,
F32
,
F32
,
F16_F16
,
F16
,
PassThrough
,
PassThrough
,
AddAdd
FastGelu
,
GemmDefault
,
1
,
128
,
128
,
64
,
32
,
2
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
0
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
4
>
,
8
>
,
DeviceGemmMultipleD_Xdl_CShuffle
<
Col
,
Col
,
Row
,
F16
,
F16
,
F32
,
F32
,
F16_F16
,
F16
,
PassThrough
,
PassThrough
,
FastGelu
,
GemmDefault
,
1
,
128
,
128
,
64
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
1
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
4
>
,
8
>
,
DeviceGemmMultipleD_Xdl_CShuffle
<
Col
,
Col
,
Row
,
F16
,
F16
,
F32
,
F32
,
F16_F16
,
F16
,
PassThrough
,
PassThrough
,
AddAdd
FastGelu
,
GemmDefault
,
1
,
128
,
128
,
64
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
1
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
4
>
,
8
>
,
DeviceGemmMultipleD_Xdl_CShuffle
<
Col
,
Col
,
Row
,
F16
,
F16
,
F32
,
F32
,
F16_F16
,
F16
,
PassThrough
,
PassThrough
,
FastGelu
,
GemmDefault
,
1
,
128
,
64
,
128
,
32
,
2
,
8
,
32
,
32
,
2
,
2
,
S
<
8
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
0
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
>
,
DeviceGemmMultipleD_Xdl_CShuffle
<
Col
,
Col
,
Row
,
F16
,
F16
,
F32
,
F32
,
F16_F16
,
F16
,
PassThrough
,
PassThrough
,
AddAdd
FastGelu
,
GemmDefault
,
1
,
128
,
64
,
128
,
32
,
2
,
8
,
32
,
32
,
2
,
2
,
S
<
8
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
0
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
>
,
DeviceGemmMultipleD_Xdl_CShuffle
<
Col
,
Col
,
Row
,
F16
,
F16
,
F32
,
F32
,
F16_F16
,
F16
,
PassThrough
,
PassThrough
,
FastGelu
,
GemmDefault
,
1
,
128
,
64
,
128
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
1
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
>
,
DeviceGemmMultipleD_Xdl_CShuffle
<
Col
,
Col
,
Row
,
F16
,
F16
,
F32
,
F32
,
F16_F16
,
F16
,
PassThrough
,
PassThrough
,
AddAdd
FastGelu
,
GemmDefault
,
1
,
128
,
64
,
128
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
1
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
>
,
DeviceGemmMultipleD_Xdl_CShuffle
<
Col
,
Col
,
Row
,
F16
,
F16
,
F32
,
F32
,
F16_F16
,
F16
,
PassThrough
,
PassThrough
,
FastGelu
,
GemmDefault
,
1
,
256
,
128
,
64
,
32
,
2
,
8
,
32
,
32
,
2
,
1
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
0
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceGemmMultipleD_Xdl_CShuffle
<
Col
,
Col
,
Row
,
F16
,
F16
,
F32
,
F32
,
F16_F16
,
F16
,
PassThrough
,
PassThrough
,
AddAdd
FastGelu
,
GemmDefault
,
1
,
256
,
128
,
64
,
32
,
2
,
8
,
32
,
32
,
2
,
1
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
0
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceGemmMultipleD_Xdl_CShuffle
<
Col
,
Col
,
Row
,
F16
,
F16
,
F32
,
F32
,
F16_F16
,
F16
,
PassThrough
,
PassThrough
,
FastGelu
,
GemmDefault
,
1
,
256
,
128
,
64
,
32
,
8
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceGemmMultipleD_Xdl_CShuffle
<
Col
,
Col
,
Row
,
F16
,
F16
,
F32
,
F32
,
F16_F16
,
F16
,
PassThrough
,
PassThrough
,
AddAdd
FastGelu
,
GemmDefault
,
1
,
256
,
128
,
64
,
32
,
8
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceGemmMultipleD_Xdl_CShuffle
<
Col
,
Col
,
Row
,
F16
,
F16
,
F32
,
F32
,
F16_F16
,
F16
,
PassThrough
,
PassThrough
,
FastGelu
,
GemmDefault
,
1
,
256
,
64
,
128
,
32
,
2
,
8
,
32
,
32
,
1
,
2
,
S
<
16
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
0
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceGemmMultipleD_Xdl_CShuffle
<
Col
,
Col
,
Row
,
F16
,
F16
,
F32
,
F32
,
F16_F16
,
F16
,
PassThrough
,
PassThrough
,
AddAdd
FastGelu
,
GemmDefault
,
1
,
256
,
64
,
128
,
32
,
2
,
8
,
32
,
32
,
1
,
2
,
S
<
16
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
0
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceGemmMultipleD_Xdl_CShuffle
<
Col
,
Col
,
Row
,
F16
,
F16
,
F32
,
F32
,
F16_F16
,
F16
,
PassThrough
,
PassThrough
,
FastGelu
,
GemmDefault
,
1
,
256
,
64
,
128
,
32
,
8
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
DeviceGemmMultipleD_Xdl_CShuffle
<
Col
,
Col
,
Row
,
F16
,
F16
,
F32
,
F32
,
F16_F16
,
F16
,
PassThrough
,
PassThrough
,
AddAdd
FastGelu
,
GemmDefault
,
1
,
256
,
64
,
128
,
32
,
8
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
// clang-format on
// clang-format on
>
;
>
;
...
@@ -57,8 +57,7 @@ void add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instanc
...
@@ -57,8 +57,7 @@ void add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instanc
std
::
vector
<
DeviceGemmMultipleDPtr
<
2
,
PassThrough
,
PassThrough
,
AddAddFastGelu
>>&
instances
)
std
::
vector
<
DeviceGemmMultipleDPtr
<
2
,
PassThrough
,
PassThrough
,
AddAddFastGelu
>>&
instances
)
{
{
add_device_operation_instances
(
add_device_operation_instances
(
instances
,
instances
,
device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances
{});
device_gemm_gelu_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances
{});
}
}
}
// namespace device_gemm_instance
}
// namespace device_gemm_instance
...
...
library/src/tensor_operation_instance/gpu/gemm_add_add_fastgelu/device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instance.cpp
View file @
b58b98ff
#include <stdlib.h>
#include <stdlib.h>
#include "config.hpp"
#include "config.hpp"
#include "device_gemm_xdl_cshuffle.hpp"
#include "element_wise_operation.hpp"
#include "element_wise_operation.hpp"
#include "device_operation_instance.hpp"
#include "device_operation_instance.hpp"
#include "device_gemm_multiple_d_xdl_cshuffle.hpp"
namespace
ck
{
namespace
ck
{
namespace
tensor_operation
{
namespace
tensor_operation
{
...
@@ -12,10 +12,10 @@ namespace device_gemm_instance {
...
@@ -12,10 +12,10 @@ namespace device_gemm_instance {
using
F16
=
ck
::
half_t
;
using
F16
=
ck
::
half_t
;
using
F32
=
float
;
using
F32
=
float
;
using
F16_F16
=
ck
::
Tuple
<
F16
,
F16
>
using
F16_F16
=
ck
::
Tuple
<
F16
,
F16
>
;
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
template
<
ck
::
index_t
...
Is
>
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
S
=
ck
::
Sequence
<
Is
...
>
;
...
...
library/src/tensor_operation_instance/gpu/gemm_add_add_fastgelu/device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instance.cpp
View file @
b58b98ff
...
@@ -12,10 +12,10 @@ namespace device_gemm_instance {
...
@@ -12,10 +12,10 @@ namespace device_gemm_instance {
using
F16
=
ck
::
half_t
;
using
F16
=
ck
::
half_t
;
using
F32
=
float
;
using
F32
=
float
;
using
F16_F16
=
ck
::
Tuple
<
F16
,
F16
>
using
F16_F16
=
ck
::
Tuple
<
F16
,
F16
>
;
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
template
<
ck
::
index_t
...
Is
>
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
S
=
ck
::
Sequence
<
Is
...
>
;
...
...
profiler/CMakeLists.txt
View file @
b58b98ff
...
@@ -24,7 +24,7 @@ include_directories(BEFORE
...
@@ -24,7 +24,7 @@ include_directories(BEFORE
# ck_profiler
# ck_profiler
set
(
PROFILER_SOURCE
set
(
PROFILER_SOURCE
src/profiler.cpp
src/profiler.cpp
src/profile_gemm.cpp
#
src/profile_gemm.cpp
# src/profile_gemm_bias_2d.cpp
# src/profile_gemm_bias_2d.cpp
# src/profile_gemm_bias_relu.cpp
# src/profile_gemm_bias_relu.cpp
# src/profile_gemm_bias_relu_add.cpp
# src/profile_gemm_bias_relu_add.cpp
...
@@ -47,7 +47,7 @@ add_executable(ckProfiler ${PROFILER_SOURCE})
...
@@ -47,7 +47,7 @@ add_executable(ckProfiler ${PROFILER_SOURCE})
target_link_libraries
(
ckProfiler PRIVATE host_tensor
)
target_link_libraries
(
ckProfiler PRIVATE host_tensor
)
target_link_libraries
(
ckProfiler PRIVATE conv_util
)
target_link_libraries
(
ckProfiler PRIVATE conv_util
)
#target_link_libraries(ckProfiler PRIVATE device_gemm_reduce_instance)
#target_link_libraries(ckProfiler PRIVATE device_gemm_reduce_instance)
target_link_libraries
(
ckProfiler PRIVATE device_gemm_instance
)
#
target_link_libraries(ckProfiler PRIVATE device_gemm_instance)
#target_link_libraries(ckProfiler PRIVATE device_gemm_bias2d_instance)
#target_link_libraries(ckProfiler PRIVATE device_gemm_bias2d_instance)
#target_link_libraries(ckProfiler PRIVATE device_gemm_bias_relu_instance)
#target_link_libraries(ckProfiler PRIVATE device_gemm_bias_relu_instance)
#target_link_libraries(ckProfiler PRIVATE device_gemm_bias_relu_add_instance)
#target_link_libraries(ckProfiler PRIVATE device_gemm_bias_relu_add_instance)
...
...
profiler/include/profile_gemm_add_add_fastgelu_impl.hpp
View file @
b58b98ff
...
@@ -11,8 +11,8 @@
...
@@ -11,8 +11,8 @@
#include "tensor_layout.hpp"
#include "tensor_layout.hpp"
#include "device_tensor.hpp"
#include "device_tensor.hpp"
#include "element_wise_operation.hpp"
#include "element_wise_operation.hpp"
#include "device_gemm.hpp"
#include "reference_gemm.hpp"
#include "reference_gemm.hpp"
#include "device_gemm_multiple_d.hpp"
namespace
ck
{
namespace
ck
{
namespace
tensor_operation
{
namespace
tensor_operation
{
...
@@ -23,7 +23,7 @@ using DeviceGemmAddAddFastGeluPtr = ck::tensor_operation::device::DeviceGemmMult
...
@@ -23,7 +23,7 @@ using DeviceGemmAddAddFastGeluPtr = ck::tensor_operation::device::DeviceGemmMult
2
,
2
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
FastGelu
>
;
ck
::
tensor_operation
::
element_wise
::
AddAdd
FastGelu
>
;
void
add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances
(
void
add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances
(
std
::
vector
<
DeviceGemmAddAddFastGeluPtr
>&
);
std
::
vector
<
DeviceGemmAddAddFastGeluPtr
>&
);
...
@@ -44,6 +44,7 @@ namespace profiler {
...
@@ -44,6 +44,7 @@ namespace profiler {
template
<
typename
ADataType
,
template
<
typename
ADataType
,
typename
BDataType
,
typename
BDataType
,
typename
AccDataType
,
typename
D0DataType
,
typename
D0DataType
,
typename
D1DataType
,
typename
D1DataType
,
typename
EDataType
,
typename
EDataType
,
...
@@ -54,7 +55,7 @@ template <typename ADataType,
...
@@ -54,7 +55,7 @@ template <typename ADataType,
typename
ELayout
>
typename
ELayout
>
int
profile_gemm_add_add_fastgelu_impl
(
int
do_verification
,
int
profile_gemm_add_add_fastgelu_impl
(
int
do_verification
,
int
init_method
,
int
init_method
,
bool
do_log
,
bool
/*
do_log
*/
,
bool
time_kernel
,
bool
time_kernel
,
int
M
,
int
M
,
int
N
,
int
N
,
...
@@ -131,28 +132,32 @@ int profile_gemm_add_add_fastgelu_impl(int do_verification,
...
@@ -131,28 +132,32 @@ int profile_gemm_add_add_fastgelu_impl(int do_verification,
is_same_v
<
ELayout
,
tensor_layout
::
gemm
::
RowMajor
>
)
is_same_v
<
ELayout
,
tensor_layout
::
gemm
::
RowMajor
>
)
{
{
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_gemm_gelu_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances
(
device_op_ptrs
);
add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances
(
device_op_ptrs
);
}
}
else
if
constexpr
(
is_same_v
<
ALayout
,
tensor_layout
::
gemm
::
RowMajor
>
&&
else
if
constexpr
(
is_same_v
<
ALayout
,
tensor_layout
::
gemm
::
RowMajor
>
&&
is_same_v
<
BLayout
,
tensor_layout
::
gemm
::
ColumnMajor
>
&&
is_same_v
<
BLayout
,
tensor_layout
::
gemm
::
ColumnMajor
>
&&
is_same_v
<
ELayout
,
tensor_layout
::
gemm
::
RowMajor
>
)
is_same_v
<
ELayout
,
tensor_layout
::
gemm
::
RowMajor
>
)
{
{
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_gemm_gelu_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances
(
device_op_ptrs
);
add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances
(
device_op_ptrs
);
}
}
else
if
constexpr
(
is_same_v
<
ALayout
,
tensor_layout
::
gemm
::
ColumnMajor
>
&&
else
if
constexpr
(
is_same_v
<
ALayout
,
tensor_layout
::
gemm
::
ColumnMajor
>
&&
is_same_v
<
BLayout
,
tensor_layout
::
gemm
::
RowMajor
>
&&
is_same_v
<
BLayout
,
tensor_layout
::
gemm
::
RowMajor
>
&&
is_same_v
<
ELayout
,
tensor_layout
::
gemm
::
RowMajor
>
)
is_same_v
<
ELayout
,
tensor_layout
::
gemm
::
RowMajor
>
)
{
{
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_gemm_gelu_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances
(
device_op_ptrs
);
add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances
(
device_op_ptrs
);
}
}
else
if
constexpr
(
is_same_v
<
ALayout
,
tensor_layout
::
gemm
::
ColumnMajor
>
&&
else
if
constexpr
(
is_same_v
<
ALayout
,
tensor_layout
::
gemm
::
ColumnMajor
>
&&
is_same_v
<
BLayout
,
tensor_layout
::
gemm
::
ColumnMajor
>
&&
is_same_v
<
BLayout
,
tensor_layout
::
gemm
::
ColumnMajor
>
&&
is_same_v
<
ELayout
,
tensor_layout
::
gemm
::
RowMajor
>
)
is_same_v
<
ELayout
,
tensor_layout
::
gemm
::
RowMajor
>
)
{
{
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_gemm_gelu_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances
(
device_op_ptrs
);
add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances
(
device_op_ptrs
);
}
}
}
}
...
...
profiler/src/profile_gemm_add_add_fastgelu.cpp
View file @
b58b98ff
...
@@ -22,16 +22,16 @@ int profile_gemm_add_add_fastgelu(int argc, char* argv[])
...
@@ -22,16 +22,16 @@ int profile_gemm_add_add_fastgelu(int argc, char* argv[])
enum
struct
MatrixDataType
enum
struct
MatrixDataType
{
{
F32_F32_F32_F32_F32
,
// 0
F32_F32_F32_F32_F32
,
// 0
F16_F16_F16_F16_F16
_F16_F16
,
// 1
F16_F16_F16_F16_F16
,
// 1
BF16_BF16_BF16_BF16_BF16
,
// 2
BF16_BF16_BF16_BF16_BF16
,
// 2
INT8_INT8_INT8_INT8_INT8
,
// 3
INT8_INT8_INT8_INT8_INT8
,
// 3
};
};
if
(
argc
!=
16
)
if
(
argc
!=
16
)
{
{
// clang-format off
// clang-format off
printf
(
"arg1: tensor operation (gemm_gelu: GEMM+Add+Add+GeLU)
\n
"
);
printf
(
"arg1: tensor operation (gemm_
add_add_fast
gelu: GEMM+Add+Add+GeLU)
\n
"
);
printf
(
"arg2: data type (0: fp32; 1: fp16; 2: bf16; 3: int8)
\n
"
);
printf
(
"arg2: data type (0: fp32; 1: fp16; 2: bf16; 3: int8)
\n
"
);
printf
(
"arg3: matrix layout (0: E[m, n] = FastGeLU(A[m, k] * B[k, n] + D0[m, n] + D1[m, n]);
\n
"
);
printf
(
"arg3: matrix layout (0: E[m, n] = FastGeLU(A[m, k] * B[k, n] + D0[m, n] + D1[m, n]);
\n
"
);
printf
(
" 1: E[m, n] = FastGeLU(A[m, k] * B[n, k] + D0[m, n] + D1[m, n]);
\n
"
);
printf
(
" 1: E[m, n] = FastGeLU(A[m, k] * B[n, k] + D0[m, n] + D1[m, n]);
\n
"
);
...
@@ -40,7 +40,7 @@ int profile_gemm_add_add_fastgelu(int argc, char* argv[])
...
@@ -40,7 +40,7 @@ int profile_gemm_add_add_fastgelu(int argc, char* argv[])
printf
(
"arg4: verification (0: no; 1: yes)
\n
"
);
printf
(
"arg4: verification (0: no; 1: yes)
\n
"
);
printf
(
"arg5: initialization (0: no init; 1: integer value; 2: decimal value)
\n
"
);
printf
(
"arg5: initialization (0: no init; 1: integer value; 2: decimal value)
\n
"
);
printf
(
"arg6: print tensor value (0: no; 1: yes)
\n
"
);
printf
(
"arg6: print tensor value (0: no; 1: yes)
\n
"
);
printf
(
"arg7: time kernel (0=n
0
, 1=yes)
\n
"
);
printf
(
"arg7: time kernel (0=n
o
, 1=yes)
\n
"
);
printf
(
"arg8 to 13: M, N, K, StrideA, StrideB, StrideD0, StrideD1, StrideE
\n
"
);
printf
(
"arg8 to 13: M, N, K, StrideA, StrideB, StrideD0, StrideD1, StrideE
\n
"
);
// clang-format on
// clang-format on
exit
(
1
);
exit
(
1
);
...
@@ -64,12 +64,14 @@ int profile_gemm_add_add_fastgelu(int argc, char* argv[])
...
@@ -64,12 +64,14 @@ int profile_gemm_add_add_fastgelu(int argc, char* argv[])
const
int
StrideE
=
std
::
stoi
(
argv
[
15
]);
const
int
StrideE
=
std
::
stoi
(
argv
[
15
]);
using
F16
=
ck
::
half_t
;
using
F16
=
ck
::
half_t
;
using
F32
=
float
;
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
auto
profile
=
[
&
](
auto
a_type
,
auto
profile
=
[
&
](
auto
a_type
,
auto
b_type
,
auto
b_type
,
auto
acc_type
,
auto
d0_type
,
auto
d0_type
,
auto
d1_type
,
auto
d1_type
,
auto
e_type
,
auto
e_type
,
...
@@ -78,11 +80,12 @@ int profile_gemm_add_add_fastgelu(int argc, char* argv[])
...
@@ -78,11 +80,12 @@ int profile_gemm_add_add_fastgelu(int argc, char* argv[])
auto
d0_layout
,
auto
d0_layout
,
auto
d1_layout
,
auto
d1_layout
,
auto
e_layout
)
{
auto
e_layout
)
{
using
ADataType
=
decltype
(
a_type
);
using
ADataType
=
decltype
(
a_type
);
using
BDataType
=
decltype
(
b_type
);
using
BDataType
=
decltype
(
b_type
);
using
D0DataType
=
decltype
(
d0_type
);
using
AccDataType
=
decltype
(
acc_type
);
using
D1DataType
=
decltype
(
d1_type
);
using
D0DataType
=
decltype
(
d0_type
);
using
EDataType
=
decltype
(
e_type
);
using
D1DataType
=
decltype
(
d1_type
);
using
EDataType
=
decltype
(
e_type
);
using
ALayout
=
decltype
(
a_layout
);
using
ALayout
=
decltype
(
a_layout
);
using
BLayout
=
decltype
(
b_layout
);
using
BLayout
=
decltype
(
b_layout
);
...
@@ -96,16 +99,17 @@ int profile_gemm_add_add_fastgelu(int argc, char* argv[])
...
@@ -96,16 +99,17 @@ int profile_gemm_add_add_fastgelu(int argc, char* argv[])
const
int
DefaultStrideD1
=
ck
::
is_same_v
<
D1Layout
,
Row
>
?
N
:
M
;
const
int
DefaultStrideD1
=
ck
::
is_same_v
<
D1Layout
,
Row
>
?
N
:
M
;
const
int
DefaultStrideE
=
ck
::
is_same_v
<
ELayout
,
Row
>
?
N
:
M
;
const
int
DefaultStrideE
=
ck
::
is_same_v
<
ELayout
,
Row
>
?
N
:
M
;
return
ck
::
profiler
::
profile_gemm_add_add_gelu_impl
<
ADataType
,
return
ck
::
profiler
::
profile_gemm_add_add_fastgelu_impl
<
ADataType
,
BDataType
,
BDataType
,
D0DataType
,
AccDataType
,
D1DataType
,
D0DataType
,
EDataType
,
D1DataType
,
ALayout
,
EDataType
,
BLayout
,
ALayout
,
D0Layout
,
BLayout
,
D1Layout
,
D0Layout
,
ELayout
>
(
D1Layout
,
ELayout
>
(
do_verification
,
do_verification
,
init_method
,
init_method
,
do_log
,
do_log
,
...
@@ -122,22 +126,22 @@ int profile_gemm_add_add_fastgelu(int argc, char* argv[])
...
@@ -122,22 +126,22 @@ int profile_gemm_add_add_fastgelu(int argc, char* argv[])
if
(
data_type
==
MatrixDataType
::
F16_F16_F16_F16_F16
&&
layout
==
MatrixLayout
::
MK_KN_MN_MN_MN
)
if
(
data_type
==
MatrixDataType
::
F16_F16_F16_F16_F16
&&
layout
==
MatrixLayout
::
MK_KN_MN_MN_MN
)
{
{
return
profile
(
F16
{},
F16
{},
F16
{},
F16
{},
F16
{},
Row
{},
Row
{},
Row
{},
Row
{},
Row
{});
return
profile
(
F16
{},
F16
{},
F32
{},
F16
{},
F16
{},
F16
{},
Row
{},
Row
{},
Row
{},
Row
{},
Row
{});
}
}
else
if
(
data_type
==
MatrixDataType
::
F16_F16_F16_F16_F16
&&
else
if
(
data_type
==
MatrixDataType
::
F16_F16_F16_F16_F16
&&
layout
==
MatrixLayout
::
MK_NK_MN_MN_MN
)
layout
==
MatrixLayout
::
MK_NK_MN_MN_MN
)
{
{
return
profile
(
F16
{},
F16
{},
F16
{},
F16
{},
F16
{},
Row
{},
Col
{},
Row
{},
Row
{},
Row
{});
return
profile
(
F16
{},
F16
{},
F32
{},
F16
{},
F16
{},
F16
{},
Row
{},
Col
{},
Row
{},
Row
{},
Row
{});
}
}
else
if
(
data_type
==
MatrixDataType
::
F16_F16_F16_F16_F16
&&
else
if
(
data_type
==
MatrixDataType
::
F16_F16_F16_F16_F16
&&
layout
==
MatrixLayout
::
KM_KN_MN_MN_MN
)
layout
==
MatrixLayout
::
KM_KN_MN_MN_MN
)
{
{
return
profile
(
F16
{},
F16
{},
F16
{},
F16
{},
F16
{},
Col
{},
Row
{},
Row
{},
Row
{},
Row
{});
return
profile
(
F16
{},
F16
{},
F32
{},
F16
{},
F16
{},
F16
{},
Col
{},
Row
{},
Row
{},
Row
{},
Row
{});
}
}
else
if
(
data_type
==
MatrixDataType
::
F16_F16_F16_F16_F16
&&
else
if
(
data_type
==
MatrixDataType
::
F16_F16_F16_F16_F16
&&
layout
==
MatrixLayout
::
KM_NK_MN_MN_MN
)
layout
==
MatrixLayout
::
KM_NK_MN_MN_MN
)
{
{
return
profile
(
F16
{},
F16
{},
F16
{},
F16
{},
F16
{},
Col
{},
Col
{},
Row
{},
Row
{},
Row
{});
return
profile
(
F16
{},
F16
{},
F32
{},
F16
{},
F16
{},
F16
{},
Col
{},
Col
{},
Row
{},
Row
{},
Row
{});
}
}
else
else
{
{
...
...
profiler/src/profiler.cpp
View file @
b58b98ff
...
@@ -54,11 +54,11 @@ int main(int argc, char* argv[])
...
@@ -54,11 +54,11 @@ int main(int argc, char* argv[])
return
0
;
return
0
;
}
}
#if 0
if(strcmp(argv[1], "gemm") == 0)
if(strcmp(argv[1], "gemm") == 0)
{
{
return profile_gemm(argc, argv);
return profile_gemm(argc, argv);
}
}
#if 0
else if(strcmp(argv[1], "gemm_bias_2d") == 0)
else if(strcmp(argv[1], "gemm_bias_2d") == 0)
{
{
return profile_gemm_bias_2d(argc, argv);
return profile_gemm_bias_2d(argc, argv);
...
@@ -124,7 +124,7 @@ int main(int argc, char* argv[])
...
@@ -124,7 +124,7 @@ int main(int argc, char* argv[])
return profile_conv_bwd_weight(argc, argv);
return profile_conv_bwd_weight(argc, argv);
}
}
#endif
#endif
else
if
(
strcmp
(
argv
[
1
],
"gemm_add_add_fastgelu"
)
==
0
)
if
(
strcmp
(
argv
[
1
],
"gemm_add_add_fastgelu"
)
==
0
)
{
{
return
profile_gemm_add_add_fastgelu
(
argc
,
argv
);
return
profile_gemm_add_add_fastgelu
(
argc
,
argv
);
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment