Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
1b5af83d
Commit
1b5af83d
authored
Oct 20, 2023
by
illsilin
Browse files
Merge branch 'develop' into lwpck-976
parents
aac26d32
1fd27d52
Changes
176
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1873 additions
and
369 deletions
+1873
-369
include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_splitk_c_shuffle.hpp
...tion/gpu/device/impl/device_gemm_xdl_splitk_c_shuffle.hpp
+86
-16
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp
...device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp
+1
-1
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_dl.hpp
...ion/gpu/device/impl/device_grouped_conv_bwd_weight_dl.hpp
+3
-28
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle.hpp
...ice/impl/device_grouped_conv_bwd_weight_wmma_cshuffle.hpp
+877
-0
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp
...vice/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp
+7
-28
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp
...impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp
+1
-1
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp
...r_operation/gpu/device/impl/device_grouped_conv_utils.hpp
+8
-1
include/ck/tensor_operation/gpu/device/impl/device_normalization_impl.hpp
...r_operation/gpu/device/impl/device_normalization_impl.hpp
+98
-23
include/ck/tensor_operation/gpu/device/impl/device_normalization_splitk_impl.hpp
...tion/gpu/device/impl/device_normalization_splitk_impl.hpp
+174
-84
include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp
...or_operation/gpu/element/unary_element_wise_operation.hpp
+103
-12
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp
...ation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp
+27
-14
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp
...tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp
+18
-9
include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_naive_variance.hpp
...d/normalization/gridwise_normalization_naive_variance.hpp
+112
-5
include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_selector.hpp
...pu/grid/normalization/gridwise_normalization_selector.hpp
+50
-16
include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_splitk_2nd.hpp
.../grid/normalization/gridwise_normalization_splitk_2nd.hpp
+85
-4
include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_welford_variance.hpp
...normalization/gridwise_normalization_welford_variance.hpp
+110
-7
include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp
include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp
+4
-28
include/ck/utility/amd_xdlops.hpp
include/ck/utility/amd_xdlops.hpp
+2
-13
include/ck/utility/data_type.hpp
include/ck/utility/data_type.hpp
+9
-26
include/ck/utility/f8_utils.hpp
include/ck/utility/f8_utils.hpp
+98
-53
No files found.
include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_splitk_c_shuffle.hpp
View file @
1b5af83d
...
@@ -127,7 +127,50 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
...
@@ -127,7 +127,50 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
PipelineVer
,
PipelineVer
,
ComputeType
>
;
ComputeType
>
;
using
Argument
=
typename
GridwiseGemm
::
Argument
;
struct
Argument
:
public
GridwiseGemm
::
Argument
{
Argument
(
const
ADataType
*
p_a_grid_
,
const
BDataType
*
p_b_grid_
,
CDataType
*
p_c_grid_
,
index_t
M_
,
index_t
N_
,
index_t
K_
,
index_t
StrideA_
,
index_t
StrideB_
,
index_t
StrideC_
,
index_t
MPadded_
,
index_t
NPadded_
,
index_t
KPadded_
,
index_t
K0_
,
index_t
k_batch_
,
AElementwiseOperation
a_element_op_
,
BElementwiseOperation
b_element_op_
,
CElementwiseOperation
c_element_op_
)
:
GridwiseGemm
::
Argument
(
p_a_grid_
,
p_b_grid_
,
p_c_grid_
,
M_
,
N_
,
K_
,
StrideA_
,
StrideB_
,
StrideC_
,
MPadded_
,
NPadded_
,
KPadded_
,
K0_
,
k_batch_
),
a_element_op
(
a_element_op_
),
b_element_op
(
b_element_op_
),
c_element_op
(
c_element_op_
)
{
}
AElementwiseOperation
a_element_op
;
BElementwiseOperation
b_element_op
;
CElementwiseOperation
c_element_op
;
};
using
DefaultBlock2CTileMap
=
typename
GridwiseGemm
::
DefaultBlock2CTileMap
;
using
DefaultBlock2CTileMap
=
typename
GridwiseGemm
::
DefaultBlock2CTileMap
;
// Invoker
// Invoker
...
@@ -168,8 +211,17 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
...
@@ -168,8 +211,17 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
karg
.
M
*
karg
.
N
*
sizeof
(
CDataType
),
karg
.
M
*
karg
.
N
*
sizeof
(
CDataType
),
stream_config
.
stream_id_
));
stream_config
.
stream_id_
));
ave_time
=
launch_and_time_kernel
(
ave_time
=
stream_config
,
kernel
,
dim3
(
gdx
,
gdy
,
gdz
),
dim3
(
BlockSize
),
0
,
karg
,
b2c_map
);
launch_and_time_kernel
(
stream_config
,
kernel
,
dim3
(
gdx
,
gdy
,
gdz
),
dim3
(
BlockSize
),
0
,
static_cast
<
typename
GridwiseGemm
::
Argument
>
(
karg
),
b2c_map
,
karg
.
a_element_op
,
karg
.
b_element_op
,
karg
.
c_element_op
);
};
};
if
(
has_main_k0_block_loop
)
if
(
has_main_k0_block_loop
)
...
@@ -180,7 +232,10 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
...
@@ -180,7 +232,10 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
kernel_gemm_xdlops_v2r4r2_simplified
<
GridwiseGemm
,
kernel_gemm_xdlops_v2r4r2_simplified
<
GridwiseGemm
,
true
,
true
,
InMemoryDataOperationEnum
::
Set
,
InMemoryDataOperationEnum
::
Set
,
DefaultBlock2CTileMap
>
;
DefaultBlock2CTileMap
,
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
>
;
Run
(
kernel
);
Run
(
kernel
);
}
}
...
@@ -190,7 +245,10 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
...
@@ -190,7 +245,10 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
kernel_gemm_xdlops_v2r4r2_simplified
<
GridwiseGemm
,
kernel_gemm_xdlops_v2r4r2_simplified
<
GridwiseGemm
,
true
,
true
,
InMemoryDataOperationEnum
::
AtomicAdd
,
InMemoryDataOperationEnum
::
AtomicAdd
,
DefaultBlock2CTileMap
>
;
DefaultBlock2CTileMap
,
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
>
;
Run
(
kernel
);
Run
(
kernel
);
}
}
...
@@ -203,7 +261,10 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
...
@@ -203,7 +261,10 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
kernel_gemm_xdlops_v2r4r2_simplified
<
GridwiseGemm
,
kernel_gemm_xdlops_v2r4r2_simplified
<
GridwiseGemm
,
false
,
false
,
InMemoryDataOperationEnum
::
Set
,
InMemoryDataOperationEnum
::
Set
,
DefaultBlock2CTileMap
>
;
DefaultBlock2CTileMap
,
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
>
;
Run
(
kernel
);
Run
(
kernel
);
}
}
...
@@ -213,7 +274,10 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
...
@@ -213,7 +274,10 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
kernel_gemm_xdlops_v2r4r2_simplified
<
GridwiseGemm
,
kernel_gemm_xdlops_v2r4r2_simplified
<
GridwiseGemm
,
false
,
false
,
InMemoryDataOperationEnum
::
AtomicAdd
,
InMemoryDataOperationEnum
::
AtomicAdd
,
DefaultBlock2CTileMap
>
;
DefaultBlock2CTileMap
,
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
>
;
Run
(
kernel
);
Run
(
kernel
);
}
}
...
@@ -261,12 +325,12 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
...
@@ -261,12 +325,12 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
index_t
StrideA
,
index_t
StrideA
,
index_t
StrideB
,
index_t
StrideB
,
index_t
StrideC
,
index_t
StrideC
,
AElementwiseOperation
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
,
CElementwiseOperation
c_element_op
,
index_t
KBatch
)
index_t
KBatch
)
{
{
return
Argument
{
p_a
,
return
Argument
(
p_a
,
p_b
,
p_b
,
p_c
,
p_c
,
M
,
M
,
...
@@ -279,7 +343,10 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
...
@@ -279,7 +343,10 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
GridwiseGemm
::
CalculateNPadded
(
N
),
GridwiseGemm
::
CalculateNPadded
(
N
),
GridwiseGemm
::
CalculateKPadded
(
K
,
KBatch
),
GridwiseGemm
::
CalculateKPadded
(
K
,
KBatch
),
GridwiseGemm
::
CalculateK0
(
K
,
KBatch
),
GridwiseGemm
::
CalculateK0
(
K
,
KBatch
),
KBatch
};
KBatch
,
a_element_op
,
b_element_op
,
c_element_op
);
}
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
...
@@ -294,9 +361,9 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
...
@@ -294,9 +361,9 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
index_t
StrideA
,
index_t
StrideA
,
index_t
StrideB
,
index_t
StrideB
,
index_t
StrideC
,
index_t
StrideC
,
AElementwiseOperation
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
,
CElementwiseOperation
c_element_op
,
ck
::
index_t
KBatch
=
1
)
override
ck
::
index_t
KBatch
=
1
)
override
{
{
return
std
::
make_unique
<
Argument
>
(
static_cast
<
const
ADataType
*>
(
p_a
),
return
std
::
make_unique
<
Argument
>
(
static_cast
<
const
ADataType
*>
(
p_a
),
...
@@ -312,7 +379,10 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
...
@@ -312,7 +379,10 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
GridwiseGemm
::
CalculateNPadded
(
N
),
GridwiseGemm
::
CalculateNPadded
(
N
),
GridwiseGemm
::
CalculateKPadded
(
K
,
KBatch
),
GridwiseGemm
::
CalculateKPadded
(
K
,
KBatch
),
GridwiseGemm
::
CalculateK0
(
K
,
KBatch
),
GridwiseGemm
::
CalculateK0
(
K
,
KBatch
),
KBatch
);
KBatch
,
a_element_op
,
b_element_op
,
c_element_op
);
}
}
// polymorphic
// polymorphic
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp
View file @
1b5af83d
...
@@ -565,7 +565,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle
...
@@ -565,7 +565,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle
auto
launch_kernel
=
[
&
](
auto
has_main_k_block_loop
)
{
auto
launch_kernel
=
[
&
](
auto
has_main_k_block_loop
)
{
constexpr
bool
has_main_loop
=
has_main_k_block_loop
.
value
;
constexpr
bool
has_main_loop
=
has_main_k_block_loop
.
value
;
const
auto
kernel
=
kernel_grouped_conv_
fwd_
multiple_d_wmma_cshuffle
<
const
auto
kernel
=
kernel_grouped_conv_multiple_d_wmma_cshuffle
<
GridwiseGemm
,
GridwiseGemm
,
ADataType
,
ADataType
,
BDataType
,
BDataType
,
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_dl.hpp
View file @
1b5af83d
...
@@ -12,6 +12,7 @@
...
@@ -12,6 +12,7 @@
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_conv_bwd_weight.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_conv_bwd_weight.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp"
#include "ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp"
#include "ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_dl_v1r3.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_dl_v1r3.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
...
@@ -22,32 +23,6 @@ namespace ck {
...
@@ -22,32 +23,6 @@ namespace ck {
namespace
tensor_operation
{
namespace
tensor_operation
{
namespace
device
{
namespace
device
{
namespace
{
struct
ComputePtrOffsetOfStridedBatch
{
__host__
__device__
constexpr
long_index_t
GetAPtrOffset
(
index_t
g_idx
)
const
{
return
g_idx
*
static_cast
<
long_index_t
>
(
BatchStrideA_
);
}
__host__
__device__
constexpr
long_index_t
GetBPtrOffset
(
index_t
g_idx
)
const
{
return
g_idx
*
static_cast
<
long_index_t
>
(
BatchStrideB_
);
}
__host__
__device__
constexpr
long_index_t
GetCPtrOffset
(
index_t
g_idx
)
const
{
return
g_idx
*
static_cast
<
long_index_t
>
(
BatchStrideC_
);
}
index_t
BatchStrideA_
;
index_t
BatchStrideB_
;
index_t
BatchStrideC_
;
};
}
// namespace
template
<
typename
GridwiseGemm
,
template
<
typename
GridwiseGemm
,
typename
FloatAB
,
typename
FloatAB
,
typename
FloatC
,
typename
FloatC
,
...
@@ -952,7 +927,7 @@ struct DeviceGroupedConvBwdWeight_Dl : public DeviceGroupedConvBwdWeight<NDimSpa
...
@@ -952,7 +927,7 @@ struct DeviceGroupedConvBwdWeight_Dl : public DeviceGroupedConvBwdWeight<NDimSpa
Block2CTileMap
block_2_ctile_map_
;
Block2CTileMap
block_2_ctile_map_
;
// for computing batch offset
// for computing batch offset
ComputePtrOffsetOfStridedBatch
compute_ptr_offset_of_batch_
;
ComputePtrOffsetOfStridedBatch
<
I0
>
compute_ptr_offset_of_batch_
;
// element-wise op
// element-wise op
OutElementwiseOperation
a_element_op_
;
OutElementwiseOperation
a_element_op_
;
...
@@ -1024,7 +999,7 @@ struct DeviceGroupedConvBwdWeight_Dl : public DeviceGroupedConvBwdWeight<NDimSpa
...
@@ -1024,7 +999,7 @@ struct DeviceGroupedConvBwdWeight_Dl : public DeviceGroupedConvBwdWeight<NDimSpa
remove_reference_t
<
DeviceOp
::
BGridDesc_B_K0_N0_N1_K1
>
,
remove_reference_t
<
DeviceOp
::
BGridDesc_B_K0_N0_N1_K1
>
,
remove_reference_t
<
DeviceOp
::
CGridDesc_M0_M10_M11_N0_N10_N11
>
,
remove_reference_t
<
DeviceOp
::
CGridDesc_M0_M10_M11_N0_N10_N11
>
,
remove_reference_t
<
DeviceOp
::
Block2CTileMap
>
,
remove_reference_t
<
DeviceOp
::
Block2CTileMap
>
,
ComputePtrOffsetOfStridedBatch
,
ComputePtrOffsetOfStridedBatch
<
I0
>
,
has_main_loop
,
has_main_loop
,
has_double_loop
>
;
has_double_loop
>
;
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle.hpp
0 → 100644
View file @
1b5af83d
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <numeric>
#include <sstream>
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_conv_bwd_weight.hpp"
#include "ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
template
<
index_t
NDimSpatial
,
typename
InLayout
,
typename
WeiLayout
,
typename
OutLayout
,
typename
InDataType
,
typename
WeiDataType
,
typename
OutDataType
,
typename
AccDataType
,
typename
InElementwiseOperation
,
typename
WeiElementwiseOperation
,
typename
OutElementwiseOperation
,
ConvolutionBackwardWeightSpecialization
ConvBackwardWeightSpecialization
,
index_t
BlockSize
,
index_t
MPerBlock
,
index_t
NPerBlock
,
index_t
K0PerBlock
,
index_t
K1
,
index_t
MPerWMMA
,
index_t
NPerWMMA
,
index_t
MRepeat
,
index_t
NRepeat
,
typename
ABlockTransferThreadClusterLengths_K0_M_K1
,
typename
ABlockTransferThreadClusterArrangeOrder
,
typename
ABlockTransferSrcAccessOrder
,
index_t
ABlockTransferSrcVectorDim
,
index_t
ABlockTransferSrcScalarPerVector
,
index_t
ABlockTransferDstScalarPerVector_K1
,
bool
ABlockLdsAddExtraM
,
typename
BBlockTransferThreadClusterLengths_K0_N_K1
,
typename
BBlockTransferThreadClusterArrangeOrder
,
typename
BBlockTransferSrcAccessOrder
,
index_t
BBlockTransferSrcVectorDim
,
index_t
BBlockTransferSrcScalarPerVector
,
index_t
BBlockTransferDstScalarPerVector_K1
,
bool
BBlockLdsAddExtraN
,
index_t
CShuffleMRepeatPerShuffle
,
index_t
CShuffleNRepeatPerShuffle
,
typename
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
index_t
CShuffleBlockTransferScalarPerVector_NPerBlock
,
index_t
NumGemmKPrefetchStage
=
1
,
LoopScheduler
LoopSched
=
make_default_loop_scheduler
(),
ck
::
PipelineVersion
PipelineVer
=
ck
::
PipelineVersion
::
v1
,
typename
ck
::
enable_if
<
NDimSpatial
==
3
,
bool
>
::
type
=
false
>
struct
DeviceGroupedConvBwdWeight_Wmma_CShuffle
:
public
DeviceGroupedConvBwdWeight
<
NDimSpatial
,
InLayout
,
WeiLayout
,
OutLayout
,
InDataType
,
WeiDataType
,
OutDataType
,
InElementwiseOperation
,
WeiElementwiseOperation
,
OutElementwiseOperation
>
{
using
DeviceOp
=
DeviceGroupedConvBwdWeight_Wmma_CShuffle
;
using
ADataType
=
OutDataType
;
using
BDataType
=
InDataType
;
using
CDataType
=
WeiDataType
;
using
AElementwiseOperation
=
OutElementwiseOperation
;
using
BElementwiseOperation
=
InElementwiseOperation
;
using
CElementwiseOperation
=
WeiElementwiseOperation
;
// TODO make A/B datatype different
using
ABDataType
=
InDataType
;
// 3d
static
constexpr
bool
is_NDHWGK_GKZYXC_NDHWGC
=
is_same_v
<
InLayout
,
tensor_layout
::
convolution
::
NDHWGC
>
&&
is_same_v
<
WeiLayout
,
tensor_layout
::
convolution
::
GKZYXC
>
&&
is_same_v
<
OutLayout
,
tensor_layout
::
convolution
::
NDHWGK
>
;
static
constexpr
bool
is_GNDHWK_GKZYXC_GNDHWC
=
is_same_v
<
InLayout
,
tensor_layout
::
convolution
::
GNDHWC
>
&&
is_same_v
<
WeiLayout
,
tensor_layout
::
convolution
::
GKZYXC
>
&&
is_same_v
<
OutLayout
,
tensor_layout
::
convolution
::
GNDHWK
>
;
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
static
constexpr
auto
I4
=
Number
<
4
>
{};
static
constexpr
auto
I5
=
Number
<
5
>
{};
static
constexpr
auto
GemmK1Number
=
Number
<
K1
>
{};
static
constexpr
index_t
KPerBlock
=
K0PerBlock
*
GemmK1Number
;
template
<
index_t
NDim
,
typename
ck
::
enable_if
<
NDim
==
3
,
bool
>
::
type
=
false
>
constexpr
static
auto
make_out_grid_desc
(
const
index_t
N
,
const
index_t
Do
,
const
index_t
Ho
,
const
index_t
Wo
,
const
index_t
K
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
output_strides
)
{
const
index_t
WoStride
=
output_strides
[
5
];
const
auto
KStride
=
Number
<
1
>
{};
return
make_naive_tensor_descriptor
(
make_tuple
(
N
*
Do
*
Ho
*
Wo
,
K
),
make_tuple
(
WoStride
,
KStride
));
}
template
<
index_t
NDim
,
typename
ck
::
enable_if
<
NDim
==
3
,
bool
>
::
type
=
false
>
constexpr
static
auto
make_in_grid_desc
(
const
index_t
N
,
const
index_t
Di
,
const
index_t
Hi
,
const
index_t
Wi
,
const
index_t
C
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
input_strides
)
{
const
index_t
NStride
=
input_strides
[
1
];
const
index_t
DiStride
=
input_strides
[
3
];
const
index_t
HiStride
=
input_strides
[
4
];
const
index_t
WiStride
=
input_strides
[
5
];
const
auto
CStride
=
input_strides
[
2
];
if
constexpr
(
ConvBackwardWeightSpecialization
==
ConvolutionBackwardWeightSpecialization
::
Filter1x1Stride1Pad0
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
N
*
Di
*
Hi
*
Wi
,
C
),
make_tuple
(
WiStride
,
CStride
));
}
else
{
return
make_naive_tensor_descriptor
(
make_tuple
(
N
,
Di
,
Hi
,
Wi
,
C
),
make_tuple
(
NStride
,
DiStride
,
HiStride
,
WiStride
,
CStride
));
}
}
template
<
index_t
NDim
,
typename
ck
::
enable_if
<
NDim
==
3
,
bool
>
::
type
=
false
>
constexpr
static
auto
make_wei_grid_desc
(
const
index_t
K
,
const
index_t
Z
,
const
index_t
Y
,
const
index_t
X
,
const
index_t
C
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
weights_strides
)
{
const
auto
CStride
=
Number
<
1
>
{};
const
auto
KStride
=
weights_strides
[
1
];
return
make_naive_tensor_descriptor
(
make_tuple
(
K
,
Z
*
Y
*
X
*
C
),
make_tuple
(
KStride
,
CStride
));
}
template
<
index_t
NDim
,
typename
ck
::
enable_if
<
NDim
==
3
,
bool
>
::
type
=
false
>
static
auto
MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N
(
const
index_t
N
,
const
index_t
K
,
const
index_t
C
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_spatial_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
filter_spatial_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
output_spatial_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
input_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
weights_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
output_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_dilations
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_left_pads
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_right_pads
)
{
using
namespace
ck
;
const
index_t
Di
=
input_spatial_lengths
[
0
];
const
index_t
Hi
=
input_spatial_lengths
[
1
];
const
index_t
Wi
=
input_spatial_lengths
[
2
];
const
index_t
Do
=
output_spatial_lengths
[
0
];
const
index_t
Ho
=
output_spatial_lengths
[
1
];
const
index_t
Wo
=
output_spatial_lengths
[
2
];
const
index_t
Z
=
filter_spatial_lengths
[
0
];
const
index_t
Y
=
filter_spatial_lengths
[
1
];
const
index_t
X
=
filter_spatial_lengths
[
2
];
const
index_t
ConvStrideD
=
conv_filter_strides
[
0
];
const
index_t
ConvStrideH
=
conv_filter_strides
[
1
];
const
index_t
ConvStrideW
=
conv_filter_strides
[
2
];
const
index_t
ConvDilationD
=
conv_filter_dilations
[
0
];
const
index_t
ConvDilationH
=
conv_filter_dilations
[
1
];
const
index_t
ConvDilationW
=
conv_filter_dilations
[
2
];
const
index_t
InLeftPadD
=
input_left_pads
[
0
];
const
index_t
InLeftPadH
=
input_left_pads
[
1
];
const
index_t
InLeftPadW
=
input_left_pads
[
2
];
const
index_t
InRightPadD
=
input_right_pads
[
0
];
const
index_t
InRightPadH
=
input_right_pads
[
1
];
const
index_t
InRightPadW
=
input_right_pads
[
2
];
const
index_t
GemmKTotal
=
N
*
Do
*
Ho
*
Wo
;
const
index_t
GemmM
=
K
;
const
index_t
GemmN
=
C
*
Z
*
X
*
Y
;
const
auto
PadGemmM
=
(
MPerBlock
-
GemmM
%
MPerBlock
)
%
MPerBlock
;
const
auto
PadGemmN
=
(
NPerBlock
-
GemmN
%
NPerBlock
)
%
NPerBlock
;
const
index_t
GemmK0
=
math
::
integer_divide_ceil
(
GemmKTotal
,
GemmK1Number
*
K0PerBlock
)
*
K0PerBlock
;
const
index_t
GemmKPad
=
GemmK0
*
GemmK1Number
;
const
auto
out_grid_desc
=
make_out_grid_desc
<
NDim
>
(
N
,
Do
,
Ho
,
Wo
,
K
,
output_strides
);
const
auto
in_grid_desc
=
make_in_grid_desc
<
NDim
>
(
N
,
Di
,
Hi
,
Wi
,
C
,
input_strides
);
const
auto
wei_grid_desc
=
make_wei_grid_desc
<
NDim
>
(
K
,
Z
,
Y
,
X
,
C
,
weights_strides
);
if
constexpr
(
ConvBackwardWeightSpecialization
==
ConvolutionBackwardWeightSpecialization
::
Filter1x1Stride1Pad0
)
{
// A: output tensor
const
auto
out_gemmkpad_gemmm_grid_desc
=
transform_tensor_descriptor
(
out_grid_desc
,
make_tuple
(
make_right_pad_transform
(
GemmKTotal
,
GemmKPad
-
GemmKTotal
),
make_pass_through_transform
(
GemmM
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc
=
transform_tensor_descriptor
(
out_gemmkpad_gemmm_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
GemmK0
,
GemmK1Number
)),
make_pass_through_transform
(
GemmM
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
// B: input tensor
const
auto
in_gemmkpad_gemmn_grid_desc
=
transform_tensor_descriptor
(
in_grid_desc
,
make_tuple
(
make_right_pad_transform
(
GemmKTotal
,
GemmKPad
-
GemmKTotal
),
make_pass_through_transform
(
GemmN
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc
=
transform_tensor_descriptor
(
in_gemmkpad_gemmn_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
GemmK0
,
GemmK1Number
)),
make_pass_through_transform
(
GemmN
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
make_tuple
(
out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc
,
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc
,
wei_grid_desc
);
}
else
{
// A: output tensor
const
auto
out_gemmkpad_gemmm_grid_desc
=
transform_tensor_descriptor
(
out_grid_desc
,
make_tuple
(
make_right_pad_transform
(
GemmKTotal
,
GemmKPad
-
GemmKTotal
),
make_pass_through_transform
(
GemmM
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc
=
transform_tensor_descriptor
(
out_gemmkpad_gemmm_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
GemmK0
,
GemmK1Number
)),
make_pass_through_transform
(
GemmM
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
// B: input tensor
const
auto
in_n_dip_hip_wip_c_grid_desc
=
transform_tensor_descriptor
(
in_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_pad_transform
(
Di
,
InLeftPadD
,
InRightPadD
),
make_pad_transform
(
Hi
,
InLeftPadH
,
InRightPadH
),
make_pad_transform
(
Wi
,
InLeftPadW
,
InRightPadW
),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{}));
const
auto
in_n_z_do_y_ho_x_wo_c_grid_desc
=
transform_tensor_descriptor
(
in_n_dip_hip_wip_c_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_embed_transform
(
make_tuple
(
Z
,
Do
),
make_tuple
(
ConvDilationD
,
ConvStrideD
)),
make_embed_transform
(
make_tuple
(
Y
,
Ho
),
make_tuple
(
ConvDilationH
,
ConvStrideH
)),
make_embed_transform
(
make_tuple
(
X
,
Wo
),
make_tuple
(
ConvDilationW
,
ConvStrideW
)),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
,
4
>
{},
Sequence
<
5
,
6
>
{},
Sequence
<
7
>
{}));
const
auto
in_gemmktotal_gemmn_grid_desc
=
transform_tensor_descriptor
(
in_n_z_do_y_ho_x_wo_c_grid_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
Z
,
Y
,
X
,
C
)),
make_merge_transform
(
make_tuple
(
N
,
Do
,
Ho
,
Wo
))),
make_tuple
(
Sequence
<
1
,
3
,
5
,
7
>
{},
Sequence
<
0
,
2
,
4
,
6
>
{}),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}));
const
auto
in_gemmkpad_gemmn_grid_desc
=
transform_tensor_descriptor
(
in_gemmktotal_gemmn_grid_desc
,
make_tuple
(
make_right_pad_transform
(
GemmKTotal
,
GemmKPad
-
GemmKTotal
),
make_pass_through_transform
(
GemmN
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc
=
transform_tensor_descriptor
(
in_gemmkpad_gemmn_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
GemmK0
,
GemmK1Number
)),
make_pass_through_transform
(
GemmN
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
// Pad
const
auto
out_gemmkbatch_gemmk0_gemmm_gemmk1_pad_grid_desc
=
transform_tensor_descriptor
(
out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc
,
make_tuple
(
make_pass_through_transform
(
GemmK0
),
make_right_pad_transform
(
GemmM
,
PadGemmM
),
make_pass_through_transform
(
GemmK1Number
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}));
const
auto
in_gemmkbatch_gemmk0_gemmn_gemmk1_pad_grid_desc
=
transform_tensor_descriptor
(
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc
,
make_tuple
(
make_pass_through_transform
(
GemmK0
),
make_right_pad_transform
(
GemmN
,
PadGemmN
),
make_pass_through_transform
(
GemmK1Number
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}));
const
auto
wei_gemmm_gemmn_pad_grid_desc
=
transform_tensor_descriptor
(
wei_grid_desc
,
make_tuple
(
make_right_pad_transform
(
GemmM
,
PadGemmM
),
make_right_pad_transform
(
GemmN
,
PadGemmN
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
return
make_tuple
(
out_gemmkbatch_gemmk0_gemmm_gemmk1_pad_grid_desc
,
in_gemmkbatch_gemmk0_gemmn_gemmk1_pad_grid_desc
,
wei_gemmm_gemmn_pad_grid_desc
);
}
}
template
<
index_t
NDim
,
typename
ck
::
enable_if
<
NDim
==
3
,
bool
>
::
type
=
false
>
static
auto
GetABCGridDesc
()
{
const
index_t
dim
=
1
;
const
std
::
array
<
index_t
,
NDimSpatial
>
lengths
{
1
,
1
,
1
};
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>
strides
{
1
,
1
,
1
,
1
,
1
,
1
};
const
std
::
array
<
index_t
,
NDimSpatial
>
params
{
1
,
1
,
1
};
return
MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N
<
3
>
(
dim
,
dim
,
dim
,
lengths
,
lengths
,
lengths
,
strides
,
strides
,
strides
,
params
,
params
,
params
,
params
);
}
using
ABCGridDescs
=
decltype
(
GetABCGridDesc
<
NDimSpatial
>
());
using
AGridDesc_K0_M_K1
=
remove_cvref_t
<
decltype
(
ABCGridDescs
{}[
I0
])
>
;
using
BGridDesc_K0_N_K1
=
remove_cvref_t
<
decltype
(
ABCGridDescs
{}[
I1
])
>
;
using
CGridDesc_M_N
=
remove_cvref_t
<
decltype
(
ABCGridDescs
{}[
I2
])
>
;
using
GridwiseGemm
=
GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle
<
// DataType Family
ADataType
,
BDataType
,
AccDataType
,
CDataType
,
Tuple
<>
,
CDataType
,
// InMemory Data Descriptor
AGridDesc_K0_M_K1
,
BGridDesc_K0_N_K1
,
Tuple
<>
,
CGridDesc_M_N
,
// ElementwiseOp Family
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
,
InMemoryDataOperationEnum
::
Set
,
// Tiling Family
MPerBlock
,
NPerBlock
,
K0PerBlock
,
MPerWMMA
,
NPerWMMA
,
K1
,
MRepeat
,
NRepeat
,
// ThreadCluster Family
BlockSize
,
ABlockTransferThreadClusterLengths_K0_M_K1
,
ABlockTransferThreadClusterArrangeOrder
,
ABlockTransferSrcAccessOrder
,
ABlockTransferSrcVectorDim
,
ABlockTransferSrcScalarPerVector
,
ABlockTransferDstScalarPerVector_K1
,
false
,
ABlockLdsAddExtraM
,
BBlockTransferThreadClusterLengths_K0_N_K1
,
BBlockTransferThreadClusterArrangeOrder
,
BBlockTransferSrcAccessOrder
,
BBlockTransferSrcVectorDim
,
BBlockTransferSrcScalarPerVector
,
BBlockTransferDstScalarPerVector_K1
,
false
,
BBlockLdsAddExtraN
,
CShuffleMRepeatPerShuffle
,
CShuffleNRepeatPerShuffle
,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
CShuffleBlockTransferScalarPerVector_NPerBlock
,
NumGemmKPrefetchStage
,
LoopSched
,
PipelineVer
>
;
using
DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
=
decltype
(
GridwiseGemm
::
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
Tuple
<>
{}));
using
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
=
decltype
(
GridwiseGemm
::
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
CGridDesc_M_N
{}));
using
Block2CTileMap
=
decltype
(
GridwiseGemm
::
MakeDefaultBlock2CTileMap
(
CGridDesc_M_N
{},
I1
/* M01 */
,
I1
/* N01 */
));
struct
Argument
:
public
BaseArgument
{
Argument
(
const
InDataType
*
p_in_grid
,
WeiDataType
*
p_wei_grid
,
const
OutDataType
*
p_out_grid
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
a_g_n_c_wis_lengths
,
// input
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
a_g_n_c_wis_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
b_g_k_c_xs_lengths
,
// weight
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
b_g_k_c_xs_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
e_g_n_k_wos_lengths
,
// output
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
e_g_n_k_wos_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_dilations
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_left_pads
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_right_pads
,
InElementwiseOperation
in_element_op
,
WeiElementwiseOperation
wei_element_op
,
OutElementwiseOperation
out_element_op
,
index_t
split_k
)
:
p_a_grid_
{
p_out_grid
},
p_b_grid_
{
p_in_grid
},
p_c_grid_
{
p_wei_grid
},
a_grid_desc_kbatch_k0_m_k1_
{},
b_grid_desc_kbatch_k0_n_k1_
{},
c_grid_desc_m_n_
{},
c_grid_desc_mblock_mperblock_nblock_nperblock_
{},
block_2_ctile_map_
{},
compute_ptr_offset_of_batch_
{},
a_element_op_
{
out_element_op
},
b_element_op_
{
in_element_op
},
c_element_op_
{
wei_element_op
},
Conv_G_
{
a_g_n_c_wis_lengths
[
0
]},
Conv_N_
{
a_g_n_c_wis_lengths
[
1
]},
Conv_K_
{
b_g_k_c_xs_lengths
[
1
]},
Conv_C_
{
a_g_n_c_wis_lengths
[
2
]},
input_spatial_lengths_
{},
filter_spatial_lengths_
{},
output_spatial_lengths_
{},
conv_filter_strides_
{
conv_filter_strides
},
input_left_pads_
{
input_left_pads
},
input_right_pads_
{
input_right_pads
},
k_batch_
{
split_k
}
{
constexpr
index_t
spatial_offset
=
3
;
std
::
copy
(
begin
(
a_g_n_c_wis_lengths
)
+
spatial_offset
,
end
(
a_g_n_c_wis_lengths
),
begin
(
input_spatial_lengths_
));
std
::
copy
(
begin
(
b_g_k_c_xs_lengths
)
+
spatial_offset
,
end
(
b_g_k_c_xs_lengths
),
begin
(
filter_spatial_lengths_
));
std
::
copy
(
begin
(
e_g_n_k_wos_lengths
)
+
spatial_offset
,
end
(
e_g_n_k_wos_lengths
),
begin
(
output_spatial_lengths_
));
const
auto
descs
=
DeviceOp
::
MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N
<
NDimSpatial
>
(
Conv_N_
,
Conv_K_
,
Conv_C_
,
input_spatial_lengths_
,
filter_spatial_lengths_
,
output_spatial_lengths_
,
a_g_n_c_wis_strides
,
b_g_k_c_xs_strides
,
e_g_n_k_wos_strides
,
conv_filter_strides
,
conv_filter_dilations
,
input_left_pads
,
input_right_pads
);
a_grid_desc_kbatch_k0_m_k1_
=
descs
[
I0
];
b_grid_desc_kbatch_k0_n_k1_
=
descs
[
I1
];
c_grid_desc_m_n_
=
descs
[
I2
];
block_2_ctile_map_
=
GridwiseGemm
::
MakeDefaultBlock2CTileMap
(
c_grid_desc_m_n_
,
I1
/* M01 */
,
I1
/* N01 */
);
// A/B/C Batch Stride
compute_ptr_offset_of_batch_
.
BatchStrideA_
=
e_g_n_k_wos_strides
[
0
];
compute_ptr_offset_of_batch_
.
BatchStrideB_
=
a_g_n_c_wis_strides
[
0
];
compute_ptr_offset_of_batch_
.
BatchStrideE_
=
Conv_K_
*
Conv_C_
*
std
::
accumulate
(
begin
(
filter_spatial_lengths_
),
end
(
filter_spatial_lengths_
),
index_t
{
1
},
std
::
multiplies
<>
{});
if
(
GridwiseGemm
::
CheckValidity
(
a_grid_desc_kbatch_k0_m_k1_
,
b_grid_desc_kbatch_k0_n_k1_
,
c_grid_desc_m_n_
,
block_2_ctile_map_
))
{
c_grid_desc_mblock_mperblock_nblock_nperblock_
=
GridwiseGemm
::
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
c_grid_desc_m_n_
);
}
}
const
ADataType
*
p_a_grid_
;
const
BDataType
*
p_b_grid_
;
CDataType
*
p_c_grid_
;
AGridDesc_K0_M_K1
a_grid_desc_kbatch_k0_m_k1_
;
BGridDesc_K0_N_K1
b_grid_desc_kbatch_k0_n_k1_
;
CGridDesc_M_N
c_grid_desc_m_n_
;
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock_
;
Block2CTileMap
block_2_ctile_map_
;
// for computing batch offset
ComputePtrOffsetOfStridedBatch
<
I0
>
compute_ptr_offset_of_batch_
;
OutElementwiseOperation
a_element_op_
;
InElementwiseOperation
b_element_op_
;
WeiElementwiseOperation
c_element_op_
;
// for checking IsSupportedArgument()
const
index_t
Conv_G_
;
const
index_t
Conv_N_
;
const
index_t
Conv_K_
;
const
index_t
Conv_C_
;
std
::
array
<
index_t
,
NDimSpatial
>
input_spatial_lengths_
;
std
::
array
<
index_t
,
NDimSpatial
>
filter_spatial_lengths_
;
std
::
array
<
index_t
,
NDimSpatial
>
output_spatial_lengths_
;
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_strides_
;
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_left_pads_
;
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_right_pads_
;
const
index_t
k_batch_
;
};
// Invoker
struct
Invoker
:
public
BaseInvoker
{
using
Argument
=
DeviceOp
::
Argument
;
void
Print
(
const
Argument
&
arg
)
{
std
::
cout
<<
"arg.a_grid_desc_kbatch_k0_m_k1_{"
<<
arg
.
a_grid_desc_kbatch_k0_m_k1_
.
GetLength
(
I0
)
<<
", "
<<
arg
.
a_grid_desc_kbatch_k0_m_k1_
.
GetLength
(
I1
)
<<
", "
<<
arg
.
a_grid_desc_kbatch_k0_m_k1_
.
GetLength
(
I2
)
<<
"}"
<<
std
::
endl
;
std
::
cout
<<
"arg.b_grid_desc_kbatch_k0_n_k1_{"
<<
arg
.
b_grid_desc_kbatch_k0_n_k1_
.
GetLength
(
I0
)
<<
", "
<<
arg
.
b_grid_desc_kbatch_k0_n_k1_
.
GetLength
(
I1
)
<<
", "
<<
arg
.
b_grid_desc_kbatch_k0_n_k1_
.
GetLength
(
I2
)
<<
"}"
<<
std
::
endl
;
std
::
cout
<<
"arg.c_grid_desc_m_n_{"
<<
arg
.
c_grid_desc_m_n_
.
GetLength
(
I0
)
<<
", "
<<
arg
.
c_grid_desc_m_n_
.
GetLength
(
I1
)
<<
"}"
<<
std
::
endl
;
}
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
{
if
(
stream_config
.
log_level_
>
0
)
{
Print
(
arg
);
}
if
(
!
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_kbatch_k0_m_k1_
,
arg
.
b_grid_desc_kbatch_k0_n_k1_
,
arg
.
c_grid_desc_m_n_
,
arg
.
block_2_ctile_map_
))
{
throw
std
::
runtime_error
(
"wrong! GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle has invalid "
"setting"
);
}
const
index_t
grid_size
=
arg
.
block_2_ctile_map_
.
CalculateGridSize
(
arg
.
c_grid_desc_m_n_
)
*
arg
.
Conv_G_
;
const
auto
K0
=
arg
.
a_grid_desc_kbatch_k0_m_k1_
.
GetLength
(
I1
);
const
bool
has_main_k0_block_loop
=
GridwiseGemm
::
CalculateHasMainKBlockLoop
(
K0
);
auto
launch_kernel
=
[
&
](
auto
has_main_k_block_loop
)
{
constexpr
bool
has_main_loop
=
has_main_k_block_loop
.
value
;
const
auto
kernel
=
kernel_grouped_conv_multiple_d_wmma_cshuffle
<
GridwiseGemm
,
ADataType
,
BDataType
,
typename
GridwiseGemm
::
DsGridPointer
,
CDataType
,
OutElementwiseOperation
,
InElementwiseOperation
,
WeiElementwiseOperation
,
AGridDesc_K0_M_K1
,
BGridDesc_K0_N_K1
,
DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
,
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
,
remove_reference_t
<
typename
GridwiseGemm
::
DefaultBlock2CTileMap
>
,
ComputePtrOffsetOfStridedBatch
<
I0
>
,
has_main_loop
>
;
using
EmptyTuple
=
Tuple
<>
;
return
launch_and_time_kernel
(
stream_config
,
kernel
,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
arg
.
p_a_grid_
,
arg
.
p_b_grid_
,
EmptyTuple
{},
// Ds
arg
.
p_c_grid_
,
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
c_element_op_
,
arg
.
Conv_G_
,
arg
.
a_grid_desc_kbatch_k0_m_k1_
,
arg
.
b_grid_desc_kbatch_k0_n_k1_
,
DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
{},
arg
.
c_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg
.
block_2_ctile_map_
,
arg
.
compute_ptr_offset_of_batch_
);
};
if
(
has_main_k0_block_loop
)
{
return
launch_kernel
(
integral_constant
<
bool
,
true
>
{});
}
else
{
return
launch_kernel
(
integral_constant
<
bool
,
false
>
{});
}
}
float
Run
(
const
BaseArgument
*
p_arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
override
{
return
Run
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
),
stream_config
);
}
};
static
constexpr
bool
IsValidCompilationParameter
()
{
// TODO: properly implement this check
return
true
;
}
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
// check device
if
(
get_device_name
()
==
"gfx1100"
||
get_device_name
()
==
"gfx1101"
||
get_device_name
()
==
"gfx1102"
)
{
if
constexpr
(
!
(
is_same_v
<
AccDataType
,
float
>
||
is_same_v
<
AccDataType
,
int32_t
>
))
{
return
false
;
}
}
else
{
return
false
;
}
// TODO: Add support for split_k > 1
if
(
arg
.
k_batch_
!=
1
)
{
return
false
;
}
if
constexpr
(
!
(
is_NDHWGK_GKZYXC_NDHWGC
||
is_GNDHWK_GKZYXC_GNDHWC
))
{
return
false
;
}
if
constexpr
(
ConvBackwardWeightSpecialization
==
ConvolutionBackwardWeightSpecialization
::
Filter1x1Stride1Pad0
)
{
// check if it's a 1x1 convolution with stride=1 and no padding
for
(
int
i
=
0
;
i
<
NDimSpatial
;
i
++
)
{
if
(
!
(
arg
.
filter_spatial_lengths_
[
i
]
==
1
&&
arg
.
conv_filter_strides_
[
i
]
==
1
&&
arg
.
input_left_pads_
[
i
]
==
0
&&
arg
.
input_right_pads_
[
i
]
==
0
))
{
return
false
;
}
}
}
// vector load A/B matrix from global memory
if
(
!
(
ABlockTransferSrcVectorDim
==
1
&&
BBlockTransferSrcVectorDim
==
1
&&
arg
.
Conv_K_
%
ABlockTransferSrcScalarPerVector
==
0
&&
arg
.
Conv_C_
%
BBlockTransferSrcScalarPerVector
==
0
))
{
return
false
;
}
// vector store C matrix into global memory
if
(
!
(
arg
.
Conv_C_
%
CShuffleBlockTransferScalarPerVector_NPerBlock
==
0
))
{
return
false
;
}
// Gridwise GEMM size
return
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_kbatch_k0_m_k1_
,
arg
.
b_grid_desc_kbatch_k0_n_k1_
,
arg
.
c_grid_desc_m_n_
,
arg
.
block_2_ctile_map_
);
}
bool
IsSupportedArgument
(
const
BaseArgument
*
p_arg
)
override
{
return
IsSupportedArgument
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
));
}
static
auto
MakeArgument
(
const
InDataType
*
p_in_grid
,
WeiDataType
*
p_wei_grid
,
const
OutDataType
*
p_out_grid
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
a_g_n_c_wis_lengths
,
// input
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
a_g_n_c_wis_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
b_g_k_c_xs_lengths
,
// weight
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
b_g_k_c_xs_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
e_g_n_k_wos_lengths
,
// output
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
e_g_n_k_wos_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_dilations
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_left_pads
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_right_pads
,
InElementwiseOperation
in_element_op
,
WeiElementwiseOperation
wei_element_op
,
OutElementwiseOperation
out_element_op
,
const
index_t
split_k
)
{
return
Argument
{
p_in_grid
,
p_wei_grid
,
p_out_grid
,
a_g_n_c_wis_lengths
,
// input
a_g_n_c_wis_strides
,
b_g_k_c_xs_lengths
,
// weight
b_g_k_c_xs_strides
,
e_g_n_k_wos_lengths
,
// output
e_g_n_k_wos_strides
,
conv_filter_strides
,
conv_filter_dilations
,
input_left_pads
,
input_right_pads
,
in_element_op
,
wei_element_op
,
out_element_op
,
split_k
};
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
void
*
p_in_grid
,
void
*
p_wei_grid
,
const
void
*
p_out_grid
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
a_g_n_c_wis_lengths
,
// input
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
a_g_n_c_wis_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
b_g_k_c_xs_lengths
,
// weight
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
b_g_k_c_xs_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
e_g_n_k_wos_lengths
,
// output
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
e_g_n_k_wos_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_dilations
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_left_pads
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_right_pads
,
InElementwiseOperation
in_element_op
,
WeiElementwiseOperation
wei_element_op
,
OutElementwiseOperation
out_element_op
,
const
index_t
split_k
)
override
{
return
std
::
make_unique
<
Argument
>
(
static_cast
<
const
InDataType
*>
(
p_in_grid
),
static_cast
<
WeiDataType
*>
(
p_wei_grid
),
static_cast
<
const
OutDataType
*>
(
p_out_grid
),
a_g_n_c_wis_lengths
,
// input
a_g_n_c_wis_strides
,
b_g_k_c_xs_lengths
,
// weight
b_g_k_c_xs_strides
,
e_g_n_k_wos_lengths
,
// output
e_g_n_k_wos_strides
,
conv_filter_strides
,
conv_filter_dilations
,
input_left_pads
,
input_right_pads
,
in_element_op
,
wei_element_op
,
out_element_op
,
split_k
);
}
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
override
{
return
std
::
make_unique
<
Invoker
>
(
Invoker
{});
}
std
::
string
GetTypeString
()
const
override
{
auto
str
=
std
::
stringstream
();
// clang-format off
str
<<
"DeviceGroupedConvBwdWeight_Wmma_CShuffle"
<<
"<"
<<
BlockSize
<<
", "
<<
MPerBlock
<<
", "
<<
NPerBlock
<<
", "
<<
K0PerBlock
<<
", "
<<
getConvBackwardWeightSpecializationString
(
ConvBackwardWeightSpecialization
)
<<
", "
<<
K1
<<
", "
<<
ABlockTransferSrcScalarPerVector
<<
", "
<<
ABlockTransferDstScalarPerVector_K1
<<
", "
<<
BBlockTransferSrcScalarPerVector
<<
", "
<<
BBlockTransferDstScalarPerVector_K1
<<
", "
<<
CShuffleMRepeatPerShuffle
<<
", "
<<
CShuffleNRepeatPerShuffle
<<
", "
<<
CShuffleBlockTransferScalarPerVector_NPerBlock
<<
">"
;
// clang-format on
return
str
.
str
();
}
};
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp
View file @
1b5af83d
...
@@ -14,6 +14,7 @@
...
@@ -14,6 +14,7 @@
#include "ck/tensor_operation/gpu/device/device_grouped_conv_bwd_weight.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_conv_bwd_weight.hpp"
#include "ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp"
#include "ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/host_utility/kernel_launch.hpp"
...
@@ -21,32 +22,6 @@ namespace ck {
...
@@ -21,32 +22,6 @@ namespace ck {
namespace
tensor_operation
{
namespace
tensor_operation
{
namespace
device
{
namespace
device
{
namespace
{
struct
ComputePtrOffsetOfStridedBatch
{
__host__
__device__
constexpr
long_index_t
GetAPtrOffset
(
index_t
g_idx
)
const
{
return
g_idx
*
static_cast
<
long_index_t
>
(
BatchStrideA_
);
}
__host__
__device__
constexpr
long_index_t
GetBPtrOffset
(
index_t
g_idx
)
const
{
return
g_idx
*
static_cast
<
long_index_t
>
(
BatchStrideB_
);
}
__host__
__device__
constexpr
long_index_t
GetCPtrOffset
(
index_t
g_idx
)
const
{
return
g_idx
*
static_cast
<
long_index_t
>
(
BatchStrideC_
);
}
index_t
BatchStrideA_
;
index_t
BatchStrideB_
;
index_t
BatchStrideC_
;
};
}
// namespace
template
<
typename
GridwiseGemm
,
template
<
typename
GridwiseGemm
,
typename
FloatA
,
typename
FloatA
,
typename
FloatB
,
typename
FloatB
,
...
@@ -1222,7 +1197,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
...
@@ -1222,7 +1197,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
Block2CTileMap
block_2_ctile_map_
;
Block2CTileMap
block_2_ctile_map_
;
// for computing batch offset
// for computing batch offset
ComputePtrOffsetOfStridedBatch
compute_ptr_offset_of_batch_
;
ComputePtrOffsetOfStridedBatch
<
I0
>
compute_ptr_offset_of_batch_
;
index_t
M01_
;
index_t
M01_
;
index_t
N01_
;
index_t
N01_
;
...
@@ -1301,7 +1276,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
...
@@ -1301,7 +1276,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
remove_reference_t
<
DeviceOp
::
BGridDesc_K0_N_K1
>
,
remove_reference_t
<
DeviceOp
::
BGridDesc_K0_N_K1
>
,
remove_reference_t
<
DeviceOp
::
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
>
,
remove_reference_t
<
DeviceOp
::
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
>
,
remove_reference_t
<
DeviceOp
::
Block2CTileMap
>
,
remove_reference_t
<
DeviceOp
::
Block2CTileMap
>
,
ComputePtrOffsetOfStridedBatch
,
ComputePtrOffsetOfStridedBatch
<
I0
>
,
has_main_loop
>
;
has_main_loop
>
;
return
launch_and_time_kernel
(
stream_config
,
return
launch_and_time_kernel
(
stream_config
,
...
@@ -1348,6 +1323,10 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
...
@@ -1348,6 +1323,10 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
{
if
(
!
ck
::
is_xdl_supported
())
{
return
false
;
}
if
constexpr
(
NDimSpatial
==
1
)
if
constexpr
(
NDimSpatial
==
1
)
{
{
if
constexpr
(
!
is_GNWK_GKXC_GNWC
)
if
constexpr
(
!
is_GNWK_GKXC_GNWC
)
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp
View file @
1b5af83d
...
@@ -471,7 +471,7 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
...
@@ -471,7 +471,7 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
auto
launch_kernel
=
[
&
](
auto
has_main_k_block_loop
)
{
auto
launch_kernel
=
[
&
](
auto
has_main_k_block_loop
)
{
constexpr
bool
has_main_loop
=
has_main_k_block_loop
.
value
;
constexpr
bool
has_main_loop
=
has_main_k_block_loop
.
value
;
const
auto
kernel
=
kernel_grouped_conv_
fwd_
multiple_d_wmma_cshuffle
<
const
auto
kernel
=
kernel_grouped_conv_multiple_d_wmma_cshuffle
<
GridwiseOp
,
GridwiseOp
,
ADataType
,
ADataType
,
BDataType
,
BDataType
,
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp
View file @
1b5af83d
...
@@ -43,7 +43,13 @@ struct ComputePtrOffsetOfStridedBatch
...
@@ -43,7 +43,13 @@ struct ComputePtrOffsetOfStridedBatch
return
ds_offset
;
return
ds_offset
;
}
}
__host__
__device__
constexpr
long_index_t
GetEPtrOffset
(
index_t
g_idx
)
const
[[
maybe_unused
]]
__host__
__device__
constexpr
long_index_t
GetEPtrOffset
(
index_t
g_idx
)
const
{
return
g_idx
*
static_cast
<
long_index_t
>
(
BatchStrideE_
);
}
// alias for kernels without multiple D
[[
maybe_unused
]]
__host__
__device__
constexpr
long_index_t
GetCPtrOffset
(
index_t
g_idx
)
const
{
{
return
g_idx
*
static_cast
<
long_index_t
>
(
BatchStrideE_
);
return
g_idx
*
static_cast
<
long_index_t
>
(
BatchStrideE_
);
}
}
...
@@ -52,6 +58,7 @@ struct ComputePtrOffsetOfStridedBatch
...
@@ -52,6 +58,7 @@ struct ComputePtrOffsetOfStridedBatch
index_t
BatchStrideB_
;
index_t
BatchStrideB_
;
Array
<
ck
::
index_t
,
NumDTensor
>
BatchStrideDs_
;
Array
<
ck
::
index_t
,
NumDTensor
>
BatchStrideDs_
;
index_t
BatchStrideE_
;
index_t
BatchStrideE_
;
index_t
&
BatchStrideC_
=
BatchStrideE_
;
// alias for kernels without multiple D
};
};
}
// namespace device
}
// namespace device
...
...
include/ck/tensor_operation/gpu/device/impl/device_normalization_impl.hpp
View file @
1b5af83d
...
@@ -28,6 +28,7 @@ template <typename XDataType,
...
@@ -28,6 +28,7 @@ template <typename XDataType,
typename
BetaDataType
,
typename
BetaDataType
,
typename
ComputeDataType
,
typename
ComputeDataType
,
typename
YDataType
,
typename
YDataType
,
typename
SaveMeanInvStdDataType
,
typename
YElementwiseOperation
,
typename
YElementwiseOperation
,
index_t
Rank
,
index_t
Rank
,
index_t
NumReduceDim
,
index_t
NumReduceDim
,
...
@@ -43,12 +44,13 @@ template <typename XDataType,
...
@@ -43,12 +44,13 @@ template <typename XDataType,
index_t
BetaSrcVectorDim
,
index_t
BetaSrcVectorDim
,
index_t
BetaSrcVectorSize
,
index_t
BetaSrcVectorSize
,
index_t
YDstVectorSize
,
index_t
YDstVectorSize
,
index_t
SaveMeanInvStdDstVectorSize
,
bool
UseWelford
=
true
>
bool
UseWelford
=
true
>
struct
DeviceNormalizationImpl
:
public
DeviceNormalization
<
XDataType
,
struct
DeviceNormalizationImpl
:
public
DeviceNormalization
<
XDataType
,
GammaDataType
,
GammaDataType
,
BetaDataType
,
BetaDataType
,
ComputeDataType
,
YDataType
,
YDataType
,
SaveMeanInvStdDataType
,
YElementwiseOperation
,
YElementwiseOperation
,
Rank
,
Rank
,
NumReduceDim
>
NumReduceDim
>
...
@@ -64,18 +66,24 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
...
@@ -64,18 +66,24 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
(
BetaSrcVectorDim
==
1
&&
KThreadSliceSize
%
BetaSrcVectorSize
==
0
)),
(
BetaSrcVectorDim
==
1
&&
KThreadSliceSize
%
BetaSrcVectorSize
==
0
)),
"Invalid thread slice sizes and/or beta vector sizes configuration, please check!"
);
"Invalid thread slice sizes and/or beta vector sizes configuration, please check!"
);
static_assert
(
MThreadSliceSize
%
SaveMeanInvStdDstVectorSize
==
0
,
"Invalid thread slice sizes and/or save mean and inverse std vector sizes "
"configuration, please check!"
);
using
PassThrough
=
tensor_operation
::
element_wise
::
PassThrough
;
using
PassThrough
=
tensor_operation
::
element_wise
::
PassThrough
;
static
constexpr
index_t
NumInvariantDim
=
Rank
-
NumReduceDim
;
static
constexpr
index_t
M_BlockTileSize
=
MThreadClusterSize
*
MThreadSliceSize
;
static
constexpr
index_t
M_BlockTileSize
=
MThreadClusterSize
*
MThreadSliceSize
;
static
constexpr
index_t
K_BlockTileSize
=
KThreadClusterSize
*
KThreadSliceSize
;
static
constexpr
index_t
K_BlockTileSize
=
KThreadClusterSize
*
KThreadSliceSize
;
static
constexpr
bool
reduceAllDim
=
(
NumInvariantDim
==
0
);
static_assert
(
!
reduceAllDim
);
// TODO
static
auto
MakeSrc2dDescriptor
(
const
std
::
vector
<
index_t
>&
inLengths
,
static
auto
MakeSrc2dDescriptor
(
const
std
::
vector
<
index_t
>&
inLengths
,
const
std
::
vector
<
index_t
>&
inStrides
,
const
std
::
vector
<
index_t
>&
inStrides
,
int
numBlockTileIteration
)
int
numBlockTileIteration
)
{
{
constexpr
index_t
NumInvariantDim
=
Rank
-
NumReduceDim
;
static
constexpr
index_t
numSrcDim
=
Rank
;
static
constexpr
index_t
numSrcDim
=
Rank
;
static
constexpr
bool
reduceAllDim
=
(
NumInvariantDim
==
0
);
const
auto
tupleSrcLengths
=
make_tuple_from_array
(
inLengths
,
Number
<
numSrcDim
>
{});
const
auto
tupleSrcLengths
=
make_tuple_from_array
(
inLengths
,
Number
<
numSrcDim
>
{});
const
auto
tupleSrcStrides
=
make_tuple_from_array
(
inStrides
,
Number
<
numSrcDim
>
{});
const
auto
tupleSrcStrides
=
make_tuple_from_array
(
inStrides
,
Number
<
numSrcDim
>
{});
...
@@ -133,7 +141,37 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
...
@@ -133,7 +141,37 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
return
(
in_grid_desc_m_k_padded
);
return
(
in_grid_desc_m_k_padded
);
};
};
static
auto
MakeSaveMeanInvStdDescriptor_M
(
const
std
::
vector
<
index_t
>&
lengths
,
const
std
::
vector
<
index_t
>&
strides
)
{
using
InvariantDims
=
typename
arithmetic_sequence_gen
<
0
,
NumInvariantDim
,
1
>::
type
;
const
auto
tupleSrcLengths
=
make_tuple_from_array_and_index_seq
(
lengths
,
InvariantDims
{});
const
auto
tupleSrcStrides
=
make_tuple_from_array_and_index_seq
(
strides
,
InvariantDims
{});
const
auto
desc
=
make_naive_tensor_descriptor
(
tupleSrcLengths
,
tupleSrcStrides
);
const
auto
grid_desc_m
=
transform_tensor_descriptor
(
desc
,
make_tuple
(
make_merge_transform
(
tupleSrcLengths
)),
make_tuple
(
InvariantDims
{}),
make_tuple
(
Sequence
<
0
>
{}));
const
auto
invariantLength
=
grid_desc_m
.
GetLength
(
Number
<
0
>
{});
const
auto
pad_M
=
math
::
integer_least_multiple
(
invariantLength
,
M_BlockTileSize
)
-
invariantLength
;
auto
grid_desc_m_padded
=
transform_tensor_descriptor
(
grid_desc_m
,
make_tuple
(
make_right_pad_transform
(
invariantLength
,
pad_M
)),
make_tuple
(
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
return
grid_desc_m_padded
;
}
using
GridDesc_M_K
=
decltype
(
MakeSrc2dDescriptor
({
1
},
{
1
},
1
));
using
GridDesc_M_K
=
decltype
(
MakeSrc2dDescriptor
({
1
},
{
1
},
1
));
using
GridDesc_M
=
decltype
(
MakeSaveMeanInvStdDescriptor_M
({
1
},
{
1
}));
struct
Argument
:
public
BaseArgument
struct
Argument
:
public
BaseArgument
{
{
...
@@ -142,17 +180,23 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
...
@@ -142,17 +180,23 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
const
std
::
vector
<
index_t
>
gammaStrides
,
const
std
::
vector
<
index_t
>
gammaStrides
,
const
std
::
vector
<
index_t
>
betaStrides
,
const
std
::
vector
<
index_t
>
betaStrides
,
const
std
::
vector
<
index_t
>
yStrides
,
const
std
::
vector
<
index_t
>
yStrides
,
const
std
::
vector
<
index_t
>
saveMeanStrides
,
const
std
::
vector
<
index_t
>
saveInvStdStrides
,
const
std
::
vector
<
index_t
>
reduceDims
,
const
std
::
vector
<
index_t
>
reduceDims
,
YElementwiseOperation
y_elementwise_op
,
YElementwiseOperation
y_elementwise_op
,
double
epsilon
,
double
epsilon
,
const
XDataType
*
p_x
,
const
XDataType
*
p_x
,
const
GammaDataType
*
p_gamma
,
const
GammaDataType
*
p_gamma
,
const
BetaDataType
*
p_beta
,
const
BetaDataType
*
p_beta
,
YDataType
*
p_y
)
YDataType
*
p_y
,
SaveMeanInvStdDataType
*
p_saveMean
,
SaveMeanInvStdDataType
*
p_saveInvStd
)
:
p_x_
(
p_x
),
:
p_x_
(
p_x
),
p_gamma_
(
p_gamma
),
p_gamma_
(
p_gamma
),
p_beta_
(
p_beta
),
p_beta_
(
p_beta
),
p_y_
(
p_y
),
p_y_
(
p_y
),
p_saveMean_
(
p_saveMean
),
p_saveInvStd_
(
p_saveInvStd
),
y_elementwise_op_
(
y_elementwise_op
)
y_elementwise_op_
(
y_elementwise_op
)
{
{
epsilon_
=
static_cast
<
ComputeDataType
>
(
epsilon
);
epsilon_
=
static_cast
<
ComputeDataType
>
(
epsilon
);
...
@@ -162,16 +206,14 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
...
@@ -162,16 +206,14 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
yStrides_
=
shuffle_tensor_dimensions
<
Rank
,
NumReduceDim
>
(
yStrides
,
reduceDims
);
yStrides_
=
shuffle_tensor_dimensions
<
Rank
,
NumReduceDim
>
(
yStrides
,
reduceDims
);
gammaStrides_
=
shuffle_tensor_dimensions
<
Rank
,
NumReduceDim
>
(
gammaStrides
,
reduceDims
);
gammaStrides_
=
shuffle_tensor_dimensions
<
Rank
,
NumReduceDim
>
(
gammaStrides
,
reduceDims
);
betaStrides_
=
shuffle_tensor_dimensions
<
Rank
,
NumReduceDim
>
(
betaStrides
,
reduceDims
);
betaStrides_
=
shuffle_tensor_dimensions
<
Rank
,
NumReduceDim
>
(
betaStrides
,
reduceDims
);
saveMeanStrides_
=
saveMeanStrides
;
saveInvStdStrides_
=
saveInvStdStrides
;
long_index_t
invariant_length
;
std
::
tie
(
MRaw_
,
KRaw_
)
=
get_2d_lengths
<
Rank
,
NumReduceDim
>
(
Lengths_
);
long_index_t
reduce_length
;
std
::
tie
(
invariant_length
,
reduce_length
)
=
get_2d_lengths
<
Rank
,
NumReduceDim
>
(
Lengths_
);
numBlockTileIteration_
=
math
::
integer_divide_ceil
(
reduce_length
,
K_BlockTileSize
);
numBlockTileIteration_
=
math
::
integer_divide_ceil
(
KRaw_
,
K_BlockTileSize
);
gridSize_
=
math
::
integer_divide_ceil
(
invariant_length
,
M_BlockTileSize
);
gridSize_
=
math
::
integer_divide_ceil
(
MRaw_
,
M_BlockTileSize
);
x_grid_desc_m_k_
=
MakeSrc2dDescriptor
(
Lengths_
,
xStrides_
,
numBlockTileIteration_
);
x_grid_desc_m_k_
=
MakeSrc2dDescriptor
(
Lengths_
,
xStrides_
,
numBlockTileIteration_
);
gamma_grid_desc_m_k_
=
gamma_grid_desc_m_k_
=
...
@@ -179,9 +221,16 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
...
@@ -179,9 +221,16 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
beta_grid_desc_m_k_
=
beta_grid_desc_m_k_
=
MakeSrc2dDescriptor
(
Lengths_
,
betaStrides_
,
numBlockTileIteration_
);
MakeSrc2dDescriptor
(
Lengths_
,
betaStrides_
,
numBlockTileIteration_
);
y_grid_desc_m_k_
=
MakeSrc2dDescriptor
(
Lengths_
,
yStrides_
,
numBlockTileIteration_
);
y_grid_desc_m_k_
=
MakeSrc2dDescriptor
(
Lengths_
,
yStrides_
,
numBlockTileIteration_
);
save_mean_grid_desc_m_
=
MakeSaveMeanInvStdDescriptor_M
(
Lengths_
,
saveMeanStrides
);
save_inv_std_grid_desc_m_
=
MakeSaveMeanInvStdDescriptor_M
(
Lengths_
,
saveInvStdStrides
);
isSweeponce_
=
isSweeponce_
=
x_grid_desc_m_k_
.
GetLength
(
Number
<
1
>
{})
<=
KThreadClusterSize
*
KThreadSliceSize
;
x_grid_desc_m_k_
.
GetLength
(
Number
<
1
>
{})
<=
KThreadClusterSize
*
KThreadSliceSize
;
if
constexpr
(
NumInvariantDim
==
0
)
invariant_lowest_length_
=
1
;
else
invariant_lowest_length_
=
Lengths_
[
NumInvariantDim
-
1
];
}
}
ComputeDataType
epsilon_
;
ComputeDataType
epsilon_
;
...
@@ -190,12 +239,16 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
...
@@ -190,12 +239,16 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
const
GammaDataType
*
p_gamma_
;
const
GammaDataType
*
p_gamma_
;
const
BetaDataType
*
p_beta_
;
const
BetaDataType
*
p_beta_
;
YDataType
*
p_y_
;
YDataType
*
p_y_
;
SaveMeanInvStdDataType
*
p_saveMean_
;
SaveMeanInvStdDataType
*
p_saveInvStd_
;
std
::
vector
<
index_t
>
Lengths_
;
std
::
vector
<
index_t
>
Lengths_
;
std
::
vector
<
index_t
>
xStrides_
;
std
::
vector
<
index_t
>
xStrides_
;
std
::
vector
<
index_t
>
gammaStrides_
;
std
::
vector
<
index_t
>
gammaStrides_
;
std
::
vector
<
index_t
>
betaStrides_
;
std
::
vector
<
index_t
>
betaStrides_
;
std
::
vector
<
index_t
>
yStrides_
;
std
::
vector
<
index_t
>
yStrides_
;
std
::
vector
<
index_t
>
saveMeanStrides_
;
std
::
vector
<
index_t
>
saveInvStdStrides_
;
YElementwiseOperation
y_elementwise_op_
;
YElementwiseOperation
y_elementwise_op_
;
...
@@ -206,7 +259,14 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
...
@@ -206,7 +259,14 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
GridDesc_M_K
gamma_grid_desc_m_k_
;
GridDesc_M_K
gamma_grid_desc_m_k_
;
GridDesc_M_K
beta_grid_desc_m_k_
;
GridDesc_M_K
beta_grid_desc_m_k_
;
GridDesc_M_K
y_grid_desc_m_k_
;
GridDesc_M_K
y_grid_desc_m_k_
;
GridDesc_M
save_mean_grid_desc_m_
;
GridDesc_M
save_inv_std_grid_desc_m_
;
bool
isSweeponce_
;
bool
isSweeponce_
;
index_t
MRaw_
;
// invarient length
index_t
KRaw_
;
// reduce length
index_t
invariant_lowest_length_
;
};
};
struct
Invoker
:
public
BaseInvoker
struct
Invoker
:
public
BaseInvoker
...
@@ -217,9 +277,11 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
...
@@ -217,9 +277,11 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
GammaDataType
,
GammaDataType
,
BetaDataType
,
BetaDataType
,
YDataType
,
YDataType
,
SaveMeanInvStdDataType
,
ComputeDataType
,
ComputeDataType
,
YElementwiseOperation
,
YElementwiseOperation
,
GridDesc_M_K
,
GridDesc_M_K
,
GridDesc_M
,
BlockSize
,
BlockSize
,
MThreadClusterSize
,
MThreadClusterSize
,
KThreadClusterSize
,
KThreadClusterSize
,
...
@@ -233,6 +295,7 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
...
@@ -233,6 +295,7 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
BetaSrcVectorSize
,
BetaSrcVectorSize
,
XYSrcVectorDim
,
XYSrcVectorDim
,
YDstVectorSize
,
YDstVectorSize
,
SaveMeanInvStdDstVectorSize
,
UseWelford
>
(
arg
.
isSweeponce_
);
UseWelford
>
(
arg
.
isSweeponce_
);
float
avg_time
=
0
;
float
avg_time
=
0
;
...
@@ -245,12 +308,16 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
...
@@ -245,12 +308,16 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
arg
.
gamma_grid_desc_m_k_
,
arg
.
gamma_grid_desc_m_k_
,
arg
.
beta_grid_desc_m_k_
,
arg
.
beta_grid_desc_m_k_
,
arg
.
y_grid_desc_m_k_
,
arg
.
y_grid_desc_m_k_
,
arg
.
save_mean_grid_desc_m_
,
arg
.
save_inv_std_grid_desc_m_
,
arg
.
numBlockTileIteration_
,
arg
.
numBlockTileIteration_
,
arg
.
epsilon_
,
arg
.
epsilon_
,
arg
.
p_x_
,
arg
.
p_x_
,
arg
.
p_gamma_
,
arg
.
p_gamma_
,
arg
.
p_beta_
,
arg
.
p_beta_
,
arg
.
p_y_
,
arg
.
p_y_
,
arg
.
p_saveMean_
,
arg
.
p_saveInvStd_
,
arg
.
y_elementwise_op_
);
arg
.
y_elementwise_op_
);
return
(
avg_time
);
return
(
avg_time
);
...
@@ -267,8 +334,6 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
...
@@ -267,8 +334,6 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
{
{
const
Argument
*
p_arg_
=
dynamic_cast
<
const
Argument
*>
(
p_arg
);
const
Argument
*
p_arg_
=
dynamic_cast
<
const
Argument
*>
(
p_arg
);
constexpr
index_t
NumInvariantDim
=
Rank
-
NumReduceDim
;
if
constexpr
(
XYSrcVectorDim
==
0
)
if
constexpr
(
XYSrcVectorDim
==
0
)
{
{
if
constexpr
(
NumInvariantDim
==
0
)
if
constexpr
(
NumInvariantDim
==
0
)
...
@@ -277,13 +342,15 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
...
@@ -277,13 +342,15 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
}
}
else
else
{
{
printf
(
"!!!! %d
\n
"
,
p_arg_
->
invariant_lowest_length_
);
if
(
p_arg_
->
xStrides_
[
NumInvariantDim
-
1
]
!=
1
)
if
(
p_arg_
->
xStrides_
[
NumInvariantDim
-
1
]
!=
1
)
return
false
;
return
false
;
if
(
p_arg_
->
invariant_lowest_length
%
XSrcVectorSize
!=
0
)
if
(
p_arg_
->
invariant_lowest_length
_
%
XSrcVectorSize
!=
0
)
return
false
;
return
false
;
if
(
p_arg_
->
invariant_lowest_length
%
YDstVectorSize
!=
0
)
if
(
p_arg_
->
invariant_lowest_length
_
%
YDstVectorSize
!=
0
)
return
false
;
return
false
;
};
};
}
}
...
@@ -325,7 +392,7 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
...
@@ -325,7 +392,7 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
if
(
p_arg_
->
betaStrides_
[
NumInvariantDim
-
1
]
!=
1
)
if
(
p_arg_
->
betaStrides_
[
NumInvariantDim
-
1
]
!=
1
)
return
(
false
);
return
(
false
);
if
(
p_arg_
->
invariant_lowest_length
%
BetaSrcVectorSize
!=
0
)
if
(
p_arg_
->
invariant_lowest_length
_
%
BetaSrcVectorSize
!=
0
)
return
(
false
);
return
(
false
);
}
}
else
// if fastest dim is reduced
else
// if fastest dim is reduced
...
@@ -337,6 +404,9 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
...
@@ -337,6 +404,9 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
return
(
false
);
return
(
false
);
}
}
if
(
p_arg_
->
invariant_lowest_length_
%
SaveMeanInvStdDstVectorSize
!=
0
)
return
false
;
return
true
;
return
true
;
};
};
...
@@ -346,6 +416,8 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
...
@@ -346,6 +416,8 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
const
std
::
vector
<
index_t
>
gammaStrides
,
const
std
::
vector
<
index_t
>
gammaStrides
,
const
std
::
vector
<
index_t
>
betaStrides
,
const
std
::
vector
<
index_t
>
betaStrides
,
const
std
::
vector
<
index_t
>
yStrides
,
const
std
::
vector
<
index_t
>
yStrides
,
const
std
::
vector
<
index_t
>
saveMeanStrides
,
const
std
::
vector
<
index_t
>
saveInvStdStrides
,
const
std
::
vector
<
index_t
>
reduceDims
,
const
std
::
vector
<
index_t
>
reduceDims
,
double
epsilon
,
double
epsilon
,
const
void
*
p_x
,
const
void
*
p_x
,
...
@@ -353,27 +425,30 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
...
@@ -353,27 +425,30 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
const
void
*
p_beta
,
const
void
*
p_beta
,
void
*
p_y
,
void
*
p_y
,
void
*
p_saveMean
,
void
*
p_saveMean
,
void
*
p_saveInv
Var
,
void
*
p_saveInv
Std
,
YElementwiseOperation
y_elementwise_op
)
override
YElementwiseOperation
y_elementwise_op
)
override
{
{
// TODO
if
(
lengths
.
size
()
!=
Rank
||
xStrides
.
size
()
!=
Rank
||
gammaStrides
.
size
()
!=
Rank
||
// Optional cache of the intermediate results (mean and InvVariance) during the
betaStrides
.
size
()
!=
Rank
||
yStrides
.
size
()
!=
Rank
||
// forward pass could speedup in the backward
saveMeanStrides
.
size
()
!=
NumInvariantDim
||
saveInvStdStrides
.
size
()
!=
NumInvariantDim
)
ignore
=
p_saveMean
;
throw
std
::
runtime_error
(
"dimension is incorrect"
);
ignore
=
p_saveInvVar
;
return
std
::
make_unique
<
Argument
>
(
lengths
,
return
std
::
make_unique
<
Argument
>
(
lengths
,
xStrides
,
xStrides
,
gammaStrides
,
gammaStrides
,
betaStrides
,
betaStrides
,
yStrides
,
yStrides
,
saveMeanStrides
,
saveInvStdStrides
,
reduceDims
,
reduceDims
,
y_elementwise_op
,
y_elementwise_op
,
epsilon
,
epsilon
,
static_cast
<
const
XDataType
*>
(
p_x
),
static_cast
<
const
XDataType
*>
(
p_x
),
static_cast
<
const
GammaDataType
*>
(
p_gamma
),
static_cast
<
const
GammaDataType
*>
(
p_gamma
),
static_cast
<
const
BetaDataType
*>
(
p_beta
),
static_cast
<
const
BetaDataType
*>
(
p_beta
),
static_cast
<
YDataType
*>
(
p_y
));
static_cast
<
YDataType
*>
(
p_y
),
static_cast
<
SaveMeanInvStdDataType
*>
(
p_saveMean
),
static_cast
<
SaveMeanInvStdDataType
*>
(
p_saveInvStd
));
};
};
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
override
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
override
...
...
include/ck/tensor_operation/gpu/device/impl/device_normalization_splitk_impl.hpp
View file @
1b5af83d
...
@@ -19,7 +19,7 @@
...
@@ -19,7 +19,7 @@
namespace
ck
{
namespace
ck
{
template
<
typename
GridwiseWelford
,
template
<
typename
GridwiseWelford
,
typename
XDataType
,
typename
XDataType
,
typename
MeanVarDataType
,
typename
Workspace
MeanVarDataType
,
typename
ComputeDataType
,
typename
ComputeDataType
,
typename
XGridDesc_M_K
,
typename
XGridDesc_M_K
,
typename
MeanVarGridDesc_M_KBlock
>
typename
MeanVarGridDesc_M_KBlock
>
...
@@ -28,8 +28,8 @@ kernel_normalizationSplitK1st(const XGridDesc_M_K x_grid_desc_m_k,
...
@@ -28,8 +28,8 @@ kernel_normalizationSplitK1st(const XGridDesc_M_K x_grid_desc_m_k,
const
MeanVarGridDesc_M_KBlock
mean_var_grid_desc_m_kblock
,
const
MeanVarGridDesc_M_KBlock
mean_var_grid_desc_m_kblock
,
index_t
num_k_block_tile_iteration
,
index_t
num_k_block_tile_iteration
,
const
XDataType
*
const
__restrict__
p_x_global
,
const
XDataType
*
const
__restrict__
p_x_global
,
MeanVarDataType
*
const
__restrict__
p_welford_mean
,
Workspace
MeanVarDataType
*
const
__restrict__
p_welford_mean
,
MeanVarDataType
*
const
__restrict__
p_welford_variance
,
Workspace
MeanVarDataType
*
const
__restrict__
p_welford_variance
,
int32_t
*
const
__restrict__
p_welford_count
)
int32_t
*
const
__restrict__
p_welford_count
)
{
{
GridwiseWelford
::
Run
(
x_grid_desc_m_k
,
GridwiseWelford
::
Run
(
x_grid_desc_m_k
,
...
@@ -42,16 +42,18 @@ kernel_normalizationSplitK1st(const XGridDesc_M_K x_grid_desc_m_k,
...
@@ -42,16 +42,18 @@ kernel_normalizationSplitK1st(const XGridDesc_M_K x_grid_desc_m_k,
};
};
template
<
typename
GridwiseWelfordNormalization
,
template
<
typename
GridwiseWelfordNormalization
,
typename
MeanVarDataType
,
typename
Workspace
MeanVarDataType
,
typename
XDataType
,
typename
XDataType
,
typename
GammaDataType
,
typename
GammaDataType
,
typename
BetaDataType
,
typename
BetaDataType
,
typename
YDataType
,
typename
YDataType
,
typename
SaveMeanInvStdDataType
,
typename
ComputeDataType
,
typename
ComputeDataType
,
typename
YElementwiseOperation
,
typename
YElementwiseOperation
,
typename
MeanVarGridDesc_M_KBlock
,
typename
MeanVarGridDesc_M_KBlock
,
typename
CountGridDesc_M_KBlock
,
typename
CountGridDesc_M_KBlock
,
typename
XYGammaBetaGridDesc_M_K
>
typename
XYGammaBetaGridDesc_M_K
,
typename
SaveMeanInvStdGridDesc_M
>
__global__
void
__global__
void
kernel_normalizationSplitK2nd
(
const
MeanVarGridDesc_M_KBlock
mean_var_grid_desc_m_kblock
,
kernel_normalizationSplitK2nd
(
const
MeanVarGridDesc_M_KBlock
mean_var_grid_desc_m_kblock
,
const
CountGridDesc_M_KBlock
count_grid_desc_m_kblock
,
const
CountGridDesc_M_KBlock
count_grid_desc_m_kblock
,
...
@@ -59,17 +61,21 @@ kernel_normalizationSplitK2nd(const MeanVarGridDesc_M_KBlock mean_var_grid_desc_
...
@@ -59,17 +61,21 @@ kernel_normalizationSplitK2nd(const MeanVarGridDesc_M_KBlock mean_var_grid_desc_
const
XYGammaBetaGridDesc_M_K
gamma_grid_desc_m_k
,
const
XYGammaBetaGridDesc_M_K
gamma_grid_desc_m_k
,
const
XYGammaBetaGridDesc_M_K
beta_grid_desc_m_k
,
const
XYGammaBetaGridDesc_M_K
beta_grid_desc_m_k
,
const
XYGammaBetaGridDesc_M_K
y_grid_desc_m_k
,
const
XYGammaBetaGridDesc_M_K
y_grid_desc_m_k
,
const
SaveMeanInvStdGridDesc_M
save_mean_grid_desc_m
,
const
SaveMeanInvStdGridDesc_M
save_inv_std_grid_desc_m
,
index_t
num_k_mean_var_count_iteration
,
index_t
num_k_mean_var_count_iteration
,
index_t
num_k_block_tile_iteration
,
index_t
num_k_block_tile_iteration
,
index_t
k_grid_size
,
index_t
k_grid_size
,
ComputeDataType
epsilon
,
ComputeDataType
epsilon
,
const
MeanVarDataType
*
const
p_mean_global
,
const
Workspace
MeanVarDataType
*
const
p_mean_global
,
const
MeanVarDataType
*
const
p_variance_global
,
const
Workspace
MeanVarDataType
*
const
p_variance_global
,
const
int32_t
*
const
p_welford_count_global
,
const
int32_t
*
const
p_welford_count_global
,
const
XDataType
*
const
__restrict__
p_x_global
,
const
XDataType
*
const
__restrict__
p_x_global
,
const
GammaDataType
*
const
__restrict__
p_gamma_global
,
const
GammaDataType
*
const
__restrict__
p_gamma_global
,
const
BetaDataType
*
const
__restrict__
p_beta_global
,
const
BetaDataType
*
const
__restrict__
p_beta_global
,
YDataType
*
const
__restrict__
p_y_global
,
YDataType
*
const
__restrict__
p_y_global
,
SaveMeanInvStdDataType
*
const
__restrict__
p_save_mean_global
,
SaveMeanInvStdDataType
*
const
__restrict__
p_save_inv_std_global
,
const
YElementwiseOperation
y_elementwise_op
)
const
YElementwiseOperation
y_elementwise_op
)
{
{
GridwiseWelfordNormalization
::
Run
(
mean_var_grid_desc_m_kblock
,
GridwiseWelfordNormalization
::
Run
(
mean_var_grid_desc_m_kblock
,
...
@@ -78,6 +84,8 @@ kernel_normalizationSplitK2nd(const MeanVarGridDesc_M_KBlock mean_var_grid_desc_
...
@@ -78,6 +84,8 @@ kernel_normalizationSplitK2nd(const MeanVarGridDesc_M_KBlock mean_var_grid_desc_
gamma_grid_desc_m_k
,
gamma_grid_desc_m_k
,
beta_grid_desc_m_k
,
beta_grid_desc_m_k
,
y_grid_desc_m_k
,
y_grid_desc_m_k
,
save_mean_grid_desc_m
,
save_inv_std_grid_desc_m
,
num_k_mean_var_count_iteration
,
num_k_mean_var_count_iteration
,
num_k_block_tile_iteration
,
num_k_block_tile_iteration
,
k_grid_size
,
k_grid_size
,
...
@@ -89,6 +97,8 @@ kernel_normalizationSplitK2nd(const MeanVarGridDesc_M_KBlock mean_var_grid_desc_
...
@@ -89,6 +97,8 @@ kernel_normalizationSplitK2nd(const MeanVarGridDesc_M_KBlock mean_var_grid_desc_
p_gamma_global
,
p_gamma_global
,
p_beta_global
,
p_beta_global
,
p_y_global
,
p_y_global
,
p_save_mean_global
,
p_save_inv_std_global
,
y_elementwise_op
);
y_elementwise_op
);
};
};
}
// namespace ck
}
// namespace ck
...
@@ -107,6 +117,7 @@ template <typename XDataType,
...
@@ -107,6 +117,7 @@ template <typename XDataType,
typename
BetaDataType
,
typename
BetaDataType
,
typename
ComputeDataType
,
typename
ComputeDataType
,
typename
YDataType
,
typename
YDataType
,
typename
SaveMeanInvStdDataType
,
typename
YElementwiseOperation
,
typename
YElementwiseOperation
,
index_t
Rank
,
index_t
Rank
,
index_t
NumReduceDim
,
index_t
NumReduceDim
,
...
@@ -121,17 +132,18 @@ template <typename XDataType,
...
@@ -121,17 +132,18 @@ template <typename XDataType,
index_t
GammaSrcVectorSize
,
index_t
GammaSrcVectorSize
,
index_t
BetaSrcVectorDim
,
index_t
BetaSrcVectorDim
,
index_t
BetaSrcVectorSize
,
index_t
BetaSrcVectorSize
,
index_t
YDstVectorSize
>
index_t
YDstVectorSize
,
index_t
SaveMeanInvStdDstVectorSize
>
struct
DeviceNormalizationSplitKImpl
:
public
DeviceNormalization
<
XDataType
,
struct
DeviceNormalizationSplitKImpl
:
public
DeviceNormalization
<
XDataType
,
GammaDataType
,
GammaDataType
,
BetaDataType
,
BetaDataType
,
ComputeDataType
,
YDataType
,
YDataType
,
SaveMeanInvStdDataType
,
YElementwiseOperation
,
YElementwiseOperation
,
Rank
,
Rank
,
NumReduceDim
>
NumReduceDim
>
{
{
using
MeanVarDataType
=
Compute
DataType
;
using
Workspace
MeanVarDataType
=
SaveMeanInvStd
DataType
;
static_assert
(
BlockSize
==
MThreadClusterSize
*
KThreadClusterSize
);
static_assert
(
BlockSize
==
MThreadClusterSize
*
KThreadClusterSize
);
static_assert
(
static_assert
(
...
@@ -144,22 +156,28 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType,
...
@@ -144,22 +156,28 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType,
(
BetaSrcVectorDim
==
1
&&
KThreadSliceSize
%
BetaSrcVectorSize
==
0
)),
(
BetaSrcVectorDim
==
1
&&
KThreadSliceSize
%
BetaSrcVectorSize
==
0
)),
"Invalid thread slice sizes and/or beta vector sizes configuration, please check!"
);
"Invalid thread slice sizes and/or beta vector sizes configuration, please check!"
);
static_assert
(
MThreadSliceSize
%
SaveMeanInvStdDstVectorSize
==
0
,
"Invalid thread slice sizes and/or save mean and inverse std vector sizes "
"configuration, please check!"
);
using
PassThrough
=
tensor_operation
::
element_wise
::
PassThrough
;
using
PassThrough
=
tensor_operation
::
element_wise
::
PassThrough
;
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
index_t
NumInvariantDim
=
Rank
-
NumReduceDim
;
static
constexpr
index_t
M_BlockTileSize
=
MThreadClusterSize
*
MThreadSliceSize
;
static
constexpr
index_t
M_BlockTileSize
=
MThreadClusterSize
*
MThreadSliceSize
;
static
constexpr
index_t
K_BlockTileSize
=
KThreadClusterSize
*
KThreadSliceSize
;
static
constexpr
index_t
K_BlockTileSize
=
KThreadClusterSize
*
KThreadSliceSize
;
static
constexpr
bool
reduceAllDim
=
(
NumInvariantDim
==
0
);
static_assert
(
!
reduceAllDim
);
// TODO
static
auto
MakeSrc2dDescriptor
(
const
std
::
vector
<
index_t
>&
inLengths
,
static
auto
MakeSrc2dDescriptor
(
const
std
::
vector
<
index_t
>&
inLengths
,
const
std
::
vector
<
index_t
>&
inStrides
,
const
std
::
vector
<
index_t
>&
inStrides
,
int
kBlockSize
,
int
kBlockSize
,
int
numBlockTileIteration
)
int
numBlockTileIteration
)
{
{
constexpr
index_t
NumInvariantDim
=
Rank
-
NumReduceDim
;
static
constexpr
index_t
numSrcDim
=
Rank
;
static
constexpr
index_t
numSrcDim
=
Rank
;
static
constexpr
bool
reduceAllDim
=
(
NumInvariantDim
==
0
);
const
auto
tupleSrcLengths
=
make_tuple_from_array
(
inLengths
,
Number
<
numSrcDim
>
{});
const
auto
tupleSrcLengths
=
make_tuple_from_array
(
inLengths
,
Number
<
numSrcDim
>
{});
const
auto
tupleSrcStrides
=
make_tuple_from_array
(
inStrides
,
Number
<
numSrcDim
>
{});
const
auto
tupleSrcStrides
=
make_tuple_from_array
(
inStrides
,
Number
<
numSrcDim
>
{});
...
@@ -219,7 +237,7 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType,
...
@@ -219,7 +237,7 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType,
};
};
template
<
typename
DoPads
,
index_t
MPerTile
,
index_t
KPerTile
>
template
<
typename
DoPads
,
index_t
MPerTile
,
index_t
KPerTile
>
static
auto
MakeMeanVarDescriptor_M_K
(
index_t
M
,
index_t
K
)
static
auto
Make
Workspace
MeanVarDescriptor_M_K
(
index_t
M
,
index_t
K
)
{
{
const
auto
grid_desc_m_k
=
const
auto
grid_desc_m_k
=
make_naive_tensor_descriptor
(
make_tuple
(
M
,
K
),
make_tuple
(
K
,
I1
));
make_naive_tensor_descriptor
(
make_tuple
(
M
,
K
),
make_tuple
(
K
,
I1
));
...
@@ -227,26 +245,57 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType,
...
@@ -227,26 +245,57 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType,
}
}
template
<
typename
DoPads
,
index_t
MPerTile
,
index_t
KPerTile
>
template
<
typename
DoPads
,
index_t
MPerTile
,
index_t
KPerTile
>
static
auto
MakeCountDescriptor_M_K
(
index_t
M
,
index_t
K
)
static
auto
Make
Workspace
CountDescriptor_M_K
(
index_t
M
,
index_t
K
)
{
{
const
auto
grid_desc_m_k
=
const
auto
grid_desc_m_k
=
make_naive_tensor_descriptor
(
make_tuple
(
M
,
K
),
make_tuple
(
I0
,
I1
));
make_naive_tensor_descriptor
(
make_tuple
(
M
,
K
),
make_tuple
(
I0
,
I1
));
return
PadTensorDescriptor
(
grid_desc_m_k
,
make_tuple
(
MPerTile
,
KPerTile
),
DoPads
{});
return
PadTensorDescriptor
(
grid_desc_m_k
,
make_tuple
(
MPerTile
,
KPerTile
),
DoPads
{});
}
}
static
auto
MakeSaveMeanInvStdDescriptor_M
(
const
std
::
vector
<
index_t
>&
lengths
,
const
std
::
vector
<
index_t
>&
strides
)
{
using
InvariantDims
=
typename
arithmetic_sequence_gen
<
0
,
NumInvariantDim
,
1
>::
type
;
const
auto
tupleSrcLengths
=
make_tuple_from_array_and_index_seq
(
lengths
,
InvariantDims
{});
const
auto
tupleSrcStrides
=
make_tuple_from_array_and_index_seq
(
strides
,
InvariantDims
{});
const
auto
desc
=
make_naive_tensor_descriptor
(
tupleSrcLengths
,
tupleSrcStrides
);
const
auto
grid_desc_m
=
transform_tensor_descriptor
(
desc
,
make_tuple
(
make_merge_transform
(
tupleSrcLengths
)),
make_tuple
(
InvariantDims
{}),
make_tuple
(
Sequence
<
0
>
{}));
const
auto
invariantLength
=
grid_desc_m
.
GetLength
(
Number
<
0
>
{});
const
auto
pad_M
=
math
::
integer_least_multiple
(
invariantLength
,
M_BlockTileSize
)
-
invariantLength
;
auto
grid_desc_m_padded
=
transform_tensor_descriptor
(
grid_desc_m
,
make_tuple
(
make_right_pad_transform
(
invariantLength
,
pad_M
)),
make_tuple
(
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
return
grid_desc_m_padded
;
}
using
SrcGridDesc_M_K
=
decltype
(
MakeSrc2dDescriptor
({
1
},
{
1
},
1
,
1
));
using
SrcGridDesc_M_K
=
decltype
(
MakeSrc2dDescriptor
({
1
},
{
1
},
1
,
1
));
using
Kernel1MeanVarGridDesc_M_KBlock
=
using
Kernel1MeanVarGridDesc_M_KBlock
=
decltype
(
MakeMeanVarDescriptor_M_K
<
Sequence
<
true
,
false
>
,
1
,
1
>
(
1
,
1
));
decltype
(
Make
Workspace
MeanVarDescriptor_M_K
<
Sequence
<
true
,
false
>
,
1
,
1
>
(
1
,
1
));
using
Kernel2MeanVarGridDesc_M_KBlock
=
using
Kernel2MeanVarGridDesc_M_KBlock
=
decltype
(
MakeMeanVarDescriptor_M_K
<
Sequence
<
true
,
true
>
,
1
,
1
>
(
1
,
1
));
decltype
(
Make
Workspace
MeanVarDescriptor_M_K
<
Sequence
<
true
,
true
>
,
1
,
1
>
(
1
,
1
));
using
Kernel2CountGridDesc_M_KBlock
=
using
Kernel2CountGridDesc_M_KBlock
=
decltype
(
MakeCountDescriptor_M_K
<
Sequence
<
true
,
true
>
,
1
,
1
>
(
1
,
1
));
decltype
(
MakeWorkspaceCountDescriptor_M_K
<
Sequence
<
true
,
true
>
,
1
,
1
>
(
1
,
1
));
using
SaveMeanInvStdGridDesc_M
=
decltype
(
MakeSaveMeanInvStdDescriptor_M
({
1
},
{
1
}));
using
GridwiseWelford
=
GridwiseNormalizationSplitK1st
<
XDataType
,
using
GridwiseWelford
=
GridwiseNormalizationSplitK1st
<
XDataType
,
ComputeDataType
,
ComputeDataType
,
MeanVarDataType
,
Workspace
MeanVarDataType
,
SrcGridDesc_M_K
,
SrcGridDesc_M_K
,
Kernel1MeanVarGridDesc_M_KBlock
,
Kernel1MeanVarGridDesc_M_KBlock
,
BlockSize
,
BlockSize
,
...
@@ -258,16 +307,18 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType,
...
@@ -258,16 +307,18 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType,
XSrcVectorSize
>
;
XSrcVectorSize
>
;
using
GridwiseWelfordNormalization
=
using
GridwiseWelfordNormalization
=
GridwiseNormalizationSplitK2nd
<
MeanVarDataType
,
GridwiseNormalizationSplitK2nd
<
Workspace
MeanVarDataType
,
XDataType
,
XDataType
,
GammaDataType
,
GammaDataType
,
BetaDataType
,
BetaDataType
,
YDataType
,
YDataType
,
SaveMeanInvStdDataType
,
ComputeDataType
,
ComputeDataType
,
YElementwiseOperation
,
YElementwiseOperation
,
Kernel2MeanVarGridDesc_M_KBlock
,
Kernel2MeanVarGridDesc_M_KBlock
,
Kernel2CountGridDesc_M_KBlock
,
Kernel2CountGridDesc_M_KBlock
,
SrcGridDesc_M_K
,
SrcGridDesc_M_K
,
SaveMeanInvStdGridDesc_M
,
BlockSize
,
BlockSize
,
MThreadClusterSize
,
MThreadClusterSize
,
KThreadClusterSize
,
KThreadClusterSize
,
...
@@ -280,7 +331,8 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType,
...
@@ -280,7 +331,8 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType,
BetaSrcVectorDim
,
BetaSrcVectorDim
,
BetaSrcVectorSize
,
BetaSrcVectorSize
,
XYVectorDim
,
XYVectorDim
,
YDstVectorSize
>
;
YDstVectorSize
,
SaveMeanInvStdDstVectorSize
>
;
struct
Argument
:
public
BaseArgument
struct
Argument
:
public
BaseArgument
{
{
...
@@ -289,17 +341,23 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType,
...
@@ -289,17 +341,23 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType,
const
std
::
vector
<
index_t
>
gammaStrides
,
const
std
::
vector
<
index_t
>
gammaStrides
,
const
std
::
vector
<
index_t
>
betaStrides
,
const
std
::
vector
<
index_t
>
betaStrides
,
const
std
::
vector
<
index_t
>
yStrides
,
const
std
::
vector
<
index_t
>
yStrides
,
const
std
::
vector
<
index_t
>
saveMeanStrides
,
const
std
::
vector
<
index_t
>
saveInvStdStrides
,
const
std
::
vector
<
index_t
>
reduceDims
,
const
std
::
vector
<
index_t
>
reduceDims
,
YElementwiseOperation
y_elementwise_op
,
YElementwiseOperation
y_elementwise_op
,
double
epsilon
,
double
epsilon
,
const
XDataType
*
p_x
,
const
XDataType
*
p_x
,
const
GammaDataType
*
p_gamma
,
const
GammaDataType
*
p_gamma
,
const
BetaDataType
*
p_beta
,
const
BetaDataType
*
p_beta
,
YDataType
*
p_y
)
YDataType
*
p_y
,
SaveMeanInvStdDataType
*
p_saveMean
,
SaveMeanInvStdDataType
*
p_saveInvStd
)
:
p_x_
(
p_x
),
:
p_x_
(
p_x
),
p_gamma_
(
p_gamma
),
p_gamma_
(
p_gamma
),
p_beta_
(
p_beta
),
p_beta_
(
p_beta
),
p_y_
(
p_y
),
p_y_
(
p_y
),
p_saveMean_
(
p_saveMean
),
p_saveInvStd_
(
p_saveInvStd
),
p_workspace_mean_
{
nullptr
},
p_workspace_mean_
{
nullptr
},
p_workspace_var_
{
nullptr
},
p_workspace_var_
{
nullptr
},
p_workspace_count_
{
nullptr
},
p_workspace_count_
{
nullptr
},
...
@@ -312,6 +370,8 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType,
...
@@ -312,6 +370,8 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType,
yStrides_
=
shuffle_tensor_dimensions
<
Rank
,
NumReduceDim
>
(
yStrides
,
reduceDims
);
yStrides_
=
shuffle_tensor_dimensions
<
Rank
,
NumReduceDim
>
(
yStrides
,
reduceDims
);
gammaStrides_
=
shuffle_tensor_dimensions
<
Rank
,
NumReduceDim
>
(
gammaStrides
,
reduceDims
);
gammaStrides_
=
shuffle_tensor_dimensions
<
Rank
,
NumReduceDim
>
(
gammaStrides
,
reduceDims
);
betaStrides_
=
shuffle_tensor_dimensions
<
Rank
,
NumReduceDim
>
(
betaStrides
,
reduceDims
);
betaStrides_
=
shuffle_tensor_dimensions
<
Rank
,
NumReduceDim
>
(
betaStrides
,
reduceDims
);
saveMeanStrides_
=
saveMeanStrides
;
saveInvStdStrides_
=
saveInvStdStrides
;
std
::
tie
(
MRaw_
,
KRaw_
)
=
get_2d_lengths
<
Rank
,
NumReduceDim
>
(
Lengths_
);
std
::
tie
(
MRaw_
,
KRaw_
)
=
get_2d_lengths
<
Rank
,
NumReduceDim
>
(
Lengths_
);
...
@@ -346,20 +406,28 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType,
...
@@ -346,20 +406,28 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType,
y_grid_desc_m_k_
=
y_grid_desc_m_k_
=
MakeSrc2dDescriptor
(
Lengths_
,
yStrides_
,
kGridSize_
,
numBlockTileIteration_
);
MakeSrc2dDescriptor
(
Lengths_
,
yStrides_
,
kGridSize_
,
numBlockTileIteration_
);
save_mean_grid_desc_m_
=
MakeSaveMeanInvStdDescriptor_M
(
Lengths_
,
saveMeanStrides
);
save_inv_std_grid_desc_m_
=
MakeSaveMeanInvStdDescriptor_M
(
Lengths_
,
saveInvStdStrides
);
// We don't need to pad in K dimension for Welford1. Set KPerTile 1.
// We don't need to pad in K dimension for Welford1. Set KPerTile 1.
kernel1_mean_var_grid_desc_m_kblock_
=
kernel1_mean_var_grid_desc_m_kblock_
=
MakeMeanVarDescriptor_M_K
<
Sequence
<
true
,
false
>
,
M_BlockTileSize
,
1
>
(
MRaw_
,
Make
Workspace
MeanVarDescriptor_M_K
<
Sequence
<
true
,
false
>
,
M_BlockTileSize
,
1
>
(
kGridSize_
);
MRaw_
,
kGridSize_
);
kernel2_mean_var_grid_desc_m_kblock_
=
kernel2_mean_var_grid_desc_m_kblock_
=
MakeMeanVarDescriptor_M_K
<
Sequence
<
true
,
true
>
,
Make
Workspace
MeanVarDescriptor_M_K
<
Sequence
<
true
,
true
>
,
M_BlockTileSize
,
M_BlockTileSize
,
K_MeanVarCountBlockTileSize
>
(
MRaw_
,
kGridSize_
);
K_MeanVarCountBlockTileSize
>
(
MRaw_
,
kGridSize_
);
kernel2_count_grid_desc_m_kblock_
=
kernel2_count_grid_desc_m_kblock_
=
MakeCountDescriptor_M_K
<
Sequence
<
true
,
true
>
,
MakeWorkspaceCountDescriptor_M_K
<
Sequence
<
true
,
true
>
,
M_BlockTileSize
,
M_BlockTileSize
,
K_MeanVarCountBlockTileSize
>
(
MRaw_
,
kGridSize_
);
K_MeanVarCountBlockTileSize
>
(
MRaw_
,
kGridSize_
);
if
constexpr
(
NumInvariantDim
==
0
)
invariant_lowest_length_
=
1
;
else
invariant_lowest_length_
=
Lengths_
[
NumInvariantDim
-
1
];
}
}
ComputeDataType
epsilon_
;
ComputeDataType
epsilon_
;
...
@@ -368,6 +436,8 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType,
...
@@ -368,6 +436,8 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType,
const
GammaDataType
*
p_gamma_
;
const
GammaDataType
*
p_gamma_
;
const
BetaDataType
*
p_beta_
;
const
BetaDataType
*
p_beta_
;
YDataType
*
p_y_
;
YDataType
*
p_y_
;
SaveMeanInvStdDataType
*
p_saveMean_
;
SaveMeanInvStdDataType
*
p_saveInvStd_
;
void
*
p_workspace_mean_
;
void
*
p_workspace_mean_
;
void
*
p_workspace_var_
;
void
*
p_workspace_var_
;
void
*
p_workspace_count_
;
void
*
p_workspace_count_
;
...
@@ -377,6 +447,8 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType,
...
@@ -377,6 +447,8 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType,
std
::
vector
<
index_t
>
gammaStrides_
;
std
::
vector
<
index_t
>
gammaStrides_
;
std
::
vector
<
index_t
>
betaStrides_
;
std
::
vector
<
index_t
>
betaStrides_
;
std
::
vector
<
index_t
>
yStrides_
;
std
::
vector
<
index_t
>
yStrides_
;
std
::
vector
<
index_t
>
saveMeanStrides_
;
std
::
vector
<
index_t
>
saveInvStdStrides_
;
YElementwiseOperation
y_elementwise_op_
;
YElementwiseOperation
y_elementwise_op_
;
...
@@ -389,6 +461,8 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType,
...
@@ -389,6 +461,8 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType,
SrcGridDesc_M_K
gamma_grid_desc_m_k_
;
SrcGridDesc_M_K
gamma_grid_desc_m_k_
;
SrcGridDesc_M_K
beta_grid_desc_m_k_
;
SrcGridDesc_M_K
beta_grid_desc_m_k_
;
SrcGridDesc_M_K
y_grid_desc_m_k_
;
SrcGridDesc_M_K
y_grid_desc_m_k_
;
SaveMeanInvStdGridDesc_M
save_mean_grid_desc_m_
;
SaveMeanInvStdGridDesc_M
save_inv_std_grid_desc_m_
;
Kernel1MeanVarGridDesc_M_KBlock
kernel1_mean_var_grid_desc_m_kblock_
;
Kernel1MeanVarGridDesc_M_KBlock
kernel1_mean_var_grid_desc_m_kblock_
;
Kernel2MeanVarGridDesc_M_KBlock
kernel2_mean_var_grid_desc_m_kblock_
;
Kernel2MeanVarGridDesc_M_KBlock
kernel2_mean_var_grid_desc_m_kblock_
;
...
@@ -396,6 +470,8 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType,
...
@@ -396,6 +470,8 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType,
index_t
MRaw_
;
// invarient length
index_t
MRaw_
;
// invarient length
index_t
KRaw_
;
// reduce length
index_t
KRaw_
;
// reduce length
index_t
invariant_lowest_length_
;
};
};
struct
Invoker
:
public
BaseInvoker
struct
Invoker
:
public
BaseInvoker
...
@@ -408,60 +484,68 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType,
...
@@ -408,60 +484,68 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType,
auto
kernel1
=
kernel_normalizationSplitK1st
<
GridwiseWelford
,
auto
kernel1
=
kernel_normalizationSplitK1st
<
GridwiseWelford
,
XDataType
,
XDataType
,
MeanVarDataType
,
Workspace
MeanVarDataType
,
ComputeDataType
,
ComputeDataType
,
SrcGridDesc_M_K
,
SrcGridDesc_M_K
,
Kernel1MeanVarGridDesc_M_KBlock
>
;
Kernel1MeanVarGridDesc_M_KBlock
>
;
auto
kernel2
=
kernel_normalizationSplitK2nd
<
GridwiseWelfordNormalization
,
auto
kernel2
=
kernel_normalizationSplitK2nd
<
GridwiseWelfordNormalization
,
MeanVarDataType
,
Workspace
MeanVarDataType
,
XDataType
,
XDataType
,
GammaDataType
,
GammaDataType
,
BetaDataType
,
BetaDataType
,
YDataType
,
YDataType
,
SaveMeanInvStdDataType
,
ComputeDataType
,
ComputeDataType
,
YElementwiseOperation
,
YElementwiseOperation
,
Kernel2MeanVarGridDesc_M_KBlock
,
Kernel2MeanVarGridDesc_M_KBlock
,
Kernel2CountGridDesc_M_KBlock
,
Kernel2CountGridDesc_M_KBlock
,
SrcGridDesc_M_K
>
;
SrcGridDesc_M_K
,
SaveMeanInvStdGridDesc_M
>
;
float
avg_time
=
0
;
float
avg_time
=
0
;
avg_time
+=
launch_and_time_kernel
(
stream_config
,
avg_time
+=
launch_and_time_kernel
(
kernel1
,
stream_config
,
dim3
(
arg
.
gridSize_
),
kernel1
,
dim3
(
BlockSize
),
dim3
(
arg
.
gridSize_
),
0
,
dim3
(
BlockSize
),
arg
.
x_grid_desc_m_k_
,
0
,
arg
.
kernel1_mean_var_grid_desc_m_kblock_
,
arg
.
x_grid_desc_m_k_
,
arg
.
numBlockTileIteration_
,
arg
.
kernel1_mean_var_grid_desc_m_kblock_
,
arg
.
p_x_
,
arg
.
numBlockTileIteration_
,
static_cast
<
MeanVarDataType
*>
(
arg
.
p_workspace_mean_
),
arg
.
p_x_
,
static_cast
<
MeanVarDataType
*>
(
arg
.
p_workspace_var_
),
static_cast
<
WorkspaceMeanVarDataType
*>
(
arg
.
p_workspace_mean_
),
static_cast
<
int32_t
*>
(
arg
.
p_workspace_count_
));
static_cast
<
WorkspaceMeanVarDataType
*>
(
arg
.
p_workspace_var_
),
static_cast
<
int32_t
*>
(
arg
.
p_workspace_count_
));
avg_time
+=
launch_and_time_kernel
(
stream_config
,
kernel2
,
avg_time
+=
launch_and_time_kernel
(
dim3
(
arg
.
gridSize_
),
stream_config
,
dim3
(
BlockSize
),
kernel2
,
0
,
dim3
(
arg
.
gridSize_
),
arg
.
kernel2_mean_var_grid_desc_m_kblock_
,
dim3
(
BlockSize
),
arg
.
kernel2_count_grid_desc_m_kblock_
,
0
,
arg
.
x_grid_desc_m_k_
,
arg
.
kernel2_mean_var_grid_desc_m_kblock_
,
arg
.
gamma_grid_desc_m_k_
,
arg
.
kernel2_count_grid_desc_m_kblock_
,
arg
.
beta_grid_desc_m_k_
,
arg
.
x_grid_desc_m_k_
,
arg
.
y_grid_desc_m_k_
,
arg
.
gamma_grid_desc_m_k_
,
arg
.
numMeanVarCountIteration_
,
arg
.
beta_grid_desc_m_k_
,
arg
.
numBlockTileIteration_
,
arg
.
y_grid_desc_m_k_
,
arg
.
kGridSize_
,
arg
.
save_mean_grid_desc_m_
,
arg
.
epsilon_
,
arg
.
save_inv_std_grid_desc_m_
,
static_cast
<
MeanVarDataType
*>
(
arg
.
p_workspace_mean_
),
arg
.
numMeanVarCountIteration_
,
static_cast
<
MeanVarDataType
*>
(
arg
.
p_workspace_var_
),
arg
.
numBlockTileIteration_
,
static_cast
<
int32_t
*>
(
arg
.
p_workspace_count_
),
arg
.
kGridSize_
,
arg
.
p_x_
,
arg
.
epsilon_
,
arg
.
p_gamma_
,
static_cast
<
const
WorkspaceMeanVarDataType
*>
(
arg
.
p_workspace_mean_
),
arg
.
p_beta_
,
static_cast
<
const
WorkspaceMeanVarDataType
*>
(
arg
.
p_workspace_var_
),
arg
.
p_y_
,
static_cast
<
const
int32_t
*>
(
arg
.
p_workspace_count_
),
arg
.
y_elementwise_op_
);
arg
.
p_x_
,
arg
.
p_gamma_
,
arg
.
p_beta_
,
arg
.
p_y_
,
arg
.
p_saveMean_
,
arg
.
p_saveInvStd_
,
arg
.
y_elementwise_op_
);
return
avg_time
;
return
avg_time
;
};
};
...
@@ -482,10 +566,10 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType,
...
@@ -482,10 +566,10 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType,
int
welford_size
=
pArg_
->
MRaw_
*
pArg_
->
kGridSize_
;
int
welford_size
=
pArg_
->
MRaw_
*
pArg_
->
kGridSize_
;
// workspace for welford intermediate mean
// workspace for welford intermediate mean
workspace_size
+=
welford_size
*
sizeof
(
MeanVarDataType
)
+
64
;
workspace_size
+=
welford_size
*
sizeof
(
Workspace
MeanVarDataType
)
+
64
;
// workspace for welford intermediate variance
// workspace for welford intermediate variance
workspace_size
+=
welford_size
*
sizeof
(
MeanVarDataType
)
+
64
;
workspace_size
+=
welford_size
*
sizeof
(
Workspace
MeanVarDataType
)
+
64
;
// workspace for welford intermediate count
// workspace for welford intermediate count
workspace_size
+=
pArg_
->
kGridSize_
*
sizeof
(
int32_t
)
+
64
;
workspace_size
+=
pArg_
->
kGridSize_
*
sizeof
(
int32_t
)
+
64
;
...
@@ -504,13 +588,13 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType,
...
@@ -504,13 +588,13 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType,
// setup buffer used for intermediate welford mean
// setup buffer used for intermediate welford mean
pArg_
->
p_workspace_mean_
=
static_cast
<
char
*>
(
pArg_
->
p_workspace_
);
pArg_
->
p_workspace_mean_
=
static_cast
<
char
*>
(
pArg_
->
p_workspace_
);
index_t
mean_space_sz
=
welford_size
*
sizeof
(
MeanVarDataType
);
index_t
mean_space_sz
=
welford_size
*
sizeof
(
Workspace
MeanVarDataType
);
mean_space_sz
=
math
::
integer_least_multiple
(
mean_space_sz
,
64
);
mean_space_sz
=
math
::
integer_least_multiple
(
mean_space_sz
,
64
);
// setup buffer used for intermediate welford varirance
// setup buffer used for intermediate welford varirance
pArg_
->
p_workspace_var_
=
reinterpret_cast
<
char
*>
(
pArg_
->
p_workspace_mean_
)
+
mean_space_sz
;
pArg_
->
p_workspace_var_
=
reinterpret_cast
<
char
*>
(
pArg_
->
p_workspace_mean_
)
+
mean_space_sz
;
index_t
variance_space_sz
=
welford_size
*
sizeof
(
MeanVarDataType
);
index_t
variance_space_sz
=
welford_size
*
sizeof
(
Workspace
MeanVarDataType
);
variance_space_sz
=
math
::
integer_least_multiple
(
variance_space_sz
,
64
);
variance_space_sz
=
math
::
integer_least_multiple
(
variance_space_sz
,
64
);
// setup buffer used for intermediate welford count
// setup buffer used for intermediate welford count
...
@@ -522,8 +606,6 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType,
...
@@ -522,8 +606,6 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType,
{
{
const
Argument
*
p_arg_
=
dynamic_cast
<
const
Argument
*>
(
p_arg
);
const
Argument
*
p_arg_
=
dynamic_cast
<
const
Argument
*>
(
p_arg
);
constexpr
index_t
NumInvariantDim
=
Rank
-
NumReduceDim
;
if
constexpr
(
XYVectorDim
==
0
)
if
constexpr
(
XYVectorDim
==
0
)
{
{
if
constexpr
(
NumInvariantDim
==
0
)
if
constexpr
(
NumInvariantDim
==
0
)
...
@@ -535,10 +617,10 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType,
...
@@ -535,10 +617,10 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType,
if
(
p_arg_
->
xStrides_
[
NumInvariantDim
-
1
]
!=
1
)
if
(
p_arg_
->
xStrides_
[
NumInvariantDim
-
1
]
!=
1
)
return
false
;
return
false
;
if
(
p_arg_
->
invariant_lowest_length
%
XSrcVectorSize
!=
0
)
if
(
p_arg_
->
invariant_lowest_length
_
%
XSrcVectorSize
!=
0
)
return
false
;
return
false
;
if
(
p_arg_
->
invariant_lowest_length
%
YDstVectorSize
!=
0
)
if
(
p_arg_
->
invariant_lowest_length
_
%
YDstVectorSize
!=
0
)
return
false
;
return
false
;
};
};
}
}
...
@@ -578,7 +660,7 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType,
...
@@ -578,7 +660,7 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType,
if
(
p_arg_
->
betaStrides_
[
NumInvariantDim
-
1
]
!=
1
)
if
(
p_arg_
->
betaStrides_
[
NumInvariantDim
-
1
]
!=
1
)
return
false
;
return
false
;
if
(
p_arg_
->
invariant_lowest_length
%
BetaSrcVectorSize
!=
0
)
if
(
p_arg_
->
invariant_lowest_length
_
%
BetaSrcVectorSize
!=
0
)
return
false
;
return
false
;
}
}
else
// if fastest dim is reduced
else
// if fastest dim is reduced
...
@@ -593,6 +675,9 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType,
...
@@ -593,6 +675,9 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType,
if
(
p_arg_
->
kGridSize_
<=
1
)
if
(
p_arg_
->
kGridSize_
<=
1
)
return
false
;
return
false
;
if
(
p_arg_
->
invariant_lowest_length_
%
SaveMeanInvStdDstVectorSize
!=
0
)
return
false
;
return
true
;
return
true
;
};
};
...
@@ -602,6 +687,8 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType,
...
@@ -602,6 +687,8 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType,
const
std
::
vector
<
index_t
>
gammaStrides
,
const
std
::
vector
<
index_t
>
gammaStrides
,
const
std
::
vector
<
index_t
>
betaStrides
,
const
std
::
vector
<
index_t
>
betaStrides
,
const
std
::
vector
<
index_t
>
yStrides
,
const
std
::
vector
<
index_t
>
yStrides
,
const
std
::
vector
<
index_t
>
saveMeanStrides
,
const
std
::
vector
<
index_t
>
saveInvStdStrides
,
const
std
::
vector
<
index_t
>
reduceDims
,
const
std
::
vector
<
index_t
>
reduceDims
,
double
epsilon
,
double
epsilon
,
const
void
*
p_x
,
const
void
*
p_x
,
...
@@ -609,27 +696,30 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType,
...
@@ -609,27 +696,30 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType,
const
void
*
p_beta
,
const
void
*
p_beta
,
void
*
p_y
,
void
*
p_y
,
void
*
p_saveMean
,
void
*
p_saveMean
,
void
*
p_saveInv
Var
,
void
*
p_saveInv
Std
,
YElementwiseOperation
y_elementwise_op
)
override
YElementwiseOperation
y_elementwise_op
)
override
{
{
// TODO
if
(
lengths
.
size
()
!=
Rank
||
xStrides
.
size
()
!=
Rank
||
gammaStrides
.
size
()
!=
Rank
||
// Optional cache of the intermediate results (mean and InvVariance) during the
betaStrides
.
size
()
!=
Rank
||
yStrides
.
size
()
!=
Rank
||
// forward pass could speedup in the backward
saveMeanStrides
.
size
()
!=
NumInvariantDim
||
saveInvStdStrides
.
size
()
!=
NumInvariantDim
)
ignore
=
p_saveMean
;
throw
std
::
runtime_error
(
"dimension is incorrect"
);
ignore
=
p_saveInvVar
;
return
std
::
make_unique
<
Argument
>
(
lengths
,
return
std
::
make_unique
<
Argument
>
(
lengths
,
xStrides
,
xStrides
,
gammaStrides
,
gammaStrides
,
betaStrides
,
betaStrides
,
yStrides
,
yStrides
,
saveMeanStrides
,
saveInvStdStrides
,
reduceDims
,
reduceDims
,
y_elementwise_op
,
y_elementwise_op
,
epsilon
,
epsilon
,
static_cast
<
const
XDataType
*>
(
p_x
),
static_cast
<
const
XDataType
*>
(
p_x
),
static_cast
<
const
GammaDataType
*>
(
p_gamma
),
static_cast
<
const
GammaDataType
*>
(
p_gamma
),
static_cast
<
const
BetaDataType
*>
(
p_beta
),
static_cast
<
const
BetaDataType
*>
(
p_beta
),
static_cast
<
YDataType
*>
(
p_y
));
static_cast
<
YDataType
*>
(
p_y
),
static_cast
<
SaveMeanInvStdDataType
*>
(
p_saveMean
),
static_cast
<
SaveMeanInvStdDataType
*>
(
p_saveInvStd
));
};
};
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
override
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
override
...
...
include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp
View file @
1b5af83d
...
@@ -113,7 +113,6 @@ struct PassThrough
...
@@ -113,7 +113,6 @@ struct PassThrough
}
}
#endif
#endif
#if defined CK_ENABLE_FP8
template
<
>
template
<
>
__host__
__device__
void
operator
()
<
f8_t
,
f8_t
>
(
f8_t
&
y
,
const
f8_t
&
x
)
const
__host__
__device__
void
operator
()
<
f8_t
,
f8_t
>
(
f8_t
&
y
,
const
f8_t
&
x
)
const
{
{
...
@@ -143,9 +142,7 @@ struct PassThrough
...
@@ -143,9 +142,7 @@ struct PassThrough
{
{
y
=
type_convert
<
f8_t
>
(
x
);
y
=
type_convert
<
f8_t
>
(
x
);
}
}
#endif
#if defined CK_ENABLE_BF8
template
<
>
template
<
>
__host__
__device__
void
operator
()
<
bf8_t
,
bf8_t
>
(
bf8_t
&
y
,
const
bf8_t
&
x
)
const
__host__
__device__
void
operator
()
<
bf8_t
,
bf8_t
>
(
bf8_t
&
y
,
const
bf8_t
&
x
)
const
{
{
...
@@ -175,7 +172,6 @@ struct PassThrough
...
@@ -175,7 +172,6 @@ struct PassThrough
{
{
y
=
ck
::
type_convert
<
bf8_t
>
(
x
);
y
=
ck
::
type_convert
<
bf8_t
>
(
x
);
}
}
#endif
};
};
struct
UnaryConvert
struct
UnaryConvert
...
@@ -204,7 +200,6 @@ struct ConvertBF16RTN
...
@@ -204,7 +200,6 @@ struct ConvertBF16RTN
}
}
};
};
#if defined CK_ENABLE_FP8
struct
ConvertF8SR
struct
ConvertF8SR
{
{
// convert to fp8 using stochastic rounding (SR)
// convert to fp8 using stochastic rounding (SR)
...
@@ -212,7 +207,8 @@ struct ConvertF8SR
...
@@ -212,7 +207,8 @@ struct ConvertF8SR
__host__
__device__
void
operator
()(
Y
&
y
,
const
X
&
x
)
const
__host__
__device__
void
operator
()(
Y
&
y
,
const
X
&
x
)
const
{
{
// check Y datatype
// check Y datatype
static_assert
(
is_same
<
Y
,
f8_t
>::
value
,
"Data type is not supported by this operation!"
);
static_assert
(
is_same
<
Y
,
f8_t
>::
value
||
is_same
<
Y
,
bf8_t
>::
value
,
"Data type is not supported by this operation!"
);
// check X datatype
// check X datatype
static_assert
(
is_same
<
X
,
float
>::
value
||
is_same
<
X
,
half_t
>::
value
,
static_assert
(
is_same
<
X
,
float
>::
value
||
is_same
<
X
,
half_t
>::
value
,
...
@@ -221,7 +217,6 @@ struct ConvertF8SR
...
@@ -221,7 +217,6 @@ struct ConvertF8SR
y
=
f8_convert_sr
<
Y
>
(
x
);
y
=
f8_convert_sr
<
Y
>
(
x
);
}
}
};
};
#endif
struct
Scale
struct
Scale
{
{
...
@@ -448,10 +443,11 @@ struct Sigmoid
...
@@ -448,10 +443,11 @@ struct Sigmoid
__host__
__device__
void
operator
()(
T
&
y
,
const
T
&
x
)
const
__host__
__device__
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
{
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
is_same
<
T
,
ck
::
half_t
>::
value
,
is_same
<
T
,
ck
::
half_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
||
is_same
<
T
,
int32_t
>::
value
,
"Data type is not supported by this operation!"
);
"Data type is not supported by this operation!"
);
constexpr
T
one
=
type_convert
<
T
>
(
1
);
y
=
1
/
(
ck
::
type_convert
<
T
>
(
1
)
+
exp
(
-
x
));
y
=
one
/
(
one
+
ck
::
math
::
exp
(
-
x
));
};
};
};
};
...
@@ -461,7 +457,8 @@ struct TanH
...
@@ -461,7 +457,8 @@ struct TanH
__host__
__device__
void
operator
()(
T
&
y
,
const
T
&
x
)
const
__host__
__device__
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
{
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
is_same
<
T
,
ck
::
half_t
>::
value
,
is_same
<
T
,
ck
::
half_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
||
is_same
<
T
,
int32_t
>::
value
,
"Data type is not supported by this operation!"
);
"Data type is not supported by this operation!"
);
y
=
ck
::
math
::
tanh
(
x
);
y
=
ck
::
math
::
tanh
(
x
);
...
@@ -487,7 +484,101 @@ struct Swish
...
@@ -487,7 +484,101 @@ struct Swish
y
=
type_convert
<
Y
>
(
x
/
(
1.
f
+
ck
::
math
::
exp
(
bx
)));
y
=
type_convert
<
Y
>
(
x
/
(
1.
f
+
ck
::
math
::
exp
(
bx
)));
};
};
float
beta_
=
1.0
f
;
const
float
beta_
;
};
struct
SoftRelu
{
SoftRelu
(
float
alpha
=
1.
f
)
:
alpha_
(
alpha
){};
template
<
typename
T
>
__host__
__device__
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
is_same
<
T
,
half_t
>::
value
||
is_same
<
T
,
int32_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
,
"Data type is not supported by this operation!"
);
T
casted_alpha
=
type_convert
<
T
>
(
alpha_
);
constexpr
T
one
=
type_convert
<
T
>
(
1
);
y
=
ck
::
math
::
log
(
one
+
ck
::
math
::
exp
(
x
*
casted_alpha
))
/
casted_alpha
;
}
const
float
alpha_
;
};
struct
Power
{
Power
(
float
alpha
=
0.
f
,
float
beta
=
1.
f
,
float
gamma
=
2.
f
)
:
alpha_
(
alpha
),
beta_
(
beta
),
gamma_
(
gamma
){};
template
<
typename
T
>
__host__
__device__
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
is_same
<
T
,
half_t
>::
value
||
is_same
<
T
,
int32_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
,
"Data type is not supported by this operation!"
);
T
casted_alpha
=
type_convert
<
T
>
(
alpha_
);
T
casted_beta
=
type_convert
<
T
>
(
beta_
);
T
casted_gamma
=
type_convert
<
T
>
(
gamma_
);
T
shifted_scaled_x
=
casted_alpha
+
casted_beta
*
x
;
y
=
ck
::
math
::
pow
(
shifted_scaled_x
,
casted_gamma
);
}
const
float
alpha_
;
const
float
beta_
;
const
float
gamma_
;
};
struct
ClippedRelu
{
ClippedRelu
(
float
alpha
=
0.
f
,
float
beta
=
1.
f
)
:
alpha_
(
alpha
),
beta_
(
beta
){};
template
<
typename
T
>
__host__
__device__
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
is_same
<
T
,
half_t
>::
value
||
is_same
<
T
,
int32_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
,
"Data type is not supported by this operation!"
);
T
casted_alpha
=
type_convert
<
T
>
(
alpha_
);
T
casted_beta
=
type_convert
<
T
>
(
beta_
);
y
=
ck
::
math
::
min
(
casted_beta
,
ck
::
math
::
max
(
casted_alpha
,
x
));
}
const
float
alpha_
;
const
float
beta_
;
};
struct
LeakyRelu
{
LeakyRelu
(
float
alpha
=
0.01
f
)
:
alpha_
(
alpha
){};
template
<
typename
T
>
__host__
__device__
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
is_same
<
T
,
half_t
>::
value
||
is_same
<
T
,
int32_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
,
"Data type is not supported by this operation!"
);
T
casted_alpha
=
type_convert
<
T
>
(
alpha_
);
y
=
x
>=
0
?
x
:
x
*
casted_alpha
;
}
const
float
alpha_
;
};
struct
Elu
{
Elu
(
float
alpha
=
1.
f
)
:
alpha_
(
alpha
){};
template
<
typename
T
>
__host__
__device__
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
is_same
<
T
,
half_t
>::
value
||
is_same
<
T
,
int32_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
,
"Data type is not supported by this operation!"
);
T
casted_alpha
=
type_convert
<
T
>
(
alpha_
);
y
=
x
>
0
?
x
:
casted_alpha
*
ck
::
math
::
expm1
(
x
);
}
const
float
alpha_
;
};
};
}
// namespace element_wise
}
// namespace element_wise
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp
View file @
1b5af83d
...
@@ -36,7 +36,7 @@ __global__ void
...
@@ -36,7 +36,7 @@ __global__ void
#if CK_USE_LAUNCH_BOUNDS
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
#endif
#endif
kernel_grouped_conv_
fwd_
multiple_d_wmma_cshuffle
(
kernel_grouped_conv_multiple_d_wmma_cshuffle
(
const
ADataType
*
__restrict__
p_a_grid
,
const
ADataType
*
__restrict__
p_a_grid
,
const
BDataType
*
__restrict__
p_b_grid
,
const
BDataType
*
__restrict__
p_b_grid
,
DsPointer
p_ds_grid
,
DsPointer
p_ds_grid
,
...
@@ -452,11 +452,11 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle
...
@@ -452,11 +452,11 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle
}
}
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
// CheckValidity for kernels without multi D
template
<
typename
Block2CTileMap
>
template
<
typename
Block2CTileMap
>
__host__
__device__
static
constexpr
bool
__host__
__device__
static
constexpr
bool
CheckValidity
(
const
AGridDesc_K0_M_K1
&
a_grid_desc_k0_m_k1
,
CheckValidity
(
const
AGridDesc_K0_M_K1
&
a_grid_desc_k0_m_k1
,
const
BGridDesc_K0_N_K1
&
b_grid_desc_k0_n_k1
,
const
BGridDesc_K0_N_K1
&
b_grid_desc_k0_n_k1
,
const
DsGridDesc_M_N
&
ds_grid_desc_m_n
,
const
EGridDesc_M_N
&
e_grid_desc_m_n
,
const
EGridDesc_M_N
&
e_grid_desc_m_n
,
const
Block2CTileMap
&
block_2_ctile_map
)
const
Block2CTileMap
&
block_2_ctile_map
)
{
{
...
@@ -471,18 +471,6 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle
...
@@ -471,18 +471,6 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle
const
auto
N
=
b_grid_desc_k0_n_k1
.
GetLength
(
I1
);
const
auto
N
=
b_grid_desc_k0_n_k1
.
GetLength
(
I1
);
const
auto
K0
=
a_grid_desc_k0_m_k1
.
GetLength
(
I0
);
const
auto
K0
=
a_grid_desc_k0_m_k1
.
GetLength
(
I0
);
bool
valid
=
true
;
static_for
<
0
,
NumDTensor
,
1
>
{}([
&
](
auto
i
)
{
valid
=
valid
&&
(
M
==
ds_grid_desc_m_n
[
i
].
GetLength
(
I0
)
&&
N
==
ds_grid_desc_m_n
[
i
].
GetLength
(
I1
));
});
if
(
!
valid
)
{
return
false
;
}
if
(
!
(
M
==
e_grid_desc_m_n
.
GetLength
(
I0
)
&&
N
==
e_grid_desc_m_n
.
GetLength
(
I1
)
&&
if
(
!
(
M
==
e_grid_desc_m_n
.
GetLength
(
I0
)
&&
N
==
e_grid_desc_m_n
.
GetLength
(
I1
)
&&
K0
==
b_grid_desc_k0_n_k1
.
GetLength
(
I0
)
&&
K1
==
a_grid_desc_k0_m_k1
.
GetLength
(
I2
)
&&
K0
==
b_grid_desc_k0_n_k1
.
GetLength
(
I0
)
&&
K1
==
a_grid_desc_k0_m_k1
.
GetLength
(
I2
)
&&
K1
==
b_grid_desc_k0_n_k1
.
GetLength
(
I2
)))
K1
==
b_grid_desc_k0_n_k1
.
GetLength
(
I2
)))
...
@@ -517,6 +505,31 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle
...
@@ -517,6 +505,31 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle
return
true
;
return
true
;
}
}
template
<
typename
Block2CTileMap
>
__host__
__device__
static
constexpr
bool
CheckValidity
(
const
AGridDesc_K0_M_K1
&
a_grid_desc_k0_m_k1
,
const
BGridDesc_K0_N_K1
&
b_grid_desc_k0_n_k1
,
const
DsGridDesc_M_N
&
ds_grid_desc_m_n
,
const
EGridDesc_M_N
&
e_grid_desc_m_n
,
const
Block2CTileMap
&
block_2_ctile_map
)
{
const
auto
M
=
a_grid_desc_k0_m_k1
.
GetLength
(
I1
);
const
auto
N
=
b_grid_desc_k0_n_k1
.
GetLength
(
I1
);
bool
valid
=
true
;
static_for
<
0
,
NumDTensor
,
1
>
{}([
&
](
auto
i
)
{
valid
=
valid
&&
(
M
==
ds_grid_desc_m_n
[
i
].
GetLength
(
I0
)
&&
N
==
ds_grid_desc_m_n
[
i
].
GetLength
(
I1
));
});
if
(
!
valid
)
{
return
false
;
}
return
CheckValidity
(
a_grid_desc_k0_m_k1
,
b_grid_desc_k0_n_k1
,
e_grid_desc_m_n
,
block_2_ctile_map
);
}
__host__
__device__
static
constexpr
bool
CalculateHasMainKBlockLoop
(
index_t
K
)
__host__
__device__
static
constexpr
bool
CalculateHasMainKBlockLoop
(
index_t
K
)
{
{
const
index_t
num_loop
=
K
/
(
K0PerBlock
*
K1
);
const
index_t
num_loop
=
K
/
(
K0PerBlock
*
K1
);
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp
View file @
1b5af83d
...
@@ -22,13 +22,19 @@ namespace ck {
...
@@ -22,13 +22,19 @@ namespace ck {
template
<
typename
GridwiseGemm
,
template
<
typename
GridwiseGemm
,
bool
HasMainKBlockLoop
,
bool
HasMainKBlockLoop
,
InMemoryDataOperationEnum
CGlobalMemoryDataOperation
,
InMemoryDataOperationEnum
CGlobalMemoryDataOperation
,
typename
Block2CTileMap
>
typename
Block2CTileMap
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
>
__global__
void
__global__
void
#if CK_USE_LAUNCH_BOUNDS
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
#endif
#endif
kernel_gemm_xdlops_v2r4r2_simplified
(
typename
GridwiseGemm
::
Argument
karg
,
kernel_gemm_xdlops_v2r4r2_simplified
(
typename
GridwiseGemm
::
Argument
karg
,
const
Block2CTileMap
&
b2c_map
)
const
Block2CTileMap
&
b2c_map
,
const
AElementwiseOperation
a_element_op
,
const
BElementwiseOperation
b_element_op
,
const
CElementwiseOperation
c_element_op
)
{
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
...
@@ -37,10 +43,13 @@ __global__ void
...
@@ -37,10 +43,13 @@ __global__ void
__shared__
uint8_t
p_shared
[
shared_size
];
__shared__
uint8_t
p_shared
[
shared_size
];
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
,
CGlobalMemoryDataOperation
>(
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
,
CGlobalMemoryDataOperation
>(
karg
,
static_cast
<
void
*>
(
p_shared
),
b2c_map
);
karg
,
static_cast
<
void
*>
(
p_shared
),
b2c_map
,
a_element_op
,
b_element_op
,
c_element_op
);
#else
#else
ignore
=
karg
;
ignore
=
karg
;
ignore
=
b2c_map
;
ignore
=
b2c_map
;
ignore
=
a_element_op
;
ignore
=
b_element_op
;
ignore
=
c_element_op
;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
}
}
...
@@ -577,7 +586,10 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
...
@@ -577,7 +586,10 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
typename
Block2CTileMap
>
typename
Block2CTileMap
>
__device__
static
void
Run
(
const
Argument
&
karg
,
__device__
static
void
Run
(
const
Argument
&
karg
,
void
*
__restrict__
p_shared_block
,
void
*
__restrict__
p_shared_block
,
const
Block2CTileMap
&
block_2_ctile_map
)
const
Block2CTileMap
&
block_2_ctile_map
,
const
AElementwiseOperation
a_element_op
=
AElementwiseOperation
{},
const
BElementwiseOperation
b_element_op
=
BElementwiseOperation
{},
const
CElementwiseOperation
c_element_op
=
CElementwiseOperation
{})
{
{
const
FloatA
*
p_a_grid
=
karg
.
p_a_grid
;
const
FloatA
*
p_a_grid
=
karg
.
p_a_grid
;
const
FloatB
*
p_b_grid
=
karg
.
p_b_grid
;
const
FloatB
*
p_b_grid
=
karg
.
p_b_grid
;
...
@@ -590,9 +602,6 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
...
@@ -590,9 +602,6 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
const
auto
c_grid_desc_mblock_mperblock_nblock_nperblock
=
const
auto
c_grid_desc_mblock_mperblock_nblock_nperblock
=
MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
(
c_grid_desc_m_n
);
MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
(
c_grid_desc_m_n
);
const
AElementwiseOperation
a_element_op
=
AElementwiseOperation
{};
const
BElementwiseOperation
b_element_op
=
BElementwiseOperation
{};
const
CElementwiseOperation
c_element_op
=
CElementwiseOperation
{};
const
auto
a_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
const
auto
a_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_a_grid
,
a_b_k0_m_k1_grid_desc
.
GetElementSpaceSize
());
p_a_grid
,
a_b_k0_m_k1_grid_desc
.
GetElementSpaceSize
());
...
@@ -761,8 +770,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
...
@@ -761,8 +770,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
auto
blockwise_gemm
=
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector
<
auto
blockwise_gemm
=
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector
<
BlockSize
,
BlockSize
,
ComputeType
,
ComputeType
,
// ComputeType A
ComputeType
,
ComputeType
,
// ComputeType B
FloatAcc
,
FloatAcc
,
decltype
(
a_k0_m_k1_block_desc
),
decltype
(
a_k0_m_k1_block_desc
),
decltype
(
b_k0_n_k1_block_desc
),
decltype
(
b_k0_n_k1_block_desc
),
...
...
include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_naive_variance.hpp
View file @
1b5af83d
...
@@ -18,9 +18,11 @@ template <typename XDataType,
...
@@ -18,9 +18,11 @@ template <typename XDataType,
typename
GammaDataType
,
typename
GammaDataType
,
typename
BetaDataType
,
typename
BetaDataType
,
typename
YDataType
,
typename
YDataType
,
typename
SaveMeanInvStdDataType
,
typename
ComputeDataType
,
typename
ComputeDataType
,
typename
YElementwiseOperation
,
typename
YElementwiseOperation
,
typename
GridDesc_M_K
,
typename
GridDesc_M_K
,
typename
GridDesc_M
,
index_t
BlockSize
,
index_t
BlockSize
,
index_t
MThreadClusterSize
,
index_t
MThreadClusterSize
,
index_t
KThreadClusterSize
,
index_t
KThreadClusterSize
,
...
@@ -34,6 +36,7 @@ template <typename XDataType,
...
@@ -34,6 +36,7 @@ template <typename XDataType,
index_t
BetaSrcVectorSize
,
index_t
BetaSrcVectorSize
,
index_t
YDstVectorDim
,
index_t
YDstVectorDim
,
index_t
YDstVectorSize
,
index_t
YDstVectorSize
,
index_t
SaveMeanInvStdDstVectorSize
,
bool
SweepOnce
>
bool
SweepOnce
>
struct
GridwiseNormalizationNaiveVariance_mk_to_mk
struct
GridwiseNormalizationNaiveVariance_mk_to_mk
{
{
...
@@ -45,6 +48,10 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk
...
@@ -45,6 +48,10 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk
(
YDstVectorDim
==
1
&&
KThreadSliceSize
%
YDstVectorSize
==
0
),
(
YDstVectorDim
==
1
&&
KThreadSliceSize
%
YDstVectorSize
==
0
),
"Invalid thread slice sizes and/or vector sizes configuration, please check!"
);
"Invalid thread slice sizes and/or vector sizes configuration, please check!"
);
static_assert
(
MThreadSliceSize
%
SaveMeanInvStdDstVectorSize
==
0
,
"Invalid thread slice sizes and/or save mean and inverse std vector sizes "
"configuration, please check!"
);
static_assert
(
XSrcVectorSize
==
YDstVectorSize
);
static_assert
(
XSrcVectorSize
==
YDstVectorSize
);
static_assert
(
XSrcVectorSize
==
GammaSrcVectorSize
);
static_assert
(
XSrcVectorSize
==
GammaSrcVectorSize
);
static_assert
(
XSrcVectorSize
==
BetaSrcVectorSize
);
static_assert
(
XSrcVectorSize
==
BetaSrcVectorSize
);
...
@@ -66,6 +73,10 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk
...
@@ -66,6 +73,10 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk
static
constexpr
auto
thread_buffer_desc_m_k
=
make_naive_tensor_descriptor_packed
(
static
constexpr
auto
thread_buffer_desc_m_k
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MThreadSliceSize
>
{},
Number
<
XSrcVectorSize
>
{}));
make_tuple
(
Number
<
MThreadSliceSize
>
{},
Number
<
XSrcVectorSize
>
{}));
using
ThreadBufferLengths_M
=
Sequence
<
MThreadSliceSize
>
;
static
constexpr
auto
thread_buffer_desc_m
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MThreadSliceSize
>
{}));
using
ThreadReduceSrcDesc_M_K
=
decltype
(
make_naive_tensor_descriptor_packed
(
using
ThreadReduceSrcDesc_M_K
=
decltype
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MThreadSliceSize
>
{},
Number
<
XSrcVectorSize
>
{})));
make_tuple
(
Number
<
MThreadSliceSize
>
{},
Number
<
XSrcVectorSize
>
{})));
using
ThreadReduceDstDesc_M
=
using
ThreadReduceDstDesc_M
=
...
@@ -84,6 +95,8 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk
...
@@ -84,6 +95,8 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk
reduce
::
Add
,
reduce
::
Add
,
true
>
;
true
>
;
using
PassThroughOp
=
tensor_operation
::
element_wise
::
PassThrough
;
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
...
@@ -98,12 +111,16 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk
...
@@ -98,12 +111,16 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk
const
GridDesc_M_K
&
gamma_grid_desc_m_k
,
const
GridDesc_M_K
&
gamma_grid_desc_m_k
,
const
GridDesc_M_K
&
beta_grid_desc_m_k
,
const
GridDesc_M_K
&
beta_grid_desc_m_k
,
const
GridDesc_M_K
&
y_grid_desc_m_k
,
const
GridDesc_M_K
&
y_grid_desc_m_k
,
const
GridDesc_M
&
save_mean_grid_desc_m
,
const
GridDesc_M
&
save_inv_std_grid_desc_m
,
index_t
num_k_block_tile_iteration
,
index_t
num_k_block_tile_iteration
,
ComputeDataType
epsilon
,
ComputeDataType
epsilon
,
const
XDataType
*
const
__restrict__
p_x_global
,
const
XDataType
*
const
__restrict__
p_x_global
,
const
GammaDataType
*
const
__restrict__
p_gamma_global
,
const
GammaDataType
*
const
__restrict__
p_gamma_global
,
const
BetaDataType
*
const
__restrict__
p_beta_global
,
const
BetaDataType
*
const
__restrict__
p_beta_global
,
YDataType
*
const
__restrict__
p_y_global
,
YDataType
*
const
__restrict__
p_y_global
,
SaveMeanInvStdDataType
*
const
__restrict__
p_save_mean_global
,
SaveMeanInvStdDataType
*
const
__restrict__
p_save_inv_std_global
,
const
YElementwiseOperation
y_elementwise_op
)
const
YElementwiseOperation
y_elementwise_op
)
{
{
// LDS
// LDS
...
@@ -115,6 +132,12 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk
...
@@ -115,6 +132,12 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk
auto
y_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
auto
y_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_y_global
,
y_grid_desc_m_k
.
GetElementSpaceSize
());
p_y_global
,
y_grid_desc_m_k
.
GetElementSpaceSize
());
auto
save_mean_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_save_mean_global
,
save_mean_grid_desc_m
.
GetElementSpaceSize
());
auto
save_inv_std_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_save_inv_std_global
,
save_inv_std_grid_desc_m
.
GetElementSpaceSize
());
auto
x_thread_buf
=
generate_tuple
(
auto
x_thread_buf
=
generate_tuple
(
[
&
](
auto
)
{
[
&
](
auto
)
{
return
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
return
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
...
@@ -152,6 +175,8 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk
...
@@ -152,6 +175,8 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk
mean_square_thread_buf
;
mean_square_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
ComputeDataType
,
MThreadSliceSize
,
true
>&
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
ComputeDataType
,
MThreadSliceSize
,
true
>&
var_thread_buf
=
mean_square_thread_buf
;
var_thread_buf
=
mean_square_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
ComputeDataType
,
MThreadSliceSize
,
true
>&
inv_std_thread_buf
=
mean_square_thread_buf
;
const
index_t
thread_local_id
=
get_thread_local_1d_id
();
const
index_t
thread_local_id
=
get_thread_local_1d_id
();
const
index_t
block_global_id
=
get_block_1d_id
();
const
index_t
block_global_id
=
get_block_1d_id
();
...
@@ -228,6 +253,42 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk
...
@@ -228,6 +253,42 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk
thread_k_cluster_id
*
YDstVectorSize
),
thread_k_cluster_id
*
YDstVectorSize
),
y_elementwise_op
);
y_elementwise_op
);
auto
threadwise_mean_store
=
ThreadwiseTensorSliceTransfer_v1r3
<
ComputeDataType
,
SaveMeanInvStdDataType
,
decltype
(
thread_buffer_desc_m
),
GridDesc_M
,
PassThroughOp
,
ThreadBufferLengths_M
,
Sequence
<
0
>
,
// DimAccessOrder
0
,
// SrcVectorDim
SaveMeanInvStdDstVectorSize
,
// ScalarPerVector
InMemoryDataOperationEnum
::
Set
,
1
,
true
>
(
save_mean_grid_desc_m
,
make_multi_index
(
block_global_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
),
PassThroughOp
{});
auto
threadwise_inv_std_store
=
ThreadwiseTensorSliceTransfer_v1r3
<
ComputeDataType
,
SaveMeanInvStdDataType
,
decltype
(
thread_buffer_desc_m
),
GridDesc_M
,
PassThroughOp
,
ThreadBufferLengths_M
,
Sequence
<
0
>
,
// DimAccessOrder
0
,
// SrcVectorDim
SaveMeanInvStdDstVectorSize
,
// ScalarPerVector
InMemoryDataOperationEnum
::
Set
,
1
,
true
>
(
save_inv_std_grid_desc_m
,
make_multi_index
(
block_global_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
),
PassThroughOp
{});
constexpr
auto
thread_copy_fwd_step_m_k
=
make_multi_index
(
0
,
K_BlockTileStepSize
);
constexpr
auto
thread_copy_fwd_step_m_k
=
make_multi_index
(
0
,
K_BlockTileStepSize
);
constexpr
auto
thread_copy_bwd_step_m_k
=
constexpr
auto
thread_copy_bwd_step_m_k
=
make_multi_index
(
0
,
SweepOnce
?
0
:
-
K_BlockTileSize
);
make_multi_index
(
0
,
SweepOnce
?
0
:
-
K_BlockTileSize
);
...
@@ -243,7 +304,8 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk
...
@@ -243,7 +304,8 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk
// E(x), E[x^2], var(x)
// E(x), E[x^2], var(x)
// FIXME: Should not hack the transform from deviceOP
// FIXME: Should not hack the transform from deviceOP
int
reduce_length
=
x_grid_desc_m_k
.
GetTransforms
()[
I2
].
GetUpperLengths
()[
I0
];
ComputeDataType
reduce_length
=
type_convert
<
ComputeDataType
>
(
x_grid_desc_m_k
.
GetTransforms
()[
I2
].
GetUpperLengths
()[
I0
]);
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
I
)
{
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
I
)
{
mean_thread_buf
(
I
)
=
reduce
::
Add
::
template
GetIdentityValue
<
ComputeDataType
>();
mean_thread_buf
(
I
)
=
reduce
::
Add
::
template
GetIdentityValue
<
ComputeDataType
>();
...
@@ -302,10 +364,34 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk
...
@@ -302,10 +364,34 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk
// var(x) = E[x^2] - E[x]^2
// var(x) = E[x^2] - E[x]^2
var_thread_buf
(
I
)
=
var_thread_buf
(
I
)
=
mean_square_thread_buf
(
I
)
-
(
mean_thread_buf
(
I
)
*
mean_thread_buf
(
I
));
mean_square_thread_buf
(
I
)
-
(
mean_thread_buf
(
I
)
*
mean_thread_buf
(
I
));
inv_std_thread_buf
(
I
)
=
type_convert
<
ComputeDataType
>
(
1.0
f
)
/
ck
::
math
::
sqrt
(
var_thread_buf
(
I
)
+
epsilon
);
});
});
// save mean and inverse std for backward (optional)
if
(
thread_k_cluster_id
==
0
)
{
if
(
p_save_mean_global
!=
nullptr
)
{
threadwise_mean_store
.
Run
(
thread_buffer_desc_m
,
make_tuple
(
I0
),
mean_thread_buf
,
save_mean_grid_desc_m
,
save_mean_global_val_buf
);
}
if
(
p_save_inv_std_global
!=
nullptr
)
{
threadwise_inv_std_store
.
Run
(
thread_buffer_desc_m
,
make_tuple
(
I0
),
inv_std_thread_buf
,
save_inv_std_grid_desc_m
,
save_inv_std_global_val_buf
);
}
}
// normalization
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
iM
)
{
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
iM
)
{
auto
divisor
=
1
/
ck
::
math
::
sqrt
(
var_thread_buf
(
iM
)
+
epsilon
);
static_for
<
0
,
ThreadBufferNumber
,
1
>
{}([
&
](
auto
iK0
)
{
static_for
<
0
,
ThreadBufferNumber
,
1
>
{}([
&
](
auto
iK0
)
{
static_for
<
0
,
XSrcVectorSize
,
1
>
{}([
&
](
auto
iK1
)
{
static_for
<
0
,
XSrcVectorSize
,
1
>
{}([
&
](
auto
iK1
)
{
constexpr
auto
offset_m_k
=
constexpr
auto
offset_m_k
=
...
@@ -314,7 +400,7 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk
...
@@ -314,7 +400,7 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk
// normalize
// normalize
y_thread_buf
(
iK0
)(
Number
<
offset_m_k
>
{})
=
y_thread_buf
(
iK0
)(
Number
<
offset_m_k
>
{})
=
(
x_thread_buf
(
iK0
)(
Number
<
offset_m_k
>
{})
-
mean_thread_buf
(
iM
))
*
(
x_thread_buf
(
iK0
)(
Number
<
offset_m_k
>
{})
-
mean_thread_buf
(
iM
))
*
divisor
;
inv_std_thread_buf
(
iM
)
;
// gamma & beta
// gamma & beta
y_thread_buf
(
iK0
)(
Number
<
offset_m_k
>
{})
=
y_thread_buf
(
iK0
)(
Number
<
offset_m_k
>
{})
=
...
@@ -404,8 +490,30 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk
...
@@ -404,8 +490,30 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk
// var(x) = E[x^2] - E[x]^2
// var(x) = E[x^2] - E[x]^2
var_thread_buf
(
I
)
=
var_thread_buf
(
I
)
=
mean_square_thread_buf
(
I
)
-
(
mean_thread_buf
(
I
)
*
mean_thread_buf
(
I
));
mean_square_thread_buf
(
I
)
-
(
mean_thread_buf
(
I
)
*
mean_thread_buf
(
I
));
inv_std_thread_buf
(
I
)
=
1
/
ck
::
math
::
sqrt
(
var_thread_buf
(
I
)
+
epsilon
);
});
});
if
(
thread_k_cluster_id
==
0
)
{
if
(
p_save_mean_global
!=
nullptr
)
{
threadwise_mean_store
.
Run
(
thread_buffer_desc_m
,
make_tuple
(
I0
),
mean_thread_buf
,
save_mean_grid_desc_m
,
save_mean_global_val_buf
);
}
if
(
p_save_inv_std_global
!=
nullptr
)
{
threadwise_inv_std_store
.
Run
(
thread_buffer_desc_m
,
make_tuple
(
I0
),
inv_std_thread_buf
,
save_inv_std_grid_desc_m
,
save_inv_std_global_val_buf
);
}
}
auto
thread_copy_tail_m_k
=
auto
thread_copy_tail_m_k
=
(
num_k_block_tile_iteration
-
1
)
*
ThreadBufferNumber
*
thread_copy_fwd_step_m_k
;
(
num_k_block_tile_iteration
-
1
)
*
ThreadBufferNumber
*
thread_copy_fwd_step_m_k
;
...
@@ -437,7 +545,6 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk
...
@@ -437,7 +545,6 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk
});
});
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
iM
)
{
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
iM
)
{
auto
divisor
=
1
/
ck
::
math
::
sqrt
(
var_thread_buf
(
iM
)
+
epsilon
);
static_for
<
0
,
ThreadBufferNumber
,
1
>
{}([
&
](
auto
iK0
)
{
static_for
<
0
,
ThreadBufferNumber
,
1
>
{}([
&
](
auto
iK0
)
{
static_for
<
0
,
XSrcVectorSize
,
1
>
{}([
&
](
auto
iK1
)
{
static_for
<
0
,
XSrcVectorSize
,
1
>
{}([
&
](
auto
iK1
)
{
constexpr
auto
offset_m_k
=
constexpr
auto
offset_m_k
=
...
@@ -446,7 +553,7 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk
...
@@ -446,7 +553,7 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk
// normalize
// normalize
y_thread_buf
(
iK0
)(
Number
<
offset_m_k
>
{})
=
y_thread_buf
(
iK0
)(
Number
<
offset_m_k
>
{})
=
(
x_thread_buf
(
iK0
)(
Number
<
offset_m_k
>
{})
-
mean_thread_buf
(
iM
))
*
(
x_thread_buf
(
iK0
)(
Number
<
offset_m_k
>
{})
-
mean_thread_buf
(
iM
))
*
divisor
;
inv_std_thread_buf
(
iM
)
;
// gamma
// gamma
y_thread_buf
(
iK0
)(
Number
<
offset_m_k
>
{})
=
y_thread_buf
(
iK0
)(
Number
<
offset_m_k
>
{})
=
...
...
include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_selector.hpp
View file @
1b5af83d
...
@@ -12,31 +12,42 @@ template <typename GridwiseReduction,
...
@@ -12,31 +12,42 @@ template <typename GridwiseReduction,
typename
GammaDataType
,
typename
GammaDataType
,
typename
BetaDataType
,
typename
BetaDataType
,
typename
YDataType
,
typename
YDataType
,
typename
SaveMeanInvStdDataType
,
typename
ComputeDataType
,
typename
ComputeDataType
,
typename
YElementwiseOperation
,
typename
YElementwiseOperation
,
typename
GridDesc_M_K
>
typename
GridDesc_M_K
,
__global__
void
kernel_normalization
(
const
GridDesc_M_K
x_grid_desc_m_k
,
typename
GridDesc_M
>
const
GridDesc_M_K
gamma_grid_desc_m_k
,
__global__
void
const
GridDesc_M_K
beta_grid_desc_m_k
,
kernel_normalization
(
const
GridDesc_M_K
x_grid_desc_m_k
,
const
GridDesc_M_K
y_grid_desc_m_k
,
const
GridDesc_M_K
gamma_grid_desc_m_k
,
index_t
num_k_block_tile_iteration
,
const
GridDesc_M_K
beta_grid_desc_m_k
,
ComputeDataType
epsilon
,
const
GridDesc_M_K
y_grid_desc_m_k
,
const
XDataType
*
const
__restrict__
p_x_global
,
const
GridDesc_M
save_mean_grid_desc_m
,
const
GammaDataType
*
const
__restrict__
p_gamma_global
,
const
GridDesc_M
save_inv_std_grid_desc_m
,
const
BetaDataType
*
const
__restrict__
p_beta_global
,
index_t
num_k_block_tile_iteration
,
YDataType
*
const
__restrict__
p_y_global
,
ComputeDataType
epsilon
,
const
YElementwiseOperation
y_elementwise_op
)
const
XDataType
*
const
__restrict__
p_x_global
,
const
GammaDataType
*
const
__restrict__
p_gamma_global
,
const
BetaDataType
*
const
__restrict__
p_beta_global
,
YDataType
*
const
__restrict__
p_y_global
,
SaveMeanInvStdDataType
*
const
__restrict__
p_save_mean_global
,
SaveMeanInvStdDataType
*
const
__restrict__
p_save_inv_std_global
,
const
YElementwiseOperation
y_elementwise_op
)
{
{
GridwiseReduction
::
Run
(
x_grid_desc_m_k
,
GridwiseReduction
::
Run
(
x_grid_desc_m_k
,
gamma_grid_desc_m_k
,
gamma_grid_desc_m_k
,
beta_grid_desc_m_k
,
beta_grid_desc_m_k
,
y_grid_desc_m_k
,
y_grid_desc_m_k
,
save_mean_grid_desc_m
,
save_inv_std_grid_desc_m
,
num_k_block_tile_iteration
,
num_k_block_tile_iteration
,
epsilon
,
epsilon
,
p_x_global
,
p_x_global
,
p_gamma_global
,
p_gamma_global
,
p_beta_global
,
p_beta_global
,
p_y_global
,
p_y_global
,
p_save_mean_global
,
p_save_inv_std_global
,
y_elementwise_op
);
y_elementwise_op
);
};
};
...
@@ -44,9 +55,11 @@ template <typename XDataType,
...
@@ -44,9 +55,11 @@ template <typename XDataType,
typename
GammaDataType
,
typename
GammaDataType
,
typename
BetaDataType
,
typename
BetaDataType
,
typename
YDataType
,
typename
YDataType
,
typename
SaveMeanInvStdDataType
,
typename
ComputeDataType
,
typename
ComputeDataType
,
typename
YElementwiseOperation
,
typename
YElementwiseOperation
,
typename
GridDesc_M_K
,
typename
GridDesc_M_K
,
typename
GridDesc_M
,
index_t
BlockSize
,
index_t
BlockSize
,
index_t
MThreadClusterSize
,
index_t
MThreadClusterSize
,
index_t
KThreadClusterSize
,
index_t
KThreadClusterSize
,
...
@@ -60,6 +73,7 @@ template <typename XDataType,
...
@@ -60,6 +73,7 @@ template <typename XDataType,
index_t
BetaSrcVectorSize
,
index_t
BetaSrcVectorSize
,
index_t
YDstVectorDim
,
index_t
YDstVectorDim
,
index_t
YDstVectorSize
,
index_t
YDstVectorSize
,
index_t
SaveMeanInvStdDstVectorSize
,
bool
UseWelford
>
bool
UseWelford
>
auto
NormalizationKernelSelector
(
bool
isSweepOnce
)
auto
NormalizationKernelSelector
(
bool
isSweepOnce
)
{
{
...
@@ -68,9 +82,11 @@ auto NormalizationKernelSelector(bool isSweepOnce)
...
@@ -68,9 +82,11 @@ auto NormalizationKernelSelector(bool isSweepOnce)
GammaDataType
,
GammaDataType
,
BetaDataType
,
BetaDataType
,
YDataType
,
YDataType
,
SaveMeanInvStdDataType
,
ComputeDataType
,
ComputeDataType
,
YElementwiseOperation
,
YElementwiseOperation
,
GridDesc_M_K
,
GridDesc_M_K
,
GridDesc_M
,
BlockSize
,
BlockSize
,
MThreadClusterSize
,
MThreadClusterSize
,
KThreadClusterSize
,
KThreadClusterSize
,
...
@@ -84,15 +100,18 @@ auto NormalizationKernelSelector(bool isSweepOnce)
...
@@ -84,15 +100,18 @@ auto NormalizationKernelSelector(bool isSweepOnce)
BetaSrcVectorSize
,
BetaSrcVectorSize
,
YDstVectorDim
,
YDstVectorDim
,
YDstVectorSize
,
YDstVectorSize
,
SaveMeanInvStdDstVectorSize
,
false
>
;
false
>
;
using
GridwiseNormalizationSweepOnceNaive
=
using
GridwiseNormalizationSweepOnceNaive
=
GridwiseNormalizationNaiveVariance_mk_to_mk
<
XDataType
,
GridwiseNormalizationNaiveVariance_mk_to_mk
<
XDataType
,
GammaDataType
,
GammaDataType
,
BetaDataType
,
BetaDataType
,
YDataType
,
YDataType
,
SaveMeanInvStdDataType
,
ComputeDataType
,
ComputeDataType
,
YElementwiseOperation
,
YElementwiseOperation
,
GridDesc_M_K
,
GridDesc_M_K
,
GridDesc_M
,
BlockSize
,
BlockSize
,
MThreadClusterSize
,
MThreadClusterSize
,
KThreadClusterSize
,
KThreadClusterSize
,
...
@@ -106,15 +125,18 @@ auto NormalizationKernelSelector(bool isSweepOnce)
...
@@ -106,15 +125,18 @@ auto NormalizationKernelSelector(bool isSweepOnce)
BetaSrcVectorSize
,
BetaSrcVectorSize
,
YDstVectorDim
,
YDstVectorDim
,
YDstVectorSize
,
YDstVectorSize
,
SaveMeanInvStdDstVectorSize
,
true
>
;
true
>
;
using
GridwiseNormalizationGenericWelford
=
using
GridwiseNormalizationGenericWelford
=
GridwiseNormalizationWelfordVariance_mk_to_mk
<
XDataType
,
GridwiseNormalizationWelfordVariance_mk_to_mk
<
XDataType
,
GammaDataType
,
GammaDataType
,
BetaDataType
,
BetaDataType
,
YDataType
,
YDataType
,
SaveMeanInvStdDataType
,
ComputeDataType
,
ComputeDataType
,
YElementwiseOperation
,
YElementwiseOperation
,
GridDesc_M_K
,
GridDesc_M_K
,
GridDesc_M
,
BlockSize
,
BlockSize
,
MThreadClusterSize
,
MThreadClusterSize
,
KThreadClusterSize
,
KThreadClusterSize
,
...
@@ -128,15 +150,18 @@ auto NormalizationKernelSelector(bool isSweepOnce)
...
@@ -128,15 +150,18 @@ auto NormalizationKernelSelector(bool isSweepOnce)
BetaSrcVectorSize
,
BetaSrcVectorSize
,
YDstVectorDim
,
YDstVectorDim
,
YDstVectorSize
,
YDstVectorSize
,
SaveMeanInvStdDstVectorSize
,
false
>
;
false
>
;
using
GridwiseNormalizationSweepOnceWelford
=
using
GridwiseNormalizationSweepOnceWelford
=
GridwiseNormalizationWelfordVariance_mk_to_mk
<
XDataType
,
GridwiseNormalizationWelfordVariance_mk_to_mk
<
XDataType
,
GammaDataType
,
GammaDataType
,
BetaDataType
,
BetaDataType
,
YDataType
,
YDataType
,
SaveMeanInvStdDataType
,
ComputeDataType
,
ComputeDataType
,
YElementwiseOperation
,
YElementwiseOperation
,
GridDesc_M_K
,
GridDesc_M_K
,
GridDesc_M
,
BlockSize
,
BlockSize
,
MThreadClusterSize
,
MThreadClusterSize
,
KThreadClusterSize
,
KThreadClusterSize
,
...
@@ -150,6 +175,7 @@ auto NormalizationKernelSelector(bool isSweepOnce)
...
@@ -150,6 +175,7 @@ auto NormalizationKernelSelector(bool isSweepOnce)
BetaSrcVectorSize
,
BetaSrcVectorSize
,
YDstVectorDim
,
YDstVectorDim
,
YDstVectorSize
,
YDstVectorSize
,
SaveMeanInvStdDstVectorSize
,
true
>
;
true
>
;
if
constexpr
(
UseWelford
)
if
constexpr
(
UseWelford
)
...
@@ -159,17 +185,21 @@ auto NormalizationKernelSelector(bool isSweepOnce)
...
@@ -159,17 +185,21 @@ auto NormalizationKernelSelector(bool isSweepOnce)
GammaDataType
,
GammaDataType
,
BetaDataType
,
BetaDataType
,
YDataType
,
YDataType
,
SaveMeanInvStdDataType
,
ComputeDataType
,
ComputeDataType
,
YElementwiseOperation
,
YElementwiseOperation
,
GridDesc_M_K
>
GridDesc_M_K
,
GridDesc_M
>
:
kernel_normalization
<
GridwiseNormalizationGenericWelford
,
:
kernel_normalization
<
GridwiseNormalizationGenericWelford
,
XDataType
,
XDataType
,
GammaDataType
,
GammaDataType
,
BetaDataType
,
BetaDataType
,
YDataType
,
YDataType
,
SaveMeanInvStdDataType
,
ComputeDataType
,
ComputeDataType
,
YElementwiseOperation
,
YElementwiseOperation
,
GridDesc_M_K
>
;
GridDesc_M_K
,
GridDesc_M
>
;
}
}
else
else
{
{
...
@@ -178,17 +208,21 @@ auto NormalizationKernelSelector(bool isSweepOnce)
...
@@ -178,17 +208,21 @@ auto NormalizationKernelSelector(bool isSweepOnce)
GammaDataType
,
GammaDataType
,
BetaDataType
,
BetaDataType
,
YDataType
,
YDataType
,
SaveMeanInvStdDataType
,
ComputeDataType
,
ComputeDataType
,
YElementwiseOperation
,
YElementwiseOperation
,
GridDesc_M_K
>
GridDesc_M_K
,
GridDesc_M
>
:
kernel_normalization
<
GridwiseNormalizationGenericNaive
,
:
kernel_normalization
<
GridwiseNormalizationGenericNaive
,
XDataType
,
XDataType
,
GammaDataType
,
GammaDataType
,
BetaDataType
,
BetaDataType
,
YDataType
,
YDataType
,
SaveMeanInvStdDataType
,
ComputeDataType
,
ComputeDataType
,
YElementwiseOperation
,
YElementwiseOperation
,
GridDesc_M_K
>
;
GridDesc_M_K
,
GridDesc_M
>
;
}
}
}
}
...
...
include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_splitk_2nd.hpp
View file @
1b5af83d
...
@@ -17,11 +17,13 @@ template <typename MeanVarDataType,
...
@@ -17,11 +17,13 @@ template <typename MeanVarDataType,
typename
GammaDataType
,
typename
GammaDataType
,
typename
BetaDataType
,
typename
BetaDataType
,
typename
YDataType
,
typename
YDataType
,
typename
SaveMeanInvStdDataType
,
typename
ComputeDataType
,
typename
ComputeDataType
,
typename
YElementwiseOperation
,
typename
YElementwiseOperation
,
typename
MeanVarGridDesc_M_KBlock
,
typename
MeanVarGridDesc_M_KBlock
,
typename
CountGridDesc_M_KBlock
,
typename
CountGridDesc_M_KBlock
,
typename
XYGammaBetaGridDesc_M_K
,
typename
XYGammaBetaGridDesc_M_K
,
typename
SaveMeanInvStdGridDesc_M
,
index_t
BlockSize
,
index_t
BlockSize
,
index_t
MThreadClusterSize
,
index_t
MThreadClusterSize
,
index_t
KThreadClusterSize
,
index_t
KThreadClusterSize
,
...
@@ -34,7 +36,8 @@ template <typename MeanVarDataType,
...
@@ -34,7 +36,8 @@ template <typename MeanVarDataType,
index_t
BetaSrcVectorDim
,
index_t
BetaSrcVectorDim
,
index_t
BetaSrcVectorSize
,
index_t
BetaSrcVectorSize
,
index_t
YDstVectorDim
,
index_t
YDstVectorDim
,
index_t
YDstVectorSize
>
index_t
YDstVectorSize
,
index_t
SaveMeanInvStdDstVectorSize
>
struct
GridwiseNormalizationSplitK2nd
struct
GridwiseNormalizationSplitK2nd
{
{
static_assert
((
XSrcVectorDim
==
0
&&
MThreadSliceSize
%
XSrcVectorSize
==
0
)
||
static_assert
((
XSrcVectorDim
==
0
&&
MThreadSliceSize
%
XSrcVectorSize
==
0
)
||
...
@@ -45,6 +48,10 @@ struct GridwiseNormalizationSplitK2nd
...
@@ -45,6 +48,10 @@ struct GridwiseNormalizationSplitK2nd
(
YDstVectorDim
==
1
&&
KThreadSliceSize
%
YDstVectorSize
==
0
),
(
YDstVectorDim
==
1
&&
KThreadSliceSize
%
YDstVectorSize
==
0
),
"Invalid thread slice sizes and/or vector sizes configuration, please check!"
);
"Invalid thread slice sizes and/or vector sizes configuration, please check!"
);
static_assert
(
MThreadSliceSize
%
SaveMeanInvStdDstVectorSize
==
0
,
"Invalid thread slice sizes and/or save mean and inverse std vector sizes "
"configuration, please check!"
);
static_assert
(
XSrcVectorSize
==
YDstVectorSize
);
static_assert
(
XSrcVectorSize
==
YDstVectorSize
);
static_assert
(
XSrcVectorSize
==
GammaSrcVectorSize
);
static_assert
(
XSrcVectorSize
==
GammaSrcVectorSize
);
static_assert
(
XSrcVectorSize
==
BetaSrcVectorSize
);
static_assert
(
XSrcVectorSize
==
BetaSrcVectorSize
);
...
@@ -69,6 +76,10 @@ struct GridwiseNormalizationSplitK2nd
...
@@ -69,6 +76,10 @@ struct GridwiseNormalizationSplitK2nd
static
constexpr
auto
thread_buffer_desc_m_k
=
make_naive_tensor_descriptor_packed
(
static
constexpr
auto
thread_buffer_desc_m_k
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MThreadSliceSize
>
{},
Number
<
XSrcVectorSize
>
{}));
make_tuple
(
Number
<
MThreadSliceSize
>
{},
Number
<
XSrcVectorSize
>
{}));
using
ThreadBufferLengths_M
=
Sequence
<
MThreadSliceSize
>
;
static
constexpr
auto
thread_buffer_desc_m
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MThreadSliceSize
>
{}));
using
ThreadBufferLengths_M_1
=
Sequence
<
MThreadSliceSize
,
1
>
;
using
ThreadBufferLengths_M_1
=
Sequence
<
MThreadSliceSize
,
1
>
;
static
constexpr
auto
thread_buffer_desc_m_1
=
static
constexpr
auto
thread_buffer_desc_m_1
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MThreadSliceSize
>
{},
I1
));
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MThreadSliceSize
>
{},
I1
));
...
@@ -99,6 +110,8 @@ struct GridwiseNormalizationSplitK2nd
...
@@ -99,6 +110,8 @@ struct GridwiseNormalizationSplitK2nd
const
XYGammaBetaGridDesc_M_K
&
gamma_grid_desc_m_k
,
const
XYGammaBetaGridDesc_M_K
&
gamma_grid_desc_m_k
,
const
XYGammaBetaGridDesc_M_K
&
beta_grid_desc_m_k
,
const
XYGammaBetaGridDesc_M_K
&
beta_grid_desc_m_k
,
const
XYGammaBetaGridDesc_M_K
&
y_grid_desc_m_k
,
const
XYGammaBetaGridDesc_M_K
&
y_grid_desc_m_k
,
const
SaveMeanInvStdGridDesc_M
&
save_mean_grid_desc_m
,
const
SaveMeanInvStdGridDesc_M
&
save_inv_std_grid_desc_m
,
index_t
num_k_mean_var_count_iteration
,
index_t
num_k_mean_var_count_iteration
,
index_t
num_k_block_tile_iteration
,
index_t
num_k_block_tile_iteration
,
index_t
k_grid_size
,
index_t
k_grid_size
,
...
@@ -110,6 +123,8 @@ struct GridwiseNormalizationSplitK2nd
...
@@ -110,6 +123,8 @@ struct GridwiseNormalizationSplitK2nd
const
GammaDataType
*
const
__restrict__
p_gamma_global
,
const
GammaDataType
*
const
__restrict__
p_gamma_global
,
const
BetaDataType
*
const
__restrict__
p_beta_global
,
const
BetaDataType
*
const
__restrict__
p_beta_global
,
YDataType
*
const
__restrict__
p_y_global
,
YDataType
*
const
__restrict__
p_y_global
,
SaveMeanInvStdDataType
*
const
__restrict__
p_save_mean_global
,
SaveMeanInvStdDataType
*
const
__restrict__
p_save_inv_std_global
,
const
YElementwiseOperation
y_elementwise_op
)
const
YElementwiseOperation
y_elementwise_op
)
{
{
// Thread/Block id
// Thread/Block id
...
@@ -145,6 +160,12 @@ struct GridwiseNormalizationSplitK2nd
...
@@ -145,6 +160,12 @@ struct GridwiseNormalizationSplitK2nd
auto
y_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
auto
y_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_y_global
,
y_grid_desc_m_k
.
GetElementSpaceSize
());
p_y_global
,
y_grid_desc_m_k
.
GetElementSpaceSize
());
auto
save_mean_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_save_mean_global
,
save_mean_grid_desc_m
.
GetElementSpaceSize
());
auto
save_inv_std_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_save_inv_std_global
,
save_inv_std_grid_desc_m
.
GetElementSpaceSize
());
// VGPR
// VGPR
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
ComputeDataType
,
MThreadSliceSize
,
true
>
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
ComputeDataType
,
MThreadSliceSize
,
true
>
in_mean_thread_buf
;
in_mean_thread_buf
;
...
@@ -158,6 +179,7 @@ struct GridwiseNormalizationSplitK2nd
...
@@ -158,6 +179,7 @@ struct GridwiseNormalizationSplitK2nd
var_thread_buf
;
var_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
int32_t
,
MThreadSliceSize
,
true
>
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
int32_t
,
MThreadSliceSize
,
true
>
welford_count_thread_buf
;
welford_count_thread_buf
;
auto
&
inv_std_thread_buf
=
var_thread_buf
;
auto
x_thread_buf
=
generate_tuple
(
auto
x_thread_buf
=
generate_tuple
(
[
&
](
auto
)
{
[
&
](
auto
)
{
...
@@ -283,6 +305,42 @@ struct GridwiseNormalizationSplitK2nd
...
@@ -283,6 +305,42 @@ struct GridwiseNormalizationSplitK2nd
thread_k_cluster_id
*
YDstVectorSize
),
thread_k_cluster_id
*
YDstVectorSize
),
y_elementwise_op
);
y_elementwise_op
);
auto
threadwise_mean_store
=
ThreadwiseTensorSliceTransfer_v1r3
<
ComputeDataType
,
SaveMeanInvStdDataType
,
decltype
(
thread_buffer_desc_m
),
SaveMeanInvStdGridDesc_M
,
PassThroughOp
,
ThreadBufferLengths_M
,
Sequence
<
0
>
,
// DimAccessOrder
0
,
// SrcVectorDim
SaveMeanInvStdDstVectorSize
,
// ScalarPerVector
InMemoryDataOperationEnum
::
Set
,
1
,
true
>
(
save_mean_grid_desc_m
,
make_multi_index
(
block_m_cluster_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
),
PassThroughOp
{});
auto
threadwise_inv_std_store
=
ThreadwiseTensorSliceTransfer_v1r3
<
ComputeDataType
,
SaveMeanInvStdDataType
,
decltype
(
thread_buffer_desc_m
),
SaveMeanInvStdGridDesc_M
,
PassThroughOp
,
ThreadBufferLengths_M
,
Sequence
<
0
>
,
// DimAccessOrder
0
,
// SrcVectorDim
SaveMeanInvStdDstVectorSize
,
// ScalarPerVector
InMemoryDataOperationEnum
::
Set
,
1
,
true
>
(
save_inv_std_grid_desc_m
,
make_multi_index
(
block_m_cluster_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
),
PassThroughOp
{});
// step1: Merge mean and variance
// step1: Merge mean and variance
constexpr
auto
mean_var_count_thread_copy_step_I0_k
=
constexpr
auto
mean_var_count_thread_copy_step_I0_k
=
make_multi_index
(
I0
,
KThreadClusterSize
);
make_multi_index
(
I0
,
KThreadClusterSize
);
...
@@ -332,9 +390,33 @@ struct GridwiseNormalizationSplitK2nd
...
@@ -332,9 +390,33 @@ struct GridwiseNormalizationSplitK2nd
BlockwiseWelford
::
Run
(
BlockwiseWelford
::
Run
(
mean_thread_buf
(
I
),
var_thread_buf
(
I
),
welford_count_thread_buf
(
I
));
mean_thread_buf
(
I
),
var_thread_buf
(
I
),
welford_count_thread_buf
(
I
));
inv_std_thread_buf
(
I
)
=
type_convert
<
ComputeDataType
>
(
1.0
f
)
/
ck
::
math
::
sqrt
(
var_thread_buf
(
I
)
+
epsilon
);
});
});
// step2: normalization
// step2: save mean and inverse std for backward (optional)
if
(
block_k_cluster_id
==
0
&&
thread_k_cluster_id
==
0
)
{
if
(
p_save_mean_global
!=
nullptr
)
{
threadwise_mean_store
.
Run
(
thread_buffer_desc_m
,
make_tuple
(
I0
),
mean_thread_buf
,
save_mean_grid_desc_m
,
save_mean_global_val_buf
);
}
if
(
p_save_inv_std_global
!=
nullptr
)
{
threadwise_inv_std_store
.
Run
(
thread_buffer_desc_m
,
make_tuple
(
I0
),
inv_std_thread_buf
,
save_inv_std_grid_desc_m
,
save_inv_std_global_val_buf
);
}
}
// step3: normalization
constexpr
auto
thread_copy_fwd_step_m_k
=
make_multi_index
(
0
,
K_BlockTileStepSize
);
constexpr
auto
thread_copy_fwd_step_m_k
=
make_multi_index
(
0
,
K_BlockTileStepSize
);
for
(
index_t
k
=
0
;
k
<
num_k_block_tile_iteration
;
++
k
)
for
(
index_t
k
=
0
;
k
<
num_k_block_tile_iteration
;
++
k
)
...
@@ -360,7 +442,6 @@ struct GridwiseNormalizationSplitK2nd
...
@@ -360,7 +442,6 @@ struct GridwiseNormalizationSplitK2nd
});
});
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
iM
)
{
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
iM
)
{
auto
divisor
=
1
/
ck
::
math
::
sqrt
(
var_thread_buf
(
iM
)
+
epsilon
);
static_for
<
0
,
ThreadBufferNumber
,
1
>
{}([
&
](
auto
iK0
)
{
static_for
<
0
,
ThreadBufferNumber
,
1
>
{}([
&
](
auto
iK0
)
{
static_for
<
0
,
XSrcVectorSize
,
1
>
{}([
&
](
auto
iK1
)
{
static_for
<
0
,
XSrcVectorSize
,
1
>
{}([
&
](
auto
iK1
)
{
constexpr
auto
offset_m_k
=
constexpr
auto
offset_m_k
=
...
@@ -369,7 +450,7 @@ struct GridwiseNormalizationSplitK2nd
...
@@ -369,7 +450,7 @@ struct GridwiseNormalizationSplitK2nd
// normalize
// normalize
y_thread_buf
(
iK0
)(
Number
<
offset_m_k
>
{})
=
y_thread_buf
(
iK0
)(
Number
<
offset_m_k
>
{})
=
(
x_thread_buf
(
iK0
)(
Number
<
offset_m_k
>
{})
-
mean_thread_buf
(
iM
))
*
(
x_thread_buf
(
iK0
)(
Number
<
offset_m_k
>
{})
-
mean_thread_buf
(
iM
))
*
divisor
;
inv_std_thread_buf
(
iM
)
;
// gamma
// gamma
y_thread_buf
(
iK0
)(
Number
<
offset_m_k
>
{})
=
y_thread_buf
(
iK0
)(
Number
<
offset_m_k
>
{})
=
...
...
include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_welford_variance.hpp
View file @
1b5af83d
...
@@ -16,9 +16,11 @@ template <typename XDataType,
...
@@ -16,9 +16,11 @@ template <typename XDataType,
typename
GammaDataType
,
typename
GammaDataType
,
typename
BetaDataType
,
typename
BetaDataType
,
typename
YDataType
,
typename
YDataType
,
typename
SaveMeanInvStdDataType
,
typename
ComputeDataType
,
typename
ComputeDataType
,
typename
YElementwiseOperation
,
typename
YElementwiseOperation
,
typename
GridDesc_M_K
,
typename
GridDesc_M_K
,
typename
GridDesc_M
,
index_t
BlockSize
,
index_t
BlockSize
,
index_t
MThreadClusterSize
,
index_t
MThreadClusterSize
,
index_t
KThreadClusterSize
,
index_t
KThreadClusterSize
,
...
@@ -32,6 +34,7 @@ template <typename XDataType,
...
@@ -32,6 +34,7 @@ template <typename XDataType,
index_t
BetaSrcVectorSize
,
index_t
BetaSrcVectorSize
,
index_t
YDstVectorDim
,
index_t
YDstVectorDim
,
index_t
YDstVectorSize
,
index_t
YDstVectorSize
,
index_t
SaveMeanInvStdDstVectorSize
,
bool
SweepOnce
>
bool
SweepOnce
>
struct
GridwiseNormalizationWelfordVariance_mk_to_mk
struct
GridwiseNormalizationWelfordVariance_mk_to_mk
{
{
...
@@ -43,6 +46,10 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk
...
@@ -43,6 +46,10 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk
(
YDstVectorDim
==
1
&&
KThreadSliceSize
%
YDstVectorSize
==
0
),
(
YDstVectorDim
==
1
&&
KThreadSliceSize
%
YDstVectorSize
==
0
),
"Invalid thread slice sizes and/or vector sizes configuration, please check!"
);
"Invalid thread slice sizes and/or vector sizes configuration, please check!"
);
static_assert
(
MThreadSliceSize
%
SaveMeanInvStdDstVectorSize
==
0
,
"Invalid thread slice sizes and/or save mean and inverse std vector sizes "
"configuration, please check!"
);
static_assert
(
XSrcVectorSize
==
YDstVectorSize
);
static_assert
(
XSrcVectorSize
==
YDstVectorSize
);
static_assert
(
XSrcVectorSize
==
GammaSrcVectorSize
);
static_assert
(
XSrcVectorSize
==
GammaSrcVectorSize
);
static_assert
(
XSrcVectorSize
==
BetaSrcVectorSize
);
static_assert
(
XSrcVectorSize
==
BetaSrcVectorSize
);
...
@@ -64,6 +71,10 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk
...
@@ -64,6 +71,10 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk
static
constexpr
auto
thread_buffer_desc_m_k
=
make_naive_tensor_descriptor_packed
(
static
constexpr
auto
thread_buffer_desc_m_k
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MThreadSliceSize
>
{},
Number
<
XSrcVectorSize
>
{}));
make_tuple
(
Number
<
MThreadSliceSize
>
{},
Number
<
XSrcVectorSize
>
{}));
using
ThreadBufferLengths_M
=
Sequence
<
MThreadSliceSize
>
;
static
constexpr
auto
thread_buffer_desc_m
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MThreadSliceSize
>
{}));
using
ThreadReduceSrcDesc_M_K
=
decltype
(
make_naive_tensor_descriptor_packed
(
using
ThreadReduceSrcDesc_M_K
=
decltype
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MThreadSliceSize
>
{},
Number
<
XSrcVectorSize
>
{})));
make_tuple
(
Number
<
MThreadSliceSize
>
{},
Number
<
XSrcVectorSize
>
{})));
using
ThreadReduceDstDesc_M
=
using
ThreadReduceDstDesc_M
=
...
@@ -77,6 +88,8 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk
...
@@ -77,6 +88,8 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk
ThreadClusterLengths_M_K
,
ThreadClusterLengths_M_K
,
ThreadClusterArrangeOrder
>
;
ThreadClusterArrangeOrder
>
;
using
PassThroughOp
=
tensor_operation
::
element_wise
::
PassThrough
;
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
...
@@ -114,17 +127,18 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk
...
@@ -114,17 +127,18 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk
const
GridDesc_M_K
&
gamma_grid_desc_m_k
,
const
GridDesc_M_K
&
gamma_grid_desc_m_k
,
const
GridDesc_M_K
&
beta_grid_desc_m_k
,
const
GridDesc_M_K
&
beta_grid_desc_m_k
,
const
GridDesc_M_K
&
y_grid_desc_m_k
,
const
GridDesc_M_K
&
y_grid_desc_m_k
,
const
GridDesc_M
&
save_mean_grid_desc_m
,
const
GridDesc_M
&
save_inv_std_grid_desc_m
,
index_t
num_k_block_tile_iteration
,
index_t
num_k_block_tile_iteration
,
ComputeDataType
epsilon
,
ComputeDataType
epsilon
,
const
XDataType
*
const
__restrict__
p_x_global
,
const
XDataType
*
const
__restrict__
p_x_global
,
const
GammaDataType
*
const
__restrict__
p_gamma_global
,
const
GammaDataType
*
const
__restrict__
p_gamma_global
,
const
BetaDataType
*
const
__restrict__
p_beta_global
,
const
BetaDataType
*
const
__restrict__
p_beta_global
,
YDataType
*
const
__restrict__
p_y_global
,
YDataType
*
const
__restrict__
p_y_global
,
SaveMeanInvStdDataType
*
const
__restrict__
p_save_mean_global
,
SaveMeanInvStdDataType
*
const
__restrict__
p_save_inv_std_global
,
const
YElementwiseOperation
y_elementwise_op
)
const
YElementwiseOperation
y_elementwise_op
)
{
{
auto
y_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_y_global
,
y_grid_desc_m_k
.
GetElementSpaceSize
());
auto
x_thread_buf
=
generate_tuple
(
auto
x_thread_buf
=
generate_tuple
(
[
&
](
auto
)
{
[
&
](
auto
)
{
return
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
return
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
...
@@ -150,6 +164,7 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk
...
@@ -150,6 +164,7 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk
mean_thread_buf
;
mean_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
ComputeDataType
,
MThreadSliceSize
,
true
>
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
ComputeDataType
,
MThreadSliceSize
,
true
>
var_thread_buf
;
var_thread_buf
;
auto
&
inv_std_thread_buf
=
var_thread_buf
;
const
index_t
thread_local_id
=
get_thread_local_1d_id
();
const
index_t
thread_local_id
=
get_thread_local_1d_id
();
const
index_t
block_global_id
=
get_block_1d_id
();
const
index_t
block_global_id
=
get_block_1d_id
();
...
@@ -226,6 +241,42 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk
...
@@ -226,6 +241,42 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk
thread_k_cluster_id
*
YDstVectorSize
),
thread_k_cluster_id
*
YDstVectorSize
),
y_elementwise_op
);
y_elementwise_op
);
auto
threadwise_mean_store
=
ThreadwiseTensorSliceTransfer_v1r3
<
ComputeDataType
,
SaveMeanInvStdDataType
,
decltype
(
thread_buffer_desc_m
),
GridDesc_M
,
PassThroughOp
,
ThreadBufferLengths_M
,
Sequence
<
0
>
,
// DimAccessOrder
0
,
// SrcVectorDim
SaveMeanInvStdDstVectorSize
,
// ScalarPerVector
InMemoryDataOperationEnum
::
Set
,
1
,
true
>
(
save_mean_grid_desc_m
,
make_multi_index
(
block_global_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
),
PassThroughOp
{});
auto
threadwise_inv_std_store
=
ThreadwiseTensorSliceTransfer_v1r3
<
ComputeDataType
,
SaveMeanInvStdDataType
,
decltype
(
thread_buffer_desc_m
),
GridDesc_M
,
PassThroughOp
,
ThreadBufferLengths_M
,
Sequence
<
0
>
,
// DimAccessOrder
0
,
// SrcVectorDim
SaveMeanInvStdDstVectorSize
,
// ScalarPerVector
InMemoryDataOperationEnum
::
Set
,
1
,
true
>
(
save_inv_std_grid_desc_m
,
make_multi_index
(
block_global_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
),
PassThroughOp
{});
constexpr
auto
thread_copy_fwd_step_m_k
=
make_multi_index
(
0
,
K_BlockTileStepSize
);
constexpr
auto
thread_copy_fwd_step_m_k
=
make_multi_index
(
0
,
K_BlockTileStepSize
);
constexpr
auto
thread_copy_bwd_step_m_k
=
constexpr
auto
thread_copy_bwd_step_m_k
=
make_multi_index
(
0
,
SweepOnce
?
0
:
-
K_BlockTileSize
);
make_multi_index
(
0
,
SweepOnce
?
0
:
-
K_BlockTileSize
);
...
@@ -239,6 +290,15 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk
...
@@ -239,6 +290,15 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk
const
auto
beta_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
const
auto
beta_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_beta_global
,
beta_grid_desc_m_k
.
GetElementSpaceSize
());
p_beta_global
,
beta_grid_desc_m_k
.
GetElementSpaceSize
());
auto
y_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_y_global
,
y_grid_desc_m_k
.
GetElementSpaceSize
());
auto
save_mean_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_save_mean_global
,
save_mean_grid_desc_m
.
GetElementSpaceSize
());
auto
save_inv_std_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_save_inv_std_global
,
save_inv_std_grid_desc_m
.
GetElementSpaceSize
());
auto
threadwise_welford
=
ThreadwiseWelford
();
auto
threadwise_welford
=
ThreadwiseWelford
();
threadwise_welford
.
max_count_
=
GetKPerThread
(
x_grid_desc_m_k
,
thread_k_cluster_id
);
threadwise_welford
.
max_count_
=
GetKPerThread
(
x_grid_desc_m_k
,
thread_k_cluster_id
);
...
@@ -279,10 +339,33 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk
...
@@ -279,10 +339,33 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk
int
count
=
threadwise_welford
.
cur_count_
;
int
count
=
threadwise_welford
.
cur_count_
;
BlockwiseWelford
::
Run
(
mean_thread_buf
(
I
),
var_thread_buf
(
I
),
count
);
BlockwiseWelford
::
Run
(
mean_thread_buf
(
I
),
var_thread_buf
(
I
),
count
);
inv_std_thread_buf
(
I
)
=
type_convert
<
ComputeDataType
>
(
1.0
f
)
/
ck
::
math
::
sqrt
(
var_thread_buf
(
I
)
+
epsilon
);
});
});
// save mean and inverse std for backward (optional)
if
(
thread_k_cluster_id
==
0
)
{
if
(
p_save_mean_global
!=
nullptr
)
{
threadwise_mean_store
.
Run
(
thread_buffer_desc_m
,
make_tuple
(
I0
),
mean_thread_buf
,
save_mean_grid_desc_m
,
save_mean_global_val_buf
);
}
if
(
p_save_inv_std_global
!=
nullptr
)
{
threadwise_inv_std_store
.
Run
(
thread_buffer_desc_m
,
make_tuple
(
I0
),
inv_std_thread_buf
,
save_inv_std_grid_desc_m
,
save_inv_std_global_val_buf
);
}
}
// normalization
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
iM
)
{
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
iM
)
{
auto
divisor
=
1
/
ck
::
math
::
sqrt
(
var_thread_buf
(
iM
)
+
epsilon
);
static_for
<
0
,
ThreadBufferNumber
,
1
>
{}([
&
](
auto
iK0
)
{
static_for
<
0
,
ThreadBufferNumber
,
1
>
{}([
&
](
auto
iK0
)
{
static_for
<
0
,
XSrcVectorSize
,
1
>
{}([
&
](
auto
iK1
)
{
static_for
<
0
,
XSrcVectorSize
,
1
>
{}([
&
](
auto
iK1
)
{
constexpr
auto
offset_m_k
=
constexpr
auto
offset_m_k
=
...
@@ -291,7 +374,7 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk
...
@@ -291,7 +374,7 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk
// normalize
// normalize
y_thread_buf
(
iK0
)(
Number
<
offset_m_k
>
{})
=
y_thread_buf
(
iK0
)(
Number
<
offset_m_k
>
{})
=
(
x_thread_buf
(
iK0
)(
Number
<
offset_m_k
>
{})
-
mean_thread_buf
(
iM
))
*
(
x_thread_buf
(
iK0
)(
Number
<
offset_m_k
>
{})
-
mean_thread_buf
(
iM
))
*
divisor
;
inv_std_thread_buf
(
iM
)
;
// gamma & beta
// gamma & beta
y_thread_buf
(
iK0
)(
Number
<
offset_m_k
>
{})
=
y_thread_buf
(
iK0
)(
Number
<
offset_m_k
>
{})
=
...
@@ -360,8 +443,29 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk
...
@@ -360,8 +443,29 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk
int
count
=
threadwise_welford
.
cur_count_
;
int
count
=
threadwise_welford
.
cur_count_
;
BlockwiseWelford
::
Run
(
mean_thread_buf
(
I
),
var_thread_buf
(
I
),
count
);
BlockwiseWelford
::
Run
(
mean_thread_buf
(
I
),
var_thread_buf
(
I
),
count
);
inv_std_thread_buf
(
I
)
=
1
/
ck
::
math
::
sqrt
(
var_thread_buf
(
I
)
+
epsilon
);
});
});
if
(
thread_k_cluster_id
==
0
)
{
if
(
p_save_mean_global
!=
nullptr
)
{
threadwise_mean_store
.
Run
(
thread_buffer_desc_m
,
make_tuple
(
I0
),
mean_thread_buf
,
save_mean_grid_desc_m
,
save_mean_global_val_buf
);
}
if
(
p_save_inv_std_global
!=
nullptr
)
{
threadwise_inv_std_store
.
Run
(
thread_buffer_desc_m
,
make_tuple
(
I0
),
inv_std_thread_buf
,
save_inv_std_grid_desc_m
,
save_inv_std_global_val_buf
);
}
}
auto
thread_copy_tail_m_k
=
auto
thread_copy_tail_m_k
=
(
num_k_block_tile_iteration
-
1
)
*
ThreadBufferNumber
*
thread_copy_fwd_step_m_k
;
(
num_k_block_tile_iteration
-
1
)
*
ThreadBufferNumber
*
thread_copy_fwd_step_m_k
;
...
@@ -393,7 +497,6 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk
...
@@ -393,7 +497,6 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk
});
});
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
iM
)
{
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
iM
)
{
auto
divisor
=
1
/
ck
::
math
::
sqrt
(
var_thread_buf
(
iM
)
+
epsilon
);
static_for
<
0
,
ThreadBufferNumber
,
1
>
{}([
&
](
auto
iK0
)
{
static_for
<
0
,
ThreadBufferNumber
,
1
>
{}([
&
](
auto
iK0
)
{
static_for
<
0
,
XSrcVectorSize
,
1
>
{}([
&
](
auto
iK1
)
{
static_for
<
0
,
XSrcVectorSize
,
1
>
{}([
&
](
auto
iK1
)
{
constexpr
auto
offset_m_k
=
constexpr
auto
offset_m_k
=
...
@@ -402,7 +505,7 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk
...
@@ -402,7 +505,7 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk
// normalize
// normalize
y_thread_buf
(
iK0
)(
Number
<
offset_m_k
>
{})
=
y_thread_buf
(
iK0
)(
Number
<
offset_m_k
>
{})
=
(
x_thread_buf
(
iK0
)(
Number
<
offset_m_k
>
{})
-
mean_thread_buf
(
iM
))
*
(
x_thread_buf
(
iK0
)(
Number
<
offset_m_k
>
{})
-
mean_thread_buf
(
iM
))
*
divisor
;
inv_std_thread_buf
(
iM
)
;
// gamma
// gamma
y_thread_buf
(
iK0
)(
Number
<
offset_m_k
>
{})
=
y_thread_buf
(
iK0
)(
Number
<
offset_m_k
>
{})
=
...
...
include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp
View file @
1b5af83d
...
@@ -462,7 +462,6 @@ struct mfma_type<MfmaInstr::mfma_f64_16x16x4f64>
...
@@ -462,7 +462,6 @@ struct mfma_type<MfmaInstr::mfma_f64_16x16x4f64>
}
}
};
};
#if defined CK_ENABLE_FP8
template
<
>
template
<
>
struct
mfma_type
<
MfmaInstr
::
mfma_f32_32x32x16f8f8
>
struct
mfma_type
<
MfmaInstr
::
mfma_f32_32x32x16f8f8
>
{
{
...
@@ -506,9 +505,7 @@ struct mfma_type<MfmaInstr::mfma_f32_16x16x32f8f8>
...
@@ -506,9 +505,7 @@ struct mfma_type<MfmaInstr::mfma_f32_16x16x32f8f8>
intrin_mfma_f32_16x16x32f8f8
<
MPerXdlops
,
NPerXdlops
>::
Run
(
a
,
b
,
reg_c
);
intrin_mfma_f32_16x16x32f8f8
<
MPerXdlops
,
NPerXdlops
>::
Run
(
a
,
b
,
reg_c
);
}
}
};
};
#endif
#if defined CK_ENABLE_BF8
template
<
>
template
<
>
struct
mfma_type
<
MfmaInstr
::
mfma_f32_32x32x16bf8bf8
>
struct
mfma_type
<
MfmaInstr
::
mfma_f32_32x32x16bf8bf8
>
{
{
...
@@ -552,9 +549,7 @@ struct mfma_type<MfmaInstr::mfma_f32_16x16x32bf8bf8>
...
@@ -552,9 +549,7 @@ struct mfma_type<MfmaInstr::mfma_f32_16x16x32bf8bf8>
intrin_mfma_f32_16x16x32bf8bf8
<
MPerXdlops
,
NPerXdlops
>::
Run
(
a
,
b
,
reg_c
);
intrin_mfma_f32_16x16x32bf8bf8
<
MPerXdlops
,
NPerXdlops
>::
Run
(
a
,
b
,
reg_c
);
}
}
};
};
#endif
#if defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8
template
<
>
template
<
>
struct
mfma_type
<
MfmaInstr
::
mfma_f32_32x32x16f8bf8
>
struct
mfma_type
<
MfmaInstr
::
mfma_f32_32x32x16f8bf8
>
{
{
...
@@ -598,9 +593,7 @@ struct mfma_type<MfmaInstr::mfma_f32_16x16x32f8bf8>
...
@@ -598,9 +593,7 @@ struct mfma_type<MfmaInstr::mfma_f32_16x16x32f8bf8>
intrin_mfma_f32_16x16x32f8bf8
<
MPerXdlops
,
NPerXdlops
>::
Run
(
a
,
b
,
reg_c
);
intrin_mfma_f32_16x16x32f8bf8
<
MPerXdlops
,
NPerXdlops
>::
Run
(
a
,
b
,
reg_c
);
}
}
};
};
#endif
#if defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8
template
<
>
template
<
>
struct
mfma_type
<
MfmaInstr
::
mfma_f32_32x32x16bf8f8
>
struct
mfma_type
<
MfmaInstr
::
mfma_f32_32x32x16bf8f8
>
{
{
...
@@ -644,7 +637,6 @@ struct mfma_type<MfmaInstr::mfma_f32_16x16x32bf8f8>
...
@@ -644,7 +637,6 @@ struct mfma_type<MfmaInstr::mfma_f32_16x16x32bf8f8>
intrin_mfma_f32_16x16x32bf8f8
<
MPerXdlops
,
NPerXdlops
>::
Run
(
a
,
b
,
reg_c
);
intrin_mfma_f32_16x16x32bf8f8
<
MPerXdlops
,
NPerXdlops
>::
Run
(
a
,
b
,
reg_c
);
}
}
};
};
#endif
template
<
typename
base_type
,
template
<
typename
base_type
,
index_t
MPerXdlops
,
index_t
MPerXdlops
,
...
@@ -792,7 +784,6 @@ struct MfmaSelector
...
@@ -792,7 +784,6 @@ struct MfmaSelector
}
}
#endif
#endif
#if defined CK_ENABLE_FP8
template
<
>
template
<
>
static
constexpr
auto
GetMfma
<
f8_t
,
32
,
32
>
()
static
constexpr
auto
GetMfma
<
f8_t
,
32
,
32
>
()
{
{
...
@@ -804,9 +795,7 @@ struct MfmaSelector
...
@@ -804,9 +795,7 @@ struct MfmaSelector
{
{
return
MfmaInstr
::
mfma_f32_16x16x32f8f8
;
return
MfmaInstr
::
mfma_f32_16x16x32f8f8
;
}
}
#endif
#if defined CK_ENABLE_BF8
template
<
>
template
<
>
static
constexpr
auto
GetMfma
<
bf8_t
,
32
,
32
>
()
static
constexpr
auto
GetMfma
<
bf8_t
,
32
,
32
>
()
{
{
...
@@ -818,9 +807,7 @@ struct MfmaSelector
...
@@ -818,9 +807,7 @@ struct MfmaSelector
{
{
return
MfmaInstr
::
mfma_f32_16x16x32bf8bf8
;
return
MfmaInstr
::
mfma_f32_16x16x32bf8bf8
;
}
}
#endif
#if defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8
template
<
>
template
<
>
static
constexpr
auto
GetMfma
<
f8_t
,
32
,
32
,
bf8_t
>
()
static
constexpr
auto
GetMfma
<
f8_t
,
32
,
32
,
bf8_t
>
()
{
{
...
@@ -832,9 +819,7 @@ struct MfmaSelector
...
@@ -832,9 +819,7 @@ struct MfmaSelector
{
{
return
MfmaInstr
::
mfma_f32_16x16x32f8bf8
;
return
MfmaInstr
::
mfma_f32_16x16x32f8bf8
;
}
}
#endif
#if defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8
template
<
>
template
<
>
static
constexpr
auto
GetMfma
<
bf8_t
,
32
,
32
,
f8_t
>
()
static
constexpr
auto
GetMfma
<
bf8_t
,
32
,
32
,
f8_t
>
()
{
{
...
@@ -846,7 +831,6 @@ struct MfmaSelector
...
@@ -846,7 +831,6 @@ struct MfmaSelector
{
{
return
MfmaInstr
::
mfma_f32_16x16x32bf8f8
;
return
MfmaInstr
::
mfma_f32_16x16x32bf8f8
;
}
}
#endif
static
constexpr
auto
selected_mfma
=
static
constexpr
auto
selected_mfma
=
mfma_type
<
GetMfma
<
base_type
,
MPerXdlops
,
NPerXdlops
,
additional_type
>
()
>
{};
mfma_type
<
GetMfma
<
base_type
,
MPerXdlops
,
NPerXdlops
,
additional_type
>
()
>
{};
...
@@ -1051,18 +1035,10 @@ struct XdlopsGemm
...
@@ -1051,18 +1035,10 @@ struct XdlopsGemm
static_assert
(
static_assert
(
is_same
<
base_type
,
double
>::
value
||
is_same
<
base_type
,
float
>::
value
||
is_same
<
base_type
,
double
>::
value
||
is_same
<
base_type
,
float
>::
value
||
is_same
<
base_type
,
half_t
>::
value
||
is_same
<
base_type
,
bhalf_t
>::
value
||
is_same
<
base_type
,
half_t
>::
value
||
is_same
<
base_type
,
bhalf_t
>::
value
||
is_same
<
base_type
,
int8_t
>::
value
is_same
<
base_type
,
int8_t
>::
value
||
is_same
<
base_type
,
f8_t
>::
value
||
#if defined CK_ENABLE_FP8
is_same
<
base_type
,
bf8_t
>::
value
||
||
is_same
<
base_type
,
f8_t
>::
value
(
is_same
<
base_type
,
f8_t
>::
value
&&
is_same
<
additional_type
,
bf8_t
>::
value
)
||
#endif
(
is_same
<
base_type
,
bf8_t
>::
value
&&
is_same
<
additional_type
,
f8_t
>::
value
),
#if defined CK_ENABLE_BF8
||
is_same
<
base_type
,
bf8_t
>::
value
#endif
#if defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8
||
(
is_same
<
base_type
,
f8_t
>::
value
&&
is_same
<
additional_type
,
bf8_t
>::
value
)
||
(
is_same
<
base_type
,
bf8_t
>::
value
&&
is_same
<
additional_type
,
f8_t
>::
value
)
#endif
,
"base base_type must be double, float, half, bfloat16, int8_t, f8_t or bf8_t!"
);
"base base_type must be double, float, half, bfloat16, int8_t, f8_t or bf8_t!"
);
static_for
<
0
,
KPack
/
mfma_instr
.
k_per_blk
,
1
>
{}([
&
](
auto
k
)
{
static_for
<
0
,
KPack
/
mfma_instr
.
k_per_blk
,
1
>
{}([
&
](
auto
k
)
{
...
...
include/ck/utility/amd_xdlops.hpp
View file @
1b5af83d
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#ifndef CK_AMD_XDLOPS_HPP
#pragma once
#define CK_AMD_XDLOPS_HPP
#include "data_type.hpp"
namespace
ck
{
namespace
ck
{
...
@@ -355,7 +352,6 @@ struct intrin_mfma_f64_16x16x4f64<16, 16>
...
@@ -355,7 +352,6 @@ struct intrin_mfma_f64_16x16x4f64<16, 16>
}
}
};
};
#if defined CK_ENABLE_FP8
template
<
index_t
MPerWave
,
index_t
NPerWave
>
template
<
index_t
MPerWave
,
index_t
NPerWave
>
struct
intrin_mfma_f32_32x32x16f8f8
;
struct
intrin_mfma_f32_32x32x16f8f8
;
...
@@ -418,9 +414,7 @@ struct intrin_mfma_f32_16x16x32f8f8<16, 16>
...
@@ -418,9 +414,7 @@ struct intrin_mfma_f32_16x16x32f8f8<16, 16>
#endif
#endif
}
}
};
};
#endif
#if defined CK_ENABLE_BF8
template
<
index_t
MPerWave
,
index_t
NPerWave
>
template
<
index_t
MPerWave
,
index_t
NPerWave
>
struct
intrin_mfma_f32_32x32x16bf8bf8
;
struct
intrin_mfma_f32_32x32x16bf8bf8
;
...
@@ -483,9 +477,7 @@ struct intrin_mfma_f32_16x16x32bf8bf8<16, 16>
...
@@ -483,9 +477,7 @@ struct intrin_mfma_f32_16x16x32bf8bf8<16, 16>
#endif
#endif
}
}
};
};
#endif
#if defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8
template
<
index_t
MPerWave
,
index_t
NPerWave
>
template
<
index_t
MPerWave
,
index_t
NPerWave
>
struct
intrin_mfma_f32_32x32x16f8bf8
;
struct
intrin_mfma_f32_32x32x16f8bf8
;
...
@@ -548,9 +540,7 @@ struct intrin_mfma_f32_16x16x32f8bf8<16, 16>
...
@@ -548,9 +540,7 @@ struct intrin_mfma_f32_16x16x32f8bf8<16, 16>
#endif
#endif
}
}
};
};
#endif
#if defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8
template
<
index_t
MPerWave
,
index_t
NPerWave
>
template
<
index_t
MPerWave
,
index_t
NPerWave
>
struct
intrin_mfma_f32_32x32x16bf8f8
;
struct
intrin_mfma_f32_32x32x16bf8f8
;
...
@@ -613,6 +603,5 @@ struct intrin_mfma_f32_16x16x32bf8f8<16, 16>
...
@@ -613,6 +603,5 @@ struct intrin_mfma_f32_16x16x32bf8f8<16, 16>
#endif
#endif
}
}
};
};
#endif
}
// namespace ck
}
// namespace ck
#endif
include/ck/utility/data_type.hpp
View file @
1b5af83d
...
@@ -9,15 +9,9 @@ namespace ck {
...
@@ -9,15 +9,9 @@ namespace ck {
using
bhalf_t
=
ushort
;
using
bhalf_t
=
ushort
;
using
half_t
=
_Float16
;
using
half_t
=
_Float16
;
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
using
int4_t
=
_BitInt
(
4
);
using
int4_t
=
_BitInt
(
4
);
using
f8_t
=
_BitInt
(
8
);
#endif
using
bf8_t
=
unsigned
_BitInt
(
8
);
#if defined CK_ENABLE_FP8
using
f8_t
=
_BitInt
(
8
);
#endif
#if defined CK_ENABLE_BF8
using
bf8_t
=
unsigned
_BitInt
(
8
);
#endif
// vector_type
// vector_type
template
<
typename
T
,
index_t
N
>
template
<
typename
T
,
index_t
N
>
...
@@ -148,23 +142,19 @@ struct scalar_type<int4_t>
...
@@ -148,23 +142,19 @@ struct scalar_type<int4_t>
};
};
#endif
#endif
#if defined CK_ENABLE_FP8
template
<
>
template
<
>
struct
scalar_type
<
f8_t
>
struct
scalar_type
<
f8_t
>
{
{
using
type
=
f8_t
;
using
type
=
f8_t
;
static
constexpr
index_t
vector_size
=
1
;
static
constexpr
index_t
vector_size
=
1
;
};
};
#endif
#if defined CK_ENABLE_BF8
template
<
>
template
<
>
struct
scalar_type
<
bf8_t
>
struct
scalar_type
<
bf8_t
>
{
{
using
type
=
bf8_t
;
using
type
=
bf8_t
;
static
constexpr
index_t
vector_size
=
1
;
static
constexpr
index_t
vector_size
=
1
;
};
};
#endif
template
<
typename
T
>
template
<
typename
T
>
struct
vector_type
<
T
,
1
>
struct
vector_type
<
T
,
1
>
...
@@ -968,24 +958,20 @@ using int8x32_t = typename vector_type<int8_t, 32>::type;
...
@@ -968,24 +958,20 @@ using int8x32_t = typename vector_type<int8_t, 32>::type;
using
int8x64_t
=
typename
vector_type
<
int8_t
,
64
>::
type
;
using
int8x64_t
=
typename
vector_type
<
int8_t
,
64
>::
type
;
// f8
// f8
#if defined CK_ENABLE_FP8
using
f8x2_t
=
typename
vector_type
<
f8_t
,
2
>::
type
;
using
f8x2_t
=
typename
vector_type
<
f8_t
,
2
>::
type
;
using
f8x4_t
=
typename
vector_type
<
f8_t
,
4
>::
type
;
using
f8x4_t
=
typename
vector_type
<
f8_t
,
4
>::
type
;
using
f8x8_t
=
typename
vector_type
<
f8_t
,
8
>::
type
;
using
f8x8_t
=
typename
vector_type
<
f8_t
,
8
>::
type
;
using
f8x16_t
=
typename
vector_type
<
f8_t
,
16
>::
type
;
using
f8x16_t
=
typename
vector_type
<
f8_t
,
16
>::
type
;
using
f8x32_t
=
typename
vector_type
<
f8_t
,
32
>::
type
;
using
f8x32_t
=
typename
vector_type
<
f8_t
,
32
>::
type
;
using
f8x64_t
=
typename
vector_type
<
f8_t
,
64
>::
type
;
using
f8x64_t
=
typename
vector_type
<
f8_t
,
64
>::
type
;
#endif
// bf8
// bf8
#if defined CK_ENABLE_BF8
using
bf8x2_t
=
typename
vector_type
<
bf8_t
,
2
>::
type
;
using
bf8x2_t
=
typename
vector_type
<
bf8_t
,
2
>::
type
;
using
bf8x4_t
=
typename
vector_type
<
bf8_t
,
4
>::
type
;
using
bf8x4_t
=
typename
vector_type
<
bf8_t
,
4
>::
type
;
using
bf8x8_t
=
typename
vector_type
<
bf8_t
,
8
>::
type
;
using
bf8x8_t
=
typename
vector_type
<
bf8_t
,
8
>::
type
;
using
bf8x16_t
=
typename
vector_type
<
bf8_t
,
16
>::
type
;
using
bf8x16_t
=
typename
vector_type
<
bf8_t
,
16
>::
type
;
using
bf8x32_t
=
typename
vector_type
<
bf8_t
,
32
>::
type
;
using
bf8x32_t
=
typename
vector_type
<
bf8_t
,
32
>::
type
;
using
bf8x64_t
=
typename
vector_type
<
bf8_t
,
64
>::
type
;
using
bf8x64_t
=
typename
vector_type
<
bf8_t
,
64
>::
type
;
#endif
template
<
typename
T
>
template
<
typename
T
>
struct
NumericLimits
struct
NumericLimits
...
@@ -1033,7 +1019,6 @@ struct NumericLimits<int4_t>
...
@@ -1033,7 +1019,6 @@ struct NumericLimits<int4_t>
};
};
#endif // CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
#endif // CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
#if defined CK_ENABLE_FP8
template
<
>
template
<
>
struct
NumericLimits
<
f8_t
>
struct
NumericLimits
<
f8_t
>
{
{
...
@@ -1056,9 +1041,7 @@ struct NumericLimits<f8_t>
...
@@ -1056,9 +1041,7 @@ struct NumericLimits<f8_t>
__host__
__device__
static
constexpr
f8_t
QuietNaN
()
{
return
f8_t
(
binary_qnan
);
}
__host__
__device__
static
constexpr
f8_t
QuietNaN
()
{
return
f8_t
(
binary_qnan
);
}
};
};
#endif
#if defined CK_ENABLE_BF8
template
<
>
template
<
>
struct
NumericLimits
<
bf8_t
>
struct
NumericLimits
<
bf8_t
>
{
{
...
@@ -1081,7 +1064,6 @@ struct NumericLimits<bf8_t>
...
@@ -1081,7 +1064,6 @@ struct NumericLimits<bf8_t>
__host__
__device__
static
constexpr
bf8_t
QuietNaN
()
{
return
bf8_t
(
binary_qnan
);
}
__host__
__device__
static
constexpr
bf8_t
QuietNaN
()
{
return
bf8_t
(
binary_qnan
);
}
};
};
#endif
template
<
typename
T
>
template
<
typename
T
>
struct
NumericUtils
struct
NumericUtils
...
@@ -1093,6 +1075,7 @@ struct NumericUtils<float>
...
@@ -1093,6 +1075,7 @@ struct NumericUtils<float>
{
{
static
constexpr
int
exp
=
8
;
static
constexpr
int
exp
=
8
;
static
constexpr
int
mant
=
23
;
static
constexpr
int
mant
=
23
;
static
constexpr
int
bias
=
127
;
static
constexpr
uint32_t
nan_mask
=
0x7F800000
;
static
constexpr
uint32_t
nan_mask
=
0x7F800000
;
static
constexpr
uint32_t
head_mask
=
0xFF800000
;
static
constexpr
uint32_t
head_mask
=
0xFF800000
;
static
constexpr
uint32_t
mant_mask
=
0x7FFFFF
;
static
constexpr
uint32_t
mant_mask
=
0x7FFFFF
;
...
@@ -1109,6 +1092,7 @@ struct NumericUtils<half_t>
...
@@ -1109,6 +1092,7 @@ struct NumericUtils<half_t>
{
{
static
constexpr
int
exp
=
5
;
static
constexpr
int
exp
=
5
;
static
constexpr
int
mant
=
10
;
static
constexpr
int
mant
=
10
;
static
constexpr
int
bias
=
15
;
static
constexpr
uint16_t
nan_mask
=
0x7C00
;
static
constexpr
uint16_t
nan_mask
=
0x7C00
;
static
constexpr
uint16_t
head_mask
=
0xFC00
;
static
constexpr
uint16_t
head_mask
=
0xFC00
;
static
constexpr
uint16_t
mant_mask
=
0x3FF
;
static
constexpr
uint16_t
mant_mask
=
0x3FF
;
...
@@ -1120,22 +1104,21 @@ struct NumericUtils<half_t>
...
@@ -1120,22 +1104,21 @@ struct NumericUtils<half_t>
using
bitwise_type
=
uint16_t
;
using
bitwise_type
=
uint16_t
;
};
};
#if defined CK_ENABLE_FP8
template
<
>
template
<
>
struct
NumericUtils
<
f8_t
>
struct
NumericUtils
<
f8_t
>
{
{
static
constexpr
int
exp
=
4
;
static
constexpr
int
exp
=
4
;
static
constexpr
int
mant
=
3
;
static
constexpr
int
mant
=
3
;
static
constexpr
int
bias
=
8
;
// negative zero nan mode
// static constexpr int bias = 7; // ieee mode
};
};
#endif
#if defined CK_ENABLE_BF8
template
<
>
template
<
>
struct
NumericUtils
<
bf8_t
>
struct
NumericUtils
<
bf8_t
>
{
{
static
constexpr
int
exp
=
5
;
static
constexpr
int
exp
=
5
;
static
constexpr
int
mant
=
2
;
static
constexpr
int
mant
=
2
;
static
constexpr
int
bias
=
16
;
// negative zero nan mode
// static constexpr int bias = 15; // ieee mode
};
};
#endif
}
// namespace ck
}
// namespace ck
include/ck/utility/f8_utils.hpp
View file @
1b5af83d
...
@@ -5,9 +5,6 @@
...
@@ -5,9 +5,6 @@
#include "ck/utility/data_type.hpp"
#include "ck/utility/data_type.hpp"
// these conversions are disabled if native conversions available
#if !defined(__gfx940__) && !defined(__gfx941__) && !defined(__gfx942__)
#if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8
namespace
ck
{
namespace
ck
{
// fp8 rounding modes
// fp8 rounding modes
...
@@ -19,6 +16,9 @@ enum class f8_rounding_mode
...
@@ -19,6 +16,9 @@ enum class f8_rounding_mode
stochastic
stochastic
};
};
__host__
inline
int
clz
(
uint32_t
x
)
{
return
__builtin_clz
(
x
);
}
__device__
inline
int
clz
(
uint32_t
x
)
{
return
__clz
(
x
);
}
}
// namespace ck
}
// namespace ck
namespace
ck
::
utils
{
namespace
ck
::
utils
{
...
@@ -36,7 +36,7 @@ __host__ __device__ Y run_cast_to_f8(X x, uint32_t rng)
...
@@ -36,7 +36,7 @@ __host__ __device__ Y run_cast_to_f8(X x, uint32_t rng)
constexpr
int
in_exp
=
NumericUtils
<
X
>::
exp
;
constexpr
int
in_exp
=
NumericUtils
<
X
>::
exp
;
constexpr
int
in_mant
=
NumericUtils
<
X
>::
mant
;
constexpr
int
in_mant
=
NumericUtils
<
X
>::
mant
;
int
exponent
;
int
exponent
,
bias
;
uint32_t
head
,
mantissa
,
sign
;
uint32_t
head
,
mantissa
,
sign
;
// nan code is same for float and half
// nan code is same for float and half
constexpr
Y
nan_code
=
0x80
;
constexpr
Y
nan_code
=
0x80
;
...
@@ -51,12 +51,11 @@ __host__ __device__ Y run_cast_to_f8(X x, uint32_t rng)
...
@@ -51,12 +51,11 @@ __host__ __device__ Y run_cast_to_f8(X x, uint32_t rng)
mantissa
=
x_bitwise
&
NumericUtils
<
X
>::
mant_mask
;
mantissa
=
x_bitwise
&
NumericUtils
<
X
>::
mant_mask
;
exponent
=
(
head
>>
in_mant
)
&
NumericUtils
<
X
>::
exp_mask
;
exponent
=
(
head
>>
in_mant
)
&
NumericUtils
<
X
>::
exp_mask
;
sign
=
head
>>
(
in_exp
+
in_mant
);
sign
=
head
>>
(
in_exp
+
in_mant
);
bias
=
NumericUtils
<
X
>::
bias
;
uint32_t
signed_inf
=
(
sign
<<
(
in_exp
+
in_mant
))
+
(((
1
<<
in_exp
)
-
1
)
<<
in_mant
);
uint32_t
signed_inf
=
(
sign
<<
(
in_exp
+
in_mant
))
+
(((
1
<<
in_exp
)
-
1
)
<<
in_mant
);
uint32_t
drop_mask
=
(
1
<<
(
in_mant
-
out_mant
))
-
1
;
uint32_t
drop_mask
=
(
1
<<
(
in_mant
-
out_mant
))
-
1
;
constexpr
int
max_exp
=
(
1
<<
out_exp
)
-
(
negative_zero_nan
?
1
:
2
);
constexpr
int
max_exp
=
(
1
<<
out_exp
)
-
(
negative_zero_nan
?
1
:
2
);
constexpr
int
exp_low_cutoff
=
(
1
<<
(
in_exp
-
1
))
-
(
1
<<
(
out_exp
-
1
))
+
1
-
(
negative_zero_nan
?
1
:
0
);
if
constexpr
(
negative_zero_nan
)
if
constexpr
(
negative_zero_nan
)
{
{
...
@@ -69,56 +68,107 @@ __host__ __device__ Y run_cast_to_f8(X x, uint32_t rng)
...
@@ -69,56 +68,107 @@ __host__ __device__ Y run_cast_to_f8(X x, uint32_t rng)
return
signed_inf
+
(
mantissa
!=
0
?
1
:
0
);
return
signed_inf
+
(
mantissa
!=
0
?
1
:
0
);
}
}
// if input is half and output is bf8
if
((
NumericUtils
<
X
>::
mant
==
10
)
&&
(
NumericUtils
<
Y
>::
mant
==
2
)
&&
negative_zero_nan
&&
exponent
==
0
)
{
exponent
+=
1
;
while
(
mantissa
<
(
1
<<
in_mant
))
{
mantissa
<<=
1
;
exponent
-=
1
;
}
mantissa
&=
~
(
1
<<
in_mant
);
}
// check if x is 0.0
// check if x is 0.0
if
(
x_bitwise
==
0
)
if
(
x_bitwise
==
0
)
return
0
;
return
0
;
exponent
-=
exp_low_cutoff
-
1
;
// First need to check if it is normal or denorm as there is a difference of implict 1
if
(
exponent
<=
0
)
// Then need to adjust the exponent to align with the F8 exponent, in the meanwhile, shift
drop_mask
=
(
1
<<
(
in_mant
-
out_mant
+
1
-
exponent
))
-
1
;
// The mantissa. Then for stochastic rounding, add rng to mantissa and truncate. And for
mantissa
+=
1
<<
in_mant
;
// RNE, no need to add rng. Then probably need to check whether there is carry and adjust
// apply random number if needed
// exponent and mantissa again3
mantissa
+=
(
stoch
?
rng
:
mantissa
)
&
drop_mask
;
if
(
mantissa
>=
(
2
<<
in_mant
))
// For IEEE bias mode, the bias is 2^(k-1)-1 where k is the width of exponent bits
{
const
int
out_bias
=
(
1
<<
(
out_exp
-
1
))
-
1
+
(
negative_zero_nan
?
1
:
0
);
mantissa
>>=
1
;
const
int
out_denormal_act_exponent
=
1
-
out_bias
;
// actual exponent of f8 denormal
exponent
++
;
// act_exponent is the actual exponent of fp32/fp16 (after subtracting bias)
// out_exponent is the converted f8 exponent with bias encoding
// exponent_diff is the diff between fp32/fp16 exponent and f8 exponent,
// the difference needs to be adjusted and mantissa shifted
int
act_exponent
,
out_exponent
,
exponent_diff
;
if
(
exponent
==
0
)
{
// fp32/fp16 is in denormal.
/* fp32 denormal is below 2^-127 so it is usually not a concern here, we mostly concern fp16
here. In this case, f8 is usually in denormal. But there could be exceptions. fp16 denormal has
exponent bias 15 while bf8 with NANOO has exponent bias 16. It means that there are some numbers in
fp16 denormal but they are bf8 (NANOO) normals - smallest bf8 (NANOO) normal is 2^-15. fp16 numbers
where exponent==0 (actual exponent -14) and highest bit of mantissa is 1 are bf8 (NANOO) normal.
In this case, the fp16 mantissa should be shift left by 1 */
act_exponent
=
exponent
-
bias
+
1
;
exponent_diff
=
out_denormal_act_exponent
-
act_exponent
;
// actual exponent is exponent-bias+1 as it is denormal
}
else
{
// fp32/fp16 is normal with implicit 1
act_exponent
=
exponent
-
bias
;
if
(
act_exponent
<=
out_denormal_act_exponent
)
{
/* This is the case where fp32/fp16 is normal but it is in f8 denormal range.
For example fp8 nanoo mode, denormal exponent is -7, but if the fp32/fp16
actual exponent is -7, it is actually larger due to the implict 1,
Therefore it needs to be adjust to -6 and mantissa shift right by 1.
So for fp32/fp16, exponent -8 is the cut point to convert to fp8 nanoo */
exponent_diff
=
out_denormal_act_exponent
-
act_exponent
;
}
else
{
// both fp32/fp16 and f8 are in normal range
exponent_diff
=
0
;
// exponent_diff=0 does not mean there is no difference for this case,
// act_exponent could be larger. Just that it does not need shift mantissa
}
mantissa
+=
(
1
<<
in_mant
);
// Add the implicit 1 into mantissa
}
}
mantissa
>>=
(
in_mant
-
out_mant
);
// check negative exponent
bool
midpoint
=
(
mantissa
&
((
1
<<
(
in_mant
-
out_mant
+
exponent_diff
))
-
1
))
==
if
(
exponent
<=
0
)
(
1
<<
(
in_mant
-
out_mant
+
exponent_diff
-
1
));
/* This part is a bit tricky. The judgment of whether it is a tie needs to be done before we
shift right as shift right could rip off some residual part and make something not midpoint look
like midpoint. For example, the fp16 number 0x1002 (0 00100 0000000010), it is larger than
midpoint, but after shift right by 4 bits, it would look like midpoint. */
if
(
exponent_diff
>
0
)
mantissa
>>=
exponent_diff
;
else
if
(
exponent_diff
==
-
1
)
mantissa
<<=
-
exponent_diff
;
bool
implicit_one
=
mantissa
&
(
1
<<
in_mant
);
// if there is no implict 1, it means the f8 is denormal and need to adjust to denorm exponent
out_exponent
=
(
act_exponent
+
exponent_diff
)
/*actual f8 exponent*/
+
out_bias
-
(
implicit_one
?
0
:
1
);
// Now we have the exponent and mantissa adjusted
bool
odd
=
mantissa
&
(
1
<<
(
in_mant
-
out_mant
));
// if the least significant bit that is not truncated is 1
mantissa
+=
(
stoch
?
rng
:
(
midpoint
?
(
odd
?
mantissa
:
mantissa
-
1
)
:
mantissa
))
&
drop_mask
;
// Now we deal with overflow
if
(
out_exponent
==
0
)
{
{
if
(
x_bitwise
==
0
)
if
((
1
<<
in_mant
)
&
mantissa
)
return
0
;
else
{
{
// subnormal range; represented by a subnormal float8 (exponent 0)
out_exponent
=
1
;
// denormal overflow to become normal, promote exponent
// and involves loss of accuracy
// No need to make 1 implicit now as it will be addressed later
mantissa
>>=
1
-
exponent
;
exponent
=
0
;
}
}
}
}
// above range: quantize to maximum possible float of the same sign
else
else
if
(
exponent
>
max_exp
)
{
if
((
1
<<
(
in_mant
+
1
))
&
mantissa
)
{
mantissa
>>=
1
;
out_exponent
++
;
// No need to make 1 implicit now as it will be addressed later
}
}
mantissa
>>=
(
in_mant
-
out_mant
);
if
(
out_exponent
>
max_exp
)
{
{
if
(
clip
)
if
(
clip
)
{
{
mantissa
=
(
1
<<
out_mant
)
-
1
;
mantissa
=
(
1
<<
out_mant
)
-
1
;
exponent
=
max_exp
;
out_
exponent
=
max_exp
;
}
}
else
else
{
{
...
@@ -127,10 +177,10 @@ __host__ __device__ Y run_cast_to_f8(X x, uint32_t rng)
...
@@ -127,10 +177,10 @@ __host__ __device__ Y run_cast_to_f8(X x, uint32_t rng)
}
}
// check if x is 0.0 or -0.0
// check if x is 0.0 or -0.0
if
(
exponent
==
0
&&
mantissa
==
0
)
if
(
out_
exponent
==
0
&&
mantissa
==
0
)
return
negative_zero_nan
?
0
:
(
sign
<<
(
out_exp
+
out_mant
));
return
negative_zero_nan
?
0
:
(
sign
<<
(
out_exp
+
out_mant
));
mantissa
&=
(
1
<<
out_mant
)
-
1
;
mantissa
&=
(
1
<<
out_mant
)
-
1
;
return
(
sign
<<
(
out_exp
+
out_mant
))
|
(
exponent
<<
out_mant
)
|
mantissa
;
return
(
sign
<<
(
out_exp
+
out_mant
))
|
(
out_
exponent
<<
out_mant
)
|
mantissa
;
}
}
template
<
typename
X
,
typename
Y
,
bool
negative_zero_nan
>
template
<
typename
X
,
typename
Y
,
bool
negative_zero_nan
>
...
@@ -196,12 +246,9 @@ __host__ __device__ Y run_cast_from_f8(X x)
...
@@ -196,12 +246,9 @@ __host__ __device__ Y run_cast_from_f8(X x)
if
(
exponent
==
0
)
if
(
exponent
==
0
)
{
{
// guaranteed mantissa!=0 since cases 0x0 and 0x80 are handled above
// guaranteed mantissa!=0 since cases 0x0 and 0x80 are handled above
exponent
++
;
int
sh
=
1
+
clz
(
mantissa
)
-
(
32
-
in_mant
);
while
(
mantissa
<
(
1
<<
in_mant
))
mantissa
<<=
sh
;
{
exponent
+=
1
-
sh
;
mantissa
<<=
1
;
exponent
--
;
}
mantissa
&=
((
1
<<
in_mant
)
-
1
);
mantissa
&=
((
1
<<
in_mant
)
-
1
);
}
}
exponent
+=
exp_low_cutoff
-
1
;
exponent
+=
exp_low_cutoff
-
1
;
...
@@ -244,5 +291,3 @@ __host__ __device__ Y cast_from_f8(X x)
...
@@ -244,5 +291,3 @@ __host__ __device__ Y cast_from_f8(X x)
}
}
}
// namespace ck::utils
}
// namespace ck::utils
#endif // #if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8
#endif // #if !defined(__gfx940__) && !defined(__gfx941__) && !defined(__gfx942__)
Prev
1
2
3
4
5
6
7
8
9
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment