Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
dae94657
Unverified
Commit
dae94657
authored
Dec 14, 2022
by
Chris Austen
Committed by
GitHub
Dec 14, 2022
Browse files
Merge branch 'develop' into jit-reduce-reg
parents
b013d991
56c43445
Changes
201
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
275 additions
and
137 deletions
+275
-137
src/eliminate_contiguous.cpp
src/eliminate_contiguous.cpp
+17
-4
src/file_buffer.cpp
src/file_buffer.cpp
+16
-8
src/fuse_pointwise.cpp
src/fuse_pointwise.cpp
+13
-2
src/include/migraphx/argument.hpp
src/include/migraphx/argument.hpp
+1
-0
src/include/migraphx/check_shapes.hpp
src/include/migraphx/check_shapes.hpp
+15
-1
src/include/migraphx/common.hpp
src/include/migraphx/common.hpp
+3
-0
src/include/migraphx/dyn_output.hpp
src/include/migraphx/dyn_output.hpp
+33
-19
src/include/migraphx/file_buffer.hpp
src/include/migraphx/file_buffer.hpp
+1
-1
src/include/migraphx/instruction.hpp
src/include/migraphx/instruction.hpp
+2
-0
src/include/migraphx/layout_nhwc.hpp
src/include/migraphx/layout_nhwc.hpp
+8
-9
src/include/migraphx/literal.hpp
src/include/migraphx/literal.hpp
+6
-15
src/include/migraphx/module.hpp
src/include/migraphx/module.hpp
+6
-0
src/include/migraphx/op/argmax.hpp
src/include/migraphx/op/argmax.hpp
+19
-11
src/include/migraphx/op/binary.hpp
src/include/migraphx/op/binary.hpp
+14
-4
src/include/migraphx/op/broadcast.hpp
src/include/migraphx/op/broadcast.hpp
+90
-33
src/include/migraphx/op/common.hpp
src/include/migraphx/op/common.hpp
+2
-2
src/include/migraphx/op/contiguous.hpp
src/include/migraphx/op/contiguous.hpp
+18
-9
src/include/migraphx/op/convert.hpp
src/include/migraphx/op/convert.hpp
+1
-1
src/include/migraphx/op/convolution.hpp
src/include/migraphx/op/convolution.hpp
+8
-16
src/include/migraphx/op/deconvolution.hpp
src/include/migraphx/op/deconvolution.hpp
+2
-2
No files found.
src/eliminate_contiguous.cpp
View file @
dae94657
...
@@ -42,6 +42,13 @@ static bool try_compute_shape(instruction_ref ins,
...
@@ -42,6 +42,13 @@ static bool try_compute_shape(instruction_ref ins,
try
try
{
{
shape
new_shape
=
ins
->
get_operator
().
compute_shape
(
inputs
,
mods
);
shape
new_shape
=
ins
->
get_operator
().
compute_shape
(
inputs
,
mods
);
// Cannot tell if a dynamic shape will need to be made contiguous
if
(
new_shape
.
dynamic
())
{
return
false
;
}
// If the output shape is a standard shape, no need to try its output
// If the output shape is a standard shape, no need to try its output
if
(
new_shape
.
standard
())
if
(
new_shape
.
standard
())
{
{
...
@@ -133,14 +140,20 @@ static void remove_contiguous(const std::string& op_name, module& m, F f)
...
@@ -133,14 +140,20 @@ static void remove_contiguous(const std::string& op_name, module& m, F f)
}
}
}
}
// Perform evaluations in parallel
// Perform
static contiguous
evaluations in parallel
std
::
vector
<
argument
>
literals
(
const_instructions
.
size
());
std
::
vector
<
argument
>
literals
(
const_instructions
.
size
());
par_for
(
const_instructions
.
size
(),
1
,
[
&
](
const
auto
i
)
{
par_for
(
const_instructions
.
size
(),
1
,
[
&
](
const
auto
i
)
{
auto
c
=
op
::
contiguous
{};
auto
c
=
op
::
contiguous
{};
auto
prev
=
const_instructions
[
i
]
->
inputs
().
front
();
auto
prev
=
const_instructions
[
i
]
->
inputs
().
front
();
literals
[
i
]
=
c
.
compute
(
c
.
compute_shape
({
prev
->
get_shape
()}),
{
prev
->
eval
()});
// compute the output contiguous shape from the previous instruction shape
shape
computed_shape
=
c
.
compute_shape
({
prev
->
get_shape
()});
const
std
::
vector
<
argument
>&
prev_eval
=
{
prev
->
eval
()};
// prev_eval should not be used in make_compute_output_shape() as computed_shape is static
auto
co_shape
=
make_compute_output_shape
(
pack
(
c
,
computed_shape
,
prev_eval
));
literals
[
i
]
=
c
.
compute
(
co_shape
,
prev_eval
);
});
});
// Replace static contiguous operations with a literal
for
(
size_t
i
=
0
;
i
<
const_instructions
.
size
();
i
++
)
for
(
size_t
i
=
0
;
i
<
const_instructions
.
size
();
i
++
)
{
{
auto
l
=
m
.
add_literal
(
literals
[
i
].
get_shape
(),
literals
[
i
].
data
());
auto
l
=
m
.
add_literal
(
literals
[
i
].
get_shape
(),
literals
[
i
].
data
());
...
...
src/file_buffer.cpp
View file @
dae94657
...
@@ -30,23 +30,31 @@ namespace migraphx {
...
@@ -30,23 +30,31 @@ namespace migraphx {
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
template
<
class
T
>
template
<
class
T
>
T
generic_read_file
(
const
std
::
string
&
filename
)
T
generic_read_file
(
const
std
::
string
&
filename
,
size_t
offset
=
0
,
size_t
nbytes
=
0
)
{
{
std
::
ifstream
is
(
filename
,
std
::
ios
::
binary
|
std
::
ios
::
ate
);
std
::
ifstream
is
(
filename
,
std
::
ios
::
binary
|
std
::
ios
::
ate
);
std
::
streamsize
size
=
is
.
tellg
();
if
(
nbytes
==
0
)
if
(
size
<
1
)
{
// if there is a non-zero offset and nbytes is not set,
// calculate size of remaining bytes to read
nbytes
=
is
.
tellg
();
if
(
offset
>
nbytes
)
MIGRAPHX_THROW
(
"offset is larger than file size"
);
nbytes
-=
offset
;
}
if
(
nbytes
<
1
)
MIGRAPHX_THROW
(
"Invalid size for: "
+
filename
);
MIGRAPHX_THROW
(
"Invalid size for: "
+
filename
);
is
.
seekg
(
0
,
std
::
ios
::
beg
);
is
.
seekg
(
offset
,
std
::
ios
::
beg
);
T
buffer
(
size
,
0
);
T
buffer
(
nbytes
,
0
);
if
(
not
is
.
read
(
&
buffer
[
0
],
size
))
if
(
not
is
.
read
(
&
buffer
[
0
],
nbytes
))
MIGRAPHX_THROW
(
"Error reading file: "
+
filename
);
MIGRAPHX_THROW
(
"Error reading file: "
+
filename
);
return
buffer
;
return
buffer
;
}
}
std
::
vector
<
char
>
read_buffer
(
const
std
::
string
&
filename
)
std
::
vector
<
char
>
read_buffer
(
const
std
::
string
&
filename
,
size_t
offset
,
size_t
nbytes
)
{
{
return
generic_read_file
<
std
::
vector
<
char
>>
(
filename
);
return
generic_read_file
<
std
::
vector
<
char
>>
(
filename
,
offset
,
nbytes
);
}
}
std
::
string
read_string
(
const
std
::
string
&
filename
)
std
::
string
read_string
(
const
std
::
string
&
filename
)
...
...
src/fuse_pointwise.cpp
View file @
dae94657
...
@@ -39,13 +39,22 @@ static literal get_scalar(instruction_ref ins)
...
@@ -39,13 +39,22 @@ static literal get_scalar(instruction_ref ins)
if
(
ins
->
name
()
==
"contiguous"
)
if
(
ins
->
name
()
==
"contiguous"
)
return
get_scalar
(
ins
->
inputs
().
front
());
return
get_scalar
(
ins
->
inputs
().
front
());
const
auto
&
s
=
ins
->
get_shape
();
const
auto
&
s
=
ins
->
get_shape
();
if
(
not
(
s
.
elements
()
=
=
1
or
s
.
scalar
()))
if
(
s
.
elements
()
!
=
1
&&
not
(
s
.
scalar
()))
return
{};
return
{};
if
(
not
ins
->
can_eval
())
if
(
not
ins
->
can_eval
())
return
{};
return
{};
auto
e
=
ins
->
eval
();
auto
e
=
ins
->
eval
();
literal
r
{};
literal
r
{};
// needed for bool as visit_at invokes as() which promotes bool to int8
// Without this we'll break type checks for logical ops that are fused.
if
(
e
.
get_shape
().
type
()
==
shape
::
bool_type
)
{
r
=
literal
{
e
.
at
<
bool
>
()};
}
else
{
e
.
visit_at
([
&
](
auto
x
)
{
r
=
literal
{
x
};
});
e
.
visit_at
([
&
](
auto
x
)
{
r
=
literal
{
x
};
});
}
return
r
;
return
r
;
}
}
...
@@ -56,6 +65,8 @@ static void create_pointwise_modules(module_pass_manager& mpm)
...
@@ -56,6 +65,8 @@ static void create_pointwise_modules(module_pass_manager& mpm)
{
{
if
(
not
ins
->
get_operator
().
attributes
().
get
(
"pointwise"
,
false
))
if
(
not
ins
->
get_operator
().
attributes
().
get
(
"pointwise"
,
false
))
continue
;
continue
;
if
(
ins
->
get_operator
().
name
()
==
"layout"
)
continue
;
assert
(
ins
->
get_operator
().
attributes
().
contains
(
"point_op"
));
assert
(
ins
->
get_operator
().
attributes
().
contains
(
"point_op"
));
auto
*
pm
=
mpm
.
create_module
(
mpm
.
get_module
().
name
()
+
":pointwise"
+
std
::
to_string
(
n
++
));
auto
*
pm
=
mpm
.
create_module
(
mpm
.
get_module
().
name
()
+
":pointwise"
+
std
::
to_string
(
n
++
));
pm
->
set_bypass
();
pm
->
set_bypass
();
...
...
src/include/migraphx/argument.hpp
View file @
dae94657
...
@@ -107,6 +107,7 @@ struct argument : raw_data<argument>
...
@@ -107,6 +107,7 @@ struct argument : raw_data<argument>
data_t
m_data
{};
data_t
m_data
{};
};
};
std
::
vector
<
shape
>
to_shapes
(
const
std
::
vector
<
argument
>&
args
);
void
migraphx_to_value
(
value
&
v
,
const
argument
&
a
);
void
migraphx_to_value
(
value
&
v
,
const
argument
&
a
);
void
migraphx_from_value
(
const
value
&
v
,
argument
&
a
);
void
migraphx_from_value
(
const
value
&
v
,
argument
&
a
);
...
...
src/include/migraphx/check_shapes.hpp
View file @
dae94657
...
@@ -24,6 +24,7 @@
...
@@ -24,6 +24,7 @@
#ifndef MIGRAPHX_GUARD_RTGLIB_CHECK_SHAPES_HPP
#ifndef MIGRAPHX_GUARD_RTGLIB_CHECK_SHAPES_HPP
#define MIGRAPHX_GUARD_RTGLIB_CHECK_SHAPES_HPP
#define MIGRAPHX_GUARD_RTGLIB_CHECK_SHAPES_HPP
#include <migraphx/permutation.hpp>
#include <migraphx/shape.hpp>
#include <migraphx/shape.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/stringutils.hpp>
...
@@ -197,7 +198,7 @@ struct check_shapes
...
@@ -197,7 +198,7 @@ struct check_shapes
*/
*/
const
check_shapes
&
same_ndims
()
const
const
check_shapes
&
same_ndims
()
const
{
{
if
(
not
this
->
same
([](
const
shape
&
s
)
{
return
s
.
max_lens
().
size
();
}))
if
(
not
this
->
same
([](
const
shape
&
s
)
{
return
s
.
ndim
();
}))
MIGRAPHX_THROW
(
prefix
()
+
"Number of dimensions do not match"
);
MIGRAPHX_THROW
(
prefix
()
+
"Number of dimensions do not match"
);
return
*
this
;
return
*
this
;
}
}
...
@@ -232,6 +233,19 @@ struct check_shapes
...
@@ -232,6 +233,19 @@ struct check_shapes
return
*
this
;
return
*
this
;
}
}
/*!
* Check all shapes are packed with certain layouts
*/
const
check_shapes
&
packed_layouts
(
const
std
::
initializer_list
<
std
::
vector
<
int64_t
>>&
layouts
)
const
{
if
(
not
this
->
all_of
([
&
](
const
shape
&
s
)
{
return
s
.
packed
()
and
contains
(
layouts
,
find_permutation
(
s
));
}))
MIGRAPHX_THROW
(
prefix
()
+
"Shapes are not packed with correct layout"
);
return
*
this
;
}
/*!
/*!
* Check all shapes are packed or broadcasted.
* Check all shapes are packed or broadcasted.
*/
*/
...
...
src/include/migraphx/common.hpp
View file @
dae94657
...
@@ -36,6 +36,9 @@ struct operation;
...
@@ -36,6 +36,9 @@ struct operation;
std
::
vector
<
std
::
size_t
>
compute_broadcasted_lens
(
std
::
vector
<
std
::
size_t
>
s0
,
std
::
vector
<
std
::
size_t
>
compute_broadcasted_lens
(
std
::
vector
<
std
::
size_t
>
s0
,
std
::
vector
<
std
::
size_t
>
s1
);
std
::
vector
<
std
::
size_t
>
s1
);
std
::
vector
<
shape
::
dynamic_dimension
>
compute_broadcasted_dyn_dims
(
shape
s0
,
shape
s1
);
shape
common_shape
(
const
std
::
vector
<
shape
>&
shapes
);
shape
common_shape
(
const
std
::
vector
<
shape
>&
shapes
);
instruction_ref
insert_common_op
(
module
&
m
,
instruction_ref
insert_common_op
(
module
&
m
,
...
...
src/
targets/gpu/
include/migraphx/
gpu/batch_norm_inference
.hpp
→
src/include/migraphx/
dyn_output
.hpp
View file @
dae94657
...
@@ -21,41 +21,55 @@
...
@@ -21,41 +21,55 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
* THE SOFTWARE.
*/
*/
#ifndef MIGRAPHX_GUARD_
RTGLIB_BATCHNORM
_HPP
#ifndef MIGRAPHX_GUARD_
MIGRAPHLIB_DYN_OUTPUT
_HPP
#define MIGRAPHX_GUARD_
RTGLIB_BATCHNORM
_HPP
#define MIGRAPHX_GUARD_
MIGRAPHLIB_DYN_OUTPUT
_HPP
#include <migraphx/shape.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/op/batch_norm_inference.hpp>
#include <migraphx/reflect.hpp>
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
gpu
{
struct
context
;
struct
dyn_output
{
// original shape from the instruction
shape
ins_shape
;
// shape computed at eval time using input arguments
shape
computed_shape
;
};
struct
miopen_batch_norm_inference
/**
* Handle dynamic and static shape at evaluation time.
* If converted to shape type, returns original ins_shape.
* If converted to dyn_output type, will compute an output shape using the input arguments.
*/
template
<
class
F
>
struct
compute_output_shape
{
{
op
::
batch_norm_inference
op
;
F
ins_inputs
;
template
<
class
Self
,
class
F
>
operator
dyn_output
()
const
static
auto
reflect
(
Self
&
self
,
F
f
)
{
{
return
migraphx
::
reflect
(
self
.
op
,
f
);
return
ins_inputs
([](
const
auto
&
x
,
shape
ins_shape
,
const
std
::
vector
<
argument
>&
inputs
)
{
if
(
ins_shape
.
dynamic
())
return
dyn_output
{
ins_shape
,
compute_shape
(
x
,
to_shapes
(
inputs
))};
return
dyn_output
{
ins_shape
,
ins_shape
};
});
}
}
std
::
string
name
()
const
{
return
"gpu::batch_norm_inference"
;
}
operator
shape
()
const
shape
compute_shape
(
const
std
::
vector
<
shape
>&
inputs
)
const
;
argument
compute
(
context
&
ctx
,
const
shape
&
output_shape
,
const
std
::
vector
<
argument
>&
args
)
const
;
std
::
ptrdiff_t
output_alias
(
const
std
::
vector
<
shape
>&
shapes
)
const
{
{
return
shapes
.
size
()
-
1
;
return
ins_inputs
(
[](
const
auto
&
,
shape
ins_shape
,
const
std
::
vector
<
argument
>&
)
{
return
ins_shape
;
});
}
}
};
};
}
// namespace gpu
template
<
class
F
>
compute_output_shape
<
F
>
make_compute_output_shape
(
F
f
)
{
return
{
f
};
}
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
}
// namespace migraphx
#endif
#endif
src/include/migraphx/file_buffer.hpp
View file @
dae94657
...
@@ -31,7 +31,7 @@
...
@@ -31,7 +31,7 @@
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
std
::
vector
<
char
>
read_buffer
(
const
std
::
string
&
filename
);
std
::
vector
<
char
>
read_buffer
(
const
std
::
string
&
filename
,
size_t
offset
=
0
,
size_t
nbytes
=
0
);
std
::
string
read_string
(
const
std
::
string
&
filename
);
std
::
string
read_string
(
const
std
::
string
&
filename
);
void
write_buffer
(
const
std
::
string
&
filename
,
const
char
*
buffer
,
std
::
size_t
size
);
void
write_buffer
(
const
std
::
string
&
filename
,
const
char
*
buffer
,
std
::
size_t
size
);
...
...
src/include/migraphx/instruction.hpp
View file @
dae94657
...
@@ -121,6 +121,8 @@ struct instruction
...
@@ -121,6 +121,8 @@ struct instruction
bool
can_eval
()
const
;
bool
can_eval
()
const
;
bool
is_undefined
()
const
;
argument
eval
(
bool
check_eval
=
true
)
const
;
argument
eval
(
bool
check_eval
=
true
)
const
;
void
finalize
(
context
&
ctx
);
void
finalize
(
context
&
ctx
);
...
...
src/include/migraphx/
rewrite_batchnorm
.hpp
→
src/include/migraphx/
layout_nhwc
.hpp
View file @
dae94657
...
@@ -21,8 +21,8 @@
...
@@ -21,8 +21,8 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
* THE SOFTWARE.
*/
*/
#ifndef MIGRAPHX_GUARD_
RTGLIB_FWD_CONV_BATCHNORM_REWRITE
_HPP
#ifndef MIGRAPHX_GUARD_
MIGRAPHX_LAYOUT_NHWC
_HPP
#define MIGRAPHX_GUARD_
RTGLIB_FWD_CONV_BATCHNORM_REWRITE
_HPP
#define MIGRAPHX_GUARD_
MIGRAPHX_LAYOUT_NHWC
_HPP
#include <string>
#include <string>
#include <migraphx/instruction_ref.hpp>
#include <migraphx/instruction_ref.hpp>
...
@@ -31,18 +31,17 @@
...
@@ -31,18 +31,17 @@
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
struct
module
;
struct
module
_pass_manager
;
/**
/**
*
Rewrite batchnorm to a multiply and add.
*
Transform convolutions to nhwc
*/
*/
struct
rewrite_batchnorm
struct
layout_nhwc
{
{
std
::
string
name
()
const
{
return
"
rewrite_batchnorm
"
;
}
std
::
string
name
()
const
{
return
"
layout_nhwc
"
;
}
void
apply
(
module
&
m
)
const
;
void
apply
(
module
_pass_manager
&
mp
m
)
const
;
};
};
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
}
// namespace migraphx
#endif // MIGRAPHX_GUARD_MIGRAPHX_LAYOUT_NHWC_HPP
#endif
src/include/migraphx/literal.hpp
View file @
dae94657
...
@@ -80,6 +80,7 @@ struct literal : raw_data<literal>
...
@@ -80,6 +80,7 @@ struct literal : raw_data<literal>
fill
(
start
,
end
);
fill
(
start
,
end
);
}
}
// Directly copies buffer of x
template
<
class
T
,
MIGRAPHX_REQUIRES
(
sizeof
(
T
)
==
1
)>
template
<
class
T
,
MIGRAPHX_REQUIRES
(
sizeof
(
T
)
==
1
)>
literal
(
const
shape
&
s
,
T
*
x
)
:
buffer
(
make_shared_array
<
char
>
(
s
.
bytes
())),
m_shape
(
s
)
literal
(
const
shape
&
s
,
T
*
x
)
:
buffer
(
make_shared_array
<
char
>
(
s
.
bytes
())),
m_shape
(
s
)
{
{
...
@@ -107,25 +108,15 @@ struct literal : raw_data<literal>
...
@@ -107,25 +108,15 @@ struct literal : raw_data<literal>
std
::
shared_ptr
<
char
>
buffer
;
std
::
shared_ptr
<
char
>
buffer
;
shape
m_shape
;
shape
m_shape
;
// Keeps the same data ordering as the given container
template
<
class
Iterator
>
template
<
class
Iterator
>
void
fill
(
Iterator
start
,
Iterator
end
)
void
fill
(
Iterator
start
,
Iterator
end
)
{
{
assert
(
std
::
distance
(
start
,
end
)
==
m_shape
.
elements
());
assert
(
std
::
distance
(
start
,
end
)
==
m_shape
.
elements
());
if
(
m_shape
.
standard
())
{
m_shape
.
visit_type
([
&
](
auto
as
)
{
std
::
copy
(
start
,
end
,
as
.
from
(
buffer
.
get
()));
});
}
else
{
auto
it
=
start
;
m_shape
.
visit_type
([
&
](
auto
as
)
{
m_shape
.
visit_type
([
&
](
auto
as
)
{
auto
output
=
make_view
(
m_shape
,
as
.
from
(
buffer
.
get
()));
auto
output
=
make_view
(
m_shape
,
as
.
from
(
buffer
.
get
()));
shape_for_each
(
output
.
get_shape
(),
[
&
](
const
auto
&
idx
)
{
std
::
copy
(
start
,
end
,
output
.
begin
());
output
(
idx
.
begin
(),
idx
.
end
())
=
*
it
;
// NOLINT(bugprone-signed-char-misuse)
it
++
;
});
});
});
}
}
}
};
};
...
...
src/include/migraphx/module.hpp
View file @
dae94657
...
@@ -205,6 +205,12 @@ struct module
...
@@ -205,6 +205,12 @@ struct module
void
print_graph
(
std
::
ostream
&
os
,
bool
brief
=
false
)
const
;
void
print_graph
(
std
::
ostream
&
os
,
bool
brief
=
false
)
const
;
void
print_py
(
std
::
ostream
&
os
)
const
;
std
::
unordered_map
<
instruction_ref
,
std
::
string
>
print_py
(
std
::
ostream
&
os
,
const
std
::
string
&
mname
,
std
::
unordered_map
<
instruction_ref
,
std
::
string
>
names
)
const
;
void
print_cpp
(
std
::
ostream
&
os
)
const
;
void
print_cpp
(
std
::
ostream
&
os
)
const
;
std
::
unordered_map
<
instruction_ref
,
std
::
string
>
std
::
unordered_map
<
instruction_ref
,
std
::
string
>
print_cpp
(
std
::
ostream
&
os
,
print_cpp
(
std
::
ostream
&
os
,
...
...
src/include/migraphx/op/argmax.hpp
View file @
dae94657
...
@@ -30,6 +30,7 @@
...
@@ -30,6 +30,7 @@
#include <migraphx/config.hpp>
#include <migraphx/config.hpp>
#include <migraphx/value.hpp>
#include <migraphx/value.hpp>
#include <migraphx/op/normalize_attribute.hpp>
#include <migraphx/op/normalize_attribute.hpp>
#include <migraphx/dyn_output.hpp>
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
@@ -56,13 +57,21 @@ struct argmax
...
@@ -56,13 +57,21 @@ struct argmax
shape
normalize_compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
shape
normalize_compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
{
check_shapes
{
inputs
,
*
this
}.
has
(
1
);
check_shapes
{
inputs
,
*
this
,
true
}.
has
(
1
);
auto
lens
=
inputs
[
0
].
lens
();
const
auto
&
s0
=
inputs
[
0
];
if
(
s0
.
dynamic
())
{
auto
dyn_dims
=
s0
.
dyn_dims
();
dyn_dims
[
axis
]
=
{
1
,
1
,
0
};
return
{
shape
::
int64_type
,
dyn_dims
};
}
else
{
auto
lens
=
s0
.
lens
();
lens
[
axis
]
=
1
;
lens
[
axis
]
=
1
;
return
{
shape
::
int64_type
,
lens
};
return
{
shape
::
int64_type
,
lens
};
}
}
}
template
<
class
T
>
template
<
class
T
>
int64_t
calc_argmax
(
T
&
input
,
std
::
vector
<
std
::
size_t
>&
indices
,
size_t
item_num
)
const
int64_t
calc_argmax
(
T
&
input
,
std
::
vector
<
std
::
size_t
>&
indices
,
size_t
item_num
)
const
...
@@ -79,19 +88,18 @@ struct argmax
...
@@ -79,19 +88,18 @@ struct argmax
max_index
=
i
;
max_index
=
i
;
}
}
}
}
return
max_index
;
return
max_index
;
}
}
argument
compute
(
const
shape
&
output_shape
,
std
::
vector
<
argument
>
args
)
const
argument
compute
(
const
dyn_output
&
dyn_out
,
std
::
vector
<
argument
>
args
)
const
{
{
argument
result
{
out
put_shape
};
argument
result
{
dyn_out
.
com
put
ed
_shape
};
auto
batch_item_num
=
args
.
front
().
get_shape
().
lens
()[
axis
];
auto
batch_item_num
=
args
.
front
().
get_shape
().
lens
()[
axis
];
result
.
visit
([
&
](
auto
output
)
{
result
.
visit
([
&
](
auto
output
)
{
args
[
0
].
visit
([
&
](
auto
input
)
{
args
[
0
].
visit
([
&
](
auto
input
)
{
par_for
(
out
put_shape
.
elements
(),
[
&
](
auto
i
)
{
par_for
(
dyn_out
.
com
put
ed
_shape
.
elements
(),
[
&
](
auto
i
)
{
auto
data_idx
=
out
put_shape
.
multi
(
i
);
auto
data_idx
=
dyn_out
.
com
put
ed
_shape
.
multi
(
i
);
output
[
i
]
=
this
->
calc_argmax
(
input
,
data_idx
,
batch_item_num
);
output
[
i
]
=
this
->
calc_argmax
(
input
,
data_idx
,
batch_item_num
);
});
});
});
});
...
...
src/include/migraphx/op/binary.hpp
View file @
dae94657
...
@@ -28,6 +28,7 @@
...
@@ -28,6 +28,7 @@
#include <migraphx/check_shapes.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/value.hpp>
#include <migraphx/value.hpp>
#include <migraphx/dyn_output.hpp>
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
@@ -60,10 +61,19 @@ struct binary : op_name<Derived>
...
@@ -60,10 +61,19 @@ struct binary : op_name<Derived>
value
attributes
()
const
{
return
base_attributes
();
}
value
attributes
()
const
{
return
base_attributes
();
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
{
check_shapes
{
inputs
,
static_cast
<
const
Derived
&>
(
*
this
)}.
has
(
2
).
same_type
().
same_dims
();
check_shapes
{
inputs
,
static_cast
<
const
Derived
&>
(
*
this
),
true
}
.
has
(
2
)
.
same_type
()
.
same_dims
();
auto
s0
=
inputs
.
at
(
0
);
auto
s0
=
inputs
.
at
(
0
);
auto
s1
=
inputs
.
at
(
1
);
auto
s1
=
inputs
.
at
(
1
);
if
(
s0
==
s1
and
s0
.
packed
())
if
(
s0
.
dynamic
()
or
s1
.
dynamic
())
{
if
(
s0
==
s1
)
return
s0
;
MIGRAPHX_THROW
(
"BINARY: "
+
point_function
()
+
": fixed-dyn shape for inputs"
);
}
else
if
(
s0
==
s1
and
s0
.
packed
())
{
{
return
s0
;
return
s0
;
}
}
...
@@ -81,9 +91,9 @@ struct binary : op_name<Derived>
...
@@ -81,9 +91,9 @@ struct binary : op_name<Derived>
}
}
}
}
argument
compute
(
const
shape
&
output_shape
,
std
::
vector
<
argument
>
args
)
const
argument
compute
(
const
dyn_output
&
dyn_out
,
std
::
vector
<
argument
>
args
)
const
{
{
argument
result
{
out
put_shape
};
argument
result
{
dyn_out
.
com
put
ed
_shape
};
visit_all
(
result
,
args
[
0
],
args
[
1
])([
&
](
auto
output
,
auto
input1
,
auto
input2
)
{
visit_all
(
result
,
args
[
0
],
args
[
1
])([
&
](
auto
output
,
auto
input1
,
auto
input2
)
{
std
::
transform
(
input1
.
begin
(),
std
::
transform
(
input1
.
begin
(),
input1
.
end
(),
input1
.
end
(),
...
...
src/include/migraphx/op/broadcast.hpp
View file @
dae94657
...
@@ -27,23 +27,30 @@
...
@@ -27,23 +27,30 @@
#include <migraphx/check_shapes.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/config.hpp>
#include <migraphx/config.hpp>
#include <migraphx/dyn_output.hpp>
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
op
{
namespace
op
{
/// The broadcast operator performs the numpy-style broadcasting of an axis of a given tensor. This
/**
/// is achieved primarily by setting the stride of the broadcasted axis to zero. Linear indicies are
* 1 input version:
/// computed from multi-indicies by computing the inner product on the multi-index with the strides.
* Broadcasts a tensor from the original shape to the broadcast_lens by setting the stride of
/// For example, if we have a tensor A(2,3) it has lengths of (2,3) and strides of (3,1). If we want
* broadcasted dimensions to zero. `axis` attribute for a 1D input shape is the output dimension
/// to compute the linear offset that corresponds to the element on the 2nd row (i = 1) and 3rd
* that stays the same. ex: broadcasting shape [1024] -> [4, 1024, 3] has axis = 1 For higher rank
/// column (j = 2), we compute the following inner product (1,2) dot (3, 1) = 1*3 + 2*1 = 5. It is
* input shapes, axis is an offset parameter for the broadcasting. Such that this operator would
/// obvious from there that we can negate the effects of a given axis by setting the stride of that
* work in the opposite direction of NumPy broadcasting. ex: broadcasting shape [2, 2] -> [2, 2, 3]
/// axis to zero.
* with axis = 0
*
* 2 input version:
* Broadcast the first input 1D shape into the second input shape based on the axis parameter.
* Handles broadcasting a 1D static shape into a higher rank dynamic shape.
* broadcast_lens is not used
*/
struct
broadcast
struct
broadcast
{
{
uint64_t
axis
=
0
;
uint64_t
axis
=
0
;
std
::
vector
<
std
::
size_t
>
broadcast_lens
;
std
::
vector
<
std
::
size_t
>
broadcast_lens
=
{}
;
template
<
class
Self
,
class
F
>
template
<
class
Self
,
class
F
>
static
auto
reflect
(
Self
&
self
,
F
f
)
static
auto
reflect
(
Self
&
self
,
F
f
)
...
@@ -54,36 +61,86 @@ struct broadcast
...
@@ -54,36 +61,86 @@ struct broadcast
std
::
string
name
()
const
{
return
"broadcast"
;
}
std
::
string
name
()
const
{
return
"broadcast"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
{
auto
input
=
inputs
.
at
(
0
);
check_shapes
{
inputs
,
*
this
,
true
}.
has
(
1
,
2
);
auto
t
=
input
.
type
();
auto
s0
=
inputs
.
at
(
0
);
auto
t
=
s0
.
type
();
std
::
vector
<
size_t
>
bcast_strides
(
broadcast_lens
.
size
(),
0
);
if
(
inputs
.
size
()
==
1
)
// the broacast op is deprecated now, so not handling the negative
{
// the ONNX broadcast op is deprecated now, so not handling the negative
// value of axis anymore
// value of axis anymore
if
(
axis
>=
broadcast_lens
.
size
())
if
(
axis
>=
broadcast_lens
.
size
())
{
{
MIGRAPHX_THROW
(
"BROADCAST : axis is out of range"
);
MIGRAPHX_THROW
(
"BROADCAST : axis "
+
migraphx
::
to_string
(
axis
)
+
" is out of range"
);
}
}
if
(
broadcast_lens
.
size
()
-
axis
<
s0
.
lens
().
size
())
if
(
broadcast_lens
.
size
()
-
axis
<
input
.
lens
().
size
())
{
{
MIGRAPHX_THROW
(
"BROADCAST: (broadcast ndims - axis) is less than
input
ndims"
);
MIGRAPHX_THROW
(
"BROADCAST: (broadcast ndims - axis) is less than
s0
ndims"
);
}
}
if
(
not
std
::
equal
(
s0
.
lens
().
begin
(),
s0
.
lens
().
end
(),
broadcast_lens
.
begin
()
+
axis
))
if
(
not
std
::
equal
(
input
.
lens
().
begin
(),
input
.
lens
().
end
(),
broadcast_lens
.
begin
()
+
axis
))
{
{
MIGRAPHX_THROW
(
"BROADCAST: when broadcasting, succeeding sizes must match"
);
MIGRAPHX_THROW
(
"BROADCAST: when broadcasting, succeeding sizes must match"
);
}
}
std
::
copy
(
input
.
strides
().
begin
(),
input
.
strides
().
end
(),
bcast_strides
.
begin
()
+
axis
);
std
::
vector
<
size_t
>
bcast_strides
(
broadcast_lens
.
size
(),
0
);
std
::
copy
(
s0
.
strides
().
begin
(),
s0
.
strides
().
end
(),
bcast_strides
.
begin
()
+
axis
);
shape
output
{
t
,
broadcast_lens
,
std
::
move
(
bcast_strides
)};
shape
output
{
t
,
broadcast_lens
,
std
::
move
(
bcast_strides
)};
if
(
output
.
elements
()
<
input
.
elements
())
if
(
output
.
elements
()
<
s0
.
elements
())
MIGRAPHX_THROW
(
"BROADCAST: output size must be greater than or equal to input size"
);
{
// don't think this can occur?
MIGRAPHX_THROW
(
"BROADCAST: output size must be greater than or equal to s0 size"
);
}
return
output
;
}
else
{
// two inputs
auto
s1
=
inputs
.
at
(
1
);
if
(
s0
.
dynamic
())
{
MIGRAPHX_THROW
(
"BROADCAST_2in: s0 is a dynamic shape, does not handle broadcasting "
"a dynamic shape"
);
}
if
(
s0
.
ndim
()
!=
1
)
{
MIGRAPHX_THROW
(
"BROADCAST_2in: s0 has ndim "
+
migraphx
::
to_string
(
s0
.
ndim
())
+
", only handle ndim = 1"
);
}
if
(
axis
>=
s1
.
ndim
())
{
MIGRAPHX_THROW
(
"BROADCAST_2in: axis "
+
migraphx
::
to_string
(
axis
)
+
" is out of range"
);
}
if
(
s1
.
dynamic
())
{
s0
=
s0
.
to_dynamic
();
if
(
s0
.
dyn_dims
()[
0
]
!=
s1
.
dyn_dims
()[
axis
])
{
MIGRAPHX_THROW
(
"BROADCAST_2in: s0 length doesn't match with dynamic s1 axis "
"dimension length ("
+
migraphx
::
to_string
(
s0
.
dyn_dims
()[
0
])
+
" != "
+
migraphx
::
to_string
(
s1
.
dyn_dims
()[
axis
])
+
")"
);
}
return
s1
;
}
if
(
s0
.
lens
()[
0
]
!=
s1
.
lens
()[
axis
])
{
MIGRAPHX_THROW
(
"BROADCAST_2in: s0 length doesn't match with static s1 axis "
"dimension length ("
+
migraphx
::
to_string
(
s0
.
lens
()[
0
])
+
" != "
+
migraphx
::
to_string
(
s1
.
lens
()[
axis
])
+
")"
);
}
std
::
vector
<
size_t
>
bcast_strides
(
s1
.
ndim
(),
0
);
std
::
copy
(
s0
.
strides
().
begin
(),
s0
.
strides
().
end
(),
bcast_strides
.
begin
()
+
axis
);
shape
output
{
t
,
s1
.
lens
(),
std
::
move
(
bcast_strides
)};
return
output
;
return
output
;
}
}
argument
compute
(
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
}
argument
compute
(
const
dyn_output
&
dyn_out
,
std
::
vector
<
argument
>
args
)
const
{
{
return
args
[
0
].
reshape
(
out
put_shape
);
return
args
[
0
].
reshape
(
dyn_out
.
com
put
ed
_shape
);
}
}
std
::
ptrdiff_t
output_alias
(
const
std
::
vector
<
shape
>&
)
const
{
return
0
;
}
std
::
ptrdiff_t
output_alias
(
const
std
::
vector
<
shape
>&
)
const
{
return
0
;
}
};
};
...
...
src/include/migraphx/op/common.hpp
View file @
dae94657
...
@@ -33,11 +33,11 @@ namespace migraphx {
...
@@ -33,11 +33,11 @@ namespace migraphx {
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
op
{
namespace
op
{
// Padding mode is default_ for fixed shape padding.
// same_lower and same_upper used for dynamic padding.
enum
padding_mode_t
enum
padding_mode_t
{
{
default_
,
// NOLINT
default_
,
// NOLINT
same
,
valid
,
same_lower
,
same_lower
,
same_upper
same_upper
};
};
...
...
src/include/migraphx/op/contiguous.hpp
View file @
dae94657
...
@@ -28,6 +28,7 @@
...
@@ -28,6 +28,7 @@
#include <migraphx/argument.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <migraphx/config.hpp>
#include <migraphx/dyn_output.hpp>
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
@@ -42,19 +43,27 @@ namespace op {
...
@@ -42,19 +43,27 @@ namespace op {
struct
contiguous
struct
contiguous
{
{
std
::
string
name
()
const
{
return
"contiguous"
;
}
std
::
string
name
()
const
{
return
"contiguous"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
{
check_shapes
{
inputs
,
*
this
}.
has
(
1
);
check_shapes
{
inputs
,
*
this
,
true
}.
has
(
1
);
if
(
inputs
.
front
().
standard
())
auto
s0
=
inputs
.
front
();
return
inputs
.
front
();
if
(
s0
.
dynamic
()
or
s0
.
standard
())
auto
lens
=
inputs
.
at
(
0
).
lens
();
{
auto
t
=
inputs
.
at
(
0
).
type
();
return
s0
;
}
else
{
const
auto
&
lens
=
s0
.
lens
();
auto
t
=
s0
.
type
();
return
{
t
,
lens
};
return
{
t
,
lens
};
}
}
argument
compute
(
const
shape
&
output_shape
,
std
::
vector
<
argument
>
args
)
const
}
argument
compute
(
const
dyn_output
&
dyn_out
,
std
::
vector
<
argument
>
args
)
const
{
{
assert
(
out
put_shape
.
standard
());
assert
(
dyn_out
.
com
put
ed
_shape
.
standard
());
argument
result
{
out
put_shape
};
argument
result
{
dyn_out
.
com
put
ed
_shape
};
visit_all
(
result
,
args
[
0
])([
&
](
auto
output
,
auto
input
)
{
visit_all
(
result
,
args
[
0
])([
&
](
auto
output
,
auto
input
)
{
shape_for_each
(
output
.
get_shape
(),
[
&
](
const
auto
&
idx
)
{
shape_for_each
(
output
.
get_shape
(),
[
&
](
const
auto
&
idx
)
{
output
(
idx
.
begin
(),
idx
.
end
())
=
input
(
idx
.
begin
(),
idx
.
end
());
output
(
idx
.
begin
(),
idx
.
end
())
=
input
(
idx
.
begin
(),
idx
.
end
());
...
...
src/include/migraphx/op/convert.hpp
View file @
dae94657
...
@@ -44,7 +44,7 @@ struct convert : unary<convert>
...
@@ -44,7 +44,7 @@ struct convert : unary<convert>
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
{
check_shapes
{
inputs
,
*
this
}.
has
(
1
);
check_shapes
{
inputs
,
*
this
,
true
}.
has
(
1
);
auto
input
=
inputs
.
at
(
0
);
auto
input
=
inputs
.
at
(
0
);
if
(
input
.
dynamic
())
if
(
input
.
dynamic
())
{
{
...
...
src/include/migraphx/op/convolution.hpp
View file @
dae94657
...
@@ -43,7 +43,6 @@ struct convolution
...
@@ -43,7 +43,6 @@ struct convolution
int
group
=
1
;
int
group
=
1
;
padding_mode_t
padding_mode
=
default_
;
padding_mode_t
padding_mode
=
default_
;
bool
use_dynamic_same_auto_pad
=
false
;
template
<
class
Self
,
class
F
>
template
<
class
Self
,
class
F
>
static
auto
reflect
(
Self
&
self
,
F
f
)
static
auto
reflect
(
Self
&
self
,
F
f
)
...
@@ -52,16 +51,15 @@ struct convolution
...
@@ -52,16 +51,15 @@ struct convolution
f
(
self
.
stride
,
"stride"
),
f
(
self
.
stride
,
"stride"
),
f
(
self
.
dilation
,
"dilation"
),
f
(
self
.
dilation
,
"dilation"
),
f
(
self
.
group
,
"group"
),
f
(
self
.
group
,
"group"
),
f
(
self
.
padding_mode
,
"padding_mode"
),
f
(
self
.
padding_mode
,
"padding_mode"
));
f
(
self
.
use_dynamic_same_auto_pad
,
"use_dynamic_same_auto_pad"
));
}
}
std
::
string
name
()
const
{
return
"convolution"
;
}
std
::
string
name
()
const
{
return
"convolution"
;
}
void
check_attribute_size
()
const
void
check_attribute_size
()
const
{
{
if
(
not
(
(
padding
.
size
()
=
=
stride
.
size
()
or
(
padding
.
size
()
/
2
)
=
=
stride
.
size
())
and
if
((
padding
.
size
()
!
=
stride
.
size
()
and
(
padding
.
size
()
/
2
)
!
=
stride
.
size
())
or
stride
.
size
()
=
=
dilation
.
size
())
)
stride
.
size
()
!
=
dilation
.
size
())
{
{
MIGRAPHX_THROW
(
"CONVOLUTION: inconsistent attribute sizes"
);
MIGRAPHX_THROW
(
"CONVOLUTION: inconsistent attribute sizes"
);
}
}
...
@@ -76,7 +74,8 @@ struct convolution
...
@@ -76,7 +74,8 @@ struct convolution
// num of dims of input and attribute should match
// num of dims of input and attribute should match
const
auto
input_size
=
inputs
[
0
].
max_lens
().
size
();
const
auto
input_size
=
inputs
[
0
].
max_lens
().
size
();
const
auto
padding_size
=
padding
.
size
();
const
auto
padding_size
=
padding
.
size
();
if
(
not
(
input_size
==
padding_size
/
2
+
2
or
input_size
==
padding_size
+
2
))
if
(
input_size
!=
padding_size
/
2
+
2
&&
input_size
!=
padding_size
+
2
)
{
{
MIGRAPHX_THROW
(
"CONVOLUTION: input and attribute size mismatch!"
);
MIGRAPHX_THROW
(
"CONVOLUTION: input and attribute size mismatch!"
);
}
}
...
@@ -93,13 +92,6 @@ struct convolution
...
@@ -93,13 +92,6 @@ struct convolution
x_shape
.
lens
().
at
(
1
)
!=
(
w_shape
.
lens
().
at
(
1
)
*
group
))
x_shape
.
lens
().
at
(
1
)
!=
(
w_shape
.
lens
().
at
(
1
)
*
group
))
MIGRAPHX_THROW
(
"CONVOLUTION: mismatched channel numbers"
);
MIGRAPHX_THROW
(
"CONVOLUTION: mismatched channel numbers"
);
std
::
vector
<
op
::
padding_mode_t
>
dyn_pad_modes
=
{
op
::
padding_mode_t
::
same_upper
,
op
::
padding_mode_t
::
same_lower
};
if
(
use_dynamic_same_auto_pad
and
not
contains
(
dyn_pad_modes
,
padding_mode
))
{
MIGRAPHX_THROW
(
"CONVOLUTION: use_dynamic_same_auto_pad set with invalid padding mode"
);
}
if
(
x_shape
.
dynamic
()
or
w_shape
.
dynamic
())
if
(
x_shape
.
dynamic
()
or
w_shape
.
dynamic
())
{
{
return
dynamic_compute_shape
(
x_shape
,
w_shape
);
return
dynamic_compute_shape
(
x_shape
,
w_shape
);
...
@@ -161,7 +153,7 @@ struct convolution
...
@@ -161,7 +153,7 @@ struct convolution
dynamic_shape_push_back
(
w_shape
);
dynamic_shape_push_back
(
w_shape
);
const
size_t
num_spatial_dims
=
x_shape
.
max_lens
().
size
()
-
2
;
const
size_t
num_spatial_dims
=
x_shape
.
max_lens
().
size
()
-
2
;
if
(
use_dynamic_same_auto_pad
)
if
(
padding_mode
!=
default_
)
{
{
for
(
std
::
size_t
i
=
0
;
i
<
num_spatial_dims
;
++
i
)
for
(
std
::
size_t
i
=
0
;
i
<
num_spatial_dims
;
++
i
)
{
{
...
...
src/include/migraphx/op/deconvolution.hpp
View file @
dae94657
...
@@ -61,8 +61,8 @@ struct deconvolution
...
@@ -61,8 +61,8 @@ struct deconvolution
void
check_attribute_size
()
const
void
check_attribute_size
()
const
{
{
if
(
not
(
(
padding
.
size
()
=
=
stride
.
size
()
or
(
padding
.
size
()
/
2
)
=
=
stride
.
size
())
and
if
((
padding
.
size
()
!
=
stride
.
size
()
and
(
padding
.
size
()
/
2
)
!
=
stride
.
size
())
or
stride
.
size
()
=
=
dilation
.
size
())
)
stride
.
size
()
!
=
dilation
.
size
())
{
{
MIGRAPHX_THROW
(
"deconvolution: inconsistent attribute sizes"
);
MIGRAPHX_THROW
(
"deconvolution: inconsistent attribute sizes"
);
}
}
...
...
Prev
1
2
3
4
5
6
…
11
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment