Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
bf49603b
"configs/vscode:/vscode.git/clone" did not exist on "53fe788d27bf3381d8ee1fbd2dd887af47d09501"
Commit
bf49603b
authored
Oct 02, 2023
by
Adam Osewski
Browse files
Strided Reduction Tile Loop work scheduler.
parent
983bd0a4
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
367 additions
and
0 deletions
+367
-0
include/ck/utility/work_scheduling.hpp
include/ck/utility/work_scheduling.hpp
+104
-0
test/CMakeLists.txt
test/CMakeLists.txt
+1
-0
test/work_scheduling/CMakeLists.txt
test/work_scheduling/CMakeLists.txt
+6
-0
test/work_scheduling/test_strided_reduction_tile_loop.cpp
test/work_scheduling/test_strided_reduction_tile_loop.cpp
+256
-0
No files found.
include/ck/utility/work_scheduling.hpp
0 → 100644
View file @
bf49603b
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/ck.hpp"
#include "ck/utility/data_type.hpp"
#include "ck/utility/workgroup_barrier.hpp"
namespace
ck
{
enum
struct
WorkSchedulingPolicy
{
StridedTileLoop
};
///
/// @brief This class describes a strided reduction tile loop work scheduling.
///
///
/// @par Overview
/// This work scheduling policy assume linear mapping (with stride) of workgroups along
/// the reduced dimension. In GEMM problem this mean that consecutive workgroups are mapped
/// to strided data tiles along K dimension. This can be obtained using i.e.
/// @see BlockToCTileMap_ReduceKSplit.
///
/// @par Synchronization
/// All workgroups aligned along particular reduced dimension have to reduce their partial
/// results. In order to do that there's a need to use global flags and atomics to communicate
/// between those workgroups.
///
class
StridedReductionTileLoop
{
public:
__device__
StridedReductionTileLoop
(
index_t
tile_count
,
uint32_t
*
const
__restrict__
p_flag_count
)
:
tile_count_
{
tile_count
},
tiles_per_block_
{(
tile_count_
+
get_grid_size
()
-
1
)
/
get_grid_size
()},
tile_id_
{
get_block_1d_id
()
*
tiles_per_block_
},
block_tile_idx_
{
0
},
finished_block_flags_
{
p_flag_count
}
{
}
__device__
bool
GetNextTile
()
{
tile_id_
++
;
block_tile_idx_
++
;
return
tile_id_
<
tile_count_
&&
block_tile_idx_
<
tiles_per_block_
;
}
///
/// @brief Calculate this workgroup flag index.
///
/// @note Note this scheduler intentionaly does not have flag index as its member, since
/// the number of `dim_tiles` may change when iterating (ie. in grouped gemm,
/// different groups may have different `dim_tiles` in K dimension).
///
/// @param[in] dim_tiles The number of data tiles in the reduced dimension.
/// @param[in] output_tile_idx The output (MN) tile index.
///
/// @return The workgroup flag index.
///
__device__
index_t
GetWorkgroupFlagIdx
(
index_t
dim_tiles
,
index_t
output_tile_idx
)
const
{
// This is the number of MN-output tiles which we cover with workgroups.
// We launch dim_tiles (k_batch) / tiles_per_block workgroups for each output tile.
const
index_t
flag_count
=
(
get_grid_size
()
*
tiles_per_block_
+
dim_tiles
-
1
)
/
dim_tiles
;
return
output_tile_idx
%
flag_count
;
}
///
/// @brief Flag each workgroup that has finished its work.
///
/// @param[in] dim_tiles The number of tiles in the reduced dimension.
/// @param[in] output_tile_idx The output (MN) tile index
///
__device__
void
FlagFinished
(
index_t
dim_tiles
,
index_t
output_tile_idx
)
{
finished_block_flags_
.
inc
(
GetWorkgroupFlagIdx
(
dim_tiles
,
output_tile_idx
));
}
///
/// @brief Wait until each workgroup has finished its work.
///
/// @param[in] dim_tiles The number of tiles in the reduced dimension.
/// @param[in] output_tile_idx The output (MN) tile index
///
__device__
void
WaitForNeighbours
(
index_t
dim_tiles
,
index_t
output_tile_idx
)
{
// Wait untill all workgroups finish and reset counter.
const
index_t
workgroups_per_dim
=
(
dim_tiles
+
tiles_per_block_
-
1
)
/
tiles_per_block_
;
finished_block_flags_
.
wait_set
(
GetWorkgroupFlagIdx
(
dim_tiles
,
output_tile_idx
),
workgroups_per_dim
,
0
);
}
const
index_t
tile_count_
;
const
index_t
tiles_per_block_
;
index_t
tile_id_
;
index_t
block_tile_idx_
;
workgroup_barrier
finished_block_flags_
;
};
}
// namespace ck
test/CMakeLists.txt
View file @
bf49603b
...
@@ -156,6 +156,7 @@ add_subdirectory(pool)
...
@@ -156,6 +156,7 @@ add_subdirectory(pool)
add_subdirectory
(
batched_gemm_multi_d
)
add_subdirectory
(
batched_gemm_multi_d
)
add_subdirectory
(
grouped_convnd_bwd_data
)
add_subdirectory
(
grouped_convnd_bwd_data
)
add_subdirectory
(
conv_tensor_rearrange
)
add_subdirectory
(
conv_tensor_rearrange
)
add_subdirectory
(
work_scheduling
)
if
(
GPU_TARGETS MATCHES
"gfx11"
)
if
(
GPU_TARGETS MATCHES
"gfx11"
)
add_subdirectory
(
wmma_op
)
add_subdirectory
(
wmma_op
)
endif
()
endif
()
test/work_scheduling/CMakeLists.txt
0 → 100644
View file @
bf49603b
add_custom_target
(
test_work_scheduling
)
add_gtest_executable
(
test_strided_reduction_tile_loop test_strided_reduction_tile_loop.cpp
)
target_link_libraries
(
test_strided_reduction_tile_loop PRIVATE utility
)
add_dependencies
(
test_work_scheduling test_strided_reduction_tile_loop
)
test/work_scheduling/test_strided_reduction_tile_loop.cpp
0 → 100644
View file @
bf49603b
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <gtest/gtest.h>
#include <ck/ck.hpp>
#include <ck/host_utility/kernel_launch.hpp>
#include <ck/utility/common_header.hpp>
#include <ck/utility/work_scheduling.hpp>
#include <ck/tensor_description/tensor_descriptor_helper.hpp>
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include <ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp>
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/fill.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
using
namespace
ck
;
namespace
{
template
<
index_t
MPerBlock
,
index_t
NPerBlock
,
index_t
KPerBlock
>
__global__
void
gemm_naive_strided_tile_loop_reduce
(
index_t
M
,
index_t
N
,
index_t
K
,
const
float
*
p_A
,
const
float
*
p_B
,
float
*
p_C
,
float
*
p_workspace
,
uint32_t
*
p_flags
,
index_t
tile_count
,
index_t
k_batch
)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
StridedReductionTileLoop
work_scheduler
{
tile_count
,
p_flags
};
const
auto
c_grid_desc_m_n
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
M
,
N
));
BlockToCTileMap_LinearKSplit
<
MPerBlock
,
NPerBlock
>
b2c_tile_map
(
c_grid_desc_m_n
,
k_batch
);
float
partial_result
=
0.
f
;
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
// Assume MK-KN-MN data layout
const
index_t
stride_a
=
K
;
const
index_t
stride_b
=
N
;
const
index_t
stride_c
=
N
;
// K is the contiguous dim in memory, as well as fastest changing dim in B2C mapping.
const
auto
block_work_idx
=
b2c_tile_map
.
CalculateBottomIndex
(
work_scheduler
.
tile_id_
);
const
index_t
block_m_id
=
__builtin_amdgcn_readfirstlane
(
block_work_idx
[
I0
]);
const
index_t
block_n_id
=
__builtin_amdgcn_readfirstlane
(
block_work_idx
[
I1
]);
do
{
const
index_t
k_batch_id
=
__builtin_amdgcn_readfirstlane
(
b2c_tile_map
.
GetTileKIdx
());
const
index_t
A_m_tile_offset
=
block_m_id
*
MPerBlock
;
const
index_t
A_k_tile_offset
=
k_batch_id
*
KPerBlock
;
const
index_t
A_thread_tile_m_idx
=
get_thread_local_1d_id
()
/
NPerBlock
;
const
index_t
B_n_tile_offset
=
block_n_id
*
NPerBlock
;
const
index_t
B_k_tile_offset
=
k_batch_id
*
KPerBlock
;
const
index_t
B_thread_tile_n_idx
=
get_thread_local_1d_id
()
%
NPerBlock
;
for
(
index_t
k
=
0
;
k
<
KPerBlock
;
++
k
)
{
partial_result
+=
p_A
[(
A_m_tile_offset
+
A_thread_tile_m_idx
)
*
stride_a
+
A_k_tile_offset
+
k
]
*
p_B
[(
B_k_tile_offset
+
k
)
*
stride_b
+
B_n_tile_offset
+
B_thread_tile_n_idx
];
}
}
while
(
work_scheduler
.
GetNextTile
()
&&
b2c_tile_map
.
GetNextKTileIdx
());
// if next [M,N] tile
if
(
!
b2c_tile_map
.
IsFirstKSplitBlock
(
work_scheduler
.
tiles_per_block_
))
{
// Assume we have MPerBlock x NPerBlock tile per each workgroup in contiguous memory.
p_workspace
[
get_block_1d_id
()
*
MPerBlock
*
NPerBlock
+
get_thread_local_1d_id
()]
=
partial_result
;
}
work_scheduler
.
FlagFinished
(
k_batch
,
b2c_tile_map
.
GetOutputTileIdx
());
// The workgroup which processed first K tile accumulates results and stores to GMEM
if
(
b2c_tile_map
.
IsFirstKSplitBlock
(
work_scheduler
.
tiles_per_block_
))
{
// Wait untill all other blocks for this [M,N] tile store their results.
work_scheduler
.
WaitForNeighbours
(
k_batch
,
b2c_tile_map
.
GetOutputTileIdx
());
// accumulate partial results
const
index_t
workgroups_per_dim
=
(
k_batch
+
work_scheduler
.
tiles_per_block_
-
1
)
/
work_scheduler
.
tiles_per_block_
;
for
(
index_t
i
=
0
;
i
<
workgroups_per_dim
;
++
i
)
{
partial_result
+=
p_workspace
[(
get_block_1d_id
())
*
MPerBlock
*
NPerBlock
+
i
*
MPerBlock
*
NPerBlock
+
get_thread_local_1d_id
()];
}
// write result
const
index_t
C_m_tile_offset
=
block_m_id
*
MPerBlock
;
const
index_t
C_thread_tile_m_idx
=
get_thread_local_1d_id
()
/
NPerBlock
;
const
index_t
C_n_tile_offset
=
block_n_id
*
NPerBlock
;
const
index_t
C_thread_tile_n_idx
=
get_thread_local_1d_id
()
%
NPerBlock
;
p_C
[(
C_m_tile_offset
+
C_thread_tile_m_idx
)
*
stride_c
+
C_n_tile_offset
+
C_thread_tile_n_idx
]
=
partial_result
;
}
#else
ignore
=
p_input
;
ignore
=
p_output
;
ignore
=
p_workspace
;
ignore
=
p_flags
;
ignore
=
tile_count
;
ignore
=
k_batch
;
#endif
}
}
// namespace
template
<
index_t
BlockSize
,
index_t
MPerBlock
,
index_t
NPerBlock
,
index_t
KPerBlock
>
struct
GemmStridedTileLoopReduce
{
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
AElementOp
=
PassThrough
;
using
BElementOp
=
PassThrough
;
using
CElementOp
=
PassThrough
;
using
ADataType
=
float
;
using
BDataType
=
float
;
using
CDataType
=
float
;
using
AccDataType
=
float
;
constexpr
static
auto
DeviceGemmKernel
=
gemm_naive_strided_tile_loop_reduce
<
MPerBlock
,
NPerBlock
,
KPerBlock
>
;
using
ReferenceGemmInstance
=
ck
::
tensor_operation
::
host
::
ReferenceGemm
<
ADataType
,
BDataType
,
CDataType
,
AccDataType
,
AElementOp
,
BElementOp
,
CElementOp
>
;
GemmStridedTileLoopReduce
()
=
default
;
bool
Run
(
index_t
M
,
index_t
N
,
index_t
K
,
index_t
k_batch
)
{
Tensor
<
float
>
a_m_k
(
HostTensorDescriptor
({
M
,
K
},
{
K
,
1
}));
Tensor
<
float
>
b_k_n
(
HostTensorDescriptor
({
K
,
N
},
{
N
,
1
}));
ck
::
utils
::
FillUniformDistributionIntegerValue
<
ADataType
>
{
-
5.
f
,
5.
f
}(
a_m_k
);
ck
::
utils
::
FillUniformDistributionIntegerValue
<
BDataType
>
{
-
5.
f
,
5.
f
}(
b_k_n
);
Tensor
<
float
>
c_m_n_host
(
HostTensorDescriptor
({
M
,
N
},
{
N
,
1
}));
Tensor
<
float
>
c_m_n_device
(
HostTensorDescriptor
({
M
,
N
},
{
N
,
1
}));
DeviceMem
a_m_k_device_buf
(
sizeof
(
float
)
*
a_m_k
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
b_k_n_device_buf
(
sizeof
(
float
)
*
b_k_n
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
c_m_n_device_buf
(
sizeof
(
float
)
*
c_m_n_device
.
mDesc
.
GetElementSpaceSize
());
a_m_k_device_buf
.
ToDevice
(
a_m_k
.
mData
.
data
());
b_k_n_device_buf
.
ToDevice
(
b_k_n
.
mData
.
data
());
c_m_n_device_buf
.
SetZero
();
c_m_n_host
.
SetZero
();
DeviceMem
gemm_workspace
,
gemm_flags
;
BlockToCTileMap_LinearKSplit
<
MPerBlock
,
NPerBlock
>
b2c_tile_map
(
M
,
N
,
k_batch
);
const
index_t
tile_count
=
b2c_tile_map
.
CalculateGridSize
(
M
,
N
);
const
index_t
grid_size
=
tile_count
/
4
;
const
index_t
tiles_per_block
=
(
tile_count
+
grid_size
-
1
)
/
grid_size
;
// This is the number of MN-output tiles which we cover with workgroups.
// We launch k_batch / tiles_per_block workgroups for each output tile.
const
index_t
flag_count
=
(
grid_size
*
tiles_per_block
+
k_batch
-
1
)
/
k_batch
;
gemm_workspace
.
Realloc
(
grid_size
*
MPerBlock
*
NPerBlock
*
sizeof
(
float
));
gemm_flags
.
Realloc
(
flag_count
*
sizeof
(
uint32_t
));
gemm_workspace
.
SetZero
();
gemm_flags
.
SetZero
();
launch_and_time_kernel
(
StreamConfig
{
nullptr
,
false
},
DeviceGemmKernel
,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
M
,
N
,
K
,
reinterpret_cast
<
const
float
*>
(
a_m_k_device_buf
.
GetDeviceBuffer
()),
reinterpret_cast
<
const
float
*>
(
b_k_n_device_buf
.
GetDeviceBuffer
()),
reinterpret_cast
<
float
*>
(
c_m_n_device_buf
.
GetDeviceBuffer
()),
reinterpret_cast
<
float
*>
(
gemm_workspace
.
GetDeviceBuffer
()),
reinterpret_cast
<
uint32_t
*>
(
gemm_flags
.
GetDeviceBuffer
()),
tile_count
,
k_batch
);
auto
a_element_op
=
AElementOp
{};
auto
b_element_op
=
BElementOp
{};
auto
c_element_op
=
CElementOp
{};
auto
ref_gemm
=
ReferenceGemmInstance
{};
auto
ref_invoker
=
ref_gemm
.
MakeInvoker
();
auto
ref_argument
=
ref_gemm
.
MakeArgument
(
a_m_k
,
b_k_n
,
c_m_n_host
,
a_element_op
,
b_element_op
,
c_element_op
);
ref_invoker
.
Run
(
ref_argument
);
c_m_n_device_buf
.
FromDevice
(
c_m_n_device
.
mData
.
data
());
return
ck
::
utils
::
check_err
(
c_m_n_device
,
c_m_n_host
);
}
};
TEST
(
TestStridedReductionTileLoop
,
SingleDataTile
)
{
constexpr
index_t
MPerBlock
=
8
;
constexpr
index_t
NPerBlock
=
32
;
constexpr
index_t
KPerBlock
=
32
;
constexpr
index_t
BlockSize
=
256
;
const
index_t
kbatch
=
4
;
EXPECT_TRUE
((
GemmStridedTileLoopReduce
<
BlockSize
,
MPerBlock
,
NPerBlock
,
KPerBlock
>
{}.
Run
(
MPerBlock
,
NPerBlock
,
KPerBlock
*
kbatch
,
kbatch
)));
}
TEST
(
TestStridedReductionTileLoop
,
SingleOutputMultipleDataTiles
)
{
constexpr
index_t
MPerBlock
=
8
;
constexpr
index_t
NPerBlock
=
32
;
constexpr
index_t
KPerBlock
=
32
;
constexpr
index_t
BlockSize
=
256
;
const
index_t
kbatch
=
16
;
EXPECT_TRUE
((
GemmStridedTileLoopReduce
<
BlockSize
,
MPerBlock
,
NPerBlock
,
KPerBlock
>
{}.
Run
(
MPerBlock
,
NPerBlock
,
KPerBlock
*
kbatch
,
kbatch
)));
}
TEST
(
TestStridedReductionTileLoop
,
MultipleDataTiles
)
{
constexpr
index_t
MPerBlock
=
8
;
constexpr
index_t
NPerBlock
=
32
;
constexpr
index_t
KPerBlock
=
32
;
constexpr
index_t
BlockSize
=
256
;
const
index_t
kbatch
=
16
;
EXPECT_TRUE
((
GemmStridedTileLoopReduce
<
BlockSize
,
MPerBlock
,
NPerBlock
,
KPerBlock
>
{}.
Run
(
MPerBlock
*
4
,
NPerBlock
*
4
,
KPerBlock
*
kbatch
,
kbatch
)));
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment