Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
b05a594e
Commit
b05a594e
authored
Apr 13, 2022
by
rocking
Browse files
Add reduce sum for denominator of softmax
parent
30348daa
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
103 additions
and
43 deletions
+103
-43
example/19_gemm_softmax/gemm_softmax_xdl_fp16.cpp
example/19_gemm_softmax/gemm_softmax_xdl_fp16.cpp
+103
-43
No files found.
example/19_gemm_softmax/gemm_softmax_xdl_fp16.cpp
View file @
b05a594e
...
...
@@ -85,27 +85,53 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdl_C_Shuffle
8
>
;
// CBlockTransferScalarPerVector_NWaveNPerXdl
// clang-format on
constexpr
int
Reduce
Rank
=
2
;
constexpr
int
NumReduceDim
=
1
;
constexpr
int
Rank
=
2
;
constexpr
int
NumReduceDim
=
1
;
constexpr
ck
::
ReduceTensorOp
ReduceMaxId
=
ck
::
ReduceTensorOp
::
MAX
;
constexpr
ck
::
NanPropagation
NanOpt
=
ck
::
NanPropagation
::
PROPAGATE_NAN
;
constexpr
ck
::
ReduceTensorOp
ReduceSumId
=
ck
::
ReduceTensorOp
::
ADD
;
constexpr
ck
::
NanPropagation
NanOpt
=
ck
::
NanPropagation
::
PROPAGATE_NAN
;
constexpr
bool
PropagateNan
=
(
NanOpt
==
ck
::
NanPropagation
::
NOT_PROPAGATE_NAN
)
?
false
:
true
;
// constexpr ck::ReduceTensorIndices_t IndicesOpt = ck::ReduceTensorIndices_t::NO_INDICES;
using
ReduceMaxOp
=
typename
ck
::
reduce_binary_operator
<
CDataType
,
ReduceMaxId
>::
opType
;
using
InElementwiseOperation
=
using
ReduceSumOp
=
typename
ck
::
reduce_binary_operator
<
CDataType
,
ReduceSumId
>::
opType
;
using
ReduceMaxInElementwiseOperation
=
typename
ck
::
reduce_unary_operator
<
CDataType
,
ReduceMaxId
,
true
,
true
>::
InElementwiseOperation
;
using
AccElementwiseOperation
=
using
ReduceMax
AccElementwiseOperation
=
typename
ck
::
reduce_unary_operator
<
CDataType
,
ReduceMaxId
,
true
,
true
>::
AccElementwiseOperation
;
using
ReduceSumInElementwiseOperation
=
typename
ck
::
reduce_unary_operator
<
CDataType
,
ReduceSumId
,
true
,
true
>::
InElementwiseOperation
;
using
ReduceSumAccElementwiseOperation
=
typename
ck
::
reduce_unary_operator
<
CDataType
,
ReduceSumId
,
true
,
true
>::
AccElementwiseOperation
;
using
DeviceReduceMaxInstance
=
ck
::
tensor_operation
::
device
::
DeviceReduceBlockWise
<
CDataType
,
CDataType
,
CDataType
,
Reduce
Rank
,
Rank
,
NumReduceDim
,
ReduceMaxOp
,
InElementwiseOperation
,
AccElementwiseOperation
,
ReduceMaxInElementwiseOperation
,
ReduceMaxAccElementwiseOperation
,
PropagateNan
,
false
,
256
,
4
,
64
,
1
,
1
,
0
,
1
,
1
>
;
using
DeviceReduceSumInstance
=
ck
::
tensor_operation
::
device
::
DeviceReduceBlockWise
<
CDataType
,
CDataType
,
CDataType
,
Rank
,
NumReduceDim
,
ReduceSumOp
,
ReduceSumInElementwiseOperation
,
ReduceSumAccElementwiseOperation
,
PropagateNan
,
false
,
256
,
...
...
@@ -119,12 +145,13 @@ using DeviceReduceMaxInstance =
struct
Sub_Exp
{
__host__
__device__
constexpr
void
operator
()(
CDataType
&
dst
,
const
CDataType
&
src1
,
const
CDataType
&
src2
)
const
__host__
__device__
constexpr
void
operator
()(
CDataType
&
dst
,
const
CDataType
&
src1
,
const
CDataType
&
src2
)
const
{
dst
=
src1
-
src2
;
// FIXME - use float16 exponential
float
dst_f32
=
static_cast
<
float
>
(
dst
);
dst
=
static_cast
<
CDataType
>
(
exp
(
dst_f32
));
dst
=
static_cast
<
CDataType
>
(
exp
(
dst_f32
));
}
};
...
...
@@ -198,22 +225,25 @@ int main(int argc, char* argv[])
Tensor
<
ADataType
>
a_m_k
(
f_host_tensor_descriptor
(
M
,
K
,
StrideA
,
ALayout
{}));
Tensor
<
BDataType
>
b_k_n
(
f_host_tensor_descriptor
(
K
,
N
,
StrideB
,
BLayout
{}));
Tensor
<
CDataType
>
c_m_n
(
f_host_tensor_descriptor
(
M
,
N
,
StrideC
,
CLayout
{}));
Tensor
<
int
>
c_m_n_max
(
std
::
vector
<
std
::
size_t
>
({
static_cast
<
std
::
size_t
>
(
N
)}),
std
::
vector
<
std
::
size_t
>
({
1
}));
Tensor
<
CDataType
>
d_m_n
(
f_host_tensor_descriptor
(
M
,
N
,
StrideC
,
CLayout
{}));
Tensor
<
int
>
c_n_max
(
std
::
vector
<
std
::
size_t
>
({
static_cast
<
std
::
size_t
>
(
N
)}),
std
::
vector
<
std
::
size_t
>
({
1
}));
Tensor
<
CDataType
>
exp_m_n
(
f_host_tensor_descriptor
(
M
,
N
,
StrideC
,
CLayout
{}));
Tensor
<
CDataType
>
exp_n_sum
(
std
::
vector
<
std
::
size_t
>
({
static_cast
<
std
::
size_t
>
(
N
)}),
std
::
vector
<
std
::
size_t
>
({
1
}));
const
auto
i_inLengths
=
ck
::
to_int_vector
(
c_m_n
.
mDesc
.
GetLengths
());
const
auto
i_inS
tride
s
=
ck
::
to_int_vector
(
c_m_n
.
mDesc
.
GetStrides
());
const
auto
i_outLengths
=
ck
::
to_int_vector
(
c_
m_
n_max
.
mDesc
.
GetLengths
());
const
auto
i_outS
tride
s
=
ck
::
to_int_vector
(
c_
m_
n_max
.
mDesc
.
GetStrides
());
const
auto
c_m_n_shape
=
ck
::
to_int_vector
(
c_m_n
.
mDesc
.
GetLengths
());
const
auto
c_m_n_s
tride
=
ck
::
to_int_vector
(
c_m_n
.
mDesc
.
GetStrides
());
const
auto
reduce_n_shape
=
ck
::
to_int_vector
(
c_n_max
.
mDesc
.
GetLengths
());
const
auto
reduce_n_s
tride
=
ck
::
to_int_vector
(
c_n_max
.
mDesc
.
GetStrides
());
size_t
reduce_total_length
=
c_m_n
.
mDesc
.
GetElementSize
()
/
c_
m_
n_max
.
mDesc
.
GetElementSize
();
size_t
reduce_total_length
=
c_m_n
.
mDesc
.
GetElementSize
()
/
c_n_max
.
mDesc
.
GetElementSize
();
std
::
cout
<<
"a_m_k: "
<<
a_m_k
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"b_k_n: "
<<
b_k_n
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"c_m_n: "
<<
c_m_n
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"c_m_n_max: "
<<
c_m_n_max
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"d_m_n: "
<<
d_m_n
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"c_n_max: "
<<
c_n_max
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"exp_m_n: "
<<
exp_m_n
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"exp_n_sum: "
<<
exp_n_sum
.
mDesc
<<
std
::
endl
;
switch
(
init_method
)
{
...
...
@@ -230,9 +260,10 @@ int main(int argc, char* argv[])
DeviceMem
a_m_k_device_buf
(
sizeof
(
ADataType
)
*
a_m_k
.
mDesc
.
GetElementSpace
());
DeviceMem
b_k_n_device_buf
(
sizeof
(
BDataType
)
*
b_k_n
.
mDesc
.
GetElementSpace
());
DeviceMem
c_m_n_device_buf
(
sizeof
(
CDataType
)
*
c_m_n
.
mDesc
.
GetElementSpace
());
DeviceMem
c_m_n_max_device_buf
(
sizeof
(
CDataType
)
*
c_m_n_max
.
mDesc
.
GetElementSpace
());
DeviceMem
c_m_n_max_indices_dev
(
0
);
DeviceMem
d_m_n_device_buf
(
sizeof
(
CDataType
)
*
d_m_n
.
mDesc
.
GetElementSpace
());
DeviceMem
c_n_max_device_buf
(
sizeof
(
CDataType
)
*
c_n_max
.
mDesc
.
GetElementSpace
());
DeviceMem
indices_device_buf
(
0
);
DeviceMem
exp_m_n_device_buf
(
sizeof
(
CDataType
)
*
exp_m_n
.
mDesc
.
GetElementSpace
());
DeviceMem
exp_n_sum_device_buf
(
sizeof
(
CDataType
)
*
exp_n_sum
.
mDesc
.
GetElementSpace
());
a_m_k_device_buf
.
ToDevice
(
a_m_k
.
mData
.
data
());
b_k_n_device_buf
.
ToDevice
(
b_k_n
.
mData
.
data
());
...
...
@@ -265,23 +296,23 @@ int main(int argc, char* argv[])
// do reduce max
auto
reduce_max
=
DeviceReduceMaxInstance
{};
auto
wsSizeInBytes
=
reduce_max
.
GetWorkspaceSizeInBytes
(
i_inLengths
,
reduceDims
);
DeviceMem
ws_dev
(
wsSizeInBytes
);
auto
reduce_max_workaspace_size
=
reduce_max
.
GetWorkspaceSizeInBytes
(
c_m_n_shape
,
reduceDims
);
DeviceMem
reduce_max_workaspace_device_buf
(
reduce_max_workaspace_size
);
auto
reduce_max_argument_ptr
=
reduce_max
.
MakeArgumentPointer
(
i_inLengths
,
i_inS
tride
s
,
i_outLengths
,
i_outS
tride
s
,
c_m_n_shape
,
c_m_n_s
tride
,
reduce_n_shape
,
reduce_n_s
tride
,
reduceDims
,
1
,
0
,
c_m_n_device_buf
.
GetDeviceBuffer
(),
c_
m_
n_max_device_buf
.
GetDeviceBuffer
(),
c_m_n_max_
indices_dev
.
GetDeviceBuffer
(),
ws_dev
.
GetDeviceBuffer
(),
InElementwiseOperation
{
static_cast
<
int
>
(
reduce_total_length
)},
AccElementwiseOperation
{
static_cast
<
int
>
(
reduce_total_length
)});
c_n_max_device_buf
.
GetDeviceBuffer
(),
indices_dev
ice_buf
.
GetDeviceBuffer
(),
reduce_max_workaspace_device_buf
.
GetDeviceBuffer
(),
ReduceMax
InElementwiseOperation
{
static_cast
<
int
>
(
reduce_total_length
)},
ReduceMax
AccElementwiseOperation
{
static_cast
<
int
>
(
reduce_total_length
)});
if
(
!
reduce_max
.
IsSupportedArgument
(
reduce_max_argument_ptr
.
get
()))
{
...
...
@@ -292,17 +323,17 @@ int main(int argc, char* argv[])
auto
reduce_max_invoker_ptr
=
reduce_max
.
MakeInvokerPointer
();
reduce_max_invoker_ptr
->
Run
(
reduce_max_argument_ptr
.
get
(),
nrepeat
);
// do broadcast sub
// do broadcast sub
and exp
auto
broadcastSubExp
=
DeviceElementwiseInstance
{};
auto
broadcastSubExp_argument_ptr
=
broadcastSubExp
.
MakeArgumentPointer
(
c_m_n_device_buf
.
GetDeviceBuffer
(),
c_m
_n_max_device_buf
.
GetDeviceBuffer
(),
d
_m_n_device_buf
.
GetDeviceBuffer
(),
{
M
,
N
},
{
StrideC
,
1
},
{
0
,
1
},
{
StrideC
,
1
},
Sub_Exp
{});
c
_n_max_device_buf
.
GetDeviceBuffer
(),
exp
_m_n_device_buf
.
GetDeviceBuffer
(),
{
M
,
N
},
{
StrideC
,
1
},
{
0
,
1
},
{
StrideC
,
1
},
Sub_Exp
{});
if
(
!
broadcastSubExp
.
IsSupportedArgument
(
broadcastSubExp_argument_ptr
.
get
()))
{
...
...
@@ -313,7 +344,36 @@ int main(int argc, char* argv[])
auto
broadcastSubExp_invoker_ptr
=
broadcastSubExp
.
MakeInvokerPointer
();
broadcastSubExp_invoker_ptr
->
Run
(
broadcastSubExp_argument_ptr
.
get
(),
nrepeat
);
// TODO - Need BroadcastSub + exponential + ReduceSum + BroadcastDiv
// do reduce sum - denominator of softmax
auto
reduce_sum
=
DeviceReduceSumInstance
{};
auto
reduce_sum_workaspace_size
=
reduce_sum
.
GetWorkspaceSizeInBytes
(
c_m_n_shape
,
reduceDims
);
DeviceMem
reduce_sum_workaspace_device_buf
(
reduce_sum_workaspace_size
);
auto
reduce_sum_argument_ptr
=
reduce_sum
.
MakeArgumentPointer
(
c_m_n_shape
,
c_m_n_stride
,
reduce_n_shape
,
reduce_n_stride
,
reduceDims
,
1
,
0
,
exp_m_n_device_buf
.
GetDeviceBuffer
(),
exp_n_sum_device_buf
.
GetDeviceBuffer
(),
indices_device_buf
.
GetDeviceBuffer
(),
reduce_sum_workaspace_device_buf
.
GetDeviceBuffer
(),
ReduceSumInElementwiseOperation
{
static_cast
<
int
>
(
reduce_total_length
)},
ReduceSumAccElementwiseOperation
{
static_cast
<
int
>
(
reduce_total_length
)});
if
(
!
reduce_sum
.
IsSupportedArgument
(
reduce_sum_argument_ptr
.
get
()))
{
throw
std
::
runtime_error
(
"The runtime parameters seems not supported by the DeviceReduce instance, exiting!"
);
};
auto
reduce_sum_invoker_ptr
=
reduce_sum
.
MakeInvokerPointer
();
reduce_sum_invoker_ptr
->
Run
(
reduce_sum_argument_ptr
.
get
(),
nrepeat
);
// TODO - Need BroadcastDiv
// TODO = do_verification
(
void
)
do_verification
;
return
0
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment