Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
7e71ea99
Commit
7e71ea99
authored
Jan 23, 2024
by
Adam Osewski
Browse files
Commit debug WIP for sharing.
parent
734df790
Changes
13
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
13 changed files
with
559 additions
and
384 deletions
+559
-384
include/ck/ck.hpp
include/ck/ck.hpp
+1
-1
include/ck/host_utility/kernel_launch.hpp
include/ck/host_utility/kernel_launch.hpp
+8
-5
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_tile_loop.hpp
...grouped_gemm_multiple_d_splitk_xdl_cshuffle_tile_loop.hpp
+145
-49
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_splitk_cshuffle_v2.hpp
.../grid/gridwise_gemm_multiple_d_xdl_splitk_cshuffle_v2.hpp
+55
-0
include/ck/utility/work_scheduling.hpp
include/ck/utility/work_scheduling.hpp
+12
-5
library/include/ck/library/tensor_operation_instance/gpu/grouped_gemm.hpp
...ck/library/tensor_operation_instance/gpu/grouped_gemm.hpp
+149
-149
library/include/ck/library/tensor_operation_instance/gpu/grouped_gemm_multiple_d.hpp
...tensor_operation_instance/gpu/grouped_gemm_multiple_d.hpp
+14
-14
library/src/tensor_operation_instance/gpu/grouped_gemm/CMakeLists.txt
...tensor_operation_instance/gpu/grouped_gemm/CMakeLists.txt
+9
-9
library/src/tensor_operation_instance/gpu/grouped_gemm_multiple_d/CMakeLists.txt
...ation_instance/gpu/grouped_gemm_multiple_d/CMakeLists.txt
+1
-1
library/src/tensor_operation_instance/gpu/grouped_gemm_multiple_d/device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_f16_f16_f16_mk_kn_mn_irregular_instance.cpp
..._xdl_cshuffle_f16_f16_f16_mk_kn_mn_irregular_instance.cpp
+60
-60
profiler/include/profiler/profile_grouped_gemm_multiple_d_splitk_impl.hpp
.../profiler/profile_grouped_gemm_multiple_d_splitk_impl.hpp
+15
-4
profiler/src/CMakeLists.txt
profiler/src/CMakeLists.txt
+84
-82
test/work_scheduling/test_strided_reduction_tile_loop.cpp
test/work_scheduling/test_strided_reduction_tile_loop.cpp
+6
-5
No files found.
include/ck/ck.hpp
View file @
7e71ea99
...
...
@@ -213,7 +213,7 @@
#define CK_WORKAROUND_SWDEV_388832 1
// flag to enable (1) or disable (0) the debugging output in some kernels
#define DEBUG_LOG
0
#define DEBUG_LOG
1
// denorm test fix, required to work around dissue
#ifndef CK_WORKAROUND_DENORM_FIX
...
...
include/ck/host_utility/kernel_launch.hpp
View file @
7e71ea99
...
...
@@ -103,14 +103,17 @@ float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config,
block_dim
.
y
,
block_dim
.
z
);
printf
(
"Warm up
1
time
\n
"
);
printf
(
"Warm up
%d
time
s
\n
"
,
stream_config
.
cold_niters_
);
#endif
// warm up
preprocess
();
kernel
<<<
grid_dim
,
block_dim
,
lds_byte
,
stream_config
.
stream_id_
>>>
(
args
...);
hip_check_error
(
hipGetLastError
());
for
(
int
i
=
0
;
i
<
stream_config
.
cold_niters_
;
++
i
)
{
preprocess
();
kernel
<<<
grid_dim
,
block_dim
,
lds_byte
,
stream_config
.
stream_id_
>>>
(
args
...);
hip_check_error
(
hipGetLastError
());
}
const
int
nrepeat
=
10
;
const
int
nrepeat
=
stream_config
.
nrepeat_
;
#if DEBUG_LOG
printf
(
"Start running %d times...
\n
"
,
nrepeat
);
#endif
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_tile_loop.hpp
View file @
7e71ea99
...
...
@@ -68,15 +68,15 @@ __global__ void
void
*
const
__restrict__
p_workspace
,
const
index_t
tile_count
,
const
index_t
k_batch
,
const
AElementwiseOperation
a_element_op
,
const
BElementwiseOperation
b_element_op
,
[[
maybe_unused
]]
const
AElementwiseOperation
a_element_op
,
[[
maybe_unused
]]
const
BElementwiseOperation
b_element_op
,
[[
maybe_unused
]]
const
CDEElementwiseOperation
cde_element_op
)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
constexpr
index_t
shared_size
=
GridwiseGemm
::
GetSharedMemoryNumberOfByte
();
__shared__
uint8_t
p_shared
[
shared_size
];
[[
maybe_unused
]]
__shared__
uint8_t
p_shared
[
shared_size
];
const
auto
gemm_desc_ptr
=
reinterpret_cast
<
const
GemmDesc
*>
(
cast_pointer_to_generic_address_space
(
gemm_descs_const
));
...
...
@@ -105,6 +105,12 @@ __global__ void
index_t
gemm_tile_id_end
=
grid_size_grp
;
auto
gridwise_gemm
=
GridwiseGemm
();
[[
maybe_unused
]]
auto
is_thread_local_1d_id_idx
=
[](
auto
...
Ids
)
->
bool
{
const
auto
tid
=
get_thread_local_1d_id
();
return
((
tid
==
Ids
)
||
...
);
};
do
{
// Find corresponding GEMM group for our tile
...
...
@@ -123,12 +129,12 @@ __global__ void
gemm_tile_id_end
=
offset
+
grid_size_grp
;
}
const
auto
p_a_grid
=
reinterpret_cast
<
const
FloatA
*>
(
gemm_desc_ptr
[
group_id
].
p_a_grid
);
const
auto
p_b_grid
=
reinterpret_cast
<
const
FloatB
*>
(
gemm_desc_ptr
[
group_id
].
p_b_grid
);
[[
maybe_unused
]]
const
auto
p_a_grid
=
reinterpret_cast
<
const
FloatA
*>
(
gemm_desc_ptr
[
group_id
].
p_a_grid
);
[[
maybe_unused
]]
const
auto
p_b_grid
=
reinterpret_cast
<
const
FloatB
*>
(
gemm_desc_ptr
[
group_id
].
p_b_grid
);
const
auto
K
=
gemm_desc_ptr
[
group_id
].
K
;
const
auto
StrideA
=
gemm_desc_ptr
[
group_id
].
StrideA
;
const
auto
StrideB
=
gemm_desc_ptr
[
group_id
].
StrideB
;
[[
maybe_unused
]]
const
auto
K
=
gemm_desc_ptr
[
group_id
].
K
;
[[
maybe_unused
]]
const
auto
StrideA
=
gemm_desc_ptr
[
group_id
].
StrideA
;
[[
maybe_unused
]]
const
auto
StrideB
=
gemm_desc_ptr
[
group_id
].
StrideB
;
auto
&
results_buffer
=
gridwise_gemm
.
GetCThreadBuffer
();
b2c_tile_map
.
CalculateBottomIndex
(
work_scheduler
.
tile_id_
-
offset
);
...
...
@@ -137,21 +143,32 @@ __global__ void
// Iterate over K dimension for this [M,N] tile
// still in the same GEMM && the same [M,N] tile
// TODO: change desc so that few K-tiles will be done in single GEMM.
// {
// if (is_thread_local_1d_id_idx(0))
// {
// printf("bid: %d, group: %d, accumulate tile id (M,N,K): [%d, %d, %d] \n",
// static_cast<index_t>(blockIdx.x),
// group_id,
// b2c_tile_map.GetTileMIdx(),
// b2c_tile_map.GetTileNIdx(),
// b2c_tile_map.GetTileKIdx());
// }
// }
do
{
// just accumulate results in registers!
gridwise_gemm
.
template
RunGEMM
<
HasMainKBlockLoop
>(
p_a_grid
,
p_b_grid
,
static_cast
<
void
*>
(
p_shared
),
a_element_op
,
b_element_op
,
M
,
N
,
K
,
StrideA
,
StrideB
,
k_batch
,
b2c_tile_map
);
//
gridwise_gemm.template RunGEMM<HasMainKBlockLoop>(p_a_grid,
//
p_b_grid,
//
static_cast<void*>(p_shared),
//
a_element_op,
//
b_element_op,
//
M,
//
N,
//
K,
//
StrideA,
//
StrideB,
//
k_batch,
//
b2c_tile_map);
}
while
(
work_scheduler
.
GetNextTile
()
&&
b2c_tile_map
.
GetNextKTileIdx
());
...
...
@@ -167,51 +184,122 @@ __global__ void
work_scheduler
.
FlagFinished
(
k_batch
,
output_tile_idx
,
output_tile_idx_offset
);
// {
// // const uint32_t flag_v2 = __builtin_amdgcn_readfirstlane(
// // work_scheduler.GetFlagValue(k_batch, output_tile_idx, output_tile_idx_offset));
if
(
is_thread_local_1d_id_idx
(
0
))
{
printf
(
"bid: %d, group: %d, FlagFInished
\n
"
,
static_cast
<
index_t
>
(
blockIdx
.
x
),
group_id
);
// printf("bid: %d, group: %d, FlagFInished flag_v[%u]: %u\n",
// static_cast<index_t>(blockIdx.x),
// group_id)
// work_scheduler.GetWorkgroupFlagIdx(k_batch, output_tile_idx, output_tile_idx_offset),
// flag_v2);
}
// }
// The workgroup which processed first K tile accumulates results and stores to GMEM
if
(
b2c_tile_map
.
IsFirstKSplitBlock
())
{
if
(
is_thread_local_1d_id_idx
(
0
))
{
printf
(
"bid: %d, group: %d, Will wait for neighbours...
\n
"
,
static_cast
<
index_t
>
(
blockIdx
.
x
),
group_id
);
}
// Wait untill all other blocks for this [M,N] tile store their results.
work_scheduler
.
WaitForNeighbours
(
k_batch
,
output_tile_idx
,
output_tile_idx_offset
);
// Accumulate partial results. We can have different # of workgroups to reduce, thus we
// read actual flag value.
const
uint32_t
flag_v
=
__builtin_amdgcn_readfirstlane
(
[[
maybe_unused
]]
const
uint32_t
flag_v
=
__builtin_amdgcn_readfirstlane
(
work_scheduler
.
GetFlagValue
(
k_batch
,
output_tile_idx
,
output_tile_idx_offset
));
// {
// if (is_thread_local_1d_id_idx(0))
// {
// printf("bid: %d, group: %d, WaitForNeighbours flag_v[%u]: %u\n",
// static_cast<index_t>(blockIdx.x),
// group_id,
// work_scheduler.GetWorkgroupFlagIdx(k_batch, output_tile_idx, output_tile_idx_offset),
// static_cast<index_t>(blockIdx.x));
// // flag_v);
// }
// }
// using CThreadBuffer = remove_cvref_t<decltype(results_buffer)>;
// constexpr index_t n_v = CThreadBuffer::num_of_v_.value;
// constexpr index_t s_per_v = CThreadBuffer::s_per_v.value;
// static_for<0, n_v, 1>{}([&](auto v) {
// static_for<0, s_per_v, 1>{}([&](auto s) {
// // printf("bid: %d; tid: %d; [Partial results] c_thread_buff[%d, %d]:
// // %f\n",
// // static_cast<index_t>(blockIdx.x),
// // static_cast<index_t>(threadIdx.x),
// // v.value,
// // s.value,
// // static_cast<float>(results_buffer[v * Number<s_per_v>{} + s])
// // );
// results_buffer(v * Number<s_per_v>{} + s) = threadIdx.x * v + s;
// });
// });
// Accumulate only when there is at least two workgroups processing splitk data-tiles
// across same MN-output tile.
if
(
flag_v
>
1
)
gridwise_gemm
.
AccumulatePartials
(
p_workspace
,
flag_v
);
// if(flag_v > 1)
// gridwise_gemm.AccumulatePartials(p_workspace, flag_v);
if
(
is_thread_local_1d_id_idx
(
0
))
{
printf
(
"bid: %d, group: %d, Reset flag
\n
"
,
static_cast
<
index_t
>
(
blockIdx
.
x
),
group_id
);
}
// Signal waiting blocks that they can start use their workspace.
work_scheduler
.
Reset
(
k_batch
,
output_tile_idx
,
output_tile_idx_offset
);
const
auto
p_e_grid
=
reinterpret_cast
<
FloatC
*>
(
gemm_desc_ptr
[
group_id
].
p_e_grid
);
const
auto
stride_e
=
gemm_desc_ptr
[
group_id
].
StrideE
;
const
auto
stride_ds
=
gemm_desc_ptr
[
group_id
].
StrideDs
;
constexpr
auto
NumDTensor
=
DsDataType
::
Size
();
using
DsGridPointer
=
decltype
(
GridwiseGemm
::
MakeDsGridPointer
());
DsGridPointer
p_ds_grid
;
static_for
<
0
,
NumDTensor
,
1
>
{}([
&
](
auto
i
)
{
using
DDataType
=
remove_cvref_t
<
tuple_element_t
<
i
.
value
,
DsDataType
>>
;
p_ds_grid
(
i
)
=
static_cast
<
const
DDataType
*>
(
gemm_desc_ptr
[
group_id
].
p_ds_grid
[
i
]);
});
gridwise_gemm
.
template
RunWrite
(
p_ds_grid
,
p_e_grid
,
static_cast
<
void
*
>(
p_shared
),
M
,
N
,
stride_ds
,
stride_e
,
cde_element_op
,
b2c_tile_map
);
//
const auto p_e_grid = reinterpret_cast<FloatC*>(gemm_desc_ptr[group_id].p_e_grid);
//
const auto stride_e = gemm_desc_ptr[group_id].StrideE;
//
const auto stride_ds = gemm_desc_ptr[group_id].StrideDs;
//
constexpr auto NumDTensor = DsDataType::Size();
//
using DsGridPointer = decltype(GridwiseGemm::MakeDsGridPointer());
//
DsGridPointer p_ds_grid;
//
static_for<0, NumDTensor, 1>{}([&](auto i) {
//
using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
//
p_ds_grid(i) = static_cast<const DDataType*>(gemm_desc_ptr[group_id].p_ds_grid[i]);
//
});
//
gridwise_gemm.template RunWrite(p_ds_grid,
//
p_e_grid,
//
static_cast<void*>(p_shared),
//
M,
//
N,
//
stride_ds,
//
stride_e,
//
cde_element_op,
//
b2c_tile_map);
}
else
if
(
work_scheduler
.
HasTile
())
{
{
// const uint32_t flag_v2 = __builtin_amdgcn_readfirstlane(
const
uint32_t
flag_v2
=
work_scheduler
.
GetFlagValue
(
k_batch
,
output_tile_idx
,
output_tile_idx_offset
);
if
(
is_thread_local_1d_id_idx
(
0
))
{
printf
(
"bid: %d, group: %d, Waiting for Reduction flag_v[%u]: %u
\n
"
,
static_cast
<
index_t
>
(
blockIdx
.
x
),
group_id
,
work_scheduler
.
GetWorkgroupFlagIdx
(
k_batch
,
output_tile_idx
,
output_tile_idx_offset
),
// static_cast<index_t>(blockIdx.x));
flag_v2
);
}
}
work_scheduler
.
WaitForReduction
(
k_batch
,
output_tile_idx
,
output_tile_idx_offset
);
}
}
while
(
work_scheduler
.
HasTile
());
...
...
@@ -751,7 +839,8 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffle
void
*
p_flags
=
reinterpret_cast
<
char
*>
(
dev_gemm_workspace
)
+
Block2ETileMapKSplit
::
GetAccWorkspaceSize
(
sizeof
(
typename
GridwiseGemm
::
AccType
),
grid_size
);
std
::
size_t
flag_count
=
(
grid_size
*
tiles_per_block
+
arg
.
K_BATCH
-
1
)
/
arg
.
K_BATCH
;
// std::size_t flag_count = (grid_size * tiles_per_block + arg.K_BATCH - 1) / arg.K_BATCH;
std
::
size_t
flag_count
=
arg
.
tile_count_
/
arg
.
K_BATCH
;
if
(
stream_config
.
log_level_
>
0
)
{
...
...
@@ -987,7 +1076,14 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffle
arg
.
gpu_cu_count_
*
std
::
min
(
arg
.
occupancy_num_blocks_
,
KernelConfig
::
CU_BLOCKS
);
int
grid_size
=
std
::
min
(
arg
.
tile_count_
,
occ_grid_size
);
int
tiles_per_block
=
(
arg
.
tile_count_
+
grid_size
-
1
)
/
grid_size
;
int
flag_count
=
(
grid_size
*
tiles_per_block
+
arg
.
K_BATCH
-
1
)
/
arg
.
K_BATCH
;
if
(
arg
.
tile_count_
>
occ_grid_size
&&
grid_size
*
tiles_per_block
>
arg
.
tile_count_
)
{
grid_size
=
(
arg
.
tile_count_
+
tiles_per_block
-
1
)
/
tiles_per_block
;
}
// int flag_count = (grid_size * tiles_per_block + arg.K_BATCH - 1) / arg.K_BATCH;
int
flag_count
=
arg
.
tile_count_
/
arg
.
K_BATCH
;
// This would be the maximum needed workspace size. Since actual grid size, which determines
// the amount of workspace bytes needed, may be less due to the number of available CUs in
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_splitk_cshuffle_v2.hpp
View file @
7e71ea99
...
...
@@ -106,6 +106,13 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
using
GridwiseGemmPipe
=
remove_cvref_t
<
decltype
(
GridwiseGemmPipeline_Selector
<
PipelineVer
,
NumGemmKPrefetchStage
,
LoopSched
>
())
>
;
template
<
index_t
...
Ids
>
__device__
static
bool
is_thread_local_1d_id_idx
()
{
const
auto
tid
=
get_thread_local_1d_id
();
return
((
tid
==
Ids
)
||
...);
}
public:
using
AccType
=
AccDataType
;
...
...
@@ -906,6 +913,32 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
Sequence
<
6
>
{},
Sequence
<
7
>
{}));
// if (is_thread_local_1d_id_idx<0>())
// {
// // printf("bid: %d; tid: %d; [Store Partials] c_block_desc:[%d, %d, %d, %d, %d, %d, %d, %d]\n",
// // static_cast<index_t>(blockIdx.x),
// // static_cast<index_t>(threadIdx.x),
// // M0.value,
// // N0.value,
// // M1.value,
// // N1.value,
// // M2.value,
// // M3.value,
// // M4.value,
// // N2.value);
// printf("bid: %d; tid: %d; [Store Partials] wrkspace_desc:[%d, %d, %d, %d, %d, %d, %d, %d]\n",
// static_cast<index_t>(blockIdx.x),
// static_cast<index_t>(threadIdx.x),
// workspace_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I0),
// workspace_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I1),
// workspace_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I2).value,
// workspace_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I3).value,
// workspace_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I4).value,
// workspace_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I5).value,
// workspace_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I6).value,
// workspace_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I7).value);
// }
auto
p_workspace_grid
=
reinterpret_cast
<
AccDataType
*>
(
p_workspace
);
auto
w_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_workspace_grid
,
workspace_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2
.
GetElementSpaceSize
());
...
...
@@ -963,11 +996,33 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
n_thread_data_on_block_idx
[
I2
]),
ck
::
tensor_operation
::
element_wise
::
PassThrough
{}};
// if (is_thread_local_1d_id_idx<0, 64, 223>())
// {
// printf("[StorePartials] bid: %d, tid: %d: dst origin idx[%d, %d, %d, %d, %d, %d, %d, %d]\n",
// static_cast<index_t>(blockIdx.x),
// static_cast<index_t>(threadIdx.x),
// (static_cast<index_t>(blockIdx.x)) * MXdlPerWave,
// n_thread_data_on_block_idx[I0],
// m_thread_data_on_block_idx[I1],
// n_thread_data_on_block_idx[I1],
// m_thread_data_on_block_idx[I2],
// m_thread_data_on_block_idx[I3],
// m_thread_data_on_block_idx[I4],
// n_thread_data_on_block_idx[I2]);
// }
c_thread_copy_vgpr_to_gmem
.
Run
(
c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
c_thread_buf
,
workspace_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
w_grid_buf
);
if
(
is_thread_local_1d_id_idx
<
0
>
())
{
printf
(
"[StorePartials] done. bid: %d, tid: %d
\n
"
,
static_cast
<
index_t
>
(
blockIdx
.
x
),
static_cast
<
index_t
>
(
threadIdx
.
x
));
}
}
__device__
void
AccumulatePartials
(
void
*
__restrict__
p_workspace
,
uint32_t
reduce_count
)
...
...
include/ck/utility/work_scheduling.hpp
View file @
7e71ea99
...
...
@@ -51,7 +51,7 @@ class StridedReductionTileLoop
{
tile_id_
++
;
block_tile_idx_
++
;
return
tile_id_
<
tile_count_
&&
block_tile_idx_
<
tiles_per_block_
;
return
HasTile
()
;
}
__device__
index_t
GetFlagCount
(
index_t
k_tiles
)
const
...
...
@@ -75,11 +75,12 @@ class StridedReductionTileLoop
///
/// @return The workgroup flag index.
///
__device__
uint32_t
GetWorkgroupFlagIdx
(
index_t
k_tiles
,
__device__
uint32_t
GetWorkgroupFlagIdx
(
[[
maybe_unused
]]
index_t
k_tiles
,
index_t
output_tile_idx
,
index_t
output_tile_idx_offset
)
const
{
return
(
output_tile_idx
+
output_tile_idx_offset
)
%
GetFlagCount
(
k_tiles
);
// return (output_tile_idx + output_tile_idx_offset) % GetFlagCount(k_tiles);
return
output_tile_idx
+
output_tile_idx_offset
;
}
///
...
...
@@ -92,7 +93,7 @@ class StridedReductionTileLoop
__device__
void
FlagFinished
(
index_t
k_tiles
,
index_t
output_tile_idx
,
index_t
output_tile_idx_offset
)
{
const
auto
fidx
=
GetWorkgroupFlagIdx
(
k_tiles
,
output_tile_idx
,
output_tile_idx_offset
);
/* [[maybe_unused]] */
const
auto
fidx
=
GetWorkgroupFlagIdx
(
k_tiles
,
output_tile_idx
,
output_tile_idx_offset
);
finished_block_flags_
.
inc
(
fidx
);
}
...
...
@@ -111,8 +112,10 @@ class StridedReductionTileLoop
// We use < because for some cases we may have +1 more workgroups per dim.
// Ie when k_tiles = 5, tiles_per_block = 3.
finished_block_flags_
.
wait_lt
(
GetWorkgroupFlagIdx
(
k_tiles
,
output_tile_idx
,
output_tile_idx_offset
),
GetWorkgroupFlagIdx
(
k_tiles
,
output_tile_idx
,
output_tile_idx_offset
),
workgroups_per_dim
);
// [[maybe_unused]] const auto fidx = GetWorkgroupFlagIdx(k_tiles, output_tile_idx, output_tile_idx_offset);
}
///
...
...
@@ -128,6 +131,8 @@ class StridedReductionTileLoop
// Wait untill the counter has been reset.
finished_block_flags_
.
wait_eq
(
GetWorkgroupFlagIdx
(
k_tiles
,
output_tile_idx
,
output_tile_idx_offset
),
0
);
// [[maybe_unused]] const auto fidx = GetWorkgroupFlagIdx(k_tiles, output_tile_idx, output_tile_idx_offset);
}
///
...
...
@@ -141,6 +146,8 @@ class StridedReductionTileLoop
{
finished_block_flags_
.
reset
(
GetWorkgroupFlagIdx
(
k_tiles
,
output_tile_idx
,
output_tile_idx_offset
));
// [[maybe_unused]] const auto fidx = GetWorkgroupFlagIdx(k_tiles, output_tile_idx, output_tile_idx_offset);
}
///
...
...
library/include/ck/library/tensor_operation_instance/gpu/grouped_gemm.hpp
View file @
7e71ea99
...
...
@@ -16,96 +16,96 @@ namespace tensor_operation {
namespace
device
{
namespace
instance
{
void
add_device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedGemm
<
Row
,
Row
,
Empty_Tuple
,
Row
,
F16
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedGemm
<
Row
,
Col
,
Empty_Tuple
,
Row
,
F16
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_grouped_gemm_xdl_f16_f16_f16_km_kn_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedGemm
<
Col
,
Row
,
Empty_Tuple
,
Row
,
F16
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_grouped_gemm_xdl_f16_f16_f16_km_nk_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedGemm
<
Col
,
Col
,
Empty_Tuple
,
Row
,
F16
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedGemm
<
Row
,
Col
,
Empty_Tuple
,
Row
,
F16
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedGemm
<
Row
,
Row
,
Empty_Tuple
,
Row
,
F16
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_irregular_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedGemm
<
Row
,
Col
,
Empty_Tuple
,
Row
,
F16
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
//
void add_device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instances(
//
std::vector<std::unique_ptr<DeviceGroupedGemm<Row,
//
Row,
//
Empty_Tuple,
//
Row,
//
F16,
//
F16,
//
Empty_Tuple,
//
F16,
//
PassThrough,
//
PassThrough,
//
PassThrough>>>& instances);
//
void add_device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_instances(
//
std::vector<std::unique_ptr<DeviceGroupedGemm<Row,
//
Col,
//
Empty_Tuple,
//
Row,
//
F16,
//
F16,
//
Empty_Tuple,
//
F16,
//
PassThrough,
//
PassThrough,
//
PassThrough>>>& instances);
//
void add_device_grouped_gemm_xdl_f16_f16_f16_km_kn_mn_instances(
//
std::vector<std::unique_ptr<DeviceGroupedGemm<Col,
//
Row,
//
Empty_Tuple,
//
Row,
//
F16,
//
F16,
//
Empty_Tuple,
//
F16,
//
PassThrough,
//
PassThrough,
//
PassThrough>>>& instances);
//
void add_device_grouped_gemm_xdl_f16_f16_f16_km_nk_mn_instances(
//
std::vector<std::unique_ptr<DeviceGroupedGemm<Col,
//
Col,
//
Empty_Tuple,
//
Row,
//
F16,
//
F16,
//
Empty_Tuple,
//
F16,
//
PassThrough,
//
PassThrough,
//
PassThrough>>>& instances);
//
void add_device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instances(
//
std::vector<std::unique_ptr<DeviceGroupedGemm<Row,
//
Col,
//
Empty_Tuple,
//
Row,
//
F16,
//
F16,
//
Empty_Tuple,
//
F16,
//
PassThrough,
//
PassThrough,
//
PassThrough>>>& instances);
//
void add_device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instances(
//
std::vector<std::unique_ptr<DeviceGroupedGemm<Row,
//
Row,
//
Empty_Tuple,
//
Row,
//
F16,
//
F16,
//
Empty_Tuple,
//
F16,
//
PassThrough,
//
PassThrough,
//
PassThrough>>>& instances);
//
void add_device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_irregular_instances(
//
std::vector<std::unique_ptr<DeviceGroupedGemm<Row,
//
Col,
//
Empty_Tuple,
//
Row,
//
F16,
//
F16,
//
Empty_Tuple,
//
F16,
//
PassThrough,
//
PassThrough,
//
PassThrough>>>& instances);
void
add_device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_irregular_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedGemm
<
Row
,
...
...
@@ -120,31 +120,31 @@ void add_device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_irregular_instances
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_grouped_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_irregular_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedGemm
<
Row
,
Row
,
Empty_Tuple
,
Row
,
F16
,
F8
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_grouped_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_irregular_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedGemm
<
Row
,
Row
,
Empty_Tuple
,
Row
,
F8
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
//
void add_device_grouped_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_irregular_instances(
//
std::vector<std::unique_ptr<DeviceGroupedGemm<Row,
//
Row,
//
Empty_Tuple,
//
Row,
//
F16,
//
F8,
//
Empty_Tuple,
//
F16,
//
PassThrough,
//
PassThrough,
//
PassThrough>>>& instances);
//
void add_device_grouped_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_irregular_instances(
//
std::vector<std::unique_ptr<DeviceGroupedGemm<Row,
//
Row,
//
Empty_Tuple,
//
Row,
//
F8,
//
F16,
//
Empty_Tuple,
//
F16,
//
PassThrough,
//
PassThrough,
//
PassThrough>>>& instances);
template
<
typename
ALayout
,
typename
BLayout
,
...
...
@@ -186,48 +186,48 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
is_same_v
<
BLayout
,
Row
>
&&
is_same_v
<
ELayout
,
Row
>
)
{
add_device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instances
(
op_ptrs
);
add_device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instances
(
op_ptrs
);
//
add_device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instances(op_ptrs);
//
add_device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instances(op_ptrs);
add_device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_irregular_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
is_same_v
<
BLayout
,
Col
>
&&
is_same_v
<
ELayout
,
Row
>
)
{
add_device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_instances
(
op_ptrs
);
add_device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instances
(
op_ptrs
);
add_device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_irregular_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same_v
<
ALayout
,
Col
>
&&
is_same_v
<
BLayout
,
Row
>
&&
is_same_v
<
ELayout
,
Row
>
)
{
add_device_grouped_gemm_xdl_f16_f16_f16_km_kn_mn_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same_v
<
ALayout
,
Col
>
&&
is_same_v
<
BLayout
,
Col
>
&&
is_same_v
<
ELayout
,
Row
>
)
{
add_device_grouped_gemm_xdl_f16_f16_f16_km_nk_mn_instances
(
op_ptrs
);
}
}
else
if
constexpr
(
is_same_v
<
ADataType
,
half_t
>
&&
is_same_v
<
BDataType
,
f8_t
>
&&
is_same_v
<
EDataType
,
half_t
>
)
{
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
is_same_v
<
BLayout
,
Row
>
&&
is_same_v
<
ELayout
,
Row
>
)
{
add_device_grouped_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_irregular_instances
(
op_ptrs
);
}
}
else
if
constexpr
(
is_same_v
<
ADataType
,
f8_t
>
&&
is_same_v
<
BDataType
,
half_t
>
&&
is_same_v
<
EDataType
,
half_t
>
)
{
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
is_same_v
<
BLayout
,
Row
>
&&
is_same_v
<
ELayout
,
Row
>
)
{
add_device_grouped_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_irregular_instances
(
op_ptrs
);
// add_device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_instances(op_ptrs);
// add_device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instances(op_ptrs);
// add_device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_irregular_instances(
// op_ptrs);
}
// else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Row> &&
// is_same_v<ELayout, Row>)
// {
// add_device_grouped_gemm_xdl_f16_f16_f16_km_kn_mn_instances(op_ptrs);
// }
// else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Col> &&
// is_same_v<ELayout, Row>)
// {
// add_device_grouped_gemm_xdl_f16_f16_f16_km_nk_mn_instances(op_ptrs);
// }
}
// else if constexpr(is_same_v<ADataType, half_t> && is_same_v<BDataType, f8_t> &&
// is_same_v<EDataType, half_t>)
// {
// if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
// is_same_v<ELayout, Row>)
// {
// add_device_grouped_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_irregular_instances(op_ptrs);
// }
// }
// else if constexpr(is_same_v<ADataType, f8_t> && is_same_v<BDataType, half_t> &&
// is_same_v<EDataType, half_t>)
// {
// if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
// is_same_v<ELayout, Row>)
// {
// add_device_grouped_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_irregular_instances(op_ptrs);
// }
// }
return
op_ptrs
;
}
};
...
...
library/include/ck/library/tensor_operation_instance/gpu/grouped_gemm_multiple_d.hpp
View file @
7e71ea99
...
...
@@ -17,18 +17,18 @@ namespace device {
namespace
instance
{
// MultiD version
void
add_device_grouped_gemm_multi_d_splitk_cshuffle_f16_f16_f16_mk_nk_mn_irregular_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedGemm
<
Row
,
Col
,
Empty_Tuple
,
Row
,
F16
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
//
void add_device_grouped_gemm_multi_d_splitk_cshuffle_f16_f16_f16_mk_nk_mn_irregular_instances(
//
std::vector<std::unique_ptr<DeviceGroupedGemm<Row,
//
Col,
//
Empty_Tuple,
//
Row,
//
F16,
//
F16,
//
Empty_Tuple,
//
F16,
//
PassThrough,
//
PassThrough,
//
PassThrough>>>& instances);
void
add_device_grouped_gemm_multi_d_splitk_cshuffle_f16_f16_f16_mk_kn_mn_irregular_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedGemm
<
Row
,
...
...
@@ -93,8 +93,8 @@ struct DeviceOperationInstanceFactory<
else
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
is_same_v
<
BLayout
,
Col
>
&&
is_same_v
<
ELayout
,
Row
>
)
{
add_device_grouped_gemm_multi_d_splitk_cshuffle_f16_f16_f16_mk_nk_mn_irregular_instances
(
op_ptrs
);
//
add_device_grouped_gemm_multi_d_splitk_cshuffle_f16_f16_f16_mk_nk_mn_irregular_instances(
//
op_ptrs);
}
}
return
op_ptrs
;
...
...
library/src/tensor_operation_instance/gpu/grouped_gemm/CMakeLists.txt
View file @
7e71ea99
add_instance_library
(
device_grouped_gemm_instance
device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instance.cpp
device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_instance.cpp
device_grouped_gemm_xdl_f16_f16_f16_km_kn_mn_instance.cpp
device_grouped_gemm_xdl_f16_f16_f16_km_nk_mn_instance.cpp
device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instance.cpp
device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instance.cpp
#
device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instance.cpp
#
device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_instance.cpp
#
device_grouped_gemm_xdl_f16_f16_f16_km_kn_mn_instance.cpp
#
device_grouped_gemm_xdl_f16_f16_f16_km_nk_mn_instance.cpp
#
device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instance.cpp
#
device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instance.cpp
device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_irregular_instance.cpp
device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_irregular_instance.cpp
device_grouped_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_irregular_instance.cpp
device_grouped_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_irregular_instance.cpp
#
device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_irregular_instance.cpp
#
device_grouped_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_irregular_instance.cpp
#
device_grouped_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_irregular_instance.cpp
)
library/src/tensor_operation_instance/gpu/grouped_gemm_multiple_d/CMakeLists.txt
View file @
7e71ea99
add_instance_library
(
device_grouped_gemm_multiple_d_instance
device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_f16_f16_f16_mk_nk_mn_irregular_instance.cpp
#
device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_f16_f16_f16_mk_nk_mn_irregular_instance.cpp
device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_f16_f16_f16_mk_kn_mn_irregular_instance.cpp
)
library/src/tensor_operation_instance/gpu/grouped_gemm_multiple_d/device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_f16_f16_f16_mk_kn_mn_irregular_instance.cpp
View file @
7e71ea99
This diff is collapsed.
Click to expand it.
profiler/include/profiler/profile_grouped_gemm_multiple_d_splitk_impl.hpp
View file @
7e71ea99
...
...
@@ -219,6 +219,8 @@ bool profile_ggemm_multid_splitk(int do_verification,
// profile device GEMM instances
for
(
auto
&
gemm_ptr
:
op_ptrs
)
{
std
::
cout
<<
"Running instance: "
<<
gemm_ptr
->
GetTypeString
()
<<
std
::
endl
;
auto
gptr
=
dynamic_cast
<
DeviceOp
*>
(
gemm_ptr
.
get
());
auto
argument_ptr
=
gemm_ptr
->
MakeArgumentPointer
(
...
...
@@ -247,20 +249,24 @@ bool profile_ggemm_multid_splitk(int do_verification,
for
(
std
::
size_t
j
=
0
;
j
<
kbatch_list
.
size
();
j
++
)
{
auto
kbatch_curr
=
kbatch_list
[
j
];
// std::cout << ">>> kbatch: " << kbatch_curr << std::endl;
gptr
->
SetKBatchSize
(
argument_ptr
.
get
(),
kbatch_curr
);
DeviceMem
gemm_desc_workspace
(
gemm_ptr
->
GetWorkSpaceSize
(
argument_ptr
.
get
()));
gemm_ptr
->
SetWorkSpacePointer
(
argument_ptr
.
get
(),
gemm_desc_workspace
.
GetDeviceBuffer
());
// std::cout << "WorkspacePointer set!" << std::endl;
if
(
gemm_ptr
->
IsSupportedArgument
(
argument_ptr
.
get
()))
{
for
(
std
::
size_t
i
=
0
;
i
<
gemm_descs
.
size
();
i
++
)
c_device_buf
[
i
]
->
SetZero
();
invoker_ptr
->
Run
(
argument_ptr
.
get
(),
StreamConfig
{
nullptr
,
false
});
// invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false, 1});
// std::cout << ">>>>>GPU Run end!" << std::endl;
if
(
do_verification
)
{
...
...
@@ -304,12 +310,16 @@ bool profile_ggemm_multid_splitk(int do_verification,
<<
(
instance_pass
?
"SUCCEED"
:
"FAILED"
)
<<
std
::
endl
;
pass
=
pass
&&
instance_pass
;
std
::
cout
<<
">>>>>CPU verification end!"
<<
std
::
endl
;
}
if
(
time_kernel
)
{
float
avg_time
=
invoker_ptr
->
Run
(
argument_ptr
.
get
(),
StreamConfig
{
nullptr
,
time_kernel
});
std
::
cout
<<
">>>>>GPU time profiling start!"
<<
std
::
endl
;
float
avg_time
=
invoker_ptr
->
Run
(
// argument_ptr.get(), StreamConfig{nullptr, time_kernel, 1, 5, 30});
argument_ptr
.
get
(),
StreamConfig
{
nullptr
,
time_kernel
,
1
,
0
,
1
});
std
::
size_t
flop
=
0
,
num_btype
=
0
;
for
(
std
::
size_t
i
=
0
;
i
<
gemm_descs
.
size
();
i
++
)
{
...
...
@@ -335,6 +345,7 @@ bool profile_ggemm_multid_splitk(int do_verification,
best_gb_per_sec
=
gb_per_sec
;
best_kbatch
=
kbatch_curr
;
}
// std::cout << ">>>>>GPU time profiling end!" << std::endl;
}
}
else
...
...
profiler/src/CMakeLists.txt
View file @
7e71ea99
# ckProfiler
set
(
PROFILER_SOURCES
profiler.cpp
profile_gemm.cpp
profile_gemm_splitk.cpp
profile_gemm_bias_add_reduce.cpp
profile_gemm_add_multiply.cpp
profile_gemm_multiply_add.cpp
profile_gemm_reduce.cpp
profile_batched_gemm.cpp
profile_batched_gemm_reduce.cpp
profile_conv_fwd.cpp
profile_conv_fwd_bias_relu.cpp
profile_conv_fwd_bias_relu_add.cpp
profile_conv_bwd_data.cpp
profile_grouped_conv_fwd.cpp
profile_grouped_conv_bwd_weight.cpp
profile_reduce.cpp
profile_groupnorm_bwd_data.cpp
profile_groupnorm_fwd.cpp
profile_layernorm_bwd_data.cpp
profile_layernorm_fwd.cpp
profile_max_pool3d_fwd.cpp
profile_avg_pool3d_bwd.cpp
profile_max_pool3d_bwd.cpp
profile_softmax.cpp
profile_batchnorm_fwd.cpp
profile_batchnorm_bwd.cpp
profile_batchnorm_infer.cpp
profile_grouped_conv_bwd_data.cpp
profile_conv_tensor_rearrange.cpp
#
profile_gemm.cpp
#
profile_gemm_splitk.cpp
#
profile_gemm_bias_add_reduce.cpp
#
profile_gemm_add_multiply.cpp
#
profile_gemm_multiply_add.cpp
#
profile_gemm_reduce.cpp
#
profile_batched_gemm.cpp
#
profile_batched_gemm_reduce.cpp
#
profile_conv_fwd.cpp
#
profile_conv_fwd_bias_relu.cpp
#
profile_conv_fwd_bias_relu_add.cpp
#
profile_conv_bwd_data.cpp
#
profile_grouped_conv_fwd.cpp
#
profile_grouped_conv_bwd_weight.cpp
#
profile_reduce.cpp
#
profile_groupnorm_bwd_data.cpp
#
profile_groupnorm_fwd.cpp
#
profile_layernorm_bwd_data.cpp
#
profile_layernorm_fwd.cpp
#
profile_max_pool3d_fwd.cpp
#
profile_avg_pool3d_bwd.cpp
#
profile_max_pool3d_bwd.cpp
#
profile_softmax.cpp
#
profile_batchnorm_fwd.cpp
#
profile_batchnorm_bwd.cpp
#
profile_batchnorm_infer.cpp
#
profile_grouped_conv_bwd_data.cpp
#
profile_conv_tensor_rearrange.cpp
)
if
(
DL_KERNELS
)
...
...
@@ -36,21 +36,22 @@ if(DL_KERNELS)
endif
()
if
(
DTYPES MATCHES
"fp16"
OR NOT DEFINED DTYPES
)
list
(
APPEND PROFILER_SOURCES profile_batched_gemm_gemm.cpp
)
list
(
APPEND PROFILER_SOURCES profile_gemm_fastgelu.cpp
)
list
(
APPEND PROFILER_SOURCES profile_gemm_streamk.cpp
)
list
(
APPEND PROFILER_SOURCES profile_gemm_bilinear.cpp
)
list
(
APPEND PROFILER_SOURCES profile_gemm_add_fastgelu.cpp
)
list
(
APPEND PROFILER_SOURCES profile_gemm_add_add_fastgelu.cpp
)
list
(
APPEND PROFILER_SOURCES profile_gemm_add_relu_add_layernorm.cpp
)
list
(
APPEND PROFILER_SOURCES profile_batched_gemm_add_relu_gemm_add.cpp
)
#
list(APPEND PROFILER_SOURCES profile_batched_gemm_gemm.cpp)
#
list(APPEND PROFILER_SOURCES profile_gemm_fastgelu.cpp)
#
list(APPEND PROFILER_SOURCES profile_gemm_streamk.cpp)
#
list(APPEND PROFILER_SOURCES profile_gemm_bilinear.cpp)
#
list(APPEND PROFILER_SOURCES profile_gemm_add_fastgelu.cpp)
#
list(APPEND PROFILER_SOURCES profile_gemm_add_add_fastgelu.cpp)
#
list(APPEND PROFILER_SOURCES profile_gemm_add_relu_add_layernorm.cpp)
#
list(APPEND PROFILER_SOURCES profile_batched_gemm_add_relu_gemm_add.cpp)
list
(
APPEND PROFILER_SOURCES profile_grouped_gemm.cpp
)
list
(
APPEND PROFILER_SOURCES profile_grouped_gemm_fastgelu.cpp
)
# list(APPEND PROFILER_SOURCES profile_grouped_gemm_fastgelu.cpp)
list
(
APPEND PROFILER_SOURCES profile_grouped_gemm_multiple_d_splitk.cpp
)
endif
()
if
(
DTYPES MATCHES
"fp32"
OR DTYPES MATCHES
"fp64"
OR NOT DEFINED DTYPES
)
list
(
APPEND PROFILER_SOURCES profile_contraction_bilinear.cpp
)
list
(
APPEND PROFILER_SOURCES profile_contraction_scale.cpp
)
#
list(APPEND PROFILER_SOURCES profile_contraction_bilinear.cpp)
#
list(APPEND PROFILER_SOURCES profile_contraction_scale.cpp)
endif
()
set
(
PROFILER_EXECUTABLE ckProfiler
)
...
...
@@ -59,42 +60,42 @@ add_executable(${PROFILER_EXECUTABLE} ${PROFILER_SOURCES})
target_compile_options
(
${
PROFILER_EXECUTABLE
}
PRIVATE -Wno-global-constructors
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE utility
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_gemm_instance
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_gemm_splitk_instance
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_gemm_add_multiply_instance
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_gemm_multiply_add_instance
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_gemm_reduce_instance
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_gemm_bias_add_reduce_instance
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_batched_gemm_instance
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_batched_gemm_reduce_instance
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_conv2d_fwd_instance
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_grouped_conv1d_fwd_instance
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_grouped_conv2d_fwd_instance
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_grouped_conv3d_fwd_instance
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_conv1d_bwd_data_instance
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_conv2d_bwd_data_instance
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_conv3d_bwd_data_instance
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_grouped_conv1d_bwd_weight_instance
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_grouped_conv2d_bwd_weight_instance
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_grouped_conv3d_bwd_weight_instance
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_conv2d_fwd_bias_relu_instance
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_conv2d_fwd_bias_relu_add_instance
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_normalization_fwd_instance
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_normalization_bwd_data_instance
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_softmax_instance
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_reduce_instance
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_batchnorm_instance
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_pool3d_fwd_instance
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_avg_pool3d_bwd_instance
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_max_pool_bwd_instance
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_grouped_conv2d_bwd_data_instance
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_grouped_conv3d_bwd_data_instance
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_image_to_column_instance
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_column_to_image_instance
)
#
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_instance)
#
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_splitk_instance)
#
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_multiply_instance)
#
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_multiply_add_instance)
#
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_reduce_instance)
#
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_bias_add_reduce_instance)
#
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batched_gemm_instance)
#
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batched_gemm_reduce_instance)
#
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_conv2d_fwd_instance)
#
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv1d_fwd_instance)
#
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv2d_fwd_instance)
#
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv3d_fwd_instance)
#
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_conv1d_bwd_data_instance)
#
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_conv2d_bwd_data_instance)
#
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_conv3d_bwd_data_instance)
#
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv1d_bwd_weight_instance)
#
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv2d_bwd_weight_instance)
#
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv3d_bwd_weight_instance)
#
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_conv2d_fwd_bias_relu_instance)
#
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_conv2d_fwd_bias_relu_add_instance)
#
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_normalization_fwd_instance)
#
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_normalization_bwd_data_instance)
#
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_softmax_instance)
#
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_reduce_instance)
#
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batchnorm_instance)
#
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_pool3d_fwd_instance)
#
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_avg_pool3d_bwd_instance)
#
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_max_pool_bwd_instance)
#
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv2d_bwd_data_instance)
#
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv3d_bwd_data_instance)
#
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_image_to_column_instance)
#
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_column_to_image_instance)
if
(
DTYPES MATCHES
"fp32"
OR DTYPES MATCHES
"fp64"
OR NOT DEFINED DTYPES
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_contraction_bilinear_instance
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_contraction_scale_instance
)
#
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_contraction_bilinear_instance)
#
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_contraction_scale_instance)
endif
()
...
...
@@ -104,16 +105,17 @@ if(DL_KERNELS)
endif
()
if
(
DTYPES MATCHES
"fp16"
OR NOT DEFINED DTYPES
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_gemm_add_fastgelu_instance
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_gemm_add_relu_add_layernorm_instance
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_gemm_bilinear_instance
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_gemm_add_add_fastgelu_instance
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_gemm_streamk_instance
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_gemm_fastgelu_instance
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_batched_gemm_gemm_instance
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_batched_gemm_add_relu_gemm_add_instance
)
#
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_fastgelu_instance)
#
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_relu_add_layernorm_instance)
#
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_bilinear_instance)
#
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_add_fastgelu_instance)
#
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_streamk_instance)
#
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_fastgelu_instance)
#
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batched_gemm_gemm_instance)
#
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batched_gemm_add_relu_gemm_add_instance)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_grouped_gemm_instance
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_grouped_gemm_fastgelu_instance
)
# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_gemm_fastgelu_instance)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_grouped_gemm_multiple_d_instance
)
endif
()
rocm_install
(
TARGETS
${
PROFILER_EXECUTABLE
}
COMPONENT profiler
)
test/work_scheduling/test_strided_reduction_tile_loop.cpp
View file @
7e71ea99
...
...
@@ -154,9 +154,9 @@ __global__ void grouped_gemm_naive_strided_tile_loop_reduce(const GemmArgDesc* p
// Accumulate partial results. We can have different # of workgroups to reduce, thus we
// read actual flag value.
const
in
dex
_t
flag_v
=
__builtin_amdgcn_readfirstlane
(
const
u
in
t32
_t
flag_v
=
__builtin_amdgcn_readfirstlane
(
work_scheduler
.
GetFlagValue
(
k_batch
,
output_tile_idx
,
output_tile_idx_offset
));
for
(
in
dex
_t
i
=
1
;
i
<
flag_v
;
++
i
)
for
(
u
in
t32
_t
i
=
1
;
i
<
flag_v
;
++
i
)
{
partial_result
+=
p_workspace
[(
get_block_1d_id
())
*
MPerBlock
*
NPerBlock
+
i
*
MPerBlock
*
NPerBlock
+
get_thread_local_1d_id
()];
...
...
@@ -174,7 +174,7 @@ __global__ void grouped_gemm_naive_strided_tile_loop_reduce(const GemmArgDesc* p
p_C
[(
C_m_tile_offset
+
C_thread_tile_m_idx
)
*
stride_c
+
C_n_tile_offset
+
C_thread_tile_n_idx
]
=
partial_result
;
}
else
else
if
(
work_scheduler
.
HasTile
())
{
work_scheduler
.
WaitForReduction
(
k_batch
,
output_tile_idx
,
output_tile_idx_offset
);
}
...
...
@@ -284,10 +284,11 @@ struct GroupedGemmStridedTileLoopReduce
DeviceMem
gemm_workspace
,
gemm_flags
;
const
index_t
tiles_per_block
=
(
tile_count
+
grid_size
-
1
)
/
grid_size
;
//
const index_t tiles_per_block = (tile_count + grid_size - 1) / grid_size;
// This is the number of MN-output tiles which we cover with workgroups.
// We launch k_batch / tiles_per_block workgroups for each output tile.
const
index_t
flag_count
=
(
grid_size
*
tiles_per_block
+
k_batch
-
1
)
/
k_batch
;
// const index_t flag_count = (grid_size * tiles_per_block + k_batch - 1) / k_batch;
const
index_t
flag_count
=
tile_count
/
k_batch
;
gemm_workspace
.
Realloc
(
grid_size
*
MPerBlock
*
NPerBlock
*
sizeof
(
float
));
gemm_flags
.
Realloc
(
flag_count
*
sizeof
(
uint32_t
));
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment