Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
f30e5975
Commit
f30e5975
authored
Oct 24, 2023
by
Jun Liu
Browse files
Merge branch 'develop' into amd-develop
parents
91b414cd
bec84efb
Changes
207
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
2452 additions
and
248 deletions
+2452
-248
include/ck/tensor_operation/gpu/device/device_gemm_splitk.hpp
...ude/ck/tensor_operation/gpu/device/device_gemm_splitk.hpp
+6
-3
include/ck/tensor_operation/gpu/device/device_normalization.hpp
...e/ck/tensor_operation/gpu/device/device_normalization.hpp
+5
-3
include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_abd_xdl_cshuffle.hpp
...ice/impl/device_contraction_multiple_abd_xdl_cshuffle.hpp
+847
-0
include/ck/tensor_operation/gpu/device/impl/device_elementwise_impl.hpp
...sor_operation/gpu/device/impl/device_elementwise_impl.hpp
+22
-0
include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_splitk_c_shuffle.hpp
...tion/gpu/device/impl/device_gemm_xdl_splitk_c_shuffle.hpp
+88
-17
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp
...device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp
+1
-1
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_dl.hpp
...ion/gpu/device/impl/device_grouped_conv_bwd_weight_dl.hpp
+3
-28
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle.hpp
...ice/impl/device_grouped_conv_bwd_weight_wmma_cshuffle.hpp
+877
-0
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp
...vice/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp
+7
-28
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp
...impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp
+1
-1
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp
...r_operation/gpu/device/impl/device_grouped_conv_utils.hpp
+8
-1
include/ck/tensor_operation/gpu/device/impl/device_normalization_impl.hpp
...r_operation/gpu/device/impl/device_normalization_impl.hpp
+98
-23
include/ck/tensor_operation/gpu/device/impl/device_normalization_splitk_impl.hpp
...tion/gpu/device/impl/device_normalization_splitk_impl.hpp
+174
-84
include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp
...or_operation/gpu/element/unary_element_wise_operation.hpp
+103
-12
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp
...tion/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp
+3
-2
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp
...ation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp
+27
-14
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp
...k/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp
+2
-1
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp
...tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp
+18
-9
include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_naive_variance.hpp
...d/normalization/gridwise_normalization_naive_variance.hpp
+112
-5
include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_selector.hpp
...pu/grid/normalization/gridwise_normalization_selector.hpp
+50
-16
No files found.
include/ck/tensor_operation/gpu/device/device_gemm_splitk.hpp
View file @
f30e5975
...
@@ -20,7 +20,8 @@ template <typename ALayout,
...
@@ -20,7 +20,8 @@ template <typename ALayout,
typename
CDataType
,
typename
CDataType
,
typename
AElementwiseOperation
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
>
typename
CElementwiseOperation
,
typename
ComputeType
=
CDataType
>
struct
DeviceGemmSplitK
:
public
BaseOperator
struct
DeviceGemmSplitK
:
public
BaseOperator
{
{
virtual
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
void
*
p_a
,
virtual
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
void
*
p_a
,
...
@@ -48,7 +49,8 @@ template <typename ALayout,
...
@@ -48,7 +49,8 @@ template <typename ALayout,
typename
CDataType
,
typename
CDataType
,
typename
AElementwiseOperation
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
>
typename
CElementwiseOperation
,
typename
ComputeType
=
CDataType
>
using
DeviceGemmSplitKPtr
=
std
::
unique_ptr
<
DeviceGemmSplitK
<
ALayout
,
using
DeviceGemmSplitKPtr
=
std
::
unique_ptr
<
DeviceGemmSplitK
<
ALayout
,
BLayout
,
BLayout
,
CLayout
,
CLayout
,
...
@@ -57,7 +59,8 @@ using DeviceGemmSplitKPtr = std::unique_ptr<DeviceGemmSplitK<ALayout,
...
@@ -57,7 +59,8 @@ using DeviceGemmSplitKPtr = std::unique_ptr<DeviceGemmSplitK<ALayout,
CDataType
,
CDataType
,
AElementwiseOperation
,
AElementwiseOperation
,
BElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
>>
;
CElementwiseOperation
,
ComputeType
>>
;
}
// namespace device
}
// namespace device
}
// namespace tensor_operation
}
// namespace tensor_operation
...
...
include/ck/tensor_operation/gpu/device/device_normalization.hpp
View file @
f30e5975
...
@@ -14,8 +14,8 @@ namespace device {
...
@@ -14,8 +14,8 @@ namespace device {
template
<
typename
XDataType
,
template
<
typename
XDataType
,
typename
GammaDataType
,
typename
GammaDataType
,
typename
BetaDataType
,
typename
BetaDataType
,
typename
ComputeDataType
,
typename
YDataType
,
typename
YDataType
,
typename
SaveMeanInvStdDataType
,
typename
YElementwiseOperation
,
typename
YElementwiseOperation
,
index_t
Rank
,
index_t
Rank
,
index_t
NumReduceDim
>
index_t
NumReduceDim
>
...
@@ -27,6 +27,8 @@ struct DeviceNormalization : public BaseOperator
...
@@ -27,6 +27,8 @@ struct DeviceNormalization : public BaseOperator
const
std
::
vector
<
index_t
>
gammaStrides
,
const
std
::
vector
<
index_t
>
gammaStrides
,
const
std
::
vector
<
index_t
>
betaStrides
,
const
std
::
vector
<
index_t
>
betaStrides
,
const
std
::
vector
<
index_t
>
yStrides
,
const
std
::
vector
<
index_t
>
yStrides
,
const
std
::
vector
<
index_t
>
saveMeanStrides
,
const
std
::
vector
<
index_t
>
saveInvStdStrides
,
const
std
::
vector
<
index_t
>
reduceDims
,
const
std
::
vector
<
index_t
>
reduceDims
,
double
epsilon
,
double
epsilon
,
const
void
*
p_x
,
const
void
*
p_x
,
...
@@ -43,16 +45,16 @@ struct DeviceNormalization : public BaseOperator
...
@@ -43,16 +45,16 @@ struct DeviceNormalization : public BaseOperator
template
<
typename
XDataType
,
template
<
typename
XDataType
,
typename
GammaDataType
,
typename
GammaDataType
,
typename
BetaDataType
,
typename
BetaDataType
,
typename
ComputeDataType
,
typename
YDataType
,
typename
YDataType
,
typename
SaveMeanInvStdDataType
,
typename
YElementwiseOperation
,
typename
YElementwiseOperation
,
index_t
Rank
,
index_t
Rank
,
index_t
NumReduceDim
>
index_t
NumReduceDim
>
using
DeviceNormalizationPtr
=
std
::
unique_ptr
<
DeviceNormalization
<
XDataType
,
using
DeviceNormalizationPtr
=
std
::
unique_ptr
<
DeviceNormalization
<
XDataType
,
GammaDataType
,
GammaDataType
,
BetaDataType
,
BetaDataType
,
ComputeDataType
,
YDataType
,
YDataType
,
SaveMeanInvStdDataType
,
YElementwiseOperation
,
YElementwiseOperation
,
Rank
,
Rank
,
NumReduceDim
>>
;
NumReduceDim
>>
;
...
...
include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_abd_xdl_cshuffle.hpp
0 → 100644
View file @
f30e5975
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <sstream>
#include <vector>
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_contraction_multiple_abd.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
namespace
ck
{
template
<
typename
GridwiseGemm
,
typename
AsPointer
,
typename
BsPointer
,
typename
DsPointer
,
typename
EDataType
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CDEElementwiseOperation
,
typename
AsGridDesc_AK0_M_AK1
,
typename
BsGridDesc_BK0_N_BK1
,
typename
DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
Block2ETileMap
,
bool
HasMainKBlockLoop
>
__global__
void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
#endif
kernel_contraction_multiple_abd_xdl_cshuffle
(
AsPointer
p_as_grid
,
BsPointer
p_bs_grid
,
DsPointer
p_ds_grid
,
EDataType
*
__restrict__
p_e_grid
,
const
AElementwiseOperation
a_element_op
,
const
BElementwiseOperation
b_element_op
,
const
CDEElementwiseOperation
cde_element_op
,
const
AsGridDesc_AK0_M_AK1
as_grid_desc_ak0_m_ak1
,
const
BsGridDesc_BK0_N_BK1
bs_grid_desc_bk0_n_bk1
,
const
DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
ds_grid_desc_mblock_mperblock_nblock_nperblock
,
const
EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
e_grid_desc_mblock_mperblock_nblock_nperblock
,
const
Block2ETileMap
block_2_etile_map
)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
p_as_grid
,
p_bs_grid
,
p_ds_grid
,
p_e_grid
,
p_shared
,
a_element_op
,
b_element_op
,
cde_element_op
,
as_grid_desc_ak0_m_ak1
,
bs_grid_desc_bk0_n_bk1
,
ds_grid_desc_mblock_mperblock_nblock_nperblock
,
e_grid_desc_mblock_mperblock_nblock_nperblock
,
block_2_etile_map
);
#else
ignore
=
p_as_grid
;
ignore
=
p_bs_grid
;
ignore
=
p_ds_grid
;
ignore
=
p_e_grid
;
ignore
=
a_element_op
;
ignore
=
b_element_op
;
ignore
=
cde_element_op
;
ignore
=
as_grid_desc_ak0_m_ak1
;
ignore
=
bs_grid_desc_bk0_n_bk1
;
ignore
=
ds_grid_desc_mblock_mperblock_nblock_nperblock
;
ignore
=
e_grid_desc_mblock_mperblock_nblock_nperblock
;
ignore
=
block_2_etile_map
;
#endif
}
}
// namespace ck
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
// GEMM:
// input : A[M, K]
// input : B[N, K]
// input : D0[M, N], D1[M, N], ...
// output : E[M, N]
// C = a_op(A) * b_op(B)
// E = cde_op(C, D0, D1, ...)
// Assume:
// D0, D1, ... and E have the same layout
template
<
index_t
NumDimM
,
index_t
NumDimN
,
index_t
NumDimK
,
typename
AsDataType
,
typename
BsDataType
,
typename
AccDataType
,
typename
CShuffleDataType
,
typename
DsDataType
,
typename
EDataType
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CDEElementwiseOperation
,
GemmSpecialization
GemmSpec
,
index_t
NumGemmKPrefetchStage
,
index_t
BlockSize
,
index_t
MPerBlock
,
index_t
NPerBlock
,
index_t
KPerBlock
,
index_t
AK1
,
index_t
BK1
,
index_t
MPerXDL
,
index_t
NPerXDL
,
index_t
MXdlPerWave
,
index_t
NXdlPerWave
,
typename
ABlockTransferThreadClusterLengths_AK0_M_AK1
,
typename
ABlockTransferThreadClusterArrangeOrder
,
typename
ABlockTransferSrcAccessOrder
,
index_t
ABlockTransferSrcVectorDim
,
index_t
ABlockTransferSrcScalarPerVector
,
index_t
ABlockTransferDstScalarPerVector_AK1
,
index_t
ABlockLdsExtraM
,
typename
BBlockTransferThreadClusterLengths_BK0_N_BK1
,
typename
BBlockTransferThreadClusterArrangeOrder
,
typename
BBlockTransferSrcAccessOrder
,
index_t
BBlockTransferSrcVectorDim
,
index_t
BBlockTransferSrcScalarPerVector
,
index_t
BBlockTransferDstScalarPerVector_BK1
,
index_t
BBlockLdsExtraN
,
index_t
CShuffleMXdlPerWavePerShuffle
,
index_t
CShuffleNXdlPerWavePerShuffle
,
typename
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
index_t
CDEBlockTransferScalarPerVector_NPerBlock
,
LoopScheduler
LoopSched
=
make_default_loop_scheduler
(),
PipelineVersion
PipelineVer
=
PipelineVersion
::
v1
>
struct
DeviceContractionMultipleABD_Xdl_CShuffle
:
public
DeviceContractionMultipleABD
<
NumDimM
,
NumDimN
,
NumDimK
,
AsDataType
,
BsDataType
,
DsDataType
,
EDataType
,
AElementwiseOperation
,
BElementwiseOperation
,
CDEElementwiseOperation
>
{
using
DeviceOp
=
DeviceContractionMultipleABD_Xdl_CShuffle
;
static
constexpr
index_t
NumATensor
=
AsDataType
::
Size
();
static
constexpr
index_t
NumBTensor
=
BsDataType
::
Size
();
static
constexpr
index_t
NumDTensor
=
DsDataType
::
Size
();
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
using
ComputeDataType
=
EDataType
;
// GridwiseGemm
using
GridwiseGemm
=
GridwiseGemmMultipleABD_xdl_cshuffle
<
AsDataType
,
BsDataType
,
ComputeDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
AElementwiseOperation
,
BElementwiseOperation
,
CDEElementwiseOperation
,
InMemoryDataOperationEnum
::
Set
,
NumGemmKPrefetchStage
,
BlockSize
,
MPerBlock
,
NPerBlock
,
KPerBlock
,
AK1
,
BK1
,
MPerXDL
,
NPerXDL
,
MXdlPerWave
,
NXdlPerWave
,
ABlockTransferThreadClusterLengths_AK0_M_AK1
,
ABlockTransferThreadClusterArrangeOrder
,
ABlockTransferSrcAccessOrder
,
ABlockTransferSrcVectorDim
,
ABlockTransferSrcScalarPerVector
,
ABlockTransferDstScalarPerVector_AK1
,
false
,
ABlockLdsExtraM
,
BBlockTransferThreadClusterLengths_BK0_N_BK1
,
BBlockTransferThreadClusterArrangeOrder
,
BBlockTransferSrcAccessOrder
,
BBlockTransferSrcVectorDim
,
BBlockTransferSrcScalarPerVector
,
BBlockTransferDstScalarPerVector_BK1
,
false
,
BBlockLdsExtraN
,
CShuffleMXdlPerWavePerShuffle
,
CShuffleNXdlPerWavePerShuffle
,
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
CDEBlockTransferScalarPerVector_NPerBlock
,
LoopSched
,
PipelineVer
>
;
static
constexpr
auto
matrix_padder
=
ck
::
tensor_operation
::
device
::
MatrixPadder
<
GemmSpec
,
index_t
,
index_t
,
index_t
>
{
MPerBlock
,
NPerBlock
,
KPerBlock
};
static
auto
MakeAGridDescriptor_M_K
(
const
std
::
vector
<
index_t
>&
a_ms_ks_lengths_
,
const
std
::
vector
<
index_t
>&
a_ms_ks_strides_
)
{
assert
(
a_ms_ks_lengths_
.
size
()
==
NumDimM
+
NumDimK
&&
a_ms_ks_strides_
.
size
()
==
NumDimM
+
NumDimK
);
const
auto
to_tuple
=
[
&
](
auto
&
vec
,
auto
num
)
{
return
generate_tuple
([
&
](
auto
i
)
{
return
vec
[
i
];
},
num
);
};
const
auto
a_ms_ks_lengths
=
to_tuple
(
a_ms_ks_lengths_
,
Number
<
NumDimM
+
NumDimK
>
{});
const
auto
a_ms_ks_strides
=
to_tuple
(
a_ms_ks_strides_
,
Number
<
NumDimM
+
NumDimK
>
{});
// dimension Ids for M0, M1, ...
constexpr
auto
mDimIds
=
typename
arithmetic_sequence_gen
<
0
,
NumDimM
,
1
>::
type
{};
// dimension Ids for K0, K1, ...
constexpr
auto
kDimIds
=
typename
arithmetic_sequence_gen
<
NumDimM
,
NumDimM
+
NumDimK
,
1
>::
type
{};
// lengths for M0, M1, ...
const
auto
mLengths
=
get_container_subset
(
a_ms_ks_lengths
,
mDimIds
);
// lengths for K0, K1, ...
const
auto
kLengths
=
get_container_subset
(
a_ms_ks_lengths
,
kDimIds
);
// naive tensor A[M0, M1, M2, ..., K0, K1, K2...]
const
auto
a_grid_desc_ms_ks
=
make_naive_tensor_descriptor
(
a_ms_ks_lengths
,
a_ms_ks_strides
);
// transformed tensor A[MRaw = M0 * M1 * M2 * ... , KRaw = K0 * K1 * K2 * ...]
const
auto
a_grid_desc_mraw_kraw
=
transform_tensor_descriptor
(
a_grid_desc_ms_ks
,
make_tuple
(
make_merge_transform
(
mLengths
),
make_merge_transform
(
kLengths
)),
make_tuple
(
mDimIds
,
kDimIds
),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
return
matrix_padder
.
PadADescriptor_M_K
(
a_grid_desc_mraw_kraw
);
}
__host__
__device__
static
auto
MakeAsGridDescriptor_M_K
(
const
std
::
array
<
std
::
vector
<
index_t
>
,
NumATensor
>&
as_ms_ks_lengths
,
const
std
::
array
<
std
::
vector
<
index_t
>
,
NumATensor
>&
as_ms_ks_strides
)
{
return
generate_tuple
(
[
&
](
auto
i
)
{
return
MakeAGridDescriptor_M_K
(
as_ms_ks_lengths
[
i
],
as_ms_ks_strides
[
i
]);
},
Number
<
NumATensor
>
{});
}
// Assume: B[N0, N1, N2, ..., K0, K1, K2, ...]
static
auto
MakeBGridDescriptor_N_K
(
const
std
::
vector
<
index_t
>&
b_ns_ks_lengths_
,
const
std
::
vector
<
index_t
>&
b_ns_ks_strides_
)
{
assert
(
b_ns_ks_lengths_
.
size
()
==
NumDimN
+
NumDimK
&&
b_ns_ks_strides_
.
size
()
==
NumDimN
+
NumDimK
);
const
auto
to_tuple
=
[
&
](
auto
&
vec
,
auto
num
)
{
return
generate_tuple
([
&
](
auto
i
)
{
return
vec
[
i
];
},
num
);
};
const
auto
b_ns_ks_lengths
=
to_tuple
(
b_ns_ks_lengths_
,
Number
<
NumDimN
+
NumDimK
>
{});
const
auto
b_ns_ks_strides
=
to_tuple
(
b_ns_ks_strides_
,
Number
<
NumDimN
+
NumDimK
>
{});
// dimension Ids for N0, N1, ...
constexpr
auto
nDimIds
=
typename
arithmetic_sequence_gen
<
0
,
NumDimN
,
1
>::
type
{};
// dimension Ids for K0, K1, ...
constexpr
auto
kDimIds
=
typename
arithmetic_sequence_gen
<
NumDimN
,
NumDimN
+
NumDimK
,
1
>::
type
{};
// lengths for K0, K1, ...
const
auto
kLengths
=
get_container_subset
(
b_ns_ks_lengths
,
kDimIds
);
// lengths for N0, N1, ...
const
auto
nLengths
=
get_container_subset
(
b_ns_ks_lengths
,
nDimIds
);
// naive tensor B[N0, N1, N2, ..., K0, K1, K2, ...]
const
auto
b_grid_desc_ns_ks
=
make_naive_tensor_descriptor
(
b_ns_ks_lengths
,
b_ns_ks_strides
);
// transformed tensor B[NRaw = N0 * N1 * N2 * ..., KRaw = K0 * K1 * K2 * ...]
const
auto
b_grid_desc_nraw_kraw
=
transform_tensor_descriptor
(
b_grid_desc_ns_ks
,
make_tuple
(
make_merge_transform
(
nLengths
),
make_merge_transform
(
kLengths
)),
make_tuple
(
nDimIds
,
kDimIds
),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
return
matrix_padder
.
PadBDescriptor_N_K
(
b_grid_desc_nraw_kraw
);
}
__host__
__device__
static
auto
MakeBsGridDescriptor_N_K
(
const
std
::
array
<
std
::
vector
<
index_t
>
,
NumBTensor
>&
bs_ns_ks_lengths
,
const
std
::
array
<
std
::
vector
<
index_t
>
,
NumBTensor
>&
bs_ns_ks_strides
)
{
return
generate_tuple
(
[
&
](
auto
i
)
{
return
MakeBGridDescriptor_N_K
(
bs_ns_ks_lengths
[
i
],
bs_ns_ks_strides
[
i
]);
},
Number
<
NumBTensor
>
{});
}
// assume E[M0, M1, M2, ..., N0, N1, N2...]
static
auto
MakeEGridDescriptor_M_N
(
const
std
::
vector
<
index_t
>&
e_ms_ns_lengths_
,
const
std
::
vector
<
index_t
>&
e_ms_ns_strides_
)
{
assert
(
e_ms_ns_lengths_
.
size
()
==
NumDimM
+
NumDimN
&&
e_ms_ns_strides_
.
size
()
==
NumDimM
+
NumDimN
);
const
auto
to_tuple
=
[
&
](
auto
&
vec
,
auto
num
)
{
return
generate_tuple
([
&
](
auto
i
)
{
return
vec
[
i
];
},
num
);
};
const
auto
e_ms_ns_lengths
=
to_tuple
(
e_ms_ns_lengths_
,
Number
<
NumDimM
+
NumDimN
>
{});
const
auto
e_ms_ns_strides
=
to_tuple
(
e_ms_ns_strides_
,
Number
<
NumDimM
+
NumDimN
>
{});
// dimension Ids for M0, M1, ...
constexpr
auto
mDimIds
=
typename
arithmetic_sequence_gen
<
0
,
NumDimM
,
1
>::
type
{};
// dimension Ids for N0, N1, ...
constexpr
auto
nDimIds
=
typename
arithmetic_sequence_gen
<
NumDimM
,
NumDimM
+
NumDimN
,
1
>::
type
{};
// lengths for M0, M1, ...
const
auto
mLengths
=
get_container_subset
(
e_ms_ns_lengths
,
mDimIds
);
// lengths for K0, K1, ...
const
auto
nLengths
=
get_container_subset
(
e_ms_ns_lengths
,
nDimIds
);
// naive tensor E[M0, M1, M2, ..., N0, N1, N2...]
const
auto
e_grid_desc_ms_ns
=
make_naive_tensor_descriptor
(
e_ms_ns_lengths
,
e_ms_ns_strides
);
// transformed tensor E[MRaw = M0 * M1 * M2 * ... , NRaw = N0 * N1 * N2 * ...]
const
auto
e_grid_desc_mraw_nraw
=
transform_tensor_descriptor
(
e_grid_desc_ms_ns
,
make_tuple
(
make_merge_transform
(
mLengths
),
make_merge_transform
(
nLengths
)),
make_tuple
(
mDimIds
,
nDimIds
),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
return
matrix_padder
.
PadCDescriptor_M_N
(
e_grid_desc_mraw_nraw
);
}
static
auto
MakeDsGridDescriptor_M_N
(
const
std
::
array
<
std
::
vector
<
index_t
>
,
NumDTensor
>&
ds_ms_ns_lengths
,
const
std
::
array
<
std
::
vector
<
index_t
>
,
NumDTensor
>&
ds_ms_ns_strides
)
{
return
generate_tuple
(
[
&
](
auto
i
)
{
return
MakeEGridDescriptor_M_N
(
ds_ms_ns_lengths
[
i
],
ds_ms_ns_strides
[
i
]);
},
Number
<
NumDTensor
>
{});
}
// desc for problem definition
using
AsGridDesc_M_K
=
remove_cvref_t
<
decltype
(
MakeAsGridDescriptor_M_K
({},
{}))
>
;
using
BsGridDesc_N_K
=
remove_cvref_t
<
decltype
(
MakeBsGridDescriptor_N_K
({},
{}))
>
;
using
DsGridDesc_M_N
=
remove_cvref_t
<
decltype
(
MakeDsGridDescriptor_M_N
({},
{}))
>
;
using
EGridDesc_M_N
=
remove_cvref_t
<
decltype
(
MakeEGridDescriptor_M_N
({},
{}))
>
;
// desc for blockwise copy
using
AsGridDesc_AK0_M_AK1
=
remove_cvref_t
<
decltype
(
GridwiseGemm
::
MakeAsGridDescriptor_AK0_M_AK1
(
AsGridDesc_M_K
{}))
>
;
using
BsGridDesc_BK0_N_BK1
=
remove_cvref_t
<
decltype
(
GridwiseGemm
::
MakeBsGridDescriptor_BK0_N_BK1
(
BsGridDesc_N_K
{}))
>
;
using
DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
=
remove_cvref_t
<
decltype
(
GridwiseGemm
::
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
DsGridDesc_M_N
{}))
>
;
using
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
=
remove_cvref_t
<
decltype
(
GridwiseGemm
::
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
EGridDesc_M_N
{}))
>
;
// block-to-e-tile map
using
Block2ETileMap
=
remove_cvref_t
<
decltype
(
GridwiseGemm
::
MakeBlock2ETileMap
(
EGridDesc_M_N
{}))
>
;
// Argument
struct
Argument
:
public
BaseArgument
{
Argument
(
std
::
array
<
const
void
*
,
NumATensor
>
p_as_grid
,
std
::
array
<
const
void
*
,
NumBTensor
>
p_bs_grid
,
std
::
array
<
const
void
*
,
NumDTensor
>
p_ds_grid
,
void
*
p_e_grid
,
const
std
::
array
<
std
::
vector
<
index_t
>
,
NumATensor
>&
a_ms_ks_lengths
,
const
std
::
array
<
std
::
vector
<
index_t
>
,
NumATensor
>&
a_ms_ks_strides
,
const
std
::
array
<
std
::
vector
<
index_t
>
,
NumBTensor
>&
b_ns_ks_lengths
,
const
std
::
array
<
std
::
vector
<
index_t
>
,
NumBTensor
>&
b_ns_ks_strides
,
const
std
::
array
<
std
::
vector
<
index_t
>
,
NumDTensor
>&
d_ms_ns_lengths
,
const
std
::
array
<
std
::
vector
<
index_t
>
,
NumDTensor
>&
d_ms_ns_strides
,
const
std
::
vector
<
index_t
>&
e_ms_ns_length
,
const
std
::
vector
<
index_t
>&
e_ms_ns_stride
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CDEElementwiseOperation
cde_element_op
)
:
p_as_grid_
{},
p_bs_grid_
{},
p_ds_grid_
{},
p_e_grid_
{
static_cast
<
EDataType
*>
(
p_e_grid
)},
as_grid_desc_m_k_
{},
bs_grid_desc_n_k_
{},
ds_grid_desc_m_n_
{},
e_grid_desc_m_n_
{
MakeEGridDescriptor_M_N
(
e_ms_ns_length
,
e_ms_ns_stride
)},
as_grid_desc_ak0_m_ak1_
{},
bs_grid_desc_bk0_n_bk1_
{},
ds_grid_desc_mblock_mperblock_nblock_nperblock_
{},
e_grid_desc_mblock_mperblock_nblock_nperblock_
{},
block_2_etile_map_
{
GridwiseGemm
::
MakeBlock2ETileMap
(
e_grid_desc_m_n_
)},
a_element_op_
{
a_element_op
},
b_element_op_
{
b_element_op
},
cde_element_op_
{
cde_element_op
}
{
// populate pointer, desc for As
static_for
<
0
,
NumATensor
,
1
>
{}([
&
](
auto
i
)
{
// using ALayout = remove_cvref_t<tuple_element_t<i.value, AsLayout>>;
using
ADataType
=
remove_cvref_t
<
tuple_element_t
<
i
.
value
,
AsDataType
>>
;
// A pointer
p_as_grid_
(
i
)
=
static_cast
<
const
ADataType
*>
(
p_as_grid
[
i
]);
// A desc
as_grid_desc_m_k_
(
i
)
=
MakeAGridDescriptor_M_K
(
a_ms_ks_lengths
[
i
],
a_ms_ks_strides
[
i
]);
});
// populate pointer, desc for Bs
static_for
<
0
,
NumBTensor
,
1
>
{}([
&
](
auto
i
)
{
// using BLayout = remove_cvref_t<tuple_element_t<i.value, BsLayout>>;
using
BDataType
=
remove_cvref_t
<
tuple_element_t
<
i
.
value
,
BsDataType
>>
;
// B pointer
p_bs_grid_
(
i
)
=
static_cast
<
const
BDataType
*>
(
p_bs_grid
[
i
]);
// B desc
bs_grid_desc_n_k_
(
i
)
=
MakeBGridDescriptor_N_K
(
b_ns_ks_lengths
[
i
],
b_ns_ks_strides
[
i
]);
});
// populate pointer, desc for Ds
static_for
<
0
,
NumDTensor
,
1
>
{}([
&
](
auto
i
)
{
// using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
using
DDataType
=
remove_cvref_t
<
tuple_element_t
<
i
.
value
,
DsDataType
>>
;
// D pointer
p_ds_grid_
(
i
)
=
static_cast
<
const
DDataType
*>
(
p_ds_grid
[
i
]);
// D desc
ds_grid_desc_m_n_
(
i
)
=
MakeEGridDescriptor_M_N
(
d_ms_ns_lengths
[
i
],
d_ms_ns_strides
[
i
]);
});
// populate desc for Ds/E
if
(
GridwiseGemm
::
CheckValidity
(
as_grid_desc_m_k_
,
bs_grid_desc_n_k_
,
ds_grid_desc_m_n_
,
e_grid_desc_m_n_
,
block_2_etile_map_
))
{
as_grid_desc_ak0_m_ak1_
=
GridwiseGemm
::
MakeAsGridDescriptor_AK0_M_AK1
(
as_grid_desc_m_k_
);
bs_grid_desc_bk0_n_bk1_
=
GridwiseGemm
::
MakeBsGridDescriptor_BK0_N_BK1
(
bs_grid_desc_n_k_
);
ds_grid_desc_mblock_mperblock_nblock_nperblock_
=
GridwiseGemm
::
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
ds_grid_desc_m_n_
);
e_grid_desc_mblock_mperblock_nblock_nperblock_
=
GridwiseGemm
::
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
e_grid_desc_m_n_
);
}
// for sanity check of vector memory access
for
(
index_t
i
=
0
;
i
<
NumATensor
;
++
i
)
{
a_mz_stride_
[
i
]
=
a_ms_ks_strides
[
i
][
NumDimM
-
1
];
a_kz_stride_
[
i
]
=
a_ms_ks_strides
[
i
][
NumDimM
+
NumDimK
-
1
];
}
for
(
index_t
i
=
0
;
i
<
NumBTensor
;
++
i
)
{
b_nz_stride_
[
i
]
=
b_ns_ks_strides
[
i
][
NumDimN
-
1
];
b_kz_stride_
[
i
]
=
b_ns_ks_strides
[
i
][
NumDimN
+
NumDimK
-
1
];
}
for
(
index_t
i
=
0
;
i
<
NumDTensor
;
++
i
)
{
ds_nz_stride_
[
i
]
=
d_ms_ns_strides
[
i
][
NumDimM
+
NumDimN
-
1
];
}
e_nz_stride_
=
e_ms_ns_stride
[
NumDimM
+
NumDimN
-
1
];
}
// pointers
typename
GridwiseGemm
::
AsGridPointer
p_as_grid_
;
typename
GridwiseGemm
::
BsGridPointer
p_bs_grid_
;
typename
GridwiseGemm
::
DsGridPointer
p_ds_grid_
;
EDataType
*
p_e_grid_
;
// tensor descriptors for problem definiton
AsGridDesc_M_K
as_grid_desc_m_k_
;
BsGridDesc_N_K
bs_grid_desc_n_k_
;
DsGridDesc_M_N
ds_grid_desc_m_n_
;
EGridDesc_M_N
e_grid_desc_m_n_
;
// tensor descriptors for block/thread-wise copy
AsGridDesc_AK0_M_AK1
as_grid_desc_ak0_m_ak1_
;
BsGridDesc_BK0_N_BK1
bs_grid_desc_bk0_n_bk1_
;
DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
ds_grid_desc_mblock_mperblock_nblock_nperblock_
;
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
e_grid_desc_mblock_mperblock_nblock_nperblock_
;
// block-to-e-tile map
Block2ETileMap
block_2_etile_map_
;
// element-wise op
AElementwiseOperation
a_element_op_
;
BElementwiseOperation
b_element_op_
;
CDEElementwiseOperation
cde_element_op_
;
// Strides for the last M/N/K dimensions of A/B/Ds/E
// for sanity check of vector load/store
std
::
array
<
index_t
,
NumATensor
>
a_mz_stride_
;
std
::
array
<
index_t
,
NumATensor
>
a_kz_stride_
;
std
::
array
<
index_t
,
NumBTensor
>
b_nz_stride_
;
std
::
array
<
index_t
,
NumBTensor
>
b_kz_stride_
;
std
::
array
<
index_t
,
NumDTensor
>
ds_nz_stride_
;
index_t
e_nz_stride_
;
};
// Invoker
struct
Invoker
:
public
BaseInvoker
{
using
Argument
=
DeviceOp
::
Argument
;
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
{
if
(
!
GridwiseGemm
::
CheckValidity
(
arg
.
as_grid_desc_m_k_
,
arg
.
bs_grid_desc_n_k_
,
arg
.
ds_grid_desc_m_n_
,
arg
.
e_grid_desc_m_n_
,
arg
.
block_2_etile_map_
))
{
throw
std
::
runtime_error
(
"wrong! GridwiseGemm has invalid setting"
);
}
const
index_t
grid_size
=
arg
.
block_2_etile_map_
.
CalculateGridSize
(
arg
.
e_grid_desc_m_n_
);
auto
launch_kernel
=
[
&
](
auto
has_main_k_block_loop
)
{
constexpr
bool
has_main_loop
=
has_main_k_block_loop
.
value
;
const
auto
kernel
=
kernel_contraction_multiple_abd_xdl_cshuffle
<
GridwiseGemm
,
typename
GridwiseGemm
::
AsGridPointer
,
typename
GridwiseGemm
::
BsGridPointer
,
typename
GridwiseGemm
::
DsGridPointer
,
EDataType
,
AElementwiseOperation
,
BElementwiseOperation
,
CDEElementwiseOperation
,
DeviceOp
::
AsGridDesc_AK0_M_AK1
,
DeviceOp
::
BsGridDesc_BK0_N_BK1
,
DeviceOp
::
DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
,
DeviceOp
::
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
,
DeviceOp
::
Block2ETileMap
,
has_main_loop
>
;
return
launch_and_time_kernel
(
stream_config
,
kernel
,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
arg
.
p_as_grid_
,
arg
.
p_bs_grid_
,
arg
.
p_ds_grid_
,
arg
.
p_e_grid_
,
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
cde_element_op_
,
arg
.
as_grid_desc_ak0_m_ak1_
,
arg
.
bs_grid_desc_bk0_n_bk1_
,
arg
.
ds_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg
.
e_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg
.
block_2_etile_map_
);
};
const
auto
K
=
arg
.
as_grid_desc_m_k_
[
I0
].
GetLength
(
I1
);
if
(
GridwiseGemm
::
CalculateHasMainKBlockLoop
(
K
))
{
return
launch_kernel
(
integral_constant
<
bool
,
true
>
{});
}
else
{
return
launch_kernel
(
integral_constant
<
bool
,
false
>
{});
}
}
// polymorphic
float
Run
(
const
BaseArgument
*
p_arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
override
{
return
Run
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
),
stream_config
);
}
};
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
if
(
!
ck
::
is_xdl_supported
())
{
return
false
;
}
// check vector load/store
{
bool
all_valid
=
true
;
static_for
<
0
,
NumATensor
,
1
>
{}([
&
](
auto
i
)
{
// vector memory access of A: could be on M or AK1 dimension
if
constexpr
(
ABlockTransferSrcVectorDim
==
1
)
{
if
(
!
(
arg
.
a_mz_stride_
[
i
]
==
1
&&
arg
.
as_grid_desc_ak0_m_ak1_
[
i
].
GetLength
(
I1
)
%
ABlockTransferSrcScalarPerVector
==
0
))
{
all_valid
=
false
;
}
}
else
{
if
(
!
(
arg
.
a_kz_stride_
[
i
]
==
1
&&
arg
.
as_grid_desc_ak0_m_ak1_
[
i
].
GetLength
(
I2
)
%
ABlockTransferSrcScalarPerVector
==
0
))
{
all_valid
=
false
;
}
}
});
// vector memory access of B: could be on N or BK1 dimension
static_for
<
0
,
NumBTensor
,
1
>
{}([
&
](
auto
i
)
{
if
constexpr
(
BBlockTransferSrcVectorDim
==
1
)
{
if
(
!
(
arg
.
b_nz_stride_
[
i
]
==
1
&&
arg
.
bs_grid_desc_bk0_n_bk1_
[
i
].
GetLength
(
I1
)
%
BBlockTransferSrcScalarPerVector
==
0
))
{
all_valid
=
false
;
}
}
else
{
if
(
!
(
arg
.
b_kz_stride_
[
i
]
==
1
&&
arg
.
bs_grid_desc_bk0_n_bk1_
[
i
].
GetLength
(
I2
)
%
BBlockTransferSrcScalarPerVector
==
0
))
{
all_valid
=
false
;
}
}
});
// check vector load of Ds
static_for
<
0
,
NumDTensor
,
1
>
{}([
&
](
auto
i
)
{
if
(
!
(
arg
.
ds_nz_stride_
[
i
]
==
1
&&
arg
.
ds_grid_desc_mblock_mperblock_nblock_nperblock_
[
i
].
GetLength
(
I3
)
%
CDEBlockTransferScalarPerVector_NPerBlock
==
0
))
{
all_valid
=
false
;
}
});
// vector memory access of E: always on NPerBlock dimension
if
(
!
(
arg
.
e_nz_stride_
==
1
&&
arg
.
e_grid_desc_mblock_mperblock_nblock_nperblock_
.
GetLength
(
I3
)
%
CDEBlockTransferScalarPerVector_NPerBlock
==
0
))
{
all_valid
=
false
;
}
if
(
!
all_valid
)
{
return
false
;
}
}
return
GridwiseGemm
::
CheckValidity
(
arg
.
as_grid_desc_m_k_
,
arg
.
bs_grid_desc_n_k_
,
arg
.
ds_grid_desc_m_n_
,
arg
.
e_grid_desc_m_n_
,
arg
.
block_2_etile_map_
);
}
// polymorphic
bool
IsSupportedArgument
(
const
BaseArgument
*
p_arg
)
override
{
return
IsSupportedArgument
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
));
}
static
auto
MakeArgument
(
std
::
array
<
const
void
*
,
NumATensor
>
p_as
,
std
::
array
<
const
void
*
,
NumBTensor
>
p_bs
,
std
::
array
<
const
void
*
,
NumDTensor
>
p_ds
,
void
*
p_e
,
const
std
::
array
<
std
::
vector
<
index_t
>
,
NumATensor
>&
a_ms_ks_lengths
,
const
std
::
array
<
std
::
vector
<
index_t
>
,
NumATensor
>&
a_ms_ks_strides
,
const
std
::
array
<
std
::
vector
<
index_t
>
,
NumBTensor
>&
b_ns_ks_lengths
,
const
std
::
array
<
std
::
vector
<
index_t
>
,
NumBTensor
>&
b_ns_ks_strides
,
const
std
::
array
<
std
::
vector
<
index_t
>
,
NumDTensor
>&
d_ms_ns_lengths
,
const
std
::
array
<
std
::
vector
<
index_t
>
,
NumDTensor
>&
d_ms_ns_strides
,
const
std
::
vector
<
index_t
>&
e_ms_ns_length
,
const
std
::
vector
<
index_t
>&
e_ms_ns_stride
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CDEElementwiseOperation
cde_element_op
)
{
return
Argument
{
p_as
,
p_bs
,
p_ds
,
p_e
,
a_ms_ks_lengths
,
a_ms_ks_strides
,
b_ns_ks_lengths
,
b_ns_ks_strides
,
d_ms_ns_lengths
,
d_ms_ns_strides
,
e_ms_ns_length
,
e_ms_ns_stride
,
a_element_op
,
b_element_op
,
cde_element_op
};
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
// polymorphic
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
std
::
array
<
const
void
*
,
NumATensor
>
p_as
,
std
::
array
<
const
void
*
,
NumBTensor
>
p_bs
,
std
::
array
<
const
void
*
,
NumDTensor
>
p_ds
,
void
*
p_e
,
const
std
::
array
<
std
::
vector
<
index_t
>
,
NumATensor
>&
as_ms_ks_lengths
,
const
std
::
array
<
std
::
vector
<
index_t
>
,
NumATensor
>&
as_ms_ks_strides
,
const
std
::
array
<
std
::
vector
<
index_t
>
,
NumBTensor
>&
bs_ns_ks_lengths
,
const
std
::
array
<
std
::
vector
<
index_t
>
,
NumBTensor
>&
bs_ns_ks_strides
,
const
std
::
array
<
std
::
vector
<
index_t
>
,
NumDTensor
>&
ds_ms_ns_lengths
,
const
std
::
array
<
std
::
vector
<
index_t
>
,
NumDTensor
>&
ds_ms_ns_strides
,
const
std
::
vector
<
index_t
>&
e_ms_ns_length
,
const
std
::
vector
<
index_t
>&
e_ms_ns_stride
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CDEElementwiseOperation
cde_element_op
)
override
{
return
std
::
make_unique
<
Argument
>
(
p_as
,
p_bs
,
p_ds
,
p_e
,
as_ms_ks_lengths
,
as_ms_ks_strides
,
bs_ns_ks_lengths
,
bs_ns_ks_strides
,
ds_ms_ns_lengths
,
ds_ms_ns_strides
,
e_ms_ns_length
,
e_ms_ns_stride
,
a_element_op
,
b_element_op
,
cde_element_op
);
}
// polymorphic
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
override
{
return
std
::
make_unique
<
Invoker
>
(
Invoker
{});
}
// polymorphic
std
::
string
GetTypeString
()
const
override
{
auto
str
=
std
::
stringstream
();
std
::
map
<
LoopScheduler
,
std
::
string
>
LoopSchedToString
{
{
LoopScheduler
::
Default
,
"Default"
},
{
LoopScheduler
::
Interwave
,
"Interwave"
}};
std
::
map
<
PipelineVersion
,
std
::
string
>
PipelineVersionToString
{{
PipelineVersion
::
v1
,
"v1"
},
{
PipelineVersion
::
v2
,
"v2"
}};
// clang-format off
str
<<
"DeviceContractionMultipleABD_Xdl_CShuffle"
<<
"<"
<<
BlockSize
<<
", "
<<
MPerBlock
<<
", "
<<
NPerBlock
<<
", "
<<
KPerBlock
<<
", "
<<
AK1
<<
", "
<<
BK1
<<
", "
<<
MPerXDL
<<
", "
<<
NPerXDL
<<
", "
<<
MXdlPerWave
<<
", "
<<
NXdlPerWave
<<
", "
<<
ABlockTransferSrcScalarPerVector
<<
", "
<<
BBlockTransferSrcScalarPerVector
<<
", "
<<
CShuffleMXdlPerWavePerShuffle
<<
", "
<<
CShuffleNXdlPerWavePerShuffle
<<
", "
<<
getGemmSpecializationString
(
GemmSpec
)
<<
">"
<<
" LoopScheduler: "
<<
LoopSchedToString
[
LoopSched
]
<<
", "
<<
"PipelineVersion: "
<<
PipelineVersionToString
[
PipelineVer
];
// clang-format on
return
str
.
str
();
}
};
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
include/ck/tensor_operation/gpu/device/impl/device_elementwise_impl.hpp
View file @
f30e5975
...
@@ -296,6 +296,28 @@ struct DeviceElementwiseImpl
...
@@ -296,6 +296,28 @@ struct DeviceElementwiseImpl
{
{
return
std
::
make_unique
<
Invoker
>
();
return
std
::
make_unique
<
Invoker
>
();
};
};
std
::
string
GetTypeString
()
const
override
{
auto
str
=
std
::
stringstream
();
// clang-format off
str
<<
"DeviceElementwiseImpl<"
;
str
<<
"NumDim_"
<<
NumDim
<<
","
;
str
<<
"MPerThread_"
<<
MPerThread
<<
","
;
str
<<
"InScalarPerVector"
;
static_for
<
0
,
InScalarPerVectorSeq
::
Size
(),
1
>
{}([
&
](
auto
i
)
{
str
<<
"_"
<<
InScalarPerVectorSeq
::
At
(
i
).
value
;
});
str
<<
","
;
str
<<
"OutScalarPerVector"
;
static_for
<
0
,
OutScalarPerVectorSeq
::
Size
(),
1
>
{}([
&
](
auto
i
)
{
str
<<
"_"
<<
OutScalarPerVectorSeq
::
At
(
i
).
value
;
});
str
<<
">"
;
// clang-format on
return
str
.
str
();
}
};
// namespace device
};
// namespace device
}
// namespace device
}
// namespace device
...
...
include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_splitk_c_shuffle.hpp
View file @
f30e5975
...
@@ -69,7 +69,8 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
...
@@ -69,7 +69,8 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
CDataType
,
CDataType
,
AElementwiseOperation
,
AElementwiseOperation
,
BElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
>
CElementwiseOperation
,
ComputeType
>
{
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
...
@@ -126,7 +127,50 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
...
@@ -126,7 +127,50 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
PipelineVer
,
PipelineVer
,
ComputeType
>
;
ComputeType
>
;
using
Argument
=
typename
GridwiseGemm
::
Argument
;
struct
Argument
:
public
GridwiseGemm
::
Argument
{
Argument
(
const
ADataType
*
p_a_grid_
,
const
BDataType
*
p_b_grid_
,
CDataType
*
p_c_grid_
,
index_t
M_
,
index_t
N_
,
index_t
K_
,
index_t
StrideA_
,
index_t
StrideB_
,
index_t
StrideC_
,
index_t
MPadded_
,
index_t
NPadded_
,
index_t
KPadded_
,
index_t
K0_
,
index_t
k_batch_
,
AElementwiseOperation
a_element_op_
,
BElementwiseOperation
b_element_op_
,
CElementwiseOperation
c_element_op_
)
:
GridwiseGemm
::
Argument
(
p_a_grid_
,
p_b_grid_
,
p_c_grid_
,
M_
,
N_
,
K_
,
StrideA_
,
StrideB_
,
StrideC_
,
MPadded_
,
NPadded_
,
KPadded_
,
K0_
,
k_batch_
),
a_element_op
(
a_element_op_
),
b_element_op
(
b_element_op_
),
c_element_op
(
c_element_op_
)
{
}
AElementwiseOperation
a_element_op
;
BElementwiseOperation
b_element_op
;
CElementwiseOperation
c_element_op
;
};
using
DefaultBlock2CTileMap
=
typename
GridwiseGemm
::
DefaultBlock2CTileMap
;
using
DefaultBlock2CTileMap
=
typename
GridwiseGemm
::
DefaultBlock2CTileMap
;
// Invoker
// Invoker
...
@@ -167,8 +211,17 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
...
@@ -167,8 +211,17 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
karg
.
M
*
karg
.
N
*
sizeof
(
CDataType
),
karg
.
M
*
karg
.
N
*
sizeof
(
CDataType
),
stream_config
.
stream_id_
));
stream_config
.
stream_id_
));
ave_time
=
launch_and_time_kernel
(
ave_time
=
stream_config
,
kernel
,
dim3
(
gdx
,
gdy
,
gdz
),
dim3
(
BlockSize
),
0
,
karg
,
b2c_map
);
launch_and_time_kernel
(
stream_config
,
kernel
,
dim3
(
gdx
,
gdy
,
gdz
),
dim3
(
BlockSize
),
0
,
static_cast
<
typename
GridwiseGemm
::
Argument
>
(
karg
),
b2c_map
,
karg
.
a_element_op
,
karg
.
b_element_op
,
karg
.
c_element_op
);
};
};
if
(
has_main_k0_block_loop
)
if
(
has_main_k0_block_loop
)
...
@@ -179,7 +232,10 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
...
@@ -179,7 +232,10 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
kernel_gemm_xdlops_v2r4r2_simplified
<
GridwiseGemm
,
kernel_gemm_xdlops_v2r4r2_simplified
<
GridwiseGemm
,
true
,
true
,
InMemoryDataOperationEnum
::
Set
,
InMemoryDataOperationEnum
::
Set
,
DefaultBlock2CTileMap
>
;
DefaultBlock2CTileMap
,
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
>
;
Run
(
kernel
);
Run
(
kernel
);
}
}
...
@@ -189,7 +245,10 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
...
@@ -189,7 +245,10 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
kernel_gemm_xdlops_v2r4r2_simplified
<
GridwiseGemm
,
kernel_gemm_xdlops_v2r4r2_simplified
<
GridwiseGemm
,
true
,
true
,
InMemoryDataOperationEnum
::
AtomicAdd
,
InMemoryDataOperationEnum
::
AtomicAdd
,
DefaultBlock2CTileMap
>
;
DefaultBlock2CTileMap
,
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
>
;
Run
(
kernel
);
Run
(
kernel
);
}
}
...
@@ -202,7 +261,10 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
...
@@ -202,7 +261,10 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
kernel_gemm_xdlops_v2r4r2_simplified
<
GridwiseGemm
,
kernel_gemm_xdlops_v2r4r2_simplified
<
GridwiseGemm
,
false
,
false
,
InMemoryDataOperationEnum
::
Set
,
InMemoryDataOperationEnum
::
Set
,
DefaultBlock2CTileMap
>
;
DefaultBlock2CTileMap
,
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
>
;
Run
(
kernel
);
Run
(
kernel
);
}
}
...
@@ -212,7 +274,10 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
...
@@ -212,7 +274,10 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
kernel_gemm_xdlops_v2r4r2_simplified
<
GridwiseGemm
,
kernel_gemm_xdlops_v2r4r2_simplified
<
GridwiseGemm
,
false
,
false
,
InMemoryDataOperationEnum
::
AtomicAdd
,
InMemoryDataOperationEnum
::
AtomicAdd
,
DefaultBlock2CTileMap
>
;
DefaultBlock2CTileMap
,
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
>
;
Run
(
kernel
);
Run
(
kernel
);
}
}
...
@@ -260,12 +325,12 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
...
@@ -260,12 +325,12 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
index_t
StrideA
,
index_t
StrideA
,
index_t
StrideB
,
index_t
StrideB
,
index_t
StrideC
,
index_t
StrideC
,
AElementwiseOperation
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
,
CElementwiseOperation
c_element_op
,
index_t
KBatch
)
index_t
KBatch
)
{
{
return
Argument
{
p_a
,
return
Argument
(
p_a
,
p_b
,
p_b
,
p_c
,
p_c
,
M
,
M
,
...
@@ -278,7 +343,10 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
...
@@ -278,7 +343,10 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
GridwiseGemm
::
CalculateNPadded
(
N
),
GridwiseGemm
::
CalculateNPadded
(
N
),
GridwiseGemm
::
CalculateKPadded
(
K
,
KBatch
),
GridwiseGemm
::
CalculateKPadded
(
K
,
KBatch
),
GridwiseGemm
::
CalculateK0
(
K
,
KBatch
),
GridwiseGemm
::
CalculateK0
(
K
,
KBatch
),
KBatch
};
KBatch
,
a_element_op
,
b_element_op
,
c_element_op
);
}
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
...
@@ -293,9 +361,9 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
...
@@ -293,9 +361,9 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
index_t
StrideA
,
index_t
StrideA
,
index_t
StrideB
,
index_t
StrideB
,
index_t
StrideC
,
index_t
StrideC
,
AElementwiseOperation
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
,
CElementwiseOperation
c_element_op
,
ck
::
index_t
KBatch
=
1
)
override
ck
::
index_t
KBatch
=
1
)
override
{
{
return
std
::
make_unique
<
Argument
>
(
static_cast
<
const
ADataType
*>
(
p_a
),
return
std
::
make_unique
<
Argument
>
(
static_cast
<
const
ADataType
*>
(
p_a
),
...
@@ -311,7 +379,10 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
...
@@ -311,7 +379,10 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
GridwiseGemm
::
CalculateNPadded
(
N
),
GridwiseGemm
::
CalculateNPadded
(
N
),
GridwiseGemm
::
CalculateKPadded
(
K
,
KBatch
),
GridwiseGemm
::
CalculateKPadded
(
K
,
KBatch
),
GridwiseGemm
::
CalculateK0
(
K
,
KBatch
),
GridwiseGemm
::
CalculateK0
(
K
,
KBatch
),
KBatch
);
KBatch
,
a_element_op
,
b_element_op
,
c_element_op
);
}
}
// polymorphic
// polymorphic
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp
View file @
f30e5975
...
@@ -565,7 +565,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle
...
@@ -565,7 +565,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle
auto
launch_kernel
=
[
&
](
auto
has_main_k_block_loop
)
{
auto
launch_kernel
=
[
&
](
auto
has_main_k_block_loop
)
{
constexpr
bool
has_main_loop
=
has_main_k_block_loop
.
value
;
constexpr
bool
has_main_loop
=
has_main_k_block_loop
.
value
;
const
auto
kernel
=
kernel_grouped_conv_
fwd_
multiple_d_wmma_cshuffle
<
const
auto
kernel
=
kernel_grouped_conv_multiple_d_wmma_cshuffle
<
GridwiseGemm
,
GridwiseGemm
,
ADataType
,
ADataType
,
BDataType
,
BDataType
,
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_dl.hpp
View file @
f30e5975
...
@@ -12,6 +12,7 @@
...
@@ -12,6 +12,7 @@
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_conv_bwd_weight.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_conv_bwd_weight.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp"
#include "ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp"
#include "ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_dl_v1r3.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_dl_v1r3.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
...
@@ -22,32 +23,6 @@ namespace ck {
...
@@ -22,32 +23,6 @@ namespace ck {
namespace
tensor_operation
{
namespace
tensor_operation
{
namespace
device
{
namespace
device
{
namespace
{
struct
ComputePtrOffsetOfStridedBatch
{
__host__
__device__
constexpr
long_index_t
GetAPtrOffset
(
index_t
g_idx
)
const
{
return
g_idx
*
static_cast
<
long_index_t
>
(
BatchStrideA_
);
}
__host__
__device__
constexpr
long_index_t
GetBPtrOffset
(
index_t
g_idx
)
const
{
return
g_idx
*
static_cast
<
long_index_t
>
(
BatchStrideB_
);
}
__host__
__device__
constexpr
long_index_t
GetCPtrOffset
(
index_t
g_idx
)
const
{
return
g_idx
*
static_cast
<
long_index_t
>
(
BatchStrideC_
);
}
index_t
BatchStrideA_
;
index_t
BatchStrideB_
;
index_t
BatchStrideC_
;
};
}
// namespace
template
<
typename
GridwiseGemm
,
template
<
typename
GridwiseGemm
,
typename
FloatAB
,
typename
FloatAB
,
typename
FloatC
,
typename
FloatC
,
...
@@ -952,7 +927,7 @@ struct DeviceGroupedConvBwdWeight_Dl : public DeviceGroupedConvBwdWeight<NDimSpa
...
@@ -952,7 +927,7 @@ struct DeviceGroupedConvBwdWeight_Dl : public DeviceGroupedConvBwdWeight<NDimSpa
Block2CTileMap
block_2_ctile_map_
;
Block2CTileMap
block_2_ctile_map_
;
// for computing batch offset
// for computing batch offset
ComputePtrOffsetOfStridedBatch
compute_ptr_offset_of_batch_
;
ComputePtrOffsetOfStridedBatch
<
I0
>
compute_ptr_offset_of_batch_
;
// element-wise op
// element-wise op
OutElementwiseOperation
a_element_op_
;
OutElementwiseOperation
a_element_op_
;
...
@@ -1024,7 +999,7 @@ struct DeviceGroupedConvBwdWeight_Dl : public DeviceGroupedConvBwdWeight<NDimSpa
...
@@ -1024,7 +999,7 @@ struct DeviceGroupedConvBwdWeight_Dl : public DeviceGroupedConvBwdWeight<NDimSpa
remove_reference_t
<
DeviceOp
::
BGridDesc_B_K0_N0_N1_K1
>
,
remove_reference_t
<
DeviceOp
::
BGridDesc_B_K0_N0_N1_K1
>
,
remove_reference_t
<
DeviceOp
::
CGridDesc_M0_M10_M11_N0_N10_N11
>
,
remove_reference_t
<
DeviceOp
::
CGridDesc_M0_M10_M11_N0_N10_N11
>
,
remove_reference_t
<
DeviceOp
::
Block2CTileMap
>
,
remove_reference_t
<
DeviceOp
::
Block2CTileMap
>
,
ComputePtrOffsetOfStridedBatch
,
ComputePtrOffsetOfStridedBatch
<
I0
>
,
has_main_loop
,
has_main_loop
,
has_double_loop
>
;
has_double_loop
>
;
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle.hpp
0 → 100644
View file @
f30e5975
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <numeric>
#include <sstream>
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_conv_bwd_weight.hpp"
#include "ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
template
<
index_t
NDimSpatial
,
typename
InLayout
,
typename
WeiLayout
,
typename
OutLayout
,
typename
InDataType
,
typename
WeiDataType
,
typename
OutDataType
,
typename
AccDataType
,
typename
InElementwiseOperation
,
typename
WeiElementwiseOperation
,
typename
OutElementwiseOperation
,
ConvolutionBackwardWeightSpecialization
ConvBackwardWeightSpecialization
,
index_t
BlockSize
,
index_t
MPerBlock
,
index_t
NPerBlock
,
index_t
K0PerBlock
,
index_t
K1
,
index_t
MPerWMMA
,
index_t
NPerWMMA
,
index_t
MRepeat
,
index_t
NRepeat
,
typename
ABlockTransferThreadClusterLengths_K0_M_K1
,
typename
ABlockTransferThreadClusterArrangeOrder
,
typename
ABlockTransferSrcAccessOrder
,
index_t
ABlockTransferSrcVectorDim
,
index_t
ABlockTransferSrcScalarPerVector
,
index_t
ABlockTransferDstScalarPerVector_K1
,
bool
ABlockLdsAddExtraM
,
typename
BBlockTransferThreadClusterLengths_K0_N_K1
,
typename
BBlockTransferThreadClusterArrangeOrder
,
typename
BBlockTransferSrcAccessOrder
,
index_t
BBlockTransferSrcVectorDim
,
index_t
BBlockTransferSrcScalarPerVector
,
index_t
BBlockTransferDstScalarPerVector_K1
,
bool
BBlockLdsAddExtraN
,
index_t
CShuffleMRepeatPerShuffle
,
index_t
CShuffleNRepeatPerShuffle
,
typename
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
index_t
CShuffleBlockTransferScalarPerVector_NPerBlock
,
index_t
NumGemmKPrefetchStage
=
1
,
LoopScheduler
LoopSched
=
make_default_loop_scheduler
(),
ck
::
PipelineVersion
PipelineVer
=
ck
::
PipelineVersion
::
v1
,
typename
ck
::
enable_if
<
NDimSpatial
==
3
,
bool
>
::
type
=
false
>
struct
DeviceGroupedConvBwdWeight_Wmma_CShuffle
:
public
DeviceGroupedConvBwdWeight
<
NDimSpatial
,
InLayout
,
WeiLayout
,
OutLayout
,
InDataType
,
WeiDataType
,
OutDataType
,
InElementwiseOperation
,
WeiElementwiseOperation
,
OutElementwiseOperation
>
{
using
DeviceOp
=
DeviceGroupedConvBwdWeight_Wmma_CShuffle
;
using
ADataType
=
OutDataType
;
using
BDataType
=
InDataType
;
using
CDataType
=
WeiDataType
;
using
AElementwiseOperation
=
OutElementwiseOperation
;
using
BElementwiseOperation
=
InElementwiseOperation
;
using
CElementwiseOperation
=
WeiElementwiseOperation
;
// TODO make A/B datatype different
using
ABDataType
=
InDataType
;
// 3d
static
constexpr
bool
is_NDHWGK_GKZYXC_NDHWGC
=
is_same_v
<
InLayout
,
tensor_layout
::
convolution
::
NDHWGC
>
&&
is_same_v
<
WeiLayout
,
tensor_layout
::
convolution
::
GKZYXC
>
&&
is_same_v
<
OutLayout
,
tensor_layout
::
convolution
::
NDHWGK
>
;
static
constexpr
bool
is_GNDHWK_GKZYXC_GNDHWC
=
is_same_v
<
InLayout
,
tensor_layout
::
convolution
::
GNDHWC
>
&&
is_same_v
<
WeiLayout
,
tensor_layout
::
convolution
::
GKZYXC
>
&&
is_same_v
<
OutLayout
,
tensor_layout
::
convolution
::
GNDHWK
>
;
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
static
constexpr
auto
I4
=
Number
<
4
>
{};
static
constexpr
auto
I5
=
Number
<
5
>
{};
static
constexpr
auto
GemmK1Number
=
Number
<
K1
>
{};
static
constexpr
index_t
KPerBlock
=
K0PerBlock
*
GemmK1Number
;
template
<
index_t
NDim
,
typename
ck
::
enable_if
<
NDim
==
3
,
bool
>
::
type
=
false
>
constexpr
static
auto
make_out_grid_desc
(
const
index_t
N
,
const
index_t
Do
,
const
index_t
Ho
,
const
index_t
Wo
,
const
index_t
K
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
output_strides
)
{
const
index_t
WoStride
=
output_strides
[
5
];
const
auto
KStride
=
Number
<
1
>
{};
return
make_naive_tensor_descriptor
(
make_tuple
(
N
*
Do
*
Ho
*
Wo
,
K
),
make_tuple
(
WoStride
,
KStride
));
}
template
<
index_t
NDim
,
typename
ck
::
enable_if
<
NDim
==
3
,
bool
>
::
type
=
false
>
constexpr
static
auto
make_in_grid_desc
(
const
index_t
N
,
const
index_t
Di
,
const
index_t
Hi
,
const
index_t
Wi
,
const
index_t
C
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
input_strides
)
{
const
index_t
NStride
=
input_strides
[
1
];
const
index_t
DiStride
=
input_strides
[
3
];
const
index_t
HiStride
=
input_strides
[
4
];
const
index_t
WiStride
=
input_strides
[
5
];
const
auto
CStride
=
input_strides
[
2
];
if
constexpr
(
ConvBackwardWeightSpecialization
==
ConvolutionBackwardWeightSpecialization
::
Filter1x1Stride1Pad0
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
N
*
Di
*
Hi
*
Wi
,
C
),
make_tuple
(
WiStride
,
CStride
));
}
else
{
return
make_naive_tensor_descriptor
(
make_tuple
(
N
,
Di
,
Hi
,
Wi
,
C
),
make_tuple
(
NStride
,
DiStride
,
HiStride
,
WiStride
,
CStride
));
}
}
template
<
index_t
NDim
,
typename
ck
::
enable_if
<
NDim
==
3
,
bool
>
::
type
=
false
>
constexpr
static
auto
make_wei_grid_desc
(
const
index_t
K
,
const
index_t
Z
,
const
index_t
Y
,
const
index_t
X
,
const
index_t
C
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
weights_strides
)
{
const
auto
CStride
=
Number
<
1
>
{};
const
auto
KStride
=
weights_strides
[
1
];
return
make_naive_tensor_descriptor
(
make_tuple
(
K
,
Z
*
Y
*
X
*
C
),
make_tuple
(
KStride
,
CStride
));
}
template
<
index_t
NDim
,
typename
ck
::
enable_if
<
NDim
==
3
,
bool
>
::
type
=
false
>
static
auto
MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N
(
const
index_t
N
,
const
index_t
K
,
const
index_t
C
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_spatial_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
filter_spatial_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
output_spatial_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
input_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
weights_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
output_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_dilations
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_left_pads
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_right_pads
)
{
using
namespace
ck
;
const
index_t
Di
=
input_spatial_lengths
[
0
];
const
index_t
Hi
=
input_spatial_lengths
[
1
];
const
index_t
Wi
=
input_spatial_lengths
[
2
];
const
index_t
Do
=
output_spatial_lengths
[
0
];
const
index_t
Ho
=
output_spatial_lengths
[
1
];
const
index_t
Wo
=
output_spatial_lengths
[
2
];
const
index_t
Z
=
filter_spatial_lengths
[
0
];
const
index_t
Y
=
filter_spatial_lengths
[
1
];
const
index_t
X
=
filter_spatial_lengths
[
2
];
const
index_t
ConvStrideD
=
conv_filter_strides
[
0
];
const
index_t
ConvStrideH
=
conv_filter_strides
[
1
];
const
index_t
ConvStrideW
=
conv_filter_strides
[
2
];
const
index_t
ConvDilationD
=
conv_filter_dilations
[
0
];
const
index_t
ConvDilationH
=
conv_filter_dilations
[
1
];
const
index_t
ConvDilationW
=
conv_filter_dilations
[
2
];
const
index_t
InLeftPadD
=
input_left_pads
[
0
];
const
index_t
InLeftPadH
=
input_left_pads
[
1
];
const
index_t
InLeftPadW
=
input_left_pads
[
2
];
const
index_t
InRightPadD
=
input_right_pads
[
0
];
const
index_t
InRightPadH
=
input_right_pads
[
1
];
const
index_t
InRightPadW
=
input_right_pads
[
2
];
const
index_t
GemmKTotal
=
N
*
Do
*
Ho
*
Wo
;
const
index_t
GemmM
=
K
;
const
index_t
GemmN
=
C
*
Z
*
X
*
Y
;
const
auto
PadGemmM
=
(
MPerBlock
-
GemmM
%
MPerBlock
)
%
MPerBlock
;
const
auto
PadGemmN
=
(
NPerBlock
-
GemmN
%
NPerBlock
)
%
NPerBlock
;
const
index_t
GemmK0
=
math
::
integer_divide_ceil
(
GemmKTotal
,
GemmK1Number
*
K0PerBlock
)
*
K0PerBlock
;
const
index_t
GemmKPad
=
GemmK0
*
GemmK1Number
;
const
auto
out_grid_desc
=
make_out_grid_desc
<
NDim
>
(
N
,
Do
,
Ho
,
Wo
,
K
,
output_strides
);
const
auto
in_grid_desc
=
make_in_grid_desc
<
NDim
>
(
N
,
Di
,
Hi
,
Wi
,
C
,
input_strides
);
const
auto
wei_grid_desc
=
make_wei_grid_desc
<
NDim
>
(
K
,
Z
,
Y
,
X
,
C
,
weights_strides
);
if
constexpr
(
ConvBackwardWeightSpecialization
==
ConvolutionBackwardWeightSpecialization
::
Filter1x1Stride1Pad0
)
{
// A: output tensor
const
auto
out_gemmkpad_gemmm_grid_desc
=
transform_tensor_descriptor
(
out_grid_desc
,
make_tuple
(
make_right_pad_transform
(
GemmKTotal
,
GemmKPad
-
GemmKTotal
),
make_pass_through_transform
(
GemmM
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc
=
transform_tensor_descriptor
(
out_gemmkpad_gemmm_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
GemmK0
,
GemmK1Number
)),
make_pass_through_transform
(
GemmM
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
// B: input tensor
const
auto
in_gemmkpad_gemmn_grid_desc
=
transform_tensor_descriptor
(
in_grid_desc
,
make_tuple
(
make_right_pad_transform
(
GemmKTotal
,
GemmKPad
-
GemmKTotal
),
make_pass_through_transform
(
GemmN
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc
=
transform_tensor_descriptor
(
in_gemmkpad_gemmn_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
GemmK0
,
GemmK1Number
)),
make_pass_through_transform
(
GemmN
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
make_tuple
(
out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc
,
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc
,
wei_grid_desc
);
}
else
{
// A: output tensor
const
auto
out_gemmkpad_gemmm_grid_desc
=
transform_tensor_descriptor
(
out_grid_desc
,
make_tuple
(
make_right_pad_transform
(
GemmKTotal
,
GemmKPad
-
GemmKTotal
),
make_pass_through_transform
(
GemmM
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc
=
transform_tensor_descriptor
(
out_gemmkpad_gemmm_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
GemmK0
,
GemmK1Number
)),
make_pass_through_transform
(
GemmM
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
// B: input tensor
const
auto
in_n_dip_hip_wip_c_grid_desc
=
transform_tensor_descriptor
(
in_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_pad_transform
(
Di
,
InLeftPadD
,
InRightPadD
),
make_pad_transform
(
Hi
,
InLeftPadH
,
InRightPadH
),
make_pad_transform
(
Wi
,
InLeftPadW
,
InRightPadW
),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{}));
const
auto
in_n_z_do_y_ho_x_wo_c_grid_desc
=
transform_tensor_descriptor
(
in_n_dip_hip_wip_c_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_embed_transform
(
make_tuple
(
Z
,
Do
),
make_tuple
(
ConvDilationD
,
ConvStrideD
)),
make_embed_transform
(
make_tuple
(
Y
,
Ho
),
make_tuple
(
ConvDilationH
,
ConvStrideH
)),
make_embed_transform
(
make_tuple
(
X
,
Wo
),
make_tuple
(
ConvDilationW
,
ConvStrideW
)),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
,
4
>
{},
Sequence
<
5
,
6
>
{},
Sequence
<
7
>
{}));
const
auto
in_gemmktotal_gemmn_grid_desc
=
transform_tensor_descriptor
(
in_n_z_do_y_ho_x_wo_c_grid_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
Z
,
Y
,
X
,
C
)),
make_merge_transform
(
make_tuple
(
N
,
Do
,
Ho
,
Wo
))),
make_tuple
(
Sequence
<
1
,
3
,
5
,
7
>
{},
Sequence
<
0
,
2
,
4
,
6
>
{}),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}));
const
auto
in_gemmkpad_gemmn_grid_desc
=
transform_tensor_descriptor
(
in_gemmktotal_gemmn_grid_desc
,
make_tuple
(
make_right_pad_transform
(
GemmKTotal
,
GemmKPad
-
GemmKTotal
),
make_pass_through_transform
(
GemmN
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc
=
transform_tensor_descriptor
(
in_gemmkpad_gemmn_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
GemmK0
,
GemmK1Number
)),
make_pass_through_transform
(
GemmN
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
// Pad
const
auto
out_gemmkbatch_gemmk0_gemmm_gemmk1_pad_grid_desc
=
transform_tensor_descriptor
(
out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc
,
make_tuple
(
make_pass_through_transform
(
GemmK0
),
make_right_pad_transform
(
GemmM
,
PadGemmM
),
make_pass_through_transform
(
GemmK1Number
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}));
const
auto
in_gemmkbatch_gemmk0_gemmn_gemmk1_pad_grid_desc
=
transform_tensor_descriptor
(
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc
,
make_tuple
(
make_pass_through_transform
(
GemmK0
),
make_right_pad_transform
(
GemmN
,
PadGemmN
),
make_pass_through_transform
(
GemmK1Number
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}));
const
auto
wei_gemmm_gemmn_pad_grid_desc
=
transform_tensor_descriptor
(
wei_grid_desc
,
make_tuple
(
make_right_pad_transform
(
GemmM
,
PadGemmM
),
make_right_pad_transform
(
GemmN
,
PadGemmN
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
return
make_tuple
(
out_gemmkbatch_gemmk0_gemmm_gemmk1_pad_grid_desc
,
in_gemmkbatch_gemmk0_gemmn_gemmk1_pad_grid_desc
,
wei_gemmm_gemmn_pad_grid_desc
);
}
}
template
<
index_t
NDim
,
typename
ck
::
enable_if
<
NDim
==
3
,
bool
>
::
type
=
false
>
static
auto
GetABCGridDesc
()
{
const
index_t
dim
=
1
;
const
std
::
array
<
index_t
,
NDimSpatial
>
lengths
{
1
,
1
,
1
};
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>
strides
{
1
,
1
,
1
,
1
,
1
,
1
};
const
std
::
array
<
index_t
,
NDimSpatial
>
params
{
1
,
1
,
1
};
return
MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N
<
3
>
(
dim
,
dim
,
dim
,
lengths
,
lengths
,
lengths
,
strides
,
strides
,
strides
,
params
,
params
,
params
,
params
);
}
using
ABCGridDescs
=
decltype
(
GetABCGridDesc
<
NDimSpatial
>
());
using
AGridDesc_K0_M_K1
=
remove_cvref_t
<
decltype
(
ABCGridDescs
{}[
I0
])
>
;
using
BGridDesc_K0_N_K1
=
remove_cvref_t
<
decltype
(
ABCGridDescs
{}[
I1
])
>
;
using
CGridDesc_M_N
=
remove_cvref_t
<
decltype
(
ABCGridDescs
{}[
I2
])
>
;
using
GridwiseGemm
=
GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle
<
// DataType Family
ADataType
,
BDataType
,
AccDataType
,
CDataType
,
Tuple
<>
,
CDataType
,
// InMemory Data Descriptor
AGridDesc_K0_M_K1
,
BGridDesc_K0_N_K1
,
Tuple
<>
,
CGridDesc_M_N
,
// ElementwiseOp Family
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
,
InMemoryDataOperationEnum
::
Set
,
// Tiling Family
MPerBlock
,
NPerBlock
,
K0PerBlock
,
MPerWMMA
,
NPerWMMA
,
K1
,
MRepeat
,
NRepeat
,
// ThreadCluster Family
BlockSize
,
ABlockTransferThreadClusterLengths_K0_M_K1
,
ABlockTransferThreadClusterArrangeOrder
,
ABlockTransferSrcAccessOrder
,
ABlockTransferSrcVectorDim
,
ABlockTransferSrcScalarPerVector
,
ABlockTransferDstScalarPerVector_K1
,
false
,
ABlockLdsAddExtraM
,
BBlockTransferThreadClusterLengths_K0_N_K1
,
BBlockTransferThreadClusterArrangeOrder
,
BBlockTransferSrcAccessOrder
,
BBlockTransferSrcVectorDim
,
BBlockTransferSrcScalarPerVector
,
BBlockTransferDstScalarPerVector_K1
,
false
,
BBlockLdsAddExtraN
,
CShuffleMRepeatPerShuffle
,
CShuffleNRepeatPerShuffle
,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
CShuffleBlockTransferScalarPerVector_NPerBlock
,
NumGemmKPrefetchStage
,
LoopSched
,
PipelineVer
>
;
using
DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
=
decltype
(
GridwiseGemm
::
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
Tuple
<>
{}));
using
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
=
decltype
(
GridwiseGemm
::
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
CGridDesc_M_N
{}));
using
Block2CTileMap
=
decltype
(
GridwiseGemm
::
MakeDefaultBlock2CTileMap
(
CGridDesc_M_N
{},
I1
/* M01 */
,
I1
/* N01 */
));
struct
Argument
:
public
BaseArgument
{
Argument
(
const
InDataType
*
p_in_grid
,
WeiDataType
*
p_wei_grid
,
const
OutDataType
*
p_out_grid
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
a_g_n_c_wis_lengths
,
// input
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
a_g_n_c_wis_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
b_g_k_c_xs_lengths
,
// weight
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
b_g_k_c_xs_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
e_g_n_k_wos_lengths
,
// output
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
e_g_n_k_wos_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_dilations
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_left_pads
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_right_pads
,
InElementwiseOperation
in_element_op
,
WeiElementwiseOperation
wei_element_op
,
OutElementwiseOperation
out_element_op
,
index_t
split_k
)
:
p_a_grid_
{
p_out_grid
},
p_b_grid_
{
p_in_grid
},
p_c_grid_
{
p_wei_grid
},
a_grid_desc_kbatch_k0_m_k1_
{},
b_grid_desc_kbatch_k0_n_k1_
{},
c_grid_desc_m_n_
{},
c_grid_desc_mblock_mperblock_nblock_nperblock_
{},
block_2_ctile_map_
{},
compute_ptr_offset_of_batch_
{},
a_element_op_
{
out_element_op
},
b_element_op_
{
in_element_op
},
c_element_op_
{
wei_element_op
},
Conv_G_
{
a_g_n_c_wis_lengths
[
0
]},
Conv_N_
{
a_g_n_c_wis_lengths
[
1
]},
Conv_K_
{
b_g_k_c_xs_lengths
[
1
]},
Conv_C_
{
a_g_n_c_wis_lengths
[
2
]},
input_spatial_lengths_
{},
filter_spatial_lengths_
{},
output_spatial_lengths_
{},
conv_filter_strides_
{
conv_filter_strides
},
input_left_pads_
{
input_left_pads
},
input_right_pads_
{
input_right_pads
},
k_batch_
{
split_k
}
{
constexpr
index_t
spatial_offset
=
3
;
std
::
copy
(
begin
(
a_g_n_c_wis_lengths
)
+
spatial_offset
,
end
(
a_g_n_c_wis_lengths
),
begin
(
input_spatial_lengths_
));
std
::
copy
(
begin
(
b_g_k_c_xs_lengths
)
+
spatial_offset
,
end
(
b_g_k_c_xs_lengths
),
begin
(
filter_spatial_lengths_
));
std
::
copy
(
begin
(
e_g_n_k_wos_lengths
)
+
spatial_offset
,
end
(
e_g_n_k_wos_lengths
),
begin
(
output_spatial_lengths_
));
const
auto
descs
=
DeviceOp
::
MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N
<
NDimSpatial
>
(
Conv_N_
,
Conv_K_
,
Conv_C_
,
input_spatial_lengths_
,
filter_spatial_lengths_
,
output_spatial_lengths_
,
a_g_n_c_wis_strides
,
b_g_k_c_xs_strides
,
e_g_n_k_wos_strides
,
conv_filter_strides
,
conv_filter_dilations
,
input_left_pads
,
input_right_pads
);
a_grid_desc_kbatch_k0_m_k1_
=
descs
[
I0
];
b_grid_desc_kbatch_k0_n_k1_
=
descs
[
I1
];
c_grid_desc_m_n_
=
descs
[
I2
];
block_2_ctile_map_
=
GridwiseGemm
::
MakeDefaultBlock2CTileMap
(
c_grid_desc_m_n_
,
I1
/* M01 */
,
I1
/* N01 */
);
// A/B/C Batch Stride
compute_ptr_offset_of_batch_
.
BatchStrideA_
=
e_g_n_k_wos_strides
[
0
];
compute_ptr_offset_of_batch_
.
BatchStrideB_
=
a_g_n_c_wis_strides
[
0
];
compute_ptr_offset_of_batch_
.
BatchStrideE_
=
Conv_K_
*
Conv_C_
*
std
::
accumulate
(
begin
(
filter_spatial_lengths_
),
end
(
filter_spatial_lengths_
),
index_t
{
1
},
std
::
multiplies
<>
{});
if
(
GridwiseGemm
::
CheckValidity
(
a_grid_desc_kbatch_k0_m_k1_
,
b_grid_desc_kbatch_k0_n_k1_
,
c_grid_desc_m_n_
,
block_2_ctile_map_
))
{
c_grid_desc_mblock_mperblock_nblock_nperblock_
=
GridwiseGemm
::
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
c_grid_desc_m_n_
);
}
}
const
ADataType
*
p_a_grid_
;
const
BDataType
*
p_b_grid_
;
CDataType
*
p_c_grid_
;
AGridDesc_K0_M_K1
a_grid_desc_kbatch_k0_m_k1_
;
BGridDesc_K0_N_K1
b_grid_desc_kbatch_k0_n_k1_
;
CGridDesc_M_N
c_grid_desc_m_n_
;
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock_
;
Block2CTileMap
block_2_ctile_map_
;
// for computing batch offset
ComputePtrOffsetOfStridedBatch
<
I0
>
compute_ptr_offset_of_batch_
;
OutElementwiseOperation
a_element_op_
;
InElementwiseOperation
b_element_op_
;
WeiElementwiseOperation
c_element_op_
;
// for checking IsSupportedArgument()
const
index_t
Conv_G_
;
const
index_t
Conv_N_
;
const
index_t
Conv_K_
;
const
index_t
Conv_C_
;
std
::
array
<
index_t
,
NDimSpatial
>
input_spatial_lengths_
;
std
::
array
<
index_t
,
NDimSpatial
>
filter_spatial_lengths_
;
std
::
array
<
index_t
,
NDimSpatial
>
output_spatial_lengths_
;
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_strides_
;
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_left_pads_
;
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_right_pads_
;
const
index_t
k_batch_
;
};
// Invoker
struct
Invoker
:
public
BaseInvoker
{
using
Argument
=
DeviceOp
::
Argument
;
void
Print
(
const
Argument
&
arg
)
{
std
::
cout
<<
"arg.a_grid_desc_kbatch_k0_m_k1_{"
<<
arg
.
a_grid_desc_kbatch_k0_m_k1_
.
GetLength
(
I0
)
<<
", "
<<
arg
.
a_grid_desc_kbatch_k0_m_k1_
.
GetLength
(
I1
)
<<
", "
<<
arg
.
a_grid_desc_kbatch_k0_m_k1_
.
GetLength
(
I2
)
<<
"}"
<<
std
::
endl
;
std
::
cout
<<
"arg.b_grid_desc_kbatch_k0_n_k1_{"
<<
arg
.
b_grid_desc_kbatch_k0_n_k1_
.
GetLength
(
I0
)
<<
", "
<<
arg
.
b_grid_desc_kbatch_k0_n_k1_
.
GetLength
(
I1
)
<<
", "
<<
arg
.
b_grid_desc_kbatch_k0_n_k1_
.
GetLength
(
I2
)
<<
"}"
<<
std
::
endl
;
std
::
cout
<<
"arg.c_grid_desc_m_n_{"
<<
arg
.
c_grid_desc_m_n_
.
GetLength
(
I0
)
<<
", "
<<
arg
.
c_grid_desc_m_n_
.
GetLength
(
I1
)
<<
"}"
<<
std
::
endl
;
}
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
{
if
(
stream_config
.
log_level_
>
0
)
{
Print
(
arg
);
}
if
(
!
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_kbatch_k0_m_k1_
,
arg
.
b_grid_desc_kbatch_k0_n_k1_
,
arg
.
c_grid_desc_m_n_
,
arg
.
block_2_ctile_map_
))
{
throw
std
::
runtime_error
(
"wrong! GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle has invalid "
"setting"
);
}
const
index_t
grid_size
=
arg
.
block_2_ctile_map_
.
CalculateGridSize
(
arg
.
c_grid_desc_m_n_
)
*
arg
.
Conv_G_
;
const
auto
K0
=
arg
.
a_grid_desc_kbatch_k0_m_k1_
.
GetLength
(
I1
);
const
bool
has_main_k0_block_loop
=
GridwiseGemm
::
CalculateHasMainKBlockLoop
(
K0
);
auto
launch_kernel
=
[
&
](
auto
has_main_k_block_loop
)
{
constexpr
bool
has_main_loop
=
has_main_k_block_loop
.
value
;
const
auto
kernel
=
kernel_grouped_conv_multiple_d_wmma_cshuffle
<
GridwiseGemm
,
ADataType
,
BDataType
,
typename
GridwiseGemm
::
DsGridPointer
,
CDataType
,
OutElementwiseOperation
,
InElementwiseOperation
,
WeiElementwiseOperation
,
AGridDesc_K0_M_K1
,
BGridDesc_K0_N_K1
,
DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
,
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
,
remove_reference_t
<
typename
GridwiseGemm
::
DefaultBlock2CTileMap
>
,
ComputePtrOffsetOfStridedBatch
<
I0
>
,
has_main_loop
>
;
using
EmptyTuple
=
Tuple
<>
;
return
launch_and_time_kernel
(
stream_config
,
kernel
,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
arg
.
p_a_grid_
,
arg
.
p_b_grid_
,
EmptyTuple
{},
// Ds
arg
.
p_c_grid_
,
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
c_element_op_
,
arg
.
Conv_G_
,
arg
.
a_grid_desc_kbatch_k0_m_k1_
,
arg
.
b_grid_desc_kbatch_k0_n_k1_
,
DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
{},
arg
.
c_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg
.
block_2_ctile_map_
,
arg
.
compute_ptr_offset_of_batch_
);
};
if
(
has_main_k0_block_loop
)
{
return
launch_kernel
(
integral_constant
<
bool
,
true
>
{});
}
else
{
return
launch_kernel
(
integral_constant
<
bool
,
false
>
{});
}
}
float
Run
(
const
BaseArgument
*
p_arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
override
{
return
Run
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
),
stream_config
);
}
};
static
constexpr
bool
IsValidCompilationParameter
()
{
// TODO: properly implement this check
return
true
;
}
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
// check device
if
(
get_device_name
()
==
"gfx1100"
||
get_device_name
()
==
"gfx1101"
||
get_device_name
()
==
"gfx1102"
)
{
if
constexpr
(
!
(
is_same_v
<
AccDataType
,
float
>
||
is_same_v
<
AccDataType
,
int32_t
>
))
{
return
false
;
}
}
else
{
return
false
;
}
// TODO: Add support for split_k > 1
if
(
arg
.
k_batch_
!=
1
)
{
return
false
;
}
if
constexpr
(
!
(
is_NDHWGK_GKZYXC_NDHWGC
||
is_GNDHWK_GKZYXC_GNDHWC
))
{
return
false
;
}
if
constexpr
(
ConvBackwardWeightSpecialization
==
ConvolutionBackwardWeightSpecialization
::
Filter1x1Stride1Pad0
)
{
// check if it's a 1x1 convolution with stride=1 and no padding
for
(
int
i
=
0
;
i
<
NDimSpatial
;
i
++
)
{
if
(
!
(
arg
.
filter_spatial_lengths_
[
i
]
==
1
&&
arg
.
conv_filter_strides_
[
i
]
==
1
&&
arg
.
input_left_pads_
[
i
]
==
0
&&
arg
.
input_right_pads_
[
i
]
==
0
))
{
return
false
;
}
}
}
// vector load A/B matrix from global memory
if
(
!
(
ABlockTransferSrcVectorDim
==
1
&&
BBlockTransferSrcVectorDim
==
1
&&
arg
.
Conv_K_
%
ABlockTransferSrcScalarPerVector
==
0
&&
arg
.
Conv_C_
%
BBlockTransferSrcScalarPerVector
==
0
))
{
return
false
;
}
// vector store C matrix into global memory
if
(
!
(
arg
.
Conv_C_
%
CShuffleBlockTransferScalarPerVector_NPerBlock
==
0
))
{
return
false
;
}
// Gridwise GEMM size
return
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_kbatch_k0_m_k1_
,
arg
.
b_grid_desc_kbatch_k0_n_k1_
,
arg
.
c_grid_desc_m_n_
,
arg
.
block_2_ctile_map_
);
}
bool
IsSupportedArgument
(
const
BaseArgument
*
p_arg
)
override
{
return
IsSupportedArgument
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
));
}
static
auto
MakeArgument
(
const
InDataType
*
p_in_grid
,
WeiDataType
*
p_wei_grid
,
const
OutDataType
*
p_out_grid
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
a_g_n_c_wis_lengths
,
// input
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
a_g_n_c_wis_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
b_g_k_c_xs_lengths
,
// weight
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
b_g_k_c_xs_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
e_g_n_k_wos_lengths
,
// output
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
e_g_n_k_wos_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_dilations
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_left_pads
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_right_pads
,
InElementwiseOperation
in_element_op
,
WeiElementwiseOperation
wei_element_op
,
OutElementwiseOperation
out_element_op
,
const
index_t
split_k
)
{
return
Argument
{
p_in_grid
,
p_wei_grid
,
p_out_grid
,
a_g_n_c_wis_lengths
,
// input
a_g_n_c_wis_strides
,
b_g_k_c_xs_lengths
,
// weight
b_g_k_c_xs_strides
,
e_g_n_k_wos_lengths
,
// output
e_g_n_k_wos_strides
,
conv_filter_strides
,
conv_filter_dilations
,
input_left_pads
,
input_right_pads
,
in_element_op
,
wei_element_op
,
out_element_op
,
split_k
};
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
void
*
p_in_grid
,
void
*
p_wei_grid
,
const
void
*
p_out_grid
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
a_g_n_c_wis_lengths
,
// input
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
a_g_n_c_wis_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
b_g_k_c_xs_lengths
,
// weight
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
b_g_k_c_xs_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
e_g_n_k_wos_lengths
,
// output
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
e_g_n_k_wos_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_dilations
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_left_pads
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_right_pads
,
InElementwiseOperation
in_element_op
,
WeiElementwiseOperation
wei_element_op
,
OutElementwiseOperation
out_element_op
,
const
index_t
split_k
)
override
{
return
std
::
make_unique
<
Argument
>
(
static_cast
<
const
InDataType
*>
(
p_in_grid
),
static_cast
<
WeiDataType
*>
(
p_wei_grid
),
static_cast
<
const
OutDataType
*>
(
p_out_grid
),
a_g_n_c_wis_lengths
,
// input
a_g_n_c_wis_strides
,
b_g_k_c_xs_lengths
,
// weight
b_g_k_c_xs_strides
,
e_g_n_k_wos_lengths
,
// output
e_g_n_k_wos_strides
,
conv_filter_strides
,
conv_filter_dilations
,
input_left_pads
,
input_right_pads
,
in_element_op
,
wei_element_op
,
out_element_op
,
split_k
);
}
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
override
{
return
std
::
make_unique
<
Invoker
>
(
Invoker
{});
}
std
::
string
GetTypeString
()
const
override
{
auto
str
=
std
::
stringstream
();
// clang-format off
str
<<
"DeviceGroupedConvBwdWeight_Wmma_CShuffle"
<<
"<"
<<
BlockSize
<<
", "
<<
MPerBlock
<<
", "
<<
NPerBlock
<<
", "
<<
K0PerBlock
<<
", "
<<
getConvBackwardWeightSpecializationString
(
ConvBackwardWeightSpecialization
)
<<
", "
<<
K1
<<
", "
<<
ABlockTransferSrcScalarPerVector
<<
", "
<<
ABlockTransferDstScalarPerVector_K1
<<
", "
<<
BBlockTransferSrcScalarPerVector
<<
", "
<<
BBlockTransferDstScalarPerVector_K1
<<
", "
<<
CShuffleMRepeatPerShuffle
<<
", "
<<
CShuffleNRepeatPerShuffle
<<
", "
<<
CShuffleBlockTransferScalarPerVector_NPerBlock
<<
">"
;
// clang-format on
return
str
.
str
();
}
};
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp
View file @
f30e5975
...
@@ -14,6 +14,7 @@
...
@@ -14,6 +14,7 @@
#include "ck/tensor_operation/gpu/device/device_grouped_conv_bwd_weight.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_conv_bwd_weight.hpp"
#include "ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp"
#include "ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/host_utility/kernel_launch.hpp"
...
@@ -21,32 +22,6 @@ namespace ck {
...
@@ -21,32 +22,6 @@ namespace ck {
namespace
tensor_operation
{
namespace
tensor_operation
{
namespace
device
{
namespace
device
{
namespace
{
struct
ComputePtrOffsetOfStridedBatch
{
__host__
__device__
constexpr
long_index_t
GetAPtrOffset
(
index_t
g_idx
)
const
{
return
g_idx
*
static_cast
<
long_index_t
>
(
BatchStrideA_
);
}
__host__
__device__
constexpr
long_index_t
GetBPtrOffset
(
index_t
g_idx
)
const
{
return
g_idx
*
static_cast
<
long_index_t
>
(
BatchStrideB_
);
}
__host__
__device__
constexpr
long_index_t
GetCPtrOffset
(
index_t
g_idx
)
const
{
return
g_idx
*
static_cast
<
long_index_t
>
(
BatchStrideC_
);
}
index_t
BatchStrideA_
;
index_t
BatchStrideB_
;
index_t
BatchStrideC_
;
};
}
// namespace
template
<
typename
GridwiseGemm
,
template
<
typename
GridwiseGemm
,
typename
FloatA
,
typename
FloatA
,
typename
FloatB
,
typename
FloatB
,
...
@@ -1222,7 +1197,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
...
@@ -1222,7 +1197,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
Block2CTileMap
block_2_ctile_map_
;
Block2CTileMap
block_2_ctile_map_
;
// for computing batch offset
// for computing batch offset
ComputePtrOffsetOfStridedBatch
compute_ptr_offset_of_batch_
;
ComputePtrOffsetOfStridedBatch
<
I0
>
compute_ptr_offset_of_batch_
;
index_t
M01_
;
index_t
M01_
;
index_t
N01_
;
index_t
N01_
;
...
@@ -1301,7 +1276,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
...
@@ -1301,7 +1276,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
remove_reference_t
<
DeviceOp
::
BGridDesc_K0_N_K1
>
,
remove_reference_t
<
DeviceOp
::
BGridDesc_K0_N_K1
>
,
remove_reference_t
<
DeviceOp
::
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
>
,
remove_reference_t
<
DeviceOp
::
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
>
,
remove_reference_t
<
DeviceOp
::
Block2CTileMap
>
,
remove_reference_t
<
DeviceOp
::
Block2CTileMap
>
,
ComputePtrOffsetOfStridedBatch
,
ComputePtrOffsetOfStridedBatch
<
I0
>
,
has_main_loop
>
;
has_main_loop
>
;
return
launch_and_time_kernel
(
stream_config
,
return
launch_and_time_kernel
(
stream_config
,
...
@@ -1348,6 +1323,10 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
...
@@ -1348,6 +1323,10 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
{
if
(
!
ck
::
is_xdl_supported
())
{
return
false
;
}
if
constexpr
(
NDimSpatial
==
1
)
if
constexpr
(
NDimSpatial
==
1
)
{
{
if
constexpr
(
!
is_GNWK_GKXC_GNWC
)
if
constexpr
(
!
is_GNWK_GKXC_GNWC
)
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp
View file @
f30e5975
...
@@ -471,7 +471,7 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
...
@@ -471,7 +471,7 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
auto
launch_kernel
=
[
&
](
auto
has_main_k_block_loop
)
{
auto
launch_kernel
=
[
&
](
auto
has_main_k_block_loop
)
{
constexpr
bool
has_main_loop
=
has_main_k_block_loop
.
value
;
constexpr
bool
has_main_loop
=
has_main_k_block_loop
.
value
;
const
auto
kernel
=
kernel_grouped_conv_
fwd_
multiple_d_wmma_cshuffle
<
const
auto
kernel
=
kernel_grouped_conv_multiple_d_wmma_cshuffle
<
GridwiseOp
,
GridwiseOp
,
ADataType
,
ADataType
,
BDataType
,
BDataType
,
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp
View file @
f30e5975
...
@@ -43,7 +43,13 @@ struct ComputePtrOffsetOfStridedBatch
...
@@ -43,7 +43,13 @@ struct ComputePtrOffsetOfStridedBatch
return
ds_offset
;
return
ds_offset
;
}
}
__host__
__device__
constexpr
long_index_t
GetEPtrOffset
(
index_t
g_idx
)
const
[[
maybe_unused
]]
__host__
__device__
constexpr
long_index_t
GetEPtrOffset
(
index_t
g_idx
)
const
{
return
g_idx
*
static_cast
<
long_index_t
>
(
BatchStrideE_
);
}
// alias for kernels without multiple D
[[
maybe_unused
]]
__host__
__device__
constexpr
long_index_t
GetCPtrOffset
(
index_t
g_idx
)
const
{
{
return
g_idx
*
static_cast
<
long_index_t
>
(
BatchStrideE_
);
return
g_idx
*
static_cast
<
long_index_t
>
(
BatchStrideE_
);
}
}
...
@@ -52,6 +58,7 @@ struct ComputePtrOffsetOfStridedBatch
...
@@ -52,6 +58,7 @@ struct ComputePtrOffsetOfStridedBatch
index_t
BatchStrideB_
;
index_t
BatchStrideB_
;
Array
<
ck
::
index_t
,
NumDTensor
>
BatchStrideDs_
;
Array
<
ck
::
index_t
,
NumDTensor
>
BatchStrideDs_
;
index_t
BatchStrideE_
;
index_t
BatchStrideE_
;
index_t
&
BatchStrideC_
=
BatchStrideE_
;
// alias for kernels without multiple D
};
};
}
// namespace device
}
// namespace device
...
...
include/ck/tensor_operation/gpu/device/impl/device_normalization_impl.hpp
View file @
f30e5975
...
@@ -28,6 +28,7 @@ template <typename XDataType,
...
@@ -28,6 +28,7 @@ template <typename XDataType,
typename
BetaDataType
,
typename
BetaDataType
,
typename
ComputeDataType
,
typename
ComputeDataType
,
typename
YDataType
,
typename
YDataType
,
typename
SaveMeanInvStdDataType
,
typename
YElementwiseOperation
,
typename
YElementwiseOperation
,
index_t
Rank
,
index_t
Rank
,
index_t
NumReduceDim
,
index_t
NumReduceDim
,
...
@@ -43,12 +44,13 @@ template <typename XDataType,
...
@@ -43,12 +44,13 @@ template <typename XDataType,
index_t
BetaSrcVectorDim
,
index_t
BetaSrcVectorDim
,
index_t
BetaSrcVectorSize
,
index_t
BetaSrcVectorSize
,
index_t
YDstVectorSize
,
index_t
YDstVectorSize
,
index_t
SaveMeanInvStdDstVectorSize
,
bool
UseWelford
=
true
>
bool
UseWelford
=
true
>
struct
DeviceNormalizationImpl
:
public
DeviceNormalization
<
XDataType
,
struct
DeviceNormalizationImpl
:
public
DeviceNormalization
<
XDataType
,
GammaDataType
,
GammaDataType
,
BetaDataType
,
BetaDataType
,
ComputeDataType
,
YDataType
,
YDataType
,
SaveMeanInvStdDataType
,
YElementwiseOperation
,
YElementwiseOperation
,
Rank
,
Rank
,
NumReduceDim
>
NumReduceDim
>
...
@@ -64,18 +66,24 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
...
@@ -64,18 +66,24 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
(
BetaSrcVectorDim
==
1
&&
KThreadSliceSize
%
BetaSrcVectorSize
==
0
)),
(
BetaSrcVectorDim
==
1
&&
KThreadSliceSize
%
BetaSrcVectorSize
==
0
)),
"Invalid thread slice sizes and/or beta vector sizes configuration, please check!"
);
"Invalid thread slice sizes and/or beta vector sizes configuration, please check!"
);
static_assert
(
MThreadSliceSize
%
SaveMeanInvStdDstVectorSize
==
0
,
"Invalid thread slice sizes and/or save mean and inverse std vector sizes "
"configuration, please check!"
);
using
PassThrough
=
tensor_operation
::
element_wise
::
PassThrough
;
using
PassThrough
=
tensor_operation
::
element_wise
::
PassThrough
;
static
constexpr
index_t
NumInvariantDim
=
Rank
-
NumReduceDim
;
static
constexpr
index_t
M_BlockTileSize
=
MThreadClusterSize
*
MThreadSliceSize
;
static
constexpr
index_t
M_BlockTileSize
=
MThreadClusterSize
*
MThreadSliceSize
;
static
constexpr
index_t
K_BlockTileSize
=
KThreadClusterSize
*
KThreadSliceSize
;
static
constexpr
index_t
K_BlockTileSize
=
KThreadClusterSize
*
KThreadSliceSize
;
static
constexpr
bool
reduceAllDim
=
(
NumInvariantDim
==
0
);
static_assert
(
!
reduceAllDim
);
// TODO
static
auto
MakeSrc2dDescriptor
(
const
std
::
vector
<
index_t
>&
inLengths
,
static
auto
MakeSrc2dDescriptor
(
const
std
::
vector
<
index_t
>&
inLengths
,
const
std
::
vector
<
index_t
>&
inStrides
,
const
std
::
vector
<
index_t
>&
inStrides
,
int
numBlockTileIteration
)
int
numBlockTileIteration
)
{
{
constexpr
index_t
NumInvariantDim
=
Rank
-
NumReduceDim
;
static
constexpr
index_t
numSrcDim
=
Rank
;
static
constexpr
index_t
numSrcDim
=
Rank
;
static
constexpr
bool
reduceAllDim
=
(
NumInvariantDim
==
0
);
const
auto
tupleSrcLengths
=
make_tuple_from_array
(
inLengths
,
Number
<
numSrcDim
>
{});
const
auto
tupleSrcLengths
=
make_tuple_from_array
(
inLengths
,
Number
<
numSrcDim
>
{});
const
auto
tupleSrcStrides
=
make_tuple_from_array
(
inStrides
,
Number
<
numSrcDim
>
{});
const
auto
tupleSrcStrides
=
make_tuple_from_array
(
inStrides
,
Number
<
numSrcDim
>
{});
...
@@ -133,7 +141,37 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
...
@@ -133,7 +141,37 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
return
(
in_grid_desc_m_k_padded
);
return
(
in_grid_desc_m_k_padded
);
};
};
static
auto
MakeSaveMeanInvStdDescriptor_M
(
const
std
::
vector
<
index_t
>&
lengths
,
const
std
::
vector
<
index_t
>&
strides
)
{
using
InvariantDims
=
typename
arithmetic_sequence_gen
<
0
,
NumInvariantDim
,
1
>::
type
;
const
auto
tupleSrcLengths
=
make_tuple_from_array_and_index_seq
(
lengths
,
InvariantDims
{});
const
auto
tupleSrcStrides
=
make_tuple_from_array_and_index_seq
(
strides
,
InvariantDims
{});
const
auto
desc
=
make_naive_tensor_descriptor
(
tupleSrcLengths
,
tupleSrcStrides
);
const
auto
grid_desc_m
=
transform_tensor_descriptor
(
desc
,
make_tuple
(
make_merge_transform
(
tupleSrcLengths
)),
make_tuple
(
InvariantDims
{}),
make_tuple
(
Sequence
<
0
>
{}));
const
auto
invariantLength
=
grid_desc_m
.
GetLength
(
Number
<
0
>
{});
const
auto
pad_M
=
math
::
integer_least_multiple
(
invariantLength
,
M_BlockTileSize
)
-
invariantLength
;
auto
grid_desc_m_padded
=
transform_tensor_descriptor
(
grid_desc_m
,
make_tuple
(
make_right_pad_transform
(
invariantLength
,
pad_M
)),
make_tuple
(
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
return
grid_desc_m_padded
;
}
using
GridDesc_M_K
=
decltype
(
MakeSrc2dDescriptor
({
1
},
{
1
},
1
));
using
GridDesc_M_K
=
decltype
(
MakeSrc2dDescriptor
({
1
},
{
1
},
1
));
using
GridDesc_M
=
decltype
(
MakeSaveMeanInvStdDescriptor_M
({
1
},
{
1
}));
struct
Argument
:
public
BaseArgument
struct
Argument
:
public
BaseArgument
{
{
...
@@ -142,17 +180,23 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
...
@@ -142,17 +180,23 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
const
std
::
vector
<
index_t
>
gammaStrides
,
const
std
::
vector
<
index_t
>
gammaStrides
,
const
std
::
vector
<
index_t
>
betaStrides
,
const
std
::
vector
<
index_t
>
betaStrides
,
const
std
::
vector
<
index_t
>
yStrides
,
const
std
::
vector
<
index_t
>
yStrides
,
const
std
::
vector
<
index_t
>
saveMeanStrides
,
const
std
::
vector
<
index_t
>
saveInvStdStrides
,
const
std
::
vector
<
index_t
>
reduceDims
,
const
std
::
vector
<
index_t
>
reduceDims
,
YElementwiseOperation
y_elementwise_op
,
YElementwiseOperation
y_elementwise_op
,
double
epsilon
,
double
epsilon
,
const
XDataType
*
p_x
,
const
XDataType
*
p_x
,
const
GammaDataType
*
p_gamma
,
const
GammaDataType
*
p_gamma
,
const
BetaDataType
*
p_beta
,
const
BetaDataType
*
p_beta
,
YDataType
*
p_y
)
YDataType
*
p_y
,
SaveMeanInvStdDataType
*
p_saveMean
,
SaveMeanInvStdDataType
*
p_saveInvStd
)
:
p_x_
(
p_x
),
:
p_x_
(
p_x
),
p_gamma_
(
p_gamma
),
p_gamma_
(
p_gamma
),
p_beta_
(
p_beta
),
p_beta_
(
p_beta
),
p_y_
(
p_y
),
p_y_
(
p_y
),
p_saveMean_
(
p_saveMean
),
p_saveInvStd_
(
p_saveInvStd
),
y_elementwise_op_
(
y_elementwise_op
)
y_elementwise_op_
(
y_elementwise_op
)
{
{
epsilon_
=
static_cast
<
ComputeDataType
>
(
epsilon
);
epsilon_
=
static_cast
<
ComputeDataType
>
(
epsilon
);
...
@@ -162,16 +206,14 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
...
@@ -162,16 +206,14 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
yStrides_
=
shuffle_tensor_dimensions
<
Rank
,
NumReduceDim
>
(
yStrides
,
reduceDims
);
yStrides_
=
shuffle_tensor_dimensions
<
Rank
,
NumReduceDim
>
(
yStrides
,
reduceDims
);
gammaStrides_
=
shuffle_tensor_dimensions
<
Rank
,
NumReduceDim
>
(
gammaStrides
,
reduceDims
);
gammaStrides_
=
shuffle_tensor_dimensions
<
Rank
,
NumReduceDim
>
(
gammaStrides
,
reduceDims
);
betaStrides_
=
shuffle_tensor_dimensions
<
Rank
,
NumReduceDim
>
(
betaStrides
,
reduceDims
);
betaStrides_
=
shuffle_tensor_dimensions
<
Rank
,
NumReduceDim
>
(
betaStrides
,
reduceDims
);
saveMeanStrides_
=
saveMeanStrides
;
saveInvStdStrides_
=
saveInvStdStrides
;
long_index_t
invariant_length
;
std
::
tie
(
MRaw_
,
KRaw_
)
=
get_2d_lengths
<
Rank
,
NumReduceDim
>
(
Lengths_
);
long_index_t
reduce_length
;
std
::
tie
(
invariant_length
,
reduce_length
)
=
get_2d_lengths
<
Rank
,
NumReduceDim
>
(
Lengths_
);
numBlockTileIteration_
=
math
::
integer_divide_ceil
(
reduce_length
,
K_BlockTileSize
);
numBlockTileIteration_
=
math
::
integer_divide_ceil
(
KRaw_
,
K_BlockTileSize
);
gridSize_
=
math
::
integer_divide_ceil
(
invariant_length
,
M_BlockTileSize
);
gridSize_
=
math
::
integer_divide_ceil
(
MRaw_
,
M_BlockTileSize
);
x_grid_desc_m_k_
=
MakeSrc2dDescriptor
(
Lengths_
,
xStrides_
,
numBlockTileIteration_
);
x_grid_desc_m_k_
=
MakeSrc2dDescriptor
(
Lengths_
,
xStrides_
,
numBlockTileIteration_
);
gamma_grid_desc_m_k_
=
gamma_grid_desc_m_k_
=
...
@@ -179,9 +221,16 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
...
@@ -179,9 +221,16 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
beta_grid_desc_m_k_
=
beta_grid_desc_m_k_
=
MakeSrc2dDescriptor
(
Lengths_
,
betaStrides_
,
numBlockTileIteration_
);
MakeSrc2dDescriptor
(
Lengths_
,
betaStrides_
,
numBlockTileIteration_
);
y_grid_desc_m_k_
=
MakeSrc2dDescriptor
(
Lengths_
,
yStrides_
,
numBlockTileIteration_
);
y_grid_desc_m_k_
=
MakeSrc2dDescriptor
(
Lengths_
,
yStrides_
,
numBlockTileIteration_
);
save_mean_grid_desc_m_
=
MakeSaveMeanInvStdDescriptor_M
(
Lengths_
,
saveMeanStrides
);
save_inv_std_grid_desc_m_
=
MakeSaveMeanInvStdDescriptor_M
(
Lengths_
,
saveInvStdStrides
);
isSweeponce_
=
isSweeponce_
=
x_grid_desc_m_k_
.
GetLength
(
Number
<
1
>
{})
<=
KThreadClusterSize
*
KThreadSliceSize
;
x_grid_desc_m_k_
.
GetLength
(
Number
<
1
>
{})
<=
KThreadClusterSize
*
KThreadSliceSize
;
if
constexpr
(
NumInvariantDim
==
0
)
invariant_lowest_length_
=
1
;
else
invariant_lowest_length_
=
Lengths_
[
NumInvariantDim
-
1
];
}
}
ComputeDataType
epsilon_
;
ComputeDataType
epsilon_
;
...
@@ -190,12 +239,16 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
...
@@ -190,12 +239,16 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
const
GammaDataType
*
p_gamma_
;
const
GammaDataType
*
p_gamma_
;
const
BetaDataType
*
p_beta_
;
const
BetaDataType
*
p_beta_
;
YDataType
*
p_y_
;
YDataType
*
p_y_
;
SaveMeanInvStdDataType
*
p_saveMean_
;
SaveMeanInvStdDataType
*
p_saveInvStd_
;
std
::
vector
<
index_t
>
Lengths_
;
std
::
vector
<
index_t
>
Lengths_
;
std
::
vector
<
index_t
>
xStrides_
;
std
::
vector
<
index_t
>
xStrides_
;
std
::
vector
<
index_t
>
gammaStrides_
;
std
::
vector
<
index_t
>
gammaStrides_
;
std
::
vector
<
index_t
>
betaStrides_
;
std
::
vector
<
index_t
>
betaStrides_
;
std
::
vector
<
index_t
>
yStrides_
;
std
::
vector
<
index_t
>
yStrides_
;
std
::
vector
<
index_t
>
saveMeanStrides_
;
std
::
vector
<
index_t
>
saveInvStdStrides_
;
YElementwiseOperation
y_elementwise_op_
;
YElementwiseOperation
y_elementwise_op_
;
...
@@ -206,7 +259,14 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
...
@@ -206,7 +259,14 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
GridDesc_M_K
gamma_grid_desc_m_k_
;
GridDesc_M_K
gamma_grid_desc_m_k_
;
GridDesc_M_K
beta_grid_desc_m_k_
;
GridDesc_M_K
beta_grid_desc_m_k_
;
GridDesc_M_K
y_grid_desc_m_k_
;
GridDesc_M_K
y_grid_desc_m_k_
;
GridDesc_M
save_mean_grid_desc_m_
;
GridDesc_M
save_inv_std_grid_desc_m_
;
bool
isSweeponce_
;
bool
isSweeponce_
;
index_t
MRaw_
;
// invarient length
index_t
KRaw_
;
// reduce length
index_t
invariant_lowest_length_
;
};
};
struct
Invoker
:
public
BaseInvoker
struct
Invoker
:
public
BaseInvoker
...
@@ -217,9 +277,11 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
...
@@ -217,9 +277,11 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
GammaDataType
,
GammaDataType
,
BetaDataType
,
BetaDataType
,
YDataType
,
YDataType
,
SaveMeanInvStdDataType
,
ComputeDataType
,
ComputeDataType
,
YElementwiseOperation
,
YElementwiseOperation
,
GridDesc_M_K
,
GridDesc_M_K
,
GridDesc_M
,
BlockSize
,
BlockSize
,
MThreadClusterSize
,
MThreadClusterSize
,
KThreadClusterSize
,
KThreadClusterSize
,
...
@@ -233,6 +295,7 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
...
@@ -233,6 +295,7 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
BetaSrcVectorSize
,
BetaSrcVectorSize
,
XYSrcVectorDim
,
XYSrcVectorDim
,
YDstVectorSize
,
YDstVectorSize
,
SaveMeanInvStdDstVectorSize
,
UseWelford
>
(
arg
.
isSweeponce_
);
UseWelford
>
(
arg
.
isSweeponce_
);
float
avg_time
=
0
;
float
avg_time
=
0
;
...
@@ -245,12 +308,16 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
...
@@ -245,12 +308,16 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
arg
.
gamma_grid_desc_m_k_
,
arg
.
gamma_grid_desc_m_k_
,
arg
.
beta_grid_desc_m_k_
,
arg
.
beta_grid_desc_m_k_
,
arg
.
y_grid_desc_m_k_
,
arg
.
y_grid_desc_m_k_
,
arg
.
save_mean_grid_desc_m_
,
arg
.
save_inv_std_grid_desc_m_
,
arg
.
numBlockTileIteration_
,
arg
.
numBlockTileIteration_
,
arg
.
epsilon_
,
arg
.
epsilon_
,
arg
.
p_x_
,
arg
.
p_x_
,
arg
.
p_gamma_
,
arg
.
p_gamma_
,
arg
.
p_beta_
,
arg
.
p_beta_
,
arg
.
p_y_
,
arg
.
p_y_
,
arg
.
p_saveMean_
,
arg
.
p_saveInvStd_
,
arg
.
y_elementwise_op_
);
arg
.
y_elementwise_op_
);
return
(
avg_time
);
return
(
avg_time
);
...
@@ -267,8 +334,6 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
...
@@ -267,8 +334,6 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
{
{
const
Argument
*
p_arg_
=
dynamic_cast
<
const
Argument
*>
(
p_arg
);
const
Argument
*
p_arg_
=
dynamic_cast
<
const
Argument
*>
(
p_arg
);
constexpr
index_t
NumInvariantDim
=
Rank
-
NumReduceDim
;
if
constexpr
(
XYSrcVectorDim
==
0
)
if
constexpr
(
XYSrcVectorDim
==
0
)
{
{
if
constexpr
(
NumInvariantDim
==
0
)
if
constexpr
(
NumInvariantDim
==
0
)
...
@@ -277,13 +342,15 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
...
@@ -277,13 +342,15 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
}
}
else
else
{
{
printf
(
"!!!! %d
\n
"
,
p_arg_
->
invariant_lowest_length_
);
if
(
p_arg_
->
xStrides_
[
NumInvariantDim
-
1
]
!=
1
)
if
(
p_arg_
->
xStrides_
[
NumInvariantDim
-
1
]
!=
1
)
return
false
;
return
false
;
if
(
p_arg_
->
invariant_lowest_length
%
XSrcVectorSize
!=
0
)
if
(
p_arg_
->
invariant_lowest_length
_
%
XSrcVectorSize
!=
0
)
return
false
;
return
false
;
if
(
p_arg_
->
invariant_lowest_length
%
YDstVectorSize
!=
0
)
if
(
p_arg_
->
invariant_lowest_length
_
%
YDstVectorSize
!=
0
)
return
false
;
return
false
;
};
};
}
}
...
@@ -325,7 +392,7 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
...
@@ -325,7 +392,7 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
if
(
p_arg_
->
betaStrides_
[
NumInvariantDim
-
1
]
!=
1
)
if
(
p_arg_
->
betaStrides_
[
NumInvariantDim
-
1
]
!=
1
)
return
(
false
);
return
(
false
);
if
(
p_arg_
->
invariant_lowest_length
%
BetaSrcVectorSize
!=
0
)
if
(
p_arg_
->
invariant_lowest_length
_
%
BetaSrcVectorSize
!=
0
)
return
(
false
);
return
(
false
);
}
}
else
// if fastest dim is reduced
else
// if fastest dim is reduced
...
@@ -337,6 +404,9 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
...
@@ -337,6 +404,9 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
return
(
false
);
return
(
false
);
}
}
if
(
p_arg_
->
invariant_lowest_length_
%
SaveMeanInvStdDstVectorSize
!=
0
)
return
false
;
return
true
;
return
true
;
};
};
...
@@ -346,6 +416,8 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
...
@@ -346,6 +416,8 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
const
std
::
vector
<
index_t
>
gammaStrides
,
const
std
::
vector
<
index_t
>
gammaStrides
,
const
std
::
vector
<
index_t
>
betaStrides
,
const
std
::
vector
<
index_t
>
betaStrides
,
const
std
::
vector
<
index_t
>
yStrides
,
const
std
::
vector
<
index_t
>
yStrides
,
const
std
::
vector
<
index_t
>
saveMeanStrides
,
const
std
::
vector
<
index_t
>
saveInvStdStrides
,
const
std
::
vector
<
index_t
>
reduceDims
,
const
std
::
vector
<
index_t
>
reduceDims
,
double
epsilon
,
double
epsilon
,
const
void
*
p_x
,
const
void
*
p_x
,
...
@@ -353,27 +425,30 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
...
@@ -353,27 +425,30 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
const
void
*
p_beta
,
const
void
*
p_beta
,
void
*
p_y
,
void
*
p_y
,
void
*
p_saveMean
,
void
*
p_saveMean
,
void
*
p_saveInv
Var
,
void
*
p_saveInv
Std
,
YElementwiseOperation
y_elementwise_op
)
override
YElementwiseOperation
y_elementwise_op
)
override
{
{
// TODO
if
(
lengths
.
size
()
!=
Rank
||
xStrides
.
size
()
!=
Rank
||
gammaStrides
.
size
()
!=
Rank
||
// Optional cache of the intermediate results (mean and InvVariance) during the
betaStrides
.
size
()
!=
Rank
||
yStrides
.
size
()
!=
Rank
||
// forward pass could speedup in the backward
saveMeanStrides
.
size
()
!=
NumInvariantDim
||
saveInvStdStrides
.
size
()
!=
NumInvariantDim
)
ignore
=
p_saveMean
;
throw
std
::
runtime_error
(
"dimension is incorrect"
);
ignore
=
p_saveInvVar
;
return
std
::
make_unique
<
Argument
>
(
lengths
,
return
std
::
make_unique
<
Argument
>
(
lengths
,
xStrides
,
xStrides
,
gammaStrides
,
gammaStrides
,
betaStrides
,
betaStrides
,
yStrides
,
yStrides
,
saveMeanStrides
,
saveInvStdStrides
,
reduceDims
,
reduceDims
,
y_elementwise_op
,
y_elementwise_op
,
epsilon
,
epsilon
,
static_cast
<
const
XDataType
*>
(
p_x
),
static_cast
<
const
XDataType
*>
(
p_x
),
static_cast
<
const
GammaDataType
*>
(
p_gamma
),
static_cast
<
const
GammaDataType
*>
(
p_gamma
),
static_cast
<
const
BetaDataType
*>
(
p_beta
),
static_cast
<
const
BetaDataType
*>
(
p_beta
),
static_cast
<
YDataType
*>
(
p_y
));
static_cast
<
YDataType
*>
(
p_y
),
static_cast
<
SaveMeanInvStdDataType
*>
(
p_saveMean
),
static_cast
<
SaveMeanInvStdDataType
*>
(
p_saveInvStd
));
};
};
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
override
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
override
...
...
include/ck/tensor_operation/gpu/device/impl/device_normalization_splitk_impl.hpp
View file @
f30e5975
...
@@ -19,7 +19,7 @@
...
@@ -19,7 +19,7 @@
namespace
ck
{
namespace
ck
{
template
<
typename
GridwiseWelford
,
template
<
typename
GridwiseWelford
,
typename
XDataType
,
typename
XDataType
,
typename
MeanVarDataType
,
typename
Workspace
MeanVarDataType
,
typename
ComputeDataType
,
typename
ComputeDataType
,
typename
XGridDesc_M_K
,
typename
XGridDesc_M_K
,
typename
MeanVarGridDesc_M_KBlock
>
typename
MeanVarGridDesc_M_KBlock
>
...
@@ -28,8 +28,8 @@ kernel_normalizationSplitK1st(const XGridDesc_M_K x_grid_desc_m_k,
...
@@ -28,8 +28,8 @@ kernel_normalizationSplitK1st(const XGridDesc_M_K x_grid_desc_m_k,
const
MeanVarGridDesc_M_KBlock
mean_var_grid_desc_m_kblock
,
const
MeanVarGridDesc_M_KBlock
mean_var_grid_desc_m_kblock
,
index_t
num_k_block_tile_iteration
,
index_t
num_k_block_tile_iteration
,
const
XDataType
*
const
__restrict__
p_x_global
,
const
XDataType
*
const
__restrict__
p_x_global
,
MeanVarDataType
*
const
__restrict__
p_welford_mean
,
Workspace
MeanVarDataType
*
const
__restrict__
p_welford_mean
,
MeanVarDataType
*
const
__restrict__
p_welford_variance
,
Workspace
MeanVarDataType
*
const
__restrict__
p_welford_variance
,
int32_t
*
const
__restrict__
p_welford_count
)
int32_t
*
const
__restrict__
p_welford_count
)
{
{
GridwiseWelford
::
Run
(
x_grid_desc_m_k
,
GridwiseWelford
::
Run
(
x_grid_desc_m_k
,
...
@@ -42,16 +42,18 @@ kernel_normalizationSplitK1st(const XGridDesc_M_K x_grid_desc_m_k,
...
@@ -42,16 +42,18 @@ kernel_normalizationSplitK1st(const XGridDesc_M_K x_grid_desc_m_k,
};
};
template
<
typename
GridwiseWelfordNormalization
,
template
<
typename
GridwiseWelfordNormalization
,
typename
MeanVarDataType
,
typename
Workspace
MeanVarDataType
,
typename
XDataType
,
typename
XDataType
,
typename
GammaDataType
,
typename
GammaDataType
,
typename
BetaDataType
,
typename
BetaDataType
,
typename
YDataType
,
typename
YDataType
,
typename
SaveMeanInvStdDataType
,
typename
ComputeDataType
,
typename
ComputeDataType
,
typename
YElementwiseOperation
,
typename
YElementwiseOperation
,
typename
MeanVarGridDesc_M_KBlock
,
typename
MeanVarGridDesc_M_KBlock
,
typename
CountGridDesc_M_KBlock
,
typename
CountGridDesc_M_KBlock
,
typename
XYGammaBetaGridDesc_M_K
>
typename
XYGammaBetaGridDesc_M_K
,
typename
SaveMeanInvStdGridDesc_M
>
__global__
void
__global__
void
kernel_normalizationSplitK2nd
(
const
MeanVarGridDesc_M_KBlock
mean_var_grid_desc_m_kblock
,
kernel_normalizationSplitK2nd
(
const
MeanVarGridDesc_M_KBlock
mean_var_grid_desc_m_kblock
,
const
CountGridDesc_M_KBlock
count_grid_desc_m_kblock
,
const
CountGridDesc_M_KBlock
count_grid_desc_m_kblock
,
...
@@ -59,17 +61,21 @@ kernel_normalizationSplitK2nd(const MeanVarGridDesc_M_KBlock mean_var_grid_desc_
...
@@ -59,17 +61,21 @@ kernel_normalizationSplitK2nd(const MeanVarGridDesc_M_KBlock mean_var_grid_desc_
const
XYGammaBetaGridDesc_M_K
gamma_grid_desc_m_k
,
const
XYGammaBetaGridDesc_M_K
gamma_grid_desc_m_k
,
const
XYGammaBetaGridDesc_M_K
beta_grid_desc_m_k
,
const
XYGammaBetaGridDesc_M_K
beta_grid_desc_m_k
,
const
XYGammaBetaGridDesc_M_K
y_grid_desc_m_k
,
const
XYGammaBetaGridDesc_M_K
y_grid_desc_m_k
,
const
SaveMeanInvStdGridDesc_M
save_mean_grid_desc_m
,
const
SaveMeanInvStdGridDesc_M
save_inv_std_grid_desc_m
,
index_t
num_k_mean_var_count_iteration
,
index_t
num_k_mean_var_count_iteration
,
index_t
num_k_block_tile_iteration
,
index_t
num_k_block_tile_iteration
,
index_t
k_grid_size
,
index_t
k_grid_size
,
ComputeDataType
epsilon
,
ComputeDataType
epsilon
,
const
MeanVarDataType
*
const
p_mean_global
,
const
Workspace
MeanVarDataType
*
const
p_mean_global
,
const
MeanVarDataType
*
const
p_variance_global
,
const
Workspace
MeanVarDataType
*
const
p_variance_global
,
const
int32_t
*
const
p_welford_count_global
,
const
int32_t
*
const
p_welford_count_global
,
const
XDataType
*
const
__restrict__
p_x_global
,
const
XDataType
*
const
__restrict__
p_x_global
,
const
GammaDataType
*
const
__restrict__
p_gamma_global
,
const
GammaDataType
*
const
__restrict__
p_gamma_global
,
const
BetaDataType
*
const
__restrict__
p_beta_global
,
const
BetaDataType
*
const
__restrict__
p_beta_global
,
YDataType
*
const
__restrict__
p_y_global
,
YDataType
*
const
__restrict__
p_y_global
,
SaveMeanInvStdDataType
*
const
__restrict__
p_save_mean_global
,
SaveMeanInvStdDataType
*
const
__restrict__
p_save_inv_std_global
,
const
YElementwiseOperation
y_elementwise_op
)
const
YElementwiseOperation
y_elementwise_op
)
{
{
GridwiseWelfordNormalization
::
Run
(
mean_var_grid_desc_m_kblock
,
GridwiseWelfordNormalization
::
Run
(
mean_var_grid_desc_m_kblock
,
...
@@ -78,6 +84,8 @@ kernel_normalizationSplitK2nd(const MeanVarGridDesc_M_KBlock mean_var_grid_desc_
...
@@ -78,6 +84,8 @@ kernel_normalizationSplitK2nd(const MeanVarGridDesc_M_KBlock mean_var_grid_desc_
gamma_grid_desc_m_k
,
gamma_grid_desc_m_k
,
beta_grid_desc_m_k
,
beta_grid_desc_m_k
,
y_grid_desc_m_k
,
y_grid_desc_m_k
,
save_mean_grid_desc_m
,
save_inv_std_grid_desc_m
,
num_k_mean_var_count_iteration
,
num_k_mean_var_count_iteration
,
num_k_block_tile_iteration
,
num_k_block_tile_iteration
,
k_grid_size
,
k_grid_size
,
...
@@ -89,6 +97,8 @@ kernel_normalizationSplitK2nd(const MeanVarGridDesc_M_KBlock mean_var_grid_desc_
...
@@ -89,6 +97,8 @@ kernel_normalizationSplitK2nd(const MeanVarGridDesc_M_KBlock mean_var_grid_desc_
p_gamma_global
,
p_gamma_global
,
p_beta_global
,
p_beta_global
,
p_y_global
,
p_y_global
,
p_save_mean_global
,
p_save_inv_std_global
,
y_elementwise_op
);
y_elementwise_op
);
};
};
}
// namespace ck
}
// namespace ck
...
@@ -107,6 +117,7 @@ template <typename XDataType,
...
@@ -107,6 +117,7 @@ template <typename XDataType,
typename
BetaDataType
,
typename
BetaDataType
,
typename
ComputeDataType
,
typename
ComputeDataType
,
typename
YDataType
,
typename
YDataType
,
typename
SaveMeanInvStdDataType
,
typename
YElementwiseOperation
,
typename
YElementwiseOperation
,
index_t
Rank
,
index_t
Rank
,
index_t
NumReduceDim
,
index_t
NumReduceDim
,
...
@@ -121,17 +132,18 @@ template <typename XDataType,
...
@@ -121,17 +132,18 @@ template <typename XDataType,
index_t
GammaSrcVectorSize
,
index_t
GammaSrcVectorSize
,
index_t
BetaSrcVectorDim
,
index_t
BetaSrcVectorDim
,
index_t
BetaSrcVectorSize
,
index_t
BetaSrcVectorSize
,
index_t
YDstVectorSize
>
index_t
YDstVectorSize
,
index_t
SaveMeanInvStdDstVectorSize
>
struct
DeviceNormalizationSplitKImpl
:
public
DeviceNormalization
<
XDataType
,
struct
DeviceNormalizationSplitKImpl
:
public
DeviceNormalization
<
XDataType
,
GammaDataType
,
GammaDataType
,
BetaDataType
,
BetaDataType
,
ComputeDataType
,
YDataType
,
YDataType
,
SaveMeanInvStdDataType
,
YElementwiseOperation
,
YElementwiseOperation
,
Rank
,
Rank
,
NumReduceDim
>
NumReduceDim
>
{
{
using
MeanVarDataType
=
Compute
DataType
;
using
Workspace
MeanVarDataType
=
SaveMeanInvStd
DataType
;
static_assert
(
BlockSize
==
MThreadClusterSize
*
KThreadClusterSize
);
static_assert
(
BlockSize
==
MThreadClusterSize
*
KThreadClusterSize
);
static_assert
(
static_assert
(
...
@@ -144,22 +156,28 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType,
...
@@ -144,22 +156,28 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType,
(
BetaSrcVectorDim
==
1
&&
KThreadSliceSize
%
BetaSrcVectorSize
==
0
)),
(
BetaSrcVectorDim
==
1
&&
KThreadSliceSize
%
BetaSrcVectorSize
==
0
)),
"Invalid thread slice sizes and/or beta vector sizes configuration, please check!"
);
"Invalid thread slice sizes and/or beta vector sizes configuration, please check!"
);
static_assert
(
MThreadSliceSize
%
SaveMeanInvStdDstVectorSize
==
0
,
"Invalid thread slice sizes and/or save mean and inverse std vector sizes "
"configuration, please check!"
);
using
PassThrough
=
tensor_operation
::
element_wise
::
PassThrough
;
using
PassThrough
=
tensor_operation
::
element_wise
::
PassThrough
;
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
index_t
NumInvariantDim
=
Rank
-
NumReduceDim
;
static
constexpr
index_t
M_BlockTileSize
=
MThreadClusterSize
*
MThreadSliceSize
;
static
constexpr
index_t
M_BlockTileSize
=
MThreadClusterSize
*
MThreadSliceSize
;
static
constexpr
index_t
K_BlockTileSize
=
KThreadClusterSize
*
KThreadSliceSize
;
static
constexpr
index_t
K_BlockTileSize
=
KThreadClusterSize
*
KThreadSliceSize
;
static
constexpr
bool
reduceAllDim
=
(
NumInvariantDim
==
0
);
static_assert
(
!
reduceAllDim
);
// TODO
static
auto
MakeSrc2dDescriptor
(
const
std
::
vector
<
index_t
>&
inLengths
,
static
auto
MakeSrc2dDescriptor
(
const
std
::
vector
<
index_t
>&
inLengths
,
const
std
::
vector
<
index_t
>&
inStrides
,
const
std
::
vector
<
index_t
>&
inStrides
,
int
kBlockSize
,
int
kBlockSize
,
int
numBlockTileIteration
)
int
numBlockTileIteration
)
{
{
constexpr
index_t
NumInvariantDim
=
Rank
-
NumReduceDim
;
static
constexpr
index_t
numSrcDim
=
Rank
;
static
constexpr
index_t
numSrcDim
=
Rank
;
static
constexpr
bool
reduceAllDim
=
(
NumInvariantDim
==
0
);
const
auto
tupleSrcLengths
=
make_tuple_from_array
(
inLengths
,
Number
<
numSrcDim
>
{});
const
auto
tupleSrcLengths
=
make_tuple_from_array
(
inLengths
,
Number
<
numSrcDim
>
{});
const
auto
tupleSrcStrides
=
make_tuple_from_array
(
inStrides
,
Number
<
numSrcDim
>
{});
const
auto
tupleSrcStrides
=
make_tuple_from_array
(
inStrides
,
Number
<
numSrcDim
>
{});
...
@@ -219,7 +237,7 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType,
...
@@ -219,7 +237,7 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType,
};
};
template
<
typename
DoPads
,
index_t
MPerTile
,
index_t
KPerTile
>
template
<
typename
DoPads
,
index_t
MPerTile
,
index_t
KPerTile
>
static
auto
MakeMeanVarDescriptor_M_K
(
index_t
M
,
index_t
K
)
static
auto
Make
Workspace
MeanVarDescriptor_M_K
(
index_t
M
,
index_t
K
)
{
{
const
auto
grid_desc_m_k
=
const
auto
grid_desc_m_k
=
make_naive_tensor_descriptor
(
make_tuple
(
M
,
K
),
make_tuple
(
K
,
I1
));
make_naive_tensor_descriptor
(
make_tuple
(
M
,
K
),
make_tuple
(
K
,
I1
));
...
@@ -227,26 +245,57 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType,
...
@@ -227,26 +245,57 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType,
}
}
template
<
typename
DoPads
,
index_t
MPerTile
,
index_t
KPerTile
>
template
<
typename
DoPads
,
index_t
MPerTile
,
index_t
KPerTile
>
static
auto
MakeCountDescriptor_M_K
(
index_t
M
,
index_t
K
)
static
auto
Make
Workspace
CountDescriptor_M_K
(
index_t
M
,
index_t
K
)
{
{
const
auto
grid_desc_m_k
=
const
auto
grid_desc_m_k
=
make_naive_tensor_descriptor
(
make_tuple
(
M
,
K
),
make_tuple
(
I0
,
I1
));
make_naive_tensor_descriptor
(
make_tuple
(
M
,
K
),
make_tuple
(
I0
,
I1
));
return
PadTensorDescriptor
(
grid_desc_m_k
,
make_tuple
(
MPerTile
,
KPerTile
),
DoPads
{});
return
PadTensorDescriptor
(
grid_desc_m_k
,
make_tuple
(
MPerTile
,
KPerTile
),
DoPads
{});
}
}
static
auto
MakeSaveMeanInvStdDescriptor_M
(
const
std
::
vector
<
index_t
>&
lengths
,
const
std
::
vector
<
index_t
>&
strides
)
{
using
InvariantDims
=
typename
arithmetic_sequence_gen
<
0
,
NumInvariantDim
,
1
>::
type
;
const
auto
tupleSrcLengths
=
make_tuple_from_array_and_index_seq
(
lengths
,
InvariantDims
{});
const
auto
tupleSrcStrides
=
make_tuple_from_array_and_index_seq
(
strides
,
InvariantDims
{});
const
auto
desc
=
make_naive_tensor_descriptor
(
tupleSrcLengths
,
tupleSrcStrides
);
const
auto
grid_desc_m
=
transform_tensor_descriptor
(
desc
,
make_tuple
(
make_merge_transform
(
tupleSrcLengths
)),
make_tuple
(
InvariantDims
{}),
make_tuple
(
Sequence
<
0
>
{}));
const
auto
invariantLength
=
grid_desc_m
.
GetLength
(
Number
<
0
>
{});
const
auto
pad_M
=
math
::
integer_least_multiple
(
invariantLength
,
M_BlockTileSize
)
-
invariantLength
;
auto
grid_desc_m_padded
=
transform_tensor_descriptor
(
grid_desc_m
,
make_tuple
(
make_right_pad_transform
(
invariantLength
,
pad_M
)),
make_tuple
(
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
return
grid_desc_m_padded
;
}
using
SrcGridDesc_M_K
=
decltype
(
MakeSrc2dDescriptor
({
1
},
{
1
},
1
,
1
));
using
SrcGridDesc_M_K
=
decltype
(
MakeSrc2dDescriptor
({
1
},
{
1
},
1
,
1
));
using
Kernel1MeanVarGridDesc_M_KBlock
=
using
Kernel1MeanVarGridDesc_M_KBlock
=
decltype
(
MakeMeanVarDescriptor_M_K
<
Sequence
<
true
,
false
>
,
1
,
1
>
(
1
,
1
));
decltype
(
Make
Workspace
MeanVarDescriptor_M_K
<
Sequence
<
true
,
false
>
,
1
,
1
>
(
1
,
1
));
using
Kernel2MeanVarGridDesc_M_KBlock
=
using
Kernel2MeanVarGridDesc_M_KBlock
=
decltype
(
MakeMeanVarDescriptor_M_K
<
Sequence
<
true
,
true
>
,
1
,
1
>
(
1
,
1
));
decltype
(
Make
Workspace
MeanVarDescriptor_M_K
<
Sequence
<
true
,
true
>
,
1
,
1
>
(
1
,
1
));
using
Kernel2CountGridDesc_M_KBlock
=
using
Kernel2CountGridDesc_M_KBlock
=
decltype
(
MakeCountDescriptor_M_K
<
Sequence
<
true
,
true
>
,
1
,
1
>
(
1
,
1
));
decltype
(
MakeWorkspaceCountDescriptor_M_K
<
Sequence
<
true
,
true
>
,
1
,
1
>
(
1
,
1
));
using
SaveMeanInvStdGridDesc_M
=
decltype
(
MakeSaveMeanInvStdDescriptor_M
({
1
},
{
1
}));
using
GridwiseWelford
=
GridwiseNormalizationSplitK1st
<
XDataType
,
using
GridwiseWelford
=
GridwiseNormalizationSplitK1st
<
XDataType
,
ComputeDataType
,
ComputeDataType
,
MeanVarDataType
,
Workspace
MeanVarDataType
,
SrcGridDesc_M_K
,
SrcGridDesc_M_K
,
Kernel1MeanVarGridDesc_M_KBlock
,
Kernel1MeanVarGridDesc_M_KBlock
,
BlockSize
,
BlockSize
,
...
@@ -258,16 +307,18 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType,
...
@@ -258,16 +307,18 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType,
XSrcVectorSize
>
;
XSrcVectorSize
>
;
using
GridwiseWelfordNormalization
=
using
GridwiseWelfordNormalization
=
GridwiseNormalizationSplitK2nd
<
MeanVarDataType
,
GridwiseNormalizationSplitK2nd
<
Workspace
MeanVarDataType
,
XDataType
,
XDataType
,
GammaDataType
,
GammaDataType
,
BetaDataType
,
BetaDataType
,
YDataType
,
YDataType
,
SaveMeanInvStdDataType
,
ComputeDataType
,
ComputeDataType
,
YElementwiseOperation
,
YElementwiseOperation
,
Kernel2MeanVarGridDesc_M_KBlock
,
Kernel2MeanVarGridDesc_M_KBlock
,
Kernel2CountGridDesc_M_KBlock
,
Kernel2CountGridDesc_M_KBlock
,
SrcGridDesc_M_K
,
SrcGridDesc_M_K
,
SaveMeanInvStdGridDesc_M
,
BlockSize
,
BlockSize
,
MThreadClusterSize
,
MThreadClusterSize
,
KThreadClusterSize
,
KThreadClusterSize
,
...
@@ -280,7 +331,8 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType,
...
@@ -280,7 +331,8 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType,
BetaSrcVectorDim
,
BetaSrcVectorDim
,
BetaSrcVectorSize
,
BetaSrcVectorSize
,
XYVectorDim
,
XYVectorDim
,
YDstVectorSize
>
;
YDstVectorSize
,
SaveMeanInvStdDstVectorSize
>
;
struct
Argument
:
public
BaseArgument
struct
Argument
:
public
BaseArgument
{
{
...
@@ -289,17 +341,23 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType,
...
@@ -289,17 +341,23 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType,
const
std
::
vector
<
index_t
>
gammaStrides
,
const
std
::
vector
<
index_t
>
gammaStrides
,
const
std
::
vector
<
index_t
>
betaStrides
,
const
std
::
vector
<
index_t
>
betaStrides
,
const
std
::
vector
<
index_t
>
yStrides
,
const
std
::
vector
<
index_t
>
yStrides
,
const
std
::
vector
<
index_t
>
saveMeanStrides
,
const
std
::
vector
<
index_t
>
saveInvStdStrides
,
const
std
::
vector
<
index_t
>
reduceDims
,
const
std
::
vector
<
index_t
>
reduceDims
,
YElementwiseOperation
y_elementwise_op
,
YElementwiseOperation
y_elementwise_op
,
double
epsilon
,
double
epsilon
,
const
XDataType
*
p_x
,
const
XDataType
*
p_x
,
const
GammaDataType
*
p_gamma
,
const
GammaDataType
*
p_gamma
,
const
BetaDataType
*
p_beta
,
const
BetaDataType
*
p_beta
,
YDataType
*
p_y
)
YDataType
*
p_y
,
SaveMeanInvStdDataType
*
p_saveMean
,
SaveMeanInvStdDataType
*
p_saveInvStd
)
:
p_x_
(
p_x
),
:
p_x_
(
p_x
),
p_gamma_
(
p_gamma
),
p_gamma_
(
p_gamma
),
p_beta_
(
p_beta
),
p_beta_
(
p_beta
),
p_y_
(
p_y
),
p_y_
(
p_y
),
p_saveMean_
(
p_saveMean
),
p_saveInvStd_
(
p_saveInvStd
),
p_workspace_mean_
{
nullptr
},
p_workspace_mean_
{
nullptr
},
p_workspace_var_
{
nullptr
},
p_workspace_var_
{
nullptr
},
p_workspace_count_
{
nullptr
},
p_workspace_count_
{
nullptr
},
...
@@ -312,6 +370,8 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType,
...
@@ -312,6 +370,8 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType,
yStrides_
=
shuffle_tensor_dimensions
<
Rank
,
NumReduceDim
>
(
yStrides
,
reduceDims
);
yStrides_
=
shuffle_tensor_dimensions
<
Rank
,
NumReduceDim
>
(
yStrides
,
reduceDims
);
gammaStrides_
=
shuffle_tensor_dimensions
<
Rank
,
NumReduceDim
>
(
gammaStrides
,
reduceDims
);
gammaStrides_
=
shuffle_tensor_dimensions
<
Rank
,
NumReduceDim
>
(
gammaStrides
,
reduceDims
);
betaStrides_
=
shuffle_tensor_dimensions
<
Rank
,
NumReduceDim
>
(
betaStrides
,
reduceDims
);
betaStrides_
=
shuffle_tensor_dimensions
<
Rank
,
NumReduceDim
>
(
betaStrides
,
reduceDims
);
saveMeanStrides_
=
saveMeanStrides
;
saveInvStdStrides_
=
saveInvStdStrides
;
std
::
tie
(
MRaw_
,
KRaw_
)
=
get_2d_lengths
<
Rank
,
NumReduceDim
>
(
Lengths_
);
std
::
tie
(
MRaw_
,
KRaw_
)
=
get_2d_lengths
<
Rank
,
NumReduceDim
>
(
Lengths_
);
...
@@ -346,20 +406,28 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType,
...
@@ -346,20 +406,28 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType,
y_grid_desc_m_k_
=
y_grid_desc_m_k_
=
MakeSrc2dDescriptor
(
Lengths_
,
yStrides_
,
kGridSize_
,
numBlockTileIteration_
);
MakeSrc2dDescriptor
(
Lengths_
,
yStrides_
,
kGridSize_
,
numBlockTileIteration_
);
save_mean_grid_desc_m_
=
MakeSaveMeanInvStdDescriptor_M
(
Lengths_
,
saveMeanStrides
);
save_inv_std_grid_desc_m_
=
MakeSaveMeanInvStdDescriptor_M
(
Lengths_
,
saveInvStdStrides
);
// We don't need to pad in K dimension for Welford1. Set KPerTile 1.
// We don't need to pad in K dimension for Welford1. Set KPerTile 1.
kernel1_mean_var_grid_desc_m_kblock_
=
kernel1_mean_var_grid_desc_m_kblock_
=
MakeMeanVarDescriptor_M_K
<
Sequence
<
true
,
false
>
,
M_BlockTileSize
,
1
>
(
MRaw_
,
Make
Workspace
MeanVarDescriptor_M_K
<
Sequence
<
true
,
false
>
,
M_BlockTileSize
,
1
>
(
kGridSize_
);
MRaw_
,
kGridSize_
);
kernel2_mean_var_grid_desc_m_kblock_
=
kernel2_mean_var_grid_desc_m_kblock_
=
MakeMeanVarDescriptor_M_K
<
Sequence
<
true
,
true
>
,
Make
Workspace
MeanVarDescriptor_M_K
<
Sequence
<
true
,
true
>
,
M_BlockTileSize
,
M_BlockTileSize
,
K_MeanVarCountBlockTileSize
>
(
MRaw_
,
kGridSize_
);
K_MeanVarCountBlockTileSize
>
(
MRaw_
,
kGridSize_
);
kernel2_count_grid_desc_m_kblock_
=
kernel2_count_grid_desc_m_kblock_
=
MakeCountDescriptor_M_K
<
Sequence
<
true
,
true
>
,
MakeWorkspaceCountDescriptor_M_K
<
Sequence
<
true
,
true
>
,
M_BlockTileSize
,
M_BlockTileSize
,
K_MeanVarCountBlockTileSize
>
(
MRaw_
,
kGridSize_
);
K_MeanVarCountBlockTileSize
>
(
MRaw_
,
kGridSize_
);
if
constexpr
(
NumInvariantDim
==
0
)
invariant_lowest_length_
=
1
;
else
invariant_lowest_length_
=
Lengths_
[
NumInvariantDim
-
1
];
}
}
ComputeDataType
epsilon_
;
ComputeDataType
epsilon_
;
...
@@ -368,6 +436,8 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType,
...
@@ -368,6 +436,8 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType,
const
GammaDataType
*
p_gamma_
;
const
GammaDataType
*
p_gamma_
;
const
BetaDataType
*
p_beta_
;
const
BetaDataType
*
p_beta_
;
YDataType
*
p_y_
;
YDataType
*
p_y_
;
SaveMeanInvStdDataType
*
p_saveMean_
;
SaveMeanInvStdDataType
*
p_saveInvStd_
;
void
*
p_workspace_mean_
;
void
*
p_workspace_mean_
;
void
*
p_workspace_var_
;
void
*
p_workspace_var_
;
void
*
p_workspace_count_
;
void
*
p_workspace_count_
;
...
@@ -377,6 +447,8 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType,
...
@@ -377,6 +447,8 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType,
std
::
vector
<
index_t
>
gammaStrides_
;
std
::
vector
<
index_t
>
gammaStrides_
;
std
::
vector
<
index_t
>
betaStrides_
;
std
::
vector
<
index_t
>
betaStrides_
;
std
::
vector
<
index_t
>
yStrides_
;
std
::
vector
<
index_t
>
yStrides_
;
std
::
vector
<
index_t
>
saveMeanStrides_
;
std
::
vector
<
index_t
>
saveInvStdStrides_
;
YElementwiseOperation
y_elementwise_op_
;
YElementwiseOperation
y_elementwise_op_
;
...
@@ -389,6 +461,8 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType,
...
@@ -389,6 +461,8 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType,
SrcGridDesc_M_K
gamma_grid_desc_m_k_
;
SrcGridDesc_M_K
gamma_grid_desc_m_k_
;
SrcGridDesc_M_K
beta_grid_desc_m_k_
;
SrcGridDesc_M_K
beta_grid_desc_m_k_
;
SrcGridDesc_M_K
y_grid_desc_m_k_
;
SrcGridDesc_M_K
y_grid_desc_m_k_
;
SaveMeanInvStdGridDesc_M
save_mean_grid_desc_m_
;
SaveMeanInvStdGridDesc_M
save_inv_std_grid_desc_m_
;
Kernel1MeanVarGridDesc_M_KBlock
kernel1_mean_var_grid_desc_m_kblock_
;
Kernel1MeanVarGridDesc_M_KBlock
kernel1_mean_var_grid_desc_m_kblock_
;
Kernel2MeanVarGridDesc_M_KBlock
kernel2_mean_var_grid_desc_m_kblock_
;
Kernel2MeanVarGridDesc_M_KBlock
kernel2_mean_var_grid_desc_m_kblock_
;
...
@@ -396,6 +470,8 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType,
...
@@ -396,6 +470,8 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType,
index_t
MRaw_
;
// invarient length
index_t
MRaw_
;
// invarient length
index_t
KRaw_
;
// reduce length
index_t
KRaw_
;
// reduce length
index_t
invariant_lowest_length_
;
};
};
struct
Invoker
:
public
BaseInvoker
struct
Invoker
:
public
BaseInvoker
...
@@ -408,60 +484,68 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType,
...
@@ -408,60 +484,68 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType,
auto
kernel1
=
kernel_normalizationSplitK1st
<
GridwiseWelford
,
auto
kernel1
=
kernel_normalizationSplitK1st
<
GridwiseWelford
,
XDataType
,
XDataType
,
MeanVarDataType
,
Workspace
MeanVarDataType
,
ComputeDataType
,
ComputeDataType
,
SrcGridDesc_M_K
,
SrcGridDesc_M_K
,
Kernel1MeanVarGridDesc_M_KBlock
>
;
Kernel1MeanVarGridDesc_M_KBlock
>
;
auto
kernel2
=
kernel_normalizationSplitK2nd
<
GridwiseWelfordNormalization
,
auto
kernel2
=
kernel_normalizationSplitK2nd
<
GridwiseWelfordNormalization
,
MeanVarDataType
,
Workspace
MeanVarDataType
,
XDataType
,
XDataType
,
GammaDataType
,
GammaDataType
,
BetaDataType
,
BetaDataType
,
YDataType
,
YDataType
,
SaveMeanInvStdDataType
,
ComputeDataType
,
ComputeDataType
,
YElementwiseOperation
,
YElementwiseOperation
,
Kernel2MeanVarGridDesc_M_KBlock
,
Kernel2MeanVarGridDesc_M_KBlock
,
Kernel2CountGridDesc_M_KBlock
,
Kernel2CountGridDesc_M_KBlock
,
SrcGridDesc_M_K
>
;
SrcGridDesc_M_K
,
SaveMeanInvStdGridDesc_M
>
;
float
avg_time
=
0
;
float
avg_time
=
0
;
avg_time
+=
launch_and_time_kernel
(
stream_config
,
avg_time
+=
launch_and_time_kernel
(
kernel1
,
stream_config
,
dim3
(
arg
.
gridSize_
),
kernel1
,
dim3
(
BlockSize
),
dim3
(
arg
.
gridSize_
),
0
,
dim3
(
BlockSize
),
arg
.
x_grid_desc_m_k_
,
0
,
arg
.
kernel1_mean_var_grid_desc_m_kblock_
,
arg
.
x_grid_desc_m_k_
,
arg
.
numBlockTileIteration_
,
arg
.
kernel1_mean_var_grid_desc_m_kblock_
,
arg
.
p_x_
,
arg
.
numBlockTileIteration_
,
static_cast
<
MeanVarDataType
*>
(
arg
.
p_workspace_mean_
),
arg
.
p_x_
,
static_cast
<
MeanVarDataType
*>
(
arg
.
p_workspace_var_
),
static_cast
<
WorkspaceMeanVarDataType
*>
(
arg
.
p_workspace_mean_
),
static_cast
<
int32_t
*>
(
arg
.
p_workspace_count_
));
static_cast
<
WorkspaceMeanVarDataType
*>
(
arg
.
p_workspace_var_
),
static_cast
<
int32_t
*>
(
arg
.
p_workspace_count_
));
avg_time
+=
launch_and_time_kernel
(
stream_config
,
kernel2
,
avg_time
+=
launch_and_time_kernel
(
dim3
(
arg
.
gridSize_
),
stream_config
,
dim3
(
BlockSize
),
kernel2
,
0
,
dim3
(
arg
.
gridSize_
),
arg
.
kernel2_mean_var_grid_desc_m_kblock_
,
dim3
(
BlockSize
),
arg
.
kernel2_count_grid_desc_m_kblock_
,
0
,
arg
.
x_grid_desc_m_k_
,
arg
.
kernel2_mean_var_grid_desc_m_kblock_
,
arg
.
gamma_grid_desc_m_k_
,
arg
.
kernel2_count_grid_desc_m_kblock_
,
arg
.
beta_grid_desc_m_k_
,
arg
.
x_grid_desc_m_k_
,
arg
.
y_grid_desc_m_k_
,
arg
.
gamma_grid_desc_m_k_
,
arg
.
numMeanVarCountIteration_
,
arg
.
beta_grid_desc_m_k_
,
arg
.
numBlockTileIteration_
,
arg
.
y_grid_desc_m_k_
,
arg
.
kGridSize_
,
arg
.
save_mean_grid_desc_m_
,
arg
.
epsilon_
,
arg
.
save_inv_std_grid_desc_m_
,
static_cast
<
MeanVarDataType
*>
(
arg
.
p_workspace_mean_
),
arg
.
numMeanVarCountIteration_
,
static_cast
<
MeanVarDataType
*>
(
arg
.
p_workspace_var_
),
arg
.
numBlockTileIteration_
,
static_cast
<
int32_t
*>
(
arg
.
p_workspace_count_
),
arg
.
kGridSize_
,
arg
.
p_x_
,
arg
.
epsilon_
,
arg
.
p_gamma_
,
static_cast
<
const
WorkspaceMeanVarDataType
*>
(
arg
.
p_workspace_mean_
),
arg
.
p_beta_
,
static_cast
<
const
WorkspaceMeanVarDataType
*>
(
arg
.
p_workspace_var_
),
arg
.
p_y_
,
static_cast
<
const
int32_t
*>
(
arg
.
p_workspace_count_
),
arg
.
y_elementwise_op_
);
arg
.
p_x_
,
arg
.
p_gamma_
,
arg
.
p_beta_
,
arg
.
p_y_
,
arg
.
p_saveMean_
,
arg
.
p_saveInvStd_
,
arg
.
y_elementwise_op_
);
return
avg_time
;
return
avg_time
;
};
};
...
@@ -482,10 +566,10 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType,
...
@@ -482,10 +566,10 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType,
int
welford_size
=
pArg_
->
MRaw_
*
pArg_
->
kGridSize_
;
int
welford_size
=
pArg_
->
MRaw_
*
pArg_
->
kGridSize_
;
// workspace for welford intermediate mean
// workspace for welford intermediate mean
workspace_size
+=
welford_size
*
sizeof
(
MeanVarDataType
)
+
64
;
workspace_size
+=
welford_size
*
sizeof
(
Workspace
MeanVarDataType
)
+
64
;
// workspace for welford intermediate variance
// workspace for welford intermediate variance
workspace_size
+=
welford_size
*
sizeof
(
MeanVarDataType
)
+
64
;
workspace_size
+=
welford_size
*
sizeof
(
Workspace
MeanVarDataType
)
+
64
;
// workspace for welford intermediate count
// workspace for welford intermediate count
workspace_size
+=
pArg_
->
kGridSize_
*
sizeof
(
int32_t
)
+
64
;
workspace_size
+=
pArg_
->
kGridSize_
*
sizeof
(
int32_t
)
+
64
;
...
@@ -504,13 +588,13 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType,
...
@@ -504,13 +588,13 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType,
// setup buffer used for intermediate welford mean
// setup buffer used for intermediate welford mean
pArg_
->
p_workspace_mean_
=
static_cast
<
char
*>
(
pArg_
->
p_workspace_
);
pArg_
->
p_workspace_mean_
=
static_cast
<
char
*>
(
pArg_
->
p_workspace_
);
index_t
mean_space_sz
=
welford_size
*
sizeof
(
MeanVarDataType
);
index_t
mean_space_sz
=
welford_size
*
sizeof
(
Workspace
MeanVarDataType
);
mean_space_sz
=
math
::
integer_least_multiple
(
mean_space_sz
,
64
);
mean_space_sz
=
math
::
integer_least_multiple
(
mean_space_sz
,
64
);
// setup buffer used for intermediate welford varirance
// setup buffer used for intermediate welford varirance
pArg_
->
p_workspace_var_
=
reinterpret_cast
<
char
*>
(
pArg_
->
p_workspace_mean_
)
+
mean_space_sz
;
pArg_
->
p_workspace_var_
=
reinterpret_cast
<
char
*>
(
pArg_
->
p_workspace_mean_
)
+
mean_space_sz
;
index_t
variance_space_sz
=
welford_size
*
sizeof
(
MeanVarDataType
);
index_t
variance_space_sz
=
welford_size
*
sizeof
(
Workspace
MeanVarDataType
);
variance_space_sz
=
math
::
integer_least_multiple
(
variance_space_sz
,
64
);
variance_space_sz
=
math
::
integer_least_multiple
(
variance_space_sz
,
64
);
// setup buffer used for intermediate welford count
// setup buffer used for intermediate welford count
...
@@ -522,8 +606,6 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType,
...
@@ -522,8 +606,6 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType,
{
{
const
Argument
*
p_arg_
=
dynamic_cast
<
const
Argument
*>
(
p_arg
);
const
Argument
*
p_arg_
=
dynamic_cast
<
const
Argument
*>
(
p_arg
);
constexpr
index_t
NumInvariantDim
=
Rank
-
NumReduceDim
;
if
constexpr
(
XYVectorDim
==
0
)
if
constexpr
(
XYVectorDim
==
0
)
{
{
if
constexpr
(
NumInvariantDim
==
0
)
if
constexpr
(
NumInvariantDim
==
0
)
...
@@ -535,10 +617,10 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType,
...
@@ -535,10 +617,10 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType,
if
(
p_arg_
->
xStrides_
[
NumInvariantDim
-
1
]
!=
1
)
if
(
p_arg_
->
xStrides_
[
NumInvariantDim
-
1
]
!=
1
)
return
false
;
return
false
;
if
(
p_arg_
->
invariant_lowest_length
%
XSrcVectorSize
!=
0
)
if
(
p_arg_
->
invariant_lowest_length
_
%
XSrcVectorSize
!=
0
)
return
false
;
return
false
;
if
(
p_arg_
->
invariant_lowest_length
%
YDstVectorSize
!=
0
)
if
(
p_arg_
->
invariant_lowest_length
_
%
YDstVectorSize
!=
0
)
return
false
;
return
false
;
};
};
}
}
...
@@ -578,7 +660,7 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType,
...
@@ -578,7 +660,7 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType,
if
(
p_arg_
->
betaStrides_
[
NumInvariantDim
-
1
]
!=
1
)
if
(
p_arg_
->
betaStrides_
[
NumInvariantDim
-
1
]
!=
1
)
return
false
;
return
false
;
if
(
p_arg_
->
invariant_lowest_length
%
BetaSrcVectorSize
!=
0
)
if
(
p_arg_
->
invariant_lowest_length
_
%
BetaSrcVectorSize
!=
0
)
return
false
;
return
false
;
}
}
else
// if fastest dim is reduced
else
// if fastest dim is reduced
...
@@ -593,6 +675,9 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType,
...
@@ -593,6 +675,9 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType,
if
(
p_arg_
->
kGridSize_
<=
1
)
if
(
p_arg_
->
kGridSize_
<=
1
)
return
false
;
return
false
;
if
(
p_arg_
->
invariant_lowest_length_
%
SaveMeanInvStdDstVectorSize
!=
0
)
return
false
;
return
true
;
return
true
;
};
};
...
@@ -602,6 +687,8 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType,
...
@@ -602,6 +687,8 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType,
const
std
::
vector
<
index_t
>
gammaStrides
,
const
std
::
vector
<
index_t
>
gammaStrides
,
const
std
::
vector
<
index_t
>
betaStrides
,
const
std
::
vector
<
index_t
>
betaStrides
,
const
std
::
vector
<
index_t
>
yStrides
,
const
std
::
vector
<
index_t
>
yStrides
,
const
std
::
vector
<
index_t
>
saveMeanStrides
,
const
std
::
vector
<
index_t
>
saveInvStdStrides
,
const
std
::
vector
<
index_t
>
reduceDims
,
const
std
::
vector
<
index_t
>
reduceDims
,
double
epsilon
,
double
epsilon
,
const
void
*
p_x
,
const
void
*
p_x
,
...
@@ -609,27 +696,30 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType,
...
@@ -609,27 +696,30 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType,
const
void
*
p_beta
,
const
void
*
p_beta
,
void
*
p_y
,
void
*
p_y
,
void
*
p_saveMean
,
void
*
p_saveMean
,
void
*
p_saveInv
Var
,
void
*
p_saveInv
Std
,
YElementwiseOperation
y_elementwise_op
)
override
YElementwiseOperation
y_elementwise_op
)
override
{
{
// TODO
if
(
lengths
.
size
()
!=
Rank
||
xStrides
.
size
()
!=
Rank
||
gammaStrides
.
size
()
!=
Rank
||
// Optional cache of the intermediate results (mean and InvVariance) during the
betaStrides
.
size
()
!=
Rank
||
yStrides
.
size
()
!=
Rank
||
// forward pass could speedup in the backward
saveMeanStrides
.
size
()
!=
NumInvariantDim
||
saveInvStdStrides
.
size
()
!=
NumInvariantDim
)
ignore
=
p_saveMean
;
throw
std
::
runtime_error
(
"dimension is incorrect"
);
ignore
=
p_saveInvVar
;
return
std
::
make_unique
<
Argument
>
(
lengths
,
return
std
::
make_unique
<
Argument
>
(
lengths
,
xStrides
,
xStrides
,
gammaStrides
,
gammaStrides
,
betaStrides
,
betaStrides
,
yStrides
,
yStrides
,
saveMeanStrides
,
saveInvStdStrides
,
reduceDims
,
reduceDims
,
y_elementwise_op
,
y_elementwise_op
,
epsilon
,
epsilon
,
static_cast
<
const
XDataType
*>
(
p_x
),
static_cast
<
const
XDataType
*>
(
p_x
),
static_cast
<
const
GammaDataType
*>
(
p_gamma
),
static_cast
<
const
GammaDataType
*>
(
p_gamma
),
static_cast
<
const
BetaDataType
*>
(
p_beta
),
static_cast
<
const
BetaDataType
*>
(
p_beta
),
static_cast
<
YDataType
*>
(
p_y
));
static_cast
<
YDataType
*>
(
p_y
),
static_cast
<
SaveMeanInvStdDataType
*>
(
p_saveMean
),
static_cast
<
SaveMeanInvStdDataType
*>
(
p_saveInvStd
));
};
};
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
override
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
override
...
...
include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp
View file @
f30e5975
...
@@ -113,7 +113,6 @@ struct PassThrough
...
@@ -113,7 +113,6 @@ struct PassThrough
}
}
#endif
#endif
#if defined CK_ENABLE_FP8
template
<
>
template
<
>
__host__
__device__
void
operator
()
<
f8_t
,
f8_t
>
(
f8_t
&
y
,
const
f8_t
&
x
)
const
__host__
__device__
void
operator
()
<
f8_t
,
f8_t
>
(
f8_t
&
y
,
const
f8_t
&
x
)
const
{
{
...
@@ -143,9 +142,7 @@ struct PassThrough
...
@@ -143,9 +142,7 @@ struct PassThrough
{
{
y
=
type_convert
<
f8_t
>
(
x
);
y
=
type_convert
<
f8_t
>
(
x
);
}
}
#endif
#if defined CK_ENABLE_BF8
template
<
>
template
<
>
__host__
__device__
void
operator
()
<
bf8_t
,
bf8_t
>
(
bf8_t
&
y
,
const
bf8_t
&
x
)
const
__host__
__device__
void
operator
()
<
bf8_t
,
bf8_t
>
(
bf8_t
&
y
,
const
bf8_t
&
x
)
const
{
{
...
@@ -175,7 +172,6 @@ struct PassThrough
...
@@ -175,7 +172,6 @@ struct PassThrough
{
{
y
=
ck
::
type_convert
<
bf8_t
>
(
x
);
y
=
ck
::
type_convert
<
bf8_t
>
(
x
);
}
}
#endif
};
};
struct
UnaryConvert
struct
UnaryConvert
...
@@ -204,7 +200,6 @@ struct ConvertBF16RTN
...
@@ -204,7 +200,6 @@ struct ConvertBF16RTN
}
}
};
};
#if defined CK_ENABLE_FP8
struct
ConvertF8SR
struct
ConvertF8SR
{
{
// convert to fp8 using stochastic rounding (SR)
// convert to fp8 using stochastic rounding (SR)
...
@@ -212,7 +207,8 @@ struct ConvertF8SR
...
@@ -212,7 +207,8 @@ struct ConvertF8SR
__host__
__device__
void
operator
()(
Y
&
y
,
const
X
&
x
)
const
__host__
__device__
void
operator
()(
Y
&
y
,
const
X
&
x
)
const
{
{
// check Y datatype
// check Y datatype
static_assert
(
is_same
<
Y
,
f8_t
>::
value
,
"Data type is not supported by this operation!"
);
static_assert
(
is_same
<
Y
,
f8_t
>::
value
||
is_same
<
Y
,
bf8_t
>::
value
,
"Data type is not supported by this operation!"
);
// check X datatype
// check X datatype
static_assert
(
is_same
<
X
,
float
>::
value
||
is_same
<
X
,
half_t
>::
value
,
static_assert
(
is_same
<
X
,
float
>::
value
||
is_same
<
X
,
half_t
>::
value
,
...
@@ -221,7 +217,6 @@ struct ConvertF8SR
...
@@ -221,7 +217,6 @@ struct ConvertF8SR
y
=
f8_convert_sr
<
Y
>
(
x
);
y
=
f8_convert_sr
<
Y
>
(
x
);
}
}
};
};
#endif
struct
Scale
struct
Scale
{
{
...
@@ -448,10 +443,11 @@ struct Sigmoid
...
@@ -448,10 +443,11 @@ struct Sigmoid
__host__
__device__
void
operator
()(
T
&
y
,
const
T
&
x
)
const
__host__
__device__
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
{
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
is_same
<
T
,
ck
::
half_t
>::
value
,
is_same
<
T
,
ck
::
half_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
||
is_same
<
T
,
int32_t
>::
value
,
"Data type is not supported by this operation!"
);
"Data type is not supported by this operation!"
);
constexpr
T
one
=
type_convert
<
T
>
(
1
);
y
=
1
/
(
ck
::
type_convert
<
T
>
(
1
)
+
exp
(
-
x
));
y
=
one
/
(
one
+
ck
::
math
::
exp
(
-
x
));
};
};
};
};
...
@@ -461,7 +457,8 @@ struct TanH
...
@@ -461,7 +457,8 @@ struct TanH
__host__
__device__
void
operator
()(
T
&
y
,
const
T
&
x
)
const
__host__
__device__
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
{
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
is_same
<
T
,
ck
::
half_t
>::
value
,
is_same
<
T
,
ck
::
half_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
||
is_same
<
T
,
int32_t
>::
value
,
"Data type is not supported by this operation!"
);
"Data type is not supported by this operation!"
);
y
=
ck
::
math
::
tanh
(
x
);
y
=
ck
::
math
::
tanh
(
x
);
...
@@ -487,7 +484,101 @@ struct Swish
...
@@ -487,7 +484,101 @@ struct Swish
y
=
type_convert
<
Y
>
(
x
/
(
1.
f
+
ck
::
math
::
exp
(
bx
)));
y
=
type_convert
<
Y
>
(
x
/
(
1.
f
+
ck
::
math
::
exp
(
bx
)));
};
};
float
beta_
=
1.0
f
;
const
float
beta_
;
};
struct
SoftRelu
{
SoftRelu
(
float
alpha
=
1.
f
)
:
alpha_
(
alpha
){};
template
<
typename
T
>
__host__
__device__
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
is_same
<
T
,
half_t
>::
value
||
is_same
<
T
,
int32_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
,
"Data type is not supported by this operation!"
);
T
casted_alpha
=
type_convert
<
T
>
(
alpha_
);
constexpr
T
one
=
type_convert
<
T
>
(
1
);
y
=
ck
::
math
::
log
(
one
+
ck
::
math
::
exp
(
x
*
casted_alpha
))
/
casted_alpha
;
}
const
float
alpha_
;
};
struct
Power
{
Power
(
float
alpha
=
0.
f
,
float
beta
=
1.
f
,
float
gamma
=
2.
f
)
:
alpha_
(
alpha
),
beta_
(
beta
),
gamma_
(
gamma
){};
template
<
typename
T
>
__host__
__device__
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
is_same
<
T
,
half_t
>::
value
||
is_same
<
T
,
int32_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
,
"Data type is not supported by this operation!"
);
T
casted_alpha
=
type_convert
<
T
>
(
alpha_
);
T
casted_beta
=
type_convert
<
T
>
(
beta_
);
T
casted_gamma
=
type_convert
<
T
>
(
gamma_
);
T
shifted_scaled_x
=
casted_alpha
+
casted_beta
*
x
;
y
=
ck
::
math
::
pow
(
shifted_scaled_x
,
casted_gamma
);
}
const
float
alpha_
;
const
float
beta_
;
const
float
gamma_
;
};
struct
ClippedRelu
{
ClippedRelu
(
float
alpha
=
0.
f
,
float
beta
=
1.
f
)
:
alpha_
(
alpha
),
beta_
(
beta
){};
template
<
typename
T
>
__host__
__device__
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
is_same
<
T
,
half_t
>::
value
||
is_same
<
T
,
int32_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
,
"Data type is not supported by this operation!"
);
T
casted_alpha
=
type_convert
<
T
>
(
alpha_
);
T
casted_beta
=
type_convert
<
T
>
(
beta_
);
y
=
ck
::
math
::
min
(
casted_beta
,
ck
::
math
::
max
(
casted_alpha
,
x
));
}
const
float
alpha_
;
const
float
beta_
;
};
struct
LeakyRelu
{
LeakyRelu
(
float
alpha
=
0.01
f
)
:
alpha_
(
alpha
){};
template
<
typename
T
>
__host__
__device__
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
is_same
<
T
,
half_t
>::
value
||
is_same
<
T
,
int32_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
,
"Data type is not supported by this operation!"
);
T
casted_alpha
=
type_convert
<
T
>
(
alpha_
);
y
=
x
>=
0
?
x
:
x
*
casted_alpha
;
}
const
float
alpha_
;
};
struct
Elu
{
Elu
(
float
alpha
=
1.
f
)
:
alpha_
(
alpha
){};
template
<
typename
T
>
__host__
__device__
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
is_same
<
T
,
half_t
>::
value
||
is_same
<
T
,
int32_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
,
"Data type is not supported by this operation!"
);
T
casted_alpha
=
type_convert
<
T
>
(
alpha_
);
y
=
x
>
0
?
x
:
casted_alpha
*
ck
::
math
::
expm1
(
x
);
}
const
float
alpha_
;
};
};
}
// namespace element_wise
}
// namespace element_wise
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp
View file @
f30e5975
...
@@ -428,7 +428,7 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle
...
@@ -428,7 +428,7 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle
[
&
](
auto
i
)
{
[
&
](
auto
i
)
{
using
ALayout
=
remove_cvref_t
<
tuple_element_t
<
i
.
value
,
AsLayout
>>
;
using
ALayout
=
remove_cvref_t
<
tuple_element_t
<
i
.
value
,
AsLayout
>>
;
return
MakeAGridDescriptor_M_
N
<
ALayout
,
GemmSpec
>
(
MRaws
[
i
],
KRaws
[
i
],
AsStride
[
i
]);
return
MakeAGridDescriptor_M_
K
<
ALayout
,
GemmSpec
>
(
MRaws
[
i
],
KRaws
[
i
],
AsStride
[
i
]);
},
},
Number
<
NumATensor
>
{});
Number
<
NumATensor
>
{});
}
}
...
@@ -656,7 +656,8 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle
...
@@ -656,7 +656,8 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle
auto
blockwise_gemm
=
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector
<
auto
blockwise_gemm
=
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector
<
BlockSize
,
BlockSize
,
ComputeDataType
,
ComputeDataType
,
// ComputeDataType for A
ComputeDataType
,
// ComputeDataType for B
AccDataType
,
AccDataType
,
decltype
(
a_block_desc_ak0_m_ak1
),
decltype
(
a_block_desc_ak0_m_ak1
),
decltype
(
b_block_desc_bk0_n_bk1
),
decltype
(
b_block_desc_bk0_n_bk1
),
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp
View file @
f30e5975
...
@@ -36,7 +36,7 @@ __global__ void
...
@@ -36,7 +36,7 @@ __global__ void
#if CK_USE_LAUNCH_BOUNDS
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
#endif
#endif
kernel_grouped_conv_
fwd_
multiple_d_wmma_cshuffle
(
kernel_grouped_conv_multiple_d_wmma_cshuffle
(
const
ADataType
*
__restrict__
p_a_grid
,
const
ADataType
*
__restrict__
p_a_grid
,
const
BDataType
*
__restrict__
p_b_grid
,
const
BDataType
*
__restrict__
p_b_grid
,
DsPointer
p_ds_grid
,
DsPointer
p_ds_grid
,
...
@@ -452,11 +452,11 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle
...
@@ -452,11 +452,11 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle
}
}
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
// CheckValidity for kernels without multi D
template
<
typename
Block2CTileMap
>
template
<
typename
Block2CTileMap
>
__host__
__device__
static
constexpr
bool
__host__
__device__
static
constexpr
bool
CheckValidity
(
const
AGridDesc_K0_M_K1
&
a_grid_desc_k0_m_k1
,
CheckValidity
(
const
AGridDesc_K0_M_K1
&
a_grid_desc_k0_m_k1
,
const
BGridDesc_K0_N_K1
&
b_grid_desc_k0_n_k1
,
const
BGridDesc_K0_N_K1
&
b_grid_desc_k0_n_k1
,
const
DsGridDesc_M_N
&
ds_grid_desc_m_n
,
const
EGridDesc_M_N
&
e_grid_desc_m_n
,
const
EGridDesc_M_N
&
e_grid_desc_m_n
,
const
Block2CTileMap
&
block_2_ctile_map
)
const
Block2CTileMap
&
block_2_ctile_map
)
{
{
...
@@ -471,18 +471,6 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle
...
@@ -471,18 +471,6 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle
const
auto
N
=
b_grid_desc_k0_n_k1
.
GetLength
(
I1
);
const
auto
N
=
b_grid_desc_k0_n_k1
.
GetLength
(
I1
);
const
auto
K0
=
a_grid_desc_k0_m_k1
.
GetLength
(
I0
);
const
auto
K0
=
a_grid_desc_k0_m_k1
.
GetLength
(
I0
);
bool
valid
=
true
;
static_for
<
0
,
NumDTensor
,
1
>
{}([
&
](
auto
i
)
{
valid
=
valid
&&
(
M
==
ds_grid_desc_m_n
[
i
].
GetLength
(
I0
)
&&
N
==
ds_grid_desc_m_n
[
i
].
GetLength
(
I1
));
});
if
(
!
valid
)
{
return
false
;
}
if
(
!
(
M
==
e_grid_desc_m_n
.
GetLength
(
I0
)
&&
N
==
e_grid_desc_m_n
.
GetLength
(
I1
)
&&
if
(
!
(
M
==
e_grid_desc_m_n
.
GetLength
(
I0
)
&&
N
==
e_grid_desc_m_n
.
GetLength
(
I1
)
&&
K0
==
b_grid_desc_k0_n_k1
.
GetLength
(
I0
)
&&
K1
==
a_grid_desc_k0_m_k1
.
GetLength
(
I2
)
&&
K0
==
b_grid_desc_k0_n_k1
.
GetLength
(
I0
)
&&
K1
==
a_grid_desc_k0_m_k1
.
GetLength
(
I2
)
&&
K1
==
b_grid_desc_k0_n_k1
.
GetLength
(
I2
)))
K1
==
b_grid_desc_k0_n_k1
.
GetLength
(
I2
)))
...
@@ -517,6 +505,31 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle
...
@@ -517,6 +505,31 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle
return
true
;
return
true
;
}
}
template
<
typename
Block2CTileMap
>
__host__
__device__
static
constexpr
bool
CheckValidity
(
const
AGridDesc_K0_M_K1
&
a_grid_desc_k0_m_k1
,
const
BGridDesc_K0_N_K1
&
b_grid_desc_k0_n_k1
,
const
DsGridDesc_M_N
&
ds_grid_desc_m_n
,
const
EGridDesc_M_N
&
e_grid_desc_m_n
,
const
Block2CTileMap
&
block_2_ctile_map
)
{
const
auto
M
=
a_grid_desc_k0_m_k1
.
GetLength
(
I1
);
const
auto
N
=
b_grid_desc_k0_n_k1
.
GetLength
(
I1
);
bool
valid
=
true
;
static_for
<
0
,
NumDTensor
,
1
>
{}([
&
](
auto
i
)
{
valid
=
valid
&&
(
M
==
ds_grid_desc_m_n
[
i
].
GetLength
(
I0
)
&&
N
==
ds_grid_desc_m_n
[
i
].
GetLength
(
I1
));
});
if
(
!
valid
)
{
return
false
;
}
return
CheckValidity
(
a_grid_desc_k0_m_k1
,
b_grid_desc_k0_n_k1
,
e_grid_desc_m_n
,
block_2_ctile_map
);
}
__host__
__device__
static
constexpr
bool
CalculateHasMainKBlockLoop
(
index_t
K
)
__host__
__device__
static
constexpr
bool
CalculateHasMainKBlockLoop
(
index_t
K
)
{
{
const
index_t
num_loop
=
K
/
(
K0PerBlock
*
K1
);
const
index_t
num_loop
=
K
/
(
K0PerBlock
*
K1
);
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp
View file @
f30e5975
...
@@ -945,7 +945,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3_ext
...
@@ -945,7 +945,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3_ext
}
}
}();
}();
if
constexpr
(
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MNPadding
)
if
constexpr
(
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MNPadding
||
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MNKPadding
)
{
{
return
transform_tensor_descriptor
(
c_grid_desc_m_n
,
return
transform_tensor_descriptor
(
c_grid_desc_m_n
,
make_tuple
(
make_right_pad_transform
(
M
,
MPad
-
M
),
make_tuple
(
make_right_pad_transform
(
M
,
MPad
-
M
),
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp
View file @
f30e5975
...
@@ -22,13 +22,19 @@ namespace ck {
...
@@ -22,13 +22,19 @@ namespace ck {
template
<
typename
GridwiseGemm
,
template
<
typename
GridwiseGemm
,
bool
HasMainKBlockLoop
,
bool
HasMainKBlockLoop
,
InMemoryDataOperationEnum
CGlobalMemoryDataOperation
,
InMemoryDataOperationEnum
CGlobalMemoryDataOperation
,
typename
Block2CTileMap
>
typename
Block2CTileMap
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
>
__global__
void
__global__
void
#if CK_USE_LAUNCH_BOUNDS
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
#endif
#endif
kernel_gemm_xdlops_v2r4r2_simplified
(
typename
GridwiseGemm
::
Argument
karg
,
kernel_gemm_xdlops_v2r4r2_simplified
(
typename
GridwiseGemm
::
Argument
karg
,
const
Block2CTileMap
&
b2c_map
)
const
Block2CTileMap
&
b2c_map
,
const
AElementwiseOperation
a_element_op
,
const
BElementwiseOperation
b_element_op
,
const
CElementwiseOperation
c_element_op
)
{
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
...
@@ -37,10 +43,13 @@ __global__ void
...
@@ -37,10 +43,13 @@ __global__ void
__shared__
uint8_t
p_shared
[
shared_size
];
__shared__
uint8_t
p_shared
[
shared_size
];
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
,
CGlobalMemoryDataOperation
>(
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
,
CGlobalMemoryDataOperation
>(
karg
,
static_cast
<
void
*>
(
p_shared
),
b2c_map
);
karg
,
static_cast
<
void
*>
(
p_shared
),
b2c_map
,
a_element_op
,
b_element_op
,
c_element_op
);
#else
#else
ignore
=
karg
;
ignore
=
karg
;
ignore
=
b2c_map
;
ignore
=
b2c_map
;
ignore
=
a_element_op
;
ignore
=
b_element_op
;
ignore
=
c_element_op
;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
}
}
...
@@ -577,7 +586,10 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
...
@@ -577,7 +586,10 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
typename
Block2CTileMap
>
typename
Block2CTileMap
>
__device__
static
void
Run
(
const
Argument
&
karg
,
__device__
static
void
Run
(
const
Argument
&
karg
,
void
*
__restrict__
p_shared_block
,
void
*
__restrict__
p_shared_block
,
const
Block2CTileMap
&
block_2_ctile_map
)
const
Block2CTileMap
&
block_2_ctile_map
,
const
AElementwiseOperation
a_element_op
=
AElementwiseOperation
{},
const
BElementwiseOperation
b_element_op
=
BElementwiseOperation
{},
const
CElementwiseOperation
c_element_op
=
CElementwiseOperation
{})
{
{
const
FloatA
*
p_a_grid
=
karg
.
p_a_grid
;
const
FloatA
*
p_a_grid
=
karg
.
p_a_grid
;
const
FloatB
*
p_b_grid
=
karg
.
p_b_grid
;
const
FloatB
*
p_b_grid
=
karg
.
p_b_grid
;
...
@@ -590,9 +602,6 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
...
@@ -590,9 +602,6 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
const
auto
c_grid_desc_mblock_mperblock_nblock_nperblock
=
const
auto
c_grid_desc_mblock_mperblock_nblock_nperblock
=
MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
(
c_grid_desc_m_n
);
MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
(
c_grid_desc_m_n
);
const
AElementwiseOperation
a_element_op
=
AElementwiseOperation
{};
const
BElementwiseOperation
b_element_op
=
BElementwiseOperation
{};
const
CElementwiseOperation
c_element_op
=
CElementwiseOperation
{};
const
auto
a_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
const
auto
a_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_a_grid
,
a_b_k0_m_k1_grid_desc
.
GetElementSpaceSize
());
p_a_grid
,
a_b_k0_m_k1_grid_desc
.
GetElementSpaceSize
());
...
@@ -761,8 +770,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
...
@@ -761,8 +770,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
auto
blockwise_gemm
=
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector
<
auto
blockwise_gemm
=
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector
<
BlockSize
,
BlockSize
,
ComputeType
,
ComputeType
,
// ComputeType A
ComputeType
,
ComputeType
,
// ComputeType B
FloatAcc
,
FloatAcc
,
decltype
(
a_k0_m_k1_block_desc
),
decltype
(
a_k0_m_k1_block_desc
),
decltype
(
b_k0_n_k1_block_desc
),
decltype
(
b_k0_n_k1_block_desc
),
...
...
include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_naive_variance.hpp
View file @
f30e5975
...
@@ -18,9 +18,11 @@ template <typename XDataType,
...
@@ -18,9 +18,11 @@ template <typename XDataType,
typename
GammaDataType
,
typename
GammaDataType
,
typename
BetaDataType
,
typename
BetaDataType
,
typename
YDataType
,
typename
YDataType
,
typename
SaveMeanInvStdDataType
,
typename
ComputeDataType
,
typename
ComputeDataType
,
typename
YElementwiseOperation
,
typename
YElementwiseOperation
,
typename
GridDesc_M_K
,
typename
GridDesc_M_K
,
typename
GridDesc_M
,
index_t
BlockSize
,
index_t
BlockSize
,
index_t
MThreadClusterSize
,
index_t
MThreadClusterSize
,
index_t
KThreadClusterSize
,
index_t
KThreadClusterSize
,
...
@@ -34,6 +36,7 @@ template <typename XDataType,
...
@@ -34,6 +36,7 @@ template <typename XDataType,
index_t
BetaSrcVectorSize
,
index_t
BetaSrcVectorSize
,
index_t
YDstVectorDim
,
index_t
YDstVectorDim
,
index_t
YDstVectorSize
,
index_t
YDstVectorSize
,
index_t
SaveMeanInvStdDstVectorSize
,
bool
SweepOnce
>
bool
SweepOnce
>
struct
GridwiseNormalizationNaiveVariance_mk_to_mk
struct
GridwiseNormalizationNaiveVariance_mk_to_mk
{
{
...
@@ -45,6 +48,10 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk
...
@@ -45,6 +48,10 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk
(
YDstVectorDim
==
1
&&
KThreadSliceSize
%
YDstVectorSize
==
0
),
(
YDstVectorDim
==
1
&&
KThreadSliceSize
%
YDstVectorSize
==
0
),
"Invalid thread slice sizes and/or vector sizes configuration, please check!"
);
"Invalid thread slice sizes and/or vector sizes configuration, please check!"
);
static_assert
(
MThreadSliceSize
%
SaveMeanInvStdDstVectorSize
==
0
,
"Invalid thread slice sizes and/or save mean and inverse std vector sizes "
"configuration, please check!"
);
static_assert
(
XSrcVectorSize
==
YDstVectorSize
);
static_assert
(
XSrcVectorSize
==
YDstVectorSize
);
static_assert
(
XSrcVectorSize
==
GammaSrcVectorSize
);
static_assert
(
XSrcVectorSize
==
GammaSrcVectorSize
);
static_assert
(
XSrcVectorSize
==
BetaSrcVectorSize
);
static_assert
(
XSrcVectorSize
==
BetaSrcVectorSize
);
...
@@ -66,6 +73,10 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk
...
@@ -66,6 +73,10 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk
static
constexpr
auto
thread_buffer_desc_m_k
=
make_naive_tensor_descriptor_packed
(
static
constexpr
auto
thread_buffer_desc_m_k
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MThreadSliceSize
>
{},
Number
<
XSrcVectorSize
>
{}));
make_tuple
(
Number
<
MThreadSliceSize
>
{},
Number
<
XSrcVectorSize
>
{}));
using
ThreadBufferLengths_M
=
Sequence
<
MThreadSliceSize
>
;
static
constexpr
auto
thread_buffer_desc_m
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MThreadSliceSize
>
{}));
using
ThreadReduceSrcDesc_M_K
=
decltype
(
make_naive_tensor_descriptor_packed
(
using
ThreadReduceSrcDesc_M_K
=
decltype
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MThreadSliceSize
>
{},
Number
<
XSrcVectorSize
>
{})));
make_tuple
(
Number
<
MThreadSliceSize
>
{},
Number
<
XSrcVectorSize
>
{})));
using
ThreadReduceDstDesc_M
=
using
ThreadReduceDstDesc_M
=
...
@@ -84,6 +95,8 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk
...
@@ -84,6 +95,8 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk
reduce
::
Add
,
reduce
::
Add
,
true
>
;
true
>
;
using
PassThroughOp
=
tensor_operation
::
element_wise
::
PassThrough
;
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
...
@@ -98,12 +111,16 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk
...
@@ -98,12 +111,16 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk
const
GridDesc_M_K
&
gamma_grid_desc_m_k
,
const
GridDesc_M_K
&
gamma_grid_desc_m_k
,
const
GridDesc_M_K
&
beta_grid_desc_m_k
,
const
GridDesc_M_K
&
beta_grid_desc_m_k
,
const
GridDesc_M_K
&
y_grid_desc_m_k
,
const
GridDesc_M_K
&
y_grid_desc_m_k
,
const
GridDesc_M
&
save_mean_grid_desc_m
,
const
GridDesc_M
&
save_inv_std_grid_desc_m
,
index_t
num_k_block_tile_iteration
,
index_t
num_k_block_tile_iteration
,
ComputeDataType
epsilon
,
ComputeDataType
epsilon
,
const
XDataType
*
const
__restrict__
p_x_global
,
const
XDataType
*
const
__restrict__
p_x_global
,
const
GammaDataType
*
const
__restrict__
p_gamma_global
,
const
GammaDataType
*
const
__restrict__
p_gamma_global
,
const
BetaDataType
*
const
__restrict__
p_beta_global
,
const
BetaDataType
*
const
__restrict__
p_beta_global
,
YDataType
*
const
__restrict__
p_y_global
,
YDataType
*
const
__restrict__
p_y_global
,
SaveMeanInvStdDataType
*
const
__restrict__
p_save_mean_global
,
SaveMeanInvStdDataType
*
const
__restrict__
p_save_inv_std_global
,
const
YElementwiseOperation
y_elementwise_op
)
const
YElementwiseOperation
y_elementwise_op
)
{
{
// LDS
// LDS
...
@@ -115,6 +132,12 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk
...
@@ -115,6 +132,12 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk
auto
y_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
auto
y_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_y_global
,
y_grid_desc_m_k
.
GetElementSpaceSize
());
p_y_global
,
y_grid_desc_m_k
.
GetElementSpaceSize
());
auto
save_mean_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_save_mean_global
,
save_mean_grid_desc_m
.
GetElementSpaceSize
());
auto
save_inv_std_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_save_inv_std_global
,
save_inv_std_grid_desc_m
.
GetElementSpaceSize
());
auto
x_thread_buf
=
generate_tuple
(
auto
x_thread_buf
=
generate_tuple
(
[
&
](
auto
)
{
[
&
](
auto
)
{
return
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
return
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
...
@@ -152,6 +175,8 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk
...
@@ -152,6 +175,8 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk
mean_square_thread_buf
;
mean_square_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
ComputeDataType
,
MThreadSliceSize
,
true
>&
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
ComputeDataType
,
MThreadSliceSize
,
true
>&
var_thread_buf
=
mean_square_thread_buf
;
var_thread_buf
=
mean_square_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
ComputeDataType
,
MThreadSliceSize
,
true
>&
inv_std_thread_buf
=
mean_square_thread_buf
;
const
index_t
thread_local_id
=
get_thread_local_1d_id
();
const
index_t
thread_local_id
=
get_thread_local_1d_id
();
const
index_t
block_global_id
=
get_block_1d_id
();
const
index_t
block_global_id
=
get_block_1d_id
();
...
@@ -228,6 +253,42 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk
...
@@ -228,6 +253,42 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk
thread_k_cluster_id
*
YDstVectorSize
),
thread_k_cluster_id
*
YDstVectorSize
),
y_elementwise_op
);
y_elementwise_op
);
auto
threadwise_mean_store
=
ThreadwiseTensorSliceTransfer_v1r3
<
ComputeDataType
,
SaveMeanInvStdDataType
,
decltype
(
thread_buffer_desc_m
),
GridDesc_M
,
PassThroughOp
,
ThreadBufferLengths_M
,
Sequence
<
0
>
,
// DimAccessOrder
0
,
// SrcVectorDim
SaveMeanInvStdDstVectorSize
,
// ScalarPerVector
InMemoryDataOperationEnum
::
Set
,
1
,
true
>
(
save_mean_grid_desc_m
,
make_multi_index
(
block_global_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
),
PassThroughOp
{});
auto
threadwise_inv_std_store
=
ThreadwiseTensorSliceTransfer_v1r3
<
ComputeDataType
,
SaveMeanInvStdDataType
,
decltype
(
thread_buffer_desc_m
),
GridDesc_M
,
PassThroughOp
,
ThreadBufferLengths_M
,
Sequence
<
0
>
,
// DimAccessOrder
0
,
// SrcVectorDim
SaveMeanInvStdDstVectorSize
,
// ScalarPerVector
InMemoryDataOperationEnum
::
Set
,
1
,
true
>
(
save_inv_std_grid_desc_m
,
make_multi_index
(
block_global_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
),
PassThroughOp
{});
constexpr
auto
thread_copy_fwd_step_m_k
=
make_multi_index
(
0
,
K_BlockTileStepSize
);
constexpr
auto
thread_copy_fwd_step_m_k
=
make_multi_index
(
0
,
K_BlockTileStepSize
);
constexpr
auto
thread_copy_bwd_step_m_k
=
constexpr
auto
thread_copy_bwd_step_m_k
=
make_multi_index
(
0
,
SweepOnce
?
0
:
-
K_BlockTileSize
);
make_multi_index
(
0
,
SweepOnce
?
0
:
-
K_BlockTileSize
);
...
@@ -243,7 +304,8 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk
...
@@ -243,7 +304,8 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk
// E(x), E[x^2], var(x)
// E(x), E[x^2], var(x)
// FIXME: Should not hack the transform from deviceOP
// FIXME: Should not hack the transform from deviceOP
int
reduce_length
=
x_grid_desc_m_k
.
GetTransforms
()[
I2
].
GetUpperLengths
()[
I0
];
ComputeDataType
reduce_length
=
type_convert
<
ComputeDataType
>
(
x_grid_desc_m_k
.
GetTransforms
()[
I2
].
GetUpperLengths
()[
I0
]);
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
I
)
{
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
I
)
{
mean_thread_buf
(
I
)
=
reduce
::
Add
::
template
GetIdentityValue
<
ComputeDataType
>();
mean_thread_buf
(
I
)
=
reduce
::
Add
::
template
GetIdentityValue
<
ComputeDataType
>();
...
@@ -302,10 +364,34 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk
...
@@ -302,10 +364,34 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk
// var(x) = E[x^2] - E[x]^2
// var(x) = E[x^2] - E[x]^2
var_thread_buf
(
I
)
=
var_thread_buf
(
I
)
=
mean_square_thread_buf
(
I
)
-
(
mean_thread_buf
(
I
)
*
mean_thread_buf
(
I
));
mean_square_thread_buf
(
I
)
-
(
mean_thread_buf
(
I
)
*
mean_thread_buf
(
I
));
inv_std_thread_buf
(
I
)
=
type_convert
<
ComputeDataType
>
(
1.0
f
)
/
ck
::
math
::
sqrt
(
var_thread_buf
(
I
)
+
epsilon
);
});
});
// save mean and inverse std for backward (optional)
if
(
thread_k_cluster_id
==
0
)
{
if
(
p_save_mean_global
!=
nullptr
)
{
threadwise_mean_store
.
Run
(
thread_buffer_desc_m
,
make_tuple
(
I0
),
mean_thread_buf
,
save_mean_grid_desc_m
,
save_mean_global_val_buf
);
}
if
(
p_save_inv_std_global
!=
nullptr
)
{
threadwise_inv_std_store
.
Run
(
thread_buffer_desc_m
,
make_tuple
(
I0
),
inv_std_thread_buf
,
save_inv_std_grid_desc_m
,
save_inv_std_global_val_buf
);
}
}
// normalization
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
iM
)
{
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
iM
)
{
auto
divisor
=
1
/
ck
::
math
::
sqrt
(
var_thread_buf
(
iM
)
+
epsilon
);
static_for
<
0
,
ThreadBufferNumber
,
1
>
{}([
&
](
auto
iK0
)
{
static_for
<
0
,
ThreadBufferNumber
,
1
>
{}([
&
](
auto
iK0
)
{
static_for
<
0
,
XSrcVectorSize
,
1
>
{}([
&
](
auto
iK1
)
{
static_for
<
0
,
XSrcVectorSize
,
1
>
{}([
&
](
auto
iK1
)
{
constexpr
auto
offset_m_k
=
constexpr
auto
offset_m_k
=
...
@@ -314,7 +400,7 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk
...
@@ -314,7 +400,7 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk
// normalize
// normalize
y_thread_buf
(
iK0
)(
Number
<
offset_m_k
>
{})
=
y_thread_buf
(
iK0
)(
Number
<
offset_m_k
>
{})
=
(
x_thread_buf
(
iK0
)(
Number
<
offset_m_k
>
{})
-
mean_thread_buf
(
iM
))
*
(
x_thread_buf
(
iK0
)(
Number
<
offset_m_k
>
{})
-
mean_thread_buf
(
iM
))
*
divisor
;
inv_std_thread_buf
(
iM
)
;
// gamma & beta
// gamma & beta
y_thread_buf
(
iK0
)(
Number
<
offset_m_k
>
{})
=
y_thread_buf
(
iK0
)(
Number
<
offset_m_k
>
{})
=
...
@@ -404,8 +490,30 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk
...
@@ -404,8 +490,30 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk
// var(x) = E[x^2] - E[x]^2
// var(x) = E[x^2] - E[x]^2
var_thread_buf
(
I
)
=
var_thread_buf
(
I
)
=
mean_square_thread_buf
(
I
)
-
(
mean_thread_buf
(
I
)
*
mean_thread_buf
(
I
));
mean_square_thread_buf
(
I
)
-
(
mean_thread_buf
(
I
)
*
mean_thread_buf
(
I
));
inv_std_thread_buf
(
I
)
=
1
/
ck
::
math
::
sqrt
(
var_thread_buf
(
I
)
+
epsilon
);
});
});
if
(
thread_k_cluster_id
==
0
)
{
if
(
p_save_mean_global
!=
nullptr
)
{
threadwise_mean_store
.
Run
(
thread_buffer_desc_m
,
make_tuple
(
I0
),
mean_thread_buf
,
save_mean_grid_desc_m
,
save_mean_global_val_buf
);
}
if
(
p_save_inv_std_global
!=
nullptr
)
{
threadwise_inv_std_store
.
Run
(
thread_buffer_desc_m
,
make_tuple
(
I0
),
inv_std_thread_buf
,
save_inv_std_grid_desc_m
,
save_inv_std_global_val_buf
);
}
}
auto
thread_copy_tail_m_k
=
auto
thread_copy_tail_m_k
=
(
num_k_block_tile_iteration
-
1
)
*
ThreadBufferNumber
*
thread_copy_fwd_step_m_k
;
(
num_k_block_tile_iteration
-
1
)
*
ThreadBufferNumber
*
thread_copy_fwd_step_m_k
;
...
@@ -437,7 +545,6 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk
...
@@ -437,7 +545,6 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk
});
});
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
iM
)
{
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
iM
)
{
auto
divisor
=
1
/
ck
::
math
::
sqrt
(
var_thread_buf
(
iM
)
+
epsilon
);
static_for
<
0
,
ThreadBufferNumber
,
1
>
{}([
&
](
auto
iK0
)
{
static_for
<
0
,
ThreadBufferNumber
,
1
>
{}([
&
](
auto
iK0
)
{
static_for
<
0
,
XSrcVectorSize
,
1
>
{}([
&
](
auto
iK1
)
{
static_for
<
0
,
XSrcVectorSize
,
1
>
{}([
&
](
auto
iK1
)
{
constexpr
auto
offset_m_k
=
constexpr
auto
offset_m_k
=
...
@@ -446,7 +553,7 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk
...
@@ -446,7 +553,7 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk
// normalize
// normalize
y_thread_buf
(
iK0
)(
Number
<
offset_m_k
>
{})
=
y_thread_buf
(
iK0
)(
Number
<
offset_m_k
>
{})
=
(
x_thread_buf
(
iK0
)(
Number
<
offset_m_k
>
{})
-
mean_thread_buf
(
iM
))
*
(
x_thread_buf
(
iK0
)(
Number
<
offset_m_k
>
{})
-
mean_thread_buf
(
iM
))
*
divisor
;
inv_std_thread_buf
(
iM
)
;
// gamma
// gamma
y_thread_buf
(
iK0
)(
Number
<
offset_m_k
>
{})
=
y_thread_buf
(
iK0
)(
Number
<
offset_m_k
>
{})
=
...
...
include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_selector.hpp
View file @
f30e5975
...
@@ -12,31 +12,42 @@ template <typename GridwiseReduction,
...
@@ -12,31 +12,42 @@ template <typename GridwiseReduction,
typename
GammaDataType
,
typename
GammaDataType
,
typename
BetaDataType
,
typename
BetaDataType
,
typename
YDataType
,
typename
YDataType
,
typename
SaveMeanInvStdDataType
,
typename
ComputeDataType
,
typename
ComputeDataType
,
typename
YElementwiseOperation
,
typename
YElementwiseOperation
,
typename
GridDesc_M_K
>
typename
GridDesc_M_K
,
__global__
void
kernel_normalization
(
const
GridDesc_M_K
x_grid_desc_m_k
,
typename
GridDesc_M
>
const
GridDesc_M_K
gamma_grid_desc_m_k
,
__global__
void
const
GridDesc_M_K
beta_grid_desc_m_k
,
kernel_normalization
(
const
GridDesc_M_K
x_grid_desc_m_k
,
const
GridDesc_M_K
y_grid_desc_m_k
,
const
GridDesc_M_K
gamma_grid_desc_m_k
,
index_t
num_k_block_tile_iteration
,
const
GridDesc_M_K
beta_grid_desc_m_k
,
ComputeDataType
epsilon
,
const
GridDesc_M_K
y_grid_desc_m_k
,
const
XDataType
*
const
__restrict__
p_x_global
,
const
GridDesc_M
save_mean_grid_desc_m
,
const
GammaDataType
*
const
__restrict__
p_gamma_global
,
const
GridDesc_M
save_inv_std_grid_desc_m
,
const
BetaDataType
*
const
__restrict__
p_beta_global
,
index_t
num_k_block_tile_iteration
,
YDataType
*
const
__restrict__
p_y_global
,
ComputeDataType
epsilon
,
const
YElementwiseOperation
y_elementwise_op
)
const
XDataType
*
const
__restrict__
p_x_global
,
const
GammaDataType
*
const
__restrict__
p_gamma_global
,
const
BetaDataType
*
const
__restrict__
p_beta_global
,
YDataType
*
const
__restrict__
p_y_global
,
SaveMeanInvStdDataType
*
const
__restrict__
p_save_mean_global
,
SaveMeanInvStdDataType
*
const
__restrict__
p_save_inv_std_global
,
const
YElementwiseOperation
y_elementwise_op
)
{
{
GridwiseReduction
::
Run
(
x_grid_desc_m_k
,
GridwiseReduction
::
Run
(
x_grid_desc_m_k
,
gamma_grid_desc_m_k
,
gamma_grid_desc_m_k
,
beta_grid_desc_m_k
,
beta_grid_desc_m_k
,
y_grid_desc_m_k
,
y_grid_desc_m_k
,
save_mean_grid_desc_m
,
save_inv_std_grid_desc_m
,
num_k_block_tile_iteration
,
num_k_block_tile_iteration
,
epsilon
,
epsilon
,
p_x_global
,
p_x_global
,
p_gamma_global
,
p_gamma_global
,
p_beta_global
,
p_beta_global
,
p_y_global
,
p_y_global
,
p_save_mean_global
,
p_save_inv_std_global
,
y_elementwise_op
);
y_elementwise_op
);
};
};
...
@@ -44,9 +55,11 @@ template <typename XDataType,
...
@@ -44,9 +55,11 @@ template <typename XDataType,
typename
GammaDataType
,
typename
GammaDataType
,
typename
BetaDataType
,
typename
BetaDataType
,
typename
YDataType
,
typename
YDataType
,
typename
SaveMeanInvStdDataType
,
typename
ComputeDataType
,
typename
ComputeDataType
,
typename
YElementwiseOperation
,
typename
YElementwiseOperation
,
typename
GridDesc_M_K
,
typename
GridDesc_M_K
,
typename
GridDesc_M
,
index_t
BlockSize
,
index_t
BlockSize
,
index_t
MThreadClusterSize
,
index_t
MThreadClusterSize
,
index_t
KThreadClusterSize
,
index_t
KThreadClusterSize
,
...
@@ -60,6 +73,7 @@ template <typename XDataType,
...
@@ -60,6 +73,7 @@ template <typename XDataType,
index_t
BetaSrcVectorSize
,
index_t
BetaSrcVectorSize
,
index_t
YDstVectorDim
,
index_t
YDstVectorDim
,
index_t
YDstVectorSize
,
index_t
YDstVectorSize
,
index_t
SaveMeanInvStdDstVectorSize
,
bool
UseWelford
>
bool
UseWelford
>
auto
NormalizationKernelSelector
(
bool
isSweepOnce
)
auto
NormalizationKernelSelector
(
bool
isSweepOnce
)
{
{
...
@@ -68,9 +82,11 @@ auto NormalizationKernelSelector(bool isSweepOnce)
...
@@ -68,9 +82,11 @@ auto NormalizationKernelSelector(bool isSweepOnce)
GammaDataType
,
GammaDataType
,
BetaDataType
,
BetaDataType
,
YDataType
,
YDataType
,
SaveMeanInvStdDataType
,
ComputeDataType
,
ComputeDataType
,
YElementwiseOperation
,
YElementwiseOperation
,
GridDesc_M_K
,
GridDesc_M_K
,
GridDesc_M
,
BlockSize
,
BlockSize
,
MThreadClusterSize
,
MThreadClusterSize
,
KThreadClusterSize
,
KThreadClusterSize
,
...
@@ -84,15 +100,18 @@ auto NormalizationKernelSelector(bool isSweepOnce)
...
@@ -84,15 +100,18 @@ auto NormalizationKernelSelector(bool isSweepOnce)
BetaSrcVectorSize
,
BetaSrcVectorSize
,
YDstVectorDim
,
YDstVectorDim
,
YDstVectorSize
,
YDstVectorSize
,
SaveMeanInvStdDstVectorSize
,
false
>
;
false
>
;
using
GridwiseNormalizationSweepOnceNaive
=
using
GridwiseNormalizationSweepOnceNaive
=
GridwiseNormalizationNaiveVariance_mk_to_mk
<
XDataType
,
GridwiseNormalizationNaiveVariance_mk_to_mk
<
XDataType
,
GammaDataType
,
GammaDataType
,
BetaDataType
,
BetaDataType
,
YDataType
,
YDataType
,
SaveMeanInvStdDataType
,
ComputeDataType
,
ComputeDataType
,
YElementwiseOperation
,
YElementwiseOperation
,
GridDesc_M_K
,
GridDesc_M_K
,
GridDesc_M
,
BlockSize
,
BlockSize
,
MThreadClusterSize
,
MThreadClusterSize
,
KThreadClusterSize
,
KThreadClusterSize
,
...
@@ -106,15 +125,18 @@ auto NormalizationKernelSelector(bool isSweepOnce)
...
@@ -106,15 +125,18 @@ auto NormalizationKernelSelector(bool isSweepOnce)
BetaSrcVectorSize
,
BetaSrcVectorSize
,
YDstVectorDim
,
YDstVectorDim
,
YDstVectorSize
,
YDstVectorSize
,
SaveMeanInvStdDstVectorSize
,
true
>
;
true
>
;
using
GridwiseNormalizationGenericWelford
=
using
GridwiseNormalizationGenericWelford
=
GridwiseNormalizationWelfordVariance_mk_to_mk
<
XDataType
,
GridwiseNormalizationWelfordVariance_mk_to_mk
<
XDataType
,
GammaDataType
,
GammaDataType
,
BetaDataType
,
BetaDataType
,
YDataType
,
YDataType
,
SaveMeanInvStdDataType
,
ComputeDataType
,
ComputeDataType
,
YElementwiseOperation
,
YElementwiseOperation
,
GridDesc_M_K
,
GridDesc_M_K
,
GridDesc_M
,
BlockSize
,
BlockSize
,
MThreadClusterSize
,
MThreadClusterSize
,
KThreadClusterSize
,
KThreadClusterSize
,
...
@@ -128,15 +150,18 @@ auto NormalizationKernelSelector(bool isSweepOnce)
...
@@ -128,15 +150,18 @@ auto NormalizationKernelSelector(bool isSweepOnce)
BetaSrcVectorSize
,
BetaSrcVectorSize
,
YDstVectorDim
,
YDstVectorDim
,
YDstVectorSize
,
YDstVectorSize
,
SaveMeanInvStdDstVectorSize
,
false
>
;
false
>
;
using
GridwiseNormalizationSweepOnceWelford
=
using
GridwiseNormalizationSweepOnceWelford
=
GridwiseNormalizationWelfordVariance_mk_to_mk
<
XDataType
,
GridwiseNormalizationWelfordVariance_mk_to_mk
<
XDataType
,
GammaDataType
,
GammaDataType
,
BetaDataType
,
BetaDataType
,
YDataType
,
YDataType
,
SaveMeanInvStdDataType
,
ComputeDataType
,
ComputeDataType
,
YElementwiseOperation
,
YElementwiseOperation
,
GridDesc_M_K
,
GridDesc_M_K
,
GridDesc_M
,
BlockSize
,
BlockSize
,
MThreadClusterSize
,
MThreadClusterSize
,
KThreadClusterSize
,
KThreadClusterSize
,
...
@@ -150,6 +175,7 @@ auto NormalizationKernelSelector(bool isSweepOnce)
...
@@ -150,6 +175,7 @@ auto NormalizationKernelSelector(bool isSweepOnce)
BetaSrcVectorSize
,
BetaSrcVectorSize
,
YDstVectorDim
,
YDstVectorDim
,
YDstVectorSize
,
YDstVectorSize
,
SaveMeanInvStdDstVectorSize
,
true
>
;
true
>
;
if
constexpr
(
UseWelford
)
if
constexpr
(
UseWelford
)
...
@@ -159,17 +185,21 @@ auto NormalizationKernelSelector(bool isSweepOnce)
...
@@ -159,17 +185,21 @@ auto NormalizationKernelSelector(bool isSweepOnce)
GammaDataType
,
GammaDataType
,
BetaDataType
,
BetaDataType
,
YDataType
,
YDataType
,
SaveMeanInvStdDataType
,
ComputeDataType
,
ComputeDataType
,
YElementwiseOperation
,
YElementwiseOperation
,
GridDesc_M_K
>
GridDesc_M_K
,
GridDesc_M
>
:
kernel_normalization
<
GridwiseNormalizationGenericWelford
,
:
kernel_normalization
<
GridwiseNormalizationGenericWelford
,
XDataType
,
XDataType
,
GammaDataType
,
GammaDataType
,
BetaDataType
,
BetaDataType
,
YDataType
,
YDataType
,
SaveMeanInvStdDataType
,
ComputeDataType
,
ComputeDataType
,
YElementwiseOperation
,
YElementwiseOperation
,
GridDesc_M_K
>
;
GridDesc_M_K
,
GridDesc_M
>
;
}
}
else
else
{
{
...
@@ -178,17 +208,21 @@ auto NormalizationKernelSelector(bool isSweepOnce)
...
@@ -178,17 +208,21 @@ auto NormalizationKernelSelector(bool isSweepOnce)
GammaDataType
,
GammaDataType
,
BetaDataType
,
BetaDataType
,
YDataType
,
YDataType
,
SaveMeanInvStdDataType
,
ComputeDataType
,
ComputeDataType
,
YElementwiseOperation
,
YElementwiseOperation
,
GridDesc_M_K
>
GridDesc_M_K
,
GridDesc_M
>
:
kernel_normalization
<
GridwiseNormalizationGenericNaive
,
:
kernel_normalization
<
GridwiseNormalizationGenericNaive
,
XDataType
,
XDataType
,
GammaDataType
,
GammaDataType
,
BetaDataType
,
BetaDataType
,
YDataType
,
YDataType
,
SaveMeanInvStdDataType
,
ComputeDataType
,
ComputeDataType
,
YElementwiseOperation
,
YElementwiseOperation
,
GridDesc_M_K
>
;
GridDesc_M_K
,
GridDesc_M
>
;
}
}
}
}
...
...
Prev
1
2
3
4
5
6
7
8
…
11
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment