Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
e21c1785
Commit
e21c1785
authored
Oct 11, 2022
by
Astha Rai
Browse files
changed isSupportedArgument for 2D
parent
64026bc3
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
13 additions
and
4 deletions
+13
-4
include/ck/tensor_operation/gpu/device/device_elementwise_2d.hpp
.../ck/tensor_operation/gpu/device/device_elementwise_2d.hpp
+13
-4
No files found.
include/ck/tensor_operation/gpu/device/device_elementwise_2d.hpp
View file @
e21c1785
...
...
@@ -250,26 +250,35 @@ struct DeviceElementwise
if
(
pArg
->
lengths_
.
back
()
%
MPerThread
!=
0
)
return
false
;
std
::
cout
<<
"lengths back: "
<<
pArg
->
lengths_
.
back
()
<<
std
::
endl
;
auto
IsScalarPerVectorValid
=
[
&
](
const
std
::
array
<
index_t
,
NumDim
>&
lengths
,
const
std
::
array
<
index_t
,
NumDim
>&
strides
,
index_t
scalarPerVector
)
{
if
(
strides
.
back
()
==
1
&&
lengths
.
back
()
%
scalarPerVector
==
0
)
return
true
;
std
::
cout
<<
"scalarPerVector: "
<<
scalarPerVector
<<
std
::
endl
;
std
::
cout
<<
"stride back: "
<<
strides
.
back
()
<<
std
::
endl
;
std
::
cout
<<
"ISPVV Check 1 starting"
<<
std
::
endl
;
if
(
strides
.
back
()
==
1
&&
lengths
.
back
()
%
scalarPerVector
==
0
){
return
true
;
}
std
::
cout
<<
"Check 1 failed "
<<
std
::
endl
;
if
(
strides
.
back
()
!=
1
&&
scalarPerVector
==
1
)
return
true
;
std
::
cout
<<
"ISPVV Check 2 starting"
<<
std
::
endl
;
if
(
strides
.
back
()
!=
1
&&
scalarPerVector
==
MPerThread
){
return
true
;
}
return
false
;
};
bool
valid
=
true
;
static_for
<
0
,
NumInput
,
1
>
{}([
&
](
auto
I
)
{
std
::
cout
<<
"running: "
<<
I
<<
std
::
endl
;
if
(
!
IsScalarPerVectorValid
(
pArg
->
lengths_
,
pArg
->
inStridesArray_
[
I
.
value
],
InScalarPerVectorSeq
::
At
(
I
)))
valid
=
false
;
});
static_for
<
0
,
NumOutput
,
1
>
{}([
&
](
auto
I
)
{
std
::
cout
<<
"running 2: "
<<
I
<<
std
::
endl
;
if
(
!
IsScalarPerVectorValid
(
pArg
->
lengths_
,
pArg
->
outStridesArray_
[
I
.
value
],
OutScalarPerVectorSeq
::
At
(
I
)))
valid
=
false
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment