Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
bd689f40
Commit
bd689f40
authored
Aug 20, 2024
by
illsilin
Browse files
merge from public repo
parents
c160c6cf
a94113a9
Changes
333
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
3339 additions
and
589 deletions
+3339
-589
include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp
...operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp
+142
-162
include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3r1.hpp
...eration/gpu/device/impl/device_gemm_xdl_cshuffle_v3r1.hpp
+703
-0
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp
...device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp
+38
-38
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp
.../device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp
+194
-58
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp
...device/impl/device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp
+31
-50
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp
...mpl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp
+210
-84
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp
.../device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp
+188
-58
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp
...e_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp
+46
-55
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp
...impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp
+202
-57
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp
...grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp
+1054
-0
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp
...r_operation/gpu/device/impl/device_grouped_conv_utils.hpp
+16
-0
include/ck/tensor_operation/gpu/device/impl/device_image_to_column_impl.hpp
...operation/gpu/device/impl/device_image_to_column_impl.hpp
+14
-14
include/ck/tensor_operation/gpu/device/impl/device_reduce_common.hpp
...tensor_operation/gpu/device/impl/device_reduce_common.hpp
+3
-3
include/ck/tensor_operation/gpu/device/impl/device_reduce_multiblock.hpp
...or_operation/gpu/device/impl/device_reduce_multiblock.hpp
+2
-2
include/ck/tensor_operation/gpu/device/impl/device_reduce_threadwise.hpp
...or_operation/gpu/device/impl/device_reduce_threadwise.hpp
+2
-2
include/ck/tensor_operation/gpu/device/impl/device_reduce_threadwise_multi_d.hpp
...tion/gpu/device/impl/device_reduce_threadwise_multi_d.hpp
+412
-0
include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp
...r_operation/gpu/element/binary_element_wise_operation.hpp
+26
-0
include/ck/tensor_operation/gpu/element/element_wise_operation.hpp
...k/tensor_operation/gpu/element/element_wise_operation.hpp
+26
-1
include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp
...or_operation/gpu/element/unary_element_wise_operation.hpp
+28
-3
include/ck/tensor_operation/gpu/grid/gridwise_2d_multiple_reduction_multiblock.hpp
...on/gpu/grid/gridwise_2d_multiple_reduction_multiblock.hpp
+2
-2
No files found.
include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp
View file @
bd689f40
...
...
@@ -168,15 +168,11 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
// rotating mem
rotating_mem
.
Next
();
// clear c mem
if
constexpr
(
!
is_same
<
remove_cvref_t
<
CDataType
>
,
bhalf_t
>::
value
)
{
if
(
arg_
.
KBatch
>
1
)
hipGetErrorString
(
hipMemsetAsync
(
arg_
.
p_c_grid
,
0
,
arg_
.
M
*
arg_
.
N
*
sizeof
(
CDataType
),
stream_config
.
stream_id_
));
}
if
(
arg_
.
KBatch
>
1
)
hipGetErrorString
(
hipMemsetAsync
(
arg_
.
p_c_grid
,
0
,
arg_
.
M
*
arg_
.
N
*
sizeof
(
CDataType
),
stream_config
.
stream_id_
));
};
ave_time
=
ck
::
utility
::
launch_and_time_kernel_with_preprocess
<
false
>
(
...
...
@@ -190,14 +186,11 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
}
else
{
if
constexpr
(
!
is_same
<
remove_cvref_t
<
CDataType
>
,
bhalf_t
>::
value
)
{
if
(
arg
.
KBatch
>
1
)
hipGetErrorString
(
hipMemsetAsync
(
arg
.
p_c_grid
,
0
,
arg
.
M
*
arg
.
N
*
sizeof
(
CDataType
),
stream_config
.
stream_id_
));
}
if
(
arg
.
KBatch
>
1
)
hipGetErrorString
(
hipMemsetAsync
(
arg
.
p_c_grid
,
0
,
arg
.
M
*
arg
.
N
*
sizeof
(
CDataType
),
stream_config
.
stream_id_
));
ave_time
=
launch_and_time_kernel
(
stream_config
,
kernel
,
dim3
(
gdx
,
gdy
,
gdz
),
dim3
(
BlockSize
),
0
,
arg
);
...
...
@@ -215,15 +208,12 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
{
if
(
arg
.
KBatch
>
1
)
{
if
constexpr
(
!
is_same
<
remove_cvref_t
<
CDataType
>
,
bhalf_t
>::
value
)
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
AtomicAdd
,
minimum_occupancy
>
;
Run
(
kernel
);
}
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
AtomicAdd
,
minimum_occupancy
>
;
Run
(
kernel
);
}
else
{
...
...
@@ -240,118 +230,113 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
{
if
(
arg
.
KBatch
>
1
)
{
if
constexpr
(
!
is_same
<
remove_cvref_t
<
CDataType
>
,
bhalf_t
>::
valu
e
)
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
On
e
)
{
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
One
)
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
AtomicAdd
,
minimum_occupancy
,
TailNumber
::
One
>
;
Run
(
kernel
);
}
else
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Full
)
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
AtomicAdd
,
minimum_occupancy
,
TailNumber
::
Full
>
;
Run
(
kernel
);
}
if
constexpr
(
GridwiseGemm
::
BlockwiseGemmPipe
::
PrefetchStages
>
2
)
{
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Two
)
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
AtomicAdd
,
minimum_occupancy
,
TailNumber
::
One
>
;
TailNumber
::
Two
>
;
Run
(
kernel
);
}
else
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Full
)
}
if
constexpr
(
GridwiseGemm
::
BlockwiseGemmPipe
::
PrefetchStages
>
3
)
{
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Three
)
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
AtomicAdd
,
minimum_occupancy
,
TailNumber
::
Full
>
;
TailNumber
::
Three
>
;
Run
(
kernel
);
}
}
if
constexpr
(
GridwiseGemm
::
BlockwiseGemmPipe
::
PrefetchStages
>
2
)
{
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Two
)
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
AtomicAdd
,
minimum_occupancy
,
TailNumber
::
Two
>
;
Run
(
kernel
);
}
}
if
constexpr
(
GridwiseGemm
::
BlockwiseGemmPipe
::
PrefetchStages
>
3
)
{
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Three
)
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
AtomicAdd
,
minimum_occupancy
,
TailNumber
::
Three
>
;
Run
(
kernel
);
}
}
if
constexpr
(
GridwiseGemm
::
BlockwiseGemmPipe
::
PrefetchStages
>
4
)
if
constexpr
(
GridwiseGemm
::
BlockwiseGemmPipe
::
PrefetchStages
>
4
)
{
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Four
)
{
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Four
)
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
AtomicAdd
,
minimum_occupancy
,
TailNumber
::
Four
>
;
Run
(
kernel
);
}
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
AtomicAdd
,
minimum_occupancy
,
TailNumber
::
Four
>
;
Run
(
kernel
);
}
}
if
constexpr
(
GridwiseGemm
::
BlockwiseGemmPipe
::
PrefetchStages
>
5
)
if
constexpr
(
GridwiseGemm
::
BlockwiseGemmPipe
::
PrefetchStages
>
5
)
{
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Five
)
{
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Five
)
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
AtomicAdd
,
minimum_occupancy
,
TailNumber
::
Five
>
;
Run
(
kernel
);
}
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
AtomicAdd
,
minimum_occupancy
,
TailNumber
::
Five
>
;
Run
(
kernel
);
}
}
if
constexpr
(
GridwiseGemm
::
BlockwiseGemmPipe
::
PrefetchStages
>
6
)
if
constexpr
(
GridwiseGemm
::
BlockwiseGemmPipe
::
PrefetchStages
>
6
)
{
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Six
)
{
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Six
)
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
AtomicAdd
,
minimum_occupancy
,
TailNumber
::
Six
>
;
Run
(
kernel
);
}
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
AtomicAdd
,
minimum_occupancy
,
TailNumber
::
Six
>
;
Run
(
kernel
);
}
}
if
constexpr
(
GridwiseGemm
::
BlockwiseGemmPipe
::
PrefetchStages
>
7
)
if
constexpr
(
GridwiseGemm
::
BlockwiseGemmPipe
::
PrefetchStages
>
7
)
{
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Seven
)
{
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Seven
)
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
AtomicAdd
,
minimum_occupancy
,
TailNumber
::
Seven
>
;
Run
(
kernel
);
}
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
AtomicAdd
,
minimum_occupancy
,
TailNumber
::
Seven
>
;
Run
(
kernel
);
}
}
}
...
...
@@ -473,28 +458,25 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
{
if
(
arg
.
KBatch
>
1
)
{
if
constexpr
(
!
is_same
<
remove_cvref_t
<
CDataType
>
,
bhalf_t
>::
value
)
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Odd
)
{
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Odd
)
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3_2lds
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
AtomicAdd
,
minimum_occupancy
,
TailNumber
::
Odd
>
;
Run
(
kernel
);
}
else
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3_2lds
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
AtomicAdd
,
minimum_occupancy
,
TailNumber
::
Even
>
;
Run
(
kernel
);
}
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3_2lds
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
AtomicAdd
,
minimum_occupancy
,
TailNumber
::
Odd
>
;
Run
(
kernel
);
}
else
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3_2lds
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
AtomicAdd
,
minimum_occupancy
,
TailNumber
::
Even
>
;
Run
(
kernel
);
}
}
else
...
...
@@ -525,28 +507,25 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
{
if
(
arg
.
KBatch
>
1
)
{
if
constexpr
(
!
is_same
<
remove_cvref_t
<
CDataType
>
,
bhalf_t
>::
value
)
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Odd
)
{
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Odd
)
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
AtomicAdd
,
minimum_occupancy
,
TailNumber
::
Odd
>
;
Run
(
kernel
);
}
else
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
AtomicAdd
,
minimum_occupancy
,
TailNumber
::
Even
>
;
Run
(
kernel
);
}
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
AtomicAdd
,
minimum_occupancy
,
TailNumber
::
Odd
>
;
Run
(
kernel
);
}
else
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
AtomicAdd
,
minimum_occupancy
,
TailNumber
::
Even
>
;
Run
(
kernel
);
}
}
else
...
...
@@ -579,18 +558,14 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
// Tail number always 1
if
constexpr
(
BlkGemmPipelineVer
==
BlockGemmPipelineVersion
::
v1
)
{
if
(
arg
.
KBatch
>
1
)
{
if
constexpr
(
!
is_same
<
remove_cvref_t
<
CDataType
>
,
bhalf_t
>::
value
)
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3
<
GridwiseGemm
,
false
,
InMemoryDataOperationEnum
::
AtomicAdd
,
minimum_occupancy
>
;
Run
(
kernel
);
}
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3
<
GridwiseGemm
,
false
,
InMemoryDataOperationEnum
::
AtomicAdd
,
minimum_occupancy
>
;
Run
(
kernel
);
}
else
{
...
...
@@ -628,6 +603,11 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
return
false
;
}
if
(
!
is_bf16_atomic_supported
()
&&
std
::
is_same_v
<
CDataType
,
ck
::
bhalf_t
>
&&
arg
.
KBatch
>
1
)
{
return
false
;
}
if
((
arg
.
K
%
AK1
!=
0
||
arg
.
K
%
BK1
!=
0
)
&&
!
(
GemmSpec
==
GemmSpecialization
::
MKPadding
||
GemmSpec
==
GemmSpecialization
::
NKPadding
||
GemmSpec
==
GemmSpecialization
::
MNKPadding
||
...
...
include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3r1.hpp
0 → 100644
View file @
bd689f40
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <sstream>
#include <typeinfo>
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_v2.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/host_utility/flush_cache.hpp"
#include "ck/utility/reduction_enums.hpp"
#include "ck/tensor_operation/gpu/device/reduction_operator_mapping.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_reduce_threadwise_multi_d.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
template
<
typename
ALayout
,
typename
BLayout
,
typename
DsLayout
,
typename
CLayout
,
typename
ADataType
,
typename
BDataType
,
typename
DsDataType
,
typename
CDataType
,
typename
GemmAccDataType
,
typename
CShuffleDataType
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
,
GemmSpecialization
GemmSpec
,
index_t
BlockSize
,
index_t
MPerBlock
,
index_t
NPerBlock
,
index_t
KPerBlock
,
index_t
AK1
,
index_t
BK1
,
index_t
MPerXDL
,
index_t
NPerXDL
,
index_t
MXdlPerWave
,
index_t
NXdlPerWave
,
typename
ABlockTransferThreadClusterLengths_AK0_M_AK1
,
typename
ABlockTransferThreadClusterArrangeOrder
,
typename
ABlockTransferSrcAccessOrder
,
index_t
ABlockTransferSrcVectorDim
,
index_t
ABlockTransferSrcScalarPerVector
,
index_t
ABlockTransferDstScalarPerVector_AK1
,
bool
ABlockLdsExtraM
,
typename
BBlockTransferThreadClusterLengths_BK0_N_BK1
,
typename
BBlockTransferThreadClusterArrangeOrder
,
typename
BBlockTransferSrcAccessOrder
,
index_t
BBlockTransferSrcVectorDim
,
index_t
BBlockTransferSrcScalarPerVector
,
index_t
BBlockTransferDstScalarPerVector_BK1
,
bool
BBlockLdsExtraN
,
index_t
CShuffleMXdlPerWavePerShuffle
,
index_t
CShuffleNXdlPerWavePerShuffle
,
typename
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
index_t
CShuffleBlockTransferScalarPerVector_NPerBlock
,
BlockGemmPipelineScheduler
BlkGemmPipeSched
=
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
BlkGemmPipelineVer
=
BlockGemmPipelineVersion
::
v1
,
typename
ReduceDataType
=
CDataType
,
typename
ComputeTypeA
=
CDataType
,
typename
ComputeTypeB
=
ComputeTypeA
>
struct
DeviceGemm_Xdl_CShuffleV3R1
:
public
DeviceGemmV2R1
<
ALayout
,
BLayout
,
DsLayout
,
CLayout
,
ADataType
,
BDataType
,
DsDataType
,
CDataType
,
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
>
{
static
constexpr
index_t
NumDTensor
=
DsDataType
::
Size
();
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
// GridwiseGemm
using
GridwiseGemm
=
GridwiseGemm_xdl_cshuffle_v3
<
ALayout
,
BLayout
,
CLayout
,
ADataType
,
BDataType
,
GemmAccDataType
,
CShuffleDataType
,
ReduceDataType
,
AElementwiseOperation
,
BElementwiseOperation
,
PassThrough
,
GemmSpec
,
BlockSize
,
MPerBlock
,
NPerBlock
,
KPerBlock
,
AK1
,
BK1
,
MPerXDL
,
NPerXDL
,
MXdlPerWave
,
NXdlPerWave
,
ABlockTransferThreadClusterLengths_AK0_M_AK1
,
ABlockTransferThreadClusterArrangeOrder
,
ABlockTransferSrcAccessOrder
,
ABlockTransferSrcVectorDim
,
ABlockTransferSrcScalarPerVector
,
ABlockTransferDstScalarPerVector_AK1
,
false
,
ABlockLdsExtraM
,
BBlockTransferThreadClusterLengths_BK0_N_BK1
,
BBlockTransferThreadClusterArrangeOrder
,
BBlockTransferSrcAccessOrder
,
BBlockTransferSrcVectorDim
,
BBlockTransferSrcScalarPerVector
,
BBlockTransferDstScalarPerVector_BK1
,
false
,
BBlockLdsExtraN
,
CShuffleMXdlPerWavePerShuffle
,
CShuffleNXdlPerWavePerShuffle
,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
CShuffleBlockTransferScalarPerVector_NPerBlock
,
BlkGemmPipeSched
,
BlkGemmPipelineVer
,
ComputeTypeA
,
ComputeTypeB
>
;
struct
Argument
:
public
GridwiseGemm
::
Argument
{
Argument
(
const
ADataType
*
p_a_grid_
,
const
BDataType
*
p_b_grid_
,
const
std
::
array
<
const
void
*
,
NumDTensor
>
p_ds_
,
CDataType
*
p_c_grid_
,
index_t
M_
,
index_t
N_
,
index_t
K_
,
index_t
StrideA_
,
index_t
StrideB_
,
std
::
array
<
ck
::
index_t
,
NumDTensor
>
StrideDs_
,
index_t
StrideC_
,
index_t
k_batch_
)
:
GridwiseGemm
::
Argument
(
p_a_grid_
,
p_b_grid_
,
reinterpret_cast
<
ReduceDataType
*>
(
p_c_grid_
),
M_
,
N_
,
K_
,
StrideA_
,
StrideB_
,
StrideC_
,
k_batch_
,
true
),
p_ds
(
p_ds_
),
StrideDs
(
StrideDs_
)
{
}
const
std
::
array
<
const
void
*
,
NumDTensor
>
p_ds
;
std
::
array
<
ck
::
index_t
,
NumDTensor
>
StrideDs
;
};
using
ReduceAdd
=
ck
::
reduce
::
Add
;
using
OutElementwiseOperation
=
CElementwiseOperation
;
static
constexpr
auto
DsVectorLengthSequence
=
generate_sequence_v2
(
[](
auto
i
)
{
using
DLayout
=
remove_cvref_t
<
tuple_element_t
<
i
.
value
,
DsLayout
>>
;
if
constexpr
(
std
::
is_same
<
CLayout
,
DLayout
>::
value
)
return
Number
<
CShuffleBlockTransferScalarPerVector_NPerBlock
>
{};
else
return
Number
<
1
>
{};
},
Number
<
NumDTensor
>
{});
using
DeviceReduceInstance
=
DeviceReduceThreadWiseMultiD
<
ReduceDataType
,
// InDataType,
DsDataType
,
// DsDatatype
GemmAccDataType
,
// AccDataType,
CDataType
,
// OutDataType,
3
,
// Rank
1
,
// NumReduceDim
ReduceAdd
,
PassThrough
,
OutElementwiseOperation
,
256
,
// BlockSize_,
CShuffleBlockTransferScalarPerVector_NPerBlock
,
// MThreadSliceSize_,
1
,
// KThreadSliceSize_,
0
,
// InSrcVectorDim_,
CShuffleBlockTransferScalarPerVector_NPerBlock
,
// InSrcVectorSize_,
CShuffleBlockTransferScalarPerVector_NPerBlock
,
// OutDstVectorSize_
decltype
(
DsVectorLengthSequence
)
>
;
// Invoker
struct
Invoker
:
public
BaseInvoker
{
float
RunReduce
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
{
static
constexpr
index_t
NumInDim
=
3
;
static
constexpr
index_t
NumOutDim
=
2
;
std
::
array
<
ck
::
index_t
,
NumInDim
>
in_lengths
=
{
arg
.
KBatch
,
arg
.
M
,
arg
.
N
};
std
::
array
<
ck
::
index_t
,
NumOutDim
>
out_lengths
=
{
arg
.
M
,
arg
.
N
};
std
::
array
<
ck
::
index_t
,
NumInDim
>
in_strides
;
std
::
array
<
ck
::
index_t
,
NumOutDim
>
out_strides
;
if
constexpr
(
std
::
is_same
<
CLayout
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
in_strides
=
{
arg
.
M
*
arg
.
N
,
arg
.
N
,
1
};
out_strides
=
{
arg
.
N
,
1
};
}
else
{
in_strides
=
{
arg
.
M
*
arg
.
N
,
1
,
arg
.
M
};
out_strides
=
{
1
,
arg
.
M
};
}
std
::
array
<
int
,
1
>
reduce_dims
{
0
};
std
::
array
<
std
::
array
<
index_t
,
NumOutDim
>
,
NumDTensor
>
DsLengths
;
std
::
array
<
std
::
array
<
index_t
,
NumOutDim
>
,
NumDTensor
>
DsStrides
;
static_for
<
0
,
NumDTensor
,
1
>
{}([
&
](
auto
i
)
{
DsLengths
[
i
]
=
out_lengths
;
using
DLayout
=
remove_cvref_t
<
tuple_element_t
<
i
.
value
,
DsLayout
>>
;
if
constexpr
(
std
::
is_same
<
DLayout
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
DsStrides
[
i
]
=
{
arg
.
StrideDs
[
i
],
1
};
}
else
{
DsStrides
[
i
]
=
{
1
,
arg
.
StrideDs
[
i
]};
}
});
auto
reduce
=
DeviceReduceInstance
{};
auto
argument_ptr
=
reduce
.
MakeArgumentPointer
(
in_lengths
,
in_strides
,
DsLengths
,
DsStrides
,
out_lengths
,
out_strides
,
reduce_dims
,
arg
.
p_workspace_
,
arg
.
p_ds
,
arg
.
p_c_grid
,
PassThrough
{},
OutElementwiseOperation
{});
auto
invoker_ptr
=
reduce
.
MakeInvokerPointer
();
float
ave_time
=
0
;
if
(
reduce
.
IsSupportedArgument
(
argument_ptr
.
get
()))
{
ave_time
=
invoker_ptr
->
Run
(
argument_ptr
.
get
(),
stream_config
);
}
else
{
throw
std
::
runtime_error
(
"The runtime parameters seems not supported by the device instance, exiting!"
);
}
return
ave_time
;
}
float
Run
(
const
Argument
&
arg_
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
{
auto
arg
=
*
dynamic_cast
<
const
typename
GridwiseGemm
::
Argument
*>
(
&
arg_
);
if
(
!
(
!
(
arg
.
IsReduceAdd
()
||
NumDTensor
>
0
)
&&
std
::
is_same
<
CDataType
,
ReduceDataType
>::
value
))
{
if
(
arg
.
p_workspace_
==
nullptr
)
{
throw
std
::
runtime_error
(
"using reduce , but empty workspace!"
);
}
arg
.
p_c_grid
=
reinterpret_cast
<
ReduceDataType
*>
(
arg
.
p_workspace_
);
}
if
(
stream_config
.
log_level_
>
0
)
{
arg
.
Print
();
}
if
(
!
GridwiseGemm
::
CheckValidity
(
arg
))
{
throw
std
::
runtime_error
(
"wrong! GridwiseGemm has invalid setting"
);
}
index_t
gdx
,
gdy
,
gdz
;
std
::
tie
(
gdx
,
gdy
,
gdz
)
=
GridwiseGemm
::
CalculateGridSize
(
arg
.
M
,
arg
.
N
,
arg
.
KBatch
);
float
ave_time
=
0
;
index_t
k_grain
=
arg
.
KBatch
*
KPerBlock
;
index_t
K_split
=
(
arg
.
K
+
k_grain
-
1
)
/
k_grain
*
KPerBlock
;
const
bool
has_main_k_block_loop
=
GridwiseGemm
::
CalculateHasMainKBlockLoop
(
K_split
);
const
auto
Run
=
[
&
](
const
auto
&
kernel
)
{
if
(
stream_config
.
flush_cache
)
{
ck
::
utility
::
RotatingMemWrapper
<
typename
GridwiseGemm
::
Argument
>
rotating_mem
(
arg
,
stream_config
.
rotating_count
,
arg
.
M
*
arg
.
K
*
sizeof
(
ADataType
),
arg
.
K
*
arg
.
N
*
sizeof
(
BDataType
));
rotating_mem
.
Print
();
auto
run_flush_cache
=
[
&
]()
{
// flush icache
ck
::
utility
::
flush_icache
();
// rotating mem
rotating_mem
.
Next
();
};
ave_time
=
ck
::
utility
::
launch_and_time_kernel_with_preprocess
<
false
>
(
stream_config
,
run_flush_cache
,
kernel
,
dim3
(
gdx
,
gdy
,
gdz
),
dim3
(
BlockSize
),
0
,
arg
);
}
else
{
ave_time
=
launch_and_time_kernel
(
stream_config
,
kernel
,
dim3
(
gdx
,
gdy
,
gdz
),
dim3
(
BlockSize
),
0
,
arg
);
}
};
constexpr
index_t
minimum_occupancy
=
BlkGemmPipeSched
==
BlockGemmPipelineScheduler
::
Intrawave
?
1
:
2
;
if
(
has_main_k_block_loop
)
{
// Tail number always full
if
constexpr
(
BlkGemmPipelineVer
==
BlockGemmPipelineVersion
::
v1
||
BlkGemmPipelineVer
==
BlockGemmPipelineVersion
::
v3
)
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
Set
,
minimum_occupancy
>
;
Run
(
kernel
);
}
// Tail number could be One to Seven
else
if
constexpr
(
BlkGemmPipelineVer
==
BlockGemmPipelineVersion
::
v2
)
{
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
One
)
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
Set
,
minimum_occupancy
,
TailNumber
::
One
>
;
Run
(
kernel
);
}
else
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Full
)
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
Set
,
minimum_occupancy
,
TailNumber
::
Full
>
;
Run
(
kernel
);
}
if
constexpr
(
GridwiseGemm
::
BlockwiseGemmPipe
::
PrefetchStages
>
2
)
{
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Two
)
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
Set
,
minimum_occupancy
,
TailNumber
::
Two
>
;
Run
(
kernel
);
}
}
if
constexpr
(
GridwiseGemm
::
BlockwiseGemmPipe
::
PrefetchStages
>
3
)
{
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Three
)
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
Set
,
minimum_occupancy
,
TailNumber
::
Three
>
;
Run
(
kernel
);
}
}
if
constexpr
(
GridwiseGemm
::
BlockwiseGemmPipe
::
PrefetchStages
>
4
)
{
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Four
)
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
Set
,
minimum_occupancy
,
TailNumber
::
Four
>
;
Run
(
kernel
);
}
}
if
constexpr
(
GridwiseGemm
::
BlockwiseGemmPipe
::
PrefetchStages
>
5
)
{
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Five
)
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
Set
,
minimum_occupancy
,
TailNumber
::
Five
>
;
Run
(
kernel
);
}
}
if
constexpr
(
GridwiseGemm
::
BlockwiseGemmPipe
::
PrefetchStages
>
6
)
{
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Six
)
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
Set
,
minimum_occupancy
,
TailNumber
::
Six
>
;
Run
(
kernel
);
}
}
if
constexpr
(
GridwiseGemm
::
BlockwiseGemmPipe
::
PrefetchStages
>
7
)
{
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Seven
)
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
Set
,
minimum_occupancy
,
TailNumber
::
Seven
>
;
Run
(
kernel
);
}
}
}
// Tail number could be Odd or Even
else
if
constexpr
(
BlkGemmPipelineVer
==
BlockGemmPipelineVersion
::
v4
)
{
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Odd
)
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3_2lds
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
Set
,
minimum_occupancy
,
TailNumber
::
Odd
>
;
Run
(
kernel
);
}
else
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3_2lds
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
Set
,
minimum_occupancy
,
TailNumber
::
Even
>
;
Run
(
kernel
);
}
}
else
{
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Odd
)
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
Set
,
minimum_occupancy
,
TailNumber
::
Odd
>
;
Run
(
kernel
);
}
else
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
Set
,
minimum_occupancy
,
TailNumber
::
Even
>
;
Run
(
kernel
);
}
}
}
else
{
// Tail number always 1
if
constexpr
(
BlkGemmPipelineVer
==
BlockGemmPipelineVersion
::
v1
)
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3
<
GridwiseGemm
,
false
,
InMemoryDataOperationEnum
::
Set
,
minimum_occupancy
>
;
Run
(
kernel
);
}
}
if
(
!
(
!
(
arg
.
IsReduceAdd
()
||
NumDTensor
>
0
)
&&
std
::
is_same
<
CDataType
,
ReduceDataType
>::
value
))
{
// reduce c data
ave_time
+=
RunReduce
(
arg_
,
stream_config
);
}
return
ave_time
;
}
// polymorphic
float
Run
(
const
BaseArgument
*
p_arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
override
{
return
Run
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
),
stream_config
);
}
};
static
constexpr
bool
IsValidCompilationParameter
()
{
// TODO: properly implement this check
return
true
;
}
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
if
(
!
ck
::
is_xdl_supported
())
{
return
false
;
}
if
((
arg
.
K
%
AK1
!=
0
||
arg
.
K
%
BK1
!=
0
)
&&
!
(
GemmSpec
==
GemmSpecialization
::
MKPadding
||
GemmSpec
==
GemmSpecialization
::
NKPadding
||
GemmSpec
==
GemmSpecialization
::
MNKPadding
||
GemmSpec
==
GemmSpecialization
::
KPadding
))
{
return
false
;
}
return
GridwiseGemm
::
CheckValidity
(
arg
);
}
// polymorphic
bool
IsSupportedArgument
(
const
BaseArgument
*
p_arg
)
override
{
return
IsSupportedArgument
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
));
}
static
auto
MakeArgument
(
const
ADataType
*
p_a
,
const
BDataType
*
p_b
,
const
std
::
array
<
const
void
*
,
NumDTensor
>
p_ds
,
CDataType
*
p_c
,
index_t
M
,
index_t
N
,
index_t
K
,
index_t
StrideA
,
index_t
StrideB
,
std
::
array
<
ck
::
index_t
,
NumDTensor
>
StrideDs
,
index_t
StrideC
,
index_t
KBatch
,
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
)
{
return
Argument
{
p_a
,
p_b
,
p_ds
,
p_c
,
M
,
N
,
K
,
StrideA
,
StrideB
,
StrideDs
,
StrideC
,
KBatch
};
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
// polymorphic
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
void
*
p_a
,
const
void
*
p_b
,
std
::
array
<
const
void
*
,
NumDTensor
>
p_ds
,
void
*
p_c
,
index_t
M
,
index_t
N
,
index_t
K
,
index_t
StrideA
,
index_t
StrideB
,
std
::
array
<
ck
::
index_t
,
NumDTensor
>
StrideDs
,
index_t
StrideC
,
index_t
KBatch
,
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
)
override
{
return
std
::
make_unique
<
Argument
>
(
static_cast
<
const
ADataType
*>
(
p_a
),
static_cast
<
const
BDataType
*>
(
p_b
),
p_ds
,
static_cast
<
CDataType
*>
(
p_c
),
M
,
N
,
K
,
StrideA
,
StrideB
,
StrideDs
,
StrideC
,
KBatch
);
}
// polymorphic
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
override
{
return
std
::
make_unique
<
Invoker
>
(
Invoker
{});
}
// polymorphic
std
::
string
GetTypeString
()
const
override
{
auto
str
=
std
::
stringstream
();
std
::
map
<
BlockGemmPipelineScheduler
,
std
::
string
>
BlkGemmPipelineSchedulerToString
{
{
BlockGemmPipelineScheduler
::
Intrawave
,
"Intrawave"
},
{
BlockGemmPipelineScheduler
::
Interwave
,
"Interwave"
}};
std
::
map
<
BlockGemmPipelineVersion
,
std
::
string
>
BlkGemmPipelineVersionToString
{
{
BlockGemmPipelineVersion
::
v1
,
"v1"
},
{
BlockGemmPipelineVersion
::
v2
,
"v2"
},
{
BlockGemmPipelineVersion
::
v3
,
"v3"
},
{
BlockGemmPipelineVersion
::
v4
,
"v4"
},
{
BlockGemmPipelineVersion
::
v5
,
"v5"
}};
// clang-format off
str
<<
"DeviceGemmXdlUniversalReduce"
<<
"<"
<<
getGemmSpecializationString
(
GemmSpec
)
<<
", "
<<
std
::
string
(
ALayout
::
name
)[
0
]
<<
std
::
string
(
BLayout
::
name
)[
0
]
<<
std
::
string
(
CLayout
::
name
)[
0
]
<<
">"
<<
" BlkSize: "
<<
BlockSize
<<
", "
<<
"BlkTile: "
<<
MPerBlock
<<
"x"
<<
NPerBlock
<<
"x"
<<
KPerBlock
<<
", "
<<
"WaveTile: "
<<
MPerXDL
<<
"x"
<<
NPerXDL
<<
", "
<<
"WaveMap: "
<<
MXdlPerWave
<<
"x"
<<
NXdlPerWave
<<
", "
<<
"VmemReadVec: "
<<
ABlockTransferSrcScalarPerVector
<<
"x"
<<
BBlockTransferSrcScalarPerVector
<<
", "
<<
"BlkGemmPipelineScheduler: "
<<
BlkGemmPipelineSchedulerToString
[
BlkGemmPipeSched
]
<<
", "
<<
"BlkGemmPipelineVersion: "
<<
BlkGemmPipelineVersionToString
[
BlkGemmPipelineVer
]
<<
", "
<<
"BlkGemmPipelinePrefetchStages: "
<<
GridwiseGemm
::
BlockwiseGemmPipe
::
PrefetchStages
;
// clang-format on
return
str
.
str
();
}
size_t
GetWorkSpaceSize
(
const
BaseArgument
*
p_arg
)
const
override
{
auto
arg
=
*
dynamic_cast
<
const
Argument
*>
(
p_arg
);
if
(
!
(
!
(
arg
.
IsReduceAdd
()
||
NumDTensor
>
0
)
&&
std
::
is_same
<
CDataType
,
ReduceDataType
>::
value
))
{
std
::
cout
<<
"using workspace"
<<
std
::
endl
;
return
arg
.
M
*
arg
.
N
*
arg
.
KBatch
*
sizeof
(
ReduceDataType
);
}
return
0
;
}
};
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp
View file @
bd689f40
...
...
@@ -36,7 +36,7 @@ template <typename GridwiseGemm,
typename
BGridDesc_BK0_N_K1
,
typename
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
ComputePtrOffsetOfBatch
,
index_t
Num
Batch
ToMerge
,
index_t
Num
Groups
ToMerge
,
bool
HasMainKBlockLoop
,
InMemoryDataOperationEnum
CGlobalMemoryDataOperation
,
index_t
MinimumOccupancy
=
1
,
...
...
@@ -56,7 +56,7 @@ __global__ void
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx94__))
const
index_t
g_idx
=
__builtin_amdgcn_readfirstlane
(
blockIdx
.
z
*
Num
Batch
ToMerge
);
const
index_t
g_idx
=
__builtin_amdgcn_readfirstlane
(
blockIdx
.
z
*
Num
Groups
ToMerge
);
const
index_t
k_idx
=
__builtin_amdgcn_readfirstlane
(
blockIdx
.
y
*
num_k_per_block
);
const
long_index_t
a_batch_offset
=
...
...
@@ -92,7 +92,7 @@ template <typename GridwiseGemm,
typename
BGridDesc_BK0_N_K1
,
typename
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
ComputePtrOffsetOfBatch
,
index_t
Num
Batch
ToMerge
,
index_t
Num
Groups
ToMerge
,
bool
HasMainKBlockLoop
,
InMemoryDataOperationEnum
CGlobalMemoryDataOperation
,
index_t
MinimumOccupancy
=
1
,
...
...
@@ -113,7 +113,7 @@ __global__ void
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) || defined(__gfx950__))
// offset base pointer for each work-group
const
index_t
g_idx
=
__builtin_amdgcn_readfirstlane
(
blockIdx
.
z
*
Num
Batch
ToMerge
);
const
index_t
g_idx
=
__builtin_amdgcn_readfirstlane
(
blockIdx
.
z
*
Num
Groups
ToMerge
);
const
index_t
k_idx
=
__builtin_amdgcn_readfirstlane
(
blockIdx
.
y
*
num_k_per_block
);
const
long_index_t
a_batch_offset
=
...
...
@@ -189,7 +189,7 @@ template <ck::index_t NDimSpatial,
index_t
CBlockTransferScalarPerVector_NWaveNPerXdl
,
BlockGemmPipelineScheduler
BlkGemmPipeSched
=
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
BlkGemmPipelineVer
=
BlockGemmPipelineVersion
::
v1
,
index_t
Num
Batch
ToMerge
=
1
,
index_t
Num
Groups
ToMerge
=
1
,
typename
ComputeTypeA
=
InDataType
,
typename
ComputeTypeB
=
ComputeTypeA
>
struct
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
...
...
@@ -238,7 +238,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
NPerBlock
,
K1Number
,
KPerBlock
/
K1Number
,
Num
Batch
ToMerge
,
Num
Groups
ToMerge
,
ConvBackwardWeightSpecialization
>
{};
static
constexpr
auto
conv_to_gemm_transformer_v1
=
...
...
@@ -638,7 +638,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
index_t
gdx
,
gdy
,
gdz
;
std
::
tie
(
gdx
,
gdy
,
gdz
)
=
GridwiseGemm
::
CalculateGridSize
(
gemm_arg
.
M
,
gemm_arg
.
N
,
gemm_arg
.
KBatch
,
arg
.
Conv_G_
/
Num
Batch
ToMerge
);
gemm_arg
.
M
,
gemm_arg
.
N
,
gemm_arg
.
KBatch
,
arg
.
Conv_G_
/
Num
Groups
ToMerge
);
float
ave_time
=
0
;
...
...
@@ -724,7 +724,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
remove_reference_t
<
DeviceOp
::
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
>
,
ComputePtrOffsetOfStridedBatch
<
I1
,
I1
,
I0
>
,
Num
Batch
ToMerge
,
Num
Groups
ToMerge
,
true
,
InMemoryDataOperationEnum
::
AtomicAdd
,
minimum_occupancy
>
;
...
...
@@ -739,7 +739,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
remove_reference_t
<
DeviceOp
::
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
>
,
ComputePtrOffsetOfStridedBatch
<
I1
,
I1
,
I0
>
,
Num
Batch
ToMerge
,
Num
Groups
ToMerge
,
true
,
InMemoryDataOperationEnum
::
Set
,
minimum_occupancy
>
;
...
...
@@ -760,7 +760,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
remove_reference_t
<
DeviceOp
::
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
>
,
ComputePtrOffsetOfStridedBatch
<
I1
,
I1
,
I0
>
,
Num
Batch
ToMerge
,
Num
Groups
ToMerge
,
true
,
InMemoryDataOperationEnum
::
AtomicAdd
,
minimum_occupancy
,
...
...
@@ -777,7 +777,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
remove_reference_t
<
DeviceOp
::
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
>
,
ComputePtrOffsetOfStridedBatch
<
I1
,
I1
,
I0
>
,
Num
Batch
ToMerge
,
Num
Groups
ToMerge
,
true
,
InMemoryDataOperationEnum
::
AtomicAdd
,
minimum_occupancy
,
...
...
@@ -796,7 +796,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
remove_reference_t
<
DeviceOp
::
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
>
,
ComputePtrOffsetOfStridedBatch
<
I1
,
I1
,
I0
>
,
Num
Batch
ToMerge
,
Num
Groups
ToMerge
,
true
,
InMemoryDataOperationEnum
::
AtomicAdd
,
minimum_occupancy
,
...
...
@@ -817,7 +817,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
remove_reference_t
<
DeviceOp
::
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
>
,
ComputePtrOffsetOfStridedBatch
<
I1
,
I1
,
I0
>
,
Num
Batch
ToMerge
,
Num
Groups
ToMerge
,
true
,
InMemoryDataOperationEnum
::
AtomicAdd
,
minimum_occupancy
,
...
...
@@ -838,7 +838,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
remove_reference_t
<
DeviceOp
::
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
>
,
ComputePtrOffsetOfStridedBatch
<
I1
,
I1
,
I0
>
,
Num
Batch
ToMerge
,
Num
Groups
ToMerge
,
true
,
InMemoryDataOperationEnum
::
AtomicAdd
,
minimum_occupancy
,
...
...
@@ -859,7 +859,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
remove_reference_t
<
DeviceOp
::
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
>
,
ComputePtrOffsetOfStridedBatch
<
I1
,
I1
,
I0
>
,
Num
Batch
ToMerge
,
Num
Groups
ToMerge
,
true
,
InMemoryDataOperationEnum
::
AtomicAdd
,
minimum_occupancy
,
...
...
@@ -879,7 +879,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
remove_reference_t
<
DeviceOp
::
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
>
,
ComputePtrOffsetOfStridedBatch
<
I1
,
I1
,
I0
>
,
Num
Batch
ToMerge
,
Num
Groups
ToMerge
,
true
,
InMemoryDataOperationEnum
::
AtomicAdd
,
minimum_occupancy
,
...
...
@@ -900,7 +900,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
remove_reference_t
<
DeviceOp
::
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
>
,
ComputePtrOffsetOfStridedBatch
<
I1
,
I1
,
I0
>
,
Num
Batch
ToMerge
,
Num
Groups
ToMerge
,
true
,
InMemoryDataOperationEnum
::
AtomicAdd
,
minimum_occupancy
,
...
...
@@ -920,7 +920,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
remove_reference_t
<
DeviceOp
::
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
>
,
ComputePtrOffsetOfStridedBatch
<
I1
,
I1
,
I0
>
,
Num
Batch
ToMerge
,
Num
Groups
ToMerge
,
true
,
InMemoryDataOperationEnum
::
Set
,
minimum_occupancy
,
...
...
@@ -937,7 +937,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
remove_reference_t
<
DeviceOp
::
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
>
,
ComputePtrOffsetOfStridedBatch
<
I1
,
I1
,
I0
>
,
Num
Batch
ToMerge
,
Num
Groups
ToMerge
,
true
,
InMemoryDataOperationEnum
::
Set
,
minimum_occupancy
,
...
...
@@ -956,7 +956,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
remove_reference_t
<
DeviceOp
::
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
>
,
ComputePtrOffsetOfStridedBatch
<
I1
,
I1
,
I0
>
,
Num
Batch
ToMerge
,
Num
Groups
ToMerge
,
true
,
InMemoryDataOperationEnum
::
Set
,
minimum_occupancy
,
...
...
@@ -977,7 +977,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
remove_reference_t
<
DeviceOp
::
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
>
,
ComputePtrOffsetOfStridedBatch
<
I1
,
I1
,
I0
>
,
Num
Batch
ToMerge
,
Num
Groups
ToMerge
,
true
,
InMemoryDataOperationEnum
::
Set
,
minimum_occupancy
,
...
...
@@ -998,7 +998,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
remove_reference_t
<
DeviceOp
::
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
>
,
ComputePtrOffsetOfStridedBatch
<
I1
,
I1
,
I0
>
,
Num
Batch
ToMerge
,
Num
Groups
ToMerge
,
true
,
InMemoryDataOperationEnum
::
Set
,
minimum_occupancy
,
...
...
@@ -1019,7 +1019,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
remove_reference_t
<
DeviceOp
::
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
>
,
ComputePtrOffsetOfStridedBatch
<
I1
,
I1
,
I0
>
,
Num
Batch
ToMerge
,
Num
Groups
ToMerge
,
true
,
InMemoryDataOperationEnum
::
Set
,
minimum_occupancy
,
...
...
@@ -1039,7 +1039,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
remove_reference_t
<
DeviceOp
::
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
>
,
ComputePtrOffsetOfStridedBatch
<
I1
,
I1
,
I0
>
,
Num
Batch
ToMerge
,
Num
Groups
ToMerge
,
true
,
InMemoryDataOperationEnum
::
Set
,
minimum_occupancy
,
...
...
@@ -1060,7 +1060,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
remove_reference_t
<
DeviceOp
::
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
>
,
ComputePtrOffsetOfStridedBatch
<
I1
,
I1
,
I0
>
,
Num
Batch
ToMerge
,
Num
Groups
ToMerge
,
true
,
InMemoryDataOperationEnum
::
Set
,
minimum_occupancy
,
...
...
@@ -1084,7 +1084,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
remove_reference_t
<
DeviceOp
::
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
>
,
ComputePtrOffsetOfStridedBatch
<
I1
,
I1
,
I0
>
,
Num
Batch
ToMerge
,
Num
Groups
ToMerge
,
true
,
InMemoryDataOperationEnum
::
AtomicAdd
,
minimum_occupancy
,
...
...
@@ -1100,7 +1100,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
remove_reference_t
<
DeviceOp
::
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
>
,
ComputePtrOffsetOfStridedBatch
<
I1
,
I1
,
I0
>
,
Num
Batch
ToMerge
,
Num
Groups
ToMerge
,
true
,
InMemoryDataOperationEnum
::
AtomicAdd
,
minimum_occupancy
,
...
...
@@ -1119,7 +1119,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
remove_reference_t
<
DeviceOp
::
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
>
,
ComputePtrOffsetOfStridedBatch
<
I1
,
I1
,
I0
>
,
Num
Batch
ToMerge
,
Num
Groups
ToMerge
,
true
,
InMemoryDataOperationEnum
::
Set
,
minimum_occupancy
,
...
...
@@ -1135,7 +1135,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
remove_reference_t
<
DeviceOp
::
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
>
,
ComputePtrOffsetOfStridedBatch
<
I1
,
I1
,
I0
>
,
Num
Batch
ToMerge
,
Num
Groups
ToMerge
,
true
,
InMemoryDataOperationEnum
::
Set
,
minimum_occupancy
,
...
...
@@ -1157,7 +1157,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
remove_reference_t
<
DeviceOp
::
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
>
,
ComputePtrOffsetOfStridedBatch
<
I1
,
I1
,
I0
>
,
Num
Batch
ToMerge
,
Num
Groups
ToMerge
,
true
,
InMemoryDataOperationEnum
::
AtomicAdd
,
minimum_occupancy
,
...
...
@@ -1173,7 +1173,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
remove_reference_t
<
DeviceOp
::
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
>
,
ComputePtrOffsetOfStridedBatch
<
I1
,
I1
,
I0
>
,
Num
Batch
ToMerge
,
Num
Groups
ToMerge
,
true
,
InMemoryDataOperationEnum
::
AtomicAdd
,
minimum_occupancy
,
...
...
@@ -1192,7 +1192,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
remove_reference_t
<
DeviceOp
::
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
>
,
ComputePtrOffsetOfStridedBatch
<
I1
,
I1
,
I0
>
,
Num
Batch
ToMerge
,
Num
Groups
ToMerge
,
true
,
InMemoryDataOperationEnum
::
Set
,
minimum_occupancy
,
...
...
@@ -1208,7 +1208,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
remove_reference_t
<
DeviceOp
::
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
>
,
ComputePtrOffsetOfStridedBatch
<
I1
,
I1
,
I0
>
,
Num
Batch
ToMerge
,
Num
Groups
ToMerge
,
true
,
InMemoryDataOperationEnum
::
Set
,
minimum_occupancy
,
...
...
@@ -1232,7 +1232,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
remove_reference_t
<
DeviceOp
::
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
>
,
ComputePtrOffsetOfStridedBatch
<
I1
,
I1
,
I0
>
,
Num
Batch
ToMerge
,
Num
Groups
ToMerge
,
false
,
InMemoryDataOperationEnum
::
AtomicAdd
,
minimum_occupancy
>
;
...
...
@@ -1247,7 +1247,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
remove_reference_t
<
DeviceOp
::
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
>
,
ComputePtrOffsetOfStridedBatch
<
I1
,
I1
,
I0
>
,
Num
Batch
ToMerge
,
Num
Groups
ToMerge
,
false
,
InMemoryDataOperationEnum
::
Set
,
minimum_occupancy
>
;
...
...
@@ -1389,7 +1389,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
}
}
if
constexpr
(
Num
Batch
ToMerge
>
1
)
if
constexpr
(
Num
Groups
ToMerge
>
1
)
{
// support only if whole M and N can be proccessed on one block
if
(
!
(
GemmM
<=
MPerBlock
&&
GemmN
<=
NPerBlock
))
...
...
@@ -1400,7 +1400,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
{
return
false
;
}
if
(
arg
.
Conv_G_
%
Num
Batch
ToMerge
!=
0
)
if
(
arg
.
Conv_G_
%
Num
Groups
ToMerge
!=
0
)
{
return
false
;
}
...
...
@@ -1563,7 +1563,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
<<
BlkGemmPipelineSchedulerToString
[
BlkGemmPipeSched
]
<<
", "
<<
"BlkGemmPipelineVersion: "
<<
BlkGemmPipelineVersionToString
[
BlkGemmPipelineVer
]
<<
", "
<<
Num
Batch
ToMerge
<<
Num
Groups
ToMerge
<<
">"
;
// clang-format on
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp
View file @
bd689f40
...
...
@@ -238,37 +238,17 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
static
constexpr
auto
conv_to_gemm_transformer
=
TransformConvFwdToGemm
<
NDimSpatial
,
ConvForwardSpecialization
>
{};
using
ConvToGemmFwdTransformer
=
TransformConvFwdToGemm
<
NDimSpatial
,
ConvForwardSpecialization
>
;
static
constexpr
auto
matrix_padder
=
MatrixPadder
<
GemmSpec
,
index_t
,
index_t
,
index_t
>
{
MPerBlock
,
NPerBlock
,
K0PerBlock
};
template
<
typename
ALay
>
static
auto
MakeAGridDescriptor_AK0_M_AK1
(
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
a_g_n_c_wis_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
a_g_n_c_wis_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
b_g_k_c_xs_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
b_g_k_c_xs_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
e_g_n_k_wos_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
e_g_n_k_wos_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_dilations
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_left_pads
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_right_pads
)
MakeAGridDescriptor_AK0_M_AK1
(
const
ConvToGemmFwdTransformer
&
conv_to_gemm_transformer
)
{
const
auto
in_gemmmraw_gemmkraw_desc
=
conv_to_gemm_transformer
.
template
MakeADescriptor_M_K
<
ALay
>(
a_g_n_c_wis_lengths
,
a_g_n_c_wis_strides
,
b_g_k_c_xs_lengths
,
b_g_k_c_xs_strides
,
e_g_n_k_wos_lengths
,
e_g_n_k_wos_strides
,
conv_filter_strides
,
conv_filter_dilations
,
input_left_pads
,
input_right_pads
,
a_g_n_c_wis_lengths
[
I1
]);
conv_to_gemm_transformer
.
template
MakeADescriptor_M_K
<
ALay
>();
const
auto
in_gemmm_gemmk_desc
=
matrix_padder
.
PadADescriptor_M_K
(
in_gemmmraw_gemmkraw_desc
);
...
...
@@ -286,12 +266,10 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
template
<
typename
BLay
>
static
auto
MakeBGridDescriptor_BK0_N_BK1
(
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
b_g_k_c_xs_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
b_g_k_c_xs_strides
)
MakeBGridDescriptor_BK0_N_BK1
(
const
ConvToGemmFwdTransformer
&
conv_to_gemm_transformer
)
{
const
auto
wei_gemmnraw_gemmkraw_desc
=
conv_to_gemm_transformer
.
template
MakeBDescriptor_N_K
<
BLay
>(
b_g_k_c_xs_lengths
,
b_g_k_c_xs_strides
);
conv_to_gemm_transformer
.
template
MakeBDescriptor_N_K
<
BLay
>();
const
auto
wei_gemmn_gemmk_desc
=
matrix_padder
.
PadBDescriptor_N_K
(
wei_gemmnraw_gemmkraw_desc
);
...
...
@@ -309,13 +287,10 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
}
template
<
typename
ELay
>
static
auto
MakeEGridDescriptor_M_N
(
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
e_g_n_k_wos_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
e_g_n_k_wos_strides
)
static
auto
MakeEGridDescriptor_M_N
(
const
ConvToGemmFwdTransformer
&
conv_to_gemm_transformer
)
{
const
auto
out_gemmmraw_gemmnraw_desc
=
conv_to_gemm_transformer
.
template
MakeCDescriptor_M_N
<
ELay
>(
e_g_n_k_wos_lengths
,
e_g_n_k_wos_strides
,
e_g_n_k_wos_lengths
[
I1
]);
conv_to_gemm_transformer
.
template
MakeCDescriptor_M_N
<
ELay
>();
const
auto
out_gemmm_gemmn_desc
=
matrix_padder
.
PadCDescriptor_M_N
(
out_gemmmraw_gemmnraw_desc
);
...
...
@@ -323,27 +298,27 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
return
out_gemmm_gemmn_desc
;
}
static
auto
MakeDsGridDescriptor_M_N
(
const
std
::
array
<
std
::
array
<
index_t
,
NDimSpatial
+
3
>
,
NumDTensor
>&
ds_g_n_k_wos_lengths
,
const
std
::
array
<
std
::
array
<
index_t
,
NDimSpatial
+
3
>
,
NumDTensor
>&
ds_g_n_k_wos_strides
)
static
auto
MakeDsGridDescriptor_M_N
(
const
ConvToGemmFwdTransformer
&
conv_to_gemm_transformer
)
{
return
generate_tuple
(
[
&
](
auto
i
)
{
using
DLayout
=
remove_cvref_t
<
tuple_element_t
<
i
.
value
,
DsLayout
>>
;
return
DeviceOp
::
MakeEGridDescriptor_M_N
<
DLayout
>
(
ds_g_n_k_wos_lengths
[
i
],
ds_g_n_k_wos_strides
[
i
]);
return
DeviceOp
::
MakeEGridDescriptor_M_N
<
DLayout
>
(
conv_to_gemm_transformer
);
},
Number
<
NumDTensor
>
{});
}
// desc for problem definition
constexpr
static
ConvToGemmFwdTransformer
dummy_conv_to_gemm_transformer
;
using
AGridDesc_AK0_M_AK1
=
remove_cvref_t
<
decltype
(
MakeAGridDescriptor_AK0_M_AK1
<
ALayout
>
(
{},
{},
{},
{},
{},
{},
{},
{},
{},
{}))
>
;
using
BGridDesc_BK0_N_BK1
=
remove_cvref_t
<
decltype
(
MakeBGridDescriptor_BK0_N_BK1
<
BLayout
>
({},
{}))
>
;
using
DsGridDesc_M_N
=
remove_cvref_t
<
decltype
(
MakeDsGridDescriptor_M_N
({},
{}))
>
;
using
EGridDesc_M_N
=
remove_cvref_t
<
decltype
(
MakeEGridDescriptor_M_N
<
ELayout
>
({},
{}))
>
;
dummy_conv_to_gemm_transformer
))
>
;
using
BGridDesc_BK0_N_BK1
=
remove_cvref_t
<
decltype
(
MakeBGridDescriptor_BK0_N_BK1
<
BLayout
>
(
dummy_conv_to_gemm_transformer
))
>
;
using
DsGridDesc_M_N
=
remove_cvref_t
<
decltype
(
MakeDsGridDescriptor_M_N
(
dummy_conv_to_gemm_transformer
))
>
;
using
EGridDesc_M_N
=
remove_cvref_t
<
decltype
(
MakeEGridDescriptor_M_N
<
ELayout
>
(
dummy_conv_to_gemm_transformer
))
>
;
// GridwiseGemm
using
GridwiseGemm
=
...
...
@@ -426,21 +401,22 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
p_ds_grid_
{},
p_e_grid_
{
static_cast
<
EDataType
*>
(
p_e
)},
num_group_
{
a_g_n_c_wis_lengths
[
0
]},
conv_to_gemm_transformer_
{
a_g_n_c_wis_lengths
,
a_g_n_c_wis_strides
,
b_g_k_c_xs_lengths
,
b_g_k_c_xs_strides
,
e_g_n_k_wos_lengths
,
e_g_n_k_wos_strides
,
conv_filter_strides
,
conv_filter_dilations
,
input_left_pads
,
input_right_pads
},
a_grid_desc_ak0_m_ak1_
{
DeviceOp
::
MakeAGridDescriptor_AK0_M_AK1
<
ALayout
>
(
a_g_n_c_wis_lengths
,
a_g_n_c_wis_strides
,
b_g_k_c_xs_lengths
,
b_g_k_c_xs_strides
,
e_g_n_k_wos_lengths
,
e_g_n_k_wos_strides
,
conv_filter_strides
,
conv_filter_dilations
,
input_left_pads
,
input_right_pads
)},
b_grid_desc_bk0_n_bk1_
{
DeviceOp
::
MakeBGridDescriptor_BK0_N_BK1
<
BLayout
>
(
b_g_k_c_xs_lengths
,
b_g_k_c_xs_strides
)},
e_grid_desc_m_n_
{
DeviceOp
::
MakeEGridDescriptor_M_N
<
ELayout
>
(
e_g_n_k_wos_lengths
,
e_g_n_k_wos_strides
)},
DeviceOp
::
MakeAGridDescriptor_AK0_M_AK1
<
ALayout
>
(
conv_to_gemm_transformer_
)},
b_grid_desc_bk0_n_bk1_
{
DeviceOp
::
MakeBGridDescriptor_BK0_N_BK1
<
BLayout
>
(
conv_to_gemm_transformer_
)},
e_grid_desc_m_n_
{
DeviceOp
::
MakeEGridDescriptor_M_N
<
ELayout
>
(
conv_to_gemm_transformer_
)},
a_grid_desc_k0_m0_m1_k1_
{},
b_grid_desc_k0_n0_n1_k1_
{},
ds_grid_desc_m0_m10_m11_n0_n10_n11_
{},
...
...
@@ -471,6 +447,17 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
using
DLayout
=
remove_cvref_t
<
tuple_element_t
<
i
.
value
,
DsLayout
>>
;
using
DDataType
=
remove_cvref_t
<
tuple_element_t
<
i
.
value
,
DsDataType
>>
;
ConvToGemmFwdTransformer
conv_to_gemm_transformer_d
{
a_g_n_c_wis_lengths
,
a_g_n_c_wis_strides
,
b_g_k_c_xs_lengths
,
b_g_k_c_xs_strides
,
ds_g_n_k_wos_lengths
[
i
],
ds_g_n_k_wos_strides
[
i
],
conv_filter_strides
,
conv_filter_dilations
,
input_left_pads
,
input_right_pads
};
// D pointer
p_ds_grid_
(
i
)
=
static_cast
<
const
DDataType
*>
(
p_ds
[
i
]);
...
...
@@ -478,8 +465,8 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
compute_ptr_offset_of_batch_
.
BatchStrideDs_
(
i
)
=
ds_g_n_k_wos_strides
[
i
][
0
];
// D desc
ds_grid_desc_m_n_
(
i
)
=
DeviceOp
::
MakeEGridDescriptor_M_N
<
DLayout
>
(
ds_g_n_k_wos_lengths
[
i
],
ds_g_n_k_wos_strides
[
i
]
);
ds_grid_desc_m_n_
(
i
)
=
DeviceOp
::
MakeEGridDescriptor_M_N
<
DLayout
>
(
conv_to_gemm_transformer_d
);
});
// populate desc for Ds/E
...
...
@@ -523,6 +510,9 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
// tensor descriptors for problem definiton
index_t
num_group_
;
ConvToGemmFwdTransformer
conv_to_gemm_transformer_
;
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1_
;
BGridDesc_BK0_N_BK1
b_grid_desc_bk0_n_bk1_
;
DsGridDesc_M_N
ds_grid_desc_m_n_
;
...
...
@@ -846,6 +836,79 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
cde_element_op
};
}
static
auto
MakeArgument
(
const
void
*
p_a
,
const
void
*
p_b
,
const
std
::
array
<
const
void
*
,
NumDTensor
>&
p_ds
,
void
*
p_e
,
const
std
::
array
<
long_index_t
,
NDimSpatial
+
3
>&
a_g_n_c_wis_lengths
,
const
std
::
array
<
long_index_t
,
NDimSpatial
+
3
>&
a_g_n_c_wis_strides
,
const
std
::
array
<
long_index_t
,
NDimSpatial
+
3
>&
b_g_k_c_xs_lengths
,
const
std
::
array
<
long_index_t
,
NDimSpatial
+
3
>&
b_g_k_c_xs_strides
,
const
std
::
array
<
std
::
array
<
long_index_t
,
NDimSpatial
+
3
>
,
NumDTensor
>&
ds_g_n_k_wos_lengths
,
const
std
::
array
<
std
::
array
<
long_index_t
,
NDimSpatial
+
3
>
,
NumDTensor
>&
ds_g_n_k_wos_strides
,
const
std
::
array
<
long_index_t
,
NDimSpatial
+
3
>&
e_g_n_k_wos_lengths
,
const
std
::
array
<
long_index_t
,
NDimSpatial
+
3
>&
e_g_n_k_wos_strides
,
const
std
::
array
<
long_index_t
,
NDimSpatial
>&
conv_filter_strides
,
const
std
::
array
<
long_index_t
,
NDimSpatial
>&
conv_filter_dilations
,
const
std
::
array
<
long_index_t
,
NDimSpatial
>&
input_left_pads
,
const
std
::
array
<
long_index_t
,
NDimSpatial
>&
input_right_pads
,
const
AElementwiseOperation
&
a_element_op
,
const
BElementwiseOperation
&
b_element_op
,
const
CDEElementwiseOperation
&
cde_element_op
)
{
std
::
array
<
index_t
,
NDimSpatial
+
3
>
a_g_n_c_wis_lengths_i32
;
std
::
array
<
index_t
,
NDimSpatial
+
3
>
a_g_n_c_wis_strides_i32
;
std
::
array
<
index_t
,
NDimSpatial
+
3
>
b_g_k_c_xs_lengths_i32
;
std
::
array
<
index_t
,
NDimSpatial
+
3
>
b_g_k_c_xs_strides_i32
;
std
::
array
<
std
::
array
<
index_t
,
NDimSpatial
+
3
>
,
NumDTensor
>
ds_g_n_k_wos_lengths_i32
;
std
::
array
<
std
::
array
<
index_t
,
NDimSpatial
+
3
>
,
NumDTensor
>
ds_g_n_k_wos_strides_i32
;
std
::
array
<
index_t
,
NDimSpatial
+
3
>
e_g_n_k_wos_lengths_i32
;
std
::
array
<
index_t
,
NDimSpatial
+
3
>
e_g_n_k_wos_strides_i32
;
std
::
array
<
index_t
,
NDimSpatial
>
conv_filter_strides_i32
;
std
::
array
<
index_t
,
NDimSpatial
>
conv_filter_dilations_i32
;
std
::
array
<
index_t
,
NDimSpatial
>
input_left_pads_i32
;
std
::
array
<
index_t
,
NDimSpatial
>
input_right_pads_i32
;
array_convert
(
a_g_n_c_wis_lengths_i32
,
a_g_n_c_wis_lengths
);
array_convert
(
a_g_n_c_wis_strides_i32
,
a_g_n_c_wis_strides
);
array_convert
(
b_g_k_c_xs_lengths_i32
,
b_g_k_c_xs_lengths
);
array_convert
(
b_g_k_c_xs_strides_i32
,
b_g_k_c_xs_strides
);
for
(
index_t
d
=
0
;
d
<
NumDTensor
;
d
++
)
{
array_convert
(
ds_g_n_k_wos_lengths_i32
[
d
],
ds_g_n_k_wos_lengths
[
d
]);
array_convert
(
ds_g_n_k_wos_strides_i32
[
d
],
ds_g_n_k_wos_strides
[
d
]);
}
array_convert
(
e_g_n_k_wos_lengths_i32
,
e_g_n_k_wos_lengths
);
array_convert
(
e_g_n_k_wos_strides_i32
,
e_g_n_k_wos_strides
);
array_convert
(
conv_filter_strides_i32
,
conv_filter_strides
);
array_convert
(
conv_filter_dilations_i32
,
conv_filter_dilations
);
array_convert
(
input_left_pads_i32
,
input_left_pads
);
array_convert
(
input_right_pads_i32
,
input_right_pads
);
return
Argument
{
p_a
,
p_b
,
p_ds
,
p_e
,
a_g_n_c_wis_lengths_i32
,
a_g_n_c_wis_strides_i32
,
b_g_k_c_xs_lengths_i32
,
b_g_k_c_xs_strides_i32
,
ds_g_n_k_wos_lengths_i32
,
ds_g_n_k_wos_strides_i32
,
e_g_n_k_wos_lengths_i32
,
e_g_n_k_wos_strides_i32
,
conv_filter_strides_i32
,
conv_filter_dilations_i32
,
input_left_pads_i32
,
input_right_pads_i32
,
a_element_op
,
b_element_op
,
cde_element_op
};
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
...
...
@@ -890,6 +953,79 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
cde_element_op
);
}
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
void
*
p_a
,
const
void
*
p_b
,
const
std
::
array
<
const
void
*
,
NumDTensor
>&
p_ds
,
void
*
p_e
,
const
std
::
array
<
long_index_t
,
NDimSpatial
+
3
>&
a_g_n_c_wis_lengths
,
const
std
::
array
<
long_index_t
,
NDimSpatial
+
3
>&
a_g_n_c_wis_strides
,
const
std
::
array
<
long_index_t
,
NDimSpatial
+
3
>&
b_g_k_c_xs_lengths
,
const
std
::
array
<
long_index_t
,
NDimSpatial
+
3
>&
b_g_k_c_xs_strides
,
const
std
::
array
<
std
::
array
<
long_index_t
,
NDimSpatial
+
3
>
,
NumDTensor
>&
ds_g_n_k_wos_lengths
,
const
std
::
array
<
std
::
array
<
long_index_t
,
NDimSpatial
+
3
>
,
NumDTensor
>&
ds_g_n_k_wos_strides
,
const
std
::
array
<
long_index_t
,
NDimSpatial
+
3
>&
e_g_n_k_wos_lengths
,
const
std
::
array
<
long_index_t
,
NDimSpatial
+
3
>&
e_g_n_k_wos_strides
,
const
std
::
array
<
long_index_t
,
NDimSpatial
>&
conv_filter_strides
,
const
std
::
array
<
long_index_t
,
NDimSpatial
>&
conv_filter_dilations
,
const
std
::
array
<
long_index_t
,
NDimSpatial
>&
input_left_pads
,
const
std
::
array
<
long_index_t
,
NDimSpatial
>&
input_right_pads
,
const
AElementwiseOperation
&
a_element_op
,
const
BElementwiseOperation
&
b_element_op
,
const
CDEElementwiseOperation
&
cde_element_op
)
override
{
std
::
array
<
index_t
,
NDimSpatial
+
3
>
a_g_n_c_wis_lengths_i32
;
std
::
array
<
index_t
,
NDimSpatial
+
3
>
a_g_n_c_wis_strides_i32
;
std
::
array
<
index_t
,
NDimSpatial
+
3
>
b_g_k_c_xs_lengths_i32
;
std
::
array
<
index_t
,
NDimSpatial
+
3
>
b_g_k_c_xs_strides_i32
;
std
::
array
<
std
::
array
<
index_t
,
NDimSpatial
+
3
>
,
NumDTensor
>
ds_g_n_k_wos_lengths_i32
;
std
::
array
<
std
::
array
<
index_t
,
NDimSpatial
+
3
>
,
NumDTensor
>
ds_g_n_k_wos_strides_i32
;
std
::
array
<
index_t
,
NDimSpatial
+
3
>
e_g_n_k_wos_lengths_i32
;
std
::
array
<
index_t
,
NDimSpatial
+
3
>
e_g_n_k_wos_strides_i32
;
std
::
array
<
index_t
,
NDimSpatial
>
conv_filter_strides_i32
;
std
::
array
<
index_t
,
NDimSpatial
>
conv_filter_dilations_i32
;
std
::
array
<
index_t
,
NDimSpatial
>
input_left_pads_i32
;
std
::
array
<
index_t
,
NDimSpatial
>
input_right_pads_i32
;
array_convert
(
a_g_n_c_wis_lengths_i32
,
a_g_n_c_wis_lengths
);
array_convert
(
a_g_n_c_wis_strides_i32
,
a_g_n_c_wis_strides
);
array_convert
(
b_g_k_c_xs_lengths_i32
,
b_g_k_c_xs_lengths
);
array_convert
(
b_g_k_c_xs_strides_i32
,
b_g_k_c_xs_strides
);
for
(
index_t
d
=
0
;
d
<
NumDTensor
;
d
++
)
{
array_convert
(
ds_g_n_k_wos_lengths_i32
[
d
],
ds_g_n_k_wos_lengths
[
d
]);
array_convert
(
ds_g_n_k_wos_strides_i32
[
d
],
ds_g_n_k_wos_strides
[
d
]);
}
array_convert
(
e_g_n_k_wos_lengths_i32
,
e_g_n_k_wos_lengths
);
array_convert
(
e_g_n_k_wos_strides_i32
,
e_g_n_k_wos_strides
);
array_convert
(
conv_filter_strides_i32
,
conv_filter_strides
);
array_convert
(
conv_filter_dilations_i32
,
conv_filter_dilations
);
array_convert
(
input_left_pads_i32
,
input_left_pads
);
array_convert
(
input_right_pads_i32
,
input_right_pads
);
return
std
::
make_unique
<
Argument
>
(
p_a
,
p_b
,
p_ds
,
p_e
,
a_g_n_c_wis_lengths_i32
,
a_g_n_c_wis_strides_i32
,
b_g_k_c_xs_lengths_i32
,
b_g_k_c_xs_strides_i32
,
ds_g_n_k_wos_lengths_i32
,
ds_g_n_k_wos_strides_i32
,
e_g_n_k_wos_lengths_i32
,
e_g_n_k_wos_strides_i32
,
conv_filter_strides_i32
,
conv_filter_dilations_i32
,
input_left_pads_i32
,
input_right_pads_i32
,
a_element_op
,
b_element_op
,
cde_element_op
);
}
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
override
{
return
std
::
make_unique
<
Invoker
>
(
Invoker
{});
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp
View file @
bd689f40
...
...
@@ -234,37 +234,17 @@ struct DeviceGroupedConvFwdDl_NHWC_KYXC_NHWK : public DeviceGroupedConvFwd<NDimS
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
static
constexpr
auto
conv_to_gemm_transformer
=
TransformConvFwdToGemm
<
NDimSpatial
,
ConvForwardSpecialization
>
{};
using
ConvToGemmFwdTransformer
=
TransformConvFwdToGemm
<
NDimSpatial
,
ConvForwardSpecialization
>
;
static
constexpr
auto
matrix_padder
=
MatrixPadder
<
GemmSpec
,
index_t
,
index_t
,
index_t
>
{
MPerBlock
,
NPerBlock
,
K0PerBlock
};
template
<
typename
ALay
>
static
auto
MakeAGridDescriptor_AK0_M_AK1
(
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
a_g_n_c_wis_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
a_g_n_c_wis_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
b_g_k_c_xs_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
b_g_k_c_xs_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
c_g_n_k_wos_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
c_g_n_k_wos_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_dilations
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_left_pads
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_right_pads
)
MakeAGridDescriptor_AK0_M_AK1
(
const
ConvToGemmFwdTransformer
&
conv_to_gemm_transformer
)
{
const
auto
in_gemmmraw_gemmkraw_desc
=
conv_to_gemm_transformer
.
template
MakeADescriptor_M_K
<
ALay
>(
a_g_n_c_wis_lengths
,
a_g_n_c_wis_strides
,
b_g_k_c_xs_lengths
,
b_g_k_c_xs_strides
,
c_g_n_k_wos_lengths
,
c_g_n_k_wos_strides
,
conv_filter_strides
,
conv_filter_dilations
,
input_left_pads
,
input_right_pads
,
a_g_n_c_wis_lengths
[
I1
]);
conv_to_gemm_transformer
.
template
MakeADescriptor_M_K
<
ALay
>();
const
auto
in_gemmm_gemmk_desc
=
matrix_padder
.
PadADescriptor_M_K
(
in_gemmmraw_gemmkraw_desc
);
...
...
@@ -283,12 +263,10 @@ struct DeviceGroupedConvFwdDl_NHWC_KYXC_NHWK : public DeviceGroupedConvFwd<NDimS
template
<
typename
BLay
>
static
auto
MakeBGridDescriptor_BK0_N_BK1
(
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
b_g_k_c_xs_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
b_g_k_c_xs_strides
)
MakeBGridDescriptor_BK0_N_BK1
(
const
ConvToGemmFwdTransformer
&
conv_to_gemm_transformer
)
{
const
auto
wei_gemmnraw_gemmkraw_desc
=
conv_to_gemm_transformer
.
template
MakeBDescriptor_N_K
<
BLay
>(
b_g_k_c_xs_lengths
,
b_g_k_c_xs_strides
);
conv_to_gemm_transformer
.
template
MakeBDescriptor_N_K
<
BLay
>();
const
auto
wei_gemmn_gemmk_desc
=
matrix_padder
.
PadBDescriptor_N_K
(
wei_gemmnraw_gemmkraw_desc
);
...
...
@@ -306,13 +284,10 @@ struct DeviceGroupedConvFwdDl_NHWC_KYXC_NHWK : public DeviceGroupedConvFwd<NDimS
}
template
<
typename
CLay
>
static
auto
MakeCGridDescriptor_M_N
(
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
c_g_n_k_wos_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
c_g_n_k_wos_strides
)
static
auto
MakeCGridDescriptor_M_N
(
const
ConvToGemmFwdTransformer
&
conv_to_gemm_transformer
)
{
const
auto
out_gemmmraw_gemmnraw_desc
=
conv_to_gemm_transformer
.
template
MakeCDescriptor_M_N
<
CLay
>(
c_g_n_k_wos_lengths
,
c_g_n_k_wos_strides
,
c_g_n_k_wos_lengths
[
I1
]);
conv_to_gemm_transformer
.
template
MakeCDescriptor_M_N
<
CLay
>();
const
auto
out_gemmm_gemmn_desc
=
matrix_padder
.
PadCDescriptor_M_N
(
out_gemmmraw_gemmnraw_desc
);
...
...
@@ -321,11 +296,13 @@ struct DeviceGroupedConvFwdDl_NHWC_KYXC_NHWK : public DeviceGroupedConvFwd<NDimS
}
// desc for problem definition
constexpr
static
ConvToGemmFwdTransformer
dummy_conv_to_gemm_transformer
;
using
AGridDesc_AK0_M_AK1
=
remove_cvref_t
<
decltype
(
MakeAGridDescriptor_AK0_M_AK1
<
ALayout
>
(
{},
{},
{},
{},
{},
{},
{},
{},
{},
{}))
>
;
using
BGridDesc_BK0_N_BK1
=
remove_cvref_t
<
decltype
(
MakeBGridDescriptor_BK0_N_BK1
<
BLayout
>
({},
{}))
>
;
using
CGridDesc_M_N
=
remove_cvref_t
<
decltype
(
MakeCGridDescriptor_M_N
<
CLayout
>
({},
{}))
>
;
dummy_conv_to_gemm_transformer
))
>
;
using
BGridDesc_BK0_N_BK1
=
remove_cvref_t
<
decltype
(
MakeBGridDescriptor_BK0_N_BK1
<
BLayout
>
(
dummy_conv_to_gemm_transformer
))
>
;
using
CGridDesc_M_N
=
remove_cvref_t
<
decltype
(
MakeCGridDescriptor_M_N
<
CLayout
>
(
dummy_conv_to_gemm_transformer
))
>
;
// GridwiseGemm
using
GridwiseGemm
=
...
...
@@ -396,21 +373,22 @@ struct DeviceGroupedConvFwdDl_NHWC_KYXC_NHWK : public DeviceGroupedConvFwd<NDimS
p_b_grid_
{
static_cast
<
const
BDataType
*>
(
p_b
)},
p_c_grid_
{
static_cast
<
CDataType
*>
(
p_c
)},
num_group_
{
a_g_n_c_wis_lengths
[
0
]},
conv_to_gemm_transformer_
{
a_g_n_c_wis_lengths
,
a_g_n_c_wis_strides
,
b_g_k_c_xs_lengths
,
b_g_k_c_xs_strides
,
e_g_n_k_wos_lengths
,
e_g_n_k_wos_strides
,
conv_filter_strides
,
conv_filter_dilations
,
input_left_pads
,
input_right_pads
},
a_grid_desc_ak0_m_ak1_
{
DeviceOp
::
MakeAGridDescriptor_AK0_M_AK1
<
ALayout
>
(
a_g_n_c_wis_lengths
,
a_g_n_c_wis_strides
,
b_g_k_c_xs_lengths
,
b_g_k_c_xs_strides
,
c_g_n_k_wos_lengths
,
c_g_n_k_wos_strides
,
conv_filter_strides
,
conv_filter_dilations
,
input_left_pads
,
input_right_pads
)},
b_grid_desc_bk0_n_bk1_
{
DeviceOp
::
MakeBGridDescriptor_BK0_N_BK1
<
BLayout
>
(
b_g_k_c_xs_lengths
,
b_g_k_c_xs_strides
)},
c_grid_desc_m_n_
{
DeviceOp
::
MakeCGridDescriptor_M_N
<
CLayout
>
(
c_g_n_k_wos_lengths
,
c_g_n_k_wos_strides
)},
DeviceOp
::
MakeAGridDescriptor_AK0_M_AK1
<
ALayout
>
(
conv_to_gemm_transformer_
)},
b_grid_desc_bk0_n_bk1_
{
DeviceOp
::
MakeBGridDescriptor_BK0_N_BK1
<
BLayout
>
(
conv_to_gemm_transformer_
)},
c_grid_desc_m_n_
{
DeviceOp
::
MakeCGridDescriptor_M_N
<
CLayout
>
(
conv_to_gemm_transformer_
)},
a_grid_desc_k0_m0_m1_k1_
{},
b_grid_desc_k0_n0_n1_k1_
{},
c_grid_desc_m0_m10_m11_n0_n10_n11_
{},
...
...
@@ -473,6 +451,9 @@ struct DeviceGroupedConvFwdDl_NHWC_KYXC_NHWK : public DeviceGroupedConvFwd<NDimS
// tensor descriptors for problem definiton
index_t
num_group_
;
ConvToGemmFwdTransformer
conv_to_gemm_transformer_
;
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1_
;
BGridDesc_BK0_N_BK1
b_grid_desc_bk0_n_bk1_
;
CGridDesc_M_N
c_grid_desc_m_n_
;
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp
View file @
bd689f40
...
...
@@ -86,7 +86,6 @@ __global__ void
const
AElementwiseOperation
a_element_op
,
const
BElementwiseOperation
b_element_op
,
const
CDEElementwiseOperation
cde_element_op
,
const
index_t
groups_count
,
const
AGridDesc_AK0_M_AK1
a_grid_desc_k0_m_k1
,
const
BGridDesc_BK0_N_BK1
b_grid_desc_k0_n_k1
,
const
DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
...
...
@@ -101,10 +100,8 @@ __global__ void
defined(__gfx94__))
// offset base pointer for each work-group
const
index_t
num_blocks_per_batch
=
__builtin_amdgcn_readfirstlane
(
gridDim
.
y
/
groups_count
);
const
index_t
&
num_blocks_per_n
=
groups_count
;
const
index_t
g_idx
=
__builtin_amdgcn_readfirstlane
(
blockIdx
.
y
/
num_blocks_per_batch
);
const
index_t
n_idx
=
__builtin_amdgcn_readfirstlane
(
blockIdx
.
y
/
num_blocks_per_n
);
const
index_t
g_idx
=
__builtin_amdgcn_readfirstlane
(
blockIdx
.
y
);
const
index_t
n_idx
=
__builtin_amdgcn_readfirstlane
(
blockIdx
.
z
);
const
long_index_t
e_batch_offset
=
amd_wave_read_first_lane
(
compute_ptr_offset_of_groups
.
GetEPtrOffset
(
g_idx
));
...
...
@@ -200,7 +197,6 @@ __global__ void
ignore
=
p_bs_grid
;
ignore
=
p_ds_grid
;
ignore
=
p_e_grid
;
ignore
=
groups_count
;
ignore
=
a_grid_desc_k0_m_k1
;
ignore
=
b_grid_desc_k0_n_k1
;
ignore
=
ds_grid_desc_mblock_mperblock_nblock_nperblock
;
...
...
@@ -318,38 +314,20 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
static
constexpr
auto
conv_to_gemm_transformer
=
TransformConvFwdToGemm
<
NDimSpatial
,
ConvForwardSpecialization
>
{};
using
ConvToGemmFwdTransformer
=
TransformConvFwdToGemm
<
NDimSpatial
,
ConvForwardSpecialization
,
true
/*SplitN*/
,
ADataType
,
EDataType
>
;
static
constexpr
auto
matrix_padder
=
MatrixPadder
<
GemmSpec
,
index_t
,
index_t
,
index_t
>
{
MPerBlock
,
NPerBlock
,
KPerBlock
};
template
<
typename
ALay
>
static
auto
MakeAGridDescriptor_M_K
(
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
a_g_n_c_wis_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
a_g_n_c_wis_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
b_g_k_c_xs_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
b_g_k_c_xs_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
e_g_n_k_wos_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
e_g_n_k_wos_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_dilations
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_left_pads
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_right_pads
,
const
index_t
Conv_N
)
static
auto
MakeAGridDescriptor_M_K
(
const
ConvToGemmFwdTransformer
&
conv_to_gemm_transformer
)
{
const
auto
in_gemmmraw_gemmkraw_desc
=
conv_to_gemm_transformer
.
template
MakeADescriptor_M_K
<
ALay
>(
a_g_n_c_wis_lengths
,
a_g_n_c_wis_strides
,
b_g_k_c_xs_lengths
,
b_g_k_c_xs_strides
,
e_g_n_k_wos_lengths
,
e_g_n_k_wos_strides
,
conv_filter_strides
,
conv_filter_dilations
,
input_left_pads
,
input_right_pads
,
Conv_N
);
conv_to_gemm_transformer
.
template
MakeADescriptor_M_K
<
ALay
>();
const
auto
in_gemmm_gemmk_desc
=
matrix_padder
.
PadADescriptor_M_K
(
in_gemmmraw_gemmkraw_desc
);
...
...
@@ -358,13 +336,10 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
}
template
<
typename
BLay
>
static
auto
MakeBGridDescriptor_N_K
(
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
b_g_k_c_xs_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
b_g_k_c_xs_strides
)
static
auto
MakeBGridDescriptor_N_K
(
const
ConvToGemmFwdTransformer
&
conv_to_gemm_transformer
)
{
const
auto
wei_gemmnraw_gemmkraw_desc
=
conv_to_gemm_transformer
.
template
MakeBDescriptor_N_K
<
BLay
>(
b_g_k_c_xs_lengths
,
b_g_k_c_xs_strides
);
conv_to_gemm_transformer
.
template
MakeBDescriptor_N_K
<
BLay
>();
const
auto
wei_gemmn_gemmk_desc
=
matrix_padder
.
PadBDescriptor_N_K
(
wei_gemmnraw_gemmkraw_desc
);
...
...
@@ -373,14 +348,10 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
}
template
<
typename
ELay
>
static
auto
MakeEGridDescriptor_M_N
(
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
e_g_n_k_wos_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
e_g_n_k_wos_strides
,
const
index_t
Conv_N
)
static
auto
MakeEGridDescriptor_M_N
(
const
ConvToGemmFwdTransformer
&
conv_to_gemm_transformer
)
{
const
auto
out_gemmmraw_gemmnraw_desc
=
conv_to_gemm_transformer
.
template
MakeCDescriptor_M_N
<
ELay
>(
e_g_n_k_wos_lengths
,
e_g_n_k_wos_strides
,
Conv_N
);
conv_to_gemm_transformer
.
template
MakeCDescriptor_M_N
<
ELay
>();
const
auto
out_gemmm_gemmn_desc
=
matrix_padder
.
PadCDescriptor_M_N
(
out_gemmmraw_gemmnraw_desc
);
...
...
@@ -390,27 +361,27 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
// Shape of Ds and E must be aligned. Strides can be different.
// Pass e_g_n_k_wos_lengths for logical broadcast.
static
auto
MakeDsGridDescriptor_M_N
(
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
e_g_n_k_wos_lengths
,
const
std
::
array
<
std
::
array
<
index_t
,
NDimSpatial
+
3
>
,
NumDTensor
>&
ds_g_n_k_wos_strides
,
const
index_t
Conv_N
)
static
auto
MakeDsGridDescriptor_M_N
(
const
ConvToGemmFwdTransformer
&
conv_to_gemm_transformer
)
{
return
generate_tuple
(
[
&
](
auto
i
)
{
using
DLayout
=
remove_cvref_t
<
tuple_element_t
<
i
.
value
,
DsLayout
>>
;
return
DeviceOp
::
MakeEGridDescriptor_M_N
<
DLayout
>
(
e_g_n_k_wos_lengths
,
ds_g_n_k_wos_strides
[
i
],
Conv_N
);
return
DeviceOp
::
MakeEGridDescriptor_M_N
<
DLayout
>
(
conv_to_gemm_transformer
);
},
Number
<
NumDTensor
>
{});
}
// desc for problem definition
using
AGridDesc_M_K
=
remove_cvref_t
<
decltype
(
MakeAGridDescriptor_M_K
<
ALayout
>
(
{},
{},
{},
{},
{},
{},
{},
{},
{},
{},
1
))
>
;
using
BGridDesc_N_K
=
remove_cvref_t
<
decltype
(
MakeBGridDescriptor_N_K
<
BLayout
>
({},
{}))
>
;
using
DsGridDesc_M_N
=
remove_cvref_t
<
decltype
(
MakeDsGridDescriptor_M_N
({},
{},
1
))
>
;
using
EGridDesc_M_N
=
remove_cvref_t
<
decltype
(
MakeEGridDescriptor_M_N
<
ELayout
>
({},
{},
1
))
>
;
constexpr
static
ConvToGemmFwdTransformer
dummy_conv_to_gemm_transformer
;
using
AGridDesc_M_K
=
remove_cvref_t
<
decltype
(
MakeAGridDescriptor_M_K
<
ALayout
>
(
dummy_conv_to_gemm_transformer
))
>
;
using
BGridDesc_N_K
=
remove_cvref_t
<
decltype
(
MakeBGridDescriptor_N_K
<
BLayout
>
(
dummy_conv_to_gemm_transformer
))
>
;
using
DsGridDesc_M_N
=
remove_cvref_t
<
decltype
(
MakeDsGridDescriptor_M_N
(
dummy_conv_to_gemm_transformer
))
>
;
using
EGridDesc_M_N
=
remove_cvref_t
<
decltype
(
MakeEGridDescriptor_M_N
<
ELayout
>
(
dummy_conv_to_gemm_transformer
))
>
;
// If we are using multiAB and one of the template datatype parameters is not a tuple, convert
// it to it
...
...
@@ -498,28 +469,24 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
p_ds_grid_
{},
p_e_grid_
{
static_cast
<
EDataType
*>
(
p_e
)},
num_group_
{
a_g_n_c_wis_lengths
[
0
]},
conv_N_per_block_
{
conv_to_gemm_transformer
.
template
GetSplitedNSize
<
ADataType
,
EDataType
>(
a_g_n_c_wis_lengths
,
a_g_n_c_wis_strides
,
e_g_n_k_wos_lengths
,
e_g_n_k_wos_strides
)},
a_grid_desc_m_k_
{
DeviceOp
::
MakeAGridDescriptor_M_K
<
ALayout
>
(
a_g_n_c_wis_lengths
,
a_g_n_c_wis_strides
,
b_g_k_c_xs_lengths
,
b_g_k_c_xs_strides
,
e_g_n_k_wos_lengths
,
e_g_n_k_wos_strides
,
conv_filter_strides
,
conv_filter_dilations
,
input_left_pads
,
input_right_pads
,
conv_N_per_block_
)},
b_grid_desc_n_k_
{
DeviceOp
::
MakeBGridDescriptor_N_K
<
BLayout
>
(
b_g_k_c_xs_lengths
,
b_g_k_c_xs_strides
)},
conv_to_gemm_transformer_
{
a_g_n_c_wis_lengths
,
a_g_n_c_wis_strides
,
b_g_k_c_xs_lengths
,
b_g_k_c_xs_strides
,
e_g_n_k_wos_lengths
,
e_g_n_k_wos_strides
,
conv_filter_strides
,
conv_filter_dilations
,
input_left_pads
,
input_right_pads
},
conv_N_per_block_
{
conv_to_gemm_transformer_
.
N_
},
a_grid_desc_m_k_
{
DeviceOp
::
MakeAGridDescriptor_M_K
<
ALayout
>
(
conv_to_gemm_transformer_
)},
b_grid_desc_n_k_
{
DeviceOp
::
MakeBGridDescriptor_N_K
<
BLayout
>
(
conv_to_gemm_transformer_
)},
ds_grid_desc_m_n_
{},
e_grid_desc_m_n_
{
DeviceOp
::
MakeEGridDescriptor_M_N
<
ELayout
>
(
e_g_n_k_wos_lengths
,
e_g_n_k_wos_strides
,
conv_N_per_block
_
)},
e_grid_desc_m_n_
{
DeviceOp
::
MakeEGridDescriptor_M_N
<
ELayout
>
(
conv_to_gemm_transformer
_
)},
a_grid_desc_ak0_m_ak1_
{
GridwiseGemm
::
MakeDefaultAGridDescriptor_AK0_M_AK1
(
a_grid_desc_m_k_
)},
b_grid_desc_bk0_n_bk1_
{
...
...
@@ -620,9 +587,20 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
compute_ptr_offset_of_n_
.
BatchStrideDs_
(
i
)
=
ds_g_n_k_wos_strides
[
i
][
1
]
*
conv_N_per_block_
;
ConvToGemmFwdTransformer
conv_to_gemm_transformer_d
{
a_g_n_c_wis_lengths
,
a_g_n_c_wis_strides
,
b_g_k_c_xs_lengths
,
b_g_k_c_xs_strides
,
e_g_n_k_wos_lengths
,
ds_g_n_k_wos_strides
[
i
],
conv_filter_strides
,
conv_filter_dilations
,
input_left_pads
,
input_right_pads
};
// D desc
ds_grid_desc_m_n_
(
i
)
=
DeviceOp
::
MakeEGridDescriptor_M_N
<
DLayout
>
(
e_g_n_k_wos_lengths
,
ds_g_n_k_wos_st
rides
[
i
],
conv_N_per_block_
);
ds_grid_desc_m_n_
(
i
)
=
DeviceOp
::
MakeEG
rid
D
es
criptor_M_N
<
DLayout
>
(
conv_to_gemm_transformer_d
);
});
compute_ptr_offset_of_groups_
.
BatchStrideE_
=
e_g_n_k_wos_strides
[
0
];
compute_ptr_offset_of_n_
.
BatchStrideE_
=
e_g_n_k_wos_strides
[
1
]
*
conv_N_per_block_
;
...
...
@@ -687,6 +665,9 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
// tensor descriptors for problem definiton
index_t
num_group_
;
ConvToGemmFwdTransformer
conv_to_gemm_transformer_
;
index_t
conv_N_per_block_
;
AGridDesc_M_K
a_grid_desc_m_k_
;
...
...
@@ -745,8 +726,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
arg
.
a_g_n_c_wis_lengths_
[
I1
]
/
arg
.
conv_N_per_block_
;
const
index_t
gdx
=
arg
.
block_2_etile_map_
.
CalculateGridSize
(
arg
.
e_grid_desc_m_n_
);
const
index_t
gdy
=
arg
.
num_group_
*
num_workgroups_per_Conv_N
;
const
index_t
gdz
=
1
;
const
index_t
gdy
=
arg
.
num_group_
;
const
index_t
gdz
=
num_workgroups_per_Conv_N
;
const
auto
K
=
arg
.
a_grid_desc_ak0_m_ak1_
.
GetLength
(
I0
)
*
arg
.
a_grid_desc_ak0_m_ak1_
.
GetLength
(
I2
);
...
...
@@ -795,7 +776,6 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
cde_element_op_
,
arg
.
a_g_n_c_wis_lengths_
[
0
],
// Group count
as_grid_desc_ak0_m_ak1
,
bs_grid_desc_bk0_n_bk1
,
arg
.
ds_grid_desc_mblock_mperblock_nblock_nperblock_
,
...
...
@@ -839,7 +819,6 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
cde_element_op_
,
arg
.
a_g_n_c_wis_lengths_
[
0
],
// Group count
arg
.
a_grid_desc_ak0_m_ak1_
,
arg
.
b_grid_desc_bk0_n_bk1_
,
arg
.
ds_grid_desc_mblock_mperblock_nblock_nperblock_
,
...
...
@@ -1103,11 +1082,84 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
cde_element_op
};
}
static
auto
MakeArgument
(
APointers
p_as
,
BPointers
p_bs
,
const
std
::
array
<
const
void
*
,
NumDTensor
>&
p_ds
,
void
*
p_e
,
const
std
::
array
<
long_index_t
,
NDimSpatial
+
3
>&
a_g_n_c_wis_lengths
,
const
std
::
array
<
long_index_t
,
NDimSpatial
+
3
>&
a_g_n_c_wis_strides
,
const
std
::
array
<
long_index_t
,
NDimSpatial
+
3
>&
b_g_k_c_xs_lengths
,
const
std
::
array
<
long_index_t
,
NDimSpatial
+
3
>&
b_g_k_c_xs_strides
,
const
std
::
array
<
std
::
array
<
long_index_t
,
NDimSpatial
+
3
>
,
NumDTensor
>&
ds_g_n_k_wos_lengths
,
const
std
::
array
<
std
::
array
<
long_index_t
,
NDimSpatial
+
3
>
,
NumDTensor
>&
ds_g_n_k_wos_strides
,
const
std
::
array
<
long_index_t
,
NDimSpatial
+
3
>&
e_g_n_k_wos_lengths
,
const
std
::
array
<
long_index_t
,
NDimSpatial
+
3
>&
e_g_n_k_wos_strides
,
const
std
::
array
<
long_index_t
,
NDimSpatial
>&
conv_filter_strides
,
const
std
::
array
<
long_index_t
,
NDimSpatial
>&
conv_filter_dilations
,
const
std
::
array
<
long_index_t
,
NDimSpatial
>&
input_left_pads
,
const
std
::
array
<
long_index_t
,
NDimSpatial
>&
input_right_pads
,
const
AElementwiseOperation
&
a_element_op
,
const
BElementwiseOperation
&
b_element_op
,
const
CDEElementwiseOperation
&
cde_element_op
)
{
std
::
array
<
index_t
,
NDimSpatial
+
3
>
a_g_n_c_wis_lengths_i32
;
std
::
array
<
index_t
,
NDimSpatial
+
3
>
a_g_n_c_wis_strides_i32
;
std
::
array
<
index_t
,
NDimSpatial
+
3
>
b_g_k_c_xs_lengths_i32
;
std
::
array
<
index_t
,
NDimSpatial
+
3
>
b_g_k_c_xs_strides_i32
;
std
::
array
<
std
::
array
<
index_t
,
NDimSpatial
+
3
>
,
NumDTensor
>
ds_g_n_k_wos_lengths_i32
;
std
::
array
<
std
::
array
<
index_t
,
NDimSpatial
+
3
>
,
NumDTensor
>
ds_g_n_k_wos_strides_i32
;
std
::
array
<
index_t
,
NDimSpatial
+
3
>
e_g_n_k_wos_lengths_i32
;
std
::
array
<
index_t
,
NDimSpatial
+
3
>
e_g_n_k_wos_strides_i32
;
std
::
array
<
index_t
,
NDimSpatial
>
conv_filter_strides_i32
;
std
::
array
<
index_t
,
NDimSpatial
>
conv_filter_dilations_i32
;
std
::
array
<
index_t
,
NDimSpatial
>
input_left_pads_i32
;
std
::
array
<
index_t
,
NDimSpatial
>
input_right_pads_i32
;
array_convert
(
a_g_n_c_wis_lengths_i32
,
a_g_n_c_wis_lengths
);
array_convert
(
a_g_n_c_wis_strides_i32
,
a_g_n_c_wis_strides
);
array_convert
(
b_g_k_c_xs_lengths_i32
,
b_g_k_c_xs_lengths
);
array_convert
(
b_g_k_c_xs_strides_i32
,
b_g_k_c_xs_strides
);
for
(
index_t
d
=
0
;
d
<
NumDTensor
;
d
++
)
{
array_convert
(
ds_g_n_k_wos_lengths_i32
[
d
],
ds_g_n_k_wos_lengths
[
d
]);
array_convert
(
ds_g_n_k_wos_strides_i32
[
d
],
ds_g_n_k_wos_strides
[
d
]);
}
array_convert
(
e_g_n_k_wos_lengths_i32
,
e_g_n_k_wos_lengths
);
array_convert
(
e_g_n_k_wos_strides_i32
,
e_g_n_k_wos_strides
);
array_convert
(
conv_filter_strides_i32
,
conv_filter_strides
);
array_convert
(
conv_filter_dilations_i32
,
conv_filter_dilations
);
array_convert
(
input_left_pads_i32
,
input_left_pads
);
array_convert
(
input_right_pads_i32
,
input_right_pads
);
return
Argument
{
p_as
,
p_bs
,
p_ds
,
p_e
,
a_g_n_c_wis_lengths_i32
,
a_g_n_c_wis_strides_i32
,
b_g_k_c_xs_lengths_i32
,
b_g_k_c_xs_strides_i32
,
ds_g_n_k_wos_lengths_i32
,
ds_g_n_k_wos_strides_i32
,
e_g_n_k_wos_lengths_i32
,
e_g_n_k_wos_strides_i32
,
conv_filter_strides_i32
,
conv_filter_dilations_i32
,
input_left_pads_i32
,
input_right_pads_i32
,
a_element_op
,
b_element_op
,
cde_element_op
};
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
APointers
p_a
,
BPointers
p_b
,
APointers
p_a
s
,
BPointers
p_b
s
,
const
std
::
array
<
const
void
*
,
NumDTensor
>&
p_ds
,
void
*
p_e
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
a_g_n_c_wis_lengths
,
...
...
@@ -1126,8 +1178,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
const
BElementwiseOperation
&
b_element_op
,
const
CDEElementwiseOperation
&
cde_element_op
)
override
{
return
std
::
make_unique
<
Argument
>
(
p_a
,
p_b
,
return
std
::
make_unique
<
Argument
>
(
p_a
s
,
p_b
s
,
p_ds
,
p_e
,
a_g_n_c_wis_lengths
,
...
...
@@ -1147,6 +1199,80 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
cde_element_op
);
}
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
APointers
p_as
,
BPointers
p_bs
,
const
std
::
array
<
const
void
*
,
NumDTensor
>&
p_ds
,
void
*
p_e
,
const
std
::
array
<
long_index_t
,
NDimSpatial
+
3
>&
a_g_n_c_wis_lengths
,
const
std
::
array
<
long_index_t
,
NDimSpatial
+
3
>&
a_g_n_c_wis_strides
,
const
std
::
array
<
long_index_t
,
NDimSpatial
+
3
>&
b_g_k_c_xs_lengths
,
const
std
::
array
<
long_index_t
,
NDimSpatial
+
3
>&
b_g_k_c_xs_strides
,
const
std
::
array
<
std
::
array
<
long_index_t
,
NDimSpatial
+
3
>
,
NumDTensor
>&
ds_g_n_k_wos_lengths
,
const
std
::
array
<
std
::
array
<
long_index_t
,
NDimSpatial
+
3
>
,
NumDTensor
>&
ds_g_n_k_wos_strides
,
const
std
::
array
<
long_index_t
,
NDimSpatial
+
3
>&
e_g_n_k_wos_lengths
,
const
std
::
array
<
long_index_t
,
NDimSpatial
+
3
>&
e_g_n_k_wos_strides
,
const
std
::
array
<
long_index_t
,
NDimSpatial
>&
conv_filter_strides
,
const
std
::
array
<
long_index_t
,
NDimSpatial
>&
conv_filter_dilations
,
const
std
::
array
<
long_index_t
,
NDimSpatial
>&
input_left_pads
,
const
std
::
array
<
long_index_t
,
NDimSpatial
>&
input_right_pads
,
const
AElementwiseOperation
&
a_element_op
,
const
BElementwiseOperation
&
b_element_op
,
const
CDEElementwiseOperation
&
cde_element_op
)
override
{
std
::
array
<
index_t
,
NDimSpatial
+
3
>
a_g_n_c_wis_lengths_i32
;
std
::
array
<
index_t
,
NDimSpatial
+
3
>
a_g_n_c_wis_strides_i32
;
std
::
array
<
index_t
,
NDimSpatial
+
3
>
b_g_k_c_xs_lengths_i32
;
std
::
array
<
index_t
,
NDimSpatial
+
3
>
b_g_k_c_xs_strides_i32
;
std
::
array
<
std
::
array
<
index_t
,
NDimSpatial
+
3
>
,
NumDTensor
>
ds_g_n_k_wos_lengths_i32
;
std
::
array
<
std
::
array
<
index_t
,
NDimSpatial
+
3
>
,
NumDTensor
>
ds_g_n_k_wos_strides_i32
;
std
::
array
<
index_t
,
NDimSpatial
+
3
>
e_g_n_k_wos_lengths_i32
;
std
::
array
<
index_t
,
NDimSpatial
+
3
>
e_g_n_k_wos_strides_i32
;
std
::
array
<
index_t
,
NDimSpatial
>
conv_filter_strides_i32
;
std
::
array
<
index_t
,
NDimSpatial
>
conv_filter_dilations_i32
;
std
::
array
<
index_t
,
NDimSpatial
>
input_left_pads_i32
;
std
::
array
<
index_t
,
NDimSpatial
>
input_right_pads_i32
;
array_convert
(
a_g_n_c_wis_lengths_i32
,
a_g_n_c_wis_lengths
);
array_convert
(
a_g_n_c_wis_strides_i32
,
a_g_n_c_wis_strides
);
array_convert
(
b_g_k_c_xs_lengths_i32
,
b_g_k_c_xs_lengths
);
array_convert
(
b_g_k_c_xs_strides_i32
,
b_g_k_c_xs_strides
);
for
(
index_t
d
=
0
;
d
<
NumDTensor
;
d
++
)
{
array_convert
(
ds_g_n_k_wos_lengths_i32
[
d
],
ds_g_n_k_wos_lengths
[
d
]);
array_convert
(
ds_g_n_k_wos_strides_i32
[
d
],
ds_g_n_k_wos_strides
[
d
]);
}
array_convert
(
e_g_n_k_wos_lengths_i32
,
e_g_n_k_wos_lengths
);
array_convert
(
e_g_n_k_wos_strides_i32
,
e_g_n_k_wos_strides
);
array_convert
(
conv_filter_strides_i32
,
conv_filter_strides
);
array_convert
(
conv_filter_dilations_i32
,
conv_filter_dilations
);
array_convert
(
input_left_pads_i32
,
input_left_pads
);
array_convert
(
input_right_pads_i32
,
input_right_pads
);
return
std
::
make_unique
<
Argument
>
(
p_as
,
p_bs
,
p_ds
,
p_e
,
a_g_n_c_wis_lengths_i32
,
a_g_n_c_wis_strides_i32
,
b_g_k_c_xs_lengths_i32
,
b_g_k_c_xs_strides_i32
,
ds_g_n_k_wos_lengths_i32
,
ds_g_n_k_wos_strides_i32
,
e_g_n_k_wos_lengths_i32
,
e_g_n_k_wos_strides_i32
,
conv_filter_strides_i32
,
conv_filter_dilations_i32
,
input_left_pads_i32
,
input_right_pads_i32
,
a_element_op
,
b_element_op
,
cde_element_op
);
}
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
override
{
return
std
::
make_unique
<
Invoker
>
(
Invoker
{});
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp
View file @
bd689f40
...
...
@@ -293,39 +293,22 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
static
constexpr
auto
conv_to_gemm_transformer
=
TransformConvFwdToGemm
<
NDimSpatial
,
ConvForwardSpecialization
>
{};
using
ConvToGemmFwdTransformer
=
TransformConvFwdToGemm
<
NDimSpatial
,
ConvForwardSpecialization
,
true
/*SplitN*/
,
ADataType
,
EDataType
>
;
static
constexpr
auto
matrix_padder
=
MatrixPadder
<
GemmSpec
,
index_t
,
index_t
,
index_t
>
{
MPerBlock
,
NPerBlock
,
KPerBlock
};
template
<
typename
ALay
>
static
auto
MakeAGridDescriptor_AK0_M_AK1
(
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
a_g_n_c_wis_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
a_g_n_c_wis_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
b_g_k_c_xs_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
b_g_k_c_xs_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
e_g_n_k_wos_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
e_g_n_k_wos_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_dilations
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_left_pads
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_right_pads
,
const
index_t
Conv_N
)
MakeAGridDescriptor_AK0_M_AK1
(
const
ConvToGemmFwdTransformer
&
conv_to_gemm_transformer
)
{
const
auto
in_gemmmraw_gemmkraw_desc
=
conv_to_gemm_transformer
.
template
MakeADescriptor_M_K
<
ALay
>(
a_g_n_c_wis_lengths
,
a_g_n_c_wis_strides
,
b_g_k_c_xs_lengths
,
b_g_k_c_xs_strides
,
e_g_n_k_wos_lengths
,
e_g_n_k_wos_strides
,
conv_filter_strides
,
conv_filter_dilations
,
input_left_pads
,
input_right_pads
,
Conv_N
);
conv_to_gemm_transformer
.
template
MakeADescriptor_M_K
<
ALay
>();
const
auto
in_gemmm_gemmk_desc
=
matrix_padder
.
PadADescriptor_M_K
(
in_gemmmraw_gemmkraw_desc
);
...
...
@@ -344,12 +327,10 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
template
<
typename
BLay
>
static
auto
MakeBGridDescriptor_BK0_N_BK1
(
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
b_g_k_c_xs_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
b_g_k_c_xs_strides
)
MakeBGridDescriptor_BK0_N_BK1
(
const
ConvToGemmFwdTransformer
&
conv_to_gemm_transformer
)
{
const
auto
wei_gemmnraw_gemmkraw_desc
=
conv_to_gemm_transformer
.
template
MakeBDescriptor_N_K
<
BLay
>(
b_g_k_c_xs_lengths
,
b_g_k_c_xs_strides
);
conv_to_gemm_transformer
.
template
MakeBDescriptor_N_K
<
BLay
>();
const
auto
wei_gemmn_gemmk_desc
=
matrix_padder
.
PadBDescriptor_N_K
(
wei_gemmnraw_gemmkraw_desc
);
...
...
@@ -367,15 +348,11 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
}
template
<
typename
ELay
>
static
auto
MakeEGridDescriptor_M_N
(
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
e_g_n_k_wos_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
e_g_n_k_wos_strides
,
const
index_t
Conv_N
)
static
auto
MakeEGridDescriptor_M_N
(
const
ConvToGemmFwdTransformer
&
conv_to_gemm_transformer
)
{
const
auto
out_gemmmraw_gemmnraw_desc
=
conv_to_gemm_transformer
.
template
MakeCDescriptor_M_N
<
ELay
>(
e_g_n_k_wos_lengths
,
e_g_n_k_wos_strides
,
Conv_N
);
conv_to_gemm_transformer
.
template
MakeCDescriptor_M_N
<
ELay
>();
const
auto
out_gemmm_gemmn_desc
=
matrix_padder
.
PadCDescriptor_M_N
(
out_gemmmraw_gemmnraw_desc
);
...
...
@@ -384,7 +361,9 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
}
// desc for problem definition
using
EGridDesc_M_N
=
remove_cvref_t
<
decltype
(
MakeEGridDescriptor_M_N
<
ELayout
>
({},
{},
1
))
>
;
constexpr
static
ConvToGemmFwdTransformer
dummy_conv_to_gemm_transformer
;
using
EGridDesc_M_N
=
remove_cvref_t
<
decltype
(
MakeEGridDescriptor_M_N
<
ELayout
>
(
dummy_conv_to_gemm_transformer
))
>
;
#define GridwiseGemmV3TemplateParams \
tensor_layout::gemm::RowMajor, tensor_layout::gemm::ColumnMajor, \
...
...
@@ -417,9 +396,9 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
// desc for blockwise copy
using
AGridDesc_AK0_M_AK1
=
remove_cvref_t
<
decltype
(
MakeAGridDescriptor_AK0_M_AK1
<
ALayout
>
(
{},
{},
{},
{},
{},
{},
{},
{},
{},
{},
1
))
>
;
using
BGridDesc_BK0_N_BK1
=
remove_cvref_t
<
decltype
(
MakeBGridDescriptor_BK0_N_BK1
<
BLayout
>
({},
{}
))
>
;
dummy_conv_to_gemm_transformer
))
>
;
using
BGridDesc_BK0_N_BK1
=
remove_cvref_t
<
decltype
(
MakeBGridDescriptor_BK0_N_BK1
<
BLayout
>
(
dummy_conv_to_gemm_transformer
))
>
;
using
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
=
remove_cvref_t
<
decltype
(
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
EGridDesc_M_N
{}))
>
;
...
...
@@ -450,27 +429,23 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
p_b_grid_
{},
p_e_grid_
{
static_cast
<
EDataType
*>
(
p_e
)},
num_group_
{
a_g_n_c_wis_lengths
[
0
]},
conv_N_per_block_
{
conv_to_gemm_transformer
.
template
GetSplitedNSize
<
ADataType
,
EDataType
>(
a_g_n_c_wis_lengths
,
a_g_n_c_wis_strides
,
e_g_n_k_wos_lengths
,
e_g_n_k_wos_strides
)},
a_grid_desc_ak0_m_ak1_
{
MakeAGridDescriptor_AK0_M_AK1
<
ALayout
>
(
a_g_n_c_wis_lengths
,
a_g_n_c_wis_strides
,
b_g_k_c_xs_lengths
,
b_g_k_c_xs_strides
,
e_g_n_k_wos_lengths
,
e_g_n_k_wos_strides
,
conv_filter_strides
,
conv_filter_dilations
,
input_left_pads
,
input_right_pads
,
conv_N_per_block_
)},
conv_to_gemm_transformer_
{
a_g_n_c_wis_lengths
,
a_g_n_c_wis_strides
,
b_g_k_c_xs_lengths
,
b_g_k_c_xs_strides
,
e_g_n_k_wos_lengths
,
e_g_n_k_wos_strides
,
conv_filter_strides
,
conv_filter_dilations
,
input_left_pads
,
input_right_pads
},
conv_N_per_block_
{
conv_to_gemm_transformer_
.
N_
},
a_grid_desc_ak0_m_ak1_
{
MakeAGridDescriptor_AK0_M_AK1
<
ALayout
>
(
conv_to_gemm_transformer_
)},
b_grid_desc_bk0_n_bk1_
{
MakeBGridDescriptor_BK0_N_BK1
<
BLayout
>
(
b_g_k_c_xs_lengths
,
b_g_k_c_xs_strides
)},
e_grid_desc_m_n_
{
DeviceOp
::
MakeEGridDescriptor_M_N
<
ELayout
>
(
e_g_n_k_wos_lengths
,
e_g_n_k_wos_strides
,
conv_N_per_block
_
)},
MakeBGridDescriptor_BK0_N_BK1
<
BLayout
>
(
conv_to_gemm_transformer_
)},
e_grid_desc_m_n_
{
DeviceOp
::
MakeEGridDescriptor_M_N
<
ELayout
>
(
conv_to_gemm_transformer
_
)},
e_grid_desc_mblock_mperblock_nblock_nperblock_
{},
compute_ptr_offset_of_groups_
{},
compute_ptr_offset_of_n_
{},
...
...
@@ -519,6 +494,9 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
// tensor descriptors for problem definiton
index_t
num_group_
;
ConvToGemmFwdTransformer
conv_to_gemm_transformer_
;
index_t
conv_N_per_block_
;
// tensor descriptors for block/thread-wise copy
...
...
@@ -1000,6 +978,12 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
return
false
;
}
// Gridwise gemm v3 doesn't verify descriptors size
if
(
!
arg
.
conv_to_gemm_transformer_
.
AreDescriptorsSmallerThan2GB
())
{
return
false
;
}
// check Gridwise GEMM
const
index_t
GemmM
=
arg
.
a_grid_desc_ak0_m_ak1_
.
GetLength
(
I1
);
const
index_t
GemmN
=
arg
.
b_grid_desc_bk0_n_bk1_
.
GetLength
(
I1
);
...
...
@@ -1059,6 +1043,79 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
cde_element_op
};
}
static
auto
MakeArgument
(
const
void
*
p_as
,
const
void
*
p_bs
,
const
std
::
array
<
const
void
*
,
NumDTensor
>&
p_ds
,
void
*
p_e
,
const
std
::
array
<
long_index_t
,
NDimSpatial
+
3
>&
a_g_n_c_wis_lengths
,
const
std
::
array
<
long_index_t
,
NDimSpatial
+
3
>&
a_g_n_c_wis_strides
,
const
std
::
array
<
long_index_t
,
NDimSpatial
+
3
>&
b_g_k_c_xs_lengths
,
const
std
::
array
<
long_index_t
,
NDimSpatial
+
3
>&
b_g_k_c_xs_strides
,
const
std
::
array
<
std
::
array
<
long_index_t
,
NDimSpatial
+
3
>
,
NumDTensor
>&
ds_g_n_k_wos_lengths
,
const
std
::
array
<
std
::
array
<
long_index_t
,
NDimSpatial
+
3
>
,
NumDTensor
>&
ds_g_n_k_wos_strides
,
const
std
::
array
<
long_index_t
,
NDimSpatial
+
3
>&
e_g_n_k_wos_lengths
,
const
std
::
array
<
long_index_t
,
NDimSpatial
+
3
>&
e_g_n_k_wos_strides
,
const
std
::
array
<
long_index_t
,
NDimSpatial
>&
conv_filter_strides
,
const
std
::
array
<
long_index_t
,
NDimSpatial
>&
conv_filter_dilations
,
const
std
::
array
<
long_index_t
,
NDimSpatial
>&
input_left_pads
,
const
std
::
array
<
long_index_t
,
NDimSpatial
>&
input_right_pads
,
const
AElementwiseOperation
&
a_element_op
,
const
BElementwiseOperation
&
b_element_op
,
const
CDEElementwiseOperation
&
cde_element_op
)
{
std
::
array
<
index_t
,
NDimSpatial
+
3
>
a_g_n_c_wis_lengths_i32
;
std
::
array
<
index_t
,
NDimSpatial
+
3
>
a_g_n_c_wis_strides_i32
;
std
::
array
<
index_t
,
NDimSpatial
+
3
>
b_g_k_c_xs_lengths_i32
;
std
::
array
<
index_t
,
NDimSpatial
+
3
>
b_g_k_c_xs_strides_i32
;
std
::
array
<
std
::
array
<
index_t
,
NDimSpatial
+
3
>
,
NumDTensor
>
ds_g_n_k_wos_lengths_i32
;
std
::
array
<
std
::
array
<
index_t
,
NDimSpatial
+
3
>
,
NumDTensor
>
ds_g_n_k_wos_strides_i32
;
std
::
array
<
index_t
,
NDimSpatial
+
3
>
e_g_n_k_wos_lengths_i32
;
std
::
array
<
index_t
,
NDimSpatial
+
3
>
e_g_n_k_wos_strides_i32
;
std
::
array
<
index_t
,
NDimSpatial
>
conv_filter_strides_i32
;
std
::
array
<
index_t
,
NDimSpatial
>
conv_filter_dilations_i32
;
std
::
array
<
index_t
,
NDimSpatial
>
input_left_pads_i32
;
std
::
array
<
index_t
,
NDimSpatial
>
input_right_pads_i32
;
array_convert
(
a_g_n_c_wis_lengths_i32
,
a_g_n_c_wis_lengths
);
array_convert
(
a_g_n_c_wis_strides_i32
,
a_g_n_c_wis_strides
);
array_convert
(
b_g_k_c_xs_lengths_i32
,
b_g_k_c_xs_lengths
);
array_convert
(
b_g_k_c_xs_strides_i32
,
b_g_k_c_xs_strides
);
for
(
index_t
d
=
0
;
d
<
NumDTensor
;
d
++
)
{
array_convert
(
ds_g_n_k_wos_lengths_i32
[
d
],
ds_g_n_k_wos_lengths
[
d
]);
array_convert
(
ds_g_n_k_wos_strides_i32
[
d
],
ds_g_n_k_wos_strides
[
d
]);
}
array_convert
(
e_g_n_k_wos_lengths_i32
,
e_g_n_k_wos_lengths
);
array_convert
(
e_g_n_k_wos_strides_i32
,
e_g_n_k_wos_strides
);
array_convert
(
conv_filter_strides_i32
,
conv_filter_strides
);
array_convert
(
conv_filter_dilations_i32
,
conv_filter_dilations
);
array_convert
(
input_left_pads_i32
,
input_left_pads
);
array_convert
(
input_right_pads_i32
,
input_right_pads
);
return
Argument
{
p_as
,
p_bs
,
p_ds
,
p_e
,
a_g_n_c_wis_lengths_i32
,
a_g_n_c_wis_strides_i32
,
b_g_k_c_xs_lengths_i32
,
b_g_k_c_xs_strides_i32
,
ds_g_n_k_wos_lengths_i32
,
ds_g_n_k_wos_strides_i32
,
e_g_n_k_wos_lengths_i32
,
e_g_n_k_wos_strides_i32
,
conv_filter_strides_i32
,
conv_filter_dilations_i32
,
input_left_pads_i32
,
input_right_pads_i32
,
a_element_op
,
b_element_op
,
cde_element_op
};
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
...
...
@@ -1103,6 +1160,79 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
cde_element_op
);
}
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
void
*
p_a
,
const
void
*
p_b
,
const
std
::
array
<
const
void
*
,
NumDTensor
>&
p_ds
,
void
*
p_e
,
const
std
::
array
<
long_index_t
,
NDimSpatial
+
3
>&
a_g_n_c_wis_lengths
,
const
std
::
array
<
long_index_t
,
NDimSpatial
+
3
>&
a_g_n_c_wis_strides
,
const
std
::
array
<
long_index_t
,
NDimSpatial
+
3
>&
b_g_k_c_xs_lengths
,
const
std
::
array
<
long_index_t
,
NDimSpatial
+
3
>&
b_g_k_c_xs_strides
,
const
std
::
array
<
std
::
array
<
long_index_t
,
NDimSpatial
+
3
>
,
NumDTensor
>&
ds_g_n_k_wos_lengths
,
const
std
::
array
<
std
::
array
<
long_index_t
,
NDimSpatial
+
3
>
,
NumDTensor
>&
ds_g_n_k_wos_strides
,
const
std
::
array
<
long_index_t
,
NDimSpatial
+
3
>&
e_g_n_k_wos_lengths
,
const
std
::
array
<
long_index_t
,
NDimSpatial
+
3
>&
e_g_n_k_wos_strides
,
const
std
::
array
<
long_index_t
,
NDimSpatial
>&
conv_filter_strides
,
const
std
::
array
<
long_index_t
,
NDimSpatial
>&
conv_filter_dilations
,
const
std
::
array
<
long_index_t
,
NDimSpatial
>&
input_left_pads
,
const
std
::
array
<
long_index_t
,
NDimSpatial
>&
input_right_pads
,
const
AElementwiseOperation
&
a_element_op
,
const
BElementwiseOperation
&
b_element_op
,
const
CDEElementwiseOperation
&
cde_element_op
)
override
{
std
::
array
<
index_t
,
NDimSpatial
+
3
>
a_g_n_c_wis_lengths_i32
;
std
::
array
<
index_t
,
NDimSpatial
+
3
>
a_g_n_c_wis_strides_i32
;
std
::
array
<
index_t
,
NDimSpatial
+
3
>
b_g_k_c_xs_lengths_i32
;
std
::
array
<
index_t
,
NDimSpatial
+
3
>
b_g_k_c_xs_strides_i32
;
std
::
array
<
std
::
array
<
index_t
,
NDimSpatial
+
3
>
,
NumDTensor
>
ds_g_n_k_wos_lengths_i32
;
std
::
array
<
std
::
array
<
index_t
,
NDimSpatial
+
3
>
,
NumDTensor
>
ds_g_n_k_wos_strides_i32
;
std
::
array
<
index_t
,
NDimSpatial
+
3
>
e_g_n_k_wos_lengths_i32
;
std
::
array
<
index_t
,
NDimSpatial
+
3
>
e_g_n_k_wos_strides_i32
;
std
::
array
<
index_t
,
NDimSpatial
>
conv_filter_strides_i32
;
std
::
array
<
index_t
,
NDimSpatial
>
conv_filter_dilations_i32
;
std
::
array
<
index_t
,
NDimSpatial
>
input_left_pads_i32
;
std
::
array
<
index_t
,
NDimSpatial
>
input_right_pads_i32
;
array_convert
(
a_g_n_c_wis_lengths_i32
,
a_g_n_c_wis_lengths
);
array_convert
(
a_g_n_c_wis_strides_i32
,
a_g_n_c_wis_strides
);
array_convert
(
b_g_k_c_xs_lengths_i32
,
b_g_k_c_xs_lengths
);
array_convert
(
b_g_k_c_xs_strides_i32
,
b_g_k_c_xs_strides
);
for
(
index_t
d
=
0
;
d
<
NumDTensor
;
d
++
)
{
array_convert
(
ds_g_n_k_wos_lengths_i32
[
d
],
ds_g_n_k_wos_lengths
[
d
]);
array_convert
(
ds_g_n_k_wos_strides_i32
[
d
],
ds_g_n_k_wos_strides
[
d
]);
}
array_convert
(
e_g_n_k_wos_lengths_i32
,
e_g_n_k_wos_lengths
);
array_convert
(
e_g_n_k_wos_strides_i32
,
e_g_n_k_wos_strides
);
array_convert
(
conv_filter_strides_i32
,
conv_filter_strides
);
array_convert
(
conv_filter_dilations_i32
,
conv_filter_dilations
);
array_convert
(
input_left_pads_i32
,
input_left_pads
);
array_convert
(
input_right_pads_i32
,
input_right_pads
);
return
std
::
make_unique
<
Argument
>
(
p_a
,
p_b
,
p_ds
,
p_e
,
a_g_n_c_wis_lengths_i32
,
a_g_n_c_wis_strides_i32
,
b_g_k_c_xs_lengths_i32
,
b_g_k_c_xs_strides_i32
,
ds_g_n_k_wos_lengths_i32
,
ds_g_n_k_wos_strides_i32
,
e_g_n_k_wos_lengths_i32
,
e_g_n_k_wos_strides_i32
,
conv_filter_strides_i32
,
conv_filter_dilations_i32
,
input_left_pads_i32
,
input_right_pads_i32
,
a_element_op
,
b_element_op
,
cde_element_op
);
}
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
override
{
return
std
::
make_unique
<
Invoker
>
(
Invoker
{});
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp
View file @
bd689f40
...
...
@@ -309,37 +309,16 @@ struct DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
static
constexpr
auto
conv_to_gemm_transformer
=
TransformConvFwdToGemm
<
NDimSpatial
,
ConvForwardSpecialization
>
{};
using
ConvToGemmFwdTransformer
=
TransformConvFwdToGemm
<
NDimSpatial
,
ConvForwardSpecialization
>
;
static
constexpr
auto
matrix_padder
=
MatrixPadder
<
GemmSpec
,
index_t
,
index_t
,
index_t
>
{
MPerBlock
,
NPerBlock
,
KPerBlock
};
template
<
typename
ALay
>
static
auto
MakeAGridDescriptor_M_K
(
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
a_g_n_c_wis_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
a_g_n_c_wis_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
b_g_k_c_xs_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
b_g_k_c_xs_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
e_g_n_k_wos_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
e_g_n_k_wos_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_dilations
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_left_pads
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_right_pads
)
static
auto
MakeAGridDescriptor_M_K
(
const
ConvToGemmFwdTransformer
&
conv_to_gemm_transformer
)
{
const
auto
in_gemmmraw_gemmkraw_desc
=
conv_to_gemm_transformer
.
template
MakeADescriptor_M_K
<
ALay
>(
a_g_n_c_wis_lengths
,
a_g_n_c_wis_strides
,
b_g_k_c_xs_lengths
,
b_g_k_c_xs_strides
,
e_g_n_k_wos_lengths
,
e_g_n_k_wos_strides
,
conv_filter_strides
,
conv_filter_dilations
,
input_left_pads
,
input_right_pads
,
a_g_n_c_wis_lengths
[
I1
]);
conv_to_gemm_transformer
.
template
MakeADescriptor_M_K
<
ALay
>();
const
auto
in_gemmm_gemmk_desc
=
matrix_padder
.
PadADescriptor_M_K
(
in_gemmmraw_gemmkraw_desc
);
...
...
@@ -348,13 +327,10 @@ struct DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle
}
template
<
typename
BLay
>
static
auto
MakeBGridDescriptor_N_K
(
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
b_g_k_c_xs_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
b_g_k_c_xs_strides
)
static
auto
MakeBGridDescriptor_N_K
(
const
ConvToGemmFwdTransformer
&
conv_to_gemm_transformer
)
{
const
auto
wei_gemmnraw_gemmkraw_desc
=
conv_to_gemm_transformer
.
template
MakeBDescriptor_N_K
<
BLay
>(
b_g_k_c_xs_lengths
,
b_g_k_c_xs_strides
);
conv_to_gemm_transformer
.
template
MakeBDescriptor_N_K
<
BLay
>();
const
auto
wei_gemmn_gemmk_desc
=
matrix_padder
.
PadBDescriptor_N_K
(
wei_gemmnraw_gemmkraw_desc
);
...
...
@@ -363,13 +339,10 @@ struct DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle
}
template
<
typename
ELay
>
static
auto
MakeEGridDescriptor_M_N
(
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
e_g_n_k_wos_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
e_g_n_k_wos_strides
)
static
auto
MakeEGridDescriptor_M_N
(
const
ConvToGemmFwdTransformer
&
conv_to_gemm_transformer
)
{
const
auto
out_gemmmraw_gemmnraw_desc
=
conv_to_gemm_transformer
.
template
MakeCDescriptor_M_N
<
ELay
>(
e_g_n_k_wos_lengths
,
e_g_n_k_wos_strides
,
e_g_n_k_wos_lengths
[
I1
]);
conv_to_gemm_transformer
.
template
MakeCDescriptor_M_N
<
ELay
>();
const
auto
out_gemmm_gemmn_desc
=
matrix_padder
.
PadCDescriptor_M_N
(
out_gemmmraw_gemmnraw_desc
);
...
...
@@ -447,11 +420,14 @@ struct DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle
return
GetPaddedRGridDescriptor
(
r_grid_desc_mraw
,
NHoWo
);
}
using
AGridDesc_M_K
=
remove_cvref_t
<
decltype
(
MakeAGridDescriptor_M_K
<
ALayout
>
(
{},
{},
{},
{},
{},
{},
{},
{},
{},
{}))
>
;
using
BGridDesc_N_K
=
remove_cvref_t
<
decltype
(
MakeBGridDescriptor_N_K
<
BLayout
>
({},
{}))
>
;
using
EGridDesc_M_N
=
remove_cvref_t
<
decltype
(
MakeEGridDescriptor_M_N
<
DELayout
>
({},
{}))
>
;
using
RGridDesc_M
=
remove_cvref_t
<
decltype
(
MakeRGridDescriptor_M
<
RLayout
>
({},
{}))
>
;
constexpr
static
ConvToGemmFwdTransformer
dummy_conv_to_gemm_transformer
;
using
AGridDesc_M_K
=
remove_cvref_t
<
decltype
(
MakeAGridDescriptor_M_K
<
ALayout
>
(
dummy_conv_to_gemm_transformer
))
>
;
using
BGridDesc_N_K
=
remove_cvref_t
<
decltype
(
MakeBGridDescriptor_N_K
<
BLayout
>
(
dummy_conv_to_gemm_transformer
))
>
;
using
EGridDesc_M_N
=
remove_cvref_t
<
decltype
(
MakeEGridDescriptor_M_N
<
DELayout
>
(
dummy_conv_to_gemm_transformer
))
>
;
using
RGridDesc_M
=
remove_cvref_t
<
decltype
(
MakeRGridDescriptor_M
<
RLayout
>
({},
{}))
>
;
// GridwiseGemm
using
GridwiseGemm
=
GridwiseGemmMultipleDMultipleR_k0mk1_k0nk1_mn_xdl_cshuffle_v1
<
...
...
@@ -551,21 +527,23 @@ struct DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle
p_ds_grid_
{},
p_e_grid_
{
static_cast
<
EDataType
*>
(
p_e
)},
p_rs_grid_
{},
// FIXME
a_grid_desc_m_k_
{
DeviceOp
::
MakeAGridDescriptor_M_K
<
ALayout
>
(
a_g_n_c_wis_lengths
,
a_g_n_c_wis_strides
,
b_g_k_c_xs_lengths
,
b_g_k_c_xs_strides
,
e_g_n_k_wos_lengths
,
e_g_n_k_wos_strides
,
conv_filter_strides
,
conv_filter_dilations
,
input_left_pads
,
input_right_pads
)},
b_grid_desc_n_k_
{
DeviceOp
::
MakeBGridDescriptor_N_K
<
BLayout
>
(
b_g_k_c_xs_lengths
,
b_g_k_c_xs_strides
)},
conv_to_gemm_transformer_
{
a_g_n_c_wis_lengths
,
a_g_n_c_wis_strides
,
b_g_k_c_xs_lengths
,
b_g_k_c_xs_strides
,
e_g_n_k_wos_lengths
,
e_g_n_k_wos_strides
,
conv_filter_strides
,
conv_filter_dilations
,
input_left_pads
,
input_right_pads
},
a_grid_desc_m_k_
{
DeviceOp
::
MakeAGridDescriptor_M_K
<
ALayout
>
(
conv_to_gemm_transformer_
)},
b_grid_desc_n_k_
{
DeviceOp
::
MakeBGridDescriptor_N_K
<
BLayout
>
(
conv_to_gemm_transformer_
)},
ds_grid_desc_m_n_
{},
e_grid_desc_m_n_
{
DeviceOp
::
MakeEGridDescriptor_M_N
<
DELayout
>
(
e_g_n_k_wos_lengths
,
e_g_n_k_wos_strides
)},
e_grid_desc_m_n_
{
DeviceOp
::
MakeEGridDescriptor_M_N
<
DELayout
>
(
conv_to_gemm_transformer_
)},
r_grid_desc_m_
{
DeviceOp
::
MakeRGridDescriptor_M
<
RLayout
>
(
r_g_n_wos_lengths
,
r_g_n_wos_strides
)},
a_grid_desc_ak0_m_ak1_
{
...
...
@@ -621,9 +599,20 @@ struct DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle
// D batch stride
compute_ptr_offset_of_batch_
.
BatchStrideDs_
(
i
)
=
ds_g_n_k_wos_strides
[
i
][
0
];
ConvToGemmFwdTransformer
conv_to_gemm_transformer_d
{
a_g_n_c_wis_lengths
,
a_g_n_c_wis_strides
,
b_g_k_c_xs_lengths
,
b_g_k_c_xs_strides
,
ds_g_n_k_wos_lengths
[
i
],
ds_g_n_k_wos_strides
[
i
],
conv_filter_strides
,
conv_filter_dilations
,
input_left_pads
,
input_right_pads
};
// D desc
ds_grid_desc_m_n_
(
i
)
=
DeviceOp
::
MakeEGridDescriptor_M_N
<
DELayout
>
(
ds_g_n_k_wos_lengths
[
i
],
ds_g_n_k_wos_strides
[
i
]
);
ds_grid_desc_m_n_
(
i
)
=
DeviceOp
::
MakeEGridDescriptor_M_N
<
DELayout
>
(
conv_to_gemm_transformer_d
);
ds_grid_desc_mblock_mperblock_nblock_nperblock_
(
i
)
=
GridwiseGemm
::
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
...
...
@@ -660,6 +649,8 @@ struct DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle
EDataType
*
p_e_grid_
;
typename
GridwiseGemm
::
RsGridPointer
p_rs_grid_
;
ConvToGemmFwdTransformer
conv_to_gemm_transformer_
;
// tensor descriptors for problem definiton
AGridDesc_M_K
a_grid_desc_m_k_
;
BGridDesc_N_K
b_grid_desc_n_k_
;
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp
View file @
bd689f40
...
...
@@ -135,36 +135,16 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
static
constexpr
auto
BEnableLds
=
BEnableLds_auto
||
BEnableLds_manu
||
(
NumGemmKPrefetchStage
>
1
);
static
constexpr
auto
conv_to_gemm_transformer
=
TransformConvFwdToGemm
<
NDimSpatial
,
ConvForwardSpecialization
>
{};
using
ConvToGemmFwdTransformer
=
TransformConvFwdToGemm
<
NDimSpatial
,
ConvForwardSpecialization
>
;
static
constexpr
auto
matrix_padder
=
MatrixPadder
<
GemmSpec
,
index_t
,
index_t
,
index_t
>
{
MPerBlock
,
NPerBlock
,
KPerBlock
};
template
<
typename
ALay
>
static
auto
MakeAGridDescriptor
(
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
a_g_n_c_wis_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
a_g_n_c_wis_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
b_g_k_c_xs_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
b_g_k_c_xs_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
e_g_n_k_wos_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
e_g_n_k_wos_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_dilations
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_left_pads
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_right_pads
)
static
auto
MakeAGridDescriptor
(
const
ConvToGemmFwdTransformer
&
conv_to_gemm_transformer
)
{
const
auto
in_gemmmraw_gemmkraw_desc
=
conv_to_gemm_transformer
.
template
MakeADescriptor_M_K
<
ALay
>(
a_g_n_c_wis_lengths
,
a_g_n_c_wis_strides
,
b_g_k_c_xs_lengths
,
b_g_k_c_xs_strides
,
e_g_n_k_wos_lengths
,
e_g_n_k_wos_strides
,
conv_filter_strides
,
conv_filter_dilations
,
input_left_pads
,
input_right_pads
,
a_g_n_c_wis_lengths
[
I1
]);
conv_to_gemm_transformer
.
template
MakeADescriptor_M_K
<
ALay
>();
const
auto
in_gemmm_gemmk_desc
=
matrix_padder
.
PadADescriptor_M_K
(
in_gemmmraw_gemmkraw_desc
);
...
...
@@ -205,12 +185,10 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
}
template
<
typename
BLay
>
static
auto
MakeBGridDescriptor
(
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
b_g_k_c_xs_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
b_g_k_c_xs_strides
)
static
auto
MakeBGridDescriptor
(
const
ConvToGemmFwdTransformer
&
conv_to_gemm_transformer
)
{
const
auto
wei_gemmnraw_gemmkraw_desc
=
conv_to_gemm_transformer
.
template
MakeBDescriptor_N_K
<
BLay
>(
b_g_k_c_xs_lengths
,
b_g_k_c_xs_strides
);
conv_to_gemm_transformer
.
template
MakeBDescriptor_N_K
<
BLay
>();
const
auto
wei_gemmn_gemmk_desc
=
matrix_padder
.
PadBDescriptor_N_K
(
wei_gemmnraw_gemmkraw_desc
);
...
...
@@ -251,13 +229,10 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
}
template
<
typename
ELay
>
static
auto
MakeEGridDescriptor_M_N
(
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
e_g_n_k_wos_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
e_g_n_k_wos_strides
)
static
auto
MakeEGridDescriptor_M_N
(
const
ConvToGemmFwdTransformer
&
conv_to_gemm_transformer
)
{
const
auto
out_gemmmraw_gemmnraw_desc
=
conv_to_gemm_transformer
.
template
MakeCDescriptor_M_N
<
ELay
>(
e_g_n_k_wos_lengths
,
e_g_n_k_wos_strides
,
e_g_n_k_wos_lengths
[
I1
]);
conv_to_gemm_transformer
.
template
MakeCDescriptor_M_N
<
ELay
>();
const
auto
out_gemmm_gemmn_desc
=
matrix_padder
.
PadCDescriptor_M_N
(
out_gemmmraw_gemmnraw_desc
);
...
...
@@ -265,26 +240,27 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
return
out_gemmm_gemmn_desc
;
}
static
auto
MakeDsGridDescriptor_M_N
(
const
std
::
array
<
std
::
array
<
index_t
,
NDimSpatial
+
3
>
,
NumDTensor
>&
ds_g_n_k_wos_lengths
,
const
std
::
array
<
std
::
array
<
index_t
,
NDimSpatial
+
3
>
,
NumDTensor
>&
ds_g_n_k_wos_strides
)
static
auto
MakeDsGridDescriptor_M_N
(
const
ConvToGemmFwdTransformer
&
conv_to_gemm_transformer
)
{
return
generate_tuple
(
[
&
](
auto
i
)
{
using
DLayout
=
remove_cvref_t
<
tuple_element_t
<
i
.
value
,
DsLayout
>>
;
return
DeviceOp
::
MakeEGridDescriptor_M_N
<
DLayout
>
(
ds_g_n_k_wos_lengths
[
i
],
ds_g_n_k_wos_strides
[
i
]);
return
DeviceOp
::
MakeEGridDescriptor_M_N
<
DLayout
>
(
conv_to_gemm_transformer
);
},
Number
<
NumDTensor
>
{});
}
// desc for problem definition
constexpr
static
ConvToGemmFwdTransformer
dummy_conv_to_gemm_transformer
;
using
AGridDesc
=
decltype
(
DeviceOp
::
MakeAGridDescriptor
<
ALayout
>
({},
{},
{},
{},
{},
{},
{},
{},
{},
{}));
using
BGridDesc
=
decltype
(
DeviceOp
::
MakeBGridDescriptor
<
BLayout
>
({},
{}));
using
DsGridDesc_M_N
=
remove_cvref_t
<
decltype
(
MakeDsGridDescriptor_M_N
({},
{}))
>
;
using
EGridDesc_M_N
=
remove_cvref_t
<
decltype
(
MakeEGridDescriptor_M_N
<
ELayout
>
({},
{}))
>
;
decltype
(
DeviceOp
::
MakeAGridDescriptor
<
ALayout
>
(
dummy_conv_to_gemm_transformer
));
using
BGridDesc
=
decltype
(
DeviceOp
::
MakeBGridDescriptor
<
BLayout
>
(
dummy_conv_to_gemm_transformer
));
using
DsGridDesc_M_N
=
remove_cvref_t
<
decltype
(
MakeDsGridDescriptor_M_N
(
dummy_conv_to_gemm_transformer
))
>
;
using
EGridDesc_M_N
=
remove_cvref_t
<
decltype
(
MakeEGridDescriptor_M_N
<
ELayout
>
(
dummy_conv_to_gemm_transformer
))
>
;
// GridwiseOp
using
GridwiseOp
=
GridwiseGemmMultipleD_Wmma
<
...
...
@@ -373,21 +349,21 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
p_ds_grid_
{},
p_e_grid_
{
static_cast
<
EDataType
*>
(
p_e
)},
num_group_
{
a_g_n_c_wis_lengths
[
0
]},
conv_to_gemm_transformer_
{
a_g_n_c_wis_lengths
,
a_g_n_c_wis_strides
,
b_g_k_c_xs_lengths
,
b_g_k_c_xs_strides
,
e_g_n_k_wos_lengths
,
e_g_n_k_wos_strides
,
conv_filter_strides
,
conv_filter_dilations
,
input_left_pads
,
input_right_pads
},
ds_grid_desc_m_n_
{},
e_grid_desc_m_n_
{
DeviceOp
::
MakeEGridDescriptor_M_N
<
ELayout
>
(
e_g_n_k_wos_lengths
,
e_g_n_k_wos_strides
)},
a_grid_desc_
{
DeviceOp
::
MakeAGridDescriptor
<
ALayout
>
(
a_g_n_c_wis_lengths
,
a_g_n_c_wis_strides
,
b_g_k_c_xs_lengths
,
b_g_k_c_xs_strides
,
e_g_n_k_wos_lengths
,
e_g_n_k_wos_strides
,
conv_filter_strides
,
conv_filter_dilations
,
input_left_pads
,
input_right_pads
)},
b_grid_desc_
{
DeviceOp
::
MakeBGridDescriptor
<
BLayout
>
(
b_g_k_c_xs_lengths
,
b_g_k_c_xs_strides
)},
e_grid_desc_m_n_
{
DeviceOp
::
MakeEGridDescriptor_M_N
<
ELayout
>
(
conv_to_gemm_transformer_
)},
a_grid_desc_
{
DeviceOp
::
MakeAGridDescriptor
<
ALayout
>
(
conv_to_gemm_transformer_
)},
b_grid_desc_
{
DeviceOp
::
MakeBGridDescriptor
<
BLayout
>
(
conv_to_gemm_transformer_
)},
ds_grid_desc_mblock_mperblock_nblock_nperblock_
{},
e_grid_desc_mblock_mperblock_nblock_nperblock_
{},
block_2_etile_map_
{
GridwiseOp
::
MakeDefaultBlock2CTileMap
(
e_grid_desc_m_n_
,
M01
,
N01
)},
...
...
@@ -426,8 +402,24 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
});
// D desc
ds_grid_desc_m_n_
=
DeviceOp
::
MakeDsGridDescriptor_M_N
(
ds_g_n_k_wos_lengths
,
ds_g_n_k_wos_strides
);
ds_grid_desc_m_n_
=
generate_tuple
(
[
&
](
auto
i
)
{
using
DLayout
=
remove_cvref_t
<
tuple_element_t
<
i
.
value
,
DsLayout
>>
;
ConvToGemmFwdTransformer
conv_to_gemm_transformer_d
{
a_g_n_c_wis_lengths
,
a_g_n_c_wis_strides
,
b_g_k_c_xs_lengths
,
b_g_k_c_xs_strides
,
ds_g_n_k_wos_lengths
[
i
],
ds_g_n_k_wos_strides
[
i
],
conv_filter_strides
,
conv_filter_dilations
,
input_left_pads
,
input_right_pads
};
return
DeviceOp
::
MakeEGridDescriptor_M_N
<
DLayout
>
(
conv_to_gemm_transformer_d
);
},
Number
<
NumDTensor
>
{});
// populate desc for Ds/E
e_grid_desc_mblock_mperblock_nblock_nperblock_
=
...
...
@@ -455,6 +447,9 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
// tensor descriptors for problem definiton
index_t
num_group_
;
ConvToGemmFwdTransformer
conv_to_gemm_transformer_
;
DsGridDesc_M_N
ds_grid_desc_m_n_
;
EGridDesc_M_N
e_grid_desc_m_n_
;
...
...
@@ -777,6 +772,81 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
cde_element_op
};
}
static
auto
MakeArgument
(
const
void
*
p_a
,
const
void
*
p_b
,
const
std
::
array
<
const
void
*
,
NumDTensor
>&
p_ds
,
void
*
p_e
,
const
std
::
array
<
long_index_t
,
NDimSpatial
+
3
>&
a_g_n_c_wis_lengths
,
const
std
::
array
<
long_index_t
,
NDimSpatial
+
3
>&
a_g_n_c_wis_strides
,
const
std
::
array
<
long_index_t
,
NDimSpatial
+
3
>&
b_g_k_c_xs_lengths
,
const
std
::
array
<
long_index_t
,
NDimSpatial
+
3
>&
b_g_k_c_xs_strides
,
const
std
::
array
<
std
::
array
<
long_index_t
,
NDimSpatial
+
3
>
,
NumDTensor
>&
ds_g_n_k_wos_lengths
,
const
std
::
array
<
std
::
array
<
long_index_t
,
NDimSpatial
+
3
>
,
NumDTensor
>&
ds_g_n_k_wos_strides
,
const
std
::
array
<
long_index_t
,
NDimSpatial
+
3
>&
e_g_n_k_wos_lengths
,
const
std
::
array
<
long_index_t
,
NDimSpatial
+
3
>&
e_g_n_k_wos_strides
,
const
std
::
array
<
long_index_t
,
NDimSpatial
>&
conv_filter_strides
,
const
std
::
array
<
long_index_t
,
NDimSpatial
>&
conv_filter_dilations
,
const
std
::
array
<
long_index_t
,
NDimSpatial
>&
input_left_pads
,
const
std
::
array
<
long_index_t
,
NDimSpatial
>&
input_right_pads
,
const
AElementwiseOperation
&
a_element_op
,
const
BElementwiseOperation
&
b_element_op
,
const
CDEElementwiseOperation
&
cde_element_op
)
{
std
::
array
<
index_t
,
NDimSpatial
+
3
>
a_g_n_c_wis_lengths_i32
;
std
::
array
<
index_t
,
NDimSpatial
+
3
>
a_g_n_c_wis_strides_i32
;
std
::
array
<
index_t
,
NDimSpatial
+
3
>
b_g_k_c_xs_lengths_i32
;
std
::
array
<
index_t
,
NDimSpatial
+
3
>
b_g_k_c_xs_strides_i32
;
std
::
array
<
std
::
array
<
index_t
,
NDimSpatial
+
3
>
,
NumDTensor
>
ds_g_n_k_wos_lengths_i32
;
std
::
array
<
std
::
array
<
index_t
,
NDimSpatial
+
3
>
,
NumDTensor
>
ds_g_n_k_wos_strides_i32
;
std
::
array
<
index_t
,
NDimSpatial
+
3
>
e_g_n_k_wos_lengths_i32
;
std
::
array
<
index_t
,
NDimSpatial
+
3
>
e_g_n_k_wos_strides_i32
;
std
::
array
<
index_t
,
NDimSpatial
>
conv_filter_strides_i32
;
std
::
array
<
index_t
,
NDimSpatial
>
conv_filter_dilations_i32
;
std
::
array
<
index_t
,
NDimSpatial
>
input_left_pads_i32
;
std
::
array
<
index_t
,
NDimSpatial
>
input_right_pads_i32
;
array_convert
(
a_g_n_c_wis_lengths_i32
,
a_g_n_c_wis_lengths
);
array_convert
(
a_g_n_c_wis_strides_i32
,
a_g_n_c_wis_strides
);
array_convert
(
b_g_k_c_xs_lengths_i32
,
b_g_k_c_xs_lengths
);
array_convert
(
b_g_k_c_xs_strides_i32
,
b_g_k_c_xs_strides
);
for
(
index_t
d
=
0
;
d
<
NumDTensor
;
d
++
)
{
array_convert
(
ds_g_n_k_wos_lengths_i32
[
d
],
ds_g_n_k_wos_lengths
[
d
]);
array_convert
(
ds_g_n_k_wos_strides_i32
[
d
],
ds_g_n_k_wos_strides
[
d
]);
}
array_convert
(
e_g_n_k_wos_lengths_i32
,
e_g_n_k_wos_lengths
);
array_convert
(
e_g_n_k_wos_strides_i32
,
e_g_n_k_wos_strides
);
array_convert
(
conv_filter_strides_i32
,
conv_filter_strides
);
array_convert
(
conv_filter_dilations_i32
,
conv_filter_dilations
);
array_convert
(
input_left_pads_i32
,
input_left_pads
);
array_convert
(
input_right_pads_i32
,
input_right_pads
);
return
Argument
{
p_a
,
p_b
,
p_ds
,
p_e
,
a_g_n_c_wis_lengths_i32
,
a_g_n_c_wis_strides_i32
,
b_g_k_c_xs_lengths_i32
,
b_g_k_c_xs_strides_i32
,
ds_g_n_k_wos_lengths_i32
,
ds_g_n_k_wos_strides_i32
,
e_g_n_k_wos_lengths_i32
,
e_g_n_k_wos_strides_i32
,
conv_filter_strides_i32
,
conv_filter_dilations_i32
,
input_left_pads_i32
,
input_right_pads_i32
,
1
,
1
,
a_element_op
,
b_element_op
,
cde_element_op
};
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
...
...
@@ -823,6 +893,81 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
cde_element_op
);
}
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
void
*
p_a
,
const
void
*
p_b
,
const
std
::
array
<
const
void
*
,
NumDTensor
>&
p_ds
,
void
*
p_e
,
const
std
::
array
<
long_index_t
,
NDimSpatial
+
3
>&
a_g_n_c_wis_lengths
,
const
std
::
array
<
long_index_t
,
NDimSpatial
+
3
>&
a_g_n_c_wis_strides
,
const
std
::
array
<
long_index_t
,
NDimSpatial
+
3
>&
b_g_k_c_xs_lengths
,
const
std
::
array
<
long_index_t
,
NDimSpatial
+
3
>&
b_g_k_c_xs_strides
,
const
std
::
array
<
std
::
array
<
long_index_t
,
NDimSpatial
+
3
>
,
NumDTensor
>&
ds_g_n_k_wos_lengths
,
const
std
::
array
<
std
::
array
<
long_index_t
,
NDimSpatial
+
3
>
,
NumDTensor
>&
ds_g_n_k_wos_strides
,
const
std
::
array
<
long_index_t
,
NDimSpatial
+
3
>&
e_g_n_k_wos_lengths
,
const
std
::
array
<
long_index_t
,
NDimSpatial
+
3
>&
e_g_n_k_wos_strides
,
const
std
::
array
<
long_index_t
,
NDimSpatial
>&
conv_filter_strides
,
const
std
::
array
<
long_index_t
,
NDimSpatial
>&
conv_filter_dilations
,
const
std
::
array
<
long_index_t
,
NDimSpatial
>&
input_left_pads
,
const
std
::
array
<
long_index_t
,
NDimSpatial
>&
input_right_pads
,
const
AElementwiseOperation
&
a_element_op
,
const
BElementwiseOperation
&
b_element_op
,
const
CDEElementwiseOperation
&
cde_element_op
)
override
{
std
::
array
<
index_t
,
NDimSpatial
+
3
>
a_g_n_c_wis_lengths_i32
;
std
::
array
<
index_t
,
NDimSpatial
+
3
>
a_g_n_c_wis_strides_i32
;
std
::
array
<
index_t
,
NDimSpatial
+
3
>
b_g_k_c_xs_lengths_i32
;
std
::
array
<
index_t
,
NDimSpatial
+
3
>
b_g_k_c_xs_strides_i32
;
std
::
array
<
std
::
array
<
index_t
,
NDimSpatial
+
3
>
,
NumDTensor
>
ds_g_n_k_wos_lengths_i32
;
std
::
array
<
std
::
array
<
index_t
,
NDimSpatial
+
3
>
,
NumDTensor
>
ds_g_n_k_wos_strides_i32
;
std
::
array
<
index_t
,
NDimSpatial
+
3
>
e_g_n_k_wos_lengths_i32
;
std
::
array
<
index_t
,
NDimSpatial
+
3
>
e_g_n_k_wos_strides_i32
;
std
::
array
<
index_t
,
NDimSpatial
>
conv_filter_strides_i32
;
std
::
array
<
index_t
,
NDimSpatial
>
conv_filter_dilations_i32
;
std
::
array
<
index_t
,
NDimSpatial
>
input_left_pads_i32
;
std
::
array
<
index_t
,
NDimSpatial
>
input_right_pads_i32
;
array_convert
(
a_g_n_c_wis_lengths_i32
,
a_g_n_c_wis_lengths
);
array_convert
(
a_g_n_c_wis_strides_i32
,
a_g_n_c_wis_strides
);
array_convert
(
b_g_k_c_xs_lengths_i32
,
b_g_k_c_xs_lengths
);
array_convert
(
b_g_k_c_xs_strides_i32
,
b_g_k_c_xs_strides
);
for
(
index_t
d
=
0
;
d
<
NumDTensor
;
d
++
)
{
array_convert
(
ds_g_n_k_wos_lengths_i32
[
d
],
ds_g_n_k_wos_lengths
[
d
]);
array_convert
(
ds_g_n_k_wos_strides_i32
[
d
],
ds_g_n_k_wos_strides
[
d
]);
}
array_convert
(
e_g_n_k_wos_lengths_i32
,
e_g_n_k_wos_lengths
);
array_convert
(
e_g_n_k_wos_strides_i32
,
e_g_n_k_wos_strides
);
array_convert
(
conv_filter_strides_i32
,
conv_filter_strides
);
array_convert
(
conv_filter_dilations_i32
,
conv_filter_dilations
);
array_convert
(
input_left_pads_i32
,
input_left_pads
);
array_convert
(
input_right_pads_i32
,
input_right_pads
);
return
std
::
make_unique
<
Argument
>
(
p_a
,
p_b
,
p_ds
,
p_e
,
a_g_n_c_wis_lengths_i32
,
a_g_n_c_wis_strides_i32
,
b_g_k_c_xs_lengths_i32
,
b_g_k_c_xs_strides_i32
,
ds_g_n_k_wos_lengths_i32
,
ds_g_n_k_wos_strides_i32
,
e_g_n_k_wos_lengths_i32
,
e_g_n_k_wos_strides_i32
,
conv_filter_strides_i32
,
conv_filter_dilations_i32
,
input_left_pads_i32
,
input_right_pads_i32
,
1
,
1
,
a_element_op
,
b_element_op
,
cde_element_op
);
}
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
override
{
return
std
::
make_unique
<
Invoker
>
(
Invoker
{});
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp
0 → 100644
View file @
bd689f40
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <functional>
#include <iostream>
#include <iterator>
#include <numeric>
#include <queue>
#include <sstream>
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp"
#include "ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_abd.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/host_utility/io.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
{
template
<
typename
GridwiseGemm
,
index_t
MaxGemmsNum
,
typename
GemmArgs
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CDEElementwiseOperation
,
typename
ComputePtrOffset
,
bool
HasMainKBlockLoop
>
__global__
void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
#endif
kernel_grouped_conv_fwd_multiple_d_grouped_gemm_xdl_cshuffle
(
Array
<
GemmArgs
,
MaxGemmsNum
>
gemm_desc_kernel_args
,
const
index_t
gemms_count
,
const
AElementwiseOperation
a_element_op
,
const
BElementwiseOperation
b_element_op
,
const
CDEElementwiseOperation
c_element_op
,
const
ComputePtrOffset
compute_ptr_offset_of_groups
,
const
ComputePtrOffset
compute_ptr_offset_of_n
)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx94__))
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
const
index_t
block_id_x
=
__builtin_amdgcn_readfirstlane
(
blockIdx
.
x
);
const
index_t
g_idx
=
__builtin_amdgcn_readfirstlane
(
blockIdx
.
y
);
const
index_t
n_idx
=
__builtin_amdgcn_readfirstlane
(
blockIdx
.
z
);
const
long_index_t
a_group_offset
=
amd_wave_read_first_lane
(
compute_ptr_offset_of_groups
.
GetAPtrOffset
(
g_idx
));
const
long_index_t
b_group_offset
=
amd_wave_read_first_lane
(
compute_ptr_offset_of_groups
.
GetBPtrOffset
(
g_idx
));
const
long_index_t
e_group_offset
=
amd_wave_read_first_lane
(
compute_ptr_offset_of_groups
.
GetEPtrOffset
(
g_idx
));
const
long_index_t
a_n_offset
=
amd_wave_read_first_lane
(
compute_ptr_offset_of_n
.
GetAPtrOffset
(
n_idx
));
const
long_index_t
e_n_offset
=
amd_wave_read_first_lane
(
compute_ptr_offset_of_n
.
GetEPtrOffset
(
n_idx
));
index_t
left
=
0
;
index_t
right
=
gemms_count
;
index_t
group_id
=
index_t
((
left
+
right
)
/
2
);
while
((
!
(
block_id_x
>=
gemm_desc_kernel_args
[
group_id
].
BlockStart_
&&
block_id_x
<
gemm_desc_kernel_args
[
group_id
].
BlockEnd_
))
&&
left
<=
right
)
{
if
(
block_id_x
<
gemm_desc_kernel_args
[
group_id
].
BlockStart_
)
{
right
=
group_id
;
}
else
{
left
=
group_id
;
}
group_id
=
index_t
((
left
+
right
)
/
2
);
}
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
gemm_desc_kernel_args
[
group_id
].
a_ptr_
+
a_group_offset
+
a_n_offset
,
gemm_desc_kernel_args
[
group_id
].
b_ptr_
+
b_group_offset
,
Tuple
<>
{},
gemm_desc_kernel_args
[
group_id
].
e_ptr_
+
e_group_offset
+
e_n_offset
,
p_shared
,
a_element_op
,
b_element_op
,
c_element_op
,
gemm_desc_kernel_args
[
group_id
].
a_grid_desc_ak0_m_ak1_
,
gemm_desc_kernel_args
[
group_id
].
b_grid_desc_bk0_n_bk1_
,
Tuple
<>
{},
gemm_desc_kernel_args
[
group_id
].
e_grid_desc_mblock_mperblock_nblock_nperblock_
,
gemm_desc_kernel_args
[
group_id
].
block_2_etile_map_
);
#else
ignore
=
gemm_desc_kernel_args
;
ignore
=
gemms_count
;
ignore
=
a_element_op
;
ignore
=
b_element_op
;
ignore
=
c_element_op
;
ignore
=
compute_ptr_offset_of_groups
;
ignore
=
compute_ptr_offset_of_n
;
#endif
}
}
// namespace
template
<
typename
T
>
using
is_tuple
=
decltype
(
std
::
declval
<
T
&>
().
IsTuple
());
template
<
index_t
NDimSpatial
,
typename
ALayout
,
typename
BLayout
,
typename
DsLayout
,
typename
ELayout
,
typename
ADataType
,
typename
BDataType
,
typename
AccDataType
,
typename
CShuffleDataType
,
typename
DsDataType
,
typename
EDataType
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CDEElementwiseOperation
,
ConvolutionForwardSpecialization
ConvForwardSpecialization
,
GemmSpecialization
GemmSpec
,
index_t
NumGemmKPrefetchStage
,
index_t
BlockSize
,
index_t
MPerBlock
,
index_t
NPerBlock
,
index_t
KPerBlock
,
index_t
AK1
,
index_t
BK1
,
index_t
MPerXDL
,
index_t
NPerXDL
,
index_t
MXdlPerWave
,
index_t
NXdlPerWave
,
typename
ABlockTransferThreadClusterLengths_AK0_M_AK1
,
typename
ABlockTransferThreadClusterArrangeOrder
,
typename
ABlockTransferSrcAccessOrder
,
index_t
ABlockTransferSrcVectorDim
,
index_t
ABlockTransferSrcScalarPerVector
,
index_t
ABlockTransferDstScalarPerVector_AK1
,
index_t
ABlockLdsExtraM
,
typename
BBlockTransferThreadClusterLengths_BK0_N_BK1
,
typename
BBlockTransferThreadClusterArrangeOrder
,
typename
BBlockTransferSrcAccessOrder
,
index_t
BBlockTransferSrcVectorDim
,
index_t
BBlockTransferSrcScalarPerVector
,
index_t
BBlockTransferDstScalarPerVector_BK1
,
index_t
BBlockLdsExtraN
,
index_t
CShuffleMXdlPerWavePerShuffle
,
index_t
CShuffleNXdlPerWavePerShuffle
,
typename
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
index_t
CDEBlockTransferScalarPerVector_NPerBlock
,
typename
AComputeDataType
=
decltype
(
UnpackDataType
<
is_detected
<
is_tuple
,
ADataType
>
::
value
,
Number
<
0
>
,
ADataType
>
()),
// ComputeType is InputType by default (first
// in tuple for MultiAB), unpack if tuple was
// passed
typename
BComputeDataType
=
AComputeDataType
,
LoopScheduler
LoopSched
=
make_default_loop_scheduler
()
>
struct
DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor
:
public
DeviceGroupedConvFwdMultipleABD
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
ADataType
,
BDataType
,
DsDataType
,
EDataType
,
AElementwiseOperation
,
BElementwiseOperation
,
CDEElementwiseOperation
,
AComputeDataType
,
BComputeDataType
>
{
using
DeviceOp
=
DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor
;
static
constexpr
index_t
NumDTensor
=
DsDataType
::
Size
();
static
constexpr
index_t
MaxGemmsNum
=
32
;
static_assert
(
NumDTensor
==
0
,
"MultiD not supported."
);
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
using
ConvToGemmFwdTransformerIndexT
=
TransformConvFwdToGemm
<
NDimSpatial
,
ConvForwardSpecialization
,
true
/*SplitN*/
,
ADataType
,
EDataType
,
I1
,
index_t
>
;
using
ConvToGemmFwdTransformerLongIndexT
=
TransformConvFwdToGemm
<
NDimSpatial
,
ConvForwardSpecialization
,
true
/*SplitN*/
,
ADataType
,
EDataType
,
I1
,
long_index_t
>
;
static
constexpr
auto
matrix_padder
=
MatrixPadder
<
GemmSpec
,
index_t
,
index_t
,
index_t
>
{
MPerBlock
,
NPerBlock
,
KPerBlock
};
template
<
typename
ALay
>
static
auto
MakeAGridDescriptor_M_K
(
const
ConvToGemmFwdTransformerIndexT
&
conv_to_gemm_transformer
)
{
const
auto
in_gemmmraw_gemmkraw_desc
=
conv_to_gemm_transformer
.
template
MakeADescriptor_M_K
<
ALay
>();
const
auto
in_gemmm_gemmk_desc
=
matrix_padder
.
PadADescriptor_M_K
(
in_gemmmraw_gemmkraw_desc
);
return
in_gemmm_gemmk_desc
;
}
template
<
typename
BLay
>
static
auto
MakeBGridDescriptor_N_K
(
const
ConvToGemmFwdTransformerIndexT
&
conv_to_gemm_transformer
)
{
const
auto
wei_gemmnraw_gemmkraw_desc
=
conv_to_gemm_transformer
.
template
MakeBDescriptor_N_K
<
BLay
>();
const
auto
wei_gemmn_gemmk_desc
=
matrix_padder
.
PadBDescriptor_N_K
(
wei_gemmnraw_gemmkraw_desc
);
return
wei_gemmn_gemmk_desc
;
}
template
<
typename
ELay
>
static
auto
MakeEGridDescriptor_M_N
(
const
ConvToGemmFwdTransformerIndexT
&
conv_to_gemm_transformer
)
{
const
auto
out_gemmmraw_gemmnraw_desc
=
conv_to_gemm_transformer
.
template
MakeCDescriptor_M_N
<
ELay
>();
const
auto
out_gemmm_gemmn_desc
=
matrix_padder
.
PadCDescriptor_M_N
(
out_gemmmraw_gemmnraw_desc
);
return
out_gemmm_gemmn_desc
;
}
// desc for problem definition
constexpr
static
ConvToGemmFwdTransformerIndexT
dummy_conv_to_gemm_transformer
;
using
AGridDesc_M_K
=
remove_cvref_t
<
decltype
(
MakeAGridDescriptor_M_K
<
ALayout
>
(
dummy_conv_to_gemm_transformer
))
>
;
using
BGridDesc_N_K
=
remove_cvref_t
<
decltype
(
MakeBGridDescriptor_N_K
<
BLayout
>
(
dummy_conv_to_gemm_transformer
))
>
;
using
EGridDesc_M_N
=
remove_cvref_t
<
decltype
(
MakeEGridDescriptor_M_N
<
ELayout
>
(
dummy_conv_to_gemm_transformer
))
>
;
static
auto
GenerateConvToGemmTransforms
(
ConvToGemmFwdTransformerLongIndexT
conv_to_gemm_transformer_base
,
const
ADataType
*
a_grid_ptr_base
,
EDataType
*
c_grid_ptr_base
)
{
// Max number of splits
// We need to use it to avoid infinity loop
constexpr
index_t
max_split_numbers
=
MaxGemmsNum
/
2
;
// Arrays to store transformers with smaller descs than 2GB
Array
<
ConvToGemmFwdTransformerIndexT
,
MaxGemmsNum
>
conv_to_gemm_transformers_arr
;
Array
<
const
ADataType
*
,
MaxGemmsNum
>
a_grid_ptrs_arr
;
Array
<
EDataType
*
,
MaxGemmsNum
>
c_grid_ptrs_arr
;
// Queue for spliting
std
::
queue
<
ConvToGemmFwdTransformerLongIndexT
>
conv_to_gemm_transformers_queue
(
{
conv_to_gemm_transformer_base
});
std
::
queue
<
const
ADataType
*>
a_grid_ptrs_queue
({
a_grid_ptr_base
});
std
::
queue
<
EDataType
*>
c_grid_ptrs_queue
({
c_grid_ptr_base
});
index_t
gemms_number
=
0
;
index_t
split_numbers
=
0
;
// Algorithm:
// While queue is not empty:
// 1. Get transformer from queue.
// 2. If descs are smaller than 2GB push to result array.
// 3. If descs are bigger than 2GB split into left and right transformer.
// and push the both into the queue.
while
(
!
conv_to_gemm_transformers_queue
.
empty
()
&&
split_numbers
<
max_split_numbers
&&
gemms_number
<
MaxGemmsNum
)
{
// Get transformer from the queue
const
auto
&
conv_to_gemm_transformer
=
conv_to_gemm_transformers_queue
.
front
();
const
ADataType
*
a_grid_ptr
=
a_grid_ptrs_queue
.
front
();
EDataType
*
c_grid_ptr
=
c_grid_ptrs_queue
.
front
();
// Check if convolution not exceed 2GB
if
(
conv_to_gemm_transformer
.
AreDescriptorsSmallerThan2GB
())
{
// If yes, push into result array
conv_to_gemm_transformers_arr
(
gemms_number
)
=
ConvToGemmFwdTransformerIndexT
{
conv_to_gemm_transformer
};
a_grid_ptrs_arr
(
gemms_number
)
=
a_grid_ptr
;
c_grid_ptrs_arr
(
gemms_number
)
=
c_grid_ptr
;
gemms_number
++
;
}
else
{
// If no, split into left and right convolutions
ConvToGemmFwdTransformerLongIndexT
conv_to_gemm_transformers_left_part
,
conv_to_gemm_transformers_right_part
;
const
ADataType
*
a_grid_right_ptr
;
EDataType
*
c_grid_right_ptr
;
ck
::
tie
(
conv_to_gemm_transformers_left_part
,
conv_to_gemm_transformers_right_part
,
a_grid_right_ptr
,
c_grid_right_ptr
)
=
conv_to_gemm_transformer
.
SplitConvProblem
(
a_grid_ptr
,
c_grid_ptr
);
conv_to_gemm_transformers_queue
.
push
(
conv_to_gemm_transformers_left_part
);
conv_to_gemm_transformers_queue
.
push
(
conv_to_gemm_transformers_right_part
);
// Left offsets remain the same
a_grid_ptrs_queue
.
push
(
a_grid_ptr
);
a_grid_ptrs_queue
.
push
(
a_grid_right_ptr
);
c_grid_ptrs_queue
.
push
(
c_grid_ptr
);
c_grid_ptrs_queue
.
push
(
c_grid_right_ptr
);
split_numbers
++
;
}
// Remove from the queue
conv_to_gemm_transformers_queue
.
pop
();
a_grid_ptrs_queue
.
pop
();
c_grid_ptrs_queue
.
pop
();
}
const
bool
is_split_valid
=
conv_to_gemm_transformers_queue
.
empty
();
return
ck
::
make_tuple
(
conv_to_gemm_transformers_arr
,
a_grid_ptrs_arr
,
c_grid_ptrs_arr
,
gemms_number
,
is_split_valid
);
}
#define GridwiseGemmTemplateParameters \
ADataType, BDataType, AComputeDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, \
AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, \
InMemoryDataOperationEnum::Set, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, \
KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave, \
ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, \
ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, \
ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, \
ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, \
BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, \
BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, \
BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, \
CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, \
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, \
CDEBlockTransferScalarPerVector_NPerBlock, LoopSched, PipelineVersion::v1, \
AComputeDataType
// Use appropriate gridwise gemm
using
GridwiseGemm
=
GridwiseGemmMultipleD_xdl_cshuffle
<
GridwiseGemmTemplateParameters
>
;
// desc for blockwise copy
using
AGridDesc_AK0_M_AK1
=
remove_cvref_t
<
decltype
(
GridwiseGemm
::
MakeDefaultAGridDescriptor_AK0_M_AK1
(
AGridDesc_M_K
{}))
>
;
using
BGridDesc_BK0_N_BK1
=
remove_cvref_t
<
decltype
(
GridwiseGemm
::
MakeDefaultBGridDescriptor_BK0_N_BK1
(
BGridDesc_N_K
{}))
>
;
using
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
=
remove_cvref_t
<
decltype
(
GridwiseGemm
::
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
EGridDesc_M_N
{}))
>
;
// block-to-e-tile map
using
Block2ETileMap
=
remove_cvref_t
<
decltype
(
GridwiseGemm
::
MakeDefaultBlock2ETileMap
(
EGridDesc_M_N
{}))
>
;
// Structure for each gemm(conv)
struct
GemmArgs
{
// pointers
const
ADataType
*
a_ptr_
;
const
BDataType
*
b_ptr_
;
EDataType
*
e_ptr_
;
// tensor descriptors for block/thread-wise copy
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1_
;
BGridDesc_BK0_N_BK1
b_grid_desc_bk0_n_bk1_
;
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
e_grid_desc_mblock_mperblock_nblock_nperblock_
;
// block-to-e-tile map
Block2ETileMap
block_2_etile_map_
;
ck
::
index_t
BlockStart_
,
BlockEnd_
;
};
// Argument
struct
Argument
:
public
BaseArgument
{
Argument
(
const
void
*
p_a
,
const
void
*
p_b
,
const
std
::
array
<
const
void
*
,
NumDTensor
>&
/*p_ds*/
,
void
*
p_e
,
const
std
::
array
<
long_index_t
,
NDimSpatial
+
3
>&
a_g_n_c_wis_lengths
,
const
std
::
array
<
long_index_t
,
NDimSpatial
+
3
>&
a_g_n_c_wis_strides
,
const
std
::
array
<
long_index_t
,
NDimSpatial
+
3
>&
b_g_k_c_xs_lengths
,
const
std
::
array
<
long_index_t
,
NDimSpatial
+
3
>&
b_g_k_c_xs_strides
,
const
std
::
array
<
std
::
array
<
long_index_t
,
NDimSpatial
+
3
>
,
NumDTensor
>&
/*ds_g_n_k_wos_lengths*/
,
const
std
::
array
<
std
::
array
<
long_index_t
,
NDimSpatial
+
3
>
,
NumDTensor
>&
/*ds_g_n_k_wos_strides*/
,
const
std
::
array
<
long_index_t
,
NDimSpatial
+
3
>&
e_g_n_k_wos_lengths
,
const
std
::
array
<
long_index_t
,
NDimSpatial
+
3
>&
e_g_n_k_wos_strides
,
const
std
::
array
<
long_index_t
,
NDimSpatial
>&
conv_filter_strides
,
const
std
::
array
<
long_index_t
,
NDimSpatial
>&
conv_filter_dilations
,
const
std
::
array
<
long_index_t
,
NDimSpatial
>&
input_left_pads
,
const
std
::
array
<
long_index_t
,
NDimSpatial
>&
input_right_pads
,
const
AElementwiseOperation
&
a_element_op
,
const
BElementwiseOperation
&
b_element_op
,
const
CDEElementwiseOperation
&
cde_element_op
)
:
num_group_
{
static_cast
<
index_t
>
(
a_g_n_c_wis_lengths
[
0
])},
compute_ptr_offset_of_groups_
{},
compute_ptr_offset_of_n_
{},
a_element_op_
{
a_element_op
},
b_element_op_
{
b_element_op
},
cde_element_op_
{
cde_element_op
},
a_g_n_c_wis_lengths_
{
a_g_n_c_wis_lengths
},
a_g_n_c_wis_strides_
{
a_g_n_c_wis_strides
},
b_g_k_c_xs_lengths_
{
b_g_k_c_xs_lengths
},
b_g_k_c_xs_strides_
{
b_g_k_c_xs_strides
},
e_g_n_k_wos_lengths_
{
e_g_n_k_wos_lengths
},
e_g_n_k_wos_strides_
{
e_g_n_k_wos_strides
},
conv_filter_strides_
{
conv_filter_strides
},
conv_filter_dilations_
{
conv_filter_dilations
},
input_left_pads_
{
input_left_pads
},
input_right_pads_
{
input_right_pads
}
{
// Perform grouped gemm, generate array of tranformer for convolution
Array
<
ConvToGemmFwdTransformerIndexT
,
MaxGemmsNum
>
conv_to_gemm_transformer_arr
;
Array
<
const
ADataType
*
,
MaxGemmsNum
>
a_grid_ptrs
;
Array
<
EDataType
*
,
MaxGemmsNum
>
c_grid_ptrs
;
ck
::
tie
(
conv_to_gemm_transformer_arr
,
a_grid_ptrs
,
c_grid_ptrs
,
gemms_count_
,
is_split_valid_
)
=
GenerateConvToGemmTransforms
(
ConvToGemmFwdTransformerLongIndexT
{
a_g_n_c_wis_lengths_
,
a_g_n_c_wis_strides_
,
b_g_k_c_xs_lengths_
,
b_g_k_c_xs_strides_
,
e_g_n_k_wos_lengths_
,
e_g_n_k_wos_strides_
,
conv_filter_strides_
,
conv_filter_dilations_
,
input_left_pads_
,
input_right_pads_
},
static_cast
<
const
ADataType
*>
(
p_a
),
static_cast
<
EDataType
*>
(
p_e
));
grid_size_
=
0
;
valid_gemms_count_
=
0
;
if
(
is_split_valid_
)
{
// Create GemmArg for each gemm(conv)
for
(
index_t
i
=
0
;
i
<
gemms_count_
;
i
++
)
{
const
AGridDesc_M_K
a_grid_desc_m_k
{
DeviceOp
::
MakeAGridDescriptor_M_K
<
ALayout
>
(
conv_to_gemm_transformer_arr
[
i
])};
const
BGridDesc_N_K
b_grid_desc_n_k
{
DeviceOp
::
MakeBGridDescriptor_N_K
<
BLayout
>
(
conv_to_gemm_transformer_arr
[
i
])};
const
auto
e_grid_desc_m_n
=
DeviceOp
::
MakeEGridDescriptor_M_N
<
ELayout
>
(
conv_to_gemm_transformer_arr
[
i
]);
const
auto
block_2_etile_map
=
GridwiseGemm
::
MakeDefaultBlock2ETileMap
(
e_grid_desc_m_n
);
const
index_t
grid_size_grp
=
block_2_etile_map
.
CalculateGridSize
(
e_grid_desc_m_n
);
const
index_t
BlockStart
=
grid_size_
;
const
index_t
BlockEnd
=
grid_size_
+
grid_size_grp
;
grid_size_
+=
grid_size_grp
;
if
(
GridwiseGemm
::
CheckValidity
(
a_grid_desc_m_k
,
b_grid_desc_n_k
,
Tuple
<>
{},
e_grid_desc_m_n
,
block_2_etile_map
))
{
gemm_desc_kernel_args_
(
valid_gemms_count_
)
=
GemmArgs
{
a_grid_ptrs
[
i
],
static_cast
<
const
BDataType
*>
(
p_b
),
c_grid_ptrs
[
i
],
GridwiseGemm
::
MakeDefaultAGridDescriptor_AK0_M_AK1
(
a_grid_desc_m_k
),
GridwiseGemm
::
MakeDefaultBGridDescriptor_BK0_N_BK1
(
b_grid_desc_n_k
),
GridwiseGemm
::
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
e_grid_desc_m_n
),
block_2_etile_map
,
BlockStart
,
BlockEnd
};
valid_gemms_count_
++
;
}
}
// N is the same for all convs
conv_N_per_block_
=
static_cast
<
index_t
>
(
conv_to_gemm_transformer_arr
[
I0
].
N_
);
}
// Strides for G and N remain the same
compute_ptr_offset_of_groups_
.
BatchStrideA_
=
a_g_n_c_wis_strides
[
0
];
compute_ptr_offset_of_groups_
.
BatchStrideB_
=
b_g_k_c_xs_strides
[
0
];
compute_ptr_offset_of_groups_
.
BatchStrideE_
=
e_g_n_k_wos_strides
[
0
];
compute_ptr_offset_of_n_
.
BatchStrideA_
=
a_g_n_c_wis_strides
[
1
]
*
conv_N_per_block_
;
compute_ptr_offset_of_n_
.
BatchStrideE_
=
e_g_n_k_wos_strides
[
1
]
*
conv_N_per_block_
;
}
void
Print
()
const
{
for
(
index_t
i
=
0
;
i
<
valid_gemms_count_
;
i
++
)
{
std
::
cout
<<
"A[AK0, M, AK1]: "
<<
gemm_desc_kernel_args_
[
i
].
a_grid_desc_ak0_m_ak1_
<<
std
::
endl
;
std
::
cout
<<
"B[BK0, N, BK1]: "
<<
gemm_desc_kernel_args_
[
i
].
b_grid_desc_bk0_n_bk1_
<<
std
::
endl
;
std
::
cout
<<
"E[MBlock, MPerBlock, NBlock, NPerBlock]: "
<<
gemm_desc_kernel_args_
[
i
].
e_grid_desc_mblock_mperblock_nblock_nperblock_
<<
std
::
endl
;
}
}
index_t
num_group_
;
index_t
conv_N_per_block_
;
Array
<
GemmArgs
,
MaxGemmsNum
>
gemm_desc_kernel_args_
;
index_t
grid_size_
;
index_t
gemms_count_
;
index_t
valid_gemms_count_
;
bool
is_split_valid_
;
// for computing batch offset
ComputePtrOffsetOfStridedBatch
<
I1
,
I1
,
I0
>
compute_ptr_offset_of_groups_
;
ComputePtrOffsetOfStridedBatch
<
I1
,
I1
,
I0
>
compute_ptr_offset_of_n_
;
// element-wise op
AElementwiseOperation
a_element_op_
;
BElementwiseOperation
b_element_op_
;
CDEElementwiseOperation
cde_element_op_
;
// for checking IsSupportedArgument()
std
::
array
<
long_index_t
,
NDimSpatial
+
3
>
a_g_n_c_wis_lengths_
;
std
::
array
<
long_index_t
,
NDimSpatial
+
3
>
a_g_n_c_wis_strides_
;
std
::
array
<
long_index_t
,
NDimSpatial
+
3
>
b_g_k_c_xs_lengths_
;
std
::
array
<
long_index_t
,
NDimSpatial
+
3
>
b_g_k_c_xs_strides_
;
std
::
array
<
long_index_t
,
NDimSpatial
+
3
>
e_g_n_k_wos_lengths_
;
std
::
array
<
long_index_t
,
NDimSpatial
+
3
>
e_g_n_k_wos_strides_
;
std
::
array
<
long_index_t
,
NDimSpatial
>
conv_filter_strides_
;
std
::
array
<
long_index_t
,
NDimSpatial
>
conv_filter_dilations_
;
std
::
array
<
long_index_t
,
NDimSpatial
>
input_left_pads_
;
std
::
array
<
long_index_t
,
NDimSpatial
>
input_right_pads_
;
};
// Invoker
struct
Invoker
:
public
BaseInvoker
{
float
Run
(
const
DeviceOp
::
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
{
if
(
stream_config
.
log_level_
>
0
)
{
arg
.
Print
();
}
const
index_t
num_workgroups_per_Conv_N
=
arg
.
a_g_n_c_wis_lengths_
[
I1
]
/
arg
.
conv_N_per_block_
;
const
index_t
gdx
=
arg
.
grid_size_
;
const
index_t
gdy
=
arg
.
num_group_
;
const
index_t
gdz
=
num_workgroups_per_Conv_N
;
// K is constant for all gemms
const
auto
K
=
arg
.
gemm_desc_kernel_args_
[
I0
].
a_grid_desc_ak0_m_ak1_
.
GetLength
(
I0
)
*
arg
.
gemm_desc_kernel_args_
[
I0
].
a_grid_desc_ak0_m_ak1_
.
GetLength
(
I2
);
auto
launch_kernel
=
[
&
](
auto
has_main_k_block_loop
)
{
constexpr
bool
has_main_loop
=
has_main_k_block_loop
.
value
;
const
auto
kernel
=
kernel_grouped_conv_fwd_multiple_d_grouped_gemm_xdl_cshuffle
<
GridwiseGemm
,
MaxGemmsNum
,
GemmArgs
,
AElementwiseOperation
,
BElementwiseOperation
,
CDEElementwiseOperation
,
ComputePtrOffsetOfStridedBatch
<
I1
,
I1
,
I0
>
,
has_main_loop
>
;
return
launch_and_time_kernel
(
stream_config
,
kernel
,
dim3
(
gdx
,
gdy
,
gdz
),
dim3
(
BlockSize
),
0
,
arg
.
gemm_desc_kernel_args_
,
arg
.
gemms_count_
,
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
cde_element_op_
,
arg
.
compute_ptr_offset_of_groups_
,
arg
.
compute_ptr_offset_of_n_
);
};
if
(
GridwiseGemm
::
CalculateHasMainKBlockLoop
(
K
))
{
return
launch_kernel
(
integral_constant
<
bool
,
true
>
{});
}
else
{
return
launch_kernel
(
integral_constant
<
bool
,
false
>
{});
}
}
float
Run
(
const
BaseArgument
*
p_arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
override
{
return
Run
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
),
stream_config
);
}
};
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
namespace
ctc
=
tensor_layout
::
convolution
;
const
long_index_t
K
=
arg
.
b_g_k_c_xs_lengths_
[
I1
];
const
long_index_t
C
=
arg
.
b_g_k_c_xs_lengths_
[
I2
];
// Check if all descs are valid
if
(
!
(
arg
.
is_split_valid_
&&
arg
.
gemms_count_
==
arg
.
valid_gemms_count_
))
{
return
false
;
}
// check device
if
(
get_device_name
()
==
"gfx908"
)
{
// FIXME: re-enable fp64 when SWDEV-335738 is fixed
if
constexpr
(
!
(
is_same_v
<
AccDataType
,
float
>
||
is_same_v
<
AccDataType
,
int32_t
>
))
{
return
false
;
}
}
if
(
!
ck
::
is_xdl_supported
())
{
return
false
;
}
// check ConvolutionForwardSpecialization
if
constexpr
(
ConvForwardSpecialization
==
ConvolutionForwardSpecialization
::
Filter1x1Stride1Pad0
)
{
// check if it's 1x1, stride=1 conv
for
(
index_t
i
=
0
;
i
<
NDimSpatial
;
++
i
)
{
const
index_t
X
=
arg
.
b_g_k_c_xs_lengths_
[
i
+
3
];
const
index_t
ConvStride
=
arg
.
conv_filter_strides_
[
i
];
const
index_t
LeftPad
=
arg
.
input_left_pads_
[
i
];
const
index_t
RightPad
=
arg
.
input_right_pads_
[
i
];
if
(
!
(
X
==
1
&&
ConvStride
==
1
&&
LeftPad
==
0
&&
RightPad
==
0
))
{
return
false
;
}
}
}
else
if
constexpr
(
ConvForwardSpecialization
==
ConvolutionForwardSpecialization
::
Filter1x1Pad0
)
{
// check if it's 1x1 conv
for
(
index_t
i
=
0
;
i
<
NDimSpatial
;
++
i
)
{
const
index_t
X
=
arg
.
b_g_k_c_xs_lengths_
[
i
+
3
];
const
index_t
LeftPad
=
arg
.
input_left_pads_
[
i
];
const
index_t
RightPad
=
arg
.
input_right_pads_
[
i
];
if
(
!
(
X
==
1
&&
LeftPad
==
0
&&
RightPad
==
0
))
{
return
false
;
}
}
}
else
if
constexpr
(
ConvForwardSpecialization
==
ConvolutionForwardSpecialization
::
Filter3x3
)
{
if
(
C
!=
1
)
{
return
false
;
}
for
(
index_t
i
=
0
;
i
<
NDimSpatial
;
++
i
)
{
const
index_t
filter_spatial_dim
=
arg
.
b_g_k_c_xs_lengths_
[
i
+
I3
];
if
(
filter_spatial_dim
!=
I3
)
{
return
false
;
}
}
if
constexpr
(
!
is_NSpatialGK_GKSpatial_NSpatialGC
<
ALayout
,
BLayout
,
ELayout
>
())
{
return
false
;
}
}
// check vector access of A
// FIXME: layout
if
constexpr
(
is_same_v
<
ALayout
,
ctc
::
G_NW_C
>
||
is_same_v
<
ALayout
,
ctc
::
G_NHW_C
>
||
is_same_v
<
ALayout
,
ctc
::
G_NDHW_C
>
||
is_same_v
<
ALayout
,
ctc
::
GNWC
>
||
is_same_v
<
ALayout
,
ctc
::
GNHWC
>
||
is_same_v
<
ALayout
,
ctc
::
GNDHWC
>
||
is_same_v
<
ALayout
,
ctc
::
NWGC
>
||
is_same_v
<
ALayout
,
ctc
::
NHWGC
>
||
is_same_v
<
ALayout
,
ctc
::
NDHWGC
>
)
{
// Check access per C
if
(
!
(
ABlockTransferSrcVectorDim
==
2
&&
C
%
ABlockTransferSrcScalarPerVector
==
0
))
{
return
false
;
}
}
else
{
return
false
;
}
// check vector access of B
// FIXME: layout
if
constexpr
(
is_same_v
<
BLayout
,
ctc
::
G_K_X_C
>
||
is_same_v
<
BLayout
,
ctc
::
G_K_YX_C
>
||
is_same_v
<
BLayout
,
ctc
::
G_K_ZYX_C
>
||
is_same_v
<
BLayout
,
ctc
::
GKXC
>
||
is_same_v
<
BLayout
,
ctc
::
GKYXC
>
||
is_same_v
<
BLayout
,
ctc
::
GKZYXC
>
||
is_same_v
<
BLayout
,
ctc
::
KXGC
>
||
is_same_v
<
BLayout
,
ctc
::
KYXGC
>
||
is_same_v
<
BLayout
,
ctc
::
KZYXGC
>
)
{
if
(
!
(
BBlockTransferSrcVectorDim
==
2
&&
C
%
BBlockTransferSrcScalarPerVector
==
0
))
{
return
false
;
}
}
else
{
return
false
;
}
// check vector access of E
if
constexpr
(
is_same_v
<
ELayout
,
ctc
::
G_NW_K
>
||
is_same_v
<
ELayout
,
ctc
::
G_NHW_K
>
||
is_same_v
<
ELayout
,
ctc
::
G_NDHW_K
>
||
is_same_v
<
ELayout
,
ctc
::
GNWK
>
||
is_same_v
<
ELayout
,
ctc
::
GNHWK
>
||
is_same_v
<
ELayout
,
ctc
::
GNDHWK
>
||
is_same_v
<
ELayout
,
ctc
::
NWGK
>
||
is_same_v
<
ELayout
,
ctc
::
NHWGK
>
||
is_same_v
<
ELayout
,
ctc
::
NDHWGK
>
)
{
if
(
!
(
K
%
CDEBlockTransferScalarPerVector_NPerBlock
==
0
))
{
return
false
;
}
}
else
{
return
false
;
}
return
true
;
}
bool
IsSupportedArgument
(
const
BaseArgument
*
p_arg
)
override
{
return
IsSupportedArgument
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
));
}
static
auto
MakeArgument
(
const
void
*
p_a
,
const
void
*
p_b
,
const
std
::
array
<
const
void
*
,
NumDTensor
>&
p_ds
,
void
*
p_e
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
a_g_n_c_wis_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
a_g_n_c_wis_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
b_g_k_c_xs_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
b_g_k_c_xs_strides
,
const
std
::
array
<
std
::
array
<
index_t
,
NDimSpatial
+
3
>
,
NumDTensor
>&
ds_g_n_k_wos_lengths
,
const
std
::
array
<
std
::
array
<
index_t
,
NDimSpatial
+
3
>
,
NumDTensor
>&
ds_g_n_k_wos_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
e_g_n_k_wos_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
e_g_n_k_wos_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_dilations
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_left_pads
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_right_pads
,
const
AElementwiseOperation
&
a_element_op
,
const
BElementwiseOperation
&
b_element_op
,
const
CDEElementwiseOperation
&
cde_element_op
)
{
std
::
array
<
long_index_t
,
NDimSpatial
+
3
>
a_g_n_c_wis_lengths_i64
;
std
::
array
<
long_index_t
,
NDimSpatial
+
3
>
a_g_n_c_wis_strides_i64
;
std
::
array
<
long_index_t
,
NDimSpatial
+
3
>
b_g_k_c_xs_lengths_i64
;
std
::
array
<
long_index_t
,
NDimSpatial
+
3
>
b_g_k_c_xs_strides_i64
;
std
::
array
<
std
::
array
<
long_index_t
,
NDimSpatial
+
3
>
,
NumDTensor
>
ds_g_n_k_wos_lengths_i64
;
std
::
array
<
std
::
array
<
long_index_t
,
NDimSpatial
+
3
>
,
NumDTensor
>
ds_g_n_k_wos_strides_i64
;
std
::
array
<
long_index_t
,
NDimSpatial
+
3
>
e_g_n_k_wos_lengths_i64
;
std
::
array
<
long_index_t
,
NDimSpatial
+
3
>
e_g_n_k_wos_strides_i64
;
std
::
array
<
long_index_t
,
NDimSpatial
>
conv_filter_strides_i64
;
std
::
array
<
long_index_t
,
NDimSpatial
>
conv_filter_dilations_i64
;
std
::
array
<
long_index_t
,
NDimSpatial
>
input_left_pads_i64
;
std
::
array
<
long_index_t
,
NDimSpatial
>
input_right_pads_i64
;
array_convert
(
a_g_n_c_wis_lengths_i64
,
a_g_n_c_wis_lengths
);
array_convert
(
a_g_n_c_wis_strides_i64
,
a_g_n_c_wis_strides
);
array_convert
(
b_g_k_c_xs_lengths_i64
,
b_g_k_c_xs_lengths
);
array_convert
(
b_g_k_c_xs_strides_i64
,
b_g_k_c_xs_strides
);
for
(
index_t
d
=
0
;
d
<
NumDTensor
;
d
++
)
{
array_convert
(
ds_g_n_k_wos_lengths_i64
[
d
],
ds_g_n_k_wos_lengths
[
d
]);
array_convert
(
ds_g_n_k_wos_strides_i64
[
d
],
ds_g_n_k_wos_strides
[
d
]);
}
array_convert
(
e_g_n_k_wos_lengths_i64
,
e_g_n_k_wos_lengths
);
array_convert
(
e_g_n_k_wos_strides_i64
,
e_g_n_k_wos_strides
);
array_convert
(
conv_filter_strides_i64
,
conv_filter_strides
);
array_convert
(
conv_filter_dilations_i64
,
conv_filter_dilations
);
array_convert
(
input_left_pads_i64
,
input_left_pads
);
array_convert
(
input_right_pads_i64
,
input_right_pads
);
return
Argument
{
p_a
,
p_b
,
p_ds
,
p_e
,
a_g_n_c_wis_lengths_i64
,
a_g_n_c_wis_strides_i64
,
b_g_k_c_xs_lengths_i64
,
b_g_k_c_xs_strides_i64
,
ds_g_n_k_wos_lengths_i64
,
ds_g_n_k_wos_strides_i64
,
e_g_n_k_wos_lengths_i64
,
e_g_n_k_wos_strides_i64
,
conv_filter_strides_i64
,
conv_filter_dilations_i64
,
input_left_pads_i64
,
input_right_pads_i64
,
a_element_op
,
b_element_op
,
cde_element_op
};
}
static
auto
MakeArgument
(
const
void
*
p_a
,
const
void
*
p_b
,
const
std
::
array
<
const
void
*
,
NumDTensor
>&
p_ds
,
void
*
p_e
,
const
std
::
array
<
long_index_t
,
NDimSpatial
+
3
>&
a_g_n_c_wis_lengths
,
const
std
::
array
<
long_index_t
,
NDimSpatial
+
3
>&
a_g_n_c_wis_strides
,
const
std
::
array
<
long_index_t
,
NDimSpatial
+
3
>&
b_g_k_c_xs_lengths
,
const
std
::
array
<
long_index_t
,
NDimSpatial
+
3
>&
b_g_k_c_xs_strides
,
const
std
::
array
<
std
::
array
<
long_index_t
,
NDimSpatial
+
3
>
,
NumDTensor
>&
ds_g_n_k_wos_lengths
,
const
std
::
array
<
std
::
array
<
long_index_t
,
NDimSpatial
+
3
>
,
NumDTensor
>&
ds_g_n_k_wos_strides
,
const
std
::
array
<
long_index_t
,
NDimSpatial
+
3
>&
e_g_n_k_wos_lengths
,
const
std
::
array
<
long_index_t
,
NDimSpatial
+
3
>&
e_g_n_k_wos_strides
,
const
std
::
array
<
long_index_t
,
NDimSpatial
>&
conv_filter_strides
,
const
std
::
array
<
long_index_t
,
NDimSpatial
>&
conv_filter_dilations
,
const
std
::
array
<
long_index_t
,
NDimSpatial
>&
input_left_pads
,
const
std
::
array
<
long_index_t
,
NDimSpatial
>&
input_right_pads
,
const
AElementwiseOperation
&
a_element_op
,
const
BElementwiseOperation
&
b_element_op
,
const
CDEElementwiseOperation
&
cde_element_op
)
{
return
Argument
{
p_a
,
p_b
,
p_ds
,
p_e
,
a_g_n_c_wis_lengths
,
a_g_n_c_wis_strides
,
b_g_k_c_xs_lengths
,
b_g_k_c_xs_strides
,
ds_g_n_k_wos_lengths
,
ds_g_n_k_wos_strides
,
e_g_n_k_wos_lengths
,
e_g_n_k_wos_strides
,
conv_filter_strides
,
conv_filter_dilations
,
input_left_pads
,
input_right_pads
,
a_element_op
,
b_element_op
,
cde_element_op
};
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
void
*
p_a
,
const
void
*
p_b
,
const
std
::
array
<
const
void
*
,
NumDTensor
>&
p_ds
,
void
*
p_e
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
a_g_n_c_wis_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
a_g_n_c_wis_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
b_g_k_c_xs_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
b_g_k_c_xs_strides
,
const
std
::
array
<
std
::
array
<
index_t
,
NDimSpatial
+
3
>
,
NumDTensor
>&
ds_g_n_k_wos_lengths
,
const
std
::
array
<
std
::
array
<
index_t
,
NDimSpatial
+
3
>
,
NumDTensor
>&
ds_g_n_k_wos_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
e_g_n_k_wos_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
e_g_n_k_wos_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_dilations
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_left_pads
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_right_pads
,
const
AElementwiseOperation
&
a_element_op
,
const
BElementwiseOperation
&
b_element_op
,
const
CDEElementwiseOperation
&
cde_element_op
)
override
{
std
::
array
<
long_index_t
,
NDimSpatial
+
3
>
a_g_n_c_wis_lengths_i64
;
std
::
array
<
long_index_t
,
NDimSpatial
+
3
>
a_g_n_c_wis_strides_i64
;
std
::
array
<
long_index_t
,
NDimSpatial
+
3
>
b_g_k_c_xs_lengths_i64
;
std
::
array
<
long_index_t
,
NDimSpatial
+
3
>
b_g_k_c_xs_strides_i64
;
std
::
array
<
std
::
array
<
long_index_t
,
NDimSpatial
+
3
>
,
NumDTensor
>
ds_g_n_k_wos_lengths_i64
;
std
::
array
<
std
::
array
<
long_index_t
,
NDimSpatial
+
3
>
,
NumDTensor
>
ds_g_n_k_wos_strides_i64
;
std
::
array
<
long_index_t
,
NDimSpatial
+
3
>
e_g_n_k_wos_lengths_i64
;
std
::
array
<
long_index_t
,
NDimSpatial
+
3
>
e_g_n_k_wos_strides_i64
;
std
::
array
<
long_index_t
,
NDimSpatial
>
conv_filter_strides_i64
;
std
::
array
<
long_index_t
,
NDimSpatial
>
conv_filter_dilations_i64
;
std
::
array
<
long_index_t
,
NDimSpatial
>
input_left_pads_i64
;
std
::
array
<
long_index_t
,
NDimSpatial
>
input_right_pads_i64
;
array_convert
(
a_g_n_c_wis_lengths_i64
,
a_g_n_c_wis_lengths
);
array_convert
(
a_g_n_c_wis_strides_i64
,
a_g_n_c_wis_strides
);
array_convert
(
b_g_k_c_xs_lengths_i64
,
b_g_k_c_xs_lengths
);
array_convert
(
b_g_k_c_xs_strides_i64
,
b_g_k_c_xs_strides
);
for
(
index_t
d
=
0
;
d
<
NumDTensor
;
d
++
)
{
array_convert
(
ds_g_n_k_wos_lengths_i64
[
d
],
ds_g_n_k_wos_lengths
[
d
]);
array_convert
(
ds_g_n_k_wos_strides_i64
[
d
],
ds_g_n_k_wos_strides
[
d
]);
}
array_convert
(
e_g_n_k_wos_lengths_i64
,
e_g_n_k_wos_lengths
);
array_convert
(
e_g_n_k_wos_strides_i64
,
e_g_n_k_wos_strides
);
array_convert
(
conv_filter_strides_i64
,
conv_filter_strides
);
array_convert
(
conv_filter_dilations_i64
,
conv_filter_dilations
);
array_convert
(
input_left_pads_i64
,
input_left_pads
);
array_convert
(
input_right_pads_i64
,
input_right_pads
);
return
std
::
make_unique
<
Argument
>
(
p_a
,
p_b
,
p_ds
,
p_e
,
a_g_n_c_wis_lengths_i64
,
a_g_n_c_wis_strides_i64
,
b_g_k_c_xs_lengths_i64
,
b_g_k_c_xs_strides_i64
,
ds_g_n_k_wos_lengths_i64
,
ds_g_n_k_wos_strides_i64
,
e_g_n_k_wos_lengths_i64
,
e_g_n_k_wos_strides_i64
,
conv_filter_strides_i64
,
conv_filter_dilations_i64
,
input_left_pads_i64
,
input_right_pads_i64
,
a_element_op
,
b_element_op
,
cde_element_op
);
}
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
void
*
p_a
,
const
void
*
p_b
,
const
std
::
array
<
const
void
*
,
NumDTensor
>&
p_ds
,
void
*
p_e
,
const
std
::
array
<
long_index_t
,
NDimSpatial
+
3
>&
a_g_n_c_wis_lengths
,
const
std
::
array
<
long_index_t
,
NDimSpatial
+
3
>&
a_g_n_c_wis_strides
,
const
std
::
array
<
long_index_t
,
NDimSpatial
+
3
>&
b_g_k_c_xs_lengths
,
const
std
::
array
<
long_index_t
,
NDimSpatial
+
3
>&
b_g_k_c_xs_strides
,
const
std
::
array
<
std
::
array
<
long_index_t
,
NDimSpatial
+
3
>
,
NumDTensor
>&
ds_g_n_k_wos_lengths
,
const
std
::
array
<
std
::
array
<
long_index_t
,
NDimSpatial
+
3
>
,
NumDTensor
>&
ds_g_n_k_wos_strides
,
const
std
::
array
<
long_index_t
,
NDimSpatial
+
3
>&
e_g_n_k_wos_lengths
,
const
std
::
array
<
long_index_t
,
NDimSpatial
+
3
>&
e_g_n_k_wos_strides
,
const
std
::
array
<
long_index_t
,
NDimSpatial
>&
conv_filter_strides
,
const
std
::
array
<
long_index_t
,
NDimSpatial
>&
conv_filter_dilations
,
const
std
::
array
<
long_index_t
,
NDimSpatial
>&
input_left_pads
,
const
std
::
array
<
long_index_t
,
NDimSpatial
>&
input_right_pads
,
const
AElementwiseOperation
&
a_element_op
,
const
BElementwiseOperation
&
b_element_op
,
const
CDEElementwiseOperation
&
cde_element_op
)
override
{
return
std
::
make_unique
<
Argument
>
(
p_a
,
p_b
,
p_ds
,
p_e
,
a_g_n_c_wis_lengths
,
a_g_n_c_wis_strides
,
b_g_k_c_xs_lengths
,
b_g_k_c_xs_strides
,
ds_g_n_k_wos_lengths
,
ds_g_n_k_wos_strides
,
e_g_n_k_wos_lengths
,
e_g_n_k_wos_strides
,
conv_filter_strides
,
conv_filter_dilations
,
input_left_pads
,
input_right_pads
,
a_element_op
,
b_element_op
,
cde_element_op
);
}
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
override
{
return
std
::
make_unique
<
Invoker
>
(
Invoker
{});
}
std
::
string
GetTypeString
()
const
override
{
auto
str
=
std
::
stringstream
();
// clang-format off
str
<<
"DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor"
<<
"<"
<<
BlockSize
<<
", "
<<
MPerBlock
<<
", "
<<
NPerBlock
<<
", "
<<
KPerBlock
<<
", "
<<
getConvForwardSpecializationString
(
ConvForwardSpecialization
)
<<
", "
<<
MPerXDL
<<
", "
<<
NPerXDL
<<
", "
<<
MXdlPerWave
<<
", "
<<
NXdlPerWave
<<
", "
<<
ABlockTransferSrcScalarPerVector
<<
", "
<<
BBlockTransferSrcScalarPerVector
<<
", "
<<
CDEBlockTransferScalarPerVector_NPerBlock
<<
", "
<<
CShuffleMXdlPerWavePerShuffle
<<
", "
<<
CShuffleNXdlPerWavePerShuffle
<<
">"
;
// clang-format on
return
str
.
str
();
}
};
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp
View file @
bd689f40
...
...
@@ -59,6 +59,22 @@ constexpr bool is_GNDHWK_GKZYXC_GNDHWC()
is_same_v
<
OutLayout
,
tensor_layout
::
convolution
::
GNDHWK
>
;
}
template
<
typename
InLayout
,
typename
WeiLayout
,
typename
OutLayout
>
constexpr
bool
is_NSpatialGK_GKSpatial_NSpatialGC
()
{
return
is_NWGK_GKXC_NWGC
<
InLayout
,
WeiLayout
,
OutLayout
>
()
||
is_NHWGK_GKYXC_NHWGC
<
InLayout
,
WeiLayout
,
OutLayout
>
()
||
is_NDHWGK_GKZYXC_NDHWGC
<
InLayout
,
WeiLayout
,
OutLayout
>
();
}
template
<
typename
InLayout
,
typename
WeiLayout
,
typename
OutLayout
>
constexpr
bool
is_GNSpatialK_GKSpatial_GNSpatialC
()
{
return
is_GNWK_GKXC_GNWC
<
InLayout
,
WeiLayout
,
OutLayout
>
()
||
is_GNHWK_GKYXC_GNHWC
<
InLayout
,
WeiLayout
,
OutLayout
>
()
||
is_GNDHWK_GKZYXC_GNDHWC
<
InLayout
,
WeiLayout
,
OutLayout
>
();
}
template
<
index_t
NumATensor
=
1
,
index_t
NumBTensor
=
1
,
index_t
NumDTensor
=
0
,
typename
=
void
>
struct
ComputePtrOffsetOfStridedBatch
{
...
...
include/ck/tensor_operation/gpu/device/impl/device_image_to_column_impl.hpp
View file @
bd689f40
...
...
@@ -57,8 +57,8 @@ struct DeviceImageToColumnImpl
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
conv_to_gemm_t
ransformer
=
TransformConvFwdToGemm
<
NDimSpatial
,
ConvolutionForwardSpecialization
::
Default
>
{}
;
using
ConvToGemmFwdT
ransformer
=
TransformConvFwdToGemm
<
NDimSpatial
,
ConvolutionForwardSpecialization
::
Default
>
;
static
constexpr
auto
matrix_padder
=
MatrixPadder
<
GemmSpecialization
::
MKPadding
,
index_t
,
index_t
,
index_t
>
{
...
...
@@ -97,19 +97,19 @@ struct DeviceImageToColumnImpl
b_g_k_c_xs_lengths
[
I2
]
=
C
;
c_g_n_k_wos_lengths
[
I1
]
=
N
;
ConvToGemmFwdTransformer
conv_to_gemm_transformer
{
a_g_n_c_wis_lengths
,
image_g_n_c_wis_strides
,
b_g_k_c_xs_lengths
,
{},
// not needed for A Descriptor
c_g_n_k_wos_lengths
,
{},
// not needed for A Descriptor
conv_filter_strides
,
conv_filter_dilations
,
input_left_pads
,
input_right_pads
};
const
auto
in_gemmmraw_gemmkraw_desc
=
conv_to_gemm_transformer
.
template
MakeADescriptor_M_K
<
ImageLayout
>(
a_g_n_c_wis_lengths
,
image_g_n_c_wis_strides
,
b_g_k_c_xs_lengths
,
{},
// not needed for A Descriptor
c_g_n_k_wos_lengths
,
{},
// not needed for A Descriptor
conv_filter_strides
,
conv_filter_dilations
,
input_left_pads
,
input_right_pads
,
N
);
conv_to_gemm_transformer
.
template
MakeADescriptor_M_K
<
ImageLayout
>();
const
auto
in_gemmm_gemmk_desc
=
matrix_padder
.
PadADescriptor_M_K
(
in_gemmmraw_gemmkraw_desc
);
...
...
include/ck/tensor_operation/gpu/device/impl/device_reduce_common.hpp
View file @
bd689f40
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -19,7 +19,7 @@ namespace device {
template
<
index_t
Rank
,
int
NumReduceDim
>
std
::
pair
<
long_index_t
,
long_index_t
>
get_2d_lengths
(
const
std
::
vector
<
index_t
>&
inLengths
)
{
static_assert
(
Rank
<=
6
,
"bigger Rank size not supported!"
);
static_assert
(
Rank
<=
12
,
"bigger Rank size not supported!"
);
long_index_t
invariant_total_length
=
1
;
long_index_t
reduce_total_length
=
1
;
...
...
@@ -38,7 +38,7 @@ std::pair<long_index_t, long_index_t> get_2d_lengths(const std::vector<index_t>&
template
<
index_t
Rank
,
int
NumReduceDim
>
std
::
pair
<
long_index_t
,
long_index_t
>
get_2d_lengths
(
const
std
::
array
<
index_t
,
Rank
>&
inLengths
)
{
static_assert
(
Rank
<=
6
,
"bigger Rank size not supported!"
);
static_assert
(
Rank
<=
12
,
"bigger Rank size not supported!"
);
long_index_t
invariant_total_length
=
1
;
long_index_t
reduce_total_length
=
1
;
...
...
include/ck/tensor_operation/gpu/device/impl/device_reduce_multiblock.hpp
View file @
bd689f40
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -51,7 +51,7 @@ struct DeviceReduceMultiBlock : public DeviceReduce<InDataType,
PropagateNan
,
OutputIndex
>
{
static_assert
(
Rank
<=
6
,
"Bigger Rank size is not supported!"
);
static_assert
(
Rank
<=
12
,
"Bigger Rank size is not supported!"
);
static_assert
(
BlockSize
==
MThreadClusterSize
*
KThreadClusterSize
,
"Invalid thread cluster size assignments!"
);
...
...
include/ck/tensor_operation/gpu/device/impl/device_reduce_threadwise.hpp
View file @
bd689f40
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -47,7 +47,7 @@ struct DeviceReduceThreadWise : public DeviceReduce<InDataType,
OutputIndex
>
{
static_assert
(
Rank
<=
6
,
"Bigger Rank size is not supported!"
);
static_assert
(
Rank
<=
12
,
"Bigger Rank size is not supported!"
);
static_assert
(((
InSrcVectorDim
==
0
&&
MThreadSliceSize
%
InSrcVectorSize
==
0
)
||
(
InSrcVectorDim
==
1
&&
KThreadSliceSize
%
InSrcVectorSize
==
0
))
&&
...
...
include/ck/tensor_operation/gpu/device/impl/device_reduce_threadwise_multi_d.hpp
0 → 100644
View file @
bd689f40
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <sstream>
#include <array>
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/tensor_operation/gpu/device/device_reduce_multi_d.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_reduce_common.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_2d_reduction_threadwise_multi_d.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
template
<
typename
InDataType
,
typename
DsDataType
,
typename
AccDataType
,
typename
OutDataType
,
index_t
Rank
,
index_t
NumReduceDim
,
typename
ReduceOperation
,
typename
InElementwiseOperation
,
typename
OutElementwiseOperation
,
index_t
BlockSize
,
index_t
MThreadSliceSize
,
index_t
KThreadSliceSize
,
index_t
InSrcVectorDim
,
index_t
InSrcVectorSize
,
index_t
OutDstVectorSize
,
typename
DsVectorSizeSequence
>
struct
DeviceReduceThreadWiseMultiD
:
public
DeviceReduceMultiD
<
InDataType
,
DsDataType
,
AccDataType
,
OutDataType
,
Rank
,
NumReduceDim
,
ReduceOperation
,
InElementwiseOperation
,
OutElementwiseOperation
>
{
static_assert
(
Rank
<=
12
,
"Bigger Rank size is not supported!"
);
static_assert
(((
InSrcVectorDim
==
0
&&
MThreadSliceSize
%
InSrcVectorSize
==
0
)
||
(
InSrcVectorDim
==
1
&&
KThreadSliceSize
%
InSrcVectorSize
==
0
))
&&
(
MThreadSliceSize
%
OutDstVectorSize
==
0
),
"Invalid thread slice sizes and/or vector sizes configuration, please check!"
);
using
IndexDataType
=
int32_t
;
static
constexpr
index_t
NumInvariantDim
=
Rank
-
NumReduceDim
;
static
constexpr
index_t
NumDTensor
=
DsDataType
::
Size
();
static
constexpr
index_t
NumSrcDim
=
Rank
;
static
constexpr
index_t
NumDstDim
=
(
NumInvariantDim
==
0
)
?
1
:
NumInvariantDim
;
static
constexpr
bool
reduceAllDim
=
(
NumInvariantDim
==
0
);
static
constexpr
index_t
M_BlockTileSize
=
BlockSize
*
MThreadSliceSize
;
static
constexpr
index_t
K_BlockTileSize
=
1
*
KThreadSliceSize
;
static
auto
MakeSrc2dDescriptor
(
const
std
::
array
<
index_t
,
Rank
>&
inLengths
,
const
std
::
array
<
index_t
,
Rank
>&
inStrides
)
{
const
auto
tupleSrcLengths
=
generate_tuple
([
&
](
auto
I
)
{
return
inLengths
[
I
];
},
Number
<
Rank
>
{});
const
auto
tupleSrcStrides
=
generate_tuple
([
&
](
auto
I
)
{
return
inStrides
[
I
];
},
Number
<
Rank
>
{});
const
auto
inDesc
=
make_naive_tensor_descriptor
(
tupleSrcLengths
,
tupleSrcStrides
);
const
auto
in_grid_desc_m_k
=
[
&
]()
{
if
constexpr
(
reduceAllDim
)
{
const
auto
one_dim_inDesc
=
transform_tensor_descriptor
(
inDesc
,
make_tuple
(
make_merge_transform
(
tupleSrcLengths
)),
make_tuple
(
typename
arithmetic_sequence_gen
<
0
,
NumSrcDim
,
1
>::
type
{}),
make_tuple
(
Sequence
<
0
>
{}));
return
transform_tensor_descriptor
(
one_dim_inDesc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
1
,
one_dim_inDesc
.
GetLength
(
Number
<
0
>
{})))),
make_tuple
(
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
1
>
{}));
}
else
{
using
InvariantDims
=
typename
arithmetic_sequence_gen
<
0
,
NumInvariantDim
,
1
>::
type
;
using
ReduceDims
=
typename
arithmetic_sequence_gen
<
NumInvariantDim
,
Rank
,
1
>::
type
;
const
auto
reduceDimLengths
=
generate_tuple
(
[
&
](
auto
I
)
{
return
inLengths
[
NumInvariantDim
+
I
];
},
Number
<
NumReduceDim
>
{});
const
auto
invariantDimLengths
=
generate_tuple
([
&
](
auto
I
)
{
return
inLengths
[
I
];
},
Number
<
NumInvariantDim
>
{});
return
transform_tensor_descriptor
(
inDesc
,
make_tuple
(
make_merge_transform
(
invariantDimLengths
),
make_merge_transform
(
reduceDimLengths
)),
make_tuple
(
InvariantDims
{},
ReduceDims
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
}
}();
const
auto
invariantLength
=
in_grid_desc_m_k
.
GetLength
(
Number
<
0
>
{});
const
auto
reduceLength
=
in_grid_desc_m_k
.
GetLength
(
Number
<
1
>
{});
const
auto
inPad_M
=
math
::
integer_least_multiple
(
invariantLength
,
M_BlockTileSize
)
-
invariantLength
;
const
auto
inPad_K
=
math
::
integer_least_multiple
(
reduceLength
,
K_BlockTileSize
)
-
reduceLength
;
auto
in_grid_desc_m_k_padded
=
transform_tensor_descriptor
(
in_grid_desc_m_k
,
make_tuple
(
make_right_pad_transform
(
invariantLength
,
inPad_M
),
make_right_pad_transform
(
reduceLength
,
inPad_K
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
return
(
in_grid_desc_m_k_padded
);
};
static
auto
MakeDst1dDescriptor
(
const
std
::
array
<
index_t
,
NumDstDim
>&
outLengths
,
const
std
::
array
<
index_t
,
NumDstDim
>&
outStrides
)
{
const
auto
tupleDstLengths
=
generate_tuple
([
&
](
auto
I
)
{
return
outLengths
[
I
];
},
Number
<
NumDstDim
>
{});
const
auto
tupleDstStrides
=
generate_tuple
([
&
](
auto
I
)
{
return
outStrides
[
I
];
},
Number
<
NumDstDim
>
{});
auto
outDesc
=
make_naive_tensor_descriptor
(
tupleDstLengths
,
tupleDstStrides
);
auto
out_grid_desc_m
=
transform_tensor_descriptor
(
outDesc
,
make_tuple
(
make_merge_transform
(
tupleDstLengths
)),
make_tuple
(
typename
arithmetic_sequence_gen
<
0
,
NumDstDim
,
1
>::
type
{}),
make_tuple
(
Sequence
<
0
>
{}));
const
auto
invariantLength
=
out_grid_desc_m
.
GetLength
(
Number
<
0
>
{});
const
auto
outPad
=
math
::
integer_least_multiple
(
invariantLength
,
M_BlockTileSize
)
-
invariantLength
;
auto
out_grid_desc_m_padded
=
transform_tensor_descriptor
(
out_grid_desc_m
,
make_tuple
(
make_right_pad_transform
(
invariantLength
,
outPad
)),
make_tuple
(
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
return
(
out_grid_desc_m_padded
);
};
static
auto
MakeDsDescriptor
(
const
std
::
array
<
std
::
array
<
index_t
,
NumDstDim
>
,
NumDTensor
>
DsLengths
,
std
::
array
<
std
::
array
<
index_t
,
NumDstDim
>
,
NumDTensor
>
DsStrides
)
{
return
generate_tuple
(
[
&
](
auto
i
)
{
return
DeviceReduceThreadWiseMultiD
::
MakeDst1dDescriptor
(
DsLengths
[
i
],
DsStrides
[
i
]);
},
Number
<
NumDTensor
>
{});
}
using
InGridDesc_M_K
=
decltype
(
MakeSrc2dDescriptor
({},
{}));
using
OutGridDesc_M
=
decltype
(
MakeDst1dDescriptor
({},
{}));
using
DsGridDesc_M
=
decltype
(
MakeDsDescriptor
({},
{}));
using
GridwiseReduce
=
GridwiseReduction_mk_to_m_threadwise_multi_d
<
InDataType
,
DsDataType
,
OutDataType
,
AccDataType
,
InGridDesc_M_K
,
DsGridDesc_M
,
OutGridDesc_M
,
ReduceOperation
,
InElementwiseOperation
,
OutElementwiseOperation
,
InMemoryDataOperationEnum
::
Set
,
BlockSize
,
MThreadSliceSize
,
KThreadSliceSize
,
InSrcVectorDim
,
InSrcVectorSize
,
OutDstVectorSize
,
DsVectorSizeSequence
>
;
using
DsGridPointer
=
typename
GridwiseReduce
::
DsGridPointer
;
struct
Argument
:
public
BaseArgument
{
Argument
(
const
std
::
array
<
index_t
,
Rank
>
inLengths
,
const
std
::
array
<
index_t
,
Rank
>
inStrides
,
const
std
::
array
<
std
::
array
<
index_t
,
NumDstDim
>
,
NumDTensor
>
DsLengths
,
const
std
::
array
<
std
::
array
<
index_t
,
NumDstDim
>
,
NumDTensor
>
DsStrides
,
const
std
::
array
<
index_t
,
NumDstDim
>
outLengths
,
const
std
::
array
<
index_t
,
NumDstDim
>
outStrides
,
const
std
::
array
<
int
,
NumReduceDim
>
reduceDims
,
const
InDataType
*
in_dev
,
const
std
::
array
<
const
void
*
,
NumDTensor
>
ds_dev
,
OutDataType
*
out_dev
,
const
InElementwiseOperation
in_elementwise_op
,
const
OutElementwiseOperation
out_elementwise_op
)
:
DsLengths_
{
DsLengths
},
DsStrides_
{
DsStrides
},
outLengths_
{
outLengths
},
outStrides_
{
outStrides
},
in_dev_
{
in_dev
},
out_dev_
{
out_dev
},
in_elementwise_op_
{
in_elementwise_op
},
out_elementwise_op_
{
out_elementwise_op
}
{
inLengths_
=
shuffle_tensor_dimensions
<
Rank
,
NumReduceDim
>
(
inLengths
,
reduceDims
);
inStrides_
=
shuffle_tensor_dimensions
<
Rank
,
NumReduceDim
>
(
inStrides
,
reduceDims
);
std
::
tie
(
invariant_total_length
,
reduce_total_length
)
=
get_2d_lengths
<
Rank
,
NumReduceDim
>
(
inLengths_
);
if
constexpr
(
NumInvariantDim
==
0
)
invariant_lowest_length
=
1
;
else
invariant_lowest_length
=
inLengths_
[
NumInvariantDim
-
1
];
reduce_lowest_length
=
inLengths_
[
Rank
-
1
];
numBlockTileIteration
=
(
reduce_total_length
+
K_BlockTileSize
-
1
)
/
K_BlockTileSize
;
gridSize
=
math
::
integer_least_multiple
(
invariant_total_length
,
M_BlockTileSize
)
/
M_BlockTileSize
;
static_for
<
0
,
NumDTensor
,
1
>
{}([
&
](
auto
i
)
{
using
DDataType
=
remove_cvref_t
<
tuple_element_t
<
i
.
value
,
DsDataType
>>
;
p_ds_grid_
(
i
)
=
static_cast
<
const
DDataType
*>
(
ds_dev
[
i
]);
});
ds_grid_desc_m_
=
MakeDsDescriptor
(
DsLengths
,
DsStrides
);
}
std
::
array
<
index_t
,
Rank
>
inLengths_
;
std
::
array
<
index_t
,
Rank
>
inStrides_
;
std
::
array
<
std
::
array
<
index_t
,
NumDstDim
>
,
NumDTensor
>
DsLengths_
;
std
::
array
<
std
::
array
<
index_t
,
NumDstDim
>
,
NumDTensor
>
DsStrides_
;
std
::
array
<
index_t
,
NumDstDim
>
outLengths_
;
std
::
array
<
index_t
,
NumDstDim
>
outStrides_
;
const
InDataType
*
in_dev_
;
OutDataType
*
out_dev_
;
DsGridPointer
p_ds_grid_
;
InElementwiseOperation
in_elementwise_op_
;
OutElementwiseOperation
out_elementwise_op_
;
DsGridDesc_M
ds_grid_desc_m_
;
index_t
invariant_lowest_length
;
index_t
reduce_lowest_length
;
long_index_t
invariant_total_length
;
long_index_t
reduce_total_length
;
int
numBlockTileIteration
;
size_t
gridSize
;
};
struct
Invoker
:
public
BaseInvoker
{
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
{
const
auto
in_grid_desc_m_k
=
DeviceReduceThreadWiseMultiD
::
MakeSrc2dDescriptor
(
arg
.
inLengths_
,
arg
.
inStrides_
);
const
auto
out_grid_desc_m
=
DeviceReduceThreadWiseMultiD
::
MakeDst1dDescriptor
(
arg
.
outLengths_
,
arg
.
outStrides_
);
float
avg_time
=
0
;
const
auto
kernel
=
kernel_reduce_threadwise_multi_d
<
GridwiseReduce
,
InDataType
,
OutDataType
,
AccDataType
,
InGridDesc_M_K
,
DsGridDesc_M
,
OutGridDesc_M
,
InElementwiseOperation
,
OutElementwiseOperation
,
DsGridPointer
>
;
avg_time
=
launch_and_time_kernel
(
stream_config
,
kernel
,
dim3
(
arg
.
gridSize
),
dim3
(
BlockSize
),
0
,
in_grid_desc_m_k
,
arg
.
ds_grid_desc_m_
,
out_grid_desc_m
,
arg
.
in_elementwise_op_
,
arg
.
out_elementwise_op_
,
arg
.
in_dev_
,
arg
.
p_ds_grid_
,
arg
.
out_dev_
);
return
(
avg_time
);
};
float
Run
(
const
BaseArgument
*
p_arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
override
{
return
Run
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
),
stream_config
);
};
};
bool
IsSupportedArgument
(
const
BaseArgument
*
p_arg
)
override
{
const
Argument
*
pArg
=
dynamic_cast
<
const
Argument
*>
(
p_arg
);
if
constexpr
(
InSrcVectorDim
==
0
)
{
if
constexpr
(
NumInvariantDim
==
0
)
{
return
(
false
);
}
else
{
if
(
pArg
->
inStrides_
[
NumInvariantDim
-
1
]
!=
1
)
return
(
false
);
if
(
pArg
->
invariant_lowest_length
%
InSrcVectorSize
!=
0
)
return
(
false
);
};
}
else
{
if
(
pArg
->
inStrides_
[
Rank
-
1
]
!=
1
)
return
(
false
);
if
(
pArg
->
reduce_lowest_length
%
InSrcVectorSize
!=
0
)
return
(
false
);
};
// To improve
if
(
pArg
->
invariant_lowest_length
%
OutDstVectorSize
!=
0
)
return
(
false
);
std
::
cerr
<<
"reduce_total_length = "
<<
pArg
->
reduce_total_length
<<
" KThreadSliceSize = "
<<
KThreadSliceSize
<<
std
::
endl
;
// cases with big reduce_total_length should be handled by Blockwise kernel
if
(
pArg
->
reduce_total_length
/
KThreadSliceSize
>=
32
)
return
(
false
);
return
(
true
);
};
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
std
::
array
<
index_t
,
Rank
>
inLengths
,
const
std
::
array
<
index_t
,
Rank
>
inStrides
,
const
std
::
array
<
std
::
array
<
index_t
,
NumDstDim
>
,
NumDTensor
>
DsLengths
,
const
std
::
array
<
std
::
array
<
index_t
,
NumDstDim
>
,
NumDTensor
>
DsStrides
,
const
std
::
array
<
index_t
,
NumDstDim
>
outLengths
,
const
std
::
array
<
index_t
,
NumDstDim
>
outStrides
,
const
std
::
array
<
int
,
NumReduceDim
>
reduceDims
,
const
void
*
in_dev
,
const
std
::
array
<
const
void
*
,
NumDTensor
>
ds_dev
,
void
*
out_dev
,
const
InElementwiseOperation
in_elementwise_op
,
const
OutElementwiseOperation
out_elementwise_op
)
override
{
return
std
::
make_unique
<
Argument
>
(
inLengths
,
inStrides
,
DsLengths
,
DsStrides
,
outLengths
,
outStrides
,
reduceDims
,
static_cast
<
const
InDataType
*>
(
in_dev
),
ds_dev
,
static_cast
<
OutDataType
*>
(
out_dev
),
in_elementwise_op
,
out_elementwise_op
);
};
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
override
{
return
std
::
make_unique
<
Invoker
>
();
};
std
::
string
GetTypeString
()
const
override
{
auto
str
=
std
::
stringstream
();
// clang-format off
str
<<
"DeviceReduceThreadWiseMultiD<"
<<
BlockSize
<<
","
;
str
<<
"M_C"
<<
BlockSize
<<
"_S"
<<
MThreadSliceSize
<<
","
;
str
<<
"K_C"
<<
1
<<
"_S"
<<
KThreadSliceSize
<<
","
;
str
<<
"InSrcVectorDim_"
<<
InSrcVectorDim
<<
"_InSrcVectorSize_"
<<
InSrcVectorSize
<<
"_OutDstVectorSize_"
<<
OutDstVectorSize
<<
">"
;
// clang-format on
return
str
.
str
();
}
};
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp
View file @
bd689f40
...
...
@@ -638,6 +638,32 @@ struct AddSilu
}
};
struct
ConvScaleAdd
{
__host__
__device__
ConvScaleAdd
(
float
scale_in
=
1.
f
,
float
scale_wei
=
1.
f
,
float
scale_out
=
1.
f
)
:
scale_in_
(
scale_in
),
scale_wei_
(
scale_wei
),
scale_out_
(
scale_out
)
{
}
template
<
typename
E
,
typename
C
,
typename
D
>
__host__
__device__
void
operator
()(
E
&
e
,
const
C
&
c
,
const
D
&
d
)
const
;
template
<
>
__host__
__device__
void
operator
()
<
f8_t
,
float
,
float
>
(
f8_t
&
e
,
const
float
&
c
,
const
float
&
d
)
const
{
float
x
;
Add
{}.
template
operator
()
<
float
>(
x
,
c
*
scale_in_
*
scale_wei_
,
d
);
e
=
type_convert
<
f8_t
>
(
x
*
scale_out_
);
};
float
scale_in_
;
float
scale_wei_
;
float
scale_out_
;
};
}
// namespace element_wise
}
// namespace tensor_operation
}
// namespace ck
include/ck/tensor_operation/gpu/element/element_wise_operation.hpp
View file @
bd689f40
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -249,6 +249,31 @@ struct MultiplyAdd
}
};
struct
MultiplyMultiply
{
template
<
typename
E
,
typename
C
,
typename
D0
,
typename
D1
>
__host__
__device__
constexpr
void
operator
()(
E
&
e
,
const
C
&
c
,
const
D0
&
d0
,
const
D1
&
d1
)
const
;
template
<
>
__host__
__device__
constexpr
void
operator
()
<
ck
::
half_t
,
float
,
float
,
float
>
(
ck
::
half_t
&
e
,
const
float
&
c
,
const
float
&
d0
,
const
float
&
d1
)
const
{
const
float
x0_f
=
c
*
d0
*
d1
;
e
=
ck
::
type_convert
<
ck
::
half_t
>
(
x0_f
);
}
template
<
>
__host__
__device__
constexpr
void
operator
()
<
ck
::
bhalf_t
,
float
,
float
,
float
>
(
ck
::
bhalf_t
&
e
,
const
float
&
c
,
const
float
&
d0
,
const
float
&
d1
)
const
{
const
float
x0_f
=
c
*
d0
*
d1
;
e
=
ck
::
type_convert
<
ck
::
bhalf_t
>
(
x0_f
);
}
};
struct
MultiplyAddFastGelu
{
template
<
typename
E
,
typename
C
,
typename
D0
,
typename
D1
>
...
...
include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp
View file @
bd689f40
...
...
@@ -431,7 +431,7 @@ struct Relu
// https://paperswithcode.com/method/gelu
// y = 0.5*x*(1+tanh(sqrt(2/pi)*(x+0.044715*x^3)))
// host code use higher accuracy "exp" and "div"
// gpu code use lower accuracy "_
_expf
" and "rcp" function
// gpu code use lower accuracy "_
ocml_exp_f32
" and "rcp" function
struct
FastGelu
{
template
<
typename
Y
,
typename
X
>
...
...
@@ -451,7 +451,7 @@ struct FastGelu
y
=
x
/
(
1.
f
+
emu
);
}
// device code, use lower precision "__
expf
" and "rcp"
// device code, use lower precision "__
ocml_exp_f32
" and "rcp"
template
<
>
__device__
void
operator
()
<
float
,
float
>
(
float
&
y
,
const
float
&
x
)
const
{
...
...
@@ -459,7 +459,7 @@ struct FastGelu
const
float
c1
=
-
2.0
*
0.035677
f
;
const
float
c2
=
-
2.0
*
0.797885
f
;
const
float
u
=
x
*
(
c1
*
x
*
x
+
c2
);
const
float
emu
=
__
expf
(
u
);
const
float
emu
=
__
ocml_exp_f32
(
u
);
y
=
x
*
ck
::
math
::
rcp
(
1.
f
+
emu
);
}
...
...
@@ -1025,6 +1025,31 @@ struct ConvScale
float
scale_out_
;
};
struct
ConvScaleRelu
{
__host__
__device__
ConvScaleRelu
(
float
scale_in
=
1.
f
,
float
scale_wei
=
1.
f
,
float
scale_out
=
1.
f
)
:
scale_in_
(
scale_in
),
scale_wei_
(
scale_wei
),
scale_out_
(
scale_out
)
{
}
template
<
typename
E
,
typename
C
>
__host__
__device__
void
operator
()(
E
&
e
,
const
C
&
c
)
const
;
template
<
>
__host__
__device__
void
operator
()
<
f8_t
,
float
>
(
f8_t
&
e
,
const
float
&
c
)
const
{
float
x
;
Relu
{}.
template
operator
()
<
float
>(
x
,
c
*
scale_in_
*
scale_wei_
);
e
=
type_convert
<
f8_t
>
(
x
*
scale_out_
);
};
float
scale_in_
;
float
scale_wei_
;
float
scale_out_
;
};
// support fastconvert of int8 to fp16
template
<
typename
InputDataType
,
typename
OutputDataType
,
index_t
RegPackNumber
>
...
...
include/ck/tensor_operation/gpu/grid/gridwise_2d_multiple_reduction_multiblock.hpp
View file @
bd689f40
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -258,7 +258,7 @@ struct GridwiseMultipleReduction_mk_to_m_multiblock
if
(
thread_k_cluster_id
==
0
)
{
if
(
block_group_size
==
0
&&
!
float_equal_zero
{}(
beta_values
[
iR
]))
if
(
!
float_equal_zero
{}(
beta_values
[
iR
]))
{
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
OutDataType
,
MThreadSliceSize
,
true
>
priorDstValueBuf
;
...
...
Prev
1
2
3
4
5
6
7
8
9
10
…
17
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment