Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
4a106f7d
Commit
4a106f7d
authored
Nov 01, 2023
by
illsilin
Browse files
merge from the public repo
parents
a73ab0d8
306fd506
Changes
601
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
3862 additions
and
434 deletions
+3862
-434
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_cshuffle.hpp
.../impl/device_grouped_conv_fwd_multiple_d_xdl_cshuffle.hpp
+24
-59
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp
...r_operation/gpu/device/impl/device_grouped_conv_utils.hpp
+66
-0
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_dl.hpp
...ion/gpu/device/impl/device_grouped_gemm_multiple_d_dl.hpp
+10
-7
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp
...device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp
+15
-14
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl.hpp
...sor_operation/gpu/device/impl/device_grouped_gemm_xdl.hpp
+29
-15
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_fixed_nk.hpp
...tion/gpu/device/impl/device_grouped_gemm_xdl_fixed_nk.hpp
+839
-0
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_splitk_cshuffle.hpp
...u/device/impl/device_grouped_gemm_xdl_splitk_cshuffle.hpp
+632
-0
include/ck/tensor_operation/gpu/device/impl/device_image_to_column_impl.hpp
...operation/gpu/device/impl/device_image_to_column_impl.hpp
+407
-0
include/ck/tensor_operation/gpu/device/impl/device_max_pool_bwd_impl.hpp
...or_operation/gpu/device/impl/device_max_pool_bwd_impl.hpp
+325
-0
include/ck/tensor_operation/gpu/device/impl/device_multiple_reduce_multiblock.hpp
...ion/gpu/device/impl/device_multiple_reduce_multiblock.hpp
+1
-1
include/ck/tensor_operation/gpu/device/impl/device_multiple_reduce_threadwise.hpp
...ion/gpu/device/impl/device_multiple_reduce_threadwise.hpp
+1
-1
include/ck/tensor_operation/gpu/device/impl/device_normalization_impl.hpp
...r_operation/gpu/device/impl/device_normalization_impl.hpp
+117
-43
include/ck/tensor_operation/gpu/device/impl/device_normalization_splitk_impl.hpp
...tion/gpu/device/impl/device_normalization_splitk_impl.hpp
+748
-0
include/ck/tensor_operation/gpu/device/impl/device_permute_impl.hpp
.../tensor_operation/gpu/device/impl/device_permute_impl.hpp
+1
-1
include/ck/tensor_operation/gpu/device/impl/device_pool2d_fwd_nhwc_nhwc.hpp
...operation/gpu/device/impl/device_pool2d_fwd_nhwc_nhwc.hpp
+76
-290
include/ck/tensor_operation/gpu/device/impl/device_pool3d_fwd_ndhwc_ndhwc.hpp
...eration/gpu/device/impl/device_pool3d_fwd_ndhwc_ndhwc.hpp
+411
-0
include/ck/tensor_operation/gpu/device/impl/device_put_element_impl.hpp
...sor_operation/gpu/device/impl/device_put_element_impl.hpp
+155
-0
include/ck/tensor_operation/gpu/device/impl/device_reduce_common.hpp
...tensor_operation/gpu/device/impl/device_reduce_common.hpp
+1
-1
include/ck/tensor_operation/gpu/device/impl/device_reduce_multiblock.hpp
...or_operation/gpu/device/impl/device_reduce_multiblock.hpp
+1
-1
include/ck/tensor_operation/gpu/device/impl/device_reduce_threadwise.hpp
...or_operation/gpu/device/impl/device_reduce_threadwise.hpp
+3
-1
No files found.
Too many changes to show.
To preserve performance only
601 of 601+
files are displayed.
Plain diff
Email patch
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_cshuffle.hpp
View file @
4a106f7d
...
...
@@ -19,6 +19,7 @@
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/host_utility/io.hpp"
...
...
@@ -29,51 +30,6 @@ namespace device {
namespace
{
template
<
index_t
NumDTensor
>
struct
ComputePtrOffsetOfStridedBatch
{
ComputePtrOffsetOfStridedBatch
()
=
default
;
ComputePtrOffsetOfStridedBatch
(
index_t
BatchStrideA
,
index_t
BatchStrideB
,
Array
<
ck
::
index_t
,
NumDTensor
>
BatchStrideDs
,
index_t
BatchStrideE
)
:
BatchStrideA_
(
BatchStrideA
),
BatchStrideB_
(
BatchStrideB
),
BatchStrideDs_
(
BatchStrideDs
),
BatchStrideE_
(
BatchStrideE
)
{
}
__host__
__device__
constexpr
long_index_t
GetAPtrOffset
(
index_t
g_idx
)
const
{
return
g_idx
*
static_cast
<
long_index_t
>
(
BatchStrideA_
);
}
__host__
__device__
constexpr
long_index_t
GetBPtrOffset
(
index_t
g_idx
)
const
{
return
g_idx
*
static_cast
<
long_index_t
>
(
BatchStrideB_
);
}
__host__
__device__
constexpr
auto
GetDsPtrOffset
(
index_t
g_idx
)
const
{
Array
<
long_index_t
,
NumDTensor
>
ds_offset
;
static_for
<
0
,
NumDTensor
,
1
>
{}(
[
&
](
auto
i
)
{
ds_offset
(
i
)
=
g_idx
*
static_cast
<
long_index_t
>
(
BatchStrideDs_
[
i
]);
});
return
ds_offset
;
}
__host__
__device__
constexpr
long_index_t
GetEPtrOffset
(
index_t
g_idx
)
const
{
return
g_idx
*
static_cast
<
long_index_t
>
(
BatchStrideE_
);
}
index_t
BatchStrideA_
;
index_t
BatchStrideB_
;
Array
<
ck
::
index_t
,
NumDTensor
>
BatchStrideDs_
;
index_t
BatchStrideE_
;
};
/*
* \brief Wrapper function of GridwiseGemm::Run to realize BatchedGEMM.
*
...
...
@@ -136,7 +92,7 @@ __global__ void
const
ComputePtrOffsetOfBatch
compute_ptr_offset_of_batch
)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__))
defined(__gfx940__)
|| defined(__gfx941__) || defined(__gfx942__)
)
// offset base pointer for each work-group
const
index_t
num_blocks_per_batch
=
__builtin_amdgcn_readfirstlane
(
get_grid_size
()
/
batch_count
);
...
...
@@ -255,7 +211,8 @@ template <index_t NDimSpatial,
index_t
CShuffleNXdlPerWavePerShuffle
,
typename
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
index_t
CDEBlockTransferScalarPerVector_NPerBlock
,
LoopScheduler
LoopSched
=
make_default_loop_scheduler
()>
typename
ComputeDataType
=
ADataType
,
LoopScheduler
LoopSched
=
make_default_loop_scheduler
()>
struct
DeviceGroupedConvFwdMultipleD_Xdl_CShuffle
:
public
DeviceGroupedConvFwdMultipleD
<
NDimSpatial
,
ALayout
,
...
...
@@ -268,7 +225,8 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle
EDataType
,
AElementwiseOperation
,
BElementwiseOperation
,
CDEElementwiseOperation
>
CDEElementwiseOperation
,
ComputeDataType
>
{
using
DeviceOp
=
DeviceGroupedConvFwdMultipleD_Xdl_CShuffle
;
...
...
@@ -361,8 +319,8 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle
}
// desc for problem definition
using
AGridDesc_M_K
=
remove_cvref_t
<
decltype
(
MakeAGridDescriptor_M_K
<
ALayout
>
(
{},
{},
{},
{},
{},
{},
{},
{},
{},
{}))
>
;
using
AGridDesc_M_K
=
remove_cvref_t
<
decltype
(
MakeAGridDescriptor_M_K
<
ALayout
>
(
{},
{},
{},
{},
{},
{},
{},
{},
{},
{}))
>
;
using
BGridDesc_N_K
=
remove_cvref_t
<
decltype
(
MakeBGridDescriptor_N_K
<
BLayout
>
({},
{}))
>
;
using
DsGridDesc_M_N
=
remove_cvref_t
<
decltype
(
MakeDsGridDescriptor_M_N
({},
{}))
>
;
using
EGridDesc_M_N
=
remove_cvref_t
<
decltype
(
MakeEGridDescriptor_M_N
<
ELayout
>
({},
{}))
>
;
...
...
@@ -370,6 +328,8 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle
// GridwiseGemm
using
GridwiseGemm
=
GridwiseGemmMultipleD_xdl_cshuffle
<
ADataType
,
// TODO: distinguish A/B datatype
BDataType
,
ComputeDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
...
...
@@ -412,14 +372,18 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle
LoopSched
>
;
// desc for blockwise copy
using
AGridDesc_AK0_M_AK1
=
remove_cvref_t
<
decltype
(
GridwiseGemm
::
MakeDefaultAGridDescriptor_AK0_M_AK1
(
AGridDesc_M_K
{}))
>
;
using
BGridDesc_BK0_N_BK1
=
remove_cvref_t
<
decltype
(
GridwiseGemm
::
MakeDefaultBGridDescriptor_BK0_N_BK1
(
BGridDesc_N_K
{}))
>
;
using
DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
=
remove_cvref_t
<
decltype
(
GridwiseGemm
::
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
DsGridDesc_M_N
{}))
>
;
using
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
=
remove_cvref_t
<
decltype
(
GridwiseGemm
::
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
EGridDesc_M_N
{}))
>
;
using
AGridDesc_AK0_M_AK1
=
remove_cvref_t
<
decltype
(
GridwiseGemm
::
MakeDefaultAGridDescriptor_AK0_M_AK1
(
AGridDesc_M_K
{}))
>
;
using
BGridDesc_BK0_N_BK1
=
remove_cvref_t
<
decltype
(
GridwiseGemm
::
MakeDefaultBGridDescriptor_BK0_N_BK1
(
BGridDesc_N_K
{}))
>
;
using
DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
=
remove_cvref_t
<
decltype
(
GridwiseGemm
::
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
DsGridDesc_M_N
{}))
>
;
using
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
=
remove_cvref_t
<
decltype
(
GridwiseGemm
::
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
EGridDesc_M_N
{}))
>
;
// block-to-e-tile map
using
Block2ETileMap
=
...
...
@@ -685,7 +649,8 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle
return
false
;
}
}
else
if
(
get_device_name
()
==
"gfx90a"
||
get_device_name
()
==
"gfx940"
)
else
if
(
get_device_name
()
==
"gfx90a"
||
get_device_name
()
==
"gfx940"
||
get_device_name
()
==
"gfx941"
||
get_device_name
()
==
"gfx942"
)
{
if
constexpr
(
!
(
is_same_v
<
AccDataType
,
float
>
||
is_same_v
<
AccDataType
,
float
>
||
is_same_v
<
AccDataType
,
int32_t
>
||
is_same_v
<
AccDataType
,
double
>
))
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp
0 → 100644
View file @
4a106f7d
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
template
<
index_t
NumDTensor
>
struct
ComputePtrOffsetOfStridedBatch
{
ComputePtrOffsetOfStridedBatch
()
=
default
;
ComputePtrOffsetOfStridedBatch
(
index_t
BatchStrideA
,
index_t
BatchStrideB
,
Array
<
ck
::
index_t
,
NumDTensor
>
BatchStrideDs
,
index_t
BatchStrideE
)
:
BatchStrideA_
(
BatchStrideA
),
BatchStrideB_
(
BatchStrideB
),
BatchStrideDs_
(
BatchStrideDs
),
BatchStrideE_
(
BatchStrideE
)
{
}
__host__
__device__
constexpr
long_index_t
GetAPtrOffset
(
index_t
g_idx
)
const
{
return
g_idx
*
static_cast
<
long_index_t
>
(
BatchStrideA_
);
}
__host__
__device__
constexpr
long_index_t
GetBPtrOffset
(
index_t
g_idx
)
const
{
return
g_idx
*
static_cast
<
long_index_t
>
(
BatchStrideB_
);
}
__host__
__device__
constexpr
auto
GetDsPtrOffset
(
index_t
g_idx
)
const
{
Array
<
long_index_t
,
NumDTensor
>
ds_offset
;
static_for
<
0
,
NumDTensor
,
1
>
{}(
[
&
](
auto
i
)
{
ds_offset
(
i
)
=
g_idx
*
static_cast
<
long_index_t
>
(
BatchStrideDs_
[
i
]);
});
return
ds_offset
;
}
[[
maybe_unused
]]
__host__
__device__
constexpr
long_index_t
GetEPtrOffset
(
index_t
g_idx
)
const
{
return
g_idx
*
static_cast
<
long_index_t
>
(
BatchStrideE_
);
}
// alias for kernels without multiple D
[[
maybe_unused
]]
__host__
__device__
constexpr
long_index_t
GetCPtrOffset
(
index_t
g_idx
)
const
{
return
g_idx
*
static_cast
<
long_index_t
>
(
BatchStrideE_
);
}
index_t
BatchStrideA_
;
index_t
BatchStrideB_
;
Array
<
ck
::
index_t
,
NumDTensor
>
BatchStrideDs_
;
index_t
BatchStrideE_
;
index_t
&
BatchStrideC_
=
BatchStrideE_
;
// alias for kernels without multiple D
};
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_dl.hpp
View file @
4a106f7d
#pragma once
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
2
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -39,8 +39,9 @@ __global__ void
const
BElementwiseOperation
b_element_op
,
const
CDEElementwiseOperation
cde_element_op
)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \
defined(__gfx90a__) || defined(__gfx1030__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \
defined(__gfx90a__) || defined(__gfx1030__) || defined(__gfx1100__) || defined(__gfx1101__) || \
defined(__gfx1102__) || defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
const
index_t
block_id
=
get_block_1d_id
();
...
...
@@ -596,10 +597,12 @@ struct DeviceGroupedGemmMultipleD_Dl : public DeviceGroupedGemm<ALayout,
}
}
hipGetErrorString
(
hipMemcpy
(
arg
.
p_workspace_
,
arg
.
gemm_desc_kernel_arg_
.
data
(),
arg
.
gemm_desc_kernel_arg_
.
size
()
*
sizeof
(
GemmKernelArg
),
hipMemcpyHostToDevice
));
hipGetErrorString
(
hipMemcpyWithStream
(
arg
.
p_workspace_
,
arg
.
gemm_desc_kernel_arg_
.
data
(),
arg
.
gemm_desc_kernel_arg_
.
size
()
*
sizeof
(
GemmKernelArg
),
hipMemcpyHostToDevice
,
stream_config
.
stream_id_
));
auto
launch_kernel
=
[
&
](
auto
has_main_k_block_loop
,
auto
has_double_tail_k_block_loop
)
{
...
...
include/ck/tensor_operation/gpu/device/device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp
→
include/ck/tensor_operation/gpu/device/
impl/
device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp
View file @
4a106f7d
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
2
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -44,7 +44,7 @@ __global__ void
const
CElementwiseOperation
c_element_op
)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__))
defined(__gfx940__)
|| defined(__gfx941__) || defined(__gfx942__)
)
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
const
index_t
block_id
=
get_block_1d_id
();
...
...
@@ -611,10 +611,12 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle
some_has_main_k_block_loop
|=
y
;
}
hipGetErrorString
(
hipMemcpy
(
arg
.
p_workspace_
,
arg
.
group_kernel_args_
.
data
(),
arg
.
group_kernel_args_
.
size
()
*
sizeof
(
GroupKernelArg
),
hipMemcpyHostToDevice
));
hipGetErrorString
(
hipMemcpyWithStream
(
arg
.
p_workspace_
,
arg
.
group_kernel_args_
.
data
(),
arg
.
group_kernel_args_
.
size
()
*
sizeof
(
GroupKernelArg
),
hipMemcpyHostToDevice
,
stream_config
.
stream_id_
));
float
ave_time
=
0
;
...
...
@@ -679,8 +681,7 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
if
(
!
(
ck
::
get_device_name
()
==
"gfx908"
||
ck
::
get_device_name
()
==
"gfx90a"
||
ck
::
get_device_name
()
==
"gfx940"
))
if
(
!
ck
::
is_xdl_supported
())
{
return
false
;
}
...
...
@@ -734,12 +735,12 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle
}
// Check vector load/store requirement
const
auto
a_stride_lowest
=
ABlockTransferSrcVectorDim
==
2
?
device_arg
.
a_mz_kz_strides_
[
1
]
:
device_arg
.
a_mz_kz_strides_
[
0
];
const
auto
b_stride_lowest
=
BBlockTransferSrcVectorDim
==
2
?
device_arg
.
b_nz_kz_strides_
[
1
]
:
device_arg
.
b_nz_kz_strides_
[
0
];
const
auto
a_stride_lowest
=
ABlockTransferSrcVectorDim
==
2
?
device_arg
.
a_mz_kz_strides_
[
1
]
:
device_arg
.
a_mz_kz_strides_
[
0
];
const
auto
b_stride_lowest
=
BBlockTransferSrcVectorDim
==
2
?
device_arg
.
b_nz_kz_strides_
[
1
]
:
device_arg
.
b_nz_kz_strides_
[
0
];
const
auto
b1_stride_lowest
=
B1BlockTransferSrcVectorDim
==
2
?
device_arg
.
b1_nz_kz_strides_
[
1
]
:
device_arg
.
b1_nz_kz_strides_
[
0
];
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl.hpp
View file @
4a106f7d
#pragma once
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
2
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -39,7 +39,7 @@ __global__ void
const
CDEElementwiseOperation
c_element_op
)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__))
defined(__gfx940__)
|| defined(__gfx941__) || defined(__gfx942__)
)
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
const
index_t
block_id
=
get_block_1d_id
();
...
...
@@ -228,9 +228,13 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout,
using
DsGridDesc_M_N
=
remove_cvref_t
<
decltype
(
MakeDsGridDescriptor_M_N
({},
{},
{}))
>
;
using
EGridDesc_M_N
=
decltype
(
MakeEGridDescriptor_M_N
<
ELayout
>
(
1
,
1
,
1
));
using
ComputeDataType
=
ADataType
;
// GridwiseGemm
using
GridwiseGemm
=
GridwiseGemmMultipleD_xdl_cshuffle
<
ADataType
,
// TODO: distinguish A/B datatype
BDataType
,
ComputeDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
...
...
@@ -272,14 +276,18 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout,
CDEBlockTransferScalarPerVector_NPerBlock
,
LoopSched
>
;
using
AGridDesc_AK0_M_AK1
=
remove_cvref_t
<
decltype
(
GridwiseGemm
::
MakeDefaultAGridDescriptor_AK0_M_AK1
(
AGridDesc_M_K
{}))
>
;
using
BGridDesc_BK0_N_BK1
=
remove_cvref_t
<
decltype
(
GridwiseGemm
::
MakeDefaultBGridDescriptor_BK0_N_BK1
(
BGridDesc_N_K
{}))
>
;
using
DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
=
remove_cvref_t
<
decltype
(
GridwiseGemm
::
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
DsGridDesc_M_N
{}))
>
;
using
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
=
remove_cvref_t
<
decltype
(
GridwiseGemm
::
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
EGridDesc_M_N
{}))
>
;
using
AGridDesc_AK0_M_AK1
=
remove_cvref_t
<
decltype
(
GridwiseGemm
::
MakeDefaultAGridDescriptor_AK0_M_AK1
(
AGridDesc_M_K
{}))
>
;
using
BGridDesc_BK0_N_BK1
=
remove_cvref_t
<
decltype
(
GridwiseGemm
::
MakeDefaultBGridDescriptor_BK0_N_BK1
(
BGridDesc_N_K
{}))
>
;
using
DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
=
remove_cvref_t
<
decltype
(
GridwiseGemm
::
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
DsGridDesc_M_N
{}))
>
;
using
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
=
remove_cvref_t
<
decltype
(
GridwiseGemm
::
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
EGridDesc_M_N
{}))
>
;
struct
GroupedGemmBlock2ETileMap
{
...
...
@@ -548,11 +556,12 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout,
}
}
hipGetErrorString
(
hipMemcpy
(
arg
.
p_workspace_
,
arg
.
gemm_desc_kernel_arg_
.
data
(),
arg
.
gemm_desc_kernel_arg_
.
size
()
*
sizeof
(
GemmBiasTransKernelArg
),
hipMemcpyHostToDevice
));
hipGetErrorString
(
hipMemcpyWithStream
(
arg
.
p_workspace_
,
arg
.
gemm_desc_kernel_arg_
.
data
(),
arg
.
gemm_desc_kernel_arg_
.
size
()
*
sizeof
(
GemmBiasTransKernelArg
),
hipMemcpyHostToDevice
,
stream_config
.
stream_id_
));
float
ave_time
=
0
;
...
...
@@ -599,6 +608,11 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout,
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
if
(
!
ck
::
is_xdl_supported
())
{
return
false
;
}
if
((
ck
::
type_convert
<
ck
::
index_t
>
(
arg
.
gemm_desc_kernel_arg_
.
size
())
+
arg
.
skipped_group_count_
)
!=
arg
.
group_count_
)
{
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_fixed_nk.hpp
0 → 100644
View file @
4a106f7d
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <sstream>
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_gemm_fixed_nk.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_splitk_cshuffle.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
template
<
typename
GridwiseGemm
,
typename
GemmDesc
,
GemmSpecialization
GemmSpec
,
typename
ALayout
,
typename
BLayout
,
typename
DsLayout
,
typename
ELayout
,
typename
DsDataType
,
typename
Block2ETileMap
,
typename
GroupedGemmBlock2ETileMap
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CDEElementwiseOperation
,
InMemoryDataOperationEnum
EGlobalMemoryDataOperation
,
bool
HasMainKBlockLoop
>
__global__
void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
#endif
kernel_grouped_gemm_xdl_fixed_nk
(
const
void
CK_CONSTANT_ADDRESS_SPACE
*
gemm_descs_const
,
uint32_t
*
barrier_count
,
const
index_t
barrier_size_grp
,
const
index_t
group_count
,
const
index_t
grid_size_grp
,
const
index_t
KBatch
,
const
AElementwiseOperation
a_element_op
,
const
BElementwiseOperation
b_element_op
,
const
CDEElementwiseOperation
c_element_op
)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
const
index_t
block_id
=
get_block_1d_id
();
const
auto
gemm_desc_ptr
=
reinterpret_cast
<
const
GemmDesc
*>
(
cast_pointer_to_generic_address_space
(
gemm_descs_const
));
const
index_t
group_id
=
block_id
/
grid_size_grp
;
if
(
group_id
>=
group_count
)
return
;
const
index_t
M
=
gemm_desc_ptr
[
group_id
].
M
;
const
index_t
N
=
gemm_desc_ptr
[
group_id
].
N
;
const
index_t
K
=
gemm_desc_ptr
[
group_id
].
K
;
if
(
M
*
N
*
K
==
0
)
return
;
const
auto
StrideA
=
gemm_desc_ptr
[
group_id
].
StrideA
;
const
auto
StrideB
=
gemm_desc_ptr
[
group_id
].
StrideB
;
const
auto
StrideDs
=
gemm_desc_ptr
[
group_id
].
StrideDs
;
const
auto
StrideE
=
gemm_desc_ptr
[
group_id
].
StrideE
;
const
auto
e_grid_desc_m_n
=
GridwiseGemm
::
template
MakeEGridDescriptor_M_N
<
ELayout
,
GemmSpec
>(
M
,
N
,
StrideE
);
const
index_t
BlockStart
=
group_id
*
grid_size_grp
;
const
auto
local_b2e_tile_map
=
Block2ETileMap
{
e_grid_desc_m_n
,
KBatch
};
const
auto
local_grid_size
=
local_b2e_tile_map
.
CalculateGridSize
(
e_grid_desc_m_n
);
constexpr
auto
NumDTensor
=
DsDataType
::
Size
();
using
DsGridPointer
=
decltype
(
GridwiseGemm
::
MakeDsGridPointer
());
DsGridPointer
p_ds_grid_
;
static_for
<
0
,
NumDTensor
,
1
>
{}([
&
](
auto
i
)
{
using
DDataType
=
remove_cvref_t
<
tuple_element_t
<
i
.
value
,
DsDataType
>>
;
// D pointer
p_ds_grid_
(
i
)
=
static_cast
<
const
DDataType
*>
(
gemm_desc_ptr
[
group_id
].
p_ds_grid
[
i
]);
});
index_t
id_off
=
0
;
index_t
id_local
=
get_block_1d_id
()
-
BlockStart
;
const
index_t
mn_blocks
=
local_grid_size
/
KBatch
;
while
(
id_local
<
local_grid_size
)
{
const
auto
block_2_etile_map
=
GroupedGemmBlock2ETileMap
(
local_b2e_tile_map
,
BlockStart
,
id_off
);
auto
barrier_count_finished
=
barrier_count
+
group_id
*
barrier_size_grp
+
id_local
%
mn_blocks
;
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
,
EGlobalMemoryDataOperation
,
GemmSpec
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
>(
gemm_desc_ptr
[
group_id
].
p_a_grid
,
gemm_desc_ptr
[
group_id
].
p_b_grid
,
p_ds_grid_
,
gemm_desc_ptr
[
group_id
].
p_e_grid
,
p_shared
,
barrier_count_finished
,
a_element_op
,
b_element_op
,
c_element_op
,
M
,
N
,
K
,
StrideA
,
StrideB
,
StrideDs
,
StrideE
,
KBatch
,
block_2_etile_map
);
id_off
+=
grid_size_grp
;
id_local
+=
grid_size_grp
;
}
#else
ignore
=
gemm_descs_const
;
ignore
=
barrier_count
;
ignore
=
barrier_size_grp
;
ignore
=
group_count
;
ignore
=
grid_size_grp
;
ignore
=
KBatch
;
ignore
=
a_element_op
;
ignore
=
b_element_op
;
ignore
=
c_element_op
;
#endif
}
template
<
typename
ALayout
,
typename
BLayout
,
typename
DsLayout
,
typename
ELayout
,
typename
ADataType
,
typename
BDataType
,
typename
AccDataType
,
typename
CShuffleDataType
,
typename
DsDataType
,
typename
EDataType
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CDEElementwiseOperation
,
GemmSpecialization
GemmSpec
,
ck
::
index_t
NumPrefetch
,
ck
::
index_t
BlockSize
,
ck
::
index_t
MPerBlock
,
ck
::
index_t
NPerBlock
,
ck
::
index_t
KPerBlock
,
ck
::
index_t
AK1
,
ck
::
index_t
BK1
,
ck
::
index_t
MPerXDL
,
ck
::
index_t
NPerXDL
,
ck
::
index_t
MXdlPerWave
,
ck
::
index_t
NXdlPerWave
,
typename
ABlockTransferThreadClusterLengths_K0_M_K1
,
typename
ABlockTransferThreadClusterArrangeOrder
,
typename
ABlockTransferSrcAccessOrder
,
ck
::
index_t
ABlockTransferSrcVectorDim
,
ck
::
index_t
ABlockTransferSrcScalarPerVector
,
ck
::
index_t
ABlockTransferDstScalarPerVector_K1
,
bool
ABlockLdsExtraM
,
typename
BBlockTransferThreadClusterLengths_K0_N_K1
,
typename
BBlockTransferThreadClusterArrangeOrder
,
typename
BBlockTransferSrcAccessOrder
,
ck
::
index_t
BBlockTransferSrcVectorDim
,
ck
::
index_t
BBlockTransferSrcScalarPerVector
,
ck
::
index_t
BBlockTransferDstScalarPerVector_K1
,
bool
BBlockLdsExtraN
,
index_t
CShuffleMXdlPerWavePerShuffle
,
index_t
CShuffleNXdlPerWavePerShuffle
,
typename
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
index_t
CDEBlockTransferScalarPerVector_NPerBlock
,
typename
ComputeType
=
ADataType
,
LoopScheduler
LoopSched
=
make_default_loop_scheduler
()>
struct
DeviceGroupedGemm_Xdl_Fixed_NK
:
public
DeviceGroupedGemmFixedNK
<
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
ADataType
,
BDataType
,
DsDataType
,
EDataType
,
AElementwiseOperation
,
BElementwiseOperation
,
CDEElementwiseOperation
>
{
using
DeviceOp
=
DeviceGroupedGemm_Xdl_Fixed_NK
;
static
constexpr
index_t
NumDTensor
=
DsDataType
::
Size
();
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
// GridwiseGemm
using
GridwiseGemm
=
GridwiseGemmMultipleD_xdl_splitk_cshuffle
<
ADataType
,
// TODO: distinguish A/B datatype
BDataType
,
ComputeType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
AElementwiseOperation
,
BElementwiseOperation
,
CDEElementwiseOperation
,
NumPrefetch
,
// NumGemmKPrefetchStage
BlockSize
,
MPerBlock
,
NPerBlock
,
KPerBlock
,
AK1
,
BK1
,
MPerXDL
,
NPerXDL
,
MXdlPerWave
,
NXdlPerWave
,
ABlockTransferThreadClusterLengths_K0_M_K1
,
ABlockTransferThreadClusterArrangeOrder
,
ABlockTransferSrcAccessOrder
,
ABlockTransferSrcVectorDim
,
ABlockTransferSrcScalarPerVector
,
ABlockTransferDstScalarPerVector_K1
,
false
,
// AThreadTransferSrcResetCoordinateAfterRun,
ABlockLdsExtraM
,
BBlockTransferThreadClusterLengths_K0_N_K1
,
BBlockTransferThreadClusterArrangeOrder
,
BBlockTransferSrcAccessOrder
,
BBlockTransferSrcVectorDim
,
BBlockTransferSrcScalarPerVector
,
BBlockTransferDstScalarPerVector_K1
,
false
,
// BThreadTransferSrcResetCoordinateAfterRun,
BBlockLdsExtraN
,
CShuffleMXdlPerWavePerShuffle
,
CShuffleNXdlPerWavePerShuffle
,
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
CDEBlockTransferScalarPerVector_NPerBlock
,
LoopSched
>
;
template
<
typename
UnderlyingBlockToCTileMap
>
struct
OffsettedBlockToCTileMapMLoops
{
using
underlying_type
=
UnderlyingBlockToCTileMap
;
__host__
__device__
OffsettedBlockToCTileMapMLoops
(
UnderlyingBlockToCTileMap
block_to_ctile_map
,
index_t
block_start
,
index_t
id_off
=
0
)
{
block_to_ctile_map_
=
block_to_ctile_map
;
block_start_
=
block_start
;
id_off_
=
id_off
;
}
template
<
typename
TopIdx
>
__host__
__device__
constexpr
auto
CalculateBottomIndex
(
const
TopIdx
&
idx_top
)
const
{
auto
idx_bot
=
block_to_ctile_map_
.
CalculateBottomIndex
(
make_multi_index
(
idx_top
[
Number
<
0
>
{}]
-
block_start_
+
id_off_
));
return
make_tuple
(
idx_bot
[
Number
<
0
>
{}],
idx_bot
[
Number
<
1
>
{}],
idx_bot
[
Number
<
2
>
{}]);
}
template
<
typename
CTileIdx
,
typename
CTileDim
>
__host__
__device__
bool
ValidCTileIndex
(
const
CTileIdx
&
c_tile_idx
,
const
CTileDim
&
c_tile_dim
)
const
{
return
block_to_ctile_map_
.
ValidCTileIndex
(
c_tile_idx
,
c_tile_dim
);
}
template
<
typename
CGridDesc_M_N
>
__host__
bool
CheckValidity
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
)
const
{
return
block_to_ctile_map_
.
CheckValidity
(
c_grid_desc_m_n
);
}
template
<
typename
CGridDesc_M_N
>
__host__
constexpr
index_t
CalculateGridSize
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
)
const
{
return
block_to_ctile_map_
.
CalculateGridSize
(
c_grid_desc_m_n
);
}
UnderlyingBlockToCTileMap
block_to_ctile_map_
;
index_t
block_start_
;
index_t
id_off_
;
};
template
<
index_t
MPerBlock_
,
index_t
NPerBlock_
>
struct
BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
__host__
__device__
BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops
()
=
default
;
__host__
__device__
BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops
(
const
BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops
&
)
=
default
;
__host__
__device__
BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops
(
BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops
&&
)
=
default
;
__host__
__device__
BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops
&
operator
=
(
const
BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops
&
)
=
default
;
__host__
__device__
BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops
&
operator
=
(
BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops
&&
)
=
default
;
__host__
__device__
BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops
(
index_t
M
,
index_t
N
,
index_t
KBatch
,
index_t
M01
=
8
)
:
M_
(
M
),
N_
(
N
),
KBatch_
(
KBatch
),
M01_
(
M01
)
{
}
template
<
typename
CGridDesc_M_N
>
__host__
__device__
BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
,
index_t
KBatch
,
index_t
M01
=
8
)
:
BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops
(
c_grid_desc_m_n
.
GetLength
(
I0
),
c_grid_desc_m_n
.
GetLength
(
I1
),
KBatch
,
M01
)
{
}
__host__
__device__
constexpr
index_t
CalculateGridSize
(
index_t
M
,
index_t
N
)
const
{
const
auto
M0
=
math
::
integer_divide_ceil
(
M
,
MPerBlock
);
const
auto
N0
=
math
::
integer_divide_ceil
(
N
,
NPerBlock
);
return
M0
*
N0
*
KBatch_
;
}
template
<
typename
CGridDesc_M_N
>
__host__
__device__
constexpr
index_t
CalculateGridSize
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
)
const
{
return
CalculateGridSize
(
c_grid_desc_m_n
.
GetLength
(
I0
),
c_grid_desc_m_n
.
GetLength
(
I1
));
}
template
<
typename
CGridDesc_M_N
>
__host__
bool
CheckValidity
(
const
CGridDesc_M_N
&
/* c_grid_desc_m_n */
)
const
{
return
true
;
}
template
<
typename
TopIdx
>
__host__
__device__
constexpr
auto
CalculateBottomIndex
(
const
TopIdx
&
idx_top
)
const
{
auto
block_1d_id
=
idx_top
[
I0
];
const
auto
M0
=
math
::
integer_divide_ceil
(
M_
,
MPerBlock_
);
const
auto
N0
=
math
::
integer_divide_ceil
(
N_
,
NPerBlock_
);
block_1d_id
=
block_1d_id
%
(
M0
*
N0
*
KBatch_
);
// hide groups
const
index_t
idx_ksplit
=
block_1d_id
/
(
M0
*
N0
);
block_1d_id
=
block_1d_id
%
(
M0
*
N0
);
index_t
idx_N0
=
block_1d_id
%
N0
;
index_t
idx_M0
=
block_1d_id
/
N0
;
const
auto
M01_adapt
=
(
idx_M0
<
M0
-
M0
%
M01_
)
?
M01_
:
M0
%
M01_
;
index_t
idx_M00
=
idx_M0
/
M01_
;
index_t
idx_M01
=
idx_M0
%
M01_
;
index_t
idx_N0_M01_local
=
idx_N0
+
idx_M01
*
N0
;
return
make_tuple
(
idx_ksplit
,
idx_N0_M01_local
%
M01_adapt
+
idx_M00
*
M01_
,
idx_N0_M01_local
/
M01_adapt
);
}
template
<
typename
CTileIdx
,
typename
CTileDim
>
__host__
__device__
bool
ValidCTileIndex
(
const
CTileIdx
&
/* c_tile_idx */
,
const
CTileDim
&
/* c_tile_dim */
)
const
{
return
true
;
// always valid provided that user gets grid size from CalculateGridSize()
}
private:
index_t
M_
;
index_t
N_
;
index_t
KBatch_
;
index_t
M01_
;
};
using
Block2ETileMap
=
BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops
<
MPerBlock
,
NPerBlock
>
;
using
GroupedGemmBlock2ETileMap
=
OffsettedBlockToCTileMapMLoops
<
Block2ETileMap
>
;
struct
GemmBiasTransKernelArg
{
// pointers
const
void
*
a_ptr_
;
const
void
*
b_ptr_
;
std
::
array
<
const
void
*
,
NumDTensor
>
ds_ptr_
;
void
*
e_ptr_
;
index_t
M_
,
N_
,
K_
;
index_t
StrideA_
,
StrideB_
;
std
::
array
<
index_t
,
NumDTensor
>
StrideDs_
;
index_t
StrideE_
;
};
// Argument
struct
Argument
:
public
BaseArgument
{
void
UpdateKBatch
(
index_t
k_batch
)
{
k_batch_
=
k_batch
;
if
(
k_batch_
<
1
)
{
throw
std
::
runtime_error
(
"wrong! k_batch must be > 0"
);
}
const
index_t
AverM
=
math
::
integer_divide_ceil
(
sum_of_m
,
group_count_
);
const
index_t
StrideE
=
gemm_desc_kernel_arg_
[
0
].
StrideE_
;
const
index_t
N
=
gemm_desc_kernel_arg_
[
0
].
N_
;
const
auto
e_grid_desc_m_n
=
GridwiseGemm
::
template
MakeEGridDescriptor_M_N
<
ELayout
,
GemmSpec
>(
AverM
,
N
,
StrideE
);
const
auto
local_b2c_tile_map
=
Block2ETileMap
{
e_grid_desc_m_n
,
k_batch_
};
grid_size_grp_
=
local_b2c_tile_map
.
CalculateGridSize
(
e_grid_desc_m_n
);
grid_size_
=
grid_size_grp_
*
group_count_
;
}
Argument
(
std
::
vector
<
const
void
*>&
,
std
::
vector
<
const
void
*>&
,
std
::
vector
<
std
::
array
<
const
void
*
,
NumDTensor
>>&
,
std
::
vector
<
void
*>&
,
std
::
vector
<
GemmDesc
>&
gemm_descs
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CDEElementwiseOperation
c_element_op
)
:
a_element_op_
{
a_element_op
},
b_element_op_
{
b_element_op
},
c_element_op_
{
c_element_op
}
{
grid_size_
=
0
;
k_batch_
=
1
;
grouped_gemm_kernel_args_dev
=
nullptr
;
group_count_
=
ck
::
type_convert
<
ck
::
index_t
>
(
gemm_descs
.
size
());
gemm_desc_kernel_arg_
.
reserve
(
group_count_
);
index_t
group_id
=
0
;
sum_of_m
=
gemm_descs
[
0
].
M_
;
const
index_t
AverM
=
math
::
integer_divide_ceil
(
sum_of_m
,
group_count_
);
const
index_t
N
=
gemm_descs
[
0
].
N_
;
const
index_t
K
=
gemm_descs
[
0
].
K_
;
for
(
std
::
size_t
i
=
0
;
i
<
gemm_descs
.
size
();
i
++
)
{
if
(
sum_of_m
!=
gemm_descs
[
i
].
M_
||
N
!=
gemm_descs
[
i
].
N_
||
K
!=
gemm_descs
[
i
].
K_
)
{
throw
std
::
runtime_error
(
"wrong! M/N/K is not identical"
);
}
a_mtx_mraw_kraw_
.
emplace_back
(
sum_of_m
,
K
);
b_mtx_nraw_kraw_
.
emplace_back
(
N
,
K
);
const
index_t
StrideA
=
gemm_descs
[
i
].
stride_A_
;
const
index_t
StrideB
=
gemm_descs
[
i
].
stride_B_
;
const
index_t
StrideE
=
gemm_descs
[
i
].
stride_C_
;
// pointer
std
::
array
<
const
void
*
,
NumDTensor
>
p_ds_grid
;
static_for
<
0
,
NumDTensor
,
1
>
{}([
&
](
auto
j
)
{
p_ds_grid
[
j
]
=
nullptr
;
});
std
::
array
<
index_t
,
NumDTensor
>
StrideDs
;
static_for
<
0
,
NumDTensor
,
1
>
{}([
&
](
auto
j
)
{
// using DLayout = remove_cvref_t<tuple_element_t<j.value, DsLayout>>;
if
(
gemm_descs
[
i
].
stride_Ds_
.
size
()
!=
NumDTensor
)
{
throw
std
::
runtime_error
(
"wrong! gemm_descs[i].stride_Ds_.size() does not match NumDTensor"
);
}
StrideDs
[
j
]
=
gemm_descs
[
i
].
stride_Ds_
[
j
];
});
const
auto
e_grid_desc_m_n
=
GridwiseGemm
::
template
MakeEGridDescriptor_M_N
<
ELayout
,
GemmSpec
>(
AverM
,
N
,
StrideE
);
// block-to-e-tile map
const
auto
local_b2c_tile_map
=
Block2ETileMap
{
e_grid_desc_m_n
,
k_batch_
};
grid_size_grp_
=
local_b2c_tile_map
.
CalculateGridSize
(
e_grid_desc_m_n
);
if
(
group_id
*
grid_size_grp_
!=
grid_size_
)
{
throw
std
::
runtime_error
(
"wrong! grid_size_grp_ is not identical!"
);
}
grid_size_
+=
grid_size_grp_
;
// check block-to-E-tile
if
(
!
local_b2c_tile_map
.
CheckValidity
(
e_grid_desc_m_n
))
{
throw
std
::
runtime_error
(
"wrong! block_2_etile_map validation failed"
);
}
if
(
!
GridwiseGemm
::
template
CheckValidity
<
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
GemmSpec
>(
AverM
,
N
,
K
,
StrideA
,
StrideB
,
StrideDs
,
StrideE
,
1
))
{
throw
std
::
runtime_error
(
"wrong! GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 has invalid setting"
);
}
gemm_desc_kernel_arg_
.
push_back
(
GemmBiasTransKernelArg
{
nullptr
,
nullptr
,
p_ds_grid
,
nullptr
,
AverM
,
N
,
K
,
StrideA
,
StrideB
,
StrideDs
,
StrideE
,
});
group_id
++
;
}
const
auto
e_grid_desc_sum_m_n
=
GridwiseGemm
::
template
MakeEGridDescriptor_M_N
<
ELayout
,
GemmSpec
>(
sum_of_m
,
gemm_desc_kernel_arg_
[
0
].
N_
,
gemm_desc_kernel_arg_
[
0
].
StrideE_
);
const
auto
local_b2c_tile_map
=
Block2ETileMap
{
e_grid_desc_sum_m_n
,
1
};
barrier_size_grp_
=
local_b2c_tile_map
.
CalculateGridSize
(
e_grid_desc_sum_m_n
);
}
// private:
index_t
group_count_
;
AElementwiseOperation
a_element_op_
;
BElementwiseOperation
b_element_op_
;
CDEElementwiseOperation
c_element_op_
;
std
::
vector
<
GemmBiasTransKernelArg
>
gemm_desc_kernel_arg_
;
std
::
vector
<
Tuple
<
index_t
,
index_t
>>
a_mtx_mraw_kraw_
;
std
::
vector
<
Tuple
<
index_t
,
index_t
>>
b_mtx_nraw_kraw_
;
const
void
*
grouped_gemm_kernel_args_dev
;
index_t
grid_size_
;
index_t
grid_size_grp_
;
index_t
barrier_size_grp_
;
index_t
sum_of_m
;
index_t
k_batch_
;
};
// Invoker
struct
Invoker
:
public
BaseInvoker
{
using
Argument
=
DeviceOp
::
Argument
;
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
{
bool
has_main_k_block_loop
=
true
;
for
(
std
::
size_t
i
=
0
;
i
<
arg
.
gemm_desc_kernel_arg_
.
size
();
i
++
)
{
const
auto
KPad
=
GridwiseGemm
::
CalculateKPadded
(
arg
.
gemm_desc_kernel_arg_
[
i
].
K_
,
arg
.
k_batch_
);
if
(
GridwiseGemm
::
CalculateHasMainKBlockLoop
(
KPad
)
!=
has_main_k_block_loop
)
{
throw
std
::
runtime_error
(
"wrong! not all gemm has_main_k_block_loop"
);
}
}
if
(
arg
.
grouped_gemm_kernel_args_dev
==
nullptr
)
{
throw
std
::
runtime_error
(
"wrong! grouped_gemm_kernel_args_dev is nullpr"
);
}
float
ave_time
=
0
;
auto
launch_kernel
=
[
&
](
auto
has_main_k_block_loop_
,
auto
e_global_memory_operation_
)
{
const
auto
kernel
=
kernel_grouped_gemm_xdl_fixed_nk
<
GridwiseGemm
,
GroupedGemmKernelArgument
<
NumDTensor
>
,
GemmSpec
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
DsDataType
,
Block2ETileMap
,
GroupedGemmBlock2ETileMap
,
AElementwiseOperation
,
BElementwiseOperation
,
CDEElementwiseOperation
,
e_global_memory_operation_
,
has_main_k_block_loop_
>
;
return
launch_and_time_kernel
(
stream_config
,
kernel
,
dim3
(
arg
.
grid_size_
),
dim3
(
BlockSize
),
0
,
cast_pointer_to_constant_address_space
(
arg
.
grouped_gemm_kernel_args_dev
),
reinterpret_cast
<
uint32_t
*>
(
arg
.
p_workspace_
),
arg
.
barrier_size_grp_
,
arg
.
gemm_desc_kernel_arg_
.
size
(),
arg
.
grid_size_grp_
,
arg
.
k_batch_
,
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
c_element_op_
);
};
constexpr
auto
AtomicAdd
=
InMemoryDataOperationEnum
::
AtomicAdd
;
constexpr
auto
Set
=
InMemoryDataOperationEnum
::
Set
;
if
(
arg
.
k_batch_
>
1
)
{
if
(
has_main_k_block_loop
)
{
ave_time
=
launch_kernel
(
integral_constant
<
bool
,
true
>
{},
integral_constant
<
InMemoryDataOperationEnum
,
AtomicAdd
>
{});
}
else
{
ave_time
=
launch_kernel
(
integral_constant
<
bool
,
false
>
{},
integral_constant
<
InMemoryDataOperationEnum
,
AtomicAdd
>
{});
}
}
else
{
if
(
has_main_k_block_loop
)
{
ave_time
=
launch_kernel
(
integral_constant
<
bool
,
true
>
{},
integral_constant
<
InMemoryDataOperationEnum
,
Set
>
{});
}
else
{
ave_time
=
launch_kernel
(
integral_constant
<
bool
,
false
>
{},
integral_constant
<
InMemoryDataOperationEnum
,
Set
>
{});
}
}
return
ave_time
;
}
// polymorphic
float
Run
(
const
BaseArgument
*
p_arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
override
{
return
Run
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
),
stream_config
);
}
};
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
if
(
ck
::
type_convert
<
ck
::
index_t
>
(
arg
.
gemm_desc_kernel_arg_
.
size
())
!=
arg
.
group_count_
)
{
return
false
;
}
bool
supported
=
true
;
// If we use padding we do not support vector loads for dimensions not divisible by vector
// load size.
if
constexpr
(
GemmSpec
!=
GemmSpecialization
::
Default
)
{
// [A|B]BlockTransferSrcVectorDim value define dimension in the block {K0,M,K1} layout,
// thus we have to adapt it to the {M,K} or {N,K} layout.
const
auto
a_raw_vector_dim
=
ABlockTransferSrcVectorDim
!=
1
?
1
:
0
;
const
auto
b_raw_vector_dim
=
BBlockTransferSrcVectorDim
!=
1
?
1
:
0
;
for
(
index_t
i
=
0
;
i
<
arg
.
group_count_
;
++
i
)
{
const
auto
a_vector_dim
=
arg
.
a_mtx_mraw_kraw_
[
i
].
At
(
Number
<
a_raw_vector_dim
>
{});
const
auto
b_vector_dim
=
arg
.
b_mtx_nraw_kraw_
[
i
].
At
(
Number
<
b_raw_vector_dim
>
{});
supported
=
supported
&
(
a_vector_dim
%
ABlockTransferSrcScalarPerVector
==
0
);
supported
=
supported
&
(
b_vector_dim
%
BBlockTransferSrcScalarPerVector
==
0
);
}
}
return
supported
;
}
// polymorphic
bool
IsSupportedArgument
(
const
BaseArgument
*
p_arg
)
override
{
return
IsSupportedArgument
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
));
}
static
auto
MakeArgument
(
std
::
vector
<
const
void
*>&
p_As
,
std
::
vector
<
const
void
*>&
p_Bs
,
std
::
vector
<
std
::
array
<
const
void
*
,
NumDTensor
>>&
p_Ds
,
std
::
vector
<
void
*>&
p_Es
,
std
::
vector
<
GemmDesc
>
gemm_descs
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CDEElementwiseOperation
c_element_op
)
{
return
Argument
{
p_As
,
p_Bs
,
p_Ds
,
p_Es
,
gemm_descs
,
a_element_op
,
b_element_op
,
c_element_op
};
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
// polymorphic
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
std
::
vector
<
const
void
*>&
p_As
,
std
::
vector
<
const
void
*>&
p_Bs
,
std
::
vector
<
std
::
array
<
const
void
*
,
NumDTensor
>>&
p_Ds
,
std
::
vector
<
void
*>&
p_Es
,
std
::
vector
<
GemmDesc
>&
gemm_descs
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CDEElementwiseOperation
c_element_op
)
override
{
return
std
::
make_unique
<
Argument
>
(
p_As
,
p_Bs
,
p_Ds
,
p_Es
,
gemm_descs
,
a_element_op
,
b_element_op
,
c_element_op
);
}
// polymorphic
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
override
{
return
std
::
make_unique
<
Invoker
>
(
Invoker
{});
}
// polymorphic
std
::
string
GetTypeString
()
const
override
{
auto
str
=
std
::
stringstream
();
// clang-format off
str
<<
"DeviceGroupedGemm_Xdl_Fixed_NK"
<<
"<"
<<
BlockSize
<<
", "
<<
MPerBlock
<<
", "
<<
NPerBlock
<<
", "
<<
KPerBlock
<<
", "
<<
AK1
<<
", "
<<
BK1
<<
", "
<<
MPerXDL
<<
", "
<<
NPerXDL
<<
", "
<<
MXdlPerWave
<<
", "
<<
NXdlPerWave
<<
", "
<<
ABlockTransferSrcScalarPerVector
<<
", "
<<
BBlockTransferSrcScalarPerVector
<<
", "
<<
CShuffleMXdlPerWavePerShuffle
<<
", "
<<
CShuffleNXdlPerWavePerShuffle
<<
", "
<<
getGemmSpecializationString
(
GemmSpec
)
<<
">"
;
// clang-format on
return
str
.
str
();
}
static
void
SetDeviceKernelArgs
(
Argument
&
arg
,
const
void
*
kernel_args
)
{
arg
.
grouped_gemm_kernel_args_dev
=
kernel_args
;
}
// polymorphic
void
SetDeviceKernelArgs
(
BaseArgument
*
p_arg
,
const
void
*
kernel_args
)
const
override
{
return
SetDeviceKernelArgs
(
*
dynamic_cast
<
Argument
*>
(
p_arg
),
kernel_args
);
}
size_t
GetWorkSpaceSize
(
const
BaseArgument
*
p_arg
)
const
override
{
auto
arg
=
*
dynamic_cast
<
const
Argument
*>
(
p_arg
);
return
arg
.
group_count_
*
arg
.
barrier_size_grp_
*
sizeof
(
uint32_t
);
}
size_t
GetDeviceKernelArgSize
(
const
BaseArgument
*
p_arg
)
const
override
{
auto
arg
=
*
dynamic_cast
<
const
Argument
*>
(
p_arg
);
return
arg
.
group_count_
*
sizeof
(
GroupedGemmKernelArgument
<
NumDTensor
>
);
}
void
SetWorkSpacePointer
(
BaseArgument
*
p_arg
,
void
*
p_workspace
)
const
override
{
auto
p_arg_
=
dynamic_cast
<
Argument
*>
(
p_arg
);
p_arg_
->
p_workspace_
=
p_workspace
;
hip_check_error
(
hipMemset
(
p_workspace
,
0
,
GetWorkSpaceSize
(
p_arg
)));
}
static
void
SetKBatch
(
Argument
&
arg
,
index_t
k_batch
)
{
arg
.
UpdateKBatch
(
k_batch
);
}
// polymorphic
void
SetKBatch
(
BaseArgument
*
p_arg
,
index_t
k_batch
)
const
override
{
return
SetKBatch
(
*
dynamic_cast
<
Argument
*>
(
p_arg
),
k_batch
);
}
};
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_splitk_cshuffle.hpp
0 → 100644
View file @
4a106f7d
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <sstream>
#include "ck/ck.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/host_utility/hip_check_error.hpp"
#include "ck/utility/common_header.hpp"
#include "ck/utility/tuple.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_gemm_splitk.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
template
<
typename
GridwiseGemm
,
typename
GemmDesc
,
bool
HasMainKBlockLoop
,
InMemoryDataOperationEnum
CGlobalMemoryDataOperation
>
__global__
void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
#endif
kernel_grouped_gemm_xdl_splitk
(
const
void
CK_CONSTANT_ADDRESS_SPACE
*
gemm_descs_const
,
const
index_t
group_count
)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
constexpr
index_t
shared_size
=
GridwiseGemm
::
GetSharedMemoryNumberOfByte
();
__shared__
uint8_t
p_shared
[
shared_size
];
const
index_t
block_id
=
get_block_1d_id
();
const
auto
gemm_desc_ptr
=
reinterpret_cast
<
const
GemmDesc
*>
(
cast_pointer_to_generic_address_space
(
gemm_descs_const
));
index_t
left
=
0
;
index_t
right
=
group_count
;
index_t
group_id
=
index_t
((
left
+
right
)
/
2
);
while
((
!
(
block_id
>=
gemm_desc_ptr
[
group_id
].
block_start_
&&
block_id
<
gemm_desc_ptr
[
group_id
].
block_end_
))
&&
left
<=
right
)
{
if
(
block_id
<
gemm_desc_ptr
[
group_id
].
block_start_
)
{
right
=
group_id
;
}
else
{
left
=
group_id
;
}
group_id
=
index_t
((
left
+
right
)
/
2
);
}
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
,
CGlobalMemoryDataOperation
>(
gemm_desc_ptr
[
group_id
].
karg_
,
static_cast
<
void
*>
(
p_shared
),
gemm_desc_ptr
[
group_id
].
block_2_ctile_map_
);
#else
ignore
=
gemm_descs_const
;
ignore
=
group_count
;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
}
template
<
typename
ALayout
,
typename
BLayout
,
typename
DsLayout
,
typename
ELayout
,
typename
ADataType
,
typename
BDataType
,
typename
AccDataType
,
typename
CShuffleDataType
,
typename
DsDataType
,
typename
EDataType
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CDEElementwiseOperation
,
GemmSpecialization
GemmSpec
,
ck
::
index_t
NumGemmKPrefetchStage
,
ck
::
index_t
BlockSize
,
ck
::
index_t
MPerBlock
,
ck
::
index_t
NPerBlock
,
ck
::
index_t
KPerBlock
,
ck
::
index_t
AK1
,
ck
::
index_t
BK1
,
ck
::
index_t
MPerXDL
,
ck
::
index_t
NPerXDL
,
ck
::
index_t
MXdlPerWave
,
ck
::
index_t
NXdlPerWave
,
typename
ABlockTransferThreadClusterLengths_K0_M_K1
,
typename
ABlockTransferThreadClusterArrangeOrder
,
typename
ABlockTransferSrcAccessOrder
,
ck
::
index_t
ABlockTransferSrcVectorDim
,
ck
::
index_t
ABlockTransferSrcScalarPerVector
,
ck
::
index_t
ABlockTransferDstScalarPerVector_K1
,
bool
ABlockLdsExtraM
,
typename
BBlockTransferThreadClusterLengths_K0_N_K1
,
typename
BBlockTransferThreadClusterArrangeOrder
,
typename
BBlockTransferSrcAccessOrder
,
ck
::
index_t
BBlockTransferSrcVectorDim
,
ck
::
index_t
BBlockTransferSrcScalarPerVector
,
ck
::
index_t
BBlockTransferDstScalarPerVector_K1
,
bool
BBlockLdsExtraN
,
index_t
CShuffleMXdlPerWavePerShuffle
,
index_t
CShuffleNXdlPerWavePerShuffle
,
typename
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
index_t
CDEBlockTransferScalarPerVector_NPerBlock
,
PipelineVersion
PipelineVer
=
PipelineVersion
::
v1
,
LoopScheduler
LoopSched
=
make_default_loop_scheduler
(),
// Current implementation does not support multiple D fusions.
enable_if_t
<
AK1
==
BK1
&&
is_same_v
<
DsLayout
,
ck
::
Tuple
<
>
>
&&
is_same_v
<
DsDataType
,
ck
::
Tuple
<>>
,
bool
>
=
false
>
struct
DeviceGroupedGemmXdlSplitKCShuffle
:
public
DeviceGroupedGemmSplitK
<
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
ADataType
,
BDataType
,
DsDataType
,
EDataType
,
AElementwiseOperation
,
BElementwiseOperation
,
CDEElementwiseOperation
>
{
static
constexpr
index_t
NumDTensor
=
DsDataType
::
Size
();
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
static_assert
(
KPerBlock
%
AK1
==
0
);
static
constexpr
index_t
K0PerBlock
=
KPerBlock
/
AK1
;
using
GridwiseGemm
=
GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
<
BlockSize
,
ADataType
,
BDataType
,
AccDataType
,
EDataType
,
ALayout
,
BLayout
,
ELayout
,
AElementwiseOperation
,
BElementwiseOperation
,
CDEElementwiseOperation
,
GemmSpec
,
NumGemmKPrefetchStage
,
MPerBlock
,
NPerBlock
,
K0PerBlock
,
MPerXDL
,
NPerXDL
,
AK1
,
MXdlPerWave
,
NXdlPerWave
,
ABlockTransferThreadClusterLengths_K0_M_K1
,
ABlockTransferThreadClusterArrangeOrder
,
ABlockTransferSrcAccessOrder
,
ABlockTransferSrcVectorDim
,
ABlockTransferSrcScalarPerVector
,
ABlockTransferDstScalarPerVector_K1
,
false
,
// AThreadTransferSrcResetCoordinateAfterRun,
ABlockLdsExtraM
,
BBlockTransferThreadClusterLengths_K0_N_K1
,
BBlockTransferThreadClusterArrangeOrder
,
BBlockTransferSrcAccessOrder
,
BBlockTransferSrcVectorDim
,
BBlockTransferSrcScalarPerVector
,
BBlockTransferDstScalarPerVector_K1
,
false
,
// BThreadTransferSrcResetCoordinateAfterRun,
BBlockLdsExtraN
,
CShuffleMXdlPerWavePerShuffle
,
CShuffleNXdlPerWavePerShuffle
,
CDEBlockTransferScalarPerVector_NPerBlock
,
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
LoopSched
,
PipelineVer
>
;
using
CGridDesc_M_N
=
typename
GridwiseGemm
::
CGridDesc_M_N
;
using
Block2ETileMapKSplit
=
BlockToCTileMap_KSplit_M00_N0_M01Adapt
<
MPerBlock
,
NPerBlock
,
CGridDesc_M_N
>
;
// Block2CTileMap configuration parameter.
static
constexpr
index_t
B2E_M01
=
8
;
using
GroupedGemmBlock2ETileMap
=
OffsettedBlockToCTileMap
<
Block2ETileMapKSplit
>
;
using
KernelArgument
=
typename
GridwiseGemm
::
Argument
;
struct
GemmTransKernelArg
{
KernelArgument
karg_
;
GroupedGemmBlock2ETileMap
block_2_ctile_map_
;
index_t
block_start_
,
block_end_
;
GemmTransKernelArg
()
=
default
;
GemmTransKernelArg
(
KernelArgument
&&
karg
,
GroupedGemmBlock2ETileMap
&&
b2c_map
,
index_t
block_start
,
index_t
block_end
)
:
karg_
{
karg
},
block_2_ctile_map_
{
b2c_map
},
block_start_
{
block_start
},
block_end_
{
block_end
}
{
}
};
static
constexpr
index_t
DefaultKBatch
=
1
;
// Argument
struct
Argument
:
public
BaseArgument
{
Argument
(
std
::
vector
<
const
void
*>&
p_As
,
std
::
vector
<
const
void
*>&
p_Bs
,
std
::
vector
<
void
*>&
p_Es
,
std
::
vector
<
GemmDesc
>&
gemm_descs
)
:
Argument
(
p_As
,
p_Bs
,
p_Es
,
gemm_descs
,
DefaultKBatch
)
{
// TODO: use occupancy api to calculate appropriate batch size.
}
Argument
(
std
::
vector
<
const
void
*>&
p_As
,
std
::
vector
<
const
void
*>&
p_Bs
,
std
::
vector
<
void
*>&
p_Es
,
std
::
vector
<
GemmDesc
>&
gemm_descs
,
index_t
kbatch
)
:
K_BATCH
{
kbatch
}
{
grid_size_
=
0
;
group_count_
=
ck
::
type_convert
<
ck
::
index_t
>
(
gemm_descs
.
size
());
if
(
!
(
group_count_
==
ck
::
type_convert
<
ck
::
index_t
>
(
p_As
.
size
())
&&
group_count_
==
ck
::
type_convert
<
ck
::
index_t
>
(
p_Bs
.
size
())
&&
group_count_
==
ck
::
type_convert
<
ck
::
index_t
>
(
p_Es
.
size
())))
{
throw
std
::
runtime_error
(
"wrong! group_count_ != p_As/b/c.size"
);
}
gemm_kernel_args_
.
reserve
(
group_count_
);
skipped_group_count_
=
0
;
for
(
std
::
size_t
i
=
0
;
i
<
gemm_descs
.
size
();
++
i
)
{
const
index_t
M
=
gemm_descs
[
i
].
M_
;
const
index_t
N
=
gemm_descs
[
i
].
N_
;
const
index_t
K
=
gemm_descs
[
i
].
K_
;
if
(
M
==
0
)
{
skipped_group_count_
++
;
continue
;
}
const
index_t
stride_a
=
gemm_descs
[
i
].
stride_A_
;
const
index_t
stride_b
=
gemm_descs
[
i
].
stride_B_
;
const
index_t
stride_c
=
gemm_descs
[
i
].
stride_C_
;
const
index_t
m_padded
=
GridwiseGemm
::
CalculateMPadded
(
M
);
const
index_t
n_padded
=
GridwiseGemm
::
CalculateNPadded
(
N
);
const
index_t
k_padded
=
GridwiseGemm
::
CalculateKPadded
(
K
,
K_BATCH
);
const
index_t
k0
=
GridwiseGemm
::
CalculateK0
(
K
,
K_BATCH
);
const
auto
c_grid_desc_m_n
=
GridwiseGemm
::
MakeCGridDescriptor_M_N
(
M
,
N
,
stride_c
);
const
auto
local_b2c_tile_map
=
Block2ETileMapKSplit
{
c_grid_desc_m_n
,
B2E_M01
,
K_BATCH
};
const
index_t
grid_size_grp
=
local_b2c_tile_map
.
CalculateGridSize
(
c_grid_desc_m_n
);
const
index_t
block_start
=
grid_size_
;
const
index_t
block_end
=
grid_size_
+
grid_size_grp
;
grid_size_
+=
grid_size_grp
;
// block-to-e-tile map
auto
grouped_block_2_ctile_map
=
GroupedGemmBlock2ETileMap
(
local_b2c_tile_map
,
block_start
);
auto
karg
=
KernelArgument
{
type_convert
<
const
ADataType
*>
(
p_As
[
i
]),
type_convert
<
const
BDataType
*>
(
p_Bs
[
i
]),
type_convert
<
EDataType
*>
(
p_Es
[
i
]),
M
,
N
,
K
,
stride_a
,
stride_b
,
stride_c
,
m_padded
,
n_padded
,
k_padded
,
k0
,
K_BATCH
};
gemm_kernel_args_
.
emplace_back
(
std
::
move
(
karg
),
std
::
move
(
grouped_block_2_ctile_map
),
block_start
,
block_end
);
}
}
/**
* @brief Recalculate group grid size for all gemms and update B2C maps.
*
* @param[in] kbatch The new splitK parameter value.
*/
void
UpdateKBatch
(
index_t
kbatch
)
{
K_BATCH
=
kbatch
;
grid_size_
=
0
;
for
(
std
::
size_t
i
=
0
;
i
<
gemm_kernel_args_
.
size
();
++
i
)
{
auto
&
karg
=
gemm_kernel_args_
[
i
].
karg_
;
const
index_t
k_padded
=
GridwiseGemm
::
CalculateKPadded
(
karg
.
K
,
K_BATCH
);
const
index_t
k0
=
GridwiseGemm
::
CalculateK0
(
karg
.
K
,
K_BATCH
);
const
auto
c_grid_desc_m_n
=
GridwiseGemm
::
MakeCGridDescriptor_M_N
(
karg
.
M
,
karg
.
N
,
karg
.
StrideC
);
const
auto
local_b2c_tile_map
=
Block2ETileMapKSplit
{
c_grid_desc_m_n
,
B2E_M01
,
K_BATCH
};
const
index_t
grid_size_grp
=
local_b2c_tile_map
.
CalculateGridSize
(
c_grid_desc_m_n
);
const
index_t
block_start
=
grid_size_
;
const
index_t
block_end
=
grid_size_
+
grid_size_grp
;
grid_size_
+=
grid_size_grp
;
// block-to-e-tile map
auto
grouped_block_2_ctile_map
=
GroupedGemmBlock2ETileMap
(
local_b2c_tile_map
,
block_start
);
karg
.
KPadded
=
k_padded
;
karg
.
K0
=
k0
;
karg
.
k_batch
=
K_BATCH
;
gemm_kernel_args_
[
i
].
block_2_ctile_map_
=
grouped_block_2_ctile_map
;
gemm_kernel_args_
[
i
].
block_start_
=
block_start
;
gemm_kernel_args_
[
i
].
block_end_
=
block_end
;
}
}
// private:
index_t
K_BATCH
;
index_t
group_count_
;
index_t
skipped_group_count_
;
std
::
vector
<
GemmTransKernelArg
>
gemm_kernel_args_
;
index_t
grid_size_
;
};
// Invoker
struct
Invoker
:
public
BaseInvoker
{
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
{
index_t
K0
=
arg
.
gemm_kernel_args_
[
0
].
karg_
.
K0
;
bool
all_have_kbatch_gt_one
=
arg
.
gemm_kernel_args_
[
0
].
karg_
.
k_batch
>
1
;
bool
all_have_main_k0_block_loop
=
GridwiseGemm
::
CalculateHasMainK0BlockLoop
(
K0
);
for
(
std
::
size_t
i
=
0
;
i
<
arg
.
gemm_kernel_args_
.
size
();
++
i
)
{
const
auto
&
karg
=
arg
.
gemm_kernel_args_
[
i
].
karg_
;
if
(
stream_config
.
log_level_
>
0
)
{
karg
.
Print
();
}
auto
kbatch
=
karg
.
k_batch
;
if
(
!
GridwiseGemm
::
CheckValidity
(
karg
))
{
std
::
ostringstream
err
;
err
<<
"Group id: "
<<
i
<<
" has invalid GridwiseGemm settings!"
<<
__FILE__
<<
":"
<<
__LINE__
<<
", in function: "
<<
__func__
;
throw
std
::
runtime_error
(
err
.
str
());
}
K0
=
karg
.
K0
;
bool
not_all_have_main_k0_block_loop_same
=
all_have_main_k0_block_loop
xor
GridwiseGemm
::
CalculateHasMainK0BlockLoop
(
K0
);
bool
not_all_have_kbatch_value_same
=
all_have_kbatch_gt_one
xor
(
kbatch
>
1
);
if
(
not_all_have_main_k0_block_loop_same
)
{
std
::
ostringstream
err
;
err
<<
"Not all gemms have same value for main_k0_block_loop! in "
<<
__FILE__
<<
":"
<<
__LINE__
<<
", in function: "
<<
__func__
;
throw
std
::
runtime_error
(
err
.
str
());
}
if
(
not_all_have_kbatch_value_same
)
{
std
::
ostringstream
err
;
err
<<
"Not all gemms have same kbatch value (=1 or >1)! "
<<
"group ["
<<
i
<<
"], kbatch: "
<<
kbatch
<<
", group [0], kbatch: "
<<
arg
.
gemm_kernel_args_
[
0
].
karg_
.
k_batch
<<
" in "
<<
__FILE__
<<
":"
<<
__LINE__
<<
", in function: "
<<
__func__
;
throw
std
::
runtime_error
(
err
.
str
());
}
}
hip_check_error
(
hipMemcpyWithStream
(
arg
.
p_workspace_
,
arg
.
gemm_kernel_args_
.
data
(),
arg
.
gemm_kernel_args_
.
size
()
*
sizeof
(
GemmTransKernelArg
),
hipMemcpyHostToDevice
,
stream_config
.
stream_id_
));
float
ave_time
=
0
;
const
auto
Run
=
[
&
](
const
auto
&
kernel
)
{
if
(
all_have_kbatch_gt_one
)
{
for
(
const
auto
&
trans_arg
:
arg
.
gemm_kernel_args_
)
{
const
auto
&
karg
=
trans_arg
.
karg_
;
hip_check_error
(
hipMemsetAsync
(
karg
.
p_c_grid
,
0
,
karg
.
M
*
karg
.
N
*
sizeof
(
EDataType
),
stream_config
.
stream_id_
));
}
}
ave_time
=
launch_and_time_kernel
(
stream_config
,
kernel
,
dim3
(
arg
.
grid_size_
),
dim3
(
BlockSize
),
0
,
cast_pointer_to_constant_address_space
(
arg
.
p_workspace_
),
arg
.
gemm_kernel_args_
.
size
());
};
if
(
all_have_main_k0_block_loop
)
{
if
(
all_have_kbatch_gt_one
)
{
const
auto
kernel
=
kernel_grouped_gemm_xdl_splitk
<
GridwiseGemm
,
GemmTransKernelArg
,
true
,
InMemoryDataOperationEnum
::
AtomicAdd
>
;
Run
(
kernel
);
}
else
{
const
auto
kernel
=
kernel_grouped_gemm_xdl_splitk
<
GridwiseGemm
,
GemmTransKernelArg
,
true
,
InMemoryDataOperationEnum
::
Set
>
;
Run
(
kernel
);
}
}
else
{
if
(
all_have_kbatch_gt_one
)
{
const
auto
kernel
=
kernel_grouped_gemm_xdl_splitk
<
GridwiseGemm
,
GemmTransKernelArg
,
false
,
InMemoryDataOperationEnum
::
AtomicAdd
>
;
Run
(
kernel
);
}
else
{
const
auto
kernel
=
kernel_grouped_gemm_xdl_splitk
<
GridwiseGemm
,
GemmTransKernelArg
,
false
,
InMemoryDataOperationEnum
::
Set
>
;
Run
(
kernel
);
}
}
return
ave_time
;
}
// polymorphic
float
Run
(
const
BaseArgument
*
p_arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
override
{
return
Run
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
),
stream_config
);
}
};
static
constexpr
bool
IsValidCompilationParameter
()
{
// TODO: properly implement this check
return
true
;
}
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
if
(
!
ck
::
is_xdl_supported
())
{
return
false
;
}
if
((
ck
::
type_convert
<
ck
::
index_t
>
(
arg
.
gemm_kernel_args_
.
size
())
+
arg
.
skipped_group_count_
)
!=
arg
.
group_count_
)
{
#if DEBUG_LOG
std
::
cout
<<
"The group count is not equal to sum of skipped groups "
"and kernel args size!"
<<
std
::
endl
;
#endif // DEBUG_LOG
return
false
;
}
bool
supported
=
true
;
for
(
std
::
size_t
i
=
0
;
i
<
arg
.
gemm_kernel_args_
.
size
();
++
i
)
{
const
auto
&
a
=
arg
.
gemm_kernel_args_
[
i
].
karg_
;
bool
group_arg_valid
=
GridwiseGemm
::
CheckValidity
(
a
);
if
(
not
group_arg_valid
)
{
#if DEBUG_LOG
std
::
cout
<<
"["
<<
__func__
<<
"] group id: "
<<
i
<<
" has invalid GridwiseGemm settings!"
<<
std
::
endl
;
a
.
Print
();
#endif // DEBUG_LOG
}
supported
=
supported
&&
group_arg_valid
;
}
return
supported
;
}
// polymorphic
bool
IsSupportedArgument
(
const
BaseArgument
*
p_arg
)
override
{
return
IsSupportedArgument
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
));
}
static
auto
MakeArgument
(
std
::
vector
<
const
void
*>&
p_As
,
std
::
vector
<
const
void
*>&
p_Bs
,
std
::
vector
<
std
::
array
<
const
void
*
,
NumDTensor
>>&
,
std
::
vector
<
void
*>&
p_Es
,
std
::
vector
<
GemmDesc
>
gemm_descs
,
AElementwiseOperation
,
BElementwiseOperation
,
CDEElementwiseOperation
)
{
return
Argument
{
p_As
,
p_Bs
,
p_Es
,
gemm_descs
};
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
// polymorphic
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
std
::
vector
<
const
void
*>&
p_As
,
std
::
vector
<
const
void
*>&
p_Bs
,
std
::
vector
<
std
::
array
<
const
void
*
,
NumDTensor
>>&
,
std
::
vector
<
void
*>&
p_Es
,
std
::
vector
<
GemmDesc
>&
gemm_descs
,
AElementwiseOperation
,
BElementwiseOperation
,
CDEElementwiseOperation
)
override
{
return
std
::
make_unique
<
Argument
>
(
p_As
,
p_Bs
,
p_Es
,
gemm_descs
);
}
// polymorphic
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
override
{
return
std
::
make_unique
<
Invoker
>
(
Invoker
{});
}
// polymorphic
std
::
string
GetTypeString
()
const
override
{
auto
str
=
std
::
stringstream
();
// clang-format off
str
<<
"DeviceGroupedGemm_XdlSplitK"
<<
"<"
<<
std
::
string
(
ALayout
::
name
)[
0
]
<<
","
<<
std
::
string
(
BLayout
::
name
)[
0
]
<<
","
<<
std
::
string
(
ELayout
::
name
)[
0
]
<<
","
<<
BlockSize
<<
", "
<<
MPerBlock
<<
", "
<<
NPerBlock
<<
", "
<<
KPerBlock
<<
", "
<<
AK1
<<
", "
<<
BK1
<<
", "
<<
MPerXDL
<<
", "
<<
NPerXDL
<<
", "
<<
MXdlPerWave
<<
", "
<<
NXdlPerWave
<<
", "
<<
ABlockTransferSrcScalarPerVector
<<
", "
<<
BBlockTransferSrcScalarPerVector
<<
", "
<<
CShuffleMXdlPerWavePerShuffle
<<
", "
<<
CShuffleNXdlPerWavePerShuffle
<<
", "
<<
getGemmSpecializationString
(
GemmSpec
)
<<
">"
;
// clang-format on
return
str
.
str
();
}
size_t
GetWorkSpaceSize
(
const
BaseArgument
*
p_arg
)
const
override
{
return
dynamic_cast
<
const
Argument
*>
(
p_arg
)
->
gemm_kernel_args_
.
size
()
*
sizeof
(
GemmTransKernelArg
);
}
static
void
SetKBatchSize
(
Argument
&
arg
,
index_t
kbatch
)
{
arg
.
UpdateKBatch
(
kbatch
);
}
// polymorphic
void
SetKBatchSize
(
BaseArgument
*
p_arg
,
index_t
kbatch
)
const
override
{
return
SetKBatchSize
(
*
dynamic_cast
<
Argument
*>
(
p_arg
),
kbatch
);
}
};
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
include/ck/tensor_operation/gpu/device/impl/device_image_to_column_impl.hpp
0 → 100644
View file @
4a106f7d
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/device_conv_tensor_rearrange.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_tensor_rearrange.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp"
#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/conv_tensor_rearrange_op.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp"
#include "ck/host_utility/io.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
// Image to column:
// input : input image [G, N, Di, Hi, Wi, C]
// output : gemm form [G * N * Do * Ho * Wo, Z * Y * X * C]
// input : input image [N, Di, Hi, Wi, G, C]
// output : gemm form [N * Do * Ho * Wo * G, Z * Y * X * C]
template
<
index_t
NDimSpatial
,
typename
ImageLayout
,
typename
InputDataType
,
typename
OutputDataType
,
index_t
BlockSize
,
index_t
MPerBlock
,
index_t
KPerBlock
,
typename
ThreadClusterLengths
,
index_t
ScalarPerVector
,
typename
std
::
enable_if
<
NDimSpatial
>
=
1
&&
NDimSpatial
<=
3
,
bool
>::
type
=
false
>
struct
DeviceImageToColumnImpl
:
public
DeviceConvTensorRearrange
<
NDimSpatial
,
ImageLayout
,
InputDataType
,
OutputDataType
,
conv_tensor_rearrange_op
::
ImageToColumn
>
{
static
constexpr
bool
is_NSpatialGC
=
std
::
is_same_v
<
ImageLayout
,
tensor_layout
::
convolution
::
NWGC
>
||
std
::
is_same_v
<
ImageLayout
,
tensor_layout
::
convolution
::
NHWGC
>
||
std
::
is_same_v
<
ImageLayout
,
tensor_layout
::
convolution
::
NDHWGC
>
;
static
constexpr
bool
is_GNSpatialC
=
std
::
is_same_v
<
ImageLayout
,
tensor_layout
::
convolution
::
GNWC
>
||
std
::
is_same_v
<
ImageLayout
,
tensor_layout
::
convolution
::
GNHWC
>
||
std
::
is_same_v
<
ImageLayout
,
tensor_layout
::
convolution
::
GNDHWC
>
;
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
conv_to_gemm_transformer
=
TransformConvFwdToGemm
<
NDimSpatial
,
ConvolutionForwardSpecialization
::
Default
>
{};
static
constexpr
auto
matrix_padder
=
MatrixPadder
<
GemmSpecialization
::
MKPadding
,
index_t
,
index_t
,
index_t
>
{
MPerBlock
,
0
/* NPerBlock*/
,
KPerBlock
};
// Use MakeADescriptor_M_K from grouped convolution forward
static
auto
MakeInputDescriptor_M_K
(
const
ck
::
index_t
N
,
const
ck
::
index_t
C
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_spatial_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
filter_spatial_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
output_spatial_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
image_g_n_c_wis_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_dilations
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_left_pads
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_right_pads
)
{
std
::
array
<
index_t
,
NDimSpatial
+
3
>
a_g_n_c_wis_lengths
{
1
};
std
::
array
<
index_t
,
NDimSpatial
+
3
>
b_g_k_c_xs_lengths
{
1
};
std
::
array
<
index_t
,
NDimSpatial
+
3
>
c_g_n_k_wos_lengths
{
1
};
auto
copy
=
[](
const
auto
&
x
,
auto
&
y
,
index_t
dst_offset
)
{
std
::
copy
(
x
.
begin
(),
x
.
end
(),
y
.
begin
()
+
dst_offset
);
};
constexpr
index_t
spatial_offset
=
3
;
copy
(
input_spatial_lengths
,
a_g_n_c_wis_lengths
,
spatial_offset
);
copy
(
filter_spatial_lengths
,
b_g_k_c_xs_lengths
,
spatial_offset
);
copy
(
output_spatial_lengths
,
c_g_n_k_wos_lengths
,
spatial_offset
);
// fill only significant values (C and N)
a_g_n_c_wis_lengths
[
I1
]
=
N
;
a_g_n_c_wis_lengths
[
I2
]
=
C
;
b_g_k_c_xs_lengths
[
I2
]
=
C
;
c_g_n_k_wos_lengths
[
I1
]
=
N
;
const
auto
in_gemmmraw_gemmkraw_desc
=
conv_to_gemm_transformer
.
template
MakeADescriptor_M_K
<
ImageLayout
>(
a_g_n_c_wis_lengths
,
image_g_n_c_wis_strides
,
b_g_k_c_xs_lengths
,
{},
// not needed for A Descriptor
c_g_n_k_wos_lengths
,
{},
// not needed for A Descriptor
conv_filter_strides
,
conv_filter_dilations
,
input_left_pads
,
input_right_pads
);
const
auto
in_gemmm_gemmk_desc
=
matrix_padder
.
PadADescriptor_M_K
(
in_gemmmraw_gemmkraw_desc
);
return
in_gemmm_gemmk_desc
;
}
static
auto
MakeOutDescriptor_M_K
(
const
ck
::
index_t
N
,
const
ck
::
index_t
C
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
filter_spatial_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
output_spatial_lengths
,
const
std
::
array
<
index_t
,
3
>&
gemm_g_m_k_strides
)
{
const
index_t
NDoHoWo
=
N
*
ck
::
accumulate_n
<
index_t
>
(
output_spatial_lengths
.
begin
(),
NDimSpatial
,
1
,
std
::
multiplies
<>
());
const
index_t
CZYX
=
C
*
ck
::
accumulate_n
<
index_t
>
(
filter_spatial_lengths
.
begin
(),
NDimSpatial
,
1
,
std
::
multiplies
<>
());
const
auto
desc_mraw_kraw
=
make_naive_tensor_descriptor
(
make_tuple
(
NDoHoWo
,
CZYX
),
make_tuple
(
gemm_g_m_k_strides
[
I1
],
gemm_g_m_k_strides
[
I2
]));
return
matrix_padder
.
PadADescriptor_M_K
(
desc_mraw_kraw
);
}
using
InputGridDesc
=
remove_cvref_t
<
decltype
(
MakeInputDescriptor_M_K
(
1
,
1
,
{},
{},
{},
{},
{},
{},
{},
{}))
>
;
using
OutputGridDesc
=
remove_cvref_t
<
decltype
(
MakeOutDescriptor_M_K
(
1
,
1
,
{},
{},
{}))
>
;
using
Block2ETileMap
=
remove_cvref_t
<
decltype
(
BlockToCTileMap_M00_N0_M01Adapt
<
MPerBlock
,
KPerBlock
,
OutputGridDesc
>
(
OutputGridDesc
{}))
>
;
using
GridwiseTensorRearrangeKernel
=
GridwiseTensorRearrange
<
InputGridDesc
,
InputDataType
,
OutputGridDesc
,
OutputDataType
,
BlockSize
,
MPerBlock
,
KPerBlock
,
ThreadClusterLengths
,
ScalarPerVector
,
InMemoryDataOperationEnum
::
Set
,
Block2ETileMap
,
ComputePtrOffsetOfStridedBatch
<
I0
>>
;
struct
Argument
:
public
BaseArgument
{
Argument
(
const
void
*
p_in
,
// input image
void
*
p_out
,
// gemm form
const
ck
::
index_t
G
,
const
ck
::
index_t
N
,
const
ck
::
index_t
C
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_spatial_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
filter_spatial_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
output_spatial_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
image_g_n_c_wis_strides
,
const
std
::
array
<
index_t
,
3
>&
gemm_g_m_k_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_dilations
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_left_pads
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_right_pads
)
:
G_
(
G
),
C_
(
C
),
X_
(
filter_spatial_lengths
[
NDimSpatial
-
I1
]),
p_in_
{
static_cast
<
const
InputDataType
*>
(
p_in
)},
p_out_
{
static_cast
<
OutputDataType
*>
(
p_out
)},
image_g_n_c_wis_strides_
{
image_g_n_c_wis_strides
},
conv_filter_strides_
{
conv_filter_strides
},
conv_filter_dilations_
{
conv_filter_dilations
},
input_left_pads_
{
input_left_pads
},
input_right_pads_
{
input_right_pads
}
{
in_grid_desc_m_k_
=
MakeInputDescriptor_M_K
(
N
,
C
,
input_spatial_lengths
,
filter_spatial_lengths
,
output_spatial_lengths
,
image_g_n_c_wis_strides
,
conv_filter_strides
,
conv_filter_dilations
,
input_left_pads
,
input_right_pads
);
out_grid_desc_m_k_
=
MakeOutDescriptor_M_K
(
N
,
C
,
filter_spatial_lengths
,
output_spatial_lengths
,
gemm_g_m_k_strides
);
compute_ptr_offset_of_batch_
.
BatchStrideA_
=
image_g_n_c_wis_strides
[
I0
];
compute_ptr_offset_of_batch_
.
BatchStrideC_
=
gemm_g_m_k_strides
[
I0
];
}
void
Print
()
const
{
std
::
cout
<<
in_grid_desc_m_k_
<<
std
::
endl
;
std
::
cout
<<
out_grid_desc_m_k_
<<
std
::
endl
;
}
const
ck
::
index_t
G_
;
const
ck
::
index_t
C_
;
const
ck
::
index_t
X_
;
const
InputDataType
*
p_in_
;
OutputDataType
*
p_out_
;
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
image_g_n_c_wis_strides_
;
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_strides_
;
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_dilations_
;
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_left_pads_
;
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_right_pads_
;
InputGridDesc
in_grid_desc_m_k_
;
OutputGridDesc
out_grid_desc_m_k_
;
ComputePtrOffsetOfStridedBatch
<
I0
>
compute_ptr_offset_of_batch_
;
};
struct
Invoker
:
public
BaseInvoker
{
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
{
if
(
stream_config
.
log_level_
>
0
)
{
arg
.
Print
();
}
const
auto
block_2_tile_map
=
BlockToCTileMap_M00_N0_M01Adapt
<
MPerBlock
,
KPerBlock
,
OutputGridDesc
>
(
arg
.
out_grid_desc_m_k_
);
const
index_t
grid_size
=
block_2_tile_map
.
CalculateGridSize
(
arg
.
out_grid_desc_m_k_
)
*
arg
.
G_
;
const
auto
kernel
=
kernel_tensor_rearrange
<
InputGridDesc
,
InputDataType
,
OutputGridDesc
,
OutputDataType
,
Block2ETileMap
,
ComputePtrOffsetOfStridedBatch
<
I0
>
,
GridwiseTensorRearrangeKernel
>
;
float
elapsed_time
=
launch_and_time_kernel
(
stream_config
,
kernel
,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
arg
.
in_grid_desc_m_k_
,
arg
.
p_in_
,
arg
.
out_grid_desc_m_k_
,
arg
.
p_out_
,
arg
.
G_
,
block_2_tile_map
,
arg
.
compute_ptr_offset_of_batch_
);
return
elapsed_time
;
}
float
Run
(
const
BaseArgument
*
p_arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
override
{
return
Run
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
),
stream_config
);
}
};
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
if
constexpr
(
!
(
is_NSpatialGC
||
is_GNSpatialC
))
{
return
false
;
}
const
auto
w_pad_left
=
arg
.
input_left_pads_
[
NDimSpatial
-
I1
];
const
auto
w_pad_right
=
arg
.
input_right_pads_
[
NDimSpatial
-
I1
];
const
auto
dilation_x
=
arg
.
conv_filter_dilations_
[
NDimSpatial
-
I1
];
const
auto
stride_x
=
arg
.
conv_filter_strides_
[
NDimSpatial
-
I1
];
bool
is_w_packed
=
arg
.
image_g_n_c_wis_strides_
[
NDimSpatial
+
I2
]
==
arg
.
C_
;
bool
is_c_packed
=
arg
.
image_g_n_c_wis_strides_
[
I2
]
==
1
;
// check vector acces with c not packed
if
(
!
is_c_packed
&&
ScalarPerVector
!=
1
)
return
false
;
// check vector access of filter window row (only C if C is not packed)
if
(
!
is_w_packed
&&
arg
.
C_
%
ScalarPerVector
!=
0
)
return
false
;
// check vector access of filter window row (X * C)
if
(
arg
.
X_
*
arg
.
C_
%
ScalarPerVector
!=
0
)
return
false
;
// check vector access of pads (w_pad_left/w_pad_right * C)
if
(
w_pad_left
*
arg
.
C_
%
ScalarPerVector
!=
0
||
w_pad_right
*
arg
.
C_
%
ScalarPerVector
!=
0
)
return
false
;
// check vector access of with stride and pad
if
((
w_pad_left
!=
0
||
w_pad_right
!=
0
)
&&
stride_x
>
1
&&
arg
.
C_
%
ScalarPerVector
!=
0
)
return
false
;
// check vector access of with dilation
if
(
dilation_x
>
1
&&
arg
.
C_
%
ScalarPerVector
!=
0
)
return
false
;
return
GridwiseTensorRearrangeKernel
::
CheckValidity
(
arg
.
in_grid_desc_m_k_
,
arg
.
out_grid_desc_m_k_
);
}
bool
IsSupportedArgument
(
const
BaseArgument
*
p_arg
)
override
{
return
IsSupportedArgument
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
));
}
static
auto
MakeArgument
(
const
void
*
p_in
,
// input image
void
*
p_out
,
// gemm form
const
ck
::
index_t
G
,
const
ck
::
index_t
N
,
const
ck
::
index_t
C
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_spatial_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
filter_spatial_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
output_spatial_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
image_g_n_c_wis_strides
,
const
std
::
array
<
index_t
,
3
>&
gemm_g_m_k_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_dilations
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_left_pads
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_right_pads
)
{
return
Argument
{
static_cast
<
const
InputDataType
*>
(
p_in
),
static_cast
<
OutputDataType
*>
(
p_out
),
G
,
N
,
C
,
input_spatial_lengths
,
filter_spatial_lengths
,
output_spatial_lengths
,
image_g_n_c_wis_strides
,
gemm_g_m_k_strides
,
conv_filter_strides
,
conv_filter_dilations
,
input_left_pads
,
input_right_pads
};
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
void
*
p_in
,
// input image
void
*
p_out
,
// gemm form
const
ck
::
index_t
G
,
const
ck
::
index_t
N
,
const
ck
::
index_t
C
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_spatial_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
filter_spatial_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
output_spatial_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
image_g_n_c_wis_strides
,
const
std
::
array
<
index_t
,
3
>&
gemm_g_m_k_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_dilations
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_left_pads
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_right_pads
)
override
{
return
std
::
make_unique
<
Argument
>
(
static_cast
<
const
InputDataType
*>
(
p_in
),
static_cast
<
OutputDataType
*>
(
p_out
),
G
,
N
,
C
,
input_spatial_lengths
,
filter_spatial_lengths
,
output_spatial_lengths
,
image_g_n_c_wis_strides
,
gemm_g_m_k_strides
,
conv_filter_strides
,
conv_filter_dilations
,
input_left_pads
,
input_right_pads
);
}
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
override
{
return
std
::
make_unique
<
Invoker
>
(
Invoker
{});
}
std
::
string
GetTypeString
()
const
override
{
auto
str
=
std
::
stringstream
();
// clang-format off
str
<<
"DeviceImageToColumn"
<<
"<"
<<
BlockSize
<<
", "
<<
MPerBlock
<<
", "
<<
KPerBlock
<<
", "
<<
ScalarPerVector
<<
">"
;
// clang-format on
return
str
.
str
();
}
};
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
include/ck/tensor_operation/gpu/device/impl/device_max_pool_bwd_impl.hpp
0 → 100644
View file @
4a106f7d
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <sstream>
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/device_max_pool_bwd.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_put_element_1d.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_elementwise_1d.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/host_utility/stream_utility.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
// output[indices] = input
template
<
typename
DOutDataType
,
typename
IndexDataType
,
typename
DInDataType
,
ck
::
index_t
InOutVectorSize
>
struct
DeviceMaxPoolBwdImpl
:
public
DeviceMaxPoolBwd
<
DOutDataType
,
IndexDataType
,
DInDataType
>
{
using
DInDataType_AutomicAddPreCast
=
conditional_t
<
is_same_v
<
DInDataType
,
float
>
||
is_same_v
<
DInDataType
,
double
>
,
DInDataType
,
float
>
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
UnaryConvert
=
ck
::
tensor_operation
::
element_wise
::
UnaryConvert
;
static
constexpr
auto
I0
=
Number
<
0
>
{};
template
<
typename
Desc_M
>
static
auto
PadDescriptor_M_1d
(
Desc_M
desc_m
,
index_t
loop_step
)
{
const
auto
m
=
desc_m
.
GetLength
(
I0
);
const
auto
pad
=
math
::
integer_least_multiple
(
m
,
loop_step
)
-
m
;
const
auto
desc_m_pad
=
transform_tensor_descriptor
(
desc_m
,
make_tuple
(
make_right_pad_transform
(
m
,
pad
)),
make_tuple
(
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
return
desc_m_pad
;
}
static
auto
MakeDescriptor_M
(
index_t
length
,
index_t
loop_step
)
{
const
auto
desc_m
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
length
));
return
PadDescriptor_M_1d
(
desc_m
,
loop_step
);
}
using
InOutGrid1dDesc
=
decltype
(
MakeDescriptor_M
(
1
,
1
));
using
GridwisePutElementSet
=
GridwisePutElement_1D
<
InOutGrid1dDesc
,
DOutDataType
,
IndexDataType
,
DInDataType
,
PassThrough
,
InMemoryDataOperationEnum
::
Set
,
InOutVectorSize
>
;
using
GridwisePutElementAtomicAdd
=
GridwisePutElement_1D
<
InOutGrid1dDesc
,
DOutDataType
,
IndexDataType
,
DInDataType_AutomicAddPreCast
,
PassThrough
,
InMemoryDataOperationEnum
::
AtomicAdd
,
InOutVectorSize
>
;
using
GridwiseCasting
=
GridwiseElementwise_1D
<
Tuple
<
InOutGrid1dDesc
>
,
Tuple
<
InOutGrid1dDesc
>
,
Tuple
<
const
DInDataType_AutomicAddPreCast
*>
,
Tuple
<
DInDataType
*>
,
UnaryConvert
,
InOutVectorSize
,
Sequence
<
InOutVectorSize
>
,
Sequence
<
InOutVectorSize
>>
;
struct
Argument
:
public
BaseArgument
{
Argument
(
const
DOutDataType
*
p_dout
,
const
IndexDataType
*
p_indices
,
DInDataType
*
p_din
,
index_t
dout_length
,
index_t
din_length
,
const
std
::
vector
<
ck
::
index_t
>&
window_lengths
,
const
std
::
vector
<
ck
::
index_t
>&
window_strides
,
const
std
::
vector
<
ck
::
index_t
>&
window_dilations
)
:
p_dout_
{
p_dout
},
p_indices_
{
p_indices
},
p_din_
{
p_din
},
dout_length_raw_
{
dout_length
},
din_length_raw_
{
din_length
},
blockSize_
{
256
},
windowOverlap_
{
false
}
{
for
(
size_t
i
=
0
;
i
<
window_lengths
.
size
();
++
i
)
{
auto
eff
=
(
window_lengths
.
at
(
i
)
-
1
)
*
window_dilations
.
at
(
i
)
+
1
;
windowOverlap_
|=
eff
>
window_strides
.
at
(
i
);
}
}
const
DOutDataType
*
p_dout_
;
const
IndexDataType
*
p_indices_
;
DInDataType
*
p_din_
;
index_t
dout_length_raw_
;
index_t
din_length_raw_
;
index_t
blockSize_
;
bool
windowOverlap_
;
};
struct
Invoker
:
public
BaseInvoker
{
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
{
index_t
gridSize
=
getAvailableComputeUnitCount
(
stream_config
);
index_t
loop_step
=
gridSize
*
arg
.
blockSize_
*
InOutVectorSize
;
InOutGrid1dDesc
din_grid_desc
=
MakeDescriptor_M
(
arg
.
din_length_raw_
,
loop_step
);
InOutGrid1dDesc
dout_grid_desc
=
MakeDescriptor_M
(
arg
.
dout_length_raw_
,
loop_step
);
if
constexpr
(
is_same_v
<
DInDataType
,
float
>
||
is_same_v
<
DInDataType
,
double
>
)
{
hip_check_error
(
hipMemsetAsync
(
arg
.
p_din_
,
0
,
arg
.
din_length_raw_
*
sizeof
(
DInDataType
),
stream_config
.
stream_id_
));
if
(
arg
.
windowOverlap_
)
{
const
auto
put_kernel
=
kernel_put_element_1d
<
GridwisePutElementAtomicAdd
,
InOutGrid1dDesc
,
DOutDataType
,
IndexDataType
,
DInDataType
,
PassThrough
>
;
return
launch_and_time_kernel
(
stream_config
,
put_kernel
,
dim3
(
gridSize
),
dim3
(
arg
.
blockSize_
),
0
,
dout_grid_desc
,
arg
.
p_dout_
,
arg
.
p_indices_
,
arg
.
p_din_
,
PassThrough
{});
}
else
{
const
auto
put_kernel
=
kernel_put_element_1d
<
GridwisePutElementSet
,
InOutGrid1dDesc
,
DOutDataType
,
IndexDataType
,
DInDataType
,
PassThrough
>
;
return
launch_and_time_kernel
(
stream_config
,
put_kernel
,
dim3
(
gridSize
),
dim3
(
arg
.
blockSize_
),
0
,
dout_grid_desc
,
arg
.
p_dout_
,
arg
.
p_indices_
,
arg
.
p_din_
,
PassThrough
{});
}
}
else
{
if
(
arg
.
windowOverlap_
)
{
if
(
arg
.
p_workspace_
==
nullptr
)
throw
std
::
runtime_error
(
"wrong! WorkSpace pointer has not been set"
);
hip_check_error
(
hipMemsetAsync
(
arg
.
p_workspace_
,
0
,
arg
.
din_length_raw_
*
sizeof
(
DInDataType_AutomicAddPreCast
),
stream_config
.
stream_id_
));
const
auto
put_kernel
=
kernel_put_element_1d
<
GridwisePutElementAtomicAdd
,
InOutGrid1dDesc
,
DOutDataType
,
IndexDataType
,
DInDataType_AutomicAddPreCast
,
PassThrough
>
;
const
auto
cast_kernel
=
kernel_elementwise_1d
<
GridwiseCasting
,
Tuple
<
InOutGrid1dDesc
>
,
Tuple
<
InOutGrid1dDesc
>
,
Tuple
<
const
DInDataType_AutomicAddPreCast
*>
,
Tuple
<
DInDataType
*>
,
UnaryConvert
>
;
float
elapsed_time
=
launch_and_time_kernel
(
stream_config
,
put_kernel
,
dim3
(
gridSize
),
dim3
(
arg
.
blockSize_
),
0
,
dout_grid_desc
,
arg
.
p_dout_
,
arg
.
p_indices_
,
static_cast
<
DInDataType_AutomicAddPreCast
*>
(
arg
.
p_workspace_
),
PassThrough
{});
elapsed_time
+=
launch_and_time_kernel
(
stream_config
,
cast_kernel
,
dim3
(
gridSize
),
dim3
(
arg
.
blockSize_
),
0
,
ck
::
make_tuple
(
din_grid_desc
),
ck
::
make_tuple
(
din_grid_desc
),
static_cast
<
DInDataType_AutomicAddPreCast
*>
(
arg
.
p_workspace_
),
arg
.
p_din_
,
UnaryConvert
{});
return
elapsed_time
;
}
else
{
hip_check_error
(
hipMemsetAsync
(
arg
.
p_din_
,
0
,
arg
.
din_length_raw_
*
sizeof
(
DInDataType
),
stream_config
.
stream_id_
));
const
auto
put_kernel
=
kernel_put_element_1d
<
GridwisePutElementSet
,
InOutGrid1dDesc
,
DOutDataType
,
IndexDataType
,
DInDataType
,
PassThrough
>
;
hip_check_error
(
hipMemsetAsync
(
arg
.
p_din_
,
0
,
arg
.
din_length_raw_
*
sizeof
(
DInDataType
),
stream_config
.
stream_id_
));
return
launch_and_time_kernel
(
stream_config
,
put_kernel
,
dim3
(
gridSize
),
dim3
(
arg
.
blockSize_
),
0
,
dout_grid_desc
,
arg
.
p_dout_
,
arg
.
p_indices_
,
arg
.
p_din_
,
PassThrough
{});
}
}
}
float
Run
(
const
BaseArgument
*
p_arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
override
{
return
Run
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
),
stream_config
);
}
};
size_t
GetWorkSpaceSize
(
const
BaseArgument
*
pArg
)
const
override
{
const
Argument
*
pArg_
=
dynamic_cast
<
const
Argument
*>
(
pArg
);
bool
needCast
=
pArg_
->
windowOverlap_
&&
!
(
is_same_v
<
DInDataType
,
float
>
||
is_same_v
<
DInDataType
,
double
>
);
if
(
!
needCast
)
return
0
;
else
return
pArg_
->
din_length_raw_
*
sizeof
(
DInDataType_AutomicAddPreCast
);
};
bool
IsSupportedArgument
(
const
BaseArgument
*
p_arg
)
override
{
const
Argument
*
pArg
=
dynamic_cast
<
const
Argument
*>
(
p_arg
);
if
(
pArg
->
din_length_raw_
%
InOutVectorSize
!=
0
||
pArg
->
dout_length_raw_
%
InOutVectorSize
!=
0
)
{
return
false
;
}
return
true
;
}
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
void
*
p_dout
,
const
void
*
p_indices
,
void
*
p_din
,
index_t
dout_length
,
index_t
din_length
,
std
::
vector
<
ck
::
index_t
>
window_lengths
,
std
::
vector
<
ck
::
index_t
>
window_strides
,
std
::
vector
<
ck
::
index_t
>
window_dilations
)
override
{
// Assume p_dout, p_indices, p_din are packed memory space, dout_length and din_length are
// physical size of the packed tensor
return
std
::
make_unique
<
Argument
>
(
static_cast
<
const
DOutDataType
*>
(
p_dout
),
static_cast
<
const
IndexDataType
*>
(
p_indices
),
static_cast
<
DInDataType
*>
(
p_din
),
dout_length
,
din_length
,
window_lengths
,
window_strides
,
window_dilations
);
}
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
override
{
return
std
::
make_unique
<
Invoker
>
(
Invoker
{});
}
};
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
include/ck/tensor_operation/gpu/device/impl/device_multiple_reduce_multiblock.hpp
View file @
4a106f7d
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
2
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
include/ck/tensor_operation/gpu/device/impl/device_multiple_reduce_threadwise.hpp
View file @
4a106f7d
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
2
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
include/ck/tensor_operation/gpu/device/impl/device_normalization_impl.hpp
View file @
4a106f7d
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
2
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -10,8 +10,7 @@
#include "ck/tensor_operation/gpu/device/device_normalization.hpp"
#include "ck/tensor_operation/gpu/device/device_reduce.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_reduce_common.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_normalization_selector.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_set_buffer_value.hpp"
#include "ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_selector.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
...
...
@@ -20,11 +19,16 @@ namespace tensor_operation {
namespace
device
{
// Y = Normalization(X, Beta, Gamma)
// M: Invarient length
// K: Reduce length (Calculate mean and variance along K dimension)
// eg. Length = [N, C, H, W], reduce dim = [C, H, W]
// Then, M = N, K = C * H * W
template
<
typename
XDataType
,
typename
GammaDataType
,
typename
BetaDataType
,
typename
ComputeDataType
,
typename
YDataType
,
typename
SaveMeanInvStdDataType
,
typename
YElementwiseOperation
,
index_t
Rank
,
index_t
NumReduceDim
,
...
...
@@ -40,12 +44,13 @@ template <typename XDataType,
index_t
BetaSrcVectorDim
,
index_t
BetaSrcVectorSize
,
index_t
YDstVectorSize
,
index_t
SaveMeanInvStdDstVectorSize
,
bool
UseWelford
=
true
>
struct
DeviceNormalizationImpl
:
public
DeviceNormalization
<
XDataType
,
GammaDataType
,
BetaDataType
,
ComputeDataType
,
YDataType
,
SaveMeanInvStdDataType
,
YElementwiseOperation
,
Rank
,
NumReduceDim
>
...
...
@@ -61,19 +66,24 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
(
BetaSrcVectorDim
==
1
&&
KThreadSliceSize
%
BetaSrcVectorSize
==
0
)),
"Invalid thread slice sizes and/or beta vector sizes configuration, please check!"
);
static_assert
(
MThreadSliceSize
%
SaveMeanInvStdDstVectorSize
==
0
,
"Invalid thread slice sizes and/or save mean and inverse std vector sizes "
"configuration, please check!"
);
using
PassThrough
=
tensor_operation
::
element_wise
::
PassThrough
;
static
constexpr
index_t
NumInvariantDim
=
Rank
-
NumReduceDim
;
static
constexpr
index_t
M_BlockTileSize
=
MThreadClusterSize
*
MThreadSliceSize
;
static
constexpr
index_t
K_BlockTileSize
=
KThreadClusterSize
*
KThreadSliceSize
;
static
constexpr
bool
reduceAllDim
=
(
NumInvariantDim
==
0
);
static_assert
(
!
reduceAllDim
);
// TODO
static
auto
MakeSrc2dDescriptor
(
const
std
::
vector
<
index_t
>&
inLengths
,
const
std
::
vector
<
index_t
>&
inStrides
,
int
blkGroupSize
,
int
numBlockTileIteration
)
{
constexpr
index_t
NumInvariantDim
=
Rank
-
NumReduceDim
;
static
constexpr
index_t
numSrcDim
=
Rank
;
static
constexpr
bool
reduceAllDim
=
(
NumInvariantDim
==
0
);
const
auto
tupleSrcLengths
=
make_tuple_from_array
(
inLengths
,
Number
<
numSrcDim
>
{});
const
auto
tupleSrcStrides
=
make_tuple_from_array
(
inStrides
,
Number
<
numSrcDim
>
{});
...
...
@@ -117,10 +127,9 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
const
auto
invariantLength
=
in_grid_desc_m_k
.
GetLength
(
Number
<
0
>
{});
const
auto
reduceLength
=
in_grid_desc_m_k
.
GetLength
(
Number
<
1
>
{});
const
int
reduceSizePerBlock
=
K_BlockTileSize
*
numBlockTileIteration
;
const
auto
inPad_M
=
math
::
integer_least_multiple
(
invariantLength
,
M_BlockTileSize
)
-
invariantLength
;
const
auto
inPad_K
=
reduceSizePerBlock
*
blkGroupSize
-
reduceLength
;
const
auto
inPad_K
=
K_BlockTileSize
*
numBlockTileIteration
-
reduceLength
;
auto
in_grid_desc_m_k_padded
=
transform_tensor_descriptor
(
in_grid_desc_m_k
,
...
...
@@ -132,7 +141,37 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
return
(
in_grid_desc_m_k_padded
);
};
using
GridDesc_M_K
=
decltype
(
MakeSrc2dDescriptor
({
1
},
{
1
},
1
,
1
));
static
auto
MakeSaveMeanInvStdDescriptor_M
(
const
std
::
vector
<
index_t
>&
lengths
,
const
std
::
vector
<
index_t
>&
strides
)
{
using
InvariantDims
=
typename
arithmetic_sequence_gen
<
0
,
NumInvariantDim
,
1
>::
type
;
const
auto
tupleSrcLengths
=
make_tuple_from_array_and_index_seq
(
lengths
,
InvariantDims
{});
const
auto
tupleSrcStrides
=
make_tuple_from_array_and_index_seq
(
strides
,
InvariantDims
{});
const
auto
desc
=
make_naive_tensor_descriptor
(
tupleSrcLengths
,
tupleSrcStrides
);
const
auto
grid_desc_m
=
transform_tensor_descriptor
(
desc
,
make_tuple
(
make_merge_transform
(
tupleSrcLengths
)),
make_tuple
(
InvariantDims
{}),
make_tuple
(
Sequence
<
0
>
{}));
const
auto
invariantLength
=
grid_desc_m
.
GetLength
(
Number
<
0
>
{});
const
auto
pad_M
=
math
::
integer_least_multiple
(
invariantLength
,
M_BlockTileSize
)
-
invariantLength
;
auto
grid_desc_m_padded
=
transform_tensor_descriptor
(
grid_desc_m
,
make_tuple
(
make_right_pad_transform
(
invariantLength
,
pad_M
)),
make_tuple
(
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
return
grid_desc_m_padded
;
}
using
GridDesc_M_K
=
decltype
(
MakeSrc2dDescriptor
({
1
},
{
1
},
1
));
using
GridDesc_M
=
decltype
(
MakeSaveMeanInvStdDescriptor_M
({
1
},
{
1
}));
struct
Argument
:
public
BaseArgument
{
...
...
@@ -141,17 +180,23 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
const
std
::
vector
<
index_t
>
gammaStrides
,
const
std
::
vector
<
index_t
>
betaStrides
,
const
std
::
vector
<
index_t
>
yStrides
,
const
std
::
vector
<
index_t
>
saveMeanStrides
,
const
std
::
vector
<
index_t
>
saveInvStdStrides
,
const
std
::
vector
<
index_t
>
reduceDims
,
YElementwiseOperation
y_elementwise_op
,
double
epsilon
,
const
XDataType
*
p_x
,
const
GammaDataType
*
p_gamma
,
const
BetaDataType
*
p_beta
,
YDataType
*
p_y
)
YDataType
*
p_y
,
SaveMeanInvStdDataType
*
p_saveMean
,
SaveMeanInvStdDataType
*
p_saveInvStd
)
:
p_x_
(
p_x
),
p_gamma_
(
p_gamma
),
p_beta_
(
p_beta
),
p_y_
(
p_y
),
p_saveMean_
(
p_saveMean
),
p_saveInvStd_
(
p_saveInvStd
),
y_elementwise_op_
(
y_elementwise_op
)
{
epsilon_
=
static_cast
<
ComputeDataType
>
(
epsilon
);
...
...
@@ -161,30 +206,31 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
yStrides_
=
shuffle_tensor_dimensions
<
Rank
,
NumReduceDim
>
(
yStrides
,
reduceDims
);
gammaStrides_
=
shuffle_tensor_dimensions
<
Rank
,
NumReduceDim
>
(
gammaStrides
,
reduceDims
);
betaStrides_
=
shuffle_tensor_dimensions
<
Rank
,
NumReduceDim
>
(
betaStrides
,
reduceDims
);
saveMeanStrides_
=
saveMeanStrides
;
saveInvStdStrides_
=
saveInvStdStrides
;
long_index_t
invariant_total_length
;
long_index_t
reduce_total_length
;
std
::
tie
(
MRaw_
,
KRaw_
)
=
get_2d_lengths
<
Rank
,
NumReduceDim
>
(
Lengths_
);
std
::
tie
(
invariant_total_length
,
reduce_total_length
)
=
get_2d_lengths
<
Rank
,
NumReduceDim
>
(
Lengths_
);
numBlockTileIteration_
=
math
::
integer_divide_ceil
(
KRaw_
,
K_BlockTileSize
);
blkGroupSize_
=
1
;
numBlockTileIteration_
=
(
reduce_total_length
+
K_BlockTileSize
-
1
)
/
K_BlockTileSize
;
gridSize_
=
math
::
integer_divide_ceil
(
MRaw_
,
M_BlockTileSize
);
gridSize_
=
math
::
integer_least_multiple
(
invariant_total_length
,
M_BlockTileSize
)
/
M_BlockTileSize
*
blkGroupSize_
;
x_grid_desc_m_k_
=
MakeSrc2dDescriptor
(
Lengths_
,
xStrides_
,
blkGroupSize_
,
numBlockTileIteration_
);
x_grid_desc_m_k_
=
MakeSrc2dDescriptor
(
Lengths_
,
xStrides_
,
numBlockTileIteration_
);
gamma_grid_desc_m_k_
=
MakeSrc2dDescriptor
(
Lengths_
,
gammaStrides_
,
blkGroupSize_
,
numBlockTileIteration_
);
MakeSrc2dDescriptor
(
Lengths_
,
gammaStrides_
,
numBlockTileIteration_
);
beta_grid_desc_m_k_
=
MakeSrc2dDescriptor
(
Lengths_
,
betaStrides_
,
blkGroupSize_
,
numBlockTileIteration_
);
y_grid_desc_m_k_
=
MakeSrc2dDescriptor
(
Lengths_
,
yStrides_
,
blkGroupSize_
,
numBlockTileIteration_
);
MakeSrc2dDescriptor
(
Lengths_
,
betaStrides_
,
numBlockTileIteration_
);
y_grid_desc_m_k_
=
MakeSrc2dDescriptor
(
Lengths_
,
yStrides_
,
numBlockTileIteration_
);
save_mean_grid_desc_m_
=
MakeSaveMeanInvStdDescriptor_M
(
Lengths_
,
saveMeanStrides
);
save_inv_std_grid_desc_m_
=
MakeSaveMeanInvStdDescriptor_M
(
Lengths_
,
saveInvStdStrides
);
isSweeponce_
=
x_grid_desc_m_k_
.
GetLength
(
Number
<
1
>
{})
<=
KThreadClusterSize
*
KThreadSliceSize
;
if
constexpr
(
NumInvariantDim
==
0
)
invariant_lowest_length_
=
1
;
else
invariant_lowest_length_
=
Lengths_
[
NumInvariantDim
-
1
];
}
ComputeDataType
epsilon_
;
...
...
@@ -193,16 +239,19 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
const
GammaDataType
*
p_gamma_
;
const
BetaDataType
*
p_beta_
;
YDataType
*
p_y_
;
SaveMeanInvStdDataType
*
p_saveMean_
;
SaveMeanInvStdDataType
*
p_saveInvStd_
;
std
::
vector
<
index_t
>
Lengths_
;
std
::
vector
<
index_t
>
xStrides_
;
std
::
vector
<
index_t
>
gammaStrides_
;
std
::
vector
<
index_t
>
betaStrides_
;
std
::
vector
<
index_t
>
yStrides_
;
std
::
vector
<
index_t
>
saveMeanStrides_
;
std
::
vector
<
index_t
>
saveInvStdStrides_
;
YElementwiseOperation
y_elementwise_op_
;
int
blkGroupSize_
;
int
numBlockTileIteration_
;
size_t
gridSize_
;
...
...
@@ -210,7 +259,14 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
GridDesc_M_K
gamma_grid_desc_m_k_
;
GridDesc_M_K
beta_grid_desc_m_k_
;
GridDesc_M_K
y_grid_desc_m_k_
;
GridDesc_M
save_mean_grid_desc_m_
;
GridDesc_M
save_inv_std_grid_desc_m_
;
bool
isSweeponce_
;
index_t
MRaw_
;
// invarient length
index_t
KRaw_
;
// reduce length
index_t
invariant_lowest_length_
;
};
struct
Invoker
:
public
BaseInvoker
...
...
@@ -221,9 +277,11 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
GammaDataType
,
BetaDataType
,
YDataType
,
SaveMeanInvStdDataType
,
ComputeDataType
,
YElementwiseOperation
,
GridDesc_M_K
,
GridDesc_M
,
BlockSize
,
MThreadClusterSize
,
KThreadClusterSize
,
...
...
@@ -237,6 +295,7 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
BetaSrcVectorSize
,
XYSrcVectorDim
,
YDstVectorSize
,
SaveMeanInvStdDstVectorSize
,
UseWelford
>
(
arg
.
isSweeponce_
);
float
avg_time
=
0
;
...
...
@@ -249,12 +308,16 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
arg
.
gamma_grid_desc_m_k_
,
arg
.
beta_grid_desc_m_k_
,
arg
.
y_grid_desc_m_k_
,
arg
.
save_mean_grid_desc_m_
,
arg
.
save_inv_std_grid_desc_m_
,
arg
.
numBlockTileIteration_
,
arg
.
epsilon_
,
arg
.
p_x_
,
arg
.
p_gamma_
,
arg
.
p_beta_
,
arg
.
p_y_
,
arg
.
p_saveMean_
,
arg
.
p_saveInvStd_
,
arg
.
y_elementwise_op_
);
return
(
avg_time
);
...
...
@@ -271,8 +334,6 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
{
const
Argument
*
p_arg_
=
dynamic_cast
<
const
Argument
*>
(
p_arg
);
constexpr
index_t
NumInvariantDim
=
Rank
-
NumReduceDim
;
if
constexpr
(
XYSrcVectorDim
==
0
)
{
if
constexpr
(
NumInvariantDim
==
0
)
...
...
@@ -281,10 +342,15 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
}
else
{
printf
(
"!!!! %d
\n
"
,
p_arg_
->
invariant_lowest_length_
);
if
(
p_arg_
->
xStrides_
[
NumInvariantDim
-
1
]
!=
1
)
return
false
;
if
(
p_arg_
->
invariant_lowest_length
%
XSrcVectorSize
!=
0
)
if
(
p_arg_
->
invariant_lowest_length_
%
XSrcVectorSize
!=
0
)
return
false
;
if
(
p_arg_
->
invariant_lowest_length_
%
YDstVectorSize
!=
0
)
return
false
;
};
}
...
...
@@ -295,12 +361,12 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
if
(
p_arg_
->
Lengths_
[
Rank
-
1
]
%
XSrcVectorSize
!=
0
)
return
false
;
};
if
(
p_arg_
->
Lengths_
[
Rank
-
1
]
%
YDstVectorSize
!=
0
)
{
return
false
;
}
if
(
p_arg_
->
Lengths_
[
Rank
-
1
]
%
YDstVectorSize
!=
0
)
{
return
false
;
}
};
// if fastest dim is not reduced
if
constexpr
(
GammaSrcVectorDim
==
0
)
...
...
@@ -326,7 +392,7 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
if
(
p_arg_
->
betaStrides_
[
NumInvariantDim
-
1
]
!=
1
)
return
(
false
);
if
(
p_arg_
->
invariant_lowest_length
%
BetaSrcVectorSize
!=
0
)
if
(
p_arg_
->
invariant_lowest_length
_
%
BetaSrcVectorSize
!=
0
)
return
(
false
);
}
else
// if fastest dim is reduced
...
...
@@ -338,6 +404,9 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
return
(
false
);
}
if
(
p_arg_
->
invariant_lowest_length_
%
SaveMeanInvStdDstVectorSize
!=
0
)
return
false
;
return
true
;
};
...
...
@@ -347,6 +416,8 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
const
std
::
vector
<
index_t
>
gammaStrides
,
const
std
::
vector
<
index_t
>
betaStrides
,
const
std
::
vector
<
index_t
>
yStrides
,
const
std
::
vector
<
index_t
>
saveMeanStrides
,
const
std
::
vector
<
index_t
>
saveInvStdStrides
,
const
std
::
vector
<
index_t
>
reduceDims
,
double
epsilon
,
const
void
*
p_x
,
...
...
@@ -354,27 +425,30 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
const
void
*
p_beta
,
void
*
p_y
,
void
*
p_saveMean
,
void
*
p_saveInv
Var
,
void
*
p_saveInv
Std
,
YElementwiseOperation
y_elementwise_op
)
override
{
// TODO
// Optional cache of the intermediate results (mean and InvVariance) during the
// forward pass could speedup in the backward
ignore
=
p_saveMean
;
ignore
=
p_saveInvVar
;
if
(
lengths
.
size
()
!=
Rank
||
xStrides
.
size
()
!=
Rank
||
gammaStrides
.
size
()
!=
Rank
||
betaStrides
.
size
()
!=
Rank
||
yStrides
.
size
()
!=
Rank
||
saveMeanStrides
.
size
()
!=
NumInvariantDim
||
saveInvStdStrides
.
size
()
!=
NumInvariantDim
)
throw
std
::
runtime_error
(
"dimension is incorrect"
);
return
std
::
make_unique
<
Argument
>
(
lengths
,
xStrides
,
gammaStrides
,
betaStrides
,
yStrides
,
saveMeanStrides
,
saveInvStdStrides
,
reduceDims
,
y_elementwise_op
,
epsilon
,
static_cast
<
const
XDataType
*>
(
p_x
),
static_cast
<
const
GammaDataType
*>
(
p_gamma
),
static_cast
<
const
BetaDataType
*>
(
p_beta
),
static_cast
<
YDataType
*>
(
p_y
));
static_cast
<
YDataType
*>
(
p_y
),
static_cast
<
SaveMeanInvStdDataType
*>
(
p_saveMean
),
static_cast
<
SaveMeanInvStdDataType
*>
(
p_saveInvStd
));
};
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
override
...
...
include/ck/tensor_operation/gpu/device/impl/device_normalization_splitk_impl.hpp
0 → 100644
View file @
4a106f7d
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <sstream>
#include "ck/utility/reduction_operator.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/device/device_normalization.hpp"
#include "ck/tensor_operation/gpu/device/device_reduce.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_reduce_common.hpp"
#include "ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_splitk_1st.hpp"
#include "ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_splitk_2nd.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
namespace
ck
{
template
<
typename
GridwiseWelford
,
typename
XDataType
,
typename
WorkspaceMeanVarDataType
,
typename
ComputeDataType
,
typename
XGridDesc_M_K
,
typename
MeanVarGridDesc_M_KBlock
>
__global__
void
kernel_normalizationSplitK1st
(
const
XGridDesc_M_K
x_grid_desc_m_k
,
const
MeanVarGridDesc_M_KBlock
mean_var_grid_desc_m_kblock
,
index_t
num_k_block_tile_iteration
,
const
XDataType
*
const
__restrict__
p_x_global
,
WorkspaceMeanVarDataType
*
const
__restrict__
p_welford_mean
,
WorkspaceMeanVarDataType
*
const
__restrict__
p_welford_variance
,
int32_t
*
const
__restrict__
p_welford_count
)
{
GridwiseWelford
::
Run
(
x_grid_desc_m_k
,
mean_var_grid_desc_m_kblock
,
num_k_block_tile_iteration
,
p_x_global
,
p_welford_mean
,
p_welford_variance
,
p_welford_count
);
};
template
<
typename
GridwiseWelfordNormalization
,
typename
WorkspaceMeanVarDataType
,
typename
XDataType
,
typename
GammaDataType
,
typename
BetaDataType
,
typename
YDataType
,
typename
SaveMeanInvStdDataType
,
typename
ComputeDataType
,
typename
YElementwiseOperation
,
typename
MeanVarGridDesc_M_KBlock
,
typename
CountGridDesc_M_KBlock
,
typename
XYGammaBetaGridDesc_M_K
,
typename
SaveMeanInvStdGridDesc_M
>
__global__
void
kernel_normalizationSplitK2nd
(
const
MeanVarGridDesc_M_KBlock
mean_var_grid_desc_m_kblock
,
const
CountGridDesc_M_KBlock
count_grid_desc_m_kblock
,
const
XYGammaBetaGridDesc_M_K
x_grid_desc_m_k
,
const
XYGammaBetaGridDesc_M_K
gamma_grid_desc_m_k
,
const
XYGammaBetaGridDesc_M_K
beta_grid_desc_m_k
,
const
XYGammaBetaGridDesc_M_K
y_grid_desc_m_k
,
const
SaveMeanInvStdGridDesc_M
save_mean_grid_desc_m
,
const
SaveMeanInvStdGridDesc_M
save_inv_std_grid_desc_m
,
index_t
num_k_mean_var_count_iteration
,
index_t
num_k_block_tile_iteration
,
index_t
k_grid_size
,
ComputeDataType
epsilon
,
const
WorkspaceMeanVarDataType
*
const
p_mean_global
,
const
WorkspaceMeanVarDataType
*
const
p_variance_global
,
const
int32_t
*
const
p_welford_count_global
,
const
XDataType
*
const
__restrict__
p_x_global
,
const
GammaDataType
*
const
__restrict__
p_gamma_global
,
const
BetaDataType
*
const
__restrict__
p_beta_global
,
YDataType
*
const
__restrict__
p_y_global
,
SaveMeanInvStdDataType
*
const
__restrict__
p_save_mean_global
,
SaveMeanInvStdDataType
*
const
__restrict__
p_save_inv_std_global
,
const
YElementwiseOperation
y_elementwise_op
)
{
GridwiseWelfordNormalization
::
Run
(
mean_var_grid_desc_m_kblock
,
count_grid_desc_m_kblock
,
x_grid_desc_m_k
,
gamma_grid_desc_m_k
,
beta_grid_desc_m_k
,
y_grid_desc_m_k
,
save_mean_grid_desc_m
,
save_inv_std_grid_desc_m
,
num_k_mean_var_count_iteration
,
num_k_block_tile_iteration
,
k_grid_size
,
epsilon
,
p_mean_global
,
p_variance_global
,
p_welford_count_global
,
p_x_global
,
p_gamma_global
,
p_beta_global
,
p_y_global
,
p_save_mean_global
,
p_save_inv_std_global
,
y_elementwise_op
);
};
}
// namespace ck
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
// Y = Normalization(X, Beta, Gamma)
// M: Invarient length
// K: Reduce length (Calculate mean and variance along K dimension)
// eg. Length = [N, C, H, W], reduce dim = [C, H, W]
// Then, M = N, K = C * H * W
template
<
typename
XDataType
,
typename
GammaDataType
,
typename
BetaDataType
,
typename
ComputeDataType
,
typename
YDataType
,
typename
SaveMeanInvStdDataType
,
typename
YElementwiseOperation
,
index_t
Rank
,
index_t
NumReduceDim
,
index_t
BlockSize
,
index_t
MThreadClusterSize
,
index_t
KThreadClusterSize
,
index_t
MThreadSliceSize
,
index_t
KThreadSliceSize
,
index_t
XYVectorDim
,
index_t
XSrcVectorSize
,
index_t
GammaSrcVectorDim
,
index_t
GammaSrcVectorSize
,
index_t
BetaSrcVectorDim
,
index_t
BetaSrcVectorSize
,
index_t
YDstVectorSize
,
index_t
SaveMeanInvStdDstVectorSize
>
struct
DeviceNormalizationSplitKImpl
:
public
DeviceNormalization
<
XDataType
,
GammaDataType
,
BetaDataType
,
YDataType
,
SaveMeanInvStdDataType
,
YElementwiseOperation
,
Rank
,
NumReduceDim
>
{
using
WorkspaceMeanVarDataType
=
SaveMeanInvStdDataType
;
static_assert
(
BlockSize
==
MThreadClusterSize
*
KThreadClusterSize
);
static_assert
(
((
GammaSrcVectorDim
==
0
&&
MThreadSliceSize
%
GammaSrcVectorSize
==
0
)
||
(
GammaSrcVectorDim
==
1
&&
KThreadSliceSize
%
GammaSrcVectorSize
==
0
)),
"Invalid thread slice sizes and/or gamma vector sizes configuration, please check!"
);
static_assert
(
((
BetaSrcVectorDim
==
0
&&
MThreadSliceSize
%
BetaSrcVectorSize
==
0
)
||
(
BetaSrcVectorDim
==
1
&&
KThreadSliceSize
%
BetaSrcVectorSize
==
0
)),
"Invalid thread slice sizes and/or beta vector sizes configuration, please check!"
);
static_assert
(
MThreadSliceSize
%
SaveMeanInvStdDstVectorSize
==
0
,
"Invalid thread slice sizes and/or save mean and inverse std vector sizes "
"configuration, please check!"
);
using
PassThrough
=
tensor_operation
::
element_wise
::
PassThrough
;
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
index_t
NumInvariantDim
=
Rank
-
NumReduceDim
;
static
constexpr
index_t
M_BlockTileSize
=
MThreadClusterSize
*
MThreadSliceSize
;
static
constexpr
index_t
K_BlockTileSize
=
KThreadClusterSize
*
KThreadSliceSize
;
static
constexpr
bool
reduceAllDim
=
(
NumInvariantDim
==
0
);
static_assert
(
!
reduceAllDim
);
// TODO
static
auto
MakeSrc2dDescriptor
(
const
std
::
vector
<
index_t
>&
inLengths
,
const
std
::
vector
<
index_t
>&
inStrides
,
int
kBlockSize
,
int
numBlockTileIteration
)
{
static
constexpr
index_t
numSrcDim
=
Rank
;
const
auto
tupleSrcLengths
=
make_tuple_from_array
(
inLengths
,
Number
<
numSrcDim
>
{});
const
auto
tupleSrcStrides
=
make_tuple_from_array
(
inStrides
,
Number
<
numSrcDim
>
{});
const
auto
inDesc
=
make_naive_tensor_descriptor
(
tupleSrcLengths
,
tupleSrcStrides
);
const
auto
in_grid_desc_m_k
=
[
&
]()
{
if
constexpr
(
reduceAllDim
)
{
const
auto
one_dim_inDesc
=
transform_tensor_descriptor
(
inDesc
,
make_tuple
(
make_merge_transform
(
tupleSrcLengths
)),
make_tuple
(
typename
arithmetic_sequence_gen
<
0
,
numSrcDim
,
1
>::
type
{}),
make_tuple
(
Sequence
<
0
>
{}));
return
transform_tensor_descriptor
(
one_dim_inDesc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
1
,
one_dim_inDesc
.
GetLength
(
Number
<
0
>
{})))),
make_tuple
(
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
1
>
{}));
}
else
{
using
InvariantDims
=
typename
arithmetic_sequence_gen
<
0
,
NumInvariantDim
,
1
>::
type
;
using
ReduceDims
=
typename
arithmetic_sequence_gen
<
NumInvariantDim
,
Rank
,
1
>::
type
;
const
auto
reduceDimLengths
=
make_tuple_from_array_and_index_seq
(
inLengths
,
ReduceDims
{});
const
auto
invariantDimLengths
=
make_tuple_from_array_and_index_seq
(
inLengths
,
InvariantDims
{});
return
transform_tensor_descriptor
(
inDesc
,
make_tuple
(
make_merge_transform
(
invariantDimLengths
),
make_merge_transform
(
reduceDimLengths
)),
make_tuple
(
InvariantDims
{},
ReduceDims
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
}
}();
const
auto
invariantLength
=
in_grid_desc_m_k
.
GetLength
(
Number
<
0
>
{});
const
auto
reduceLength
=
in_grid_desc_m_k
.
GetLength
(
Number
<
1
>
{});
const
int
reduceSizePerBlock
=
K_BlockTileSize
*
numBlockTileIteration
;
const
auto
inPad_M
=
math
::
integer_least_multiple
(
invariantLength
,
M_BlockTileSize
)
-
invariantLength
;
const
auto
inPad_K
=
reduceSizePerBlock
*
kBlockSize
-
reduceLength
;
auto
in_grid_desc_m_k_padded
=
transform_tensor_descriptor
(
in_grid_desc_m_k
,
make_tuple
(
make_right_pad_transform
(
invariantLength
,
inPad_M
),
make_right_pad_transform
(
reduceLength
,
inPad_K
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
return
(
in_grid_desc_m_k_padded
);
};
template
<
typename
DoPads
,
index_t
MPerTile
,
index_t
KPerTile
>
static
auto
MakeWorkspaceMeanVarDescriptor_M_K
(
index_t
M
,
index_t
K
)
{
const
auto
grid_desc_m_k
=
make_naive_tensor_descriptor
(
make_tuple
(
M
,
K
),
make_tuple
(
K
,
I1
));
return
PadTensorDescriptor
(
grid_desc_m_k
,
make_tuple
(
MPerTile
,
KPerTile
),
DoPads
{});
}
template
<
typename
DoPads
,
index_t
MPerTile
,
index_t
KPerTile
>
static
auto
MakeWorkspaceCountDescriptor_M_K
(
index_t
M
,
index_t
K
)
{
const
auto
grid_desc_m_k
=
make_naive_tensor_descriptor
(
make_tuple
(
M
,
K
),
make_tuple
(
I0
,
I1
));
return
PadTensorDescriptor
(
grid_desc_m_k
,
make_tuple
(
MPerTile
,
KPerTile
),
DoPads
{});
}
static
auto
MakeSaveMeanInvStdDescriptor_M
(
const
std
::
vector
<
index_t
>&
lengths
,
const
std
::
vector
<
index_t
>&
strides
)
{
using
InvariantDims
=
typename
arithmetic_sequence_gen
<
0
,
NumInvariantDim
,
1
>::
type
;
const
auto
tupleSrcLengths
=
make_tuple_from_array_and_index_seq
(
lengths
,
InvariantDims
{});
const
auto
tupleSrcStrides
=
make_tuple_from_array_and_index_seq
(
strides
,
InvariantDims
{});
const
auto
desc
=
make_naive_tensor_descriptor
(
tupleSrcLengths
,
tupleSrcStrides
);
const
auto
grid_desc_m
=
transform_tensor_descriptor
(
desc
,
make_tuple
(
make_merge_transform
(
tupleSrcLengths
)),
make_tuple
(
InvariantDims
{}),
make_tuple
(
Sequence
<
0
>
{}));
const
auto
invariantLength
=
grid_desc_m
.
GetLength
(
Number
<
0
>
{});
const
auto
pad_M
=
math
::
integer_least_multiple
(
invariantLength
,
M_BlockTileSize
)
-
invariantLength
;
auto
grid_desc_m_padded
=
transform_tensor_descriptor
(
grid_desc_m
,
make_tuple
(
make_right_pad_transform
(
invariantLength
,
pad_M
)),
make_tuple
(
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
return
grid_desc_m_padded
;
}
using
SrcGridDesc_M_K
=
decltype
(
MakeSrc2dDescriptor
({
1
},
{
1
},
1
,
1
));
using
Kernel1MeanVarGridDesc_M_KBlock
=
decltype
(
MakeWorkspaceMeanVarDescriptor_M_K
<
Sequence
<
true
,
false
>
,
1
,
1
>
(
1
,
1
));
using
Kernel2MeanVarGridDesc_M_KBlock
=
decltype
(
MakeWorkspaceMeanVarDescriptor_M_K
<
Sequence
<
true
,
true
>
,
1
,
1
>
(
1
,
1
));
using
Kernel2CountGridDesc_M_KBlock
=
decltype
(
MakeWorkspaceCountDescriptor_M_K
<
Sequence
<
true
,
true
>
,
1
,
1
>
(
1
,
1
));
using
SaveMeanInvStdGridDesc_M
=
decltype
(
MakeSaveMeanInvStdDescriptor_M
({
1
},
{
1
}));
using
GridwiseWelford
=
GridwiseNormalizationSplitK1st
<
XDataType
,
ComputeDataType
,
WorkspaceMeanVarDataType
,
SrcGridDesc_M_K
,
Kernel1MeanVarGridDesc_M_KBlock
,
BlockSize
,
MThreadClusterSize
,
KThreadClusterSize
,
MThreadSliceSize
,
KThreadSliceSize
,
XYVectorDim
,
XSrcVectorSize
>
;
using
GridwiseWelfordNormalization
=
GridwiseNormalizationSplitK2nd
<
WorkspaceMeanVarDataType
,
XDataType
,
GammaDataType
,
BetaDataType
,
YDataType
,
SaveMeanInvStdDataType
,
ComputeDataType
,
YElementwiseOperation
,
Kernel2MeanVarGridDesc_M_KBlock
,
Kernel2CountGridDesc_M_KBlock
,
SrcGridDesc_M_K
,
SaveMeanInvStdGridDesc_M
,
BlockSize
,
MThreadClusterSize
,
KThreadClusterSize
,
MThreadSliceSize
,
KThreadSliceSize
,
XYVectorDim
,
XSrcVectorSize
,
GammaSrcVectorDim
,
GammaSrcVectorSize
,
BetaSrcVectorDim
,
BetaSrcVectorSize
,
XYVectorDim
,
YDstVectorSize
,
SaveMeanInvStdDstVectorSize
>
;
struct
Argument
:
public
BaseArgument
{
Argument
(
const
std
::
vector
<
index_t
>
lengths
,
const
std
::
vector
<
index_t
>
xStrides
,
const
std
::
vector
<
index_t
>
gammaStrides
,
const
std
::
vector
<
index_t
>
betaStrides
,
const
std
::
vector
<
index_t
>
yStrides
,
const
std
::
vector
<
index_t
>
saveMeanStrides
,
const
std
::
vector
<
index_t
>
saveInvStdStrides
,
const
std
::
vector
<
index_t
>
reduceDims
,
YElementwiseOperation
y_elementwise_op
,
double
epsilon
,
const
XDataType
*
p_x
,
const
GammaDataType
*
p_gamma
,
const
BetaDataType
*
p_beta
,
YDataType
*
p_y
,
SaveMeanInvStdDataType
*
p_saveMean
,
SaveMeanInvStdDataType
*
p_saveInvStd
)
:
p_x_
(
p_x
),
p_gamma_
(
p_gamma
),
p_beta_
(
p_beta
),
p_y_
(
p_y
),
p_saveMean_
(
p_saveMean
),
p_saveInvStd_
(
p_saveInvStd
),
p_workspace_mean_
{
nullptr
},
p_workspace_var_
{
nullptr
},
p_workspace_count_
{
nullptr
},
y_elementwise_op_
(
y_elementwise_op
)
{
epsilon_
=
static_cast
<
ComputeDataType
>
(
epsilon
);
Lengths_
=
shuffle_tensor_dimensions
<
Rank
,
NumReduceDim
>
(
lengths
,
reduceDims
);
xStrides_
=
shuffle_tensor_dimensions
<
Rank
,
NumReduceDim
>
(
xStrides
,
reduceDims
);
yStrides_
=
shuffle_tensor_dimensions
<
Rank
,
NumReduceDim
>
(
yStrides
,
reduceDims
);
gammaStrides_
=
shuffle_tensor_dimensions
<
Rank
,
NumReduceDim
>
(
gammaStrides
,
reduceDims
);
betaStrides_
=
shuffle_tensor_dimensions
<
Rank
,
NumReduceDim
>
(
betaStrides
,
reduceDims
);
saveMeanStrides_
=
saveMeanStrides
;
saveInvStdStrides_
=
saveInvStdStrides
;
std
::
tie
(
MRaw_
,
KRaw_
)
=
get_2d_lengths
<
Rank
,
NumReduceDim
>
(
Lengths_
);
numBlockTileIteration_
=
1
;
while
(
true
)
{
int
testKGridSize
=
math
::
integer_divide_ceil
(
KRaw_
,
K_BlockTileSize
*
numBlockTileIteration_
);
// we want the kGridSize_ be not more than 128
if
(
testKGridSize
<=
128
)
break
;
++
numBlockTileIteration_
;
};
kGridSize_
=
math
::
integer_divide_ceil
(
KRaw_
,
K_BlockTileSize
*
numBlockTileIteration_
);
gridSize_
=
math
::
integer_divide_ceil
(
MRaw_
,
M_BlockTileSize
)
*
kGridSize_
;
// We do not use vector load for mean, var and count
static
constexpr
index_t
K_MeanVarCountBlockTileSize
=
KThreadClusterSize
;
numMeanVarCountIteration_
=
math
::
integer_divide_ceil
(
kGridSize_
,
K_MeanVarCountBlockTileSize
);
x_grid_desc_m_k_
=
MakeSrc2dDescriptor
(
Lengths_
,
xStrides_
,
kGridSize_
,
numBlockTileIteration_
);
gamma_grid_desc_m_k_
=
MakeSrc2dDescriptor
(
Lengths_
,
gammaStrides_
,
kGridSize_
,
numBlockTileIteration_
);
beta_grid_desc_m_k_
=
MakeSrc2dDescriptor
(
Lengths_
,
betaStrides_
,
kGridSize_
,
numBlockTileIteration_
);
y_grid_desc_m_k_
=
MakeSrc2dDescriptor
(
Lengths_
,
yStrides_
,
kGridSize_
,
numBlockTileIteration_
);
save_mean_grid_desc_m_
=
MakeSaveMeanInvStdDescriptor_M
(
Lengths_
,
saveMeanStrides
);
save_inv_std_grid_desc_m_
=
MakeSaveMeanInvStdDescriptor_M
(
Lengths_
,
saveInvStdStrides
);
// We don't need to pad in K dimension for Welford1. Set KPerTile 1.
kernel1_mean_var_grid_desc_m_kblock_
=
MakeWorkspaceMeanVarDescriptor_M_K
<
Sequence
<
true
,
false
>
,
M_BlockTileSize
,
1
>
(
MRaw_
,
kGridSize_
);
kernel2_mean_var_grid_desc_m_kblock_
=
MakeWorkspaceMeanVarDescriptor_M_K
<
Sequence
<
true
,
true
>
,
M_BlockTileSize
,
K_MeanVarCountBlockTileSize
>
(
MRaw_
,
kGridSize_
);
kernel2_count_grid_desc_m_kblock_
=
MakeWorkspaceCountDescriptor_M_K
<
Sequence
<
true
,
true
>
,
M_BlockTileSize
,
K_MeanVarCountBlockTileSize
>
(
MRaw_
,
kGridSize_
);
if
constexpr
(
NumInvariantDim
==
0
)
invariant_lowest_length_
=
1
;
else
invariant_lowest_length_
=
Lengths_
[
NumInvariantDim
-
1
];
}
ComputeDataType
epsilon_
;
const
XDataType
*
p_x_
;
const
GammaDataType
*
p_gamma_
;
const
BetaDataType
*
p_beta_
;
YDataType
*
p_y_
;
SaveMeanInvStdDataType
*
p_saveMean_
;
SaveMeanInvStdDataType
*
p_saveInvStd_
;
void
*
p_workspace_mean_
;
void
*
p_workspace_var_
;
void
*
p_workspace_count_
;
std
::
vector
<
index_t
>
Lengths_
;
std
::
vector
<
index_t
>
xStrides_
;
std
::
vector
<
index_t
>
gammaStrides_
;
std
::
vector
<
index_t
>
betaStrides_
;
std
::
vector
<
index_t
>
yStrides_
;
std
::
vector
<
index_t
>
saveMeanStrides_
;
std
::
vector
<
index_t
>
saveInvStdStrides_
;
YElementwiseOperation
y_elementwise_op_
;
int
kGridSize_
;
int
numMeanVarCountIteration_
;
int
numBlockTileIteration_
;
size_t
gridSize_
;
SrcGridDesc_M_K
x_grid_desc_m_k_
;
SrcGridDesc_M_K
gamma_grid_desc_m_k_
;
SrcGridDesc_M_K
beta_grid_desc_m_k_
;
SrcGridDesc_M_K
y_grid_desc_m_k_
;
SaveMeanInvStdGridDesc_M
save_mean_grid_desc_m_
;
SaveMeanInvStdGridDesc_M
save_inv_std_grid_desc_m_
;
Kernel1MeanVarGridDesc_M_KBlock
kernel1_mean_var_grid_desc_m_kblock_
;
Kernel2MeanVarGridDesc_M_KBlock
kernel2_mean_var_grid_desc_m_kblock_
;
Kernel2CountGridDesc_M_KBlock
kernel2_count_grid_desc_m_kblock_
;
index_t
MRaw_
;
// invarient length
index_t
KRaw_
;
// reduce length
index_t
invariant_lowest_length_
;
};
struct
Invoker
:
public
BaseInvoker
{
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
{
if
(
arg
.
p_workspace_mean_
==
nullptr
||
arg
.
p_workspace_var_
==
nullptr
||
arg
.
p_workspace_count_
==
nullptr
)
throw
std
::
runtime_error
(
"wrong! WorkSpace pointer has not been set"
);
auto
kernel1
=
kernel_normalizationSplitK1st
<
GridwiseWelford
,
XDataType
,
WorkspaceMeanVarDataType
,
ComputeDataType
,
SrcGridDesc_M_K
,
Kernel1MeanVarGridDesc_M_KBlock
>
;
auto
kernel2
=
kernel_normalizationSplitK2nd
<
GridwiseWelfordNormalization
,
WorkspaceMeanVarDataType
,
XDataType
,
GammaDataType
,
BetaDataType
,
YDataType
,
SaveMeanInvStdDataType
,
ComputeDataType
,
YElementwiseOperation
,
Kernel2MeanVarGridDesc_M_KBlock
,
Kernel2CountGridDesc_M_KBlock
,
SrcGridDesc_M_K
,
SaveMeanInvStdGridDesc_M
>
;
float
avg_time
=
0
;
avg_time
+=
launch_and_time_kernel
(
stream_config
,
kernel1
,
dim3
(
arg
.
gridSize_
),
dim3
(
BlockSize
),
0
,
arg
.
x_grid_desc_m_k_
,
arg
.
kernel1_mean_var_grid_desc_m_kblock_
,
arg
.
numBlockTileIteration_
,
arg
.
p_x_
,
static_cast
<
WorkspaceMeanVarDataType
*>
(
arg
.
p_workspace_mean_
),
static_cast
<
WorkspaceMeanVarDataType
*>
(
arg
.
p_workspace_var_
),
static_cast
<
int32_t
*>
(
arg
.
p_workspace_count_
));
avg_time
+=
launch_and_time_kernel
(
stream_config
,
kernel2
,
dim3
(
arg
.
gridSize_
),
dim3
(
BlockSize
),
0
,
arg
.
kernel2_mean_var_grid_desc_m_kblock_
,
arg
.
kernel2_count_grid_desc_m_kblock_
,
arg
.
x_grid_desc_m_k_
,
arg
.
gamma_grid_desc_m_k_
,
arg
.
beta_grid_desc_m_k_
,
arg
.
y_grid_desc_m_k_
,
arg
.
save_mean_grid_desc_m_
,
arg
.
save_inv_std_grid_desc_m_
,
arg
.
numMeanVarCountIteration_
,
arg
.
numBlockTileIteration_
,
arg
.
kGridSize_
,
arg
.
epsilon_
,
static_cast
<
const
WorkspaceMeanVarDataType
*>
(
arg
.
p_workspace_mean_
),
static_cast
<
const
WorkspaceMeanVarDataType
*>
(
arg
.
p_workspace_var_
),
static_cast
<
const
int32_t
*>
(
arg
.
p_workspace_count_
),
arg
.
p_x_
,
arg
.
p_gamma_
,
arg
.
p_beta_
,
arg
.
p_y_
,
arg
.
p_saveMean_
,
arg
.
p_saveInvStd_
,
arg
.
y_elementwise_op_
);
return
avg_time
;
};
float
Run
(
const
BaseArgument
*
p_arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
override
{
return
Run
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
),
stream_config
);
};
};
size_t
GetWorkSpaceSize
(
const
BaseArgument
*
pArg
)
const
override
{
const
Argument
*
pArg_
=
dynamic_cast
<
const
Argument
*>
(
pArg
);
size_t
workspace_size
=
0
;
int
welford_size
=
pArg_
->
MRaw_
*
pArg_
->
kGridSize_
;
// workspace for welford intermediate mean
workspace_size
+=
welford_size
*
sizeof
(
WorkspaceMeanVarDataType
)
+
64
;
// workspace for welford intermediate variance
workspace_size
+=
welford_size
*
sizeof
(
WorkspaceMeanVarDataType
)
+
64
;
// workspace for welford intermediate count
workspace_size
+=
pArg_
->
kGridSize_
*
sizeof
(
int32_t
)
+
64
;
return
(
workspace_size
);
};
void
SetWorkSpacePointer
(
BaseArgument
*
pArg
,
void
*
p_workspace
)
const
override
{
Argument
*
pArg_
=
dynamic_cast
<
Argument
*>
(
pArg
);
pArg_
->
p_workspace_
=
p_workspace
;
int
welford_size
=
pArg_
->
MRaw_
*
pArg_
->
kGridSize_
;
// setup buffer used for intermediate welford mean
pArg_
->
p_workspace_mean_
=
static_cast
<
char
*>
(
pArg_
->
p_workspace_
);
index_t
mean_space_sz
=
welford_size
*
sizeof
(
WorkspaceMeanVarDataType
);
mean_space_sz
=
math
::
integer_least_multiple
(
mean_space_sz
,
64
);
// setup buffer used for intermediate welford varirance
pArg_
->
p_workspace_var_
=
reinterpret_cast
<
char
*>
(
pArg_
->
p_workspace_mean_
)
+
mean_space_sz
;
index_t
variance_space_sz
=
welford_size
*
sizeof
(
WorkspaceMeanVarDataType
);
variance_space_sz
=
math
::
integer_least_multiple
(
variance_space_sz
,
64
);
// setup buffer used for intermediate welford count
pArg_
->
p_workspace_count_
=
reinterpret_cast
<
char
*>
(
pArg_
->
p_workspace_var_
)
+
variance_space_sz
;
};
bool
IsSupportedArgument
(
const
BaseArgument
*
p_arg
)
override
{
const
Argument
*
p_arg_
=
dynamic_cast
<
const
Argument
*>
(
p_arg
);
if
constexpr
(
XYVectorDim
==
0
)
{
if
constexpr
(
NumInvariantDim
==
0
)
{
return
false
;
}
else
{
if
(
p_arg_
->
xStrides_
[
NumInvariantDim
-
1
]
!=
1
)
return
false
;
if
(
p_arg_
->
invariant_lowest_length_
%
XSrcVectorSize
!=
0
)
return
false
;
if
(
p_arg_
->
invariant_lowest_length_
%
YDstVectorSize
!=
0
)
return
false
;
};
}
else
{
if
(
p_arg_
->
xStrides_
[
Rank
-
1
]
!=
1
)
return
false
;
if
(
p_arg_
->
Lengths_
[
Rank
-
1
]
%
XSrcVectorSize
!=
0
)
return
false
;
if
(
p_arg_
->
Lengths_
[
Rank
-
1
]
%
YDstVectorSize
!=
0
)
return
false
;
};
// if fastest dim is not reduced
if
constexpr
(
GammaSrcVectorDim
==
0
)
{
if
(
p_arg_
->
gammaStrides_
[
NumInvariantDim
-
1
]
!=
1
)
return
false
;
if
(
p_arg_
->
Lengths_
[
Rank
-
1
]
%
GammaSrcVectorSize
!=
0
)
return
false
;
}
else
// if fastest dim is reduced
{
if
(
p_arg_
->
gammaStrides_
[
Rank
-
1
]
!=
1
)
return
false
;
if
(
p_arg_
->
Lengths_
[
Rank
-
1
]
%
GammaSrcVectorSize
!=
0
)
return
false
;
}
// if fastest dim is not reduced
if
constexpr
(
BetaSrcVectorDim
==
0
)
{
if
(
p_arg_
->
betaStrides_
[
NumInvariantDim
-
1
]
!=
1
)
return
false
;
if
(
p_arg_
->
invariant_lowest_length_
%
BetaSrcVectorSize
!=
0
)
return
false
;
}
else
// if fastest dim is reduced
{
if
(
p_arg_
->
betaStrides_
[
Rank
-
1
]
!=
1
)
return
false
;
if
(
p_arg_
->
Lengths_
[
Rank
-
1
]
%
BetaSrcVectorSize
!=
0
)
return
false
;
}
if
(
p_arg_
->
kGridSize_
<=
1
)
return
false
;
if
(
p_arg_
->
invariant_lowest_length_
%
SaveMeanInvStdDstVectorSize
!=
0
)
return
false
;
return
true
;
};
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
std
::
vector
<
index_t
>
lengths
,
const
std
::
vector
<
index_t
>
xStrides
,
const
std
::
vector
<
index_t
>
gammaStrides
,
const
std
::
vector
<
index_t
>
betaStrides
,
const
std
::
vector
<
index_t
>
yStrides
,
const
std
::
vector
<
index_t
>
saveMeanStrides
,
const
std
::
vector
<
index_t
>
saveInvStdStrides
,
const
std
::
vector
<
index_t
>
reduceDims
,
double
epsilon
,
const
void
*
p_x
,
const
void
*
p_gamma
,
const
void
*
p_beta
,
void
*
p_y
,
void
*
p_saveMean
,
void
*
p_saveInvStd
,
YElementwiseOperation
y_elementwise_op
)
override
{
if
(
lengths
.
size
()
!=
Rank
||
xStrides
.
size
()
!=
Rank
||
gammaStrides
.
size
()
!=
Rank
||
betaStrides
.
size
()
!=
Rank
||
yStrides
.
size
()
!=
Rank
||
saveMeanStrides
.
size
()
!=
NumInvariantDim
||
saveInvStdStrides
.
size
()
!=
NumInvariantDim
)
throw
std
::
runtime_error
(
"dimension is incorrect"
);
return
std
::
make_unique
<
Argument
>
(
lengths
,
xStrides
,
gammaStrides
,
betaStrides
,
yStrides
,
saveMeanStrides
,
saveInvStdStrides
,
reduceDims
,
y_elementwise_op
,
epsilon
,
static_cast
<
const
XDataType
*>
(
p_x
),
static_cast
<
const
GammaDataType
*>
(
p_gamma
),
static_cast
<
const
BetaDataType
*>
(
p_beta
),
static_cast
<
YDataType
*>
(
p_y
),
static_cast
<
SaveMeanInvStdDataType
*>
(
p_saveMean
),
static_cast
<
SaveMeanInvStdDataType
*>
(
p_saveInvStd
));
};
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
override
{
return
std
::
make_unique
<
Invoker
>
();
};
std
::
string
GetTypeString
()
const
override
{
auto
str
=
std
::
stringstream
();
// clang-format off
str
<<
"DeviceNormalizationSplitKImpl<"
<<
BlockSize
<<
","
;
str
<<
"Cluster_MK_"
<<
MThreadClusterSize
<<
"_"
<<
KThreadClusterSize
<<
","
;
str
<<
"Slice_MK_"
<<
MThreadSliceSize
<<
"_"
<<
KThreadSliceSize
<<
","
;
str
<<
"XYSrcVectorDim_"
<<
XYVectorDim
<<
","
;
str
<<
"VectorSize_X"
<<
XSrcVectorSize
<<
"_Gamma"
<<
GammaSrcVectorSize
<<
"_Beta"
<<
BetaSrcVectorSize
<<
"_Y"
<<
YDstVectorSize
<<
">"
;
// clang-format on
return
str
.
str
();
}
};
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
include/ck/tensor_operation/gpu/device/impl/device_permute_impl.hpp
View file @
4a106f7d
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
2
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
include/ck/tensor_operation/gpu/device/impl/device_pool2d_fwd_nhwc_nhwc.hpp
View file @
4a106f7d
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
2
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <sstream>
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/reduction_operator_mapping.hpp"
#include "ck/tensor_operation/gpu/device/device_pool2d_fwd.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_2d_reduction_threadwise.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_pool3d_fwd_ndhwc_ndhwc.hpp"
namespace
ck
{
namespace
tensor_operation
{
...
...
@@ -20,305 +11,100 @@ namespace device {
template
<
typename
InDataType
,
typename
OutDataType
,
typename
AccDataType
,
typename
IndexDataType
,
// enable if OutputIndex == true
typename
ComputeDataType
,
ck
::
ReduceTensorOp
ReduceOpId
,
bool
OuputIndex
,
bool
Ou
t
putIndex
,
ck
::
index_t
BlockSize
,
ck
::
index_t
ReduceMThreadClusterSize
,
ck
::
index_t
ReduceKThreadClusterSize
,
ck
::
index_t
ReduceMThreadSliceSize
,
ck
::
index_t
ReduceKThreadSliceSize
,
ck
::
index_t
InSrcOutDstVectorSize
>
struct
DevicePool2dFwd_Input_N_Hi_Wi_C_Output_N_Ho_Wo_C
:
public
DevicePool2dFwd
<
ReduceOpId
>
struct
DevicePool2dFwd_NHWC_NHWC
:
public
DevicePool3dFwd_NDHWC_NDHWC
<
InDataType
,
OutDataType
,
IndexDataType
,
ComputeDataType
,
ReduceOpId
,
OutputIndex
,
BlockSize
,
ReduceMThreadClusterSize
,
ReduceKThreadClusterSize
,
ReduceMThreadSliceSize
,
ReduceKThreadSliceSize
,
InSrcOutDstVectorSize
>
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
static
constexpr
auto
I4
=
Number
<
4
>
{};
static
constexpr
auto
I5
=
Number
<
5
>
{};
using
IndexDataType
=
int32_t
;
using
ReduceOperation
=
typename
reduce_binary_operator
<
ReduceOpId
>::
opType
;
using
InElementwiseOperation
=
typename
reduce_unary_operator
<
ReduceOpId
,
true
,
true
>::
InElementwiseOperation
;
using
AccElementwiseOperation
=
typename
reduce_unary_operator
<
ReduceOpId
,
true
,
true
>::
AccElementwiseOperation
;
static
constexpr
index_t
InSrcOutDstVectorDim
=
0
;
// for NHWC, the dim C is the vector Dim for both input and output in memory, which is
// not reduced.
static
constexpr
ck
::
index_t
ReduceM_BlockTileSize
=
ReduceMThreadClusterSize
*
ReduceMThreadSliceSize
;
static
constexpr
ck
::
index_t
ReduceK_BlockTileSize
=
ReduceKThreadClusterSize
*
ReduceKThreadSliceSize
;
static
auto
MakeABGridDescriptor_A_M_K_B_M
(
ck
::
index_t
N
,
ck
::
index_t
C
,
std
::
array
<
ck
::
index_t
,
2
>
input_spatial_lengths
,
std
::
array
<
ck
::
index_t
,
2
>
window_spatial_lengths
,
std
::
array
<
ck
::
index_t
,
2
>
output_spatial_lengths
,
std
::
array
<
ck
::
index_t
,
2
>
window_strides
,
std
::
array
<
ck
::
index_t
,
2
>
input_left_pads
,
std
::
array
<
ck
::
index_t
,
2
>
input_right_pads
)
{
const
index_t
Hi
=
input_spatial_lengths
[
0
];
const
index_t
Wi
=
input_spatial_lengths
[
1
];
const
index_t
Ho
=
output_spatial_lengths
[
0
];
const
index_t
Wo
=
output_spatial_lengths
[
1
];
const
index_t
Y
=
window_spatial_lengths
[
0
];
const
index_t
X
=
window_spatial_lengths
[
1
];
const
index_t
ConvStrideH
=
window_strides
[
0
];
const
index_t
ConvStrideW
=
window_strides
[
1
];
const
index_t
InLeftPadH
=
input_left_pads
[
0
];
const
index_t
InLeftPadW
=
input_left_pads
[
1
];
const
index_t
InRightPadH
=
input_right_pads
[
0
];
const
index_t
InRightPadW
=
input_right_pads
[
1
];
const
index_t
ReduceMRaw
=
N
*
Ho
*
Wo
*
C
;
const
index_t
ReduceMPad
=
math
::
integer_least_multiple
(
ReduceMRaw
,
ReduceM_BlockTileSize
)
-
ReduceMRaw
;
const
index_t
ReduceKRaw
=
Y
*
X
;
const
index_t
ReduceKPad
=
math
::
integer_least_multiple
(
ReduceKRaw
,
ReduceK_BlockTileSize
)
-
ReduceKRaw
;
// A[ReduceM, ReduceK]
const
auto
in_grid_desc_n_hi_wi_c
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
N
,
Hi
,
Wi
,
C
));
const
auto
in_grid_desc_n_hip_wip_c
=
transform_tensor_descriptor
(
in_grid_desc_n_hi_wi_c
,
make_tuple
(
make_pass_through_transform
(
N
),
make_pad_transform
(
Hi
,
InLeftPadH
,
InRightPadH
),
make_pad_transform
(
Wi
,
InLeftPadW
,
InRightPadW
),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}));
const
auto
in_grid_desc_n_y_ho_x_wo_c
=
transform_tensor_descriptor
(
in_grid_desc_n_hip_wip_c
,
make_tuple
(
make_pass_through_transform
(
N
),
make_embed_transform
(
make_tuple
(
Y
,
Ho
),
make_tuple
(
I1
,
ConvStrideH
)),
make_embed_transform
(
make_tuple
(
X
,
Wo
),
make_tuple
(
I1
,
ConvStrideW
)),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
,
4
>
{},
Sequence
<
5
>
{}));
const
auto
in_grid_desc_reducemraw_reducekraw
=
transform_tensor_descriptor
(
in_grid_desc_n_y_ho_x_wo_c
,
make_tuple
(
make_merge_transform
(
make_tuple
(
N
,
Ho
,
Wo
,
C
)),
make_merge_transform
(
make_tuple
(
Y
,
X
))),
make_tuple
(
Sequence
<
0
,
2
,
4
,
5
>
{},
Sequence
<
1
,
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
in_grid_desc_reducem_reducek
=
transform_tensor_descriptor
(
in_grid_desc_reducemraw_reducekraw
,
make_tuple
(
make_right_pad_transform
(
ReduceMRaw
,
ReduceMPad
),
make_right_pad_transform
(
ReduceKRaw
,
ReduceKPad
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
// B[ReduceM]
const
auto
out_grid_desc_reducemraw
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
N
*
Ho
*
Wo
*
C
));
const
auto
out_grid_desc_reducem
=
transform_tensor_descriptor
(
out_grid_desc_reducemraw
,
make_tuple
(
make_right_pad_transform
(
ReduceMRaw
,
ReduceMPad
)),
make_tuple
(
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
return
make_tuple
(
in_grid_desc_reducem_reducek
,
out_grid_desc_reducem
);
}
using
ABGridDescs
=
decltype
(
MakeABGridDescriptor_A_M_K_B_M
(
1
,
1
,
{
1
,
1
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
}));
using
AGridDesc_M_K
=
remove_cvref_t
<
decltype
(
ABGridDescs
{}[
I0
])
>
;
using
BGridDesc_M
=
remove_cvref_t
<
decltype
(
ABGridDescs
{}[
I1
])
>
;
// TODO
struct
Argument
:
public
BaseArgument
{
Argument
(
const
InDataType
*
p_in_dev
,
OutDataType
*
p_out_dev
,
int
*
p_out_indices_dev
,
ck
::
index_t
N
,
ck
::
index_t
C
,
std
::
array
<
ck
::
index_t
,
2
>&
input_spatial_lengths
,
std
::
array
<
ck
::
index_t
,
2
>&
window_spatial_lengths
,
std
::
array
<
ck
::
index_t
,
2
>&
output_spatial_lengths
,
std
::
array
<
ck
::
index_t
,
2
>&
window_strides
,
std
::
array
<
ck
::
index_t
,
2
>&
input_left_pads
,
std
::
array
<
ck
::
index_t
,
2
>&
input_right_pads
)
:
p_in_dev_
{
p_in_dev
},
p_out_dev_
{
p_out_dev
},
p_out_indices_dev_
{
p_out_indices_dev
},
a_grid_desc_m_k_
{},
b_grid_desc_m_
{}
{
const
auto
descs
=
MakeABGridDescriptor_A_M_K_B_M
(
N
,
C
,
input_spatial_lengths
,
window_spatial_lengths
,
output_spatial_lengths
,
window_strides
,
input_left_pads
,
input_right_pads
);
a_grid_desc_m_k_
=
descs
[
I0
];
b_grid_desc_m_
=
descs
[
I1
];
invariant_lowest_length_
=
C
;
reduce_lowest_length_
=
window_spatial_lengths
[
1
];
int32_t
reduceLength
=
window_spatial_lengths
[
0
]
*
window_spatial_lengths
[
1
];
std
::
tie
(
in_element_op_
,
acc_element_op_
)
=
reduce_unary_operator
<
ReduceOpId
,
true
,
true
>::
GetElementwiseOperator
(
reduceLength
);
}
const
InDataType
*
p_in_dev_
;
OutDataType
*
p_out_dev_
;
int
*
p_out_indices_dev_
;
AGridDesc_M_K
a_grid_desc_m_k_
;
BGridDesc_M
b_grid_desc_m_
;
InElementwiseOperation
in_element_op_
;
AccElementwiseOperation
acc_element_op_
;
// for checking vector load/store
ck
::
index_t
invariant_lowest_length_
;
ck
::
index_t
reduce_lowest_length_
;
};
struct
Invoker
:
public
BaseInvoker
{
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
{
using
gridwise_reduce
=
GridwiseReduction_mk_to_m_threadwise
<
InDataType
,
using
DevicePool3D
=
DevicePool3dFwd_NDHWC_NDHWC
<
InDataType
,
OutDataType
,
AccDataType
,
IndexDataType
,
AGridDesc_M_K
,
BGridDesc_M
,
ReduceOperation
,
InElementwiseOperation
,
AccElementwiseOperation
,
InMemoryDataOperationEnum
::
Set
,
false
,
// propagate_nan
ComputeDataType
,
ReduceOpId
,
OutputIndex
,
BlockSize
,
ReduceMThreadClusterSize
,
ReduceKThreadClusterSize
,
ReduceMThreadSliceSize
,
ReduceKThreadSliceSize
,
InSrcOutDstVectorDim
,
InSrcOutDstVectorSize
,
InSrcOutDstVectorSize
>
;
const
auto
kernel
=
kernel_reduce_threadwise
<
gridwise_reduce
,
OuputIndex
,
false
,
// don't have index input
InDataType
,
OutDataType
,
AccDataType
,
IndexDataType
,
AGridDesc_M_K
,
BGridDesc_M
,
InElementwiseOperation
,
AccElementwiseOperation
>
;
ck
::
index_t
ReduceM
=
arg
.
a_grid_desc_m_k_
.
GetLength
(
I0
);
const
index_t
grid_size
=
(
ReduceM
/
ReduceM_BlockTileSize
);
return
launch_and_time_kernel
(
stream_config
,
kernel
,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
arg
.
a_grid_desc_m_k_
,
arg
.
b_grid_desc_m_
,
arg
.
in_element_op_
,
arg
.
acc_element_op_
,
float
(
1
),
arg
.
p_in_dev_
,
nullptr
,
float
(
0
),
arg
.
p_out_dev_
,
arg
.
p_out_indices_dev_
);
}
float
Run
(
const
BaseArgument
*
p_arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
override
{
return
Run
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
),
stream_config
);
}
};
bool
IsSupportedArgument
(
const
BaseArgument
*
p_arg
)
override
{
const
Argument
*
pArg
=
dynamic_cast
<
const
Argument
*>
(
p_arg
);
if
(
pArg
->
invariant_lowest_length_
%
InSrcOutDstVectorSize
!=
0
)
{
return
(
false
);
}
return
(
true
);
}
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
void
*
p_in_dev
,
void
*
p_out_dev
,
void
*
p_out_indices_dev
,
ck
::
index_t
N
,
ck
::
index_t
C
,
std
::
array
<
ck
::
index_t
,
2
>
input_spatial_lengths
,
std
::
array
<
ck
::
index_t
,
2
>
window_spatial_lengths
,
std
::
array
<
ck
::
index_t
,
2
>
output_spatial_lengths
,
std
::
array
<
ck
::
index_t
,
2
>
window_strides
,
std
::
array
<
ck
::
index_t
,
2
>
input_left_pads
,
std
::
array
<
ck
::
index_t
,
2
>
input_right_pads
)
override
std
::
vector
<
ck
::
index_t
>
input_lengths
,
std
::
vector
<
ck
::
index_t
>
window_lengths
,
std
::
vector
<
ck
::
index_t
>
output_lengths
,
std
::
vector
<
ck
::
index_t
>
input_stride
,
std
::
vector
<
ck
::
index_t
>
output_stride
,
std
::
vector
<
ck
::
index_t
>
indices_stride
,
std
::
vector
<
ck
::
index_t
>
window_strides
,
std
::
vector
<
ck
::
index_t
>
window_dilations
,
std
::
vector
<
ck
::
index_t
>
input_left_pads
,
std
::
vector
<
ck
::
index_t
>
input_right_pads
,
std
::
vector
<
ck
::
index_t
>
pooling_dims
)
override
{
return
std
::
make_unique
<
Argument
>
(
static_cast
<
const
InDataType
*>
(
p_in_dev
),
static_cast
<
OutDataType
*>
(
p_out_dev
),
static_cast
<
int
*>
(
p_out_indices_dev
),
N
,
C
,
input_spatial_lengths
,
window_spatial_lengths
,
output_spatial_lengths
,
window_strides
,
input_left_pads
,
input_right_pads
);
}
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
override
{
return
std
::
make_unique
<
Invoker
>
(
Invoker
{});
}
std
::
string
GetTypeString
()
const
override
{
auto
str
=
std
::
stringstream
();
// clang-format off
str
<<
"DevicePool2dFwd_Input_N_Hi_Wi_C_Output_N_Ho_Wo_C<"
<<
BlockSize
<<
","
;
str
<<
"M_C"
<<
ReduceMThreadClusterSize
<<
"_S"
<<
ReduceMThreadSliceSize
<<
","
;
str
<<
"K_C"
<<
ReduceKThreadClusterSize
<<
"_S"
<<
ReduceKThreadSliceSize
<<
","
;
str
<<
"InSrcOutDstVectorSize_"
<<
InSrcOutDstVectorSize
<<
">"
;
// clang-format on
return
str
.
str
();
static
constexpr
index_t
InOutRank
=
4
;
static
constexpr
index_t
WindowRank
=
2
;
if
(
input_lengths
.
size
()
!=
InOutRank
||
window_lengths
.
size
()
!=
WindowRank
||
input_lengths
.
size
()
!=
InOutRank
||
window_strides
.
size
()
!=
WindowRank
||
window_dilations
.
size
()
!=
WindowRank
||
input_left_pads
.
size
()
!=
WindowRank
||
input_right_pads
.
size
()
!=
WindowRank
)
throw
std
::
runtime_error
(
"dimension is incorrect"
);
if
(
pooling_dims
!=
std
::
vector
<
ck
::
index_t
>
{
2
,
3
})
throw
std
::
runtime_error
(
"pooling_dims only support {2, 3} in pool2d so far"
);
// NCHW to NCDHW
input_lengths
.
insert
(
input_lengths
.
begin
()
+
2
,
1
);
output_lengths
.
insert
(
output_lengths
.
begin
()
+
2
,
1
);
input_stride
.
insert
(
input_stride
.
begin
()
+
2
,
0
);
output_stride
.
insert
(
output_stride
.
begin
()
+
2
,
0
);
indices_stride
.
insert
(
indices_stride
.
begin
()
+
2
,
0
);
// YX to ZYX
window_lengths
.
insert
(
window_lengths
.
begin
(),
1
);
window_strides
.
insert
(
window_strides
.
begin
(),
0
);
window_dilations
.
insert
(
window_dilations
.
begin
(),
0
);
input_left_pads
.
insert
(
input_left_pads
.
begin
(),
0
);
input_right_pads
.
insert
(
input_right_pads
.
begin
(),
0
);
pooling_dims
=
{
2
,
3
,
4
};
return
DevicePool3D
::
MakeArgumentPointer
(
p_in_dev
,
p_out_dev
,
p_out_indices_dev
,
input_lengths
,
window_lengths
,
output_lengths
,
input_stride
,
output_stride
,
indices_stride
,
window_strides
,
window_dilations
,
input_left_pads
,
input_right_pads
,
pooling_dims
);
}
};
...
...
include/ck/tensor_operation/gpu/device/impl/device_pool3d_fwd_ndhwc_ndhwc.hpp
0 → 100644
View file @
4a106f7d
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <sstream>
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/reduction_operator_mapping.hpp"
#include "ck/tensor_operation/gpu/device/device_pool_fwd.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_reduce_common.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_2d_reduction_threadwise.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
template
<
typename
InDataType
,
typename
OutDataType
,
typename
IndexDataType
,
// enable if OutputIndex == true
typename
ComputeDataType
,
ck
::
ReduceTensorOp
ReduceOpId
,
bool
OutputIndex
,
ck
::
index_t
BlockSize
,
ck
::
index_t
MThreadClusterSize
,
ck
::
index_t
KThreadClusterSize
,
ck
::
index_t
MThreadSliceSize
,
ck
::
index_t
KThreadSliceSize
,
ck
::
index_t
InSrcOutDstVectorSize
>
struct
DevicePool3dFwd_NDHWC_NDHWC
:
public
DevicePoolFwd
<
5
,
3
,
InDataType
,
OutDataType
,
IndexDataType
,
tensor_layout
::
convolution
::
NDHWC
,
tensor_layout
::
convolution
::
NDHWC
,
ReduceOpId
,
OutputIndex
>
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
static
constexpr
auto
I4
=
Number
<
4
>
{};
static
constexpr
auto
I5
=
Number
<
5
>
{};
static
constexpr
index_t
InOutRank
=
5
;
static
constexpr
index_t
WindowRank
=
3
;
using
ReduceOperation
=
typename
reduce_binary_operator
<
ReduceOpId
>::
opType
;
using
InElementwiseOperation
=
typename
reduce_unary_operator
<
ReduceOpId
,
true
,
true
>::
InElementwiseOperation
;
using
AccElementwiseOperation
=
typename
reduce_unary_operator
<
ReduceOpId
,
true
,
true
>::
AccElementwiseOperation
;
static
constexpr
ck
::
index_t
M_BlockTileSize
=
MThreadClusterSize
*
MThreadSliceSize
;
static
constexpr
ck
::
index_t
K_BlockTileSize
=
KThreadClusterSize
*
KThreadSliceSize
;
static
auto
MakeABGridDescriptor_A_M_K_B_M
(
std
::
vector
<
ck
::
index_t
>
input_ncdhw_lengths
,
std
::
vector
<
ck
::
index_t
>
output_ncdhw_lengths
,
std
::
vector
<
ck
::
index_t
>
input_ncdhw_stride
,
std
::
vector
<
ck
::
index_t
>
output_ncdhw_stride
,
std
::
vector
<
ck
::
index_t
>
window_spatial_zyx_lengths
,
std
::
vector
<
ck
::
index_t
>
window_zyx_strides
,
std
::
vector
<
ck
::
index_t
>
window_zyx_dilations
,
std
::
vector
<
ck
::
index_t
>
input_left_dhw_pads
,
std
::
vector
<
ck
::
index_t
>
input_right_dhw_pads
)
{
const
index_t
N
=
input_ncdhw_lengths
[
0
];
const
index_t
C
=
input_ncdhw_lengths
[
1
];
const
index_t
Di
=
input_ncdhw_lengths
[
2
];
const
index_t
Hi
=
input_ncdhw_lengths
[
3
];
const
index_t
Wi
=
input_ncdhw_lengths
[
4
];
const
index_t
Do
=
output_ncdhw_lengths
[
2
];
const
index_t
Ho
=
output_ncdhw_lengths
[
3
];
const
index_t
Wo
=
output_ncdhw_lengths
[
4
];
const
index_t
Z
=
window_spatial_zyx_lengths
[
0
];
const
index_t
Y
=
window_spatial_zyx_lengths
[
1
];
const
index_t
X
=
window_spatial_zyx_lengths
[
2
];
const
index_t
WindowStrideD
=
window_zyx_strides
[
0
];
const
index_t
WindowStrideH
=
window_zyx_strides
[
1
];
const
index_t
WindowStrideW
=
window_zyx_strides
[
2
];
const
index_t
WindowDilationD
=
window_zyx_dilations
[
0
];
const
index_t
WindowDilationH
=
window_zyx_dilations
[
1
];
const
index_t
WindowDilationW
=
window_zyx_dilations
[
2
];
const
index_t
InLeftPadD
=
input_left_dhw_pads
[
0
];
const
index_t
InLeftPadH
=
input_left_dhw_pads
[
1
];
const
index_t
InLeftPadW
=
input_left_dhw_pads
[
2
];
const
index_t
InRightPadD
=
input_right_dhw_pads
[
0
];
const
index_t
InRightPadH
=
input_right_dhw_pads
[
1
];
const
index_t
InRightPadW
=
input_right_dhw_pads
[
2
];
const
index_t
MRaw
=
N
*
Do
*
Ho
*
Wo
*
C
;
const
index_t
MPad
=
math
::
integer_least_multiple
(
MRaw
,
M_BlockTileSize
)
-
MRaw
;
const
index_t
KRaw
=
Z
*
Y
*
X
;
const
index_t
KPad
=
math
::
integer_least_multiple
(
KRaw
,
K_BlockTileSize
)
-
KRaw
;
// A[ReduceM, ReduceK]
const
index_t
Ni_stride
=
input_ncdhw_stride
[
0
];
const
index_t
Ci_stride
=
input_ncdhw_stride
[
1
];
const
index_t
Di_stride
=
input_ncdhw_stride
[
2
];
const
index_t
Hi_stride
=
input_ncdhw_stride
[
3
];
const
index_t
Wi_stride
=
input_ncdhw_stride
[
4
];
const
auto
in_grid_desc_n_di_hi_wi_c
=
make_naive_tensor_descriptor
(
make_tuple
(
N
,
Di
,
Hi
,
Wi
,
C
),
make_tuple
(
Ni_stride
,
Di_stride
,
Hi_stride
,
Wi_stride
,
Ci_stride
));
const
auto
in_grid_desc_n_dip_hip_wip_c
=
transform_tensor_descriptor
(
in_grid_desc_n_di_hi_wi_c
,
make_tuple
(
make_pass_through_transform
(
N
),
make_pad_transform
(
Di
,
InLeftPadD
,
InRightPadD
),
make_pad_transform
(
Hi
,
InLeftPadH
,
InRightPadH
),
make_pad_transform
(
Wi
,
InLeftPadW
,
InRightPadW
),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{}));
const
auto
in_grid_desc_n_z_do_y_ho_x_wo_c
=
transform_tensor_descriptor
(
in_grid_desc_n_dip_hip_wip_c
,
make_tuple
(
make_pass_through_transform
(
N
),
make_embed_transform
(
make_tuple
(
Z
,
Do
),
make_tuple
(
WindowDilationD
,
WindowStrideD
)),
make_embed_transform
(
make_tuple
(
Y
,
Ho
),
make_tuple
(
WindowDilationH
,
WindowStrideH
)),
make_embed_transform
(
make_tuple
(
X
,
Wo
),
make_tuple
(
WindowDilationW
,
WindowStrideW
)),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
,
4
>
{},
Sequence
<
5
,
6
>
{},
Sequence
<
7
>
{}));
const
auto
in_grid_desc_reducemraw_reducekraw
=
transform_tensor_descriptor
(
in_grid_desc_n_z_do_y_ho_x_wo_c
,
make_tuple
(
make_merge_transform
(
make_tuple
(
N
,
Do
,
Ho
,
Wo
,
C
)),
make_merge_transform
(
make_tuple
(
Z
,
Y
,
X
))),
make_tuple
(
Sequence
<
0
,
2
,
4
,
6
,
7
>
{},
Sequence
<
1
,
3
,
5
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
in_grid_desc_reducem_reducek
=
transform_tensor_descriptor
(
in_grid_desc_reducemraw_reducekraw
,
make_tuple
(
make_right_pad_transform
(
MRaw
,
MPad
),
make_right_pad_transform
(
KRaw
,
KPad
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
// B[ReduceM]
const
index_t
No_stride
=
output_ncdhw_stride
[
0
];
const
index_t
Co_stride
=
output_ncdhw_stride
[
1
];
const
index_t
Do_stride
=
output_ncdhw_stride
[
2
];
const
index_t
Ho_stride
=
output_ncdhw_stride
[
3
];
const
index_t
Wo_stride
=
output_ncdhw_stride
[
4
];
const
auto
out_grid_desc_n_do_ho_wo_c
=
make_naive_tensor_descriptor
(
make_tuple
(
N
,
Di
,
Hi
,
Wi
,
C
),
make_tuple
(
No_stride
,
Do_stride
,
Ho_stride
,
Wo_stride
,
Co_stride
));
const
auto
out_grid_desc_reducemraw
=
transform_tensor_descriptor
(
out_grid_desc_n_do_ho_wo_c
,
make_tuple
(
make_merge_transform
(
make_tuple
(
N
,
Do
,
Ho
,
Wo
,
C
))),
make_tuple
(
Sequence
<
0
,
1
,
2
,
3
,
4
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
const
auto
out_grid_desc_reducem
=
transform_tensor_descriptor
(
out_grid_desc_reducemraw
,
make_tuple
(
make_right_pad_transform
(
MRaw
,
MPad
)),
make_tuple
(
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
return
make_tuple
(
in_grid_desc_reducem_reducek
,
out_grid_desc_reducem
);
}
using
ABGridDescs
=
decltype
(
MakeABGridDescriptor_A_M_K_B_M
({},
{},
{},
{},
{},
{},
{},
{},
{}));
using
AGridDesc_M_K
=
remove_cvref_t
<
decltype
(
ABGridDescs
{}[
I0
])
>
;
using
BGridDesc_M
=
remove_cvref_t
<
decltype
(
ABGridDescs
{}[
I1
])
>
;
struct
Argument
:
public
BaseArgument
{
Argument
(
const
InDataType
*
p_in_dev
,
OutDataType
*
p_out_dev
,
IndexDataType
*
p_out_indices_dev
,
std
::
vector
<
ck
::
index_t
>&
input_ncdhw_lengths
,
std
::
vector
<
ck
::
index_t
>&
output_ncdhw_lengths
,
std
::
vector
<
ck
::
index_t
>&
input_ncdhw_stride
,
std
::
vector
<
ck
::
index_t
>&
output_ncdhw_stride
,
std
::
vector
<
ck
::
index_t
>&
,
// indices_ncdhw_stride
std
::
vector
<
ck
::
index_t
>&
window_spatial_zyx_lengths
,
std
::
vector
<
ck
::
index_t
>&
window_zyx_strides
,
std
::
vector
<
ck
::
index_t
>&
window_zyx_dilations
,
std
::
vector
<
ck
::
index_t
>&
input_left_dhw_pads
,
std
::
vector
<
ck
::
index_t
>&
input_right_dhw_pads
)
:
p_in_dev_
{
p_in_dev
},
p_out_dev_
{
p_out_dev
},
p_out_indices_dev_
{
p_out_indices_dev
},
a_grid_desc_m_k_
{},
b_grid_desc_m_
{},
input_ncdhw_lengths_
{
input_ncdhw_lengths
},
output_ncdhw_lengths_
{
output_ncdhw_lengths
},
input_ncdhw_stride_
{
input_ncdhw_stride
},
output_ncdhw_stride_
{
output_ncdhw_stride
}
{
const
auto
descs
=
MakeABGridDescriptor_A_M_K_B_M
(
input_ncdhw_lengths
,
output_ncdhw_lengths
,
input_ncdhw_stride
,
output_ncdhw_stride
,
window_spatial_zyx_lengths
,
window_zyx_strides
,
window_zyx_dilations
,
input_left_dhw_pads
,
input_right_dhw_pads
);
a_grid_desc_m_k_
=
descs
[
I0
];
b_grid_desc_m_
=
descs
[
I1
];
int32_t
reduceLength
=
window_spatial_zyx_lengths
[
0
]
*
window_spatial_zyx_lengths
[
1
]
*
window_spatial_zyx_lengths
[
2
];
std
::
tie
(
in_element_op_
,
acc_element_op_
)
=
reduce_unary_operator
<
ReduceOpId
,
true
,
true
>::
GetElementwiseOperator
(
reduceLength
);
}
const
InDataType
*
p_in_dev_
;
OutDataType
*
p_out_dev_
;
IndexDataType
*
p_out_indices_dev_
;
AGridDesc_M_K
a_grid_desc_m_k_
;
BGridDesc_M
b_grid_desc_m_
;
InElementwiseOperation
in_element_op_
;
AccElementwiseOperation
acc_element_op_
;
// for checking vector load/store
std
::
vector
<
ck
::
index_t
>
input_ncdhw_lengths_
;
std
::
vector
<
ck
::
index_t
>
output_ncdhw_lengths_
;
std
::
vector
<
ck
::
index_t
>
input_ncdhw_stride_
;
std
::
vector
<
ck
::
index_t
>
output_ncdhw_stride_
;
};
struct
Invoker
:
public
BaseInvoker
{
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
{
// for NDHWC, the dim C is the fastest dimension, and is not reduced.
// Hence, it is in M dimension for reduction kernel.
static
constexpr
index_t
InSrcOutDstVectorDim
=
0
;
// 0: M, 1: K
using
gridwise_reduce
=
GridwiseReduction_mk_to_m_threadwise
<
InDataType
,
OutDataType
,
ComputeDataType
,
IndexDataType
,
AGridDesc_M_K
,
BGridDesc_M
,
ReduceOperation
,
InElementwiseOperation
,
AccElementwiseOperation
,
InMemoryDataOperationEnum
::
Set
,
false
,
// propagate_nan
BlockSize
,
MThreadSliceSize
,
KThreadSliceSize
,
InSrcOutDstVectorDim
,
InSrcOutDstVectorSize
,
InSrcOutDstVectorSize
>
;
const
auto
kernel
=
kernel_reduce_threadwise
<
gridwise_reduce
,
OutputIndex
,
true
,
// pooling need to return global index
false
,
// don't have index input
InDataType
,
OutDataType
,
ComputeDataType
,
IndexDataType
,
AGridDesc_M_K
,
BGridDesc_M
,
InElementwiseOperation
,
AccElementwiseOperation
>
;
ck
::
index_t
M
=
arg
.
a_grid_desc_m_k_
.
GetLength
(
I0
);
const
index_t
grid_size
=
(
M
/
M_BlockTileSize
);
return
launch_and_time_kernel
(
stream_config
,
kernel
,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
arg
.
a_grid_desc_m_k_
,
arg
.
b_grid_desc_m_
,
arg
.
in_element_op_
,
arg
.
acc_element_op_
,
float
(
1
),
arg
.
p_in_dev_
,
nullptr
,
float
(
0
),
arg
.
p_out_dev_
,
arg
.
p_out_indices_dev_
);
}
float
Run
(
const
BaseArgument
*
p_arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
override
{
return
Run
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
),
stream_config
);
}
};
bool
IsSupportedArgument
(
const
BaseArgument
*
p_arg
)
override
{
const
Argument
*
pArg
=
dynamic_cast
<
const
Argument
*>
(
p_arg
);
// C should be fastest dimension
if
(
pArg
->
input_ncdhw_stride_
[
1
]
!=
1
)
return
false
;
for
(
int
i
=
0
;
i
<
InOutRank
;
++
i
)
{
if
(
pArg
->
input_ncdhw_stride_
[
i
]
==
1
&&
pArg
->
input_ncdhw_lengths_
[
i
]
%
InSrcOutDstVectorSize
!=
0
)
return
false
;
if
(
pArg
->
output_ncdhw_stride_
[
i
]
==
1
&&
pArg
->
output_ncdhw_lengths_
[
i
]
%
InSrcOutDstVectorSize
!=
0
)
return
false
;
}
return
true
;
}
virtual
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
void
*
p_in_dev
,
void
*
p_out_dev
,
void
*
p_out_indices_dev
,
std
::
vector
<
ck
::
index_t
>
input_ncdhw_lengths
,
std
::
vector
<
ck
::
index_t
>
window_zyx_lengths
,
std
::
vector
<
ck
::
index_t
>
output_ncdhw_lengths
,
std
::
vector
<
ck
::
index_t
>
input_ncdhw_stride
,
std
::
vector
<
ck
::
index_t
>
output_ncdhw_stride
,
std
::
vector
<
ck
::
index_t
>
indices_ncdhw_stride
,
std
::
vector
<
ck
::
index_t
>
window_zyx_strides
,
std
::
vector
<
ck
::
index_t
>
window_zyx_dilations
,
std
::
vector
<
ck
::
index_t
>
input_left_dhw_pads
,
std
::
vector
<
ck
::
index_t
>
input_right_dhw_pads
,
std
::
vector
<
ck
::
index_t
>
pooling_dims
)
override
{
if
(
input_ncdhw_lengths
.
size
()
!=
InOutRank
||
window_zyx_lengths
.
size
()
!=
WindowRank
||
input_ncdhw_lengths
.
size
()
!=
InOutRank
||
window_zyx_strides
.
size
()
!=
WindowRank
||
window_zyx_dilations
.
size
()
!=
WindowRank
||
input_left_dhw_pads
.
size
()
!=
WindowRank
||
input_right_dhw_pads
.
size
()
!=
WindowRank
)
throw
std
::
runtime_error
(
"dimension is incorrect"
);
if
(
pooling_dims
!=
std
::
vector
<
ck
::
index_t
>
{
2
,
3
,
4
})
throw
std
::
runtime_error
(
"pooling_dims only support {2, 3, 4} in pool3d so far"
);
if
(
output_ncdhw_stride
!=
indices_ncdhw_stride
)
throw
std
::
runtime_error
(
"output_ncdhw_stride need to be equal to indices_ncdhw_stride for now"
);
return
std
::
make_unique
<
Argument
>
(
static_cast
<
const
InDataType
*>
(
p_in_dev
),
static_cast
<
OutDataType
*>
(
p_out_dev
),
static_cast
<
IndexDataType
*>
(
p_out_indices_dev
),
input_ncdhw_lengths
,
output_ncdhw_lengths
,
input_ncdhw_stride
,
output_ncdhw_stride
,
indices_ncdhw_stride
,
window_zyx_lengths
,
window_zyx_strides
,
window_zyx_dilations
,
input_left_dhw_pads
,
input_right_dhw_pads
);
}
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
override
{
return
std
::
make_unique
<
Invoker
>
(
Invoker
{});
}
std
::
string
GetTypeString
()
const
override
{
auto
str
=
std
::
stringstream
();
// clang-format off
str
<<
"DevicePool3dFwd_NDHWC_NDHWC<"
<<
BlockSize
<<
","
;
str
<<
"M_C"
<<
MThreadClusterSize
<<
"_S"
<<
MThreadSliceSize
<<
","
;
str
<<
"K_C"
<<
KThreadClusterSize
<<
"_S"
<<
KThreadSliceSize
<<
","
;
str
<<
"InSrcOutDstVectorSize_"
<<
InSrcOutDstVectorSize
<<
">"
;
// clang-format on
return
str
.
str
();
}
};
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
include/ck/tensor_operation/gpu/device/impl/device_put_element_impl.hpp
0 → 100644
View file @
4a106f7d
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <sstream>
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/device_put_element.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_put_element_1d.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/host_utility/stream_utility.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
// output[indices] = input
template
<
typename
InDataType
,
typename
IndexDataType
,
typename
OutDataType
,
typename
ElementwiseOperation
,
InMemoryDataOperationEnum
MemOp
,
ck
::
index_t
InVectorSize
>
struct
DevicePutElementImpl
:
public
DevicePutElement
<
InDataType
,
IndexDataType
,
OutDataType
,
ElementwiseOperation
,
MemOp
>
{
template
<
typename
Desc_M
>
static
auto
PadDescriptor_M_1d
(
Desc_M
desc_m
,
index_t
gridSize
,
index_t
blockSize
)
{
constexpr
auto
I0
=
Number
<
0
>
{};
const
auto
m
=
desc_m
.
GetLength
(
I0
);
const
index_t
loop_step
=
gridSize
*
blockSize
*
InVectorSize
;
const
auto
pad
=
math
::
integer_least_multiple
(
m
,
loop_step
)
-
m
;
const
auto
desc_m_pad
=
transform_tensor_descriptor
(
desc_m
,
make_tuple
(
make_right_pad_transform
(
m
,
pad
)),
make_tuple
(
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
return
desc_m_pad
;
}
static
auto
MakeDescriptor_M
(
index_t
length
,
index_t
gridSize
,
index_t
blockSize
)
{
const
auto
desc_m
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
length
));
return
PadDescriptor_M_1d
(
desc_m
,
gridSize
,
blockSize
);
}
using
InGrid1dDesc
=
decltype
(
MakeDescriptor_M
(
1
,
1
,
1
));
using
GridwisePutElement
=
GridwisePutElement_1D
<
InGrid1dDesc
,
InDataType
,
IndexDataType
,
OutDataType
,
ElementwiseOperation
,
MemOp
,
InVectorSize
>
;
struct
Argument
:
public
BaseArgument
{
Argument
(
const
InDataType
*
p_input
,
const
IndexDataType
*
p_indices
,
OutDataType
*
p_output
,
index_t
input_length
,
ElementwiseOperation
elementwise_op
)
:
p_input_
{
p_input
},
p_indices_
{
p_indices
},
p_output_
{
p_output
},
input_length_raw_
{
input_length
},
elementwise_op_
{
elementwise_op
},
blockSize_
{
256
}
{
}
const
InDataType
*
p_input_
;
const
IndexDataType
*
p_indices_
;
OutDataType
*
p_output_
;
index_t
input_length_raw_
;
ElementwiseOperation
elementwise_op_
;
index_t
blockSize_
;
};
struct
Invoker
:
public
BaseInvoker
{
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
{
index_t
gridSize
=
getAvailableComputeUnitCount
(
stream_config
);
InGrid1dDesc
in_grid_desc
=
MakeDescriptor_M
(
arg
.
input_length_raw_
,
gridSize
,
arg
.
blockSize_
);
const
auto
kernel
=
kernel_put_element_1d
<
GridwisePutElement
,
InGrid1dDesc
,
InDataType
,
IndexDataType
,
OutDataType
,
ElementwiseOperation
>
;
float
elapsed_time
=
launch_and_time_kernel
(
stream_config
,
kernel
,
dim3
(
gridSize
),
dim3
(
arg
.
blockSize_
),
0
,
in_grid_desc
,
arg
.
p_input_
,
arg
.
p_indices_
,
arg
.
p_output_
,
arg
.
elementwise_op_
);
return
elapsed_time
;
}
float
Run
(
const
BaseArgument
*
p_arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
override
{
return
Run
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
),
stream_config
);
}
};
bool
IsSupportedArgument
(
const
BaseArgument
*
p_arg
)
override
{
const
Argument
*
pArg
=
dynamic_cast
<
const
Argument
*>
(
p_arg
);
if
(
pArg
->
input_length_raw_
%
InVectorSize
!=
0
)
{
return
false
;
}
return
true
;
}
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
void
*
p_input
,
const
void
*
p_indices
,
void
*
p_output
,
index_t
input_length
,
index_t
,
ElementwiseOperation
elementwise_op
)
override
{
return
std
::
make_unique
<
Argument
>
(
static_cast
<
const
InDataType
*>
(
p_input
),
static_cast
<
const
IndexDataType
*>
(
p_indices
),
static_cast
<
OutDataType
*>
(
p_output
),
input_length
,
elementwise_op
);
}
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
override
{
return
std
::
make_unique
<
Invoker
>
(
Invoker
{});
}
};
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
include/ck/tensor_operation/gpu/device/impl/device_reduce_common.hpp
View file @
4a106f7d
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
2
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
include/ck/tensor_operation/gpu/device/impl/device_reduce_multiblock.hpp
View file @
4a106f7d
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
2
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
include/ck/tensor_operation/gpu/device/impl/device_reduce_threadwise.hpp
View file @
4a106f7d
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
2
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -28,6 +28,7 @@ template <typename InDataType,
typename
AccElementwiseOperation
,
bool
PropagateNan
,
bool
OutputIndex
,
bool
TransformIndexKtoGlobal
,
bool
HaveIndexInputIfOutputIndex
,
index_t
BlockSize
,
index_t
MThreadSliceSize
,
...
...
@@ -260,6 +261,7 @@ struct DeviceReduceThreadWise : public DeviceReduce<InDataType,
const
auto
kernel
=
kernel_reduce_threadwise
<
GridwiseReduce
,
OutputIndex
,
TransformIndexKtoGlobal
,
HaveIndexInput
,
InDataType
,
OutDataType
,
...
...
Prev
1
…
24
25
26
27
28
29
30
31
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment