Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
83be9a70
Commit
83be9a70
authored
Nov 07, 2023
by
Bartlomiej Kocot
Committed by
Bartłomiej Kocot
Nov 07, 2023
Browse files
Support multi AB for grouped conv fwd xdl
parent
98fd41f5
Changes
12
Show whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
587 additions
and
189 deletions
+587
-189
include/ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_d.hpp
...eration/gpu/device/device_grouped_conv_fwd_multiple_d.hpp
+77
-11
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp
...device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp
+2
-2
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp
...vice_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp
+2
-2
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_dl.hpp
...ion/gpu/device/impl/device_grouped_conv_bwd_weight_dl.hpp
+2
-2
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle.hpp
...ice/impl/device_grouped_conv_bwd_weight_wmma_cshuffle.hpp
+2
-2
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp
...vice/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp
+2
-2
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp
.../device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp
+2
-2
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp
...impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp
+2
-2
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_cshuffle.hpp
.../impl/device_grouped_conv_fwd_multiple_d_xdl_cshuffle.hpp
+332
-150
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp
...r_operation/gpu/device/impl/device_grouped_conv_utils.hpp
+148
-7
include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp
...r_operation/gpu/element/binary_element_wise_operation.hpp
+5
-2
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp
...tion/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp
+11
-5
No files found.
include/ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_d.hpp
View file @
83be9a70
...
@@ -6,18 +6,42 @@
...
@@ -6,18 +6,42 @@
#include <array>
#include <array>
#include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp"
#include "ck/utility/is_detected.hpp"
namespace
ck
{
namespace
ck
{
namespace
tensor_operation
{
namespace
tensor_operation
{
namespace
device
{
namespace
device
{
// Convolution Forward:
template
<
typename
T
>
// input : input image A[G, N, C, Hi, Wi],
using
is_tuple
=
decltype
(
std
::
declval
<
T
&>
().
IsTuple
());
// input : weight B[G, K, C, Y, X],
// input : D0[G, N, K, Ho, Wo], D1[G, N, K, Ho, Wo], ...
/**
// output : output image E[G, N, K, Ho, Wo]
* \brief Grouped Convolution Forward
// C = a_op(A) * b_op(B)
*
// E = cde_op(C, D0, D1, ...)
* \details
* input : input image A[G, N, C, Hi, Wi], A1[G, N, C, Hi, Wi]...
* input : weight B[G, K, C, Y, X], B1[G, K, C, Y, X]...
* input : D0[G, N, K, Ho, Wo], D1[G, N, K, Ho, Wo], ...
* output : output image E[G, N, K, Ho, Wo]
*
* C = a_op(A, A1...) * b_op(B, B1...)
* E = cde_op(C, D0, D1, ...)
*
* \tparam NDimSpatial Number of spatial dimensions.
* \tparam ALayout Input layout (also for a1, a2...).
* \tparam BLayout Weight layout (also for b1, b2...).
* \tparam DsLayout Ds layouts.
* \tparam ELayout Output layout.
* \tparam ADataType Input data type. Pass tuple if there is multiple A.
* \tparam BDataType Weight data type. Pass tuple if there is multiple B.
* \tparam DsDataType D data types.
* \tparam EDataType Output data type.
* \tparam AElementwiseOperation A elementwise operation.
* \tparam BElementwiseOperation B elementwise operation.
* \tparam CDEElementwiseOperation CDE elementwise operation.
* \tparam ComputeType Compute data type (default: ADataType, first if tuple passed).
*/
template
<
index_t
NDimSpatial
,
template
<
index_t
NDimSpatial
,
typename
ALayout
,
typename
ALayout
,
typename
BLayout
,
typename
BLayout
,
...
@@ -30,18 +54,60 @@ template <index_t NDimSpatial,
...
@@ -30,18 +54,60 @@ template <index_t NDimSpatial,
typename
AElementwiseOperation
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CDEElementwiseOperation
,
typename
CDEElementwiseOperation
,
typename
ComputeType
=
ADataType
>
typename
ComputeType
=
decltype
(
UnpackDataType
<
is_detected
<
is_tuple
,
ADataType
>
::
value
,
Number
<
0
>
,
ADataType
>
())
>
// ComputeType is InputType by default (first
// in tuple for MultiAB), unpack if tuple was
// passed
struct
DeviceGroupedConvFwdMultipleD
:
public
BaseOperator
struct
DeviceGroupedConvFwdMultipleD
:
public
BaseOperator
{
{
static
constexpr
bool
isMultiA
=
is_detected
<
is_tuple
,
ADataType
>::
value
;
static
constexpr
bool
isMultiB
=
is_detected
<
is_tuple
,
BDataType
>::
value
;
static
constexpr
index_t
NumATensor
=
GetNumABTensors
<
isMultiA
,
ADataType
>
();
static
constexpr
index_t
NumBTensor
=
GetNumABTensors
<
isMultiB
,
BDataType
>
();
static
constexpr
index_t
NumDTensor
=
DsDataType
::
Size
();
static
constexpr
index_t
NumDTensor
=
DsDataType
::
Size
();
static_assert
(
NumDTensor
==
DsLayout
::
Size
(),
"wrong! Inconsistent NumDTensor"
);
static_assert
(
NumDTensor
==
DsLayout
::
Size
(),
"wrong! Inconsistent NumDTensor"
);
// If DataType is tuple, user has to pass std::array with pointers.
using
APointers
=
std
::
conditional_t
<
isMultiA
,
std
::
array
<
const
void
*
,
NumATensor
>&
,
const
void
*>
;
using
BPointers
=
std
::
conditional_t
<
isMultiB
,
std
::
array
<
const
void
*
,
NumBTensor
>&
,
const
void
*>
;
/**
* \brief Make argument pointer for grouped conv fwd.
*
* \param p_a A pointer to the input (std::array<const void*, NumA> with
pointers for multiple A).
* \param p_b A pointer to the weight (std::array<const void*, NumA> with
pointers for multiple B).
* \param p_ds A pointers to the Ds.
* \param p_e A pointers to the output.
* \param a_g_n_c_wis_lengths Input lengths [G, N, C, Spatial...] (for 3d).
* \param a_g_n_c_wis_strides Input strides [G, N, C, Spatial...] (for 3d).
* \param b_g_k_c_xs_lengths Weight lengths [G, K, C, Spatial...] (for 3d).
* \param b_g_k_c_xs_strides Weight strides [G, K, C, Spatial...] (for 3d).
* \param ds_g_n_k_wos_lengths Ds lengths [G, N, K, Spatial...] (for 3d).
* \param ds_g_n_k_wos_strides Ds strides [G, N, K, Spatial...] (for 3d).
* \param e_g_n_k_wos_lengths Output lengths [G, N, K, Spatial...] (for 3d).
* \param e_g_n_k_wos_strides Output strides [G, N, K, Spatial...] (for 3d).
* \param conv_filter_strides Convolution filter strides.
* \param conv_filter_dilations Convolution filter dilations.
* \param input_left_pads Input left paddings.
* \param input_right_pads Input right paddings.
* \param a_element_op A elementwise operation object.
* \param b_element_op B elementwise operation object.
* \param cde_element_op CDE elementwise operation object.
* \return Pointer to the argument.
*/
virtual
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
virtual
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
void
*
p_a
,
// input image
APointers
p_a
,
const
void
*
p_b
,
// weight
BPointers
p_b
,
const
std
::
array
<
const
void
*
,
NumDTensor
>&
p_ds
,
const
std
::
array
<
const
void
*
,
NumDTensor
>&
p_ds
,
void
*
p_e
,
// output image
void
*
p_e
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
a_g_n_c_wis_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
a_g_n_c_wis_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
a_g_n_c_wis_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
a_g_n_c_wis_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
b_g_k_c_xs_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
b_g_k_c_xs_lengths
,
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp
View file @
83be9a70
...
@@ -517,7 +517,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle
...
@@ -517,7 +517,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle
std
::
vector
<
typename
GridwiseGemm
::
DefaultBlock2CTileMap
>
block_2_ctile_map_container_
;
std
::
vector
<
typename
GridwiseGemm
::
DefaultBlock2CTileMap
>
block_2_ctile_map_container_
;
// for computing batch offset
// for computing batch offset
ComputePtrOffsetOfStridedBatch
<
NumDTensor
>
compute_ptr_offset_of_batch_
;
ComputePtrOffsetOfStridedBatch
<
I1
,
I1
,
NumDTensor
>
compute_ptr_offset_of_batch_
;
// element-wise op
// element-wise op
AElementwiseOp
a_element_op_
;
AElementwiseOp
a_element_op_
;
...
@@ -579,7 +579,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle
...
@@ -579,7 +579,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle
typename
GridwiseGemm
::
DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
GridwiseGemm
::
DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
GridwiseGemm
::
EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
GridwiseGemm
::
EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
remove_reference_t
<
typename
GridwiseGemm
::
DefaultBlock2CTileMap
>
,
remove_reference_t
<
typename
GridwiseGemm
::
DefaultBlock2CTileMap
>
,
ComputePtrOffsetOfStridedBatch
<
NumDTensor
>
,
ComputePtrOffsetOfStridedBatch
<
I1
,
I1
,
NumDTensor
>
,
has_main_loop
>
;
has_main_loop
>
;
return
launch_and_time_kernel
(
return
launch_and_time_kernel
(
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp
View file @
83be9a70
...
@@ -677,7 +677,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
...
@@ -677,7 +677,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
std
::
vector
<
Block2ETileMap
>
block_2_etile_map_container_
;
std
::
vector
<
Block2ETileMap
>
block_2_etile_map_container_
;
// for computing batch offset
// for computing batch offset
ComputePtrOffsetOfStridedBatch
<
NumDTensor
>
compute_ptr_offset_of_batch_
;
ComputePtrOffsetOfStridedBatch
<
I1
,
I1
,
NumDTensor
>
compute_ptr_offset_of_batch_
;
// element-wise op
// element-wise op
AElementwiseOp
a_element_op_
;
AElementwiseOp
a_element_op_
;
...
@@ -746,7 +746,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
...
@@ -746,7 +746,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
DeviceOp
::
DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
,
DeviceOp
::
DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
,
DeviceOp
::
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
,
DeviceOp
::
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
,
Block2ETileMap
,
Block2ETileMap
,
ComputePtrOffsetOfStridedBatch
<
NumDTensor
>
,
ComputePtrOffsetOfStridedBatch
<
I1
,
I1
,
NumDTensor
>
,
has_main_loop
>
;
has_main_loop
>
;
return
launch_and_time_kernel
(
return
launch_and_time_kernel
(
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_dl.hpp
View file @
83be9a70
...
@@ -927,7 +927,7 @@ struct DeviceGroupedConvBwdWeight_Dl : public DeviceGroupedConvBwdWeight<NDimSpa
...
@@ -927,7 +927,7 @@ struct DeviceGroupedConvBwdWeight_Dl : public DeviceGroupedConvBwdWeight<NDimSpa
Block2CTileMap
block_2_ctile_map_
;
Block2CTileMap
block_2_ctile_map_
;
// for computing batch offset
// for computing batch offset
ComputePtrOffsetOfStridedBatch
<
I0
>
compute_ptr_offset_of_batch_
;
ComputePtrOffsetOfStridedBatch
<>
compute_ptr_offset_of_batch_
;
// element-wise op
// element-wise op
OutElementwiseOperation
a_element_op_
;
OutElementwiseOperation
a_element_op_
;
...
@@ -999,7 +999,7 @@ struct DeviceGroupedConvBwdWeight_Dl : public DeviceGroupedConvBwdWeight<NDimSpa
...
@@ -999,7 +999,7 @@ struct DeviceGroupedConvBwdWeight_Dl : public DeviceGroupedConvBwdWeight<NDimSpa
remove_reference_t
<
DeviceOp
::
BGridDesc_B_K0_N0_N1_K1
>
,
remove_reference_t
<
DeviceOp
::
BGridDesc_B_K0_N0_N1_K1
>
,
remove_reference_t
<
DeviceOp
::
CGridDesc_M0_M10_M11_N0_N10_N11
>
,
remove_reference_t
<
DeviceOp
::
CGridDesc_M0_M10_M11_N0_N10_N11
>
,
remove_reference_t
<
DeviceOp
::
Block2CTileMap
>
,
remove_reference_t
<
DeviceOp
::
Block2CTileMap
>
,
ComputePtrOffsetOfStridedBatch
<
I0
>
,
ComputePtrOffsetOfStridedBatch
<>
,
has_main_loop
,
has_main_loop
,
has_double_loop
>
;
has_double_loop
>
;
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle.hpp
View file @
83be9a70
...
@@ -565,7 +565,7 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffle
...
@@ -565,7 +565,7 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffle
Block2CTileMap
block_2_ctile_map_
;
Block2CTileMap
block_2_ctile_map_
;
// for computing batch offset
// for computing batch offset
ComputePtrOffsetOfStridedBatch
<
I0
>
compute_ptr_offset_of_batch_
;
ComputePtrOffsetOfStridedBatch
<>
compute_ptr_offset_of_batch_
;
OutElementwiseOperation
a_element_op_
;
OutElementwiseOperation
a_element_op_
;
InElementwiseOperation
b_element_op_
;
InElementwiseOperation
b_element_op_
;
...
@@ -647,7 +647,7 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffle
...
@@ -647,7 +647,7 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffle
DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
,
DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
,
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
,
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
,
remove_reference_t
<
typename
GridwiseGemm
::
DefaultBlock2CTileMap
>
,
remove_reference_t
<
typename
GridwiseGemm
::
DefaultBlock2CTileMap
>
,
ComputePtrOffsetOfStridedBatch
<
I0
>
,
ComputePtrOffsetOfStridedBatch
<>
,
has_main_loop
>
;
has_main_loop
>
;
using
EmptyTuple
=
Tuple
<>
;
using
EmptyTuple
=
Tuple
<>
;
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp
View file @
83be9a70
...
@@ -1197,7 +1197,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
...
@@ -1197,7 +1197,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
Block2CTileMap
block_2_ctile_map_
;
Block2CTileMap
block_2_ctile_map_
;
// for computing batch offset
// for computing batch offset
ComputePtrOffsetOfStridedBatch
<
I0
>
compute_ptr_offset_of_batch_
;
ComputePtrOffsetOfStridedBatch
<>
compute_ptr_offset_of_batch_
;
index_t
M01_
;
index_t
M01_
;
index_t
N01_
;
index_t
N01_
;
...
@@ -1276,7 +1276,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
...
@@ -1276,7 +1276,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
remove_reference_t
<
DeviceOp
::
BGridDesc_K0_N_K1
>
,
remove_reference_t
<
DeviceOp
::
BGridDesc_K0_N_K1
>
,
remove_reference_t
<
DeviceOp
::
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
>
,
remove_reference_t
<
DeviceOp
::
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
>
,
remove_reference_t
<
DeviceOp
::
Block2CTileMap
>
,
remove_reference_t
<
DeviceOp
::
Block2CTileMap
>
,
ComputePtrOffsetOfStridedBatch
<
I0
>
,
ComputePtrOffsetOfStridedBatch
<>
,
has_main_loop
>
;
has_main_loop
>
;
return
launch_and_time_kernel
(
stream_config
,
return
launch_and_time_kernel
(
stream_config
,
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp
View file @
83be9a70
...
@@ -537,7 +537,7 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
...
@@ -537,7 +537,7 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
DefaultBlock2CTileMap
block_2_ctile_map_
;
DefaultBlock2CTileMap
block_2_ctile_map_
;
// for computing batch offset
// for computing batch offset
ComputePtrOffsetOfStridedBatch
<
NumDTensor
>
compute_ptr_offset_of_batch_
;
ComputePtrOffsetOfStridedBatch
<
I1
,
I1
,
NumDTensor
>
compute_ptr_offset_of_batch_
;
// element-wise op
// element-wise op
AElementwiseOperation
a_element_op_
;
AElementwiseOperation
a_element_op_
;
...
@@ -601,7 +601,7 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
...
@@ -601,7 +601,7 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
DeviceOp
::
DsGridDesc_M0_M10_M11_N0_N10_N11
,
DeviceOp
::
DsGridDesc_M0_M10_M11_N0_N10_N11
,
DeviceOp
::
CGridDesc_M0_M10_M11_N0_N10_N11
,
DeviceOp
::
CGridDesc_M0_M10_M11_N0_N10_N11
,
DefaultBlock2CTileMap
,
DefaultBlock2CTileMap
,
ComputePtrOffsetOfStridedBatch
<
NumDTensor
>
,
ComputePtrOffsetOfStridedBatch
<
I1
,
I1
,
NumDTensor
>
,
has_main_loop
,
has_main_loop
,
has_double_loop
>
;
has_double_loop
>
;
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp
View file @
83be9a70
...
@@ -428,7 +428,7 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
...
@@ -428,7 +428,7 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
typename
GridwiseOp
::
DefaultBlock2CTileMap
block_2_etile_map_
;
typename
GridwiseOp
::
DefaultBlock2CTileMap
block_2_etile_map_
;
// for computing batch offset
// for computing batch offset
ComputePtrOffsetOfStridedBatch
<
NumDTensor
>
compute_ptr_offset_of_batch_
;
ComputePtrOffsetOfStridedBatch
<
I1
,
I1
,
NumDTensor
>
compute_ptr_offset_of_batch_
;
// element-wise op
// element-wise op
AElementwiseOperation
a_element_op_
;
AElementwiseOperation
a_element_op_
;
...
@@ -485,7 +485,7 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
...
@@ -485,7 +485,7 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
typename
GridwiseOp
::
DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
GridwiseOp
::
DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
GridwiseOp
::
EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
GridwiseOp
::
EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
remove_reference_t
<
typename
GridwiseOp
::
DefaultBlock2CTileMap
>
,
remove_reference_t
<
typename
GridwiseOp
::
DefaultBlock2CTileMap
>
,
ComputePtrOffsetOfStridedBatch
<
NumDTensor
>
,
ComputePtrOffsetOfStridedBatch
<
I1
,
I1
,
NumDTensor
>
,
has_main_loop
>
;
has_main_loop
>
;
return
launch_and_time_kernel
(
stream_config
,
return
launch_and_time_kernel
(
stream_config
,
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_cshuffle.hpp
View file @
83be9a70
...
@@ -19,6 +19,7 @@
...
@@ -19,6 +19,7 @@
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/host_utility/kernel_launch.hpp"
...
@@ -56,7 +57,8 @@ namespace {
...
@@ -56,7 +57,8 @@ namespace {
*
*
*/
*/
template
<
typename
GridwiseGemm
,
template
<
typename
GridwiseGemm
,
typename
ABDataType
,
typename
AsPointer
,
// tuples if multi AB, pointers if no
typename
BsPointer
,
typename
DsPointer
,
typename
DsPointer
,
typename
EDataType
,
typename
EDataType
,
typename
AElementwiseOperation
,
typename
AElementwiseOperation
,
...
@@ -68,14 +70,16 @@ template <typename GridwiseGemm,
...
@@ -68,14 +70,16 @@ template <typename GridwiseGemm,
typename
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
Block2ETileMap
,
typename
Block2ETileMap
,
typename
ComputePtrOffsetOfBatch
,
typename
ComputePtrOffsetOfBatch
,
bool
HasMainKBlockLoop
>
bool
HasMainKBlockLoop
,
bool
isMultiA
,
bool
isMultiB
>
__global__
void
__global__
void
#if CK_USE_LAUNCH_BOUNDS
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
#endif
#endif
kernel_grouped_conv_fwd_multiple_d_xdl_cshuffle
(
kernel_grouped_conv_fwd_multiple_d_xdl_cshuffle
(
const
ABDataType
*
__restrict__
p_a_grid
,
AsPointer
p_a
s
_grid
,
const
ABDataType
*
__restrict__
p_b_grid
,
BsPointer
p_b
s
_grid
,
DsPointer
p_ds_grid
,
DsPointer
p_ds_grid
,
EDataType
*
__restrict__
p_e_grid
,
EDataType
*
__restrict__
p_e_grid
,
const
AElementwiseOperation
a_element_op
,
const
AElementwiseOperation
a_element_op
,
...
@@ -98,13 +102,8 @@ __global__ void
...
@@ -98,13 +102,8 @@ __global__ void
__builtin_amdgcn_readfirstlane
(
get_grid_size
()
/
batch_count
);
__builtin_amdgcn_readfirstlane
(
get_grid_size
()
/
batch_count
);
const
index_t
g_idx
=
__builtin_amdgcn_readfirstlane
(
get_block_1d_id
()
/
num_blocks_per_batch
);
const
index_t
g_idx
=
__builtin_amdgcn_readfirstlane
(
get_block_1d_id
()
/
num_blocks_per_batch
);
const
long_index_t
a_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_ptr_offset_of_batch
.
GetAPtrOffset
(
g_idx
)));
const
long_index_t
b_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_ptr_offset_of_batch
.
GetBPtrOffset
(
g_idx
)));
const
long_index_t
e_batch_offset
=
__builtin_amdgcn_readfirstlane
(
const
long_index_t
e_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_ptr_offset_of_batch
.
GetEPtrOffset
(
g_idx
)));
static_cast
<
long_index_t
>
(
compute_ptr_offset_of_batch
.
GetEPtrOffset
(
g_idx
)));
const
auto
ds_batch_offset
=
compute_ptr_offset_of_batch
.
GetDsPtrOffset
(
g_idx
);
const
auto
ds_batch_offset
=
compute_ptr_offset_of_batch
.
GetDsPtrOffset
(
g_idx
);
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
...
@@ -117,8 +116,26 @@ __global__ void
...
@@ -117,8 +116,26 @@ __global__ void
static_for
<
0
,
NumDTensor
,
1
>
{}(
static_for
<
0
,
NumDTensor
,
1
>
{}(
[
&
](
auto
i
)
{
p_ds_grid_grp
(
i
)
=
p_ds_grid
[
i
]
+
ds_batch_offset
[
i
];
});
[
&
](
auto
i
)
{
p_ds_grid_grp
(
i
)
=
p_ds_grid
[
i
]
+
ds_batch_offset
[
i
];
});
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
p_a_grid
+
a_batch_offset
,
if
constexpr
(
isMultiA
||
isMultiB
)
p_b_grid
+
b_batch_offset
,
{
AsPointer
p_as_grid_grp
;
BsPointer
p_bs_grid_grp
;
const
auto
as_batch_offset
=
compute_ptr_offset_of_batch
.
GetAsPtrOffset
(
g_idx
);
static
constexpr
index_t
NumATensor
=
AGridDesc_AK0_M_AK1
::
Size
();
static_for
<
0
,
NumATensor
,
1
>
{}(
[
&
](
auto
i
)
{
p_as_grid_grp
(
i
)
=
p_as_grid
[
i
]
+
as_batch_offset
[
i
];
});
const
auto
bs_batch_offset
=
compute_ptr_offset_of_batch
.
GetBsPtrOffset
(
g_idx
);
static
constexpr
index_t
NumBTensor
=
BGridDesc_BK0_N_BK1
::
Size
();
static_for
<
0
,
NumBTensor
,
1
>
{}(
[
&
](
auto
i
)
{
p_bs_grid_grp
(
i
)
=
p_bs_grid
[
i
]
+
bs_batch_offset
[
i
];
});
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
p_as_grid_grp
,
p_bs_grid_grp
,
p_ds_grid_grp
,
p_ds_grid_grp
,
p_e_grid
+
e_batch_offset
,
p_e_grid
+
e_batch_offset
,
p_shared
,
p_shared
,
...
@@ -130,9 +147,32 @@ __global__ void
...
@@ -130,9 +147,32 @@ __global__ void
ds_grid_desc_mblock_mperblock_nblock_nperblock
,
ds_grid_desc_mblock_mperblock_nblock_nperblock
,
e_grid_desc_mblock_mperblock_nblock_nperblock_
,
e_grid_desc_mblock_mperblock_nblock_nperblock_
,
block_2_ctile_map
);
block_2_ctile_map
);
}
else
{
const
long_index_t
a_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_ptr_offset_of_batch
.
GetAPtrOffset
(
g_idx
)));
const
long_index_t
b_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_ptr_offset_of_batch
.
GetBPtrOffset
(
g_idx
)));
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
p_as_grid
+
a_batch_offset
,
p_bs_grid
+
b_batch_offset
,
p_ds_grid_grp
,
p_e_grid
+
e_batch_offset
,
p_shared
,
a_element_op
,
b_element_op
,
cde_element_op
,
a_grid_desc_k0_m_k1
,
b_grid_desc_k0_n_k1
,
ds_grid_desc_mblock_mperblock_nblock_nperblock
,
e_grid_desc_mblock_mperblock_nblock_nperblock_
,
block_2_ctile_map
);
}
#else
#else
ignore
=
p_a_grid
;
ignore
=
p_a
s
_grid
;
ignore
=
p_b_grid
;
ignore
=
p_b
s
_grid
;
ignore
=
p_ds_grid
;
ignore
=
p_ds_grid
;
ignore
=
p_e_grid
;
ignore
=
p_e_grid
;
ignore
=
batch_count
;
ignore
=
batch_count
;
...
@@ -150,6 +190,9 @@ __global__ void
...
@@ -150,6 +190,9 @@ __global__ void
}
// namespace
}
// namespace
template
<
typename
T
>
using
is_tuple
=
decltype
(
std
::
declval
<
T
&>
().
IsTuple
());
//
//
// @brief Device Convolution operation.
// @brief Device Convolution operation.
//
//
...
@@ -211,7 +254,12 @@ template <index_t NDimSpatial,
...
@@ -211,7 +254,12 @@ template <index_t NDimSpatial,
index_t
CShuffleNXdlPerWavePerShuffle
,
index_t
CShuffleNXdlPerWavePerShuffle
,
typename
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
index_t
CDEBlockTransferScalarPerVector_NPerBlock
,
index_t
CDEBlockTransferScalarPerVector_NPerBlock
,
typename
ComputeDataType
=
ADataType
,
typename
ComputeDataType
=
decltype
(
UnpackDataType
<
is_detected
<
is_tuple
,
ADataType
>
::
value
,
Number
<
0
>
,
ADataType
>
()),
// ComputeType is InputType by default (first
// in tuple for MultiAB), unpack if tuple was
// passed
LoopScheduler
LoopSched
=
make_default_loop_scheduler
()
>
LoopScheduler
LoopSched
=
make_default_loop_scheduler
()
>
struct
DeviceGroupedConvFwdMultipleD_Xdl_CShuffle
struct
DeviceGroupedConvFwdMultipleD_Xdl_CShuffle
:
public
DeviceGroupedConvFwdMultipleD
<
NDimSpatial
,
:
public
DeviceGroupedConvFwdMultipleD
<
NDimSpatial
,
...
@@ -230,6 +278,11 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle
...
@@ -230,6 +278,11 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle
{
{
using
DeviceOp
=
DeviceGroupedConvFwdMultipleD_Xdl_CShuffle
;
using
DeviceOp
=
DeviceGroupedConvFwdMultipleD_Xdl_CShuffle
;
static
constexpr
bool
isMultiA
=
is_detected
<
is_tuple
,
ADataType
>::
value
;
static
constexpr
bool
isMultiB
=
is_detected
<
is_tuple
,
BDataType
>::
value
;
static
constexpr
index_t
NumATensor
=
GetNumABTensors
<
isMultiA
,
ADataType
>
();
static
constexpr
index_t
NumBTensor
=
GetNumABTensors
<
isMultiB
,
BDataType
>
();
static
constexpr
index_t
NumDTensor
=
DsDataType
::
Size
();
static
constexpr
index_t
NumDTensor
=
DsDataType
::
Size
();
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I0
=
Number
<
0
>
{};
...
@@ -325,51 +378,43 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle
...
@@ -325,51 +378,43 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle
using
DsGridDesc_M_N
=
remove_cvref_t
<
decltype
(
MakeDsGridDescriptor_M_N
({},
{}))
>
;
using
DsGridDesc_M_N
=
remove_cvref_t
<
decltype
(
MakeDsGridDescriptor_M_N
({},
{}))
>
;
using
EGridDesc_M_N
=
remove_cvref_t
<
decltype
(
MakeEGridDescriptor_M_N
<
ELayout
>
({},
{}))
>
;
using
EGridDesc_M_N
=
remove_cvref_t
<
decltype
(
MakeEGridDescriptor_M_N
<
ELayout
>
({},
{}))
>
;
// GridwiseGemm
// If we are using multiAB and one of the template datatype parameters is not a tuple, convert
using
GridwiseGemm
=
GridwiseGemmMultipleD_xdl_cshuffle
<
// it to it
ADataType
,
// TODO: distinguish A/B datatype
using
GemmADataType
=
std
::
conditional_t
<!
isMultiA
&&
isMultiB
,
Tuple
<
ADataType
>
,
ADataType
>
;
BDataType
,
using
GemmBDataType
=
std
::
conditional_t
<!
isMultiB
&&
isMultiA
,
Tuple
<
BDataType
>
,
BDataType
>
;
ComputeDataType
,
AccDataType
,
#define GridwiseGemmTemplateParameters \
CShuffleDataType
,
GemmADataType, GemmBDataType, ComputeDataType, AccDataType, CShuffleDataType, DsDataType, \
DsDataType
,
EDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, \
EDataType
,
InMemoryDataOperationEnum::Set, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, \
AElementwiseOperation
,
KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave, \
BElementwiseOperation
,
ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, \
CDEElementwiseOperation
,
ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, \
InMemoryDataOperationEnum
::
Set
,
ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, \
NumGemmKPrefetchStage
,
ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, \
BlockSize
,
BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, \
MPerBlock
,
BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, \
NPerBlock
,
BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, \
KPerBlock
,
CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, \
AK1
,
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, \
BK1
,
CDEBlockTransferScalarPerVector_NPerBlock, LoopSched
MPerXDL
,
// Use appropriate gridwise gemm
NPerXDL
,
using
GridwiseGemm
=
MXdlPerWave
,
std
::
conditional_t
<
isMultiA
||
isMultiB
,
NXdlPerWave
,
GridwiseGemmMultipleABD_xdl_cshuffle
<
GridwiseGemmTemplateParameters
>
,
ABlockTransferThreadClusterLengths_AK0_M_AK1
,
GridwiseGemmMultipleD_xdl_cshuffle
<
GridwiseGemmTemplateParameters
>>
;
ABlockTransferThreadClusterArrangeOrder
,
ABlockTransferSrcAccessOrder
,
// If ADataTypes or BDataTypes is tuple, user has to pass std::array with pointers.
ABlockTransferSrcVectorDim
,
using
APointers
=
ABlockTransferSrcScalarPerVector
,
std
::
conditional_t
<
isMultiA
,
std
::
array
<
const
void
*
,
NumATensor
>&
,
const
void
*>
;
ABlockTransferDstScalarPerVector_AK1
,
using
BPointers
=
false
,
std
::
conditional_t
<
isMultiB
,
std
::
array
<
const
void
*
,
NumBTensor
>&
,
const
void
*>
;
ABlockLdsExtraM
,
// Use Tuple for the both cases for GridPointer to initialize it in Argument constructor (not
BBlockTransferThreadClusterLengths_BK0_N_BK1
,
// in initializer list what is required for single const pointer).
BBlockTransferThreadClusterArrangeOrder
,
using
AGridPointer
=
remove_cvref_t
<
BBlockTransferSrcAccessOrder
,
decltype
(
GetAGridPointer
<
isMultiA
||
isMultiB
,
GridwiseGemm
,
ADataType
>
())
>
;
BBlockTransferSrcVectorDim
,
using
BGridPointer
=
remove_cvref_t
<
BBlockTransferSrcScalarPerVector
,
decltype
(
GetBGridPointer
<
isMultiA
||
isMultiB
,
GridwiseGemm
,
BDataType
>
())
>
;
BBlockTransferDstScalarPerVector_BK1
,
false
,
BBlockLdsExtraN
,
CShuffleMXdlPerWavePerShuffle
,
CShuffleNXdlPerWavePerShuffle
,
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
CDEBlockTransferScalarPerVector_NPerBlock
,
LoopSched
>
;
// desc for blockwise copy
// desc for blockwise copy
using
AGridDesc_AK0_M_AK1
=
using
AGridDesc_AK0_M_AK1
=
...
@@ -392,8 +437,8 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle
...
@@ -392,8 +437,8 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle
// Argument
// Argument
struct
Argument
:
public
BaseArgument
struct
Argument
:
public
BaseArgument
{
{
Argument
(
const
void
*
p_a
,
Argument
(
APointers
p_a
s
,
const
void
*
p_b
,
BPointers
p_b
s
,
const
std
::
array
<
const
void
*
,
NumDTensor
>&
p_ds
,
const
std
::
array
<
const
void
*
,
NumDTensor
>&
p_ds
,
void
*
p_e
,
void
*
p_e
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
a_g_n_c_wis_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
a_g_n_c_wis_lengths
,
...
@@ -413,8 +458,8 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle
...
@@ -413,8 +458,8 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle
const
AElementwiseOperation
&
a_element_op
,
const
AElementwiseOperation
&
a_element_op
,
const
BElementwiseOperation
&
b_element_op
,
const
BElementwiseOperation
&
b_element_op
,
const
CDEElementwiseOperation
&
cde_element_op
)
const
CDEElementwiseOperation
&
cde_element_op
)
:
p_a_grid_
{
static_cast
<
const
ADataType
*>
(
p_a
)
},
:
p_a
s
_grid_
{},
p_b_grid_
{
static_cast
<
const
BDataType
*>
(
p_b
)
},
p_b
s
_grid_
{},
p_ds_grid_
{},
p_ds_grid_
{},
p_e_grid_
{
static_cast
<
EDataType
*>
(
p_e
)},
p_e_grid_
{
static_cast
<
EDataType
*>
(
p_e
)},
num_group_
{
a_g_n_c_wis_lengths
[
0
]},
num_group_
{
a_g_n_c_wis_lengths
[
0
]},
...
@@ -458,9 +503,58 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle
...
@@ -458,9 +503,58 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle
input_right_pads_
{
input_right_pads
}
input_right_pads_
{
input_right_pads
}
{
{
// A/B/E Batch Stride
// A/B/E Batch Stride
if
constexpr
(
isMultiA
||
isMultiB
)
{
static_for
<
0
,
NumATensor
,
1
>
{}([
&
](
auto
i
)
{
// Init compute_ptr_offset_of_batch_ for multiple AB
compute_ptr_offset_of_batch_
.
BatchStrideA_
(
i
)
=
a_g_n_c_wis_strides
[
0
];
// Use GemmADataType/GemmBDataType to iterate over tuple (even if passed data
// type is not tuple)
using
DataType
=
remove_cvref_t
<
tuple_element_t
<
i
.
value
,
GemmADataType
>>
;
// It is possible that one of the AB is a pointer and one is a tuple.
// Then also use multiAB but we have to cast single pointer instead of tuple of
// pointer.
if
constexpr
(
isMultiA
)
{
// p_as is tuple
p_as_grid_
(
i
)
=
static_cast
<
const
DataType
*>
(
p_as
[
i
.
value
]);
}
else
{
// if MultiB and not MultiA then p_as is single pointer
p_as_grid_
(
i
)
=
static_cast
<
const
DataType
*>
(
p_as
);
}
});
static_for
<
0
,
NumBTensor
,
1
>
{}([
&
](
auto
i
)
{
// Init compute_ptr_offset_of_batch_ for multiple AB
compute_ptr_offset_of_batch_
.
BatchStrideB_
(
i
)
=
b_g_k_c_xs_strides
[
0
];
using
DataType
=
remove_cvref_t
<
tuple_element_t
<
i
.
value
,
GemmBDataType
>>
;
// It is possible that one of the AB is a pointer and one is a tuple.
// Then also use multiAB but we have to cast single pointer instead of tuple of
// pointer.
if
constexpr
(
isMultiB
)
{
// p_bs is tuple
p_bs_grid_
(
i
)
=
static_cast
<
const
DataType
*>
(
p_bs
[
i
.
value
]);
}
else
{
// if MultiA and not MultiB then p_bs is single pointer
p_bs_grid_
(
i
)
=
static_cast
<
const
DataType
*>
(
p_bs
);
}
});
}
else
{
compute_ptr_offset_of_batch_
.
BatchStrideA_
=
a_g_n_c_wis_strides
[
0
];
compute_ptr_offset_of_batch_
.
BatchStrideA_
=
a_g_n_c_wis_strides
[
0
];
compute_ptr_offset_of_batch_
.
BatchStrideB_
=
b_g_k_c_xs_strides
[
0
];
compute_ptr_offset_of_batch_
.
BatchStrideB_
=
b_g_k_c_xs_strides
[
0
];
compute_ptr_offset_of_batch_
.
BatchStrideE_
=
e_g_n_k_wos_strides
[
0
];
// p_as and p_bs are pointers
p_as_grid_
(
I0
)
=
static_cast
<
const
ADataType
*>
(
p_as
);
p_bs_grid_
(
I0
)
=
static_cast
<
const
BDataType
*>
(
p_bs
);
}
// populate pointer, batch stride, desc for Ds
// populate pointer, batch stride, desc for Ds
static_for
<
0
,
NumDTensor
,
1
>
{}([
&
](
auto
i
)
{
static_for
<
0
,
NumDTensor
,
1
>
{}([
&
](
auto
i
)
{
...
@@ -477,8 +571,33 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle
...
@@ -477,8 +571,33 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle
ds_grid_desc_m_n_
(
i
)
=
DeviceOp
::
MakeEGridDescriptor_M_N
<
DLayout
>
(
ds_grid_desc_m_n_
(
i
)
=
DeviceOp
::
MakeEGridDescriptor_M_N
<
DLayout
>
(
ds_g_n_k_wos_lengths
[
i
],
ds_g_n_k_wos_strides
[
i
]);
ds_g_n_k_wos_lengths
[
i
],
ds_g_n_k_wos_strides
[
i
]);
});
});
compute_ptr_offset_of_batch_
.
BatchStrideE_
=
e_g_n_k_wos_strides
[
0
];
// populate desc for Ds/E
// populate desc for Ds/E
if
constexpr
(
isMultiA
||
isMultiB
)
{
const
auto
as_grid_desc_ak0_m_ak1
=
generate_tuple
([
&
](
auto
)
{
return
a_grid_desc_m_k_
;
},
Number
<
NumATensor
>
{});
const
auto
bs_grid_desc_bk0_n_bk1
=
generate_tuple
([
&
](
auto
)
{
return
b_grid_desc_n_k_
;
},
Number
<
NumBTensor
>
{});
if
(
GridwiseGemm
::
CheckValidity
(
as_grid_desc_ak0_m_ak1
,
bs_grid_desc_bk0_n_bk1
,
ds_grid_desc_m_n_
,
e_grid_desc_m_n_
,
block_2_etile_map_
))
{
e_grid_desc_mblock_mperblock_nblock_nperblock_
=
GridwiseGemm
::
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
e_grid_desc_m_n_
);
ds_grid_desc_mblock_mperblock_nblock_nperblock_
=
GridwiseGemm
::
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
ds_grid_desc_m_n_
);
}
}
else
{
if
(
GridwiseGemm
::
CheckValidity
(
a_grid_desc_m_k_
,
if
(
GridwiseGemm
::
CheckValidity
(
a_grid_desc_m_k_
,
b_grid_desc_n_k_
,
b_grid_desc_n_k_
,
ds_grid_desc_m_n_
,
ds_grid_desc_m_n_
,
...
@@ -494,6 +613,7 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle
...
@@ -494,6 +613,7 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle
ds_grid_desc_m_n_
);
ds_grid_desc_m_n_
);
}
}
}
}
}
void
Print
()
const
void
Print
()
const
{
{
...
@@ -505,9 +625,9 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle
...
@@ -505,9 +625,9 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle
}
}
// private:
// private:
// pointers
// pointers
(tuple if multi AB, pointer if no)
const
ADataType
*
p_a_grid_
;
AGridPointer
p_a
s
_grid_
;
const
BDataType
*
p_b_grid_
;
BGridPointer
p_b
s
_grid_
;
typename
GridwiseGemm
::
DsGridPointer
p_ds_grid_
;
typename
GridwiseGemm
::
DsGridPointer
p_ds_grid_
;
EDataType
*
p_e_grid_
;
EDataType
*
p_e_grid_
;
...
@@ -529,7 +649,8 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle
...
@@ -529,7 +649,8 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle
Block2ETileMap
block_2_etile_map_
;
Block2ETileMap
block_2_etile_map_
;
// for computing batch offset
// for computing batch offset
ComputePtrOffsetOfStridedBatch
<
NumDTensor
>
compute_ptr_offset_of_batch_
;
ComputePtrOffsetOfStridedBatch
<
NumATensor
,
NumBTensor
,
NumDTensor
>
compute_ptr_offset_of_batch_
;
// element-wise op
// element-wise op
AElementwiseOperation
a_element_op_
;
AElementwiseOperation
a_element_op_
;
...
@@ -563,16 +684,6 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle
...
@@ -563,16 +684,6 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle
arg
.
Print
();
arg
.
Print
();
}
}
if
(
!
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_m_k_
,
arg
.
b_grid_desc_n_k_
,
arg
.
ds_grid_desc_m_n_
,
arg
.
e_grid_desc_m_n_
,
arg
.
block_2_etile_map_
))
{
throw
std
::
runtime_error
(
"wrong! GridwiseGemmMultipleD_xdl_cshuffle has invalid setting"
);
}
const
index_t
grid_size
=
const
index_t
grid_size
=
arg
.
block_2_etile_map_
.
CalculateGridSize
(
arg
.
e_grid_desc_m_n_
)
*
arg
.
num_group_
;
arg
.
block_2_etile_map_
.
CalculateGridSize
(
arg
.
e_grid_desc_m_n_
)
*
arg
.
num_group_
;
...
@@ -582,9 +693,60 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle
...
@@ -582,9 +693,60 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle
auto
launch_kernel
=
[
&
](
auto
has_main_k_block_loop
)
{
auto
launch_kernel
=
[
&
](
auto
has_main_k_block_loop
)
{
constexpr
bool
has_main_loop
=
has_main_k_block_loop
.
value
;
constexpr
bool
has_main_loop
=
has_main_k_block_loop
.
value
;
if
constexpr
(
isMultiA
||
isMultiB
)
{
// Generate tuples with grid descriptors for each A and B
const
auto
as_grid_desc_ak0_m_ak1
=
generate_tuple
(
[
&
](
auto
)
{
return
arg
.
a_grid_desc_ak0_m_ak1_
;
},
Number
<
NumATensor
>
{});
const
auto
bs_grid_desc_bk0_n_bk1
=
generate_tuple
(
[
&
](
auto
)
{
return
arg
.
b_grid_desc_bk0_n_bk1_
;
},
Number
<
NumBTensor
>
{});
const
auto
kernel
=
kernel_grouped_conv_fwd_multiple_d_xdl_cshuffle
<
GridwiseGemm
,
AGridPointer
,
BGridPointer
,
typename
GridwiseGemm
::
DsGridPointer
,
EDataType
,
AElementwiseOperation
,
BElementwiseOperation
,
CDEElementwiseOperation
,
decltype
(
as_grid_desc_ak0_m_ak1
),
decltype
(
bs_grid_desc_bk0_n_bk1
),
DeviceOp
::
DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
,
DeviceOp
::
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
,
Block2ETileMap
,
ComputePtrOffsetOfStridedBatch
<
NumATensor
,
NumBTensor
,
NumDTensor
>
,
has_main_loop
,
isMultiA
,
isMultiB
>
;
return
launch_and_time_kernel
(
stream_config
,
kernel
,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
arg
.
p_as_grid_
,
arg
.
p_bs_grid_
,
arg
.
p_ds_grid_
,
arg
.
p_e_grid_
,
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
cde_element_op_
,
arg
.
a_g_n_c_wis_lengths_
[
0
],
// Group count
as_grid_desc_ak0_m_ak1
,
bs_grid_desc_bk0_n_bk1
,
arg
.
ds_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg
.
e_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg
.
block_2_etile_map_
,
arg
.
compute_ptr_offset_of_batch_
);
}
else
{
const
auto
kernel
=
kernel_grouped_conv_fwd_multiple_d_xdl_cshuffle
<
const
auto
kernel
=
kernel_grouped_conv_fwd_multiple_d_xdl_cshuffle
<
GridwiseGemm
,
GridwiseGemm
,
ADataType
,
// TODO: distiguish A/B datatype
const
ADataType
*
,
const
BDataType
*
,
typename
GridwiseGemm
::
DsGridPointer
,
typename
GridwiseGemm
::
DsGridPointer
,
EDataType
,
EDataType
,
AElementwiseOperation
,
AElementwiseOperation
,
...
@@ -595,16 +757,19 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle
...
@@ -595,16 +757,19 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle
DeviceOp
::
DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
,
DeviceOp
::
DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
,
DeviceOp
::
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
,
DeviceOp
::
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
,
Block2ETileMap
,
Block2ETileMap
,
ComputePtrOffsetOfStridedBatch
<
NumDTensor
>
,
ComputePtrOffsetOfStridedBatch
<
NumATensor
,
NumBTensor
,
NumDTensor
>
,
has_main_loop
>
;
has_main_loop
,
isMultiA
,
isMultiB
>
;
return
launch_and_time_kernel
(
stream_config
,
return
launch_and_time_kernel
(
stream_config
,
kernel
,
kernel
,
dim3
(
grid_size
),
dim3
(
grid_size
),
dim3
(
BlockSize
),
dim3
(
BlockSize
),
0
,
0
,
arg
.
p_a_grid_
,
arg
.
p_a
s
_grid_
.
At
(
I0
),
// Pass just A descriptor instead of tuple
arg
.
p_b_grid_
,
arg
.
p_b
s
_grid_
.
At
(
I0
),
// Pass just B descriptor instead of tuple
arg
.
p_ds_grid_
,
arg
.
p_ds_grid_
,
arg
.
p_e_grid_
,
arg
.
p_e_grid_
,
arg
.
a_element_op_
,
arg
.
a_element_op_
,
...
@@ -617,6 +782,7 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle
...
@@ -617,6 +782,7 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle
arg
.
e_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg
.
e_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg
.
block_2_etile_map_
,
arg
.
block_2_etile_map_
,
arg
.
compute_ptr_offset_of_batch_
);
arg
.
compute_ptr_offset_of_batch_
);
}
};
};
if
(
GridwiseGemm
::
CalculateHasMainKBlockLoop
(
K
))
if
(
GridwiseGemm
::
CalculateHasMainKBlockLoop
(
K
))
...
@@ -791,12 +957,28 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle
...
@@ -791,12 +957,28 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle
}
}
// check Gridwise GEMM
// check Gridwise GEMM
if
constexpr
(
isMultiA
||
isMultiB
)
{
// Genarate tuples with the same descriptors
const
auto
as_grid_desc_ak0_m_ak1
=
generate_tuple
([
&
](
auto
)
{
return
arg
.
a_grid_desc_m_k_
;
},
Number
<
NumATensor
>
{});
const
auto
bs_grid_desc_bk0_n_bk1
=
generate_tuple
([
&
](
auto
)
{
return
arg
.
b_grid_desc_n_k_
;
},
Number
<
NumBTensor
>
{});
return
GridwiseGemm
::
CheckValidity
(
as_grid_desc_ak0_m_ak1
,
bs_grid_desc_bk0_n_bk1
,
arg
.
ds_grid_desc_m_n_
,
arg
.
e_grid_desc_m_n_
,
arg
.
block_2_etile_map_
);
}
else
{
return
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_m_k_
,
return
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_m_k_
,
arg
.
b_grid_desc_n_k_
,
arg
.
b_grid_desc_n_k_
,
arg
.
ds_grid_desc_m_n_
,
arg
.
ds_grid_desc_m_n_
,
arg
.
e_grid_desc_m_n_
,
arg
.
e_grid_desc_m_n_
,
arg
.
block_2_etile_map_
);
arg
.
block_2_etile_map_
);
}
}
}
bool
IsSupportedArgument
(
const
BaseArgument
*
p_arg
)
override
bool
IsSupportedArgument
(
const
BaseArgument
*
p_arg
)
override
{
{
...
@@ -804,9 +986,9 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle
...
@@ -804,9 +986,9 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle
}
}
static
auto
MakeArgument
(
static
auto
MakeArgument
(
const
void
*
p_a
,
APointers
p_a
s
,
const
void
*
p_b
,
BPointers
p_b
s
,
const
std
::
array
<
const
void
*
,
NumDTensor
>&
p_ds
,
std
::
array
<
const
void
*
,
NumDTensor
>&
p_ds
,
void
*
p_e
,
void
*
p_e
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
a_g_n_c_wis_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
a_g_n_c_wis_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
a_g_n_c_wis_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
a_g_n_c_wis_strides
,
...
@@ -824,8 +1006,8 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle
...
@@ -824,8 +1006,8 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle
const
BElementwiseOperation
&
b_element_op
,
const
BElementwiseOperation
&
b_element_op
,
const
CDEElementwiseOperation
&
cde_element_op
)
const
CDEElementwiseOperation
&
cde_element_op
)
{
{
return
Argument
{
p_a
,
return
Argument
{
p_a
s
,
p_b
,
p_b
s
,
p_ds
,
p_ds
,
p_e
,
p_e
,
a_g_n_c_wis_lengths
,
a_g_n_c_wis_lengths
,
...
@@ -848,8 +1030,8 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle
...
@@ -848,8 +1030,8 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
void
*
p_a
,
APointers
p_a
,
const
void
*
p_b
,
BPointers
p_b
,
const
std
::
array
<
const
void
*
,
NumDTensor
>&
p_ds
,
const
std
::
array
<
const
void
*
,
NumDTensor
>&
p_ds
,
void
*
p_e
,
void
*
p_e
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
a_g_n_c_wis_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
a_g_n_c_wis_lengths
,
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp
View file @
83be9a70
...
@@ -9,31 +9,112 @@ namespace ck {
...
@@ -9,31 +9,112 @@ namespace ck {
namespace
tensor_operation
{
namespace
tensor_operation
{
namespace
device
{
namespace
device
{
template
<
index_t
NumDTensor
>
template
<
index_t
NumATensor
=
1
,
index_t
NumBTensor
=
1
,
index_t
NumDTensor
=
0
>
struct
ComputePtrOffsetOfStridedBatch
struct
ComputePtrOffsetOfStridedBatch
{
{
static
constexpr
bool
isMultiAB
=
NumATensor
>
1
||
NumBTensor
>
1
;
ComputePtrOffsetOfStridedBatch
()
=
default
;
ComputePtrOffsetOfStridedBatch
()
=
default
;
ComputePtrOffsetOfStridedBatch
(
index_t
BatchStrideA
,
ComputePtrOffsetOfStridedBatch
(
index_t
BatchStrideA
,
index_t
BatchStrideB
,
index_t
BatchStrideB
,
Array
<
ck
::
index_t
,
NumDTensor
>
BatchStrideDs
,
Array
<
ck
::
index_t
,
NumDTensor
>
BatchStrideDs
,
index_t
BatchStrideE
)
index_t
BatchStrideE
)
:
BatchStrideA_
(
BatchStrideA
),
:
BatchStrideA_
(),
BatchStrideB_
(
BatchStrideB
),
BatchStrideB_
(),
BatchStrideDs_
(
BatchStrideDs
),
BatchStrideE_
(
BatchStrideE
)
{
if
constexpr
(
!
isMultiAB
)
{
BatchStrideA_
=
BatchStrideA
;
BatchStrideB_
=
BatchStrideB
;
}
else
{
static_assert
(
"Invalid constructor for multiple A or B"
);
}
}
ComputePtrOffsetOfStridedBatch
(
Array
<
ck
::
index_t
,
NumATensor
>
BatchStrideAs
,
Array
<
ck
::
index_t
,
NumBTensor
>
BatchStrideBs
,
Array
<
ck
::
index_t
,
NumDTensor
>
BatchStrideDs
,
index_t
BatchStrideE
)
:
BatchStrideA_
(),
BatchStrideB_
(),
BatchStrideDs_
(
BatchStrideDs
),
BatchStrideDs_
(
BatchStrideDs
),
BatchStrideE_
(
BatchStrideE
)
BatchStrideE_
(
BatchStrideE
)
{
{
if
constexpr
(
isMultiAB
)
{
BatchStrideA_
=
BatchStrideAs
;
BatchStrideB_
=
BatchStrideBs
;
}
else
{
static_assert
(
"Invalid constructor for single A and B"
);
}
}
}
__host__
__device__
constexpr
long_index_t
GetAPtrOffset
(
index_t
g_idx
)
const
__host__
__device__
constexpr
long_index_t
GetAPtrOffset
(
index_t
g_idx
)
const
{
if
constexpr
(
!
isMultiAB
)
{
{
return
g_idx
*
static_cast
<
long_index_t
>
(
BatchStrideA_
);
return
g_idx
*
static_cast
<
long_index_t
>
(
BatchStrideA_
);
}
}
else
{
static_assert
(
"Invalid function for multiple A or B"
);
return
0
;
}
}
__host__
__device__
constexpr
long_index_t
GetBPtrOffset
(
index_t
g_idx
)
const
__host__
__device__
constexpr
long_index_t
GetBPtrOffset
(
index_t
g_idx
)
const
{
if
constexpr
(
!
isMultiAB
)
{
{
return
g_idx
*
static_cast
<
long_index_t
>
(
BatchStrideB_
);
return
g_idx
*
static_cast
<
long_index_t
>
(
BatchStrideB_
);
}
}
else
{
static_assert
(
"Invalid function for multiple A or B"
);
return
0
;
}
}
__host__
__device__
constexpr
auto
GetAsPtrOffset
(
index_t
g_idx
)
const
{
if
constexpr
(
isMultiAB
)
{
Array
<
long_index_t
,
NumATensor
>
as_offset
;
static_for
<
0
,
NumATensor
,
1
>
{}([
&
](
auto
i
)
{
as_offset
(
i
)
=
g_idx
*
static_cast
<
long_index_t
>
(
BatchStrideA_
[
i
]);
});
return
as_offset
;
}
else
{
static_assert
(
"Invalid function for single A and B"
);
return
BatchStrideA_
;
}
}
__host__
__device__
constexpr
auto
GetBsPtrOffset
(
index_t
g_idx
)
const
{
if
constexpr
(
isMultiAB
)
{
Array
<
long_index_t
,
NumBTensor
>
bs_offset
;
static_for
<
0
,
NumBTensor
,
1
>
{}([
&
](
auto
i
)
{
bs_offset
(
i
)
=
g_idx
*
static_cast
<
long_index_t
>
(
BatchStrideB_
[
i
]);
});
return
bs_offset
;
}
else
{
static_assert
(
"Invalid function for single A and B"
);
return
BatchStrideB_
;
}
}
__host__
__device__
constexpr
auto
GetDsPtrOffset
(
index_t
g_idx
)
const
__host__
__device__
constexpr
auto
GetDsPtrOffset
(
index_t
g_idx
)
const
{
{
...
@@ -54,13 +135,73 @@ struct ComputePtrOffsetOfStridedBatch
...
@@ -54,13 +135,73 @@ struct ComputePtrOffsetOfStridedBatch
return
g_idx
*
static_cast
<
long_index_t
>
(
BatchStrideE_
);
return
g_idx
*
static_cast
<
long_index_t
>
(
BatchStrideE_
);
}
}
index_t
BatchStrideA_
;
// If multiAB use Array
index_t
BatchStrideB_
;
using
BatchStrideAType
=
std
::
conditional_t
<
isMultiAB
,
Array
<
ck
::
index_t
,
NumATensor
>
,
ck
::
index_t
>
;
using
BatchStrideBType
=
std
::
conditional_t
<
isMultiAB
,
Array
<
ck
::
index_t
,
NumBTensor
>
,
ck
::
index_t
>
;
BatchStrideAType
BatchStrideA_
;
BatchStrideBType
BatchStrideB_
;
Array
<
ck
::
index_t
,
NumDTensor
>
BatchStrideDs_
;
Array
<
ck
::
index_t
,
NumDTensor
>
BatchStrideDs_
;
index_t
BatchStrideE_
;
index_t
BatchStrideE_
;
index_t
&
BatchStrideC_
=
BatchStrideE_
;
// alias for kernels without multiple D
index_t
&
BatchStrideC_
=
BatchStrideE_
;
// alias for kernels without multiple D
};
};
template
<
bool
isTuple
,
typename
Tensors
>
constexpr
static
auto
GetNumABTensors
()
{
if
constexpr
(
isTuple
)
{
return
Number
<
Tensors
::
Size
()
>
{};
}
else
{
return
Number
<
1
>
{};
}
}
template
<
bool
isTuple
,
typename
GridwiseGemm
,
typename
DataType
>
constexpr
static
auto
GetAGridPointer
()
{
if
constexpr
(
isTuple
)
{
return
typename
GridwiseGemm
::
AsGridPointer
{};
}
else
{
return
Tuple
<
const
DataType
*>
{};
}
}
template
<
bool
isTuple
,
typename
GridwiseGemm
,
typename
DataType
>
constexpr
static
auto
GetBGridPointer
()
{
if
constexpr
(
isTuple
)
{
return
typename
GridwiseGemm
::
BsGridPointer
{};
}
else
{
return
Tuple
<
const
DataType
*>
{};
}
}
template
<
bool
isTuple
,
typename
Id
,
typename
Type
>
constexpr
static
auto
UnpackDataType
()
{
if
constexpr
(
isTuple
)
{
// unpack if tuple
return
tuple_element_t
<
Id
{},
Type
>
{};
}
else
{
// if no, return Type
return
Type
{};
}
}
}
// namespace device
}
// namespace device
}
// namespace tensor_operation
}
// namespace tensor_operation
}
// namespace ck
}
// namespace ck
include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp
View file @
83be9a70
...
@@ -85,10 +85,13 @@ struct Add
...
@@ -85,10 +85,13 @@ struct Add
struct
ScaleAdd
struct
ScaleAdd
{
{
__host__
__device__
ScaleAdd
(
float
scale
)
:
scale_
(
scale
)
{}
__host__
__device__
ScaleAdd
(
float
scale
=
1.
f
)
:
scale_
(
scale
)
{}
template
<
typename
Y
,
typename
X0
,
typename
X1
>
template
<
typename
Y
,
typename
X0
,
typename
X1
>
__host__
__device__
constexpr
void
operator
()(
Y
&
y
,
const
X0
&
x0
,
const
X1
&
x1
)
const
;
__host__
__device__
constexpr
void
operator
()(
Y
&
y
,
const
X0
&
x0
,
const
X1
&
x1
)
const
{
y
=
ck
::
type_convert
<
Y
>
(
scale_
*
ck
::
type_convert
<
float
>
(
x0
)
+
ck
::
type_convert
<
float
>
(
x1
));
}
template
<
>
template
<
>
__host__
__device__
void
__host__
__device__
void
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp
View file @
83be9a70
...
@@ -203,7 +203,7 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle
...
@@ -203,7 +203,7 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle
// A desc for source in blockwise copy
// A desc for source in blockwise copy
template
<
typename
AGridDesc_M_K
>
template
<
typename
AGridDesc_M_K
>
__host__
__device__
static
constexpr
auto
__host__
__device__
static
constexpr
auto
MakeAGridDescriptor_AK0_M_AK1
(
const
AGridDesc_M_K
&
a_grid_desc_m_k
)
Make
Default
AGridDescriptor_AK0_M_AK1
(
const
AGridDesc_M_K
&
a_grid_desc_m_k
)
{
{
const
auto
M
=
a_grid_desc_m_k
.
GetLength
(
I0
);
const
auto
M
=
a_grid_desc_m_k
.
GetLength
(
I0
);
const
auto
K
=
a_grid_desc_m_k
.
GetLength
(
I1
);
const
auto
K
=
a_grid_desc_m_k
.
GetLength
(
I1
);
...
@@ -219,7 +219,7 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle
...
@@ -219,7 +219,7 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle
template
<
typename
AsGridDesc_M_K
>
template
<
typename
AsGridDesc_M_K
>
__host__
__device__
static
constexpr
auto
__host__
__device__
static
constexpr
auto
MakeAsGridDescriptor_AK0_M_AK1
(
const
AsGridDesc_M_K
&
as_grid_desc_m_k
)
Make
Default
AsGridDescriptor_AK0_M_AK1
(
const
AsGridDesc_M_K
&
as_grid_desc_m_k
)
{
{
return
generate_tuple
(
return
generate_tuple
(
[
&
](
auto
i
)
{
return
MakeAGridDescriptor_AK0_M_AK1
(
as_grid_desc_m_k
[
i
]);
},
[
&
](
auto
i
)
{
return
MakeAGridDescriptor_AK0_M_AK1
(
as_grid_desc_m_k
[
i
]);
},
...
@@ -229,7 +229,7 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle
...
@@ -229,7 +229,7 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle
// B desc for source in blockwise copy
// B desc for source in blockwise copy
template
<
typename
BGridDesc_N_K
>
template
<
typename
BGridDesc_N_K
>
__host__
__device__
static
constexpr
auto
__host__
__device__
static
constexpr
auto
MakeBGridDescriptor_BK0_N_BK1
(
const
BGridDesc_N_K
&
b_grid_desc_n_k
)
Make
Default
BGridDescriptor_BK0_N_BK1
(
const
BGridDesc_N_K
&
b_grid_desc_n_k
)
{
{
const
auto
N
=
b_grid_desc_n_k
.
GetLength
(
I0
);
const
auto
N
=
b_grid_desc_n_k
.
GetLength
(
I0
);
const
auto
K
=
b_grid_desc_n_k
.
GetLength
(
I1
);
const
auto
K
=
b_grid_desc_n_k
.
GetLength
(
I1
);
...
@@ -245,7 +245,7 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle
...
@@ -245,7 +245,7 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle
template
<
typename
BsGridDesc_N_K
>
template
<
typename
BsGridDesc_N_K
>
__host__
__device__
static
constexpr
auto
__host__
__device__
static
constexpr
auto
MakeBsGridDescriptor_BK0_N_BK1
(
const
BsGridDesc_N_K
&
bs_grid_desc_n_k
)
Make
Default
BsGridDescriptor_BK0_N_BK1
(
const
BsGridDesc_N_K
&
bs_grid_desc_n_k
)
{
{
return
generate_tuple
(
return
generate_tuple
(
[
&
](
auto
i
)
{
return
MakeBGridDescriptor_BK0_N_BK1
(
bs_grid_desc_n_k
[
i
]);
},
[
&
](
auto
i
)
{
return
MakeBGridDescriptor_BK0_N_BK1
(
bs_grid_desc_n_k
[
i
]);
},
...
@@ -288,7 +288,7 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle
...
@@ -288,7 +288,7 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle
// return block_id to E matrix tile idx (m0, n0) mapping
// return block_id to E matrix tile idx (m0, n0) mapping
template
<
typename
EGridDesc_M_N
>
template
<
typename
EGridDesc_M_N
>
__host__
__device__
static
constexpr
auto
__host__
__device__
static
constexpr
auto
MakeBlock2ETileMap
(
const
EGridDesc_M_N
&
e_grid_desc_m_n
)
Make
Default
Block2ETileMap
(
const
EGridDesc_M_N
&
e_grid_desc_m_n
)
{
{
return
BlockToCTileMap_M00_N0_M01Adapt
<
MPerBlock
,
NPerBlock
,
EGridDesc_M_N
>
(
return
BlockToCTileMap_M00_N0_M01Adapt
<
MPerBlock
,
NPerBlock
,
EGridDesc_M_N
>
(
e_grid_desc_m_n
);
e_grid_desc_m_n
);
...
@@ -591,6 +591,9 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle
...
@@ -591,6 +591,9 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle
generate_tuple
([
&
](
auto
)
{
return
make_multi_index
(
0
,
m_block_data_idx_on_grid
,
0
);
},
generate_tuple
([
&
](
auto
)
{
return
make_multi_index
(
0
,
m_block_data_idx_on_grid
,
0
);
},
Number
<
NumATensor
>
{});
Number
<
NumATensor
>
{});
static_assert
(
ABlockTransferSrcScalarPerVector
==
ABlockTransferDstScalarPerVector_AK1
,
"Src and Dst ScalarPerVector must be the same"
);
auto
a_blockwise_copy
=
ThreadGroupTensorSliceTransfer_v7r2
<
auto
a_blockwise_copy
=
ThreadGroupTensorSliceTransfer_v7r2
<
ThisThreadBlock
,
ThisThreadBlock
,
AsDataType
,
AsDataType
,
...
@@ -619,6 +622,9 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle
...
@@ -619,6 +622,9 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle
generate_tuple
([
&
](
auto
)
{
return
make_multi_index
(
0
,
n_block_data_idx_on_grid
,
0
);
},
generate_tuple
([
&
](
auto
)
{
return
make_multi_index
(
0
,
n_block_data_idx_on_grid
,
0
);
},
Number
<
NumBTensor
>
{});
Number
<
NumBTensor
>
{});
static_assert
(
BBlockTransferSrcScalarPerVector
==
BBlockTransferDstScalarPerVector_BK1
,
"Src and Dst ScalarPerVector must be the same"
);
auto
b_blockwise_copy
=
ThreadGroupTensorSliceTransfer_v7r2
<
auto
b_blockwise_copy
=
ThreadGroupTensorSliceTransfer_v7r2
<
ThisThreadBlock
,
ThisThreadBlock
,
BsDataType
,
BsDataType
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment