Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
21802fda
Commit
21802fda
authored
Apr 18, 2022
by
rocking
Browse files
[What] Sync input of each host kernel and device kernel
[Why] Prevent error propogation
parent
e83b22e0
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
11 additions
and
10 deletions
+11
-10
example/19_gemm_softmax/gemm_softmax_xdl_fp16.cpp
example/19_gemm_softmax/gemm_softmax_xdl_fp16.cpp
+11
-10
No files found.
example/19_gemm_softmax/gemm_softmax_xdl_fp16.cpp
View file @
21802fda
...
@@ -455,6 +455,13 @@ int main(int argc, char* argv[])
...
@@ -455,6 +455,13 @@ int main(int argc, char* argv[])
if
(
do_verification
)
if
(
do_verification
)
{
{
std
::
cout
<<
"verification..."
<<
std
::
endl
;
std
::
cout
<<
"verification..."
<<
std
::
endl
;
c_m_n_device_buf
.
FromDevice
(
c_m_n
.
mData
.
data
());
c_n_max_device_buf
.
FromDevice
(
c_n_max
.
mData
.
data
());
exp_m_n_device_buf
.
FromDevice
(
exp_m_n
.
mData
.
data
());
exp_n_sum_device_buf
.
FromDevice
(
exp_n_sum
.
mData
.
data
());
softmax_m_n_device_buf
.
FromDevice
(
softmax_m_n
.
mData
.
data
());
const
std
::
vector
<
int
>
reduceInvariantDims
{
1
};
const
std
::
vector
<
int
>
reduceInvariantDims
{
1
};
Tensor
<
CDataType
>
host_c_m_n
(
f_host_tensor_descriptor
(
M
,
N
,
StrideC
,
CLayout
{}));
Tensor
<
CDataType
>
host_c_m_n
(
f_host_tensor_descriptor
(
M
,
N
,
StrideC
,
CLayout
{}));
Tensor
<
CDataType
>
host_c_n_max
(
std
::
vector
<
std
::
size_t
>
({
static_cast
<
std
::
size_t
>
(
N
)}),
Tensor
<
CDataType
>
host_c_n_max
(
std
::
vector
<
std
::
size_t
>
({
static_cast
<
std
::
size_t
>
(
N
)}),
...
@@ -478,28 +485,22 @@ int main(int argc, char* argv[])
...
@@ -478,28 +485,22 @@ int main(int argc, char* argv[])
host_gemm_invoker
.
Run
(
host_gemm_argument
);
host_gemm_invoker
.
Run
(
host_gemm_argument
);
host_reduce_max
.
Run
(
1
,
// alpha
host_reduce_max
.
Run
(
1
,
// alpha
reinterpret_cast
<
const
HostReduceDataType
*>
(
host_
c_m_n
.
mData
.
data
()),
reinterpret_cast
<
const
HostReduceDataType
*>
(
c_m_n
.
mData
.
data
()),
0
,
// beta
0
,
// beta
reinterpret_cast
<
HostReduceDataType
*>
(
host_c_n_max
.
mData
.
data
()),
reinterpret_cast
<
HostReduceDataType
*>
(
host_c_n_max
.
mData
.
data
()),
host_indices
.
mData
.
data
());
host_indices
.
mData
.
data
());
host_broadcast2D
<
Tensor
<
CDataType
>
,
Tensor
<
CDataType
>
,
Tensor
<
CDataType
>
,
Sub_Exp
,
1
>
(
host_broadcast2D
<
Tensor
<
CDataType
>
,
Tensor
<
CDataType
>
,
Tensor
<
CDataType
>
,
Sub_Exp
,
1
>
(
host_exp_m_n
,
host_
c_m_n
,
host_
c_n_max
,
M
,
N
,
Sub_Exp
{});
host_exp_m_n
,
c_m_n
,
c_n_max
,
M
,
N
,
Sub_Exp
{});
host_reduce_sum
.
Run
(
1
,
// alpha
host_reduce_sum
.
Run
(
1
,
// alpha
reinterpret_cast
<
const
HostReduceDataType
*>
(
host_
exp_m_n
.
mData
.
data
()),
reinterpret_cast
<
const
HostReduceDataType
*>
(
exp_m_n
.
mData
.
data
()),
0
,
// beta
0
,
// beta
reinterpret_cast
<
HostReduceDataType
*>
(
host_exp_n_sum
.
mData
.
data
()),
reinterpret_cast
<
HostReduceDataType
*>
(
host_exp_n_sum
.
mData
.
data
()),
host_indices
.
mData
.
data
());
host_indices
.
mData
.
data
());
host_broadcast2D
<
Tensor
<
CDataType
>
,
Tensor
<
CDataType
>
,
Tensor
<
CDataType
>
,
Div
,
1
>
(
host_broadcast2D
<
Tensor
<
CDataType
>
,
Tensor
<
CDataType
>
,
Tensor
<
CDataType
>
,
Div
,
1
>
(
host_softmax_m_n
,
host_exp_m_n
,
host_exp_n_sum
,
M
,
N
,
Div
{});
host_softmax_m_n
,
exp_m_n
,
exp_n_sum
,
M
,
N
,
Div
{});
c_m_n_device_buf
.
FromDevice
(
c_m_n
.
mData
.
data
());
c_n_max_device_buf
.
FromDevice
(
c_n_max
.
mData
.
data
());
exp_m_n_device_buf
.
FromDevice
(
exp_m_n
.
mData
.
data
());
exp_n_sum_device_buf
.
FromDevice
(
exp_n_sum
.
mData
.
data
());
softmax_m_n_device_buf
.
FromDevice
(
softmax_m_n
.
mData
.
data
());
bool
result
=
true
;
bool
result
=
true
;
if
(
result
&=
ck
::
utils
::
check_err
(
c_m_n
.
mData
,
host_c_m_n
.
mData
))
if
(
result
&=
ck
::
utils
::
check_err
(
c_m_n
.
mData
,
host_c_m_n
.
mData
))
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment