Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
3d345953
Commit
3d345953
authored
Jul 20, 2023
by
Adam Osewski
Browse files
Update API.
parent
0e33fbdf
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
91 additions
and
108 deletions
+91
-108
include/ck/tensor_operation/gpu/device/device_grouped_gemm.hpp
...de/ck/tensor_operation/gpu/device/device_grouped_gemm.hpp
+0
-51
include/ck/tensor_operation/gpu/device/device_grouped_gemm_splitk.hpp
...ensor_operation/gpu/device/device_grouped_gemm_splitk.hpp
+67
-0
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_splitk_cshuffle_tile_loop.hpp
...mpl/device_grouped_gemm_xdl_splitk_cshuffle_tile_loop.hpp
+24
-57
No files found.
include/ck/tensor_operation/gpu/device/device_grouped_gemm.hpp
View file @
3d345953
...
@@ -12,57 +12,6 @@ namespace ck {
...
@@ -12,57 +12,6 @@ namespace ck {
namespace
tensor_operation
{
namespace
tensor_operation
{
namespace
device
{
namespace
device
{
/**
* @brief Structure representing single GEMM problem arguments.
*
* The pointer to the vector of those structures is passed
* to the GroupedGEMM entry point kernel.
*/
struct
GemmKernelArguments
{
__host__
__device__
GemmKernelArguments
(
const
void
*
p_a_grid_
,
const
void
*
p_b_grid_
,
void
*
p_c_grid_
,
index_t
M_
,
index_t
N_
,
index_t
K_
,
index_t
StrideA_
,
index_t
StrideB_
,
index_t
StrideC_
)
:
p_a_grid
{
p_a_grid_
},
p_b_grid
{
p_b_grid_
},
p_c_grid
{
p_c_grid_
},
M
{
M_
},
N
{
N_
},
K
{
K_
},
StrideA
{
StrideA_
},
StrideB
{
StrideB_
},
StrideC
{
StrideC_
}
{
}
const
void
*
p_a_grid
;
const
void
*
p_b_grid
;
void
*
p_c_grid
;
index_t
M
;
index_t
N
;
index_t
K
;
index_t
StrideA
;
index_t
StrideB
;
index_t
StrideC
;
void
Print
()
const
{
std
::
cout
<<
"arg {"
<<
"M:"
<<
M
<<
", "
<<
"N:"
<<
N
<<
", "
<<
"K:"
<<
K
<<
", "
<<
"SA:"
<<
StrideA
<<
", "
<<
"SB:"
<<
StrideB
<<
", "
<<
"SC:"
<<
StrideC
<<
"}"
<<
std
::
endl
;
}
};
struct
GemmDesc
struct
GemmDesc
{
{
ck
::
index_t
M_
,
N_
,
K_
;
ck
::
index_t
M_
,
N_
,
K_
;
...
...
include/ck/tensor_operation/gpu/device/device_grouped_gemm_splitk.hpp
View file @
3d345953
...
@@ -8,6 +8,57 @@ namespace ck {
...
@@ -8,6 +8,57 @@ namespace ck {
namespace
tensor_operation
{
namespace
tensor_operation
{
namespace
device
{
namespace
device
{
/**
* @brief Structure representing single GEMM problem arguments.
*
* The pointer to the vector of those structures is passed
* to the GroupedGEMM entry point kernel.
*/
struct
GroupedGemmKernelArguments
{
__host__
__device__
GroupedGemmKernelArguments
(
const
void
*
p_a_grid_
,
const
void
*
p_b_grid_
,
void
*
p_c_grid_
,
index_t
M_
,
index_t
N_
,
index_t
K_
,
index_t
StrideA_
,
index_t
StrideB_
,
index_t
StrideC_
)
:
p_a_grid
{
p_a_grid_
},
p_b_grid
{
p_b_grid_
},
p_c_grid
{
p_c_grid_
},
M
{
M_
},
N
{
N_
},
K
{
K_
},
StrideA
{
StrideA_
},
StrideB
{
StrideB_
},
StrideC
{
StrideC_
}
{
}
const
void
*
p_a_grid
;
const
void
*
p_b_grid
;
void
*
p_c_grid
;
index_t
M
;
index_t
N
;
index_t
K
;
index_t
StrideA
;
index_t
StrideB
;
index_t
StrideC
;
void
Print
()
const
{
std
::
cout
<<
"arg {"
<<
"M:"
<<
M
<<
", "
<<
"N:"
<<
N
<<
", "
<<
"K:"
<<
K
<<
", "
<<
"SA:"
<<
StrideA
<<
", "
<<
"SB:"
<<
StrideB
<<
", "
<<
"SC:"
<<
StrideC
<<
"}"
<<
std
::
endl
;
}
};
template
<
typename
ALayout
,
template
<
typename
ALayout
,
typename
BLayout
,
typename
BLayout
,
typename
DsLayout
,
typename
DsLayout
,
...
@@ -31,7 +82,23 @@ struct DeviceGroupedGemmSplitK : public DeviceGroupedGemm<ALayout,
...
@@ -31,7 +82,23 @@ struct DeviceGroupedGemmSplitK : public DeviceGroupedGemm<ALayout,
BElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
>
CElementwiseOperation
>
{
{
//------------------------------------------------------------------------//
// @brief Sets the k batch size.
//
// @param p_arg Pointer to the Argument we're going to change.
// @param[in] kbatch The kbatch value.
//
virtual
void
SetKBatchSize
(
BaseArgument
*
p_arg
,
index_t
kbatch
)
const
=
0
;
virtual
void
SetKBatchSize
(
BaseArgument
*
p_arg
,
index_t
kbatch
)
const
=
0
;
//------------------------------------------------------------------------//
//
// @brief Sets the device kernel arguments pointer.
//
// @param p_arg The pointer to the Argument we're going to update.
// @param[in] p_dev_kernel_args The pointer to the device memory which contains kernel
// arguments.
//
virtual
void
SetDeviceKernelArgs
(
BaseArgument
*
p_arg
,
const
void
*
p_dev_kernel_args
)
const
=
0
;
};
};
}
// namespace device
}
// namespace device
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_splitk_cshuffle_tile_loop.hpp
View file @
3d345953
...
@@ -265,7 +265,7 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
...
@@ -265,7 +265,7 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
using
CGridDesc_M_N
=
typename
GridwiseGemm
::
CGridDesc_M_N
;
using
CGridDesc_M_N
=
typename
GridwiseGemm
::
CGridDesc_M_N
;
using
GridwiseGemmArg
=
typename
GridwiseGemm
::
Argument
;
using
GridwiseGemmArg
=
typename
GridwiseGemm
::
Argument
;
using
KernelArguments
=
GemmKernelArguments
;
using
KernelArguments
=
Grouped
GemmKernelArguments
;
using
Block2ETileMapKSplit
=
using
Block2ETileMapKSplit
=
BlockToCTileMap_KSplit_M00_N0_M01Adapt
<
MPerBlock
,
NPerBlock
,
CGridDesc_M_N
>
;
BlockToCTileMap_KSplit_M00_N0_M01Adapt
<
MPerBlock
,
NPerBlock
,
CGridDesc_M_N
>
;
// Block2CTileMap configuration parameter.
// Block2CTileMap configuration parameter.
...
@@ -366,6 +366,7 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
...
@@ -366,6 +366,7 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
index_t
skipped_group_count_
;
index_t
skipped_group_count_
;
// The overall number of output tiles to be processed.
// The overall number of output tiles to be processed.
index_t
grid_size_
;
index_t
grid_size_
;
const
void
*
p_dev_gemm_args_
;
std
::
vector
<
KernelArguments
>
gemm_kernel_args_
;
std
::
vector
<
KernelArguments
>
gemm_kernel_args_
;
};
};
...
@@ -384,8 +385,8 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
...
@@ -384,8 +385,8 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
//
//
// @brief Launch Grouped Gemm kernel.
// @brief Launch Grouped Gemm kernel.
//
//
// @note This function overload is using user provided device
workspace
buffer for
// @note This function overload is using user provided device buffer for
kernel
//
kernel
arguments.
// arguments.
//
//
// @param[in] arg The structure containing kernel arguments (in host memory).
// @param[in] arg The structure containing kernel arguments (in host memory).
// @param[in] dev_gemm_args The point to device memory with kernel arguments.
// @param[in] dev_gemm_args The point to device memory with kernel arguments.
...
@@ -400,11 +401,7 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
...
@@ -400,11 +401,7 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
auto
[
all_have_kbatch_gt_one
,
all_have_main_k0_block_loop
]
=
auto
[
all_have_kbatch_gt_one
,
all_have_main_k0_block_loop
]
=
CheckArgument
(
arg
,
stream_config
);
CheckArgument
(
arg
,
stream_config
);
if
(
dev_gemm_args
!=
nullptr
)
if
(
dev_gemm_args
==
nullptr
)
{
arg
.
p_workspace_
=
dev_gemm_args
;
}
else
{
{
std
::
ostringstream
err
;
std
::
ostringstream
err
;
err
<<
"The gemm arguments workspace buffer is not allocated!"
err
<<
"The gemm arguments workspace buffer is not allocated!"
...
@@ -428,12 +425,12 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
...
@@ -428,12 +425,12 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
if
(
all_have_kbatch_gt_one
)
if
(
all_have_kbatch_gt_one
)
{
{
ave_time
=
DispatchKernel
<
InMemoryDataOperationEnum
::
AtomicAdd
,
true
>
(
ave_time
=
DispatchKernel
<
InMemoryDataOperationEnum
::
AtomicAdd
,
true
>
(
arg
,
stream_config
);
arg
,
dev_gemm_args
,
stream_config
);
}
}
else
else
{
{
ave_time
=
ave_time
=
DispatchKernel
<
InMemoryDataOperationEnum
::
Set
,
true
>
(
DispatchKernel
<
InMemoryDataOperationEnum
::
Set
,
true
>
(
arg
,
stream_config
);
arg
,
dev_gemm_
arg
s
,
stream_config
);
}
}
}
}
else
else
...
@@ -441,12 +438,12 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
...
@@ -441,12 +438,12 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
if
(
all_have_kbatch_gt_one
)
if
(
all_have_kbatch_gt_one
)
{
{
ave_time
=
DispatchKernel
<
InMemoryDataOperationEnum
::
AtomicAdd
,
false
>
(
ave_time
=
DispatchKernel
<
InMemoryDataOperationEnum
::
AtomicAdd
,
false
>
(
arg
,
stream_config
);
arg
,
dev_gemm_args
,
stream_config
);
}
}
else
else
{
{
ave_time
=
ave_time
=
DispatchKernel
<
InMemoryDataOperationEnum
::
Set
,
false
>
(
DispatchKernel
<
InMemoryDataOperationEnum
::
Set
,
false
>
(
arg
,
stream_config
);
arg
,
dev_gemm_
arg
s
,
stream_config
);
}
}
}
}
...
@@ -467,9 +464,6 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
...
@@ -467,9 +464,6 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
//
//
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
{
{
auto
[
all_have_kbatch_gt_one
,
all_have_main_k0_block_loop
]
=
CheckArgument
(
arg
,
stream_config
);
if
(
arg
.
p_workspace_
!=
nullptr
)
if
(
arg
.
p_workspace_
!=
nullptr
)
{
{
hip_check_error
(
hip_check_error
(
...
@@ -487,45 +481,7 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
...
@@ -487,45 +481,7 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
throw
std
::
runtime_error
(
err
.
str
());
throw
std
::
runtime_error
(
err
.
str
());
}
}
if
(
all_have_kbatch_gt_one
)
return
Run
(
arg
,
arg
.
p_workspace_
,
stream_config
);
{
for
(
const
auto
&
gemm_arg
:
arg
.
gemm_kernel_args_
)
{
hip_check_error
(
hipMemset
(
gemm_arg
.
p_c_grid
,
0
,
gemm_arg
.
M
*
gemm_arg
.
N
*
sizeof
(
EDataType
)));
}
}
float
ave_time
=
0
;
if
(
all_have_main_k0_block_loop
)
{
if
(
all_have_kbatch_gt_one
)
{
ave_time
=
DispatchKernel
<
InMemoryDataOperationEnum
::
AtomicAdd
,
true
>
(
arg
,
stream_config
);
}
else
{
ave_time
=
DispatchKernel
<
InMemoryDataOperationEnum
::
Set
,
true
>
(
arg
,
stream_config
);
}
}
else
{
if
(
all_have_kbatch_gt_one
)
{
ave_time
=
DispatchKernel
<
InMemoryDataOperationEnum
::
AtomicAdd
,
false
>
(
arg
,
stream_config
);
}
else
{
ave_time
=
DispatchKernel
<
InMemoryDataOperationEnum
::
Set
,
false
>
(
arg
,
stream_config
);
}
}
return
ave_time
;
}
}
float
Run
(
const
BaseArgument
*
p_arg
,
float
Run
(
const
BaseArgument
*
p_arg
,
...
@@ -600,7 +556,9 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
...
@@ -600,7 +556,9 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
}
}
template
<
InMemoryDataOperationEnum
CGlobalMemoryDataOperation
,
bool
HasMainKBlockLoop
>
template
<
InMemoryDataOperationEnum
CGlobalMemoryDataOperation
,
bool
HasMainKBlockLoop
>
float
DispatchKernel
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
)
const
float
DispatchKernel
(
const
Argument
&
arg
,
const
void
*
dev_gemm_args
,
const
StreamConfig
&
stream_config
)
const
{
{
const
auto
kernel
=
kernel_grouped_gemm_xdl_splitk
<
GridwiseGemm
,
const
auto
kernel
=
kernel_grouped_gemm_xdl_splitk
<
GridwiseGemm
,
KernelArguments
,
KernelArguments
,
...
@@ -772,11 +730,20 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
...
@@ -772,11 +730,20 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
}
}
static
void
SetKBatchSize
(
Argument
&
arg
,
index_t
kbatch
)
{
arg
.
UpdateKBatch
(
kbatch
);
}
static
void
SetKBatchSize
(
Argument
&
arg
,
index_t
kbatch
)
{
arg
.
UpdateKBatch
(
kbatch
);
}
static
void
SetDeviceKernelArgs
(
Argument
&
arg
,
const
void
*
p_dev_kernel_args
)
{
arg
.
p_dev_gemm_args_
=
p_dev_kernel_args
;
}
void
SetKBatchSize
(
BaseArgument
*
p_arg
,
index_t
kbatch
)
const
override
void
SetKBatchSize
(
BaseArgument
*
p_arg
,
index_t
kbatch
)
const
override
{
{
return
SetKBatchSize
(
*
dynamic_cast
<
Argument
*>
(
p_arg
),
kbatch
);
return
SetKBatchSize
(
*
dynamic_cast
<
Argument
*>
(
p_arg
),
kbatch
);
}
}
void
SetDeviceKernelArgs
(
BaseArgument
*
p_arg
,
const
void
*
p_dev_kernel_args
)
const
override
{
return
SetDeviceKernelArgs
(
*
dynamic_cast
<
Argument
*>
(
p_arg
),
p_dev_kernel_args
);
}
};
};
}
// namespace device
}
// namespace device
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment