Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
680cfaa7
Commit
680cfaa7
authored
Apr 22, 2022
by
rocking
Browse files
Fix the meaning of broadcast dim parameter
parent
5d36f7a2
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
3 additions
and
3 deletions
+3
-3
example/19_gemm_softmax/gemm_softmax_xdl_fp16.cpp
example/19_gemm_softmax/gemm_softmax_xdl_fp16.cpp
+3
-3
No files found.
example/19_gemm_softmax/gemm_softmax_xdl_fp16.cpp
View file @
680cfaa7
...
@@ -223,7 +223,7 @@ void host_broadcast2D(
...
@@ -223,7 +223,7 @@ void host_broadcast2D(
{
{
ComputeDataType
Amn
=
static_cast
<
ComputeDataType
>
(
A
(
m
,
n
));
ComputeDataType
Amn
=
static_cast
<
ComputeDataType
>
(
A
(
m
,
n
));
ComputeDataType
Cmn
=
0
;
ComputeDataType
Cmn
=
0
;
if
constexpr
(
broadcastDim
==
1
)
if
constexpr
(
broadcastDim
==
0
)
{
{
ComputeDataType
Bn
=
static_cast
<
ComputeDataType
>
(
B
(
n
));
ComputeDataType
Bn
=
static_cast
<
ComputeDataType
>
(
B
(
n
));
functor
(
Cmn
,
Amn
,
Bn
);
functor
(
Cmn
,
Amn
,
Bn
);
...
@@ -516,7 +516,7 @@ int main(int argc, char* argv[])
...
@@ -516,7 +516,7 @@ int main(int argc, char* argv[])
Tensor
<
CDataType
>
,
Tensor
<
CDataType
>
,
EltwiseComputeDataType
,
EltwiseComputeDataType
,
Sub_Exp
,
Sub_Exp
,
1
>
(
host_exp_m_n
,
c_m_n
,
c_n_max
,
M
,
N
,
Sub_Exp
{});
0
>
(
host_exp_m_n
,
c_m_n
,
c_n_max
,
M
,
N
,
Sub_Exp
{});
host_reduce_sum
.
Run
(
1
,
// alpha
host_reduce_sum
.
Run
(
1
,
// alpha
reinterpret_cast
<
const
HostReduceDataType
*>
(
exp_m_n
.
mData
.
data
()),
reinterpret_cast
<
const
HostReduceDataType
*>
(
exp_m_n
.
mData
.
data
()),
...
@@ -529,7 +529,7 @@ int main(int argc, char* argv[])
...
@@ -529,7 +529,7 @@ int main(int argc, char* argv[])
Tensor
<
CDataType
>
,
Tensor
<
CDataType
>
,
EltwiseComputeDataType
,
EltwiseComputeDataType
,
Div
,
Div
,
1
>
(
host_softmax_m_n
,
exp_m_n
,
exp_n_sum
,
M
,
N
,
Div
{});
0
>
(
host_softmax_m_n
,
exp_m_n
,
exp_n_sum
,
M
,
N
,
Div
{});
bool
result
=
true
;
bool
result
=
true
;
if
(
result
&=
ck
::
utils
::
check_err
(
c_m_n
.
mData
,
host_c_m_n
.
mData
))
if
(
result
&=
ck
::
utils
::
check_err
(
c_m_n
.
mData
,
host_c_m_n
.
mData
))
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment