Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
b3512749
Commit
b3512749
authored
Jul 18, 2023
by
Adam Osewski
Browse files
Move Gemm KernelArguments to device op interface.
parent
61862fb4
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
53 additions
and
53 deletions
+53
-53
include/ck/tensor_operation/gpu/device/device_grouped_gemm.hpp
...de/ck/tensor_operation/gpu/device/device_grouped_gemm.hpp
+51
-0
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_splitk_cshuffle_tile_loop.hpp
...mpl/device_grouped_gemm_xdl_splitk_cshuffle_tile_loop.hpp
+2
-53
No files found.
include/ck/tensor_operation/gpu/device/device_grouped_gemm.hpp
View file @
b3512749
...
...
@@ -12,6 +12,57 @@ namespace ck {
namespace
tensor_operation
{
namespace
device
{
/**
* @brief Structure representing single GEMM problem arguments.
*
* The pointer to the vector of those structures is passed
* to the GroupedGEMM entry point kernel.
*/
struct
GemmKernelArguments
{
__host__
__device__
GemmKernelArguments
(
const
void
*
p_a_grid_
,
const
void
*
p_b_grid_
,
void
*
p_c_grid_
,
index_t
M_
,
index_t
N_
,
index_t
K_
,
index_t
StrideA_
,
index_t
StrideB_
,
index_t
StrideC_
)
:
p_a_grid
{
p_a_grid_
},
p_b_grid
{
p_b_grid_
},
p_c_grid
{
p_c_grid_
},
M
{
M_
},
N
{
N_
},
K
{
K_
},
StrideA
{
StrideA_
},
StrideB
{
StrideB_
},
StrideC
{
StrideC_
}
{
}
const
void
*
p_a_grid
;
const
void
*
p_b_grid
;
void
*
p_c_grid
;
index_t
M
;
index_t
N
;
index_t
K
;
index_t
StrideA
;
index_t
StrideB
;
index_t
StrideC
;
void
Print
()
const
{
std
::
cout
<<
"arg {"
<<
"M:"
<<
M
<<
", "
<<
"N:"
<<
N
<<
", "
<<
"K:"
<<
K
<<
", "
<<
"SA:"
<<
StrideA
<<
", "
<<
"SB:"
<<
StrideB
<<
", "
<<
"SC:"
<<
StrideC
<<
"}"
<<
std
::
endl
;
}
};
struct
GemmDesc
{
ck
::
index_t
M_
,
N_
,
K_
;
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_splitk_cshuffle_tile_loop.hpp
View file @
b3512749
...
...
@@ -265,62 +265,11 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
using
CGridDesc_M_N
=
typename
GridwiseGemm
::
CGridDesc_M_N
;
using
GridwiseGemmArg
=
typename
GridwiseGemm
::
Argument
;
using
KernelArguments
=
GemmKernelArguments
;
using
Block2ETileMapKSplit
=
BlockToCTileMap_KSplit_M00_N0_M01Adapt
<
MPerBlock
,
NPerBlock
,
CGridDesc_M_N
>
;
// Block2CTileMap configuration parameter.
static
constexpr
index_t
B2E_M01
=
8
;
/**
* @brief Structure representing single GEMM problem arguments.
*
* The pointer to the vector of those structures is passed
* to the GroupedGEMM entry point kernel.
*/
struct
KernelArguments
{
__host__
__device__
KernelArguments
(
const
void
*
p_a_grid_
,
const
void
*
p_b_grid_
,
void
*
p_c_grid_
,
index_t
M_
,
index_t
N_
,
index_t
K_
,
index_t
StrideA_
,
index_t
StrideB_
,
index_t
StrideC_
)
:
p_a_grid
{
p_a_grid_
},
p_b_grid
{
p_b_grid_
},
p_c_grid
{
p_c_grid_
},
M
{
M_
},
N
{
N_
},
K
{
K_
},
StrideA
{
StrideA_
},
StrideB
{
StrideB_
},
StrideC
{
StrideC_
}
{
}
const
void
*
p_a_grid
;
const
void
*
p_b_grid
;
void
*
p_c_grid
;
index_t
M
;
index_t
N
;
index_t
K
;
index_t
StrideA
;
index_t
StrideB
;
index_t
StrideC
;
void
Print
()
const
{
std
::
cout
<<
"arg {"
<<
"M:"
<<
M
<<
", "
<<
"N:"
<<
N
<<
", "
<<
"K:"
<<
K
<<
", "
<<
"SA:"
<<
StrideA
<<
", "
<<
"SB:"
<<
StrideB
<<
", "
<<
"SC:"
<<
StrideC
<<
"}"
<<
std
::
endl
;
}
};
static
constexpr
index_t
DefaultKBatch
=
1
;
// Argument
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment