Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
52c79ace
Commit
52c79ace
authored
Jul 18, 2023
by
Adam Osewski
Browse files
Change Run API to accept user provided workspace buffer.
parent
21fbf2ce
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
195 additions
and
118 deletions
+195
-118
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_splitk_cshuffle_tile_loop.hpp
...mpl/device_grouped_gemm_xdl_splitk_cshuffle_tile_loop.hpp
+195
-118
No files found.
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_splitk_cshuffle_tile_loop.hpp
View file @
52c79ace
...
...
@@ -5,6 +5,7 @@
#include <iostream>
#include <sstream>
#include <tuple>
#include "ck/ck.hpp"
#include "ck/host_utility/device_prop.hpp"
...
...
@@ -431,7 +432,161 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
// Assume we want to have at most 2 waves per SIMD
static
constexpr
int
CU_BLOCKS
=
math
::
integer_divide_floor
(
2
*
CU_SIMDS
,
BLOCK_WAVES
);
//
// @brief Launch Grouped Gemm kernel.
//
// @note This function overload is using user provided device workspace buffer for
// kernel arguments.
//
// @param[in] arg The structure containing kernel arguments (in host memory).
// @param[in] dev_gemm_args The point to device memory with kernel arguments.
// @param[in] stream_config The device stream configuration.
//
// @return The average kernel execution time (if time measurement is enabled.)
//
float
Run
(
const
Argument
&
arg
,
const
void
*
dev_gemm_args
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
{
auto
[
all_have_kbatch_gt_one
,
all_have_main_k0_block_loop
]
=
CheckArgument
(
arg
,
stream_config
);
if
(
dev_gemm_args
!=
nullptr
)
{
arg
.
p_workspace_
=
dev_gemm_args
;
}
else
{
std
::
ostringstream
err
;
err
<<
"The gemm arguments workspace buffer is not allocated!"
<<
" In "
<<
__FILE__
<<
":"
<<
__LINE__
<<
", in function: "
<<
__func__
;
throw
std
::
runtime_error
(
err
.
str
());
}
if
(
all_have_kbatch_gt_one
)
{
for
(
const
auto
&
gemm_arg
:
arg
.
gemm_kernel_args_
)
{
hip_check_error
(
hipMemset
(
gemm_arg
.
p_c_grid
,
0
,
gemm_arg
.
M
*
gemm_arg
.
N
*
sizeof
(
EDataType
)));
}
}
float
ave_time
=
0
;
if
(
all_have_main_k0_block_loop
)
{
if
(
all_have_kbatch_gt_one
)
{
ave_time
=
DispatchKernel
<
InMemoryDataOperationEnum
::
AtomicAdd
,
true
>
(
arg
,
stream_config
);
}
else
{
ave_time
=
DispatchKernel
<
InMemoryDataOperationEnum
::
Set
,
true
>
(
arg
,
stream_config
);
}
}
else
{
if
(
all_have_kbatch_gt_one
)
{
ave_time
=
DispatchKernel
<
InMemoryDataOperationEnum
::
AtomicAdd
,
false
>
(
arg
,
stream_config
);
}
else
{
ave_time
=
DispatchKernel
<
InMemoryDataOperationEnum
::
Set
,
false
>
(
arg
,
stream_config
);
}
}
return
ave_time
;
}
//
// @brief Launch Grouped Gemm kernel.
//
// @note This function overload is using device workspace buffer for kernel arguments.
// The user should call @see GetWorkSpaceSize and @see SetWorkSpacePointer on
// arg parameter to properly allocate this buffer.
//
// @param[in] arg The structure containing kernel arguments (in host memory).
// @param[in] stream_config The device stream configuration.
//
// @return The average kernel execution time (if time measurement is enabled.)
//
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
{
auto
[
all_have_kbatch_gt_one
,
all_have_main_k0_block_loop
]
=
CheckArgument
(
arg
,
stream_config
);
if
(
arg
.
p_workspace_
!=
nullptr
)
{
hip_check_error
(
hipMemcpyWithStream
(
arg
.
p_workspace_
,
arg
.
gemm_kernel_args_
.
data
(),
arg
.
gemm_kernel_args_
.
size
()
*
sizeof
(
KernelArguments
),
hipMemcpyHostToDevice
,
stream_config
.
stream_id_
));
}
else
{
std
::
ostringstream
err
;
err
<<
"The gemm arguments workspace buffer is not allocated!"
<<
" In "
<<
__FILE__
<<
":"
<<
__LINE__
<<
", in function: "
<<
__func__
;
throw
std
::
runtime_error
(
err
.
str
());
}
if
(
all_have_kbatch_gt_one
)
{
for
(
const
auto
&
gemm_arg
:
arg
.
gemm_kernel_args_
)
{
hip_check_error
(
hipMemset
(
gemm_arg
.
p_c_grid
,
0
,
gemm_arg
.
M
*
gemm_arg
.
N
*
sizeof
(
EDataType
)));
}
}
float
ave_time
=
0
;
if
(
all_have_main_k0_block_loop
)
{
if
(
all_have_kbatch_gt_one
)
{
ave_time
=
DispatchKernel
<
InMemoryDataOperationEnum
::
AtomicAdd
,
true
>
(
arg
,
stream_config
);
}
else
{
ave_time
=
DispatchKernel
<
InMemoryDataOperationEnum
::
Set
,
true
>
(
arg
,
stream_config
);
}
}
else
{
if
(
all_have_kbatch_gt_one
)
{
ave_time
=
DispatchKernel
<
InMemoryDataOperationEnum
::
AtomicAdd
,
false
>
(
arg
,
stream_config
);
}
else
{
ave_time
=
DispatchKernel
<
InMemoryDataOperationEnum
::
Set
,
false
>
(
arg
,
stream_config
);
}
}
return
ave_time
;
}
float
Run
(
const
BaseArgument
*
p_arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
override
{
return
Run
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
),
stream_config
);
}
private:
auto
CheckArgument
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
)
const
{
index_t
K0
=
GridwiseGemm
::
CalculateK0
(
arg
.
gemm_kernel_args_
[
0
].
K
,
arg
.
K_BATCH
);
bool
all_have_kbatch_gt_one
=
arg
.
K_BATCH
>
1
;
...
...
@@ -492,131 +647,53 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
throw
std
::
runtime_error
(
err
.
str
());
}
}
return
std
::
make_tuple
(
all_have_kbatch_gt_one
,
all_have_main_k0_block_loop
);
}
if
(
arg
.
p_workspace_
!=
nullptr
)
{
hip_check_error
(
hipMemcpyWithStream
(
arg
.
p_workspace_
,
arg
.
gemm_kernel_args_
.
data
(),
arg
.
gemm_kernel_args_
.
size
()
*
sizeof
(
KernelArguments
),
hipMemcpyHostToDevice
,
stream_config
.
stream_id_
));
}
else
{
std
::
ostringstream
err
;
err
<<
"The argument workspace buffer is not allocated!"
<<
" In "
<<
__FILE__
<<
":"
<<
__LINE__
<<
", in function: "
<<
__func__
;
throw
std
::
runtime_error
(
err
.
str
());
}
float
ave_time
=
0
;
const
auto
Run
=
[
&
](
const
auto
&
kernel
)
{
if
(
all_have_kbatch_gt_one
)
{
for
(
const
auto
&
gemm_arg
:
arg
.
gemm_kernel_args_
)
{
hip_check_error
(
hipMemset
(
gemm_arg
.
p_c_grid
,
0
,
gemm_arg
.
M
*
gemm_arg
.
N
*
sizeof
(
EDataType
)));
}
}
// Calculate max number of workgroups that can simultaneously reside on the CU.
int
num_blocks
=
0
;
size_t
dyn_shared_mem_per_blk
=
0
;
hip_check_error
(
hipOccupancyMaxActiveBlocksPerMultiprocessor
(
&
num_blocks
,
kernel
,
BlockSize
,
dyn_shared_mem_per_blk
));
int
cu_count
=
getAvailableComputeUnitCount
(
stream_config
);
template
<
InMemoryDataOperationEnum
CGlobalMemoryDataOperation
,
bool
HasMainKBlockLoop
>
float
DispatchKernel
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
)
const
{
const
auto
kernel
=
kernel_grouped_gemm_xdl_splitk
<
GridwiseGemm
,
KernelArguments
,
ADataType
,
BDataType
,
EDataType
,
HasMainKBlockLoop
,
CGlobalMemoryDataOperation
>
;
return
LaunchKernel
(
kernel
,
arg
,
stream_config
);
}
if
(
stream_config
.
log_level_
>
0
)
{
std
::
cout
<<
"MaxActiveBlocksPerCU: "
<<
num_blocks
<<
", available CUs count: "
<<
cu_count
<<
", grid size: "
<<
ck
::
math
::
min
(
num_blocks
,
CU_BLOCKS
)
*
cu_count
*
BLOCK_SUBSCRIPTION_FACTOR
<<
std
::
endl
;
}
template
<
typename
KernelFunction
>
float
LaunchKernel
(
const
KernelFunction
&
kernel
,
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
)
const
{
// Calculate max number of workgroups that can simultaneously reside on the CU.
int
num_blocks
=
0
;
size_t
dyn_shared_mem_per_blk
=
0
;
hip_check_error
(
hipOccupancyMaxActiveBlocksPerMultiprocessor
(
&
num_blocks
,
kernel
,
BlockSize
,
dyn_shared_mem_per_blk
));
ave_time
=
launch_and_time_kernel
(
stream_config
,
kernel
,
dim3
(
cu_count
*
ck
::
math
::
min
(
num_blocks
,
CU_BLOCKS
)
*
BLOCK_SUBSCRIPTION_FACTOR
),
dim3
(
BlockSize
),
0
,
arg
.
p_workspace_
,
arg
.
grid_size_
,
arg
.
K_BATCH
);
};
int
cu_count
=
getAvailableComputeUnitCount
(
stream_config
);
if
(
all_have_main_k0_block_loop
)
{
if
(
all_have_kbatch_gt_one
)
{
const
auto
kernel
=
kernel_grouped_gemm_xdl_splitk
<
GridwiseGemm
,
KernelArguments
,
ADataType
,
BDataType
,
EDataType
,
true
,
InMemoryDataOperationEnum
::
AtomicAdd
>
;
Run
(
kernel
);
}
else
{
const
auto
kernel
=
kernel_grouped_gemm_xdl_splitk
<
GridwiseGemm
,
KernelArguments
,
ADataType
,
BDataType
,
EDataType
,
true
,
InMemoryDataOperationEnum
::
Set
>
;
Run
(
kernel
);
}
}
else
if
(
stream_config
.
log_level_
>
0
)
{
if
(
all_have_kbatch_gt_one
)
{
const
auto
kernel
=
kernel_grouped_gemm_xdl_splitk
<
GridwiseGemm
,
KernelArguments
,
ADataType
,
BDataType
,
EDataType
,
false
,
InMemoryDataOperationEnum
::
AtomicAdd
>
;
Run
(
kernel
);
}
else
{
const
auto
kernel
=
kernel_grouped_gemm_xdl_splitk
<
GridwiseGemm
,
KernelArguments
,
ADataType
,
BDataType
,
EDataType
,
false
,
InMemoryDataOperationEnum
::
Set
>
;
Run
(
kernel
);
}
std
::
cout
<<
"MaxActiveBlocksPerCU: "
<<
num_blocks
<<
", available CUs count: "
<<
cu_count
<<
", grid size: "
<<
ck
::
math
::
min
(
num_blocks
,
CU_BLOCKS
)
*
cu_count
*
BLOCK_SUBSCRIPTION_FACTOR
<<
std
::
endl
;
}
return
ave_time
;
}
float
Run
(
const
BaseArgument
*
p_arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
override
{
return
Run
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
),
stream_config
);
return
launch_and_time_kernel
(
stream_config
,
kernel
,
dim3
(
cu_count
*
ck
::
math
::
min
(
num_blocks
,
CU_BLOCKS
)
*
BLOCK_SUBSCRIPTION_FACTOR
),
dim3
(
BlockSize
),
0
,
arg
.
p_workspace_
,
arg
.
grid_size_
,
arg
.
K_BATCH
);
}
};
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment