Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
1b5af83d
Commit
1b5af83d
authored
Oct 20, 2023
by
illsilin
Browse files
Merge branch 'develop' into lwpck-976
parents
aac26d32
1fd27d52
Changes
176
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1873 additions
and
369 deletions
+1873
-369
include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_splitk_c_shuffle.hpp
...tion/gpu/device/impl/device_gemm_xdl_splitk_c_shuffle.hpp
+86
-16
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp
...device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp
+1
-1
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_dl.hpp
...ion/gpu/device/impl/device_grouped_conv_bwd_weight_dl.hpp
+3
-28
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle.hpp
...ice/impl/device_grouped_conv_bwd_weight_wmma_cshuffle.hpp
+877
-0
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp
...vice/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp
+7
-28
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp
...impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp
+1
-1
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp
...r_operation/gpu/device/impl/device_grouped_conv_utils.hpp
+8
-1
include/ck/tensor_operation/gpu/device/impl/device_normalization_impl.hpp
...r_operation/gpu/device/impl/device_normalization_impl.hpp
+98
-23
include/ck/tensor_operation/gpu/device/impl/device_normalization_splitk_impl.hpp
...tion/gpu/device/impl/device_normalization_splitk_impl.hpp
+174
-84
include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp
...or_operation/gpu/element/unary_element_wise_operation.hpp
+103
-12
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp
...ation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp
+27
-14
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp
...tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp
+18
-9
include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_naive_variance.hpp
...d/normalization/gridwise_normalization_naive_variance.hpp
+112
-5
include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_selector.hpp
...pu/grid/normalization/gridwise_normalization_selector.hpp
+50
-16
include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_splitk_2nd.hpp
.../grid/normalization/gridwise_normalization_splitk_2nd.hpp
+85
-4
include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_welford_variance.hpp
...normalization/gridwise_normalization_welford_variance.hpp
+110
-7
include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp
include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp
+4
-28
include/ck/utility/amd_xdlops.hpp
include/ck/utility/amd_xdlops.hpp
+2
-13
include/ck/utility/data_type.hpp
include/ck/utility/data_type.hpp
+9
-26
include/ck/utility/f8_utils.hpp
include/ck/utility/f8_utils.hpp
+98
-53
No files found.
include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_splitk_c_shuffle.hpp
View file @
1b5af83d
...
...
@@ -127,7 +127,50 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
PipelineVer
,
ComputeType
>
;
using
Argument
=
typename
GridwiseGemm
::
Argument
;
struct
Argument
:
public
GridwiseGemm
::
Argument
{
Argument
(
const
ADataType
*
p_a_grid_
,
const
BDataType
*
p_b_grid_
,
CDataType
*
p_c_grid_
,
index_t
M_
,
index_t
N_
,
index_t
K_
,
index_t
StrideA_
,
index_t
StrideB_
,
index_t
StrideC_
,
index_t
MPadded_
,
index_t
NPadded_
,
index_t
KPadded_
,
index_t
K0_
,
index_t
k_batch_
,
AElementwiseOperation
a_element_op_
,
BElementwiseOperation
b_element_op_
,
CElementwiseOperation
c_element_op_
)
:
GridwiseGemm
::
Argument
(
p_a_grid_
,
p_b_grid_
,
p_c_grid_
,
M_
,
N_
,
K_
,
StrideA_
,
StrideB_
,
StrideC_
,
MPadded_
,
NPadded_
,
KPadded_
,
K0_
,
k_batch_
),
a_element_op
(
a_element_op_
),
b_element_op
(
b_element_op_
),
c_element_op
(
c_element_op_
)
{
}
AElementwiseOperation
a_element_op
;
BElementwiseOperation
b_element_op
;
CElementwiseOperation
c_element_op
;
};
using
DefaultBlock2CTileMap
=
typename
GridwiseGemm
::
DefaultBlock2CTileMap
;
// Invoker
...
...
@@ -168,8 +211,17 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
karg
.
M
*
karg
.
N
*
sizeof
(
CDataType
),
stream_config
.
stream_id_
));
ave_time
=
launch_and_time_kernel
(
stream_config
,
kernel
,
dim3
(
gdx
,
gdy
,
gdz
),
dim3
(
BlockSize
),
0
,
karg
,
b2c_map
);
ave_time
=
launch_and_time_kernel
(
stream_config
,
kernel
,
dim3
(
gdx
,
gdy
,
gdz
),
dim3
(
BlockSize
),
0
,
static_cast
<
typename
GridwiseGemm
::
Argument
>
(
karg
),
b2c_map
,
karg
.
a_element_op
,
karg
.
b_element_op
,
karg
.
c_element_op
);
};
if
(
has_main_k0_block_loop
)
...
...
@@ -180,7 +232,10 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
kernel_gemm_xdlops_v2r4r2_simplified
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
Set
,
DefaultBlock2CTileMap
>
;
DefaultBlock2CTileMap
,
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
>
;
Run
(
kernel
);
}
...
...
@@ -190,7 +245,10 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
kernel_gemm_xdlops_v2r4r2_simplified
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
AtomicAdd
,
DefaultBlock2CTileMap
>
;
DefaultBlock2CTileMap
,
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
>
;
Run
(
kernel
);
}
...
...
@@ -203,7 +261,10 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
kernel_gemm_xdlops_v2r4r2_simplified
<
GridwiseGemm
,
false
,
InMemoryDataOperationEnum
::
Set
,
DefaultBlock2CTileMap
>
;
DefaultBlock2CTileMap
,
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
>
;
Run
(
kernel
);
}
...
...
@@ -213,7 +274,10 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
kernel_gemm_xdlops_v2r4r2_simplified
<
GridwiseGemm
,
false
,
InMemoryDataOperationEnum
::
AtomicAdd
,
DefaultBlock2CTileMap
>
;
DefaultBlock2CTileMap
,
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
>
;
Run
(
kernel
);
}
...
...
@@ -261,12 +325,12 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
index_t
StrideA
,
index_t
StrideB
,
index_t
StrideC
,
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
c_element_op
,
index_t
KBatch
)
{
return
Argument
{
p_a
,
return
Argument
(
p_a
,
p_b
,
p_c
,
M
,
...
...
@@ -279,7 +343,10 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
GridwiseGemm
::
CalculateNPadded
(
N
),
GridwiseGemm
::
CalculateKPadded
(
K
,
KBatch
),
GridwiseGemm
::
CalculateK0
(
K
,
KBatch
),
KBatch
};
KBatch
,
a_element_op
,
b_element_op
,
c_element_op
);
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
...
...
@@ -294,9 +361,9 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
index_t
StrideA
,
index_t
StrideB
,
index_t
StrideC
,
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
c_element_op
,
ck
::
index_t
KBatch
=
1
)
override
{
return
std
::
make_unique
<
Argument
>
(
static_cast
<
const
ADataType
*>
(
p_a
),
...
...
@@ -312,7 +379,10 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
GridwiseGemm
::
CalculateNPadded
(
N
),
GridwiseGemm
::
CalculateKPadded
(
K
,
KBatch
),
GridwiseGemm
::
CalculateK0
(
K
,
KBatch
),
KBatch
);
KBatch
,
a_element_op
,
b_element_op
,
c_element_op
);
}
// polymorphic
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp
View file @
1b5af83d
...
...
@@ -565,7 +565,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle
auto
launch_kernel
=
[
&
](
auto
has_main_k_block_loop
)
{
constexpr
bool
has_main_loop
=
has_main_k_block_loop
.
value
;
const
auto
kernel
=
kernel_grouped_conv_
fwd_
multiple_d_wmma_cshuffle
<
const
auto
kernel
=
kernel_grouped_conv_multiple_d_wmma_cshuffle
<
GridwiseGemm
,
ADataType
,
BDataType
,
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_dl.hpp
View file @
1b5af83d
...
...
@@ -12,6 +12,7 @@
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_conv_bwd_weight.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp"
#include "ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_dl_v1r3.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
...
...
@@ -22,32 +23,6 @@ namespace ck {
namespace
tensor_operation
{
namespace
device
{
namespace
{
struct
ComputePtrOffsetOfStridedBatch
{
__host__
__device__
constexpr
long_index_t
GetAPtrOffset
(
index_t
g_idx
)
const
{
return
g_idx
*
static_cast
<
long_index_t
>
(
BatchStrideA_
);
}
__host__
__device__
constexpr
long_index_t
GetBPtrOffset
(
index_t
g_idx
)
const
{
return
g_idx
*
static_cast
<
long_index_t
>
(
BatchStrideB_
);
}
__host__
__device__
constexpr
long_index_t
GetCPtrOffset
(
index_t
g_idx
)
const
{
return
g_idx
*
static_cast
<
long_index_t
>
(
BatchStrideC_
);
}
index_t
BatchStrideA_
;
index_t
BatchStrideB_
;
index_t
BatchStrideC_
;
};
}
// namespace
template
<
typename
GridwiseGemm
,
typename
FloatAB
,
typename
FloatC
,
...
...
@@ -952,7 +927,7 @@ struct DeviceGroupedConvBwdWeight_Dl : public DeviceGroupedConvBwdWeight<NDimSpa
Block2CTileMap
block_2_ctile_map_
;
// for computing batch offset
ComputePtrOffsetOfStridedBatch
compute_ptr_offset_of_batch_
;
ComputePtrOffsetOfStridedBatch
<
I0
>
compute_ptr_offset_of_batch_
;
// element-wise op
OutElementwiseOperation
a_element_op_
;
...
...
@@ -1024,7 +999,7 @@ struct DeviceGroupedConvBwdWeight_Dl : public DeviceGroupedConvBwdWeight<NDimSpa
remove_reference_t
<
DeviceOp
::
BGridDesc_B_K0_N0_N1_K1
>
,
remove_reference_t
<
DeviceOp
::
CGridDesc_M0_M10_M11_N0_N10_N11
>
,
remove_reference_t
<
DeviceOp
::
Block2CTileMap
>
,
ComputePtrOffsetOfStridedBatch
,
ComputePtrOffsetOfStridedBatch
<
I0
>
,
has_main_loop
,
has_double_loop
>
;
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle.hpp
0 → 100644
View file @
1b5af83d
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <numeric>
#include <sstream>
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_conv_bwd_weight.hpp"
#include "ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
template
<
index_t
NDimSpatial
,
typename
InLayout
,
typename
WeiLayout
,
typename
OutLayout
,
typename
InDataType
,
typename
WeiDataType
,
typename
OutDataType
,
typename
AccDataType
,
typename
InElementwiseOperation
,
typename
WeiElementwiseOperation
,
typename
OutElementwiseOperation
,
ConvolutionBackwardWeightSpecialization
ConvBackwardWeightSpecialization
,
index_t
BlockSize
,
index_t
MPerBlock
,
index_t
NPerBlock
,
index_t
K0PerBlock
,
index_t
K1
,
index_t
MPerWMMA
,
index_t
NPerWMMA
,
index_t
MRepeat
,
index_t
NRepeat
,
typename
ABlockTransferThreadClusterLengths_K0_M_K1
,
typename
ABlockTransferThreadClusterArrangeOrder
,
typename
ABlockTransferSrcAccessOrder
,
index_t
ABlockTransferSrcVectorDim
,
index_t
ABlockTransferSrcScalarPerVector
,
index_t
ABlockTransferDstScalarPerVector_K1
,
bool
ABlockLdsAddExtraM
,
typename
BBlockTransferThreadClusterLengths_K0_N_K1
,
typename
BBlockTransferThreadClusterArrangeOrder
,
typename
BBlockTransferSrcAccessOrder
,
index_t
BBlockTransferSrcVectorDim
,
index_t
BBlockTransferSrcScalarPerVector
,
index_t
BBlockTransferDstScalarPerVector_K1
,
bool
BBlockLdsAddExtraN
,
index_t
CShuffleMRepeatPerShuffle
,
index_t
CShuffleNRepeatPerShuffle
,
typename
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
index_t
CShuffleBlockTransferScalarPerVector_NPerBlock
,
index_t
NumGemmKPrefetchStage
=
1
,
LoopScheduler
LoopSched
=
make_default_loop_scheduler
(),
ck
::
PipelineVersion
PipelineVer
=
ck
::
PipelineVersion
::
v1
,
typename
ck
::
enable_if
<
NDimSpatial
==
3
,
bool
>
::
type
=
false
>
struct
DeviceGroupedConvBwdWeight_Wmma_CShuffle
:
public
DeviceGroupedConvBwdWeight
<
NDimSpatial
,
InLayout
,
WeiLayout
,
OutLayout
,
InDataType
,
WeiDataType
,
OutDataType
,
InElementwiseOperation
,
WeiElementwiseOperation
,
OutElementwiseOperation
>
{
using
DeviceOp
=
DeviceGroupedConvBwdWeight_Wmma_CShuffle
;
using
ADataType
=
OutDataType
;
using
BDataType
=
InDataType
;
using
CDataType
=
WeiDataType
;
using
AElementwiseOperation
=
OutElementwiseOperation
;
using
BElementwiseOperation
=
InElementwiseOperation
;
using
CElementwiseOperation
=
WeiElementwiseOperation
;
// TODO make A/B datatype different
using
ABDataType
=
InDataType
;
// 3d
static
constexpr
bool
is_NDHWGK_GKZYXC_NDHWGC
=
is_same_v
<
InLayout
,
tensor_layout
::
convolution
::
NDHWGC
>
&&
is_same_v
<
WeiLayout
,
tensor_layout
::
convolution
::
GKZYXC
>
&&
is_same_v
<
OutLayout
,
tensor_layout
::
convolution
::
NDHWGK
>
;
static
constexpr
bool
is_GNDHWK_GKZYXC_GNDHWC
=
is_same_v
<
InLayout
,
tensor_layout
::
convolution
::
GNDHWC
>
&&
is_same_v
<
WeiLayout
,
tensor_layout
::
convolution
::
GKZYXC
>
&&
is_same_v
<
OutLayout
,
tensor_layout
::
convolution
::
GNDHWK
>
;
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
static
constexpr
auto
I4
=
Number
<
4
>
{};
static
constexpr
auto
I5
=
Number
<
5
>
{};
static
constexpr
auto
GemmK1Number
=
Number
<
K1
>
{};
static
constexpr
index_t
KPerBlock
=
K0PerBlock
*
GemmK1Number
;
template
<
index_t
NDim
,
typename
ck
::
enable_if
<
NDim
==
3
,
bool
>
::
type
=
false
>
constexpr
static
auto
make_out_grid_desc
(
const
index_t
N
,
const
index_t
Do
,
const
index_t
Ho
,
const
index_t
Wo
,
const
index_t
K
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
output_strides
)
{
const
index_t
WoStride
=
output_strides
[
5
];
const
auto
KStride
=
Number
<
1
>
{};
return
make_naive_tensor_descriptor
(
make_tuple
(
N
*
Do
*
Ho
*
Wo
,
K
),
make_tuple
(
WoStride
,
KStride
));
}
template
<
index_t
NDim
,
typename
ck
::
enable_if
<
NDim
==
3
,
bool
>
::
type
=
false
>
constexpr
static
auto
make_in_grid_desc
(
const
index_t
N
,
const
index_t
Di
,
const
index_t
Hi
,
const
index_t
Wi
,
const
index_t
C
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
input_strides
)
{
const
index_t
NStride
=
input_strides
[
1
];
const
index_t
DiStride
=
input_strides
[
3
];
const
index_t
HiStride
=
input_strides
[
4
];
const
index_t
WiStride
=
input_strides
[
5
];
const
auto
CStride
=
input_strides
[
2
];
if
constexpr
(
ConvBackwardWeightSpecialization
==
ConvolutionBackwardWeightSpecialization
::
Filter1x1Stride1Pad0
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
N
*
Di
*
Hi
*
Wi
,
C
),
make_tuple
(
WiStride
,
CStride
));
}
else
{
return
make_naive_tensor_descriptor
(
make_tuple
(
N
,
Di
,
Hi
,
Wi
,
C
),
make_tuple
(
NStride
,
DiStride
,
HiStride
,
WiStride
,
CStride
));
}
}
template
<
index_t
NDim
,
typename
ck
::
enable_if
<
NDim
==
3
,
bool
>
::
type
=
false
>
constexpr
static
auto
make_wei_grid_desc
(
const
index_t
K
,
const
index_t
Z
,
const
index_t
Y
,
const
index_t
X
,
const
index_t
C
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
weights_strides
)
{
const
auto
CStride
=
Number
<
1
>
{};
const
auto
KStride
=
weights_strides
[
1
];
return
make_naive_tensor_descriptor
(
make_tuple
(
K
,
Z
*
Y
*
X
*
C
),
make_tuple
(
KStride
,
CStride
));
}
template
<
index_t
NDim
,
typename
ck
::
enable_if
<
NDim
==
3
,
bool
>
::
type
=
false
>
static
auto
MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N
(
const
index_t
N
,
const
index_t
K
,
const
index_t
C
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_spatial_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
filter_spatial_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
output_spatial_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
input_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
weights_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
output_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_dilations
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_left_pads
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_right_pads
)
{
using
namespace
ck
;
const
index_t
Di
=
input_spatial_lengths
[
0
];
const
index_t
Hi
=
input_spatial_lengths
[
1
];
const
index_t
Wi
=
input_spatial_lengths
[
2
];
const
index_t
Do
=
output_spatial_lengths
[
0
];
const
index_t
Ho
=
output_spatial_lengths
[
1
];
const
index_t
Wo
=
output_spatial_lengths
[
2
];
const
index_t
Z
=
filter_spatial_lengths
[
0
];
const
index_t
Y
=
filter_spatial_lengths
[
1
];
const
index_t
X
=
filter_spatial_lengths
[
2
];
const
index_t
ConvStrideD
=
conv_filter_strides
[
0
];
const
index_t
ConvStrideH
=
conv_filter_strides
[
1
];
const
index_t
ConvStrideW
=
conv_filter_strides
[
2
];
const
index_t
ConvDilationD
=
conv_filter_dilations
[
0
];
const
index_t
ConvDilationH
=
conv_filter_dilations
[
1
];
const
index_t
ConvDilationW
=
conv_filter_dilations
[
2
];
const
index_t
InLeftPadD
=
input_left_pads
[
0
];
const
index_t
InLeftPadH
=
input_left_pads
[
1
];
const
index_t
InLeftPadW
=
input_left_pads
[
2
];
const
index_t
InRightPadD
=
input_right_pads
[
0
];
const
index_t
InRightPadH
=
input_right_pads
[
1
];
const
index_t
InRightPadW
=
input_right_pads
[
2
];
const
index_t
GemmKTotal
=
N
*
Do
*
Ho
*
Wo
;
const
index_t
GemmM
=
K
;
const
index_t
GemmN
=
C
*
Z
*
X
*
Y
;
const
auto
PadGemmM
=
(
MPerBlock
-
GemmM
%
MPerBlock
)
%
MPerBlock
;
const
auto
PadGemmN
=
(
NPerBlock
-
GemmN
%
NPerBlock
)
%
NPerBlock
;
const
index_t
GemmK0
=
math
::
integer_divide_ceil
(
GemmKTotal
,
GemmK1Number
*
K0PerBlock
)
*
K0PerBlock
;
const
index_t
GemmKPad
=
GemmK0
*
GemmK1Number
;
const
auto
out_grid_desc
=
make_out_grid_desc
<
NDim
>
(
N
,
Do
,
Ho
,
Wo
,
K
,
output_strides
);
const
auto
in_grid_desc
=
make_in_grid_desc
<
NDim
>
(
N
,
Di
,
Hi
,
Wi
,
C
,
input_strides
);
const
auto
wei_grid_desc
=
make_wei_grid_desc
<
NDim
>
(
K
,
Z
,
Y
,
X
,
C
,
weights_strides
);
if
constexpr
(
ConvBackwardWeightSpecialization
==
ConvolutionBackwardWeightSpecialization
::
Filter1x1Stride1Pad0
)
{
// A: output tensor
const
auto
out_gemmkpad_gemmm_grid_desc
=
transform_tensor_descriptor
(
out_grid_desc
,
make_tuple
(
make_right_pad_transform
(
GemmKTotal
,
GemmKPad
-
GemmKTotal
),
make_pass_through_transform
(
GemmM
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc
=
transform_tensor_descriptor
(
out_gemmkpad_gemmm_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
GemmK0
,
GemmK1Number
)),
make_pass_through_transform
(
GemmM
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
// B: input tensor
const
auto
in_gemmkpad_gemmn_grid_desc
=
transform_tensor_descriptor
(
in_grid_desc
,
make_tuple
(
make_right_pad_transform
(
GemmKTotal
,
GemmKPad
-
GemmKTotal
),
make_pass_through_transform
(
GemmN
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc
=
transform_tensor_descriptor
(
in_gemmkpad_gemmn_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
GemmK0
,
GemmK1Number
)),
make_pass_through_transform
(
GemmN
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
make_tuple
(
out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc
,
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc
,
wei_grid_desc
);
}
else
{
// A: output tensor
const
auto
out_gemmkpad_gemmm_grid_desc
=
transform_tensor_descriptor
(
out_grid_desc
,
make_tuple
(
make_right_pad_transform
(
GemmKTotal
,
GemmKPad
-
GemmKTotal
),
make_pass_through_transform
(
GemmM
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc
=
transform_tensor_descriptor
(
out_gemmkpad_gemmm_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
GemmK0
,
GemmK1Number
)),
make_pass_through_transform
(
GemmM
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
// B: input tensor
const
auto
in_n_dip_hip_wip_c_grid_desc
=
transform_tensor_descriptor
(
in_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_pad_transform
(
Di
,
InLeftPadD
,
InRightPadD
),
make_pad_transform
(
Hi
,
InLeftPadH
,
InRightPadH
),
make_pad_transform
(
Wi
,
InLeftPadW
,
InRightPadW
),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{}));
const
auto
in_n_z_do_y_ho_x_wo_c_grid_desc
=
transform_tensor_descriptor
(
in_n_dip_hip_wip_c_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_embed_transform
(
make_tuple
(
Z
,
Do
),
make_tuple
(
ConvDilationD
,
ConvStrideD
)),
make_embed_transform
(
make_tuple
(
Y
,
Ho
),
make_tuple
(
ConvDilationH
,
ConvStrideH
)),
make_embed_transform
(
make_tuple
(
X
,
Wo
),
make_tuple
(
ConvDilationW
,
ConvStrideW
)),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
,
4
>
{},
Sequence
<
5
,
6
>
{},
Sequence
<
7
>
{}));
const
auto
in_gemmktotal_gemmn_grid_desc
=
transform_tensor_descriptor
(
in_n_z_do_y_ho_x_wo_c_grid_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
Z
,
Y
,
X
,
C
)),
make_merge_transform
(
make_tuple
(
N
,
Do
,
Ho
,
Wo
))),
make_tuple
(
Sequence
<
1
,
3
,
5
,
7
>
{},
Sequence
<
0
,
2
,
4
,
6
>
{}),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}));
const
auto
in_gemmkpad_gemmn_grid_desc
=
transform_tensor_descriptor
(
in_gemmktotal_gemmn_grid_desc
,
make_tuple
(
make_right_pad_transform
(
GemmKTotal
,
GemmKPad
-
GemmKTotal
),
make_pass_through_transform
(
GemmN
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc
=
transform_tensor_descriptor
(
in_gemmkpad_gemmn_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
GemmK0
,
GemmK1Number
)),
make_pass_through_transform
(
GemmN
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
// Pad
const
auto
out_gemmkbatch_gemmk0_gemmm_gemmk1_pad_grid_desc
=
transform_tensor_descriptor
(
out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc
,
make_tuple
(
make_pass_through_transform
(
GemmK0
),
make_right_pad_transform
(
GemmM
,
PadGemmM
),
make_pass_through_transform
(
GemmK1Number
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}));
const
auto
in_gemmkbatch_gemmk0_gemmn_gemmk1_pad_grid_desc
=
transform_tensor_descriptor
(
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc
,
make_tuple
(
make_pass_through_transform
(
GemmK0
),
make_right_pad_transform
(
GemmN
,
PadGemmN
),
make_pass_through_transform
(
GemmK1Number
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}));
const
auto
wei_gemmm_gemmn_pad_grid_desc
=
transform_tensor_descriptor
(
wei_grid_desc
,
make_tuple
(
make_right_pad_transform
(
GemmM
,
PadGemmM
),
make_right_pad_transform
(
GemmN
,
PadGemmN
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
return
make_tuple
(
out_gemmkbatch_gemmk0_gemmm_gemmk1_pad_grid_desc
,
in_gemmkbatch_gemmk0_gemmn_gemmk1_pad_grid_desc
,
wei_gemmm_gemmn_pad_grid_desc
);
}
}
template
<
index_t
NDim
,
typename
ck
::
enable_if
<
NDim
==
3
,
bool
>
::
type
=
false
>
static
auto
GetABCGridDesc
()
{
const
index_t
dim
=
1
;
const
std
::
array
<
index_t
,
NDimSpatial
>
lengths
{
1
,
1
,
1
};
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>
strides
{
1
,
1
,
1
,
1
,
1
,
1
};
const
std
::
array
<
index_t
,
NDimSpatial
>
params
{
1
,
1
,
1
};
return
MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N
<
3
>
(
dim
,
dim
,
dim
,
lengths
,
lengths
,
lengths
,
strides
,
strides
,
strides
,
params
,
params
,
params
,
params
);
}
using
ABCGridDescs
=
decltype
(
GetABCGridDesc
<
NDimSpatial
>
());
using
AGridDesc_K0_M_K1
=
remove_cvref_t
<
decltype
(
ABCGridDescs
{}[
I0
])
>
;
using
BGridDesc_K0_N_K1
=
remove_cvref_t
<
decltype
(
ABCGridDescs
{}[
I1
])
>
;
using
CGridDesc_M_N
=
remove_cvref_t
<
decltype
(
ABCGridDescs
{}[
I2
])
>
;
using
GridwiseGemm
=
GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle
<
// DataType Family
ADataType
,
BDataType
,
AccDataType
,
CDataType
,
Tuple
<>
,
CDataType
,
// InMemory Data Descriptor
AGridDesc_K0_M_K1
,
BGridDesc_K0_N_K1
,
Tuple
<>
,
CGridDesc_M_N
,
// ElementwiseOp Family
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
,
InMemoryDataOperationEnum
::
Set
,
// Tiling Family
MPerBlock
,
NPerBlock
,
K0PerBlock
,
MPerWMMA
,
NPerWMMA
,
K1
,
MRepeat
,
NRepeat
,
// ThreadCluster Family
BlockSize
,
ABlockTransferThreadClusterLengths_K0_M_K1
,
ABlockTransferThreadClusterArrangeOrder
,
ABlockTransferSrcAccessOrder
,
ABlockTransferSrcVectorDim
,
ABlockTransferSrcScalarPerVector
,
ABlockTransferDstScalarPerVector_K1
,
false
,
ABlockLdsAddExtraM
,
BBlockTransferThreadClusterLengths_K0_N_K1
,
BBlockTransferThreadClusterArrangeOrder
,
BBlockTransferSrcAccessOrder
,
BBlockTransferSrcVectorDim
,
BBlockTransferSrcScalarPerVector
,
BBlockTransferDstScalarPerVector_K1
,
false
,
BBlockLdsAddExtraN
,
CShuffleMRepeatPerShuffle
,
CShuffleNRepeatPerShuffle
,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
CShuffleBlockTransferScalarPerVector_NPerBlock
,
NumGemmKPrefetchStage
,
LoopSched
,
PipelineVer
>
;
using
DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
=
decltype
(
GridwiseGemm
::
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
Tuple
<>
{}));
using
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
=
decltype
(
GridwiseGemm
::
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
CGridDesc_M_N
{}));
using
Block2CTileMap
=
decltype
(
GridwiseGemm
::
MakeDefaultBlock2CTileMap
(
CGridDesc_M_N
{},
I1
/* M01 */
,
I1
/* N01 */
));
struct
Argument
:
public
BaseArgument
{
Argument
(
const
InDataType
*
p_in_grid
,
WeiDataType
*
p_wei_grid
,
const
OutDataType
*
p_out_grid
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
a_g_n_c_wis_lengths
,
// input
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
a_g_n_c_wis_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
b_g_k_c_xs_lengths
,
// weight
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
b_g_k_c_xs_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
e_g_n_k_wos_lengths
,
// output
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
e_g_n_k_wos_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_dilations
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_left_pads
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_right_pads
,
InElementwiseOperation
in_element_op
,
WeiElementwiseOperation
wei_element_op
,
OutElementwiseOperation
out_element_op
,
index_t
split_k
)
:
p_a_grid_
{
p_out_grid
},
p_b_grid_
{
p_in_grid
},
p_c_grid_
{
p_wei_grid
},
a_grid_desc_kbatch_k0_m_k1_
{},
b_grid_desc_kbatch_k0_n_k1_
{},
c_grid_desc_m_n_
{},
c_grid_desc_mblock_mperblock_nblock_nperblock_
{},
block_2_ctile_map_
{},
compute_ptr_offset_of_batch_
{},
a_element_op_
{
out_element_op
},
b_element_op_
{
in_element_op
},
c_element_op_
{
wei_element_op
},
Conv_G_
{
a_g_n_c_wis_lengths
[
0
]},
Conv_N_
{
a_g_n_c_wis_lengths
[
1
]},
Conv_K_
{
b_g_k_c_xs_lengths
[
1
]},
Conv_C_
{
a_g_n_c_wis_lengths
[
2
]},
input_spatial_lengths_
{},
filter_spatial_lengths_
{},
output_spatial_lengths_
{},
conv_filter_strides_
{
conv_filter_strides
},
input_left_pads_
{
input_left_pads
},
input_right_pads_
{
input_right_pads
},
k_batch_
{
split_k
}
{
constexpr
index_t
spatial_offset
=
3
;
std
::
copy
(
begin
(
a_g_n_c_wis_lengths
)
+
spatial_offset
,
end
(
a_g_n_c_wis_lengths
),
begin
(
input_spatial_lengths_
));
std
::
copy
(
begin
(
b_g_k_c_xs_lengths
)
+
spatial_offset
,
end
(
b_g_k_c_xs_lengths
),
begin
(
filter_spatial_lengths_
));
std
::
copy
(
begin
(
e_g_n_k_wos_lengths
)
+
spatial_offset
,
end
(
e_g_n_k_wos_lengths
),
begin
(
output_spatial_lengths_
));
const
auto
descs
=
DeviceOp
::
MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N
<
NDimSpatial
>
(
Conv_N_
,
Conv_K_
,
Conv_C_
,
input_spatial_lengths_
,
filter_spatial_lengths_
,
output_spatial_lengths_
,
a_g_n_c_wis_strides
,
b_g_k_c_xs_strides
,
e_g_n_k_wos_strides
,
conv_filter_strides
,
conv_filter_dilations
,
input_left_pads
,
input_right_pads
);
a_grid_desc_kbatch_k0_m_k1_
=
descs
[
I0
];
b_grid_desc_kbatch_k0_n_k1_
=
descs
[
I1
];
c_grid_desc_m_n_
=
descs
[
I2
];
block_2_ctile_map_
=
GridwiseGemm
::
MakeDefaultBlock2CTileMap
(
c_grid_desc_m_n_
,
I1
/* M01 */
,
I1
/* N01 */
);
// A/B/C Batch Stride
compute_ptr_offset_of_batch_
.
BatchStrideA_
=
e_g_n_k_wos_strides
[
0
];
compute_ptr_offset_of_batch_
.
BatchStrideB_
=
a_g_n_c_wis_strides
[
0
];
compute_ptr_offset_of_batch_
.
BatchStrideE_
=
Conv_K_
*
Conv_C_
*
std
::
accumulate
(
begin
(
filter_spatial_lengths_
),
end
(
filter_spatial_lengths_
),
index_t
{
1
},
std
::
multiplies
<>
{});
if
(
GridwiseGemm
::
CheckValidity
(
a_grid_desc_kbatch_k0_m_k1_
,
b_grid_desc_kbatch_k0_n_k1_
,
c_grid_desc_m_n_
,
block_2_ctile_map_
))
{
c_grid_desc_mblock_mperblock_nblock_nperblock_
=
GridwiseGemm
::
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
c_grid_desc_m_n_
);
}
}
const
ADataType
*
p_a_grid_
;
const
BDataType
*
p_b_grid_
;
CDataType
*
p_c_grid_
;
AGridDesc_K0_M_K1
a_grid_desc_kbatch_k0_m_k1_
;
BGridDesc_K0_N_K1
b_grid_desc_kbatch_k0_n_k1_
;
CGridDesc_M_N
c_grid_desc_m_n_
;
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock_
;
Block2CTileMap
block_2_ctile_map_
;
// for computing batch offset
ComputePtrOffsetOfStridedBatch
<
I0
>
compute_ptr_offset_of_batch_
;
OutElementwiseOperation
a_element_op_
;
InElementwiseOperation
b_element_op_
;
WeiElementwiseOperation
c_element_op_
;
// for checking IsSupportedArgument()
const
index_t
Conv_G_
;
const
index_t
Conv_N_
;
const
index_t
Conv_K_
;
const
index_t
Conv_C_
;
std
::
array
<
index_t
,
NDimSpatial
>
input_spatial_lengths_
;
std
::
array
<
index_t
,
NDimSpatial
>
filter_spatial_lengths_
;
std
::
array
<
index_t
,
NDimSpatial
>
output_spatial_lengths_
;
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_strides_
;
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_left_pads_
;
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_right_pads_
;
const
index_t
k_batch_
;
};
// Invoker
struct
Invoker
:
public
BaseInvoker
{
using
Argument
=
DeviceOp
::
Argument
;
void
Print
(
const
Argument
&
arg
)
{
std
::
cout
<<
"arg.a_grid_desc_kbatch_k0_m_k1_{"
<<
arg
.
a_grid_desc_kbatch_k0_m_k1_
.
GetLength
(
I0
)
<<
", "
<<
arg
.
a_grid_desc_kbatch_k0_m_k1_
.
GetLength
(
I1
)
<<
", "
<<
arg
.
a_grid_desc_kbatch_k0_m_k1_
.
GetLength
(
I2
)
<<
"}"
<<
std
::
endl
;
std
::
cout
<<
"arg.b_grid_desc_kbatch_k0_n_k1_{"
<<
arg
.
b_grid_desc_kbatch_k0_n_k1_
.
GetLength
(
I0
)
<<
", "
<<
arg
.
b_grid_desc_kbatch_k0_n_k1_
.
GetLength
(
I1
)
<<
", "
<<
arg
.
b_grid_desc_kbatch_k0_n_k1_
.
GetLength
(
I2
)
<<
"}"
<<
std
::
endl
;
std
::
cout
<<
"arg.c_grid_desc_m_n_{"
<<
arg
.
c_grid_desc_m_n_
.
GetLength
(
I0
)
<<
", "
<<
arg
.
c_grid_desc_m_n_
.
GetLength
(
I1
)
<<
"}"
<<
std
::
endl
;
}
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
{
if
(
stream_config
.
log_level_
>
0
)
{
Print
(
arg
);
}
if
(
!
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_kbatch_k0_m_k1_
,
arg
.
b_grid_desc_kbatch_k0_n_k1_
,
arg
.
c_grid_desc_m_n_
,
arg
.
block_2_ctile_map_
))
{
throw
std
::
runtime_error
(
"wrong! GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle has invalid "
"setting"
);
}
const
index_t
grid_size
=
arg
.
block_2_ctile_map_
.
CalculateGridSize
(
arg
.
c_grid_desc_m_n_
)
*
arg
.
Conv_G_
;
const
auto
K0
=
arg
.
a_grid_desc_kbatch_k0_m_k1_
.
GetLength
(
I1
);
const
bool
has_main_k0_block_loop
=
GridwiseGemm
::
CalculateHasMainKBlockLoop
(
K0
);
auto
launch_kernel
=
[
&
](
auto
has_main_k_block_loop
)
{
constexpr
bool
has_main_loop
=
has_main_k_block_loop
.
value
;
const
auto
kernel
=
kernel_grouped_conv_multiple_d_wmma_cshuffle
<
GridwiseGemm
,
ADataType
,
BDataType
,
typename
GridwiseGemm
::
DsGridPointer
,
CDataType
,
OutElementwiseOperation
,
InElementwiseOperation
,
WeiElementwiseOperation
,
AGridDesc_K0_M_K1
,
BGridDesc_K0_N_K1
,
DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
,
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
,
remove_reference_t
<
typename
GridwiseGemm
::
DefaultBlock2CTileMap
>
,
ComputePtrOffsetOfStridedBatch
<
I0
>
,
has_main_loop
>
;
using
EmptyTuple
=
Tuple
<>
;
return
launch_and_time_kernel
(
stream_config
,
kernel
,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
arg
.
p_a_grid_
,
arg
.
p_b_grid_
,
EmptyTuple
{},
// Ds
arg
.
p_c_grid_
,
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
c_element_op_
,
arg
.
Conv_G_
,
arg
.
a_grid_desc_kbatch_k0_m_k1_
,
arg
.
b_grid_desc_kbatch_k0_n_k1_
,
DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
{},
arg
.
c_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg
.
block_2_ctile_map_
,
arg
.
compute_ptr_offset_of_batch_
);
};
if
(
has_main_k0_block_loop
)
{
return
launch_kernel
(
integral_constant
<
bool
,
true
>
{});
}
else
{
return
launch_kernel
(
integral_constant
<
bool
,
false
>
{});
}
}
float
Run
(
const
BaseArgument
*
p_arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
override
{
return
Run
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
),
stream_config
);
}
};
static
constexpr
bool
IsValidCompilationParameter
()
{
// TODO: properly implement this check
return
true
;
}
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
// check device
if
(
get_device_name
()
==
"gfx1100"
||
get_device_name
()
==
"gfx1101"
||
get_device_name
()
==
"gfx1102"
)
{
if
constexpr
(
!
(
is_same_v
<
AccDataType
,
float
>
||
is_same_v
<
AccDataType
,
int32_t
>
))
{
return
false
;
}
}
else
{
return
false
;
}
// TODO: Add support for split_k > 1
if
(
arg
.
k_batch_
!=
1
)
{
return
false
;
}
if
constexpr
(
!
(
is_NDHWGK_GKZYXC_NDHWGC
||
is_GNDHWK_GKZYXC_GNDHWC
))
{
return
false
;
}
if
constexpr
(
ConvBackwardWeightSpecialization
==
ConvolutionBackwardWeightSpecialization
::
Filter1x1Stride1Pad0
)
{
// check if it's a 1x1 convolution with stride=1 and no padding
for
(
int
i
=
0
;
i
<
NDimSpatial
;
i
++
)
{
if
(
!
(
arg
.
filter_spatial_lengths_
[
i
]
==
1
&&
arg
.
conv_filter_strides_
[
i
]
==
1
&&
arg
.
input_left_pads_
[
i
]
==
0
&&
arg
.
input_right_pads_
[
i
]
==
0
))
{
return
false
;
}
}
}
// vector load A/B matrix from global memory
if
(
!
(
ABlockTransferSrcVectorDim
==
1
&&
BBlockTransferSrcVectorDim
==
1
&&
arg
.
Conv_K_
%
ABlockTransferSrcScalarPerVector
==
0
&&
arg
.
Conv_C_
%
BBlockTransferSrcScalarPerVector
==
0
))
{
return
false
;
}
// vector store C matrix into global memory
if
(
!
(
arg
.
Conv_C_
%
CShuffleBlockTransferScalarPerVector_NPerBlock
==
0
))
{
return
false
;
}
// Gridwise GEMM size
return
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_kbatch_k0_m_k1_
,
arg
.
b_grid_desc_kbatch_k0_n_k1_
,
arg
.
c_grid_desc_m_n_
,
arg
.
block_2_ctile_map_
);
}
bool
IsSupportedArgument
(
const
BaseArgument
*
p_arg
)
override
{
return
IsSupportedArgument
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
));
}
static
auto
MakeArgument
(
const
InDataType
*
p_in_grid
,
WeiDataType
*
p_wei_grid
,
const
OutDataType
*
p_out_grid
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
a_g_n_c_wis_lengths
,
// input
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
a_g_n_c_wis_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
b_g_k_c_xs_lengths
,
// weight
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
b_g_k_c_xs_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
e_g_n_k_wos_lengths
,
// output
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
e_g_n_k_wos_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_dilations
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_left_pads
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_right_pads
,
InElementwiseOperation
in_element_op
,
WeiElementwiseOperation
wei_element_op
,
OutElementwiseOperation
out_element_op
,
const
index_t
split_k
)
{
return
Argument
{
p_in_grid
,
p_wei_grid
,
p_out_grid
,
a_g_n_c_wis_lengths
,
// input
a_g_n_c_wis_strides
,
b_g_k_c_xs_lengths
,
// weight
b_g_k_c_xs_strides
,
e_g_n_k_wos_lengths
,
// output
e_g_n_k_wos_strides
,
conv_filter_strides
,
conv_filter_dilations
,
input_left_pads
,
input_right_pads
,
in_element_op
,
wei_element_op
,
out_element_op
,
split_k
};
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
void
*
p_in_grid
,
void
*
p_wei_grid
,
const
void
*
p_out_grid
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
a_g_n_c_wis_lengths
,
// input
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
a_g_n_c_wis_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
b_g_k_c_xs_lengths
,
// weight
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
b_g_k_c_xs_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
e_g_n_k_wos_lengths
,
// output
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
e_g_n_k_wos_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_dilations
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_left_pads
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_right_pads
,
InElementwiseOperation
in_element_op
,
WeiElementwiseOperation
wei_element_op
,
OutElementwiseOperation
out_element_op
,
const
index_t
split_k
)
override
{
return
std
::
make_unique
<
Argument
>
(
static_cast
<
const
InDataType
*>
(
p_in_grid
),
static_cast
<
WeiDataType
*>
(
p_wei_grid
),
static_cast
<
const
OutDataType
*>
(
p_out_grid
),
a_g_n_c_wis_lengths
,
// input
a_g_n_c_wis_strides
,
b_g_k_c_xs_lengths
,
// weight
b_g_k_c_xs_strides
,
e_g_n_k_wos_lengths
,
// output
e_g_n_k_wos_strides
,
conv_filter_strides
,
conv_filter_dilations
,
input_left_pads
,
input_right_pads
,
in_element_op
,
wei_element_op
,
out_element_op
,
split_k
);
}
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
override
{
return
std
::
make_unique
<
Invoker
>
(
Invoker
{});
}
std
::
string
GetTypeString
()
const
override
{
auto
str
=
std
::
stringstream
();
// clang-format off
str
<<
"DeviceGroupedConvBwdWeight_Wmma_CShuffle"
<<
"<"
<<
BlockSize
<<
", "
<<
MPerBlock
<<
", "
<<
NPerBlock
<<
", "
<<
K0PerBlock
<<
", "
<<
getConvBackwardWeightSpecializationString
(
ConvBackwardWeightSpecialization
)
<<
", "
<<
K1
<<
", "
<<
ABlockTransferSrcScalarPerVector
<<
", "
<<
ABlockTransferDstScalarPerVector_K1
<<
", "
<<
BBlockTransferSrcScalarPerVector
<<
", "
<<
BBlockTransferDstScalarPerVector_K1
<<
", "
<<
CShuffleMRepeatPerShuffle
<<
", "
<<
CShuffleNRepeatPerShuffle
<<
", "
<<
CShuffleBlockTransferScalarPerVector_NPerBlock
<<
">"
;
// clang-format on
return
str
.
str
();
}
};
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp
View file @
1b5af83d
...
...
@@ -14,6 +14,7 @@
#include "ck/tensor_operation/gpu/device/device_grouped_conv_bwd_weight.hpp"
#include "ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
...
...
@@ -21,32 +22,6 @@ namespace ck {
namespace
tensor_operation
{
namespace
device
{
namespace
{
struct
ComputePtrOffsetOfStridedBatch
{
__host__
__device__
constexpr
long_index_t
GetAPtrOffset
(
index_t
g_idx
)
const
{
return
g_idx
*
static_cast
<
long_index_t
>
(
BatchStrideA_
);
}
__host__
__device__
constexpr
long_index_t
GetBPtrOffset
(
index_t
g_idx
)
const
{
return
g_idx
*
static_cast
<
long_index_t
>
(
BatchStrideB_
);
}
__host__
__device__
constexpr
long_index_t
GetCPtrOffset
(
index_t
g_idx
)
const
{
return
g_idx
*
static_cast
<
long_index_t
>
(
BatchStrideC_
);
}
index_t
BatchStrideA_
;
index_t
BatchStrideB_
;
index_t
BatchStrideC_
;
};
}
// namespace
template
<
typename
GridwiseGemm
,
typename
FloatA
,
typename
FloatB
,
...
...
@@ -1222,7 +1197,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
Block2CTileMap
block_2_ctile_map_
;
// for computing batch offset
ComputePtrOffsetOfStridedBatch
compute_ptr_offset_of_batch_
;
ComputePtrOffsetOfStridedBatch
<
I0
>
compute_ptr_offset_of_batch_
;
index_t
M01_
;
index_t
N01_
;
...
...
@@ -1301,7 +1276,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
remove_reference_t
<
DeviceOp
::
BGridDesc_K0_N_K1
>
,
remove_reference_t
<
DeviceOp
::
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
>
,
remove_reference_t
<
DeviceOp
::
Block2CTileMap
>
,
ComputePtrOffsetOfStridedBatch
,
ComputePtrOffsetOfStridedBatch
<
I0
>
,
has_main_loop
>
;
return
launch_and_time_kernel
(
stream_config
,
...
...
@@ -1348,6 +1323,10 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
if
(
!
ck
::
is_xdl_supported
())
{
return
false
;
}
if
constexpr
(
NDimSpatial
==
1
)
{
if
constexpr
(
!
is_GNWK_GKXC_GNWC
)
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp
View file @
1b5af83d
...
...
@@ -471,7 +471,7 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
auto
launch_kernel
=
[
&
](
auto
has_main_k_block_loop
)
{
constexpr
bool
has_main_loop
=
has_main_k_block_loop
.
value
;
const
auto
kernel
=
kernel_grouped_conv_
fwd_
multiple_d_wmma_cshuffle
<
const
auto
kernel
=
kernel_grouped_conv_multiple_d_wmma_cshuffle
<
GridwiseOp
,
ADataType
,
BDataType
,
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp
View file @
1b5af83d
...
...
@@ -43,7 +43,13 @@ struct ComputePtrOffsetOfStridedBatch
return
ds_offset
;
}
__host__
__device__
constexpr
long_index_t
GetEPtrOffset
(
index_t
g_idx
)
const
[[
maybe_unused
]]
__host__
__device__
constexpr
long_index_t
GetEPtrOffset
(
index_t
g_idx
)
const
{
return
g_idx
*
static_cast
<
long_index_t
>
(
BatchStrideE_
);
}
// alias for kernels without multiple D
[[
maybe_unused
]]
__host__
__device__
constexpr
long_index_t
GetCPtrOffset
(
index_t
g_idx
)
const
{
return
g_idx
*
static_cast
<
long_index_t
>
(
BatchStrideE_
);
}
...
...
@@ -52,6 +58,7 @@ struct ComputePtrOffsetOfStridedBatch
index_t
BatchStrideB_
;
Array
<
ck
::
index_t
,
NumDTensor
>
BatchStrideDs_
;
index_t
BatchStrideE_
;
index_t
&
BatchStrideC_
=
BatchStrideE_
;
// alias for kernels without multiple D
};
}
// namespace device
...
...
include/ck/tensor_operation/gpu/device/impl/device_normalization_impl.hpp
View file @
1b5af83d
...
...
@@ -28,6 +28,7 @@ template <typename XDataType,
typename
BetaDataType
,
typename
ComputeDataType
,
typename
YDataType
,
typename
SaveMeanInvStdDataType
,
typename
YElementwiseOperation
,
index_t
Rank
,
index_t
NumReduceDim
,
...
...
@@ -43,12 +44,13 @@ template <typename XDataType,
index_t
BetaSrcVectorDim
,
index_t
BetaSrcVectorSize
,
index_t
YDstVectorSize
,
index_t
SaveMeanInvStdDstVectorSize
,
bool
UseWelford
=
true
>
struct
DeviceNormalizationImpl
:
public
DeviceNormalization
<
XDataType
,
GammaDataType
,
BetaDataType
,
ComputeDataType
,
YDataType
,
SaveMeanInvStdDataType
,
YElementwiseOperation
,
Rank
,
NumReduceDim
>
...
...
@@ -64,18 +66,24 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
(
BetaSrcVectorDim
==
1
&&
KThreadSliceSize
%
BetaSrcVectorSize
==
0
)),
"Invalid thread slice sizes and/or beta vector sizes configuration, please check!"
);
static_assert
(
MThreadSliceSize
%
SaveMeanInvStdDstVectorSize
==
0
,
"Invalid thread slice sizes and/or save mean and inverse std vector sizes "
"configuration, please check!"
);
using
PassThrough
=
tensor_operation
::
element_wise
::
PassThrough
;
static
constexpr
index_t
NumInvariantDim
=
Rank
-
NumReduceDim
;
static
constexpr
index_t
M_BlockTileSize
=
MThreadClusterSize
*
MThreadSliceSize
;
static
constexpr
index_t
K_BlockTileSize
=
KThreadClusterSize
*
KThreadSliceSize
;
static
constexpr
bool
reduceAllDim
=
(
NumInvariantDim
==
0
);
static_assert
(
!
reduceAllDim
);
// TODO
static
auto
MakeSrc2dDescriptor
(
const
std
::
vector
<
index_t
>&
inLengths
,
const
std
::
vector
<
index_t
>&
inStrides
,
int
numBlockTileIteration
)
{
constexpr
index_t
NumInvariantDim
=
Rank
-
NumReduceDim
;
static
constexpr
index_t
numSrcDim
=
Rank
;
static
constexpr
bool
reduceAllDim
=
(
NumInvariantDim
==
0
);
const
auto
tupleSrcLengths
=
make_tuple_from_array
(
inLengths
,
Number
<
numSrcDim
>
{});
const
auto
tupleSrcStrides
=
make_tuple_from_array
(
inStrides
,
Number
<
numSrcDim
>
{});
...
...
@@ -133,7 +141,37 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
return
(
in_grid_desc_m_k_padded
);
};
static
auto
MakeSaveMeanInvStdDescriptor_M
(
const
std
::
vector
<
index_t
>&
lengths
,
const
std
::
vector
<
index_t
>&
strides
)
{
using
InvariantDims
=
typename
arithmetic_sequence_gen
<
0
,
NumInvariantDim
,
1
>::
type
;
const
auto
tupleSrcLengths
=
make_tuple_from_array_and_index_seq
(
lengths
,
InvariantDims
{});
const
auto
tupleSrcStrides
=
make_tuple_from_array_and_index_seq
(
strides
,
InvariantDims
{});
const
auto
desc
=
make_naive_tensor_descriptor
(
tupleSrcLengths
,
tupleSrcStrides
);
const
auto
grid_desc_m
=
transform_tensor_descriptor
(
desc
,
make_tuple
(
make_merge_transform
(
tupleSrcLengths
)),
make_tuple
(
InvariantDims
{}),
make_tuple
(
Sequence
<
0
>
{}));
const
auto
invariantLength
=
grid_desc_m
.
GetLength
(
Number
<
0
>
{});
const
auto
pad_M
=
math
::
integer_least_multiple
(
invariantLength
,
M_BlockTileSize
)
-
invariantLength
;
auto
grid_desc_m_padded
=
transform_tensor_descriptor
(
grid_desc_m
,
make_tuple
(
make_right_pad_transform
(
invariantLength
,
pad_M
)),
make_tuple
(
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
return
grid_desc_m_padded
;
}
using
GridDesc_M_K
=
decltype
(
MakeSrc2dDescriptor
({
1
},
{
1
},
1
));
using
GridDesc_M
=
decltype
(
MakeSaveMeanInvStdDescriptor_M
({
1
},
{
1
}));
struct
Argument
:
public
BaseArgument
{
...
...
@@ -142,17 +180,23 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
const
std
::
vector
<
index_t
>
gammaStrides
,
const
std
::
vector
<
index_t
>
betaStrides
,
const
std
::
vector
<
index_t
>
yStrides
,
const
std
::
vector
<
index_t
>
saveMeanStrides
,
const
std
::
vector
<
index_t
>
saveInvStdStrides
,
const
std
::
vector
<
index_t
>
reduceDims
,
YElementwiseOperation
y_elementwise_op
,
double
epsilon
,
const
XDataType
*
p_x
,
const
GammaDataType
*
p_gamma
,
const
BetaDataType
*
p_beta
,
YDataType
*
p_y
)
YDataType
*
p_y
,
SaveMeanInvStdDataType
*
p_saveMean
,
SaveMeanInvStdDataType
*
p_saveInvStd
)
:
p_x_
(
p_x
),
p_gamma_
(
p_gamma
),
p_beta_
(
p_beta
),
p_y_
(
p_y
),
p_saveMean_
(
p_saveMean
),
p_saveInvStd_
(
p_saveInvStd
),
y_elementwise_op_
(
y_elementwise_op
)
{
epsilon_
=
static_cast
<
ComputeDataType
>
(
epsilon
);
...
...
@@ -162,16 +206,14 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
yStrides_
=
shuffle_tensor_dimensions
<
Rank
,
NumReduceDim
>
(
yStrides
,
reduceDims
);
gammaStrides_
=
shuffle_tensor_dimensions
<
Rank
,
NumReduceDim
>
(
gammaStrides
,
reduceDims
);
betaStrides_
=
shuffle_tensor_dimensions
<
Rank
,
NumReduceDim
>
(
betaStrides
,
reduceDims
);
saveMeanStrides_
=
saveMeanStrides
;
saveInvStdStrides_
=
saveInvStdStrides
;
long_index_t
invariant_length
;
long_index_t
reduce_length
;
std
::
tie
(
invariant_length
,
reduce_length
)
=
get_2d_lengths
<
Rank
,
NumReduceDim
>
(
Lengths_
);
std
::
tie
(
MRaw_
,
KRaw_
)
=
get_2d_lengths
<
Rank
,
NumReduceDim
>
(
Lengths_
);
numBlockTileIteration_
=
math
::
integer_divide_ceil
(
reduce_length
,
K_BlockTileSize
);
numBlockTileIteration_
=
math
::
integer_divide_ceil
(
KRaw_
,
K_BlockTileSize
);
gridSize_
=
math
::
integer_divide_ceil
(
invariant_length
,
M_BlockTileSize
);
gridSize_
=
math
::
integer_divide_ceil
(
MRaw_
,
M_BlockTileSize
);
x_grid_desc_m_k_
=
MakeSrc2dDescriptor
(
Lengths_
,
xStrides_
,
numBlockTileIteration_
);
gamma_grid_desc_m_k_
=
...
...
@@ -179,9 +221,16 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
beta_grid_desc_m_k_
=
MakeSrc2dDescriptor
(
Lengths_
,
betaStrides_
,
numBlockTileIteration_
);
y_grid_desc_m_k_
=
MakeSrc2dDescriptor
(
Lengths_
,
yStrides_
,
numBlockTileIteration_
);
save_mean_grid_desc_m_
=
MakeSaveMeanInvStdDescriptor_M
(
Lengths_
,
saveMeanStrides
);
save_inv_std_grid_desc_m_
=
MakeSaveMeanInvStdDescriptor_M
(
Lengths_
,
saveInvStdStrides
);
isSweeponce_
=
x_grid_desc_m_k_
.
GetLength
(
Number
<
1
>
{})
<=
KThreadClusterSize
*
KThreadSliceSize
;
if
constexpr
(
NumInvariantDim
==
0
)
invariant_lowest_length_
=
1
;
else
invariant_lowest_length_
=
Lengths_
[
NumInvariantDim
-
1
];
}
ComputeDataType
epsilon_
;
...
...
@@ -190,12 +239,16 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
const
GammaDataType
*
p_gamma_
;
const
BetaDataType
*
p_beta_
;
YDataType
*
p_y_
;
SaveMeanInvStdDataType
*
p_saveMean_
;
SaveMeanInvStdDataType
*
p_saveInvStd_
;
std
::
vector
<
index_t
>
Lengths_
;
std
::
vector
<
index_t
>
xStrides_
;
std
::
vector
<
index_t
>
gammaStrides_
;
std
::
vector
<
index_t
>
betaStrides_
;
std
::
vector
<
index_t
>
yStrides_
;
std
::
vector
<
index_t
>
saveMeanStrides_
;
std
::
vector
<
index_t
>
saveInvStdStrides_
;
YElementwiseOperation
y_elementwise_op_
;
...
...
@@ -206,7 +259,14 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
GridDesc_M_K
gamma_grid_desc_m_k_
;
GridDesc_M_K
beta_grid_desc_m_k_
;
GridDesc_M_K
y_grid_desc_m_k_
;
GridDesc_M
save_mean_grid_desc_m_
;
GridDesc_M
save_inv_std_grid_desc_m_
;
bool
isSweeponce_
;
index_t
MRaw_
;
// invarient length
index_t
KRaw_
;
// reduce length
index_t
invariant_lowest_length_
;
};
struct
Invoker
:
public
BaseInvoker
...
...
@@ -217,9 +277,11 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
GammaDataType
,
BetaDataType
,
YDataType
,
SaveMeanInvStdDataType
,
ComputeDataType
,
YElementwiseOperation
,
GridDesc_M_K
,
GridDesc_M
,
BlockSize
,
MThreadClusterSize
,
KThreadClusterSize
,
...
...
@@ -233,6 +295,7 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
BetaSrcVectorSize
,
XYSrcVectorDim
,
YDstVectorSize
,
SaveMeanInvStdDstVectorSize
,
UseWelford
>
(
arg
.
isSweeponce_
);
float
avg_time
=
0
;
...
...
@@ -245,12 +308,16 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
arg
.
gamma_grid_desc_m_k_
,
arg
.
beta_grid_desc_m_k_
,
arg
.
y_grid_desc_m_k_
,
arg
.
save_mean_grid_desc_m_
,
arg
.
save_inv_std_grid_desc_m_
,
arg
.
numBlockTileIteration_
,
arg
.
epsilon_
,
arg
.
p_x_
,
arg
.
p_gamma_
,
arg
.
p_beta_
,
arg
.
p_y_
,
arg
.
p_saveMean_
,
arg
.
p_saveInvStd_
,
arg
.
y_elementwise_op_
);
return
(
avg_time
);
...
...
@@ -267,8 +334,6 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
{
const
Argument
*
p_arg_
=
dynamic_cast
<
const
Argument
*>
(
p_arg
);
constexpr
index_t
NumInvariantDim
=
Rank
-
NumReduceDim
;
if
constexpr
(
XYSrcVectorDim
==
0
)
{
if
constexpr
(
NumInvariantDim
==
0
)
...
...
@@ -277,13 +342,15 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
}
else
{
printf
(
"!!!! %d
\n
"
,
p_arg_
->
invariant_lowest_length_
);
if
(
p_arg_
->
xStrides_
[
NumInvariantDim
-
1
]
!=
1
)
return
false
;
if
(
p_arg_
->
invariant_lowest_length
%
XSrcVectorSize
!=
0
)
if
(
p_arg_
->
invariant_lowest_length
_
%
XSrcVectorSize
!=
0
)
return
false
;
if
(
p_arg_
->
invariant_lowest_length
%
YDstVectorSize
!=
0
)
if
(
p_arg_
->
invariant_lowest_length
_
%
YDstVectorSize
!=
0
)
return
false
;
};
}
...
...
@@ -325,7 +392,7 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
if
(
p_arg_
->
betaStrides_
[
NumInvariantDim
-
1
]
!=
1
)
return
(
false
);
if
(
p_arg_
->
invariant_lowest_length
%
BetaSrcVectorSize
!=
0
)
if
(
p_arg_
->
invariant_lowest_length
_
%
BetaSrcVectorSize
!=
0
)
return
(
false
);
}
else
// if fastest dim is reduced
...
...
@@ -337,6 +404,9 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
return
(
false
);
}
if
(
p_arg_
->
invariant_lowest_length_
%
SaveMeanInvStdDstVectorSize
!=
0
)
return
false
;
return
true
;
};
...
...
@@ -346,6 +416,8 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
const
std
::
vector
<
index_t
>
gammaStrides
,
const
std
::
vector
<
index_t
>
betaStrides
,
const
std
::
vector
<
index_t
>
yStrides
,
const
std
::
vector
<
index_t
>
saveMeanStrides
,
const
std
::
vector
<
index_t
>
saveInvStdStrides
,
const
std
::
vector
<
index_t
>
reduceDims
,
double
epsilon
,
const
void
*
p_x
,
...
...
@@ -353,27 +425,30 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
const
void
*
p_beta
,
void
*
p_y
,
void
*
p_saveMean
,
void
*
p_saveInv
Var
,
void
*
p_saveInv
Std
,
YElementwiseOperation
y_elementwise_op
)
override
{
// TODO
// Optional cache of the intermediate results (mean and InvVariance) during the
// forward pass could speedup in the backward
ignore
=
p_saveMean
;
ignore
=
p_saveInvVar
;
if
(
lengths
.
size
()
!=
Rank
||
xStrides
.
size
()
!=
Rank
||
gammaStrides
.
size
()
!=
Rank
||
betaStrides
.
size
()
!=
Rank
||
yStrides
.
size
()
!=
Rank
||
saveMeanStrides
.
size
()
!=
NumInvariantDim
||
saveInvStdStrides
.
size
()
!=
NumInvariantDim
)
throw
std
::
runtime_error
(
"dimension is incorrect"
);
return
std
::
make_unique
<
Argument
>
(
lengths
,
xStrides
,
gammaStrides
,
betaStrides
,
yStrides
,
saveMeanStrides
,
saveInvStdStrides
,
reduceDims
,
y_elementwise_op
,
epsilon
,
static_cast
<
const
XDataType
*>
(
p_x
),
static_cast
<
const
GammaDataType
*>
(
p_gamma
),
static_cast
<
const
BetaDataType
*>
(
p_beta
),
static_cast
<
YDataType
*>
(
p_y
));
static_cast
<
YDataType
*>
(
p_y
),
static_cast
<
SaveMeanInvStdDataType
*>
(
p_saveMean
),
static_cast
<
SaveMeanInvStdDataType
*>
(
p_saveInvStd
));
};
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
override
...
...
include/ck/tensor_operation/gpu/device/impl/device_normalization_splitk_impl.hpp
View file @
1b5af83d
...
...
@@ -19,7 +19,7 @@
namespace
ck
{
template
<
typename
GridwiseWelford
,
typename
XDataType
,
typename
MeanVarDataType
,
typename
Workspace
MeanVarDataType
,
typename
ComputeDataType
,
typename
XGridDesc_M_K
,
typename
MeanVarGridDesc_M_KBlock
>
...
...
@@ -28,8 +28,8 @@ kernel_normalizationSplitK1st(const XGridDesc_M_K x_grid_desc_m_k,
const
MeanVarGridDesc_M_KBlock
mean_var_grid_desc_m_kblock
,
index_t
num_k_block_tile_iteration
,
const
XDataType
*
const
__restrict__
p_x_global
,
MeanVarDataType
*
const
__restrict__
p_welford_mean
,
MeanVarDataType
*
const
__restrict__
p_welford_variance
,
Workspace
MeanVarDataType
*
const
__restrict__
p_welford_mean
,
Workspace
MeanVarDataType
*
const
__restrict__
p_welford_variance
,
int32_t
*
const
__restrict__
p_welford_count
)
{
GridwiseWelford
::
Run
(
x_grid_desc_m_k
,
...
...
@@ -42,16 +42,18 @@ kernel_normalizationSplitK1st(const XGridDesc_M_K x_grid_desc_m_k,
};
template
<
typename
GridwiseWelfordNormalization
,
typename
MeanVarDataType
,
typename
Workspace
MeanVarDataType
,
typename
XDataType
,
typename
GammaDataType
,
typename
BetaDataType
,
typename
YDataType
,
typename
SaveMeanInvStdDataType
,
typename
ComputeDataType
,
typename
YElementwiseOperation
,
typename
MeanVarGridDesc_M_KBlock
,
typename
CountGridDesc_M_KBlock
,
typename
XYGammaBetaGridDesc_M_K
>
typename
XYGammaBetaGridDesc_M_K
,
typename
SaveMeanInvStdGridDesc_M
>
__global__
void
kernel_normalizationSplitK2nd
(
const
MeanVarGridDesc_M_KBlock
mean_var_grid_desc_m_kblock
,
const
CountGridDesc_M_KBlock
count_grid_desc_m_kblock
,
...
...
@@ -59,17 +61,21 @@ kernel_normalizationSplitK2nd(const MeanVarGridDesc_M_KBlock mean_var_grid_desc_
const
XYGammaBetaGridDesc_M_K
gamma_grid_desc_m_k
,
const
XYGammaBetaGridDesc_M_K
beta_grid_desc_m_k
,
const
XYGammaBetaGridDesc_M_K
y_grid_desc_m_k
,
const
SaveMeanInvStdGridDesc_M
save_mean_grid_desc_m
,
const
SaveMeanInvStdGridDesc_M
save_inv_std_grid_desc_m
,
index_t
num_k_mean_var_count_iteration
,
index_t
num_k_block_tile_iteration
,
index_t
k_grid_size
,
ComputeDataType
epsilon
,
const
MeanVarDataType
*
const
p_mean_global
,
const
MeanVarDataType
*
const
p_variance_global
,
const
Workspace
MeanVarDataType
*
const
p_mean_global
,
const
Workspace
MeanVarDataType
*
const
p_variance_global
,
const
int32_t
*
const
p_welford_count_global
,
const
XDataType
*
const
__restrict__
p_x_global
,
const
GammaDataType
*
const
__restrict__
p_gamma_global
,
const
BetaDataType
*
const
__restrict__
p_beta_global
,
YDataType
*
const
__restrict__
p_y_global
,
SaveMeanInvStdDataType
*
const
__restrict__
p_save_mean_global
,
SaveMeanInvStdDataType
*
const
__restrict__
p_save_inv_std_global
,
const
YElementwiseOperation
y_elementwise_op
)
{
GridwiseWelfordNormalization
::
Run
(
mean_var_grid_desc_m_kblock
,
...
...
@@ -78,6 +84,8 @@ kernel_normalizationSplitK2nd(const MeanVarGridDesc_M_KBlock mean_var_grid_desc_
gamma_grid_desc_m_k
,
beta_grid_desc_m_k
,
y_grid_desc_m_k
,
save_mean_grid_desc_m
,
save_inv_std_grid_desc_m
,
num_k_mean_var_count_iteration
,
num_k_block_tile_iteration
,
k_grid_size
,
...
...
@@ -89,6 +97,8 @@ kernel_normalizationSplitK2nd(const MeanVarGridDesc_M_KBlock mean_var_grid_desc_
p_gamma_global
,
p_beta_global
,
p_y_global
,
p_save_mean_global
,
p_save_inv_std_global
,
y_elementwise_op
);
};
}
// namespace ck
...
...
@@ -107,6 +117,7 @@ template <typename XDataType,
typename
BetaDataType
,
typename
ComputeDataType
,
typename
YDataType
,
typename
SaveMeanInvStdDataType
,
typename
YElementwiseOperation
,
index_t
Rank
,
index_t
NumReduceDim
,
...
...
@@ -121,17 +132,18 @@ template <typename XDataType,
index_t
GammaSrcVectorSize
,
index_t
BetaSrcVectorDim
,
index_t
BetaSrcVectorSize
,
index_t
YDstVectorSize
>
index_t
YDstVectorSize
,
index_t
SaveMeanInvStdDstVectorSize
>
struct
DeviceNormalizationSplitKImpl
:
public
DeviceNormalization
<
XDataType
,
GammaDataType
,
BetaDataType
,
ComputeDataType
,
YDataType
,
SaveMeanInvStdDataType
,
YElementwiseOperation
,
Rank
,
NumReduceDim
>
{
using
MeanVarDataType
=
Compute
DataType
;
using
Workspace
MeanVarDataType
=
SaveMeanInvStd
DataType
;
static_assert
(
BlockSize
==
MThreadClusterSize
*
KThreadClusterSize
);
static_assert
(
...
...
@@ -144,22 +156,28 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType,
(
BetaSrcVectorDim
==
1
&&
KThreadSliceSize
%
BetaSrcVectorSize
==
0
)),
"Invalid thread slice sizes and/or beta vector sizes configuration, please check!"
);
static_assert
(
MThreadSliceSize
%
SaveMeanInvStdDstVectorSize
==
0
,
"Invalid thread slice sizes and/or save mean and inverse std vector sizes "
"configuration, please check!"
);
using
PassThrough
=
tensor_operation
::
element_wise
::
PassThrough
;
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
index_t
NumInvariantDim
=
Rank
-
NumReduceDim
;
static
constexpr
index_t
M_BlockTileSize
=
MThreadClusterSize
*
MThreadSliceSize
;
static
constexpr
index_t
K_BlockTileSize
=
KThreadClusterSize
*
KThreadSliceSize
;
static
constexpr
bool
reduceAllDim
=
(
NumInvariantDim
==
0
);
static_assert
(
!
reduceAllDim
);
// TODO
static
auto
MakeSrc2dDescriptor
(
const
std
::
vector
<
index_t
>&
inLengths
,
const
std
::
vector
<
index_t
>&
inStrides
,
int
kBlockSize
,
int
numBlockTileIteration
)
{
constexpr
index_t
NumInvariantDim
=
Rank
-
NumReduceDim
;
static
constexpr
index_t
numSrcDim
=
Rank
;
static
constexpr
bool
reduceAllDim
=
(
NumInvariantDim
==
0
);
const
auto
tupleSrcLengths
=
make_tuple_from_array
(
inLengths
,
Number
<
numSrcDim
>
{});
const
auto
tupleSrcStrides
=
make_tuple_from_array
(
inStrides
,
Number
<
numSrcDim
>
{});
...
...
@@ -219,7 +237,7 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType,
};
template
<
typename
DoPads
,
index_t
MPerTile
,
index_t
KPerTile
>
static
auto
MakeMeanVarDescriptor_M_K
(
index_t
M
,
index_t
K
)
static
auto
Make
Workspace
MeanVarDescriptor_M_K
(
index_t
M
,
index_t
K
)
{
const
auto
grid_desc_m_k
=
make_naive_tensor_descriptor
(
make_tuple
(
M
,
K
),
make_tuple
(
K
,
I1
));
...
...
@@ -227,26 +245,57 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType,
}
template
<
typename
DoPads
,
index_t
MPerTile
,
index_t
KPerTile
>
static
auto
MakeCountDescriptor_M_K
(
index_t
M
,
index_t
K
)
static
auto
Make
Workspace
CountDescriptor_M_K
(
index_t
M
,
index_t
K
)
{
const
auto
grid_desc_m_k
=
make_naive_tensor_descriptor
(
make_tuple
(
M
,
K
),
make_tuple
(
I0
,
I1
));
return
PadTensorDescriptor
(
grid_desc_m_k
,
make_tuple
(
MPerTile
,
KPerTile
),
DoPads
{});
}
static
auto
MakeSaveMeanInvStdDescriptor_M
(
const
std
::
vector
<
index_t
>&
lengths
,
const
std
::
vector
<
index_t
>&
strides
)
{
using
InvariantDims
=
typename
arithmetic_sequence_gen
<
0
,
NumInvariantDim
,
1
>::
type
;
const
auto
tupleSrcLengths
=
make_tuple_from_array_and_index_seq
(
lengths
,
InvariantDims
{});
const
auto
tupleSrcStrides
=
make_tuple_from_array_and_index_seq
(
strides
,
InvariantDims
{});
const
auto
desc
=
make_naive_tensor_descriptor
(
tupleSrcLengths
,
tupleSrcStrides
);
const
auto
grid_desc_m
=
transform_tensor_descriptor
(
desc
,
make_tuple
(
make_merge_transform
(
tupleSrcLengths
)),
make_tuple
(
InvariantDims
{}),
make_tuple
(
Sequence
<
0
>
{}));
const
auto
invariantLength
=
grid_desc_m
.
GetLength
(
Number
<
0
>
{});
const
auto
pad_M
=
math
::
integer_least_multiple
(
invariantLength
,
M_BlockTileSize
)
-
invariantLength
;
auto
grid_desc_m_padded
=
transform_tensor_descriptor
(
grid_desc_m
,
make_tuple
(
make_right_pad_transform
(
invariantLength
,
pad_M
)),
make_tuple
(
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
return
grid_desc_m_padded
;
}
using
SrcGridDesc_M_K
=
decltype
(
MakeSrc2dDescriptor
({
1
},
{
1
},
1
,
1
));
using
Kernel1MeanVarGridDesc_M_KBlock
=
decltype
(
MakeMeanVarDescriptor_M_K
<
Sequence
<
true
,
false
>
,
1
,
1
>
(
1
,
1
));
decltype
(
Make
Workspace
MeanVarDescriptor_M_K
<
Sequence
<
true
,
false
>
,
1
,
1
>
(
1
,
1
));
using
Kernel2MeanVarGridDesc_M_KBlock
=
decltype
(
MakeMeanVarDescriptor_M_K
<
Sequence
<
true
,
true
>
,
1
,
1
>
(
1
,
1
));
decltype
(
Make
Workspace
MeanVarDescriptor_M_K
<
Sequence
<
true
,
true
>
,
1
,
1
>
(
1
,
1
));
using
Kernel2CountGridDesc_M_KBlock
=
decltype
(
MakeCountDescriptor_M_K
<
Sequence
<
true
,
true
>
,
1
,
1
>
(
1
,
1
));
decltype
(
MakeWorkspaceCountDescriptor_M_K
<
Sequence
<
true
,
true
>
,
1
,
1
>
(
1
,
1
));
using
SaveMeanInvStdGridDesc_M
=
decltype
(
MakeSaveMeanInvStdDescriptor_M
({
1
},
{
1
}));
using
GridwiseWelford
=
GridwiseNormalizationSplitK1st
<
XDataType
,
ComputeDataType
,
MeanVarDataType
,
Workspace
MeanVarDataType
,
SrcGridDesc_M_K
,
Kernel1MeanVarGridDesc_M_KBlock
,
BlockSize
,
...
...
@@ -258,16 +307,18 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType,
XSrcVectorSize
>
;
using
GridwiseWelfordNormalization
=
GridwiseNormalizationSplitK2nd
<
MeanVarDataType
,
GridwiseNormalizationSplitK2nd
<
Workspace
MeanVarDataType
,
XDataType
,
GammaDataType
,
BetaDataType
,
YDataType
,
SaveMeanInvStdDataType
,
ComputeDataType
,
YElementwiseOperation
,
Kernel2MeanVarGridDesc_M_KBlock
,
Kernel2CountGridDesc_M_KBlock
,
SrcGridDesc_M_K
,
SaveMeanInvStdGridDesc_M
,
BlockSize
,
MThreadClusterSize
,
KThreadClusterSize
,
...
...
@@ -280,7 +331,8 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType,
BetaSrcVectorDim
,
BetaSrcVectorSize
,
XYVectorDim
,
YDstVectorSize
>
;
YDstVectorSize
,
SaveMeanInvStdDstVectorSize
>
;
struct
Argument
:
public
BaseArgument
{
...
...
@@ -289,17 +341,23 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType,
const
std
::
vector
<
index_t
>
gammaStrides
,
const
std
::
vector
<
index_t
>
betaStrides
,
const
std
::
vector
<
index_t
>
yStrides
,
const
std
::
vector
<
index_t
>
saveMeanStrides
,
const
std
::
vector
<
index_t
>
saveInvStdStrides
,
const
std
::
vector
<
index_t
>
reduceDims
,
YElementwiseOperation
y_elementwise_op
,
double
epsilon
,
const
XDataType
*
p_x
,
const
GammaDataType
*
p_gamma
,
const
BetaDataType
*
p_beta
,
YDataType
*
p_y
)
YDataType
*
p_y
,
SaveMeanInvStdDataType
*
p_saveMean
,
SaveMeanInvStdDataType
*
p_saveInvStd
)
:
p_x_
(
p_x
),
p_gamma_
(
p_gamma
),
p_beta_
(
p_beta
),
p_y_
(
p_y
),
p_saveMean_
(
p_saveMean
),
p_saveInvStd_
(
p_saveInvStd
),
p_workspace_mean_
{
nullptr
},
p_workspace_var_
{
nullptr
},
p_workspace_count_
{
nullptr
},
...
...
@@ -312,6 +370,8 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType,
yStrides_
=
shuffle_tensor_dimensions
<
Rank
,
NumReduceDim
>
(
yStrides
,
reduceDims
);
gammaStrides_
=
shuffle_tensor_dimensions
<
Rank
,
NumReduceDim
>
(
gammaStrides
,
reduceDims
);
betaStrides_
=
shuffle_tensor_dimensions
<
Rank
,
NumReduceDim
>
(
betaStrides
,
reduceDims
);
saveMeanStrides_
=
saveMeanStrides
;
saveInvStdStrides_
=
saveInvStdStrides
;
std
::
tie
(
MRaw_
,
KRaw_
)
=
get_2d_lengths
<
Rank
,
NumReduceDim
>
(
Lengths_
);
...
...
@@ -346,20 +406,28 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType,
y_grid_desc_m_k_
=
MakeSrc2dDescriptor
(
Lengths_
,
yStrides_
,
kGridSize_
,
numBlockTileIteration_
);
save_mean_grid_desc_m_
=
MakeSaveMeanInvStdDescriptor_M
(
Lengths_
,
saveMeanStrides
);
save_inv_std_grid_desc_m_
=
MakeSaveMeanInvStdDescriptor_M
(
Lengths_
,
saveInvStdStrides
);
// We don't need to pad in K dimension for Welford1. Set KPerTile 1.
kernel1_mean_var_grid_desc_m_kblock_
=
MakeMeanVarDescriptor_M_K
<
Sequence
<
true
,
false
>
,
M_BlockTileSize
,
1
>
(
MRaw_
,
kGridSize_
);
Make
Workspace
MeanVarDescriptor_M_K
<
Sequence
<
true
,
false
>
,
M_BlockTileSize
,
1
>
(
MRaw_
,
kGridSize_
);
kernel2_mean_var_grid_desc_m_kblock_
=
MakeMeanVarDescriptor_M_K
<
Sequence
<
true
,
true
>
,
M_BlockTileSize
,
K_MeanVarCountBlockTileSize
>
(
MRaw_
,
kGridSize_
);
Make
Workspace
MeanVarDescriptor_M_K
<
Sequence
<
true
,
true
>
,
M_BlockTileSize
,
K_MeanVarCountBlockTileSize
>
(
MRaw_
,
kGridSize_
);
kernel2_count_grid_desc_m_kblock_
=
MakeCountDescriptor_M_K
<
Sequence
<
true
,
true
>
,
M_BlockTileSize
,
K_MeanVarCountBlockTileSize
>
(
MRaw_
,
kGridSize_
);
MakeWorkspaceCountDescriptor_M_K
<
Sequence
<
true
,
true
>
,
M_BlockTileSize
,
K_MeanVarCountBlockTileSize
>
(
MRaw_
,
kGridSize_
);
if
constexpr
(
NumInvariantDim
==
0
)
invariant_lowest_length_
=
1
;
else
invariant_lowest_length_
=
Lengths_
[
NumInvariantDim
-
1
];
}
ComputeDataType
epsilon_
;
...
...
@@ -368,6 +436,8 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType,
const
GammaDataType
*
p_gamma_
;
const
BetaDataType
*
p_beta_
;
YDataType
*
p_y_
;
SaveMeanInvStdDataType
*
p_saveMean_
;
SaveMeanInvStdDataType
*
p_saveInvStd_
;
void
*
p_workspace_mean_
;
void
*
p_workspace_var_
;
void
*
p_workspace_count_
;
...
...
@@ -377,6 +447,8 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType,
std
::
vector
<
index_t
>
gammaStrides_
;
std
::
vector
<
index_t
>
betaStrides_
;
std
::
vector
<
index_t
>
yStrides_
;
std
::
vector
<
index_t
>
saveMeanStrides_
;
std
::
vector
<
index_t
>
saveInvStdStrides_
;
YElementwiseOperation
y_elementwise_op_
;
...
...
@@ -389,6 +461,8 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType,
SrcGridDesc_M_K
gamma_grid_desc_m_k_
;
SrcGridDesc_M_K
beta_grid_desc_m_k_
;
SrcGridDesc_M_K
y_grid_desc_m_k_
;
SaveMeanInvStdGridDesc_M
save_mean_grid_desc_m_
;
SaveMeanInvStdGridDesc_M
save_inv_std_grid_desc_m_
;
Kernel1MeanVarGridDesc_M_KBlock
kernel1_mean_var_grid_desc_m_kblock_
;
Kernel2MeanVarGridDesc_M_KBlock
kernel2_mean_var_grid_desc_m_kblock_
;
...
...
@@ -396,6 +470,8 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType,
index_t
MRaw_
;
// invarient length
index_t
KRaw_
;
// reduce length
index_t
invariant_lowest_length_
;
};
struct
Invoker
:
public
BaseInvoker
...
...
@@ -408,60 +484,68 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType,
auto
kernel1
=
kernel_normalizationSplitK1st
<
GridwiseWelford
,
XDataType
,
MeanVarDataType
,
Workspace
MeanVarDataType
,
ComputeDataType
,
SrcGridDesc_M_K
,
Kernel1MeanVarGridDesc_M_KBlock
>
;
auto
kernel2
=
kernel_normalizationSplitK2nd
<
GridwiseWelfordNormalization
,
MeanVarDataType
,
Workspace
MeanVarDataType
,
XDataType
,
GammaDataType
,
BetaDataType
,
YDataType
,
SaveMeanInvStdDataType
,
ComputeDataType
,
YElementwiseOperation
,
Kernel2MeanVarGridDesc_M_KBlock
,
Kernel2CountGridDesc_M_KBlock
,
SrcGridDesc_M_K
>
;
SrcGridDesc_M_K
,
SaveMeanInvStdGridDesc_M
>
;
float
avg_time
=
0
;
avg_time
+=
launch_and_time_kernel
(
stream_config
,
kernel1
,
dim3
(
arg
.
gridSize_
),
dim3
(
BlockSize
),
0
,
arg
.
x_grid_desc_m_k_
,
arg
.
kernel1_mean_var_grid_desc_m_kblock_
,
arg
.
numBlockTileIteration_
,
arg
.
p_x_
,
static_cast
<
MeanVarDataType
*>
(
arg
.
p_workspace_mean_
),
static_cast
<
MeanVarDataType
*>
(
arg
.
p_workspace_var_
),
static_cast
<
int32_t
*>
(
arg
.
p_workspace_count_
));
avg_time
+=
launch_and_time_kernel
(
stream_config
,
kernel2
,
dim3
(
arg
.
gridSize_
),
dim3
(
BlockSize
),
0
,
arg
.
kernel2_mean_var_grid_desc_m_kblock_
,
arg
.
kernel2_count_grid_desc_m_kblock_
,
arg
.
x_grid_desc_m_k_
,
arg
.
gamma_grid_desc_m_k_
,
arg
.
beta_grid_desc_m_k_
,
arg
.
y_grid_desc_m_k_
,
arg
.
numMeanVarCountIteration_
,
arg
.
numBlockTileIteration_
,
arg
.
kGridSize_
,
arg
.
epsilon_
,
static_cast
<
MeanVarDataType
*>
(
arg
.
p_workspace_mean_
),
static_cast
<
MeanVarDataType
*>
(
arg
.
p_workspace_var_
),
static_cast
<
int32_t
*>
(
arg
.
p_workspace_count_
),
arg
.
p_x_
,
arg
.
p_gamma_
,
arg
.
p_beta_
,
arg
.
p_y_
,
arg
.
y_elementwise_op_
);
avg_time
+=
launch_and_time_kernel
(
stream_config
,
kernel1
,
dim3
(
arg
.
gridSize_
),
dim3
(
BlockSize
),
0
,
arg
.
x_grid_desc_m_k_
,
arg
.
kernel1_mean_var_grid_desc_m_kblock_
,
arg
.
numBlockTileIteration_
,
arg
.
p_x_
,
static_cast
<
WorkspaceMeanVarDataType
*>
(
arg
.
p_workspace_mean_
),
static_cast
<
WorkspaceMeanVarDataType
*>
(
arg
.
p_workspace_var_
),
static_cast
<
int32_t
*>
(
arg
.
p_workspace_count_
));
avg_time
+=
launch_and_time_kernel
(
stream_config
,
kernel2
,
dim3
(
arg
.
gridSize_
),
dim3
(
BlockSize
),
0
,
arg
.
kernel2_mean_var_grid_desc_m_kblock_
,
arg
.
kernel2_count_grid_desc_m_kblock_
,
arg
.
x_grid_desc_m_k_
,
arg
.
gamma_grid_desc_m_k_
,
arg
.
beta_grid_desc_m_k_
,
arg
.
y_grid_desc_m_k_
,
arg
.
save_mean_grid_desc_m_
,
arg
.
save_inv_std_grid_desc_m_
,
arg
.
numMeanVarCountIteration_
,
arg
.
numBlockTileIteration_
,
arg
.
kGridSize_
,
arg
.
epsilon_
,
static_cast
<
const
WorkspaceMeanVarDataType
*>
(
arg
.
p_workspace_mean_
),
static_cast
<
const
WorkspaceMeanVarDataType
*>
(
arg
.
p_workspace_var_
),
static_cast
<
const
int32_t
*>
(
arg
.
p_workspace_count_
),
arg
.
p_x_
,
arg
.
p_gamma_
,
arg
.
p_beta_
,
arg
.
p_y_
,
arg
.
p_saveMean_
,
arg
.
p_saveInvStd_
,
arg
.
y_elementwise_op_
);
return
avg_time
;
};
...
...
@@ -482,10 +566,10 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType,
int
welford_size
=
pArg_
->
MRaw_
*
pArg_
->
kGridSize_
;
// workspace for welford intermediate mean
workspace_size
+=
welford_size
*
sizeof
(
MeanVarDataType
)
+
64
;
workspace_size
+=
welford_size
*
sizeof
(
Workspace
MeanVarDataType
)
+
64
;
// workspace for welford intermediate variance
workspace_size
+=
welford_size
*
sizeof
(
MeanVarDataType
)
+
64
;
workspace_size
+=
welford_size
*
sizeof
(
Workspace
MeanVarDataType
)
+
64
;
// workspace for welford intermediate count
workspace_size
+=
pArg_
->
kGridSize_
*
sizeof
(
int32_t
)
+
64
;
...
...
@@ -504,13 +588,13 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType,
// setup buffer used for intermediate welford mean
pArg_
->
p_workspace_mean_
=
static_cast
<
char
*>
(
pArg_
->
p_workspace_
);
index_t
mean_space_sz
=
welford_size
*
sizeof
(
MeanVarDataType
);
index_t
mean_space_sz
=
welford_size
*
sizeof
(
Workspace
MeanVarDataType
);
mean_space_sz
=
math
::
integer_least_multiple
(
mean_space_sz
,
64
);
// setup buffer used for intermediate welford varirance
pArg_
->
p_workspace_var_
=
reinterpret_cast
<
char
*>
(
pArg_
->
p_workspace_mean_
)
+
mean_space_sz
;
index_t
variance_space_sz
=
welford_size
*
sizeof
(
MeanVarDataType
);
index_t
variance_space_sz
=
welford_size
*
sizeof
(
Workspace
MeanVarDataType
);
variance_space_sz
=
math
::
integer_least_multiple
(
variance_space_sz
,
64
);
// setup buffer used for intermediate welford count
...
...
@@ -522,8 +606,6 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType,
{
const
Argument
*
p_arg_
=
dynamic_cast
<
const
Argument
*>
(
p_arg
);
constexpr
index_t
NumInvariantDim
=
Rank
-
NumReduceDim
;
if
constexpr
(
XYVectorDim
==
0
)
{
if
constexpr
(
NumInvariantDim
==
0
)
...
...
@@ -535,10 +617,10 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType,
if
(
p_arg_
->
xStrides_
[
NumInvariantDim
-
1
]
!=
1
)
return
false
;
if
(
p_arg_
->
invariant_lowest_length
%
XSrcVectorSize
!=
0
)
if
(
p_arg_
->
invariant_lowest_length
_
%
XSrcVectorSize
!=
0
)
return
false
;
if
(
p_arg_
->
invariant_lowest_length
%
YDstVectorSize
!=
0
)
if
(
p_arg_
->
invariant_lowest_length
_
%
YDstVectorSize
!=
0
)
return
false
;
};
}
...
...
@@ -578,7 +660,7 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType,
if
(
p_arg_
->
betaStrides_
[
NumInvariantDim
-
1
]
!=
1
)
return
false
;
if
(
p_arg_
->
invariant_lowest_length
%
BetaSrcVectorSize
!=
0
)
if
(
p_arg_
->
invariant_lowest_length
_
%
BetaSrcVectorSize
!=
0
)
return
false
;
}
else
// if fastest dim is reduced
...
...
@@ -593,6 +675,9 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType,
if
(
p_arg_
->
kGridSize_
<=
1
)
return
false
;
if
(
p_arg_
->
invariant_lowest_length_
%
SaveMeanInvStdDstVectorSize
!=
0
)
return
false
;
return
true
;
};
...
...
@@ -602,6 +687,8 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType,
const
std
::
vector
<
index_t
>
gammaStrides
,
const
std
::
vector
<
index_t
>
betaStrides
,
const
std
::
vector
<
index_t
>
yStrides
,
const
std
::
vector
<
index_t
>
saveMeanStrides
,
const
std
::
vector
<
index_t
>
saveInvStdStrides
,
const
std
::
vector
<
index_t
>
reduceDims
,
double
epsilon
,
const
void
*
p_x
,
...
...
@@ -609,27 +696,30 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType,
const
void
*
p_beta
,
void
*
p_y
,
void
*
p_saveMean
,
void
*
p_saveInv
Var
,
void
*
p_saveInv
Std
,
YElementwiseOperation
y_elementwise_op
)
override
{
// TODO
// Optional cache of the intermediate results (mean and InvVariance) during the
// forward pass could speedup in the backward
ignore
=
p_saveMean
;
ignore
=
p_saveInvVar
;
if
(
lengths
.
size
()
!=
Rank
||
xStrides
.
size
()
!=
Rank
||
gammaStrides
.
size
()
!=
Rank
||
betaStrides
.
size
()
!=
Rank
||
yStrides
.
size
()
!=
Rank
||
saveMeanStrides
.
size
()
!=
NumInvariantDim
||
saveInvStdStrides
.
size
()
!=
NumInvariantDim
)
throw
std
::
runtime_error
(
"dimension is incorrect"
);
return
std
::
make_unique
<
Argument
>
(
lengths
,
xStrides
,
gammaStrides
,
betaStrides
,
yStrides
,
saveMeanStrides
,
saveInvStdStrides
,
reduceDims
,
y_elementwise_op
,
epsilon
,
static_cast
<
const
XDataType
*>
(
p_x
),
static_cast
<
const
GammaDataType
*>
(
p_gamma
),
static_cast
<
const
BetaDataType
*>
(
p_beta
),
static_cast
<
YDataType
*>
(
p_y
));
static_cast
<
YDataType
*>
(
p_y
),
static_cast
<
SaveMeanInvStdDataType
*>
(
p_saveMean
),
static_cast
<
SaveMeanInvStdDataType
*>
(
p_saveInvStd
));
};
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
override
...
...
include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp
View file @
1b5af83d
...
...
@@ -113,7 +113,6 @@ struct PassThrough
}
#endif
#if defined CK_ENABLE_FP8
template
<
>
__host__
__device__
void
operator
()
<
f8_t
,
f8_t
>
(
f8_t
&
y
,
const
f8_t
&
x
)
const
{
...
...
@@ -143,9 +142,7 @@ struct PassThrough
{
y
=
type_convert
<
f8_t
>
(
x
);
}
#endif
#if defined CK_ENABLE_BF8
template
<
>
__host__
__device__
void
operator
()
<
bf8_t
,
bf8_t
>
(
bf8_t
&
y
,
const
bf8_t
&
x
)
const
{
...
...
@@ -175,7 +172,6 @@ struct PassThrough
{
y
=
ck
::
type_convert
<
bf8_t
>
(
x
);
}
#endif
};
struct
UnaryConvert
...
...
@@ -204,7 +200,6 @@ struct ConvertBF16RTN
}
};
#if defined CK_ENABLE_FP8
struct
ConvertF8SR
{
// convert to fp8 using stochastic rounding (SR)
...
...
@@ -212,7 +207,8 @@ struct ConvertF8SR
__host__
__device__
void
operator
()(
Y
&
y
,
const
X
&
x
)
const
{
// check Y datatype
static_assert
(
is_same
<
Y
,
f8_t
>::
value
,
"Data type is not supported by this operation!"
);
static_assert
(
is_same
<
Y
,
f8_t
>::
value
||
is_same
<
Y
,
bf8_t
>::
value
,
"Data type is not supported by this operation!"
);
// check X datatype
static_assert
(
is_same
<
X
,
float
>::
value
||
is_same
<
X
,
half_t
>::
value
,
...
...
@@ -221,7 +217,6 @@ struct ConvertF8SR
y
=
f8_convert_sr
<
Y
>
(
x
);
}
};
#endif
struct
Scale
{
...
...
@@ -448,10 +443,11 @@ struct Sigmoid
__host__
__device__
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
is_same
<
T
,
ck
::
half_t
>::
value
,
is_same
<
T
,
ck
::
half_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
||
is_same
<
T
,
int32_t
>::
value
,
"Data type is not supported by this operation!"
);
y
=
1
/
(
ck
::
type_convert
<
T
>
(
1
)
+
exp
(
-
x
));
constexpr
T
one
=
type_convert
<
T
>
(
1
);
y
=
one
/
(
one
+
ck
::
math
::
exp
(
-
x
));
};
};
...
...
@@ -461,7 +457,8 @@ struct TanH
__host__
__device__
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
is_same
<
T
,
ck
::
half_t
>::
value
,
is_same
<
T
,
ck
::
half_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
||
is_same
<
T
,
int32_t
>::
value
,
"Data type is not supported by this operation!"
);
y
=
ck
::
math
::
tanh
(
x
);
...
...
@@ -487,7 +484,101 @@ struct Swish
y
=
type_convert
<
Y
>
(
x
/
(
1.
f
+
ck
::
math
::
exp
(
bx
)));
};
float
beta_
=
1.0
f
;
const
float
beta_
;
};
struct
SoftRelu
{
SoftRelu
(
float
alpha
=
1.
f
)
:
alpha_
(
alpha
){};
template
<
typename
T
>
__host__
__device__
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
is_same
<
T
,
half_t
>::
value
||
is_same
<
T
,
int32_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
,
"Data type is not supported by this operation!"
);
T
casted_alpha
=
type_convert
<
T
>
(
alpha_
);
constexpr
T
one
=
type_convert
<
T
>
(
1
);
y
=
ck
::
math
::
log
(
one
+
ck
::
math
::
exp
(
x
*
casted_alpha
))
/
casted_alpha
;
}
const
float
alpha_
;
};
struct
Power
{
Power
(
float
alpha
=
0.
f
,
float
beta
=
1.
f
,
float
gamma
=
2.
f
)
:
alpha_
(
alpha
),
beta_
(
beta
),
gamma_
(
gamma
){};
template
<
typename
T
>
__host__
__device__
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
is_same
<
T
,
half_t
>::
value
||
is_same
<
T
,
int32_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
,
"Data type is not supported by this operation!"
);
T
casted_alpha
=
type_convert
<
T
>
(
alpha_
);
T
casted_beta
=
type_convert
<
T
>
(
beta_
);
T
casted_gamma
=
type_convert
<
T
>
(
gamma_
);
T
shifted_scaled_x
=
casted_alpha
+
casted_beta
*
x
;
y
=
ck
::
math
::
pow
(
shifted_scaled_x
,
casted_gamma
);
}
const
float
alpha_
;
const
float
beta_
;
const
float
gamma_
;
};
struct
ClippedRelu
{
ClippedRelu
(
float
alpha
=
0.
f
,
float
beta
=
1.
f
)
:
alpha_
(
alpha
),
beta_
(
beta
){};
template
<
typename
T
>
__host__
__device__
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
is_same
<
T
,
half_t
>::
value
||
is_same
<
T
,
int32_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
,
"Data type is not supported by this operation!"
);
T
casted_alpha
=
type_convert
<
T
>
(
alpha_
);
T
casted_beta
=
type_convert
<
T
>
(
beta_
);
y
=
ck
::
math
::
min
(
casted_beta
,
ck
::
math
::
max
(
casted_alpha
,
x
));
}
const
float
alpha_
;
const
float
beta_
;
};
struct
LeakyRelu
{
LeakyRelu
(
float
alpha
=
0.01
f
)
:
alpha_
(
alpha
){};
template
<
typename
T
>
__host__
__device__
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
is_same
<
T
,
half_t
>::
value
||
is_same
<
T
,
int32_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
,
"Data type is not supported by this operation!"
);
T
casted_alpha
=
type_convert
<
T
>
(
alpha_
);
y
=
x
>=
0
?
x
:
x
*
casted_alpha
;
}
const
float
alpha_
;
};
struct
Elu
{
Elu
(
float
alpha
=
1.
f
)
:
alpha_
(
alpha
){};
template
<
typename
T
>
__host__
__device__
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
is_same
<
T
,
half_t
>::
value
||
is_same
<
T
,
int32_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
,
"Data type is not supported by this operation!"
);
T
casted_alpha
=
type_convert
<
T
>
(
alpha_
);
y
=
x
>
0
?
x
:
casted_alpha
*
ck
::
math
::
expm1
(
x
);
}
const
float
alpha_
;
};
}
// namespace element_wise
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp
View file @
1b5af83d
...
...
@@ -36,7 +36,7 @@ __global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
#endif
kernel_grouped_conv_
fwd_
multiple_d_wmma_cshuffle
(
kernel_grouped_conv_multiple_d_wmma_cshuffle
(
const
ADataType
*
__restrict__
p_a_grid
,
const
BDataType
*
__restrict__
p_b_grid
,
DsPointer
p_ds_grid
,
...
...
@@ -452,11 +452,11 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle
}
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
// CheckValidity for kernels without multi D
template
<
typename
Block2CTileMap
>
__host__
__device__
static
constexpr
bool
CheckValidity
(
const
AGridDesc_K0_M_K1
&
a_grid_desc_k0_m_k1
,
const
BGridDesc_K0_N_K1
&
b_grid_desc_k0_n_k1
,
const
DsGridDesc_M_N
&
ds_grid_desc_m_n
,
const
EGridDesc_M_N
&
e_grid_desc_m_n
,
const
Block2CTileMap
&
block_2_ctile_map
)
{
...
...
@@ -471,18 +471,6 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle
const
auto
N
=
b_grid_desc_k0_n_k1
.
GetLength
(
I1
);
const
auto
K0
=
a_grid_desc_k0_m_k1
.
GetLength
(
I0
);
bool
valid
=
true
;
static_for
<
0
,
NumDTensor
,
1
>
{}([
&
](
auto
i
)
{
valid
=
valid
&&
(
M
==
ds_grid_desc_m_n
[
i
].
GetLength
(
I0
)
&&
N
==
ds_grid_desc_m_n
[
i
].
GetLength
(
I1
));
});
if
(
!
valid
)
{
return
false
;
}
if
(
!
(
M
==
e_grid_desc_m_n
.
GetLength
(
I0
)
&&
N
==
e_grid_desc_m_n
.
GetLength
(
I1
)
&&
K0
==
b_grid_desc_k0_n_k1
.
GetLength
(
I0
)
&&
K1
==
a_grid_desc_k0_m_k1
.
GetLength
(
I2
)
&&
K1
==
b_grid_desc_k0_n_k1
.
GetLength
(
I2
)))
...
...
@@ -517,6 +505,31 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle
return
true
;
}
template
<
typename
Block2CTileMap
>
__host__
__device__
static
constexpr
bool
CheckValidity
(
const
AGridDesc_K0_M_K1
&
a_grid_desc_k0_m_k1
,
const
BGridDesc_K0_N_K1
&
b_grid_desc_k0_n_k1
,
const
DsGridDesc_M_N
&
ds_grid_desc_m_n
,
const
EGridDesc_M_N
&
e_grid_desc_m_n
,
const
Block2CTileMap
&
block_2_ctile_map
)
{
const
auto
M
=
a_grid_desc_k0_m_k1
.
GetLength
(
I1
);
const
auto
N
=
b_grid_desc_k0_n_k1
.
GetLength
(
I1
);
bool
valid
=
true
;
static_for
<
0
,
NumDTensor
,
1
>
{}([
&
](
auto
i
)
{
valid
=
valid
&&
(
M
==
ds_grid_desc_m_n
[
i
].
GetLength
(
I0
)
&&
N
==
ds_grid_desc_m_n
[
i
].
GetLength
(
I1
));
});
if
(
!
valid
)
{
return
false
;
}
return
CheckValidity
(
a_grid_desc_k0_m_k1
,
b_grid_desc_k0_n_k1
,
e_grid_desc_m_n
,
block_2_ctile_map
);
}
__host__
__device__
static
constexpr
bool
CalculateHasMainKBlockLoop
(
index_t
K
)
{
const
index_t
num_loop
=
K
/
(
K0PerBlock
*
K1
);
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp
View file @
1b5af83d
...
...
@@ -22,13 +22,19 @@ namespace ck {
template
<
typename
GridwiseGemm
,
bool
HasMainKBlockLoop
,
InMemoryDataOperationEnum
CGlobalMemoryDataOperation
,
typename
Block2CTileMap
>
typename
Block2CTileMap
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
>
__global__
void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
#endif
kernel_gemm_xdlops_v2r4r2_simplified
(
typename
GridwiseGemm
::
Argument
karg
,
const
Block2CTileMap
&
b2c_map
)
const
Block2CTileMap
&
b2c_map
,
const
AElementwiseOperation
a_element_op
,
const
BElementwiseOperation
b_element_op
,
const
CElementwiseOperation
c_element_op
)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
...
...
@@ -37,10 +43,13 @@ __global__ void
__shared__
uint8_t
p_shared
[
shared_size
];
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
,
CGlobalMemoryDataOperation
>(
karg
,
static_cast
<
void
*>
(
p_shared
),
b2c_map
);
karg
,
static_cast
<
void
*>
(
p_shared
),
b2c_map
,
a_element_op
,
b_element_op
,
c_element_op
);
#else
ignore
=
karg
;
ignore
=
b2c_map
;
ignore
=
a_element_op
;
ignore
=
b_element_op
;
ignore
=
c_element_op
;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
}
...
...
@@ -577,7 +586,10 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
typename
Block2CTileMap
>
__device__
static
void
Run
(
const
Argument
&
karg
,
void
*
__restrict__
p_shared_block
,
const
Block2CTileMap
&
block_2_ctile_map
)
const
Block2CTileMap
&
block_2_ctile_map
,
const
AElementwiseOperation
a_element_op
=
AElementwiseOperation
{},
const
BElementwiseOperation
b_element_op
=
BElementwiseOperation
{},
const
CElementwiseOperation
c_element_op
=
CElementwiseOperation
{})
{
const
FloatA
*
p_a_grid
=
karg
.
p_a_grid
;
const
FloatB
*
p_b_grid
=
karg
.
p_b_grid
;
...
...
@@ -590,9 +602,6 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
const
auto
c_grid_desc_mblock_mperblock_nblock_nperblock
=
MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
(
c_grid_desc_m_n
);
const
AElementwiseOperation
a_element_op
=
AElementwiseOperation
{};
const
BElementwiseOperation
b_element_op
=
BElementwiseOperation
{};
const
CElementwiseOperation
c_element_op
=
CElementwiseOperation
{};
const
auto
a_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_a_grid
,
a_b_k0_m_k1_grid_desc
.
GetElementSpaceSize
());
...
...
@@ -761,8 +770,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
auto
blockwise_gemm
=
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector
<
BlockSize
,
ComputeType
,
ComputeType
,
ComputeType
,
// ComputeType A
ComputeType
,
// ComputeType B
FloatAcc
,
decltype
(
a_k0_m_k1_block_desc
),
decltype
(
b_k0_n_k1_block_desc
),
...
...
include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_naive_variance.hpp
View file @
1b5af83d
...
...
@@ -18,9 +18,11 @@ template <typename XDataType,
typename
GammaDataType
,
typename
BetaDataType
,
typename
YDataType
,
typename
SaveMeanInvStdDataType
,
typename
ComputeDataType
,
typename
YElementwiseOperation
,
typename
GridDesc_M_K
,
typename
GridDesc_M
,
index_t
BlockSize
,
index_t
MThreadClusterSize
,
index_t
KThreadClusterSize
,
...
...
@@ -34,6 +36,7 @@ template <typename XDataType,
index_t
BetaSrcVectorSize
,
index_t
YDstVectorDim
,
index_t
YDstVectorSize
,
index_t
SaveMeanInvStdDstVectorSize
,
bool
SweepOnce
>
struct
GridwiseNormalizationNaiveVariance_mk_to_mk
{
...
...
@@ -45,6 +48,10 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk
(
YDstVectorDim
==
1
&&
KThreadSliceSize
%
YDstVectorSize
==
0
),
"Invalid thread slice sizes and/or vector sizes configuration, please check!"
);
static_assert
(
MThreadSliceSize
%
SaveMeanInvStdDstVectorSize
==
0
,
"Invalid thread slice sizes and/or save mean and inverse std vector sizes "
"configuration, please check!"
);
static_assert
(
XSrcVectorSize
==
YDstVectorSize
);
static_assert
(
XSrcVectorSize
==
GammaSrcVectorSize
);
static_assert
(
XSrcVectorSize
==
BetaSrcVectorSize
);
...
...
@@ -66,6 +73,10 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk
static
constexpr
auto
thread_buffer_desc_m_k
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MThreadSliceSize
>
{},
Number
<
XSrcVectorSize
>
{}));
using
ThreadBufferLengths_M
=
Sequence
<
MThreadSliceSize
>
;
static
constexpr
auto
thread_buffer_desc_m
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MThreadSliceSize
>
{}));
using
ThreadReduceSrcDesc_M_K
=
decltype
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MThreadSliceSize
>
{},
Number
<
XSrcVectorSize
>
{})));
using
ThreadReduceDstDesc_M
=
...
...
@@ -84,6 +95,8 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk
reduce
::
Add
,
true
>
;
using
PassThroughOp
=
tensor_operation
::
element_wise
::
PassThrough
;
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
...
...
@@ -98,12 +111,16 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk
const
GridDesc_M_K
&
gamma_grid_desc_m_k
,
const
GridDesc_M_K
&
beta_grid_desc_m_k
,
const
GridDesc_M_K
&
y_grid_desc_m_k
,
const
GridDesc_M
&
save_mean_grid_desc_m
,
const
GridDesc_M
&
save_inv_std_grid_desc_m
,
index_t
num_k_block_tile_iteration
,
ComputeDataType
epsilon
,
const
XDataType
*
const
__restrict__
p_x_global
,
const
GammaDataType
*
const
__restrict__
p_gamma_global
,
const
BetaDataType
*
const
__restrict__
p_beta_global
,
YDataType
*
const
__restrict__
p_y_global
,
SaveMeanInvStdDataType
*
const
__restrict__
p_save_mean_global
,
SaveMeanInvStdDataType
*
const
__restrict__
p_save_inv_std_global
,
const
YElementwiseOperation
y_elementwise_op
)
{
// LDS
...
...
@@ -115,6 +132,12 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk
auto
y_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_y_global
,
y_grid_desc_m_k
.
GetElementSpaceSize
());
auto
save_mean_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_save_mean_global
,
save_mean_grid_desc_m
.
GetElementSpaceSize
());
auto
save_inv_std_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_save_inv_std_global
,
save_inv_std_grid_desc_m
.
GetElementSpaceSize
());
auto
x_thread_buf
=
generate_tuple
(
[
&
](
auto
)
{
return
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
...
...
@@ -152,6 +175,8 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk
mean_square_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
ComputeDataType
,
MThreadSliceSize
,
true
>&
var_thread_buf
=
mean_square_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
ComputeDataType
,
MThreadSliceSize
,
true
>&
inv_std_thread_buf
=
mean_square_thread_buf
;
const
index_t
thread_local_id
=
get_thread_local_1d_id
();
const
index_t
block_global_id
=
get_block_1d_id
();
...
...
@@ -228,6 +253,42 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk
thread_k_cluster_id
*
YDstVectorSize
),
y_elementwise_op
);
auto
threadwise_mean_store
=
ThreadwiseTensorSliceTransfer_v1r3
<
ComputeDataType
,
SaveMeanInvStdDataType
,
decltype
(
thread_buffer_desc_m
),
GridDesc_M
,
PassThroughOp
,
ThreadBufferLengths_M
,
Sequence
<
0
>
,
// DimAccessOrder
0
,
// SrcVectorDim
SaveMeanInvStdDstVectorSize
,
// ScalarPerVector
InMemoryDataOperationEnum
::
Set
,
1
,
true
>
(
save_mean_grid_desc_m
,
make_multi_index
(
block_global_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
),
PassThroughOp
{});
auto
threadwise_inv_std_store
=
ThreadwiseTensorSliceTransfer_v1r3
<
ComputeDataType
,
SaveMeanInvStdDataType
,
decltype
(
thread_buffer_desc_m
),
GridDesc_M
,
PassThroughOp
,
ThreadBufferLengths_M
,
Sequence
<
0
>
,
// DimAccessOrder
0
,
// SrcVectorDim
SaveMeanInvStdDstVectorSize
,
// ScalarPerVector
InMemoryDataOperationEnum
::
Set
,
1
,
true
>
(
save_inv_std_grid_desc_m
,
make_multi_index
(
block_global_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
),
PassThroughOp
{});
constexpr
auto
thread_copy_fwd_step_m_k
=
make_multi_index
(
0
,
K_BlockTileStepSize
);
constexpr
auto
thread_copy_bwd_step_m_k
=
make_multi_index
(
0
,
SweepOnce
?
0
:
-
K_BlockTileSize
);
...
...
@@ -243,7 +304,8 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk
// E(x), E[x^2], var(x)
// FIXME: Should not hack the transform from deviceOP
int
reduce_length
=
x_grid_desc_m_k
.
GetTransforms
()[
I2
].
GetUpperLengths
()[
I0
];
ComputeDataType
reduce_length
=
type_convert
<
ComputeDataType
>
(
x_grid_desc_m_k
.
GetTransforms
()[
I2
].
GetUpperLengths
()[
I0
]);
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
I
)
{
mean_thread_buf
(
I
)
=
reduce
::
Add
::
template
GetIdentityValue
<
ComputeDataType
>();
...
...
@@ -302,10 +364,34 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk
// var(x) = E[x^2] - E[x]^2
var_thread_buf
(
I
)
=
mean_square_thread_buf
(
I
)
-
(
mean_thread_buf
(
I
)
*
mean_thread_buf
(
I
));
inv_std_thread_buf
(
I
)
=
type_convert
<
ComputeDataType
>
(
1.0
f
)
/
ck
::
math
::
sqrt
(
var_thread_buf
(
I
)
+
epsilon
);
});
// save mean and inverse std for backward (optional)
if
(
thread_k_cluster_id
==
0
)
{
if
(
p_save_mean_global
!=
nullptr
)
{
threadwise_mean_store
.
Run
(
thread_buffer_desc_m
,
make_tuple
(
I0
),
mean_thread_buf
,
save_mean_grid_desc_m
,
save_mean_global_val_buf
);
}
if
(
p_save_inv_std_global
!=
nullptr
)
{
threadwise_inv_std_store
.
Run
(
thread_buffer_desc_m
,
make_tuple
(
I0
),
inv_std_thread_buf
,
save_inv_std_grid_desc_m
,
save_inv_std_global_val_buf
);
}
}
// normalization
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
iM
)
{
auto
divisor
=
1
/
ck
::
math
::
sqrt
(
var_thread_buf
(
iM
)
+
epsilon
);
static_for
<
0
,
ThreadBufferNumber
,
1
>
{}([
&
](
auto
iK0
)
{
static_for
<
0
,
XSrcVectorSize
,
1
>
{}([
&
](
auto
iK1
)
{
constexpr
auto
offset_m_k
=
...
...
@@ -314,7 +400,7 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk
// normalize
y_thread_buf
(
iK0
)(
Number
<
offset_m_k
>
{})
=
(
x_thread_buf
(
iK0
)(
Number
<
offset_m_k
>
{})
-
mean_thread_buf
(
iM
))
*
divisor
;
inv_std_thread_buf
(
iM
)
;
// gamma & beta
y_thread_buf
(
iK0
)(
Number
<
offset_m_k
>
{})
=
...
...
@@ -404,8 +490,30 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk
// var(x) = E[x^2] - E[x]^2
var_thread_buf
(
I
)
=
mean_square_thread_buf
(
I
)
-
(
mean_thread_buf
(
I
)
*
mean_thread_buf
(
I
));
inv_std_thread_buf
(
I
)
=
1
/
ck
::
math
::
sqrt
(
var_thread_buf
(
I
)
+
epsilon
);
});
if
(
thread_k_cluster_id
==
0
)
{
if
(
p_save_mean_global
!=
nullptr
)
{
threadwise_mean_store
.
Run
(
thread_buffer_desc_m
,
make_tuple
(
I0
),
mean_thread_buf
,
save_mean_grid_desc_m
,
save_mean_global_val_buf
);
}
if
(
p_save_inv_std_global
!=
nullptr
)
{
threadwise_inv_std_store
.
Run
(
thread_buffer_desc_m
,
make_tuple
(
I0
),
inv_std_thread_buf
,
save_inv_std_grid_desc_m
,
save_inv_std_global_val_buf
);
}
}
auto
thread_copy_tail_m_k
=
(
num_k_block_tile_iteration
-
1
)
*
ThreadBufferNumber
*
thread_copy_fwd_step_m_k
;
...
...
@@ -437,7 +545,6 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk
});
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
iM
)
{
auto
divisor
=
1
/
ck
::
math
::
sqrt
(
var_thread_buf
(
iM
)
+
epsilon
);
static_for
<
0
,
ThreadBufferNumber
,
1
>
{}([
&
](
auto
iK0
)
{
static_for
<
0
,
XSrcVectorSize
,
1
>
{}([
&
](
auto
iK1
)
{
constexpr
auto
offset_m_k
=
...
...
@@ -446,7 +553,7 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk
// normalize
y_thread_buf
(
iK0
)(
Number
<
offset_m_k
>
{})
=
(
x_thread_buf
(
iK0
)(
Number
<
offset_m_k
>
{})
-
mean_thread_buf
(
iM
))
*
divisor
;
inv_std_thread_buf
(
iM
)
;
// gamma
y_thread_buf
(
iK0
)(
Number
<
offset_m_k
>
{})
=
...
...
include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_selector.hpp
View file @
1b5af83d
...
...
@@ -12,31 +12,42 @@ template <typename GridwiseReduction,
typename
GammaDataType
,
typename
BetaDataType
,
typename
YDataType
,
typename
SaveMeanInvStdDataType
,
typename
ComputeDataType
,
typename
YElementwiseOperation
,
typename
GridDesc_M_K
>
__global__
void
kernel_normalization
(
const
GridDesc_M_K
x_grid_desc_m_k
,
const
GridDesc_M_K
gamma_grid_desc_m_k
,
const
GridDesc_M_K
beta_grid_desc_m_k
,
const
GridDesc_M_K
y_grid_desc_m_k
,
index_t
num_k_block_tile_iteration
,
ComputeDataType
epsilon
,
const
XDataType
*
const
__restrict__
p_x_global
,
const
GammaDataType
*
const
__restrict__
p_gamma_global
,
const
BetaDataType
*
const
__restrict__
p_beta_global
,
YDataType
*
const
__restrict__
p_y_global
,
const
YElementwiseOperation
y_elementwise_op
)
typename
GridDesc_M_K
,
typename
GridDesc_M
>
__global__
void
kernel_normalization
(
const
GridDesc_M_K
x_grid_desc_m_k
,
const
GridDesc_M_K
gamma_grid_desc_m_k
,
const
GridDesc_M_K
beta_grid_desc_m_k
,
const
GridDesc_M_K
y_grid_desc_m_k
,
const
GridDesc_M
save_mean_grid_desc_m
,
const
GridDesc_M
save_inv_std_grid_desc_m
,
index_t
num_k_block_tile_iteration
,
ComputeDataType
epsilon
,
const
XDataType
*
const
__restrict__
p_x_global
,
const
GammaDataType
*
const
__restrict__
p_gamma_global
,
const
BetaDataType
*
const
__restrict__
p_beta_global
,
YDataType
*
const
__restrict__
p_y_global
,
SaveMeanInvStdDataType
*
const
__restrict__
p_save_mean_global
,
SaveMeanInvStdDataType
*
const
__restrict__
p_save_inv_std_global
,
const
YElementwiseOperation
y_elementwise_op
)
{
GridwiseReduction
::
Run
(
x_grid_desc_m_k
,
gamma_grid_desc_m_k
,
beta_grid_desc_m_k
,
y_grid_desc_m_k
,
save_mean_grid_desc_m
,
save_inv_std_grid_desc_m
,
num_k_block_tile_iteration
,
epsilon
,
p_x_global
,
p_gamma_global
,
p_beta_global
,
p_y_global
,
p_save_mean_global
,
p_save_inv_std_global
,
y_elementwise_op
);
};
...
...
@@ -44,9 +55,11 @@ template <typename XDataType,
typename
GammaDataType
,
typename
BetaDataType
,
typename
YDataType
,
typename
SaveMeanInvStdDataType
,
typename
ComputeDataType
,
typename
YElementwiseOperation
,
typename
GridDesc_M_K
,
typename
GridDesc_M
,
index_t
BlockSize
,
index_t
MThreadClusterSize
,
index_t
KThreadClusterSize
,
...
...
@@ -60,6 +73,7 @@ template <typename XDataType,
index_t
BetaSrcVectorSize
,
index_t
YDstVectorDim
,
index_t
YDstVectorSize
,
index_t
SaveMeanInvStdDstVectorSize
,
bool
UseWelford
>
auto
NormalizationKernelSelector
(
bool
isSweepOnce
)
{
...
...
@@ -68,9 +82,11 @@ auto NormalizationKernelSelector(bool isSweepOnce)
GammaDataType
,
BetaDataType
,
YDataType
,
SaveMeanInvStdDataType
,
ComputeDataType
,
YElementwiseOperation
,
GridDesc_M_K
,
GridDesc_M
,
BlockSize
,
MThreadClusterSize
,
KThreadClusterSize
,
...
...
@@ -84,15 +100,18 @@ auto NormalizationKernelSelector(bool isSweepOnce)
BetaSrcVectorSize
,
YDstVectorDim
,
YDstVectorSize
,
SaveMeanInvStdDstVectorSize
,
false
>
;
using
GridwiseNormalizationSweepOnceNaive
=
GridwiseNormalizationNaiveVariance_mk_to_mk
<
XDataType
,
GammaDataType
,
BetaDataType
,
YDataType
,
SaveMeanInvStdDataType
,
ComputeDataType
,
YElementwiseOperation
,
GridDesc_M_K
,
GridDesc_M
,
BlockSize
,
MThreadClusterSize
,
KThreadClusterSize
,
...
...
@@ -106,15 +125,18 @@ auto NormalizationKernelSelector(bool isSweepOnce)
BetaSrcVectorSize
,
YDstVectorDim
,
YDstVectorSize
,
SaveMeanInvStdDstVectorSize
,
true
>
;
using
GridwiseNormalizationGenericWelford
=
GridwiseNormalizationWelfordVariance_mk_to_mk
<
XDataType
,
GammaDataType
,
BetaDataType
,
YDataType
,
SaveMeanInvStdDataType
,
ComputeDataType
,
YElementwiseOperation
,
GridDesc_M_K
,
GridDesc_M
,
BlockSize
,
MThreadClusterSize
,
KThreadClusterSize
,
...
...
@@ -128,15 +150,18 @@ auto NormalizationKernelSelector(bool isSweepOnce)
BetaSrcVectorSize
,
YDstVectorDim
,
YDstVectorSize
,
SaveMeanInvStdDstVectorSize
,
false
>
;
using
GridwiseNormalizationSweepOnceWelford
=
GridwiseNormalizationWelfordVariance_mk_to_mk
<
XDataType
,
GammaDataType
,
BetaDataType
,
YDataType
,
SaveMeanInvStdDataType
,
ComputeDataType
,
YElementwiseOperation
,
GridDesc_M_K
,
GridDesc_M
,
BlockSize
,
MThreadClusterSize
,
KThreadClusterSize
,
...
...
@@ -150,6 +175,7 @@ auto NormalizationKernelSelector(bool isSweepOnce)
BetaSrcVectorSize
,
YDstVectorDim
,
YDstVectorSize
,
SaveMeanInvStdDstVectorSize
,
true
>
;
if
constexpr
(
UseWelford
)
...
...
@@ -159,17 +185,21 @@ auto NormalizationKernelSelector(bool isSweepOnce)
GammaDataType
,
BetaDataType
,
YDataType
,
SaveMeanInvStdDataType
,
ComputeDataType
,
YElementwiseOperation
,
GridDesc_M_K
>
GridDesc_M_K
,
GridDesc_M
>
:
kernel_normalization
<
GridwiseNormalizationGenericWelford
,
XDataType
,
GammaDataType
,
BetaDataType
,
YDataType
,
SaveMeanInvStdDataType
,
ComputeDataType
,
YElementwiseOperation
,
GridDesc_M_K
>
;
GridDesc_M_K
,
GridDesc_M
>
;
}
else
{
...
...
@@ -178,17 +208,21 @@ auto NormalizationKernelSelector(bool isSweepOnce)
GammaDataType
,
BetaDataType
,
YDataType
,
SaveMeanInvStdDataType
,
ComputeDataType
,
YElementwiseOperation
,
GridDesc_M_K
>
GridDesc_M_K
,
GridDesc_M
>
:
kernel_normalization
<
GridwiseNormalizationGenericNaive
,
XDataType
,
GammaDataType
,
BetaDataType
,
YDataType
,
SaveMeanInvStdDataType
,
ComputeDataType
,
YElementwiseOperation
,
GridDesc_M_K
>
;
GridDesc_M_K
,
GridDesc_M
>
;
}
}
...
...
include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_splitk_2nd.hpp
View file @
1b5af83d
...
...
@@ -17,11 +17,13 @@ template <typename MeanVarDataType,
typename
GammaDataType
,
typename
BetaDataType
,
typename
YDataType
,
typename
SaveMeanInvStdDataType
,
typename
ComputeDataType
,
typename
YElementwiseOperation
,
typename
MeanVarGridDesc_M_KBlock
,
typename
CountGridDesc_M_KBlock
,
typename
XYGammaBetaGridDesc_M_K
,
typename
SaveMeanInvStdGridDesc_M
,
index_t
BlockSize
,
index_t
MThreadClusterSize
,
index_t
KThreadClusterSize
,
...
...
@@ -34,7 +36,8 @@ template <typename MeanVarDataType,
index_t
BetaSrcVectorDim
,
index_t
BetaSrcVectorSize
,
index_t
YDstVectorDim
,
index_t
YDstVectorSize
>
index_t
YDstVectorSize
,
index_t
SaveMeanInvStdDstVectorSize
>
struct
GridwiseNormalizationSplitK2nd
{
static_assert
((
XSrcVectorDim
==
0
&&
MThreadSliceSize
%
XSrcVectorSize
==
0
)
||
...
...
@@ -45,6 +48,10 @@ struct GridwiseNormalizationSplitK2nd
(
YDstVectorDim
==
1
&&
KThreadSliceSize
%
YDstVectorSize
==
0
),
"Invalid thread slice sizes and/or vector sizes configuration, please check!"
);
static_assert
(
MThreadSliceSize
%
SaveMeanInvStdDstVectorSize
==
0
,
"Invalid thread slice sizes and/or save mean and inverse std vector sizes "
"configuration, please check!"
);
static_assert
(
XSrcVectorSize
==
YDstVectorSize
);
static_assert
(
XSrcVectorSize
==
GammaSrcVectorSize
);
static_assert
(
XSrcVectorSize
==
BetaSrcVectorSize
);
...
...
@@ -69,6 +76,10 @@ struct GridwiseNormalizationSplitK2nd
static
constexpr
auto
thread_buffer_desc_m_k
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MThreadSliceSize
>
{},
Number
<
XSrcVectorSize
>
{}));
using
ThreadBufferLengths_M
=
Sequence
<
MThreadSliceSize
>
;
static
constexpr
auto
thread_buffer_desc_m
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MThreadSliceSize
>
{}));
using
ThreadBufferLengths_M_1
=
Sequence
<
MThreadSliceSize
,
1
>
;
static
constexpr
auto
thread_buffer_desc_m_1
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MThreadSliceSize
>
{},
I1
));
...
...
@@ -99,6 +110,8 @@ struct GridwiseNormalizationSplitK2nd
const
XYGammaBetaGridDesc_M_K
&
gamma_grid_desc_m_k
,
const
XYGammaBetaGridDesc_M_K
&
beta_grid_desc_m_k
,
const
XYGammaBetaGridDesc_M_K
&
y_grid_desc_m_k
,
const
SaveMeanInvStdGridDesc_M
&
save_mean_grid_desc_m
,
const
SaveMeanInvStdGridDesc_M
&
save_inv_std_grid_desc_m
,
index_t
num_k_mean_var_count_iteration
,
index_t
num_k_block_tile_iteration
,
index_t
k_grid_size
,
...
...
@@ -110,6 +123,8 @@ struct GridwiseNormalizationSplitK2nd
const
GammaDataType
*
const
__restrict__
p_gamma_global
,
const
BetaDataType
*
const
__restrict__
p_beta_global
,
YDataType
*
const
__restrict__
p_y_global
,
SaveMeanInvStdDataType
*
const
__restrict__
p_save_mean_global
,
SaveMeanInvStdDataType
*
const
__restrict__
p_save_inv_std_global
,
const
YElementwiseOperation
y_elementwise_op
)
{
// Thread/Block id
...
...
@@ -145,6 +160,12 @@ struct GridwiseNormalizationSplitK2nd
auto
y_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_y_global
,
y_grid_desc_m_k
.
GetElementSpaceSize
());
auto
save_mean_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_save_mean_global
,
save_mean_grid_desc_m
.
GetElementSpaceSize
());
auto
save_inv_std_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_save_inv_std_global
,
save_inv_std_grid_desc_m
.
GetElementSpaceSize
());
// VGPR
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
ComputeDataType
,
MThreadSliceSize
,
true
>
in_mean_thread_buf
;
...
...
@@ -158,6 +179,7 @@ struct GridwiseNormalizationSplitK2nd
var_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
int32_t
,
MThreadSliceSize
,
true
>
welford_count_thread_buf
;
auto
&
inv_std_thread_buf
=
var_thread_buf
;
auto
x_thread_buf
=
generate_tuple
(
[
&
](
auto
)
{
...
...
@@ -283,6 +305,42 @@ struct GridwiseNormalizationSplitK2nd
thread_k_cluster_id
*
YDstVectorSize
),
y_elementwise_op
);
auto
threadwise_mean_store
=
ThreadwiseTensorSliceTransfer_v1r3
<
ComputeDataType
,
SaveMeanInvStdDataType
,
decltype
(
thread_buffer_desc_m
),
SaveMeanInvStdGridDesc_M
,
PassThroughOp
,
ThreadBufferLengths_M
,
Sequence
<
0
>
,
// DimAccessOrder
0
,
// SrcVectorDim
SaveMeanInvStdDstVectorSize
,
// ScalarPerVector
InMemoryDataOperationEnum
::
Set
,
1
,
true
>
(
save_mean_grid_desc_m
,
make_multi_index
(
block_m_cluster_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
),
PassThroughOp
{});
auto
threadwise_inv_std_store
=
ThreadwiseTensorSliceTransfer_v1r3
<
ComputeDataType
,
SaveMeanInvStdDataType
,
decltype
(
thread_buffer_desc_m
),
SaveMeanInvStdGridDesc_M
,
PassThroughOp
,
ThreadBufferLengths_M
,
Sequence
<
0
>
,
// DimAccessOrder
0
,
// SrcVectorDim
SaveMeanInvStdDstVectorSize
,
// ScalarPerVector
InMemoryDataOperationEnum
::
Set
,
1
,
true
>
(
save_inv_std_grid_desc_m
,
make_multi_index
(
block_m_cluster_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
),
PassThroughOp
{});
// step1: Merge mean and variance
constexpr
auto
mean_var_count_thread_copy_step_I0_k
=
make_multi_index
(
I0
,
KThreadClusterSize
);
...
...
@@ -332,9 +390,33 @@ struct GridwiseNormalizationSplitK2nd
BlockwiseWelford
::
Run
(
mean_thread_buf
(
I
),
var_thread_buf
(
I
),
welford_count_thread_buf
(
I
));
inv_std_thread_buf
(
I
)
=
type_convert
<
ComputeDataType
>
(
1.0
f
)
/
ck
::
math
::
sqrt
(
var_thread_buf
(
I
)
+
epsilon
);
});
// step2: normalization
// step2: save mean and inverse std for backward (optional)
if
(
block_k_cluster_id
==
0
&&
thread_k_cluster_id
==
0
)
{
if
(
p_save_mean_global
!=
nullptr
)
{
threadwise_mean_store
.
Run
(
thread_buffer_desc_m
,
make_tuple
(
I0
),
mean_thread_buf
,
save_mean_grid_desc_m
,
save_mean_global_val_buf
);
}
if
(
p_save_inv_std_global
!=
nullptr
)
{
threadwise_inv_std_store
.
Run
(
thread_buffer_desc_m
,
make_tuple
(
I0
),
inv_std_thread_buf
,
save_inv_std_grid_desc_m
,
save_inv_std_global_val_buf
);
}
}
// step3: normalization
constexpr
auto
thread_copy_fwd_step_m_k
=
make_multi_index
(
0
,
K_BlockTileStepSize
);
for
(
index_t
k
=
0
;
k
<
num_k_block_tile_iteration
;
++
k
)
...
...
@@ -360,7 +442,6 @@ struct GridwiseNormalizationSplitK2nd
});
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
iM
)
{
auto
divisor
=
1
/
ck
::
math
::
sqrt
(
var_thread_buf
(
iM
)
+
epsilon
);
static_for
<
0
,
ThreadBufferNumber
,
1
>
{}([
&
](
auto
iK0
)
{
static_for
<
0
,
XSrcVectorSize
,
1
>
{}([
&
](
auto
iK1
)
{
constexpr
auto
offset_m_k
=
...
...
@@ -369,7 +450,7 @@ struct GridwiseNormalizationSplitK2nd
// normalize
y_thread_buf
(
iK0
)(
Number
<
offset_m_k
>
{})
=
(
x_thread_buf
(
iK0
)(
Number
<
offset_m_k
>
{})
-
mean_thread_buf
(
iM
))
*
divisor
;
inv_std_thread_buf
(
iM
)
;
// gamma
y_thread_buf
(
iK0
)(
Number
<
offset_m_k
>
{})
=
...
...
include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_welford_variance.hpp
View file @
1b5af83d
...
...
@@ -16,9 +16,11 @@ template <typename XDataType,
typename
GammaDataType
,
typename
BetaDataType
,
typename
YDataType
,
typename
SaveMeanInvStdDataType
,
typename
ComputeDataType
,
typename
YElementwiseOperation
,
typename
GridDesc_M_K
,
typename
GridDesc_M
,
index_t
BlockSize
,
index_t
MThreadClusterSize
,
index_t
KThreadClusterSize
,
...
...
@@ -32,6 +34,7 @@ template <typename XDataType,
index_t
BetaSrcVectorSize
,
index_t
YDstVectorDim
,
index_t
YDstVectorSize
,
index_t
SaveMeanInvStdDstVectorSize
,
bool
SweepOnce
>
struct
GridwiseNormalizationWelfordVariance_mk_to_mk
{
...
...
@@ -43,6 +46,10 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk
(
YDstVectorDim
==
1
&&
KThreadSliceSize
%
YDstVectorSize
==
0
),
"Invalid thread slice sizes and/or vector sizes configuration, please check!"
);
static_assert
(
MThreadSliceSize
%
SaveMeanInvStdDstVectorSize
==
0
,
"Invalid thread slice sizes and/or save mean and inverse std vector sizes "
"configuration, please check!"
);
static_assert
(
XSrcVectorSize
==
YDstVectorSize
);
static_assert
(
XSrcVectorSize
==
GammaSrcVectorSize
);
static_assert
(
XSrcVectorSize
==
BetaSrcVectorSize
);
...
...
@@ -64,6 +71,10 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk
static
constexpr
auto
thread_buffer_desc_m_k
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MThreadSliceSize
>
{},
Number
<
XSrcVectorSize
>
{}));
using
ThreadBufferLengths_M
=
Sequence
<
MThreadSliceSize
>
;
static
constexpr
auto
thread_buffer_desc_m
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MThreadSliceSize
>
{}));
using
ThreadReduceSrcDesc_M_K
=
decltype
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MThreadSliceSize
>
{},
Number
<
XSrcVectorSize
>
{})));
using
ThreadReduceDstDesc_M
=
...
...
@@ -77,6 +88,8 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk
ThreadClusterLengths_M_K
,
ThreadClusterArrangeOrder
>
;
using
PassThroughOp
=
tensor_operation
::
element_wise
::
PassThrough
;
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
...
...
@@ -114,17 +127,18 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk
const
GridDesc_M_K
&
gamma_grid_desc_m_k
,
const
GridDesc_M_K
&
beta_grid_desc_m_k
,
const
GridDesc_M_K
&
y_grid_desc_m_k
,
const
GridDesc_M
&
save_mean_grid_desc_m
,
const
GridDesc_M
&
save_inv_std_grid_desc_m
,
index_t
num_k_block_tile_iteration
,
ComputeDataType
epsilon
,
const
XDataType
*
const
__restrict__
p_x_global
,
const
GammaDataType
*
const
__restrict__
p_gamma_global
,
const
BetaDataType
*
const
__restrict__
p_beta_global
,
YDataType
*
const
__restrict__
p_y_global
,
SaveMeanInvStdDataType
*
const
__restrict__
p_save_mean_global
,
SaveMeanInvStdDataType
*
const
__restrict__
p_save_inv_std_global
,
const
YElementwiseOperation
y_elementwise_op
)
{
auto
y_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_y_global
,
y_grid_desc_m_k
.
GetElementSpaceSize
());
auto
x_thread_buf
=
generate_tuple
(
[
&
](
auto
)
{
return
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
...
...
@@ -150,6 +164,7 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk
mean_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
ComputeDataType
,
MThreadSliceSize
,
true
>
var_thread_buf
;
auto
&
inv_std_thread_buf
=
var_thread_buf
;
const
index_t
thread_local_id
=
get_thread_local_1d_id
();
const
index_t
block_global_id
=
get_block_1d_id
();
...
...
@@ -226,6 +241,42 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk
thread_k_cluster_id
*
YDstVectorSize
),
y_elementwise_op
);
auto
threadwise_mean_store
=
ThreadwiseTensorSliceTransfer_v1r3
<
ComputeDataType
,
SaveMeanInvStdDataType
,
decltype
(
thread_buffer_desc_m
),
GridDesc_M
,
PassThroughOp
,
ThreadBufferLengths_M
,
Sequence
<
0
>
,
// DimAccessOrder
0
,
// SrcVectorDim
SaveMeanInvStdDstVectorSize
,
// ScalarPerVector
InMemoryDataOperationEnum
::
Set
,
1
,
true
>
(
save_mean_grid_desc_m
,
make_multi_index
(
block_global_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
),
PassThroughOp
{});
auto
threadwise_inv_std_store
=
ThreadwiseTensorSliceTransfer_v1r3
<
ComputeDataType
,
SaveMeanInvStdDataType
,
decltype
(
thread_buffer_desc_m
),
GridDesc_M
,
PassThroughOp
,
ThreadBufferLengths_M
,
Sequence
<
0
>
,
// DimAccessOrder
0
,
// SrcVectorDim
SaveMeanInvStdDstVectorSize
,
// ScalarPerVector
InMemoryDataOperationEnum
::
Set
,
1
,
true
>
(
save_inv_std_grid_desc_m
,
make_multi_index
(
block_global_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
),
PassThroughOp
{});
constexpr
auto
thread_copy_fwd_step_m_k
=
make_multi_index
(
0
,
K_BlockTileStepSize
);
constexpr
auto
thread_copy_bwd_step_m_k
=
make_multi_index
(
0
,
SweepOnce
?
0
:
-
K_BlockTileSize
);
...
...
@@ -239,6 +290,15 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk
const
auto
beta_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_beta_global
,
beta_grid_desc_m_k
.
GetElementSpaceSize
());
auto
y_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_y_global
,
y_grid_desc_m_k
.
GetElementSpaceSize
());
auto
save_mean_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_save_mean_global
,
save_mean_grid_desc_m
.
GetElementSpaceSize
());
auto
save_inv_std_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_save_inv_std_global
,
save_inv_std_grid_desc_m
.
GetElementSpaceSize
());
auto
threadwise_welford
=
ThreadwiseWelford
();
threadwise_welford
.
max_count_
=
GetKPerThread
(
x_grid_desc_m_k
,
thread_k_cluster_id
);
...
...
@@ -279,10 +339,33 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk
int
count
=
threadwise_welford
.
cur_count_
;
BlockwiseWelford
::
Run
(
mean_thread_buf
(
I
),
var_thread_buf
(
I
),
count
);
inv_std_thread_buf
(
I
)
=
type_convert
<
ComputeDataType
>
(
1.0
f
)
/
ck
::
math
::
sqrt
(
var_thread_buf
(
I
)
+
epsilon
);
});
// save mean and inverse std for backward (optional)
if
(
thread_k_cluster_id
==
0
)
{
if
(
p_save_mean_global
!=
nullptr
)
{
threadwise_mean_store
.
Run
(
thread_buffer_desc_m
,
make_tuple
(
I0
),
mean_thread_buf
,
save_mean_grid_desc_m
,
save_mean_global_val_buf
);
}
if
(
p_save_inv_std_global
!=
nullptr
)
{
threadwise_inv_std_store
.
Run
(
thread_buffer_desc_m
,
make_tuple
(
I0
),
inv_std_thread_buf
,
save_inv_std_grid_desc_m
,
save_inv_std_global_val_buf
);
}
}
// normalization
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
iM
)
{
auto
divisor
=
1
/
ck
::
math
::
sqrt
(
var_thread_buf
(
iM
)
+
epsilon
);
static_for
<
0
,
ThreadBufferNumber
,
1
>
{}([
&
](
auto
iK0
)
{
static_for
<
0
,
XSrcVectorSize
,
1
>
{}([
&
](
auto
iK1
)
{
constexpr
auto
offset_m_k
=
...
...
@@ -291,7 +374,7 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk
// normalize
y_thread_buf
(
iK0
)(
Number
<
offset_m_k
>
{})
=
(
x_thread_buf
(
iK0
)(
Number
<
offset_m_k
>
{})
-
mean_thread_buf
(
iM
))
*
divisor
;
inv_std_thread_buf
(
iM
)
;
// gamma & beta
y_thread_buf
(
iK0
)(
Number
<
offset_m_k
>
{})
=
...
...
@@ -360,8 +443,29 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk
int
count
=
threadwise_welford
.
cur_count_
;
BlockwiseWelford
::
Run
(
mean_thread_buf
(
I
),
var_thread_buf
(
I
),
count
);
inv_std_thread_buf
(
I
)
=
1
/
ck
::
math
::
sqrt
(
var_thread_buf
(
I
)
+
epsilon
);
});
if
(
thread_k_cluster_id
==
0
)
{
if
(
p_save_mean_global
!=
nullptr
)
{
threadwise_mean_store
.
Run
(
thread_buffer_desc_m
,
make_tuple
(
I0
),
mean_thread_buf
,
save_mean_grid_desc_m
,
save_mean_global_val_buf
);
}
if
(
p_save_inv_std_global
!=
nullptr
)
{
threadwise_inv_std_store
.
Run
(
thread_buffer_desc_m
,
make_tuple
(
I0
),
inv_std_thread_buf
,
save_inv_std_grid_desc_m
,
save_inv_std_global_val_buf
);
}
}
auto
thread_copy_tail_m_k
=
(
num_k_block_tile_iteration
-
1
)
*
ThreadBufferNumber
*
thread_copy_fwd_step_m_k
;
...
...
@@ -393,7 +497,6 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk
});
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
iM
)
{
auto
divisor
=
1
/
ck
::
math
::
sqrt
(
var_thread_buf
(
iM
)
+
epsilon
);
static_for
<
0
,
ThreadBufferNumber
,
1
>
{}([
&
](
auto
iK0
)
{
static_for
<
0
,
XSrcVectorSize
,
1
>
{}([
&
](
auto
iK1
)
{
constexpr
auto
offset_m_k
=
...
...
@@ -402,7 +505,7 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk
// normalize
y_thread_buf
(
iK0
)(
Number
<
offset_m_k
>
{})
=
(
x_thread_buf
(
iK0
)(
Number
<
offset_m_k
>
{})
-
mean_thread_buf
(
iM
))
*
divisor
;
inv_std_thread_buf
(
iM
)
;
// gamma
y_thread_buf
(
iK0
)(
Number
<
offset_m_k
>
{})
=
...
...
include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp
View file @
1b5af83d
...
...
@@ -462,7 +462,6 @@ struct mfma_type<MfmaInstr::mfma_f64_16x16x4f64>
}
};
#if defined CK_ENABLE_FP8
template
<
>
struct
mfma_type
<
MfmaInstr
::
mfma_f32_32x32x16f8f8
>
{
...
...
@@ -506,9 +505,7 @@ struct mfma_type<MfmaInstr::mfma_f32_16x16x32f8f8>
intrin_mfma_f32_16x16x32f8f8
<
MPerXdlops
,
NPerXdlops
>::
Run
(
a
,
b
,
reg_c
);
}
};
#endif
#if defined CK_ENABLE_BF8
template
<
>
struct
mfma_type
<
MfmaInstr
::
mfma_f32_32x32x16bf8bf8
>
{
...
...
@@ -552,9 +549,7 @@ struct mfma_type<MfmaInstr::mfma_f32_16x16x32bf8bf8>
intrin_mfma_f32_16x16x32bf8bf8
<
MPerXdlops
,
NPerXdlops
>::
Run
(
a
,
b
,
reg_c
);
}
};
#endif
#if defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8
template
<
>
struct
mfma_type
<
MfmaInstr
::
mfma_f32_32x32x16f8bf8
>
{
...
...
@@ -598,9 +593,7 @@ struct mfma_type<MfmaInstr::mfma_f32_16x16x32f8bf8>
intrin_mfma_f32_16x16x32f8bf8
<
MPerXdlops
,
NPerXdlops
>::
Run
(
a
,
b
,
reg_c
);
}
};
#endif
#if defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8
template
<
>
struct
mfma_type
<
MfmaInstr
::
mfma_f32_32x32x16bf8f8
>
{
...
...
@@ -644,7 +637,6 @@ struct mfma_type<MfmaInstr::mfma_f32_16x16x32bf8f8>
intrin_mfma_f32_16x16x32bf8f8
<
MPerXdlops
,
NPerXdlops
>::
Run
(
a
,
b
,
reg_c
);
}
};
#endif
template
<
typename
base_type
,
index_t
MPerXdlops
,
...
...
@@ -792,7 +784,6 @@ struct MfmaSelector
}
#endif
#if defined CK_ENABLE_FP8
template
<
>
static
constexpr
auto
GetMfma
<
f8_t
,
32
,
32
>
()
{
...
...
@@ -804,9 +795,7 @@ struct MfmaSelector
{
return
MfmaInstr
::
mfma_f32_16x16x32f8f8
;
}
#endif
#if defined CK_ENABLE_BF8
template
<
>
static
constexpr
auto
GetMfma
<
bf8_t
,
32
,
32
>
()
{
...
...
@@ -818,9 +807,7 @@ struct MfmaSelector
{
return
MfmaInstr
::
mfma_f32_16x16x32bf8bf8
;
}
#endif
#if defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8
template
<
>
static
constexpr
auto
GetMfma
<
f8_t
,
32
,
32
,
bf8_t
>
()
{
...
...
@@ -832,9 +819,7 @@ struct MfmaSelector
{
return
MfmaInstr
::
mfma_f32_16x16x32f8bf8
;
}
#endif
#if defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8
template
<
>
static
constexpr
auto
GetMfma
<
bf8_t
,
32
,
32
,
f8_t
>
()
{
...
...
@@ -846,7 +831,6 @@ struct MfmaSelector
{
return
MfmaInstr
::
mfma_f32_16x16x32bf8f8
;
}
#endif
static
constexpr
auto
selected_mfma
=
mfma_type
<
GetMfma
<
base_type
,
MPerXdlops
,
NPerXdlops
,
additional_type
>
()
>
{};
...
...
@@ -1051,18 +1035,10 @@ struct XdlopsGemm
static_assert
(
is_same
<
base_type
,
double
>::
value
||
is_same
<
base_type
,
float
>::
value
||
is_same
<
base_type
,
half_t
>::
value
||
is_same
<
base_type
,
bhalf_t
>::
value
||
is_same
<
base_type
,
int8_t
>::
value
#if defined CK_ENABLE_FP8
||
is_same
<
base_type
,
f8_t
>::
value
#endif
#if defined CK_ENABLE_BF8
||
is_same
<
base_type
,
bf8_t
>::
value
#endif
#if defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8
||
(
is_same
<
base_type
,
f8_t
>::
value
&&
is_same
<
additional_type
,
bf8_t
>::
value
)
||
(
is_same
<
base_type
,
bf8_t
>::
value
&&
is_same
<
additional_type
,
f8_t
>::
value
)
#endif
,
is_same
<
base_type
,
int8_t
>::
value
||
is_same
<
base_type
,
f8_t
>::
value
||
is_same
<
base_type
,
bf8_t
>::
value
||
(
is_same
<
base_type
,
f8_t
>::
value
&&
is_same
<
additional_type
,
bf8_t
>::
value
)
||
(
is_same
<
base_type
,
bf8_t
>::
value
&&
is_same
<
additional_type
,
f8_t
>::
value
),
"base base_type must be double, float, half, bfloat16, int8_t, f8_t or bf8_t!"
);
static_for
<
0
,
KPack
/
mfma_instr
.
k_per_blk
,
1
>
{}([
&
](
auto
k
)
{
...
...
include/ck/utility/amd_xdlops.hpp
View file @
1b5af83d
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#ifndef CK_AMD_XDLOPS_HPP
#define CK_AMD_XDLOPS_HPP
#include "data_type.hpp"
#pragma once
namespace
ck
{
...
...
@@ -355,7 +352,6 @@ struct intrin_mfma_f64_16x16x4f64<16, 16>
}
};
#if defined CK_ENABLE_FP8
template
<
index_t
MPerWave
,
index_t
NPerWave
>
struct
intrin_mfma_f32_32x32x16f8f8
;
...
...
@@ -418,9 +414,7 @@ struct intrin_mfma_f32_16x16x32f8f8<16, 16>
#endif
}
};
#endif
#if defined CK_ENABLE_BF8
template
<
index_t
MPerWave
,
index_t
NPerWave
>
struct
intrin_mfma_f32_32x32x16bf8bf8
;
...
...
@@ -483,9 +477,7 @@ struct intrin_mfma_f32_16x16x32bf8bf8<16, 16>
#endif
}
};
#endif
#if defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8
template
<
index_t
MPerWave
,
index_t
NPerWave
>
struct
intrin_mfma_f32_32x32x16f8bf8
;
...
...
@@ -548,9 +540,7 @@ struct intrin_mfma_f32_16x16x32f8bf8<16, 16>
#endif
}
};
#endif
#if defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8
template
<
index_t
MPerWave
,
index_t
NPerWave
>
struct
intrin_mfma_f32_32x32x16bf8f8
;
...
...
@@ -613,6 +603,5 @@ struct intrin_mfma_f32_16x16x32bf8f8<16, 16>
#endif
}
};
#endif
}
// namespace ck
#endif
include/ck/utility/data_type.hpp
View file @
1b5af83d
...
...
@@ -9,15 +9,9 @@ namespace ck {
using
bhalf_t
=
ushort
;
using
half_t
=
_Float16
;
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
using
int4_t
=
_BitInt
(
4
);
#endif
#if defined CK_ENABLE_FP8
using
f8_t
=
_BitInt
(
8
);
#endif
#if defined CK_ENABLE_BF8
using
bf8_t
=
unsigned
_BitInt
(
8
);
#endif
using
int4_t
=
_BitInt
(
4
);
using
f8_t
=
_BitInt
(
8
);
using
bf8_t
=
unsigned
_BitInt
(
8
);
// vector_type
template
<
typename
T
,
index_t
N
>
...
...
@@ -148,23 +142,19 @@ struct scalar_type<int4_t>
};
#endif
#if defined CK_ENABLE_FP8
template
<
>
struct
scalar_type
<
f8_t
>
{
using
type
=
f8_t
;
static
constexpr
index_t
vector_size
=
1
;
};
#endif
#if defined CK_ENABLE_BF8
template
<
>
struct
scalar_type
<
bf8_t
>
{
using
type
=
bf8_t
;
static
constexpr
index_t
vector_size
=
1
;
};
#endif
template
<
typename
T
>
struct
vector_type
<
T
,
1
>
...
...
@@ -968,24 +958,20 @@ using int8x32_t = typename vector_type<int8_t, 32>::type;
using
int8x64_t
=
typename
vector_type
<
int8_t
,
64
>::
type
;
// f8
#if defined CK_ENABLE_FP8
using
f8x2_t
=
typename
vector_type
<
f8_t
,
2
>::
type
;
using
f8x4_t
=
typename
vector_type
<
f8_t
,
4
>::
type
;
using
f8x8_t
=
typename
vector_type
<
f8_t
,
8
>::
type
;
using
f8x16_t
=
typename
vector_type
<
f8_t
,
16
>::
type
;
using
f8x32_t
=
typename
vector_type
<
f8_t
,
32
>::
type
;
using
f8x64_t
=
typename
vector_type
<
f8_t
,
64
>::
type
;
#endif
// bf8
#if defined CK_ENABLE_BF8
using
bf8x2_t
=
typename
vector_type
<
bf8_t
,
2
>::
type
;
using
bf8x4_t
=
typename
vector_type
<
bf8_t
,
4
>::
type
;
using
bf8x8_t
=
typename
vector_type
<
bf8_t
,
8
>::
type
;
using
bf8x16_t
=
typename
vector_type
<
bf8_t
,
16
>::
type
;
using
bf8x32_t
=
typename
vector_type
<
bf8_t
,
32
>::
type
;
using
bf8x64_t
=
typename
vector_type
<
bf8_t
,
64
>::
type
;
#endif
template
<
typename
T
>
struct
NumericLimits
...
...
@@ -1033,7 +1019,6 @@ struct NumericLimits<int4_t>
};
#endif // CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
#if defined CK_ENABLE_FP8
template
<
>
struct
NumericLimits
<
f8_t
>
{
...
...
@@ -1056,9 +1041,7 @@ struct NumericLimits<f8_t>
__host__
__device__
static
constexpr
f8_t
QuietNaN
()
{
return
f8_t
(
binary_qnan
);
}
};
#endif
#if defined CK_ENABLE_BF8
template
<
>
struct
NumericLimits
<
bf8_t
>
{
...
...
@@ -1081,7 +1064,6 @@ struct NumericLimits<bf8_t>
__host__
__device__
static
constexpr
bf8_t
QuietNaN
()
{
return
bf8_t
(
binary_qnan
);
}
};
#endif
template
<
typename
T
>
struct
NumericUtils
...
...
@@ -1093,6 +1075,7 @@ struct NumericUtils<float>
{
static
constexpr
int
exp
=
8
;
static
constexpr
int
mant
=
23
;
static
constexpr
int
bias
=
127
;
static
constexpr
uint32_t
nan_mask
=
0x7F800000
;
static
constexpr
uint32_t
head_mask
=
0xFF800000
;
static
constexpr
uint32_t
mant_mask
=
0x7FFFFF
;
...
...
@@ -1109,6 +1092,7 @@ struct NumericUtils<half_t>
{
static
constexpr
int
exp
=
5
;
static
constexpr
int
mant
=
10
;
static
constexpr
int
bias
=
15
;
static
constexpr
uint16_t
nan_mask
=
0x7C00
;
static
constexpr
uint16_t
head_mask
=
0xFC00
;
static
constexpr
uint16_t
mant_mask
=
0x3FF
;
...
...
@@ -1120,22 +1104,21 @@ struct NumericUtils<half_t>
using
bitwise_type
=
uint16_t
;
};
#if defined CK_ENABLE_FP8
template
<
>
struct
NumericUtils
<
f8_t
>
{
static
constexpr
int
exp
=
4
;
static
constexpr
int
mant
=
3
;
static
constexpr
int
bias
=
8
;
// negative zero nan mode
// static constexpr int bias = 7; // ieee mode
};
#endif
#if defined CK_ENABLE_BF8
template
<
>
struct
NumericUtils
<
bf8_t
>
{
static
constexpr
int
exp
=
5
;
static
constexpr
int
mant
=
2
;
static
constexpr
int
bias
=
16
;
// negative zero nan mode
// static constexpr int bias = 15; // ieee mode
};
#endif
}
// namespace ck
include/ck/utility/f8_utils.hpp
View file @
1b5af83d
...
...
@@ -5,9 +5,6 @@
#include "ck/utility/data_type.hpp"
// these conversions are disabled if native conversions available
#if !defined(__gfx940__) && !defined(__gfx941__) && !defined(__gfx942__)
#if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8
namespace
ck
{
// fp8 rounding modes
...
...
@@ -19,6 +16,9 @@ enum class f8_rounding_mode
stochastic
};
__host__
inline
int
clz
(
uint32_t
x
)
{
return
__builtin_clz
(
x
);
}
__device__
inline
int
clz
(
uint32_t
x
)
{
return
__clz
(
x
);
}
}
// namespace ck
namespace
ck
::
utils
{
...
...
@@ -36,7 +36,7 @@ __host__ __device__ Y run_cast_to_f8(X x, uint32_t rng)
constexpr
int
in_exp
=
NumericUtils
<
X
>::
exp
;
constexpr
int
in_mant
=
NumericUtils
<
X
>::
mant
;
int
exponent
;
int
exponent
,
bias
;
uint32_t
head
,
mantissa
,
sign
;
// nan code is same for float and half
constexpr
Y
nan_code
=
0x80
;
...
...
@@ -51,12 +51,11 @@ __host__ __device__ Y run_cast_to_f8(X x, uint32_t rng)
mantissa
=
x_bitwise
&
NumericUtils
<
X
>::
mant_mask
;
exponent
=
(
head
>>
in_mant
)
&
NumericUtils
<
X
>::
exp_mask
;
sign
=
head
>>
(
in_exp
+
in_mant
);
bias
=
NumericUtils
<
X
>::
bias
;
uint32_t
signed_inf
=
(
sign
<<
(
in_exp
+
in_mant
))
+
(((
1
<<
in_exp
)
-
1
)
<<
in_mant
);
uint32_t
drop_mask
=
(
1
<<
(
in_mant
-
out_mant
))
-
1
;
constexpr
int
max_exp
=
(
1
<<
out_exp
)
-
(
negative_zero_nan
?
1
:
2
);
constexpr
int
exp_low_cutoff
=
(
1
<<
(
in_exp
-
1
))
-
(
1
<<
(
out_exp
-
1
))
+
1
-
(
negative_zero_nan
?
1
:
0
);
if
constexpr
(
negative_zero_nan
)
{
...
...
@@ -69,56 +68,107 @@ __host__ __device__ Y run_cast_to_f8(X x, uint32_t rng)
return
signed_inf
+
(
mantissa
!=
0
?
1
:
0
);
}
// if input is half and output is bf8
if
((
NumericUtils
<
X
>::
mant
==
10
)
&&
(
NumericUtils
<
Y
>::
mant
==
2
)
&&
negative_zero_nan
&&
exponent
==
0
)
{
exponent
+=
1
;
while
(
mantissa
<
(
1
<<
in_mant
))
{
mantissa
<<=
1
;
exponent
-=
1
;
}
mantissa
&=
~
(
1
<<
in_mant
);
}
// check if x is 0.0
if
(
x_bitwise
==
0
)
return
0
;
exponent
-=
exp_low_cutoff
-
1
;
if
(
exponent
<=
0
)
drop_mask
=
(
1
<<
(
in_mant
-
out_mant
+
1
-
exponent
))
-
1
;
mantissa
+=
1
<<
in_mant
;
// apply random number if needed
mantissa
+=
(
stoch
?
rng
:
mantissa
)
&
drop_mask
;
if
(
mantissa
>=
(
2
<<
in_mant
))
{
mantissa
>>=
1
;
exponent
++
;
// First need to check if it is normal or denorm as there is a difference of implict 1
// Then need to adjust the exponent to align with the F8 exponent, in the meanwhile, shift
// The mantissa. Then for stochastic rounding, add rng to mantissa and truncate. And for
// RNE, no need to add rng. Then probably need to check whether there is carry and adjust
// exponent and mantissa again3
// For IEEE bias mode, the bias is 2^(k-1)-1 where k is the width of exponent bits
const
int
out_bias
=
(
1
<<
(
out_exp
-
1
))
-
1
+
(
negative_zero_nan
?
1
:
0
);
const
int
out_denormal_act_exponent
=
1
-
out_bias
;
// actual exponent of f8 denormal
// act_exponent is the actual exponent of fp32/fp16 (after subtracting bias)
// out_exponent is the converted f8 exponent with bias encoding
// exponent_diff is the diff between fp32/fp16 exponent and f8 exponent,
// the difference needs to be adjusted and mantissa shifted
int
act_exponent
,
out_exponent
,
exponent_diff
;
if
(
exponent
==
0
)
{
// fp32/fp16 is in denormal.
/* fp32 denormal is below 2^-127 so it is usually not a concern here, we mostly concern fp16
here. In this case, f8 is usually in denormal. But there could be exceptions. fp16 denormal has
exponent bias 15 while bf8 with NANOO has exponent bias 16. It means that there are some numbers in
fp16 denormal but they are bf8 (NANOO) normals - smallest bf8 (NANOO) normal is 2^-15. fp16 numbers
where exponent==0 (actual exponent -14) and highest bit of mantissa is 1 are bf8 (NANOO) normal.
In this case, the fp16 mantissa should be shift left by 1 */
act_exponent
=
exponent
-
bias
+
1
;
exponent_diff
=
out_denormal_act_exponent
-
act_exponent
;
// actual exponent is exponent-bias+1 as it is denormal
}
else
{
// fp32/fp16 is normal with implicit 1
act_exponent
=
exponent
-
bias
;
if
(
act_exponent
<=
out_denormal_act_exponent
)
{
/* This is the case where fp32/fp16 is normal but it is in f8 denormal range.
For example fp8 nanoo mode, denormal exponent is -7, but if the fp32/fp16
actual exponent is -7, it is actually larger due to the implict 1,
Therefore it needs to be adjust to -6 and mantissa shift right by 1.
So for fp32/fp16, exponent -8 is the cut point to convert to fp8 nanoo */
exponent_diff
=
out_denormal_act_exponent
-
act_exponent
;
}
else
{
// both fp32/fp16 and f8 are in normal range
exponent_diff
=
0
;
// exponent_diff=0 does not mean there is no difference for this case,
// act_exponent could be larger. Just that it does not need shift mantissa
}
mantissa
+=
(
1
<<
in_mant
);
// Add the implicit 1 into mantissa
}
mantissa
>>=
(
in_mant
-
out_mant
);
// check negative exponent
if
(
exponent
<=
0
)
bool
midpoint
=
(
mantissa
&
((
1
<<
(
in_mant
-
out_mant
+
exponent_diff
))
-
1
))
==
(
1
<<
(
in_mant
-
out_mant
+
exponent_diff
-
1
));
/* This part is a bit tricky. The judgment of whether it is a tie needs to be done before we
shift right as shift right could rip off some residual part and make something not midpoint look
like midpoint. For example, the fp16 number 0x1002 (0 00100 0000000010), it is larger than
midpoint, but after shift right by 4 bits, it would look like midpoint. */
if
(
exponent_diff
>
0
)
mantissa
>>=
exponent_diff
;
else
if
(
exponent_diff
==
-
1
)
mantissa
<<=
-
exponent_diff
;
bool
implicit_one
=
mantissa
&
(
1
<<
in_mant
);
// if there is no implict 1, it means the f8 is denormal and need to adjust to denorm exponent
out_exponent
=
(
act_exponent
+
exponent_diff
)
/*actual f8 exponent*/
+
out_bias
-
(
implicit_one
?
0
:
1
);
// Now we have the exponent and mantissa adjusted
bool
odd
=
mantissa
&
(
1
<<
(
in_mant
-
out_mant
));
// if the least significant bit that is not truncated is 1
mantissa
+=
(
stoch
?
rng
:
(
midpoint
?
(
odd
?
mantissa
:
mantissa
-
1
)
:
mantissa
))
&
drop_mask
;
// Now we deal with overflow
if
(
out_exponent
==
0
)
{
if
(
x_bitwise
==
0
)
return
0
;
else
if
((
1
<<
in_mant
)
&
mantissa
)
{
// subnormal range; represented by a subnormal float8 (exponent 0)
// and involves loss of accuracy
mantissa
>>=
1
-
exponent
;
exponent
=
0
;
out_exponent
=
1
;
// denormal overflow to become normal, promote exponent
// No need to make 1 implicit now as it will be addressed later
}
}
// above range: quantize to maximum possible float of the same sign
else
if
(
exponent
>
max_exp
)
else
{
if
((
1
<<
(
in_mant
+
1
))
&
mantissa
)
{
mantissa
>>=
1
;
out_exponent
++
;
// No need to make 1 implicit now as it will be addressed later
}
}
mantissa
>>=
(
in_mant
-
out_mant
);
if
(
out_exponent
>
max_exp
)
{
if
(
clip
)
{
mantissa
=
(
1
<<
out_mant
)
-
1
;
exponent
=
max_exp
;
mantissa
=
(
1
<<
out_mant
)
-
1
;
out_
exponent
=
max_exp
;
}
else
{
...
...
@@ -127,10 +177,10 @@ __host__ __device__ Y run_cast_to_f8(X x, uint32_t rng)
}
// check if x is 0.0 or -0.0
if
(
exponent
==
0
&&
mantissa
==
0
)
if
(
out_
exponent
==
0
&&
mantissa
==
0
)
return
negative_zero_nan
?
0
:
(
sign
<<
(
out_exp
+
out_mant
));
mantissa
&=
(
1
<<
out_mant
)
-
1
;
return
(
sign
<<
(
out_exp
+
out_mant
))
|
(
exponent
<<
out_mant
)
|
mantissa
;
return
(
sign
<<
(
out_exp
+
out_mant
))
|
(
out_
exponent
<<
out_mant
)
|
mantissa
;
}
template
<
typename
X
,
typename
Y
,
bool
negative_zero_nan
>
...
...
@@ -196,12 +246,9 @@ __host__ __device__ Y run_cast_from_f8(X x)
if
(
exponent
==
0
)
{
// guaranteed mantissa!=0 since cases 0x0 and 0x80 are handled above
exponent
++
;
while
(
mantissa
<
(
1
<<
in_mant
))
{
mantissa
<<=
1
;
exponent
--
;
}
int
sh
=
1
+
clz
(
mantissa
)
-
(
32
-
in_mant
);
mantissa
<<=
sh
;
exponent
+=
1
-
sh
;
mantissa
&=
((
1
<<
in_mant
)
-
1
);
}
exponent
+=
exp_low_cutoff
-
1
;
...
...
@@ -244,5 +291,3 @@ __host__ __device__ Y cast_from_f8(X x)
}
}
// namespace ck::utils
#endif // #if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8
#endif // #if !defined(__gfx940__) && !defined(__gfx941__) && !defined(__gfx942__)
Prev
1
2
3
4
5
6
7
8
9
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment