Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
a768dea5
Unverified
Commit
a768dea5
authored
Jan 10, 2023
by
Rostyslav Geyyer
Committed by
GitHub
Jan 10, 2023
Browse files
Merge branch 'develop' into lwpck-471
parents
3f976dd0
0345963e
Changes
143
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1380 additions
and
129 deletions
+1380
-129
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl.hpp
...sor_operation/gpu/device/impl/device_grouped_gemm_xdl.hpp
+12
-1
include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp
...r_operation/gpu/element/binary_element_wise_operation.hpp
+16
-0
include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp
...id/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp
+5
-0
include/ck/tensor_operation/gpu/grid/gridwise_elementwise_2d.hpp
.../ck/tensor_operation/gpu/grid/gridwise_elementwise_2d.hpp
+230
-0
include/ck/tensor_operation/gpu/grid/gridwise_gemm_dl_multiple_d.hpp
...tensor_operation/gpu/grid/gridwise_gemm_dl_multiple_d.hpp
+678
-0
include/ck/utility/amd_wmma.hpp
include/ck/utility/amd_wmma.hpp
+102
-0
include/ck/utility/math_v2.hpp
include/ck/utility/math_v2.hpp
+16
-2
library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp
...or_operation_instance/gpu/grouped_convolution_forward.hpp
+44
-0
library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_dl.hpp
...operation_instance/gpu/grouped_convolution_forward_dl.hpp
+0
-116
library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16_km_kn_mn_instance.cpp
...pu/gemm/device_gemm_xdl_f16_f16_f16_km_kn_mn_instance.cpp
+26
-1
library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16_km_nk_mn_instance.cpp
...pu/gemm/device_gemm_xdl_f16_f16_f16_km_nk_mn_instance.cpp
+26
-1
library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16_mk_kn_mn_instance.cpp
...pu/gemm/device_gemm_xdl_f16_f16_f16_mk_kn_mn_instance.cpp
+26
-1
library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16_mk_nk_mn_instance.cpp
...pu/gemm/device_gemm_xdl_f16_f16_f16_mk_nk_mn_instance.cpp
+3
-0
library/src/tensor_operation_instance/gpu/gemm_add_add_fastgelu/device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_km_kn_mn_mn_mn_instance.cpp
...c_shuffle_f16_f16_f16_f16_f16_km_kn_mn_mn_mn_instance.cpp
+28
-1
library/src/tensor_operation_instance/gpu/gemm_add_add_fastgelu/device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_km_nk_mn_mn_mn_instance.cpp
...c_shuffle_f16_f16_f16_f16_f16_km_nk_mn_mn_mn_instance.cpp
+28
-1
library/src/tensor_operation_instance/gpu/gemm_add_add_fastgelu/device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_kn_mn_mn_mn_instance.cpp
...c_shuffle_f16_f16_f16_f16_f16_mk_kn_mn_mn_mn_instance.cpp
+28
-1
library/src/tensor_operation_instance/gpu/gemm_add_add_fastgelu/device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_nk_mn_mn_mn_instance.cpp
...c_shuffle_f16_f16_f16_f16_f16_mk_nk_mn_mn_mn_instance.cpp
+28
-1
library/src/tensor_operation_instance/gpu/gemm_add_fastgelu/device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_km_kn_mn_mn_instance.cpp
...lu_xdl_c_shuffle_f16_f16_f16_f16_km_kn_mn_mn_instance.cpp
+28
-1
library/src/tensor_operation_instance/gpu/gemm_add_fastgelu/device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_km_nk_mn_mn_instance.cpp
...lu_xdl_c_shuffle_f16_f16_f16_f16_km_nk_mn_mn_instance.cpp
+28
-1
library/src/tensor_operation_instance/gpu/gemm_add_fastgelu/device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instance.cpp
...lu_xdl_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instance.cpp
+28
-1
No files found.
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl.hpp
View file @
a768dea5
...
...
@@ -373,12 +373,20 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout,
gemm_desc_kernel_arg_
.
reserve
(
group_count_
);
skipped_group_count_
=
0
;
for
(
std
::
size_t
i
=
0
;
i
<
gemm_descs
.
size
();
i
++
)
{
const
index_t
M
=
gemm_descs
[
i
].
M_
;
const
index_t
N
=
gemm_descs
[
i
].
N_
;
const
index_t
K
=
gemm_descs
[
i
].
K_
;
if
(
M
==
0
)
{
skipped_group_count_
++
;
continue
;
}
const
index_t
StrideA
=
gemm_descs
[
i
].
stride_A_
;
const
index_t
StrideB
=
gemm_descs
[
i
].
stride_B_
;
const
index_t
StrideC
=
gemm_descs
[
i
].
stride_C_
;
...
...
@@ -470,6 +478,8 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout,
// private:
index_t
group_count_
;
index_t
skipped_group_count_
;
AElementwiseOperation
a_element_op_
;
BElementwiseOperation
b_element_op_
;
CDEElementwiseOperation
c_element_op_
;
...
...
@@ -581,7 +591,8 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout,
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
if
(
ck
::
type_convert
<
ck
::
index_t
>
(
arg
.
gemm_desc_kernel_arg_
.
size
())
!=
arg
.
group_count_
)
if
((
ck
::
type_convert
<
ck
::
index_t
>
(
arg
.
gemm_desc_kernel_arg_
.
size
())
+
arg
.
skipped_group_count_
)
!=
arg
.
group_count_
)
{
return
false
;
}
...
...
include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp
View file @
a768dea5
...
...
@@ -187,6 +187,22 @@ struct AddRelu
const
float
a
=
x0
+
type_convert
<
float
>
(
x1
);
y
=
a
>
0.0
f
?
a
:
0.0
f
;
};
template
<
>
__host__
__device__
constexpr
void
operator
()
<
int
,
int
,
int8_t
>
(
int
&
y
,
const
int
&
x0
,
const
int8_t
&
x1
)
const
{
const
int8_t
a
=
x0
+
x1
;
y
=
a
>
0
?
a
:
0
;
};
template
<
>
__host__
__device__
constexpr
void
operator
()
<
int8_t
,
int8_t
,
int8_t
>
(
int8_t
&
y
,
const
int8_t
&
x0
,
const
int8_t
&
x1
)
const
{
const
int8_t
a
=
x0
+
x1
;
y
=
a
>
0
?
a
:
0
;
};
};
struct
AddHardswish
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp
View file @
a768dea5
...
...
@@ -796,6 +796,11 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
}
});
}
else
{
static_for
<
0
,
acc_thread_buf
.
Size
(),
1
>
{}(
[
&
](
auto
i
)
{
acc_element_op
(
acc_thread_buf
(
i
),
acc_thread_buf
[
i
]);
});
}
block_sync_lds
();
// wait for lds read in gemm0 blockwise gemm
...
...
include/ck/tensor_operation/gpu/grid/gridwise_elementwise_2d.hpp
0 → 100644
View file @
a768dea5
// SPDX-License-Identifier: MIT
// // Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
//
#pragma once
#include "ck/tensor_description/cluster_descriptor.hpp"
#include "ck/utility/data_type.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
namespace
ck
{
template
<
typename
GridwiseElementwise2dFunctor
,
typename
InGrid2dDescTuple
,
typename
OutGrid2dDescTuple
,
typename
InDataTypePointerTuple
,
typename
OutDataTypePointerTuple
,
typename
ElementwiseOperation
>
__global__
void
kernel_elementwise_2d
(
const
InGrid2dDescTuple
in_grid_2d_desc_tuple
,
const
OutGrid2dDescTuple
out_grid_2d_desc_tuple
,
const
InDataTypePointerTuple
p_in_global_tuple
,
const
OutDataTypePointerTuple
p_out_global_tuple
,
const
ElementwiseOperation
elementwise_op
,
const
index_t
num_threads_m
,
const
index_t
num_threads_n
)
{
GridwiseElementwise2dFunctor
::
Run
(
in_grid_2d_desc_tuple
,
out_grid_2d_desc_tuple
,
p_in_global_tuple
,
p_out_global_tuple
,
elementwise_op
,
num_threads_m
,
num_threads_n
);
}
template
<
typename
InGrid2dDescTuple
,
typename
OutGrid2dDescTuple
,
typename
InDataTypePointerTuple
,
typename
OutDataTypePointerTuple
,
typename
ElementwiseOperation
,
index_t
MPerThread
,
index_t
NPerThread
,
typename
InScalarPerVectorSeq
,
typename
OutScalarPerVectorSeq
>
struct
GridwiseElementwise_2D
{
static
constexpr
index_t
NumInput
=
InDataTypePointerTuple
::
Size
();
static
constexpr
index_t
NumOutput
=
OutDataTypePointerTuple
::
Size
();
static_assert
(
NumInput
==
InScalarPerVectorSeq
::
Size
()
&&
NumOutput
==
OutScalarPerVectorSeq
::
Size
()
&&
NumInput
==
InGrid2dDescTuple
::
Size
()
&&
NumOutput
==
OutGrid2dDescTuple
::
Size
(),
"Tuple size is inconsistent with the number of in/out!"
);
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
thread_buffer_desc_mn
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MPerThread
>
{},
Number
<
NPerThread
>
{}));
using
PassThroughOp
=
tensor_operation
::
element_wise
::
PassThrough
;
__device__
static
void
Run
(
const
InGrid2dDescTuple
in_grid_2d_desc_tuple
,
const
OutGrid2dDescTuple
out_grid_2d_desc_tuple
,
const
InDataTypePointerTuple
p_in_global_tuple
,
const
OutDataTypePointerTuple
p_out_global_tuple
,
const
ElementwiseOperation
elementwise_op
,
const
index_t
num_threads_m
,
const
index_t
num_threads_n
)
{
auto
in_thread_buf_tuple
=
generate_tuple
(
[
&
](
auto
I
)
{
using
DataTypePointer
=
remove_cvref_t
<
decltype
(
InDataTypePointerTuple
{}[
I
])
>
;
using
DataType
=
remove_cv_t
<
remove_pointer_t
<
DataTypePointer
>>
;
return
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
DataType
,
MPerThread
*
NPerThread
,
true
>
{};
},
Number
<
NumInput
>
{});
auto
out_thread_buf_tuple
=
generate_tuple
(
[
&
](
auto
I
)
{
using
DataTypePointer
=
remove_cvref_t
<
decltype
(
OutDataTypePointerTuple
{}[
I
])
>
;
using
DataType
=
remove_pointer_t
<
DataTypePointer
>
;
return
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
DataType
,
MPerThread
*
NPerThread
,
true
>
{};
},
Number
<
NumOutput
>
{});
auto
in_global_buf_tuple
=
generate_tuple
(
[
&
](
auto
I
)
{
return
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_in_global_tuple
[
I
],
in_grid_2d_desc_tuple
[
I
].
GetElementSpaceSize
());
},
Number
<
NumInput
>
{});
auto
out_global_buf_tuple
=
generate_tuple
(
[
&
](
auto
I
)
{
return
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_out_global_tuple
[
I
],
out_grid_2d_desc_tuple
[
I
].
GetElementSpaceSize
());
},
Number
<
NumOutput
>
{});
const
auto
M
=
in_grid_2d_desc_tuple
[
I0
].
GetLength
(
I0
);
const
auto
N
=
in_grid_2d_desc_tuple
[
I0
].
GetLength
(
I1
);
const
index_t
loop_step_m
=
num_threads_m
*
MPerThread
;
const
index_t
loop_step_n
=
num_threads_n
*
NPerThread
;
const
index_t
thread_1d_id
=
get_thread_global_1d_id
();
index_t
tid_m
=
thread_1d_id
/
num_threads_n
;
index_t
tid_n
=
thread_1d_id
%
num_threads_n
;
const
auto
thread_global_offset
=
make_multi_index
(
tid_m
*
MPerThread
,
tid_n
*
NPerThread
);
auto
in_global_load_tuple
=
generate_tuple
(
[
&
](
auto
I
)
{
using
DataTypePointer
=
remove_cvref_t
<
decltype
(
InDataTypePointerTuple
{}[
I
])
>
;
using
DataType
=
remove_cv_t
<
remove_pointer_t
<
DataTypePointer
>>
;
return
ThreadwiseTensorSliceTransfer_v2
<
DataType
,
DataType
,
decltype
(
in_grid_2d_desc_tuple
[
I
]),
decltype
(
thread_buffer_desc_mn
),
Sequence
<
MPerThread
,
NPerThread
>
,
// SliceLengths
Sequence
<
0
,
1
>
,
// DimAccessOrder
0
,
// SrcVectorDim
InScalarPerVectorSeq
::
At
(
I
),
// ScalarPerVector
1
,
// SrcScalarStrideInVector
true
>
{
in_grid_2d_desc_tuple
[
I
],
thread_global_offset
};
},
Number
<
NumInput
>
{});
auto
out_global_store_tuple
=
generate_tuple
(
[
&
](
auto
I
)
{
using
DataTypePointer
=
remove_cvref_t
<
decltype
(
OutDataTypePointerTuple
{}[
I
])
>
;
using
DataType
=
remove_pointer_t
<
DataTypePointer
>
;
return
ThreadwiseTensorSliceTransfer_v1r3
<
DataType
,
DataType
,
decltype
(
thread_buffer_desc_mn
),
decltype
(
out_grid_2d_desc_tuple
[
I
]),
PassThroughOp
,
Sequence
<
MPerThread
,
NPerThread
>
,
// SliceLengths
Sequence
<
0
,
1
>
,
// DimAccessOrder
1
,
// SrcVectorDim
1
,
// OutScalarPerVectorSeq::At(I),
InMemoryDataOperationEnum
::
Set
,
1
,
true
>
(
out_grid_2d_desc_tuple
[
I
],
thread_global_offset
,
PassThroughOp
{});
},
Number
<
NumOutput
>
{});
index_t
num_iter_m
=
M
/
(
loop_step_m
);
do
{
index_t
num_iter_n
=
N
/
(
loop_step_n
);
do
{
static_for
<
0
,
NumInput
,
1
>
{}([
&
](
auto
I
)
{
in_global_load_tuple
(
I
).
Run
(
in_grid_2d_desc_tuple
[
I
],
in_global_buf_tuple
[
I
],
thread_buffer_desc_mn
,
make_tuple
(
I0
,
I0
),
in_thread_buf_tuple
(
I
));
in_global_load_tuple
(
I
).
MoveSrcSliceWindow
(
in_grid_2d_desc_tuple
[
I
],
make_multi_index
(
0
,
loop_step_n
));
});
static_for
<
0
,
MPerThread
,
1
>
{}([
&
](
auto
iM
)
{
static_for
<
0
,
NPerThread
,
1
>
{}([
&
](
auto
iN
)
{
constexpr
auto
offset
=
thread_buffer_desc_mn
.
CalculateOffset
(
make_tuple
(
iM
,
iN
));
// get reference to in data
const
auto
in_data_refs
=
generate_tie
(
// return type should be lvalue
[
&
](
auto
I
)
->
const
auto
&
{
return
in_thread_buf_tuple
(
I
)(
Number
<
offset
>
{});
},
Number
<
NumInput
>
{});
// get referenec to dst data
auto
out_data_refs
=
generate_tie
(
// return type should be lvalue
[
&
](
auto
I
)
->
auto
&
{
return
out_thread_buf_tuple
(
I
)(
Number
<
offset
>
{});
},
Number
<
NumOutput
>
{});
unpack2
(
elementwise_op
,
out_data_refs
,
in_data_refs
);
});
});
static_for
<
0
,
NumOutput
,
1
>
{}([
&
](
auto
I
)
{
out_global_store_tuple
(
I
).
Run
(
thread_buffer_desc_mn
,
make_tuple
(
I0
,
I0
),
out_thread_buf_tuple
[
I
],
out_grid_2d_desc_tuple
[
I
],
out_global_buf_tuple
(
I
));
out_global_store_tuple
(
I
).
MoveDstSliceWindow
(
out_grid_2d_desc_tuple
[
I
],
make_multi_index
(
0
,
loop_step_n
));
});
}
while
(
--
num_iter_n
);
static_for
<
0
,
NumInput
,
1
>
{}([
&
](
auto
I
)
{
in_global_load_tuple
(
I
).
MoveSrcSliceWindow
(
in_grid_2d_desc_tuple
[
I
],
make_multi_index
(
loop_step_m
,
-
(
N
/
loop_step_n
)
*
loop_step_n
));
});
static_for
<
0
,
NumOutput
,
1
>
{}([
&
](
auto
I
)
{
out_global_store_tuple
(
I
).
MoveDstSliceWindow
(
out_grid_2d_desc_tuple
[
I
],
make_multi_index
(
loop_step_m
,
-
(
N
/
loop_step_n
)
*
loop_step_n
));
});
}
while
(
--
num_iter_m
);
}
};
}
// namespace ck
include/ck/tensor_operation/gpu/grid/gridwise_gemm_dl_multiple_d.hpp
0 → 100644
View file @
a768dea5
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/multi_index_transform_helper.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_dl_v2r3.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_tensor_slice_transfer_v5r1.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_set.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
namespace
ck
{
template
<
index_t
BlockSize
,
typename
FloatAB
,
typename
FloatAcc
,
typename
DsDataType
,
typename
FloatC
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CDEElementwiseOperation
,
InMemoryDataOperationEnum
CGlobalMemoryDataOperation
,
typename
AGridDesc_K0_M_K1
,
typename
BGridDesc_K0_N_K1
,
typename
CGridDesc_M_N
,
index_t
MPerBlock
,
index_t
NPerBlock
,
index_t
K0PerBlock
,
index_t
K1Value
,
index_t
M1PerThreadM111
,
index_t
N1PerThreadN111
,
index_t
KPerThread
,
typename
M11N11ThreadClusterM110Xs
,
typename
M11N11ThreadClusterN110Xs
,
typename
ABlockTransferThreadSliceLengths_K0_M0_M1_K1
,
typename
ABlockTransferThreadClusterLengths_K0_M0_M1_K1
,
typename
ABlockTransferThreadClusterArrangeOrder
,
typename
ABlockTransferSrcAccessOrder
,
typename
ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1
,
typename
ABlockTransferSrcVectorTensorContiguousDimOrder
,
typename
ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1
,
typename
BBlockTransferThreadSliceLengths_K0_N0_N1_K1
,
typename
BBlockTransferThreadClusterLengths_K0_N0_N1_K1
,
typename
BBlockTransferThreadClusterArrangeOrder
,
typename
BBlockTransferSrcAccessOrder
,
typename
BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1
,
typename
BBlockTransferSrcVectorTensorContiguousDimOrder
,
typename
BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1
,
typename
CThreadTransferSrcDstAccessOrder
,
index_t
CThreadTransferSrcDstVectorDim
,
index_t
CThreadTransferDstScalarPerVector
>
struct
GridwiseGemmDlMultipleD_km_kn_mn
{
static
constexpr
index_t
NumDTensor
=
DsDataType
::
Size
();
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
// K1 should be Number<...>
static
constexpr
auto
K1
=
Number
<
K1Value
>
{};
// ck::Tuple<const D0DataType*, const D1DataType*, ...>
static
constexpr
auto
MakeDsGridPointer
()
{
return
generate_tuple
(
[
&
](
auto
i
)
{
using
DDataType
=
remove_cvref_t
<
tuple_element_t
<
i
.
value
,
DsDataType
>>
;
return
static_cast
<
const
DDataType
*>
(
nullptr
);
},
Number
<
NumDTensor
>
{});
}
__host__
__device__
static
constexpr
index_t
GetSharedMemoryNumberOfByte
()
{
// TODO: change this. I think it needs multi-dimensional alignment
constexpr
auto
max_lds_align
=
K1
;
// TODO: check alignment
// A matrix in LDS memory, dst of blockwise copy
constexpr
auto
a_block_desc_k_m
=
make_naive_tensor_descriptor_aligned
(
make_tuple
(
Number
<
K0PerBlock
>
{},
Number
<
MPerBlock
>
{},
K1
),
max_lds_align
);
// TODO: check alignment
// B matrix in LDS memory, dst of blockwise copy
constexpr
auto
b_block_desc_k_n
=
make_naive_tensor_descriptor_aligned
(
make_tuple
(
Number
<
K0PerBlock
>
{},
Number
<
NPerBlock
>
{},
K1
),
max_lds_align
);
// TODO: check alignment
// LDS allocation for A and B: be careful of alignment
constexpr
auto
a_block_aligned_space_size
=
math
::
integer_least_multiple
(
a_block_desc_k_m
.
GetElementSpaceSize
(),
max_lds_align
);
constexpr
auto
b_block_aligned_space_size
=
math
::
integer_least_multiple
(
b_block_desc_k_n
.
GetElementSpaceSize
(),
max_lds_align
);
return
2
*
(
a_block_aligned_space_size
+
b_block_aligned_space_size
)
*
sizeof
(
FloatAB
);
}
__host__
__device__
static
constexpr
bool
CheckValidity
(
const
AGridDesc_K0_M_K1
&
a_grid_desc_k0_m_k1
,
const
BGridDesc_K0_N_K1
&
b_grid_desc_k0_n_k1
,
const
CGridDesc_M_N
&
c_grid_desc_m_n
)
{
const
auto
M
=
a_grid_desc_k0_m_k1
.
GetLength
(
I1
);
const
auto
N
=
b_grid_desc_k0_n_k1
.
GetLength
(
I1
);
const
auto
K0
=
a_grid_desc_k0_m_k1
.
GetLength
(
I0
);
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
return
(
M
==
c_grid_desc_m_n
.
GetLength
(
I0
)
&&
N
==
c_grid_desc_m_n
.
GetLength
(
I1
)
&&
K0
==
b_grid_desc_k0_n_k1
.
GetLength
(
I0
)
&&
K1
==
a_grid_desc_k0_m_k1
.
GetLength
(
I2
)
&&
K1
==
b_grid_desc_k0_n_k1
.
GetLength
(
I2
))
&&
(
M
%
MPerBlock
==
0
&&
N
%
NPerBlock
==
0
&&
K0
%
K0PerBlock
==
0
);
}
__host__
__device__
static
constexpr
index_t
CalculateGridSize
(
index_t
M
,
index_t
N
)
{
const
index_t
grid_size
=
(
M
/
MPerBlock
)
*
(
N
/
NPerBlock
);
return
grid_size
;
}
__host__
__device__
static
constexpr
bool
CalculateHasMainKBlockLoop
(
index_t
K0
)
{
const
bool
has_main_k_block_loop
=
(
K0
+
K0PerBlock
)
/
(
2
*
K0PerBlock
)
>
1
;
return
has_main_k_block_loop
;
}
__host__
__device__
static
constexpr
bool
CalculateHasDoubleTailKBlockLoop
(
index_t
K0
)
{
const
bool
has_double_tail_k_block_loop
=
(
K0
/
K0PerBlock
)
%
2
==
0
;
return
has_double_tail_k_block_loop
;
}
__host__
__device__
static
constexpr
auto
MakeAGridDescriptor_K0_M0_M1_K1
(
const
AGridDesc_K0_M_K1
&
a_grid_desc_k0_m_k1
)
{
const
auto
K0
=
a_grid_desc_k0_m_k1
.
GetLength
(
I0
);
const
auto
M
=
a_grid_desc_k0_m_k1
.
GetLength
(
I1
);
const
auto
M1
=
Number
<
MPerBlock
>
{};
const
auto
M0
=
M
/
M1
;
const
auto
a_grid_desc_k0_m0_m1_k1
=
transform_tensor_descriptor
(
a_grid_desc_k0_m_k1
,
make_tuple
(
make_pass_through_transform
(
K0
),
make_unmerge_transform
(
make_tuple
(
M0
,
M1
)),
make_pass_through_transform
(
K1
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
>
{}));
return
a_grid_desc_k0_m0_m1_k1
;
}
__host__
__device__
static
constexpr
auto
MakeBGridDescriptor_K0_N0_N1_K1
(
const
BGridDesc_K0_N_K1
&
b_grid_desc_k0_n_k1
)
{
const
auto
K0
=
b_grid_desc_k0_n_k1
.
GetLength
(
I0
);
const
auto
N
=
b_grid_desc_k0_n_k1
.
GetLength
(
I1
);
const
auto
N1
=
Number
<
NPerBlock
>
{};
const
auto
N0
=
N
/
N1
;
const
auto
b_grid_desc_k0_n0_n1_k1
=
transform_tensor_descriptor
(
b_grid_desc_k0_n_k1
,
make_tuple
(
make_pass_through_transform
(
K0
),
make_unmerge_transform
(
make_tuple
(
N0
,
N1
)),
make_pass_through_transform
(
K1
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
>
{}));
return
b_grid_desc_k0_n0_n1_k1
;
}
__host__
__device__
static
constexpr
auto
MakeCGridDescriptor_M0_M10_M11_N0_N10_N11
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
)
{
const
auto
M
=
c_grid_desc_m_n
.
GetLength
(
I0
);
const
auto
N
=
c_grid_desc_m_n
.
GetLength
(
I1
);
constexpr
auto
M1
=
Number
<
MPerBlock
>
{};
constexpr
auto
N1
=
Number
<
NPerBlock
>
{};
const
auto
M0
=
M
/
M1
;
const
auto
N0
=
N
/
N1
;
constexpr
auto
M11
=
Number
<
container_reduce
(
M11N11ThreadClusterM110Xs
{},
math
::
multiplies
{},
I1
)
*
M1PerThreadM111
>
{};
constexpr
auto
N11
=
Number
<
container_reduce
(
M11N11ThreadClusterN110Xs
{},
math
::
multiplies
{},
I1
)
*
N1PerThreadN111
>
{};
constexpr
auto
M10
=
M1
/
M11
;
constexpr
auto
N10
=
N1
/
N11
;
const
auto
c_grid_desc_m0_m10_m11_n0_n10_n11
=
transform_tensor_descriptor
(
c_grid_desc_m_n
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
M0
,
M10
,
M11
)),
make_unmerge_transform
(
make_tuple
(
N0
,
N10
,
N11
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{},
Sequence
<
3
,
4
,
5
>
{}));
return
c_grid_desc_m0_m10_m11_n0_n10_n11
;
}
// Ds desc for source in blockwise copy
template
<
typename
DsGridDesc_M_N
>
__host__
__device__
static
constexpr
auto
MakeDsGridDescriptor_M0_M10_M11_N0_N10_N11
(
const
DsGridDesc_M_N
&
ds_grid_desc_m_n
)
{
return
generate_tuple
(
[
&
](
auto
i
)
{
return
MakeCGridDescriptor_M0_M10_M11_N0_N10_N11
(
ds_grid_desc_m_n
[
i
]);
},
Number
<
NumDTensor
>
{});
}
// return block_id to C matrix tile idx (m0, n0) mapping
__host__
__device__
static
constexpr
auto
MakeDefaultBlock2CTileMap
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
)
{
return
BlockToCTileMap_M00_N00_M01_N01
<
MPerBlock
,
NPerBlock
,
CGridDesc_M_N
>
(
c_grid_desc_m_n
);
}
using
AGridDesc_K0_M0_M1_K1
=
decltype
(
MakeAGridDescriptor_K0_M0_M1_K1
(
AGridDesc_K0_M_K1
{}));
using
BGridDesc_K0_N0_N1_K1
=
decltype
(
MakeBGridDescriptor_K0_N0_N1_K1
(
BGridDesc_K0_N_K1
{}));
using
CGridDesc_M0_M10_M11_N0_N10_N11
=
decltype
(
MakeCGridDescriptor_M0_M10_M11_N0_N10_N11
(
CGridDesc_M_N
{}));
using
Block2CTileMap
=
decltype
(
MakeDefaultBlock2CTileMap
(
CGridDesc_M_N
{}));
using
DsGridPointer
=
decltype
(
MakeDsGridPointer
());
template
<
typename
DsGridDesc_M0_M10_M11_N0_N10_N11
,
bool
HasMainKBlockLoop
,
bool
HasDoubleTailKBlockLoop
>
__device__
static
void
Run
(
const
FloatAB
*
__restrict__
p_a_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
DsGridPointer
p_ds_grid
,
FloatC
*
__restrict__
p_c_grid
,
FloatAB
*
__restrict__
p_shared_block
,
const
AElementwiseOperation
&
,
const
BElementwiseOperation
&
,
const
CDEElementwiseOperation
&
cde_element_op
,
const
AGridDesc_K0_M0_M1_K1
&
a_grid_desc_k0_m0_m1_k1
,
const
BGridDesc_K0_N0_N1_K1
&
b_grid_desc_k0_n0_n1_k1
,
const
DsGridDesc_M0_M10_M11_N0_N10_N11
&
ds_grid_desc_m0_m10_m11_n0_n10_n11
,
const
CGridDesc_M0_M10_M11_N0_N10_N11
&
c_grid_desc_m0_m10_m11_n0_n10_n11
,
const
Block2CTileMap
&
block_2_ctile_map
,
integral_constant
<
bool
,
HasMainKBlockLoop
>
,
integral_constant
<
bool
,
HasDoubleTailKBlockLoop
>
)
{
const
auto
a_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_a_grid
,
a_grid_desc_k0_m0_m1_k1
.
GetElementSpaceSize
());
const
auto
b_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_b_grid
,
b_grid_desc_k0_n0_n1_k1
.
GetElementSpaceSize
());
auto
c_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_c_grid
,
c_grid_desc_m0_m10_m11_n0_n10_n11
.
GetElementSpaceSize
());
// divide block work by [M, N]
const
auto
c_m0_n0_block_cluster_idx
=
block_2_ctile_map
.
CalculateBottomIndex
(
make_multi_index
(
get_block_1d_id
()));
// HACK: this force index data into SGPR
const
index_t
im0
=
__builtin_amdgcn_readfirstlane
(
c_m0_n0_block_cluster_idx
[
I0
]);
const
index_t
in0
=
__builtin_amdgcn_readfirstlane
(
c_m0_n0_block_cluster_idx
[
I1
]);
if
(
!
block_2_ctile_map
.
ValidCTileIndex
(
make_tuple
(
im0
,
in0
),
make_tuple
(
c_grid_desc_m0_m10_m11_n0_n10_n11
.
GetLength
(
I0
),
c_grid_desc_m0_m10_m11_n0_n10_n11
.
GetLength
(
I3
))))
{
return
;
}
// TODO: change this. I think it needs multi-dimensional alignment
constexpr
auto
max_lds_align
=
K1
;
// TODO: check alignment
// A matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr
auto
a_block_desc_k0_m0_m1_k1
=
make_naive_tensor_descriptor_aligned
(
make_tuple
(
Number
<
K0PerBlock
>
{},
I1
,
Number
<
MPerBlock
>
{},
K1
),
max_lds_align
);
// TODO: check alignment
// B matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr
auto
b_block_desc_k0_n0_n1_k1
=
make_naive_tensor_descriptor_aligned
(
make_tuple
(
Number
<
K0PerBlock
>
{},
I1
,
Number
<
NPerBlock
>
{},
K1
),
max_lds_align
);
// TODO: check alignment
// A matrix in LDS memory, for blockwise GEMM
constexpr
auto
a_k0_m_k1_block_desc
=
make_naive_tensor_descriptor_aligned
(
make_tuple
(
Number
<
K0PerBlock
>
{},
Number
<
MPerBlock
>
{},
K1
),
max_lds_align
);
// TODO: check alignment
// B matrix in LDS memory, for blockwise GEMM
constexpr
auto
b_k0_n_k1_block_desc
=
make_naive_tensor_descriptor_aligned
(
make_tuple
(
Number
<
K0PerBlock
>
{},
Number
<
NPerBlock
>
{},
K1
),
max_lds_align
);
static_assert
(
a_block_desc_k0_m0_m1_k1
.
GetElementSpaceSize
()
==
a_k0_m_k1_block_desc
.
GetElementSpaceSize
()
&&
b_block_desc_k0_n0_n1_k1
.
GetElementSpaceSize
()
==
b_k0_n_k1_block_desc
.
GetElementSpaceSize
()
&&
"wrong!"
);
// A matrix blockwise copy
auto
a_blockwise_copy
=
BlockwiseTensorSliceTransfer_v5r1
<
BlockSize
,
InMemoryDataOperationEnum
::
Set
,
Sequence
<
K0PerBlock
,
1
,
MPerBlock
,
K1
.
value
>
,
ABlockTransferThreadSliceLengths_K0_M0_M1_K1
,
ABlockTransferThreadClusterLengths_K0_M0_M1_K1
,
ABlockTransferThreadClusterArrangeOrder
,
FloatAB
,
FloatAB
,
remove_reference_t
<
decltype
(
a_grid_desc_k0_m0_m1_k1
)
>
,
decltype
(
a_block_desc_k0_m0_m1_k1
),
ABlockTransferSrcAccessOrder
,
Sequence
<
0
,
1
,
2
,
3
>
,
ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1
,
// SrcVectorTensorLengths
ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1
,
// DstVectorTensorLengths
ABlockTransferSrcVectorTensorContiguousDimOrder
,
// SrcVectorTensorContiguousDimOrder
Sequence
<
0
,
1
,
2
,
3
>
,
// DstVectorTensorContiguousDimOrder
false
,
true
>
(
a_grid_desc_k0_m0_m1_k1
,
make_multi_index
(
0
,
im0
,
0
,
0
),
a_block_desc_k0_m0_m1_k1
,
make_multi_index
(
0
,
0
,
0
,
0
));
// B matrix blockwise copy
auto
b_blockwise_copy
=
BlockwiseTensorSliceTransfer_v5r1
<
BlockSize
,
InMemoryDataOperationEnum
::
Set
,
Sequence
<
K0PerBlock
,
1
,
NPerBlock
,
K1
.
value
>
,
BBlockTransferThreadSliceLengths_K0_N0_N1_K1
,
BBlockTransferThreadClusterLengths_K0_N0_N1_K1
,
BBlockTransferThreadClusterArrangeOrder
,
FloatAB
,
FloatAB
,
remove_reference_t
<
decltype
(
b_grid_desc_k0_n0_n1_k1
)
>
,
decltype
(
b_block_desc_k0_n0_n1_k1
),
BBlockTransferSrcAccessOrder
,
Sequence
<
0
,
1
,
2
,
3
>
,
BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1
,
// SrcVectorTensorLengths
BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1
,
// DstVectorTensorLengths
BBlockTransferSrcVectorTensorContiguousDimOrder
,
// SrcVectorTensorContiguousDimOrder
Sequence
<
0
,
1
,
2
,
3
>
,
// DstVectorTensorContiguousDimOrder
false
,
true
>
(
b_grid_desc_k0_n0_n1_k1
,
make_multi_index
(
0
,
in0
,
0
,
0
),
b_block_desc_k0_n0_n1_k1
,
make_multi_index
(
0
,
0
,
0
,
0
));
// GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx
// a_mtx[K0PerBlock, MPerBlock] is in LDS
// b_mtx[KPerBlocl, NPerBlock] is in LDS
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// register
const
auto
blockwise_gemm
=
BlockwiseGemmDl_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_BM0_2_BN0_2
<
BlockSize
,
FloatAB
,
FloatAB
,
FloatAcc
,
decltype
(
a_k0_m_k1_block_desc
),
decltype
(
b_k0_n_k1_block_desc
),
M1PerThreadM111
,
N1PerThreadN111
,
KPerThread
,
M11N11ThreadClusterM110Xs
,
M11N11ThreadClusterN110Xs
,
M1PerThreadM111
,
N1PerThreadN111
>
{};
constexpr
auto
c_m10_m11_n10_n11_thread_tensor_lengths
=
decltype
(
blockwise_gemm
)
::
GetCThreadTensorLengths_BM0_BM1_BN0_BN1
();
constexpr
auto
c_thread_desc_m10_m11_n10_n11
=
make_naive_tensor_descriptor_packed
(
sequence_to_tuple_of_number
(
c_m10_m11_n10_n11_thread_tensor_lengths
));
// LDS allocation for A and B: be careful of alignment
constexpr
auto
a_block_aligned_space_size
=
math
::
integer_least_multiple
(
a_block_desc_k0_m0_m1_k1
.
GetElementSpaceSize
(),
max_lds_align
);
constexpr
auto
b_block_aligned_space_size
=
math
::
integer_least_multiple
(
b_block_desc_k0_n0_n1_k1
.
GetElementSpaceSize
(),
max_lds_align
);
FloatAB
*
p_a_block_double
=
p_shared_block
;
FloatAB
*
p_b_block_double
=
p_shared_block
+
2
*
a_block_aligned_space_size
;
// register allocation for output
auto
c_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
FloatAcc
>
(
c_thread_desc_m10_m11_n10_n11
.
GetElementSpaceSize
());
// Initialize C
c_thread_buf
.
Clear
();
constexpr
auto
a_block_slice_copy_step
=
make_multi_index
(
K0PerBlock
,
0
,
0
,
0
);
constexpr
auto
b_block_slice_copy_step
=
make_multi_index
(
K0PerBlock
,
0
,
0
,
0
);
auto
a_block_even_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
p_a_block_double
,
a_block_desc_k0_m0_m1_k1
.
GetElementSpaceSize
());
auto
b_block_even_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
p_b_block_double
,
b_block_desc_k0_n0_n1_k1
.
GetElementSpaceSize
());
auto
a_block_odd_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
p_a_block_double
+
a_block_aligned_space_size
,
a_block_desc_k0_m0_m1_k1
.
GetElementSpaceSize
());
auto
b_block_odd_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
p_b_block_double
+
b_block_aligned_space_size
,
b_block_desc_k0_n0_n1_k1
.
GetElementSpaceSize
());
// LDS double buffer: preload data into LDS
{
a_blockwise_copy
.
RunRead
(
a_grid_desc_k0_m0_m1_k1
,
a_global_buf
);
b_blockwise_copy
.
RunRead
(
b_grid_desc_k0_n0_n1_k1
,
b_global_buf
);
a_blockwise_copy
.
RunWrite
(
a_block_desc_k0_m0_m1_k1
,
a_block_even_buf
);
b_blockwise_copy
.
RunWrite
(
b_block_desc_k0_n0_n1_k1
,
b_block_even_buf
);
}
if
constexpr
(
HasMainKBlockLoop
)
{
const
auto
K0
=
a_grid_desc_k0_m0_m1_k1
.
GetLength
(
I0
);
index_t
k_block_data_begin
=
0
;
// LDS double buffer: main body
// use Do-While loop instead of For loop to simplify control flow
do
{
// even iteration
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc_k0_m0_m1_k1
,
a_block_slice_copy_step
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc_k0_n0_n1_k1
,
b_block_slice_copy_step
);
// LDS doubel buffer: load next data from device mem
a_blockwise_copy
.
RunRead
(
a_grid_desc_k0_m0_m1_k1
,
a_global_buf
);
b_blockwise_copy
.
RunRead
(
b_grid_desc_k0_n0_n1_k1
,
b_global_buf
);
block_sync_lds
();
// LDS double buffer: GEMM on current data
blockwise_gemm
.
Run
(
c_thread_desc_m10_m11_n10_n11
,
a_block_even_buf
,
b_block_even_buf
,
c_thread_buf
);
// LDS double buffer: store next data to LDS
a_blockwise_copy
.
RunWrite
(
a_block_desc_k0_m0_m1_k1
,
a_block_odd_buf
);
b_blockwise_copy
.
RunWrite
(
b_block_desc_k0_n0_n1_k1
,
b_block_odd_buf
);
// odd iteration
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc_k0_m0_m1_k1
,
a_block_slice_copy_step
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc_k0_n0_n1_k1
,
b_block_slice_copy_step
);
// LDS doubel buffer: load next data from device mem
a_blockwise_copy
.
RunRead
(
a_grid_desc_k0_m0_m1_k1
,
a_global_buf
);
b_blockwise_copy
.
RunRead
(
b_grid_desc_k0_n0_n1_k1
,
b_global_buf
);
block_sync_lds
();
// LDS double buffer: GEMM on current data
blockwise_gemm
.
Run
(
c_thread_desc_m10_m11_n10_n11
,
a_block_odd_buf
,
b_block_odd_buf
,
c_thread_buf
);
// LDS double buffer: store next data to LDS
a_blockwise_copy
.
RunWrite
(
a_block_desc_k0_m0_m1_k1
,
a_block_even_buf
);
b_blockwise_copy
.
RunWrite
(
b_block_desc_k0_n0_n1_k1
,
b_block_even_buf
);
k_block_data_begin
+=
2
*
K0PerBlock
;
}
while
(
k_block_data_begin
<
K0
-
2
*
K0PerBlock
);
}
// LDS double buffer: tail
if
constexpr
(
HasDoubleTailKBlockLoop
)
// if has 2 iteration left
{
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc_k0_m0_m1_k1
,
a_block_slice_copy_step
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc_k0_n0_n1_k1
,
b_block_slice_copy_step
);
block_sync_lds
();
// LDS double buffer: load last data from device mem
a_blockwise_copy
.
RunRead
(
a_grid_desc_k0_m0_m1_k1
,
a_global_buf
);
b_blockwise_copy
.
RunRead
(
b_grid_desc_k0_n0_n1_k1
,
b_global_buf
);
// LDS double buffer: GEMM on 2nd-last data
blockwise_gemm
.
Run
(
c_thread_desc_m10_m11_n10_n11
,
a_block_even_buf
,
b_block_even_buf
,
c_thread_buf
);
// LDS double buffer: store last data to LDS
a_blockwise_copy
.
RunWrite
(
a_block_desc_k0_m0_m1_k1
,
a_block_odd_buf
);
b_blockwise_copy
.
RunWrite
(
b_block_desc_k0_n0_n1_k1
,
b_block_odd_buf
);
block_sync_lds
();
// LDS double buffer: GEMM on last data
blockwise_gemm
.
Run
(
c_thread_desc_m10_m11_n10_n11
,
a_block_odd_buf
,
b_block_odd_buf
,
c_thread_buf
);
}
else
// if has 1 iteration left
{
__syncthreads
();
// LDS double buffer: GEMM on last data
blockwise_gemm
.
Run
(
c_thread_desc_m10_m11_n10_n11
,
a_block_even_buf
,
b_block_even_buf
,
c_thread_buf
);
}
// output: register to global memory
{
constexpr
auto
c_thread_desc_m0_m10_m11_n0_n10_n11
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
Number
<
c_m10_m11_n10_n11_thread_tensor_lengths
[
I0
]
>
{},
Number
<
c_m10_m11_n10_n11_thread_tensor_lengths
[
I1
]
>
{},
I1
,
Number
<
c_m10_m11_n10_n11_thread_tensor_lengths
[
I2
]
>
{},
Number
<
c_m10_m11_n10_n11_thread_tensor_lengths
[
I3
]
>
{}));
const
auto
c_m10_m11_n10_n11_thread_origin_idx_on_block
=
blockwise_gemm
.
CalculateCThreadOriginOnBlock_BM0_BM1_BN0_BN1
(
get_thread_local_1d_id
());
const
auto
ds_grid_buf
=
generate_tuple
(
[
&
](
auto
i
)
{
return
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_ds_grid
[
i
],
ds_grid_desc_m0_m10_m11_n0_n10_n11
[
i
].
GetElementSpaceSize
());
},
Number
<
NumDTensor
>
{});
auto
ds_thread_buf
=
generate_tuple
(
[
&
](
auto
i
)
{
using
DDataType
=
remove_cvref_t
<
tuple_element_t
<
i
.
value
,
DsDataType
>>
;
return
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
DDataType
,
c_m10_m11_n10_n11_thread_tensor_lengths
[
I3
],
true
>
{};
},
Number
<
NumDTensor
>
{});
auto
ds_threadwise_copy
=
generate_tuple
(
[
&
](
auto
i
)
{
using
DDataType
=
remove_cvref_t
<
tuple_element_t
<
i
.
value
,
DsDataType
>>
;
return
ThreadwiseTensorSliceTransfer_v2
<
DDataType
,
DDataType
,
decltype
(
ds_grid_desc_m0_m10_m11_n0_n10_n11
[
i
]),
decltype
(
c_thread_desc_m0_m10_m11_n0_n10_n11
),
Sequence
<
I1
,
I1
,
I1
,
I1
,
I1
,
Number
<
c_m10_m11_n10_n11_thread_tensor_lengths
[
I3
]
>
{}
>
,
CThreadTransferSrcDstAccessOrder
,
CThreadTransferSrcDstVectorDim
,
CThreadTransferDstScalarPerVector
,
1
,
false
>
(
ds_grid_desc_m0_m10_m11_n0_n10_n11
[
i
],
make_multi_index
(
im0
,
c_m10_m11_n10_n11_thread_origin_idx_on_block
[
I0
],
c_m10_m11_n10_n11_thread_origin_idx_on_block
[
I1
],
in0
,
c_m10_m11_n10_n11_thread_origin_idx_on_block
[
I2
],
c_m10_m11_n10_n11_thread_origin_idx_on_block
[
I3
]));
},
Number
<
NumDTensor
>
{});
static_for
<
0
,
c_m10_m11_n10_n11_thread_tensor_lengths
[
I0
],
1
>
{}([
&
](
auto
m10
)
{
static_for
<
0
,
c_m10_m11_n10_n11_thread_tensor_lengths
[
I1
],
1
>
{}([
&
](
auto
m11
)
{
static_for
<
0
,
c_m10_m11_n10_n11_thread_tensor_lengths
[
I2
],
1
>
{}([
&
](
auto
n10
)
{
// load d matrix data
static_for
<
0
,
NumDTensor
,
1
>
{}([
&
](
auto
i
)
{
ds_threadwise_copy
(
i
).
Run
(
ds_grid_desc_m0_m10_m11_n0_n10_n11
[
i
],
ds_grid_buf
[
i
],
c_thread_desc_m0_m10_m11_n0_n10_n11
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
ds_thread_buf
(
i
));
});
// cal element op
static_for
<
0
,
c_m10_m11_n10_n11_thread_tensor_lengths
[
I3
],
1
>
{}(
[
&
](
auto
i
)
{
// get reference to src data
const
auto
src_data_refs
=
generate_tie
(
// return type should be lvalue
[
&
](
auto
iSrc
)
->
const
auto
&
{
return
ds_thread_buf
[
iSrc
][
i
];
},
Number
<
NumDTensor
>
{});
// get reference to dst data
constexpr
index_t
c_offset
=
c_thread_desc_m0_m10_m11_n0_n10_n11
.
CalculateOffset
(
make_tuple
(
0
,
m10
,
m11
,
0
,
n10
,
i
));
auto
dst_data_refs
=
generate_tie
(
// return type should be lvalue
[
&
](
auto
)
->
auto
&
{
return
c_thread_buf
(
Number
<
c_offset
>
{});
},
Number
<
2
>
{});
unpack2
(
cde_element_op
,
dst_data_refs
,
src_data_refs
);
});
static_for
<
0
,
NumDTensor
,
1
>
{}([
&
](
auto
i
)
{
ds_threadwise_copy
(
i
).
MoveSrcSliceWindow
(
ds_grid_desc_m0_m10_m11_n0_n10_n11
[
i
],
make_multi_index
(
0
,
0
,
0
,
0
,
1
,
0
));
});
});
static_for
<
0
,
NumDTensor
,
1
>
{}([
&
](
auto
i
)
{
ds_threadwise_copy
(
i
).
MoveSrcSliceWindow
(
ds_grid_desc_m0_m10_m11_n0_n10_n11
[
i
],
make_multi_index
(
0
,
0
,
1
,
0
,
-
c_m10_m11_n10_n11_thread_tensor_lengths
[
I2
],
0
));
});
});
static_for
<
0
,
NumDTensor
,
1
>
{}([
&
](
auto
i
)
{
ds_threadwise_copy
(
i
).
MoveSrcSliceWindow
(
ds_grid_desc_m0_m10_m11_n0_n10_n11
[
i
],
make_multi_index
(
0
,
1
,
-
c_m10_m11_n10_n11_thread_tensor_lengths
[
I1
],
0
,
0
,
0
));
});
});
ThreadwiseTensorSliceTransfer_v1r3
<
FloatAcc
,
FloatC
,
decltype
(
c_thread_desc_m0_m10_m11_n0_n10_n11
),
decltype
(
c_grid_desc_m0_m10_m11_n0_n10_n11
),
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
Sequence
<
1
,
c_m10_m11_n10_n11_thread_tensor_lengths
[
I0
],
c_m10_m11_n10_n11_thread_tensor_lengths
[
I1
],
1
,
c_m10_m11_n10_n11_thread_tensor_lengths
[
I2
],
c_m10_m11_n10_n11_thread_tensor_lengths
[
I3
]
>
,
CThreadTransferSrcDstAccessOrder
,
CThreadTransferSrcDstVectorDim
,
CThreadTransferDstScalarPerVector
,
CGlobalMemoryDataOperation
,
1
,
true
>
{
c_grid_desc_m0_m10_m11_n0_n10_n11
,
make_multi_index
(
im0
,
c_m10_m11_n10_n11_thread_origin_idx_on_block
[
I0
],
c_m10_m11_n10_n11_thread_origin_idx_on_block
[
I1
],
in0
,
c_m10_m11_n10_n11_thread_origin_idx_on_block
[
I2
],
c_m10_m11_n10_n11_thread_origin_idx_on_block
[
I3
]),
ck
::
tensor_operation
::
element_wise
::
PassThrough
{}}
.
Run
(
c_thread_desc_m0_m10_m11_n0_n10_n11
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
c_thread_buf
,
c_grid_desc_m0_m10_m11_n0_n10_n11
,
c_grid_buf
);
}
}
};
}
// namespace ck
include/ck/utility/amd_wmma.hpp
0 → 100644
View file @
a768dea5
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#ifndef CK_AMD_WMMA_HPP
#define CK_AMD_WMMA_HPP
#include "data_type.hpp"
// TODO: Add arch limitation
namespace
ck
{
// wave32 only
// src: fp16, dst: fp32
template
<
index_t
MPerWave
,
index_t
NPerWave
>
struct
intrin_wmma_f32_16x16x16_f16_w32
;
template
<
>
struct
intrin_wmma_f32_16x16x16_f16_w32
<
16
,
16
>
{
template
<
class
FloatC
>
__device__
static
void
Run
(
const
half16_t
&
reg_a
,
const
half16_t
&
reg_b
,
FloatC
&
reg_c
)
{
reg_c
.
template
AsType
<
float8_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_wmma_f32_16x16x16_f16_w32
(
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float8_t
>()[
Number
<
0
>
{}]);
}
};
// src: bf16, dst: fp32
template
<
index_t
MPerWave
,
index_t
NPerWave
>
struct
intrin_wmma_f32_16x16x16_bf16_w32
;
template
<
>
struct
intrin_wmma_f32_16x16x16_bf16_w32
<
16
,
16
>
{
template
<
class
FloatC
>
__device__
static
void
Run
(
const
bhalf16_t
&
reg_a
,
const
bhalf16_t
&
reg_b
,
FloatC
&
reg_c
)
{
reg_c
.
template
AsType
<
float8_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_wmma_f32_16x16x16_bf16_w32
(
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float8_t
>()[
Number
<
0
>
{}]);
}
};
// src: fp16, dst: fp16
template
<
index_t
MPerWave
,
index_t
NPerWave
,
index_t
Opsel
>
struct
intrin_wmma_f16_16x16x16_f16_w32
;
template
<
index_t
Opsel
>
struct
intrin_wmma_f16_16x16x16_f16_w32
<
16
,
16
,
Opsel
>
{
template
<
class
FloatC
>
__device__
static
void
Run
(
const
half16_t
&
reg_a
,
const
half16_t
&
reg_b
,
FloatC
&
reg_c
)
{
// opsel usage
// false: D0.[0:15] = result
// true : D0.[16:31]= result
reg_c
.
template
AsType
<
half16_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_wmma_f16_16x16x16_f16_w32
(
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
half16_t
>()[
Number
<
0
>
{}],
Opsel
);
}
};
// src: bf16, dst: bf16
template
<
index_t
MPerWave
,
index_t
NPerWave
,
index_t
Opsel
>
struct
intrin_wmma_bf16_16x16x16_bf16_w32
;
template
<
index_t
Opsel
>
struct
intrin_wmma_bf16_16x16x16_bf16_w32
<
16
,
16
,
Opsel
>
{
template
<
class
FloatC
>
__device__
static
void
Run
(
const
bhalf16_t
&
reg_a
,
const
bhalf16_t
&
reg_b
,
FloatC
&
reg_c
)
{
// opsel usage
// false: D0.[0:15] = result
// true : D0.[16:31]= result
reg_c
.
template
AsType
<
bhalf16_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_wmma_bf16_16x16x16_bf16_w32
(
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
bhalf16_t
>()[
Number
<
0
>
{}],
Opsel
);
}
};
// src: iu8, dst: i32
template
<
index_t
MPerWave
,
index_t
NPerWave
,
bool
neg_a
,
bool
neg_b
,
bool
clamp
>
struct
intrin_wmma_i32_16x16x16_iu8_w32
;
template
<
bool
neg_a
,
bool
neg_b
,
bool
clamp
>
struct
intrin_wmma_i32_16x16x16_iu8_w32
<
16
,
16
,
neg_a
,
neg_b
,
clamp
>
{
template
<
class
FloatC
>
__device__
static
void
Run
(
const
int8x16_t
&
reg_a
,
const
int8x16_t
&
reg_b
,
FloatC
&
reg_c
)
{
reg_c
.
template
AsType
<
int32x8_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_wmma_i32_16x16x16_iu8_w32
(
neg_a
,
bit_cast
<
int32x4_t
>
(
reg_a
),
neg_b
,
bit_cast
<
int32x4_t
>
(
reg_b
),
reg_c
.
template
AsType
<
int32x8_t
>()[
Number
<
0
>
{}],
clamp
);
}
};
}
// namespace ck
#endif
include/ck/utility/math_v2.hpp
View file @
a768dea5
...
...
@@ -114,7 +114,16 @@ static inline __device__ int4_t abs(int4_t x)
};
#endif
static
inline
__device__
half_t
abs
(
half_t
x
)
{
return
::
__habs
(
x
);
};
static
inline
__device__
half_t
abs
(
half_t
x
)
{
uint16_t
xx
=
ck
::
bit_cast
<
uint16_t
>
(
x
);
uint16_t
abs_xx
=
xx
&
0x7fff
;
half_t
abs_x
=
ck
::
bit_cast
<
half_t
>
(
abs_xx
);
return
abs_x
;
};
static
inline
__device__
bool
isnan
(
float
x
)
{
return
::
isnan
(
x
);
};
...
...
@@ -140,7 +149,12 @@ static inline __device__ bool isnan(int4_t x)
};
#endif
static
inline
__device__
bool
isnan
(
half_t
x
)
{
return
::
__hisnan
(
x
);
};
static
inline
__device__
bool
isnan
(
half_t
x
)
{
uint16_t
xx
=
ck
::
bit_cast
<
uint16_t
>
(
x
);
return
(
xx
&
0x7FFF
)
>
0x7C00
;
};
static
inline
__device__
float
sqrt
(
float
x
)
{
return
::
sqrtf
(
x
);
};
...
...
library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp
View file @
a768dea5
...
...
@@ -131,6 +131,47 @@ void add_device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_int8_instances(
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleD
<
2
,
GNHWC
,
GKYXC
,
Empty_Tuple
,
GNHWK
,
F16
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f32_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleD
<
2
,
GNHWC
,
GKYXC
,
Empty_Tuple
,
GNHWK
,
F32
,
F32
,
Empty_Tuple
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_int8_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleD
<
2
,
GNHWC
,
GKYXC
,
Empty_Tuple
,
GNHWK
,
int8_t
,
int8_t
,
Empty_Tuple
,
int8_t
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
// grouped conv2d forward, NHWGC/GKYXC/NHWGK
void
add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleD
<
2
,
...
...
@@ -273,11 +314,13 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
is_same_v
<
OutDataType
,
float
>
)
{
add_device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f32_instances
(
op_ptrs
);
add_device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f32_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same_v
<
InDataType
,
half_t
>
&&
is_same_v
<
WeiDataType
,
half_t
>
&&
is_same_v
<
OutDataType
,
half_t
>
)
{
add_device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f16_instances
(
op_ptrs
);
add_device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f16_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same_v
<
InDataType
,
ck
::
bhalf_t
>
&&
is_same_v
<
WeiDataType
,
ck
::
bhalf_t
>
&&
...
...
@@ -289,6 +332,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
is_same_v
<
OutDataType
,
int8_t
>
)
{
add_device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_int8_instances
(
op_ptrs
);
add_device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_int8_instances
(
op_ptrs
);
}
}
else
if
constexpr
(
NumDimSpatial
==
2
&&
is_same_v
<
InLayout
,
NHWGC
>
&&
...
...
library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_dl.hpp
deleted
100644 → 0
View file @
3f976dd0
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
// grouped conv2d forward, GNHWC/GKYXC/GNHWK
void
add_device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwd
<
2
,
GNHWC
,
GKYXC
,
GNHWK
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f32_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwd
<
2
,
GNHWC
,
GKYXC
,
GNHWK
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_int8_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwd
<
2
,
GNHWC
,
GKYXC
,
GNHWK
,
int8_t
,
int8_t
,
int8_t
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
template
<
ck
::
index_t
NumDimSpatial
,
typename
InLayout
,
typename
WeiLayout
,
typename
OutLayout
,
typename
InDataType
,
typename
WeiDataType
,
typename
OutDataType
>
struct
DeviceOperationInstanceFactory
<
ck
::
tensor_operation
::
device
::
DeviceGroupedConvFwd
<
NumDimSpatial
,
InLayout
,
WeiLayout
,
OutLayout
,
InDataType
,
WeiDataType
,
OutDataType
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
>>
{
using
DeviceOp
=
DeviceGroupedConvFwd
<
NumDimSpatial
,
InLayout
,
WeiLayout
,
OutLayout
,
InDataType
,
WeiDataType
,
OutDataType
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
>
;
static
auto
GetInstances
()
{
std
::
vector
<
std
::
unique_ptr
<
DeviceOp
>>
op_ptrs
;
if
constexpr
(
NumDimSpatial
==
2
&&
is_same_v
<
InLayout
,
GNHWC
>
&&
is_same_v
<
WeiLayout
,
GKYXC
>
&&
is_same_v
<
OutLayout
,
GNHWK
>
)
{
if
constexpr
(
is_same_v
<
InDataType
,
float
>
&&
is_same_v
<
WeiDataType
,
float
>
&&
is_same_v
<
OutDataType
,
float
>
)
{
add_device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f32_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same_v
<
InDataType
,
half_t
>
&&
is_same_v
<
WeiDataType
,
half_t
>
&&
is_same_v
<
OutDataType
,
half_t
>
)
{
add_device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f16_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same_v
<
InDataType
,
int8_t
>
&&
is_same_v
<
WeiDataType
,
int8_t
>
&&
is_same_v
<
OutDataType
,
int8_t
>
)
{
add_device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_int8_instances
(
op_ptrs
);
}
}
return
op_ptrs
;
}
};
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16_km_kn_mn_instance.cpp
View file @
a768dea5
...
...
@@ -25,7 +25,8 @@ using S = ck::Sequence<Is...>;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
static
constexpr
auto
GemmDefault
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
;
static
constexpr
auto
GemmDefault
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
;
static
constexpr
auto
GemmMNPadding
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNPadding
;
// Compilation parameters for a[k, m] * b[k, n] = c[m, n]
using
device_gemm_xdl_f16_f16_f16_km_kn_mn_instances
=
...
...
@@ -71,12 +72,36 @@ using device_gemm_xdl_f16_f16_f16_km_kn_mn_instances =
// clang-format on
>
;
// irregular tile size
using
device_gemm_xdl_f16_f16_f16_km_kn_mn_irregular_tile_instances
=
std
::
tuple
<
// clang-format off
//###########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| NumPrefetch| LoopScheduler| Pipeline|
//###########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| | | |
//###########| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| | | |
//###########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// pipeline v1, 1 wave
DeviceGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Col
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
64
,
16
,
16
,
4
,
8
,
16
,
16
,
1
,
1
,
S
<
4
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
8
,
true
,
S
<
4
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
8
,
true
,
7
,
1
,
1
,
LoopScheduler
::
Default
,
PipelineVersion
::
v1
>
#if CK_EXPERIMENTAL_INTER_WAVE_INSTANCES
// pipeline v1, 2 waves
,
DeviceGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Col
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
64
,
16
,
16
,
4
,
8
,
16
,
16
,
1
,
1
,
S
<
4
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
8
,
true
,
S
<
4
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
8
,
true
,
7
,
1
,
1
,
LoopScheduler
::
Interwave
,
PipelineVersion
::
v1
>
#endif
#if CK_EXPERIMENTAL_PIPELINE_V2_INSTANCES
// pipeline v2, 1 wave
,
DeviceGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Col
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
64
,
16
,
16
,
4
,
8
,
16
,
16
,
1
,
1
,
S
<
4
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
8
,
true
,
S
<
4
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
8
,
true
,
7
,
1
,
1
,
LoopScheduler
::
Default
,
PipelineVersion
::
v2
>
#endif
// clang-format on
>
;
void
add_device_gemm_xdl_f16_f16_f16_km_kn_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemm
<
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_gemm_xdl_f16_f16_f16_km_kn_mn_instances
{});
add_device_operation_instances
(
instances
,
device_gemm_xdl_f16_f16_f16_km_kn_mn_irregular_tile_instances
{});
}
}
// namespace instance
...
...
library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16_km_nk_mn_instance.cpp
View file @
a768dea5
...
...
@@ -25,7 +25,8 @@ using S = ck::Sequence<Is...>;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
static
constexpr
auto
GemmDefault
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
;
static
constexpr
auto
GemmDefault
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
;
static
constexpr
auto
GemmMNPadding
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNPadding
;
// Compilation parameters for a[k, m] * b[n, k] = c[m, n]
using
device_gemm_xdl_f16_f16_f16_km_nk_mn_instances
=
...
...
@@ -71,12 +72,36 @@ using device_gemm_xdl_f16_f16_f16_km_nk_mn_instances =
// clang-format on
>
;
// irregular tile size
using
device_gemm_xdl_f16_f16_f16_km_nk_mn_irregular_tile_instances
=
std
::
tuple
<
// clang-format off
//###########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| NumPrefetch| LoopScheduler| Pipeline|
//###########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| | | |
//###########| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| | | |
//###########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// pipeline v1, 1 wave
DeviceGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Col
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
64
,
16
,
16
,
4
,
8
,
16
,
16
,
1
,
1
,
S
<
4
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
8
,
true
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
8
,
true
,
7
,
1
,
1
,
LoopScheduler
::
Default
,
PipelineVersion
::
v1
>
#if CK_EXPERIMENTAL_INTER_WAVE_INSTANCES
// pipeline v1, 2 waves
,
DeviceGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Col
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
64
,
16
,
16
,
4
,
8
,
16
,
16
,
1
,
1
,
S
<
4
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
8
,
true
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
8
,
true
,
7
,
1
,
1
,
LoopScheduler
::
Interwave
,
PipelineVersion
::
v1
>
#endif
#if CK_EXPERIMENTAL_PIPELINE_V2_INSTANCES
// pipeline v2, 1 wave
,
DeviceGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Col
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
64
,
16
,
16
,
4
,
8
,
16
,
16
,
1
,
1
,
S
<
4
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
8
,
true
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
8
,
true
,
7
,
1
,
1
,
LoopScheduler
::
Default
,
PipelineVersion
::
v2
>
#endif
// clang-format on
>
;
void
add_device_gemm_xdl_f16_f16_f16_km_nk_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemm
<
Col
,
Col
,
Row
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_gemm_xdl_f16_f16_f16_km_nk_mn_instances
{});
add_device_operation_instances
(
instances
,
device_gemm_xdl_f16_f16_f16_km_nk_mn_irregular_tile_instances
{});
}
}
// namespace instance
...
...
library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16_mk_kn_mn_instance.cpp
View file @
a768dea5
...
...
@@ -25,7 +25,8 @@ using S = ck::Sequence<Is...>;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
static
constexpr
auto
GemmDefault
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
;
static
constexpr
auto
GemmDefault
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
;
static
constexpr
auto
GemmMNPadding
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNPadding
;
// Compilation parameters for a[m, k] * b[k, n] = c[m, n]
using
device_gemm_xdl_f16_f16_f16_mk_kn_mn_instances
=
...
...
@@ -98,12 +99,36 @@ using device_gemm_xdl_f16_f16_f16_mk_kn_mn_instances =
// clang-format on
>
;
// irregular tile size
using
device_gemm_xdl_f16_f16_f16_mk_kn_mn_irregular_tile_instances
=
std
::
tuple
<
// clang-format off
//###########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| NumPrefetch| LoopScheduler| Pipeline|
//###########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| | | |
//###########| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| | | |
//###########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// pipeline v1, 1 wave
DeviceGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
64
,
16
,
16
,
4
,
8
,
16
,
16
,
1
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
8
,
true
,
7
,
1
,
1
,
LoopScheduler
::
Default
,
PipelineVersion
::
v1
>
#if CK_EXPERIMENTAL_INTER_WAVE_INSTANCES
// pipeline v1, 2 waves
,
DeviceGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
64
,
16
,
16
,
4
,
8
,
16
,
16
,
1
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
8
,
true
,
7
,
1
,
1
,
LoopScheduler
::
Interwave
,
PipelineVersion
::
v1
>
#endif
#if CK_EXPERIMENTAL_PIPELINE_V2_INSTANCES
// pipeline v2, 1 wave
,
DeviceGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
64
,
16
,
16
,
4
,
8
,
16
,
16
,
1
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
8
,
true
,
7
,
1
,
1
,
LoopScheduler
::
Default
,
PipelineVersion
::
v2
>
#endif
// clang-format on
>
;
void
add_device_gemm_xdl_f16_f16_f16_mk_kn_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemm
<
Row
,
Row
,
Row
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_gemm_xdl_f16_f16_f16_mk_kn_mn_instances
{});
add_device_operation_instances
(
instances
,
device_gemm_xdl_f16_f16_f16_mk_kn_mn_irregular_tile_instances
{});
}
}
// namespace instance
...
...
library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16_mk_nk_mn_instance.cpp
View file @
a768dea5
...
...
@@ -94,17 +94,20 @@ using device_gemm_xdl_f16_f16_f16_mk_nk_mn_irregular_tile_instances = std::tuple
//###########| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| | | |
//###########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// pipeline v1, 1 wave
DeviceGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
64
,
64
,
32
,
4
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
7
,
1
,
1
,
LoopScheduler
::
Default
,
PipelineVersion
::
v1
>
,
DeviceGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
256
,
128
,
144
,
8
,
8
,
16
,
16
,
2
,
9
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
8
,
8
,
4
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
2
,
2
,
true
,
7
,
1
,
1
,
LoopScheduler
::
Default
,
PipelineVersion
::
v1
>
,
DeviceGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
256
,
128
,
144
,
4
,
8
,
16
,
16
,
2
,
9
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
16
,
4
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
2
,
2
,
true
,
7
,
1
,
1
,
LoopScheduler
::
Default
,
PipelineVersion
::
v1
>
#if CK_EXPERIMENTAL_INTER_WAVE_INSTANCES
// pipeline v1, 2 waves
,
DeviceGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
64
,
64
,
32
,
4
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
7
,
1
,
1
,
LoopScheduler
::
Interwave
,
PipelineVersion
::
v1
>
,
DeviceGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
256
,
128
,
144
,
8
,
8
,
16
,
16
,
2
,
9
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
8
,
8
,
4
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
2
,
2
,
true
,
7
,
1
,
1
,
LoopScheduler
::
Interwave
,
PipelineVersion
::
v1
>
,
DeviceGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
256
,
128
,
144
,
4
,
8
,
16
,
16
,
2
,
9
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
16
,
4
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
2
,
2
,
true
,
7
,
1
,
1
,
LoopScheduler
::
Interwave
,
PipelineVersion
::
v1
>
#endif
#if CK_EXPERIMENTAL_PIPELINE_V2_INSTANCES
// pipeline v2, 1 wave
,
DeviceGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
64
,
64
,
32
,
4
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
7
,
1
,
1
,
LoopScheduler
::
Default
,
PipelineVersion
::
v2
>
,
DeviceGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
256
,
128
,
144
,
8
,
8
,
16
,
16
,
2
,
9
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
8
,
8
,
4
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
2
,
2
,
true
,
7
,
1
,
1
,
LoopScheduler
::
Default
,
PipelineVersion
::
v2
>
,
DeviceGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
256
,
128
,
144
,
4
,
8
,
16
,
16
,
2
,
9
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
16
,
4
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
2
,
2
,
true
,
7
,
1
,
1
,
LoopScheduler
::
Default
,
PipelineVersion
::
v2
>
#endif
...
...
library/src/tensor_operation_instance/gpu/gemm_add_add_fastgelu/device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_km_kn_mn_mn_mn_instance.cpp
View file @
a768dea5
...
...
@@ -30,7 +30,8 @@ using S = ck::Sequence<Is...>;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
AddAddFastGelu
=
ck
::
tensor_operation
::
element_wise
::
AddAddFastGelu
;
static
constexpr
auto
GemmDefault
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
;
static
constexpr
auto
GemmDefault
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
;
static
constexpr
auto
GemmMNKPadding
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNKPadding
;
// e = elementwise((a * b), d0, d1)
// outout: e[m, n]
...
...
@@ -102,6 +103,29 @@ using device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_km_kn_mn_mn
// clang-format on
>
;
// irregular tile size
using
device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_km_kn_mn_mn_mn_irregular_tile_instances
=
std
::
tuple
<
// clang-format off
//##############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| LoopScheduler| Pipeline|
//##############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| | |
//##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | |
//##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// pipeline v1, 1 wave
DeviceGemmMultipleD_Xdl_CShuffle
<
Col
,
Row
,
Row_Row_Tuple
,
Row
,
F16
,
F16
,
F32
,
F32
,
F16_F16_Tuple
,
F16
,
PassThrough
,
PassThrough
,
AddAddFastGelu
,
GemmMNKPadding
,
1
,
64
,
16
,
16
,
32
,
8
,
8
,
16
,
16
,
1
,
1
,
S
<
4
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
8
,
1
,
S
<
4
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
1
,
LoopScheduler
::
Default
,
PipelineVersion
::
v1
>
#if CK_EXPERIMENTAL_INTER_WAVE_INSTANCES
// pipeline v1, 2 waves
,
DeviceGemmMultipleD_Xdl_CShuffle
<
Col
,
Row
,
Row_Row_Tuple
,
Row
,
F16
,
F16
,
F32
,
F32
,
F16_F16_Tuple
,
F16
,
PassThrough
,
PassThrough
,
AddAddFastGelu
,
GemmMNKPadding
,
1
,
64
,
16
,
16
,
32
,
8
,
8
,
16
,
16
,
1
,
1
,
S
<
4
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
8
,
1
,
S
<
4
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
1
,
LoopScheduler
::
Interwave
,
PipelineVersion
::
v1
>
#endif
#if CK_EXPERIMENTAL_PIPELINE_V2_INSTANCES
// pipeline v2, 1 wave
,
DeviceGemmMultipleD_Xdl_CShuffle
<
Col
,
Row
,
Row_Row_Tuple
,
Row
,
F16
,
F16
,
F32
,
F32
,
F16_F16_Tuple
,
F16
,
PassThrough
,
PassThrough
,
AddAddFastGelu
,
GemmMNKPadding
,
1
,
64
,
16
,
16
,
32
,
8
,
8
,
16
,
16
,
1
,
1
,
S
<
4
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
8
,
1
,
S
<
4
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
1
,
LoopScheduler
::
Default
,
PipelineVersion
::
v2
>
#endif
// clang-format on
>
;
void
add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_km_kn_mn_mn_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleD
<
Col
,
Row
,
...
...
@@ -118,6 +142,9 @@ void add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_km_kn_mn
add_device_operation_instances
(
instances
,
device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_km_kn_mn_mn_mn_instances
{});
add_device_operation_instances
(
instances
,
device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_km_kn_mn_mn_mn_irregular_tile_instances
{});
}
}
// namespace instance
...
...
library/src/tensor_operation_instance/gpu/gemm_add_add_fastgelu/device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_km_nk_mn_mn_mn_instance.cpp
View file @
a768dea5
...
...
@@ -30,7 +30,8 @@ using S = ck::Sequence<Is...>;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
AddAddFastGelu
=
ck
::
tensor_operation
::
element_wise
::
AddAddFastGelu
;
static
constexpr
auto
GemmDefault
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
;
static
constexpr
auto
GemmDefault
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
;
static
constexpr
auto
GemmMNKPadding
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNKPadding
;
// e = elementwise((a * b), d0, d1)
// outout: e[m, n]
...
...
@@ -102,6 +103,29 @@ using device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_km_nk_mn_mn
// clang-format on
>
;
// irregular tile size
using
device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_km_nk_mn_mn_mn_irregular_tile_instances
=
std
::
tuple
<
// clang-format off
//##############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| LoopScheduler| Pipeline|
//##############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| | |
//##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | |
//##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// pipeline v1, 1 wave
DeviceGemmMultipleD_Xdl_CShuffle
<
Col
,
Col
,
Row_Row_Tuple
,
Row
,
F16
,
F16
,
F32
,
F32
,
F16_F16_Tuple
,
F16
,
PassThrough
,
PassThrough
,
AddAddFastGelu
,
GemmMNKPadding
,
1
,
64
,
16
,
16
,
32
,
8
,
8
,
16
,
16
,
1
,
1
,
S
<
4
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
8
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
1
,
LoopScheduler
::
Default
,
PipelineVersion
::
v1
>
#if CK_EXPERIMENTAL_INTER_WAVE_INSTANCES
// pipeline v1, 2 waves
,
DeviceGemmMultipleD_Xdl_CShuffle
<
Col
,
Col
,
Row_Row_Tuple
,
Row
,
F16
,
F16
,
F32
,
F32
,
F16_F16_Tuple
,
F16
,
PassThrough
,
PassThrough
,
AddAddFastGelu
,
GemmMNKPadding
,
1
,
64
,
16
,
16
,
32
,
8
,
8
,
16
,
16
,
1
,
1
,
S
<
4
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
8
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
1
,
LoopScheduler
::
Interwave
,
PipelineVersion
::
v1
>
#endif
#if CK_EXPERIMENTAL_PIPELINE_V2_INSTANCES
// pipeline v2, 1 wave
,
DeviceGemmMultipleD_Xdl_CShuffle
<
Col
,
Col
,
Row_Row_Tuple
,
Row
,
F16
,
F16
,
F32
,
F32
,
F16_F16_Tuple
,
F16
,
PassThrough
,
PassThrough
,
AddAddFastGelu
,
GemmMNKPadding
,
1
,
64
,
16
,
16
,
32
,
8
,
8
,
16
,
16
,
1
,
1
,
S
<
4
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
8
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
1
,
LoopScheduler
::
Default
,
PipelineVersion
::
v2
>
#endif
// clang-format on
>
;
void
add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_km_nk_mn_mn_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleD
<
Col
,
Col
,
...
...
@@ -118,6 +142,9 @@ void add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_km_nk_mn
add_device_operation_instances
(
instances
,
device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_km_nk_mn_mn_mn_instances
{});
add_device_operation_instances
(
instances
,
device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_km_nk_mn_mn_mn_irregular_tile_instances
{});
}
}
// namespace instance
...
...
library/src/tensor_operation_instance/gpu/gemm_add_add_fastgelu/device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_kn_mn_mn_mn_instance.cpp
View file @
a768dea5
...
...
@@ -30,7 +30,8 @@ using S = ck::Sequence<Is...>;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
AddAddFastGelu
=
ck
::
tensor_operation
::
element_wise
::
AddAddFastGelu
;
static
constexpr
auto
GemmDefault
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
;
static
constexpr
auto
GemmDefault
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
;
static
constexpr
auto
GemmMNKPadding
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNKPadding
;
// e = elementwise((a * b), d0, d1)
// outout: e[m, n]
...
...
@@ -102,6 +103,29 @@ using device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_kn_mn_mn
// clang-format on
>
;
// irregular tile size
using
device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_kn_mn_mn_mn_irregular_tile_instances
=
std
::
tuple
<
// clang-format off
//##############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| LoopScheduler| Pipeline|
//##############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| | |
//##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | |
//##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// pipeline v1, 1 wave
DeviceGemmMultipleD_Xdl_CShuffle
<
Row
,
Row
,
Row_Row_Tuple
,
Row
,
F16
,
F16
,
F32
,
F32
,
F16_F16_Tuple
,
F16
,
PassThrough
,
PassThrough
,
AddAddFastGelu
,
GemmMNKPadding
,
1
,
64
,
16
,
16
,
32
,
8
,
8
,
16
,
16
,
1
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
8
,
1
,
S
<
4
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
1
,
LoopScheduler
::
Default
,
PipelineVersion
::
v1
>
#if CK_EXPERIMENTAL_INTER_WAVE_INSTANCES
// pipeline v1, 2 waves
,
DeviceGemmMultipleD_Xdl_CShuffle
<
Row
,
Row
,
Row_Row_Tuple
,
Row
,
F16
,
F16
,
F32
,
F32
,
F16_F16_Tuple
,
F16
,
PassThrough
,
PassThrough
,
AddAddFastGelu
,
GemmMNKPadding
,
1
,
64
,
16
,
16
,
32
,
8
,
8
,
16
,
16
,
1
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
8
,
1
,
S
<
4
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
1
,
LoopScheduler
::
Interwave
,
PipelineVersion
::
v1
>
#endif
#if CK_EXPERIMENTAL_PIPELINE_V2_INSTANCES
// pipeline v2, 1 wave
,
DeviceGemmMultipleD_Xdl_CShuffle
<
Row
,
Row
,
Row_Row_Tuple
,
Row
,
F16
,
F16
,
F32
,
F32
,
F16_F16_Tuple
,
F16
,
PassThrough
,
PassThrough
,
AddAddFastGelu
,
GemmMNKPadding
,
1
,
64
,
16
,
16
,
32
,
8
,
8
,
16
,
16
,
1
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
8
,
1
,
S
<
4
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
1
,
LoopScheduler
::
Default
,
PipelineVersion
::
v2
>
#endif
// clang-format on
>
;
void
add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_kn_mn_mn_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleD
<
Row
,
Row
,
...
...
@@ -118,6 +142,9 @@ void add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_kn_mn
add_device_operation_instances
(
instances
,
device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_kn_mn_mn_mn_instances
{});
add_device_operation_instances
(
instances
,
device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_kn_mn_mn_mn_irregular_tile_instances
{});
}
}
// namespace instance
...
...
library/src/tensor_operation_instance/gpu/gemm_add_add_fastgelu/device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_nk_mn_mn_mn_instance.cpp
View file @
a768dea5
...
...
@@ -30,7 +30,8 @@ using S = ck::Sequence<Is...>;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
AddAddFastGelu
=
ck
::
tensor_operation
::
element_wise
::
AddAddFastGelu
;
static
constexpr
auto
GemmDefault
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
;
static
constexpr
auto
GemmDefault
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
;
static
constexpr
auto
GemmMNKPadding
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNKPadding
;
// e = elementwise((a * b), d0, d1)
// outout: e[m, n]
...
...
@@ -93,6 +94,29 @@ using device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_nk_mn_mn
// clang-format on
>
;
// irregular tile size
using
device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_nk_mn_mn_mn_irregular_tile_instances
=
std
::
tuple
<
// clang-format off
//##############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| LoopScheduler| Pipeline|
//##############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| | |
//##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | |
//##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// pipeline v1, 1 wave
DeviceGemmMultipleD_Xdl_CShuffle
<
Row
,
Col
,
Row_Row_Tuple
,
Row
,
F16
,
F16
,
F32
,
F32
,
F16_F16_Tuple
,
F16
,
PassThrough
,
PassThrough
,
AddAddFastGelu
,
GemmMNKPadding
,
1
,
64
,
16
,
16
,
32
,
8
,
8
,
16
,
16
,
1
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
8
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
1
,
LoopScheduler
::
Default
,
PipelineVersion
::
v1
>
#if CK_EXPERIMENTAL_INTER_WAVE_INSTANCES
// pipeline v1, 2 waves
,
DeviceGemmMultipleD_Xdl_CShuffle
<
Row
,
Col
,
Row_Row_Tuple
,
Row
,
F16
,
F16
,
F32
,
F32
,
F16_F16_Tuple
,
F16
,
PassThrough
,
PassThrough
,
AddAddFastGelu
,
GemmMNKPadding
,
1
,
64
,
16
,
16
,
32
,
8
,
8
,
16
,
16
,
1
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
8
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
1
,
LoopScheduler
::
Interwave
,
PipelineVersion
::
v1
>
#endif
#if CK_EXPERIMENTAL_PIPELINE_V2_INSTANCES
// pipeline v2, 1 wave
,
DeviceGemmMultipleD_Xdl_CShuffle
<
Row
,
Col
,
Row_Row_Tuple
,
Row
,
F16
,
F16
,
F32
,
F32
,
F16_F16_Tuple
,
F16
,
PassThrough
,
PassThrough
,
AddAddFastGelu
,
GemmMNKPadding
,
1
,
64
,
16
,
16
,
32
,
8
,
8
,
16
,
16
,
1
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
8
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
1
,
LoopScheduler
::
Default
,
PipelineVersion
::
v2
>
#endif
// clang-format on
>
;
void
add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_nk_mn_mn_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleD
<
Row
,
Col
,
...
...
@@ -109,6 +133,9 @@ void add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_nk_mn
add_device_operation_instances
(
instances
,
device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_nk_mn_mn_mn_instances
{});
add_device_operation_instances
(
instances
,
device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_nk_mn_mn_mn_irregular_tile_instances
{});
}
}
// namespace instance
...
...
library/src/tensor_operation_instance/gpu/gemm_add_fastgelu/device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_km_kn_mn_mn_instance.cpp
View file @
a768dea5
...
...
@@ -15,7 +15,8 @@ namespace instance {
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
static
constexpr
auto
GemmDefault
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
;
static
constexpr
auto
GemmDefault
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
;
static
constexpr
auto
GemmMNKPadding
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNKPadding
;
// e = elementwise((a * b), d0)
// outout: e[m, n]
...
...
@@ -86,6 +87,29 @@ using device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_km_kn_mn_mn_instanc
// clang-format on
>
;
// irregular tile size
using
device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_km_kn_mn_mn_irregular_tile_instances
=
std
::
tuple
<
// clang-format off
//##############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| LoopScheduler| Pipeline|
//##############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| | |
//##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | |
//##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// pipeline v1, 1 wave
DeviceGemmMultipleD_Xdl_CShuffle
<
Col
,
Row
,
Row_Tuple
,
Row
,
F16
,
F16
,
F32
,
F32
,
F16_Tuple
,
F16
,
PassThrough
,
PassThrough
,
AddFastGelu
,
GemmMNKPadding
,
1
,
64
,
16
,
16
,
32
,
8
,
8
,
16
,
16
,
1
,
1
,
S
<
4
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
8
,
1
,
S
<
4
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
1
,
LoopScheduler
::
Default
,
PipelineVersion
::
v1
>
#if CK_EXPERIMENTAL_INTER_WAVE_INSTANCES
// pipeline v1, 2 waves
,
DeviceGemmMultipleD_Xdl_CShuffle
<
Col
,
Row
,
Row_Tuple
,
Row
,
F16
,
F16
,
F32
,
F32
,
F16_Tuple
,
F16
,
PassThrough
,
PassThrough
,
AddFastGelu
,
GemmMNKPadding
,
1
,
64
,
16
,
16
,
32
,
8
,
8
,
16
,
16
,
1
,
1
,
S
<
4
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
8
,
1
,
S
<
4
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
1
,
LoopScheduler
::
Interwave
,
PipelineVersion
::
v1
>
#endif
#if CK_EXPERIMENTAL_PIPELINE_V2_INSTANCES
// pipeline v2, 1 wave
,
DeviceGemmMultipleD_Xdl_CShuffle
<
Col
,
Row
,
Row_Tuple
,
Row
,
F16
,
F16
,
F32
,
F32
,
F16_Tuple
,
F16
,
PassThrough
,
PassThrough
,
AddFastGelu
,
GemmMNKPadding
,
1
,
64
,
16
,
16
,
32
,
8
,
8
,
16
,
16
,
1
,
1
,
S
<
4
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
8
,
1
,
S
<
4
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
1
,
LoopScheduler
::
Default
,
PipelineVersion
::
v2
>
#endif
// clang-format on
>
;
void
add_device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_km_kn_mn_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleD
<
Col
,
Row
,
...
...
@@ -101,6 +125,9 @@ void add_device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_km_kn_mn_mn_inst
{
add_device_operation_instances
(
instances
,
device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_km_kn_mn_mn_instances
{});
add_device_operation_instances
(
instances
,
device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_km_kn_mn_mn_irregular_tile_instances
{});
}
}
// namespace instance
...
...
library/src/tensor_operation_instance/gpu/gemm_add_fastgelu/device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_km_nk_mn_mn_instance.cpp
View file @
a768dea5
...
...
@@ -15,7 +15,8 @@ namespace instance {
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
static
constexpr
auto
GemmDefault
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
;
static
constexpr
auto
GemmDefault
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
;
static
constexpr
auto
GemmMNKPadding
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNKPadding
;
// e = elementwise((a * b), d0, d1)
// outout: e[m, n]
...
...
@@ -86,6 +87,29 @@ using device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_km_nk_mn_mn_instanc
// clang-format on
>
;
// irregular tile size
using
device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_km_nk_mn_mn_irregular_tile_instances
=
std
::
tuple
<
// clang-format off
//##############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| LoopScheduler| Pipeline|
//##############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| | |
//##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | |
//##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// pipeline v1, 1 wave
DeviceGemmMultipleD_Xdl_CShuffle
<
Col
,
Col
,
Row_Tuple
,
Row
,
F16
,
F16
,
F32
,
F32
,
F16_Tuple
,
F16
,
PassThrough
,
PassThrough
,
AddFastGelu
,
GemmMNKPadding
,
1
,
64
,
16
,
16
,
32
,
8
,
8
,
16
,
16
,
1
,
1
,
S
<
4
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
8
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
1
,
LoopScheduler
::
Default
,
PipelineVersion
::
v1
>
#if CK_EXPERIMENTAL_INTER_WAVE_INSTANCES
// pipeline v1, 2 waves
,
DeviceGemmMultipleD_Xdl_CShuffle
<
Col
,
Col
,
Row_Tuple
,
Row
,
F16
,
F16
,
F32
,
F32
,
F16_Tuple
,
F16
,
PassThrough
,
PassThrough
,
AddFastGelu
,
GemmMNKPadding
,
1
,
64
,
16
,
16
,
32
,
8
,
8
,
16
,
16
,
1
,
1
,
S
<
4
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
8
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
1
,
LoopScheduler
::
Interwave
,
PipelineVersion
::
v1
>
#endif
#if CK_EXPERIMENTAL_PIPELINE_V2_INSTANCES
// pipeline v2, 1 wave
,
DeviceGemmMultipleD_Xdl_CShuffle
<
Col
,
Col
,
Row_Tuple
,
Row
,
F16
,
F16
,
F32
,
F32
,
F16_Tuple
,
F16
,
PassThrough
,
PassThrough
,
AddFastGelu
,
GemmMNKPadding
,
1
,
64
,
16
,
16
,
32
,
8
,
8
,
16
,
16
,
1
,
1
,
S
<
4
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
8
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
1
,
LoopScheduler
::
Default
,
PipelineVersion
::
v2
>
#endif
// clang-format on
>
;
void
add_device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_km_nk_mn_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleD
<
Col
,
Col
,
...
...
@@ -101,6 +125,9 @@ void add_device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_km_nk_mn_mn_inst
{
add_device_operation_instances
(
instances
,
device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_km_nk_mn_mn_instances
{});
add_device_operation_instances
(
instances
,
device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_km_nk_mn_mn_irregular_tile_instances
{});
}
}
// namespace instance
...
...
library/src/tensor_operation_instance/gpu/gemm_add_fastgelu/device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instance.cpp
View file @
a768dea5
...
...
@@ -15,7 +15,8 @@ namespace instance {
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
static
constexpr
auto
GemmDefault
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
;
static
constexpr
auto
GemmDefault
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
;
static
constexpr
auto
GemmMNKPadding
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNKPadding
;
// e = elementwise((a * b), d0, d1)
// outout: e[m, n]
...
...
@@ -86,6 +87,29 @@ using device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instanc
// clang-format on
>
;
// irregular tile size
using
device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_irregular_tile_instances
=
std
::
tuple
<
// clang-format off
//##############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| LoopScheduler| Pipeline|
//##############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| | |
//##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | |
//##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// pipeline v1, 1 wave
DeviceGemmMultipleD_Xdl_CShuffle
<
Row
,
Row
,
Row_Tuple
,
Row
,
F16
,
F16
,
F32
,
F32
,
F16_Tuple
,
F16
,
PassThrough
,
PassThrough
,
AddFastGelu
,
GemmMNKPadding
,
1
,
64
,
16
,
16
,
32
,
8
,
8
,
16
,
16
,
1
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
8
,
1
,
S
<
4
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
1
,
LoopScheduler
::
Default
,
PipelineVersion
::
v1
>
#if CK_EXPERIMENTAL_INTER_WAVE_INSTANCES
// pipeline v1, 2 waves
,
DeviceGemmMultipleD_Xdl_CShuffle
<
Row
,
Row
,
Row_Tuple
,
Row
,
F16
,
F16
,
F32
,
F32
,
F16_Tuple
,
F16
,
PassThrough
,
PassThrough
,
AddFastGelu
,
GemmMNKPadding
,
1
,
64
,
16
,
16
,
32
,
8
,
8
,
16
,
16
,
1
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
8
,
1
,
S
<
4
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
1
,
LoopScheduler
::
Interwave
,
PipelineVersion
::
v1
>
#endif
#if CK_EXPERIMENTAL_PIPELINE_V2_INSTANCES
// pipeline v2, 1 wave
,
DeviceGemmMultipleD_Xdl_CShuffle
<
Row
,
Row
,
Row_Tuple
,
Row
,
F16
,
F16
,
F32
,
F32
,
F16_Tuple
,
F16
,
PassThrough
,
PassThrough
,
AddFastGelu
,
GemmMNKPadding
,
1
,
64
,
16
,
16
,
32
,
8
,
8
,
16
,
16
,
1
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
8
,
1
,
S
<
4
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
1
,
LoopScheduler
::
Default
,
PipelineVersion
::
v2
>
#endif
// clang-format on
>
;
void
add_device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleD
<
Row
,
Row
,
...
...
@@ -101,6 +125,9 @@ void add_device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_inst
{
add_device_operation_instances
(
instances
,
device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instances
{});
add_device_operation_instances
(
instances
,
device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_irregular_tile_instances
{});
}
}
// namespace instance
...
...
Prev
1
2
3
4
5
6
…
8
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment