Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
c1ed00b6
Commit
c1ed00b6
authored
Feb 10, 2023
by
ltqin
Browse files
fix:yygrad_threadwise_copy InvalidElementAsNaN to false
parent
0016f6ac
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
3 additions
and
5 deletions
+3
-5
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_backward_fp16.cpp
...oftmax_gemm/batched_multihead_attention_backward_fp16.cpp
+1
-3
include/ck/tensor_operation/gpu/grid/gridwise_batched_multihead_attention_backward_xdl_cshuffle_v1.hpp
..._batched_multihead_attention_backward_xdl_cshuffle_v1.hpp
+2
-2
No files found.
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_backward_fp16.cpp
View file @
c1ed00b6
...
@@ -30,7 +30,6 @@ Kernel outputs:
...
@@ -30,7 +30,6 @@ Kernel outputs:
#include <numeric>
#include <numeric>
#include <initializer_list>
#include <initializer_list>
#include <cstdlib>
#include <cstdlib>
#include <fstream>
#include "ck/ck.hpp"
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
...
@@ -611,9 +610,8 @@ int run(int argc, char* argv[])
...
@@ -611,9 +610,8 @@ int run(int argc, char* argv[])
<<
gemm
.
GetTypeString
()
<<
std
::
endl
;
<<
gemm
.
GetTypeString
()
<<
std
::
endl
;
// copy z matirx data form device
// copy z matirx data form device
std
::
ofstream
file
(
"./z_matrix_txt"
);
z_device_buf
.
FromDevice
(
z_g_m_n
.
mData
.
data
());
z_device_buf
.
FromDevice
(
z_g_m_n
.
mData
.
data
());
file
<<
z_g_m_n
<<
std
::
endl
;
// std::cout << "z_g_m_n ref:\n" << z_g_m_n;
// std::cout << "z_g_m_n ref:\n" << z_g_m_n;
bool
pass
=
true
;
bool
pass
=
true
;
if
(
do_verification
)
if
(
do_verification
)
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_multihead_attention_backward_xdl_cshuffle_v1.hpp
View file @
c1ed00b6
...
@@ -1649,8 +1649,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
...
@@ -1649,8 +1649,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
YDotYGrad_M_O
::
SrcScalarPerVector
,
// SrcScalarPerVector
YDotYGrad_M_O
::
SrcScalarPerVector
,
// SrcScalarPerVector
1
,
// SrcScalarStrideInVector
1
,
// SrcScalarStrideInVector
true
/* ResetCoordAfterRun */
,
true
/* ResetCoordAfterRun */
,
tru
e
/* InvalidElementAsNaN */
>
(
y_grid_desc_mblock_mperblock_oblock_operblock
,
fals
e
/* InvalidElementAsNaN */
>
(
y_grid_desc_mblock_mperblock_oblock_operblock
,
y_thread_data_on_grid_idx
);
y_thread_data_on_grid_idx
);
auto
y_thread_buf
=
typename
YDotYGrad_M_O
::
SrcBufType
{};
auto
y_thread_buf
=
typename
YDotYGrad_M_O
::
SrcBufType
{};
auto
ygrad_thread_buf
=
typename
YDotYGrad_M_O
::
SrcBufType
{};
auto
ygrad_thread_buf
=
typename
YDotYGrad_M_O
::
SrcBufType
{};
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment