Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
5aa3c344
Unverified
Commit
5aa3c344
authored
Oct 05, 2022
by
rocking5566
Committed by
GitHub
Oct 05, 2022
Browse files
Merge branch 'develop' into gemm_layernorm_welford
parents
7fefc966
9d8f834a
Changes
129
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
4177 additions
and
267 deletions
+4177
-267
include/ck/tensor_operation/gpu/device/device_layernorm_impl.hpp
.../ck/tensor_operation/gpu/device/device_layernorm_impl.hpp
+88
-107
include/ck/tensor_operation/gpu/device/device_permute.hpp
include/ck/tensor_operation/gpu/device/device_permute.hpp
+37
-0
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp
...vice_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp
+1015
-0
include/ck/tensor_operation/gpu/device/impl/device_permute_impl.hpp
.../tensor_operation/gpu/device/impl/device_permute_impl.hpp
+282
-0
include/ck/tensor_operation/gpu/device/matrix_padder.hpp
include/ck/tensor_operation/gpu/device/matrix_padder.hpp
+159
-0
include/ck/tensor_operation/gpu/device/tensor_layout.hpp
include/ck/tensor_operation/gpu/device/tensor_layout.hpp
+24
-0
include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp
...r_operation/gpu/element/binary_element_wise_operation.hpp
+55
-0
include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp
...or_operation/gpu/element/unary_element_wise_operation.hpp
+36
-0
include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp
include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp
+44
-0
include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle_v1.hpp
...tched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle_v1.hpp
+1268
-0
include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp
...id/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp
+93
-29
include/ck/tensor_operation/gpu/grid/gridwise_elementwise_1d.hpp
.../ck/tensor_operation/gpu/grid/gridwise_elementwise_1d.hpp
+4
-0
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp
...ration/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp
+17
-25
include/ck/tensor_operation/gpu/grid/gridwise_layernorm_naive_variance.hpp
..._operation/gpu/grid/gridwise_layernorm_naive_variance.hpp
+55
-53
include/ck/tensor_operation/gpu/grid/gridwise_layernorm_welford_variance.hpp
...peration/gpu/grid/gridwise_layernorm_welford_variance.hpp
+52
-50
include/ck/tensor_operation/gpu/grid/gridwise_permute.hpp
include/ck/tensor_operation/gpu/grid/gridwise_permute.hpp
+339
-0
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1.hpp
...tion/gpu/thread/threadwise_tensor_slice_transfer_v3r1.hpp
+1
-0
include/ck/tensor_operation/operator_transform/transform_conv_bwd_data_to_gemm_v1.hpp
...operator_transform/transform_conv_bwd_data_to_gemm_v1.hpp
+583
-0
include/ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp
...eration/operator_transform/transform_conv_fwd_to_gemm.hpp
+24
-0
include/ck/utility/ignore.hpp
include/ck/utility/ignore.hpp
+1
-3
No files found.
include/ck/tensor_operation/gpu/device/device_layernorm_impl.hpp
View file @
5aa3c344
...
...
@@ -23,11 +23,10 @@ template <typename GridwiseReduction,
typename
YDataType
,
typename
AccDataType
,
typename
AccElementwiseOperation
,
typename
GridDesc_M_K
,
typename
GridDesc_K
>
typename
GridDesc_M_K
>
__global__
void
kernel_layernorm
(
const
GridDesc_M_K
x_grid_desc_m_k
,
const
GridDesc_K
gamma_grid_desc_k
,
const
GridDesc_K
beta_grid_desc_k
,
const
GridDesc_
M_
K
gamma_grid_desc_
m_
k
,
const
GridDesc_
M_
K
beta_grid_desc_
m_
k
,
const
GridDesc_M_K
y_grid_desc_m_k
,
index_t
num_k_block_tile_iteration
,
AccDataType
epsilon
,
...
...
@@ -38,8 +37,8 @@ __global__ void kernel_layernorm(const GridDesc_M_K x_grid_desc_m_k,
const
AccElementwiseOperation
acc_elementwise_op
)
{
GridwiseReduction
::
Run
(
x_grid_desc_m_k
,
gamma_grid_desc_k
,
beta_grid_desc_k
,
gamma_grid_desc_
m_
k
,
beta_grid_desc_
m_
k
,
y_grid_desc_m_k
,
num_k_block_tile_iteration
,
epsilon
,
...
...
@@ -71,7 +70,9 @@ template <typename XDataType,
index_t
KThreadSliceSize
,
index_t
XYSrcVectorDim
,
index_t
XSrcVectorSize
,
index_t
GammaSrcVectorDim
,
index_t
GammaSrcVectorSize
,
index_t
BetaSrcVectorDim
,
index_t
BetaSrcVectorSize
,
index_t
YDstVectorSize
>
struct
DeviceLayernormImpl
:
public
DeviceLayernorm
<
XDataType
,
...
...
@@ -84,11 +85,13 @@ struct DeviceLayernormImpl : public DeviceLayernorm<XDataType,
NumReduceDim
>
{
static_assert
(
(
KThreadSliceSize
%
GammaSrcVectorSize
==
0
),
((
GammaSrcVectorDim
==
0
&&
MThreadSliceSize
%
GammaSrcVectorSize
==
0
)
||
(
GammaSrcVectorDim
==
1
&&
KThreadSliceSize
%
GammaSrcVectorSize
==
0
)),
"Invalid thread slice sizes and/or gamma vector sizes configuration, please check!"
);
static_assert
(
(
KThreadSliceSize
%
BetaSrcVectorSize
==
0
),
((
BetaSrcVectorDim
==
0
&&
MThreadSliceSize
%
BetaSrcVectorSize
==
0
)
||
(
BetaSrcVectorDim
==
1
&&
KThreadSliceSize
%
BetaSrcVectorSize
==
0
)),
"Invalid thread slice sizes and/or beta vector sizes configuration, please check!"
);
using
PassThrough
=
tensor_operation
::
element_wise
::
PassThrough
;
...
...
@@ -162,38 +165,7 @@ struct DeviceLayernormImpl : public DeviceLayernorm<XDataType,
return
(
in_grid_desc_m_k_padded
);
};
static
auto
MakeAffine1dDescriptor
(
const
std
::
vector
<
index_t
>&
Lengths
,
const
std
::
vector
<
index_t
>&
Strides
,
int
blkGroupSize
,
int
numBlockTileIteration
)
{
const
auto
tupleLengths
=
make_tuple_from_array
(
Lengths
,
Number
<
NumReduceDim
>
{});
const
auto
tupleStrides
=
make_tuple_from_array
(
Strides
,
Number
<
NumReduceDim
>
{});
auto
desc
=
make_naive_tensor_descriptor
(
tupleLengths
,
tupleStrides
);
auto
grid_desc_k
=
transform_tensor_descriptor
(
desc
,
make_tuple
(
make_merge_transform
(
tupleLengths
)),
make_tuple
(
typename
arithmetic_sequence_gen
<
0
,
NumReduceDim
,
1
>::
type
{}),
make_tuple
(
Sequence
<
0
>
{}));
const
auto
reduceTotalLength
=
grid_desc_k
.
GetLength
(
Number
<
0
>
{});
const
int
reduceSizePerBlock
=
K_BlockTileSize
*
numBlockTileIteration
;
const
auto
Pad_K
=
reduceSizePerBlock
*
blkGroupSize
-
reduceTotalLength
;
auto
grid_desc_k_padded
=
transform_tensor_descriptor
(
grid_desc_k
,
make_tuple
(
make_right_pad_transform
(
reduceTotalLength
,
Pad_K
)),
make_tuple
(
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
return
(
grid_desc_k_padded
);
};
using
GridDesc_M_K
=
decltype
(
MakeSrc2dDescriptor
({
1
},
{
1
},
1
,
1
));
using
GridDesc_K
=
decltype
(
MakeAffine1dDescriptor
({
1
},
{
1
},
1
,
1
));
using
GridwiseReduceLayernormGeneric
=
GridwiseLayernormWelfordVariance_mk_to_mk
<
XDataType
,
...
...
@@ -203,7 +175,6 @@ struct DeviceLayernormImpl : public DeviceLayernorm<XDataType,
AccDataType
,
AccElementwiseOperation
,
GridDesc_M_K
,
GridDesc_K
,
BlockSize
,
MThreadClusterSize
,
KThreadClusterSize
,
...
...
@@ -211,12 +182,13 @@ struct DeviceLayernormImpl : public DeviceLayernorm<XDataType,
KThreadSliceSize
,
XYSrcVectorDim
,
XSrcVectorSize
,
GammaSrcVectorDim
,
GammaSrcVectorSize
,
BetaSrcVectorDim
,
BetaSrcVectorSize
,
XYSrcVectorDim
,
YDstVectorSize
,
false
>
;
using
GridwiseReduceLayernormSweepOnce
=
GridwiseLayernormWelfordVariance_mk_to_mk
<
XDataType
,
GammaDataType
,
...
...
@@ -225,7 +197,6 @@ struct DeviceLayernormImpl : public DeviceLayernorm<XDataType,
AccDataType
,
AccElementwiseOperation
,
GridDesc_M_K
,
GridDesc_K
,
BlockSize
,
MThreadClusterSize
,
KThreadClusterSize
,
...
...
@@ -233,7 +204,9 @@ struct DeviceLayernormImpl : public DeviceLayernorm<XDataType,
KThreadSliceSize
,
XYSrcVectorDim
,
XSrcVectorSize
,
GammaSrcVectorDim
,
GammaSrcVectorSize
,
BetaSrcVectorDim
,
BetaSrcVectorSize
,
XYSrcVectorDim
,
YDstVectorSize
,
...
...
@@ -258,13 +231,13 @@ struct DeviceLayernormImpl : public DeviceLayernorm<XDataType,
p_gamma_
(
p_gamma
),
p_beta_
(
p_beta
),
p_y_
(
p_y
),
gammaStrides_
(
gammaStrides
),
betaStrides_
(
betaStrides
),
acc_elementwise_op_
(
acc_elementwise_op
)
{
Lengths_
=
shuffle_tensor_dimensions
<
Rank
,
NumReduceDim
>
(
lengths
,
reduceDims
);
xStrides_
=
shuffle_tensor_dimensions
<
Rank
,
NumReduceDim
>
(
xStrides
,
reduceDims
);
yStrides_
=
shuffle_tensor_dimensions
<
Rank
,
NumReduceDim
>
(
yStrides
,
reduceDims
);
Lengths_
=
shuffle_tensor_dimensions
<
Rank
,
NumReduceDim
>
(
lengths
,
reduceDims
);
xStrides_
=
shuffle_tensor_dimensions
<
Rank
,
NumReduceDim
>
(
xStrides
,
reduceDims
);
yStrides_
=
shuffle_tensor_dimensions
<
Rank
,
NumReduceDim
>
(
yStrides
,
reduceDims
);
gammaStrides_
=
shuffle_tensor_dimensions
<
Rank
,
NumReduceDim
>
(
gammaStrides
,
reduceDims
);
betaStrides_
=
shuffle_tensor_dimensions
<
Rank
,
NumReduceDim
>
(
betaStrides
,
reduceDims
);
long_index_t
invariant_total_length
;
long_index_t
reduce_total_length
;
...
...
@@ -278,12 +251,17 @@ struct DeviceLayernormImpl : public DeviceLayernorm<XDataType,
gridSize_
=
math
::
integer_least_multiple
(
invariant_total_length
,
M_BlockTileSize
)
/
M_BlockTileSize
*
blkGroupSize_
;
reduceLengths_
.
resize
(
NumReduceDim
);
for
(
int
i
=
0
;
i
<
NumReduceDim
;
++
i
)
{
reduceLengths_
[
i
]
=
lengths
[
reduceDims
[
i
]];
}
x_grid_desc_m_k_
=
MakeSrc2dDescriptor
(
Lengths_
,
xStrides_
,
blkGroupSize_
,
numBlockTileIteration_
);
gamma_grid_desc_m_k_
=
MakeSrc2dDescriptor
(
Lengths_
,
gammaStrides_
,
blkGroupSize_
,
numBlockTileIteration_
);
beta_grid_desc_m_k_
=
MakeSrc2dDescriptor
(
Lengths_
,
betaStrides_
,
blkGroupSize_
,
numBlockTileIteration_
);
y_grid_desc_m_k_
=
MakeSrc2dDescriptor
(
Lengths_
,
yStrides_
,
blkGroupSize_
,
numBlockTileIteration_
);
isSweeponce_
=
x_grid_desc_m_k_
.
GetLength
(
Number
<
1
>
{})
<=
KThreadClusterSize
*
KThreadSliceSize
;
}
AccDataType
epsilon_
;
...
...
@@ -295,7 +273,6 @@ struct DeviceLayernormImpl : public DeviceLayernorm<XDataType,
std
::
vector
<
index_t
>
Lengths_
;
std
::
vector
<
index_t
>
xStrides_
;
std
::
vector
<
index_t
>
reduceLengths_
;
std
::
vector
<
index_t
>
gammaStrides_
;
std
::
vector
<
index_t
>
betaStrides_
;
std
::
vector
<
index_t
>
yStrides_
;
...
...
@@ -305,46 +282,35 @@ struct DeviceLayernormImpl : public DeviceLayernorm<XDataType,
int
blkGroupSize_
;
int
numBlockTileIteration_
;
size_t
gridSize_
;
GridDesc_M_K
x_grid_desc_m_k_
;
GridDesc_M_K
gamma_grid_desc_m_k_
;
GridDesc_M_K
beta_grid_desc_m_k_
;
GridDesc_M_K
y_grid_desc_m_k_
;
bool
isSweeponce_
;
};
struct
Invoker
:
public
BaseInvoker
{
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
{
const
auto
x_grid_desc_m_k
=
MakeSrc2dDescriptor
(
arg
.
Lengths_
,
arg
.
xStrides_
,
arg
.
blkGroupSize_
,
arg
.
numBlockTileIteration_
);
const
auto
gamma_grid_desc_k
=
MakeAffine1dDescriptor
(
arg
.
reduceLengths_
,
arg
.
gammaStrides_
,
arg
.
blkGroupSize_
,
arg
.
numBlockTileIteration_
);
const
auto
beta_grid_desc_k
=
MakeAffine1dDescriptor
(
arg
.
reduceLengths_
,
arg
.
betaStrides_
,
arg
.
blkGroupSize_
,
arg
.
numBlockTileIteration_
);
const
auto
y_grid_desc_m_k
=
MakeSrc2dDescriptor
(
arg
.
Lengths_
,
arg
.
yStrides_
,
arg
.
blkGroupSize_
,
arg
.
numBlockTileIteration_
);
bool
sweep_once
=
x_grid_desc_m_k
.
GetLength
(
Number
<
1
>
{})
<=
KThreadClusterSize
*
KThreadSliceSize
;
const
auto
kernel_main
=
sweep_once
?
kernel_layernorm
<
GridwiseReduceLayernormSweepOnce
,
XDataType
,
GammaDataType
,
BetaDataType
,
YDataType
,
AccDataType
,
AccElementwiseOperation
,
GridDesc_M_K
,
GridDesc_K
>
:
kernel_layernorm
<
GridwiseReduceLayernormGeneric
,
XDataType
,
GammaDataType
,
BetaDataType
,
YDataType
,
AccDataType
,
AccElementwiseOperation
,
GridDesc_M_K
,
GridDesc_K
>
;
const
auto
kernel_main
=
arg
.
isSweeponce_
?
kernel_layernorm
<
GridwiseReduceLayernormSweepOnce
,
XDataType
,
GammaDataType
,
BetaDataType
,
YDataType
,
AccDataType
,
AccElementwiseOperation
,
GridDesc_M_K
>
:
kernel_layernorm
<
GridwiseReduceLayernormGeneric
,
XDataType
,
GammaDataType
,
BetaDataType
,
YDataType
,
AccDataType
,
AccElementwiseOperation
,
GridDesc_M_K
>
;
float
avg_time
=
0
;
avg_time
+=
launch_and_time_kernel
(
stream_config
,
...
...
@@ -352,10 +318,10 @@ struct DeviceLayernormImpl : public DeviceLayernorm<XDataType,
dim3
(
arg
.
gridSize_
),
dim3
(
BlockSize
),
0
,
x_grid_desc_m_k
,
gamma_grid_desc_
k
,
beta_grid_desc_
k
,
y_grid_desc_m_k
,
arg
.
x_grid_desc_m_k
_
,
arg
.
gamma_grid_desc_
m_k_
,
arg
.
beta_grid_desc_
m_k_
,
arg
.
y_grid_desc_m_k
_
,
arg
.
numBlockTileIteration_
,
arg
.
epsilon_
,
arg
.
p_x_
,
...
...
@@ -409,26 +375,41 @@ struct DeviceLayernormImpl : public DeviceLayernorm<XDataType,
return
false
;
}
if
(
p_arg_
->
gammaStrides_
.
size
()
!=
NumReduceDim
||
p_arg_
->
betaStrides_
.
size
()
!=
NumReduceDim
)
return
false
;
// if fastest dim is not reduced
if
constexpr
(
GammaSrcVectorDim
==
0
)
{
if
(
p_arg_
->
gammaStrides_
[
NumInvariantDim
-
1
]
!=
1
)
return
(
false
);
auto
IsScalarPerVectorValid
=
[](
bool
isLastDimensionCoalesced
,
int
scalarPerVector
)
{
bool
ret
=
true
;
if
(
p_arg_
->
Lengths_
[
Rank
-
1
]
%
GammaSrcVectorSize
!=
0
)
return
(
false
);
}
else
// if fastest dim is reduced
{
if
(
p_arg_
->
gammaStrides_
[
Rank
-
1
]
!=
1
)
return
(
false
);
if
(
!
isLastDimensionCoalesced
)
ret
=
scalarPerVector
==
1
;
else
ret
=
KThreadSliceSize
%
scalarPerVector
==
0
;
if
(
p_arg_
->
Lengths_
[
Rank
-
1
]
%
GammaSrcVectorSize
!=
0
)
return
(
false
);
}
return
ret
;
};
// if fastest dim is not reduced
if
constexpr
(
BetaSrcVectorDim
==
0
)
{
if
(
p_arg_
->
betaStrides_
[
NumInvariantDim
-
1
]
!=
1
)
return
(
false
);
if
(
!
IsScalarPerVectorValid
(
p_arg_
->
gammaStrides_
.
back
()
==
1
,
GammaSrcVectorSize
))
return
false
;
if
(
p_arg_
->
invariant_lowest_length
%
BetaSrcVectorSize
!=
0
)
return
(
false
);
}
else
// if fastest dim is reduced
{
if
(
p_arg_
->
betaStrides_
[
Rank
-
1
]
!=
1
)
return
(
false
);
if
(
!
IsScalarPerVectorValid
(
p_arg_
->
betaStrides_
.
back
()
==
1
,
BetaSrcVectorSize
))
return
false
;
if
(
p_arg_
->
Lengths_
[
Rank
-
1
]
%
BetaSrcVectorSize
!=
0
)
return
(
false
);
}
return
true
;
};
...
...
include/ck/tensor_operation/gpu/device/device_permute.hpp
0 → 100644
View file @
5aa3c344
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <array>
#include <cmath>
#include <memory>
#include <type_traits>
#include "ck/tensor_operation/gpu/device/device_base.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
template
<
index_t
NumDim
,
typename
InDataType
,
typename
OutDataType
,
typename
ElementwiseOperation
>
struct
DevicePermute
:
BaseOperator
{
using
Lengths
=
std
::
array
<
index_t
,
NumDim
>
;
using
Strides
=
Lengths
;
virtual
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
Lengths
&
in_lengths
,
const
Strides
&
in_strides
,
const
Lengths
&
out_lengths
,
const
Strides
&
out_strides
,
const
void
*
in_dev_buffer
,
void
*
out_dev_buffer
,
ElementwiseOperation
elementwise_op
)
=
0
;
virtual
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
=
0
;
};
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp
0 → 100644
View file @
5aa3c344
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <sstream>
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_conv_bwd_data_multiple_d.hpp"
#include "ck/tensor_operation/gpu/device/convolution_backward_data_specialization.hpp"
#include "ck/tensor_operation/operator_transform/transform_conv_bwd_data_to_gemm_v1.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/host_utility/io.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
{
template
<
index_t
NumDTensor
>
struct
ComputePtrOffsetOfStridedBatch
{
ComputePtrOffsetOfStridedBatch
()
=
default
;
ComputePtrOffsetOfStridedBatch
(
index_t
BatchStrideA
,
index_t
BatchStrideB
,
Array
<
ck
::
index_t
,
NumDTensor
>
BatchStrideDs
,
index_t
BatchStrideE
)
:
BatchStrideA_
(
BatchStrideA
),
BatchStrideB_
(
BatchStrideB
),
BatchStrideDs_
(
BatchStrideDs
),
BatchStrideE_
(
BatchStrideE
)
{
}
__host__
__device__
constexpr
long_index_t
GetAPtrOffset
(
index_t
g_idx
)
const
{
return
g_idx
*
static_cast
<
long_index_t
>
(
BatchStrideA_
);
}
__host__
__device__
constexpr
long_index_t
GetBPtrOffset
(
index_t
g_idx
)
const
{
return
g_idx
*
static_cast
<
long_index_t
>
(
BatchStrideB_
);
}
__host__
__device__
constexpr
auto
GetDsPtrOffset
(
index_t
g_idx
)
const
{
Array
<
long_index_t
,
NumDTensor
>
ds_offset
;
static_for
<
0
,
NumDTensor
,
1
>
{}(
[
&
](
auto
i
)
{
ds_offset
(
i
)
=
g_idx
*
static_cast
<
long_index_t
>
(
BatchStrideDs_
[
i
]);
});
return
ds_offset
;
}
__host__
__device__
constexpr
long_index_t
GetEPtrOffset
(
index_t
g_idx
)
const
{
return
g_idx
*
static_cast
<
long_index_t
>
(
BatchStrideE_
);
}
index_t
BatchStrideA_
;
index_t
BatchStrideB_
;
Array
<
ck
::
index_t
,
NumDTensor
>
BatchStrideDs_
;
index_t
BatchStrideE_
;
};
/*
* \brief Wrapper function of GridwiseGemm::Run to realize BatchedGEMM.
*
* \tparam ComputePtrOffsetOfBatch Class that computes the base pointer offsets of A, B, C matrix
* given the batch. For example, ComputePtrOffsetOfStridedBatch() computes the offsets of evenly
* strided batched, but we can easily extend to other layouts. The returned offset can be either \p
* index_t or \p long_index_t. If it returns \p long_index_t, we are not subject to the 2GB
* limitations.
*
* \tparam Block2ETileMap Block2ETileMap::CalculateBottomIndex() takes in id of a workgroup and
* returns the 2D index of the tile that it computes. \see
* GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3::Run().
*
* \note Using \p ComputePtrOffsetOfBatch gives us the flexibility that 2 workgroups can compute 2
* tiles from different matrices. Keep in mind that these 2 matrices can share the same grid
* descriptor (like in BatchedGEMM), or use their own grid descriptors (in GroupedGemm). \link
* device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp kernel_gemm_xdlops_v2r3_for_conv3d \endlink for \link
* DeviceConv3d \endlink uses the same concept, but currently does NOT encapsulate the computing of
* pointer offset into \p ComputePtrOffsetOfStridedBatch.
*
* \note \p Block2ETileMap allows customized mapping between a workgroup and the C-tile it computes.
* Together with \p ComputePtrOffsetOfBatch, we can reuse GridwiseGemm (and GridwiseGemm fusion ) to
* realize BatchedGemm and GroupedGemm (and the corresponding GEMM fusion).
*
*/
template
<
typename
GridwiseGemm
,
typename
ABDataType
,
typename
DsPointer
,
typename
EDataType
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CDEElementwiseOperation
,
typename
AGridDesc_AK0_M_AK1
,
typename
BGridDesc_BK0_N_BK1
,
typename
DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
Block2ETileMap
,
typename
ComputePtrOffsetOfBatch
,
bool
HasMainKBlockLoop
>
__global__
void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
#endif
kernel_grouped_conv_bwd_data_multiple_d_xdl_cshuffle
(
const
ABDataType
*
__restrict__
p_a_grid
,
const
ABDataType
*
__restrict__
p_b_grid
,
DsPointer
p_ds_grid
,
EDataType
*
__restrict__
p_e_grid
,
const
AElementwiseOperation
a_element_op
,
const
BElementwiseOperation
b_element_op
,
const
CDEElementwiseOperation
cde_element_op
,
const
index_t
batch_count
,
const
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1
,
const
BGridDesc_BK0_N_BK1
b_grid_desc_bk0_n_bk1
,
const
DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
ds_grid_desc_mblock_mperblock_nblock_nperblock
,
const
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
e_grid_desc_mblock_mperblock_nblock_nperblock_
,
const
Block2ETileMap
block_2_ctile_map
,
const
ComputePtrOffsetOfBatch
compute_ptr_offset_of_batch
)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
// offset base pointer for each work-group
const
index_t
num_blocks_per_batch
=
__builtin_amdgcn_readfirstlane
(
get_grid_size
()
/
batch_count
);
const
index_t
g_idx
=
__builtin_amdgcn_readfirstlane
(
get_block_1d_id
()
/
num_blocks_per_batch
);
const
long_index_t
a_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_ptr_offset_of_batch
.
GetAPtrOffset
(
g_idx
)));
const
long_index_t
b_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_ptr_offset_of_batch
.
GetBPtrOffset
(
g_idx
)));
const
long_index_t
e_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_ptr_offset_of_batch
.
GetEPtrOffset
(
g_idx
)));
const
auto
ds_batch_offset
=
compute_ptr_offset_of_batch
.
GetDsPtrOffset
(
g_idx
);
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
DsPointer
p_ds_grid_grp
;
static
constexpr
index_t
NumDTensor
=
DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
::
Size
();
static_for
<
0
,
NumDTensor
,
1
>
{}(
[
&
](
auto
i
)
{
p_ds_grid_grp
(
i
)
=
p_ds_grid
[
i
]
+
ds_batch_offset
[
i
];
});
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
p_a_grid
+
a_batch_offset
,
p_b_grid
+
b_batch_offset
,
p_ds_grid_grp
,
p_e_grid
+
e_batch_offset
,
p_shared
,
a_element_op
,
b_element_op
,
cde_element_op
,
a_grid_desc_ak0_m_ak1
,
b_grid_desc_bk0_n_bk1
,
ds_grid_desc_mblock_mperblock_nblock_nperblock
,
e_grid_desc_mblock_mperblock_nblock_nperblock_
,
block_2_ctile_map
);
#else
ignore
=
p_a_grid
;
ignore
=
p_b_grid
;
ignore
=
p_ds_grid
;
ignore
=
p_e_grid
;
ignore
=
batch_count
;
ignore
=
a_grid_desc_ak0_m_ak1
;
ignore
=
b_grid_desc_bk0_n_bk1
;
ignore
=
ds_grid_desc_mblock_mperblock_nblock_nperblock
;
ignore
=
e_grid_desc_mblock_mperblock_nblock_nperblock_
;
ignore
=
a_element_op
;
ignore
=
b_element_op
;
ignore
=
cde_element_op
;
ignore
=
compute_ptr_offset_of_batch
;
ignore
=
block_2_ctile_map
;
#endif
}
}
// namespace
// Conv backward data multiple D:
// input : output image A: [G, N, K, Ho, Wo]
// input : weight B: [G, K, C, Y, X],
// input : D0, D1, ... : [G, N, K, Ho, Wo]
// output : input image E: [G, N, C, Hi, Wi]
// C = a_op(A) * b_op(B)
// E = cde_op(C, D0, D1, ...)
template
<
index_t
NDimSpatial
,
typename
ALayout
,
// output image
typename
BLayout
,
// weight
typename
DsLayout
,
// bias
typename
ELayout
,
// input image
typename
ADataType
,
// output image
typename
BDataType
,
// weight
typename
AccDataType
,
typename
CShuffleDataType
,
typename
DsDataType
,
// bias
typename
EDataType
,
// input image
typename
AElementwiseOp
,
// output image
typename
BElementwiseOp
,
// weight
typename
CDEElementwiseOp
,
// C, bias, and input image
ConvolutionBackwardDataSpecialization
ConvBackwardDataSpecialization
,
bool
DoPadGemmM
,
bool
DoPadGemmN
,
index_t
NumGemmKPrefetchStage
,
index_t
BlockSize
,
index_t
MPerBlock
,
index_t
NPerBlock
,
index_t
KPerBlock
,
index_t
AK1
,
index_t
BK1
,
index_t
MPerXDL
,
index_t
NPerXDL
,
index_t
MXdlPerWave
,
index_t
NXdlPerWave
,
typename
ABlockTransferThreadClusterLengths_AK0_M_AK1
,
typename
ABlockTransferThreadClusterArrangeOrder
,
typename
ABlockTransferSrcAccessOrder
,
index_t
ABlockTransferSrcVectorDim
,
index_t
ABlockTransferSrcScalarPerVector
,
index_t
ABlockTransferDstScalarPerVector_AK1
,
index_t
ABlockLdsExtraM
,
typename
BBlockTransferThreadClusterLengths_BK0_N_BK1
,
typename
BBlockTransferThreadClusterArrangeOrder
,
typename
BBlockTransferSrcAccessOrder
,
index_t
BBlockTransferSrcVectorDim
,
index_t
BBlockTransferSrcScalarPerVector
,
index_t
BBlockTransferDstScalarPerVector_BK1
,
index_t
BBlockLdsExtraN
,
index_t
CShuffleMXdlPerWavePerShuffle
,
index_t
CShuffleNXdlPerWavePerShuffle
,
typename
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
index_t
CDEBlockTransferScalarPerVector_NPerBlock
,
LoopScheduler
LoopSched
=
make_default_loop_scheduler
()>
struct
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
:
public
DeviceGroupedConvBwdDataMultipleD
<
NDimSpatial
,
ALayout
,
// output image
BLayout
,
// weight
DsLayout
,
// bias
ELayout
,
// input image
ADataType
,
// output image
BDataType
,
// weight
DsDataType
,
// bias
EDataType
,
// input image
AElementwiseOp
,
BElementwiseOp
,
CDEElementwiseOp
>
{
// FIXME
static_assert
(
NDimSpatial
==
2
,
"wrong! only implemented for 2D now"
);
using
DeviceOp
=
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
;
static
constexpr
index_t
NumDTensor
=
DsDataType
::
Size
();
// TODO make A/B datatype different
using
ABDataType
=
ADataType
;
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
static
constexpr
auto
transform_conv_to_gemm
=
TransformConvBwdDataToGemm_v1
<
NDimSpatial
,
ConvBackwardDataSpecialization
,
AK1
,
BK1
,
MPerBlock
,
NPerBlock
,
DoPadGemmM
,
DoPadGemmN
>
{};
static
auto
GetDummyABDsEGridDescriptor
()
{
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>
dummy_tensor_lengths
=
{
1
};
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>
dummy_tensor_strides
=
{
1
};
const
std
::
array
<
index_t
,
NDimSpatial
>
dummy_spatial_lengths
=
{
1
};
const
auto
a_grid_desc_ak0_m_ak1
=
transform_conv_to_gemm
.
template
MakeADescriptor_AK0_M_AK1
<
ALayout
>(
dummy_tensor_lengths
,
dummy_tensor_strides
,
dummy_tensor_lengths
,
dummy_tensor_strides
,
dummy_tensor_lengths
,
dummy_tensor_strides
,
dummy_spatial_lengths
,
dummy_spatial_lengths
,
dummy_spatial_lengths
,
dummy_spatial_lengths
,
dummy_spatial_lengths
);
const
auto
b_grid_desc_bk0_n_bk1
=
transform_conv_to_gemm
.
template
MakeBDescriptor_BK0_N_BK1
<
BLayout
>(
dummy_tensor_lengths
,
dummy_tensor_strides
,
dummy_tensor_lengths
,
dummy_tensor_strides
,
dummy_tensor_lengths
,
dummy_tensor_strides
,
dummy_spatial_lengths
,
dummy_spatial_lengths
,
dummy_spatial_lengths
,
dummy_spatial_lengths
,
dummy_spatial_lengths
);
const
auto
ds_grid_desc_m_n
=
generate_tuple
(
[
&
](
auto
i
)
{
using
DLayout
=
remove_cvref_t
<
tuple_element_t
<
i
.
value
,
DsLayout
>>
;
return
transform_conv_to_gemm
.
template
MakeCDescriptor_M_N
<
DLayout
>(
dummy_tensor_lengths
,
dummy_tensor_strides
,
dummy_tensor_lengths
,
dummy_tensor_strides
,
dummy_tensor_lengths
,
dummy_tensor_strides
,
dummy_spatial_lengths
,
dummy_spatial_lengths
,
dummy_spatial_lengths
,
dummy_spatial_lengths
,
dummy_spatial_lengths
);
},
Number
<
NumDTensor
>
{});
const
auto
e_grid_desc_m_n
=
transform_conv_to_gemm
.
template
MakeCDescriptor_M_N
<
ELayout
>(
dummy_tensor_lengths
,
dummy_tensor_strides
,
dummy_tensor_lengths
,
dummy_tensor_strides
,
dummy_tensor_lengths
,
dummy_tensor_strides
,
dummy_spatial_lengths
,
dummy_spatial_lengths
,
dummy_spatial_lengths
,
dummy_spatial_lengths
,
dummy_spatial_lengths
);
return
make_tuple
(
a_grid_desc_ak0_m_ak1
,
b_grid_desc_bk0_n_bk1
,
ds_grid_desc_m_n
,
e_grid_desc_m_n
);
}
// GridwiseGemm
using
GridwiseGemm
=
GridwiseGemmMultipleD_xdl_cshuffle
<
ABDataType
,
// TODO: distinguish A/B datatype
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
AElementwiseOp
,
BElementwiseOp
,
CDEElementwiseOp
,
InMemoryDataOperationEnum
::
Set
,
NumGemmKPrefetchStage
,
BlockSize
,
MPerBlock
,
NPerBlock
,
KPerBlock
,
AK1
,
BK1
,
MPerXDL
,
NPerXDL
,
MXdlPerWave
,
NXdlPerWave
,
ABlockTransferThreadClusterLengths_AK0_M_AK1
,
ABlockTransferThreadClusterArrangeOrder
,
ABlockTransferSrcAccessOrder
,
ABlockTransferSrcVectorDim
,
ABlockTransferSrcScalarPerVector
,
ABlockTransferDstScalarPerVector_AK1
,
false
,
ABlockLdsExtraM
,
BBlockTransferThreadClusterLengths_BK0_N_BK1
,
BBlockTransferThreadClusterArrangeOrder
,
BBlockTransferSrcAccessOrder
,
BBlockTransferSrcVectorDim
,
BBlockTransferSrcScalarPerVector
,
BBlockTransferDstScalarPerVector_BK1
,
false
,
BBlockLdsExtraN
,
CShuffleMXdlPerWavePerShuffle
,
CShuffleNXdlPerWavePerShuffle
,
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
CDEBlockTransferScalarPerVector_NPerBlock
,
LoopSched
>
;
template
<
typename
Desc_K0_M_K1
>
static
auto
transform_k0_m_k1_to_m_k
(
const
Desc_K0_M_K1
&
desc_k0_m_k1
)
{
const
auto
grid_desc_m_k
=
transform_tensor_descriptor
(
desc_k0_m_k1
,
make_tuple
(
make_pass_through_transform
(
desc_k0_m_k1
.
GetLength
(
I1
)),
make_merge_transform
(
make_tuple
(
desc_k0_m_k1
.
GetLength
(
I0
),
desc_k0_m_k1
.
GetLength
(
I2
)))),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
,
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
return
grid_desc_m_k
;
}
// desc
using
ABDsEGridDesc
=
decltype
(
GetDummyABDsEGridDescriptor
());
using
AGridDesc_AK0_M_AK1
=
remove_cvref_t
<
tuple_element_t
<
0
,
ABDsEGridDesc
>>
;
using
BGridDesc_BK0_N_BK1
=
remove_cvref_t
<
tuple_element_t
<
1
,
ABDsEGridDesc
>>
;
using
DsGridDesc_M_N
=
remove_cvref_t
<
tuple_element_t
<
2
,
ABDsEGridDesc
>>
;
using
EGridDesc_M_N
=
remove_cvref_t
<
tuple_element_t
<
3
,
ABDsEGridDesc
>>
;
using
AGridDesc_M_K
=
decltype
(
transform_k0_m_k1_to_m_k
(
AGridDesc_AK0_M_AK1
{}));
using
BGridDesc_N_K
=
decltype
(
transform_k0_m_k1_to_m_k
(
BGridDesc_BK0_N_BK1
{}));
using
DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
=
decltype
(
GridwiseGemm
::
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
DsGridDesc_M_N
{}));
using
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
=
decltype
(
GridwiseGemm
::
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
EGridDesc_M_N
{}));
// block-to-e-tile map
using
Block2ETileMap
=
remove_cvref_t
<
decltype
(
GridwiseGemm
::
MakeDefaultBlock2ETileMap
(
EGridDesc_M_N
{}))
>
;
// Argument
struct
Argument
:
public
BaseArgument
{
Argument
(
const
void
*
p_a
,
// output image
const
void
*
p_b
,
// weight
const
std
::
array
<
const
void
*
,
NumDTensor
>&
p_ds
,
// bias
void
*
p_e
,
// input image
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
a_g_n_k_wos_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
a_g_n_k_wos_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
b_g_k_c_xs_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
b_g_k_c_xs_strides
,
const
std
::
array
<
std
::
array
<
index_t
,
NDimSpatial
+
3
>
,
NumDTensor
>&
ds_g_n_c_wis_lengths
,
const
std
::
array
<
std
::
array
<
index_t
,
NDimSpatial
+
3
>
,
NumDTensor
>&
ds_g_n_c_wis_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
e_g_n_c_wis_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
e_g_n_c_wis_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_dilations
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_left_pads
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_right_pads
,
const
AElementwiseOp
&
a_element_op
,
const
BElementwiseOp
&
b_element_op
,
const
CDEElementwiseOp
&
cde_element_op
)
:
p_a_grid_
{
static_cast
<
const
ADataType
*>
(
p_a
)},
p_b_grid_
{
static_cast
<
const
BDataType
*>
(
p_b
)},
p_ds_grid_
{},
p_e_grid_
{
static_cast
<
EDataType
*>
(
p_e
)},
num_group_
{
a_g_n_k_wos_lengths
[
0
]},
num_gemm_
{},
a_element_op_
{
a_element_op
},
b_element_op_
{
b_element_op
},
cde_element_op_
{
cde_element_op
},
a_g_n_k_wos_lengths_
{
a_g_n_k_wos_lengths
},
a_g_n_k_wos_strides_
{
a_g_n_k_wos_strides
},
b_g_k_c_xs_lengths_
{
b_g_k_c_xs_lengths
},
b_g_k_c_xs_strides_
{
b_g_k_c_xs_strides
},
ds_g_n_c_wis_lengths_
{
ds_g_n_c_wis_lengths
},
ds_g_n_c_wis_strides_
{
ds_g_n_c_wis_strides
},
e_g_n_c_wis_lengths_
{
e_g_n_c_wis_lengths
},
e_g_n_c_wis_strides_
{
e_g_n_c_wis_strides
},
conv_filter_strides_
{
conv_filter_strides
},
conv_filter_dilations_
{
conv_filter_dilations
},
input_left_pads_
{
input_left_pads
},
input_right_pads_
{
input_right_pads
}
{
// populate Ds pointer
static_for
<
0
,
NumDTensor
,
1
>
{}([
&
](
auto
i
)
{
using
DDataType
=
remove_cvref_t
<
tuple_element_t
<
i
.
value
,
DsDataType
>>
;
p_ds_grid_
(
i
)
=
static_cast
<
const
DDataType
*>
(
p_ds
[
i
]);
});
// A/B/Ds/E Batch Stride
compute_ptr_offset_of_batch_
.
BatchStrideA_
=
a_g_n_k_wos_strides
[
0
];
compute_ptr_offset_of_batch_
.
BatchStrideB_
=
b_g_k_c_xs_strides
[
0
];
compute_ptr_offset_of_batch_
.
BatchStrideE_
=
e_g_n_c_wis_strides
[
0
];
static_for
<
0
,
NumDTensor
,
1
>
{}([
&
](
auto
i
)
{
compute_ptr_offset_of_batch_
.
BatchStrideDs_
(
i
)
=
ds_g_n_c_wis_strides
[
i
][
0
];
});
// problem definition
const
index_t
Y
=
b_g_k_c_xs_lengths
[
3
];
const
index_t
X
=
b_g_k_c_xs_lengths
[
4
];
const
index_t
ConvStrideH
=
conv_filter_strides_
[
0
];
const
index_t
ConvStrideW
=
conv_filter_strides_
[
1
];
const
index_t
ConvDilationH
=
conv_filter_dilations_
[
0
];
const
index_t
ConvDilationW
=
conv_filter_dilations_
[
1
];
const
auto
GcdStrideDilationH
=
math
::
gcd
(
ConvStrideH
,
ConvDilationH
);
const
auto
GcdStrideDilationW
=
math
::
gcd
(
ConvStrideW
,
ConvDilationW
);
const
auto
YTilde
=
ConvStrideH
/
GcdStrideDilationH
;
const
auto
XTilde
=
ConvStrideW
/
GcdStrideDilationW
;
// number of GEMM
num_gemm_
=
YTilde
*
XTilde
;
for
(
index_t
i_ytilde
=
0
;
i_ytilde
<
YTilde
;
++
i_ytilde
)
{
for
(
index_t
i_xtilde
=
0
;
i_xtilde
<
XTilde
;
++
i_xtilde
)
{
// check slice is valid
const
auto
YDotSlice
=
math
::
integer_divide_ceil
(
Y
-
i_ytilde
,
YTilde
);
const
auto
XDotSlice
=
math
::
integer_divide_ceil
(
X
-
i_xtilde
,
XTilde
);
if
(
YDotSlice
*
XDotSlice
<=
0
)
{
continue
;
}
const
auto
a_grid_desc_ak0_m_ak1
=
transform_conv_to_gemm
.
template
MakeADescriptor_AK0_M_AK1
<
ALayout
>(
a_g_n_k_wos_lengths
,
a_g_n_k_wos_strides
,
b_g_k_c_xs_lengths
,
b_g_k_c_xs_strides
,
e_g_n_c_wis_lengths
,
e_g_n_c_wis_strides
,
conv_filter_strides
,
conv_filter_dilations
,
input_left_pads
,
input_right_pads
,
{
i_ytilde
,
i_xtilde
});
const
auto
b_grid_desc_bk0_n_bk1
=
transform_conv_to_gemm
.
template
MakeBDescriptor_BK0_N_BK1
<
BLayout
>(
a_g_n_k_wos_lengths
,
a_g_n_k_wos_strides
,
b_g_k_c_xs_lengths
,
b_g_k_c_xs_strides
,
e_g_n_c_wis_lengths
,
e_g_n_c_wis_strides
,
conv_filter_strides
,
conv_filter_dilations
,
input_left_pads
,
input_right_pads
,
{
i_ytilde
,
i_xtilde
});
DsGridDesc_M_N
ds_grid_desc_m_n
;
// populate Ds desc
static_for
<
0
,
NumDTensor
,
1
>
{}([
&
](
auto
i
)
{
using
DLayout
=
remove_cvref_t
<
tuple_element_t
<
i
.
value
,
DsLayout
>>
;
ds_grid_desc_m_n
(
i
)
=
transform_conv_to_gemm
.
template
MakeCDescriptor_M_N
<
DLayout
>(
a_g_n_k_wos_lengths
,
a_g_n_k_wos_strides
,
b_g_k_c_xs_lengths
,
b_g_k_c_xs_strides
,
ds_g_n_c_wis_lengths
[
i
],
ds_g_n_c_wis_strides
[
i
],
conv_filter_strides
,
conv_filter_dilations
,
input_left_pads
,
input_right_pads
,
{
i_ytilde
,
i_xtilde
});
});
const
auto
e_grid_desc_m_n
=
transform_conv_to_gemm
.
template
MakeCDescriptor_M_N
<
ELayout
>(
a_g_n_k_wos_lengths
,
a_g_n_k_wos_strides
,
b_g_k_c_xs_lengths
,
b_g_k_c_xs_strides
,
e_g_n_c_wis_lengths
,
e_g_n_c_wis_strides
,
conv_filter_strides
,
conv_filter_dilations
,
input_left_pads
,
input_right_pads
,
{
i_ytilde
,
i_xtilde
});
// desc for problem definition
const
auto
a_grid_desc_m_k
=
transform_k0_m_k1_to_m_k
(
a_grid_desc_ak0_m_ak1
);
const
auto
b_grid_desc_n_k
=
transform_k0_m_k1_to_m_k
(
b_grid_desc_bk0_n_bk1
);
a_grid_desc_m_k_container_
.
push_back
(
a_grid_desc_m_k
);
b_grid_desc_n_k_container_
.
push_back
(
b_grid_desc_n_k
);
ds_grid_desc_m_n_container_
.
push_back
(
ds_grid_desc_m_n
);
e_grid_desc_m_n_container_
.
push_back
(
e_grid_desc_m_n
);
// desc for blockwise copy
a_grid_desc_ak0_m_ak1_container_
.
push_back
(
a_grid_desc_ak0_m_ak1
);
b_grid_desc_bk0_n_bk1_container_
.
push_back
(
b_grid_desc_bk0_n_bk1
);
// block-to-e-tile-map
auto
block_2_etile_map
=
GridwiseGemm
::
MakeDefaultBlock2ETileMap
(
e_grid_desc_m_n
);
block_2_etile_map_container_
.
push_back
(
block_2_etile_map
);
if
(
GridwiseGemm
::
CheckValidity
(
a_grid_desc_m_k
,
b_grid_desc_n_k
,
ds_grid_desc_m_n
,
e_grid_desc_m_n
,
block_2_etile_map
))
{
ds_grid_desc_mblock_mperblock_nblock_nperblock_container_
.
push_back
(
GridwiseGemm
::
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
ds_grid_desc_m_n
));
e_grid_desc_mblock_mperblock_nblock_nperblock_container_
.
push_back
(
GridwiseGemm
::
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
e_grid_desc_m_n
));
}
}
}
}
void
Print
()
const
{
for
(
index_t
i
=
0
;
i
<
num_gemm_
;
i
++
)
{
std
::
cout
<<
"a_grid_desc_ak0_m_ak1_container_"
<<
a_grid_desc_ak0_m_ak1_container_
[
i
]
<<
std
::
endl
;
std
::
cout
<<
"b_grid_desc_bk0_n_bk1_container_"
<<
b_grid_desc_bk0_n_bk1_container_
[
i
]
<<
std
::
endl
;
static_for
<
0
,
NumDTensor
,
1
>
{}([
&
](
auto
j
)
{
std
::
cout
<<
"ds_grid_desc_mblock_mperblock_nblock_nperblock_container_"
<<
ds_grid_desc_mblock_mperblock_nblock_nperblock_container_
[
i
][
j
]
<<
std
::
endl
;
});
std
::
cout
<<
"e_grid_desc_mblock_mperblock_nblock_nperblock_container_"
<<
e_grid_desc_mblock_mperblock_nblock_nperblock_container_
[
i
]
<<
std
::
endl
;
}
}
// pointers
const
ADataType
*
p_a_grid_
;
const
BDataType
*
p_b_grid_
;
typename
GridwiseGemm
::
DsGridPointer
p_ds_grid_
;
EDataType
*
p_e_grid_
;
// tensor descriptor for problem definition
index_t
num_group_
;
index_t
num_gemm_
;
std
::
vector
<
AGridDesc_M_K
>
a_grid_desc_m_k_container_
;
std
::
vector
<
BGridDesc_N_K
>
b_grid_desc_n_k_container_
;
std
::
vector
<
DsGridDesc_M_N
>
ds_grid_desc_m_n_container_
;
std
::
vector
<
EGridDesc_M_N
>
e_grid_desc_m_n_container_
;
// tensor descriptor for block-wise copy
std
::
vector
<
AGridDesc_AK0_M_AK1
>
a_grid_desc_ak0_m_ak1_container_
;
std
::
vector
<
BGridDesc_BK0_N_BK1
>
b_grid_desc_bk0_n_bk1_container_
;
std
::
vector
<
DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
>
ds_grid_desc_mblock_mperblock_nblock_nperblock_container_
;
std
::
vector
<
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
>
e_grid_desc_mblock_mperblock_nblock_nperblock_container_
;
// block-to-e-tile map
std
::
vector
<
Block2ETileMap
>
block_2_etile_map_container_
;
// for computing batch offset
ComputePtrOffsetOfStridedBatch
<
NumDTensor
>
compute_ptr_offset_of_batch_
;
// element-wise op
AElementwiseOp
a_element_op_
;
BElementwiseOp
b_element_op_
;
CDEElementwiseOp
cde_element_op_
;
// for checking IsSupportedArgument()
std
::
array
<
index_t
,
NDimSpatial
+
3
>
a_g_n_k_wos_lengths_
;
std
::
array
<
index_t
,
NDimSpatial
+
3
>
a_g_n_k_wos_strides_
;
std
::
array
<
index_t
,
NDimSpatial
+
3
>
b_g_k_c_xs_lengths_
;
std
::
array
<
index_t
,
NDimSpatial
+
3
>
b_g_k_c_xs_strides_
;
std
::
array
<
std
::
array
<
index_t
,
NDimSpatial
+
3
>
,
NumDTensor
>
ds_g_n_c_wis_lengths_
;
std
::
array
<
std
::
array
<
index_t
,
NDimSpatial
+
3
>
,
NumDTensor
>
ds_g_n_c_wis_strides_
;
std
::
array
<
index_t
,
NDimSpatial
+
3
>
e_g_n_c_wis_lengths_
;
std
::
array
<
index_t
,
NDimSpatial
+
3
>
e_g_n_c_wis_strides_
;
std
::
array
<
index_t
,
NDimSpatial
>
conv_filter_strides_
;
std
::
array
<
index_t
,
NDimSpatial
>
conv_filter_dilations_
;
std
::
array
<
index_t
,
NDimSpatial
>
input_left_pads_
;
std
::
array
<
index_t
,
NDimSpatial
>
input_right_pads_
;
};
// Invoker
struct
Invoker
:
public
BaseInvoker
{
using
Argument
=
DeviceOp
::
Argument
;
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
{
if
(
stream_config
.
log_level_
>
0
)
{
arg
.
Print
();
}
float
ave_time
=
0
;
for
(
index_t
i
=
0
;
i
<
arg
.
num_gemm_
;
i
++
)
{
if
(
!
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_m_k_container_
[
i
],
arg
.
b_grid_desc_n_k_container_
[
i
],
arg
.
ds_grid_desc_m_n_container_
[
i
],
arg
.
e_grid_desc_m_n_container_
[
i
],
arg
.
block_2_etile_map_container_
[
i
]))
{
throw
std
::
runtime_error
(
"wrong! device_op has invalid setting"
);
}
const
index_t
grid_size
=
arg
.
block_2_etile_map_container_
[
i
].
CalculateGridSize
(
arg
.
e_grid_desc_m_n_container_
[
i
])
*
arg
.
num_group_
;
const
auto
GemmK
=
arg
.
a_grid_desc_m_k_container_
[
i
].
GetLength
(
I1
);
auto
launch_kernel
=
[
&
](
auto
has_main_k_block_loop
)
{
constexpr
bool
has_main_loop
=
has_main_k_block_loop
.
value
;
const
auto
kernel
=
kernel_grouped_conv_bwd_data_multiple_d_xdl_cshuffle
<
GridwiseGemm
,
ADataType
,
// TODO: distiguish A/B datatype
typename
GridwiseGemm
::
DsGridPointer
,
EDataType
,
AElementwiseOp
,
BElementwiseOp
,
CDEElementwiseOp
,
DeviceOp
::
AGridDesc_AK0_M_AK1
,
DeviceOp
::
BGridDesc_BK0_N_BK1
,
DeviceOp
::
DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
,
DeviceOp
::
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
,
Block2ETileMap
,
ComputePtrOffsetOfStridedBatch
<
NumDTensor
>
,
has_main_loop
>
;
return
launch_and_time_kernel
(
stream_config
,
kernel
,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
arg
.
p_a_grid_
,
arg
.
p_b_grid_
,
arg
.
p_ds_grid_
,
arg
.
p_e_grid_
,
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
cde_element_op_
,
arg
.
a_g_n_k_wos_lengths_
[
0
],
// Group count
arg
.
a_grid_desc_ak0_m_ak1_container_
[
i
],
arg
.
b_grid_desc_bk0_n_bk1_container_
[
i
],
arg
.
ds_grid_desc_mblock_mperblock_nblock_nperblock_container_
[
i
],
arg
.
e_grid_desc_mblock_mperblock_nblock_nperblock_container_
[
i
],
arg
.
block_2_etile_map_container_
[
i
],
arg
.
compute_ptr_offset_of_batch_
);
};
if
(
GridwiseGemm
::
CalculateHasMainKBlockLoop
(
GemmK
))
{
ave_time
+=
launch_kernel
(
integral_constant
<
bool
,
true
>
{});
}
else
{
ave_time
+=
launch_kernel
(
integral_constant
<
bool
,
false
>
{});
}
}
return
ave_time
;
}
float
Run
(
const
BaseArgument
*
p_arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
override
{
return
Run
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
),
stream_config
);
}
};
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
const
index_t
ConvK
=
arg
.
b_g_k_c_xs_lengths_
[
1
];
const
index_t
ConvC
=
arg
.
b_g_k_c_xs_lengths_
[
2
];
// Specifialization
if
constexpr
(
ConvBackwardDataSpecialization
==
ConvolutionBackwardDataSpecialization
::
Filter1x1Stride1Pad0
)
{
// check if it's 1x1, stride=1 pad = 0 conv
for
(
int
i
=
0
;
i
<
NDimSpatial
;
i
++
)
{
if
(
!
(
arg
.
b_g_k_c_xs_lengths_
[
3
+
i
]
==
1
&&
arg
.
conv_filter_strides_
[
i
]
==
1
&&
arg
.
input_left_pads_
[
i
]
==
0
&&
arg
.
input_right_pads_
[
i
]
==
0
))
{
return
false
;
}
}
}
// vector load for A matrix from global memory to LDS
if
constexpr
(
is_same_v
<
ALayout
,
tensor_layout
::
convolution
::
GNHWK
>
)
{
if
(
!
(
ABlockTransferSrcVectorDim
==
2
&&
ConvK
%
ABlockTransferSrcScalarPerVector
==
0
))
{
return
false
;
}
}
else
{
return
false
;
}
// vector load for B matrix from global memory to LDS
if
constexpr
(
is_same_v
<
BLayout
,
tensor_layout
::
convolution
::
GKYXC
>
)
{
if
(
!
(
BBlockTransferSrcVectorDim
==
1
&&
ConvC
%
BBlockTransferSrcScalarPerVector
==
0
))
{
return
false
;
}
}
else
{
return
false
;
}
// vector store for Ds
bool
ds_valid
=
true
;
static_for
<
0
,
NumDTensor
,
1
>
{}([
&
](
auto
i
)
{
using
DLayout
=
remove_cvref_t
<
tuple_element_t
<
i
.
value
,
DsLayout
>>
;
if
constexpr
(
is_same_v
<
DLayout
,
tensor_layout
::
convolution
::
GNHWC
>
||
is_same_v
<
DLayout
,
tensor_layout
::
convolution
::
NHWGC
>
||
is_same_v
<
DLayout
,
tensor_layout
::
convolution
::
G_NHW_C
>
||
is_same_v
<
DLayout
,
tensor_layout
::
convolution
::
GC
>
||
is_same_v
<
DLayout
,
tensor_layout
::
convolution
::
G_C
>
)
{
// vector load D matrix from global memory
if
(
!
(
ConvC
%
CDEBlockTransferScalarPerVector_NPerBlock
==
0
))
{
ds_valid
=
false
;
}
}
else
{
ds_valid
=
false
;
}
});
if
(
!
ds_valid
)
{
return
false
;
}
// vector store for E
if
constexpr
(
is_same_v
<
ELayout
,
tensor_layout
::
convolution
::
GNHWC
>
)
{
// vector store C matrix into global memory
if
(
!
(
ConvC
%
CDEBlockTransferScalarPerVector_NPerBlock
==
0
))
{
return
false
;
}
}
else
{
return
false
;
}
// Gridwise GEMM size
for
(
std
::
size_t
i
=
0
;
i
<
arg
.
a_grid_desc_ak0_m_ak1_container_
.
size
();
i
++
)
{
if
(
!
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_m_k_container_
[
i
],
arg
.
b_grid_desc_n_k_container_
[
i
],
arg
.
ds_grid_desc_m_n_container_
[
i
],
arg
.
e_grid_desc_m_n_container_
[
i
],
arg
.
block_2_etile_map_container_
[
i
]))
{
return
false
;
}
}
return
true
;
}
bool
IsSupportedArgument
(
const
BaseArgument
*
p_arg
)
override
{
return
IsSupportedArgument
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
));
}
static
auto
MakeArgument
(
const
void
*
p_a
,
// output image
const
void
*
p_b
,
// weight
const
std
::
array
<
const
void
*
,
NumDTensor
>&
p_ds
,
// bias
void
*
p_e
,
// input image
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
a_g_n_k_wos_lengths
,
// output image
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
a_g_n_k_wos_strides
,
// output image
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
b_g_k_c_xs_lengths
,
// weight
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
b_g_k_c_xs_strides
,
// weight
const
std
::
array
<
std
::
array
<
index_t
,
NDimSpatial
+
3
>
,
NumDTensor
>&
ds_g_n_c_wis_lengths
,
// bias
const
std
::
array
<
std
::
array
<
index_t
,
NDimSpatial
+
3
>
,
NumDTensor
>&
ds_g_n_c_wis_strides
,
// bias
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
e_g_n_c_wis_lengths
,
// input image
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
e_g_n_c_wis_strides
,
// input image
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_dilations
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_left_pads
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_right_pads
,
const
AElementwiseOp
&
a_element_op
,
const
BElementwiseOp
&
b_element_op
,
const
CDEElementwiseOp
&
cde_element_op
)
{
return
Argument
{
p_a
,
p_b
,
p_ds
,
p_e
,
a_g_n_k_wos_lengths
,
a_g_n_k_wos_strides
,
b_g_k_c_xs_lengths
,
b_g_k_c_xs_strides
,
ds_g_n_c_wis_lengths
,
ds_g_n_c_wis_strides
,
e_g_n_c_wis_lengths
,
e_g_n_c_wis_strides
,
conv_filter_strides
,
conv_filter_dilations
,
input_left_pads
,
input_right_pads
,
a_element_op
,
b_element_op
,
cde_element_op
};
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
void
*
p_a
,
// output image
const
void
*
p_b
,
// weight
const
std
::
array
<
const
void
*
,
NumDTensor
>&
p_ds
,
// bias
void
*
p_e
,
// input image
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
a_g_n_k_wos_lengths
,
// output image
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
a_g_n_k_wos_strides
,
// output image
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
b_g_k_c_xs_lengths
,
// weight
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
b_g_k_c_xs_strides
,
// weight
const
std
::
array
<
std
::
array
<
index_t
,
NDimSpatial
+
3
>
,
NumDTensor
>&
ds_g_n_c_wis_lengths
,
// bias
const
std
::
array
<
std
::
array
<
index_t
,
NDimSpatial
+
3
>
,
NumDTensor
>&
ds_g_n_c_wis_strides
,
// bias
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
e_g_n_c_wis_lengths
,
// input image
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
e_g_n_c_wis_strides
,
// input image
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_dilations
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_left_pads
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_right_pads
,
const
AElementwiseOp
&
a_element_op
,
const
BElementwiseOp
&
b_element_op
,
const
CDEElementwiseOp
&
cde_element_op
)
override
{
return
std
::
make_unique
<
Argument
>
(
p_a
,
p_b
,
p_ds
,
p_e
,
a_g_n_k_wos_lengths
,
a_g_n_k_wos_strides
,
b_g_k_c_xs_lengths
,
b_g_k_c_xs_strides
,
ds_g_n_c_wis_lengths
,
ds_g_n_c_wis_strides
,
e_g_n_c_wis_lengths
,
e_g_n_c_wis_strides
,
conv_filter_strides
,
conv_filter_dilations
,
input_left_pads
,
input_right_pads
,
a_element_op
,
b_element_op
,
cde_element_op
);
}
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
override
{
return
std
::
make_unique
<
Invoker
>
(
Invoker
{});
}
std
::
string
GetTypeString
()
const
override
{
auto
str
=
std
::
stringstream
();
// clang-format off
str
<<
"DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1"
<<
"<"
<<
BlockSize
<<
", "
<<
MPerBlock
<<
", "
<<
NPerBlock
<<
", "
<<
KPerBlock
<<
", "
<<
AK1
<<
", "
<<
BK1
<<
", "
<<
getConvBackwardDataSpecializationString
(
ConvBackwardDataSpecialization
)
<<
">"
;
return
str
.
str
();
}
};
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
include/ck/tensor_operation/gpu/device/impl/device_permute_impl.hpp
0 → 100644
View file @
5aa3c344
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <array>
#include <memory>
#include <utility>
#include "ck/utility/math.hpp"
#include "ck/utility/sequence.hpp"
#include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/tensor_operation/gpu/device/device_permute.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_permute.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/host_utility/kernel_launch.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
// Swap last 2 dimensions
// input shape: [d[0], d[1], d[2], ..., d[NumDim-3], d[NumDim-2], d[NumDim-1]]
// ^^^^^^^^^^^
// output shape: [d[0], d[1], d[2], ..., d[NumDim-3], d[NumDim-1], d[NumDim-2]]
// ^^^^^^^^^^^
template
<
index_t
NumDim
,
typename
InDataType
,
typename
OutDataType
,
typename
ElementwiseOperation
,
index_t
BlockSize
,
index_t
NPerBlock
,
index_t
HPerBlock
,
index_t
WPerBlock
,
index_t
InBlockLdsExtraW
,
typename
InBlockTransferThreadClusterLengths
,
typename
InBlockTransferThreadClusterArrangeOrder
,
index_t
SrcVectorDim
,
index_t
DstVectorDim
,
index_t
SrcScalarPerVector
,
index_t
DstScalarPerVector
>
struct
DevicePermuteImpl
:
DevicePermute
<
NumDim
,
InDataType
,
OutDataType
,
ElementwiseOperation
>
{
using
BaseType
=
DevicePermute
<
NumDim
,
InDataType
,
OutDataType
,
ElementwiseOperation
>
;
using
typename
BaseType
::
Lengths
;
using
typename
BaseType
::
Strides
;
static_assert
(
3
<=
NumDim
,
"Only accept at least 3D dimension tensor"
);
static_assert
((
NumDim
-
2
)
<=
SrcVectorDim
&&
SrcVectorDim
<
NumDim
);
static_assert
((
NumDim
-
2
)
<=
DstVectorDim
&&
DstVectorDim
<
NumDim
);
static_assert
(
SrcVectorDim
!=
DstVectorDim
);
template
<
index_t
N
=
NumDim
>
static
auto
ConvertArrayToTuple
(
const
std
::
array
<
index_t
,
NumDim
>&
array
)
{
static_assert
(
1
<=
N
&&
N
<=
NumDim
);
return
generate_tuple
([
&
](
auto
I
)
{
return
array
[
I
];
},
Number
<
N
>
{});
}
static
auto
MakeDescriptor_N_H_W
(
const
Lengths
&
lengths
,
const
Strides
&
stride
)
{
// create nd descriptor, shape: [d[0], d[1], d[2], ..., d[NumDim-3], d[NumDim-2],
// d[NumDim-1]]
const
auto
desc
=
make_naive_tensor_descriptor
(
ConvertArrayToTuple
(
lengths
),
ConvertArrayToTuple
(
stride
));
// merge nd to 3d descriptor, shape: [(d[0] * d[1] * d[2] * ... * d[NumDim-3]), d[NumDim-2],
// d[NumDim-1]]
// => [N, H, W]
const
index_t
H
=
*
std
::
next
(
rbegin
(
lengths
));
const
index_t
W
=
*
rbegin
(
lengths
);
const
auto
desc_n_h_w
=
transform_tensor_descriptor
(
desc
,
make_tuple
(
make_merge_transform
(
ConvertArrayToTuple
<
NumDim
-
2
>
(
lengths
)),
make_pass_through_transform
(
H
),
make_pass_through_transform
(
W
)),
make_tuple
(
generate_sequence_v2
([
&
](
auto
I
)
{
return
I
;
},
Number
<
NumDim
-
2
>
{}),
Sequence
<
NumDim
-
2
>
{},
Sequence
<
NumDim
-
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}));
return
PadTensorDescriptor
(
desc_n_h_w
,
make_tuple
(
NPerBlock
,
HPerBlock
,
WPerBlock
),
Sequence
<
true
,
true
,
true
>
{});
}
using
InGridDesc
=
decltype
(
MakeDescriptor_N_H_W
({
1
,
1
},
{
1
,
1
}));
using
OutGridDesc
=
InGridDesc
;
using
GridwisePermute
=
GridwisePermute
<
InGridDesc
,
OutGridDesc
,
InDataType
,
OutDataType
,
ElementwiseOperation
,
BlockSize
,
NPerBlock
,
HPerBlock
,
WPerBlock
,
InBlockLdsExtraW
,
InBlockTransferThreadClusterLengths
,
InBlockTransferThreadClusterArrangeOrder
,
SrcVectorDim
-
(
NumDim
-
3
),
// calculate new SrcVectorDim for the merged descriptor
DstVectorDim
-
(
NumDim
-
3
),
// calculate new DstVectorDim for the merged descriptor
SrcScalarPerVector
,
DstScalarPerVector
>
;
using
Block2TileMap
=
typename
GridwisePermute
::
DefaultBlock2TileMap
;
struct
Argument
:
public
BaseArgument
{
Argument
(
const
Lengths
&
in_lengths
,
const
Strides
&
in_strides
,
const
Lengths
&
out_lengths
,
const
Strides
&
out_strides
,
const
void
*
in_dev_buffer
,
void
*
out_dev_buffer
,
ElementwiseOperation
elementwise_op
)
:
in_dev_buffer_
(
static_cast
<
const
InDataType
*>
(
in_dev_buffer
)),
out_dev_buffer_
(
static_cast
<
OutDataType
*>
(
out_dev_buffer
)),
in_grid_desc_
(
MakeDescriptor_N_H_W
(
in_lengths
,
in_strides
)),
out_grid_desc_
(
MakeDescriptor_N_H_W
(
out_lengths
,
out_strides
)),
in_lengths_
(
in_lengths
),
in_strides_
(
in_strides
),
out_lengths_
(
out_lengths
),
out_strides_
(
out_strides
),
elementwise_op_
(
elementwise_op
),
block_2_tile_map_
(
GridwisePermute
::
MakeDefaultBlock2TileMap
(
in_grid_desc_
))
{
}
const
InDataType
*
in_dev_buffer_
;
OutDataType
*
out_dev_buffer_
;
InGridDesc
in_grid_desc_
;
OutGridDesc
out_grid_desc_
;
Lengths
in_lengths_
;
Strides
in_strides_
;
Lengths
out_lengths_
;
Strides
out_strides_
;
ElementwiseOperation
elementwise_op_
;
Block2TileMap
block_2_tile_map_
;
};
struct
Invoker
:
BaseInvoker
{
static
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
{
const
index_t
grid_size
=
arg
.
block_2_tile_map_
.
CalculateGridSize
(
arg
.
in_grid_desc_
);
const
auto
kernel
=
kernel_nd_permute
<
GridwisePermute
,
InGridDesc
,
OutGridDesc
,
InDataType
,
OutDataType
,
ElementwiseOperation
,
Block2TileMap
>
;
float
elapsed_time
=
launch_and_time_kernel
(
stream_config
,
kernel
,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
arg
.
in_grid_desc_
,
arg
.
out_grid_desc_
,
arg
.
in_dev_buffer_
,
arg
.
out_dev_buffer_
,
arg
.
elementwise_op_
,
arg
.
block_2_tile_map_
);
return
elapsed_time
;
}
float
Run
(
const
BaseArgument
*
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
override
final
{
const
auto
*
const
argument
=
dynamic_cast
<
const
Argument
*>
(
arg
);
if
(
!
argument
)
{
return
NAN
;
}
return
Run
(
*
argument
,
stream_config
);
}
};
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
constexpr
auto
GetPaddedLength
=
[](
index_t
length
,
index_t
tile_length
)
{
return
math
::
integer_divide_ceil
(
length
,
tile_length
)
*
tile_length
;
};
constexpr
auto
IsScalarPerVectorValid
=
[](
index_t
length
,
index_t
stride
,
index_t
scalar_per_vector
)
{
if
(
stride
==
1
&&
length
%
scalar_per_vector
==
0
)
{
return
true
;
}
else
if
(
stride
!=
1
&&
scalar_per_vector
==
1
)
{
return
true
;
}
return
false
;
};
return
IsScalarPerVectorValid
(
arg
.
in_lengths_
[
SrcVectorDim
],
arg
.
in_strides_
[
SrcVectorDim
],
SrcScalarPerVector
)
&&
IsScalarPerVectorValid
(
GetPaddedLength
(
arg
.
in_lengths_
[
SrcVectorDim
],
(
SrcVectorDim
==
NumDim
-
2
?
HPerBlock
:
WPerBlock
)),
arg
.
in_strides_
[
SrcVectorDim
],
SrcScalarPerVector
)
&&
IsScalarPerVectorValid
(
arg
.
out_lengths_
[
DstVectorDim
],
arg
.
out_strides_
[
DstVectorDim
],
DstScalarPerVector
)
&&
IsScalarPerVectorValid
(
GetPaddedLength
(
arg
.
out_lengths_
[
DstVectorDim
],
(
DstVectorDim
==
NumDim
-
2
?
HPerBlock
:
WPerBlock
)),
arg
.
in_strides_
[
DstVectorDim
],
DstScalarPerVector
)
&&
GridwisePermute
::
CheckValidity
(
arg
.
in_grid_desc_
,
arg
.
out_grid_desc_
);
};
// override methods inherited from 'BaseOperator'
bool
IsSupportedArgument
(
const
BaseArgument
*
arg
)
override
final
{
const
auto
*
const
argument
=
dynamic_cast
<
const
Argument
*>
(
arg
);
if
(
!
argument
)
{
return
false
;
}
return
IsSupportedArgument
(
*
argument
);
}
// override methods inherited from 'DevicePermute'
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
Lengths
&
in_lengths
,
const
Strides
&
in_strides
,
const
Lengths
&
out_lengths
,
const
Strides
&
out_strides
,
const
void
*
in_dev_buffer
,
void
*
out_dev_buffer
,
ElementwiseOperation
elementwise_op
)
override
final
{
return
std
::
make_unique
<
Argument
>
(
in_lengths
,
in_strides
,
out_lengths
,
out_strides
,
in_dev_buffer
,
out_dev_buffer
,
elementwise_op
);
}
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
override
final
{
return
std
::
make_unique
<
Invoker
>
();
};
// other constructor methods
template
<
typename
...
Args
>
static
std
::
enable_if_t
<
std
::
is_constructible_v
<
Argument
,
Args
...
>
,
Argument
>
MakeArgument
(
Args
&&
...
args
)
noexcept
(
std
::
is_nothrow_constructible_v
<
Argument
,
Args
...
>
)
{
return
Argument
{
std
::
forward
<
Args
>
(
args
)...};
}
static
std
::
enable_if_t
<
std
::
is_default_constructible_v
<
Invoker
>
,
Invoker
>
MakeInvoker
()
noexcept
(
std
::
is_nothrow_default_constructible_v
<
Invoker
>
)
{
return
Invoker
{};
}
};
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
include/ck/tensor_operation/gpu/device/matrix_padder.hpp
View file @
5aa3c344
...
...
@@ -218,6 +218,165 @@ struct GemmPadder_v2
KPerTileType
KPerTile_
;
};
// M/N/KPerTileType could be index_t or Number<>
template
<
bool
PadM
,
bool
PadN
,
bool
PadK
,
typename
MPerTileType
,
typename
NPerTileType
,
typename
KPerTileType
>
struct
MatrixPadder_v2
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
template
<
typename
ADesc_MRaw_KRaw
>
__host__
__device__
constexpr
auto
PadADescriptor_M_K
(
const
ADesc_MRaw_KRaw
&
a_desc_mraw_kraw
)
const
{
const
auto
MRaw
=
a_desc_mraw_kraw
.
GetLength
(
I0
);
const
auto
KRaw
=
a_desc_mraw_kraw
.
GetLength
(
I1
);
const
auto
M
=
math
::
integer_divide_ceil
(
MRaw
,
MPerTile_
)
*
MPerTile_
;
const
auto
K
=
math
::
integer_divide_ceil
(
KRaw
,
KPerTile_
)
*
KPerTile_
;
const
auto
MPad
=
M
-
MRaw
;
const
auto
KPad
=
K
-
KRaw
;
if
constexpr
(
PadM
&&
PadK
)
{
// pad both M and K
return
transform_tensor_descriptor
(
a_desc_mraw_kraw
,
make_tuple
(
make_right_pad_transform
(
MRaw
,
MPad
),
make_right_pad_transform
(
KRaw
,
KPad
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
}
else
if
constexpr
(
PadM
&&
(
!
PadK
))
{
// pad M, but not K
return
transform_tensor_descriptor
(
a_desc_mraw_kraw
,
make_tuple
(
make_right_pad_transform
(
MRaw
,
MPad
),
make_pass_through_transform
(
KRaw
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
}
else
if
constexpr
((
!
PadM
)
&&
PadK
)
{
// pad K, but not M
return
transform_tensor_descriptor
(
a_desc_mraw_kraw
,
make_tuple
(
make_pass_through_transform
(
MRaw
),
make_right_pad_transform
(
KRaw
,
KPad
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
}
else
{
// not pad M or K
return
a_desc_mraw_kraw
;
}
}
template
<
typename
BDesc_NRaw_KRaw
>
__host__
__device__
constexpr
auto
PadBDescriptor_N_K
(
const
BDesc_NRaw_KRaw
&
b_desc_nraw_kraw
)
const
{
const
auto
NRaw
=
b_desc_nraw_kraw
.
GetLength
(
I0
);
const
auto
KRaw
=
b_desc_nraw_kraw
.
GetLength
(
I1
);
const
auto
N
=
math
::
integer_divide_ceil
(
NRaw
,
NPerTile_
)
*
NPerTile_
;
const
auto
K
=
math
::
integer_divide_ceil
(
KRaw
,
KPerTile_
)
*
KPerTile_
;
const
auto
NPad
=
N
-
NRaw
;
const
auto
KPad
=
K
-
KRaw
;
if
constexpr
(
PadN
&&
PadK
)
{
// pad both N and K
return
transform_tensor_descriptor
(
b_desc_nraw_kraw
,
make_tuple
(
make_right_pad_transform
(
NRaw
,
NPad
),
make_right_pad_transform
(
KRaw
,
KPad
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
}
else
if
constexpr
(
PadN
&&
(
!
PadK
))
{
// pad N, but not K
return
transform_tensor_descriptor
(
b_desc_nraw_kraw
,
make_tuple
(
make_right_pad_transform
(
NRaw
,
NPad
),
make_pass_through_transform
(
KRaw
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
}
else
if
constexpr
((
!
PadN
)
&&
PadK
)
{
// pad K, but not N
return
transform_tensor_descriptor
(
b_desc_nraw_kraw
,
make_tuple
(
make_pass_through_transform
(
NRaw
),
make_right_pad_transform
(
KRaw
,
KPad
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
}
else
{
// not pad N or K
return
b_desc_nraw_kraw
;
}
}
template
<
typename
CDesc_MRaw_NRaw
>
__host__
__device__
constexpr
auto
PadCDescriptor_M_N
(
const
CDesc_MRaw_NRaw
&
c_desc_mraw_nraw
)
const
{
const
auto
MRaw
=
c_desc_mraw_nraw
.
GetLength
(
I0
);
const
auto
NRaw
=
c_desc_mraw_nraw
.
GetLength
(
I1
);
const
auto
M
=
math
::
integer_divide_ceil
(
MRaw
,
MPerTile_
)
*
MPerTile_
;
const
auto
N
=
math
::
integer_divide_ceil
(
NRaw
,
NPerTile_
)
*
NPerTile_
;
const
auto
MPad
=
M
-
MRaw
;
const
auto
NPad
=
N
-
NRaw
;
if
constexpr
(
PadM
&&
PadN
)
{
// pad M and N
return
transform_tensor_descriptor
(
c_desc_mraw_nraw
,
make_tuple
(
make_right_pad_transform
(
MRaw
,
MPad
),
make_right_pad_transform
(
NRaw
,
NPad
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
}
else
if
constexpr
(
PadM
&&
(
!
PadN
))
{
// pad M, but not N
return
transform_tensor_descriptor
(
c_desc_mraw_nraw
,
make_tuple
(
make_right_pad_transform
(
MRaw
,
MPad
),
make_pass_through_transform
(
NRaw
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
}
else
if
constexpr
((
!
PadM
)
&&
PadN
)
{
// pad N, but not M
return
transform_tensor_descriptor
(
c_desc_mraw_nraw
,
make_tuple
(
make_pass_through_transform
(
MRaw
),
make_right_pad_transform
(
NRaw
,
NPad
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
}
else
{
// not pad M or N
return
c_desc_mraw_nraw
;
}
}
MPerTileType
MPerTile_
;
NPerTileType
NPerTile_
;
KPerTileType
KPerTile_
;
};
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
include/ck/tensor_operation/gpu/device/tensor_layout.hpp
View file @
5aa3c344
...
...
@@ -92,6 +92,12 @@ struct GNDHWC : public BaseTensorLayout
static
constexpr
const
char
*
name
=
"GNDHWC"
;
};
// for input bias
struct
GC
:
public
BaseTensorLayout
{
static
constexpr
const
char
*
name
=
"GC"
;
};
// input tensor
// packed NWGC/NHWGC/NDHWGC
struct
NWGC
:
public
BaseTensorLayout
...
...
@@ -126,6 +132,12 @@ struct G_NDHW_C : public BaseTensorLayout
static
constexpr
const
char
*
name
=
"G_NDHW_C"
;
};
// for input bias
struct
G_C
:
public
BaseTensorLayout
{
static
constexpr
const
char
*
name
=
"G_C"
;
};
// weight tensor
// packed KCX/KCYX/KCZYX
struct
KCX
:
public
BaseTensorLayout
...
...
@@ -296,6 +308,12 @@ struct GNDHWK : public BaseTensorLayout
static
constexpr
const
char
*
name
=
"GNDHWK"
;
};
// for output bias
struct
GK
:
public
BaseTensorLayout
{
static
constexpr
const
char
*
name
=
"GK"
;
};
// output tensor
// packed NWGK/NHWGK/NDHWGK
struct
NWGK
:
public
BaseTensorLayout
...
...
@@ -330,6 +348,12 @@ struct G_NDHW_K : public BaseTensorLayout
static
constexpr
const
char
*
name
=
"G_NDHW_K"
;
};
// for output bias
struct
G_K
:
public
BaseTensorLayout
{
static
constexpr
const
char
*
name
=
"G_K"
;
};
// K-reduced output tensor (packed)
struct
GNW
:
public
BaseTensorLayout
{
...
...
include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp
View file @
5aa3c344
...
...
@@ -28,6 +28,13 @@ struct Add
y
=
x0
+
x1
;
};
template
<
>
__host__
__device__
constexpr
void
operator
()
<
float
>
(
float
&
y
,
const
float
&
x0
,
const
half_t
&
x1
)
const
{
y
=
x0
+
type_convert
<
half_t
>
(
x1
);
};
template
<
>
__host__
__device__
constexpr
void
operator
()
<
half_t
>
(
half_t
&
y
,
const
float
&
x0
,
const
half_t
&
x1
)
const
...
...
@@ -172,6 +179,14 @@ struct AddRelu
const
float
a
=
x0
+
x1
;
y
=
a
>
type_convert
<
half_t
>
(
0.0
f
)
?
a
:
type_convert
<
half_t
>
(
0.0
f
);
};
template
<
>
__host__
__device__
constexpr
void
operator
()
<
float
,
float
,
half_t
>
(
float
&
y
,
const
float
&
x0
,
const
half_t
&
x1
)
const
{
const
float
a
=
x0
+
type_convert
<
float
>
(
x1
);
y
=
a
>
0.0
f
?
a
:
0.0
f
;
};
};
struct
AddHardswish
...
...
@@ -210,6 +225,46 @@ struct AddHardswish
};
};
// C = A * B
// E = FastGelu(C + D)
struct
AddFastGelu
{
// Fast GeLU
// https://paperswithcode.com/method/gelu
// y = 0.5*x*(1+tanh(sqrt(2/pi)*(x+0.044715*x^3)))
__host__
__device__
static
constexpr
float
GetFastGeLU
(
float
x
)
{
const
float
u
=
2.
f
*
x
*
(
0.035677
f
*
x
*
x
+
0.797885
f
);
const
float
emu
=
exp
(
-
u
);
const
float
cdf
=
0.5
f
+
0.5
f
*
(
2.
f
/
(
1.
f
+
emu
)
-
1.
f
);
return
x
*
cdf
;
}
template
<
typename
T
>
static
inline
constexpr
bool
is_valid_param_type_v
=
std
::
is_same_v
<
T
,
float
>
||
std
::
is_same_v
<
T
,
half_t
>
||
std
::
is_same_v
<
T
,
bhalf_t
>
||
std
::
is_same_v
<
T
,
int32_t
>
||
std
::
is_same_v
<
T
,
int8_t
>
;
template
<
typename
E
,
typename
C
,
typename
D
>
__host__
__device__
constexpr
void
operator
()(
E
&
e
,
const
C
&
c
,
const
D
&
d
)
const
{
static_assert
(
is_valid_param_type_v
<
E
>
&&
is_valid_param_type_v
<
C
>
&&
is_valid_param_type_v
<
D
>
);
const
float
y
=
GetFastGeLU
(
type_convert
<
float
>
(
c
)
+
type_convert
<
float
>
(
d
));
e
=
type_convert
<
E
>
(
y
);
}
template
<
typename
D
>
__host__
__device__
constexpr
void
operator
()(
float
&
e
,
const
float
&
c
,
const
D
&
d
)
const
{
static_assert
(
is_valid_param_type_v
<
D
>
);
e
=
GetFastGeLU
(
c
+
type_convert
<
float
>
(
d
));
}
};
}
// namespace element_wise
}
// namespace tensor_operation
}
// namespace ck
include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp
View file @
5aa3c344
...
...
@@ -211,6 +211,42 @@ struct FastGelu
}
};
// https://paperswithcode.com/method/gelu
// y = 0.5*x*(1+erf(x/sqrt(2)))
struct
Gelu
{
template
<
typename
Y
,
typename
X
>
__host__
__device__
void
operator
()(
Y
&
y
,
const
X
&
x
)
const
;
template
<
>
__host__
__device__
void
operator
()
<
float
,
float
>
(
float
&
y
,
const
float
&
x
)
const
{
y
=
0.5
f
*
x
*
(
1.
f
+
erf
(
float
(
0.70710678118
f
*
x
)));
}
template
<
>
__host__
__device__
void
operator
()
<
ck
::
half_t
,
ck
::
half_t
>
(
ck
::
half_t
&
y
,
const
ck
::
half_t
&
x
)
const
{
y
=
ck
::
half_t
(
0.5
)
*
x
*
(
ck
::
half_t
(
1
)
+
ck
::
half_t
(
erf
(
float
(
0.70710678118
f
*
x
))));
}
};
struct
Sigmoid
{
template
<
typename
T
>
__host__
__device__
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
is_same
<
T
,
ck
::
half_t
>::
value
,
"Data type is not supported by this operation!"
);
y
=
1
/
(
ck
::
type_convert
<
T
>
(
1
)
+
exp
(
-
x
));
};
int32_t
divider_
=
1
;
};
}
// namespace element_wise
}
// namespace tensor_operation
}
// namespace ck
include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp
View file @
5aa3c344
...
...
@@ -486,4 +486,48 @@ __host__ __device__ bool DefaultValidCTileIndex(const CTileIdx& c_tile_idx,
return
is_valid
;
}
// This wrapper class is for grouped gemm where it subtracts blockIdx by a value so that the
// workgroups assigned to a given gemm problem have top index offsetted to range [0,
// grid_size_per_gemm]
template
<
typename
UnderlyingBlockToCTileMap
>
struct
OffsettedBlockToCTileMap
{
using
underlying_type
=
UnderlyingBlockToCTileMap
;
OffsettedBlockToCTileMap
(
UnderlyingBlockToCTileMap
block_to_ctile_map
,
index_t
block_start
)
{
block_to_ctile_map_
=
block_to_ctile_map
;
block_start_
=
block_start
;
}
template
<
typename
TopIdx
>
__host__
__device__
constexpr
auto
CalculateBottomIndex
(
const
TopIdx
&
idx_top
)
const
{
return
block_to_ctile_map_
.
CalculateBottomIndex
(
make_multi_index
(
idx_top
[
Number
<
0
>
{}]
-
block_start_
));
}
template
<
typename
CTileIdx
,
typename
CTileDim
>
__host__
__device__
bool
ValidCTileIndex
(
const
CTileIdx
&
c_tile_idx
,
const
CTileDim
&
c_tile_dim
)
const
{
return
block_to_ctile_map_
.
ValidCTileIndex
(
c_tile_idx
,
c_tile_dim
);
}
template
<
typename
CGridDesc_M_N
>
__host__
bool
CheckValidity
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
)
const
{
return
block_to_ctile_map_
.
CheckValidity
(
c_grid_desc_m_n
);
}
template
<
typename
CGridDesc_M_N
>
__host__
constexpr
index_t
CalculateGridSize
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
)
const
{
return
block_to_ctile_map_
.
CalculateGridSize
(
c_grid_desc_m_n
);
}
UnderlyingBlockToCTileMap
block_to_ctile_map_
;
index_t
block_start_
;
};
}
// namespace ck
include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle_v1.hpp
0 → 100644
View file @
5aa3c344
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/multi_index_transform_helper.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
namespace
ck
{
template
<
typename
A0B0B1DataType
,
// FIXME: don't assume A0/B0/B1 have same datatype
typename
Acc0DataType
,
typename
D0sDataType
,
typename
Acc1DataType
,
typename
C1ShuffleDataType
,
typename
D1sDataType
,
typename
E1DataType
,
typename
A0ElementwiseOperation
,
typename
B0ElementwiseOperation
,
typename
CDE0ElementwiseOperation
,
typename
B1ElementwiseOperation
,
typename
CDE1ElementwiseOperation
,
InMemoryDataOperationEnum
E1GlobalMemoryDataOperation
,
typename
A0GridDesc_M_K
,
typename
B0GridDesc_N_K
,
typename
D0sGridDesc_M_N
,
typename
B1GridDesc_N_K
,
typename
D1sGridDesc_M_N
,
typename
E1GridDesc_M_N
,
index_t
NumGemm0KPrefetchStage
,
index_t
BlockSize
,
index_t
Gemm0MPerBlock
,
index_t
Gemm0NPerBlock
,
index_t
Gemm0KPerBlock
,
index_t
Gemm1NPerBlock
,
index_t
Gemm1KPerBlock
,
index_t
A0K1Value
,
index_t
B0K1Value
,
index_t
B1K1Value
,
index_t
Gemm0MPerXdl
,
index_t
Gemm0NPerXdl
,
index_t
Gemm0MXdlPerWave
,
index_t
Gemm0NXdlPerWave
,
index_t
Gemm1NXdlPerWave
,
typename
A0BlockTransferThreadClusterLengths_AK0_M_AK1
,
typename
A0BlockTransferThreadClusterArrangeOrder
,
typename
A0BlockTransferSrcAccessOrder
,
index_t
A0BlockTransferSrcVectorDim
,
index_t
A0BlockTransferSrcScalarPerVector
,
index_t
A0BlockTransferDstScalarPerVector_AK1
,
bool
A0ThreadTransferSrcResetCoordinateAfterRun
,
// ignored
index_t
A0BlockLdsExtraM
,
typename
B0BlockTransferThreadClusterLengths_BK0_N_BK1
,
typename
B0BlockTransferThreadClusterArrangeOrder
,
typename
B0BlockTransferSrcAccessOrder
,
index_t
B0BlockTransferSrcVectorDim
,
index_t
B0BlockTransferSrcScalarPerVector
,
index_t
B0BlockTransferDstScalarPerVector_BK1
,
bool
B0ThreadTransferSrcResetCoordinateAfterRun
,
// ignored
index_t
B0BlockLdsExtraN
,
typename
B1BlockTransferThreadClusterLengths_BK0_N_BK1
,
typename
B1BlockTransferThreadClusterArrangeOrder
,
typename
B1BlockTransferSrcAccessOrder
,
index_t
B1BlockTransferSrcVectorDim
,
index_t
B1BlockTransferSrcScalarPerVector
,
index_t
B1BlockTransferDstScalarPerVector_BK1
,
bool
B1ThreadTransferSrcResetCoordinateAfterRun
,
index_t
B1BlockLdsExtraN
,
index_t
C1ShuffleGemm0MXdlPerWavePerShuffle
,
index_t
C1ShuffleGemm0NXdlPerWavePerShuffle
,
typename
CDE1ShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
index_t
CDE1ShuffleBlockTransferScalarPerVector_NPerBlock
,
LoopScheduler
LoopSched
>
struct
GridwiseBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle
{
static_assert
(
LoopSched
==
LoopScheduler
::
Default
,
"Non-default loop scheduler is currently not supported"
);
static
constexpr
index_t
NumD0Tensor
=
D0sDataType
::
Size
();
static
constexpr
index_t
NumD1Tensor
=
D1sDataType
::
Size
();
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
static
constexpr
auto
I4
=
Number
<
4
>
{};
static
constexpr
auto
I5
=
Number
<
5
>
{};
static
constexpr
auto
I6
=
Number
<
6
>
{};
static
constexpr
auto
I7
=
Number
<
7
>
{};
static
constexpr
auto
WaveSize
=
64
;
// K1 should be Number<...>
// Gemm0
static
constexpr
auto
A0K1
=
Number
<
A0K1Value
>
{};
static
constexpr
auto
B0K1
=
Number
<
B0K1Value
>
{};
static
constexpr
auto
A0K0PerBlock
=
Number
<
Gemm0KPerBlock
/
A0K1Value
>
{};
static
constexpr
auto
B0K0PerBlock
=
Number
<
Gemm0KPerBlock
/
B0K1Value
>
{};
static
constexpr
auto
Gemm0MWaves
=
Gemm0MPerBlock
/
(
Gemm0MPerXdl
*
Gemm0MXdlPerWave
);
static
constexpr
auto
Gemm0NWaves
=
Gemm0NPerBlock
/
(
Gemm0NPerXdl
*
Gemm0NXdlPerWave
);
// Gemm1
static
constexpr
auto
B1K1
=
Number
<
B1K1Value
>
{};
static
constexpr
auto
B1K0PerBlock
=
Number
<
Gemm1KPerBlock
/
B1K1Value
>
{};
using
ThisThreadBlock
=
ThisThreadBlock
<
BlockSize
>
;
using
GridwiseGemmPipe
=
GridwiseGemmPipeline_v1
<
NumGemm0KPrefetchStage
>
;
// ck::Tuple<const D0DataType1*, const D0DataType2*, ...>
static
constexpr
auto
MakeD0sGridPointer
()
{
return
generate_tuple
(
[
&
](
auto
i
)
{
using
D0DataType
=
remove_cvref_t
<
tuple_element_t
<
i
.
value
,
D0sDataType
>>
;
return
static_cast
<
const
D0DataType
*>
(
nullptr
);
},
Number
<
NumD0Tensor
>
{});
}
// ck::Tuple<const D1DataType1*, const D1DataType2*, ...>
static
constexpr
auto
MakeD1sGridPointer
()
{
return
generate_tuple
(
[
&
](
auto
i
)
{
using
D1DataType
=
remove_cvref_t
<
tuple_element_t
<
i
.
value
,
D1sDataType
>>
;
return
static_cast
<
const
D1DataType
*>
(
nullptr
);
},
Number
<
NumD1Tensor
>
{});
}
__device__
static
auto
GetGemm0WaveIdx
()
{
const
index_t
thread_id
=
get_thread_local_1d_id
();
constexpr
auto
threadid_to_wave_idx_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
Gemm0MWaves
,
Gemm0NWaves
,
WaveSize
))),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
return
threadid_to_wave_idx_adaptor
.
CalculateBottomIndex
(
make_multi_index
(
thread_id
));
}
__device__
static
auto
GetGemm0WaveMNIdx
(
const
index_t
thread_id
)
{
constexpr
auto
wave_threadid_to_mn_idx_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
WaveSize
/
Gemm0NPerXdl
,
Gemm0NPerXdl
))),
make_tuple
(
Sequence
<
0
,
1
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
return
wave_threadid_to_mn_idx_adaptor
.
CalculateBottomIndex
(
make_multi_index
(
thread_id
));
}
template
<
typename
A0BlockDesc_AK0_M_AK1
>
__host__
__device__
static
constexpr
auto
MakeGemm0AMmaTileDescriptor_M0_M1_M2_K
(
const
A0BlockDesc_AK0_M_AK1
&
)
{
constexpr
index_t
MWaves
=
Gemm0MPerBlock
/
(
Gemm0MXdlPerWave
*
Gemm0MPerXdl
);
return
MakeGemmMmaTileDescriptor_MN0_MN1_MN2_K
<
Gemm0MXdlPerWave
,
MWaves
,
Gemm0MPerXdl
>
(
A0BlockDesc_AK0_M_AK1
{});
}
template
<
typename
BBlockDesc_BK0_N_BK1
>
__host__
__device__
static
constexpr
auto
MakeGemm0BMmaTileDescriptor_N0_N1_N2_K
(
const
BBlockDesc_BK0_N_BK1
&
)
{
constexpr
index_t
NWaves
=
Gemm0NPerBlock
/
(
Gemm0NXdlPerWave
*
Gemm0NPerXdl
);
return
MakeGemmMmaTileDescriptor_MN0_MN1_MN2_K
<
Gemm0NXdlPerWave
,
NWaves
,
Gemm0NPerXdl
>
(
BBlockDesc_BK0_N_BK1
{});
}
template
<
typename
A0BlockDesc_AK0_M_AK1
>
__host__
__device__
static
constexpr
auto
MakeGemm1AMmaTileDescriptor_M0_M1_M2_K
(
const
A0BlockDesc_AK0_M_AK1
&
)
{
return
MakeGemmMmaTileDescriptor_MN0_MN1_MN2_K
<
Gemm0MXdlPerWave
,
1
,
1
>
(
A0BlockDesc_AK0_M_AK1
{});
}
template
<
typename
BBlockDesc_BK0_N_BK1
>
__host__
__device__
static
constexpr
auto
MakeGemm1BMmaTileDescriptor_N0_N1_N2_K
(
const
BBlockDesc_BK0_N_BK1
&
)
{
constexpr
index_t
Gemm1NWaves
=
Gemm1NPerBlock
/
(
Gemm1NXdlPerWave
*
Gemm0NPerXdl
);
return
MakeGemmMmaTileDescriptor_MN0_MN1_MN2_K
<
Gemm1NXdlPerWave
,
Gemm1NWaves
,
Gemm0NPerXdl
>
(
BBlockDesc_BK0_N_BK1
{});
}
__host__
__device__
static
constexpr
auto
GetA0BlockDescriptor_AK0PerBlock_MPerBlock_AK1
()
{
// A0 matrix in LDS memory, dst of blockwise copy
return
make_naive_tensor_descriptor
(
make_tuple
(
A0K0PerBlock
,
Number
<
Gemm0MPerBlock
>
{},
A0K1
),
make_tuple
(
Number
<
Gemm0MPerBlock
+
A0BlockLdsExtraM
>
{}
*
A0K1
,
A0K1
,
I1
));
}
__host__
__device__
static
constexpr
auto
GetB0BlockDescriptor_BK0PerBlock_NPerBlock_BK1
()
{
// B0 matrix in LDS memory, dst of blockwise copy
return
make_naive_tensor_descriptor
(
make_tuple
(
B0K0PerBlock
,
Number
<
Gemm0NPerBlock
>
{},
B0K1
),
make_tuple
(
Number
<
Gemm0NPerBlock
+
B0BlockLdsExtraN
>
{}
*
B0K1
,
B0K1
,
I1
));
}
__host__
__device__
static
constexpr
auto
GetB1BlockDescriptor_BK0PerBlock_NPerBlock_BK1
()
{
// B1 matrix in LDS memory, dst of blockwise copy
return
make_naive_tensor_descriptor
(
make_tuple
(
B1K0PerBlock
,
Number
<
Gemm1NPerBlock
>
{},
B1K1
),
make_tuple
(
Number
<
Gemm1NPerBlock
+
B1BlockLdsExtraN
>
{}
*
B1K1
,
B1K1
,
I1
));
}
__host__
__device__
static
constexpr
auto
GetC1ShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
()
{
constexpr
index_t
MWave
=
Gemm0MPerBlock
/
(
Gemm0MXdlPerWave
*
Gemm0MPerXdl
);
constexpr
index_t
NWave
=
Gemm1NPerBlock
/
(
Gemm1NXdlPerWave
*
Gemm0NPerXdl
);
constexpr
auto
c1_shuffle_block_desc_mblock_mperblock_nblock_nperblock
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
Number
<
C1ShuffleGemm0MXdlPerWavePerShuffle
*
MWave
*
Gemm0MPerXdl
>
{},
I1
,
Number
<
C1ShuffleGemm0NXdlPerWavePerShuffle
*
NWave
*
Gemm0NPerXdl
>
{}));
return
c1_shuffle_block_desc_mblock_mperblock_nblock_nperblock
;
}
__host__
__device__
static
constexpr
index_t
GetSharedMemoryNumberOfByte
()
{
const
index_t
gemm0_bytes_end
=
(
SharedMemTrait
::
a0_block_space_size_aligned
+
SharedMemTrait
::
b0_block_space_size_aligned
)
*
sizeof
(
A0B0B1DataType
);
const
index_t
gemm1_bytes_end
=
(
SharedMemTrait
::
b1_block_space_offset
+
SharedMemTrait
::
b1_block_space_size_aligned
)
*
sizeof
(
A0B0B1DataType
);
const
index_t
c1_block_bytes_end
=
SharedMemTrait
::
c1_block_space_size
*
sizeof
(
C1ShuffleDataType
);
return
math
::
max
(
gemm0_bytes_end
,
gemm1_bytes_end
,
c1_block_bytes_end
);
}
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
template
<
typename
Block2E1TileMap
>
__host__
__device__
static
constexpr
bool
CheckValidity
(
const
A0GridDesc_M_K
&
a0_grid_desc_m_k
,
const
B0GridDesc_N_K
&
b0_grid_desc_n_k
,
const
B1GridDesc_N_K
&
b1_grid_desc_n_k
,
const
E1GridDesc_M_N
&
e1_grid_desc_m_n
,
const
Block2E1TileMap
&
block_2_e1tile_map
)
{
static_assert
((
Gemm0MPerBlock
%
(
Gemm0MPerXdl
*
Gemm0MXdlPerWave
)
==
0
)
&&
(
Gemm0NPerBlock
%
(
Gemm0NXdlPerWave
*
Gemm0NPerXdl
))
==
0
,
"Invalid tuning param!"
);
const
auto
M
=
a0_grid_desc_m_k
.
GetLength
(
I0
);
const
auto
N
=
b0_grid_desc_n_k
.
GetLength
(
I0
);
const
auto
K
=
a0_grid_desc_m_k
.
GetLength
(
I1
);
const
auto
Gemm1N
=
b1_grid_desc_n_k
.
GetLength
(
I0
);
if
(
!
(
M
==
e1_grid_desc_m_n
.
GetLength
(
I0
)
&&
Gemm1N
==
e1_grid_desc_m_n
.
GetLength
(
I1
)))
{
return
false
;
}
if
(
!
(
M
%
Gemm0MPerBlock
==
0
&&
N
%
Gemm0NPerBlock
==
0
&&
K
%
Gemm0KPerBlock
==
0
&&
Gemm1N
%
Gemm1NPerBlock
==
0
))
{
return
false
;
}
// check gemm0 gridwise gemm pipeline
const
auto
num_gemm0_k_loop
=
K
/
Gemm0KPerBlock
;
if
(
!
GridwiseGemmPipe
::
IsSupported
(
num_gemm0_k_loop
))
{
return
false
;
}
// check gemm1 gridwise gemm pipeline
if
(
!
(
Gemm0NPerBlock
%
Gemm1KPerBlock
==
0
))
{
return
false
;
}
const
auto
num_gemm1_k_inner_loop
=
Gemm0NPerBlock
/
Gemm1KPerBlock
;
if
(
!
GridwiseGemmPipe
::
IsSupported
(
num_gemm1_k_inner_loop
))
{
return
false
;
}
if
(
!
block_2_e1tile_map
.
CheckValidity
(
e1_grid_desc_m_n
))
{
return
false
;
}
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
return
true
;
}
__host__
__device__
static
constexpr
bool
CalculateHasMainKBlockLoop
(
index_t
K
)
{
const
index_t
num_loop
=
K
/
Gemm0KPerBlock
;
return
GridwiseGemmPipe
::
CalculateHasMainLoop
(
num_loop
);
}
// A0 desc for source in blockwise copy
__host__
__device__
static
constexpr
auto
MakeDefaultA0GridDescriptor_AK0_M_AK1
(
const
A0GridDesc_M_K
&
a0_grid_desc_m_k
)
{
const
auto
M
=
a0_grid_desc_m_k
.
GetLength
(
I0
);
const
auto
K
=
a0_grid_desc_m_k
.
GetLength
(
I1
);
const
auto
A0K0
=
K
/
A0K1
;
return
transform_tensor_descriptor
(
a0_grid_desc_m_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
A0K0
,
A0K1
)),
make_pass_through_transform
(
M
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
}
// B0 desc for source in blockwise copy
__host__
__device__
static
constexpr
auto
MakeDefaultB0GridDescriptor_BK0_N_BK1
(
const
B0GridDesc_N_K
&
b0_grid_desc_n_k
)
{
const
auto
N
=
b0_grid_desc_n_k
.
GetLength
(
I0
);
const
auto
K
=
b0_grid_desc_n_k
.
GetLength
(
I1
);
const
auto
B0K0
=
K
/
B0K1
;
return
transform_tensor_descriptor
(
b0_grid_desc_n_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
B0K0
,
B0K1
)),
make_pass_through_transform
(
N
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
}
// D0 desc for source in blockwise copy
template
<
typename
D0GridDesc_M_N
>
__host__
__device__
static
constexpr
auto
MakeGemm0D0GridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
(
const
D0GridDesc_M_N
&
d0_grid_desc_m_n
)
{
const
auto
M
=
d0_grid_desc_m_n
.
GetLength
(
I0
);
const
auto
N
=
d0_grid_desc_m_n
.
GetLength
(
I1
);
constexpr
auto
mfma
=
MfmaSelector
<
A0B0B1DataType
,
Gemm0MPerXdl
,
Gemm0NPerXdl
>::
selected_mfma
;
constexpr
auto
N3
=
mfma
.
num_groups_per_blk
;
constexpr
auto
N5
=
mfma
.
group_size
;
return
transform_tensor_descriptor
(
d0_grid_desc_m_n
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
M
/
Gemm0MPerBlock
,
Gemm0MXdlPerWave
,
Gemm0MWaves
,
Gemm0MPerXdl
)),
make_unmerge_transform
(
make_tuple
(
N
/
Gemm0NPerBlock
,
Gemm0NXdlPerWave
,
Gemm0NWaves
,
N3
,
WaveSize
/
Gemm0NPerXdl
,
N5
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
,
4
,
6
>
{},
Sequence
<
1
,
3
,
5
,
7
,
8
,
9
>
{}));
}
// B1 desc for source in blockwise copy
__host__
__device__
static
constexpr
auto
MakeDefaultB1GridDescriptor_BK0_N_BK1
(
const
B1GridDesc_N_K
&
b1_grid_desc_n_k
)
{
const
auto
N
=
b1_grid_desc_n_k
.
GetLength
(
I0
);
const
auto
K
=
b1_grid_desc_n_k
.
GetLength
(
I1
);
const
auto
B1K0
=
K
/
B1K1
;
return
transform_tensor_descriptor
(
b1_grid_desc_n_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
B1K0
,
B1K1
)),
make_pass_through_transform
(
N
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
}
// C1 desc for destination in blockwise copy
__host__
__device__
static
constexpr
auto
MakeE1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
const
E1GridDesc_M_N
&
e1_grid_desc_m_n
)
{
const
auto
M
=
e1_grid_desc_m_n
.
GetLength
(
I0
);
const
auto
N
=
e1_grid_desc_m_n
.
GetLength
(
I1
);
const
auto
MBlock
=
M
/
Gemm0MPerBlock
;
const
auto
NBlock
=
N
/
Gemm1NPerBlock
;
const
auto
e1_grid_desc_mblock_mperblock_nblock_nperblock
=
transform_tensor_descriptor
(
e1_grid_desc_m_n
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
MBlock
,
Number
<
Gemm0MPerBlock
>
{})),
make_unmerge_transform
(
make_tuple
(
NBlock
,
Number
<
Gemm1NPerBlock
>
{}))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
1
>
{},
Sequence
<
2
,
3
>
{}));
return
e1_grid_desc_mblock_mperblock_nblock_nperblock
;
}
// D0s desc for source in blockwise copy
__host__
__device__
static
constexpr
auto
MakeD0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
(
const
D0sGridDesc_M_N
&
ds_grid_desc_m_n
)
{
return
generate_tuple
(
[
&
](
auto
i
)
{
return
MakeGemm0D0GridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
(
ds_grid_desc_m_n
[
i
]);
},
Number
<
NumD0Tensor
>
{});
}
// Ds desc for source in blockwise copy
template
<
typename
DsGridDescriptor_M_N
>
__host__
__device__
static
constexpr
auto
MakeD1sGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
const
DsGridDescriptor_M_N
&
ds_grid_desc_m_n
)
{
return
generate_tuple
(
[
&
](
auto
i
)
{
return
MakeE1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
ds_grid_desc_m_n
[
i
]);
},
Number
<
NumD1Tensor
>
{});
}
// return block_id to C1 matrix tile idx (m0, n0) mapping
__host__
__device__
static
constexpr
auto
MakeDefaultBlock2E1TileMap
(
const
E1GridDesc_M_N
&
e1_grid_desc_m_n
)
{
return
BlockToCTileMap_M00_N0_M01Adapt
<
Gemm0MPerBlock
,
Gemm1NPerBlock
,
E1GridDesc_M_N
>
(
e1_grid_desc_m_n
);
}
using
E1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
=
remove_cvref_t
<
decltype
(
MakeE1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
E1GridDesc_M_N
{}))
>
;
using
D0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
=
remove_cvref_t
<
decltype
(
MakeD0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
(
D0sGridDesc_M_N
{}))
>
;
using
D1sGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
=
remove_cvref_t
<
decltype
(
MakeD1sGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
D1sGridDesc_M_N
{}))
>
;
using
DefaultBlock2E1TileMap
=
remove_cvref_t
<
decltype
(
MakeDefaultBlock2E1TileMap
(
E1GridDesc_M_N
{}))
>
;
struct
SharedMemTrait
{
// LDS allocation for A0 and B0: be careful of alignment
static
constexpr
auto
a0_block_desc_ak0_m_ak1
=
GetA0BlockDescriptor_AK0PerBlock_MPerBlock_AK1
();
static
constexpr
auto
b0_block_desc_bk0_n_bk1
=
GetB0BlockDescriptor_BK0PerBlock_NPerBlock_BK1
();
static
constexpr
auto
b1_block_desc_bk0_n_bk1
=
GetB1BlockDescriptor_BK0PerBlock_NPerBlock_BK1
();
static
constexpr
auto
max_lds_align
=
math
::
lcm
(
math
::
lcm
(
A0K1
,
B0K1
),
B1K1
);
static
constexpr
auto
a0_block_space_size_aligned
=
math
::
integer_least_multiple
(
a0_block_desc_ak0_m_ak1
.
GetElementSpaceSize
(),
max_lds_align
);
static
constexpr
auto
b0_block_space_size_aligned
=
math
::
integer_least_multiple
(
b0_block_desc_bk0_n_bk1
.
GetElementSpaceSize
(),
max_lds_align
);
static
constexpr
auto
b1_block_space_size_aligned
=
math
::
integer_least_multiple
(
b1_block_desc_bk0_n_bk1
.
GetElementSpaceSize
(),
max_lds_align
);
static
constexpr
auto
a0_block_space_offset
=
0
;
static
constexpr
auto
b0_block_space_offset
=
a0_block_space_size_aligned
.
value
;
static
constexpr
auto
b1_block_space_offset
=
0
;
// LDS allocation for C1 shuffle in LDS
static
constexpr
auto
c1_shuffle_block_desc_mblock_mperblock_nblock_nperblock
=
GetC1ShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
();
static
constexpr
auto
c1_block_space_size
=
c1_shuffle_block_desc_mblock_mperblock_nblock_nperblock
.
GetElementSpaceSize
();
};
using
D0sGridPointer
=
decltype
(
MakeD0sGridPointer
());
using
D1sGridPointer
=
decltype
(
MakeD1sGridPointer
());
template
<
bool
HasMainKBlockLoop
,
typename
A0GridDesc_AK0_M_AK1
,
typename
B0GridDesc_BK0_N_BK1
,
typename
B1GridDesc_BK0_N_BK1
,
typename
Block2E1TileMap
>
__device__
static
void
Run
(
const
A0B0B1DataType
*
__restrict__
p_a0_grid
,
const
A0B0B1DataType
*
__restrict__
p_b0_grid
,
D0sGridPointer
p_d0s_grid
,
const
A0B0B1DataType
*
__restrict__
p_b1_grid
,
D1sGridPointer
p_d1s_grid
,
E1DataType
*
__restrict__
p_e1_grid
,
void
*
__restrict__
p_shared
,
const
A0ElementwiseOperation
&
a0_element_op
,
const
B0ElementwiseOperation
&
b0_element_op
,
const
CDE0ElementwiseOperation
&
cde0_element_op
,
const
B1ElementwiseOperation
&
b1_element_op
,
const
CDE1ElementwiseOperation
&
cde1_element_op
,
const
A0GridDesc_AK0_M_AK1
&
a0_grid_desc_ak0_m_ak1
,
const
B0GridDesc_BK0_N_BK1
&
b0_grid_desc_bk0_n_bk1
,
const
D0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
&
d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
const
B1GridDesc_BK0_N_BK1
&
b1_grid_desc_bk0_n_bk1
,
const
D1sGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
&
d1s_grid_desc_mblock_mperblock_nblock_nperblock
,
const
E1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
&
e1_grid_desc_mblock_mperblock_nblock_nperblock
,
const
Block2E1TileMap
&
block_2_e1tile_map
)
{
const
auto
a0_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_a0_grid
,
a0_grid_desc_ak0_m_ak1
.
GetElementSpaceSize
());
const
auto
b0_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_b0_grid
,
b0_grid_desc_bk0_n_bk1
.
GetElementSpaceSize
());
const
auto
b1_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_b1_grid
,
b1_grid_desc_bk0_n_bk1
.
GetElementSpaceSize
());
auto
e1_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_e1_grid
,
e1_grid_desc_mblock_mperblock_nblock_nperblock
.
GetElementSpaceSize
());
const
auto
d0s_grid_buf
=
generate_tuple
(
[
&
](
auto
i
)
{
return
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_d0s_grid
[
i
],
d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
[
i
].
GetElementSpaceSize
());
},
Number
<
NumD0Tensor
>
{});
const
auto
d1s_grid_buf
=
generate_tuple
(
[
&
](
auto
i
)
{
return
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_d1s_grid
[
i
],
d1s_grid_desc_mblock_mperblock_nblock_nperblock
[
i
].
GetElementSpaceSize
());
},
Number
<
NumD1Tensor
>
{});
// divide block work by [M, N]
const
auto
block_work_idx
=
block_2_e1tile_map
.
CalculateBottomIndex
(
make_multi_index
(
get_block_1d_id
()));
if
(
!
block_2_e1tile_map
.
ValidCTileIndex
(
block_work_idx
,
make_tuple
(
e1_grid_desc_mblock_mperblock_nblock_nperblock
.
GetLength
(
I0
),
e1_grid_desc_mblock_mperblock_nblock_nperblock
.
GetLength
(
I2
))))
{
return
;
}
// HACK: this force m/n_block_data_idx_on_grid into SGPR
const
index_t
m_block_data_idx_on_grid
=
__builtin_amdgcn_readfirstlane
(
block_work_idx
[
I0
]
*
Gemm0MPerBlock
);
const
index_t
n_block_data_idx_on_grid
=
__builtin_amdgcn_readfirstlane
(
block_work_idx
[
I1
]
*
Gemm1NPerBlock
);
// A0 matrix in LDS memory, dst of blockwise copy
constexpr
auto
a0_block_desc_ak0_m_ak1
=
GetA0BlockDescriptor_AK0PerBlock_MPerBlock_AK1
();
// B0 matrix in LDS memory, dst of blockwise copy
constexpr
auto
b0_block_desc_bk0_n_bk1
=
GetB0BlockDescriptor_BK0PerBlock_NPerBlock_BK1
();
//
// set up Gemm0
//
// A0 matrix blockwise copy
auto
a0_blockwise_copy
=
ThreadGroupTensorSliceTransfer_v4r1
<
ThisThreadBlock
,
A0ElementwiseOperation
,
tensor_operation
::
element_wise
::
PassThrough
,
InMemoryDataOperationEnum
::
Set
,
Sequence
<
A0K0PerBlock
,
Gemm0MPerBlock
,
A0K1
>
,
A0BlockTransferThreadClusterLengths_AK0_M_AK1
,
A0BlockTransferThreadClusterArrangeOrder
,
A0B0B1DataType
,
A0B0B1DataType
,
decltype
(
a0_grid_desc_ak0_m_ak1
),
decltype
(
a0_block_desc_ak0_m_ak1
),
A0BlockTransferSrcAccessOrder
,
Sequence
<
1
,
0
,
2
>
,
A0BlockTransferSrcVectorDim
,
2
,
A0BlockTransferSrcScalarPerVector
,
A0BlockTransferDstScalarPerVector_AK1
,
1
,
1
,
true
,
// SrcResetCoord
true
,
// DstResetCoord
NumGemm0KPrefetchStage
>
(
a0_grid_desc_ak0_m_ak1
,
make_multi_index
(
0
,
m_block_data_idx_on_grid
,
0
),
a0_element_op
,
a0_block_desc_ak0_m_ak1
,
make_multi_index
(
0
,
0
,
0
),
tensor_operation
::
element_wise
::
PassThrough
{});
// B0 matrix blockwise copy
auto
b0_blockwise_copy
=
ThreadGroupTensorSliceTransfer_v4r1
<
ThisThreadBlock
,
B0ElementwiseOperation
,
tensor_operation
::
element_wise
::
PassThrough
,
InMemoryDataOperationEnum
::
Set
,
Sequence
<
B0K0PerBlock
,
Gemm0NPerBlock
,
B0K1
>
,
B0BlockTransferThreadClusterLengths_BK0_N_BK1
,
B0BlockTransferThreadClusterArrangeOrder
,
A0B0B1DataType
,
A0B0B1DataType
,
decltype
(
b0_grid_desc_bk0_n_bk1
),
decltype
(
b0_block_desc_bk0_n_bk1
),
B0BlockTransferSrcAccessOrder
,
Sequence
<
1
,
0
,
2
>
,
B0BlockTransferSrcVectorDim
,
2
,
B0BlockTransferSrcScalarPerVector
,
B0BlockTransferDstScalarPerVector_BK1
,
1
,
1
,
true
,
// SrcResetCoord
true
,
// DstResetCoord
NumGemm0KPrefetchStage
>
(
b0_grid_desc_bk0_n_bk1
,
make_multi_index
(
0
,
0
,
0
),
// will loop over GemmN dimension
b0_element_op
,
b0_block_desc_bk0_n_bk1
,
make_multi_index
(
0
,
0
,
0
),
tensor_operation
::
element_wise
::
PassThrough
{});
// Fused Gemm+Gemm pipeline
// for n in N0:
// for k in K0:
// acc[m][n] += A[m][k] * B0[k][n]
// acc1[m][o] += acc[m][n] * B1[n][o]
// sanity check
constexpr
index_t
KPack
=
math
::
max
(
math
::
lcm
(
A0K1
,
B0K1
),
MfmaSelector
<
A0B0B1DataType
,
Gemm0MPerXdl
,
Gemm0NPerXdl
>::
selected_mfma
.
k_per_blk
);
auto
blockwise_gemm0
=
BlockwiseGemmXdlops_v2
<
BlockSize
,
A0B0B1DataType
,
Acc0DataType
,
decltype
(
a0_block_desc_ak0_m_ak1
),
decltype
(
b0_block_desc_bk0_n_bk1
),
decltype
(
MakeGemm0AMmaTileDescriptor_M0_M1_M2_K
(
a0_block_desc_ak0_m_ak1
)),
decltype
(
MakeGemm0BMmaTileDescriptor_N0_N1_N2_K
(
b0_block_desc_bk0_n_bk1
)),
Gemm0MPerBlock
,
Gemm0NPerBlock
,
Gemm0KPerBlock
,
Gemm0MPerXdl
,
Gemm0NPerXdl
,
Gemm0MXdlPerWave
,
Gemm0NXdlPerWave
,
KPack
,
true
>
{};
// TransposeC
auto
acc0_thread_buf
=
blockwise_gemm0
.
GetCThreadBuffer
();
// LDS allocation for A0 and B0: be careful of alignment
auto
a0_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
A0B0B1DataType
*>
(
p_shared
)
+
SharedMemTrait
::
a0_block_space_offset
,
a0_block_desc_ak0_m_ak1
.
GetElementSpaceSize
());
auto
b0_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
A0B0B1DataType
*>
(
p_shared
)
+
SharedMemTrait
::
b0_block_space_offset
,
b0_block_desc_bk0_n_bk1
.
GetElementSpaceSize
());
constexpr
auto
a0_block_slice_copy_step
=
make_multi_index
(
Gemm0KPerBlock
/
A0K1
,
0
,
0
);
constexpr
auto
b0_block_slice_copy_step
=
make_multi_index
(
Gemm0KPerBlock
/
B0K1
,
0
,
0
);
const
auto
a0_block_reset_copy_step
=
make_multi_index
(
-
a0_grid_desc_ak0_m_ak1
.
GetLength
(
I0
),
0
,
0
);
const
auto
b0_block_reset_copy_step
=
make_multi_index
(
-
b0_grid_desc_bk0_n_bk1
.
GetLength
(
I0
),
Gemm0NPerBlock
,
0
);
// gridwise GEMM pipeline
// Only supports LoopScheduler::Default
const
auto
gridwise_gemm0_pipeline
=
GridwiseGemmPipeline_v1_Selector
<
NumGemm0KPrefetchStage
,
LoopScheduler
::
Default
>
();
const
index_t
num_k_block_main_loop
=
__builtin_amdgcn_readfirstlane
(
(
a0_grid_desc_ak0_m_ak1
.
GetLength
(
I0
)
*
a0_grid_desc_ak0_m_ak1
.
GetLength
(
I2
))
/
Gemm0KPerBlock
);
//
// set up Gemm1
//
// Acc0 matrix threadwise copy: AccVGPR to VGPR and downcast to XDL input data type
constexpr
auto
acc0_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
=
blockwise_gemm0
.
GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4
();
constexpr
auto
m0
=
acc0_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
.
GetLength
(
I0
);
constexpr
auto
n0
=
acc0_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
.
GetLength
(
I1
);
constexpr
auto
m1
=
acc0_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
.
GetLength
(
I2
);
constexpr
auto
n1
=
acc0_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
.
GetLength
(
I3
);
constexpr
auto
m2
=
acc0_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
.
GetLength
(
I4
);
constexpr
auto
n2
=
acc0_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
.
GetLength
(
I5
);
constexpr
auto
n3
=
acc0_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
.
GetLength
(
I6
);
constexpr
auto
n4
=
acc0_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
.
GetLength
(
I7
);
constexpr
auto
b1_block_slice_copy_step
=
make_multi_index
(
Gemm1KPerBlock
/
B1K1
,
0
,
0
);
// d0 matrix threadwise copy
constexpr
auto
d0_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
// MBlockId
I1
,
// NBlockID
I1
,
// MRepeat
I1
,
// NRepeat
I1
,
// MWaveId
I1
,
// NWaveId
I1
,
// MPerXdl
I1
,
// NGroupNum
I1
,
// NInputNum
n4
));
// registerNum
auto
d0s_thread_buf
=
generate_tuple
(
[
&
](
auto
)
{
return
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
A0B0B1DataType
,
d0_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
.
GetElementSpaceSize
(),
true
>
{};
},
Number
<
NumD0Tensor
>
{});
const
auto
wave_id
=
GetGemm0WaveIdx
();
const
auto
wave_m_n_id
=
GetGemm0WaveMNIdx
(
wave_id
[
I2
]);
// I2: 0~63
constexpr
auto
acc0_thread_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
Gemm0MXdlPerWave
>
{},
Number
<
Gemm0NXdlPerWave
>
{},
n2
,
n4
));
auto
d0s_threadwise_copy
=
generate_tuple
(
[
&
](
auto
i
)
{
return
ThreadwiseTensorSliceTransfer_v2
<
A0B0B1DataType
,
A0B0B1DataType
,
decltype
(
d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
[
i
]),
decltype
(
d0_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
),
Sequence
<
I1
,
I1
,
I1
,
I1
,
I1
,
I1
,
I1
,
I1
,
I1
,
n4
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
>
,
9
,
n4
,
1
,
false
>
(
d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
[
i
],
make_multi_index
(
block_work_idx
[
I0
],
// MBlockId
0
,
// NBlockId
0
,
// mrepeat
0
,
// nrepeat
wave_id
[
I0
],
// MWaveId
wave_id
[
I1
],
// NWaveId
wave_m_n_id
[
I1
],
// MPerXdl
0
,
// group
wave_m_n_id
[
I0
],
// NInputIndex
0
));
// register number
},
Number
<
NumD0Tensor
>
{});
// acc0_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4 to acc0_thread_desc_k0_m_k1
// n0_n1_n2_n3 -> k0
// m0_m1_m2 -> m
// n4 -> k1
// NOTE: had to use merge_v3 or will spit out compilation errors
constexpr
auto
acc0_thread_desc_k0_m_k1
=
transform_tensor_descriptor
(
acc0_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
,
make_tuple
(
make_merge_transform_v3_division_mod
(
make_tuple
(
n0
,
n1
,
n2
,
n3
)),
make_merge_transform_v3_division_mod
(
make_tuple
(
m0
,
m1
,
m2
)),
make_pass_through_transform
(
n4
)),
make_tuple
(
Sequence
<
1
,
3
,
5
,
6
>
{},
Sequence
<
0
,
2
,
4
>
{},
Sequence
<
7
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}));
// A1 matrix in AccVGPR
// N2 num_groups_per_blk, N3 num_input_blks, N4 group_size
constexpr
auto
Acc0N3
=
blockwise_gemm0
.
GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4
().
GetLength
(
I6
);
constexpr
auto
A1ThreadSlice_K0_M_K1
=
make_tuple
(
Number
<
Gemm1KPerBlock
/
n4
/
Acc0N3
>
{},
Number
<
m0
*
m1
*
m2
>
{},
Number
<
n4
>
{});
constexpr
auto
A1ThreadSliceK0
=
A1ThreadSlice_K0_M_K1
[
I0
];
constexpr
auto
A1ThreadSliceM
=
A1ThreadSlice_K0_M_K1
[
I1
];
constexpr
auto
A1ThreadSliceK1
=
A1ThreadSlice_K0_M_K1
[
I2
];
constexpr
auto
a1_thread_desc_k0_m_k1
=
make_naive_tensor_descriptor
(
A1ThreadSlice_K0_M_K1
,
make_tuple
(
A1ThreadSliceM
*
A1ThreadSliceK1
,
A1ThreadSliceK1
,
I1
));
// B1 matrix in LDS memory, dst of blockwise copy
constexpr
auto
b1_block_desc_bk0_n_bk1
=
GetB1BlockDescriptor_BK0PerBlock_NPerBlock_BK1
();
// A1 matrix blockwise copy
auto
a1_blockwise_copy
=
ThreadwiseTensorSliceTransfer_StaticToStatic
<
Acc0DataType
,
A0B0B1DataType
,
decltype
(
acc0_thread_desc_k0_m_k1
),
decltype
(
a1_thread_desc_k0_m_k1
),
tensor_operation
::
element_wise
::
PassThrough
,
Sequence
<
A1ThreadSliceK0
,
A1ThreadSliceM
,
A1ThreadSliceK1
>
,
Sequence
<
1
,
0
,
2
>
,
2
,
n4
>
{
tensor_operation
::
element_wise
::
PassThrough
{}};
// B1 matrix blockwise copy
auto
b1_blockwise_copy
=
ThreadGroupTensorSliceTransfer_v4r1
<
ThisThreadBlock
,
B0ElementwiseOperation
,
tensor_operation
::
element_wise
::
PassThrough
,
InMemoryDataOperationEnum
::
Set
,
Sequence
<
B1K0PerBlock
,
Gemm1NPerBlock
,
B1K1
>
,
B1BlockTransferThreadClusterLengths_BK0_N_BK1
,
B1BlockTransferThreadClusterArrangeOrder
,
A0B0B1DataType
,
A0B0B1DataType
,
decltype
(
b1_grid_desc_bk0_n_bk1
),
decltype
(
b1_block_desc_bk0_n_bk1
),
B1BlockTransferSrcAccessOrder
,
Sequence
<
1
,
0
,
2
>
,
B1BlockTransferSrcVectorDim
,
2
,
B1BlockTransferSrcScalarPerVector
,
B1BlockTransferDstScalarPerVector_BK1
,
1
,
1
,
B1ThreadTransferSrcResetCoordinateAfterRun
,
true
,
// DstResetCoord
1
>
(
b1_grid_desc_bk0_n_bk1
,
make_multi_index
(
0
,
n_block_data_idx_on_grid
,
0
),
b1_element_op
,
b1_block_desc_bk0_n_bk1
,
make_multi_index
(
0
,
0
,
0
),
tensor_operation
::
element_wise
::
PassThrough
{});
auto
a1_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
A0B0B1DataType
>
(
a1_thread_desc_k0_m_k1
.
GetElementSpaceSize
());
// reuse LDS space for gemm0's b0_block_buf
auto
b1_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
A0B0B1DataType
*>
(
p_shared
)
+
SharedMemTrait
::
b1_block_space_offset
,
b1_block_desc_bk0_n_bk1
.
GetElementSpaceSize
());
constexpr
index_t
Gemm1KPack
=
math
::
max
(
math
::
lcm
(
MfmaSelector
<
A0B0B1DataType
,
Gemm0MPerXdl
,
Gemm0NPerXdl
>::
selected_mfma
.
group_size
,
B1K1
),
MfmaSelector
<
A0B0B1DataType
,
Gemm0MPerXdl
,
Gemm0NPerXdl
>::
selected_mfma
.
k_per_blk
);
auto
blockwise_gemm1
=
BlockwiseGemmXdlops_v2
<
BlockSize
,
A0B0B1DataType
,
Acc1DataType
,
decltype
(
a1_thread_desc_k0_m_k1
),
decltype
(
b1_block_desc_bk0_n_bk1
),
decltype
(
MakeGemm1AMmaTileDescriptor_M0_M1_M2_K
(
a1_thread_desc_k0_m_k1
)),
decltype
(
MakeGemm1BMmaTileDescriptor_N0_N1_N2_K
(
b1_block_desc_bk0_n_bk1
)),
Gemm0MPerBlock
,
Gemm1NPerBlock
,
Gemm1KPerBlock
,
Gemm0MPerXdl
,
Gemm0NPerXdl
,
Gemm0MXdlPerWave
,
Gemm1NXdlPerWave
,
Gemm1KPack
,
false
,
// TransposeC
Gemm1KPack
,
// AMmaKStride
Gemm1KPack
*
XdlopsGemm
<
A0B0B1DataType
,
Gemm0MPerXdl
,
Gemm0NPerXdl
,
Gemm1KPack
,
false
>
{}
.
K0PerXdlops
>
{
// BMmaKStride
make_tuple
(
0
,
0
,
0
,
0
)};
// A_origin
auto
c1_thread_buf
=
blockwise_gemm1
.
GetCThreadBuffer
();
const
index_t
num_gemm1_k_block_outer_loop
=
b0_grid_desc_bk0_n_bk1
.
GetLength
(
I1
)
/
Gemm0NPerBlock
;
constexpr
index_t
num_gemm1_k_block_inner_loop
=
Gemm0NPerBlock
/
Gemm1KPerBlock
;
// Initialize C1
c1_thread_buf
.
Clear
();
// gemm1 K loop
index_t
gemm1_k_block_outer_index
=
0
;
do
{
// gemm0
gridwise_gemm0_pipeline
.
template
Run
<
HasMainKBlockLoop
>(
a0_grid_desc_ak0_m_ak1
,
a0_block_desc_ak0_m_ak1
,
a0_blockwise_copy
,
a0_grid_buf
,
a0_block_buf
,
a0_block_slice_copy_step
,
b0_grid_desc_bk0_n_bk1
,
b0_block_desc_bk0_n_bk1
,
b0_blockwise_copy
,
b0_grid_buf
,
b0_block_buf
,
b0_block_slice_copy_step
,
blockwise_gemm0
,
acc0_thread_buf
,
num_k_block_main_loop
);
// bias+gelu
{
static_for
<
0
,
Gemm0MXdlPerWave
,
1
>
{}([
&
](
auto
mr
)
{
static_for
<
0
,
Gemm0NXdlPerWave
,
1
>
{}([
&
](
auto
nr
)
{
static_for
<
0
,
n2
,
1
>
{}([
&
](
auto
groupid
)
{
static_for
<
0
,
NumD0Tensor
,
1
>
{}([
&
](
auto
i
)
{
d0s_threadwise_copy
(
i
).
Run
(
d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
[
i
],
d0s_grid_buf
[
i
],
d0_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
d0s_thread_buf
(
i
));
});
static_for
<
0
,
n4
,
1
>
{}([
&
](
auto
i
)
{
constexpr
index_t
c_offset
=
acc0_thread_desc
.
CalculateOffset
(
make_tuple
(
mr
,
nr
,
groupid
,
i
));
// get reference to src data
const
auto
src_data_refs
=
generate_tie
(
// return type should be lvalue
[
&
](
auto
iSrc
)
->
const
auto
&
{
return
d0s_thread_buf
[
iSrc
][
i
];
},
Number
<
NumD0Tensor
>
{});
// get reference to dst data
auto
dst_data_refs
=
generate_tie
(
// return type should be lvalue
[
&
](
auto
)
->
auto
&
{
return
acc0_thread_buf
(
Number
<
c_offset
>
{});
},
Number
<
2
>
{});
unpack2
(
cde0_element_op
,
dst_data_refs
,
src_data_refs
);
});
static_for
<
0
,
NumD0Tensor
,
1
>
{}([
&
](
auto
i
)
{
d0s_threadwise_copy
(
i
).
MoveSrcSliceWindow
(
d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
[
i
],
make_multi_index
(
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
));
});
});
static_for
<
0
,
NumD0Tensor
,
1
>
{}([
&
](
auto
i
)
{
d0s_threadwise_copy
(
i
).
MoveSrcSliceWindow
(
d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
[
i
],
make_multi_index
(
0
,
0
,
0
,
1
,
0
,
0
,
0
,
-
n2
.
value
,
0
,
0
));
});
});
static_for
<
0
,
NumD0Tensor
,
1
>
{}([
&
](
auto
i
)
{
d0s_threadwise_copy
(
i
).
MoveSrcSliceWindow
(
d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
[
i
],
make_multi_index
(
0
,
0
,
1
,
-
Gemm0NXdlPerWave
,
0
,
0
,
0
,
0
,
0
,
0
));
});
});
static_for
<
0
,
NumD0Tensor
,
1
>
{}([
&
](
auto
i
)
{
d0s_threadwise_copy
(
i
).
MoveSrcSliceWindow
(
d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
[
i
],
make_multi_index
(
0
,
1
,
-
Gemm0MXdlPerWave
,
0
,
0
,
0
,
0
,
0
,
0
,
0
));
});
}
// gemm1
{
// TODO: explore using dynamic buffer for a1 thread buffer
// For a1_blockwise_copy, the goal is to satisfy pipeline requirements RunRead(),
// RunWrite(), and MoveSliceWindow(). But it is impossible to implement given that
// the A1 source buffer is static buffer holding the output of first GEMM and
// requires constexpr offset by design. Therefore, we pass tensor coordinate offset
// explicitly in Run() below.
// preload data into LDS
b1_blockwise_copy
.
RunRead
(
b1_grid_desc_bk0_n_bk1
,
b1_grid_buf
);
b1_blockwise_copy
.
MoveSrcSliceWindow
(
b1_grid_desc_bk0_n_bk1
,
b1_block_slice_copy_step
);
block_sync_lds
();
// wait for gemm0 LDS read
b1_blockwise_copy
.
RunWrite
(
b1_block_desc_bk0_n_bk1
,
b1_block_buf
);
// main body
if
constexpr
(
num_gemm1_k_block_inner_loop
>
1
)
{
static_for
<
0
,
num_gemm1_k_block_inner_loop
-
1
,
1
>
{}([
&
](
auto
i
)
{
a1_blockwise_copy
.
Run
(
acc0_thread_desc_k0_m_k1
,
make_tuple
(
Number
<
i
*
A1ThreadSliceK0
>
{},
I0
,
I0
),
acc0_thread_buf
,
a1_thread_desc_k0_m_k1
,
make_tuple
(
I0
,
I0
,
I0
),
a1_thread_buf
);
b1_blockwise_copy
.
RunRead
(
b1_grid_desc_bk0_n_bk1
,
b1_grid_buf
);
block_sync_lds
();
blockwise_gemm1
.
Run
(
a1_thread_buf
,
b1_block_buf
,
c1_thread_buf
);
block_sync_lds
();
b1_blockwise_copy
.
MoveSrcSliceWindow
(
b1_grid_desc_bk0_n_bk1
,
b1_block_slice_copy_step
);
b1_blockwise_copy
.
RunWrite
(
b1_block_desc_bk0_n_bk1
,
b1_block_buf
);
});
}
// tail
{
a1_blockwise_copy
.
Run
(
acc0_thread_desc_k0_m_k1
,
make_tuple
(
Number
<
(
num_gemm1_k_block_inner_loop
-
1
)
*
A1ThreadSliceK0
>
{},
I0
,
I0
),
acc0_thread_buf
,
a1_thread_desc_k0_m_k1
,
make_tuple
(
I0
,
I0
,
I0
),
a1_thread_buf
);
block_sync_lds
();
blockwise_gemm1
.
Run
(
a1_thread_buf
,
b1_block_buf
,
c1_thread_buf
);
}
}
// end gemm1
a0_blockwise_copy
.
MoveSrcSliceWindow
(
a0_grid_desc_ak0_m_ak1
,
a0_block_reset_copy_step
);
// rewind K
b0_blockwise_copy
.
MoveSrcSliceWindow
(
b0_grid_desc_bk0_n_bk1
,
b0_block_reset_copy_step
);
// rewind K and step N
block_sync_lds
();
// wait for gemm1 LDS read
}
while
(
++
gemm1_k_block_outer_index
<
num_gemm1_k_block_outer_loop
);
// end j loop
// shuffle C1 and write out
{
static_assert
(
Gemm0MXdlPerWave
%
C1ShuffleGemm0MXdlPerWavePerShuffle
==
0
&&
Gemm1NXdlPerWave
%
C1ShuffleGemm0NXdlPerWavePerShuffle
==
0
,
"wrong!"
);
constexpr
index_t
MWave
=
Gemm0MPerBlock
/
(
Gemm0MXdlPerWave
*
Gemm0MPerXdl
);
constexpr
index_t
NWave
=
Gemm1NPerBlock
/
(
Gemm1NXdlPerWave
*
Gemm0NPerXdl
);
// TODO: hacky, fix it!
constexpr
auto
c1_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2
=
blockwise_gemm1
.
GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
();
// TODO: hacky, fix it!
// c1_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths
constexpr
auto
c1_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
=
blockwise_gemm1
.
GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
();
constexpr
auto
M0
=
c1_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
.
GetLength
(
I0
);
constexpr
auto
N0
=
c1_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
.
GetLength
(
I1
);
constexpr
auto
M1
=
c1_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
.
GetLength
(
I2
);
constexpr
auto
N1
=
c1_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
.
GetLength
(
I3
);
constexpr
auto
M2
=
c1_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
.
GetLength
(
I4
);
constexpr
auto
M3
=
c1_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
.
GetLength
(
I5
);
constexpr
auto
M4
=
c1_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
.
GetLength
(
I6
);
constexpr
auto
N2
=
c1_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
.
GetLength
(
I7
);
constexpr
auto
c1_shuffle_block_desc_mblock_mperblock_nblock_nperblock
=
GetC1ShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
();
auto
c1_shuffle_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
C1ShuffleDataType
*>
(
p_shared
),
c1_shuffle_block_desc_mblock_mperblock_nblock_nperblock
.
GetElementSpaceSize
());
constexpr
auto
c1_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
=
transform_tensor_descriptor
(
c1_shuffle_block_desc_mblock_mperblock_nblock_nperblock
,
make_tuple
(
make_freeze_transform
(
I0
),
make_unmerge_transform
(
make_tuple
(
Number
<
C1ShuffleGemm0MXdlPerWavePerShuffle
>
{},
// M0 (Gemm0MXdlPerWave) per
// shuffle
M1
,
// M1 = MWave
M2
,
// M2 * M3 * M4 = Gemm0MPerXdl
M3
,
M4
)),
make_freeze_transform
(
I0
),
make_unmerge_transform
(
make_tuple
(
Number
<
C1ShuffleGemm0NXdlPerWavePerShuffle
>
{},
// N0 (Gemm0NXdlPerWave) per
// shuffle
N1
,
// N1 = NWave
N2
))),
// N2 = Gemm0NPerXdl
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<>
{},
Sequence
<
0
,
2
,
4
,
5
,
6
>
{},
Sequence
<>
{},
Sequence
<
1
,
3
,
7
>
{}));
// calculate origin of thread output tensor on global memory
// blockwise GEMM C1 matrix starting index
const
auto
c1_thread_mtx_on_block
=
blockwise_gemm1
.
CalculateCThreadOriginDataIndex
(
I0
,
I0
,
I0
,
I0
);
const
index_t
m_thread_data_on_block
=
c1_thread_mtx_on_block
[
I0
];
const
index_t
n_thread_data_on_block
=
c1_thread_mtx_on_block
[
I1
];
const
auto
m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
M0
,
M1
,
M2
,
M3
,
M4
))),
make_tuple
(
Sequence
<
0
,
1
,
2
,
3
,
4
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
const
auto
m_thread_data_on_block_idx
=
m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor
.
CalculateBottomIndex
(
make_multi_index
(
m_thread_data_on_block
));
const
auto
n_thread_data_on_block_to_n0_n1_n2_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
N0
,
N1
,
N2
))),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
const
auto
n_thread_data_on_block_idx
=
n_thread_data_on_block_to_n0_n1_n2_adaptor
.
CalculateBottomIndex
(
make_multi_index
(
n_thread_data_on_block
));
// shuffle: threadwise copy C from VGPR to LDS
auto
c1_thread_copy_vgpr_to_lds
=
ThreadwiseTensorSliceTransfer_v1r3
<
Acc1DataType
,
C1ShuffleDataType
,
decltype
(
c1_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2
),
decltype
(
c1_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
),
tensor_operation
::
element_wise
::
PassThrough
,
Sequence
<
C1ShuffleGemm0MXdlPerWavePerShuffle
,
C1ShuffleGemm0NXdlPerWavePerShuffle
,
I1
,
I1
,
M2
,
I1
,
M4
,
I1
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
>
,
7
,
1
,
InMemoryDataOperationEnum
::
Set
,
1
,
true
>
{
c1_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
make_multi_index
(
0
,
0
,
m_thread_data_on_block_idx
[
I1
],
n_thread_data_on_block_idx
[
I1
],
m_thread_data_on_block_idx
[
I2
],
m_thread_data_on_block_idx
[
I3
],
m_thread_data_on_block_idx
[
I4
],
n_thread_data_on_block_idx
[
I2
]),
tensor_operation
::
element_wise
::
PassThrough
{}};
// tuple of reference to C/Ds tensor descriptors
const
auto
c1_d1s_desc_refs
=
concat_tuple_of_reference
(
tie
(
c1_shuffle_block_desc_mblock_mperblock_nblock_nperblock
),
generate_tie
(
[
&
](
auto
i
)
->
const
auto
&
// return type should be reference
{
return
d1s_grid_desc_mblock_mperblock_nblock_nperblock
[
i
];
},
Number
<
NumD1Tensor
>
{}));
// tuple of reference to C/Ds tensor descriptors
const
auto
c1_d1s_buf_refs
=
concat_tuple_of_reference
(
tie
(
c1_shuffle_block_buf
),
generate_tie
(
[
&
](
auto
i
)
->
const
auto
&
// return type should be reference
{
return
d1s_grid_buf
[
i
];
},
Number
<
NumD1Tensor
>
{}));
// tuple of starting index of C/Ds blockwise copy
const
auto
idx_c1_d1s_block_begin
=
container_concat
(
make_tuple
(
make_multi_index
(
0
,
0
,
0
,
0
)),
generate_tuple
(
[
&
](
auto
)
{
return
make_multi_index
(
block_work_idx
[
I0
],
0
,
block_work_idx
[
I1
],
0
);
},
Number
<
NumD1Tensor
>
{}));
// shuffle: blockwise copy C from LDS to global
auto
cde1_shuffle_block_copy_lds_to_global
=
ThreadGroupTensorSliceTransfer_v7
<
ThisThreadBlock
,
decltype
(
container_concat
(
make_tuple
(
C1ShuffleDataType
{}),
D1sDataType
{})),
Tuple
<
E1DataType
>
,
decltype
(
c1_d1s_desc_refs
),
decltype
(
tie
(
e1_grid_desc_mblock_mperblock_nblock_nperblock
)),
CDE1ElementwiseOperation
,
Sequence
<
static_cast
<
index_t
>
(
E1GlobalMemoryDataOperation
)
>
,
// FIXME: make Sequence
// support arbitray
// type
Sequence
<
1
,
C1ShuffleGemm0MXdlPerWavePerShuffle
*
MWave
*
Gemm0MPerXdl
,
1
,
C1ShuffleGemm0NXdlPerWavePerShuffle
*
NWave
*
Gemm0NPerXdl
>
,
// BlockSliceLengths,
CDE1ShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
Sequence
<
0
,
1
,
2
,
3
>
,
// typename ThreadClusterArrangeOrder,
Sequence
<
0
,
1
,
2
,
3
>
,
// typename DimAccessOrder,
3
,
// index_t VectorDim,
CDE1ShuffleBlockTransferScalarPerVector_NPerBlock
,
sequence_merge_t
<
Sequence
<
true
>
,
uniform_sequence_gen_t
<
NumD1Tensor
,
false
>>
,
// ThreadTransferSrcResetCoordinateAfterRunFlags
Sequence
<
false
>>
// ThreadTransferDstResetCoordinateAfterRunFlags
{
c1_d1s_desc_refs
,
idx_c1_d1s_block_begin
,
tie
(
e1_grid_desc_mblock_mperblock_nblock_nperblock
),
make_tuple
(
make_multi_index
(
block_work_idx
[
I0
],
0
,
block_work_idx
[
I1
],
0
)),
cde1_element_op
};
// space filling curve for threadwise C in VGPR
constexpr
auto
sfc_c1_vgpr
=
SpaceFillingCurve
<
Sequence
<
Gemm0MXdlPerWave
,
Gemm1NXdlPerWave
,
1
,
1
,
M2
,
1
,
M4
,
1
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
>
,
Sequence
<
C1ShuffleGemm0MXdlPerWavePerShuffle
,
C1ShuffleGemm0NXdlPerWavePerShuffle
,
1
,
1
,
M2
,
1
,
M4
,
1
>>
{};
// space filling curve for shuffled blockwise C in global mem
constexpr
auto
sfc_e1_global
=
SpaceFillingCurve
<
Sequence
<
1
,
Gemm0MPerBlock
,
1
,
Gemm1NPerBlock
>
,
Sequence
<
0
,
2
,
1
,
3
>
,
Sequence
<
1
,
C1ShuffleGemm0MXdlPerWavePerShuffle
*
MWave
*
Gemm0MPerXdl
,
1
,
C1ShuffleGemm0NXdlPerWavePerShuffle
*
NWave
*
Gemm0NPerXdl
>>
{};
constexpr
index_t
num_access
=
sfc_c1_vgpr
.
GetNumOfAccess
();
static_assert
(
num_access
==
sfc_e1_global
.
GetNumOfAccess
(),
"wrong!"
);
static_for
<
0
,
num_access
,
1
>
{}([
&
](
auto
access_id
)
{
// make sure it's safe to write to LDS
block_sync_lds
();
// each thread write its data from VGPR to LDS
c1_thread_copy_vgpr_to_lds
.
Run
(
c1_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
sfc_c1_vgpr
.
GetIndexTupleOfNumber
(
access_id
),
c1_thread_buf
,
c1_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
c1_shuffle_block_buf
);
// make sure it's safe to read from LDS
block_sync_lds
();
// each block copy its data from LDS to global
cde1_shuffle_block_copy_lds_to_global
.
Run
(
c1_d1s_desc_refs
,
c1_d1s_buf_refs
,
tie
(
e1_grid_desc_mblock_mperblock_nblock_nperblock
),
tie
(
e1_grid_buf
));
if
constexpr
(
access_id
<
num_access
-
1
)
{
constexpr
auto
e1_global_step
=
sfc_e1_global
.
GetForwardStep
(
access_id
);
// move on D1s
static_for
<
0
,
NumD1Tensor
,
1
>
{}([
&
](
auto
i
)
{
cde1_shuffle_block_copy_lds_to_global
.
MoveSrcSliceWindow
(
c1_d1s_desc_refs
,
i
+
I1
,
e1_global_step
);
});
// move on C
cde1_shuffle_block_copy_lds_to_global
.
MoveDstSliceWindow
(
tie
(
e1_grid_desc_mblock_mperblock_nblock_nperblock
),
I0
,
e1_global_step
);
}
});
}
}
};
}
// namespace ck
include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp
View file @
5aa3c344
...
...
@@ -76,7 +76,8 @@ template <typename FloatAB,
typename
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
index_t
CShuffleBlockTransferScalarPerVector_NPerBlock
,
LoopScheduler
LoopSched
,
bool
PadN
>
bool
PadN
,
bool
MaskOutUpperTriangle
>
struct
GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
{
static_assert
(
LoopSched
==
LoopScheduler
::
Default
,
...
...
@@ -97,6 +98,10 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
static
constexpr
auto
BK0
=
Number
<
KPerBlock
/
BK1Value
>
{};
static
constexpr
auto
AK1
=
Number
<
AK1Value
>
{};
static
constexpr
auto
BK1
=
Number
<
BK1Value
>
{};
static
constexpr
auto
Gemm0MWaves
=
MPerBlock
/
(
MPerXdl
*
MXdlPerWave
);
static
constexpr
auto
Gemm0NWaves
=
NPerBlock
/
(
NPerXdl
*
NXdlPerWave
);
// Gemm1
static
constexpr
auto
B1K0
=
Number
<
Gemm1KPerBlock
/
B1K1Value
>
{};
static
constexpr
auto
B1K1
=
Number
<
B1K1Value
>
{};
...
...
@@ -361,7 +366,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
}
};
template
<
bool
HasMainKBlockLoop
,
typename
Block2CTileMap
>
template
<
bool
HasMainKBlockLoop
,
typename
Block2CTileMap
,
typename
C0MatrixMask
>
__device__
static
void
Run
(
const
FloatAB
*
__restrict__
p_a_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
const
FloatAB
*
__restrict__
p_b1_grid
,
...
...
@@ -377,22 +382,13 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
const
B1GridDesc_BK0_N_BK1
&
b1_grid_desc_bk0_n_bk1
,
const
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
&
c_grid_desc_mblock_mperblock_nblock_nperblock
,
const
Block2CTileMap
&
block_2_ctile_map
)
const
Block2CTileMap
&
block_2_ctile_map
,
const
C0MatrixMask
&
c0_matrix_mask
)
{
const
auto
a_grid_buf
=
conditional_expr
<
PadN
>
(
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_a_grid
,
a_grid_desc_ak0_m_ak1
.
GetElementSpaceSize
(),
NumericLimits
<
FloatAB
>::
QuietNaN
()),
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_a_grid
,
a_grid_desc_ak0_m_ak1
.
GetElementSpaceSize
()));
const
auto
b_grid_buf
=
conditional_expr
<
PadN
>
(
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_b_grid
,
b_grid_desc_bk0_n_bk1
.
GetElementSpaceSize
(),
NumericLimits
<
FloatAB
>::
QuietNaN
()),
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_b_grid
,
b_grid_desc_bk0_n_bk1
.
GetElementSpaceSize
()));
const
auto
a_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_a_grid
,
a_grid_desc_ak0_m_ak1
.
GetElementSpaceSize
());
const
auto
b_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_b_grid
,
b_grid_desc_bk0_n_bk1
.
GetElementSpaceSize
());
const
auto
b1_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_b1_grid
,
b1_grid_desc_bk0_n_bk1
.
GetElementSpaceSize
());
auto
c_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
...
...
@@ -749,10 +745,30 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
running_max
=
NumericLimits
<
FloatGemmAcc
>::
Lowest
();
running_max_new
=
NumericLimits
<
FloatGemmAcc
>::
Lowest
();
// decoder lower triangular mask
const
auto
thread_cluster_idx
=
threadid_to_m_n_thread_cluster_adaptor
.
CalculateBottomIndex
(
make_multi_index
(
get_thread_local_1d_id
()));
const
auto
thread_m_cluster_id
=
thread_cluster_idx
[
I0
];
const
auto
thread_n_cluster_id
=
thread_cluster_idx
[
I1
];
const
index_t
MPerRepeat
=
MPerBlock
/
MXdlPerWave
;
const
index_t
NPerRepeat
=
NPerBlock
/
NXdlPerWave
;
const
index_t
mstart
=
m_block_data_idx_on_grid
+
thread_m_cluster_id
;
// gemm1 K loop
index_t
gemm1_k_block_outer_index
=
0
;
do
{
if
constexpr
(
MaskOutUpperTriangle
)
{
auto
gemm0_n_block_idx
=
__builtin_amdgcn_readfirstlane
(
gemm1_k_block_outer_index
*
NPerBlock
);
if
(
c0_matrix_mask
.
IsUpperTriangle
(
m_block_data_idx_on_grid
,
gemm0_n_block_idx
)
&&
c0_matrix_mask
.
IsUpperTriangle
(
m_block_data_idx_on_grid
+
MPerBlock
-
1
,
gemm0_n_block_idx
))
{
continue
;
}
}
// gemm0
gridwise_gemm_pipeline
.
template
Run
<
HasMainKBlockLoop
>(
a_grid_desc_ak0_m_ak1
,
a_block_desc_ak0_m_ak1
,
...
...
@@ -770,16 +786,63 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
acc_thread_buf
,
num_k_block_main_loop
);
// Acc0 elementwise Op
#if CK_WORKAROUND_SWDEV_XXXXXX_ATTN_KERNEL_CLANG_CANNOT_SCAVENGE_REGISTER
static_for
<
0
,
acc_thread_buf
.
Size
(),
1
>
{}(
[
&
](
auto
i
)
{
acc_element_op
(
acc_thread_buf
(
i
),
acc_thread_buf
[
i
]);
});
#else
static_for
<
0
,
acc_thread_buf
.
Size
(),
1
>
{}([
&
](
auto
i
)
{
ElementOpPredicatedResetNaNToMinusInf
<
PadN
>
{}.
Run
(
acc_thread_buf
(
i
),
acc_element_op
,
acc_thread_buf
[
i
]);
});
#endif
// do MNK padding or upper triangular masking
if
constexpr
(
MaskOutUpperTriangle
||
PadN
)
{
const
index_t
nstart
=
gemm1_k_block_outer_index
*
NPerBlock
;
static_for
<
0
,
m0
,
1
>
{}([
&
](
auto
m0_i
)
{
const
index_t
m_global
=
mstart
+
m0_i
*
MPerRepeat
;
const
index_t
acc_idx_m0
=
m0_i
*
n0
*
n2
*
n4
;
static_for
<
0
,
n0
,
1
>
{}([
&
](
auto
n0_i
)
{
// constexpr auto nrepeat_i = n0_i * NPerRepeat;
// const index_t nstartxdl = nstart + nrepeat_i;
const
index_t
nstartxdl
=
nstart
+
n0_i
*
NPerRepeat
;
const
index_t
acc_idx_n0
=
acc_idx_m0
+
n0_i
*
n2
*
n4
;
static_for
<
0
,
n2
,
1
>
{}([
&
](
auto
n2_i
)
{
const
index_t
nstartgroup
=
nstartxdl
+
thread_n_cluster_id
*
n4
+
n2_i
*
AccN3
*
n4
;
const
index_t
acc_idx_n2
=
acc_idx_n0
+
n2_i
*
n4
;
static_for
<
0
,
n4
,
1
>
{}([
&
](
auto
n4_i
)
{
const
index_t
n_global
=
nstartgroup
+
n4_i
;
const
auto
acc_offset
=
Number
<
acc_idx_n2
+
n4_i
>
{};
if
constexpr
(
MaskOutUpperTriangle
)
{
if
(
c0_matrix_mask
.
IsMaskedElement
(
m_global
,
n_global
))
{
acc_thread_buf
(
acc_offset
)
=
-
ck
::
NumericLimits
<
float
>::
Infinity
();
}
else
{
acc_element_op
(
acc_thread_buf
(
acc_offset
),
acc_thread_buf
[
acc_offset
]);
}
}
else
{
// ignore m_global;
if
(
c0_matrix_mask
.
IsNOutOfBound
(
n_global
))
{
acc_thread_buf
(
acc_offset
)
=
-
ck
::
NumericLimits
<
float
>::
Infinity
();
}
else
{
acc_element_op
(
acc_thread_buf
(
acc_offset
),
acc_thread_buf
[
acc_offset
]);
}
}
});
});
});
});
}
else
{
static_for
<
0
,
acc_thread_buf
.
Size
(),
1
>
{}(
[
&
](
auto
i
)
{
acc_element_op
(
acc_thread_buf
(
i
),
acc_thread_buf
[
i
]);
});
}
block_sync_lds
();
// wait for lds read in gemm0 blockwise gemm
...
...
@@ -881,9 +944,10 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
FloatGemmAcc
c_new
=
(
running_sum
[
iM
]
*
math
::
exp
(
running_max
[
iM
]
-
running_max_new
[
iM
])
*
c
+
math
::
exp
(
max
[
iM
]
-
running_max_new
[
iM
])
*
acc1
)
/
running_sum_new
[
iM
];
// O_new
running_sum_new
[
iM
];
// Formula by Dao et al.,
// https://arxiv.org/pdf/2205.14135v2.pdf section 3.1
c_thread_buf
(
I
)
=
c_new
;
c_thread_buf
(
I
)
=
c_new
;
// O_new
});
});
...
...
include/ck/tensor_operation/gpu/grid/gridwise_elementwise_1d.hpp
View file @
5aa3c344
...
...
@@ -83,6 +83,8 @@ struct GridwiseElementwise_1D
auto
in_global_buf_tuple
=
generate_tuple
(
[
&
](
auto
I
)
{
static_assert
(
in_grid_1d_desc_tuple
[
I
].
GetNumOfDimension
()
==
1
);
return
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_in_global_tuple
[
I
],
in_grid_1d_desc_tuple
[
I
].
GetElementSpaceSize
());
},
...
...
@@ -90,6 +92,8 @@ struct GridwiseElementwise_1D
auto
out_global_buf_tuple
=
generate_tuple
(
[
&
](
auto
I
)
{
static_assert
(
out_grid_1d_desc_tuple
[
I
].
GetNumOfDimension
()
==
1
);
return
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_out_global_tuple
[
I
],
out_grid_1d_desc_tuple
[
I
].
GetElementSpaceSize
());
},
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp
View file @
5aa3c344
...
...
@@ -35,10 +35,6 @@ template <typename ABDataType, // FIXME: don't assume A/B have same datatype
typename
BElementwiseOperation
,
typename
CDEElementwiseOperation
,
InMemoryDataOperationEnum
EGlobalMemoryDataOperation
,
typename
AGridDesc_M_K
,
typename
BGridDesc_N_K
,
typename
DsGridDesc_M_N
,
typename
EGridDesc_M_N
,
index_t
NumGemmKPrefetchStage
,
index_t
BlockSize
,
index_t
MPerBlock
,
...
...
@@ -166,6 +162,7 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
}
// A desc for source in blockwise copy
template
<
typename
AGridDesc_M_K
>
__host__
__device__
static
constexpr
auto
MakeDefaultAGridDescriptor_AK0_M_AK1
(
const
AGridDesc_M_K
&
a_grid_desc_m_k
)
{
...
...
@@ -182,6 +179,7 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
}
// B desc for source in blockwise copy
template
<
typename
BGridDesc_N_K
>
__host__
__device__
static
constexpr
auto
MakeDefaultBGridDescriptor_BK0_N_BK1
(
const
BGridDesc_N_K
&
b_grid_desc_n_k
)
{
...
...
@@ -198,9 +196,9 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
}
// E desc for destination in blockwise copy
template
<
typename
EGridDesc
riptor
_M_N
>
__host__
__device__
static
constexpr
auto
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
const
EGridDescriptor_M_N
&
e_grid_desc_m_n
)
template
<
typename
EGridDesc_M_N
>
__host__
__device__
static
constexpr
auto
Make
EGridDescriptor_
MBlock_MPerBlock_NBlock_NPerBlock
(
const
EGridDesc_
M_N
&
e_grid_desc_m_n
)
{
const
auto
M
=
e_grid_desc_m_n
.
GetLength
(
I0
);
const
auto
N
=
e_grid_desc_m_n
.
GetLength
(
I1
);
...
...
@@ -219,10 +217,9 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
}
// Ds desc for source in blockwise copy
template
<
typename
DsGridDesc
riptor
_M_N
>
template
<
typename
DsGridDesc_M_N
>
__host__
__device__
static
constexpr
auto
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
const
DsGridDescriptor_M_N
&
ds_grid_desc_m_n
)
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
const
DsGridDesc_M_N
&
ds_grid_desc_m_n
)
{
return
generate_tuple
(
[
&
](
auto
i
)
{
...
...
@@ -232,6 +229,7 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
}
// return block_id to E matrix tile idx (m0, n0) mapping
template
<
typename
EGridDesc_M_N
>
__host__
__device__
static
constexpr
auto
MakeDefaultBlock2ETileMap
(
const
EGridDesc_M_N
&
e_grid_desc_m_n
)
{
...
...
@@ -240,7 +238,11 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
}
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
template
<
typename
Block2ETileMap
>
template
<
typename
AGridDesc_M_K
,
typename
BGridDesc_N_K
,
typename
DsGridDesc_M_N
,
typename
EGridDesc_M_N
,
typename
Block2ETileMap
>
__host__
__device__
static
constexpr
bool
CheckValidity
(
const
AGridDesc_M_K
&
a_grid_desc_m_k
,
const
BGridDesc_N_K
&
b_grid_desc_n_k
,
const
DsGridDesc_M_N
&
ds_grid_desc_m_n
,
...
...
@@ -314,23 +316,13 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
return
GridwiseGemmPipe
::
CalculateHasMainLoop
(
num_loop
);
}
using
DefaultAGridDesc_AK0_M_AK1
=
remove_cvref_t
<
decltype
(
MakeDefaultAGridDescriptor_AK0_M_AK1
(
AGridDesc_M_K
{}))
>
;
using
DefaultBGridDesc_BK0_N_BK1
=
remove_cvref_t
<
decltype
(
MakeDefaultBGridDescriptor_BK0_N_BK1
(
BGridDesc_N_K
{}))
>
;
using
EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
=
remove_cvref_t
<
decltype
(
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
EGridDesc_M_N
{}))
>
;
using
DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
=
remove_cvref_t
<
decltype
(
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
DsGridDesc_M_N
{}))
>
;
using
DefaultBlock2ETileMap
=
remove_cvref_t
<
decltype
(
MakeDefaultBlock2ETileMap
(
EGridDesc_M_N
{}))
>
;
using
DsGridPointer
=
decltype
(
MakeDsGridPointer
());
template
<
bool
HasMainKBlockLoop
,
typename
AGridDesc_AK0_M_AK1
,
typename
BGridDesc_BK0_N_BK1
,
typename
DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
Block2ETileMap
>
__device__
static
void
Run
(
const
ABDataType
*
__restrict__
p_a_grid
,
const
ABDataType
*
__restrict__
p_b_grid
,
...
...
@@ -342,9 +334,9 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
const
CDEElementwiseOperation
&
cde_element_op
,
const
AGridDesc_AK0_M_AK1
&
a_grid_desc_ak0_m_ak1
,
const
BGridDesc_BK0_N_BK1
&
b_grid_desc_bk0_n_bk1
,
const
DsGridDesc
riptor
_MBlock_MPerBlock_NBlock_NPerBlock
&
const
DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
&
ds_grid_desc_mblock_mperblock_nblock_nperblock
,
const
EGridDesc
riptor
_MBlock_MPerBlock_NBlock_NPerBlock
&
const
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
&
e_grid_desc_mblock_mperblock_nblock_nperblock
,
const
Block2ETileMap
&
block_2_etile_map
)
{
...
...
include/ck/tensor_operation/gpu/grid/gridwise_layernorm_naive_variance.hpp
View file @
5aa3c344
...
...
@@ -22,7 +22,6 @@ template <typename XDataType,
typename
AccDataType
,
typename
AccElementwiseOperation
,
typename
GridDesc_M_K
,
typename
GridDesc_K
,
index_t
BlockSize
,
index_t
MThreadClusterSize
,
index_t
KThreadClusterSize
,
...
...
@@ -30,7 +29,9 @@ template <typename XDataType,
index_t
KThreadSliceSize
,
index_t
XSrcVectorDim
,
index_t
XSrcVectorSize
,
index_t
GammaSrcVectorDim
,
index_t
GammaSrcVectorSize
,
index_t
BetaSrcVectorDim
,
index_t
BetaSrcVectorSize
,
index_t
YDstVectorDim
,
index_t
YDstVectorSize
,
...
...
@@ -78,13 +79,14 @@ struct GridwiseLayernormNaiveVariance_mk_to_mk
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
index_t
M_BlockTileSize
=
MThreadClusterSize
*
MThreadSliceSize
;
static
constexpr
index_t
K_BlockTileSize
=
KThreadClusterSize
*
KThreadSliceSize
;
__device__
static
void
Run
(
const
GridDesc_M_K
&
x_grid_desc_m_k
,
const
GridDesc_K
&
gamma_grid_desc_k
,
const
GridDesc_K
&
beta_grid_desc_k
,
const
GridDesc_
M_
K
&
gamma_grid_desc_
m_
k
,
const
GridDesc_
M_
K
&
beta_grid_desc_
m_
k
,
const
GridDesc_M_K
&
y_grid_desc_m_k
,
index_t
num_k_block_tile_iteration
,
AccDataType
epsilon
,
...
...
@@ -111,11 +113,14 @@ struct GridwiseLayernormNaiveVariance_mk_to_mk
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
*
KThreadSliceSize
,
true
>
x_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
KThreadSliceSize
,
true
>
gamma_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
KThreadSliceSize
,
true
>&
beta_thread_buf
=
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
*
KThreadSliceSize
,
true
>
gamma_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
*
KThreadSliceSize
,
true
>&
beta_thread_buf
=
gamma_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
*
KThreadSliceSize
,
true
>
y_thread_buf
;
...
...
@@ -127,7 +132,7 @@ struct GridwiseLayernormNaiveVariance_mk_to_mk
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
,
true
>
mean_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
,
true
>
mean_square_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
,
true
>&
var_
value
_buf
=
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
,
true
>&
var_
thread
_buf
=
mean_square_thread_buf
;
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
I
)
{
...
...
@@ -145,11 +150,8 @@ struct GridwiseLayernormNaiveVariance_mk_to_mk
const
auto
thread_k_cluster_id
=
thread_cluster_idx
[
I1
];
using
ThreadBufferLengths_M_K
=
Sequence
<
MThreadSliceSize
,
KThreadSliceSize
>
;
using
ThreadBufferLengths_K
=
Sequence
<
KThreadSliceSize
>
;
constexpr
auto
thread_buffer_desc_m_k
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MThreadSliceSize
>
{},
Number
<
KThreadSliceSize
>
{}));
constexpr
auto
thread_buffer_desc_k
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
KThreadSliceSize
>
{}));
auto
threadwise_x_load
=
ThreadwiseTensorSliceTransfer_v2
<
XDataType
,
AccDataType
,
...
...
@@ -169,27 +171,34 @@ struct GridwiseLayernormNaiveVariance_mk_to_mk
auto
threadwise_gamma_load
=
ThreadwiseTensorSliceTransfer_v2
<
GammaDataType
,
AccDataType
,
GridDesc_K
,
decltype
(
thread_buffer_desc_k
),
ThreadBufferLengths_K
,
Sequence
<
0
>
,
0
,
GridDesc_
M_
K
,
decltype
(
thread_buffer_desc_
m_
k
),
ThreadBufferLengths_
M_
K
,
ThreadBufferDimAccessOrder
,
GammaSrcVectorDim
,
GammaSrcVectorSize
,
1
,
true
>
(
gamma_grid_desc_k
,
make_multi_index
(
thread_k_cluster_id
*
KThreadSliceSize
));
auto
threadwise_beta_load
=
ThreadwiseTensorSliceTransfer_v2
<
BetaDataType
,
AccDataType
,
GridDesc_K
,
decltype
(
thread_buffer_desc_k
),
ThreadBufferLengths_K
,
Sequence
<
0
>
,
0
,
BetaSrcVectorSize
,
1
,
true
>
(
beta_grid_desc_k
,
make_multi_index
(
thread_k_cluster_id
*
KThreadSliceSize
));
gamma_grid_desc_m_k
,
make_multi_index
(
block_global_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
,
thread_k_cluster_id
*
KThreadSliceSize
));
auto
threadwise_beta_load
=
ThreadwiseTensorSliceTransfer_v2
<
BetaDataType
,
AccDataType
,
GridDesc_M_K
,
decltype
(
thread_buffer_desc_m_k
),
ThreadBufferLengths_M_K
,
ThreadBufferDimAccessOrder
,
BetaSrcVectorDim
,
BetaSrcVectorSize
,
1
,
true
>
(
beta_grid_desc_m_k
,
make_multi_index
(
block_global_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
,
thread_k_cluster_id
*
KThreadSliceSize
));
auto
threadwise_y_store
=
ThreadwiseTensorSliceTransfer_v1r3
<
AccDataType
,
...
...
@@ -212,9 +221,6 @@ struct GridwiseLayernormNaiveVariance_mk_to_mk
// Copy x from Cache
// one pass: fwd, second pass: bwd
constexpr
auto
thread_copy_fwd_step_k
=
make_multi_index
(
SweepOnce
?
0
:
K_BlockTileSize
);
constexpr
auto
thread_copy_bwd_step_k
=
make_multi_index
(
SweepOnce
?
0
:
-
K_BlockTileSize
);
constexpr
auto
thread_copy_fwd_step_m_k
=
make_multi_index
(
0
,
SweepOnce
?
0
:
K_BlockTileSize
);
constexpr
auto
thread_copy_bwd_step_m_k
=
...
...
@@ -224,13 +230,14 @@ struct GridwiseLayernormNaiveVariance_mk_to_mk
p_x_global
,
x_grid_desc_m_k
.
GetElementSpaceSize
());
const
auto
gamma_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_gamma_global
,
gamma_grid_desc_k
.
GetElementSpaceSize
());
p_gamma_global
,
gamma_grid_desc_
m_
k
.
GetElementSpaceSize
());
const
auto
beta_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_beta_global
,
beta_grid_desc_k
.
GetElementSpaceSize
());
p_beta_global
,
beta_grid_desc_
m_
k
.
GetElementSpaceSize
());
// E(x), E[x^2], var(x)
int
reduce_length
=
x_grid_desc_m_k
.
GetTransforms
()[
I0
].
GetUpperLengths
()[
I1
];
// FIXME: Should not hack the transform from deviceOP
int
reduce_length
=
x_grid_desc_m_k
.
GetTransforms
()[
I2
].
GetUpperLengths
()[
I0
];
index_t
reducedTiles
=
0
;
do
...
...
@@ -271,17 +278,16 @@ struct GridwiseLayernormNaiveVariance_mk_to_mk
mean_square_thread_buf
(
I
)
=
mean_square_thread_buf
(
I
)
/
reduce_length
;
// var(x) = E[x^2] - E[x]^2
var_
value
_buf
(
I
)
=
var_
thread
_buf
(
I
)
=
mean_square_thread_buf
(
I
)
-
(
mean_thread_buf
(
I
)
*
mean_thread_buf
(
I
));
});
// y = (x - E[x]) / sqrt(var[x] + epsilon)
auto
thread_copy_tail_m_k
=
(
num_k_block_tile_iteration
-
1
)
*
thread_copy_fwd_step_m_k
;
auto
thread_copy_tail_k
=
(
num_k_block_tile_iteration
-
1
)
*
thread_copy_fwd_step_k
;
threadwise_x_load
.
MoveSrcSliceWindow
(
x_grid_desc_m_k
,
thread_copy_bwd_step_m_k
);
threadwise_gamma_load
.
MoveSrcSliceWindow
(
gamma_grid_desc_k
,
thread_copy_tail_k
);
threadwise_beta_load
.
MoveSrcSliceWindow
(
beta_grid_desc_k
,
thread_copy_tail_k
);
threadwise_gamma_load
.
MoveSrcSliceWindow
(
gamma_grid_desc_
m_
k
,
thread_copy_tail_
m_
k
);
threadwise_beta_load
.
MoveSrcSliceWindow
(
beta_grid_desc_
m_
k
,
thread_copy_tail_
m_
k
);
threadwise_y_store
.
MoveDstSliceWindow
(
y_grid_desc_m_k
,
thread_copy_tail_m_k
);
reducedTiles
=
0
;
...
...
@@ -296,10 +302,10 @@ struct GridwiseLayernormNaiveVariance_mk_to_mk
x_thread_buf
);
}
threadwise_gamma_load
.
Run
(
gamma_grid_desc_k
,
threadwise_gamma_load
.
Run
(
gamma_grid_desc_
m_
k
,
gamma_global_val_buf
,
thread_buffer_desc_k
,
make_tuple
(
I0
),
thread_buffer_desc_
m_
k
,
make_tuple
(
I0
,
I0
),
gamma_thread_buf
);
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
iM
)
{
...
...
@@ -307,23 +313,21 @@ struct GridwiseLayernormNaiveVariance_mk_to_mk
constexpr
auto
offset_m_k
=
thread_buffer_desc_m_k
.
CalculateOffset
(
make_tuple
(
iM
,
iK
));
constexpr
auto
offset_k
=
thread_buffer_desc_k
.
CalculateOffset
(
make_tuple
(
iK
));
// normalize
y_thread_buf
(
Number
<
offset_m_k
>
{})
=
(
x_thread_buf
(
Number
<
offset_m_k
>
{})
-
mean_thread_buf
(
iM
))
/
sqrt
(
var_
value
_buf
(
iM
)
+
epsilon
);
sqrt
(
var_
thread
_buf
(
iM
)
+
epsilon
);
// gamma
y_thread_buf
(
Number
<
offset_m_k
>
{})
=
y_thread_buf
(
Number
<
offset_m_k
>
{})
*
gamma_thread_buf
(
Number
<
offset_k
>
{});
y_thread_buf
(
Number
<
offset_m_k
>
{})
*
gamma_thread_buf
(
Number
<
offset_
m_
k
>
{});
});
});
threadwise_beta_load
.
Run
(
beta_grid_desc_k
,
threadwise_beta_load
.
Run
(
beta_grid_desc_
m_
k
,
beta_global_val_buf
,
thread_buffer_desc_k
,
make_tuple
(
I0
),
thread_buffer_desc_
m_
k
,
make_tuple
(
I0
,
I0
),
beta_thread_buf
);
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
iM
)
{
...
...
@@ -331,11 +335,9 @@ struct GridwiseLayernormNaiveVariance_mk_to_mk
constexpr
auto
offset_m_k
=
thread_buffer_desc_m_k
.
CalculateOffset
(
make_tuple
(
iM
,
iK
));
constexpr
auto
offset_k
=
thread_buffer_desc_k
.
CalculateOffset
(
make_tuple
(
iK
));
// beta
y_thread_buf
(
Number
<
offset_m_k
>
{})
=
y_thread_buf
(
Number
<
offset_m_k
>
{})
+
beta_thread_buf
(
Number
<
offset_k
>
{});
y_thread_buf
(
Number
<
offset_m_k
>
{})
+
beta_thread_buf
(
Number
<
offset_
m_
k
>
{});
});
});
...
...
@@ -346,8 +348,8 @@ struct GridwiseLayernormNaiveVariance_mk_to_mk
y_global_val_buf
);
threadwise_x_load
.
MoveSrcSliceWindow
(
x_grid_desc_m_k
,
thread_copy_bwd_step_m_k
);
threadwise_gamma_load
.
MoveSrcSliceWindow
(
gamma_grid_desc_k
,
thread_copy_bwd_step_k
);
threadwise_beta_load
.
MoveSrcSliceWindow
(
beta_grid_desc_k
,
thread_copy_bwd_step_k
);
threadwise_gamma_load
.
MoveSrcSliceWindow
(
gamma_grid_desc_
m_
k
,
thread_copy_bwd_step_
m_
k
);
threadwise_beta_load
.
MoveSrcSliceWindow
(
beta_grid_desc_
m_
k
,
thread_copy_bwd_step_
m_
k
);
threadwise_y_store
.
MoveDstSliceWindow
(
y_grid_desc_m_k
,
thread_copy_bwd_step_m_k
);
++
reducedTiles
;
...
...
include/ck/tensor_operation/gpu/grid/gridwise_layernorm_welford_variance.hpp
View file @
5aa3c344
...
...
@@ -19,7 +19,6 @@ template <typename XDataType,
typename
AccDataType
,
typename
AccElementwiseOperation
,
typename
GridDesc_M_K
,
typename
GridDesc_K
,
index_t
BlockSize
,
index_t
MThreadClusterSize
,
index_t
KThreadClusterSize
,
...
...
@@ -27,7 +26,9 @@ template <typename XDataType,
index_t
KThreadSliceSize
,
index_t
XSrcVectorDim
,
index_t
XSrcVectorSize
,
index_t
GammaSrcVectorDim
,
index_t
GammaSrcVectorSize
,
index_t
BetaSrcVectorDim
,
index_t
BetaSrcVectorSize
,
index_t
YDstVectorDim
,
index_t
YDstVectorSize
,
...
...
@@ -70,6 +71,7 @@ struct GridwiseLayernormWelfordVariance_mk_to_mk
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
index_t
M_BlockTileSize
=
MThreadClusterSize
*
MThreadSliceSize
;
static
constexpr
index_t
K_BlockTileSize
=
KThreadClusterSize
*
KThreadSliceSize
;
...
...
@@ -77,7 +79,8 @@ struct GridwiseLayernormWelfordVariance_mk_to_mk
__device__
static
int
GetKPerThread
(
const
GridDesc_M_K
&
x_grid_desc_m_k
,
int
thread_k_cluster_id
)
{
int
kPerBlock
=
x_grid_desc_m_k
.
GetTransforms
()[
I0
].
GetUpperLengths
()[
I1
];
// FIXME: Should not hack the transform from deviceOP
int
kPerBlock
=
x_grid_desc_m_k
.
GetTransforms
()[
I2
].
GetUpperLengths
()[
I0
];
int
kPerThread
=
kPerBlock
<
K_BlockTileSize
?
0
:
KThreadSliceSize
*
(
kPerBlock
/
K_BlockTileSize
);
int
kPerBlockTail
=
kPerBlock
-
kPerThread
*
KThreadClusterSize
;
...
...
@@ -94,8 +97,8 @@ struct GridwiseLayernormWelfordVariance_mk_to_mk
}
__device__
static
void
Run
(
const
GridDesc_M_K
&
x_grid_desc_m_k
,
const
GridDesc_K
&
gamma_grid_desc_k
,
const
GridDesc_K
&
beta_grid_desc_k
,
const
GridDesc_
M_
K
&
gamma_grid_desc_
m_
k
,
const
GridDesc_
M_
K
&
beta_grid_desc_
m_
k
,
const
GridDesc_M_K
&
y_grid_desc_m_k
,
index_t
num_k_block_tile_iteration
,
AccDataType
epsilon
,
...
...
@@ -116,11 +119,14 @@ struct GridwiseLayernormWelfordVariance_mk_to_mk
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
*
KThreadSliceSize
,
true
>
x_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
KThreadSliceSize
,
true
>
gamma_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
KThreadSliceSize
,
true
>&
beta_thread_buf
=
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
*
KThreadSliceSize
,
true
>
gamma_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
*
KThreadSliceSize
,
true
>&
beta_thread_buf
=
gamma_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
*
KThreadSliceSize
,
true
>
y_thread_buf
;
...
...
@@ -137,11 +143,8 @@ struct GridwiseLayernormWelfordVariance_mk_to_mk
const
auto
thread_k_cluster_id
=
thread_cluster_idx
[
I1
];
using
ThreadBufferLengths_M_K
=
Sequence
<
MThreadSliceSize
,
KThreadSliceSize
>
;
using
ThreadBufferLengths_K
=
Sequence
<
KThreadSliceSize
>
;
constexpr
auto
thread_buffer_desc_m_k
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MThreadSliceSize
>
{},
Number
<
KThreadSliceSize
>
{}));
constexpr
auto
thread_buffer_desc_k
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
KThreadSliceSize
>
{}));
auto
threadwise_x_load
=
ThreadwiseTensorSliceTransfer_v2
<
XDataType
,
AccDataType
,
...
...
@@ -161,27 +164,34 @@ struct GridwiseLayernormWelfordVariance_mk_to_mk
auto
threadwise_gamma_load
=
ThreadwiseTensorSliceTransfer_v2
<
GammaDataType
,
AccDataType
,
GridDesc_K
,
decltype
(
thread_buffer_desc_k
),
ThreadBufferLengths_K
,
Sequence
<
0
>
,
0
,
GridDesc_
M_
K
,
decltype
(
thread_buffer_desc_
m_
k
),
ThreadBufferLengths_
M_
K
,
ThreadBufferDimAccessOrder
,
GammaSrcVectorDim
,
GammaSrcVectorSize
,
1
,
true
>
(
gamma_grid_desc_k
,
make_multi_index
(
thread_k_cluster_id
*
KThreadSliceSize
));
auto
threadwise_beta_load
=
ThreadwiseTensorSliceTransfer_v2
<
BetaDataType
,
AccDataType
,
GridDesc_K
,
decltype
(
thread_buffer_desc_k
),
ThreadBufferLengths_K
,
Sequence
<
0
>
,
0
,
BetaSrcVectorSize
,
1
,
true
>
(
beta_grid_desc_k
,
make_multi_index
(
thread_k_cluster_id
*
KThreadSliceSize
));
gamma_grid_desc_m_k
,
make_multi_index
(
block_global_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
,
thread_k_cluster_id
*
KThreadSliceSize
));
auto
threadwise_beta_load
=
ThreadwiseTensorSliceTransfer_v2
<
BetaDataType
,
AccDataType
,
GridDesc_M_K
,
decltype
(
thread_buffer_desc_m_k
),
ThreadBufferLengths_M_K
,
ThreadBufferDimAccessOrder
,
BetaSrcVectorDim
,
BetaSrcVectorSize
,
1
,
true
>
(
beta_grid_desc_m_k
,
make_multi_index
(
block_global_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
,
thread_k_cluster_id
*
KThreadSliceSize
));
auto
threadwise_y_store
=
ThreadwiseTensorSliceTransfer_v1r3
<
AccDataType
,
...
...
@@ -204,9 +214,6 @@ struct GridwiseLayernormWelfordVariance_mk_to_mk
// Copy x from Cache
// one pass: fwd, second pass: bwd
constexpr
auto
thread_copy_fwd_step_k
=
make_multi_index
(
SweepOnce
?
0
:
K_BlockTileSize
);
constexpr
auto
thread_copy_bwd_step_k
=
make_multi_index
(
SweepOnce
?
0
:
-
K_BlockTileSize
);
constexpr
auto
thread_copy_fwd_step_m_k
=
make_multi_index
(
0
,
SweepOnce
?
0
:
K_BlockTileSize
);
constexpr
auto
thread_copy_bwd_step_m_k
=
...
...
@@ -216,10 +223,10 @@ struct GridwiseLayernormWelfordVariance_mk_to_mk
p_x_global
,
x_grid_desc_m_k
.
GetElementSpaceSize
());
const
auto
gamma_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_gamma_global
,
gamma_grid_desc_k
.
GetElementSpaceSize
());
p_gamma_global
,
gamma_grid_desc_
m_
k
.
GetElementSpaceSize
());
const
auto
beta_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_beta_global
,
beta_grid_desc_k
.
GetElementSpaceSize
());
p_beta_global
,
beta_grid_desc_
m_
k
.
GetElementSpaceSize
());
auto
threadwise_welford
=
ThreadwiseWelford
();
threadwise_welford
.
max_count_
=
GetKPerThread
(
x_grid_desc_m_k
,
thread_k_cluster_id
);
...
...
@@ -250,11 +257,10 @@ struct GridwiseLayernormWelfordVariance_mk_to_mk
});
auto
thread_copy_tail_m_k
=
(
num_k_block_tile_iteration
-
1
)
*
thread_copy_fwd_step_m_k
;
auto
thread_copy_tail_k
=
(
num_k_block_tile_iteration
-
1
)
*
thread_copy_fwd_step_k
;
threadwise_x_load
.
MoveSrcSliceWindow
(
x_grid_desc_m_k
,
thread_copy_bwd_step_m_k
);
threadwise_gamma_load
.
MoveSrcSliceWindow
(
gamma_grid_desc_k
,
thread_copy_tail_k
);
threadwise_beta_load
.
MoveSrcSliceWindow
(
beta_grid_desc_k
,
thread_copy_tail_k
);
threadwise_gamma_load
.
MoveSrcSliceWindow
(
gamma_grid_desc_
m_
k
,
thread_copy_tail_
m_
k
);
threadwise_beta_load
.
MoveSrcSliceWindow
(
beta_grid_desc_
m_
k
,
thread_copy_tail_
m_
k
);
threadwise_y_store
.
MoveDstSliceWindow
(
y_grid_desc_m_k
,
thread_copy_tail_m_k
);
for
(
index_t
reducedTiles
=
0
;
reducedTiles
<
num_k_block_tile_iteration
;
++
reducedTiles
)
...
...
@@ -268,10 +274,10 @@ struct GridwiseLayernormWelfordVariance_mk_to_mk
x_thread_buf
);
}
threadwise_gamma_load
.
Run
(
gamma_grid_desc_k
,
threadwise_gamma_load
.
Run
(
gamma_grid_desc_
m_
k
,
gamma_global_val_buf
,
thread_buffer_desc_k
,
make_tuple
(
I0
),
thread_buffer_desc_
m_
k
,
make_tuple
(
I0
,
I0
),
gamma_thread_buf
);
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
iM
)
{
...
...
@@ -279,8 +285,6 @@ struct GridwiseLayernormWelfordVariance_mk_to_mk
constexpr
auto
offset_m_k
=
thread_buffer_desc_m_k
.
CalculateOffset
(
make_tuple
(
iM
,
iK
));
constexpr
auto
offset_k
=
thread_buffer_desc_k
.
CalculateOffset
(
make_tuple
(
iK
));
// normalize
y_thread_buf
(
Number
<
offset_m_k
>
{})
=
(
x_thread_buf
(
Number
<
offset_m_k
>
{})
-
mean_thread_buf
(
iM
))
/
...
...
@@ -288,14 +292,14 @@ struct GridwiseLayernormWelfordVariance_mk_to_mk
// gamma
y_thread_buf
(
Number
<
offset_m_k
>
{})
=
y_thread_buf
(
Number
<
offset_m_k
>
{})
*
gamma_thread_buf
(
Number
<
offset_k
>
{});
y_thread_buf
(
Number
<
offset_m_k
>
{})
*
gamma_thread_buf
(
Number
<
offset_
m_
k
>
{});
});
});
threadwise_beta_load
.
Run
(
beta_grid_desc_k
,
threadwise_beta_load
.
Run
(
beta_grid_desc_
m_
k
,
beta_global_val_buf
,
thread_buffer_desc_k
,
make_tuple
(
I0
),
thread_buffer_desc_
m_
k
,
make_tuple
(
I0
,
I0
),
beta_thread_buf
);
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
iM
)
{
...
...
@@ -303,11 +307,9 @@ struct GridwiseLayernormWelfordVariance_mk_to_mk
constexpr
auto
offset_m_k
=
thread_buffer_desc_m_k
.
CalculateOffset
(
make_tuple
(
iM
,
iK
));
constexpr
auto
offset_k
=
thread_buffer_desc_k
.
CalculateOffset
(
make_tuple
(
iK
));
// beta
y_thread_buf
(
Number
<
offset_m_k
>
{})
=
y_thread_buf
(
Number
<
offset_m_k
>
{})
+
beta_thread_buf
(
Number
<
offset_k
>
{});
y_thread_buf
(
Number
<
offset_m_k
>
{})
+
beta_thread_buf
(
Number
<
offset_
m_
k
>
{});
});
});
...
...
@@ -318,8 +320,8 @@ struct GridwiseLayernormWelfordVariance_mk_to_mk
y_global_val_buf
);
threadwise_x_load
.
MoveSrcSliceWindow
(
x_grid_desc_m_k
,
thread_copy_bwd_step_m_k
);
threadwise_gamma_load
.
MoveSrcSliceWindow
(
gamma_grid_desc_k
,
thread_copy_bwd_step_k
);
threadwise_beta_load
.
MoveSrcSliceWindow
(
beta_grid_desc_k
,
thread_copy_bwd_step_k
);
threadwise_gamma_load
.
MoveSrcSliceWindow
(
gamma_grid_desc_
m_
k
,
thread_copy_bwd_step_
m_
k
);
threadwise_beta_load
.
MoveSrcSliceWindow
(
beta_grid_desc_
m_
k
,
thread_copy_bwd_step_
m_
k
);
threadwise_y_store
.
MoveDstSliceWindow
(
y_grid_desc_m_k
,
thread_copy_bwd_step_m_k
);
}
}
...
...
include/ck/tensor_operation/gpu/grid/gridwise_permute.hpp
0 → 100644
View file @
5aa3c344
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <functional>
#include <numeric>
#include <iterator>
#include "ck/tensor_description/cluster_descriptor.hpp"
#include "ck/utility/data_type.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
namespace
ck
{
template
<
typename
GridwisePermute
,
typename
InGridDesc
,
typename
OutGridDesc
,
typename
InDataType
,
typename
OutDataType
,
typename
ElementwiseOperation
,
typename
Block2TileMap
>
__global__
void
kernel_nd_permute
(
const
InGridDesc
in_grid_desc
,
const
OutGridDesc
out_grid_desc
,
const
InDataType
*
p_in_global
,
OutDataType
*
p_out_global
,
const
ElementwiseOperation
elementwise_op
,
const
Block2TileMap
block_2_tile_map
)
{
__shared__
char
p_shared
[
GridwisePermute
::
GetSharedMemoryNumberOfByte
()];
GridwisePermute
::
Run
(
in_grid_desc
,
out_grid_desc
,
p_in_global
,
p_out_global
,
p_shared
,
elementwise_op
,
block_2_tile_map
);
}
template
<
typename
InGridDesc
,
typename
OutGridDesc
,
typename
InDataType
,
typename
OutDataType
,
typename
ElementwiseOperation
,
index_t
BlockSize
,
index_t
NPerBlock
,
index_t
HPerBlock
,
index_t
WPerBlock
,
index_t
InBlockLdsExtraW
,
typename
InBlockTransferThreadClusterLengths
,
typename
InBlockTransferThreadClusterArrangeOrder
,
index_t
SrcVectorDim
,
index_t
DstVectorDim
,
index_t
SrcScalarPerVector
,
index_t
DstScalarPerVector
>
struct
GridwisePermute
{
static_assert
(
InGridDesc
::
GetNumOfDimension
()
==
OutGridDesc
::
GetNumOfDimension
());
static_assert
(
3
<=
InGridDesc
::
GetNumOfDimension
());
static_assert
((
InGridDesc
::
GetNumOfDimension
()
-
2
)
<=
SrcVectorDim
&&
SrcVectorDim
<
InGridDesc
::
GetNumOfDimension
());
static_assert
((
OutGridDesc
::
GetNumOfDimension
()
-
2
)
<=
DstVectorDim
&&
DstVectorDim
<
OutGridDesc
::
GetNumOfDimension
());
static_assert
(
SrcVectorDim
!=
DstVectorDim
);
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
using
ThisThreadBlock
=
ThisThreadBlock
<
BlockSize
>
;
struct
Block2TileMap
{
static
constexpr
index_t
NumDim
=
InGridDesc
::
GetNumOfDimension
();
static_assert
(
3
<=
NumDim
);
static
constexpr
auto
I0
=
Number
<
0
>
{};
Block2TileMap
()
=
delete
;
Block2TileMap
(
const
Block2TileMap
&
)
=
default
;
Block2TileMap
(
Block2TileMap
&&
)
=
delete
;
~
Block2TileMap
()
=
default
;
Block2TileMap
&
operator
=
(
const
Block2TileMap
&
)
=
delete
;
Block2TileMap
&
operator
=
(
Block2TileMap
&&
)
=
delete
;
explicit
Block2TileMap
(
const
InGridDesc
&
desc
)
:
desc_
(
desc
)
{}
__host__
constexpr
index_t
CalculateGridSize
(
const
InGridDesc
&
desc
)
const
{
const
auto
N0
=
math
::
integer_divide_ceil
(
desc
.
GetLength
(
Number
<
NumDim
-
3
>
{}),
NPerBlock
);
const
auto
H0
=
math
::
integer_divide_ceil
(
desc
.
GetLength
(
Number
<
NumDim
-
2
>
{}),
HPerBlock
);
const
auto
W0
=
math
::
integer_divide_ceil
(
desc
.
GetLength
(
Number
<
NumDim
-
1
>
{}),
WPerBlock
);
const
index_t
grid_size
=
N0
*
H0
*
W0
;
return
grid_size
;
}
template
<
typename
TopIdx
>
__host__
__device__
constexpr
auto
CalculateBottomIndex
(
const
TopIdx
&
idx_top
)
const
{
static_assert
(
TopIdx
::
Size
()
==
1
);
auto
block_1d_id
=
idx_top
[
I0
];
const
auto
N0
=
math
::
integer_divide_ceil
(
desc_
.
GetLength
(
Number
<
NumDim
-
3
>
{}),
NPerBlock
);
const
auto
H0
=
math
::
integer_divide_ceil
(
desc_
.
GetLength
(
Number
<
NumDim
-
2
>
{}),
HPerBlock
);
const
auto
W0
=
math
::
integer_divide_ceil
(
desc_
.
GetLength
(
Number
<
NumDim
-
1
>
{}),
WPerBlock
);
block_1d_id
=
block_1d_id
%
(
N0
*
H0
*
W0
);
index_t
idx_N0
=
block_1d_id
/
(
H0
*
W0
);
index_t
idx_H0
=
(
block_1d_id
%
(
H0
*
W0
))
/
W0
;
index_t
idx_W0
=
block_1d_id
%
W0
;
return
make_tuple
(
idx_N0
,
idx_H0
,
idx_W0
);
}
private:
const
InGridDesc
desc_
;
};
using
DefaultBlock2TileMap
=
Block2TileMap
;
// use an [NPerBlock, HPerBlock, WPerBlock] tensor as element-copy relay
__host__
__device__
static
constexpr
auto
GetInBlockDesc_NPerBlock_HPerBlock_WPerBlock
()
{
return
make_naive_tensor_descriptor
(
make_tuple
(
Number
<
NPerBlock
>
{},
Number
<
HPerBlock
>
{},
Number
<
WPerBlock
>
{}),
make_tuple
(
Number
<
HPerBlock
*
(
WPerBlock
+
InBlockLdsExtraW
)
>
{},
Number
<
WPerBlock
+
InBlockLdsExtraW
>
{},
I1
));
}
// for N-dimension descriptor, reserve its last 2 dimensions, then merge its leading dimensions
// into single one. finally, form a 3D descriptor: [d(0), d(1), ..., d(N - 2), d(N - 1)] ->
// [(d(0) x d(1) x ...), d(N - 2), d(N - 1)]
template
<
typename
GridDesc
>
__host__
__device__
static
constexpr
auto
GetMergedDesc
(
const
GridDesc
&
desc
)
{
constexpr
index_t
NumDim
=
GridDesc
::
GetNumOfDimension
();
static_assert
(
3
<=
NumDim
);
const
auto
merged_desc
=
transform_tensor_descriptor
(
desc
,
make_tuple
(
make_merge_transform
(
generate_tuple
(
[
&
](
auto
I
)
{
return
desc
.
GetLength
(
I
);
},
Number
<
NumDim
-
2
>
{})),
make_pass_through_transform
(
desc
.
GetLength
(
Number
<
NumDim
-
2
>
{})),
make_pass_through_transform
(
desc
.
GetLength
(
Number
<
NumDim
-
1
>
{}))),
make_tuple
(
generate_sequence_v2
([
&
](
auto
I
)
{
return
I
;
},
Number
<
NumDim
-
2
>
{}),
Sequence
<
NumDim
-
2
>
{},
Sequence
<
NumDim
-
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}));
return
merged_desc
;
}
__host__
__device__
static
constexpr
index_t
GetSharedMemoryNumberOfByte
()
{
constexpr
auto
in_block_desc_nperblock_hperblock_wperblock
=
GetInBlockDesc_NPerBlock_HPerBlock_WPerBlock
();
return
in_block_desc_nperblock_hperblock_wperblock
.
GetElementSpaceSize
()
*
sizeof
(
InDataType
);
}
__host__
__device__
static
constexpr
auto
MakeDefaultBlock2TileMap
(
const
InGridDesc
&
desc
)
{
return
DefaultBlock2TileMap
{
desc
};
}
__host__
__device__
static
constexpr
bool
CheckValidity
(
const
InGridDesc
&
in_grid_desc
,
const
OutGridDesc
&
out_grid_desc
)
{
constexpr
index_t
NumDim
=
InGridDesc
::
GetNumOfDimension
();
// check if we only swap last 2 dimensions
bool
valid
=
true
;
static_for
<
0
,
NumDim
-
2
,
1
>
{}([
&
](
auto
I
)
{
if
(
valid
&&
in_grid_desc
.
GetLength
(
I
)
!=
out_grid_desc
.
GetLength
(
I
))
{
valid
=
false
;
}
});
return
valid
&&
(
in_grid_desc
.
GetLength
(
Number
<
NumDim
-
1
>
{})
==
out_grid_desc
.
GetLength
(
Number
<
NumDim
-
2
>
{}))
&&
(
in_grid_desc
.
GetLength
(
Number
<
NumDim
-
2
>
{})
==
out_grid_desc
.
GetLength
(
Number
<
NumDim
-
1
>
{}));
}
template
<
typename
Block2TileMap
>
__device__
static
void
Run
(
const
InGridDesc
in_grid_desc
,
const
OutGridDesc
out_grid_desc
,
const
InDataType
*
p_in_global
,
OutDataType
*
p_out_global
,
void
*
__restrict__
p_shared
,
const
ElementwiseOperation
elementwise_op
,
const
Block2TileMap
&
block_2_tile_map
)
{
auto
in_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_in_global
,
in_grid_desc
.
GetElementSpaceSize
());
auto
out_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_out_global
,
out_grid_desc
.
GetElementSpaceSize
());
// each workgroup handles an [NPerBlock, HPerBlock, WPerBLock] slice-transpose problem
const
auto
block_work_idx
=
block_2_tile_map
.
CalculateBottomIndex
(
make_multi_index
(
get_block_1d_id
()));
const
index_t
n_block_data_idx_on_grid
=
__builtin_amdgcn_readfirstlane
(
block_work_idx
[
I0
]
*
NPerBlock
);
const
index_t
h_block_data_idx_on_grid
=
__builtin_amdgcn_readfirstlane
(
block_work_idx
[
I1
]
*
HPerBlock
);
const
index_t
w_block_data_idx_on_grid
=
__builtin_amdgcn_readfirstlane
(
block_work_idx
[
I2
]
*
WPerBlock
);
// create [NPerBlock, HPerBlock, WPerBLock] shaped LDS buffer
constexpr
auto
in_block_desc_nperblock_hperblock_wperblock
=
GetInBlockDesc_NPerBlock_HPerBlock_WPerBlock
();
auto
in_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
InDataType
*>
(
p_shared
),
in_block_desc_nperblock_hperblock_wperblock
.
GetElementSpaceSize
());
using
BlockSliceLengths
=
Sequence
<
NPerBlock
,
HPerBlock
,
WPerBlock
>
;
using
InBlockTransferAccessOrder
=
Sequence
<
0
,
1
,
2
>
;
constexpr
index_t
SrcVectorDimAfterMerge
=
SrcVectorDim
-
(
InGridDesc
::
GetNumOfDimension
()
-
3
);
constexpr
index_t
DstVectorDimAfterMerge
=
SrcVectorDimAfterMerge
;
using
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
// merge input descriptor into [(in_grid_desc.GetLength(0) x in_grid_desc.GetLength(1) x
// ...), in_grid_desc.GetLength(NumDim - 2), in_grid_desc.GetLength(NumDim - 1)]
const
auto
in_grid_desc_n_h_w
=
GetMergedDesc
(
in_grid_desc
);
// a workgroup copies an [NPerBlock, HPerBlock, WPerBlock] slice from global memory to LDS
auto
in_global_load
=
ThreadGroupTensorSliceTransfer_v4r1
<
ThisThreadBlock
,
ElementwiseOperation
,
PassThrough
,
InMemoryDataOperationEnum
::
Set
,
BlockSliceLengths
,
InBlockTransferThreadClusterLengths
,
InBlockTransferThreadClusterArrangeOrder
,
InDataType
,
InDataType
,
decltype
(
in_grid_desc_n_h_w
),
decltype
(
in_block_desc_nperblock_hperblock_wperblock
),
InBlockTransferAccessOrder
,
InBlockTransferAccessOrder
,
SrcVectorDimAfterMerge
,
2
,
SrcScalarPerVector
,
1
,
1
,
1
,
true
,
true
>
(
in_grid_desc_n_h_w
,
make_multi_index
(
n_block_data_idx_on_grid
,
h_block_data_idx_on_grid
,
w_block_data_idx_on_grid
),
PassThrough
{},
in_block_desc_nperblock_hperblock_wperblock
,
make_multi_index
(
0
,
0
,
0
),
PassThrough
{});
// merge output descriptor into [(out_grid_desc.GetLength(0) x out_grid_desc.GetLength(1) x
// ...), out_grid_desc.GetLength(NumDim - 2), out_grid_desc.GetLength(NumDim - 1)]
const
auto
out_grid_desc_n_w_h
=
GetMergedDesc
(
out_grid_desc
);
// create transposed view of output tensor
const
auto
out_grid_desc_n_h_w
=
transform_tensor_descriptor
(
out_grid_desc_n_w_h
,
make_tuple
(
make_pass_through_transform
(
out_grid_desc_n_w_h
.
GetLength
(
I0
)),
make_pass_through_transform
(
out_grid_desc_n_w_h
.
GetLength
(
I1
)),
make_pass_through_transform
(
out_grid_desc_n_w_h
.
GetLength
(
I2
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
2
>
{},
Sequence
<
1
>
{}));
// a workgroup copies an [NPerBlock, HPerBlock, WPerBlock] slice from LDS to global memory
auto
out_global_store
=
ThreadGroupTensorSliceTransfer_v4r1
<
ThisThreadBlock
,
ElementwiseOperation
,
PassThrough
,
InMemoryDataOperationEnum
::
Set
,
BlockSliceLengths
,
InBlockTransferThreadClusterLengths
,
InBlockTransferThreadClusterArrangeOrder
,
InDataType
,
OutDataType
,
decltype
(
in_block_desc_nperblock_hperblock_wperblock
),
decltype
(
out_grid_desc_n_h_w
),
InBlockTransferAccessOrder
,
InBlockTransferAccessOrder
,
2
,
DstVectorDimAfterMerge
,
1
,
DstScalarPerVector
,
1
,
1
,
true
,
true
>
(
in_block_desc_nperblock_hperblock_wperblock
,
make_multi_index
(
0
,
0
,
0
),
PassThrough
{},
out_grid_desc_n_h_w
,
make_multi_index
(
n_block_data_idx_on_grid
,
h_block_data_idx_on_grid
,
w_block_data_idx_on_grid
),
elementwise_op
);
in_global_load
.
Run
(
in_grid_desc_n_h_w
,
in_global_buf
,
in_block_desc_nperblock_hperblock_wperblock
,
in_block_buf
,
I0
);
out_global_store
.
Run
(
in_block_desc_nperblock_hperblock_wperblock
,
in_block_buf
,
out_grid_desc_n_h_w
,
out_global_buf
,
I0
);
}
};
}
// namespace ck
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1.hpp
View file @
5aa3c344
...
...
@@ -6,6 +6,7 @@
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor/static_tensor.hpp"
namespace
ck
{
...
...
include/ck/tensor_operation/operator_transform/transform_conv_bwd_data_to_gemm_v1.hpp
0 → 100644
View file @
5aa3c344
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/convolution_backward_data_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
namespace
ck
{
namespace
tensor_operation
{
template
<
index_t
NDimSpatial
,
ck
::
tensor_operation
::
device
::
ConvolutionBackwardDataSpecialization
ConvBwdDataSpecialization
,
index_t
AK1
,
index_t
BK1
,
index_t
GemmMPerBlock
,
index_t
GemmNPerBlock
,
bool
DoPadGemmM
,
bool
DoPadGemmN
>
struct
TransformConvBwdDataToGemm_v1
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
template
<
typename
ALayout
,
typename
std
::
enable_if
<
NDimSpatial
==
2
&&
is_same_v
<
ALayout
,
tensor_layout
::
convolution
::
GNHWK
>,
bool
>::
type
=
false
>
static
auto
MakeADescriptor_AK0_M_AK1
(
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
out_g_n_k_wos_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
/* out_g_n_k_wos_strides */
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
wei_g_k_c_xs_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
/* wei_g_k_c_xs_strides */
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
in_g_n_c_wis_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
/* in_g_n_c_wis_strides */
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_dilations
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_left_pads
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
/* input_right_pads */
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
tildes
)
{
index_t
i_ytilde
=
tildes
[
0
];
index_t
i_xtilde
=
tildes
[
1
];
const
index_t
N
=
in_g_n_c_wis_lengths
[
1
];
const
index_t
K
=
wei_g_k_c_xs_lengths
[
1
];
const
index_t
Hi
=
in_g_n_c_wis_lengths
[
3
];
const
index_t
Wi
=
in_g_n_c_wis_lengths
[
4
];
const
index_t
Ho
=
out_g_n_k_wos_lengths
[
3
];
const
index_t
Wo
=
out_g_n_k_wos_lengths
[
4
];
const
index_t
Y
=
wei_g_k_c_xs_lengths
[
3
];
const
index_t
X
=
wei_g_k_c_xs_lengths
[
4
];
const
index_t
InLeftPadH
=
input_left_pads
[
0
];
const
index_t
InLeftPadW
=
input_left_pads
[
1
];
const
index_t
ConvStrideH
=
conv_filter_strides
[
0
];
const
index_t
ConvStrideW
=
conv_filter_strides
[
1
];
const
index_t
ConvDilationH
=
conv_filter_dilations
[
0
];
const
index_t
ConvDilationW
=
conv_filter_dilations
[
1
];
const
index_t
AK0
=
K
/
AK1
;
// assume packed
const
auto
out_n_ho_wo_k_grid_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
N
,
Ho
,
Wo
,
K
));
if
constexpr
(
ConvBwdDataSpecialization
==
ck
::
tensor_operation
::
device
::
ConvolutionBackwardDataSpecialization
::
Filter1x1Stride1Pad0
)
{
// A: output tensor
const
auto
out_gemmak0_gemmmraw_gemmak1_grid_desc
=
transform_tensor_descriptor
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
N
*
Ho
*
Wo
,
K
)),
make_tuple
(
make_pass_through_transform
(
N
*
Ho
*
Wo
),
make_unmerge_transform
(
make_tuple
(
AK0
,
AK1
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
,
2
>
{}));
const
auto
out_gemmak0_gemmm_gemmak1_grid_desc
=
ck
::
tensor_operation
::
device
::
PadTensorDescriptor
(
out_gemmak0_gemmmraw_gemmak1_grid_desc
,
make_tuple
(
AK0
,
GemmMPerBlock
,
AK1
),
Sequence
<
false
,
DoPadGemmM
,
false
>
{});
return
out_gemmak0_gemmm_gemmak1_grid_desc
;
}
else
{
const
auto
GcdStrideDilationH
=
math
::
gcd
(
ConvStrideH
,
ConvDilationH
);
const
auto
GcdStrideDilationW
=
math
::
gcd
(
ConvStrideW
,
ConvDilationW
);
const
auto
YTilde
=
ConvStrideH
/
GcdStrideDilationH
;
const
auto
XTilde
=
ConvStrideW
/
GcdStrideDilationW
;
const
auto
YDot
=
math
::
integer_divide_ceil
(
Y
,
YTilde
);
const
auto
XDot
=
math
::
integer_divide_ceil
(
X
,
XTilde
);
const
auto
HTilde
=
Ho
+
math
::
integer_divide_ceil
(
ConvDilationH
*
(
Y
-
I1
),
ConvStrideH
);
const
auto
WTilde
=
Wo
+
math
::
integer_divide_ceil
(
ConvDilationW
*
(
X
-
I1
),
ConvStrideW
);
// only work on HTilde and WTilde that contribute to non-padding area of input tensor
const
auto
IHTildeSliceBegin
=
math
::
integer_divide_floor
(
math
::
max
(
I0
,
InLeftPadH
-
ConvDilationH
*
(
YTilde
-
I1
)),
ConvStrideH
);
const
auto
IWTildeSliceBegin
=
math
::
integer_divide_floor
(
math
::
max
(
I0
,
InLeftPadW
-
ConvDilationW
*
(
XTilde
-
I1
)),
ConvStrideW
);
const
auto
IHTildeSliceEnd
=
math
::
min
(
HTilde
,
math
::
integer_divide_ceil
(
InLeftPadH
+
Hi
-
I1
,
ConvStrideH
)
+
I1
);
const
auto
IWTildeSliceEnd
=
math
::
min
(
WTilde
,
math
::
integer_divide_ceil
(
InLeftPadW
+
Wi
-
I1
,
ConvStrideW
)
+
I1
);
const
auto
HTildeSlice
=
IHTildeSliceEnd
-
IHTildeSliceBegin
;
const
auto
WTildeSlice
=
IWTildeSliceEnd
-
IWTildeSliceBegin
;
// GemmK is different for each GEMM
const
auto
YDotSlice
=
math
::
integer_divide_ceil
(
Y
-
i_ytilde
,
YTilde
);
const
auto
XDotSlice
=
math
::
integer_divide_ceil
(
X
-
i_xtilde
,
XTilde
);
// A: output tensor
const
auto
out_n_hop_wop_k_grid_desc
=
transform_tensor_descriptor
(
out_n_ho_wo_k_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_pad_transform
(
Ho
,
I0
,
I0
),
make_pad_transform
(
Wo
,
I0
,
I0
),
make_pass_through_transform
(
K
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}));
const
auto
out_n_ydot_htilde_xdot_wtilde_k_grid_desc
=
transform_tensor_descriptor
(
out_n_hop_wop_k_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_embed_transform
(
make_tuple
(
YDot
,
HTilde
),
make_tuple
(
-
ConvDilationH
/
GcdStrideDilationH
,
I1
)),
make_embed_transform
(
make_tuple
(
XDot
,
WTilde
),
make_tuple
(
-
ConvDilationW
/
GcdStrideDilationW
,
I1
)),
make_pass_through_transform
(
K
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
,
4
>
{},
Sequence
<
5
>
{}));
const
auto
out_n_ydotslice_htildeslice_xdotslice_wtildeslice_ak0_ak1_grid_desc
=
transform_tensor_descriptor
(
out_n_ydot_htilde_xdot_wtilde_k_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_slice_transform
(
YDot
,
I0
,
YDotSlice
),
make_slice_transform
(
HTilde
,
IHTildeSliceBegin
,
HTildeSlice
),
make_slice_transform
(
XDot
,
I0
,
XDotSlice
),
make_slice_transform
(
WTilde
,
IWTildeSliceBegin
,
WTildeSlice
),
make_unmerge_transform
(
make_tuple
(
AK0
,
AK1
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{},
Sequence
<
5
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{},
Sequence
<
5
,
6
>
{}));
const
auto
out_gemmak0_gemmmraw_gemmak1_grid_desc
=
transform_tensor_descriptor
(
out_n_ydotslice_htildeslice_xdotslice_wtildeslice_ak0_ak1_grid_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
YDotSlice
,
XDotSlice
,
AK0
)),
make_merge_transform
(
make_tuple
(
N
,
HTildeSlice
,
WTildeSlice
)),
make_pass_through_transform
(
AK1
)),
make_tuple
(
Sequence
<
1
,
3
,
5
>
{},
Sequence
<
0
,
2
,
4
>
{},
Sequence
<
6
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}));
const
auto
out_gemmak0_gemmm_gemmak1_grid_desc
=
ck
::
tensor_operation
::
device
::
PadTensorDescriptor
(
out_gemmak0_gemmmraw_gemmak1_grid_desc
,
make_tuple
(
AK0
,
GemmMPerBlock
,
AK1
),
Sequence
<
false
,
DoPadGemmM
,
false
>
{});
return
out_gemmak0_gemmm_gemmak1_grid_desc
;
}
}
template
<
typename
BLayout
,
typename
std
::
enable_if
<
NDimSpatial
==
2
&&
is_same_v
<
BLayout
,
tensor_layout
::
convolution
::
GKYXC
>,
bool
>::
type
=
false
>
static
auto
MakeBDescriptor_BK0_N_BK1
(
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
out_g_n_k_wos_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
/* out_g_n_k_wos_strides */
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
wei_g_k_c_xs_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
/* wei_g_k_c_xs_strides */
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
in_g_n_c_wis_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
/* in_g_n_c_wis_strides */
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_dilations
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
/* input_left_pads */
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
/* input_right_pads */
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
tildes
)
{
index_t
i_ytilde
=
tildes
[
0
];
index_t
i_xtilde
=
tildes
[
1
];
const
index_t
N
=
in_g_n_c_wis_lengths
[
1
];
const
index_t
K
=
wei_g_k_c_xs_lengths
[
1
];
const
index_t
C
=
wei_g_k_c_xs_lengths
[
2
];
const
index_t
Ho
=
out_g_n_k_wos_lengths
[
3
];
const
index_t
Wo
=
out_g_n_k_wos_lengths
[
4
];
const
index_t
Y
=
wei_g_k_c_xs_lengths
[
3
];
const
index_t
X
=
wei_g_k_c_xs_lengths
[
4
];
const
index_t
ConvStrideH
=
conv_filter_strides
[
0
];
const
index_t
ConvStrideW
=
conv_filter_strides
[
1
];
const
index_t
ConvDilationH
=
conv_filter_dilations
[
0
];
const
index_t
ConvDilationW
=
conv_filter_dilations
[
1
];
const
index_t
BK0
=
K
/
BK1
;
// assume packed
const
auto
wei_k_y_x_c_grid_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
K
,
Y
,
X
,
C
));
if
constexpr
(
ConvBwdDataSpecialization
==
ck
::
tensor_operation
::
device
::
ConvolutionBackwardDataSpecialization
::
Filter1x1Stride1Pad0
)
{
// B: weight tensor
const
auto
wei_gemmbk0_gemmnraw_gemmbk1_grid_desc
=
transform_tensor_descriptor
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
K
,
C
)),
make_tuple
(
make_unmerge_transform
(
make_tuple
(
BK0
,
BK1
)),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
make_naive_tensor_descriptor
(
make_tuple
(
N
*
Ho
*
Wo
,
C
),
make_tuple
(
I0
,
I1
));
const
auto
wei_gemmbk0_gemmn_gemmbk1_grid_desc
=
ck
::
tensor_operation
::
device
::
PadTensorDescriptor
(
wei_gemmbk0_gemmnraw_gemmbk1_grid_desc
,
make_tuple
(
BK0
,
GemmNPerBlock
,
BK1
),
Sequence
<
false
,
DoPadGemmN
,
false
>
{});
return
wei_gemmbk0_gemmn_gemmbk1_grid_desc
;
}
else
{
const
auto
GcdStrideDilationH
=
math
::
gcd
(
ConvStrideH
,
ConvDilationH
);
const
auto
GcdStrideDilationW
=
math
::
gcd
(
ConvStrideW
,
ConvDilationW
);
const
auto
YTilde
=
ConvStrideH
/
GcdStrideDilationH
;
const
auto
XTilde
=
ConvStrideW
/
GcdStrideDilationW
;
const
auto
YDot
=
math
::
integer_divide_ceil
(
Y
,
YTilde
);
const
auto
XDot
=
math
::
integer_divide_ceil
(
X
,
XTilde
);
// GemmK is different for each GEMM
const
auto
YDotSlice
=
math
::
integer_divide_ceil
(
Y
-
i_ytilde
,
YTilde
);
const
auto
XDotSlice
=
math
::
integer_divide_ceil
(
X
-
i_xtilde
,
XTilde
);
// B weight tensor
const
auto
wei_k_ydot_ytilde_xdot_xtilde_c_grid_desc
=
transform_tensor_descriptor
(
wei_k_y_x_c_grid_desc
,
make_tuple
(
make_pass_through_transform
(
K
),
make_embed_transform
(
make_tuple
(
YDot
,
YTilde
),
make_tuple
(
ConvStrideH
/
GcdStrideDilationH
,
I1
)),
make_embed_transform
(
make_tuple
(
XDot
,
XTilde
),
make_tuple
(
ConvStrideW
/
GcdStrideDilationW
,
I1
)),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
,
4
>
{},
Sequence
<
5
>
{}));
const
auto
wei_bk0_bk1_ydotslice_xdotslice_c_grid_desc
=
transform_tensor_descriptor
(
wei_k_ydot_ytilde_xdot_xtilde_c_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
BK0
,
BK1
)),
make_slice_transform
(
YDot
,
I0
,
YDotSlice
),
make_slice_transform
(
XDot
,
I0
,
XDotSlice
),
make_freeze_transform
(
i_ytilde
),
make_freeze_transform
(
i_xtilde
),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
3
>
{},
Sequence
<
2
>
{},
Sequence
<
4
>
{},
Sequence
<
5
>
{}),
make_tuple
(
Sequence
<
0
,
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<>
{},
Sequence
<>
{},
Sequence
<
4
>
{}));
const
auto
wei_gemmbk0_gemmnraw_gemmbk1_grid_desc
=
transform_tensor_descriptor
(
wei_bk0_bk1_ydotslice_xdotslice_c_grid_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
YDotSlice
,
XDotSlice
,
BK0
)),
make_pass_through_transform
(
C
),
make_pass_through_transform
(
BK1
)),
make_tuple
(
Sequence
<
2
,
3
,
0
>
{},
Sequence
<
4
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}));
const
auto
wei_gemmbk0_gemmn_gemmbk1_grid_desc
=
ck
::
tensor_operation
::
device
::
PadTensorDescriptor
(
wei_gemmbk0_gemmnraw_gemmbk1_grid_desc
,
make_tuple
(
wei_gemmbk0_gemmnraw_gemmbk1_grid_desc
.
GetLength
(
I0
),
GemmNPerBlock
,
BK1
),
Sequence
<
false
,
DoPadGemmN
,
false
>
{});
return
wei_gemmbk0_gemmn_gemmbk1_grid_desc
;
}
}
template
<
typename
CLayout
,
typename
std
::
enable_if
<
NDimSpatial
==
2
&&
(
is_same_v
<
CLayout
,
tensor_layout
::
convolution
::
GNHWC
>
||
is_same_v
<
CLayout
,
tensor_layout
::
convolution
::
NHWGC
>
||
is_same_v
<
CLayout
,
tensor_layout
::
convolution
::
G_NHW_C
>
),
bool
>::
type
=
false
>
static
auto
MakeCDescriptor_M_N
(
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
out_g_n_k_wos_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
/* out_g_n_k_wos_strides */
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
wei_g_k_c_xs_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
/* wei_g_k_c_xs_strides */
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
in_g_n_c_wis_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
in_g_n_c_wis_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_dilations
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_left_pads
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_right_pads
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
tildes
)
{
index_t
i_ytilde
=
tildes
[
0
];
index_t
i_xtilde
=
tildes
[
1
];
const
index_t
N
=
in_g_n_c_wis_lengths
[
1
];
const
index_t
C
=
wei_g_k_c_xs_lengths
[
2
];
const
index_t
Hi
=
in_g_n_c_wis_lengths
[
3
];
const
index_t
Wi
=
in_g_n_c_wis_lengths
[
4
];
const
index_t
Ho
=
out_g_n_k_wos_lengths
[
3
];
const
index_t
Wo
=
out_g_n_k_wos_lengths
[
4
];
const
index_t
Y
=
wei_g_k_c_xs_lengths
[
3
];
const
index_t
X
=
wei_g_k_c_xs_lengths
[
4
];
const
index_t
InLeftPadH
=
input_left_pads
[
0
];
const
index_t
InLeftPadW
=
input_left_pads
[
1
];
const
index_t
InRightPadH
=
input_right_pads
[
0
];
const
index_t
InRightPadW
=
input_right_pads
[
1
];
const
index_t
ConvStrideH
=
conv_filter_strides
[
0
];
const
index_t
ConvStrideW
=
conv_filter_strides
[
1
];
const
index_t
ConvDilationH
=
conv_filter_dilations
[
0
];
const
index_t
ConvDilationW
=
conv_filter_dilations
[
1
];
// assume strided
const
auto
in_n_hi_wi_c_grid_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
N
,
Hi
,
Wi
,
C
),
make_tuple
(
in_g_n_c_wis_strides
[
1
],
in_g_n_c_wis_strides
[
3
],
in_g_n_c_wis_strides
[
4
],
in_g_n_c_wis_strides
[
2
]));
if
constexpr
(
ConvBwdDataSpecialization
==
ck
::
tensor_operation
::
device
::
ConvolutionBackwardDataSpecialization
::
Filter1x1Stride1Pad0
)
{
// C: input tensor
const
auto
in_n_y_ho_x_wo_c_grid_desc
=
transform_tensor_descriptor
(
in_n_hi_wi_c_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_embed_transform
(
make_tuple
(
I1
,
Ho
),
make_tuple
(
I1
,
ConvStrideH
)),
make_embed_transform
(
make_tuple
(
I1
,
Wo
),
make_tuple
(
I1
,
ConvStrideW
)),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
,
4
>
{},
Sequence
<
5
>
{}));
const
auto
in_gemmmraw_gemmnraw_grid_desc
=
transform_tensor_descriptor
(
in_n_y_ho_x_wo_c_grid_desc
,
make_tuple
(
make_freeze_transform
(
I0
),
make_freeze_transform
(
I0
),
make_merge_transform
(
make_tuple
(
N
,
Ho
,
Wo
)),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
3
>
{},
Sequence
<
0
,
2
,
4
>
{},
Sequence
<
5
>
{}),
make_tuple
(
Sequence
<>
{},
Sequence
<>
{},
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
in_gemmm_gemmn_grid_desc
=
ck
::
tensor_operation
::
device
::
PadTensorDescriptor
(
in_gemmmraw_gemmnraw_grid_desc
,
make_tuple
(
GemmMPerBlock
,
GemmNPerBlock
),
Sequence
<
DoPadGemmM
,
DoPadGemmN
>
{});
return
in_gemmm_gemmn_grid_desc
;
}
else
{
const
auto
GcdStrideDilationH
=
math
::
gcd
(
ConvStrideH
,
ConvDilationH
);
const
auto
GcdStrideDilationW
=
math
::
gcd
(
ConvStrideW
,
ConvDilationW
);
const
auto
YTilde
=
ConvStrideH
/
GcdStrideDilationH
;
const
auto
XTilde
=
ConvStrideW
/
GcdStrideDilationW
;
const
auto
HTilde
=
Ho
+
math
::
integer_divide_ceil
(
ConvDilationH
*
(
Y
-
I1
),
ConvStrideH
);
const
auto
WTilde
=
Wo
+
math
::
integer_divide_ceil
(
ConvDilationW
*
(
X
-
I1
),
ConvStrideW
);
// only work on HTilde and WTilde that contribute to non-padding area of input tensor
const
auto
IHTildeSliceBegin
=
math
::
integer_divide_floor
(
math
::
max
(
I0
,
InLeftPadH
-
ConvDilationH
*
(
YTilde
-
I1
)),
ConvStrideH
);
const
auto
IWTildeSliceBegin
=
math
::
integer_divide_floor
(
math
::
max
(
I0
,
InLeftPadW
-
ConvDilationW
*
(
XTilde
-
I1
)),
ConvStrideW
);
const
auto
IHTildeSliceEnd
=
math
::
min
(
HTilde
,
math
::
integer_divide_ceil
(
InLeftPadH
+
Hi
-
I1
,
ConvStrideH
)
+
I1
);
const
auto
IWTildeSliceEnd
=
math
::
min
(
WTilde
,
math
::
integer_divide_ceil
(
InLeftPadW
+
Wi
-
I1
,
ConvStrideW
)
+
I1
);
const
auto
HTildeSlice
=
IHTildeSliceEnd
-
IHTildeSliceBegin
;
const
auto
WTildeSlice
=
IWTildeSliceEnd
-
IWTildeSliceBegin
;
// C: input tensor
const
auto
in_n_hip_wip_c_grid_desc
=
transform_tensor_descriptor
(
in_n_hi_wi_c_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_pad_transform
(
Hi
,
InLeftPadH
,
InRightPadH
),
make_pad_transform
(
Wi
,
InLeftPadW
,
InRightPadW
),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}));
const
auto
in_n_ytilde_htilde_xtilde_wtilde_c_grid_desc
=
transform_tensor_descriptor
(
in_n_hip_wip_c_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_embed_transform
(
make_tuple
(
YTilde
,
HTilde
),
make_tuple
(
ConvDilationH
,
ConvStrideH
)),
make_embed_transform
(
make_tuple
(
XTilde
,
WTilde
),
make_tuple
(
ConvDilationW
,
ConvStrideW
)),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
,
4
>
{},
Sequence
<
5
>
{}));
const
auto
in_n_htildeslice_wtildeslice_c_grid_desc
=
transform_tensor_descriptor
(
in_n_ytilde_htilde_xtilde_wtilde_c_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_freeze_transform
(
i_ytilde
),
make_slice_transform
(
HTilde
,
IHTildeSliceBegin
,
HTildeSlice
),
make_freeze_transform
(
i_xtilde
),
make_slice_transform
(
WTilde
,
IWTildeSliceBegin
,
WTildeSlice
),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{},
Sequence
<
5
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<>
{},
Sequence
<
1
>
{},
Sequence
<>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}));
const
auto
in_gemmmraw_gemmnraw_grid_desc
=
transform_tensor_descriptor
(
in_n_htildeslice_wtildeslice_c_grid_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
N
,
HTildeSlice
,
WTildeSlice
)),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
in_gemmm_gemmn_grid_desc
=
ck
::
tensor_operation
::
device
::
PadTensorDescriptor
(
in_gemmmraw_gemmnraw_grid_desc
,
make_tuple
(
GemmMPerBlock
,
GemmNPerBlock
),
Sequence
<
DoPadGemmM
,
DoPadGemmN
>
{});
return
in_gemmm_gemmn_grid_desc
;
}
}
// for input bias
template
<
typename
CLayout
,
typename
std
::
enable_if
<
NDimSpatial
==
2
&&
(
is_same_v
<
CLayout
,
tensor_layout
::
convolution
::
GC
>
||
is_same_v
<
CLayout
,
tensor_layout
::
convolution
::
G_C
>
),
bool
>::
type
=
false
>
static
auto
MakeCDescriptor_M_N
(
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
out_g_n_k_wos_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
/* out_g_n_k_wos_strides */
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
wei_g_k_c_xs_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
/* wei_g_k_c_xs_strides */
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
in_g_n_c_wis_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
/* in_g_n_c_wis_strides */
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_dilations
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_left_pads
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
/* input_right_pads */
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
/* tildes */
)
{
const
index_t
N
=
in_g_n_c_wis_lengths
[
1
];
const
index_t
C
=
wei_g_k_c_xs_lengths
[
2
];
const
index_t
Hi
=
in_g_n_c_wis_lengths
[
3
];
const
index_t
Wi
=
in_g_n_c_wis_lengths
[
4
];
const
index_t
Ho
=
out_g_n_k_wos_lengths
[
3
];
const
index_t
Wo
=
out_g_n_k_wos_lengths
[
4
];
const
index_t
Y
=
wei_g_k_c_xs_lengths
[
3
];
const
index_t
X
=
wei_g_k_c_xs_lengths
[
4
];
const
index_t
InLeftPadH
=
input_left_pads
[
0
];
const
index_t
InLeftPadW
=
input_left_pads
[
1
];
const
index_t
ConvStrideH
=
conv_filter_strides
[
0
];
const
index_t
ConvStrideW
=
conv_filter_strides
[
1
];
const
index_t
ConvDilationH
=
conv_filter_dilations
[
0
];
const
index_t
ConvDilationW
=
conv_filter_dilations
[
1
];
if
constexpr
(
ConvBwdDataSpecialization
==
ck
::
tensor_operation
::
device
::
ConvolutionBackwardDataSpecialization
::
Filter1x1Stride1Pad0
)
{
const
auto
in_gemmm_gemmn_grid_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
N
*
Ho
*
Wo
,
C
),
make_tuple
(
I0
,
I1
));
return
in_gemmm_gemmn_grid_desc
;
}
else
{
const
auto
GcdStrideDilationH
=
math
::
gcd
(
ConvStrideH
,
ConvDilationH
);
const
auto
GcdStrideDilationW
=
math
::
gcd
(
ConvStrideW
,
ConvDilationW
);
const
auto
YTilde
=
ConvStrideH
/
GcdStrideDilationH
;
const
auto
XTilde
=
ConvStrideW
/
GcdStrideDilationW
;
const
auto
HTilde
=
Ho
+
math
::
integer_divide_ceil
(
ConvDilationH
*
(
Y
-
I1
),
ConvStrideH
);
const
auto
WTilde
=
Wo
+
math
::
integer_divide_ceil
(
ConvDilationW
*
(
X
-
I1
),
ConvStrideW
);
// only work on HTilde and WTilde that contribute to non-padding area of input tensor
const
auto
IHTildeSliceBegin
=
math
::
integer_divide_floor
(
math
::
max
(
I0
,
InLeftPadH
-
ConvDilationH
*
(
YTilde
-
I1
)),
ConvStrideH
);
const
auto
IWTildeSliceBegin
=
math
::
integer_divide_floor
(
math
::
max
(
I0
,
InLeftPadW
-
ConvDilationW
*
(
XTilde
-
I1
)),
ConvStrideW
);
const
auto
IHTildeSliceEnd
=
math
::
min
(
HTilde
,
math
::
integer_divide_ceil
(
InLeftPadH
+
Hi
-
I1
,
ConvStrideH
)
+
I1
);
const
auto
IWTildeSliceEnd
=
math
::
min
(
WTilde
,
math
::
integer_divide_ceil
(
InLeftPadW
+
Wi
-
I1
,
ConvStrideW
)
+
I1
);
const
auto
HTildeSlice
=
IHTildeSliceEnd
-
IHTildeSliceBegin
;
const
auto
WTildeSlice
=
IWTildeSliceEnd
-
IWTildeSliceBegin
;
// bias tensor
const
auto
in_gemmmraw_gemmnraw_grid_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
N
*
HTildeSlice
*
WTildeSlice
,
C
),
make_tuple
(
I0
,
I1
));
const
auto
in_gemmm_gemmn_grid_desc
=
ck
::
tensor_operation
::
device
::
PadTensorDescriptor
(
in_gemmmraw_gemmnraw_grid_desc
,
make_tuple
(
GemmMPerBlock
,
GemmNPerBlock
),
Sequence
<
DoPadGemmM
,
DoPadGemmN
>
{});
return
in_gemmm_gemmn_grid_desc
;
}
}
};
}
// namespace tensor_operation
}
// namespace ck
include/ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp
View file @
5aa3c344
...
...
@@ -16,6 +16,7 @@ namespace tensor_operation {
template
<
index_t
NDimSpatial
,
device
::
ConvolutionForwardSpecialization
ConvForwardSpecialization
>
struct
TransformConvFwdToGemm
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
template
<
typename
ALayout
,
...
...
@@ -864,6 +865,29 @@ struct TransformConvFwdToGemm
return
out_gemmm_gemmn_desc
;
}
// for output bias
template
<
typename
CLayout
,
typename
std
::
enable_if
<
is_same_v
<
CLayout
,
tensor_layout
::
convolution
::
GK
>
||
is_same_v
<
CLayout
,
tensor_layout
::
convolution
::
G_K
>
,
bool
>::
type
=
false
>
static
auto
MakeCDescriptor_M_N
(
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
c_g_n_k_wos_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
/* c_g_n_k_wos_strides */
)
{
const
index_t
N
=
c_g_n_k_wos_lengths
[
1
];
const
index_t
K
=
c_g_n_k_wos_lengths
[
2
];
const
index_t
NHoWo
=
N
*
std
::
accumulate
(
c_g_n_k_wos_lengths
.
begin
()
+
3
,
c_g_n_k_wos_lengths
.
begin
()
+
3
+
NDimSpatial
,
index_t
{
1
},
std
::
multiplies
<
index_t
>
());
const
auto
out_gemmm_gemmn_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
NHoWo
,
K
),
make_tuple
(
I0
,
I1
));
return
out_gemmm_gemmn_desc
;
}
};
}
// namespace tensor_operation
...
...
include/ck/utility/ignore.hpp
View file @
5aa3c344
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#ifndef CK_IGNORE_HPP
#define CK_IGNORE_HPP
#pragma once
// https://en.cppreference.com/w/cpp/utility/tuple/ignore
...
...
@@ -21,4 +20,3 @@ struct ignore_t
inline
constexpr
detail
::
ignore_t
ignore
;
}
// namespace ck
#endif
Prev
1
2
3
4
5
6
7
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment