Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
f26fb605
Commit
f26fb605
authored
Jun 07, 2022
by
wangshaojie6
Browse files
Merge branch 'develop' into bwd_weight_bf16_splitk
parents
32d06c66
1677cf70
Changes
69
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
828 additions
and
488 deletions
+828
-488
include/ck/tensor_operation/gpu/device/device_grouped_gemm_xdl.hpp
...k/tensor_operation/gpu/device/device_grouped_gemm_xdl.hpp
+98
-110
include/ck/tensor_operation/gpu/device/device_reduce_multiblock.hpp
.../tensor_operation/gpu/device/device_reduce_multiblock.hpp
+3
-3
include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp
...r_operation/gpu/element/binary_element_wise_operation.hpp
+103
-1
include/ck/tensor_operation/gpu/element/element_wise_operation.hpp
...k/tensor_operation/gpu/element/element_wise_operation.hpp
+28
-11
include/ck/tensor_operation/gpu/grid/gridwise_2d_reduction_multiblock.hpp
...r_operation/gpu/grid/gridwise_2d_reduction_multiblock.hpp
+8
-8
include/ck/tensor_operation/gpu/grid/gridwise_2d_reduction_threadwise.hpp
...r_operation/gpu/grid/gridwise_2d_reduction_threadwise.hpp
+6
-6
include/ck/tensor_operation/gpu/grid/gridwise_5ary_Elementwise_1d.hpp
...ensor_operation/gpu/grid/gridwise_5ary_Elementwise_1d.hpp
+251
-0
include/ck/tensor_operation/gpu/grid/gridwise_gemm_reduce_xdl_cshuffle_v1.hpp
...eration/gpu/grid/gridwise_gemm_reduce_xdl_cshuffle_v1.hpp
+6
-6
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r3.hpp
...k/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r3.hpp
+1
-1
include/ck/utility/math_v2.hpp
include/ck/utility/math_v2.hpp
+60
-10
include/ck/utility/reduction_functions_accumulate.hpp
include/ck/utility/reduction_functions_accumulate.hpp
+14
-21
include/ck/utility/reduction_operator.hpp
include/ck/utility/reduction_operator.hpp
+7
-10
library/include/ck/library/host_tensor/host_reduce_util.hpp
library/include/ck/library/host_tensor/host_reduce_util.hpp
+0
-257
library/include/ck/library/host_tensor/host_reduction.hpp
library/include/ck/library/host_tensor/host_reduction.hpp
+33
-38
library/include/ck/library/reference_tensor_operation/cpu/reference_cgemm.hpp
...ibrary/reference_tensor_operation/cpu/reference_cgemm.hpp
+203
-0
library/src/tensor_operation_instance/gpu/batched_gemm_reduce/device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gkn_gmn_instance.cpp
...xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gkn_gmn_instance.cpp
+1
-1
library/src/tensor_operation_instance/gpu/batched_gemm_reduce/device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gnk_gmn_instance.cpp
...xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gnk_gmn_instance.cpp
+1
-1
library/src/tensor_operation_instance/gpu/batched_gemm_reduce/device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gkn_gmn_instance.cpp
...xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gkn_gmn_instance.cpp
+1
-1
library/src/tensor_operation_instance/gpu/batched_gemm_reduce/device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gnk_gmn_instance.cpp
...xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gnk_gmn_instance.cpp
+1
-1
library/src/tensor_operation_instance/gpu/gemm_reduce/device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_kn_mn_instance.cpp
...ce_xdl_cshuffle_f16_f16_f16_f32_f32_km_kn_mn_instance.cpp
+3
-2
No files found.
include/ck/tensor_operation/gpu/device/device_grouped_gemm_xdl.hpp
View file @
f26fb605
...
@@ -24,14 +24,12 @@ template <typename GridwiseGemm,
...
@@ -24,14 +24,12 @@ template <typename GridwiseGemm,
typename
AElementwiseOperation
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
,
typename
CElementwiseOperation
,
bool
HasMainKBlockLoop
,
bool
HasMainKBlockLoop
>
index_t
MaxGroupCount
>
__global__
void
__global__
void
#if CK_USE_LAUNCH_BOUNDS
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
#endif
#endif
kernel_grouped_gemm_xdlops_v2r3
(
kernel_grouped_gemm_xdlops_v2r3
(
const
void
CK_CONSTANT_ADDRESS_SPACE
*
gemm_descs_const
,
const
StaticallyIndexedArray
<
GemmDesc
,
MaxGroupCount
>
gemm_descs
,
const
index_t
group_count
,
const
index_t
group_count
,
const
AElementwiseOperation
a_element_op
,
const
AElementwiseOperation
a_element_op
,
const
BElementwiseOperation
b_element_op
,
const
BElementwiseOperation
b_element_op
,
...
@@ -42,39 +40,17 @@ __global__ void
...
@@ -42,39 +40,17 @@ __global__ void
const
index_t
block_id
=
get_block_1d_id
();
const
index_t
block_id
=
get_block_1d_id
();
#if 1
const
auto
gemm_desc_ptr
=
static_for
<
0
,
MaxGroupCount
,
1
>
{}([
&
](
auto
i
)
{
reinterpret_cast
<
const
GemmDesc
*>
(
cast_pointer_to_generic_address_space
(
gemm_descs_const
));
if
(
block_id
>=
gemm_descs
[
i
].
BlockStart_
&&
block_id
<
gemm_descs
[
i
].
BlockEnd_
&&
i
<
group_count
)
{
auto
group_id
=
i
;
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
gemm_descs
[
group_id
].
a_ptr
,
gemm_descs
[
group_id
].
b_ptr
,
gemm_descs
[
group_id
].
c_ptr
,
p_shared
,
gemm_descs
[
group_id
].
a_grid_desc_k0_m_k1_
,
gemm_descs
[
group_id
].
b_grid_desc_k0_n_k1_
,
gemm_descs
[
group_id
].
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_
,
a_element_op
,
b_element_op
,
c_element_op
,
gemm_descs
[
group_id
].
grouped_gemm_block_2_ctile_map_
);
}
});
#else
const
auto
gemm_desc_ptr
=
reinterpret_cast
<
const
GemmDesc
*>
(
&
gemm_descs
);
index_t
group_id
=
0
;
index_t
group_id
=
0
;
static_for
<
0
,
MaxGroupCount
,
1
>
{}([
&
](
auto
i
)
{
for
(
index_t
i
=
0
;
i
<
group_count
;
i
++
)
group_id
=
(
block_id
>=
gemm_descs
[
i
].
BlockStart
&&
block_id
<
gemm_descs
[
i
].
BlockEnd
&&
{
i
<
group_count
)
group_id
=
(
block_id
>=
gemm_desc_ptr
[
i
].
BlockStart_
&&
block_id
<
gemm_desc_ptr
[
i
].
BlockEnd_
)
?
i
?
i
:
group_id
;
:
group_id
;
});
}
const
index_t
block_id_grp
=
block_id
-
gemm_desc_ptr
[
group_id
].
BlockStart
;
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
gemm_desc_ptr
[
group_id
].
a_ptr
,
gemm_desc_ptr
[
group_id
].
a_ptr
,
...
@@ -87,11 +63,9 @@ __global__ void
...
@@ -87,11 +63,9 @@ __global__ void
a_element_op
,
a_element_op
,
b_element_op
,
b_element_op
,
c_element_op
,
c_element_op
,
gemm_desc_ptr
[
group_id
].
block_2_ctile_map_
,
gemm_desc_ptr
[
group_id
].
grouped_gemm_block_2_ctile_map_
);
block_id_grp
);
#endif
#else
#else
ignore
=
gemm_descs
;
ignore
=
gemm_descs
_const
;
ignore
=
group_count
;
ignore
=
group_count
;
ignore
=
a_element_op
;
ignore
=
a_element_op
;
ignore
=
b_element_op
;
ignore
=
b_element_op
;
...
@@ -388,6 +362,8 @@ struct DeviceGroupedGemmXdl
...
@@ -388,6 +362,8 @@ struct DeviceGroupedGemmXdl
{
{
grid_size_
=
0
;
grid_size_
=
0
;
gemm_descs_args_workspace_
=
nullptr
;
group_count_
=
ck
::
type_convert
<
ck
::
index_t
>
(
gemm_shapes
.
size
());
group_count_
=
ck
::
type_convert
<
ck
::
index_t
>
(
gemm_shapes
.
size
());
if
(
!
(
group_count_
==
ck
::
type_convert
<
ck
::
index_t
>
(
p_a
.
size
())
&&
if
(
!
(
group_count_
==
ck
::
type_convert
<
ck
::
index_t
>
(
p_a
.
size
())
&&
...
@@ -461,6 +437,8 @@ struct DeviceGroupedGemmXdl
...
@@ -461,6 +437,8 @@ struct DeviceGroupedGemmXdl
std
::
vector
<
GemmDescKernelArg
>
gemm_desc_kernel_arg_
;
std
::
vector
<
GemmDescKernelArg
>
gemm_desc_kernel_arg_
;
void
*
gemm_descs_args_workspace_
;
index_t
grid_size_
;
index_t
grid_size_
;
};
};
...
@@ -471,49 +449,49 @@ struct DeviceGroupedGemmXdl
...
@@ -471,49 +449,49 @@ struct DeviceGroupedGemmXdl
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
{
{
StaticallyIndexedArray
<
GemmDescKernelArg
,
MaxGroupCount
>
gemm_desc_kernel_args
;
bool
has_main_k_block_loop
=
true
;
bool
has_main_k_block_loop
=
true
;
static_for
<
0
,
MaxGroupCount
,
1
>
{}([
&
](
auto
i
)
{
for
(
std
::
size_t
i
=
0
;
i
<
arg
.
gemm_desc_kernel_arg_
.
size
();
i
++
)
if
(
i
<
arg
.
gemm_desc_kernel_arg_
.
size
())
{
{
gemm_desc_kernel_args
(
i
)
=
arg
.
gemm_desc_kernel_arg_
[
i
];
std
::
cout
<<
"group: "
<<
i
<<
" arg.a_grid_desc_k0_m_k1_{"
std
::
cout
<<
"group: "
<<
i
<<
" arg.a_grid_desc_k0_m_k1_{"
<<
gemm_desc_kernel_arg
s
[
i
].
a_grid_desc_k0_m_k1_
.
GetLength
(
I0
)
<<
", "
<<
arg
.
gemm_desc_kernel_arg
_
[
i
].
a_grid_desc_k0_m_k1_
.
GetLength
(
I0
)
<<
", "
<<
gemm_desc_kernel_arg
s
[
i
].
a_grid_desc_k0_m_k1_
.
GetLength
(
I1
)
<<
", "
<<
arg
.
gemm_desc_kernel_arg
_
[
i
].
a_grid_desc_k0_m_k1_
.
GetLength
(
I1
)
<<
", "
<<
gemm_desc_kernel_arg
s
[
i
].
a_grid_desc_k0_m_k1_
.
GetLength
(
I2
)
<<
"}"
;
<<
arg
.
gemm_desc_kernel_arg
_
[
i
].
a_grid_desc_k0_m_k1_
.
GetLength
(
I2
)
<<
"}"
;
std
::
cout
<<
", arg.b_grid_desc_k0_n_k1_{"
std
::
cout
<<
", arg.b_grid_desc_k0_n_k1_{"
<<
gemm_desc_kernel_arg
s
[
i
].
b_grid_desc_k0_n_k1_
.
GetLength
(
I0
)
<<
", "
<<
arg
.
gemm_desc_kernel_arg
_
[
i
].
b_grid_desc_k0_n_k1_
.
GetLength
(
I0
)
<<
", "
<<
gemm_desc_kernel_arg
s
[
i
].
b_grid_desc_k0_n_k1_
.
GetLength
(
I1
)
<<
", "
<<
arg
.
gemm_desc_kernel_arg
_
[
i
].
b_grid_desc_k0_n_k1_
.
GetLength
(
I1
)
<<
", "
<<
gemm_desc_kernel_arg
s
[
i
].
b_grid_desc_k0_n_k1_
.
GetLength
(
I2
)
<<
"}"
;
<<
arg
.
gemm_desc_kernel_arg
_
[
i
].
b_grid_desc_k0_n_k1_
.
GetLength
(
I2
)
<<
"}"
;
std
::
cout
<<
", arg.c_grid_desc_m_n_{ "
std
::
cout
<<
", arg.c_grid_desc_m_n_{ "
<<
gemm_desc_kernel_arg
s
[
i
].
c_grid_desc_m_n_
.
GetLength
(
I0
)
<<
", "
<<
arg
.
gemm_desc_kernel_arg
_
[
i
].
c_grid_desc_m_n_
.
GetLength
(
I0
)
<<
", "
<<
gemm_desc_kernel_arg
s
[
i
].
c_grid_desc_m_n_
.
GetLength
(
I1
)
<<
"}"
<<
arg
.
gemm_desc_kernel_arg
_
[
i
].
c_grid_desc_m_n_
.
GetLength
(
I1
)
<<
"}"
<<
std
::
endl
;
<<
std
::
endl
;
if
(
!
GridwiseGemm
::
CheckValidity
(
if
(
!
GridwiseGemm
::
CheckValidity
(
gemm_desc_kernel_arg
s
[
i
].
a_grid_desc_k0_m_k1_
,
arg
.
gemm_desc_kernel_arg
_
[
i
].
a_grid_desc_k0_m_k1_
,
gemm_desc_kernel_arg
s
[
i
].
b_grid_desc_k0_n_k1_
,
arg
.
gemm_desc_kernel_arg
_
[
i
].
b_grid_desc_k0_n_k1_
,
gemm_desc_kernel_arg
s
[
i
].
c_grid_desc_m_n_
,
arg
.
gemm_desc_kernel_arg
_
[
i
].
c_grid_desc_m_n_
,
gemm_desc_kernel_arg
s
[
i
].
grouped_gemm_block_2_ctile_map_
))
arg
.
gemm_desc_kernel_arg
_
[
i
].
grouped_gemm_block_2_ctile_map_
))
{
{
throw
std
::
runtime_error
(
throw
std
::
runtime_error
(
"wrong! GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 has invalid setting"
);
"wrong! GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 has invalid setting"
);
}
}
const
auto
K
=
gemm_desc_kernel_arg
s
[
i
].
a_grid_desc_k0_m_k1_
.
GetLength
(
I0
)
*
const
auto
K
=
arg
.
gemm_desc_kernel_arg
_
[
i
].
a_grid_desc_k0_m_k1_
.
GetLength
(
I0
)
*
gemm_desc_kernel_arg
s
[
i
].
a_grid_desc_k0_m_k1_
.
GetLength
(
I2
);
arg
.
gemm_desc_kernel_arg
_
[
i
].
a_grid_desc_k0_m_k1_
.
GetLength
(
I2
);
if
(
GridwiseGemm
::
CalculateHasMainKBlockLoop
(
K
)
!=
has_main_k_block_loop
)
if
(
GridwiseGemm
::
CalculateHasMainKBlockLoop
(
K
)
!=
has_main_k_block_loop
)
{
{
throw
std
::
runtime_error
(
"wrong! not all gemm has_main_k_block_loop"
);
throw
std
::
runtime_error
(
"wrong! not all gemm has_main_k_block_loop"
);
}
}
}
}
});
hipGetErrorString
(
hipMemcpy
(
arg
.
gemm_descs_args_workspace_
,
arg
.
gemm_desc_kernel_arg_
.
data
(),
arg
.
gemm_desc_kernel_arg_
.
size
()
*
sizeof
(
GemmDescKernelArg
),
hipMemcpyHostToDevice
));
float
ave_time
=
0
;
float
ave_time
=
0
;
...
@@ -523,19 +501,19 @@ struct DeviceGroupedGemmXdl
...
@@ -523,19 +501,19 @@ struct DeviceGroupedGemmXdl
kernel_grouped_gemm_xdlops_v2r3
<
GridwiseGemm
,
kernel_grouped_gemm_xdlops_v2r3
<
GridwiseGemm
,
ADataType
,
// TODO: distiguish A/B datatype
ADataType
,
// TODO: distiguish A/B datatype
CDataType
,
CDataType
,
remove_reference_t
<
GemmDescKernelArg
>
,
GemmDescKernelArg
,
AElementwiseOperation
,
AElementwiseOperation
,
BElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
,
CElementwiseOperation
,
true
,
true
>
;
MaxGroupCount
>
;
ave_time
=
launch_and_time_kernel
(
stream_config
,
ave_time
=
launch_and_time_kernel
(
stream_config
,
kernel
,
kernel
,
dim3
(
arg
.
grid_size_
),
dim3
(
arg
.
grid_size_
),
dim3
(
BlockSize
),
dim3
(
BlockSize
),
0
,
0
,
gemm_desc_kernel_args
,
cast_pointer_to_constant_address_space
(
arg
.
gemm_descs_args_workspace_
)
,
arg
.
gemm_desc_kernel_arg_
.
size
(),
arg
.
gemm_desc_kernel_arg_
.
size
(),
arg
.
a_element_op_
,
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
b_element_op_
,
...
@@ -547,19 +525,19 @@ struct DeviceGroupedGemmXdl
...
@@ -547,19 +525,19 @@ struct DeviceGroupedGemmXdl
kernel_grouped_gemm_xdlops_v2r3
<
GridwiseGemm
,
kernel_grouped_gemm_xdlops_v2r3
<
GridwiseGemm
,
ADataType
,
// TODO: distiguish A/B datatype
ADataType
,
// TODO: distiguish A/B datatype
CDataType
,
CDataType
,
remove_reference_t
<
GemmDescKernelArg
>
,
GemmDescKernelArg
,
AElementwiseOperation
,
AElementwiseOperation
,
BElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
,
CElementwiseOperation
,
false
,
false
>
;
MaxGroupCount
>
;
ave_time
=
launch_and_time_kernel
(
stream_config
,
ave_time
=
launch_and_time_kernel
(
stream_config
,
kernel
,
kernel
,
dim3
(
arg
.
grid_size_
),
dim3
(
arg
.
grid_size_
),
dim3
(
BlockSize
),
dim3
(
BlockSize
),
0
,
0
,
gemm_desc_kernel_args
,
cast_pointer_to_constant_address_space
(
arg
.
gemm_descs_args_workspace_
)
,
arg
.
gemm_desc_kernel_arg_
.
size
(),
arg
.
gemm_desc_kernel_arg_
.
size
(),
arg
.
a_element_op_
,
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
b_element_op_
,
...
@@ -652,6 +630,16 @@ struct DeviceGroupedGemmXdl
...
@@ -652,6 +630,16 @@ struct DeviceGroupedGemmXdl
return
str
.
str
();
return
str
.
str
();
}
}
size_t
GetWorkSpaceSize
(
const
BaseArgument
*
p_arg
)
const
override
{
return
dynamic_cast
<
const
Argument
*>
(
p_arg
)
->
group_count_
*
sizeof
(
GemmDescKernelArg
);
}
void
SetWorkSpacePointer
(
BaseArgument
*
p_arg
,
void
*
workspace_ptr
)
const
override
{
dynamic_cast
<
Argument
*>
(
p_arg
)
->
gemm_descs_args_workspace_
=
workspace_ptr
;
}
};
};
}
// namespace device
}
// namespace device
...
...
include/ck/tensor_operation/gpu/device/device_reduce_multiblock.hpp
View file @
f26fb605
...
@@ -348,8 +348,8 @@ struct DeviceReduceMultiBlock : public DeviceReduce<InElementwiseOperation, AccE
...
@@ -348,8 +348,8 @@ struct DeviceReduceMultiBlock : public DeviceReduce<InElementwiseOperation, AccE
if
constexpr
(
use_multiblock
)
if
constexpr
(
use_multiblock
)
{
{
const
auto
zero
Val
=
const
auto
identity
Val
=
ck
::
reduce
::
Get
ReductionZero
ValueForInMemoryDataOperation
<
OutDataType
>
(
ck
::
reduce
::
Get
Identity
Value
ue
ForInMemoryDataOperation
<
OutDataType
>
(
OutMemoryDataOperation
);
OutMemoryDataOperation
);
const
auto
kernel_pre
=
const
auto
kernel_pre
=
...
@@ -362,7 +362,7 @@ struct DeviceReduceMultiBlock : public DeviceReduce<InElementwiseOperation, AccE
...
@@ -362,7 +362,7 @@ struct DeviceReduceMultiBlock : public DeviceReduce<InElementwiseOperation, AccE
0
,
0
,
out_grid_desc_m_2
,
out_grid_desc_m_2
,
arg
.
out_dev_
,
arg
.
out_dev_
,
zero
Val
);
identity
Val
);
};
};
avg_time
+=
launch_and_time_kernel
(
stream_config
,
avg_time
+=
launch_and_time_kernel
(
stream_config
,
...
...
include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp
View file @
f26fb605
/*******************************************************************************
*
* MIT License
*
* Copyright (c) 2022 Advanced Micro Devices, Inc.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in all
* copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*
*******************************************************************************/
#pragma once
#pragma once
#include "data_type.hpp"
#include "data_type.hpp"
...
@@ -5,14 +30,22 @@ namespace ck {
...
@@ -5,14 +30,22 @@ namespace ck {
namespace
tensor_operation
{
namespace
tensor_operation
{
namespace
binary_element_wise
{
namespace
binary_element_wise
{
struct
Add
template
<
typename
Y
,
typename
X1
,
typename
X2
>
struct
Add
;
template
<
>
struct
Add
<
double
,
double
,
double
>
{
{
__host__
__device__
constexpr
void
__host__
__device__
constexpr
void
operator
()(
double
&
dst
,
const
double
&
src1
,
const
double
&
src2
)
const
operator
()(
double
&
dst
,
const
double
&
src1
,
const
double
&
src2
)
const
{
{
dst
=
src1
+
src2
;
dst
=
src1
+
src2
;
}
}
};
template
<
>
struct
Add
<
float
,
float
,
float
>
{
__host__
__device__
constexpr
void
__host__
__device__
constexpr
void
operator
()(
float
&
dst
,
const
float
&
src1
,
const
float
&
src2
)
const
operator
()(
float
&
dst
,
const
float
&
src1
,
const
float
&
src2
)
const
{
{
...
@@ -20,6 +53,75 @@ struct Add
...
@@ -20,6 +53,75 @@ struct Add
}
}
};
};
template
<
>
struct
Add
<
half_t
,
half_t
,
half_t
>
{
__host__
__device__
constexpr
void
operator
()(
half_t
&
dst
,
const
half_t
&
src1
,
const
half_t
&
src2
)
const
{
dst
=
src1
+
src2
;
}
};
template
<
>
struct
Add
<
bhalf_t
,
bhalf_t
,
bhalf_t
>
{
__host__
__device__
constexpr
void
operator
()(
bhalf_t
&
dst
,
const
bhalf_t
&
src1
,
const
bhalf_t
&
src2
)
const
{
const
float
x1
=
ck
::
type_convert
<
float
>
(
src1
);
const
float
x2
=
ck
::
type_convert
<
float
>
(
src2
);
const
float
y
=
x1
+
x2
;
dst
=
ck
::
type_convert
<
bhalf_t
>
(
y
);
}
};
template
<
typename
Y
,
typename
X1
,
typename
X2
>
struct
Substract
;
template
<
>
struct
Substract
<
double
,
double
,
double
>
{
__host__
__device__
constexpr
void
operator
()(
double
&
dst
,
const
double
&
src1
,
const
double
&
src2
)
const
{
dst
=
src1
-
src2
;
}
};
template
<
>
struct
Substract
<
float
,
float
,
float
>
{
__host__
__device__
constexpr
void
operator
()(
float
&
dst
,
const
float
&
src1
,
const
float
&
src2
)
const
{
dst
=
src1
-
src2
;
}
};
template
<
>
struct
Substract
<
half_t
,
half_t
,
half_t
>
{
__host__
__device__
constexpr
void
operator
()(
half_t
&
dst
,
const
half_t
&
src1
,
const
half_t
&
src2
)
const
{
dst
=
src1
-
src2
;
}
};
template
<
>
struct
Substract
<
bhalf_t
,
bhalf_t
,
bhalf_t
>
{
__host__
__device__
constexpr
void
operator
()(
bhalf_t
&
dst
,
const
bhalf_t
&
src1
,
const
bhalf_t
&
src2
)
const
{
const
float
x1
=
ck
::
type_convert
<
float
>
(
src1
);
const
float
x2
=
ck
::
type_convert
<
float
>
(
src2
);
const
float
y
=
x1
-
x2
;
dst
=
ck
::
type_convert
<
bhalf_t
>
(
y
);
}
};
}
// namespace binary_element_wise
}
// namespace binary_element_wise
}
// namespace tensor_operation
}
// namespace tensor_operation
}
// namespace ck
}
// namespace ck
include/ck/tensor_operation/gpu/element/element_wise_operation.hpp
View file @
f26fb605
#pragma once
#pragma once
#include "data_type.hpp"
#include "data_type.hpp"
#include "math_v2.hpp"
namespace
ck
{
namespace
ck
{
namespace
tensor_operation
{
namespace
tensor_operation
{
...
@@ -143,6 +144,24 @@ struct AddHardswishAdd
...
@@ -143,6 +144,24 @@ struct AddHardswishAdd
}
}
};
};
struct
Normalize
{
Normalize
(
float
epsilon
=
1e-4
)
:
epsilon_
(
epsilon
)
{}
__host__
__device__
constexpr
void
operator
()(
float
&
y
,
const
float
&
x
,
const
float
&
mean
,
const
float
&
mean_square
,
const
float
&
gamma
,
const
float
&
beta
)
const
{
float
variance
=
mean_square
-
(
mean
*
mean
);
y
=
((
x
-
mean
)
/
sqrtf
(
variance
+
epsilon_
))
*
gamma
+
beta
;
}
float
epsilon_
;
};
// Unary operators are usually called element-wisely before/after the reduction is executed on the
// Unary operators are usually called element-wisely before/after the reduction is executed on the
// elements. They are needed for easy implementation of reduction types of AVG, NRM1, NRM2
// elements. They are needed for easy implementation of reduction types of AVG, NRM1, NRM2
...
@@ -278,7 +297,7 @@ struct UnaryAbs<float, float>
...
@@ -278,7 +297,7 @@ struct UnaryAbs<float, float>
{
{
__host__
__device__
UnaryAbs
(
const
int32_t
divider
=
1
)
{
(
void
)
divider
;
};
__host__
__device__
UnaryAbs
(
const
int32_t
divider
=
1
)
{
(
void
)
divider
;
};
__host__
__device__
void
operator
()(
float
&
y
,
const
float
&
x
)
const
{
y
=
abs
(
x
);
};
__host__
__device__
void
operator
()(
float
&
y
,
const
float
&
x
)
const
{
y
=
ck
::
math
::
abs
(
x
);
};
};
};
template
<
>
template
<
>
...
@@ -286,7 +305,7 @@ struct UnaryAbs<half_t, half_t>
...
@@ -286,7 +305,7 @@ struct UnaryAbs<half_t, half_t>
{
{
__host__
__device__
UnaryAbs
(
const
int32_t
divider
=
1
)
{
(
void
)
divider
;
};
__host__
__device__
UnaryAbs
(
const
int32_t
divider
=
1
)
{
(
void
)
divider
;
};
__host__
__device__
void
operator
()(
half_t
&
y
,
const
half_t
&
x
)
const
{
y
=
__h
abs
(
x
);
};
__host__
__device__
void
operator
()(
half_t
&
y
,
const
half_t
&
x
)
const
{
y
=
ck
::
math
::
abs
(
x
);
};
};
};
template
<
>
template
<
>
...
@@ -294,7 +313,7 @@ struct UnaryAbs<double, double>
...
@@ -294,7 +313,7 @@ struct UnaryAbs<double, double>
{
{
__host__
__device__
UnaryAbs
(
const
int32_t
divider
=
1
)
{
(
void
)
divider
;
};
__host__
__device__
UnaryAbs
(
const
int32_t
divider
=
1
)
{
(
void
)
divider
;
};
__host__
__device__
void
operator
()(
double
&
y
,
const
double
&
x
)
const
{
y
=
abs
(
x
);
};
__host__
__device__
void
operator
()(
double
&
y
,
const
double
&
x
)
const
{
y
=
ck
::
math
::
abs
(
x
);
};
};
};
template
<
>
template
<
>
...
@@ -302,12 +321,7 @@ struct UnaryAbs<int8_t, int8_t>
...
@@ -302,12 +321,7 @@ struct UnaryAbs<int8_t, int8_t>
{
{
__host__
__device__
UnaryAbs
(
const
int32_t
divider
=
1
)
{
(
void
)
divider
;
};
__host__
__device__
UnaryAbs
(
const
int32_t
divider
=
1
)
{
(
void
)
divider
;
};
__host__
__device__
void
operator
()(
int8_t
&
y
,
const
int8_t
&
x
)
const
__host__
__device__
void
operator
()(
int8_t
&
y
,
const
int8_t
&
x
)
const
{
y
=
ck
::
math
::
abs
(
x
);
};
{
int8_t
sgn
=
x
>>
(
8
-
1
);
y
=
(
x
^
sgn
)
-
sgn
;
};
};
};
template
<
typename
Y
,
typename
X
>
template
<
typename
Y
,
typename
X
>
...
@@ -318,7 +332,7 @@ struct UnarySqrt<float, float>
...
@@ -318,7 +332,7 @@ struct UnarySqrt<float, float>
{
{
__host__
__device__
UnarySqrt
(
const
int32_t
divider
=
1
)
{
(
void
)
divider
;
};
__host__
__device__
UnarySqrt
(
const
int32_t
divider
=
1
)
{
(
void
)
divider
;
};
__host__
__device__
void
operator
()(
float
&
y
,
const
float
&
x
)
const
{
y
=
sqrt
f
(
x
);
};
__host__
__device__
void
operator
()(
float
&
y
,
const
float
&
x
)
const
{
y
=
ck
::
math
::
sqrt
(
x
);
};
};
};
template
<
>
template
<
>
...
@@ -326,7 +340,10 @@ struct UnarySqrt<double, double>
...
@@ -326,7 +340,10 @@ struct UnarySqrt<double, double>
{
{
__host__
__device__
UnarySqrt
(
const
int32_t
divider
=
1
)
{
(
void
)
divider
;
};
__host__
__device__
UnarySqrt
(
const
int32_t
divider
=
1
)
{
(
void
)
divider
;
};
__host__
__device__
void
operator
()(
double
&
y
,
const
double
&
x
)
const
{
y
=
sqrt
(
x
);
};
__host__
__device__
void
operator
()(
double
&
y
,
const
double
&
x
)
const
{
y
=
ck
::
math
::
sqrt
(
x
);
};
};
};
template
<
typename
Y
,
typename
X
>
template
<
typename
Y
,
typename
X
>
...
...
include/ck/tensor_operation/gpu/grid/gridwise_2d_reduction_multiblock.hpp
View file @
f26fb605
...
@@ -171,7 +171,7 @@ struct GridwiseReduction_mk_to_m_multiblock
...
@@ -171,7 +171,7 @@ struct GridwiseReduction_mk_to_m_multiblock
AccDataType
beta
,
AccDataType
beta
,
OutDataType
*
const
__restrict__
p_out_value_global
)
OutDataType
*
const
__restrict__
p_out_value_global
)
{
{
const
auto
zero
Val
=
ReduceOperation
::
Get
ReductionZero
Val
();
const
auto
identity
Val
=
ReduceOperation
::
Get
Identity
Val
ue
();
// LDS
// LDS
__shared__
AccDataType
p_reduce_work_buffer
[
BlockSize
];
__shared__
AccDataType
p_reduce_work_buffer
[
BlockSize
];
...
@@ -179,7 +179,7 @@ struct GridwiseReduction_mk_to_m_multiblock
...
@@ -179,7 +179,7 @@ struct GridwiseReduction_mk_to_m_multiblock
const
auto
in_global_val_buf
=
const
auto
in_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_in_value_global
,
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_in_value_global
,
in_grid_desc_m_k
.
GetElementSpaceSize
(),
in_grid_desc_m_k
.
GetElementSpaceSize
(),
type_convert
<
InDataType
>
(
zero
Val
));
type_convert
<
InDataType
>
(
identity
Val
));
auto
out_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
auto
out_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_out_value_global
,
out_grid_desc_m
.
GetElementSpaceSize
());
p_out_value_global
,
out_grid_desc_m
.
GetElementSpaceSize
());
...
@@ -191,7 +191,7 @@ struct GridwiseReduction_mk_to_m_multiblock
...
@@ -191,7 +191,7 @@ struct GridwiseReduction_mk_to_m_multiblock
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
,
true
>
accu_value_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
,
true
>
accu_value_buf
;
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
I
)
{
accu_value_buf
(
I
)
=
zero
Val
;
});
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
I
)
{
accu_value_buf
(
I
)
=
identity
Val
;
});
const
index_t
thread_local_id
=
get_thread_local_1d_id
();
const
index_t
thread_local_id
=
get_thread_local_1d_id
();
const
index_t
block_global_id
=
get_block_1d_id
();
const
index_t
block_global_id
=
get_block_1d_id
();
...
@@ -358,12 +358,12 @@ struct GridwiseReduction_mk_to_m_multiblock
...
@@ -358,12 +358,12 @@ struct GridwiseReduction_mk_to_m_multiblock
__shared__
AccDataType
p_reduce_work_val_buffer
[
BlockSize
];
__shared__
AccDataType
p_reduce_work_val_buffer
[
BlockSize
];
__shared__
IndexDataType
p_reduce_work_idx_buffer
[
BlockSize
];
__shared__
IndexDataType
p_reduce_work_idx_buffer
[
BlockSize
];
const
auto
zero
Val
=
ReduceOperation
::
Get
ReductionZero
Val
();
const
auto
identity
Val
=
ReduceOperation
::
Get
Identity
Val
ue
();
const
auto
in_global_val_buf
=
const
auto
in_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_in_value_global
,
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_in_value_global
,
in_grid_desc_m_k
.
GetElementSpaceSize
(),
in_grid_desc_m_k
.
GetElementSpaceSize
(),
type_convert
<
InDataType
>
(
zero
Val
));
type_convert
<
InDataType
>
(
identity
Val
));
const
auto
in_global_idx_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
const
auto
in_global_idx_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_in_index_global
,
in_grid_desc_m_k
.
GetElementSpaceSize
());
p_in_index_global
,
in_grid_desc_m_k
.
GetElementSpaceSize
());
auto
out_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
auto
out_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
...
@@ -418,7 +418,7 @@ struct GridwiseReduction_mk_to_m_multiblock
...
@@ -418,7 +418,7 @@ struct GridwiseReduction_mk_to_m_multiblock
thread_k_cluster_id
*
KThreadSliceSize
));
thread_k_cluster_id
*
KThreadSliceSize
));
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
I
)
{
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
I
)
{
accu_value_buf
(
I
)
=
zero
Val
;
accu_value_buf
(
I
)
=
identity
Val
;
accu_index_buf
(
I
)
=
0
;
accu_index_buf
(
I
)
=
0
;
});
});
...
@@ -459,7 +459,7 @@ struct GridwiseReduction_mk_to_m_multiblock
...
@@ -459,7 +459,7 @@ struct GridwiseReduction_mk_to_m_multiblock
in_thread_idx_buf
);
in_thread_idx_buf
);
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
iM
)
{
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
iM
)
{
AccDataType
tmpValue
=
zero
Val
;
AccDataType
tmpValue
=
identity
Val
;
IndexDataType
tmpIndex
=
0
;
IndexDataType
tmpIndex
=
0
;
static_for
<
0
,
KThreadSliceSize
,
1
>
{}([
&
](
auto
iK
)
{
static_for
<
0
,
KThreadSliceSize
,
1
>
{}([
&
](
auto
iK
)
{
...
@@ -512,7 +512,7 @@ struct GridwiseReduction_mk_to_m_multiblock
...
@@ -512,7 +512,7 @@ struct GridwiseReduction_mk_to_m_multiblock
in_thread_val_buf
(
Number
<
offset
>
{}));
in_thread_val_buf
(
Number
<
offset
>
{}));
});
});
AccDataType
tmpValue
=
zero
Val
;
AccDataType
tmpValue
=
identity
Val
;
IndexDataType
tmpIndex
=
0
;
IndexDataType
tmpIndex
=
0
;
static_for
<
0
,
KThreadSliceSize
,
1
>
{}([
&
](
auto
iK
)
{
static_for
<
0
,
KThreadSliceSize
,
1
>
{}([
&
](
auto
iK
)
{
...
...
include/ck/tensor_operation/gpu/grid/gridwise_2d_reduction_threadwise.hpp
View file @
f26fb605
...
@@ -135,12 +135,12 @@ struct GridwiseReduction_mk_to_m_threadwise
...
@@ -135,12 +135,12 @@ struct GridwiseReduction_mk_to_m_threadwise
ReduceOperation
,
ReduceOperation
,
PropagateNan
>
;
PropagateNan
>
;
const
auto
zero
Val
=
ReduceOperation
::
Get
ReductionZero
Val
();
const
auto
identity
Val
=
ReduceOperation
::
Get
Identity
Val
ue
();
const
auto
in_global_val_buf
=
const
auto
in_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_in_value_global
,
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_in_value_global
,
in_grid_desc_m_k
.
GetElementSpaceSize
(),
in_grid_desc_m_k
.
GetElementSpaceSize
(),
type_convert
<
InDataType
>
(
zero
Val
));
type_convert
<
InDataType
>
(
identity
Val
));
auto
dst_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
auto
dst_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_out_value_global
,
out_grid_desc_m
.
GetElementSpaceSize
());
p_out_value_global
,
out_grid_desc_m
.
GetElementSpaceSize
());
...
@@ -149,7 +149,7 @@ struct GridwiseReduction_mk_to_m_threadwise
...
@@ -149,7 +149,7 @@ struct GridwiseReduction_mk_to_m_threadwise
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
,
true
>
accu_value_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
,
true
>
accu_value_buf
;
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
I
)
{
accu_value_buf
(
I
)
=
zero
Val
;
});
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
I
)
{
accu_value_buf
(
I
)
=
identity
Val
;
});
const
auto
toReduceLength
=
in_grid_desc_m_k
.
GetLength
(
Number
<
1
>
{});
const
auto
toReduceLength
=
in_grid_desc_m_k
.
GetLength
(
Number
<
1
>
{});
...
@@ -276,12 +276,12 @@ struct GridwiseReduction_mk_to_m_threadwise
...
@@ -276,12 +276,12 @@ struct GridwiseReduction_mk_to_m_threadwise
(
void
)
acc_elementwise_op
;
(
void
)
acc_elementwise_op
;
const
auto
zero
Val
=
ReduceOperation
::
Get
ReductionZero
Val
();
const
auto
identity
Val
=
ReduceOperation
::
Get
Identity
Val
ue
();
const
auto
in_global_val_buf
=
const
auto
in_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_in_value_global
,
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_in_value_global
,
in_grid_desc_m_k
.
GetElementSpaceSize
(),
in_grid_desc_m_k
.
GetElementSpaceSize
(),
type_convert
<
InDataType
>
(
zero
Val
));
type_convert
<
InDataType
>
(
identity
Val
));
const
auto
in_global_idx_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
const
auto
in_global_idx_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_in_index_global
,
in_grid_desc_m_k
.
GetElementSpaceSize
());
p_in_index_global
,
in_grid_desc_m_k
.
GetElementSpaceSize
());
...
@@ -303,7 +303,7 @@ struct GridwiseReduction_mk_to_m_threadwise
...
@@ -303,7 +303,7 @@ struct GridwiseReduction_mk_to_m_threadwise
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
IndexDataType
,
MThreadSliceSize
,
true
>
accu_index_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
IndexDataType
,
MThreadSliceSize
,
true
>
accu_index_buf
;
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
I
)
{
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
I
)
{
accu_value_buf
(
I
)
=
zero
Val
;
accu_value_buf
(
I
)
=
identity
Val
;
accu_index_buf
(
I
)
=
0
;
accu_index_buf
(
I
)
=
0
;
});
});
...
...
include/ck/tensor_operation/gpu/grid/gridwise_5ary_Elementwise_1d.hpp
0 → 100644
View file @
f26fb605
#pragma once
#include "cluster_descriptor.hpp"
#include "data_type.hpp"
#include "element_wise_operation.hpp"
#include "threadwise_tensor_slice_transfer.hpp"
namespace
ck
{
template
<
typename
Gridwise5AryEltwise
,
typename
ADataType
,
typename
BDataType
,
typename
CDataType
,
typename
DDataType
,
typename
EDataType
,
typename
FDataType
,
typename
AGridDesc_M
,
typename
BGridDesc_M
,
typename
CGridDesc_M
,
typename
DGridDesc_M
,
typename
EGridDesc_M
,
typename
FGridDesc_M
,
typename
ElementwiseFunctor
>
__global__
void
kernel_5ary_elementwise_1d
(
const
ADataType
*
__restrict__
p_a_global
,
const
BDataType
*
__restrict__
p_b_global
,
const
CDataType
*
__restrict__
p_c_global
,
const
DDataType
*
__restrict__
p_d_global
,
const
EDataType
*
__restrict__
p_e_global
,
FDataType
*
__restrict__
p_f_global
,
const
AGridDesc_M
a_grid_desc_m
,
const
BGridDesc_M
b_grid_desc_m
,
const
CGridDesc_M
c_grid_desc_m
,
const
DGridDesc_M
d_grid_desc_m
,
const
EGridDesc_M
e_grid_desc_m
,
const
FGridDesc_M
f_grid_desc_m
,
const
ElementwiseFunctor
functor
)
{
Gridwise5AryEltwise
::
Run
(
p_a_global
,
p_b_global
,
p_c_global
,
p_d_global
,
p_e_global
,
p_f_global
,
a_grid_desc_m
,
b_grid_desc_m
,
c_grid_desc_m
,
d_grid_desc_m
,
e_grid_desc_m
,
f_grid_desc_m
,
functor
);
}
// TODO - implement n-ary Elemenetwise_1D, tuple of inputs and tuple of outputs
template
<
typename
ADataType
,
typename
BDataType
,
typename
CDataType
,
typename
DDataType
,
typename
EDataType
,
typename
FDataType
,
typename
ComputeDataType
,
typename
AGridDesc_M
,
typename
BGridDesc_M
,
typename
CGridDesc_M
,
typename
DGridDesc_M
,
typename
EGridDesc_M
,
typename
FGridDesc_M
,
typename
ElementwiseFunctor
,
index_t
MPerThread
,
index_t
AScalarPerVector
,
index_t
BScalarPerVector
,
index_t
CScalarPerVector
,
index_t
DScalarPerVector
,
index_t
EScalarPerVector
,
index_t
FScalarPerVector
>
struct
Gridwise5AryElementwise_1D
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
thread_desc_m
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MPerThread
>
{}));
using
PassThrough
=
tensor_operation
::
element_wise
::
PassThrough
;
static
__device__
auto
CalculateElementwiseIndex
()
{
const
index_t
global_thread_id
=
get_thread_global_1d_id
();
return
make_multi_index
(
global_thread_id
*
MPerThread
);
}
__device__
static
void
Run
(
const
ADataType
*
__restrict__
p_a_global
,
const
BDataType
*
__restrict__
p_b_global
,
const
CDataType
*
__restrict__
p_c_global
,
const
DDataType
*
__restrict__
p_d_global
,
const
EDataType
*
__restrict__
p_e_global
,
FDataType
*
__restrict__
p_f_global
,
const
AGridDesc_M
a_grid_desc_m
,
const
BGridDesc_M
b_grid_desc_m
,
const
CGridDesc_M
c_grid_desc_m
,
const
DGridDesc_M
d_grid_desc_m
,
const
EGridDesc_M
e_grid_desc_m
,
const
FGridDesc_M
f_grid_desc_m
,
const
ElementwiseFunctor
functor
)
{
const
auto
a_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_a_global
,
a_grid_desc_m
.
GetElementSpaceSize
());
const
auto
b_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_b_global
,
b_grid_desc_m
.
GetElementSpaceSize
());
const
auto
c_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_c_global
,
c_grid_desc_m
.
GetElementSpaceSize
());
const
auto
d_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_d_global
,
d_grid_desc_m
.
GetElementSpaceSize
());
const
auto
e_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_e_global
,
e_grid_desc_m
.
GetElementSpaceSize
());
auto
f_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_f_global
,
f_grid_desc_m
.
GetElementSpaceSize
());
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
ComputeDataType
,
MPerThread
,
true
>
a_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
ComputeDataType
,
MPerThread
,
true
>
b_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
ComputeDataType
,
MPerThread
,
true
>
c_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
ComputeDataType
,
MPerThread
,
true
>
d_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
ComputeDataType
,
MPerThread
,
true
>
e_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
ComputeDataType
,
MPerThread
,
true
>
f_thread_buf
;
const
auto
thread_store_global_offset
=
CalculateElementwiseIndex
();
auto
a_global_load
=
ThreadwiseTensorSliceTransfer_v2
<
ADataType
,
ComputeDataType
,
AGridDesc_M
,
decltype
(
thread_desc_m
),
Sequence
<
MPerThread
>
,
// SliceLengths
Sequence
<
0
>
,
// DimAccessOrder
0
,
// SrcVectorDim
AScalarPerVector
,
// ScalarPerVector
1
,
// SrcScalarStrideInVector
false
>
{
a_grid_desc_m
,
thread_store_global_offset
};
auto
b_global_load
=
ThreadwiseTensorSliceTransfer_v2
<
BDataType
,
ComputeDataType
,
BGridDesc_M
,
decltype
(
thread_desc_m
),
Sequence
<
MPerThread
>
,
// SliceLengths
Sequence
<
0
>
,
// DimAccessOrder
0
,
// SrcVectorDim
BScalarPerVector
,
// ScalarPerVector
1
,
// SrcScalarStrideInVector
false
>
{
b_grid_desc_m
,
thread_store_global_offset
};
auto
c_global_load
=
ThreadwiseTensorSliceTransfer_v2
<
CDataType
,
ComputeDataType
,
CGridDesc_M
,
decltype
(
thread_desc_m
),
Sequence
<
MPerThread
>
,
// SliceLengths
Sequence
<
0
>
,
// DimAccessOrder
0
,
// SrcVectorDim
CScalarPerVector
,
// ScalarPerVector
1
,
// SrcScalarStrideInVector
false
>
{
c_grid_desc_m
,
thread_store_global_offset
};
auto
d_global_load
=
ThreadwiseTensorSliceTransfer_v2
<
DDataType
,
ComputeDataType
,
DGridDesc_M
,
decltype
(
thread_desc_m
),
Sequence
<
MPerThread
>
,
// SliceLengths
Sequence
<
0
>
,
// DimAccessOrder
0
,
// SrcVectorDim
DScalarPerVector
,
// ScalarPerVector
1
,
// SrcScalarStrideInVector
false
>
{
d_grid_desc_m
,
thread_store_global_offset
};
auto
e_global_load
=
ThreadwiseTensorSliceTransfer_v2
<
EDataType
,
ComputeDataType
,
EGridDesc_M
,
decltype
(
thread_desc_m
),
Sequence
<
MPerThread
>
,
// SliceLengths
Sequence
<
0
>
,
// DimAccessOrder
0
,
// SrcVectorDim
EScalarPerVector
,
// ScalarPerVector
1
,
// SrcScalarStrideInVector
false
>
{
e_grid_desc_m
,
thread_store_global_offset
};
auto
f_global_write
=
ThreadwiseTensorSliceTransfer_v1r3
<
ComputeDataType
,
FDataType
,
decltype
(
thread_desc_m
),
FGridDesc_M
,
PassThrough
,
Sequence
<
MPerThread
>
,
// SliceLengths
Sequence
<
0
>
,
// DimAccessOrder
0
,
// DstVectorDim
FScalarPerVector
,
// ScalarPerVector
InMemoryDataOperationEnum
::
Set
,
1
,
// DstScalarStrideInVector
false
>
{
f_grid_desc_m
,
thread_store_global_offset
,
PassThrough
{}};
const
index_t
blockSize
=
get_block_size
();
const
index_t
blockPerGrid
=
get_grid_size
();
const
auto
M
=
c_grid_desc_m
.
GetLength
(
I0
);
const
index_t
loop_step
=
blockPerGrid
*
blockSize
*
MPerThread
;
const
auto
loop_step_index
=
make_multi_index
(
loop_step
);
index_t
num_iter
=
M
/
(
loop_step
);
do
{
// read and process MPerThread elements
a_global_load
.
Run
(
a_grid_desc_m
,
a_global_buf
,
thread_desc_m
,
make_tuple
(
I0
),
a_thread_buf
);
b_global_load
.
Run
(
b_grid_desc_m
,
b_global_buf
,
thread_desc_m
,
make_tuple
(
I0
),
b_thread_buf
);
c_global_load
.
Run
(
c_grid_desc_m
,
c_global_buf
,
thread_desc_m
,
make_tuple
(
I0
),
c_thread_buf
);
d_global_load
.
Run
(
d_grid_desc_m
,
d_global_buf
,
thread_desc_m
,
make_tuple
(
I0
),
d_thread_buf
);
e_global_load
.
Run
(
e_grid_desc_m
,
e_global_buf
,
thread_desc_m
,
make_tuple
(
I0
),
e_thread_buf
);
static_for
<
0
,
MPerThread
,
1
>
{}([
&
](
auto
m
)
{
constexpr
auto
offset
=
thread_desc_m
.
CalculateOffset
(
make_tuple
(
m
));
functor
(
f_thread_buf
(
Number
<
offset
>
{}),
a_thread_buf
(
Number
<
offset
>
{}),
b_thread_buf
(
Number
<
offset
>
{}),
c_thread_buf
(
Number
<
offset
>
{}),
d_thread_buf
(
Number
<
offset
>
{}),
e_thread_buf
(
Number
<
offset
>
{}));
});
f_global_write
.
Run
(
thread_desc_m
,
make_tuple
(
I0
),
// SrcSliceOriginIdx
f_thread_buf
,
f_grid_desc_m
,
f_global_buf
);
a_global_load
.
MoveSrcSliceWindow
(
a_grid_desc_m
,
loop_step_index
);
b_global_load
.
MoveSrcSliceWindow
(
b_grid_desc_m
,
loop_step_index
);
c_global_load
.
MoveSrcSliceWindow
(
c_grid_desc_m
,
loop_step_index
);
d_global_load
.
MoveSrcSliceWindow
(
d_grid_desc_m
,
loop_step_index
);
e_global_load
.
MoveSrcSliceWindow
(
e_grid_desc_m
,
loop_step_index
);
f_global_write
.
MoveDstSliceWindow
(
f_grid_desc_m
,
loop_step_index
);
}
while
(
--
num_iter
);
}
};
}
// namespace ck
include/ck/tensor_operation/gpu/grid/gridwise_gemm_reduce_xdl_cshuffle_v1.hpp
View file @
f26fb605
...
@@ -21,7 +21,7 @@ template <typename GridwiseGemm,
...
@@ -21,7 +21,7 @@ template <typename GridwiseGemm,
typename
BElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
,
typename
CElementwiseOperation
,
typename
DxsInElementwiseOperation
,
typename
DxsInElementwiseOperation
,
typename
Dxs
Out
ElementwiseOperation
,
typename
Dxs
Acc
ElementwiseOperation
,
typename
AGridDesc_AK0_M_AK1
,
typename
AGridDesc_AK0_M_AK1
,
typename
BGridDesc_BK0_N_BK1
,
typename
BGridDesc_BK0_N_BK1
,
typename
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
...
@@ -41,7 +41,7 @@ __global__ void
...
@@ -41,7 +41,7 @@ __global__ void
const
BElementwiseOperation
b_element_op
,
const
BElementwiseOperation
b_element_op
,
const
CElementwiseOperation
c_element_op
,
const
CElementwiseOperation
c_element_op
,
const
DxsInElementwiseOperation
dxs_in_element_op
,
const
DxsInElementwiseOperation
dxs_in_element_op
,
const
Dxs
Out
ElementwiseOperation
dxs_out_element_op
,
const
Dxs
Acc
ElementwiseOperation
dxs_out_element_op
,
const
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1
,
const
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1
,
const
BGridDesc_BK0_N_BK1
b_grid_desc_bk0_n_bk1
,
const
BGridDesc_BK0_N_BK1
b_grid_desc_bk0_n_bk1
,
const
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
const
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
...
@@ -96,7 +96,7 @@ template <typename FloatAB,
...
@@ -96,7 +96,7 @@ template <typename FloatAB,
typename
CElementwiseOperation
,
typename
CElementwiseOperation
,
typename
DxsReduceOperation
,
typename
DxsReduceOperation
,
typename
DxsInElementwiseOperation
,
typename
DxsInElementwiseOperation
,
typename
Dxs
Out
ElementwiseOperation
,
typename
Dxs
Acc
ElementwiseOperation
,
InMemoryDataOperationEnum
CGlobalMemoryDataOperation
,
InMemoryDataOperationEnum
CGlobalMemoryDataOperation
,
typename
DGlobalMemoryDataOperation
,
typename
DGlobalMemoryDataOperation
,
typename
AGridDesc_AK0_M_AK1
,
typename
AGridDesc_AK0_M_AK1
,
...
@@ -329,7 +329,7 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
...
@@ -329,7 +329,7 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
const
BElementwiseOperation
&
b_element_op
,
const
BElementwiseOperation
&
b_element_op
,
const
CElementwiseOperation
&
c_element_op
,
const
CElementwiseOperation
&
c_element_op
,
const
DxsInElementwiseOperation
&
dxs_in_element_op
,
const
DxsInElementwiseOperation
&
dxs_in_element_op
,
const
Dxs
Out
ElementwiseOperation
&
dxs_out_element_op
,
const
Dxs
Acc
ElementwiseOperation
&
dxs_out_element_op
,
const
AGridDesc_AK0_M_AK1
&
a_grid_desc_ak0_m_ak1
,
const
AGridDesc_AK0_M_AK1
&
a_grid_desc_ak0_m_ak1
,
const
BGridDesc_BK0_N_BK1
&
b_grid_desc_bk0_n_bk1
,
const
BGridDesc_BK0_N_BK1
&
b_grid_desc_bk0_n_bk1
,
const
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
&
const
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
&
...
@@ -816,10 +816,10 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
...
@@ -816,10 +816,10 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
false
>
;
false
>
;
// Global write Gemm shuffle + reduction
// Global write Gemm shuffle + reduction
const
auto
d_
zero
Val
=
DReduceOperation
::
Get
ReductionZero
Val
();
const
auto
d_
identity
Val
=
DReduceOperation
::
Get
Identity
Val
ue
();
static_for
<
0
,
mreduce_per_thread
,
1
>
{}(
static_for
<
0
,
mreduce_per_thread
,
1
>
{}(
[
&
](
auto
I
)
{
d_thread_buf
(
I
)
=
d_
zero
Val
;
});
[
&
](
auto
I
)
{
d_thread_buf
(
I
)
=
d_
identity
Val
;
});
// reduce in VGPR
// reduce in VGPR
static_for
<
0
,
mreduce_per_thread
,
1
>
{}([
&
](
auto
im
)
{
static_for
<
0
,
mreduce_per_thread
,
1
>
{}([
&
](
auto
im
)
{
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r3.hpp
View file @
f26fb605
...
@@ -340,7 +340,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3
...
@@ -340,7 +340,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3
using
DefaultBlock2CTileMap
=
using
DefaultBlock2CTileMap
=
remove_cvref_t
<
decltype
(
MakeDefaultBlock2CTileMap
(
CGridDesc_M_N
{},
1
,
1
))
>
;
remove_cvref_t
<
decltype
(
MakeDefaultBlock2CTileMap
(
CGridDesc_M_N
{},
1
,
1
))
>
;
template
<
bool
HasMainKBlockLoop
,
typename
Block2CTileMap
=
DefaultBlock2CTileMap
>
template
<
bool
HasMainKBlockLoop
,
typename
Block2CTileMap
>
__device__
static
void
__device__
static
void
Run
(
const
FloatAB
*
__restrict__
p_a_grid
,
Run
(
const
FloatAB
*
__restrict__
p_a_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
...
...
include/ck/utility/math_v2.hpp
View file @
f26fb605
...
@@ -3,11 +3,13 @@
...
@@ -3,11 +3,13 @@
#include <cmath>
#include <cmath>
#include "data_type.hpp"
#include "data_type.hpp"
#include "
half
.hpp"
#include "
type
.hpp"
namespace
ck
{
namespace
ck
{
namespace
math
{
namespace
math
{
// math functions for the host, some are implemented by calling C++ std functions
static
inline
__host__
float
abs
(
float
x
)
{
return
std
::
abs
(
x
);
};
static
inline
__host__
float
abs
(
float
x
)
{
return
std
::
abs
(
x
);
};
static
inline
__host__
double
abs
(
double
x
)
{
return
std
::
abs
(
x
);
};
static
inline
__host__
double
abs
(
double
x
)
{
return
std
::
abs
(
x
);
};
...
@@ -28,26 +30,26 @@ static inline __host__ int32_t abs(int32_t x)
...
@@ -28,26 +30,26 @@ static inline __host__ int32_t abs(int32_t x)
static
inline
__host__
half_t
abs
(
half_t
x
)
static
inline
__host__
half_t
abs
(
half_t
x
)
{
{
half_float
::
half
xx
=
*
reinterpret_cast
<
half_float
::
half
*
>
(
&
x
);
uint16_t
xx
=
ck
::
bit_cast
<
uint16_t
>
(
x
);
half_float
::
half
abs_xx
=
half_float
::
abs
(
xx
)
;
uint16_t
abs_xx
=
xx
&
0x7fff
;
half_t
abs_x
=
*
reinterpre
t_cast
<
half_t
*
>
(
&
abs_xx
);
half_t
abs_x
=
ck
::
bi
t_cast
<
half_t
>
(
abs_xx
);
return
abs_x
;
return
abs_x
;
};
};
static
inline
__host__
float
isnan
(
float
x
)
{
return
std
::
isnan
(
x
);
};
static
inline
__host__
bool
isnan
(
float
x
)
{
return
std
::
isnan
(
x
);
};
static
inline
__host__
double
isnan
(
double
x
)
{
return
std
::
isnan
(
x
);
};
static
inline
__host__
bool
isnan
(
double
x
)
{
return
std
::
isnan
(
x
);
};
static
inline
__host__
int8_t
isnan
(
int8_t
x
)
static
inline
__host__
bool
isnan
(
int8_t
x
)
{
{
(
void
)
x
;
(
void
)
x
;
return
false
;
return
false
;
};
};
static
inline
__host__
int32_t
isnan
(
int32_t
x
)
static
inline
__host__
bool
isnan
(
int32_t
x
)
{
{
(
void
)
x
;
(
void
)
x
;
return
false
;
return
false
;
...
@@ -55,11 +57,59 @@ static inline __host__ int32_t isnan(int32_t x)
...
@@ -55,11 +57,59 @@ static inline __host__ int32_t isnan(int32_t x)
static
inline
__host__
bool
isnan
(
half_t
x
)
static
inline
__host__
bool
isnan
(
half_t
x
)
{
{
half_float
::
half
xx
=
*
reinterpret_cast
<
half_float
::
half
*>
(
&
x
);
uint16_t
xx
=
ck
::
bit_cast
<
uint16_t
>
(
x
);
return
(
xx
&
0x7FFF
)
>
0x7C00
;
};
static
inline
__host__
float
sqrt
(
float
x
)
{
return
std
::
sqrt
(
x
);
};
static
inline
__host__
double
sqrt
(
double
x
)
{
return
std
::
sqrt
(
x
);
};
// math functions for the HIP kernel, some are implemented by calling hip builtin functions
static
inline
__device__
float
abs
(
float
x
)
{
return
::
abs
(
x
);
};
static
inline
__device__
double
abs
(
double
x
)
{
return
::
abs
(
x
);
};
static
inline
__device__
int8_t
abs
(
int8_t
x
)
{
int8_t
sgn
=
x
>>
(
8
-
1
);
return
(
x
^
sgn
)
-
sgn
;
};
static
inline
__device__
int32_t
abs
(
int32_t
x
)
{
int32_t
sgn
=
x
>>
(
32
-
1
);
return
(
x
^
sgn
)
-
sgn
;
};
static
inline
__device__
half_t
abs
(
half_t
x
)
{
return
::
__habs
(
x
);
};
static
inline
__device__
bool
isnan
(
float
x
)
{
return
::
isnan
(
x
);
};
static
inline
__device__
bool
isnan
(
double
x
)
{
return
::
isnan
(
x
);
};
static
inline
__device__
bool
isnan
(
int8_t
x
)
{
(
void
)
x
;
return
false
;
};
return
half_float
::
isnan
(
xx
);
static
inline
__device__
bool
isnan
(
int32_t
x
)
{
(
void
)
x
;
return
false
;
};
};
static
inline
__device__
bool
isnan
(
half_t
x
)
{
return
::
__hisnan
(
x
);
};
static
inline
__device__
float
sqrt
(
float
x
)
{
return
::
sqrtf
(
x
);
};
static
inline
__device__
double
sqrt
(
double
x
)
{
return
::
sqrt
(
x
);
};
}
// namespace math
}
// namespace math
}
// namespace ck
}
// namespace ck
...
...
include/ck/utility/reduction_functions_accumulate.hpp
View file @
f26fb605
...
@@ -27,6 +27,7 @@
...
@@ -27,6 +27,7 @@
#define CK_REDUCTION_FUNCTIONS_BINOP_HPP
#define CK_REDUCTION_FUNCTIONS_BINOP_HPP
#include "data_type.hpp"
#include "data_type.hpp"
#include "math_v2.hpp"
#include "reduction_common.hpp"
#include "reduction_common.hpp"
#include "reduction_operator.hpp"
#include "reduction_operator.hpp"
...
@@ -34,18 +35,6 @@
...
@@ -34,18 +35,6 @@
namespace
ck
{
namespace
ck
{
namespace
detail
{
namespace
detail
{
template
<
typename
T
>
static
inline
__device__
bool
is_nan
(
T
x
)
{
return
(
isnan
(
x
));
};
template
<
>
inline
__device__
bool
is_nan
<
half_t
>
(
half_t
x
)
{
return
(
__hisnan
(
x
));
};
template
<
bool
PropagateNan
,
typename
ReduceOperation
,
typename
AccDataType
>
template
<
bool
PropagateNan
,
typename
ReduceOperation
,
typename
AccDataType
>
struct
AccumulateWithNanCheck
;
struct
AccumulateWithNanCheck
;
...
@@ -53,7 +42,7 @@ template <typename ReduceOperation, typename AccDataType>
...
@@ -53,7 +42,7 @@ template <typename ReduceOperation, typename AccDataType>
struct
AccumulateWithNanCheck
<
false
,
ReduceOperation
,
AccDataType
>
struct
AccumulateWithNanCheck
<
false
,
ReduceOperation
,
AccDataType
>
{
{
// cppcheck-suppress constParameter
// cppcheck-suppress constParameter
__device__
static
inline
void
Calculate
(
AccDataType
&
accuVal
,
AccDataType
currVal
)
__host__
__device__
static
inline
void
Calculate
(
AccDataType
&
accuVal
,
AccDataType
currVal
)
{
{
ReduceOperation
{}(
accuVal
,
currVal
);
ReduceOperation
{}(
accuVal
,
currVal
);
};
};
...
@@ -62,9 +51,11 @@ struct AccumulateWithNanCheck<false, ReduceOperation, AccDataType>
...
@@ -62,9 +51,11 @@ struct AccumulateWithNanCheck<false, ReduceOperation, AccDataType>
template
<
typename
ReduceOperation
,
typename
AccDataType
>
template
<
typename
ReduceOperation
,
typename
AccDataType
>
struct
AccumulateWithNanCheck
<
true
,
ReduceOperation
,
AccDataType
>
struct
AccumulateWithNanCheck
<
true
,
ReduceOperation
,
AccDataType
>
{
{
__device__
static
inline
void
Calculate
(
AccDataType
&
accuVal
,
AccDataType
currVal
)
__host__
__device__
static
inline
void
Calculate
(
AccDataType
&
accuVal
,
AccDataType
currVal
)
{
{
if
(
is_nan
(
currVal
))
using
ck
::
math
::
isnan
;
if
(
isnan
(
currVal
))
{
{
accuVal
=
currVal
;
accuVal
=
currVal
;
}
}
...
@@ -81,7 +72,7 @@ struct AccumulateWithIndexAndNanCheck;
...
@@ -81,7 +72,7 @@ struct AccumulateWithIndexAndNanCheck;
template
<
typename
ReduceOperation
,
typename
AccDataType
,
typename
IndexDataType
>
template
<
typename
ReduceOperation
,
typename
AccDataType
,
typename
IndexDataType
>
struct
AccumulateWithIndexAndNanCheck
<
false
,
ReduceOperation
,
AccDataType
,
IndexDataType
>
struct
AccumulateWithIndexAndNanCheck
<
false
,
ReduceOperation
,
AccDataType
,
IndexDataType
>
{
{
__device__
static
inline
void
__host__
__device__
static
inline
void
// cppcheck-suppress constParameter
// cppcheck-suppress constParameter
Calculate
(
AccDataType
&
accuVal
,
Calculate
(
AccDataType
&
accuVal
,
AccDataType
currVal
,
AccDataType
currVal
,
...
@@ -101,12 +92,14 @@ template <typename ReduceOperation, typename AccDataType, typename IndexDataType
...
@@ -101,12 +92,14 @@ template <typename ReduceOperation, typename AccDataType, typename IndexDataType
struct
AccumulateWithIndexAndNanCheck
<
true
,
ReduceOperation
,
AccDataType
,
IndexDataType
>
struct
AccumulateWithIndexAndNanCheck
<
true
,
ReduceOperation
,
AccDataType
,
IndexDataType
>
{
{
// The method is called when the ReduceOperation is indexable and the user asked for indices
// The method is called when the ReduceOperation is indexable and the user asked for indices
__device__
static
inline
void
Calculate
(
AccDataType
&
accuVal
,
__host__
__device__
static
inline
void
Calculate
(
AccDataType
&
accuVal
,
AccDataType
currVal
,
AccDataType
currVal
,
IndexDataType
&
accuIndex
,
IndexDataType
&
accuIndex
,
IndexDataType
currIndex
)
IndexDataType
currIndex
)
{
{
if
(
is_nan
(
currVal
))
using
ck
::
math
::
isnan
;
if
(
isnan
(
currVal
))
{
{
accuVal
=
currVal
;
accuVal
=
currVal
;
accuIndex
=
currIndex
;
accuIndex
=
currIndex
;
...
...
include/ck/utility/reduction_operator.hpp
View file @
f26fb605
...
@@ -36,7 +36,7 @@ namespace reduce {
...
@@ -36,7 +36,7 @@ namespace reduce {
// Every binary operator used in reduction is represented by a templated functor class. Each functor
// Every binary operator used in reduction is represented by a templated functor class. Each functor
// class must provide at least
// class must provide at least
// three members:
// three members:
// 1) Get
ReductionZero
Val() -- the interface to return the "identity element" for the binary
// 1) Get
Identity
Val
ue
() -- the interface to return the "identity element" for the binary
// operator, "identity element" is the unique
// operator, "identity element" is the unique
// element in the algebraic space that doesn't affect the value of other elements
// element in the algebraic space that doesn't affect the value of other elements
// when operated against them, and the concept is similar to zero vector in
// when operated against them, and the concept is similar to zero vector in
...
@@ -59,7 +59,7 @@ struct Add
...
@@ -59,7 +59,7 @@ struct Add
{
{
using
dataType
=
T
;
using
dataType
=
T
;
__host__
__device__
static
constexpr
T
Get
ReductionZero
Val
()
{
return
static_cast
<
T
>
(
0.0
f
);
};
__host__
__device__
static
constexpr
T
Get
Identity
Val
ue
()
{
return
static_cast
<
T
>
(
0.0
f
);
};
__device__
static
constexpr
bool
__device__
static
constexpr
bool
IsCompatibleInMemoryDataOperation
(
InMemoryDataOperationEnum
operation
)
IsCompatibleInMemoryDataOperation
(
InMemoryDataOperationEnum
operation
)
...
@@ -76,7 +76,7 @@ struct Mul
...
@@ -76,7 +76,7 @@ struct Mul
{
{
using
dataType
=
T
;
using
dataType
=
T
;
__host__
__device__
static
constexpr
T
Get
ReductionZero
Val
()
{
return
static_cast
<
T
>
(
1.0
f
);
};
__host__
__device__
static
constexpr
T
Get
Identity
Val
ue
()
{
return
static_cast
<
T
>
(
1.0
f
);
};
__device__
static
constexpr
bool
__device__
static
constexpr
bool
IsCompatibleInMemoryDataOperation
(
InMemoryDataOperationEnum
operation
)
IsCompatibleInMemoryDataOperation
(
InMemoryDataOperationEnum
operation
)
...
@@ -92,7 +92,7 @@ struct Max
...
@@ -92,7 +92,7 @@ struct Max
{
{
using
dataType
=
T
;
using
dataType
=
T
;
__host__
__device__
static
constexpr
T
Get
ReductionZero
Val
()
__host__
__device__
static
constexpr
T
Get
Identity
Val
ue
()
{
{
return
NumericLimits
<
T
>::
Lowest
();
return
NumericLimits
<
T
>::
Lowest
();
};
};
...
@@ -125,10 +125,7 @@ struct Min
...
@@ -125,10 +125,7 @@ struct Min
{
{
using
dataType
=
T
;
using
dataType
=
T
;
__host__
__device__
static
constexpr
T
GetReductionZeroVal
()
__host__
__device__
static
constexpr
T
GetIdentityValue
()
{
return
NumericLimits
<
T
>::
Max
();
};
{
return
NumericLimits
<
T
>::
Max
();
};
__device__
static
constexpr
bool
__device__
static
constexpr
bool
IsCompatibleInMemoryDataOperation
(
InMemoryDataOperationEnum
operation
)
IsCompatibleInMemoryDataOperation
(
InMemoryDataOperationEnum
operation
)
...
@@ -158,7 +155,7 @@ struct AMax
...
@@ -158,7 +155,7 @@ struct AMax
{
{
using
dataType
=
T
;
using
dataType
=
T
;
__host__
__device__
static
constexpr
T
Get
ReductionZero
Val
()
{
return
static_cast
<
T
>
(
0.0
f
);
};
__host__
__device__
static
constexpr
T
Get
Identity
Val
ue
()
{
return
static_cast
<
T
>
(
0.0
f
);
};
__device__
static
constexpr
bool
__device__
static
constexpr
bool
IsCompatibleInMemoryDataOperation
(
InMemoryDataOperationEnum
operation
)
IsCompatibleInMemoryDataOperation
(
InMemoryDataOperationEnum
operation
)
...
@@ -184,7 +181,7 @@ struct AMax
...
@@ -184,7 +181,7 @@ struct AMax
};
};
template
<
typename
T
>
template
<
typename
T
>
T
Get
ReductionZero
ValueForInMemoryDataOperation
(
InMemoryDataOperationEnum
operation
)
T
Get
Identity
Value
ue
ForInMemoryDataOperation
(
InMemoryDataOperationEnum
operation
)
{
{
T
result
=
ck
::
type_convert
<
T
>
(
0.0
f
);
T
result
=
ck
::
type_convert
<
T
>
(
0.0
f
);
...
...
library/include/ck/library/host_tensor/host_reduce_util.hpp
deleted
100644 → 0
View file @
32d06c66
/*******************************************************************************
*
* MIT License
*
* Copyright (c) 2020 Advanced Micro Devices, Inc.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in all
* copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*
*******************************************************************************/
#ifndef GUARD_HOST_REDUCE_UTIL_HPP
#define GUARD_HOST_REDUCE_UTIL_HPP
#include <limits>
#include <cmath>
#include <functional>
#include "reduction_enums.hpp"
#include "data_type.hpp"
#include "math_v2.hpp"
namespace
ck
{
namespace
host_reduce
{
using
ck
::
NanPropagation
;
using
ck
::
ReduceTensorOp
;
template
<
typename
AccDataType
,
ReduceTensorOp
ReduceOpId
>
__host__
static
inline
std
::
function
<
void
(
AccDataType
&
)
>
PreUnaryOpFn
(
int
)
{
using
ck
::
math
::
abs
;
if
constexpr
(
ReduceOpId
==
ReduceTensorOp
::
NORM1
)
{
return
([
&
](
AccDataType
&
a_
)
{
a_
=
abs
(
a_
);
});
}
else
if
constexpr
(
ReduceOpId
==
ReduceTensorOp
::
NORM2
)
{
return
([
&
](
AccDataType
&
a_
)
{
a_
=
a_
*
a_
;
});
}
else
if
constexpr
(
ReduceOpId
==
ReduceTensorOp
::
AMAX
)
{
return
([
&
](
AccDataType
&
a_
)
{
a_
=
abs
(
a_
);
});
}
else
{
// ReduceTensorOp::AVG:
// ReduceTensorOp::ADD:
// ReduceTensorOp::MUL:
// ReduceTensorOp::MIN:
// ReduceTensorOp::MAX:
return
([
&
](
AccDataType
&
)
{});
};
};
template
<
typename
AccDataType
,
ReduceTensorOp
ReduceOpId
>
__host__
static
inline
std
::
function
<
void
(
AccDataType
&
)
>
PosUnaryOpFn
(
int32_t
divider
)
{
using
std
::
sqrt
;
if
constexpr
(
ReduceOpId
==
ReduceTensorOp
::
NORM2
)
{
return
([
&
](
AccDataType
&
a_
)
{
a_
=
sqrt
(
a_
);
});
}
else
if
constexpr
(
ReduceOpId
==
ReduceTensorOp
::
AVG
)
{
return
([
&
,
divider
](
AccDataType
&
a_
)
{
a_
=
a_
/
static_cast
<
AccDataType
>
(
static_cast
<
float
>
(
divider
));
});
}
else
{
// ReduceTensorOp::ADD:
// ReduceTensorOp::NORM1:
// ReduceTensorOp::MUL:
// ReduceTensorOp::MIN:
// ReduceTensorOp::MAX:
// ReduceTensorOp::AMAX:
return
([
&
](
AccDataType
&
)
{});
}
};
template
<
typename
AccDataType
,
ReduceTensorOp
ReduceOpId
>
__host__
static
inline
std
::
function
<
void
(
AccDataType
&
,
AccDataType
)
>
ReduceOpFn
()
{
if
constexpr
(
ReduceOpId
==
ReduceTensorOp
::
ADD
||
ReduceOpId
==
ReduceTensorOp
::
AVG
||
ReduceOpId
==
ReduceTensorOp
::
NORM1
||
ReduceOpId
==
ReduceTensorOp
::
NORM2
)
{
return
([
&
](
AccDataType
&
a_
,
AccDataType
b_
)
{
a_
=
a_
+
b_
;
});
}
else
if
constexpr
(
ReduceOpId
==
ReduceTensorOp
::
MUL
)
{
return
([
&
](
AccDataType
&
a_
,
AccDataType
b_
)
{
a_
=
a_
*
b_
;
});
}
else
if
constexpr
(
ReduceOpId
==
ReduceTensorOp
::
MIN
)
{
return
([
&
](
AccDataType
&
a_
,
AccDataType
b_
)
{
if
(
a_
>
b_
)
a_
=
b_
;
});
}
else
if
constexpr
(
ReduceOpId
==
ReduceTensorOp
::
MAX
||
ReduceOpId
==
ReduceTensorOp
::
AMAX
)
{
return
([
&
](
AccDataType
&
a_
,
AccDataType
b_
)
{
if
(
a_
<
b_
)
a_
=
b_
;
});
}
};
template
<
typename
AccDataType
,
ReduceTensorOp
ReduceOpId
>
__host__
static
inline
std
::
function
<
void
(
AccDataType
&
,
AccDataType
,
bool
&
changed
)
>
ReduceOpFn2
()
{
if
constexpr
(
ReduceOpId
==
ReduceTensorOp
::
MIN
)
{
return
([
&
](
AccDataType
&
a_
,
AccDataType
b_
,
bool
&
changed
)
{
if
(
a_
>
b_
)
{
a_
=
b_
;
changed
=
true
;
}
else
changed
=
false
;
});
}
else
if
constexpr
(
ReduceOpId
==
ReduceTensorOp
::
MAX
||
ReduceOpId
==
ReduceTensorOp
::
AMAX
)
{
return
([
&
](
AccDataType
&
a_
,
AccDataType
b_
,
bool
&
changed
)
{
if
(
a_
<
b_
)
{
a_
=
b_
;
changed
=
true
;
}
else
changed
=
false
;
});
}
else
{
// ReduceTensorOp::ADD:
// ReduceTensorOp::MUL:
// ReduceTensorOp::AVG:
// ReduceTensorOp::NORM1:
// ReduceTensorOp::NORM2:
return
(
std
::
function
<
void
(
AccDataType
&
,
AccDataType
,
bool
&
)
>
{});
};
};
template
<
typename
AccDataType
,
ReduceTensorOp
ReduceOpId
>
__host__
static
inline
AccDataType
ReduceOpZeroVal
()
{
if
constexpr
(
ReduceOpId
==
ReduceTensorOp
::
MUL
)
{
return
(
static_cast
<
AccDataType
>
(
1.0
f
));
}
else
if
constexpr
(
ReduceOpId
==
ReduceTensorOp
::
MIN
)
{
return
(
ck
::
NumericLimits
<
AccDataType
>::
Max
());
}
else
if
constexpr
(
ReduceOpId
==
ReduceTensorOp
::
MAX
)
{
return
(
ck
::
NumericLimits
<
AccDataType
>::
Lowest
());
}
else
if
constexpr
(
ReduceOpId
==
ReduceTensorOp
::
AMAX
)
{
return
(
static_cast
<
AccDataType
>
(
0.0
f
));
}
else
{
// ReduceTensorOp::ADD
// ReduceTensorOp::AVG
// ReduceTensorOp::NORM1
// ReduceTensorOp::NORM2
return
(
static_cast
<
AccDataType
>
(
0.0
f
));
};
};
template
<
typename
AccDataType
,
bool
PropagateNan
>
__host__
static
inline
void
binop_with_nan_check
(
std
::
function
<
void
(
AccDataType
&
,
AccDataType
)
>
opReduce
,
AccDataType
&
accuVal
,
AccDataType
currVal
)
{
using
ck
::
math
::
isnan
;
if
constexpr
(
!
PropagateNan
)
{
opReduce
(
accuVal
,
currVal
);
}
else
{
if
(
isnan
(
currVal
))
accuVal
=
currVal
;
else
opReduce
(
accuVal
,
currVal
);
};
};
template
<
typename
AccDataType
,
typename
IndexDataType
,
bool
PropagateNan
>
__host__
static
inline
void
binop_with_index_and_nan_check
(
std
::
function
<
void
(
AccDataType
&
,
AccDataType
,
bool
&
)
>
opReduce
,
AccDataType
&
accuVal
,
AccDataType
currVal
,
IndexDataType
&
accuIndex
,
IndexDataType
currIndex
)
{
using
ck
::
math
::
isnan
;
if
constexpr
(
!
PropagateNan
)
{
bool
changed
;
opReduce
(
accuVal
,
currVal
,
changed
);
if
(
changed
)
accuIndex
=
currIndex
;
}
else
{
if
(
isnan
(
currVal
))
{
accuVal
=
currVal
;
accuIndex
=
currIndex
;
}
else
{
bool
changed
;
opReduce
(
accuVal
,
currVal
,
changed
);
if
(
changed
)
accuIndex
=
currIndex
;
};
};
};
};
// namespace host_reduce
};
// namespace ck
#endif
library/include/ck/library/host_tensor/host_reduction.hpp
View file @
f26fb605
...
@@ -33,10 +33,10 @@
...
@@ -33,10 +33,10 @@
#include "reduction_enums.hpp"
#include "reduction_enums.hpp"
#include "reduction_common.hpp"
#include "reduction_common.hpp"
#include "host_reduce_util.hpp"
#include "host_common_util.hpp"
#include "host_common_util.hpp"
#include "host_tensor.hpp"
#include "host_tensor.hpp"
#include "data_type.hpp"
#include "data_type.hpp"
#include "reduction_functions_accumulate.hpp"
template
<
int
NDim
>
template
<
int
NDim
>
static
void
get_all_indexes
(
const
std
::
array
<
size_t
,
NDim
>&
dimLengths
,
static
void
get_all_indexes
(
const
std
::
array
<
size_t
,
NDim
>&
dimLengths
,
...
@@ -106,11 +106,13 @@ static size_t get_offset_from_index(const std::vector<size_t>& strides,
...
@@ -106,11 +106,13 @@ static size_t get_offset_from_index(const std::vector<size_t>& strides,
template
<
typename
InDataType
,
template
<
typename
InDataType
,
typename
AccDataType
,
typename
AccDataType
,
typename
OutDataType
,
typename
OutDataType
,
ck
::
ReduceTensorOp
ReduceOpId
,
typename
ReduceOperation
,
typename
InElementwiseOperation
,
typename
AccElementwiseOperation
,
int
Rank
,
int
Rank
,
int
NumReduceDim
,
int
NumReduceDim
,
bool
PropagateNan
,
bool
PropagateNan
,
bool
NeedIndices
>
bool
OutputIndex
>
struct
ReductionHost
struct
ReductionHost
{
{
using
IndexDataType
=
int32_t
;
using
IndexDataType
=
int32_t
;
...
@@ -122,8 +124,6 @@ struct ReductionHost
...
@@ -122,8 +124,6 @@ struct ReductionHost
std
::
vector
<
int
>
reduceDims
;
std
::
vector
<
int
>
reduceDims
;
IndexDataType
divider
;
IndexDataType
divider
;
std
::
function
<
void
(
AccDataType
&
)
>
preUnaryOp
;
std
::
function
<
void
(
AccDataType
&
)
>
posUnaryOp
;
std
::
array
<
size_t
,
NumReduceDim
>
reduceLengths
;
std
::
array
<
size_t
,
NumReduceDim
>
reduceLengths
;
std
::
array
<
size_t
,
NumReduceDim
>
reduceStrides
;
std
::
array
<
size_t
,
NumReduceDim
>
reduceStrides
;
std
::
array
<
size_t
,
NumInvariantDim
>
invariantLengths
;
std
::
array
<
size_t
,
NumInvariantDim
>
invariantLengths
;
...
@@ -137,9 +137,6 @@ struct ReductionHost
...
@@ -137,9 +137,6 @@ struct ReductionHost
const
std
::
vector
<
int
>&
invariantDims_
,
const
std
::
vector
<
int
>&
invariantDims_
,
const
std
::
vector
<
int
>&
reduceDims_
)
const
std
::
vector
<
int
>&
reduceDims_
)
{
{
using
ck
::
host_reduce
::
PosUnaryOpFn
;
using
ck
::
host_reduce
::
PreUnaryOpFn
;
// this->outLengths = to_int_vector(outDesc.GetLengths());
// this->outLengths = to_int_vector(outDesc.GetLengths());
this
->
outStrides
=
outDesc
.
GetStrides
();
this
->
outStrides
=
outDesc
.
GetStrides
();
...
@@ -171,9 +168,6 @@ struct ReductionHost
...
@@ -171,9 +168,6 @@ struct ReductionHost
invariant_dim_indexes
.
clear
();
invariant_dim_indexes
.
clear
();
get_all_indexes
<
NumInvariantDim
>
(
invariantLengths
,
invariant_dim_indexes
);
get_all_indexes
<
NumInvariantDim
>
(
invariantLengths
,
invariant_dim_indexes
);
};
};
preUnaryOp
=
PreUnaryOpFn
<
AccDataType
,
ReduceOpId
>
(
divider
);
posUnaryOp
=
PosUnaryOpFn
<
AccDataType
,
ReduceOpId
>
(
divider
);
};
};
void
Run
(
float
alpha
,
void
Run
(
float
alpha
,
...
@@ -182,7 +176,7 @@ struct ReductionHost
...
@@ -182,7 +176,7 @@ struct ReductionHost
OutDataType
*
out_data
,
OutDataType
*
out_data
,
IndexDataType
*
out_indices
)
IndexDataType
*
out_indices
)
{
{
if
constexpr
(
NeedIndices
)
if
constexpr
(
OutputIndex
)
{
{
RunImpl_with_index
(
alpha
,
in_data
,
beta
,
out_data
,
out_indices
);
RunImpl_with_index
(
alpha
,
in_data
,
beta
,
out_data
,
out_indices
);
}
}
...
@@ -201,15 +195,17 @@ struct ReductionHost
...
@@ -201,15 +195,17 @@ struct ReductionHost
using
ck
::
float_equal_one
;
using
ck
::
float_equal_one
;
using
ck
::
float_equal_zero
;
using
ck
::
float_equal_zero
;
using
ck
::
type_convert
;
using
ck
::
type_convert
;
using
ck
::
host_reduce
::
binop_with_index_and_nan_check
;
using
ck
::
host_reduce
::
ReduceOpFn2
;
using
ck
::
host_reduce
::
ReduceOpZeroVal
;
auto
opReduce2
=
ReduceOpFn2
<
AccDataType
,
ReduceOpId
>
();
using
Accumulation
=
ck
::
detail
::
AccumulateWithIndexAndNanCheck
<
PropagateNan
,
ReduceOperation
,
AccDataType
,
IndexDataType
>
;
InElementwiseOperation
in_elementwise_op
(
divider
);
AccElementwiseOperation
acc_elementwise_op
(
divider
);
if
constexpr
(
NumInvariantDim
==
0
)
if
constexpr
(
NumInvariantDim
==
0
)
{
{
AccDataType
accuVal
=
ReduceOp
Z
er
oVal
<
AccDataType
,
ReduceOpId
>
();
AccDataType
accuVal
=
ReduceOper
ation
::
GetIdentityValue
();
IndexDataType
accuIndex
=
0
;
IndexDataType
accuIndex
=
0
;
for
(
std
::
size_t
i
=
0
;
i
<
reduce_dim_indexes
.
size
();
i
++
)
for
(
std
::
size_t
i
=
0
;
i
<
reduce_dim_indexes
.
size
();
i
++
)
...
@@ -219,15 +215,14 @@ struct ReductionHost
...
@@ -219,15 +215,14 @@ struct ReductionHost
auto
currVal
=
type_convert
<
AccDataType
>
(
in_data
[
offset_reduce
]);
auto
currVal
=
type_convert
<
AccDataType
>
(
in_data
[
offset_reduce
]);
preUnaryOp
(
currVal
);
in_elementwise_op
(
currVal
,
currVal
);
auto
currIndex
=
static_cast
<
IndexDataType
>
(
i
);
auto
currIndex
=
static_cast
<
IndexDataType
>
(
i
);
binop_with_index_and_nan_check
<
AccDataType
,
IndexDataType
,
PropagateNan
>
(
Accumulation
::
Calculate
(
accuVal
,
currVal
,
accuIndex
,
currIndex
);
opReduce2
,
accuVal
,
currVal
,
accuIndex
,
currIndex
);
};
};
posUnaryOp
(
accuVal
);
acc_elementwise_op
(
accuVal
,
accuVal
);
if
(
!
float_equal_one
{}(
alpha
))
if
(
!
float_equal_one
{}(
alpha
))
accuVal
*=
type_convert
<
AccDataType
>
(
alpha
);
accuVal
*=
type_convert
<
AccDataType
>
(
alpha
);
...
@@ -241,7 +236,7 @@ struct ReductionHost
...
@@ -241,7 +236,7 @@ struct ReductionHost
else
else
{
{
auto
thread_reduce_func
=
[
&
](
auto
invariant_index
)
{
auto
thread_reduce_func
=
[
&
](
auto
invariant_index
)
{
AccDataType
accuVal
=
ReduceOp
Z
er
oVal
<
AccDataType
,
ReduceOpId
>
();
AccDataType
accuVal
=
ReduceOper
ation
::
GetIdentityValue
();
IndexDataType
accuIndex
=
0
;
IndexDataType
accuIndex
=
0
;
auto
offset_invariant
=
auto
offset_invariant
=
...
@@ -255,15 +250,14 @@ struct ReductionHost
...
@@ -255,15 +250,14 @@ struct ReductionHost
auto
currVal
=
auto
currVal
=
type_convert
<
AccDataType
>
(
in_data
[
offset_invariant
+
offset_reduce
]);
type_convert
<
AccDataType
>
(
in_data
[
offset_invariant
+
offset_reduce
]);
preUnaryOp
(
currVal
);
in_elementwise_op
(
currVal
,
currVal
);
auto
currIndex
=
static_cast
<
IndexDataType
>
(
i
);
auto
currIndex
=
static_cast
<
IndexDataType
>
(
i
);
binop_with_index_and_nan_check
<
AccDataType
,
IndexDataType
,
PropagateNan
>
(
Accumulation
::
Calculate
(
accuVal
,
currVal
,
accuIndex
,
currIndex
);
opReduce2
,
accuVal
,
currVal
,
accuIndex
,
currIndex
);
};
};
posUnaryOp
(
accuVal
);
acc_elementwise_op
(
accuVal
,
accuVal
);
if
(
!
float_equal_one
{}(
alpha
))
if
(
!
float_equal_one
{}(
alpha
))
accuVal
*=
type_convert
<
AccDataType
>
(
alpha
);
accuVal
*=
type_convert
<
AccDataType
>
(
alpha
);
...
@@ -308,15 +302,16 @@ struct ReductionHost
...
@@ -308,15 +302,16 @@ struct ReductionHost
using
ck
::
float_equal_one
;
using
ck
::
float_equal_one
;
using
ck
::
float_equal_zero
;
using
ck
::
float_equal_zero
;
using
ck
::
type_convert
;
using
ck
::
type_convert
;
using
ck
::
host_reduce
::
binop_with_nan_check
;
using
ck
::
host_reduce
::
ReduceOpFn
;
using
ck
::
host_reduce
::
ReduceOpZeroVal
;
auto
opReduce
=
ReduceOpFn
<
AccDataType
,
ReduceOpId
>
();
using
Accumulation
=
ck
::
detail
::
AccumulateWithNanCheck
<
PropagateNan
,
ReduceOperation
,
AccDataType
>
;
InElementwiseOperation
in_elementwise_op
(
divider
);
AccElementwiseOperation
acc_elementwise_op
(
divider
);
if
constexpr
(
NumInvariantDim
==
0
)
if
constexpr
(
NumInvariantDim
==
0
)
{
{
AccDataType
accuVal
=
ReduceOp
Z
er
oVal
<
AccDataType
,
ReduceOpId
>
();
AccDataType
accuVal
=
ReduceOper
ation
::
GetIdentityValue
();
for
(
const
auto
&
reduce_index
:
reduce_dim_indexes
)
for
(
const
auto
&
reduce_index
:
reduce_dim_indexes
)
{
{
...
@@ -325,12 +320,12 @@ struct ReductionHost
...
@@ -325,12 +320,12 @@ struct ReductionHost
auto
currVal
=
type_convert
<
AccDataType
>
(
in_data
[
offset_reduce
]);
auto
currVal
=
type_convert
<
AccDataType
>
(
in_data
[
offset_reduce
]);
preUnaryOp
(
currVal
);
in_elementwise_op
(
currVal
,
currVal
);
binop_with_nan_check
<
AccDataType
,
PropagateNan
>
(
opReduce
,
accuVal
,
currVal
);
Accumulation
::
Calculate
(
accuVal
,
currVal
);
};
};
posUnaryOp
(
accuVal
);
acc_elementwise_op
(
accuVal
,
accuVal
);
if
(
!
float_equal_one
{}(
alpha
))
if
(
!
float_equal_one
{}(
alpha
))
accuVal
*=
type_convert
<
AccDataType
>
(
alpha
);
accuVal
*=
type_convert
<
AccDataType
>
(
alpha
);
...
@@ -343,7 +338,7 @@ struct ReductionHost
...
@@ -343,7 +338,7 @@ struct ReductionHost
else
else
{
{
auto
thread_reduce_func
=
[
&
](
auto
invariant_index
)
{
auto
thread_reduce_func
=
[
&
](
auto
invariant_index
)
{
AccDataType
accuVal
=
ReduceOp
Z
er
oVal
<
AccDataType
,
ReduceOpId
>
();
AccDataType
accuVal
=
ReduceOper
ation
::
GetIdentityValue
();
auto
offset_invariant
=
auto
offset_invariant
=
get_offset_from_index
<
NumInvariantDim
>
(
invariantStrides
,
invariant_index
);
get_offset_from_index
<
NumInvariantDim
>
(
invariantStrides
,
invariant_index
);
...
@@ -356,12 +351,12 @@ struct ReductionHost
...
@@ -356,12 +351,12 @@ struct ReductionHost
auto
currVal
=
auto
currVal
=
type_convert
<
AccDataType
>
(
in_data
[
offset_invariant
+
offset_reduce
]);
type_convert
<
AccDataType
>
(
in_data
[
offset_invariant
+
offset_reduce
]);
preUnaryOp
(
currVal
);
in_elementwise_op
(
currVal
,
currVal
);
binop_with_nan_check
<
AccDataType
,
PropagateNan
>
(
opReduce
,
accuVal
,
currVal
);
Accumulation
::
Calculate
(
accuVal
,
currVal
);
};
};
posUnaryOp
(
accuVal
);
acc_elementwise_op
(
accuVal
,
accuVal
);
if
(
!
float_equal_one
{}(
alpha
))
if
(
!
float_equal_one
{}(
alpha
))
accuVal
*=
type_convert
<
AccDataType
>
(
alpha
);
accuVal
*=
type_convert
<
AccDataType
>
(
alpha
);
...
...
library/include/ck/library/reference_tensor_operation/cpu/reference_cgemm.hpp
0 → 100644
View file @
f26fb605
This diff is collapsed.
Click to expand it.
library/src/tensor_operation_instance/gpu/batched_gemm_reduce/device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gkn_gmn_instance.cpp
View file @
f26fb605
...
@@ -38,7 +38,7 @@ static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecializa
...
@@ -38,7 +38,7 @@ static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecializa
using
device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gkn_gmn_instances
=
using
device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gkn_gmn_instances
=
std
::
tuple
<
std
::
tuple
<
// clang-format off
// clang-format off
//##################################| ALayout| BLayout| CLayout|AData| BData| CData| GemmAcc| CShuffle| ReduceAcc| DData| A| B| C| Dxs| DxsInEleOp| Dxs
Out
EleOp| D| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy|
//##################################| ALayout| BLayout| CLayout|AData| BData| CData| GemmAcc| CShuffle| ReduceAcc| DData| A| B| C| Dxs| DxsInEleOp| Dxs
Acc
EleOp| D| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy|
//##################################| | | | Type| Type| Type| DataType| DataType| DataType| Type Tuple| Elementwise| Elementwise| Elementwise| Reduce| | | MemoryData|Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector|
//##################################| | | | Type| Type| Type| DataType| DataType| DataType| Type Tuple| Elementwise| Elementwise| Elementwise| Reduce| | | MemoryData|Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector|
//##################################| | | | | | | | | | | Operation| Operation| Operation| Operation| | | Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock|
//##################################| | | | | | | | | | | Operation| Operation| Operation| Operation| | | Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock|
//##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
//##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
...
...
library/src/tensor_operation_instance/gpu/batched_gemm_reduce/device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gnk_gmn_instance.cpp
View file @
f26fb605
...
@@ -38,7 +38,7 @@ static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecializa
...
@@ -38,7 +38,7 @@ static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecializa
using
device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gnk_gmn_instances
=
using
device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gnk_gmn_instances
=
std
::
tuple
<
std
::
tuple
<
// clang-format off
// clang-format off
//##################################| ALayout| BLayout| CLayout|AData| BData| CData| GemmAcc| CShuffle| ReduceAcc| DData| A| B| C| Dxs| DxsInEleOp| Dxs
Out
EleOp| D| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy|
//##################################| ALayout| BLayout| CLayout|AData| BData| CData| GemmAcc| CShuffle| ReduceAcc| DData| A| B| C| Dxs| DxsInEleOp| Dxs
Acc
EleOp| D| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy|
//##################################| | | | Type| Type| Type| DataType| DataType| DataType| Type Tuple| Elementwise| Elementwise| Elementwise| Reduce| | | MemoryData|Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector|
//##################################| | | | Type| Type| Type| DataType| DataType| DataType| Type Tuple| Elementwise| Elementwise| Elementwise| Reduce| | | MemoryData|Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector|
//##################################| | | | | | | | | | | Operation| Operation| Operation| Operation| | | Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock|
//##################################| | | | | | | | | | | Operation| Operation| Operation| Operation| | | Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock|
//##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
//##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
...
...
library/src/tensor_operation_instance/gpu/batched_gemm_reduce/device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gkn_gmn_instance.cpp
View file @
f26fb605
...
@@ -38,7 +38,7 @@ static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecializa
...
@@ -38,7 +38,7 @@ static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecializa
using
device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gkn_gmn_instances
=
using
device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gkn_gmn_instances
=
std
::
tuple
<
std
::
tuple
<
// clang-format off
// clang-format off
//##################################| ALayout| BLayout| CLayout| AData| BData| CData| GemmAcc| CShuffle| ReduceAcc| DData| A| B| C| Dxs| DxsInEleOp| Dxs
Out
EleOp| D| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy|
//##################################| ALayout| BLayout| CLayout| AData| BData| CData| GemmAcc| CShuffle| ReduceAcc| DData| A| B| C| Dxs| DxsInEleOp| Dxs
Acc
EleOp| D| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy|
//##################################| | | | Type| Type| Type| DataType| DataType| DataType| Type Tuple| Elementwise| Elementwise| Elementwise| Reduce| | | MemoryData|Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector|
//##################################| | | | Type| Type| Type| DataType| DataType| DataType| Type Tuple| Elementwise| Elementwise| Elementwise| Reduce| | | MemoryData|Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector|
//##################################| | | | | | | | | | | Operation| Operation| Operation| Operation| | | Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock|
//##################################| | | | | | | | | | | Operation| Operation| Operation| Operation| | | Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock|
//##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
//##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
...
...
library/src/tensor_operation_instance/gpu/batched_gemm_reduce/device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gnk_gmn_instance.cpp
View file @
f26fb605
...
@@ -38,7 +38,7 @@ static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecializa
...
@@ -38,7 +38,7 @@ static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecializa
using
device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gnk_gmn_instances
=
using
device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gnk_gmn_instances
=
std
::
tuple
<
std
::
tuple
<
// clang-format off
// clang-format off
//##################################| ALayout| BLayout| CLayout|AData| BData| CData| GemmAcc| CShuffle| ReduceAcc| DData| A| B| C| Dxs| DxsInEleOp| Dxs
Out
EleOp| D| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy|
//##################################| ALayout| BLayout| CLayout|AData| BData| CData| GemmAcc| CShuffle| ReduceAcc| DData| A| B| C| Dxs| DxsInEleOp| Dxs
Acc
EleOp| D| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy|
//##################################| | | | Type| Type| Type| DataType| DataType| DataType| Type Tuple| Elementwise| Elementwise| Elementwise| Reduce| | | MemoryData|Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector|
//##################################| | | | Type| Type| Type| DataType| DataType| DataType| Type Tuple| Elementwise| Elementwise| Elementwise| Reduce| | | MemoryData|Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector|
//##################################| | | | | | | | | | | Operation| Operation| Operation| Operation| | | Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock|
//##################################| | | | | | | | | | | Operation| Operation| Operation| Operation| | | Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock|
//##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
//##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
...
...
library/src/tensor_operation_instance/gpu/gemm_reduce/device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_kn_mn_instance.cpp
View file @
f26fb605
...
@@ -24,10 +24,11 @@ using PassThrough = ck::tensor_operation::element_wise::PassThrough;
...
@@ -24,10 +24,11 @@ using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using
ReduceSum
=
ck
::
reduce
::
Add
<
F32
>
;
using
ReduceSum
=
ck
::
reduce
::
Add
<
F32
>
;
using
ReduceOps
=
ck
::
Tuple
<
ReduceSum
,
ReduceSum
>
;
using
ReduceOps
=
ck
::
Tuple
<
ReduceSum
,
ReduceSum
>
;
using
Div
=
ck
::
tensor_operation
::
element_wise
::
UnaryIdentic
<
F32
,
F32
,
true
>
;
using
Identity
=
ck
::
tensor_operation
::
element_wise
::
UnaryIdentic
<
F32
,
F32
,
false
>
;
using
Identity
=
ck
::
tensor_operation
::
element_wise
::
UnaryIdentic
<
F32
,
F32
,
false
>
;
using
Square
=
ck
::
tensor_operation
::
element_wise
::
UnarySquare
<
F32
,
F32
,
false
>
;
using
Square
=
ck
::
tensor_operation
::
element_wise
::
UnarySquare
<
F32
,
F32
,
false
>
;
using
DInElementOps
=
ck
::
Tuple
<
Identity
,
Square
>
;
using
DInElementOps
=
ck
::
Tuple
<
Identity
,
Square
>
;
using
DOutElementOps
=
ck
::
Tuple
<
Identity
,
Identity
>
;
using
DOutElementOps
=
ck
::
Tuple
<
Div
,
Div
>
;
using
ReduceMemOp
=
ck
::
InMemoryDataOperationEnumSequence
<
ck
::
InMemoryDataOperationEnum
::
AtomicAdd
,
using
ReduceMemOp
=
ck
::
InMemoryDataOperationEnumSequence
<
ck
::
InMemoryDataOperationEnum
::
AtomicAdd
,
ck
::
InMemoryDataOperationEnum
::
AtomicAdd
>
;
ck
::
InMemoryDataOperationEnum
::
AtomicAdd
>
;
...
@@ -37,7 +38,7 @@ static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecializa
...
@@ -37,7 +38,7 @@ static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecializa
// c[m, n] = a[k, m] * b[k, n]
// c[m, n] = a[k, m] * b[k, n]
using
device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_kn_mn_instances
=
std
::
tuple
<
using
device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_kn_mn_instances
=
std
::
tuple
<
// clang-format off
// clang-format off
//###########################| ALayout| BLayout| CLayout|AData| BData| CData| GemmAcc| CShuffle| ReduceAcc| DData| A| B| C| Dxs| DxsInEleOp| Dxs
Out
EleOp| D| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy|
//###########################| ALayout| BLayout| CLayout|AData| BData| CData| GemmAcc| CShuffle| ReduceAcc| DData| A| B| C| Dxs| DxsInEleOp| Dxs
Acc
EleOp| D| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy|
//###########################| | | | Type| Type| Type| DataType| DataType| DataType| Type Tuple| Elementwise| Elementwise| Elementwise| Reduce| | | MemoryData|Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector|
//###########################| | | | Type| Type| Type| DataType| DataType| DataType| Type Tuple| Elementwise| Elementwise| Elementwise| Reduce| | | MemoryData|Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector|
//###########################| | | | | | | | | | | Operation| Operation| Operation| Operation| | | Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock|
//###########################| | | | | | | | | | | Operation| Operation| Operation| Operation| | | Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock|
//###########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
//###########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
...
...
Prev
1
2
3
4
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment