Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
e801e2f7
Commit
e801e2f7
authored
Nov 10, 2023
by
Khalique Ahmed
Browse files
manual merge
parents
e2ec9378
aa56068c
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
31 additions
and
14 deletions
+31
-14
src/auto_contiguous.cpp
src/auto_contiguous.cpp
+9
-4
src/onnx/parse_constant_of_shape.cpp
src/onnx/parse_constant_of_shape.cpp
+1
-1
src/propagate_constant.cpp
src/propagate_constant.cpp
+6
-4
src/simplify_algebra.cpp
src/simplify_algebra.cpp
+11
-0
src/targets/gpu/device/include/migraphx/gpu/device/launch.hpp
...targets/gpu/device/include/migraphx/gpu/device/launch.hpp
+1
-1
test/auto_contiguous_test.cpp
test/auto_contiguous_test.cpp
+3
-4
No files found.
src/auto_contiguous.cpp
View file @
e801e2f7
...
@@ -26,7 +26,6 @@
...
@@ -26,7 +26,6 @@
#include <migraphx/instruction.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/iterator_for.hpp>
namespace
migraphx
{
namespace
migraphx
{
...
@@ -62,9 +61,16 @@ void auto_contiguous::apply(module& m) const
...
@@ -62,9 +61,16 @@ void auto_contiguous::apply(module& m) const
{
{
if
(
contains
({
"layout"
,
"contiguous"
,
"@return"
,
"@param"
,
"@outline"
},
ins
->
name
()))
if
(
contains
({
"layout"
,
"contiguous"
,
"@return"
,
"@param"
,
"@outline"
},
ins
->
name
()))
continue
;
continue
;
auto
outputs
=
ins
->
outputs
();
// for last instruction that is NOT a return
// for last instruction that is NOT a return
if
(
ins
->
outputs
()
.
empty
()
and
ins
!=
last
)
if
(
outputs
.
empty
()
and
ins
!=
last
)
continue
;
continue
;
if
(
not
outputs
.
empty
())
// if contiguous was already inserted, skip
if
(
std
::
all_of
(
outputs
.
begin
(),
outputs
.
end
(),
[](
auto
output
)
{
return
output
->
name
()
==
"contiguous"
;
}))
continue
;
shape
s
=
ins
->
get_shape
();
shape
s
=
ins
->
get_shape
();
if
(
s
.
dynamic
())
if
(
s
.
dynamic
())
continue
;
continue
;
...
@@ -73,9 +79,8 @@ void auto_contiguous::apply(module& m) const
...
@@ -73,9 +79,8 @@ void auto_contiguous::apply(module& m) const
if
(
s
.
standard
()
and
ins
->
name
()
==
"@literal"
)
if
(
s
.
standard
()
and
ins
->
name
()
==
"@literal"
)
continue
;
continue
;
if
(
s
.
scalar
()
and
not
contains
(
ins
->
name
(),
"broadcast"
))
if
(
s
.
scalar
()
and
not
contains
(
ins
->
name
(),
"broadcast"
))
{
continue
;
continue
;
}
auto
c
=
m
.
insert_instruction
(
std
::
next
(
ins
),
make_op
(
"contiguous"
),
ins
);
auto
c
=
m
.
insert_instruction
(
std
::
next
(
ins
),
make_op
(
"contiguous"
),
ins
);
m
.
replace_instruction
(
ins
,
c
);
m
.
replace_instruction
(
ins
,
c
);
}
}
...
...
src/onnx/parse_constant_of_shape.cpp
View file @
e801e2f7
...
@@ -65,7 +65,7 @@ struct parse_constant_of_shape : op_parser<parse_constant_of_shape>
...
@@ -65,7 +65,7 @@ struct parse_constant_of_shape : op_parser<parse_constant_of_shape>
{
{
migraphx
::
shape
s
;
migraphx
::
shape
s
;
// input is empty, output is a scalar
// input is empty, output is a scalar
auto
type
=
l_val
.
get_shape
().
type
();
auto
type
=
l_val
.
get_shape
().
type
();
migraphx
::
argument
input
=
args
[
0
]
->
eval
();
migraphx
::
argument
input
=
args
[
0
]
->
eval
();
if
(
not
input
.
empty
())
if
(
not
input
.
empty
())
{
{
...
...
src/propagate_constant.cpp
View file @
e801e2f7
...
@@ -40,7 +40,7 @@ bool skip_propagate(instruction_ref ins)
...
@@ -40,7 +40,7 @@ bool skip_propagate(instruction_ref ins)
if
(
ins
->
name
()
==
"contiguous"
)
if
(
ins
->
name
()
==
"contiguous"
)
return
skip_propagate
(
ins
->
inputs
().
front
());
return
skip_propagate
(
ins
->
inputs
().
front
());
auto
&&
s
=
ins
->
get_shape
();
auto
&&
s
=
ins
->
get_shape
();
if
(
s
.
broadcasted
()
and
not
s
.
scalar
())
if
(
s
.
broadcasted
()
and
not
s
.
scalar
()
and
not
s
.
packed
()
)
return
true
;
return
true
;
if
(
s
.
scalar
()
and
s
.
elements
()
!=
1
)
if
(
s
.
scalar
()
and
s
.
elements
()
!=
1
)
return
true
;
return
true
;
...
@@ -101,9 +101,11 @@ void propagate_constant::apply(module& m) const
...
@@ -101,9 +101,11 @@ void propagate_constant::apply(module& m) const
})(
const_instrs_vec
[
i
]);
})(
const_instrs_vec
[
i
]);
m
.
debug_print
(
inss
);
m
.
debug_print
(
inss
);
}
}
assert
(
literals
[
i
].
get_shape
()
==
const_instrs_vec
[
i
]
->
get_shape
());
auto
in_shape
=
const_instrs_vec
[
i
]
->
get_shape
();
auto
l
=
m
.
add_literal
(
literals
[
i
].
get_shape
(),
literals
[
i
].
data
());
assert
(
literals
[
i
].
get_shape
()
==
in_shape
);
m
.
replace_instruction
(
const_instrs_vec
[
i
],
l
);
literal
l
{
in_shape
,
literals
[
i
].
data
()};
auto
l0
=
m
.
add_literal
(
l
);
m
.
replace_instruction
(
const_instrs_vec
[
i
],
l0
);
}
}
}
}
}
}
...
...
src/simplify_algebra.cpp
View file @
e801e2f7
...
@@ -564,6 +564,17 @@ struct find_inner_broadcast
...
@@ -564,6 +564,17 @@ struct find_inner_broadcast
return
3
;
return
3
;
}));
}));
auto
op
=
insert_common_op
(
m
,
ins
,
ins
->
get_operator
(),
inputs
);
auto
op
=
insert_common_op
(
m
,
ins
,
ins
->
get_operator
(),
inputs
);
std
::
vector
<
shape
>
broadcast_shapes
;
std
::
transform
(
broadcasts
.
begin
(),
broadcasts
.
end
(),
std
::
back_inserter
(
broadcast_shapes
),
[](
auto
broadcast
){
return
broadcast
->
get_shape
();
});
std
::
vector
<
shape
>
common_shapes
;
std
::
transform
(
op
->
inputs
().
begin
(),
op
->
inputs
().
end
(),
std
::
back_inserter
(
common_shapes
),
[](
auto
common
){
return
common
->
get_shape
();
});
if
(
broadcast_shapes
==
common_shapes
and
std
::
all_of
(
op
->
inputs
().
begin
(),
op
->
inputs
().
end
(),
[](
auto
i
){
return
i
->
name
()
==
"broadcast"
or
i
->
name
()
==
"multibroadcast"
;}))
return
;
m
.
replace_instruction
(
ins
,
broadcasts
.
front
()
->
get_operator
(),
op
);
m
.
replace_instruction
(
ins
,
broadcasts
.
front
()
->
get_operator
(),
op
);
}
}
};
};
...
...
src/targets/gpu/device/include/migraphx/gpu/device/launch.hpp
View file @
e801e2f7
...
@@ -43,7 +43,7 @@ struct index
...
@@ -43,7 +43,7 @@ struct index
__device__
index_int
nglobal
()
const
{
return
blockDim
.
x
*
gridDim
.
x
;
}
// NOLINT
__device__
index_int
nglobal
()
const
{
return
blockDim
.
x
*
gridDim
.
x
;
}
// NOLINT
__device__
index_int
nlocal
()
const
{
return
blockDim
.
x
;
}
// NOLINT
__device__
index_int
nlocal
()
const
{
return
blockDim
.
x
;
}
// NOLINT
template
<
class
F
>
template
<
class
F
>
__device__
void
global_stride
(
index_int
n
,
F
f
)
const
__device__
void
global_stride
(
index_int
n
,
F
f
)
const
...
...
test/auto_contiguous_test.cpp
View file @
e801e2f7
...
@@ -179,7 +179,8 @@ TEST_CASE(standard_reshape_lazy)
...
@@ -179,7 +179,8 @@ TEST_CASE(standard_reshape_lazy)
auto
ca
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"contiguous"
),
add
);
auto
ca
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"contiguous"
),
add
);
auto
r
=
auto
r
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"reshape_lazy"
,
{{
"dims"
,
{
2
,
1
,
12
,
5
}}}),
ca
);
m2
.
add_instruction
(
migraphx
::
make_op
(
"reshape_lazy"
,
{{
"dims"
,
{
2
,
1
,
12
,
5
}}}),
ca
);
m2
.
add_return
({
r
});
auto
cr
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"contiguous"
),
r
);
m2
.
add_return
({
cr
});
}
}
EXPECT
(
m1
==
m2
);
EXPECT
(
m1
==
m2
);
...
@@ -201,9 +202,7 @@ TEST_CASE(standard_reshape)
...
@@ -201,9 +202,7 @@ TEST_CASE(standard_reshape)
auto
data
=
m2
.
add_parameter
(
"2x2"
,
{
migraphx
::
shape
::
float_type
,
{
2
,
3
,
4
,
5
}});
auto
data
=
m2
.
add_parameter
(
"2x2"
,
{
migraphx
::
shape
::
float_type
,
{
2
,
3
,
4
,
5
}});
auto
add
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"add"
),
data
,
data
);
auto
add
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"add"
),
data
,
data
);
auto
ca
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"contiguous"
),
add
);
auto
ca
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"contiguous"
),
add
);
// extra contiguous coming from reshape logic which has "requires_std_shape" attribute
auto
r
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
{
2
,
1
,
12
,
5
}}}),
ca
);
auto
cb
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"contiguous"
),
ca
);
auto
r
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
{
2
,
1
,
12
,
5
}}}),
cb
);
auto
cr
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"contiguous"
),
r
);
auto
cr
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"contiguous"
),
r
);
m2
.
add_return
({
cr
});
m2
.
add_return
({
cr
});
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment