Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
0f421d6f
Commit
0f421d6f
authored
Apr 20, 2022
by
rocking
Browse files
[What] Add ComputeDataType to the eltwise kernel
[Why] Similar to acc datatype, it increase precision
parent
cf326690
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
59 additions
and
31 deletions
+59
-31
example/19_gemm_softmax/gemm_softmax_xdl_fp16.cpp
example/19_gemm_softmax/gemm_softmax_xdl_fp16.cpp
+50
-25
include/ck/tensor_operation/gpu/device/device_elementwise_2d.hpp
.../ck/tensor_operation/gpu/device/device_elementwise_2d.hpp
+2
-0
include/ck/tensor_operation/gpu/grid/gridwise_elementwise_1d.hpp
.../ck/tensor_operation/gpu/grid/gridwise_elementwise_1d.hpp
+7
-6
No files found.
example/19_gemm_softmax/gemm_softmax_xdl_fp16.cpp
View file @
0f421d6f
...
...
@@ -36,10 +36,11 @@ using Col = ck::tensor_layout::gemm::ColumnMajor;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
ADataType
=
F16
;
using
BDataType
=
F16
;
using
CDataType
=
F16
;
using
AccDataType
=
F32
;
using
ADataType
=
F16
;
using
BDataType
=
F16
;
using
CDataType
=
F16
;
using
AccDataType
=
F32
;
using
EltwiseComputeDataType
=
F32
;
// CAUSION - host reduce_max will call numeric_limits<ck::half_t>::lowest()
// However, numeric_limits<ck::half_t>::lowest() will return zero. So, used half_float::half instead
...
...
@@ -103,10 +104,10 @@ using ReduceMaxInElementwiseOperation =
typename
ck
::
reduce_unary_operator
<
CDataType
,
ReduceMaxId
,
true
,
true
>::
InElementwiseOperation
;
using
ReduceMaxAccElementwiseOperation
=
typename
ck
::
reduce_unary_operator
<
CDataType
,
ReduceMaxId
,
true
,
true
>::
AccElementwiseOperation
;
using
ReduceSumInElementwiseOperation
=
typename
ck
::
reduce_unary_operator
<
AccDataType
,
ReduceSumId
,
true
,
true
>::
InElementwiseOperation
;
using
ReduceSumAccElementwiseOperation
=
typename
ck
::
reduce_unary_operator
<
AccDataType
,
ReduceSumId
,
true
,
true
>::
AccElementwiseOperation
;
using
ReduceSumInElementwiseOperation
=
typename
ck
::
reduce_unary_operator
<
AccDataType
,
ReduceSumId
,
true
,
true
>::
InElementwiseOperation
;
using
ReduceSumAccElementwiseOperation
=
typename
ck
::
reduce_unary_operator
<
AccDataType
,
ReduceSumId
,
true
,
true
>::
AccElementwiseOperation
;
using
DeviceReduceMaxInstance
=
ck
::
tensor_operation
::
device
::
DeviceReduceBlockWise
<
CDataType
,
...
...
@@ -150,30 +151,36 @@ using DeviceReduceSumInstance =
struct
Sub_Exp
{
__host__
__device__
constexpr
void
operator
()(
CDataType
&
dst
,
const
CDataType
&
src1
,
const
CDataType
&
src2
)
const
__host__
__device__
constexpr
void
operator
()(
EltwiseComputeDataType
&
dst
,
const
EltwiseComputeDataType
&
src1
,
const
EltwiseComputeDataType
&
src2
)
const
{
dst
=
src1
-
src2
;
// FIXME - use float16 exponential
float
dst_f32
=
static_cast
<
float
>
(
dst
);
dst
=
static_cast
<
CDataType
>
(
exp
(
dst_f32
));
dst
=
exp
(
src1
-
src2
);
}
};
struct
Div
{
__host__
__device__
constexpr
void
operator
()(
CDataType
&
dst
,
const
CDataType
&
src1
,
const
CDataType
&
src2
)
const
__host__
__device__
constexpr
void
operator
()(
EltwiseComputeDataType
&
dst
,
const
EltwiseComputeDataType
&
src1
,
const
EltwiseComputeDataType
&
src2
)
const
{
dst
=
src1
/
src2
;
}
};
using
DeviceElementwiseSubExpInstance
=
ck
::
tensor_operation
::
device
::
DeviceElementwise_2D
<
CDataType
,
CDataType
,
CDataType
,
Sub_Exp
,
256
,
32
,
8
>
;
using
DeviceElementwiseSubExpInstance
=
ck
::
tensor_operation
::
device
::
DeviceElementwise_2D
<
CDataType
,
CDataType
,
CDataType
,
EltwiseComputeDataType
,
Sub_Exp
,
256
,
32
,
8
>
;
using
DeviceElementwiseDivInstance
=
ck
::
tensor_operation
::
device
::
DeviceElementwise_2D
<
CDataType
,
CDataType
,
CDataType
,
Div
,
256
,
32
,
8
>
;
DeviceElementwise_2D
<
CDataType
,
CDataType
,
CDataType
,
EltwiseComputeDataType
,
Div
,
256
,
32
,
8
>
;
using
HostGemmInstance
=
ck
::
tensor_operation
::
host
::
ReferenceGemm
<
ADataType
,
BDataType
,
CDataType
,
PassThrough
,
PassThrough
,
PassThrough
>
;
...
...
@@ -199,6 +206,7 @@ using HostReduceSumInstance = ReductionHost<HostReduceDataType,
template
<
typename
HostTensorA
,
typename
HostTensorB
,
typename
HostTensorC
,
typename
ComputeDataType
,
typename
Functor
,
int
broadcastDim
>
void
host_broadcast2D
(
...
...
@@ -208,10 +216,19 @@ void host_broadcast2D(
{
for
(
int
n
=
0
;
n
<
N
;
++
n
)
{
ComputeDataType
Amn
=
static_cast
<
ComputeDataType
>
(
A
(
m
,
n
));
ComputeDataType
Cmn
=
0
;
if
constexpr
(
broadcastDim
==
1
)
functor
(
C
(
m
,
n
),
A
(
m
,
n
),
B
(
n
));
{
ComputeDataType
Bn
=
static_cast
<
ComputeDataType
>
(
B
(
n
));
functor
(
Cmn
,
Amn
,
Bn
);
}
else
functor
(
C
(
m
,
n
),
A
(
m
,
n
),
B
(
m
));
{
ComputeDataType
Bm
=
static_cast
<
ComputeDataType
>
(
B
(
m
));
functor
(
Cmn
,
Amn
,
Bm
);
}
C
(
m
,
n
)
=
static_cast
<
ComputeDataType
>
(
Cmn
);
}
}
}
...
...
@@ -490,8 +507,12 @@ int main(int argc, char* argv[])
reinterpret_cast
<
HostReduceDataType
*>
(
host_c_n_max
.
mData
.
data
()),
host_indices
.
mData
.
data
());
host_broadcast2D
<
Tensor
<
CDataType
>
,
Tensor
<
CDataType
>
,
Tensor
<
CDataType
>
,
Sub_Exp
,
1
>
(
host_exp_m_n
,
c_m_n
,
c_n_max
,
M
,
N
,
Sub_Exp
{});
host_broadcast2D
<
Tensor
<
CDataType
>
,
Tensor
<
CDataType
>
,
Tensor
<
CDataType
>
,
EltwiseComputeDataType
,
Sub_Exp
,
1
>
(
host_exp_m_n
,
c_m_n
,
c_n_max
,
M
,
N
,
Sub_Exp
{});
host_reduce_sum
.
Run
(
1
,
// alpha
reinterpret_cast
<
const
HostReduceDataType
*>
(
exp_m_n
.
mData
.
data
()),
...
...
@@ -499,8 +520,12 @@ int main(int argc, char* argv[])
reinterpret_cast
<
HostReduceDataType
*>
(
host_exp_n_sum
.
mData
.
data
()),
host_indices
.
mData
.
data
());
host_broadcast2D
<
Tensor
<
CDataType
>
,
Tensor
<
CDataType
>
,
Tensor
<
CDataType
>
,
Div
,
1
>
(
host_softmax_m_n
,
exp_m_n
,
exp_n_sum
,
M
,
N
,
Div
{});
host_broadcast2D
<
Tensor
<
CDataType
>
,
Tensor
<
CDataType
>
,
Tensor
<
CDataType
>
,
EltwiseComputeDataType
,
Div
,
1
>
(
host_softmax_m_n
,
exp_m_n
,
exp_n_sum
,
M
,
N
,
Div
{});
bool
result
=
true
;
if
(
result
&=
ck
::
utils
::
check_err
(
c_m_n
.
mData
,
host_c_m_n
.
mData
))
...
...
include/ck/tensor_operation/gpu/device/device_elementwise_2d.hpp
View file @
0f421d6f
...
...
@@ -13,6 +13,7 @@ namespace device {
template
<
typename
ADataType
,
typename
BDataType
,
typename
CDataType
,
typename
ComputeDataType
,
typename
ElementwiseFunctor
,
index_t
ThreadPerBlock
,
index_t
ThreadTileSize
,
...
...
@@ -43,6 +44,7 @@ struct DeviceElementwise_2D : public DeviceElementwise<ElementwiseFunctor>
using
GridwiseEltwise
=
GridwiseElementwise_1D
<
ADataType
,
BDataType
,
CDataType
,
ComputeDataType
,
GridDesc_M0
,
ElementwiseFunctor
,
ThreadPerBlock
,
...
...
include/ck/tensor_operation/gpu/grid/gridwise_elementwise_1d.hpp
View file @
0f421d6f
...
...
@@ -33,6 +33,7 @@ __global__ void kernel_elementwise_1d(const ADataType* __restrict__ p_a_global,
template
<
typename
ADataType
,
typename
BDataType
,
typename
CDataType
,
typename
ComputeDataType
,
typename
GridDesc_M0
,
typename
ElementwiseFunctor
,
index_t
ThreadPerBlock
,
...
...
@@ -70,15 +71,15 @@ struct GridwiseElementwise_1D
auto
c_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_c_global
,
c_grid_desc_m0
.
GetElementSpaceSize
());
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
A
DataType
,
ScalarPerVector
,
true
>
a_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
B
DataType
,
ScalarPerVector
,
true
>
b_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
CDataType
,
ScalarPerVector
,
true
>
c_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
Compute
DataType
,
ScalarPerVector
,
true
>
a_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
Compute
DataType
,
ScalarPerVector
,
true
>
b_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
C
ompute
DataType
,
ScalarPerVector
,
true
>
c_thread_buf
;
const
auto
thread_to_global_offset
=
CalculateElementwiseIndex
();
auto
a_global_load
=
ThreadwiseTensorSliceTransfer_v2
<
ADataType
,
A
DataType
,
Compute
DataType
,
GridDesc_M0
,
decltype
(
thread_desc_M0
),
Sequence
<
ScalarPerVector
>
,
// SliceLengths
...
...
@@ -90,7 +91,7 @@ struct GridwiseElementwise_1D
auto
b_global_load
=
ThreadwiseTensorSliceTransfer_v2
<
BDataType
,
B
DataType
,
Compute
DataType
,
GridDesc_M0
,
decltype
(
thread_desc_M0
),
Sequence
<
ScalarPerVector
>
,
// SliceLengths
...
...
@@ -101,7 +102,7 @@ struct GridwiseElementwise_1D
false
>
{
b_grid_desc_m0
,
thread_to_global_offset
};
auto
c_global_write
=
ThreadwiseTensorSliceTransfer_v1r3
<
CDataType
,
ThreadwiseTensorSliceTransfer_v1r3
<
C
ompute
DataType
,
CDataType
,
decltype
(
thread_desc_M0
),
GridDesc_M0
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment