Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
45c6c530
"backends/v2/src/client/grpc_client.rs" did not exist on "f16f2f5ae15a1b728559e1ac1a8b840f15a54efd"
Unverified
Commit
45c6c530
authored
Nov 10, 2023
by
arai713
Committed by
GitHub
Nov 10, 2023
Browse files
Merge branch 'develop' into hip_tensor_permute
parents
4026fced
49e52bb3
Changes
49
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1922 additions
and
349 deletions
+1922
-349
include/ck/tensor_operation/gpu/device/impl/device_column_to_image_impl.hpp
...operation/gpu/device/impl/device_column_to_image_impl.hpp
+14
-15
include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_abd_xdl_cshuffle.hpp
...ice/impl/device_contraction_multiple_abd_xdl_cshuffle.hpp
+8
-6
include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_abd_xdl_cshuffle.hpp
...gpu/device/impl/device_gemm_multiple_abd_xdl_cshuffle.hpp
+8
-6
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp
...device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp
+2
-2
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp
...vice_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp
+2
-2
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_dl.hpp
...ion/gpu/device/impl/device_grouped_conv_bwd_weight_dl.hpp
+2
-2
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle.hpp
...ice/impl/device_grouped_conv_bwd_weight_wmma_cshuffle.hpp
+2
-2
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp
...vice/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp
+2
-2
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp
.../device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp
+2
-2
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp
...impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp
+2
-2
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_cshuffle.hpp
.../impl/device_grouped_conv_fwd_multiple_d_xdl_cshuffle.hpp
+332
-150
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp
...r_operation/gpu/device/impl/device_grouped_conv_utils.hpp
+126
-3
include/ck/tensor_operation/gpu/device/impl/device_image_to_column_impl.hpp
...operation/gpu/device/impl/device_image_to_column_impl.hpp
+14
-15
include/ck/tensor_operation/gpu/device/impl/device_normalization_bwd_gamma_beta_impl.hpp
.../device/impl/device_normalization_bwd_gamma_beta_impl.hpp
+464
-0
include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp
...r_operation/gpu/element/binary_element_wise_operation.hpp
+5
-2
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp
...tion/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp
+15
-9
include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_bwd_gamma_beta.hpp
...d/normalization/gridwise_normalization_bwd_gamma_beta.hpp
+343
-0
library/include/ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp
...ary/reference_tensor_operation/cpu/reference_conv_fwd.hpp
+195
-129
library/include/ck/library/reference_tensor_operation/cpu/reference_groupnorm_bwd.hpp
...eference_tensor_operation/cpu/reference_groupnorm_bwd.hpp
+207
-0
library/include/ck/library/reference_tensor_operation/cpu/reference_layernorm_bwd.hpp
...eference_tensor_operation/cpu/reference_layernorm_bwd.hpp
+177
-0
No files found.
include/ck/tensor_operation/gpu/device/impl/device_column_to_image_impl.hpp
View file @
45c6c530
...
...
@@ -263,19 +263,18 @@ struct DeviceColumnToImageImpl
decltype
(
BlockToCTileMap_M00_N0_M01Adapt
<
MPerBlock
,
KPerBlock
,
InputGridDesc
>
(
InputGridDesc
{}))
>
;
using
GridwiseTensorRearrangeKernel
=
GridwiseTensorRearrange
<
InputGridDesc
,
InputDataType
,
OutputGridDesc
,
OutputDataType
,
BlockSize
,
MPerBlock
,
KPerBlock
,
ThreadClusterLengths
,
ScalarPerVector
,
InMemoryDataOperationEnum
::
Add
,
Block2ETileMap
,
ComputePtrOffsetOfStridedBatch
<
I0
>>
;
using
GridwiseTensorRearrangeKernel
=
GridwiseTensorRearrange
<
InputGridDesc
,
InputDataType
,
OutputGridDesc
,
OutputDataType
,
BlockSize
,
MPerBlock
,
KPerBlock
,
ThreadClusterLengths
,
ScalarPerVector
,
InMemoryDataOperationEnum
::
Add
,
Block2ETileMap
,
ComputePtrOffsetOfStridedBatch
<>>
;
struct
Argument
:
public
BaseArgument
{
...
...
@@ -453,7 +452,7 @@ struct DeviceColumnToImageImpl
std
::
vector
<
const
InputDataType
*>
p_in_container_
;
std
::
vector
<
OutputDataType
*>
p_out_container_
;
ComputePtrOffsetOfStridedBatch
<
I0
>
compute_ptr_offset_of_batch_
;
ComputePtrOffsetOfStridedBatch
<>
compute_ptr_offset_of_batch_
;
};
struct
Invoker
:
public
BaseInvoker
...
...
@@ -471,7 +470,7 @@ struct DeviceColumnToImageImpl
OutputGridDesc
,
OutputDataType
,
Block2ETileMap
,
ComputePtrOffsetOfStridedBatch
<
I0
>
,
ComputePtrOffsetOfStridedBatch
<>
,
GridwiseTensorRearrangeKernel
>
;
// Execute each set of independent filters
...
...
include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_abd_xdl_cshuffle.hpp
View file @
45c6c530
...
...
@@ -385,9 +385,11 @@ struct DeviceContractionMultipleABD_Xdl_CShuffle
// desc for blockwise copy
using
AsGridDesc_AK0_M_AK1
=
remove_cvref_t
<
decltype
(
GridwiseGemm
::
MakeAsGridDescriptor_AK0_M_AK1
(
AsGridDesc_M_K
{}))
>
;
remove_cvref_t
<
decltype
(
GridwiseGemm
::
MakeDefaultAsGridDescriptor_AK0_M_AK1
(
AsGridDesc_M_K
{}))
>
;
using
BsGridDesc_BK0_N_BK1
=
remove_cvref_t
<
decltype
(
GridwiseGemm
::
MakeBsGridDescriptor_BK0_N_BK1
(
BsGridDesc_N_K
{}))
>
;
remove_cvref_t
<
decltype
(
GridwiseGemm
::
MakeDefaultBsGridDescriptor_BK0_N_BK1
(
BsGridDesc_N_K
{}))
>
;
using
DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
=
remove_cvref_t
<
decltype
(
GridwiseGemm
::
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
DsGridDesc_M_N
{}))
>
;
...
...
@@ -397,7 +399,7 @@ struct DeviceContractionMultipleABD_Xdl_CShuffle
// block-to-e-tile map
using
Block2ETileMap
=
remove_cvref_t
<
decltype
(
GridwiseGemm
::
MakeBlock2ETileMap
(
EGridDesc_M_N
{}))
>
;
remove_cvref_t
<
decltype
(
GridwiseGemm
::
Make
Default
Block2ETileMap
(
EGridDesc_M_N
{}))
>
;
// Argument
struct
Argument
:
public
BaseArgument
...
...
@@ -429,7 +431,7 @@ struct DeviceContractionMultipleABD_Xdl_CShuffle
bs_grid_desc_bk0_n_bk1_
{},
ds_grid_desc_mblock_mperblock_nblock_nperblock_
{},
e_grid_desc_mblock_mperblock_nblock_nperblock_
{},
block_2_etile_map_
{
GridwiseGemm
::
MakeBlock2ETileMap
(
e_grid_desc_m_n_
)},
block_2_etile_map_
{
GridwiseGemm
::
Make
Default
Block2ETileMap
(
e_grid_desc_m_n_
)},
a_element_op_
{
a_element_op
},
b_element_op_
{
b_element_op
},
cde_element_op_
{
cde_element_op
}
...
...
@@ -481,10 +483,10 @@ struct DeviceContractionMultipleABD_Xdl_CShuffle
block_2_etile_map_
))
{
as_grid_desc_ak0_m_ak1_
=
GridwiseGemm
::
MakeAsGridDescriptor_AK0_M_AK1
(
as_grid_desc_m_k_
);
GridwiseGemm
::
Make
Default
AsGridDescriptor_AK0_M_AK1
(
as_grid_desc_m_k_
);
bs_grid_desc_bk0_n_bk1_
=
GridwiseGemm
::
MakeBsGridDescriptor_BK0_N_BK1
(
bs_grid_desc_n_k_
);
GridwiseGemm
::
Make
Default
BsGridDescriptor_BK0_N_BK1
(
bs_grid_desc_n_k_
);
ds_grid_desc_mblock_mperblock_nblock_nperblock_
=
GridwiseGemm
::
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
...
...
include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_abd_xdl_cshuffle.hpp
View file @
45c6c530
...
...
@@ -305,9 +305,11 @@ struct DeviceGemmMultipleABD_Xdl_CShuffle : public DeviceGemmMultipleABD<AsLayou
// desc for blockwise copy
using
AsGridDesc_AK0_M_AK1
=
remove_cvref_t
<
decltype
(
GridwiseGemm
::
MakeAsGridDescriptor_AK0_M_AK1
(
AsGridDesc_M_K
{}))
>
;
remove_cvref_t
<
decltype
(
GridwiseGemm
::
MakeDefaultAsGridDescriptor_AK0_M_AK1
(
AsGridDesc_M_K
{}))
>
;
using
BsGridDesc_BK0_N_BK1
=
remove_cvref_t
<
decltype
(
GridwiseGemm
::
MakeBsGridDescriptor_BK0_N_BK1
(
BsGridDesc_N_K
{}))
>
;
remove_cvref_t
<
decltype
(
GridwiseGemm
::
MakeDefaultBsGridDescriptor_BK0_N_BK1
(
BsGridDesc_N_K
{}))
>
;
using
DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
=
remove_cvref_t
<
decltype
(
GridwiseGemm
::
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
DsGridDesc_M_N
{}))
>
;
...
...
@@ -317,7 +319,7 @@ struct DeviceGemmMultipleABD_Xdl_CShuffle : public DeviceGemmMultipleABD<AsLayou
// block-to-e-tile map
using
Block2ETileMap
=
remove_cvref_t
<
decltype
(
GridwiseGemm
::
MakeBlock2ETileMap
(
EGridDesc_M_N
{}))
>
;
remove_cvref_t
<
decltype
(
GridwiseGemm
::
Make
Default
Block2ETileMap
(
EGridDesc_M_N
{}))
>
;
// Argument
struct
Argument
:
public
BaseArgument
...
...
@@ -349,7 +351,7 @@ struct DeviceGemmMultipleABD_Xdl_CShuffle : public DeviceGemmMultipleABD<AsLayou
bs_grid_desc_bk0_n_bk1_
{},
ds_grid_desc_mblock_mperblock_nblock_nperblock_
{},
e_grid_desc_mblock_mperblock_nblock_nperblock_
{},
block_2_etile_map_
{
GridwiseGemm
::
MakeBlock2ETileMap
(
e_grid_desc_m_n_
)},
block_2_etile_map_
{
GridwiseGemm
::
Make
Default
Block2ETileMap
(
e_grid_desc_m_n_
)},
a_element_op_
{
a_element_op
},
b_element_op_
{
b_element_op
},
cde_element_op_
{
cde_element_op
},
...
...
@@ -407,10 +409,10 @@ struct DeviceGemmMultipleABD_Xdl_CShuffle : public DeviceGemmMultipleABD<AsLayou
block_2_etile_map_
))
{
as_grid_desc_ak0_m_ak1_
=
GridwiseGemm
::
MakeAsGridDescriptor_AK0_M_AK1
(
as_grid_desc_m_k_
);
GridwiseGemm
::
Make
Default
AsGridDescriptor_AK0_M_AK1
(
as_grid_desc_m_k_
);
bs_grid_desc_bk0_n_bk1_
=
GridwiseGemm
::
MakeBsGridDescriptor_BK0_N_BK1
(
bs_grid_desc_n_k_
);
GridwiseGemm
::
Make
Default
BsGridDescriptor_BK0_N_BK1
(
bs_grid_desc_n_k_
);
ds_grid_desc_mblock_mperblock_nblock_nperblock_
=
GridwiseGemm
::
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp
View file @
45c6c530
...
...
@@ -517,7 +517,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle
std
::
vector
<
typename
GridwiseGemm
::
DefaultBlock2CTileMap
>
block_2_ctile_map_container_
;
// for computing batch offset
ComputePtrOffsetOfStridedBatch
<
NumDTensor
>
compute_ptr_offset_of_batch_
;
ComputePtrOffsetOfStridedBatch
<
I1
,
I1
,
NumDTensor
>
compute_ptr_offset_of_batch_
;
// element-wise op
AElementwiseOp
a_element_op_
;
...
...
@@ -579,7 +579,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle
typename
GridwiseGemm
::
DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
GridwiseGemm
::
EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
remove_reference_t
<
typename
GridwiseGemm
::
DefaultBlock2CTileMap
>
,
ComputePtrOffsetOfStridedBatch
<
NumDTensor
>
,
ComputePtrOffsetOfStridedBatch
<
I1
,
I1
,
NumDTensor
>
,
has_main_loop
>
;
return
launch_and_time_kernel
(
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp
View file @
45c6c530
...
...
@@ -677,7 +677,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
std
::
vector
<
Block2ETileMap
>
block_2_etile_map_container_
;
// for computing batch offset
ComputePtrOffsetOfStridedBatch
<
NumDTensor
>
compute_ptr_offset_of_batch_
;
ComputePtrOffsetOfStridedBatch
<
I1
,
I1
,
NumDTensor
>
compute_ptr_offset_of_batch_
;
// element-wise op
AElementwiseOp
a_element_op_
;
...
...
@@ -746,7 +746,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
DeviceOp
::
DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
,
DeviceOp
::
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
,
Block2ETileMap
,
ComputePtrOffsetOfStridedBatch
<
NumDTensor
>
,
ComputePtrOffsetOfStridedBatch
<
I1
,
I1
,
NumDTensor
>
,
has_main_loop
>
;
return
launch_and_time_kernel
(
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_dl.hpp
View file @
45c6c530
...
...
@@ -927,7 +927,7 @@ struct DeviceGroupedConvBwdWeight_Dl : public DeviceGroupedConvBwdWeight<NDimSpa
Block2CTileMap
block_2_ctile_map_
;
// for computing batch offset
ComputePtrOffsetOfStridedBatch
<
I0
>
compute_ptr_offset_of_batch_
;
ComputePtrOffsetOfStridedBatch
<>
compute_ptr_offset_of_batch_
;
// element-wise op
OutElementwiseOperation
a_element_op_
;
...
...
@@ -999,7 +999,7 @@ struct DeviceGroupedConvBwdWeight_Dl : public DeviceGroupedConvBwdWeight<NDimSpa
remove_reference_t
<
DeviceOp
::
BGridDesc_B_K0_N0_N1_K1
>
,
remove_reference_t
<
DeviceOp
::
CGridDesc_M0_M10_M11_N0_N10_N11
>
,
remove_reference_t
<
DeviceOp
::
Block2CTileMap
>
,
ComputePtrOffsetOfStridedBatch
<
I0
>
,
ComputePtrOffsetOfStridedBatch
<>
,
has_main_loop
,
has_double_loop
>
;
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle.hpp
View file @
45c6c530
...
...
@@ -565,7 +565,7 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffle
Block2CTileMap
block_2_ctile_map_
;
// for computing batch offset
ComputePtrOffsetOfStridedBatch
<
I0
>
compute_ptr_offset_of_batch_
;
ComputePtrOffsetOfStridedBatch
<>
compute_ptr_offset_of_batch_
;
OutElementwiseOperation
a_element_op_
;
InElementwiseOperation
b_element_op_
;
...
...
@@ -647,7 +647,7 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffle
DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
,
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
,
remove_reference_t
<
typename
GridwiseGemm
::
DefaultBlock2CTileMap
>
,
ComputePtrOffsetOfStridedBatch
<
I0
>
,
ComputePtrOffsetOfStridedBatch
<>
,
has_main_loop
>
;
using
EmptyTuple
=
Tuple
<>
;
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp
View file @
45c6c530
...
...
@@ -1197,7 +1197,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
Block2CTileMap
block_2_ctile_map_
;
// for computing batch offset
ComputePtrOffsetOfStridedBatch
<
I0
>
compute_ptr_offset_of_batch_
;
ComputePtrOffsetOfStridedBatch
<>
compute_ptr_offset_of_batch_
;
index_t
M01_
;
index_t
N01_
;
...
...
@@ -1276,7 +1276,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
remove_reference_t
<
DeviceOp
::
BGridDesc_K0_N_K1
>
,
remove_reference_t
<
DeviceOp
::
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
>
,
remove_reference_t
<
DeviceOp
::
Block2CTileMap
>
,
ComputePtrOffsetOfStridedBatch
<
I0
>
,
ComputePtrOffsetOfStridedBatch
<>
,
has_main_loop
>
;
return
launch_and_time_kernel
(
stream_config
,
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp
View file @
45c6c530
...
...
@@ -537,7 +537,7 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
DefaultBlock2CTileMap
block_2_ctile_map_
;
// for computing batch offset
ComputePtrOffsetOfStridedBatch
<
NumDTensor
>
compute_ptr_offset_of_batch_
;
ComputePtrOffsetOfStridedBatch
<
I1
,
I1
,
NumDTensor
>
compute_ptr_offset_of_batch_
;
// element-wise op
AElementwiseOperation
a_element_op_
;
...
...
@@ -601,7 +601,7 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
DeviceOp
::
DsGridDesc_M0_M10_M11_N0_N10_N11
,
DeviceOp
::
CGridDesc_M0_M10_M11_N0_N10_N11
,
DefaultBlock2CTileMap
,
ComputePtrOffsetOfStridedBatch
<
NumDTensor
>
,
ComputePtrOffsetOfStridedBatch
<
I1
,
I1
,
NumDTensor
>
,
has_main_loop
,
has_double_loop
>
;
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp
View file @
45c6c530
...
...
@@ -428,7 +428,7 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
typename
GridwiseOp
::
DefaultBlock2CTileMap
block_2_etile_map_
;
// for computing batch offset
ComputePtrOffsetOfStridedBatch
<
NumDTensor
>
compute_ptr_offset_of_batch_
;
ComputePtrOffsetOfStridedBatch
<
I1
,
I1
,
NumDTensor
>
compute_ptr_offset_of_batch_
;
// element-wise op
AElementwiseOperation
a_element_op_
;
...
...
@@ -485,7 +485,7 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
typename
GridwiseOp
::
DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
GridwiseOp
::
EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
remove_reference_t
<
typename
GridwiseOp
::
DefaultBlock2CTileMap
>
,
ComputePtrOffsetOfStridedBatch
<
NumDTensor
>
,
ComputePtrOffsetOfStridedBatch
<
I1
,
I1
,
NumDTensor
>
,
has_main_loop
>
;
return
launch_and_time_kernel
(
stream_config
,
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_cshuffle.hpp
View file @
45c6c530
...
...
@@ -19,6 +19,7 @@
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
...
...
@@ -56,7 +57,8 @@ namespace {
*
*/
template
<
typename
GridwiseGemm
,
typename
ABDataType
,
typename
AsPointer
,
// tuples if multi AB, pointers if no
typename
BsPointer
,
typename
DsPointer
,
typename
EDataType
,
typename
AElementwiseOperation
,
...
...
@@ -68,14 +70,16 @@ template <typename GridwiseGemm,
typename
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
Block2ETileMap
,
typename
ComputePtrOffsetOfBatch
,
bool
HasMainKBlockLoop
>
bool
HasMainKBlockLoop
,
bool
isMultiA
,
bool
isMultiB
>
__global__
void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
#endif
kernel_grouped_conv_fwd_multiple_d_xdl_cshuffle
(
const
ABDataType
*
__restrict__
p_a_grid
,
const
ABDataType
*
__restrict__
p_b_grid
,
AsPointer
p_a
s
_grid
,
BsPointer
p_b
s
_grid
,
DsPointer
p_ds_grid
,
EDataType
*
__restrict__
p_e_grid
,
const
AElementwiseOperation
a_element_op
,
...
...
@@ -98,14 +102,9 @@ __global__ void
__builtin_amdgcn_readfirstlane
(
get_grid_size
()
/
batch_count
);
const
index_t
g_idx
=
__builtin_amdgcn_readfirstlane
(
get_block_1d_id
()
/
num_blocks_per_batch
);
const
long_index_t
a_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_ptr_offset_of_batch
.
GetAPtrOffset
(
g_idx
)));
const
long_index_t
b_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_ptr_offset_of_batch
.
GetBPtrOffset
(
g_idx
)));
const
long_index_t
e_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_ptr_offset_of_batch
.
GetEPtrOffset
(
g_idx
)));
const
auto
ds_batch_offset
=
compute_ptr_offset_of_batch
.
GetDsPtrOffset
(
g_idx
);
const
auto
&
ds_batch_offset
=
compute_ptr_offset_of_batch
.
GetDsPtrOffset
(
g_idx
);
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
...
...
@@ -117,22 +116,63 @@ __global__ void
static_for
<
0
,
NumDTensor
,
1
>
{}(
[
&
](
auto
i
)
{
p_ds_grid_grp
(
i
)
=
p_ds_grid
[
i
]
+
ds_batch_offset
[
i
];
});
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
p_a_grid
+
a_batch_offset
,
p_b_grid
+
b_batch_offset
,
p_ds_grid_grp
,
p_e_grid
+
e_batch_offset
,
p_shared
,
a_element_op
,
b_element_op
,
cde_element_op
,
a_grid_desc_k0_m_k1
,
b_grid_desc_k0_n_k1
,
ds_grid_desc_mblock_mperblock_nblock_nperblock
,
e_grid_desc_mblock_mperblock_nblock_nperblock_
,
block_2_ctile_map
);
if
constexpr
(
isMultiA
||
isMultiB
)
{
AsPointer
p_as_grid_grp
;
BsPointer
p_bs_grid_grp
;
const
auto
&
as_batch_offset
=
compute_ptr_offset_of_batch
.
GetAsPtrOffset
(
g_idx
);
static
constexpr
index_t
NumATensor
=
AGridDesc_AK0_M_AK1
::
Size
();
static_for
<
0
,
NumATensor
,
1
>
{}(
[
&
](
auto
i
)
{
p_as_grid_grp
(
i
)
=
p_as_grid
[
i
]
+
as_batch_offset
[
i
];
});
const
auto
&
bs_batch_offset
=
compute_ptr_offset_of_batch
.
GetBsPtrOffset
(
g_idx
);
static
constexpr
index_t
NumBTensor
=
BGridDesc_BK0_N_BK1
::
Size
();
static_for
<
0
,
NumBTensor
,
1
>
{}(
[
&
](
auto
i
)
{
p_bs_grid_grp
(
i
)
=
p_bs_grid
[
i
]
+
bs_batch_offset
[
i
];
});
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
p_as_grid_grp
,
p_bs_grid_grp
,
p_ds_grid_grp
,
p_e_grid
+
e_batch_offset
,
p_shared
,
a_element_op
,
b_element_op
,
cde_element_op
,
a_grid_desc_k0_m_k1
,
b_grid_desc_k0_n_k1
,
ds_grid_desc_mblock_mperblock_nblock_nperblock
,
e_grid_desc_mblock_mperblock_nblock_nperblock_
,
block_2_ctile_map
);
}
else
{
const
long_index_t
a_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_ptr_offset_of_batch
.
GetAPtrOffset
(
g_idx
)));
const
long_index_t
b_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_ptr_offset_of_batch
.
GetBPtrOffset
(
g_idx
)));
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
p_as_grid
+
a_batch_offset
,
p_bs_grid
+
b_batch_offset
,
p_ds_grid_grp
,
p_e_grid
+
e_batch_offset
,
p_shared
,
a_element_op
,
b_element_op
,
cde_element_op
,
a_grid_desc_k0_m_k1
,
b_grid_desc_k0_n_k1
,
ds_grid_desc_mblock_mperblock_nblock_nperblock
,
e_grid_desc_mblock_mperblock_nblock_nperblock_
,
block_2_ctile_map
);
}
#else
ignore
=
p_a_grid
;
ignore
=
p_b_grid
;
ignore
=
p_a
s
_grid
;
ignore
=
p_b
s
_grid
;
ignore
=
p_ds_grid
;
ignore
=
p_e_grid
;
ignore
=
batch_count
;
...
...
@@ -150,6 +190,9 @@ __global__ void
}
// namespace
template
<
typename
T
>
using
is_tuple
=
decltype
(
std
::
declval
<
T
&>
().
IsTuple
());
//
// @brief Device Convolution operation.
//
...
...
@@ -211,8 +254,13 @@ template <index_t NDimSpatial,
index_t
CShuffleNXdlPerWavePerShuffle
,
typename
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
index_t
CDEBlockTransferScalarPerVector_NPerBlock
,
typename
ComputeDataType
=
ADataType
,
LoopScheduler
LoopSched
=
make_default_loop_scheduler
()>
typename
ComputeDataType
=
decltype
(
UnpackDataType
<
is_detected
<
is_tuple
,
ADataType
>
::
value
,
Number
<
0
>
,
ADataType
>
()),
// ComputeType is InputType by default (first
// in tuple for MultiAB), unpack if tuple was
// passed
LoopScheduler
LoopSched
=
make_default_loop_scheduler
()
>
struct
DeviceGroupedConvFwdMultipleD_Xdl_CShuffle
:
public
DeviceGroupedConvFwdMultipleD
<
NDimSpatial
,
ALayout
,
...
...
@@ -230,6 +278,11 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle
{
using
DeviceOp
=
DeviceGroupedConvFwdMultipleD_Xdl_CShuffle
;
static
constexpr
bool
isMultiA
=
is_detected
<
is_tuple
,
ADataType
>::
value
;
static
constexpr
bool
isMultiB
=
is_detected
<
is_tuple
,
BDataType
>::
value
;
static
constexpr
index_t
NumATensor
=
GetNumABTensors
<
isMultiA
,
ADataType
>
();
static
constexpr
index_t
NumBTensor
=
GetNumABTensors
<
isMultiB
,
BDataType
>
();
static
constexpr
index_t
NumDTensor
=
DsDataType
::
Size
();
static
constexpr
auto
I0
=
Number
<
0
>
{};
...
...
@@ -325,51 +378,43 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle
using
DsGridDesc_M_N
=
remove_cvref_t
<
decltype
(
MakeDsGridDescriptor_M_N
({},
{}))
>
;
using
EGridDesc_M_N
=
remove_cvref_t
<
decltype
(
MakeEGridDescriptor_M_N
<
ELayout
>
({},
{}))
>
;
// GridwiseGemm
using
GridwiseGemm
=
GridwiseGemmMultipleD_xdl_cshuffle
<
ADataType
,
// TODO: distinguish A/B datatype
BDataType
,
ComputeDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
AElementwiseOperation
,
BElementwiseOperation
,
CDEElementwiseOperation
,
InMemoryDataOperationEnum
::
Set
,
NumGemmKPrefetchStage
,
BlockSize
,
MPerBlock
,
NPerBlock
,
KPerBlock
,
AK1
,
BK1
,
MPerXDL
,
NPerXDL
,
MXdlPerWave
,
NXdlPerWave
,
ABlockTransferThreadClusterLengths_AK0_M_AK1
,
ABlockTransferThreadClusterArrangeOrder
,
ABlockTransferSrcAccessOrder
,
ABlockTransferSrcVectorDim
,
ABlockTransferSrcScalarPerVector
,
ABlockTransferDstScalarPerVector_AK1
,
false
,
ABlockLdsExtraM
,
BBlockTransferThreadClusterLengths_BK0_N_BK1
,
BBlockTransferThreadClusterArrangeOrder
,
BBlockTransferSrcAccessOrder
,
BBlockTransferSrcVectorDim
,
BBlockTransferSrcScalarPerVector
,
BBlockTransferDstScalarPerVector_BK1
,
false
,
BBlockLdsExtraN
,
CShuffleMXdlPerWavePerShuffle
,
CShuffleNXdlPerWavePerShuffle
,
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
CDEBlockTransferScalarPerVector_NPerBlock
,
LoopSched
>
;
// If we are using multiAB and one of the template datatype parameters is not a tuple, convert
// it to it
using
GemmADataType
=
std
::
conditional_t
<!
isMultiA
&&
isMultiB
,
Tuple
<
ADataType
>
,
ADataType
>
;
using
GemmBDataType
=
std
::
conditional_t
<!
isMultiB
&&
isMultiA
,
Tuple
<
BDataType
>
,
BDataType
>
;
#define GridwiseGemmTemplateParameters \
GemmADataType, GemmBDataType, ComputeDataType, AccDataType, CShuffleDataType, DsDataType, \
EDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, \
InMemoryDataOperationEnum::Set, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, \
KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave, \
ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, \
ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, \
ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, \
ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, \
BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, \
BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, \
BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, \
CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, \
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, \
CDEBlockTransferScalarPerVector_NPerBlock, LoopSched
// Use appropriate gridwise gemm
using
GridwiseGemm
=
std
::
conditional_t
<
isMultiA
||
isMultiB
,
GridwiseGemmMultipleABD_xdl_cshuffle
<
GridwiseGemmTemplateParameters
>
,
GridwiseGemmMultipleD_xdl_cshuffle
<
GridwiseGemmTemplateParameters
>>
;
// If ADataTypes or BDataTypes is tuple, user has to pass std::array with pointers.
using
APointers
=
std
::
conditional_t
<
isMultiA
,
std
::
array
<
const
void
*
,
NumATensor
>&
,
const
void
*>
;
using
BPointers
=
std
::
conditional_t
<
isMultiB
,
std
::
array
<
const
void
*
,
NumBTensor
>&
,
const
void
*>
;
// Use Tuple for the both cases for GridPointer to initialize it in Argument constructor (not
// in initializer list what is required for single const pointer).
using
AGridPointer
=
remove_cvref_t
<
decltype
(
GetAGridPointer
<
isMultiA
||
isMultiB
,
GridwiseGemm
,
ADataType
>
())
>
;
using
BGridPointer
=
remove_cvref_t
<
decltype
(
GetBGridPointer
<
isMultiA
||
isMultiB
,
GridwiseGemm
,
BDataType
>
())
>
;
// desc for blockwise copy
using
AGridDesc_AK0_M_AK1
=
...
...
@@ -392,8 +437,8 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle
// Argument
struct
Argument
:
public
BaseArgument
{
Argument
(
const
void
*
p_a
,
const
void
*
p_b
,
Argument
(
APointers
p_a
s
,
BPointers
p_b
s
,
const
std
::
array
<
const
void
*
,
NumDTensor
>&
p_ds
,
void
*
p_e
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
a_g_n_c_wis_lengths
,
...
...
@@ -413,8 +458,8 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle
const
AElementwiseOperation
&
a_element_op
,
const
BElementwiseOperation
&
b_element_op
,
const
CDEElementwiseOperation
&
cde_element_op
)
:
p_a_grid_
{
static_cast
<
const
ADataType
*>
(
p_a
)
},
p_b_grid_
{
static_cast
<
const
BDataType
*>
(
p_b
)
},
:
p_a
s
_grid_
{},
p_b
s
_grid_
{},
p_ds_grid_
{},
p_e_grid_
{
static_cast
<
EDataType
*>
(
p_e
)},
num_group_
{
a_g_n_c_wis_lengths
[
0
]},
...
...
@@ -458,9 +503,58 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle
input_right_pads_
{
input_right_pads
}
{
// A/B/E Batch Stride
compute_ptr_offset_of_batch_
.
BatchStrideA_
=
a_g_n_c_wis_strides
[
0
];
compute_ptr_offset_of_batch_
.
BatchStrideB_
=
b_g_k_c_xs_strides
[
0
];
compute_ptr_offset_of_batch_
.
BatchStrideE_
=
e_g_n_k_wos_strides
[
0
];
if
constexpr
(
isMultiA
||
isMultiB
)
{
static_for
<
0
,
NumATensor
,
1
>
{}([
&
](
auto
i
)
{
// Init compute_ptr_offset_of_batch_ for multiple AB
compute_ptr_offset_of_batch_
.
BatchStrideA_
(
i
)
=
a_g_n_c_wis_strides
[
0
];
// Use GemmADataType/GemmBDataType to iterate over tuple (even if passed data
// type is not tuple)
using
DataType
=
remove_cvref_t
<
tuple_element_t
<
i
.
value
,
GemmADataType
>>
;
// It is possible that one of the AB is a pointer and one is a tuple.
// Then also use multiAB but we have to cast single pointer instead of tuple of
// pointer.
if
constexpr
(
isMultiA
)
{
// p_as is tuple
p_as_grid_
(
i
)
=
static_cast
<
const
DataType
*>
(
p_as
[
i
.
value
]);
}
else
{
// if MultiB and not MultiA then p_as is single pointer
p_as_grid_
(
i
)
=
static_cast
<
const
DataType
*>
(
p_as
);
}
});
static_for
<
0
,
NumBTensor
,
1
>
{}([
&
](
auto
i
)
{
// Init compute_ptr_offset_of_batch_ for multiple AB
compute_ptr_offset_of_batch_
.
BatchStrideB_
(
i
)
=
b_g_k_c_xs_strides
[
0
];
using
DataType
=
remove_cvref_t
<
tuple_element_t
<
i
.
value
,
GemmBDataType
>>
;
// It is possible that one of the AB is a pointer and one is a tuple.
// Then also use multiAB but we have to cast single pointer instead of tuple of
// pointer.
if
constexpr
(
isMultiB
)
{
// p_bs is tuple
p_bs_grid_
(
i
)
=
static_cast
<
const
DataType
*>
(
p_bs
[
i
.
value
]);
}
else
{
// if MultiA and not MultiB then p_bs is single pointer
p_bs_grid_
(
i
)
=
static_cast
<
const
DataType
*>
(
p_bs
);
}
});
}
else
{
compute_ptr_offset_of_batch_
.
BatchStrideA_
=
a_g_n_c_wis_strides
[
0
];
compute_ptr_offset_of_batch_
.
BatchStrideB_
=
b_g_k_c_xs_strides
[
0
];
// p_as and p_bs are pointers
p_as_grid_
(
I0
)
=
static_cast
<
const
ADataType
*>
(
p_as
);
p_bs_grid_
(
I0
)
=
static_cast
<
const
BDataType
*>
(
p_bs
);
}
// populate pointer, batch stride, desc for Ds
static_for
<
0
,
NumDTensor
,
1
>
{}([
&
](
auto
i
)
{
...
...
@@ -477,21 +571,47 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle
ds_grid_desc_m_n_
(
i
)
=
DeviceOp
::
MakeEGridDescriptor_M_N
<
DLayout
>
(
ds_g_n_k_wos_lengths
[
i
],
ds_g_n_k_wos_strides
[
i
]);
});
compute_ptr_offset_of_batch_
.
BatchStrideE_
=
e_g_n_k_wos_strides
[
0
];
// populate desc for Ds/E
if
(
GridwiseGemm
::
CheckValidity
(
a_grid_desc_m_k_
,
b_grid_desc_n_k_
,
ds_grid_desc_m_n_
,
e_grid_desc_m_n_
,
block_2_etile_map_
))
if
constexpr
(
isMultiA
||
isMultiB
)
{
e_grid_desc_mblock_mperblock_nblock_nperblock_
=
GridwiseGemm
::
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
e_grid_desc_m_n_
);
const
auto
as_grid_desc_ak0_m_ak1
=
generate_tuple
([
&
](
auto
)
{
return
a_grid_desc_m_k_
;
},
Number
<
NumATensor
>
{});
const
auto
bs_grid_desc_bk0_n_bk1
=
generate_tuple
([
&
](
auto
)
{
return
b_grid_desc_n_k_
;
},
Number
<
NumBTensor
>
{});
if
(
GridwiseGemm
::
CheckValidity
(
as_grid_desc_ak0_m_ak1
,
bs_grid_desc_bk0_n_bk1
,
ds_grid_desc_m_n_
,
e_grid_desc_m_n_
,
block_2_etile_map_
))
{
e_grid_desc_mblock_mperblock_nblock_nperblock_
=
GridwiseGemm
::
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
e_grid_desc_m_n_
);
ds_grid_desc_mblock_mperblock_nblock_nperblock_
=
GridwiseGemm
::
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
ds_grid_desc_m_n_
);
ds_grid_desc_mblock_mperblock_nblock_nperblock_
=
GridwiseGemm
::
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
ds_grid_desc_m_n_
);
}
}
else
{
if
(
GridwiseGemm
::
CheckValidity
(
a_grid_desc_m_k_
,
b_grid_desc_n_k_
,
ds_grid_desc_m_n_
,
e_grid_desc_m_n_
,
block_2_etile_map_
))
{
e_grid_desc_mblock_mperblock_nblock_nperblock_
=
GridwiseGemm
::
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
e_grid_desc_m_n_
);
ds_grid_desc_mblock_mperblock_nblock_nperblock_
=
GridwiseGemm
::
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
ds_grid_desc_m_n_
);
}
}
}
...
...
@@ -505,9 +625,9 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle
}
// private:
// pointers
const
ADataType
*
p_a_grid_
;
const
BDataType
*
p_b_grid_
;
// pointers
(tuple if multi AB, pointer if no)
AGridPointer
p_a
s
_grid_
;
BGridPointer
p_b
s
_grid_
;
typename
GridwiseGemm
::
DsGridPointer
p_ds_grid_
;
EDataType
*
p_e_grid_
;
...
...
@@ -529,7 +649,8 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle
Block2ETileMap
block_2_etile_map_
;
// for computing batch offset
ComputePtrOffsetOfStridedBatch
<
NumDTensor
>
compute_ptr_offset_of_batch_
;
ComputePtrOffsetOfStridedBatch
<
NumATensor
,
NumBTensor
,
NumDTensor
>
compute_ptr_offset_of_batch_
;
// element-wise op
AElementwiseOperation
a_element_op_
;
...
...
@@ -563,16 +684,6 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle
arg
.
Print
();
}
if
(
!
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_m_k_
,
arg
.
b_grid_desc_n_k_
,
arg
.
ds_grid_desc_m_n_
,
arg
.
e_grid_desc_m_n_
,
arg
.
block_2_etile_map_
))
{
throw
std
::
runtime_error
(
"wrong! GridwiseGemmMultipleD_xdl_cshuffle has invalid setting"
);
}
const
index_t
grid_size
=
arg
.
block_2_etile_map_
.
CalculateGridSize
(
arg
.
e_grid_desc_m_n_
)
*
arg
.
num_group_
;
...
...
@@ -582,41 +693,96 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle
auto
launch_kernel
=
[
&
](
auto
has_main_k_block_loop
)
{
constexpr
bool
has_main_loop
=
has_main_k_block_loop
.
value
;
const
auto
kernel
=
kernel_grouped_conv_fwd_multiple_d_xdl_cshuffle
<
GridwiseGemm
,
ADataType
,
// TODO: distiguish A/B datatype
typename
GridwiseGemm
::
DsGridPointer
,
EDataType
,
AElementwiseOperation
,
BElementwiseOperation
,
CDEElementwiseOperation
,
DeviceOp
::
AGridDesc_AK0_M_AK1
,
DeviceOp
::
BGridDesc_BK0_N_BK1
,
DeviceOp
::
DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
,
DeviceOp
::
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
,
Block2ETileMap
,
ComputePtrOffsetOfStridedBatch
<
NumDTensor
>
,
has_main_loop
>
;
return
launch_and_time_kernel
(
stream_config
,
kernel
,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
arg
.
p_a_grid_
,
arg
.
p_b_grid_
,
arg
.
p_ds_grid_
,
arg
.
p_e_grid_
,
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
cde_element_op_
,
arg
.
a_g_n_c_wis_lengths_
[
0
],
// Group count
arg
.
a_grid_desc_ak0_m_ak1_
,
arg
.
b_grid_desc_bk0_n_bk1_
,
arg
.
ds_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg
.
e_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg
.
block_2_etile_map_
,
arg
.
compute_ptr_offset_of_batch_
);
if
constexpr
(
isMultiA
||
isMultiB
)
{
// Generate tuples with grid descriptors for each A and B
const
auto
as_grid_desc_ak0_m_ak1
=
generate_tuple
(
[
&
](
auto
)
{
return
arg
.
a_grid_desc_ak0_m_ak1_
;
},
Number
<
NumATensor
>
{});
const
auto
bs_grid_desc_bk0_n_bk1
=
generate_tuple
(
[
&
](
auto
)
{
return
arg
.
b_grid_desc_bk0_n_bk1_
;
},
Number
<
NumBTensor
>
{});
const
auto
kernel
=
kernel_grouped_conv_fwd_multiple_d_xdl_cshuffle
<
GridwiseGemm
,
AGridPointer
,
BGridPointer
,
typename
GridwiseGemm
::
DsGridPointer
,
EDataType
,
AElementwiseOperation
,
BElementwiseOperation
,
CDEElementwiseOperation
,
decltype
(
as_grid_desc_ak0_m_ak1
),
decltype
(
bs_grid_desc_bk0_n_bk1
),
DeviceOp
::
DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
,
DeviceOp
::
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
,
Block2ETileMap
,
ComputePtrOffsetOfStridedBatch
<
NumATensor
,
NumBTensor
,
NumDTensor
>
,
has_main_loop
,
isMultiA
,
isMultiB
>
;
return
launch_and_time_kernel
(
stream_config
,
kernel
,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
arg
.
p_as_grid_
,
arg
.
p_bs_grid_
,
arg
.
p_ds_grid_
,
arg
.
p_e_grid_
,
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
cde_element_op_
,
arg
.
a_g_n_c_wis_lengths_
[
0
],
// Group count
as_grid_desc_ak0_m_ak1
,
bs_grid_desc_bk0_n_bk1
,
arg
.
ds_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg
.
e_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg
.
block_2_etile_map_
,
arg
.
compute_ptr_offset_of_batch_
);
}
else
{
const
auto
kernel
=
kernel_grouped_conv_fwd_multiple_d_xdl_cshuffle
<
GridwiseGemm
,
const
ADataType
*
,
const
BDataType
*
,
typename
GridwiseGemm
::
DsGridPointer
,
EDataType
,
AElementwiseOperation
,
BElementwiseOperation
,
CDEElementwiseOperation
,
DeviceOp
::
AGridDesc_AK0_M_AK1
,
DeviceOp
::
BGridDesc_BK0_N_BK1
,
DeviceOp
::
DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
,
DeviceOp
::
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
,
Block2ETileMap
,
ComputePtrOffsetOfStridedBatch
<
NumATensor
,
NumBTensor
,
NumDTensor
>
,
has_main_loop
,
isMultiA
,
isMultiB
>
;
return
launch_and_time_kernel
(
stream_config
,
kernel
,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
arg
.
p_as_grid_
.
At
(
I0
),
// Pass just A descriptor instead of tuple
arg
.
p_bs_grid_
.
At
(
I0
),
// Pass just B descriptor instead of tuple
arg
.
p_ds_grid_
,
arg
.
p_e_grid_
,
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
cde_element_op_
,
arg
.
a_g_n_c_wis_lengths_
[
0
],
// Group count
arg
.
a_grid_desc_ak0_m_ak1_
,
arg
.
b_grid_desc_bk0_n_bk1_
,
arg
.
ds_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg
.
e_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg
.
block_2_etile_map_
,
arg
.
compute_ptr_offset_of_batch_
);
}
};
if
(
GridwiseGemm
::
CalculateHasMainKBlockLoop
(
K
))
...
...
@@ -791,11 +957,27 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle
}
// check Gridwise GEMM
return
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_m_k_
,
arg
.
b_grid_desc_n_k_
,
arg
.
ds_grid_desc_m_n_
,
arg
.
e_grid_desc_m_n_
,
arg
.
block_2_etile_map_
);
if
constexpr
(
isMultiA
||
isMultiB
)
{
// Genarate tuples with the same descriptors
const
auto
as_grid_desc_ak0_m_ak1
=
generate_tuple
([
&
](
auto
)
{
return
arg
.
a_grid_desc_m_k_
;
},
Number
<
NumATensor
>
{});
const
auto
bs_grid_desc_bk0_n_bk1
=
generate_tuple
([
&
](
auto
)
{
return
arg
.
b_grid_desc_n_k_
;
},
Number
<
NumBTensor
>
{});
return
GridwiseGemm
::
CheckValidity
(
as_grid_desc_ak0_m_ak1
,
bs_grid_desc_bk0_n_bk1
,
arg
.
ds_grid_desc_m_n_
,
arg
.
e_grid_desc_m_n_
,
arg
.
block_2_etile_map_
);
}
else
{
return
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_m_k_
,
arg
.
b_grid_desc_n_k_
,
arg
.
ds_grid_desc_m_n_
,
arg
.
e_grid_desc_m_n_
,
arg
.
block_2_etile_map_
);
}
}
bool
IsSupportedArgument
(
const
BaseArgument
*
p_arg
)
override
...
...
@@ -804,8 +986,8 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle
}
static
auto
MakeArgument
(
const
void
*
p_a
,
const
void
*
p_b
,
APointers
p_a
s
,
BPointers
p_b
s
,
const
std
::
array
<
const
void
*
,
NumDTensor
>&
p_ds
,
void
*
p_e
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
a_g_n_c_wis_lengths
,
...
...
@@ -824,8 +1006,8 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle
const
BElementwiseOperation
&
b_element_op
,
const
CDEElementwiseOperation
&
cde_element_op
)
{
return
Argument
{
p_a
,
p_b
,
return
Argument
{
p_a
s
,
p_b
s
,
p_ds
,
p_e
,
a_g_n_c_wis_lengths
,
...
...
@@ -848,8 +1030,8 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
void
*
p_a
,
const
void
*
p_b
,
APointers
p_a
,
BPointers
p_b
,
const
std
::
array
<
const
void
*
,
NumDTensor
>&
p_ds
,
void
*
p_e
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
a_g_n_c_wis_lengths
,
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp
View file @
45c6c530
...
...
@@ -9,8 +9,77 @@ namespace ck {
namespace
tensor_operation
{
namespace
device
{
template
<
index_t
Num
DTensor
>
template
<
index_t
Num
ATensor
=
1
,
index_t
NumBTensor
=
1
,
index_t
NumDTensor
=
0
,
typename
=
void
>
struct
ComputePtrOffsetOfStridedBatch
{
};
template
<
index_t
NumATensor
,
index_t
NumBTensor
,
index_t
NumDTensor
>
struct
ComputePtrOffsetOfStridedBatch
<
NumATensor
,
NumBTensor
,
NumDTensor
,
ck
::
enable_if_t
<
(
NumATensor
>
1
||
NumBTensor
>
1
)
>>
{
ComputePtrOffsetOfStridedBatch
()
=
default
;
ComputePtrOffsetOfStridedBatch
(
Array
<
ck
::
index_t
,
NumATensor
>&
BatchStrideAs
,
Array
<
ck
::
index_t
,
NumBTensor
>&
BatchStrideBs
,
Array
<
ck
::
index_t
,
NumDTensor
>&
BatchStrideDs
,
index_t
BatchStrideE
)
:
BatchStrideA_
(
BatchStrideAs
),
BatchStrideB_
(
BatchStrideBs
),
BatchStrideDs_
(
BatchStrideDs
),
BatchStrideE_
(
BatchStrideE
)
{
}
__host__
__device__
constexpr
auto
GetAsPtrOffset
(
index_t
g_idx
)
const
{
Array
<
long_index_t
,
NumATensor
>
as_offset
;
static_for
<
0
,
NumATensor
,
1
>
{}(
[
&
](
auto
i
)
{
as_offset
(
i
)
=
g_idx
*
static_cast
<
long_index_t
>
(
BatchStrideA_
[
i
]);
});
return
as_offset
;
}
__host__
__device__
constexpr
auto
GetBsPtrOffset
(
index_t
g_idx
)
const
{
Array
<
long_index_t
,
NumBTensor
>
bs_offset
;
static_for
<
0
,
NumBTensor
,
1
>
{}(
[
&
](
auto
i
)
{
bs_offset
(
i
)
=
g_idx
*
static_cast
<
long_index_t
>
(
BatchStrideB_
[
i
]);
});
return
bs_offset
;
}
__host__
__device__
constexpr
auto
GetDsPtrOffset
(
index_t
g_idx
)
const
{
Array
<
long_index_t
,
NumDTensor
>
ds_offset
;
static_for
<
0
,
NumDTensor
,
1
>
{}(
[
&
](
auto
i
)
{
ds_offset
(
i
)
=
g_idx
*
static_cast
<
long_index_t
>
(
BatchStrideDs_
[
i
]);
});
return
ds_offset
;
}
[[
maybe_unused
]]
__host__
__device__
constexpr
long_index_t
GetEPtrOffset
(
index_t
g_idx
)
const
{
return
g_idx
*
static_cast
<
long_index_t
>
(
BatchStrideE_
);
}
// alias for kernels without multiple D
[[
maybe_unused
]]
__host__
__device__
constexpr
long_index_t
GetCPtrOffset
(
index_t
g_idx
)
const
{
return
g_idx
*
static_cast
<
long_index_t
>
(
BatchStrideE_
);
}
Array
<
ck
::
index_t
,
NumATensor
>
BatchStrideA_
;
Array
<
ck
::
index_t
,
NumBTensor
>
BatchStrideB_
;
Array
<
ck
::
index_t
,
NumDTensor
>
BatchStrideDs_
;
index_t
BatchStrideE_
;
index_t
&
BatchStrideC_
=
BatchStrideE_
;
// alias for kernels without multiple D
};
template
<
index_t
NumATensor
,
index_t
NumBTensor
,
index_t
NumDTensor
>
struct
ComputePtrOffsetOfStridedBatch
<
NumATensor
,
NumBTensor
,
NumDTensor
,
ck
::
enable_if_t
<
(
NumATensor
==
1
&&
NumBTensor
==
1
)
>>
{
ComputePtrOffsetOfStridedBatch
()
=
default
;
...
...
@@ -54,13 +123,67 @@ struct ComputePtrOffsetOfStridedBatch
return
g_idx
*
static_cast
<
long_index_t
>
(
BatchStrideE_
);
}
index_t
BatchStrideA_
;
index_t
BatchStrideB_
;
ck
::
index_t
BatchStrideA_
;
ck
::
index_t
BatchStrideB_
;
Array
<
ck
::
index_t
,
NumDTensor
>
BatchStrideDs_
;
index_t
BatchStrideE_
;
index_t
&
BatchStrideC_
=
BatchStrideE_
;
// alias for kernels without multiple D
};
template
<
bool
isTuple
,
typename
Tensors
>
constexpr
static
auto
GetNumABTensors
()
{
if
constexpr
(
isTuple
)
{
return
Number
<
Tensors
::
Size
()
>
{};
}
else
{
return
Number
<
1
>
{};
}
}
template
<
bool
isTuple
,
typename
GridwiseGemm
,
typename
DataType
>
constexpr
static
auto
GetAGridPointer
()
{
if
constexpr
(
isTuple
)
{
return
typename
GridwiseGemm
::
AsGridPointer
{};
}
else
{
return
Tuple
<
const
DataType
*>
{};
}
}
template
<
bool
isTuple
,
typename
GridwiseGemm
,
typename
DataType
>
constexpr
static
auto
GetBGridPointer
()
{
if
constexpr
(
isTuple
)
{
return
typename
GridwiseGemm
::
BsGridPointer
{};
}
else
{
return
Tuple
<
const
DataType
*>
{};
}
}
template
<
bool
isTuple
,
typename
Id
,
typename
Type
>
constexpr
static
auto
UnpackDataType
()
{
if
constexpr
(
isTuple
)
{
// unpack if tuple
return
tuple_element_t
<
Id
{},
Type
>
{};
}
else
{
// if no, return Type
return
Type
{};
}
}
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
include/ck/tensor_operation/gpu/device/impl/device_image_to_column_impl.hpp
View file @
45c6c530
...
...
@@ -142,19 +142,18 @@ struct DeviceImageToColumnImpl
decltype
(
BlockToCTileMap_M00_N0_M01Adapt
<
MPerBlock
,
KPerBlock
,
OutputGridDesc
>
(
OutputGridDesc
{}))
>
;
using
GridwiseTensorRearrangeKernel
=
GridwiseTensorRearrange
<
InputGridDesc
,
InputDataType
,
OutputGridDesc
,
OutputDataType
,
BlockSize
,
MPerBlock
,
KPerBlock
,
ThreadClusterLengths
,
ScalarPerVector
,
InMemoryDataOperationEnum
::
Set
,
Block2ETileMap
,
ComputePtrOffsetOfStridedBatch
<
I0
>>
;
using
GridwiseTensorRearrangeKernel
=
GridwiseTensorRearrange
<
InputGridDesc
,
InputDataType
,
OutputGridDesc
,
OutputDataType
,
BlockSize
,
MPerBlock
,
KPerBlock
,
ThreadClusterLengths
,
ScalarPerVector
,
InMemoryDataOperationEnum
::
Set
,
Block2ETileMap
,
ComputePtrOffsetOfStridedBatch
<>>
;
struct
Argument
:
public
BaseArgument
{
...
...
@@ -224,7 +223,7 @@ struct DeviceImageToColumnImpl
InputGridDesc
in_grid_desc_m_k_
;
OutputGridDesc
out_grid_desc_m_k_
;
ComputePtrOffsetOfStridedBatch
<
I0
>
compute_ptr_offset_of_batch_
;
ComputePtrOffsetOfStridedBatch
<>
compute_ptr_offset_of_batch_
;
};
struct
Invoker
:
public
BaseInvoker
...
...
@@ -246,7 +245,7 @@ struct DeviceImageToColumnImpl
OutputGridDesc
,
OutputDataType
,
Block2ETileMap
,
ComputePtrOffsetOfStridedBatch
<
I0
>
,
ComputePtrOffsetOfStridedBatch
<>
,
GridwiseTensorRearrangeKernel
>
;
float
elapsed_time
=
launch_and_time_kernel
(
stream_config
,
...
...
include/ck/tensor_operation/gpu/device/impl/device_normalization_bwd_gamma_beta_impl.hpp
0 → 100644
View file @
45c6c530
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <vector>
#include "ck/tensor_operation/gpu/device/device_normalization_bwd_gamma_beta.hpp"
#include "ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_bwd_gamma_beta.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_reduce_common.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
// M is invarient dimension, K is reduced dimension
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
template
<
typename
GridwiseReduction
,
typename
DYDataType
,
typename
XDataType
,
typename
MeanInvStdDataType
,
typename
DGammaDataType
,
typename
DBetaDataType
,
typename
GridDesc_M_K
,
typename
GridDesc_M
>
__global__
void
kernel_normalization_bwd_gamma_beta
(
const
GridDesc_M_K
dy_grid_desc_m_k
,
const
GridDesc_M_K
x_grid_desc_m_k
,
const
GridDesc_M_K
mean_grid_desc_m_k
,
const
GridDesc_M_K
inv_std_grid_desc_m_k
,
const
GridDesc_M
dgamma_grid_desc_m
,
const
GridDesc_M
dbeta_grid_desc_m
,
index_t
num_k_block_tile_iteration
,
const
DYDataType
*
const
__restrict__
p_dy_global
,
const
XDataType
*
const
__restrict__
p_x_global
,
const
MeanInvStdDataType
*
const
__restrict__
p_mean_global
,
const
MeanInvStdDataType
*
const
__restrict__
p_inv_std_global
,
DGammaDataType
*
const
__restrict__
p_dgamma_global
,
DBetaDataType
*
const
__restrict__
p_dbeta_global
)
{
GridwiseReduction
::
Run
(
dy_grid_desc_m_k
,
x_grid_desc_m_k
,
mean_grid_desc_m_k
,
inv_std_grid_desc_m_k
,
dgamma_grid_desc_m
,
dbeta_grid_desc_m
,
num_k_block_tile_iteration
,
p_dy_global
,
p_x_global
,
p_mean_global
,
p_inv_std_global
,
p_dgamma_global
,
p_dbeta_global
);
};
template
<
typename
DYDataType
,
typename
XDataType
,
typename
MeanInvStdDataType
,
typename
ComputeDataType
,
typename
DGammaDataType
,
typename
DBetaDataType
,
index_t
Rank
,
index_t
NumReduceDim
,
index_t
BlockSize
,
index_t
MThreadClusterSize
,
index_t
KThreadClusterSize
,
index_t
MThreadSliceSize
,
index_t
KThreadSliceSize
,
bool
IsDYFastestDimReduced
,
index_t
DYSrcVectorSize
,
bool
IsXFastestDimReduced
,
index_t
XSrcVectorSize
,
bool
IsMeanInvStdFastestDimReduced
,
index_t
MeanInvStdSrcVectorSize
,
index_t
DGammaDstVectorSize
,
index_t
DBetaDstVectorSize
>
struct
DeviceNormalizationBwdGammaBetaImpl
:
public
DeviceNormalizationBwdGammaBeta
<
DYDataType
,
XDataType
,
MeanInvStdDataType
,
DGammaDataType
,
DBetaDataType
,
Rank
,
NumReduceDim
>
{
static
constexpr
index_t
DYSrcVectorDim
=
IsDYFastestDimReduced
?
1
:
0
;
static
constexpr
index_t
XSrcVectorDim
=
IsXFastestDimReduced
?
1
:
0
;
static
constexpr
index_t
MeanInvStdSrcVectorDim
=
IsMeanInvStdFastestDimReduced
?
1
:
0
;
static_assert
(
BlockSize
==
MThreadClusterSize
*
KThreadClusterSize
);
static_assert
(((
DYSrcVectorDim
==
0
&&
MThreadSliceSize
%
DYSrcVectorSize
==
0
)
||
(
DYSrcVectorDim
==
1
&&
KThreadSliceSize
%
DYSrcVectorSize
==
0
)),
"Invalid thread slice sizes and/or dy vector sizes configuration, please check!"
);
static_assert
(((
XSrcVectorDim
==
0
&&
MThreadSliceSize
%
XSrcVectorSize
==
0
)
||
(
XSrcVectorDim
==
1
&&
KThreadSliceSize
%
XSrcVectorSize
==
0
)),
"Invalid thread slice sizes and/or x vector sizes configuration, please check!"
);
static_assert
(
((
MThreadSliceSize
%
DGammaDstVectorSize
==
0
)
||
(
MThreadSliceSize
%
DBetaDstVectorSize
==
0
)),
"Invalid thread slice sizes and/or Gamma and beta vector sizes configuration, please "
"check!"
);
static_assert
(
(
MeanInvStdSrcVectorDim
==
0
&&
MThreadSliceSize
%
MeanInvStdSrcVectorSize
==
0
)
||
(
MeanInvStdSrcVectorDim
==
1
&&
KThreadSliceSize
%
MeanInvStdSrcVectorSize
==
0
),
"Invalid thread slice sizes and/or mean and inverse std vector sizes configuration, please "
"check!"
);
static
constexpr
index_t
NumInvariantDim
=
Rank
-
NumReduceDim
;
static
constexpr
index_t
M_BlockTileSize
=
MThreadClusterSize
*
MThreadSliceSize
;
static
constexpr
index_t
K_BlockTileSize
=
KThreadClusterSize
*
KThreadSliceSize
;
static
constexpr
bool
reduceAllDim
=
(
NumInvariantDim
==
0
);
static_assert
(
!
reduceAllDim
);
static
auto
MakeSrc2dDescriptor
(
const
std
::
vector
<
index_t
>&
inLengths
,
const
std
::
vector
<
index_t
>&
inStrides
,
int
numBlockTileIteration
)
{
const
auto
tupleSrcLengths
=
make_tuple_from_array
(
inLengths
,
Number
<
Rank
>
{});
const
auto
tupleSrcStrides
=
make_tuple_from_array
(
inStrides
,
Number
<
Rank
>
{});
const
auto
inDesc
=
make_naive_tensor_descriptor
(
tupleSrcLengths
,
tupleSrcStrides
);
const
auto
in_grid_desc_m_k
=
[
&
]()
{
using
InvariantDims
=
typename
arithmetic_sequence_gen
<
0
,
NumInvariantDim
,
1
>::
type
;
using
ReduceDims
=
typename
arithmetic_sequence_gen
<
NumInvariantDim
,
Rank
,
1
>::
type
;
const
auto
reduceDimLengths
=
make_tuple_from_array_and_index_seq
(
inLengths
,
ReduceDims
{});
const
auto
invariantDimLengths
=
make_tuple_from_array_and_index_seq
(
inLengths
,
InvariantDims
{});
return
transform_tensor_descriptor
(
inDesc
,
make_tuple
(
make_merge_transform
(
invariantDimLengths
),
make_merge_transform
(
reduceDimLengths
)),
make_tuple
(
InvariantDims
{},
ReduceDims
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
}();
const
auto
invariantLength
=
in_grid_desc_m_k
.
GetLength
(
Number
<
0
>
{});
const
auto
reduceLength
=
in_grid_desc_m_k
.
GetLength
(
Number
<
1
>
{});
const
auto
inPad_M
=
math
::
integer_least_multiple
(
invariantLength
,
M_BlockTileSize
)
-
invariantLength
;
const
auto
inPad_K
=
K_BlockTileSize
*
numBlockTileIteration
-
reduceLength
;
auto
in_grid_desc_m_k_padded
=
transform_tensor_descriptor
(
in_grid_desc_m_k
,
make_tuple
(
make_right_pad_transform
(
invariantLength
,
inPad_M
),
make_right_pad_transform
(
reduceLength
,
inPad_K
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
return
in_grid_desc_m_k_padded
;
}
static
auto
MakeDst1dDescriptor
(
const
std
::
vector
<
index_t
>&
outLengths
,
const
std
::
vector
<
index_t
>&
outStrides
)
{
const
auto
tupleDstLengths
=
generate_tuple
([
&
](
auto
I
)
{
return
outLengths
[
I
];
},
Number
<
NumInvariantDim
>
{});
const
auto
tupleDstStrides
=
generate_tuple
([
&
](
auto
I
)
{
return
outStrides
[
I
];
},
Number
<
NumInvariantDim
>
{});
auto
outDesc
=
make_naive_tensor_descriptor
(
tupleDstLengths
,
tupleDstStrides
);
auto
out_grid_desc_m
=
transform_tensor_descriptor
(
outDesc
,
make_tuple
(
make_merge_transform
(
tupleDstLengths
)),
make_tuple
(
typename
arithmetic_sequence_gen
<
0
,
NumInvariantDim
,
1
>::
type
{}),
make_tuple
(
Sequence
<
0
>
{}));
const
auto
invariantLength
=
out_grid_desc_m
.
GetLength
(
Number
<
0
>
{});
const
auto
outPad
=
math
::
integer_least_multiple
(
invariantLength
,
M_BlockTileSize
)
-
invariantLength
;
auto
out_grid_desc_m_padded
=
transform_tensor_descriptor
(
out_grid_desc_m
,
make_tuple
(
make_right_pad_transform
(
invariantLength
,
outPad
)),
make_tuple
(
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
return
(
out_grid_desc_m_padded
);
};
using
GridDesc_M_K
=
decltype
(
MakeSrc2dDescriptor
({
1
},
{
1
},
1
));
using
GridDesc_M
=
decltype
(
MakeDst1dDescriptor
({
1
},
{
1
}));
using
GridwiseNormalizationBwdGammaBeta
=
GridwiseNormalizationBwdGammaBeta_mk_to_k
<
DYDataType
,
XDataType
,
MeanInvStdDataType
,
ComputeDataType
,
DGammaDataType
,
DBetaDataType
,
GridDesc_M_K
,
GridDesc_M
,
BlockSize
,
MThreadClusterSize
,
KThreadClusterSize
,
MThreadSliceSize
,
KThreadSliceSize
,
DYSrcVectorDim
,
DYSrcVectorSize
,
XSrcVectorDim
,
XSrcVectorSize
,
MeanInvStdSrcVectorDim
,
MeanInvStdSrcVectorSize
,
DGammaDstVectorSize
,
DBetaDstVectorSize
>
;
struct
Argument
:
public
BaseArgument
{
Argument
(
const
std
::
vector
<
index_t
>
inLengths
,
const
std
::
vector
<
index_t
>
dyStrides
,
const
std
::
vector
<
index_t
>
xStrides
,
const
std
::
vector
<
index_t
>
meanStrides
,
const
std
::
vector
<
index_t
>
invStdStrides
,
const
std
::
vector
<
index_t
>
outLengths
,
const
std
::
vector
<
index_t
>
dgammaStrides
,
const
std
::
vector
<
index_t
>
dbetaStrides
,
const
std
::
vector
<
index_t
>
reduceDims
,
const
DYDataType
*
p_dy
,
const
XDataType
*
p_x
,
const
MeanInvStdDataType
*
p_mean
,
const
MeanInvStdDataType
*
p_invStd
,
DGammaDataType
*
p_dgamma
,
DBetaDataType
*
p_dbeta
)
:
p_dy_
(
p_dy
),
p_x_
(
p_x
),
p_mean_
(
p_mean
),
p_invStd_
(
p_invStd
),
p_dgamma_
(
p_dgamma
),
p_dbeta_
(
p_dbeta
),
outLengths_
{
outLengths
},
dgammaStrides_
{
dgammaStrides
},
dbetaStrides_
{
dbetaStrides
}
{
inLengths_
=
shuffle_tensor_dimensions
<
Rank
,
NumReduceDim
>
(
inLengths
,
reduceDims
);
dyStrides_
=
shuffle_tensor_dimensions
<
Rank
,
NumReduceDim
>
(
dyStrides
,
reduceDims
);
xStrides_
=
shuffle_tensor_dimensions
<
Rank
,
NumReduceDim
>
(
xStrides
,
reduceDims
);
meanStrides_
=
shuffle_tensor_dimensions
<
Rank
,
NumReduceDim
>
(
meanStrides
,
reduceDims
);
invStdStrides_
=
shuffle_tensor_dimensions
<
Rank
,
NumReduceDim
>
(
invStdStrides
,
reduceDims
);
std
::
tie
(
MRaw_
,
KRaw_
)
=
get_2d_lengths
<
Rank
,
NumReduceDim
>
(
inLengths_
);
numBlockTileIteration_
=
math
::
integer_divide_ceil
(
KRaw_
,
K_BlockTileSize
);
gridSize_
=
math
::
integer_divide_ceil
(
MRaw_
,
M_BlockTileSize
);
dy_grid_desc_m_k_
=
MakeSrc2dDescriptor
(
inLengths_
,
dyStrides_
,
numBlockTileIteration_
);
x_grid_desc_m_k_
=
MakeSrc2dDescriptor
(
inLengths_
,
xStrides_
,
numBlockTileIteration_
);
mean_grid_desc_m_k_
=
MakeSrc2dDescriptor
(
inLengths_
,
meanStrides_
,
numBlockTileIteration_
);
inv_std_grid_desc_m_k_
=
MakeSrc2dDescriptor
(
inLengths_
,
invStdStrides_
,
numBlockTileIteration_
);
dgamma_grid_desc_m_
=
MakeDst1dDescriptor
(
outLengths_
,
dgammaStrides_
);
dbeta_grid_desc_m_
=
MakeDst1dDescriptor
(
outLengths_
,
dbetaStrides_
);
}
const
DYDataType
*
p_dy_
;
const
XDataType
*
p_x_
;
const
MeanInvStdDataType
*
p_mean_
;
const
MeanInvStdDataType
*
p_invStd_
;
DGammaDataType
*
p_dgamma_
;
DBetaDataType
*
p_dbeta_
;
std
::
vector
<
index_t
>
inLengths_
;
std
::
vector
<
index_t
>
dyStrides_
;
std
::
vector
<
index_t
>
xStrides_
;
std
::
vector
<
index_t
>
meanStrides_
;
std
::
vector
<
index_t
>
invStdStrides_
;
std
::
vector
<
index_t
>
outLengths_
;
std
::
vector
<
index_t
>
dgammaStrides_
;
std
::
vector
<
index_t
>
dbetaStrides_
;
int
numBlockTileIteration_
;
size_t
gridSize_
;
// Source descriptor
GridDesc_M_K
dy_grid_desc_m_k_
;
GridDesc_M_K
x_grid_desc_m_k_
;
GridDesc_M_K
mean_grid_desc_m_k_
;
GridDesc_M_K
inv_std_grid_desc_m_k_
;
// Destination descriptor
GridDesc_M
dgamma_grid_desc_m_
;
GridDesc_M
dbeta_grid_desc_m_
;
index_t
MRaw_
;
// invarient length
index_t
KRaw_
;
// reduce length
};
struct
Invoker
:
public
BaseInvoker
{
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
{
const
auto
kernel_main
=
kernel_normalization_bwd_gamma_beta
<
GridwiseNormalizationBwdGammaBeta
,
DYDataType
,
XDataType
,
MeanInvStdDataType
,
DGammaDataType
,
DBetaDataType
,
GridDesc_M_K
,
GridDesc_M
>
;
return
launch_and_time_kernel
(
stream_config
,
kernel_main
,
dim3
(
arg
.
gridSize_
),
dim3
(
BlockSize
),
0
,
arg
.
dy_grid_desc_m_k_
,
arg
.
x_grid_desc_m_k_
,
arg
.
mean_grid_desc_m_k_
,
arg
.
inv_std_grid_desc_m_k_
,
arg
.
dgamma_grid_desc_m_
,
arg
.
dbeta_grid_desc_m_
,
arg
.
numBlockTileIteration_
,
arg
.
p_dy_
,
arg
.
p_x_
,
arg
.
p_mean_
,
arg
.
p_invStd_
,
arg
.
p_dgamma_
,
arg
.
p_dbeta_
);
}
float
Run
(
const
BaseArgument
*
p_arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
override
{
return
Run
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
),
stream_config
);
}
};
template
<
index_t
SrcVectorDim
,
index_t
SrcVectorSize
>
bool
IsSrcVectorDimSizeValid
(
const
std
::
vector
<
index_t
>&
lengths
,
const
std
::
vector
<
index_t
>&
strides
)
{
if
constexpr
(
SrcVectorSize
==
1
)
return
true
;
// Fastest dimension is not reduced
if
constexpr
(
SrcVectorDim
==
0
)
{
if
constexpr
(
NumInvariantDim
==
0
)
return
false
;
if
(
strides
[
NumInvariantDim
-
1
]
!=
1
)
return
false
;
if
(
lengths
[
NumInvariantDim
-
1
]
%
SrcVectorSize
!=
0
)
return
false
;
}
else
// Fastest dimension is reduced
{
if
(
strides
[
Rank
-
1
]
!=
1
)
return
false
;
if
(
lengths
[
Rank
-
1
]
%
SrcVectorSize
!=
0
)
return
false
;
};
return
true
;
}
template
<
index_t
DstVectorSize
>
bool
IsDstVectorSizeValid
(
const
std
::
vector
<
index_t
>&
lengths
,
const
std
::
vector
<
index_t
>&
strides
)
{
if
constexpr
(
DstVectorSize
==
1
)
return
true
;
if
(
strides
[
NumInvariantDim
-
1
]
!=
1
)
return
false
;
if
(
lengths
[
NumInvariantDim
-
1
]
%
DstVectorSize
!=
0
)
return
false
;
return
true
;
}
bool
IsSupportedArgument
(
const
BaseArgument
*
p_arg
)
override
{
const
Argument
*
p_arg_
=
dynamic_cast
<
const
Argument
*>
(
p_arg
);
bool
pass
=
true
;
pass
&=
IsSrcVectorDimSizeValid
<
DYSrcVectorDim
,
DYSrcVectorSize
>
(
p_arg_
->
inLengths_
,
p_arg_
->
dyStrides_
);
pass
&=
IsSrcVectorDimSizeValid
<
XSrcVectorDim
,
XSrcVectorSize
>
(
p_arg_
->
inLengths_
,
p_arg_
->
xStrides_
);
pass
&=
IsSrcVectorDimSizeValid
<
MeanInvStdSrcVectorDim
,
MeanInvStdSrcVectorSize
>
(
p_arg_
->
inLengths_
,
p_arg_
->
meanStrides_
);
pass
&=
IsSrcVectorDimSizeValid
<
MeanInvStdSrcVectorDim
,
MeanInvStdSrcVectorSize
>
(
p_arg_
->
inLengths_
,
p_arg_
->
invStdStrides_
);
pass
&=
IsDstVectorSizeValid
<
DGammaDstVectorSize
>
(
p_arg_
->
outLengths_
,
p_arg_
->
dgammaStrides_
);
pass
&=
IsDstVectorSizeValid
<
DBetaDstVectorSize
>
(
p_arg_
->
outLengths_
,
p_arg_
->
dbetaStrides_
);
return
pass
;
}
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
std
::
vector
<
index_t
>
inLengths
,
const
std
::
vector
<
index_t
>
dyStrides
,
const
std
::
vector
<
index_t
>
xStrides
,
const
std
::
vector
<
index_t
>
meanStrides
,
const
std
::
vector
<
index_t
>
invStdStrides
,
const
std
::
vector
<
index_t
>
outLengths
,
const
std
::
vector
<
index_t
>
dgammaStrides
,
const
std
::
vector
<
index_t
>
dbetaStrides
,
const
std
::
vector
<
index_t
>
reduceDims
,
const
void
*
p_dy
,
const
void
*
p_x
,
const
void
*
p_mean
,
const
void
*
p_invStd
,
void
*
p_dgamma
,
void
*
p_dbeta
)
override
{
if
(
inLengths
.
size
()
!=
Rank
||
dyStrides
.
size
()
!=
Rank
||
xStrides
.
size
()
!=
Rank
||
meanStrides
.
size
()
!=
Rank
||
invStdStrides
.
size
()
!=
Rank
)
throw
std
::
runtime_error
(
"dimension is incorrect"
);
if
(
outLengths
.
size
()
!=
NumInvariantDim
||
dgammaStrides
.
size
()
!=
NumInvariantDim
||
dbetaStrides
.
size
()
!=
NumInvariantDim
)
throw
std
::
runtime_error
(
"dimension is incorrect"
);
return
std
::
make_unique
<
Argument
>
(
inLengths
,
dyStrides
,
xStrides
,
meanStrides
,
invStdStrides
,
outLengths
,
dgammaStrides
,
dbetaStrides
,
reduceDims
,
static_cast
<
const
DYDataType
*>
(
p_dy
),
static_cast
<
const
XDataType
*>
(
p_x
),
static_cast
<
const
MeanInvStdDataType
*>
(
p_mean
),
static_cast
<
const
MeanInvStdDataType
*>
(
p_invStd
),
static_cast
<
DGammaDataType
*>
(
p_dgamma
),
static_cast
<
DBetaDataType
*>
(
p_dbeta
));
}
virtual
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
override
{
return
std
::
make_unique
<
Invoker
>
();
}
};
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp
View file @
45c6c530
...
...
@@ -85,10 +85,13 @@ struct Add
struct
ScaleAdd
{
__host__
__device__
ScaleAdd
(
float
scale
)
:
scale_
(
scale
)
{}
__host__
__device__
ScaleAdd
(
float
scale
=
1.
f
)
:
scale_
(
scale
)
{}
template
<
typename
Y
,
typename
X0
,
typename
X1
>
__host__
__device__
constexpr
void
operator
()(
Y
&
y
,
const
X0
&
x0
,
const
X1
&
x1
)
const
;
__host__
__device__
constexpr
void
operator
()(
Y
&
y
,
const
X0
&
x0
,
const
X1
&
x1
)
const
{
y
=
ck
::
type_convert
<
Y
>
(
scale_
*
ck
::
type_convert
<
float
>
(
x0
)
+
ck
::
type_convert
<
float
>
(
x1
));
}
template
<
>
__host__
__device__
void
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp
View file @
45c6c530
...
...
@@ -203,7 +203,7 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle
// A desc for source in blockwise copy
template
<
typename
AGridDesc_M_K
>
__host__
__device__
static
constexpr
auto
MakeAGridDescriptor_AK0_M_AK1
(
const
AGridDesc_M_K
&
a_grid_desc_m_k
)
Make
Default
AGridDescriptor_AK0_M_AK1
(
const
AGridDesc_M_K
&
a_grid_desc_m_k
)
{
const
auto
M
=
a_grid_desc_m_k
.
GetLength
(
I0
);
const
auto
K
=
a_grid_desc_m_k
.
GetLength
(
I1
);
...
...
@@ -219,17 +219,17 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle
template
<
typename
AsGridDesc_M_K
>
__host__
__device__
static
constexpr
auto
MakeAsGridDescriptor_AK0_M_AK1
(
const
AsGridDesc_M_K
&
as_grid_desc_m_k
)
Make
Default
AsGridDescriptor_AK0_M_AK1
(
const
AsGridDesc_M_K
&
as_grid_desc_m_k
)
{
return
generate_tuple
(
[
&
](
auto
i
)
{
return
MakeAGridDescriptor_AK0_M_AK1
(
as_grid_desc_m_k
[
i
]);
},
[
&
](
auto
i
)
{
return
Make
Default
AGridDescriptor_AK0_M_AK1
(
as_grid_desc_m_k
[
i
]);
},
Number
<
NumATensor
>
{});
}
// B desc for source in blockwise copy
template
<
typename
BGridDesc_N_K
>
__host__
__device__
static
constexpr
auto
MakeBGridDescriptor_BK0_N_BK1
(
const
BGridDesc_N_K
&
b_grid_desc_n_k
)
Make
Default
BGridDescriptor_BK0_N_BK1
(
const
BGridDesc_N_K
&
b_grid_desc_n_k
)
{
const
auto
N
=
b_grid_desc_n_k
.
GetLength
(
I0
);
const
auto
K
=
b_grid_desc_n_k
.
GetLength
(
I1
);
...
...
@@ -245,10 +245,10 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle
template
<
typename
BsGridDesc_N_K
>
__host__
__device__
static
constexpr
auto
MakeBsGridDescriptor_BK0_N_BK1
(
const
BsGridDesc_N_K
&
bs_grid_desc_n_k
)
Make
Default
BsGridDescriptor_BK0_N_BK1
(
const
BsGridDesc_N_K
&
bs_grid_desc_n_k
)
{
return
generate_tuple
(
[
&
](
auto
i
)
{
return
MakeBGridDescriptor_BK0_N_BK1
(
bs_grid_desc_n_k
[
i
]);
},
[
&
](
auto
i
)
{
return
Make
Default
BGridDescriptor_BK0_N_BK1
(
bs_grid_desc_n_k
[
i
]);
},
Number
<
NumBTensor
>
{});
}
...
...
@@ -288,7 +288,7 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle
// return block_id to E matrix tile idx (m0, n0) mapping
template
<
typename
EGridDesc_M_N
>
__host__
__device__
static
constexpr
auto
MakeBlock2ETileMap
(
const
EGridDesc_M_N
&
e_grid_desc_m_n
)
Make
Default
Block2ETileMap
(
const
EGridDesc_M_N
&
e_grid_desc_m_n
)
{
return
BlockToCTileMap_M00_N0_M01Adapt
<
MPerBlock
,
NPerBlock
,
EGridDesc_M_N
>
(
e_grid_desc_m_n
);
...
...
@@ -591,6 +591,9 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle
generate_tuple
([
&
](
auto
)
{
return
make_multi_index
(
0
,
m_block_data_idx_on_grid
,
0
);
},
Number
<
NumATensor
>
{});
static_assert
(
ABlockTransferSrcScalarPerVector
==
ABlockTransferDstScalarPerVector_AK1
,
"Src and Dst ScalarPerVector must be the same"
);
auto
a_blockwise_copy
=
ThreadGroupTensorSliceTransfer_v7r2
<
ThisThreadBlock
,
AsDataType
,
...
...
@@ -619,6 +622,9 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle
generate_tuple
([
&
](
auto
)
{
return
make_multi_index
(
0
,
n_block_data_idx_on_grid
,
0
);
},
Number
<
NumBTensor
>
{});
static_assert
(
BBlockTransferSrcScalarPerVector
==
BBlockTransferDstScalarPerVector_BK1
,
"Src and Dst ScalarPerVector must be the same"
);
auto
b_blockwise_copy
=
ThreadGroupTensorSliceTransfer_v7r2
<
ThisThreadBlock
,
BsDataType
,
...
...
@@ -1005,9 +1011,9 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle
const
auto
e_grid_desc_m_n
=
MakeEGridDescriptor_M_N
<
ELayout
,
GemmSpec
>
(
M
,
N
,
StrideE
);
// tensor descriptors for block/thread-wise copy
const
auto
as_grid_desc_ak0_m_ak1
=
MakeAsGridDescriptor_AK0_M_AK1
(
as_grid_desc_m_k
);
const
auto
as_grid_desc_ak0_m_ak1
=
Make
Default
AsGridDescriptor_AK0_M_AK1
(
as_grid_desc_m_k
);
const
auto
bs_grid_desc_bk0_n_bk1
=
MakeBsGridDescriptor_BK0_N_BK1
(
bs_grid_desc_n_k
);
const
auto
bs_grid_desc_bk0_n_bk1
=
Make
Default
BsGridDescriptor_BK0_N_BK1
(
bs_grid_desc_n_k
);
const
auto
ds_grid_desc_mblock_mperblock_nblock_nperblock
=
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
ds_grid_desc_m_n
);
...
...
include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_bwd_gamma_beta.hpp
0 → 100644
View file @
45c6c530
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/data_type.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/block/reduction_functions_blockwise.hpp"
namespace
ck
{
// dgamma = reduce_sum(dy * (x - mean) * inv_std)
// dbeta = reduce_sum(dy)
template
<
typename
DYDataType
,
typename
XDataType
,
typename
MeanInvStdDataType
,
typename
ComputeDataType
,
typename
DGammaDataType
,
typename
DBetaDataType
,
typename
GridDesc_M_K
,
typename
GridDesc_M
,
index_t
BlockSize
,
index_t
MThreadClusterSize
,
index_t
KThreadClusterSize
,
index_t
MThreadSliceSize
,
index_t
KThreadSliceSize
,
index_t
DYSrcVectorDim
,
index_t
DYSrcVectorSize
,
index_t
XSrcVectorDim
,
index_t
XSrcVectorSize
,
index_t
MeanInvStdSrcVectorDim
,
index_t
MeanInvStdSrcVectorSize
,
index_t
DGammaDstVectorSize
,
index_t
DBetaDstVectorSize
>
struct
GridwiseNormalizationBwdGammaBeta_mk_to_k
{
// if we just check ThreadSliceSize & VectorSize == 0, the performance may be poor
static_assert
(((
DYSrcVectorDim
==
0
&&
MThreadSliceSize
==
DYSrcVectorSize
)
||
(
DYSrcVectorDim
==
1
&&
KThreadSliceSize
==
DYSrcVectorSize
)),
"Invalid thread slice sizes and/or dy vector sizes configuration, please check!"
);
static_assert
(((
XSrcVectorDim
==
0
&&
MThreadSliceSize
==
XSrcVectorSize
)
||
(
XSrcVectorDim
==
1
&&
KThreadSliceSize
==
XSrcVectorSize
)),
"Invalid thread slice sizes and/or x vector sizes configuration, please check!"
);
using
ThreadClusterLengths_M_K
=
Sequence
<
MThreadClusterSize
,
KThreadClusterSize
>
;
using
DYThreadBufferDimAccessOrder
=
typename
conditional
<
DYSrcVectorDim
==
0
,
Sequence
<
1
,
0
>
,
Sequence
<
0
,
1
>>::
type
;
using
XThreadBufferDimAccessOrder
=
typename
conditional
<
XSrcVectorDim
==
0
,
Sequence
<
1
,
0
>
,
Sequence
<
0
,
1
>>::
type
;
using
MeanInvStdThreadBufferDimAccessOrder
=
typename
conditional
<
MeanInvStdSrcVectorDim
==
0
,
Sequence
<
1
,
0
>
,
Sequence
<
0
,
1
>>::
type
;
using
ThreadClusterArrangeOrder
=
DYThreadBufferDimAccessOrder
;
static
constexpr
auto
thread_cluster_desc
=
make_cluster_descriptor
(
ThreadClusterLengths_M_K
{},
ThreadClusterArrangeOrder
{});
using
ThreadBufferLengths_M_K
=
Sequence
<
MThreadSliceSize
,
KThreadSliceSize
>
;
using
ThreadBufferLengths_M
=
Sequence
<
MThreadSliceSize
>
;
static
constexpr
auto
thread_buffer_desc_m_k
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MThreadSliceSize
>
{},
Number
<
KThreadSliceSize
>
{}));
static
constexpr
auto
thread_buffer_desc_m
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MThreadSliceSize
>
{}));
using
PassThroughOp
=
tensor_operation
::
element_wise
::
PassThrough
;
using
BlockwiseSumReduce
=
PartitionedBlockwiseReduction
<
ComputeDataType
,
BlockSize
,
ThreadClusterLengths_M_K
,
ThreadClusterArrangeOrder
,
reduce
::
Add
,
true
>
;
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
index_t
M_BlockTileSize
=
MThreadClusterSize
*
MThreadSliceSize
;
static
constexpr
index_t
K_BlockTileSize
=
KThreadClusterSize
*
KThreadSliceSize
;
__device__
static
void
Run
(
const
GridDesc_M_K
&
dy_grid_desc_m_k
,
const
GridDesc_M_K
&
x_grid_desc_m_k
,
const
GridDesc_M_K
&
mean_grid_desc_m_k
,
const
GridDesc_M_K
&
inv_std_grid_desc_m_k
,
const
GridDesc_M
&
dgamma_grid_desc_m
,
const
GridDesc_M
&
dbeta_grid_desc_m
,
index_t
num_k_block_tile_iteration
,
const
DYDataType
*
const
__restrict__
p_dy_global
,
const
XDataType
*
const
__restrict__
p_x_global
,
const
MeanInvStdDataType
*
const
__restrict__
p_mean_global
,
const
MeanInvStdDataType
*
const
__restrict__
p_inv_std_global
,
DGammaDataType
*
const
__restrict__
p_dgamma_global
,
DBetaDataType
*
const
__restrict__
p_dbeta_global
)
{
// LDS
__shared__
ComputeDataType
p_reduce_work_buffer
[
BlockSize
];
auto
reduce_work_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
p_reduce_work_buffer
,
BlockSize
);
// Global
const
auto
dy_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_dy_global
,
dy_grid_desc_m_k
.
GetElementSpaceSize
());
const
auto
x_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_x_global
,
x_grid_desc_m_k
.
GetElementSpaceSize
());
const
auto
mean_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_mean_global
,
mean_grid_desc_m_k
.
GetElementSpaceSize
());
const
auto
inv_std_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_inv_std_global
,
inv_std_grid_desc_m_k
.
GetElementSpaceSize
());
auto
dgamma_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_dgamma_global
,
dgamma_grid_desc_m
.
GetElementSpaceSize
());
auto
dbeta_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_dbeta_global
,
dbeta_grid_desc_m
.
GetElementSpaceSize
());
// VGPR
auto
dy_thread_buf
=
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
ComputeDataType
,
MThreadSliceSize
*
KThreadSliceSize
,
true
>
{};
auto
x_thread_buf
=
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
ComputeDataType
,
MThreadSliceSize
*
KThreadSliceSize
,
true
>
{};
auto
mean_thread_buf
=
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
ComputeDataType
,
MThreadSliceSize
*
KThreadSliceSize
,
true
>
{};
auto
inv_std_thread_buf
=
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
ComputeDataType
,
MThreadSliceSize
*
KThreadSliceSize
,
true
>
{};
auto
dgamma_thread_buf
=
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
ComputeDataType
,
MThreadSliceSize
,
true
>
{};
auto
dbeta_thread_buf
=
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
ComputeDataType
,
MThreadSliceSize
,
true
>
{};
const
index_t
thread_local_id
=
get_thread_local_1d_id
();
const
index_t
block_global_id
=
get_block_1d_id
();
const
auto
thread_cluster_idx
=
thread_cluster_desc
.
CalculateBottomIndex
(
make_multi_index
(
thread_local_id
));
const
auto
thread_m_cluster_id
=
thread_cluster_idx
[
I0
];
const
auto
thread_k_cluster_id
=
thread_cluster_idx
[
I1
];
// IO
auto
threadwise_dy_load
=
ThreadwiseTensorSliceTransfer_v2
<
DYDataType
,
ComputeDataType
,
GridDesc_M_K
,
decltype
(
thread_buffer_desc_m_k
),
ThreadBufferLengths_M_K
,
DYThreadBufferDimAccessOrder
,
DYSrcVectorDim
,
DYSrcVectorSize
,
1
,
true
>
(
dy_grid_desc_m_k
,
make_multi_index
(
block_global_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
,
thread_k_cluster_id
*
KThreadSliceSize
));
auto
threadwise_x_load
=
ThreadwiseTensorSliceTransfer_v2
<
XDataType
,
ComputeDataType
,
GridDesc_M_K
,
decltype
(
thread_buffer_desc_m_k
),
ThreadBufferLengths_M_K
,
XThreadBufferDimAccessOrder
,
XSrcVectorDim
,
XSrcVectorSize
,
1
,
true
>
(
x_grid_desc_m_k
,
make_multi_index
(
block_global_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
,
thread_k_cluster_id
*
KThreadSliceSize
));
auto
threadwise_mean_load
=
ThreadwiseTensorSliceTransfer_v2
<
MeanInvStdDataType
,
ComputeDataType
,
GridDesc_M_K
,
decltype
(
thread_buffer_desc_m_k
),
ThreadBufferLengths_M_K
,
MeanInvStdThreadBufferDimAccessOrder
,
MeanInvStdSrcVectorDim
,
MeanInvStdSrcVectorSize
,
1
,
true
>
(
mean_grid_desc_m_k
,
make_multi_index
(
block_global_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
,
thread_k_cluster_id
*
KThreadSliceSize
));
auto
threadwise_inv_std_load
=
ThreadwiseTensorSliceTransfer_v2
<
MeanInvStdDataType
,
ComputeDataType
,
GridDesc_M_K
,
decltype
(
thread_buffer_desc_m_k
),
ThreadBufferLengths_M_K
,
MeanInvStdThreadBufferDimAccessOrder
,
MeanInvStdSrcVectorDim
,
MeanInvStdSrcVectorSize
,
1
,
true
>
(
inv_std_grid_desc_m_k
,
make_multi_index
(
block_global_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
,
thread_k_cluster_id
*
KThreadSliceSize
));
auto
threadwise_dgamma_store
=
ThreadwiseTensorSliceTransfer_v1r3
<
ComputeDataType
,
DGammaDataType
,
decltype
(
thread_buffer_desc_m
),
GridDesc_M
,
PassThroughOp
,
ThreadBufferLengths_M
,
Sequence
<
0
>
,
0
,
DGammaDstVectorSize
,
InMemoryDataOperationEnum
::
Set
,
1
,
true
>
(
dgamma_grid_desc_m
,
make_multi_index
(
block_global_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
),
PassThroughOp
{});
auto
threadwise_dbeta_store
=
ThreadwiseTensorSliceTransfer_v1r3
<
ComputeDataType
,
DBetaDataType
,
decltype
(
thread_buffer_desc_m
),
GridDesc_M
,
PassThroughOp
,
ThreadBufferLengths_M
,
Sequence
<
0
>
,
0
,
DBetaDstVectorSize
,
InMemoryDataOperationEnum
::
Set
,
1
,
true
>
(
dbeta_grid_desc_m
,
make_multi_index
(
block_global_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
),
PassThroughOp
{});
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
I
)
{
dgamma_thread_buf
(
I
)
=
type_convert
<
ComputeDataType
>
(
0.0
f
);
dbeta_thread_buf
(
I
)
=
type_convert
<
ComputeDataType
>
(
0.0
f
);
});
constexpr
auto
thread_copy_fwd_step_m_k
=
make_multi_index
(
0
,
K_BlockTileSize
);
for
(
index_t
reducedTiles
=
0
;
reducedTiles
<
num_k_block_tile_iteration
;
++
reducedTiles
)
{
threadwise_dy_load
.
Run
(
dy_grid_desc_m_k
,
dy_global_val_buf
,
thread_buffer_desc_m_k
,
make_tuple
(
I0
,
I0
),
dy_thread_buf
);
threadwise_x_load
.
Run
(
x_grid_desc_m_k
,
x_global_val_buf
,
thread_buffer_desc_m_k
,
make_tuple
(
I0
,
I0
),
x_thread_buf
);
threadwise_mean_load
.
Run
(
mean_grid_desc_m_k
,
mean_global_val_buf
,
thread_buffer_desc_m_k
,
make_tuple
(
I0
,
I0
),
mean_thread_buf
);
threadwise_inv_std_load
.
Run
(
inv_std_grid_desc_m_k
,
inv_std_global_val_buf
,
thread_buffer_desc_m_k
,
make_tuple
(
I0
,
I0
),
inv_std_thread_buf
);
threadwise_dy_load
.
MoveSrcSliceWindow
(
dy_grid_desc_m_k
,
thread_copy_fwd_step_m_k
);
threadwise_x_load
.
MoveSrcSliceWindow
(
x_grid_desc_m_k
,
thread_copy_fwd_step_m_k
);
threadwise_mean_load
.
MoveSrcSliceWindow
(
mean_grid_desc_m_k
,
thread_copy_fwd_step_m_k
);
threadwise_inv_std_load
.
MoveSrcSliceWindow
(
inv_std_grid_desc_m_k
,
thread_copy_fwd_step_m_k
);
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
iM
)
{
constexpr
auto
offset_m
=
Number
<
thread_buffer_desc_m
.
CalculateOffset
(
make_tuple
(
iM
))
>
{};
static_for
<
0
,
KThreadSliceSize
,
1
>
{}([
&
](
auto
iK
)
{
constexpr
auto
offset_m_k
=
Number
<
thread_buffer_desc_m_k
.
CalculateOffset
(
make_tuple
(
iM
,
iK
))
>
{};
dgamma_thread_buf
(
offset_m
)
+=
dy_thread_buf
[
offset_m_k
]
*
inv_std_thread_buf
[
offset_m_k
]
*
(
x_thread_buf
[
offset_m_k
]
-
mean_thread_buf
[
offset_m_k
]);
dbeta_thread_buf
(
offset_m
)
+=
dy_thread_buf
[
offset_m_k
];
});
});
}
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
I
)
{
if
constexpr
(
I
>
0
)
block_sync_lds
();
BlockwiseSumReduce
::
Reduce
(
reduce_work_buf
,
dbeta_thread_buf
(
I
));
block_sync_lds
();
BlockwiseSumReduce
::
Reduce
(
reduce_work_buf
,
dgamma_thread_buf
(
I
));
});
if
(
thread_k_cluster_id
==
0
)
{
threadwise_dgamma_store
.
Run
(
thread_buffer_desc_m
,
make_tuple
(
I0
),
dgamma_thread_buf
,
dgamma_grid_desc_m
,
dgamma_global_val_buf
);
threadwise_dbeta_store
.
Run
(
thread_buffer_desc_m
,
make_tuple
(
I0
),
dbeta_thread_buf
,
dbeta_grid_desc_m
,
dbeta_global_val_buf
);
}
}
};
}
// namespace ck
library/include/ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp
View file @
45c6c530
...
...
@@ -3,12 +3,23 @@
#pragma once
#include <iostream>
#include <cmath>
#include <cstdlib>
#include <numeric>
#include <type_traits>
#include <
sstream
>
#include <
vector
>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/library/utility/algorithm.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/fill.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/convolution_parameter.hpp"
#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp"
namespace
ck
{
namespace
tensor_operation
{
...
...
@@ -22,6 +33,7 @@ namespace host {
// Supports both GNCHW/NGCHW as well as GNHWC/NHWGC physical layout
// as long as dimensions in tensor descriptor is in GNCHW order
//
// @tparam NDimSpatial Number of spatial dimensions.
// @tparam InDataType Input tensor data type.
// @tparam WeiDataType Weights tensor data type.
// @tparam OutDataType Output tensor data type.
...
...
@@ -29,7 +41,9 @@ namespace host {
// operation.
// @tparam WeiElementwiseOperation Functor for weights tensor elementwise
// operation.
// @tparam NDimSpatial Number of spatial dimensions.
// @tparam NumAElementwiseTensor Number of A elementwise tensors.
// @tparam NumBElementwiseTensor Number of B elementwise tensors.
// @tparam NumDElementwiseTensor Number of D elementwise tensors.
//
// input descriptor in [G, N, C, Do, Ho, Wo] order
// weight descriptor in [G, K, C, Z, Y, X] order
...
...
@@ -42,28 +56,35 @@ template <ck::index_t NDimSpatial,
typename
InElementwiseOperation
,
typename
WeiElementwiseOperation
,
typename
OutElementwiseOperation
,
ck
::
index_t
NumDTensor
=
0
,
ck
::
index_t
NumAElementwiseTensor
=
0
,
ck
::
index_t
NumBElementwiseTensor
=
0
,
ck
::
index_t
NumDElementwiseTensor
=
0
,
typename
std
::
enable_if
<
NDimSpatial
>
=
1
&&
NDimSpatial
<=
3
,
bool
>::
type
=
false
>
struct
ReferenceConvFwd
:
public
device
::
BaseOperator
{
// Argument
struct
Argument
:
public
device
::
BaseArgument
{
Argument
(
const
Tensor
<
InDataType
>&
input
,
const
Tensor
<
WeiDataType
>&
weight
,
Tensor
<
OutDataType
>&
output
,
std
::
vector
<
ck
::
index_t
>
conv_filter_strides
,
std
::
vector
<
ck
::
index_t
>
conv_filter_dilations
,
std
::
vector
<
ck
::
index_t
>
input_left_pads
,
std
::
vector
<
ck
::
index_t
>
input_right_pads
,
InElementwiseOperation
in_element_op
,
WeiElementwiseOperation
wei_element_op
,
OutElementwiseOperation
out_element_op
,
const
std
::
array
<
Tensor
<
OutDataType
>
,
NumDTensor
>&
d_tensors
)
Argument
(
const
Tensor
<
InDataType
>&
input
,
const
Tensor
<
WeiDataType
>&
weight
,
Tensor
<
OutDataType
>&
output
,
std
::
vector
<
ck
::
index_t
>
conv_filter_strides
,
std
::
vector
<
ck
::
index_t
>
conv_filter_dilations
,
std
::
vector
<
ck
::
index_t
>
input_left_pads
,
std
::
vector
<
ck
::
index_t
>
input_right_pads
,
InElementwiseOperation
in_element_op
,
WeiElementwiseOperation
wei_element_op
,
OutElementwiseOperation
out_element_op
,
const
std
::
array
<
Tensor
<
InDataType
>
,
NumAElementwiseTensor
>&
elementwise_a_tensors
,
const
std
::
array
<
Tensor
<
WeiDataType
>
,
NumBElementwiseTensor
>&
elementwise_b_tensors
,
const
std
::
array
<
Tensor
<
OutDataType
>
,
NumDElementwiseTensor
>&
elementwise_d_tensors
)
:
input_
{
input
},
weight_
{
weight
},
output_
{
output
},
d_tensors_
{
d_tensors
},
elementwise_a_tensors_
{
elementwise_a_tensors
},
elementwise_b_tensors_
{
elementwise_b_tensors
},
elementwise_d_tensors_
{
elementwise_d_tensors
},
conv_strides_
{
conv_filter_strides
},
conv_dilations_
{
conv_filter_dilations
},
in_left_pads_
{
input_left_pads
},
...
...
@@ -78,7 +99,9 @@ struct ReferenceConvFwd : public device::BaseOperator
const
Tensor
<
WeiDataType
>&
weight_
;
Tensor
<
OutDataType
>&
output_
;
const
std
::
array
<
Tensor
<
OutDataType
>
,
NumDTensor
>&
d_tensors_
;
const
std
::
array
<
Tensor
<
InDataType
>
,
NumAElementwiseTensor
>&
elementwise_a_tensors_
;
const
std
::
array
<
Tensor
<
WeiDataType
>
,
NumBElementwiseTensor
>&
elementwise_b_tensors_
;
const
std
::
array
<
Tensor
<
OutDataType
>
,
NumDElementwiseTensor
>&
elementwise_d_tensors_
;
std
::
vector
<
index_t
>
conv_strides_
;
std
::
vector
<
index_t
>
conv_dilations_
;
...
...
@@ -119,42 +142,43 @@ struct ReferenceConvFwd : public device::BaseOperator
if
(
wi
>=
0
&&
ck
::
type_convert
<
std
::
size_t
>
(
wi
)
<
arg
.
input_
.
GetLengths
()[
3
])
{
float
v_in
;
float
v_wei
;
arg
.
in_element_op_
(
v_in
,
ck
::
type_convert
<
float
>
(
arg
.
input_
(
g
,
n
,
c
,
wi
)));
arg
.
wei_element_op_
(
v_wei
,
ck
::
type_convert
<
float
>
(
arg
.
weight_
(
g
,
k
,
c
,
x
)));
v_acc
+=
v_in
*
v_wei
;
InDataType
v_in
;
WeiDataType
v_wei
;
ExecuteElementwiseOp
(
arg
.
in_element_op_
,
arg
.
elementwise_a_tensors_
,
Number
<
NumAElementwiseTensor
>
{},
v_in
,
arg
.
input_
(
g
,
n
,
c
,
wi
),
g
,
n
,
c
,
wi
);
ExecuteElementwiseOp
(
arg
.
wei_element_op_
,
arg
.
elementwise_b_tensors_
,
Number
<
NumBElementwiseTensor
>
{},
v_wei
,
arg
.
weight_
(
g
,
k
,
c
,
x
),
g
,
k
,
c
,
x
);
v_acc
+=
ck
::
type_convert
<
float
>
(
v_in
)
*
ck
::
type_convert
<
float
>
(
v_wei
);
}
}
}
OutDataType
v_out
;
OutDataType
v_acc_converted
=
ck
::
type_convert
<
OutDataType
>
(
v_acc
);
if
constexpr
(
NumDTensor
==
0
)
{
arg
.
out_element_op_
(
v_out
,
v_acc_converted
);
}
else
if
constexpr
(
NumDTensor
==
1
)
{
arg
.
out_element_op_
(
v_out
,
v_acc_converted
,
arg
.
d_tensors_
[
0
](
g
,
n
,
k
,
wo
));
}
else
if
constexpr
(
NumDTensor
==
2
)
{
arg
.
out_element_op_
(
v_out
,
v_acc_converted
,
arg
.
d_tensors_
[
0
](
g
,
n
,
k
,
wo
),
arg
.
d_tensors_
[
1
](
g
,
n
,
k
,
wo
));
}
else
{
throw
std
::
runtime_error
(
"Output ElementOp not supported in reference."
);
}
arg
.
output_
(
g
,
n
,
k
,
wo
)
=
v_out
;
OutDataType
&
v_out
=
arg
.
output_
(
g
,
n
,
k
,
wo
);
ExecuteElementwiseOp
(
arg
.
out_element_op_
,
arg
.
elementwise_d_tensors_
,
Number
<
NumDElementwiseTensor
>
{},
v_out
,
v_acc_converted
,
g
,
n
,
k
,
wo
);
};
make_ParallelTensorFunctor
(
func
,
...
...
@@ -191,44 +215,47 @@ struct ReferenceConvFwd : public device::BaseOperator
wi
>=
0
&&
ck
::
type_convert
<
std
::
size_t
>
(
wi
)
<
arg
.
input_
.
GetLengths
()[
4
])
{
float
v_in
;
float
v_wei
;
arg
.
in_element_op_
(
v_in
,
ck
::
type_convert
<
float
>
(
arg
.
input_
(
g
,
n
,
c
,
hi
,
wi
)));
arg
.
wei_element_op_
(
v_wei
,
ck
::
type_convert
<
float
>
(
arg
.
weight_
(
g
,
k
,
c
,
y
,
x
)));
v_acc
+=
v_in
*
v_wei
;
InDataType
v_in
;
WeiDataType
v_wei
;
ExecuteElementwiseOp
(
arg
.
in_element_op_
,
arg
.
elementwise_a_tensors_
,
Number
<
NumAElementwiseTensor
>
{},
v_in
,
arg
.
input_
(
g
,
n
,
c
,
hi
,
wi
),
g
,
n
,
c
,
hi
,
wi
);
ExecuteElementwiseOp
(
arg
.
wei_element_op_
,
arg
.
elementwise_b_tensors_
,
Number
<
NumBElementwiseTensor
>
{},
v_wei
,
arg
.
weight_
(
g
,
k
,
c
,
y
,
x
),
g
,
k
,
c
,
y
,
x
);
v_acc
+=
ck
::
type_convert
<
float
>
(
v_in
)
*
ck
::
type_convert
<
float
>
(
v_wei
);
}
}
}
}
OutDataType
v_out
;
OutDataType
v_acc_converted
=
ck
::
type_convert
<
OutDataType
>
(
v_acc
);
if
constexpr
(
NumDTensor
==
0
)
{
arg
.
out_element_op_
(
v_out
,
v_acc_converted
);
}
else
if
constexpr
(
NumDTensor
==
1
)
{
arg
.
out_element_op_
(
v_out
,
v_acc_converted
,
arg
.
d_tensors_
[
0
](
g
,
n
,
k
,
ho
,
wo
));
}
else
if
constexpr
(
NumDTensor
==
2
)
{
arg
.
out_element_op_
(
v_out
,
v_acc_converted
,
arg
.
d_tensors_
[
0
](
g
,
n
,
k
,
ho
,
wo
),
arg
.
d_tensors_
[
1
](
g
,
n
,
k
,
ho
,
wo
));
}
else
{
throw
std
::
runtime_error
(
"Output ElementOp not supported in reference."
);
}
arg
.
output_
(
g
,
n
,
k
,
ho
,
wo
)
=
v_out
;
OutDataType
&
v_out
=
arg
.
output_
(
g
,
n
,
k
,
ho
,
wo
);
ExecuteElementwiseOp
(
arg
.
out_element_op_
,
arg
.
elementwise_d_tensors_
,
Number
<
NumDElementwiseTensor
>
{},
v_out
,
v_acc_converted
,
g
,
n
,
k
,
ho
,
wo
);
};
make_ParallelTensorFunctor
(
func
,
...
...
@@ -275,47 +302,51 @@ struct ReferenceConvFwd : public device::BaseOperator
ck
::
type_convert
<
std
::
size_t
>
(
wi
)
<
arg
.
input_
.
GetLengths
()[
5
])
{
float
v_in
;
float
v_wei
;
arg
.
in_element_op_
(
v_in
,
ck
::
type_convert
<
float
>
(
arg
.
input_
(
g
,
n
,
c
,
di
,
hi
,
wi
)));
arg
.
wei_element_op_
(
v_wei
,
ck
::
type_convert
<
float
>
(
arg
.
weight_
(
g
,
k
,
c
,
z
,
y
,
x
)));
v_acc
+=
v_in
*
v_wei
;
InDataType
v_in
;
WeiDataType
v_wei
;
ExecuteElementwiseOp
(
arg
.
in_element_op_
,
arg
.
elementwise_a_tensors_
,
Number
<
NumAElementwiseTensor
>
{},
v_in
,
arg
.
input_
(
g
,
n
,
c
,
di
,
hi
,
wi
),
g
,
n
,
c
,
di
,
hi
,
wi
);
ExecuteElementwiseOp
(
arg
.
wei_element_op_
,
arg
.
elementwise_b_tensors_
,
Number
<
NumBElementwiseTensor
>
{},
v_wei
,
arg
.
weight_
(
g
,
k
,
c
,
z
,
y
,
x
),
g
,
k
,
c
,
z
,
y
,
x
);
v_acc
+=
ck
::
type_convert
<
float
>
(
v_in
)
*
ck
::
type_convert
<
float
>
(
v_wei
);
}
}
}
}
}
OutDataType
v_out
;
OutDataType
v_acc_converted
=
ck
::
type_convert
<
OutDataType
>
(
v_acc
);
if
constexpr
(
NumDTensor
==
0
)
{
arg
.
out_element_op_
(
v_out
,
v_acc_converted
);
}
else
if
constexpr
(
NumDTensor
==
1
)
{
arg
.
out_element_op_
(
v_out
,
v_acc_converted
,
arg
.
d_tensors_
[
0
](
g
,
n
,
k
,
d_o
,
ho
,
wo
));
}
else
if
constexpr
(
NumDTensor
==
2
)
{
arg
.
out_element_op_
(
v_out
,
v_acc_converted
,
arg
.
d_tensors_
[
0
](
g
,
n
,
k
,
d_o
,
ho
,
wo
),
arg
.
d_tensors_
[
1
](
g
,
n
,
k
,
d_o
,
ho
,
wo
));
}
else
{
throw
std
::
runtime_error
(
"Output ElementOp not supported in reference."
);
}
arg
.
output_
(
g
,
n
,
k
,
d_o
,
ho
,
wo
)
=
v_out
;
OutDataType
&
v_out
=
arg
.
output_
(
g
,
n
,
k
,
d_o
,
ho
,
wo
);
ExecuteElementwiseOp
(
arg
.
out_element_op_
,
arg
.
elementwise_d_tensors_
,
Number
<
NumDElementwiseTensor
>
{},
v_out
,
v_acc_converted
,
g
,
n
,
k
,
d_o
,
ho
,
wo
);
};
make_ParallelTensorFunctor
(
func
,
...
...
@@ -338,6 +369,36 @@ struct ReferenceConvFwd : public device::BaseOperator
}
};
template
<
typename
...
Args
,
typename
ElementwiseOp
,
typename
ElementwiseTensor
,
typename
NumTensor
,
typename
T
>
static
void
ExecuteElementwiseOp
(
ElementwiseOp
&
elementwise_op
,
ElementwiseTensor
&
elementwise_tensors
,
NumTensor
,
T
&
y
,
const
T
&
x
,
Args
...
dims
)
{
if
constexpr
(
NumTensor
::
value
==
0
)
{
elementwise_op
(
y
,
x
);
}
else
if
constexpr
(
NumTensor
::
value
==
1
)
{
elementwise_op
(
y
,
x
,
elementwise_tensors
[
0
](
dims
...));
}
else
if
constexpr
(
NumTensor
::
value
==
2
)
{
elementwise_op
(
y
,
x
,
elementwise_tensors
[
0
](
dims
...),
elementwise_tensors
[
1
](
dims
...));
}
else
{
throw
std
::
runtime_error
(
"ElementOp not supported in reference."
);
}
}
static
constexpr
bool
IsValidCompilationParameter
()
{
// TODO: properly implement this check
...
...
@@ -349,17 +410,20 @@ struct ReferenceConvFwd : public device::BaseOperator
return
NDimSpatial
>=
1
&&
NDimSpatial
<=
3
;
}
static
auto
MakeArgument
(
const
Tensor
<
InDataType
>&
input
,
const
Tensor
<
WeiDataType
>&
weight
,
Tensor
<
OutDataType
>&
output
,
std
::
vector
<
ck
::
index_t
>
conv_filter_strides
,
std
::
vector
<
ck
::
index_t
>
conv_filter_dilations
,
std
::
vector
<
ck
::
index_t
>
input_left_pads
,
std
::
vector
<
ck
::
index_t
>
input_right_pads
,
InElementwiseOperation
in_element_op
,
WeiElementwiseOperation
wei_element_op
,
OutElementwiseOperation
out_element_op
,
const
std
::
array
<
Tensor
<
OutDataType
>
,
NumDTensor
>&
d_tensors
=
{})
static
auto
MakeArgument
(
const
Tensor
<
InDataType
>&
input
,
const
Tensor
<
WeiDataType
>&
weight
,
Tensor
<
OutDataType
>&
output
,
std
::
vector
<
ck
::
index_t
>
conv_filter_strides
,
std
::
vector
<
ck
::
index_t
>
conv_filter_dilations
,
std
::
vector
<
ck
::
index_t
>
input_left_pads
,
std
::
vector
<
ck
::
index_t
>
input_right_pads
,
InElementwiseOperation
in_element_op
,
WeiElementwiseOperation
wei_element_op
,
OutElementwiseOperation
out_element_op
,
const
std
::
array
<
Tensor
<
InDataType
>
,
NumAElementwiseTensor
>&
elementwise_a_tensors
=
{},
const
std
::
array
<
Tensor
<
WeiDataType
>
,
NumBElementwiseTensor
>&
elementwise_b_tensors
=
{},
const
std
::
array
<
Tensor
<
OutDataType
>
,
NumDElementwiseTensor
>&
elementwise_d_tensors
=
{})
{
return
Argument
{
input
,
weight
,
...
...
@@ -371,7 +435,9 @@ struct ReferenceConvFwd : public device::BaseOperator
in_element_op
,
wei_element_op
,
out_element_op
,
d_tensors
};
elementwise_a_tensors
,
elementwise_b_tensors
,
elementwise_d_tensors
};
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
...
...
library/include/ck/library/reference_tensor_operation/cpu/reference_groupnorm_bwd.hpp
0 → 100644
View file @
45c6c530
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <sstream>
#include <vector>
#include <algorithm>
#include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
host
{
template
<
typename
DYDataType
,
typename
XDataType
,
typename
GammaDataType
,
typename
MeanInvStdDataType
,
typename
DGammaDataType
,
typename
DBetaDataType
,
typename
DXDataType
,
typename
ComputeDataType
>
struct
ReferenceGroupnormBwd
:
public
device
::
BaseOperator
{
// Argument
struct
Argument
:
public
device
::
BaseArgument
{
Argument
(
const
Tensor
<
DYDataType
>&
dy_nhwgc
,
const
Tensor
<
XDataType
>&
x_nhwgc
,
const
Tensor
<
GammaDataType
>&
gamma_gc
,
const
Tensor
<
MeanInvStdDataType
>&
mean_ng
,
const
Tensor
<
MeanInvStdDataType
>&
inv_std_ng
,
Tensor
<
DGammaDataType
>&
dgamma_gc
,
Tensor
<
DBetaDataType
>&
dbeta_gc
,
Tensor
<
DXDataType
>&
dx_nhwgc
,
const
std
::
vector
<
index_t
>
lengths
)
:
dy_nhwgc_
(
dy_nhwgc
),
x_nhwgc_
(
x_nhwgc
),
gamma_gc_
(
gamma_gc
),
mean_ng_
(
mean_ng
),
inv_std_ng_
(
inv_std_ng
),
dgamma_gc_
(
dgamma_gc
),
dbeta_gc_
(
dbeta_gc
),
dx_nhwgc_
(
dx_nhwgc
),
lengths_
(
lengths
)
{
}
const
Tensor
<
DYDataType
>&
dy_nhwgc_
;
const
Tensor
<
XDataType
>&
x_nhwgc_
;
const
Tensor
<
GammaDataType
>&
gamma_gc_
;
const
Tensor
<
MeanInvStdDataType
>&
mean_ng_
;
const
Tensor
<
MeanInvStdDataType
>&
inv_std_ng_
;
Tensor
<
DGammaDataType
>&
dgamma_gc_
;
Tensor
<
DBetaDataType
>&
dbeta_gc_
;
Tensor
<
DXDataType
>&
dx_nhwgc_
;
std
::
vector
<
index_t
>
lengths_
;
};
// Invoker
struct
Invoker
:
public
device
::
BaseInvoker
{
float
Run
(
const
Argument
&
arg
)
{
int
N
=
arg
.
lengths_
[
0
];
int
H
=
arg
.
lengths_
[
1
];
int
W
=
arg
.
lengths_
[
2
];
int
G
=
arg
.
lengths_
[
3
];
int
C
=
arg
.
lengths_
[
4
];
// Calculate dgamma and dbeta
for
(
int
g
=
0
;
g
<
G
;
++
g
)
for
(
int
c
=
0
;
c
<
C
;
++
c
)
{
ComputeDataType
dgamma
=
0
;
ComputeDataType
dbeta
=
0
;
for
(
int
n
=
0
;
n
<
N
;
++
n
)
for
(
int
h
=
0
;
h
<
H
;
++
h
)
for
(
int
w
=
0
;
w
<
W
;
++
w
)
{
ComputeDataType
dy
=
ck
::
type_convert
<
ComputeDataType
>
(
arg
.
dy_nhwgc_
(
n
,
h
,
w
,
g
,
c
));
ComputeDataType
x
=
ck
::
type_convert
<
ComputeDataType
>
(
arg
.
x_nhwgc_
(
n
,
h
,
w
,
g
,
c
));
ComputeDataType
mean
=
ck
::
type_convert
<
ComputeDataType
>
(
arg
.
mean_ng_
(
n
,
g
));
ComputeDataType
rstd
=
ck
::
type_convert
<
ComputeDataType
>
(
arg
.
inv_std_ng_
(
n
,
g
));
dgamma
+=
dy
*
rstd
*
(
x
-
mean
);
dbeta
+=
dy
;
}
arg
.
dgamma_gc_
(
g
,
c
)
=
ck
::
type_convert
<
DGammaDataType
>
(
dgamma
);
arg
.
dbeta_gc_
(
g
,
c
)
=
ck
::
type_convert
<
DBetaDataType
>
(
dbeta
);
}
// Calculate dx
int
reduce_size
=
H
*
W
*
C
;
for
(
int
n
=
0
;
n
<
N
;
++
n
)
for
(
int
g
=
0
;
g
<
G
;
++
g
)
{
ComputeDataType
ds
=
0
;
ComputeDataType
db
=
0
;
ComputeDataType
mean
=
ck
::
type_convert
<
ComputeDataType
>
(
arg
.
mean_ng_
(
n
,
g
));
ComputeDataType
rstd
=
ck
::
type_convert
<
ComputeDataType
>
(
arg
.
inv_std_ng_
(
n
,
g
));
for
(
int
h
=
0
;
h
<
H
;
++
h
)
for
(
int
w
=
0
;
w
<
W
;
++
w
)
for
(
int
c
=
0
;
c
<
C
;
++
c
)
{
ComputeDataType
dy
=
ck
::
type_convert
<
ComputeDataType
>
(
arg
.
dy_nhwgc_
(
n
,
h
,
w
,
g
,
c
));
ComputeDataType
x
=
ck
::
type_convert
<
ComputeDataType
>
(
arg
.
x_nhwgc_
(
n
,
h
,
w
,
g
,
c
));
ComputeDataType
gamma
=
ck
::
type_convert
<
ComputeDataType
>
(
arg
.
gamma_gc_
(
g
,
c
));
ds
+=
dy
*
gamma
*
x
;
db
+=
dy
*
gamma
;
}
for
(
int
h
=
0
;
h
<
H
;
++
h
)
for
(
int
w
=
0
;
w
<
W
;
++
w
)
for
(
int
c
=
0
;
c
<
C
;
++
c
)
{
ComputeDataType
dy
=
ck
::
type_convert
<
ComputeDataType
>
(
arg
.
dy_nhwgc_
(
n
,
h
,
w
,
g
,
c
));
ComputeDataType
x
=
ck
::
type_convert
<
ComputeDataType
>
(
arg
.
x_nhwgc_
(
n
,
h
,
w
,
g
,
c
));
ComputeDataType
gamma
=
ck
::
type_convert
<
ComputeDataType
>
(
arg
.
gamma_gc_
(
g
,
c
));
ComputeDataType
b
=
(
db
*
mean
-
ds
)
*
rstd
*
rstd
*
rstd
/
reduce_size
;
ComputeDataType
c1
=
-
b
*
mean
-
db
*
rstd
/
reduce_size
;
arg
.
dx_nhwgc_
(
n
,
h
,
w
,
g
,
c
)
=
ck
::
type_convert
<
DXDataType
>
(
dy
*
gamma
*
rstd
+
b
*
x
+
c1
);
}
}
return
0
;
}
float
Run
(
const
device
::
BaseArgument
*
p_arg
,
const
StreamConfig
&
/* stream_config */
=
StreamConfig
{})
override
{
return
Run
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
));
}
};
static
constexpr
bool
IsValidCompilationParameter
()
{
// TODO: properly implement this check
return
true
;
}
bool
IsSupportedArgument
(
const
device
::
BaseArgument
*
)
override
{
return
true
;
}
static
auto
MakeArgument
(
const
Tensor
<
DYDataType
>&
dy_nhwgc
,
const
Tensor
<
XDataType
>&
x_nhwgc
,
const
Tensor
<
GammaDataType
>&
gamma_gc
,
const
Tensor
<
MeanInvStdDataType
>&
mean_ng
,
const
Tensor
<
MeanInvStdDataType
>&
inv_std_ng
,
Tensor
<
DGammaDataType
>&
dgamma_gc
,
Tensor
<
DBetaDataType
>&
dbeta_gc
,
Tensor
<
DXDataType
>&
dx_nhwgc
,
const
std
::
vector
<
index_t
>
lengths
)
{
return
Argument
{
dy_nhwgc
,
x_nhwgc
,
gamma_gc
,
mean_ng
,
inv_std_ng
,
dgamma_gc
,
dbeta_gc
,
dx_nhwgc
,
lengths
};
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
virtual
std
::
unique_ptr
<
device
::
BaseInvoker
>
MakeInvokerPointer
()
{
return
std
::
make_unique
<
Invoker
>
(
Invoker
{});
}
std
::
string
GetTypeString
()
const
override
{
auto
str
=
std
::
stringstream
();
// clang-format off
str
<<
"ReferenceGroupnormBwd"
<<
std
::
endl
;
// clang-format on
return
str
.
str
();
}
};
}
// namespace host
}
// namespace tensor_operation
}
// namespace ck
library/include/ck/library/reference_tensor_operation/cpu/reference_layernorm_bwd.hpp
0 → 100644
View file @
45c6c530
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <sstream>
#include <vector>
#include <algorithm>
#include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
host
{
template
<
typename
DYDataType
,
typename
XDataType
,
typename
GammaDataType
,
typename
MeanInvStdDataType
,
typename
DGammaDataType
,
typename
DBetaDataType
,
typename
DXDataType
,
typename
ComputeDataType
>
struct
ReferenceLayernormBwd
:
public
device
::
BaseOperator
{
// Argument
struct
Argument
:
public
device
::
BaseArgument
{
Argument
(
const
Tensor
<
DYDataType
>&
dy_m_n
,
const
Tensor
<
XDataType
>&
x_m_n
,
const
Tensor
<
GammaDataType
>&
gamma_n
,
const
Tensor
<
MeanInvStdDataType
>&
mean_m
,
const
Tensor
<
MeanInvStdDataType
>&
inv_std_m
,
Tensor
<
DGammaDataType
>&
dgamma_n
,
Tensor
<
DBetaDataType
>&
dbeta_n
,
Tensor
<
DXDataType
>&
dx_m_n
,
const
std
::
vector
<
index_t
>
lengths
)
:
dy_m_n_
(
dy_m_n
),
x_m_n_
(
x_m_n
),
gamma_n_
(
gamma_n
),
mean_m_
(
mean_m
),
inv_std_m_
(
inv_std_m
),
dgamma_n_
(
dgamma_n
),
dbeta_n_
(
dbeta_n
),
dx_m_n_
(
dx_m_n
),
lengths_
(
lengths
)
{
}
const
Tensor
<
DYDataType
>&
dy_m_n_
;
const
Tensor
<
XDataType
>&
x_m_n_
;
const
Tensor
<
GammaDataType
>&
gamma_n_
;
const
Tensor
<
MeanInvStdDataType
>&
mean_m_
;
const
Tensor
<
MeanInvStdDataType
>&
inv_std_m_
;
Tensor
<
DGammaDataType
>&
dgamma_n_
;
Tensor
<
DBetaDataType
>&
dbeta_n_
;
Tensor
<
DXDataType
>&
dx_m_n_
;
std
::
vector
<
index_t
>
lengths_
;
};
// Invoker
struct
Invoker
:
public
device
::
BaseInvoker
{
float
Run
(
const
Argument
&
arg
)
{
int
M
=
arg
.
lengths_
[
0
];
int
N
=
arg
.
lengths_
[
1
];
// Calculate dgamma and dbeta
for
(
int
n
=
0
;
n
<
N
;
++
n
)
{
ComputeDataType
dgamma
=
0
;
ComputeDataType
dbeta
=
0
;
for
(
int
m
=
0
;
m
<
M
;
++
m
)
{
ComputeDataType
dy
=
ck
::
type_convert
<
ComputeDataType
>
(
arg
.
dy_m_n_
(
m
,
n
));
ComputeDataType
x
=
ck
::
type_convert
<
ComputeDataType
>
(
arg
.
x_m_n_
(
m
,
n
));
ComputeDataType
mean
=
ck
::
type_convert
<
ComputeDataType
>
(
arg
.
mean_m_
(
m
));
ComputeDataType
rstd
=
ck
::
type_convert
<
ComputeDataType
>
(
arg
.
inv_std_m_
(
m
));
dgamma
+=
dy
*
rstd
*
(
x
-
mean
);
dbeta
+=
dy
;
}
arg
.
dgamma_n_
(
n
)
=
ck
::
type_convert
<
DGammaDataType
>
(
dgamma
);
arg
.
dbeta_n_
(
n
)
=
ck
::
type_convert
<
DBetaDataType
>
(
dbeta
);
}
// Calculate dx
for
(
int
m
=
0
;
m
<
M
;
++
m
)
{
ComputeDataType
ds
=
0
;
ComputeDataType
db
=
0
;
ComputeDataType
mean
=
ck
::
type_convert
<
ComputeDataType
>
(
arg
.
mean_m_
(
m
));
ComputeDataType
rstd
=
ck
::
type_convert
<
ComputeDataType
>
(
arg
.
inv_std_m_
(
m
));
for
(
int
n
=
0
;
n
<
N
;
++
n
)
{
ComputeDataType
dy
=
ck
::
type_convert
<
ComputeDataType
>
(
arg
.
dy_m_n_
(
m
,
n
));
ComputeDataType
x
=
ck
::
type_convert
<
ComputeDataType
>
(
arg
.
x_m_n_
(
m
,
n
));
ComputeDataType
gamma
=
ck
::
type_convert
<
ComputeDataType
>
(
arg
.
gamma_n_
(
n
));
ds
+=
dy
*
gamma
*
x
;
db
+=
dy
*
gamma
;
}
for
(
int
n
=
0
;
n
<
N
;
++
n
)
{
ComputeDataType
dy
=
ck
::
type_convert
<
ComputeDataType
>
(
arg
.
dy_m_n_
(
m
,
n
));
ComputeDataType
x
=
ck
::
type_convert
<
ComputeDataType
>
(
arg
.
x_m_n_
(
m
,
n
));
ComputeDataType
gamma
=
ck
::
type_convert
<
ComputeDataType
>
(
arg
.
gamma_n_
(
n
));
ComputeDataType
b
=
(
db
*
mean
-
ds
)
*
rstd
*
rstd
*
rstd
/
N
;
ComputeDataType
c
=
-
b
*
mean
-
db
*
rstd
/
N
;
arg
.
dx_m_n_
(
m
,
n
)
=
ck
::
type_convert
<
DXDataType
>
(
dy
*
gamma
*
rstd
+
b
*
x
+
c
);
}
}
return
0
;
}
float
Run
(
const
device
::
BaseArgument
*
p_arg
,
const
StreamConfig
&
/* stream_config */
=
StreamConfig
{})
override
{
return
Run
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
));
}
};
static
constexpr
bool
IsValidCompilationParameter
()
{
// TODO: properly implement this check
return
true
;
}
bool
IsSupportedArgument
(
const
device
::
BaseArgument
*
)
override
{
return
true
;
}
static
auto
MakeArgument
(
const
Tensor
<
DYDataType
>&
dy_m_n
,
const
Tensor
<
XDataType
>&
x_m_n
,
const
Tensor
<
GammaDataType
>&
gamma_n
,
const
Tensor
<
MeanInvStdDataType
>&
mean_m
,
const
Tensor
<
MeanInvStdDataType
>&
inv_std_m
,
Tensor
<
DGammaDataType
>&
dgamma_n
,
Tensor
<
DBetaDataType
>&
dbeta_n
,
Tensor
<
DXDataType
>&
dx_m_n
,
const
std
::
vector
<
index_t
>
lengths
)
{
return
Argument
{
dy_m_n
,
x_m_n
,
gamma_n
,
mean_m
,
inv_std_m
,
dgamma_n
,
dbeta_n
,
dx_m_n
,
lengths
};
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
virtual
std
::
unique_ptr
<
device
::
BaseInvoker
>
MakeInvokerPointer
()
{
return
std
::
make_unique
<
Invoker
>
(
Invoker
{});
}
std
::
string
GetTypeString
()
const
override
{
auto
str
=
std
::
stringstream
();
// clang-format off
str
<<
"ReferenceLayernormBwd"
<<
std
::
endl
;
// clang-format on
return
str
.
str
();
}
};
}
// namespace host
}
// namespace tensor_operation
}
// namespace ck
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment