Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
54d73870
Commit
54d73870
authored
Mar 22, 2024
by
Jakub Piasecki
Browse files
added bf16@int8 version
parent
f76c0072
Changes
6
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
263 additions
and
158 deletions
+263
-158
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp
...grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp
+75
-62
library/include/ck/library/tensor_operation_instance/gpu/grouped_gemm.hpp
...ck/library/tensor_operation_instance/gpu/grouped_gemm.hpp
+26
-2
library/src/tensor_operation_instance/gpu/grouped_gemm/CMakeLists.txt
...tensor_operation_instance/gpu/grouped_gemm/CMakeLists.txt
+1
-0
library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_multiple_d_splitk_xdl_two_stage_bf16_i8_bf16_mk_kn_mn_instance.cpp
...d_splitk_xdl_two_stage_bf16_i8_bf16_mk_kn_mn_instance.cpp
+85
-0
library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_multiple_d_splitk_xdl_two_stage_f16_f16_f16_mk_kn_mn_instance.cpp
..._d_splitk_xdl_two_stage_f16_f16_f16_mk_kn_mn_instance.cpp
+22
-37
profiler/src/profile_grouped_gemm_two_stage.cpp
profiler/src/profile_grouped_gemm_two_stage.cpp
+54
-57
No files found.
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp
View file @
54d73870
...
@@ -19,14 +19,13 @@
...
@@ -19,14 +19,13 @@
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_gemm_multiple_d_splitk.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_gemm_multiple_d_splitk.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_elementwise_dynamic_vector_dims.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_splitk_cshuffle.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_splitk_cshuffle.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include <ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp>
#include <ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp>
#include <ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp>
#include <ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp>
#include <ck/tensor_operation/gpu/grid/gridwise_elementwise_2d.hpp>
namespace
ck
{
namespace
ck
{
namespace
tensor_operation
{
namespace
tensor_operation
{
namespace
device
{
namespace
device
{
...
@@ -103,7 +102,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
...
@@ -103,7 +102,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
// TODO change GridwiseGEMM v2r4r2 to support separate AK1 & BK1
// TODO change GridwiseGEMM v2r4r2 to support separate AK1 & BK1
static
constexpr
index_t
K0PerBlock
=
KPerBlock
/
AK1
;
static
constexpr
index_t
K0PerBlock
=
KPerBlock
/
AK1
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
WorkspaceDataType
=
float
;
using
WorkspaceDataType
=
float
;
// First stage GridwiseGEMM kernel.
// First stage GridwiseGEMM kernel.
...
@@ -153,10 +152,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
...
@@ -153,10 +152,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
PipelineVer
,
PipelineVer
,
ComputeDataType
>
;
ComputeDataType
>
;
template
<
typename
ELay
>
// CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
// indexy 1,3 -> MPerBlock, NPerBlock || podzielone przez MPerBlock -> NPerThread
template
<
typename
ELay
>
static
auto
MakeEGridDescriptor_M_N
(
index_t
M
,
index_t
N
,
index_t
StrideE
)
static
auto
MakeEGridDescriptor_M_N
(
index_t
M
,
index_t
N
,
index_t
StrideE
)
{
{
const
auto
c_grid_desc_m_n
=
[
&
]()
{
const
auto
c_grid_desc_m_n
=
[
&
]()
{
...
@@ -192,7 +188,7 @@ template <typename ELay>
...
@@ -192,7 +188,7 @@ template <typename ELay>
}
}
}
}
static
auto
MakeDsGridDescriptor_M_N
(
const
std
::
array
<
index_t
,
NumDTensor
>&
MRaws
,
static
auto
MakeDsGridDescriptor_M_N
(
const
std
::
array
<
index_t
,
NumDTensor
>&
MRaws
,
const
std
::
array
<
index_t
,
NumDTensor
>&
NRaws
,
const
std
::
array
<
index_t
,
NumDTensor
>&
NRaws
,
const
std
::
array
<
index_t
,
NumDTensor
>&
DsStride
)
const
std
::
array
<
index_t
,
NumDTensor
>&
DsStride
)
{
{
...
@@ -219,33 +215,45 @@ template <typename ELay>
...
@@ -219,33 +215,45 @@ template <typename ELay>
static
constexpr
auto
MakeElementwiseInputSequence
()
static
constexpr
auto
MakeElementwiseInputSequence
()
{
{
return
generate_sequence_v2
(
return
generate_sequence_v2
(
[
&
]([[
maybe_unused
]]
auto
i
)
constexpr
{
return
Number
<
CDEShuffleBlockTransferScalarPerVector_NPerBlock
>
{};
},
[
&
]([[
maybe_unused
]]
auto
i
)
constexpr
{
Number
<
NumDTensor
+
1
>
{});
return
Number
<
CDEShuffleBlockTransferScalarPerVector_NPerBlock
>
{};
},
Number
<
NumDTensor
+
1
>
{});
}
}
using
CGridDesc_M_N
=
typename
GridwiseGemm
::
CGridDesc_M_N
;
using
CGridDesc_M_N
=
typename
GridwiseGemm
::
CGridDesc_M_N
;
using
EGridDesc_M_N
=
typename
GridwiseGemm
::
CGridDesc_M_N
;
using
EGridDesc_M_N
=
typename
GridwiseGemm
::
CGridDesc_M_N
;
using
DsGridDesc_M_N
=
decltype
(
MakeDsGridDescriptor_M_N
({},
{},
{}));
using
DsGridDesc_M_N
=
decltype
(
MakeDsGridDescriptor_M_N
({},
{},
{}));
using
DsGridPointer
=
decltype
(
MakeDsGridPointer
());
using
DsGridPointer
=
decltype
(
MakeDsGridPointer
());
using
CDGridDesc_M_N
=
decltype
(
concat_tuple
(
ck
::
Tuple
<
CGridDesc_M_N
>
{},
DsGridDesc_M_N
{}));
using
CDGridDesc_M_N
=
decltype
(
concat_tuple
(
ck
::
Tuple
<
CGridDesc_M_N
>
{},
DsGridDesc_M_N
{}));
using
CDDataTypes
=
decltype
(
concat_tuple
(
ck
::
Tuple
<
WorkspaceDataType
*>
{},
DsGridPointer
{}));
using
CDDataTypes
=
decltype
(
concat_tuple
(
ck
::
Tuple
<
WorkspaceDataType
*>
{},
DsGridPointer
{}));
using
ElementwiseInputSequence
=
decltype
(
MakeElementwiseInputSequence
());
using
ElementwiseInputSequence
=
decltype
(
MakeElementwiseInputSequence
());
static
constexpr
index_t
ClusterLengthMPerBlock
=
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
::
At
(
1
);
static
constexpr
index_t
ClusterLengthMPerBlock
=
static
constexpr
index_t
ClusterLengthNPerBlock
=
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
::
At
(
3
);
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
::
At
(
1
);
using
GridwiseElementwise
=
GridwiseElementwise_2D
<
CDGridDesc_M_N
,
static
constexpr
index_t
ClusterLengthNPerBlock
=
ck
::
Tuple
<
EGridDesc_M_N
>
,
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
::
At
(
3
);
CDDataTypes
,
ck
::
Tuple
<
EDataType
*>
,
CDEElementwiseOperation
,
MPerBlock
/
ClusterLengthMPerBlock
,
NPerBlock
/
ClusterLengthNPerBlock
,
ElementwiseInputSequence
,
ck
::
Sequence
<
CDEShuffleBlockTransferScalarPerVector_NPerBlock
>>
;
using
Block2ETileMapKSplit
=
using
Block2ETileMapKSplit
=
BlockToCTileMap_KSplit_M00_N0_M01Adapt
<
MPerBlock
,
NPerBlock
,
CGridDesc_M_N
>
;
BlockToCTileMap_KSplit_M00_N0_M01Adapt
<
MPerBlock
,
NPerBlock
,
CGridDesc_M_N
>
;
using
GridwiseElementwise
=
GridwiseElementwise
<
CDGridDesc_M_N
,
ck
::
Tuple
<
EGridDesc_M_N
>
,
CDDataTypes
,
ck
::
Tuple
<
EDataType
*>
,
Block2ETileMapKSplit
,
CDEElementwiseOperation
,
BlockSize
,
MPerBlock
,
NPerBlock
,
MPerBlock
/
ClusterLengthMPerBlock
,
NPerBlock
/
ClusterLengthNPerBlock
,
Sequence
<
0
,
1
>
,
ElementwiseInputSequence
,
ck
::
Sequence
<
CDEShuffleBlockTransferScalarPerVector_NPerBlock
>
,
true
>
;
// Block2CTileMap configuration parameter.
// Block2CTileMap configuration parameter.
static
constexpr
index_t
B2E_M01
=
8
;
static
constexpr
index_t
B2E_M01
=
8
;
using
GroupedGemmBlock2ETileMap
=
OffsettedBlockToCTileMap
<
Block2ETileMapKSplit
>
;
using
GroupedGemmBlock2ETileMap
=
OffsettedBlockToCTileMap
<
Block2ETileMapKSplit
>
;
...
@@ -318,7 +326,6 @@ template <typename ELay>
...
@@ -318,7 +326,6 @@ template <typename ELay>
if
(
!
(
group_count_
==
ck
::
type_convert
<
ck
::
index_t
>
(
p_As
.
size
())
&&
if
(
!
(
group_count_
==
ck
::
type_convert
<
ck
::
index_t
>
(
p_As
.
size
())
&&
group_count_
==
ck
::
type_convert
<
ck
::
index_t
>
(
p_Bs
.
size
())
&&
group_count_
==
ck
::
type_convert
<
ck
::
index_t
>
(
p_Bs
.
size
())
&&
// group_count_ == ck::type_convert<ck::index_t>(p_Ds.size()) &&
group_count_
==
ck
::
type_convert
<
ck
::
index_t
>
(
p_Es
.
size
())))
group_count_
==
ck
::
type_convert
<
ck
::
index_t
>
(
p_Es
.
size
())))
{
{
throw
std
::
runtime_error
(
"Error! group_count_ != p_As/Bs/Ds/Es size"
);
throw
std
::
runtime_error
(
"Error! group_count_ != p_As/Bs/Ds/Es size"
);
...
@@ -451,7 +458,7 @@ template <typename ELay>
...
@@ -451,7 +458,7 @@ template <typename ELay>
auto
grouped_block_2_ctile_map
=
auto
grouped_block_2_ctile_map
=
GroupedGemmBlock2ETileMap
(
local_b2c_tile_map
,
block_start
);
GroupedGemmBlock2ETileMap
(
local_b2c_tile_map
,
block_start
);
group_grid_size_
[
i
]
=
grid_size_grp
;
group_grid_size_
[
i
]
=
grid_size_grp
;
karg
.
KPadded
=
k_padded
;
karg
.
KPadded
=
k_padded
;
karg
.
K0Padded
=
k0_padded
;
karg
.
K0Padded
=
k0_padded
;
karg
.
k_batch
=
K_BATCH
;
karg
.
k_batch
=
K_BATCH
;
...
@@ -460,14 +467,15 @@ template <typename ELay>
...
@@ -460,14 +467,15 @@ template <typename ELay>
gemm_kernel_args_
[
i
].
block_end_
=
block_end
;
gemm_kernel_args_
[
i
].
block_end_
=
block_end
;
#if DEBUG_LOG
#if DEBUG_LOG
index_t
tiles
=
(
block_end
-
block_start
)
/
K_BATCH
;
index_t
tiles
=
(
block_end
-
block_start
)
/
K_BATCH
;
std
::
cout
<<
"block_start: "
<<
block_start
<<
"
\n
"
std
::
cout
<<
"block_start: "
<<
block_start
<<
"
\n
"
<<
"block_end: "
<<
block_end
<<
"
\n
"
<<
"block_end: "
<<
block_end
<<
"
\n
"
<<
"tiles: "
<<
tiles
<<
std
::
endl
<<
std
::
endl
;
<<
"tiles: "
<<
tiles
<<
std
::
endl
<<
std
::
endl
;
std
::
cout
<<
"KPadded: "
<<
karg
.
KPadded
<<
std
::
endl
std
::
cout
<<
"KPadded: "
<<
karg
.
KPadded
<<
std
::
endl
<<
"K0Padded: "
<<
karg
.
K0Padded
<<
std
::
endl
<<
"K0Padded: "
<<
karg
.
K0Padded
<<
std
::
endl
<<
"KBatch: "
<<
karg
.
k_batch
<<
std
::
endl
<<
"KBatch: "
<<
karg
.
k_batch
<<
std
::
endl
<<
"grid_size_: "
<<
karg
.
KPadded
<<
std
::
endl
;
<<
"grid_size_: "
<<
karg
.
KPadded
<<
std
::
endl
;
#endif
#endif
}
}
...
@@ -476,16 +484,13 @@ template <typename ELay>
...
@@ -476,16 +484,13 @@ template <typename ELay>
void
UpdateEPointers
()
void
UpdateEPointers
()
{
{
// set-up each group E pointer to it's designated workspace memory.
// set-up each group E pointer to it's designated workspace memory.
float
*
p_workspace
=
reinterpret_cast
<
float
*>
(
p_workspace_
);
WorkspaceDataType
*
p_workspace
=
reinterpret_cast
<
WorkspaceDataType
*>
(
p_workspace_
);
std
::
size_t
offset
=
0
;
std
::
size_t
offset
=
0
;
// TODO: per group e-ptr memory alignment (128B)?
for
(
auto
&
arg
:
gemm_kernel_args_
)
for
(
auto
&
arg
:
gemm_kernel_args_
)
{
{
arg
.
karg_
.
p_c_grid
=
p_workspace
+
offset
;
arg
.
karg_
.
p_c_grid
=
p_workspace
+
offset
;
index_t
tiles
=
(
arg
.
block_end_
-
arg
.
block_start_
)
/
arg
.
karg_
.
k_batch
;
index_t
tiles
=
(
arg
.
block_end_
-
arg
.
block_start_
)
/
arg
.
karg_
.
k_batch
;
// TODO: a co z paddingiem, layout'em w pamięci ??
// czy jest jakiś deskryptor ?
offset
+=
tiles
*
MPerBlock
*
NPerBlock
;
offset
+=
tiles
*
MPerBlock
*
NPerBlock
;
#if DEBUG_LOG
#if DEBUG_LOG
std
::
cout
<<
"block_start: "
<<
arg
.
block_start_
<<
"
\n
"
std
::
cout
<<
"block_start: "
<<
arg
.
block_start_
<<
"
\n
"
...
@@ -499,7 +504,7 @@ template <typename ELay>
...
@@ -499,7 +504,7 @@ template <typename ELay>
std
::
size_t
GetWorkspaceSizeBytes
()
const
std
::
size_t
GetWorkspaceSizeBytes
()
const
{
{
std
::
size_t
size_bytes
{
0
};
std
::
size_t
size_bytes
{
0
};
// TODO: per group e-ptr memory alignment (128B)?
for
(
const
auto
&
arg
:
gemm_kernel_args_
)
for
(
const
auto
&
arg
:
gemm_kernel_args_
)
{
{
index_t
tiles
=
(
arg
.
block_end_
-
arg
.
block_start_
)
/
arg
.
karg_
.
k_batch
;
index_t
tiles
=
(
arg
.
block_end_
-
arg
.
block_start_
)
/
arg
.
karg_
.
k_batch
;
...
@@ -534,10 +539,8 @@ template <typename ELay>
...
@@ -534,10 +539,8 @@ template <typename ELay>
std
::
vector
<
CGridDesc_M_N
>
elementwise_c_grid_descs_m_n_
;
std
::
vector
<
CGridDesc_M_N
>
elementwise_c_grid_descs_m_n_
;
std
::
vector
<
DsGridDesc_M_N
>
elementwise_d_grid_descs_m_n_
;
std
::
vector
<
DsGridDesc_M_N
>
elementwise_d_grid_descs_m_n_
;
std
::
vector
<
DsGridPointer
>
ds_grid_pointer_
;
std
::
vector
<
DsGridPointer
>
ds_grid_pointer_
;
std
::
vector
<
void
*>
e_ptrs_
;
std
::
vector
<
void
*>
e_ptrs_
;
};
};
// Invoker
// Invoker
...
@@ -729,13 +732,19 @@ template <typename ELay>
...
@@ -729,13 +732,19 @@ template <typename ELay>
BElementwiseOperation
,
BElementwiseOperation
,
PassThrough
>
;
PassThrough
>
;
const
auto
elementwise_kernel
=
kernel_elementwise_2d
<
GridwiseElementwise
,
const
auto
elementwise_kernel
=
kernel_elementwise
<
GridwiseElementwise
,
CDGridDesc_M_N
,
CDGridDesc_M_N
,
ck
::
Tuple
<
EGridDesc_M_N
>
,
ck
::
Tuple
<
EGridDesc_M_N
>
,
CDDataTypes
,
CDDataTypes
,
ck
::
Tuple
<
EDataType
*>
,
ck
::
Tuple
<
EDataType
*>
,
CDEElementwiseOperation
>
;
Block2ETileMapKSplit
,
return
LaunchKernel
(
gemm_kernel
,
elementwise_kernel
,
arg
,
dev_gemm_args
,
dev_gemm_workspace
,
stream_config
);
CDEElementwiseOperation
>
;
return
LaunchKernel
(
gemm_kernel
,
elementwise_kernel
,
arg
,
dev_gemm_args
,
dev_gemm_workspace
,
stream_config
);
}
}
template
<
typename
KernelFunction
,
typename
KernelFunction2
>
template
<
typename
KernelFunction
,
typename
KernelFunction2
>
...
@@ -767,20 +776,24 @@ template <typename ELay>
...
@@ -767,20 +776,24 @@ template <typename ELay>
arg
.
b_element_op_
,
arg
.
b_element_op_
,
PassThrough
{});
PassThrough
{});
// launch elementwise kernels.
// Elementwise kernels
for
(
int
i
=
0
;
i
<
arg
.
group_count_
;
++
i
)
{
for
(
int
i
=
0
;
i
<
arg
.
group_count_
;
++
i
)
time
+=
launch_and_time_kernel
(
stream_config
,
{
elementwise_kernel
,
time
+=
launch_and_time_kernel
(
dim3
(
arg
.
group_grid_size_
[
i
]),
// chyba group_grid_size <<< tak zmienic na group_grid_size[i]
stream_config
,
dim3
(
BlockSize
),
elementwise_kernel
,
0
,
dim3
(
arg
.
group_grid_size_
[
i
]),
concat_tuple
(
make_tuple
(
arg
.
elementwise_c_grid_descs_m_n_
[
i
]),
arg
.
elementwise_d_grid_descs_m_n_
[
i
]),
dim3
(
BlockSize
),
make_tuple
(
arg
.
elementwise_c_grid_descs_m_n_
[
i
]),
0
,
concat_tuple
(
make_tuple
(
arg
.
gemm_kernel_args_
[
i
].
karg_
.
p_c_grid
),
arg
.
ds_grid_pointer_
[
i
]),
concat_tuple
(
make_tuple
(
arg
.
elementwise_c_grid_descs_m_n_
[
i
]),
type_convert
<
EDataType
*>
(
arg
.
e_ptrs_
[
i
]),
arg
.
elementwise_d_grid_descs_m_n_
[
i
]),
arg
.
cde_element_op_
,
make_tuple
(
arg
.
elementwise_c_grid_descs_m_n_
[
i
]),
ClusterLengthMPerBlock
,
// num_threads_m
concat_tuple
(
make_tuple
(
arg
.
gemm_kernel_args_
[
i
].
karg_
.
p_c_grid
),
ClusterLengthNPerBlock
);
// num_threads_n
arg
.
ds_grid_pointer_
[
i
]),
type_convert
<
EDataType
*>
(
arg
.
e_ptrs_
[
i
]),
Block2ETileMapKSplit
{
arg
.
elementwise_c_grid_descs_m_n_
[
i
],
B2E_M01
,
arg
.
K_BATCH
},
arg
.
cde_element_op_
);
}
}
return
time
;
return
time
;
}
}
...
...
library/include/ck/library/tensor_operation_instance/gpu/grouped_gemm.hpp
View file @
54d73870
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
@@ -159,6 +159,19 @@ void add_device_grouped_gemm_multiple_d_xdl_two_stage_f16_f16_f16_mk_kn_mn_insta
...
@@ -159,6 +159,19 @@ void add_device_grouped_gemm_multiple_d_xdl_two_stage_f16_f16_f16_mk_kn_mn_insta
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
PassThrough
>>>&
instances
);
void
add_device_grouped_gemm_multiple_d_xdl_two_stage_bf16_i8_bf16_mk_kn_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedGemm
<
Row
,
Row
,
Empty_Tuple
,
Row
,
BF16
,
I8
,
Empty_Tuple
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
template
<
typename
ALayout
,
template
<
typename
ALayout
,
typename
BLayout
,
typename
BLayout
,
typename
ELayout
,
typename
ELayout
,
...
@@ -203,7 +216,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
...
@@ -203,7 +216,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
add_device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instances
(
op_ptrs
);
add_device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instances
(
op_ptrs
);
add_device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_irregular_instances
(
add_device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_irregular_instances
(
op_ptrs
);
op_ptrs
);
add_device_grouped_gemm_multiple_d_xdl_two_stage_f16_f16_f16_mk_kn_mn_instances
(
op_ptrs
);
add_device_grouped_gemm_multiple_d_xdl_two_stage_f16_f16_f16_mk_kn_mn_instances
(
op_ptrs
);
}
}
else
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
is_same_v
<
BLayout
,
Col
>
&&
else
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
is_same_v
<
BLayout
,
Col
>
&&
is_same_v
<
ELayout
,
Row
>
)
is_same_v
<
ELayout
,
Row
>
)
...
@@ -242,6 +256,16 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
...
@@ -242,6 +256,16 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
add_device_grouped_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_irregular_instances
(
op_ptrs
);
add_device_grouped_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_irregular_instances
(
op_ptrs
);
}
}
}
}
else
if
constexpr
(
is_same_v
<
ADataType
,
bhalf_t
>
&&
is_same_v
<
BDataType
,
int8_t
>
&&
is_same_v
<
EDataType
,
bhalf_t
>
)
{
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
is_same_v
<
BLayout
,
Row
>
&&
is_same_v
<
ELayout
,
Row
>
)
{
add_device_grouped_gemm_multiple_d_xdl_two_stage_bf16_i8_bf16_mk_kn_mn_instances
(
op_ptrs
);
}
}
return
op_ptrs
;
return
op_ptrs
;
}
}
};
};
...
...
library/src/tensor_operation_instance/gpu/grouped_gemm/CMakeLists.txt
View file @
54d73870
...
@@ -10,4 +10,5 @@ add_instance_library(device_grouped_gemm_instance
...
@@ -10,4 +10,5 @@ add_instance_library(device_grouped_gemm_instance
device_grouped_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_irregular_instance.cpp
device_grouped_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_irregular_instance.cpp
device_grouped_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_irregular_instance.cpp
device_grouped_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_irregular_instance.cpp
device_grouped_gemm_multiple_d_splitk_xdl_two_stage_f16_f16_f16_mk_kn_mn_instance.cpp
device_grouped_gemm_multiple_d_splitk_xdl_two_stage_f16_f16_f16_mk_kn_mn_instance.cpp
device_grouped_gemm_multiple_d_splitk_xdl_two_stage_bf16_i8_bf16_mk_kn_mn_instance.cpp
)
)
library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_multiple_d_splitk_xdl_two_stage_bf16_i8_bf16_mk_kn_mn_instance.cpp
0 → 100644
View file @
54d73870
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
using
I8
=
int8_t
;
using
BF16
=
ck
::
bhalf_t
;
using
F32
=
float
;
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
Empty_Tuple
=
ck
::
Tuple
<>
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
static
constexpr
auto
GemmMNKPadding
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNKPadding
;
// Instances having AK1!=BK1 are temporarily disabled and will be re-enabled in future
// a[m, k] * b[k, n] = e[m, n]
using
device_grouped_gemm_multiple_d_xdl_two_stage_bf16_i8_bf16_mk_kn_mn_instances
=
std
::
tuple
<
// clang-format off
//#################################################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//#################################################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//#################################################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//#################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
<
Row
,
Row
,
Empty_Tuple
,
Row
,
BF16
,
I8
,
F32
,
BF16
,
Empty_Tuple
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
256
,
256
,
128
,
32
,
8
,
8
,
32
,
32
,
4
,
2
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
2
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
4
,
PipelineVersion
::
v1
>
,
DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
<
Row
,
Row
,
Empty_Tuple
,
Row
,
BF16
,
I8
,
F32
,
BF16
,
Empty_Tuple
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
256
,
128
,
256
,
32
,
8
,
8
,
32
,
32
,
2
,
4
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
4
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
4
,
PipelineVersion
::
v1
>
,
DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
<
Row
,
Row
,
Empty_Tuple
,
Row
,
BF16
,
I8
,
F32
,
BF16
,
Empty_Tuple
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
256
,
192
,
64
,
32
,
8
,
8
,
32
,
32
,
3
,
1
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
2
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
4
,
PipelineVersion
::
v1
>
,
DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
<
Row
,
Row
,
Empty_Tuple
,
Row
,
BF16
,
I8
,
F32
,
BF16
,
Empty_Tuple
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
256
,
64
,
192
,
32
,
8
,
8
,
32
,
32
,
1
,
3
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
48
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
2
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
4
,
PipelineVersion
::
v1
>
,
DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
<
Row
,
Row
,
Empty_Tuple
,
Row
,
BF16
,
I8
,
F32
,
BF16
,
Empty_Tuple
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
256
,
128
,
128
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
2
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
4
,
PipelineVersion
::
v1
>
,
DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
<
Row
,
Row
,
Empty_Tuple
,
Row
,
BF16
,
I8
,
F32
,
BF16
,
Empty_Tuple
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
256
,
128
,
64
,
32
,
8
,
8
,
32
,
32
,
2
,
1
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
1
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
4
,
PipelineVersion
::
v1
>
,
DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
<
Row
,
Row
,
Empty_Tuple
,
Row
,
BF16
,
I8
,
F32
,
BF16
,
Empty_Tuple
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
256
,
64
,
128
,
32
,
8
,
8
,
32
,
32
,
1
,
2
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
2
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
4
,
PipelineVersion
::
v1
>
,
DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
<
Row
,
Row
,
Empty_Tuple
,
Row
,
BF16
,
I8
,
F32
,
BF16
,
Empty_Tuple
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
128
,
128
,
128
,
32
,
8
,
8
,
32
,
32
,
4
,
2
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
4
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
4
,
PipelineVersion
::
v1
>
,
DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
<
Row
,
Row
,
Empty_Tuple
,
Row
,
BF16
,
I8
,
F32
,
BF16
,
Empty_Tuple
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
128
,
128
,
64
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
2
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
4
>
,
4
,
PipelineVersion
::
v1
>
,
DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
<
Row
,
Row
,
Empty_Tuple
,
Row
,
BF16
,
I8
,
F32
,
BF16
,
Empty_Tuple
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
128
,
64
,
128
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
4
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
4
,
PipelineVersion
::
v1
>
,
DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
<
Row
,
Row
,
Empty_Tuple
,
Row
,
BF16
,
I8
,
F32
,
BF16
,
Empty_Tuple
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
128
,
32
,
192
,
32
,
8
,
8
,
32
,
32
,
1
,
3
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
24
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
4
,
PipelineVersion
::
v1
>
,
DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
<
Row
,
Row
,
Empty_Tuple
,
Row
,
BF16
,
I8
,
F32
,
BF16
,
Empty_Tuple
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
128
,
192
,
32
,
32
,
8
,
8
,
32
,
32
,
3
,
1
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
1
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
4
>
,
4
,
PipelineVersion
::
v1
>
,
DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
<
Row
,
Row
,
Empty_Tuple
,
Row
,
BF16
,
I8
,
F32
,
BF16
,
Empty_Tuple
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
128
,
32
,
64
,
32
,
8
,
8
,
32
,
32
,
1
,
1
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
2
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
4
>
,
4
,
PipelineVersion
::
v1
>
,
DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
<
Row
,
Row
,
Empty_Tuple
,
Row
,
BF16
,
I8
,
F32
,
BF16
,
Empty_Tuple
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
128
,
64
,
32
,
32
,
8
,
8
,
32
,
32
,
1
,
1
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
1
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
4
>
,
4
,
PipelineVersion
::
v1
>
,
DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
<
Row
,
Row
,
Empty_Tuple
,
Row
,
BF16
,
I8
,
F32
,
BF16
,
Empty_Tuple
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
128
,
32
,
128
,
32
,
8
,
8
,
32
,
32
,
1
,
2
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
4
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
4
,
PipelineVersion
::
v1
>
,
DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
<
Row
,
Row
,
Empty_Tuple
,
Row
,
BF16
,
I8
,
F32
,
BF16
,
Empty_Tuple
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
128
,
128
,
32
,
32
,
8
,
8
,
32
,
32
,
2
,
1
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
1
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
4
>
,
4
,
PipelineVersion
::
v1
>
,
DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
<
Row
,
Row
,
Empty_Tuple
,
Row
,
BF16
,
I8
,
F32
,
BF16
,
Empty_Tuple
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
128
,
64
,
64
,
32
,
8
,
8
,
32
,
32
,
2
,
1
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
2
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
4
,
PipelineVersion
::
v1
>
,
DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
<
Row
,
Row
,
Empty_Tuple
,
Row
,
BF16
,
I8
,
F32
,
BF16
,
Empty_Tuple
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
64
,
64
,
64
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
1
,
4
,
16
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
16
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
2
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
4
,
PipelineVersion
::
v1
>
,
DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
<
Row
,
Row
,
Empty_Tuple
,
Row
,
BF16
,
I8
,
F32
,
BF16
,
Empty_Tuple
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
64
,
64
,
32
,
32
,
8
,
8
,
32
,
32
,
2
,
1
,
S
<
1
,
4
,
16
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
16
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
2
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
4
,
PipelineVersion
::
v1
>
,
DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
<
Row
,
Row
,
Empty_Tuple
,
Row
,
BF16
,
I8
,
F32
,
BF16
,
Empty_Tuple
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
64
,
32
,
64
,
32
,
8
,
8
,
32
,
32
,
1
,
2
,
S
<
1
,
4
,
16
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
16
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
2
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
4
,
PipelineVersion
::
v1
>
// clang-format on
>
;
void
add_device_grouped_gemm_multiple_d_xdl_two_stage_bf16_i8_bf16_mk_kn_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedGemm
<
Row
,
Row
,
Empty_Tuple
,
Row
,
BF16
,
I8
,
Empty_Tuple
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_grouped_gemm_multiple_d_xdl_two_stage_bf16_i8_bf16_mk_kn_mn_instances
{});
}
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_multiple_d_splitk_xdl_two_stage_f16_f16_f16_mk_kn_mn_instance.cpp
View file @
54d73870
This diff is collapsed.
Click to expand it.
profiler/src/profile_grouped_gemm_two_stage.cpp
View file @
54d73870
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 20
18-2023
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 20
24
, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <iostream>
#include <numeric>
#include <numeric>
...
@@ -12,17 +12,12 @@
...
@@ -12,17 +12,12 @@
enum
struct
GemmMatrixLayout
enum
struct
GemmMatrixLayout
{
{
MK_KN_MN
,
// 0
MK_KN_MN
,
// 0
MK_NK_MN
,
// 1
};
};
enum
struct
GemmDataType
enum
struct
GemmDataType
{
{
F32_F32_F32
,
// 0
F16_F16_F16
,
// 0
F16_F16_F16
,
// 1
BF16_INT8_BF16
// 1
BF16_BF16_BF16
,
// 2
INT8_INT8_INT8
,
// 3
F8_F16_F16
,
// 4
F16_F8_F16
,
// 5
};
};
#define OP_NAME "grouped_gemm_two_stage"
#define OP_NAME "grouped_gemm_two_stage"
...
@@ -52,9 +47,8 @@ int profile_grouped_gemm_two_stage(int argc, char* argv[])
...
@@ -52,9 +47,8 @@ int profile_grouped_gemm_two_stage(int argc, char* argv[])
{
{
std
::
cout
std
::
cout
<<
"arg1: tensor operation ("
OP_NAME
": "
OP_DESC
")
\n
"
<<
"arg1: tensor operation ("
OP_NAME
": "
OP_DESC
")
\n
"
<<
"arg2: data type (0: fp32; 1: fp16; 2: bf16; 3: int8; 4: fp8@fp6; 5: f16@f8)
\n
"
<<
"arg2: data type (0: fp16; 1: bf16@int8)
\n
"
<<
"arg3: matrix layout (0: A[m, k] * B[k, n] = C[m, n];
\n
"
<<
"arg3: matrix layout (0: A[m, k] * B[k, n] = C[m, n]);
\n
"
<<
" 1: A[m, k] * B[n, k] = C[m, n];
\n
"
<<
"arg4: verification (0: no; 1: yes)
\n
"
<<
"arg4: verification (0: no; 1: yes)
\n
"
<<
"arg5: initialization (0: no init; 1: integer value; 2: decimal value)
\n
"
<<
"arg5: initialization (0: no init; 1: integer value; 2: decimal value)
\n
"
<<
"arg6: print tensor value (0: no; 1: yes)
\n
"
<<
"arg6: print tensor value (0: no; 1: yes)
\n
"
...
@@ -81,16 +75,17 @@ int profile_grouped_gemm_two_stage(int argc, char* argv[])
...
@@ -81,16 +75,17 @@ int profile_grouped_gemm_two_stage(int argc, char* argv[])
const
auto
Ns
=
argToIntArray
(
argv
[
9
]);
const
auto
Ns
=
argToIntArray
(
argv
[
9
]);
const
auto
Ks
=
argToIntArray
(
argv
[
10
]);
const
auto
Ks
=
argToIntArray
(
argv
[
10
]);
auto
StrideAs
=
argToIntArray
(
argv
[
11
]);
//a: mk b: kn, c: mn: stride a =
auto
StrideAs
=
argToIntArray
(
argv
[
11
]);
auto
StrideBs
=
argToIntArray
(
argv
[
12
]);
auto
StrideBs
=
argToIntArray
(
argv
[
12
]);
auto
StrideCs
=
argToIntArray
(
argv
[
13
]);
auto
StrideCs
=
argToIntArray
(
argv
[
13
]);
const
int
kbatch
=
argc
==
15
?
std
::
stoi
(
argv
[
14
])
:
1
;
const
int
kbatch
=
argc
==
15
?
std
::
stoi
(
argv
[
14
])
:
1
;
const
int
DefaultStrideA
=
Ks
[
0
];
const
int
DefaultStrideA
=
Ks
[
0
];
const
int
DefaultStrideB
=
Ns
[
0
];
const
int
DefaultStrideB
=
Ns
[
0
];
const
int
DefaultStrideC
=
Ns
[
0
];
const
int
DefaultStrideC
=
Ns
[
0
];
for
(
size_t
i
=
0
;
i
<
Ms
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
Ms
.
size
();
++
i
)
{
StrideAs
[
i
]
=
StrideAs
[
i
]
==
-
1
?
DefaultStrideA
:
StrideAs
[
i
];
StrideAs
[
i
]
=
StrideAs
[
i
]
==
-
1
?
DefaultStrideA
:
StrideAs
[
i
];
StrideBs
[
i
]
=
StrideBs
[
i
]
==
-
1
?
DefaultStrideB
:
StrideBs
[
i
];
StrideBs
[
i
]
=
StrideBs
[
i
]
==
-
1
?
DefaultStrideB
:
StrideBs
[
i
];
StrideCs
[
i
]
=
StrideCs
[
i
]
==
-
1
?
DefaultStrideC
:
StrideCs
[
i
];
StrideCs
[
i
]
=
StrideCs
[
i
]
==
-
1
?
DefaultStrideC
:
StrideCs
[
i
];
...
@@ -108,46 +103,48 @@ int profile_grouped_gemm_two_stage(int argc, char* argv[])
...
@@ -108,46 +103,48 @@ int profile_grouped_gemm_two_stage(int argc, char* argv[])
if
(
data_type
==
GemmDataType
::
F16_F16_F16
&&
layout
==
GemmMatrixLayout
::
MK_KN_MN
)
if
(
data_type
==
GemmDataType
::
F16_F16_F16
&&
layout
==
GemmMatrixLayout
::
MK_KN_MN
)
{
{
ck
::
profiler
::
profile_grouped_gemm_two_stage_impl
<
ck
::
half_t
,
ck
::
profiler
::
profile_grouped_gemm_two_stage_impl
<
ck
::
half_t
,
ck
::
half_t
,
ck
::
half_t
,
ck
::
half_t
,
ck
::
half_t
,
float
,
float
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>
(
do_verification
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>
(
init_method
,
do_verification
,
do_log
,
init_method
,
time_kernel
,
do_log
,
Ms
,
time_kernel
,
Ns
,
Ms
,
Ks
,
Ns
,
StrideAs
,
Ks
,
StrideBs
,
StrideAs
,
StrideCs
,
StrideBs
,
kbatch
,
StrideCs
,
n_warmup
,
kbatch
,
n_iter
);
n_warmup
,
n_iter
);
}
}
else
if
(
data_type
==
GemmDataType
::
F16_
F16_
F16
&&
layout
==
GemmMatrixLayout
::
MK_
N
K_MN
)
else
if
(
data_type
==
GemmDataType
::
B
F16_
INT8_B
F16
&&
layout
==
GemmMatrixLayout
::
MK_K
N
_MN
)
{
{
ck
::
profiler
::
profile_grouped_gemm_two_stage_impl
<
ck
::
half_t
,
ck
::
profiler
::
profile_grouped_gemm_two_stage_impl
<
ck
::
bhalf_t
,
ck
::
half_t
,
int8_t
,
ck
::
half_t
,
ck
::
bhalf_t
,
float
,
float
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
ColumnMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>
(
do_verification
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>
(
init_method
,
do_verification
,
do_log
,
init_method
,
time_kernel
,
do_log
,
Ms
,
time_kernel
,
Ns
,
Ms
,
Ks
,
Ns
,
StrideAs
,
Ks
,
StrideBs
,
StrideAs
,
StrideCs
,
StrideBs
,
kbatch
,
StrideCs
,
n_warmup
,
kbatch
,
n_iter
);
n_warmup
,
n_iter
);
}
}
else
else
{
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment