Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
1ddc3ec7
Unverified
Commit
1ddc3ec7
authored
Aug 31, 2023
by
Rostyslav Geyyer
Committed by
GitHub
Aug 31, 2023
Browse files
Merge branch 'develop' into lwpck-756
parents
e9703d5b
f5ec04f0
Changes
66
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1304 additions
and
6 deletions
+1304
-6
library/src/tensor_operation_instance/gpu/grouped_gemm_bias/device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f32_mk_kn_mn_instance.cpp
..._gemm_xdl_fixed_nk_bias_f16_f16_f32_mk_kn_mn_instance.cpp
+83
-0
library/src/tensor_operation_instance/gpu/grouped_gemm_bias/device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f32_mk_nk_mn_instance.cpp
..._gemm_xdl_fixed_nk_bias_f16_f16_f32_mk_nk_mn_instance.cpp
+83
-0
library/src/tensor_operation_instance/gpu/max_pool_bwd/CMakeLists.txt
...tensor_operation_instance/gpu/max_pool_bwd/CMakeLists.txt
+11
-0
library/src/tensor_operation_instance/gpu/max_pool_bwd/device_max_pool_bwd_bf16_instance.cpp
...ce/gpu/max_pool_bwd/device_max_pool_bwd_bf16_instance.cpp
+20
-0
library/src/tensor_operation_instance/gpu/max_pool_bwd/device_max_pool_bwd_f16_instance.cpp
...nce/gpu/max_pool_bwd/device_max_pool_bwd_f16_instance.cpp
+20
-0
library/src/tensor_operation_instance/gpu/max_pool_bwd/device_max_pool_bwd_f32_instance.cpp
...nce/gpu/max_pool_bwd/device_max_pool_bwd_f32_instance.cpp
+20
-0
library/src/tensor_operation_instance/gpu/max_pool_bwd/max_pool_bwd_instance_common.hpp
...nstance/gpu/max_pool_bwd/max_pool_bwd_instance_common.hpp
+35
-0
library/src/tensor_operation_instance/gpu/pool3d_fwd/CMakeLists.txt
...c/tensor_operation_instance/gpu/pool3d_fwd/CMakeLists.txt
+4
-0
library/src/tensor_operation_instance/gpu/pool3d_fwd/device_avg_pool3d_fwd_ndhwc_bf16_instance.cpp
.../pool3d_fwd/device_avg_pool3d_fwd_ndhwc_bf16_instance.cpp
+25
-0
library/src/tensor_operation_instance/gpu/pool3d_fwd/device_max_pool3d_fwd_ndhwc_bf16_instance.cpp
.../pool3d_fwd/device_max_pool3d_fwd_ndhwc_bf16_instance.cpp
+34
-0
library/src/tensor_operation_instance/gpu/pool3d_fwd/pool_fwd_instance_common.hpp
...tion_instance/gpu/pool3d_fwd/pool_fwd_instance_common.hpp
+1
-0
library/src/utility/device_memory.cpp
library/src/utility/device_memory.cpp
+10
-0
profiler/include/profiler/profile_avg_pool3d_bwd_impl.hpp
profiler/include/profiler/profile_avg_pool3d_bwd_impl.hpp
+253
-0
profiler/include/profiler/profile_max_pool3d_bwd_impl.hpp
profiler/include/profiler/profile_max_pool3d_bwd_impl.hpp
+288
-0
profiler/src/CMakeLists.txt
profiler/src/CMakeLists.txt
+4
-0
profiler/src/profile_avg_pool3d_bwd.cpp
profiler/src/profile_avg_pool3d_bwd.cpp
+175
-0
profiler/src/profile_max_pool3d_bwd.cpp
profiler/src/profile_max_pool3d_bwd.cpp
+177
-0
profiler/src/profile_max_pool3d_fwd.cpp
profiler/src/profile_max_pool3d_fwd.cpp
+60
-4
script/cmake-ck-dev.sh
script/cmake-ck-dev.sh
+0
-1
test/CMakeLists.txt
test/CMakeLists.txt
+1
-1
No files found.
library/src/tensor_operation_instance/gpu/grouped_gemm_bias/device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f32_mk_kn_mn_instance.cpp
0 → 100644
View file @
1ddc3ec7
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_fixed_nk.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
using
F16
=
ck
::
half_t
;
using
F32
=
float
;
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
D0DataType
=
F32
;
using
DsDataType
=
ck
::
Tuple
<
D0DataType
>
;
using
D0Layout
=
Row
;
using
DsLayout
=
ck
::
Tuple
<
D0Layout
>
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
Add
=
ck
::
tensor_operation
::
element_wise
::
Add
;
static
constexpr
auto
GemmMNKPadding
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNKPadding
;
using
device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f32_mk_kn_mn_irregular_tile_instances
=
std
::
tuple
<
// clang-format off
//############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGroupedGemm_Xdl_Fixed_NK
<
Row
,
Row
,
DsLayout
,
Row
,
F16
,
F16
,
F32
,
F32
,
DsDataType
,
F32
,
PassThrough
,
PassThrough
,
Add
,
GemmMNKPadding
,
1
,
256
,
256
,
128
,
32
,
8
,
8
,
32
,
32
,
4
,
2
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
2
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
4
>
,
DeviceGroupedGemm_Xdl_Fixed_NK
<
Row
,
Row
,
DsLayout
,
Row
,
F16
,
F16
,
F32
,
F32
,
DsDataType
,
F32
,
PassThrough
,
PassThrough
,
Add
,
GemmMNKPadding
,
1
,
256
,
128
,
128
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
1
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
4
>
,
DeviceGroupedGemm_Xdl_Fixed_NK
<
Row
,
Row
,
DsLayout
,
Row
,
F16
,
F16
,
F32
,
F32
,
DsDataType
,
F32
,
PassThrough
,
PassThrough
,
Add
,
GemmMNKPadding
,
1
,
256
,
128
,
64
,
32
,
8
,
2
,
32
,
32
,
2
,
1
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
16
,
16
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
4
,
2
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
4
>
,
DeviceGroupedGemm_Xdl_Fixed_NK
<
Row
,
Row
,
DsLayout
,
Row
,
F16
,
F16
,
F32
,
F32
,
DsDataType
,
F32
,
PassThrough
,
PassThrough
,
Add
,
GemmMNKPadding
,
1
,
256
,
128
,
64
,
32
,
8
,
8
,
32
,
32
,
2
,
1
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
1
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
4
>
,
DeviceGroupedGemm_Xdl_Fixed_NK
<
Row
,
Row
,
DsLayout
,
Row
,
F16
,
F16
,
F32
,
F32
,
DsDataType
,
F32
,
PassThrough
,
PassThrough
,
Add
,
GemmMNKPadding
,
1
,
256
,
64
,
128
,
32
,
8
,
2
,
32
,
32
,
1
,
2
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
8
,
32
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
4
,
2
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
4
>
,
DeviceGroupedGemm_Xdl_Fixed_NK
<
Row
,
Row
,
DsLayout
,
Row
,
F16
,
F16
,
F32
,
F32
,
DsDataType
,
F32
,
PassThrough
,
PassThrough
,
Add
,
GemmMNKPadding
,
1
,
256
,
64
,
128
,
32
,
8
,
8
,
32
,
32
,
1
,
2
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
2
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
4
>
,
DeviceGroupedGemm_Xdl_Fixed_NK
<
Row
,
Row
,
DsLayout
,
Row
,
F16
,
F16
,
F32
,
F32
,
DsDataType
,
F32
,
PassThrough
,
PassThrough
,
Add
,
GemmMNKPadding
,
1
,
256
,
32
,
128
,
32
,
8
,
8
,
32
,
32
,
1
,
1
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
2
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
4
>
,
DeviceGroupedGemm_Xdl_Fixed_NK
<
Row
,
Row
,
DsLayout
,
Row
,
F16
,
F16
,
F32
,
F32
,
DsDataType
,
F32
,
PassThrough
,
PassThrough
,
Add
,
GemmMNKPadding
,
1
,
128
,
128
,
64
,
32
,
8
,
2
,
32
,
32
,
2
,
2
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
8
,
16
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
4
,
2
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
4
>
,
4
>
,
DeviceGroupedGemm_Xdl_Fixed_NK
<
Row
,
Row
,
DsLayout
,
Row
,
F16
,
F16
,
F32
,
F32
,
DsDataType
,
F32
,
PassThrough
,
PassThrough
,
Add
,
GemmMNKPadding
,
1
,
128
,
128
,
64
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
2
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
4
>
,
4
>
,
DeviceGroupedGemm_Xdl_Fixed_NK
<
Row
,
Row
,
DsLayout
,
Row
,
F16
,
F16
,
F32
,
F32
,
DsDataType
,
F32
,
PassThrough
,
PassThrough
,
Add
,
GemmMNKPadding
,
1
,
128
,
64
,
128
,
32
,
8
,
2
,
32
,
32
,
2
,
2
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
4
,
2
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
4
>
,
DeviceGroupedGemm_Xdl_Fixed_NK
<
Row
,
Row
,
DsLayout
,
Row
,
F16
,
F16
,
F32
,
F32
,
DsDataType
,
F32
,
PassThrough
,
PassThrough
,
Add
,
GemmMNKPadding
,
1
,
128
,
64
,
128
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
4
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
4
>
,
DeviceGroupedGemm_Xdl_Fixed_NK
<
Row
,
Row
,
DsLayout
,
Row
,
F16
,
F16
,
F32
,
F32
,
DsDataType
,
F32
,
PassThrough
,
PassThrough
,
Add
,
GemmMNKPadding
,
1
,
128
,
32
,
128
,
32
,
8
,
8
,
32
,
32
,
1
,
2
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
4
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
4
>
,
DeviceGroupedGemm_Xdl_Fixed_NK
<
Row
,
Row
,
DsLayout
,
Row
,
F16
,
F16
,
F32
,
F32
,
DsDataType
,
F32
,
PassThrough
,
PassThrough
,
Add
,
GemmMNKPadding
,
1
,
256
,
32
,
256
,
32
,
8
,
8
,
32
,
32
,
1
,
2
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
4
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
4
>
,
DeviceGroupedGemm_Xdl_Fixed_NK
<
Row
,
Row
,
DsLayout
,
Row
,
F16
,
F16
,
F32
,
F32
,
DsDataType
,
F32
,
PassThrough
,
PassThrough
,
Add
,
GemmMNKPadding
,
1
,
128
,
16
,
128
,
32
,
8
,
8
,
16
,
16
,
1
,
4
,
S
<
1
,
4
,
16
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
4
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
4
>
// clang-format on
>
;
void
add_device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f32_mk_kn_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedGemmFixedNK
<
Row
,
Row
,
DsLayout
,
Row
,
F16
,
F16
,
DsDataType
,
F32
,
PassThrough
,
PassThrough
,
Add
>>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f32_mk_kn_mn_irregular_tile_instances
{});
}
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/src/tensor_operation_instance/gpu/grouped_gemm_bias/device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f32_mk_nk_mn_instance.cpp
0 → 100644
View file @
1ddc3ec7
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_fixed_nk.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
using
F16
=
ck
::
half_t
;
using
F32
=
float
;
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
D0DataType
=
F32
;
using
DsDataType
=
ck
::
Tuple
<
D0DataType
>
;
using
D0Layout
=
Row
;
using
DsLayout
=
ck
::
Tuple
<
D0Layout
>
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
Add
=
ck
::
tensor_operation
::
element_wise
::
Add
;
static
constexpr
auto
GemmMNKPadding
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNKPadding
;
using
device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f32_mk_nk_mn_irregular_tile_instances
=
std
::
tuple
<
// clang-format off
//############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGroupedGemm_Xdl_Fixed_NK
<
Row
,
Col
,
DsLayout
,
Row
,
F16
,
F16
,
F32
,
F32
,
DsDataType
,
F32
,
PassThrough
,
PassThrough
,
Add
,
GemmMNKPadding
,
1
,
256
,
128
,
256
,
64
,
8
,
8
,
32
,
32
,
2
,
4
,
S
<
1
,
8
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
8
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
4
>
,
DeviceGroupedGemm_Xdl_Fixed_NK
<
Row
,
Col
,
DsLayout
,
Row
,
F16
,
F16
,
F32
,
F32
,
DsDataType
,
F32
,
PassThrough
,
PassThrough
,
Add
,
GemmMNKPadding
,
1
,
256
,
128
,
128
,
64
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
1
,
8
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
8
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
4
>
,
DeviceGroupedGemm_Xdl_Fixed_NK
<
Row
,
Col
,
DsLayout
,
Row
,
F16
,
F16
,
F32
,
F32
,
DsDataType
,
F32
,
PassThrough
,
PassThrough
,
Add
,
GemmMNKPadding
,
1
,
256
,
128
,
64
,
64
,
8
,
8
,
32
,
32
,
2
,
1
,
S
<
1
,
8
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
8
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
4
>
,
DeviceGroupedGemm_Xdl_Fixed_NK
<
Row
,
Col
,
DsLayout
,
Row
,
F16
,
F16
,
F32
,
F32
,
DsDataType
,
F32
,
PassThrough
,
PassThrough
,
Add
,
GemmMNKPadding
,
1
,
256
,
64
,
128
,
64
,
8
,
8
,
32
,
32
,
1
,
2
,
S
<
1
,
8
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
8
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
4
>
,
DeviceGroupedGemm_Xdl_Fixed_NK
<
Row
,
Col
,
DsLayout
,
Row
,
F16
,
F16
,
F32
,
F32
,
DsDataType
,
F32
,
PassThrough
,
PassThrough
,
Add
,
GemmMNKPadding
,
1
,
256
,
32
,
128
,
64
,
8
,
8
,
32
,
32
,
1
,
1
,
S
<
1
,
8
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
8
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
4
>
,
DeviceGroupedGemm_Xdl_Fixed_NK
<
Row
,
Col
,
DsLayout
,
Row
,
F16
,
F16
,
F32
,
F32
,
DsDataType
,
F32
,
PassThrough
,
PassThrough
,
Add
,
GemmMNKPadding
,
1
,
128
,
128
,
128
,
64
,
8
,
8
,
32
,
32
,
4
,
2
,
S
<
1
,
8
,
16
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
8
,
16
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
4
>
,
DeviceGroupedGemm_Xdl_Fixed_NK
<
Row
,
Col
,
DsLayout
,
Row
,
F16
,
F16
,
F32
,
F32
,
DsDataType
,
F32
,
PassThrough
,
PassThrough
,
Add
,
GemmMNKPadding
,
1
,
128
,
128
,
64
,
64
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
1
,
8
,
16
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
8
,
16
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
4
>
,
4
>
,
DeviceGroupedGemm_Xdl_Fixed_NK
<
Row
,
Col
,
DsLayout
,
Row
,
F16
,
F16
,
F32
,
F32
,
DsDataType
,
F32
,
PassThrough
,
PassThrough
,
Add
,
GemmMNKPadding
,
1
,
128
,
64
,
128
,
64
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
1
,
8
,
16
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
8
,
16
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
4
>
,
DeviceGroupedGemm_Xdl_Fixed_NK
<
Row
,
Col
,
DsLayout
,
Row
,
F16
,
F16
,
F32
,
F32
,
DsDataType
,
F32
,
PassThrough
,
PassThrough
,
Add
,
GemmMNKPadding
,
1
,
128
,
128
,
32
,
64
,
8
,
8
,
32
,
32
,
2
,
1
,
S
<
1
,
8
,
16
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
8
,
16
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
4
>
,
4
>
,
DeviceGroupedGemm_Xdl_Fixed_NK
<
Row
,
Col
,
DsLayout
,
Row
,
F16
,
F16
,
F32
,
F32
,
DsDataType
,
F32
,
PassThrough
,
PassThrough
,
Add
,
GemmMNKPadding
,
1
,
128
,
32
,
128
,
64
,
8
,
8
,
32
,
32
,
1
,
2
,
S
<
1
,
8
,
16
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
8
,
16
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
4
>
,
DeviceGroupedGemm_Xdl_Fixed_NK
<
Row
,
Col
,
DsLayout
,
Row
,
F16
,
F16
,
F32
,
F32
,
DsDataType
,
F32
,
PassThrough
,
PassThrough
,
Add
,
GemmMNKPadding
,
1
,
128
,
32
,
256
,
64
,
8
,
8
,
32
,
32
,
1
,
4
,
S
<
1
,
8
,
16
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
8
,
16
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
4
>
,
DeviceGroupedGemm_Xdl_Fixed_NK
<
Row
,
Col
,
DsLayout
,
Row
,
F16
,
F16
,
F32
,
F32
,
DsDataType
,
F32
,
PassThrough
,
PassThrough
,
Add
,
GemmMNKPadding
,
1
,
64
,
64
,
64
,
64
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
1
,
8
,
8
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
8
,
8
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
4
>
,
DeviceGroupedGemm_Xdl_Fixed_NK
<
Row
,
Col
,
DsLayout
,
Row
,
F16
,
F16
,
F32
,
F32
,
DsDataType
,
F32
,
PassThrough
,
PassThrough
,
Add
,
GemmMNKPadding
,
1
,
64
,
64
,
32
,
64
,
8
,
8
,
32
,
32
,
2
,
1
,
S
<
1
,
8
,
8
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
8
,
8
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
4
>
,
DeviceGroupedGemm_Xdl_Fixed_NK
<
Row
,
Col
,
DsLayout
,
Row
,
F16
,
F16
,
F32
,
F32
,
DsDataType
,
F32
,
PassThrough
,
PassThrough
,
Add
,
GemmMNKPadding
,
1
,
64
,
32
,
64
,
64
,
8
,
8
,
32
,
32
,
1
,
2
,
S
<
1
,
8
,
8
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
8
,
8
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
4
>
// clang-format on
>
;
void
add_device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f32_mk_nk_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedGemmFixedNK
<
Row
,
Col
,
DsLayout
,
Row
,
F16
,
F16
,
DsDataType
,
F32
,
PassThrough
,
PassThrough
,
Add
>>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f32_mk_nk_mn_irregular_tile_instances
{});
}
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/src/tensor_operation_instance/gpu/max_pool_bwd/CMakeLists.txt
0 → 100644
View file @
1ddc3ec7
set
(
DEVICE_MAXPOOL_BWD_INSTANCES
)
if
(
DTYPES MATCHES
"fp16"
OR NOT DEFINED DTYPES
)
list
(
APPEND DEVICE_MAXPOOL_BWD_INSTANCES device_max_pool_bwd_f16_instance.cpp
)
endif
()
if
(
DTYPES MATCHES
"bf16"
OR NOT DEFINED DTYPES
)
list
(
APPEND DEVICE_MAXPOOL_BWD_INSTANCES device_max_pool_bwd_bf16_instance.cpp
)
endif
()
if
(
DTYPES MATCHES
"fp32"
OR NOT DEFINED DTYPES
)
list
(
APPEND DEVICE_MAXPOOL_BWD_INSTANCES device_max_pool_bwd_f32_instance.cpp
)
endif
()
add_instance_library
(
device_max_pool_bwd_instance
${
DEVICE_MAXPOOL_BWD_INSTANCES
}
)
library/src/tensor_operation_instance/gpu/max_pool_bwd/device_max_pool_bwd_bf16_instance.cpp
0 → 100644
View file @
1ddc3ec7
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include "max_pool_bwd_instance_common.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
void
add_device_maxpool_bwd_bf16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceMaxPoolBwd
<
BF16
,
I32
,
BF16
>>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_maxpool_bwd_instances
<
BF16
,
I32
,
BF16
>
{});
}
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/src/tensor_operation_instance/gpu/max_pool_bwd/device_max_pool_bwd_f16_instance.cpp
0 → 100644
View file @
1ddc3ec7
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include "max_pool_bwd_instance_common.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
void
add_device_maxpool_bwd_f16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceMaxPoolBwd
<
F16
,
I32
,
F16
>>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_maxpool_bwd_instances
<
F16
,
I32
,
F16
>
{});
}
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/src/tensor_operation_instance/gpu/max_pool_bwd/device_max_pool_bwd_f32_instance.cpp
0 → 100644
View file @
1ddc3ec7
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include "max_pool_bwd_instance_common.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
void
add_device_maxpool_bwd_f32_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceMaxPoolBwd
<
F32
,
I32
,
F32
>>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_maxpool_bwd_instances
<
F32
,
I32
,
F32
>
{});
}
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/src/tensor_operation_instance/gpu/max_pool_bwd/max_pool_bwd_instance_common.hpp
0 → 100644
View file @
1ddc3ec7
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_max_pool_bwd_impl.hpp"
#include "ck/utility/data_type.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
using
I32
=
int32_t
;
using
F16
=
ck
::
half_t
;
using
BF16
=
ck
::
bhalf_t
;
using
F32
=
float
;
template
<
typename
DOutDataType
,
typename
IndexDataType
,
typename
DInDataType
>
using
device_maxpool_bwd_instances
=
// clang-format off
std
::
tuple
<
DeviceMaxPoolBwdImpl
<
DOutDataType
,
IndexDataType
,
DInDataType
,
1
>
,
DeviceMaxPoolBwdImpl
<
DOutDataType
,
IndexDataType
,
DInDataType
,
2
>
,
DeviceMaxPoolBwdImpl
<
DOutDataType
,
IndexDataType
,
DInDataType
,
4
>
// clang-format on
>
;
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/src/tensor_operation_instance/gpu/pool3d_fwd/CMakeLists.txt
View file @
1ddc3ec7
...
@@ -3,6 +3,10 @@ if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
...
@@ -3,6 +3,10 @@ if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
list
(
APPEND DEVICE_POOL3D_FWD_INSTANCES device_avg_pool3d_fwd_ndhwc_f16_instance.cpp
list
(
APPEND DEVICE_POOL3D_FWD_INSTANCES device_avg_pool3d_fwd_ndhwc_f16_instance.cpp
device_max_pool3d_fwd_ndhwc_f16_instance.cpp
)
device_max_pool3d_fwd_ndhwc_f16_instance.cpp
)
endif
()
endif
()
if
(
DTYPES MATCHES
"bf16"
OR NOT DEFINED DTYPES
)
list
(
APPEND DEVICE_POOL3D_FWD_INSTANCES device_avg_pool3d_fwd_ndhwc_bf16_instance.cpp
device_max_pool3d_fwd_ndhwc_bf16_instance.cpp
)
endif
()
if
(
DTYPES MATCHES
"fp32"
OR NOT DEFINED DTYPES
)
if
(
DTYPES MATCHES
"fp32"
OR NOT DEFINED DTYPES
)
list
(
APPEND DEVICE_POOL3D_FWD_INSTANCES device_avg_pool3d_fwd_ndhwc_f32_instance.cpp
list
(
APPEND DEVICE_POOL3D_FWD_INSTANCES device_avg_pool3d_fwd_ndhwc_f32_instance.cpp
device_max_pool3d_fwd_ndhwc_f32_instance.cpp
)
device_max_pool3d_fwd_ndhwc_f32_instance.cpp
)
...
...
library/src/tensor_operation_instance/gpu/pool3d_fwd/device_avg_pool3d_fwd_ndhwc_bf16_instance.cpp
0 → 100644
View file @
1ddc3ec7
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include "pool_fwd_instance_common.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
static
constexpr
auto
ReduceOpId
=
ck
::
ReduceTensorOp
::
AVG
;
void
add_device_pool3d_fwd_ndhwc_bf16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DevicePoolFwd
<
5
,
3
,
BF16
,
BF16
,
I32
,
NDHWC
,
NDHWC
,
ReduceOpId
,
false
>>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_pool3d_fwd_ndhwc_instances
<
BF16
,
BF16
,
I32
,
F32
,
ReduceOpId
,
false
>
{});
}
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/src/tensor_operation_instance/gpu/pool3d_fwd/device_max_pool3d_fwd_ndhwc_bf16_instance.cpp
0 → 100644
View file @
1ddc3ec7
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include "pool_fwd_instance_common.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
static
constexpr
auto
ReduceOpId
=
ck
::
ReduceTensorOp
::
MAX
;
void
add_device_pool3d_fwd_ndhwc_bf16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DevicePoolFwd
<
5
,
3
,
BF16
,
BF16
,
I32
,
NDHWC
,
NDHWC
,
ReduceOpId
,
false
>>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_pool3d_fwd_ndhwc_instances
<
BF16
,
BF16
,
I32
,
BF16
,
ReduceOpId
,
false
>
{});
}
void
add_device_pool3d_fwd_ndhwc_index_bf16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DevicePoolFwd
<
5
,
3
,
BF16
,
BF16
,
I32
,
NDHWC
,
NDHWC
,
ReduceOpId
,
true
>>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_pool3d_fwd_ndhwc_instances
<
BF16
,
BF16
,
I32
,
BF16
,
ReduceOpId
,
true
>
{});
}
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/src/tensor_operation_instance/gpu/pool3d_fwd/pool_fwd_instance_common.hpp
View file @
1ddc3ec7
...
@@ -17,6 +17,7 @@ namespace instance {
...
@@ -17,6 +17,7 @@ namespace instance {
using
I32
=
int32_t
;
using
I32
=
int32_t
;
using
F16
=
ck
::
half_t
;
using
F16
=
ck
::
half_t
;
using
BF16
=
ck
::
bhalf_t
;
using
F32
=
float
;
using
F32
=
float
;
using
NDHWC
=
ck
::
tensor_layout
::
convolution
::
NDHWC
;
using
NDHWC
=
ck
::
tensor_layout
::
convolution
::
NDHWC
;
...
...
library/src/utility/device_memory.cpp
View file @
1ddc3ec7
...
@@ -37,6 +37,11 @@ void DeviceMem::ToDevice(const void* p) const
...
@@ -37,6 +37,11 @@ void DeviceMem::ToDevice(const void* p) const
}
}
}
}
void
DeviceMem
::
ToDevice
(
const
void
*
p
,
const
std
::
size_t
cpySize
)
const
{
hip_check_error
(
hipMemcpy
(
mpDeviceBuf
,
const_cast
<
void
*>
(
p
),
cpySize
,
hipMemcpyHostToDevice
));
}
void
DeviceMem
::
FromDevice
(
void
*
p
)
const
void
DeviceMem
::
FromDevice
(
void
*
p
)
const
{
{
if
(
mpDeviceBuf
)
if
(
mpDeviceBuf
)
...
@@ -49,6 +54,11 @@ void DeviceMem::FromDevice(void* p) const
...
@@ -49,6 +54,11 @@ void DeviceMem::FromDevice(void* p) const
}
}
}
}
void
DeviceMem
::
FromDevice
(
void
*
p
,
const
std
::
size_t
cpySize
)
const
{
hip_check_error
(
hipMemcpy
(
p
,
mpDeviceBuf
,
cpySize
,
hipMemcpyDeviceToHost
));
}
void
DeviceMem
::
SetZero
()
const
void
DeviceMem
::
SetZero
()
const
{
{
if
(
mpDeviceBuf
)
if
(
mpDeviceBuf
)
...
...
profiler/include/profiler/profile_avg_pool3d_bwd_impl.hpp
0 → 100644
View file @
1ddc3ec7
This diff is collapsed.
Click to expand it.
profiler/include/profiler/profile_max_pool3d_bwd_impl.hpp
0 → 100644
View file @
1ddc3ec7
This diff is collapsed.
Click to expand it.
profiler/src/CMakeLists.txt
View file @
1ddc3ec7
...
@@ -19,6 +19,8 @@ set(PROFILER_SOURCES
...
@@ -19,6 +19,8 @@ set(PROFILER_SOURCES
profile_groupnorm.cpp
profile_groupnorm.cpp
profile_layernorm.cpp
profile_layernorm.cpp
profile_max_pool3d_fwd.cpp
profile_max_pool3d_fwd.cpp
profile_avg_pool3d_bwd.cpp
profile_max_pool3d_bwd.cpp
profile_softmax.cpp
profile_softmax.cpp
profile_batchnorm_fwd.cpp
profile_batchnorm_fwd.cpp
profile_batchnorm_bwd.cpp
profile_batchnorm_bwd.cpp
...
@@ -76,6 +78,8 @@ target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batchnorm_instance)
...
@@ -76,6 +78,8 @@ target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batchnorm_instance)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_contraction_bilinear_instance
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_contraction_bilinear_instance
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_contraction_scale_instance
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_contraction_scale_instance
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_pool3d_fwd_instance
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_pool3d_fwd_instance
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_avg_pool3d_bwd_instance
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_max_pool_bwd_instance
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_grouped_conv2d_bwd_data_instance
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_grouped_conv2d_bwd_data_instance
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_grouped_conv3d_bwd_data_instance
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_grouped_conv3d_bwd_data_instance
)
if
(
DL_KERNELS
)
if
(
DL_KERNELS
)
...
...
profiler/src/profile_avg_pool3d_bwd.cpp
0 → 100644
View file @
1ddc3ec7
This diff is collapsed.
Click to expand it.
profiler/src/profile_max_pool3d_bwd.cpp
0 → 100644
View file @
1ddc3ec7
This diff is collapsed.
Click to expand it.
profiler/src/profile_max_pool3d_fwd.cpp
View file @
1ddc3ec7
This diff is collapsed.
Click to expand it.
script/cmake-ck-dev.sh
View file @
1ddc3ec7
...
@@ -16,4 +16,3 @@ cmake
...
@@ -16,4 +16,3 @@ cmake
-D
CMAKE_VERBOSE_MAKEFILE:BOOL
=
ON
\
-D
CMAKE_VERBOSE_MAKEFILE:BOOL
=
ON
\
-D
USE_BITINT_EXTENSION_INT4
=
OFF
\
-D
USE_BITINT_EXTENSION_INT4
=
OFF
\
${
MY_PROJECT_SOURCE
}
${
MY_PROJECT_SOURCE
}
test/CMakeLists.txt
View file @
1ddc3ec7
...
@@ -57,7 +57,7 @@ add_subdirectory(data_type)
...
@@ -57,7 +57,7 @@ add_subdirectory(data_type)
add_subdirectory
(
elementwise_normalization
)
add_subdirectory
(
elementwise_normalization
)
add_subdirectory
(
batchnorm
)
add_subdirectory
(
batchnorm
)
add_subdirectory
(
contraction
)
add_subdirectory
(
contraction
)
add_subdirectory
(
pool
_fwd
)
add_subdirectory
(
pool
)
add_subdirectory
(
batched_gemm_multi_d
)
add_subdirectory
(
batched_gemm_multi_d
)
add_subdirectory
(
grouped_convnd_bwd_data
)
add_subdirectory
(
grouped_convnd_bwd_data
)
if
(
GPU_TARGETS MATCHES
"gfx11"
)
if
(
GPU_TARGETS MATCHES
"gfx11"
)
...
...
Prev
1
2
3
4
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment