Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
6252d207
"official/projects/pruning/train.py" did not exist on "b6907e8dcde1c42205a424fca10d8ee61586e23c"
Commit
6252d207
authored
Sep 13, 2024
by
carlushuang
Browse files
Merge remote-tracking branch 'origin/develop' into ck_tile/fav3_fwd_sept
parents
eed60199
e07f1108
Changes
51
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1863 additions
and
142 deletions
+1863
-142
Dockerfile
Dockerfile
+2
-0
Jenkinsfile
Jenkinsfile
+2
-2
example/01_gemm/run_gemm_example.inc
example/01_gemm/run_gemm_example.inc
+8
-0
include/ck/tensor_operation/gpu/block/blockwise_gemm_smfmac_xdlops.hpp
...nsor_operation/gpu/block/blockwise_gemm_smfmac_xdlops.hpp
+453
-0
include/ck/tensor_operation/gpu/device/impl/device_avgpool2d_bwd_nhwc_nhwc.hpp
...ration/gpu/device/impl/device_avgpool2d_bwd_nhwc_nhwc.hpp
+523
-0
include/ck/tensor_operation/gpu/device/impl/device_pool2d_fwd_nhwc_nhwc.hpp
...operation/gpu/device/impl/device_pool2d_fwd_nhwc_nhwc.hpp
+349
-74
include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp
...or_operation/gpu/element/unary_element_wise_operation.hpp
+28
-1
include/ck/tensor_operation/gpu/warp/smfmac_xdlops_gemm.hpp
include/ck/tensor_operation/gpu/warp/smfmac_xdlops_gemm.hpp
+38
-14
include/ck/utility/amd_smfmac.hpp
include/ck/utility/amd_smfmac.hpp
+14
-12
include/ck/utility/reduction_operator.hpp
include/ck/utility/reduction_operator.hpp
+165
-10
include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp
.../fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp
+15
-15
include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp
include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp
+1
-1
include/ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v1.hpp
...gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v1.hpp
+1
-1
include/ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp
...lock_gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp
+6
-6
include/ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v2.hpp
...gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v2.hpp
+1
-1
include/ck_tile/ops/gemm/pipeline/block_gemm_pipeline_problem.hpp
...ck_tile/ops/gemm/pipeline/block_gemm_pipeline_problem.hpp
+4
-4
library/include/ck/library/tensor_operation_instance/gpu/avg_pool2d_bwd.hpp
.../library/tensor_operation_instance/gpu/avg_pool2d_bwd.hpp
+80
-0
library/include/ck/library/tensor_operation_instance/gpu/max_pool_bwd.hpp
...ck/library/tensor_operation_instance/gpu/max_pool_bwd.hpp
+12
-1
library/include/ck/library/tensor_operation_instance/gpu/pool2d_fwd.hpp
...e/ck/library/tensor_operation_instance/gpu/pool2d_fwd.hpp
+153
-0
library/src/tensor_operation_instance/gpu/avg_pool2d_bwd/CMakeLists.txt
...nsor_operation_instance/gpu/avg_pool2d_bwd/CMakeLists.txt
+8
-0
No files found.
Dockerfile
View file @
6252d207
...
...
@@ -130,6 +130,8 @@ ENV compiler_commit=$compiler_commit
RUN
sh
-c
"echo compiler version = '
$compiler_version
'"
RUN
sh
-c
"echo compiler commit = '
$compiler_commit
'"
ARG
DISABLE_CACHE=0
RUN if
(
[
"
$compiler_version
"
=
"amd-staging"
]
||
[
"
$compiler_version
"
=
"amd-mainline-open"
]
)
&&
[
"
$compiler_commit
"
=
""
]
;
then
\
git clone
-b
"
$compiler_version
"
https://github.com/ROCm/llvm-project.git
&&
\
cd
llvm-project
&&
mkdir
build
&&
cd
build
&&
\
...
...
Jenkinsfile
View file @
6252d207
...
...
@@ -94,7 +94,7 @@ def getDockerImage(Map conf=[:]){
env
.
DOCKER_BUILDKIT
=
1
def
prefixpath
=
conf
.
get
(
"prefixpath"
,
"/opt/rocm"
)
def
no_cache
=
conf
.
get
(
"no_cache"
,
false
)
def
dockerArgs
=
"--build-arg BUILDKIT_INLINE_CACHE=1 --build-arg PREFIX=${prefixpath} --build-arg CK_SCCACHE='${env.CK_SCCACHE}' --build-arg compiler_version='${params.COMPILER_VERSION}' --build-arg compiler_commit='${params.COMPILER_COMMIT}' --build-arg ROCMVERSION='${params.ROCMVERSION}' "
def
dockerArgs
=
"--build-arg BUILDKIT_INLINE_CACHE=1 --build-arg PREFIX=${prefixpath} --build-arg CK_SCCACHE='${env.CK_SCCACHE}' --build-arg compiler_version='${params.COMPILER_VERSION}' --build-arg compiler_commit='${params.COMPILER_COMMIT}' --build-arg ROCMVERSION='${params.ROCMVERSION}'
--build-arg DISABLE_CACHE='git rev-parse ${params.COMPILER_VERSION}'
"
if
(
no_cache
)
{
dockerArgs
=
dockerArgs
+
" --no-cache "
...
...
@@ -124,7 +124,7 @@ def buildDocker(install_prefix){
checkout
scm
def
image_name
=
getDockerImageName
()
echo
"Building Docker for ${image_name}"
def
dockerArgs
=
"--build-arg BUILDKIT_INLINE_CACHE=1 --build-arg PREFIX=${install_prefix} --build-arg CK_SCCACHE='${env.CK_SCCACHE}' --build-arg compiler_version='${params.COMPILER_VERSION}' --build-arg compiler_commit='${params.COMPILER_COMMIT}' --build-arg ROCMVERSION='${params.ROCMVERSION}' "
def
dockerArgs
=
"--build-arg BUILDKIT_INLINE_CACHE=1 --build-arg PREFIX=${install_prefix} --build-arg CK_SCCACHE='${env.CK_SCCACHE}' --build-arg compiler_version='${params.COMPILER_VERSION}' --build-arg compiler_commit='${params.COMPILER_COMMIT}' --build-arg ROCMVERSION='${params.ROCMVERSION}'
--build-arg DISABLE_CACHE='git rev-parse ${params.COMPILER_VERSION}'
"
echo
"Build Args: ${dockerArgs}"
try
{
...
...
example/01_gemm/run_gemm_example.inc
View file @
6252d207
...
...
@@ -305,6 +305,14 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
}
#endif
}
else
{
// When the Problem Type and Problem Size does not fit.
std
::
cerr
<<
gemm
.
GetTypeString
()
<<
": the instance does not support the problem config."
<<
std
::
endl
;
return
true
;
}
std
::
size_t
flop
=
2_
uz
*
M
*
N
*
K
;
std
::
size_t
num_btype
=
...
...
include/ck/tensor_operation/gpu/block/blockwise_gemm_smfmac_xdlops.hpp
0 → 100644
View file @
6252d207
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/utility/loop_scheduler.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/warp/xdlops_gemm.hpp"
#include "ck/tensor_description/tensor_adaptor.hpp"
namespace
ck
{
template
<
index_t
MNXdlPerWave
,
index_t
MNWaves
,
index_t
MNPerXdl
,
typename
TileDesc_K0_MN_K1
>
__host__
__device__
static
constexpr
auto
MakeGemmMmaTileDescriptor_MN0_MN1_MN2_K
(
const
TileDesc_K0_MN_K1
&
)
{
constexpr
index_t
K0
=
TileDesc_K0_MN_K1
{}.
GetLength
(
Number
<
0
>
{});
constexpr
index_t
K1
=
TileDesc_K0_MN_K1
{}.
GetLength
(
Number
<
2
>
{});
return
transform_tensor_descriptor
(
TileDesc_K0_MN_K1
{},
make_tuple
(
make_merge_transform_v3_division_mod
(
make_tuple
(
Number
<
K0
>
{},
Number
<
K1
>
{})),
make_unmerge_transform
(
make_tuple
(
Number
<
MNXdlPerWave
>
{},
Number
<
MNWaves
>
{},
Number
<
MNPerXdl
>
{}))),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
3
>
{},
Sequence
<
0
,
1
,
2
>
{}));
}
template
<
index_t
BlockSize
,
typename
FloatA
,
typename
FloatB
,
typename
FloatAcc
,
typename
AK0MK1BlockDesc
,
typename
BK0NK1BlockDesc
,
index_t
MPerXDL
,
index_t
NPerXDL
,
index_t
MRepeat
,
index_t
NRepeat
,
index_t
KPack
,
typename
ComputeTypeA
=
FloatA
,
typename
ComputeTypeB
=
FloatB
>
struct
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
using
ThisThreadBlock
=
ThisThreadBlock
<
BlockSize
>
;
static
constexpr
index_t
WaveSize
=
get_warp_size
();
static
constexpr
index_t
MPerBlock
=
AK0MK1BlockDesc
{}.
GetLength
(
I1
);
static
constexpr
index_t
NPerBlock
=
BK0NK1BlockDesc
{}.
GetLength
(
I1
);
static
constexpr
index_t
KPerBlock
=
BK0NK1BlockDesc
{}.
GetLength
(
I0
)
*
BK0NK1BlockDesc
{}.
GetLength
(
I2
);
static
constexpr
index_t
A_K0
=
AK0MK1BlockDesc
{}.
GetLength
(
I0
);
static
constexpr
index_t
B_K0
=
BK0NK1BlockDesc
{}.
GetLength
(
I0
);
static
constexpr
index_t
A_K1
=
AK0MK1BlockDesc
{}.
GetLength
(
I2
);
static
constexpr
index_t
B_K1
=
BK0NK1BlockDesc
{}.
GetLength
(
I2
);
static
constexpr
auto
xdlops_gemm
=
SparseXdlopsGemm
<
ComputeTypeA
,
MPerXDL
,
NPerXDL
,
KPack
,
ComputeTypeB
>
{};
static
constexpr
index_t
KPerThread
=
KPerBlock
/
xdlops_gemm
.
K0PerXdlops
;
static
constexpr
index_t
MWaves
=
MPerBlock
/
(
MRepeat
*
MPerXDL
);
static
constexpr
index_t
NWaves
=
NPerBlock
/
(
NRepeat
*
NPerXDL
);
StaticBufferTupleOfVector
<
AddressSpaceEnum
::
Vgpr
,
FloatAcc
,
MRepeat
*
NRepeat
,
xdlops_gemm
.
GetRegSizePerXdlops
(),
true
>
c_thread_buf_
;
__host__
__device__
constexpr
auto
&
GetCThreadBuffer
()
{
return
c_thread_buf_
;
}
__device__
static
auto
GetWaveIdx
()
{
const
index_t
thread_id
=
ThisThreadBlock
::
GetThreadId
();
constexpr
auto
threadid_to_wave_idx_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
MWaves
,
NWaves
,
WaveSize
))),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
return
threadid_to_wave_idx_adaptor
.
CalculateBottomIndex
(
make_multi_index
(
thread_id
));
}
__device__
static
auto
CalculateAThreadOriginDataIndex
()
{
const
auto
wave_idx
=
GetWaveIdx
();
const
auto
waveId_m
=
wave_idx
[
I0
];
const
auto
xdlops_a_idx
=
xdlops_gemm
.
CalculateAThreadOriginDataIndex
();
return
make_tuple
(
0
,
waveId_m
,
xdlops_a_idx
[
I1
],
KPerThread
*
xdlops_a_idx
[
I0
]);
}
__device__
static
auto
CalculateBThreadOriginDataIndex
()
{
const
auto
wave_idx
=
GetWaveIdx
();
const
auto
waveId_n
=
wave_idx
[
I1
];
const
auto
xdlops_b_idx
=
xdlops_gemm
.
CalculateBThreadOriginDataIndex
();
return
make_tuple
(
0
,
waveId_n
,
xdlops_b_idx
[
I1
],
KPerThread
*
xdlops_b_idx
[
I0
]);
}
template
<
index_t
m0
,
index_t
n0
,
index_t
xdlops_i
,
index_t
blk_i
>
__device__
static
auto
CalculateCThreadOriginDataIndex
(
Number
<
m0
>
,
Number
<
n0
>
,
Number
<
xdlops_i
>
,
Number
<
blk_i
>
)
{
const
auto
wave_idx
=
GetWaveIdx
();
const
auto
waveId_m
=
wave_idx
[
I0
];
const
auto
waveId_n
=
wave_idx
[
I1
];
const
auto
blk_idx
=
xdlops_gemm
.
GetBeginOfThreadBlk
(
xdlops_i
,
blk_i
);
constexpr
auto
mrepeat_mwave_mperxdl_to_m_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_unmerge_transform
(
make_tuple
(
MRepeat
,
MWaves
,
MPerXDL
))),
make_tuple
(
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{}));
constexpr
auto
nrepeat_nwave_nperxdl_to_n_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_unmerge_transform
(
make_tuple
(
NRepeat
,
NWaves
,
NPerXDL
))),
make_tuple
(
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{}));
const
index_t
c_thread_m
=
mrepeat_mwave_mperxdl_to_m_adaptor
.
CalculateBottomIndex
(
make_tuple
(
m0
,
waveId_m
,
blk_idx
[
I0
]))[
I0
];
const
index_t
c_thread_n
=
nrepeat_nwave_nperxdl_to_n_adaptor
.
CalculateBottomIndex
(
make_tuple
(
n0
,
waveId_n
,
blk_idx
[
I1
]))[
I0
];
return
make_tuple
(
c_thread_m
,
c_thread_n
);
}
template
<
index_t
m0
,
index_t
n0
,
index_t
xdlops_i
,
index_t
blk_i
>
__device__
static
auto
CalculateCThreadOriginDataIndex8D
(
Number
<
m0
>
,
Number
<
n0
>
,
Number
<
xdlops_i
>
,
Number
<
blk_i
>
)
{
const
auto
wave_idx
=
GetWaveIdx
();
const
auto
waveId_m
=
wave_idx
[
I0
];
const
auto
waveId_n
=
wave_idx
[
I1
];
const
auto
blk_idx
=
xdlops_gemm
.
GetBeginOfThreadBlk4D
(
xdlops_i
,
blk_i
);
return
make_tuple
(
Number
<
m0
>
{},
Number
<
n0
>
{},
waveId_m
,
waveId_n
,
blk_idx
[
I0
],
blk_idx
[
I1
],
blk_idx
[
I2
],
blk_idx
[
I3
]);
}
__host__
__device__
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
()
{
static_assert
(
AK0MK1BlockDesc
::
IsKnownAtCompileTime
()
&&
BK0NK1BlockDesc
::
IsKnownAtCompileTime
(),
"wrong! Desc should be known at compile-time"
);
static_assert
(
ThisThreadBlock
::
GetNumOfThread
()
==
MWaves
*
NWaves
*
WaveSize
,
"ThisThreadBlock::GetNumOfThread() != MWaves * NWaves * WaveSize
\n
"
);
static_assert
(
MPerBlock
%
(
MPerXDL
*
MRepeat
)
==
0
,
"MPerBlock must be divisible by MPerXDL * MRepeat"
);
static_assert
(
NPerBlock
%
(
NPerXDL
*
NRepeat
)
==
0
,
"NPerBlock must be divisible by NPerXDL * NRepeat"
);
static_assert
(
KPack
%
(
16
*
sizeof
(
ComputeTypeA
))
==
0
,
"KPack must be divisbile by number of elements processed in single smfmac instruction"
);
}
__host__
__device__
static
constexpr
auto
GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
()
{
constexpr
auto
c_m0_m1_m2_n_tblk_lens
=
xdlops_gemm
.
GetCM0M1M2NThreadBlkLengths
();
constexpr
auto
M0
=
c_m0_m1_m2_n_tblk_lens
[
I0
];
constexpr
auto
M1
=
c_m0_m1_m2_n_tblk_lens
[
I1
];
constexpr
auto
M2
=
c_m0_m1_m2_n_tblk_lens
[
I2
];
constexpr
auto
N
=
c_m0_m1_m2_n_tblk_lens
[
I3
];
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MRepeat
>
{},
Number
<
NRepeat
>
{},
I1
,
I1
,
M0
,
M1
,
M2
,
N
));
}
__host__
__device__
static
constexpr
auto
GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2
()
{
constexpr
auto
c_m0_m1_m2_n_tblk_lens
=
xdlops_gemm
.
GetCM0M1M2NThreadBlkLengths
();
constexpr
auto
M0
=
c_m0_m1_m2_n_tblk_lens
[
I0
];
constexpr
auto
M1
=
c_m0_m1_m2_n_tblk_lens
[
I1
];
constexpr
auto
M2
=
c_m0_m1_m2_n_tblk_lens
[
I2
];
constexpr
auto
N
=
c_m0_m1_m2_n_tblk_lens
[
I3
];
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
Number
<
MRepeat
>
{},
Number
<
NRepeat
>
{},
I1
,
I1
,
M0
,
M1
,
M2
,
N
));
}
__host__
__device__
static
constexpr
auto
GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
()
{
constexpr
auto
c_block_desc_m0_n0_m1_n1_m2_n2
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MRepeat
>
{},
Number
<
NRepeat
>
{},
Number
<
MWaves
>
{},
Number
<
NWaves
>
{},
Number
<
MPerXDL
>
{},
Number
<
NPerXDL
>
{}));
return
xdlops_gemm
.
MakeCDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
(
c_block_desc_m0_n0_m1_n1_m2_n2
);
}
__host__
__device__
static
constexpr
auto
GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2
()
{
constexpr
auto
c_block_desc_g_m0_n0_m1_n1_m2_n2
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
Number
<
MRepeat
>
{},
Number
<
NRepeat
>
{},
Number
<
MWaves
>
{},
Number
<
NWaves
>
{},
Number
<
MPerXDL
>
{},
Number
<
NPerXDL
>
{}));
return
xdlops_gemm
.
MakeCDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2
(
c_block_desc_g_m0_n0_m1_n1_m2_n2
);
}
template
<
typename
CGridDesc_M_N
>
__host__
__device__
static
constexpr
auto
MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
)
{
const
auto
M
=
c_grid_desc_m_n
.
GetLength
(
I0
);
const
auto
N
=
c_grid_desc_m_n
.
GetLength
(
I1
);
const
auto
c_grid_desc_m0_n0_m1_n1_m2_n2
=
transform_tensor_descriptor
(
c_grid_desc_m_n
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
M
/
(
MWaves
*
MPerXDL
),
MWaves
,
MPerXDL
)),
make_unmerge_transform
(
make_tuple
(
N
/
(
NWaves
*
NPerXDL
),
NWaves
,
NPerXDL
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
,
4
>
{},
Sequence
<
1
,
3
,
5
>
{}));
return
xdlops_gemm
.
MakeCDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
(
c_grid_desc_m0_n0_m1_n1_m2_n2
);
}
template
<
typename
CGridDesc_G_M_N
>
__host__
__device__
static
constexpr
auto
MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2
(
const
CGridDesc_G_M_N
&
c_grid_desc_g_m_n
)
{
const
auto
G
=
c_grid_desc_g_m_n
.
GetLength
(
I0
);
const
auto
M
=
c_grid_desc_g_m_n
.
GetLength
(
I1
);
const
auto
N
=
c_grid_desc_g_m_n
.
GetLength
(
I2
);
const
auto
c_grid_desc_g_m0_n0_m1_n1_m2_n2
=
transform_tensor_descriptor
(
c_grid_desc_g_m_n
,
make_tuple
(
make_pass_through_transform
(
G
),
make_unmerge_transform
(
make_tuple
(
M
/
(
MWaves
*
MPerXDL
),
MWaves
,
MPerXDL
)),
make_unmerge_transform
(
make_tuple
(
N
/
(
NWaves
*
NPerXDL
),
NWaves
,
NPerXDL
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
3
,
5
>
{},
Sequence
<
2
,
4
,
6
>
{}));
return
xdlops_gemm
.
MakeCDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2
(
c_grid_desc_g_m0_n0_m1_n1_m2_n2
);
}
__host__
__device__
static
constexpr
auto
MakeABlockDescriptor_M0_M1_M2_K
()
{
return
transform_tensor_descriptor
(
AK0MK1BlockDesc
{},
make_tuple
(
make_merge_transform_v3_division_mod
(
make_tuple
(
Number
<
A_K0
>
{},
Number
<
A_K1
>
{})),
make_unmerge_transform
(
make_tuple
(
Number
<
MRepeat
>
{},
Number
<
MWaves
>
{},
Number
<
MPerXDL
>
{}))),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
3
>
{},
Sequence
<
0
,
1
,
2
>
{}));
}
__host__
__device__
static
constexpr
auto
MakeBBlockDescriptor_N0_N1_N2_K
()
{
return
transform_tensor_descriptor
(
BK0NK1BlockDesc
{},
make_tuple
(
make_merge_transform_v3_division_mod
(
make_tuple
(
Number
<
B_K0
>
{},
Number
<
B_K1
>
{})),
make_unmerge_transform
(
make_tuple
(
Number
<
NRepeat
>
{},
Number
<
NWaves
>
{},
Number
<
NPerXDL
>
{}))),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
3
>
{},
Sequence
<
0
,
1
,
2
>
{}));
}
static
constexpr
auto
a_block_desc_m0_m1_m2_k
=
MakeABlockDescriptor_M0_M1_M2_K
();
static
constexpr
auto
b_block_desc_n0_n1_n2_k
=
MakeBBlockDescriptor_N0_N1_N2_K
();
// Prepares data in a_thread_buf by squeezing values by ommiting zeros to adjust it to 2:4
// structural sparsity. The indexes of non-zero elements are stored in idx_buf and used later in
// smfmac instruction
template
<
typename
AThreadBuf
,
typename
IdxBuf
,
int32_t
num_elems
>
__device__
void
SetIdxSqueezeA
(
AThreadBuf
&
a_thread_buf
,
IdxBuf
&
idx_buf
)
{
static
constexpr
int32_t
bit_clear_masks
[
4
]
=
{
0b11
,
0b1100
,
0b110000
,
0b11000000
};
static
constexpr
int32_t
processed_elems
=
16
/
sizeof
(
ComputeTypeA
);
static_for
<
0
,
num_elems
,
processed_elems
>
{}([
&
](
auto
i
)
{
constexpr
int
idx_reg_num
=
i
/
(
16
*
sizeof
(
ComputeTypeA
));
constexpr
int
idx_reg_part
=
(
i
%
32
)
/
processed_elems
;
vector_type
<
ComputeTypeA
,
processed_elems
>
a_thread_vec
;
static_for
<
0
,
processed_elems
,
1
>
{}([
&
](
auto
j
)
{
a_thread_vec
.
template
AsType
<
ComputeTypeA
>()(
j
)
=
a_thread_buf
[
Number
<
a_thread_desc_
.
CalculateOffset
(
make_tuple
(
0
,
0
,
0
,
i
+
j
))
>
{}];
});
uint8_t
idx
=
0b11101110
;
// set to last 2 elems for both 4-elems subgroups by default
for
(
int
j
=
0
;
j
<
processed_elems
;
j
+=
4
)
{
int32_t
a_pos
=
idx_reg_part
*
processed_elems
+
j
;
int32_t
nonzero_pos
=
0
;
ComputeTypeA
nonzero_elems
[
2
]
=
{
a_thread_vec
[
j
+
2
],
a_thread_vec
[
j
+
3
]};
for
(
int
k
=
0
;
k
<
3
;
k
+=
1
)
{
if
(
a_thread_vec
[
j
+
k
]
!=
0.0
f
)
{
nonzero_elems
[
nonzero_pos
]
=
a_thread_vec
[
j
+
k
];
idx
&=
~
bit_clear_masks
[
j
/
2
+
nonzero_pos
];
idx
|=
k
<<
2
*
(
j
/
2
+
nonzero_pos
);
++
nonzero_pos
;
}
}
a_thread_vec
[
j
/
2
]
=
nonzero_elems
[
0
];
a_thread_vec
[
j
/
2
+
1
]
=
nonzero_elems
[
1
];
}
IdxBuf
[
idx_reg_num
].
AsType
<
int8x4_t
>
()[
Number
<
idx_reg_part
>
{}]
=
idx
;
static_for
<
0
,
processed_elems
/
2
,
1
>
{}([
&
](
auto
j
)
{
a_thread_buf
[
Number
<
a_thread_desc_
.
CalculateOffset
(
make_tuple
(
0
,
0
,
0
,
i
/
2
+
j
))
>
{}]
=
a_thread_vec
[
j
];
});
});
}
template
<
typename
ABlockBuffer
,
typename
BBlockBuffer
,
typename
CThreadBuffer
>
__device__
void
Run
(
const
ABlockBuffer
&
a_block_buf
,
const
BBlockBuffer
&
b_block_buf
,
CThreadBuffer
&
c_thread_buf
)
const
{
auto
a_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
ComputeTypeA
>
(
a_thread_desc_
.
GetElementSpaceSize
());
auto
b_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
ComputeTypeB
>
(
b_thread_desc_
.
GetElementSpaceSize
());
static
constexpr
int32_t
elems_per_idx
=
16
*
sizeof
(
ComputeTypeA
);
auto
idx_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
int32_t
>
(
(
a_thread_desc_
.
GetElementSpaceSize
()
+
elems_per_idx
-
1
)
/
elems_per_idx
);
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
// read A
a_thread_copy_
.
Run
(
a_block_desc_m0_m1_m2_k
,
make_tuple
(
m0
,
I0
,
I0
,
I0
),
a_block_buf
,
a_thread_desc_
,
make_tuple
(
I0
,
I0
,
I0
,
I0
),
a_thread_buf
);
SetIdxSqueezeA
(
a_thread_buf
,
idx_buf
,
a_thread_desc_
.
GetElementSpaceSize
());
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
// read B
b_thread_copy_
.
Run
(
b_block_desc_n0_n1_n2_k
,
make_tuple
(
n0
,
I0
,
I0
,
I0
),
b_block_buf
,
b_thread_desc_
,
make_tuple
(
I0
,
I0
,
I0
,
I0
),
b_thread_buf
);
static_for
<
0
,
KPerThread
,
KPack
>
{}([
&
](
auto
k
)
{
// a_thread_vec is smaller because it's structurally sparse 2:4
vector_type
<
ComputeTypeA
,
KPack
/
2
>
a_thread_vec
;
vector_type
<
ComputeTypeB
,
KPack
>
b_thread_vec
;
vector_type
<
int32_t
,
KPack
/
elems_per_idx
>
idx_vec
;
static_for
<
0
,
KPack
/
2
,
1
>
{}([
&
](
auto
i
)
{
a_thread_vec
.
template
AsType
<
ComputeTypeA
>()(
i
)
=
a_thread_buf
[
Number
<
a_thread_desc_
.
CalculateOffset
(
make_tuple
(
0
,
0
,
0
,
k
/
2
+
i
))
>
{}];
});
static_for
<
0
,
KPack
,
1
>
{}([
&
](
auto
i
)
{
b_thread_vec
.
template
AsType
<
ComputeTypeB
>()(
2
*
i
)
=
b_thread_buf
[
Number
<
b_thread_desc_
.
CalculateOffset
(
make_tuple
(
0
,
0
,
0
,
k
+
i
))
>
{}];
});
static_for
<
0
,
KPack
/
elems_per_idx
,
1
>
{}([
&
](
auto
i
)
{
idx_vec
.
template
AsType
<
int32_t
>()(
i
)
=
idx_buf
[
k
/
elems_per_idx
+
i
];
});
// A is smaller because it's structurally sparse 2:4
using
mfma_input_type_a
=
typename
vector_type
<
ComputeTypeA
,
xdlops_gemm
.
K1PerXdlops
/
2
>::
type
;
using
mfma_input_type_b
=
typename
vector_type
<
ComputeTypeB
,
xdlops_gemm
.
K1PerXdlops
>::
type
;
using
mfma_input_type_idx
=
typename
vector_type
<
int32_t
,
1
>::
type
;
constexpr
index_t
c_offset
=
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
0
));
xdlops_gemm
.
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type_a
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type_b
>(),
idx_vec
.
template
AsType
<
mfma_input_type_idx
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>
{}));
});
});
});
}
protected:
// A[M0, M1, M2, KPerThread]
static
constexpr
auto
a_thread_desc_
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
I1
,
I1
,
Number
<
KPerThread
>
{}));
// B[N0, N1, N2, KPerThread]
static
constexpr
auto
b_thread_desc_
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
I1
,
I1
,
Number
<
KPerThread
>
{}));
// C[M, N, NumRegXdlops]
static
constexpr
auto
c_thread_desc_
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MRepeat
>
{},
Number
<
NRepeat
>
{},
xdlops_gemm
.
GetRegSizePerXdlops
()));
using
AThreadCopy
=
ThreadwiseTensorSliceTransfer_v4
<
FloatA
,
ComputeTypeA
,
decltype
(
a_block_desc_m0_m1_m2_k
),
decltype
(
a_thread_desc_
),
Sequence
<
1
,
1
,
1
,
KPerThread
>
,
Sequence
<
0
,
1
,
2
,
3
>
,
3
,
A_K1
,
A_K1
>
;
using
BThreadCopy
=
ThreadwiseTensorSliceTransfer_v4
<
FloatB
,
ComputeTypeB
,
decltype
(
b_block_desc_n0_n1_n2_k
),
decltype
(
b_thread_desc_
),
Sequence
<
1
,
1
,
1
,
KPerThread
>
,
Sequence
<
0
,
1
,
2
,
3
>
,
3
,
B_K1
,
B_K1
>
;
AThreadCopy
a_thread_copy_
{
CalculateAThreadOriginDataIndex
()};
BThreadCopy
b_thread_copy_
{
CalculateBThreadOriginDataIndex
()};
};
}
// namespace ck
include/ck/tensor_operation/gpu/device/impl/device_avgpool2d_bwd_nhwc_nhwc.hpp
0 → 100644
View file @
6252d207
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <sstream>
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/reduction_operator_mapping.hpp"
#include "ck/tensor_operation/gpu/device/device_avgpool_bwd.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_2d_reduction_threadwise.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
// In and Din = [N, C, Hi, Wi]
// Out and Dout = [N, C, Ho, Wo]
// Out = AvgPool2dFwd(In)
// Din = AvgPool2dBwd(Dout)
// Pooling dimension = H, W
template
<
typename
DOutDataType
,
typename
DInDataType
,
typename
ComputeDataType
,
ck
::
index_t
BlockSize
,
ck
::
index_t
MThreadClusterSize
,
ck
::
index_t
KThreadClusterSize
,
ck
::
index_t
MThreadSliceSize
,
ck
::
index_t
KThreadSliceSize
,
ck
::
index_t
InSrcOutDstVectorSize
>
struct
DeviceAvgPool2dBwd_NHWC_NHWC
:
public
DeviceAvgPoolBwd
<
2
,
DOutDataType
,
DInDataType
,
tensor_layout
::
convolution
::
NHWC
,
tensor_layout
::
convolution
::
NHWC
>
{
static
constexpr
ck
::
index_t
NDimSpatial
=
2
;
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
ck
::
index_t
M_BlockTileSize
=
MThreadClusterSize
*
MThreadSliceSize
;
static
constexpr
ck
::
index_t
K_BlockTileSize
=
KThreadClusterSize
*
KThreadSliceSize
;
static
auto
Make2DGridDescriptor_Out_M_K_In_M
(
const
std
::
vector
<
ck
::
index_t
>&
dout_n_c_wos_lengths
,
const
std
::
vector
<
ck
::
index_t
>&
din_n_c_wos_length
,
const
std
::
vector
<
ck
::
index_t
>&
dout_n_c_wos_strides
,
const
std
::
vector
<
ck
::
index_t
>&
din_n_c_wos_strides
,
const
std
::
vector
<
ck
::
index_t
>&
window_lengths
,
const
std
::
vector
<
ck
::
index_t
>&
window_strides
,
const
std
::
vector
<
ck
::
index_t
>&
window_dilations
,
const
std
::
vector
<
ck
::
index_t
>&
input_left_pads
,
const
std
::
vector
<
ck
::
index_t
>&
input_right_pads
,
const
std
::
vector
<
ck
::
index_t
>&
tildes
)
{
index_t
i_ytilde
=
tildes
[
0
];
index_t
i_xtilde
=
tildes
[
1
];
const
index_t
N
=
dout_n_c_wos_lengths
[
0
];
const
index_t
C
=
dout_n_c_wos_lengths
[
1
];
const
index_t
Ho
=
dout_n_c_wos_lengths
[
2
];
const
index_t
Wo
=
dout_n_c_wos_lengths
[
3
];
const
index_t
Hi
=
din_n_c_wos_length
[
2
];
const
index_t
Wi
=
din_n_c_wos_length
[
3
];
const
index_t
Y
=
window_lengths
[
0
];
const
index_t
X
=
window_lengths
[
1
];
const
index_t
InLeftPadH
=
input_left_pads
[
0
];
const
index_t
InLeftPadW
=
input_left_pads
[
1
];
const
index_t
InRightPadH
=
input_right_pads
[
0
];
const
index_t
InRightPadW
=
input_right_pads
[
1
];
const
index_t
ConvStrideH
=
window_strides
[
0
];
const
index_t
ConvStrideW
=
window_strides
[
1
];
const
index_t
ConvDilationH
=
window_dilations
[
0
];
const
index_t
ConvDilationW
=
window_dilations
[
1
];
const
index_t
Ni_stride
=
dout_n_c_wos_strides
[
0
];
const
index_t
Ci_stride
=
dout_n_c_wos_strides
[
1
];
const
index_t
Ho_stride
=
dout_n_c_wos_strides
[
2
];
const
index_t
Wo_stride
=
dout_n_c_wos_strides
[
3
];
const
auto
GcdStrideDilationH
=
math
::
gcd
(
ConvStrideH
,
ConvDilationH
);
const
auto
GcdStrideDilationW
=
math
::
gcd
(
ConvStrideW
,
ConvDilationW
);
const
auto
YTilde
=
ConvStrideH
/
GcdStrideDilationH
;
const
auto
XTilde
=
ConvStrideW
/
GcdStrideDilationW
;
const
auto
YDot
=
math
::
integer_divide_ceil
(
Y
,
YTilde
);
const
auto
XDot
=
math
::
integer_divide_ceil
(
X
,
XTilde
);
const
auto
HTilde
=
Ho
+
math
::
integer_divide_ceil
(
ConvDilationH
*
(
Y
-
I1
),
ConvStrideH
);
const
auto
WTilde
=
Wo
+
math
::
integer_divide_ceil
(
ConvDilationW
*
(
X
-
I1
),
ConvStrideW
);
// only work on Tildes that contribute to non-padding area of input tensor
const
auto
IHTildeSliceBegin
=
math
::
integer_divide_floor
(
math
::
max
(
I0
,
InLeftPadH
-
ConvDilationH
*
(
YTilde
-
I1
)),
ConvStrideH
);
const
auto
IWTildeSliceBegin
=
math
::
integer_divide_floor
(
math
::
max
(
I0
,
InLeftPadW
-
ConvDilationW
*
(
XTilde
-
I1
)),
ConvStrideW
);
const
auto
IHTildeSliceEnd
=
math
::
min
(
HTilde
,
math
::
integer_divide_ceil
(
InLeftPadH
+
Hi
-
I1
,
ConvStrideH
)
+
I1
);
const
auto
IWTildeSliceEnd
=
math
::
min
(
WTilde
,
math
::
integer_divide_ceil
(
InLeftPadW
+
Wi
-
I1
,
ConvStrideW
)
+
I1
);
const
auto
HTildeSlice
=
IHTildeSliceEnd
-
IHTildeSliceBegin
;
const
auto
WTildeSlice
=
IWTildeSliceEnd
-
IWTildeSliceBegin
;
// ReduceK is different for each Reduce
const
auto
YDotSlice
=
math
::
integer_divide_ceil
(
Y
-
i_ytilde
,
YTilde
);
const
auto
XDotSlice
=
math
::
integer_divide_ceil
(
X
-
i_xtilde
,
XTilde
);
// Problem size of reduction kernel
const
index_t
MRaw
=
N
*
HTildeSlice
*
WTildeSlice
*
C
;
const
index_t
MPad
=
math
::
integer_least_multiple
(
MRaw
,
M_BlockTileSize
)
-
MRaw
;
const
index_t
KRaw
=
YDotSlice
*
XDotSlice
;
const
index_t
KPad
=
math
::
integer_least_multiple
(
KRaw
,
K_BlockTileSize
)
-
KRaw
;
const
auto
out_n_ho_wo_c_grid_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
N
,
Ho
,
Wo
,
C
),
make_tuple
(
Ni_stride
,
Ho_stride
,
Wo_stride
,
Ci_stride
));
// Out[ReduceM, ReduceK]
const
auto
out_n_hop_wop_c_grid_desc
=
transform_tensor_descriptor
(
out_n_ho_wo_c_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_pad_transform
(
Ho
,
I0
,
I0
),
make_pad_transform
(
Wo
,
I0
,
I0
),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}));
const
auto
out_n_ydot_htilde_xdot_wtilde_c_grid_desc
=
transform_tensor_descriptor
(
out_n_hop_wop_c_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_embed_transform
(
make_tuple
(
YDot
,
HTilde
),
make_tuple
(
-
ConvDilationH
/
GcdStrideDilationH
,
I1
)),
make_embed_transform
(
make_tuple
(
XDot
,
WTilde
),
make_tuple
(
-
ConvDilationW
/
GcdStrideDilationW
,
I1
)),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
,
4
>
{},
Sequence
<
5
>
{}));
const
auto
out_n_ydotslice_htildeslice_xdotslice_wtildeslice_c_grid_desc
=
transform_tensor_descriptor
(
out_n_ydot_htilde_xdot_wtilde_c_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_slice_transform
(
YDot
,
I0
,
YDotSlice
),
make_slice_transform
(
HTilde
,
IHTildeSliceBegin
,
HTildeSlice
),
make_slice_transform
(
XDot
,
I0
,
XDotSlice
),
make_slice_transform
(
WTilde
,
IWTildeSliceBegin
,
WTildeSlice
),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{},
Sequence
<
5
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{},
Sequence
<
5
>
{}));
const
auto
out_grid_desc_reducemraw_reducekraw
=
transform_tensor_descriptor
(
out_n_ydotslice_htildeslice_xdotslice_wtildeslice_c_grid_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
N
,
HTildeSlice
,
WTildeSlice
,
C
)),
make_merge_transform
(
make_tuple
(
YDotSlice
,
XDotSlice
))),
make_tuple
(
Sequence
<
0
,
2
,
4
,
5
>
{},
Sequence
<
1
,
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
out_grid_desc_reducem_reducek
=
transform_tensor_descriptor
(
out_grid_desc_reducemraw_reducekraw
,
make_tuple
(
make_right_pad_transform
(
MRaw
,
MPad
),
make_right_pad_transform
(
KRaw
,
KPad
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
// In[ReduceM]
const
auto
in_n_hi_wi_c_grid_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
N
,
Hi
,
Wi
,
C
),
make_tuple
(
din_n_c_wos_strides
[
0
],
din_n_c_wos_strides
[
2
],
din_n_c_wos_strides
[
3
],
din_n_c_wos_strides
[
1
]));
const
auto
in_n_hip_wip_c_grid_desc
=
transform_tensor_descriptor
(
in_n_hi_wi_c_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_pad_transform
(
Hi
,
InLeftPadH
,
InRightPadH
),
make_pad_transform
(
Wi
,
InLeftPadW
,
InRightPadW
),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}));
const
auto
in_n_ytilde_htilde_xtilde_wtilde_c_grid_desc
=
transform_tensor_descriptor
(
in_n_hip_wip_c_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_embed_transform
(
make_tuple
(
YTilde
,
HTilde
),
make_tuple
(
ConvDilationH
,
ConvStrideH
)),
make_embed_transform
(
make_tuple
(
XTilde
,
WTilde
),
make_tuple
(
ConvDilationW
,
ConvStrideW
)),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
,
4
>
{},
Sequence
<
5
>
{}));
const
auto
in_n_htildeslice_wtildeslice_c_grid_desc
=
transform_tensor_descriptor
(
in_n_ytilde_htilde_xtilde_wtilde_c_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_freeze_transform
(
i_ytilde
),
make_slice_transform
(
HTilde
,
IHTildeSliceBegin
,
HTildeSlice
),
make_freeze_transform
(
i_xtilde
),
make_slice_transform
(
WTilde
,
IWTildeSliceBegin
,
WTildeSlice
),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{},
Sequence
<
5
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<>
{},
Sequence
<
1
>
{},
Sequence
<>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}));
const
auto
in_grid_desc_reducemraw
=
transform_tensor_descriptor
(
in_n_htildeslice_wtildeslice_c_grid_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
N
,
HTildeSlice
,
WTildeSlice
,
C
))),
make_tuple
(
Sequence
<
0
,
1
,
2
,
3
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
const
auto
in_grid_desc_reducem
=
transform_tensor_descriptor
(
in_grid_desc_reducemraw
,
make_tuple
(
make_right_pad_transform
(
MRaw
,
MPad
)),
make_tuple
(
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
return
make_tuple
(
out_grid_desc_reducem_reducek
,
in_grid_desc_reducem
);
}
using
DoutDinGridDesc
=
decltype
(
Make2DGridDescriptor_Out_M_K_In_M
({
0
,
0
,
0
,
0
},
{
0
,
0
,
0
,
0
},
{
0
,
0
,
0
,
0
},
{
0
,
0
,
0
,
0
},
{
0
,
0
},
{
0
,
0
},
{
0
,
0
},
{
0
,
0
},
{
0
,
0
},
{
0
,
0
}));
using
DoutGridDesc_M_K
=
remove_cvref_t
<
tuple_element_t
<
0
,
DoutDinGridDesc
>>
;
using
DinGridDesc_M
=
remove_cvref_t
<
tuple_element_t
<
1
,
DoutDinGridDesc
>>
;
// FIXME
// for NHWC, the dim C is the fastest dimension, and is not reduced.
// Hence, it is in M dimension for reduction kernel.
static
constexpr
index_t
OutSrcInDstVectorDim
=
0
;
// 0: M, 1: K
using
PassThrough
=
tensor_operation
::
element_wise
::
PassThrough
;
using
Div
=
tensor_operation
::
element_wise
::
UnaryDivide
;
using
gridwise_reduce
=
GridwiseReduction_mk_to_m_threadwise
<
DOutDataType
,
DInDataType
,
ComputeDataType
,
int
,
DoutGridDesc_M_K
,
DinGridDesc_M
,
reduce
::
Add
,
PassThrough
,
Div
,
InMemoryDataOperationEnum
::
Set
,
false
,
// propagate_nan
BlockSize
,
MThreadSliceSize
,
KThreadSliceSize
,
OutSrcInDstVectorDim
,
InSrcOutDstVectorSize
,
InSrcOutDstVectorSize
>
;
struct
Argument
:
public
BaseArgument
{
Argument
(
const
DOutDataType
*
p_dout
,
DInDataType
*
p_din
,
std
::
vector
<
ck
::
index_t
>
dout_n_c_wos_lengths
,
std
::
vector
<
ck
::
index_t
>
din_n_c_wos_length
,
std
::
vector
<
ck
::
index_t
>
dout_n_c_wos_strides
,
std
::
vector
<
ck
::
index_t
>
din_n_c_wos_strides
,
std
::
vector
<
ck
::
index_t
>
window_lengths
,
std
::
vector
<
ck
::
index_t
>
window_strides
,
std
::
vector
<
ck
::
index_t
>
window_dilations
,
std
::
vector
<
ck
::
index_t
>
input_left_pads
,
std
::
vector
<
ck
::
index_t
>
input_right_pads
)
:
p_dout_grid_
{
p_dout
},
p_din_grid_
{
p_din
},
dout_n_c_wos_lengths_
{
dout_n_c_wos_lengths
},
din_n_c_wos_length_
{
din_n_c_wos_length
},
dout_n_c_wos_strides_
{
dout_n_c_wos_strides
},
din_n_c_wos_strides_
{
din_n_c_wos_strides
},
num_reduce_
{
1
},
div_element_op_
{
window_lengths
[
0
]
*
window_lengths
[
1
]}
{
std
::
vector
<
ck
::
index_t
>
Tildes
(
NDimSpatial
);
for
(
int
i
=
0
;
i
<
NDimSpatial
;
++
i
)
{
int
GcdStrideDilation
=
math
::
gcd
(
window_strides
[
i
],
window_dilations
[
i
]);
Tildes
[
i
]
=
window_strides
[
i
]
/
GcdStrideDilation
;
num_reduce_
*=
Tildes
[
i
];
}
for
(
index_t
i_ytilde
=
0
;
i_ytilde
<
Tildes
[
0
];
++
i_ytilde
)
{
for
(
index_t
i_xtilde
=
0
;
i_xtilde
<
Tildes
[
1
];
++
i_xtilde
)
{
const
auto
YDotSlice
=
math
::
integer_divide_ceil
(
window_lengths
[
0
]
-
i_ytilde
,
Tildes
[
0
]);
const
auto
XDotSlice
=
math
::
integer_divide_ceil
(
window_lengths
[
1
]
-
i_xtilde
,
Tildes
[
1
]);
if
(
YDotSlice
*
XDotSlice
<=
0
)
{
continue
;
}
const
auto
dout_din_grid_desc
=
Make2DGridDescriptor_Out_M_K_In_M
(
dout_n_c_wos_lengths
,
din_n_c_wos_length
,
dout_n_c_wos_strides
,
din_n_c_wos_strides
,
window_lengths
,
window_strides
,
window_dilations
,
input_left_pads
,
input_right_pads
,
{
i_ytilde
,
i_xtilde
});
dout_grid_desc_m_k_container_
.
push_back
(
dout_din_grid_desc
[
I0
]);
din_grid_desc_m_container_
.
push_back
(
dout_din_grid_desc
[
I1
]);
}
}
}
const
DOutDataType
*
p_dout_grid_
;
DInDataType
*
p_din_grid_
;
std
::
vector
<
ck
::
index_t
>
dout_n_c_wos_lengths_
;
std
::
vector
<
ck
::
index_t
>
din_n_c_wos_length_
;
std
::
vector
<
ck
::
index_t
>
dout_n_c_wos_strides_
;
std
::
vector
<
ck
::
index_t
>
din_n_c_wos_strides_
;
int
num_reduce_
;
std
::
vector
<
DoutGridDesc_M_K
>
dout_grid_desc_m_k_container_
;
std
::
vector
<
DinGridDesc_M
>
din_grid_desc_m_container_
;
Div
div_element_op_
;
};
struct
Invoker
:
public
BaseInvoker
{
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
{
float
ave_time
=
0
;
for
(
index_t
i
=
0
;
i
<
arg
.
num_reduce_
;
i
++
)
{
const
auto
kernel
=
kernel_reduce_threadwise
<
gridwise_reduce
,
false
,
false
,
false
,
// don't have index input
DOutDataType
,
DInDataType
,
ComputeDataType
,
int
,
DoutGridDesc_M_K
,
DinGridDesc_M
,
PassThrough
,
Div
>
;
ck
::
index_t
M
=
arg
.
dout_grid_desc_m_k_container_
[
i
].
GetLength
(
I0
);
const
index_t
grid_size
=
(
M
/
M_BlockTileSize
);
ave_time
+=
launch_and_time_kernel
(
stream_config
,
kernel
,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
arg
.
dout_grid_desc_m_k_container_
[
i
],
arg
.
din_grid_desc_m_container_
[
i
],
PassThrough
{},
arg
.
div_element_op_
,
float
(
1
),
arg
.
p_dout_grid_
,
nullptr
,
float
(
0
),
arg
.
p_din_grid_
,
nullptr
);
}
return
ave_time
;
}
float
Run
(
const
BaseArgument
*
p_arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
override
{
return
Run
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
),
stream_config
);
}
};
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
constexpr
index_t
Rank
=
NDimSpatial
+
2
;
int
doutFastestDim
=
-
1
;
int
dinFastestDim
=
-
1
;
for
(
int
i
=
0
;
i
<
Rank
;
++
i
)
{
if
(
arg
.
dout_n_c_wos_strides_
[
i
]
==
1
)
doutFastestDim
=
i
;
if
(
arg
.
din_n_c_wos_strides_
[
i
]
==
1
)
dinFastestDim
=
i
;
}
if
(
InSrcOutDstVectorSize
!=
1
&&
(
dinFastestDim
!=
1
||
doutFastestDim
!=
1
))
{
return
false
;
}
if
(
doutFastestDim
==
-
1
||
dinFastestDim
==
-
1
)
{
if
constexpr
(
InSrcOutDstVectorSize
!=
1
)
return
false
;
}
else
{
if
(
arg
.
dout_n_c_wos_lengths_
[
doutFastestDim
]
%
InSrcOutDstVectorSize
!=
0
)
return
false
;
if
(
arg
.
din_n_c_wos_length_
[
dinFastestDim
]
%
InSrcOutDstVectorSize
!=
0
)
return
false
;
}
return
true
;
}
bool
IsSupportedArgument
(
const
BaseArgument
*
p_arg
)
override
{
return
IsSupportedArgument
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
));
}
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
void
*
p_dout
,
void
*
p_din
,
std
::
vector
<
ck
::
index_t
>
dout_n_c_wos_lengths
,
std
::
vector
<
ck
::
index_t
>
din_n_c_wos_length
,
std
::
vector
<
ck
::
index_t
>
dout_n_c_wos_strides
,
std
::
vector
<
ck
::
index_t
>
din_n_c_wos_strides
,
std
::
vector
<
ck
::
index_t
>
window_lengths
,
std
::
vector
<
ck
::
index_t
>
window_strides
,
std
::
vector
<
ck
::
index_t
>
window_dilations
,
std
::
vector
<
ck
::
index_t
>
input_left_pads
,
std
::
vector
<
ck
::
index_t
>
input_right_pads
)
override
{
constexpr
index_t
Rank
=
NDimSpatial
+
2
;
if
(
dout_n_c_wos_strides
.
size
()
!=
Rank
||
din_n_c_wos_strides
.
size
()
!=
Rank
||
dout_n_c_wos_lengths
.
size
()
!=
Rank
||
din_n_c_wos_length
.
size
()
!=
Rank
)
{
throw
std
::
runtime_error
(
"dimension of [dout|din]_n_c_wos_strides or "
"[dout|din]_n_c_wos_lengths is not equal to Rank"
);
}
if
(
window_lengths
.
size
()
!=
NDimSpatial
||
window_strides
.
size
()
!=
NDimSpatial
||
window_dilations
.
size
()
!=
NDimSpatial
||
input_left_pads
.
size
()
!=
NDimSpatial
||
input_right_pads
.
size
()
!=
NDimSpatial
)
{
throw
std
::
runtime_error
(
"dimension of [window_lengths, window_strides, window_dilations, input_left_pads, "
"input_right_pads] is not equal to Rank"
);
}
return
std
::
make_unique
<
Argument
>
(
static_cast
<
const
DOutDataType
*>
(
p_dout
),
static_cast
<
DInDataType
*>
(
p_din
),
dout_n_c_wos_lengths
,
din_n_c_wos_length
,
dout_n_c_wos_strides
,
din_n_c_wos_strides
,
window_lengths
,
window_strides
,
window_dilations
,
input_left_pads
,
input_right_pads
);
}
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
override
{
return
std
::
make_unique
<
Invoker
>
(
Invoker
{});
}
std
::
string
GetTypeString
()
const
override
{
auto
str
=
std
::
stringstream
();
// clang-format off
str
<<
"DeviceAvgPool2dBwd<"
<<
BlockSize
<<
","
;
str
<<
"M_C"
<<
MThreadClusterSize
<<
"_S"
<<
MThreadSliceSize
<<
","
;
str
<<
"K_C"
<<
KThreadClusterSize
<<
"_S"
<<
KThreadSliceSize
<<
","
;
str
<<
"InSrcOutDstVectorSize_"
<<
InSrcOutDstVectorSize
<<
">"
;
// clang-format on
return
str
.
str
();
}
};
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
include/ck/tensor_operation/gpu/device/impl/device_pool2d_fwd_nhwc_nhwc.hpp
View file @
6252d207
// SPDX-License-Identifier: MIT
// Copyright (c) 20
18-2023
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 20
24
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/tensor_operation/gpu/device/impl/device_pool3d_fwd_ndhwc_ndhwc.hpp"
#include <iostream>
#include <sstream>
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/reduction_operator_mapping.hpp"
#include "ck/tensor_operation/gpu/device/device_pool_fwd.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_reduce_common.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_2d_reduction_threadwise.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
namespace
ck
{
namespace
tensor_operation
{
...
...
@@ -16,95 +27,359 @@ template <typename InDataType,
ck
::
ReduceTensorOp
ReduceOpId
,
bool
OutputIndex
,
ck
::
index_t
BlockSize
,
ck
::
index_t
Reduce
MThreadClusterSize
,
ck
::
index_t
Reduce
KThreadClusterSize
,
ck
::
index_t
Reduce
MThreadSliceSize
,
ck
::
index_t
Reduce
KThreadSliceSize
,
ck
::
index_t
MThreadClusterSize
,
ck
::
index_t
KThreadClusterSize
,
ck
::
index_t
MThreadSliceSize
,
ck
::
index_t
KThreadSliceSize
,
ck
::
index_t
InSrcOutDstVectorSize
>
struct
DevicePool2dFwd_NHWC_NHWC
:
public
DevicePool3dFwd_NDHWC_NDHWC
<
InDataType
,
OutDataType
,
IndexDataType
,
ComputeDataType
,
ReduceOpId
,
OutputIndex
,
BlockSize
,
ReduceMThreadClusterSize
,
ReduceKThreadClusterSize
,
ReduceMThreadSliceSize
,
ReduceKThreadSliceSize
,
InSrcOutDstVectorSize
>
struct
DevicePool2dFwd_NHWC_NHWC
:
public
DevicePoolFwd
<
4
,
2
,
InDataType
,
OutDataType
,
IndexDataType
,
tensor_layout
::
convolution
::
NHWC
,
tensor_layout
::
convolution
::
NHWC
,
ReduceOpId
,
OutputIndex
>
{
using
DevicePool3D
=
DevicePool3dFwd_NDHWC_NDHWC
<
InDataType
,
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
index_t
InOutRank
=
4
;
static
constexpr
index_t
WindowRank
=
2
;
using
ReduceOperation
=
typename
reduce_binary_operator
<
ReduceOpId
>::
opType
;
using
InElementwiseOperation
=
typename
reduce_unary_operator
<
ReduceOpId
,
true
,
true
>::
InElementwiseOperation
;
using
AccElementwiseOperation
=
typename
reduce_unary_operator
<
ReduceOpId
,
true
,
true
>::
AccElementwiseOperation
;
static
constexpr
ck
::
index_t
M_BlockTileSize
=
MThreadClusterSize
*
MThreadSliceSize
;
static
constexpr
ck
::
index_t
K_BlockTileSize
=
KThreadClusterSize
*
KThreadSliceSize
;
static
auto
MakeABGridDescriptor_A_M_K_B_M
(
std
::
vector
<
ck
::
index_t
>
input_nchw_lengths
,
std
::
vector
<
ck
::
index_t
>
output_nchw_lengths
,
std
::
vector
<
ck
::
index_t
>
input_nchw_stride
,
std
::
vector
<
ck
::
index_t
>
output_nchw_stride
,
std
::
vector
<
ck
::
index_t
>
window_spatial_yx_lengths
,
std
::
vector
<
ck
::
index_t
>
window_yx_strides
,
std
::
vector
<
ck
::
index_t
>
window_yx_dilations
,
std
::
vector
<
ck
::
index_t
>
input_left_hw_pads
,
std
::
vector
<
ck
::
index_t
>
input_right_hw_pads
)
{
const
index_t
N
=
input_nchw_lengths
[
0
];
const
index_t
C
=
input_nchw_lengths
[
1
];
const
index_t
Hi
=
input_nchw_lengths
[
2
];
const
index_t
Wi
=
input_nchw_lengths
[
3
];
const
index_t
Ho
=
output_nchw_lengths
[
2
];
const
index_t
Wo
=
output_nchw_lengths
[
3
];
const
index_t
Y
=
window_spatial_yx_lengths
[
0
];
const
index_t
X
=
window_spatial_yx_lengths
[
1
];
const
index_t
WindowStrideH
=
window_yx_strides
[
0
];
const
index_t
WindowStrideW
=
window_yx_strides
[
1
];
const
index_t
WindowDilationH
=
window_yx_dilations
[
0
];
const
index_t
WindowDilationW
=
window_yx_dilations
[
1
];
const
index_t
InLeftPadH
=
input_left_hw_pads
[
0
];
const
index_t
InLeftPadW
=
input_left_hw_pads
[
1
];
const
index_t
InRightPadH
=
input_right_hw_pads
[
0
];
const
index_t
InRightPadW
=
input_right_hw_pads
[
1
];
const
index_t
MRaw
=
N
*
Ho
*
Wo
*
C
;
const
index_t
MPad
=
math
::
integer_least_multiple
(
MRaw
,
M_BlockTileSize
)
-
MRaw
;
const
index_t
KRaw
=
Y
*
X
;
const
index_t
KPad
=
math
::
integer_least_multiple
(
KRaw
,
K_BlockTileSize
)
-
KRaw
;
// A[ReduceM, ReduceK]
const
index_t
Ni_stride
=
input_nchw_stride
[
0
];
const
index_t
Ci_stride
=
input_nchw_stride
[
1
];
const
index_t
Hi_stride
=
input_nchw_stride
[
2
];
const
index_t
Wi_stride
=
input_nchw_stride
[
3
];
const
auto
in_grid_desc_n_hi_wi_c
=
make_naive_tensor_descriptor
(
make_tuple
(
N
,
Hi
,
Wi
,
C
),
make_tuple
(
Ni_stride
,
Hi_stride
,
Wi_stride
,
Ci_stride
));
const
auto
in_grid_desc_n_hip_wip_c
=
transform_tensor_descriptor
(
in_grid_desc_n_hi_wi_c
,
make_tuple
(
make_pass_through_transform
(
N
),
make_pad_transform
(
Hi
,
InLeftPadH
,
InRightPadH
),
make_pad_transform
(
Wi
,
InLeftPadW
,
InRightPadW
),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}));
const
auto
in_grid_desc_n_y_ho_x_wo_c
=
transform_tensor_descriptor
(
in_grid_desc_n_hip_wip_c
,
make_tuple
(
make_pass_through_transform
(
N
),
make_embed_transform
(
make_tuple
(
Y
,
Ho
),
make_tuple
(
WindowDilationH
,
WindowStrideH
)),
make_embed_transform
(
make_tuple
(
X
,
Wo
),
make_tuple
(
WindowDilationW
,
WindowStrideW
)),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
,
4
>
{},
Sequence
<
5
>
{}));
const
auto
in_grid_desc_reducemraw_reducekraw
=
transform_tensor_descriptor
(
in_grid_desc_n_y_ho_x_wo_c
,
make_tuple
(
make_merge_transform
(
make_tuple
(
N
,
Ho
,
Wo
,
C
)),
make_merge_transform
(
make_tuple
(
Y
,
X
))),
make_tuple
(
Sequence
<
0
,
2
,
4
,
5
>
{},
Sequence
<
1
,
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
in_grid_desc_reducem_reducek
=
transform_tensor_descriptor
(
in_grid_desc_reducemraw_reducekraw
,
make_tuple
(
make_right_pad_transform
(
MRaw
,
MPad
),
make_right_pad_transform
(
KRaw
,
KPad
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
// B[ReduceM]
const
index_t
No_stride
=
output_nchw_stride
[
0
];
const
index_t
Co_stride
=
output_nchw_stride
[
1
];
const
index_t
Ho_stride
=
output_nchw_stride
[
2
];
const
index_t
Wo_stride
=
output_nchw_stride
[
3
];
const
auto
out_grid_desc_n_ho_wo_c
=
make_naive_tensor_descriptor
(
make_tuple
(
N
,
Hi
,
Wi
,
C
),
make_tuple
(
No_stride
,
Ho_stride
,
Wo_stride
,
Co_stride
));
const
auto
out_grid_desc_reducemraw
=
transform_tensor_descriptor
(
out_grid_desc_n_ho_wo_c
,
make_tuple
(
make_merge_transform
(
make_tuple
(
N
,
Ho
,
Wo
,
C
))),
make_tuple
(
Sequence
<
0
,
1
,
2
,
3
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
const
auto
out_grid_desc_reducem
=
transform_tensor_descriptor
(
out_grid_desc_reducemraw
,
make_tuple
(
make_right_pad_transform
(
MRaw
,
MPad
)),
make_tuple
(
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
return
make_tuple
(
in_grid_desc_reducem_reducek
,
out_grid_desc_reducem
);
}
using
ABGridDescs
=
decltype
(
MakeABGridDescriptor_A_M_K_B_M
({},
{},
{},
{},
{},
{},
{},
{},
{}));
using
AGridDesc_M_K
=
remove_cvref_t
<
decltype
(
ABGridDescs
{}[
I0
])
>
;
using
BGridDesc_M
=
remove_cvref_t
<
decltype
(
ABGridDescs
{}[
I1
])
>
;
struct
Argument
:
public
BaseArgument
{
Argument
(
const
InDataType
*
p_in_dev
,
OutDataType
*
p_out_dev
,
IndexDataType
*
p_out_indices_dev
,
std
::
vector
<
ck
::
index_t
>&
input_nchw_lengths
,
std
::
vector
<
ck
::
index_t
>&
output_nchw_lengths
,
std
::
vector
<
ck
::
index_t
>&
input_nchw_stride
,
std
::
vector
<
ck
::
index_t
>&
output_nchw_stride
,
std
::
vector
<
ck
::
index_t
>&
,
// indices_nchw_stride
std
::
vector
<
ck
::
index_t
>&
window_spatial_yx_lengths
,
std
::
vector
<
ck
::
index_t
>&
window_yx_strides
,
std
::
vector
<
ck
::
index_t
>&
window_yx_dilations
,
std
::
vector
<
ck
::
index_t
>&
input_left_hw_pads
,
std
::
vector
<
ck
::
index_t
>&
input_right_hw_pads
)
:
p_in_dev_
{
p_in_dev
},
p_out_dev_
{
p_out_dev
},
p_out_indices_dev_
{
p_out_indices_dev
},
a_grid_desc_m_k_
{},
b_grid_desc_m_
{},
input_nchw_lengths_
{
input_nchw_lengths
},
output_nchw_lengths_
{
output_nchw_lengths
},
input_nchw_stride_
{
input_nchw_stride
},
output_nchw_stride_
{
output_nchw_stride
}
{
const
auto
descs
=
MakeABGridDescriptor_A_M_K_B_M
(
input_nchw_lengths
,
output_nchw_lengths
,
input_nchw_stride
,
output_nchw_stride
,
window_spatial_yx_lengths
,
window_yx_strides
,
window_yx_dilations
,
input_left_hw_pads
,
input_right_hw_pads
);
a_grid_desc_m_k_
=
descs
[
I0
];
b_grid_desc_m_
=
descs
[
I1
];
int32_t
reduceLength
=
window_spatial_yx_lengths
[
0
]
*
window_spatial_yx_lengths
[
1
];
std
::
tie
(
in_element_op_
,
acc_element_op_
)
=
reduce_unary_operator
<
ReduceOpId
,
true
,
true
>::
GetElementwiseOperator
(
reduceLength
);
}
const
InDataType
*
p_in_dev_
;
OutDataType
*
p_out_dev_
;
IndexDataType
*
p_out_indices_dev_
;
AGridDesc_M_K
a_grid_desc_m_k_
;
BGridDesc_M
b_grid_desc_m_
;
InElementwiseOperation
in_element_op_
;
AccElementwiseOperation
acc_element_op_
;
// for checking vector load/store
std
::
vector
<
ck
::
index_t
>
input_nchw_lengths_
;
std
::
vector
<
ck
::
index_t
>
output_nchw_lengths_
;
std
::
vector
<
ck
::
index_t
>
input_nchw_stride_
;
std
::
vector
<
ck
::
index_t
>
output_nchw_stride_
;
};
struct
Invoker
:
public
BaseInvoker
{
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
{
// for NHWC, the dim C is the fastest dimension, and is not reduced.
// Hence, it is in M dimension for reduction kernel.
static
constexpr
index_t
InSrcOutDstVectorDim
=
0
;
// 0: M, 1: K
using
gridwise_reduce
=
GridwiseReduction_mk_to_m_threadwise
<
InDataType
,
OutDataType
,
IndexDataType
,
ComputeDataType
,
ReduceOpId
,
OutputIndex
,
IndexDataType
,
AGridDesc_M_K
,
BGridDesc_M
,
ReduceOperation
,
InElementwiseOperation
,
AccElementwiseOperation
,
InMemoryDataOperationEnum
::
Set
,
false
,
// propagate_nan
BlockSize
,
Reduce
MThread
Cluster
Size
,
Reduce
KThread
Cluster
Size
,
ReduceMThreadSliceSize
,
ReduceKThreadSlice
Size
,
MThread
Slice
Size
,
KThread
Slice
Size
,
InSrcOutDstVectorDim
,
InSrcOutDstVector
Size
,
InSrcOutDstVectorSize
>
;
std
::
unique_ptr
<
BaseArgument
>
const
auto
kernel
=
kernel_reduce_threadwise
<
gridwise_reduce
,
OutputIndex
,
true
,
// pooling need to return global index
false
,
// don't have index input
InDataType
,
OutDataType
,
ComputeDataType
,
IndexDataType
,
AGridDesc_M_K
,
BGridDesc_M
,
InElementwiseOperation
,
AccElementwiseOperation
>
;
ck
::
index_t
M
=
arg
.
a_grid_desc_m_k_
.
GetLength
(
I0
);
const
index_t
grid_size
=
(
M
/
M_BlockTileSize
);
return
launch_and_time_kernel
(
stream_config
,
kernel
,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
arg
.
a_grid_desc_m_k_
,
arg
.
b_grid_desc_m_
,
arg
.
in_element_op_
,
arg
.
acc_element_op_
,
float
(
1
),
arg
.
p_in_dev_
,
nullptr
,
float
(
0
),
arg
.
p_out_dev_
,
arg
.
p_out_indices_dev_
);
}
float
Run
(
const
BaseArgument
*
p_arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
override
{
return
Run
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
),
stream_config
);
}
};
bool
IsSupportedArgument
(
const
BaseArgument
*
p_arg
)
override
{
const
Argument
*
pArg
=
dynamic_cast
<
const
Argument
*>
(
p_arg
);
// C should be fastest dimension
if
(
pArg
->
input_nchw_stride_
[
1
]
!=
1
)
return
false
;
for
(
int
i
=
0
;
i
<
InOutRank
;
++
i
)
{
if
(
pArg
->
input_nchw_stride_
[
i
]
==
1
&&
pArg
->
input_nchw_lengths_
[
i
]
%
InSrcOutDstVectorSize
!=
0
)
return
false
;
if
(
pArg
->
output_nchw_stride_
[
i
]
==
1
&&
pArg
->
output_nchw_lengths_
[
i
]
%
InSrcOutDstVectorSize
!=
0
)
return
false
;
}
return
true
;
}
virtual
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
void
*
p_in_dev
,
void
*
p_out_dev
,
void
*
p_out_indices_dev
,
std
::
vector
<
ck
::
index_t
>
input_lengths
,
std
::
vector
<
ck
::
index_t
>
window_lengths
,
std
::
vector
<
ck
::
index_t
>
output_lengths
,
std
::
vector
<
ck
::
index_t
>
input_stride
,
std
::
vector
<
ck
::
index_t
>
output_stride
,
std
::
vector
<
ck
::
index_t
>
indices_stride
,
std
::
vector
<
ck
::
index_t
>
window_strides
,
std
::
vector
<
ck
::
index_t
>
window_dilations
,
std
::
vector
<
ck
::
index_t
>
input_left_pads
,
std
::
vector
<
ck
::
index_t
>
input_right_pads
,
std
::
vector
<
ck
::
index_t
>
input_
nchw_
lengths
,
std
::
vector
<
ck
::
index_t
>
window_
yx_
lengths
,
std
::
vector
<
ck
::
index_t
>
output_
nchw_
lengths
,
std
::
vector
<
ck
::
index_t
>
input_
nchw_
stride
,
std
::
vector
<
ck
::
index_t
>
output_
nchw_
stride
,
std
::
vector
<
ck
::
index_t
>
indices_
nchw_
stride
,
std
::
vector
<
ck
::
index_t
>
window_
yx_
strides
,
std
::
vector
<
ck
::
index_t
>
window_
yx_
dilations
,
std
::
vector
<
ck
::
index_t
>
input_left_
hw_
pads
,
std
::
vector
<
ck
::
index_t
>
input_right_
hw_
pads
,
std
::
vector
<
ck
::
index_t
>
pooling_dims
)
override
{
static
constexpr
index_t
InOutRank
=
4
;
static
constexpr
index_t
WindowRank
=
2
;
if
(
input_lengths
.
size
()
!=
InOutRank
||
window_lengths
.
size
()
!=
WindowRank
||
input_lengths
.
size
()
!=
InOutRank
||
window_strides
.
size
()
!=
WindowRank
||
window_dilations
.
size
()
!=
WindowRank
||
input_left_pads
.
size
()
!=
WindowRank
||
input_right_pads
.
size
()
!=
WindowRank
)
if
(
input_nchw_lengths
.
size
()
!=
InOutRank
||
window_yx_lengths
.
size
()
!=
WindowRank
||
input_nchw_lengths
.
size
()
!=
InOutRank
||
window_yx_strides
.
size
()
!=
WindowRank
||
window_yx_dilations
.
size
()
!=
WindowRank
||
input_left_hw_pads
.
size
()
!=
WindowRank
||
input_right_hw_pads
.
size
()
!=
WindowRank
)
throw
std
::
runtime_error
(
"dimension is incorrect"
);
if
(
pooling_dims
!=
std
::
vector
<
ck
::
index_t
>
{
2
,
3
})
throw
std
::
runtime_error
(
"pooling_dims only support {2, 3} in pool2d so far"
);
// NCHW to NCDHW
input_lengths
.
insert
(
input_lengths
.
begin
()
+
2
,
1
);
output_lengths
.
insert
(
output_lengths
.
begin
()
+
2
,
1
);
input_stride
.
insert
(
input_stride
.
begin
()
+
2
,
0
);
output_stride
.
insert
(
output_stride
.
begin
()
+
2
,
0
);
indices_stride
.
insert
(
indices_stride
.
begin
()
+
2
,
0
);
// YX to ZYX
window_lengths
.
insert
(
window_lengths
.
begin
(),
1
);
window_strides
.
insert
(
window_strides
.
begin
(),
0
);
window_dilations
.
insert
(
window_dilations
.
begin
(),
0
);
input_left_pads
.
insert
(
input_left_pads
.
begin
(),
0
);
input_right_pads
.
insert
(
input_right_pads
.
begin
(),
0
);
pooling_dims
=
{
2
,
3
,
4
};
return
DevicePool3D
::
MakeArgumentPointer
(
p_in_dev
,
p_out_dev
,
p_out_indices_dev
,
input_lengths
,
window_lengths
,
output_lengths
,
input_stride
,
output_stride
,
indices_stride
,
window_strides
,
window_dilations
,
input_left_pads
,
input_right_pads
,
pooling_dims
);
if
(
output_nchw_stride
!=
indices_nchw_stride
)
throw
std
::
runtime_error
(
"output_nchw_stride need to be equal to indices_nchw_stride for now"
);
return
std
::
make_unique
<
Argument
>
(
static_cast
<
const
InDataType
*>
(
p_in_dev
),
static_cast
<
OutDataType
*>
(
p_out_dev
),
static_cast
<
IndexDataType
*>
(
p_out_indices_dev
),
input_nchw_lengths
,
output_nchw_lengths
,
input_nchw_stride
,
output_nchw_stride
,
indices_nchw_stride
,
window_yx_lengths
,
window_yx_strides
,
window_yx_dilations
,
input_left_hw_pads
,
input_right_hw_pads
);
}
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
override
{
return
std
::
make_unique
<
Invoker
>
(
Invoker
{});
}
std
::
string
GetTypeString
()
const
override
{
auto
str
=
std
::
stringstream
();
// clang-format off
str
<<
"DevicePool2dFwd_NHWC_NHWC<"
<<
BlockSize
<<
","
;
str
<<
"M_C"
<<
MThreadClusterSize
<<
"_S"
<<
MThreadSliceSize
<<
","
;
str
<<
"K_C"
<<
KThreadClusterSize
<<
"_S"
<<
KThreadSliceSize
<<
","
;
str
<<
"InSrcOutDstVectorSize_"
<<
InSrcOutDstVectorSize
<<
">"
;
// clang-format on
return
str
.
str
();
}
};
...
...
include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp
View file @
6252d207
...
...
@@ -355,12 +355,39 @@ struct UnaryDivide
__host__
__device__
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
is_same
<
T
,
int32_t
>::
value
,
is_same
<
T
,
int32_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
,
"Data type is not supported by this operation!"
);
y
=
x
/
type_convert
<
T
>
(
divider_
);
};
template
<
>
__host__
__device__
void
operator
()
<
half_t
>
(
half_t
&
y
,
const
half_t
&
x
)
const
{
float
x_
=
type_convert
<
float
>
(
x
);
float
divider_f_
=
type_convert
<
float
>
(
divider_
);
y
=
type_convert
<
half_t
>
(
x_
/
divider_f_
);
};
template
<
>
__host__
__device__
void
operator
()
<
bhalf_t
>
(
bhalf_t
&
y
,
const
bhalf_t
&
x
)
const
{
float
x_
=
type_convert
<
float
>
(
x
);
float
divider_f_
=
type_convert
<
float
>
(
divider_
);
y
=
type_convert
<
bhalf_t
>
(
x_
/
divider_f_
);
};
template
<
>
__host__
__device__
void
operator
()
<
f8_t
>
(
f8_t
&
y
,
const
f8_t
&
x
)
const
{
float
x_
=
type_convert
<
float
>
(
x
);
float
divider_f_
=
type_convert
<
float
>
(
divider_
);
y
=
type_convert
<
f8_t
>
(
x_
/
divider_f_
);
};
int32_t
divider_
=
1
;
};
...
...
include/ck/tensor_operation/gpu/warp/smfmac_xdlops_gemm.hpp
View file @
6252d207
...
...
@@ -35,10 +35,16 @@ struct smfmac<SmfmacInstr::smfmac_f32_16x16x32f16>
static
constexpr
index_t
k_per_blk
=
8
;
static
constexpr
bool
is_k_reduction
=
true
;
template
<
index_t
MPerXdlops
,
index_t
NPerXdlops
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
void
run
(
const
FloatA
&
a
,
const
FloatB
&
b
,
const
int32_t
&
idx
,
FloatC
&
reg_c
)
const
template
<
index_t
MPerXdlops
,
index_t
NPerXdlops
,
index_t
idx_part
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
void
run
(
const
FloatA
&
a
,
const
FloatB
&
b
,
const
index_t
&
idx
,
FloatC
&
reg_c
)
const
{
intrin_smfmac_f32_16x16x32f16
<
MPerXdlops
,
NPerXdlops
>::
Run
(
a
,
b
,
idx
,
reg_c
);
intrin_smfmac_f32_16x16x32f16
<
MPerXdlops
,
NPerXdlops
>::
Run
<
FloatC
,
idx_part
>
(
a
,
b
,
idx
,
reg_c
);
}
};
...
...
@@ -57,10 +63,16 @@ struct smfmac<SmfmacInstr::smfmac_f32_32x32x16f16>
static
constexpr
index_t
k_per_blk
=
16
;
static
constexpr
bool
is_k_reduction
=
true
;
template
<
index_t
MPerXdlops
,
index_t
NPerXdlops
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
void
run
(
const
FloatA
&
a
,
const
FloatB
&
b
,
const
int32_t
&
idx
,
FloatC
&
reg_c
)
const
template
<
index_t
MPerXdlops
,
index_t
NPerXdlops
,
index_t
idx_part
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
void
run
(
const
FloatA
&
a
,
const
FloatB
&
b
,
const
index_t
&
idx
,
FloatC
&
reg_c
)
const
{
intrin_smfmac_f32_32x32x16f16
<
MPerXdlops
,
NPerXdlops
>::
Run
(
a
,
b
,
idx
,
reg_c
);
intrin_smfmac_f32_32x32x16f16
<
MPerXdlops
,
NPerXdlops
>::
Run
<
FloatC
,
idx_part
>
(
a
,
b
,
idx
,
reg_c
);
}
};
...
...
@@ -79,10 +91,16 @@ struct smfmac<SmfmacInstr::smfmac_f32_16x16x32bf16>
static
constexpr
index_t
k_per_blk
=
8
;
static
constexpr
bool
is_k_reduction
=
true
;
template
<
index_t
MPerXdlops
,
index_t
NPerXdlops
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
void
run
(
const
FloatA
&
a
,
const
FloatB
&
b
,
const
int32_t
&
idx
,
FloatC
&
reg_c
)
const
template
<
index_t
MPerXdlops
,
index_t
NPerXdlops
,
index_t
idx_part
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
void
run
(
const
FloatA
&
a
,
const
FloatB
&
b
,
const
index_t
&
idx
,
FloatC
&
reg_c
)
const
{
intrin_smfmac_f32_16x16x32bf16
<
MPerXdlops
,
NPerXdlops
>::
Run
(
a
,
b
,
idx
,
reg_c
);
intrin_smfmac_f32_16x16x32bf16
<
MPerXdlops
,
NPerXdlops
>::
Run
<
FloatC
,
idx_part
>
(
a
,
b
,
idx
,
reg_c
);
}
};
...
...
@@ -101,10 +119,16 @@ struct smfmac<SmfmacInstr::smfmac_f32_32x32x16bf16>
static
constexpr
index_t
k_per_blk
=
16
;
static
constexpr
bool
is_k_reduction
=
true
;
template
<
index_t
MPerXdlops
,
index_t
NPerXdlops
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
void
run
(
const
FloatA
&
a
,
const
FloatB
&
b
,
const
int32_t
&
idx
,
FloatC
&
reg_c
)
const
template
<
index_t
MPerXdlops
,
index_t
NPerXdlops
,
index_t
idx_part
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
void
run
(
const
FloatA
&
a
,
const
FloatB
&
b
,
const
index_t
&
idx
,
FloatC
&
reg_c
)
const
{
intrin_smfmac_f32_32x32x16bf16
<
MPerXdlops
,
NPerXdlops
>::
Run
(
a
,
b
,
idx
,
reg_c
);
intrin_smfmac_f32_32x32x16bf16
<
MPerXdlops
,
NPerXdlops
>::
Run
<
FloatC
,
idx_part
>
(
a
,
b
,
idx
,
reg_c
);
}
};
...
...
@@ -305,8 +329,8 @@ struct SparseXdlopsGemm
"base base_type must be half or bfloat16!"
);
static_for
<
0
,
KPack
/
smfmac_instr
.
k_per_blk
,
1
>
{}([
&
](
auto
k
)
{
smfmac_instr
.
template
run
<
MPerXdlops
,
NPerXdlops
>(
p_a_wave
[
k
],
p_b_wave
[
k
],
idx
[
k
],
p_c_thread
);
smfmac_instr
.
template
run
<
MPerXdlops
,
NPerXdlops
,
k
%
4
>(
p_a_wave
[
k
],
p_b_wave
[
k
],
idx
[
k
/
4
],
p_c_thread
);
});
}
...
...
include/ck/utility/amd_smfmac.hpp
View file @
6252d207
...
...
@@ -9,16 +9,18 @@ namespace ck {
template
<
index_t
MPerWave
,
index_t
NPerWave
>
struct
intrin_smfmac_f32_16x16x32f16
;
// for every smfmac instruction if CBSZ[1:0]=0, ABID[1:0] selects one of four 8-bit sets of sparse
// indices from reg_idx
template
<
>
struct
intrin_smfmac_f32_16x16x32f16
<
16
,
16
>
{
template
<
class
FloatC
>
template
<
class
FloatC
,
index_t
abid
=
0
>
__device__
static
void
Run
(
const
half4_t
&
reg_a
,
const
half8_t
&
reg_b
,
const
in
t32
_t
&
reg_idx
,
FloatC
&
reg_c
)
Run
(
const
half4_t
&
reg_a
,
const
half8_t
&
reg_b
,
const
in
dex
_t
&
reg_idx
,
FloatC
&
reg_c
)
{
#if defined(__gfx94__)
reg_c
.
template
AsType
<
float4_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_smfmac_f32_16x16x32_f16
(
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float4_t
>()[
Number
<
0
>
{}],
reg_idx
,
0
,
0
);
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float4_t
>()[
Number
<
0
>
{}],
reg_idx
,
0
,
abid
);
#else
ignore
=
reg_a
;
ignore
=
reg_b
;
...
...
@@ -34,13 +36,13 @@ struct intrin_smfmac_f32_16x16x32bf16;
template
<
>
struct
intrin_smfmac_f32_16x16x32bf16
<
16
,
16
>
{
template
<
class
FloatC
>
template
<
class
FloatC
,
index_t
abid
=
0
>
__device__
static
void
Run
(
const
bhalf4_t
&
reg_a
,
const
bhalf8_t
&
reg_b
,
const
in
t32
_t
&
reg_idx
,
FloatC
&
reg_c
)
Run
(
const
bhalf4_t
&
reg_a
,
const
bhalf8_t
&
reg_b
,
const
in
dex
_t
&
reg_idx
,
FloatC
&
reg_c
)
{
#if defined(__gfx94__)
reg_c
.
template
AsType
<
float4_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_smfmac_f32_16x16x32_bf16
(
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float4_t
>()[
Number
<
0
>
{}],
reg_idx
,
0
,
0
);
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float4_t
>()[
Number
<
0
>
{}],
reg_idx
,
0
,
abid
);
#else
ignore
=
reg_a
;
ignore
=
reg_b
;
...
...
@@ -56,13 +58,13 @@ struct intrin_smfmac_f32_32x32x16f16;
template
<
>
struct
intrin_smfmac_f32_32x32x16f16
<
32
,
32
>
{
template
<
class
FloatC
>
template
<
class
FloatC
,
index_t
abid
=
0
>
__device__
static
void
Run
(
const
half4_t
&
reg_a
,
const
half8_t
&
reg_b
,
const
in
t32
_t
&
reg_idx
,
FloatC
&
reg_c
)
Run
(
const
half4_t
&
reg_a
,
const
half8_t
&
reg_b
,
const
in
dex
_t
&
reg_idx
,
FloatC
&
reg_c
)
{
#if defined(__gfx94__)
reg_c
.
template
AsType
<
float16_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_smfmac_f32_32x32x16_f16
(
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float16_t
>()[
Number
<
0
>
{}],
reg_idx
,
0
,
0
);
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float16_t
>()[
Number
<
0
>
{}],
reg_idx
,
0
,
abid
);
#else
ignore
=
reg_a
;
ignore
=
reg_b
;
...
...
@@ -78,13 +80,13 @@ struct intrin_smfmac_f32_32x32x16bf16;
template
<
>
struct
intrin_smfmac_f32_32x32x16bf16
<
32
,
32
>
{
template
<
class
FloatC
>
template
<
class
FloatC
,
index_t
abid
=
0
>
__device__
static
void
Run
(
const
bhalf4_t
&
reg_a
,
const
bhalf8_t
&
reg_b
,
const
in
t32
_t
&
reg_idx
,
FloatC
&
reg_c
)
Run
(
const
bhalf4_t
&
reg_a
,
const
bhalf8_t
&
reg_b
,
const
in
dex
_t
&
reg_idx
,
FloatC
&
reg_c
)
{
#if defined(__gfx94__)
reg_c
.
template
AsType
<
float16_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_smfmac_f32_32x32x16_bf16
(
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float16_t
>()[
Number
<
0
>
{}],
reg_idx
,
0
,
0
);
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float16_t
>()[
Number
<
0
>
{}],
reg_idx
,
0
,
abid
);
#else
ignore
=
reg_a
;
ignore
=
reg_b
;
...
...
include/ck/utility/reduction_operator.hpp
View file @
6252d207
...
...
@@ -52,12 +52,28 @@ struct Add
__host__
__device__
inline
constexpr
void
operator
()(
T
&
a
,
T
b
)
const
{
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
is_same
<
T
,
int32_t
>::
value
||
is_same
<
T
,
half
_t
>::
value
,
is_same
<
T
,
int32_t
>::
value
||
is_same
<
T
,
int8
_t
>::
value
,
"The data type is not supported by the Add accumulator!"
);
a
=
a
+
b
;
}
__host__
__device__
inline
constexpr
void
operator
()(
f8_t
&
a
,
f8_t
b
)
const
{
float
a_
=
type_convert
<
float
>
(
a
);
float
b_
=
type_convert
<
float
>
(
b
);
a
=
type_convert
<
f8_t
>
(
a_
+
b_
);
}
__host__
__device__
inline
constexpr
void
operator
()(
half_t
&
a
,
half_t
b
)
const
{
float
a_
=
type_convert
<
float
>
(
a
);
float
b_
=
type_convert
<
float
>
(
b
);
a
=
type_convert
<
half_t
>
(
a_
+
b_
);
}
__host__
__device__
inline
constexpr
void
operator
()(
bhalf_t
&
a
,
bhalf_t
b
)
const
{
float
a_
=
type_convert
<
float
>
(
a
);
...
...
@@ -112,12 +128,28 @@ struct Mul
__host__
__device__
inline
constexpr
void
operator
()(
T
&
a
,
T
b
)
const
{
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
is_same
<
T
,
int32_t
>::
value
||
is_same
<
T
,
half
_t
>::
value
,
is_same
<
T
,
int32_t
>::
value
||
is_same
<
T
,
int8
_t
>::
value
,
"The data type is not supported by the Mul accumulator!"
);
a
=
a
*
b
;
}
__host__
__device__
inline
constexpr
void
operator
()(
f8_t
&
a
,
f8_t
b
)
const
{
float
a_
=
type_convert
<
float
>
(
a
);
float
b_
=
type_convert
<
float
>
(
b
);
a
=
type_convert
<
f8_t
>
(
a_
*
b_
);
}
__host__
__device__
inline
constexpr
void
operator
()(
half_t
&
a
,
half_t
b
)
const
{
float
a_
=
type_convert
<
float
>
(
a
);
float
b_
=
type_convert
<
float
>
(
b
);
a
=
type_convert
<
half_t
>
(
a_
*
b_
);
}
__host__
__device__
inline
constexpr
void
operator
()(
bhalf_t
&
a
,
bhalf_t
b
)
const
{
float
a_
=
type_convert
<
float
>
(
a
);
...
...
@@ -137,6 +169,16 @@ struct Max
float
val
=
NumericLimits
<
float
>::
Lowest
();
return
type_convert
<
bhalf_t
>
(
val
);
}
if
constexpr
(
is_same_v
<
T
,
f8_t
>
)
{
float
val
=
NumericLimits
<
float
>::
Lowest
();
return
type_convert
<
f8_t
>
(
val
);
}
if
constexpr
(
is_same_v
<
T
,
half_t
>
)
{
float
val
=
NumericLimits
<
float
>::
Lowest
();
return
type_convert
<
half_t
>
(
val
);
}
else
{
return
NumericLimits
<
T
>::
Lowest
();
...
...
@@ -154,8 +196,7 @@ struct Max
__host__
__device__
inline
constexpr
void
operator
()(
T
&
a
,
T
b
)
const
{
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
is_same
<
T
,
half_t
>::
value
||
is_same
<
T
,
int32_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
,
is_same
<
T
,
int32_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
,
"The data type is not supported by the Max accumulator!"
);
if
(
a
<
b
)
...
...
@@ -171,12 +212,29 @@ struct Max
a
=
b
;
}
__host__
__device__
inline
constexpr
void
operator
()(
half_t
&
a
,
half_t
b
)
const
{
float
a_
=
type_convert
<
float
>
(
a
);
float
b_
=
type_convert
<
float
>
(
b
);
if
(
a_
<
b_
)
a
=
b
;
}
__host__
__device__
inline
constexpr
void
operator
()(
f8_t
&
a
,
f8_t
b
)
const
{
float
a_
=
type_convert
<
float
>
(
a
);
float
b_
=
type_convert
<
float
>
(
b
);
if
(
a_
<
b_
)
a
=
b
;
}
template
<
typename
T
>
__host__
__device__
inline
constexpr
void
operator
()(
T
&
a
,
T
b
,
bool
&
changed
)
const
{
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
is_same
<
T
,
half_t
>::
value
||
is_same
<
T
,
int32_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
,
is_same
<
T
,
int32_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
,
"The data type is not supported by the Max accumulator!"
);
if
(
a
<
b
)
...
...
@@ -197,6 +255,30 @@ struct Max
changed
=
true
;
}
}
__host__
__device__
inline
constexpr
void
operator
()(
half_t
&
a
,
half_t
b
,
bool
&
changed
)
const
{
float
a_
=
type_convert
<
float
>
(
a
);
float
b_
=
type_convert
<
float
>
(
b
);
if
(
a_
<
b_
)
{
a
=
b
;
changed
=
true
;
}
}
__host__
__device__
inline
constexpr
void
operator
()(
f8_t
&
a
,
f8_t
b
,
bool
&
changed
)
const
{
float
a_
=
type_convert
<
float
>
(
a
);
float
b_
=
type_convert
<
float
>
(
b
);
if
(
a_
<
b_
)
{
a
=
b
;
changed
=
true
;
}
}
};
struct
Min
...
...
@@ -209,6 +291,16 @@ struct Min
float
val
=
NumericLimits
<
float
>::
Max
();
return
type_convert
<
bhalf_t
>
(
val
);
}
else
if
constexpr
(
is_same_v
<
T
,
half_t
>
)
{
float
val
=
NumericLimits
<
float
>::
Max
();
return
type_convert
<
half_t
>
(
val
);
}
else
if
constexpr
(
is_same_v
<
T
,
f8_t
>
)
{
float
val
=
NumericLimits
<
float
>::
Max
();
return
type_convert
<
f8_t
>
(
val
);
}
else
{
return
NumericLimits
<
T
>::
Max
();
...
...
@@ -227,8 +319,7 @@ struct Min
__host__
__device__
inline
constexpr
void
operator
()(
T
&
a
,
T
b
)
const
{
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
is_same
<
T
,
half_t
>::
value
||
is_same
<
T
,
int32_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
,
is_same
<
T
,
int32_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
,
"The data type is not supported by the Min accumulator!"
);
if
(
a
>
b
)
...
...
@@ -244,6 +335,24 @@ struct Min
a
=
b
;
}
__host__
__device__
inline
constexpr
void
operator
()(
half_t
&
a
,
half_t
b
)
const
{
float
a_
=
type_convert
<
float
>
(
a
);
float
b_
=
type_convert
<
float
>
(
b
);
if
(
a_
>
b_
)
a
=
b
;
}
__host__
__device__
inline
constexpr
void
operator
()(
f8_t
&
a
,
f8_t
b
)
const
{
float
a_
=
type_convert
<
float
>
(
a
);
float
b_
=
type_convert
<
float
>
(
b
);
if
(
a_
>
b_
)
a
=
b
;
}
template
<
typename
T
>
__host__
__device__
inline
constexpr
void
operator
()(
T
&
a
,
T
b
,
bool
&
changed
)
const
{
...
...
@@ -270,6 +379,30 @@ struct Min
changed
=
true
;
}
}
__host__
__device__
inline
constexpr
void
operator
()(
half_t
&
a
,
half_t
b
,
bool
&
changed
)
const
{
float
a_
=
type_convert
<
float
>
(
a
);
float
b_
=
type_convert
<
float
>
(
b
);
if
(
a_
>
b_
)
{
a
=
b
;
changed
=
true
;
}
}
__host__
__device__
inline
constexpr
void
operator
()(
f8_t
&
a
,
f8_t
b
,
bool
&
changed
)
const
{
float
a_
=
type_convert
<
float
>
(
a
);
float
b_
=
type_convert
<
float
>
(
b
);
if
(
a_
>
b_
)
{
a
=
b
;
changed
=
true
;
}
}
};
struct
AMax
...
...
@@ -299,6 +432,15 @@ struct AMax
a
=
b
;
}
__host__
__device__
inline
constexpr
void
operator
()(
f8_t
&
a
,
f8_t
b
)
const
{
float
a_
=
type_convert
<
float
>
(
a
);
float
b_
=
type_convert
<
float
>
(
b
);
if
(
a_
<
b_
)
a
=
b
;
}
template
<
typename
T
>
__host__
__device__
inline
constexpr
void
operator
()(
T
&
a
,
T
b
,
bool
&
changed
)
const
{
...
...
@@ -313,6 +455,18 @@ struct AMax
changed
=
true
;
}
}
__host__
__device__
inline
constexpr
void
operator
()(
f8_t
&
a
,
f8_t
b
,
bool
&
changed
)
const
{
float
a_
=
type_convert
<
float
>
(
a
);
float
b_
=
type_convert
<
float
>
(
b
);
if
(
a_
<
b_
)
{
a
=
b
;
changed
=
true
;
}
}
};
template
<
typename
T
>
...
...
@@ -352,7 +506,8 @@ struct InMemoryDataOperationSupportedOnDataType<InMemoryDataOperationEnum::Set,
static
constexpr
bool
value
=
is_same
<
DataType
,
float
>::
value
||
is_same
<
DataType
,
double
>::
value
||
is_same
<
DataType
,
half_t
>::
value
||
is_same
<
DataType
,
bhalf_t
>::
value
||
is_same
<
DataType
,
int8_t
>::
value
||
is_same
<
DataType
,
int32_t
>::
value
;
is_same
<
DataType
,
int8_t
>::
value
||
is_same
<
DataType
,
int32_t
>::
value
||
is_same
<
DataType
,
f8_t
>::
value
;
};
template
<
typename
DataType
>
...
...
@@ -361,7 +516,7 @@ struct InMemoryDataOperationSupportedOnDataType<InMemoryDataOperationEnum::Add,
static
constexpr
bool
value
=
is_same
<
DataType
,
float
>::
value
||
is_same
<
DataType
,
double
>::
value
||
is_same
<
DataType
,
half_t
>::
value
||
is_same
<
DataType
,
int8_t
>::
value
||
is_same
<
DataType
,
int32_t
>::
value
;
is_same
<
DataType
,
int32_t
>::
value
||
is_same
<
DataType
,
f8_t
>::
value
;
};
}
// namespace reduce
...
...
include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp
View file @
6252d207
...
...
@@ -29,9 +29,9 @@ struct BlockFmhaBwdPipelineDefaultPolicy
typename
Problem
::
QDataType
,
typename
Problem
::
KDataType
,
typename
Problem
::
AccDataType
,
TileGemmShape
<
sequence
<
Problem
::
BlockFmhaShape
::
BlockTile
::
kM0
,
Problem
::
BlockFmhaShape
::
BlockTile
::
kN0
,
Problem
::
BlockFmhaShape
::
BlockTile
::
kK0
>
,
TileGemmShape
<
sequence
<
Problem
::
BlockFmhaShape
::
kM0
,
Problem
::
BlockFmhaShape
::
kN0
,
Problem
::
BlockFmhaShape
::
kK0
>
,
typename
Problem
::
BlockFmhaShape
::
Gemm0BlockWarps
,
typename
Problem
::
BlockFmhaShape
::
Gemm0WarpTile
>>
;
...
...
@@ -62,9 +62,9 @@ struct BlockFmhaBwdPipelineDefaultPolicy
typename
Problem
::
GemmDataType
,
typename
Problem
::
OGradDataType
,
typename
Problem
::
AccDataType
,
TileGemmShape
<
sequence
<
Problem
::
BlockFmhaShape
::
BlockTile
::
kN0
,
Problem
::
BlockFmhaShape
::
BlockTile
::
kVHeaddim
,
Problem
::
BlockFmhaShape
::
BlockTile
::
kK1
>
,
TileGemmShape
<
sequence
<
Problem
::
BlockFmhaShape
::
kN0
,
Problem
::
BlockFmhaShape
::
kVHeaddim
,
Problem
::
BlockFmhaShape
::
kK1
>
,
typename
Problem
::
BlockFmhaShape
::
Gemm1BlockWarps
,
typename
Problem
::
BlockFmhaShape
::
Gemm1WarpTile
>>
;
...
...
@@ -94,9 +94,9 @@ struct BlockFmhaBwdPipelineDefaultPolicy
typename
Problem
::
OGradDataType
,
typename
Problem
::
VDataType
,
typename
Problem
::
AccDataType
,
TileGemmShape
<
sequence
<
Problem
::
BlockFmhaShape
::
BlockTile
::
kM0
,
Problem
::
BlockFmhaShape
::
BlockTile
::
kN0
,
Problem
::
BlockFmhaShape
::
BlockTile
::
kK2
>
,
TileGemmShape
<
sequence
<
Problem
::
BlockFmhaShape
::
kM0
,
Problem
::
BlockFmhaShape
::
kN0
,
Problem
::
BlockFmhaShape
::
kK2
>
,
typename
Problem
::
BlockFmhaShape
::
Gemm2BlockWarps
,
typename
Problem
::
BlockFmhaShape
::
Gemm2WarpTile
>>
;
...
...
@@ -127,9 +127,9 @@ struct BlockFmhaBwdPipelineDefaultPolicy
typename
Problem
::
GemmDataType
,
typename
Problem
::
QDataType
,
typename
Problem
::
AccDataType
,
TileGemmShape
<
sequence
<
Problem
::
BlockFmhaShape
::
BlockTile
::
kN0
,
Problem
::
BlockFmhaShape
::
BlockTile
::
kQKHeaddim
,
Problem
::
BlockFmhaShape
::
BlockTile
::
kK3
>
,
TileGemmShape
<
sequence
<
Problem
::
BlockFmhaShape
::
kN0
,
Problem
::
BlockFmhaShape
::
kQKHeaddim
,
Problem
::
BlockFmhaShape
::
kK3
>
,
typename
Problem
::
BlockFmhaShape
::
Gemm3BlockWarps
,
typename
Problem
::
BlockFmhaShape
::
Gemm3WarpTile
>>
;
...
...
@@ -159,9 +159,9 @@ struct BlockFmhaBwdPipelineDefaultPolicy
typename
Problem
::
GemmDataType
,
typename
Problem
::
KDataType
,
typename
Problem
::
AccDataType
,
TileGemmShape
<
sequence
<
Problem
::
BlockFmhaShape
::
BlockTile
::
kM0
,
Problem
::
BlockFmhaShape
::
BlockTile
::
kQKHeaddim
,
Problem
::
BlockFmhaShape
::
BlockTile
::
kK4
>
,
TileGemmShape
<
sequence
<
Problem
::
BlockFmhaShape
::
kM0
,
Problem
::
BlockFmhaShape
::
kQKHeaddim
,
Problem
::
BlockFmhaShape
::
kK4
>
,
typename
Problem
::
BlockFmhaShape
::
Gemm4BlockWarps
,
typename
Problem
::
BlockFmhaShape
::
Gemm4WarpTile
>>
;
...
...
include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp
View file @
6252d207
...
...
@@ -25,7 +25,7 @@ struct GemmKernel
using
LayoutA
=
remove_cvref_t
<
LayoutA_
>
;
using
LayoutB
=
remove_cvref_t
<
LayoutB_
>
;
using
LayoutC
=
remove_cvref_t
<
LayoutC_
>
;
static
constexpr
index_t
KernelBlockSize
=
GemmPipeline
::
Kernel
BlockSize
;
static
constexpr
index_t
KernelBlockSize
=
GemmPipeline
::
k
BlockSize
;
using
ADataType
=
remove_cvref_t
<
typename
GemmPipeline
::
ADataType
>
;
using
BDataType
=
remove_cvref_t
<
typename
GemmPipeline
::
BDataType
>
;
...
...
include/ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v1.hpp
View file @
6252d207
...
...
@@ -19,7 +19,7 @@ struct BlockGemmPipelineAGmemBGmemCRegV1
using
CDataType
=
remove_cvref_t
<
typename
Problem
::
CDataType
>
;
using
BlockGemmShape
=
remove_cvref_t
<
typename
Problem
::
BlockGemmShape
>
;
static
constexpr
index_t
Kernel
BlockSize
=
Problem
::
Kernel
BlockSize
;
static
constexpr
index_t
k
BlockSize
=
Problem
::
k
BlockSize
;
static
constexpr
index_t
kMPerBlock
=
BlockGemmShape
::
kM
;
static
constexpr
index_t
kNPerBlock
=
BlockGemmShape
::
kN
;
...
...
include/ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp
View file @
6252d207
...
...
@@ -195,7 +195,7 @@ struct BlockGemmPipelineAGmemBGmemCRegV1DefaultPolicy
{
using
ADataType
=
remove_cvref_t
<
typename
Problem
::
ADataType
>
;
constexpr
index_t
Kernel
BlockSize
=
Problem
::
Kernel
BlockSize
;
constexpr
index_t
k
BlockSize
=
Problem
::
k
BlockSize
;
constexpr
index_t
kMPerBlock
=
Problem
::
BlockGemmShape
::
kM
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockGemmShape
::
kK
;
...
...
@@ -204,7 +204,7 @@ struct BlockGemmPipelineAGmemBGmemCRegV1DefaultPolicy
constexpr
index_t
K0
=
kKPerBlock
/
K1
;
constexpr
index_t
M2
=
get_warp_size
()
/
K0
;
#if 1 // coalesce reading for each blocks
constexpr
index_t
M1
=
Kernel
BlockSize
/
get_warp_size
();
constexpr
index_t
M1
=
k
BlockSize
/
get_warp_size
();
static_assert
(
M2
!=
0
,
"M2 is zero, which will lead to a division by zero error."
);
static_assert
(
M1
!=
0
,
"M1 is zero, which will lead to a division by zero error."
);
constexpr
index_t
M0
=
kMPerBlock
/
(
M2
*
M1
);
...
...
@@ -217,7 +217,7 @@ struct BlockGemmPipelineAGmemBGmemCRegV1DefaultPolicy
sequence
<
1
,
2
>
,
sequence
<
0
,
1
>>
{});
#else // coalesce reading for each warps
constexpr
index_t
M0
=
Kernel
BlockSize
/
get_warp_size
();
constexpr
index_t
M0
=
k
BlockSize
/
get_warp_size
();
constexpr
index_t
M1
=
kMPerBlock
/
(
M2
*
M0
);
return
make_static_tile_distribution
(
...
...
@@ -235,7 +235,7 @@ struct BlockGemmPipelineAGmemBGmemCRegV1DefaultPolicy
{
using
BDataType
=
remove_cvref_t
<
typename
Problem
::
BDataType
>
;
constexpr
index_t
Kernel
BlockSize
=
Problem
::
Kernel
BlockSize
;
constexpr
index_t
k
BlockSize
=
Problem
::
k
BlockSize
;
constexpr
index_t
kNPerBlock
=
Problem
::
BlockGemmShape
::
kN
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockGemmShape
::
kK
;
...
...
@@ -244,7 +244,7 @@ struct BlockGemmPipelineAGmemBGmemCRegV1DefaultPolicy
constexpr
index_t
K0
=
kKPerBlock
/
K1
;
constexpr
index_t
N2
=
get_warp_size
()
/
K0
;
#if 1 // coalesce reading for each blocks
constexpr
index_t
N1
=
Kernel
BlockSize
/
get_warp_size
();
constexpr
index_t
N1
=
k
BlockSize
/
get_warp_size
();
static_assert
(
N2
!=
0
,
"M2 is zero, which will lead to a division by zero error."
);
static_assert
(
N1
!=
0
,
"M1 is zero, which will lead to a division by zero error."
);
constexpr
index_t
N0
=
kNPerBlock
/
(
N2
*
N1
);
...
...
@@ -257,7 +257,7 @@ struct BlockGemmPipelineAGmemBGmemCRegV1DefaultPolicy
sequence
<
1
,
2
>
,
sequence
<
0
,
1
>>
{});
#else // coalesce reading for each warps
constexpr
index_t
N0
=
Kernel
BlockSize
/
get_warp_size
();
constexpr
index_t
N0
=
k
BlockSize
/
get_warp_size
();
constexpr
index_t
N1
=
kNPerBlock
/
(
N2
*
N0
);
return
make_static_tile_distribution
(
...
...
include/ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v2.hpp
View file @
6252d207
...
...
@@ -19,7 +19,7 @@ struct BlockGemmPipelineAGmemBGmemCRegV2
using
CDataType
=
remove_cvref_t
<
typename
Problem
::
CDataType
>
;
using
BlockGemmShape
=
remove_cvref_t
<
typename
Problem
::
BlockGemmShape
>
;
static
constexpr
index_t
Kernel
BlockSize
=
Problem
::
Kernel
BlockSize
;
static
constexpr
index_t
k
BlockSize
=
Problem
::
k
BlockSize
;
static
constexpr
index_t
kMPerBlock
=
BlockGemmShape
::
kM
;
static
constexpr
index_t
kNPerBlock
=
BlockGemmShape
::
kN
;
...
...
include/ck_tile/ops/gemm/pipeline/block_gemm_pipeline_problem.hpp
View file @
6252d207
...
...
@@ -23,10 +23,10 @@ struct BlockGemmPipelineProblem
using
CDataType
=
remove_cvref_t
<
CDataType_
>
;
using
BlockGemmShape
=
remove_cvref_t
<
BlockGemmShape_
>
;
static
constexpr
index_t
Kernel
BlockSize
=
BlockGemmShape
::
NumWarps
*
get_warp_size
();
static
constexpr
bool
kPadA
=
kPadA_
;
static
constexpr
bool
kPadB
=
kPadB_
;
static
constexpr
bool
kPadC
=
kPadC_
;
static
constexpr
index_t
k
BlockSize
=
BlockGemmShape
::
NumWarps
*
get_warp_size
();
static
constexpr
bool
kPadA
=
kPadA_
;
static
constexpr
bool
kPadB
=
kPadB_
;
static
constexpr
bool
kPadC
=
kPadC_
;
static
constexpr
index_t
AlignmentA
=
kPadA
?
VectorLoadSize
/
sizeof
(
ADataType
)
:
1
;
static
constexpr
index_t
AlignmentB
=
kPadB
?
VectorLoadSize
/
sizeof
(
BDataType
)
:
1
;
...
...
library/include/ck/library/tensor_operation_instance/gpu/avg_pool2d_bwd.hpp
0 → 100644
View file @
6252d207
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/tensor_operation/gpu/device/device_avgpool_bwd.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
#ifdef CK_ENABLE_BF16
void
add_device_avgpool_2D_bwd_nhwc_bf16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceAvgPoolBwd
<
2
,
BF16
,
BF16
,
NHWC
,
NHWC
>>>&
);
#endif
#ifdef CK_ENABLE_FP16
void
add_device_avgpool_2D_bwd_nhwc_f16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceAvgPoolBwd
<
2
,
F16
,
F16
,
NHWC
,
NHWC
>>>&
);
#endif
#ifdef CK_ENABLE_FP8
void
add_device_avgpool_2D_bwd_nhwc_f8_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceAvgPoolBwd
<
2
,
F8
,
F8
,
NHWC
,
NHWC
>>>&
);
#endif
#ifdef CK_ENABLE_FP32
void
add_device_avgpool_2D_bwd_nhwc_f32_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceAvgPoolBwd
<
2
,
F32
,
F32
,
NHWC
,
NHWC
>>>&
);
#endif
#ifdef CK_ENABLE_INT8
void
add_device_avgpool_2D_bwd_nhwc_int8_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceAvgPoolBwd
<
2
,
I8
,
I8
,
NHWC
,
NHWC
>>>&
);
#endif
template
<
typename
DOutDataType
,
typename
DInDataType
,
typename
InLayout
,
typename
OutLayout
>
struct
DeviceOperationInstanceFactory
<
ck
::
tensor_operation
::
device
::
DeviceAvgPoolBwd
<
2
,
DOutDataType
,
DInDataType
,
InLayout
,
OutLayout
>>
{
using
DeviceOp
=
DeviceAvgPoolBwd
<
2
,
DOutDataType
,
DInDataType
,
InLayout
,
OutLayout
>
;
static
auto
GetInstances
()
{
std
::
vector
<
std
::
unique_ptr
<
DeviceOp
>>
op_ptrs
;
if
constexpr
(
is_same_v
<
InLayout
,
NHWC
>
&&
is_same_v
<
OutLayout
,
NHWC
>
)
{
#ifdef CK_ENABLE_FP16
if
constexpr
(
is_same_v
<
DOutDataType
,
F16
>
&&
is_same_v
<
DInDataType
,
F16
>
)
add_device_avgpool_2D_bwd_nhwc_f16_instances
(
op_ptrs
);
#endif
#ifdef CK_ENABLE_BF16
else
if
constexpr
(
is_same_v
<
DOutDataType
,
BF16
>
&&
is_same_v
<
DInDataType
,
BF16
>
)
add_device_avgpool_2D_bwd_nhwc_bf16_instances
(
op_ptrs
);
#endif
#ifdef CK_ENABLE_FP32
else
if
constexpr
(
is_same_v
<
DOutDataType
,
F32
>
&&
is_same_v
<
DInDataType
,
F32
>
)
add_device_avgpool_2D_bwd_nhwc_f32_instances
(
op_ptrs
);
#endif
#ifdef CK_ENABLE_FP8
else
if
constexpr
(
is_same_v
<
DOutDataType
,
F8
>
&&
is_same_v
<
DInDataType
,
F8
>
)
add_device_avgpool_2D_bwd_nhwc_f8_instances
(
op_ptrs
);
#endif
#ifdef CK_ENABLE_INT8
else
if
constexpr
(
is_same_v
<
DOutDataType
,
I8
>
&&
is_same_v
<
DInDataType
,
I8
>
)
add_device_avgpool_2D_bwd_nhwc_int8_instances
(
op_ptrs
);
#endif
}
return
op_ptrs
;
}
};
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/include/ck/library/tensor_operation_instance/gpu/max_pool_bwd.hpp
View file @
6252d207
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -23,6 +23,11 @@ void add_device_maxpool_bwd_bf16_instances(
void
add_device_maxpool_bwd_f32_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceMaxPoolBwd
<
F32
,
I32
,
F32
>>>&
);
#endif
#ifdef CK_ENABLE_INT8
void
add_device_maxpool_bwd_int8_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceMaxPoolBwd
<
I8
,
I32
,
I8
>>>&
);
#endif
template
<
typename
DOutDataType
,
typename
IndexDataType
,
typename
DInDataType
>
struct
DeviceOperationInstanceFactory
<
ck
::
tensor_operation
::
device
::
DeviceMaxPoolBwd
<
DOutDataType
,
IndexDataType
,
DInDataType
>>
...
...
@@ -32,6 +37,7 @@ struct DeviceOperationInstanceFactory<
static
auto
GetInstances
()
{
std
::
vector
<
std
::
unique_ptr
<
DeviceOp
>>
op_ptrs
;
#ifdef CK_ENABLE_FP16
if
constexpr
(
is_same_v
<
DOutDataType
,
F16
>
&&
is_same_v
<
DInDataType
,
F16
>
&&
is_same_v
<
IndexDataType
,
I32
>
)
...
...
@@ -47,6 +53,11 @@ struct DeviceOperationInstanceFactory<
is_same_v
<
IndexDataType
,
I32
>
)
add_device_maxpool_bwd_f32_instances
(
op_ptrs
);
#endif
#ifdef CK_ENABLE_INT8
else
if
constexpr
(
is_same_v
<
DOutDataType
,
I8
>
&&
is_same_v
<
DInDataType
,
I8
>
&&
is_same_v
<
IndexDataType
,
I32
>
)
add_device_maxpool_bwd_int8_instances
(
op_ptrs
);
#endif
return
op_ptrs
;
}
...
...
library/include/ck/library/tensor_operation_instance/gpu/pool2d_fwd.hpp
0 → 100644
View file @
6252d207
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_pool_fwd.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
static
constexpr
auto
InOutRank
=
4
;
static
constexpr
auto
WindowRank
=
2
;
static
constexpr
auto
MaxOp
=
ck
::
ReduceTensorOp
::
MAX
;
static
constexpr
auto
AvgOp
=
ck
::
ReduceTensorOp
::
AVG
;
#ifdef CK_ENABLE_FP16
// FP16
void
add_device_pool2d_fwd_nhwc_f16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DevicePoolFwd
<
InOutRank
,
WindowRank
,
F16
,
F16
,
I32
,
NHWC
,
NHWC
,
MaxOp
,
false
>>>&
);
void
add_device_pool2d_fwd_nhwc_f16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DevicePoolFwd
<
InOutRank
,
WindowRank
,
F16
,
F16
,
I32
,
NHWC
,
NHWC
,
AvgOp
,
false
>>>&
);
// FP16 - return index
void
add_device_pool2d_fwd_nhwc_index_f16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DevicePoolFwd
<
InOutRank
,
WindowRank
,
F16
,
F16
,
I32
,
NHWC
,
NHWC
,
MaxOp
,
true
>>>&
);
#endif
#ifdef CK_ENABLE_BF16
// BF16
void
add_device_pool2d_fwd_nhwc_bf16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DevicePoolFwd
<
InOutRank
,
WindowRank
,
BF16
,
BF16
,
I32
,
NHWC
,
NHWC
,
MaxOp
,
false
>>>&
);
void
add_device_pool2d_fwd_nhwc_bf16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DevicePoolFwd
<
InOutRank
,
WindowRank
,
BF16
,
BF16
,
I32
,
NHWC
,
NHWC
,
AvgOp
,
false
>>>&
);
// BF16 - return index
void
add_device_pool2d_fwd_nhwc_index_bf16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DevicePoolFwd
<
InOutRank
,
WindowRank
,
BF16
,
BF16
,
I32
,
NHWC
,
NHWC
,
MaxOp
,
true
>>>&
);
#endif
#ifdef CK_ENABLE_FP32
// FP32
void
add_device_pool2d_fwd_nhwc_f32_instances
(
std
::
vector
<
std
::
unique_ptr
<
DevicePoolFwd
<
InOutRank
,
WindowRank
,
F32
,
F32
,
I32
,
NHWC
,
NHWC
,
MaxOp
,
false
>>>&
);
void
add_device_pool2d_fwd_nhwc_f32_instances
(
std
::
vector
<
std
::
unique_ptr
<
DevicePoolFwd
<
InOutRank
,
WindowRank
,
F32
,
F32
,
I32
,
NHWC
,
NHWC
,
AvgOp
,
false
>>>&
);
// FP32 - return index
void
add_device_pool2d_fwd_nhwc_index_f32_instances
(
std
::
vector
<
std
::
unique_ptr
<
DevicePoolFwd
<
InOutRank
,
WindowRank
,
F32
,
F32
,
I32
,
NHWC
,
NHWC
,
MaxOp
,
true
>>>&
);
#endif
template
<
typename
InDataType
,
typename
OutDataType
,
typename
IndexDataType
,
typename
InLayout
,
typename
OutLayout
,
ck
::
ReduceTensorOp
ReduceOpId
,
bool
OutputIndex
>
struct
DeviceOperationInstanceFactory
<
ck
::
tensor_operation
::
device
::
DevicePoolFwd
<
InOutRank
,
WindowRank
,
InDataType
,
OutDataType
,
IndexDataType
,
InLayout
,
OutLayout
,
ReduceOpId
,
OutputIndex
>>
{
using
DeviceOp
=
DevicePoolFwd
<
InOutRank
,
WindowRank
,
InDataType
,
OutDataType
,
IndexDataType
,
InLayout
,
OutLayout
,
ReduceOpId
,
OutputIndex
>
;
static
auto
GetInstances
()
{
std
::
vector
<
std
::
unique_ptr
<
DeviceOp
>>
op_ptrs
;
if
constexpr
(
is_same_v
<
InLayout
,
NHWC
>
&&
is_same_v
<
OutLayout
,
NHWC
>
)
{
#ifdef CK_ENABLE_FP16
if
constexpr
(
is_same_v
<
InDataType
,
F16
>
&&
is_same_v
<
OutDataType
,
F16
>
&&
is_same_v
<
IndexDataType
,
I32
>
)
{
if
constexpr
(
OutputIndex
&&
ReduceOpId
==
MaxOp
)
{
add_device_pool2d_fwd_nhwc_index_f16_instances
(
op_ptrs
);
}
else
{
add_device_pool2d_fwd_nhwc_f16_instances
(
op_ptrs
);
}
}
#endif
#ifdef CK_ENABLE_BF16
else
if
constexpr
(
is_same_v
<
InDataType
,
BF16
>
&&
is_same_v
<
OutDataType
,
BF16
>
&&
is_same_v
<
IndexDataType
,
I32
>
)
{
if
constexpr
(
OutputIndex
&&
ReduceOpId
==
MaxOp
)
{
add_device_pool2d_fwd_nhwc_index_bf16_instances
(
op_ptrs
);
}
else
{
add_device_pool2d_fwd_nhwc_bf16_instances
(
op_ptrs
);
}
}
#endif
#ifdef CK_ENABLE_FP32
else
if
constexpr
(
is_same_v
<
InDataType
,
F32
>
&&
is_same_v
<
OutDataType
,
F32
>
&&
is_same_v
<
IndexDataType
,
I32
>
)
{
if
constexpr
(
OutputIndex
&&
ReduceOpId
==
MaxOp
)
{
add_device_pool2d_fwd_nhwc_index_f32_instances
(
op_ptrs
);
}
else
{
add_device_pool2d_fwd_nhwc_f32_instances
(
op_ptrs
);
}
}
#endif
}
return
op_ptrs
;
}
};
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/src/tensor_operation_instance/gpu/avg_pool2d_bwd/CMakeLists.txt
0 → 100644
View file @
6252d207
set
(
DEVICE_AVGPOOL_2D_BWD_INSTANCES
)
list
(
APPEND DEVICE_AVGPOOL_2D_BWD_INSTANCES device_avg_pool2d_bwd_nhwc_bf16_instance.cpp
device_avg_pool2d_bwd_nhwc_f16_instance.cpp
device_avg_pool2d_bwd_nhwc_f32_instance.cpp
device_avg_pool2d_bwd_nhwc_f8_instance.cpp
device_avg_pool2d_bwd_nhwc_int8_instance.cpp
)
add_instance_library
(
device_avg_pool2d_bwd_instance
${
DEVICE_AVGPOOL_2D_BWD_INSTANCES
}
)
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment