Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
9f8ab221
Unverified
Commit
9f8ab221
authored
Oct 19, 2023
by
zjing14
Committed by
GitHub
Oct 19, 2023
Browse files
Merge branch 'develop' into add_int8_wmma_example_instance
parents
755ace59
b4fc4d0b
Changes
490
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
4238 additions
and
292 deletions
+4238
-292
include/ck/tensor_operation/gpu/device/device_gemm_multiple_abd.hpp
.../tensor_operation/gpu/device/device_gemm_multiple_abd.hpp
+60
-0
include/ck/tensor_operation/gpu/device/device_gemm_splitk.hpp
...ude/ck/tensor_operation/gpu/device/device_gemm_splitk.hpp
+6
-3
include/ck/tensor_operation/gpu/device/device_grouped_conv_bwd_data_multiple_d.hpp
...on/gpu/device/device_grouped_conv_bwd_data_multiple_d.hpp
+3
-1
include/ck/tensor_operation/gpu/device/device_grouped_conv_bwd_weight.hpp
...r_operation/gpu/device/device_grouped_conv_bwd_weight.hpp
+3
-1
include/ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_d.hpp
...eration/gpu/device/device_grouped_conv_fwd_multiple_d.hpp
+2
-1
include/ck/tensor_operation/gpu/device/device_normalization.hpp
...e/ck/tensor_operation/gpu/device/device_normalization.hpp
+5
-3
include/ck/tensor_operation/gpu/device/impl/device_column_to_image_impl.hpp
...operation/gpu/device/impl/device_column_to_image_impl.hpp
+621
-0
include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_abd_xdl_cshuffle.hpp
...ice/impl/device_contraction_multiple_abd_xdl_cshuffle.hpp
+847
-0
include/ck/tensor_operation/gpu/device/impl/device_elementwise_impl.hpp
...sor_operation/gpu/device/impl/device_elementwise_impl.hpp
+22
-0
include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_abd_xdl_cshuffle.hpp
...gpu/device/impl/device_gemm_multiple_abd_xdl_cshuffle.hpp
+766
-0
include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp
...or_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp
+4
-2
include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_splitk_c_shuffle.hpp
...tion/gpu/device/impl/device_gemm_xdl_splitk_c_shuffle.hpp
+88
-17
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp
...device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp
+879
-0
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp
...vice_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp
+15
-53
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_dl.hpp
...ion/gpu/device/impl/device_grouped_conv_bwd_weight_dl.hpp
+3
-28
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle.hpp
...ice/impl/device_grouped_conv_bwd_weight_wmma_cshuffle.hpp
+877
-0
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp
...vice/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp
+29
-39
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp
.../device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp
+1
-45
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp
...impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp
+2
-50
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_cshuffle.hpp
.../impl/device_grouped_conv_fwd_multiple_d_xdl_cshuffle.hpp
+5
-49
No files found.
include/ck/tensor_operation/gpu/device/device_gemm_multiple_abd.hpp
0 → 100644
View file @
9f8ab221
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <array>
#include "ck/tensor_operation/gpu/device/device_base.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
// GEMM:
// input : A0[M, K], B0[K, N],
// input : D0[M, N], D1[M, N], ...
// output : E[M, N]
// C = a_op(A) * b_op(B)
// E = cde_op(C, D0, D1, ...)
// Assume:
// D0, D1, ... and E have the same layout
template
<
typename
AsLayout
,
typename
BsLayout
,
typename
DsLayout
,
typename
ELayout
,
typename
AsDataType
,
typename
BsDataType
,
typename
DsDataType
,
typename
EDataType
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CDEElementwiseOperation
>
struct
DeviceGemmMultipleABD
:
public
BaseOperator
{
static
constexpr
index_t
NumATensor
=
AsDataType
::
Size
();
static
constexpr
index_t
NumBTensor
=
BsDataType
::
Size
();
static
constexpr
index_t
NumDTensor
=
DsDataType
::
Size
();
virtual
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
std
::
array
<
const
void
*
,
NumATensor
>
p_as
,
std
::
array
<
const
void
*
,
NumBTensor
>
p_bs
,
std
::
array
<
const
void
*
,
NumDTensor
>
p_ds
,
void
*
p_e
,
ck
::
index_t
M
,
ck
::
index_t
N
,
ck
::
index_t
K
,
std
::
array
<
ck
::
index_t
,
NumATensor
>
StrideAs
,
std
::
array
<
ck
::
index_t
,
NumBTensor
>
StrideBs
,
std
::
array
<
ck
::
index_t
,
NumDTensor
>
StrideDs
,
ck
::
index_t
StrideE
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CDEElementwiseOperation
cde_element_op
)
=
0
;
virtual
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
=
0
;
};
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
include/ck/tensor_operation/gpu/device/device_gemm_splitk.hpp
View file @
9f8ab221
...
...
@@ -20,7 +20,8 @@ template <typename ALayout,
typename
CDataType
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
>
typename
CElementwiseOperation
,
typename
ComputeType
=
CDataType
>
struct
DeviceGemmSplitK
:
public
BaseOperator
{
virtual
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
void
*
p_a
,
...
...
@@ -48,7 +49,8 @@ template <typename ALayout,
typename
CDataType
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
>
typename
CElementwiseOperation
,
typename
ComputeType
=
CDataType
>
using
DeviceGemmSplitKPtr
=
std
::
unique_ptr
<
DeviceGemmSplitK
<
ALayout
,
BLayout
,
CLayout
,
...
...
@@ -57,7 +59,8 @@ using DeviceGemmSplitKPtr = std::unique_ptr<DeviceGemmSplitK<ALayout,
CDataType
,
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
>>
;
CElementwiseOperation
,
ComputeType
>>
;
}
// namespace device
}
// namespace tensor_operation
...
...
include/ck/tensor_operation/gpu/device/device_grouped_conv_bwd_data_multiple_d.hpp
View file @
9f8ab221
...
...
@@ -29,7 +29,9 @@ template <ck::index_t NDimSpatial,
typename
EDataType
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CDEElementwiseOperation
>
typename
CDEElementwiseOperation
,
typename
AComputeType
=
ADataType
,
typename
BComputeType
=
AComputeType
>
struct
DeviceGroupedConvBwdDataMultipleD
:
public
BaseOperator
{
static
constexpr
index_t
NumDTensor
=
DsDataType
::
Size
();
...
...
include/ck/tensor_operation/gpu/device/device_grouped_conv_bwd_weight.hpp
View file @
9f8ab221
...
...
@@ -20,7 +20,9 @@ template <ck::index_t NDimSpatial,
typename
OutDataType
,
typename
InElementwiseOperation
,
typename
WeiElementwiseOperation
,
typename
OutElementwiseOperation
>
typename
OutElementwiseOperation
,
typename
ComputeTypeA
=
InDataType
,
typename
ComputeTypeB
=
ComputeTypeA
>
struct
DeviceGroupedConvBwdWeight
:
public
BaseOperator
{
virtual
std
::
unique_ptr
<
BaseArgument
>
...
...
include/ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_d.hpp
View file @
9f8ab221
...
...
@@ -29,7 +29,8 @@ template <index_t NDimSpatial,
typename
EDataType
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CDEElementwiseOperation
>
typename
CDEElementwiseOperation
,
typename
ComputeType
=
ADataType
>
struct
DeviceGroupedConvFwdMultipleD
:
public
BaseOperator
{
static
constexpr
index_t
NumDTensor
=
DsDataType
::
Size
();
...
...
include/ck/tensor_operation/gpu/device/device_normalization.hpp
View file @
9f8ab221
...
...
@@ -14,8 +14,8 @@ namespace device {
template
<
typename
XDataType
,
typename
GammaDataType
,
typename
BetaDataType
,
typename
ComputeDataType
,
typename
YDataType
,
typename
SaveMeanInvStdDataType
,
typename
YElementwiseOperation
,
index_t
Rank
,
index_t
NumReduceDim
>
...
...
@@ -27,6 +27,8 @@ struct DeviceNormalization : public BaseOperator
const
std
::
vector
<
index_t
>
gammaStrides
,
const
std
::
vector
<
index_t
>
betaStrides
,
const
std
::
vector
<
index_t
>
yStrides
,
const
std
::
vector
<
index_t
>
saveMeanStrides
,
const
std
::
vector
<
index_t
>
saveInvStdStrides
,
const
std
::
vector
<
index_t
>
reduceDims
,
double
epsilon
,
const
void
*
p_x
,
...
...
@@ -43,16 +45,16 @@ struct DeviceNormalization : public BaseOperator
template
<
typename
XDataType
,
typename
GammaDataType
,
typename
BetaDataType
,
typename
ComputeDataType
,
typename
YDataType
,
typename
SaveMeanInvStdDataType
,
typename
YElementwiseOperation
,
index_t
Rank
,
index_t
NumReduceDim
>
using
DeviceNormalizationPtr
=
std
::
unique_ptr
<
DeviceNormalization
<
XDataType
,
GammaDataType
,
BetaDataType
,
ComputeDataType
,
YDataType
,
SaveMeanInvStdDataType
,
YElementwiseOperation
,
Rank
,
NumReduceDim
>>
;
...
...
include/ck/tensor_operation/gpu/device/impl/device_column_to_image_impl.hpp
0 → 100644
View file @
9f8ab221
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/device_conv_tensor_rearrange.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_tensor_rearrange.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/tensor_operation/gpu/device/convolution_backward_data_specialization.hpp"
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/conv_tensor_rearrange_op.hpp"
#include "ck/host_utility/io.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
// Image to column for input layout NDHWC:
// input : image converted to the gemm problem [N * Do * Ho * Wo, Z * Y * X * C]
// output : image [N, Di, Hi, Wi, C]
template
<
index_t
NDimSpatial
,
typename
ImageLayout
,
typename
InputDataType
,
typename
OutputDataType
,
index_t
BlockSize
,
index_t
MPerBlock
,
index_t
KPerBlock
,
typename
ThreadClusterLengths
,
index_t
ScalarPerVector
,
typename
std
::
enable_if
<
NDimSpatial
>
=
1
&&
NDimSpatial
<=
3
,
bool
>::
type
=
false
>
struct
DeviceColumnToImageImpl
:
public
DeviceConvTensorRearrange
<
NDimSpatial
,
ImageLayout
,
InputDataType
,
OutputDataType
,
conv_tensor_rearrange_op
::
ColumnToImage
>
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
ZIdx
=
Number
<
I0
>
{};
static
constexpr
auto
YIdx
=
NDimSpatial
==
1
?
I0
:
Number
<
NDimSpatial
-
I2
>
{};
static
constexpr
auto
XIdx
=
Number
<
NDimSpatial
-
I1
>
{};
static
constexpr
auto
spatial_offset
=
Number
<
3
>
{};
static
constexpr
auto
conv_to_gemm_transformer
=
TransformConvFwdToGemm
<
NDimSpatial
,
ConvolutionForwardSpecialization
::
Default
>
{};
static
constexpr
auto
matrix_padder
=
MatrixPadder
<
GemmSpecialization
::
MKPadding
,
index_t
,
index_t
,
index_t
>
{
MPerBlock
,
0
/* NPerBlock*/
,
KPerBlock
};
// Calculate number of independent filters for given conv params
static
index_t
GetNumberOfIndependentFilters
(
const
index_t
input_spatial_len
,
const
index_t
left_pad
,
const
index_t
right_pad
,
const
index_t
filter_len
,
const
index_t
filter_stride
,
const
index_t
filter_dilation
,
const
index_t
image_offset
)
{
const
index_t
x_eff
=
(
filter_len
-
1
)
*
filter_dilation
+
1
;
const
index_t
next_filter_padded
=
math
::
integer_divide_ceil
(
x_eff
,
filter_stride
)
*
filter_stride
;
// If filter_stride >= x_eff then each filter is independent
const
index_t
independent_filter_stride
=
filter_stride
>=
x_eff
?
filter_stride
:
next_filter_padded
;
const
index_t
w_eff
=
input_spatial_len
-
image_offset
+
left_pad
+
right_pad
-
x_eff
;
// There are no independent filters
if
(
w_eff
<
0
)
return
0
;
const
index_t
independent_kernels_num
=
w_eff
/
independent_filter_stride
+
1
;
return
independent_kernels_num
;
}
// Make column form descriptor
static
auto
MakeInputDescriptor_M_K
(
const
ck
::
index_t
N
,
const
ck
::
index_t
C
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
filter_spatial_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
output_spatial_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_strides
,
const
std
::
array
<
index_t
,
2
>&
gemm_m_k_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
independent_filters
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
effs
)
{
const
index_t
DoHoWo
=
ck
::
accumulate_n
<
index_t
>
(
output_spatial_lengths
.
begin
(),
NDimSpatial
,
1
,
std
::
multiplies
<>
());
const
index_t
CZYX
=
C
*
ck
::
accumulate_n
<
index_t
>
(
filter_spatial_lengths
.
begin
(),
NDimSpatial
,
1
,
std
::
multiplies
<>
());
const
index_t
NStride
=
DoHoWo
*
gemm_m_k_strides
[
I0
]
*
gemm_m_k_strides
[
I1
];
// Calculate the appropriate stride for each set of independent filters
// in each dimension
const
index_t
WStride
=
math
::
integer_divide_ceil
(
effs
[
XIdx
],
conv_filter_strides
[
XIdx
])
*
gemm_m_k_strides
[
I0
];
const
index_t
HStride
=
math
::
integer_divide_ceil
(
effs
[
YIdx
],
conv_filter_strides
[
YIdx
])
*
output_spatial_lengths
[
XIdx
]
*
gemm_m_k_strides
[
I0
];
const
index_t
DStride
=
math
::
integer_divide_ceil
(
effs
[
ZIdx
],
conv_filter_strides
[
ZIdx
])
*
output_spatial_lengths
[
YIdx
]
*
output_spatial_lengths
[
XIdx
]
*
gemm_m_k_strides
[
I0
];
// Create descriptor for independent filters in each dimension and
// then merge them into column form
if
constexpr
(
NDimSpatial
==
1
)
{
const
auto
desc_gemm_form
=
make_naive_tensor_descriptor
(
make_tuple
(
N
,
independent_filters
[
XIdx
],
CZYX
),
make_tuple
(
NStride
,
WStride
,
gemm_m_k_strides
[
I1
]));
const
auto
desc_gemm_form_merged_filters
=
transform_tensor_descriptor
(
desc_gemm_form
,
make_tuple
(
make_merge_transform
(
make_tuple
(
N
,
independent_filters
[
XIdx
])),
make_pass_through_transform
(
CZYX
)),
make_tuple
(
Sequence
<
0
,
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
desc_m_k
=
matrix_padder
.
PadADescriptor_M_K
(
desc_gemm_form_merged_filters
);
return
desc_m_k
;
}
else
if
constexpr
(
NDimSpatial
==
2
)
{
const
auto
desc_gemm_form
=
make_naive_tensor_descriptor
(
make_tuple
(
N
,
independent_filters
[
YIdx
],
independent_filters
[
XIdx
],
CZYX
),
make_tuple
(
NStride
,
HStride
,
WStride
,
gemm_m_k_strides
[
I1
]));
const
auto
desc_gemm_form_merged_filters
=
transform_tensor_descriptor
(
desc_gemm_form
,
make_tuple
(
make_merge_transform
(
make_tuple
(
N
,
independent_filters
[
YIdx
],
independent_filters
[
XIdx
])),
make_pass_through_transform
(
CZYX
)),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
desc_m_k
=
matrix_padder
.
PadADescriptor_M_K
(
desc_gemm_form_merged_filters
);
return
desc_m_k
;
}
else
if
constexpr
(
NDimSpatial
==
3
)
{
const
auto
desc_gemm_form
=
make_naive_tensor_descriptor
(
make_tuple
(
N
,
independent_filters
[
ZIdx
],
independent_filters
[
YIdx
],
independent_filters
[
XIdx
],
CZYX
),
make_tuple
(
NStride
,
DStride
,
HStride
,
WStride
,
gemm_m_k_strides
[
I1
]));
const
auto
desc_gemm_form_merged_filters
=
transform_tensor_descriptor
(
desc_gemm_form
,
make_tuple
(
make_merge_transform
(
make_tuple
(
N
,
independent_filters
[
ZIdx
],
independent_filters
[
YIdx
],
independent_filters
[
XIdx
])),
make_pass_through_transform
(
CZYX
)),
make_tuple
(
Sequence
<
0
,
1
,
2
,
3
>
{},
Sequence
<
4
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
desc_m_k
=
matrix_padder
.
PadADescriptor_M_K
(
desc_gemm_form_merged_filters
);
return
desc_m_k
;
}
}
// Use MakeADescriptor_M_K from grouped convolution forward
static
auto
MakeOutDescriptor_M_K
(
const
ck
::
index_t
N
,
const
ck
::
index_t
C
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_spatial_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
filter_spatial_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
image_g_n_c_wis_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_dilations
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_left_pads
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_right_pads
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
image_offsets
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
independent_filters
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
effs
)
{
std
::
array
<
index_t
,
NDimSpatial
+
3
>
a_g_n_c_wis_lengths
{
1
};
std
::
array
<
index_t
,
NDimSpatial
+
3
>
b_g_k_c_xs_lengths
{
1
};
std
::
array
<
index_t
,
NDimSpatial
+
3
>
c_g_n_k_wos_lengths
{
1
};
auto
copy
=
[](
const
auto
&
x
,
auto
&
y
,
index_t
dst_offset
)
{
std
::
copy
(
x
.
begin
(),
x
.
end
(),
y
.
begin
()
+
dst_offset
);
};
copy
(
input_spatial_lengths
,
a_g_n_c_wis_lengths
,
spatial_offset
);
copy
(
filter_spatial_lengths
,
b_g_k_c_xs_lengths
,
spatial_offset
);
// Calculate descriptor only for independent filters
copy
(
independent_filters
,
c_g_n_k_wos_lengths
,
spatial_offset
);
// fill only significant values (C and N)
a_g_n_c_wis_lengths
[
I1
]
=
N
;
a_g_n_c_wis_lengths
[
I2
]
=
C
;
b_g_k_c_xs_lengths
[
I2
]
=
C
;
c_g_n_k_wos_lengths
[
I1
]
=
N
;
// Modify pads to apply offsets
std
::
array
<
index_t
,
NDimSpatial
>
input_left_pads_with_offset
;
for
(
index_t
i
=
0
;
i
<
NDimSpatial
;
i
++
)
{
input_left_pads_with_offset
[
i
]
=
math
::
max
(
0
,
input_left_pads
[
i
]
-
image_offsets
[
i
]);
}
// Modify input spatial lengths to apply offsets
for
(
index_t
i
=
0
;
i
<
NDimSpatial
;
i
++
)
{
a_g_n_c_wis_lengths
[
i
+
spatial_offset
]
-=
math
::
max
(
0
,
image_offsets
[
i
]
-
input_left_pads
[
i
]);
}
// Strides to next independent filters
std
::
array
<
index_t
,
NDimSpatial
>
independent_filter_strides
;
for
(
index_t
i
=
0
;
i
<
NDimSpatial
;
i
++
)
{
index_t
independent_filter_stride
=
math
::
integer_divide_ceil
(
effs
[
i
],
conv_filter_strides
[
i
])
*
conv_filter_strides
[
i
];
// If conv stride is greater than whole filter size, use conv stride
independent_filter_strides
[
i
]
=
conv_filter_strides
[
i
]
>=
effs
[
i
]
?
conv_filter_strides
[
i
]
:
independent_filter_stride
;
}
// Calculate image form descriptor for the modified convolution problem
const
auto
in_gemmmraw_gemmkraw_desc
=
conv_to_gemm_transformer
.
template
MakeADescriptor_M_K
<
ImageLayout
>(
a_g_n_c_wis_lengths
,
image_g_n_c_wis_strides
,
b_g_k_c_xs_lengths
,
{},
// not needed for A Descriptor
c_g_n_k_wos_lengths
,
{},
// not needed for A Descriptor
// conv_filter_strides,
independent_filter_strides
,
conv_filter_dilations
,
input_left_pads_with_offset
,
input_right_pads
);
const
auto
in_gemmm_gemmk_desc
=
matrix_padder
.
PadADescriptor_M_K
(
in_gemmmraw_gemmkraw_desc
);
return
in_gemmm_gemmk_desc
;
}
using
InputGridDesc
=
remove_cvref_t
<
decltype
(
MakeInputDescriptor_M_K
(
1
,
1
,
{},
{},
{},
{},
{},
{}))
>
;
using
OutputGridDesc
=
remove_cvref_t
<
decltype
(
MakeOutDescriptor_M_K
(
1
,
1
,
{},
{},
{},
{},
{},
{},
{},
{},
{},
{}))
>
;
using
Block2ETileMap
=
remove_cvref_t
<
decltype
(
BlockToCTileMap_M00_N0_M01Adapt
<
MPerBlock
,
KPerBlock
,
InputGridDesc
>
(
InputGridDesc
{}))
>
;
using
GridwiseTensorRearrangeKernel
=
GridwiseTensorRearrange
<
InputGridDesc
,
InputDataType
,
OutputGridDesc
,
OutputDataType
,
BlockSize
,
MPerBlock
,
KPerBlock
,
ThreadClusterLengths
,
ScalarPerVector
,
InMemoryDataOperationEnum
::
Add
,
Block2ETileMap
>
;
struct
Argument
:
public
BaseArgument
{
Argument
(
const
void
*
p_in
,
// input image
void
*
p_out
,
// output image
const
ck
::
index_t
N
,
const
ck
::
index_t
C
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_spatial_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
filter_spatial_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
output_spatial_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
image_g_n_c_wis_strides
,
const
std
::
array
<
index_t
,
2
>&
gemm_m_k_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_dilations
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_left_pads
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_right_pads
)
:
C_
(
C
),
X_
(
filter_spatial_lengths
[
NDimSpatial
-
I1
]),
p_in_
{
static_cast
<
const
InputDataType
*>
(
p_in
)},
p_out_
{
static_cast
<
OutputDataType
*>
(
p_out
)},
image_g_n_c_wis_strides_
{
image_g_n_c_wis_strides
},
conv_filter_strides_
{
conv_filter_strides
},
conv_filter_dilations_
{
conv_filter_dilations
},
input_left_pads_
{
input_left_pads
},
input_right_pads_
{
input_right_pads
}
{
const
index_t
x_eff
=
(
filter_spatial_lengths
[
XIdx
]
-
1
)
*
conv_filter_dilations
[
XIdx
]
+
1
;
const
index_t
y_eff
=
NDimSpatial
<
2
?
I1
:
(
filter_spatial_lengths
[
YIdx
]
-
1
)
*
conv_filter_dilations
[
YIdx
]
+
1
;
const
index_t
z_eff
=
NDimSpatial
<
3
?
I1
:
(
filter_spatial_lengths
[
ZIdx
]
-
1
)
*
conv_filter_dilations
[
ZIdx
]
+
1
;
// Iterate over sets of independent filters
for
(
int
z_img_offset
=
0
;
z_img_offset
<
z_eff
;
z_img_offset
+=
conv_filter_strides
[
ZIdx
])
{
for
(
int
y_img_offset
=
0
;
y_img_offset
<
y_eff
;
y_img_offset
+=
conv_filter_strides
[
YIdx
])
{
for
(
int
x_img_offset
=
0
;
x_img_offset
<
x_eff
;
x_img_offset
+=
conv_filter_strides
[
XIdx
])
{
std
::
array
<
index_t
,
NDimSpatial
>
image_offsets
;
std
::
array
<
index_t
,
NDimSpatial
>
effs
;
// Calculate the starting offset for a given set of
// independent filters
if
constexpr
(
NDimSpatial
==
1
)
{
image_offsets
=
{
x_img_offset
};
effs
=
{
x_eff
};
}
if
constexpr
(
NDimSpatial
==
2
)
{
image_offsets
=
{
y_img_offset
,
x_img_offset
};
effs
=
{
y_eff
,
x_eff
};
}
else
if
constexpr
(
NDimSpatial
==
3
)
{
image_offsets
=
{
z_img_offset
,
y_img_offset
,
x_img_offset
};
effs
=
{
z_eff
,
y_eff
,
x_eff
};
}
std
::
array
<
index_t
,
NDimSpatial
>
independent_filters
;
for
(
index_t
i
=
0
;
i
<
NDimSpatial
;
i
++
)
{
independent_filters
[
i
]
=
GetNumberOfIndependentFilters
(
input_spatial_lengths
[
i
],
input_left_pads
[
i
],
input_right_pads
[
i
],
filter_spatial_lengths
[
i
],
conv_filter_strides
[
i
],
conv_filter_dilations
[
i
],
image_offsets
[
i
]);
}
const
index_t
independent_filters_acum
=
ck
::
accumulate_n
<
index_t
>
(
independent_filters
.
begin
(),
NDimSpatial
,
1
,
std
::
multiplies
<>
());
if
(
independent_filters_acum
<=
0
)
continue
;
const
auto
in_grid_desc_m_k
=
MakeInputDescriptor_M_K
(
N
,
C
,
filter_spatial_lengths
,
output_spatial_lengths
,
conv_filter_strides
,
gemm_m_k_strides
,
independent_filters
,
effs
);
const
auto
out_grid_desc_m_k
=
MakeOutDescriptor_M_K
(
N
,
C
,
input_spatial_lengths
,
filter_spatial_lengths
,
image_g_n_c_wis_strides
,
conv_filter_strides
,
conv_filter_dilations
,
input_left_pads
,
input_right_pads
,
image_offsets
,
independent_filters
,
effs
);
in_grid_desc_m_k_container_
.
push_back
(
in_grid_desc_m_k
);
out_grid_desc_m_k_container_
.
push_back
(
out_grid_desc_m_k
);
const
index_t
x_idx
=
x_img_offset
/
conv_filter_strides
[
XIdx
];
const
index_t
y_idx
=
y_img_offset
/
conv_filter_strides
[
YIdx
];
const
index_t
z_idx
=
z_img_offset
/
conv_filter_strides
[
ZIdx
];
const
index_t
x_offset_with_pad
=
math
::
max
(
0
,
x_img_offset
-
input_left_pads
[
XIdx
]);
const
index_t
y_offset_with_pad
=
math
::
max
(
0
,
y_img_offset
-
input_left_pads
[
YIdx
]);
const
index_t
z_offset_with_pad
=
math
::
max
(
0
,
z_img_offset
-
input_left_pads
[
ZIdx
]);
// Memory offsets to next set of independent filters,
// move to independent filters in each dimension
const
index_t
in_offset
=
x_idx
*
gemm_m_k_strides
[
0
]
+
y_idx
*
gemm_m_k_strides
[
0
]
*
output_spatial_lengths
[
XIdx
]
+
z_idx
*
gemm_m_k_strides
[
0
]
*
output_spatial_lengths
[
YIdx
]
*
output_spatial_lengths
[
XIdx
];
// Move to independent filters in appropriate dimensions
const
index_t
out_offset
=
x_offset_with_pad
*
image_g_n_c_wis_strides
[
spatial_offset
+
XIdx
]
+
y_offset_with_pad
*
image_g_n_c_wis_strides
[
spatial_offset
+
YIdx
]
+
z_offset_with_pad
*
image_g_n_c_wis_strides
[
spatial_offset
+
ZIdx
];
const
InputDataType
*
p_in_with_offset
=
static_cast
<
const
InputDataType
*>
(
p_in
)
+
in_offset
;
OutputDataType
*
p_out_with_offset
=
static_cast
<
OutputDataType
*>
(
p_out
)
+
out_offset
;
p_in_container_
.
push_back
(
p_in_with_offset
);
p_out_container_
.
push_back
(
p_out_with_offset
);
}
}
}
}
void
Print
()
const
{
for
(
std
::
size_t
i
=
0
;
i
<
in_grid_desc_m_k_container_
.
size
();
i
++
)
{
std
::
cout
<<
in_grid_desc_m_k_container_
[
i
]
<<
std
::
endl
;
std
::
cout
<<
out_grid_desc_m_k_container_
[
i
]
<<
std
::
endl
;
}
}
const
ck
::
index_t
C_
;
const
ck
::
index_t
X_
;
const
InputDataType
*
p_in_
;
OutputDataType
*
p_out_
;
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
image_g_n_c_wis_strides_
;
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_strides_
;
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_dilations_
;
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_left_pads_
;
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_right_pads_
;
std
::
vector
<
InputGridDesc
>
in_grid_desc_m_k_container_
;
std
::
vector
<
OutputGridDesc
>
out_grid_desc_m_k_container_
;
std
::
vector
<
const
InputDataType
*>
p_in_container_
;
std
::
vector
<
OutputDataType
*>
p_out_container_
;
};
struct
Invoker
:
public
BaseInvoker
{
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
{
if
(
stream_config
.
log_level_
>
0
)
{
arg
.
Print
();
}
float
elapsed_time
=
0.
f
;
const
auto
kernel
=
kernel_tensor_rearrange
<
InputGridDesc
,
InputDataType
,
OutputGridDesc
,
OutputDataType
,
Block2ETileMap
,
GridwiseTensorRearrangeKernel
>
;
// Execute each set of independent filters
for
(
std
::
size_t
i
=
0
;
i
<
arg
.
in_grid_desc_m_k_container_
.
size
();
i
++
)
{
const
auto
block_2_tile_map
=
BlockToCTileMap_M00_N0_M01Adapt
<
MPerBlock
,
KPerBlock
,
InputGridDesc
>
(
arg
.
out_grid_desc_m_k_container_
[
i
]);
const
index_t
grid_size
=
block_2_tile_map
.
CalculateGridSize
(
arg
.
in_grid_desc_m_k_container_
[
i
]);
elapsed_time
+=
launch_and_time_kernel
(
stream_config
,
kernel
,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
arg
.
in_grid_desc_m_k_container_
[
i
],
arg
.
p_in_container_
[
i
],
arg
.
out_grid_desc_m_k_container_
[
i
],
arg
.
p_out_container_
[
i
],
block_2_tile_map
);
}
return
elapsed_time
;
}
float
Run
(
const
BaseArgument
*
p_arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
override
{
return
Run
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
),
stream_config
);
}
};
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
using
namespace
tensor_layout
::
convolution
;
if
constexpr
(
!
(
std
::
is_same_v
<
ImageLayout
,
GNWC
>
||
std
::
is_same_v
<
ImageLayout
,
GNHWC
>
||
std
::
is_same_v
<
ImageLayout
,
GNDHWC
>
))
{
return
false
;
}
const
auto
w_pad_left
=
arg
.
input_left_pads_
[
NDimSpatial
-
I1
];
const
auto
w_pad_right
=
arg
.
input_right_pads_
[
NDimSpatial
-
I1
];
const
auto
dilation_x
=
arg
.
conv_filter_dilations_
[
NDimSpatial
-
I1
];
const
auto
stride_x
=
arg
.
conv_filter_strides_
[
NDimSpatial
-
I1
];
bool
is_w_packed
=
arg
.
image_g_n_c_wis_strides_
[
NDimSpatial
+
I2
]
==
arg
.
C_
;
bool
is_c_packed
=
arg
.
image_g_n_c_wis_strides_
[
I2
]
==
1
;
// check vector acces with c not packed
if
(
!
is_c_packed
&&
ScalarPerVector
!=
1
)
return
false
;
// check vector access of filter window row (only C if C is not packed)
if
(
!
is_w_packed
&&
arg
.
C_
%
ScalarPerVector
!=
0
)
return
false
;
// check vector access of filter window row (X * C)
if
(
arg
.
X_
*
arg
.
C_
%
ScalarPerVector
!=
0
)
return
false
;
// check vector access of pads (w_pad_left/w_pad_right * C)
if
(
w_pad_left
*
arg
.
C_
%
ScalarPerVector
!=
0
||
w_pad_right
*
arg
.
C_
%
ScalarPerVector
!=
0
)
return
false
;
// check vector access of with stride and pad
if
((
w_pad_left
!=
0
||
w_pad_right
!=
0
)
&&
stride_x
>
1
&&
arg
.
C_
%
ScalarPerVector
!=
0
)
return
false
;
// check vector access of with dilation
if
(
dilation_x
>
1
&&
arg
.
C_
%
ScalarPerVector
!=
0
)
return
false
;
bool
valid
=
true
;
for
(
std
::
size_t
i
=
0
;
i
<
arg
.
in_grid_desc_m_k_container_
.
size
();
i
++
)
{
valid
&=
GridwiseTensorRearrangeKernel
::
CheckValidity
(
arg
.
in_grid_desc_m_k_container_
[
i
],
arg
.
out_grid_desc_m_k_container_
[
i
]);
}
return
valid
;
}
bool
IsSupportedArgument
(
const
BaseArgument
*
p_arg
)
override
{
return
IsSupportedArgument
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
));
}
static
auto
MakeArgument
(
const
void
*
p_in
,
// input image
void
*
p_out
,
// output image
const
ck
::
index_t
N
,
const
ck
::
index_t
C
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_spatial_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
filter_spatial_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
output_spatial_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
image_g_n_c_wis_strides
,
const
std
::
array
<
index_t
,
2
>&
gemm_m_k_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_dilations
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_left_pads
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_right_pads
)
{
return
Argument
{
static_cast
<
const
InputDataType
*>
(
p_in
),
static_cast
<
OutputDataType
*>
(
p_out
),
N
,
C
,
input_spatial_lengths
,
filter_spatial_lengths
,
output_spatial_lengths
,
image_g_n_c_wis_strides
,
gemm_m_k_strides
,
conv_filter_strides
,
conv_filter_dilations
,
input_left_pads
,
input_right_pads
};
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
void
*
p_in
,
// input image
void
*
p_out
,
// output image
const
ck
::
index_t
N
,
const
ck
::
index_t
C
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_spatial_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
filter_spatial_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
output_spatial_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
image_g_n_c_wis_strides
,
const
std
::
array
<
index_t
,
2
>&
gemm_m_k_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_dilations
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_left_pads
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_right_pads
)
override
{
return
std
::
make_unique
<
Argument
>
(
static_cast
<
const
InputDataType
*>
(
p_in
),
static_cast
<
OutputDataType
*>
(
p_out
),
N
,
C
,
input_spatial_lengths
,
filter_spatial_lengths
,
output_spatial_lengths
,
image_g_n_c_wis_strides
,
gemm_m_k_strides
,
conv_filter_strides
,
conv_filter_dilations
,
input_left_pads
,
input_right_pads
);
}
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
override
{
return
std
::
make_unique
<
Invoker
>
(
Invoker
{});
}
std
::
string
GetTypeString
()
const
override
{
auto
str
=
std
::
stringstream
();
// clang-format off
str
<<
"DeviceColumnToImage"
<<
"<"
<<
BlockSize
<<
", "
<<
MPerBlock
<<
", "
<<
KPerBlock
<<
", "
<<
ScalarPerVector
<<
">"
;
// clang-format on
return
str
.
str
();
}
};
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_abd_xdl_cshuffle.hpp
0 → 100644
View file @
9f8ab221
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <sstream>
#include <vector>
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_contraction_multiple_abd.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
namespace
ck
{
template
<
typename
GridwiseGemm
,
typename
AsPointer
,
typename
BsPointer
,
typename
DsPointer
,
typename
EDataType
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CDEElementwiseOperation
,
typename
AsGridDesc_AK0_M_AK1
,
typename
BsGridDesc_BK0_N_BK1
,
typename
DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
Block2ETileMap
,
bool
HasMainKBlockLoop
>
__global__
void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
#endif
kernel_contraction_multiple_abd_xdl_cshuffle
(
AsPointer
p_as_grid
,
BsPointer
p_bs_grid
,
DsPointer
p_ds_grid
,
EDataType
*
__restrict__
p_e_grid
,
const
AElementwiseOperation
a_element_op
,
const
BElementwiseOperation
b_element_op
,
const
CDEElementwiseOperation
cde_element_op
,
const
AsGridDesc_AK0_M_AK1
as_grid_desc_ak0_m_ak1
,
const
BsGridDesc_BK0_N_BK1
bs_grid_desc_bk0_n_bk1
,
const
DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
ds_grid_desc_mblock_mperblock_nblock_nperblock
,
const
EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
e_grid_desc_mblock_mperblock_nblock_nperblock
,
const
Block2ETileMap
block_2_etile_map
)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
p_as_grid
,
p_bs_grid
,
p_ds_grid
,
p_e_grid
,
p_shared
,
a_element_op
,
b_element_op
,
cde_element_op
,
as_grid_desc_ak0_m_ak1
,
bs_grid_desc_bk0_n_bk1
,
ds_grid_desc_mblock_mperblock_nblock_nperblock
,
e_grid_desc_mblock_mperblock_nblock_nperblock
,
block_2_etile_map
);
#else
ignore
=
p_as_grid
;
ignore
=
p_bs_grid
;
ignore
=
p_ds_grid
;
ignore
=
p_e_grid
;
ignore
=
a_element_op
;
ignore
=
b_element_op
;
ignore
=
cde_element_op
;
ignore
=
as_grid_desc_ak0_m_ak1
;
ignore
=
bs_grid_desc_bk0_n_bk1
;
ignore
=
ds_grid_desc_mblock_mperblock_nblock_nperblock
;
ignore
=
e_grid_desc_mblock_mperblock_nblock_nperblock
;
ignore
=
block_2_etile_map
;
#endif
}
}
// namespace ck
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
// GEMM:
// input : A[M, K]
// input : B[N, K]
// input : D0[M, N], D1[M, N], ...
// output : E[M, N]
// C = a_op(A) * b_op(B)
// E = cde_op(C, D0, D1, ...)
// Assume:
// D0, D1, ... and E have the same layout
template
<
index_t
NumDimM
,
index_t
NumDimN
,
index_t
NumDimK
,
typename
AsDataType
,
typename
BsDataType
,
typename
AccDataType
,
typename
CShuffleDataType
,
typename
DsDataType
,
typename
EDataType
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CDEElementwiseOperation
,
GemmSpecialization
GemmSpec
,
index_t
NumGemmKPrefetchStage
,
index_t
BlockSize
,
index_t
MPerBlock
,
index_t
NPerBlock
,
index_t
KPerBlock
,
index_t
AK1
,
index_t
BK1
,
index_t
MPerXDL
,
index_t
NPerXDL
,
index_t
MXdlPerWave
,
index_t
NXdlPerWave
,
typename
ABlockTransferThreadClusterLengths_AK0_M_AK1
,
typename
ABlockTransferThreadClusterArrangeOrder
,
typename
ABlockTransferSrcAccessOrder
,
index_t
ABlockTransferSrcVectorDim
,
index_t
ABlockTransferSrcScalarPerVector
,
index_t
ABlockTransferDstScalarPerVector_AK1
,
index_t
ABlockLdsExtraM
,
typename
BBlockTransferThreadClusterLengths_BK0_N_BK1
,
typename
BBlockTransferThreadClusterArrangeOrder
,
typename
BBlockTransferSrcAccessOrder
,
index_t
BBlockTransferSrcVectorDim
,
index_t
BBlockTransferSrcScalarPerVector
,
index_t
BBlockTransferDstScalarPerVector_BK1
,
index_t
BBlockLdsExtraN
,
index_t
CShuffleMXdlPerWavePerShuffle
,
index_t
CShuffleNXdlPerWavePerShuffle
,
typename
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
index_t
CDEBlockTransferScalarPerVector_NPerBlock
,
LoopScheduler
LoopSched
=
make_default_loop_scheduler
(),
PipelineVersion
PipelineVer
=
PipelineVersion
::
v1
>
struct
DeviceContractionMultipleABD_Xdl_CShuffle
:
public
DeviceContractionMultipleABD
<
NumDimM
,
NumDimN
,
NumDimK
,
AsDataType
,
BsDataType
,
DsDataType
,
EDataType
,
AElementwiseOperation
,
BElementwiseOperation
,
CDEElementwiseOperation
>
{
using
DeviceOp
=
DeviceContractionMultipleABD_Xdl_CShuffle
;
static
constexpr
index_t
NumATensor
=
AsDataType
::
Size
();
static
constexpr
index_t
NumBTensor
=
BsDataType
::
Size
();
static
constexpr
index_t
NumDTensor
=
DsDataType
::
Size
();
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
using
ComputeDataType
=
EDataType
;
// GridwiseGemm
using
GridwiseGemm
=
GridwiseGemmMultipleABD_xdl_cshuffle
<
AsDataType
,
BsDataType
,
ComputeDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
AElementwiseOperation
,
BElementwiseOperation
,
CDEElementwiseOperation
,
InMemoryDataOperationEnum
::
Set
,
NumGemmKPrefetchStage
,
BlockSize
,
MPerBlock
,
NPerBlock
,
KPerBlock
,
AK1
,
BK1
,
MPerXDL
,
NPerXDL
,
MXdlPerWave
,
NXdlPerWave
,
ABlockTransferThreadClusterLengths_AK0_M_AK1
,
ABlockTransferThreadClusterArrangeOrder
,
ABlockTransferSrcAccessOrder
,
ABlockTransferSrcVectorDim
,
ABlockTransferSrcScalarPerVector
,
ABlockTransferDstScalarPerVector_AK1
,
false
,
ABlockLdsExtraM
,
BBlockTransferThreadClusterLengths_BK0_N_BK1
,
BBlockTransferThreadClusterArrangeOrder
,
BBlockTransferSrcAccessOrder
,
BBlockTransferSrcVectorDim
,
BBlockTransferSrcScalarPerVector
,
BBlockTransferDstScalarPerVector_BK1
,
false
,
BBlockLdsExtraN
,
CShuffleMXdlPerWavePerShuffle
,
CShuffleNXdlPerWavePerShuffle
,
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
CDEBlockTransferScalarPerVector_NPerBlock
,
LoopSched
,
PipelineVer
>
;
static
constexpr
auto
matrix_padder
=
ck
::
tensor_operation
::
device
::
MatrixPadder
<
GemmSpec
,
index_t
,
index_t
,
index_t
>
{
MPerBlock
,
NPerBlock
,
KPerBlock
};
static
auto
MakeAGridDescriptor_M_K
(
const
std
::
vector
<
index_t
>&
a_ms_ks_lengths_
,
const
std
::
vector
<
index_t
>&
a_ms_ks_strides_
)
{
assert
(
a_ms_ks_lengths_
.
size
()
==
NumDimM
+
NumDimK
&&
a_ms_ks_strides_
.
size
()
==
NumDimM
+
NumDimK
);
const
auto
to_tuple
=
[
&
](
auto
&
vec
,
auto
num
)
{
return
generate_tuple
([
&
](
auto
i
)
{
return
vec
[
i
];
},
num
);
};
const
auto
a_ms_ks_lengths
=
to_tuple
(
a_ms_ks_lengths_
,
Number
<
NumDimM
+
NumDimK
>
{});
const
auto
a_ms_ks_strides
=
to_tuple
(
a_ms_ks_strides_
,
Number
<
NumDimM
+
NumDimK
>
{});
// dimension Ids for M0, M1, ...
constexpr
auto
mDimIds
=
typename
arithmetic_sequence_gen
<
0
,
NumDimM
,
1
>::
type
{};
// dimension Ids for K0, K1, ...
constexpr
auto
kDimIds
=
typename
arithmetic_sequence_gen
<
NumDimM
,
NumDimM
+
NumDimK
,
1
>::
type
{};
// lengths for M0, M1, ...
const
auto
mLengths
=
get_container_subset
(
a_ms_ks_lengths
,
mDimIds
);
// lengths for K0, K1, ...
const
auto
kLengths
=
get_container_subset
(
a_ms_ks_lengths
,
kDimIds
);
// naive tensor A[M0, M1, M2, ..., K0, K1, K2...]
const
auto
a_grid_desc_ms_ks
=
make_naive_tensor_descriptor
(
a_ms_ks_lengths
,
a_ms_ks_strides
);
// transformed tensor A[MRaw = M0 * M1 * M2 * ... , KRaw = K0 * K1 * K2 * ...]
const
auto
a_grid_desc_mraw_kraw
=
transform_tensor_descriptor
(
a_grid_desc_ms_ks
,
make_tuple
(
make_merge_transform
(
mLengths
),
make_merge_transform
(
kLengths
)),
make_tuple
(
mDimIds
,
kDimIds
),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
return
matrix_padder
.
PadADescriptor_M_K
(
a_grid_desc_mraw_kraw
);
}
__host__
__device__
static
auto
MakeAsGridDescriptor_M_K
(
const
std
::
array
<
std
::
vector
<
index_t
>
,
NumATensor
>&
as_ms_ks_lengths
,
const
std
::
array
<
std
::
vector
<
index_t
>
,
NumATensor
>&
as_ms_ks_strides
)
{
return
generate_tuple
(
[
&
](
auto
i
)
{
return
MakeAGridDescriptor_M_K
(
as_ms_ks_lengths
[
i
],
as_ms_ks_strides
[
i
]);
},
Number
<
NumATensor
>
{});
}
// Assume: B[N0, N1, N2, ..., K0, K1, K2, ...]
static
auto
MakeBGridDescriptor_N_K
(
const
std
::
vector
<
index_t
>&
b_ns_ks_lengths_
,
const
std
::
vector
<
index_t
>&
b_ns_ks_strides_
)
{
assert
(
b_ns_ks_lengths_
.
size
()
==
NumDimN
+
NumDimK
&&
b_ns_ks_strides_
.
size
()
==
NumDimN
+
NumDimK
);
const
auto
to_tuple
=
[
&
](
auto
&
vec
,
auto
num
)
{
return
generate_tuple
([
&
](
auto
i
)
{
return
vec
[
i
];
},
num
);
};
const
auto
b_ns_ks_lengths
=
to_tuple
(
b_ns_ks_lengths_
,
Number
<
NumDimN
+
NumDimK
>
{});
const
auto
b_ns_ks_strides
=
to_tuple
(
b_ns_ks_strides_
,
Number
<
NumDimN
+
NumDimK
>
{});
// dimension Ids for N0, N1, ...
constexpr
auto
nDimIds
=
typename
arithmetic_sequence_gen
<
0
,
NumDimN
,
1
>::
type
{};
// dimension Ids for K0, K1, ...
constexpr
auto
kDimIds
=
typename
arithmetic_sequence_gen
<
NumDimN
,
NumDimN
+
NumDimK
,
1
>::
type
{};
// lengths for K0, K1, ...
const
auto
kLengths
=
get_container_subset
(
b_ns_ks_lengths
,
kDimIds
);
// lengths for N0, N1, ...
const
auto
nLengths
=
get_container_subset
(
b_ns_ks_lengths
,
nDimIds
);
// naive tensor B[N0, N1, N2, ..., K0, K1, K2, ...]
const
auto
b_grid_desc_ns_ks
=
make_naive_tensor_descriptor
(
b_ns_ks_lengths
,
b_ns_ks_strides
);
// transformed tensor B[NRaw = N0 * N1 * N2 * ..., KRaw = K0 * K1 * K2 * ...]
const
auto
b_grid_desc_nraw_kraw
=
transform_tensor_descriptor
(
b_grid_desc_ns_ks
,
make_tuple
(
make_merge_transform
(
nLengths
),
make_merge_transform
(
kLengths
)),
make_tuple
(
nDimIds
,
kDimIds
),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
return
matrix_padder
.
PadBDescriptor_N_K
(
b_grid_desc_nraw_kraw
);
}
__host__
__device__
static
auto
MakeBsGridDescriptor_N_K
(
const
std
::
array
<
std
::
vector
<
index_t
>
,
NumBTensor
>&
bs_ns_ks_lengths
,
const
std
::
array
<
std
::
vector
<
index_t
>
,
NumBTensor
>&
bs_ns_ks_strides
)
{
return
generate_tuple
(
[
&
](
auto
i
)
{
return
MakeBGridDescriptor_N_K
(
bs_ns_ks_lengths
[
i
],
bs_ns_ks_strides
[
i
]);
},
Number
<
NumBTensor
>
{});
}
// assume E[M0, M1, M2, ..., N0, N1, N2...]
static
auto
MakeEGridDescriptor_M_N
(
const
std
::
vector
<
index_t
>&
e_ms_ns_lengths_
,
const
std
::
vector
<
index_t
>&
e_ms_ns_strides_
)
{
assert
(
e_ms_ns_lengths_
.
size
()
==
NumDimM
+
NumDimN
&&
e_ms_ns_strides_
.
size
()
==
NumDimM
+
NumDimN
);
const
auto
to_tuple
=
[
&
](
auto
&
vec
,
auto
num
)
{
return
generate_tuple
([
&
](
auto
i
)
{
return
vec
[
i
];
},
num
);
};
const
auto
e_ms_ns_lengths
=
to_tuple
(
e_ms_ns_lengths_
,
Number
<
NumDimM
+
NumDimN
>
{});
const
auto
e_ms_ns_strides
=
to_tuple
(
e_ms_ns_strides_
,
Number
<
NumDimM
+
NumDimN
>
{});
// dimension Ids for M0, M1, ...
constexpr
auto
mDimIds
=
typename
arithmetic_sequence_gen
<
0
,
NumDimM
,
1
>::
type
{};
// dimension Ids for N0, N1, ...
constexpr
auto
nDimIds
=
typename
arithmetic_sequence_gen
<
NumDimM
,
NumDimM
+
NumDimN
,
1
>::
type
{};
// lengths for M0, M1, ...
const
auto
mLengths
=
get_container_subset
(
e_ms_ns_lengths
,
mDimIds
);
// lengths for K0, K1, ...
const
auto
nLengths
=
get_container_subset
(
e_ms_ns_lengths
,
nDimIds
);
// naive tensor E[M0, M1, M2, ..., N0, N1, N2...]
const
auto
e_grid_desc_ms_ns
=
make_naive_tensor_descriptor
(
e_ms_ns_lengths
,
e_ms_ns_strides
);
// transformed tensor E[MRaw = M0 * M1 * M2 * ... , NRaw = N0 * N1 * N2 * ...]
const
auto
e_grid_desc_mraw_nraw
=
transform_tensor_descriptor
(
e_grid_desc_ms_ns
,
make_tuple
(
make_merge_transform
(
mLengths
),
make_merge_transform
(
nLengths
)),
make_tuple
(
mDimIds
,
nDimIds
),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
return
matrix_padder
.
PadCDescriptor_M_N
(
e_grid_desc_mraw_nraw
);
}
static
auto
MakeDsGridDescriptor_M_N
(
const
std
::
array
<
std
::
vector
<
index_t
>
,
NumDTensor
>&
ds_ms_ns_lengths
,
const
std
::
array
<
std
::
vector
<
index_t
>
,
NumDTensor
>&
ds_ms_ns_strides
)
{
return
generate_tuple
(
[
&
](
auto
i
)
{
return
MakeEGridDescriptor_M_N
(
ds_ms_ns_lengths
[
i
],
ds_ms_ns_strides
[
i
]);
},
Number
<
NumDTensor
>
{});
}
// desc for problem definition
using
AsGridDesc_M_K
=
remove_cvref_t
<
decltype
(
MakeAsGridDescriptor_M_K
({},
{}))
>
;
using
BsGridDesc_N_K
=
remove_cvref_t
<
decltype
(
MakeBsGridDescriptor_N_K
({},
{}))
>
;
using
DsGridDesc_M_N
=
remove_cvref_t
<
decltype
(
MakeDsGridDescriptor_M_N
({},
{}))
>
;
using
EGridDesc_M_N
=
remove_cvref_t
<
decltype
(
MakeEGridDescriptor_M_N
({},
{}))
>
;
// desc for blockwise copy
using
AsGridDesc_AK0_M_AK1
=
remove_cvref_t
<
decltype
(
GridwiseGemm
::
MakeAsGridDescriptor_AK0_M_AK1
(
AsGridDesc_M_K
{}))
>
;
using
BsGridDesc_BK0_N_BK1
=
remove_cvref_t
<
decltype
(
GridwiseGemm
::
MakeBsGridDescriptor_BK0_N_BK1
(
BsGridDesc_N_K
{}))
>
;
using
DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
=
remove_cvref_t
<
decltype
(
GridwiseGemm
::
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
DsGridDesc_M_N
{}))
>
;
using
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
=
remove_cvref_t
<
decltype
(
GridwiseGemm
::
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
EGridDesc_M_N
{}))
>
;
// block-to-e-tile map
using
Block2ETileMap
=
remove_cvref_t
<
decltype
(
GridwiseGemm
::
MakeBlock2ETileMap
(
EGridDesc_M_N
{}))
>
;
// Argument
struct
Argument
:
public
BaseArgument
{
Argument
(
std
::
array
<
const
void
*
,
NumATensor
>
p_as_grid
,
std
::
array
<
const
void
*
,
NumBTensor
>
p_bs_grid
,
std
::
array
<
const
void
*
,
NumDTensor
>
p_ds_grid
,
void
*
p_e_grid
,
const
std
::
array
<
std
::
vector
<
index_t
>
,
NumATensor
>&
a_ms_ks_lengths
,
const
std
::
array
<
std
::
vector
<
index_t
>
,
NumATensor
>&
a_ms_ks_strides
,
const
std
::
array
<
std
::
vector
<
index_t
>
,
NumBTensor
>&
b_ns_ks_lengths
,
const
std
::
array
<
std
::
vector
<
index_t
>
,
NumBTensor
>&
b_ns_ks_strides
,
const
std
::
array
<
std
::
vector
<
index_t
>
,
NumDTensor
>&
d_ms_ns_lengths
,
const
std
::
array
<
std
::
vector
<
index_t
>
,
NumDTensor
>&
d_ms_ns_strides
,
const
std
::
vector
<
index_t
>&
e_ms_ns_length
,
const
std
::
vector
<
index_t
>&
e_ms_ns_stride
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CDEElementwiseOperation
cde_element_op
)
:
p_as_grid_
{},
p_bs_grid_
{},
p_ds_grid_
{},
p_e_grid_
{
static_cast
<
EDataType
*>
(
p_e_grid
)},
as_grid_desc_m_k_
{},
bs_grid_desc_n_k_
{},
ds_grid_desc_m_n_
{},
e_grid_desc_m_n_
{
MakeEGridDescriptor_M_N
(
e_ms_ns_length
,
e_ms_ns_stride
)},
as_grid_desc_ak0_m_ak1_
{},
bs_grid_desc_bk0_n_bk1_
{},
ds_grid_desc_mblock_mperblock_nblock_nperblock_
{},
e_grid_desc_mblock_mperblock_nblock_nperblock_
{},
block_2_etile_map_
{
GridwiseGemm
::
MakeBlock2ETileMap
(
e_grid_desc_m_n_
)},
a_element_op_
{
a_element_op
},
b_element_op_
{
b_element_op
},
cde_element_op_
{
cde_element_op
}
{
// populate pointer, desc for As
static_for
<
0
,
NumATensor
,
1
>
{}([
&
](
auto
i
)
{
// using ALayout = remove_cvref_t<tuple_element_t<i.value, AsLayout>>;
using
ADataType
=
remove_cvref_t
<
tuple_element_t
<
i
.
value
,
AsDataType
>>
;
// A pointer
p_as_grid_
(
i
)
=
static_cast
<
const
ADataType
*>
(
p_as_grid
[
i
]);
// A desc
as_grid_desc_m_k_
(
i
)
=
MakeAGridDescriptor_M_K
(
a_ms_ks_lengths
[
i
],
a_ms_ks_strides
[
i
]);
});
// populate pointer, desc for Bs
static_for
<
0
,
NumBTensor
,
1
>
{}([
&
](
auto
i
)
{
// using BLayout = remove_cvref_t<tuple_element_t<i.value, BsLayout>>;
using
BDataType
=
remove_cvref_t
<
tuple_element_t
<
i
.
value
,
BsDataType
>>
;
// B pointer
p_bs_grid_
(
i
)
=
static_cast
<
const
BDataType
*>
(
p_bs_grid
[
i
]);
// B desc
bs_grid_desc_n_k_
(
i
)
=
MakeBGridDescriptor_N_K
(
b_ns_ks_lengths
[
i
],
b_ns_ks_strides
[
i
]);
});
// populate pointer, desc for Ds
static_for
<
0
,
NumDTensor
,
1
>
{}([
&
](
auto
i
)
{
// using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
using
DDataType
=
remove_cvref_t
<
tuple_element_t
<
i
.
value
,
DsDataType
>>
;
// D pointer
p_ds_grid_
(
i
)
=
static_cast
<
const
DDataType
*>
(
p_ds_grid
[
i
]);
// D desc
ds_grid_desc_m_n_
(
i
)
=
MakeEGridDescriptor_M_N
(
d_ms_ns_lengths
[
i
],
d_ms_ns_strides
[
i
]);
});
// populate desc for Ds/E
if
(
GridwiseGemm
::
CheckValidity
(
as_grid_desc_m_k_
,
bs_grid_desc_n_k_
,
ds_grid_desc_m_n_
,
e_grid_desc_m_n_
,
block_2_etile_map_
))
{
as_grid_desc_ak0_m_ak1_
=
GridwiseGemm
::
MakeAsGridDescriptor_AK0_M_AK1
(
as_grid_desc_m_k_
);
bs_grid_desc_bk0_n_bk1_
=
GridwiseGemm
::
MakeBsGridDescriptor_BK0_N_BK1
(
bs_grid_desc_n_k_
);
ds_grid_desc_mblock_mperblock_nblock_nperblock_
=
GridwiseGemm
::
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
ds_grid_desc_m_n_
);
e_grid_desc_mblock_mperblock_nblock_nperblock_
=
GridwiseGemm
::
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
e_grid_desc_m_n_
);
}
// for sanity check of vector memory access
for
(
index_t
i
=
0
;
i
<
NumATensor
;
++
i
)
{
a_mz_stride_
[
i
]
=
a_ms_ks_strides
[
i
][
NumDimM
-
1
];
a_kz_stride_
[
i
]
=
a_ms_ks_strides
[
i
][
NumDimM
+
NumDimK
-
1
];
}
for
(
index_t
i
=
0
;
i
<
NumBTensor
;
++
i
)
{
b_nz_stride_
[
i
]
=
b_ns_ks_strides
[
i
][
NumDimN
-
1
];
b_kz_stride_
[
i
]
=
b_ns_ks_strides
[
i
][
NumDimN
+
NumDimK
-
1
];
}
for
(
index_t
i
=
0
;
i
<
NumDTensor
;
++
i
)
{
ds_nz_stride_
[
i
]
=
d_ms_ns_strides
[
i
][
NumDimM
+
NumDimN
-
1
];
}
e_nz_stride_
=
e_ms_ns_stride
[
NumDimM
+
NumDimN
-
1
];
}
// pointers
typename
GridwiseGemm
::
AsGridPointer
p_as_grid_
;
typename
GridwiseGemm
::
BsGridPointer
p_bs_grid_
;
typename
GridwiseGemm
::
DsGridPointer
p_ds_grid_
;
EDataType
*
p_e_grid_
;
// tensor descriptors for problem definiton
AsGridDesc_M_K
as_grid_desc_m_k_
;
BsGridDesc_N_K
bs_grid_desc_n_k_
;
DsGridDesc_M_N
ds_grid_desc_m_n_
;
EGridDesc_M_N
e_grid_desc_m_n_
;
// tensor descriptors for block/thread-wise copy
AsGridDesc_AK0_M_AK1
as_grid_desc_ak0_m_ak1_
;
BsGridDesc_BK0_N_BK1
bs_grid_desc_bk0_n_bk1_
;
DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
ds_grid_desc_mblock_mperblock_nblock_nperblock_
;
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
e_grid_desc_mblock_mperblock_nblock_nperblock_
;
// block-to-e-tile map
Block2ETileMap
block_2_etile_map_
;
// element-wise op
AElementwiseOperation
a_element_op_
;
BElementwiseOperation
b_element_op_
;
CDEElementwiseOperation
cde_element_op_
;
// Strides for the last M/N/K dimensions of A/B/Ds/E
// for sanity check of vector load/store
std
::
array
<
index_t
,
NumATensor
>
a_mz_stride_
;
std
::
array
<
index_t
,
NumATensor
>
a_kz_stride_
;
std
::
array
<
index_t
,
NumBTensor
>
b_nz_stride_
;
std
::
array
<
index_t
,
NumBTensor
>
b_kz_stride_
;
std
::
array
<
index_t
,
NumDTensor
>
ds_nz_stride_
;
index_t
e_nz_stride_
;
};
// Invoker
struct
Invoker
:
public
BaseInvoker
{
using
Argument
=
DeviceOp
::
Argument
;
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
{
if
(
!
GridwiseGemm
::
CheckValidity
(
arg
.
as_grid_desc_m_k_
,
arg
.
bs_grid_desc_n_k_
,
arg
.
ds_grid_desc_m_n_
,
arg
.
e_grid_desc_m_n_
,
arg
.
block_2_etile_map_
))
{
throw
std
::
runtime_error
(
"wrong! GridwiseGemm has invalid setting"
);
}
const
index_t
grid_size
=
arg
.
block_2_etile_map_
.
CalculateGridSize
(
arg
.
e_grid_desc_m_n_
);
auto
launch_kernel
=
[
&
](
auto
has_main_k_block_loop
)
{
constexpr
bool
has_main_loop
=
has_main_k_block_loop
.
value
;
const
auto
kernel
=
kernel_contraction_multiple_abd_xdl_cshuffle
<
GridwiseGemm
,
typename
GridwiseGemm
::
AsGridPointer
,
typename
GridwiseGemm
::
BsGridPointer
,
typename
GridwiseGemm
::
DsGridPointer
,
EDataType
,
AElementwiseOperation
,
BElementwiseOperation
,
CDEElementwiseOperation
,
DeviceOp
::
AsGridDesc_AK0_M_AK1
,
DeviceOp
::
BsGridDesc_BK0_N_BK1
,
DeviceOp
::
DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
,
DeviceOp
::
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
,
DeviceOp
::
Block2ETileMap
,
has_main_loop
>
;
return
launch_and_time_kernel
(
stream_config
,
kernel
,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
arg
.
p_as_grid_
,
arg
.
p_bs_grid_
,
arg
.
p_ds_grid_
,
arg
.
p_e_grid_
,
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
cde_element_op_
,
arg
.
as_grid_desc_ak0_m_ak1_
,
arg
.
bs_grid_desc_bk0_n_bk1_
,
arg
.
ds_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg
.
e_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg
.
block_2_etile_map_
);
};
const
auto
K
=
arg
.
as_grid_desc_m_k_
[
I0
].
GetLength
(
I1
);
if
(
GridwiseGemm
::
CalculateHasMainKBlockLoop
(
K
))
{
return
launch_kernel
(
integral_constant
<
bool
,
true
>
{});
}
else
{
return
launch_kernel
(
integral_constant
<
bool
,
false
>
{});
}
}
// polymorphic
float
Run
(
const
BaseArgument
*
p_arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
override
{
return
Run
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
),
stream_config
);
}
};
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
if
(
!
ck
::
is_xdl_supported
())
{
return
false
;
}
// check vector load/store
{
bool
all_valid
=
true
;
static_for
<
0
,
NumATensor
,
1
>
{}([
&
](
auto
i
)
{
// vector memory access of A: could be on M or AK1 dimension
if
constexpr
(
ABlockTransferSrcVectorDim
==
1
)
{
if
(
!
(
arg
.
a_mz_stride_
[
i
]
==
1
&&
arg
.
as_grid_desc_ak0_m_ak1_
[
i
].
GetLength
(
I1
)
%
ABlockTransferSrcScalarPerVector
==
0
))
{
all_valid
=
false
;
}
}
else
{
if
(
!
(
arg
.
a_kz_stride_
[
i
]
==
1
&&
arg
.
as_grid_desc_ak0_m_ak1_
[
i
].
GetLength
(
I2
)
%
ABlockTransferSrcScalarPerVector
==
0
))
{
all_valid
=
false
;
}
}
});
// vector memory access of B: could be on N or BK1 dimension
static_for
<
0
,
NumBTensor
,
1
>
{}([
&
](
auto
i
)
{
if
constexpr
(
BBlockTransferSrcVectorDim
==
1
)
{
if
(
!
(
arg
.
b_nz_stride_
[
i
]
==
1
&&
arg
.
bs_grid_desc_bk0_n_bk1_
[
i
].
GetLength
(
I1
)
%
BBlockTransferSrcScalarPerVector
==
0
))
{
all_valid
=
false
;
}
}
else
{
if
(
!
(
arg
.
b_kz_stride_
[
i
]
==
1
&&
arg
.
bs_grid_desc_bk0_n_bk1_
[
i
].
GetLength
(
I2
)
%
BBlockTransferSrcScalarPerVector
==
0
))
{
all_valid
=
false
;
}
}
});
// check vector load of Ds
static_for
<
0
,
NumDTensor
,
1
>
{}([
&
](
auto
i
)
{
if
(
!
(
arg
.
ds_nz_stride_
[
i
]
==
1
&&
arg
.
ds_grid_desc_mblock_mperblock_nblock_nperblock_
[
i
].
GetLength
(
I3
)
%
CDEBlockTransferScalarPerVector_NPerBlock
==
0
))
{
all_valid
=
false
;
}
});
// vector memory access of E: always on NPerBlock dimension
if
(
!
(
arg
.
e_nz_stride_
==
1
&&
arg
.
e_grid_desc_mblock_mperblock_nblock_nperblock_
.
GetLength
(
I3
)
%
CDEBlockTransferScalarPerVector_NPerBlock
==
0
))
{
all_valid
=
false
;
}
if
(
!
all_valid
)
{
return
false
;
}
}
return
GridwiseGemm
::
CheckValidity
(
arg
.
as_grid_desc_m_k_
,
arg
.
bs_grid_desc_n_k_
,
arg
.
ds_grid_desc_m_n_
,
arg
.
e_grid_desc_m_n_
,
arg
.
block_2_etile_map_
);
}
// polymorphic
bool
IsSupportedArgument
(
const
BaseArgument
*
p_arg
)
override
{
return
IsSupportedArgument
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
));
}
static
auto
MakeArgument
(
std
::
array
<
const
void
*
,
NumATensor
>
p_as
,
std
::
array
<
const
void
*
,
NumBTensor
>
p_bs
,
std
::
array
<
const
void
*
,
NumDTensor
>
p_ds
,
void
*
p_e
,
const
std
::
array
<
std
::
vector
<
index_t
>
,
NumATensor
>&
a_ms_ks_lengths
,
const
std
::
array
<
std
::
vector
<
index_t
>
,
NumATensor
>&
a_ms_ks_strides
,
const
std
::
array
<
std
::
vector
<
index_t
>
,
NumBTensor
>&
b_ns_ks_lengths
,
const
std
::
array
<
std
::
vector
<
index_t
>
,
NumBTensor
>&
b_ns_ks_strides
,
const
std
::
array
<
std
::
vector
<
index_t
>
,
NumDTensor
>&
d_ms_ns_lengths
,
const
std
::
array
<
std
::
vector
<
index_t
>
,
NumDTensor
>&
d_ms_ns_strides
,
const
std
::
vector
<
index_t
>&
e_ms_ns_length
,
const
std
::
vector
<
index_t
>&
e_ms_ns_stride
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CDEElementwiseOperation
cde_element_op
)
{
return
Argument
{
p_as
,
p_bs
,
p_ds
,
p_e
,
a_ms_ks_lengths
,
a_ms_ks_strides
,
b_ns_ks_lengths
,
b_ns_ks_strides
,
d_ms_ns_lengths
,
d_ms_ns_strides
,
e_ms_ns_length
,
e_ms_ns_stride
,
a_element_op
,
b_element_op
,
cde_element_op
};
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
// polymorphic
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
std
::
array
<
const
void
*
,
NumATensor
>
p_as
,
std
::
array
<
const
void
*
,
NumBTensor
>
p_bs
,
std
::
array
<
const
void
*
,
NumDTensor
>
p_ds
,
void
*
p_e
,
const
std
::
array
<
std
::
vector
<
index_t
>
,
NumATensor
>&
as_ms_ks_lengths
,
const
std
::
array
<
std
::
vector
<
index_t
>
,
NumATensor
>&
as_ms_ks_strides
,
const
std
::
array
<
std
::
vector
<
index_t
>
,
NumBTensor
>&
bs_ns_ks_lengths
,
const
std
::
array
<
std
::
vector
<
index_t
>
,
NumBTensor
>&
bs_ns_ks_strides
,
const
std
::
array
<
std
::
vector
<
index_t
>
,
NumDTensor
>&
ds_ms_ns_lengths
,
const
std
::
array
<
std
::
vector
<
index_t
>
,
NumDTensor
>&
ds_ms_ns_strides
,
const
std
::
vector
<
index_t
>&
e_ms_ns_length
,
const
std
::
vector
<
index_t
>&
e_ms_ns_stride
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CDEElementwiseOperation
cde_element_op
)
override
{
return
std
::
make_unique
<
Argument
>
(
p_as
,
p_bs
,
p_ds
,
p_e
,
as_ms_ks_lengths
,
as_ms_ks_strides
,
bs_ns_ks_lengths
,
bs_ns_ks_strides
,
ds_ms_ns_lengths
,
ds_ms_ns_strides
,
e_ms_ns_length
,
e_ms_ns_stride
,
a_element_op
,
b_element_op
,
cde_element_op
);
}
// polymorphic
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
override
{
return
std
::
make_unique
<
Invoker
>
(
Invoker
{});
}
// polymorphic
std
::
string
GetTypeString
()
const
override
{
auto
str
=
std
::
stringstream
();
std
::
map
<
LoopScheduler
,
std
::
string
>
LoopSchedToString
{
{
LoopScheduler
::
Default
,
"Default"
},
{
LoopScheduler
::
Interwave
,
"Interwave"
}};
std
::
map
<
PipelineVersion
,
std
::
string
>
PipelineVersionToString
{{
PipelineVersion
::
v1
,
"v1"
},
{
PipelineVersion
::
v2
,
"v2"
}};
// clang-format off
str
<<
"DeviceContractionMultipleABD_Xdl_CShuffle"
<<
"<"
<<
BlockSize
<<
", "
<<
MPerBlock
<<
", "
<<
NPerBlock
<<
", "
<<
KPerBlock
<<
", "
<<
AK1
<<
", "
<<
BK1
<<
", "
<<
MPerXDL
<<
", "
<<
NPerXDL
<<
", "
<<
MXdlPerWave
<<
", "
<<
NXdlPerWave
<<
", "
<<
ABlockTransferSrcScalarPerVector
<<
", "
<<
BBlockTransferSrcScalarPerVector
<<
", "
<<
CShuffleMXdlPerWavePerShuffle
<<
", "
<<
CShuffleNXdlPerWavePerShuffle
<<
", "
<<
getGemmSpecializationString
(
GemmSpec
)
<<
">"
<<
" LoopScheduler: "
<<
LoopSchedToString
[
LoopSched
]
<<
", "
<<
"PipelineVersion: "
<<
PipelineVersionToString
[
PipelineVer
];
// clang-format on
return
str
.
str
();
}
};
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
include/ck/tensor_operation/gpu/device/impl/device_elementwise_impl.hpp
View file @
9f8ab221
...
...
@@ -296,6 +296,28 @@ struct DeviceElementwiseImpl
{
return
std
::
make_unique
<
Invoker
>
();
};
std
::
string
GetTypeString
()
const
override
{
auto
str
=
std
::
stringstream
();
// clang-format off
str
<<
"DeviceElementwiseImpl<"
;
str
<<
"NumDim_"
<<
NumDim
<<
","
;
str
<<
"MPerThread_"
<<
MPerThread
<<
","
;
str
<<
"InScalarPerVector"
;
static_for
<
0
,
InScalarPerVectorSeq
::
Size
(),
1
>
{}([
&
](
auto
i
)
{
str
<<
"_"
<<
InScalarPerVectorSeq
::
At
(
i
).
value
;
});
str
<<
","
;
str
<<
"OutScalarPerVector"
;
static_for
<
0
,
OutScalarPerVectorSeq
::
Size
(),
1
>
{}([
&
](
auto
i
)
{
str
<<
"_"
<<
OutScalarPerVectorSeq
::
At
(
i
).
value
;
});
str
<<
">"
;
// clang-format on
return
str
.
str
();
}
};
// namespace device
}
// namespace device
...
...
include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_abd_xdl_cshuffle.hpp
0 → 100644
View file @
9f8ab221
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <sstream>
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_multiple_abd.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
namespace
ck
{
template
<
typename
GridwiseGemm
,
typename
AsPointer
,
typename
BsPointer
,
typename
DsPointer
,
typename
EDataType
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CDEElementwiseOperation
,
typename
AsGridDesc_AK0_M_AK1
,
typename
BsGridDesc_BK0_N_BK1
,
typename
DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
Block2ETileMap
,
bool
HasMainKBlockLoop
>
__global__
void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
#endif
kernel_gemm_multiple_abd_xdl_cshuffle
(
AsPointer
p_as_grid
,
BsPointer
p_bs_grid
,
DsPointer
p_ds_grid
,
EDataType
*
__restrict__
p_e_grid
,
const
AElementwiseOperation
a_element_op
,
const
BElementwiseOperation
b_element_op
,
const
CDEElementwiseOperation
cde_element_op
,
const
AsGridDesc_AK0_M_AK1
as_grid_desc_ak0_m_ak1
,
const
BsGridDesc_BK0_N_BK1
bs_grid_desc_bk0_n_bk1
,
const
DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
ds_grid_desc_mblock_mperblock_nblock_nperblock
,
const
EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
e_grid_desc_mblock_mperblock_nblock_nperblock
,
const
Block2ETileMap
block_2_etile_map
)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
p_as_grid
,
p_bs_grid
,
p_ds_grid
,
p_e_grid
,
p_shared
,
a_element_op
,
b_element_op
,
cde_element_op
,
as_grid_desc_ak0_m_ak1
,
bs_grid_desc_bk0_n_bk1
,
ds_grid_desc_mblock_mperblock_nblock_nperblock
,
e_grid_desc_mblock_mperblock_nblock_nperblock
,
block_2_etile_map
);
#else
ignore
=
p_as_grid
;
ignore
=
p_bs_grid
;
ignore
=
p_ds_grid
;
ignore
=
p_e_grid
;
ignore
=
a_element_op
;
ignore
=
b_element_op
;
ignore
=
cde_element_op
;
ignore
=
as_grid_desc_ak0_m_ak1
;
ignore
=
bs_grid_desc_bk0_n_bk1
;
ignore
=
ds_grid_desc_mblock_mperblock_nblock_nperblock
;
ignore
=
e_grid_desc_mblock_mperblock_nblock_nperblock
;
ignore
=
block_2_etile_map
;
#endif
}
}
// namespace ck
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
// GEMM:
// input : A[M, K]
// input : B[N, K]
// input : D0[M, N], D1[M, N], ...
// output : E[M, N]
// C = a_op(A) * b_op(B)
// E = cde_op(C, D0, D1, ...)
// Assume:
// D0, D1, ... and E have the same layout
template
<
typename
AsLayout
,
typename
BsLayout
,
typename
DsLayout
,
typename
ELayout
,
typename
AsDataType
,
typename
BsDataType
,
typename
AccDataType
,
typename
CShuffleDataType
,
typename
DsDataType
,
typename
EDataType
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CDEElementwiseOperation
,
GemmSpecialization
GemmSpec
,
index_t
NumGemmKPrefetchStage
,
index_t
BlockSize
,
index_t
MPerBlock
,
index_t
NPerBlock
,
index_t
KPerBlock
,
index_t
AK1
,
index_t
BK1
,
index_t
MPerXDL
,
index_t
NPerXDL
,
index_t
MXdlPerWave
,
index_t
NXdlPerWave
,
typename
ABlockTransferThreadClusterLengths_AK0_M_AK1
,
typename
ABlockTransferThreadClusterArrangeOrder
,
typename
ABlockTransferSrcAccessOrder
,
index_t
ABlockTransferSrcVectorDim
,
index_t
ABlockTransferSrcScalarPerVector
,
index_t
ABlockTransferDstScalarPerVector_AK1
,
index_t
ABlockLdsExtraM
,
typename
BBlockTransferThreadClusterLengths_BK0_N_BK1
,
typename
BBlockTransferThreadClusterArrangeOrder
,
typename
BBlockTransferSrcAccessOrder
,
index_t
BBlockTransferSrcVectorDim
,
index_t
BBlockTransferSrcScalarPerVector
,
index_t
BBlockTransferDstScalarPerVector_BK1
,
index_t
BBlockLdsExtraN
,
index_t
CShuffleMXdlPerWavePerShuffle
,
index_t
CShuffleNXdlPerWavePerShuffle
,
typename
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
index_t
CDEBlockTransferScalarPerVector_NPerBlock
,
LoopScheduler
LoopSched
=
make_default_loop_scheduler
(),
PipelineVersion
PipelineVer
=
PipelineVersion
::
v1
>
struct
DeviceGemmMultipleABD_Xdl_CShuffle
:
public
DeviceGemmMultipleABD
<
AsLayout
,
BsLayout
,
DsLayout
,
ELayout
,
AsDataType
,
BsDataType
,
DsDataType
,
EDataType
,
AElementwiseOperation
,
BElementwiseOperation
,
CDEElementwiseOperation
>
{
using
DeviceOp
=
DeviceGemmMultipleABD_Xdl_CShuffle
;
static
constexpr
index_t
NumATensor
=
AsDataType
::
Size
();
static
constexpr
index_t
NumBTensor
=
BsDataType
::
Size
();
static
constexpr
index_t
NumDTensor
=
DsDataType
::
Size
();
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
#if 0
static constexpr auto matrix_padder =
MatrixPadder<GemmSpec, index_t, index_t, index_t>{MPerBlock, NPerBlock, KPerBlock};
static auto MakeAGridDescriptor_M_K(index_t MRaw, index_t KRaw, index_t StrideAs)
{
const auto a_grid_desc_mraw_kraw = [&]() {
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, AsLayout>)
{
return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw),
make_tuple(StrideAs, I1));
}
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, AsLayout>)
{
return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw),
make_tuple(I1, StrideAs));
}
}();
return matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw);
}
static auto MakeBGridDescriptor_N_K(index_t KRaw, index_t NRaw, index_t StrideBs)
{
const auto b_grid_desc_nraw_kraw = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, BsLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw),
make_tuple(I1, StrideBs));
}
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, BsLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw),
make_tuple(StrideBs, I1));
}
}();
return matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw);
}
template <typename ELay>
static auto MakeEGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideE)
{
const auto e_grid_desc_mraw_nraw = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, ELay>::value)
{
return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw),
make_tuple(StrideE, I1));
}
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, ELay>::value)
{
return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw),
make_tuple(I1, StrideE));
}
}();
return matrix_padder.PadCDescriptor_M_N(e_grid_desc_mraw_nraw);
}
static auto MakeDsGridDescriptor_M_N(const std::array<index_t, NumDTensor>& MRaws,
const std::array<index_t, NumDTensor>& NRaws,
const std::array<index_t, NumDTensor>& DsStride)
{
return generate_tuple(
[&](auto i) {
using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
return DeviceOp::MakeEGridDescriptor_M_N<DLayout>(MRaws[i], NRaws[i], DsStride[i]);
},
Number<NumDTensor>{});
}
#endif
using
ComputeDataType
=
EDataType
;
// GridwiseGemm
using
GridwiseGemm
=
GridwiseGemmMultipleABD_xdl_cshuffle
<
AsDataType
,
BsDataType
,
ComputeDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
AElementwiseOperation
,
BElementwiseOperation
,
CDEElementwiseOperation
,
InMemoryDataOperationEnum
::
Set
,
NumGemmKPrefetchStage
,
BlockSize
,
MPerBlock
,
NPerBlock
,
KPerBlock
,
AK1
,
BK1
,
MPerXDL
,
NPerXDL
,
MXdlPerWave
,
NXdlPerWave
,
ABlockTransferThreadClusterLengths_AK0_M_AK1
,
ABlockTransferThreadClusterArrangeOrder
,
ABlockTransferSrcAccessOrder
,
ABlockTransferSrcVectorDim
,
ABlockTransferSrcScalarPerVector
,
ABlockTransferDstScalarPerVector_AK1
,
false
,
ABlockLdsExtraM
,
BBlockTransferThreadClusterLengths_BK0_N_BK1
,
BBlockTransferThreadClusterArrangeOrder
,
BBlockTransferSrcAccessOrder
,
BBlockTransferSrcVectorDim
,
BBlockTransferSrcScalarPerVector
,
BBlockTransferDstScalarPerVector_BK1
,
false
,
BBlockLdsExtraN
,
CShuffleMXdlPerWavePerShuffle
,
CShuffleNXdlPerWavePerShuffle
,
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
CDEBlockTransferScalarPerVector_NPerBlock
,
LoopSched
,
PipelineVer
>
;
// desc for problem definition
using
AsGridDesc_M_K
=
remove_cvref_t
<
decltype
(
GridwiseGemm
::
template
MakeAsGridDescriptor_M_K
<
AsLayout
,
GemmSpec
>(
{},
{},
{}))
>
;
using
BsGridDesc_N_K
=
remove_cvref_t
<
decltype
(
GridwiseGemm
::
template
MakeBsGridDescriptor_N_K
<
BsLayout
,
GemmSpec
>(
{},
{},
{}))
>
;
using
DsGridDesc_M_N
=
remove_cvref_t
<
decltype
(
GridwiseGemm
::
template
MakeDsGridDescriptor_M_N
<
DsLayout
,
GemmSpec
>(
{},
{},
{}))
>
;
using
EGridDesc_M_N
=
decltype
(
GridwiseGemm
::
template
MakeEGridDescriptor_M_N
<
ELayout
,
GemmSpec
>(
1
,
1
,
1
));
// desc for blockwise copy
using
AsGridDesc_AK0_M_AK1
=
remove_cvref_t
<
decltype
(
GridwiseGemm
::
MakeAsGridDescriptor_AK0_M_AK1
(
AsGridDesc_M_K
{}))
>
;
using
BsGridDesc_BK0_N_BK1
=
remove_cvref_t
<
decltype
(
GridwiseGemm
::
MakeBsGridDescriptor_BK0_N_BK1
(
BsGridDesc_N_K
{}))
>
;
using
DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
=
remove_cvref_t
<
decltype
(
GridwiseGemm
::
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
DsGridDesc_M_N
{}))
>
;
using
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
=
remove_cvref_t
<
decltype
(
GridwiseGemm
::
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
EGridDesc_M_N
{}))
>
;
// block-to-e-tile map
using
Block2ETileMap
=
remove_cvref_t
<
decltype
(
GridwiseGemm
::
MakeBlock2ETileMap
(
EGridDesc_M_N
{}))
>
;
// Argument
struct
Argument
:
public
BaseArgument
{
Argument
(
std
::
array
<
const
void
*
,
NumATensor
>
p_as_grid
,
std
::
array
<
const
void
*
,
NumBTensor
>
p_bs_grid
,
std
::
array
<
const
void
*
,
NumDTensor
>
p_ds_grid
,
void
*
p_e_grid
,
index_t
MRaw
,
index_t
NRaw
,
index_t
KRaw
,
std
::
array
<
index_t
,
NumATensor
>
StrideAs
,
std
::
array
<
index_t
,
NumBTensor
>
StrideBs
,
std
::
array
<
index_t
,
NumDTensor
>
StrideDs
,
index_t
StrideE
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CDEElementwiseOperation
cde_element_op
)
:
p_as_grid_
{},
p_bs_grid_
{},
p_ds_grid_
{},
p_e_grid_
{
static_cast
<
EDataType
*>
(
p_e_grid
)},
as_grid_desc_m_k_
{},
bs_grid_desc_n_k_
{},
ds_grid_desc_m_n_
{},
e_grid_desc_m_n_
{
GridwiseGemm
::
template
MakeEGridDescriptor_M_N
<
ELayout
,
GemmSpec
>(
MRaw
,
NRaw
,
StrideE
)},
as_grid_desc_ak0_m_ak1_
{},
bs_grid_desc_bk0_n_bk1_
{},
ds_grid_desc_mblock_mperblock_nblock_nperblock_
{},
e_grid_desc_mblock_mperblock_nblock_nperblock_
{},
block_2_etile_map_
{
GridwiseGemm
::
MakeBlock2ETileMap
(
e_grid_desc_m_n_
)},
a_element_op_
{
a_element_op
},
b_element_op_
{
b_element_op
},
cde_element_op_
{
cde_element_op
},
MRaw_
{
MRaw
},
NRaw_
{
NRaw
},
KRaw_
{
KRaw
}
{
// populate pointer, desc for As
static_for
<
0
,
NumATensor
,
1
>
{}([
&
](
auto
i
)
{
using
ALayout
=
remove_cvref_t
<
tuple_element_t
<
i
.
value
,
AsLayout
>>
;
using
ADataType
=
remove_cvref_t
<
tuple_element_t
<
i
.
value
,
AsDataType
>>
;
// A pointer
p_as_grid_
(
i
)
=
static_cast
<
const
ADataType
*>
(
p_as_grid
[
i
]);
// A desc
as_grid_desc_m_k_
(
i
)
=
GridwiseGemm
::
template
MakeAGridDescriptor_M_K
<
ALayout
,
GemmSpec
>(
MRaw
,
KRaw
,
StrideAs
[
i
]);
});
// populate pointer, desc for Bs
static_for
<
0
,
NumBTensor
,
1
>
{}([
&
](
auto
i
)
{
using
BLayout
=
remove_cvref_t
<
tuple_element_t
<
i
.
value
,
BsLayout
>>
;
using
BDataType
=
remove_cvref_t
<
tuple_element_t
<
i
.
value
,
BsDataType
>>
;
// B pointer
p_bs_grid_
(
i
)
=
static_cast
<
const
BDataType
*>
(
p_bs_grid
[
i
]);
// B desc
bs_grid_desc_n_k_
(
i
)
=
GridwiseGemm
::
template
MakeBGridDescriptor_N_K
<
BLayout
,
GemmSpec
>(
KRaw
,
NRaw
,
StrideBs
[
i
]);
});
// populate pointer, desc for Ds
static_for
<
0
,
NumDTensor
,
1
>
{}([
&
](
auto
i
)
{
using
DLayout
=
remove_cvref_t
<
tuple_element_t
<
i
.
value
,
DsLayout
>>
;
using
DDataType
=
remove_cvref_t
<
tuple_element_t
<
i
.
value
,
DsDataType
>>
;
// D pointer
p_ds_grid_
(
i
)
=
static_cast
<
const
DDataType
*>
(
p_ds_grid
[
i
]);
// D desc
ds_grid_desc_m_n_
(
i
)
=
GridwiseGemm
::
template
MakeEGridDescriptor_M_N
<
DLayout
,
GemmSpec
>(
MRaw
,
NRaw
,
StrideDs
[
i
]);
});
// populate desc for Ds/E
if
(
GridwiseGemm
::
CheckValidity
(
as_grid_desc_m_k_
,
bs_grid_desc_n_k_
,
ds_grid_desc_m_n_
,
e_grid_desc_m_n_
,
block_2_etile_map_
))
{
as_grid_desc_ak0_m_ak1_
=
GridwiseGemm
::
MakeAsGridDescriptor_AK0_M_AK1
(
as_grid_desc_m_k_
);
bs_grid_desc_bk0_n_bk1_
=
GridwiseGemm
::
MakeBsGridDescriptor_BK0_N_BK1
(
bs_grid_desc_n_k_
);
ds_grid_desc_mblock_mperblock_nblock_nperblock_
=
GridwiseGemm
::
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
ds_grid_desc_m_n_
);
e_grid_desc_mblock_mperblock_nblock_nperblock_
=
GridwiseGemm
::
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
e_grid_desc_m_n_
);
}
}
void
Print
()
const
{
// std::cout << "A[M, K]: " << as_grid_desc_m_k_ << std::endl;
// std::cout << "B[N, K]: " << bs_grid_desc_n_k_ << std::endl;
// static_for<0, NumDTensor, 1>{}(
//[&](auto i) { std::cout << "Ds[M, N]: " << ds_grid_desc_m_n_[i] << std::endl; });
// std::cout << "E[M, N]: " << e_grid_desc_m_n_ << std::endl;
}
// private:
// pointers
typename
GridwiseGemm
::
AsGridPointer
p_as_grid_
;
typename
GridwiseGemm
::
BsGridPointer
p_bs_grid_
;
typename
GridwiseGemm
::
DsGridPointer
p_ds_grid_
;
EDataType
*
p_e_grid_
;
// tensor descriptors for problem definiton
AsGridDesc_M_K
as_grid_desc_m_k_
;
BsGridDesc_N_K
bs_grid_desc_n_k_
;
DsGridDesc_M_N
ds_grid_desc_m_n_
;
EGridDesc_M_N
e_grid_desc_m_n_
;
// tensor descriptors for block/thread-wise copy
AsGridDesc_AK0_M_AK1
as_grid_desc_ak0_m_ak1_
;
BsGridDesc_BK0_N_BK1
bs_grid_desc_bk0_n_bk1_
;
DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
ds_grid_desc_mblock_mperblock_nblock_nperblock_
;
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
e_grid_desc_mblock_mperblock_nblock_nperblock_
;
// block-to-e-tile map
Block2ETileMap
block_2_etile_map_
;
// element-wise op
AElementwiseOperation
a_element_op_
;
BElementwiseOperation
b_element_op_
;
CDEElementwiseOperation
cde_element_op_
;
// for checking vector load/store
index_t
MRaw_
;
index_t
NRaw_
;
index_t
KRaw_
;
};
// Invoker
struct
Invoker
:
public
BaseInvoker
{
using
Argument
=
DeviceOp
::
Argument
;
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
{
if
(
!
GridwiseGemm
::
CheckValidity
(
arg
.
as_grid_desc_m_k_
,
arg
.
bs_grid_desc_n_k_
,
arg
.
ds_grid_desc_m_n_
,
arg
.
e_grid_desc_m_n_
,
arg
.
block_2_etile_map_
))
{
throw
std
::
runtime_error
(
"wrong! GridwiseGemm has invalid setting"
);
}
const
index_t
grid_size
=
arg
.
block_2_etile_map_
.
CalculateGridSize
(
arg
.
e_grid_desc_m_n_
);
auto
launch_kernel
=
[
&
](
auto
has_main_k_block_loop
)
{
constexpr
bool
has_main_loop
=
has_main_k_block_loop
.
value
;
const
auto
kernel
=
kernel_gemm_multiple_abd_xdl_cshuffle
<
GridwiseGemm
,
typename
GridwiseGemm
::
AsGridPointer
,
typename
GridwiseGemm
::
BsGridPointer
,
typename
GridwiseGemm
::
DsGridPointer
,
EDataType
,
AElementwiseOperation
,
BElementwiseOperation
,
CDEElementwiseOperation
,
DeviceOp
::
AsGridDesc_AK0_M_AK1
,
DeviceOp
::
BsGridDesc_BK0_N_BK1
,
DeviceOp
::
DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
,
DeviceOp
::
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
,
DeviceOp
::
Block2ETileMap
,
has_main_loop
>
;
return
launch_and_time_kernel
(
stream_config
,
kernel
,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
arg
.
p_as_grid_
,
arg
.
p_bs_grid_
,
arg
.
p_ds_grid_
,
arg
.
p_e_grid_
,
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
cde_element_op_
,
arg
.
as_grid_desc_ak0_m_ak1_
,
arg
.
bs_grid_desc_bk0_n_bk1_
,
arg
.
ds_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg
.
e_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg
.
block_2_etile_map_
);
};
const
auto
K
=
arg
.
as_grid_desc_m_k_
[
I0
].
GetLength
(
I1
);
if
(
GridwiseGemm
::
CalculateHasMainKBlockLoop
(
K
))
{
return
launch_kernel
(
integral_constant
<
bool
,
true
>
{});
}
else
{
return
launch_kernel
(
integral_constant
<
bool
,
false
>
{});
}
}
// polymorphic
float
Run
(
const
BaseArgument
*
p_arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
override
{
return
Run
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
),
stream_config
);
}
};
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
if
(
!
ck
::
is_xdl_supported
())
{
return
false
;
}
// check vector load/store
{
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
bool
all_valid
=
true
;
static_for
<
0
,
NumATensor
,
1
>
{}([
&
](
auto
i
)
{
using
ALayout
=
remove_cvref_t
<
tuple_element_t
<
i
.
value
,
AsLayout
>>
;
// check vector load of A
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
ABlockTransferSrcVectorDim
==
2
)
{
if
(
arg
.
KRaw_
%
ABlockTransferSrcScalarPerVector
!=
0
)
{
all_valid
=
false
;
}
}
else
if
constexpr
(
is_same_v
<
ALayout
,
Col
>
&&
ABlockTransferSrcVectorDim
==
1
)
{
// FIXME: not rigorous
if
(
arg
.
MRaw_
%
ABlockTransferSrcScalarPerVector
!=
0
)
{
all_valid
=
false
;
}
}
else
{
all_valid
=
false
;
}
});
static_for
<
0
,
NumBTensor
,
1
>
{}([
&
](
auto
i
)
{
using
BLayout
=
remove_cvref_t
<
tuple_element_t
<
i
.
value
,
BsLayout
>>
;
// check vector laod of B
if
constexpr
(
is_same_v
<
BLayout
,
Col
>
&&
BBlockTransferSrcVectorDim
==
2
)
{
if
(
arg
.
KRaw_
%
BBlockTransferSrcScalarPerVector
!=
0
)
{
all_valid
=
false
;
}
}
else
if
constexpr
(
is_same_v
<
BLayout
,
Row
>
&&
BBlockTransferSrcVectorDim
==
1
)
{
// FIXME: not rigorous
if
(
arg
.
NRaw_
%
BBlockTransferSrcScalarPerVector
!=
0
)
{
all_valid
=
false
;
}
}
else
{
all_valid
=
false
;
}
});
// check vector load of Ds
// only support RowMajor for now
static_for
<
0
,
NumDTensor
,
1
>
{}([
&
](
auto
i
)
{
using
DLayout
=
remove_cvref_t
<
tuple_element_t
<
i
.
value
,
DsLayout
>>
;
if
constexpr
(
!
is_same_v
<
DLayout
,
Row
>
)
{
all_valid
=
false
;
}
});
if
(
!
all_valid
)
{
return
false
;
}
// check vector store of E
// only support RowMajor for now
if
constexpr
(
is_same_v
<
ELayout
,
Row
>
)
{
if
(
arg
.
NRaw_
%
CDEBlockTransferScalarPerVector_NPerBlock
!=
0
)
{
return
false
;
}
}
else
{
return
false
;
}
}
return
GridwiseGemm
::
CheckValidity
(
arg
.
as_grid_desc_m_k_
,
arg
.
bs_grid_desc_n_k_
,
arg
.
ds_grid_desc_m_n_
,
arg
.
e_grid_desc_m_n_
,
arg
.
block_2_etile_map_
);
}
// polymorphic
bool
IsSupportedArgument
(
const
BaseArgument
*
p_arg
)
override
{
return
IsSupportedArgument
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
));
}
static
auto
MakeArgument
(
std
::
array
<
const
void
*
,
NumATensor
>
p_as
,
std
::
array
<
const
void
*
,
NumBTensor
>
p_bs
,
std
::
array
<
const
void
*
,
NumDTensor
>
p_ds
,
void
*
p_e
,
index_t
MRaw
,
index_t
NRaw
,
index_t
KRaw
,
std
::
array
<
index_t
,
NumATensor
>
StrideAs
,
std
::
array
<
index_t
,
NumBTensor
>
StrideBs
,
std
::
array
<
index_t
,
NumDTensor
>
StrideDs
,
index_t
StrideE
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CDEElementwiseOperation
cde_element_op
)
{
return
Argument
{
p_as
,
p_bs
,
p_ds
,
p_e
,
MRaw
,
NRaw
,
KRaw
,
StrideAs
,
StrideBs
,
StrideDs
,
StrideE
,
a_element_op
,
b_element_op
,
cde_element_op
};
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
// polymorphic
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
std
::
array
<
const
void
*
,
NumATensor
>
p_as
,
std
::
array
<
const
void
*
,
NumBTensor
>
p_bs
,
std
::
array
<
const
void
*
,
NumDTensor
>
p_ds
,
void
*
p_e
,
index_t
MRaw
,
index_t
NRaw
,
index_t
KRaw
,
std
::
array
<
ck
::
index_t
,
NumATensor
>
StrideAs
,
std
::
array
<
ck
::
index_t
,
NumBTensor
>
StrideBs
,
std
::
array
<
ck
::
index_t
,
NumDTensor
>
StrideDs
,
index_t
StrideE
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CDEElementwiseOperation
cde_element_op
)
override
{
return
std
::
make_unique
<
Argument
>
(
p_as
,
p_bs
,
p_ds
,
p_e
,
MRaw
,
NRaw
,
KRaw
,
StrideAs
,
StrideBs
,
StrideDs
,
StrideE
,
a_element_op
,
b_element_op
,
cde_element_op
);
}
// polymorphic
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
override
{
return
std
::
make_unique
<
Invoker
>
(
Invoker
{});
}
// polymorphic
std
::
string
GetTypeString
()
const
override
{
auto
str
=
std
::
stringstream
();
std
::
map
<
LoopScheduler
,
std
::
string
>
LoopSchedToString
{
{
LoopScheduler
::
Default
,
"Default"
},
{
LoopScheduler
::
Interwave
,
"Interwave"
}};
std
::
map
<
PipelineVersion
,
std
::
string
>
PipelineVersionToString
{{
PipelineVersion
::
v1
,
"v1"
},
{
PipelineVersion
::
v2
,
"v2"
}};
// clang-format off
str
<<
"DeviceGemmMultipleABD_Xdl_CShuffle"
<<
"<"
<<
BlockSize
<<
", "
<<
MPerBlock
<<
", "
<<
NPerBlock
<<
", "
<<
KPerBlock
<<
", "
<<
AK1
<<
", "
<<
BK1
<<
", "
<<
MPerXDL
<<
", "
<<
NPerXDL
<<
", "
<<
MXdlPerWave
<<
", "
<<
NXdlPerWave
<<
", "
<<
ABlockTransferSrcScalarPerVector
<<
", "
<<
BBlockTransferSrcScalarPerVector
<<
", "
<<
CShuffleMXdlPerWavePerShuffle
<<
", "
<<
CShuffleNXdlPerWavePerShuffle
<<
", "
<<
getGemmSpecializationString
(
GemmSpec
)
<<
">"
<<
" LoopScheduler: "
<<
LoopSchedToString
[
LoopSched
]
<<
", "
<<
"PipelineVersion: "
<<
PipelineVersionToString
[
PipelineVer
];
// clang-format on
return
str
.
str
();
}
};
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp
View file @
9f8ab221
...
...
@@ -66,7 +66,8 @@ template <typename ALayout,
index_t
CShuffleBlockTransferScalarPerVector_NPerBlock
,
LoopScheduler
LoopSched
=
make_default_loop_scheduler
(),
PipelineVersion
PipelineVer
=
PipelineVersion
::
v1
,
typename
ComputeType
=
CDataType
>
typename
ComputeTypeA
=
CDataType
,
typename
ComputeTypeB
=
ComputeTypeA
>
struct
DeviceGemm_Xdl_CShuffle
:
public
DeviceGemm
<
ALayout
,
BLayout
,
CLayout
,
...
...
@@ -131,7 +132,8 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
CShuffleBlockTransferScalarPerVector_NPerBlock
,
LoopSched
,
PipelineVer
,
ComputeType
>
;
ComputeTypeA
,
ComputeTypeB
>
;
using
Argument
=
typename
GridwiseGemm
::
Argument
;
...
...
include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_splitk_c_shuffle.hpp
View file @
9f8ab221
...
...
@@ -69,7 +69,8 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
CDataType
,
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
>
CElementwiseOperation
,
ComputeType
>
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
...
...
@@ -126,7 +127,50 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
PipelineVer
,
ComputeType
>
;
using
Argument
=
typename
GridwiseGemm
::
Argument
;
struct
Argument
:
public
GridwiseGemm
::
Argument
{
Argument
(
const
ADataType
*
p_a_grid_
,
const
BDataType
*
p_b_grid_
,
CDataType
*
p_c_grid_
,
index_t
M_
,
index_t
N_
,
index_t
K_
,
index_t
StrideA_
,
index_t
StrideB_
,
index_t
StrideC_
,
index_t
MPadded_
,
index_t
NPadded_
,
index_t
KPadded_
,
index_t
K0_
,
index_t
k_batch_
,
AElementwiseOperation
a_element_op_
,
BElementwiseOperation
b_element_op_
,
CElementwiseOperation
c_element_op_
)
:
GridwiseGemm
::
Argument
(
p_a_grid_
,
p_b_grid_
,
p_c_grid_
,
M_
,
N_
,
K_
,
StrideA_
,
StrideB_
,
StrideC_
,
MPadded_
,
NPadded_
,
KPadded_
,
K0_
,
k_batch_
),
a_element_op
(
a_element_op_
),
b_element_op
(
b_element_op_
),
c_element_op
(
c_element_op_
)
{
}
AElementwiseOperation
a_element_op
;
BElementwiseOperation
b_element_op
;
CElementwiseOperation
c_element_op
;
};
using
DefaultBlock2CTileMap
=
typename
GridwiseGemm
::
DefaultBlock2CTileMap
;
// Invoker
...
...
@@ -167,8 +211,17 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
karg
.
M
*
karg
.
N
*
sizeof
(
CDataType
),
stream_config
.
stream_id_
));
ave_time
=
launch_and_time_kernel
(
stream_config
,
kernel
,
dim3
(
gdx
,
gdy
,
gdz
),
dim3
(
BlockSize
),
0
,
karg
,
b2c_map
);
ave_time
=
launch_and_time_kernel
(
stream_config
,
kernel
,
dim3
(
gdx
,
gdy
,
gdz
),
dim3
(
BlockSize
),
0
,
static_cast
<
typename
GridwiseGemm
::
Argument
>
(
karg
),
b2c_map
,
karg
.
a_element_op
,
karg
.
b_element_op
,
karg
.
c_element_op
);
};
if
(
has_main_k0_block_loop
)
...
...
@@ -179,7 +232,10 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
kernel_gemm_xdlops_v2r4r2_simplified
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
Set
,
DefaultBlock2CTileMap
>
;
DefaultBlock2CTileMap
,
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
>
;
Run
(
kernel
);
}
...
...
@@ -189,7 +245,10 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
kernel_gemm_xdlops_v2r4r2_simplified
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
AtomicAdd
,
DefaultBlock2CTileMap
>
;
DefaultBlock2CTileMap
,
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
>
;
Run
(
kernel
);
}
...
...
@@ -202,7 +261,10 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
kernel_gemm_xdlops_v2r4r2_simplified
<
GridwiseGemm
,
false
,
InMemoryDataOperationEnum
::
Set
,
DefaultBlock2CTileMap
>
;
DefaultBlock2CTileMap
,
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
>
;
Run
(
kernel
);
}
...
...
@@ -212,7 +274,10 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
kernel_gemm_xdlops_v2r4r2_simplified
<
GridwiseGemm
,
false
,
InMemoryDataOperationEnum
::
AtomicAdd
,
DefaultBlock2CTileMap
>
;
DefaultBlock2CTileMap
,
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
>
;
Run
(
kernel
);
}
...
...
@@ -260,12 +325,12 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
index_t
StrideA
,
index_t
StrideB
,
index_t
StrideC
,
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
c_element_op
,
index_t
KBatch
)
{
return
Argument
{
p_a
,
return
Argument
(
p_a
,
p_b
,
p_c
,
M
,
...
...
@@ -278,7 +343,10 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
GridwiseGemm
::
CalculateNPadded
(
N
),
GridwiseGemm
::
CalculateKPadded
(
K
,
KBatch
),
GridwiseGemm
::
CalculateK0
(
K
,
KBatch
),
KBatch
};
KBatch
,
a_element_op
,
b_element_op
,
c_element_op
);
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
...
...
@@ -293,9 +361,9 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
index_t
StrideA
,
index_t
StrideB
,
index_t
StrideC
,
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
c_element_op
,
ck
::
index_t
KBatch
=
1
)
override
{
return
std
::
make_unique
<
Argument
>
(
static_cast
<
const
ADataType
*>
(
p_a
),
...
...
@@ -311,7 +379,10 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
GridwiseGemm
::
CalculateNPadded
(
N
),
GridwiseGemm
::
CalculateKPadded
(
K
,
KBatch
),
GridwiseGemm
::
CalculateK0
(
K
,
KBatch
),
KBatch
);
KBatch
,
a_element_op
,
b_element_op
,
c_element_op
);
}
// polymorphic
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp
0 → 100644
View file @
9f8ab221
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <sstream>
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_conv_bwd_data_multiple_d.hpp"
#include "ck/tensor_operation/gpu/device/convolution_backward_data_specialization.hpp"
#include "ck/tensor_operation/operator_transform/transform_conv_bwd_data_to_gemm_v1.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/host_utility/io.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
// Conv backward data multiple D:
// input : output image A: [G, N, K, Ho, Wo]
// input : weight B: [G, K, C, Y, X],
// input : D0, D1, ... : [G, N, K, Ho, Wo]
// output : input image E: [G, N, C, Hi, Wi]
// C = a_op(A) * b_op(B)
// E = cde_op(C, D0, D1, ...)
template
<
index_t
NDimSpatial
,
typename
ALayout
,
// output image
typename
BLayout
,
// weight
typename
DsLayout
,
// bias
typename
ELayout
,
// input image
typename
ADataType
,
// output image
typename
BDataType
,
// weight
typename
AccDataType
,
typename
CShuffleDataType
,
typename
DsDataType
,
// bias
typename
EDataType
,
// input image
typename
AElementwiseOp
,
// output image
typename
BElementwiseOp
,
// weight
typename
CDEElementwiseOp
,
// C, bias, and input image
ConvolutionBackwardDataSpecialization
ConvBackwardDataSpecialization
,
ck
::
index_t
BlockSize
,
ck
::
index_t
MPerBlock
,
ck
::
index_t
NPerBlock
,
ck
::
index_t
K0PerBlock
,
ck
::
index_t
K1
,
ck
::
index_t
MPerWMMA
,
ck
::
index_t
NPerWMMA
,
ck
::
index_t
MRepeat
,
ck
::
index_t
NRepeat
,
typename
ABlockTransferThreadClusterLengths_AK0_M_AK1
,
typename
ABlockTransferThreadClusterArrangeOrder
,
typename
ABlockTransferSrcAccessOrder
,
index_t
ABlockTransferSrcVectorDim
,
index_t
ABlockTransferSrcScalarPerVector
,
index_t
ABlockTransferDstScalarPerVector_AK1
,
bool
ABlockLdsExtraM
,
typename
BBlockTransferThreadClusterLengths_BK0_N_BK1
,
typename
BBlockTransferThreadClusterArrangeOrder
,
typename
BBlockTransferSrcAccessOrder
,
index_t
BBlockTransferSrcVectorDim
,
index_t
BBlockTransferSrcScalarPerVector
,
index_t
BBlockTransferDstScalarPerVector_BK1
,
bool
BBlockLdsExtraN
,
index_t
CShuffleMRepeatPerShuffle
,
index_t
CShuffleNRepeatPerShuffle
,
typename
CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
index_t
CDEShuffleBlockTransferScalarPerVector_NPerBlock
,
index_t
NumGemmKPrefetchStage
=
1
,
LoopScheduler
LoopSched
=
make_default_loop_scheduler
(),
ck
::
PipelineVersion
PipelineVer
=
ck
::
PipelineVersion
::
v1
>
struct
DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle
:
public
DeviceGroupedConvBwdDataMultipleD
<
NDimSpatial
,
ALayout
,
// output image
BLayout
,
// weight
DsLayout
,
// bias
ELayout
,
// input image
ADataType
,
// output image
BDataType
,
// weight
DsDataType
,
// bias
EDataType
,
// input image
AElementwiseOp
,
BElementwiseOp
,
CDEElementwiseOp
>
{
// TODO: Extend support for more spatial dimensions.
static_assert
(
NDimSpatial
==
2
||
NDimSpatial
==
3
,
"wrong! only implemented for 2D and 3D now"
);
using
DeviceOp
=
DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle
;
static
constexpr
index_t
NumDTensor
=
DsDataType
::
Size
();
// TODO: Add support for different A and B data types.
using
ABDataType
=
ADataType
;
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
static
constexpr
index_t
KPerBlock
=
K0PerBlock
*
K1
;
static
constexpr
auto
transform_conv_to_gemm
=
TransformConvBwdDataToGemm_v1
<
NDimSpatial
,
ConvBackwardDataSpecialization
,
K1
,
K1
,
MPerBlock
,
NPerBlock
,
KPerBlock
,
true
/* DoPadGemmM */
,
true
/* DoPadGemmN */
>
{};
static
auto
GetDummyABDsEGridDescriptor
()
{
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>
dummy_tensor_lengths
=
{
1
};
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>
dummy_tensor_strides
=
{
1
};
const
std
::
array
<
index_t
,
NDimSpatial
>
dummy_spatial_lengths
=
{
1
};
const
auto
a_grid_desc_ak0_m_ak1
=
transform_conv_to_gemm
.
template
MakeADescriptor_AK0_M_AK1
<
ALayout
>(
dummy_tensor_lengths
,
dummy_tensor_strides
,
dummy_tensor_lengths
,
dummy_tensor_strides
,
dummy_tensor_lengths
,
dummy_tensor_strides
,
dummy_spatial_lengths
,
dummy_spatial_lengths
,
dummy_spatial_lengths
,
dummy_spatial_lengths
,
dummy_spatial_lengths
);
const
auto
b_grid_desc_bk0_n_bk1
=
transform_conv_to_gemm
.
template
MakeBDescriptor_BK0_N_BK1
<
BLayout
>(
dummy_tensor_lengths
,
dummy_tensor_strides
,
dummy_tensor_lengths
,
dummy_tensor_strides
,
dummy_tensor_lengths
,
dummy_tensor_strides
,
dummy_spatial_lengths
,
dummy_spatial_lengths
,
dummy_spatial_lengths
,
dummy_spatial_lengths
,
dummy_spatial_lengths
);
const
auto
ds_grid_desc_m_n
=
generate_tuple
(
[
&
](
auto
i
)
{
using
DLayout
=
remove_cvref_t
<
tuple_element_t
<
i
.
value
,
DsLayout
>>
;
return
transform_conv_to_gemm
.
template
MakeCDescriptor_M_N
<
DLayout
>(
dummy_tensor_lengths
,
dummy_tensor_strides
,
dummy_tensor_lengths
,
dummy_tensor_strides
,
dummy_tensor_lengths
,
dummy_tensor_strides
,
dummy_spatial_lengths
,
dummy_spatial_lengths
,
dummy_spatial_lengths
,
dummy_spatial_lengths
,
dummy_spatial_lengths
);
},
Number
<
NumDTensor
>
{});
const
auto
e_grid_desc_m_n
=
transform_conv_to_gemm
.
template
MakeCDescriptor_M_N
<
ELayout
>(
dummy_tensor_lengths
,
dummy_tensor_strides
,
dummy_tensor_lengths
,
dummy_tensor_strides
,
dummy_tensor_lengths
,
dummy_tensor_strides
,
dummy_spatial_lengths
,
dummy_spatial_lengths
,
dummy_spatial_lengths
,
dummy_spatial_lengths
,
dummy_spatial_lengths
);
return
make_tuple
(
a_grid_desc_ak0_m_ak1
,
b_grid_desc_bk0_n_bk1
,
ds_grid_desc_m_n
,
e_grid_desc_m_n
);
}
// desc
using
ABDsEGridDesc
=
decltype
(
GetDummyABDsEGridDescriptor
());
using
AGridDesc_AK0_M_AK1
=
remove_cvref_t
<
tuple_element_t
<
0
,
ABDsEGridDesc
>>
;
using
BGridDesc_BK0_N_BK1
=
remove_cvref_t
<
tuple_element_t
<
1
,
ABDsEGridDesc
>>
;
using
DsGridDesc_M_N
=
remove_cvref_t
<
tuple_element_t
<
2
,
ABDsEGridDesc
>>
;
using
EGridDesc_M_N
=
remove_cvref_t
<
tuple_element_t
<
3
,
ABDsEGridDesc
>>
;
// GridwiseGemm
using
GridwiseGemm
=
GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle
<
// DataType Family
ADataType
,
BDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
// InMemory Data Descriptor
AGridDesc_AK0_M_AK1
,
BGridDesc_BK0_N_BK1
,
DsGridDesc_M_N
,
EGridDesc_M_N
,
// ElementwiseOp Family
AElementwiseOp
,
BElementwiseOp
,
CDEElementwiseOp
,
InMemoryDataOperationEnum
::
Set
,
// Tiling Family
MPerBlock
,
NPerBlock
,
K0PerBlock
,
MPerWMMA
,
NPerWMMA
,
K1
,
MRepeat
,
NRepeat
,
// ThreadCluster Family
BlockSize
,
ABlockTransferThreadClusterLengths_AK0_M_AK1
,
ABlockTransferThreadClusterArrangeOrder
,
ABlockTransferSrcAccessOrder
,
ABlockTransferSrcVectorDim
,
ABlockTransferSrcScalarPerVector
,
ABlockTransferDstScalarPerVector_AK1
,
false
,
ABlockLdsExtraM
,
BBlockTransferThreadClusterLengths_BK0_N_BK1
,
BBlockTransferThreadClusterArrangeOrder
,
BBlockTransferSrcAccessOrder
,
BBlockTransferSrcVectorDim
,
BBlockTransferSrcScalarPerVector
,
BBlockTransferDstScalarPerVector_BK1
,
false
,
BBlockLdsExtraN
,
CShuffleMRepeatPerShuffle
,
CShuffleNRepeatPerShuffle
,
CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
CDEShuffleBlockTransferScalarPerVector_NPerBlock
,
NumGemmKPrefetchStage
,
LoopSched
,
PipelineVer
>
;
using
DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
=
decltype
(
GridwiseGemm
::
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
DsGridDesc_M_N
{}));
using
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
=
decltype
(
GridwiseGemm
::
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
EGridDesc_M_N
{}));
// Argument
struct
Argument
:
public
BaseArgument
{
Argument
(
const
void
*
p_a
,
// output image
const
void
*
p_b
,
// weight
const
std
::
array
<
const
void
*
,
NumDTensor
>&
p_ds
,
// bias
void
*
p_e
,
// input image
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
a_g_n_k_wos_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
a_g_n_k_wos_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
b_g_k_c_xs_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
b_g_k_c_xs_strides
,
const
std
::
array
<
std
::
array
<
index_t
,
NDimSpatial
+
3
>
,
NumDTensor
>&
ds_g_n_c_wis_lengths
,
const
std
::
array
<
std
::
array
<
index_t
,
NDimSpatial
+
3
>
,
NumDTensor
>&
ds_g_n_c_wis_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
e_g_n_c_wis_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
e_g_n_c_wis_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_dilations
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_left_pads
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_right_pads
,
const
AElementwiseOp
&
a_element_op
,
const
BElementwiseOp
&
b_element_op
,
const
CDEElementwiseOp
&
cde_element_op
)
:
p_a_grid_
{
static_cast
<
const
ADataType
*>
(
p_a
)},
p_b_grid_
{
static_cast
<
const
BDataType
*>
(
p_b
)},
p_ds_grid_
{},
p_e_grid_
{
static_cast
<
EDataType
*>
(
p_e
)},
num_group_
{
a_g_n_k_wos_lengths
[
0
]},
a_element_op_
{
a_element_op
},
b_element_op_
{
b_element_op
},
cde_element_op_
{
cde_element_op
},
a_g_n_k_wos_lengths_
{
a_g_n_k_wos_lengths
},
a_g_n_k_wos_strides_
{
a_g_n_k_wos_strides
},
b_g_k_c_xs_lengths_
{
b_g_k_c_xs_lengths
},
b_g_k_c_xs_strides_
{
b_g_k_c_xs_strides
},
ds_g_n_c_wis_lengths_
{
ds_g_n_c_wis_lengths
},
ds_g_n_c_wis_strides_
{
ds_g_n_c_wis_strides
},
e_g_n_c_wis_lengths_
{
e_g_n_c_wis_lengths
},
e_g_n_c_wis_strides_
{
e_g_n_c_wis_strides
},
conv_filter_strides_
{
conv_filter_strides
},
conv_filter_dilations_
{
conv_filter_dilations
},
input_left_pads_
{
input_left_pads
},
input_right_pads_
{
input_right_pads
}
{
// populate Ds pointer
static_for
<
0
,
NumDTensor
,
1
>
{}([
&
](
auto
i
)
{
using
DDataType
=
remove_cvref_t
<
tuple_element_t
<
i
.
value
,
DsDataType
>>
;
p_ds_grid_
(
i
)
=
static_cast
<
const
DDataType
*>
(
p_ds
[
i
]);
});
// A/B/Ds/E Batch Stride
compute_ptr_offset_of_batch_
.
BatchStrideA_
=
a_g_n_k_wos_strides
[
0
];
compute_ptr_offset_of_batch_
.
BatchStrideB_
=
b_g_k_c_xs_strides
[
0
];
compute_ptr_offset_of_batch_
.
BatchStrideE_
=
e_g_n_c_wis_strides
[
0
];
static_for
<
0
,
NumDTensor
,
1
>
{}([
&
](
auto
i
)
{
compute_ptr_offset_of_batch_
.
BatchStrideDs_
(
i
)
=
ds_g_n_c_wis_strides
[
i
][
0
];
});
static
constexpr
auto
NonSpatialDimsNum
=
Number
<
3
>
{};
static
constexpr
auto
DIdx
=
Number
<
NonSpatialDimsNum
>
{};
static
constexpr
auto
HIdx
=
NDimSpatial
==
2
?
Number
<
NonSpatialDimsNum
>
{}
:
Number
<
NonSpatialDimsNum
+
1
>
{};
static
constexpr
auto
WIdx
=
NDimSpatial
==
2
?
Number
<
NonSpatialDimsNum
+
1
>
{}
:
Number
<
NonSpatialDimsNum
+
2
>
{};
static
constexpr
auto
ZIdx
=
Number
<
NonSpatialDimsNum
>
{};
static
constexpr
auto
YIdx
=
NDimSpatial
==
2
?
Number
<
NonSpatialDimsNum
>
{}
:
Number
<
NonSpatialDimsNum
+
1
>
{};
static
constexpr
auto
XIdx
=
NDimSpatial
==
2
?
Number
<
NonSpatialDimsNum
+
1
>
{}
:
Number
<
NonSpatialDimsNum
+
2
>
{};
// problem definition
const
index_t
Z
=
b_g_k_c_xs_lengths
[
ZIdx
];
const
index_t
Y
=
b_g_k_c_xs_lengths
[
YIdx
];
const
index_t
X
=
b_g_k_c_xs_lengths
[
XIdx
];
const
index_t
ConvStrideD
=
conv_filter_strides
[
DIdx
-
NonSpatialDimsNum
];
const
index_t
ConvStrideH
=
conv_filter_strides
[
HIdx
-
NonSpatialDimsNum
];
const
index_t
ConvStrideW
=
conv_filter_strides
[
WIdx
-
NonSpatialDimsNum
];
const
index_t
ConvDilationD
=
conv_filter_dilations
[
DIdx
-
NonSpatialDimsNum
];
const
index_t
ConvDilationH
=
conv_filter_dilations
[
HIdx
-
NonSpatialDimsNum
];
const
index_t
ConvDilationW
=
conv_filter_dilations
[
WIdx
-
NonSpatialDimsNum
];
const
auto
GcdStrideDilationD
=
math
::
gcd
(
ConvStrideD
,
ConvDilationD
);
const
auto
GcdStrideDilationH
=
math
::
gcd
(
ConvStrideH
,
ConvDilationH
);
const
auto
GcdStrideDilationW
=
math
::
gcd
(
ConvStrideW
,
ConvDilationW
);
const
auto
ZTilde
=
NDimSpatial
==
3
?
ConvStrideD
/
GcdStrideDilationD
:
1
;
const
auto
YTilde
=
ConvStrideH
/
GcdStrideDilationH
;
const
auto
XTilde
=
ConvStrideW
/
GcdStrideDilationW
;
for
(
index_t
i_ztilde
=
0
;
i_ztilde
<
ZTilde
;
++
i_ztilde
)
{
for
(
index_t
i_ytilde
=
0
;
i_ytilde
<
YTilde
;
++
i_ytilde
)
{
for
(
index_t
i_xtilde
=
0
;
i_xtilde
<
XTilde
;
++
i_xtilde
)
{
// check slice is valid
const
auto
ZDotSlice
=
NDimSpatial
==
3
?
math
::
integer_divide_ceil
(
Z
-
i_ztilde
,
ZTilde
)
:
1
;
const
auto
YDotSlice
=
math
::
integer_divide_ceil
(
Y
-
i_ytilde
,
YTilde
);
const
auto
XDotSlice
=
math
::
integer_divide_ceil
(
X
-
i_xtilde
,
XTilde
);
if
(
YDotSlice
*
XDotSlice
*
ZDotSlice
<=
0
)
{
continue
;
}
std
::
array
<
index_t
,
NDimSpatial
>
tildes
;
if
constexpr
(
NDimSpatial
==
2
)
{
tildes
=
{
i_ytilde
,
i_xtilde
};
}
else
if
constexpr
(
NDimSpatial
==
3
)
{
tildes
=
{
i_ztilde
,
i_ytilde
,
i_xtilde
};
}
else
{
throw
std
::
runtime_error
(
"wrong! only implemented for 2D and 3D now"
);
}
const
auto
a_grid_desc_ak0_m_ak1
=
transform_conv_to_gemm
.
template
MakeADescriptor_AK0_M_AK1
<
ALayout
>(
a_g_n_k_wos_lengths
,
a_g_n_k_wos_strides
,
b_g_k_c_xs_lengths
,
b_g_k_c_xs_strides
,
e_g_n_c_wis_lengths
,
e_g_n_c_wis_strides
,
conv_filter_strides
,
conv_filter_dilations
,
input_left_pads
,
input_right_pads
,
tildes
);
const
auto
b_grid_desc_bk0_n_bk1
=
transform_conv_to_gemm
.
template
MakeBDescriptor_BK0_N_BK1
<
BLayout
>(
a_g_n_k_wos_lengths
,
a_g_n_k_wos_strides
,
b_g_k_c_xs_lengths
,
b_g_k_c_xs_strides
,
e_g_n_c_wis_lengths
,
e_g_n_c_wis_strides
,
conv_filter_strides
,
conv_filter_dilations
,
input_left_pads
,
input_right_pads
,
tildes
);
DsGridDesc_M_N
ds_grid_desc_m_n
;
// populate Ds desc
static_for
<
0
,
NumDTensor
,
1
>
{}([
&
](
auto
i
)
{
using
DLayout
=
remove_cvref_t
<
tuple_element_t
<
i
.
value
,
DsLayout
>>
;
ds_grid_desc_m_n
(
i
)
=
transform_conv_to_gemm
.
template
MakeCDescriptor_M_N
<
DLayout
>(
a_g_n_k_wos_lengths
,
a_g_n_k_wos_strides
,
b_g_k_c_xs_lengths
,
b_g_k_c_xs_strides
,
ds_g_n_c_wis_lengths
[
i
],
ds_g_n_c_wis_strides
[
i
],
conv_filter_strides
,
conv_filter_dilations
,
input_left_pads
,
input_right_pads
,
tildes
);
});
const
auto
e_grid_desc_m_n
=
transform_conv_to_gemm
.
template
MakeCDescriptor_M_N
<
ELayout
>(
a_g_n_k_wos_lengths
,
a_g_n_k_wos_strides
,
b_g_k_c_xs_lengths
,
b_g_k_c_xs_strides
,
e_g_n_c_wis_lengths
,
e_g_n_c_wis_strides
,
conv_filter_strides
,
conv_filter_dilations
,
input_left_pads
,
input_right_pads
,
tildes
);
// for check validity
ds_grid_desc_m_n_container_
.
push_back
(
ds_grid_desc_m_n
);
e_grid_desc_m_n_container_
.
push_back
(
e_grid_desc_m_n
);
// desc for blockwise copy
a_grid_desc_ak0_m_ak1_container_
.
push_back
(
a_grid_desc_ak0_m_ak1
);
b_grid_desc_bk0_n_bk1_container_
.
push_back
(
b_grid_desc_bk0_n_bk1
);
// block-to-e-tile-map
auto
block_2_ctile_map
=
GridwiseGemm
::
MakeDefaultBlock2CTileMap
(
e_grid_desc_m_n
,
1
/* M01 */
,
1
/* N01 */
);
block_2_ctile_map_container_
.
push_back
(
block_2_ctile_map
);
ds_grid_desc_mblock_mperblock_nblock_nperblock_container_
.
push_back
(
GridwiseGemm
::
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
ds_grid_desc_m_n
));
e_grid_desc_mblock_mperblock_nblock_nperblock_container_
.
push_back
(
GridwiseGemm
::
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
e_grid_desc_m_n
));
}
}
}
}
void
Print
()
const
{
for
(
std
::
size_t
i
=
0
;
i
<
a_grid_desc_ak0_m_ak1_container_
.
size
();
i
++
)
{
std
::
cout
<<
"a_grid_desc_ak0_m_ak1_container_"
<<
a_grid_desc_ak0_m_ak1_container_
[
i
]
<<
std
::
endl
;
std
::
cout
<<
"b_grid_desc_bk0_n_bk1_container_"
<<
b_grid_desc_bk0_n_bk1_container_
[
i
]
<<
std
::
endl
;
static_for
<
0
,
NumDTensor
,
1
>
{}([
&
](
auto
j
)
{
std
::
cout
<<
"ds_grid_desc_mblock_mperblock_nblock_nperblock_container_"
<<
ds_grid_desc_mblock_mperblock_nblock_nperblock_container_
[
i
][
j
]
<<
std
::
endl
;
});
std
::
cout
<<
"e_grid_desc_mblock_mperblock_nblock_nperblock_container_"
<<
e_grid_desc_mblock_mperblock_nblock_nperblock_container_
[
i
]
<<
std
::
endl
;
}
}
// pointers
const
ADataType
*
p_a_grid_
;
const
BDataType
*
p_b_grid_
;
typename
GridwiseGemm
::
DsGridPointer
p_ds_grid_
;
EDataType
*
p_e_grid_
;
// tensor descriptor for problem definition
index_t
num_group_
;
std
::
vector
<
DsGridDesc_M_N
>
ds_grid_desc_m_n_container_
;
std
::
vector
<
EGridDesc_M_N
>
e_grid_desc_m_n_container_
;
// tensor descriptor for block-wise copy
std
::
vector
<
AGridDesc_AK0_M_AK1
>
a_grid_desc_ak0_m_ak1_container_
;
std
::
vector
<
BGridDesc_BK0_N_BK1
>
b_grid_desc_bk0_n_bk1_container_
;
std
::
vector
<
DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
>
ds_grid_desc_mblock_mperblock_nblock_nperblock_container_
;
std
::
vector
<
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
>
e_grid_desc_mblock_mperblock_nblock_nperblock_container_
;
// block-to-e-tile map
std
::
vector
<
typename
GridwiseGemm
::
DefaultBlock2CTileMap
>
block_2_ctile_map_container_
;
// for computing batch offset
ComputePtrOffsetOfStridedBatch
<
NumDTensor
>
compute_ptr_offset_of_batch_
;
// element-wise op
AElementwiseOp
a_element_op_
;
BElementwiseOp
b_element_op_
;
CDEElementwiseOp
cde_element_op_
;
// for checking IsSupportedArgument()
std
::
array
<
index_t
,
NDimSpatial
+
3
>
a_g_n_k_wos_lengths_
;
std
::
array
<
index_t
,
NDimSpatial
+
3
>
a_g_n_k_wos_strides_
;
std
::
array
<
index_t
,
NDimSpatial
+
3
>
b_g_k_c_xs_lengths_
;
std
::
array
<
index_t
,
NDimSpatial
+
3
>
b_g_k_c_xs_strides_
;
std
::
array
<
std
::
array
<
index_t
,
NDimSpatial
+
3
>
,
NumDTensor
>
ds_g_n_c_wis_lengths_
;
std
::
array
<
std
::
array
<
index_t
,
NDimSpatial
+
3
>
,
NumDTensor
>
ds_g_n_c_wis_strides_
;
std
::
array
<
index_t
,
NDimSpatial
+
3
>
e_g_n_c_wis_lengths_
;
std
::
array
<
index_t
,
NDimSpatial
+
3
>
e_g_n_c_wis_strides_
;
std
::
array
<
index_t
,
NDimSpatial
>
conv_filter_strides_
;
std
::
array
<
index_t
,
NDimSpatial
>
conv_filter_dilations_
;
std
::
array
<
index_t
,
NDimSpatial
>
input_left_pads_
;
std
::
array
<
index_t
,
NDimSpatial
>
input_right_pads_
;
};
// Invoker
struct
Invoker
:
public
BaseInvoker
{
using
Argument
=
DeviceOp
::
Argument
;
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
{
if
(
stream_config
.
log_level_
>
0
)
{
arg
.
Print
();
}
float
ave_time
=
0
;
for
(
std
::
size_t
i
=
0
;
i
<
arg
.
a_grid_desc_ak0_m_ak1_container_
.
size
();
i
++
)
{
const
index_t
grid_size
=
arg
.
block_2_ctile_map_container_
[
i
].
CalculateGridSize
(
arg
.
e_grid_desc_m_n_container_
[
i
])
*
arg
.
num_group_
;
const
auto
GemmK
=
arg
.
a_grid_desc_ak0_m_ak1_container_
[
i
].
GetLength
(
I0
)
*
arg
.
a_grid_desc_ak0_m_ak1_container_
[
i
].
GetLength
(
I2
);
auto
launch_kernel
=
[
&
](
auto
has_main_k_block_loop
)
{
constexpr
bool
has_main_loop
=
has_main_k_block_loop
.
value
;
const
auto
kernel
=
kernel_grouped_conv_multiple_d_wmma_cshuffle
<
GridwiseGemm
,
ADataType
,
BDataType
,
typename
GridwiseGemm
::
DsGridPointer
,
EDataType
,
AElementwiseOp
,
BElementwiseOp
,
CDEElementwiseOp
,
DeviceOp
::
AGridDesc_AK0_M_AK1
,
DeviceOp
::
BGridDesc_BK0_N_BK1
,
typename
GridwiseGemm
::
DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
GridwiseGemm
::
EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
remove_reference_t
<
typename
GridwiseGemm
::
DefaultBlock2CTileMap
>
,
ComputePtrOffsetOfStridedBatch
<
NumDTensor
>
,
has_main_loop
>
;
return
launch_and_time_kernel
(
stream_config
,
kernel
,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
arg
.
p_a_grid_
,
arg
.
p_b_grid_
,
arg
.
p_ds_grid_
,
arg
.
p_e_grid_
,
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
cde_element_op_
,
arg
.
a_g_n_k_wos_lengths_
[
0
],
// Group count
arg
.
a_grid_desc_ak0_m_ak1_container_
[
i
],
arg
.
b_grid_desc_bk0_n_bk1_container_
[
i
],
arg
.
ds_grid_desc_mblock_mperblock_nblock_nperblock_container_
[
i
],
arg
.
e_grid_desc_mblock_mperblock_nblock_nperblock_container_
[
i
],
arg
.
block_2_ctile_map_container_
[
i
],
arg
.
compute_ptr_offset_of_batch_
);
};
if
(
GridwiseGemm
::
CalculateHasMainKBlockLoop
(
GemmK
))
{
ave_time
+=
launch_kernel
(
integral_constant
<
bool
,
true
>
{});
}
else
{
ave_time
+=
launch_kernel
(
integral_constant
<
bool
,
false
>
{});
}
}
return
ave_time
;
}
float
Run
(
const
BaseArgument
*
p_arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
override
{
return
Run
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
),
stream_config
);
}
};
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
// check device
if
(
get_device_name
()
==
"gfx1100"
||
get_device_name
()
==
"gfx1101"
||
ck
::
get_device_name
()
==
"gfx1102"
)
{
if
constexpr
(
!
(
is_same_v
<
AccDataType
,
float
>
||
is_same_v
<
AccDataType
,
int32_t
>
))
{
return
false
;
}
}
else
{
return
false
;
}
const
index_t
ConvK
=
arg
.
b_g_k_c_xs_lengths_
[
1
];
const
index_t
ConvC
=
arg
.
b_g_k_c_xs_lengths_
[
2
];
// Specialization
if
constexpr
(
ConvBackwardDataSpecialization
==
ConvolutionBackwardDataSpecialization
::
Filter1x1Stride1Pad0
)
{
// check if it's a 1x1 convolution with stride=1 and no padding
for
(
int
i
=
0
;
i
<
NDimSpatial
;
i
++
)
{
if
(
!
(
arg
.
b_g_k_c_xs_lengths_
[
3
+
i
]
==
1
&&
arg
.
conv_filter_strides_
[
i
]
==
1
&&
arg
.
input_left_pads_
[
i
]
==
0
&&
arg
.
input_right_pads_
[
i
]
==
0
))
{
return
false
;
}
}
}
// vector load for A matrix from global memory to LDS
if
constexpr
(
is_same_v
<
ALayout
,
tensor_layout
::
convolution
::
GNHWK
>
||
is_same_v
<
ALayout
,
tensor_layout
::
convolution
::
GNDHWK
>
||
is_same_v
<
ALayout
,
tensor_layout
::
convolution
::
NHWGK
>
||
is_same_v
<
ALayout
,
tensor_layout
::
convolution
::
NDHWGK
>
)
{
if
(
!
(
ABlockTransferSrcVectorDim
==
2
&&
ConvK
%
ABlockTransferSrcScalarPerVector
==
0
))
{
return
false
;
}
}
else
{
return
false
;
}
// vector load for B matrix from global memory to LDS
if
constexpr
(
is_same_v
<
BLayout
,
tensor_layout
::
convolution
::
GKYXC
>
||
is_same_v
<
BLayout
,
tensor_layout
::
convolution
::
GKZYXC
>
)
{
if
(
!
(
BBlockTransferSrcVectorDim
==
1
&&
ConvC
%
BBlockTransferSrcScalarPerVector
==
0
))
{
return
false
;
}
}
else
{
return
false
;
}
// vector store for Ds
bool
ds_valid
=
true
;
static_for
<
0
,
NumDTensor
,
1
>
{}([
&
](
auto
i
)
{
using
DLayout
=
remove_cvref_t
<
tuple_element_t
<
i
.
value
,
DsLayout
>>
;
if
constexpr
(
is_same_v
<
DLayout
,
tensor_layout
::
convolution
::
GNHWC
>
||
is_same_v
<
DLayout
,
tensor_layout
::
convolution
::
GNDHWC
>
||
is_same_v
<
DLayout
,
tensor_layout
::
convolution
::
NHWGC
>
||
is_same_v
<
DLayout
,
tensor_layout
::
convolution
::
NDHWGC
>
||
is_same_v
<
DLayout
,
tensor_layout
::
convolution
::
G_NHW_C
>
||
is_same_v
<
DLayout
,
tensor_layout
::
convolution
::
GC
>
||
is_same_v
<
DLayout
,
tensor_layout
::
convolution
::
G_C
>
)
{
// vector load D matrix from global memory
if
(
!
(
ConvC
%
CDEShuffleBlockTransferScalarPerVector_NPerBlock
==
0
))
{
ds_valid
=
false
;
}
}
else
{
ds_valid
=
false
;
}
});
if
(
!
ds_valid
)
{
return
false
;
}
// vector store for E
if
constexpr
(
is_same_v
<
ELayout
,
tensor_layout
::
convolution
::
GNHWC
>
||
is_same_v
<
ELayout
,
tensor_layout
::
convolution
::
GNDHWC
>
||
is_same_v
<
ELayout
,
tensor_layout
::
convolution
::
NHWGC
>
||
is_same_v
<
ELayout
,
tensor_layout
::
convolution
::
NDHWGC
>
)
{
// vector store C matrix into global memory
if
(
!
(
ConvC
%
CDEShuffleBlockTransferScalarPerVector_NPerBlock
==
0
))
{
return
false
;
}
}
else
{
return
false
;
}
// Gridwise GEMM size
for
(
std
::
size_t
i
=
0
;
i
<
arg
.
a_grid_desc_ak0_m_ak1_container_
.
size
();
i
++
)
{
if
(
!
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_ak0_m_ak1_container_
[
i
],
arg
.
b_grid_desc_bk0_n_bk1_container_
[
i
],
arg
.
ds_grid_desc_m_n_container_
[
i
],
arg
.
e_grid_desc_m_n_container_
[
i
],
arg
.
block_2_ctile_map_container_
[
i
]))
{
return
false
;
}
}
return
true
;
}
bool
IsSupportedArgument
(
const
BaseArgument
*
p_arg
)
override
{
return
IsSupportedArgument
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
));
}
static
auto
MakeArgument
(
const
void
*
p_a
,
// output image
const
void
*
p_b
,
// weight
const
std
::
array
<
const
void
*
,
NumDTensor
>&
p_ds
,
// bias
void
*
p_e
,
// input image
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
a_g_n_k_wos_lengths
,
// output image
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
a_g_n_k_wos_strides
,
// output image
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
b_g_k_c_xs_lengths
,
// weight
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
b_g_k_c_xs_strides
,
// weight
const
std
::
array
<
std
::
array
<
index_t
,
NDimSpatial
+
3
>
,
NumDTensor
>&
ds_g_n_c_wis_lengths
,
// bias
const
std
::
array
<
std
::
array
<
index_t
,
NDimSpatial
+
3
>
,
NumDTensor
>&
ds_g_n_c_wis_strides
,
// bias
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
e_g_n_c_wis_lengths
,
// input image
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
e_g_n_c_wis_strides
,
// input image
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_dilations
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_left_pads
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_right_pads
,
const
AElementwiseOp
&
a_element_op
,
const
BElementwiseOp
&
b_element_op
,
const
CDEElementwiseOp
&
cde_element_op
)
{
return
Argument
{
p_a
,
p_b
,
p_ds
,
p_e
,
a_g_n_k_wos_lengths
,
a_g_n_k_wos_strides
,
b_g_k_c_xs_lengths
,
b_g_k_c_xs_strides
,
ds_g_n_c_wis_lengths
,
ds_g_n_c_wis_strides
,
e_g_n_c_wis_lengths
,
e_g_n_c_wis_strides
,
conv_filter_strides
,
conv_filter_dilations
,
input_left_pads
,
input_right_pads
,
a_element_op
,
b_element_op
,
cde_element_op
};
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
void
*
p_a
,
// output image
const
void
*
p_b
,
// weight
const
std
::
array
<
const
void
*
,
NumDTensor
>&
p_ds
,
// bias
void
*
p_e
,
// input image
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
a_g_n_k_wos_lengths
,
// output image
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
a_g_n_k_wos_strides
,
// output image
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
b_g_k_c_xs_lengths
,
// weight
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
b_g_k_c_xs_strides
,
// weight
const
std
::
array
<
std
::
array
<
index_t
,
NDimSpatial
+
3
>
,
NumDTensor
>&
ds_g_n_c_wis_lengths
,
// bias
const
std
::
array
<
std
::
array
<
index_t
,
NDimSpatial
+
3
>
,
NumDTensor
>&
ds_g_n_c_wis_strides
,
// bias
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
e_g_n_c_wis_lengths
,
// input image
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
e_g_n_c_wis_strides
,
// input image
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_dilations
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_left_pads
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_right_pads
,
const
AElementwiseOp
&
a_element_op
,
const
BElementwiseOp
&
b_element_op
,
const
CDEElementwiseOp
&
cde_element_op
)
override
{
return
std
::
make_unique
<
Argument
>
(
p_a
,
p_b
,
p_ds
,
p_e
,
a_g_n_k_wos_lengths
,
a_g_n_k_wos_strides
,
b_g_k_c_xs_lengths
,
b_g_k_c_xs_strides
,
ds_g_n_c_wis_lengths
,
ds_g_n_c_wis_strides
,
e_g_n_c_wis_lengths
,
e_g_n_c_wis_strides
,
conv_filter_strides
,
conv_filter_dilations
,
input_left_pads
,
input_right_pads
,
a_element_op
,
b_element_op
,
cde_element_op
);
}
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
override
{
return
std
::
make_unique
<
Invoker
>
(
Invoker
{});
}
std
::
string
GetTypeString
()
const
override
{
auto
str
=
std
::
stringstream
();
// clang-format off
str
<<
"DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle"
<<
"<"
<<
BlockSize
<<
", "
<<
MPerBlock
<<
", "
<<
NPerBlock
<<
", "
<<
KPerBlock
<<
", "
<<
getConvBackwardDataSpecializationString
(
ConvBackwardDataSpecialization
)
<<
", "
<<
K1
<<
", "
<<
ABlockTransferSrcScalarPerVector
<<
", "
<<
BBlockTransferSrcScalarPerVector
<<
">"
;
// clang-format on
return
str
.
str
();
}
};
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp
View file @
9f8ab221
...
...
@@ -14,6 +14,7 @@
#include "ck/tensor_operation/gpu/device/convolution_backward_data_specialization.hpp"
#include "ck/tensor_operation/operator_transform/transform_conv_bwd_data_to_gemm_v1.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/host_utility/io.hpp"
...
...
@@ -24,51 +25,6 @@ namespace device {
namespace
{
template
<
index_t
NumDTensor
>
struct
ComputePtrOffsetOfStridedBatch
{
ComputePtrOffsetOfStridedBatch
()
=
default
;
ComputePtrOffsetOfStridedBatch
(
index_t
BatchStrideA
,
index_t
BatchStrideB
,
Array
<
ck
::
index_t
,
NumDTensor
>
BatchStrideDs
,
index_t
BatchStrideE
)
:
BatchStrideA_
(
BatchStrideA
),
BatchStrideB_
(
BatchStrideB
),
BatchStrideDs_
(
BatchStrideDs
),
BatchStrideE_
(
BatchStrideE
)
{
}
__host__
__device__
constexpr
long_index_t
GetAPtrOffset
(
index_t
g_idx
)
const
{
return
g_idx
*
static_cast
<
long_index_t
>
(
BatchStrideA_
);
}
__host__
__device__
constexpr
long_index_t
GetBPtrOffset
(
index_t
g_idx
)
const
{
return
g_idx
*
static_cast
<
long_index_t
>
(
BatchStrideB_
);
}
__host__
__device__
constexpr
auto
GetDsPtrOffset
(
index_t
g_idx
)
const
{
Array
<
long_index_t
,
NumDTensor
>
ds_offset
;
static_for
<
0
,
NumDTensor
,
1
>
{}(
[
&
](
auto
i
)
{
ds_offset
(
i
)
=
g_idx
*
static_cast
<
long_index_t
>
(
BatchStrideDs_
[
i
]);
});
return
ds_offset
;
}
__host__
__device__
constexpr
long_index_t
GetEPtrOffset
(
index_t
g_idx
)
const
{
return
g_idx
*
static_cast
<
long_index_t
>
(
BatchStrideE_
);
}
index_t
BatchStrideA_
;
index_t
BatchStrideB_
;
Array
<
ck
::
index_t
,
NumDTensor
>
BatchStrideDs_
;
index_t
BatchStrideE_
;
};
/*
* \brief Wrapper function of GridwiseGemm::Run to realize BatchedGEMM.
*
...
...
@@ -242,7 +198,9 @@ template <index_t NDimSpatial,
index_t
CShuffleNXdlPerWavePerShuffle
,
typename
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
index_t
CDEBlockTransferScalarPerVector_NPerBlock
,
LoopScheduler
LoopSched
=
make_default_loop_scheduler
()>
LoopScheduler
LoopSched
=
make_default_loop_scheduler
(),
typename
AComputeType
=
ADataType
,
typename
BComputeType
=
AComputeType
>
struct
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
:
public
DeviceGroupedConvBwdDataMultipleD
<
NDimSpatial
,
ALayout
,
// output image
...
...
@@ -255,9 +213,11 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
EDataType
,
// input image
AElementwiseOp
,
BElementwiseOp
,
CDEElementwiseOp
>
CDEElementwiseOp
,
AComputeType
,
BComputeType
>
{
//
FIXME
//
TODO: Extend support for more spatial dimensions.
static_assert
(
NDimSpatial
==
2
||
NDimSpatial
==
3
,
"wrong! only implemented for 2D and 3D now"
);
...
...
@@ -265,7 +225,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
static
constexpr
index_t
NumDTensor
=
DsDataType
::
Size
();
// TODO
make A/
B datatype
different
// TODO
: Add support for different A and
B data
type
s.
using
ABDataType
=
ADataType
;
static
constexpr
auto
I0
=
Number
<
0
>
{};
...
...
@@ -356,9 +316,9 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
// GridwiseGemm
using
GridwiseGemm
=
GridwiseGemmMultipleD_xdl_cshuffle
<
ABDataType
,
// TODO: distinguish A/B datatype
ABDataType
,
// TODO: distinguish A/B datatype
A
BDataType
,
// TODO: distinguish A/B datat
ype
ABDataType
,
ABDataType
,
A
ComputeT
ype
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
...
...
@@ -398,7 +358,9 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
CShuffleNXdlPerWavePerShuffle
,
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
CDEBlockTransferScalarPerVector_NPerBlock
,
LoopSched
>
;
LoopSched
,
PipelineVersion
::
v1
,
BComputeType
>
;
template
<
typename
Desc_K0_M_K1
>
static
auto
transform_k0_m_k1_to_m_k
(
const
Desc_K0_M_K1
&
desc_k0_m_k1
)
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_dl.hpp
View file @
9f8ab221
...
...
@@ -12,6 +12,7 @@
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_conv_bwd_weight.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp"
#include "ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_dl_v1r3.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
...
...
@@ -22,32 +23,6 @@ namespace ck {
namespace
tensor_operation
{
namespace
device
{
namespace
{
struct
ComputePtrOffsetOfStridedBatch
{
__host__
__device__
constexpr
long_index_t
GetAPtrOffset
(
index_t
g_idx
)
const
{
return
g_idx
*
static_cast
<
long_index_t
>
(
BatchStrideA_
);
}
__host__
__device__
constexpr
long_index_t
GetBPtrOffset
(
index_t
g_idx
)
const
{
return
g_idx
*
static_cast
<
long_index_t
>
(
BatchStrideB_
);
}
__host__
__device__
constexpr
long_index_t
GetCPtrOffset
(
index_t
g_idx
)
const
{
return
g_idx
*
static_cast
<
long_index_t
>
(
BatchStrideC_
);
}
index_t
BatchStrideA_
;
index_t
BatchStrideB_
;
index_t
BatchStrideC_
;
};
}
// namespace
template
<
typename
GridwiseGemm
,
typename
FloatAB
,
typename
FloatC
,
...
...
@@ -952,7 +927,7 @@ struct DeviceGroupedConvBwdWeight_Dl : public DeviceGroupedConvBwdWeight<NDimSpa
Block2CTileMap
block_2_ctile_map_
;
// for computing batch offset
ComputePtrOffsetOfStridedBatch
compute_ptr_offset_of_batch_
;
ComputePtrOffsetOfStridedBatch
<
I0
>
compute_ptr_offset_of_batch_
;
// element-wise op
OutElementwiseOperation
a_element_op_
;
...
...
@@ -1024,7 +999,7 @@ struct DeviceGroupedConvBwdWeight_Dl : public DeviceGroupedConvBwdWeight<NDimSpa
remove_reference_t
<
DeviceOp
::
BGridDesc_B_K0_N0_N1_K1
>
,
remove_reference_t
<
DeviceOp
::
CGridDesc_M0_M10_M11_N0_N10_N11
>
,
remove_reference_t
<
DeviceOp
::
Block2CTileMap
>
,
ComputePtrOffsetOfStridedBatch
,
ComputePtrOffsetOfStridedBatch
<
I0
>
,
has_main_loop
,
has_double_loop
>
;
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle.hpp
0 → 100644
View file @
9f8ab221
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <numeric>
#include <sstream>
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_conv_bwd_weight.hpp"
#include "ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
template
<
index_t
NDimSpatial
,
typename
InLayout
,
typename
WeiLayout
,
typename
OutLayout
,
typename
InDataType
,
typename
WeiDataType
,
typename
OutDataType
,
typename
AccDataType
,
typename
InElementwiseOperation
,
typename
WeiElementwiseOperation
,
typename
OutElementwiseOperation
,
ConvolutionBackwardWeightSpecialization
ConvBackwardWeightSpecialization
,
index_t
BlockSize
,
index_t
MPerBlock
,
index_t
NPerBlock
,
index_t
K0PerBlock
,
index_t
K1
,
index_t
MPerWMMA
,
index_t
NPerWMMA
,
index_t
MRepeat
,
index_t
NRepeat
,
typename
ABlockTransferThreadClusterLengths_K0_M_K1
,
typename
ABlockTransferThreadClusterArrangeOrder
,
typename
ABlockTransferSrcAccessOrder
,
index_t
ABlockTransferSrcVectorDim
,
index_t
ABlockTransferSrcScalarPerVector
,
index_t
ABlockTransferDstScalarPerVector_K1
,
bool
ABlockLdsAddExtraM
,
typename
BBlockTransferThreadClusterLengths_K0_N_K1
,
typename
BBlockTransferThreadClusterArrangeOrder
,
typename
BBlockTransferSrcAccessOrder
,
index_t
BBlockTransferSrcVectorDim
,
index_t
BBlockTransferSrcScalarPerVector
,
index_t
BBlockTransferDstScalarPerVector_K1
,
bool
BBlockLdsAddExtraN
,
index_t
CShuffleMRepeatPerShuffle
,
index_t
CShuffleNRepeatPerShuffle
,
typename
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
index_t
CShuffleBlockTransferScalarPerVector_NPerBlock
,
index_t
NumGemmKPrefetchStage
=
1
,
LoopScheduler
LoopSched
=
make_default_loop_scheduler
(),
ck
::
PipelineVersion
PipelineVer
=
ck
::
PipelineVersion
::
v1
,
typename
ck
::
enable_if
<
NDimSpatial
==
3
,
bool
>
::
type
=
false
>
struct
DeviceGroupedConvBwdWeight_Wmma_CShuffle
:
public
DeviceGroupedConvBwdWeight
<
NDimSpatial
,
InLayout
,
WeiLayout
,
OutLayout
,
InDataType
,
WeiDataType
,
OutDataType
,
InElementwiseOperation
,
WeiElementwiseOperation
,
OutElementwiseOperation
>
{
using
DeviceOp
=
DeviceGroupedConvBwdWeight_Wmma_CShuffle
;
using
ADataType
=
OutDataType
;
using
BDataType
=
InDataType
;
using
CDataType
=
WeiDataType
;
using
AElementwiseOperation
=
OutElementwiseOperation
;
using
BElementwiseOperation
=
InElementwiseOperation
;
using
CElementwiseOperation
=
WeiElementwiseOperation
;
// TODO make A/B datatype different
using
ABDataType
=
InDataType
;
// 3d
static
constexpr
bool
is_NDHWGK_GKZYXC_NDHWGC
=
is_same_v
<
InLayout
,
tensor_layout
::
convolution
::
NDHWGC
>
&&
is_same_v
<
WeiLayout
,
tensor_layout
::
convolution
::
GKZYXC
>
&&
is_same_v
<
OutLayout
,
tensor_layout
::
convolution
::
NDHWGK
>
;
static
constexpr
bool
is_GNDHWK_GKZYXC_GNDHWC
=
is_same_v
<
InLayout
,
tensor_layout
::
convolution
::
GNDHWC
>
&&
is_same_v
<
WeiLayout
,
tensor_layout
::
convolution
::
GKZYXC
>
&&
is_same_v
<
OutLayout
,
tensor_layout
::
convolution
::
GNDHWK
>
;
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
static
constexpr
auto
I4
=
Number
<
4
>
{};
static
constexpr
auto
I5
=
Number
<
5
>
{};
static
constexpr
auto
GemmK1Number
=
Number
<
K1
>
{};
static
constexpr
index_t
KPerBlock
=
K0PerBlock
*
GemmK1Number
;
template
<
index_t
NDim
,
typename
ck
::
enable_if
<
NDim
==
3
,
bool
>
::
type
=
false
>
constexpr
static
auto
make_out_grid_desc
(
const
index_t
N
,
const
index_t
Do
,
const
index_t
Ho
,
const
index_t
Wo
,
const
index_t
K
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
output_strides
)
{
const
index_t
WoStride
=
output_strides
[
5
];
const
auto
KStride
=
Number
<
1
>
{};
return
make_naive_tensor_descriptor
(
make_tuple
(
N
*
Do
*
Ho
*
Wo
,
K
),
make_tuple
(
WoStride
,
KStride
));
}
template
<
index_t
NDim
,
typename
ck
::
enable_if
<
NDim
==
3
,
bool
>
::
type
=
false
>
constexpr
static
auto
make_in_grid_desc
(
const
index_t
N
,
const
index_t
Di
,
const
index_t
Hi
,
const
index_t
Wi
,
const
index_t
C
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
input_strides
)
{
const
index_t
NStride
=
input_strides
[
1
];
const
index_t
DiStride
=
input_strides
[
3
];
const
index_t
HiStride
=
input_strides
[
4
];
const
index_t
WiStride
=
input_strides
[
5
];
const
auto
CStride
=
input_strides
[
2
];
if
constexpr
(
ConvBackwardWeightSpecialization
==
ConvolutionBackwardWeightSpecialization
::
Filter1x1Stride1Pad0
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
N
*
Di
*
Hi
*
Wi
,
C
),
make_tuple
(
WiStride
,
CStride
));
}
else
{
return
make_naive_tensor_descriptor
(
make_tuple
(
N
,
Di
,
Hi
,
Wi
,
C
),
make_tuple
(
NStride
,
DiStride
,
HiStride
,
WiStride
,
CStride
));
}
}
template
<
index_t
NDim
,
typename
ck
::
enable_if
<
NDim
==
3
,
bool
>
::
type
=
false
>
constexpr
static
auto
make_wei_grid_desc
(
const
index_t
K
,
const
index_t
Z
,
const
index_t
Y
,
const
index_t
X
,
const
index_t
C
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
weights_strides
)
{
const
auto
CStride
=
Number
<
1
>
{};
const
auto
KStride
=
weights_strides
[
1
];
return
make_naive_tensor_descriptor
(
make_tuple
(
K
,
Z
*
Y
*
X
*
C
),
make_tuple
(
KStride
,
CStride
));
}
template
<
index_t
NDim
,
typename
ck
::
enable_if
<
NDim
==
3
,
bool
>
::
type
=
false
>
static
auto
MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N
(
const
index_t
N
,
const
index_t
K
,
const
index_t
C
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_spatial_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
filter_spatial_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
output_spatial_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
input_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
weights_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
output_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_dilations
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_left_pads
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_right_pads
)
{
using
namespace
ck
;
const
index_t
Di
=
input_spatial_lengths
[
0
];
const
index_t
Hi
=
input_spatial_lengths
[
1
];
const
index_t
Wi
=
input_spatial_lengths
[
2
];
const
index_t
Do
=
output_spatial_lengths
[
0
];
const
index_t
Ho
=
output_spatial_lengths
[
1
];
const
index_t
Wo
=
output_spatial_lengths
[
2
];
const
index_t
Z
=
filter_spatial_lengths
[
0
];
const
index_t
Y
=
filter_spatial_lengths
[
1
];
const
index_t
X
=
filter_spatial_lengths
[
2
];
const
index_t
ConvStrideD
=
conv_filter_strides
[
0
];
const
index_t
ConvStrideH
=
conv_filter_strides
[
1
];
const
index_t
ConvStrideW
=
conv_filter_strides
[
2
];
const
index_t
ConvDilationD
=
conv_filter_dilations
[
0
];
const
index_t
ConvDilationH
=
conv_filter_dilations
[
1
];
const
index_t
ConvDilationW
=
conv_filter_dilations
[
2
];
const
index_t
InLeftPadD
=
input_left_pads
[
0
];
const
index_t
InLeftPadH
=
input_left_pads
[
1
];
const
index_t
InLeftPadW
=
input_left_pads
[
2
];
const
index_t
InRightPadD
=
input_right_pads
[
0
];
const
index_t
InRightPadH
=
input_right_pads
[
1
];
const
index_t
InRightPadW
=
input_right_pads
[
2
];
const
index_t
GemmKTotal
=
N
*
Do
*
Ho
*
Wo
;
const
index_t
GemmM
=
K
;
const
index_t
GemmN
=
C
*
Z
*
X
*
Y
;
const
auto
PadGemmM
=
(
MPerBlock
-
GemmM
%
MPerBlock
)
%
MPerBlock
;
const
auto
PadGemmN
=
(
NPerBlock
-
GemmN
%
NPerBlock
)
%
NPerBlock
;
const
index_t
GemmK0
=
math
::
integer_divide_ceil
(
GemmKTotal
,
GemmK1Number
*
K0PerBlock
)
*
K0PerBlock
;
const
index_t
GemmKPad
=
GemmK0
*
GemmK1Number
;
const
auto
out_grid_desc
=
make_out_grid_desc
<
NDim
>
(
N
,
Do
,
Ho
,
Wo
,
K
,
output_strides
);
const
auto
in_grid_desc
=
make_in_grid_desc
<
NDim
>
(
N
,
Di
,
Hi
,
Wi
,
C
,
input_strides
);
const
auto
wei_grid_desc
=
make_wei_grid_desc
<
NDim
>
(
K
,
Z
,
Y
,
X
,
C
,
weights_strides
);
if
constexpr
(
ConvBackwardWeightSpecialization
==
ConvolutionBackwardWeightSpecialization
::
Filter1x1Stride1Pad0
)
{
// A: output tensor
const
auto
out_gemmkpad_gemmm_grid_desc
=
transform_tensor_descriptor
(
out_grid_desc
,
make_tuple
(
make_right_pad_transform
(
GemmKTotal
,
GemmKPad
-
GemmKTotal
),
make_pass_through_transform
(
GemmM
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc
=
transform_tensor_descriptor
(
out_gemmkpad_gemmm_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
GemmK0
,
GemmK1Number
)),
make_pass_through_transform
(
GemmM
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
// B: input tensor
const
auto
in_gemmkpad_gemmn_grid_desc
=
transform_tensor_descriptor
(
in_grid_desc
,
make_tuple
(
make_right_pad_transform
(
GemmKTotal
,
GemmKPad
-
GemmKTotal
),
make_pass_through_transform
(
GemmN
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc
=
transform_tensor_descriptor
(
in_gemmkpad_gemmn_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
GemmK0
,
GemmK1Number
)),
make_pass_through_transform
(
GemmN
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
make_tuple
(
out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc
,
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc
,
wei_grid_desc
);
}
else
{
// A: output tensor
const
auto
out_gemmkpad_gemmm_grid_desc
=
transform_tensor_descriptor
(
out_grid_desc
,
make_tuple
(
make_right_pad_transform
(
GemmKTotal
,
GemmKPad
-
GemmKTotal
),
make_pass_through_transform
(
GemmM
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc
=
transform_tensor_descriptor
(
out_gemmkpad_gemmm_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
GemmK0
,
GemmK1Number
)),
make_pass_through_transform
(
GemmM
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
// B: input tensor
const
auto
in_n_dip_hip_wip_c_grid_desc
=
transform_tensor_descriptor
(
in_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_pad_transform
(
Di
,
InLeftPadD
,
InRightPadD
),
make_pad_transform
(
Hi
,
InLeftPadH
,
InRightPadH
),
make_pad_transform
(
Wi
,
InLeftPadW
,
InRightPadW
),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{}));
const
auto
in_n_z_do_y_ho_x_wo_c_grid_desc
=
transform_tensor_descriptor
(
in_n_dip_hip_wip_c_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_embed_transform
(
make_tuple
(
Z
,
Do
),
make_tuple
(
ConvDilationD
,
ConvStrideD
)),
make_embed_transform
(
make_tuple
(
Y
,
Ho
),
make_tuple
(
ConvDilationH
,
ConvStrideH
)),
make_embed_transform
(
make_tuple
(
X
,
Wo
),
make_tuple
(
ConvDilationW
,
ConvStrideW
)),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
,
4
>
{},
Sequence
<
5
,
6
>
{},
Sequence
<
7
>
{}));
const
auto
in_gemmktotal_gemmn_grid_desc
=
transform_tensor_descriptor
(
in_n_z_do_y_ho_x_wo_c_grid_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
Z
,
Y
,
X
,
C
)),
make_merge_transform
(
make_tuple
(
N
,
Do
,
Ho
,
Wo
))),
make_tuple
(
Sequence
<
1
,
3
,
5
,
7
>
{},
Sequence
<
0
,
2
,
4
,
6
>
{}),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}));
const
auto
in_gemmkpad_gemmn_grid_desc
=
transform_tensor_descriptor
(
in_gemmktotal_gemmn_grid_desc
,
make_tuple
(
make_right_pad_transform
(
GemmKTotal
,
GemmKPad
-
GemmKTotal
),
make_pass_through_transform
(
GemmN
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc
=
transform_tensor_descriptor
(
in_gemmkpad_gemmn_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
GemmK0
,
GemmK1Number
)),
make_pass_through_transform
(
GemmN
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
// Pad
const
auto
out_gemmkbatch_gemmk0_gemmm_gemmk1_pad_grid_desc
=
transform_tensor_descriptor
(
out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc
,
make_tuple
(
make_pass_through_transform
(
GemmK0
),
make_right_pad_transform
(
GemmM
,
PadGemmM
),
make_pass_through_transform
(
GemmK1Number
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}));
const
auto
in_gemmkbatch_gemmk0_gemmn_gemmk1_pad_grid_desc
=
transform_tensor_descriptor
(
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc
,
make_tuple
(
make_pass_through_transform
(
GemmK0
),
make_right_pad_transform
(
GemmN
,
PadGemmN
),
make_pass_through_transform
(
GemmK1Number
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}));
const
auto
wei_gemmm_gemmn_pad_grid_desc
=
transform_tensor_descriptor
(
wei_grid_desc
,
make_tuple
(
make_right_pad_transform
(
GemmM
,
PadGemmM
),
make_right_pad_transform
(
GemmN
,
PadGemmN
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
return
make_tuple
(
out_gemmkbatch_gemmk0_gemmm_gemmk1_pad_grid_desc
,
in_gemmkbatch_gemmk0_gemmn_gemmk1_pad_grid_desc
,
wei_gemmm_gemmn_pad_grid_desc
);
}
}
template
<
index_t
NDim
,
typename
ck
::
enable_if
<
NDim
==
3
,
bool
>
::
type
=
false
>
static
auto
GetABCGridDesc
()
{
const
index_t
dim
=
1
;
const
std
::
array
<
index_t
,
NDimSpatial
>
lengths
{
1
,
1
,
1
};
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>
strides
{
1
,
1
,
1
,
1
,
1
,
1
};
const
std
::
array
<
index_t
,
NDimSpatial
>
params
{
1
,
1
,
1
};
return
MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N
<
3
>
(
dim
,
dim
,
dim
,
lengths
,
lengths
,
lengths
,
strides
,
strides
,
strides
,
params
,
params
,
params
,
params
);
}
using
ABCGridDescs
=
decltype
(
GetABCGridDesc
<
NDimSpatial
>
());
using
AGridDesc_K0_M_K1
=
remove_cvref_t
<
decltype
(
ABCGridDescs
{}[
I0
])
>
;
using
BGridDesc_K0_N_K1
=
remove_cvref_t
<
decltype
(
ABCGridDescs
{}[
I1
])
>
;
using
CGridDesc_M_N
=
remove_cvref_t
<
decltype
(
ABCGridDescs
{}[
I2
])
>
;
using
GridwiseGemm
=
GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle
<
// DataType Family
ADataType
,
BDataType
,
AccDataType
,
CDataType
,
Tuple
<>
,
CDataType
,
// InMemory Data Descriptor
AGridDesc_K0_M_K1
,
BGridDesc_K0_N_K1
,
Tuple
<>
,
CGridDesc_M_N
,
// ElementwiseOp Family
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
,
InMemoryDataOperationEnum
::
Set
,
// Tiling Family
MPerBlock
,
NPerBlock
,
K0PerBlock
,
MPerWMMA
,
NPerWMMA
,
K1
,
MRepeat
,
NRepeat
,
// ThreadCluster Family
BlockSize
,
ABlockTransferThreadClusterLengths_K0_M_K1
,
ABlockTransferThreadClusterArrangeOrder
,
ABlockTransferSrcAccessOrder
,
ABlockTransferSrcVectorDim
,
ABlockTransferSrcScalarPerVector
,
ABlockTransferDstScalarPerVector_K1
,
false
,
ABlockLdsAddExtraM
,
BBlockTransferThreadClusterLengths_K0_N_K1
,
BBlockTransferThreadClusterArrangeOrder
,
BBlockTransferSrcAccessOrder
,
BBlockTransferSrcVectorDim
,
BBlockTransferSrcScalarPerVector
,
BBlockTransferDstScalarPerVector_K1
,
false
,
BBlockLdsAddExtraN
,
CShuffleMRepeatPerShuffle
,
CShuffleNRepeatPerShuffle
,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
CShuffleBlockTransferScalarPerVector_NPerBlock
,
NumGemmKPrefetchStage
,
LoopSched
,
PipelineVer
>
;
using
DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
=
decltype
(
GridwiseGemm
::
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
Tuple
<>
{}));
using
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
=
decltype
(
GridwiseGemm
::
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
CGridDesc_M_N
{}));
using
Block2CTileMap
=
decltype
(
GridwiseGemm
::
MakeDefaultBlock2CTileMap
(
CGridDesc_M_N
{},
I1
/* M01 */
,
I1
/* N01 */
));
struct
Argument
:
public
BaseArgument
{
Argument
(
const
InDataType
*
p_in_grid
,
WeiDataType
*
p_wei_grid
,
const
OutDataType
*
p_out_grid
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
a_g_n_c_wis_lengths
,
// input
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
a_g_n_c_wis_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
b_g_k_c_xs_lengths
,
// weight
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
b_g_k_c_xs_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
e_g_n_k_wos_lengths
,
// output
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
e_g_n_k_wos_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_dilations
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_left_pads
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_right_pads
,
InElementwiseOperation
in_element_op
,
WeiElementwiseOperation
wei_element_op
,
OutElementwiseOperation
out_element_op
,
index_t
split_k
)
:
p_a_grid_
{
p_out_grid
},
p_b_grid_
{
p_in_grid
},
p_c_grid_
{
p_wei_grid
},
a_grid_desc_kbatch_k0_m_k1_
{},
b_grid_desc_kbatch_k0_n_k1_
{},
c_grid_desc_m_n_
{},
c_grid_desc_mblock_mperblock_nblock_nperblock_
{},
block_2_ctile_map_
{},
compute_ptr_offset_of_batch_
{},
a_element_op_
{
out_element_op
},
b_element_op_
{
in_element_op
},
c_element_op_
{
wei_element_op
},
Conv_G_
{
a_g_n_c_wis_lengths
[
0
]},
Conv_N_
{
a_g_n_c_wis_lengths
[
1
]},
Conv_K_
{
b_g_k_c_xs_lengths
[
1
]},
Conv_C_
{
a_g_n_c_wis_lengths
[
2
]},
input_spatial_lengths_
{},
filter_spatial_lengths_
{},
output_spatial_lengths_
{},
conv_filter_strides_
{
conv_filter_strides
},
input_left_pads_
{
input_left_pads
},
input_right_pads_
{
input_right_pads
},
k_batch_
{
split_k
}
{
constexpr
index_t
spatial_offset
=
3
;
std
::
copy
(
begin
(
a_g_n_c_wis_lengths
)
+
spatial_offset
,
end
(
a_g_n_c_wis_lengths
),
begin
(
input_spatial_lengths_
));
std
::
copy
(
begin
(
b_g_k_c_xs_lengths
)
+
spatial_offset
,
end
(
b_g_k_c_xs_lengths
),
begin
(
filter_spatial_lengths_
));
std
::
copy
(
begin
(
e_g_n_k_wos_lengths
)
+
spatial_offset
,
end
(
e_g_n_k_wos_lengths
),
begin
(
output_spatial_lengths_
));
const
auto
descs
=
DeviceOp
::
MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N
<
NDimSpatial
>
(
Conv_N_
,
Conv_K_
,
Conv_C_
,
input_spatial_lengths_
,
filter_spatial_lengths_
,
output_spatial_lengths_
,
a_g_n_c_wis_strides
,
b_g_k_c_xs_strides
,
e_g_n_k_wos_strides
,
conv_filter_strides
,
conv_filter_dilations
,
input_left_pads
,
input_right_pads
);
a_grid_desc_kbatch_k0_m_k1_
=
descs
[
I0
];
b_grid_desc_kbatch_k0_n_k1_
=
descs
[
I1
];
c_grid_desc_m_n_
=
descs
[
I2
];
block_2_ctile_map_
=
GridwiseGemm
::
MakeDefaultBlock2CTileMap
(
c_grid_desc_m_n_
,
I1
/* M01 */
,
I1
/* N01 */
);
// A/B/C Batch Stride
compute_ptr_offset_of_batch_
.
BatchStrideA_
=
e_g_n_k_wos_strides
[
0
];
compute_ptr_offset_of_batch_
.
BatchStrideB_
=
a_g_n_c_wis_strides
[
0
];
compute_ptr_offset_of_batch_
.
BatchStrideE_
=
Conv_K_
*
Conv_C_
*
std
::
accumulate
(
begin
(
filter_spatial_lengths_
),
end
(
filter_spatial_lengths_
),
index_t
{
1
},
std
::
multiplies
<>
{});
if
(
GridwiseGemm
::
CheckValidity
(
a_grid_desc_kbatch_k0_m_k1_
,
b_grid_desc_kbatch_k0_n_k1_
,
c_grid_desc_m_n_
,
block_2_ctile_map_
))
{
c_grid_desc_mblock_mperblock_nblock_nperblock_
=
GridwiseGemm
::
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
c_grid_desc_m_n_
);
}
}
const
ADataType
*
p_a_grid_
;
const
BDataType
*
p_b_grid_
;
CDataType
*
p_c_grid_
;
AGridDesc_K0_M_K1
a_grid_desc_kbatch_k0_m_k1_
;
BGridDesc_K0_N_K1
b_grid_desc_kbatch_k0_n_k1_
;
CGridDesc_M_N
c_grid_desc_m_n_
;
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock_
;
Block2CTileMap
block_2_ctile_map_
;
// for computing batch offset
ComputePtrOffsetOfStridedBatch
<
I0
>
compute_ptr_offset_of_batch_
;
OutElementwiseOperation
a_element_op_
;
InElementwiseOperation
b_element_op_
;
WeiElementwiseOperation
c_element_op_
;
// for checking IsSupportedArgument()
const
index_t
Conv_G_
;
const
index_t
Conv_N_
;
const
index_t
Conv_K_
;
const
index_t
Conv_C_
;
std
::
array
<
index_t
,
NDimSpatial
>
input_spatial_lengths_
;
std
::
array
<
index_t
,
NDimSpatial
>
filter_spatial_lengths_
;
std
::
array
<
index_t
,
NDimSpatial
>
output_spatial_lengths_
;
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_strides_
;
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_left_pads_
;
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_right_pads_
;
const
index_t
k_batch_
;
};
// Invoker
struct
Invoker
:
public
BaseInvoker
{
using
Argument
=
DeviceOp
::
Argument
;
void
Print
(
const
Argument
&
arg
)
{
std
::
cout
<<
"arg.a_grid_desc_kbatch_k0_m_k1_{"
<<
arg
.
a_grid_desc_kbatch_k0_m_k1_
.
GetLength
(
I0
)
<<
", "
<<
arg
.
a_grid_desc_kbatch_k0_m_k1_
.
GetLength
(
I1
)
<<
", "
<<
arg
.
a_grid_desc_kbatch_k0_m_k1_
.
GetLength
(
I2
)
<<
"}"
<<
std
::
endl
;
std
::
cout
<<
"arg.b_grid_desc_kbatch_k0_n_k1_{"
<<
arg
.
b_grid_desc_kbatch_k0_n_k1_
.
GetLength
(
I0
)
<<
", "
<<
arg
.
b_grid_desc_kbatch_k0_n_k1_
.
GetLength
(
I1
)
<<
", "
<<
arg
.
b_grid_desc_kbatch_k0_n_k1_
.
GetLength
(
I2
)
<<
"}"
<<
std
::
endl
;
std
::
cout
<<
"arg.c_grid_desc_m_n_{"
<<
arg
.
c_grid_desc_m_n_
.
GetLength
(
I0
)
<<
", "
<<
arg
.
c_grid_desc_m_n_
.
GetLength
(
I1
)
<<
"}"
<<
std
::
endl
;
}
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
{
if
(
stream_config
.
log_level_
>
0
)
{
Print
(
arg
);
}
if
(
!
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_kbatch_k0_m_k1_
,
arg
.
b_grid_desc_kbatch_k0_n_k1_
,
arg
.
c_grid_desc_m_n_
,
arg
.
block_2_ctile_map_
))
{
throw
std
::
runtime_error
(
"wrong! GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle has invalid "
"setting"
);
}
const
index_t
grid_size
=
arg
.
block_2_ctile_map_
.
CalculateGridSize
(
arg
.
c_grid_desc_m_n_
)
*
arg
.
Conv_G_
;
const
auto
K0
=
arg
.
a_grid_desc_kbatch_k0_m_k1_
.
GetLength
(
I1
);
const
bool
has_main_k0_block_loop
=
GridwiseGemm
::
CalculateHasMainKBlockLoop
(
K0
);
auto
launch_kernel
=
[
&
](
auto
has_main_k_block_loop
)
{
constexpr
bool
has_main_loop
=
has_main_k_block_loop
.
value
;
const
auto
kernel
=
kernel_grouped_conv_multiple_d_wmma_cshuffle
<
GridwiseGemm
,
ADataType
,
BDataType
,
typename
GridwiseGemm
::
DsGridPointer
,
CDataType
,
OutElementwiseOperation
,
InElementwiseOperation
,
WeiElementwiseOperation
,
AGridDesc_K0_M_K1
,
BGridDesc_K0_N_K1
,
DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
,
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
,
remove_reference_t
<
typename
GridwiseGemm
::
DefaultBlock2CTileMap
>
,
ComputePtrOffsetOfStridedBatch
<
I0
>
,
has_main_loop
>
;
using
EmptyTuple
=
Tuple
<>
;
return
launch_and_time_kernel
(
stream_config
,
kernel
,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
arg
.
p_a_grid_
,
arg
.
p_b_grid_
,
EmptyTuple
{},
// Ds
arg
.
p_c_grid_
,
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
c_element_op_
,
arg
.
Conv_G_
,
arg
.
a_grid_desc_kbatch_k0_m_k1_
,
arg
.
b_grid_desc_kbatch_k0_n_k1_
,
DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
{},
arg
.
c_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg
.
block_2_ctile_map_
,
arg
.
compute_ptr_offset_of_batch_
);
};
if
(
has_main_k0_block_loop
)
{
return
launch_kernel
(
integral_constant
<
bool
,
true
>
{});
}
else
{
return
launch_kernel
(
integral_constant
<
bool
,
false
>
{});
}
}
float
Run
(
const
BaseArgument
*
p_arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
override
{
return
Run
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
),
stream_config
);
}
};
static
constexpr
bool
IsValidCompilationParameter
()
{
// TODO: properly implement this check
return
true
;
}
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
// check device
if
(
get_device_name
()
==
"gfx1100"
||
get_device_name
()
==
"gfx1101"
||
get_device_name
()
==
"gfx1102"
)
{
if
constexpr
(
!
(
is_same_v
<
AccDataType
,
float
>
||
is_same_v
<
AccDataType
,
int32_t
>
))
{
return
false
;
}
}
else
{
return
false
;
}
// TODO: Add support for split_k > 1
if
(
arg
.
k_batch_
!=
1
)
{
return
false
;
}
if
constexpr
(
!
(
is_NDHWGK_GKZYXC_NDHWGC
||
is_GNDHWK_GKZYXC_GNDHWC
))
{
return
false
;
}
if
constexpr
(
ConvBackwardWeightSpecialization
==
ConvolutionBackwardWeightSpecialization
::
Filter1x1Stride1Pad0
)
{
// check if it's a 1x1 convolution with stride=1 and no padding
for
(
int
i
=
0
;
i
<
NDimSpatial
;
i
++
)
{
if
(
!
(
arg
.
filter_spatial_lengths_
[
i
]
==
1
&&
arg
.
conv_filter_strides_
[
i
]
==
1
&&
arg
.
input_left_pads_
[
i
]
==
0
&&
arg
.
input_right_pads_
[
i
]
==
0
))
{
return
false
;
}
}
}
// vector load A/B matrix from global memory
if
(
!
(
ABlockTransferSrcVectorDim
==
1
&&
BBlockTransferSrcVectorDim
==
1
&&
arg
.
Conv_K_
%
ABlockTransferSrcScalarPerVector
==
0
&&
arg
.
Conv_C_
%
BBlockTransferSrcScalarPerVector
==
0
))
{
return
false
;
}
// vector store C matrix into global memory
if
(
!
(
arg
.
Conv_C_
%
CShuffleBlockTransferScalarPerVector_NPerBlock
==
0
))
{
return
false
;
}
// Gridwise GEMM size
return
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_kbatch_k0_m_k1_
,
arg
.
b_grid_desc_kbatch_k0_n_k1_
,
arg
.
c_grid_desc_m_n_
,
arg
.
block_2_ctile_map_
);
}
bool
IsSupportedArgument
(
const
BaseArgument
*
p_arg
)
override
{
return
IsSupportedArgument
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
));
}
static
auto
MakeArgument
(
const
InDataType
*
p_in_grid
,
WeiDataType
*
p_wei_grid
,
const
OutDataType
*
p_out_grid
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
a_g_n_c_wis_lengths
,
// input
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
a_g_n_c_wis_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
b_g_k_c_xs_lengths
,
// weight
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
b_g_k_c_xs_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
e_g_n_k_wos_lengths
,
// output
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
e_g_n_k_wos_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_dilations
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_left_pads
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_right_pads
,
InElementwiseOperation
in_element_op
,
WeiElementwiseOperation
wei_element_op
,
OutElementwiseOperation
out_element_op
,
const
index_t
split_k
)
{
return
Argument
{
p_in_grid
,
p_wei_grid
,
p_out_grid
,
a_g_n_c_wis_lengths
,
// input
a_g_n_c_wis_strides
,
b_g_k_c_xs_lengths
,
// weight
b_g_k_c_xs_strides
,
e_g_n_k_wos_lengths
,
// output
e_g_n_k_wos_strides
,
conv_filter_strides
,
conv_filter_dilations
,
input_left_pads
,
input_right_pads
,
in_element_op
,
wei_element_op
,
out_element_op
,
split_k
};
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
void
*
p_in_grid
,
void
*
p_wei_grid
,
const
void
*
p_out_grid
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
a_g_n_c_wis_lengths
,
// input
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
a_g_n_c_wis_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
b_g_k_c_xs_lengths
,
// weight
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
b_g_k_c_xs_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
e_g_n_k_wos_lengths
,
// output
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
e_g_n_k_wos_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_dilations
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_left_pads
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_right_pads
,
InElementwiseOperation
in_element_op
,
WeiElementwiseOperation
wei_element_op
,
OutElementwiseOperation
out_element_op
,
const
index_t
split_k
)
override
{
return
std
::
make_unique
<
Argument
>
(
static_cast
<
const
InDataType
*>
(
p_in_grid
),
static_cast
<
WeiDataType
*>
(
p_wei_grid
),
static_cast
<
const
OutDataType
*>
(
p_out_grid
),
a_g_n_c_wis_lengths
,
// input
a_g_n_c_wis_strides
,
b_g_k_c_xs_lengths
,
// weight
b_g_k_c_xs_strides
,
e_g_n_k_wos_lengths
,
// output
e_g_n_k_wos_strides
,
conv_filter_strides
,
conv_filter_dilations
,
input_left_pads
,
input_right_pads
,
in_element_op
,
wei_element_op
,
out_element_op
,
split_k
);
}
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
override
{
return
std
::
make_unique
<
Invoker
>
(
Invoker
{});
}
std
::
string
GetTypeString
()
const
override
{
auto
str
=
std
::
stringstream
();
// clang-format off
str
<<
"DeviceGroupedConvBwdWeight_Wmma_CShuffle"
<<
"<"
<<
BlockSize
<<
", "
<<
MPerBlock
<<
", "
<<
NPerBlock
<<
", "
<<
K0PerBlock
<<
", "
<<
getConvBackwardWeightSpecializationString
(
ConvBackwardWeightSpecialization
)
<<
", "
<<
K1
<<
", "
<<
ABlockTransferSrcScalarPerVector
<<
", "
<<
ABlockTransferDstScalarPerVector_K1
<<
", "
<<
BBlockTransferSrcScalarPerVector
<<
", "
<<
BBlockTransferDstScalarPerVector_K1
<<
", "
<<
CShuffleMRepeatPerShuffle
<<
", "
<<
CShuffleNRepeatPerShuffle
<<
", "
<<
CShuffleBlockTransferScalarPerVector_NPerBlock
<<
">"
;
// clang-format on
return
str
.
str
();
}
};
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp
View file @
9f8ab221
...
...
@@ -14,6 +14,7 @@
#include "ck/tensor_operation/gpu/device/device_grouped_conv_bwd_weight.hpp"
#include "ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
...
...
@@ -21,34 +22,9 @@ namespace ck {
namespace
tensor_operation
{
namespace
device
{
namespace
{
struct
ComputePtrOffsetOfStridedBatch
{
__host__
__device__
constexpr
long_index_t
GetAPtrOffset
(
index_t
g_idx
)
const
{
return
g_idx
*
static_cast
<
long_index_t
>
(
BatchStrideA_
);
}
__host__
__device__
constexpr
long_index_t
GetBPtrOffset
(
index_t
g_idx
)
const
{
return
g_idx
*
static_cast
<
long_index_t
>
(
BatchStrideB_
);
}
__host__
__device__
constexpr
long_index_t
GetCPtrOffset
(
index_t
g_idx
)
const
{
return
g_idx
*
static_cast
<
long_index_t
>
(
BatchStrideC_
);
}
index_t
BatchStrideA_
;
index_t
BatchStrideB_
;
index_t
BatchStrideC_
;
};
}
// namespace
template
<
typename
GridwiseGemm
,
typename
FloatAB
,
typename
FloatA
,
typename
FloatB
,
typename
FloatC
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
...
...
@@ -64,8 +40,8 @@ __global__ void
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
#endif
kernel_batched_gemm_xdlops_bwd_weight
(
const
FloatA
B
*
__restrict__
p_a_grid
,
const
Float
A
B
*
__restrict__
p_b_grid
,
const
FloatA
*
__restrict__
p_a_grid
,
const
FloatB
*
__restrict__
p_b_grid
,
FloatC
*
__restrict__
p_c_grid
,
const
AElementwiseOperation
a_element_op
,
const
BElementwiseOperation
b_element_op
,
...
...
@@ -91,7 +67,7 @@ __global__ void
const
long_index_t
c_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_ptr_offset_of_batch
.
GetCPtrOffset
(
g_idx
)));
__shared__
FloatA
B
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()
/
sizeof
(
FloatA
B
)];
__shared__
FloatA
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()
/
sizeof
(
FloatA
)];
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
p_a_grid
+
a_batch_offset
,
p_b_grid
+
b_batch_offset
,
...
...
@@ -163,7 +139,9 @@ template <ck::index_t NDimSpatial,
index_t
CShuffleMXdlPerWavePerShuffle
,
index_t
CShuffleNXdlPerWavePerShuffle
,
typename
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
index_t
CBlockTransferScalarPerVector_NWaveNPerXdl
>
index_t
CBlockTransferScalarPerVector_NWaveNPerXdl
,
typename
ComputeTypeA
=
InDataType
,
typename
ComputeTypeB
=
ComputeTypeA
>
struct
DeviceGroupedConvBwdWeight_Xdl_CShuffle
:
public
DeviceGroupedConvBwdWeight
<
NDimSpatial
,
InLayout
,
...
...
@@ -174,7 +152,9 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
OutDataType
,
InElementwiseOperation
,
WeiElementwiseOperation
,
OutElementwiseOperation
>
OutElementwiseOperation
,
ComputeTypeA
,
ComputeTypeB
>
{
using
DeviceOp
=
DeviceGroupedConvBwdWeight_Xdl_CShuffle
;
...
...
@@ -1045,7 +1025,8 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
using
GridwiseGemm
=
GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight
<
BlockSize
,
ADataType
,
// TODO: distinguish A/B datatype
ADataType
,
BDataType
,
AccDataType
,
CDataType
,
InMemoryDataOperationEnum
::
AtomicAdd
,
...
...
@@ -1090,7 +1071,11 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
CBlockTransferScalarPerVector_NWaveNPerXdl
,
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
true
,
true
>
;
true
,
1
,
PipelineVersion
::
v1
,
ComputeTypeA
,
ComputeTypeB
>
;
// Argument
using
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
=
...
...
@@ -1212,13 +1197,13 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
Block2CTileMap
block_2_ctile_map_
;
// for computing batch offset
ComputePtrOffsetOfStridedBatch
compute_ptr_offset_of_batch_
;
ComputePtrOffsetOfStridedBatch
<
I0
>
compute_ptr_offset_of_batch_
;
index_t
M01_
;
index_t
N01_
;
In
ElementwiseOperation
a_element_op_
;
Out
ElementwiseOperation
b_element_op_
;
Out
ElementwiseOperation
a_element_op_
;
In
ElementwiseOperation
b_element_op_
;
WeiElementwiseOperation
c_element_op_
;
// for checking IsSupportedArgument()
...
...
@@ -1281,7 +1266,8 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
const
auto
kernel
=
kernel_batched_gemm_xdlops_bwd_weight
<
GridwiseGemm
,
ADataType
,
// TODO: distiguish A/B datatype
ADataType
,
BDataType
,
CDataType
,
OutElementwiseOperation
,
InElementwiseOperation
,
...
...
@@ -1290,7 +1276,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
remove_reference_t
<
DeviceOp
::
BGridDesc_K0_N_K1
>
,
remove_reference_t
<
DeviceOp
::
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
>
,
remove_reference_t
<
DeviceOp
::
Block2CTileMap
>
,
ComputePtrOffsetOfStridedBatch
,
ComputePtrOffsetOfStridedBatch
<
I0
>
,
has_main_loop
>
;
return
launch_and_time_kernel
(
stream_config
,
...
...
@@ -1337,6 +1323,10 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
if
(
!
ck
::
is_xdl_supported
())
{
return
false
;
}
if
constexpr
(
NDimSpatial
==
1
)
{
if
constexpr
(
!
is_GNWK_GKXC_GNWC
)
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp
View file @
9f8ab221
...
...
@@ -19,6 +19,7 @@
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_dl_multiple_d.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/host_utility/io.hpp"
...
...
@@ -29,51 +30,6 @@ namespace device {
namespace
{
template
<
index_t
NumDTensor
>
struct
ComputePtrOffsetOfStridedBatch
{
ComputePtrOffsetOfStridedBatch
()
=
default
;
ComputePtrOffsetOfStridedBatch
(
index_t
BatchStrideA
,
index_t
BatchStrideB
,
Array
<
ck
::
index_t
,
NumDTensor
>
BatchStrideDs
,
index_t
BatchStrideE
)
:
BatchStrideA_
(
BatchStrideA
),
BatchStrideB_
(
BatchStrideB
),
BatchStrideDs_
(
BatchStrideDs
),
BatchStrideE_
(
BatchStrideE
)
{
}
__host__
__device__
constexpr
long_index_t
GetAPtrOffset
(
index_t
g_idx
)
const
{
return
g_idx
*
static_cast
<
long_index_t
>
(
BatchStrideA_
);
}
__host__
__device__
constexpr
long_index_t
GetBPtrOffset
(
index_t
g_idx
)
const
{
return
g_idx
*
static_cast
<
long_index_t
>
(
BatchStrideB_
);
}
__host__
__device__
constexpr
auto
GetDsPtrOffset
(
index_t
g_idx
)
const
{
Array
<
long_index_t
,
NumDTensor
>
ds_offset
;
static_for
<
0
,
NumDTensor
,
1
>
{}(
[
&
](
auto
i
)
{
ds_offset
(
i
)
=
g_idx
*
static_cast
<
long_index_t
>
(
BatchStrideDs_
[
i
]);
});
return
ds_offset
;
}
__host__
__device__
constexpr
long_index_t
GetEPtrOffset
(
index_t
g_idx
)
const
{
return
g_idx
*
static_cast
<
long_index_t
>
(
BatchStrideE_
);
}
index_t
BatchStrideA_
;
index_t
BatchStrideB_
;
Array
<
ck
::
index_t
,
NumDTensor
>
BatchStrideDs_
;
index_t
BatchStrideE_
;
};
/*
* \brief Wrapper function of GridwiseGemm::Run to realize BatchedGEMM.
*
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp
View file @
9f8ab221
...
...
@@ -19,6 +19,7 @@
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/host_utility/io.hpp"
...
...
@@ -27,55 +28,6 @@ namespace ck {
namespace
tensor_operation
{
namespace
device
{
namespace
{
template
<
index_t
NumDTensor
>
struct
ComputePtrOffsetOfStridedBatch
{
ComputePtrOffsetOfStridedBatch
()
=
default
;
ComputePtrOffsetOfStridedBatch
(
index_t
BatchStrideA
,
index_t
BatchStrideB
,
Array
<
ck
::
index_t
,
NumDTensor
>
BatchStrideDs
,
index_t
BatchStrideE
)
:
BatchStrideA_
(
BatchStrideA
),
BatchStrideB_
(
BatchStrideB
),
BatchStrideDs_
(
BatchStrideDs
),
BatchStrideE_
(
BatchStrideE
)
{
}
__host__
__device__
constexpr
long_index_t
GetAPtrOffset
(
index_t
g_idx
)
const
{
return
g_idx
*
static_cast
<
long_index_t
>
(
BatchStrideA_
);
}
__host__
__device__
constexpr
long_index_t
GetBPtrOffset
(
index_t
g_idx
)
const
{
return
g_idx
*
static_cast
<
long_index_t
>
(
BatchStrideB_
);
}
__host__
__device__
constexpr
auto
GetDsPtrOffset
(
index_t
g_idx
)
const
{
Array
<
long_index_t
,
NumDTensor
>
ds_offset
;
static_for
<
0
,
NumDTensor
,
1
>
{}(
[
&
](
auto
i
)
{
ds_offset
(
i
)
=
g_idx
*
static_cast
<
long_index_t
>
(
BatchStrideDs_
[
i
]);
});
return
ds_offset
;
}
__host__
__device__
constexpr
long_index_t
GetEPtrOffset
(
index_t
g_idx
)
const
{
return
g_idx
*
static_cast
<
long_index_t
>
(
BatchStrideE_
);
}
index_t
BatchStrideA_
;
index_t
BatchStrideB_
;
Array
<
ck
::
index_t
,
NumDTensor
>
BatchStrideDs_
;
index_t
BatchStrideE_
;
};
}
// namespace
//
// @brief Device Convolution operation.
//
...
...
@@ -519,7 +471,7 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
auto
launch_kernel
=
[
&
](
auto
has_main_k_block_loop
)
{
constexpr
bool
has_main_loop
=
has_main_k_block_loop
.
value
;
const
auto
kernel
=
kernel_grouped_conv_
fwd_
multiple_d_wmma_cshuffle
<
const
auto
kernel
=
kernel_grouped_conv_multiple_d_wmma_cshuffle
<
GridwiseOp
,
ADataType
,
BDataType
,
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_cshuffle.hpp
View file @
9f8ab221
...
...
@@ -19,6 +19,7 @@
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/host_utility/io.hpp"
...
...
@@ -29,51 +30,6 @@ namespace device {
namespace
{
template
<
index_t
NumDTensor
>
struct
ComputePtrOffsetOfStridedBatch
{
ComputePtrOffsetOfStridedBatch
()
=
default
;
ComputePtrOffsetOfStridedBatch
(
index_t
BatchStrideA
,
index_t
BatchStrideB
,
Array
<
ck
::
index_t
,
NumDTensor
>
BatchStrideDs
,
index_t
BatchStrideE
)
:
BatchStrideA_
(
BatchStrideA
),
BatchStrideB_
(
BatchStrideB
),
BatchStrideDs_
(
BatchStrideDs
),
BatchStrideE_
(
BatchStrideE
)
{
}
__host__
__device__
constexpr
long_index_t
GetAPtrOffset
(
index_t
g_idx
)
const
{
return
g_idx
*
static_cast
<
long_index_t
>
(
BatchStrideA_
);
}
__host__
__device__
constexpr
long_index_t
GetBPtrOffset
(
index_t
g_idx
)
const
{
return
g_idx
*
static_cast
<
long_index_t
>
(
BatchStrideB_
);
}
__host__
__device__
constexpr
auto
GetDsPtrOffset
(
index_t
g_idx
)
const
{
Array
<
long_index_t
,
NumDTensor
>
ds_offset
;
static_for
<
0
,
NumDTensor
,
1
>
{}(
[
&
](
auto
i
)
{
ds_offset
(
i
)
=
g_idx
*
static_cast
<
long_index_t
>
(
BatchStrideDs_
[
i
]);
});
return
ds_offset
;
}
__host__
__device__
constexpr
long_index_t
GetEPtrOffset
(
index_t
g_idx
)
const
{
return
g_idx
*
static_cast
<
long_index_t
>
(
BatchStrideE_
);
}
index_t
BatchStrideA_
;
index_t
BatchStrideB_
;
Array
<
ck
::
index_t
,
NumDTensor
>
BatchStrideDs_
;
index_t
BatchStrideE_
;
};
/*
* \brief Wrapper function of GridwiseGemm::Run to realize BatchedGEMM.
*
...
...
@@ -255,7 +211,8 @@ template <index_t NDimSpatial,
index_t
CShuffleNXdlPerWavePerShuffle
,
typename
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
index_t
CDEBlockTransferScalarPerVector_NPerBlock
,
LoopScheduler
LoopSched
=
make_default_loop_scheduler
()>
typename
ComputeDataType
=
ADataType
,
LoopScheduler
LoopSched
=
make_default_loop_scheduler
()>
struct
DeviceGroupedConvFwdMultipleD_Xdl_CShuffle
:
public
DeviceGroupedConvFwdMultipleD
<
NDimSpatial
,
ALayout
,
...
...
@@ -268,7 +225,8 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle
EDataType
,
AElementwiseOperation
,
BElementwiseOperation
,
CDEElementwiseOperation
>
CDEElementwiseOperation
,
ComputeDataType
>
{
using
DeviceOp
=
DeviceGroupedConvFwdMultipleD_Xdl_CShuffle
;
...
...
@@ -367,8 +325,6 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle
using
DsGridDesc_M_N
=
remove_cvref_t
<
decltype
(
MakeDsGridDescriptor_M_N
({},
{}))
>
;
using
EGridDesc_M_N
=
remove_cvref_t
<
decltype
(
MakeEGridDescriptor_M_N
<
ELayout
>
({},
{}))
>
;
using
ComputeDataType
=
ADataType
;
// GridwiseGemm
using
GridwiseGemm
=
GridwiseGemmMultipleD_xdl_cshuffle
<
ADataType
,
// TODO: distinguish A/B datatype
...
...
Prev
1
…
3
4
5
6
7
8
9
10
11
…
25
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment