Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
8f571c0b
Commit
8f571c0b
authored
Apr 24, 2024
by
Harisankar Sadasivan
Browse files
files modified for gemm with atomics and reduction
parent
5490b99c
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
109 additions
and
42 deletions
+109
-42
include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_streamk.hpp
...sor_operation/gpu/device/impl/device_gemm_xdl_streamk.hpp
+40
-14
include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp
include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp
+29
-10
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_streamk.hpp
...ensor_operation/gpu/grid/gridwise_gemm_xdlops_streamk.hpp
+40
-18
No files found.
include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_streamk.hpp
100644 → 100755
View file @
8f571c0b
...
@@ -78,8 +78,8 @@ struct DeviceGemmXdlStreamK : public DeviceGemmStreamK<ALayout,
...
@@ -78,8 +78,8 @@ struct DeviceGemmXdlStreamK : public DeviceGemmStreamK<ALayout,
BlockToCTileMap_GemmStreamK
<
MPerBlock
,
BlockToCTileMap_GemmStreamK
<
MPerBlock
,
NPerBlock
,
NPerBlock
,
K0PerBlock
*
K1
,
K0PerBlock
*
K1
,
//
StreamKReductionStrategy::Atomic>,
StreamKReductionStrategy
::
Atomic
>
,
StreamKReductionStrategy
::
Reduction
>
,
//
StreamKReductionStrategy::Reduction>,
ADataType
,
// TODO: distinguish A/B datatype
ADataType
,
// TODO: distinguish A/B datatype
AccDataType
,
AccDataType
,
CDataType
,
CDataType
,
...
@@ -149,10 +149,10 @@ struct DeviceGemmXdlStreamK : public DeviceGemmStreamK<ALayout,
...
@@ -149,10 +149,10 @@ struct DeviceGemmXdlStreamK : public DeviceGemmStreamK<ALayout,
hipDevice_t
dev
;
hipDevice_t
dev
;
hip_check_error
(
hipGetDevice
(
&
dev
));
hip_check_error
(
hipGetDevice
(
&
dev
));
hip_check_error
(
hipGetDeviceProperties
(
&
dev_prop
,
dev
));
hip_check_error
(
hipGetDeviceProperties
(
&
dev_prop
,
dev
));
num_cu
=
dev_prop
.
multiProcessorCount
;
num_cu
=
dev_prop
.
multiProcessorCount
;
dim3
grid_dims
=
karg
.
block_mapping
.
sk_num_blocks
;
dim3
grid_dims
=
printf
(
"Recommended #stream-k blocks (assuming full GPU availability): %0d
\n
"
,
(
karg
.
block_mapping
.
sk_num_blocks
?
karg
.
block_mapping
.
sk_num_blocks
num_cu
*
occupancy
);
:
karg
.
block_mapping
.
reduction_start_block_idx
);
float
ave_time
=
0
;
float
ave_time
=
0
;
// TODO: remove clear buffer for streamk kernels
// TODO: remove clear buffer for streamk kernels
...
@@ -187,11 +187,8 @@ struct DeviceGemmXdlStreamK : public DeviceGemmStreamK<ALayout,
...
@@ -187,11 +187,8 @@ struct DeviceGemmXdlStreamK : public DeviceGemmStreamK<ALayout,
karg
.
block_mapping
.
get_workspace_size_for_acc
(
karg
.
block_mapping
.
get_workspace_size_for_acc
(
sizeof
(
typename
GridwiseGemm
::
FloatAcc
));
sizeof
(
typename
GridwiseGemm
::
FloatAcc
));
auto
preprocess
=
[
&
]()
{
auto
preprocess
=
[
&
]()
{
hipGetErrorString
(
hipGetErrorString
(
hipMemsetAsync
(
hipMemsetAsync
(
workspace_semaphore
,
workspace_semaphore
,
0
,
sizeof
(
num_cu
),
stream_config
.
stream_id_
));
0
,
karg
.
block_mapping
.
get_workspace_size_for_semaphore
(),
stream_config
.
stream_id_
));
};
};
ave_time
=
launch_and_time_kernel_with_preprocess
(
stream_config
,
ave_time
=
launch_and_time_kernel_with_preprocess
(
stream_config
,
...
@@ -282,8 +279,27 @@ struct DeviceGemmXdlStreamK : public DeviceGemmStreamK<ALayout,
...
@@ -282,8 +279,27 @@ struct DeviceGemmXdlStreamK : public DeviceGemmStreamK<ALayout,
CElementwiseOperation
,
CElementwiseOperation
,
uint32_t
NumSKBlocks
=
0
)
uint32_t
NumSKBlocks
=
0
)
{
{
const
auto
kernel
=
kernel_gemm_xdlops_streamk
<
GridwiseGemm
>
;
return
Argument
{
p_a
,
p_b
,
p_c
,
M
,
N
,
K
,
StrideA
,
StrideB
,
StrideC
,
NumSKBlocks
};
int
occupancy
,
num_cu
;
hip_check_error
(
hipOccupancyMaxActiveBlocksPerMultiprocessor
(
&
occupancy
,
kernel
,
BlockSize
,
0
));
hipDeviceProp_t
dev_prop
;
hipDevice_t
dev
;
hip_check_error
(
hipGetDevice
(
&
dev
));
hip_check_error
(
hipGetDeviceProperties
(
&
dev_prop
,
dev
));
num_cu
=
dev_prop
.
multiProcessorCount
;
return
Argument
{
p_a
,
p_b
,
p_c
,
M
,
N
,
K
,
StrideA
,
StrideB
,
StrideC
,
static_cast
<
uint32_t
>
(
num_cu
),
static_cast
<
uint32_t
>
(
occupancy
),
NumSKBlocks
};
}
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
...
@@ -303,7 +319,15 @@ struct DeviceGemmXdlStreamK : public DeviceGemmStreamK<ALayout,
...
@@ -303,7 +319,15 @@ struct DeviceGemmXdlStreamK : public DeviceGemmStreamK<ALayout,
CElementwiseOperation
,
CElementwiseOperation
,
index_t
NumSKBlocks
=
0
)
override
index_t
NumSKBlocks
=
0
)
override
{
{
const
auto
kernel
=
kernel_gemm_xdlops_streamk
<
GridwiseGemm
>
;
int
occupancy
,
num_cu
;
hip_check_error
(
hipOccupancyMaxActiveBlocksPerMultiprocessor
(
&
occupancy
,
kernel
,
BlockSize
,
0
));
hipDeviceProp_t
dev_prop
;
hipDevice_t
dev
;
hip_check_error
(
hipGetDevice
(
&
dev
));
hip_check_error
(
hipGetDeviceProperties
(
&
dev_prop
,
dev
));
num_cu
=
dev_prop
.
multiProcessorCount
;
return
std
::
make_unique
<
Argument
>
(
reinterpret_cast
<
const
ADataType
*>
(
p_a
),
return
std
::
make_unique
<
Argument
>
(
reinterpret_cast
<
const
ADataType
*>
(
p_a
),
reinterpret_cast
<
const
BDataType
*>
(
p_b
),
reinterpret_cast
<
const
BDataType
*>
(
p_b
),
reinterpret_cast
<
CDataType
*>
(
p_c
),
reinterpret_cast
<
CDataType
*>
(
p_c
),
...
@@ -313,6 +337,8 @@ struct DeviceGemmXdlStreamK : public DeviceGemmStreamK<ALayout,
...
@@ -313,6 +337,8 @@ struct DeviceGemmXdlStreamK : public DeviceGemmStreamK<ALayout,
StrideA
,
StrideA
,
StrideB
,
StrideB
,
StrideC
,
StrideC
,
static_cast
<
uint32_t
>
(
num_cu
),
static_cast
<
uint32_t
>
(
occupancy
),
static_cast
<
uint32_t
>
(
NumSKBlocks
));
static_cast
<
uint32_t
>
(
NumSKBlocks
));
}
}
...
...
include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp
100644 → 100755
View file @
8f571c0b
...
@@ -1000,7 +1000,7 @@ struct BlockToCTileMap_GemmStreamK
...
@@ -1000,7 +1000,7 @@ struct BlockToCTileMap_GemmStreamK
//--------------------------------------
//--------------------------------------
// pass to device
// pass to device
uint32_t
sk_num_blocks
;
mutable
uint32_t
sk_num_blocks
;
uint32_t
sk_num_big_blocks
;
uint32_t
sk_num_big_blocks
;
uint32_t
dp_start_block_idx
;
uint32_t
dp_start_block_idx
;
uint32_t
reduction_start_block_idx
;
uint32_t
reduction_start_block_idx
;
...
@@ -1011,7 +1011,12 @@ struct BlockToCTileMap_GemmStreamK
...
@@ -1011,7 +1011,12 @@ struct BlockToCTileMap_GemmStreamK
MDiv
equiv_tiles_little
;
// for reduction
MDiv
equiv_tiles_little
;
// for reduction
// prefer construct on host
// prefer construct on host
BlockToCTileMap_GemmStreamK
(
uint32_t
m
,
uint32_t
n
,
uint32_t
k
,
uint32_t
sk_blocks
=
0
)
BlockToCTileMap_GemmStreamK
(
uint32_t
m
,
uint32_t
n
,
uint32_t
k
,
uint32_t
num_cu
,
uint32_t
occupancy
,
uint32_t
sk_blocks
=
0
)
{
{
// total output tiles
// total output tiles
uint32_t
num_tiles
=
uint32_t
num_tiles
=
...
@@ -1019,10 +1024,24 @@ struct BlockToCTileMap_GemmStreamK
...
@@ -1019,10 +1024,24 @@ struct BlockToCTileMap_GemmStreamK
k_iters_per_tile
=
MDiv
(
math
::
integer_divide_ceil
(
k
,
KPerBlock
));
k_iters_per_tile
=
MDiv
(
math
::
integer_divide_ceil
(
k
,
KPerBlock
));
uint32_t
dp_tiles
,
dp_num_blocks
,
sk_total_iters
;
uint32_t
dp_tiles
,
dp_num_blocks
,
sk_total_iters
;
const
uint32_t
one_wave
=
num_cu
*
occupancy
;
// default to regular DP GEMM if sk blocks == 0
if
((
sk_blocks
>
one_wave
)
&&
(
num_tiles
>
one_wave
))
sk_num_blocks
=
sk_blocks
;
{
printf
(
"WARNING: Do not tune above max possible occupancy for the kernel, "
"defaulting to max occupancy
\n
"
);
sk_num_blocks
=
one_wave
;
}
else
if
(
sk_blocks
<
one_wave
)
{
printf
(
"Recommended #stream-k blocks (assuming full GPU availability): %0d
\n
"
,
one_wave
);
sk_num_blocks
=
sk_blocks
;
}
else
sk_num_blocks
=
sk_blocks
;
// default to regular DP GEMM if sk blocks == 0
if
(
sk_num_blocks
==
0
||
sk_num_blocks
==
0xFFFFFFFF
)
if
(
sk_num_blocks
==
0
||
sk_num_blocks
==
0xFFFFFFFF
)
{
{
sk_num_blocks
=
0
;
sk_num_blocks
=
0
;
...
@@ -1064,7 +1083,7 @@ struct BlockToCTileMap_GemmStreamK
...
@@ -1064,7 +1083,7 @@ struct BlockToCTileMap_GemmStreamK
k_iters_per_big_block
=
k_iters_per_sk_block
+
1
;
k_iters_per_big_block
=
k_iters_per_sk_block
+
1
;
dp_num_blocks
=
dp_tiles
;
dp_num_blocks
=
dp_tiles
;
dp_start_block_idx
=
sk_num_blocks
;
dp_start_block_idx
=
((
sk_num_blocks
+
grid_size
-
1
)
/
grid_size
)
*
grid_size
;
}
}
n_tiles
=
MDiv2
(
math
::
integer_divide_ceil
(
n
,
NPerBlock
));
n_tiles
=
MDiv2
(
math
::
integer_divide_ceil
(
n
,
NPerBlock
));
...
@@ -1079,15 +1098,15 @@ struct BlockToCTileMap_GemmStreamK
...
@@ -1079,15 +1098,15 @@ struct BlockToCTileMap_GemmStreamK
equiv_tiles_little
=
MDiv
(
upper_little
/
k_iters_per_tile
.
get
());
equiv_tiles_little
=
MDiv
(
upper_little
/
k_iters_per_tile
.
get
());
}
}
#if
0
#if
1
printf("
cu:%d, occupancy:%d, gridsize:%d,
num_tiles:%d, dp_tiles:%d, sk_num_big_blocks:%d, "
printf
(
"num_tiles:%d, dp_tiles:%d, sk_num_big_blocks:%d, "
"sk_num_blocks:%d, "
"sk_num_blocks:%d, "
"sk_total_iters:%d, dp_start_block_idx:%d, dp_num_blocks:%d, "
"sk_total_iters:%d, dp_start_block_idx:%d, dp_num_blocks:%d, "
"k_iters_per_tile:%d, k_iters_per_big_block:%d, reduction_start_block_idx:%u, "
"k_iters_per_tile:%d, k_iters_per_big_block:%d, reduction_start_block_idx:%u, "
"sk_tiles:%u, workspace(acc float):%u
\n
"
,
"sk_tiles:%u, workspace(acc float):%u
\n
"
,
num_cu,
//
num_cu,
occupancy,
//
occupancy,
get_grid_dims(num_cu, occupancy).x,
//
get_grid_dims(num_cu, occupancy).x,
num_tiles
,
num_tiles
,
dp_tiles
,
dp_tiles
,
sk_num_big_blocks
,
sk_num_big_blocks
,
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_streamk.hpp
100644 → 100755
View file @
8f571c0b
...
@@ -23,19 +23,19 @@ namespace ck {
...
@@ -23,19 +23,19 @@ namespace ck {
template
<
typename
GridwiseGemm
>
template
<
typename
GridwiseGemm
>
__global__
void
__global__
void
#if CK_USE_LAUNCH_BOUNDS
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
#endif
#endif
kernel_gemm_xdlops_streamk
(
const
typename
GridwiseGemm
::
FloatAB
*
p_a_grid
,
kernel_gemm_xdlops_streamk
(
const
typename
GridwiseGemm
::
FloatAB
*
p_a_grid
,
const
typename
GridwiseGemm
::
FloatAB
*
p_b_grid
,
const
typename
GridwiseGemm
::
FloatAB
*
p_b_grid
,
typename
GridwiseGemm
::
FloatC
*
p_c_grid
,
typename
GridwiseGemm
::
FloatC
*
p_c_grid
,
void
*
p_workspace
,
void
*
p_workspace
,
index_t
M
,
index_t
M
,
index_t
N
,
index_t
N
,
index_t
K
,
index_t
K
,
index_t
StrideA
,
index_t
StrideA
,
index_t
StrideB
,
index_t
StrideB
,
index_t
StrideC
,
index_t
StrideC
,
typename
GridwiseGemm
::
Block2CTileMap
block_mapping
)
typename
GridwiseGemm
::
Block2CTileMap
block_mapping
)
{
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx94__))
defined(__gfx94__))
...
@@ -145,6 +145,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
...
@@ -145,6 +145,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
index_t
StrideA
;
index_t
StrideA
;
index_t
StrideB
;
index_t
StrideB
;
index_t
StrideC
;
index_t
StrideC
;
index_t
num_cu
,
occupancy
;
// stream-k arguments
Block2CTileMap
block_mapping
;
Block2CTileMap
block_mapping
;
Argument
(
const
FloatAB
*
p_a_grid_
,
Argument
(
const
FloatAB
*
p_a_grid_
,
...
@@ -156,6 +157,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
...
@@ -156,6 +157,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
index_t
StrideA_
,
index_t
StrideA_
,
index_t
StrideB_
,
index_t
StrideB_
,
index_t
StrideC_
,
index_t
StrideC_
,
uint32_t
num_cu_
,
uint32_t
occupancy_
,
uint32_t
num_sk_blocks_
)
uint32_t
num_sk_blocks_
)
:
p_a_grid
(
p_a_grid_
),
:
p_a_grid
(
p_a_grid_
),
p_b_grid
(
p_b_grid_
),
p_b_grid
(
p_b_grid_
),
...
@@ -166,7 +169,9 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
...
@@ -166,7 +169,9 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
StrideA
(
StrideA_
),
StrideA
(
StrideA_
),
StrideB
(
StrideB_
),
StrideB
(
StrideB_
),
StrideC
(
StrideC_
),
StrideC
(
StrideC_
),
block_mapping
(
M
,
N
,
K
,
num_sk_blocks_
)
num_cu
(
num_cu_
),
occupancy
(
occupancy_
),
block_mapping
(
M
,
N
,
K
,
num_cu_
,
occupancy_
,
num_sk_blocks_
)
{
{
}
}
...
@@ -518,11 +523,10 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
...
@@ -518,11 +523,10 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
// gridwise GEMM pipeline
// gridwise GEMM pipeline
const
auto
gridwise_gemm_pipeline
=
GridwiseGemmPipeline_v3
();
const
auto
gridwise_gemm_pipeline
=
GridwiseGemmPipeline_v3
();
uint32_t
*
p_semaphore
=
reinterpret_cast
<
uint32_t
*>
(
reinterpret_cast
<
char
*>
(
p_workspace
)
+
block_mapping
.
get_workspace_size_for_acc
(
sizeof
(
FloatAcc
)));
// offset for last acc buffer of this block
uint32_t
block_acc_offset
=
(
block_mapping
.
get_acc_buffer_offset_from_block
(
block_idx
+
1
)
-
1
)
*
MPerBlock
*
NPerBlock
;
uint32_t
iter_start
,
iter_end
;
uint32_t
iter_start
,
iter_end
;
bool
is_sk_block
,
is_dp_block
,
is_padding_block
,
is_reduction_block
;
bool
is_sk_block
,
is_dp_block
,
is_padding_block
,
is_reduction_block
;
...
@@ -555,11 +559,15 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
...
@@ -555,11 +559,15 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
I1
,
I1
,
I1
,
Number
<
CBlockTransferScalarPerVector_NWaveNPerXDL
>
{}));
// thread
I1
,
I1
,
I1
,
Number
<
CBlockTransferScalarPerVector_NWaveNPerXDL
>
{}));
// thread
// buf STORE
// buf STORE
// descriptor
// descriptor
#pragma unroll
// stream-k: for new work for all the persistent blocks.
// stream-k: for new work for all the persistent blocks.
for
(;
block_idx
<
block_mapping
.
total_blocks_allocated
();
block_idx
+=
gridDim
.
x
)
for
(;
block_idx
<
block_mapping
.
total_blocks_allocated
();
block_idx
+=
gridDim
.
x
)
{
{
// offset for last acc buffer of this block
uint32_t
block_acc_offset
=
(
block_mapping
.
get_acc_buffer_offset_from_block
(
block_idx
+
1
)
-
1
)
*
MPerBlock
*
NPerBlock
;
is_sk_block
=
block_idx
<
block_mapping
.
sk_num_blocks
;
is_sk_block
=
block_idx
<
block_mapping
.
sk_num_blocks
;
is_dp_block
=
block_idx
>=
block_mapping
.
dp_start_block_idx
&&
is_dp_block
=
block_idx
>=
block_mapping
.
dp_start_block_idx
&&
block_idx
<
block_mapping
.
reduction_start_block_idx
;
block_idx
<
block_mapping
.
reduction_start_block_idx
;
...
@@ -621,6 +629,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
...
@@ -621,6 +629,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
// start to compute
// start to compute
auto
reduction_idx
=
block_idx
-
block_mapping
.
reduction_start_block_idx
;
auto
reduction_idx
=
block_idx
-
block_mapping
.
reduction_start_block_idx
;
auto
spatial_idx
=
block_mapping
.
tile_to_spatial
(
reduction_idx
,
m
,
n
);
auto
spatial_idx
=
block_mapping
.
tile_to_spatial
(
reduction_idx
,
m
,
n
);
workgroup_barrier
wg_barrier
(
p_semaphore
);
uint32_t
tile_acc_offset_start
=
uint32_t
tile_acc_offset_start
=
block_mapping
.
get_acc_buffer_offset_from_tile
(
reduction_idx
);
block_mapping
.
get_acc_buffer_offset_from_tile
(
reduction_idx
);
...
@@ -666,6 +675,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
...
@@ -666,6 +675,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
thread_n_cluster_id
*
thread_n_cluster_id
*
CBlockTransferScalarPerVector_NWaveNPerXDL
),
CBlockTransferScalarPerVector_NWaveNPerXDL
),
CElementwiseOperation
{}};
CElementwiseOperation
{}};
// block synchronization
wg_barrier
.
wait_eq
(
0
,
block_mapping
.
sk_num_blocks
);
#if 0
#if 0
if(threadIdx.x == 0) {
if(threadIdx.x == 0) {
...
@@ -1142,6 +1153,17 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
...
@@ -1142,6 +1153,17 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
// make sure next loop LDS is ready for use
// make sure next loop LDS is ready for use
block_sync_lds
();
block_sync_lds
();
}
}
if
constexpr
(
Block2CTileMap
::
ReductionStrategy
==
StreamKReductionStrategy
::
Reduction
)
{
if
(
is_sk_block
)
{
// increase the counter for this tile
workgroup_barrier
wg_barrier
(
p_semaphore
);
wg_barrier
.
inc
(
0
);
// printf("block_idx=%0d, \n",block_idx);
}
}
}
}
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment