Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
e70a4d19
Commit
e70a4d19
authored
Dec 13, 2023
by
Jun Liu
Browse files
Merge branch 'amd-develop' into amd-master
parents
ce72f286
0dacd895
Changes
472
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1869 additions
and
110 deletions
+1869
-110
include/ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_abd.hpp
...ation/gpu/device/device_grouped_conv_fwd_multiple_abd.hpp
+132
-0
include/ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_d.hpp
...eration/gpu/device/device_grouped_conv_fwd_multiple_d.hpp
+41
-40
include/ck/tensor_operation/gpu/device/device_normalization_bwd_gamma_beta.hpp
...ration/gpu/device/device_normalization_bwd_gamma_beta.hpp
+61
-0
include/ck/tensor_operation/gpu/device/device_normalization_fwd.hpp
.../tensor_operation/gpu/device/device_normalization_fwd.hpp
+9
-9
include/ck/tensor_operation/gpu/device/impl/device_batchnorm_backward_impl.hpp
...ration/gpu/device/impl/device_batchnorm_backward_impl.hpp
+3
-1
include/ck/tensor_operation/gpu/device/impl/device_batchnorm_forward_impl.hpp
...eration/gpu/device/impl/device_batchnorm_forward_impl.hpp
+3
-1
include/ck/tensor_operation/gpu/device/impl/device_batchnorm_forward_impl_obsolete.hpp
...pu/device/impl/device_batchnorm_forward_impl_obsolete.hpp
+3
-1
include/ck/tensor_operation/gpu/device/impl/device_column_to_image_impl.hpp
...operation/gpu/device/impl/device_column_to_image_impl.hpp
+53
-28
include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_abd_xdl_cshuffle.hpp
...ice/impl/device_contraction_multiple_abd_xdl_cshuffle.hpp
+8
-6
include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp
...evice/impl/device_contraction_multiple_d_xdl_cshuffle.hpp
+7
-5
include/ck/tensor_operation/gpu/device/impl/device_elementwise_3d_impl.hpp
..._operation/gpu/device/impl/device_elementwise_3d_impl.hpp
+371
-0
include/ck/tensor_operation/gpu/device/impl/device_elementwise_scale_impl.hpp
...eration/gpu/device/impl/device_elementwise_scale_impl.hpp
+329
-0
include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_abd_xdl_cshuffle.hpp
...gpu/device/impl/device_gemm_multiple_abd_xdl_cshuffle.hpp
+8
-6
include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp
...ce/impl/device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp
+3
-1
include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_lds_direct_load.hpp
...l/device_gemm_multiple_d_xdl_cshuffle_lds_direct_load.hpp
+414
-0
include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl.hpp
...e/ck/tensor_operation/gpu/device/impl/device_gemm_xdl.hpp
+2
-1
include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp
...or_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp
+2
-1
include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_lds_direct_load.hpp
.../device/impl/device_gemm_xdl_cshuffle_lds_direct_load.hpp
+394
-0
include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_splitk_c_shuffle.hpp
...tion/gpu/device/impl/device_gemm_xdl_splitk_c_shuffle.hpp
+23
-9
include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_streamk.hpp
...sor_operation/gpu/device/impl/device_gemm_xdl_streamk.hpp
+3
-1
No files found.
include/ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_abd.hpp
0 → 100644
View file @
e70a4d19
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <array>
#include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp"
#include "ck/utility/is_detected.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
template
<
typename
T
>
using
is_tuple
=
decltype
(
std
::
declval
<
T
&>
().
IsTuple
());
/**
* \brief Grouped Convolution Forward
*
* \details
* input : input image A[G, N, C, Hi, Wi], A1[G, N, C, Hi, Wi]...
* input : weight B[G, K, C, Y, X], B1[G, K, C, Y, X]...
* input : D0[G, N, K, Ho, Wo], D1[G, N, K, Ho, Wo], ...
* output : output image E[G, N, K, Ho, Wo]
*
* C = a_op(A, A1...) * b_op(B, B1...)
* E = cde_op(C, D0, D1, ...)
*
* \tparam NDimSpatial Number of spatial dimensions.
* \tparam ALayout Input layout (also for a1, a2...).
* \tparam BLayout Weight layout (also for b1, b2...).
* \tparam DsLayout Ds layouts.
* \tparam ELayout Output layout.
* \tparam ADataType Input data type. Pass tuple if there is multiple A.
* \tparam BDataType Weight data type. Pass tuple if there is multiple B.
* \tparam DsDataType D data types.
* \tparam EDataType Output data type.
* \tparam AElementwiseOperation A elementwise operation.
* \tparam BElementwiseOperation B elementwise operation.
* \tparam CDEElementwiseOperation CDE elementwise operation.
* \tparam ComputeType Compute data type (default: ADataType, first if tuple passed).
*/
template
<
index_t
NDimSpatial
,
typename
ALayout
,
typename
BLayout
,
typename
DsLayout
,
typename
ELayout
,
typename
ADataType
,
typename
BDataType
,
typename
DsDataType
,
typename
EDataType
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CDEElementwiseOperation
,
typename
ComputeType
=
decltype
(
UnpackDataType
<
is_detected
<
is_tuple
,
ADataType
>
::
value
,
Number
<
0
>
,
ADataType
>
())
>
// ComputeType is InputType by default (first
// in tuple for MultiAB), unpack if tuple was
// passed
struct
DeviceGroupedConvFwdMultipleABD
:
public
BaseOperator
{
static
constexpr
bool
isMultiA
=
is_detected
<
is_tuple
,
ADataType
>::
value
;
static
constexpr
bool
isMultiB
=
is_detected
<
is_tuple
,
BDataType
>::
value
;
static
constexpr
index_t
NumATensor
=
GetNumABTensors
<
isMultiA
,
ADataType
>
();
static
constexpr
index_t
NumBTensor
=
GetNumABTensors
<
isMultiB
,
BDataType
>
();
static
constexpr
index_t
NumDTensor
=
DsDataType
::
Size
();
static_assert
(
NumDTensor
==
DsLayout
::
Size
(),
"wrong! Inconsistent NumDTensor"
);
// If DataType is tuple, user has to pass std::array with pointers.
using
APointers
=
std
::
conditional_t
<
isMultiA
,
std
::
array
<
const
void
*
,
NumATensor
>&
,
const
void
*>
;
using
BPointers
=
std
::
conditional_t
<
isMultiB
,
std
::
array
<
const
void
*
,
NumBTensor
>&
,
const
void
*>
;
/**
* \brief Make argument pointer for grouped conv fwd.
*
* \param p_a A pointer to the input (std::array<const void*, NumA> with
pointers for multiple A).
* \param p_b A pointer to the weight (std::array<const void*, NumA> with
pointers for multiple B).
* \param p_ds A pointers to the Ds.
* \param p_e A pointers to the output.
* \param a_g_n_c_wis_lengths Input lengths [G, N, C, Spatial...] (for 3d).
* \param a_g_n_c_wis_strides Input strides [G, N, C, Spatial...] (for 3d).
* \param b_g_k_c_xs_lengths Weight lengths [G, K, C, Spatial...] (for 3d).
* \param b_g_k_c_xs_strides Weight strides [G, K, C, Spatial...] (for 3d).
* \param ds_g_n_k_wos_lengths Ds lengths [G, N, K, Spatial...] (for 3d).
* \param ds_g_n_k_wos_strides Ds strides [G, N, K, Spatial...] (for 3d).
* \param e_g_n_k_wos_lengths Output lengths [G, N, K, Spatial...] (for 3d).
* \param e_g_n_k_wos_strides Output strides [G, N, K, Spatial...] (for 3d).
* \param conv_filter_strides Convolution filter strides.
* \param conv_filter_dilations Convolution filter dilations.
* \param input_left_pads Input left paddings.
* \param input_right_pads Input right paddings.
* \param a_element_op A elementwise operation object.
* \param b_element_op B elementwise operation object.
* \param cde_element_op CDE elementwise operation object.
* \return Pointer to the argument.
*/
virtual
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
APointers
p_a
,
BPointers
p_b
,
const
std
::
array
<
const
void
*
,
NumDTensor
>&
p_ds
,
void
*
p_e
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
a_g_n_c_wis_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
a_g_n_c_wis_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
b_g_k_c_xs_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
b_g_k_c_xs_strides
,
const
std
::
array
<
std
::
array
<
index_t
,
NDimSpatial
+
3
>
,
NumDTensor
>&
ds_g_n_k_wos_lengths
,
const
std
::
array
<
std
::
array
<
index_t
,
NDimSpatial
+
3
>
,
NumDTensor
>&
ds_g_n_k_wos_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
e_g_n_k_wos_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
e_g_n_k_wos_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_dilations
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_left_pads
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_right_pads
,
const
AElementwiseOperation
&
a_element_op
,
const
BElementwiseOperation
&
b_element_op
,
const
CDEElementwiseOperation
&
cde_element_op
)
=
0
;
virtual
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
=
0
;
};
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
include/ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_d.hpp
View file @
e70a4d19
...
...
@@ -3,21 +3,33 @@
#pragma once
#include <array>
#include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_abd.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
// Convolution Forward:
// input : input image A[G, N, C, Hi, Wi],
// input : weight B[G, K, C, Y, X],
// input : D0[G, N, K, Ho, Wo], D1[G, N, K, Ho, Wo], ...
// output : output image E[G, N, K, Ho, Wo]
// C = a_op(A) * b_op(B)
// E = cde_op(C, D0, D1, ...)
/**
* \brief Grouped Convolution Forward
*
* \note This structure is deprecated (left for backwards compatibility). Please use
* DeviceGroupedConvFwdMultipleABD.
*
* \tparam NDimSpatial Number of spatial dimensions.
* \tparam ALayout Input layout (also for a1, a2...).
* \tparam BLayout Weight layout (also for b1, b2...).
* \tparam DsLayout Ds layouts.
* \tparam ELayout Output layout.
* \tparam ADataType Input data type. Pass tuple if there is multiple A.
* \tparam BDataType Weight data type. Pass tuple if there is multiple B.
* \tparam DsDataType D data types.
* \tparam EDataType Output data type.
* \tparam AElementwiseOperation A elementwise operation.
* \tparam BElementwiseOperation B elementwise operation.
* \tparam CDEElementwiseOperation CDE elementwise operation.
* \tparam ComputeType Compute data type (default: ADataType, first if tuple passed).
*/
template
<
index_t
NDimSpatial
,
typename
ALayout
,
typename
BLayout
,
...
...
@@ -30,36 +42,25 @@ template <index_t NDimSpatial,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CDEElementwiseOperation
,
typename
ComputeType
=
ADataType
>
struct
DeviceGroupedConvFwdMultipleD
:
public
BaseOperator
{
static
constexpr
index_t
NumDTensor
=
DsDataType
::
Size
();
static_assert
(
NumDTensor
==
DsLayout
::
Size
(),
"wrong! Inconsistent NumDTensor"
);
virtual
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
void
*
p_a
,
// input image
const
void
*
p_b
,
// weight
const
std
::
array
<
const
void
*
,
NumDTensor
>&
p_ds
,
void
*
p_e
,
// output image
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
a_g_n_c_wis_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
a_g_n_c_wis_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
b_g_k_c_xs_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
b_g_k_c_xs_strides
,
const
std
::
array
<
std
::
array
<
index_t
,
NDimSpatial
+
3
>
,
NumDTensor
>&
ds_g_n_k_wos_lengths
,
const
std
::
array
<
std
::
array
<
index_t
,
NDimSpatial
+
3
>
,
NumDTensor
>&
ds_g_n_k_wos_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
e_g_n_k_wos_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
e_g_n_k_wos_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_dilations
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_left_pads
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_right_pads
,
const
AElementwiseOperation
&
a_element_op
,
const
BElementwiseOperation
&
b_element_op
,
const
CDEElementwiseOperation
&
cde_element_op
)
=
0
;
virtual
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
=
0
;
};
typename
ComputeType
=
decltype
(
UnpackDataType
<
is_detected
<
is_tuple
,
ADataType
>
::
value
,
Number
<
0
>
,
ADataType
>
())
>
// ComputeType is InputType by default (first
// in tuple for MultiAB), unpack if tuple was
// passed
using
DeviceGroupedConvFwdMultipleD
=
DeviceGroupedConvFwdMultipleABD
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
ADataType
,
BDataType
,
DsDataType
,
EDataType
,
AElementwiseOperation
,
BElementwiseOperation
,
CDEElementwiseOperation
,
ComputeType
>
;
}
// namespace device
}
// namespace tensor_operation
...
...
include/ck/tensor_operation/gpu/device/device_normalization_bwd_gamma_beta.hpp
0 → 100644
View file @
e70a4d19
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <vector>
#include "ck/tensor_operation/gpu/device/device_base.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
template
<
typename
DYDataType
,
typename
XDataType
,
typename
MeanInvStdDataType
,
typename
DGammaDataType
,
typename
DBetaDataType
,
index_t
Rank
,
index_t
NumReduceDim
>
struct
DeviceNormalizationBwdGammaBeta
:
public
BaseOperator
{
virtual
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
std
::
vector
<
index_t
>
inLengths
,
const
std
::
vector
<
index_t
>
dyStrides
,
const
std
::
vector
<
index_t
>
xStrides
,
const
std
::
vector
<
index_t
>
meanStrides
,
const
std
::
vector
<
index_t
>
invStdStrides
,
const
std
::
vector
<
index_t
>
outLengths
,
const
std
::
vector
<
index_t
>
dgammaStrides
,
const
std
::
vector
<
index_t
>
dbetaStrides
,
const
std
::
vector
<
index_t
>
reduceDims
,
const
void
*
p_dy
,
const
void
*
p_x
,
const
void
*
p_mean
,
const
void
*
p_invStd
,
void
*
p_dgamma
,
void
*
p_dbeta
)
=
0
;
virtual
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
=
0
;
};
template
<
typename
DYDataType
,
typename
XDataType
,
typename
MeanInvStdDataType
,
typename
DGammaDataType
,
typename
DBetaDataType
,
index_t
Rank
,
index_t
NumReduceDim
>
using
DeviceNormalizationBwdGammaBetaPtr
=
std
::
unique_ptr
<
DeviceNormalizationBwdGammaBeta
<
DYDataType
,
XDataType
,
MeanInvStdDataType
,
DGammaDataType
,
DBetaDataType
,
Rank
,
NumReduceDim
>>
;
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
include/ck/tensor_operation/gpu/device/device_normalization.hpp
→
include/ck/tensor_operation/gpu/device/device_normalization
_fwd
.hpp
View file @
e70a4d19
...
...
@@ -19,7 +19,7 @@ template <typename XDataType,
typename
YElementwiseOperation
,
index_t
Rank
,
index_t
NumReduceDim
>
struct
DeviceNormalization
:
public
BaseOperator
struct
DeviceNormalization
Fwd
:
public
BaseOperator
{
virtual
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
std
::
vector
<
index_t
>
lengths
,
...
...
@@ -50,14 +50,14 @@ template <typename XDataType,
typename
YElementwiseOperation
,
index_t
Rank
,
index_t
NumReduceDim
>
using
DeviceNormalizationPtr
=
std
::
unique_ptr
<
DeviceNormalization
<
XDataType
,
GammaDataType
,
BetaDataType
,
YDataType
,
SaveMeanInvStdDataType
,
YElementwiseOperation
,
Rank
,
NumReduceDim
>>
;
using
DeviceNormalization
Fwd
Ptr
=
std
::
unique_ptr
<
DeviceNormalization
Fwd
<
XDataType
,
GammaDataType
,
BetaDataType
,
YDataType
,
SaveMeanInvStdDataType
,
YElementwiseOperation
,
Rank
,
NumReduceDim
>>
;
}
// namespace device
}
// namespace tensor_operation
...
...
include/ck/tensor_operation/gpu/device/impl/device_batchnorm_backward_impl.hpp
View file @
e70a4d19
...
...
@@ -376,7 +376,9 @@ struct DeviceBatchNormBwdImpl : public DeviceBatchNormBwd<XDataType,
return
(
workspace_size
);
};
void
SetWorkSpacePointer
(
BaseArgument
*
pArg
,
void
*
p_workspace
)
const
override
void
SetWorkSpacePointer
(
BaseArgument
*
pArg
,
void
*
p_workspace
,
const
StreamConfig
&
=
StreamConfig
{})
const
override
{
Argument
*
pArg_
=
dynamic_cast
<
Argument
*>
(
pArg
);
...
...
include/ck/tensor_operation/gpu/device/impl/device_batchnorm_forward_impl.hpp
View file @
e70a4d19
...
...
@@ -354,7 +354,9 @@ struct DeviceBatchNormFwdImpl : public DeviceBatchNormFwd<XDataType,
return
(
workspace_size
);
};
void
SetWorkSpacePointer
(
BaseArgument
*
pArg
,
void
*
p_workspace
)
const
override
void
SetWorkSpacePointer
(
BaseArgument
*
pArg
,
void
*
p_workspace
,
const
StreamConfig
&
=
StreamConfig
{})
const
override
{
Argument
*
pArg_
=
dynamic_cast
<
Argument
*>
(
pArg
);
...
...
include/ck/tensor_operation/gpu/device/impl/device_batchnorm_forward_impl_obsolete.hpp
View file @
e70a4d19
...
...
@@ -345,7 +345,9 @@ struct DeviceBatchNormFwdImpl : public DeviceBatchNormFwd<XDataType,
return
(
workspace_size
);
};
void
SetWorkSpacePointer
(
BaseArgument
*
pArg
,
void
*
p_workspace
)
const
override
void
SetWorkSpacePointer
(
BaseArgument
*
pArg
,
void
*
p_workspace
,
const
StreamConfig
&
=
StreamConfig
{})
const
override
{
Argument
*
pArg_
=
dynamic_cast
<
Argument
*>
(
pArg
);
...
...
include/ck/tensor_operation/gpu/device/impl/device_column_to_image_impl.hpp
View file @
e70a4d19
...
...
@@ -17,15 +17,18 @@
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/conv_tensor_rearrange_op.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp"
#include "ck/host_utility/io.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
// Image to column for input layout NDHWC:
// input : image converted to the gemm problem [N * Do * Ho * Wo, Z * Y * X * C]
// output : image [N, Di, Hi, Wi, C]
// Column to Image:
// input : gemm form [G, N * Do * Ho * Wo, Z * Y * X * C]
// output : input image [G, N, Di, Hi, Wi, C]
// input : gemm form [N * Do * Ho * Wo, G, Z * Y * X * C]
// output : input image [N, Di, Hi, Wi, G, C]
template
<
index_t
NDimSpatial
,
typename
ImageLayout
,
typename
InputDataType
,
...
...
@@ -43,6 +46,14 @@ struct DeviceColumnToImageImpl
OutputDataType
,
conv_tensor_rearrange_op
::
ColumnToImage
>
{
static
constexpr
bool
is_NSpatialGC
=
std
::
is_same_v
<
ImageLayout
,
tensor_layout
::
convolution
::
NWGC
>
||
std
::
is_same_v
<
ImageLayout
,
tensor_layout
::
convolution
::
NHWGC
>
||
std
::
is_same_v
<
ImageLayout
,
tensor_layout
::
convolution
::
NDHWGC
>
;
static
constexpr
bool
is_GNSpatialC
=
std
::
is_same_v
<
ImageLayout
,
tensor_layout
::
convolution
::
GNWC
>
||
std
::
is_same_v
<
ImageLayout
,
tensor_layout
::
convolution
::
GNHWC
>
||
std
::
is_same_v
<
ImageLayout
,
tensor_layout
::
convolution
::
GNDHWC
>
;
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
...
...
@@ -90,7 +101,7 @@ struct DeviceColumnToImageImpl
const
std
::
array
<
index_t
,
NDimSpatial
>&
filter_spatial_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
output_spatial_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_strides
,
const
std
::
array
<
index_t
,
2
>&
gemm_m_k_strides
,
const
std
::
array
<
index_t
,
3
>&
gemm_
g_
m_k_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
independent_filters
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
effs
)
{
...
...
@@ -100,23 +111,23 @@ struct DeviceColumnToImageImpl
C
*
ck
::
accumulate_n
<
index_t
>
(
filter_spatial_lengths
.
begin
(),
NDimSpatial
,
1
,
std
::
multiplies
<>
());
const
index_t
NStride
=
DoHoWo
*
gemm_m_k_strides
[
I
0
]
*
gemm_m_k_strides
[
I
1
];
const
index_t
NStride
=
DoHoWo
*
gemm_
g_
m_k_strides
[
I
1
]
*
gemm_
g_
m_k_strides
[
I
2
];
// Calculate the appropriate stride for each set of independent filters
// in each dimension
const
index_t
WStride
=
math
::
integer_divide_ceil
(
effs
[
XIdx
],
conv_filter_strides
[
XIdx
])
*
gemm_m_k_strides
[
I
0
];
const
index_t
WStride
=
math
::
integer_divide_ceil
(
effs
[
XIdx
],
conv_filter_strides
[
XIdx
])
*
gemm_
g_
m_k_strides
[
I
1
];
const
index_t
HStride
=
math
::
integer_divide_ceil
(
effs
[
YIdx
],
conv_filter_strides
[
YIdx
])
*
output_spatial_lengths
[
XIdx
]
*
gemm_m_k_strides
[
I
0
];
output_spatial_lengths
[
XIdx
]
*
gemm_
g_
m_k_strides
[
I
1
];
const
index_t
DStride
=
math
::
integer_divide_ceil
(
effs
[
ZIdx
],
conv_filter_strides
[
ZIdx
])
*
output_spatial_lengths
[
YIdx
]
*
output_spatial_lengths
[
XIdx
]
*
gemm_m_k_strides
[
I
0
];
gemm_
g_
m_k_strides
[
I
1
];
// Create descriptor for independent filters in each dimension and
// then merge them into column form
if
constexpr
(
NDimSpatial
==
1
)
{
const
auto
desc_gemm_form
=
make_naive_tensor_descriptor
(
make_tuple
(
N
,
independent_filters
[
XIdx
],
CZYX
),
make_tuple
(
NStride
,
WStride
,
gemm_m_k_strides
[
I
1
]));
make_tuple
(
NStride
,
WStride
,
gemm_
g_
m_k_strides
[
I
2
]));
const
auto
desc_gemm_form_merged_filters
=
transform_tensor_descriptor
(
desc_gemm_form
,
make_tuple
(
make_merge_transform
(
make_tuple
(
N
,
independent_filters
[
XIdx
])),
...
...
@@ -130,7 +141,7 @@ struct DeviceColumnToImageImpl
{
const
auto
desc_gemm_form
=
make_naive_tensor_descriptor
(
make_tuple
(
N
,
independent_filters
[
YIdx
],
independent_filters
[
XIdx
],
CZYX
),
make_tuple
(
NStride
,
HStride
,
WStride
,
gemm_m_k_strides
[
I
1
]));
make_tuple
(
NStride
,
HStride
,
WStride
,
gemm_
g_
m_k_strides
[
I
2
]));
const
auto
desc_gemm_form_merged_filters
=
transform_tensor_descriptor
(
desc_gemm_form
,
make_tuple
(
make_merge_transform
(
...
...
@@ -149,7 +160,7 @@ struct DeviceColumnToImageImpl
independent_filters
[
YIdx
],
independent_filters
[
XIdx
],
CZYX
),
make_tuple
(
NStride
,
DStride
,
HStride
,
WStride
,
gemm_m_k_strides
[
I
1
]));
make_tuple
(
NStride
,
DStride
,
HStride
,
WStride
,
gemm_
g_
m_k_strides
[
I
2
]));
const
auto
desc_gemm_form_merged_filters
=
transform_tensor_descriptor
(
desc_gemm_form
,
make_tuple
(
make_merge_transform
(
make_tuple
(
N
,
...
...
@@ -262,24 +273,27 @@ struct DeviceColumnToImageImpl
ThreadClusterLengths
,
ScalarPerVector
,
InMemoryDataOperationEnum
::
Add
,
Block2ETileMap
>
;
Block2ETileMap
,
ComputePtrOffsetOfStridedBatch
<>>
;
struct
Argument
:
public
BaseArgument
{
Argument
(
const
void
*
p_in
,
// input image
void
*
p_out
,
// output image
const
ck
::
index_t
G
,
const
ck
::
index_t
N
,
const
ck
::
index_t
C
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_spatial_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
filter_spatial_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
output_spatial_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
image_g_n_c_wis_strides
,
const
std
::
array
<
index_t
,
2
>&
gemm_m_k_strides
,
const
std
::
array
<
index_t
,
3
>&
gemm_
g_
m_k_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_dilations
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_left_pads
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_right_pads
)
:
C_
(
C
),
:
G_
(
G
),
C_
(
C
),
X_
(
filter_spatial_lengths
[
NDimSpatial
-
I1
]),
p_in_
{
static_cast
<
const
InputDataType
*>
(
p_in
)},
p_out_
{
static_cast
<
OutputDataType
*>
(
p_out
)},
...
...
@@ -289,6 +303,9 @@ struct DeviceColumnToImageImpl
input_left_pads_
{
input_left_pads
},
input_right_pads_
{
input_right_pads
}
{
compute_ptr_offset_of_batch_
.
BatchStrideA_
=
gemm_g_m_k_strides
[
I0
];
compute_ptr_offset_of_batch_
.
BatchStrideC_
=
image_g_n_c_wis_strides
[
I0
];
const
index_t
x_eff
=
(
filter_spatial_lengths
[
XIdx
]
-
1
)
*
conv_filter_dilations
[
XIdx
]
+
1
;
const
index_t
y_eff
=
...
...
@@ -354,7 +371,7 @@ struct DeviceColumnToImageImpl
filter_spatial_lengths
,
output_spatial_lengths
,
conv_filter_strides
,
gemm_m_k_strides
,
gemm_
g_
m_k_strides
,
independent_filters
,
effs
);
const
auto
out_grid_desc_m_k
=
...
...
@@ -387,10 +404,9 @@ struct DeviceColumnToImageImpl
// Memory offsets to next set of independent filters,
// move to independent filters in each dimension
const
index_t
in_offset
=
x_idx
*
gemm_m_k_strides
[
0
]
+
y_idx
*
gemm_m_k_strides
[
0
]
*
output_spatial_lengths
[
XIdx
]
+
z_idx
*
gemm_m_k_strides
[
0
]
*
output_spatial_lengths
[
YIdx
]
*
output_spatial_lengths
[
XIdx
];
(
x_idx
+
y_idx
*
output_spatial_lengths
[
XIdx
]
+
z_idx
*
output_spatial_lengths
[
YIdx
]
*
output_spatial_lengths
[
XIdx
])
*
gemm_g_m_k_strides
[
I1
];
// Move to independent filters in appropriate dimensions
const
index_t
out_offset
=
x_offset_with_pad
*
image_g_n_c_wis_strides
[
spatial_offset
+
XIdx
]
+
...
...
@@ -417,6 +433,7 @@ struct DeviceColumnToImageImpl
}
}
const
ck
::
index_t
G_
;
const
ck
::
index_t
C_
;
const
ck
::
index_t
X_
;
...
...
@@ -434,6 +451,8 @@ struct DeviceColumnToImageImpl
std
::
vector
<
const
InputDataType
*>
p_in_container_
;
std
::
vector
<
OutputDataType
*>
p_out_container_
;
ComputePtrOffsetOfStridedBatch
<>
compute_ptr_offset_of_batch_
;
};
struct
Invoker
:
public
BaseInvoker
...
...
@@ -451,6 +470,7 @@ struct DeviceColumnToImageImpl
OutputGridDesc
,
OutputDataType
,
Block2ETileMap
,
ComputePtrOffsetOfStridedBatch
<>
,
GridwiseTensorRearrangeKernel
>
;
// Execute each set of independent filters
...
...
@@ -460,7 +480,7 @@ struct DeviceColumnToImageImpl
BlockToCTileMap_M00_N0_M01Adapt
<
MPerBlock
,
KPerBlock
,
InputGridDesc
>
(
arg
.
out_grid_desc_m_k_container_
[
i
]);
const
index_t
grid_size
=
block_2_tile_map
.
CalculateGridSize
(
arg
.
in_grid_desc_m_k_container_
[
i
]);
block_2_tile_map
.
CalculateGridSize
(
arg
.
in_grid_desc_m_k_container_
[
i
])
*
arg
.
G_
;
elapsed_time
+=
launch_and_time_kernel
(
stream_config
,
kernel
,
dim3
(
grid_size
),
...
...
@@ -470,7 +490,9 @@ struct DeviceColumnToImageImpl
arg
.
p_in_container_
[
i
],
arg
.
out_grid_desc_m_k_container_
[
i
],
arg
.
p_out_container_
[
i
],
block_2_tile_map
);
arg
.
G_
,
block_2_tile_map
,
arg
.
compute_ptr_offset_of_batch_
);
}
return
elapsed_time
;
}
...
...
@@ -485,8 +507,7 @@ struct DeviceColumnToImageImpl
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
using
namespace
tensor_layout
::
convolution
;
if
constexpr
(
!
(
std
::
is_same_v
<
ImageLayout
,
GNWC
>
||
std
::
is_same_v
<
ImageLayout
,
GNHWC
>
||
std
::
is_same_v
<
ImageLayout
,
GNDHWC
>
))
if
constexpr
(
!
(
is_NSpatialGC
||
is_GNSpatialC
))
{
return
false
;
}
...
...
@@ -534,13 +555,14 @@ struct DeviceColumnToImageImpl
static
auto
MakeArgument
(
const
void
*
p_in
,
// input image
void
*
p_out
,
// output image
const
ck
::
index_t
G
,
const
ck
::
index_t
N
,
const
ck
::
index_t
C
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_spatial_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
filter_spatial_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
output_spatial_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
image_g_n_c_wis_strides
,
const
std
::
array
<
index_t
,
2
>&
gemm_m_k_strides
,
const
std
::
array
<
index_t
,
3
>&
gemm_
g_
m_k_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_dilations
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_left_pads
,
...
...
@@ -548,13 +570,14 @@ struct DeviceColumnToImageImpl
{
return
Argument
{
static_cast
<
const
InputDataType
*>
(
p_in
),
static_cast
<
OutputDataType
*>
(
p_out
),
G
,
N
,
C
,
input_spatial_lengths
,
filter_spatial_lengths
,
output_spatial_lengths
,
image_g_n_c_wis_strides
,
gemm_m_k_strides
,
gemm_
g_
m_k_strides
,
conv_filter_strides
,
conv_filter_dilations
,
input_left_pads
,
...
...
@@ -566,13 +589,14 @@ struct DeviceColumnToImageImpl
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
void
*
p_in
,
// input image
void
*
p_out
,
// output image
const
ck
::
index_t
G
,
const
ck
::
index_t
N
,
const
ck
::
index_t
C
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_spatial_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
filter_spatial_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
output_spatial_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
image_g_n_c_wis_strides
,
const
std
::
array
<
index_t
,
2
>&
gemm_m_k_strides
,
const
std
::
array
<
index_t
,
3
>&
gemm_
g_
m_k_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_dilations
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_left_pads
,
...
...
@@ -580,13 +604,14 @@ struct DeviceColumnToImageImpl
{
return
std
::
make_unique
<
Argument
>
(
static_cast
<
const
InputDataType
*>
(
p_in
),
static_cast
<
OutputDataType
*>
(
p_out
),
G
,
N
,
C
,
input_spatial_lengths
,
filter_spatial_lengths
,
output_spatial_lengths
,
image_g_n_c_wis_strides
,
gemm_m_k_strides
,
gemm_
g_
m_k_strides
,
conv_filter_strides
,
conv_filter_dilations
,
input_left_pads
,
...
...
include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_abd_xdl_cshuffle.hpp
View file @
e70a4d19
...
...
@@ -385,9 +385,11 @@ struct DeviceContractionMultipleABD_Xdl_CShuffle
// desc for blockwise copy
using
AsGridDesc_AK0_M_AK1
=
remove_cvref_t
<
decltype
(
GridwiseGemm
::
MakeAsGridDescriptor_AK0_M_AK1
(
AsGridDesc_M_K
{}))
>
;
remove_cvref_t
<
decltype
(
GridwiseGemm
::
MakeDefaultAsGridDescriptor_AK0_M_AK1
(
AsGridDesc_M_K
{}))
>
;
using
BsGridDesc_BK0_N_BK1
=
remove_cvref_t
<
decltype
(
GridwiseGemm
::
MakeBsGridDescriptor_BK0_N_BK1
(
BsGridDesc_N_K
{}))
>
;
remove_cvref_t
<
decltype
(
GridwiseGemm
::
MakeDefaultBsGridDescriptor_BK0_N_BK1
(
BsGridDesc_N_K
{}))
>
;
using
DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
=
remove_cvref_t
<
decltype
(
GridwiseGemm
::
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
DsGridDesc_M_N
{}))
>
;
...
...
@@ -397,7 +399,7 @@ struct DeviceContractionMultipleABD_Xdl_CShuffle
// block-to-e-tile map
using
Block2ETileMap
=
remove_cvref_t
<
decltype
(
GridwiseGemm
::
MakeBlock2ETileMap
(
EGridDesc_M_N
{}))
>
;
remove_cvref_t
<
decltype
(
GridwiseGemm
::
Make
Default
Block2ETileMap
(
EGridDesc_M_N
{}))
>
;
// Argument
struct
Argument
:
public
BaseArgument
...
...
@@ -429,7 +431,7 @@ struct DeviceContractionMultipleABD_Xdl_CShuffle
bs_grid_desc_bk0_n_bk1_
{},
ds_grid_desc_mblock_mperblock_nblock_nperblock_
{},
e_grid_desc_mblock_mperblock_nblock_nperblock_
{},
block_2_etile_map_
{
GridwiseGemm
::
MakeBlock2ETileMap
(
e_grid_desc_m_n_
)},
block_2_etile_map_
{
GridwiseGemm
::
Make
Default
Block2ETileMap
(
e_grid_desc_m_n_
)},
a_element_op_
{
a_element_op
},
b_element_op_
{
b_element_op
},
cde_element_op_
{
cde_element_op
}
...
...
@@ -481,10 +483,10 @@ struct DeviceContractionMultipleABD_Xdl_CShuffle
block_2_etile_map_
))
{
as_grid_desc_ak0_m_ak1_
=
GridwiseGemm
::
MakeAsGridDescriptor_AK0_M_AK1
(
as_grid_desc_m_k_
);
GridwiseGemm
::
Make
Default
AsGridDescriptor_AK0_M_AK1
(
as_grid_desc_m_k_
);
bs_grid_desc_bk0_n_bk1_
=
GridwiseGemm
::
MakeBsGridDescriptor_BK0_N_BK1
(
bs_grid_desc_n_k_
);
GridwiseGemm
::
Make
Default
BsGridDescriptor_BK0_N_BK1
(
bs_grid_desc_n_k_
);
ds_grid_desc_mblock_mperblock_nblock_nperblock_
=
GridwiseGemm
::
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
...
...
include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp
View file @
e70a4d19
...
...
@@ -145,7 +145,8 @@ template <index_t NumDimM,
index_t
CShuffleNXdlPerWavePerShuffle
,
typename
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
index_t
CDEBlockTransferScalarPerVector_NPerBlock
,
LoopScheduler
LoopSched
=
make_default_loop_scheduler
()>
typename
ComputeDataType
=
ADataType
,
LoopScheduler
LoopSched
=
make_default_loop_scheduler
()>
struct
DeviceContractionMultipleD_Xdl_CShuffle
:
public
DeviceContractionMultipleD
<
NumDimM
,
NumDimN
,
...
...
@@ -156,7 +157,8 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
EDataType
,
AElementwiseOperation
,
BElementwiseOperation
,
CDEElementwiseOperation
>
CDEElementwiseOperation
,
ComputeDataType
>
{
using
DeviceOp
=
DeviceContractionMultipleD_Xdl_CShuffle
;
...
...
@@ -310,8 +312,6 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
using
DsGridDesc_M_N
=
remove_cvref_t
<
decltype
(
MakeDsGridDescriptor_M_N
({{}},
{{}}))
>
;
using
EGridDesc_M_N
=
decltype
(
MakeEGridDescriptor_M_N
({},
{}));
using
ComputeDataType
=
ADataType
;
// GridwiseGemm
using
GridwiseGemm
=
GridwiseGemmMultipleD_xdl_cshuffle
<
ADataType
,
// TODO: distinguish A/B datatype
...
...
@@ -595,7 +595,9 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
return
false
;
}
if
(
ck
::
get_device_name
()
!=
"gfx90a"
&&
std
::
is_same
<
ADataType
,
double
>::
value
)
if
(
ck
::
get_device_name
()
!=
"gfx90a"
&&
ck
::
get_device_name
()
!=
"gfx940"
&&
ck
::
get_device_name
()
!=
"gfx941"
&&
ck
::
get_device_name
()
!=
"gfx942"
&&
std
::
is_same
<
ADataType
,
double
>::
value
)
{
return
false
;
}
...
...
include/ck/tensor_operation/gpu/device/impl/device_elementwise_3d_impl.hpp
0 → 100644
View file @
e70a4d19
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <sstream>
#include "ck/utility/math.hpp"
#include "ck/utility/sequence.hpp"
#include "ck/tensor_operation/gpu/device/device_elementwise.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_elementwise_3d.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/stream_utility.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
template
<
typename
InDataTypeTuple
,
typename
OutDataTypeTuple
,
typename
ElementwiseOperation
,
index_t
NumDim_m
,
// choose how to set dims
index_t
NumDim_n
,
index_t
NumDim_k
,
index_t
MPerThread
,
index_t
NPerThread
,
index_t
KPerThread
,
typename
InScalarPerVectorSeq
,
typename
OutScalarPerVectorSeq
>
struct
DeviceElementwise3dImpl
:
public
DeviceElementwise
<
InDataTypeTuple
,
OutDataTypeTuple
,
ElementwiseOperation
,
NumDim_m
+
NumDim_n
+
NumDim_k
>
{
static
constexpr
index_t
NumDim
=
NumDim_m
+
NumDim_n
+
NumDim_k
;
static
constexpr
int
NumInput
=
InDataTypeTuple
::
Size
();
static
constexpr
int
NumOutput
=
OutDataTypeTuple
::
Size
();
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
static
constexpr
auto
I4
=
Number
<
4
>
{};
static_assert
(
NumInput
==
InScalarPerVectorSeq
::
Size
()
&&
NumOutput
==
OutScalarPerVectorSeq
::
Size
(),
"Tuple size is inconsistent with the number of in/out!"
);
static
auto
GenerateInDataTypePointerTuple
()
{
return
generate_tuple
(
[
&
](
auto
I
)
{
using
DataType
=
remove_cvref_t
<
decltype
(
InDataTypeTuple
{}[
I
])
>
;
return
static_cast
<
const
DataType
*>
(
nullptr
);
},
Number
<
NumInput
>
{});
}
static
auto
GenerateOutDataTypePointerTuple
()
{
return
generate_tuple
(
[
&
](
auto
I
)
{
using
DataType
=
remove_cvref_t
<
decltype
(
OutDataTypeTuple
{}[
I
])
>
;
return
static_cast
<
DataType
*>
(
nullptr
);
},
Number
<
NumOutput
>
{});
}
using
InDataTypePointerTuple
=
decltype
(
GenerateInDataTypePointerTuple
());
using
OutDataTypePointerTuple
=
decltype
(
GenerateOutDataTypePointerTuple
());
template
<
typename
Desc_MNK
>
static
auto
PadDescriptor_MNK
(
Desc_MNK
desc_mnk
,
index_t
gridSize
,
index_t
blockSize
,
index_t
num_threads_m
,
index_t
num_threads_n
,
index_t
num_threads_k
)
{
std
::
ignore
=
blockSize
;
std
::
ignore
=
gridSize
;
const
auto
m
=
desc_mnk
.
GetLength
(
I0
);
const
auto
n
=
desc_mnk
.
GetLength
(
I1
);
const
auto
k
=
desc_mnk
.
GetLength
(
I2
);
const
index_t
loop_step_m
=
num_threads_m
*
MPerThread
;
const
index_t
loop_step_n
=
num_threads_n
*
NPerThread
;
const
index_t
loop_step_k
=
num_threads_k
*
KPerThread
;
const
auto
pad_m
=
math
::
integer_least_multiple
(
m
,
loop_step_m
)
-
m
;
const
auto
pad_n
=
math
::
integer_least_multiple
(
n
,
loop_step_n
)
-
n
;
const
auto
pad_k
=
math
::
integer_least_multiple
(
k
,
loop_step_k
)
-
k
;
const
auto
desc_mnk_pad
=
transform_tensor_descriptor
(
desc_mnk
,
make_tuple
(
make_right_pad_transform
(
m
,
pad_m
),
make_right_pad_transform
(
n
,
pad_n
),
make_right_pad_transform
(
k
,
pad_k
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}));
return
desc_mnk_pad
;
}
static
auto
MakeDescriptor_MNK
(
const
std
::
array
<
index_t
,
NumDim
>&
lengths
,
const
std
::
array
<
index_t
,
NumDim
>&
stride
,
index_t
gridSize
,
index_t
blockSize
,
index_t
num_threads_m
,
index_t
num_threads_n
,
index_t
num_threads_k
)
{
auto
tupleOfShape
=
generate_tuple
([
&
](
auto
I
)
{
return
lengths
[
I
];
},
Number
<
NumDim
>
{});
auto
tupleOfStride
=
generate_tuple
([
&
](
auto
I
)
{
return
stride
[
I
];
},
Number
<
NumDim
>
{});
// nd desc - [s0, s1, s2, ...]
const
auto
desc
=
make_naive_tensor_descriptor
(
tupleOfShape
,
tupleOfStride
);
constexpr
auto
mDimIds
=
typename
arithmetic_sequence_gen
<
0
,
NumDim_m
,
1
>::
type
();
constexpr
auto
nDimIds
=
typename
arithmetic_sequence_gen
<
NumDim_m
,
NumDim_m
+
NumDim_n
,
1
>::
type
();
constexpr
auto
kDimIds
=
typename
arithmetic_sequence_gen
<
NumDim_m
+
NumDim_n
,
NumDim
,
1
>::
type
();
const
auto
mLengths
=
get_container_subset
(
tupleOfShape
,
mDimIds
);
const
auto
nLengths
=
get_container_subset
(
tupleOfShape
,
nDimIds
);
const
auto
kLengths
=
get_container_subset
(
tupleOfShape
,
kDimIds
);
// merge nd to 3d desc - [s0 * s1 * ...]
if
constexpr
(
NumDim
>
3
)
{
const
auto
desc_mnk
=
transform_tensor_descriptor
(
desc
,
make_tuple
(
make_merge_transform
(
mLengths
),
make_merge_transform
(
nLengths
),
make_merge_transform
(
kLengths
)),
make_tuple
(
mDimIds
,
nDimIds
,
kDimIds
),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}));
return
PadDescriptor_MNK
(
desc_mnk
,
gridSize
,
blockSize
,
num_threads_m
,
num_threads_n
,
num_threads_k
);
}
else
return
PadDescriptor_MNK
(
desc
,
gridSize
,
blockSize
,
num_threads_m
,
num_threads_n
,
num_threads_k
);
}
template
<
index_t
TupleSize
>
static
auto
GenerateInOutGrid3dDescTuple
(
Number
<
TupleSize
>
)
{
return
generate_tuple
(
[
&
](
auto
)
{
if
constexpr
(
NumDim
>
3
)
{
return
MakeDescriptor_MNK
({
1
,
1
,
1
},
{
1
,
1
,
1
},
1
,
1
,
1
,
1
,
1
);
}
else
{
return
MakeDescriptor_MNK
({
1
},
{
1
},
1
,
1
,
1
,
1
,
1
);
};
},
Number
<
TupleSize
>
{});
}
using
OutGrid3dDescTuple
=
decltype
(
GenerateInOutGrid3dDescTuple
(
Number
<
NumOutput
>
{}));
using
InGrid3dDescTuple
=
decltype
(
GenerateInOutGrid3dDescTuple
(
Number
<
NumInput
>
{}));
using
GridwiseElementwise
=
GridwiseElementwise_3D
<
InGrid3dDescTuple
,
OutGrid3dDescTuple
,
InDataTypePointerTuple
,
OutDataTypePointerTuple
,
ElementwiseOperation
,
MPerThread
,
NPerThread
,
KPerThread
,
InScalarPerVectorSeq
,
OutScalarPerVectorSeq
>
;
struct
Argument
:
public
BaseArgument
{
Argument
(
const
std
::
array
<
index_t
,
NumDim
>
lengths
,
const
std
::
array
<
std
::
array
<
index_t
,
NumDim
>
,
NumInput
>
inStridesArray
,
const
std
::
array
<
std
::
array
<
index_t
,
NumDim
>
,
NumOutput
>
outStridesArray
,
const
std
::
array
<
const
void
*
,
NumInput
>
in_dev_buffers
,
const
std
::
array
<
void
*
,
NumOutput
>
out_dev_buffers
,
ElementwiseOperation
elementwise_op
)
:
lengths_
(
lengths
),
inStridesArray_
(
inStridesArray
),
outStridesArray_
(
outStridesArray
),
elementwise_op_
(
elementwise_op
),
blockSize_
(
256
)
{
static_assert
(
NumDim_m
>
0
,
""
);
static_assert
(
NumDim_n
>
0
,
""
);
static_assert
(
NumDim_k
>
0
,
""
);
in_dev_buffers_
=
generate_tuple
(
[
&
](
auto
I
)
{
using
DataType
=
remove_cvref_t
<
decltype
(
InDataTypeTuple
{}[
I
])
>
;
return
static_cast
<
const
DataType
*>
(
in_dev_buffers
[
I
.
value
]);
},
Number
<
NumInput
>
{});
out_dev_buffers_
=
generate_tuple
(
[
&
](
auto
I
)
{
using
DataType
=
remove_cvref_t
<
decltype
(
OutDataTypeTuple
{}[
I
])
>
;
return
static_cast
<
DataType
*>
(
out_dev_buffers
[
I
.
value
]);
},
Number
<
NumOutput
>
{});
}
InDataTypePointerTuple
in_dev_buffers_
;
OutDataTypePointerTuple
out_dev_buffers_
;
std
::
array
<
index_t
,
NumDim
>
lengths_
;
std
::
array
<
std
::
array
<
index_t
,
NumDim
>
,
NumInput
>
inStridesArray_
;
std
::
array
<
std
::
array
<
index_t
,
NumDim
>
,
NumOutput
>
outStridesArray_
;
ElementwiseOperation
elementwise_op_
;
index_t
blockSize_
;
};
struct
Invoker
:
public
BaseInvoker
{
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
{
index_t
gridSize
=
getAvailableComputeUnitCount
(
stream_config
)
*
arg
.
blockSize_
;
index_t
num_threads_m
=
gridSize
/
(
16
*
16
);
index_t
num_threads_n
=
16
;
index_t
num_threads_k
=
16
;
auto
in_grid_3d_desc_tuple
=
generate_tuple
(
[
&
](
auto
I
)
{
return
MakeDescriptor_MNK
(
arg
.
lengths_
,
arg
.
inStridesArray_
[
I
.
value
],
gridSize
,
arg
.
blockSize_
,
num_threads_m
,
num_threads_n
,
num_threads_k
);
},
Number
<
NumInput
>
{});
auto
out_grid_3d_desc_tuple
=
generate_tuple
(
[
&
](
auto
I
)
{
return
MakeDescriptor_MNK
(
arg
.
lengths_
,
arg
.
outStridesArray_
[
I
.
value
],
gridSize
,
arg
.
blockSize_
,
num_threads_m
,
num_threads_n
,
num_threads_k
);
},
Number
<
NumOutput
>
{});
const
auto
kernel
=
kernel_elementwise_3d
<
GridwiseElementwise
,
InGrid3dDescTuple
,
OutGrid3dDescTuple
,
InDataTypePointerTuple
,
OutDataTypePointerTuple
,
ElementwiseOperation
>
;
float
elapsed_time
=
launch_and_time_kernel
(
stream_config
,
kernel
,
dim3
(
gridSize
),
dim3
(
arg
.
blockSize_
),
0
,
in_grid_3d_desc_tuple
,
out_grid_3d_desc_tuple
,
arg
.
in_dev_buffers_
,
arg
.
out_dev_buffers_
,
arg
.
elementwise_op_
,
num_threads_m
,
num_threads_n
,
num_threads_k
);
return
elapsed_time
;
}
// polymorphic
float
Run
(
const
BaseArgument
*
p_arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
override
{
return
Run
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
),
stream_config
);
}
};
bool
IsSupportedArgument
(
const
BaseArgument
*
p_arg
)
override
{
if
((
ck
::
get_device_name
()
==
"gfx940"
||
ck
::
get_device_name
()
==
"gfx941"
||
ck
::
get_device_name
()
==
"gfx942"
))
{
return
false
;
}
const
Argument
*
pArg
=
dynamic_cast
<
const
Argument
*>
(
p_arg
);
if
(
pArg
==
nullptr
)
return
false
;
if
(
pArg
->
lengths_
.
back
()
%
MPerThread
!=
0
)
return
false
;
auto
IsScalarPerVectorValid
=
[
&
](
const
std
::
array
<
index_t
,
NumDim
>&
lengths
,
const
std
::
array
<
index_t
,
NumDim
>&
strides
,
index_t
scalarPerVector
,
index_t
vectorDim
)
{
if
(
strides
[
vectorDim
]
==
1
&&
(
lengths
[
vectorDim
]
%
scalarPerVector
==
0
||
lengths
[
vectorDim
]
%
scalarPerVector
==
lengths
[
vectorDim
]))
{
return
true
;
}
if
(
strides
[
vectorDim
]
>=
scalarPerVector
)
{
return
true
;
}
return
false
;
};
bool
valid
=
true
;
static_for
<
0
,
NumInput
,
1
>
{}([
&
](
auto
I
)
{
valid
=
valid
&&
IsScalarPerVectorValid
(
pArg
->
lengths_
,
pArg
->
inStridesArray_
[
I
.
value
],
InScalarPerVectorSeq
::
At
(
I
),
NumDim_m
-
1
);
});
static_for
<
0
,
NumOutput
,
1
>
{}([
&
](
auto
I
)
{
valid
=
valid
&&
IsScalarPerVectorValid
(
pArg
->
lengths_
,
pArg
->
outStridesArray_
[
I
.
value
],
OutScalarPerVectorSeq
::
At
(
I
),
NumDim
-
1
);
});
return
valid
;
}
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
std
::
array
<
index_t
,
NumDim
>
lengths
,
const
std
::
array
<
std
::
array
<
index_t
,
NumDim
>
,
NumInput
>
inStridesArray
,
const
std
::
array
<
std
::
array
<
index_t
,
NumDim
>
,
NumOutput
>
outStridesArray
,
const
std
::
array
<
const
void
*
,
NumInput
>
in_dev_buffers
,
const
std
::
array
<
void
*
,
NumOutput
>
out_dev_buffers
,
ElementwiseOperation
elementwise_op
)
override
{
return
std
::
make_unique
<
Argument
>
(
lengths
,
inStridesArray
,
outStridesArray
,
in_dev_buffers
,
out_dev_buffers
,
elementwise_op
);
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
override
{
return
std
::
make_unique
<
Invoker
>
();
}
};
// namespace device
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
include/ck/tensor_operation/gpu/device/impl/device_elementwise_scale_impl.hpp
0 → 100644
View file @
e70a4d19
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <sstream>
#include "ck/utility/math.hpp"
#include "ck/utility/sequence.hpp"
#include "ck/tensor_operation/gpu/device/device_elementwise_scale.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_elementwise_1d_scale.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/host_utility/stream_utility.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
template
<
typename
InDataTypeTuple
,
typename
OutDataTypeTuple
,
typename
ElementwiseOperation
,
typename
UnaryOperation
,
typename
Scale
,
index_t
NumDim
,
index_t
MPerThread
,
typename
InScalarPerVectorSeq
,
typename
OutScalarPerVectorSeq
>
struct
DeviceElementwiseImpl
:
public
DeviceElementwise
<
InDataTypeTuple
,
OutDataTypeTuple
,
ElementwiseOperation
,
UnaryOperation
,
Scale
,
NumDim
>
{
static
constexpr
int
NumInput
=
InDataTypeTuple
::
Size
();
static
constexpr
int
NumOutput
=
OutDataTypeTuple
::
Size
();
static_assert
(
NumInput
==
InScalarPerVectorSeq
::
Size
()
&&
NumOutput
==
OutScalarPerVectorSeq
::
Size
(),
"Tuple size is inconsistent with the number of in/out!"
);
static
auto
GenerateInDataTypePointerTuple
()
{
return
generate_tuple
(
[
&
](
auto
I
)
{
using
DataType
=
remove_cvref_t
<
decltype
(
InDataTypeTuple
{}[
I
])
>
;
return
static_cast
<
const
DataType
*>
(
nullptr
);
},
Number
<
NumInput
>
{});
};
static
auto
GenerateOutDataTypePointerTuple
()
{
return
generate_tuple
(
[
&
](
auto
I
)
{
using
DataType
=
remove_cvref_t
<
decltype
(
OutDataTypeTuple
{}[
I
])
>
;
return
static_cast
<
DataType
*>
(
nullptr
);
},
Number
<
NumOutput
>
{});
};
using
InDataTypePointerTuple
=
decltype
(
GenerateInDataTypePointerTuple
());
using
OutDataTypePointerTuple
=
decltype
(
GenerateOutDataTypePointerTuple
());
template
<
typename
Desc_M
>
static
auto
PadDescriptor_M_1d
(
Desc_M
desc_m
,
index_t
gridSize
,
index_t
blockSize
)
{
constexpr
auto
I0
=
Number
<
0
>
{};
const
auto
m
=
desc_m
.
GetLength
(
I0
);
const
index_t
loop_step
=
gridSize
*
blockSize
*
MPerThread
;
const
auto
pad
=
math
::
integer_least_multiple
(
m
,
loop_step
)
-
m
;
const
auto
desc_m_pad
=
transform_tensor_descriptor
(
desc_m
,
make_tuple
(
make_right_pad_transform
(
m
,
pad
)),
make_tuple
(
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
return
desc_m_pad
;
}
static
auto
MakeDescriptor_M
(
const
std
::
array
<
index_t
,
NumDim
>&
lengths
,
const
std
::
array
<
index_t
,
NumDim
>&
stride
,
index_t
gridSize
,
index_t
blockSize
)
{
auto
tupleOfShape
=
generate_tuple
([
&
](
auto
I
)
{
return
lengths
[
I
];
},
Number
<
NumDim
>
{});
auto
tupleOfStride
=
generate_tuple
([
&
](
auto
I
)
{
return
stride
[
I
];
},
Number
<
NumDim
>
{});
// nd desc - [s0, s1, s2, ...]
const
auto
desc
=
make_naive_tensor_descriptor
(
tupleOfShape
,
tupleOfStride
);
// merge nd to 1d desc - [s0 * s1 * ...]
if
constexpr
(
NumDim
>
1
)
{
const
auto
desc_m
=
transform_tensor_descriptor
(
desc
,
make_tuple
(
make_merge_transform
(
tupleOfShape
)),
make_tuple
(
generate_sequence_v2
([
&
](
auto
I
)
{
return
I
;
},
Number
<
NumDim
>
{})),
make_tuple
(
Sequence
<
0
>
{}));
return
PadDescriptor_M_1d
(
desc_m
,
gridSize
,
blockSize
);
}
else
return
PadDescriptor_M_1d
(
desc
,
gridSize
,
blockSize
);
}
template
<
index_t
TupleSize
>
static
auto
GenerateInOutGrid1dDescTuple
(
Number
<
TupleSize
>
)
{
return
generate_tuple
(
[
&
](
auto
)
{
if
constexpr
(
NumDim
>
1
)
{
return
MakeDescriptor_M
({
1
,
1
},
{
1
,
1
},
1
,
1
);
}
else
{
return
MakeDescriptor_M
({
1
},
{
1
},
1
,
1
);
};
},
Number
<
TupleSize
>
{});
};
using
InGrid1dDescTuple
=
decltype
(
GenerateInOutGrid1dDescTuple
(
Number
<
NumInput
>
{}));
using
OutGrid1dDescTuple
=
decltype
(
GenerateInOutGrid1dDescTuple
(
Number
<
NumOutput
>
{}));
using
GridwiseElementwise
=
GridwiseElementwise_1D
<
InGrid1dDescTuple
,
OutGrid1dDescTuple
,
InDataTypePointerTuple
,
OutDataTypePointerTuple
,
ElementwiseOperation
,
UnaryOperation
,
Scale
,
MPerThread
,
InScalarPerVectorSeq
,
OutScalarPerVectorSeq
>
;
struct
Argument
:
public
BaseArgument
{
Argument
(
const
std
::
array
<
index_t
,
NumDim
>
lengths
,
const
std
::
array
<
std
::
array
<
index_t
,
NumDim
>
,
NumInput
>
inStridesArray
,
const
std
::
array
<
std
::
array
<
index_t
,
NumDim
>
,
NumOutput
>
outStridesArray
,
const
std
::
array
<
const
void
*
,
NumInput
>
in_dev_buffers
,
const
std
::
array
<
void
*
,
NumOutput
>
out_dev_buffers
,
ElementwiseOperation
elementwise_op
,
UnaryOperation
unary_op
,
Scale
scale_op
)
:
lengths_
(
lengths
),
inStridesArray_
(
inStridesArray
),
outStridesArray_
(
outStridesArray
),
elementwise_op_
(
elementwise_op
),
unary_op_
(
unary_op
),
scale_op_
(
scale_op
),
blockSize_
(
256
)
{
in_dev_buffers_
=
generate_tuple
(
[
&
](
auto
I
)
{
using
DataType
=
remove_cvref_t
<
decltype
(
InDataTypeTuple
{}[
I
])
>
;
return
static_cast
<
const
DataType
*>
(
in_dev_buffers
[
I
.
value
]);
},
Number
<
NumInput
>
{});
out_dev_buffers_
=
generate_tuple
(
[
&
](
auto
I
)
{
using
DataType
=
remove_cvref_t
<
decltype
(
OutDataTypeTuple
{}[
I
])
>
;
return
static_cast
<
DataType
*>
(
out_dev_buffers
[
I
.
value
]);
},
Number
<
NumOutput
>
{});
}
InDataTypePointerTuple
in_dev_buffers_
;
OutDataTypePointerTuple
out_dev_buffers_
;
std
::
array
<
index_t
,
NumDim
>
lengths_
;
std
::
array
<
std
::
array
<
index_t
,
NumDim
>
,
NumInput
>
inStridesArray_
;
std
::
array
<
std
::
array
<
index_t
,
NumDim
>
,
NumOutput
>
outStridesArray_
;
ElementwiseOperation
elementwise_op_
;
UnaryOperation
unary_op_
;
Scale
scale_op_
;
index_t
blockSize_
;
};
struct
Invoker
:
public
BaseInvoker
{
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
{
index_t
gridSize
=
getAvailableComputeUnitCount
(
stream_config
);
auto
in_grid_1d_desc_tuple
=
generate_tuple
(
[
&
](
auto
I
)
{
return
MakeDescriptor_M
(
arg
.
lengths_
,
arg
.
inStridesArray_
[
I
.
value
],
gridSize
,
arg
.
blockSize_
);
},
Number
<
NumInput
>
{});
auto
out_grid_1d_desc_tuple
=
generate_tuple
(
[
&
](
auto
I
)
{
return
MakeDescriptor_M
(
arg
.
lengths_
,
arg
.
outStridesArray_
[
I
.
value
],
gridSize
,
arg
.
blockSize_
);
},
Number
<
NumOutput
>
{});
const
auto
kernel
=
kernel_elementwise_1d
<
GridwiseElementwise
,
InGrid1dDescTuple
,
OutGrid1dDescTuple
,
InDataTypePointerTuple
,
OutDataTypePointerTuple
,
ElementwiseOperation
,
UnaryOperation
,
Scale
>
;
float
elapsed_time
=
launch_and_time_kernel
(
stream_config
,
kernel
,
dim3
(
gridSize
),
dim3
(
arg
.
blockSize_
),
0
,
in_grid_1d_desc_tuple
,
out_grid_1d_desc_tuple
,
arg
.
in_dev_buffers_
,
arg
.
out_dev_buffers_
,
arg
.
elementwise_op_
,
arg
.
unary_op_
,
arg
.
scale_op_
);
return
elapsed_time
;
}
// polymorphic
float
Run
(
const
BaseArgument
*
p_arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
override
{
return
Run
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
),
stream_config
);
}
};
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
if
(
arg
.
lengths_
.
back
()
%
MPerThread
!=
0
)
return
false
;
auto
IsScalarPerVectorValid
=
[
&
](
const
std
::
array
<
index_t
,
NumDim
>&
lengths
,
const
std
::
array
<
index_t
,
NumDim
>&
strides
,
index_t
scalarPerVector
)
{
if
(
strides
.
back
()
==
1
&&
lengths
.
back
()
%
scalarPerVector
==
0
)
return
true
;
if
(
strides
.
back
()
!=
1
&&
scalarPerVector
==
1
)
return
true
;
return
false
;
};
bool
valid
=
true
;
static_for
<
0
,
NumInput
,
1
>
{}([
&
](
auto
I
)
{
if
(
!
IsScalarPerVectorValid
(
arg
.
lengths_
,
arg
.
inStridesArray_
[
I
.
value
],
InScalarPerVectorSeq
::
At
(
I
)))
valid
=
false
;
});
static_for
<
0
,
NumOutput
,
1
>
{}([
&
](
auto
I
)
{
if
(
!
IsScalarPerVectorValid
(
arg
.
lengths_
,
arg
.
outStridesArray_
[
I
.
value
],
OutScalarPerVectorSeq
::
At
(
I
)))
valid
=
false
;
});
return
valid
;
};
bool
IsSupportedArgument
(
const
BaseArgument
*
p_arg
)
override
{
return
IsSupportedArgument
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
));
}
static
auto
MakeArgument
(
const
std
::
array
<
index_t
,
NumDim
>
lengths
,
const
std
::
array
<
std
::
array
<
index_t
,
NumDim
>
,
NumInput
>
inStridesArray
,
const
std
::
array
<
std
::
array
<
index_t
,
NumDim
>
,
NumOutput
>
outStridesArray
,
const
std
::
array
<
const
void
*
,
NumInput
>
in_dev_buffers
,
const
std
::
array
<
void
*
,
NumOutput
>
out_dev_buffers
,
ElementwiseOperation
elementwise_op
,
UnaryOperation
unary_op
,
Scale
scale_op
)
{
return
Argument
{
lengths
,
inStridesArray
,
outStridesArray
,
in_dev_buffers
,
out_dev_buffers
,
elementwise_op
,
unary_op
,
scale_op
};
}
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
std
::
array
<
index_t
,
NumDim
>
lengths
,
const
std
::
array
<
std
::
array
<
index_t
,
NumDim
>
,
NumInput
>
inStridesArray
,
const
std
::
array
<
std
::
array
<
index_t
,
NumDim
>
,
NumOutput
>
outStridesArray
,
const
std
::
array
<
const
void
*
,
NumInput
>
in_dev_buffers
,
const
std
::
array
<
void
*
,
NumOutput
>
out_dev_buffers
,
ElementwiseOperation
elementwise_op
,
UnaryOperation
unary_op
,
Scale
scale_op
)
override
{
return
std
::
make_unique
<
Argument
>
(
lengths
,
inStridesArray
,
outStridesArray
,
in_dev_buffers
,
out_dev_buffers
,
elementwise_op
,
unary_op
,
scale_op
);
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
override
{
return
std
::
make_unique
<
Invoker
>
();
};
};
// namespace device
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_abd_xdl_cshuffle.hpp
View file @
e70a4d19
...
...
@@ -305,9 +305,11 @@ struct DeviceGemmMultipleABD_Xdl_CShuffle : public DeviceGemmMultipleABD<AsLayou
// desc for blockwise copy
using
AsGridDesc_AK0_M_AK1
=
remove_cvref_t
<
decltype
(
GridwiseGemm
::
MakeAsGridDescriptor_AK0_M_AK1
(
AsGridDesc_M_K
{}))
>
;
remove_cvref_t
<
decltype
(
GridwiseGemm
::
MakeDefaultAsGridDescriptor_AK0_M_AK1
(
AsGridDesc_M_K
{}))
>
;
using
BsGridDesc_BK0_N_BK1
=
remove_cvref_t
<
decltype
(
GridwiseGemm
::
MakeBsGridDescriptor_BK0_N_BK1
(
BsGridDesc_N_K
{}))
>
;
remove_cvref_t
<
decltype
(
GridwiseGemm
::
MakeDefaultBsGridDescriptor_BK0_N_BK1
(
BsGridDesc_N_K
{}))
>
;
using
DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
=
remove_cvref_t
<
decltype
(
GridwiseGemm
::
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
DsGridDesc_M_N
{}))
>
;
...
...
@@ -317,7 +319,7 @@ struct DeviceGemmMultipleABD_Xdl_CShuffle : public DeviceGemmMultipleABD<AsLayou
// block-to-e-tile map
using
Block2ETileMap
=
remove_cvref_t
<
decltype
(
GridwiseGemm
::
MakeBlock2ETileMap
(
EGridDesc_M_N
{}))
>
;
remove_cvref_t
<
decltype
(
GridwiseGemm
::
Make
Default
Block2ETileMap
(
EGridDesc_M_N
{}))
>
;
// Argument
struct
Argument
:
public
BaseArgument
...
...
@@ -349,7 +351,7 @@ struct DeviceGemmMultipleABD_Xdl_CShuffle : public DeviceGemmMultipleABD<AsLayou
bs_grid_desc_bk0_n_bk1_
{},
ds_grid_desc_mblock_mperblock_nblock_nperblock_
{},
e_grid_desc_mblock_mperblock_nblock_nperblock_
{},
block_2_etile_map_
{
GridwiseGemm
::
MakeBlock2ETileMap
(
e_grid_desc_m_n_
)},
block_2_etile_map_
{
GridwiseGemm
::
Make
Default
Block2ETileMap
(
e_grid_desc_m_n_
)},
a_element_op_
{
a_element_op
},
b_element_op_
{
b_element_op
},
cde_element_op_
{
cde_element_op
},
...
...
@@ -407,10 +409,10 @@ struct DeviceGemmMultipleABD_Xdl_CShuffle : public DeviceGemmMultipleABD<AsLayou
block_2_etile_map_
))
{
as_grid_desc_ak0_m_ak1_
=
GridwiseGemm
::
MakeAsGridDescriptor_AK0_M_AK1
(
as_grid_desc_m_k_
);
GridwiseGemm
::
Make
Default
AsGridDescriptor_AK0_M_AK1
(
as_grid_desc_m_k_
);
bs_grid_desc_bk0_n_bk1_
=
GridwiseGemm
::
MakeBsGridDescriptor_BK0_N_BK1
(
bs_grid_desc_n_k_
);
GridwiseGemm
::
Make
Default
BsGridDescriptor_BK0_N_BK1
(
bs_grid_desc_n_k_
);
ds_grid_desc_mblock_mperblock_nblock_nperblock_
=
GridwiseGemm
::
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
...
...
include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp
View file @
e70a4d19
...
...
@@ -821,7 +821,9 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle
return
(
workspace_size
);
};
void
SetWorkSpacePointer
(
BaseArgument
*
pArg
,
void
*
p_workspace
)
const
override
void
SetWorkSpacePointer
(
BaseArgument
*
pArg
,
void
*
p_workspace
,
const
StreamConfig
&
=
StreamConfig
{})
const
override
{
Argument
*
pArg_
=
dynamic_cast
<
Argument
*>
(
pArg
);
...
...
include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_lds_direct_load.hpp
0 → 100644
View file @
e70a4d19
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <sstream>
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle_lds_direct_load.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
template
<
typename
ALayout
,
typename
BLayout
,
typename
DsLayout
,
typename
ELayout
,
typename
ADataType
,
typename
BDataType
,
typename
AccDataType
,
typename
CShuffleDataType
,
typename
DsDataType
,
typename
EDataType
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CDEElementwiseOperation
,
GemmSpecialization
GemmSpec
,
index_t
NumGemmKPrefetchStage
,
index_t
BlockSize
,
index_t
MPerBlock
,
index_t
NPerBlock
,
index_t
KPerBlock
,
index_t
AK1
,
index_t
BK1
,
index_t
MPerXDL
,
index_t
NPerXDL
,
index_t
MXdlPerWave
,
index_t
NXdlPerWave
,
typename
ABlockTransferThreadClusterLengths_AK0_M_AK1
,
typename
ABlockTransferSrcAccessOrder
,
index_t
ABlockTransferSrcVectorDim
,
index_t
ABlockTransferScalarPerVector
,
index_t
ABlockLdsExtraM
,
typename
BBlockTransferThreadClusterLengths_BK0_N_BK1
,
typename
BBlockTransferSrcAccessOrder
,
index_t
BBlockTransferSrcVectorDim
,
index_t
BBlockTransferScalarPerVector
,
index_t
BBlockLdsExtraN
,
index_t
CShuffleMXdlPerWavePerShuffle
,
index_t
CShuffleNXdlPerWavePerShuffle
,
typename
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
index_t
CDEBlockTransferScalarPerVector_NPerBlock
,
LoopScheduler
LoopSched
=
make_default_loop_scheduler
(),
PipelineVersion
PipelineVer
=
PipelineVersion
::
v4
,
typename
ComputeDataType
=
EDataType
>
struct
DeviceGemmMultipleD_Xdl_CShuffle_LdsDirectLoad
:
public
DeviceGemmMultipleD
<
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
ADataType
,
BDataType
,
DsDataType
,
EDataType
,
AElementwiseOperation
,
BElementwiseOperation
,
CDEElementwiseOperation
>
{
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
index_t
NumDTensor
=
DsDataType
::
Size
();
using
GridwiseGemm
=
GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad
<
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
ADataType
,
BDataType
,
ComputeDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
AElementwiseOperation
,
BElementwiseOperation
,
CDEElementwiseOperation
,
InMemoryDataOperationEnum
::
Set
,
GemmSpec
,
NumGemmKPrefetchStage
,
BlockSize
,
MPerBlock
,
NPerBlock
,
KPerBlock
,
AK1
,
BK1
,
MPerXDL
,
NPerXDL
,
MXdlPerWave
,
NXdlPerWave
,
ABlockTransferThreadClusterLengths_AK0_M_AK1
,
ABlockTransferSrcAccessOrder
,
ABlockTransferSrcVectorDim
,
ABlockTransferScalarPerVector
,
ABlockLdsExtraM
,
BBlockTransferThreadClusterLengths_BK0_N_BK1
,
BBlockTransferSrcAccessOrder
,
BBlockTransferSrcVectorDim
,
BBlockTransferScalarPerVector
,
BBlockLdsExtraN
,
CShuffleMXdlPerWavePerShuffle
,
CShuffleNXdlPerWavePerShuffle
,
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
CDEBlockTransferScalarPerVector_NPerBlock
,
LoopSched
,
PipelineVer
>
;
using
Argument
=
typename
GridwiseGemm
::
Argument
;
struct
Invoker
:
public
BaseInvoker
{
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
{
if
(
!
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_m_k_
,
arg
.
b_grid_desc_n_k_
,
arg
.
ds_grid_desc_m_n_
,
arg
.
e_grid_desc_m_n_
,
arg
.
block_2_etile_map_
))
{
throw
std
::
runtime_error
(
"wrong! GridwiseGemm has invalid setting"
);
}
const
index_t
grid_size
=
arg
.
block_2_etile_map_
.
CalculateGridSize
(
arg
.
e_grid_desc_m_n_
);
auto
launch_kernel
=
[
&
](
auto
has_main_k_block_loop
)
{
constexpr
bool
has_main_loop
=
has_main_k_block_loop
.
value
;
const
auto
kernel
=
kernel_gemm_multiple_d_xdl_cshuffle_lds_direct_load
<
GridwiseGemm
,
ADataType
,
BDataType
,
typename
GridwiseGemm
::
DsGridPointer
,
EDataType
,
AElementwiseOperation
,
BElementwiseOperation
,
CDEElementwiseOperation
,
typename
GridwiseGemm
::
AGridDesc_AK0_M_AK1
,
typename
GridwiseGemm
::
BGridDesc_BK0_N_BK1
,
typename
GridwiseGemm
::
DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
GridwiseGemm
::
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
GridwiseGemm
::
Block2ETileMap
,
has_main_loop
>
;
return
launch_and_time_kernel
(
stream_config
,
kernel
,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
arg
.
p_a_grid_
,
arg
.
p_b_grid_
,
arg
.
p_ds_grid_
,
arg
.
p_e_grid_
,
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
cde_element_op_
,
arg
.
a_grid_desc_ak0_m_ak1_
,
arg
.
b_grid_desc_bk0_n_bk1_
,
arg
.
ds_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg
.
e_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg
.
block_2_etile_map_
);
};
const
auto
K
=
arg
.
a_grid_desc_m_k_
.
GetLength
(
I1
);
if
(
GridwiseGemm
::
CalculateHasMainKBlockLoop
(
K
))
{
return
launch_kernel
(
integral_constant
<
bool
,
true
>
{});
}
else
{
return
launch_kernel
(
integral_constant
<
bool
,
false
>
{});
}
}
float
Run
(
const
BaseArgument
*
p_arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
override
{
return
Run
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
),
stream_config
);
}
};
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
if
(
!
ck
::
is_xdl_supported
())
{
return
false
;
}
if
(
!
ck
::
is_lds_direct_load_supported
())
{
return
false
;
}
// Check vector load/store.
{
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
// Check vector load of A.
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
ABlockTransferSrcVectorDim
==
2
)
{
if
(
arg
.
KRaw_
%
ABlockTransferScalarPerVector
!=
0
)
{
return
false
;
}
}
else
if
constexpr
(
is_same_v
<
ALayout
,
Col
>
&&
ABlockTransferSrcVectorDim
==
1
)
{
if
(
arg
.
MRaw_
%
ABlockTransferScalarPerVector
!=
0
)
{
return
false
;
}
}
else
{
return
false
;
}
// Check vector load of B.
if
constexpr
(
is_same_v
<
BLayout
,
Col
>
&&
BBlockTransferSrcVectorDim
==
2
)
{
if
(
arg
.
KRaw_
%
BBlockTransferScalarPerVector
!=
0
)
{
return
false
;
}
}
else
if
constexpr
(
is_same_v
<
BLayout
,
Row
>
&&
BBlockTransferSrcVectorDim
==
1
)
{
if
(
arg
.
NRaw_
%
BBlockTransferScalarPerVector
!=
0
)
{
return
false
;
}
}
else
{
return
false
;
}
// Check vector load of Ds.
// For now, only the RowMajor layout is supported.
bool
all_valid
=
true
;
static_for
<
0
,
NumDTensor
,
1
>
{}([
&
](
auto
i
)
{
using
DLayout
=
remove_cvref_t
<
tuple_element_t
<
i
.
value
,
DsLayout
>>
;
if
constexpr
(
!
is_same_v
<
DLayout
,
Row
>
)
{
all_valid
=
false
;
}
});
if
(
!
all_valid
)
{
return
false
;
}
// Check vector load of E.
// For now, only the RowMajor layout is supported.
if
constexpr
(
is_same_v
<
ELayout
,
Row
>
)
{
if
(
arg
.
NRaw_
%
CDEBlockTransferScalarPerVector_NPerBlock
!=
0
)
{
return
false
;
}
}
else
{
return
false
;
}
}
return
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_m_k_
,
arg
.
b_grid_desc_n_k_
,
arg
.
ds_grid_desc_m_n_
,
arg
.
e_grid_desc_m_n_
,
arg
.
block_2_etile_map_
);
}
bool
IsSupportedArgument
(
const
BaseArgument
*
p_arg
)
override
{
return
IsSupportedArgument
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
));
}
static
auto
MakeArgument
(
const
void
*
p_a
,
const
void
*
p_b
,
std
::
array
<
const
void
*
,
NumDTensor
>
p_ds
,
void
*
p_e
,
index_t
MRaw
,
index_t
NRaw
,
index_t
KRaw
,
index_t
StrideA
,
index_t
StrideB
,
std
::
array
<
index_t
,
NumDTensor
>
StrideDs
,
index_t
StrideE
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CDEElementwiseOperation
cde_element_op
)
{
return
Argument
{
p_a
,
p_b
,
p_ds
,
p_e
,
MRaw
,
NRaw
,
KRaw
,
StrideA
,
StrideB
,
StrideDs
,
StrideE
,
a_element_op
,
b_element_op
,
cde_element_op
};
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
void
*
p_a
,
const
void
*
p_b
,
std
::
array
<
const
void
*
,
NumDTensor
>
p_ds
,
void
*
p_e
,
index_t
MRaw
,
index_t
NRaw
,
index_t
KRaw
,
index_t
StrideA
,
index_t
StrideB
,
std
::
array
<
ck
::
index_t
,
NumDTensor
>
StrideDs
,
index_t
StrideE
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CDEElementwiseOperation
cde_element_op
)
override
{
return
std
::
make_unique
<
Argument
>
(
p_a
,
p_b
,
p_ds
,
p_e
,
MRaw
,
NRaw
,
KRaw
,
StrideA
,
StrideB
,
StrideDs
,
StrideE
,
a_element_op
,
b_element_op
,
cde_element_op
);
}
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
override
{
return
std
::
make_unique
<
Invoker
>
(
Invoker
{});
}
std
::
string
GetTypeString
()
const
override
{
auto
str
=
std
::
stringstream
();
std
::
map
<
LoopScheduler
,
std
::
string
>
LoopSchedToString
{
{
LoopScheduler
::
Default
,
"Default"
},
{
LoopScheduler
::
Interwave
,
"Interwave"
}};
std
::
map
<
PipelineVersion
,
std
::
string
>
PipelineVersionToString
{
{
PipelineVersion
::
v1
,
"v1"
},
{
PipelineVersion
::
v2
,
"v2"
},
{
PipelineVersion
::
v4
,
"v4"
}};
// clang-format off
str
<<
"DeviceGemmMultipleD_Xdl_CShuffle_LdsDirectLoad"
<<
"<"
<<
BlockSize
<<
", "
<<
MPerBlock
<<
", "
<<
NPerBlock
<<
", "
<<
KPerBlock
<<
", "
<<
AK1
<<
", "
<<
BK1
<<
", "
<<
MPerXDL
<<
", "
<<
NPerXDL
<<
", "
<<
MXdlPerWave
<<
", "
<<
NXdlPerWave
<<
", "
<<
ABlockTransferScalarPerVector
<<
", "
<<
BBlockTransferScalarPerVector
<<
", "
<<
CShuffleMXdlPerWavePerShuffle
<<
", "
<<
CShuffleNXdlPerWavePerShuffle
<<
", "
<<
getGemmSpecializationString
(
GemmSpec
)
<<
">"
<<
" LoopScheduler: "
<<
LoopSchedToString
[
LoopSched
]
<<
", "
<<
"PipelineVersion: "
<<
PipelineVersionToString
[
PipelineVer
];
// clang-format on
return
str
.
str
();
}
};
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl.hpp
View file @
e70a4d19
...
...
@@ -184,7 +184,8 @@ struct DeviceGemmXdl : public DeviceGemm<ALayout,
return
false
;
}
}
else
if
(
ck
::
get_device_name
()
==
"gfx90a"
||
ck
::
get_device_name
()
==
"gfx940"
)
else
if
(
ck
::
get_device_name
()
==
"gfx90a"
||
ck
::
get_device_name
()
==
"gfx940"
||
ck
::
get_device_name
()
==
"gfx941"
||
ck
::
get_device_name
()
==
"gfx942"
)
{
if
constexpr
(
!
(
is_same_v
<
AccDataType
,
float
>
||
is_same_v
<
AccDataType
,
float
>
||
is_same_v
<
AccDataType
,
int32_t
>
||
is_same_v
<
AccDataType
,
double
>
))
...
...
include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp
View file @
e70a4d19
...
...
@@ -278,6 +278,7 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
// clang-format off
str
<<
"DeviceGemm_Xdl_CShuffle"
<<
"<"
<<
getGemmSpecializationString
(
GemmSpec
)
<<
", "
<<
BlockSize
<<
", "
<<
MPerBlock
<<
", "
<<
NPerBlock
<<
", "
...
...
@@ -296,7 +297,7 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
<<
" LoopScheduler: "
<<
LoopSchedToString
[
LoopSched
]
<<
", "
<<
"PipelineVersion: "
<<
PipelineVersionToString
[
PipelineVer
];
;
<<
PipelineVersionToString
[
PipelineVer
];
// clang-format on
return
str
.
str
();
...
...
include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_lds_direct_load.hpp
0 → 100644
View file @
e70a4d19
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <sstream>
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle_lds_direct_load.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
template
<
typename
ALayout
,
typename
BLayout
,
typename
ELayout
,
typename
ADataType
,
typename
BDataType
,
typename
EDataType
,
typename
AccDataType
,
typename
CShuffleDataType
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CDEElementwiseOperation
,
GemmSpecialization
GemmSpec
,
index_t
NumGemmKPrefetchStage
,
index_t
BlockSize
,
index_t
MPerBlock
,
index_t
NPerBlock
,
index_t
KPerBlock
,
index_t
AK1
,
index_t
BK1
,
index_t
MPerXDL
,
index_t
NPerXDL
,
index_t
MXdlPerWave
,
index_t
NXdlPerWave
,
typename
ABlockTransferThreadClusterLengths_AK0_M_AK1
,
typename
ABlockTransferSrcAccessOrder
,
index_t
ABlockTransferSrcVectorDim
,
index_t
ABlockTransferScalarPerVector
,
bool
ABlockLdsExtraM
,
typename
BBlockTransferThreadClusterLengths_BK0_N_BK1
,
typename
BBlockTransferSrcAccessOrder
,
index_t
BBlockTransferSrcVectorDim
,
index_t
BBlockTransferScalarPerVector
,
bool
BBlockLdsExtraN
,
index_t
CShuffleMXdlPerWavePerShuffle
,
index_t
CShuffleNXdlPerWavePerShuffle
,
typename
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
index_t
CDEBlockTransferScalarPerVector_NPerBlock
,
LoopScheduler
LoopSched
=
make_default_loop_scheduler
(),
PipelineVersion
PipelineVer
=
PipelineVersion
::
v4
,
typename
ComputeDataType
=
EDataType
>
struct
DeviceGemm_Xdl_CShuffle_LdsDirectLoad
:
public
DeviceGemm
<
ALayout
,
BLayout
,
ELayout
,
ADataType
,
BDataType
,
EDataType
,
AElementwiseOperation
,
BElementwiseOperation
,
CDEElementwiseOperation
>
{
static
constexpr
auto
I1
=
Number
<
1
>
{};
using
GridwiseGemm
=
GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad
<
ALayout
,
BLayout
,
ck
::
Tuple
<>
,
ELayout
,
ADataType
,
BDataType
,
ComputeDataType
,
AccDataType
,
CShuffleDataType
,
ck
::
Tuple
<>
,
EDataType
,
AElementwiseOperation
,
BElementwiseOperation
,
CDEElementwiseOperation
,
InMemoryDataOperationEnum
::
Set
,
GemmSpec
,
NumGemmKPrefetchStage
,
BlockSize
,
MPerBlock
,
NPerBlock
,
KPerBlock
,
AK1
,
BK1
,
MPerXDL
,
NPerXDL
,
MXdlPerWave
,
NXdlPerWave
,
ABlockTransferThreadClusterLengths_AK0_M_AK1
,
ABlockTransferSrcAccessOrder
,
ABlockTransferSrcVectorDim
,
ABlockTransferScalarPerVector
,
ABlockLdsExtraM
,
BBlockTransferThreadClusterLengths_BK0_N_BK1
,
BBlockTransferSrcAccessOrder
,
BBlockTransferSrcVectorDim
,
BBlockTransferScalarPerVector
,
BBlockLdsExtraN
,
CShuffleMXdlPerWavePerShuffle
,
CShuffleNXdlPerWavePerShuffle
,
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
CDEBlockTransferScalarPerVector_NPerBlock
,
LoopSched
,
PipelineVer
>
;
using
Argument
=
typename
GridwiseGemm
::
Argument
;
struct
Invoker
:
public
BaseInvoker
{
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
{
if
(
!
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_m_k_
,
arg
.
b_grid_desc_n_k_
,
arg
.
ds_grid_desc_m_n_
,
arg
.
e_grid_desc_m_n_
,
arg
.
block_2_etile_map_
))
{
throw
std
::
runtime_error
(
"wrong! GridwiseGemm has invalid setting"
);
}
const
index_t
grid_size
=
arg
.
block_2_etile_map_
.
CalculateGridSize
(
arg
.
e_grid_desc_m_n_
);
auto
launch_kernel
=
[
&
](
auto
has_main_k_block_loop
)
{
constexpr
bool
has_main_loop
=
has_main_k_block_loop
.
value
;
const
auto
kernel
=
kernel_gemm_multiple_d_xdl_cshuffle_lds_direct_load
<
GridwiseGemm
,
ADataType
,
BDataType
,
typename
GridwiseGemm
::
DsGridPointer
,
EDataType
,
AElementwiseOperation
,
BElementwiseOperation
,
CDEElementwiseOperation
,
typename
GridwiseGemm
::
AGridDesc_AK0_M_AK1
,
typename
GridwiseGemm
::
BGridDesc_BK0_N_BK1
,
typename
GridwiseGemm
::
DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
GridwiseGemm
::
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
GridwiseGemm
::
Block2ETileMap
,
has_main_loop
>
;
return
launch_and_time_kernel
(
stream_config
,
kernel
,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
arg
.
p_a_grid_
,
arg
.
p_b_grid_
,
// arg.p_ds_grid_,
ck
::
Tuple
<>
{},
arg
.
p_e_grid_
,
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
cde_element_op_
,
arg
.
a_grid_desc_ak0_m_ak1_
,
arg
.
b_grid_desc_bk0_n_bk1_
,
arg
.
ds_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg
.
e_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg
.
block_2_etile_map_
);
};
const
auto
K
=
arg
.
a_grid_desc_m_k_
.
GetLength
(
I1
);
if
(
GridwiseGemm
::
CalculateHasMainKBlockLoop
(
K
))
{
return
launch_kernel
(
integral_constant
<
bool
,
true
>
{});
}
else
{
return
launch_kernel
(
integral_constant
<
bool
,
false
>
{});
}
}
float
Run
(
const
BaseArgument
*
p_arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
override
{
return
Run
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
),
stream_config
);
}
};
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
if
(
!
ck
::
is_xdl_supported
())
{
return
false
;
}
if
(
!
ck
::
is_lds_direct_load_supported
())
{
return
false
;
}
// Check vector load/store.
{
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
// Check vector load of A.
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
ABlockTransferSrcVectorDim
==
2
)
{
if
(
arg
.
KRaw_
%
ABlockTransferScalarPerVector
!=
0
)
{
return
false
;
}
}
else
if
constexpr
(
is_same_v
<
ALayout
,
Col
>
&&
ABlockTransferSrcVectorDim
==
1
)
{
if
(
arg
.
MRaw_
%
ABlockTransferScalarPerVector
!=
0
)
{
return
false
;
}
}
else
{
return
false
;
}
// Check vector load of B.
if
constexpr
(
is_same_v
<
BLayout
,
Col
>
&&
BBlockTransferSrcVectorDim
==
2
)
{
if
(
arg
.
KRaw_
%
BBlockTransferScalarPerVector
!=
0
)
{
return
false
;
}
}
else
if
constexpr
(
is_same_v
<
BLayout
,
Row
>
&&
BBlockTransferSrcVectorDim
==
1
)
{
if
(
arg
.
NRaw_
%
BBlockTransferScalarPerVector
!=
0
)
{
return
false
;
}
}
else
{
return
false
;
}
// Check vector load of E.
// For now, only the RowMajor layout is supported.
if
constexpr
(
is_same_v
<
ELayout
,
Row
>
)
{
if
(
arg
.
NRaw_
%
CDEBlockTransferScalarPerVector_NPerBlock
!=
0
)
{
return
false
;
}
}
else
{
return
false
;
}
}
return
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_m_k_
,
arg
.
b_grid_desc_n_k_
,
arg
.
ds_grid_desc_m_n_
,
arg
.
e_grid_desc_m_n_
,
arg
.
block_2_etile_map_
);
}
bool
IsSupportedArgument
(
const
BaseArgument
*
p_arg
)
override
{
return
IsSupportedArgument
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
));
}
static
auto
MakeArgument
(
const
void
*
p_a
,
const
void
*
p_b
,
void
*
p_e
,
index_t
MRaw
,
index_t
NRaw
,
index_t
KRaw
,
index_t
StrideA
,
index_t
StrideB
,
index_t
StrideE
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CDEElementwiseOperation
cde_element_op
)
{
using
EmptyDsPointers
=
std
::
array
<
const
void
*
,
0
>
;
using
EmptyDsStrides
=
std
::
array
<
ck
::
index_t
,
0
>
;
return
Argument
{
p_a
,
p_b
,
EmptyDsPointers
{},
p_e
,
MRaw
,
NRaw
,
KRaw
,
StrideA
,
StrideB
,
EmptyDsStrides
{},
StrideE
,
a_element_op
,
b_element_op
,
cde_element_op
};
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
void
*
p_a
,
const
void
*
p_b
,
void
*
p_e
,
index_t
MRaw
,
index_t
NRaw
,
index_t
KRaw
,
index_t
StrideA
,
index_t
StrideB
,
index_t
StrideE
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CDEElementwiseOperation
cde_element_op
)
override
{
using
EmptyDsPointers
=
std
::
array
<
const
void
*
,
0
>
;
using
EmptyDsStrides
=
std
::
array
<
ck
::
index_t
,
0
>
;
return
std
::
make_unique
<
Argument
>
(
p_a
,
p_b
,
EmptyDsPointers
{},
p_e
,
MRaw
,
NRaw
,
KRaw
,
StrideA
,
StrideB
,
EmptyDsStrides
{},
StrideE
,
a_element_op
,
b_element_op
,
cde_element_op
);
}
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
override
{
return
std
::
make_unique
<
Invoker
>
(
Invoker
{});
}
std
::
string
GetTypeString
()
const
override
{
auto
str
=
std
::
stringstream
();
std
::
map
<
LoopScheduler
,
std
::
string
>
LoopSchedToString
{
{
LoopScheduler
::
Default
,
"Default"
},
{
LoopScheduler
::
Interwave
,
"Interwave"
}};
std
::
map
<
PipelineVersion
,
std
::
string
>
PipelineVersionToString
{
{
PipelineVersion
::
v1
,
"v1"
},
{
PipelineVersion
::
v2
,
"v2"
},
{
PipelineVersion
::
v4
,
"v4"
}};
// clang-format off
str
<<
"DeviceGemm_Xdl_CShuffle_LdsDirectLoad"
<<
"<"
<<
BlockSize
<<
", "
<<
MPerBlock
<<
", "
<<
NPerBlock
<<
", "
<<
KPerBlock
<<
", "
<<
AK1
<<
", "
<<
BK1
<<
", "
<<
MPerXDL
<<
", "
<<
NPerXDL
<<
", "
<<
MXdlPerWave
<<
", "
<<
NXdlPerWave
<<
", "
<<
ABlockTransferScalarPerVector
<<
", "
<<
BBlockTransferScalarPerVector
<<
", "
<<
CShuffleMXdlPerWavePerShuffle
<<
", "
<<
CShuffleNXdlPerWavePerShuffle
<<
", "
<<
getGemmSpecializationString
(
GemmSpec
)
<<
">"
<<
" LoopScheduler: "
<<
LoopSchedToString
[
LoopSched
]
<<
", "
<<
"PipelineVersion: "
<<
PipelineVersionToString
[
PipelineVer
]
<<
", "
<<
"Prefetch: "
<<
NumGemmKPrefetchStage
;
// clang-format on
return
str
.
str
();
}
};
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_splitk_c_shuffle.hpp
View file @
e70a4d19
...
...
@@ -59,7 +59,8 @@ template <typename ADataType,
typename
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
index_t
CBlockTransferScalarPerVector_NWaveNPerXDL
,
typename
ComputeType
=
CDataType
,
PipelineVersion
PipelineVer
=
PipelineVersion
::
v1
>
PipelineVersion
PipelineVer
=
PipelineVersion
::
v1
,
LoopScheduler
LoopSched
=
make_default_loop_scheduler
()>
struct
DeviceGemmXdlSplitKCShuffle
:
public
DeviceGemmSplitK
<
ALayout
,
BLayout
,
...
...
@@ -79,7 +80,6 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
// TODO: should be exposed as Tparams.
static
constexpr
index_t
NumGemmKPrefetchStage
=
1
;
static
constexpr
LoopScheduler
LoopSched
=
make_default_loop_scheduler
();
using
GridwiseGemm
=
GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
<
BlockSize
,
...
...
@@ -141,7 +141,7 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
index_t
MPadded_
,
index_t
NPadded_
,
index_t
KPadded_
,
index_t
K0_
,
index_t
K0
Padded
_
,
index_t
k_batch_
,
AElementwiseOperation
a_element_op_
,
BElementwiseOperation
b_element_op_
,
...
...
@@ -158,7 +158,7 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
MPadded_
,
NPadded_
,
KPadded_
,
K0_
,
K0
Padded
_
,
k_batch_
),
a_element_op
(
a_element_op_
),
b_element_op
(
b_element_op_
),
...
...
@@ -198,9 +198,9 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
const
auto
b2c_map
=
DefaultBlock2CTileMap
{};
index_t
gdx
,
gdy
,
gdz
;
std
::
tie
(
gdx
,
gdy
,
gdz
)
=
b2c_map
.
CalculateGridSize
(
karg
.
M
,
karg
.
N
,
karg
.
k_batch
);
const
auto
K0
=
karg
.
K0
;
const
auto
K0
Padded
=
karg
.
K0
Padded
;
const
bool
has_main_k0_block_loop
=
GridwiseGemm
::
CalculateHasMainK0BlockLoop
(
K0
);
const
bool
has_main_k0_block_loop
=
GridwiseGemm
::
CalculateHasMainK0BlockLoop
(
K0
Padded
);
float
ave_time
=
0
;
...
...
@@ -342,7 +342,7 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
GridwiseGemm
::
CalculateMPadded
(
M
),
GridwiseGemm
::
CalculateNPadded
(
N
),
GridwiseGemm
::
CalculateKPadded
(
K
,
KBatch
),
GridwiseGemm
::
CalculateK0
(
K
,
KBatch
),
GridwiseGemm
::
CalculateK0
Padded
(
K
,
KBatch
),
KBatch
,
a_element_op
,
b_element_op
,
...
...
@@ -378,7 +378,7 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
GridwiseGemm
::
CalculateMPadded
(
M
),
GridwiseGemm
::
CalculateNPadded
(
N
),
GridwiseGemm
::
CalculateKPadded
(
K
,
KBatch
),
GridwiseGemm
::
CalculateK0
(
K
,
KBatch
),
GridwiseGemm
::
CalculateK0
Padded
(
K
,
KBatch
),
KBatch
,
a_element_op
,
b_element_op
,
...
...
@@ -392,7 +392,21 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
}
// polymorphic
std
::
string
GetTypeString
()
const
override
{
return
GridwiseGemm
::
GetTypeString
();
}
std
::
string
GetTypeString
()
const
override
{
auto
str
=
std
::
stringstream
();
std
::
map
<
LoopScheduler
,
std
::
string
>
LoopSchedToString
{
{
LoopScheduler
::
Default
,
"Default"
},
{
LoopScheduler
::
Interwave
,
"Interwave"
}};
std
::
map
<
PipelineVersion
,
std
::
string
>
PipelineVersionToString
{{
PipelineVersion
::
v1
,
"v1"
},
{
PipelineVersion
::
v2
,
"v2"
}};
str
<<
GridwiseGemm
::
GetTypeString
()
<<
" LoopScheduler: "
<<
LoopSchedToString
[
LoopSched
]
<<
", PipelineVersion: "
<<
PipelineVersionToString
[
PipelineVer
];
return
str
.
str
();
}
};
}
// namespace device
...
...
include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_streamk.hpp
View file @
e70a4d19
...
...
@@ -226,7 +226,9 @@ struct DeviceGemmXdlStreamK : public DeviceGemmStreamK<ALayout,
}
}
void
SetWorkSpacePointer
(
BaseArgument
*
pArg
,
void
*
p_workspace
)
const
override
void
SetWorkSpacePointer
(
BaseArgument
*
pArg
,
void
*
p_workspace
,
const
StreamConfig
&
=
StreamConfig
{})
const
override
{
Argument
*
pArg_
=
dynamic_cast
<
Argument
*>
(
pArg
);
...
...
Prev
1
…
5
6
7
8
9
10
11
12
13
…
24
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment