Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
6a781e51
Commit
6a781e51
authored
Apr 13, 2022
by
rocking
Browse files
Add broadcast div, the final step of softmax
parent
b05a594e
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
43 additions
and
8 deletions
+43
-8
example/19_gemm_softmax/gemm_softmax_xdl_fp16.cpp
example/19_gemm_softmax/gemm_softmax_xdl_fp16.cpp
+43
-8
No files found.
example/19_gemm_softmax/gemm_softmax_xdl_fp16.cpp
View file @
6a781e51
...
...
@@ -155,9 +155,21 @@ struct Sub_Exp
}
};
using
DeviceElementwiseInstance
=
ck
::
tensor_operation
::
device
::
struct
Div
{
__host__
__device__
constexpr
void
operator
()(
CDataType
&
dst
,
const
CDataType
&
src1
,
const
CDataType
&
src2
)
const
{
dst
=
src1
/
src2
;
}
};
using
DeviceElementwiseSubExpInstance
=
ck
::
tensor_operation
::
device
::
DeviceElementwise_2D
<
CDataType
,
CDataType
,
CDataType
,
Sub_Exp
,
16
,
16
,
8
,
8
,
1
,
1
,
1
,
1
,
1
>
;
using
DeviceElementwiseDivInstance
=
ck
::
tensor_operation
::
device
::
DeviceElementwise_2D
<
CDataType
,
CDataType
,
CDataType
,
Div
,
16
,
16
,
8
,
8
,
1
,
1
,
1
,
1
,
1
>
;
using
ReferenceGemmInstance
=
ck
::
tensor_operation
::
host
::
ReferenceGemm
<
ADataType
,
BDataType
,
CDataType
,
PassThrough
,
PassThrough
,
PassThrough
>
;
...
...
@@ -230,6 +242,7 @@ int main(int argc, char* argv[])
Tensor
<
CDataType
>
exp_m_n
(
f_host_tensor_descriptor
(
M
,
N
,
StrideC
,
CLayout
{}));
Tensor
<
CDataType
>
exp_n_sum
(
std
::
vector
<
std
::
size_t
>
({
static_cast
<
std
::
size_t
>
(
N
)}),
std
::
vector
<
std
::
size_t
>
({
1
}));
Tensor
<
CDataType
>
softmax_m_n
(
f_host_tensor_descriptor
(
M
,
N
,
StrideC
,
CLayout
{}));
const
auto
c_m_n_shape
=
ck
::
to_int_vector
(
c_m_n
.
mDesc
.
GetLengths
());
const
auto
c_m_n_stride
=
ck
::
to_int_vector
(
c_m_n
.
mDesc
.
GetStrides
());
...
...
@@ -244,6 +257,7 @@ int main(int argc, char* argv[])
std
::
cout
<<
"c_n_max: "
<<
c_n_max
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"exp_m_n: "
<<
exp_m_n
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"exp_n_sum: "
<<
exp_n_sum
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"softmax_m_n: "
<<
softmax_m_n
.
mDesc
<<
std
::
endl
;
switch
(
init_method
)
{
...
...
@@ -264,6 +278,7 @@ int main(int argc, char* argv[])
DeviceMem
indices_device_buf
(
0
);
DeviceMem
exp_m_n_device_buf
(
sizeof
(
CDataType
)
*
exp_m_n
.
mDesc
.
GetElementSpace
());
DeviceMem
exp_n_sum_device_buf
(
sizeof
(
CDataType
)
*
exp_n_sum
.
mDesc
.
GetElementSpace
());
DeviceMem
softmax_m_n_device_buf
(
sizeof
(
CDataType
)
*
softmax_m_n
.
mDesc
.
GetElementSpace
());
a_m_k_device_buf
.
ToDevice
(
a_m_k
.
mData
.
data
());
b_k_n_device_buf
.
ToDevice
(
b_k_n
.
mData
.
data
());
...
...
@@ -324,7 +339,7 @@ int main(int argc, char* argv[])
reduce_max_invoker_ptr
->
Run
(
reduce_max_argument_ptr
.
get
(),
nrepeat
);
// do broadcast sub and exp
auto
broadcastSubExp
=
DeviceElementwiseInstance
{};
auto
broadcastSubExp
=
DeviceElementwise
SubExp
Instance
{};
auto
broadcastSubExp_argument_ptr
=
broadcastSubExp
.
MakeArgumentPointer
(
c_m_n_device_buf
.
GetDeviceBuffer
(),
c_n_max_device_buf
.
GetDeviceBuffer
(),
...
...
@@ -373,7 +388,27 @@ int main(int argc, char* argv[])
auto
reduce_sum_invoker_ptr
=
reduce_sum
.
MakeInvokerPointer
();
reduce_sum_invoker_ptr
->
Run
(
reduce_sum_argument_ptr
.
get
(),
nrepeat
);
// TODO - Need BroadcastDiv
// do broadcast div
auto
broadcastDiv
=
DeviceElementwiseDivInstance
{};
auto
broadcastDiv_argument_ptr
=
broadcastDiv
.
MakeArgumentPointer
(
exp_m_n_device_buf
.
GetDeviceBuffer
(),
exp_n_sum_device_buf
.
GetDeviceBuffer
(),
softmax_m_n_device_buf
.
GetDeviceBuffer
(),
{
M
,
N
},
{
StrideC
,
1
},
{
0
,
1
},
{
StrideC
,
1
},
Div
{});
if
(
!
broadcastDiv
.
IsSupportedArgument
(
broadcastDiv_argument_ptr
.
get
()))
{
throw
std
::
runtime_error
(
"The runtime parameters seems not supported by the "
"DeviceElementwise_2D instance, exiting!"
);
};
auto
broadcastDiv_invoker_ptr
=
broadcastDiv
.
MakeInvokerPointer
();
broadcastDiv_invoker_ptr
->
Run
(
broadcastDiv_argument_ptr
.
get
(),
nrepeat
);
// TODO = do_verification
(
void
)
do_verification
;
return
0
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment