Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
b89a88b5
Commit
b89a88b5
authored
Sep 19, 2022
by
Adam Osewski
Browse files
Merge branch 'develop' into wavelet_model
parents
41d5fca7
43c898f6
Changes
261
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
3436 additions
and
691 deletions
+3436
-691
include/ck/tensor_operation/gpu/device/gemm_specialization.hpp
...de/ck/tensor_operation/gpu/device/gemm_specialization.hpp
+18
-0
include/ck/tensor_operation/gpu/device/impl/device_softmax_impl.hpp
.../tensor_operation/gpu/device/impl/device_softmax_impl.hpp
+272
-0
include/ck/tensor_operation/gpu/device/matrix_padder.hpp
include/ck/tensor_operation/gpu/device/matrix_padder.hpp
+218
-20
include/ck/tensor_operation/gpu/device/tensor_layout.hpp
include/ck/tensor_operation/gpu/device/tensor_layout.hpp
+49
-1
include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp
...r_operation/gpu/element/binary_element_wise_operation.hpp
+55
-0
include/ck/tensor_operation/gpu/element/element_wise_operation.hpp
...k/tensor_operation/gpu/element/element_wise_operation.hpp
+59
-16
include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp
...or_operation/gpu/element/unary_element_wise_operation.hpp
+51
-2
include/ck/tensor_operation/gpu/grid/gridwise_2d_multiple_reduction_multiblock.hpp
...on/gpu/grid/gridwise_2d_multiple_reduction_multiblock.hpp
+321
-0
include/ck/tensor_operation/gpu/grid/gridwise_2d_multiple_reduction_threadwise.hpp
...on/gpu/grid/gridwise_2d_multiple_reduction_threadwise.hpp
+264
-0
include/ck/tensor_operation/gpu/grid/gridwise_5ary_Elementwise_1d.hpp
...ensor_operation/gpu/grid/gridwise_5ary_Elementwise_1d.hpp
+0
-254
include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_gemm_xdl_cshuffle_v1.hpp
...n/gpu/grid/gridwise_batched_gemm_gemm_xdl_cshuffle_v1.hpp
+59
-46
include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle_v1.hpp
...tched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle_v1.hpp
+1268
-0
include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp
...id/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp
+114
-34
include/ck/tensor_operation/gpu/grid/gridwise_binary_elementwise_1d.hpp
...sor_operation/gpu/grid/gridwise_binary_elementwise_1d.hpp
+0
-155
include/ck/tensor_operation/gpu/grid/gridwise_elementwise_1d.hpp
.../ck/tensor_operation/gpu/grid/gridwise_elementwise_1d.hpp
+191
-0
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp
...grid/gridwise_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp
+63
-22
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp
...tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp
+4
-9
include/ck/tensor_operation/gpu/grid/gridwise_set_multiple_buffer_value.hpp
...operation/gpu/grid/gridwise_set_multiple_buffer_value.hpp
+86
-0
include/ck/tensor_operation/gpu/grid/gridwise_sparse_embedding3_forward_layernorm.hpp
...gpu/grid/gridwise_sparse_embedding3_forward_layernorm.hpp
+344
-0
include/ck/tensor_operation/gpu/grid/gridwise_unary_elementwise_1d.hpp
...nsor_operation/gpu/grid/gridwise_unary_elementwise_1d.hpp
+0
-132
No files found.
include/ck/tensor_operation/gpu/device/gemm_specialization.hpp
View file @
b89a88b5
...
...
@@ -9,6 +9,7 @@ namespace device {
enum
struct
GemmSpecialization
{
// Gemm
Default
,
MPadding
,
NPadding
,
...
...
@@ -17,6 +18,15 @@ enum struct GemmSpecialization
MKPadding
,
NKPadding
,
MNKPadding
,
// Gemm + Gemm
OPadding
,
MOPadding
,
NOPadding
,
KOPadding
,
MNOPadding
,
MKOPadding
,
NKOPadding
,
MNKOPadding
,
};
inline
std
::
string
getGemmSpecializationString
(
const
GemmSpecialization
&
s
)
...
...
@@ -31,6 +41,14 @@ inline std::string getGemmSpecializationString(const GemmSpecialization& s)
case
GemmSpecialization
::
MKPadding
:
return
"MKPadding"
;
case
GemmSpecialization
::
NKPadding
:
return
"NKPadding"
;
case
GemmSpecialization
::
MNKPadding
:
return
"MNKPadding"
;
case
GemmSpecialization
::
OPadding
:
return
"OPadding"
;
case
GemmSpecialization
::
MOPadding
:
return
"MOPadding"
;
case
GemmSpecialization
::
NOPadding
:
return
"NOPadding"
;
case
GemmSpecialization
::
KOPadding
:
return
"KOPadding"
;
case
GemmSpecialization
::
MNOPadding
:
return
"MNOPadding"
;
case
GemmSpecialization
::
MKOPadding
:
return
"MKOPadding"
;
case
GemmSpecialization
::
NKOPadding
:
return
"NKOPadding"
;
case
GemmSpecialization
::
MNKOPadding
:
return
"MNKOPadding"
;
default:
return
"Unrecognized specialization!"
;
}
}
...
...
include/ck/tensor_operation/gpu/device/impl/device_softmax_impl.hpp
0 → 100644
View file @
b89a88b5
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <sstream>
#include "ck/utility/reduction_operator.hpp"
#include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/tensor_operation/gpu/device/device_reduce.hpp"
#include "ck/tensor_operation/gpu/device/device_softmax.hpp"
#include "ck/tensor_operation/gpu/device/device_reduce_multiblock.hpp"
#include "ck/tensor_operation/gpu/device/device_reduce_common.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_softmax.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_set_buffer_value.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
template
<
typename
InDataType
,
typename
AccDataType
,
typename
OutDataType
,
typename
InElementwiseOp
,
typename
AccElementwiseOp
,
index_t
Rank
,
index_t
NumReduceDim
,
index_t
BlockSize
,
index_t
MThreadClusterSize
,
index_t
KThreadClusterSize
,
index_t
MThreadSliceSize
,
index_t
KThreadSliceSize
,
index_t
InSrcVectorDim
,
index_t
InSrcVectorSize
,
index_t
OutDstVectorSize
>
struct
DeviceSoftmaxImpl
:
public
DeviceSoftmax
<
InDataType
,
AccDataType
,
OutDataType
,
InElementwiseOp
,
AccElementwiseOp
,
Rank
>
{
static
constexpr
index_t
kRank
=
Rank
;
static
constexpr
index_t
kNumReduceDim
=
NumReduceDim
;
virtual
index_t
GetRank
()
const
override
{
return
kRank
;
}
virtual
index_t
GetNumReduceDim
()
const
override
{
return
kNumReduceDim
;
}
// Used for freeloading of some handy functions from DeviceReduceMultiBlock
using
Reduction
=
DeviceReduceMultiBlock
<
InDataType
,
AccDataType
,
OutDataType
,
Rank
,
NumReduceDim
,
reduce
::
Add
,
InElementwiseOp
,
AccElementwiseOp
,
InMemoryDataOperationEnum
::
Set
,
false
,
// PropagateNan
false
,
// OutputIndex
false
,
// HaveIndexInputIfOutputIndex
BlockSize
,
MThreadClusterSize
,
KThreadClusterSize
,
MThreadSliceSize
,
KThreadSliceSize
,
InSrcVectorDim
,
InSrcVectorSize
,
1
>
;
// OutDstVectorSize
using
GridDesc_M_K
=
decltype
(
Reduction
::
MakeSrc2dDescriptor
({
1
},
{
1
},
1
,
1
));
using
GridwiseSoftmaxGeneric
=
GridwiseSoftmax_mk_to_mk
<
InDataType
,
OutDataType
,
AccDataType
,
GridDesc_M_K
,
BlockSize
,
MThreadClusterSize
,
KThreadClusterSize
,
MThreadSliceSize
,
KThreadSliceSize
,
InSrcVectorDim
,
InSrcVectorSize
,
OutDstVectorSize
,
false
>
;
using
GridwiseSoftmaxSweepOnce
=
GridwiseSoftmax_mk_to_mk
<
InDataType
,
OutDataType
,
AccDataType
,
GridDesc_M_K
,
BlockSize
,
MThreadClusterSize
,
KThreadClusterSize
,
MThreadSliceSize
,
KThreadSliceSize
,
InSrcVectorDim
,
InSrcVectorSize
,
OutDstVectorSize
,
true
>
;
struct
Argument
:
public
Reduction
::
Argument
{
Argument
(
const
std
::
vector
<
index_t
>
inLengths
,
const
std
::
vector
<
index_t
>
inStrides
,
const
std
::
vector
<
index_t
>
reduceDims
,
AccDataType
alpha
,
AccDataType
beta
,
const
InDataType
*
in_dev
,
OutDataType
*
out_dev
,
InElementwiseOp
in_elementwise_op
,
AccElementwiseOp
acc_elementwise_op
)
:
Reduction
::
Argument
(
inLengths
,
inStrides
,
{},
{},
reduceDims
,
0.0
f
,
// alpha
0.0
f
,
// beta
in_dev
,
nullptr
,
out_dev
,
nullptr
,
in_elementwise_op
,
acc_elementwise_op
),
// FIXME: The base class DeviceReduceMultiBlock::Argument only supports alpha/beta of
// float32 precision. Make it support any data type so the fields can be removed.
alpha_
(
alpha
),
beta_
(
beta
)
{
// std::cout << "blkGroupSize= " << this->blkGroupSize
// << ", numBlockTileIteration= " << this->numBlockTileIteration
// << ", gridSize=" << this->gridSize
// << ", invariant_total_length=" << this->invariant_total_length <<
// std::endl;
}
AccDataType
alpha_
;
AccDataType
beta_
;
};
struct
Invoker
:
public
BaseInvoker
{
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
{
const
auto
in_grid_desc_m_k
=
Reduction
::
MakeSrc2dDescriptor
(
arg
.
inLengths_
,
arg
.
inStrides_
,
arg
.
blkGroupSize
,
arg
.
numBlockTileIteration
);
const
auto
out_grid_desc_m_k
=
Reduction
::
MakeSrc2dDescriptor
(
arg
.
inLengths_
,
arg
.
inStrides_
,
arg
.
blkGroupSize
,
arg
.
numBlockTileIteration
);
bool
sweep_once
=
in_grid_desc_m_k
.
GetLength
(
Number
<
1
>
{})
<=
KThreadClusterSize
*
KThreadSliceSize
;
const
auto
kernel_main
=
sweep_once
?
kernel_softmax
<
GridwiseSoftmaxSweepOnce
,
InDataType
,
OutDataType
,
AccDataType
,
GridDesc_M_K
>
:
kernel_softmax
<
GridwiseSoftmaxGeneric
,
InDataType
,
OutDataType
,
AccDataType
,
GridDesc_M_K
>
;
float
avg_time
=
0
;
avg_time
+=
launch_and_time_kernel
(
stream_config
,
kernel_main
,
dim3
(
arg
.
gridSize
),
dim3
(
BlockSize
),
0
,
in_grid_desc_m_k
,
out_grid_desc_m_k
,
arg
.
blkGroupSize
,
arg
.
numBlockTileIteration
,
arg
.
alpha_
,
arg
.
in_dev_
,
arg
.
beta_
,
arg
.
out_dev_
);
return
(
avg_time
);
};
float
Run
(
const
BaseArgument
*
p_arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
override
{
return
Run
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
),
stream_config
);
};
};
bool
IsSupportedArgument
(
const
BaseArgument
*
p_arg
)
override
{
const
Argument
*
p_arg_
=
dynamic_cast
<
const
Argument
*>
(
p_arg
);
if
(
!
Reduction
::
IsSupportedArgument
(
p_arg_
))
{
return
false
;
}
if
(
p_arg_
->
inLengths_
[
Rank
-
1
]
%
OutDstVectorSize
!=
0
)
{
return
false
;
}
return
true
;
};
//
// @brief Makes a pointer to Argument class.
//
// @param[in] inLengths Input tensor extent(s) from high to low dimension
// @param[in] inStrides Input tensor stride(s) from high to low dimension
// @param[in] reduceDims The dimension(s) the normalization operation is applied
// @param[in] alpha Typeless pointer in host memory storing the alpha scaling
// value as type AccDataType
// @param[in] beta Typeless pointer in host memory storing the beta scaling
// value as type AccDataType
// @param[in] in_dev Typeless const pointer in device memory storing the input
// tensor
// @param out_dev Typeless pointer in device memory storing the output tensor
// @param[in] in_elementwise_op The input elementwise operation.
// @param[in] acc_elementwise_op The accumulation elementwise operation.
//
// @return Unique pointer to the Argument class.
//
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
std
::
vector
<
index_t
>
inLengths
,
const
std
::
vector
<
index_t
>
inStrides
,
const
std
::
vector
<
int
>
reduceDims
,
const
void
*
alpha
,
const
void
*
beta
,
const
void
*
in_dev
,
void
*
out_dev
,
InElementwiseOp
in_elementwise_op
,
AccElementwiseOp
acc_elementwise_op
)
override
{
return
std
::
make_unique
<
Argument
>
(
inLengths
,
inStrides
,
reduceDims
,
*
static_cast
<
const
AccDataType
*>
(
alpha
),
*
static_cast
<
const
AccDataType
*>
(
beta
),
static_cast
<
const
InDataType
*>
(
in_dev
),
static_cast
<
OutDataType
*>
(
out_dev
),
in_elementwise_op
,
acc_elementwise_op
);
};
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
override
{
return
std
::
make_unique
<
Invoker
>
();
};
std
::
string
GetTypeString
()
const
override
{
auto
str
=
std
::
stringstream
();
// clang-format off
str
<<
"DeviceReduceSoftmax<"
<<
BlockSize
<<
","
;
str
<<
"M_C"
<<
MThreadClusterSize
<<
"_S"
<<
MThreadSliceSize
<<
","
;
str
<<
"K_C"
<<
KThreadClusterSize
<<
"_S"
<<
KThreadSliceSize
<<
","
;
str
<<
"InSrcVectorDim_"
<<
InSrcVectorDim
<<
"_InSrcVectorSize_"
<<
InSrcVectorSize
<<
"_OutDstVectorSize_"
<<
OutDstVectorSize
<<
">"
;
// clang-format on
return
str
.
str
();
}
};
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
include/ck/tensor_operation/gpu/device/matrix_padder.hpp
View file @
b89a88b5
...
...
@@ -12,12 +12,220 @@ namespace ck {
namespace
tensor_operation
{
namespace
device
{
template
<
typename
TensorDesc
,
typename
TileLengths
,
// Tuple<...>
typename
DoPads
>
// Sequence<bool, bool, ...>
__host__
__device__
constexpr
auto
PadTensorDescriptor
(
const
TensorDesc
&
desc
,
const
TileLengths
&
tile_lengths
,
DoPads
)
{
constexpr
index_t
num_dim
=
DoPads
::
Size
();
static_assert
(
num_dim
==
TileLengths
::
Size
()
&&
num_dim
==
TensorDesc
::
GetNumOfDimension
(),
"wrong! inconsistent # of dimensions"
);
// transforms
const
auto
transforms
=
generate_tuple
(
[
&
](
auto
idim
)
{
const
auto
MRaw
=
desc
.
GetLength
(
idim
);
const
auto
MPerTile
=
tile_lengths
[
idim
];
const
auto
M
=
math
::
integer_divide_ceil
(
MRaw
,
MPerTile
)
*
MPerTile
;
const
auto
MPad
=
M
-
MRaw
;
const
bool
DoPadM
=
DoPads
::
At
(
idim
);
const
auto
MTransform
=
conditional_expr
<
DoPadM
>
(
make_right_pad_transform
(
MRaw
,
MPad
),
make_pass_through_transform
(
MRaw
));
return
MTransform
;
},
Number
<
num_dim
>
{});
// lower dimension Id
const
auto
lower_dimss
=
generate_tuple
([
&
](
auto
idim
)
{
return
Sequence
<
idim
.
value
>
{};
},
Number
<
num_dim
>
{});
// upper dimension Id
const
auto
upper_dimss
=
lower_dimss
;
return
transform_tensor_descriptor
(
desc
,
transforms
,
lower_dimss
,
upper_dimss
);
}
// M/N/K/OPerTileType could be index_t or Number<>
template
<
GemmSpecialization
GemmSpec
,
typename
MPerTileType
,
typename
NPerTileType
,
typename
KPerTileType
,
typename
OPerTileType
>
struct
GemmGemmPadder
{
// TODO: hard to scale; use mask instead
static
constexpr
bool
PadM
=
GemmSpec
==
GemmSpecialization
::
MPadding
||
GemmSpec
==
GemmSpecialization
::
MNPadding
||
GemmSpec
==
GemmSpecialization
::
MKPadding
||
GemmSpec
==
GemmSpecialization
::
MNKPadding
||
GemmSpec
==
GemmSpecialization
::
MOPadding
||
GemmSpec
==
GemmSpecialization
::
MNOPadding
||
GemmSpec
==
GemmSpecialization
::
MKOPadding
||
GemmSpec
==
GemmSpecialization
::
MNKOPadding
;
static
constexpr
bool
PadN
=
GemmSpec
==
GemmSpecialization
::
NPadding
||
GemmSpec
==
GemmSpecialization
::
MNPadding
||
GemmSpec
==
GemmSpecialization
::
NKPadding
||
GemmSpec
==
GemmSpecialization
::
MNKPadding
||
GemmSpec
==
GemmSpecialization
::
NOPadding
||
GemmSpec
==
GemmSpecialization
::
MNOPadding
||
GemmSpec
==
GemmSpecialization
::
NKOPadding
||
GemmSpec
==
GemmSpecialization
::
MNKOPadding
;
static
constexpr
bool
PadK
=
GemmSpec
==
GemmSpecialization
::
KPadding
||
GemmSpec
==
GemmSpecialization
::
MKPadding
||
GemmSpec
==
GemmSpecialization
::
NKPadding
||
GemmSpec
==
GemmSpecialization
::
MNKPadding
||
GemmSpec
==
GemmSpecialization
::
KOPadding
||
GemmSpec
==
GemmSpecialization
::
MKOPadding
||
GemmSpec
==
GemmSpecialization
::
NKOPadding
||
GemmSpec
==
GemmSpecialization
::
MNKOPadding
;
static
constexpr
bool
PadO
=
GemmSpec
==
GemmSpecialization
::
OPadding
||
GemmSpec
==
GemmSpecialization
::
MOPadding
||
GemmSpec
==
GemmSpecialization
::
NOPadding
||
GemmSpec
==
GemmSpecialization
::
KOPadding
||
GemmSpec
==
GemmSpecialization
::
MNOPadding
||
GemmSpec
==
GemmSpecialization
::
MKOPadding
||
GemmSpec
==
GemmSpecialization
::
NKOPadding
||
GemmSpec
==
GemmSpecialization
::
MNKOPadding
;
// A[M, K]
template
<
typename
ADesc_MRaw_KRaw
>
__host__
__device__
constexpr
auto
PadADescriptor_M_K
(
const
ADesc_MRaw_KRaw
&
a_desc_mraw_kraw
)
const
{
return
PadTensorDescriptor
(
a_desc_mraw_kraw
,
make_tuple
(
MPerTile_
,
KPerTile_
),
Sequence
<
PadM
,
PadK
>
{});
}
// B[K, N]
template
<
typename
BDesc_NRaw_KRaw
>
__host__
__device__
constexpr
auto
PadBDescriptor_N_K
(
const
BDesc_NRaw_KRaw
&
b_desc_nraw_kraw
)
const
{
return
PadTensorDescriptor
(
b_desc_nraw_kraw
,
make_tuple
(
NPerTile_
,
KPerTile_
),
Sequence
<
PadN
,
PadK
>
{});
}
// B1[Gemm1N, Gemm1K] = B1[O, N]
template
<
typename
B1Desc_NRaw_KRaw
>
__host__
__device__
constexpr
auto
PadB1Descriptor_N_K
(
const
B1Desc_NRaw_KRaw
&
b1_desc_nraw_kraw
)
const
{
return
PadTensorDescriptor
(
b1_desc_nraw_kraw
,
make_tuple
(
OPerTile_
,
NPerTile_
),
Sequence
<
PadO
,
PadN
>
{});
}
// C[M, Gemm1N] = C[M, O]
template
<
typename
CDesc_MRaw_NRaw
>
__host__
__device__
constexpr
auto
PadCDescriptor_M_N
(
const
CDesc_MRaw_NRaw
&
c_desc_mraw_nraw
)
const
{
return
PadTensorDescriptor
(
c_desc_mraw_nraw
,
make_tuple
(
MPerTile_
,
OPerTile_
),
Sequence
<
PadM
,
PadO
>
{});
}
MPerTileType
MPerTile_
;
NPerTileType
NPerTile_
;
KPerTileType
KPerTile_
;
OPerTileType
OPerTile_
;
};
// M/N/KPerTileType could be index_t or Number<>
template
<
GemmSpecialization
GemmSpec
,
typename
MPerTileType
,
typename
NPerTileType
,
typename
KPerTileType
>
struct
MatrixPadder
struct
GemmPadder
{
static
constexpr
bool
PadM
=
(
GemmSpec
==
GemmSpecialization
::
MPadding
||
GemmSpec
==
GemmSpecialization
::
MNPadding
||
GemmSpec
==
GemmSpecialization
::
MKPadding
||
GemmSpec
==
GemmSpecialization
::
MNKPadding
);
static
constexpr
bool
PadN
=
(
GemmSpec
==
GemmSpecialization
::
NPadding
||
GemmSpec
==
GemmSpecialization
::
MNPadding
||
GemmSpec
==
GemmSpecialization
::
NKPadding
||
GemmSpec
==
GemmSpecialization
::
MNKPadding
);
static
constexpr
bool
PadK
=
(
GemmSpec
==
GemmSpecialization
::
KPadding
||
GemmSpec
==
GemmSpecialization
::
MKPadding
||
GemmSpec
==
GemmSpecialization
::
NKPadding
||
GemmSpec
==
GemmSpecialization
::
MNKPadding
);
template
<
typename
ADesc_MRaw_KRaw
>
__host__
__device__
constexpr
auto
PadADescriptor_M_K
(
const
ADesc_MRaw_KRaw
&
a_desc_mraw_kraw
)
const
{
return
PadTensorDescriptor
(
a_desc_mraw_kraw
,
make_tuple
(
MPerTile_
,
KPerTile_
),
Sequence
<
PadM
,
PadK
>
{});
}
template
<
typename
BDesc_NRaw_KRaw
>
__host__
__device__
constexpr
auto
PadBDescriptor_N_K
(
const
BDesc_NRaw_KRaw
&
b_desc_nraw_kraw
)
const
{
return
PadTensorDescriptor
(
b_desc_nraw_kraw
,
make_tuple
(
NPerTile_
,
KPerTile_
),
Sequence
<
PadN
,
PadK
>
{});
}
template
<
typename
CDesc_MRaw_NRaw
>
__host__
__device__
constexpr
auto
PadCDescriptor_M_N
(
const
CDesc_MRaw_NRaw
&
c_desc_mraw_nraw
)
const
{
return
PadTensorDescriptor
(
c_desc_mraw_nraw
,
make_tuple
(
MPerTile_
,
NPerTile_
),
Sequence
<
PadM
,
PadN
>
{});
}
MPerTileType
MPerTile_
;
NPerTileType
NPerTile_
;
KPerTileType
KPerTile_
;
};
// Alias of GemmPadder; to deprecate
template
<
GemmSpecialization
GemmSpec
,
typename
MPerTileType
,
typename
NPerTileType
,
typename
KPerTileType
>
struct
MatrixPadder
:
public
GemmPadder
<
GemmSpec
,
MPerTileType
,
NPerTileType
,
KPerTileType
>
{
};
// M/N/KPerTileType could be index_t or Number<>
template
<
bool
PadM
,
bool
PadN
,
bool
PadK
,
typename
MPerTileType
,
typename
NPerTileType
,
typename
KPerTileType
>
struct
GemmPadder_v2
{
template
<
typename
ADesc_MRaw_KRaw
>
__host__
__device__
constexpr
auto
PadADescriptor_M_K
(
const
ADesc_MRaw_KRaw
&
a_desc_mraw_kraw
)
const
{
return
PadTensorDescriptor
(
a_desc_mraw_kraw
,
make_tuple
(
MPerTile_
,
KPerTile_
),
Sequence
<
PadM
,
PadK
>
{});
}
template
<
typename
BDesc_NRaw_KRaw
>
__host__
__device__
constexpr
auto
PadBDescriptor_N_K
(
const
BDesc_NRaw_KRaw
&
b_desc_nraw_kraw
)
const
{
return
PadTensorDescriptor
(
b_desc_nraw_kraw
,
make_tuple
(
NPerTile_
,
KPerTile_
),
Sequence
<
PadN
,
PadK
>
{});
}
template
<
typename
CDesc_MRaw_NRaw
>
__host__
__device__
constexpr
auto
PadCDescriptor_M_N
(
const
CDesc_MRaw_NRaw
&
c_desc_mraw_nraw
)
const
{
return
PadTensorDescriptor
(
c_desc_mraw_nraw
,
make_tuple
(
MPerTile_
,
NPerTile_
),
Sequence
<
PadM
,
PadN
>
{});
}
MPerTileType
MPerTile_
;
NPerTileType
NPerTile_
;
KPerTileType
KPerTile_
;
};
// M/N/KPerTileType could be index_t or Number<>
template
<
bool
PadM
,
bool
PadN
,
bool
PadK
,
typename
MPerTileType
,
typename
NPerTileType
,
typename
KPerTileType
>
struct
MatrixPadder_v2
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
...
...
@@ -37,8 +245,7 @@ struct MatrixPadder
const
auto
MPad
=
M
-
MRaw
;
const
auto
KPad
=
K
-
KRaw
;
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
MKPadding
||
GemmSpec
==
GemmSpecialization
::
MNKPadding
)
if
constexpr
(
PadM
&&
PadK
)
{
// pad both M and K
return
transform_tensor_descriptor
(
a_desc_mraw_kraw
,
...
...
@@ -47,8 +254,7 @@ struct MatrixPadder
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
}
else
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
MPadding
||
GemmSpec
==
GemmSpecialization
::
MNPadding
)
else
if
constexpr
(
PadM
&&
(
!
PadK
))
{
// pad M, but not K
return
transform_tensor_descriptor
(
...
...
@@ -57,8 +263,7 @@ struct MatrixPadder
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
}
else
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
KPadding
||
GemmSpec
==
GemmSpecialization
::
NKPadding
)
else
if
constexpr
((
!
PadM
)
&&
PadK
)
{
// pad K, but not M
return
transform_tensor_descriptor
(
...
...
@@ -87,8 +292,7 @@ struct MatrixPadder
const
auto
NPad
=
N
-
NRaw
;
const
auto
KPad
=
K
-
KRaw
;
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
NKPadding
||
GemmSpec
==
GemmSpecialization
::
MNKPadding
)
if
constexpr
(
PadN
&&
PadK
)
{
// pad both N and K
return
transform_tensor_descriptor
(
b_desc_nraw_kraw
,
...
...
@@ -97,8 +301,7 @@ struct MatrixPadder
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
}
else
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
NPadding
||
GemmSpec
==
GemmSpecialization
::
MNPadding
)
else
if
constexpr
(
PadN
&&
(
!
PadK
))
{
// pad N, but not K
return
transform_tensor_descriptor
(
...
...
@@ -107,8 +310,7 @@ struct MatrixPadder
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
}
else
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
KPadding
||
GemmSpec
==
GemmSpecialization
::
MKPadding
)
else
if
constexpr
((
!
PadN
)
&&
PadK
)
{
// pad K, but not N
return
transform_tensor_descriptor
(
...
...
@@ -137,8 +339,7 @@ struct MatrixPadder
const
auto
MPad
=
M
-
MRaw
;
const
auto
NPad
=
N
-
NRaw
;
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
MNPadding
||
GemmSpec
==
GemmSpecialization
::
MNKPadding
)
if
constexpr
(
PadM
&&
PadN
)
{
// pad M and N
return
transform_tensor_descriptor
(
c_desc_mraw_nraw
,
...
...
@@ -147,8 +348,7 @@ struct MatrixPadder
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
}
else
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
MPadding
||
GemmSpec
==
GemmSpecialization
::
MKPadding
)
else
if
constexpr
(
PadM
&&
(
!
PadN
))
{
// pad M, but not N
return
transform_tensor_descriptor
(
...
...
@@ -157,8 +357,7 @@ struct MatrixPadder
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
}
else
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
NPadding
||
GemmSpec
==
GemmSpecialization
::
NKPadding
)
else
if
constexpr
((
!
PadM
)
&&
PadN
)
{
// pad N, but not M
return
transform_tensor_descriptor
(
...
...
@@ -178,7 +377,6 @@ struct MatrixPadder
NPerTileType
NPerTile_
;
KPerTileType
KPerTile_
;
};
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
include/ck/tensor_operation/gpu/device/tensor_layout.hpp
View file @
b89a88b5
...
...
@@ -93,7 +93,7 @@ struct GNDHWC : public BaseTensorLayout
};
// input tensor
// packed
G
NWC/
G
NHWC/
G
NDHWC
// packed NW
G
C/NHW
G
C/NDHW
G
C
struct
NWGC
:
public
BaseTensorLayout
{
static
constexpr
const
char
*
name
=
"NWGC"
;
...
...
@@ -330,6 +330,54 @@ struct G_NDHW_K : public BaseTensorLayout
static
constexpr
const
char
*
name
=
"G_NDHW_K"
;
};
// K-reduced output tensor (packed)
struct
GNW
:
public
BaseTensorLayout
{
static
constexpr
const
char
*
name
=
"GNW"
;
};
struct
GNHW
:
public
BaseTensorLayout
{
static
constexpr
const
char
*
name
=
"GNHW"
;
};
struct
GNDHW
:
public
BaseTensorLayout
{
static
constexpr
const
char
*
name
=
"GNDHW"
;
};
// K-reduced output tensor (packed)
struct
NWG
:
public
BaseTensorLayout
{
static
constexpr
const
char
*
name
=
"NWG"
;
};
struct
NHWG
:
public
BaseTensorLayout
{
static
constexpr
const
char
*
name
=
"NHWG"
;
};
struct
NDHWG
:
public
BaseTensorLayout
{
static
constexpr
const
char
*
name
=
"NDHWG"
;
};
// K-reduced output tensor (strided)
struct
G_NW
:
public
BaseTensorLayout
{
static
constexpr
const
char
*
name
=
"G_NW"
;
};
struct
G_NHW
:
public
BaseTensorLayout
{
static
constexpr
const
char
*
name
=
"G_NHW"
;
};
struct
G_NDHW
:
public
BaseTensorLayout
{
static
constexpr
const
char
*
name
=
"G_NDHW"
;
};
}
// namespace convolution
template
<
...
...
include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp
View file @
b89a88b5
...
...
@@ -28,6 +28,13 @@ struct Add
y
=
x0
+
x1
;
};
template
<
>
__host__
__device__
constexpr
void
operator
()
<
float
>
(
float
&
y
,
const
float
&
x0
,
const
half_t
&
x1
)
const
{
y
=
x0
+
type_convert
<
half_t
>
(
x1
);
};
template
<
>
__host__
__device__
constexpr
void
operator
()
<
half_t
>
(
half_t
&
y
,
const
float
&
x0
,
const
half_t
&
x1
)
const
...
...
@@ -172,6 +179,14 @@ struct AddRelu
const
float
a
=
x0
+
x1
;
y
=
a
>
type_convert
<
half_t
>
(
0.0
f
)
?
a
:
type_convert
<
half_t
>
(
0.0
f
);
};
template
<
>
__host__
__device__
constexpr
void
operator
()
<
float
,
float
,
half_t
>
(
float
&
y
,
const
float
&
x0
,
const
half_t
&
x1
)
const
{
const
float
a
=
x0
+
type_convert
<
float
>
(
x1
);
y
=
a
>
0.0
f
?
a
:
0.0
f
;
};
};
struct
AddHardswish
...
...
@@ -210,6 +225,46 @@ struct AddHardswish
};
};
// C = A * B
// E = FastGelu(C + D)
struct
AddFastGelu
{
// Fast GeLU
// https://paperswithcode.com/method/gelu
// y = 0.5*x*(1+tanh(sqrt(2/pi)*(x+0.044715*x^3)))
__host__
__device__
static
constexpr
float
GetFastGeLU
(
float
x
)
{
const
float
u
=
2.
f
*
x
*
(
0.035677
f
*
x
*
x
+
0.797885
f
);
const
float
emu
=
exp
(
-
u
);
const
float
cdf
=
0.5
f
+
0.5
f
*
(
2.
f
/
(
1.
f
+
emu
)
-
1.
f
);
return
x
*
cdf
;
}
template
<
typename
T
>
static
inline
constexpr
bool
is_valid_param_type_v
=
std
::
is_same_v
<
T
,
float
>
||
std
::
is_same_v
<
T
,
half_t
>
||
std
::
is_same_v
<
T
,
bhalf_t
>
||
std
::
is_same_v
<
T
,
int32_t
>
||
std
::
is_same_v
<
T
,
int8_t
>
;
template
<
typename
E
,
typename
C
,
typename
D
>
__host__
__device__
constexpr
void
operator
()(
E
&
e
,
const
C
&
c
,
const
D
&
d
)
const
{
static_assert
(
is_valid_param_type_v
<
E
>
&&
is_valid_param_type_v
<
C
>
&&
is_valid_param_type_v
<
D
>
);
const
float
y
=
GetFastGeLU
(
type_convert
<
float
>
(
c
)
+
type_convert
<
float
>
(
d
));
e
=
type_convert
<
E
>
(
y
);
}
template
<
typename
D
>
__host__
__device__
constexpr
void
operator
()(
float
&
e
,
const
float
&
c
,
const
D
&
d
)
const
{
static_assert
(
is_valid_param_type_v
<
D
>
);
e
=
GetFastGeLU
(
c
+
type_convert
<
float
>
(
d
));
}
};
}
// namespace element_wise
}
// namespace tensor_operation
}
// namespace ck
include/ck/tensor_operation/gpu/element/element_wise_operation.hpp
View file @
b89a88b5
...
...
@@ -98,6 +98,18 @@ struct AddReluAdd
int32_t
c
=
b
+
x2
;
y
=
c
;
}
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
template
<
>
__host__
__device__
constexpr
void
operator
()
<
int4_t
,
int8_t
,
int4_t
,
int4_t
>
(
int4_t
&
y
,
const
int8_t
&
x0
,
const
int4_t
&
x1
,
const
int4_t
&
x2
)
const
{
int32_t
a
=
x0
+
x1
;
int32_t
b
=
a
>
0
?
a
:
0
;
int32_t
c
=
b
+
x2
;
y
=
c
;
}
#endif // CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
};
struct
AddHardswishAdd
...
...
@@ -177,7 +189,11 @@ struct AddAddFastGelu
template
<
typename
T
>
static
inline
constexpr
bool
is_valid_param_type_v
=
std
::
is_same_v
<
T
,
float
>
||
std
::
is_same_v
<
T
,
half_t
>
||
std
::
is_same_v
<
T
,
bhalf_t
>
||
std
::
is_same_v
<
T
,
int32_t
>
||
std
::
is_same_v
<
T
,
int8_t
>
;
std
::
is_same_v
<
T
,
int32_t
>
||
std
::
is_same_v
<
T
,
int8_t
>
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
||
std
::
is_same_v
<
T
,
ck
::
int4_t
>
#endif
;
template
<
typename
E
,
typename
C
,
typename
D0
,
typename
D1
>
__host__
__device__
constexpr
void
...
...
@@ -198,17 +214,44 @@ struct Normalize
// FIXME: is double absolutely necessary?
Normalize
(
double
epsilon
=
1e-4
)
:
epsilon_
(
epsilon
)
{}
template
<
typename
T
>
__host__
__device__
constexpr
void
operator
()(
T
&
y
,
const
T
&
x
,
const
T
&
mean
,
const
T
&
mean_square
,
const
T
&
gamma
,
const
T
&
beta
)
const
;
template
<
typename
T1
,
typename
T2
,
typename
T3
>
__host__
__device__
constexpr
void
operator
()(
T1
&
y
,
const
T1
&
x
,
const
T2
&
mean
,
const
T2
&
mean_square
,
const
T3
&
gamma
,
const
T3
&
beta
)
const
;
template
<
>
__host__
__device__
constexpr
void
operator
()
<
half_t
,
float
,
half_t
>
(
half_t
&
y
,
const
half_t
&
x
,
const
float
&
mean
,
const
float
&
mean_square
,
const
half_t
&
gamma
,
const
half_t
&
beta
)
const
{
using
ck
::
math
::
sqrt
;
float
variance
=
mean_square
-
(
mean
*
mean
);
float
tmp_x
=
type_convert
<
float
>
(
x
);
float
tmp_gamma
=
type_convert
<
float
>
(
gamma
);
float
tmp_beta
=
type_convert
<
float
>
(
beta
);
float
tmp_y
=
((
tmp_x
-
mean
)
/
sqrt
(
variance
+
type_convert
<
float
>
(
epsilon_
)))
*
tmp_gamma
+
tmp_beta
;
y
=
type_convert
<
half_t
>
(
tmp_y
);
};
template
<
>
__host__
__device__
constexpr
void
operator
()
<
float
>
(
float
&
y
,
const
float
&
x
,
const
float
&
mean
,
const
float
&
mean_square
,
const
float
&
gamma
,
const
float
&
beta
)
const
__host__
__device__
constexpr
void
operator
()
<
float
,
float
,
float
>
(
float
&
y
,
const
float
&
x
,
const
float
&
mean
,
const
float
&
mean_square
,
const
float
&
gamma
,
const
float
&
beta
)
const
{
using
ck
::
math
::
sqrt
;
...
...
@@ -217,12 +260,12 @@ struct Normalize
};
template
<
>
__host__
__device__
constexpr
void
operator
()
<
double
>
(
double
&
y
,
const
double
&
x
,
const
double
&
mean
,
const
double
&
mean_square
,
const
double
&
gamma
,
const
double
&
beta
)
const
__host__
__device__
constexpr
void
operator
()
<
double
,
double
,
double
>
(
double
&
y
,
const
double
&
x
,
const
double
&
mean
,
const
double
&
mean_square
,
const
double
&
gamma
,
const
double
&
beta
)
const
{
using
ck
::
math
::
sqrt
;
...
...
include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp
View file @
b89a88b5
...
...
@@ -62,6 +62,14 @@ struct PassThrough
{
y
=
type_convert
<
int8_t
>
(
x
);
}
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
template
<
>
__host__
__device__
void
operator
()
<
int4_t
,
int4_t
>
(
int4_t
&
y
,
const
int4_t
&
x
)
const
{
y
=
x
;
}
#endif
};
struct
UnaryConvert
...
...
@@ -89,6 +97,22 @@ struct Scale
float
scale_
;
};
struct
ScaleAndResetNaNToMinusInfinity
{
__host__
__device__
ScaleAndResetNaNToMinusInfinity
(
float
scale
)
:
scale_
(
scale
)
{}
template
<
typename
Y
,
typename
X
>
__host__
__device__
void
operator
()(
Y
&
y
,
const
X
&
x
)
const
;
template
<
>
__host__
__device__
void
operator
()
<
float
,
float
>
(
float
&
y
,
const
float
&
x
)
const
{
y
=
ck
::
math
::
isnan
(
x
)
?
-
ck
::
NumericLimits
<
float
>::
Infinity
()
:
scale_
*
x
;
};
float
scale_
;
};
struct
UnaryDivide
{
__host__
__device__
UnaryDivide
(
const
int32_t
divider
=
1
)
:
divider_
(
divider
)
{}
...
...
@@ -111,9 +135,13 @@ struct UnarySquare
template
<
typename
T
>
__host__
__device__
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
,
static_assert
(
is_same_v
<
T
,
float
>
||
is_same_v
<
T
,
double
>
||
is_same_v
<
T
,
int32_t
>
||
is_same_v
<
T
,
int8_t
>
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
||
is_same_v
<
T
,
int4_t
>
#endif
,
"Data type is not supported by this operation!"
);
y
=
x
*
x
;
};
};
...
...
@@ -183,6 +211,27 @@ struct FastGelu
}
};
// https://paperswithcode.com/method/gelu
// y = 0.5*x*(1+erf(x/sqrt(2)))
struct
Gelu
{
template
<
typename
Y
,
typename
X
>
__host__
__device__
void
operator
()(
Y
&
y
,
const
X
&
x
)
const
;
template
<
>
__host__
__device__
void
operator
()
<
float
,
float
>
(
float
&
y
,
const
float
&
x
)
const
{
y
=
0.5
f
*
x
*
(
1.
f
+
erf
(
float
(
0.70710678118
f
*
x
)));
}
template
<
>
__host__
__device__
void
operator
()
<
ck
::
half_t
,
ck
::
half_t
>
(
ck
::
half_t
&
y
,
const
ck
::
half_t
&
x
)
const
{
y
=
ck
::
half_t
(
0.5
)
*
x
*
(
ck
::
half_t
(
1
)
+
ck
::
half_t
(
erf
(
float
(
0.70710678118
f
*
x
))));
}
};
}
// namespace element_wise
}
// namespace tensor_operation
}
// namespace ck
include/ck/tensor_operation/gpu/grid/gridwise_2d_multiple_reduction_multiblock.hpp
0 → 100644
View file @
b89a88b5
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/reduction_common.hpp"
#include "ck/utility/reduction_operator.hpp"
#include "ck/utility/reduction_functions_accumulate.hpp"
#include "ck/tensor_operation/gpu/block/reduction_functions_blockwise.hpp"
#include "ck/tensor_operation/gpu/thread/reduction_functions_threadwise.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
namespace
ck
{
template
<
typename
GridwiseMultipleReduction
,
index_t
NumReduction
,
typename
InDataType
,
typename
OutDataTypePointerTuple
,
typename
AccDataType
,
typename
InGridDesc_M_K
,
typename
OutGridDesc_M_Tuple
,
typename
InElementwiseOperationTuple
,
typename
AccElementwiseOperationTuple
>
__global__
void
kernel_multiple_reduce_multiblock
(
const
InGridDesc_M_K
in_grid_desc_m_k
,
const
OutGridDesc_M_Tuple
out_grid_desc_m_tuple
,
const
InElementwiseOperationTuple
in_elementwise_op_tuple
,
const
AccElementwiseOperationTuple
acc_elementwise_op_tuple
,
index_t
block_group_size
,
index_t
num_k_block_tile_iteration
,
Array
<
AccDataType
,
NumReduction
>
alpha_values
,
const
InDataType
*
const
__restrict__
p_in_value_global
,
Array
<
AccDataType
,
NumReduction
>
beta_values
,
OutDataTypePointerTuple
p_out_value_global_tuple
)
{
GridwiseMultipleReduction
::
Run
(
in_grid_desc_m_k
,
out_grid_desc_m_tuple
,
in_elementwise_op_tuple
,
acc_elementwise_op_tuple
,
block_group_size
,
num_k_block_tile_iteration
,
alpha_values
,
p_in_value_global
,
beta_values
,
p_out_value_global_tuple
);
};
template
<
index_t
NumReduction
,
typename
InDataType
,
typename
OutDataTypePointerTuple
,
typename
AccDataType
,
typename
InGridDesc_M_K
,
typename
OutGridDesc_M_Tuple
,
typename
ReduceOperation
,
typename
InElementwiseOperationTuple
,
typename
AccElementwiseOperationTuple
,
InMemoryDataOperationEnum
OutMemoryDataOperation
,
bool
PropagateNan
,
index_t
BlockSize
,
index_t
MThreadClusterSize
,
index_t
KThreadClusterSize
,
index_t
MThreadSliceSize
,
index_t
KThreadSliceSize
,
index_t
InSrcVectorDim
,
index_t
InSrcVectorSize
,
typename
OutDstVectorSizeSeq
>
struct
GridwiseMultipleReduction_mk_to_m_multiblock
{
static_assert
(((
InSrcVectorDim
==
0
&&
MThreadSliceSize
%
InSrcVectorSize
==
0
)
||
(
InSrcVectorDim
==
1
&&
KThreadSliceSize
%
InSrcVectorSize
==
0
)),
"Invalid thread slice sizes and/or vector sizes configuration, please check!"
);
static_assert
(
NumReduction
==
OutDataTypePointerTuple
::
Size
()
&&
NumReduction
==
OutGridDesc_M_Tuple
::
Size
()
&&
NumReduction
==
OutDstVectorSizeSeq
::
Size
()
&&
NumReduction
==
InElementwiseOperationTuple
::
Size
()
&&
NumReduction
==
AccElementwiseOperationTuple
::
Size
(),
"All tuple should have the same size as the number of Reductions!"
);
static
constexpr
bool
reorder_thread_cluster
=
(
InSrcVectorDim
==
0
);
using
ThreadClusterLengths_M_K
=
Sequence
<
MThreadClusterSize
,
KThreadClusterSize
>
;
using
ThreadBufferDimAccessOrder
=
typename
conditional
<
reorder_thread_cluster
,
Sequence
<
1
,
0
>
,
Sequence
<
0
,
1
>>::
type
;
using
ThreadClusterArrangeOrder
=
typename
conditional
<
reorder_thread_cluster
,
Sequence
<
1
,
0
>
,
Sequence
<
0
,
1
>>::
type
;
static
constexpr
auto
thread_cluster_desc
=
make_cluster_descriptor
(
ThreadClusterLengths_M_K
{},
ThreadClusterArrangeOrder
{});
using
ThreadReduceSrcDesc_M_K
=
decltype
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MThreadSliceSize
>
{},
Number
<
KThreadSliceSize
>
{})));
using
ThreadReduceDstDesc_M
=
decltype
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MThreadSliceSize
>
{})));
using
BlockwiseReduce
=
PartitionedBlockwiseReduction
<
AccDataType
,
BlockSize
,
ThreadClusterLengths_M_K
,
ThreadClusterArrangeOrder
,
ReduceOperation
,
PropagateNan
>
;
using
ThreadwiseReduce
=
ThreadwiseReduction
<
AccDataType
,
ThreadReduceSrcDesc_M_K
,
ThreadReduceDstDesc_M
,
ReduceOperation
,
PropagateNan
>
;
using
PassThroughOp
=
tensor_operation
::
element_wise
::
PassThrough
;
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
index_t
M_BlockTileSize
=
MThreadClusterSize
*
MThreadSliceSize
;
static
constexpr
index_t
K_BlockTileSize
=
KThreadClusterSize
*
KThreadSliceSize
;
using
Accumulation
=
detail
::
AccumulateWithNanCheck
<
PropagateNan
,
ReduceOperation
,
AccDataType
>
;
__device__
static
void
Run
(
const
InGridDesc_M_K
&
in_grid_desc_m_k
,
const
OutGridDesc_M_Tuple
&
out_grid_desc_m_tuple
,
const
InElementwiseOperationTuple
&
in_elementwise_op_tuple
,
const
AccElementwiseOperationTuple
&
acc_elementwise_op_tuple
,
index_t
block_group_size
,
index_t
num_k_block_tile_iteration
,
Array
<
AccDataType
,
NumReduction
>
alpha_values
,
const
InDataType
*
const
__restrict__
p_in_value_global
,
Array
<
AccDataType
,
NumReduction
>
beta_values
,
OutDataTypePointerTuple
p_out_value_global_tuple
)
{
const
auto
identityVal
=
ReduceOperation
::
template
GetIdentityValue
<
AccDataType
>();
// LDS, reused by all reductions
__shared__
AccDataType
p_reduce_work_buffer
[
BlockSize
];
const
auto
in_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_in_value_global
,
in_grid_desc_m_k
.
GetElementSpaceSize
(),
ReduceOperation
::
template
GetIdentityValue
<
InDataType
>());
auto
out_global_val_buf_tuple
=
generate_tuple
(
[
&
](
auto
iR
)
{
return
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_out_value_global_tuple
[
iR
],
out_grid_desc_m_tuple
[
iR
].
GetElementSpaceSize
());
},
Number
<
NumReduction
>
{});
auto
reduce_work_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
p_reduce_work_buffer
,
BlockSize
);
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
*
KThreadSliceSize
,
true
>
in_thread_buf
;
auto
in_thread_buf_tuple
=
generate_tuple
(
[
&
](
auto
iR
)
{
(
void
)
iR
;
return
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
*
KThreadSliceSize
,
true
>
{};
},
Number
<
NumReduction
>
{});
auto
accu_value_buf_tuple
=
generate_tuple
(
[
&
](
auto
iR
)
{
(
void
)
iR
;
return
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
,
true
>
{};
},
Number
<
NumReduction
>
{});
static_for
<
0
,
NumReduction
,
1
>
{}([
&
](
auto
iR
)
{
static_for
<
0
,
MThreadSliceSize
,
1
>
{}(
[
&
](
auto
J
)
{
accu_value_buf_tuple
(
iR
)(
J
)
=
identityVal
;
});
});
const
index_t
thread_local_id
=
get_thread_local_1d_id
();
const
index_t
block_global_id
=
get_block_1d_id
();
const
index_t
blkgroup_id
=
block_global_id
/
block_group_size
;
const
index_t
block_local_id
=
block_global_id
%
block_group_size
;
const
auto
thread_cluster_idx
=
thread_cluster_desc
.
CalculateBottomIndex
(
make_multi_index
(
thread_local_id
));
const
auto
thread_m_cluster_id
=
thread_cluster_idx
[
I0
];
const
auto
thread_k_cluster_id
=
thread_cluster_idx
[
I1
];
const
index_t
reduceSizePerBlock
=
K_BlockTileSize
*
num_k_block_tile_iteration
;
using
ThreadBufferLengths
=
Sequence
<
MThreadSliceSize
,
KThreadSliceSize
>
;
constexpr
auto
thread_buffer_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MThreadSliceSize
>
{},
Number
<
KThreadSliceSize
>
{}));
auto
threadwise_src_load
=
ThreadwiseTensorSliceTransfer_v2
<
InDataType
,
AccDataType
,
InGridDesc_M_K
,
decltype
(
thread_buffer_desc
),
ThreadBufferLengths
,
ThreadBufferDimAccessOrder
,
InSrcVectorDim
,
InSrcVectorSize
,
1
,
false
>
(
in_grid_desc_m_k
,
make_multi_index
(
blkgroup_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
,
block_local_id
*
reduceSizePerBlock
+
thread_k_cluster_id
*
KThreadSliceSize
));
constexpr
auto
in_thread_copy_step
=
make_multi_index
(
0
,
K_BlockTileSize
);
index_t
reducedTiles
=
0
;
do
{
threadwise_src_load
.
Run
(
in_grid_desc_m_k
,
in_global_val_buf
,
thread_buffer_desc
,
make_tuple
(
I0
,
I0
),
in_thread_buf
);
static_for
<
0
,
NumReduction
,
1
>
{}([
&
](
auto
iR
)
{
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
iM
)
{
// do element-wise pre-reduction operation
static_for
<
0
,
KThreadSliceSize
,
1
>
{}([
&
](
auto
iK
)
{
constexpr
auto
offset
=
thread_buffer_desc
.
CalculateOffset
(
make_tuple
(
iM
,
iK
));
in_elementwise_op_tuple
[
iR
](
in_thread_buf_tuple
(
iR
)(
Number
<
offset
>
{}),
in_thread_buf
(
Number
<
offset
>
{}));
});
});
ThreadwiseReduce
::
Reduce
(
in_thread_buf_tuple
(
iR
),
accu_value_buf_tuple
(
iR
));
});
threadwise_src_load
.
MoveSrcSliceWindow
(
in_grid_desc_m_k
,
in_thread_copy_step
);
reducedTiles
++
;
}
while
(
reducedTiles
<
num_k_block_tile_iteration
);
constexpr
auto
reduced_data_desc
=
ThreadReduceDstDesc_M
{};
static_for
<
0
,
NumReduction
,
1
>
{}([
&
](
auto
iR
)
{
using
OutDataTypePointer
=
remove_cvref_t
<
decltype
(
OutDataTypePointerTuple
{}[
iR
])
>
;
using
OutDataType
=
remove_cvref_t
<
remove_pointer_t
<
OutDataTypePointer
>>
;
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
I
)
{
BlockwiseReduce
::
Reduce
(
reduce_work_buf
,
accu_value_buf_tuple
(
iR
)(
I
));
});
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
I
)
{
if
(
thread_k_cluster_id
==
0
)
{
acc_elementwise_op_tuple
[
iR
](
accu_value_buf_tuple
(
iR
)(
I
),
accu_value_buf_tuple
(
iR
)(
I
));
accu_value_buf_tuple
(
iR
)(
I
)
*=
alpha_values
[
iR
];
}
});
if
(
thread_k_cluster_id
==
0
)
{
if
(
block_group_size
==
0
&&
!
float_equal_zero
{}(
beta_values
[
iR
]))
{
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
OutDataType
,
MThreadSliceSize
,
true
>
priorDstValueBuf
;
auto
threadwise_dst_load
=
ThreadwiseTensorSliceTransfer_v2
<
OutDataType
,
OutDataType
,
decltype
(
out_grid_desc_m_tuple
[
iR
]),
decltype
(
reduced_data_desc
),
Sequence
<
MThreadSliceSize
>
,
Sequence
<
0
>
,
0
,
OutDstVectorSizeSeq
::
At
(
iR
),
1
,
false
>
(
out_grid_desc_m_tuple
[
iR
],
make_multi_index
(
blkgroup_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
));
threadwise_dst_load
.
Run
(
out_grid_desc_m_tuple
[
iR
],
out_global_val_buf_tuple
(
iR
),
reduced_data_desc
,
make_tuple
(
I0
),
priorDstValueBuf
);
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
I
)
{
accu_value_buf_tuple
(
iR
)(
I
)
+=
type_convert
<
AccDataType
>
(
priorDstValueBuf
[
I
])
*
beta_values
[
iR
];
});
};
auto
threadwise_dst_store
=
ThreadwiseTensorSliceTransfer_v1r3
<
AccDataType
,
OutDataType
,
decltype
(
reduced_data_desc
),
decltype
(
out_grid_desc_m_tuple
[
iR
]),
PassThroughOp
,
Sequence
<
MThreadSliceSize
>
,
Sequence
<
0
>
,
0
,
OutDstVectorSizeSeq
::
At
(
iR
),
OutMemoryDataOperation
,
1
,
true
>
(
out_grid_desc_m_tuple
[
iR
],
make_multi_index
(
blkgroup_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
),
PassThroughOp
{});
threadwise_dst_store
.
Run
(
reduced_data_desc
,
make_tuple
(
I0
),
accu_value_buf_tuple
[
iR
],
out_grid_desc_m_tuple
[
iR
],
out_global_val_buf_tuple
(
iR
));
};
});
};
};
// namespace ck
}
// namespace ck
include/ck/tensor_operation/gpu/grid/gridwise_2d_multiple_reduction_threadwise.hpp
0 → 100644
View file @
b89a88b5
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/reduction_common.hpp"
#include "ck/utility/reduction_operator.hpp"
#include "ck/utility/reduction_functions_accumulate.hpp"
#include "ck/tensor_operation/gpu/block/reduction_functions_blockwise.hpp"
#include "ck/tensor_operation/gpu/thread/reduction_functions_threadwise.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
namespace
ck
{
template
<
typename
GridwiseMultipleReduction
,
index_t
NumReduction
,
typename
InDataType
,
typename
OutDataTypePointerTuple
,
typename
AccDataType
,
typename
InGridDesc_M_K
,
typename
OutGridDesc_M_Tuple
,
typename
InElementwiseOperationTuple
,
typename
AccElementwiseOperationTuple
>
__global__
void
kernel_multiple_reduce_threadwise
(
const
InGridDesc_M_K
in_grid_desc_m_k
,
const
OutGridDesc_M_Tuple
out_grid_desc_m_tuple
,
const
InElementwiseOperationTuple
in_elementwise_op_tuple
,
const
AccElementwiseOperationTuple
acc_elementwise_op_tuple
,
Array
<
AccDataType
,
NumReduction
>
alpha_values
,
const
InDataType
*
const
__restrict__
p_in_value_global
,
Array
<
AccDataType
,
NumReduction
>
beta_values
,
OutDataTypePointerTuple
p_out_value_global_tuple
)
{
GridwiseMultipleReduction
::
Run
(
in_grid_desc_m_k
,
out_grid_desc_m_tuple
,
in_elementwise_op_tuple
,
acc_elementwise_op_tuple
,
alpha_values
,
p_in_value_global
,
beta_values
,
p_out_value_global_tuple
);
};
template
<
index_t
NumReduction
,
typename
InDataType
,
typename
OutDataTypePointerTuple
,
typename
AccDataType
,
typename
InGridDesc_M_K
,
typename
OutGridDesc_M_Tuple
,
typename
ReduceOperation
,
typename
InElementwiseOperationTuple
,
typename
AccElementwiseOperationTuple
,
InMemoryDataOperationEnum
OutMemoryDataOperation
,
bool
PropagateNan
,
index_t
BlockSize
,
index_t
MThreadSliceSize
,
index_t
KThreadSliceSize
,
index_t
InSrcVectorDim
,
index_t
InSrcVectorSize
,
typename
OutDstVectorSizeSeq
>
struct
GridwiseMultipleReduction_mk_to_m_threadwise
{
static_assert
(((
InSrcVectorDim
==
0
&&
MThreadSliceSize
%
InSrcVectorSize
==
0
)
||
(
InSrcVectorDim
==
1
&&
KThreadSliceSize
%
InSrcVectorSize
==
0
)),
"Invalid thread slice sizes and/or vector sizes configuration, please check!"
);
static_assert
(
NumReduction
==
OutDataTypePointerTuple
::
Size
()
&&
NumReduction
==
OutGridDesc_M_Tuple
::
Size
()
&&
NumReduction
==
OutDstVectorSizeSeq
::
Size
()
&&
NumReduction
==
InElementwiseOperationTuple
::
Size
()
&&
NumReduction
==
AccElementwiseOperationTuple
::
Size
(),
"All tuple should have the same size as the number of Reductions!"
);
static
constexpr
bool
reorder_thread_cluster
=
(
InSrcVectorDim
==
0
);
using
ThreadBufferDimAccessOrder
=
typename
conditional
<
reorder_thread_cluster
,
Sequence
<
1
,
0
>
,
Sequence
<
0
,
1
>>::
type
;
using
ThreadReduceSrcDesc_M_K
=
decltype
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MThreadSliceSize
>
{},
Number
<
KThreadSliceSize
>
{})));
using
ThreadReduceDstDesc_M
=
decltype
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MThreadSliceSize
>
{})));
using
ThreadwiseReduce
=
ThreadwiseReduction
<
AccDataType
,
ThreadReduceSrcDesc_M_K
,
ThreadReduceDstDesc_M
,
ReduceOperation
,
PropagateNan
>
;
using
PassThroughOp
=
tensor_operation
::
element_wise
::
PassThrough
;
static
constexpr
auto
I0
=
Number
<
0
>
{};
using
Accumulation
=
detail
::
AccumulateWithNanCheck
<
PropagateNan
,
ReduceOperation
,
AccDataType
>
;
__device__
static
void
Run
(
const
InGridDesc_M_K
&
in_grid_desc_m_k
,
const
OutGridDesc_M_Tuple
&
out_grid_desc_m_tuple
,
const
InElementwiseOperationTuple
&
in_elementwise_op_tuple
,
const
AccElementwiseOperationTuple
&
acc_elementwise_op_tuple
,
Array
<
AccDataType
,
NumReduction
>
alpha_values
,
const
InDataType
*
const
__restrict__
p_in_value_global
,
Array
<
AccDataType
,
NumReduction
>
beta_values
,
OutDataTypePointerTuple
p_out_value_global_tuple
)
{
const
auto
identityVal
=
ReduceOperation
::
template
GetIdentityValue
<
AccDataType
>();
const
auto
in_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_in_value_global
,
in_grid_desc_m_k
.
GetElementSpaceSize
(),
ReduceOperation
::
template
GetIdentityValue
<
InDataType
>());
auto
out_global_val_buf_tuple
=
generate_tuple
(
[
&
](
auto
iR
)
{
return
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_out_value_global_tuple
[
iR
],
out_grid_desc_m_tuple
[
iR
].
GetElementSpaceSize
());
},
Number
<
NumReduction
>
{});
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
*
KThreadSliceSize
,
true
>
in_thread_buf
;
auto
in_thread_buf_tuple
=
generate_tuple
(
[
&
](
auto
iR
)
{
(
void
)
iR
;
return
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
*
KThreadSliceSize
,
true
>
{};
},
Number
<
NumReduction
>
{});
auto
accu_value_buf_tuple
=
generate_tuple
(
[
&
](
auto
iR
)
{
(
void
)
iR
;
return
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
,
true
>
{};
},
Number
<
NumReduction
>
{});
static_for
<
0
,
NumReduction
,
1
>
{}([
&
](
auto
iR
)
{
static_for
<
0
,
MThreadSliceSize
,
1
>
{}(
[
&
](
auto
J
)
{
accu_value_buf_tuple
(
iR
)(
J
)
=
identityVal
;
});
});
const
index_t
thread_global_1d_id
=
get_thread_global_1d_id
();
const
auto
toReduceLength
=
in_grid_desc_m_k
.
GetLength
(
Number
<
1
>
{});
using
ThreadBufferLengths
=
Sequence
<
MThreadSliceSize
,
KThreadSliceSize
>
;
constexpr
auto
thread_buffer_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MThreadSliceSize
>
{},
Number
<
KThreadSliceSize
>
{}));
auto
threadwise_src_load
=
ThreadwiseTensorSliceTransfer_v2
<
InDataType
,
AccDataType
,
InGridDesc_M_K
,
decltype
(
thread_buffer_desc
),
ThreadBufferLengths
,
ThreadBufferDimAccessOrder
,
InSrcVectorDim
,
InSrcVectorSize
,
1
,
false
>
(
in_grid_desc_m_k
,
make_multi_index
(
thread_global_1d_id
*
MThreadSliceSize
,
0
));
constexpr
auto
in_thread_copy_step
=
make_multi_index
(
0
,
KThreadSliceSize
);
index_t
reducedLength
=
0
;
do
{
threadwise_src_load
.
Run
(
in_grid_desc_m_k
,
in_global_val_buf
,
thread_buffer_desc
,
make_tuple
(
I0
,
I0
),
in_thread_buf
);
static_for
<
0
,
NumReduction
,
1
>
{}([
&
](
auto
iR
)
{
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
iM
)
{
// do element-wise pre-reduction operation
static_for
<
0
,
KThreadSliceSize
,
1
>
{}([
&
](
auto
iK
)
{
constexpr
auto
offset
=
thread_buffer_desc
.
CalculateOffset
(
make_tuple
(
iM
,
iK
));
in_elementwise_op_tuple
[
iR
](
in_thread_buf_tuple
(
iR
)(
Number
<
offset
>
{}),
in_thread_buf
(
Number
<
offset
>
{}));
});
});
ThreadwiseReduce
::
Reduce
(
in_thread_buf_tuple
(
iR
),
accu_value_buf_tuple
(
iR
));
});
threadwise_src_load
.
MoveSrcSliceWindow
(
in_grid_desc_m_k
,
in_thread_copy_step
);
reducedLength
+=
KThreadSliceSize
;
}
while
(
reducedLength
<
toReduceLength
);
constexpr
auto
reduced_data_desc
=
ThreadReduceDstDesc_M
{};
static_for
<
0
,
NumReduction
,
1
>
{}([
&
](
auto
iR
)
{
using
OutDataTypePointer
=
remove_cvref_t
<
decltype
(
OutDataTypePointerTuple
{}[
iR
])
>
;
using
OutDataType
=
remove_cvref_t
<
remove_pointer_t
<
OutDataTypePointer
>>
;
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
I
)
{
acc_elementwise_op_tuple
[
iR
](
accu_value_buf_tuple
(
iR
)(
I
),
accu_value_buf_tuple
(
iR
)(
I
));
accu_value_buf_tuple
(
iR
)(
I
)
*=
alpha_values
[
iR
];
});
if
(
!
float_equal_zero
{}(
beta_values
[
iR
]))
{
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
OutDataType
,
MThreadSliceSize
,
true
>
priorDstValueBuf
;
auto
threadwise_dst_load
=
ThreadwiseTensorSliceTransfer_v2
<
OutDataType
,
OutDataType
,
decltype
(
out_grid_desc_m_tuple
[
iR
]),
decltype
(
reduced_data_desc
),
Sequence
<
MThreadSliceSize
>
,
Sequence
<
0
>
,
0
,
OutDstVectorSizeSeq
::
At
(
iR
),
1
,
false
>
(
out_grid_desc_m_tuple
[
iR
],
make_multi_index
(
thread_global_1d_id
*
MThreadSliceSize
));
threadwise_dst_load
.
Run
(
out_grid_desc_m_tuple
[
iR
],
out_global_val_buf_tuple
(
iR
),
reduced_data_desc
,
make_tuple
(
I0
),
priorDstValueBuf
);
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
I
)
{
accu_value_buf_tuple
(
iR
)(
I
)
+=
type_convert
<
AccDataType
>
(
priorDstValueBuf
[
I
])
*
beta_values
[
iR
];
});
};
auto
threadwise_dst_store
=
ThreadwiseTensorSliceTransfer_v1r3
<
AccDataType
,
OutDataType
,
decltype
(
reduced_data_desc
),
decltype
(
out_grid_desc_m_tuple
[
iR
]),
PassThroughOp
,
Sequence
<
MThreadSliceSize
>
,
Sequence
<
0
>
,
0
,
OutDstVectorSizeSeq
::
At
(
iR
),
OutMemoryDataOperation
,
1
,
true
>
(
out_grid_desc_m_tuple
[
iR
],
make_multi_index
(
thread_global_1d_id
*
MThreadSliceSize
),
PassThroughOp
{});
threadwise_dst_store
.
Run
(
reduced_data_desc
,
make_tuple
(
I0
),
accu_value_buf_tuple
[
iR
],
out_grid_desc_m_tuple
[
iR
],
out_global_val_buf_tuple
(
iR
));
});
};
};
}
// namespace ck
include/ck/tensor_operation/gpu/grid/gridwise_5ary_Elementwise_1d.hpp
deleted
100644 → 0
View file @
41d5fca7
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/tensor_description/cluster_descriptor.hpp"
#include "ck/utility/data_type.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
namespace
ck
{
template
<
typename
Gridwise5AryEltwise
,
typename
ADataType
,
typename
BDataType
,
typename
CDataType
,
typename
DDataType
,
typename
EDataType
,
typename
FDataType
,
typename
AGridDesc_M
,
typename
BGridDesc_M
,
typename
CGridDesc_M
,
typename
DGridDesc_M
,
typename
EGridDesc_M
,
typename
FGridDesc_M
,
typename
ElementwiseFunctor
>
__global__
void
kernel_5ary_elementwise_1d
(
const
ADataType
*
__restrict__
p_a_global
,
const
BDataType
*
__restrict__
p_b_global
,
const
CDataType
*
__restrict__
p_c_global
,
const
DDataType
*
__restrict__
p_d_global
,
const
EDataType
*
__restrict__
p_e_global
,
FDataType
*
__restrict__
p_f_global
,
const
AGridDesc_M
a_grid_desc_m
,
const
BGridDesc_M
b_grid_desc_m
,
const
CGridDesc_M
c_grid_desc_m
,
const
DGridDesc_M
d_grid_desc_m
,
const
EGridDesc_M
e_grid_desc_m
,
const
FGridDesc_M
f_grid_desc_m
,
const
ElementwiseFunctor
functor
)
{
Gridwise5AryEltwise
::
Run
(
p_a_global
,
p_b_global
,
p_c_global
,
p_d_global
,
p_e_global
,
p_f_global
,
a_grid_desc_m
,
b_grid_desc_m
,
c_grid_desc_m
,
d_grid_desc_m
,
e_grid_desc_m
,
f_grid_desc_m
,
functor
);
}
// TODO - implement n-ary Elemenetwise_1D, tuple of inputs and tuple of outputs
template
<
typename
ADataType
,
typename
BDataType
,
typename
CDataType
,
typename
DDataType
,
typename
EDataType
,
typename
FDataType
,
typename
ComputeDataType
,
typename
AGridDesc_M
,
typename
BGridDesc_M
,
typename
CGridDesc_M
,
typename
DGridDesc_M
,
typename
EGridDesc_M
,
typename
FGridDesc_M
,
typename
ElementwiseFunctor
,
index_t
MPerThread
,
index_t
AScalarPerVector
,
index_t
BScalarPerVector
,
index_t
CScalarPerVector
,
index_t
DScalarPerVector
,
index_t
EScalarPerVector
,
index_t
FScalarPerVector
>
struct
Gridwise5AryElementwise_1D
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
thread_desc_m
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MPerThread
>
{}));
using
PassThrough
=
tensor_operation
::
element_wise
::
PassThrough
;
static
__device__
auto
CalculateElementwiseIndex
()
{
const
index_t
global_thread_id
=
get_thread_global_1d_id
();
return
make_multi_index
(
global_thread_id
*
MPerThread
);
}
__device__
static
void
Run
(
const
ADataType
*
__restrict__
p_a_global
,
const
BDataType
*
__restrict__
p_b_global
,
const
CDataType
*
__restrict__
p_c_global
,
const
DDataType
*
__restrict__
p_d_global
,
const
EDataType
*
__restrict__
p_e_global
,
FDataType
*
__restrict__
p_f_global
,
const
AGridDesc_M
a_grid_desc_m
,
const
BGridDesc_M
b_grid_desc_m
,
const
CGridDesc_M
c_grid_desc_m
,
const
DGridDesc_M
d_grid_desc_m
,
const
EGridDesc_M
e_grid_desc_m
,
const
FGridDesc_M
f_grid_desc_m
,
const
ElementwiseFunctor
functor
)
{
const
auto
a_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_a_global
,
a_grid_desc_m
.
GetElementSpaceSize
());
const
auto
b_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_b_global
,
b_grid_desc_m
.
GetElementSpaceSize
());
const
auto
c_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_c_global
,
c_grid_desc_m
.
GetElementSpaceSize
());
const
auto
d_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_d_global
,
d_grid_desc_m
.
GetElementSpaceSize
());
const
auto
e_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_e_global
,
e_grid_desc_m
.
GetElementSpaceSize
());
auto
f_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_f_global
,
f_grid_desc_m
.
GetElementSpaceSize
());
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
ComputeDataType
,
MPerThread
,
true
>
a_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
ComputeDataType
,
MPerThread
,
true
>
b_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
ComputeDataType
,
MPerThread
,
true
>
c_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
ComputeDataType
,
MPerThread
,
true
>
d_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
ComputeDataType
,
MPerThread
,
true
>
e_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
ComputeDataType
,
MPerThread
,
true
>
f_thread_buf
;
const
auto
thread_store_global_offset
=
CalculateElementwiseIndex
();
auto
a_global_load
=
ThreadwiseTensorSliceTransfer_v2
<
ADataType
,
ComputeDataType
,
AGridDesc_M
,
decltype
(
thread_desc_m
),
Sequence
<
MPerThread
>
,
// SliceLengths
Sequence
<
0
>
,
// DimAccessOrder
0
,
// SrcVectorDim
AScalarPerVector
,
// ScalarPerVector
1
,
// SrcScalarStrideInVector
false
>
{
a_grid_desc_m
,
thread_store_global_offset
};
auto
b_global_load
=
ThreadwiseTensorSliceTransfer_v2
<
BDataType
,
ComputeDataType
,
BGridDesc_M
,
decltype
(
thread_desc_m
),
Sequence
<
MPerThread
>
,
// SliceLengths
Sequence
<
0
>
,
// DimAccessOrder
0
,
// SrcVectorDim
BScalarPerVector
,
// ScalarPerVector
1
,
// SrcScalarStrideInVector
false
>
{
b_grid_desc_m
,
thread_store_global_offset
};
auto
c_global_load
=
ThreadwiseTensorSliceTransfer_v2
<
CDataType
,
ComputeDataType
,
CGridDesc_M
,
decltype
(
thread_desc_m
),
Sequence
<
MPerThread
>
,
// SliceLengths
Sequence
<
0
>
,
// DimAccessOrder
0
,
// SrcVectorDim
CScalarPerVector
,
// ScalarPerVector
1
,
// SrcScalarStrideInVector
false
>
{
c_grid_desc_m
,
thread_store_global_offset
};
auto
d_global_load
=
ThreadwiseTensorSliceTransfer_v2
<
DDataType
,
ComputeDataType
,
DGridDesc_M
,
decltype
(
thread_desc_m
),
Sequence
<
MPerThread
>
,
// SliceLengths
Sequence
<
0
>
,
// DimAccessOrder
0
,
// SrcVectorDim
DScalarPerVector
,
// ScalarPerVector
1
,
// SrcScalarStrideInVector
false
>
{
d_grid_desc_m
,
thread_store_global_offset
};
auto
e_global_load
=
ThreadwiseTensorSliceTransfer_v2
<
EDataType
,
ComputeDataType
,
EGridDesc_M
,
decltype
(
thread_desc_m
),
Sequence
<
MPerThread
>
,
// SliceLengths
Sequence
<
0
>
,
// DimAccessOrder
0
,
// SrcVectorDim
EScalarPerVector
,
// ScalarPerVector
1
,
// SrcScalarStrideInVector
false
>
{
e_grid_desc_m
,
thread_store_global_offset
};
auto
f_global_write
=
ThreadwiseTensorSliceTransfer_v1r3
<
ComputeDataType
,
FDataType
,
decltype
(
thread_desc_m
),
FGridDesc_M
,
PassThrough
,
Sequence
<
MPerThread
>
,
// SliceLengths
Sequence
<
0
>
,
// DimAccessOrder
0
,
// DstVectorDim
FScalarPerVector
,
// ScalarPerVector
InMemoryDataOperationEnum
::
Set
,
1
,
// DstScalarStrideInVector
false
>
{
f_grid_desc_m
,
thread_store_global_offset
,
PassThrough
{}};
const
index_t
blockSize
=
get_block_size
();
const
index_t
blockPerGrid
=
get_grid_size
();
const
auto
M
=
c_grid_desc_m
.
GetLength
(
I0
);
const
index_t
loop_step
=
blockPerGrid
*
blockSize
*
MPerThread
;
const
auto
loop_step_index
=
make_multi_index
(
loop_step
);
index_t
num_iter
=
M
/
(
loop_step
);
do
{
// read and process MPerThread elements
a_global_load
.
Run
(
a_grid_desc_m
,
a_global_buf
,
thread_desc_m
,
make_tuple
(
I0
),
a_thread_buf
);
b_global_load
.
Run
(
b_grid_desc_m
,
b_global_buf
,
thread_desc_m
,
make_tuple
(
I0
),
b_thread_buf
);
c_global_load
.
Run
(
c_grid_desc_m
,
c_global_buf
,
thread_desc_m
,
make_tuple
(
I0
),
c_thread_buf
);
d_global_load
.
Run
(
d_grid_desc_m
,
d_global_buf
,
thread_desc_m
,
make_tuple
(
I0
),
d_thread_buf
);
e_global_load
.
Run
(
e_grid_desc_m
,
e_global_buf
,
thread_desc_m
,
make_tuple
(
I0
),
e_thread_buf
);
static_for
<
0
,
MPerThread
,
1
>
{}([
&
](
auto
m
)
{
constexpr
auto
offset
=
thread_desc_m
.
CalculateOffset
(
make_tuple
(
m
));
functor
(
f_thread_buf
(
Number
<
offset
>
{}),
a_thread_buf
(
Number
<
offset
>
{}),
b_thread_buf
(
Number
<
offset
>
{}),
c_thread_buf
(
Number
<
offset
>
{}),
d_thread_buf
(
Number
<
offset
>
{}),
e_thread_buf
(
Number
<
offset
>
{}));
});
f_global_write
.
Run
(
thread_desc_m
,
make_tuple
(
I0
),
// SrcSliceOriginIdx
f_thread_buf
,
f_grid_desc_m
,
f_global_buf
);
a_global_load
.
MoveSrcSliceWindow
(
a_grid_desc_m
,
loop_step_index
);
b_global_load
.
MoveSrcSliceWindow
(
b_grid_desc_m
,
loop_step_index
);
c_global_load
.
MoveSrcSliceWindow
(
c_grid_desc_m
,
loop_step_index
);
d_global_load
.
MoveSrcSliceWindow
(
d_grid_desc_m
,
loop_step_index
);
e_global_load
.
MoveSrcSliceWindow
(
e_grid_desc_m
,
loop_step_index
);
f_global_write
.
MoveDstSliceWindow
(
f_grid_desc_m
,
loop_step_index
);
}
while
(
--
num_iter
);
}
};
}
// namespace ck
include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_gemm_xdl_cshuffle_v1.hpp
View file @
b89a88b5
...
...
@@ -181,36 +181,16 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle
__host__
__device__
static
constexpr
index_t
GetSharedMemoryNumberOfByte
()
{
// LDS allocation for A and B: be careful of alignment
constexpr
auto
a_block_desc_ak0_m_ak1
=
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1
();
constexpr
auto
b_block_desc_bk0_n_bk1
=
GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1
();
constexpr
auto
b1_block_desc_bk0_n_bk1
=
GetB1BlockDescriptor_BK0PerBlock_NPerBlock_BK1
();
// lds max alignment
constexpr
auto
max_lds_align
=
math
::
lcm
(
math
::
lcm
(
AK1
,
BK1
),
B1K1
);
constexpr
auto
a_block_space_size_aligned
=
math
::
integer_least_multiple
(
a_block_desc_ak0_m_ak1
.
GetElementSpaceSize
(),
max_lds_align
);
constexpr
auto
b0_block_space_size_aligned
=
math
::
integer_least_multiple
(
b_block_desc_bk0_n_bk1
.
GetElementSpaceSize
(),
max_lds_align
);
constexpr
auto
b1_block_space_size_aligned
=
math
::
integer_least_multiple
(
b1_block_desc_bk0_n_bk1
.
GetElementSpaceSize
(),
max_lds_align
);
constexpr
auto
b_block_space_size_aligned
=
math
::
max
(
b0_block_space_size_aligned
.
value
,
b1_block_space_size_aligned
.
value
);
// LDS allocation for C shuffle in LDS
constexpr
auto
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
=
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
();
constexpr
auto
c_block_size
=
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
.
GetElementSpaceSize
();
return
math
::
max
((
a_block_space_size_aligned
+
b_block_space_size_aligned
)
*
sizeof
(
FloatAB
),
c_block_size
*
sizeof
(
FloatCShuffle
));
const
index_t
gemm0_bytes_end
=
(
SharedMemTrait
::
a_block_space_size_aligned
+
SharedMemTrait
::
b_block_space_size_aligned
)
*
sizeof
(
FloatAB
);
const
index_t
gemm1_bytes_end
=
(
SharedMemTrait
::
b1_block_space_offset
+
SharedMemTrait
::
b1_block_space_size_aligned
)
*
sizeof
(
FloatAB
);
const
index_t
c_block_bytes_end
=
SharedMemTrait
::
c_block_space_size
*
sizeof
(
FloatCShuffle
);
return
math
::
max
(
gemm0_bytes_end
,
gemm1_bytes_end
,
c_block_bytes_end
);
}
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
...
...
@@ -261,8 +241,6 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle
return
false
;
}
assert
(
num_gemm1_k_outer_loop
*
num_gemm1_k_inner_loop
==
N
/
Gemm1KPerBlock
);
if
(
!
block_2_ctile_map
.
CheckValidity
(
c_grid_desc_m_n
))
{
return
false
;
...
...
@@ -312,6 +290,36 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle
using
DefaultBlock2CTileMap
=
remove_cvref_t
<
decltype
(
MakeDefaultBlock2CTileMap
(
CGridDesc_M_N
{}))
>
;
struct
SharedMemTrait
{
// LDS allocation for A and B: be careful of alignment
static
constexpr
auto
a_block_desc_ak0_m_ak1
=
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1
();
static
constexpr
auto
b_block_desc_bk0_n_bk1
=
GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1
();
static
constexpr
auto
b1_block_desc_bk0_n_bk1
=
GetB1BlockDescriptor_BK0PerBlock_NPerBlock_BK1
();
static
constexpr
auto
max_lds_align
=
math
::
lcm
(
math
::
lcm
(
AK1
,
BK1
),
B1K1
);
static
constexpr
auto
a_block_space_size_aligned
=
math
::
integer_least_multiple
(
a_block_desc_ak0_m_ak1
.
GetElementSpaceSize
(),
max_lds_align
);
static
constexpr
auto
b_block_space_size_aligned
=
math
::
integer_least_multiple
(
b_block_desc_bk0_n_bk1
.
GetElementSpaceSize
(),
max_lds_align
);
static
constexpr
auto
b1_block_space_size_aligned
=
math
::
integer_least_multiple
(
b1_block_desc_bk0_n_bk1
.
GetElementSpaceSize
(),
max_lds_align
);
static
constexpr
auto
a_block_space_offset
=
0
;
static
constexpr
auto
b_block_space_offset
=
a_block_space_size_aligned
.
value
;
static
constexpr
auto
b1_block_space_offset
=
0
;
// LDS allocation for C shuffle in LDS
static
constexpr
auto
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
=
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
();
static
constexpr
auto
c_block_space_size
=
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
.
GetElementSpaceSize
();
};
template
<
bool
HasMainKBlockLoop
,
typename
Block2CTileMap
>
__device__
static
void
Run
(
const
FloatAB
*
__restrict__
p_a_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
...
...
@@ -358,9 +366,6 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle
const
index_t
n_block_data_idx_on_grid
=
__builtin_amdgcn_readfirstlane
(
block_work_idx
[
I1
]
*
Gemm1NPerBlock
);
// lds max alignment
constexpr
auto
max_lds_align
=
math
::
lcm
(
math
::
lcm
(
AK1
,
BK1
),
B1K1
);
// A matrix in LDS memory, dst of blockwise copy
constexpr
auto
a_block_desc_ak0_m_ak1
=
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1
();
...
...
@@ -464,14 +469,12 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle
auto
acc_thread_buf
=
blockwise_gemm
.
GetCThreadBuffer
();
// LDS allocation for A and B: be careful of alignment
constexpr
auto
a_block_space_size_aligned
=
math
::
integer_least_multiple
(
a_block_desc_ak0_m_ak1
.
GetElementSpaceSize
(),
max_lds_align
);
auto
a_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
FloatAB
*>
(
p_shared
),
a_block_desc_ak0_m_ak1
.
GetElementSpaceSize
());
static_cast
<
FloatAB
*>
(
p_shared
)
+
SharedMemTrait
::
a_block_space_offset
,
a_block_desc_ak0_m_ak1
.
GetElementSpaceSize
());
auto
b_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
FloatAB
*>
(
p_shared
)
+
a
_block_space_
size_aligned
,
static_cast
<
FloatAB
*>
(
p_shared
)
+
SharedMemTrait
::
b
_block_space_
offset
,
b_block_desc_bk0_n_bk1
.
GetElementSpaceSize
());
constexpr
auto
a_block_slice_copy_step
=
make_multi_index
(
KPerBlock
/
AK1
,
0
,
0
);
...
...
@@ -588,12 +591,20 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle
// reuse LDS space for gemm0's b_block_buf
auto
b1_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
FloatAB
*>
(
p_shared
)
+
a
_block_space_
size_aligned
,
static_cast
<
FloatAB
*>
(
p_shared
)
+
SharedMemTrait
::
b1
_block_space_
offset
,
b1_block_desc_bk0_n_bk1
.
GetElementSpaceSize
());
constexpr
index_t
Gemm1KPack
=
math
::
max
(
math
::
lcm
(
MfmaSelector
<
FloatAB
,
MPerXdl
,
NPerXdl
>::
selected_mfma
.
group_size
,
B1K1
),
MfmaSelector
<
FloatAB
,
MPerXdl
,
NPerXdl
>::
selected_mfma
.
k_per_blk
);
// selected_mfma.group_size or B1K1 <= Gemm1KPack <= selected_mfma.group_size
// selected_mfma.k_per_blk <= Gemm1KPack
//
// Following similar rationale behind Gemm0KPack, let Gemm1KPack be the lowest common
// multiples of A1K1 (predetermined by selected_mfma.group_size) and B1K1. But in this case
// Gemm1KPack can't be higher than A1K1 itself because A1 matrix is distributed in VGPRs
// with 'group_size' amount of contiguous elements. Having Gemm1KPack greater than A1K1 will
// cause mismatch in summation index for example c[0:7] = a1[[0:3, 8:11]] * b1[0:7].
// therefore we may just as well assign Gemm1KPack = group_size
constexpr
index_t
Gemm1KPack
=
MfmaSelector
<
FloatAB
,
MPerXdl
,
NPerXdl
>::
selected_mfma
.
group_size
;
auto
gemm1_blockwise_gemm
=
BlockwiseGemmXdlops_v2
<
BlockSize
,
...
...
@@ -611,10 +622,11 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle
MXdlPerWave
,
Gemm1NXdlPerWave
,
Gemm1KPack
,
false
,
false
,
// TransposeC
Gemm1KPack
,
// AMmaKStride
Gemm1KPack
*
XdlopsGemm
<
FloatAB
,
MPerXdl
,
NPerXdl
,
Gemm1KPack
,
false
>
{}.
K0PerXdlops
>
{
make_tuple
(
0
,
0
,
0
,
0
)};
// TransposeC
// BMmaKStride
make_tuple
(
0
,
0
,
0
,
0
)};
// A_origin
auto
c_thread_buf
=
gemm1_blockwise_gemm
.
GetCThreadBuffer
();
...
...
@@ -699,6 +711,7 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle
a1_thread_desc_k0_m_k1
,
make_tuple
(
I0
,
I0
,
I0
),
a1_thread_buf
);
block_sync_lds
();
gemm1_blockwise_gemm
.
Run
(
a1_thread_buf
,
b1_block_buf
,
c_thread_buf
);
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle_v1.hpp
0 → 100644
View file @
b89a88b5
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/multi_index_transform_helper.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
namespace
ck
{
template
<
typename
A0B0B1DataType
,
// FIXME: don't assume A0/B0/B1 have same datatype
typename
Acc0DataType
,
typename
D0sDataType
,
typename
Acc1DataType
,
typename
C1ShuffleDataType
,
typename
D1sDataType
,
typename
E1DataType
,
typename
A0ElementwiseOperation
,
typename
B0ElementwiseOperation
,
typename
CDE0ElementwiseOperation
,
typename
B1ElementwiseOperation
,
typename
CDE1ElementwiseOperation
,
InMemoryDataOperationEnum
E1GlobalMemoryDataOperation
,
typename
A0GridDesc_M_K
,
typename
B0GridDesc_N_K
,
typename
D0sGridDesc_M_N
,
typename
B1GridDesc_N_K
,
typename
D1sGridDesc_M_N
,
typename
E1GridDesc_M_N
,
index_t
NumGemm0KPrefetchStage
,
index_t
BlockSize
,
index_t
Gemm0MPerBlock
,
index_t
Gemm0NPerBlock
,
index_t
Gemm0KPerBlock
,
index_t
Gemm1NPerBlock
,
index_t
Gemm1KPerBlock
,
index_t
A0K1Value
,
index_t
B0K1Value
,
index_t
B1K1Value
,
index_t
Gemm0MPerXdl
,
index_t
Gemm0NPerXdl
,
index_t
Gemm0MXdlPerWave
,
index_t
Gemm0NXdlPerWave
,
index_t
Gemm1NXdlPerWave
,
typename
A0BlockTransferThreadClusterLengths_AK0_M_AK1
,
typename
A0BlockTransferThreadClusterArrangeOrder
,
typename
A0BlockTransferSrcAccessOrder
,
index_t
A0BlockTransferSrcVectorDim
,
index_t
A0BlockTransferSrcScalarPerVector
,
index_t
A0BlockTransferDstScalarPerVector_AK1
,
bool
A0ThreadTransferSrcResetCoordinateAfterRun
,
// ignored
index_t
A0BlockLdsExtraM
,
typename
B0BlockTransferThreadClusterLengths_BK0_N_BK1
,
typename
B0BlockTransferThreadClusterArrangeOrder
,
typename
B0BlockTransferSrcAccessOrder
,
index_t
B0BlockTransferSrcVectorDim
,
index_t
B0BlockTransferSrcScalarPerVector
,
index_t
B0BlockTransferDstScalarPerVector_BK1
,
bool
B0ThreadTransferSrcResetCoordinateAfterRun
,
// ignored
index_t
B0BlockLdsExtraN
,
typename
B1BlockTransferThreadClusterLengths_BK0_N_BK1
,
typename
B1BlockTransferThreadClusterArrangeOrder
,
typename
B1BlockTransferSrcAccessOrder
,
index_t
B1BlockTransferSrcVectorDim
,
index_t
B1BlockTransferSrcScalarPerVector
,
index_t
B1BlockTransferDstScalarPerVector_BK1
,
bool
B1ThreadTransferSrcResetCoordinateAfterRun
,
index_t
B1BlockLdsExtraN
,
index_t
C1ShuffleGemm0MXdlPerWavePerShuffle
,
index_t
C1ShuffleGemm0NXdlPerWavePerShuffle
,
typename
CDE1ShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
index_t
CDE1ShuffleBlockTransferScalarPerVector_NPerBlock
,
LoopScheduler
LoopSched
>
struct
GridwiseBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle
{
static_assert
(
LoopSched
==
LoopScheduler
::
Default
,
"Non-default loop scheduler is currently not supported"
);
static
constexpr
index_t
NumD0Tensor
=
D0sDataType
::
Size
();
static
constexpr
index_t
NumD1Tensor
=
D1sDataType
::
Size
();
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
static
constexpr
auto
I4
=
Number
<
4
>
{};
static
constexpr
auto
I5
=
Number
<
5
>
{};
static
constexpr
auto
I6
=
Number
<
6
>
{};
static
constexpr
auto
I7
=
Number
<
7
>
{};
static
constexpr
auto
WaveSize
=
64
;
// K1 should be Number<...>
// Gemm0
static
constexpr
auto
A0K1
=
Number
<
A0K1Value
>
{};
static
constexpr
auto
B0K1
=
Number
<
B0K1Value
>
{};
static
constexpr
auto
A0K0PerBlock
=
Number
<
Gemm0KPerBlock
/
A0K1Value
>
{};
static
constexpr
auto
B0K0PerBlock
=
Number
<
Gemm0KPerBlock
/
B0K1Value
>
{};
static
constexpr
auto
Gemm0MWaves
=
Gemm0MPerBlock
/
(
Gemm0MPerXdl
*
Gemm0MXdlPerWave
);
static
constexpr
auto
Gemm0NWaves
=
Gemm0NPerBlock
/
(
Gemm0NPerXdl
*
Gemm0NXdlPerWave
);
// Gemm1
static
constexpr
auto
B1K1
=
Number
<
B1K1Value
>
{};
static
constexpr
auto
B1K0PerBlock
=
Number
<
Gemm1KPerBlock
/
B1K1Value
>
{};
using
ThisThreadBlock
=
ThisThreadBlock
<
BlockSize
>
;
using
GridwiseGemmPipe
=
GridwiseGemmPipeline_v1
<
NumGemm0KPrefetchStage
>
;
// ck::Tuple<const D0DataType1*, const D0DataType2*, ...>
static
constexpr
auto
MakeD0sGridPointer
()
{
return
generate_tuple
(
[
&
](
auto
i
)
{
using
D0DataType
=
remove_cvref_t
<
tuple_element_t
<
i
.
value
,
D0sDataType
>>
;
return
static_cast
<
const
D0DataType
*>
(
nullptr
);
},
Number
<
NumD0Tensor
>
{});
}
// ck::Tuple<const D1DataType1*, const D1DataType2*, ...>
static
constexpr
auto
MakeD1sGridPointer
()
{
return
generate_tuple
(
[
&
](
auto
i
)
{
using
D1DataType
=
remove_cvref_t
<
tuple_element_t
<
i
.
value
,
D1sDataType
>>
;
return
static_cast
<
const
D1DataType
*>
(
nullptr
);
},
Number
<
NumD1Tensor
>
{});
}
__device__
static
auto
GetGemm0WaveIdx
()
{
const
index_t
thread_id
=
get_thread_local_1d_id
();
constexpr
auto
threadid_to_wave_idx_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
Gemm0MWaves
,
Gemm0NWaves
,
WaveSize
))),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
return
threadid_to_wave_idx_adaptor
.
CalculateBottomIndex
(
make_multi_index
(
thread_id
));
}
__device__
static
auto
GetGemm0WaveMNIdx
(
const
index_t
thread_id
)
{
constexpr
auto
wave_threadid_to_mn_idx_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
WaveSize
/
Gemm0NPerXdl
,
Gemm0NPerXdl
))),
make_tuple
(
Sequence
<
0
,
1
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
return
wave_threadid_to_mn_idx_adaptor
.
CalculateBottomIndex
(
make_multi_index
(
thread_id
));
}
template
<
typename
A0BlockDesc_AK0_M_AK1
>
__host__
__device__
static
constexpr
auto
MakeGemm0AMmaTileDescriptor_M0_M1_M2_K
(
const
A0BlockDesc_AK0_M_AK1
&
)
{
constexpr
index_t
MWaves
=
Gemm0MPerBlock
/
(
Gemm0MXdlPerWave
*
Gemm0MPerXdl
);
return
MakeGemmMmaTileDescriptor_MN0_MN1_MN2_K
<
Gemm0MXdlPerWave
,
MWaves
,
Gemm0MPerXdl
>
(
A0BlockDesc_AK0_M_AK1
{});
}
template
<
typename
BBlockDesc_BK0_N_BK1
>
__host__
__device__
static
constexpr
auto
MakeGemm0BMmaTileDescriptor_N0_N1_N2_K
(
const
BBlockDesc_BK0_N_BK1
&
)
{
constexpr
index_t
NWaves
=
Gemm0NPerBlock
/
(
Gemm0NXdlPerWave
*
Gemm0NPerXdl
);
return
MakeGemmMmaTileDescriptor_MN0_MN1_MN2_K
<
Gemm0NXdlPerWave
,
NWaves
,
Gemm0NPerXdl
>
(
BBlockDesc_BK0_N_BK1
{});
}
template
<
typename
A0BlockDesc_AK0_M_AK1
>
__host__
__device__
static
constexpr
auto
MakeGemm1AMmaTileDescriptor_M0_M1_M2_K
(
const
A0BlockDesc_AK0_M_AK1
&
)
{
return
MakeGemmMmaTileDescriptor_MN0_MN1_MN2_K
<
Gemm0MXdlPerWave
,
1
,
1
>
(
A0BlockDesc_AK0_M_AK1
{});
}
template
<
typename
BBlockDesc_BK0_N_BK1
>
__host__
__device__
static
constexpr
auto
MakeGemm1BMmaTileDescriptor_N0_N1_N2_K
(
const
BBlockDesc_BK0_N_BK1
&
)
{
constexpr
index_t
Gemm1NWaves
=
Gemm1NPerBlock
/
(
Gemm1NXdlPerWave
*
Gemm0NPerXdl
);
return
MakeGemmMmaTileDescriptor_MN0_MN1_MN2_K
<
Gemm1NXdlPerWave
,
Gemm1NWaves
,
Gemm0NPerXdl
>
(
BBlockDesc_BK0_N_BK1
{});
}
__host__
__device__
static
constexpr
auto
GetA0BlockDescriptor_AK0PerBlock_MPerBlock_AK1
()
{
// A0 matrix in LDS memory, dst of blockwise copy
return
make_naive_tensor_descriptor
(
make_tuple
(
A0K0PerBlock
,
Number
<
Gemm0MPerBlock
>
{},
A0K1
),
make_tuple
(
Number
<
Gemm0MPerBlock
+
A0BlockLdsExtraM
>
{}
*
A0K1
,
A0K1
,
I1
));
}
__host__
__device__
static
constexpr
auto
GetB0BlockDescriptor_BK0PerBlock_NPerBlock_BK1
()
{
// B0 matrix in LDS memory, dst of blockwise copy
return
make_naive_tensor_descriptor
(
make_tuple
(
B0K0PerBlock
,
Number
<
Gemm0NPerBlock
>
{},
B0K1
),
make_tuple
(
Number
<
Gemm0NPerBlock
+
B0BlockLdsExtraN
>
{}
*
B0K1
,
B0K1
,
I1
));
}
__host__
__device__
static
constexpr
auto
GetB1BlockDescriptor_BK0PerBlock_NPerBlock_BK1
()
{
// B1 matrix in LDS memory, dst of blockwise copy
return
make_naive_tensor_descriptor
(
make_tuple
(
B1K0PerBlock
,
Number
<
Gemm1NPerBlock
>
{},
B1K1
),
make_tuple
(
Number
<
Gemm1NPerBlock
+
B1BlockLdsExtraN
>
{}
*
B1K1
,
B1K1
,
I1
));
}
__host__
__device__
static
constexpr
auto
GetC1ShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
()
{
constexpr
index_t
MWave
=
Gemm0MPerBlock
/
(
Gemm0MXdlPerWave
*
Gemm0MPerXdl
);
constexpr
index_t
NWave
=
Gemm1NPerBlock
/
(
Gemm1NXdlPerWave
*
Gemm0NPerXdl
);
constexpr
auto
c1_shuffle_block_desc_mblock_mperblock_nblock_nperblock
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
Number
<
C1ShuffleGemm0MXdlPerWavePerShuffle
*
MWave
*
Gemm0MPerXdl
>
{},
I1
,
Number
<
C1ShuffleGemm0NXdlPerWavePerShuffle
*
NWave
*
Gemm0NPerXdl
>
{}));
return
c1_shuffle_block_desc_mblock_mperblock_nblock_nperblock
;
}
__host__
__device__
static
constexpr
index_t
GetSharedMemoryNumberOfByte
()
{
const
index_t
gemm0_bytes_end
=
(
SharedMemTrait
::
a0_block_space_size_aligned
+
SharedMemTrait
::
b0_block_space_size_aligned
)
*
sizeof
(
A0B0B1DataType
);
const
index_t
gemm1_bytes_end
=
(
SharedMemTrait
::
b1_block_space_offset
+
SharedMemTrait
::
b1_block_space_size_aligned
)
*
sizeof
(
A0B0B1DataType
);
const
index_t
c1_block_bytes_end
=
SharedMemTrait
::
c1_block_space_size
*
sizeof
(
C1ShuffleDataType
);
return
math
::
max
(
gemm0_bytes_end
,
gemm1_bytes_end
,
c1_block_bytes_end
);
}
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
template
<
typename
Block2E1TileMap
>
__host__
__device__
static
constexpr
bool
CheckValidity
(
const
A0GridDesc_M_K
&
a0_grid_desc_m_k
,
const
B0GridDesc_N_K
&
b0_grid_desc_n_k
,
const
B1GridDesc_N_K
&
b1_grid_desc_n_k
,
const
E1GridDesc_M_N
&
e1_grid_desc_m_n
,
const
Block2E1TileMap
&
block_2_e1tile_map
)
{
static_assert
((
Gemm0MPerBlock
%
(
Gemm0MPerXdl
*
Gemm0MXdlPerWave
)
==
0
)
&&
(
Gemm0NPerBlock
%
(
Gemm0NXdlPerWave
*
Gemm0NPerXdl
))
==
0
,
"Invalid tuning param!"
);
const
auto
M
=
a0_grid_desc_m_k
.
GetLength
(
I0
);
const
auto
N
=
b0_grid_desc_n_k
.
GetLength
(
I0
);
const
auto
K
=
a0_grid_desc_m_k
.
GetLength
(
I1
);
const
auto
Gemm1N
=
b1_grid_desc_n_k
.
GetLength
(
I0
);
if
(
!
(
M
==
e1_grid_desc_m_n
.
GetLength
(
I0
)
&&
Gemm1N
==
e1_grid_desc_m_n
.
GetLength
(
I1
)))
{
return
false
;
}
if
(
!
(
M
%
Gemm0MPerBlock
==
0
&&
N
%
Gemm0NPerBlock
==
0
&&
K
%
Gemm0KPerBlock
==
0
&&
Gemm1N
%
Gemm1NPerBlock
==
0
))
{
return
false
;
}
// check gemm0 gridwise gemm pipeline
const
auto
num_gemm0_k_loop
=
K
/
Gemm0KPerBlock
;
if
(
!
GridwiseGemmPipe
::
IsSupported
(
num_gemm0_k_loop
))
{
return
false
;
}
// check gemm1 gridwise gemm pipeline
if
(
!
(
Gemm0NPerBlock
%
Gemm1KPerBlock
==
0
))
{
return
false
;
}
const
auto
num_gemm1_k_inner_loop
=
Gemm0NPerBlock
/
Gemm1KPerBlock
;
if
(
!
GridwiseGemmPipe
::
IsSupported
(
num_gemm1_k_inner_loop
))
{
return
false
;
}
if
(
!
block_2_e1tile_map
.
CheckValidity
(
e1_grid_desc_m_n
))
{
return
false
;
}
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
return
true
;
}
__host__
__device__
static
constexpr
bool
CalculateHasMainKBlockLoop
(
index_t
K
)
{
const
index_t
num_loop
=
K
/
Gemm0KPerBlock
;
return
GridwiseGemmPipe
::
CalculateHasMainLoop
(
num_loop
);
}
// A0 desc for source in blockwise copy
__host__
__device__
static
constexpr
auto
MakeDefaultA0GridDescriptor_AK0_M_AK1
(
const
A0GridDesc_M_K
&
a0_grid_desc_m_k
)
{
const
auto
M
=
a0_grid_desc_m_k
.
GetLength
(
I0
);
const
auto
K
=
a0_grid_desc_m_k
.
GetLength
(
I1
);
const
auto
A0K0
=
K
/
A0K1
;
return
transform_tensor_descriptor
(
a0_grid_desc_m_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
A0K0
,
A0K1
)),
make_pass_through_transform
(
M
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
}
// B0 desc for source in blockwise copy
__host__
__device__
static
constexpr
auto
MakeDefaultB0GridDescriptor_BK0_N_BK1
(
const
B0GridDesc_N_K
&
b0_grid_desc_n_k
)
{
const
auto
N
=
b0_grid_desc_n_k
.
GetLength
(
I0
);
const
auto
K
=
b0_grid_desc_n_k
.
GetLength
(
I1
);
const
auto
B0K0
=
K
/
B0K1
;
return
transform_tensor_descriptor
(
b0_grid_desc_n_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
B0K0
,
B0K1
)),
make_pass_through_transform
(
N
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
}
// D0 desc for source in blockwise copy
template
<
typename
D0GridDesc_M_N
>
__host__
__device__
static
constexpr
auto
MakeGemm0D0GridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
(
const
D0GridDesc_M_N
&
d0_grid_desc_m_n
)
{
const
auto
M
=
d0_grid_desc_m_n
.
GetLength
(
I0
);
const
auto
N
=
d0_grid_desc_m_n
.
GetLength
(
I1
);
constexpr
auto
mfma
=
MfmaSelector
<
A0B0B1DataType
,
Gemm0MPerXdl
,
Gemm0NPerXdl
>::
selected_mfma
;
constexpr
auto
N3
=
mfma
.
num_groups_per_blk
;
constexpr
auto
N5
=
mfma
.
group_size
;
return
transform_tensor_descriptor
(
d0_grid_desc_m_n
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
M
/
Gemm0MPerBlock
,
Gemm0MXdlPerWave
,
Gemm0MWaves
,
Gemm0MPerXdl
)),
make_unmerge_transform
(
make_tuple
(
N
/
Gemm0NPerBlock
,
Gemm0NXdlPerWave
,
Gemm0NWaves
,
N3
,
WaveSize
/
Gemm0NPerXdl
,
N5
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
,
4
,
6
>
{},
Sequence
<
1
,
3
,
5
,
7
,
8
,
9
>
{}));
}
// B1 desc for source in blockwise copy
__host__
__device__
static
constexpr
auto
MakeDefaultB1GridDescriptor_BK0_N_BK1
(
const
B1GridDesc_N_K
&
b1_grid_desc_n_k
)
{
const
auto
N
=
b1_grid_desc_n_k
.
GetLength
(
I0
);
const
auto
K
=
b1_grid_desc_n_k
.
GetLength
(
I1
);
const
auto
B1K0
=
K
/
B1K1
;
return
transform_tensor_descriptor
(
b1_grid_desc_n_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
B1K0
,
B1K1
)),
make_pass_through_transform
(
N
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
}
// C1 desc for destination in blockwise copy
__host__
__device__
static
constexpr
auto
MakeE1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
const
E1GridDesc_M_N
&
e1_grid_desc_m_n
)
{
const
auto
M
=
e1_grid_desc_m_n
.
GetLength
(
I0
);
const
auto
N
=
e1_grid_desc_m_n
.
GetLength
(
I1
);
const
auto
MBlock
=
M
/
Gemm0MPerBlock
;
const
auto
NBlock
=
N
/
Gemm1NPerBlock
;
const
auto
e1_grid_desc_mblock_mperblock_nblock_nperblock
=
transform_tensor_descriptor
(
e1_grid_desc_m_n
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
MBlock
,
Number
<
Gemm0MPerBlock
>
{})),
make_unmerge_transform
(
make_tuple
(
NBlock
,
Number
<
Gemm1NPerBlock
>
{}))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
1
>
{},
Sequence
<
2
,
3
>
{}));
return
e1_grid_desc_mblock_mperblock_nblock_nperblock
;
}
// D0s desc for source in blockwise copy
__host__
__device__
static
constexpr
auto
MakeD0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
(
const
D0sGridDesc_M_N
&
ds_grid_desc_m_n
)
{
return
generate_tuple
(
[
&
](
auto
i
)
{
return
MakeGemm0D0GridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
(
ds_grid_desc_m_n
[
i
]);
},
Number
<
NumD0Tensor
>
{});
}
// Ds desc for source in blockwise copy
template
<
typename
DsGridDescriptor_M_N
>
__host__
__device__
static
constexpr
auto
MakeD1sGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
const
DsGridDescriptor_M_N
&
ds_grid_desc_m_n
)
{
return
generate_tuple
(
[
&
](
auto
i
)
{
return
MakeE1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
ds_grid_desc_m_n
[
i
]);
},
Number
<
NumD1Tensor
>
{});
}
// return block_id to C1 matrix tile idx (m0, n0) mapping
__host__
__device__
static
constexpr
auto
MakeDefaultBlock2E1TileMap
(
const
E1GridDesc_M_N
&
e1_grid_desc_m_n
)
{
return
BlockToCTileMap_M00_N0_M01Adapt
<
Gemm0MPerBlock
,
Gemm1NPerBlock
,
E1GridDesc_M_N
>
(
e1_grid_desc_m_n
);
}
using
E1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
=
remove_cvref_t
<
decltype
(
MakeE1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
E1GridDesc_M_N
{}))
>
;
using
D0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
=
remove_cvref_t
<
decltype
(
MakeD0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
(
D0sGridDesc_M_N
{}))
>
;
using
D1sGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
=
remove_cvref_t
<
decltype
(
MakeD1sGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
D1sGridDesc_M_N
{}))
>
;
using
DefaultBlock2E1TileMap
=
remove_cvref_t
<
decltype
(
MakeDefaultBlock2E1TileMap
(
E1GridDesc_M_N
{}))
>
;
struct
SharedMemTrait
{
// LDS allocation for A0 and B0: be careful of alignment
static
constexpr
auto
a0_block_desc_ak0_m_ak1
=
GetA0BlockDescriptor_AK0PerBlock_MPerBlock_AK1
();
static
constexpr
auto
b0_block_desc_bk0_n_bk1
=
GetB0BlockDescriptor_BK0PerBlock_NPerBlock_BK1
();
static
constexpr
auto
b1_block_desc_bk0_n_bk1
=
GetB1BlockDescriptor_BK0PerBlock_NPerBlock_BK1
();
static
constexpr
auto
max_lds_align
=
math
::
lcm
(
math
::
lcm
(
A0K1
,
B0K1
),
B1K1
);
static
constexpr
auto
a0_block_space_size_aligned
=
math
::
integer_least_multiple
(
a0_block_desc_ak0_m_ak1
.
GetElementSpaceSize
(),
max_lds_align
);
static
constexpr
auto
b0_block_space_size_aligned
=
math
::
integer_least_multiple
(
b0_block_desc_bk0_n_bk1
.
GetElementSpaceSize
(),
max_lds_align
);
static
constexpr
auto
b1_block_space_size_aligned
=
math
::
integer_least_multiple
(
b1_block_desc_bk0_n_bk1
.
GetElementSpaceSize
(),
max_lds_align
);
static
constexpr
auto
a0_block_space_offset
=
0
;
static
constexpr
auto
b0_block_space_offset
=
a0_block_space_size_aligned
.
value
;
static
constexpr
auto
b1_block_space_offset
=
0
;
// LDS allocation for C1 shuffle in LDS
static
constexpr
auto
c1_shuffle_block_desc_mblock_mperblock_nblock_nperblock
=
GetC1ShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
();
static
constexpr
auto
c1_block_space_size
=
c1_shuffle_block_desc_mblock_mperblock_nblock_nperblock
.
GetElementSpaceSize
();
};
using
D0sGridPointer
=
decltype
(
MakeD0sGridPointer
());
using
D1sGridPointer
=
decltype
(
MakeD1sGridPointer
());
template
<
bool
HasMainKBlockLoop
,
typename
A0GridDesc_AK0_M_AK1
,
typename
B0GridDesc_BK0_N_BK1
,
typename
B1GridDesc_BK0_N_BK1
,
typename
Block2E1TileMap
>
__device__
static
void
Run
(
const
A0B0B1DataType
*
__restrict__
p_a0_grid
,
const
A0B0B1DataType
*
__restrict__
p_b0_grid
,
D0sGridPointer
p_d0s_grid
,
const
A0B0B1DataType
*
__restrict__
p_b1_grid
,
D1sGridPointer
p_d1s_grid
,
E1DataType
*
__restrict__
p_e1_grid
,
void
*
__restrict__
p_shared
,
const
A0ElementwiseOperation
&
a0_element_op
,
const
B0ElementwiseOperation
&
b0_element_op
,
const
CDE0ElementwiseOperation
&
cde0_element_op
,
const
B1ElementwiseOperation
&
b1_element_op
,
const
CDE1ElementwiseOperation
&
cde1_element_op
,
const
A0GridDesc_AK0_M_AK1
&
a0_grid_desc_ak0_m_ak1
,
const
B0GridDesc_BK0_N_BK1
&
b0_grid_desc_bk0_n_bk1
,
const
D0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
&
d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
const
B1GridDesc_BK0_N_BK1
&
b1_grid_desc_bk0_n_bk1
,
const
D1sGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
&
d1s_grid_desc_mblock_mperblock_nblock_nperblock
,
const
E1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
&
e1_grid_desc_mblock_mperblock_nblock_nperblock
,
const
Block2E1TileMap
&
block_2_e1tile_map
)
{
const
auto
a0_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_a0_grid
,
a0_grid_desc_ak0_m_ak1
.
GetElementSpaceSize
());
const
auto
b0_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_b0_grid
,
b0_grid_desc_bk0_n_bk1
.
GetElementSpaceSize
());
const
auto
b1_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_b1_grid
,
b1_grid_desc_bk0_n_bk1
.
GetElementSpaceSize
());
auto
e1_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_e1_grid
,
e1_grid_desc_mblock_mperblock_nblock_nperblock
.
GetElementSpaceSize
());
const
auto
d0s_grid_buf
=
generate_tuple
(
[
&
](
auto
i
)
{
return
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_d0s_grid
[
i
],
d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
[
i
].
GetElementSpaceSize
());
},
Number
<
NumD0Tensor
>
{});
const
auto
d1s_grid_buf
=
generate_tuple
(
[
&
](
auto
i
)
{
return
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_d1s_grid
[
i
],
d1s_grid_desc_mblock_mperblock_nblock_nperblock
[
i
].
GetElementSpaceSize
());
},
Number
<
NumD1Tensor
>
{});
// divide block work by [M, N]
const
auto
block_work_idx
=
block_2_e1tile_map
.
CalculateBottomIndex
(
make_multi_index
(
get_block_1d_id
()));
if
(
!
block_2_e1tile_map
.
ValidCTileIndex
(
block_work_idx
,
make_tuple
(
e1_grid_desc_mblock_mperblock_nblock_nperblock
.
GetLength
(
I0
),
e1_grid_desc_mblock_mperblock_nblock_nperblock
.
GetLength
(
I2
))))
{
return
;
}
// HACK: this force m/n_block_data_idx_on_grid into SGPR
const
index_t
m_block_data_idx_on_grid
=
__builtin_amdgcn_readfirstlane
(
block_work_idx
[
I0
]
*
Gemm0MPerBlock
);
const
index_t
n_block_data_idx_on_grid
=
__builtin_amdgcn_readfirstlane
(
block_work_idx
[
I1
]
*
Gemm1NPerBlock
);
// A0 matrix in LDS memory, dst of blockwise copy
constexpr
auto
a0_block_desc_ak0_m_ak1
=
GetA0BlockDescriptor_AK0PerBlock_MPerBlock_AK1
();
// B0 matrix in LDS memory, dst of blockwise copy
constexpr
auto
b0_block_desc_bk0_n_bk1
=
GetB0BlockDescriptor_BK0PerBlock_NPerBlock_BK1
();
//
// set up Gemm0
//
// A0 matrix blockwise copy
auto
a0_blockwise_copy
=
ThreadGroupTensorSliceTransfer_v4r1
<
ThisThreadBlock
,
A0ElementwiseOperation
,
tensor_operation
::
element_wise
::
PassThrough
,
InMemoryDataOperationEnum
::
Set
,
Sequence
<
A0K0PerBlock
,
Gemm0MPerBlock
,
A0K1
>
,
A0BlockTransferThreadClusterLengths_AK0_M_AK1
,
A0BlockTransferThreadClusterArrangeOrder
,
A0B0B1DataType
,
A0B0B1DataType
,
decltype
(
a0_grid_desc_ak0_m_ak1
),
decltype
(
a0_block_desc_ak0_m_ak1
),
A0BlockTransferSrcAccessOrder
,
Sequence
<
1
,
0
,
2
>
,
A0BlockTransferSrcVectorDim
,
2
,
A0BlockTransferSrcScalarPerVector
,
A0BlockTransferDstScalarPerVector_AK1
,
1
,
1
,
true
,
// SrcResetCoord
true
,
// DstResetCoord
NumGemm0KPrefetchStage
>
(
a0_grid_desc_ak0_m_ak1
,
make_multi_index
(
0
,
m_block_data_idx_on_grid
,
0
),
a0_element_op
,
a0_block_desc_ak0_m_ak1
,
make_multi_index
(
0
,
0
,
0
),
tensor_operation
::
element_wise
::
PassThrough
{});
// B0 matrix blockwise copy
auto
b0_blockwise_copy
=
ThreadGroupTensorSliceTransfer_v4r1
<
ThisThreadBlock
,
B0ElementwiseOperation
,
tensor_operation
::
element_wise
::
PassThrough
,
InMemoryDataOperationEnum
::
Set
,
Sequence
<
B0K0PerBlock
,
Gemm0NPerBlock
,
B0K1
>
,
B0BlockTransferThreadClusterLengths_BK0_N_BK1
,
B0BlockTransferThreadClusterArrangeOrder
,
A0B0B1DataType
,
A0B0B1DataType
,
decltype
(
b0_grid_desc_bk0_n_bk1
),
decltype
(
b0_block_desc_bk0_n_bk1
),
B0BlockTransferSrcAccessOrder
,
Sequence
<
1
,
0
,
2
>
,
B0BlockTransferSrcVectorDim
,
2
,
B0BlockTransferSrcScalarPerVector
,
B0BlockTransferDstScalarPerVector_BK1
,
1
,
1
,
true
,
// SrcResetCoord
true
,
// DstResetCoord
NumGemm0KPrefetchStage
>
(
b0_grid_desc_bk0_n_bk1
,
make_multi_index
(
0
,
0
,
0
),
// will loop over GemmN dimension
b0_element_op
,
b0_block_desc_bk0_n_bk1
,
make_multi_index
(
0
,
0
,
0
),
tensor_operation
::
element_wise
::
PassThrough
{});
// Fused Gemm+Gemm pipeline
// for n in N0:
// for k in K0:
// acc[m][n] += A[m][k] * B0[k][n]
// acc1[m][o] += acc[m][n] * B1[n][o]
// sanity check
constexpr
index_t
KPack
=
math
::
max
(
math
::
lcm
(
A0K1
,
B0K1
),
MfmaSelector
<
A0B0B1DataType
,
Gemm0MPerXdl
,
Gemm0NPerXdl
>::
selected_mfma
.
k_per_blk
);
auto
blockwise_gemm0
=
BlockwiseGemmXdlops_v2
<
BlockSize
,
A0B0B1DataType
,
Acc0DataType
,
decltype
(
a0_block_desc_ak0_m_ak1
),
decltype
(
b0_block_desc_bk0_n_bk1
),
decltype
(
MakeGemm0AMmaTileDescriptor_M0_M1_M2_K
(
a0_block_desc_ak0_m_ak1
)),
decltype
(
MakeGemm0BMmaTileDescriptor_N0_N1_N2_K
(
b0_block_desc_bk0_n_bk1
)),
Gemm0MPerBlock
,
Gemm0NPerBlock
,
Gemm0KPerBlock
,
Gemm0MPerXdl
,
Gemm0NPerXdl
,
Gemm0MXdlPerWave
,
Gemm0NXdlPerWave
,
KPack
,
true
>
{};
// TransposeC
auto
acc0_thread_buf
=
blockwise_gemm0
.
GetCThreadBuffer
();
// LDS allocation for A0 and B0: be careful of alignment
auto
a0_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
A0B0B1DataType
*>
(
p_shared
)
+
SharedMemTrait
::
a0_block_space_offset
,
a0_block_desc_ak0_m_ak1
.
GetElementSpaceSize
());
auto
b0_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
A0B0B1DataType
*>
(
p_shared
)
+
SharedMemTrait
::
b0_block_space_offset
,
b0_block_desc_bk0_n_bk1
.
GetElementSpaceSize
());
constexpr
auto
a0_block_slice_copy_step
=
make_multi_index
(
Gemm0KPerBlock
/
A0K1
,
0
,
0
);
constexpr
auto
b0_block_slice_copy_step
=
make_multi_index
(
Gemm0KPerBlock
/
B0K1
,
0
,
0
);
const
auto
a0_block_reset_copy_step
=
make_multi_index
(
-
a0_grid_desc_ak0_m_ak1
.
GetLength
(
I0
),
0
,
0
);
const
auto
b0_block_reset_copy_step
=
make_multi_index
(
-
b0_grid_desc_bk0_n_bk1
.
GetLength
(
I0
),
Gemm0NPerBlock
,
0
);
// gridwise GEMM pipeline
// Only supports LoopScheduler::Default
const
auto
gridwise_gemm0_pipeline
=
GridwiseGemmPipeline_v1_Selector
<
NumGemm0KPrefetchStage
,
LoopScheduler
::
Default
>
();
const
index_t
num_k_block_main_loop
=
__builtin_amdgcn_readfirstlane
(
(
a0_grid_desc_ak0_m_ak1
.
GetLength
(
I0
)
*
a0_grid_desc_ak0_m_ak1
.
GetLength
(
I2
))
/
Gemm0KPerBlock
);
//
// set up Gemm1
//
// Acc0 matrix threadwise copy: AccVGPR to VGPR and downcast to XDL input data type
constexpr
auto
acc0_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
=
blockwise_gemm0
.
GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4
();
constexpr
auto
m0
=
acc0_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
.
GetLength
(
I0
);
constexpr
auto
n0
=
acc0_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
.
GetLength
(
I1
);
constexpr
auto
m1
=
acc0_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
.
GetLength
(
I2
);
constexpr
auto
n1
=
acc0_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
.
GetLength
(
I3
);
constexpr
auto
m2
=
acc0_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
.
GetLength
(
I4
);
constexpr
auto
n2
=
acc0_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
.
GetLength
(
I5
);
constexpr
auto
n3
=
acc0_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
.
GetLength
(
I6
);
constexpr
auto
n4
=
acc0_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
.
GetLength
(
I7
);
constexpr
auto
b1_block_slice_copy_step
=
make_multi_index
(
Gemm1KPerBlock
/
B1K1
,
0
,
0
);
// d0 matrix threadwise copy
constexpr
auto
d0_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
// MBlockId
I1
,
// NBlockID
I1
,
// MRepeat
I1
,
// NRepeat
I1
,
// MWaveId
I1
,
// NWaveId
I1
,
// MPerXdl
I1
,
// NGroupNum
I1
,
// NInputNum
n4
));
// registerNum
auto
d0s_thread_buf
=
generate_tuple
(
[
&
](
auto
)
{
return
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
A0B0B1DataType
,
d0_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
.
GetElementSpaceSize
(),
true
>
{};
},
Number
<
NumD0Tensor
>
{});
const
auto
wave_id
=
GetGemm0WaveIdx
();
const
auto
wave_m_n_id
=
GetGemm0WaveMNIdx
(
wave_id
[
I2
]);
// I2: 0~63
constexpr
auto
acc0_thread_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
Gemm0MXdlPerWave
>
{},
Number
<
Gemm0NXdlPerWave
>
{},
n2
,
n4
));
auto
d0s_threadwise_copy
=
generate_tuple
(
[
&
](
auto
i
)
{
return
ThreadwiseTensorSliceTransfer_v2
<
A0B0B1DataType
,
A0B0B1DataType
,
decltype
(
d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
[
i
]),
decltype
(
d0_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
),
Sequence
<
I1
,
I1
,
I1
,
I1
,
I1
,
I1
,
I1
,
I1
,
I1
,
n4
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
>
,
9
,
n4
,
1
,
false
>
(
d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
[
i
],
make_multi_index
(
block_work_idx
[
I0
],
// MBlockId
0
,
// NBlockId
0
,
// mrepeat
0
,
// nrepeat
wave_id
[
I0
],
// MWaveId
wave_id
[
I1
],
// NWaveId
wave_m_n_id
[
I1
],
// MPerXdl
0
,
// group
wave_m_n_id
[
I0
],
// NInputIndex
0
));
// register number
},
Number
<
NumD0Tensor
>
{});
// acc0_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4 to acc0_thread_desc_k0_m_k1
// n0_n1_n2_n3 -> k0
// m0_m1_m2 -> m
// n4 -> k1
// NOTE: had to use merge_v3 or will spit out compilation errors
constexpr
auto
acc0_thread_desc_k0_m_k1
=
transform_tensor_descriptor
(
acc0_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
,
make_tuple
(
make_merge_transform_v3_division_mod
(
make_tuple
(
n0
,
n1
,
n2
,
n3
)),
make_merge_transform_v3_division_mod
(
make_tuple
(
m0
,
m1
,
m2
)),
make_pass_through_transform
(
n4
)),
make_tuple
(
Sequence
<
1
,
3
,
5
,
6
>
{},
Sequence
<
0
,
2
,
4
>
{},
Sequence
<
7
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}));
// A1 matrix in AccVGPR
// N2 num_groups_per_blk, N3 num_input_blks, N4 group_size
constexpr
auto
Acc0N3
=
blockwise_gemm0
.
GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4
().
GetLength
(
I6
);
constexpr
auto
A1ThreadSlice_K0_M_K1
=
make_tuple
(
Number
<
Gemm1KPerBlock
/
n4
/
Acc0N3
>
{},
Number
<
m0
*
m1
*
m2
>
{},
Number
<
n4
>
{});
constexpr
auto
A1ThreadSliceK0
=
A1ThreadSlice_K0_M_K1
[
I0
];
constexpr
auto
A1ThreadSliceM
=
A1ThreadSlice_K0_M_K1
[
I1
];
constexpr
auto
A1ThreadSliceK1
=
A1ThreadSlice_K0_M_K1
[
I2
];
constexpr
auto
a1_thread_desc_k0_m_k1
=
make_naive_tensor_descriptor
(
A1ThreadSlice_K0_M_K1
,
make_tuple
(
A1ThreadSliceM
*
A1ThreadSliceK1
,
A1ThreadSliceK1
,
I1
));
// B1 matrix in LDS memory, dst of blockwise copy
constexpr
auto
b1_block_desc_bk0_n_bk1
=
GetB1BlockDescriptor_BK0PerBlock_NPerBlock_BK1
();
// A1 matrix blockwise copy
auto
a1_blockwise_copy
=
ThreadwiseTensorSliceTransfer_StaticToStatic
<
Acc0DataType
,
A0B0B1DataType
,
decltype
(
acc0_thread_desc_k0_m_k1
),
decltype
(
a1_thread_desc_k0_m_k1
),
tensor_operation
::
element_wise
::
PassThrough
,
Sequence
<
A1ThreadSliceK0
,
A1ThreadSliceM
,
A1ThreadSliceK1
>
,
Sequence
<
1
,
0
,
2
>
,
2
,
n4
>
{
tensor_operation
::
element_wise
::
PassThrough
{}};
// B1 matrix blockwise copy
auto
b1_blockwise_copy
=
ThreadGroupTensorSliceTransfer_v4r1
<
ThisThreadBlock
,
B0ElementwiseOperation
,
tensor_operation
::
element_wise
::
PassThrough
,
InMemoryDataOperationEnum
::
Set
,
Sequence
<
B1K0PerBlock
,
Gemm1NPerBlock
,
B1K1
>
,
B1BlockTransferThreadClusterLengths_BK0_N_BK1
,
B1BlockTransferThreadClusterArrangeOrder
,
A0B0B1DataType
,
A0B0B1DataType
,
decltype
(
b1_grid_desc_bk0_n_bk1
),
decltype
(
b1_block_desc_bk0_n_bk1
),
B1BlockTransferSrcAccessOrder
,
Sequence
<
1
,
0
,
2
>
,
B1BlockTransferSrcVectorDim
,
2
,
B1BlockTransferSrcScalarPerVector
,
B1BlockTransferDstScalarPerVector_BK1
,
1
,
1
,
B1ThreadTransferSrcResetCoordinateAfterRun
,
true
,
// DstResetCoord
1
>
(
b1_grid_desc_bk0_n_bk1
,
make_multi_index
(
0
,
n_block_data_idx_on_grid
,
0
),
b1_element_op
,
b1_block_desc_bk0_n_bk1
,
make_multi_index
(
0
,
0
,
0
),
tensor_operation
::
element_wise
::
PassThrough
{});
auto
a1_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
A0B0B1DataType
>
(
a1_thread_desc_k0_m_k1
.
GetElementSpaceSize
());
// reuse LDS space for gemm0's b0_block_buf
auto
b1_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
A0B0B1DataType
*>
(
p_shared
)
+
SharedMemTrait
::
b1_block_space_offset
,
b1_block_desc_bk0_n_bk1
.
GetElementSpaceSize
());
constexpr
index_t
Gemm1KPack
=
math
::
max
(
math
::
lcm
(
MfmaSelector
<
A0B0B1DataType
,
Gemm0MPerXdl
,
Gemm0NPerXdl
>::
selected_mfma
.
group_size
,
B1K1
),
MfmaSelector
<
A0B0B1DataType
,
Gemm0MPerXdl
,
Gemm0NPerXdl
>::
selected_mfma
.
k_per_blk
);
auto
blockwise_gemm1
=
BlockwiseGemmXdlops_v2
<
BlockSize
,
A0B0B1DataType
,
Acc1DataType
,
decltype
(
a1_thread_desc_k0_m_k1
),
decltype
(
b1_block_desc_bk0_n_bk1
),
decltype
(
MakeGemm1AMmaTileDescriptor_M0_M1_M2_K
(
a1_thread_desc_k0_m_k1
)),
decltype
(
MakeGemm1BMmaTileDescriptor_N0_N1_N2_K
(
b1_block_desc_bk0_n_bk1
)),
Gemm0MPerBlock
,
Gemm1NPerBlock
,
Gemm1KPerBlock
,
Gemm0MPerXdl
,
Gemm0NPerXdl
,
Gemm0MXdlPerWave
,
Gemm1NXdlPerWave
,
Gemm1KPack
,
false
,
// TransposeC
Gemm1KPack
,
// AMmaKStride
Gemm1KPack
*
XdlopsGemm
<
A0B0B1DataType
,
Gemm0MPerXdl
,
Gemm0NPerXdl
,
Gemm1KPack
,
false
>
{}
.
K0PerXdlops
>
{
// BMmaKStride
make_tuple
(
0
,
0
,
0
,
0
)};
// A_origin
auto
c1_thread_buf
=
blockwise_gemm1
.
GetCThreadBuffer
();
const
index_t
num_gemm1_k_block_outer_loop
=
b0_grid_desc_bk0_n_bk1
.
GetLength
(
I1
)
/
Gemm0NPerBlock
;
constexpr
index_t
num_gemm1_k_block_inner_loop
=
Gemm0NPerBlock
/
Gemm1KPerBlock
;
// Initialize C1
c1_thread_buf
.
Clear
();
// gemm1 K loop
index_t
gemm1_k_block_outer_index
=
0
;
do
{
// gemm0
gridwise_gemm0_pipeline
.
template
Run
<
HasMainKBlockLoop
>(
a0_grid_desc_ak0_m_ak1
,
a0_block_desc_ak0_m_ak1
,
a0_blockwise_copy
,
a0_grid_buf
,
a0_block_buf
,
a0_block_slice_copy_step
,
b0_grid_desc_bk0_n_bk1
,
b0_block_desc_bk0_n_bk1
,
b0_blockwise_copy
,
b0_grid_buf
,
b0_block_buf
,
b0_block_slice_copy_step
,
blockwise_gemm0
,
acc0_thread_buf
,
num_k_block_main_loop
);
// bias+gelu
{
static_for
<
0
,
Gemm0MXdlPerWave
,
1
>
{}([
&
](
auto
mr
)
{
static_for
<
0
,
Gemm0NXdlPerWave
,
1
>
{}([
&
](
auto
nr
)
{
static_for
<
0
,
n2
,
1
>
{}([
&
](
auto
groupid
)
{
static_for
<
0
,
NumD0Tensor
,
1
>
{}([
&
](
auto
i
)
{
d0s_threadwise_copy
(
i
).
Run
(
d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
[
i
],
d0s_grid_buf
[
i
],
d0_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
d0s_thread_buf
(
i
));
});
static_for
<
0
,
n4
,
1
>
{}([
&
](
auto
i
)
{
constexpr
index_t
c_offset
=
acc0_thread_desc
.
CalculateOffset
(
make_tuple
(
mr
,
nr
,
groupid
,
i
));
// get reference to src data
const
auto
src_data_refs
=
generate_tie
(
// return type should be lvalue
[
&
](
auto
iSrc
)
->
const
auto
&
{
return
d0s_thread_buf
[
iSrc
][
i
];
},
Number
<
NumD0Tensor
>
{});
// get reference to dst data
auto
dst_data_refs
=
generate_tie
(
// return type should be lvalue
[
&
](
auto
)
->
auto
&
{
return
acc0_thread_buf
(
Number
<
c_offset
>
{});
},
Number
<
2
>
{});
unpack2
(
cde0_element_op
,
dst_data_refs
,
src_data_refs
);
});
static_for
<
0
,
NumD0Tensor
,
1
>
{}([
&
](
auto
i
)
{
d0s_threadwise_copy
(
i
).
MoveSrcSliceWindow
(
d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
[
i
],
make_multi_index
(
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
));
});
});
static_for
<
0
,
NumD0Tensor
,
1
>
{}([
&
](
auto
i
)
{
d0s_threadwise_copy
(
i
).
MoveSrcSliceWindow
(
d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
[
i
],
make_multi_index
(
0
,
0
,
0
,
1
,
0
,
0
,
0
,
-
n2
.
value
,
0
,
0
));
});
});
static_for
<
0
,
NumD0Tensor
,
1
>
{}([
&
](
auto
i
)
{
d0s_threadwise_copy
(
i
).
MoveSrcSliceWindow
(
d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
[
i
],
make_multi_index
(
0
,
0
,
1
,
-
Gemm0NXdlPerWave
,
0
,
0
,
0
,
0
,
0
,
0
));
});
});
static_for
<
0
,
NumD0Tensor
,
1
>
{}([
&
](
auto
i
)
{
d0s_threadwise_copy
(
i
).
MoveSrcSliceWindow
(
d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
[
i
],
make_multi_index
(
0
,
1
,
-
Gemm0MXdlPerWave
,
0
,
0
,
0
,
0
,
0
,
0
,
0
));
});
}
// gemm1
{
// TODO: explore using dynamic buffer for a1 thread buffer
// For a1_blockwise_copy, the goal is to satisfy pipeline requirements RunRead(),
// RunWrite(), and MoveSliceWindow(). But it is impossible to implement given that
// the A1 source buffer is static buffer holding the output of first GEMM and
// requires constexpr offset by design. Therefore, we pass tensor coordinate offset
// explicitly in Run() below.
// preload data into LDS
b1_blockwise_copy
.
RunRead
(
b1_grid_desc_bk0_n_bk1
,
b1_grid_buf
);
b1_blockwise_copy
.
MoveSrcSliceWindow
(
b1_grid_desc_bk0_n_bk1
,
b1_block_slice_copy_step
);
block_sync_lds
();
// wait for gemm0 LDS read
b1_blockwise_copy
.
RunWrite
(
b1_block_desc_bk0_n_bk1
,
b1_block_buf
);
// main body
if
constexpr
(
num_gemm1_k_block_inner_loop
>
1
)
{
static_for
<
0
,
num_gemm1_k_block_inner_loop
-
1
,
1
>
{}([
&
](
auto
i
)
{
a1_blockwise_copy
.
Run
(
acc0_thread_desc_k0_m_k1
,
make_tuple
(
Number
<
i
*
A1ThreadSliceK0
>
{},
I0
,
I0
),
acc0_thread_buf
,
a1_thread_desc_k0_m_k1
,
make_tuple
(
I0
,
I0
,
I0
),
a1_thread_buf
);
b1_blockwise_copy
.
RunRead
(
b1_grid_desc_bk0_n_bk1
,
b1_grid_buf
);
block_sync_lds
();
blockwise_gemm1
.
Run
(
a1_thread_buf
,
b1_block_buf
,
c1_thread_buf
);
block_sync_lds
();
b1_blockwise_copy
.
MoveSrcSliceWindow
(
b1_grid_desc_bk0_n_bk1
,
b1_block_slice_copy_step
);
b1_blockwise_copy
.
RunWrite
(
b1_block_desc_bk0_n_bk1
,
b1_block_buf
);
});
}
// tail
{
a1_blockwise_copy
.
Run
(
acc0_thread_desc_k0_m_k1
,
make_tuple
(
Number
<
(
num_gemm1_k_block_inner_loop
-
1
)
*
A1ThreadSliceK0
>
{},
I0
,
I0
),
acc0_thread_buf
,
a1_thread_desc_k0_m_k1
,
make_tuple
(
I0
,
I0
,
I0
),
a1_thread_buf
);
block_sync_lds
();
blockwise_gemm1
.
Run
(
a1_thread_buf
,
b1_block_buf
,
c1_thread_buf
);
}
}
// end gemm1
a0_blockwise_copy
.
MoveSrcSliceWindow
(
a0_grid_desc_ak0_m_ak1
,
a0_block_reset_copy_step
);
// rewind K
b0_blockwise_copy
.
MoveSrcSliceWindow
(
b0_grid_desc_bk0_n_bk1
,
b0_block_reset_copy_step
);
// rewind K and step N
block_sync_lds
();
// wait for gemm1 LDS read
}
while
(
++
gemm1_k_block_outer_index
<
num_gemm1_k_block_outer_loop
);
// end j loop
// shuffle C1 and write out
{
static_assert
(
Gemm0MXdlPerWave
%
C1ShuffleGemm0MXdlPerWavePerShuffle
==
0
&&
Gemm1NXdlPerWave
%
C1ShuffleGemm0NXdlPerWavePerShuffle
==
0
,
"wrong!"
);
constexpr
index_t
MWave
=
Gemm0MPerBlock
/
(
Gemm0MXdlPerWave
*
Gemm0MPerXdl
);
constexpr
index_t
NWave
=
Gemm1NPerBlock
/
(
Gemm1NXdlPerWave
*
Gemm0NPerXdl
);
// TODO: hacky, fix it!
constexpr
auto
c1_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2
=
blockwise_gemm1
.
GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
();
// TODO: hacky, fix it!
// c1_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths
constexpr
auto
c1_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
=
blockwise_gemm1
.
GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
();
constexpr
auto
M0
=
c1_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
.
GetLength
(
I0
);
constexpr
auto
N0
=
c1_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
.
GetLength
(
I1
);
constexpr
auto
M1
=
c1_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
.
GetLength
(
I2
);
constexpr
auto
N1
=
c1_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
.
GetLength
(
I3
);
constexpr
auto
M2
=
c1_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
.
GetLength
(
I4
);
constexpr
auto
M3
=
c1_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
.
GetLength
(
I5
);
constexpr
auto
M4
=
c1_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
.
GetLength
(
I6
);
constexpr
auto
N2
=
c1_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
.
GetLength
(
I7
);
constexpr
auto
c1_shuffle_block_desc_mblock_mperblock_nblock_nperblock
=
GetC1ShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
();
auto
c1_shuffle_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
C1ShuffleDataType
*>
(
p_shared
),
c1_shuffle_block_desc_mblock_mperblock_nblock_nperblock
.
GetElementSpaceSize
());
constexpr
auto
c1_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
=
transform_tensor_descriptor
(
c1_shuffle_block_desc_mblock_mperblock_nblock_nperblock
,
make_tuple
(
make_freeze_transform
(
I0
),
make_unmerge_transform
(
make_tuple
(
Number
<
C1ShuffleGemm0MXdlPerWavePerShuffle
>
{},
// M0 (Gemm0MXdlPerWave) per
// shuffle
M1
,
// M1 = MWave
M2
,
// M2 * M3 * M4 = Gemm0MPerXdl
M3
,
M4
)),
make_freeze_transform
(
I0
),
make_unmerge_transform
(
make_tuple
(
Number
<
C1ShuffleGemm0NXdlPerWavePerShuffle
>
{},
// N0 (Gemm0NXdlPerWave) per
// shuffle
N1
,
// N1 = NWave
N2
))),
// N2 = Gemm0NPerXdl
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<>
{},
Sequence
<
0
,
2
,
4
,
5
,
6
>
{},
Sequence
<>
{},
Sequence
<
1
,
3
,
7
>
{}));
// calculate origin of thread output tensor on global memory
// blockwise GEMM C1 matrix starting index
const
auto
c1_thread_mtx_on_block
=
blockwise_gemm1
.
CalculateCThreadOriginDataIndex
(
I0
,
I0
,
I0
,
I0
);
const
index_t
m_thread_data_on_block
=
c1_thread_mtx_on_block
[
I0
];
const
index_t
n_thread_data_on_block
=
c1_thread_mtx_on_block
[
I1
];
const
auto
m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
M0
,
M1
,
M2
,
M3
,
M4
))),
make_tuple
(
Sequence
<
0
,
1
,
2
,
3
,
4
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
const
auto
m_thread_data_on_block_idx
=
m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor
.
CalculateBottomIndex
(
make_multi_index
(
m_thread_data_on_block
));
const
auto
n_thread_data_on_block_to_n0_n1_n2_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
N0
,
N1
,
N2
))),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
const
auto
n_thread_data_on_block_idx
=
n_thread_data_on_block_to_n0_n1_n2_adaptor
.
CalculateBottomIndex
(
make_multi_index
(
n_thread_data_on_block
));
// shuffle: threadwise copy C from VGPR to LDS
auto
c1_thread_copy_vgpr_to_lds
=
ThreadwiseTensorSliceTransfer_v1r3
<
Acc1DataType
,
C1ShuffleDataType
,
decltype
(
c1_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2
),
decltype
(
c1_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
),
tensor_operation
::
element_wise
::
PassThrough
,
Sequence
<
C1ShuffleGemm0MXdlPerWavePerShuffle
,
C1ShuffleGemm0NXdlPerWavePerShuffle
,
I1
,
I1
,
M2
,
I1
,
M4
,
I1
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
>
,
7
,
1
,
InMemoryDataOperationEnum
::
Set
,
1
,
true
>
{
c1_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
make_multi_index
(
0
,
0
,
m_thread_data_on_block_idx
[
I1
],
n_thread_data_on_block_idx
[
I1
],
m_thread_data_on_block_idx
[
I2
],
m_thread_data_on_block_idx
[
I3
],
m_thread_data_on_block_idx
[
I4
],
n_thread_data_on_block_idx
[
I2
]),
tensor_operation
::
element_wise
::
PassThrough
{}};
// tuple of reference to C/Ds tensor descriptors
const
auto
c1_d1s_desc_refs
=
concat_tuple_of_reference
(
tie
(
c1_shuffle_block_desc_mblock_mperblock_nblock_nperblock
),
generate_tie
(
[
&
](
auto
i
)
->
const
auto
&
// return type should be reference
{
return
d1s_grid_desc_mblock_mperblock_nblock_nperblock
[
i
];
},
Number
<
NumD1Tensor
>
{}));
// tuple of reference to C/Ds tensor descriptors
const
auto
c1_d1s_buf_refs
=
concat_tuple_of_reference
(
tie
(
c1_shuffle_block_buf
),
generate_tie
(
[
&
](
auto
i
)
->
const
auto
&
// return type should be reference
{
return
d1s_grid_buf
[
i
];
},
Number
<
NumD1Tensor
>
{}));
// tuple of starting index of C/Ds blockwise copy
const
auto
idx_c1_d1s_block_begin
=
container_concat
(
make_tuple
(
make_multi_index
(
0
,
0
,
0
,
0
)),
generate_tuple
(
[
&
](
auto
)
{
return
make_multi_index
(
block_work_idx
[
I0
],
0
,
block_work_idx
[
I1
],
0
);
},
Number
<
NumD1Tensor
>
{}));
// shuffle: blockwise copy C from LDS to global
auto
cde1_shuffle_block_copy_lds_to_global
=
ThreadGroupTensorSliceTransfer_v7
<
ThisThreadBlock
,
decltype
(
container_concat
(
make_tuple
(
C1ShuffleDataType
{}),
D1sDataType
{})),
Tuple
<
E1DataType
>
,
decltype
(
c1_d1s_desc_refs
),
decltype
(
tie
(
e1_grid_desc_mblock_mperblock_nblock_nperblock
)),
CDE1ElementwiseOperation
,
Sequence
<
static_cast
<
index_t
>
(
E1GlobalMemoryDataOperation
)
>
,
// FIXME: make Sequence
// support arbitray
// type
Sequence
<
1
,
C1ShuffleGemm0MXdlPerWavePerShuffle
*
MWave
*
Gemm0MPerXdl
,
1
,
C1ShuffleGemm0NXdlPerWavePerShuffle
*
NWave
*
Gemm0NPerXdl
>
,
// BlockSliceLengths,
CDE1ShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
Sequence
<
0
,
1
,
2
,
3
>
,
// typename ThreadClusterArrangeOrder,
Sequence
<
0
,
1
,
2
,
3
>
,
// typename DimAccessOrder,
3
,
// index_t VectorDim,
CDE1ShuffleBlockTransferScalarPerVector_NPerBlock
,
sequence_merge_t
<
Sequence
<
true
>
,
uniform_sequence_gen_t
<
NumD1Tensor
,
false
>>
,
// ThreadTransferSrcResetCoordinateAfterRunFlags
Sequence
<
false
>>
// ThreadTransferDstResetCoordinateAfterRunFlags
{
c1_d1s_desc_refs
,
idx_c1_d1s_block_begin
,
tie
(
e1_grid_desc_mblock_mperblock_nblock_nperblock
),
make_tuple
(
make_multi_index
(
block_work_idx
[
I0
],
0
,
block_work_idx
[
I1
],
0
)),
cde1_element_op
};
// space filling curve for threadwise C in VGPR
constexpr
auto
sfc_c1_vgpr
=
SpaceFillingCurve
<
Sequence
<
Gemm0MXdlPerWave
,
Gemm1NXdlPerWave
,
1
,
1
,
M2
,
1
,
M4
,
1
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
>
,
Sequence
<
C1ShuffleGemm0MXdlPerWavePerShuffle
,
C1ShuffleGemm0NXdlPerWavePerShuffle
,
1
,
1
,
M2
,
1
,
M4
,
1
>>
{};
// space filling curve for shuffled blockwise C in global mem
constexpr
auto
sfc_e1_global
=
SpaceFillingCurve
<
Sequence
<
1
,
Gemm0MPerBlock
,
1
,
Gemm1NPerBlock
>
,
Sequence
<
0
,
2
,
1
,
3
>
,
Sequence
<
1
,
C1ShuffleGemm0MXdlPerWavePerShuffle
*
MWave
*
Gemm0MPerXdl
,
1
,
C1ShuffleGemm0NXdlPerWavePerShuffle
*
NWave
*
Gemm0NPerXdl
>>
{};
constexpr
index_t
num_access
=
sfc_c1_vgpr
.
GetNumOfAccess
();
static_assert
(
num_access
==
sfc_e1_global
.
GetNumOfAccess
(),
"wrong!"
);
static_for
<
0
,
num_access
,
1
>
{}([
&
](
auto
access_id
)
{
// make sure it's safe to write to LDS
block_sync_lds
();
// each thread write its data from VGPR to LDS
c1_thread_copy_vgpr_to_lds
.
Run
(
c1_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
sfc_c1_vgpr
.
GetIndexTupleOfNumber
(
access_id
),
c1_thread_buf
,
c1_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
c1_shuffle_block_buf
);
// make sure it's safe to read from LDS
block_sync_lds
();
// each block copy its data from LDS to global
cde1_shuffle_block_copy_lds_to_global
.
Run
(
c1_d1s_desc_refs
,
c1_d1s_buf_refs
,
tie
(
e1_grid_desc_mblock_mperblock_nblock_nperblock
),
tie
(
e1_grid_buf
));
if
constexpr
(
access_id
<
num_access
-
1
)
{
constexpr
auto
e1_global_step
=
sfc_e1_global
.
GetForwardStep
(
access_id
);
// move on D1s
static_for
<
0
,
NumD1Tensor
,
1
>
{}([
&
](
auto
i
)
{
cde1_shuffle_block_copy_lds_to_global
.
MoveSrcSliceWindow
(
c1_d1s_desc_refs
,
i
+
I1
,
e1_global_step
);
});
// move on C
cde1_shuffle_block_copy_lds_to_global
.
MoveDstSliceWindow
(
tie
(
e1_grid_desc_mblock_mperblock_nblock_nperblock
),
I0
,
e1_global_step
);
}
});
}
}
};
}
// namespace ck
include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp
View file @
b89a88b5
...
...
@@ -75,7 +75,8 @@ template <typename FloatAB,
index_t
CShuffleNXdlPerWavePerShuffle
,
typename
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
index_t
CShuffleBlockTransferScalarPerVector_NPerBlock
,
LoopScheduler
LoopSched
>
LoopScheduler
LoopSched
,
bool
PadN
>
struct
GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
{
static_assert
(
LoopSched
==
LoopScheduler
::
Default
,
...
...
@@ -182,11 +183,19 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
__host__
__device__
static
constexpr
index_t
GetSharedMemoryNumberOfByte
()
{
return
math
::
max
((
SharedMemTrait
::
a_block_space_size_aligned
+
SharedMemTrait
::
b_block_space_size_aligned
)
*
sizeof
(
FloatAB
)
+
SharedMemTrait
::
reduction_workspace
*
sizeof
(
FloatGemmAcc
),
SharedMemTrait
::
c_block_size
*
sizeof
(
FloatCShuffle
));
const
index_t
gemm0_bytes_end
=
(
SharedMemTrait
::
a_block_space_size_aligned
+
SharedMemTrait
::
b_block_space_size_aligned
)
*
sizeof
(
FloatAB
);
const
index_t
gemm1_bytes_end
=
(
SharedMemTrait
::
b1_block_space_offset
+
SharedMemTrait
::
b1_block_space_size_aligned
)
*
sizeof
(
FloatAB
);
const
index_t
softmax_bytes_end
=
(
SharedMemTrait
::
reduction_space_offset
+
SharedMemTrait
::
reduction_space_size_aligned
)
*
sizeof
(
FloatGemmAcc
);
const
index_t
c_block_bytes_end
=
SharedMemTrait
::
c_block_space_size
*
sizeof
(
FloatCShuffle
);
return
math
::
max
(
gemm0_bytes_end
,
gemm1_bytes_end
,
softmax_bytes_end
,
c_block_bytes_end
);
}
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
...
...
@@ -237,8 +246,6 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
return
false
;
}
assert
(
num_gemm1_k_outer_loop
*
num_gemm1_k_inner_loop
==
N
/
Gemm1KPerBlock
);
if
(
!
block_2_ctile_map
.
CheckValidity
(
c_grid_desc_m_n
))
{
return
false
;
...
...
@@ -302,25 +309,58 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
static
constexpr
auto
a_block_space_size_aligned
=
math
::
integer_least_multiple
(
a_block_desc_ak0_m_ak1
.
GetElementSpaceSize
(),
max_lds_align
);
static
constexpr
auto
b
0
_block_space_size_aligned
=
math
::
integer_least_multiple
(
static
constexpr
auto
b_block_space_size_aligned
=
math
::
integer_least_multiple
(
b_block_desc_bk0_n_bk1
.
GetElementSpaceSize
(),
max_lds_align
);
static
constexpr
auto
b1_block_space_size_aligned
=
math
::
integer_least_multiple
(
b1_block_desc_bk0_n_bk1
.
GetElementSpaceSize
(),
max_lds_align
);
// B1 can reuse B's LDS
static
constexpr
auto
b_block_space_size_aligned
=
math
::
max
(
b0_block_space_size_aligned
.
value
,
b1_block_space_size_aligned
.
value
)
;
static
constexpr
auto
a_block_space_offset
=
0
;
static
constexpr
auto
b_block_space_
offset
=
a_block_space_
size_aligned
.
value
;
static
constexpr
auto
b1_block_space_offset
=
0
;
// LDS allocation for reduction
static
constexpr
index_t
reduction_workspace
=
BlockSize
;
static
constexpr
index_t
reduction_space_size_aligned
=
math
::
integer_least_multiple
(
BlockSize
,
max_lds_align
);
static
constexpr
auto
reduction_space_offset
=
0
;
// LDS allocation for C shuffle in LDS
static
constexpr
auto
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
=
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
();
static
constexpr
auto
c_block_size
=
static
constexpr
auto
c_block_
space_
size
=
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
.
GetElementSpaceSize
();
};
template
<
bool
Pred
>
struct
ElementOpPredicatedResetNaNToMinusInf
;
template
<
>
struct
ElementOpPredicatedResetNaNToMinusInf
<
true
>
{
template
<
typename
ElementOp
,
typename
OutT
,
typename
InT
>
__host__
__device__
void
Run
(
OutT
&
y
,
const
ElementOp
&
op
,
const
InT
&
x
)
{
if
(
ck
::
math
::
isnan
(
x
))
{
y
=
-
ck
::
NumericLimits
<
float
>::
Infinity
();
}
else
{
op
(
y
,
x
);
}
}
};
template
<
>
struct
ElementOpPredicatedResetNaNToMinusInf
<
false
>
{
template
<
typename
ElementOp
,
typename
OutT
,
typename
InT
>
__host__
__device__
void
Run
(
OutT
&
y
,
const
ElementOp
&
op
,
const
InT
&
x
)
{
op
(
y
,
x
);
}
};
template
<
bool
HasMainKBlockLoop
,
typename
Block2CTileMap
>
__device__
static
void
Run
(
const
FloatAB
*
__restrict__
p_a_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
...
...
@@ -339,10 +379,20 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
c_grid_desc_mblock_mperblock_nblock_nperblock
,
const
Block2CTileMap
&
block_2_ctile_map
)
{
const
auto
a_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_a_grid
,
a_grid_desc_ak0_m_ak1
.
GetElementSpaceSize
());
const
auto
b_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_b_grid
,
b_grid_desc_bk0_n_bk1
.
GetElementSpaceSize
());
const
auto
a_grid_buf
=
conditional_expr
<
PadN
>
(
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_a_grid
,
a_grid_desc_ak0_m_ak1
.
GetElementSpaceSize
(),
NumericLimits
<
FloatAB
>::
QuietNaN
()),
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_a_grid
,
a_grid_desc_ak0_m_ak1
.
GetElementSpaceSize
()));
const
auto
b_grid_buf
=
conditional_expr
<
PadN
>
(
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_b_grid
,
b_grid_desc_bk0_n_bk1
.
GetElementSpaceSize
(),
NumericLimits
<
FloatAB
>::
QuietNaN
()),
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_b_grid
,
b_grid_desc_bk0_n_bk1
.
GetElementSpaceSize
()));
const
auto
b1_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_b1_grid
,
b1_grid_desc_bk0_n_bk1
.
GetElementSpaceSize
());
auto
c_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
...
...
@@ -471,10 +521,11 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
// LDS allocation for A and B: be careful of alignment
auto
a_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
FloatAB
*>
(
p_shared
),
a_block_desc_ak0_m_ak1
.
GetElementSpaceSize
());
static_cast
<
FloatAB
*>
(
p_shared
)
+
SharedMemTrait
::
a_block_space_offset
,
a_block_desc_ak0_m_ak1
.
GetElementSpaceSize
());
auto
b_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
FloatAB
*>
(
p_shared
)
+
SharedMemTrait
::
a
_block_space_
size_aligned
,
static_cast
<
FloatAB
*>
(
p_shared
)
+
SharedMemTrait
::
b
_block_space_
offset
,
b_block_desc_bk0_n_bk1
.
GetElementSpaceSize
());
constexpr
auto
a_block_slice_copy_step
=
make_multi_index
(
KPerBlock
/
AK1
,
0
,
0
);
...
...
@@ -549,11 +600,11 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
FloatAB
,
decltype
(
acc_thread_desc_k0_m_k1
),
decltype
(
a1_thread_desc_k0_m_k1
),
decltype
(
acc_element_op
)
,
tensor_operation
::
element_wise
::
PassThrough
,
Sequence
<
A1ThreadSliceK0
,
A1ThreadSliceM
,
A1ThreadSliceK1
>
,
Sequence
<
1
,
0
,
2
>
,
2
,
n4
>
{
acc_element_op
};
n4
>
{
tensor_operation
::
element_wise
::
PassThrough
{}
};
// B1 matrix blockwise copy
auto
b1_blockwise_copy
=
...
...
@@ -591,12 +642,20 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
// reuse LDS space for gemm0's b_block_buf
auto
b1_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
FloatAB
*>
(
p_shared
)
+
SharedMemTrait
::
a
_block_space_
size_aligned
,
static_cast
<
FloatAB
*>
(
p_shared
)
+
SharedMemTrait
::
b1
_block_space_
offset
,
b1_block_desc_bk0_n_bk1
.
GetElementSpaceSize
());
constexpr
index_t
Gemm1KPack
=
math
::
max
(
math
::
lcm
(
MfmaSelector
<
FloatAB
,
MPerXdl
,
NPerXdl
>::
selected_mfma
.
group_size
,
B1K1
),
MfmaSelector
<
FloatAB
,
MPerXdl
,
NPerXdl
>::
selected_mfma
.
k_per_blk
);
// selected_mfma.group_size or B1K1 <= Gemm1KPack <= selected_mfma.group_size
// selected_mfma.k_per_blk <= Gemm1KPack
//
// Following similar rationale behind Gemm0KPack, let Gemm1KPack be the lowest common
// multiples of A1K1 (predetermined by selected_mfma.group_size) and B1K1. But in this case
// Gemm1KPack can't be higher than A1K1 itself because A1 matrix is distributed in VGPRs
// with 'group_size' amount of contiguous elements. Having Gemm1KPack greater than A1K1 will
// cause mismatch in summation index for example c[0:7] = a1[[0:3, 8:11]] * b1[0:7].
// therefore we may just as well assign Gemm1KPack = group_size
constexpr
index_t
Gemm1KPack
=
MfmaSelector
<
FloatAB
,
MPerXdl
,
NPerXdl
>::
selected_mfma
.
group_size
;
auto
gemm1_blockwise_gemm
=
BlockwiseGemmXdlops_v2
<
BlockSize
,
...
...
@@ -617,7 +676,8 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
true
,
// TransposeC
Gemm1KPack
,
// AMmaKStride
Gemm1KPack
*
XdlopsGemm
<
FloatAB
,
MPerXdl
,
NPerXdl
,
Gemm1KPack
,
false
>
{}.
K0PerXdlops
>
{
make_tuple
(
0
,
0
,
0
,
0
)};
// TransposeC
// BMmaKStride
make_tuple
(
0
,
0
,
0
,
0
)};
// A_origin
auto
acc1_thread_buf
=
gemm1_blockwise_gemm
.
GetCThreadBuffer
();
...
...
@@ -625,10 +685,8 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
// Blockwise softmax
//
auto
workspace_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
FloatGemmAcc
*>
(
p_shared
)
+
SharedMemTrait
::
a_block_space_size_aligned
*
sizeof
(
FloatAB
)
/
4
+
SharedMemTrait
::
b_block_space_size_aligned
*
sizeof
(
FloatAB
)
/
4
,
SharedMemTrait
::
reduction_workspace
);
static_cast
<
FloatGemmAcc
*>
(
p_shared
)
+
SharedMemTrait
::
reduction_space_offset
,
SharedMemTrait
::
reduction_space_size_aligned
);
// get acc0 8D thread cluster
constexpr
auto
thread_cluster_m0_n0_m1_n1_m2_n2_n3_n4
=
...
...
@@ -667,7 +725,12 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
FloatGemmAcc
,
decltype
(
threadid_to_m_n_thread_cluster_adaptor
),
decltype
(
thread_cluster_desc_m_n
),
decltype
(
thread_slice_desc_m_n
)
>
{};
decltype
(
thread_slice_desc_m_n
)
#if CK_WORKAROUND_SWDEV_XXXXXX_ATTN_KERNEL_CLANG_CANNOT_SCAVENGE_REGISTER
,
true
#endif
>
{};
const
index_t
num_gemm1_k_block_outer_loop
=
b_grid_desc_bk0_n_bk1
.
GetLength
(
I1
)
/
NPerBlock
;
...
...
@@ -706,6 +769,20 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
blockwise_gemm
,
acc_thread_buf
,
num_k_block_main_loop
);
// Acc0 elementwise Op
#if CK_WORKAROUND_SWDEV_XXXXXX_ATTN_KERNEL_CLANG_CANNOT_SCAVENGE_REGISTER
static_for
<
0
,
acc_thread_buf
.
Size
(),
1
>
{}(
[
&
](
auto
i
)
{
acc_element_op
(
acc_thread_buf
(
i
),
acc_thread_buf
[
i
]);
});
#else
static_for
<
0
,
acc_thread_buf
.
Size
(),
1
>
{}([
&
](
auto
i
)
{
ElementOpPredicatedResetNaNToMinusInf
<
PadN
>
{}.
Run
(
acc_thread_buf
(
i
),
acc_element_op
,
acc_thread_buf
[
i
]);
});
#endif
block_sync_lds
();
// wait for lds read in gemm0 blockwise gemm
// softmax
SoftmaxBuf
&
max
=
blockwise_softmax
.
max_value_buf
;
SoftmaxBuf
&
sum
=
blockwise_softmax
.
sum_value_buf
;
...
...
@@ -717,7 +794,6 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
running_sum_new
=
mathext
::
exp
(
running_max
-
running_max_new
)
*
running_sum
+
mathext
::
exp
(
max
-
running_max_new
)
*
sum
;
block_sync_lds
();
// gemm1
{
// TODO: explore using dynamic buffer for a1 thread buffer
...
...
@@ -736,12 +812,13 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
b1_blockwise_copy
.
MoveSrcSliceWindow
(
b1_grid_desc_bk0_n_bk1
,
b1_block_slice_copy_step
);
block_sync_lds
();
// wait for reduction LDS read
b1_blockwise_copy
.
RunWrite
(
b1_block_desc_bk0_n_bk1
,
b1_block_buf
);
// main body
if
constexpr
(
num_gemm1_k_block_inner_loop
>
1
)
{
static_for
<
0
,
num_gemm1_k_block_inner_loop
-
1
,
1
>
{}([
&
](
auto
i
)
{
a1_blockwise_copy
.
Run
(
acc_thread_desc_k0_m_k1
,
make_tuple
(
Number
<
i
*
A1ThreadSliceK0
>
{},
I0
,
I0
),
...
...
@@ -749,6 +826,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
a1_thread_desc_k0_m_k1
,
make_tuple
(
I0
,
I0
,
I0
),
a1_thread_buf
);
b1_blockwise_copy
.
RunRead
(
b1_grid_desc_bk0_n_bk1
,
b1_grid_buf
);
block_sync_lds
();
...
...
@@ -773,6 +851,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
a1_thread_desc_k0_m_k1
,
make_tuple
(
I0
,
I0
,
I0
),
a1_thread_buf
);
block_sync_lds
();
gemm1_blockwise_gemm
.
Run
(
a1_thread_buf
,
b1_block_buf
,
acc1_thread_buf
);
...
...
@@ -817,6 +896,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
running_max
=
running_max_new
;
running_sum
=
running_sum_new
;
block_sync_lds
();
// wait for gemm1 LDS read
}
while
(
++
gemm1_k_block_outer_index
<
num_gemm1_k_block_outer_loop
);
// end j loop
// shuffle C and write out
...
...
include/ck/tensor_operation/gpu/grid/gridwise_binary_elementwise_1d.hpp
deleted
100644 → 0
View file @
41d5fca7
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/data_type.hpp"
#include "ck/tensor_description/cluster_descriptor.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
namespace
ck
{
template
<
typename
GridwiseBinEltwise
,
typename
ADataType
,
typename
BDataType
,
typename
CDataType
,
typename
AGridDesc_M
,
typename
BGridDesc_M
,
typename
CGridDesc_M
,
typename
ElementwiseFunctor
>
__global__
void
kernel_binary_elementwise_1d
(
const
ADataType
*
__restrict__
p_a_global
,
const
BDataType
*
__restrict__
p_b_global
,
CDataType
*
__restrict__
p_c_global
,
const
AGridDesc_M
a_grid_desc_m
,
const
BGridDesc_M
b_grid_desc_m
,
const
CGridDesc_M
c_grid_desc_m
,
const
ElementwiseFunctor
functor
)
{
GridwiseBinEltwise
::
Run
(
p_a_global
,
p_b_global
,
p_c_global
,
a_grid_desc_m
,
b_grid_desc_m
,
c_grid_desc_m
,
functor
);
}
template
<
typename
ADataType
,
typename
BDataType
,
typename
CDataType
,
typename
ComputeDataType
,
typename
AGridDesc_M
,
typename
BGridDesc_M
,
typename
CGridDesc_M
,
typename
ElementwiseFunctor
,
index_t
MPerThread
,
index_t
AScalarPerVector
,
index_t
BScalarPerVector
,
index_t
CScalarPerVector
>
struct
GridwiseBinaryElementwise_1D
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
thread_desc_m
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MPerThread
>
{}));
using
PassThrough
=
tensor_operation
::
element_wise
::
PassThrough
;
static
__device__
auto
CalculateElementwiseIndex
()
{
const
index_t
global_thread_id
=
get_thread_global_1d_id
();
return
make_multi_index
(
global_thread_id
*
MPerThread
);
}
__device__
static
void
Run
(
const
ADataType
*
__restrict__
p_a_global
,
const
BDataType
*
__restrict__
p_b_global
,
CDataType
*
__restrict__
p_c_global
,
const
AGridDesc_M
a_grid_desc_m
,
const
BGridDesc_M
b_grid_desc_m
,
const
CGridDesc_M
c_grid_desc_m
,
const
ElementwiseFunctor
functor
)
{
const
auto
a_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_a_global
,
a_grid_desc_m
.
GetElementSpaceSize
());
const
auto
b_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_b_global
,
b_grid_desc_m
.
GetElementSpaceSize
());
auto
c_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_c_global
,
c_grid_desc_m
.
GetElementSpaceSize
());
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
ComputeDataType
,
MPerThread
,
true
>
a_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
ComputeDataType
,
MPerThread
,
true
>
b_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
ComputeDataType
,
MPerThread
,
true
>
c_thread_buf
;
const
auto
thread_store_global_offset
=
CalculateElementwiseIndex
();
auto
a_global_load
=
ThreadwiseTensorSliceTransfer_v2
<
ADataType
,
ComputeDataType
,
AGridDesc_M
,
decltype
(
thread_desc_m
),
Sequence
<
MPerThread
>
,
// SliceLengths
Sequence
<
0
>
,
// DimAccessOrder
0
,
// SrcVectorDim
AScalarPerVector
,
// ScalarPerVector
1
,
// SrcScalarStrideInVector
false
>
{
a_grid_desc_m
,
thread_store_global_offset
};
auto
b_global_load
=
ThreadwiseTensorSliceTransfer_v2
<
BDataType
,
ComputeDataType
,
BGridDesc_M
,
decltype
(
thread_desc_m
),
Sequence
<
MPerThread
>
,
// SliceLengths
Sequence
<
0
>
,
// DimAccessOrder
0
,
// SrcVectorDim
BScalarPerVector
,
// ScalarPerVector
1
,
// SrcScalarStrideInVector
false
>
{
b_grid_desc_m
,
thread_store_global_offset
};
auto
c_global_write
=
ThreadwiseTensorSliceTransfer_v1r3
<
ComputeDataType
,
CDataType
,
decltype
(
thread_desc_m
),
CGridDesc_M
,
PassThrough
,
Sequence
<
MPerThread
>
,
// SliceLengths
Sequence
<
0
>
,
// DimAccessOrder
0
,
// DstVectorDim
CScalarPerVector
,
// ScalarPerVector
InMemoryDataOperationEnum
::
Set
,
1
,
// DstScalarStrideInVector
false
>
{
c_grid_desc_m
,
thread_store_global_offset
,
PassThrough
{}};
const
index_t
blockSize
=
get_block_size
();
const
index_t
blockPerGrid
=
get_grid_size
();
const
auto
M
=
c_grid_desc_m
.
GetLength
(
I0
);
const
index_t
loop_step
=
blockPerGrid
*
blockSize
*
MPerThread
;
const
auto
loop_step_index
=
make_multi_index
(
loop_step
);
index_t
num_iter
=
M
/
(
loop_step
);
do
{
// read and process MPerThread elements
a_global_load
.
Run
(
a_grid_desc_m
,
a_global_buf
,
thread_desc_m
,
make_tuple
(
I0
),
a_thread_buf
);
b_global_load
.
Run
(
b_grid_desc_m
,
b_global_buf
,
thread_desc_m
,
make_tuple
(
I0
),
b_thread_buf
);
static_for
<
0
,
MPerThread
,
1
>
{}([
&
](
auto
m
)
{
constexpr
auto
offset
=
thread_desc_m
.
CalculateOffset
(
make_tuple
(
m
));
functor
(
c_thread_buf
(
Number
<
offset
>
{}),
a_thread_buf
(
Number
<
offset
>
{}),
b_thread_buf
(
Number
<
offset
>
{}));
});
c_global_write
.
Run
(
thread_desc_m
,
make_tuple
(
I0
),
// SrcSliceOriginIdx
c_thread_buf
,
c_grid_desc_m
,
c_global_buf
);
a_global_load
.
MoveSrcSliceWindow
(
a_grid_desc_m
,
loop_step_index
);
b_global_load
.
MoveSrcSliceWindow
(
b_grid_desc_m
,
loop_step_index
);
c_global_write
.
MoveDstSliceWindow
(
c_grid_desc_m
,
loop_step_index
);
}
while
(
--
num_iter
);
}
};
}
// namespace ck
include/ck/tensor_operation/gpu/grid/gridwise_elementwise_1d.hpp
0 → 100644
View file @
b89a88b5
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/tensor_description/cluster_descriptor.hpp"
#include "ck/utility/data_type.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
namespace
ck
{
template
<
typename
GridwiseElementwise1dFunctor
,
typename
InGrid1dDescTuple
,
typename
OutGrid1dDescTuple
,
typename
InDataTypePointerTuple
,
typename
OutDataTypePointerTuple
,
typename
ElementwiseOperation
>
__global__
void
kernel_elementwise_1d
(
const
InGrid1dDescTuple
in_grid_1d_desc_tuple
,
const
OutGrid1dDescTuple
out_grid_1d_desc_tuple
,
const
InDataTypePointerTuple
p_in_global_tuple
,
const
OutDataTypePointerTuple
p_out_global_tuple
,
const
ElementwiseOperation
elementwise_op
)
{
GridwiseElementwise1dFunctor
::
Run
(
in_grid_1d_desc_tuple
,
out_grid_1d_desc_tuple
,
p_in_global_tuple
,
p_out_global_tuple
,
elementwise_op
);
}
template
<
typename
InGrid1dDescTuple
,
typename
OutGrid1dDescTuple
,
typename
InDataTypePointerTuple
,
typename
OutDataTypePointerTuple
,
typename
ElementwiseOperation
,
index_t
MPerThread
,
typename
InScalarPerVectorSeq
,
typename
OutScalarPerVectorSeq
>
struct
GridwiseElementwise_1D
{
static
constexpr
index_t
NumInput
=
InDataTypePointerTuple
::
Size
();
static
constexpr
index_t
NumOutput
=
OutDataTypePointerTuple
::
Size
();
static_assert
(
NumInput
==
InScalarPerVectorSeq
::
Size
()
&&
NumOutput
==
OutScalarPerVectorSeq
::
Size
()
&&
NumInput
==
InGrid1dDescTuple
::
Size
()
&&
NumOutput
==
OutGrid1dDescTuple
::
Size
(),
"Tuple size is inconsistent with the number of in/out!"
);
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
thread_buffer_desc_m
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MPerThread
>
{}));
using
PassThroughOp
=
tensor_operation
::
element_wise
::
PassThrough
;
__device__
static
void
Run
(
const
InGrid1dDescTuple
in_grid_1d_desc_tuple
,
const
OutGrid1dDescTuple
out_grid_1d_desc_tuple
,
const
InDataTypePointerTuple
p_in_global_tuple
,
const
OutDataTypePointerTuple
p_out_global_tuple
,
const
ElementwiseOperation
elementwise_op
)
{
const
index_t
thread_global_id
=
get_thread_global_1d_id
();
auto
in_thread_buf_tuple
=
generate_tuple
(
[
&
](
auto
I
)
{
using
DataTypePointer
=
remove_cvref_t
<
decltype
(
InDataTypePointerTuple
{}[
I
])
>
;
using
DataType
=
remove_cv_t
<
remove_pointer_t
<
DataTypePointer
>>
;
return
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
DataType
,
MPerThread
,
true
>
{};
},
Number
<
NumInput
>
{});
auto
out_thread_buf_tuple
=
generate_tuple
(
[
&
](
auto
I
)
{
using
DataTypePointer
=
remove_cvref_t
<
decltype
(
OutDataTypePointerTuple
{}[
I
])
>
;
using
DataType
=
remove_pointer_t
<
DataTypePointer
>
;
return
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
DataType
,
MPerThread
,
true
>
{};
},
Number
<
NumOutput
>
{});
auto
in_global_buf_tuple
=
generate_tuple
(
[
&
](
auto
I
)
{
return
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_in_global_tuple
[
I
],
in_grid_1d_desc_tuple
[
I
].
GetElementSpaceSize
());
},
Number
<
NumInput
>
{});
auto
out_global_buf_tuple
=
generate_tuple
(
[
&
](
auto
I
)
{
return
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_out_global_tuple
[
I
],
out_grid_1d_desc_tuple
[
I
].
GetElementSpaceSize
());
},
Number
<
NumOutput
>
{});
const
auto
thread_global_offset
=
make_multi_index
(
thread_global_id
*
MPerThread
);
const
index_t
blockSize
=
get_block_size
();
const
index_t
blockPerGrid
=
get_grid_size
();
const
auto
M
=
in_grid_1d_desc_tuple
[
I0
].
GetLength
(
I0
);
const
index_t
loop_step
=
blockPerGrid
*
blockSize
*
MPerThread
;
const
auto
loop_step_index
=
make_multi_index
(
loop_step
);
auto
in_global_load_tuple
=
generate_tuple
(
[
&
](
auto
I
)
{
using
DataTypePointer
=
remove_cvref_t
<
decltype
(
InDataTypePointerTuple
{}[
I
])
>
;
using
DataType
=
remove_cv_t
<
remove_pointer_t
<
DataTypePointer
>>
;
return
ThreadwiseTensorSliceTransfer_v2
<
DataType
,
DataType
,
decltype
(
in_grid_1d_desc_tuple
[
I
]),
decltype
(
thread_buffer_desc_m
),
Sequence
<
MPerThread
>
,
// SliceLengths
Sequence
<
0
>
,
// DimAccessOrder
0
,
// SrcVectorDim
InScalarPerVectorSeq
::
At
(
I
),
// ScalarPerVector
1
,
// SrcScalarStrideInVector
false
>
{
in_grid_1d_desc_tuple
[
I
],
thread_global_offset
};
},
Number
<
NumInput
>
{});
auto
out_global_store_tuple
=
generate_tuple
(
[
&
](
auto
I
)
{
using
DataTypePointer
=
remove_cvref_t
<
decltype
(
OutDataTypePointerTuple
{}[
I
])
>
;
using
DataType
=
remove_pointer_t
<
DataTypePointer
>
;
return
ThreadwiseTensorSliceTransfer_v1r3
<
DataType
,
DataType
,
decltype
(
thread_buffer_desc_m
),
decltype
(
out_grid_1d_desc_tuple
[
I
]),
PassThroughOp
,
Sequence
<
MPerThread
>
,
// SliceLengths
Sequence
<
0
>
,
// DimAccessOrder
0
,
// SrcVectorDim
OutScalarPerVectorSeq
::
At
(
I
),
InMemoryDataOperationEnum
::
Set
,
1
,
false
>
(
out_grid_1d_desc_tuple
[
I
],
thread_global_offset
,
PassThroughOp
{});
},
Number
<
NumOutput
>
{});
index_t
num_iter
=
M
/
(
loop_step
);
do
{
static_for
<
0
,
NumInput
,
1
>
{}([
&
](
auto
I
)
{
in_global_load_tuple
(
I
).
Run
(
in_grid_1d_desc_tuple
[
I
],
in_global_buf_tuple
[
I
],
thread_buffer_desc_m
,
make_tuple
(
I0
),
in_thread_buf_tuple
(
I
));
in_global_load_tuple
(
I
).
MoveSrcSliceWindow
(
in_grid_1d_desc_tuple
[
I
],
loop_step_index
);
});
static_for
<
0
,
MPerThread
,
1
>
{}([
&
](
auto
iM
)
{
// get reference to in data
const
auto
in_data_refs
=
generate_tie
(
// return type should be lvalue
[
&
](
auto
I
)
->
const
auto
&
{
return
in_thread_buf_tuple
(
I
)(
iM
);
},
Number
<
NumInput
>
{});
// get reference to dst data
auto
out_data_refs
=
generate_tie
(
// return type should be lvalue
[
&
](
auto
I
)
->
auto
&
{
return
out_thread_buf_tuple
(
I
)(
iM
);
},
Number
<
NumOutput
>
{});
unpack2
(
elementwise_op
,
out_data_refs
,
in_data_refs
);
});
static_for
<
0
,
NumOutput
,
1
>
{}([
&
](
auto
I
)
{
out_global_store_tuple
(
I
).
Run
(
thread_buffer_desc_m
,
make_tuple
(
I0
),
out_thread_buf_tuple
[
I
],
out_grid_1d_desc_tuple
[
I
],
out_global_buf_tuple
(
I
));
out_global_store_tuple
(
I
).
MoveDstSliceWindow
(
out_grid_1d_desc_tuple
[
I
],
loop_step_index
);
});
}
while
(
--
num_iter
);
}
};
}
// namespace ck
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp
View file @
b89a88b5
...
...
@@ -32,8 +32,8 @@ template <typename FloatAB,
typename
ThreadReduceOperations
,
InMemoryDataOperationEnum
EGlobalMemoryDataOperation
,
typename
RsGlobalMemoryDataOperation
,
typename
AGridDesc_
AK0_M_AK1
,
typename
BGridDesc_
BK0_N_BK1
,
typename
AGridDesc_
M_K
,
typename
BGridDesc_
N_K
,
typename
EGridDesc_M_N
,
typename
RGridDesc_M
,
index_t
NumGemmKPrefetchStage
,
...
...
@@ -84,10 +84,10 @@ struct GridwiseGemmMultipleDMultipleR_k0mk1_k0nk1_mn_xdl_cshuffle_v1
static
constexpr
auto
I7
=
Number
<
7
>
{};
// K1 should be Number<...>
static
constexpr
auto
AK
0
=
Number
<
KPerBlock
/
AK1Value
>
{};
static
constexpr
auto
BK
0
=
Number
<
KPerBlock
/
BK1Value
>
{};
static
constexpr
auto
AK
1
=
Number
<
AK1Value
>
{};
static
constexpr
auto
BK
1
=
Number
<
BK1Value
>
{};
static
constexpr
auto
AK
1
=
Number
<
AK1Value
>
{};
static
constexpr
auto
BK
1
=
Number
<
BK1Value
>
{};
static
constexpr
auto
AK
0PerBlock
=
Number
<
KPerBlock
/
AK1Value
>
{};
static
constexpr
auto
BK
0PerBlock
=
Number
<
KPerBlock
/
BK1Value
>
{};
using
ThisThreadBlock
=
ThisThreadBlock
<
BlockSize
>
;
...
...
@@ -97,7 +97,7 @@ struct GridwiseGemmMultipleDMultipleR_k0mk1_k0nk1_mn_xdl_cshuffle_v1
{
// A matrix in LDS memory, dst of blockwise copy
return
make_naive_tensor_descriptor
(
make_tuple
(
AK0
,
Number
<
MPerBlock
>
{},
AK1
),
make_tuple
(
AK0
PerBlock
,
Number
<
MPerBlock
>
{},
AK1
),
make_tuple
(
Number
<
MPerBlock
+
ABlockLdsExtraM
>
{}
*
AK1
,
AK1
,
I1
));
}
...
...
@@ -105,7 +105,7 @@ struct GridwiseGemmMultipleDMultipleR_k0mk1_k0nk1_mn_xdl_cshuffle_v1
{
// B matrix in LDS memory, dst of blockwise copy
return
make_naive_tensor_descriptor
(
make_tuple
(
BK0
,
Number
<
NPerBlock
>
{},
BK1
),
make_tuple
(
BK0
PerBlock
,
Number
<
NPerBlock
>
{},
BK1
),
make_tuple
(
Number
<
NPerBlock
+
BBlockLdsExtraN
>
{}
*
BK1
,
BK1
,
I1
));
}
...
...
@@ -167,22 +167,57 @@ struct GridwiseGemmMultipleDMultipleR_k0mk1_k0nk1_mn_xdl_cshuffle_v1
c_block_size
*
sizeof
(
FloatCShuffle
));
}
// A desc for source in blockwise copy
__host__
__device__
static
constexpr
auto
MakeDefaultAGridDescriptor_AK0_M_AK1
(
const
AGridDesc_M_K
&
a_grid_desc_m_k
)
{
const
auto
M
=
a_grid_desc_m_k
.
GetLength
(
I0
);
const
auto
K
=
a_grid_desc_m_k
.
GetLength
(
I1
);
const
auto
AK0
=
K
/
AK1
;
return
transform_tensor_descriptor
(
a_grid_desc_m_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
AK0
,
AK1
)),
make_pass_through_transform
(
M
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
}
// B desc for source in blockwise copy
__host__
__device__
static
constexpr
auto
MakeDefaultBGridDescriptor_BK0_N_BK1
(
const
BGridDesc_N_K
&
b_grid_desc_n_k
)
{
const
auto
N
=
b_grid_desc_n_k
.
GetLength
(
I0
);
const
auto
K
=
b_grid_desc_n_k
.
GetLength
(
I1
);
const
auto
BK0
=
K
/
BK1
;
return
transform_tensor_descriptor
(
b_grid_desc_n_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
BK0
,
BK1
)),
make_pass_through_transform
(
N
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
}
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
template
<
typename
Block2ETileMap
>
__host__
__device__
static
constexpr
bool
CheckValidity
(
const
AGridDesc_AK0_M_AK1
&
a_grid_desc_ak0_m_ak1
,
const
BGridDesc_BK0_N_BK1
&
b_grid_desc_bk0_n_bk1
,
const
EGridDesc_M_N
&
e_grid_desc_m_n
,
const
RGridDesc_M
&
r_grid_desc_m
,
const
Block2ETileMap
&
block_2_etile_map
)
__host__
__device__
static
constexpr
bool
CheckValidity
(
const
AGridDesc_M_K
&
a_grid_desc_m_k
,
const
BGridDesc_N_K
&
b_grid_desc_n_k
,
const
EGridDesc_M_N
&
e_grid_desc_m_n
,
const
RGridDesc_M
&
r_grid_desc_m
,
const
Block2ETileMap
&
block_2_etile_map
)
{
static_assert
((
MPerBlock
%
(
MPerXdl
*
MXdlPerWave
)
==
0
)
&&
(
NPerBlock
%
(
NXdlPerWave
*
NPerXdl
))
==
0
,
"Invalid tuning param!"
);
const
auto
M
=
a_grid_desc_ak0_m_ak1
.
GetLength
(
I1
);
const
auto
N
=
b_grid_desc_bk0_n_bk1
.
GetLength
(
I1
);
const
auto
K
=
a_grid_desc_ak0_m_ak1
.
GetLength
(
I0
)
*
a_grid_desc_ak0_m_ak1
.
GetLength
(
I2
);
static_assert
(
AGridDesc_M_K
::
GetNumOfDimension
()
==
2
);
static_assert
(
BGridDesc_N_K
::
GetNumOfDimension
()
==
2
);
static_assert
(
EGridDesc_M_N
::
GetNumOfDimension
()
==
2
);
const
auto
M
=
a_grid_desc_m_k
.
GetLength
(
I0
);
const
auto
N
=
b_grid_desc_n_k
.
GetLength
(
I0
);
const
auto
K
=
a_grid_desc_m_k
.
GetLength
(
I1
);
if
(
!
(
M
==
e_grid_desc_m_n
.
GetLength
(
I0
)
&&
N
==
e_grid_desc_m_n
.
GetLength
(
I1
)))
return
false
;
...
...
@@ -259,6 +294,10 @@ struct GridwiseGemmMultipleDMultipleR_k0mk1_k0nk1_mn_xdl_cshuffle_v1
e_grid_desc_m_n
);
}
using
DefaultAGridDesc_AK0_M_AK1
=
remove_cvref_t
<
decltype
(
MakeDefaultAGridDescriptor_AK0_M_AK1
(
AGridDesc_M_K
{}))
>
;
using
DefaultBGridDesc_BK0_N_BK1
=
remove_cvref_t
<
decltype
(
MakeDefaultBGridDescriptor_BK0_N_BK1
(
BGridDesc_N_K
{}))
>
;
using
EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
=
remove_cvref_t
<
decltype
(
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
EGridDesc_M_N
{}))
>
;
...
...
@@ -272,7 +311,10 @@ struct GridwiseGemmMultipleDMultipleR_k0mk1_k0nk1_mn_xdl_cshuffle_v1
using
DsGridPointer
=
decltype
(
MakeTsGridPointer
<
DsDataType
,
true
>
());
using
RsGridPointer
=
decltype
(
MakeTsGridPointer
<
RsDataType
,
false
>
());
template
<
bool
HasMainKBlockLoop
,
typename
Block2ETileMap
>
template
<
bool
HasMainKBlockLoop
,
typename
AGridDesc_AK0_M_AK1
,
typename
BGridDesc_BK0_N_BK1
,
typename
Block2ETileMap
>
__device__
static
void
Run
(
const
FloatAB
*
__restrict__
p_a_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
...
...
@@ -356,7 +398,7 @@ struct GridwiseGemmMultipleDMultipleR_k0mk1_k0nk1_mn_xdl_cshuffle_v1
AElementwiseOperation
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
InMemoryDataOperationEnum
::
Set
,
Sequence
<
AK0
,
MPerBlock
,
AK1
>
,
Sequence
<
AK0
PerBlock
,
MPerBlock
,
AK1
>
,
ABlockTransferThreadClusterLengths_AK0_M_AK1
,
ABlockTransferThreadClusterArrangeOrder
,
FloatAB
,
...
...
@@ -387,7 +429,7 @@ struct GridwiseGemmMultipleDMultipleR_k0mk1_k0nk1_mn_xdl_cshuffle_v1
BElementwiseOperation
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
InMemoryDataOperationEnum
::
Set
,
Sequence
<
BK0
,
NPerBlock
,
BK1
>
,
Sequence
<
BK0
PerBlock
,
NPerBlock
,
BK1
>
,
BBlockTransferThreadClusterLengths_BK0_N_BK1
,
BBlockTransferThreadClusterArrangeOrder
,
FloatAB
,
...
...
@@ -776,8 +818,7 @@ struct GridwiseGemmMultipleDMultipleR_k0mk1_k0nk1_mn_xdl_cshuffle_v1
static_for
<
0
,
num_access
,
1
>
{}([
&
](
auto
access_id
)
{
// make sure it's safe to read from LDS
if
constexpr
(
access_id
>
0
)
block_sync_lds
();
block_sync_lds
();
// each thread shuffle data from VGPR to LDS
c_thread_copy_vgpr_to_lds
.
Run
(
c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp
View file @
b89a88b5
...
...
@@ -53,7 +53,7 @@ __global__ void
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
p_a_grid
,
p_b_grid
,
p_c_grid
,
p_shared_block
,
static_cast
<
void
*>
(
p_shared_block
)
,
a_b_k0_m_k1_grid_desc
,
b_b_k0_n_k1_grid_desc
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
...
...
@@ -270,7 +270,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
__device__
static
void
Run
(
const
FloatAB
*
__restrict__
p_a_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
FloatC
*
__restrict__
p_c_grid
,
FloatAB
*
__restrict__
p_shared_block
,
void
*
__restrict__
p_shared_block
,
const
AGridDesc_B_K0_M_K1
&
a_b_k0_m_k1_grid_desc
,
const
BGridDesc_B_K0_N_K1
&
b_b_k0_n_k1_grid_desc
,
const
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
&
...
...
@@ -463,8 +463,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
constexpr
auto
a_block_space_size
=
math
::
integer_least_multiple
(
a_k0_m_k1_block_desc
.
GetElementSpaceSize
(),
max_lds_align
);
FloatAB
*
p_a_block
=
p_shared_block
;
FloatAB
*
p_b_block
=
p_shared_block
+
a_block_space_size
;
FloatAB
*
p_a_block
=
static_cast
<
FloatAB
*>
(
p_shared_block
)
;
FloatAB
*
p_b_block
=
static_cast
<
FloatAB
*>
(
p_shared_block
)
+
a_block_space_size
;
constexpr
auto
a_block_slice_copy_step
=
make_multi_index
(
0
,
K0PerBlock
,
0
,
0
);
constexpr
auto
b_block_slice_copy_step
=
make_multi_index
(
0
,
K0PerBlock
,
0
,
0
);
...
...
@@ -547,11 +547,6 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
static_cast
<
FloatC
*>
(
p_shared_block
),
c_block_desc_mblock_mperblock_nblock_nperblock
.
GetElementSpaceSize
());
static_assert
(
M1
==
MWave
,
""
);
static_assert
(
N1
==
NWave
,
""
);
static_assert
(
M2
*
M3
*
M4
==
MPerXDL
,
""
);
static_assert
(
N2
==
NPerXDL
,
""
);
constexpr
auto
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
=
transform_tensor_descriptor
(
c_block_desc_mblock_mperblock_nblock_nperblock
,
make_tuple
(
...
...
include/ck/tensor_operation/gpu/grid/gridwise_set_multiple_buffer_value.hpp
0 → 100644
View file @
b89a88b5
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
namespace
ck
{
template
<
typename
Grid1dBufferDescTuple
,
index_t
NumBuffer
,
index_t
BlockSize
,
typename
DataTypePointerTuple
,
typename
DataTypeTuple
>
__global__
void
kernel_multiple_buffer_set_value
(
const
Grid1dBufferDescTuple
grid_1d_buffer_desc_tuple
,
DataTypePointerTuple
p_global_tuple
,
DataTypeTuple
value_tuple
)
{
static_assert
(
NumBuffer
==
DataTypePointerTuple
::
Size
()
&&
NumBuffer
==
DataTypeTuple
::
Size
(),
"The tuple size should be same as NumBuffer!"
);
static_for
<
0
,
NumBuffer
,
1
>
{}([
&
](
auto
iB
)
{
using
DataTypePointer
=
remove_cvref_t
<
decltype
(
DataTypePointerTuple
{}[
iB
])
>
;
using
DataTypeFromPointer
=
remove_pointer_t
<
DataTypePointer
>
;
using
DataType
=
remove_cvref_t
<
decltype
(
DataTypeTuple
{}[
iB
])
>
;
static_assert
(
is_same
<
DataType
,
DataTypeFromPointer
>::
value
,
"Types in tuples does not match!"
);
});
constexpr
auto
I0
=
Number
<
0
>
{};
const
index_t
thread_global_id
=
get_thread_global_1d_id
();
auto
value_buf_tuple
=
generate_tuple
(
[
&
](
auto
iB
)
{
using
DataType
=
remove_cvref_t
<
decltype
(
DataTypeTuple
{}[
iB
])
>
;
return
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
DataType
,
1
,
true
>
{};
},
Number
<
NumBuffer
>
{});
static_for
<
0
,
NumBuffer
,
1
>
{}([
&
](
auto
iB
)
{
static_for
<
0
,
1
,
1
>
{}([
&
](
auto
J
)
{
value_buf_tuple
(
iB
)(
J
)
=
value_tuple
[
iB
];
});
});
auto
global_buf_tuple
=
generate_tuple
(
[
&
](
auto
iB
)
{
return
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_global_tuple
(
iB
),
grid_1d_buffer_desc_tuple
[
iB
].
GetElementSpaceSize
());
},
Number
<
NumBuffer
>
{});
constexpr
auto
val_buff_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
1
>
{}));
static_for
<
0
,
NumBuffer
,
1
>
{}([
&
](
auto
iB
)
{
using
DataType
=
remove_cvref_t
<
decltype
(
DataTypeTuple
{}[
iB
])
>
;
using
PassThroughOp
=
tensor_operation
::
element_wise
::
PassThrough
;
auto
threadwise_store
=
ThreadwiseTensorSliceTransfer_v1r3
<
DataType
,
DataType
,
decltype
(
val_buff_desc
),
decltype
(
Grid1dBufferDescTuple
{}[
iB
]),
PassThroughOp
,
Sequence
<
1
>
,
Sequence
<
0
>
,
0
,
1
,
InMemoryDataOperationEnum
::
Set
,
1
,
true
>
(
grid_1d_buffer_desc_tuple
[
iB
],
make_multi_index
(
thread_global_id
),
PassThroughOp
{});
threadwise_store
.
Run
(
val_buff_desc
,
make_tuple
(
I0
),
value_buf_tuple
(
iB
),
grid_1d_buffer_desc_tuple
[
iB
],
global_buf_tuple
(
iB
));
});
};
}
// namespace ck
include/ck/tensor_operation/gpu/grid/gridwise_sparse_embedding3_forward_layernorm.hpp
0 → 100644
View file @
b89a88b5
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_welford.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_welford.hpp"
namespace
ck
{
template
<
typename
GridwiseSparseEmbedding
,
typename
EmbType
,
typename
IndexType
,
typename
GammaDataType
,
typename
BetaDataType
,
typename
AccDataType
,
typename
OutType
,
typename
OutGridDesc
>
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
#endif
__global__
void
kernel_sparse_embedding3_forward_layernorm
(
OutType
*
p_out
,
const
EmbType
*
p_emb_a
,
const
EmbType
*
p_emb_b
,
const
EmbType
*
p_emb_c
,
const
IndexType
*
p_index_a
,
const
IndexType
*
p_index_b
,
const
IndexType
*
p_index_c
,
const
GammaDataType
*
p_gamma
,
const
BetaDataType
*
p_beta
,
const
OutGridDesc
out_grid_desc
,
const
AccDataType
epsilon
)
{
GridwiseSparseEmbedding
::
Run
(
p_out
,
p_emb_a
,
p_emb_b
,
p_emb_c
,
p_index_a
,
p_index_b
,
p_index_c
,
p_gamma
,
p_beta
,
out_grid_desc
,
epsilon
);
}
template
<
typename
EmbType
,
typename
IndexType
,
typename
GammaDataType
,
typename
BetaDataType
,
typename
AccDataType
,
typename
OutType
,
typename
OutGridDesc
,
ck
::
index_t
BlockSize
,
ck
::
index_t
DimClusterSize
,
ck
::
index_t
RowClusterSize
,
ck
::
index_t
DimPerBlock
,
// Row x Dim, along Dim
ck
::
index_t
RowPerBlock
,
// Row x Dim, along Row
ck
::
index_t
DimThreadSize
,
// this is actually not vector, but number of registers
ck
::
index_t
RowVectorSize
>
struct
GridwiseSparseEmbedding3ForwardLayernorm
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
static
constexpr
index_t
WaveSize
=
64
;
static_assert
(
BlockSize
==
RowClusterSize
*
DimClusterSize
,
"Invalid cluster distribution within block"
);
static_assert
(
RowClusterSize
%
WaveSize
==
0
,
"need to be wavewise"
);
static_assert
(
DimPerBlock
%
(
DimClusterSize
*
DimThreadSize
)
==
0
,
""
);
static_assert
(
RowPerBlock
%
(
RowClusterSize
*
RowVectorSize
)
==
0
,
""
);
static
constexpr
auto
DimSubBlocks
=
DimPerBlock
/
(
DimClusterSize
*
DimThreadSize
);
static
constexpr
auto
RowSubBlocks
=
RowPerBlock
/
(
RowClusterSize
*
RowVectorSize
);
static_assert
((
DimPerBlock
%
DimSubBlocks
==
0
)
&&
(
RowPerBlock
%
RowSubBlocks
==
0
),
""
);
static
constexpr
auto
DimPerSubBlock
=
DimPerBlock
/
DimSubBlocks
;
static
constexpr
auto
RowPerSubBlock
=
RowPerBlock
/
RowSubBlocks
;
using
ThreadwiseWolfordDesc2D
=
decltype
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
DimSubBlocks
*
DimThreadSize
>
{},
Number
<
RowSubBlocks
*
RowVectorSize
>
{})));
using
ThreadwiseWolfordDescReduce
=
decltype
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
DimSubBlocks
*
DimThreadSize
>
{})));
using
ThreadwiseWelford
=
ThreadwiseWelford
<
AccDataType
,
ThreadwiseWolfordDesc2D
,
ThreadwiseWolfordDescReduce
>
;
using
ThreadClusterLength
=
Sequence
<
DimClusterSize
,
RowClusterSize
>
;
using
BlockwiseWelford
=
BlockwiseWelford
<
AccDataType
,
BlockSize
,
ThreadClusterLength
,
Sequence
<
0
,
1
>>
;
__device__
static
void
Run
(
OutType
*
p_out
,
const
EmbType
*
p_emb_a
,
const
EmbType
*
p_emb_b
,
const
EmbType
*
p_emb_c
,
const
IndexType
*
p_index_a
,
const
IndexType
*
p_index_b
,
const
IndexType
*
p_index_c
,
const
GammaDataType
*
p_gamma
,
const
BetaDataType
*
p_beta
,
const
OutGridDesc
,
const
AccDataType
epsilon
)
{
const
index_t
thread_local_id
=
get_thread_local_1d_id
();
const
index_t
block_global_id
=
get_block_1d_id
();
// const auto index_length = out_grid_desc.GetLength(I0);
// const auto emb_dim = out_grid_desc.GetLength(I1);
constexpr
auto
thread_cluster_desc
=
make_cluster_descriptor
(
Sequence
<
DimClusterSize
,
RowClusterSize
>
{},
Sequence
<
0
,
1
>
{});
const
auto
thread_cluster_idx
=
thread_cluster_desc
.
CalculateBottomIndex
(
make_multi_index
(
thread_local_id
));
const
auto
thread_dim_cluster_id
=
thread_cluster_idx
[
I0
];
const
auto
thread_row_cluster_id
=
thread_cluster_idx
[
I1
];
const
auto
wave_dim_id
=
__builtin_amdgcn_readfirstlane
(
thread_dim_cluster_id
/
WaveSize
);
const
auto
index_start
=
block_global_id
*
DimPerBlock
+
wave_dim_id
*
DimThreadSize
;
auto
threadwise_welford
=
ThreadwiseWelford
();
threadwise_welford
.
max_count_
=
RowSubBlocks
*
RowVectorSize
;
constexpr
auto
thread_buf_size
=
DimSubBlocks
*
DimThreadSize
*
RowSubBlocks
*
RowVectorSize
;
constexpr
auto
thread_buf_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
DimSubBlocks
,
DimThreadSize
,
RowSubBlocks
,
RowVectorSize
));
constexpr
auto
mean_var_buf_size
=
DimSubBlocks
*
DimThreadSize
;
constexpr
auto
mean_var_buf_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
DimSubBlocks
,
DimThreadSize
));
constexpr
auto
gamma_beta_buf_size
=
RowSubBlocks
*
RowVectorSize
;
constexpr
auto
gamma_beta_buf_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
RowSubBlocks
,
RowVectorSize
));
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
EmbType
,
thread_buf_size
,
true
>
in_thread_buf_a
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
EmbType
,
thread_buf_size
,
true
>
in_thread_buf_b
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
EmbType
,
thread_buf_size
,
true
>
in_thread_buf_c
;
StaticBuffer
<
AddressSpaceEnum
::
Sgpr
,
IndexType
,
DimPerBlock
,
true
>
index_buf_a
;
StaticBuffer
<
AddressSpaceEnum
::
Sgpr
,
IndexType
,
DimPerBlock
,
true
>
index_buf_b
;
StaticBuffer
<
AddressSpaceEnum
::
Sgpr
,
IndexType
,
DimPerBlock
,
true
>
index_buf_c
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
thread_buf_size
,
true
>
acc_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
gamma_beta_buf_size
,
true
>
gamma_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
gamma_beta_buf_size
,
true
>
beta_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
mean_var_buf_size
,
true
>
mean_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
mean_var_buf_size
,
true
>
var_thread_buf
;
auto
load_current_sub_row
=
[
&
](
auto
i_dim_sub_
,
auto
i_row_sub_
)
{
vector_type_maker_t
<
EmbType
,
RowVectorSize
>
emb_vector_a
;
vector_type_maker_t
<
EmbType
,
RowVectorSize
>
emb_vector_b
;
vector_type_maker_t
<
EmbType
,
RowVectorSize
>
emb_vector_c
;
using
src_vector_t
=
typename
decltype
(
emb_vector_a
)
::
type
;
static_for
<
0
,
DimThreadSize
,
1
>
{}([
&
](
auto
i_dim_vec_
)
{
constexpr
auto
current_dim
=
i_dim_sub_
*
DimPerSubBlock
+
i_dim_vec_
;
IndexType
index_a
=
index_buf_a
[
Number
<
current_dim
>
{}];
IndexType
index_b
=
index_buf_b
[
Number
<
current_dim
>
{}];
IndexType
index_c
=
index_buf_c
[
Number
<
current_dim
>
{}];
auto
thread_offset
=
(
thread_row_cluster_id
+
i_row_sub_
*
RowClusterSize
)
*
sizeof
(
EmbType
)
*
RowVectorSize
;
int32x4_t
emb_res_a
=
make_wave_buffer_resource_with_default_range
(
p_emb_a
+
index_a
*
RowPerBlock
);
int32x4_t
emb_res_b
=
make_wave_buffer_resource_with_default_range
(
p_emb_b
+
index_b
*
RowPerBlock
);
int32x4_t
emb_res_c
=
make_wave_buffer_resource_with_default_range
(
p_emb_c
+
index_c
*
RowPerBlock
);
emb_vector_a
.
template
AsType
<
src_vector_t
>()(
I0
)
=
amd_buffer_load_impl
<
EmbType
,
RowVectorSize
>
(
emb_res_a
,
thread_offset
,
0
);
emb_vector_b
.
template
AsType
<
src_vector_t
>()(
I0
)
=
amd_buffer_load_impl
<
EmbType
,
RowVectorSize
>
(
emb_res_b
,
thread_offset
,
0
);
emb_vector_c
.
template
AsType
<
src_vector_t
>()(
I0
)
=
amd_buffer_load_impl
<
EmbType
,
RowVectorSize
>
(
emb_res_c
,
thread_offset
,
0
);
static_for
<
0
,
RowVectorSize
,
1
>
{}([
&
](
auto
i_row_vec_
)
{
constexpr
auto
register_offset
=
thread_buf_desc
.
CalculateOffset
(
make_tuple
(
i_dim_sub_
,
i_dim_vec_
,
i_row_sub_
,
i_row_vec_
));
in_thread_buf_a
(
Number
<
register_offset
>
{})
=
emb_vector_a
.
template
AsType
<
EmbType
>()[
i_row_vec_
];
in_thread_buf_b
(
Number
<
register_offset
>
{})
=
emb_vector_b
.
template
AsType
<
EmbType
>()[
i_row_vec_
];
in_thread_buf_c
(
Number
<
register_offset
>
{})
=
emb_vector_c
.
template
AsType
<
EmbType
>()[
i_row_vec_
];
});
});
};
auto
accumulate_current_sub_row
=
[
&
](
auto
i_dim_sub_
,
auto
i_row_sub_
)
{
static_for
<
0
,
DimThreadSize
,
1
>
{}([
&
](
auto
i_dim_vec_
)
{
static_for
<
0
,
RowVectorSize
,
1
>
{}([
&
](
auto
i_row_vec_
)
{
constexpr
auto
register_offset
=
thread_buf_desc
.
CalculateOffset
(
make_tuple
(
i_dim_sub_
,
i_dim_vec_
,
i_row_sub_
,
i_row_vec_
));
AccDataType
va
=
ck
::
type_convert
<
AccDataType
>
(
in_thread_buf_a
(
Number
<
register_offset
>
{}));
AccDataType
vb
=
ck
::
type_convert
<
AccDataType
>
(
in_thread_buf_b
(
Number
<
register_offset
>
{}));
AccDataType
vc
=
ck
::
type_convert
<
AccDataType
>
(
in_thread_buf_c
(
Number
<
register_offset
>
{}));
acc_thread_buf
(
Number
<
register_offset
>
{})
+=
va
+
vb
+
vc
;
});
});
};
auto
threadwise_welford_sub_row
=
[
&
](
auto
i_dim_sub_
,
auto
i_row_sub_
)
{
static_for
<
0
,
DimThreadSize
,
1
>
{}([
&
](
auto
i_dim_vec_
)
{
static_for
<
0
,
RowVectorSize
,
1
>
{}([
&
](
auto
i_row_vec_
)
{
constexpr
auto
register_offset
=
thread_buf_desc
.
CalculateOffset
(
make_tuple
(
i_dim_sub_
,
i_dim_vec_
,
i_row_sub_
,
i_row_vec_
));
constexpr
auto
mean_var_offset
=
mean_var_buf_desc
.
CalculateOffset
(
make_tuple
(
i_dim_sub_
,
i_dim_vec_
));
threadwise_welford
.
cur_count_
++
;
threadwise_welford
.
Update
(
mean_thread_buf
(
Number
<
mean_var_offset
>
{}),
var_thread_buf
(
Number
<
mean_var_offset
>
{}),
acc_thread_buf
(
Number
<
register_offset
>
{}));
});
});
};
auto
threadwise_normalize_store_out
=
[
&
](
auto
i_dim_sub_
,
auto
i_row_sub_
)
{
int32x4_t
out_res
=
make_wave_buffer_resource_with_default_range
(
p_out
+
index_start
*
RowPerBlock
);
static_for
<
0
,
DimThreadSize
,
1
>
{}([
&
](
auto
i_dim_vec_
)
{
vector_type_maker_t
<
OutType
,
RowVectorSize
>
out_vector
;
using
dst_vector_t
=
typename
decltype
(
out_vector
)
::
type
;
constexpr
auto
mean_var_offset
=
mean_var_buf_desc
.
CalculateOffset
(
make_tuple
(
i_dim_sub_
,
i_dim_vec_
));
static_for
<
0
,
RowVectorSize
,
1
>
{}([
&
](
auto
i_row_vec_
)
{
constexpr
auto
register_offset
=
thread_buf_desc
.
CalculateOffset
(
make_tuple
(
i_dim_sub_
,
i_dim_vec_
,
i_row_sub_
,
i_row_vec_
));
constexpr
auto
gamma_beta_offset
=
gamma_beta_buf_desc
.
CalculateOffset
(
make_tuple
(
i_row_sub_
,
i_row_vec_
));
auto
acc_val
=
acc_thread_buf
[
Number
<
register_offset
>
{}];
acc_val
=
(
acc_val
-
mean_thread_buf
(
Number
<
mean_var_offset
>
{}))
/
sqrt
(
var_thread_buf
(
Number
<
mean_var_offset
>
{})
+
epsilon
);
acc_val
=
acc_val
*
gamma_thread_buf
[
Number
<
gamma_beta_offset
>
{}]
+
beta_thread_buf
[
Number
<
gamma_beta_offset
>
{}];
out_vector
.
template
AsType
<
OutType
>()(
Number
<
i_row_vec_
>
{})
=
type_convert
<
OutType
>
(
acc_val
);
});
index_t
thread_offset
=
(
thread_row_cluster_id
+
i_row_sub_
*
RowClusterSize
)
*
sizeof
(
OutType
)
*
RowVectorSize
;
amd_buffer_store_impl
<
OutType
,
RowVectorSize
>
(
out_vector
.
template
AsType
<
dst_vector_t
>()[
Number
<
0
>
{}],
out_res
,
thread_offset
,
0
);
});
};
// first load index
ck
::
static_for
<
0
,
DimPerBlock
,
1
>
{}([
&
](
auto
i_idx_
)
{
// prefer use s_load
index_buf_a
(
i_idx_
)
=
p_index_a
[
index_start
+
i_idx_
.
value
];
index_buf_b
(
i_idx_
)
=
p_index_b
[
index_start
+
i_idx_
.
value
];
index_buf_c
(
i_idx_
)
=
p_index_c
[
index_start
+
i_idx_
.
value
];
});
// load gamma/beta
static_for
<
0
,
RowSubBlocks
,
1
>
{}([
&
](
auto
i_row_sub_
)
{
vector_type_maker_t
<
GammaDataType
,
RowVectorSize
>
gamma_vector
;
vector_type_maker_t
<
BetaDataType
,
RowVectorSize
>
beta_vector
;
index_t
thread_offset_gamma
=
(
thread_row_cluster_id
+
i_row_sub_
*
RowClusterSize
)
*
sizeof
(
GammaDataType
)
*
RowVectorSize
;
index_t
thread_offset_beta
=
(
thread_row_cluster_id
+
i_row_sub_
*
RowClusterSize
)
*
sizeof
(
BetaDataType
)
*
RowVectorSize
;
int32x4_t
gamma_res
=
make_wave_buffer_resource_with_default_range
(
p_gamma
);
int32x4_t
beta_res
=
make_wave_buffer_resource_with_default_range
(
p_beta
);
gamma_vector
.
template
AsType
<
typename
decltype
(
gamma_vector
)
::
type
>()(
I0
)
=
amd_buffer_load_impl
<
GammaDataType
,
RowVectorSize
>
(
gamma_res
,
thread_offset_gamma
,
0
);
beta_vector
.
template
AsType
<
typename
decltype
(
beta_vector
)
::
type
>()(
I0
)
=
amd_buffer_load_impl
<
BetaDataType
,
RowVectorSize
>
(
beta_res
,
thread_offset_beta
,
0
);
static_for
<
0
,
RowVectorSize
,
1
>
{}([
&
](
auto
i_row_vec_
)
{
constexpr
auto
offset
=
gamma_beta_buf_desc
.
CalculateOffset
(
make_tuple
(
i_row_sub_
,
i_row_vec_
));
gamma_thread_buf
(
Number
<
offset
>
{})
=
type_convert
<
AccDataType
>
(
gamma_vector
.
template
AsType
<
GammaDataType
>()[
Number
<
i_row_vec_
>
{}]);
beta_thread_buf
(
Number
<
offset
>
{})
=
type_convert
<
AccDataType
>
(
beta_vector
.
template
AsType
<
BetaDataType
>()[
Number
<
i_row_vec_
>
{}]);
});
});
static_for
<
0
,
thread_buf_size
,
1
>
{}(
[
&
](
auto
I
)
{
acc_thread_buf
(
I
)
=
type_convert
<
AccDataType
>
(
0.0
f
);
});
static_for
<
0
,
mean_var_buf_size
,
1
>
{}([
&
](
auto
I
)
{
mean_thread_buf
(
I
)
=
type_convert
<
AccDataType
>
(
0.0
f
);
var_thread_buf
(
I
)
=
type_convert
<
AccDataType
>
(
0.0
f
);
});
static_for
<
0
,
DimSubBlocks
,
1
>
{}([
&
](
auto
i_dim_sub
)
{
load_current_sub_row
(
i_dim_sub
,
Number
<
0
>
{});
static_for
<
0
,
RowSubBlocks
-
1
,
1
>
{}([
&
](
auto
i_row
)
{
load_current_sub_row
(
i_dim_sub
,
Number
<
1
>
{}
+
i_row
);
accumulate_current_sub_row
(
i_dim_sub
,
i_row
);
threadwise_welford_sub_row
(
i_dim_sub
,
i_row
);
});
accumulate_current_sub_row
(
i_dim_sub
,
Number
<
RowSubBlocks
-
1
>
{});
threadwise_welford_sub_row
(
i_dim_sub
,
Number
<
RowSubBlocks
-
1
>
{});
// blockwise welford
static_for
<
0
,
mean_var_buf_size
,
1
>
{}([
&
](
auto
I
)
{
if
constexpr
(
I
>
0
)
block_sync_lds
();
BlockwiseWelford
::
Run
(
mean_thread_buf
(
I
),
var_thread_buf
(
I
),
threadwise_welford
.
cur_count_
);
});
// store
static_for
<
0
,
RowSubBlocks
,
1
>
{}(
[
&
](
auto
i_row
)
{
threadwise_normalize_store_out
(
i_dim_sub
,
i_row
);
});
});
}
};
}
// namespace ck
include/ck/tensor_operation/gpu/grid/gridwise_unary_elementwise_1d.hpp
deleted
100644 → 0
View file @
41d5fca7
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/data_type.hpp"
#include "ck/tensor_description/cluster_descriptor.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
namespace
ck
{
template
<
typename
GridwiseUEltwise
,
typename
ADataType
,
typename
BDataType
,
typename
GridDesc_M0
,
typename
ElementwiseFunctor
>
__global__
void
kernel_unary_elementwise_1d
(
const
ADataType
*
__restrict__
p_a_global
,
BDataType
*
__restrict__
p_b_global
,
const
GridDesc_M0
a_grid_desc_m0
,
const
GridDesc_M0
b_grid_desc_m0
,
const
ElementwiseFunctor
functor
)
{
GridwiseUEltwise
::
Run
(
p_a_global
,
p_b_global
,
a_grid_desc_m0
,
b_grid_desc_m0
,
functor
);
}
template
<
typename
ADataType
,
typename
BDataType
,
typename
GridDesc_M0
,
typename
ElementwiseFunctor
,
index_t
ScalarPerVector
>
struct
GridwiseUnaryElementwise_1D
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
thread_desc_m0
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
ScalarPerVector
>
{}));
using
PassThrough
=
tensor_operation
::
element_wise
::
PassThrough
;
static
__device__
auto
CalculateElementwiseIndex
()
{
const
index_t
global_thread_id
=
get_thread_global_1d_id
();
return
make_multi_index
(
global_thread_id
*
ScalarPerVector
);
}
__host__
__device__
static
constexpr
bool
CheckValidity
(
const
GridDesc_M0
a_grid_desc_m0
,
const
GridDesc_M0
b_grid_desc_m0
)
{
return
a_grid_desc_m0
.
GetLength
(
I0
)
==
b_grid_desc_m0
.
GetLength
(
I0
);
}
__host__
__device__
static
constexpr
index_t
CalculateGridSize
(
const
index_t
tensor_size
)
{
const
index_t
grid_size
=
math
::
integer_divide_ceil
(
tensor_size
,
256
*
ScalarPerVector
);
return
grid_size
;
}
__device__
static
void
Run
(
const
ADataType
*
__restrict__
p_a_global
,
BDataType
*
__restrict__
p_b_global
,
const
GridDesc_M0
a_grid_desc_m0
,
const
GridDesc_M0
b_grid_desc_m0
,
const
ElementwiseFunctor
functor
)
{
const
auto
a_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_a_global
,
a_grid_desc_m0
.
GetElementSpaceSize
());
auto
b_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_b_global
,
b_grid_desc_m0
.
GetElementSpaceSize
());
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
ADataType
,
ScalarPerVector
,
true
>
a_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
BDataType
,
ScalarPerVector
,
true
>
b_thread_buf
;
const
auto
thread_store_global_offset
=
CalculateElementwiseIndex
();
auto
a_global_load
=
ThreadwiseTensorSliceTransfer_v2
<
ADataType
,
ADataType
,
GridDesc_M0
,
decltype
(
thread_desc_m0
),
Sequence
<
ScalarPerVector
>
,
// SliceLengths
Sequence
<
0
>
,
// DimAccessOrder
0
,
// SrcVectorDim
ScalarPerVector
,
1
,
// SrcScalarStrideInVector
false
>
{
a_grid_desc_m0
,
thread_store_global_offset
};
auto
b_global_write
=
ThreadwiseTensorSliceTransfer_v1r3
<
BDataType
,
BDataType
,
decltype
(
thread_desc_m0
),
GridDesc_M0
,
PassThrough
,
Sequence
<
ScalarPerVector
>
,
// SliceLengths
Sequence
<
0
>
,
// DimAccessOrder
0
,
// DstVectorDim
ScalarPerVector
,
InMemoryDataOperationEnum
::
Set
,
1
,
// DstScalarStrideInVector
false
>
{
b_grid_desc_m0
,
thread_store_global_offset
,
PassThrough
{}};
const
index_t
blockSize
=
get_block_size
();
const
index_t
blockPerGrid
=
get_grid_size
();
const
auto
m0
=
b_grid_desc_m0
.
GetLength
(
I0
);
const
index_t
loop_step
=
blockPerGrid
*
blockSize
*
ScalarPerVector
;
const
auto
loop_step_index
=
make_multi_index
(
loop_step
);
index_t
num_iter
=
m0
/
(
loop_step
);
do
{
// read and process ScalarPerVector elements
a_global_load
.
Run
(
a_grid_desc_m0
,
a_global_buf
,
thread_desc_m0
,
make_tuple
(
I0
),
a_thread_buf
);
static_for
<
0
,
ScalarPerVector
,
1
>
{}([
&
](
auto
m
)
{
constexpr
auto
offset
=
thread_desc_m0
.
CalculateOffset
(
make_tuple
(
m
));
functor
(
b_thread_buf
(
Number
<
offset
>
{}),
a_thread_buf
(
Number
<
offset
>
{}));
});
b_global_write
.
Run
(
thread_desc_m0
,
make_tuple
(
I0
),
// SrcSliceOriginIdx
b_thread_buf
,
b_grid_desc_m0
,
b_global_buf
);
a_global_load
.
MoveSrcSliceWindow
(
a_grid_desc_m0
,
loop_step_index
);
b_global_write
.
MoveDstSliceWindow
(
b_grid_desc_m0
,
loop_step_index
);
}
while
(
--
num_iter
);
}
};
}
// namespace ck
Prev
1
…
6
7
8
9
10
11
12
13
14
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment