Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
644df335
Commit
644df335
authored
Jan 30, 2023
by
rocking
Browse files
Merge branch 'develop' into gemm_layernorm_instance
parents
d99640ab
7494c1c6
Changes
254
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1897 additions
and
218 deletions
+1897
-218
include/ck/tensor_operation/gpu/device/device_reduce.hpp
include/ck/tensor_operation/gpu/device/device_reduce.hpp
+28
-8
include/ck/tensor_operation/gpu/device/device_softmax.hpp
include/ck/tensor_operation/gpu/device/device_softmax.hpp
+4
-6
include/ck/tensor_operation/gpu/device/impl/device_elementwise_2d_impl.hpp
..._operation/gpu/device/impl/device_elementwise_2d_impl.hpp
+5
-5
include/ck/tensor_operation/gpu/device/impl/device_elementwise_impl.hpp
...sor_operation/gpu/device/impl/device_elementwise_impl.hpp
+3
-3
include/ck/tensor_operation/gpu/device/impl/device_elementwise_normalization_impl.hpp
...gpu/device/impl/device_elementwise_normalization_impl.hpp
+4
-4
include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_dl.hpp
...r_operation/gpu/device/impl/device_gemm_multiple_d_dl.hpp
+669
-0
include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp
...n/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp
+2
-3
include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp
...or_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp
+0
-1
include/ck/tensor_operation/gpu/device/impl/device_multiple_reduce_multiblock.hpp
...ion/gpu/device/impl/device_multiple_reduce_multiblock.hpp
+8
-8
include/ck/tensor_operation/gpu/device/impl/device_multiple_reduce_threadwise.hpp
...ion/gpu/device/impl/device_multiple_reduce_threadwise.hpp
+6
-6
include/ck/tensor_operation/gpu/device/impl/device_normalization_impl.hpp
...r_operation/gpu/device/impl/device_normalization_impl.hpp
+5
-4
include/ck/tensor_operation/gpu/device/impl/device_reduce_multiblock.hpp
...or_operation/gpu/device/impl/device_reduce_multiblock.hpp
+16
-8
include/ck/tensor_operation/gpu/device/impl/device_reduce_threadwise.hpp
...or_operation/gpu/device/impl/device_reduce_threadwise.hpp
+15
-6
include/ck/tensor_operation/gpu/device/impl/device_softmax_impl.hpp
.../tensor_operation/gpu/device/impl/device_softmax_impl.hpp
+12
-11
include/ck/tensor_operation/gpu/device/impl/device_sparse_embeddings_forward_layernorm.hpp
...evice/impl/device_sparse_embeddings_forward_layernorm.hpp
+44
-61
include/ck/tensor_operation/gpu/element/element_wise_operation.hpp
...k/tensor_operation/gpu/element/element_wise_operation.hpp
+70
-0
include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp
include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp
+44
-0
include/ck/tensor_operation/gpu/grid/gridwise_gemm_waveletmodel.hpp
.../tensor_operation/gpu/grid/gridwise_gemm_waveletmodel.hpp
+157
-0
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_waveletmodel_cshuffle.hpp
...tion/gpu/grid/gridwise_gemm_xdl_waveletmodel_cshuffle.hpp
+744
-0
include/ck/tensor_operation/gpu/grid/gridwise_sparse_embeddings_forward_layernorm.hpp
...gpu/grid/gridwise_sparse_embeddings_forward_layernorm.hpp
+61
-84
No files found.
include/ck/tensor_operation/gpu/device/device_reduce.hpp
View file @
644df335
...
...
@@ -13,10 +13,16 @@ namespace ck {
namespace
tensor_operation
{
namespace
device
{
template
<
index_t
Rank
,
template
<
typename
InDataType
,
typename
AccDataType
,
typename
OutDataType
,
index_t
Rank
,
index_t
NumReduceDim
,
typename
ReduceOperation
,
typename
InElementwiseOperation
,
typename
AccElementwiseOperation
>
typename
AccElementwiseOperation
,
bool
PropagateNan
,
bool
OutputIndex
>
struct
DeviceReduce
:
public
BaseOperator
{
static
constexpr
index_t
NumOutDim
=
(
Rank
-
NumReduceDim
==
0
)
?
1
:
Rank
-
NumReduceDim
;
...
...
@@ -27,8 +33,8 @@ struct DeviceReduce : public BaseOperator
const
std
::
array
<
index_t
,
NumOutDim
>
outLengths
,
const
std
::
array
<
index_t
,
NumOutDim
>
outStrides
,
const
std
::
array
<
int
,
NumReduceDim
>
reduceDims
,
float
alpha
,
float
beta
,
double
alpha
,
double
beta
,
const
void
*
in_dev
,
const
void
*
in_index_dev
,
void
*
out_dev
,
...
...
@@ -39,12 +45,26 @@ struct DeviceReduce : public BaseOperator
virtual
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
=
0
;
};
template
<
index_t
Rank
,
template
<
typename
InDataType
,
typename
AccDataType
,
typename
OutDataType
,
index_t
Rank
,
index_t
NumReduceDim
,
typename
ReduceOperation
,
typename
InElementwiseOperation
,
typename
AccElementwiseOperation
>
using
DeviceReducePtr
=
std
::
unique_ptr
<
DeviceReduce
<
Rank
,
NumReduceDim
,
InElementwiseOperation
,
AccElementwiseOperation
>>
;
typename
AccElementwiseOperation
,
bool
PropagateNan
,
bool
OutputIndex
>
using
DeviceReducePtr
=
std
::
unique_ptr
<
DeviceReduce
<
InDataType
,
AccDataType
,
OutDataType
,
Rank
,
NumReduceDim
,
ReduceOperation
,
InElementwiseOperation
,
AccElementwiseOperation
,
PropagateNan
,
OutputIndex
>>
;
}
// namespace device
}
// namespace tensor_operation
...
...
include/ck/tensor_operation/gpu/device/device_softmax.hpp
View file @
644df335
...
...
@@ -27,10 +27,8 @@ struct DeviceSoftmax : public BaseOperator
// @param[in] inLengths Input tensor extent(s) from high to low dimension
// @param[in] inStrides Input tensor stride(s) from high to low dimension
// @param[in] reduceDims The dimension(s) the normalization operation is applied
// @param[in] alpha Typeless pointer in host memory storing the alpha scaling
// value as type AccDataType
// @param[in] beta Typeless pointer in host memory storing the beta scaling
// value as type AccDataType
// @param[in] alpha double type value
// @param[in] beta double type value
// @param[in] in_dev Typeless const pointer in device memory storing the input
// tensor
// @param out_dev Typeless pointer in device memory storing the output tensor
...
...
@@ -43,8 +41,8 @@ struct DeviceSoftmax : public BaseOperator
MakeArgumentPointer
(
const
std
::
vector
<
index_t
>
inLengths
,
const
std
::
vector
<
index_t
>
inStrides
,
const
std
::
vector
<
int
>
reduceDims
,
const
void
*
alpha
,
const
void
*
beta
,
double
alpha
,
double
beta
,
const
void
*
in_dev
,
void
*
out_dev
,
InElementwiseOp
in_elementwise_op
,
...
...
include/ck/tensor_operation/gpu/device/device_elementwise_2d.hpp
→
include/ck/tensor_operation/gpu/device/
impl/
device_elementwise_2d
_impl
.hpp
View file @
644df335
...
...
@@ -8,7 +8,7 @@
#include "ck/utility/math.hpp"
#include "ck/utility/sequence.hpp"
#include "ck/tensor_operation/gpu/device/device_elementwise
_base
.hpp"
#include "ck/tensor_operation/gpu/device/device_elementwise.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_elementwise_2d.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
...
...
@@ -26,10 +26,10 @@ template <typename InDataTypeTuple,
index_t
NPerThread
,
typename
InScalarPerVectorSeq
,
typename
OutScalarPerVectorSeq
>
struct
DeviceElementwise
:
public
DeviceElementwise
Base
<
InDataTypeTuple
,
OutDataTypeTuple
,
ElementwiseOperation
,
NumDim_m
+
NumDim_n
>
struct
DeviceElementwise
2dImpl
:
public
DeviceElementwise
<
InDataTypeTuple
,
OutDataTypeTuple
,
ElementwiseOperation
,
NumDim_m
+
NumDim_n
>
{
static
constexpr
index_t
NumDim
=
NumDim_m
+
NumDim_n
;
...
...
include/ck/tensor_operation/gpu/device/impl/device_elementwise.hpp
→
include/ck/tensor_operation/gpu/device/impl/device_elementwise
_impl
.hpp
View file @
644df335
...
...
@@ -8,7 +8,7 @@
#include "ck/utility/math.hpp"
#include "ck/utility/sequence.hpp"
#include "ck/tensor_operation/gpu/device/device_elementwise
_base
.hpp"
#include "ck/tensor_operation/gpu/device/device_elementwise.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_elementwise_1d.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
...
...
@@ -25,8 +25,8 @@ template <typename InDataTypeTuple,
index_t
MPerThread
,
typename
InScalarPerVectorSeq
,
typename
OutScalarPerVectorSeq
>
struct
DeviceElementwise
:
public
DeviceElementwise
Base
<
InDataTypeTuple
,
OutDataTypeTuple
,
ElementwiseOperation
,
NumDim
>
struct
DeviceElementwise
Impl
:
public
DeviceElementwise
<
InDataTypeTuple
,
OutDataTypeTuple
,
ElementwiseOperation
,
NumDim
>
{
static
constexpr
int
NumInput
=
InDataTypeTuple
::
Size
();
static
constexpr
int
NumOutput
=
OutDataTypeTuple
::
Size
();
...
...
include/ck/tensor_operation/gpu/device/impl/device_elementwise_normalization_impl.hpp
View file @
644df335
...
...
@@ -270,18 +270,18 @@ struct DeviceElementwiseNormalizationImpl
const
std
::
vector
<
index_t
>
reduceDims
,
XElementwiseOperation
x_elementwise_op
,
YElementwiseOperation
y_elementwise_op
,
AccDataTyp
e
epsilon
,
doubl
e
epsilon
,
const
std
::
array
<
const
void
*
,
NumInput
>
in_dev_buffers
,
const
GammaDataType
*
p_gamma
,
const
BetaDataType
*
p_beta
,
YDataType
*
p_y
)
:
epsilon_
(
epsilon
),
p_gamma_
(
p_gamma
),
:
p_gamma_
(
p_gamma
),
p_beta_
(
p_beta
),
p_y_
(
p_y
),
x_elementwise_op_
(
x_elementwise_op
),
y_elementwise_op_
(
y_elementwise_op
)
{
epsilon_
=
static_cast
<
AccDataType
>
(
epsilon
);
Lengths_
=
shuffle_tensor_dimensions
<
Rank
,
NumReduceDim
>
(
lengths
,
reduceDims
);
for
(
int
i
=
0
;
i
<
NumInput
;
i
++
)
...
...
@@ -543,7 +543,7 @@ struct DeviceElementwiseNormalizationImpl
const
std
::
vector
<
index_t
>
betaStrides
,
const
std
::
vector
<
index_t
>
yStrides
,
const
std
::
vector
<
index_t
>
reduceDims
,
AccDataTyp
e
epsilon
,
doubl
e
epsilon
,
const
std
::
array
<
const
void
*
,
NumInput
>
in_dev_buffers
,
const
void
*
p_gamma
,
const
void
*
p_beta
,
...
...
include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_dl.hpp
0 → 100644
View file @
644df335
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <sstream>
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_dl_multiple_d.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
namespace
ck
{
template
<
typename
GridwiseGemm
,
typename
ABDataType
,
typename
DsPointer
,
typename
EDataType
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CDEElementwiseOperation
,
typename
AGridDesc_K0_M0_M1_K1
,
typename
BGridDesc_K0_N0_N1_K1
,
typename
DsGridDesc_M0_M10_M11_N0_N10_N11
,
typename
CGridDesc_M0_M10_M11_N0_N10_N11
,
typename
Block2CTileMap
,
bool
HasMainKBlockLoop
,
bool
HasDoubleTailKBlockLoop
>
__global__
void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
#endif
kernel_gemm_dl_multiple_d
(
const
ABDataType
*
__restrict__
p_a_grid
,
const
ABDataType
*
__restrict__
p_b_grid
,
DsPointer
p_ds_grid
,
EDataType
*
__restrict__
p_e_grid
,
const
AElementwiseOperation
a_element_op
,
const
BElementwiseOperation
b_element_op
,
const
CDEElementwiseOperation
cde_element_op
,
const
AGridDesc_K0_M0_M1_K1
a_grid_desc_k0_m0_m1_k1
,
const
BGridDesc_K0_N0_N1_K1
b_grid_desc_k0_n0_n1_k1
,
const
DsGridDesc_M0_M10_M11_N0_N10_N11
ds_grid_desc_m0_m10_m11_n0_n10_n11
,
const
CGridDesc_M0_M10_M11_N0_N10_N11
e_grid_desc_m0_m10_m11_n0_n10_n11
,
const
Block2CTileMap
block_2_ctile_map
)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \
defined(__gfx1030__))
constexpr
index_t
shared_block_size
=
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()
/
sizeof
(
ABDataType
);
__shared__
ABDataType
p_shared
[
shared_block_size
];
GridwiseGemm
::
Run
(
p_a_grid
,
p_b_grid
,
p_ds_grid
,
p_e_grid
,
p_shared
,
a_element_op
,
b_element_op
,
cde_element_op
,
a_grid_desc_k0_m0_m1_k1
,
b_grid_desc_k0_n0_n1_k1
,
ds_grid_desc_m0_m10_m11_n0_n10_n11
,
e_grid_desc_m0_m10_m11_n0_n10_n11
,
block_2_ctile_map
,
integral_constant
<
bool
,
HasMainKBlockLoop
>
{},
integral_constant
<
bool
,
HasDoubleTailKBlockLoop
>
{});
#else
ignore
=
p_a_grid
;
ignore
=
p_b_grid
;
ignore
=
p_ds_grid
;
ignore
=
p_e_grid
;
ignore
=
a_element_op
;
ignore
=
b_element_op
;
ignore
=
cde_element_op
;
ignore
=
a_grid_desc_k0_m0_m1_k1
;
ignore
=
b_grid_desc_k0_n0_n1_k1
;
ignore
=
ds_grid_desc_m0_m10_m11_n0_n10_n11
;
ignore
=
e_grid_desc_m0_m10_m11_n0_n10_n11
;
ignore
=
block_2_ctile_map
;
#endif
}
}
// namespace ck
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
template
<
typename
ALayout
,
typename
BLayout
,
typename
DsLayout
,
typename
ELayout
,
typename
ADataType
,
typename
BDataType
,
typename
AccDataType
,
typename
DsDataType
,
typename
EDataType
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CDEElementwiseOperation
,
GemmSpecialization
GemmSpec
,
index_t
BlockSize
,
index_t
MPerBlock
,
index_t
NPerBlock
,
index_t
K0PerBlock
,
index_t
K1
,
index_t
M1PerThread
,
index_t
N1PerThread
,
index_t
KPerThread
,
typename
M1N1ThreadClusterM1Xs
,
typename
M1N1ThreadClusterN1Xs
,
typename
ABlockTransferThreadSliceLengths_K0_M0_M1_K1
,
typename
ABlockTransferThreadClusterLengths_K0_M0_M1_K1
,
typename
ABlockTransferThreadClusterArrangeOrder
,
typename
ABlockTransferSrcAccessOrder
,
typename
ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1
,
typename
ABlockTransferSrcVectorTensorContiguousDimOrder
,
typename
ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1
,
typename
BBlockTransferThreadSliceLengths_K0_N0_N1_K1
,
typename
BBlockTransferThreadClusterLengths_K0_N0_N1_K1
,
typename
BBlockTransferThreadClusterArrangeOrder
,
typename
BBlockTransferSrcAccessOrder
,
typename
BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1
,
typename
BBlockTransferSrcVectorTensorContiguousDimOrder
,
typename
BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1
,
typename
CThreadTransferSrcDstAccessOrder
,
index_t
CThreadTransferSrcDstVectorDim
,
index_t
CThreadTransferDstScalarPerVector
,
enable_if_t
<
is_same_v
<
AElementwiseOperation
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
>
&&
is_same_v
<
BElementwiseOperation
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
>
,
bool
>
=
false
>
struct
DeviceGemmMultipleD_Dl
:
public
DeviceGemmMultipleD
<
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
ADataType
,
BDataType
,
DsDataType
,
EDataType
,
AElementwiseOperation
,
BElementwiseOperation
,
CDEElementwiseOperation
>
{
using
DeviceOp
=
DeviceGemmMultipleD_Dl
;
static
constexpr
index_t
NumDTensor
=
DsDataType
::
Size
();
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
static
constexpr
auto
I4
=
Number
<
4
>
{};
static
constexpr
auto
I5
=
Number
<
5
>
{};
static
constexpr
auto
K1Number
=
Number
<
K1
>
{};
static
auto
MakeAGridDescriptor_K0_M_K1
(
index_t
M
,
index_t
K
,
index_t
StrideA
)
{
assert
(
K
%
K1
==
0
);
const
index_t
K0
=
K
/
K1
;
const
auto
a_grid_desc_m_k
=
[
&
]()
{
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
ALayout
>::
value
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
M
,
K
),
make_tuple
(
StrideA
,
I1
));
}
else
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
ColumnMajor
,
ALayout
>::
value
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
M
,
K
),
make_tuple
(
I1
,
StrideA
));
}
}();
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
MNPadding
)
{
const
auto
PadM
=
(
MPerBlock
-
M
%
MPerBlock
)
%
MPerBlock
;
return
transform_tensor_descriptor
(
a_grid_desc_m_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
K0
,
K1Number
)),
make_right_pad_transform
(
M
,
PadM
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
}
else
{
return
transform_tensor_descriptor
(
a_grid_desc_m_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
K0
,
K1Number
)),
make_pass_through_transform
(
M
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
}
}
static
auto
MakeBGridDescriptor_K0_N_K1
(
index_t
K
,
index_t
N
,
index_t
StrideB
)
{
assert
(
K
%
K1
==
0
);
const
index_t
K0
=
K
/
K1
;
const
auto
b_grid_desc_k_n
=
[
&
]()
{
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
BLayout
>::
value
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
K
,
N
),
make_tuple
(
StrideB
,
I1
));
}
else
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
ColumnMajor
,
BLayout
>::
value
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
K
,
N
),
make_tuple
(
I1
,
StrideB
));
}
}();
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
MNPadding
)
{
const
auto
PadN
=
(
NPerBlock
-
N
%
NPerBlock
)
%
NPerBlock
;
return
transform_tensor_descriptor
(
b_grid_desc_k_n
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
K0
,
K1Number
)),
make_right_pad_transform
(
N
,
PadN
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
}
else
{
return
transform_tensor_descriptor
(
b_grid_desc_k_n
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
K0
,
K1Number
)),
make_pass_through_transform
(
N
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
}
}
template
<
typename
ELay
>
static
auto
MakeEGridDescriptor_M_N
(
index_t
M
,
index_t
N
,
index_t
StrideE
)
{
const
auto
c_grid_desc_m_n
=
[
&
]()
{
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
ELay
>::
value
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
M
,
N
),
make_tuple
(
StrideE
,
I1
));
}
else
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
ColumnMajor
,
ELay
>::
value
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
M
,
N
),
make_tuple
(
I1
,
StrideE
));
}
}();
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
MNPadding
)
{
const
auto
PadM
=
(
MPerBlock
-
M
%
MPerBlock
)
%
MPerBlock
;
const
auto
PadN
=
(
NPerBlock
-
N
%
NPerBlock
)
%
NPerBlock
;
return
transform_tensor_descriptor
(
c_grid_desc_m_n
,
make_tuple
(
make_right_pad_transform
(
M
,
PadM
),
make_right_pad_transform
(
N
,
PadN
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
}
else
{
return
transform_tensor_descriptor
(
c_grid_desc_m_n
,
make_tuple
(
make_pass_through_transform
(
M
),
make_pass_through_transform
(
N
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
}
}
static
auto
MakeDsGridDescriptor_M_N
(
const
std
::
array
<
index_t
,
NumDTensor
>&
MRaws
,
const
std
::
array
<
index_t
,
NumDTensor
>&
NRaws
,
const
std
::
array
<
index_t
,
NumDTensor
>&
DsStride
)
{
return
generate_tuple
(
[
&
](
auto
i
)
{
using
DLayout
=
remove_cvref_t
<
tuple_element_t
<
i
.
value
,
DsLayout
>>
;
return
DeviceOp
::
MakeEGridDescriptor_M_N
<
DLayout
>
(
MRaws
[
i
],
NRaws
[
i
],
DsStride
[
i
]);
},
Number
<
NumDTensor
>
{});
}
using
AGridDesc_K0_M_K1
=
decltype
(
MakeAGridDescriptor_K0_M_K1
(
1
,
1
,
1
));
using
BGridDesc_K0_N_K1
=
decltype
(
MakeBGridDescriptor_K0_N_K1
(
1
,
1
,
1
));
using
DsGridDesc_M_N
=
decltype
(
MakeDsGridDescriptor_M_N
({},
{},
{}));
using
EGridDesc_M_N
=
decltype
(
MakeEGridDescriptor_M_N
<
ELayout
>
(
1
,
1
,
1
));
// GridwiseGemm
using
GridwiseGemm
=
GridwiseGemmDlMultipleD_km_kn_mn
<
BlockSize
,
ADataType
,
AccDataType
,
DsDataType
,
EDataType
,
AElementwiseOperation
,
BElementwiseOperation
,
CDEElementwiseOperation
,
InMemoryDataOperationEnum
::
Set
,
AGridDesc_K0_M_K1
,
BGridDesc_K0_N_K1
,
EGridDesc_M_N
,
MPerBlock
,
NPerBlock
,
K0PerBlock
,
K1
,
M1PerThread
,
N1PerThread
,
KPerThread
,
M1N1ThreadClusterM1Xs
,
M1N1ThreadClusterN1Xs
,
ABlockTransferThreadSliceLengths_K0_M0_M1_K1
,
ABlockTransferThreadClusterLengths_K0_M0_M1_K1
,
ABlockTransferThreadClusterArrangeOrder
,
ABlockTransferSrcAccessOrder
,
ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1
,
ABlockTransferSrcVectorTensorContiguousDimOrder
,
ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1
,
BBlockTransferThreadSliceLengths_K0_N0_N1_K1
,
BBlockTransferThreadClusterLengths_K0_N0_N1_K1
,
BBlockTransferThreadClusterArrangeOrder
,
BBlockTransferSrcAccessOrder
,
BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1
,
BBlockTransferSrcVectorTensorContiguousDimOrder
,
BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1
,
CThreadTransferSrcDstAccessOrder
,
CThreadTransferSrcDstVectorDim
,
CThreadTransferDstScalarPerVector
>
;
using
AGridDesc_K0_M0_M1_K1
=
decltype
(
GridwiseGemm
::
MakeAGridDescriptor_K0_M0_M1_K1
(
AGridDesc_K0_M_K1
{}));
using
BGridDesc_K0_N0_N1_K1
=
decltype
(
GridwiseGemm
::
MakeBGridDescriptor_K0_N0_N1_K1
(
BGridDesc_K0_N_K1
{}));
using
DsGridDesc_M0_M10_M11_N0_N10_N11
=
decltype
(
GridwiseGemm
::
MakeDsGridDescriptor_M0_M10_M11_N0_N10_N11
(
DsGridDesc_M_N
{}));
using
EGridDesc_M0_M10_M11_N0_N10_N11
=
decltype
(
GridwiseGemm
::
MakeCGridDescriptor_M0_M10_M11_N0_N10_N11
(
EGridDesc_M_N
{}));
using
DefaultBlock2CTileMap
=
decltype
(
GridwiseGemm
::
MakeDefaultBlock2CTileMap
(
EGridDesc_M_N
{}));
// Argument
struct
Argument
:
public
BaseArgument
{
Argument
(
const
void
*
p_a_grid
,
const
void
*
p_b_grid
,
std
::
array
<
const
void
*
,
NumDTensor
>
p_ds_grid
,
void
*
p_e_grid
,
index_t
M
,
index_t
N
,
index_t
K
,
index_t
StrideA
,
index_t
StrideB
,
std
::
array
<
index_t
,
NumDTensor
>
StrideDs
,
index_t
StrideE
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CDEElementwiseOperation
cde_element_op
)
:
p_a_grid_
{
static_cast
<
const
ADataType
*>
(
p_a_grid
)},
p_b_grid_
{
static_cast
<
const
BDataType
*>
(
p_b_grid
)},
p_ds_grid_
{},
p_e_grid_
{
static_cast
<
EDataType
*>
(
p_e_grid
)},
a_grid_desc_k0_m0_m1_k1_
{},
b_grid_desc_k0_n0_n1_k1_
{},
e_grid_desc_m0_m10_m11_n0_n10_n11_
{},
block_2_ctile_map_
{},
a_element_op_
{
a_element_op
},
b_element_op_
{
b_element_op
},
cde_element_op_
{
cde_element_op
}
{
a_grid_desc_k0_m_k1_
=
DeviceGemmMultipleD_Dl
::
MakeAGridDescriptor_K0_M_K1
(
M
,
K
,
StrideA
);
b_grid_desc_k0_n_k1_
=
DeviceGemmMultipleD_Dl
::
MakeBGridDescriptor_K0_N_K1
(
K
,
N
,
StrideB
);
static_for
<
0
,
NumDTensor
,
1
>
{}([
&
](
auto
i
)
{
using
DLayout
=
remove_cvref_t
<
tuple_element_t
<
i
.
value
,
DsLayout
>>
;
using
DDataType
=
remove_cvref_t
<
tuple_element_t
<
i
.
value
,
DsDataType
>>
;
// D pointer
p_ds_grid_
(
i
)
=
static_cast
<
const
DDataType
*>
(
p_ds_grid
[
i
]);
// D desc
ds_grid_desc_m_n_
(
i
)
=
DeviceOp
::
MakeEGridDescriptor_M_N
<
DLayout
>
(
M
,
N
,
StrideDs
[
i
]);
});
e_grid_desc_m_n_
=
DeviceGemmMultipleD_Dl
::
MakeEGridDescriptor_M_N
<
ELayout
>
(
M
,
N
,
StrideE
);
if
(
GridwiseGemm
::
CheckValidity
(
a_grid_desc_k0_m_k1_
,
b_grid_desc_k0_n_k1_
,
e_grid_desc_m_n_
))
{
a_grid_desc_k0_m0_m1_k1_
=
GridwiseGemm
::
MakeAGridDescriptor_K0_M0_M1_K1
(
a_grid_desc_k0_m_k1_
);
b_grid_desc_k0_n0_n1_k1_
=
GridwiseGemm
::
MakeBGridDescriptor_K0_N0_N1_K1
(
b_grid_desc_k0_n_k1_
);
ds_grid_desc_m0_m10_m11_n0_n10_n11_
=
GridwiseGemm
::
MakeDsGridDescriptor_M0_M10_M11_N0_N10_N11
(
ds_grid_desc_m_n_
);
e_grid_desc_m0_m10_m11_n0_n10_n11_
=
GridwiseGemm
::
MakeCGridDescriptor_M0_M10_M11_N0_N10_N11
(
e_grid_desc_m_n_
);
block_2_ctile_map_
=
GridwiseGemm
::
MakeDefaultBlock2CTileMap
(
e_grid_desc_m_n_
);
}
}
// private:
const
ADataType
*
p_a_grid_
;
const
BDataType
*
p_b_grid_
;
typename
GridwiseGemm
::
DsGridPointer
p_ds_grid_
;
EDataType
*
p_e_grid_
;
AGridDesc_K0_M_K1
a_grid_desc_k0_m_k1_
;
BGridDesc_K0_N_K1
b_grid_desc_k0_n_k1_
;
DsGridDesc_M_N
ds_grid_desc_m_n_
;
EGridDesc_M_N
e_grid_desc_m_n_
;
AGridDesc_K0_M0_M1_K1
a_grid_desc_k0_m0_m1_k1_
;
BGridDesc_K0_N0_N1_K1
b_grid_desc_k0_n0_n1_k1_
;
DsGridDesc_M0_M10_M11_N0_N10_N11
ds_grid_desc_m0_m10_m11_n0_n10_n11_
;
EGridDesc_M0_M10_M11_N0_N10_N11
e_grid_desc_m0_m10_m11_n0_n10_n11_
;
DefaultBlock2CTileMap
block_2_ctile_map_
;
// TODO: unused since gridwise_gemm_dl_v1r3 does NOT support prologue for the time being.
AElementwiseOperation
a_element_op_
;
BElementwiseOperation
b_element_op_
;
CDEElementwiseOperation
cde_element_op_
;
};
// Invoker
struct
Invoker
:
public
BaseInvoker
{
using
Argument
=
DeviceGemmMultipleD_Dl
::
Argument
;
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
{
{
std
::
cout
<<
"arg.a_grid_desc_k0_m0_m1_k1_{"
<<
arg
.
a_grid_desc_k0_m_k1_
.
GetLength
(
I0
)
<<
", "
<<
arg
.
a_grid_desc_k0_m_k1_
.
GetLength
(
I1
)
<<
", "
<<
arg
.
a_grid_desc_k0_m_k1_
.
GetLength
(
I2
)
<<
"}"
<<
std
::
endl
;
std
::
cout
<<
"arg.b_grid_desc_k0_n0_n1_k1_{"
<<
arg
.
b_grid_desc_k0_n_k1_
.
GetLength
(
I0
)
<<
", "
<<
arg
.
b_grid_desc_k0_n_k1_
.
GetLength
(
I1
)
<<
", "
<<
arg
.
b_grid_desc_k0_n_k1_
.
GetLength
(
I2
)
<<
"}"
<<
std
::
endl
;
std
::
cout
<<
"arg.e_grid_desc_m_n_{ "
<<
arg
.
e_grid_desc_m_n_
.
GetLength
(
I0
)
<<
", "
<<
arg
.
e_grid_desc_m_n_
.
GetLength
(
I1
)
<<
"}"
<<
std
::
endl
;
}
if
(
!
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_k0_m_k1_
,
arg
.
b_grid_desc_k0_n_k1_
,
arg
.
e_grid_desc_m_n_
))
{
throw
std
::
runtime_error
(
"wrong! GridwiseGemmDlMultipleD_km_kn_mn has invalid setting"
);
}
const
index_t
grid_size
=
GridwiseGemm
::
CalculateGridSize
(
arg
.
e_grid_desc_m_n_
.
GetLength
(
I0
),
arg
.
e_grid_desc_m_n_
.
GetLength
(
I1
));
auto
launch_kernel
=
[
&
](
auto
has_main_k_block_loop
,
auto
has_double_tail_k_block_loop
)
{
constexpr
bool
has_main_loop
=
has_main_k_block_loop
.
value
;
constexpr
bool
has_double_loop
=
has_double_tail_k_block_loop
.
value
;
const
auto
kernel
=
kernel_gemm_dl_multiple_d
<
GridwiseGemm
,
ADataType
,
typename
GridwiseGemm
::
DsGridPointer
,
EDataType
,
AElementwiseOperation
,
BElementwiseOperation
,
CDEElementwiseOperation
,
DeviceOp
::
AGridDesc_K0_M0_M1_K1
,
DeviceOp
::
BGridDesc_K0_N0_N1_K1
,
DeviceOp
::
DsGridDesc_M0_M10_M11_N0_N10_N11
,
DeviceOp
::
EGridDesc_M0_M10_M11_N0_N10_N11
,
DefaultBlock2CTileMap
,
has_main_loop
,
has_double_loop
>
;
return
launch_and_time_kernel
(
stream_config
,
kernel
,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
arg
.
p_a_grid_
,
arg
.
p_b_grid_
,
arg
.
p_ds_grid_
,
arg
.
p_e_grid_
,
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
cde_element_op_
,
arg
.
a_grid_desc_k0_m0_m1_k1_
,
arg
.
b_grid_desc_k0_n0_n1_k1_
,
arg
.
ds_grid_desc_m0_m10_m11_n0_n10_n11_
,
arg
.
e_grid_desc_m0_m10_m11_n0_n10_n11_
,
arg
.
block_2_ctile_map_
);
};
const
auto
K0
=
arg
.
a_grid_desc_k0_m0_m1_k1_
.
GetLength
(
I0
);
const
bool
has_main_k_block_loop
=
GridwiseGemm
::
CalculateHasMainKBlockLoop
(
K0
);
const
bool
has_double_tail_k_block_loop
=
GridwiseGemm
::
CalculateHasDoubleTailKBlockLoop
(
K0
);
if
(
has_main_k_block_loop
&&
has_double_tail_k_block_loop
)
{
return
launch_kernel
(
integral_constant
<
bool
,
true
>
{},
integral_constant
<
bool
,
true
>
{});
}
else
if
(
has_main_k_block_loop
&&
!
has_double_tail_k_block_loop
)
{
return
launch_kernel
(
integral_constant
<
bool
,
true
>
{},
integral_constant
<
bool
,
false
>
{});
}
else
if
(
!
has_main_k_block_loop
&&
has_double_tail_k_block_loop
)
{
return
launch_kernel
(
integral_constant
<
bool
,
false
>
{},
integral_constant
<
bool
,
true
>
{});
}
else
{
return
launch_kernel
(
integral_constant
<
bool
,
false
>
{},
integral_constant
<
bool
,
false
>
{});
}
}
// polymorphic
float
Run
(
const
BaseArgument
*
p_arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
override
{
return
Run
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
),
stream_config
);
}
};
static
constexpr
bool
IsValidCompilationParameter
()
{
// TODO: properly implement this check
return
true
;
}
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
if
(
ck
::
get_device_name
()
==
"gfx906"
||
ck
::
get_device_name
()
==
"gfx908"
||
ck
::
get_device_name
()
==
"gfx1030"
)
{
return
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_k0_m_k1_
,
arg
.
b_grid_desc_k0_n_k1_
,
arg
.
e_grid_desc_m_n_
);
}
else
{
return
false
;
}
}
// polymorphic
bool
IsSupportedArgument
(
const
BaseArgument
*
p_arg
)
override
{
return
IsSupportedArgument
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
));
}
static
auto
MakeArgument
(
const
void
*
p_a
,
const
void
*
p_b
,
std
::
array
<
const
void
*
,
NumDTensor
>
p_ds
,
void
*
p_e
,
index_t
M
,
index_t
N
,
index_t
K
,
index_t
StrideA
,
index_t
StrideB
,
std
::
array
<
ck
::
index_t
,
NumDTensor
>
StrideDs
,
index_t
StrideE
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CDEElementwiseOperation
cde_element_op
)
{
return
Argument
{
p_a
,
p_b
,
p_ds
,
p_e
,
M
,
N
,
K
,
StrideA
,
StrideB
,
StrideDs
,
StrideE
,
a_element_op
,
b_element_op
,
cde_element_op
};
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
// polymorphic
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
void
*
p_a
,
const
void
*
p_b
,
std
::
array
<
const
void
*
,
NumDTensor
>
p_ds
,
void
*
p_e
,
index_t
M
,
index_t
N
,
index_t
K
,
index_t
StrideA
,
index_t
StrideB
,
std
::
array
<
ck
::
index_t
,
NumDTensor
>
StrideDs
,
index_t
StrideE
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CDEElementwiseOperation
cde_element_op
)
override
{
return
std
::
make_unique
<
Argument
>
(
p_a
,
p_b
,
p_ds
,
p_e
,
M
,
N
,
K
,
StrideA
,
StrideB
,
StrideDs
,
StrideE
,
a_element_op
,
b_element_op
,
cde_element_op
);
}
// polymorphic
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
override
{
return
std
::
make_unique
<
Invoker
>
(
Invoker
{});
}
// polymorphic
std
::
string
GetTypeString
()
const
override
{
auto
str
=
std
::
stringstream
();
// clang-format off
str
<<
"DeviceGemmMultipleD_Dl"
<<
"<"
<<
BlockSize
<<
", "
<<
MPerBlock
<<
", "
<<
NPerBlock
<<
", "
<<
K0PerBlock
<<
", "
<<
K1
<<
", "
<<
M1PerThread
<<
", "
<<
N1PerThread
<<
", "
<<
KPerThread
<<
">"
;
// clang-format on
return
str
.
str
();
}
};
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp
View file @
644df335
...
...
@@ -431,9 +431,6 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
const
index_t
grid_size
=
arg
.
block_2_etile_map_
.
CalculateGridSize
(
arg
.
e_grid_desc_m_n_
);
const
auto
K
=
arg
.
a_grid_desc_ak0_m_ak1_
.
GetLength
(
I0
)
*
arg
.
a_grid_desc_ak0_m_ak1_
.
GetLength
(
I2
);
auto
launch_kernel
=
[
&
](
auto
has_main_k_block_loop
)
{
constexpr
bool
has_main_loop
=
has_main_k_block_loop
.
value
;
...
...
@@ -471,6 +468,8 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
arg
.
block_2_etile_map_
);
};
const
auto
K
=
arg
.
a_grid_desc_m_k_
.
GetLength
(
I1
);
if
(
GridwiseGemm
::
CalculateHasMainKBlockLoop
(
K
))
{
return
launch_kernel
(
integral_constant
<
bool
,
true
>
{});
...
...
include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp
View file @
644df335
...
...
@@ -486,7 +486,6 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
const
index_t
grid_size
=
arg
.
block_2_ctile_map_
.
CalculateGridSize
(
arg
.
c_grid_desc_m_n_
);
const
auto
K
=
arg
.
a_grid_desc_ak0_m_ak1_
.
GetLength
(
I0
)
*
arg
.
a_grid_desc_ak0_m_ak1_
.
GetLength
(
I2
);
...
...
include/ck/tensor_operation/gpu/device/impl/device_multiple_reduce_multiblock.hpp
View file @
644df335
...
...
@@ -73,8 +73,8 @@ struct DeviceMultipleReduceMultiBlock : public DeviceMultipleReduce<Rank,
static_for
<
0
,
NumReduction
,
1
>
{}([
&
](
auto
I
)
{
using
OutDataType
=
remove_cvref_t
<
decltype
(
OutDataTypeTuple
{}[
I
])
>
;
flag
=
flag
&&
ck
::
reduce
::
InMemoryDataOperatonSupportedOnDataType
<
OutMemoryDataOperation
,
OutDataType
>::
value
;
flag
&&
ck
::
reduce
::
InMemoryDataOperat
i
onSupportedOnDataType
<
OutMemoryDataOperation
,
OutDataType
>::
value
;
});
return
flag
;
...
...
@@ -270,8 +270,8 @@ struct DeviceMultipleReduceMultiBlock : public DeviceMultipleReduce<Rank,
const
std
::
array
<
index_t
,
NumOutputDim
>&
outLengths
,
const
std
::
array
<
std
::
array
<
index_t
,
NumOutputDim
>
,
NumReduction
>&
outStridesArray
,
const
std
::
array
<
int
,
NumReduceDim
>&
reduceDims
,
const
std
::
array
<
const
void
*
,
NumReduction
>&
alphas
,
const
std
::
array
<
const
void
*
,
NumReduction
>&
betas
,
const
std
::
array
<
double
,
NumReduction
>&
alphas
,
const
std
::
array
<
double
,
NumReduction
>&
betas
,
const
void
*
in_dev
,
const
std
::
array
<
void
*
,
NumReduction
>&
out_dev_buffers
,
const
InElementwiseOperationTuple
in_elementwise_op_tuple
,
...
...
@@ -286,8 +286,8 @@ struct DeviceMultipleReduceMultiBlock : public DeviceMultipleReduce<Rank,
for
(
size_t
i
=
0
;
i
<
NumReduction
;
i
++
)
{
alpha_values_
(
i
)
=
*
static_cast
<
const
AccDataType
*
>
(
alphas
[
i
]);
beta_values_
(
i
)
=
*
static_cast
<
const
AccDataType
*
>
(
betas
[
i
]);
alpha_values_
(
i
)
=
static_cast
<
AccDataType
>
(
alphas
[
i
]);
beta_values_
(
i
)
=
static_cast
<
AccDataType
>
(
betas
[
i
]);
};
in_dev_
=
static_cast
<
const
InDataType
*>
(
in_dev
);
...
...
@@ -547,8 +547,8 @@ struct DeviceMultipleReduceMultiBlock : public DeviceMultipleReduce<Rank,
const
std
::
array
<
index_t
,
NumOutputDim
>
outLengths
,
const
std
::
array
<
std
::
array
<
index_t
,
NumOutputDim
>
,
NumReduction
>
outStridesArray
,
const
std
::
array
<
int
,
NumReduceDim
>
reduceDims
,
const
std
::
array
<
const
void
*
,
NumReduction
>
alphas
,
const
std
::
array
<
const
void
*
,
NumReduction
>
betas
,
const
std
::
array
<
double
,
NumReduction
>
alphas
,
const
std
::
array
<
double
,
NumReduction
>
betas
,
const
void
*
in_dev
,
const
std
::
array
<
void
*
,
NumReduction
>
out_dev_buffers
,
const
InElementwiseOperationTuple
in_elementwise_op_tuple
,
...
...
include/ck/tensor_operation/gpu/device/impl/device_multiple_reduce_threadwise.hpp
View file @
644df335
...
...
@@ -195,8 +195,8 @@ struct DeviceMultipleReduceThreadWise : public DeviceMultipleReduce<Rank,
const
std
::
array
<
index_t
,
NumOutputDim
>&
outLengths
,
const
std
::
array
<
std
::
array
<
index_t
,
NumOutputDim
>
,
NumReduction
>&
outStridesArray
,
const
std
::
array
<
int
,
NumReduceDim
>&
reduceDims
,
const
std
::
array
<
const
void
*
,
NumReduction
>&
alphas
,
const
std
::
array
<
const
void
*
,
NumReduction
>&
betas
,
const
std
::
array
<
double
,
NumReduction
>&
alphas
,
const
std
::
array
<
double
,
NumReduction
>&
betas
,
const
void
*
in_dev
,
const
std
::
array
<
void
*
,
NumReduction
>&
out_dev_buffers
,
const
InElementwiseOperationTuple
in_elementwise_op_tuple
,
...
...
@@ -211,8 +211,8 @@ struct DeviceMultipleReduceThreadWise : public DeviceMultipleReduce<Rank,
for
(
size_t
i
=
0
;
i
<
NumReduction
;
i
++
)
{
alpha_values_
(
i
)
=
*
static_cast
<
const
AccDataType
*
>
(
alphas
[
i
]);
beta_values_
(
i
)
=
*
static_cast
<
const
AccDataType
*
>
(
betas
[
i
]);
alpha_values_
(
i
)
=
static_cast
<
AccDataType
>
(
alphas
[
i
]);
beta_values_
(
i
)
=
static_cast
<
AccDataType
>
(
betas
[
i
]);
};
in_dev_
=
static_cast
<
const
InDataType
*>
(
in_dev
);
...
...
@@ -374,8 +374,8 @@ struct DeviceMultipleReduceThreadWise : public DeviceMultipleReduce<Rank,
const
std
::
array
<
index_t
,
NumOutputDim
>
outLengths
,
const
std
::
array
<
std
::
array
<
index_t
,
NumOutputDim
>
,
NumReduction
>
outStridesArray
,
const
std
::
array
<
int
,
NumReduceDim
>
reduceDims
,
const
std
::
array
<
const
void
*
,
NumReduction
>
alphas
,
const
std
::
array
<
const
void
*
,
NumReduction
>
betas
,
const
std
::
array
<
double
,
NumReduction
>
alphas
,
const
std
::
array
<
double
,
NumReduction
>
betas
,
const
void
*
in_dev
,
const
std
::
array
<
void
*
,
NumReduction
>
out_dev_buffers
,
const
InElementwiseOperationTuple
in_elementwise_op_tuple
,
...
...
include/ck/tensor_operation/gpu/device/impl/device_normalization_impl.hpp
View file @
644df335
...
...
@@ -221,18 +221,19 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
const
std
::
vector
<
index_t
>
yStrides
,
const
std
::
vector
<
index_t
>
reduceDims
,
AccElementwiseOperation
acc_elementwise_op
,
AccDataTyp
e
epsilon
,
doubl
e
epsilon
,
const
XDataType
*
p_x
,
const
GammaDataType
*
p_gamma
,
const
BetaDataType
*
p_beta
,
YDataType
*
p_y
)
:
epsilon_
(
epsilon
),
p_x_
(
p_x
),
:
p_x_
(
p_x
),
p_gamma_
(
p_gamma
),
p_beta_
(
p_beta
),
p_y_
(
p_y
),
acc_elementwise_op_
(
acc_elementwise_op
)
{
epsilon_
=
static_cast
<
AccDataType
>
(
epsilon
);
Lengths_
=
shuffle_tensor_dimensions
<
Rank
,
NumReduceDim
>
(
lengths
,
reduceDims
);
xStrides_
=
shuffle_tensor_dimensions
<
Rank
,
NumReduceDim
>
(
xStrides
,
reduceDims
);
yStrides_
=
shuffle_tensor_dimensions
<
Rank
,
NumReduceDim
>
(
yStrides
,
reduceDims
);
...
...
@@ -421,7 +422,7 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
const
std
::
vector
<
index_t
>
betaStrides
,
const
std
::
vector
<
index_t
>
yStrides
,
const
std
::
vector
<
index_t
>
reduceDims
,
AccDataTyp
e
epsilon
,
doubl
e
epsilon
,
const
void
*
p_x
,
const
void
*
p_gamma
,
const
void
*
p_beta
,
...
...
include/ck/tensor_operation/gpu/device/impl/device_reduce_multiblock.hpp
View file @
644df335
...
...
@@ -40,8 +40,16 @@ template <typename InDataType,
index_t
InSrcVectorDim
,
index_t
InSrcVectorSize
,
index_t
OutDstVectorSize
>
struct
DeviceReduceMultiBlock
:
public
DeviceReduce
<
Rank
,
NumReduceDim
,
InElementwiseOperation
,
AccElementwiseOperation
>
struct
DeviceReduceMultiBlock
:
public
DeviceReduce
<
InDataType
,
AccDataType
,
OutDataType
,
Rank
,
NumReduceDim
,
ReduceOperation
,
InElementwiseOperation
,
AccElementwiseOperation
,
PropagateNan
,
OutputIndex
>
{
static_assert
(
Rank
<=
6
,
"Bigger Rank size is not supported!"
);
static_assert
(
BlockSize
==
MThreadClusterSize
*
KThreadClusterSize
,
...
...
@@ -67,8 +75,8 @@ struct DeviceReduceMultiBlock
static
constexpr
bool
use_multiblock
=
(
OutMemoryDataOperation
==
InMemoryDataOperationEnum
::
AtomicAdd
);
static_assert
(
ck
::
reduce
::
InMemoryDataOperatonSupportedOnDataType
<
OutMemoryDataOperation
,
OutDataType
>::
value
,
static_assert
(
ck
::
reduce
::
InMemoryDataOperat
i
onSupportedOnDataType
<
OutMemoryDataOperation
,
OutDataType
>::
value
,
"The OutDataType must support the specified OutMemoryDataOperation!"
);
static_assert
(
!
use_multiblock
||
(
use_multiblock
&&
!
OutputIndex
),
...
...
@@ -209,8 +217,8 @@ struct DeviceReduceMultiBlock
const
std
::
array
<
index_t
,
NumDstDim
>
outLengths
,
const
std
::
array
<
index_t
,
NumDstDim
>
outStrides
,
const
std
::
array
<
int
,
NumReduceDim
>
reduceDims
,
float
alpha
,
float
beta
,
double
alpha
,
double
beta
,
const
InDataType
*
in_dev
,
const
IndexDataType
*
in_index_dev
,
OutDataType
*
out_dev
,
...
...
@@ -494,8 +502,8 @@ struct DeviceReduceMultiBlock
const
std
::
array
<
index_t
,
NumDstDim
>
outLengths
,
const
std
::
array
<
index_t
,
NumDstDim
>
outStrides
,
const
std
::
array
<
int
,
NumReduceDim
>
reduceDims
,
float
alpha
,
float
beta
,
double
alpha
,
double
beta
,
const
void
*
in_dev
,
const
void
*
in_index_dev
,
void
*
out_dev
,
...
...
include/ck/tensor_operation/gpu/device/impl/device_reduce_threadwise.hpp
View file @
644df335
...
...
@@ -35,8 +35,17 @@ template <typename InDataType,
index_t
InSrcVectorDim
,
index_t
InSrcVectorSize
,
index_t
OutDstVectorSize
>
struct
DeviceReduceThreadWise
:
public
DeviceReduce
<
Rank
,
NumReduceDim
,
InElementwiseOperation
,
AccElementwiseOperation
>
struct
DeviceReduceThreadWise
:
public
DeviceReduce
<
InDataType
,
AccDataType
,
OutDataType
,
Rank
,
NumReduceDim
,
ReduceOperation
,
InElementwiseOperation
,
AccElementwiseOperation
,
PropagateNan
,
OutputIndex
>
{
static_assert
(
Rank
<=
6
,
"Bigger Rank size is not supported!"
);
...
...
@@ -156,8 +165,8 @@ struct DeviceReduceThreadWise
const
std
::
array
<
index_t
,
NumDstDim
>
outLengths
,
const
std
::
array
<
index_t
,
NumDstDim
>
outStrides
,
const
std
::
array
<
int
,
NumReduceDim
>
reduceDims
,
float
alpha
,
float
beta
,
double
alpha
,
double
beta
,
const
InDataType
*
in_dev
,
OutDataType
*
out_dev
,
IndexDataType
*
out_index_dev
,
...
...
@@ -332,8 +341,8 @@ struct DeviceReduceThreadWise
const
std
::
array
<
index_t
,
NumDstDim
>
outLengths
,
const
std
::
array
<
index_t
,
NumDstDim
>
outStrides
,
const
std
::
array
<
int
,
NumReduceDim
>
reduceDims
,
float
alpha
,
float
beta
,
double
alpha
,
double
beta
,
const
void
*
in_dev
,
const
void
*
in_index_dev
,
void
*
out_dev
,
...
...
include/ck/tensor_operation/gpu/device/impl/device_softmax_impl.hpp
View file @
644df335
...
...
@@ -156,19 +156,20 @@ struct DeviceSoftmaxImpl : public DeviceSoftmax<InDataType,
Argument
(
const
std
::
vector
<
index_t
>
inLengths
,
const
std
::
vector
<
index_t
>
inStrides
,
const
std
::
vector
<
index_t
>
reduceDims
,
AccDataTyp
e
alpha
,
AccDataTyp
e
beta
,
doubl
e
alpha
,
doubl
e
beta
,
const
InDataType
*
in_dev
,
OutDataType
*
out_dev
,
InElementwiseOp
in_elementwise_op
,
AccElementwiseOp
acc_elementwise_op
)
:
alpha_
{
alpha
},
beta_
{
beta
},
in_dev_
{
in_dev
},
:
in_dev_
{
in_dev
},
out_dev_
{
out_dev
},
in_elementwise_op_
{
in_elementwise_op
},
acc_elementwise_op_
{
acc_elementwise_op
}
{
alpha_
=
static_cast
<
AccDataType
>
(
alpha
);
beta_
=
static_cast
<
AccDataType
>
(
beta
);
if
(
Rank
!=
inLengths
.
size
()
||
Rank
!=
inStrides
.
size
()
||
NumReduceDim
!=
reduceDims
.
size
())
{
...
...
@@ -336,8 +337,8 @@ struct DeviceSoftmaxImpl : public DeviceSoftmax<InDataType,
static
auto
MakeArgument
(
const
std
::
vector
<
index_t
>
inLengths
,
const
std
::
vector
<
index_t
>
inStrides
,
const
std
::
vector
<
int
>
reduceDims
,
const
AccDataTyp
e
alpha
,
const
AccDataTyp
e
beta
,
doubl
e
alpha
,
doubl
e
beta
,
const
InDataType
*
in_dev
,
OutDataType
*
out_dev
,
InElementwiseOp
in_elementwise_op
,
...
...
@@ -375,8 +376,8 @@ struct DeviceSoftmaxImpl : public DeviceSoftmax<InDataType,
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
std
::
vector
<
index_t
>
inLengths
,
const
std
::
vector
<
index_t
>
inStrides
,
const
std
::
vector
<
int
>
reduceDims
,
const
void
*
alpha
,
const
void
*
beta
,
double
alpha
,
double
beta
,
const
void
*
in_dev
,
void
*
out_dev
,
InElementwiseOp
in_elementwise_op
,
...
...
@@ -385,8 +386,8 @@ struct DeviceSoftmaxImpl : public DeviceSoftmax<InDataType,
return
std
::
make_unique
<
Argument
>
(
inLengths
,
inStrides
,
reduceDims
,
*
static_cast
<
const
AccDataType
*>
(
alpha
)
,
*
static_cast
<
const
AccDataType
*>
(
beta
)
,
alpha
,
beta
,
static_cast
<
const
InDataType
*>
(
in_dev
),
static_cast
<
OutDataType
*>
(
out_dev
),
in_elementwise_op
,
...
...
include/ck/tensor_operation/gpu/device/impl/device_sparse_embedding
3
_forward_layernorm.hpp
→
include/ck/tensor_operation/gpu/device/impl/device_sparse_embedding
s
_forward_layernorm.hpp
View file @
644df335
...
...
@@ -12,7 +12,7 @@
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_sparse_embedding
3
_forward_layernorm.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_sparse_embedding
s
_forward_layernorm.hpp"
namespace
ck
{
namespace
tensor_operation
{
...
...
@@ -24,16 +24,17 @@ template <typename EmbType,
typename
BetaDataType
,
typename
AccDataType
,
typename
OutType
,
typename
EmbElementwiseOperation
,
ck
::
index_t
BlockSize
,
ck
::
index_t
DimClusterSize
,
ck
::
index_t
RowClusterSize
,
ck
::
index_t
DimPerBlock
,
ck
::
index_t
RowPerBlock
,
ck
::
index_t
DimThreadSize
,
ck
::
index_t
RowVectorSize
>
struct
DeviceSparseEmbedding3ForwardLayernorm
:
public
BaseOperator
ck
::
index_t
RowVectorSize
,
ck
::
index_t
NumEmbeddings
>
struct
DeviceSparseEmbeddingsForwardLayernorm
:
public
BaseOperator
{
static
auto
MakeOutputDescriptor
(
const
index_t
index_length
,
const
index_t
rows
)
{
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
index_length
,
rows
));
...
...
@@ -42,96 +43,79 @@ struct DeviceSparseEmbedding3ForwardLayernorm : public BaseOperator
struct
Argument
:
public
BaseArgument
{
Argument
(
OutType
*
p_out
,
const
EmbType
*
p_emb_a
,
const
EmbType
*
p_emb_b
,
const
EmbType
*
p_emb_c
,
const
IndexType
*
p_index_a
,
const
IndexType
*
p_index_b
,
const
IndexType
*
p_index_c
,
const
ck
::
Array
<
EmbType
*
,
NumEmbeddings
>&
p_embs
,
const
ck
::
Array
<
IndexType
*
,
NumEmbeddings
>&
p_indexs
,
const
GammaDataType
*
p_gamma
,
const
BetaDataType
*
p_beta
,
const
ck
::
index_t
NumRows
,
const
ck
::
index_t
EmbeddingDim
,
const
ck
::
index_t
IndexLength
,
const
AccDataType
epsilon
)
const
AccDataType
epsilon
,
const
EmbElementwiseOperation
emb_elementwise_op
)
:
p_out_
(
p_out
),
p_emb_a_
(
p_emb_a
),
p_emb_b_
(
p_emb_b
),
p_emb_c_
(
p_emb_c
),
p_index_a_
(
p_index_a
),
p_index_b_
(
p_index_b
),
p_index_c_
(
p_index_c
),
p_embs_
(
p_embs
),
p_indexs_
(
p_indexs
),
p_gamma_
(
p_gamma
),
p_beta_
(
p_beta
),
NumRows_
(
NumRows
),
EmbeddingDim_
(
EmbeddingDim
),
IndexLength_
(
IndexLength
),
epsilon_
(
epsilon
)
epsilon_
(
epsilon
),
emb_elementwise_op_
(
emb_elementwise_op
)
{
grid_size_
=
(
IndexLength
+
DimClusterSize
-
1
)
/
DimClusterSize
;
}
OutType
*
p_out_
;
const
EmbType
*
p_emb_a_
;
const
EmbType
*
p_emb_b_
;
const
EmbType
*
p_emb_c_
;
const
IndexType
*
p_index_a_
;
const
IndexType
*
p_index_b_
;
const
IndexType
*
p_index_c_
;
ck
::
Array
<
EmbType
*
,
NumEmbeddings
>
p_embs_
;
ck
::
Array
<
IndexType
*
,
NumEmbeddings
>
p_indexs_
;
const
GammaDataType
*
p_gamma_
;
const
BetaDataType
*
p_beta_
;
ck
::
index_t
NumRows_
;
ck
::
index_t
EmbeddingDim_
;
ck
::
index_t
IndexLength_
;
AccDataType
epsilon_
;
EmbElementwiseOperation
emb_elementwise_op_
;
size_t
grid_size_
;
};
virtual
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
void
*
p_out
,
const
void
*
p_emb_a
,
const
void
*
p_emb_b
,
const
void
*
p_emb_c
,
const
void
*
p_index_a
,
const
void
*
p_index_b
,
const
void
*
p_index_c
,
const
void
*
p_gamma
,
const
void
*
p_beta
,
ck
::
index_t
NumRows
,
ck
::
index_t
EmbeddingDim
,
ck
::
index_t
IndexLength
,
const
AccDataType
epsilon
)
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
void
*
p_out
,
const
ck
::
Array
<
EmbType
*
,
NumEmbeddings
>&
p_embs
,
const
ck
::
Array
<
IndexType
*
,
NumEmbeddings
>&
p_indexs
,
const
void
*
p_gamma
,
const
void
*
p_beta
,
ck
::
index_t
EmbeddingDim
,
ck
::
index_t
IndexLength
,
const
AccDataType
epsilon
,
const
EmbElementwiseOperation
emb_elementwise_op
)
{
return
std
::
make_unique
<
Argument
>
(
reinterpret_cast
<
OutType
*>
(
p_out
),
reinterpret_cast
<
const
EmbType
*>
(
p_emb_a
),
reinterpret_cast
<
const
EmbType
*>
(
p_emb_b
),
reinterpret_cast
<
const
EmbType
*>
(
p_emb_c
),
reinterpret_cast
<
const
IndexType
*>
(
p_index_a
),
reinterpret_cast
<
const
IndexType
*>
(
p_index_b
),
reinterpret_cast
<
const
IndexType
*>
(
p_index_c
),
p_embs
,
p_indexs
,
reinterpret_cast
<
const
GammaDataType
*>
(
p_gamma
),
reinterpret_cast
<
const
BetaDataType
*>
(
p_beta
),
NumRows
,
EmbeddingDim
,
IndexLength
,
epsilon
);
epsilon
,
emb_elementwise_op
);
}
using
GridwiseSparseEmbedding
=
GridwiseSparseEmbedding
3
ForwardLayernorm
<
EmbType
,
GridwiseSparseEmbedding
s
ForwardLayernorm
<
EmbType
,
IndexType
,
GammaDataType
,
BetaDataType
,
AccDataType
,
OutType
,
decltype
(
MakeOutputDescriptor
(
1
,
1
)),
EmbElementwiseOperation
,
BlockSize
,
DimClusterSize
,
RowClusterSize
,
DimPerBlock
,
RowPerBlock
,
DimThreadSize
,
RowVectorSize
>
;
RowVectorSize
,
NumEmbeddings
>
;
struct
Invoker
:
public
BaseInvoker
{
...
...
@@ -139,14 +123,16 @@ struct DeviceSparseEmbedding3ForwardLayernorm : public BaseOperator
{
auto
out_desc
=
MakeOutputDescriptor
(
arg
.
IndexLength_
,
arg
.
EmbeddingDim_
);
const
auto
kernel_main
=
kernel_sparse_embedding
3
_forward_layernorm
<
GridwiseSparseEmbedding
,
kernel_sparse_embedding
s
_forward_layernorm
<
GridwiseSparseEmbedding
,
EmbType
,
IndexType
,
GammaDataType
,
BetaDataType
,
AccDataType
,
OutType
,
decltype
(
out_desc
)
>
;
decltype
(
out_desc
),
EmbElementwiseOperation
,
NumEmbeddings
>
;
float
avg_time
=
0
;
avg_time
+=
launch_and_time_kernel
(
stream_config
,
kernel_main
,
...
...
@@ -154,16 +140,13 @@ struct DeviceSparseEmbedding3ForwardLayernorm : public BaseOperator
dim3
(
BlockSize
),
0
,
arg
.
p_out_
,
arg
.
p_emb_a_
,
arg
.
p_emb_b_
,
arg
.
p_emb_c_
,
arg
.
p_index_a_
,
arg
.
p_index_b_
,
arg
.
p_index_c_
,
arg
.
p_embs_
,
arg
.
p_indexs_
,
arg
.
p_gamma_
,
arg
.
p_beta_
,
out_desc
,
arg
.
epsilon_
);
arg
.
epsilon_
,
arg
.
emb_elementwise_op_
);
return
(
avg_time
);
}
...
...
@@ -177,7 +160,7 @@ struct DeviceSparseEmbedding3ForwardLayernorm : public BaseOperator
static
bool
IsSupportedArgument
(
const
Argument
*
p_arg
)
{
return
(
RowPerBlock
==
p_arg
->
EmbeddingDim_
)
&&
(
p_arg
->
NumRows_
%
DimPerBlock
==
0
)
;
return
(
RowPerBlock
==
p_arg
->
EmbeddingDim_
);
}
bool
IsSupportedArgument
(
const
BaseArgument
*
p_arg
)
override
...
...
@@ -195,7 +178,7 @@ struct DeviceSparseEmbedding3ForwardLayernorm : public BaseOperator
auto
str
=
std
::
stringstream
();
// clang-format off
str
<<
"DeviceSparseEmbedding
3
ForwardLayernorm_"
<<
BlockSize
<<
"_"
<<
str
<<
"DeviceSparseEmbedding
s
ForwardLayernorm_"
<<
BlockSize
<<
"_"
<<
DimClusterSize
<<
"x"
<<
RowClusterSize
<<
"_"
<<
DimPerBlock
<<
"x"
<<
RowPerBlock
<<
"_"
<<
DimThreadSize
<<
"x"
<<
RowVectorSize
;
...
...
include/ck/tensor_operation/gpu/element/element_wise_operation.hpp
View file @
644df335
...
...
@@ -172,6 +172,42 @@ struct AddAdd
}
};
// C = A * B
// E = (C + D0) x D1
struct
AddMultiply
{
template
<
typename
E
,
typename
C
,
typename
D0
,
typename
D1
>
__host__
__device__
void
operator
()(
E
&
e
,
const
C
&
c
,
const
D0
&
d0
,
const
D1
&
d1
)
const
;
template
<
>
__host__
__device__
void
operator
()
<
half_t
,
half_t
,
half_t
,
half_t
>
(
half_t
&
e
,
const
half_t
&
c
,
const
half_t
&
d0
,
const
half_t
&
d1
)
const
{
const
half_t
y
=
(
c
+
d0
)
*
d1
;
e
=
y
;
}
template
<
>
__host__
__device__
void
operator
()
<
half_t
,
float
,
half_t
,
half_t
>
(
half_t
&
e
,
const
float
&
c
,
const
half_t
&
d0
,
const
half_t
&
d1
)
const
{
const
half_t
y
=
(
type_convert
<
half_t
>
(
c
)
+
d0
)
*
d1
;
e
=
y
;
}
template
<
>
__host__
__device__
void
operator
()
<
float
,
float
,
half_t
,
half_t
>
(
float
&
e
,
const
float
&
c
,
const
half_t
&
d0
,
const
half_t
&
d1
)
const
{
const
float
y
=
(
c
+
d0
)
*
d1
;
e
=
y
;
}
};
// C = A * B
// E = FastGelu(C + D0 + D1)
struct
AddAddFastGelu
...
...
@@ -278,6 +314,40 @@ struct Normalize
double
epsilon_
;
};
// used by BatchNorm inference
// y = gamma * (x-mean) / sqrt(epsilon+variance) + beta
// The data type of mean and variance is used as AccDataType
struct
NormalizeInInfer
{
NormalizeInInfer
(
double
epsilon
=
1e-4
)
:
epsilon_
(
epsilon
)
{}
template
<
typename
T1
,
typename
T2
,
typename
T3
,
typename
T4
>
__host__
__device__
constexpr
void
operator
()(
T1
&
y
,
const
T1
&
x
,
const
T2
&
mean
,
const
T2
&
variance
,
const
T3
&
gamma
,
const
T4
&
beta
)
const
{
static_assert
(
std
::
is_same
<
T2
,
float
>::
value
||
std
::
is_same
<
T2
,
double
>::
value
,
"Data type is not supported by this operation!"
);
using
ck
::
type_convert
;
using
ck
::
math
::
sqrt
;
T2
tmp_x
,
tmp_y
;
tmp_x
=
type_convert
<
T2
>
(
x
);
tmp_y
=
((
tmp_x
-
mean
)
/
sqrt
(
variance
+
type_convert
<
T2
>
(
epsilon_
)))
*
type_convert
<
T2
>
(
gamma
)
+
type_convert
<
T2
>
(
beta
);
y
=
type_convert
<
T1
>
(
tmp_y
);
};
double
epsilon_
;
};
template
<
typename
Y
,
typename
X
>
struct
UnaryTypeConvert
;
...
...
include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp
View file @
644df335
...
...
@@ -154,6 +154,50 @@ struct BlockToCTileMap_M00_N0_M01Adapt
index_t
idx_M01
=
idx_M0
%
M01_
;
index_t
idx_N0_M01_local
=
idx_N0
+
idx_M01
*
N0
;
/**
* idxN0
*
* |< mtx N >|
*
* NPerBlock NPerBlock NPerBlock NPerBlock
* N_0 N_1 N_2 N_3
* - |-----------|-----------|-----------|-----|-----|-
* ^ | - - 0 |/----> 2 | | | |
* | | | / | | | | | M_0 MPerBlock
* | M | /| | | | | |
* |-0---|---/-|-----|-----|-----------|-----|-----|-
* | 1 | / | | | blockid | | |
* idxM0 | | | / | V | 5 | | | M_1 MPerBlock
* | - V 1 | - 3 | | | |
* |-----------|-----------|-----------|-----|-----|-
* mtx M | | | | | |
* | | | | | | M_2 MPerBlock
* | | | | | |
* |-----------|-----------|-----------|-----|-----|-
* | | | | | |
* | | | | | | M_3 MPerBlock
* | | | | | |
* |-----------|-----------|-----------|-----|-----|-
* V | | | | | |
* - |-----------|-----------|-----------|-----|-----|- M_4 MPerBlock
* | | | | | |
* |-----------|-----------|-----------|-----|-----|-
* Example:
* assume:
* M0 = 5
* N0 = 4
* block_1d_id = 5
* M01 = 2
*
* idx_N0 = 1
* idx_M0 = 1
* M01_adapt = 2
* idx_M00 = 0
* idx_M01 = 1
* idx_N0_M01_local = 5
* output {1, 2}
*/
return
make_tuple
(
idx_N0_M01_local
%
M01_adapt
+
idx_M00
*
M01_
,
idx_N0_M01_local
/
M01_adapt
);
}
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_waveletmodel.hpp
0 → 100644
View file @
644df335
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
namespace
ck
{
template
<
typename
TileLoadThreadGroup
,
index_t
NumGemmKPrefetchStage
>
struct
GridwiseGemmLoadWave
;
// 1-stage prefetch
template
<
typename
TileLoadThreadGroup
>
struct
GridwiseGemmLoadWave
<
TileLoadThreadGroup
,
1
>
{
__host__
__device__
static
constexpr
bool
IsSupported
(
index_t
/* num_loop */
)
{
// TODO: improve applicability
return
true
;
}
__host__
__device__
static
constexpr
bool
CalculateHasMainLoop
(
index_t
num_loop
)
{
return
num_loop
>
1
;
}
template
<
bool
HasMainLoop
,
typename
AGridDesc
,
typename
ABlockDesc
,
typename
ABlockTransfer
,
typename
AGridBuffer
,
typename
ABlockBuffer
,
typename
ABlockTransferStep
,
typename
BGridDesc
,
typename
BBlockDesc
,
typename
BBlockTransfer
,
typename
BGridBuffer
,
typename
BBlockBuffer
,
typename
BBlockTransferStep
>
static
__device__
void
RunLoadWavePipeline
(
const
AGridDesc
&
a_grid_desc
,
const
ABlockDesc
&
a_block_desc
,
ABlockTransfer
&
a_blockwise_copy
,
const
AGridBuffer
&
a_grid_buf
,
ABlockBuffer
&
a_block_buf
,
const
ABlockTransferStep
&
a_block_copy_step
,
const
BGridDesc
&
b_grid_desc
,
const
BBlockDesc
&
b_block_desc
,
BBlockTransfer
&
b_blockwise_copy
,
const
BGridBuffer
&
b_grid_buf
,
BBlockBuffer
&
b_block_buf
,
const
BBlockTransferStep
&
b_block_copy_step
,
index_t
num_loop
)
{
// global read 0
a_blockwise_copy
.
RunRead
(
a_grid_desc
,
a_grid_buf
);
b_blockwise_copy
.
RunRead
(
b_grid_desc
,
b_grid_buf
);
// move to 1
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc
,
a_block_copy_step
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc
,
b_block_copy_step
);
// LDS write 0
a_blockwise_copy
.
RunWrite
(
a_block_desc
,
a_block_buf
);
b_blockwise_copy
.
RunWrite
(
b_block_desc
,
b_block_buf
);
if
constexpr
(
HasMainLoop
)
{
index_t
i
=
0
;
do
{
// sync for Load threads()
block_sync_lds
();
// global read i + 1
a_blockwise_copy
.
RunRead
(
a_grid_desc
,
a_grid_buf
);
b_blockwise_copy
.
RunRead
(
b_grid_desc
,
b_grid_buf
);
// move to i + 2
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc
,
a_block_copy_step
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc
,
b_block_copy_step
);
// sync with math threads()
block_sync_lds
();
// LDS write i+1
a_blockwise_copy
.
RunWrite
(
a_block_desc
,
a_block_buf
);
b_blockwise_copy
.
RunWrite
(
b_block_desc
,
b_block_buf
);
++
i
;
}
while
(
i
<
(
num_loop
-
1
));
}
// tail
{
block_sync_lds
();
// GEMM num_loop - 1
}
}
};
template
<
typename
TileMathThreadGroup
,
index_t
NumGemmKPrefetchStage
>
struct
GridwiseGemmMathWave
;
// 1- stage prefetch
template
<
typename
TileMathThreadGroup
>
struct
GridwiseGemmMathWave
<
TileMathThreadGroup
,
1
>
{
__host__
__device__
static
constexpr
bool
IsSupported
(
index_t
/* num_loop */
)
{
return
true
;
}
__host__
__device__
static
constexpr
bool
CalculateHasMainLoop
(
index_t
num_loop
)
{
return
num_loop
>
1
;
}
template
<
bool
HasMainLoop
,
typename
ABlockBuffer
,
typename
BBlockBuffer
,
typename
BlockwiseGemm
,
typename
CThreadBuffer
>
static
__device__
void
RunMathWavePipeline
(
ABlockBuffer
&
a_block_buf
,
BBlockBuffer
&
b_block_buf
,
const
BlockwiseGemm
&
block_gemm
,
CThreadBuffer
&
c_thread_buf
,
index_t
num_loop
)
{
// Initialize C
c_thread_buf
.
Clear
();
// main body
if
constexpr
(
HasMainLoop
)
{
index_t
i
=
0
;
do
{
block_sync_lds
();
// GEMM i
block_gemm
.
Run
(
a_block_buf
,
b_block_buf
,
c_thread_buf
);
block_sync_lds
();
++
i
;
}
while
(
i
<
(
num_loop
-
1
));
}
// tail
{
block_sync_lds
();
// GEMM num_loop - 1
block_gemm
.
Run
(
a_block_buf
,
b_block_buf
,
c_thread_buf
);
}
}
};
}
// namespace ck
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_waveletmodel_cshuffle.hpp
0 → 100644
View file @
644df335
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/multi_index_transform_helper.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_waveletmodel.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
namespace
ck
{
template
<
typename
ABDataType
,
typename
FloatGemmAcc
,
typename
EDataTypeShuffle
,
typename
EDataType
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
EElementwiseOperation
,
InMemoryDataOperationEnum
CGlobalMemoryDataOperation
,
typename
AGridDesc_M_K
,
typename
BGridDesc_N_K
,
typename
EGridDesc_M_N
,
index_t
NumGemmKPrefetchStage
,
index_t
TileLoadThreadGroupSize
,
index_t
TileMathThreadGroupSize
,
index_t
MPerBlock
,
index_t
NPerBlock
,
index_t
KPerBlock
,
index_t
AK1Value
,
index_t
BK1Value
,
index_t
MPerXdl
,
index_t
NPerXdl
,
index_t
MXdlPerWave
,
index_t
NXdlPerWave
,
typename
ABlockTransferThreadClusterLengths_AK0_M_AK1
,
typename
ABlockTransferThreadClusterArrangeOrder
,
typename
ABlockTransferSrcAccessOrder
,
index_t
ABlockTransferSrcVectorDim
,
index_t
ABlockTransferSrcScalarPerVector
,
index_t
ABlockTransferDstScalarPerVector_AK1
,
bool
AThreadTransferSrcResetCoordinateAfterRun
,
index_t
ABlockLdsExtraM
,
typename
BBlockTransferThreadClusterLengths_BK0_N_BK1
,
typename
BBlockTransferThreadClusterArrangeOrder
,
typename
BBlockTransferSrcAccessOrder
,
index_t
BBlockTransferSrcVectorDim
,
index_t
BBlockTransferSrcScalarPerVector
,
index_t
BBlockTransferDstScalarPerVector_BK1
,
bool
BThreadTransferSrcResetCoordinateAfterRun
,
index_t
BBlockLdsExtraN
,
index_t
CShuffleMXdlPerWavePerShuffle
,
index_t
CShuffleNXdlPerWavePerShuffle
,
typename
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
index_t
CShuffleBlockTransferScalarPerVector_NPerBlock
>
struct
GridwiseGemm_k0mk1_k0nk1_mn_xdl_waveletmodel_cshuffle
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
static
constexpr
auto
I4
=
Number
<
4
>
{};
static
constexpr
auto
I5
=
Number
<
5
>
{};
static
constexpr
auto
I6
=
Number
<
6
>
{};
static
constexpr
auto
I7
=
Number
<
7
>
{};
// K1 should be Number<...>
static
constexpr
auto
AK1
=
Number
<
AK1Value
>
{};
static
constexpr
auto
BK1
=
Number
<
BK1Value
>
{};
static
constexpr
auto
AK0PerBlock
=
Number
<
KPerBlock
/
AK1Value
>
{};
static
constexpr
auto
BK0PerBlock
=
Number
<
KPerBlock
/
BK1Value
>
{};
struct
TileLoadThreadGroup
{
__device__
static
constexpr
index_t
GetNumOfThread
()
{
return
TileLoadThreadGroupSize
;
}
__device__
static
constexpr
bool
IsBelong
()
{
return
(
get_thread_local_1d_id
()
>=
TileLoadThreadGroupSize
);
}
__device__
static
index_t
GetThreadId
()
{
return
get_thread_local_1d_id
()
-
TileMathThreadGroupSize
;
}
};
struct
TileMathThreadGroup
{
__device__
static
constexpr
index_t
GetNumOfThread
()
{
return
TileMathThreadGroupSize
;
}
__device__
static
constexpr
bool
IsBelong
()
{
return
get_thread_local_1d_id
()
<
TileMathThreadGroupSize
;
}
__device__
static
index_t
GetThreadId
()
{
return
get_thread_local_1d_id
();
}
};
using
CShuffleBlockTransferThreadGroup
=
ThisThreadBlock
<
TileMathThreadGroupSize
>
;
// load and math+store Wave pipelines.
// TODO: build pipelines blocks scheduling parallel tasks
using
GridwiseGemmLoad
=
GridwiseGemmLoadWave
<
TileLoadThreadGroup
,
NumGemmKPrefetchStage
>
;
using
GridwiseGemmMath
=
GridwiseGemmMathWave
<
TileMathThreadGroup
,
NumGemmKPrefetchStage
>
;
__host__
__device__
static
constexpr
auto
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1
()
{
// A matrix in LDS memory, dst of blockwise copy
return
make_naive_tensor_descriptor
(
make_tuple
(
AK0PerBlock
,
Number
<
MPerBlock
>
{},
AK1
),
make_tuple
(
Number
<
MPerBlock
+
ABlockLdsExtraM
>
{}
*
AK1
,
AK1
,
I1
));
}
__host__
__device__
static
constexpr
auto
GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1
()
{
// B matrix in LDS memory, dst of blockwise copy
return
make_naive_tensor_descriptor
(
make_tuple
(
BK0PerBlock
,
Number
<
NPerBlock
>
{},
BK1
),
make_tuple
(
Number
<
NPerBlock
+
BBlockLdsExtraN
>
{}
*
BK1
,
BK1
,
I1
));
}
__host__
__device__
static
constexpr
auto
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
()
{
constexpr
index_t
MWave
=
MPerBlock
/
(
MXdlPerWave
*
MPerXdl
);
constexpr
index_t
NWave
=
NPerBlock
/
(
NXdlPerWave
*
NPerXdl
);
constexpr
auto
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
Number
<
CShuffleMXdlPerWavePerShuffle
*
MWave
*
MPerXdl
>
{},
I1
,
Number
<
CShuffleNXdlPerWavePerShuffle
*
NWave
*
NPerXdl
>
{}));
return
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
;
}
__host__
__device__
static
constexpr
index_t
GetSharedMemoryNumberOfByte
()
{
// LDS allocation for A and B: be careful of alignment
constexpr
auto
a_block_desc_ak0_m_ak1
=
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1
();
constexpr
auto
b_block_desc_bk0_n_bk1
=
GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1
();
// lds max alignment
constexpr
auto
max_lds_align
=
math
::
lcm
(
AK1
,
BK1
);
constexpr
auto
a_block_space_size_aligned
=
math
::
integer_least_multiple
(
a_block_desc_ak0_m_ak1
.
GetElementSpaceSize
(),
max_lds_align
);
constexpr
auto
b_block_space_size_aligned
=
math
::
integer_least_multiple
(
b_block_desc_bk0_n_bk1
.
GetElementSpaceSize
(),
max_lds_align
);
// LDS allocation for C shuffle in LDS
constexpr
auto
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
=
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
();
constexpr
auto
c_block_size
=
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
.
GetElementSpaceSize
();
return
math
::
max
((
a_block_space_size_aligned
+
b_block_space_size_aligned
)
*
sizeof
(
ABDataType
),
c_block_size
*
sizeof
(
EDataTypeShuffle
));
}
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
template
<
typename
Block2ETileMap
>
__host__
__device__
static
constexpr
bool
CheckValidity
(
const
AGridDesc_M_K
&
a_grid_desc_m_k
,
const
BGridDesc_N_K
&
b_grid_desc_n_k
,
const
EGridDesc_M_N
&
e_grid_desc_m_n
,
const
Block2ETileMap
&
/*block_2_etile_map*/
)
{
static_assert
((
MPerBlock
%
(
MPerXdl
*
MXdlPerWave
)
==
0
)
&&
(
NPerBlock
%
(
NXdlPerWave
*
NPerXdl
))
==
0
,
"Invalid tuning param!"
);
const
auto
M
=
a_grid_desc_m_k
.
GetLength
(
I0
);
const
auto
N
=
b_grid_desc_n_k
.
GetLength
(
I0
);
const
auto
K
=
a_grid_desc_m_k
.
GetLength
(
I1
);
// check consistency of desc
if
(
!
(
M
==
e_grid_desc_m_n
.
GetLength
(
I0
)
&&
N
==
e_grid_desc_m_n
.
GetLength
(
I1
)
&&
K
==
b_grid_desc_n_k
.
GetLength
(
I1
)))
{
return
false
;
}
// check tile size
if
(
!
(
M
%
MPerBlock
==
0
&&
N
%
NPerBlock
==
0
&&
K
%
KPerBlock
==
0
))
{
return
false
;
}
// check gridwise gemm pipeline
const
auto
num_k_loop
=
K
/
KPerBlock
;
if
(
!
GridwiseGemmMath
::
IsSupported
(
num_k_loop
))
{
return
false
;
}
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
// check tensor size: cannot be larger than 2GB each
constexpr
long_index_t
TwoGB
=
(
long_index_t
{
1
}
<<
31
);
if
(
!
(
a_grid_desc_m_k
.
GetElementSpaceSize
()
*
sizeof
(
ABDataType
)
<=
TwoGB
&&
b_grid_desc_n_k
.
GetElementSpaceSize
()
*
sizeof
(
ABDataType
)
<=
TwoGB
&&
e_grid_desc_m_n
.
GetElementSpaceSize
()
*
sizeof
(
EDataType
)
<=
TwoGB
))
{
return
false
;
}
return
true
;
}
__host__
__device__
static
constexpr
bool
CalculateHasMainKBlockLoop
(
index_t
K
)
{
const
index_t
num_loop
=
K
/
KPerBlock
;
return
GridwiseGemmMath
::
CalculateHasMainLoop
(
num_loop
);
}
// return block_id to E matrix tile idx (m0, n0) mapping
__host__
__device__
static
constexpr
auto
MakeDefaultBlock2ETileMap
(
const
EGridDesc_M_N
&
e_grid_desc_m_n
)
{
const
auto
M
=
e_grid_desc_m_n
.
GetLength
(
I0
);
const
auto
N
=
e_grid_desc_m_n
.
GetLength
(
I1
);
constexpr
auto
M1
=
Number
<
MPerBlock
>
{};
constexpr
auto
N1
=
Number
<
NPerBlock
>
{};
const
auto
M0
=
M
/
M1
;
const
auto
N0
=
N
/
N1
;
constexpr
auto
M01
=
I1
;
constexpr
auto
N01
=
I1
;
const
auto
m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_unmerge_transform
(
make_tuple
(
M0
,
M01
)),
make_unmerge_transform
(
make_tuple
(
N0
,
N01
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
,
3
>
{}));
const
auto
cblockid_to_m00_m01_n00_n01_block_cluster_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
M0
,
N0
,
M01
,
N01
))),
make_tuple
(
Sequence
<
0
,
1
,
2
,
3
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
const
auto
cblockid_to_m0_n0_block_cluster_adaptor
=
chain_tensor_adaptors
(
m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor
,
cblockid_to_m00_m01_n00_n01_block_cluster_adaptor
);
return
cblockid_to_m0_n0_block_cluster_adaptor
;
}
__host__
__device__
static
constexpr
index_t
CalculateGridSize
(
const
EGridDesc_M_N
&
e_grid_desc_m_n
)
{
const
auto
M
=
e_grid_desc_m_n
.
GetLength
(
I0
);
const
auto
N
=
e_grid_desc_m_n
.
GetLength
(
I1
);
const
index_t
grid_size
=
(
M
/
MPerBlock
)
*
(
N
/
NPerBlock
);
return
grid_size
;
}
// A desc for source in blockwise copy
__host__
__device__
static
constexpr
auto
MakeDefaultAGridDescriptor_AK0_M_AK1
(
const
AGridDesc_M_K
&
a_grid_desc_m_k
)
{
const
auto
M
=
a_grid_desc_m_k
.
GetLength
(
I0
);
const
auto
K
=
a_grid_desc_m_k
.
GetLength
(
I1
);
const
auto
AK0
=
K
/
AK1
;
return
transform_tensor_descriptor
(
a_grid_desc_m_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
AK0
,
AK1
)),
make_pass_through_transform
(
M
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
}
// B desc for source in blockwise copy
__host__
__device__
static
constexpr
auto
MakeDefaultBGridDescriptor_BK0_N_BK1
(
const
BGridDesc_N_K
&
b_grid_desc_n_k
)
{
const
auto
N
=
b_grid_desc_n_k
.
GetLength
(
I0
);
const
auto
K
=
b_grid_desc_n_k
.
GetLength
(
I1
);
const
auto
BK0
=
K
/
BK1
;
return
transform_tensor_descriptor
(
b_grid_desc_n_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
BK0
,
BK1
)),
make_pass_through_transform
(
N
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
}
// E desc for destination in blockwise copy
template
<
typename
EGridDescriptor_M_N
>
__host__
__device__
static
constexpr
auto
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
const
EGridDescriptor_M_N
&
e_grid_desc_m_n
)
{
const
auto
M
=
e_grid_desc_m_n
.
GetLength
(
I0
);
const
auto
N
=
e_grid_desc_m_n
.
GetLength
(
I1
);
const
auto
MBlock
=
M
/
MPerBlock
;
const
auto
NBlock
=
N
/
NPerBlock
;
const
auto
e_grid_desc_mblock_mperblock_nblock_nperblock
=
transform_tensor_descriptor
(
e_grid_desc_m_n
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
MBlock
,
Number
<
MPerBlock
>
{})),
make_unmerge_transform
(
make_tuple
(
NBlock
,
Number
<
NPerBlock
>
{}))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
1
>
{},
Sequence
<
2
,
3
>
{}));
return
e_grid_desc_mblock_mperblock_nblock_nperblock
;
}
using
EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
=
remove_cvref_t
<
decltype
(
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
EGridDesc_M_N
{}))
>
;
using
DefaultBlock2ETileMap
=
remove_cvref_t
<
decltype
(
MakeDefaultBlock2ETileMap
(
EGridDesc_M_N
{}))
>
;
template
<
bool
HasMainKBlockLoop
,
typename
AGridDesc_AK0_M_AK1
,
typename
BGridDesc_BK0_N_BK1
,
typename
Block2ETileMap
>
__device__
static
void
Run
(
const
ABDataType
*
__restrict__
p_a_grid
,
const
ABDataType
*
__restrict__
p_b_grid
,
EDataType
*
__restrict__
p_e_grid
,
void
*
__restrict__
p_shared
,
const
AElementwiseOperation
&
a_element_op
,
const
BElementwiseOperation
&
b_element_op
,
const
EElementwiseOperation
&
e_element_op
,
const
AGridDesc_AK0_M_AK1
&
a_grid_desc_ak0_m_ak1
,
const
BGridDesc_BK0_N_BK1
&
b_grid_desc_bk0_n_bk1
,
const
EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
&
e_grid_desc_mblock_mperblock_nblock_nperblock
,
const
Block2ETileMap
&
block_2_etile_map
)
{
// build loadWave and MathWave pipelines
// loadWave and MathWave synchronized through LDS
// A matrix in LDS memory, dst of blockwise copy
constexpr
auto
a_block_desc_ak0_m_ak1
=
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1
();
// B matrix in LDS memory, dst of blockwise copy
constexpr
auto
b_block_desc_bk0_n_bk1
=
GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1
();
// lds max alignment
constexpr
auto
max_lds_align
=
math
::
lcm
(
AK1
,
BK1
);
// LDS allocation for A and B: be careful of alignment
constexpr
auto
a_block_space_size_aligned
=
math
::
integer_least_multiple
(
a_block_desc_ak0_m_ak1
.
GetElementSpaceSize
(),
max_lds_align
);
auto
a_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
ABDataType
*>
(
p_shared
),
a_block_desc_ak0_m_ak1
.
GetElementSpaceSize
());
auto
b_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
ABDataType
*>
(
p_shared
)
+
a_block_space_size_aligned
,
b_block_desc_bk0_n_bk1
.
GetElementSpaceSize
());
constexpr
auto
a_block_slice_copy_step
=
make_multi_index
(
KPerBlock
/
AK1
,
0
,
0
);
constexpr
auto
b_block_slice_copy_step
=
make_multi_index
(
KPerBlock
/
BK1
,
0
,
0
);
const
index_t
num_k_block_main_loop
=
__builtin_amdgcn_readfirstlane
(
(
a_grid_desc_ak0_m_ak1
.
GetLength
(
I0
)
*
a_grid_desc_ak0_m_ak1
.
GetLength
(
I2
))
/
KPerBlock
);
// divide block work by [M, N]
const
auto
block_work_idx
=
block_2_etile_map
.
CalculateBottomIndex
(
make_multi_index
(
get_block_1d_id
()));
// HACK: this force m/n_block_data_idx_on_grid into SGPR
const
index_t
m_block_data_idx_on_grid
=
__builtin_amdgcn_readfirstlane
(
block_work_idx
[
I0
]
*
MPerBlock
);
const
index_t
n_block_data_idx_on_grid
=
__builtin_amdgcn_readfirstlane
(
block_work_idx
[
I1
]
*
NPerBlock
);
if
(
TileLoadThreadGroup
::
IsBelong
())
{
// LoadWave
const
auto
a_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_a_grid
,
a_grid_desc_ak0_m_ak1
.
GetElementSpaceSize
());
const
auto
b_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_b_grid
,
b_grid_desc_bk0_n_bk1
.
GetElementSpaceSize
());
// A matrix blockwise copy
auto
a_blockwise_copy
=
ThreadGroupTensorSliceTransfer_v4r1
<
TileLoadThreadGroup
,
AElementwiseOperation
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
InMemoryDataOperationEnum
::
Set
,
Sequence
<
AK0PerBlock
,
MPerBlock
,
AK1
>
,
ABlockTransferThreadClusterLengths_AK0_M_AK1
,
ABlockTransferThreadClusterArrangeOrder
,
ABDataType
,
ABDataType
,
decltype
(
a_grid_desc_ak0_m_ak1
),
decltype
(
a_block_desc_ak0_m_ak1
),
ABlockTransferSrcAccessOrder
,
Sequence
<
1
,
0
,
2
>
,
ABlockTransferSrcVectorDim
,
2
,
ABlockTransferSrcScalarPerVector
,
ABlockTransferDstScalarPerVector_AK1
,
1
,
1
,
AThreadTransferSrcResetCoordinateAfterRun
,
true
,
NumGemmKPrefetchStage
>
(
a_grid_desc_ak0_m_ak1
,
make_multi_index
(
0
,
m_block_data_idx_on_grid
,
0
),
a_element_op
,
a_block_desc_ak0_m_ak1
,
make_multi_index
(
0
,
0
,
0
),
ck
::
tensor_operation
::
element_wise
::
PassThrough
{});
// B matrix blockwise copy
auto
b_blockwise_copy
=
ThreadGroupTensorSliceTransfer_v4r1
<
TileLoadThreadGroup
,
BElementwiseOperation
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
InMemoryDataOperationEnum
::
Set
,
Sequence
<
BK0PerBlock
,
NPerBlock
,
BK1
>
,
BBlockTransferThreadClusterLengths_BK0_N_BK1
,
BBlockTransferThreadClusterArrangeOrder
,
ABDataType
,
ABDataType
,
decltype
(
b_grid_desc_bk0_n_bk1
),
decltype
(
b_block_desc_bk0_n_bk1
),
BBlockTransferSrcAccessOrder
,
Sequence
<
1
,
0
,
2
>
,
BBlockTransferSrcVectorDim
,
2
,
BBlockTransferSrcScalarPerVector
,
BBlockTransferDstScalarPerVector_BK1
,
1
,
1
,
BThreadTransferSrcResetCoordinateAfterRun
,
true
,
NumGemmKPrefetchStage
>
(
b_grid_desc_bk0_n_bk1
,
make_multi_index
(
0
,
n_block_data_idx_on_grid
,
0
),
b_element_op
,
b_block_desc_bk0_n_bk1
,
make_multi_index
(
0
,
0
,
0
),
ck
::
tensor_operation
::
element_wise
::
PassThrough
{});
GridwiseGemmLoad
::
template
RunLoadWavePipeline
<
HasMainKBlockLoop
>(
a_grid_desc_ak0_m_ak1
,
a_block_desc_ak0_m_ak1
,
a_blockwise_copy
,
a_grid_buf
,
a_block_buf
,
a_block_slice_copy_step
,
b_grid_desc_bk0_n_bk1
,
b_block_desc_bk0_n_bk1
,
b_blockwise_copy
,
b_grid_buf
,
b_block_buf
,
b_block_slice_copy_step
,
num_k_block_main_loop
);
block_sync_lds
();
block_sync_lds
();
}
else
if
(
TileMathThreadGroup
::
IsBelong
())
{
// branch early for math wave
constexpr
index_t
KPack
=
math
::
max
(
math
::
lcm
(
AK1
,
BK1
),
MfmaSelector
<
ABDataType
,
MPerXdl
,
NPerXdl
>::
selected_mfma
.
k_per_blk
);
auto
blockwise_gemm
=
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
<
TileMathThreadGroupSize
,
ABDataType
,
FloatGemmAcc
,
decltype
(
a_block_desc_ak0_m_ak1
),
decltype
(
b_block_desc_bk0_n_bk1
),
MPerXdl
,
NPerXdl
,
MXdlPerWave
,
NXdlPerWave
,
KPack
>
{};
auto
c_thread_buf
=
blockwise_gemm
.
GetCThreadBuffer
();
auto
c_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_e_grid
,
e_grid_desc_mblock_mperblock_nblock_nperblock
.
GetElementSpaceSize
());
// TODO re-architect LDS+math stages
// Writing data to GMEM: only math wave is doing the work in cshuffle
GridwiseGemmMath
::
template
RunMathWavePipeline
<
HasMainKBlockLoop
>(
a_block_buf
,
b_block_buf
,
blockwise_gemm
,
c_thread_buf
,
num_k_block_main_loop
);
// GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx
// a_mtx[K0PerBlock, MPerBlock] is in LDS
// b_mtx[K0PerBlock, NPerBlock] is in LDS
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// register
// sanity check
// shuffle C and write out
{
static_assert
(
MXdlPerWave
%
CShuffleMXdlPerWavePerShuffle
==
0
&&
NXdlPerWave
%
CShuffleNXdlPerWavePerShuffle
==
0
,
"wrong!"
);
constexpr
index_t
MWave
=
MPerBlock
/
(
MXdlPerWave
*
MPerXdl
);
constexpr
index_t
NWave
=
NPerBlock
/
(
NXdlPerWave
*
NPerXdl
);
// TODO: hacky, fix it!
constexpr
auto
c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2
=
blockwise_gemm
.
GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
();
// TODO: hacky, fix it!
// c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths
constexpr
auto
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
=
blockwise_gemm
.
GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
();
constexpr
auto
M0
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
.
GetLength
(
I0
);
constexpr
auto
N0
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
.
GetLength
(
I1
);
constexpr
auto
M1
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
.
GetLength
(
I2
);
constexpr
auto
N1
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
.
GetLength
(
I3
);
constexpr
auto
M2
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
.
GetLength
(
I4
);
constexpr
auto
M3
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
.
GetLength
(
I5
);
constexpr
auto
M4
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
.
GetLength
(
I6
);
constexpr
auto
N2
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
.
GetLength
(
I7
);
constexpr
auto
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
=
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
();
auto
c_shuffle_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
EDataTypeShuffle
*>
(
p_shared
),
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
.
GetElementSpaceSize
());
constexpr
auto
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
=
transform_tensor_descriptor
(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
,
make_tuple
(
make_freeze_transform
(
I0
),
make_unmerge_transform
(
make_tuple
(
Number
<
CShuffleMXdlPerWavePerShuffle
>
{},
// M0 (MXdlPerWave) per shuffle
M1
,
// M1 = MWave
M2
,
// M2 * M3 * M4 = MPerXdl
M3
,
M4
)),
make_freeze_transform
(
I0
),
make_unmerge_transform
(
make_tuple
(
Number
<
CShuffleNXdlPerWavePerShuffle
>
{},
// N0 (NXdlPerWave) per shuffle
N1
,
// N1 = NWave
N2
))),
// N2 = NPerXdl
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<>
{},
Sequence
<
0
,
2
,
4
,
5
,
6
>
{},
Sequence
<>
{},
Sequence
<
1
,
3
,
7
>
{}));
// calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index
const
auto
c_thread_mtx_on_block
=
blockwise_gemm
.
CalculateCThreadOriginDataIndex
(
I0
,
I0
,
I0
,
I0
);
const
index_t
m_thread_data_on_block
=
c_thread_mtx_on_block
[
I0
];
const
index_t
n_thread_data_on_block
=
c_thread_mtx_on_block
[
I1
];
const
auto
m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
M0
,
M1
,
M2
,
M3
,
M4
))),
make_tuple
(
Sequence
<
0
,
1
,
2
,
3
,
4
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
const
auto
m_thread_data_on_block_idx
=
m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor
.
CalculateBottomIndex
(
make_multi_index
(
m_thread_data_on_block
));
const
auto
n_thread_data_on_block_to_n0_n1_n2_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
N0
,
N1
,
N2
))),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
const
auto
n_thread_data_on_block_idx
=
n_thread_data_on_block_to_n0_n1_n2_adaptor
.
CalculateBottomIndex
(
make_multi_index
(
n_thread_data_on_block
));
// shuffle: threadwise copy C from VGPR to LDS
auto
c_thread_copy_vgpr_to_lds
=
ThreadwiseTensorSliceTransfer_v1r3
<
FloatGemmAcc
,
EDataTypeShuffle
,
decltype
(
c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2
),
decltype
(
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
),
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
Sequence
<
CShuffleMXdlPerWavePerShuffle
,
CShuffleNXdlPerWavePerShuffle
,
I1
,
I1
,
M2
,
I1
,
M4
,
I1
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
>
,
7
,
1
,
InMemoryDataOperationEnum
::
Set
,
1
,
true
>
{
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
make_multi_index
(
0
,
0
,
m_thread_data_on_block_idx
[
I1
],
n_thread_data_on_block_idx
[
I1
],
m_thread_data_on_block_idx
[
I2
],
m_thread_data_on_block_idx
[
I3
],
m_thread_data_on_block_idx
[
I4
],
n_thread_data_on_block_idx
[
I2
]),
ck
::
tensor_operation
::
element_wise
::
PassThrough
{}};
// shuffle: blockwise copy C from LDS to global
auto
c_shuffle_block_copy_lds_to_global
=
ThreadGroupTensorSliceTransfer_v6r1
<
CShuffleBlockTransferThreadGroup
,
// ThreadGroup
EElementwiseOperation
,
// ElementwiseOperation,
CGlobalMemoryDataOperation
,
// DstInMemOp,
Sequence
<
1
,
CShuffleMXdlPerWavePerShuffle
*
MWave
*
MPerXdl
,
1
,
CShuffleNXdlPerWavePerShuffle
*
NWave
*
NPerXdl
>
,
// BlockSliceLengths,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
Sequence
<
0
,
1
,
2
,
3
>
,
// typename ThreadClusterArrangeOrder,
EDataTypeShuffle
,
// typename SrcData,
EDataType
,
// typename DstData,
decltype
(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
),
decltype
(
e_grid_desc_mblock_mperblock_nblock_nperblock
),
Sequence
<
0
,
1
,
2
,
3
>
,
// typename DimAccessOrder,
3
,
// index_t VectorDim,
CShuffleBlockTransferScalarPerVector_NPerBlock
,
// index_t ScalarPerVector,
true
,
// bool ThreadTransferSrcResetCoordinateAfterRun,
false
>
// bool ThreadTransferDstResetCoordinateAfterRun>
{
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
,
make_multi_index
(
0
,
0
,
0
,
0
),
e_grid_desc_mblock_mperblock_nblock_nperblock
,
make_multi_index
(
block_work_idx
[
I0
],
0
,
block_work_idx
[
I1
],
0
),
e_element_op
};
// space filling curve for threadwise C in VGPR
constexpr
auto
sfc_c_vgpr
=
SpaceFillingCurve
<
Sequence
<
MXdlPerWave
,
NXdlPerWave
,
1
,
1
,
M2
,
1
,
M4
,
1
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
>
,
Sequence
<
CShuffleMXdlPerWavePerShuffle
,
CShuffleNXdlPerWavePerShuffle
,
1
,
1
,
M2
,
1
,
M4
,
1
>>
{};
// space filling curve for shuffled blockwise C in global mem
constexpr
auto
sfc_c_global
=
SpaceFillingCurve
<
Sequence
<
1
,
MPerBlock
,
1
,
NPerBlock
>
,
Sequence
<
0
,
2
,
1
,
3
>
,
Sequence
<
1
,
CShuffleMXdlPerWavePerShuffle
*
MWave
*
MPerXdl
,
1
,
CShuffleNXdlPerWavePerShuffle
*
NWave
*
NPerXdl
>>
{};
constexpr
index_t
num_access
=
sfc_c_vgpr
.
GetNumOfAccess
();
static_assert
(
num_access
==
sfc_c_global
.
GetNumOfAccess
(),
"wrong!"
);
// Different way of getting coalesced writes:
// We can get rid of doing cshuffle. Instead of reading A rows in contiguous manner
// do it interleaved, then mfma can have nice c-mat layout as below:
//
// TODO
// We do not need to do LDS swizzle to align global writes writing cache lines:
// v_mfma cmat, amat, bmat, cmat - c-mat register layout are 1xN
// elments (N is vertical or strided
// dimension)
// v_mfma cmat, bmat, amat, cmat - c-mat register layout are Mx1
// elments (M is coalescing
// dimension) by enumerating M index in
// amat, bmat you can align cmat
// register(s) to contiguous M elements
// for example
// 1st mfma instruction output space : 0 4 8 12 16 ....
// 2nd mfma instruction output space : 1 5 9 13 17 ....
// 3rd mfma instruction output space : 2 6 10 14 18 ....
// 4th mfma instruction output space : 3 7 11 15 19 ....
// you can pack 4 registers output space into 2WORD and do global write
// (no LDS swizzling required)
static_for
<
0
,
num_access
,
1
>
{}([
&
](
auto
access_id
)
{
// make sure it's safe to write to LDS
block_sync_lds
();
// each thread write its data from VGPR to LDS
c_thread_copy_vgpr_to_lds
.
Run
(
c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
sfc_c_vgpr
.
GetIndexTupleOfNumber
(
access_id
),
c_thread_buf
,
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
c_shuffle_block_buf
);
// make sure it's safe to read from LDS
block_sync_lds
();
// each block copy its data from LDS to global
c_shuffle_block_copy_lds_to_global
.
Run
(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
,
c_shuffle_block_buf
,
e_grid_desc_mblock_mperblock_nblock_nperblock
,
c_grid_buf
);
if
constexpr
(
access_id
<
num_access
-
1
)
{
constexpr
auto
c_global_step
=
sfc_c_global
.
GetForwardStep
(
access_id
);
// move on C
c_shuffle_block_copy_lds_to_global
.
MoveDstSliceWindow
(
e_grid_desc_mblock_mperblock_nblock_nperblock
,
c_global_step
);
}
});
}
}
}
};
}
// namespace ck
include/ck/tensor_operation/gpu/grid/gridwise_sparse_embedding
3
_forward_layernorm.hpp
→
include/ck/tensor_operation/gpu/grid/gridwise_sparse_embedding
s
_forward_layernorm.hpp
View file @
644df335
...
...
@@ -17,33 +17,24 @@ template <typename GridwiseSparseEmbedding,
typename
BetaDataType
,
typename
AccDataType
,
typename
OutType
,
typename
OutGridDesc
>
typename
OutGridDesc
,
typename
EmbElementwiseOperation
,
ck
::
index_t
NumEmbeddings
>
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
#endif
__global__
void
kernel_sparse_embedding3_forward_layernorm
(
OutType
*
p_out
,
const
EmbType
*
p_emb_a
,
const
EmbType
*
p_emb_b
,
const
EmbType
*
p_emb_c
,
const
IndexType
*
p_index_a
,
const
IndexType
*
p_index_b
,
const
IndexType
*
p_index_c
,
const
GammaDataType
*
p_gamma
,
const
BetaDataType
*
p_beta
,
const
OutGridDesc
out_grid_desc
,
const
AccDataType
epsilon
)
__global__
void
kernel_sparse_embeddings_forward_layernorm
(
OutType
*
p_out
,
const
ck
::
Array
<
EmbType
*
,
NumEmbeddings
>
p_embs
,
const
ck
::
Array
<
IndexType
*
,
NumEmbeddings
>
p_indexes
,
const
GammaDataType
*
p_gamma
,
const
BetaDataType
*
p_beta
,
const
OutGridDesc
out_grid_desc
,
const
AccDataType
epsilon
,
const
EmbElementwiseOperation
emb_elementwise_op
)
{
GridwiseSparseEmbedding
::
Run
(
p_out
,
p_emb_a
,
p_emb_b
,
p_emb_c
,
p_index_a
,
p_index_b
,
p_index_c
,
p_gamma
,
p_beta
,
out_grid_desc
,
epsilon
);
GridwiseSparseEmbedding
::
Run
(
p_out
,
p_embs
,
p_indexes
,
p_gamma
,
p_beta
,
out_grid_desc
,
epsilon
,
emb_elementwise_op
);
}
template
<
typename
EmbType
,
...
...
@@ -53,14 +44,16 @@ template <typename EmbType,
typename
AccDataType
,
typename
OutType
,
typename
OutGridDesc
,
typename
EmbElementwiseOperation
,
ck
::
index_t
BlockSize
,
ck
::
index_t
DimClusterSize
,
ck
::
index_t
RowClusterSize
,
ck
::
index_t
DimPerBlock
,
// Row x Dim, along Dim
ck
::
index_t
RowPerBlock
,
// Row x Dim, along Row
ck
::
index_t
DimThreadSize
,
// this is actually not vector, but number of registers
ck
::
index_t
RowVectorSize
>
struct
GridwiseSparseEmbedding3ForwardLayernorm
ck
::
index_t
RowVectorSize
,
ck
::
index_t
NumEmbeddings
>
struct
GridwiseSparseEmbeddingsForwardLayernorm
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
...
...
@@ -97,23 +90,17 @@ struct GridwiseSparseEmbedding3ForwardLayernorm
BlockwiseWelford
<
AccDataType
,
BlockSize
,
ThreadClusterLength
,
Sequence
<
0
,
1
>>
;
__device__
static
void
Run
(
OutType
*
p_out
,
const
EmbType
*
p_emb_a
,
const
EmbType
*
p_emb_b
,
const
EmbType
*
p_emb_c
,
const
IndexType
*
p_index_a
,
const
IndexType
*
p_index_b
,
const
IndexType
*
p_index_c
,
const
ck
::
Array
<
EmbType
*
,
NumEmbeddings
>
p_embs
,
const
ck
::
Array
<
IndexType
*
,
NumEmbeddings
>
p_indexes
,
const
GammaDataType
*
p_gamma
,
const
BetaDataType
*
p_beta
,
const
OutGridDesc
,
const
AccDataType
epsilon
)
const
AccDataType
epsilon
,
const
EmbElementwiseOperation
emb_elementwise_op
)
{
const
index_t
thread_local_id
=
get_thread_local_1d_id
();
const
index_t
block_global_id
=
get_block_1d_id
();
// const auto index_length = out_grid_desc.GetLength(I0);
// const auto emb_dim = out_grid_desc.GetLength(I1);
constexpr
auto
thread_cluster_desc
=
make_cluster_descriptor
(
Sequence
<
DimClusterSize
,
RowClusterSize
>
{},
Sequence
<
0
,
1
>
{});
...
...
@@ -141,13 +128,11 @@ struct GridwiseSparseEmbedding3ForwardLayernorm
constexpr
auto
gamma_beta_buf_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
RowSubBlocks
,
RowVectorSize
));
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
EmbType
,
thread_buf_size
,
true
>
in_thread_buf_a
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
EmbType
,
thread_buf_size
,
true
>
in_thread_buf_b
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
EmbType
,
thread_buf_size
,
true
>
in_thread_buf_c
;
StaticBuffer
<
AddressSpaceEnum
::
Sgpr
,
IndexType
,
DimPerBlock
,
true
>
index_buf_a
;
StaticBuffer
<
AddressSpaceEnum
::
Sgpr
,
IndexType
,
DimPerBlock
,
true
>
index_buf_b
;
StaticBuffer
<
AddressSpaceEnum
::
Sgpr
,
IndexType
,
DimPerBlock
,
true
>
index_buf_c
;
ck
::
Array
<
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
thread_buf_size
,
true
>
,
NumEmbeddings
>
in_thread_bufs
;
ck
::
Array
<
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
IndexType
,
DimPerBlock
,
true
>
,
NumEmbeddings
>
index_bufs
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
thread_buf_size
,
true
>
acc_thread_buf
;
...
...
@@ -160,42 +145,31 @@ struct GridwiseSparseEmbedding3ForwardLayernorm
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
mean_var_buf_size
,
true
>
var_thread_buf
;
auto
load_current_sub_row
=
[
&
](
auto
i_dim_sub_
,
auto
i_row_sub_
)
{
vector_type_maker_t
<
EmbType
,
RowVectorSize
>
emb_vector_a
;
vector_type_maker_t
<
EmbType
,
RowVectorSize
>
emb_vector_b
;
vector_type_maker_t
<
EmbType
,
RowVectorSize
>
emb_vector_c
;
using
src_vector_t
=
typename
decltype
(
emb_vector_a
)
::
type
;
ck
::
Array
<
vector_type_maker_t
<
EmbType
,
RowVectorSize
>
,
NumEmbeddings
>
emb_vectors
;
auto
emb_a
=
emb_vectors
[
0
];
using
src_vector_t
=
typename
decltype
(
emb_a
)
::
type
;
static_for
<
0
,
DimThreadSize
,
1
>
{}([
&
](
auto
i_dim_vec_
)
{
constexpr
auto
current_dim
=
i_dim_sub_
*
DimPerSubBlock
+
i_dim_vec_
;
IndexType
index_a
=
index_buf_a
[
Number
<
current_dim
>
{}];
IndexType
index_b
=
index_buf_b
[
Number
<
current_dim
>
{}];
IndexType
index_c
=
index_buf_c
[
Number
<
current_dim
>
{}];
auto
thread_offset
=
(
thread_row_cluster_id
+
i_row_sub_
*
RowClusterSize
)
*
sizeof
(
EmbType
)
*
RowVectorSize
;
static_for
<
0
,
NumEmbeddings
,
1
>
{}([
&
](
auto
i_embedding_
)
{
IndexType
index
=
index_bufs
[
i_embedding_
][
Number
<
current_dim
>
{}];
int32x4_t
emb_res_a
=
make_wave_buffer_resource_with_default_range
(
p_emb_a
+
index_a
*
RowPerBlock
);
int32x4_t
emb_res_b
=
make_wave_buffer_resource_with_default_range
(
p_emb_b
+
index_b
*
RowPerBlock
);
int32x4_t
emb_res_c
=
make_wave_buffer_resource_with_default_range
(
p_emb_c
+
index_c
*
RowPerBlock
);
emb_vector_a
.
template
AsType
<
src_vector_t
>()(
I0
)
=
amd_buffer_load_impl
<
EmbType
,
RowVectorSize
>
(
emb_res_a
,
thread_offset
,
0
);
emb_vector_b
.
template
AsType
<
src_vector_t
>()(
I0
)
=
amd_buffer_load_impl
<
EmbType
,
RowVectorSize
>
(
emb_res_b
,
thread_offset
,
0
);
emb_vector_c
.
template
AsType
<
src_vector_t
>()(
I0
)
=
amd_buffer_load_impl
<
EmbType
,
RowVectorSize
>
(
emb_res_c
,
thread_offset
,
0
);
int32x4_t
emb_res
=
make_wave_buffer_resource_with_default_range
(
p_embs
[
i_embedding_
]
+
index
*
RowPerBlock
);
emb_vectors
(
i_embedding_
).
template
AsType
<
src_vector_t
>()(
I0
)
=
amd_buffer_load_impl
<
EmbType
,
RowVectorSize
>
(
emb_res
,
thread_offset
,
0
);
});
static_for
<
0
,
RowVectorSize
,
1
>
{}([
&
](
auto
i_row_vec_
)
{
constexpr
auto
register_offset
=
thread_buf_desc
.
CalculateOffset
(
make_tuple
(
i_dim_sub_
,
i_dim_vec_
,
i_row_sub_
,
i_row_vec_
));
in_thread_buf_a
(
Number
<
register_offset
>
{})
=
emb_vector_a
.
template
AsType
<
EmbType
>()[
i_row_vec_
];
in_thread_buf_b
(
Number
<
register_offset
>
{})
=
emb_vector_b
.
template
AsType
<
EmbType
>()[
i_row_vec_
];
in_thread_buf_c
(
Number
<
register_offset
>
{})
=
emb_vector_c
.
template
AsType
<
EmbType
>()[
i_row_vec_
];
static_for
<
0
,
NumEmbeddings
,
1
>
{}([
&
](
auto
i_embedding_
)
{
in_thread_bufs
(
i_embedding_
)(
Number
<
register_offset
>
{})
=
ck
::
type_convert
<
AccDataType
>
(
emb_vectors
[
i_embedding_
].
template
AsType
<
EmbType
>()[
i_row_vec_
]);
});
});
});
};
...
...
@@ -205,14 +179,17 @@ struct GridwiseSparseEmbedding3ForwardLayernorm
static_for
<
0
,
RowVectorSize
,
1
>
{}([
&
](
auto
i_row_vec_
)
{
constexpr
auto
register_offset
=
thread_buf_desc
.
CalculateOffset
(
make_tuple
(
i_dim_sub_
,
i_dim_vec_
,
i_row_sub_
,
i_row_vec_
));
AccDataType
va
=
ck
::
type_convert
<
AccDataType
>
(
in_thread_buf_a
(
Number
<
register_offset
>
{}));
AccDataType
vb
=
ck
::
type_convert
<
AccDataType
>
(
in_thread_buf_b
(
Number
<
register_offset
>
{}));
AccDataType
vc
=
ck
::
type_convert
<
AccDataType
>
(
in_thread_buf_c
(
Number
<
register_offset
>
{}));
acc_thread_buf
(
Number
<
register_offset
>
{})
+=
va
+
vb
+
vc
;
auto
in_data_refs
=
generate_tie
(
[
&
](
auto
i_embedding_
)
->
const
auto
&
{
return
in_thread_bufs
(
i_embedding_
)(
Number
<
register_offset
>
{});
},
Number
<
NumEmbeddings
>
{});
auto
out_data_refs
=
generate_tie
(
[
&
](
auto
output_index_
)
->
auto
&
{
return
acc_thread_buf
(
Number
<
register_offset
>
{});
},
Number
<
1
>
{});
unpack2
(
emb_elementwise_op
,
out_data_refs
,
in_data_refs
);
});
});
};
...
...
@@ -242,7 +219,8 @@ struct GridwiseSparseEmbedding3ForwardLayernorm
constexpr
auto
mean_var_offset
=
mean_var_buf_desc
.
CalculateOffset
(
make_tuple
(
i_dim_sub_
,
i_dim_vec_
));
auto
divisor
=
1
/
__builtin_amdgcn_sqrtf
(
var_thread_buf
(
Number
<
mean_var_offset
>
{})
+
epsilon
);
static_for
<
0
,
RowVectorSize
,
1
>
{}([
&
](
auto
i_row_vec_
)
{
constexpr
auto
register_offset
=
thread_buf_desc
.
CalculateOffset
(
make_tuple
(
i_dim_sub_
,
i_dim_vec_
,
i_row_sub_
,
i_row_vec_
));
...
...
@@ -250,9 +228,8 @@ struct GridwiseSparseEmbedding3ForwardLayernorm
gamma_beta_buf_desc
.
CalculateOffset
(
make_tuple
(
i_row_sub_
,
i_row_vec_
));
auto
acc_val
=
acc_thread_buf
[
Number
<
register_offset
>
{}];
acc_val
=
(
acc_val
-
mean_thread_buf
(
Number
<
mean_var_offset
>
{}))
/
sqrt
(
var_thread_buf
(
Number
<
mean_var_offset
>
{})
+
epsilon
);
acc_val
=
acc_val
*
gamma_thread_buf
[
Number
<
gamma_beta_offset
>
{}]
+
acc_val
=
(
acc_val
-
mean_thread_buf
(
Number
<
mean_var_offset
>
{}))
*
divisor
;
acc_val
=
acc_val
*
gamma_thread_buf
[
Number
<
gamma_beta_offset
>
{}]
+
beta_thread_buf
[
Number
<
gamma_beta_offset
>
{}];
out_vector
.
template
AsType
<
OutType
>()(
Number
<
i_row_vec_
>
{})
=
...
...
@@ -273,9 +250,10 @@ struct GridwiseSparseEmbedding3ForwardLayernorm
// first load index
ck
::
static_for
<
0
,
DimPerBlock
,
1
>
{}([
&
](
auto
i_idx_
)
{
// prefer use s_load
index_buf_a
(
i_idx_
)
=
p_index_a
[
index_start
+
i_idx_
.
value
];
index_buf_b
(
i_idx_
)
=
p_index_b
[
index_start
+
i_idx_
.
value
];
index_buf_c
(
i_idx_
)
=
p_index_c
[
index_start
+
i_idx_
.
value
];
ck
::
static_for
<
0
,
NumEmbeddings
,
1
>
{}([
&
](
auto
i_embedding_
)
{
index_bufs
(
i_embedding_
)(
i_idx_
)
=
p_indexes
[
i_embedding_
][
index_start
+
i_idx_
.
value
];
});
});
// load gamma/beta
...
...
@@ -329,7 +307,6 @@ struct GridwiseSparseEmbedding3ForwardLayernorm
static_for
<
0
,
mean_var_buf_size
,
1
>
{}([
&
](
auto
I
)
{
if
constexpr
(
I
>
0
)
block_sync_lds
();
BlockwiseWelford
::
Run
(
mean_thread_buf
(
I
),
var_thread_buf
(
I
),
threadwise_welford
.
cur_count_
);
});
...
...
Prev
1
2
3
4
5
6
7
…
13
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment