Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
cdec576c
Commit
cdec576c
authored
Aug 10, 2023
by
Adam Osewski
Browse files
Remove old splitk impl and replace it with tile looping one.
parent
50540084
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
348 additions
and
971 deletions
+348
-971
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_splitk_cshuffle.hpp
...u/device/impl/device_grouped_gemm_xdl_splitk_cshuffle.hpp
+348
-203
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_splitk_cshuffle_tile_loop.hpp
...mpl/device_grouped_gemm_xdl_splitk_cshuffle_tile_loop.hpp
+0
-768
No files found.
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_splitk_cshuffle.hpp
View file @
cdec576c
...
@@ -5,11 +5,13 @@
...
@@ -5,11 +5,13 @@
#include <iostream>
#include <iostream>
#include <sstream>
#include <sstream>
#include <tuple>
#include "ck/ck.hpp"
#include "ck/ck.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/host_utility/hip_check_error.hpp"
#include "ck/host_utility/hip_check_error.hpp"
#include "ck/host_utility/stream_utility.hpp"
#include "ck/utility/common_header.hpp"
#include "ck/utility/common_header.hpp"
#include "ck/utility/tuple.hpp"
#include "ck/utility/tuple.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
...
@@ -23,8 +25,28 @@ namespace ck {
...
@@ -23,8 +25,28 @@ namespace ck {
namespace
tensor_operation
{
namespace
tensor_operation
{
namespace
device
{
namespace
device
{
///
/// @brief Entry point kernel for device-wide Grouped GEMM operation.
///
/// @param[in] gemm_descs_const The pointer to the array of GEMM descriptor structures.
/// @param[in] tile_count The overall number of output tiles we divided all groups
/// into.
/// @param[in] k_batch The number of batches we split the K dimension into.
///
/// @tparam GridwiseGemm The specific GridwiseGEMM algorithm implementation.
/// @tparam GemmDesc The structure holding all necessary descriptors and
/// other data needed for groupd gemm calculation and work
/// distribution.
/// @tparam HasMainKBlockLoop Flag indicating whether all GEMM problem configurations
/// need to loop over tiles in K dimension.
/// @tparam CGlobalMemoryDataOperation The functor used to store data in output C matrix.
/// In example could be: AtomicAdd or Store.
///
template
<
typename
GridwiseGemm
,
template
<
typename
GridwiseGemm
,
typename
GemmDesc
,
typename
GemmDesc
,
typename
FloatA
,
typename
FloatB
,
typename
FloatC
,
bool
HasMainKBlockLoop
,
bool
HasMainKBlockLoop
,
InMemoryDataOperationEnum
CGlobalMemoryDataOperation
>
InMemoryDataOperationEnum
CGlobalMemoryDataOperation
>
__global__
void
__global__
void
...
@@ -32,44 +54,99 @@ __global__ void
...
@@ -32,44 +54,99 @@ __global__ void
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
#endif
#endif
kernel_grouped_gemm_xdl_splitk
(
const
void
CK_CONSTANT_ADDRESS_SPACE
*
gemm_descs_const
,
kernel_grouped_gemm_xdl_splitk
(
const
void
CK_CONSTANT_ADDRESS_SPACE
*
gemm_descs_const
,
const
index_t
group_count
)
const
index_t
tile_count
,
const
index_t
k_batch
)
{
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
constexpr
index_t
shared_size
=
GridwiseGemm
::
GetSharedMemoryNumberOfByte
();
constexpr
index_t
shared_size
=
GridwiseGemm
::
GetSharedMemoryNumberOfByte
();
__shared__
uint8_t
p_shared
[
shared_size
];
__shared__
uint8_t
p_shared
[
shared_size
];
const
index_t
block_id
=
get_block_1d_id
();
index_t
tile_id
=
get_block_1d_id
();
const
index_t
grid_size
=
get_grid_size
();
const
auto
gemm_desc_ptr
=
const
auto
gemm_desc_ptr
=
reinterpret_cast
<
const
GemmDesc
*>
(
cast_pointer_to_generic_address_space
(
gemm_descs_const
));
reinterpret_cast
<
const
GemmDesc
*>
(
cast_pointer_to_generic_address_space
(
gemm_descs_const
));
index_t
left
=
0
;
static
constexpr
index_t
MPerBlock
=
GridwiseGemm
::
GetMPerBlock
();
index_t
right
=
group_count
;
static
constexpr
index_t
NPerBlock
=
GridwiseGemm
::
GetNPerBlock
();
index_t
group_id
=
index_t
((
left
+
right
)
/
2
);
static
constexpr
index_t
B2E_M01
=
8
;
while
((
!
(
block_id
>=
gemm_desc_ptr
[
group_id
].
block_start_
&&
block_id
<
gemm_desc_ptr
[
group_id
].
block_end_
))
&&
using
CGridDesc_M_N
=
typename
GridwiseGemm
::
CGridDesc_M_N
;
left
<=
right
)
using
Block2ETileMapKSplit
=
BlockToCTileMap_KSplit_M00_N0_M01Adapt
<
MPerBlock
,
NPerBlock
,
CGridDesc_M_N
>
;
index_t
group_id
=
0
;
index_t
offset
=
0
;
auto
M
=
gemm_desc_ptr
[
group_id
].
M
;
auto
N
=
gemm_desc_ptr
[
group_id
].
N
;
auto
StrideC
=
gemm_desc_ptr
[
group_id
].
StrideC
;
auto
c_grid_desc_m_n
=
GridwiseGemm
::
MakeCGridDescriptor_M_N
(
M
,
N
,
StrideC
);
auto
b2c_tile_map
=
Block2ETileMapKSplit
{
c_grid_desc_m_n
,
B2E_M01
,
k_batch
};
index_t
grid_size_grp
=
b2c_tile_map
.
CalculateGridSize
(
c_grid_desc_m_n
);
index_t
gemm_tile_id_start
=
0
;
index_t
gemm_tile_id_end
=
grid_size_grp
;
while
(
tile_id
<
tile_count
)
{
{
if
(
block_id
<
gemm_desc_ptr
[
group_id
].
block_start_
)
// Find corresponding GEMM group for out tile
{
while
(
!
(
tile_id
>=
gemm_tile_id_start
&&
tile_id
<
gemm_tile_id_end
))
right
=
group_id
;
}
else
{
{
left
=
group_id
;
offset
+=
grid_size_grp
;
group_id
++
;
M
=
gemm_desc_ptr
[
group_id
].
M
;
N
=
gemm_desc_ptr
[
group_id
].
N
;
StrideC
=
gemm_desc_ptr
[
group_id
].
StrideC
;
c_grid_desc_m_n
=
GridwiseGemm
::
MakeCGridDescriptor_M_N
(
M
,
N
,
StrideC
);
b2c_tile_map
=
Block2ETileMapKSplit
{
c_grid_desc_m_n
,
B2E_M01
,
k_batch
};
grid_size_grp
=
b2c_tile_map
.
CalculateGridSize
(
c_grid_desc_m_n
);
gemm_tile_id_start
=
offset
;
gemm_tile_id_end
=
offset
+
grid_size_grp
;
}
}
group_id
=
index_t
((
left
+
right
)
/
2
);
}
LocalBlockToCTileMap
<
typename
GemmDesc
::
B2CType
>
local_b2c
{
const
auto
p_a_grid
=
reinterpret_cast
<
const
FloatA
*>
(
gemm_desc_ptr
[
group_id
].
p_a_grid
);
gemm_desc_ptr
[
group_id
].
block_2_ctile_map_
,
const
auto
p_b_grid
=
reinterpret_cast
<
const
FloatB
*>
(
gemm_desc_ptr
[
group_id
].
p_b_grid
);
block_id
-
gemm_desc_ptr
[
group_id
].
block_start_
};
const
auto
p_c_grid
=
reinterpret_cast
<
FloatC
*>
(
gemm_desc_ptr
[
group_id
].
p_c_grid
);
const
auto
K
=
gemm_desc_ptr
[
group_id
].
K
;
const
auto
StrideA
=
gemm_desc_ptr
[
group_id
].
StrideA
;
const
auto
StrideB
=
gemm_desc_ptr
[
group_id
].
StrideB
;
const
auto
MPadded
=
GridwiseGemm
::
CalculateMPadded
(
M
);
const
auto
NPadded
=
GridwiseGemm
::
CalculateNPadded
(
N
);
const
auto
KPadded
=
GridwiseGemm
::
CalculateKPadded
(
K
,
k_batch
);
const
auto
K0
=
GridwiseGemm
::
CalculateK0
(
K
,
k_batch
);
LocalBlockToCTileMap
<
Block2ETileMapKSplit
>
local_b2c
{
b2c_tile_map
,
tile_id
-
offset
};
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
,
CGlobalMemoryDataOperation
>(
p_a_grid
,
p_b_grid
,
p_c_grid
,
M
,
N
,
K
,
StrideA
,
StrideB
,
StrideC
,
MPadded
,
NPadded
,
KPadded
,
K0
,
k_batch
,
static_cast
<
void
*>
(
p_shared
),
local_b2c
);
tile_id
+=
grid_size
;
}
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
,
CGlobalMemoryDataOperation
>(
gemm_desc_ptr
[
group_id
].
karg_
,
static_cast
<
void
*>
(
p_shared
),
local_b2c
);
#else
#else
ignore
=
gemm_descs_const
;
ignore
=
gemm_descs_const
;
ignore
=
group_count
;
ignore
=
tile_count
;
ignore
=
k_batch
;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
}
}
...
@@ -186,35 +263,13 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
...
@@ -186,35 +263,13 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
LoopSched
,
LoopSched
,
PipelineVersion
::
v2
>
;
PipelineVersion
::
v2
>
;
using
CGridDesc_M_N
=
typename
GridwiseGemm
::
CGridDesc_M_N
;
using
CGridDesc_M_N
=
typename
GridwiseGemm
::
CGridDesc_M_N
;
using
GridwiseGemmArg
=
typename
GridwiseGemm
::
Argument
;
using
KernelArguments
=
GroupedGemmKernelArguments
;
using
Block2ETileMapKSplit
=
using
Block2ETileMapKSplit
=
BlockToCTileMap_KSplit_M00_N0_M01Adapt
<
MPerBlock
,
NPerBlock
,
CGridDesc_M_N
>
;
BlockToCTileMap_KSplit_M00_N0_M01Adapt
<
MPerBlock
,
NPerBlock
,
CGridDesc_M_N
>
;
// Block2CTileMap configuration parameter.
// Block2CTileMap configuration parameter.
static
constexpr
index_t
B2E_M01
=
8
;
static
constexpr
index_t
B2E_M01
=
8
;
// using GroupedGemmBlock2ETileMap = LocalBlockToCTileMap<Block2ETileMapKSplit>;
using
KernelArgument
=
typename
GridwiseGemm
::
Argument
;
struct
GemmTransKernelArg
{
using
B2CType
=
Block2ETileMapKSplit
;
KernelArgument
karg_
;
Block2ETileMapKSplit
block_2_ctile_map_
;
index_t
block_start_
,
block_end_
;
GemmTransKernelArg
()
=
default
;
GemmTransKernelArg
(
KernelArgument
&&
karg
,
Block2ETileMapKSplit
&&
b2c_map
,
index_t
block_start
,
index_t
block_end
)
:
karg_
{
karg
},
block_2_ctile_map_
{
b2c_map
},
block_start_
{
block_start
},
block_end_
{
block_end
}
{
}
};
static
constexpr
index_t
DefaultKBatch
=
1
;
static
constexpr
index_t
DefaultKBatch
=
1
;
// Argument
// Argument
...
@@ -227,7 +282,6 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
...
@@ -227,7 +282,6 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
std
::
vector
<
GemmDesc
>&
gemm_descs
)
std
::
vector
<
GemmDesc
>&
gemm_descs
)
:
Argument
(
p_As
,
p_Bs
,
p_Es
,
gemm_descs
,
DefaultKBatch
)
:
Argument
(
p_As
,
p_Bs
,
p_Es
,
gemm_descs
,
DefaultKBatch
)
{
{
// TODO: use occupancy api to calculate appropriate batch size.
}
}
Argument
(
std
::
vector
<
const
void
*>&
p_As
,
Argument
(
std
::
vector
<
const
void
*>&
p_As
,
...
@@ -235,9 +289,8 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
...
@@ -235,9 +289,8 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
std
::
vector
<
void
*>&
p_Es
,
std
::
vector
<
void
*>&
p_Es
,
std
::
vector
<
GemmDesc
>&
gemm_descs
,
std
::
vector
<
GemmDesc
>&
gemm_descs
,
index_t
kbatch
)
index_t
kbatch
)
:
K_BATCH
{
kbatch
}
:
K_BATCH
{
kbatch
}
,
group_count_
{
0
},
skipped_group_count_
{
0
},
grid_size_
{
0
}
{
{
grid_size_
=
0
;
group_count_
=
ck
::
type_convert
<
ck
::
index_t
>
(
gemm_descs
.
size
());
group_count_
=
ck
::
type_convert
<
ck
::
index_t
>
(
gemm_descs
.
size
());
if
(
!
(
group_count_
==
ck
::
type_convert
<
ck
::
index_t
>
(
p_As
.
size
())
&&
if
(
!
(
group_count_
==
ck
::
type_convert
<
ck
::
index_t
>
(
p_As
.
size
())
&&
...
@@ -249,8 +302,6 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
...
@@ -249,8 +302,6 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
gemm_kernel_args_
.
reserve
(
group_count_
);
gemm_kernel_args_
.
reserve
(
group_count_
);
skipped_group_count_
=
0
;
for
(
std
::
size_t
i
=
0
;
i
<
gemm_descs
.
size
();
++
i
)
for
(
std
::
size_t
i
=
0
;
i
<
gemm_descs
.
size
();
++
i
)
{
{
const
index_t
M
=
gemm_descs
[
i
].
M_
;
const
index_t
M
=
gemm_descs
[
i
].
M_
;
...
@@ -267,46 +318,29 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
...
@@ -267,46 +318,29 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
const
index_t
stride_b
=
gemm_descs
[
i
].
stride_B_
;
const
index_t
stride_b
=
gemm_descs
[
i
].
stride_B_
;
const
index_t
stride_c
=
gemm_descs
[
i
].
stride_C_
;
const
index_t
stride_c
=
gemm_descs
[
i
].
stride_C_
;
const
index_t
m_padded
=
GridwiseGemm
::
CalculateMPadded
(
M
);
const
index_t
n_padded
=
GridwiseGemm
::
CalculateNPadded
(
N
);
const
index_t
k_padded
=
GridwiseGemm
::
CalculateKPadded
(
K
,
K_BATCH
);
const
index_t
k0
=
GridwiseGemm
::
CalculateK0
(
K
,
K_BATCH
);
const
auto
c_grid_desc_m_n
=
GridwiseGemm
::
MakeCGridDescriptor_M_N
(
M
,
N
,
stride_c
);
const
auto
c_grid_desc_m_n
=
GridwiseGemm
::
MakeCGridDescriptor_M_N
(
M
,
N
,
stride_c
);
auto
local_b2c_tile_map
=
Block2ETileMapKSplit
{
c_grid_desc_m_n
,
B2E_M01
,
K_BATCH
};
auto
local_b2c_tile_map
=
Block2ETileMapKSplit
{
c_grid_desc_m_n
,
B2E_M01
,
K_BATCH
};
const
index_t
grid_size_grp
=
local_b2c_tile_map
.
CalculateGridSize
(
c_grid_desc_m_n
);
const
index_t
grid_size_grp
=
local_b2c_tile_map
.
CalculateGridSize
(
c_grid_desc_m_n
);
const
index_t
block_start
=
grid_size_
;
const
index_t
block_end
=
grid_size_
+
grid_size_grp
;
grid_size_
+=
grid_size_grp
;
grid_size_
+=
grid_size_grp
;
auto
karg
=
KernelArgument
{
type_convert
<
const
ADataType
*>
(
p_As
[
i
]),
gemm_kernel_args_
.
emplace_back
(
type_convert
<
const
ADataType
*>
(
p_As
[
i
]),
type_convert
<
const
BDataType
*>
(
p_Bs
[
i
]),
type_convert
<
const
BDataType
*>
(
p_Bs
[
i
]),
type_convert
<
EDataType
*>
(
p_Es
[
i
]),
type_convert
<
EDataType
*>
(
p_Es
[
i
]),
M
,
M
,
N
,
N
,
K
,
K
,
stride_a
,
stride_a
,
stride_b
,
stride_b
,
stride_c
,
stride_c
);
m_padded
,
n_padded
,
k_padded
,
k0
,
K_BATCH
};
gemm_kernel_args_
.
emplace_back
(
std
::
move
(
karg
),
std
::
move
(
local_b2c_tile_map
),
block_start
,
block_end
);
}
}
}
}
/
**
/
//
*
@brief
Recalculate group grid size for all gemms and update B2C maps
.
///
@brief
Set new kbatch value
.
*
///
*
@param[in] kbatch The new splitK parameter value.
///
@param[in] kbatch The new splitK parameter value.
*
/
//
/
void
UpdateKBatch
(
index_t
kbatch
)
void
UpdateKBatch
(
index_t
kbatch
)
{
{
K_BATCH
=
kbatch
;
K_BATCH
=
kbatch
;
...
@@ -315,28 +349,14 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
...
@@ -315,28 +349,14 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
for
(
std
::
size_t
i
=
0
;
i
<
gemm_kernel_args_
.
size
();
++
i
)
for
(
std
::
size_t
i
=
0
;
i
<
gemm_kernel_args_
.
size
();
++
i
)
{
{
auto
&
karg
=
gemm_kernel_args_
[
i
].
karg_
;
auto
&
gemm_arg
=
gemm_kernel_args_
[
i
];
const
index_t
k_padded
=
GridwiseGemm
::
CalculateKPadded
(
karg
.
K
,
K_BATCH
);
const
index_t
k0
=
GridwiseGemm
::
CalculateK0
(
karg
.
K
,
K_BATCH
);
const
auto
c_grid_desc_m_n
=
const
auto
c_grid_desc_m_n
=
GridwiseGemm
::
MakeCGridDescriptor_M_N
(
k
arg
.
M
,
k
arg
.
N
,
k
arg
.
StrideC
);
GridwiseGemm
::
MakeCGridDescriptor_M_N
(
gemm_
arg
.
M
,
gemm_
arg
.
N
,
gemm_
arg
.
StrideC
);
auto
local_b2c_tile_map
=
Block2ETileMapKSplit
{
c_grid_desc_m_n
,
B2E_M01
,
K_BATCH
};
auto
local_b2c_tile_map
=
Block2ETileMapKSplit
{
c_grid_desc_m_n
,
B2E_M01
,
K_BATCH
};
const
index_t
grid_size_grp
=
local_b2c_tile_map
.
CalculateGridSize
(
c_grid_desc_m_n
);
const
index_t
grid_size_grp
=
local_b2c_tile_map
.
CalculateGridSize
(
c_grid_desc_m_n
);
const
index_t
block_start
=
grid_size_
;
const
index_t
block_end
=
grid_size_
+
grid_size_grp
;
grid_size_
+=
grid_size_grp
;
grid_size_
+=
grid_size_grp
;
karg
.
KPadded
=
k_padded
;
karg
.
K0
=
k0
;
karg
.
k_batch
=
K_BATCH
;
gemm_kernel_args_
[
i
].
block_2_ctile_map_
=
local_b2c_tile_map
;
gemm_kernel_args_
[
i
].
block_start_
=
block_start
;
gemm_kernel_args_
[
i
].
block_end_
=
block_end
;
}
}
}
}
...
@@ -344,31 +364,165 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
...
@@ -344,31 +364,165 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
index_t
K_BATCH
;
index_t
K_BATCH
;
index_t
group_count_
;
index_t
group_count_
;
index_t
skipped_group_count_
;
index_t
skipped_group_count_
;
// The overall number of output tiles to be processed.
std
::
vector
<
GemmTransKernelArg
>
gemm_kernel_args_
;
index_t
grid_size_
;
index_t
grid_size_
;
const
void
*
p_dev_gemm_args_
;
std
::
vector
<
KernelArguments
>
gemm_kernel_args_
;
};
};
// Invoker
// Invoker
struct
Invoker
:
public
BaseInvoker
struct
Invoker
:
public
BaseInvoker
{
{
// The oversubscription factor for the number of blocks that can simultaneously reside on
// GPU.
static
constexpr
int
BLOCK_SUBSCRIPTION_FACTOR
=
1
;
static
constexpr
int
BLOCK_WAVES
=
BlockSize
/
get_warp_size
();
static
constexpr
int
CU_SIMDS
=
4
;
// Assume we want to have at most 2 waves per SIMD
static
constexpr
int
CU_BLOCKS
=
math
::
integer_divide_floor
(
2
*
CU_SIMDS
,
BLOCK_WAVES
);
///
/// @brief Launch Grouped Gemm kernel.
///
/// @note This function overload is using user provided device buffer for kernel
/// arguments.
///
/// @param[in] arg The structure containing kernel arguments (in host memory).
/// @param[in] dev_gemm_args The point to device memory with kernel arguments.
/// @param[in] stream_config The device stream configuration.
///
/// @return The average kernel execution time (if time measurement is enabled.)
///
float
Run
(
const
Argument
&
arg
,
const
void
*
dev_gemm_args
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
{
auto
[
all_have_kbatch_gt_one
,
all_have_main_k0_block_loop
]
=
CheckArgument
(
arg
,
stream_config
);
if
(
dev_gemm_args
==
nullptr
)
{
std
::
ostringstream
err
;
err
<<
"The gemm arguments workspace buffer is not allocated!"
<<
" In "
<<
__FILE__
<<
":"
<<
__LINE__
<<
", in function: "
<<
__func__
;
throw
std
::
runtime_error
(
err
.
str
());
}
if
(
all_have_kbatch_gt_one
)
{
for
(
const
auto
&
gemm_arg
:
arg
.
gemm_kernel_args_
)
{
hip_check_error
(
hipMemset
(
gemm_arg
.
p_c_grid
,
0
,
gemm_arg
.
M
*
gemm_arg
.
N
*
sizeof
(
EDataType
)));
}
}
float
ave_time
=
0
;
if
(
all_have_main_k0_block_loop
)
{
if
(
all_have_kbatch_gt_one
)
{
ave_time
=
DispatchKernel
<
InMemoryDataOperationEnum
::
AtomicAdd
,
true
>
(
arg
,
dev_gemm_args
,
stream_config
);
}
else
{
ave_time
=
DispatchKernel
<
InMemoryDataOperationEnum
::
Set
,
true
>
(
arg
,
dev_gemm_args
,
stream_config
);
}
}
else
{
if
(
all_have_kbatch_gt_one
)
{
ave_time
=
DispatchKernel
<
InMemoryDataOperationEnum
::
AtomicAdd
,
false
>
(
arg
,
dev_gemm_args
,
stream_config
);
}
else
{
ave_time
=
DispatchKernel
<
InMemoryDataOperationEnum
::
Set
,
false
>
(
arg
,
dev_gemm_args
,
stream_config
);
}
}
return
ave_time
;
}
///
/// @brief Launch Grouped Gemm kernel.
///
/// @note This function overload is using device workspace buffer for kernel
/// arguments. The user should call @see GetWorkSpaceSize and @see
/// SetWorkSpacePointer on arg parameter to properly allocate this buffer.
///
/// @param[in] arg The structure containing kernel arguments (in host memory).
/// @param[in] stream_config The device stream configuration.
///
/// @return The average kernel execution time (if time measurement is enabled.)
///
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
{
{
index_t
K0
=
arg
.
gemm_kernel_args_
[
0
].
karg_
.
K0
;
if
(
arg
.
p_workspace_
!=
nullptr
)
bool
all_have_kbatch_gt_one
=
arg
.
gemm_kernel_args_
[
0
].
karg_
.
k_batch
>
1
;
{
hip_check_error
(
hipMemcpyWithStream
(
arg
.
p_workspace_
,
arg
.
gemm_kernel_args_
.
data
(),
arg
.
gemm_kernel_args_
.
size
()
*
sizeof
(
KernelArguments
),
hipMemcpyHostToDevice
,
stream_config
.
stream_id_
));
}
else
{
std
::
ostringstream
err
;
err
<<
"The gemm arguments workspace buffer is not allocated!"
<<
" In "
<<
__FILE__
<<
":"
<<
__LINE__
<<
", in function: "
<<
__func__
;
throw
std
::
runtime_error
(
err
.
str
());
}
return
Run
(
arg
,
arg
.
p_workspace_
,
stream_config
);
}
float
Run
(
const
BaseArgument
*
p_arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
override
{
return
Run
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
),
stream_config
);
}
private:
auto
CheckArgument
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
)
const
{
index_t
K0
=
GridwiseGemm
::
CalculateK0
(
arg
.
gemm_kernel_args_
[
0
].
K
,
arg
.
K_BATCH
);
bool
all_have_kbatch_gt_one
=
arg
.
K_BATCH
>
1
;
bool
all_have_main_k0_block_loop
=
GridwiseGemm
::
CalculateHasMainK0BlockLoop
(
K0
);
bool
all_have_main_k0_block_loop
=
GridwiseGemm
::
CalculateHasMainK0BlockLoop
(
K0
);
for
(
std
::
size_t
i
=
0
;
i
<
arg
.
gemm_kernel_args_
.
size
();
++
i
)
for
(
std
::
size_t
i
=
0
;
i
<
arg
.
gemm_kernel_args_
.
size
();
++
i
)
{
{
const
auto
&
k
arg
=
arg
.
gemm_kernel_args_
[
i
]
.
karg_
;
const
auto
&
gemm_
arg
=
arg
.
gemm_kernel_args_
[
i
];
if
(
stream_config
.
log_level_
>
0
)
if
(
stream_config
.
log_level_
>
0
)
{
{
k
arg
.
Print
();
gemm_
arg
.
Print
();
}
}
auto
kbatch
=
karg
.
k_batch
;
// Currently all groups use same kbatch value.
auto
kbatch
=
arg
.
K_BATCH
;
if
(
!
GridwiseGemm
::
CheckValidity
(
karg
))
K0
=
GridwiseGemm
::
CalculateK0
(
arg
.
gemm_kernel_args_
[
i
].
K
,
arg
.
K_BATCH
);
if
(
!
GridwiseGemm
::
CheckValidity
(
GridwiseGemmArg
{
nullptr
,
nullptr
,
nullptr
,
gemm_arg
.
M
,
gemm_arg
.
N
,
gemm_arg
.
K
,
gemm_arg
.
StrideA
,
gemm_arg
.
StrideB
,
gemm_arg
.
StrideC
,
0
,
// MPadded
0
,
// NPadded
0
,
// KPadded
K0
,
kbatch
}))
{
{
std
::
ostringstream
err
;
std
::
ostringstream
err
;
err
<<
"Group id: "
<<
i
<<
" has invalid GridwiseGemm settings!"
<<
__FILE__
err
<<
"Group id: "
<<
i
<<
" has invalid GridwiseGemm settings!"
<<
__FILE__
...
@@ -376,7 +530,6 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
...
@@ -376,7 +530,6 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
throw
std
::
runtime_error
(
err
.
str
());
throw
std
::
runtime_error
(
err
.
str
());
}
}
K0
=
karg
.
K0
;
bool
not_all_have_main_k0_block_loop_same
=
bool
not_all_have_main_k0_block_loop_same
=
all_have_main_k0_block_loop
xor
GridwiseGemm
::
CalculateHasMainK0BlockLoop
(
K0
);
all_have_main_k0_block_loop
xor
GridwiseGemm
::
CalculateHasMainK0BlockLoop
(
K0
);
bool
not_all_have_kbatch_value_same
=
all_have_kbatch_gt_one
xor
(
kbatch
>
1
);
bool
not_all_have_kbatch_value_same
=
all_have_kbatch_gt_one
xor
(
kbatch
>
1
);
...
@@ -394,97 +547,75 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
...
@@ -394,97 +547,75 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
std
::
ostringstream
err
;
std
::
ostringstream
err
;
err
<<
"Not all gemms have same kbatch value (=1 or >1)! "
err
<<
"Not all gemms have same kbatch value (=1 or >1)! "
<<
"group ["
<<
i
<<
"], kbatch: "
<<
kbatch
<<
"group ["
<<
i
<<
"], kbatch: "
<<
kbatch
<<
", group [0], kbatch: "
<<
arg
.
gemm_kernel_args_
[
0
].
karg_
.
k_batch
<<
", group [0], kbatch: "
<<
arg
.
K_BATCH
<<
" in "
<<
__FILE__
<<
":"
<<
" in "
<<
__FILE__
<<
":"
<<
__LINE__
<<
", in function: "
<<
__func__
;
<<
__LINE__
<<
", in function: "
<<
__func__
;
throw
std
::
runtime_error
(
err
.
str
());
throw
std
::
runtime_error
(
err
.
str
());
}
}
}
}
return
std
::
make_tuple
(
all_have_kbatch_gt_one
,
all_have_main_k0_block_loop
);
}
hip_check_error
(
template
<
InMemoryDataOperationEnum
CGlobalMemoryDataOperation
,
bool
HasMainKBlockLoop
>
hipMemcpyWithStream
(
arg
.
p_workspace_
,
float
DispatchKernel
(
const
Argument
&
arg
,
arg
.
gemm_kernel_args_
.
data
(),
const
void
*
dev_gemm_args
,
arg
.
gemm_kernel_args_
.
size
()
*
sizeof
(
GemmTransKernelArg
),
const
StreamConfig
&
stream_config
)
const
hipMemcpyHostToDevice
,
{
stream_config
.
stream_id_
));
const
auto
kernel
=
kernel_grouped_gemm_xdl_splitk
<
GridwiseGemm
,
KernelArguments
,
float
ave_time
=
0
;
ADataType
,
BDataType
,
const
auto
Run
=
[
&
](
const
auto
&
kernel
)
{
EDataType
,
if
(
all_have_kbatch_gt_one
)
HasMainKBlockLoop
,
{
CGlobalMemoryDataOperation
>
;
for
(
const
auto
&
trans_arg
:
arg
.
gemm_kernel_args_
)
return
LaunchKernel
(
kernel
,
arg
,
dev_gemm_args
,
stream_config
);
{
}
const
auto
&
karg
=
trans_arg
.
karg_
;
hip_check_error
(
hipMemset
(
karg
.
p_c_grid
,
0
,
karg
.
M
*
karg
.
N
*
sizeof
(
EDataType
)));
}
}
ave_time
=
launch_and_time_kernel
(
stream_config
,
kernel
,
dim3
(
arg
.
grid_size_
),
dim3
(
BlockSize
),
0
,
cast_pointer_to_constant_address_space
(
arg
.
p_workspace_
),
arg
.
gemm_kernel_args_
.
size
());
};
if
(
all_have_main_k0_block_loop
)
template
<
typename
KernelFunction
>
{
int
CalculateMaxOccupancyGridSize
(
const
KernelFunction
&
kernel
,
if
(
all_have_kbatch_gt_one
)
const
StreamConfig
&
stream_config
)
const
{
{
const
auto
kernel
=
// Calculate max number of workgroups that can simultaneously reside on the CU.
kernel_grouped_gemm_xdl_splitk
<
GridwiseGemm
,
int
num_blocks
=
0
;
GemmTransKernelArg
,
size_t
dyn_shared_mem_per_blk
=
0
;
true
,
hip_check_error
(
hipOccupancyMaxActiveBlocksPerMultiprocessor
(
InMemoryDataOperationEnum
::
AtomicAdd
>
;
&
num_blocks
,
kernel
,
BlockSize
,
dyn_shared_mem_per_blk
))
;
Run
(
kernel
);
int
cu_count
=
getAvailableComputeUnitCount
(
stream_config
);
}
else
{
const
auto
kernel
=
kernel_grouped_gemm_xdl_splitk
<
GridwiseGemm
,
GemmTransKernelArg
,
true
,
InMemoryDataOperationEnum
::
Set
>
;
Run
(
kernel
);
if
(
stream_config
.
log_level_
>
0
)
}
}
else
{
{
if
(
all_have_kbatch_gt_one
)
std
::
cout
<<
"MaxActiveBlocksPerCU: "
<<
num_blocks
{
<<
", available CUs count: "
<<
cu_count
<<
", occup. grid size: "
const
auto
kernel
=
<<
ck
::
math
::
min
(
num_blocks
,
CU_BLOCKS
)
*
cu_count
*
kernel_grouped_gemm_xdl_splitk
<
GridwiseGemm
,
BLOCK_SUBSCRIPTION_FACTOR
GemmTransKernelArg
,
<<
std
::
endl
;
false
,
InMemoryDataOperationEnum
::
AtomicAdd
>
;
Run
(
kernel
);
}
else
{
const
auto
kernel
=
kernel_grouped_gemm_xdl_splitk
<
GridwiseGemm
,
GemmTransKernelArg
,
false
,
InMemoryDataOperationEnum
::
Set
>
;
Run
(
kernel
);
}
}
}
return
ave_time
;
return
cu_count
*
ck
::
math
::
min
(
num_blocks
,
CU_BLOCKS
)
*
BLOCK_SUBSCRIPTION_FACTOR
;
}
}
// polymorphic
template
<
typename
KernelFunction
>
float
Run
(
const
BaseArgument
*
p_arg
,
float
LaunchKernel
(
const
KernelFunction
&
kernel
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
override
const
Argument
&
arg
,
const
void
*
dev_gemm_args
,
const
StreamConfig
&
stream_config
)
const
{
{
return
Run
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
),
stream_config
);
int
max_occupancy_grid_size
=
CalculateMaxOccupancyGridSize
(
kernel
,
stream_config
);
// We launch the smaller number of workgroups from acutally needed tiles and the
// number of workgroups that maximize the GPU occupancy. That is because for some tile
// configuration the first is smaller than the latter. Launching too many workgroups
// mean some of them will have to iterate through all gemm problem descriptors just to
// find out they have nothing to do which is of course waste of GPU cycles.
return
launch_and_time_kernel
(
stream_config
,
kernel
,
dim3
(
ck
::
math
::
min
(
arg
.
grid_size_
,
max_occupancy_grid_size
)),
dim3
(
BlockSize
),
0
,
cast_pointer_to_constant_address_space
(
dev_gemm_args
),
arg
.
grid_size_
,
arg
.
K_BATCH
);
}
}
};
};
...
@@ -496,11 +627,6 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
...
@@ -496,11 +627,6 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
{
if
(
!
ck
::
is_xdl_supported
())
{
return
false
;
}
if
((
ck
::
type_convert
<
ck
::
index_t
>
(
arg
.
gemm_kernel_args_
.
size
())
+
if
((
ck
::
type_convert
<
ck
::
index_t
>
(
arg
.
gemm_kernel_args_
.
size
())
+
arg
.
skipped_group_count_
)
!=
arg
.
group_count_
)
arg
.
skipped_group_count_
)
!=
arg
.
group_count_
)
{
{
...
@@ -515,14 +641,28 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
...
@@ -515,14 +641,28 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
bool
supported
=
true
;
bool
supported
=
true
;
for
(
std
::
size_t
i
=
0
;
i
<
arg
.
gemm_kernel_args_
.
size
();
++
i
)
for
(
std
::
size_t
i
=
0
;
i
<
arg
.
gemm_kernel_args_
.
size
();
++
i
)
{
{
const
auto
&
a
=
arg
.
gemm_kernel_args_
[
i
].
karg_
;
const
auto
&
gemm_arg
=
arg
.
gemm_kernel_args_
[
i
];
bool
group_arg_valid
=
GridwiseGemm
::
CheckValidity
(
a
);
const
auto
K0
=
GridwiseGemm
::
CalculateK0
(
gemm_arg
.
K
,
arg
.
K_BATCH
);
bool
group_arg_valid
=
GridwiseGemm
::
CheckValidity
(
GridwiseGemmArg
{
nullptr
,
nullptr
,
nullptr
,
gemm_arg
.
M
,
gemm_arg
.
N
,
gemm_arg
.
K
,
gemm_arg
.
StrideA
,
gemm_arg
.
StrideB
,
gemm_arg
.
StrideC
,
0
,
// MPadded
0
,
// NPadded
0
,
// KPadded
K0
,
arg
.
K_BATCH
});
if
(
not
group_arg_valid
)
if
(
not
group_arg_valid
)
{
{
#if DEBUG_LOG
#if DEBUG_LOG
std
::
cout
<<
"["
<<
__func__
<<
"] group id: "
<<
i
std
::
cout
<<
"["
<<
__func__
<<
"] group id: "
<<
i
<<
" has invalid GridwiseGemm settings!"
<<
std
::
endl
;
<<
" has invalid GridwiseGemm settings!"
<<
std
::
endl
;
a
.
Print
();
gemm_arg
.
Print
();
#endif // DEBUG_LOG
#endif // DEBUG_LOG
}
}
supported
=
supported
&&
group_arg_valid
;
supported
=
supported
&&
group_arg_valid
;
...
@@ -530,7 +670,6 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
...
@@ -530,7 +670,6 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
return
supported
;
return
supported
;
}
}
// polymorphic
bool
IsSupportedArgument
(
const
BaseArgument
*
p_arg
)
override
bool
IsSupportedArgument
(
const
BaseArgument
*
p_arg
)
override
{
{
return
IsSupportedArgument
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
));
return
IsSupportedArgument
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
));
...
@@ -550,7 +689,6 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
...
@@ -550,7 +689,6 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
// polymorphic
std
::
unique_ptr
<
BaseArgument
>
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
std
::
vector
<
const
void
*>&
p_As
,
MakeArgumentPointer
(
std
::
vector
<
const
void
*>&
p_As
,
std
::
vector
<
const
void
*>&
p_Bs
,
std
::
vector
<
const
void
*>&
p_Bs
,
...
@@ -564,19 +702,17 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
...
@@ -564,19 +702,17 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
return
std
::
make_unique
<
Argument
>
(
p_As
,
p_Bs
,
p_Es
,
gemm_descs
);
return
std
::
make_unique
<
Argument
>
(
p_As
,
p_Bs
,
p_Es
,
gemm_descs
);
}
}
// polymorphic
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
override
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
override
{
{
return
std
::
make_unique
<
Invoker
>
(
Invoker
{});
return
std
::
make_unique
<
Invoker
>
(
Invoker
{});
}
}
// polymorphic
std
::
string
GetTypeString
()
const
override
std
::
string
GetTypeString
()
const
override
{
{
auto
str
=
std
::
stringstream
();
auto
str
=
std
::
stringstream
();
// clang-format off
// clang-format off
str
<<
"DeviceGroupedGemm_XdlSplitK"
str
<<
"DeviceGroupedGemm_XdlSplitK
TileLoop
"
<<
"<"
<<
"<"
<<
std
::
string
(
ALayout
::
name
)[
0
]
<<
","
<<
std
::
string
(
ALayout
::
name
)[
0
]
<<
","
<<
std
::
string
(
BLayout
::
name
)[
0
]
<<
","
<<
std
::
string
(
BLayout
::
name
)[
0
]
<<
","
...
@@ -595,6 +731,7 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
...
@@ -595,6 +731,7 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
<<
BBlockTransferSrcScalarPerVector
<<
", "
<<
BBlockTransferSrcScalarPerVector
<<
", "
<<
CShuffleMXdlPerWavePerShuffle
<<
", "
<<
CShuffleMXdlPerWavePerShuffle
<<
", "
<<
CShuffleNXdlPerWavePerShuffle
<<
", "
<<
CShuffleNXdlPerWavePerShuffle
<<
", "
<<
ABlockTransferThreadClusterLengths_K0_M_K1
{}
<<
", "
<<
getGemmSpecializationString
(
GemmSpec
)
<<
getGemmSpecializationString
(
GemmSpec
)
<<
">"
;
<<
">"
;
// clang-format on
// clang-format on
...
@@ -605,16 +742,24 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
...
@@ -605,16 +742,24 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
size_t
GetWorkSpaceSize
(
const
BaseArgument
*
p_arg
)
const
override
size_t
GetWorkSpaceSize
(
const
BaseArgument
*
p_arg
)
const
override
{
{
return
dynamic_cast
<
const
Argument
*>
(
p_arg
)
->
gemm_kernel_args_
.
size
()
*
return
dynamic_cast
<
const
Argument
*>
(
p_arg
)
->
gemm_kernel_args_
.
size
()
*
sizeof
(
GemmTrans
KernelArg
);
sizeof
(
KernelArg
uments
);
}
}
static
void
SetKBatchSize
(
Argument
&
arg
,
index_t
kbatch
)
{
arg
.
UpdateKBatch
(
kbatch
);
}
static
void
SetKBatchSize
(
Argument
&
arg
,
index_t
kbatch
)
{
arg
.
UpdateKBatch
(
kbatch
);
}
static
void
SetDeviceKernelArgs
(
Argument
&
arg
,
const
void
*
p_dev_kernel_args
)
{
arg
.
p_dev_gemm_args_
=
p_dev_kernel_args
;
}
// polymorphic
void
SetKBatchSize
(
BaseArgument
*
p_arg
,
index_t
kbatch
)
const
override
void
SetKBatchSize
(
BaseArgument
*
p_arg
,
index_t
kbatch
)
const
override
{
{
return
SetKBatchSize
(
*
dynamic_cast
<
Argument
*>
(
p_arg
),
kbatch
);
return
SetKBatchSize
(
*
dynamic_cast
<
Argument
*>
(
p_arg
),
kbatch
);
}
}
void
SetDeviceKernelArgs
(
BaseArgument
*
p_arg
,
const
void
*
p_dev_kernel_args
)
const
override
{
return
SetDeviceKernelArgs
(
*
dynamic_cast
<
Argument
*>
(
p_arg
),
p_dev_kernel_args
);
}
};
};
}
// namespace device
}
// namespace device
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_splitk_cshuffle_tile_loop.hpp
deleted
100644 → 0
View file @
50540084
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <sstream>
#include <tuple>
#include "ck/ck.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/host_utility/hip_check_error.hpp"
#include "ck/host_utility/stream_utility.hpp"
#include "ck/utility/common_header.hpp"
#include "ck/utility/tuple.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_gemm_splitk.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
///
/// @brief Entry point kernel for device-wide Grouped GEMM operation.
///
/// @param[in] gemm_descs_const The pointer to the array of GEMM descriptor structures.
/// @param[in] tile_count The overall number of output tiles we divided all groups
/// into.
/// @param[in] k_batch The number of batches we split the K dimension into.
///
/// @tparam GridwiseGemm The specific GridwiseGEMM algorithm implementation.
/// @tparam GemmDesc The structure holding all necessary descriptors and
/// other data needed for groupd gemm calculation and work
/// distribution.
/// @tparam HasMainKBlockLoop Flag indicating whether all GEMM problem configurations
/// need to loop over tiles in K dimension.
/// @tparam CGlobalMemoryDataOperation The functor used to store data in output C matrix.
/// In example could be: AtomicAdd or Store.
///
template
<
typename
GridwiseGemm
,
typename
GemmDesc
,
typename
FloatA
,
typename
FloatB
,
typename
FloatC
,
bool
HasMainKBlockLoop
,
InMemoryDataOperationEnum
CGlobalMemoryDataOperation
>
__global__
void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
#endif
kernel_grouped_gemm_xdl_splitk
(
const
void
CK_CONSTANT_ADDRESS_SPACE
*
gemm_descs_const
,
const
index_t
tile_count
,
const
index_t
k_batch
)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
constexpr
index_t
shared_size
=
GridwiseGemm
::
GetSharedMemoryNumberOfByte
();
__shared__
uint8_t
p_shared
[
shared_size
];
index_t
tile_id
=
get_block_1d_id
();
const
index_t
grid_size
=
get_grid_size
();
const
auto
gemm_desc_ptr
=
reinterpret_cast
<
const
GemmDesc
*>
(
cast_pointer_to_generic_address_space
(
gemm_descs_const
));
static
constexpr
index_t
MPerBlock
=
GridwiseGemm
::
GetMPerBlock
();
static
constexpr
index_t
NPerBlock
=
GridwiseGemm
::
GetNPerBlock
();
static
constexpr
index_t
B2E_M01
=
8
;
using
CGridDesc_M_N
=
typename
GridwiseGemm
::
CGridDesc_M_N
;
using
Block2ETileMapKSplit
=
BlockToCTileMap_KSplit_M00_N0_M01Adapt
<
MPerBlock
,
NPerBlock
,
CGridDesc_M_N
>
;
index_t
group_id
=
0
;
index_t
offset
=
0
;
auto
M
=
gemm_desc_ptr
[
group_id
].
M
;
auto
N
=
gemm_desc_ptr
[
group_id
].
N
;
auto
StrideC
=
gemm_desc_ptr
[
group_id
].
StrideC
;
auto
c_grid_desc_m_n
=
GridwiseGemm
::
MakeCGridDescriptor_M_N
(
M
,
N
,
StrideC
);
auto
b2c_tile_map
=
Block2ETileMapKSplit
{
c_grid_desc_m_n
,
B2E_M01
,
k_batch
};
index_t
grid_size_grp
=
b2c_tile_map
.
CalculateGridSize
(
c_grid_desc_m_n
);
index_t
gemm_tile_id_start
=
0
;
index_t
gemm_tile_id_end
=
grid_size_grp
;
while
(
tile_id
<
tile_count
)
{
// Find corresponding GEMM group for out tile
while
(
!
(
tile_id
>=
gemm_tile_id_start
&&
tile_id
<
gemm_tile_id_end
))
{
offset
+=
grid_size_grp
;
group_id
++
;
M
=
gemm_desc_ptr
[
group_id
].
M
;
N
=
gemm_desc_ptr
[
group_id
].
N
;
StrideC
=
gemm_desc_ptr
[
group_id
].
StrideC
;
c_grid_desc_m_n
=
GridwiseGemm
::
MakeCGridDescriptor_M_N
(
M
,
N
,
StrideC
);
b2c_tile_map
=
Block2ETileMapKSplit
{
c_grid_desc_m_n
,
B2E_M01
,
k_batch
};
grid_size_grp
=
b2c_tile_map
.
CalculateGridSize
(
c_grid_desc_m_n
);
gemm_tile_id_start
=
offset
;
gemm_tile_id_end
=
offset
+
grid_size_grp
;
}
const
auto
p_a_grid
=
reinterpret_cast
<
const
FloatA
*>
(
gemm_desc_ptr
[
group_id
].
p_a_grid
);
const
auto
p_b_grid
=
reinterpret_cast
<
const
FloatB
*>
(
gemm_desc_ptr
[
group_id
].
p_b_grid
);
const
auto
p_c_grid
=
reinterpret_cast
<
FloatC
*>
(
gemm_desc_ptr
[
group_id
].
p_c_grid
);
const
auto
K
=
gemm_desc_ptr
[
group_id
].
K
;
const
auto
StrideA
=
gemm_desc_ptr
[
group_id
].
StrideA
;
const
auto
StrideB
=
gemm_desc_ptr
[
group_id
].
StrideB
;
const
auto
MPadded
=
GridwiseGemm
::
CalculateMPadded
(
M
);
const
auto
NPadded
=
GridwiseGemm
::
CalculateNPadded
(
N
);
const
auto
KPadded
=
GridwiseGemm
::
CalculateKPadded
(
K
,
k_batch
);
const
auto
K0
=
GridwiseGemm
::
CalculateK0
(
K
,
k_batch
);
LocalBlockToCTileMap
<
Block2ETileMapKSplit
>
local_b2c
{
b2c_tile_map
,
tile_id
-
offset
};
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
,
CGlobalMemoryDataOperation
>(
p_a_grid
,
p_b_grid
,
p_c_grid
,
M
,
N
,
K
,
StrideA
,
StrideB
,
StrideC
,
MPadded
,
NPadded
,
KPadded
,
K0
,
k_batch
,
static_cast
<
void
*>
(
p_shared
),
local_b2c
);
tile_id
+=
grid_size
;
}
#else
ignore
=
gemm_descs_const
;
ignore
=
tile_count
;
ignore
=
k_batch
;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
}
template
<
typename
ALayout
,
typename
BLayout
,
typename
DsLayout
,
typename
ELayout
,
typename
ADataType
,
typename
BDataType
,
typename
AccDataType
,
typename
CShuffleDataType
,
typename
DsDataType
,
typename
EDataType
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CDEElementwiseOperation
,
GemmSpecialization
GemmSpec
,
ck
::
index_t
NumGemmKPrefetchStage
,
ck
::
index_t
BlockSize
,
ck
::
index_t
MPerBlock
,
ck
::
index_t
NPerBlock
,
ck
::
index_t
KPerBlock
,
ck
::
index_t
AK1
,
ck
::
index_t
BK1
,
ck
::
index_t
MPerXDL
,
ck
::
index_t
NPerXDL
,
ck
::
index_t
MXdlPerWave
,
ck
::
index_t
NXdlPerWave
,
typename
ABlockTransferThreadClusterLengths_K0_M_K1
,
typename
ABlockTransferThreadClusterArrangeOrder
,
typename
ABlockTransferSrcAccessOrder
,
ck
::
index_t
ABlockTransferSrcVectorDim
,
ck
::
index_t
ABlockTransferSrcScalarPerVector
,
ck
::
index_t
ABlockTransferDstScalarPerVector_K1
,
bool
ABlockLdsExtraM
,
typename
BBlockTransferThreadClusterLengths_K0_N_K1
,
typename
BBlockTransferThreadClusterArrangeOrder
,
typename
BBlockTransferSrcAccessOrder
,
ck
::
index_t
BBlockTransferSrcVectorDim
,
ck
::
index_t
BBlockTransferSrcScalarPerVector
,
ck
::
index_t
BBlockTransferDstScalarPerVector_K1
,
bool
BBlockLdsExtraN
,
index_t
CShuffleMXdlPerWavePerShuffle
,
index_t
CShuffleNXdlPerWavePerShuffle
,
typename
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
index_t
CDEBlockTransferScalarPerVector_NPerBlock
,
LoopScheduler
LoopSched
=
make_default_loop_scheduler
(),
// Current implementation does not support multiple D fusions.
enable_if_t
<
AK1
==
BK1
&&
is_same_v
<
DsLayout
,
ck
::
Tuple
<
>
>
&&
is_same_v
<
DsDataType
,
ck
::
Tuple
<>>
,
bool
>
=
false
>
struct
DeviceGroupedGemmXdlSplitKCShuffle
:
public
DeviceGroupedGemmSplitK
<
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
ADataType
,
BDataType
,
DsDataType
,
EDataType
,
AElementwiseOperation
,
BElementwiseOperation
,
CDEElementwiseOperation
>
{
static
constexpr
index_t
NumDTensor
=
DsDataType
::
Size
();
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
static_assert
(
KPerBlock
%
AK1
==
0
);
static
constexpr
index_t
K0PerBlock
=
KPerBlock
/
AK1
;
using
GridwiseGemm
=
GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
<
BlockSize
,
ADataType
,
// TODO: distinguish A/B datatype
AccDataType
,
EDataType
,
ALayout
,
BLayout
,
ELayout
,
AElementwiseOperation
,
BElementwiseOperation
,
CDEElementwiseOperation
,
GemmSpec
,
NumGemmKPrefetchStage
,
MPerBlock
,
NPerBlock
,
K0PerBlock
,
MPerXDL
,
NPerXDL
,
AK1
,
MXdlPerWave
,
NXdlPerWave
,
ABlockTransferThreadClusterLengths_K0_M_K1
,
ABlockTransferThreadClusterArrangeOrder
,
ABlockTransferSrcAccessOrder
,
ABlockTransferSrcVectorDim
,
ABlockTransferSrcScalarPerVector
,
ABlockTransferDstScalarPerVector_K1
,
false
,
// AThreadTransferSrcResetCoordinateAfterRun,
ABlockLdsExtraM
,
BBlockTransferThreadClusterLengths_K0_N_K1
,
BBlockTransferThreadClusterArrangeOrder
,
BBlockTransferSrcAccessOrder
,
BBlockTransferSrcVectorDim
,
BBlockTransferSrcScalarPerVector
,
BBlockTransferDstScalarPerVector_K1
,
false
,
// BThreadTransferSrcResetCoordinateAfterRun,
BBlockLdsExtraN
,
CShuffleMXdlPerWavePerShuffle
,
CShuffleNXdlPerWavePerShuffle
,
CDEBlockTransferScalarPerVector_NPerBlock
,
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
LoopSched
,
PipelineVersion
::
v2
>
;
using
CGridDesc_M_N
=
typename
GridwiseGemm
::
CGridDesc_M_N
;
using
GridwiseGemmArg
=
typename
GridwiseGemm
::
Argument
;
using
KernelArguments
=
GroupedGemmKernelArguments
;
using
Block2ETileMapKSplit
=
BlockToCTileMap_KSplit_M00_N0_M01Adapt
<
MPerBlock
,
NPerBlock
,
CGridDesc_M_N
>
;
// Block2CTileMap configuration parameter.
static
constexpr
index_t
B2E_M01
=
8
;
static
constexpr
index_t
DefaultKBatch
=
1
;
// Argument
struct
Argument
:
public
BaseArgument
{
Argument
(
std
::
vector
<
const
void
*>&
p_As
,
std
::
vector
<
const
void
*>&
p_Bs
,
std
::
vector
<
void
*>&
p_Es
,
std
::
vector
<
GemmDesc
>&
gemm_descs
)
:
Argument
(
p_As
,
p_Bs
,
p_Es
,
gemm_descs
,
DefaultKBatch
)
{
}
Argument
(
std
::
vector
<
const
void
*>&
p_As
,
std
::
vector
<
const
void
*>&
p_Bs
,
std
::
vector
<
void
*>&
p_Es
,
std
::
vector
<
GemmDesc
>&
gemm_descs
,
index_t
kbatch
)
:
K_BATCH
{
kbatch
},
group_count_
{
0
},
skipped_group_count_
{
0
},
grid_size_
{
0
}
{
group_count_
=
ck
::
type_convert
<
ck
::
index_t
>
(
gemm_descs
.
size
());
if
(
!
(
group_count_
==
ck
::
type_convert
<
ck
::
index_t
>
(
p_As
.
size
())
&&
group_count_
==
ck
::
type_convert
<
ck
::
index_t
>
(
p_Bs
.
size
())
&&
group_count_
==
ck
::
type_convert
<
ck
::
index_t
>
(
p_Es
.
size
())))
{
throw
std
::
runtime_error
(
"wrong! group_count_ != p_As/b/c.size"
);
}
gemm_kernel_args_
.
reserve
(
group_count_
);
for
(
std
::
size_t
i
=
0
;
i
<
gemm_descs
.
size
();
++
i
)
{
const
index_t
M
=
gemm_descs
[
i
].
M_
;
const
index_t
N
=
gemm_descs
[
i
].
N_
;
const
index_t
K
=
gemm_descs
[
i
].
K_
;
if
(
M
==
0
)
{
skipped_group_count_
++
;
continue
;
}
const
index_t
stride_a
=
gemm_descs
[
i
].
stride_A_
;
const
index_t
stride_b
=
gemm_descs
[
i
].
stride_B_
;
const
index_t
stride_c
=
gemm_descs
[
i
].
stride_C_
;
const
auto
c_grid_desc_m_n
=
GridwiseGemm
::
MakeCGridDescriptor_M_N
(
M
,
N
,
stride_c
);
auto
local_b2c_tile_map
=
Block2ETileMapKSplit
{
c_grid_desc_m_n
,
B2E_M01
,
K_BATCH
};
const
index_t
grid_size_grp
=
local_b2c_tile_map
.
CalculateGridSize
(
c_grid_desc_m_n
);
grid_size_
+=
grid_size_grp
;
gemm_kernel_args_
.
emplace_back
(
type_convert
<
const
ADataType
*>
(
p_As
[
i
]),
type_convert
<
const
BDataType
*>
(
p_Bs
[
i
]),
type_convert
<
EDataType
*>
(
p_Es
[
i
]),
M
,
N
,
K
,
stride_a
,
stride_b
,
stride_c
);
}
}
/**
* @brief Set new kbatch value.
*
* @param[in] kbatch The new splitK parameter value.
*/
void
UpdateKBatch
(
index_t
kbatch
)
{
K_BATCH
=
kbatch
;
grid_size_
=
0
;
for
(
std
::
size_t
i
=
0
;
i
<
gemm_kernel_args_
.
size
();
++
i
)
{
auto
&
gemm_arg
=
gemm_kernel_args_
[
i
];
const
auto
c_grid_desc_m_n
=
GridwiseGemm
::
MakeCGridDescriptor_M_N
(
gemm_arg
.
M
,
gemm_arg
.
N
,
gemm_arg
.
StrideC
);
auto
local_b2c_tile_map
=
Block2ETileMapKSplit
{
c_grid_desc_m_n
,
B2E_M01
,
K_BATCH
};
const
index_t
grid_size_grp
=
local_b2c_tile_map
.
CalculateGridSize
(
c_grid_desc_m_n
);
grid_size_
+=
grid_size_grp
;
}
}
// private:
index_t
K_BATCH
;
index_t
group_count_
;
index_t
skipped_group_count_
;
// The overall number of output tiles to be processed.
index_t
grid_size_
;
const
void
*
p_dev_gemm_args_
;
std
::
vector
<
KernelArguments
>
gemm_kernel_args_
;
};
// Invoker
struct
Invoker
:
public
BaseInvoker
{
// The oversubscription factor for the number of blocks that can simultaneously reside on
// GPU.
static
constexpr
int
BLOCK_SUBSCRIPTION_FACTOR
=
1
;
static
constexpr
int
BLOCK_WAVES
=
BlockSize
/
get_warp_size
();
static
constexpr
int
CU_SIMDS
=
4
;
// Assume we want to have at most 2 waves per SIMD
static
constexpr
int
CU_BLOCKS
=
math
::
integer_divide_floor
(
2
*
CU_SIMDS
,
BLOCK_WAVES
);
///
/// @brief Launch Grouped Gemm kernel.
///
/// @note This function overload is using user provided device buffer for kernel
/// arguments.
///
/// @param[in] arg The structure containing kernel arguments (in host memory).
/// @param[in] dev_gemm_args The point to device memory with kernel arguments.
/// @param[in] stream_config The device stream configuration.
///
/// @return The average kernel execution time (if time measurement is enabled.)
///
float
Run
(
const
Argument
&
arg
,
const
void
*
dev_gemm_args
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
{
auto
[
all_have_kbatch_gt_one
,
all_have_main_k0_block_loop
]
=
CheckArgument
(
arg
,
stream_config
);
if
(
dev_gemm_args
==
nullptr
)
{
std
::
ostringstream
err
;
err
<<
"The gemm arguments workspace buffer is not allocated!"
<<
" In "
<<
__FILE__
<<
":"
<<
__LINE__
<<
", in function: "
<<
__func__
;
throw
std
::
runtime_error
(
err
.
str
());
}
if
(
all_have_kbatch_gt_one
)
{
for
(
const
auto
&
gemm_arg
:
arg
.
gemm_kernel_args_
)
{
hip_check_error
(
hipMemset
(
gemm_arg
.
p_c_grid
,
0
,
gemm_arg
.
M
*
gemm_arg
.
N
*
sizeof
(
EDataType
)));
}
}
float
ave_time
=
0
;
if
(
all_have_main_k0_block_loop
)
{
if
(
all_have_kbatch_gt_one
)
{
ave_time
=
DispatchKernel
<
InMemoryDataOperationEnum
::
AtomicAdd
,
true
>
(
arg
,
dev_gemm_args
,
stream_config
);
}
else
{
ave_time
=
DispatchKernel
<
InMemoryDataOperationEnum
::
Set
,
true
>
(
arg
,
dev_gemm_args
,
stream_config
);
}
}
else
{
if
(
all_have_kbatch_gt_one
)
{
ave_time
=
DispatchKernel
<
InMemoryDataOperationEnum
::
AtomicAdd
,
false
>
(
arg
,
dev_gemm_args
,
stream_config
);
}
else
{
ave_time
=
DispatchKernel
<
InMemoryDataOperationEnum
::
Set
,
false
>
(
arg
,
dev_gemm_args
,
stream_config
);
}
}
return
ave_time
;
}
///
/// @brief Launch Grouped Gemm kernel.
///
/// @note This function overload is using device workspace buffer for kernel
/// arguments. The user should call @see GetWorkSpaceSize and @see
/// SetWorkSpacePointer on arg parameter to properly allocate this buffer.
///
/// @param[in] arg The structure containing kernel arguments (in host memory).
/// @param[in] stream_config The device stream configuration.
///
/// @return The average kernel execution time (if time measurement is enabled.)
///
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
{
if
(
arg
.
p_workspace_
!=
nullptr
)
{
hip_check_error
(
hipMemcpyWithStream
(
arg
.
p_workspace_
,
arg
.
gemm_kernel_args_
.
data
(),
arg
.
gemm_kernel_args_
.
size
()
*
sizeof
(
KernelArguments
),
hipMemcpyHostToDevice
,
stream_config
.
stream_id_
));
}
else
{
std
::
ostringstream
err
;
err
<<
"The gemm arguments workspace buffer is not allocated!"
<<
" In "
<<
__FILE__
<<
":"
<<
__LINE__
<<
", in function: "
<<
__func__
;
throw
std
::
runtime_error
(
err
.
str
());
}
return
Run
(
arg
,
arg
.
p_workspace_
,
stream_config
);
}
float
Run
(
const
BaseArgument
*
p_arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
override
{
return
Run
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
),
stream_config
);
}
private:
auto
CheckArgument
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
)
const
{
index_t
K0
=
GridwiseGemm
::
CalculateK0
(
arg
.
gemm_kernel_args_
[
0
].
K
,
arg
.
K_BATCH
);
bool
all_have_kbatch_gt_one
=
arg
.
K_BATCH
>
1
;
bool
all_have_main_k0_block_loop
=
GridwiseGemm
::
CalculateHasMainK0BlockLoop
(
K0
);
for
(
std
::
size_t
i
=
0
;
i
<
arg
.
gemm_kernel_args_
.
size
();
++
i
)
{
const
auto
&
gemm_arg
=
arg
.
gemm_kernel_args_
[
i
];
if
(
stream_config
.
log_level_
>
0
)
{
// gemm_arg.Print();
}
// Currently all groups use same kbatch value.
auto
kbatch
=
arg
.
K_BATCH
;
K0
=
GridwiseGemm
::
CalculateK0
(
arg
.
gemm_kernel_args_
[
i
].
K
,
arg
.
K_BATCH
);
if
(
!
GridwiseGemm
::
CheckValidity
(
GridwiseGemmArg
{
nullptr
,
nullptr
,
nullptr
,
gemm_arg
.
M
,
gemm_arg
.
N
,
gemm_arg
.
K
,
gemm_arg
.
StrideA
,
gemm_arg
.
StrideB
,
gemm_arg
.
StrideC
,
0
,
// MPadded
0
,
// NPadded
0
,
// KPadded
K0
,
kbatch
}))
{
std
::
ostringstream
err
;
err
<<
"Group id: "
<<
i
<<
" has invalid GridwiseGemm settings!"
<<
__FILE__
<<
":"
<<
__LINE__
<<
", in function: "
<<
__func__
;
throw
std
::
runtime_error
(
err
.
str
());
}
bool
not_all_have_main_k0_block_loop_same
=
all_have_main_k0_block_loop
xor
GridwiseGemm
::
CalculateHasMainK0BlockLoop
(
K0
);
bool
not_all_have_kbatch_value_same
=
all_have_kbatch_gt_one
xor
(
kbatch
>
1
);
if
(
not_all_have_main_k0_block_loop_same
)
{
std
::
ostringstream
err
;
err
<<
"Not all gemms have same value for main_k0_block_loop! in "
<<
__FILE__
<<
":"
<<
__LINE__
<<
", in function: "
<<
__func__
;
throw
std
::
runtime_error
(
err
.
str
());
}
if
(
not_all_have_kbatch_value_same
)
{
std
::
ostringstream
err
;
err
<<
"Not all gemms have same kbatch value (=1 or >1)! "
<<
"group ["
<<
i
<<
"], kbatch: "
<<
kbatch
<<
", group [0], kbatch: "
<<
arg
.
K_BATCH
<<
" in "
<<
__FILE__
<<
":"
<<
__LINE__
<<
", in function: "
<<
__func__
;
throw
std
::
runtime_error
(
err
.
str
());
}
}
return
std
::
make_tuple
(
all_have_kbatch_gt_one
,
all_have_main_k0_block_loop
);
}
template
<
InMemoryDataOperationEnum
CGlobalMemoryDataOperation
,
bool
HasMainKBlockLoop
>
float
DispatchKernel
(
const
Argument
&
arg
,
const
void
*
dev_gemm_args
,
const
StreamConfig
&
stream_config
)
const
{
const
auto
kernel
=
kernel_grouped_gemm_xdl_splitk
<
GridwiseGemm
,
KernelArguments
,
ADataType
,
BDataType
,
EDataType
,
HasMainKBlockLoop
,
CGlobalMemoryDataOperation
>
;
return
LaunchKernel
(
kernel
,
arg
,
dev_gemm_args
,
stream_config
);
}
template
<
typename
KernelFunction
>
int
CalculateMaxOccupancyGridSize
(
const
KernelFunction
&
kernel
,
const
StreamConfig
&
stream_config
)
const
{
// Calculate max number of workgroups that can simultaneously reside on the CU.
int
num_blocks
=
0
;
size_t
dyn_shared_mem_per_blk
=
0
;
hip_check_error
(
hipOccupancyMaxActiveBlocksPerMultiprocessor
(
&
num_blocks
,
kernel
,
BlockSize
,
dyn_shared_mem_per_blk
));
int
cu_count
=
getAvailableComputeUnitCount
(
stream_config
);
if
(
stream_config
.
log_level_
>
0
)
{
std
::
cout
<<
"MaxActiveBlocksPerCU: "
<<
num_blocks
<<
", available CUs count: "
<<
cu_count
<<
", grid size: "
<<
ck
::
math
::
min
(
num_blocks
,
CU_BLOCKS
)
*
cu_count
*
BLOCK_SUBSCRIPTION_FACTOR
<<
std
::
endl
;
}
return
cu_count
*
ck
::
math
::
min
(
num_blocks
,
CU_BLOCKS
)
*
BLOCK_SUBSCRIPTION_FACTOR
;
}
template
<
typename
KernelFunction
>
float
LaunchKernel
(
const
KernelFunction
&
kernel
,
const
Argument
&
arg
,
const
void
*
dev_gemm_args
,
const
StreamConfig
&
stream_config
)
const
{
int
max_occupancy_grid_size
=
CalculateMaxOccupancyGridSize
(
kernel
,
stream_config
);
// We launch the smaller number of workgroups from acutally needed tiles and the
// number of workgroups that maximize the GPU occupancy. That is because for some tile
// configuration the first is smaller than the latter. Launching too many workgroups
// mean some of them will have to iterate through all gemm problem descriptors just to
// find out they have nothing to do which is of course waste of GPU cycles.
return
launch_and_time_kernel
(
stream_config
,
kernel
,
dim3
(
ck
::
math
::
min
(
arg
.
grid_size_
,
max_occupancy_grid_size
)),
dim3
(
BlockSize
),
0
,
cast_pointer_to_constant_address_space
(
dev_gemm_args
),
arg
.
grid_size_
,
arg
.
K_BATCH
);
}
};
static
constexpr
bool
IsValidCompilationParameter
()
{
// TODO: properly implement this check
return
true
;
}
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
if
((
ck
::
type_convert
<
ck
::
index_t
>
(
arg
.
gemm_kernel_args_
.
size
())
+
arg
.
skipped_group_count_
)
!=
arg
.
group_count_
)
{
#if DEBUG_LOG
std
::
cout
<<
"The group count is not equal to sum of skipped groups "
"and kernel args size!"
<<
std
::
endl
;
#endif // DEBUG_LOG
return
false
;
}
bool
supported
=
true
;
for
(
std
::
size_t
i
=
0
;
i
<
arg
.
gemm_kernel_args_
.
size
();
++
i
)
{
const
auto
&
gemm_arg
=
arg
.
gemm_kernel_args_
[
i
];
const
auto
K0
=
GridwiseGemm
::
CalculateK0
(
gemm_arg
.
K
,
arg
.
K_BATCH
);
bool
group_arg_valid
=
GridwiseGemm
::
CheckValidity
(
GridwiseGemmArg
{
nullptr
,
nullptr
,
nullptr
,
gemm_arg
.
M
,
gemm_arg
.
N
,
gemm_arg
.
K
,
gemm_arg
.
StrideA
,
gemm_arg
.
StrideB
,
gemm_arg
.
StrideC
,
0
,
// MPadded
0
,
// NPadded
0
,
// KPadded
K0
,
arg
.
K_BATCH
});
if
(
not
group_arg_valid
)
{
#if DEBUG_LOG
std
::
cout
<<
"["
<<
__func__
<<
"] group id: "
<<
i
<<
" has invalid GridwiseGemm settings!"
<<
std
::
endl
;
gemm_arg
.
Print
();
#endif // DEBUG_LOG
}
supported
=
supported
&&
group_arg_valid
;
}
return
supported
;
}
bool
IsSupportedArgument
(
const
BaseArgument
*
p_arg
)
override
{
return
IsSupportedArgument
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
));
}
static
auto
MakeArgument
(
std
::
vector
<
const
void
*>&
p_As
,
std
::
vector
<
const
void
*>&
p_Bs
,
std
::
vector
<
std
::
array
<
const
void
*
,
NumDTensor
>>&
,
std
::
vector
<
void
*>&
p_Es
,
std
::
vector
<
GemmDesc
>
gemm_descs
,
AElementwiseOperation
,
BElementwiseOperation
,
CDEElementwiseOperation
)
{
return
Argument
{
p_As
,
p_Bs
,
p_Es
,
gemm_descs
};
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
std
::
vector
<
const
void
*>&
p_As
,
std
::
vector
<
const
void
*>&
p_Bs
,
std
::
vector
<
std
::
array
<
const
void
*
,
NumDTensor
>>&
,
std
::
vector
<
void
*>&
p_Es
,
std
::
vector
<
GemmDesc
>&
gemm_descs
,
AElementwiseOperation
,
BElementwiseOperation
,
CDEElementwiseOperation
)
override
{
return
std
::
make_unique
<
Argument
>
(
p_As
,
p_Bs
,
p_Es
,
gemm_descs
);
}
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
override
{
return
std
::
make_unique
<
Invoker
>
(
Invoker
{});
}
std
::
string
GetTypeString
()
const
override
{
auto
str
=
std
::
stringstream
();
// clang-format off
str
<<
"DeviceGroupedGemm_XdlSplitKTileLoop"
<<
"<"
<<
std
::
string
(
ALayout
::
name
)[
0
]
<<
","
<<
std
::
string
(
BLayout
::
name
)[
0
]
<<
","
<<
std
::
string
(
ELayout
::
name
)[
0
]
<<
","
<<
BlockSize
<<
", "
<<
MPerBlock
<<
", "
<<
NPerBlock
<<
", "
<<
KPerBlock
<<
", "
<<
AK1
<<
", "
<<
BK1
<<
", "
<<
MPerXDL
<<
", "
<<
NPerXDL
<<
", "
<<
MXdlPerWave
<<
", "
<<
NXdlPerWave
<<
", "
<<
ABlockTransferSrcScalarPerVector
<<
", "
<<
BBlockTransferSrcScalarPerVector
<<
", "
<<
CShuffleMXdlPerWavePerShuffle
<<
", "
<<
CShuffleNXdlPerWavePerShuffle
<<
", "
<<
ABlockTransferThreadClusterLengths_K0_M_K1
{}
<<
", "
<<
getGemmSpecializationString
(
GemmSpec
)
<<
">"
;
// clang-format on
return
str
.
str
();
}
size_t
GetWorkSpaceSize
(
const
BaseArgument
*
p_arg
)
const
override
{
return
dynamic_cast
<
const
Argument
*>
(
p_arg
)
->
gemm_kernel_args_
.
size
()
*
sizeof
(
KernelArguments
);
}
static
void
SetKBatchSize
(
Argument
&
arg
,
index_t
kbatch
)
{
arg
.
UpdateKBatch
(
kbatch
);
}
static
void
SetDeviceKernelArgs
(
Argument
&
arg
,
const
void
*
p_dev_kernel_args
)
{
arg
.
p_dev_gemm_args_
=
p_dev_kernel_args
;
}
void
SetKBatchSize
(
BaseArgument
*
p_arg
,
index_t
kbatch
)
const
override
{
return
SetKBatchSize
(
*
dynamic_cast
<
Argument
*>
(
p_arg
),
kbatch
);
}
void
SetDeviceKernelArgs
(
BaseArgument
*
p_arg
,
const
void
*
p_dev_kernel_args
)
const
override
{
return
SetDeviceKernelArgs
(
*
dynamic_cast
<
Argument
*>
(
p_arg
),
p_dev_kernel_args
);
}
};
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment