Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
a3b4c5cb
Commit
a3b4c5cb
authored
Jun 03, 2022
by
wangshaojie6
Browse files
merge develop branch and add gridwise pipeline v3
parents
48918ab9
1677cf70
Changes
361
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
2336 additions
and
2993 deletions
+2336
-2993
include/ck/tensor_operation/gpu/device/device_gemm_xdl_splitk.hpp
...ck/tensor_operation/gpu/device/device_gemm_xdl_splitk.hpp
+44
-56
include/ck/tensor_operation/gpu/device/device_gemm_xdl_splitk_c_shuffle.hpp
...operation/gpu/device/device_gemm_xdl_splitk_c_shuffle.hpp
+38
-59
include/ck/tensor_operation/gpu/device/device_grouped_gemm_xdl.hpp
...k/tensor_operation/gpu/device/device_grouped_gemm_xdl.hpp
+175
-89
include/ck/tensor_operation/gpu/device/device_pool2d_fwd_nhwc_nhwc.hpp
...nsor_operation/gpu/device/device_pool2d_fwd_nhwc_nhwc.hpp
+28
-28
include/ck/tensor_operation/gpu/device/device_reduce.hpp
include/ck/tensor_operation/gpu/device/device_reduce.hpp
+6
-23
include/ck/tensor_operation/gpu/device/device_reduce_blockwise.hpp
...k/tensor_operation/gpu/device/device_reduce_blockwise.hpp
+0
-373
include/ck/tensor_operation/gpu/device/device_reduce_blockwise_second_call.hpp
...ration/gpu/device/device_reduce_blockwise_second_call.hpp
+0
-327
include/ck/tensor_operation/gpu/device/device_reduce_common.hpp
...e/ck/tensor_operation/gpu/device/device_reduce_common.hpp
+9
-9
include/ck/tensor_operation/gpu/device/device_reduce_multiblock.hpp
.../tensor_operation/gpu/device/device_reduce_multiblock.hpp
+211
-138
include/ck/tensor_operation/gpu/device/device_reduce_multiblock_partial_reduce.hpp
...on/gpu/device/device_reduce_multiblock_partial_reduce.hpp
+0
-439
include/ck/tensor_operation/gpu/device/device_reduce_threadwise.hpp
.../tensor_operation/gpu/device/device_reduce_threadwise.hpp
+76
-78
include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp
...r_operation/gpu/element/binary_element_wise_operation.hpp
+127
-0
include/ck/tensor_operation/gpu/element/element_wise_operation.hpp
...k/tensor_operation/gpu/element/element_wise_operation.hpp
+22
-36
include/ck/tensor_operation/gpu/element/element_wise_reduce_operation.hpp
...r_operation/gpu/element/element_wise_reduce_operation.hpp
+0
-14
include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp
include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp
+489
-0
include/ck/tensor_operation/gpu/grid/gridwise_2d_reduction_blockwise.hpp
...or_operation/gpu/grid/gridwise_2d_reduction_blockwise.hpp
+0
-886
include/ck/tensor_operation/gpu/grid/gridwise_2d_reduction_multiblock.hpp
...r_operation/gpu/grid/gridwise_2d_reduction_multiblock.hpp
+638
-0
include/ck/tensor_operation/gpu/grid/gridwise_2d_reduction_multiblock_atomic_add.hpp
.../gpu/grid/gridwise_2d_reduction_multiblock_atomic_add.hpp
+0
-269
include/ck/tensor_operation/gpu/grid/gridwise_2d_reduction_threadwise.hpp
...r_operation/gpu/grid/gridwise_2d_reduction_threadwise.hpp
+222
-169
include/ck/tensor_operation/gpu/grid/gridwise_5ary_Elementwise_1d.hpp
...ensor_operation/gpu/grid/gridwise_5ary_Elementwise_1d.hpp
+251
-0
No files found.
include/ck/tensor_operation/gpu/device/device_gemm_xdl_splitk.hpp
View file @
a3b4c5cb
...
@@ -12,6 +12,7 @@
...
@@ -12,6 +12,7 @@
#include "tensor_descriptor_helper.hpp"
#include "tensor_descriptor_helper.hpp"
#include "gridwise_gemm_xdlops_v2r4.hpp"
#include "gridwise_gemm_xdlops_v2r4.hpp"
#include "gemm_specialization.hpp"
#include "gemm_specialization.hpp"
#include "device_prop.hpp"
#ifndef CK_RUN_KERNEL_AND_TIME
#ifndef CK_RUN_KERNEL_AND_TIME
#define CK_RUN_KERNEL_AND_TIME 1
#define CK_RUN_KERNEL_AND_TIME 1
...
@@ -332,17 +333,16 @@ struct DeviceGemmXdlSplitK
...
@@ -332,17 +333,16 @@ struct DeviceGemmXdlSplitK
K
,
N
,
StrideB
,
k_batch_
,
KPad
);
K
,
N
,
StrideB
,
k_batch_
,
KPad
);
c_grid_desc_m_n_
=
DeviceGemmXdlSplitK
::
MakeCGridDescriptor_M_N
(
M
,
N
,
StrideC
);
c_grid_desc_m_n_
=
DeviceGemmXdlSplitK
::
MakeCGridDescriptor_M_N
(
M
,
N
,
StrideC
);
block_2_ctile_map_
=
GridwiseGemm
::
MakeCBlockClusterAdaptor
(
c_grid_desc_m_n_
,
M01
,
N01
,
k_batch_
);
if
(
GridwiseGemm
::
CheckValidity
(
a_grid_desc_kbatch_k0_m_k1_
,
if
(
GridwiseGemm
::
CheckValidity
(
a_grid_desc_kbatch_k0_m_k1_
,
b_grid_desc_kbatch_k0_n_k1_
,
b_grid_desc_kbatch_k0_n_k1_
,
c_grid_desc_m_n_
,
c_grid_desc_m_n_
,
M01_
,
block_2_ctile_map_
))
N01_
))
{
{
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_
=
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_
=
GridwiseGemm
::
MakeCM0N0M1N1M2M3M4N2GridDescriptor
(
c_grid_desc_m_n_
);
GridwiseGemm
::
MakeCM0N0M1N1M2M3M4N2GridDescriptor
(
c_grid_desc_m_n_
);
block_2_ctile_map_
=
GridwiseGemm
::
MakeCBlockClusterAdaptor
(
c_grid_desc_m_n_
,
M01
,
N01
,
k_batch_
);
}
}
}
}
...
@@ -385,21 +385,24 @@ struct DeviceGemmXdlSplitK
...
@@ -385,21 +385,24 @@ struct DeviceGemmXdlSplitK
std
::
cout
<<
"arg.c_grid_desc_m_n_{ "
<<
arg
.
c_grid_desc_m_n_
.
GetLength
(
I0
)
<<
", "
std
::
cout
<<
"arg.c_grid_desc_m_n_{ "
<<
arg
.
c_grid_desc_m_n_
.
GetLength
(
I0
)
<<
", "
<<
arg
.
c_grid_desc_m_n_
.
GetLength
(
I1
)
<<
"}"
<<
std
::
endl
;
<<
arg
.
c_grid_desc_m_n_
.
GetLength
(
I1
)
<<
"}"
<<
std
::
endl
;
}
}
float
Run
(
const
Argument
&
arg
,
int
nrepeat
=
1
)
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
{
{
ShowInfo
(
arg
);
const
auto
kbatch
=
arg
.
a_grid_desc_kbatch_k0_m_k1_
.
GetLength
(
I0
);
const
auto
kbatch
=
arg
.
a_grid_desc_kbatch_k0_m_k1_
.
GetLength
(
I0
);
if
(
!
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_kbatch_k0_m_k1_
,
if
(
!
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_kbatch_k0_m_k1_
,
arg
.
b_grid_desc_kbatch_k0_n_k1_
,
arg
.
b_grid_desc_kbatch_k0_n_k1_
,
arg
.
c_grid_desc_m_n_
,
arg
.
c_grid_desc_m_n_
,
arg
.
M01_
,
arg
.
block_2_ctile_map_
))
arg
.
N01_
))
{
{
throw
std
::
runtime_error
(
throw
std
::
runtime_error
(
"wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v2r3 has invalid setting"
);
"wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v2r3 has invalid setting"
);
}
}
const
index_t
grid_size
=
GridwiseGemm
::
CalculateGridSize
(
arg
.
c_grid_desc_m_n_
,
kbatch
);
const
index_t
grid_size
=
arg
.
block_2_ctile_map_
.
CalculateGridSize
(
arg
.
c_grid_desc_m_n_
);
const
auto
K0
=
arg
.
a_grid_desc_kbatch_k0_m_k1_
.
GetLength
(
I1
);
const
auto
K0
=
arg
.
a_grid_desc_kbatch_k0_m_k1_
.
GetLength
(
I1
);
...
@@ -408,50 +411,30 @@ struct DeviceGemmXdlSplitK
...
@@ -408,50 +411,30 @@ struct DeviceGemmXdlSplitK
float
ave_time
=
0
;
float
ave_time
=
0
;
const
auto
Run
=
[
&
](
const
auto
&
kernel
)
{
const
auto
Run
=
[
&
](
const
auto
&
kernel
)
{
if
(
nrepeat
>
0
)
// FIXME: this should be moved outside of DeviceOp
{
hipGetErrorString
(
ShowInfo
(
arg
);
hipMemset
(
arg
.
p_c_grid_
,
ave_time
=
launch_and_time_kernel
(
kernel
,
0
,
nrepeat
,
arg
.
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_
.
GetElementSpaceSize
()
*
dim3
(
grid_size
),
sizeof
(
CDataType
)));
dim3
(
BlockSize
),
0
,
ave_time
=
launch_and_time_kernel
(
stream_config
,
arg
.
p_a_grid_
,
kernel
,
arg
.
p_b_grid_
,
dim3
(
grid_size
),
arg
.
p_c_grid_
,
dim3
(
BlockSize
),
arg
.
a_grid_desc_kbatch_k0_m_k1_
,
0
,
arg
.
b_grid_desc_kbatch_k0_n_k1_
,
arg
.
p_a_grid_
,
arg
.
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_
,
arg
.
p_b_grid_
,
arg
.
a_element_op_
,
arg
.
p_c_grid_
,
arg
.
b_element_op_
,
arg
.
a_grid_desc_kbatch_k0_m_k1_
,
arg
.
c_element_op_
,
arg
.
b_grid_desc_kbatch_k0_n_k1_
,
arg
.
block_2_ctile_map_
);
arg
.
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_
,
}
arg
.
a_element_op_
,
arg
.
b_element_op_
,
if
(
kbatch
>
1
||
nrepeat
<=
0
)
arg
.
c_element_op_
,
{
arg
.
block_2_ctile_map_
);
hipGetErrorString
(
hipMemset
(
arg
.
p_c_grid_
,
0
,
arg
.
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_
.
GetElementSpaceSize
()
*
sizeof
(
CDataType
)));
launch_kernel
(
kernel
,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
arg
.
p_a_grid_
,
arg
.
p_b_grid_
,
arg
.
p_c_grid_
,
arg
.
a_grid_desc_kbatch_k0_m_k1_
,
arg
.
b_grid_desc_kbatch_k0_n_k1_
,
arg
.
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_
,
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
c_element_op_
,
arg
.
block_2_ctile_map_
);
}
};
};
if
(
has_main_k0_block_loop
)
if
(
has_main_k0_block_loop
)
{
{
if
(
kbatch
==
1
)
if
(
kbatch
==
1
)
...
@@ -531,9 +514,10 @@ struct DeviceGemmXdlSplitK
...
@@ -531,9 +514,10 @@ struct DeviceGemmXdlSplitK
}
}
// polymorphic
// polymorphic
float
Run
(
const
BaseArgument
*
p_arg
,
int
nrepeat
=
1
)
override
float
Run
(
const
BaseArgument
*
p_arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
override
{
{
return
Run
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
),
nrepeat
);
return
Run
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
),
stream_config
);
}
}
};
};
...
@@ -545,11 +529,15 @@ struct DeviceGemmXdlSplitK
...
@@ -545,11 +529,15 @@ struct DeviceGemmXdlSplitK
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
{
if
(
!
(
ck
::
get_device_name
()
==
"gfx908"
||
ck
::
get_device_name
()
==
"gfx90a"
))
{
return
false
;
}
return
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_kbatch_k0_m_k1_
,
return
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_kbatch_k0_m_k1_
,
arg
.
b_grid_desc_kbatch_k0_n_k1_
,
arg
.
b_grid_desc_kbatch_k0_n_k1_
,
arg
.
c_grid_desc_m_n_
,
arg
.
c_grid_desc_m_n_
,
arg
.
M01_
,
arg
.
block_2_ctile_map_
);
arg
.
N01_
);
}
}
// polymorphic
// polymorphic
...
...
include/ck/tensor_operation/gpu/device/device_gemm_xdl_splitk_c_shuffle.hpp
View file @
a3b4c5cb
...
@@ -292,8 +292,7 @@ struct DeviceGemmXdlSplitKCShuffle
...
@@ -292,8 +292,7 @@ struct DeviceGemmXdlSplitKCShuffle
using
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
=
using
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
=
decltype
(
GridwiseGemm
::
MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
(
CGridDesc_M_N
{}));
decltype
(
GridwiseGemm
::
MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
(
CGridDesc_M_N
{}));
using
Block2CTileMap
=
using
Block2CTileMap
=
typename
GridwiseGemm
::
CBlockClusterAdaptor
;
decltype
(
GridwiseGemm
::
MakeCBlockClusterAdaptor
(
CGridDesc_M_N
{},
1
,
1
,
1
));
// Argument
// Argument
struct
Argument
:
public
BaseArgument
struct
Argument
:
public
BaseArgument
...
@@ -338,17 +337,16 @@ struct DeviceGemmXdlSplitKCShuffle
...
@@ -338,17 +337,16 @@ struct DeviceGemmXdlSplitKCShuffle
K
,
N
,
StrideB
,
k_batch_
,
KPad
);
K
,
N
,
StrideB
,
k_batch_
,
KPad
);
c_grid_desc_m_n_
=
DeviceGemmXdlSplitKCShuffle
::
MakeCGridDescriptor_M_N
(
M
,
N
,
StrideC
);
c_grid_desc_m_n_
=
DeviceGemmXdlSplitKCShuffle
::
MakeCGridDescriptor_M_N
(
M
,
N
,
StrideC
);
block_2_ctile_map_
=
GridwiseGemm
::
MakeCBlockClusterAdaptor
(
c_grid_desc_m_n_
,
M01
,
N01
,
k_batch_
);
if
(
GridwiseGemm
::
CheckValidity
(
a_grid_desc_kbatch_k0_m_k1_
,
if
(
GridwiseGemm
::
CheckValidity
(
a_grid_desc_kbatch_k0_m_k1_
,
b_grid_desc_kbatch_k0_n_k1_
,
b_grid_desc_kbatch_k0_n_k1_
,
c_grid_desc_m_n_
,
c_grid_desc_m_n_
,
M01_
,
block_2_ctile_map_
))
N01_
))
{
{
c_grid_desc_mblock_mperblock_nblock_nperblock_
=
c_grid_desc_mblock_mperblock_nblock_nperblock_
=
GridwiseGemm
::
MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
(
c_grid_desc_m_n_
);
GridwiseGemm
::
MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
(
c_grid_desc_m_n_
);
block_2_ctile_map_
=
GridwiseGemm
::
MakeCBlockClusterAdaptor
(
c_grid_desc_m_n_
,
M01
,
N01
,
k_batch_
);
}
}
}
}
...
@@ -391,21 +389,24 @@ struct DeviceGemmXdlSplitKCShuffle
...
@@ -391,21 +389,24 @@ struct DeviceGemmXdlSplitKCShuffle
std
::
cout
<<
"arg.c_grid_desc_m_n_{ "
<<
arg
.
c_grid_desc_m_n_
.
GetLength
(
I0
)
<<
", "
std
::
cout
<<
"arg.c_grid_desc_m_n_{ "
<<
arg
.
c_grid_desc_m_n_
.
GetLength
(
I0
)
<<
", "
<<
arg
.
c_grid_desc_m_n_
.
GetLength
(
I1
)
<<
"}"
<<
std
::
endl
;
<<
arg
.
c_grid_desc_m_n_
.
GetLength
(
I1
)
<<
"}"
<<
std
::
endl
;
}
}
float
Run
(
const
Argument
&
arg
,
int
nrepeat
=
1
)
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
{
{
ShowInfo
(
arg
);
const
auto
kbatch
=
arg
.
a_grid_desc_kbatch_k0_m_k1_
.
GetLength
(
I0
);
const
auto
kbatch
=
arg
.
a_grid_desc_kbatch_k0_m_k1_
.
GetLength
(
I0
);
if
(
!
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_kbatch_k0_m_k1_
,
if
(
!
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_kbatch_k0_m_k1_
,
arg
.
b_grid_desc_kbatch_k0_n_k1_
,
arg
.
b_grid_desc_kbatch_k0_n_k1_
,
arg
.
c_grid_desc_m_n_
,
arg
.
c_grid_desc_m_n_
,
arg
.
M01_
,
arg
.
block_2_ctile_map_
))
arg
.
N01_
))
{
{
throw
std
::
runtime_error
(
throw
std
::
runtime_error
(
"wrong! GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 has invalid setting"
);
"wrong! GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 has invalid setting"
);
}
}
const
index_t
grid_size
=
GridwiseGemm
::
CalculateGridSize
(
arg
.
c_grid_desc_m_n_
,
kbatch
);
const
index_t
grid_size
=
arg
.
block_2_ctile_map_
.
CalculateGridSize
(
arg
.
c_grid_desc_m_n_
);
const
auto
K0
=
arg
.
a_grid_desc_kbatch_k0_m_k1_
.
GetLength
(
I1
);
const
auto
K0
=
arg
.
a_grid_desc_kbatch_k0_m_k1_
.
GetLength
(
I1
);
...
@@ -414,51 +415,29 @@ struct DeviceGemmXdlSplitKCShuffle
...
@@ -414,51 +415,29 @@ struct DeviceGemmXdlSplitKCShuffle
float
ave_time
=
0
;
float
ave_time
=
0
;
const
auto
Run
=
[
&
](
const
auto
&
kernel
)
{
const
auto
Run
=
[
&
](
const
auto
&
kernel
)
{
if
(
nrepeat
>
0
)
hipGetErrorString
(
hipMemset
(
{
arg
.
p_c_grid_
,
ShowInfo
(
arg
);
0
,
ave_time
=
arg
.
c_grid_desc_mblock_mperblock_nblock_nperblock_
.
GetElementSpaceSize
()
*
launch_and_time_kernel
(
kernel
,
sizeof
(
CDataType
)));
nrepeat
,
dim3
(
grid_size
),
launch_and_time_kernel
(
stream_config
,
dim3
(
BlockSize
),
kernel
,
0
,
dim3
(
grid_size
),
arg
.
p_a_grid_
,
dim3
(
BlockSize
),
arg
.
p_b_grid_
,
0
,
arg
.
p_c_grid_
,
arg
.
p_a_grid_
,
arg
.
a_grid_desc_kbatch_k0_m_k1_
,
arg
.
p_b_grid_
,
arg
.
b_grid_desc_kbatch_k0_n_k1_
,
arg
.
p_c_grid_
,
arg
.
c_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg
.
a_grid_desc_kbatch_k0_m_k1_
,
arg
.
a_element_op_
,
arg
.
b_grid_desc_kbatch_k0_n_k1_
,
arg
.
b_element_op_
,
arg
.
c_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg
.
c_element_op_
,
arg
.
a_element_op_
,
arg
.
block_2_ctile_map_
);
arg
.
b_element_op_
,
}
arg
.
c_element_op_
,
arg
.
block_2_ctile_map_
);
if
(
kbatch
>
1
||
nrepeat
<=
0
)
{
hipGetErrorString
(
hipMemset
(
arg
.
p_c_grid_
,
0
,
arg
.
c_grid_desc_mblock_mperblock_nblock_nperblock_
.
GetElementSpaceSize
()
*
sizeof
(
CDataType
)));
launch_kernel
(
kernel
,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
arg
.
p_a_grid_
,
arg
.
p_b_grid_
,
arg
.
p_c_grid_
,
arg
.
a_grid_desc_kbatch_k0_m_k1_
,
arg
.
b_grid_desc_kbatch_k0_n_k1_
,
arg
.
c_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
c_element_op_
,
arg
.
block_2_ctile_map_
);
}
};
};
if
(
has_main_k0_block_loop
)
if
(
has_main_k0_block_loop
)
{
{
if
(
kbatch
==
1
)
if
(
kbatch
==
1
)
...
@@ -542,9 +521,10 @@ struct DeviceGemmXdlSplitKCShuffle
...
@@ -542,9 +521,10 @@ struct DeviceGemmXdlSplitKCShuffle
}
}
// polymorphic
// polymorphic
float
Run
(
const
BaseArgument
*
p_arg
,
int
nrepeat
=
1
)
override
float
Run
(
const
BaseArgument
*
p_arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
override
{
{
return
Run
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
),
nrepeat
);
return
Run
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
),
stream_config
);
}
}
};
};
...
@@ -559,8 +539,7 @@ struct DeviceGemmXdlSplitKCShuffle
...
@@ -559,8 +539,7 @@ struct DeviceGemmXdlSplitKCShuffle
return
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_kbatch_k0_m_k1_
,
return
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_kbatch_k0_m_k1_
,
arg
.
b_grid_desc_kbatch_k0_n_k1_
,
arg
.
b_grid_desc_kbatch_k0_n_k1_
,
arg
.
c_grid_desc_m_n_
,
arg
.
c_grid_desc_m_n_
,
arg
.
M01_
,
arg
.
block_2_ctile_map_
);
arg
.
N01_
);
}
}
// polymorphic
// polymorphic
...
...
include/ck/tensor_operation/gpu/device/device_grouped_gemm_xdl.hpp
View file @
a3b4c5cb
...
@@ -17,6 +17,62 @@ namespace ck {
...
@@ -17,6 +17,62 @@ namespace ck {
namespace
tensor_operation
{
namespace
tensor_operation
{
namespace
device
{
namespace
device
{
template
<
typename
GridwiseGemm
,
typename
FloatAB
,
typename
FloatC
,
typename
GemmDesc
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
,
bool
HasMainKBlockLoop
>
__global__
void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
#endif
kernel_grouped_gemm_xdlops_v2r3
(
const
void
CK_CONSTANT_ADDRESS_SPACE
*
gemm_descs_const
,
const
index_t
group_count
,
const
AElementwiseOperation
a_element_op
,
const
BElementwiseOperation
b_element_op
,
const
CElementwiseOperation
c_element_op
)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
const
index_t
block_id
=
get_block_1d_id
();
const
auto
gemm_desc_ptr
=
reinterpret_cast
<
const
GemmDesc
*>
(
cast_pointer_to_generic_address_space
(
gemm_descs_const
));
index_t
group_id
=
0
;
for
(
index_t
i
=
0
;
i
<
group_count
;
i
++
)
{
group_id
=
(
block_id
>=
gemm_desc_ptr
[
i
].
BlockStart_
&&
block_id
<
gemm_desc_ptr
[
i
].
BlockEnd_
)
?
i
:
group_id
;
}
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
gemm_desc_ptr
[
group_id
].
a_ptr
,
gemm_desc_ptr
[
group_id
].
b_ptr
,
gemm_desc_ptr
[
group_id
].
c_ptr
,
p_shared
,
gemm_desc_ptr
[
group_id
].
a_grid_desc_k0_m_k1_
,
gemm_desc_ptr
[
group_id
].
b_grid_desc_k0_n_k1_
,
gemm_desc_ptr
[
group_id
].
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_
,
a_element_op
,
b_element_op
,
c_element_op
,
gemm_desc_ptr
[
group_id
].
grouped_gemm_block_2_ctile_map_
);
#else
ignore
=
gemm_descs_const
;
ignore
=
group_count
;
ignore
=
a_element_op
;
ignore
=
b_element_op
;
ignore
=
c_element_op
;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
}
template
<
typename
ADataType
,
template
<
typename
ADataType
,
typename
BDataType
,
typename
BDataType
,
typename
CDataType
,
typename
CDataType
,
...
@@ -225,6 +281,11 @@ struct DeviceGroupedGemmXdl
...
@@ -225,6 +281,11 @@ struct DeviceGroupedGemmXdl
struct
GroupedGemmBlock2CTileMap
struct
GroupedGemmBlock2CTileMap
{
{
using
UnderlyingBlock2CTileMap
=
typename
GridwiseGemm
::
DefaultBlock2CTileMap
;
static_assert
(
std
::
is_same
<
decltype
(
GridwiseGemm
::
MakeDefaultBlock2CTileMap
(
CGridDesc_M_N
{},
1
,
1
)),
typename
GridwiseGemm
::
DefaultBlock2CTileMap
>::
value
,
"Wrong! Should be the same type name"
);
GroupedGemmBlock2CTileMap
()
GroupedGemmBlock2CTileMap
()
{
{
block_2_ctile_map_
=
GridwiseGemm
::
MakeDefaultBlock2CTileMap
(
CGridDesc_M_N
{},
1
,
1
);
block_2_ctile_map_
=
GridwiseGemm
::
MakeDefaultBlock2CTileMap
(
CGridDesc_M_N
{},
1
,
1
);
...
@@ -247,7 +308,18 @@ struct DeviceGroupedGemmXdl
...
@@ -247,7 +308,18 @@ struct DeviceGroupedGemmXdl
make_multi_index
(
idx_top
[
I0
]
-
BlockStart_
));
make_multi_index
(
idx_top
[
I0
]
-
BlockStart_
));
}
}
private:
template
<
typename
CTileIdx
,
typename
CTileDim
>
__host__
__device__
bool
ValidCTileIndex
(
const
CTileIdx
&
c_tile_idx
,
const
CTileDim
&
c_tile_dim
)
const
{
return
block_2_ctile_map_
.
ValidCTileIndex
(
c_tile_idx
,
c_tile_dim
);
}
__host__
bool
CheckValidity
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
)
const
{
return
block_2_ctile_map_
.
CheckValidity
(
c_grid_desc_m_n
);
}
typename
GridwiseGemm
::
DefaultBlock2CTileMap
block_2_ctile_map_
;
typename
GridwiseGemm
::
DefaultBlock2CTileMap
block_2_ctile_map_
;
ck
::
index_t
BlockStart_
;
ck
::
index_t
BlockStart_
;
};
};
...
@@ -290,17 +362,20 @@ struct DeviceGroupedGemmXdl
...
@@ -290,17 +362,20 @@ struct DeviceGroupedGemmXdl
{
{
grid_size_
=
0
;
grid_size_
=
0
;
group_count_
=
static_cast
<
int
>
(
gemm_shapes
.
size
());
gemm_descs_args_workspace_
=
nullptr
;
group_count_
=
ck
::
type_convert
<
ck
::
index_t
>
(
gemm_shapes
.
size
());
if
(
!
(
group_count_
==
p_a
.
size
()
&&
group_count_
==
p_b
.
size
()
&&
if
(
!
(
group_count_
==
ck
::
type_convert
<
ck
::
index_t
>
(
p_a
.
size
())
&&
group_count_
==
p_c
.
size
()))
group_count_
==
ck
::
type_convert
<
ck
::
index_t
>
(
p_b
.
size
())
&&
group_count_
==
ck
::
type_convert
<
ck
::
index_t
>
(
p_c
.
size
())))
{
{
throw
std
::
runtime_error
(
"wrong! group_count_ != P_a/b/c.size"
);
throw
std
::
runtime_error
(
"wrong! group_count_ != P_a/b/c.size"
);
}
}
gemm_desc_kernel_arg_
.
reserve
(
group_count_
);
gemm_desc_kernel_arg_
.
reserve
(
group_count_
);
for
(
index
_t
i
=
0
;
i
<
gemm_shapes
.
size
();
i
++
)
for
(
std
::
size
_t
i
=
0
;
i
<
gemm_shapes
.
size
();
i
++
)
{
{
const
index_t
M
=
gemm_shapes
[
i
].
M
;
const
index_t
M
=
gemm_shapes
[
i
].
M
;
const
index_t
N
=
gemm_shapes
[
i
].
N
;
const
index_t
N
=
gemm_shapes
[
i
].
N
;
...
@@ -317,22 +392,26 @@ struct DeviceGroupedGemmXdl
...
@@ -317,22 +392,26 @@ struct DeviceGroupedGemmXdl
const
auto
c_grid_desc_m_n_
=
const
auto
c_grid_desc_m_n_
=
DeviceGroupedGemmXdl
::
MakeCGridDescriptor_M_N
(
M
,
N
,
StrideC
);
DeviceGroupedGemmXdl
::
MakeCGridDescriptor_M_N
(
M
,
N
,
StrideC
);
const
index_t
grid_size_grp
=
GridwiseGemm
::
CalculateGridSize
(
c_grid_desc_m_n_
);
const
index_t
grid_size_grp
=
GroupedGemmBlock2CTileMap
(
c_grid_desc_m_n_
,
M01
,
N01
,
0
)
.
block_2_ctile_map_
.
CalculateGridSize
(
c_grid_desc_m_n_
);
const
index_t
BlockStart
=
grid_size_
;
const
index_t
BlockStart
=
grid_size_
;
const
index_t
BlockEnd
=
grid_size_
+
grid_size_grp
;
const
index_t
BlockEnd
=
grid_size_
+
grid_size_grp
;
grid_size_
+=
grid_size_grp
;
grid_size_
+=
grid_size_grp
;
if
(
GridwiseGemm
::
CheckValidity
(
const
auto
grouped_gemm_block_2_ctile_map_
=
a_grid_desc_k0_m_k1_
,
b_grid_desc_k0_n_k1_
,
c_grid_desc_m_n_
,
M01_
,
N01_
))
GroupedGemmBlock2CTileMap
(
c_grid_desc_m_n_
,
M01
,
N01
,
BlockStart
);
if
(
GridwiseGemm
::
CheckValidity
(
a_grid_desc_k0_m_k1_
,
b_grid_desc_k0_n_k1_
,
c_grid_desc_m_n_
,
grouped_gemm_block_2_ctile_map_
))
{
{
const
auto
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_
=
const
auto
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_
=
GridwiseGemm
::
MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
(
c_grid_desc_m_n_
);
GridwiseGemm
::
MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
(
c_grid_desc_m_n_
);
const
auto
grouped_gemm_block_2_ctile_map_
=
GroupedGemmBlock2CTileMap
(
c_grid_desc_m_n_
,
M01
,
N01
,
BlockStart
);
gemm_desc_kernel_arg_
.
push_back
(
gemm_desc_kernel_arg_
.
push_back
(
GemmDescKernelArg
{
a_grid_desc_k0_m_k1_
,
GemmDescKernelArg
{
a_grid_desc_k0_m_k1_
,
b_grid_desc_k0_n_k1_
,
b_grid_desc_k0_n_k1_
,
...
@@ -358,6 +437,8 @@ struct DeviceGroupedGemmXdl
...
@@ -358,6 +437,8 @@ struct DeviceGroupedGemmXdl
std
::
vector
<
GemmDescKernelArg
>
gemm_desc_kernel_arg_
;
std
::
vector
<
GemmDescKernelArg
>
gemm_desc_kernel_arg_
;
void
*
gemm_descs_args_workspace_
;
index_t
grid_size_
;
index_t
grid_size_
;
};
};
...
@@ -366,83 +447,77 @@ struct DeviceGroupedGemmXdl
...
@@ -366,83 +447,77 @@ struct DeviceGroupedGemmXdl
{
{
using
Argument
=
DeviceGroupedGemmXdl
::
Argument
;
using
Argument
=
DeviceGroupedGemmXdl
::
Argument
;
float
Run
(
const
Argument
&
arg
,
int
nrepeat
=
1
)
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{}
)
{
{
StaticallyIndexedArray
<
GemmDescKernelArg
,
MaxGroupCount
>
gemm_desc_kernel_arg_arg
;
bool
has_main_k_block_loop
=
true
;
for
(
std
::
size_t
i
=
0
;
i
<
arg
.
gemm_desc_kernel_arg_
.
size
();
i
++
)
{
std
::
cout
<<
"group: "
<<
i
<<
" arg.a_grid_desc_k0_m_k1_{"
<<
arg
.
gemm_desc_kernel_arg_
[
i
].
a_grid_desc_k0_m_k1_
.
GetLength
(
I0
)
<<
", "
<<
arg
.
gemm_desc_kernel_arg_
[
i
].
a_grid_desc_k0_m_k1_
.
GetLength
(
I1
)
<<
", "
<<
arg
.
gemm_desc_kernel_arg_
[
i
].
a_grid_desc_k0_m_k1_
.
GetLength
(
I2
)
<<
"}"
;
std
::
cout
<<
", arg.b_grid_desc_k0_n_k1_{"
<<
arg
.
gemm_desc_kernel_arg_
[
i
].
b_grid_desc_k0_n_k1_
.
GetLength
(
I0
)
<<
", "
<<
arg
.
gemm_desc_kernel_arg_
[
i
].
b_grid_desc_k0_n_k1_
.
GetLength
(
I1
)
<<
", "
<<
arg
.
gemm_desc_kernel_arg_
[
i
].
b_grid_desc_k0_n_k1_
.
GetLength
(
I2
)
<<
"}"
;
std
::
cout
<<
", arg.c_grid_desc_m_n_{ "
<<
arg
.
gemm_desc_kernel_arg_
[
i
].
c_grid_desc_m_n_
.
GetLength
(
I0
)
<<
", "
<<
arg
.
gemm_desc_kernel_arg_
[
i
].
c_grid_desc_m_n_
.
GetLength
(
I1
)
<<
"}"
<<
std
::
endl
;
if
(
!
GridwiseGemm
::
CheckValidity
(
arg
.
gemm_desc_kernel_arg_
[
i
].
a_grid_desc_k0_m_k1_
,
arg
.
gemm_desc_kernel_arg_
[
i
].
b_grid_desc_k0_n_k1_
,
arg
.
gemm_desc_kernel_arg_
[
i
].
c_grid_desc_m_n_
,
arg
.
gemm_desc_kernel_arg_
[
i
].
grouped_gemm_block_2_ctile_map_
))
{
throw
std
::
runtime_error
(
"wrong! GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 has invalid setting"
);
}
bool
has_main_k0_block_loop
=
true
;
const
auto
K
=
arg
.
gemm_desc_kernel_arg_
[
i
].
a_grid_desc_k0_m_k1_
.
GetLength
(
I0
)
*
arg
.
gemm_desc_kernel_arg_
[
i
].
a_grid_desc_k0_m_k1_
.
GetLength
(
I2
);
static_for
<
0
,
MaxGroupCount
,
1
>
{}([
&
](
auto
i
)
{
if
(
GridwiseGemm
::
CalculateHasMainKBlockLoop
(
K
)
!=
has_main_k_block_loop
)
if
(
i
<
arg
.
gemm_desc_kernel_arg_
.
size
())
{
{
gemm_desc_kernel_arg_arg
(
i
)
=
arg
.
gemm_desc_kernel_arg_
[
i
];
throw
std
::
runtime_error
(
"wrong! not all gemm has_main_k_block_loop"
);
std
::
cout
<<
"group: "
<<
i
<<
" arg.a_grid_desc_k0_m_k1_{"
<<
gemm_desc_kernel_arg_arg
[
i
].
a_grid_desc_k0_m_k1_
.
GetLength
(
I0
)
<<
", "
<<
gemm_desc_kernel_arg_arg
[
i
].
a_grid_desc_k0_m_k1_
.
GetLength
(
I1
)
<<
", "
<<
gemm_desc_kernel_arg_arg
[
i
].
a_grid_desc_k0_m_k1_
.
GetLength
(
I2
)
<<
"}"
;
std
::
cout
<<
", arg.b_grid_desc_k0_n_k1_{"
<<
gemm_desc_kernel_arg_arg
[
i
].
b_grid_desc_k0_n_k1_
.
GetLength
(
I0
)
<<
", "
<<
gemm_desc_kernel_arg_arg
[
i
].
b_grid_desc_k0_n_k1_
.
GetLength
(
I1
)
<<
", "
<<
gemm_desc_kernel_arg_arg
[
i
].
b_grid_desc_k0_n_k1_
.
GetLength
(
I2
)
<<
"}"
;
std
::
cout
<<
", arg.c_grid_desc_m_n_{ "
<<
gemm_desc_kernel_arg_arg
[
i
].
c_grid_desc_m_n_
.
GetLength
(
I0
)
<<
", "
<<
gemm_desc_kernel_arg_arg
[
i
].
c_grid_desc_m_n_
.
GetLength
(
I1
)
<<
"}"
<<
std
::
endl
;
if
(
!
GridwiseGemm
::
CheckValidity
(
gemm_desc_kernel_arg_arg
[
i
].
a_grid_desc_k0_m_k1_
,
gemm_desc_kernel_arg_arg
[
i
].
b_grid_desc_k0_n_k1_
,
gemm_desc_kernel_arg_arg
[
i
].
c_grid_desc_m_n_
,
arg
.
M01_
,
arg
.
N01_
))
{
throw
std
::
runtime_error
(
"wrong! GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 has invalid setting"
);
}
const
auto
K0
=
gemm_desc_kernel_arg_arg
[
i
].
a_grid_desc_k0_m_k1_
.
GetLength
(
I0
);
if
(
GridwiseGemm
::
CalculateHasMainK0BlockLoop
(
K0
)
!=
has_main_k0_block_loop
)
{
throw
std
::
runtime_error
(
"wrong! not all gemm has_main_k0_block_loop"
);
}
}
}
});
}
hipGetErrorString
(
hipMemcpy
(
arg
.
gemm_descs_args_workspace_
,
arg
.
gemm_desc_kernel_arg_
.
data
(),
arg
.
gemm_desc_kernel_arg_
.
size
()
*
sizeof
(
GemmDescKernelArg
),
hipMemcpyHostToDevice
));
float
ave_time
=
0
;
float
ave_time
=
0
;
if
(
has_main_k
0
_block_loop
)
if
(
has_main_k_block_loop
)
{
{
const
auto
kernel
=
const
auto
kernel
=
kernel_grouped_gemm_xdlops_v2r3
<
GridwiseGemm
,
kernel_grouped_gemm_xdlops_v2r3
<
GridwiseGemm
,
ADataType
,
// TODO: distiguish A/B datatype
ADataType
,
// TODO: distiguish A/B datatype
CDataType
,
CDataType
,
remove_reference_t
<
GemmDescKernelArg
>
,
GemmDescKernelArg
,
AElementwiseOperation
,
AElementwiseOperation
,
BElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
,
CElementwiseOperation
,
true
,
true
>
;
MaxGroupCount
>
;
ave_time
=
launch_and_time_kernel
(
ave_time
=
launch_and_time_kernel
(
kernel
,
stream_config
,
nrepeat
,
kernel
,
dim3
(
arg
.
grid_size_
),
dim3
(
arg
.
grid_size_
),
dim3
(
BlockSize
),
dim3
(
BlockSize
),
0
,
0
,
gemm_desc_kernel_arg_arg
,
cast_pointer_to_constant_address_space
(
arg
.
gemm_descs_args_workspace_
)
,
arg
.
gemm_desc_kernel_arg_
.
size
(),
arg
.
gemm_desc_kernel_arg_
.
size
(),
arg
.
a_element_op_
,
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
b_element_op_
,
arg
.
c_element_op_
);
arg
.
c_element_op_
);
}
}
else
else
{
{
...
@@ -450,32 +525,33 @@ struct DeviceGroupedGemmXdl
...
@@ -450,32 +525,33 @@ struct DeviceGroupedGemmXdl
kernel_grouped_gemm_xdlops_v2r3
<
GridwiseGemm
,
kernel_grouped_gemm_xdlops_v2r3
<
GridwiseGemm
,
ADataType
,
// TODO: distiguish A/B datatype
ADataType
,
// TODO: distiguish A/B datatype
CDataType
,
CDataType
,
remove_reference_t
<
GemmDescKernelArg
>
,
GemmDescKernelArg
,
AElementwiseOperation
,
AElementwiseOperation
,
BElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
,
CElementwiseOperation
,
false
,
false
>
;
MaxGroupCount
>
;
ave_time
=
launch_and_time_kernel
(
ave_time
=
launch_and_time_kernel
(
kernel
,
stream_config
,
nrepeat
,
kernel
,
dim3
(
arg
.
grid_size_
),
dim3
(
arg
.
grid_size_
),
dim3
(
BlockSize
),
dim3
(
BlockSize
),
0
,
0
,
gemm_desc_kernel_arg_arg
,
cast_pointer_to_constant_address_space
(
arg
.
gemm_descs_args_workspace_
)
,
arg
.
gemm_desc_kernel_arg_
.
size
(),
arg
.
gemm_desc_kernel_arg_
.
size
(),
arg
.
a_element_op_
,
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
b_element_op_
,
arg
.
c_element_op_
);
arg
.
c_element_op_
);
}
}
return
ave_time
;
return
ave_time
;
}
}
// polymorphic
// polymorphic
float
Run
(
const
BaseArgument
*
p_arg
,
int
nrepeat
=
1
)
override
float
Run
(
const
BaseArgument
*
p_arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
override
{
{
return
Run
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
),
nrepeat
);
return
Run
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
),
stream_config
);
}
}
};
};
...
@@ -487,7 +563,7 @@ struct DeviceGroupedGemmXdl
...
@@ -487,7 +563,7 @@ struct DeviceGroupedGemmXdl
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
{
if
(
arg
.
gemm_desc_kernel_arg_
.
size
()
!=
arg
.
group_count_
)
if
(
ck
::
type_convert
<
ck
::
index_t
>
(
arg
.
gemm_desc_kernel_arg_
.
size
()
)
!=
arg
.
group_count_
)
return
false
;
return
false
;
else
else
return
true
;
return
true
;
...
@@ -554,6 +630,16 @@ struct DeviceGroupedGemmXdl
...
@@ -554,6 +630,16 @@ struct DeviceGroupedGemmXdl
return
str
.
str
();
return
str
.
str
();
}
}
size_t
GetWorkSpaceSize
(
const
BaseArgument
*
p_arg
)
const
override
{
return
dynamic_cast
<
const
Argument
*>
(
p_arg
)
->
group_count_
*
sizeof
(
GemmDescKernelArg
);
}
void
SetWorkSpacePointer
(
BaseArgument
*
p_arg
,
void
*
workspace_ptr
)
const
override
{
dynamic_cast
<
Argument
*>
(
p_arg
)
->
gemm_descs_args_workspace_
=
workspace_ptr
;
}
};
};
}
// namespace device
}
// namespace device
...
...
include/ck/tensor_operation/gpu/device/device_pool2d_fwd_nhwc_nhwc.hpp
View file @
a3b4c5cb
...
@@ -17,7 +17,7 @@ template <typename InDataType,
...
@@ -17,7 +17,7 @@ template <typename InDataType,
typename
OutDataType
,
typename
OutDataType
,
typename
AccDataType
,
typename
AccDataType
,
ck
::
ReduceTensorOp
ReduceOpId
,
ck
::
ReduceTensorOp
ReduceOpId
,
bool
NeedIndices
,
bool
OuputIndex
,
ck
::
index_t
BlockSize
,
ck
::
index_t
BlockSize
,
ck
::
index_t
ReduceMThreadClusterSize
,
ck
::
index_t
ReduceMThreadClusterSize
,
ck
::
index_t
ReduceKThreadClusterSize
,
ck
::
index_t
ReduceKThreadClusterSize
,
...
@@ -44,8 +44,6 @@ struct DevicePool2dFwd_Input_N_Hi_Wi_C_Output_N_Ho_Wo_C : public DevicePool2dFwd
...
@@ -44,8 +44,6 @@ struct DevicePool2dFwd_Input_N_Hi_Wi_C_Output_N_Ho_Wo_C : public DevicePool2dFwd
typename
reduce_unary_operator
<
AccDataType
,
ReduceOpId
,
true
,
true
>::
typename
reduce_unary_operator
<
AccDataType
,
ReduceOpId
,
true
,
true
>::
AccElementwiseOperation
;
AccElementwiseOperation
;
static
constexpr
bool
BetaIsZero
=
true
;
static
constexpr
index_t
InSrcOutDstVectorDim
=
static
constexpr
index_t
InSrcOutDstVectorDim
=
0
;
// for NHWC, the dim C is the vector Dim for both input and output in memory, which is
0
;
// for NHWC, the dim C is the vector Dim for both input and output in memory, which is
// not reduced.
// not reduced.
...
@@ -204,30 +202,30 @@ struct DevicePool2dFwd_Input_N_Hi_Wi_C_Output_N_Ho_Wo_C : public DevicePool2dFwd
...
@@ -204,30 +202,30 @@ struct DevicePool2dFwd_Input_N_Hi_Wi_C_Output_N_Ho_Wo_C : public DevicePool2dFwd
struct
Invoker
:
public
BaseInvoker
struct
Invoker
:
public
BaseInvoker
{
{
float
Run
(
const
Argument
&
arg
,
int
nrepeat
=
1
)
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{}
)
{
{
using
gridwise_reduce
=
GridwiseReduction_mk_to_m_threadwise
<
InDataType
,
using
gridwise_reduce
=
OutDataType
,
GridwiseReduction_mk_to_m_threadwise
<
InDataType
,
AccDataType
,
OutDataType
,
IndexDataType
,
AccDataType
,
AGridDesc_M_K
,
IndexDataType
,
BGridDesc_M
,
AGridDesc_M_K
,
ReduceOperation
,
BGridDesc_M
,
InElementwiseOperation
,
ReduceOperation
,
AccElementwiseOperation
,
InElementwiseOperation
,
false
,
// propagate_nan
AccElementwiseOperation
,
BetaIsZero
,
InMemoryDataOperationEnum
::
Set
,
BlockSize
,
false
,
// propagate_nan
ReduceMThreadClusterSize
,
BlockSize
,
ReduceKThreadClusterSize
,
ReduceMThreadSliceSize
,
ReduceMThreadSliceSize
,
ReduceKThreadSliceSize
,
ReduceKThreadSliceSize
,
InSrcOutDstVectorDim
,
InSrcOutDstVectorDim
,
InSrcOutDstVectorSize
,
InSrcOutDstVectorSize
,
InSrcOutDstVectorSize
>
;
InSrcOutDstVectorSize
>
;
const
auto
kernel
=
kernel_reduce_threadwise
<
gridwise_reduce
,
const
auto
kernel
=
kernel_reduce_threadwise
<
gridwise_reduce
,
NeedIndices
,
OuputIndex
,
false
,
// don't have index input
InDataType
,
InDataType
,
OutDataType
,
OutDataType
,
AccDataType
,
AccDataType
,
...
@@ -241,8 +239,8 @@ struct DevicePool2dFwd_Input_N_Hi_Wi_C_Output_N_Ho_Wo_C : public DevicePool2dFwd
...
@@ -241,8 +239,8 @@ struct DevicePool2dFwd_Input_N_Hi_Wi_C_Output_N_Ho_Wo_C : public DevicePool2dFwd
const
index_t
grid_size
=
(
ReduceM
/
ReduceM_BlockTileSize
);
const
index_t
grid_size
=
(
ReduceM
/
ReduceM_BlockTileSize
);
return
launch_and_time_kernel
(
kernel
,
return
launch_and_time_kernel
(
stream_config
,
nrepeat
,
kernel
,
dim3
(
grid_size
),
dim3
(
grid_size
),
dim3
(
BlockSize
),
dim3
(
BlockSize
),
0
,
0
,
...
@@ -252,14 +250,16 @@ struct DevicePool2dFwd_Input_N_Hi_Wi_C_Output_N_Ho_Wo_C : public DevicePool2dFwd
...
@@ -252,14 +250,16 @@ struct DevicePool2dFwd_Input_N_Hi_Wi_C_Output_N_Ho_Wo_C : public DevicePool2dFwd
arg
.
acc_element_op_
,
arg
.
acc_element_op_
,
float
(
1
),
float
(
1
),
arg
.
p_in_dev_
,
arg
.
p_in_dev_
,
nullptr
,
float
(
0
),
float
(
0
),
arg
.
p_out_dev_
,
arg
.
p_out_dev_
,
arg
.
p_out_indices_dev_
);
arg
.
p_out_indices_dev_
);
}
}
float
Run
(
const
BaseArgument
*
p_arg
,
int
nrepeat
=
1
)
override
float
Run
(
const
BaseArgument
*
p_arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
override
{
{
return
Run
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
),
nrepeat
);
return
Run
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
),
stream_config
);
}
}
};
};
...
...
include/ck/tensor_operation/gpu/device/device_reduce.hpp
View file @
a3b4c5cb
...
@@ -16,35 +16,18 @@ namespace device {
...
@@ -16,35 +16,18 @@ namespace device {
template
<
typename
InElementwiseOperation
,
typename
AccElementwiseOperation
>
template
<
typename
InElementwiseOperation
,
typename
AccElementwiseOperation
>
struct
DeviceReduce
:
public
BaseOperator
struct
DeviceReduce
:
public
BaseOperator
{
{
virtual
long_index_t
GetWorkspaceSizeInBytes
(
const
std
::
vector
<
int
>
inLengths
,
const
std
::
vector
<
int
>
reduceDims
)
{
(
void
)
inLengths
;
(
void
)
reduceDims
;
return
(
0
);
};
virtual
bool
HasFurtherCall
()
{
return
(
false
);
};
virtual
std
::
vector
<
int
>
GetWorkspace2dLengths
(
const
BaseArgument
*
argPtr
)
{
(
void
)
argPtr
;
return
(
std
::
vector
<
int
>
{
0
,
0
});
};
virtual
std
::
unique_ptr
<
BaseArgument
>
virtual
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
std
::
vector
<
int
>
inLengths
,
MakeArgumentPointer
(
const
std
::
vector
<
in
dex_
t
>
inLengths
,
const
std
::
vector
<
int
>
inStrides
,
const
std
::
vector
<
in
dex_
t
>
inStrides
,
const
std
::
vector
<
int
>
outLengths
,
const
std
::
vector
<
in
dex_
t
>
outLengths
,
const
std
::
vector
<
int
>
outStrides
,
const
std
::
vector
<
in
dex_
t
>
outStrides
,
const
std
::
vector
<
int
>
reduceDims
,
const
std
::
vector
<
int
>
reduceDims
,
float
alpha
,
float
alpha
,
float
beta
,
float
beta
,
const
void
*
in_dev
,
const
void
*
in_dev
,
const
void
*
in_index_dev
,
void
*
out_dev
,
void
*
out_dev
,
void
*
out_indices_dev
,
void
*
out_index_dev
,
void
*
workspace_dev
,
const
InElementwiseOperation
in_elementwise_op
,
const
InElementwiseOperation
in_elementwise_op
,
const
AccElementwiseOperation
acc_elementwise_op
)
=
0
;
const
AccElementwiseOperation
acc_elementwise_op
)
=
0
;
...
...
include/ck/tensor_operation/gpu/device/device_reduce_blockwise.hpp
deleted
100644 → 0
View file @
48918ab9
#ifndef DEVICE_REDUCE_BLOCKWISE_HPP
#define DEVICE_REDUCE_BLOCKWISE_HPP
#include <iostream>
#include <sstream>
#include "device.hpp"
#include "device_reduce.hpp"
#include "device_reduce_common.hpp"
#include "gridwise_2d_reduction_blockwise.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
template
<
typename
InDataType
,
typename
AccDataType
,
typename
OutDataType
,
index_t
Rank
,
index_t
NumReduceDim
,
typename
ReduceOperation
,
typename
InElementwiseOperation
,
typename
AccElementwiseOperation
,
bool
PropagateNan
,
bool
NeedIndices
,
index_t
BlockSize
,
index_t
MThreadClusterSize
,
index_t
KThreadClusterSize
,
index_t
MThreadSliceSize
,
index_t
KThreadSliceSize
,
index_t
InSrcVectorDim
,
index_t
InSrcVectorSize
,
index_t
OutDstVectorSize
>
struct
DeviceReduceBlockWise
:
public
DeviceReduce
<
InElementwiseOperation
,
AccElementwiseOperation
>
{
static_assert
(
Rank
<=
6
,
"Bigger Rank size is not supported!"
);
static_assert
(
BlockSize
==
MThreadClusterSize
*
KThreadClusterSize
,
"Invalid thread cluster size assignments!"
);
static_assert
(((
InSrcVectorDim
==
0
&&
MThreadSliceSize
%
InSrcVectorSize
==
0
)
||
(
InSrcVectorDim
==
1
&&
KThreadSliceSize
%
InSrcVectorSize
==
0
))
&&
(
MThreadSliceSize
%
OutDstVectorSize
==
0
),
"Invalid thread slice sizes and/or vector sizes configuration, please check!"
);
using
IndexDataType
=
int32_t
;
static
constexpr
bool
BetaIsZero
=
NeedIndices
;
static
constexpr
index_t
NumInvariantDim
=
Rank
-
NumReduceDim
;
static
constexpr
index_t
numSrcDim
=
Rank
;
static
constexpr
index_t
numDstDim
=
(
NumInvariantDim
==
0
)
?
1
:
NumInvariantDim
;
static
constexpr
bool
reduceAllDim
=
(
NumInvariantDim
==
0
);
static
constexpr
int
M_BlockTileSize
=
MThreadClusterSize
*
MThreadSliceSize
;
static
constexpr
int
K_BlockTileSize
=
KThreadClusterSize
*
KThreadSliceSize
;
static
auto
MakeSrc2dDescriptor
(
const
std
::
vector
<
int
>&
inLengths
,
const
std
::
vector
<
int
>&
inStrides
)
{
const
auto
tupleSrcLengths
=
make_tuple_from_array
(
inLengths
,
Number
<
numSrcDim
>
{});
const
auto
tupleSrcStrides
=
make_tuple_from_array
(
inStrides
,
Number
<
numSrcDim
>
{});
const
auto
inDesc
=
make_naive_tensor_descriptor
(
tupleSrcLengths
,
tupleSrcStrides
);
const
auto
in_grid_desc_m_k
=
[
&
]()
{
if
constexpr
(
reduceAllDim
)
{
const
auto
one_dim_inDesc
=
transform_tensor_descriptor
(
inDesc
,
make_tuple
(
make_merge_transform
(
tupleSrcLengths
)),
make_tuple
(
typename
arithmetic_sequence_gen
<
0
,
numSrcDim
,
1
>::
type
{}),
make_tuple
(
Sequence
<
0
>
{}));
return
transform_tensor_descriptor
(
one_dim_inDesc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
1
,
one_dim_inDesc
.
GetLength
(
Number
<
0
>
{})))),
make_tuple
(
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
1
>
{}));
}
else
{
using
InvariantDims
=
typename
arithmetic_sequence_gen
<
0
,
NumInvariantDim
,
1
>::
type
;
using
ReduceDims
=
typename
arithmetic_sequence_gen
<
NumInvariantDim
,
Rank
,
1
>::
type
;
const
auto
reduceDimLengths
=
make_tuple_from_array_and_index_seq
(
inLengths
,
ReduceDims
{});
const
auto
invariantDimLengths
=
make_tuple_from_array_and_index_seq
(
inLengths
,
InvariantDims
{});
return
transform_tensor_descriptor
(
inDesc
,
make_tuple
(
make_merge_transform
(
invariantDimLengths
),
make_merge_transform
(
reduceDimLengths
)),
make_tuple
(
InvariantDims
{},
ReduceDims
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
}
}();
const
auto
invariantLength
=
in_grid_desc_m_k
.
GetLength
(
Number
<
0
>
{});
const
auto
reduceLength
=
in_grid_desc_m_k
.
GetLength
(
Number
<
1
>
{});
const
auto
inPad_M
=
math
::
integer_least_multiple
(
invariantLength
,
M_BlockTileSize
)
-
invariantLength
;
const
auto
inPad_K
=
math
::
integer_least_multiple
(
reduceLength
,
K_BlockTileSize
)
-
reduceLength
;
auto
in_grid_desc_m_k_padded
=
transform_tensor_descriptor
(
in_grid_desc_m_k
,
make_tuple
(
make_right_pad_transform
(
invariantLength
,
inPad_M
),
make_right_pad_transform
(
reduceLength
,
inPad_K
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
return
(
in_grid_desc_m_k_padded
);
};
static
auto
MakeDst1dDescriptor
(
const
std
::
vector
<
int
>&
outLengths
,
const
std
::
vector
<
int
>&
outStrides
)
{
const
auto
tupleDstLengths
=
make_tuple_from_array
(
outLengths
,
Number
<
numDstDim
>
{});
const
auto
tupleDstStrides
=
make_tuple_from_array
(
outStrides
,
Number
<
numDstDim
>
{});
auto
outDesc
=
make_naive_tensor_descriptor
(
tupleDstLengths
,
tupleDstStrides
);
auto
out_grid_desc_m
=
transform_tensor_descriptor
(
outDesc
,
make_tuple
(
make_merge_transform
(
tupleDstLengths
)),
make_tuple
(
typename
arithmetic_sequence_gen
<
0
,
numDstDim
,
1
>::
type
{}),
make_tuple
(
Sequence
<
0
>
{}));
const
auto
invariantLength
=
out_grid_desc_m
.
GetLength
(
Number
<
0
>
{});
const
auto
inPad
=
math
::
integer_least_multiple
(
invariantLength
,
M_BlockTileSize
)
-
invariantLength
;
auto
out_grid_desc_m_padded
=
transform_tensor_descriptor
(
out_grid_desc_m
,
make_tuple
(
make_right_pad_transform
(
invariantLength
,
inPad
)),
make_tuple
(
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
return
(
out_grid_desc_m_padded
);
};
struct
Argument
:
public
BaseArgument
{
Argument
(
const
std
::
vector
<
int
>
inLengths
,
const
std
::
vector
<
int
>
inStrides
,
const
std
::
vector
<
int
>
outLengths
,
const
std
::
vector
<
int
>
outStrides
,
const
std
::
vector
<
int
>
reduceDims
,
float
alpha
,
float
beta
,
const
InDataType
*
in_dev
,
OutDataType
*
out_dev
,
IndexDataType
*
out_indices_dev
,
AccDataType
*
workspace_dev
,
const
InElementwiseOperation
in_elementwise_op
,
const
AccElementwiseOperation
acc_elementwise_op
)
:
outLengths_
{
outLengths
},
outStrides_
{
outStrides
},
in_dev_
{
in_dev
},
out_dev_
{
out_dev
},
out_indices_dev_
{
out_indices_dev
},
in_elementwise_op_
{
in_elementwise_op
},
acc_elementwise_op_
{
acc_elementwise_op
}
{
(
void
)
workspace_dev
;
inLengths_
=
shuffle_tensor_dimensions
<
Rank
,
NumReduceDim
>
(
inLengths
,
reduceDims
);
inStrides_
=
shuffle_tensor_dimensions
<
Rank
,
NumReduceDim
>
(
inStrides
,
reduceDims
);
alpha_
=
type_convert
<
AccDataType
>
(
alpha
);
beta_
=
type_convert
<
AccDataType
>
(
beta
);
std
::
tie
(
invariant_total_length
,
reduce_total_length
)
=
get_2d_lengths
<
Rank
,
NumReduceDim
>
(
inLengths_
);
if
constexpr
(
NumInvariantDim
==
0
)
invariant_lowest_length
=
1
;
else
invariant_lowest_length
=
inLengths_
[
NumInvariantDim
-
1
];
reduce_lowest_length
=
inLengths_
[
Rank
-
1
];
gridSize
=
math
::
integer_least_multiple
(
invariant_total_length
,
M_BlockTileSize
)
/
M_BlockTileSize
;
}
std
::
vector
<
int
>
inLengths_
;
std
::
vector
<
int
>
inStrides_
;
std
::
vector
<
int
>
outLengths_
;
std
::
vector
<
int
>
outStrides_
;
AccDataType
alpha_
;
AccDataType
beta_
;
const
InDataType
*
in_dev_
;
OutDataType
*
out_dev_
;
IndexDataType
*
out_indices_dev_
;
InElementwiseOperation
in_elementwise_op_
;
AccElementwiseOperation
acc_elementwise_op_
;
int
invariant_lowest_length
;
int
reduce_lowest_length
;
size_t
invariant_total_length
;
size_t
reduce_total_length
;
size_t
gridSize
;
};
struct
Invoker
:
public
BaseInvoker
{
float
Run
(
const
Argument
&
arg
,
int
nrepeat
=
1
)
{
const
auto
in_grid_desc_m_k
=
DeviceReduceBlockWise
::
MakeSrc2dDescriptor
(
arg
.
inLengths_
,
arg
.
inStrides_
);
const
auto
out_grid_desc_m
=
DeviceReduceBlockWise
::
MakeDst1dDescriptor
(
arg
.
outLengths_
,
arg
.
outStrides_
);
using
InGridDesc_M_K
=
decltype
(
in_grid_desc_m_k
);
using
OutGridDesc_M
=
decltype
(
out_grid_desc_m
);
using
GridwiseReduce
=
GridwiseReduction_mk_to_m_blockwise
<
InDataType
,
OutDataType
,
AccDataType
,
IndexDataType
,
InGridDesc_M_K
,
OutGridDesc_M
,
ReduceOperation
,
InElementwiseOperation
,
AccElementwiseOperation
,
PropagateNan
,
BetaIsZero
,
BlockSize
,
MThreadClusterSize
,
KThreadClusterSize
,
MThreadSliceSize
,
KThreadSliceSize
,
InSrcVectorDim
,
InSrcVectorSize
,
OutDstVectorSize
>
;
float
avg_time
=
0
;
const
auto
kernel
=
kernel_reduce_blockwise
<
GridwiseReduce
,
NeedIndices
,
InDataType
,
OutDataType
,
AccDataType
,
IndexDataType
,
InGridDesc_M_K
,
OutGridDesc_M
,
InElementwiseOperation
,
AccElementwiseOperation
>
;
avg_time
=
launch_and_time_kernel
(
kernel
,
nrepeat
,
dim3
(
arg
.
gridSize
),
dim3
(
BlockSize
),
0
,
in_grid_desc_m_k
,
out_grid_desc_m
,
arg
.
in_elementwise_op_
,
arg
.
acc_elementwise_op_
,
arg
.
alpha_
,
arg
.
in_dev_
,
arg
.
beta_
,
arg
.
out_dev_
,
nullptr
,
arg
.
out_indices_dev_
);
return
(
avg_time
);
};
float
Run
(
const
BaseArgument
*
p_arg
,
int
nrepeat
=
1
)
override
{
return
Run
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
),
nrepeat
);
};
};
bool
IsSupportedArgument
(
const
BaseArgument
*
p_arg
)
override
{
const
Argument
*
pArg
=
dynamic_cast
<
const
Argument
*>
(
p_arg
);
if
constexpr
(
InSrcVectorDim
==
0
)
{
if
constexpr
(
NumInvariantDim
==
0
)
{
return
(
false
);
}
else
{
if
(
pArg
->
inStrides_
[
NumInvariantDim
-
1
]
!=
1
)
return
(
false
);
if
(
pArg
->
invariant_lowest_length
%
InSrcVectorSize
!=
0
)
return
(
false
);
};
}
else
{
if
(
pArg
->
inStrides_
[
Rank
-
1
]
!=
1
)
return
(
false
);
if
(
pArg
->
reduce_lowest_length
%
InSrcVectorSize
!=
0
)
return
(
false
);
};
// To improve
if
(
pArg
->
invariant_lowest_length
%
OutDstVectorSize
!=
0
)
return
(
false
);
// cases with very small reduce_total_length should be handled by the ThreadWise method
if
(
pArg
->
reduce_total_length
/
KThreadSliceSize
<
2
)
return
(
false
);
return
(
true
);
};
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
std
::
vector
<
int
>
inLengths
,
const
std
::
vector
<
int
>
inStrides
,
const
std
::
vector
<
int
>
outLengths
,
const
std
::
vector
<
int
>
outStrides
,
const
std
::
vector
<
int
>
reduceDims
,
float
alpha
,
float
beta
,
const
void
*
in_dev
,
void
*
out_dev
,
void
*
out_indices_dev
,
void
*
workspace_dev
,
const
InElementwiseOperation
in_elementwise_op
,
const
AccElementwiseOperation
acc_elementwise_op
)
override
{
return
std
::
make_unique
<
Argument
>
(
inLengths
,
inStrides
,
outLengths
,
outStrides
,
reduceDims
,
alpha
,
beta
,
static_cast
<
const
InDataType
*>
(
in_dev
),
static_cast
<
OutDataType
*>
(
out_dev
),
static_cast
<
IndexDataType
*>
(
out_indices_dev
),
static_cast
<
AccDataType
*>
(
workspace_dev
),
in_elementwise_op
,
acc_elementwise_op
);
};
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
override
{
return
std
::
make_unique
<
Invoker
>
();
};
std
::
string
GetTypeString
()
const
override
{
auto
str
=
std
::
stringstream
();
// clang-format off
str
<<
"DeviceReduceBlockWise<"
<<
BlockSize
<<
","
;
str
<<
"M_C"
<<
MThreadClusterSize
<<
"_S"
<<
MThreadSliceSize
<<
","
;
str
<<
"K_C"
<<
KThreadClusterSize
<<
"_S"
<<
KThreadSliceSize
<<
","
;
str
<<
"InSrcVectorDim_"
<<
InSrcVectorDim
<<
"_InSrcVectorSize_"
<<
InSrcVectorSize
<<
"_OutDstVectorSize_"
<<
OutDstVectorSize
<<
">"
;
// clang-format on
return
str
.
str
();
}
};
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
#endif
include/ck/tensor_operation/gpu/device/device_reduce_blockwise_second_call.hpp
deleted
100644 → 0
View file @
48918ab9
#ifndef DEVICE_REDUCE_BLOCKWISE_SECOND_CALL_HPP
#define DEVICE_REDUCE_BLOCKWISE_SECOND_CALL_HPP
#include <iostream>
#include <sstream>
#include "device.hpp"
#include "device_reduce.hpp"
#include "device_reduce_common.hpp"
#include "gridwise_2d_reduction_blockwise.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
template
<
typename
InDataType
,
typename
AccDataType
,
typename
OutDataType
,
index_t
Rank
,
index_t
NumReduceDim
,
typename
ReduceOperation
,
typename
InElementwiseOperation
,
typename
AccElementwiseOperation
,
bool
PropagateNan
,
bool
NeedIndices
,
index_t
BlockSize
,
index_t
MThreadClusterSize
,
index_t
KThreadClusterSize
,
index_t
MThreadSliceSize
,
index_t
KThreadSliceSize
,
index_t
InSrcVectorDim
,
index_t
InSrcVectorSize
,
index_t
OutDstVectorSize
>
struct
DeviceReduceBlockWiseSecondCall
:
public
DeviceReduce
<
InElementwiseOperation
,
AccElementwiseOperation
>
{
static_assert
(
Rank
<=
6
,
"Bigger Rank size is not supported!"
);
static_assert
(
BlockSize
==
MThreadClusterSize
*
KThreadClusterSize
,
"Invalid thread cluster size assignments!"
);
static_assert
((
InSrcVectorDim
==
1
&&
KThreadSliceSize
%
InSrcVectorSize
==
0
)
&&
(
MThreadSliceSize
%
OutDstVectorSize
==
0
),
"Invalid thread slice sizes and/or vector sizes configuration, please check!"
);
using
IndexDataType
=
int32_t
;
static
constexpr
bool
BetaIsZero
=
NeedIndices
;
static_assert
(
std
::
is_same
<
InDataType
,
AccDataType
>::
value
,
"InDataType and AccDataType should be the same to use DEviceReduceBlockWiseSecondCall!"
);
static
constexpr
index_t
NumInvariantDim
=
Rank
-
NumReduceDim
;
static
constexpr
index_t
numDstDim
=
(
NumInvariantDim
==
0
)
?
1
:
NumInvariantDim
;
static
constexpr
int
M_BlockTileSize
=
MThreadClusterSize
*
MThreadSliceSize
;
static
constexpr
int
K_BlockTileSize
=
KThreadClusterSize
*
KThreadSliceSize
;
static
auto
MakeSrc2dDescriptor
(
const
std
::
vector
<
int
>&
inLengths
,
const
std
::
vector
<
int
>&
inStrides
)
{
const
auto
tupleSrcLengths
=
make_tuple_from_array
(
inLengths
,
Number
<
2
>
{});
const
auto
tupleSrcStrides
=
make_tuple_from_array
(
inStrides
,
Number
<
2
>
{});
const
auto
in_grid_desc_m_k
=
make_naive_tensor_descriptor
(
tupleSrcLengths
,
tupleSrcStrides
);
const
auto
invariantLength
=
in_grid_desc_m_k
.
GetLength
(
Number
<
0
>
{});
const
auto
reduceLength
=
in_grid_desc_m_k
.
GetLength
(
Number
<
1
>
{});
const
auto
inPad_M
=
math
::
integer_least_multiple
(
invariantLength
,
M_BlockTileSize
)
-
invariantLength
;
const
auto
inPad_K
=
math
::
integer_least_multiple
(
reduceLength
,
K_BlockTileSize
)
-
reduceLength
;
auto
in_grid_desc_m_k_padded
=
transform_tensor_descriptor
(
in_grid_desc_m_k
,
make_tuple
(
make_right_pad_transform
(
invariantLength
,
inPad_M
),
make_right_pad_transform
(
reduceLength
,
inPad_K
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
return
(
in_grid_desc_m_k_padded
);
};
static
auto
MakeDst1dDescriptor
(
const
std
::
vector
<
int
>&
outLengths
,
const
std
::
vector
<
int
>&
outStrides
)
{
const
auto
tupleDstLengths
=
make_tuple_from_array
(
outLengths
,
Number
<
numDstDim
>
{});
const
auto
tupleDstStrides
=
make_tuple_from_array
(
outStrides
,
Number
<
numDstDim
>
{});
auto
outDesc
=
make_naive_tensor_descriptor
(
tupleDstLengths
,
tupleDstStrides
);
auto
out_grid_desc_m
=
transform_tensor_descriptor
(
outDesc
,
make_tuple
(
make_merge_transform
(
tupleDstLengths
)),
make_tuple
(
typename
arithmetic_sequence_gen
<
0
,
numDstDim
,
1
>::
type
{}),
make_tuple
(
Sequence
<
0
>
{}));
const
auto
invariantLength
=
out_grid_desc_m
.
GetLength
(
Number
<
0
>
{});
const
auto
outPad
=
math
::
integer_least_multiple
(
invariantLength
,
M_BlockTileSize
)
-
invariantLength
;
auto
out_grid_desc_m_padded
=
transform_tensor_descriptor
(
out_grid_desc_m
,
make_tuple
(
make_right_pad_transform
(
invariantLength
,
outPad
)),
make_tuple
(
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
return
(
out_grid_desc_m_padded
);
};
struct
Argument
:
public
BaseArgument
{
Argument
(
const
std
::
vector
<
int
>&
inLengths
,
const
std
::
vector
<
int
>&
inStrides
,
const
std
::
vector
<
int
>&
outLengths
,
const
std
::
vector
<
int
>&
outStrides
,
float
alpha
,
float
beta
,
const
InDataType
*
in_dev
,
OutDataType
*
out_dev
,
IndexDataType
*
out_indices_dev
,
AccDataType
*
workspace_dev
,
const
InElementwiseOperation
&
in_elementwise_op
,
const
AccElementwiseOperation
&
acc_elementwise_op
)
:
inLengths_
(
inLengths
),
inStrides_
(
inStrides
),
outLengths_
(
outLengths
),
outStrides_
(
outStrides
),
in_dev_
{
in_dev
},
out_dev_
{
out_dev
},
out_indices_dev_
{
out_indices_dev
},
in_elementwise_op_
(
in_elementwise_op
),
acc_elementwise_op_
(
acc_elementwise_op
)
{
alpha_
=
type_convert
<
AccDataType
>
(
alpha
);
beta_
=
type_convert
<
AccDataType
>
(
beta
);
invariant_total_length
=
inLengths
[
0
];
reduce_total_length
=
inLengths
[
1
];
invariant_lowest_length
=
inLengths
[
0
];
reduce_lowest_length
=
inLengths
[
1
];
gridSize
=
math
::
integer_least_multiple
(
invariant_total_length
,
M_BlockTileSize
)
/
M_BlockTileSize
;
size_t
ws_buf2_bytes_offset
=
math
::
integer_least_multiple
(
invariant_total_length
*
reduce_total_length
*
sizeof
(
AccDataType
),
64
);
if
constexpr
(
NeedIndices
)
workspace_indices_dev_
=
reinterpret_cast
<
index_t
*>
(
reinterpret_cast
<
char
*>
(
workspace_dev
)
+
ws_buf2_bytes_offset
);
else
workspace_indices_dev_
=
nullptr
;
}
std
::
vector
<
int
>
inLengths_
;
std
::
vector
<
int
>
inStrides_
;
std
::
vector
<
int
>
outLengths_
;
std
::
vector
<
int
>
outStrides_
;
AccDataType
alpha_
;
AccDataType
beta_
;
const
InDataType
*
in_dev_
;
OutDataType
*
out_dev_
;
IndexDataType
*
out_indices_dev_
;
IndexDataType
*
workspace_indices_dev_
;
InElementwiseOperation
in_elementwise_op_
;
AccElementwiseOperation
acc_elementwise_op_
;
int
invariant_lowest_length
;
int
reduce_lowest_length
;
size_t
invariant_total_length
;
size_t
reduce_total_length
;
size_t
gridSize
;
};
struct
Invoker
:
public
BaseInvoker
{
float
Run
(
const
Argument
&
arg
,
int
nrepeat
=
1
)
{
const
auto
in_grid_desc_m_k
=
DeviceReduceBlockWiseSecondCall
::
MakeSrc2dDescriptor
(
arg
.
inLengths_
,
arg
.
inStrides_
);
const
auto
out_grid_desc_m
=
DeviceReduceBlockWiseSecondCall
::
MakeDst1dDescriptor
(
arg
.
outLengths_
,
arg
.
outStrides_
);
using
InGridDesc_M_K
=
decltype
(
in_grid_desc_m_k
);
using
OutGridDesc_M
=
decltype
(
out_grid_desc_m
);
using
GridwiseReduce
=
GridwiseReduction_mk_to_m_blockwise
<
InDataType
,
OutDataType
,
AccDataType
,
IndexDataType
,
InGridDesc_M_K
,
OutGridDesc_M
,
ReduceOperation
,
InElementwiseOperation
,
AccElementwiseOperation
,
PropagateNan
,
BetaIsZero
,
BlockSize
,
MThreadClusterSize
,
KThreadClusterSize
,
MThreadSliceSize
,
KThreadSliceSize
,
InSrcVectorDim
,
InSrcVectorSize
,
OutDstVectorSize
>
;
float
avg_time
=
0
;
const
auto
kernel
=
kernel_reduce_blockwise_second_call
<
GridwiseReduce
,
NeedIndices
,
InDataType
,
OutDataType
,
AccDataType
,
IndexDataType
,
InGridDesc_M_K
,
OutGridDesc_M
,
InElementwiseOperation
,
AccElementwiseOperation
>
;
avg_time
=
launch_and_time_kernel
(
kernel
,
nrepeat
,
dim3
(
arg
.
gridSize
),
dim3
(
BlockSize
),
0
,
in_grid_desc_m_k
,
out_grid_desc_m
,
arg
.
in_elementwise_op_
,
arg
.
acc_elementwise_op_
,
arg
.
alpha_
,
arg
.
in_dev_
,
arg
.
beta_
,
arg
.
out_dev_
,
arg
.
workspace_indices_dev_
,
arg
.
out_indices_dev_
);
return
(
avg_time
);
};
float
Run
(
const
BaseArgument
*
p_arg
,
int
nrepeat
=
1
)
override
{
return
Run
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
),
nrepeat
);
};
};
bool
IsSupportedArgument
(
const
BaseArgument
*
p_arg
)
override
{
const
Argument
*
pArg
=
dynamic_cast
<
const
Argument
*>
(
p_arg
);
if
constexpr
(
InSrcVectorDim
==
0
)
return
(
false
);
if
(
pArg
->
reduce_lowest_length
%
InSrcVectorSize
!=
0
)
return
(
false
);
// To improve
if
(
pArg
->
invariant_lowest_length
%
OutDstVectorSize
!=
0
)
return
(
false
);
// cases with very small reduce_total_length should be handled by the ThreadWise method
if
(
pArg
->
reduce_total_length
/
KThreadSliceSize
<
2
)
return
(
false
);
return
(
true
);
};
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
std
::
vector
<
int
>
inLengths
,
const
std
::
vector
<
int
>
inStrides
,
const
std
::
vector
<
int
>
outLengths
,
const
std
::
vector
<
int
>
outStrides
,
const
std
::
vector
<
int
>
reduceDims
,
float
alpha
,
float
beta
,
const
void
*
in_dev
,
void
*
out_dev
,
void
*
out_indices_dev
,
void
*
workspace_dev
,
const
InElementwiseOperation
in_elementwise_op
,
const
AccElementwiseOperation
acc_elementwise_op
)
override
{
(
void
)
reduceDims
;
return
std
::
make_unique
<
Argument
>
(
inLengths
,
inStrides
,
outLengths
,
outStrides
,
alpha
,
beta
,
static_cast
<
const
InDataType
*>
(
in_dev
),
static_cast
<
OutDataType
*>
(
out_dev
),
static_cast
<
IndexDataType
*>
(
out_indices_dev
),
static_cast
<
AccDataType
*>
(
workspace_dev
),
in_elementwise_op
,
acc_elementwise_op
);
};
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
override
{
return
std
::
make_unique
<
Invoker
>
();
};
std
::
string
GetTypeString
()
const
override
{
auto
str
=
std
::
stringstream
();
// clang-format off
str
<<
"DeviceReduceBlockWiseSecondCall<"
<<
BlockSize
<<
","
;
str
<<
"M_C"
<<
MThreadClusterSize
<<
"_S"
<<
MThreadSliceSize
<<
","
;
str
<<
"K_C"
<<
KThreadClusterSize
<<
"_S"
<<
KThreadSliceSize
<<
","
;
str
<<
"InSrcVectorDim_"
<<
InSrcVectorDim
<<
"_InSrcVectorSize_"
<<
InSrcVectorSize
<<
"_OutDstVectorSize_"
<<
OutDstVectorSize
<<
">"
;
// clang-format on
return
str
.
str
();
}
};
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
#endif
include/ck/tensor_operation/gpu/device/device_reduce_common.hpp
View file @
a3b4c5cb
...
@@ -14,13 +14,13 @@ namespace device {
...
@@ -14,13 +14,13 @@ namespace device {
// here, inLengths[] is already shuffled so that lengths of invariant dims are included before those
// here, inLengths[] is already shuffled so that lengths of invariant dims are included before those
// of reduce dims
// of reduce dims
template
<
int
Rank
,
int
NumReduceDim
>
template
<
in
dex_
t
Rank
,
int
NumReduceDim
>
std
::
pair
<
size_t
,
size
_t
>
get_2d_lengths
(
const
std
::
vector
<
int
>&
inLengths
)
std
::
pair
<
long_index_t
,
long_index
_t
>
get_2d_lengths
(
const
std
::
vector
<
in
dex_
t
>&
inLengths
)
{
{
static_assert
(
Rank
<=
6
,
"bigger Rank size not supported!"
);
static_assert
(
Rank
<=
6
,
"bigger Rank size not supported!"
);
size
_t
invariant_total_length
=
1
;
long_index
_t
invariant_total_length
=
1
;
size
_t
reduce_total_length
=
1
;
long_index
_t
reduce_total_length
=
1
;
constexpr
int
NumInvariantDim
=
Rank
-
NumReduceDim
;
constexpr
int
NumInvariantDim
=
Rank
-
NumReduceDim
;
...
@@ -35,13 +35,13 @@ std::pair<size_t, size_t> get_2d_lengths(const std::vector<int>& inLengths)
...
@@ -35,13 +35,13 @@ std::pair<size_t, size_t> get_2d_lengths(const std::vector<int>& inLengths)
// helper functions using variadic template arguments
// helper functions using variadic template arguments
template
<
index_t
...
Ns
>
template
<
index_t
...
Ns
>
auto
make_tuple_from_array_and_index_seq
(
const
std
::
vector
<
int
>&
lengths
,
Sequence
<
Ns
...
>
)
auto
make_tuple_from_array_and_index_seq
(
const
std
::
vector
<
in
dex_
t
>&
lengths
,
Sequence
<
Ns
...
>
)
{
{
return
make_tuple
(
static_cast
<
index_t
>
(
lengths
[
Ns
])...);
return
make_tuple
(
static_cast
<
index_t
>
(
lengths
[
Ns
])...);
};
};
template
<
index_t
arraySize
>
template
<
index_t
arraySize
>
static
auto
make_tuple_from_array
(
const
std
::
vector
<
int
>&
lengths
,
Number
<
arraySize
>
)
auto
make_tuple_from_array
(
const
std
::
vector
<
in
dex_
t
>&
lengths
,
Number
<
arraySize
>
)
{
{
static_assert
(
arraySize
>=
1
&&
arraySize
<=
6
,
"The tensor should have 1 to 6 dimensions"
);
static_assert
(
arraySize
>=
1
&&
arraySize
<=
6
,
"The tensor should have 1 to 6 dimensions"
);
...
@@ -51,10 +51,10 @@ static auto make_tuple_from_array(const std::vector<int>& lengths, Number<arrayS
...
@@ -51,10 +51,10 @@ static auto make_tuple_from_array(const std::vector<int>& lengths, Number<arrayS
};
};
template
<
index_t
Rank
,
index_t
NumReduceDim
>
template
<
index_t
Rank
,
index_t
NumReduceDim
>
std
::
vector
<
int
>
shuffle_tensor_dimensions
(
const
std
::
vector
<
int
>&
origLengthsStrides
,
std
::
vector
<
in
dex_
t
>
shuffle_tensor_dimensions
(
const
std
::
vector
<
in
dex_
t
>&
origLengthsStrides
,
const
std
::
vector
<
int
>&
reduceDims
)
const
std
::
vector
<
int
>&
reduceDims
)
{
{
std
::
vector
<
int
>
newLengthsStrides
;
std
::
vector
<
in
dex_
t
>
newLengthsStrides
;
assert
(
Rank
==
origLengthsStrides
.
size
()
&&
NumReduceDim
==
reduceDims
.
size
());
assert
(
Rank
==
origLengthsStrides
.
size
()
&&
NumReduceDim
==
reduceDims
.
size
());
...
...
include/ck/tensor_operation/gpu/device/device_reduce_multiblock
_atomic_add
.hpp
→
include/ck/tensor_operation/gpu/device/device_reduce_multiblock.hpp
View file @
a3b4c5cb
#ifndef DEVICE_REDUCE_MULTIBLOCK_
ATOMIC_ADD_
HPP
#ifndef DEVICE_REDUCE_MULTIBLOCK_HPP
#define DEVICE_REDUCE_MULTIBLOCK_
ATOMIC_ADD_
HPP
#define DEVICE_REDUCE_MULTIBLOCK_HPP
#include <iostream>
#include <iostream>
#include <sstream>
#include <sstream>
...
@@ -7,8 +7,9 @@
...
@@ -7,8 +7,9 @@
#include "device_base.hpp"
#include "device_base.hpp"
#include "device_reduce.hpp"
#include "device_reduce.hpp"
#include "device_reduce_common.hpp"
#include "device_reduce_common.hpp"
#include "gridwise_2d_reduction_multiblock
_atomic_add
.hpp"
#include "gridwise_2d_reduction_multiblock.hpp"
#include "gridwise_set_buffer_value.hpp"
#include "gridwise_set_buffer_value.hpp"
#include "reduction_operator.hpp"
namespace
ck
{
namespace
ck
{
namespace
tensor_operation
{
namespace
tensor_operation
{
...
@@ -22,8 +23,10 @@ template <typename InDataType,
...
@@ -22,8 +23,10 @@ template <typename InDataType,
typename
ReduceOperation
,
typename
ReduceOperation
,
typename
InElementwiseOperation
,
typename
InElementwiseOperation
,
typename
AccElementwiseOperation
,
typename
AccElementwiseOperation
,
InMemoryDataOperationEnum
OutMemoryDataOperation
,
bool
PropagateNan
,
bool
PropagateNan
,
bool
NeedIndices
,
bool
OutputIndex
,
bool
HaveIndexInputIfOutputIndex
,
index_t
BlockSize
,
index_t
BlockSize
,
index_t
MThreadClusterSize
,
index_t
MThreadClusterSize
,
index_t
KThreadClusterSize
,
index_t
KThreadClusterSize
,
...
@@ -32,8 +35,7 @@ template <typename InDataType,
...
@@ -32,8 +35,7 @@ template <typename InDataType,
index_t
InSrcVectorDim
,
index_t
InSrcVectorDim
,
index_t
InSrcVectorSize
,
index_t
InSrcVectorSize
,
index_t
OutDstVectorSize
>
index_t
OutDstVectorSize
>
struct
DeviceReduceMultiBlockAtomicAdd
struct
DeviceReduceMultiBlock
:
public
DeviceReduce
<
InElementwiseOperation
,
AccElementwiseOperation
>
:
public
DeviceReduce
<
InElementwiseOperation
,
AccElementwiseOperation
>
{
{
static_assert
(
Rank
<=
6
,
"Bigger Rank size is not supported!"
);
static_assert
(
Rank
<=
6
,
"Bigger Rank size is not supported!"
);
static_assert
(
BlockSize
==
MThreadClusterSize
*
KThreadClusterSize
,
static_assert
(
BlockSize
==
MThreadClusterSize
*
KThreadClusterSize
,
...
@@ -46,26 +48,40 @@ struct DeviceReduceMultiBlockAtomicAdd
...
@@ -46,26 +48,40 @@ struct DeviceReduceMultiBlockAtomicAdd
using
IndexDataType
=
int32_t
;
using
IndexDataType
=
int32_t
;
static
constexpr
bool
HaveIndexInput
=
OutputIndex
&&
HaveIndexInputIfOutputIndex
;
static
constexpr
index_t
NumInvariantDim
=
Rank
-
NumReduceDim
;
static
constexpr
index_t
NumInvariantDim
=
Rank
-
NumReduceDim
;
static
constexpr
index_t
numSrcDim
=
Rank
;
static
constexpr
index_t
numSrcDim
=
Rank
;
static
constexpr
index_t
numDstDim
=
(
NumInvariantDim
==
0
)
?
1
:
NumInvariantDim
;
static
constexpr
index_t
numDstDim
=
(
NumInvariantDim
==
0
)
?
1
:
NumInvariantDim
;
static
constexpr
bool
reduceAllDim
=
(
NumInvariantDim
==
0
);
static
constexpr
bool
reduceAllDim
=
(
NumInvariantDim
==
0
);
static
constexpr
bool
support_AtomicAdd
=
// So far, only AtomicAdd is considered, other Atomic Operation like AtomicMax can be added
// later
static
constexpr
bool
use_multiblock
=
(
OutMemoryDataOperation
==
InMemoryDataOperationEnum
::
AtomicAdd
);
static
constexpr
bool
out_type_compatible_with_atomic_op
=
std
::
is_same
<
OutDataType
,
float
>::
value
||
std
::
is_same
<
OutDataType
,
double
>::
value
;
std
::
is_same
<
OutDataType
,
float
>::
value
||
std
::
is_same
<
OutDataType
,
double
>::
value
;
static_assert
(
!
NeedIndices
&&
support_AtomicAdd
,
static_assert
(
"MultiBlockAtomicAdd method can only be used with non-indiced operation and when "
!
use_multiblock
||
(
use_multiblock
&&
out_type_compatible_with_atomic_op
),
"having float/double output type!"
);
"The OutDataType must support the atomic operation for using MultiBlock reduction"
);
static_assert
(
!
use_multiblock
||
(
use_multiblock
&&
!
OutputIndex
),
"MultiBlock reduction can only be used when outputing index is not required"
);
static_assert
(
ReduceOperation
::
IsCompatibleInMemoryDataOperation
(
OutMemoryDataOperation
),
"The reduction accumulation operation must be compatible with the OutMemoryDataOperation!"
);
static
constexpr
int
M_BlockTileSize
=
MThreadClusterSize
*
MThreadSliceSize
;
static
constexpr
in
dex_
t
M_BlockTileSize
=
MThreadClusterSize
*
MThreadSliceSize
;
static
constexpr
int
K_BlockTileSize
=
KThreadClusterSize
*
KThreadSliceSize
;
static
constexpr
in
dex_
t
K_BlockTileSize
=
KThreadClusterSize
*
KThreadSliceSize
;
static
auto
MakeSrc2dDescriptor
(
const
std
::
vector
<
int
>&
inLengths
,
static
auto
MakeSrc2dDescriptor
(
const
std
::
vector
<
in
dex_
t
>&
inLengths
,
const
std
::
vector
<
int
>&
inStrides
,
const
std
::
vector
<
in
dex_
t
>&
inStrides
,
int
blkGroupSize
,
int
blkGroupSize
,
int
k
BlockTileIteration
s
)
int
num
BlockTileIteration
)
{
{
const
auto
tupleSrcLengths
=
make_tuple_from_array
(
inLengths
,
Number
<
numSrcDim
>
{});
const
auto
tupleSrcLengths
=
make_tuple_from_array
(
inLengths
,
Number
<
numSrcDim
>
{});
const
auto
tupleSrcStrides
=
make_tuple_from_array
(
inStrides
,
Number
<
numSrcDim
>
{});
const
auto
tupleSrcStrides
=
make_tuple_from_array
(
inStrides
,
Number
<
numSrcDim
>
{});
...
@@ -109,7 +125,7 @@ struct DeviceReduceMultiBlockAtomicAdd
...
@@ -109,7 +125,7 @@ struct DeviceReduceMultiBlockAtomicAdd
const
auto
invariantLength
=
in_grid_desc_m_k
.
GetLength
(
Number
<
0
>
{});
const
auto
invariantLength
=
in_grid_desc_m_k
.
GetLength
(
Number
<
0
>
{});
const
auto
reduceLength
=
in_grid_desc_m_k
.
GetLength
(
Number
<
1
>
{});
const
auto
reduceLength
=
in_grid_desc_m_k
.
GetLength
(
Number
<
1
>
{});
const
int
reduceSizePerBlock
=
K_BlockTileSize
*
k
BlockTileIteration
s
;
const
int
reduceSizePerBlock
=
K_BlockTileSize
*
num
BlockTileIteration
;
const
auto
inPad_M
=
const
auto
inPad_M
=
math
::
integer_least_multiple
(
invariantLength
,
M_BlockTileSize
)
-
invariantLength
;
math
::
integer_least_multiple
(
invariantLength
,
M_BlockTileSize
)
-
invariantLength
;
const
auto
inPad_K
=
reduceSizePerBlock
*
blkGroupSize
-
reduceLength
;
const
auto
inPad_K
=
reduceSizePerBlock
*
blkGroupSize
-
reduceLength
;
...
@@ -124,8 +140,8 @@ struct DeviceReduceMultiBlockAtomicAdd
...
@@ -124,8 +140,8 @@ struct DeviceReduceMultiBlockAtomicAdd
return
(
in_grid_desc_m_k_padded
);
return
(
in_grid_desc_m_k_padded
);
};
};
static
auto
MakeDst1dDescriptor
(
const
std
::
vector
<
int
>&
outLengths
,
static
auto
MakeDst1dDescriptor
(
const
std
::
vector
<
in
dex_
t
>&
outLengths
,
const
std
::
vector
<
int
>&
outStrides
)
const
std
::
vector
<
in
dex_
t
>&
outStrides
)
{
{
const
auto
tupleDstLengths
=
make_tuple_from_array
(
outLengths
,
Number
<
numDstDim
>
{});
const
auto
tupleDstLengths
=
make_tuple_from_array
(
outLengths
,
Number
<
numDstDim
>
{});
const
auto
tupleDstStrides
=
make_tuple_from_array
(
outStrides
,
Number
<
numDstDim
>
{});
const
auto
tupleDstStrides
=
make_tuple_from_array
(
outStrides
,
Number
<
numDstDim
>
{});
...
@@ -151,31 +167,56 @@ struct DeviceReduceMultiBlockAtomicAdd
...
@@ -151,31 +167,56 @@ struct DeviceReduceMultiBlockAtomicAdd
return
(
out_grid_desc_m_padded
);
return
(
out_grid_desc_m_padded
);
};
};
static
auto
MakeDst1dDescriptorForBufferSet
(
const
std
::
vector
<
index_t
>&
outLengths
,
const
std
::
vector
<
index_t
>&
outStrides
)
{
const
auto
tupleDstLengths
=
make_tuple_from_array
(
outLengths
,
Number
<
numDstDim
>
{});
const
auto
tupleDstStrides
=
make_tuple_from_array
(
outStrides
,
Number
<
numDstDim
>
{});
auto
outDesc
=
make_naive_tensor_descriptor
(
tupleDstLengths
,
tupleDstStrides
);
auto
out_grid_desc_m
=
transform_tensor_descriptor
(
outDesc
,
make_tuple
(
make_merge_transform
(
tupleDstLengths
)),
make_tuple
(
typename
arithmetic_sequence_gen
<
0
,
numDstDim
,
1
>::
type
{}),
make_tuple
(
Sequence
<
0
>
{}));
const
auto
length
=
out_grid_desc_m
.
GetLength
(
Number
<
0
>
{});
const
auto
pad
=
math
::
integer_least_multiple
(
length
,
BlockSize
)
-
length
;
auto
out_grid_desc_m_padded
=
transform_tensor_descriptor
(
out_grid_desc_m
,
make_tuple
(
make_right_pad_transform
(
length
,
pad
)),
make_tuple
(
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
return
(
out_grid_desc_m_padded
);
};
struct
Argument
:
public
BaseArgument
struct
Argument
:
public
BaseArgument
{
{
Argument
(
const
std
::
vector
<
int
>
inLengths
,
Argument
(
const
std
::
vector
<
in
dex_
t
>
inLengths
,
const
std
::
vector
<
int
>
inStrides
,
const
std
::
vector
<
in
dex_
t
>
inStrides
,
const
std
::
vector
<
int
>
outLengths
,
const
std
::
vector
<
in
dex_
t
>
outLengths
,
const
std
::
vector
<
int
>
outStrides
,
const
std
::
vector
<
in
dex_
t
>
outStrides
,
const
std
::
vector
<
int
>
reduceDims
,
const
std
::
vector
<
int
>
reduceDims
,
float
alpha
,
float
alpha
,
float
beta
,
float
beta
,
const
InDataType
*
in_dev
,
const
InDataType
*
in_dev
,
const
IndexDataType
*
in_index_dev
,
OutDataType
*
out_dev
,
OutDataType
*
out_dev
,
IndexDataType
*
out_indices_dev
,
IndexDataType
*
out_index_dev
,
AccDataType
*
workspace_dev
,
const
InElementwiseOperation
in_elementwise_op
,
const
InElementwiseOperation
in_elementwise_op
,
const
AccElementwiseOperation
acc_elementwise_op
)
const
AccElementwiseOperation
acc_elementwise_op
)
:
outLengths_
{
outLengths
},
:
outLengths_
{
outLengths
},
outStrides_
{
outStrides
},
outStrides_
{
outStrides
},
in_dev_
{
in_dev
},
in_dev_
{
in_dev
},
in_index_dev_
{
in_index_dev
},
out_dev_
{
out_dev
},
out_dev_
{
out_dev
},
out_index_dev_
{
out_index_dev
},
in_elementwise_op_
{
in_elementwise_op
},
in_elementwise_op_
{
in_elementwise_op
},
acc_elementwise_op_
{
acc_elementwise_op
}
acc_elementwise_op_
{
acc_elementwise_op
}
{
{
(
void
)
out_indices_dev
;
(
void
)
workspace_dev
;
inLengths_
=
shuffle_tensor_dimensions
<
Rank
,
NumReduceDim
>
(
inLengths
,
reduceDims
);
inLengths_
=
shuffle_tensor_dimensions
<
Rank
,
NumReduceDim
>
(
inLengths
,
reduceDims
);
inStrides_
=
shuffle_tensor_dimensions
<
Rank
,
NumReduceDim
>
(
inStrides
,
reduceDims
);
inStrides_
=
shuffle_tensor_dimensions
<
Rank
,
NumReduceDim
>
(
inStrides
,
reduceDims
);
...
@@ -192,23 +233,34 @@ struct DeviceReduceMultiBlockAtomicAdd
...
@@ -192,23 +233,34 @@ struct DeviceReduceMultiBlockAtomicAdd
reduce_lowest_length
=
inLengths_
[
Rank
-
1
];
reduce_lowest_length
=
inLengths_
[
Rank
-
1
];
int
iterations
=
1
;
if
constexpr
(
use_multiblock
)
while
(
true
)
{
{
int
testBlkGroupSize
=
(
reduce_total_length
+
(
K_BlockTileSize
*
iterations
)
-
1
)
/
(
K_BlockTileSize
*
iterations
);
// we want the blkGroupSize be not more than 128
int
iterations
=
1
;
if
(
testBlkGroupSize
<=
128
)
while
(
true
)
break
;
{
int
testBlkGroupSize
=
(
reduce_total_length
+
(
K_BlockTileSize
*
iterations
)
-
1
)
/
(
K_BlockTileSize
*
iterations
);
iterations
++
;
// we want the blkGroupSize be not more than 128
};
if
(
testBlkGroupSize
<=
128
)
break
;
blkGroupSize
=
(
reduce_total_length
+
(
K_BlockTileSize
*
iterations
)
-
1
)
/
iterations
++
;
(
K_BlockTileSize
*
iterations
)
;
}
;
kBlockTileIterations
=
iterations
;
blkGroupSize
=
(
reduce_total_length
+
(
K_BlockTileSize
*
iterations
)
-
1
)
/
(
K_BlockTileSize
*
iterations
);
numBlockTileIteration
=
iterations
;
}
else
{
blkGroupSize
=
1
;
numBlockTileIteration
=
(
reduce_total_length
+
K_BlockTileSize
-
1
)
/
K_BlockTileSize
;
};
gridSize
=
math
::
integer_least_multiple
(
invariant_total_length
,
M_BlockTileSize
)
/
gridSize
=
math
::
integer_least_multiple
(
invariant_total_length
,
M_BlockTileSize
)
/
M_BlockTileSize
*
blkGroupSize
;
M_BlockTileSize
*
blkGroupSize
;
...
@@ -217,27 +269,29 @@ struct DeviceReduceMultiBlockAtomicAdd
...
@@ -217,27 +269,29 @@ struct DeviceReduceMultiBlockAtomicAdd
math
::
integer_least_multiple
(
invariant_total_length
,
BlockSize
)
/
BlockSize
;
math
::
integer_least_multiple
(
invariant_total_length
,
BlockSize
)
/
BlockSize
;
}
}
std
::
vector
<
int
>
inLengths_
;
std
::
vector
<
in
dex_
t
>
inLengths_
;
std
::
vector
<
int
>
inStrides_
;
std
::
vector
<
in
dex_
t
>
inStrides_
;
std
::
vector
<
int
>
outLengths_
;
std
::
vector
<
in
dex_
t
>
outLengths_
;
std
::
vector
<
int
>
outStrides_
;
std
::
vector
<
in
dex_
t
>
outStrides_
;
AccDataType
alpha_
;
AccDataType
alpha_
;
AccDataType
beta_
;
AccDataType
beta_
;
const
InDataType
*
in_dev_
;
const
InDataType
*
in_dev_
;
const
IndexDataType
*
in_index_dev_
;
OutDataType
*
out_dev_
;
OutDataType
*
out_dev_
;
IndexDataType
*
out_index_dev_
;
InElementwiseOperation
in_elementwise_op_
;
InElementwiseOperation
in_elementwise_op_
;
AccElementwiseOperation
acc_elementwise_op_
;
AccElementwiseOperation
acc_elementwise_op_
;
int
invariant_lowest_length
;
in
dex_
t
invariant_lowest_length
;
int
reduce_lowest_length
;
in
dex_
t
reduce_lowest_length
;
size
_t
invariant_total_length
;
long_index
_t
invariant_total_length
;
size
_t
reduce_total_length
;
long_index
_t
reduce_total_length
;
in
dex_
t
blkGroupSize
;
int
blkGroupSize
;
in
dex_t
k
BlockTileIteration
s
;
in
t
num
BlockTileIteration
;
size_t
gridSize
;
size_t
gridSize
;
size_t
gridSize_pre
;
size_t
gridSize_pre
;
...
@@ -245,91 +299,97 @@ struct DeviceReduceMultiBlockAtomicAdd
...
@@ -245,91 +299,97 @@ struct DeviceReduceMultiBlockAtomicAdd
struct
Invoker
:
public
BaseInvoker
struct
Invoker
:
public
BaseInvoker
{
{
float
Run
(
const
Argument
&
arg
,
int
nrepeat
=
1
)
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{}
)
{
{
const
auto
in_grid_desc_m_k
=
DeviceReduceMultiBlockAtomicAdd
::
MakeSrc2dDescriptor
(
const
auto
in_grid_desc_m_k
=
DeviceReduceMultiBlock
::
MakeSrc2dDescriptor
(
arg
.
inLengths_
,
arg
.
inStrides_
,
arg
.
blkGroupSize
,
arg
.
kBlockTileIterations
);
arg
.
inLengths_
,
arg
.
inStrides_
,
arg
.
blkGroupSize
,
arg
.
numBlockTileIteration
);
const
auto
out_grid_desc_m
=
DeviceReduceMultiBlockAtomicAdd
::
MakeDst1dDescriptor
(
const
auto
out_grid_desc_m
=
DeviceReduceMultiBlock
::
MakeDst1dDescriptor
(
arg
.
outLengths_
,
arg
.
outStrides_
);
const
auto
out_grid_desc_m_2
=
DeviceReduceMultiBlock
::
MakeDst1dDescriptorForBufferSet
(
arg
.
outLengths_
,
arg
.
outStrides_
);
arg
.
outLengths_
,
arg
.
outStrides_
);
using
InGridDesc_M_K
=
decltype
(
in_grid_desc_m_k
);
using
OutGridDesc_M
=
decltype
(
out_grid_desc_m
);
using
GridwiseReduce
=
GridwiseReduction_mk_to_m_multiblock_atomic_add
<
InDataType
,
OutDataType
,
AccDataType
,
InGridDesc_M_K
,
OutGridDesc_M
,
ReduceOperation
,
InElementwiseOperation
,
AccElementwiseOperation
,
PropagateNan
,
BlockSize
,
MThreadClusterSize
,
KThreadClusterSize
,
MThreadSliceSize
,
KThreadSliceSize
,
InSrcVectorDim
,
InSrcVectorSize
,
OutDstVectorSize
>
;
float
avg_time
=
0
;
KernelTimer
timer
;
const
auto
kernel_pre
=
kernel_buffer_set_value
<
BlockSize
,
OutDataType
,
OutGridDesc_M
>
;
using
InGridDesc_M_K
=
decltype
(
in_grid_desc_m_k
);
const
auto
kernel_main
=
kernel_reduce_multiblock_atocmi_add
<
GridwiseReduce
,
using
OutGridDesc_M
=
decltype
(
out_grid_desc_m
);
InDataType
,
using
OutGridDesc_M_2
=
decltype
(
out_grid_desc_m_2
);
OutDataType
,
AccDataType
,
using
GridwiseReduce
=
GridwiseReduction_mk_to_m_multiblock
<
InDataType
,
InGridDesc_M_K
,
OutDataType
,
OutGridDesc_M
,
AccDataType
,
InElementwiseOperation
,
IndexDataType
,
AccElementwiseOperation
>
;
InGridDesc_M_K
,
OutGridDesc_M
,
ReduceOperation
,
InElementwiseOperation
,
AccElementwiseOperation
,
OutMemoryDataOperation
,
PropagateNan
,
BlockSize
,
MThreadClusterSize
,
KThreadClusterSize
,
MThreadSliceSize
,
KThreadSliceSize
,
InSrcVectorDim
,
InSrcVectorSize
,
OutDstVectorSize
>
;
const
auto
kernel_main
=
kernel_reduce_multiblock
<
GridwiseReduce
,
OutputIndex
,
HaveIndexInput
,
InDataType
,
OutDataType
,
AccDataType
,
int32_t
,
InGridDesc_M_K
,
OutGridDesc_M
,
InElementwiseOperation
,
AccElementwiseOperation
>
;
printf
(
"launch_and_time_kernel: grid_dim {%ld, 1, 1}, block_dim {%d, 1, 1}
\n
"
,
float
avg_time
=
0
;
arg
.
gridSize
,
BlockSize
);
printf
(
"Warm up
\n
"
);
for
(
int
i
=
0
;
i
<
nrepeat
+
1
;
i
++
)
if
constexpr
(
use_multiblock
)
{
{
if
(
i
==
1
)
const
auto
identityVal
=
timer
.
Start
();
ck
::
reduce
::
GetIdentityValueueForInMemoryDataOperation
<
OutDataType
>
(
OutMemoryDataOperation
);
launch_kernel
(
kernel_pre
,
dim3
(
arg
.
gridSize_pre
),
const
auto
kernel_pre
=
dim3
(
BlockSize
),
kernel_buffer_set_value
<
BlockSize
,
OutDataType
,
OutGridDesc_M_2
>
;
0
,
out_grid_desc_m
,
avg_time
+=
launch_and_time_kernel
(
stream_config
,
arg
.
out_dev_
,
kernel_pre
,
static_cast
<
OutDataType
>
(
0.0
f
));
dim3
(
arg
.
gridSize_pre
),
dim3
(
BlockSize
),
launch_kernel
(
kernel_main
,
0
,
dim3
(
arg
.
gridSize
),
out_grid_desc_m_2
,
dim3
(
BlockSize
),
arg
.
out_dev_
,
0
,
identityVal
);
in_grid_desc_m_k
,
out_grid_desc_m
,
arg
.
in_elementwise_op_
,
arg
.
acc_elementwise_op_
,
arg
.
blkGroupSize
,
arg
.
kBlockTileIterations
,
arg
.
alpha_
,
arg
.
in_dev_
,
arg
.
out_dev_
);
};
};
timer
.
End
();
avg_time
+=
launch_and_time_kernel
(
stream_config
,
kernel_main
,
avg_time
=
timer
.
GetElapsedTime
()
/
nrepeat
;
dim3
(
arg
.
gridSize
),
dim3
(
BlockSize
),
0
,
in_grid_desc_m_k
,
out_grid_desc_m
,
arg
.
in_elementwise_op_
,
arg
.
acc_elementwise_op_
,
arg
.
blkGroupSize
,
arg
.
numBlockTileIteration
,
arg
.
alpha_
,
arg
.
in_dev_
,
arg
.
in_index_dev_
,
arg
.
beta_
,
arg
.
out_dev_
,
arg
.
out_index_dev_
);
return
(
avg_time
);
return
(
avg_time
);
};
};
float
Run
(
const
BaseArgument
*
p_arg
,
int
nrepeat
=
1
)
override
float
Run
(
const
BaseArgument
*
p_arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
override
{
{
return
Run
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
),
nrepeat
);
return
Run
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
),
stream_config
);
};
};
};
};
...
@@ -337,6 +397,12 @@ struct DeviceReduceMultiBlockAtomicAdd
...
@@ -337,6 +397,12 @@ struct DeviceReduceMultiBlockAtomicAdd
{
{
const
Argument
*
pArg
=
dynamic_cast
<
const
Argument
*>
(
p_arg
);
const
Argument
*
pArg
=
dynamic_cast
<
const
Argument
*>
(
p_arg
);
if
constexpr
(
use_multiblock
)
{
if
(
static_cast
<
float
>
(
pArg
->
beta_
)
!=
0.0
f
)
return
(
false
);
};
if
constexpr
(
InSrcVectorDim
==
0
)
if
constexpr
(
InSrcVectorDim
==
0
)
{
{
if
constexpr
(
NumInvariantDim
==
0
)
if
constexpr
(
NumInvariantDim
==
0
)
...
@@ -361,36 +427,43 @@ struct DeviceReduceMultiBlockAtomicAdd
...
@@ -361,36 +427,43 @@ struct DeviceReduceMultiBlockAtomicAdd
return
(
false
);
return
(
false
);
};
};
if
(
static_cast
<
float
>
(
pArg
->
beta_
)
!=
0.0
f
)
return
(
false
);
// To improve
// To improve
if
(
pArg
->
invariant_lowest_length
%
OutDstVectorSize
!=
0
)
if
(
pArg
->
invariant_lowest_length
%
OutDstVectorSize
!=
0
)
return
(
false
);
return
(
false
);
// cases with small reduce_total_length should be handled by the BlockWise method
if
constexpr
(
use_multiblock
)
if
(
pArg
->
reduce_total_length
<=
BlockSize
*
KThreadSliceSize
)
{
return
(
false
);
// blkGroupSize of 1 should be handled by Blockwise path using
// InMemoryDataOperationEnum::Set
if
(
pArg
->
blkGroupSize
==
1
)
return
(
false
);
// This is very strong restriction, but needed to avoid some failure
// This is very strong restriction, but needed to avoid some failure
if
(
pArg
->
invariant_lowest_length
%
M_BlockTileSize
!=
0
)
if
(
pArg
->
invariant_lowest_length
%
M_BlockTileSize
!=
0
)
return
(
false
);
return
(
false
);
}
else
{
// cases with very small reduce_total_length should be handled by ThreadWise kernel
if
(
pArg
->
reduce_total_length
/
KThreadSliceSize
<
2
)
return
(
false
);
};
return
(
true
);
return
(
true
);
};
};
std
::
unique_ptr
<
BaseArgument
>
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
std
::
vector
<
int
>
inLengths
,
MakeArgumentPointer
(
const
std
::
vector
<
in
dex_
t
>
inLengths
,
const
std
::
vector
<
int
>
inStrides
,
const
std
::
vector
<
in
dex_
t
>
inStrides
,
const
std
::
vector
<
int
>
outLengths
,
const
std
::
vector
<
in
dex_
t
>
outLengths
,
const
std
::
vector
<
int
>
outStrides
,
const
std
::
vector
<
in
dex_
t
>
outStrides
,
const
std
::
vector
<
int
>
reduceDims
,
const
std
::
vector
<
int
>
reduceDims
,
float
alpha
,
float
alpha
,
float
beta
,
float
beta
,
const
void
*
in_dev
,
const
void
*
in_dev
,
const
void
*
in_index_dev
,
void
*
out_dev
,
void
*
out_dev
,
void
*
out_indices_dev
,
void
*
out_index_dev
,
void
*
workspace_dev
,
const
InElementwiseOperation
in_elementwise_op
,
const
InElementwiseOperation
in_elementwise_op
,
const
AccElementwiseOperation
acc_elementwise_op
)
override
const
AccElementwiseOperation
acc_elementwise_op
)
override
{
{
...
@@ -402,9 +475,9 @@ struct DeviceReduceMultiBlockAtomicAdd
...
@@ -402,9 +475,9 @@ struct DeviceReduceMultiBlockAtomicAdd
alpha
,
alpha
,
beta
,
beta
,
static_cast
<
const
InDataType
*>
(
in_dev
),
static_cast
<
const
InDataType
*>
(
in_dev
),
static_cast
<
const
IndexDataType
*>
(
in_index_dev
),
static_cast
<
OutDataType
*>
(
out_dev
),
static_cast
<
OutDataType
*>
(
out_dev
),
static_cast
<
IndexDataType
*>
(
out_indices_dev
),
static_cast
<
IndexDataType
*>
(
out_index_dev
),
static_cast
<
AccDataType
*>
(
workspace_dev
),
in_elementwise_op
,
in_elementwise_op
,
acc_elementwise_op
);
acc_elementwise_op
);
};
};
...
...
include/ck/tensor_operation/gpu/device/device_reduce_multiblock_partial_reduce.hpp
deleted
100644 → 0
View file @
48918ab9
#ifndef DEVICE_REDUCE_MULTIBLOCK_PARTIAL_REDUCE_HPP
#define DEVICE_REDUCE_MULTIBLOCK_PARTIAL_REDUCE_HPP
#include <iostream>
#include <sstream>
#include "device.hpp"
#include "device_reduce.hpp"
#include "device_reduce_common.hpp"
#include "gridwise_2d_reduction_multiblock_partial_reduce.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
template
<
typename
InDataType
,
typename
AccDataType
,
typename
OutDataType
,
index_t
Rank
,
index_t
NumReduceDim
,
typename
ReduceOperation
,
typename
InElementwiseOperation
,
typename
AccElementwiseOperation
,
bool
PropagateNan
,
bool
NeedIndices
,
index_t
BlockSize
,
index_t
MThreadClusterSize
,
index_t
KThreadClusterSize
,
index_t
MThreadSliceSize
,
index_t
KThreadSliceSize
,
index_t
InSrcVectorDim
,
index_t
InSrcVectorSize
,
index_t
OutDstVectorSize
>
struct
DeviceReduceMultiBlockPartialReduce
:
public
DeviceReduce
<
InElementwiseOperation
,
AccElementwiseOperation
>
{
static_assert
(
Rank
<=
6
,
"Bigger Rank size is not supported!"
);
static_assert
(
BlockSize
==
MThreadClusterSize
*
KThreadClusterSize
,
"Invalid thread cluster size assignments!"
);
static_assert
((
InSrcVectorDim
==
0
&&
MThreadSliceSize
%
InSrcVectorSize
==
0
)
||
(
InSrcVectorDim
==
1
&&
KThreadSliceSize
%
InSrcVectorSize
==
0
),
"Invalid thread slice sizes and/or vector sizes configuration, please check!"
);
static_assert
(
OutDstVectorSize
==
1
,
"OutDstVectorSize must be 1 for MultiBlockPartialReduce!"
);
using
IndexDataType
=
int32_t
;
static
constexpr
index_t
NumInvariantDim
=
Rank
-
NumReduceDim
;
static
constexpr
index_t
numSrcDim
=
Rank
;
static
constexpr
index_t
numDstDim
=
(
NumInvariantDim
==
0
)
?
1
:
NumInvariantDim
;
static
constexpr
bool
reduceAllDim
=
(
NumInvariantDim
==
0
);
static
constexpr
int
M_BlockTileSize
=
MThreadClusterSize
*
MThreadSliceSize
;
static
constexpr
int
K_BlockTileSize
=
KThreadClusterSize
*
KThreadSliceSize
;
static
constexpr
int
MaxBlockGroupSize
=
256
;
long_index_t
GetWorkspaceSizeInBytes
(
const
std
::
vector
<
int
>
inLengths
,
const
std
::
vector
<
int
>
reduceDims
)
override
{
size_t
invariant_total_length
;
size_t
reduce_total_length
;
auto
inLengths_
=
shuffle_tensor_dimensions
<
Rank
,
NumReduceDim
>
(
inLengths
,
reduceDims
);
std
::
tie
(
invariant_total_length
,
reduce_total_length
)
=
get_2d_lengths
<
Rank
,
NumReduceDim
>
(
inLengths_
);
int
iterations
=
1
;
while
(
true
)
{
int
testBlkGroupSize
=
(
reduce_total_length
+
(
K_BlockTileSize
*
iterations
)
-
1
)
/
(
K_BlockTileSize
*
iterations
);
if
(
testBlkGroupSize
<=
MaxBlockGroupSize
)
break
;
iterations
++
;
};
int
blkGroupSize
=
(
reduce_total_length
+
(
K_BlockTileSize
*
iterations
)
-
1
)
/
(
K_BlockTileSize
*
iterations
);
long_index_t
workspace_size
=
invariant_total_length
*
blkGroupSize
;
long_index_t
wsSizeInBytes
=
!
NeedIndices
?
workspace_size
*
sizeof
(
AccDataType
)
:
workspace_size
*
(
sizeof
(
AccDataType
)
+
sizeof
(
int32_t
))
+
64
+
sizeof
(
int
);
return
(
wsSizeInBytes
);
};
bool
HasFurtherCall
()
override
{
return
(
true
);
};
static
auto
MakeSrc2dDescriptor
(
const
std
::
vector
<
int
>&
inLengths
,
const
std
::
vector
<
int
>&
inStrides
,
int
blkGroupSize
,
int
kBlockTileIterations
)
{
const
auto
tupleSrcLengths
=
make_tuple_from_array
(
inLengths
,
Number
<
numSrcDim
>
{});
const
auto
tupleSrcStrides
=
make_tuple_from_array
(
inStrides
,
Number
<
numSrcDim
>
{});
const
auto
inDesc
=
make_naive_tensor_descriptor
(
tupleSrcLengths
,
tupleSrcStrides
);
const
auto
in_grid_desc_m_k
=
[
&
]()
{
if
constexpr
(
reduceAllDim
)
{
const
auto
one_dim_inDesc
=
transform_tensor_descriptor
(
inDesc
,
make_tuple
(
make_merge_transform
(
tupleSrcLengths
)),
make_tuple
(
typename
arithmetic_sequence_gen
<
0
,
numSrcDim
,
1
>::
type
{}),
make_tuple
(
Sequence
<
0
>
{}));
return
transform_tensor_descriptor
(
one_dim_inDesc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
1
,
one_dim_inDesc
.
GetLength
(
Number
<
0
>
{})))),
make_tuple
(
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
1
>
{}));
}
else
{
using
InvariantDims
=
typename
arithmetic_sequence_gen
<
0
,
NumInvariantDim
,
1
>::
type
;
using
ReduceDims
=
typename
arithmetic_sequence_gen
<
NumInvariantDim
,
Rank
,
1
>::
type
;
const
auto
reduceDimLengths
=
make_tuple_from_array_and_index_seq
(
inLengths
,
ReduceDims
{});
const
auto
invariantDimLengths
=
make_tuple_from_array_and_index_seq
(
inLengths
,
InvariantDims
{});
return
transform_tensor_descriptor
(
inDesc
,
make_tuple
(
make_merge_transform
(
invariantDimLengths
),
make_merge_transform
(
reduceDimLengths
)),
make_tuple
(
InvariantDims
{},
ReduceDims
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
}
}();
const
auto
invariantLength
=
in_grid_desc_m_k
.
GetLength
(
Number
<
0
>
{});
const
auto
reduceLength
=
in_grid_desc_m_k
.
GetLength
(
Number
<
1
>
{});
const
int
reduceSizePerBlock
=
K_BlockTileSize
*
kBlockTileIterations
;
const
auto
inPad_M
=
math
::
integer_least_multiple
(
invariantLength
,
M_BlockTileSize
)
-
invariantLength
;
const
auto
inPad_K
=
reduceSizePerBlock
*
blkGroupSize
-
reduceLength
;
auto
in_grid_desc_m_k_padded
=
transform_tensor_descriptor
(
in_grid_desc_m_k
,
make_tuple
(
make_right_pad_transform
(
invariantLength
,
inPad_M
),
make_right_pad_transform
(
reduceLength
,
inPad_K
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
return
(
in_grid_desc_m_k_padded
);
};
static
auto
MakeWorkspace2dDescriptor
(
int
invariantLength
,
int
blkGroupSize
)
{
auto
ws_desc_m_k
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
invariantLength
,
blkGroupSize
));
const
auto
wsPad
=
math
::
integer_least_multiple
(
invariantLength
,
M_BlockTileSize
)
-
invariantLength
;
auto
ws_desc_m_k_padded
=
transform_tensor_descriptor
(
ws_desc_m_k
,
make_tuple
(
make_right_pad_transform
(
invariantLength
,
wsPad
),
make_pass_through_transform
(
blkGroupSize
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
return
(
ws_desc_m_k_padded
);
};
struct
Argument
:
public
BaseArgument
{
Argument
(
const
std
::
vector
<
int
>
inLengths
,
const
std
::
vector
<
int
>
inStrides
,
const
std
::
vector
<
int
>
outLengths
,
const
std
::
vector
<
int
>
outStrides
,
const
std
::
vector
<
int
>
reduceDims
,
float
alpha
,
float
beta
,
const
InDataType
*
in_dev
,
OutDataType
*
out_dev
,
IndexDataType
*
out_indices_dev
,
AccDataType
*
workspace_dev
,
const
InElementwiseOperation
in_elementwise_op
,
const
AccElementwiseOperation
acc_elementwise_op
)
:
outLengths_
{
outLengths
},
outStrides_
{
outStrides
},
in_dev_
{
in_dev
},
out_dev_
{
out_dev
},
out_indices_dev_
{
out_indices_dev
},
workspace_dev_
{
workspace_dev
},
in_elementwise_op_
{
in_elementwise_op
},
acc_elementwise_op_
{
acc_elementwise_op
}
{
inLengths_
=
shuffle_tensor_dimensions
<
Rank
,
NumReduceDim
>
(
inLengths
,
reduceDims
);
inStrides_
=
shuffle_tensor_dimensions
<
Rank
,
NumReduceDim
>
(
inStrides
,
reduceDims
);
alpha_
=
type_convert
<
AccDataType
>
(
alpha
);
beta_
=
type_convert
<
AccDataType
>
(
beta
);
std
::
tie
(
invariant_total_length
,
reduce_total_length
)
=
get_2d_lengths
<
Rank
,
NumReduceDim
>
(
inLengths_
);
if
constexpr
(
NumInvariantDim
==
0
)
invariant_lowest_length
=
1
;
else
invariant_lowest_length
=
inLengths_
[
NumInvariantDim
-
1
];
reduce_lowest_length
=
inLengths_
[
Rank
-
1
];
int
iterations
=
1
;
while
(
true
)
{
int
testBlkGroupSize
=
(
reduce_total_length
+
(
K_BlockTileSize
*
iterations
)
-
1
)
/
(
K_BlockTileSize
*
iterations
);
if
(
testBlkGroupSize
<=
MaxBlockGroupSize
)
break
;
iterations
++
;
};
blkGroupSize
=
(
reduce_total_length
+
(
K_BlockTileSize
*
iterations
)
-
1
)
/
(
K_BlockTileSize
*
iterations
);
kBlockTileIterations
=
iterations
;
gridSize
=
math
::
integer_least_multiple
(
invariant_total_length
,
M_BlockTileSize
)
/
M_BlockTileSize
*
blkGroupSize
;
size_t
ws_buf2_bytes_offset
=
math
::
integer_least_multiple
(
invariant_total_length
*
blkGroupSize
*
sizeof
(
AccDataType
),
64
);
if
constexpr
(
NeedIndices
)
workspace_indices_dev_
=
reinterpret_cast
<
int
*>
(
reinterpret_cast
<
char
*>
(
workspace_dev_
)
+
ws_buf2_bytes_offset
);
else
workspace_indices_dev_
=
nullptr
;
}
std
::
vector
<
int
>
inLengths_
;
std
::
vector
<
int
>
inStrides_
;
std
::
vector
<
int
>
outLengths_
;
std
::
vector
<
int
>
outStrides_
;
AccDataType
alpha_
;
AccDataType
beta_
;
const
InDataType
*
in_dev_
;
OutDataType
*
out_dev_
;
IndexDataType
*
out_indices_dev_
;
AccDataType
*
workspace_dev_
;
IndexDataType
*
workspace_indices_dev_
;
InElementwiseOperation
in_elementwise_op_
;
AccElementwiseOperation
acc_elementwise_op_
;
int
invariant_lowest_length
;
int
reduce_lowest_length
;
size_t
invariant_total_length
;
size_t
reduce_total_length
;
index_t
blkGroupSize
;
index_t
kBlockTileIterations
;
size_t
gridSize
;
};
struct
Invoker
:
public
BaseInvoker
{
float
Run
(
const
Argument
&
arg
,
int
nrepeat
=
1
)
{
const
auto
in_grid_desc_m_k
=
DeviceReduceMultiBlockPartialReduce
::
MakeSrc2dDescriptor
(
arg
.
inLengths_
,
arg
.
inStrides_
,
arg
.
blkGroupSize
,
arg
.
kBlockTileIterations
);
const
auto
ws_desc_m_k
=
DeviceReduceMultiBlockPartialReduce
::
MakeWorkspace2dDescriptor
(
arg
.
invariant_total_length
,
arg
.
blkGroupSize
);
using
InGridDesc_M_K
=
decltype
(
in_grid_desc_m_k
);
using
WorkspaceDesc_M_K
=
decltype
(
ws_desc_m_k
);
using
GridwiseReduce
=
GridwiseReduction_mk_to_mk_multiblock_partial_reduce
<
InDataType
,
AccDataType
,
IndexDataType
,
InGridDesc_M_K
,
WorkspaceDesc_M_K
,
ReduceOperation
,
InElementwiseOperation
,
AccElementwiseOperation
,
PropagateNan
,
BlockSize
,
MThreadClusterSize
,
KThreadClusterSize
,
MThreadSliceSize
,
KThreadSliceSize
,
InSrcVectorDim
,
InSrcVectorSize
,
OutDstVectorSize
>
;
float
avg_time
=
0
;
const
auto
kernel
=
kernel_partial_reduce_multiblock
<
GridwiseReduce
,
NeedIndices
,
InDataType
,
AccDataType
,
IndexDataType
,
InGridDesc_M_K
,
WorkspaceDesc_M_K
,
InElementwiseOperation
,
AccElementwiseOperation
>
;
avg_time
=
launch_and_time_kernel
(
kernel
,
nrepeat
,
dim3
(
arg
.
gridSize
),
dim3
(
BlockSize
),
0
,
in_grid_desc_m_k
,
ws_desc_m_k
,
arg
.
in_elementwise_op_
,
arg
.
acc_elementwise_op_
,
arg
.
blkGroupSize
,
arg
.
kBlockTileIterations
,
arg
.
in_dev_
,
arg
.
workspace_dev_
,
arg
.
workspace_indices_dev_
);
return
(
avg_time
);
};
float
Run
(
const
BaseArgument
*
p_arg
,
int
nrepeat
=
1
)
override
{
return
Run
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
),
nrepeat
);
};
};
bool
IsSupportedArgument
(
const
BaseArgument
*
p_arg
)
override
{
const
Argument
*
pArg
=
dynamic_cast
<
const
Argument
*>
(
p_arg
);
if
constexpr
(
OutDstVectorSize
!=
1
)
return
(
false
);
if
constexpr
(
InSrcVectorDim
==
0
)
{
if
constexpr
(
NumInvariantDim
==
0
)
{
return
(
false
);
}
else
{
if
(
pArg
->
inStrides_
[
NumInvariantDim
-
1
]
!=
1
)
return
(
false
);
if
(
pArg
->
invariant_lowest_length
%
InSrcVectorSize
!=
0
)
return
(
false
);
};
}
else
{
if
(
pArg
->
inStrides_
[
Rank
-
1
]
!=
1
)
return
(
false
);
if
(
pArg
->
reduce_lowest_length
%
InSrcVectorSize
!=
0
)
return
(
false
);
};
// cases with small reduce_total_length should be handled by the BlockWise method
if
(
pArg
->
reduce_total_length
<=
BlockSize
*
KThreadSliceSize
)
return
(
false
);
return
(
true
);
};
std
::
vector
<
int
>
GetWorkspace2dLengths
(
const
BaseArgument
*
p_arg
)
override
{
const
Argument
*
pArg
=
dynamic_cast
<
const
Argument
*>
(
p_arg
);
return
(
std
::
vector
<
int
>
{
static_cast
<
int
>
(
pArg
->
invariant_total_length
),
pArg
->
blkGroupSize
});
};
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
std
::
vector
<
int
>
inLengths
,
const
std
::
vector
<
int
>
inStrides
,
const
std
::
vector
<
int
>
outLengths
,
const
std
::
vector
<
int
>
outStrides
,
const
std
::
vector
<
int
>
reduceDims
,
float
alpha
,
float
beta
,
const
void
*
in_dev
,
void
*
out_dev
,
void
*
out_indices_dev
,
void
*
workspace_dev
,
const
InElementwiseOperation
in_elementwise_op
,
const
AccElementwiseOperation
acc_elementwise_op
)
override
{
return
std
::
make_unique
<
Argument
>
(
inLengths
,
inStrides
,
outLengths
,
outStrides
,
reduceDims
,
alpha
,
beta
,
static_cast
<
const
InDataType
*>
(
in_dev
),
static_cast
<
OutDataType
*>
(
out_dev
),
static_cast
<
IndexDataType
*>
(
out_indices_dev
),
static_cast
<
AccDataType
*>
(
workspace_dev
),
in_elementwise_op
,
acc_elementwise_op
);
};
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
override
{
return
std
::
make_unique
<
Invoker
>
();
};
std
::
string
GetTypeString
()
const
override
{
auto
str
=
std
::
stringstream
();
// clang-format off
str
<<
"DeviceReduceMultiBlockPartialReduce<"
<<
BlockSize
<<
","
;
str
<<
"M_C"
<<
MThreadClusterSize
<<
"_S"
<<
MThreadSliceSize
<<
","
;
str
<<
"K_C"
<<
KThreadClusterSize
<<
"_S"
<<
KThreadSliceSize
<<
","
;
str
<<
"InSrcVectorDim_"
<<
InSrcVectorDim
<<
"_InSrcVectorSize_"
<<
InSrcVectorSize
<<
"_OutDstVectorSize_"
<<
OutDstVectorSize
<<
">"
;
// clang-format on
return
str
.
str
();
}
};
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
#endif
include/ck/tensor_operation/gpu/device/device_reduce_threadwise.hpp
View file @
a3b4c5cb
...
@@ -6,6 +6,7 @@
...
@@ -6,6 +6,7 @@
#include "device.hpp"
#include "device.hpp"
#include "device_reduce.hpp"
#include "device_reduce.hpp"
#include "device_reduce_common.hpp"
#include "device_reduce_common.hpp"
#include "gridwise_2d_reduction_multiblock.hpp"
#include "gridwise_2d_reduction_threadwise.hpp"
#include "gridwise_2d_reduction_threadwise.hpp"
namespace
ck
{
namespace
ck
{
...
@@ -19,22 +20,19 @@ template <typename InDataType,
...
@@ -19,22 +20,19 @@ template <typename InDataType,
index_t
NumReduceDim
,
index_t
NumReduceDim
,
typename
ReduceOperation
,
typename
ReduceOperation
,
typename
InElementwiseOperation
,
typename
InElementwiseOperation
,
typename
Out
ElementwiseOperation
,
typename
Acc
ElementwiseOperation
,
bool
PropagateNan
,
bool
PropagateNan
,
bool
NeedIndices
,
bool
OutputIndex
,
bool
HaveIndexInputIfOutputIndex
,
index_t
BlockSize
,
index_t
BlockSize
,
index_t
MThreadClusterSize
,
index_t
KThreadClusterSize
,
index_t
MThreadSliceSize
,
index_t
MThreadSliceSize
,
index_t
KThreadSliceSize
,
index_t
KThreadSliceSize
,
index_t
InSrcVectorDim
,
index_t
InSrcVectorDim
,
index_t
InSrcVectorSize
,
index_t
InSrcVectorSize
,
index_t
OutDstVectorSize
>
index_t
OutDstVectorSize
>
struct
DeviceReduceThreadWise
:
public
DeviceReduce
<
InElementwiseOperation
,
Out
ElementwiseOperation
>
struct
DeviceReduceThreadWise
:
public
DeviceReduce
<
InElementwiseOperation
,
Acc
ElementwiseOperation
>
{
{
static_assert
(
Rank
<=
6
,
"Bigger Rank size is not supported!"
);
static_assert
(
Rank
<=
6
,
"Bigger Rank size is not supported!"
);
static_assert
((
BlockSize
==
MThreadClusterSize
)
&&
(
KThreadClusterSize
==
1
),
"Threadwise can only be called with KThreadClusterSize be 1 !"
);
static_assert
(((
InSrcVectorDim
==
0
&&
MThreadSliceSize
%
InSrcVectorSize
==
0
)
||
static_assert
(((
InSrcVectorDim
==
0
&&
MThreadSliceSize
%
InSrcVectorSize
==
0
)
||
(
InSrcVectorDim
==
1
&&
KThreadSliceSize
%
InSrcVectorSize
==
0
))
&&
(
InSrcVectorDim
==
1
&&
KThreadSliceSize
%
InSrcVectorSize
==
0
))
&&
...
@@ -43,7 +41,7 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, OutE
...
@@ -43,7 +41,7 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, OutE
using
IndexDataType
=
int32_t
;
using
IndexDataType
=
int32_t
;
static
constexpr
bool
BetaIsZero
=
NeedIndices
;
static
constexpr
bool
HaveIndexInput
=
OutputIndex
&&
HaveIndexInputIfOutputIndex
;
static
constexpr
index_t
NumInvariantDim
=
Rank
-
NumReduceDim
;
static
constexpr
index_t
NumInvariantDim
=
Rank
-
NumReduceDim
;
...
@@ -51,11 +49,11 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, OutE
...
@@ -51,11 +49,11 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, OutE
static
constexpr
index_t
numDstDim
=
(
NumInvariantDim
==
0
)
?
1
:
NumInvariantDim
;
static
constexpr
index_t
numDstDim
=
(
NumInvariantDim
==
0
)
?
1
:
NumInvariantDim
;
static
constexpr
bool
reduceAllDim
=
(
NumInvariantDim
==
0
);
static
constexpr
bool
reduceAllDim
=
(
NumInvariantDim
==
0
);
static
constexpr
int
M_BlockTileSize
=
MThreadCluster
Size
*
MThreadSliceSize
;
static
constexpr
in
dex_
t
M_BlockTileSize
=
Block
Size
*
MThreadSliceSize
;
static
constexpr
int
K_BlockTileSize
=
KThreadClusterSize
*
KThreadSliceSize
;
static
constexpr
in
dex_
t
K_BlockTileSize
=
1
*
KThreadSliceSize
;
static
auto
MakeSrc2dDescriptor
(
const
std
::
vector
<
int
>&
inLengths
,
static
auto
MakeSrc2dDescriptor
(
const
std
::
vector
<
in
dex_
t
>&
inLengths
,
const
std
::
vector
<
int
>&
inStrides
)
const
std
::
vector
<
in
dex_
t
>&
inStrides
)
{
{
const
auto
tupleSrcLengths
=
make_tuple_from_array
(
inLengths
,
Number
<
numSrcDim
>
{});
const
auto
tupleSrcLengths
=
make_tuple_from_array
(
inLengths
,
Number
<
numSrcDim
>
{});
const
auto
tupleSrcStrides
=
make_tuple_from_array
(
inStrides
,
Number
<
numSrcDim
>
{});
const
auto
tupleSrcStrides
=
make_tuple_from_array
(
inStrides
,
Number
<
numSrcDim
>
{});
...
@@ -114,8 +112,8 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, OutE
...
@@ -114,8 +112,8 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, OutE
return
(
in_grid_desc_m_k_padded
);
return
(
in_grid_desc_m_k_padded
);
};
};
static
auto
MakeDst1dDescriptor
(
const
std
::
vector
<
int
>&
outLengths
,
static
auto
MakeDst1dDescriptor
(
const
std
::
vector
<
in
dex_
t
>&
outLengths
,
const
std
::
vector
<
int
>&
outStrides
)
const
std
::
vector
<
in
dex_
t
>&
outStrides
)
{
{
const
auto
tupleDstLengths
=
make_tuple_from_array
(
outLengths
,
Number
<
numDstDim
>
{});
const
auto
tupleDstLengths
=
make_tuple_from_array
(
outLengths
,
Number
<
numDstDim
>
{});
const
auto
tupleDstStrides
=
make_tuple_from_array
(
outStrides
,
Number
<
numDstDim
>
{});
const
auto
tupleDstStrides
=
make_tuple_from_array
(
outStrides
,
Number
<
numDstDim
>
{});
...
@@ -143,30 +141,26 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, OutE
...
@@ -143,30 +141,26 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, OutE
struct
Argument
:
public
BaseArgument
struct
Argument
:
public
BaseArgument
{
{
Argument
(
const
std
::
vector
<
int
>
inLengths
,
Argument
(
const
std
::
vector
<
in
dex_
t
>
inLengths
,
const
std
::
vector
<
int
>
inStrides
,
const
std
::
vector
<
in
dex_
t
>
inStrides
,
const
std
::
vector
<
int
>
outLengths
,
const
std
::
vector
<
in
dex_
t
>
outLengths
,
const
std
::
vector
<
int
>
outStrides
,
const
std
::
vector
<
in
dex_
t
>
outStrides
,
const
std
::
vector
<
int
>
reduceDims
,
const
std
::
vector
<
int
>
reduceDims
,
float
alpha
,
float
alpha
,
float
beta
,
float
beta
,
const
InDataType
*
in_dev
,
const
InDataType
*
in_dev
,
OutDataType
*
out_dev
,
OutDataType
*
out_dev
,
IndexDataType
*
out_indices_dev
,
IndexDataType
*
out_index_dev
,
AccDataType
*
workspace_dev
,
const
InElementwiseOperation
in_elementwise_op
,
const
InElementwiseOperation
in_elementwise_op
,
const
Out
ElementwiseOperation
acc_elementwise_op
)
const
Acc
ElementwiseOperation
acc_elementwise_op
)
:
outLengths_
{
outLengths
},
:
outLengths_
{
outLengths
},
outStrides_
{
outStrides
},
outStrides_
{
outStrides
},
in_dev_
{
in_dev
},
in_dev_
{
in_dev
},
out_dev_
{
out_dev
},
out_dev_
{
out_dev
},
out_ind
ices
_dev_
{
out_ind
ices
_dev
},
out_ind
ex
_dev_
{
out_ind
ex
_dev
},
in_elementwise_op_
{
in_elementwise_op
},
in_elementwise_op_
{
in_elementwise_op
},
acc_elementwise_op_
{
acc_elementwise_op
}
acc_elementwise_op_
{
acc_elementwise_op
}
{
{
(
void
)
workspace_dev
;
inLengths_
=
shuffle_tensor_dimensions
<
Rank
,
NumReduceDim
>
(
inLengths
,
reduceDims
);
inLengths_
=
shuffle_tensor_dimensions
<
Rank
,
NumReduceDim
>
(
inLengths
,
reduceDims
);
inStrides_
=
shuffle_tensor_dimensions
<
Rank
,
NumReduceDim
>
(
inStrides
,
reduceDims
);
inStrides_
=
shuffle_tensor_dimensions
<
Rank
,
NumReduceDim
>
(
inStrides
,
reduceDims
);
...
@@ -183,36 +177,39 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, OutE
...
@@ -183,36 +177,39 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, OutE
reduce_lowest_length
=
inLengths_
[
Rank
-
1
];
reduce_lowest_length
=
inLengths_
[
Rank
-
1
];
numBlockTileIteration
=
(
reduce_total_length
+
K_BlockTileSize
-
1
)
/
K_BlockTileSize
;
gridSize
=
math
::
integer_least_multiple
(
invariant_total_length
,
M_BlockTileSize
)
/
gridSize
=
math
::
integer_least_multiple
(
invariant_total_length
,
M_BlockTileSize
)
/
M_BlockTileSize
;
M_BlockTileSize
;
}
}
std
::
vector
<
int
>
inLengths_
;
std
::
vector
<
in
dex_
t
>
inLengths_
;
std
::
vector
<
int
>
inStrides_
;
std
::
vector
<
in
dex_
t
>
inStrides_
;
std
::
vector
<
int
>
outLengths_
;
std
::
vector
<
in
dex_
t
>
outLengths_
;
std
::
vector
<
int
>
outStrides_
;
std
::
vector
<
in
dex_
t
>
outStrides_
;
AccDataType
alpha_
;
AccDataType
alpha_
;
AccDataType
beta_
;
AccDataType
beta_
;
const
InDataType
*
in_dev_
;
const
InDataType
*
in_dev_
;
OutDataType
*
out_dev_
;
OutDataType
*
out_dev_
;
IndexDataType
*
out_ind
ices
_dev_
;
IndexDataType
*
out_ind
ex
_dev_
;
InElementwiseOperation
in_elementwise_op_
;
InElementwiseOperation
in_elementwise_op_
;
Out
ElementwiseOperation
acc_elementwise_op_
;
Acc
ElementwiseOperation
acc_elementwise_op_
;
int
invariant_lowest_length
;
in
dex_
t
invariant_lowest_length
;
int
reduce_lowest_length
;
in
dex_
t
reduce_lowest_length
;
size
_t
invariant_total_length
;
long_index
_t
invariant_total_length
;
size
_t
reduce_total_length
;
long_index
_t
reduce_total_length
;
int
numBlockTileIteration
;
size_t
gridSize
;
size_t
gridSize
;
};
};
struct
Invoker
:
public
BaseInvoker
struct
Invoker
:
public
BaseInvoker
{
{
float
Run
(
const
Argument
&
arg
,
int
nrepeat
=
1
)
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{}
)
{
{
const
auto
in_grid_desc_m_k
=
const
auto
in_grid_desc_m_k
=
DeviceReduceThreadWise
::
MakeSrc2dDescriptor
(
arg
.
inLengths_
,
arg
.
inStrides_
);
DeviceReduceThreadWise
::
MakeSrc2dDescriptor
(
arg
.
inLengths_
,
arg
.
inStrides_
);
...
@@ -221,30 +218,30 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, OutE
...
@@ -221,30 +218,30 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, OutE
using
InGridDesc_M_K
=
decltype
(
in_grid_desc_m_k
);
using
InGridDesc_M_K
=
decltype
(
in_grid_desc_m_k
);
using
OutGridDesc_M
=
decltype
(
out_grid_desc_m
);
using
OutGridDesc_M
=
decltype
(
out_grid_desc_m
);
using
GridwiseReduce
=
GridwiseReduction_mk_to_m_threadwise
<
InDataType
,
OutDataType
,
AccDataType
,
IndexDataType
,
InGridDesc_M_K
,
OutGridDesc_M
,
ReduceOperation
,
InElementwiseOperation
,
OutElementwiseOperation
,
PropagateNan
,
BetaIsZero
,
BlockSize
,
MThreadClusterSize
,
KThreadClusterSize
,
MThreadSliceSize
,
KThreadSliceSize
,
InSrcVectorDim
,
InSrcVectorSize
,
OutDstVectorSize
>
;
float
avg_time
=
0
;
float
avg_time
=
0
;
using
GridwiseReduce
=
GridwiseReduction_mk_to_m_threadwise
<
InDataType
,
OutDataType
,
AccDataType
,
IndexDataType
,
InGridDesc_M_K
,
OutGridDesc_M
,
ReduceOperation
,
InElementwiseOperation
,
AccElementwiseOperation
,
InMemoryDataOperationEnum
::
Set
,
PropagateNan
,
BlockSize
,
MThreadSliceSize
,
KThreadSliceSize
,
InSrcVectorDim
,
InSrcVectorSize
,
OutDstVectorSize
>
;
const
auto
kernel
=
kernel_reduce_threadwise
<
GridwiseReduce
,
const
auto
kernel
=
kernel_reduce_threadwise
<
GridwiseReduce
,
NeedIndices
,
OutputIndex
,
HaveIndexInput
,
InDataType
,
InDataType
,
OutDataType
,
OutDataType
,
AccDataType
,
AccDataType
,
...
@@ -252,10 +249,10 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, OutE
...
@@ -252,10 +249,10 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, OutE
InGridDesc_M_K
,
InGridDesc_M_K
,
OutGridDesc_M
,
OutGridDesc_M
,
InElementwiseOperation
,
InElementwiseOperation
,
Out
ElementwiseOperation
>
;
Acc
ElementwiseOperation
>
;
avg_time
=
launch_and_time_kernel
(
kernel
,
avg_time
=
launch_and_time_kernel
(
stream_config
,
nrepeat
,
kernel
,
dim3
(
arg
.
gridSize
),
dim3
(
arg
.
gridSize
),
dim3
(
BlockSize
),
dim3
(
BlockSize
),
0
,
0
,
...
@@ -265,16 +262,18 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, OutE
...
@@ -265,16 +262,18 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, OutE
arg
.
acc_elementwise_op_
,
arg
.
acc_elementwise_op_
,
arg
.
alpha_
,
arg
.
alpha_
,
arg
.
in_dev_
,
arg
.
in_dev_
,
nullptr
,
arg
.
beta_
,
arg
.
beta_
,
arg
.
out_dev_
,
arg
.
out_dev_
,
arg
.
out_ind
ices
_dev_
);
arg
.
out_ind
ex
_dev_
);
return
(
avg_time
);
return
(
avg_time
);
};
};
float
Run
(
const
BaseArgument
*
p_arg
,
int
nrepeat
=
1
)
override
float
Run
(
const
BaseArgument
*
p_arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
override
{
{
return
Run
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
),
nrepeat
);
return
Run
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
),
stream_config
);
};
};
};
};
...
@@ -310,9 +309,7 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, OutE
...
@@ -310,9 +309,7 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, OutE
if
(
pArg
->
invariant_lowest_length
%
OutDstVectorSize
!=
0
)
if
(
pArg
->
invariant_lowest_length
%
OutDstVectorSize
!=
0
)
return
(
false
);
return
(
false
);
// TODO: remove this. Should return true, as long as this DeviceOP instance support this
// cases with big reduce_total_length should be handled by Blockwise kernel
// case for bigger reduce_total_length size, we are supposed to use BlockWise method for
// better performance
if
(
pArg
->
reduce_total_length
/
KThreadSliceSize
>=
32
)
if
(
pArg
->
reduce_total_length
/
KThreadSliceSize
>=
32
)
return
(
false
);
return
(
false
);
...
@@ -320,20 +317,22 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, OutE
...
@@ -320,20 +317,22 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, OutE
};
};
std
::
unique_ptr
<
BaseArgument
>
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
std
::
vector
<
int
>
inLengths
,
MakeArgumentPointer
(
const
std
::
vector
<
in
dex_
t
>
inLengths
,
const
std
::
vector
<
int
>
inStrides
,
const
std
::
vector
<
in
dex_
t
>
inStrides
,
const
std
::
vector
<
int
>
outLengths
,
const
std
::
vector
<
in
dex_
t
>
outLengths
,
const
std
::
vector
<
int
>
outStrides
,
const
std
::
vector
<
in
dex_
t
>
outStrides
,
const
std
::
vector
<
int
>
reduceDims
,
const
std
::
vector
<
int
>
reduceDims
,
float
alpha
,
float
alpha
,
float
beta
,
float
beta
,
const
void
*
in_dev
,
const
void
*
in_dev
,
const
void
*
in_index_dev
,
void
*
out_dev
,
void
*
out_dev
,
void
*
out_indices_dev
,
void
*
out_index_dev
,
void
*
workspace_dev
,
const
InElementwiseOperation
in_elementwise_op
,
const
InElementwiseOperation
in_elementwise_op
,
const
Out
ElementwiseOperation
acc_elementwise_op
)
override
const
Acc
ElementwiseOperation
acc_elementwise_op
)
override
{
{
(
void
)
in_index_dev
;
return
std
::
make_unique
<
Argument
>
(
inLengths
,
return
std
::
make_unique
<
Argument
>
(
inLengths
,
inStrides
,
inStrides
,
outLengths
,
outLengths
,
...
@@ -343,8 +342,7 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, OutE
...
@@ -343,8 +342,7 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, OutE
beta
,
beta
,
static_cast
<
const
InDataType
*>
(
in_dev
),
static_cast
<
const
InDataType
*>
(
in_dev
),
static_cast
<
OutDataType
*>
(
out_dev
),
static_cast
<
OutDataType
*>
(
out_dev
),
static_cast
<
IndexDataType
*>
(
out_indices_dev
),
static_cast
<
IndexDataType
*>
(
out_index_dev
),
static_cast
<
AccDataType
*>
(
workspace_dev
),
in_elementwise_op
,
in_elementwise_op
,
acc_elementwise_op
);
acc_elementwise_op
);
};
};
...
@@ -359,9 +357,9 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, OutE
...
@@ -359,9 +357,9 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, OutE
auto
str
=
std
::
stringstream
();
auto
str
=
std
::
stringstream
();
// clang-format off
// clang-format off
str
<<
"DeviceReduc
c
eThreadWise<"
<<
BlockSize
<<
","
;
str
<<
"DeviceReduceThreadWise<"
<<
BlockSize
<<
","
;
str
<<
"M_C"
<<
MThreadCluster
Size
<<
"_S"
<<
MThreadSliceSize
<<
","
;
str
<<
"M_C"
<<
Block
Size
<<
"_S"
<<
MThreadSliceSize
<<
","
;
str
<<
"K_C"
<<
KThreadClusterSize
<<
"_S"
<<
KThreadSliceSize
<<
","
;
str
<<
"K_C"
<<
1
<<
"_S"
<<
KThreadSliceSize
<<
","
;
str
<<
"InSrcVectorDim_"
<<
InSrcVectorDim
<<
"_InSrcVectorSize_"
<<
InSrcVectorSize
<<
"_OutDstVectorSize_"
<<
OutDstVectorSize
<<
">"
;
str
<<
"InSrcVectorDim_"
<<
InSrcVectorDim
<<
"_InSrcVectorSize_"
<<
InSrcVectorSize
<<
"_OutDstVectorSize_"
<<
OutDstVectorSize
<<
">"
;
// clang-format on
// clang-format on
...
...
include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp
0 → 100644
View file @
a3b4c5cb
/*******************************************************************************
*
* MIT License
*
* Copyright (c) 2022 Advanced Micro Devices, Inc.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in all
* copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*
*******************************************************************************/
#pragma once
#include "data_type.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
binary_element_wise
{
template
<
typename
Y
,
typename
X1
,
typename
X2
>
struct
Add
;
template
<
>
struct
Add
<
double
,
double
,
double
>
{
__host__
__device__
constexpr
void
operator
()(
double
&
dst
,
const
double
&
src1
,
const
double
&
src2
)
const
{
dst
=
src1
+
src2
;
}
};
template
<
>
struct
Add
<
float
,
float
,
float
>
{
__host__
__device__
constexpr
void
operator
()(
float
&
dst
,
const
float
&
src1
,
const
float
&
src2
)
const
{
dst
=
src1
+
src2
;
}
};
template
<
>
struct
Add
<
half_t
,
half_t
,
half_t
>
{
__host__
__device__
constexpr
void
operator
()(
half_t
&
dst
,
const
half_t
&
src1
,
const
half_t
&
src2
)
const
{
dst
=
src1
+
src2
;
}
};
template
<
>
struct
Add
<
bhalf_t
,
bhalf_t
,
bhalf_t
>
{
__host__
__device__
constexpr
void
operator
()(
bhalf_t
&
dst
,
const
bhalf_t
&
src1
,
const
bhalf_t
&
src2
)
const
{
const
float
x1
=
ck
::
type_convert
<
float
>
(
src1
);
const
float
x2
=
ck
::
type_convert
<
float
>
(
src2
);
const
float
y
=
x1
+
x2
;
dst
=
ck
::
type_convert
<
bhalf_t
>
(
y
);
}
};
template
<
typename
Y
,
typename
X1
,
typename
X2
>
struct
Substract
;
template
<
>
struct
Substract
<
double
,
double
,
double
>
{
__host__
__device__
constexpr
void
operator
()(
double
&
dst
,
const
double
&
src1
,
const
double
&
src2
)
const
{
dst
=
src1
-
src2
;
}
};
template
<
>
struct
Substract
<
float
,
float
,
float
>
{
__host__
__device__
constexpr
void
operator
()(
float
&
dst
,
const
float
&
src1
,
const
float
&
src2
)
const
{
dst
=
src1
-
src2
;
}
};
template
<
>
struct
Substract
<
half_t
,
half_t
,
half_t
>
{
__host__
__device__
constexpr
void
operator
()(
half_t
&
dst
,
const
half_t
&
src1
,
const
half_t
&
src2
)
const
{
dst
=
src1
-
src2
;
}
};
template
<
>
struct
Substract
<
bhalf_t
,
bhalf_t
,
bhalf_t
>
{
__host__
__device__
constexpr
void
operator
()(
bhalf_t
&
dst
,
const
bhalf_t
&
src1
,
const
bhalf_t
&
src2
)
const
{
const
float
x1
=
ck
::
type_convert
<
float
>
(
src1
);
const
float
x2
=
ck
::
type_convert
<
float
>
(
src2
);
const
float
y
=
x1
-
x2
;
dst
=
ck
::
type_convert
<
bhalf_t
>
(
y
);
}
};
}
// namespace binary_element_wise
}
// namespace tensor_operation
}
// namespace ck
include/ck/tensor_operation/gpu/element/element_wise_operation.hpp
View file @
a3b4c5cb
#pragma once
#pragma once
#include "data_type.hpp"
#include "data_type.hpp"
#include "math_v2.hpp"
namespace
ck
{
namespace
ck
{
namespace
tensor_operation
{
namespace
tensor_operation
{
...
@@ -143,35 +144,22 @@ struct AddHardswishAdd
...
@@ -143,35 +144,22 @@ struct AddHardswishAdd
}
}
};
};
struct
RequantReluRequant
struct
Normalize
{
{
// FIXME: We just need one scale for Relu / Leaky Relu / PRelu
Normalize
(
float
epsilon
=
1e-4
)
:
epsilon_
(
epsilon
)
{}
RequantReluRequant
(
float
scaleGemm
,
float
scaleRelu
)
:
scaleGemm_
(
scaleGemm
),
scaleRelu_
(
scaleRelu
)
__host__
__device__
constexpr
void
operator
()(
float
&
y
,
const
float
&
x
,
const
float
&
mean
,
const
float
&
mean_square
,
const
float
&
gamma
,
const
float
&
beta
)
const
{
{
float
variance
=
mean_square
-
(
mean
*
mean
);
y
=
((
x
-
mean
)
/
sqrtf
(
variance
+
epsilon_
))
*
gamma
+
beta
;
}
}
__host__
__device__
constexpr
void
operator
()(
int8_t
&
y
,
const
int
&
x
)
const
float
epsilon_
;
{
float
gemm_requant
=
scaleGemm_
*
static_cast
<
float
>
(
x
);
float
relu
=
gemm_requant
>
0
?
gemm_requant
:
0
;
float
relu_requant
=
scaleRelu_
*
relu
;
y
=
static_cast
<
int8_t
>
(
relu_requant
>
127
?
127
:
relu_requant
<
-
128
?
-
128
:
relu_requant
);
}
// for reference_gemm
__host__
__device__
constexpr
void
operator
()(
float
&
y
,
const
float
&
x
)
const
{
float
gemm_requant
=
scaleGemm_
*
x
;
float
relu
=
gemm_requant
>
0
?
gemm_requant
:
0
;
float
relu_requant
=
scaleRelu_
*
relu
;
y
=
static_cast
<
float
>
(
relu_requant
>
127
?
127
:
relu_requant
<
-
128
?
-
128
:
relu_requant
);
}
float
scaleGemm_
;
float
scaleRelu_
;
};
};
// Unary operators are usually called element-wisely before/after the reduction is executed on the
// Unary operators are usually called element-wisely before/after the reduction is executed on the
...
@@ -309,7 +297,7 @@ struct UnaryAbs<float, float>
...
@@ -309,7 +297,7 @@ struct UnaryAbs<float, float>
{
{
__host__
__device__
UnaryAbs
(
const
int32_t
divider
=
1
)
{
(
void
)
divider
;
};
__host__
__device__
UnaryAbs
(
const
int32_t
divider
=
1
)
{
(
void
)
divider
;
};
__host__
__device__
void
operator
()(
float
&
y
,
const
float
&
x
)
const
{
y
=
abs
(
x
);
};
__host__
__device__
void
operator
()(
float
&
y
,
const
float
&
x
)
const
{
y
=
ck
::
math
::
abs
(
x
);
};
};
};
template
<
>
template
<
>
...
@@ -317,7 +305,7 @@ struct UnaryAbs<half_t, half_t>
...
@@ -317,7 +305,7 @@ struct UnaryAbs<half_t, half_t>
{
{
__host__
__device__
UnaryAbs
(
const
int32_t
divider
=
1
)
{
(
void
)
divider
;
};
__host__
__device__
UnaryAbs
(
const
int32_t
divider
=
1
)
{
(
void
)
divider
;
};
__host__
__device__
void
operator
()(
half_t
&
y
,
const
half_t
&
x
)
const
{
y
=
__h
abs
(
x
);
};
__host__
__device__
void
operator
()(
half_t
&
y
,
const
half_t
&
x
)
const
{
y
=
ck
::
math
::
abs
(
x
);
};
};
};
template
<
>
template
<
>
...
@@ -325,7 +313,7 @@ struct UnaryAbs<double, double>
...
@@ -325,7 +313,7 @@ struct UnaryAbs<double, double>
{
{
__host__
__device__
UnaryAbs
(
const
int32_t
divider
=
1
)
{
(
void
)
divider
;
};
__host__
__device__
UnaryAbs
(
const
int32_t
divider
=
1
)
{
(
void
)
divider
;
};
__host__
__device__
void
operator
()(
double
&
y
,
const
double
&
x
)
const
{
y
=
abs
(
x
);
};
__host__
__device__
void
operator
()(
double
&
y
,
const
double
&
x
)
const
{
y
=
ck
::
math
::
abs
(
x
);
};
};
};
template
<
>
template
<
>
...
@@ -333,12 +321,7 @@ struct UnaryAbs<int8_t, int8_t>
...
@@ -333,12 +321,7 @@ struct UnaryAbs<int8_t, int8_t>
{
{
__host__
__device__
UnaryAbs
(
const
int32_t
divider
=
1
)
{
(
void
)
divider
;
};
__host__
__device__
UnaryAbs
(
const
int32_t
divider
=
1
)
{
(
void
)
divider
;
};
__host__
__device__
void
operator
()(
int8_t
&
y
,
const
int8_t
&
x
)
const
__host__
__device__
void
operator
()(
int8_t
&
y
,
const
int8_t
&
x
)
const
{
y
=
ck
::
math
::
abs
(
x
);
};
{
int8_t
sgn
=
x
>>
(
8
-
1
);
y
=
(
x
^
sgn
)
-
sgn
;
};
};
};
template
<
typename
Y
,
typename
X
>
template
<
typename
Y
,
typename
X
>
...
@@ -349,7 +332,7 @@ struct UnarySqrt<float, float>
...
@@ -349,7 +332,7 @@ struct UnarySqrt<float, float>
{
{
__host__
__device__
UnarySqrt
(
const
int32_t
divider
=
1
)
{
(
void
)
divider
;
};
__host__
__device__
UnarySqrt
(
const
int32_t
divider
=
1
)
{
(
void
)
divider
;
};
__host__
__device__
void
operator
()(
float
&
y
,
const
float
&
x
)
const
{
y
=
sqrt
f
(
x
);
};
__host__
__device__
void
operator
()(
float
&
y
,
const
float
&
x
)
const
{
y
=
ck
::
math
::
sqrt
(
x
);
};
};
};
template
<
>
template
<
>
...
@@ -357,7 +340,10 @@ struct UnarySqrt<double, double>
...
@@ -357,7 +340,10 @@ struct UnarySqrt<double, double>
{
{
__host__
__device__
UnarySqrt
(
const
int32_t
divider
=
1
)
{
(
void
)
divider
;
};
__host__
__device__
UnarySqrt
(
const
int32_t
divider
=
1
)
{
(
void
)
divider
;
};
__host__
__device__
void
operator
()(
double
&
y
,
const
double
&
x
)
const
{
y
=
sqrt
(
x
);
};
__host__
__device__
void
operator
()(
double
&
y
,
const
double
&
x
)
const
{
y
=
ck
::
math
::
sqrt
(
x
);
};
};
};
}
// namespace element_wise
}
// namespace element_wise
...
...
include/ck/tensor_operation/gpu/element/element_wise_reduce_operation.hpp
View file @
a3b4c5cb
...
@@ -5,20 +5,6 @@ namespace ck {
...
@@ -5,20 +5,6 @@ namespace ck {
namespace
tensor_operation
{
namespace
tensor_operation
{
namespace
element_wise
{
namespace
element_wise
{
struct
ReduceSum
{
__host__
__device__
static
constexpr
float
GetReduceZeroValue
()
{
return
float
(
0
);
}
__host__
__device__
void
Reduce
(
float
&
acc
,
float
v
)
const
{
acc
+=
v
;
}
};
struct
ReduceSquareSum
{
__host__
__device__
static
constexpr
float
GetReduceZeroValue
()
{
return
float
(
0
);
}
__host__
__device__
void
Reduce
(
float
&
acc
,
float
v
)
const
{
acc
+=
v
*
v
;
}
};
}
// namespace element_wise
}
// namespace element_wise
}
// namespace tensor_operation
}
// namespace tensor_operation
}
// namespace ck
}
// namespace ck
include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp
0 → 100644
View file @
a3b4c5cb
#ifndef UTILITY_BLOCK_TO_CTILE_MAP
#define UTILITY_BLOCK_TO_CTILE_MAP
#include "utility/math.hpp"
#include "utility/number.hpp"
#include "tensor_description/tensor_adaptor.hpp"
#include "tensor_description/multi_index_transform_helper.hpp"
namespace
ck
{
// Rows of column-vectors
template
<
index_t
MPerBlock
,
index_t
NPerBlock
,
typename
CGridDesc_M_N
,
bool
DeviceCTileIndexCheck
=
false
>
struct
BlockToCTileMap_M00_N0_M01
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
__host__
__device__
BlockToCTileMap_M00_N0_M01
()
=
default
;
__host__
__device__
BlockToCTileMap_M00_N0_M01
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
,
index_t
M01
=
1
)
:
M01_
(
M01
),
underlying_map_
(
GetBlockToCTileMap
(
c_grid_desc_m_n
,
M01
))
{
}
__host__
constexpr
index_t
CalculateGridSize
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
)
const
{
const
auto
M0
=
math
::
integer_divide_ceil
(
c_grid_desc_m_n
.
GetLength
(
I0
),
MPerBlock
);
const
auto
N0
=
math
::
integer_divide_ceil
(
c_grid_desc_m_n
.
GetLength
(
I1
),
NPerBlock
);
const
auto
M00
=
math
::
integer_divide_ceil
(
M0
,
M01_
);
const
index_t
grid_size
=
M00
*
M01_
*
N0
;
return
grid_size
;
}
template
<
typename
TopIdx
>
__host__
__device__
constexpr
auto
CalculateBottomIndex
(
const
TopIdx
&
idx_top
)
const
{
return
underlying_map_
.
CalculateBottomIndex
(
idx_top
);
}
template
<
typename
CTileIdx
,
typename
CTileDim
>
__host__
__device__
bool
ValidCTileIndex
(
const
CTileIdx
&
c_tile_idx
,
const
CTileDim
&
c_tile_dim
)
const
{
if
constexpr
(
DeviceCTileIndexCheck
)
return
DefaultValidCTileIndex
(
c_tile_idx
,
c_tile_dim
);
else
return
true
;
}
__host__
bool
CheckValidity
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
)
const
{
if
constexpr
(
DeviceCTileIndexCheck
)
return
true
;
// validity check moved to kernel
const
index_t
M0
=
math
::
integer_divide_ceil
(
c_grid_desc_m_n
.
GetLength
(
I0
),
MPerBlock
);
if
(
M0
%
M01_
==
0
)
{
return
true
;
}
else
{
return
false
;
}
}
private:
__host__
__device__
static
constexpr
auto
GetBlockToCTileMap
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
,
index_t
M01
)
{
const
auto
M0
=
math
::
integer_divide_ceil
(
c_grid_desc_m_n
.
GetLength
(
I0
),
MPerBlock
);
const
auto
N0
=
math
::
integer_divide_ceil
(
c_grid_desc_m_n
.
GetLength
(
I1
),
NPerBlock
);
const
auto
M00
=
math
::
integer_divide_ceil
(
M0
,
M01
);
const
auto
m00_n0_m01_to_m0_n0_block_cluster_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_insert_transform
(
1
),
make_unmerge_transform
(
make_tuple
(
M00
,
M01
)),
make_pass_through_transform
(
make_tuple
(
N0
))),
make_tuple
(
Sequence
<>
{},
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
3
>
{},
Sequence
<
2
>
{}));
const
auto
cblockid_to_m00_n0_m01_block_cluster_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
1
,
M00
,
N0
,
M01
))),
make_tuple
(
Sequence
<
0
,
1
,
2
,
3
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
const
auto
cblockid_to_m0_n0_block_cluster_adaptor
=
chain_tensor_adaptors
(
m00_n0_m01_to_m0_n0_block_cluster_adaptor
,
cblockid_to_m00_n0_m01_block_cluster_adaptor
);
return
cblockid_to_m0_n0_block_cluster_adaptor
;
}
index_t
M01_
;
using
UnderlyingMap
=
decltype
(
GetBlockToCTileMap
(
CGridDesc_M_N
{},
1
));
UnderlyingMap
underlying_map_
;
};
// Rows of column-vectors
// This C-tile map dynamically adjusts M01 when C-tile index is out of range
template
<
index_t
MPerBlock
,
index_t
NPerBlock
,
typename
CGridDesc_M_N
>
struct
BlockToCTileMap_M00_N0_M01Adapt
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
__host__
__device__
BlockToCTileMap_M00_N0_M01Adapt
()
=
default
;
__host__
__device__
BlockToCTileMap_M00_N0_M01Adapt
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
,
index_t
M01
=
8
)
:
M01_
(
M01
),
c_grid_desc_m_n_
(
c_grid_desc_m_n
)
{
}
__host__
constexpr
index_t
CalculateGridSize
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
)
const
{
const
auto
M0
=
math
::
integer_divide_ceil
(
c_grid_desc_m_n
.
GetLength
(
I0
),
MPerBlock
);
const
auto
N0
=
math
::
integer_divide_ceil
(
c_grid_desc_m_n
.
GetLength
(
I1
),
NPerBlock
);
const
index_t
grid_size
=
M0
*
N0
;
return
grid_size
;
}
template
<
typename
TopIdx
>
__host__
__device__
constexpr
auto
CalculateBottomIndex
(
const
TopIdx
&
idx_top
)
const
{
auto
block_1d_id
=
idx_top
[
I0
];
const
auto
M0
=
math
::
integer_divide_ceil
(
c_grid_desc_m_n_
.
GetLength
(
I0
),
MPerBlock
);
const
auto
N0
=
math
::
integer_divide_ceil
(
c_grid_desc_m_n_
.
GetLength
(
I1
),
NPerBlock
);
block_1d_id
=
block_1d_id
%
(
M0
*
N0
);
// swallow batch index
index_t
idx_N0
=
block_1d_id
%
N0
;
index_t
idx_M0
=
block_1d_id
/
N0
;
const
auto
M01_adapt
=
(
idx_M0
<
M0
-
M0
%
M01_
)
?
M01_
:
M0
%
M01_
;
index_t
idx_M00
=
idx_M0
/
M01_
;
index_t
idx_M01
=
idx_M0
%
M01_
;
index_t
idx_N0_M01_local
=
idx_N0
+
idx_M01
*
N0
;
return
make_tuple
(
idx_N0_M01_local
%
M01_adapt
+
idx_M00
*
M01_
,
idx_N0_M01_local
/
M01_adapt
);
}
template
<
typename
CTileIdx
,
typename
CTileDim
>
__host__
__device__
bool
ValidCTileIndex
(
const
CTileIdx
&
/* c_tile_idx */
,
const
CTileDim
&
/* c_tile_dim */
)
const
{
return
true
;
// always valid provided that user gets grid size from CalculateGridSize()
}
__host__
bool
CheckValidity
(
const
CGridDesc_M_N
&
/* c_grid_desc_m_n */
)
const
{
return
true
;
}
private:
index_t
M01_
;
CGridDesc_M_N
c_grid_desc_m_n_
;
};
// 2D slices of column-vectors in 3D space
// This C-tile map dynamically adjusts M01 when C-tile index is out of range
template
<
index_t
MPerBlock
,
index_t
NPerBlock
,
typename
CGridDesc_M_N
>
struct
BlockToCTileMap_KSplit_M00_N0_M01Adapt
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
__host__
__device__
BlockToCTileMap_KSplit_M00_N0_M01Adapt
()
=
default
;
__host__
__device__
BlockToCTileMap_KSplit_M00_N0_M01Adapt
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
,
index_t
M01
=
8
,
index_t
KSplit
=
1
)
:
M01_
(
M01
),
KSplit_
(
KSplit
),
c_grid_desc_m_n_
(
c_grid_desc_m_n
)
{
}
__host__
constexpr
index_t
CalculateGridSize
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
)
const
{
const
auto
M0
=
math
::
integer_divide_ceil
(
c_grid_desc_m_n
.
GetLength
(
I0
),
MPerBlock
);
const
auto
N0
=
math
::
integer_divide_ceil
(
c_grid_desc_m_n
.
GetLength
(
I1
),
NPerBlock
);
const
index_t
grid_size
=
M0
*
N0
*
KSplit_
;
return
grid_size
;
}
template
<
typename
TopIdx
>
__host__
__device__
constexpr
auto
CalculateBottomIndex
(
const
TopIdx
&
idx_top
)
const
{
auto
block_1d_id
=
idx_top
[
I0
];
const
auto
M0
=
math
::
integer_divide_ceil
(
c_grid_desc_m_n_
.
GetLength
(
I0
),
MPerBlock
);
const
auto
N0
=
math
::
integer_divide_ceil
(
c_grid_desc_m_n_
.
GetLength
(
I1
),
NPerBlock
);
const
index_t
idx_ksplit
=
block_1d_id
/
(
M0
*
N0
);
block_1d_id
=
block_1d_id
%
(
M0
*
N0
);
index_t
idx_N0
=
block_1d_id
%
N0
;
index_t
idx_M0
=
block_1d_id
/
N0
;
const
auto
M01_adapt
=
(
idx_M0
<
M0
-
M0
%
M01_
)
?
M01_
:
M0
%
M01_
;
index_t
idx_M00
=
idx_M0
/
M01_
;
index_t
idx_M01
=
idx_M0
%
M01_
;
index_t
idx_N0_M01_local
=
idx_N0
+
idx_M01
*
N0
;
return
make_tuple
(
idx_ksplit
,
idx_N0_M01_local
%
M01_adapt
+
idx_M00
*
M01_
,
idx_N0_M01_local
/
M01_adapt
);
}
template
<
typename
CTileIdx
,
typename
CTileDim
>
__host__
__device__
bool
ValidCTileIndex
(
const
CTileIdx
&
/* c_tile_idx */
,
const
CTileDim
&
/* c_tile_dim */
)
const
{
return
true
;
// always valid provided that user gets grid size from CalculateGridSize()
}
__host__
bool
CheckValidity
(
const
CGridDesc_M_N
&
/* c_grid_desc_m_n */
)
const
{
return
true
;
}
private:
index_t
M01_
;
index_t
KSplit_
;
CGridDesc_M_N
c_grid_desc_m_n_
;
};
// Blocks of row-vectors
template
<
index_t
MPerBlock
,
index_t
NPerBlock
,
typename
CGridDesc_M_N
,
bool
DeviceCTileIndexCheck
=
false
>
struct
BlockToCTileMap_M00_N00_M01_N01
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
__host__
__device__
BlockToCTileMap_M00_N00_M01_N01
()
=
default
;
__host__
__device__
BlockToCTileMap_M00_N00_M01_N01
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
,
index_t
M01
=
1
,
index_t
N01
=
1
)
:
M01_
(
M01
),
N01_
(
N01
),
underlying_map_
(
GetBlockToCTileMap
(
c_grid_desc_m_n
,
M01
,
N01
))
{
}
__host__
constexpr
index_t
CalculateGridSize
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
)
const
{
const
auto
M0
=
math
::
integer_divide_ceil
(
c_grid_desc_m_n
.
GetLength
(
I0
),
MPerBlock
);
const
auto
N0
=
math
::
integer_divide_ceil
(
c_grid_desc_m_n
.
GetLength
(
I1
),
NPerBlock
);
const
auto
M00
=
math
::
integer_divide_ceil
(
M0
,
M01_
);
const
auto
N00
=
math
::
integer_divide_ceil
(
N0
,
N01_
);
const
index_t
grid_size
=
M00
*
M01_
*
N00
*
N01_
;
return
grid_size
;
}
template
<
typename
TopIdx
>
__host__
__device__
constexpr
auto
CalculateBottomIndex
(
const
TopIdx
&
idx_top
)
const
{
return
underlying_map_
.
CalculateBottomIndex
(
idx_top
);
}
template
<
typename
CTileIdx
,
typename
CTileDim
>
__host__
__device__
bool
ValidCTileIndex
(
const
CTileIdx
&
c_tile_idx
,
const
CTileDim
&
c_tile_dim
)
const
{
if
constexpr
(
DeviceCTileIndexCheck
)
return
DefaultValidCTileIndex
(
c_tile_idx
,
c_tile_dim
);
else
return
true
;
}
__host__
bool
CheckValidity
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
)
const
{
if
constexpr
(
DeviceCTileIndexCheck
)
return
true
;
// validity check moved to kernel
const
index_t
M0
=
math
::
integer_divide_ceil
(
c_grid_desc_m_n
.
GetLength
(
I0
),
MPerBlock
);
const
index_t
N0
=
math
::
integer_divide_ceil
(
c_grid_desc_m_n
.
GetLength
(
I1
),
NPerBlock
);
if
(
M0
%
M01_
==
0
&&
N0
%
N01_
==
0
)
{
return
true
;
}
else
{
return
false
;
}
}
private:
__host__
__device__
static
constexpr
auto
GetBlockToCTileMap
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
,
index_t
M01
,
index_t
N01
)
{
const
auto
M0
=
math
::
integer_divide_ceil
(
c_grid_desc_m_n
.
GetLength
(
I0
),
MPerBlock
);
const
auto
N0
=
math
::
integer_divide_ceil
(
c_grid_desc_m_n
.
GetLength
(
I1
),
NPerBlock
);
const
auto
M00
=
math
::
integer_divide_ceil
(
M0
,
M01
);
const
auto
N00
=
math
::
integer_divide_ceil
(
N0
,
N01
);
const
auto
m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_insert_transform
(
1
),
// swallow the carry from lower dimensions
make_unmerge_transform
(
make_tuple
(
M00
,
M01
)),
make_unmerge_transform
(
make_tuple
(
N00
,
N01
))),
make_tuple
(
Sequence
<>
{},
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
3
>
{},
Sequence
<
2
,
4
>
{}));
const
auto
cblockid_to_m00_m01_n00_n01_block_cluster_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
1
,
M00
,
N00
,
M01
,
N01
))),
make_tuple
(
Sequence
<
0
,
1
,
2
,
3
,
4
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
const
auto
cblockid_to_m0_n0_block_cluster_adaptor
=
chain_tensor_adaptors
(
m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor
,
cblockid_to_m00_m01_n00_n01_block_cluster_adaptor
);
return
cblockid_to_m0_n0_block_cluster_adaptor
;
}
index_t
M01_
,
N01_
;
using
UnderlyingMap
=
decltype
(
GetBlockToCTileMap
(
CGridDesc_M_N
{},
1
,
1
));
UnderlyingMap
underlying_map_
;
};
// 2D slices of row-vectors in 3D space
template
<
index_t
MPerBlock
,
index_t
NPerBlock
,
typename
CGridDesc_M_N
,
bool
DeviceCTileIndexCheck
=
false
>
struct
BlockToCTileMap_KSplit_M00_N00_M01_N01
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
__host__
BlockToCTileMap_KSplit_M00_N00_M01_N01
()
=
default
;
__host__
BlockToCTileMap_KSplit_M00_N00_M01_N01
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
,
index_t
M01
=
1
,
index_t
N01
=
1
,
index_t
KSplit
=
1
)
:
M01_
(
M01
),
N01_
(
N01
),
KSplit_
(
KSplit
),
underlying_map_
(
GetBlockToCTileMap
(
c_grid_desc_m_n
,
M01
,
N01
,
KSplit
))
{
}
__host__
constexpr
index_t
CalculateGridSize
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
)
const
{
const
auto
M0
=
math
::
integer_divide_ceil
(
c_grid_desc_m_n
.
GetLength
(
I0
),
MPerBlock
);
const
auto
N0
=
math
::
integer_divide_ceil
(
c_grid_desc_m_n
.
GetLength
(
I1
),
NPerBlock
);
const
auto
M00
=
math
::
integer_divide_ceil
(
M0
,
M01_
);
const
auto
N00
=
math
::
integer_divide_ceil
(
N0
,
N01_
);
const
index_t
grid_size
=
M00
*
M01_
*
N00
*
N01_
*
KSplit_
;
return
grid_size
;
}
template
<
typename
TopIdx
>
__host__
__device__
constexpr
auto
CalculateBottomIndex
(
const
TopIdx
&
idx_top
)
const
{
return
underlying_map_
.
CalculateBottomIndex
(
idx_top
);
}
template
<
typename
CTileIdx
,
typename
CTileDim
>
__host__
__device__
bool
ValidCTileIndex
(
const
CTileIdx
&
c_tile_idx
,
const
CTileDim
&
c_tile_dim
)
const
{
if
constexpr
(
DeviceCTileIndexCheck
)
return
DefaultValidCTileIndex
(
c_tile_idx
,
c_tile_dim
);
else
return
true
;
}
__host__
bool
CheckValidity
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
)
const
{
if
constexpr
(
DeviceCTileIndexCheck
)
return
true
;
// validity check moved to kernel
const
index_t
M0
=
math
::
integer_divide_ceil
(
c_grid_desc_m_n
.
GetLength
(
I0
),
MPerBlock
);
const
index_t
N0
=
math
::
integer_divide_ceil
(
c_grid_desc_m_n
.
GetLength
(
I1
),
NPerBlock
);
if
(
M0
%
M01_
==
0
&&
N0
%
N01_
==
0
)
{
return
true
;
}
else
{
return
false
;
}
}
private:
__host__
static
constexpr
auto
GetBlockToCTileMap
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
,
index_t
M01
,
index_t
N01
,
index_t
KSplit
)
{
const
auto
M0
=
math
::
integer_divide_ceil
(
c_grid_desc_m_n
.
GetLength
(
I0
),
MPerBlock
);
const
auto
N0
=
math
::
integer_divide_ceil
(
c_grid_desc_m_n
.
GetLength
(
I1
),
NPerBlock
);
const
auto
M00
=
math
::
integer_divide_ceil
(
M0
,
M01
);
const
auto
N00
=
math
::
integer_divide_ceil
(
N0
,
N01
);
const
auto
ksplit_m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_pass_through_transform
(
KSplit
),
make_unmerge_transform
(
make_tuple
(
M00
,
M01
)),
make_unmerge_transform
(
make_tuple
(
N00
,
N01
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
3
>
{},
Sequence
<
2
,
4
>
{}));
const
auto
c_blockid_to_ksplit_m00_m01_n00_n01_block_cluster_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
KSplit
,
M00
,
N00
,
M01
,
N01
))),
make_tuple
(
Sequence
<
0
,
1
,
2
,
3
,
4
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
const
auto
c_blockid_to_ksplit_m0_n0_block_cluster_adaptor
=
chain_tensor_adaptors
(
ksplit_m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor
,
c_blockid_to_ksplit_m00_m01_n00_n01_block_cluster_adaptor
);
return
c_blockid_to_ksplit_m0_n0_block_cluster_adaptor
;
}
index_t
M01_
,
N01_
,
KSplit_
;
using
UnderlyingMap
=
decltype
(
GetBlockToCTileMap
(
CGridDesc_M_N
{},
1
,
1
,
1
));
UnderlyingMap
underlying_map_
;
};
template
<
typename
CTileIdx
,
typename
CTileDim
>
__host__
__device__
bool
DefaultValidCTileIndex
(
const
CTileIdx
&
c_tile_idx
,
const
CTileDim
&
c_tile_dim
)
{
bool
is_valid
=
false
;
const
index_t
m_block
=
c_tile_dim
[
Number
<
0
>
{}];
const
index_t
n_block
=
c_tile_dim
[
Number
<
1
>
{}];
if
constexpr
(
CTileIdx
::
Size
()
==
2
)
{
const
index_t
m_block_idx
=
c_tile_idx
[
Number
<
0
>
{}];
const
index_t
n_block_idx
=
c_tile_idx
[
Number
<
1
>
{}];
if
(
0
<=
m_block_idx
&&
m_block_idx
<
m_block
&&
0
<=
n_block_idx
&&
n_block_idx
<
n_block
)
{
is_valid
=
true
;
}
}
else
if
constexpr
(
CTileIdx
::
Size
()
==
3
)
{
const
index_t
ksplit_idx
=
c_tile_idx
[
Number
<
0
>
{}];
const
index_t
m_block_idx
=
c_tile_idx
[
Number
<
1
>
{}];
const
index_t
n_block_idx
=
c_tile_idx
[
Number
<
2
>
{}];
if
(
0
<=
m_block_idx
&&
m_block_idx
<
m_block
&&
0
<=
n_block_idx
&&
n_block_idx
<
n_block
)
{
is_valid
=
true
;
}
ignore
=
ksplit_idx
;
}
return
is_valid
;
}
}
// namespace ck
#endif // UTILITY_BLOCK_TO_CTILE_MAP
include/ck/tensor_operation/gpu/grid/gridwise_2d_reduction_blockwise.hpp
deleted
100644 → 0
View file @
48918ab9
/*******************************************************************************
*
* MIT License
*
* Copyright (c) 2021 Advanced Micro Devices, Inc.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in all
* copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*
*******************************************************************************/
#ifndef CK_GRIDWISE_2D_REDUCTION_BLOCKWISE_HPP
#define CK_GRIDWISE_2D_REDUCTION_BLOCKWISE_HPP
#include "data_type.hpp"
#include "reduction_common.hpp"
#include "reduction_operator.hpp"
#include "reduction_functions_accumulate.hpp"
#include "reduction_functions_blockwise.hpp"
#include "reduction_functions_threadwise.hpp"
#include "threadwise_tensor_slice_transfer.hpp"
#include "cluster_descriptor.hpp"
#include "element_wise_operation.hpp"
namespace
ck
{
template
<
typename
GridwiseReduction
,
bool
NeedIndices
,
typename
InDataType
,
typename
OutDataType
,
typename
AccDataType
,
typename
IndexDataType
,
typename
InGridDesc_M_K
,
typename
OutGridDesc_M
,
typename
InElementwiseOperation
,
typename
OutElementwiseOperation
>
__global__
void
kernel_reduce_blockwise
(
const
InGridDesc_M_K
in_grid_desc_m_k
,
const
OutGridDesc_M
out_grid_desc_m
,
const
InElementwiseOperation
in_elementwise_op
,
const
OutElementwiseOperation
acc_elementwise_op
,
AccDataType
alpha
,
const
InDataType
*
const
__restrict__
p_in_global
,
AccDataType
beta
,
OutDataType
*
const
__restrict__
p_out_global
,
const
IndexDataType
*
const
__restrict__
p_ws_indices_global
,
IndexDataType
*
const
__restrict__
p_indices_global
)
{
if
constexpr
(
!
NeedIndices
)
{
constexpr
bool
IsSecondCall
=
false
;
GridwiseReduction
::
template
Run
<
IsSecondCall
>(
in_grid_desc_m_k
,
out_grid_desc_m
,
in_elementwise_op
,
acc_elementwise_op
,
alpha
,
p_in_global
,
beta
,
p_out_global
,
p_ws_indices_global
,
p_indices_global
);
}
else
{
GridwiseReduction
::
RunWithIndex
(
in_grid_desc_m_k
,
out_grid_desc_m
,
in_elementwise_op
,
acc_elementwise_op
,
alpha
,
p_in_global
,
beta
,
p_out_global
,
p_ws_indices_global
,
p_indices_global
);
};
};
template
<
typename
GridwiseReduction
,
bool
NeedIndices
,
typename
InDataType
,
typename
OutDataType
,
typename
AccDataType
,
typename
IndexDataType
,
typename
InGridDesc_M_K
,
typename
OutGridDesc_M
,
typename
InElementwiseOperation
,
typename
OutElementwiseOperation
>
__global__
void
kernel_reduce_blockwise_second_call
(
const
InGridDesc_M_K
in_grid_desc_m_k
,
const
OutGridDesc_M
out_grid_desc_m
,
const
InElementwiseOperation
in_elementwise_op
,
const
OutElementwiseOperation
acc_elementwise_op
,
AccDataType
alpha
,
const
InDataType
*
const
__restrict__
p_in_global
,
AccDataType
beta
,
OutDataType
*
const
__restrict__
p_out_global
,
const
IndexDataType
*
const
__restrict__
p_ws_indices_global
,
IndexDataType
*
const
__restrict__
p_indices_global
)
{
if
constexpr
(
!
NeedIndices
)
{
constexpr
bool
IsSecondCall
=
true
;
GridwiseReduction
::
template
Run
<
IsSecondCall
>(
in_grid_desc_m_k
,
out_grid_desc_m
,
in_elementwise_op
,
acc_elementwise_op
,
alpha
,
p_in_global
,
beta
,
p_out_global
,
p_ws_indices_global
,
p_indices_global
);
}
else
{
GridwiseReduction
::
RunSecondCallWithIndex
(
in_grid_desc_m_k
,
out_grid_desc_m
,
in_elementwise_op
,
acc_elementwise_op
,
alpha
,
p_in_global
,
beta
,
p_out_global
,
p_ws_indices_global
,
p_indices_global
);
};
};
template
<
typename
InDataType
,
typename
OutDataType
,
typename
AccDataType
,
typename
IndexDataType
,
typename
InGridDesc_M_K
,
typename
OutGridDesc_M
,
typename
ReduceOperation
,
typename
InElementwiseOperation
,
typename
OutElementwiseOperation
,
bool
PropagateNan
,
bool
BetaIsZero
,
index_t
BlockSize
,
index_t
MThreadClusterSize
,
index_t
KThreadClusterSize
,
index_t
MThreadSliceSize
,
index_t
KThreadSliceSize
,
index_t
InSrcVectorDim
,
index_t
InSrcVectorSize
,
index_t
OutDstVectorSize
>
struct
GridwiseReduction_mk_to_m_blockwise
{
static_assert
(((
InSrcVectorDim
==
0
&&
MThreadSliceSize
%
InSrcVectorSize
==
0
)
||
(
InSrcVectorDim
==
1
&&
KThreadSliceSize
%
InSrcVectorSize
==
0
))
&&
(
MThreadSliceSize
%
OutDstVectorSize
==
0
),
"Invalid thread slice sizes and/or vector sizes configuration, please check!"
);
static
constexpr
bool
reorder_thread_cluster
=
(
InSrcVectorDim
==
0
);
using
ThreadClusterLengths_M_K
=
Sequence
<
MThreadClusterSize
,
KThreadClusterSize
>
;
using
ThreadBufferDimAccessOrder
=
typename
conditional
<
reorder_thread_cluster
,
Sequence
<
1
,
0
>
,
Sequence
<
0
,
1
>>::
type
;
using
ThreadClusterArrangeOrder
=
typename
conditional
<
reorder_thread_cluster
,
Sequence
<
1
,
0
>
,
Sequence
<
0
,
1
>>::
type
;
static
constexpr
auto
thread_cluster_desc
=
make_cluster_descriptor
(
ThreadClusterLengths_M_K
{},
ThreadClusterArrangeOrder
{});
using
ThreadReduceSrcDesc_M_K
=
decltype
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MThreadSliceSize
>
{},
Number
<
KThreadSliceSize
>
{})));
using
ThreadReduceDstDesc_M
=
decltype
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MThreadSliceSize
>
{})));
using
PassThroughOp
=
tensor_operation
::
element_wise
::
PassThrough
;
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
index_t
M_BlockTileSize
=
MThreadClusterSize
*
MThreadSliceSize
;
static
constexpr
index_t
K_BlockTileSize
=
KThreadClusterSize
*
KThreadSliceSize
;
template
<
bool
IsSecondCall
>
__device__
static
void
Run
(
const
InGridDesc_M_K
&
in_grid_desc_m_k
,
const
OutGridDesc_M
&
out_grid_desc_m
,
const
InElementwiseOperation
&
in_elementwise_op
,
const
OutElementwiseOperation
&
acc_elementwise_op
,
AccDataType
alpha
,
const
InDataType
*
const
__restrict__
p_in_global
,
AccDataType
beta
,
OutDataType
*
const
__restrict__
p_out_global
,
const
IndexDataType
*
const
__restrict__
p_ws_indices_global
,
IndexDataType
*
const
__restrict__
p_indices_global
)
{
if
constexpr
(
IsSecondCall
)
{
static_assert
(
InSrcVectorDim
==
1
,
"InSrcVectorDim must be 1 for BlockwiseSecondCall, please check!"
);
};
using
BlockwiseReduce
=
PartitionedBlockwiseReduction
<
AccDataType
,
BlockSize
,
ThreadClusterLengths_M_K
,
ThreadClusterArrangeOrder
,
ReduceOperation
,
PropagateNan
>
;
using
ThreadwiseReduce
=
ThreadwiseReduction
<
AccDataType
,
ThreadReduceSrcDesc_M_K
,
ThreadReduceDstDesc_M
,
ReduceOperation
,
PropagateNan
>
;
(
void
)
p_ws_indices_global
;
(
void
)
p_indices_global
;
// LDS
__shared__
AccDataType
p_reduce_work_buffer
[
BlockSize
];
const
auto
zeroVal
=
ReduceOperation
::
GetReductionZeroVal
();
const
auto
in_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_in_global
,
in_grid_desc_m_k
.
GetElementSpaceSize
(),
type_convert
<
InDataType
>
(
zeroVal
));
auto
out_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_out_global
,
out_grid_desc_m
.
GetElementSpaceSize
());
auto
reduce_work_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
p_reduce_work_buffer
,
BlockSize
);
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
*
KThreadSliceSize
,
true
>
in_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
,
true
>
accu_value_buf
;
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
I
)
{
accu_value_buf
(
I
)
=
zeroVal
;
});
const
auto
toReduceLength
=
in_grid_desc_m_k
.
GetLength
(
Number
<
1
>
{});
const
index_t
thread_local_id
=
get_thread_local_1d_id
();
const
index_t
block_global_1d_id
=
get_block_1d_id
();
const
auto
thread_cluster_idx
=
thread_cluster_desc
.
CalculateBottomIndex
(
make_multi_index
(
thread_local_id
));
const
auto
thread_m_cluster_id
=
thread_cluster_idx
[
I0
];
const
auto
thread_k_cluster_id
=
thread_cluster_idx
[
I1
];
using
ThreadBufferLengths
=
Sequence
<
MThreadSliceSize
,
KThreadSliceSize
>
;
constexpr
auto
thread_buffer_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MThreadSliceSize
>
{},
Number
<
KThreadSliceSize
>
{}));
auto
threadwise_src_load
=
ThreadwiseTensorSliceTransfer_v2
<
InDataType
,
AccDataType
,
InGridDesc_M_K
,
decltype
(
thread_buffer_desc
),
ThreadBufferLengths
,
ThreadBufferDimAccessOrder
,
InSrcVectorDim
,
InSrcVectorSize
,
1
,
false
>
(
in_grid_desc_m_k
,
make_multi_index
(
block_global_1d_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
,
thread_k_cluster_id
*
KThreadSliceSize
));
constexpr
auto
in_thread_copy_step
=
make_multi_index
(
0
,
K_BlockTileSize
);
const
index_t
toReduceTiles
=
(
toReduceLength
+
K_BlockTileSize
-
1
)
/
K_BlockTileSize
;
index_t
reducedTiles
=
0
;
do
{
threadwise_src_load
.
Run
(
in_grid_desc_m_k
,
in_global_buf
,
thread_buffer_desc
,
make_tuple
(
I0
,
I0
),
in_thread_buf
);
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
iM
)
{
// do element-wise pre-reduction operation
static_for
<
0
,
KThreadSliceSize
,
1
>
{}([
&
](
auto
iK
)
{
constexpr
auto
offset
=
thread_buffer_desc
.
CalculateOffset
(
make_tuple
(
iM
,
iK
));
in_elementwise_op
(
in_thread_buf
(
Number
<
offset
>
{}),
in_thread_buf
(
Number
<
offset
>
{}));
});
});
ThreadwiseReduce
::
Reduce
(
in_thread_buf
,
accu_value_buf
);
threadwise_src_load
.
MoveSrcSliceWindow
(
in_grid_desc_m_k
,
in_thread_copy_step
);
reducedTiles
++
;
}
while
(
reducedTiles
<
toReduceTiles
);
constexpr
auto
reduced_data_desc
=
ThreadReduceDstDesc_M
{};
static_for
<
0
,
MThreadSliceSize
,
1
>
{}(
[
&
](
auto
I
)
{
BlockwiseReduce
::
Reduce
(
reduce_work_buf
,
accu_value_buf
(
I
));
});
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
I
)
{
if
(
thread_k_cluster_id
==
0
)
{
acc_elementwise_op
(
accu_value_buf
(
I
),
accu_value_buf
(
I
));
accu_value_buf
(
I
)
*=
alpha
;
}
});
if
(
thread_k_cluster_id
==
0
)
{
if
constexpr
(
!
BetaIsZero
)
{
if
(
!
float_equal_zero
{}(
beta
))
{
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
OutDataType
,
MThreadSliceSize
,
true
>
priorDstValueBuf
;
auto
threadwise_dst_load
=
ThreadwiseTensorSliceTransfer_v2
<
OutDataType
,
OutDataType
,
OutGridDesc_M
,
decltype
(
reduced_data_desc
),
Sequence
<
MThreadSliceSize
>
,
Sequence
<
0
>
,
0
,
OutDstVectorSize
,
1
,
false
>
(
out_grid_desc_m
,
make_multi_index
(
block_global_1d_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
));
threadwise_dst_load
.
Run
(
out_grid_desc_m
,
out_global_buf
,
reduced_data_desc
,
make_tuple
(
I0
),
priorDstValueBuf
);
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
I
)
{
accu_value_buf
(
I
)
+=
type_convert
<
AccDataType
>
(
priorDstValueBuf
[
I
])
*
beta
;
});
};
};
auto
threadwise_dst_store
=
ThreadwiseTensorSliceTransfer_v1r3
<
AccDataType
,
OutDataType
,
decltype
(
reduced_data_desc
),
OutGridDesc_M
,
PassThroughOp
,
Sequence
<
MThreadSliceSize
>
,
Sequence
<
0
>
,
0
,
OutDstVectorSize
,
InMemoryDataOperationEnum
::
Set
,
1
,
true
>
(
out_grid_desc_m
,
make_multi_index
(
block_global_1d_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
),
PassThroughOp
{});
threadwise_dst_store
.
Run
(
reduced_data_desc
,
make_tuple
(
I0
),
accu_value_buf
,
out_grid_desc_m
,
out_global_buf
);
}
};
__device__
static
void
RunWithIndex
(
const
InGridDesc_M_K
&
in_grid_desc_m_k
,
const
OutGridDesc_M
&
out_grid_desc_m
,
const
InElementwiseOperation
&
in_elementwise_op
,
const
OutElementwiseOperation
&
acc_elementwise_op
,
AccDataType
alpha
,
const
InDataType
*
const
__restrict__
p_in_global
,
AccDataType
beta
,
OutDataType
*
const
__restrict__
p_out_global
,
const
IndexDataType
*
const
__restrict__
p_ws_indices_global
,
IndexDataType
*
const
__restrict__
p_indices_global
)
{
using
BlockwiseReduceWithIndex
=
PartitionedBlockwiseReductionWithIndex
<
AccDataType
,
IndexDataType
,
BlockSize
,
ThreadClusterLengths_M_K
,
ThreadClusterArrangeOrder
,
ReduceOperation
,
PropagateNan
>
;
using
AccumulationWithIndex
=
detail
::
AccumulateWithIndexAndNanCheck
<
PropagateNan
,
ReduceOperation
,
AccDataType
,
IndexDataType
>
;
(
void
)
p_ws_indices_global
;
// LDS
__shared__
AccDataType
p_reduce_work_val_buffer
[
BlockSize
];
__shared__
IndexDataType
p_reduce_work_idx_buffer
[
BlockSize
];
const
auto
zeroVal
=
ReduceOperation
::
GetReductionZeroVal
();
const
auto
in_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_in_global
,
in_grid_desc_m_k
.
GetElementSpaceSize
(),
type_convert
<
InDataType
>
(
zeroVal
));
auto
out_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_out_global
,
out_grid_desc_m
.
GetElementSpaceSize
());
auto
out_global_idx_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_indices_global
,
out_grid_desc_m
.
GetElementSpaceSize
());
auto
reduce_work_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
p_reduce_work_val_buffer
,
BlockSize
);
auto
reduce_work_idx_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
p_reduce_work_idx_buffer
,
BlockSize
);
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
*
KThreadSliceSize
,
true
>
in_thread_val_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
IndexDataType
,
MThreadSliceSize
*
KThreadSliceSize
,
true
>
in_thread_idx_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
,
true
>
accu_value_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
IndexDataType
,
MThreadSliceSize
,
true
>
accu_index_buf
;
const
auto
toReduceLength
=
in_grid_desc_m_k
.
GetLength
(
Number
<
1
>
{});
const
index_t
thread_local_id
=
get_thread_local_1d_id
();
const
index_t
block_global_1d_id
=
get_block_1d_id
();
const
auto
thread_cluster_idx
=
thread_cluster_desc
.
CalculateBottomIndex
(
make_multi_index
(
thread_local_id
));
const
auto
thread_m_cluster_id
=
thread_cluster_idx
[
I0
];
const
auto
thread_k_cluster_id
=
thread_cluster_idx
[
I1
];
using
ThreadBufferLengths
=
Sequence
<
MThreadSliceSize
,
KThreadSliceSize
>
;
constexpr
auto
thread_buffer_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MThreadSliceSize
>
{},
Number
<
KThreadSliceSize
>
{}));
auto
threadwise_src_load
=
ThreadwiseTensorSliceTransfer_v2
<
InDataType
,
AccDataType
,
InGridDesc_M_K
,
decltype
(
thread_buffer_desc
),
ThreadBufferLengths
,
ThreadBufferDimAccessOrder
,
InSrcVectorDim
,
InSrcVectorSize
,
1
,
false
>
(
in_grid_desc_m_k
,
make_multi_index
(
block_global_1d_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
,
thread_k_cluster_id
*
KThreadSliceSize
));
index_t
indexOffset
=
0
;
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
I
)
{
accu_value_buf
(
I
)
=
zeroVal
;
accu_index_buf
(
I
)
=
0
;
});
constexpr
auto
in_thread_copy_step
=
make_multi_index
(
0
,
K_BlockTileSize
);
const
index_t
toReduceTiles
=
(
toReduceLength
+
K_BlockTileSize
-
1
)
/
K_BlockTileSize
;
index_t
reducedTiles
=
0
;
do
{
// load the thread slice
threadwise_src_load
.
Run
(
in_grid_desc_m_k
,
in_global_buf
,
thread_buffer_desc
,
make_tuple
(
I0
,
I0
),
in_thread_val_buf
);
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
iM
)
{
static_for
<
0
,
KThreadSliceSize
,
1
>
{}([
&
](
auto
iK
)
{
constexpr
auto
offset
=
thread_buffer_desc
.
CalculateOffset
(
make_tuple
(
iM
,
iK
));
// initialize the indices for the per-thread to-reduce values
in_thread_idx_buf
(
Number
<
offset
>
{})
=
indexOffset
+
thread_k_cluster_id
*
KThreadSliceSize
+
iK
();
// do element-wise pre-reduction operation
in_elementwise_op
(
in_thread_val_buf
(
Number
<
offset
>
{}),
in_thread_val_buf
(
Number
<
offset
>
{}));
});
AccDataType
tmpValue
=
zeroVal
;
IndexDataType
tmpIndex
=
0
;
static_for
<
0
,
KThreadSliceSize
,
1
>
{}([
&
](
auto
iK
)
{
constexpr
auto
offset
=
thread_buffer_desc
.
CalculateOffset
(
make_tuple
(
iM
,
iK
));
AccumulationWithIndex
::
Calculate
(
tmpValue
,
in_thread_val_buf
[
Number
<
offset
>
{}],
tmpIndex
,
in_thread_idx_buf
[
Number
<
offset
>
{}]);
});
BlockwiseReduceWithIndex
::
Reduce
(
reduce_work_val_buf
,
reduce_work_idx_buf
,
tmpValue
,
tmpIndex
);
AccumulationWithIndex
::
Calculate
(
accu_value_buf
(
iM
),
tmpValue
,
accu_index_buf
(
iM
),
tmpIndex
);
});
threadwise_src_load
.
MoveSrcSliceWindow
(
in_grid_desc_m_k
,
in_thread_copy_step
);
indexOffset
+=
K_BlockTileSize
;
reducedTiles
++
;
}
while
(
reducedTiles
<
toReduceTiles
);
constexpr
auto
reduced_data_desc
=
ThreadReduceDstDesc_M
{};
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
I
)
{
if
(
thread_k_cluster_id
==
0
)
{
// for indiced operation, acc_elementwise_op shoud do nothing
acc_elementwise_op
(
accu_value_buf
(
I
),
accu_value_buf
(
I
));
accu_value_buf
(
I
)
*=
alpha
;
}
});
if
(
thread_k_cluster_id
==
0
)
{
if
constexpr
(
!
BetaIsZero
)
{
if
(
!
float_equal_zero
{}(
beta
))
{
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
OutDataType
,
MThreadSliceSize
,
true
>
priorDstValueBuf
;
auto
threadwise_dst_load
=
ThreadwiseTensorSliceTransfer_v2
<
OutDataType
,
OutDataType
,
OutGridDesc_M
,
decltype
(
reduced_data_desc
),
Sequence
<
MThreadSliceSize
>
,
Sequence
<
0
>
,
0
,
OutDstVectorSize
,
1
,
false
>
(
out_grid_desc_m
,
make_multi_index
(
block_global_1d_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
));
threadwise_dst_load
.
Run
(
out_grid_desc_m
,
out_global_val_buf
,
reduced_data_desc
,
make_tuple
(
I0
),
priorDstValueBuf
);
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
I
)
{
accu_value_buf
(
I
)
+=
type_convert
<
AccDataType
>
(
priorDstValueBuf
[
I
])
*
beta
;
});
};
};
auto
threadwise_dst_val_store
=
ThreadwiseTensorSliceTransfer_v1r3
<
AccDataType
,
OutDataType
,
decltype
(
reduced_data_desc
),
OutGridDesc_M
,
PassThroughOp
,
Sequence
<
MThreadSliceSize
>
,
Sequence
<
0
>
,
0
,
OutDstVectorSize
,
InMemoryDataOperationEnum
::
Set
,
1
,
false
>
(
out_grid_desc_m
,
make_multi_index
(
block_global_1d_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
),
PassThroughOp
{});
auto
threadwise_dst_idx_store
=
ThreadwiseTensorSliceTransfer_v1r3
<
IndexDataType
,
IndexDataType
,
decltype
(
reduced_data_desc
),
OutGridDesc_M
,
PassThroughOp
,
Sequence
<
MThreadSliceSize
>
,
Sequence
<
0
>
,
0
,
OutDstVectorSize
,
InMemoryDataOperationEnum
::
Set
,
1
,
false
>
(
out_grid_desc_m
,
make_multi_index
(
block_global_1d_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
),
PassThroughOp
{});
threadwise_dst_val_store
.
Run
(
reduced_data_desc
,
make_tuple
(
I0
),
accu_value_buf
,
out_grid_desc_m
,
out_global_val_buf
);
threadwise_dst_idx_store
.
Run
(
reduced_data_desc
,
make_tuple
(
I0
),
accu_index_buf
,
out_grid_desc_m
,
out_global_idx_buf
);
}
};
__device__
static
void
RunSecondCallWithIndex
(
const
InGridDesc_M_K
&
in_grid_desc_m_k
,
const
OutGridDesc_M
&
out_grid_desc_m
,
const
InElementwiseOperation
in_elementwise_op
,
const
OutElementwiseOperation
acc_elementwise_op
,
AccDataType
alpha
,
const
InDataType
*
const
__restrict__
p_ws_values_global
,
AccDataType
beta
,
OutDataType
*
const
__restrict__
p_out_global
,
const
IndexDataType
*
const
__restrict__
p_ws_indices_global
,
IndexDataType
*
const
__restrict__
p_indices_global
)
{
static_assert
(
InSrcVectorDim
==
1
,
"InSrcVectorDim must be 1 for BlockwiseSecondCall, please check!"
);
using
BlockwiseReduceWithIndex
=
PartitionedBlockwiseReductionWithIndex
<
AccDataType
,
IndexDataType
,
BlockSize
,
Sequence
<
MThreadClusterSize
,
KThreadClusterSize
>
,
ThreadClusterArrangeOrder
,
ReduceOperation
,
PropagateNan
>
;
using
AccumulationWithIndex
=
detail
::
AccumulateWithIndexAndNanCheck
<
PropagateNan
,
ReduceOperation
,
AccDataType
,
IndexDataType
>
;
(
void
)
in_elementwise_op
;
// LDS
__shared__
AccDataType
p_reduce_work_val_buffer
[
BlockSize
];
__shared__
IndexDataType
p_reduce_work_idx_buffer
[
BlockSize
];
const
auto
zeroVal
=
ReduceOperation
::
GetReductionZeroVal
();
const
auto
src_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_ws_values_global
,
in_grid_desc_m_k
.
GetElementSpaceSize
(),
type_convert
<
InDataType
>
(
zeroVal
));
const
auto
src_global_idx_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_ws_indices_global
,
in_grid_desc_m_k
.
GetElementSpaceSize
());
auto
out_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_out_global
,
out_grid_desc_m
.
GetElementSpaceSize
());
auto
out_global_idx_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_indices_global
,
out_grid_desc_m
.
GetElementSpaceSize
());
auto
reduce_work_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
p_reduce_work_val_buffer
,
BlockSize
);
auto
reduce_work_idx_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
p_reduce_work_idx_buffer
,
BlockSize
);
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
*
KThreadSliceSize
,
true
>
in_thread_val_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
IndexDataType
,
MThreadSliceSize
*
KThreadSliceSize
,
true
>
in_thread_idx_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
,
true
>
accu_value_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
IndexDataType
,
MThreadSliceSize
,
true
>
accu_index_buf
;
const
auto
toReduceLength
=
in_grid_desc_m_k
.
GetLength
(
Number
<
1
>
{});
const
index_t
thread_local_id
=
get_thread_local_1d_id
();
const
index_t
block_global_1d_id
=
get_block_1d_id
();
const
auto
thread_cluster_idx
=
thread_cluster_desc
.
CalculateBottomIndex
(
make_multi_index
(
thread_local_id
));
const
auto
thread_m_cluster_id
=
thread_cluster_idx
[
I0
];
const
auto
thread_k_cluster_id
=
thread_cluster_idx
[
I1
];
using
ThreadBufferLengths
=
Sequence
<
MThreadSliceSize
,
KThreadSliceSize
>
;
constexpr
auto
thread_buffer_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MThreadSliceSize
>
{},
Number
<
KThreadSliceSize
>
{}));
auto
threadwise_src_val_load
=
ThreadwiseTensorSliceTransfer_v2
<
InDataType
,
AccDataType
,
InGridDesc_M_K
,
decltype
(
thread_buffer_desc
),
ThreadBufferLengths
,
ThreadBufferDimAccessOrder
,
InSrcVectorDim
,
InSrcVectorSize
,
1
,
false
>
(
in_grid_desc_m_k
,
make_multi_index
(
block_global_1d_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
,
thread_k_cluster_id
*
KThreadSliceSize
));
auto
threadwise_src_idx_load
=
ThreadwiseTensorSliceTransfer_v2
<
IndexDataType
,
IndexDataType
,
InGridDesc_M_K
,
decltype
(
thread_buffer_desc
),
ThreadBufferLengths
,
ThreadBufferDimAccessOrder
,
InSrcVectorDim
,
InSrcVectorSize
,
1
,
false
>
(
in_grid_desc_m_k
,
make_multi_index
(
block_global_1d_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
,
thread_k_cluster_id
*
KThreadSliceSize
));
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
I
)
{
accu_value_buf
(
I
)
=
zeroVal
;
accu_index_buf
(
I
)
=
0
;
});
constexpr
auto
in_thread_copy_step
=
make_multi_index
(
0
,
K_BlockTileSize
);
const
index_t
toReduceTiles
=
(
toReduceLength
+
K_BlockTileSize
-
1
)
/
K_BlockTileSize
;
index_t
reducedTiles
=
0
;
do
{
// load the thread slice
threadwise_src_val_load
.
Run
(
in_grid_desc_m_k
,
src_global_val_buf
,
thread_buffer_desc
,
make_tuple
(
I0
,
I0
),
in_thread_val_buf
);
threadwise_src_idx_load
.
Run
(
in_grid_desc_m_k
,
src_global_idx_buf
,
thread_buffer_desc
,
make_tuple
(
I0
,
I0
),
in_thread_idx_buf
);
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
iM
)
{
AccDataType
tmpValue
=
zeroVal
;
IndexDataType
tmpIndex
=
0
;
static_for
<
0
,
KThreadSliceSize
,
1
>
{}([
&
](
auto
iK
)
{
constexpr
auto
offset
=
thread_buffer_desc
.
CalculateOffset
(
make_tuple
(
iM
,
iK
));
AccumulationWithIndex
::
Calculate
(
tmpValue
,
in_thread_val_buf
[
Number
<
offset
>
{}],
tmpIndex
,
in_thread_idx_buf
[
Number
<
offset
>
{}]);
});
BlockwiseReduceWithIndex
::
Reduce
(
reduce_work_val_buf
,
reduce_work_idx_buf
,
tmpValue
,
tmpIndex
);
AccumulationWithIndex
::
Calculate
(
accu_value_buf
(
iM
),
tmpValue
,
accu_index_buf
(
iM
),
tmpIndex
);
});
threadwise_src_val_load
.
MoveSrcSliceWindow
(
in_grid_desc_m_k
,
in_thread_copy_step
);
threadwise_src_idx_load
.
MoveSrcSliceWindow
(
in_grid_desc_m_k
,
in_thread_copy_step
);
reducedTiles
++
;
}
while
(
reducedTiles
<
toReduceTiles
);
constexpr
auto
reduced_data_desc
=
ThreadReduceDstDesc_M
{};
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
I
)
{
if
(
thread_k_cluster_id
==
0
)
{
// for indiced operation, acc_elementwise_op shoud do nothing
acc_elementwise_op
(
accu_value_buf
(
I
),
accu_value_buf
(
I
));
accu_value_buf
(
I
)
*=
alpha
;
}
});
if
(
thread_k_cluster_id
==
0
)
{
if
constexpr
(
!
BetaIsZero
)
{
if
(
!
float_equal_zero
{}(
beta
))
{
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
OutDataType
,
MThreadSliceSize
,
true
>
priorDstValueBuf
;
auto
threadwise_dst_load
=
ThreadwiseTensorSliceTransfer_v2
<
OutDataType
,
OutDataType
,
OutGridDesc_M
,
decltype
(
reduced_data_desc
),
Sequence
<
MThreadSliceSize
>
,
Sequence
<
0
>
,
0
,
OutDstVectorSize
,
1
,
true
>
(
out_grid_desc_m
,
make_multi_index
(
block_global_1d_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
));
threadwise_dst_load
.
Run
(
out_grid_desc_m
,
out_global_val_buf
,
reduced_data_desc
,
make_tuple
(
I0
),
priorDstValueBuf
);
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
I
)
{
accu_value_buf
(
I
)
+=
type_convert
<
AccDataType
>
(
priorDstValueBuf
[
I
])
*
beta
;
});
};
};
auto
threadwise_dst_val_store
=
ThreadwiseTensorSliceTransfer_v1r3
<
AccDataType
,
OutDataType
,
decltype
(
reduced_data_desc
),
OutGridDesc_M
,
PassThroughOp
,
Sequence
<
MThreadSliceSize
>
,
Sequence
<
0
>
,
0
,
OutDstVectorSize
,
InMemoryDataOperationEnum
::
Set
,
1
,
true
>
(
out_grid_desc_m
,
make_multi_index
(
block_global_1d_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
),
PassThroughOp
{});
auto
threadwise_dst_idx_store
=
ThreadwiseTensorSliceTransfer_v1r3
<
IndexDataType
,
IndexDataType
,
decltype
(
reduced_data_desc
),
OutGridDesc_M
,
PassThroughOp
,
Sequence
<
MThreadSliceSize
>
,
Sequence
<
0
>
,
0
,
OutDstVectorSize
,
InMemoryDataOperationEnum
::
Set
,
1
,
true
>
(
out_grid_desc_m
,
make_multi_index
(
block_global_1d_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
),
PassThroughOp
{});
threadwise_dst_val_store
.
Run
(
reduced_data_desc
,
make_tuple
(
I0
),
accu_value_buf
,
out_grid_desc_m
,
out_global_val_buf
);
threadwise_dst_idx_store
.
Run
(
reduced_data_desc
,
make_tuple
(
I0
),
accu_index_buf
,
out_grid_desc_m
,
out_global_idx_buf
);
}
};
};
}
// namespace ck
#endif
include/ck/tensor_operation/gpu/grid/gridwise_2d_reduction_multiblock
_partial_reduce
.hpp
→
include/ck/tensor_operation/gpu/grid/gridwise_2d_reduction_multiblock.hpp
View file @
a3b4c5cb
...
@@ -23,75 +23,86 @@
...
@@ -23,75 +23,86 @@
* SOFTWARE.
* SOFTWARE.
*
*
*******************************************************************************/
*******************************************************************************/
#ifndef CK_GRIDWISE_2D_REDUCTION_MULTIBLOCK_
PARTIAL_REDUCE_
HPP
#ifndef CK_GRIDWISE_2D_REDUCTION_MULTIBLOCK_HPP
#define CK_GRIDWISE_2D_REDUCTION_MULTIBLOCK_
PARTIAL_REDUCE_
HPP
#define CK_GRIDWISE_2D_REDUCTION_MULTIBLOCK_HPP
#include "reduction_common.hpp"
#include "reduction_common.hpp"
#include "reduction_operator.hpp"
#include "reduction_operator.hpp"
#include "reduction_functions_accumulate.hpp"
#include "reduction_functions_accumulate.hpp"
#include "reduction_functions_blockwise.hpp"
#include "reduction_functions_blockwise.hpp"
#include "reduction_functions_threadwise.hpp"
#include "reduction_functions_threadwise.hpp"
#include "threadwise_tensor_slice_transfer.hpp"
#include "threadwise_tensor_slice_transfer.hpp"
#include "cluster_descriptor.hpp"
#include "element_wise_operation.hpp"
#include "element_wise_operation.hpp"
namespace
ck
{
namespace
ck
{
template
<
typename
GridwiseReduction
,
template
<
typename
GridwiseReduction
,
bool
NeedIndices
,
bool
OutputIndex
,
bool
HaveIndexInput
,
typename
InDataType
,
typename
InDataType
,
typename
OutDataType
,
typename
AccDataType
,
typename
AccDataType
,
typename
IndexDataType
,
typename
IndexDataType
,
typename
InGridDesc_M_K
,
typename
InGridDesc_M_K
,
typename
Workspace
Desc_M
_K
,
typename
OutGrid
Desc_M
,
typename
InElementwiseOperation
,
typename
InElementwiseOperation
,
typename
AccElementwiseOperation
>
typename
AccElementwiseOperation
>
__global__
void
__global__
void
kernel_reduce_multiblock
(
const
InGridDesc_M_K
in_grid_desc_m_k
,
kernel_partial_reduce_multiblock
(
const
InGridDesc_M_K
in_grid_desc_m_k
,
const
OutGridDesc_M
out_grid_desc_m
,
const
WorkspaceDesc_M_K
workspace_desc_m_k
,
const
InElementwiseOperation
in_elementwise_op
,
const
InElementwiseOperation
in_elementwise_op
,
const
AccElementwiseOperation
acc_elementwise_op
,
const
AccElementwiseOperation
acc_elementwise_op
,
index_t
block_group_size
,
index_t
block_group_size
,
index_t
num_k_block_tile_iteration
,
index_t
num_k_block_tile_iteration
,
AccDataType
alpha
,
const
InDataType
*
const
__restrict__
p_src_global
,
const
InDataType
*
const
__restrict__
p_in_value_global
,
AccDataType
*
const
__restrict__
p_ws_values_global
,
const
IndexDataType
*
const
__restrict__
p_in_index_global
,
IndexDataType
*
const
__restrict__
p_ws_indices_global
)
AccDataType
beta
,
OutDataType
*
const
__restrict__
p_out_value_global
,
IndexDataType
*
const
__restrict__
p_out_index_global
)
{
{
if
constexpr
(
!
NeedIndices
)
if
constexpr
(
!
OutputIndex
)
{
{
(
void
)
p_in_index_global
;
(
void
)
p_out_index_global
;
GridwiseReduction
::
Run
(
in_grid_desc_m_k
,
GridwiseReduction
::
Run
(
in_grid_desc_m_k
,
workspace
_desc_m
_k
,
out_grid
_desc_m
,
in_elementwise_op
,
in_elementwise_op
,
acc_elementwise_op
,
acc_elementwise_op
,
block_group_size
,
block_group_size
,
num_k_block_tile_iteration
,
num_k_block_tile_iteration
,
p_src_global
,
alpha
,
p_ws_values_global
,
p_in_value_global
,
p_ws_indices_global
);
beta
,
p_out_value_global
);
}
}
else
else
{
{
GridwiseReduction
::
RunWithIndex
(
in_grid_desc_m_k
,
GridwiseReduction
::
template
RunWithIndex
<
HaveIndexInput
>(
in_grid_desc_m_k
,
workspace_desc_m_k
,
out_grid_desc_m
,
in_elementwise_op
,
in_elementwise_op
,
acc_elementwise_op
,
acc_elementwise_op
,
block_group_size
,
num_k_block_tile_iteration
,
num_k_block_tile_iteration
,
alpha
,
p_src_global
,
p_in_value_global
,
p_ws_values_global
,
p_in_index_global
,
p_ws_indices_global
);
beta
,
p_out_value_global
,
p_out_index_global
);
};
};
};
};
template
<
typename
InDataType
,
template
<
typename
InDataType
,
typename
OutDataType
,
typename
AccDataType
,
typename
AccDataType
,
typename
IndexDataType
,
typename
IndexDataType
,
typename
InGridDesc_M_K
,
typename
InGridDesc_M_K
,
typename
Workspace
Desc_M
_K
,
typename
OutGrid
Desc_M
,
typename
ReduceOperation
,
typename
ReduceOperation
,
typename
InElementwiseOperation
,
typename
InElementwiseOperation
,
typename
AccElementwiseOperation
,
typename
AccElementwiseOperation
,
InMemoryDataOperationEnum
OutMemoryDataOperation
,
bool
PropagateNan
,
bool
PropagateNan
,
index_t
BlockSize
,
index_t
BlockSize
,
index_t
MThreadClusterSize
,
index_t
MThreadClusterSize
,
...
@@ -101,14 +112,13 @@ template <typename InDataType,
...
@@ -101,14 +112,13 @@ template <typename InDataType,
index_t
InSrcVectorDim
,
index_t
InSrcVectorDim
,
index_t
InSrcVectorSize
,
index_t
InSrcVectorSize
,
index_t
OutDstVectorSize
>
index_t
OutDstVectorSize
>
struct
GridwiseReduction_mk_to_m
k
_multiblock
_partial_reduce
struct
GridwiseReduction_mk_to_m_multiblock
{
{
static_assert
((
InSrcVectorDim
==
0
&&
MThreadSliceSize
%
InSrcVectorSize
==
0
)
||
static_assert
(((
InSrcVectorDim
==
0
&&
MThreadSliceSize
%
InSrcVectorSize
==
0
)
||
(
InSrcVectorDim
==
1
&&
KThreadSliceSize
%
InSrcVectorSize
==
0
),
(
InSrcVectorDim
==
1
&&
KThreadSliceSize
%
InSrcVectorSize
==
0
))
&&
(
MThreadSliceSize
%
OutDstVectorSize
==
0
),
"Invalid thread slice sizes and/or vector sizes configuration, please check!"
);
"Invalid thread slice sizes and/or vector sizes configuration, please check!"
);
static_assert
(
OutDstVectorSize
==
1
,
"OutDstVectorSize must be 1 for MultiBlockPartialReduce!"
);
static
constexpr
bool
reorder_thread_cluster
=
(
InSrcVectorDim
==
0
);
static
constexpr
bool
reorder_thread_cluster
=
(
InSrcVectorDim
==
0
);
using
ThreadClusterLengths_M_K
=
Sequence
<
MThreadClusterSize
,
KThreadClusterSize
>
;
using
ThreadClusterLengths_M_K
=
Sequence
<
MThreadClusterSize
,
KThreadClusterSize
>
;
...
@@ -127,6 +137,19 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce
...
@@ -127,6 +137,19 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce
using
ThreadReduceDstDesc_M
=
using
ThreadReduceDstDesc_M
=
decltype
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MThreadSliceSize
>
{})));
decltype
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MThreadSliceSize
>
{})));
using
BlockwiseReduce
=
PartitionedBlockwiseReduction
<
AccDataType
,
BlockSize
,
ThreadClusterLengths_M_K
,
ThreadClusterArrangeOrder
,
ReduceOperation
,
PropagateNan
>
;
using
ThreadwiseReduce
=
ThreadwiseReduction
<
AccDataType
,
ThreadReduceSrcDesc_M_K
,
ThreadReduceDstDesc_M
,
ReduceOperation
,
PropagateNan
>
;
using
PassThroughOp
=
tensor_operation
::
element_wise
::
PassThrough
;
using
PassThroughOp
=
tensor_operation
::
element_wise
::
PassThrough
;
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I0
=
Number
<
0
>
{};
...
@@ -135,43 +158,30 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce
...
@@ -135,43 +158,30 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce
static
constexpr
index_t
M_BlockTileSize
=
MThreadClusterSize
*
MThreadSliceSize
;
static
constexpr
index_t
M_BlockTileSize
=
MThreadClusterSize
*
MThreadSliceSize
;
static
constexpr
index_t
K_BlockTileSize
=
KThreadClusterSize
*
KThreadSliceSize
;
static
constexpr
index_t
K_BlockTileSize
=
KThreadClusterSize
*
KThreadSliceSize
;
using
Accumulation
=
detail
::
AccumulateWithNanCheck
<
PropagateNan
,
ReduceOperation
,
AccDataType
>
;
__device__
static
void
Run
(
const
InGridDesc_M_K
&
in_grid_desc_m_k
,
__device__
static
void
Run
(
const
InGridDesc_M_K
&
in_grid_desc_m_k
,
const
WorkspaceDesc_M_K
&
workspace
_desc_m
_k
,
const
OutGridDesc_M
&
out_grid
_desc_m
,
const
InElementwiseOperation
&
in_elementwise_op
,
const
InElementwiseOperation
&
in_elementwise_op
,
const
AccElementwiseOperation
&
acc_elementwise_op
,
const
AccElementwiseOperation
&
acc_elementwise_op
,
index_t
block_group_size
,
index_t
block_group_size
,
index_t
num_k_block_tile_iteration
,
index_t
num_k_block_tile_iteration
,
const
InDataType
*
const
__restrict__
p_src_global
,
AccDataType
alpha
,
AccDataType
*
const
__restrict__
p_ws_values_global
,
const
InDataType
*
const
__restrict__
p_in_value_global
,
IndexDataType
*
const
__restrict__
p_ws_indices_global
)
AccDataType
beta
,
OutDataType
*
const
__restrict__
p_out_value_global
)
{
{
using
BlockwiseReduce
=
PartitionedBlockwiseReduction
<
AccDataType
,
const
auto
identityVal
=
ReduceOperation
::
GetIdentityValue
();
BlockSize
,
ThreadClusterLengths_M_K
,
ThreadClusterArrangeOrder
,
ReduceOperation
,
PropagateNan
>
;
using
ThreadwiseReduce
=
ThreadwiseReduction
<
AccDataType
,
ThreadReduceSrcDesc_M_K
,
ThreadReduceDstDesc_M
,
ReduceOperation
,
PropagateNan
>
;
(
void
)
p_ws_indices_global
;
(
void
)
acc_elementwise_op
;
const
auto
zeroVal
=
ReduceOperation
::
GetReductionZeroVal
();
// LDS
// LDS
__shared__
AccDataType
p_reduce_work_buffer
[
BlockSize
];
__shared__
AccDataType
p_reduce_work_buffer
[
BlockSize
];
const
auto
in_global_buf
=
const
auto
in_global_
val_
buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_
src
_global
,
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_
in_value
_global
,
in_grid_desc_m_k
.
GetElementSpaceSize
(),
in_grid_desc_m_k
.
GetElementSpaceSize
(),
type_convert
<
InDataType
>
(
zero
Val
));
type_convert
<
InDataType
>
(
identity
Val
));
auto
workspace
_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
auto
out
_global_
val_
buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_
ws
_value
s
_global
,
workspace
_desc_m
_k
.
GetElementSpaceSize
());
p_
out
_value_global
,
out_grid
_desc_m
.
GetElementSpaceSize
());
auto
reduce_work_buf
=
auto
reduce_work_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
p_reduce_work_buffer
,
BlockSize
);
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
p_reduce_work_buffer
,
BlockSize
);
...
@@ -181,7 +191,7 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce
...
@@ -181,7 +191,7 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
,
true
>
accu_value_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
,
true
>
accu_value_buf
;
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
I
)
{
accu_value_buf
(
I
)
=
zero
Val
;
});
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
I
)
{
accu_value_buf
(
I
)
=
identity
Val
;
});
const
index_t
thread_local_id
=
get_thread_local_1d_id
();
const
index_t
thread_local_id
=
get_thread_local_1d_id
();
const
index_t
block_global_id
=
get_block_1d_id
();
const
index_t
block_global_id
=
get_block_1d_id
();
...
@@ -221,7 +231,7 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce
...
@@ -221,7 +231,7 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce
do
do
{
{
threadwise_src_load
.
Run
(
in_grid_desc_m_k
,
threadwise_src_load
.
Run
(
in_grid_desc_m_k
,
in_global_buf
,
in_global_
val_
buf
,
thread_buffer_desc
,
thread_buffer_desc
,
make_tuple
(
I0
,
I0
),
make_tuple
(
I0
,
I0
),
in_thread_buf
);
in_thread_buf
);
...
@@ -242,58 +252,97 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce
...
@@ -242,58 +252,97 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce
reducedTiles
++
;
reducedTiles
++
;
}
while
(
reducedTiles
<
num_k_block_tile_iteration
);
}
while
(
reducedTiles
<
num_k_block_tile_iteration
);
// Each block executes multiple parallel reductions on the LDS, and due to the using of
constexpr
auto
reduced_data_desc
=
ThreadReduceDstDesc_M
{};
// vector_load, each block/thread is involved into multiple invarirant dimensions.
static_for
<
0
,
MThreadSliceSize
,
1
>
{}(
static_for
<
0
,
MThreadSliceSize
,
1
>
{}(
[
&
](
auto
I
)
{
BlockwiseReduce
::
Reduce
(
reduce_work_buf
,
accu_value_buf
(
I
));
});
[
&
](
auto
I
)
{
BlockwiseReduce
::
Reduce
(
reduce_work_buf
,
accu_value_buf
(
I
));
});
constexpr
auto
reduced_data_desc
=
make_naive_tensor_descriptor_packed
(
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
I
)
{
make_tuple
(
Number
<
MThreadSliceSize
>
{},
Number
<
1
>
{}));
if
(
thread_k_cluster_id
==
0
)
{
acc_elementwise_op
(
accu_value_buf
(
I
),
accu_value_buf
(
I
));
accu_value_buf
(
I
)
*=
alpha
;
}
});
if
(
thread_k_cluster_id
==
0
)
if
(
thread_k_cluster_id
==
0
)
{
{
auto
threadwise_workspace_store
=
if
(
block_group_size
==
0
&&
!
float_equal_zero
{}(
beta
))
{
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
OutDataType
,
MThreadSliceSize
,
true
>
priorDstValueBuf
;
auto
threadwise_dst_load
=
ThreadwiseTensorSliceTransfer_v2
<
OutDataType
,
OutDataType
,
OutGridDesc_M
,
decltype
(
reduced_data_desc
),
Sequence
<
MThreadSliceSize
>
,
Sequence
<
0
>
,
0
,
OutDstVectorSize
,
1
,
false
>
(
out_grid_desc_m
,
make_multi_index
(
blkgroup_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
));
threadwise_dst_load
.
Run
(
out_grid_desc_m
,
out_global_val_buf
,
reduced_data_desc
,
make_tuple
(
I0
),
priorDstValueBuf
);
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
I
)
{
accu_value_buf
(
I
)
+=
type_convert
<
AccDataType
>
(
priorDstValueBuf
[
I
])
*
beta
;
});
};
auto
threadwise_dst_store
=
ThreadwiseTensorSliceTransfer_v1r3
<
AccDataType
,
ThreadwiseTensorSliceTransfer_v1r3
<
AccDataType
,
Acc
DataType
,
Out
DataType
,
decltype
(
reduced_data_desc
),
decltype
(
reduced_data_desc
),
Workspace
Desc_M
_K
,
OutGrid
Desc_M
,
PassThroughOp
,
PassThroughOp
,
Sequence
<
MThreadSliceSize
,
1
>
,
Sequence
<
MThreadSliceSize
>
,
Sequence
<
0
,
1
>
,
Sequence
<
0
>
,
1
,
0
,
1
,
OutDstVectorSize
,
In
MemoryDataOperation
Enum
::
Set
,
Out
MemoryDataOperation
,
1
,
1
,
true
>
(
true
>
(
workspace
_desc_m
_k
,
out_grid
_desc_m
,
make_multi_index
(
blkgroup_id
*
M_BlockTileSize
+
make_multi_index
(
blkgroup_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
,
thread_m_cluster_id
*
MThreadSliceSize
),
block_local_id
),
PassThroughOp
{});
PassThroughOp
{});
threadwise_
workspace
_store
.
Run
(
reduced_data_desc
,
threadwise_
dst
_store
.
Run
(
reduced_data_desc
,
make_tuple
(
I0
,
I0
),
make_tuple
(
I0
),
accu_value_buf
,
accu_value_buf
,
workspace
_desc_m
_k
,
out_grid
_desc_m
,
workspace
_global_buf
);
out
_global_
val_
buf
);
}
}
};
};
template
<
bool
HaveIndexInput
>
__device__
static
void
RunWithIndex
(
const
InGridDesc_M_K
&
in_grid_desc_m_k
,
__device__
static
void
RunWithIndex
(
const
InGridDesc_M_K
&
in_grid_desc_m_k
,
const
WorkspaceDesc_M_K
&
workspace_desc_m_k
,
const
OutGridDesc_M
&
out_grid_desc_m
,
const
InElementwiseOperation
&
in_elementwise_op
,
const
InElementwiseOperation
in_elementwise_op
,
const
AccElementwiseOperation
&
acc_elementwise_op
,
const
AccElementwiseOperation
acc_elementwise_op
,
index_t
block_group_size
,
index_t
num_k_block_tile_iteration
,
index_t
num_k_block_tile_iteration
,
const
InDataType
*
const
__restrict__
p_src_global
,
AccDataType
alpha
,
AccDataType
*
const
__restrict__
p_ws_values_global
,
const
InDataType
*
const
__restrict__
p_in_value_global
,
IndexDataType
*
const
__restrict__
p_ws_indices_global
)
const
IndexDataType
*
const
__restrict__
p_in_index_global
,
AccDataType
beta
,
OutDataType
*
const
__restrict__
p_out_value_global
,
IndexDataType
*
const
__restrict__
p_out_index_global
)
{
{
using
BlockwiseReduceWithIndex
=
using
BlockwiseReduceWithIndex
=
PartitionedBlockwiseReductionWithIndex
<
AccDataType
,
PartitionedBlockwiseReductionWithIndex
<
AccDataType
,
IndexDataType
,
IndexDataType
,
BlockSize
,
BlockSize
,
ThreadCluster
Lengths_M_K
,
Sequence
<
M
ThreadCluster
Size
,
KThreadClusterSize
>
,
ThreadClusterArrangeOrder
,
ThreadClusterArrangeOrder
,
ReduceOperation
,
ReduceOperation
,
PropagateNan
>
;
PropagateNan
>
;
...
@@ -303,22 +352,24 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce
...
@@ -303,22 +352,24 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce
AccDataType
,
AccDataType
,
IndexDataType
>
;
IndexDataType
>
;
(
void
)
acc_elementwise_op
;
(
void
)
in_elementwise_op
;
const
auto
zeroVal
=
ReduceOperation
::
GetReductionZeroVal
();
// LDS
// LDS
__shared__
AccDataType
p_reduce_work_val_buffer
[
BlockSize
];
__shared__
AccDataType
p_reduce_work_val_buffer
[
BlockSize
];
__shared__
i
ndex
_t
p_reduce_work_idx_buffer
[
BlockSize
];
__shared__
I
ndex
DataType
p_reduce_work_idx_buffer
[
BlockSize
];
const
auto
in_global_buf
=
const
auto
identityVal
=
ReduceOperation
::
GetIdentityValue
();
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_src_global
,
const
auto
in_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_in_value_global
,
in_grid_desc_m_k
.
GetElementSpaceSize
(),
in_grid_desc_m_k
.
GetElementSpaceSize
(),
type_convert
<
InDataType
>
(
zeroVal
));
type_convert
<
InDataType
>
(
identityVal
));
auto
workspace_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
const
auto
in_global_idx_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_ws_values_global
,
workspace_desc_m_k
.
GetElementSpaceSize
());
p_in_index_global
,
in_grid_desc_m_k
.
GetElementSpaceSize
());
auto
workspace_global_idx_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
auto
out_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_ws_indices_global
,
workspace_desc_m_k
.
GetElementSpaceSize
());
p_out_value_global
,
out_grid_desc_m
.
GetElementSpaceSize
());
auto
out_global_idx_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_out_index_global
,
out_grid_desc_m
.
GetElementSpaceSize
());
auto
reduce_work_val_buf
=
auto
reduce_work_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
p_reduce_work_val_buffer
,
BlockSize
);
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
p_reduce_work_val_buffer
,
BlockSize
);
...
@@ -327,6 +378,7 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce
...
@@ -327,6 +378,7 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
*
KThreadSliceSize
,
true
>
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
*
KThreadSliceSize
,
true
>
in_thread_val_buf
;
in_thread_val_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
IndexDataType
,
IndexDataType
,
MThreadSliceSize
*
KThreadSliceSize
,
MThreadSliceSize
*
KThreadSliceSize
,
...
@@ -336,10 +388,8 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce
...
@@ -336,10 +388,8 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
,
true
>
accu_value_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
,
true
>
accu_value_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
IndexDataType
,
MThreadSliceSize
,
true
>
accu_index_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
IndexDataType
,
MThreadSliceSize
,
true
>
accu_index_buf
;
const
index_t
thread_local_id
=
get_thread_local_1d_id
();
const
index_t
thread_local_id
=
get_thread_local_1d_id
();
const
index_t
block_global_id
=
get_block_1d_id
();
const
index_t
block_global_1d_id
=
get_block_1d_id
();
const
index_t
blkgroup_id
=
block_global_id
/
block_group_size
;
const
index_t
block_local_id
=
block_global_id
%
block_group_size
;
const
auto
thread_cluster_idx
=
const
auto
thread_cluster_idx
=
thread_cluster_desc
.
CalculateBottomIndex
(
make_multi_index
(
thread_local_id
));
thread_cluster_desc
.
CalculateBottomIndex
(
make_multi_index
(
thread_local_id
));
...
@@ -347,138 +397,239 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce
...
@@ -347,138 +397,239 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce
const
auto
thread_m_cluster_id
=
thread_cluster_idx
[
I0
];
const
auto
thread_m_cluster_id
=
thread_cluster_idx
[
I0
];
const
auto
thread_k_cluster_id
=
thread_cluster_idx
[
I1
];
const
auto
thread_k_cluster_id
=
thread_cluster_idx
[
I1
];
const
index_t
reduceSizePerBlock
=
K_BlockTileSize
*
num_k_block_tile_iteration
;
using
ThreadBufferLengths
=
Sequence
<
MThreadSliceSize
,
KThreadSliceSize
>
;
using
ThreadBufferLengths
=
Sequence
<
MThreadSliceSize
,
KThreadSliceSize
>
;
constexpr
auto
thread_buffer_desc
=
make_naive_tensor_descriptor_packed
(
constexpr
auto
thread_buffer_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MThreadSliceSize
>
{},
Number
<
KThreadSliceSize
>
{}));
make_tuple
(
Number
<
MThreadSliceSize
>
{},
Number
<
KThreadSliceSize
>
{}));
auto
threadwise_src_load
=
ThreadwiseTensorSliceTransfer_v2
<
InDataType
,
auto
threadwise_src_val_load
=
AccDataType
,
ThreadwiseTensorSliceTransfer_v2
<
InDataType
,
InGridDesc_M_K
,
AccDataType
,
decltype
(
thread_buffer_desc
),
InGridDesc_M_K
,
ThreadBufferLengths
,
decltype
(
thread_buffer_desc
),
ThreadBufferDimAccessOrder
,
ThreadBufferLengths
,
InSrcVectorDim
,
ThreadBufferDimAccessOrder
,
InSrcVectorSize
,
InSrcVectorDim
,
1
,
InSrcVectorSize
,
false
>
(
1
,
in_grid_desc_m_k
,
false
>
(
make_multi_index
(
blkgroup_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
,
in_grid_desc_m_k
,
block_local_id
*
reduceSizePerBlock
+
make_multi_index
(
block_global_1d_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
,
thread_k_cluster_id
*
KThreadSliceSize
));
thread_k_cluster_id
*
KThreadSliceSize
));
constexpr
auto
in_thread_copy_step
=
make_multi_index
(
0
,
K_BlockTileSize
);
index_t
indexOffset
=
block_local_id
*
reduceSizePerBlock
;
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
I
)
{
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
I
)
{
accu_value_buf
(
I
)
=
zero
Val
;
accu_value_buf
(
I
)
=
identity
Val
;
accu_index_buf
(
I
)
=
0
;
accu_index_buf
(
I
)
=
0
;
});
});
index_t
reducedTiles
=
0
;
constexpr
auto
in_thread_copy_step
=
make_multi_index
(
0
,
K_BlockTileSize
);
do
{
// load the thread slice
threadwise_src_load
.
Run
(
in_grid_desc_m_k
,
in_global_buf
,
thread_buffer_desc
,
make_tuple
(
I0
,
I0
),
in_thread_val_buf
);
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
iM
)
{
static_for
<
0
,
KThreadSliceSize
,
1
>
{}([
&
](
auto
iK
)
{
constexpr
auto
offset
=
thread_buffer_desc
.
CalculateOffset
(
make_tuple
(
iM
,
iK
));
// initialize the indices for the per-thread to-reduce values
index_t
reducedTiles
=
0
;
in_thread_idx_buf
(
Number
<
offset
>
{})
=
indexOffset
+
thread_k_cluster_id
*
KThreadSliceSize
+
iK
();
// do element-wise pre-reduction operation
if
constexpr
(
HaveIndexInput
)
in_elementwise_op
(
in_thread_val_buf
(
Number
<
offset
>
{}),
{
in_thread_val_buf
(
Number
<
offset
>
{}));
auto
threadwise_src_idx_load
=
ThreadwiseTensorSliceTransfer_v2
<
IndexDataType
,
IndexDataType
,
InGridDesc_M_K
,
decltype
(
thread_buffer_desc
),
ThreadBufferLengths
,
ThreadBufferDimAccessOrder
,
InSrcVectorDim
,
InSrcVectorSize
,
1
,
false
>
(
in_grid_desc_m_k
,
make_multi_index
(
block_global_1d_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
,
thread_k_cluster_id
*
KThreadSliceSize
));
do
{
// load the thread slice
threadwise_src_val_load
.
Run
(
in_grid_desc_m_k
,
in_global_val_buf
,
thread_buffer_desc
,
make_tuple
(
I0
,
I0
),
in_thread_val_buf
);
threadwise_src_idx_load
.
Run
(
in_grid_desc_m_k
,
in_global_idx_buf
,
thread_buffer_desc
,
make_tuple
(
I0
,
I0
),
in_thread_idx_buf
);
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
iM
)
{
AccDataType
tmpValue
=
identityVal
;
IndexDataType
tmpIndex
=
0
;
static_for
<
0
,
KThreadSliceSize
,
1
>
{}([
&
](
auto
iK
)
{
constexpr
auto
offset
=
thread_buffer_desc
.
CalculateOffset
(
make_tuple
(
iM
,
iK
));
AccumulationWithIndex
::
Calculate
(
tmpValue
,
in_thread_val_buf
[
Number
<
offset
>
{}],
tmpIndex
,
in_thread_idx_buf
[
Number
<
offset
>
{}]);
});
BlockwiseReduceWithIndex
::
Reduce
(
reduce_work_val_buf
,
reduce_work_idx_buf
,
tmpValue
,
tmpIndex
);
AccumulationWithIndex
::
Calculate
(
accu_value_buf
(
iM
),
tmpValue
,
accu_index_buf
(
iM
),
tmpIndex
);
});
});
AccDataType
tmpValue
=
zeroVal
;
threadwise_src_val_load
.
MoveSrcSliceWindow
(
in_grid_desc_m_k
,
in_thread_copy_step
);
IndexDataType
tmpIndex
=
0
;
threadwise_src_idx_load
.
MoveSrcSliceWindow
(
in_grid_desc_m_k
,
in_thread_copy_step
);
static_for
<
0
,
KThreadSliceSize
,
1
>
{}([
&
](
auto
iK
)
{
constexpr
auto
offset
=
thread_buffer_desc
.
CalculateOffset
(
make_tuple
(
iM
,
iK
));
AccumulationWithIndex
::
Calculate
(
tmpValue
,
reducedTiles
++
;
in_thread_val_buf
[
Number
<
offset
>
{}],
}
while
(
reducedTiles
<
num_k_block_tile_iteration
);
tmpIndex
,
}
in_thread_idx_buf
[
Number
<
offset
>
{}]);
else
{
index_t
indexOffset
=
0
;
do
{
// load the thread slice
threadwise_src_val_load
.
Run
(
in_grid_desc_m_k
,
in_global_val_buf
,
thread_buffer_desc
,
make_tuple
(
I0
,
I0
),
in_thread_val_buf
);
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
iM
)
{
static_for
<
0
,
KThreadSliceSize
,
1
>
{}([
&
](
auto
iK
)
{
constexpr
auto
offset
=
thread_buffer_desc
.
CalculateOffset
(
make_tuple
(
iM
,
iK
));
// initialize the indices for the per-thread to-reduce values
in_thread_idx_buf
(
Number
<
offset
>
{})
=
indexOffset
+
thread_k_cluster_id
*
KThreadSliceSize
+
iK
();
// do element-wise pre-reduction operation
in_elementwise_op
(
in_thread_val_buf
(
Number
<
offset
>
{}),
in_thread_val_buf
(
Number
<
offset
>
{}));
});
AccDataType
tmpValue
=
identityVal
;
IndexDataType
tmpIndex
=
0
;
static_for
<
0
,
KThreadSliceSize
,
1
>
{}([
&
](
auto
iK
)
{
constexpr
auto
offset
=
thread_buffer_desc
.
CalculateOffset
(
make_tuple
(
iM
,
iK
));
AccumulationWithIndex
::
Calculate
(
tmpValue
,
in_thread_val_buf
[
Number
<
offset
>
{}],
tmpIndex
,
in_thread_idx_buf
[
Number
<
offset
>
{}]);
});
BlockwiseReduceWithIndex
::
Reduce
(
reduce_work_val_buf
,
reduce_work_idx_buf
,
tmpValue
,
tmpIndex
);
AccumulationWithIndex
::
Calculate
(
accu_value_buf
(
iM
),
tmpValue
,
accu_index_buf
(
iM
),
tmpIndex
);
});
});
BlockwiseReduceWithIndex
::
Reduce
(
threadwise_src_val_load
.
MoveSrcSliceWindow
(
in_grid_desc_m_k
,
in_thread_copy_step
);
reduce_work_val_buf
,
reduce_work_idx_buf
,
tmpValue
,
tmpIndex
);
AccumulationWithIndex
::
Calculate
(
accu_value_buf
(
iM
),
tmpValue
,
accu_index_buf
(
iM
),
tmpIndex
);
});
threadwise_src_load
.
MoveSrcSliceWindow
(
in_grid_desc_m_k
,
in_thread_copy_step
);
indexOffset
+=
K_BlockTileSize
;
reducedTiles
++
;
}
while
(
reducedTiles
<
num_k_block_tile_iteration
);
};
indexOffset
+=
K_BlockTileSize
;
constexpr
auto
reduced_data_desc
=
ThreadReduceDstDesc_M
{}
;
reducedTiles
++
;
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
I
)
{
}
while
(
reducedTiles
<
num_k_block_tile_iteration
);
if
(
thread_k_cluster_id
==
0
)
{
// for indiced operation, acc_elementwise_op shoud do nothing
acc_elementwise_op
(
accu_value_buf
(
I
),
accu_value_buf
(
I
));
constexpr
auto
reduced_data_desc
=
make_naive_tensor_descriptor_packed
(
accu_value_buf
(
I
)
*=
alpha
;
make_tuple
(
Number
<
MThreadSliceSize
>
{},
Number
<
1
>
{}));
}
});
if
(
thread_k_cluster_id
==
0
)
if
(
thread_k_cluster_id
==
0
)
{
{
auto
threadwise_workspace_val_store
=
if
(
!
float_equal_zero
{}(
beta
))
{
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
OutDataType
,
MThreadSliceSize
,
true
>
priorDstValueBuf
;
auto
threadwise_dst_load
=
ThreadwiseTensorSliceTransfer_v2
<
OutDataType
,
OutDataType
,
OutGridDesc_M
,
decltype
(
reduced_data_desc
),
Sequence
<
MThreadSliceSize
>
,
Sequence
<
0
>
,
0
,
OutDstVectorSize
,
1
,
true
>
(
out_grid_desc_m
,
make_multi_index
(
block_global_1d_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
));
threadwise_dst_load
.
Run
(
out_grid_desc_m
,
out_global_val_buf
,
reduced_data_desc
,
make_tuple
(
I0
),
priorDstValueBuf
);
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
I
)
{
accu_value_buf
(
I
)
+=
type_convert
<
AccDataType
>
(
priorDstValueBuf
[
I
])
*
beta
;
});
};
auto
threadwise_dst_val_store
=
ThreadwiseTensorSliceTransfer_v1r3
<
AccDataType
,
ThreadwiseTensorSliceTransfer_v1r3
<
AccDataType
,
Acc
DataType
,
Out
DataType
,
decltype
(
reduced_data_desc
),
decltype
(
reduced_data_desc
),
Workspace
Desc_M
_K
,
OutGrid
Desc_M
,
PassThroughOp
,
PassThroughOp
,
Sequence
<
MThreadSliceSize
,
1
>
,
Sequence
<
MThreadSliceSize
>
,
Sequence
<
0
,
1
>
,
Sequence
<
0
>
,
1
,
0
,
1
,
OutDstVectorSize
,
InMemoryDataOperationEnum
::
Set
,
InMemoryDataOperationEnum
::
Set
,
1
,
1
,
true
>
(
true
>
(
workspace_desc_m_k
,
out_grid_desc_m
,
make_multi_index
(
blkgroup_id
*
M_BlockTileSize
+
make_multi_index
(
block_global_1d_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
,
thread_m_cluster_id
*
MThreadSliceSize
),
block_local_id
),
PassThroughOp
{});
PassThroughOp
{});
auto
threadwise_
workspace
_idx_store
=
auto
threadwise_
dst
_idx_store
=
ThreadwiseTensorSliceTransfer_v1r3
<
IndexDataType
,
ThreadwiseTensorSliceTransfer_v1r3
<
IndexDataType
,
IndexDataType
,
IndexDataType
,
decltype
(
reduced_data_desc
),
decltype
(
reduced_data_desc
),
Workspace
Desc_M
_K
,
OutGrid
Desc_M
,
PassThroughOp
,
PassThroughOp
,
Sequence
<
MThreadSliceSize
,
1
>
,
Sequence
<
MThreadSliceSize
>
,
Sequence
<
0
,
1
>
,
Sequence
<
0
>
,
1
,
0
,
1
,
OutDstVectorSize
,
InMemoryDataOperationEnum
::
Set
,
InMemoryDataOperationEnum
::
Set
,
1
,
1
,
true
>
(
true
>
(
workspace_desc_m_k
,
out_grid_desc_m
,
make_multi_index
(
blkgroup_id
*
M_BlockTileSize
+
make_multi_index
(
block_global_1d_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
,
thread_m_cluster_id
*
MThreadSliceSize
),
block_local_id
),
PassThroughOp
{});
PassThroughOp
{});
threadwise_
workspace
_val_store
.
Run
(
reduced_data_desc
,
threadwise_
dst
_val_store
.
Run
(
reduced_data_desc
,
make_tuple
(
I0
,
I0
),
make_tuple
(
I0
),
accu_value_buf
,
accu_value_buf
,
workspace
_desc_m
_k
,
out_grid
_desc_m
,
workspace
_global_val_buf
);
out
_global_val_buf
);
threadwise_
workspace
_idx_store
.
Run
(
reduced_data_desc
,
threadwise_
dst
_idx_store
.
Run
(
reduced_data_desc
,
make_tuple
(
I0
,
I0
),
make_tuple
(
I0
),
accu_index_buf
,
accu_index_buf
,
workspace
_desc_m
_k
,
out_grid
_desc_m
,
workspace
_global_idx_buf
);
out
_global_idx_buf
);
}
}
};
};
};
};
...
...
include/ck/tensor_operation/gpu/grid/gridwise_2d_reduction_multiblock_atomic_add.hpp
deleted
100644 → 0
View file @
48918ab9
/*******************************************************************************
*
* MIT License
*
* Copyright (c) 2020 Advanced Micro Devices, Inc.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in all
* copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*
*******************************************************************************/
#ifndef CK_GRIDWISE_2D_REDUCTION_MULTIBLOCK_ATOMIC_ADD_HPP
#define CK_GRIDWISE_2D_REDUCTION_MULTIBLOCK_ATOMIC_ADD_HPP
#include "reduction_common.hpp"
#include "reduction_operator.hpp"
#include "reduction_functions_accumulate.hpp"
#include "reduction_functions_blockwise.hpp"
#include "reduction_functions_threadwise.hpp"
#include "threadwise_tensor_slice_transfer.hpp"
#include "element_wise_operation.hpp"
namespace
ck
{
template
<
typename
GridwiseReduction
,
typename
InDataType
,
typename
OutDataType
,
typename
AccDataType
,
typename
InGridDesc_M_K
,
typename
OutGridDesc_M
,
typename
InElementwiseOperation
,
typename
AccElementwiseOperation
>
__global__
void
kernel_reduce_multiblock_atocmi_add
(
const
InGridDesc_M_K
in_grid_desc_m_k
,
const
OutGridDesc_M
out_grid_desc_m
,
const
InElementwiseOperation
in_elementwise_op
,
const
AccElementwiseOperation
acc_elementwise_op
,
index_t
block_group_size
,
index_t
num_k_block_tile_iteration
,
AccDataType
alpha
,
const
InDataType
*
const
__restrict__
p_in_global
,
OutDataType
*
const
__restrict__
p_out_global
)
{
GridwiseReduction
::
Run
(
in_grid_desc_m_k
,
out_grid_desc_m
,
in_elementwise_op
,
acc_elementwise_op
,
block_group_size
,
num_k_block_tile_iteration
,
alpha
,
p_in_global
,
p_out_global
);
};
template
<
typename
InDataType
,
typename
OutDataType
,
typename
AccDataType
,
typename
InGridDesc_M_K
,
typename
OutGridDesc_M
,
typename
ReduceOperation
,
typename
InElementwiseOperation
,
typename
AccElementwiseOperation
,
bool
PropagateNan
,
index_t
BlockSize
,
index_t
MThreadClusterSize
,
index_t
KThreadClusterSize
,
index_t
MThreadSliceSize
,
index_t
KThreadSliceSize
,
index_t
InSrcVectorDim
,
index_t
InSrcVectorSize
,
index_t
OutDstVectorSize
>
struct
GridwiseReduction_mk_to_m_multiblock_atomic_add
{
static_assert
(((
InSrcVectorDim
==
0
&&
MThreadSliceSize
%
InSrcVectorSize
==
0
)
||
(
InSrcVectorDim
==
1
&&
KThreadSliceSize
%
InSrcVectorSize
==
0
))
&&
(
MThreadSliceSize
%
OutDstVectorSize
==
0
),
"Invalid thread slice sizes and/or vector sizes configuration, please check!"
);
static
constexpr
bool
reorder_thread_cluster
=
(
InSrcVectorDim
==
0
);
using
ThreadClusterLengths_M_K
=
Sequence
<
MThreadClusterSize
,
KThreadClusterSize
>
;
using
ThreadBufferDimAccessOrder
=
typename
conditional
<
reorder_thread_cluster
,
Sequence
<
1
,
0
>
,
Sequence
<
0
,
1
>>::
type
;
using
ThreadClusterArrangeOrder
=
typename
conditional
<
reorder_thread_cluster
,
Sequence
<
1
,
0
>
,
Sequence
<
0
,
1
>>::
type
;
static
constexpr
auto
thread_cluster_desc
=
make_cluster_descriptor
(
ThreadClusterLengths_M_K
{},
ThreadClusterArrangeOrder
{});
using
ThreadReduceSrcDesc_M_K
=
decltype
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MThreadSliceSize
>
{},
Number
<
KThreadSliceSize
>
{})));
using
ThreadReduceDstDesc_M
=
decltype
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MThreadSliceSize
>
{})));
using
BlockwiseReduce
=
PartitionedBlockwiseReduction
<
AccDataType
,
BlockSize
,
ThreadClusterLengths_M_K
,
ThreadClusterArrangeOrder
,
ReduceOperation
,
PropagateNan
>
;
using
ThreadwiseReduce
=
ThreadwiseReduction
<
AccDataType
,
ThreadReduceSrcDesc_M_K
,
ThreadReduceDstDesc_M
,
ReduceOperation
,
PropagateNan
>
;
using
PassThroughOp
=
tensor_operation
::
element_wise
::
PassThrough
;
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
index_t
M_BlockTileSize
=
MThreadClusterSize
*
MThreadSliceSize
;
static
constexpr
index_t
K_BlockTileSize
=
KThreadClusterSize
*
KThreadSliceSize
;
using
Accumulation
=
detail
::
AccumulateWithNanCheck
<
PropagateNan
,
ReduceOperation
,
AccDataType
>
;
__device__
static
void
Run
(
const
InGridDesc_M_K
&
in_grid_desc_m_k
,
const
OutGridDesc_M
&
out_grid_desc_m
,
const
InElementwiseOperation
&
in_elementwise_op
,
const
AccElementwiseOperation
&
acc_elementwise_op
,
index_t
block_group_size
,
index_t
num_k_block_tile_iteration
,
AccDataType
alpha
,
const
InDataType
*
const
__restrict__
p_in_global
,
OutDataType
*
const
__restrict__
p_out_global
)
{
const
auto
zeroVal
=
ReduceOperation
::
GetReductionZeroVal
();
// LDS
__shared__
AccDataType
p_reduce_work_buffer
[
BlockSize
];
const
auto
in_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_in_global
,
in_grid_desc_m_k
.
GetElementSpaceSize
(),
type_convert
<
InDataType
>
(
zeroVal
));
auto
out_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_out_global
,
out_grid_desc_m
.
GetElementSpaceSize
());
auto
reduce_work_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
p_reduce_work_buffer
,
BlockSize
);
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
*
KThreadSliceSize
,
true
>
in_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
,
true
>
accu_value_buf
;
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
I
)
{
accu_value_buf
(
I
)
=
zeroVal
;
});
const
index_t
thread_local_id
=
get_thread_local_1d_id
();
const
index_t
block_global_id
=
get_block_1d_id
();
const
index_t
blkgroup_id
=
block_global_id
/
block_group_size
;
const
index_t
block_local_id
=
block_global_id
%
block_group_size
;
const
auto
thread_cluster_idx
=
thread_cluster_desc
.
CalculateBottomIndex
(
make_multi_index
(
thread_local_id
));
const
auto
thread_m_cluster_id
=
thread_cluster_idx
[
I0
];
const
auto
thread_k_cluster_id
=
thread_cluster_idx
[
I1
];
const
index_t
reduceSizePerBlock
=
K_BlockTileSize
*
num_k_block_tile_iteration
;
using
ThreadBufferLengths
=
Sequence
<
MThreadSliceSize
,
KThreadSliceSize
>
;
constexpr
auto
thread_buffer_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MThreadSliceSize
>
{},
Number
<
KThreadSliceSize
>
{}));
auto
threadwise_src_load
=
ThreadwiseTensorSliceTransfer_v2
<
InDataType
,
AccDataType
,
InGridDesc_M_K
,
decltype
(
thread_buffer_desc
),
ThreadBufferLengths
,
ThreadBufferDimAccessOrder
,
InSrcVectorDim
,
InSrcVectorSize
,
1
,
false
>
(
in_grid_desc_m_k
,
make_multi_index
(
blkgroup_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
,
block_local_id
*
reduceSizePerBlock
+
thread_k_cluster_id
*
KThreadSliceSize
));
constexpr
auto
in_thread_copy_step
=
make_multi_index
(
0
,
K_BlockTileSize
);
index_t
reducedTiles
=
0
;
do
{
threadwise_src_load
.
Run
(
in_grid_desc_m_k
,
in_global_buf
,
thread_buffer_desc
,
make_tuple
(
I0
,
I0
),
in_thread_buf
);
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
iM
)
{
// do element-wise pre-reduction operation
static_for
<
0
,
KThreadSliceSize
,
1
>
{}([
&
](
auto
iK
)
{
constexpr
auto
offset
=
thread_buffer_desc
.
CalculateOffset
(
make_tuple
(
iM
,
iK
));
in_elementwise_op
(
in_thread_buf
(
Number
<
offset
>
{}),
in_thread_buf
(
Number
<
offset
>
{}));
});
});
ThreadwiseReduce
::
Reduce
(
in_thread_buf
,
accu_value_buf
);
threadwise_src_load
.
MoveSrcSliceWindow
(
in_grid_desc_m_k
,
in_thread_copy_step
);
reducedTiles
++
;
}
while
(
reducedTiles
<
num_k_block_tile_iteration
);
constexpr
auto
reduced_data_desc
=
ThreadReduceDstDesc_M
{};
// Each block executes multiple parallel reductions on the LDS, and by atomic-adding its
// reduced output to the global location corresponding to each invariant dimension to get a
// consistent reduced result for that invariant dimension. due to the using of vector_load,
// each block/thread is involved into multiple invarirant dimensions.
static_for
<
0
,
MThreadSliceSize
,
1
>
{}(
[
&
](
auto
I
)
{
BlockwiseReduce
::
Reduce
(
reduce_work_buf
,
accu_value_buf
(
I
));
});
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
I
)
{
if
(
thread_k_cluster_id
==
0
)
{
acc_elementwise_op
(
accu_value_buf
(
I
),
accu_value_buf
(
I
));
accu_value_buf
(
I
)
*=
alpha
;
}
});
if
(
thread_k_cluster_id
==
0
)
{
auto
threadwise_dst_store
=
ThreadwiseTensorSliceTransfer_v1r3
<
AccDataType
,
OutDataType
,
decltype
(
reduced_data_desc
),
OutGridDesc_M
,
PassThroughOp
,
Sequence
<
MThreadSliceSize
>
,
Sequence
<
0
>
,
0
,
OutDstVectorSize
,
InMemoryDataOperationEnum
::
AtomicAdd
,
1
,
true
>
(
out_grid_desc_m
,
make_multi_index
(
blkgroup_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
),
PassThroughOp
{});
threadwise_dst_store
.
Run
(
reduced_data_desc
,
make_tuple
(
I0
),
accu_value_buf
,
out_grid_desc_m
,
out_global_buf
);
}
};
};
}
// namespace ck
#endif
include/ck/tensor_operation/gpu/grid/gridwise_2d_reduction_threadwise.hpp
View file @
a3b4c5cb
...
@@ -37,7 +37,8 @@
...
@@ -37,7 +37,8 @@
namespace
ck
{
namespace
ck
{
template
<
typename
GridwiseReduction
,
template
<
typename
GridwiseReduction
,
bool
NeedIndices
,
bool
OutputIndex
,
bool
HaveIndexInput
,
typename
InDataType
,
typename
InDataType
,
typename
OutDataType
,
typename
OutDataType
,
typename
AccDataType
,
typename
AccDataType
,
...
@@ -51,34 +52,35 @@ __global__ void kernel_reduce_threadwise(const InGridDesc_M_K in_grid_desc_m_k,
...
@@ -51,34 +52,35 @@ __global__ void kernel_reduce_threadwise(const InGridDesc_M_K in_grid_desc_m_k,
const
InElementwiseOperation
in_elementwise_op
,
const
InElementwiseOperation
in_elementwise_op
,
const
AccElementwiseOperation
acc_elementwise_op
,
const
AccElementwiseOperation
acc_elementwise_op
,
AccDataType
alpha
,
AccDataType
alpha
,
const
InDataType
*
const
__restrict__
p_in_global
,
const
InDataType
*
const
__restrict__
p_in_value_global
,
const
IndexDataType
*
const
__restrict__
p_in_index_global
,
AccDataType
beta
,
AccDataType
beta
,
OutDataType
*
const
__restrict__
p_out_global
,
OutDataType
*
const
__restrict__
p_out_
value_
global
,
IndexDataType
*
const
__restrict__
p_
indices
_global
)
IndexDataType
*
const
__restrict__
p_
out_index
_global
)
{
{
if
constexpr
(
!
NeedIndices
)
if
constexpr
(
!
OutputIndex
)
{
{
GridwiseReduction
::
Run
(
in_grid_desc_m_k
,
GridwiseReduction
::
Run
(
in_grid_desc_m_k
,
out_grid_desc_m
,
out_grid_desc_m
,
in_elementwise_op
,
in_elementwise_op
,
acc_elementwise_op
,
acc_elementwise_op
,
alpha
,
alpha
,
p_in_global
,
p_in_
value_
global
,
beta
,
beta
,
p_out_global
,
p_out_value_global
);
p_indices_global
);
}
}
else
else
{
{
GridwiseReduction
::
RunWithIndices
(
in_grid_desc_m_k
,
GridwiseReduction
::
template
RunWithIndex
<
HaveIndexInput
>(
in_grid_desc_m_k
,
out_grid_desc_m
,
out_grid_desc_m
,
in_elementwise_op
,
in_elementwise_op
,
acc_elementwise_op
,
acc_elementwise_op
,
alpha
,
alpha
,
p_in_global
,
p_in_value_global
,
beta
,
p_in_index_global
,
p_out_global
,
beta
,
p_indices_global
);
p_out_value_global
,
p_out_index_global
);
};
};
};
};
...
@@ -91,11 +93,9 @@ template <typename InDataType,
...
@@ -91,11 +93,9 @@ template <typename InDataType,
typename
ReduceOperation
,
typename
ReduceOperation
,
typename
InElementwiseOperation
,
typename
InElementwiseOperation
,
typename
AccElementwiseOperation
,
typename
AccElementwiseOperation
,
InMemoryDataOperationEnum
OutMemoryDataOperation
,
bool
PropagateNan
,
bool
PropagateNan
,
bool
BetaIsZero
,
index_t
BlockSize
,
index_t
BlockSize
,
index_t
MThreadClusterSize
,
index_t
KThreadClusterSize
,
index_t
MThreadSliceSize
,
index_t
MThreadSliceSize
,
index_t
KThreadSliceSize
,
index_t
KThreadSliceSize
,
index_t
InSrcVectorDim
,
index_t
InSrcVectorDim
,
...
@@ -125,10 +125,9 @@ struct GridwiseReduction_mk_to_m_threadwise
...
@@ -125,10 +125,9 @@ struct GridwiseReduction_mk_to_m_threadwise
const
InElementwiseOperation
&
in_elementwise_op
,
const
InElementwiseOperation
&
in_elementwise_op
,
const
AccElementwiseOperation
&
acc_elementwise_op
,
const
AccElementwiseOperation
&
acc_elementwise_op
,
AccDataType
alpha
,
AccDataType
alpha
,
const
InDataType
*
const
__restrict__
p_in_global
,
const
InDataType
*
const
__restrict__
p_in_
value_
global
,
AccDataType
beta
,
AccDataType
beta
,
OutDataType
*
const
__restrict__
p_out_global
,
OutDataType
*
const
__restrict__
p_out_value_global
)
IndexDataType
*
const
__restrict__
p_indices_global
)
{
{
using
ThreadwiseReduce
=
ThreadwiseReduction
<
AccDataType
,
using
ThreadwiseReduce
=
ThreadwiseReduction
<
AccDataType
,
ThreadReduceSrcDesc_M_K
,
ThreadReduceSrcDesc_M_K
,
...
@@ -136,21 +135,21 @@ struct GridwiseReduction_mk_to_m_threadwise
...
@@ -136,21 +135,21 @@ struct GridwiseReduction_mk_to_m_threadwise
ReduceOperation
,
ReduceOperation
,
PropagateNan
>
;
PropagateNan
>
;
(
void
)
p_indices_global
;
const
auto
identityVal
=
ReduceOperation
::
GetIdentityValue
()
;
const
auto
zeroVal
=
ReduceOperation
::
GetReductionZeroVal
();
const
auto
in_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_in_value_global
,
const
auto
in_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
in_grid_desc_m_k
.
GetElementSpaceSize
(),
p_in_global
,
in_grid_desc_m_k
.
GetElementSpaceSize
(),
type_convert
<
InDataType
>
(
zero
Val
));
type_convert
<
InDataType
>
(
identity
Val
));
auto
dst_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
auto
dst_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_out_global
,
out_grid_desc_m
.
GetElementSpaceSize
());
p_out_
value_
global
,
out_grid_desc_m
.
GetElementSpaceSize
());
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
*
KThreadSliceSize
,
true
>
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
*
KThreadSliceSize
,
true
>
in_thread_buf
;
in_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
,
true
>
accu_value_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
,
true
>
accu_value_buf
;
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
I
)
{
accu_value_buf
(
I
)
=
zero
Val
;
});
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
I
)
{
accu_value_buf
(
I
)
=
identity
Val
;
});
const
auto
toReduceLength
=
in_grid_desc_m_k
.
GetLength
(
Number
<
1
>
{});
const
auto
toReduceLength
=
in_grid_desc_m_k
.
GetLength
(
Number
<
1
>
{});
...
@@ -160,28 +159,29 @@ struct GridwiseReduction_mk_to_m_threadwise
...
@@ -160,28 +159,29 @@ struct GridwiseReduction_mk_to_m_threadwise
index_t
thread_global_1d_id
=
get_block_1d_id
()
*
BlockSize
+
get_thread_local_1d_id
();
index_t
thread_global_1d_id
=
get_block_1d_id
()
*
BlockSize
+
get_thread_local_1d_id
();
auto
threadwise_src_load
=
ThreadwiseTensorSliceTransfer_v2
<
InDataType
,
auto
threadwise_src_val_load
=
AccDataType
,
ThreadwiseTensorSliceTransfer_v2
<
InDataType
,
InGridDesc_M_K
,
AccDataType
,
decltype
(
thread_buffer_desc
),
InGridDesc_M_K
,
ThreadBufferLengths
,
decltype
(
thread_buffer_desc
),
ThreadBufferDimAccessOrder
,
ThreadBufferLengths
,
InSrcVectorDim
,
ThreadBufferDimAccessOrder
,
InSrcVectorSize
,
InSrcVectorDim
,
1
,
InSrcVectorSize
,
false
>
(
1
,
in_grid_desc_m_k
,
make_multi_index
(
thread_global_1d_id
*
MThreadSliceSize
,
0
));
false
>
(
in_grid_desc_m_k
,
make_multi_index
(
thread_global_1d_id
*
MThreadSliceSize
,
0
));
constexpr
auto
in_thread_copy_step
=
make_multi_index
(
0
,
KThreadSliceSize
);
constexpr
auto
in_thread_copy_step
=
make_multi_index
(
0
,
KThreadSliceSize
);
index_t
reducedLength
=
0
;
index_t
reducedLength
=
0
;
do
do
{
{
threadwise_src_load
.
Run
(
in_grid_desc_m_k
,
threadwise_src_
val_
load
.
Run
(
in_grid_desc_m_k
,
in_global_buf
,
in_global_
val_
buf
,
thread_buffer_desc
,
thread_buffer_desc
,
make_tuple
(
I0
,
I0
),
make_tuple
(
I0
,
I0
),
in_thread_buf
);
in_thread_buf
);
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
iM
)
{
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
iM
)
{
// do element-wise pre-reduction operation
// do element-wise pre-reduction operation
...
@@ -194,7 +194,7 @@ struct GridwiseReduction_mk_to_m_threadwise
...
@@ -194,7 +194,7 @@ struct GridwiseReduction_mk_to_m_threadwise
ThreadwiseReduce
::
Reduce
(
in_thread_buf
,
accu_value_buf
);
ThreadwiseReduce
::
Reduce
(
in_thread_buf
,
accu_value_buf
);
threadwise_src_load
.
MoveSrcSliceWindow
(
in_grid_desc_m_k
,
in_thread_copy_step
);
threadwise_src_
val_
load
.
MoveSrcSliceWindow
(
in_grid_desc_m_k
,
in_thread_copy_step
);
reducedLength
+=
KThreadSliceSize
;
reducedLength
+=
KThreadSliceSize
;
}
while
(
reducedLength
<
toReduceLength
);
}
while
(
reducedLength
<
toReduceLength
);
...
@@ -207,68 +207,65 @@ struct GridwiseReduction_mk_to_m_threadwise
...
@@ -207,68 +207,65 @@ struct GridwiseReduction_mk_to_m_threadwise
constexpr
auto
reduced_data_desc
=
ThreadReduceDstDesc_M
{};
constexpr
auto
reduced_data_desc
=
ThreadReduceDstDesc_M
{};
if
constexpr
(
!
BetaIsZero
)
if
(
!
float_equal_zero
{}(
beta
)
)
{
{
if
(
!
float_equal_zero
{}(
beta
))
auto
threadwise_dst_load
=
ThreadwiseTensorSliceTransfer_v2
<
OutDataType
,
{
OutDataType
,
auto
threadwise_dst_load
=
OutGridDesc_M
,
ThreadwiseTensorSliceTransfer_v2
<
OutDataType
,
decltype
(
reduced_data_desc
),
OutDataType
,
Sequence
<
MThreadSliceSize
>
,
OutGridDesc_M
,
Sequence
<
0
>
,
decltype
(
reduced_data_desc
),
0
,
Sequence
<
MThreadSliceSize
>
,
1
,
Sequence
<
0
>
,
1
,
0
,
true
>
(
1
,
out_grid_desc_m
,
make_multi_index
(
thread_global_1d_id
*
MThreadSliceSize
));
1
,
true
>
(
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
OutDataType
,
MThreadSliceSize
,
true
>
out_grid_desc_m
,
make_multi_index
(
thread_global_1d_id
*
MThreadSliceSize
));
priorDstValue_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
OutDataType
,
MThreadSliceSize
,
true
>
threadwise_dst_load
.
Run
(
out_grid_desc_m
,
priorDstValue_buf
;
dst_global_buf
,
reduced_data_desc
,
threadwise_dst_load
.
Run
(
out_grid_desc_m
,
make_tuple
(
I0
),
dst_global_buf
,
priorDstValue_buf
);
reduced_data_desc
,
make_tuple
(
I0
),
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
I
)
{
priorDstValue_buf
);
accu_value_buf
(
I
)
+=
type_convert
<
AccDataType
>
(
priorDstValue_buf
[
I
])
*
beta
;
});
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
I
)
{
accu_value_buf
(
I
)
+=
type_convert
<
AccDataType
>
(
priorDstValue_buf
[
I
])
*
beta
;
});
};
};
};
auto
threadwise_dst_store
=
auto
threadwise_dst_store
=
ThreadwiseTensorSliceTransfer_v1r3
<
AccDataType
,
ThreadwiseTensorSliceTransfer_v1r3
<
AccDataType
,
OutDataType
,
OutDataType
,
decltype
(
reduced_data_desc
),
decltype
(
reduced_data_desc
),
OutGridDesc_M
,
OutGridDesc_M
,
PassThroughOp
,
PassThroughOp
,
Sequence
<
MThreadSliceSize
>
,
Sequence
<
MThreadSliceSize
>
,
Sequence
<
0
>
,
Sequence
<
0
>
,
0
,
0
,
OutDstVectorSize
,
OutDstVectorSize
,
OutMemoryDataOperation
,
InMemoryDataOperationEnum
::
Set
,
1
,
1
,
false
>
(
false
>
(
out_grid_desc_m
,
out_grid_desc_m
,
make_multi_index
(
thread_global_1d_id
*
MThreadSliceSize
),
make_multi_index
(
thread_global_1d_id
*
MThreadSliceSize
),
PassThroughOp
{});
PassThroughOp
{});
threadwise_dst_store
.
Run
(
threadwise_dst_store
.
Run
(
reduced_data_desc
,
make_tuple
(
I0
),
accu_value_buf
,
out_grid_desc_m
,
dst_global_buf
);
reduced_data_desc
,
make_tuple
(
I0
),
accu_value_buf
,
out_grid_desc_m
,
dst_global_buf
);
};
};
__device__
static
void
RunWithIndices
(
const
InGridDesc_M_K
&
in_grid_desc_m_k
,
template
<
bool
HaveIndexInput
>
const
OutGridDesc_M
&
out_grid_desc_m
,
__device__
static
void
RunWithIndex
(
const
InGridDesc_M_K
&
in_grid_desc_m_k
,
const
InElementwiseOperation
&
in_elementwise_op
,
const
OutGridDesc_M
&
out_grid_desc_m
,
const
AccElementwiseOperation
&
acc_elementwise_op
,
const
InElementwiseOperation
&
in_elementwise_op
,
AccDataType
alpha
,
const
AccElementwiseOperation
&
acc_elementwise_op
,
const
InDataType
*
const
__restrict__
p_in_global
,
AccDataType
alpha
,
AccDataType
beta
,
const
InDataType
*
const
__restrict__
p_in_value_global
,
OutDataType
*
const
__restrict__
p_out_global
,
const
IndexDataType
*
const
__restrict__
p_in_index_global
,
IndexDataType
*
const
__restrict__
p_indices_global
)
AccDataType
beta
,
OutDataType
*
const
__restrict__
p_out_value_global
,
IndexDataType
*
const
__restrict__
p_out_index_global
)
{
{
using
ThreadwiseReduceWithIndex
=
ThreadwiseReductionWithIndex
<
AccDataType
,
using
ThreadwiseReduceWithIndex
=
ThreadwiseReductionWithIndex
<
AccDataType
,
IndexDataType
,
IndexDataType
,
...
@@ -279,14 +276,19 @@ struct GridwiseReduction_mk_to_m_threadwise
...
@@ -279,14 +276,19 @@ struct GridwiseReduction_mk_to_m_threadwise
(
void
)
acc_elementwise_op
;
(
void
)
acc_elementwise_op
;
const
auto
zeroVal
=
ReduceOperation
::
GetReductionZeroVal
();
const
auto
identityVal
=
ReduceOperation
::
GetIdentityValue
();
const
auto
in_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_in_value_global
,
in_grid_desc_m_k
.
GetElementSpaceSize
(),
type_convert
<
InDataType
>
(
identityVal
));
const
auto
in_global_idx_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_in_index_global
,
in_grid_desc_m_k
.
GetElementSpaceSize
());
const
auto
in_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_in_global
,
in_grid_desc_m_k
.
GetElementSpaceSize
(),
type_convert
<
InDataType
>
(
zeroVal
));
auto
out_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
auto
out_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_out_global
,
out_grid_desc_m
.
GetElementSpaceSize
());
p_out_
value_
global
,
out_grid_desc_m
.
GetElementSpaceSize
());
auto
out_global_idx_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
auto
out_global_idx_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_
indices
_global
,
out_grid_desc_m
.
GetElementSpaceSize
());
p_
out_index
_global
,
out_grid_desc_m
.
GetElementSpaceSize
());
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
*
KThreadSliceSize
,
true
>
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
*
KThreadSliceSize
,
true
>
in_thread_val_buf
;
in_thread_val_buf
;
...
@@ -301,7 +303,7 @@ struct GridwiseReduction_mk_to_m_threadwise
...
@@ -301,7 +303,7 @@ struct GridwiseReduction_mk_to_m_threadwise
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
IndexDataType
,
MThreadSliceSize
,
true
>
accu_index_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
IndexDataType
,
MThreadSliceSize
,
true
>
accu_index_buf
;
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
I
)
{
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
I
)
{
accu_value_buf
(
I
)
=
zero
Val
;
accu_value_buf
(
I
)
=
identity
Val
;
accu_index_buf
(
I
)
=
0
;
accu_index_buf
(
I
)
=
0
;
});
});
...
@@ -313,50 +315,105 @@ struct GridwiseReduction_mk_to_m_threadwise
...
@@ -313,50 +315,105 @@ struct GridwiseReduction_mk_to_m_threadwise
index_t
thread_global_1d_id
=
get_block_1d_id
()
*
BlockSize
+
get_thread_local_1d_id
();
index_t
thread_global_1d_id
=
get_block_1d_id
()
*
BlockSize
+
get_thread_local_1d_id
();
auto
threadwise_src_load
=
ThreadwiseTensorSliceTransfer_v2
<
InDataType
,
auto
threadwise_src_val_load
=
AccDataType
,
ThreadwiseTensorSliceTransfer_v2
<
InDataType
,
InGridDesc_M_K
,
AccDataType
,
decltype
(
thread_buffer_desc
),
InGridDesc_M_K
,
ThreadBufferLengths
,
decltype
(
thread_buffer_desc
),
ThreadBufferDimAccessOrder
,
ThreadBufferLengths
,
InSrcVectorDim
,
ThreadBufferDimAccessOrder
,
InSrcVectorSize
,
InSrcVectorDim
,
1
,
InSrcVectorSize
,
false
>
(
1
,
in_grid_desc_m_k
,
make_multi_index
(
thread_global_1d_id
*
MThreadSliceSize
,
0
));
false
>
(
in_grid_desc_m_k
,
make_multi_index
(
thread_global_1d_id
*
MThreadSliceSize
,
0
));
constexpr
auto
in_thread_copy_step
=
make_multi_index
(
0
,
KThreadSliceSize
);
constexpr
auto
in_thread_copy_step
=
make_multi_index
(
0
,
KThreadSliceSize
);
index_t
indexStart
=
0
;
index_t
indexStart
=
0
;
index_t
reducedLength
=
0
;
index_t
reducedLength
=
0
;
do
if
constexpr
(
HaveIndexInput
)
{
{
threadwise_src_load
.
Run
(
in_grid_desc_m_k
,
auto
threadwise_src_idx_load
=
in_global_buf
,
ThreadwiseTensorSliceTransfer_v2
<
IndexDataType
,
thread_buffer_desc
,
IndexDataType
,
make_tuple
(
I0
,
I0
),
InGridDesc_M_K
,
in_thread_val_buf
);
decltype
(
thread_buffer_desc
),
ThreadBufferLengths
,
ThreadBufferDimAccessOrder
,
InSrcVectorDim
,
InSrcVectorSize
,
1
,
false
>
(
in_grid_desc_m_k
,
make_multi_index
(
thread_global_1d_id
*
MThreadSliceSize
,
0
));
do
{
threadwise_src_val_load
.
Run
(
in_grid_desc_m_k
,
in_global_val_buf
,
thread_buffer_desc
,
make_tuple
(
I0
,
I0
),
in_thread_val_buf
);
threadwise_src_idx_load
.
Run
(
in_grid_desc_m_k
,
in_global_idx_buf
,
thread_buffer_desc
,
make_tuple
(
I0
,
I0
),
in_thread_idx_buf
);
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
iM
)
{
// do element-wise pre-reduction operation
static_for
<
0
,
KThreadSliceSize
,
1
>
{}([
&
](
auto
iK
)
{
constexpr
auto
offset
=
thread_buffer_desc
.
CalculateOffset
(
make_tuple
(
iM
,
iK
));
in_elementwise_op
(
in_thread_val_buf
(
Number
<
offset
>
{}),
in_thread_val_buf
(
Number
<
offset
>
{}));
});
});
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
iM
)
{
ThreadwiseReduceWithIndex
::
Reduce
(
// do element-wise pre-reduction operation
in_thread_val_buf
,
in_thread_idx_buf
,
accu_value_buf
,
accu_index_buf
);
static_for
<
0
,
KThreadSliceSize
,
1
>
{}([
&
](
auto
iK
)
{
constexpr
auto
offset
=
thread_buffer_desc
.
CalculateOffset
(
make_tuple
(
iM
,
iK
));
in_thread_idx_buf
(
Number
<
offset
>
{})
=
indexStart
+
iK
();
threadwise_src_val_load
.
MoveSrcSliceWindow
(
in_grid_desc_m_k
,
in_thread_copy_step
);
threadwise_src_idx_load
.
MoveSrcSliceWindow
(
in_grid_desc_m_k
,
in_thread_copy_step
);
in_elementwise_op
(
in_thread_val_buf
(
Number
<
offset
>
{}),
indexStart
+=
KThreadSliceSize
;
in_thread_val_buf
(
Number
<
offset
>
{}));
reducedLength
+=
KThreadSliceSize
;
}
while
(
reducedLength
<
toReduceLength
);
}
else
{
do
{
threadwise_src_val_load
.
Run
(
in_grid_desc_m_k
,
in_global_val_buf
,
thread_buffer_desc
,
make_tuple
(
I0
,
I0
),
in_thread_val_buf
);
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
iM
)
{
// do element-wise pre-reduction operation
static_for
<
0
,
KThreadSliceSize
,
1
>
{}([
&
](
auto
iK
)
{
constexpr
auto
offset
=
thread_buffer_desc
.
CalculateOffset
(
make_tuple
(
iM
,
iK
));
in_thread_idx_buf
(
Number
<
offset
>
{})
=
indexStart
+
iK
();
in_elementwise_op
(
in_thread_val_buf
(
Number
<
offset
>
{}),
in_thread_val_buf
(
Number
<
offset
>
{}));
});
});
});
});
ThreadwiseReduceWithIndex
::
Reduce
(
ThreadwiseReduceWithIndex
::
Reduce
(
in_thread_val_buf
,
in_thread_idx_buf
,
accu_value_buf
,
accu_index_buf
);
in_thread_val_buf
,
in_thread_idx_buf
,
accu_value_buf
,
accu_index_buf
);
threadwise_src_load
.
MoveSrcSliceWindow
(
in_grid_desc_m_k
,
in_thread_copy_step
);
threadwise_src_
val_
load
.
MoveSrcSliceWindow
(
in_grid_desc_m_k
,
in_thread_copy_step
);
indexStart
+=
KThreadSliceSize
;
indexStart
+=
KThreadSliceSize
;
reducedLength
+=
KThreadSliceSize
;
reducedLength
+=
KThreadSliceSize
;
}
while
(
reducedLength
<
toReduceLength
);
}
while
(
reducedLength
<
toReduceLength
);
};
// for indiced operation, acc_elementwise_op shoud do nothing
// for indiced operation, acc_elementwise_op shoud do nothing
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
I
)
{
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
I
)
{
...
@@ -367,36 +424,32 @@ struct GridwiseReduction_mk_to_m_threadwise
...
@@ -367,36 +424,32 @@ struct GridwiseReduction_mk_to_m_threadwise
constexpr
auto
reduced_data_desc
=
ThreadReduceDstDesc_M
{};
constexpr
auto
reduced_data_desc
=
ThreadReduceDstDesc_M
{};
if
constexpr
(
!
BetaIsZero
)
if
(
!
float_equal_zero
{}(
beta
)
)
{
{
if
(
!
float_equal_zero
{}(
beta
))
auto
threadwise_dst_load
=
ThreadwiseTensorSliceTransfer_v2
<
OutDataType
,
{
OutDataType
,
auto
threadwise_dst_load
=
OutGridDesc_M
,
ThreadwiseTensorSliceTransfer_v2
<
OutDataType
,
decltype
(
reduced_data_desc
),
OutDataType
,
Sequence
<
MThreadSliceSize
>
,
OutGridDesc_M
,
Sequence
<
0
>
,
decltype
(
reduced_data_desc
),
0
,
Sequence
<
MThreadSliceSize
>
,
1
,
Sequence
<
0
>
,
1
,
0
,
false
>
(
1
,
out_grid_desc_m
,
make_multi_index
(
thread_global_1d_id
*
MThreadSliceSize
));
1
,
false
>
(
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
OutDataType
,
MThreadSliceSize
,
true
>
out_grid_desc_m
,
make_multi_index
(
thread_global_1d_id
*
MThreadSliceSize
));
priorDstValue_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
OutDataType
,
MThreadSliceSize
,
true
>
threadwise_dst_load
.
Run
(
out_grid_desc_m
,
priorDstValue_buf
;
out_global_val_buf
,
reduced_data_desc
,
threadwise_dst_load
.
Run
(
out_grid_desc_m
,
make_tuple
(
I0
),
out_global_val_buf
,
priorDstValue_buf
);
reduced_data_desc
,
make_tuple
(
I0
),
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
I
)
{
priorDstValue_buf
);
accu_value_buf
(
I
)
+=
type_convert
<
AccDataType
>
(
priorDstValue_buf
[
I
])
*
beta
;
});
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
I
)
{
accu_value_buf
(
I
)
+=
type_convert
<
AccDataType
>
(
priorDstValue_buf
[
I
])
*
beta
;
});
};
};
};
auto
threadwise_dst_val_store
=
auto
threadwise_dst_val_store
=
...
@@ -409,7 +462,7 @@ struct GridwiseReduction_mk_to_m_threadwise
...
@@ -409,7 +462,7 @@ struct GridwiseReduction_mk_to_m_threadwise
Sequence
<
0
>
,
Sequence
<
0
>
,
0
,
0
,
OutDstVectorSize
,
OutDstVectorSize
,
In
MemoryDataOperation
Enum
::
Set
,
Out
MemoryDataOperation
,
1
,
1
,
false
>
(
false
>
(
out_grid_desc_m
,
out_grid_desc_m
,
...
@@ -426,7 +479,7 @@ struct GridwiseReduction_mk_to_m_threadwise
...
@@ -426,7 +479,7 @@ struct GridwiseReduction_mk_to_m_threadwise
Sequence
<
0
>
,
Sequence
<
0
>
,
0
,
0
,
OutDstVectorSize
,
OutDstVectorSize
,
In
MemoryDataOperation
Enum
::
Set
,
Out
MemoryDataOperation
,
1
,
1
,
false
>
(
false
>
(
out_grid_desc_m
,
out_grid_desc_m
,
...
...
include/ck/tensor_operation/gpu/grid/gridwise_5ary_Elementwise_1d.hpp
0 → 100644
View file @
a3b4c5cb
#pragma once
#include "cluster_descriptor.hpp"
#include "data_type.hpp"
#include "element_wise_operation.hpp"
#include "threadwise_tensor_slice_transfer.hpp"
namespace
ck
{
template
<
typename
Gridwise5AryEltwise
,
typename
ADataType
,
typename
BDataType
,
typename
CDataType
,
typename
DDataType
,
typename
EDataType
,
typename
FDataType
,
typename
AGridDesc_M
,
typename
BGridDesc_M
,
typename
CGridDesc_M
,
typename
DGridDesc_M
,
typename
EGridDesc_M
,
typename
FGridDesc_M
,
typename
ElementwiseFunctor
>
__global__
void
kernel_5ary_elementwise_1d
(
const
ADataType
*
__restrict__
p_a_global
,
const
BDataType
*
__restrict__
p_b_global
,
const
CDataType
*
__restrict__
p_c_global
,
const
DDataType
*
__restrict__
p_d_global
,
const
EDataType
*
__restrict__
p_e_global
,
FDataType
*
__restrict__
p_f_global
,
const
AGridDesc_M
a_grid_desc_m
,
const
BGridDesc_M
b_grid_desc_m
,
const
CGridDesc_M
c_grid_desc_m
,
const
DGridDesc_M
d_grid_desc_m
,
const
EGridDesc_M
e_grid_desc_m
,
const
FGridDesc_M
f_grid_desc_m
,
const
ElementwiseFunctor
functor
)
{
Gridwise5AryEltwise
::
Run
(
p_a_global
,
p_b_global
,
p_c_global
,
p_d_global
,
p_e_global
,
p_f_global
,
a_grid_desc_m
,
b_grid_desc_m
,
c_grid_desc_m
,
d_grid_desc_m
,
e_grid_desc_m
,
f_grid_desc_m
,
functor
);
}
// TODO - implement n-ary Elemenetwise_1D, tuple of inputs and tuple of outputs
template
<
typename
ADataType
,
typename
BDataType
,
typename
CDataType
,
typename
DDataType
,
typename
EDataType
,
typename
FDataType
,
typename
ComputeDataType
,
typename
AGridDesc_M
,
typename
BGridDesc_M
,
typename
CGridDesc_M
,
typename
DGridDesc_M
,
typename
EGridDesc_M
,
typename
FGridDesc_M
,
typename
ElementwiseFunctor
,
index_t
MPerThread
,
index_t
AScalarPerVector
,
index_t
BScalarPerVector
,
index_t
CScalarPerVector
,
index_t
DScalarPerVector
,
index_t
EScalarPerVector
,
index_t
FScalarPerVector
>
struct
Gridwise5AryElementwise_1D
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
thread_desc_m
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MPerThread
>
{}));
using
PassThrough
=
tensor_operation
::
element_wise
::
PassThrough
;
static
__device__
auto
CalculateElementwiseIndex
()
{
const
index_t
global_thread_id
=
get_thread_global_1d_id
();
return
make_multi_index
(
global_thread_id
*
MPerThread
);
}
__device__
static
void
Run
(
const
ADataType
*
__restrict__
p_a_global
,
const
BDataType
*
__restrict__
p_b_global
,
const
CDataType
*
__restrict__
p_c_global
,
const
DDataType
*
__restrict__
p_d_global
,
const
EDataType
*
__restrict__
p_e_global
,
FDataType
*
__restrict__
p_f_global
,
const
AGridDesc_M
a_grid_desc_m
,
const
BGridDesc_M
b_grid_desc_m
,
const
CGridDesc_M
c_grid_desc_m
,
const
DGridDesc_M
d_grid_desc_m
,
const
EGridDesc_M
e_grid_desc_m
,
const
FGridDesc_M
f_grid_desc_m
,
const
ElementwiseFunctor
functor
)
{
const
auto
a_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_a_global
,
a_grid_desc_m
.
GetElementSpaceSize
());
const
auto
b_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_b_global
,
b_grid_desc_m
.
GetElementSpaceSize
());
const
auto
c_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_c_global
,
c_grid_desc_m
.
GetElementSpaceSize
());
const
auto
d_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_d_global
,
d_grid_desc_m
.
GetElementSpaceSize
());
const
auto
e_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_e_global
,
e_grid_desc_m
.
GetElementSpaceSize
());
auto
f_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_f_global
,
f_grid_desc_m
.
GetElementSpaceSize
());
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
ComputeDataType
,
MPerThread
,
true
>
a_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
ComputeDataType
,
MPerThread
,
true
>
b_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
ComputeDataType
,
MPerThread
,
true
>
c_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
ComputeDataType
,
MPerThread
,
true
>
d_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
ComputeDataType
,
MPerThread
,
true
>
e_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
ComputeDataType
,
MPerThread
,
true
>
f_thread_buf
;
const
auto
thread_store_global_offset
=
CalculateElementwiseIndex
();
auto
a_global_load
=
ThreadwiseTensorSliceTransfer_v2
<
ADataType
,
ComputeDataType
,
AGridDesc_M
,
decltype
(
thread_desc_m
),
Sequence
<
MPerThread
>
,
// SliceLengths
Sequence
<
0
>
,
// DimAccessOrder
0
,
// SrcVectorDim
AScalarPerVector
,
// ScalarPerVector
1
,
// SrcScalarStrideInVector
false
>
{
a_grid_desc_m
,
thread_store_global_offset
};
auto
b_global_load
=
ThreadwiseTensorSliceTransfer_v2
<
BDataType
,
ComputeDataType
,
BGridDesc_M
,
decltype
(
thread_desc_m
),
Sequence
<
MPerThread
>
,
// SliceLengths
Sequence
<
0
>
,
// DimAccessOrder
0
,
// SrcVectorDim
BScalarPerVector
,
// ScalarPerVector
1
,
// SrcScalarStrideInVector
false
>
{
b_grid_desc_m
,
thread_store_global_offset
};
auto
c_global_load
=
ThreadwiseTensorSliceTransfer_v2
<
CDataType
,
ComputeDataType
,
CGridDesc_M
,
decltype
(
thread_desc_m
),
Sequence
<
MPerThread
>
,
// SliceLengths
Sequence
<
0
>
,
// DimAccessOrder
0
,
// SrcVectorDim
CScalarPerVector
,
// ScalarPerVector
1
,
// SrcScalarStrideInVector
false
>
{
c_grid_desc_m
,
thread_store_global_offset
};
auto
d_global_load
=
ThreadwiseTensorSliceTransfer_v2
<
DDataType
,
ComputeDataType
,
DGridDesc_M
,
decltype
(
thread_desc_m
),
Sequence
<
MPerThread
>
,
// SliceLengths
Sequence
<
0
>
,
// DimAccessOrder
0
,
// SrcVectorDim
DScalarPerVector
,
// ScalarPerVector
1
,
// SrcScalarStrideInVector
false
>
{
d_grid_desc_m
,
thread_store_global_offset
};
auto
e_global_load
=
ThreadwiseTensorSliceTransfer_v2
<
EDataType
,
ComputeDataType
,
EGridDesc_M
,
decltype
(
thread_desc_m
),
Sequence
<
MPerThread
>
,
// SliceLengths
Sequence
<
0
>
,
// DimAccessOrder
0
,
// SrcVectorDim
EScalarPerVector
,
// ScalarPerVector
1
,
// SrcScalarStrideInVector
false
>
{
e_grid_desc_m
,
thread_store_global_offset
};
auto
f_global_write
=
ThreadwiseTensorSliceTransfer_v1r3
<
ComputeDataType
,
FDataType
,
decltype
(
thread_desc_m
),
FGridDesc_M
,
PassThrough
,
Sequence
<
MPerThread
>
,
// SliceLengths
Sequence
<
0
>
,
// DimAccessOrder
0
,
// DstVectorDim
FScalarPerVector
,
// ScalarPerVector
InMemoryDataOperationEnum
::
Set
,
1
,
// DstScalarStrideInVector
false
>
{
f_grid_desc_m
,
thread_store_global_offset
,
PassThrough
{}};
const
index_t
blockSize
=
get_block_size
();
const
index_t
blockPerGrid
=
get_grid_size
();
const
auto
M
=
c_grid_desc_m
.
GetLength
(
I0
);
const
index_t
loop_step
=
blockPerGrid
*
blockSize
*
MPerThread
;
const
auto
loop_step_index
=
make_multi_index
(
loop_step
);
index_t
num_iter
=
M
/
(
loop_step
);
do
{
// read and process MPerThread elements
a_global_load
.
Run
(
a_grid_desc_m
,
a_global_buf
,
thread_desc_m
,
make_tuple
(
I0
),
a_thread_buf
);
b_global_load
.
Run
(
b_grid_desc_m
,
b_global_buf
,
thread_desc_m
,
make_tuple
(
I0
),
b_thread_buf
);
c_global_load
.
Run
(
c_grid_desc_m
,
c_global_buf
,
thread_desc_m
,
make_tuple
(
I0
),
c_thread_buf
);
d_global_load
.
Run
(
d_grid_desc_m
,
d_global_buf
,
thread_desc_m
,
make_tuple
(
I0
),
d_thread_buf
);
e_global_load
.
Run
(
e_grid_desc_m
,
e_global_buf
,
thread_desc_m
,
make_tuple
(
I0
),
e_thread_buf
);
static_for
<
0
,
MPerThread
,
1
>
{}([
&
](
auto
m
)
{
constexpr
auto
offset
=
thread_desc_m
.
CalculateOffset
(
make_tuple
(
m
));
functor
(
f_thread_buf
(
Number
<
offset
>
{}),
a_thread_buf
(
Number
<
offset
>
{}),
b_thread_buf
(
Number
<
offset
>
{}),
c_thread_buf
(
Number
<
offset
>
{}),
d_thread_buf
(
Number
<
offset
>
{}),
e_thread_buf
(
Number
<
offset
>
{}));
});
f_global_write
.
Run
(
thread_desc_m
,
make_tuple
(
I0
),
// SrcSliceOriginIdx
f_thread_buf
,
f_grid_desc_m
,
f_global_buf
);
a_global_load
.
MoveSrcSliceWindow
(
a_grid_desc_m
,
loop_step_index
);
b_global_load
.
MoveSrcSliceWindow
(
b_grid_desc_m
,
loop_step_index
);
c_global_load
.
MoveSrcSliceWindow
(
c_grid_desc_m
,
loop_step_index
);
d_global_load
.
MoveSrcSliceWindow
(
d_grid_desc_m
,
loop_step_index
);
e_global_load
.
MoveSrcSliceWindow
(
e_grid_desc_m
,
loop_step_index
);
f_global_write
.
MoveDstSliceWindow
(
f_grid_desc_m
,
loop_step_index
);
}
while
(
--
num_iter
);
}
};
}
// namespace ck
Prev
1
2
3
4
5
6
7
8
9
10
…
19
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment