Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
b30d416c
Commit
b30d416c
authored
Feb 10, 2024
by
Jun Liu
Browse files
Merge branch 'develop' into amd-develop
parents
2fd6c6d4
94fbaac0
Changes
183
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1634 additions
and
104 deletions
+1634
-104
include/ck/wrapper/layout.hpp
include/ck/wrapper/layout.hpp
+3
-0
include/ck/wrapper/operations/copy.hpp
include/ck/wrapper/operations/copy.hpp
+94
-31
include/ck/wrapper/operations/gemm.hpp
include/ck/wrapper/operations/gemm.hpp
+337
-0
include/ck/wrapper/tensor.hpp
include/ck/wrapper/tensor.hpp
+25
-14
include/ck/wrapper/traits/blockwise_gemm_xdl_traits.hpp
include/ck/wrapper/traits/blockwise_gemm_xdl_traits.hpp
+69
-0
include/ck/wrapper/utils/tensor_partition.hpp
include/ck/wrapper/utils/tensor_partition.hpp
+225
-30
include/ck/wrapper/utils/tensor_utils.hpp
include/ck/wrapper/utils/tensor_utils.hpp
+25
-10
library/include/ck/library/tensor_operation_instance/device_operation_instance_factory.hpp
..._operation_instance/device_operation_instance_factory.hpp
+3
-1
library/include/ck/library/tensor_operation_instance/gpu/device_gemm_xdl_c_shuffle_fp8_fp8_fp8_mk_kn_mn_v1_instance.hpp
...e_gemm_xdl_c_shuffle_fp8_fp8_fp8_mk_kn_mn_v1_instance.hpp
+59
-0
library/include/ck/library/tensor_operation_instance/gpu/device_gemm_xdl_c_shuffle_fp8_fp8_fp8_mk_kn_mn_v1_interwave_instance.hpp
..._c_shuffle_fp8_fp8_fp8_mk_kn_mn_v1_interwave_instance.hpp
+59
-0
library/include/ck/library/tensor_operation_instance/gpu/device_gemm_xdl_c_shuffle_fp8_fp8_fp8_mk_kn_mn_v2_instance.hpp
...e_gemm_xdl_c_shuffle_fp8_fp8_fp8_mk_kn_mn_v2_instance.hpp
+59
-0
library/include/ck/library/tensor_operation_instance/gpu/gemm.hpp
...include/ck/library/tensor_operation_instance/gpu/gemm.hpp
+26
-4
library/include/ck/library/tensor_operation_instance/gpu/gemm_add.hpp
...ude/ck/library/tensor_operation_instance/gpu/gemm_add.hpp
+114
-0
library/include/ck/library/tensor_operation_instance/gpu/gemm_add_fastgelu.hpp
...brary/tensor_operation_instance/gpu/gemm_add_fastgelu.hpp
+53
-1
library/include/ck/library/tensor_operation_instance/gpu/gemm_add_relu.hpp
...k/library/tensor_operation_instance/gpu/gemm_add_relu.hpp
+116
-0
library/include/ck/library/tensor_operation_instance/gpu/gemm_add_silu.hpp
...k/library/tensor_operation_instance/gpu/gemm_add_silu.hpp
+116
-0
library/include/ck/library/tensor_operation_instance/gpu/gemm_splitk.hpp
.../ck/library/tensor_operation_instance/gpu/gemm_splitk.hpp
+121
-11
library/include/ck/library/tensor_operation_instance/gpu/gemm_streamk.hpp
...ck/library/tensor_operation_instance/gpu/gemm_streamk.hpp
+12
-2
library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_wmma_f16_instance.hpp
...d_data/device_grouped_conv_bwd_data_wmma_f16_instance.hpp
+0
-0
library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_wmma_i8_instance.hpp
...wd_data/device_grouped_conv_bwd_data_wmma_i8_instance.hpp
+118
-0
No files found.
include/ck/wrapper/layout.hpp
View file @
b30d416c
...
...
@@ -248,6 +248,9 @@ struct Layout
using
DefaultIdxsTupleType
=
remove_cvref_t
<
decltype
(
GenerateDefaultIdxsTuple
(
Shape
{}))
>
;
public:
using
LayoutShape
=
Shape
;
using
LayoutUnrolledDescriptorType
=
UnrolledDescriptorType
;
/**
* \brief Transform descriptor to align to passed indexes.
*
...
...
include/ck/wrapper/operations/copy.hpp
View file @
b30d416c
...
...
@@ -3,45 +3,18 @@
#pragma once
#include "
..
/utils/tensor_utils.hpp"
#include "
ck/wrapper
/utils/tensor_utils.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v4r1.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_description/tensor_space_filling_curve.hpp"
namespace
ck
{
namespace
wrapper
{
/**
* \brief Perform generic copy between two tensors partitions (threadwise copy).
* Tensors must have the same size.
*
* \param src_tensor Source tensor.
* \param dst_tensor Destination tensor.
*/
template
<
typename
SrcTensorType
,
typename
DstTensorType
>
__host__
__device__
void
copy
(
const
SrcTensorType
&
src_tensor
,
DstTensorType
&
dst_tensor
)
{
if
constexpr
(
!
SrcTensorType
::
IsDynamicBuffer
)
{
using
SizeType
=
decltype
(
size
(
src_tensor
));
static_for
<
0
,
SizeType
{},
1
>
{}([
&
](
auto
i
)
{
dst_tensor
(
i
)
=
src_tensor
(
i
);
});
}
else
if
constexpr
(
!
DstTensorType
::
IsDynamicBuffer
)
{
using
SizeType
=
decltype
(
size
(
dst_tensor
));
static_for
<
0
,
SizeType
{},
1
>
{}([
&
](
auto
i
)
{
dst_tensor
(
i
)
=
src_tensor
(
i
);
});
}
else
{
for
(
int
i
=
0
;
i
<
size
(
src_tensor
);
i
++
)
{
dst_tensor
(
i
)
=
src_tensor
(
i
);
}
}
}
/**
* \brief Perform optimized copy between two tensors partitions (threadwise copy).
* Tensors must have the same size.
...
...
@@ -167,9 +140,99 @@ __device__ void copy(const SrcTensorType& src_tensor, DstTensorType& dst_tensor)
else
{
// Perform copy between StaticBuffers
copy
(
src_tensor
,
dst
_tensor
);
static_for
<
0
,
SrcShapeType
::
Size
(),
1
>
{}([
&
](
auto
i
)
{
dst_tensor
(
i
)
=
src
_tensor
(
i
);
}
);
}
}
/**
* \brief Perform generic copy between two tensors partitions (threadwise copy).
* Tensors must have the same size.
*
* \param src_tensor Source tensor.
* \param dst_tensor Destination tensor.
*/
template
<
typename
SrcTensorType
,
typename
DstTensorType
>
__host__
__device__
void
copy
(
const
SrcTensorType
&
src_tensor
,
DstTensorType
&
dst_tensor
)
{
// Generate default params
using
SrcShapeType
=
remove_cvref_t
<
decltype
(
shape
(
src_tensor
))
>
;
constexpr
index_t
num_dims
=
SrcShapeType
::
Size
();
// Incrementing dims 0, 1, 2 ... num_dims - 1
constexpr
auto
dim_access_order_tuple
=
generate_tuple
([](
auto
i
)
{
return
Number
<
i
>
{};
},
Number
<
num_dims
>
{});
constexpr
index_t
vector_dim
=
num_dims
-
1
;
constexpr
index_t
scalar_per_vector
=
1
;
copy
<
decltype
(
dim_access_order_tuple
),
vector_dim
,
scalar_per_vector
>
(
src_tensor
,
dst_tensor
);
}
/**
* \brief Perform optimized blockwise copy between two tensors. Tensors must have the
* same size.
*
* \note At now Vgpr and Sgpr are not supported.
*
* \tparam DimAccessOrderTuple Tuple with dimension access order.
* \tparam VectorDim Dimension for vectorize read and write.
* \tparam ScalarPerVector Number of scalar per vectorize read and write.
* \param src_tensor Source tensor.
* \param dst_tensor Destination tensor.
* \param thread_layout Thread layout per each dimension for copy.
*/
template
<
typename
DimAccessOrderTuple
,
index_t
VectorDim
,
index_t
ScalarPerVector
,
typename
SrcTensorType
,
typename
DstTensorType
,
typename
ThreadLayoutTuple
>
__device__
void
blockwise_copy
(
const
SrcTensorType
&
src_tensor
,
DstTensorType
&
dst_tensor
,
[[
maybe_unused
]]
ThreadLayoutTuple
&
thread_layout
)
{
static_assert
(
SrcTensorType
::
IsDynamicBuffer
&&
DstTensorType
::
IsDynamicBuffer
);
static_assert
(
is_detected
<
is_tuple
,
DimAccessOrderTuple
>::
value
);
const
auto
&
in_grid_desc
=
layout
(
src_tensor
).
GetUnrolledDescriptor
();
const
auto
&
out_grid_desc
=
layout
(
dst_tensor
).
GetUnrolledDescriptor
();
using
SrcShapeType
=
remove_cvref_t
<
decltype
(
shape
(
src_tensor
))
>
;
constexpr
index_t
num_dims
=
SrcShapeType
::
Size
();
constexpr
auto
tile_lengths_seq
=
generate_sequence_v2
([](
auto
I
)
{
return
size
(
SrcShapeType
{}.
At
(
I
));
},
Number
<
num_dims
>
{});
constexpr
auto
thread_layout_seq
=
generate_sequence_v2
(
[](
auto
I
)
{
return
size
(
ThreadLayoutTuple
{}.
At
(
I
));
},
Number
<
num_dims
>
{});
constexpr
auto
dim_access_order
=
generate_sequence_v2
(
[](
auto
I
)
{
return
DimAccessOrderTuple
{}.
At
(
I
);
},
Number
<
num_dims
>
{});
using
ThisThreadBlock
=
ThisThreadBlock
<
size
(
ThreadLayoutTuple
{})
>
;
// Perform copy between DynamicBuffers
auto
transfer
=
ThreadGroupTensorSliceTransfer_v7
<
ThisThreadBlock
,
Tuple
<
typename
SrcTensorType
::
TensorElementType
>
,
Tuple
<
typename
DstTensorType
::
TensorElementType
>
,
decltype
(
tie
(
in_grid_desc
)),
decltype
(
tie
(
out_grid_desc
)),
tensor_operation
::
element_wise
::
PassThrough
,
Sequence
<
static_cast
<
index_t
>
(
InMemoryDataOperationEnum
::
Set
)
>
,
std
::
remove_const_t
<
decltype
(
tile_lengths_seq
)
>
,
std
::
remove_const_t
<
decltype
(
thread_layout_seq
)
>
,
std
::
remove_const_t
<
decltype
(
dim_access_order
)
>
,
std
::
remove_const_t
<
decltype
(
dim_access_order
)
>
,
VectorDim
,
ScalarPerVector
,
Sequence
<
true
>
,
Sequence
<
true
>>
{
in_grid_desc
,
make_tuple
(
src_tensor
.
GetMultiIdxOffsets
()),
out_grid_desc
,
make_tuple
(
dst_tensor
.
GetMultiIdxOffsets
()),
tensor_operation
::
element_wise
::
PassThrough
{}};
transfer
.
Run
(
tie
(
in_grid_desc
),
tie
(
src_tensor
.
GetBuffer
()),
tie
(
out_grid_desc
),
tie
(
dst_tensor
.
GetBuffer
()));
}
}
// namespace wrapper
}
// namespace ck
include/ck/wrapper/operations/gemm.hpp
0 → 100644
View file @
b30d416c
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/wrapper/utils/tensor_utils.hpp"
#include "ck/wrapper/traits/blockwise_gemm_xdl_traits.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp"
namespace
ck
{
namespace
wrapper
{
namespace
{
namespace
detail
{
/**
* \brief Create block descriptor (K0, MPerBlock or NPerBlock, K1).
*
*
* \tparam K1 The number of K-dim elements that are packed together as a separate logical dimension.
* \tparam TileLayout Tensor data tile layout (M,K) or (N,K).
*
* \return Block descriptor (K0, MPerBlock or NPerBlock, K1)
*/
template
<
index_t
K1
,
typename
TileLayout
>
__device__
constexpr
auto
GetBlockDescriptor
()
{
using
TileLayoutShape
=
typename
TileLayout
::
LayoutShape
;
using
TileLayoutDescriptor
=
typename
TileLayout
::
LayoutUnrolledDescriptorType
;
constexpr
auto
K0PerBlock
=
Number
<
size
<
1
>
(
TileLayoutShape
{})
>
{}
/
Number
<
K1
>
{};
// MPerBlock or NPerBlock
constexpr
auto
Dim0
=
Number
<
size
<
0
>
(
TileLayoutShape
{})
>
{};
constexpr
auto
a_block_desc_k0_m_k1
=
transform_tensor_descriptor
(
TileLayoutDescriptor
{},
make_tuple
(
make_unmerge_transform
(
make_tuple
(
K0PerBlock
,
Number
<
K1
>
{})),
make_pass_through_transform
(
Dim0
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
a_block_desc_k0_m_k1
;
}
}
// namespace detail
}
// namespace
/**
* \brief Perform blockwise gemm xdl on tensors stored in lds. Result will be
* stored in Vgpr register. A data layout must be (MPerBlock, KPerBlock) and B
* data layout must be (NPerBlock, KPerBlock).
*
* \note C output Vgpr register layout (8D):
* - MXdlPerWave - The number of MFMA instructions run by single wave in M
* dimension per tile.
* - NXdlPerWave - The number of MFMA instructions run by single wave in N
* dimension per tile.
* - MWave - Equals to 1 since this is for single wave.
* - NWave - Equals to 1 since this is for single wave.
* - NumGroupsPerBlock - Mfma instruction internal layout (depeneds on the
* instruction size).
* - NumInputsBlock - Mfma instruction internal layout (depeneds on the
* instruction size).
* - GroupSize - Mfma instruction internal layout (depeneds on the
* instruction size).
* - NumThreadsPerBlock - Mfma instruction internal layout (depeneds on the
* instruction size).
*
* \tparam DataType Input data types.
* \tparam BlockSize Tensor to pad.
* \tparam GemmTraits Traits of gemm xdl operation.
* \param a_local_tile_tensor A tensor in LDS memory for blockwise gemm
* (MPerBlock, KPerBlock) layout.
* \param b_local_tile_tensor B tensor in LDS memory for blockwise gemm
* (NPerBlock, KPerBlock) layout.
* \param c_reg_tensor C tensor VGPR memory for blockwise gemm.
*/
template
<
typename
DataType
,
index_t
BlockSize
,
typename
GemmTraits
,
typename
ATensorType
,
typename
BTensorType
,
typename
CTensorType
>
__device__
void
blockwise_gemm_xdl
(
const
ATensorType
&
a_local_tile_tensor
,
const
BTensorType
&
b_local_tile_tensor
,
CTensorType
&
c_reg_tensor
)
{
static_assert
(
ATensorType
::
TensorBufferAddressSpace
==
MemoryTypeEnum
::
Lds
);
static_assert
(
BTensorType
::
TensorBufferAddressSpace
==
MemoryTypeEnum
::
Lds
);
static_assert
(
CTensorType
::
TensorBufferAddressSpace
==
MemoryTypeEnum
::
Vgpr
);
static_assert
(
is_same_v
<
DataType
,
typename
ATensorType
::
TensorElementType
>
);
static_assert
(
is_same_v
<
DataType
,
typename
BTensorType
::
TensorElementType
>
);
constexpr
bool
is_integer
=
is_same_v
<
DataType
,
int8_t
>
||
is_same_v
<
DataType
,
int16_t
>
||
is_same_v
<
DataType
,
int32_t
>
;
using
GemmAccDataType
=
std
::
conditional_t
<
is_integer
,
int32_t
,
float
>
;
using
ATileLayout
=
remove_cvref_t
<
decltype
(
layout
(
a_local_tile_tensor
))
>
;
using
BTileLayout
=
remove_cvref_t
<
decltype
(
layout
(
b_local_tile_tensor
))
>
;
using
ABlockDesc_K0_M_K1_Type
=
decltype
(
detail
::
GetBlockDescriptor
<
GemmTraits
::
K1
,
ATileLayout
>
());
using
BBlockDesc_K0_N_K1_Type
=
decltype
(
detail
::
GetBlockDescriptor
<
GemmTraits
::
K1
,
BTileLayout
>
());
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
<
BlockSize
,
DataType
,
DataType
,
GemmAccDataType
,
ABlockDesc_K0_M_K1_Type
,
BBlockDesc_K0_N_K1_Type
,
GemmTraits
::
MPerXDL
,
GemmTraits
::
NPerXDL
,
GemmTraits
::
MXdlPerWave
,
GemmTraits
::
NXdlPerWave
,
GemmTraits
::
K1
>
blockwise_gemm_xdl_op
{};
blockwise_gemm_xdl_op
.
Run
(
a_local_tile_tensor
.
GetBuffer
(),
b_local_tile_tensor
.
GetBuffer
(),
c_reg_tensor
.
GetBuffer
());
}
/**
* \brief Create local partition per thread for C tensor.
*
* \note C output global memory layout (8D):
* - MXdlPerWave - The number of MFMA instructions run by single wave in M
* dimension.
* - NXdlPerWave - The number of MFMA instructions run by single wave in N
* dimension.
* - MWave - The number of waves in single tile M dimension per tile.
* - NWave - The number of waves in single tile N dimension per tile.
* - NumGroupsPerBlock - Mfma instruction internal layout (depeneds on the
* instruction size).
* - NumInputsBlock - Mfma instruction internal layout (depeneds on the
* instruction size).
* - GroupSize - Mfma instruction internal layout (depeneds on the
* instruction size).
* - NumThreadsPerBlock - Mfma instruction internal layout (depeneds on the
* instruction size).
*
* \tparam DataType Input data types.
* \tparam ATileLayout A tensor layout.
* \tparam BTileLayout B tensor layout.
* \tparam BlockSize Number of threads in block.
* \tparam GemmTraits Traits of gemm xdl operation.
* \param c_local_tile_tensor C tensor in LDS memory for blockwise gemm
* (MPerBlock, NPerBlock) layout.
*
* \return Partition c tensor for blockwise gemm.
*/
template
<
typename
DataType
,
typename
ATileLayout
,
typename
BTileLayout
,
index_t
BlockSize
,
typename
GemmTraits
,
typename
CTensorType
>
__host__
__device__
constexpr
auto
make_blockwise_gemm_xdl_c_local_partition
(
CTensorType
&
c_local_tile_tensor
)
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
I4
=
Number
<
4
>
{};
constexpr
auto
I5
=
Number
<
5
>
{};
constexpr
auto
I6
=
Number
<
6
>
{};
constexpr
auto
I7
=
Number
<
7
>
{};
constexpr
bool
is_integer
=
is_same_v
<
DataType
,
int8_t
>
||
is_same_v
<
DataType
,
int16_t
>
||
is_same_v
<
DataType
,
int32_t
>
;
using
GemmAccDataType
=
std
::
conditional_t
<
is_integer
,
int32_t
,
float
>
;
using
ABlockDesc_K0_M_K1_Type
=
decltype
(
detail
::
GetBlockDescriptor
<
GemmTraits
::
K1
,
ATileLayout
>
());
using
BBlockDesc_K0_N_K1_Type
=
decltype
(
detail
::
GetBlockDescriptor
<
GemmTraits
::
K1
,
BTileLayout
>
());
using
BlockwiseGemmXdlops
=
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
<
BlockSize
,
DataType
,
DataType
,
GemmAccDataType
,
ABlockDesc_K0_M_K1_Type
,
BBlockDesc_K0_N_K1_Type
,
GemmTraits
::
MPerXDL
,
GemmTraits
::
NPerXDL
,
GemmTraits
::
MXdlPerWave
,
GemmTraits
::
NXdlPerWave
,
GemmTraits
::
K1
>
;
constexpr
auto
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
=
BlockwiseGemmXdlops
::
GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
();
constexpr
auto
M0
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
.
GetLength
(
I0
);
constexpr
auto
N0
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
.
GetLength
(
I1
);
constexpr
auto
M1
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
.
GetLength
(
I2
);
constexpr
auto
N1
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
.
GetLength
(
I3
);
constexpr
auto
M2
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
.
GetLength
(
I4
);
constexpr
auto
M3
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
.
GetLength
(
I5
);
constexpr
auto
M4
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
.
GetLength
(
I6
);
constexpr
auto
N2
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
.
GetLength
(
I7
);
// Calculate offset on grid
const
auto
c_thread_mtx_on_block
=
BlockwiseGemmXdlops
::
CalculateCThreadOriginDataIndex
(
I0
,
I0
,
I0
,
I0
);
const
index_t
m_thread_data_on_grid
=
c_local_tile_tensor
.
GetMultiIdxOffsets
()[
I0
]
+
c_thread_mtx_on_block
[
I0
];
const
index_t
n_thread_data_on_grid
=
c_local_tile_tensor
.
GetMultiIdxOffsets
()[
I1
]
+
c_thread_mtx_on_block
[
I1
];
const
auto
m_thread_data_on_grid_to_m0_m1_m2_m3_m4_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
M0
,
M1
,
M2
,
M3
,
M4
))),
make_tuple
(
Sequence
<
0
,
1
,
2
,
3
,
4
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
const
auto
m_thread_data_on_grid_idx
=
m_thread_data_on_grid_to_m0_m1_m2_m3_m4_adaptor
.
CalculateBottomIndex
(
make_multi_index
(
m_thread_data_on_grid
));
const
auto
n_thread_data_on_grid_to_n0_n1_n2_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
N0
,
N1
,
N2
))),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
const
auto
n_thread_data_on_grid_idx
=
n_thread_data_on_grid_to_n0_n1_n2_adaptor
.
CalculateBottomIndex
(
make_multi_index
(
n_thread_data_on_grid
));
// Create partition shape based on descriptor dims.
const
auto
partition_shape
=
make_tuple
(
M0
,
N0
,
I1
,
I1
,
M2
,
I1
,
M4
,
I1
);
const
auto
partition_desc
=
BlockwiseGemmXdlops
::
MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
(
layout
(
c_local_tile_tensor
).
GetUnrolledDescriptor
());
const
auto
partition_layout
=
Layout
<
remove_reference_t
<
decltype
(
partition_shape
)
>
,
decltype
(
partition_desc
)
>
(
partition_shape
,
partition_desc
);
auto
partition_tensor
=
make_tensor
<
CTensorType
::
TensorBufferAddressSpace
>
(
c_local_tile_tensor
.
GetPointer
(),
partition_layout
);
partition_tensor
.
SetMultiIdxOffset
(
make_multi_index
(
m_thread_data_on_grid_idx
[
I0
],
n_thread_data_on_grid_idx
[
I0
],
m_thread_data_on_grid_idx
[
I1
],
n_thread_data_on_grid_idx
[
I1
],
m_thread_data_on_grid_idx
[
I2
],
m_thread_data_on_grid_idx
[
I3
],
m_thread_data_on_grid_idx
[
I4
],
n_thread_data_on_grid_idx
[
I2
]));
return
partition_tensor
;
}
/**
* \brief Create local partition per thread for C tensor.
*
* \note C output Vgpr register layout (8D):
* - MXdlPerWave - The number of MFMA instructions run by single wave in M
* dimension per tile.
* - NXdlPerWave - The number of MFMA instructions run by single wave in N
* dimension per tile.
* - MWave - Equals to 1 since this is for single wave.
* - NWave - Equals to 1 since this is for single wave.
* - NumGroupsPerBlock - Mfma instruction internal layout (depeneds on the
* instruction size).
* - NumInputsBlock - Mfma instruction internal layout (depeneds on the
* instruction size).
* - GroupSize - Mfma instruction internal layout (depeneds on the
* instruction size).
* - NumThreadsPerBlock - Mfma instruction internal layout (depeneds on the
* instruction size).
*
* \tparam DataType Input data types.
* \tparam ATileLayout A tensor layout.
* \tparam BTileLayout B tensor layout.
* \tparam BlockSize Number of threads in block.
* \tparam GemmTraits Traits of gemm xdl operation.
*
* \return Vgpr c tensor for blockwise gemm.
*/
template
<
typename
DataType
,
typename
ATileLayout
,
typename
BTileLayout
,
index_t
BlockSize
,
typename
GemmTraits
>
__host__
__device__
constexpr
auto
make_blockwise_gemm_xdl_c_vgpr
()
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
I4
=
Number
<
4
>
{};
constexpr
auto
I5
=
Number
<
5
>
{};
constexpr
auto
I6
=
Number
<
6
>
{};
constexpr
auto
I7
=
Number
<
7
>
{};
constexpr
bool
is_integer
=
is_same_v
<
DataType
,
int8_t
>
||
is_same_v
<
DataType
,
int16_t
>
||
is_same_v
<
DataType
,
int32_t
>
;
using
GemmAccDataType
=
std
::
conditional_t
<
is_integer
,
int32_t
,
float
>
;
using
ABlockDesc_K0_M_K1_Type
=
decltype
(
detail
::
GetBlockDescriptor
<
GemmTraits
::
K1
,
ATileLayout
>
());
using
BBlockDesc_K0_N_K1_Type
=
decltype
(
detail
::
GetBlockDescriptor
<
GemmTraits
::
K1
,
BTileLayout
>
());
using
BlockwiseGemmXdlops
=
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
<
BlockSize
,
DataType
,
DataType
,
GemmAccDataType
,
ABlockDesc_K0_M_K1_Type
,
BBlockDesc_K0_N_K1_Type
,
GemmTraits
::
MPerXDL
,
GemmTraits
::
NPerXDL
,
GemmTraits
::
MXdlPerWave
,
GemmTraits
::
NXdlPerWave
,
GemmTraits
::
K1
>
;
// Calcualte descriptor, shape and layout
constexpr
auto
vgpr_desc
=
BlockwiseGemmXdlops
::
GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
();
const
auto
vgpr_shape
=
make_tuple
(
vgpr_desc
.
GetLengths
()[
I0
],
vgpr_desc
.
GetLengths
()[
I1
],
vgpr_desc
.
GetLengths
()[
I2
],
vgpr_desc
.
GetLengths
()[
I3
],
vgpr_desc
.
GetLengths
()[
I4
],
vgpr_desc
.
GetLengths
()[
I5
],
vgpr_desc
.
GetLengths
()[
I6
],
vgpr_desc
.
GetLengths
()[
I7
]);
const
auto
vgpr_layout
=
Layout
<
remove_reference_t
<
decltype
(
vgpr_shape
)
>
,
decltype
(
vgpr_desc
)
>
(
vgpr_shape
,
vgpr_desc
);
// Get vector type for Vgpr
using
BlockwiseGemmCThreadBufferType
=
remove_reference_t
<
decltype
(
BlockwiseGemmXdlops
{}.
GetCThreadBuffer
())
>
;
using
VgprVectorType
=
typename
BlockwiseGemmCThreadBufferType
::
V
;
return
ck
::
wrapper
::
make_register_tensor
<
ck
::
wrapper
::
MemoryTypeEnum
::
Vgpr
,
VgprVectorType
>
(
vgpr_layout
);
}
}
// namespace wrapper
}
// namespace ck
include/ck/wrapper/tensor.hpp
View file @
b30d416c
...
...
@@ -10,8 +10,8 @@
namespace
ck
{
namespace
wrapper
{
namespace
detail
{
namespace
{
namespace
detail
{
/**
* \brief Check if Tuple contains Slice object
*
...
...
@@ -187,8 +187,8 @@ __host__ __device__ constexpr auto GenerateSlicedDescriptor(const Tuple<Ts...>&
const
auto
upper_dims
=
decltype
(
GenerateUpperDims
<
0
>
(
TransformsTupleType
{})){};
return
transform_tensor_descriptor
(
flatten_desc
,
transforms
,
lower_dims
,
upper_dims
);
}
}
// namespace
}
// namespace detail
}
// namespace
/**
* \brief Tensor wrapper that performs static and dynamic buffer logic.
...
...
@@ -209,7 +209,10 @@ struct Tensor
public:
using
ElementSpaceSize
=
decltype
(
Layout
<
Shape
,
UnrolledDescriptorType
>
{
Shape
{},
UnrolledDescriptorType
{}}.
GetElementSpaceSize
());
// SpaceSize type for buffer
using
TensorElementType
=
ElementType
;
// DataType
using
TensorElementType
=
std
::
conditional_t
<
is_scalar_type
<
ElementType
>::
value
,
ElementType
,
typename
scalar_type
<
std
::
remove_const_t
<
ElementType
>>::
type
>
;
// DataType
static
constexpr
MemoryTypeEnum
TensorBufferAddressSpace
=
BufferAddressSpace
;
static
constexpr
bool
IsDynamicBuffer
=
!
(
BufferAddressSpace
==
MemoryTypeEnum
::
Sgpr
||
...
...
@@ -280,7 +283,7 @@ struct Tensor
* \return Requested value.
*/
template
<
typename
...
Ts
,
enable_if_t
<!
detail
::
HasSlice
(
Tuple
<
Ts
...>{}),
bool
>
=
false
>
__host__
__device__
const
ElementType
&
operator
[](
const
Tuple
<
Ts
...
>&
idx
)
const
__host__
__device__
const
Tensor
ElementType
&
operator
[](
const
Tuple
<
Ts
...
>&
idx
)
const
{
if
constexpr
(
IsDynamicBuffer
)
{
...
...
@@ -301,13 +304,13 @@ struct Tensor
}
template
<
typename
...
Ts
,
enable_if_t
<!
detail
::
HasSlice
(
Tuple
<
Ts
...>{}),
bool
>
=
false
>
__host__
__device__
const
ElementType
&
operator
()(
const
Tuple
<
Ts
...
>&
idx
)
const
__host__
__device__
const
Tensor
ElementType
&
operator
()(
const
Tuple
<
Ts
...
>&
idx
)
const
{
return
this
->
operator
[](
idx
);
}
template
<
typename
...
Idxs
,
enable_if_t
<!
detail
::
HasSlice
(
Tuple
<
Idxs
...>{}),
bool
>
=
false
>
__host__
__device__
const
ElementType
&
operator
()(
Idxs
...
idxs
)
const
__host__
__device__
const
Tensor
ElementType
&
operator
()(
Idxs
...
idxs
)
const
{
return
this
->
operator
[](
make_tuple
(
idxs
...));
}
...
...
@@ -319,7 +322,7 @@ struct Tensor
* \return Requested value.
*/
template
<
typename
...
Ts
,
enable_if_t
<!
detail
::
HasSlice
(
Tuple
<
Ts
...>{}),
bool
>
=
false
>
__host__
__device__
ElementType
&
operator
[](
const
Tuple
<
Ts
...
>&
idx
)
__host__
__device__
Tensor
ElementType
&
operator
[](
const
Tuple
<
Ts
...
>&
idx
)
{
if
constexpr
(
IsDynamicBuffer
)
{
...
...
@@ -340,13 +343,13 @@ struct Tensor
}
template
<
typename
...
Ts
,
enable_if_t
<!
detail
::
HasSlice
(
Tuple
<
Ts
...>{}),
bool
>
=
false
>
__host__
__device__
ElementType
&
operator
()(
const
Tuple
<
Ts
...
>&
idx
)
__host__
__device__
Tensor
ElementType
&
operator
()(
const
Tuple
<
Ts
...
>&
idx
)
{
return
this
->
operator
[](
idx
);
}
template
<
typename
...
Idxs
,
enable_if_t
<!
detail
::
HasSlice
(
Tuple
<
Idxs
...>{}),
bool
>
=
false
>
__host__
__device__
ElementType
&
operator
()(
Idxs
...
idxs
)
__host__
__device__
Tensor
ElementType
&
operator
()(
Idxs
...
idxs
)
{
return
this
->
operator
[](
make_tuple
(
idxs
...));
}
...
...
@@ -366,7 +369,7 @@ struct Tensor
*
* \return Pointer.
*/
__host__
__device__
ElementType
*
GetPointer
()
const
{
return
buffer_
.
p_data_
;
}
__host__
__device__
Tensor
ElementType
*
GetPointer
()
const
{
return
buffer_
.
p_data_
;
}
__host__
__device__
constexpr
auto
&
GetBuffer
()
{
return
buffer_
;
}
__host__
__device__
constexpr
auto
&
GetBuffer
()
const
{
return
buffer_
;
}
...
...
@@ -395,10 +398,18 @@ struct Tensor
ElementType
,
ElementSpaceSize
,
true
/*InvalidElementUseNumericalZeroValue*/
>
;
using
StaticBufferType
=
StaticBuffer
<
BufferAddressSpace
,
ElementType
,
size
(
Shape
{}),
true
/*InvalidElementUseNumericalZeroValue*/
>
;
using
StaticBufferType
=
std
::
conditional_t
<
is_scalar_type
<
ElementType
>::
value
,
StaticBuffer
<
BufferAddressSpace
,
ElementType
,
size
(
Shape
{}),
true
/*InvalidElementUseNumericalZeroValue*/
>
,
StaticBufferTupleOfVector
<
BufferAddressSpace
,
TensorElementType
,
size
(
Shape
{})
/
scalar_type
<
std
::
remove_const_t
<
ElementType
>>::
vector_size
,
scalar_type
<
std
::
remove_const_t
<
ElementType
>>::
vector_size
,
true
/*InvalidElementUseNumericalZeroValue*/
>>
;
// If register use static buffer, else use dynamic buffer
using
Buffer
=
std
::
conditional_t
<
IsDynamicBuffer
,
DynamicBufferType
,
StaticBufferType
>
;
...
...
include/ck/wrapper/traits/blockwise_gemm_xdl_traits.hpp
0 → 100644
View file @
b30d416c
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/ck.hpp"
namespace
ck
{
namespace
wrapper
{
/**
* \brief Traits for blockwise gemm xdl.
*
* \tparam MPerXDLValue The MFMA instruction size in M dimension.
* \tparam NPerXDLValue The MFMA instruction size in N dimension.
* \tparam MXdlPerWaveValue The number of MFMA instructions run by single
* wave in M dimension.
* \tparam NXdlPerWaveValue The number of MFMA instructions run by single
* wave in N dimension.
* \tparam K1Value The number of K-dim elements that are packed together as
* a separate logical dimension. Usually aligns with vector load size.
*/
template
<
index_t
MPerXDLValue
,
index_t
NPerXDLValue
,
index_t
MXdlPerWaveValue
,
index_t
NXdlPerWaveValue
,
index_t
K1Value
>
struct
BlockwisGemmXdlTraits
{
static
constexpr
index_t
MPerXDL
=
MPerXDLValue
;
static
constexpr
index_t
NPerXDL
=
NPerXDLValue
;
static
constexpr
index_t
MXdlPerWave
=
MXdlPerWaveValue
;
static
constexpr
index_t
NXdlPerWave
=
NXdlPerWaveValue
;
static
constexpr
index_t
K1
=
K1Value
;
};
// K1 = 4
struct
BlockwisGemmXdlTraits_32x32Xdl_4x2XdlPerWave_4K1
:
BlockwisGemmXdlTraits
<
32
,
32
,
4
,
2
,
4
>
{
};
struct
BlockwisGemmXdlTraits_32x32Xdl_2x4XdlPerWave_4K1
:
BlockwisGemmXdlTraits
<
32
,
32
,
2
,
4
,
4
>
{
};
struct
BlockwisGemmXdlTraits_32x32Xdl_2x2XdlPerWave_4K1
:
BlockwisGemmXdlTraits
<
32
,
32
,
2
,
2
,
4
>
{
};
// K1 = 8
struct
BlockwisGemmXdlTraits_32x32Xdl_4x2XdlPerWave_8K1
:
BlockwisGemmXdlTraits
<
32
,
32
,
4
,
2
,
8
>
{
};
struct
BlockwisGemmXdlTraits_32x32Xdl_2x4XdlPerWave_8K1
:
BlockwisGemmXdlTraits
<
32
,
32
,
2
,
4
,
8
>
{
};
struct
BlockwisGemmXdlTraits_32x32Xdl_2x2XdlPerWave_8K1
:
BlockwisGemmXdlTraits
<
32
,
32
,
2
,
2
,
8
>
{
};
// K1 = 16
struct
BlockwisGemmXdlTraits_32x32Xdl_4x2XdlPerWave_16K1
:
BlockwisGemmXdlTraits
<
32
,
32
,
4
,
2
,
16
>
{
};
struct
BlockwisGemmXdlTraits_32x32Xdl_2x4XdlPerWave_16K1
:
BlockwisGemmXdlTraits
<
32
,
32
,
2
,
4
,
16
>
{
};
struct
BlockwisGemmXdlTraits_32x32Xdl_2x2XdlPerWave_16K1
:
BlockwisGemmXdlTraits
<
32
,
32
,
2
,
2
,
16
>
{
};
}
// namespace wrapper
}
// namespace ck
include/ck/wrapper/utils/tensor_partition.hpp
View file @
b30d416c
...
...
@@ -6,6 +6,7 @@
#include "tensor_utils.hpp"
#include "layout_utils.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "ck/tensor_description/cluster_descriptor.hpp"
...
...
@@ -14,6 +15,8 @@ namespace wrapper {
namespace
{
namespace
detail
{
/**
* \brief Calculate shape for partition based on number of threads per each dim and
* previous shape
...
...
@@ -30,26 +33,109 @@ __host__ __device__ constexpr auto CalculateLocalPartitionShape(const Tuple<Ts..
return
generate_tuple
(
[
&
](
auto
i
)
{
constexpr
auto
num_i
=
Number
<
i
>
{};
const
auto
slice_len
=
size
<
num_i
>
(
shape
)
/
thread_lengths
.
At
(
num_i
);
const
auto
slice_len
=
ck
::
math
::
integer_divide_ceil
(
size
<
num_i
>
(
shape
),
thread_lengths
.
At
(
num_i
));
return
slice_len
;
},
Number
<
Tuple
<
Ls
...
>::
Size
()
>
{});
}
/**
* \brief Apply projection.
*
* \param base_tuple Tuple to apply projection.
* \param projection Projection to remove selected dim from partitioning.
* slice(X) to remove, where X is dim size, Number<1>{} to keep.
* \return Multi index after projection.
*/
template
<
typename
MultiIndex
,
typename
ProjectionTuple
>
__host__
__device__
constexpr
auto
ApplyProjection
([[
maybe_unused
]]
const
MultiIndex
&
base_tuple
,
[[
maybe_unused
]]
const
ProjectionTuple
&
projection
)
{
if
constexpr
(
is_same_v
<
ProjectionTuple
,
Tuple
<>>
)
{
return
Tuple
<>
{};
}
else
{
auto
base_tuple_after_projection
=
generate_tuple
(
[
&
](
auto
i
)
{
const
auto
i_num
=
Number
<
i
.
value
>
{};
static_assert
(
is_detected
<
is_slice
,
tuple_element_t
<
i_num
,
ProjectionTuple
>>::
value
||
is_same_v
<
tuple_element_t
<
i_num
,
ProjectionTuple
>
,
Number
<
1
>>
);
if
constexpr
(
is_detected
<
is_slice
,
tuple_element_t
<
i_num
,
ProjectionTuple
>>::
value
)
{
// When slice (to remove), then insert empty tuple (will be removed in next
// step).
return
Tuple
<>
{};
}
else
{
return
base_tuple
.
At
(
i_num
);
}
},
Number
<
MultiIndex
::
Size
()
>
{});
// Remove empty tuples
return
UnrollNestedTuple
<
0
,
1
>
(
base_tuple_after_projection
);
}
}
/**
* \brief Calculate shape with dims from projection.
*
* \param shape Base tensor shape.
* \param projection Projection to remove selected dim from partitioning.
* slice(X) to remove, where X is dim size, Number<1>{} to keep.
* \return Shape with dims from projection
*/
template
<
typename
...
Ts
,
typename
...
Ps
>
__host__
__device__
constexpr
auto
CalculateShapeWithProjection
(
const
Tuple
<
Ts
...
>&
shape
,
const
Tuple
<
Ps
...
>&
projection
)
{
return
generate_tuple
(
[
&
](
auto
i
)
{
if
constexpr
(
is_detected
<
is_slice
,
tuple_element_t
<
i
,
Tuple
<
Ps
...
>>>::
value
)
{
return
size
<
i
>
(
projection
).
to_
;
}
else
{
// number of shape element in actual fragment of shape and projection (method to
// calculate shape idx)
constexpr
index_t
shape_i
=
detail
::
ApplyProjection
(
TupleSlice
<
0
,
i
>
(
Tuple
<
Ts
...
>
{}),
TupleSlice
<
0
,
i
>
(
Tuple
<
Ps
...
>
{}))
.
Size
();
return
size
<
shape_i
>
(
shape
);
}
},
Number
<
Tuple
<
Ps
...
>::
Size
()
>
{});
}
/**
* \brief Calculate total number of blocks.
*
* \param shape Base tensor shape.
* \param tile_shape Tile shape.
* \param projection Projection is used to remove selected dim from
* partitioning. Use `slice(X)` to remove dimension, where X is dim
* size. Use `Number<1>{}` to keep it.
* \return Tuple with blocks number.
*/
template
<
typename
...
Ts
,
typename
...
Ls
>
template
<
typename
...
Ts
,
typename
...
Ls
,
typename
...
Ps
>
__host__
__device__
constexpr
auto
CalculateGridSize
(
const
Tuple
<
Ts
...
>&
shape
,
const
Tuple
<
Ls
...
>&
tile_shape
)
const
Tuple
<
Ls
...
>&
tile_shape
,
const
Tuple
<
Ps
...
>&
projection
)
{
static_assert
(
Tuple
<
Ts
...
>::
Size
()
==
Tuple
<
Ls
...
>::
Size
(),
"Wrong thread_lengths shape."
);
return
generate_tuple
([
&
](
auto
i
)
{
return
size
<
i
>
(
shape
)
/
size
<
i
>
(
tile_shape
);
},
Number
<
Tuple
<
Ls
...
>::
Size
()
>
{});
auto
shape_with_projection
=
CalculateShapeWithProjection
(
shape
,
projection
);
return
generate_tuple
(
[
&
](
auto
i
)
{
return
ck
::
math
::
integer_divide_ceil
(
size
<
i
>
(
shape_with_projection
),
size
<
i
>
(
tile_shape
));
},
Number
<
Tuple
<
Ls
...
>::
Size
()
>
{});
}
/**
...
...
@@ -69,6 +155,20 @@ CalculateOffsetMultiIdxs(const ThreadIdxs& thread_idxs,
return
thread_idxs
*
partition_lengths_seq
+
old_offset_idxs
;
}
/**
* \brief Calculate default projection.
*
* \param tile_shape Tile shape.
* \return Default projection (filled with Number<1>{}).
*/
template
<
typename
TileShape
>
__host__
__device__
constexpr
auto
GenerateDefaultProjection
([[
maybe_unused
]]
const
TileShape
tile_shape
)
{
return
generate_tuple
([
&
](
auto
)
{
return
Number
<
1
>
{};
},
Number
<
TileShape
::
Size
()
>
{});
}
}
// namespace detail
}
// namespace
/**
...
...
@@ -78,35 +178,45 @@ CalculateOffsetMultiIdxs(const ThreadIdxs& thread_idxs,
* \param tensor Tensor for partition.
* \param thread_lengths Layout of threads (could not be nested).
* \param thread_id Thread index represented as integer.
* \param projection Projection is used to remove selected dim from
* partitioning. Use `slice(X)` to remove dimension, where X is dim
* size. Use `Number<1>{}` to keep it.
* \return Partition tensor.
*/
template
<
typename
TensorType
,
typename
ThreadLengthsTuple
>
template
<
typename
TensorType
,
typename
ThreadLengthsTuple
,
typename
ProjectionTuple
>
__host__
__device__
constexpr
auto
make_local_partition
(
TensorType
&
tensor
,
[[
maybe_unused
]]
const
ThreadLengthsTuple
&
thread_lengths
,
const
index_t
thread_id
)
const
index_t
thread_id
,
const
ProjectionTuple
&
projection
)
{
static_assert
(
!
IsNestedTuple
(
ThreadLengthsTuple
{}));
// Calculate new partition shape
const
auto
&
tensor_shape
=
shape
(
tensor
);
// Calculate projected thread lengths
constexpr
auto
projected_thread_lengths
=
detail
::
ApplyProjection
(
ThreadLengthsTuple
{},
ProjectionTuple
{});
constexpr
auto
partition_shape
=
CalculateLocalPartitionShape
(
decltype
(
tensor_shape
){},
T
hread
L
engths
Tuple
{}
);
detail
::
CalculateLocalPartitionShape
(
decltype
(
tensor_shape
){},
projected_t
hread
_l
engths
);
// Create Thread Cluster Descriptor
constexpr
auto
partition_lengths_seq
=
generate_sequence_v2
(
[
&
](
auto
I
)
{
return
size
<
I
>
(
partition_shape
);
},
Number
<
ThreadLengthsTuple
::
Size
()
>
{});
constexpr
auto
partition_shape_seq
=
generate_sequence_v2
([
&
](
auto
I
)
{
return
size
<
I
>
(
partition_shape
);
},
Number
<
decltype
(
partition_shape
)
::
Size
()
>
{});
constexpr
auto
thread_lengths_seq
=
generate_sequence_v2
([
&
](
auto
I
)
{
return
size
<
I
>
(
ThreadLengthsTuple
{});
},
Number
<
ThreadLengthsTuple
::
Size
()
>
{});
constexpr
auto
thread_cluster_desc_
=
make_cluster_descriptor
(
thread_lengths_seq
);
// Calculate thread idxs and offsets
const
auto
thread_idxs
=
thread_cluster_desc_
.
CalculateBottomIndex
(
make_multi_index
(
thread_id
));
const
auto
offset_multi_idxs
=
CalculateOffsetMultiIdxs
(
thread_idxs
,
partition_lengths_seq
,
tensor
.
GetMultiIdxOffsets
());
// Apply projection on thread idxs to remove not needed idxs
const
auto
projected_thread_idxs
=
detail
::
ApplyProjection
(
thread_idxs
,
projection
);
const
auto
offset_multi_idxs
=
detail
::
CalculateOffsetMultiIdxs
(
projected_thread_idxs
,
partition_shape_seq
,
tensor
.
GetMultiIdxOffsets
());
// Create new layout and tensor
auto
&
flatten
_desc
=
layout
(
tensor
).
GetUnrolledDescriptor
();
auto
&
unrolled
_desc
=
layout
(
tensor
).
GetUnrolledDescriptor
();
const
auto
partition_layout
=
Layout
<
remove_reference_t
<
decltype
(
partition_shape
)
>
,
decltype
(
flatten
_desc
)
>
(
partition_shape
,
flatten
_desc
);
Layout
<
remove_reference_t
<
decltype
(
partition_shape
)
>
,
decltype
(
unrolled
_desc
)
>
(
partition_shape
,
unrolled
_desc
);
auto
partition_tensor
=
make_tensor
<
TensorType
::
TensorBufferAddressSpace
>
(
tensor
.
GetPointer
(),
partition_layout
);
// Apply offsets
...
...
@@ -114,6 +224,24 @@ make_local_partition(TensorType& tensor,
return
partition_tensor
;
}
/**
* \brief Create local partition for thread (At now only packed partition
* is supported).
*
* \param tensor Tensor for partition.
* \param thread_lengths Layout of threads (could not be nested).
* \param thread_id Thread index represented as integer.
* \return Partition tensor.
*/
template
<
typename
TensorType
,
typename
ThreadLengthsTuple
>
__host__
__device__
constexpr
auto
make_local_partition
(
TensorType
&
tensor
,
const
ThreadLengthsTuple
&
thread_lengths
,
const
index_t
thread_id
)
{
const
auto
projection
=
detail
::
GenerateDefaultProjection
(
ThreadLengthsTuple
{});
return
make_local_partition
(
tensor
,
thread_lengths
,
thread_id
,
projection
);
}
/**
* \brief Create local tile for thread block. (At now only packed tile
* is supported).
...
...
@@ -125,22 +253,29 @@ make_local_partition(TensorType& tensor,
* \param tensor Tensor for partition.
* \param tile_shape Shapes of requested tile.
* \param block_id Block index represented as integer.
* \param projection Projection to remove selected dim from partitioning.
* slice(X) to remove, where X is dim size, Number<1>{} to keep.
* \return Tile tensor.
*/
template
<
typename
TensorType
,
typename
BlockShapeTuple
>
__host__
__device__
constexpr
auto
make_local_tile
(
const
TensorType
&
tensor
,
const
BlockShapeTuple
&
tile_shape
,
const
index_t
block_id
)
template
<
typename
TensorType
,
typename
BlockShapeTuple
,
typename
ProjectionTuple
>
__host__
__device__
constexpr
auto
make_local_tile
(
const
TensorType
&
tensor
,
const
BlockShapeTuple
&
tile_shape
,
const
index_t
block_id
,
const
ProjectionTuple
&
projection
)
{
static_assert
(
!
IsNestedTuple
(
BlockShapeTuple
{}));
constexpr
bool
is_default_projection
=
is_same_v
<
ProjectionTuple
,
decltype
(
detail
::
GenerateDefaultProjection
(
BlockShapeTuple
{}))
>
;
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
auto
&
aligned_desc
=
layout
(
tensor
).
GetMergedNestingDescriptor
();
if
constexpr
(
BlockShapeTuple
::
Size
()
==
I2
)
// TODO: Enable block_2_tile_map partitioning for non-default projection.
if
constexpr
(
BlockShapeTuple
::
Size
()
==
I2
&&
is_default_projection
)
{
// Optimized version for 2d tile shape [MxK]
const
auto
block_2_tile_map
=
...
...
@@ -169,20 +304,24 @@ make_local_tile(const TensorType& tensor, const BlockShapeTuple& tile_shape, con
{
// Calculate offsets
// Sequence with data to process per block
constexpr
auto
tile_shape_seq
=
generate_sequence_v2
([](
auto
I
)
{
return
size
(
BlockShapeTuple
{}.
At
(
I
));
},
Number
<
BlockShapeTuple
::
Size
()
>
{});
constexpr
auto
projected_tile_shape
=
detail
::
ApplyProjection
(
BlockShapeTuple
{},
ProjectionTuple
{});
using
ProjectedTileShapeTuple
=
decltype
(
projected_tile_shape
);
constexpr
auto
projected_tile_shape_seq
=
generate_sequence_v2
([](
auto
I
)
{
return
ProjectedTileShapeTuple
{}.
At
(
I
);
},
Number
<
ProjectedTileShapeTuple
::
Size
()
>
{});
// Tuple with number of blocks
const
auto
block_lengths
=
CalculateGridSize
(
shape
(
tensor
),
tile_shape
);
const
expr
auto
block_cluster_desc_
=
make_cluster_descriptor
(
block_lengths
);
const
auto
block_lengths
=
detail
::
CalculateGridSize
(
shape
(
tensor
),
tile_shape
,
projection
);
const
auto
block_cluster_desc_
=
make_cluster_descriptor
(
block_lengths
);
const
auto
block_idxs
=
block_cluster_desc_
.
CalculateBottomIndex
(
make_multi_index
(
block_id
));
const
auto
offset_multi_idxs
=
CalculateOffsetMultiIdxs
(
block_idxs
,
tile_shape_seq
,
tensor
.
GetMultiIdxOffsets
());
const
auto
projected_block_idxs
=
detail
::
ApplyProjection
(
block_idxs
,
projection
);
const
auto
offset_multi_idxs
=
detail
::
CalculateOffsetMultiIdxs
(
projected_block_idxs
,
projected_tile_shape_seq
,
tensor
.
GetMultiIdxOffsets
());
// Create new layout and tensor
const
auto
tile_layout
=
Layout
<
remove_reference_t
<
decltype
(
t
ile
_s
hape
)
>
,
decltype
(
aligned_desc
)
>
(
tile_shape
,
aligned_desc
);
Layout
<
remove_reference_t
<
ProjectedT
ile
S
hape
Tuple
>
,
decltype
(
aligned_desc
)
>
(
projected_tile_shape
,
aligned_desc
);
auto
tile_tensor
=
make_tensor
<
TensorType
::
TensorBufferAddressSpace
>
(
tensor
.
GetPointer
(),
tile_layout
);
// Apply offsets
...
...
@@ -191,5 +330,61 @@ make_local_tile(const TensorType& tensor, const BlockShapeTuple& tile_shape, con
}
}
/**
* \brief Create local tile for thread block. (At now only packed tile
* is supported).
*
* \note Currently to get the best performance please use 2d shape.
*
* \param tensor Tensor for partition.
* \param tile_shape Shapes of requested tile.
* \param block_id Block index represented as integer.
* \return Tile tensor.
*/
template
<
typename
TensorType
,
typename
BlockShapeTuple
>
__host__
__device__
constexpr
auto
make_local_tile
(
const
TensorType
&
tensor
,
const
BlockShapeTuple
&
tile_shape
,
const
index_t
block_id
)
{
const
auto
projection
=
detail
::
GenerateDefaultProjection
(
BlockShapeTuple
{});
return
make_local_tile
(
tensor
,
tile_shape
,
block_id
,
projection
);
}
/**
* \brief Pad tensor shapes to be adjusted to tile lengths.
*
*
* \param tensor Tensor to pad.
* \param tile_lengths Tile lengths to align tensor shape.
* \return Padded tensor.
*/
template
<
typename
TensorType
,
typename
TileLengths
>
__host__
__device__
constexpr
auto
pad
(
const
TensorType
&
tensor
,
const
TileLengths
&
tile_lengths
)
{
const
auto
&
tensor_shape
=
shape
(
tensor
);
using
TensorShapeType
=
remove_reference_t
<
decltype
(
tensor_shape
)
>
;
auto
&
unrolled_desc
=
layout
(
tensor
).
GetUnrolledDescriptor
();
// Generate sequence with ones to mark that all dims will be padded
constexpr
auto
do_pads_seq
=
generate_sequence_v2
([](
auto
)
{
return
Number
<
1
>
{};
},
Number
<
TensorShapeType
::
Size
()
>
{});
// Create descriptor with padding
auto
padded_desc
=
tensor_operation
::
device
::
PadTensorDescriptor
(
unrolled_desc
,
tile_lengths
,
do_pads_seq
);
// Generate padded shape
const
auto
padded_shape
=
generate_tuple
(
[
&
](
auto
i
)
{
const
auto
&
dim
=
size
<
i
>
(
tensor_shape
);
const
auto
&
tile_length
=
size
<
i
>
(
tile_lengths
);
return
ck
::
math
::
integer_divide_ceil
(
dim
,
tile_length
)
*
tile_length
;
},
Number
<
TileLengths
::
Size
()
>
{});
// Create layout and tensor
const
auto
padded_layout
=
Layout
<
decltype
(
padded_shape
),
decltype
(
padded_desc
)
>
(
padded_shape
,
padded_desc
);
auto
partition_tensor
=
make_tensor
<
TensorType
::
TensorBufferAddressSpace
>
(
tensor
.
GetPointer
(),
padded_layout
);
partition_tensor
.
SetMultiIdxOffset
(
tensor
.
GetMultiIdxOffsets
());
return
partition_tensor
;
}
}
// namespace wrapper
}
// namespace ck
include/ck/wrapper/utils/tensor_utils.hpp
View file @
b30d416c
...
...
@@ -5,6 +5,7 @@
#include "ck/ck.hpp"
#include "ck/utility/data_type.hpp"
#include "ck/utility/number.hpp"
#include "ck/utility/tuple.hpp"
#include "ck/utility/tuple_helper.hpp"
...
...
@@ -19,9 +20,9 @@ namespace wrapper {
* \brief Memory type, allowed members:
* - Generic,
* - Global,
* - L
DS
,
* - S
GPR
,
* - V
GPR
,
* - L
ds
,
* - S
gpr
,
* - V
gpr
,
*/
using
MemoryTypeEnum
=
AddressSpaceEnum
;
...
...
@@ -52,12 +53,8 @@ struct Slice
__host__
__device__
constexpr
auto
range
(
const
T
&
dim
)
const
{
if
constexpr
(
is_same_v
<
FromType
,
index_t
>
||
is_same_v
<
ToType
,
index_t
>
||
is_same_v
<
T
,
index_t
>
)
is_same_v
<
std
::
remove_const_t
<
T
>
,
index_t
>
)
{
if
(
!
(
dim
>=
to_
&&
from_
>=
0
&&
(
to_
<
0
||
to_
>
from_
)))
{
throw
std
::
runtime_error
(
"Invalid range"
);
}
if
(
to_
<
0
)
{
return
dim
-
from_
+
to_
+
1
;
...
...
@@ -70,9 +67,10 @@ struct Slice
}
else
{
static_assert
(
dim
>=
to_
&&
from_
>=
Number
<
0
>
{}
&&
(
to_
<
0
||
to_
>
from_
),
static_assert
(
T
{}
>=
ToType
{}
&&
FromType
{}
>=
Number
<
0
>
{}
&&
(
ToType
{}
<
0
||
ToType
{}
>
FromType
{}),
"Invalid range"
);
if
constexpr
(
to_
<
0
)
if
constexpr
(
ToType
{}
<
0
)
{
return
dim
-
from_
+
to_
+
Number
<
1
>
{};
}
...
...
@@ -130,6 +128,23 @@ constexpr auto make_register_tensor(const Layout<Shape, UnrolledDescriptorType>&
return
Tensor
<
MemoryType
,
ElementType
,
Shape
,
UnrolledDescriptorType
>
(
layout
);
}
/**
* \brief Clear tensor. (Only for Vpgr/Sgpr)
*
* \param tensor Tensor to be cleared.
*/
template
<
MemoryTypeEnum
BufferAddressSpace
,
typename
ElementType
,
typename
Shape
,
typename
UnrolledDescriptorType
>
__host__
__device__
void
clear
(
Tensor
<
BufferAddressSpace
,
ElementType
,
Shape
,
UnrolledDescriptorType
>&
tensor
)
{
static_assert
(
!
Tensor
<
BufferAddressSpace
,
ElementType
,
Shape
,
UnrolledDescriptorType
>::
IsDynamicBuffer
);
return
tensor
.
GetBuffer
().
Clear
();
}
/**
* \brief Get Tensor Layout.
*
...
...
library/include/ck/library/tensor_operation_instance/device_operation_instance_factory.hpp
View file @
b30d416c
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -98,6 +98,8 @@ using Scale = ck::tensor_operation::element_wise::Scale;
using
Bilinear
=
ck
::
tensor_operation
::
element_wise
::
Bilinear
;
using
AddAddFastGelu
=
ck
::
tensor_operation
::
element_wise
::
AddAddFastGelu
;
using
AddFastGelu
=
ck
::
tensor_operation
::
element_wise
::
AddFastGelu
;
using
AddRelu
=
ck
::
tensor_operation
::
element_wise
::
AddRelu
;
using
AddSilu
=
ck
::
tensor_operation
::
element_wise
::
AddSilu
;
using
AddReluAdd
=
ck
::
tensor_operation
::
element_wise
::
AddReluAdd
;
using
FastGelu
=
ck
::
tensor_operation
::
element_wise
::
FastGelu
;
using
AddMultiply
=
ck
::
tensor_operation
::
element_wise
::
AddMultiply
;
...
...
library/include/ck/library/tensor_operation_instance/gpu/device_gemm_xdl_c_shuffle_fp8_fp8_fp8_mk_kn_mn_v1_instance.hpp
0 → 100644
View file @
b30d416c
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#ifdef CK_ENABLE_FP8
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
using
F32
=
float
;
using
F8
=
f8_t
;
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
// Compilation parameters for a[m, k] * b[k, n] = c[m, n]
template
<
ck
::
tensor_operation
::
device
::
GemmSpecialization
GemmSpec
>
using
device_gemm_xdl_c_shuffle_f8_f8_f8_mk_kn_mn_v1_instances
=
std
::
tuple
<
// clang-format off
//#####################| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| LoopScheduler| Pipeline|
//#####################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| | |
//#####################| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | |
//#####################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// pipeline v1, 1 wave
DeviceGemm_Xdl_CShuffle
<
Row
,
Row
,
Row
,
F8
,
F8
,
F8
,
F32
,
F8
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
1
,
256
,
256
,
128
,
64
,
16
,
4
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
1
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
4
,
0
,
1
,
1
,
S
<
1
,
64
,
1
,
4
>
,
16
,
LoopScheduler
::
Default
,
PipelineVersion
::
v1
>
,
DeviceGemm_Xdl_CShuffle
<
Row
,
Row
,
Row
,
F8
,
F8
,
F8
,
F32
,
F8
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
1
,
256
,
256
,
128
,
64
,
16
,
16
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
1
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
16
,
1
,
1
,
1
,
S
<
1
,
64
,
1
,
4
>
,
16
,
LoopScheduler
::
Default
,
PipelineVersion
::
v1
>
,
DeviceGemm_Xdl_CShuffle
<
Row
,
Row
,
Row
,
F8
,
F8
,
F8
,
F32
,
F8
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
1
,
256
,
128
,
256
,
64
,
16
,
4
,
32
,
32
,
2
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
1
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
4
,
0
,
1
,
1
,
S
<
1
,
64
,
1
,
4
>
,
16
,
LoopScheduler
::
Default
,
PipelineVersion
::
v1
>
,
DeviceGemm_Xdl_CShuffle
<
Row
,
Row
,
Row
,
F8
,
F8
,
F8
,
F32
,
F8
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
1
,
256
,
128
,
256
,
64
,
16
,
16
,
32
,
32
,
2
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
1
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
16
,
1
,
1
,
1
,
S
<
1
,
64
,
1
,
4
>
,
16
,
LoopScheduler
::
Default
,
PipelineVersion
::
v1
>
,
DeviceGemm_Xdl_CShuffle
<
Row
,
Row
,
Row
,
F8
,
F8
,
F8
,
F32
,
F8
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
1
,
128
,
128
,
128
,
64
,
16
,
4
,
32
,
32
,
4
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
1
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
4
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
4
>
,
16
,
LoopScheduler
::
Default
,
PipelineVersion
::
v1
>
,
DeviceGemm_Xdl_CShuffle
<
Row
,
Row
,
Row
,
F8
,
F8
,
F8
,
F32
,
F8
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
1
,
128
,
128
,
128
,
64
,
16
,
16
,
32
,
32
,
4
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
1
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
16
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
4
>
,
16
,
LoopScheduler
::
Default
,
PipelineVersion
::
v1
>
,
DeviceGemm_Xdl_CShuffle
<
Row
,
Row
,
Row
,
F8
,
F8
,
F8
,
F32
,
F8
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
1
,
256
,
128
,
128
,
64
,
16
,
4
,
32
,
32
,
2
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
1
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
4
,
0
,
1
,
1
,
S
<
1
,
64
,
1
,
4
>
,
16
,
LoopScheduler
::
Default
,
PipelineVersion
::
v1
>
,
DeviceGemm_Xdl_CShuffle
<
Row
,
Row
,
Row
,
F8
,
F8
,
F8
,
F32
,
F8
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
1
,
256
,
128
,
128
,
64
,
16
,
16
,
32
,
32
,
2
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
1
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
16
,
1
,
1
,
1
,
S
<
1
,
64
,
1
,
4
>
,
16
,
LoopScheduler
::
Default
,
PipelineVersion
::
v1
>
,
DeviceGemm_Xdl_CShuffle
<
Row
,
Row
,
Row
,
F8
,
F8
,
F8
,
F32
,
F8
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
1
,
128
,
128
,
64
,
64
,
16
,
4
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
1
,
S
<
8
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
4
,
0
,
1
,
1
,
S
<
1
,
64
,
1
,
2
>
,
16
,
LoopScheduler
::
Default
,
PipelineVersion
::
v1
>
,
DeviceGemm_Xdl_CShuffle
<
Row
,
Row
,
Row
,
F8
,
F8
,
F8
,
F32
,
F8
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
1
,
128
,
128
,
64
,
64
,
16
,
16
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
1
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
16
,
1
,
1
,
1
,
S
<
1
,
64
,
1
,
2
>
,
16
,
LoopScheduler
::
Default
,
PipelineVersion
::
v1
>
,
DeviceGemm_Xdl_CShuffle
<
Row
,
Row
,
Row
,
F8
,
F8
,
F8
,
F32
,
F8
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
1
,
128
,
64
,
128
,
64
,
16
,
4
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
1
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
4
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
4
>
,
16
,
LoopScheduler
::
Default
,
PipelineVersion
::
v1
>
,
DeviceGemm_Xdl_CShuffle
<
Row
,
Row
,
Row
,
F8
,
F8
,
F8
,
F32
,
F8
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
1
,
128
,
64
,
128
,
64
,
16
,
16
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
1
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
16
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
4
>
,
16
,
LoopScheduler
::
Default
,
PipelineVersion
::
v1
>
,
DeviceGemm_Xdl_CShuffle
<
Row
,
Row
,
Row
,
F8
,
F8
,
F8
,
F32
,
F8
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
1
,
256
,
128
,
64
,
64
,
16
,
4
,
32
,
32
,
2
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
1
,
S
<
16
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
4
,
0
,
1
,
1
,
S
<
1
,
64
,
1
,
4
>
,
16
,
LoopScheduler
::
Default
,
PipelineVersion
::
v1
>
,
DeviceGemm_Xdl_CShuffle
<
Row
,
Row
,
Row
,
F8
,
F8
,
F8
,
F32
,
F8
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
1
,
256
,
128
,
64
,
64
,
16
,
16
,
32
,
32
,
2
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
1
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
16
,
1
,
1
,
1
,
S
<
1
,
64
,
1
,
4
>
,
16
,
LoopScheduler
::
Default
,
PipelineVersion
::
v1
>
,
DeviceGemm_Xdl_CShuffle
<
Row
,
Row
,
Row
,
F8
,
F8
,
F8
,
F32
,
F8
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
1
,
256
,
64
,
128
,
64
,
16
,
4
,
32
,
32
,
1
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
1
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
4
,
0
,
1
,
1
,
S
<
1
,
64
,
1
,
4
>
,
16
,
LoopScheduler
::
Default
,
PipelineVersion
::
v1
>
,
DeviceGemm_Xdl_CShuffle
<
Row
,
Row
,
Row
,
F8
,
F8
,
F8
,
F32
,
F8
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
1
,
256
,
64
,
128
,
64
,
16
,
16
,
32
,
32
,
1
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
1
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
16
,
1
,
1
,
1
,
S
<
1
,
64
,
1
,
4
>
,
16
,
LoopScheduler
::
Default
,
PipelineVersion
::
v1
>
// clang-format on
>
;
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
#endif
library/include/ck/library/tensor_operation_instance/gpu/device_gemm_xdl_c_shuffle_fp8_fp8_fp8_mk_kn_mn_instance.hpp
→
library/include/ck/library/tensor_operation_instance/gpu/device_gemm_xdl_c_shuffle_fp8_fp8_fp8_mk_kn_mn_
v1_interwave_
instance.hpp
View file @
b30d416c
This diff is collapsed.
Click to expand it.
library/include/ck/library/tensor_operation_instance/gpu/device_gemm_xdl_c_shuffle_fp8_fp8_fp8_mk_kn_mn_v2_instance.hpp
0 → 100644
View file @
b30d416c
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#ifdef CK_ENABLE_FP8
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
using
F32
=
float
;
using
F8
=
f8_t
;
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
// Compilation parameters for a[m, k] * b[k, n] = c[m, n]
template
<
ck
::
tensor_operation
::
device
::
GemmSpecialization
GemmSpec
>
using
device_gemm_xdl_c_shuffle_f8_f8_f8_mk_kn_mn_v2_instances
=
std
::
tuple
<
// clang-format off
//#####################| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| LoopScheduler| Pipeline|
//#####################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| | |
//#####################| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | |
//#####################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// pipeline v2, 1 wave
DeviceGemm_Xdl_CShuffle
<
Row
,
Row
,
Row
,
F8
,
F8
,
F8
,
F32
,
F8
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
1
,
256
,
256
,
128
,
64
,
16
,
4
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
1
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
4
,
0
,
1
,
1
,
S
<
1
,
64
,
1
,
4
>
,
16
,
LoopScheduler
::
Default
,
PipelineVersion
::
v2
>
,
DeviceGemm_Xdl_CShuffle
<
Row
,
Row
,
Row
,
F8
,
F8
,
F8
,
F32
,
F8
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
1
,
256
,
256
,
128
,
64
,
16
,
16
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
1
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
16
,
1
,
1
,
1
,
S
<
1
,
64
,
1
,
4
>
,
16
,
LoopScheduler
::
Default
,
PipelineVersion
::
v2
>
,
DeviceGemm_Xdl_CShuffle
<
Row
,
Row
,
Row
,
F8
,
F8
,
F8
,
F32
,
F8
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
1
,
256
,
128
,
256
,
64
,
16
,
4
,
32
,
32
,
2
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
1
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
4
,
0
,
1
,
1
,
S
<
1
,
64
,
1
,
4
>
,
16
,
LoopScheduler
::
Default
,
PipelineVersion
::
v2
>
,
DeviceGemm_Xdl_CShuffle
<
Row
,
Row
,
Row
,
F8
,
F8
,
F8
,
F32
,
F8
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
1
,
256
,
128
,
256
,
64
,
16
,
16
,
32
,
32
,
2
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
1
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
16
,
1
,
1
,
1
,
S
<
1
,
64
,
1
,
4
>
,
16
,
LoopScheduler
::
Default
,
PipelineVersion
::
v2
>
,
DeviceGemm_Xdl_CShuffle
<
Row
,
Row
,
Row
,
F8
,
F8
,
F8
,
F32
,
F8
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
1
,
128
,
128
,
128
,
64
,
16
,
4
,
32
,
32
,
4
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
1
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
4
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
4
>
,
16
,
LoopScheduler
::
Default
,
PipelineVersion
::
v2
>
,
DeviceGemm_Xdl_CShuffle
<
Row
,
Row
,
Row
,
F8
,
F8
,
F8
,
F32
,
F8
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
1
,
128
,
128
,
128
,
64
,
16
,
16
,
32
,
32
,
4
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
1
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
16
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
4
>
,
16
,
LoopScheduler
::
Default
,
PipelineVersion
::
v2
>
,
DeviceGemm_Xdl_CShuffle
<
Row
,
Row
,
Row
,
F8
,
F8
,
F8
,
F32
,
F8
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
1
,
256
,
128
,
128
,
64
,
16
,
4
,
32
,
32
,
2
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
1
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
4
,
0
,
1
,
1
,
S
<
1
,
64
,
1
,
4
>
,
16
,
LoopScheduler
::
Default
,
PipelineVersion
::
v2
>
,
DeviceGemm_Xdl_CShuffle
<
Row
,
Row
,
Row
,
F8
,
F8
,
F8
,
F32
,
F8
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
1
,
256
,
128
,
128
,
64
,
16
,
16
,
32
,
32
,
2
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
1
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
16
,
1
,
1
,
1
,
S
<
1
,
64
,
1
,
4
>
,
16
,
LoopScheduler
::
Default
,
PipelineVersion
::
v2
>
,
DeviceGemm_Xdl_CShuffle
<
Row
,
Row
,
Row
,
F8
,
F8
,
F8
,
F32
,
F8
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
1
,
128
,
128
,
64
,
64
,
16
,
4
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
1
,
S
<
8
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
4
,
0
,
1
,
1
,
S
<
1
,
64
,
1
,
2
>
,
16
,
LoopScheduler
::
Default
,
PipelineVersion
::
v2
>
,
DeviceGemm_Xdl_CShuffle
<
Row
,
Row
,
Row
,
F8
,
F8
,
F8
,
F32
,
F8
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
1
,
128
,
128
,
64
,
64
,
16
,
16
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
1
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
16
,
1
,
1
,
1
,
S
<
1
,
64
,
1
,
2
>
,
16
,
LoopScheduler
::
Default
,
PipelineVersion
::
v2
>
,
DeviceGemm_Xdl_CShuffle
<
Row
,
Row
,
Row
,
F8
,
F8
,
F8
,
F32
,
F8
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
1
,
128
,
64
,
128
,
64
,
16
,
4
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
1
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
4
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
4
>
,
16
,
LoopScheduler
::
Default
,
PipelineVersion
::
v2
>
,
DeviceGemm_Xdl_CShuffle
<
Row
,
Row
,
Row
,
F8
,
F8
,
F8
,
F32
,
F8
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
1
,
128
,
64
,
128
,
64
,
16
,
16
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
1
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
16
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
4
>
,
16
,
LoopScheduler
::
Default
,
PipelineVersion
::
v2
>
,
DeviceGemm_Xdl_CShuffle
<
Row
,
Row
,
Row
,
F8
,
F8
,
F8
,
F32
,
F8
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
1
,
256
,
128
,
64
,
64
,
16
,
4
,
32
,
32
,
2
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
1
,
S
<
16
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
4
,
0
,
1
,
1
,
S
<
1
,
64
,
1
,
4
>
,
16
,
LoopScheduler
::
Default
,
PipelineVersion
::
v2
>
,
DeviceGemm_Xdl_CShuffle
<
Row
,
Row
,
Row
,
F8
,
F8
,
F8
,
F32
,
F8
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
1
,
256
,
128
,
64
,
64
,
16
,
16
,
32
,
32
,
2
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
1
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
16
,
1
,
1
,
1
,
S
<
1
,
64
,
1
,
4
>
,
16
,
LoopScheduler
::
Default
,
PipelineVersion
::
v2
>
,
DeviceGemm_Xdl_CShuffle
<
Row
,
Row
,
Row
,
F8
,
F8
,
F8
,
F32
,
F8
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
1
,
256
,
64
,
128
,
64
,
16
,
4
,
32
,
32
,
1
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
1
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
4
,
0
,
1
,
1
,
S
<
1
,
64
,
1
,
4
>
,
16
,
LoopScheduler
::
Default
,
PipelineVersion
::
v2
>
,
DeviceGemm_Xdl_CShuffle
<
Row
,
Row
,
Row
,
F8
,
F8
,
F8
,
F32
,
F8
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
1
,
256
,
64
,
128
,
64
,
16
,
16
,
32
,
32
,
1
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
1
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
16
,
1
,
1
,
1
,
S
<
1
,
64
,
1
,
4
>
,
16
,
LoopScheduler
::
Default
,
PipelineVersion
::
v2
>
// clang-format on
>
;
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
#endif
library/include/ck/library/tensor_operation_instance/gpu/gemm.hpp
View file @
b30d416c
...
...
@@ -345,11 +345,27 @@ void add_device_gemm_xdl_c_shuffle_f8_f8_f8_km_nk_mn_instances(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemm
<
Col
,
Col
,
Row
,
F8
,
F8
,
F8
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_c_shuffle_f8_f8_f8_mk_kn_mn_default_instances
(
void
add_device_gemm_xdl_c_shuffle_f8_f8_f8_mk_kn_mn_
v1_
default_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemm
<
Row
,
Row
,
Row
,
F8
,
F8
,
F8
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_c_shuffle_f8_f8_f8_mk_kn_mn_padded_instances
(
void
add_device_gemm_xdl_c_shuffle_f8_f8_f8_mk_kn_mn_v1_interwave_default_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemm
<
Row
,
Row
,
Row
,
F8
,
F8
,
F8
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_c_shuffle_f8_f8_f8_mk_kn_mn_v2_default_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemm
<
Row
,
Row
,
Row
,
F8
,
F8
,
F8
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_c_shuffle_f8_f8_f8_mk_kn_mn_v1_padded_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemm
<
Row
,
Row
,
Row
,
F8
,
F8
,
F8
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_c_shuffle_f8_f8_f8_mk_kn_mn_v1_interwave_padded_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemm
<
Row
,
Row
,
Row
,
F8
,
F8
,
F8
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_c_shuffle_f8_f8_f8_mk_kn_mn_v2_padded_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemm
<
Row
,
Row
,
Row
,
F8
,
F8
,
F8
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
...
...
@@ -579,8 +595,14 @@ struct DeviceOperationInstanceFactory<
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
is_same_v
<
BLayout
,
Row
>
&&
is_same_v
<
CLayout
,
Row
>
)
{
add_device_gemm_xdl_c_shuffle_f8_f8_f8_mk_kn_mn_padded_instances
(
op_ptrs
);
add_device_gemm_xdl_c_shuffle_f8_f8_f8_mk_kn_mn_default_instances
(
op_ptrs
);
add_device_gemm_xdl_c_shuffle_f8_f8_f8_mk_kn_mn_v1_padded_instances
(
op_ptrs
);
add_device_gemm_xdl_c_shuffle_f8_f8_f8_mk_kn_mn_v1_interwave_padded_instances
(
op_ptrs
);
add_device_gemm_xdl_c_shuffle_f8_f8_f8_mk_kn_mn_v2_padded_instances
(
op_ptrs
);
add_device_gemm_xdl_c_shuffle_f8_f8_f8_mk_kn_mn_v1_default_instances
(
op_ptrs
);
add_device_gemm_xdl_c_shuffle_f8_f8_f8_mk_kn_mn_v1_interwave_default_instances
(
op_ptrs
);
add_device_gemm_xdl_c_shuffle_f8_f8_f8_mk_kn_mn_v2_default_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
is_same_v
<
BLayout
,
Col
>
&&
is_same_v
<
CLayout
,
Row
>
)
...
...
library/include/ck/library/tensor_operation_instance/gpu/gemm_add.hpp
0 → 100644
View file @
b30d416c
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include <vector>
#include <memory>
#include "ck/ck.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
void
add_device_gemm_add_xdl_c_shuffle_f16_i8_f16_f16_mk_kn_mn_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleD
<
Row
,
Row
,
Row_Tuple
,
Row
,
F16
,
I8
,
F16_Tuple
,
F16
,
PassThrough
,
PassThrough
,
Add
>>>&
);
void
add_device_gemm_add_xdl_c_shuffle_bf16_i8_bf16_bf16_mk_kn_mn_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleD
<
Row
,
Row
,
Row_Tuple
,
Row
,
BF16
,
I8
,
BF16_Tuple
,
BF16
,
PassThrough
,
PassThrough
,
Add
>>>&
);
// GEMM + Add +
template
<
typename
ALayout
,
typename
BLayout
,
typename
D0Layout
,
typename
ELayout
,
typename
ADataType
,
typename
BDataType
,
typename
D0DataType
,
typename
EDataType
>
struct
DeviceOperationInstanceFactory
<
ck
::
tensor_operation
::
device
::
DeviceGemmMultipleD
<
ALayout
,
BLayout
,
ck
::
Tuple
<
D0Layout
>
,
ELayout
,
ADataType
,
BDataType
,
ck
::
Tuple
<
D0DataType
>
,
EDataType
,
PassThrough
,
PassThrough
,
Add
>>
{
using
DeviceOp
=
DeviceGemmMultipleD
<
ALayout
,
BLayout
,
ck
::
Tuple
<
D0Layout
>
,
ELayout
,
ADataType
,
BDataType
,
ck
::
Tuple
<
D0DataType
>
,
EDataType
,
PassThrough
,
PassThrough
,
Add
>
;
static
auto
GetInstances
()
{
std
::
vector
<
std
::
unique_ptr
<
DeviceOp
>>
op_ptrs
;
#if defined(CK_ENABLE_INT8) && defined(CK_ENABLE_FP16)
if
constexpr
(
is_same_v
<
ADataType
,
half_t
>
&&
is_same_v
<
BDataType
,
int8_t
>
&&
is_same_v
<
D0DataType
,
half_t
>
&&
is_same_v
<
EDataType
,
half_t
>
)
{
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
is_same_v
<
BLayout
,
Row
>
&&
is_same_v
<
D0Layout
,
Row
>
&&
is_same_v
<
ELayout
,
Row
>
)
{
add_device_gemm_add_xdl_c_shuffle_f16_i8_f16_f16_mk_kn_mn_mn_instances
(
op_ptrs
);
}
}
#endif
#if defined(CK_ENABLE_INT8) && defined(CK_ENABLE_BF16)
if
constexpr
(
is_same_v
<
ADataType
,
ck
::
bhalf_t
>
&&
is_same_v
<
BDataType
,
int8_t
>
&&
is_same_v
<
D0DataType
,
ck
::
bhalf_t
>
&&
is_same_v
<
EDataType
,
ck
::
bhalf_t
>
)
{
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
is_same_v
<
BLayout
,
Row
>
&&
is_same_v
<
D0Layout
,
Row
>
&&
is_same_v
<
ELayout
,
Row
>
)
{
add_device_gemm_add_xdl_c_shuffle_bf16_i8_bf16_bf16_mk_kn_mn_mn_instances
(
op_ptrs
);
}
}
#endif
return
op_ptrs
;
}
};
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/include/ck/library/tensor_operation_instance/gpu/gemm_add_fastgelu.hpp
View file @
b30d416c
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -68,6 +68,32 @@ void add_device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_km_nk_mn_mn_inst
PassThrough
,
AddFastGelu
>>>&
);
void
add_device_gemm_add_fastgelu_xdl_c_shuffle_f16_i8_f16_f16_mk_kn_mn_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleD
<
Row
,
Row
,
Row_Tuple
,
Row
,
F16
,
I8
,
F16_Tuple
,
F16
,
PassThrough
,
PassThrough
,
AddFastGelu
>>>&
);
void
add_device_gemm_add_fastgelu_xdl_c_shuffle_bf16_i8_bf16_bf16_mk_kn_mn_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleD
<
Row
,
Row
,
Row_Tuple
,
Row
,
BF16
,
I8
,
BF16_Tuple
,
BF16
,
PassThrough
,
PassThrough
,
AddFastGelu
>>>&
);
// GEMM + Add + FastGelu
template
<
typename
ALayout
,
typename
BLayout
,
...
...
@@ -106,6 +132,32 @@ struct DeviceOperationInstanceFactory<
{
std
::
vector
<
std
::
unique_ptr
<
DeviceOp
>>
op_ptrs
;
#if defined(CK_ENABLE_INT8) && defined(CK_ENABLE_FP16)
if
constexpr
(
is_same_v
<
ADataType
,
half_t
>
&&
is_same_v
<
BDataType
,
int8_t
>
&&
is_same_v
<
D0DataType
,
half_t
>
&&
is_same_v
<
EDataType
,
half_t
>
)
{
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
is_same_v
<
BLayout
,
Row
>
&&
is_same_v
<
D0Layout
,
Row
>
&&
is_same_v
<
ELayout
,
Row
>
)
{
add_device_gemm_add_fastgelu_xdl_c_shuffle_f16_i8_f16_f16_mk_kn_mn_mn_instances
(
op_ptrs
);
}
}
#endif
#if defined(CK_ENABLE_BF16) && defined(CK_ENABLE_INT8)
if
constexpr
(
is_same_v
<
ADataType
,
bhalf_t
>
&&
is_same_v
<
BDataType
,
int8_t
>
&&
is_same_v
<
D0DataType
,
bhalf_t
>
&&
is_same_v
<
EDataType
,
bhalf_t
>
)
{
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
is_same_v
<
BLayout
,
Row
>
&&
is_same_v
<
D0Layout
,
Row
>
&&
is_same_v
<
ELayout
,
Row
>
)
{
add_device_gemm_add_fastgelu_xdl_c_shuffle_bf16_i8_bf16_bf16_mk_kn_mn_mn_instances
(
op_ptrs
);
}
}
#endif
if
constexpr
(
is_same_v
<
ADataType
,
half_t
>
&&
is_same_v
<
BDataType
,
half_t
>
&&
is_same_v
<
D0DataType
,
half_t
>
&&
is_same_v
<
EDataType
,
half_t
>
)
{
...
...
library/include/ck/library/tensor_operation_instance/gpu/gemm_add_relu.hpp
0 → 100644
View file @
b30d416c
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include <vector>
#include <memory>
#include "ck/ck.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
void
add_device_gemm_add_relu_xdl_c_shuffle_f16_i8_f16_f16_mk_kn_mn_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleD
<
Row
,
Row
,
Row_Tuple
,
Row
,
F16
,
I8
,
F16_Tuple
,
F16
,
PassThrough
,
PassThrough
,
AddRelu
>>>&
);
void
add_device_gemm_add_relu_xdl_c_shuffle_bf16_i8_bf16_bf16_mk_kn_mn_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleD
<
Row
,
Row
,
Row_Tuple
,
Row
,
BF16
,
I8
,
BF16_Tuple
,
BF16
,
PassThrough
,
PassThrough
,
AddRelu
>>>&
);
// GEMM + Add + Relu
template
<
typename
ALayout
,
typename
BLayout
,
typename
D0Layout
,
typename
ELayout
,
typename
ADataType
,
typename
BDataType
,
typename
D0DataType
,
typename
EDataType
>
struct
DeviceOperationInstanceFactory
<
ck
::
tensor_operation
::
device
::
DeviceGemmMultipleD
<
ALayout
,
BLayout
,
ck
::
Tuple
<
D0Layout
>
,
ELayout
,
ADataType
,
BDataType
,
ck
::
Tuple
<
D0DataType
>
,
EDataType
,
PassThrough
,
PassThrough
,
AddRelu
>>
{
using
DeviceOp
=
DeviceGemmMultipleD
<
ALayout
,
BLayout
,
ck
::
Tuple
<
D0Layout
>
,
ELayout
,
ADataType
,
BDataType
,
ck
::
Tuple
<
D0DataType
>
,
EDataType
,
PassThrough
,
PassThrough
,
AddRelu
>
;
static
auto
GetInstances
()
{
std
::
vector
<
std
::
unique_ptr
<
DeviceOp
>>
op_ptrs
;
#if defined(CK_ENABLE_INT8) && defined(CK_ENABLE_FP16)
if
constexpr
(
is_same_v
<
ADataType
,
half_t
>
&&
is_same_v
<
BDataType
,
int8_t
>
&&
is_same_v
<
D0DataType
,
half_t
>
&&
is_same_v
<
EDataType
,
half_t
>
)
{
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
is_same_v
<
BLayout
,
Row
>
&&
is_same_v
<
D0Layout
,
Row
>
&&
is_same_v
<
ELayout
,
Row
>
)
{
add_device_gemm_add_relu_xdl_c_shuffle_f16_i8_f16_f16_mk_kn_mn_mn_instances
(
op_ptrs
);
}
}
#endif
#if defined(CK_ENABLE_INT8) && defined(CK_ENABLE_BF16)
if
constexpr
(
is_same_v
<
ADataType
,
ck
::
bhalf_t
>
&&
is_same_v
<
BDataType
,
int8_t
>
&&
is_same_v
<
D0DataType
,
ck
::
bhalf_t
>
&&
is_same_v
<
EDataType
,
ck
::
bhalf_t
>
)
{
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
is_same_v
<
BLayout
,
Row
>
&&
is_same_v
<
D0Layout
,
Row
>
&&
is_same_v
<
ELayout
,
Row
>
)
{
add_device_gemm_add_relu_xdl_c_shuffle_bf16_i8_bf16_bf16_mk_kn_mn_mn_instances
(
op_ptrs
);
}
}
#endif
return
op_ptrs
;
}
};
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/include/ck/library/tensor_operation_instance/gpu/gemm_add_silu.hpp
0 → 100644
View file @
b30d416c
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include <vector>
#include <memory>
#include "ck/ck.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
void
add_device_gemm_add_silu_xdl_c_shuffle_f16_i8_f16_f16_mk_kn_mn_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleD
<
Row
,
Row
,
Row_Tuple
,
Row
,
F16
,
I8
,
F16_Tuple
,
F16
,
PassThrough
,
PassThrough
,
AddSilu
>>>&
);
void
add_device_gemm_add_silu_xdl_c_shuffle_bf16_i8_bf16_bf16_mk_kn_mn_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleD
<
Row
,
Row
,
Row_Tuple
,
Row
,
BF16
,
I8
,
BF16_Tuple
,
BF16
,
PassThrough
,
PassThrough
,
AddSilu
>>>&
);
// GEMM + Add + Silu
template
<
typename
ALayout
,
typename
BLayout
,
typename
D0Layout
,
typename
ELayout
,
typename
ADataType
,
typename
BDataType
,
typename
D0DataType
,
typename
EDataType
>
struct
DeviceOperationInstanceFactory
<
ck
::
tensor_operation
::
device
::
DeviceGemmMultipleD
<
ALayout
,
BLayout
,
ck
::
Tuple
<
D0Layout
>
,
ELayout
,
ADataType
,
BDataType
,
ck
::
Tuple
<
D0DataType
>
,
EDataType
,
PassThrough
,
PassThrough
,
AddSilu
>>
{
using
DeviceOp
=
DeviceGemmMultipleD
<
ALayout
,
BLayout
,
ck
::
Tuple
<
D0Layout
>
,
ELayout
,
ADataType
,
BDataType
,
ck
::
Tuple
<
D0DataType
>
,
EDataType
,
PassThrough
,
PassThrough
,
AddSilu
>
;
static
auto
GetInstances
()
{
std
::
vector
<
std
::
unique_ptr
<
DeviceOp
>>
op_ptrs
;
#if defined(CK_ENABLE_INT8) && defined(CK_ENABLE_FP16)
if
constexpr
(
is_same_v
<
ADataType
,
half_t
>
&&
is_same_v
<
BDataType
,
int8_t
>
&&
is_same_v
<
D0DataType
,
half_t
>
&&
is_same_v
<
EDataType
,
half_t
>
)
{
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
is_same_v
<
BLayout
,
Row
>
&&
is_same_v
<
D0Layout
,
Row
>
&&
is_same_v
<
ELayout
,
Row
>
)
{
add_device_gemm_add_silu_xdl_c_shuffle_f16_i8_f16_f16_mk_kn_mn_mn_instances
(
op_ptrs
);
}
}
#endif
#if defined(CK_ENABLE_INT8) && defined(CK_ENABLE_BF16)
if
constexpr
(
is_same_v
<
ADataType
,
ck
::
bhalf_t
>
&&
is_same_v
<
BDataType
,
int8_t
>
&&
is_same_v
<
D0DataType
,
ck
::
bhalf_t
>
&&
is_same_v
<
EDataType
,
ck
::
bhalf_t
>
)
{
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
is_same_v
<
BLayout
,
Row
>
&&
is_same_v
<
D0Layout
,
Row
>
&&
is_same_v
<
ELayout
,
Row
>
)
{
add_device_gemm_add_silu_xdl_c_shuffle_bf16_i8_bf16_bf16_mk_kn_mn_mn_instances
(
op_ptrs
);
}
}
#endif
return
op_ptrs
;
}
};
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/include/ck/library/tensor_operation_instance/gpu/gemm_splitk.hpp
View file @
b30d416c
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -27,12 +27,67 @@ void add_device_gemm_xdl_splitk_f16_f16_f16_km_nk_mn_instances(
DeviceGemmSplitK
<
Col
,
Col
,
Row
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instances
(
void
add_device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_
v1_
instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmSplitK
<
Row
,
Row
,
Row
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instances
(
void
add_device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_v1_irregular_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmSplitK
<
Row
,
Row
,
Row
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_v1_interwave_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmSplitK
<
Row
,
Row
,
Row
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_v1_interwave_irregular_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmSplitK
<
Row
,
Row
,
Row
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_v2_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmSplitK
<
Row
,
Row
,
Row
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_v2_irregular_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmSplitK
<
Row
,
Row
,
Row
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_v1_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmSplitK
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_v1_irregular_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmSplitK
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_v1_interwave_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmSplitK
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_v1_interwave_irregular_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmSplitK
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_v2_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmSplitK
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_v2_irregular_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmSplitK
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_splitk_lds_direct_load_f16_f16_f16_mk_nk_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmSplitK
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
...
...
@@ -69,7 +124,17 @@ void add_device_gemm_xdl_splitk_f8_f16_f16_km_nk_mn_instances(
DeviceGemmSplitK
<
Col
,
Col
,
Row
,
F8
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instances
(
void
add_device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_v1_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmSplitK
<
Row
,
Row
,
Row
,
F8
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_v1_interwave_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmSplitK
<
Row
,
Row
,
Row
,
F8
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_v2_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmSplitK
<
Row
,
Row
,
Row
,
F8
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
...
...
@@ -89,12 +154,37 @@ void add_device_gemm_xdl_splitk_f16_f8_f16_km_nk_mn_instances(
DeviceGemmSplitK
<
Col
,
Col
,
Row
,
F16
,
F8
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances
(
void
add_device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_v1_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmSplitK
<
Row
,
Row
,
Row
,
F16
,
F8
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_v1_interwave_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmSplitK
<
Row
,
Row
,
Row
,
F16
,
F8
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_v2_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmSplitK
<
Row
,
Row
,
Row
,
F16
,
F8
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_irregular_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmSplitK
<
Row
,
Row
,
Row
,
F16
,
F8
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_instances
(
void
add_device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_v1_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmSplitK
<
Row
,
Col
,
Row
,
F16
,
F8
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_v1_interwave_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmSplitK
<
Row
,
Col
,
Row
,
F16
,
F8
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_v2_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmSplitK
<
Row
,
Col
,
Row
,
F16
,
F8
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
...
...
@@ -186,12 +276,25 @@ struct DeviceOperationInstanceFactory<
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
is_same_v
<
BLayout
,
Row
>
&&
is_same_v
<
CLayout
,
Row
>
)
{
add_device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instances
(
op_ptrs
);
add_device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_v1_instances
(
op_ptrs
);
add_device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_v1_irregular_instances
(
op_ptrs
);
add_device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_v1_interwave_instances
(
op_ptrs
);
add_device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_v1_interwave_irregular_instances
(
op_ptrs
);
add_device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_v2_instances
(
op_ptrs
);
add_device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_v2_irregular_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
is_same_v
<
BLayout
,
Col
>
&&
is_same_v
<
CLayout
,
Row
>
)
{
add_device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instances
(
op_ptrs
);
add_device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_v1_instances
(
op_ptrs
);
add_device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_v1_irregular_instances
(
op_ptrs
);
add_device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_v1_interwave_instances
(
op_ptrs
);
add_device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_v1_interwave_irregular_instances
(
op_ptrs
);
add_device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_v2_instances
(
op_ptrs
);
add_device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_v2_irregular_instances
(
op_ptrs
);
add_device_gemm_xdl_splitk_lds_direct_load_f16_f16_f16_mk_nk_mn_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same_v
<
ALayout
,
Col
>
&&
is_same_v
<
BLayout
,
Row
>
&&
is_same_v
<
CLayout
,
Row
>
)
...
...
@@ -212,7 +315,9 @@ struct DeviceOperationInstanceFactory<
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
is_same_v
<
BLayout
,
Row
>
&&
is_same_v
<
CLayout
,
Row
>
)
{
add_device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instances
(
op_ptrs
);
add_device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_v1_instances
(
op_ptrs
);
add_device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_v1_interwave_instances
(
op_ptrs
);
add_device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_v2_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
is_same_v
<
BLayout
,
Col
>
&&
is_same_v
<
CLayout
,
Row
>
)
...
...
@@ -236,12 +341,17 @@ struct DeviceOperationInstanceFactory<
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
is_same_v
<
BLayout
,
Row
>
&&
is_same_v
<
CLayout
,
Row
>
)
{
add_device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances
(
op_ptrs
);
add_device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_v1_instances
(
op_ptrs
);
add_device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_v1_interwave_instances
(
op_ptrs
);
add_device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_v2_instances
(
op_ptrs
);
add_device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_irregular_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
is_same_v
<
BLayout
,
Col
>
&&
is_same_v
<
CLayout
,
Row
>
)
{
add_device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_instances
(
op_ptrs
);
add_device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_v1_instances
(
op_ptrs
);
add_device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_v1_interwave_instances
(
op_ptrs
);
add_device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_v2_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same_v
<
ALayout
,
Col
>
&&
is_same_v
<
BLayout
,
Row
>
&&
is_same_v
<
CLayout
,
Row
>
)
...
...
library/include/ck/library/tensor_operation_instance/gpu/gemm_streamk.hpp
View file @
b30d416c
...
...
@@ -83,12 +83,22 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmSt
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
is_same_v<CLayout, Row>)
{
add_device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instances(op_ptrs);
add_device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_v1_instances(op_ptrs);
add_device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_v1_irregular_instances(op_ptrs);
add_device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_v1_interwave_instances(op_ptrs);
add_device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_v1_interwave_irregular_instances(op_ptrs);
add_device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_v2_instances(op_ptrs);
add_device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_v2_irregular_instances(op_ptrs);
}
else if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> &&
is_same_v<CLayout, Row>)
{
add_device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instances(op_ptrs);
add_device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_v1_instances(op_ptrs);
add_device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_v1_irregular__instances(op_ptrs);
add_device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_v1_interwave_instances(op_ptrs);
add_device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_v1_interwave_irregular_instances(op_ptrs);
add_device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_v2_instances(op_ptrs);
add_device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_v2_irregular_instances(op_ptrs);
}
else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Row> &&
is_same_v<CLayout, Row>)
...
...
library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_wmma_instance.hpp
→
library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_wmma_
f16_
instance.hpp
View file @
b30d416c
File moved
library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_wmma_i8_instance.hpp
0 → 100644
View file @
b30d416c
This diff is collapsed.
Click to expand it.
Prev
1
2
3
4
5
6
7
8
9
10
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment