Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
b98308b8
Unverified
Commit
b98308b8
authored
Dec 27, 2022
by
Charlie Lin
Committed by
GitHub
Dec 27, 2022
Browse files
Merge branch 'develop' into dyn_onnx_matmul
parents
b48c4cf6
56c43445
Changes
58
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
344 additions
and
105 deletions
+344
-105
Dockerfile
Dockerfile
+1
-1
src/common.cpp
src/common.cpp
+1
-2
src/dead_code_elimination.cpp
src/dead_code_elimination.cpp
+2
-2
src/driver/main.cpp
src/driver/main.cpp
+8
-2
src/file_buffer.cpp
src/file_buffer.cpp
+16
-8
src/include/migraphx/file_buffer.hpp
src/include/migraphx/file_buffer.hpp
+1
-1
src/include/migraphx/instruction.hpp
src/include/migraphx/instruction.hpp
+2
-0
src/include/migraphx/literal.hpp
src/include/migraphx/literal.hpp
+6
-15
src/include/migraphx/module.hpp
src/include/migraphx/module.hpp
+6
-0
src/include/migraphx/op/argmax.hpp
src/include/migraphx/op/argmax.hpp
+19
-11
src/include/migraphx/op/flatten.hpp
src/include/migraphx/op/flatten.hpp
+39
-9
src/include/migraphx/op/pooling.hpp
src/include/migraphx/op/pooling.hpp
+115
-36
src/include/migraphx/op/softmax.hpp
src/include/migraphx/op/softmax.hpp
+5
-5
src/include/migraphx/op/squeeze.hpp
src/include/migraphx/op/squeeze.hpp
+5
-10
src/include/migraphx/program.hpp
src/include/migraphx/program.hpp
+1
-0
src/include/migraphx/shape.hpp
src/include/migraphx/shape.hpp
+6
-0
src/include/migraphx/shape_for_each.hpp
src/include/migraphx/shape_for_each.hpp
+3
-1
src/insert_pad.cpp
src/insert_pad.cpp
+2
-2
src/instruction.cpp
src/instruction.cpp
+18
-0
src/module.cpp
src/module.cpp
+88
-0
No files found.
Dockerfile
View file @
b98308b8
...
@@ -87,7 +87,7 @@ RUN git clone --single-branch --branch ${ONNXRUNTIME_BRANCH} --recursive ${ONNXR
...
@@ -87,7 +87,7 @@ RUN git clone --single-branch --branch ${ONNXRUNTIME_BRANCH} --recursive ${ONNXR
ADD
tools/build_and_test_onnxrt.sh /onnxruntime/build_and_test_onnxrt.sh
ADD
tools/build_and_test_onnxrt.sh /onnxruntime/build_and_test_onnxrt.sh
RUN
cget
-p
/usr/local
install
ROCmSoftwarePlatform/
llvm-project-mlir@c0723a7e50043d973cb73ae51dc30d36679ee7e5
-DBUILD_MIXR_TARGET
=
On
RUN
cget
-p
/usr/local
install
ROCmSoftwarePlatform/
rocMLIR@0f38fb33f518b53b94b541feb9b079668c5518e8
-DBUILD_MIXR_TARGET
=
On
-DLLVM_ENABLE_ZSTD
=
Off
-DLLVM_ENABLE_THREADS
=
Off
ENV
MIOPEN_FIND_DB_PATH=/tmp/miopen/find-db
ENV
MIOPEN_FIND_DB_PATH=/tmp/miopen/find-db
ENV
MIOPEN_USER_DB_PATH=/tmp/miopen/user-db
ENV
MIOPEN_USER_DB_PATH=/tmp/miopen/user-db
...
...
src/common.cpp
View file @
b98308b8
...
@@ -77,7 +77,6 @@ std::vector<shape::dynamic_dimension> compute_broadcasted_dyn_dims(shape s0, sha
...
@@ -77,7 +77,6 @@ std::vector<shape::dynamic_dimension> compute_broadcasted_dyn_dims(shape s0, sha
}
}
auto
offset
=
s1
.
ndim
()
-
s0
.
ndim
();
auto
offset
=
s1
.
ndim
()
-
s0
.
ndim
();
std
::
vector
<
shape
::
dynamic_dimension
>
out_dims
(
s1
.
dyn_dims
());
std
::
vector
<
shape
::
dynamic_dimension
>
out_dims
(
s1
.
dyn_dims
());
shape
::
dynamic_dimension
one_dyn_dim
{
1
,
1
,
0
};
std
::
transform
(
std
::
transform
(
s0
.
dyn_dims
().
cbegin
(),
s0
.
dyn_dims
().
cbegin
(),
s0
.
dyn_dims
().
cend
(),
s0
.
dyn_dims
().
cend
(),
...
@@ -88,7 +87,7 @@ std::vector<shape::dynamic_dimension> compute_broadcasted_dyn_dims(shape s0, sha
...
@@ -88,7 +87,7 @@ std::vector<shape::dynamic_dimension> compute_broadcasted_dyn_dims(shape s0, sha
{
{
return
a
;
return
a
;
}
}
else
if
(
a
==
one_dyn_dim
or
b
==
one_dyn_dim
)
else
if
(
a
==
1
or
b
==
1
)
{
{
// setting opt to 0, may need to be changed
// setting opt to 0, may need to be changed
return
shape
::
dynamic_dimension
{
std
::
max
(
a
.
min
,
b
.
min
),
std
::
max
(
a
.
max
,
b
.
max
),
0
};
return
shape
::
dynamic_dimension
{
std
::
max
(
a
.
min
,
b
.
min
),
std
::
max
(
a
.
max
,
b
.
max
),
0
};
...
...
src/dead_code_elimination.cpp
View file @
b98308b8
...
@@ -51,8 +51,8 @@ void dead_code_elimination::apply(module& m) const
...
@@ -51,8 +51,8 @@ void dead_code_elimination::apply(module& m) const
// Skip instruction with empty shape as output unless its [dynamic, builtin, undefined,
// Skip instruction with empty shape as output unless its [dynamic, builtin, undefined,
// identity, allocate]
// identity, allocate]
if
((
not
i
->
get_shape
().
dynamic
()
and
i
->
get_shape
().
elements
()
==
0
)
and
if
((
not
i
->
get_shape
().
dynamic
()
and
i
->
get_shape
().
elements
()
==
0
)
and
i
->
name
().
front
()
!
=
'@'
and
not
(
i
->
name
().
front
()
=
=
'@'
)
and
not
contains
({
"identity"
,
"allocate"
},
i
->
name
())
and
not
contains
({
"undefined"
,
"identity"
,
"allocate"
},
i
->
name
()
))
not
i
->
is_undefined
(
))
continue
;
continue
;
assert
(
std
::
distance
(
m
.
begin
(),
i
)
<=
std
::
distance
(
m
.
begin
(),
last
));
assert
(
std
::
distance
(
m
.
begin
(),
i
)
<=
std
::
distance
(
m
.
begin
(),
last
));
std
::
unordered_set
<
instruction_ref
>
visited
;
std
::
unordered_set
<
instruction_ref
>
visited
;
...
...
src/driver/main.cpp
View file @
b98308b8
...
@@ -109,8 +109,12 @@ struct loader
...
@@ -109,8 +109,12 @@ struct loader
ap
(
brief
,
{
"--brief"
},
ap
.
help
(
"Make the output brief."
),
ap
.
set_value
(
true
));
ap
(
brief
,
{
"--brief"
},
ap
.
help
(
"Make the output brief."
),
ap
.
set_value
(
true
));
ap
(
output_type
,
ap
(
output_type
,
{
"--cpp"
},
{
"--cpp"
},
ap
.
help
(
"Print out the program as
cpp
program."
),
ap
.
help
(
"Print out the program as
C++
program."
),
ap
.
set_value
(
"cpp"
));
ap
.
set_value
(
"cpp"
));
ap
(
output_type
,
{
"--python"
,
"--py"
},
ap
.
help
(
"Print out the program as python program."
),
ap
.
set_value
(
"py"
));
ap
(
output_type
,
{
"--json"
},
ap
.
help
(
"Print out program as json."
),
ap
.
set_value
(
"json"
));
ap
(
output_type
,
{
"--json"
},
ap
.
help
(
"Print out program as json."
),
ap
.
set_value
(
"json"
));
ap
(
output_type
,
ap
(
output_type
,
{
"--text"
},
{
"--text"
},
...
@@ -259,7 +263,9 @@ struct loader
...
@@ -259,7 +263,9 @@ struct loader
type
=
"binary"
;
type
=
"binary"
;
}
}
if
(
type
==
"cpp"
)
if
(
type
==
"py"
)
p
.
print_py
(
*
os
);
else
if
(
type
==
"cpp"
)
p
.
print_cpp
(
*
os
);
p
.
print_cpp
(
*
os
);
else
if
(
type
==
"graphviz"
)
else
if
(
type
==
"graphviz"
)
p
.
print_graph
(
*
os
,
brief
);
p
.
print_graph
(
*
os
,
brief
);
...
...
src/file_buffer.cpp
View file @
b98308b8
...
@@ -30,23 +30,31 @@ namespace migraphx {
...
@@ -30,23 +30,31 @@ namespace migraphx {
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
template
<
class
T
>
template
<
class
T
>
T
generic_read_file
(
const
std
::
string
&
filename
)
T
generic_read_file
(
const
std
::
string
&
filename
,
size_t
offset
=
0
,
size_t
nbytes
=
0
)
{
{
std
::
ifstream
is
(
filename
,
std
::
ios
::
binary
|
std
::
ios
::
ate
);
std
::
ifstream
is
(
filename
,
std
::
ios
::
binary
|
std
::
ios
::
ate
);
std
::
streamsize
size
=
is
.
tellg
();
if
(
nbytes
==
0
)
if
(
size
<
1
)
{
// if there is a non-zero offset and nbytes is not set,
// calculate size of remaining bytes to read
nbytes
=
is
.
tellg
();
if
(
offset
>
nbytes
)
MIGRAPHX_THROW
(
"offset is larger than file size"
);
nbytes
-=
offset
;
}
if
(
nbytes
<
1
)
MIGRAPHX_THROW
(
"Invalid size for: "
+
filename
);
MIGRAPHX_THROW
(
"Invalid size for: "
+
filename
);
is
.
seekg
(
0
,
std
::
ios
::
beg
);
is
.
seekg
(
offset
,
std
::
ios
::
beg
);
T
buffer
(
size
,
0
);
T
buffer
(
nbytes
,
0
);
if
(
not
is
.
read
(
&
buffer
[
0
],
size
))
if
(
not
is
.
read
(
&
buffer
[
0
],
nbytes
))
MIGRAPHX_THROW
(
"Error reading file: "
+
filename
);
MIGRAPHX_THROW
(
"Error reading file: "
+
filename
);
return
buffer
;
return
buffer
;
}
}
std
::
vector
<
char
>
read_buffer
(
const
std
::
string
&
filename
)
std
::
vector
<
char
>
read_buffer
(
const
std
::
string
&
filename
,
size_t
offset
,
size_t
nbytes
)
{
{
return
generic_read_file
<
std
::
vector
<
char
>>
(
filename
);
return
generic_read_file
<
std
::
vector
<
char
>>
(
filename
,
offset
,
nbytes
);
}
}
std
::
string
read_string
(
const
std
::
string
&
filename
)
std
::
string
read_string
(
const
std
::
string
&
filename
)
...
...
src/include/migraphx/file_buffer.hpp
View file @
b98308b8
...
@@ -31,7 +31,7 @@
...
@@ -31,7 +31,7 @@
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
std
::
vector
<
char
>
read_buffer
(
const
std
::
string
&
filename
);
std
::
vector
<
char
>
read_buffer
(
const
std
::
string
&
filename
,
size_t
offset
=
0
,
size_t
nbytes
=
0
);
std
::
string
read_string
(
const
std
::
string
&
filename
);
std
::
string
read_string
(
const
std
::
string
&
filename
);
void
write_buffer
(
const
std
::
string
&
filename
,
const
char
*
buffer
,
std
::
size_t
size
);
void
write_buffer
(
const
std
::
string
&
filename
,
const
char
*
buffer
,
std
::
size_t
size
);
...
...
src/include/migraphx/instruction.hpp
View file @
b98308b8
...
@@ -121,6 +121,8 @@ struct instruction
...
@@ -121,6 +121,8 @@ struct instruction
bool
can_eval
()
const
;
bool
can_eval
()
const
;
bool
is_undefined
()
const
;
argument
eval
(
bool
check_eval
=
true
)
const
;
argument
eval
(
bool
check_eval
=
true
)
const
;
void
finalize
(
context
&
ctx
);
void
finalize
(
context
&
ctx
);
...
...
src/include/migraphx/literal.hpp
View file @
b98308b8
...
@@ -80,6 +80,7 @@ struct literal : raw_data<literal>
...
@@ -80,6 +80,7 @@ struct literal : raw_data<literal>
fill
(
start
,
end
);
fill
(
start
,
end
);
}
}
// Directly copies buffer of x
template
<
class
T
,
MIGRAPHX_REQUIRES
(
sizeof
(
T
)
==
1
)>
template
<
class
T
,
MIGRAPHX_REQUIRES
(
sizeof
(
T
)
==
1
)>
literal
(
const
shape
&
s
,
T
*
x
)
:
buffer
(
make_shared_array
<
char
>
(
s
.
bytes
())),
m_shape
(
s
)
literal
(
const
shape
&
s
,
T
*
x
)
:
buffer
(
make_shared_array
<
char
>
(
s
.
bytes
())),
m_shape
(
s
)
{
{
...
@@ -107,25 +108,15 @@ struct literal : raw_data<literal>
...
@@ -107,25 +108,15 @@ struct literal : raw_data<literal>
std
::
shared_ptr
<
char
>
buffer
;
std
::
shared_ptr
<
char
>
buffer
;
shape
m_shape
;
shape
m_shape
;
// Keeps the same data ordering as the given container
template
<
class
Iterator
>
template
<
class
Iterator
>
void
fill
(
Iterator
start
,
Iterator
end
)
void
fill
(
Iterator
start
,
Iterator
end
)
{
{
assert
(
std
::
distance
(
start
,
end
)
==
m_shape
.
elements
());
assert
(
std
::
distance
(
start
,
end
)
==
m_shape
.
elements
());
if
(
m_shape
.
standard
())
m_shape
.
visit_type
([
&
](
auto
as
)
{
{
auto
output
=
make_view
(
m_shape
,
as
.
from
(
buffer
.
get
()));
m_shape
.
visit_type
([
&
](
auto
as
)
{
std
::
copy
(
start
,
end
,
as
.
from
(
buffer
.
get
()));
});
std
::
copy
(
start
,
end
,
output
.
begin
());
}
});
else
{
auto
it
=
start
;
m_shape
.
visit_type
([
&
](
auto
as
)
{
auto
output
=
make_view
(
m_shape
,
as
.
from
(
buffer
.
get
()));
shape_for_each
(
output
.
get_shape
(),
[
&
](
const
auto
&
idx
)
{
output
(
idx
.
begin
(),
idx
.
end
())
=
*
it
;
// NOLINT(bugprone-signed-char-misuse)
it
++
;
});
});
}
}
}
};
};
...
...
src/include/migraphx/module.hpp
View file @
b98308b8
...
@@ -205,6 +205,12 @@ struct module
...
@@ -205,6 +205,12 @@ struct module
void
print_graph
(
std
::
ostream
&
os
,
bool
brief
=
false
)
const
;
void
print_graph
(
std
::
ostream
&
os
,
bool
brief
=
false
)
const
;
void
print_py
(
std
::
ostream
&
os
)
const
;
std
::
unordered_map
<
instruction_ref
,
std
::
string
>
print_py
(
std
::
ostream
&
os
,
const
std
::
string
&
mname
,
std
::
unordered_map
<
instruction_ref
,
std
::
string
>
names
)
const
;
void
print_cpp
(
std
::
ostream
&
os
)
const
;
void
print_cpp
(
std
::
ostream
&
os
)
const
;
std
::
unordered_map
<
instruction_ref
,
std
::
string
>
std
::
unordered_map
<
instruction_ref
,
std
::
string
>
print_cpp
(
std
::
ostream
&
os
,
print_cpp
(
std
::
ostream
&
os
,
...
...
src/include/migraphx/op/argmax.hpp
View file @
b98308b8
...
@@ -30,6 +30,7 @@
...
@@ -30,6 +30,7 @@
#include <migraphx/config.hpp>
#include <migraphx/config.hpp>
#include <migraphx/value.hpp>
#include <migraphx/value.hpp>
#include <migraphx/op/normalize_attribute.hpp>
#include <migraphx/op/normalize_attribute.hpp>
#include <migraphx/dyn_output.hpp>
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
@@ -56,12 +57,20 @@ struct argmax
...
@@ -56,12 +57,20 @@ struct argmax
shape
normalize_compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
shape
normalize_compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
{
check_shapes
{
inputs
,
*
this
}.
has
(
1
);
check_shapes
{
inputs
,
*
this
,
true
}.
has
(
1
);
auto
lens
=
inputs
[
0
].
lens
();
const
auto
&
s0
=
inputs
[
0
];
if
(
s0
.
dynamic
())
lens
[
axis
]
=
1
;
{
auto
dyn_dims
=
s0
.
dyn_dims
();
return
{
shape
::
int64_type
,
lens
};
dyn_dims
[
axis
]
=
{
1
,
1
,
0
};
return
{
shape
::
int64_type
,
dyn_dims
};
}
else
{
auto
lens
=
s0
.
lens
();
lens
[
axis
]
=
1
;
return
{
shape
::
int64_type
,
lens
};
}
}
}
template
<
class
T
>
template
<
class
T
>
...
@@ -79,19 +88,18 @@ struct argmax
...
@@ -79,19 +88,18 @@ struct argmax
max_index
=
i
;
max_index
=
i
;
}
}
}
}
return
max_index
;
return
max_index
;
}
}
argument
compute
(
const
shape
&
output_shape
,
std
::
vector
<
argument
>
args
)
const
argument
compute
(
const
dyn_output
&
dyn_out
,
std
::
vector
<
argument
>
args
)
const
{
{
argument
result
{
out
put_shape
};
argument
result
{
dyn_out
.
com
put
ed
_shape
};
auto
batch_item_num
=
args
.
front
().
get_shape
().
lens
()[
axis
];
auto
batch_item_num
=
args
.
front
().
get_shape
().
lens
()[
axis
];
result
.
visit
([
&
](
auto
output
)
{
result
.
visit
([
&
](
auto
output
)
{
args
[
0
].
visit
([
&
](
auto
input
)
{
args
[
0
].
visit
([
&
](
auto
input
)
{
par_for
(
out
put_shape
.
elements
(),
[
&
](
auto
i
)
{
par_for
(
dyn_out
.
com
put
ed
_shape
.
elements
(),
[
&
](
auto
i
)
{
auto
data_idx
=
out
put_shape
.
multi
(
i
);
auto
data_idx
=
dyn_out
.
com
put
ed
_shape
.
multi
(
i
);
output
[
i
]
=
this
->
calc_argmax
(
input
,
data_idx
,
batch_item_num
);
output
[
i
]
=
this
->
calc_argmax
(
input
,
data_idx
,
batch_item_num
);
});
});
});
});
...
...
src/include/migraphx/op/flatten.hpp
View file @
b98308b8
...
@@ -55,17 +55,47 @@ struct flatten
...
@@ -55,17 +55,47 @@ struct flatten
std
::
string
name
()
const
{
return
"flatten"
;
}
std
::
string
name
()
const
{
return
"flatten"
;
}
shape
normalize_compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
shape
normalize_compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
{
check_shapes
{
inputs
,
*
this
}.
has
(
1
).
standard
();
check_shapes
{
inputs
,
*
this
,
true
}.
has
(
1
);
auto
&&
lens
=
inputs
.
front
().
lens
();
auto
s
=
inputs
[
0
];
auto
x
=
if
(
s
.
dynamic
())
std
::
accumulate
(
lens
.
begin
(),
lens
.
begin
()
+
axis
,
std
::
size_t
{
1
},
std
::
multiplies
<>
{});
{
auto
y
=
auto
min_lens
=
s
.
min_lens
();
std
::
accumulate
(
lens
.
begin
()
+
axis
,
lens
.
end
(),
std
::
size_t
{
1
},
std
::
multiplies
<>
{});
auto
max_lens
=
s
.
max_lens
();
return
{
inputs
.
at
(
0
).
type
(),
{
x
,
y
}};
auto
opt_lens
=
s
.
opt_lens
();
// If any of the opt values is 0, output opt will be 0
shape
::
dynamic_dimension
x
=
{
std
::
accumulate
(
min_lens
.
begin
(),
min_lens
.
begin
()
+
axis
,
std
::
size_t
{
1
},
std
::
multiplies
<>
{}),
std
::
accumulate
(
max_lens
.
begin
(),
max_lens
.
begin
()
+
axis
,
std
::
size_t
{
1
},
std
::
multiplies
<>
{}),
std
::
accumulate
(
opt_lens
.
begin
(),
opt_lens
.
begin
()
+
axis
,
std
::
size_t
{
1
},
std
::
multiplies
<>
{})};
shape
::
dynamic_dimension
y
=
{
std
::
accumulate
(
min_lens
.
begin
()
+
axis
,
min_lens
.
end
(),
std
::
size_t
{
1
},
std
::
multiplies
<>
{}),
std
::
accumulate
(
max_lens
.
begin
()
+
axis
,
max_lens
.
end
(),
std
::
size_t
{
1
},
std
::
multiplies
<>
{}),
std
::
accumulate
(
opt_lens
.
begin
()
+
axis
,
opt_lens
.
end
(),
std
::
size_t
{
1
},
std
::
multiplies
<>
{}),
};
return
{
s
.
type
(),
{
x
,
y
}};
}
else
{
check_shapes
{
inputs
,
*
this
}.
standard
();
auto
&&
lens
=
s
.
lens
();
auto
x
=
std
::
accumulate
(
lens
.
begin
(),
lens
.
begin
()
+
axis
,
std
::
size_t
{
1
},
std
::
multiplies
<>
{});
auto
y
=
std
::
accumulate
(
lens
.
begin
()
+
axis
,
lens
.
end
(),
std
::
size_t
{
1
},
std
::
multiplies
<>
{});
return
{
s
.
type
(),
{
x
,
y
}};
}
}
}
argument
compute
(
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
argument
compute
(
const
dyn_output
&
dyn_out
,
std
::
vector
<
argument
>
args
)
const
{
{
return
args
[
0
].
reshape
(
out
put_shape
);
return
args
[
0
].
reshape
(
dyn_out
.
com
put
ed
_shape
);
}
}
std
::
ptrdiff_t
output_alias
(
const
std
::
vector
<
shape
>&
)
const
{
return
0
;
}
std
::
ptrdiff_t
output_alias
(
const
std
::
vector
<
shape
>&
)
const
{
return
0
;
}
};
};
...
...
src/include/migraphx/op/pooling.hpp
View file @
b98308b8
...
@@ -31,7 +31,7 @@
...
@@ -31,7 +31,7 @@
#include <migraphx/argument.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/par_for.hpp>
#include <migraphx/par_for.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/
int_divide
.hpp>
#include <migraphx/
dyn_output
.hpp>
#include <cmath>
#include <cmath>
#include <utility>
#include <utility>
...
@@ -49,6 +49,9 @@ struct pooling
...
@@ -49,6 +49,9 @@ struct pooling
bool
ceil_mode
=
false
;
bool
ceil_mode
=
false
;
int
lp_order
=
2
;
int
lp_order
=
2
;
// Global pooling with dynamic shape input
bool
dyn_global
=
false
;
template
<
class
Self
,
class
F
>
template
<
class
Self
,
class
F
>
static
auto
reflect
(
Self
&
self
,
F
f
)
static
auto
reflect
(
Self
&
self
,
F
f
)
{
{
...
@@ -57,7 +60,8 @@ struct pooling
...
@@ -57,7 +60,8 @@ struct pooling
f
(
self
.
stride
,
"stride"
),
f
(
self
.
stride
,
"stride"
),
f
(
self
.
lengths
,
"lengths"
),
f
(
self
.
lengths
,
"lengths"
),
f
(
self
.
ceil_mode
,
"ceil_mode"
),
f
(
self
.
ceil_mode
,
"ceil_mode"
),
f
(
self
.
lp_order
,
"lp_order"
));
f
(
self
.
lp_order
,
"lp_order"
),
f
(
self
.
dyn_global
,
"dyn_global"
));
}
}
std
::
string
name
()
const
{
return
"pooling"
;
}
std
::
string
name
()
const
{
return
"pooling"
;
}
...
@@ -65,51 +69,111 @@ struct pooling
...
@@ -65,51 +69,111 @@ struct pooling
void
check_attribute_size
()
const
void
check_attribute_size
()
const
{
{
if
((
padding
.
size
()
!=
stride
.
size
()
and
(
padding
.
size
()
/
2
)
!=
stride
.
size
())
or
if
((
padding
.
size
()
!=
stride
.
size
()
and
(
padding
.
size
()
/
2
)
!=
stride
.
size
())
or
stride
.
size
()
!=
lengths
.
size
())
(
not
dyn_global
and
stride
.
size
()
!=
lengths
.
size
())
)
{
{
MIGRAPHX_THROW
(
"POOLING: inconsistent attribute sizes"
);
MIGRAPHX_THROW
(
"POOLING: inconsistent attribute sizes"
);
}
}
}
}
size_t
kdims
()
const
{
check_attribute_size
();
return
stride
.
size
();
}
value
attributes
()
const
{
return
{{
"normalize_padding"
,
"padding"
}};
}
value
attributes
()
const
{
return
{{
"normalize_padding"
,
"padding"
}};
}
std
::
vector
<
std
::
size_t
>
calc_spatial_dim_out
(
const
std
::
vector
<
std
::
size_t
>&
input_lens
,
std
::
size_t
kdims
)
const
{
std
::
vector
<
std
::
size_t
>
output_lens
{};
for
(
size_t
i
=
0
;
i
<
kdims
;
++
i
)
{
if
(
input_lens
[
i
+
2
]
==
0
)
{
// handle opt = 0
output_lens
.
push_back
(
0
);
}
else
{
std
::
size_t
padding_factor
=
2
*
padding
[
i
];
if
(
padding
.
size
()
==
2
*
kdims
)
padding_factor
=
padding
[
i
]
+
padding
[
i
+
kdims
];
assert
(
input_lens
[
i
+
2
]
+
padding_factor
>=
lengths
[
i
]);
std
::
size_t
dim_size
=
input_lens
[
i
+
2
]
+
padding_factor
-
lengths
[
i
];
std
::
size_t
len
=
(
ceil_mode
)
?
dim_size
/
stride
[
i
]
+
static_cast
<
std
::
size_t
>
((
dim_size
%
stride
[
i
]
!=
0
))
// ceil uint divide
:
dim_size
/
stride
[
i
];
// floor divide
output_lens
.
push_back
(
len
+
1
);
}
}
return
output_lens
;
}
shape
normalize_compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
shape
normalize_compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
{
check_shapes
{
inputs
,
*
this
}.
has
(
1
);
check_shapes
{
inputs
,
*
this
,
true
}.
has
(
1
);
check_attribute_size
();
const
shape
&
input
=
inputs
.
at
(
0
);
const
shape
&
input
=
inputs
.
at
(
0
);
auto
padding_size
=
padding
.
size
();
auto
input_lens
=
input
.
lens
();
size_t
kdims
=
input
.
ndim
()
-
2
;
size_t
kdims
=
input_lens
.
size
()
-
2
;
if
(
input
.
ndim
()
!=
padding_size
/
2
+
2
and
input
.
ndim
()
!=
padding_size
+
2
)
auto
input_size
=
inputs
[
0
].
lens
().
size
();
auto
padding_size
=
padding
.
size
();
if
(
input_size
!=
padding_size
/
2
+
2
and
input_size
!=
padding_size
+
2
)
{
{
MIGRAPHX_THROW
(
"POOLING: input and attribute size mismatch!"
);
MIGRAPHX_THROW
(
"POOLING: input and attribute size mismatch!"
);
}
}
std
::
vector
<
std
::
size_t
>
output_lens
(
input_lens
.
begin
(),
input_lens
.
begin
()
+
2
);
if
(
input
.
dynamic
())
for
(
size_t
i
=
0
;
i
<
kdims
;
i
++
)
{
{
std
::
ptrdiff_t
dim_size
;
auto
input_dyn_dims
=
input
.
dyn_dims
();
auto
padding_factor
=
2
*
padding
[
i
];
std
::
vector
<
shape
::
dynamic_dimension
>
output_dyn_dims
(
input_dyn_dims
.
begin
(),
if
(
padding_size
==
2
*
kdims
)
input_dyn_dims
.
begin
()
+
2
);
padding_factor
=
padding
[
i
]
+
padding
[
i
+
kdims
];
if
(
dyn_global
)
dim_size
=
input_lens
[
i
+
2
]
+
padding_factor
-
lengths
[
i
];
{
assert
(
dim_size
>=
0
);
for
(
size_t
i
=
0
;
i
<
kdims
;
++
i
)
std
::
size_t
len
=
(
ceil_mode
)
?
ceil_divide
<
std
::
ptrdiff_t
>
(
dim_size
,
stride
[
i
])
{
:
floor_divide
<
std
::
ptrdiff_t
>
(
dim_size
,
stride
[
i
]);
output_dyn_dims
.
push_back
(
shape
::
dynamic_dimension
{
1
,
1
,
1
});
}
output_lens
.
push_back
(
std
::
size_t
(
std
::
max
<
std
::
ptrdiff_t
>
(
1
,
len
+
1
)));
return
{
input
.
type
(),
output_dyn_dims
};
}
else
{
auto
min_spatial_dims
=
calc_spatial_dim_out
(
input
.
min_lens
(),
kdims
);
auto
max_spatial_dims
=
calc_spatial_dim_out
(
input
.
max_lens
(),
kdims
);
auto
opt_spatial_dims
=
calc_spatial_dim_out
(
input
.
opt_lens
(),
kdims
);
for
(
size_t
i
=
0
;
i
<
kdims
;
++
i
)
{
output_dyn_dims
.
push_back
(
shape
::
dynamic_dimension
{
min_spatial_dims
[
i
],
max_spatial_dims
[
i
],
opt_spatial_dims
[
i
]});
}
return
{
input
.
type
(),
output_dyn_dims
};
}
}
}
return
inputs
[
0
].
with_lens
(
output_lens
);
else
}
{
auto
input_lens
=
input
.
lens
();
size_t
kdims
()
const
std
::
vector
<
std
::
size_t
>
output_lens
(
input_lens
.
begin
(),
input_lens
.
begin
()
+
2
);
{
// Used for when normalize_compute_shape() is called again at model eval time
check_attribute_size
();
// for an originally dynamic shape. Since kernel shape is not used with dyn_global.
return
stride
.
size
();
if
(
dyn_global
)
{
for
(
size_t
i
=
0
;
i
<
kdims
;
++
i
)
{
output_lens
.
push_back
(
1
);
}
return
{
input
.
type
(),
output_lens
};
}
else
{
auto
output_spatial_lens
=
calc_spatial_dim_out
(
input_lens
,
kdims
);
output_lens
.
insert
(
output_lens
.
end
(),
output_spatial_lens
.
begin
(),
output_spatial_lens
.
end
());
return
inputs
[
0
].
with_lens
(
output_lens
);
}
}
}
}
struct
lpnorm_pool
struct
lpnorm_pool
...
@@ -158,7 +222,11 @@ struct pooling
...
@@ -158,7 +222,11 @@ struct pooling
};
};
template
<
class
Type
,
class
Out
,
class
In
,
class
Op
>
template
<
class
Type
,
class
Out
,
class
In
,
class
Op
>
void
calc_pooling
(
const
shape
&
output_shape
,
Out
&
output
,
const
In
&
input
,
Op
op
)
const
void
calc_pooling
(
const
shape
&
output_shape
,
Out
&
output
,
const
In
&
input
,
const
std
::
vector
<
std
::
size_t
>&
kernel_dims
,
Op
op
)
const
{
{
auto
in_s
=
input
.
get_shape
();
auto
in_s
=
input
.
get_shape
();
auto
in_lens
=
in_s
.
lens
();
auto
in_lens
=
in_s
.
lens
();
...
@@ -172,7 +240,7 @@ struct pooling
...
@@ -172,7 +240,7 @@ struct pooling
auto
d_2
=
dim
-
2
;
auto
d_2
=
dim
-
2
;
int
start
=
int
start
=
static_cast
<
int
>
(
idx_o
[
dim
]
*
stride
[
d_2
])
-
static_cast
<
int
>
(
padding
[
d_2
]);
static_cast
<
int
>
(
idx_o
[
dim
]
*
stride
[
d_2
])
-
static_cast
<
int
>
(
padding
[
d_2
]);
int
end
=
std
::
min
(
start
+
length
s
[
d_2
],
in_lens
[
dim
]);
int
end
=
std
::
min
(
start
+
kernel_dim
s
[
d_2
],
in_lens
[
dim
]);
start
=
std
::
max
(
start
,
0
);
start
=
std
::
max
(
start
,
0
);
win_start
.
push_back
(
start
);
win_start
.
push_back
(
start
);
win_size
.
push_back
(
end
-
start
);
win_size
.
push_back
(
end
-
start
);
...
@@ -198,21 +266,32 @@ struct pooling
...
@@ -198,21 +266,32 @@ struct pooling
});
});
}
}
argument
compute
(
const
shape
&
output_shape
,
std
::
vector
<
argument
>
args
)
const
argument
compute
(
const
dyn_output
&
dyn_out
,
std
::
vector
<
argument
>
args
)
const
{
{
argument
result
{
output_shape
};
argument
result
{
dyn_out
.
computed_shape
};
auto
input_lens
=
args
[
0
].
get_shape
().
lens
();
std
::
vector
<
std
::
size_t
>
kernel_dims
;
if
(
dyn_global
)
{
kernel_dims
.
insert
(
kernel_dims
.
end
(),
input_lens
.
begin
()
+
2
,
input_lens
.
end
());
}
else
{
kernel_dims
=
this
->
lengths
;
}
visit_all
(
result
,
args
[
0
])([
&
](
auto
output
,
auto
input
)
{
visit_all
(
result
,
args
[
0
])([
&
](
auto
output
,
auto
input
)
{
using
type
=
typename
decltype
(
output
)
::
value_type
;
using
type
=
typename
decltype
(
output
)
::
value_type
;
switch
(
mode
)
switch
(
mode
)
{
{
case
migraphx
::
op
::
pooling_mode
::
average
:
case
migraphx
::
op
::
pooling_mode
::
average
:
calc_pooling
<
type
>
(
out
put_shape
,
output
,
input
,
avg_pool
{});
calc_pooling
<
type
>
(
dyn_out
.
com
put
ed
_shape
,
output
,
input
,
kernel_dims
,
avg_pool
{});
break
;
break
;
case
migraphx
::
op
::
pooling_mode
::
max
:
case
migraphx
::
op
::
pooling_mode
::
max
:
calc_pooling
<
type
>
(
out
put_shape
,
output
,
input
,
max_pool
{});
calc_pooling
<
type
>
(
dyn_out
.
com
put
ed
_shape
,
output
,
input
,
kernel_dims
,
max_pool
{});
break
;
break
;
case
migraphx
::
op
::
pooling_mode
::
lpnorm
:
case
migraphx
::
op
::
pooling_mode
::
lpnorm
:
calc_pooling
<
type
>
(
output_shape
,
output
,
input
,
lpnorm_pool
{
lp_order
});
calc_pooling
<
type
>
(
dyn_out
.
computed_shape
,
output
,
input
,
kernel_dims
,
lpnorm_pool
{
lp_order
});
break
;
break
;
}
}
});
});
...
...
src/include/migraphx/op/softmax.hpp
View file @
b98308b8
...
@@ -53,15 +53,15 @@ struct softmax
...
@@ -53,15 +53,15 @@ struct softmax
std
::
string
name
()
const
{
return
"softmax"
;
}
std
::
string
name
()
const
{
return
"softmax"
;
}
shape
normalize_compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
shape
normalize_compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
{
check_shapes
{
inputs
,
*
this
}.
has
(
1
);
check_shapes
{
inputs
,
*
this
,
true
}.
has
(
1
);
if
(
inputs
.
at
(
0
).
packed
())
auto
s0
=
inputs
[
0
];
if
(
s0
.
dynamic
()
or
s0
.
packed
())
{
{
return
inputs
.
at
(
0
)
;
return
s0
;
}
}
else
else
{
{
auto
lens
=
inputs
.
at
(
0
).
lens
();
return
{
s0
.
type
(),
s0
.
lens
()};
return
{
inputs
.
at
(
0
).
type
(),
lens
};
}
}
}
}
...
...
src/include/migraphx/op/squeeze.hpp
View file @
b98308b8
...
@@ -59,9 +59,8 @@ struct squeeze
...
@@ -59,9 +59,8 @@ struct squeeze
auto
input_shape
=
inputs
[
0
];
auto
input_shape
=
inputs
[
0
];
if
(
input_shape
.
dynamic
())
if
(
input_shape
.
dynamic
())
{
{
shape
::
dynamic_dimension
one_dyn_dim
{
1
,
1
,
0
};
if
(
std
::
any_of
(
axes
.
begin
(),
axes
.
end
(),
[
&
](
auto
axis
)
{
if
(
std
::
any_of
(
axes
.
begin
(),
axes
.
end
(),
[
&
](
auto
axis
)
{
return
input_shape
.
dyn_dims
()[
axis
]
!=
one_dyn_dim
;
return
input_shape
.
dyn_dims
()[
axis
]
!=
1
;
}))
}))
{
{
MIGRAPHX_THROW
(
MIGRAPHX_THROW
(
...
@@ -70,14 +69,10 @@ struct squeeze
...
@@ -70,14 +69,10 @@ struct squeeze
std
::
vector
<
shape
::
dynamic_dimension
>
dyn_dims
=
{};
std
::
vector
<
shape
::
dynamic_dimension
>
dyn_dims
=
{};
if
(
axes
.
empty
())
if
(
axes
.
empty
())
{
{
for
(
auto
i
:
range
(
input_shape
.
ndim
()))
std
::
copy_if
(
input_shape
.
dyn_dims
().
cbegin
(),
{
input_shape
.
dyn_dims
().
cend
(),
auto
dd
=
input_shape
.
dyn_dims
()[
i
];
std
::
back_inserter
(
dyn_dims
),
if
(
dd
!=
one_dyn_dim
)
[
&
](
auto
dd
)
{
return
dd
!=
1
;
});
{
dyn_dims
.
push_back
(
dd
);
}
}
}
}
else
else
{
{
...
...
src/include/migraphx/program.hpp
View file @
b98308b8
...
@@ -115,6 +115,7 @@ struct program
...
@@ -115,6 +115,7 @@ struct program
print_func
)
const
;
print_func
)
const
;
void
print_graph
(
std
::
ostream
&
os
,
bool
brief
=
false
)
const
;
void
print_graph
(
std
::
ostream
&
os
,
bool
brief
=
false
)
const
;
void
print_py
(
std
::
ostream
&
os
)
const
;
void
print_cpp
(
std
::
ostream
&
os
)
const
;
void
print_cpp
(
std
::
ostream
&
os
)
const
;
void
dry_run
(
parameter_map
params
)
const
;
void
dry_run
(
parameter_map
params
)
const
;
...
...
src/include/migraphx/shape.hpp
View file @
b98308b8
...
@@ -101,6 +101,12 @@ struct shape
...
@@ -101,6 +101,12 @@ struct shape
friend
bool
operator
==
(
const
dynamic_dimension
&
x
,
const
dynamic_dimension
&
y
);
friend
bool
operator
==
(
const
dynamic_dimension
&
x
,
const
dynamic_dimension
&
y
);
friend
bool
operator
!=
(
const
dynamic_dimension
&
x
,
const
dynamic_dimension
&
y
);
friend
bool
operator
!=
(
const
dynamic_dimension
&
x
,
const
dynamic_dimension
&
y
);
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
dynamic_dimension
&
x
);
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
dynamic_dimension
&
x
);
// compare to fixed std::size_t dimension
friend
bool
operator
==
(
const
dynamic_dimension
&
x
,
const
std
::
size_t
&
y
);
friend
bool
operator
==
(
const
std
::
size_t
&
x
,
const
dynamic_dimension
&
y
);
friend
bool
operator
!=
(
const
dynamic_dimension
&
x
,
const
std
::
size_t
&
y
);
friend
bool
operator
!=
(
const
std
::
size_t
&
x
,
const
dynamic_dimension
&
y
);
};
};
static
const
std
::
vector
<
type_t
>&
types
();
static
const
std
::
vector
<
type_t
>&
types
();
...
...
src/include/migraphx/shape_for_each.hpp
View file @
b98308b8
...
@@ -31,6 +31,9 @@
...
@@ -31,6 +31,9 @@
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
/**
* Iterates the given function over the indices from the shape in order.
*/
template
<
class
F
>
template
<
class
F
>
void
shape_for_each
(
const
migraphx
::
shape
&
s
,
F
f
)
void
shape_for_each
(
const
migraphx
::
shape
&
s
,
F
f
)
{
{
...
@@ -51,7 +54,6 @@ void shape_for_each(const migraphx::shape& s, F f)
...
@@ -51,7 +54,6 @@ void shape_for_each(const migraphx::shape& s, F f)
call
(
indices
);
call
(
indices
);
}
}
}
}
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
}
// namespace migraphx
...
...
src/insert_pad.cpp
View file @
b98308b8
...
@@ -77,14 +77,14 @@ static void update_pooling(const instruction_ref& input, const instruction_ref&
...
@@ -77,14 +77,14 @@ static void update_pooling(const instruction_ref& input, const instruction_ref&
{
{
return
;
return
;
}
}
auto
kdims
=
input
->
get_shape
().
lens
().
size
()
-
2
;
auto
kdims
=
input
->
get_shape
().
ndim
()
-
2
;
if
(
std
::
equal
(
op
.
padding
.
begin
(),
if
(
std
::
equal
(
op
.
padding
.
begin
(),
op
.
padding
.
begin
()
+
kdims
,
op
.
padding
.
begin
()
+
kdims
,
op
.
padding
.
begin
()
+
kdims
,
op
.
padding
.
begin
()
+
kdims
,
op
.
padding
.
end
()))
op
.
padding
.
end
()))
return
;
return
;
std
::
vector
<
int64_t
>
padding
(
input
->
get_shape
().
lens
().
size
()
*
2
,
0
);
std
::
vector
<
int64_t
>
padding
(
input
->
get_shape
().
ndim
()
*
2
,
0
);
std
::
vector
<
size_t
>
pads_l
(
op
.
padding
.
begin
(),
op
.
padding
.
begin
()
+
kdims
);
std
::
vector
<
size_t
>
pads_l
(
op
.
padding
.
begin
(),
op
.
padding
.
begin
()
+
kdims
);
std
::
vector
<
size_t
>
pads_r
(
op
.
padding
.
begin
()
+
kdims
,
op
.
padding
.
end
());
std
::
vector
<
size_t
>
pads_r
(
op
.
padding
.
begin
()
+
kdims
,
op
.
padding
.
end
());
op
.
padding
=
std
::
vector
<
size_t
>
(
kdims
*
2
,
0
);
op
.
padding
=
std
::
vector
<
size_t
>
(
kdims
*
2
,
0
);
...
...
src/instruction.cpp
View file @
b98308b8
...
@@ -302,6 +302,24 @@ void instruction::replace_mod_argument(module_ref old, module_ref new_mod)
...
@@ -302,6 +302,24 @@ void instruction::replace_mod_argument(module_ref old, module_ref new_mod)
std
::
replace
(
module_args
.
begin
(),
module_args
.
end
(),
old
,
new_mod
);
std
::
replace
(
module_args
.
begin
(),
module_args
.
end
(),
old
,
new_mod
);
}
}
bool
instruction
::
is_undefined
()
const
{
if
(
op
.
name
()
==
"undefined"
)
{
return
true
;
}
else
if
(
this
->
inputs
().
empty
())
{
return
false
;
}
else
{
return
std
::
all_of
(
this
->
inputs
().
begin
(),
this
->
inputs
().
end
(),
[](
auto
arg
)
{
return
arg
->
is_undefined
();
});
}
}
bool
instruction
::
can_eval
()
const
bool
instruction
::
can_eval
()
const
{
{
if
(
op
.
name
()
==
"@literal"
)
if
(
op
.
name
()
==
"@literal"
)
...
...
src/module.cpp
View file @
b98308b8
...
@@ -789,6 +789,22 @@ static std::string cpp_var_name(const std::string& name)
...
@@ -789,6 +789,22 @@ static std::string cpp_var_name(const std::string& name)
return
to_c_id
(
"x_"
+
replace_string
(
name
,
":"
,
"_module_"
));
return
to_c_id
(
"x_"
+
replace_string
(
name
,
":"
,
"_module_"
));
}
}
static
void
print_py_op
(
std
::
ostream
&
os
,
const
operation
&
op
)
{
auto
v
=
op
.
to_value
();
os
<<
"migraphx.op("
<<
enclose_name
(
op
.
name
());
auto
default_values
=
make_op
(
op
.
name
()).
to_value
();
for
(
auto
&&
x
:
v
)
{
auto
name
=
x
.
get_key
();
if
(
default_values
[
name
]
==
x
)
continue
;
os
<<
", "
<<
name
<<
"="
<<
to_json_string
(
x
.
without_key
());
}
os
<<
")"
;
}
static
void
print_make_op
(
std
::
ostream
&
os
,
const
operation
&
op
)
static
void
print_make_op
(
std
::
ostream
&
os
,
const
operation
&
op
)
{
{
auto
v
=
op
.
to_value
();
auto
v
=
op
.
to_value
();
...
@@ -804,6 +820,14 @@ static void print_make_op(std::ostream& os, const operation& op)
...
@@ -804,6 +820,14 @@ static void print_make_op(std::ostream& os, const operation& op)
os
<<
")"
;
os
<<
")"
;
}
}
static
void
print_py_shape
(
std
::
ostream
&
os
,
const
migraphx
::
shape
&
s
)
{
os
<<
"migraphx.shape("
<<
s
.
type_string
()
<<
", lens="
<<
to_json_string
(
s
.
lens
());
if
(
not
s
.
standard
())
os
<<
", strides="
<<
to_json_string
(
s
.
strides
());
os
<<
")"
;
}
static
void
print_cpp_shape
(
std
::
ostream
&
os
,
const
migraphx
::
shape
&
s
)
static
void
print_cpp_shape
(
std
::
ostream
&
os
,
const
migraphx
::
shape
&
s
)
{
{
os
<<
"migraphx::shape{migraphx::shape::"
<<
s
.
type_string
();
os
<<
"migraphx::shape{migraphx::shape::"
<<
s
.
type_string
();
...
@@ -813,6 +837,68 @@ static void print_cpp_shape(std::ostream& os, const migraphx::shape& s)
...
@@ -813,6 +837,68 @@ static void print_cpp_shape(std::ostream& os, const migraphx::shape& s)
os
<<
"}"
;
os
<<
"}"
;
}
}
std
::
unordered_map
<
instruction_ref
,
std
::
string
>
module
::
print_py
(
std
::
ostream
&
os
,
const
std
::
string
&
mname
,
std
::
unordered_map
<
instruction_ref
,
std
::
string
>
names
)
const
{
// cppcheck-suppress variableScope
unsigned
long
seed
=
names
.
size
();
auto
last
=
std
::
prev
(
this
->
end
());
names
=
this
->
print
(
[
&
](
auto
ins
,
auto
ins_names
)
{
std
::
vector
<
std
::
string
>
input_vars
;
std
::
transform
(
ins
->
inputs
().
begin
(),
ins
->
inputs
().
end
(),
std
::
back_inserter
(
input_vars
),
[
&
](
auto
input
)
{
return
cpp_var_name
(
ins_names
.
at
(
input
));
});
if
(
ins
!=
last
)
os
<<
cpp_var_name
(
ins_names
.
at
(
ins
))
<<
" = "
;
if
(
ins
->
name
()
==
"@literal"
)
{
os
<<
mname
<<
".add_literal("
;
bool
use_abs
=
false
;
ins
->
get_literal
().
visit
([
&
](
auto
v
)
{
use_abs
=
std
::
none_of
(
v
.
begin
(),
v
.
end
(),
[](
auto
x
)
{
return
x
<
0
;
});
});
// Disable abs for now
use_abs
=
false
;
if
(
use_abs
)
os
<<
"migraphx.abs_literal("
;
os
<<
"migraphx.generate_literal("
;
print_py_shape
(
os
,
ins
->
get_shape
());
os
<<
", "
<<
seed
<<
")"
;
if
(
use_abs
)
os
<<
")"
;
os
<<
")"
<<
std
::
endl
;
seed
++
;
}
else
if
(
ins
->
name
()
==
"@param"
)
{
std
::
string
name
=
any_cast
<
builtin
::
param
>
(
ins
->
get_operator
()).
parameter
;
os
<<
mname
<<
".add_parameter("
<<
enclose_name
(
name
)
<<
","
;
print_py_shape
(
os
,
ins
->
get_shape
());
os
<<
")"
<<
std
::
endl
;
}
else
if
(
ins
->
name
()
==
"@return"
)
{
os
<<
mname
<<
".add_return(["
<<
join_strings
(
input_vars
,
", "
)
<<
"])"
<<
std
::
endl
;
}
else
{
assert
(
ins
->
name
().
front
()
!=
'@'
);
os
<<
mname
<<
".add_instruction("
;
print_py_op
(
os
,
ins
->
get_operator
());
os
<<
", ["
<<
join_strings
(
input_vars
,
", "
)
<<
"]"
;
os
<<
")"
<<
std
::
endl
;
}
},
names
);
return
names
;
}
std
::
unordered_map
<
instruction_ref
,
std
::
string
>
std
::
unordered_map
<
instruction_ref
,
std
::
string
>
module
::
print_cpp
(
std
::
ostream
&
os
,
module
::
print_cpp
(
std
::
ostream
&
os
,
const
std
::
string
&
mname
,
const
std
::
string
&
mname
,
...
@@ -874,6 +960,8 @@ module::print_cpp(std::ostream& os,
...
@@ -874,6 +960,8 @@ module::print_cpp(std::ostream& os,
return
names
;
return
names
;
}
}
void
module
::
print_py
(
std
::
ostream
&
os
)
const
{
this
->
print_py
(
os
,
this
->
name
(),
{});
}
void
module
::
print_cpp
(
std
::
ostream
&
os
)
const
{
this
->
print_cpp
(
os
,
this
->
name
(),
{});
}
void
module
::
print_cpp
(
std
::
ostream
&
os
)
const
{
this
->
print_cpp
(
os
,
this
->
name
(),
{});
}
void
module
::
annotate
(
std
::
ostream
&
os
,
std
::
function
<
void
(
instruction_ref
)
>
a
)
const
void
module
::
annotate
(
std
::
ostream
&
os
,
std
::
function
<
void
(
instruction_ref
)
>
a
)
const
...
...
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment