Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
38163d54
Commit
38163d54
authored
Aug 11, 2022
by
turneram
Browse files
Merge remote-tracking branch 'origin/develop' into bert-attention-no-transpose-ops
parents
7e316254
5bf4dee6
Changes
36
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
406 additions
and
96 deletions
+406
-96
README.md
README.md
+1
-0
examples/migraphx/custom_op_rocblas_kernel/custom_op_rocblas_kernel.cpp
...phx/custom_op_rocblas_kernel/custom_op_rocblas_kernel.cpp
+8
-6
src/CMakeLists.txt
src/CMakeLists.txt
+1
-0
src/auto_contiguous.cpp
src/auto_contiguous.cpp
+1
-1
src/dead_code_elimination.cpp
src/dead_code_elimination.cpp
+4
-3
src/include/migraphx/literal.hpp
src/include/migraphx/literal.hpp
+5
-0
src/include/migraphx/op/common.hpp
src/include/migraphx/op/common.hpp
+3
-1
src/include/migraphx/op/convolution.hpp
src/include/migraphx/op/convolution.hpp
+122
-26
src/include/migraphx/op/quant_convolution.hpp
src/include/migraphx/op/quant_convolution.hpp
+5
-3
src/include/migraphx/operation.hpp
src/include/migraphx/operation.hpp
+5
-3
src/include/migraphx/pad_calc.hpp
src/include/migraphx/pad_calc.hpp
+23
-25
src/insert_pad.cpp
src/insert_pad.cpp
+6
-0
src/instruction.cpp
src/instruction.cpp
+2
-2
src/normalize_ops.cpp
src/normalize_ops.cpp
+2
-2
src/onnx/onnx_parser.cpp
src/onnx/onnx_parser.cpp
+2
-3
src/onnx/parse_constant.cpp
src/onnx/parse_constant.cpp
+1
-1
src/onnx/parse_convolution.cpp
src/onnx/parse_convolution.cpp
+63
-17
src/pad_calc.cpp
src/pad_calc.cpp
+90
-0
src/program.cpp
src/program.cpp
+8
-2
src/targets/ref/lowering.cpp
src/targets/ref/lowering.cpp
+54
-1
No files found.
README.md
View file @
38163d54
...
@@ -46,6 +46,7 @@ The following is a list of prerequisites required to build MIGraphX source.
...
@@ -46,6 +46,7 @@ The following is a list of prerequisites required to build MIGraphX source.
*
[
pybind11
](
https://pybind11.readthedocs.io/en/stable/
)
- for python bindings
*
[
pybind11
](
https://pybind11.readthedocs.io/en/stable/
)
- for python bindings
*
[
JSON
](
https://github.com/nlohmann/json
)
- for model serialization to json string format
*
[
JSON
](
https://github.com/nlohmann/json
)
- for model serialization to json string format
*
[
MessagePack
](
https://msgpack.org/index.html
)
- for model serialization to binary format
*
[
MessagePack
](
https://msgpack.org/index.html
)
- for model serialization to binary format
*
[
SQLite3
](
https://www.sqlite.org/index.html
)
- to create database of kernels' tuning information or execute queries on existing database
#### Use the ROCm build tool [rbuild](https://github.com/RadeonOpenCompute/rbuild).
#### Use the ROCm build tool [rbuild](https://github.com/RadeonOpenCompute/rbuild).
...
...
examples/migraphx/custom_op_rocblas_kernel/custom_op_rocblas_kernel.cpp
View file @
38163d54
...
@@ -23,7 +23,7 @@
...
@@ -23,7 +23,7 @@
*/
*/
#include <algorithm>
#include <algorithm>
#include <hip/hip_runtime.h>
#include <hip/hip_runtime.h>
#include <rocblas.h>
#include <rocblas
/rocblas
.h>
#include <migraphx/migraphx.h>
#include <migraphx/migraphx.h>
#include <migraphx/migraphx.hpp> // MIGraphX's C++ API
#include <migraphx/migraphx.hpp> // MIGraphX's C++ API
#include <numeric>
#include <numeric>
...
@@ -56,11 +56,13 @@ struct sscal_custom_op final : migraphx::experimental_custom_op_base
...
@@ -56,11 +56,13 @@ struct sscal_custom_op final : migraphx::experimental_custom_op_base
migraphx
::
arguments
args
)
const
override
migraphx
::
arguments
args
)
const
override
{
{
// create rocblas stream handle
// create rocblas stream handle
auto
rocblas_handle
=
create_rocblas_handle_ptr
(
ctx
);
auto
rb_handle
=
create_rocblas_handle_ptr
(
ctx
);
rocblas_int
n
=
args
[
1
].
get_shape
().
lengths
()[
0
];
MIGRAPHX_ROCBLAS_ASSERT
(
rocblas_set_pointer_mode
(
rb_handle
,
rocblas_pointer_mode_device
));
float
*
alpha
=
reinterpret_cast
<
float
*>
(
args
[
0
].
data
());
rocblas_int
n
=
args
[
1
].
get_shape
().
lengths
()[
0
];
float
*
vec_ptr
=
reinterpret_cast
<
float
*>
(
args
[
1
].
data
());
float
*
alpha
=
reinterpret_cast
<
float
*>
(
args
[
0
].
data
());
MIGRAPHX_ROCBLAS_ASSERT
(
rocblas_sscal
(
rocblas_handle
,
n
,
alpha
,
vec_ptr
,
1
));
float
*
vec_ptr
=
reinterpret_cast
<
float
*>
(
args
[
1
].
data
());
MIGRAPHX_ROCBLAS_ASSERT
(
rocblas_sscal
(
rb_handle
,
n
,
alpha
,
vec_ptr
,
1
));
MIGRAPHX_ROCBLAS_ASSERT
(
rocblas_destroy_handle
(
rb_handle
));
return
args
[
1
];
return
args
[
1
];
}
}
...
...
src/CMakeLists.txt
View file @
38163d54
...
@@ -65,6 +65,7 @@ add_library(migraphx
...
@@ -65,6 +65,7 @@ add_library(migraphx
operation.cpp
operation.cpp
opt/memory_coloring.cpp
opt/memory_coloring.cpp
opt/memory_coloring_impl.cpp
opt/memory_coloring_impl.cpp
pad_calc.cpp
pass_manager.cpp
pass_manager.cpp
permutation.cpp
permutation.cpp
preallocate_param.cpp
preallocate_param.cpp
...
...
src/auto_contiguous.cpp
View file @
38163d54
...
@@ -63,7 +63,7 @@ void auto_contiguous::apply(module& m) const
...
@@ -63,7 +63,7 @@ void auto_contiguous::apply(module& m) const
if
(
ins
->
outputs
().
empty
()
and
ins
!=
last
)
if
(
ins
->
outputs
().
empty
()
and
ins
!=
last
)
continue
;
continue
;
shape
s
=
ins
->
get_shape
();
shape
s
=
ins
->
get_shape
();
if
(
not
s
.
standard
()
and
s
.
elements
()
!=
0
)
if
(
not
s
.
dynamic
()
and
not
s
.
standard
()
and
s
.
elements
()
!=
0
)
{
{
auto
c
=
m
.
insert_instruction
(
std
::
next
(
ins
),
make_op
(
"contiguous"
),
ins
);
auto
c
=
m
.
insert_instruction
(
std
::
next
(
ins
),
make_op
(
"contiguous"
),
ins
);
m
.
replace_instruction
(
ins
,
c
);
m
.
replace_instruction
(
ins
,
c
);
...
...
src/dead_code_elimination.cpp
View file @
38163d54
...
@@ -48,9 +48,10 @@ void dead_code_elimination::apply(module& m) const
...
@@ -48,9 +48,10 @@ void dead_code_elimination::apply(module& m) const
// Skip the last instruction
// Skip the last instruction
if
(
i
==
last
)
if
(
i
==
last
)
break
;
break
;
// Skip instruction with empty shape as output unless its a builtin, undefined, identity, or
// Skip instruction with empty shape as output unless its [dynamic, builtin, undefined,
// allocate
// identity, allocate]
if
(
i
->
get_shape
().
elements
()
==
0
and
i
->
name
().
front
()
!=
'@'
and
if
((
not
i
->
get_shape
().
dynamic
()
and
i
->
get_shape
().
elements
()
==
0
)
and
i
->
name
().
front
()
!=
'@'
and
not
contains
({
"undefined"
,
"identity"
,
"allocate"
},
i
->
name
()))
not
contains
({
"undefined"
,
"identity"
,
"allocate"
},
i
->
name
()))
continue
;
continue
;
assert
(
std
::
distance
(
m
.
begin
(),
i
)
<=
std
::
distance
(
m
.
begin
(),
last
));
assert
(
std
::
distance
(
m
.
begin
(),
i
)
<=
std
::
distance
(
m
.
begin
(),
last
));
...
...
src/include/migraphx/literal.hpp
View file @
38163d54
...
@@ -45,6 +45,11 @@ struct literal : raw_data<literal>
...
@@ -45,6 +45,11 @@ struct literal : raw_data<literal>
{
{
literal
()
{}
literal
()
{}
/*!
* Empty literal with a specific shape type
*/
explicit
literal
(
shape
::
type_t
shape_type
)
:
m_shape
(
shape_type
,
{})
{}
template
<
class
U
,
class
T
=
deduce
<
U
>,
shape
::
type_t
ShapeType
=
shape
::
get_type
<
T
>
{}
>
template
<
class
U
,
class
T
=
deduce
<
U
>,
shape
::
type_t
ShapeType
=
shape
::
get_type
<
T
>
{}
>
literal
(
U
x
)
:
buffer
(
make_shared_array
<
char
>
(
sizeof
(
T
))),
m_shape
(
ShapeType
)
literal
(
U
x
)
:
buffer
(
make_shared_array
<
char
>
(
sizeof
(
T
))),
m_shape
(
ShapeType
)
{
{
...
...
src/include/migraphx/op/common.hpp
View file @
38163d54
...
@@ -37,7 +37,9 @@ enum padding_mode_t
...
@@ -37,7 +37,9 @@ enum padding_mode_t
{
{
default_
,
// NOLINT
default_
,
// NOLINT
same
,
same
,
valid
valid
,
same_lower
,
same_upper
};
};
// The pooling modes must correspond 1-1 to the operators defined for struct parse_pooling.
// The pooling modes must correspond 1-1 to the operators defined for struct parse_pooling.
...
...
src/include/migraphx/op/convolution.hpp
View file @
38163d54
...
@@ -41,8 +41,9 @@ struct convolution
...
@@ -41,8 +41,9 @@ struct convolution
std
::
vector
<
std
::
size_t
>
stride
=
{
1
,
1
};
std
::
vector
<
std
::
size_t
>
stride
=
{
1
,
1
};
std
::
vector
<
std
::
size_t
>
dilation
=
{
1
,
1
};
std
::
vector
<
std
::
size_t
>
dilation
=
{
1
,
1
};
int
group
=
1
;
int
group
=
1
;
padding_mode_t
padding_mode
=
default_
;
padding_mode_t
padding_mode
=
default_
;
bool
use_dynamic_same_auto_pad
=
false
;
template
<
class
Self
,
class
F
>
template
<
class
Self
,
class
F
>
static
auto
reflect
(
Self
&
self
,
F
f
)
static
auto
reflect
(
Self
&
self
,
F
f
)
...
@@ -51,7 +52,8 @@ struct convolution
...
@@ -51,7 +52,8 @@ struct convolution
f
(
self
.
stride
,
"stride"
),
f
(
self
.
stride
,
"stride"
),
f
(
self
.
dilation
,
"dilation"
),
f
(
self
.
dilation
,
"dilation"
),
f
(
self
.
group
,
"group"
),
f
(
self
.
group
,
"group"
),
f
(
self
.
padding_mode
,
"padding_mode"
));
f
(
self
.
padding_mode
,
"padding_mode"
),
f
(
self
.
use_dynamic_same_auto_pad
,
"use_dynamic_same_auto_pad"
));
}
}
std
::
string
name
()
const
{
return
"convolution"
;
}
std
::
string
name
()
const
{
return
"convolution"
;
}
...
@@ -69,43 +71,137 @@ struct convolution
...
@@ -69,43 +71,137 @@ struct convolution
shape
normalize_compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
shape
normalize_compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
{
check_shapes
{
inputs
,
*
this
}.
has
(
2
).
same_type
().
same_ndims
().
min_ndims
(
3
);
check_shapes
{
inputs
,
*
this
,
true
}.
has
(
2
).
same_type
().
same_ndims
().
min_ndims
(
3
);
check_attribute_size
();
check_attribute_size
();
//
dim num
of input and attribute should match
//
num of dims
of input and attribute should match
auto
input_size
=
inputs
[
0
].
lens
().
size
();
const
auto
input_size
=
inputs
[
0
].
max_
lens
().
size
();
auto
padding_size
=
padding
.
size
();
const
auto
padding_size
=
padding
.
size
();
if
(
not
(
input_size
==
padding_size
/
2
+
2
or
input_size
==
padding_size
+
2
))
if
(
not
(
input_size
==
padding_size
/
2
+
2
or
input_size
==
padding_size
+
2
))
{
{
MIGRAPHX_THROW
(
"CONVOLUTION: input and attribute size mismatch!"
);
MIGRAPHX_THROW
(
"CONVOLUTION: input and attribute size mismatch!"
);
}
}
const
shape
&
input
=
inputs
.
at
(
0
);
const
shape
&
x_shape
=
inputs
.
at
(
0
);
const
shape
&
w
eights
=
inputs
.
at
(
1
);
const
shape
&
w
_shape
=
inputs
.
at
(
1
);
size_t
kdims
=
input_size
-
2
;
const
size_t
num_spatial_dims
=
input_size
-
2
;
if
(
k
dims
!=
this
->
kdims
())
if
(
num_spatial_
dims
!=
this
->
kdims
())
{
{
MIGRAPHX_THROW
(
"
convolution
: input k-dims does not match attribute size"
);
MIGRAPHX_THROW
(
"
CONVOLUTION
: input k-dims does not match attribute size"
);
}
}
if
(
input
.
lens
().
at
(
1
)
!=
(
weights
.
lens
().
at
(
1
)
*
group
))
if
(
not
x_shape
.
dynamic
()
and
not
w_shape
.
dynamic
()
and
MIGRAPHX_THROW
(
"CONVOLUTION: Mismatch channel numbers"
);
x_shape
.
lens
().
at
(
1
)
!=
(
w_shape
.
lens
().
at
(
1
)
*
group
))
MIGRAPHX_THROW
(
"CONVOLUTION: mismatched channel numbers"
);
std
::
vector
<
size_t
>
output_lens
{
input
.
lens
()[
0
],
weights
.
lens
()[
0
]};
std
::
vector
<
op
::
padding_mode_t
>
dyn_pad_modes
=
{
op
::
padding_mode_t
::
same_upper
,
op
::
padding_mode_t
::
same_lower
};
if
(
use_dynamic_same_auto_pad
and
not
contains
(
dyn_pad_modes
,
padding_mode
))
{
MIGRAPHX_THROW
(
"CONVOLUTION: use_dynamic_same_auto_pad set with invalid padding mode"
);
}
if
(
x_shape
.
dynamic
()
or
w_shape
.
dynamic
())
{
return
dynamic_compute_shape
(
x_shape
,
w_shape
);
}
else
{
return
fixed_compute_shape
(
x_shape
,
w_shape
);
}
}
std
::
vector
<
std
::
size_t
>
calc_conv_lens
(
std
::
vector
<
std
::
size_t
>
x_lens
,
std
::
vector
<
std
::
size_t
>
w_lens
)
const
{
const
size_t
num_spatial_dims
=
x_lens
.
size
()
-
2
;
std
::
vector
<
size_t
>
ret
=
{};
// calculate the output shape of the convolution: ((W - K + 2P) / S) + 1
for
(
size_t
i
=
0
;
i
<
num_spatial_dims
;
i
++
)
{
if
(
x_lens
[
i
]
==
0
or
w_lens
[
i
]
==
0
)
{
// for handling when a dimension = 0 (opt of dynamic_dimension)
ret
.
push_back
(
0
);
}
else
{
auto
padding_factor
=
2
*
padding
[
i
];
if
(
padding
.
size
()
==
2
*
num_spatial_dims
)
{
// when padding is {x0_begin, x1_begin, ... x0_end , x1_end, ...}
padding_factor
=
padding
[
i
]
+
padding
[
i
+
num_spatial_dims
];
}
ret
.
push_back
(
std
::
size_t
(
std
::
max
<
std
::
ptrdiff_t
>
(
1
,
(
x_lens
[
i
+
2
]
-
(
1
+
dilation
[
i
]
*
(
w_lens
[
i
+
2
]
-
1
))
+
padding_factor
)
/
stride
[
i
]
+
1
)));
}
}
return
ret
;
}
for
(
size_t
i
=
0
;
i
<
kdims
;
i
++
)
shape
dynamic_compute_shape
(
shape
x_shape
,
shape
w_shape
)
const
{
std
::
vector
<
shape
::
dynamic_dimension
>
output_dyn_dims
=
{};
auto
dynamic_shape_push_back
=
[
&
](
const
shape
&
input_shape
)
{
if
(
input_shape
.
dynamic
())
{
output_dyn_dims
.
push_back
(
input_shape
.
dyn_dims
().
at
(
0
));
}
else
{
auto
l
=
input_shape
.
lens
().
at
(
0
);
output_dyn_dims
.
push_back
({
l
,
l
,
0
});
}
};
dynamic_shape_push_back
(
x_shape
);
dynamic_shape_push_back
(
w_shape
);
const
size_t
num_spatial_dims
=
x_shape
.
max_lens
().
size
()
-
2
;
if
(
use_dynamic_same_auto_pad
)
{
{
auto
padding_factor
=
2
*
padding
[
i
];
for
(
std
::
size_t
i
=
0
;
i
<
num_spatial_dims
;
++
i
)
if
(
padding_size
==
2
*
kdims
)
{
padding_factor
=
padding
[
i
]
+
padding
[
i
+
kdims
];
auto
ceil_div
=
[](
std
::
size_t
x
,
std
::
size_t
y
)
{
return
(
x
+
y
-
1
)
/
y
;
};
output_lens
.
push_back
(
std
::
size_t
(
std
::
max
<
std
::
ptrdiff_t
>
(
auto
s
=
stride
[
i
];
1
,
if
(
x_shape
.
dynamic
())
(
input
.
lens
()[
i
+
2
]
-
(
1
+
dilation
[
i
]
*
(
weights
.
lens
()[
i
+
2
]
-
1
))
+
{
padding_factor
)
/
auto
x
=
x_shape
.
dyn_dims
()[
i
+
2
];
stride
[
i
]
+
output_dyn_dims
.
push_back
(
shape
::
dynamic_dimension
{
1
)));
ceil_div
(
x
.
min
,
s
),
ceil_div
(
x
.
max
,
s
),
ceil_div
(
x
.
opt
,
s
)});
}
else
{
auto
od
=
ceil_div
(
x_shape
.
lens
()[
i
+
2
],
s
);
output_dyn_dims
.
push_back
(
shape
::
dynamic_dimension
{
od
,
od
,
0
});
}
}
}
}
else
{
auto
min_spatial_dims
=
calc_conv_lens
(
x_shape
.
min_lens
(),
w_shape
.
max_lens
());
auto
max_spatial_dims
=
calc_conv_lens
(
x_shape
.
max_lens
(),
w_shape
.
min_lens
());
auto
opt_spatial_dims
=
calc_conv_lens
(
x_shape
.
opt_lens
(),
w_shape
.
opt_lens
());
for
(
size_t
i
=
0
;
i
<
num_spatial_dims
;
++
i
)
{
output_dyn_dims
.
push_back
(
shape
::
dynamic_dimension
{
min_spatial_dims
[
i
],
max_spatial_dims
[
i
],
opt_spatial_dims
[
i
]});
}
}
return
shape
{
x_shape
.
type
(),
output_dyn_dims
};
}
return
inputs
[
0
].
with_lens
(
output_lens
);
shape
fixed_compute_shape
(
shape
x_shape
,
shape
w_shape
)
const
{
std
::
vector
<
size_t
>
output_lens
{
x_shape
.
lens
()[
0
],
w_shape
.
lens
()[
0
]};
auto
spatial_lens
=
calc_conv_lens
(
x_shape
.
lens
(),
w_shape
.
lens
());
std
::
for_each
(
spatial_lens
.
begin
(),
spatial_lens
.
end
(),
[
&
output_lens
](
auto
x
)
{
output_lens
.
push_back
(
x
);
});
return
x_shape
.
with_lens
(
output_lens
);
}
}
size_t
kdims
()
const
size_t
kdims
()
const
...
...
src/include/migraphx/op/quant_convolution.hpp
View file @
38163d54
...
@@ -41,8 +41,9 @@ struct quant_convolution
...
@@ -41,8 +41,9 @@ struct quant_convolution
std
::
vector
<
std
::
size_t
>
stride
=
{
1
,
1
};
std
::
vector
<
std
::
size_t
>
stride
=
{
1
,
1
};
std
::
vector
<
std
::
size_t
>
dilation
=
{
1
,
1
};
std
::
vector
<
std
::
size_t
>
dilation
=
{
1
,
1
};
padding_mode_t
padding_mode
=
default_
;
padding_mode_t
padding_mode
=
default_
;
int
group
=
1
;
int
group
=
1
;
bool
use_dynamic_same_auto_pad
=
false
;
template
<
class
Self
,
class
F
>
template
<
class
Self
,
class
F
>
static
auto
reflect
(
Self
&
self
,
F
f
)
static
auto
reflect
(
Self
&
self
,
F
f
)
...
@@ -51,7 +52,8 @@ struct quant_convolution
...
@@ -51,7 +52,8 @@ struct quant_convolution
f
(
self
.
stride
,
"stride"
),
f
(
self
.
stride
,
"stride"
),
f
(
self
.
dilation
,
"dilation"
),
f
(
self
.
dilation
,
"dilation"
),
f
(
self
.
padding_mode
,
"padding_mode"
),
f
(
self
.
padding_mode
,
"padding_mode"
),
f
(
self
.
group
,
"group"
));
f
(
self
.
group
,
"group"
),
f
(
self
.
use_dynamic_same_auto_pad
,
"use_dynamic_same_auto_pad"
));
}
}
value
attributes
()
const
value
attributes
()
const
...
...
src/include/migraphx/operation.hpp
View file @
38163d54
...
@@ -68,8 +68,10 @@ struct operation
...
@@ -68,8 +68,10 @@ struct operation
*
*
* @param ctx This is the context created by the `target` during compilation. Implementations
* @param ctx This is the context created by the `target` during compilation. Implementations
* can use the target's `context` class rather than the `context` interface class.
* can use the target's `context` class rather than the `context` interface class.
* @param output This is the output shape. It is equivalent to running `compute_shape` with each
* @param output Equivalent to running `compute_shape` with each `shape` of the `argument`.
* `shape` of the `argument`.
* For a fixed shape, the returned argument will have the same shape as `output`.
* For a dynamic shape, the returned `argument` will be a fixed shape within the bounds
* set in the dynamic shape `output`.
* @param input This is the `argument` result from the previous instruction's computation.
* @param input This is the `argument` result from the previous instruction's computation.
* @return Return an `argument` of the result computation. The `shape` of `argument` should be
* @return Return an `argument` of the result computation. The `shape` of `argument` should be
* the same the `output` shape.
* the same the `output` shape.
...
@@ -137,7 +139,7 @@ auto compute_shape_op(rank<2>, const T& x, const std::vector<shape>& inputs)
...
@@ -137,7 +139,7 @@ auto compute_shape_op(rank<2>, const T& x, const std::vector<shape>& inputs)
->
decltype
(
x
.
normalize_compute_shape
(
inputs
))
->
decltype
(
x
.
normalize_compute_shape
(
inputs
))
{
{
dependent_type
<
operation
,
T
>
y
=
x
;
dependent_type
<
operation
,
T
>
y
=
x
;
normalize_attributes
(
y
,
inputs
[
0
].
lens
());
normalize_attributes
(
y
,
inputs
[
0
].
max_
lens
());
return
any_cast
<
T
>
(
y
).
normalize_compute_shape
(
inputs
);
return
any_cast
<
T
>
(
y
).
normalize_compute_shape
(
inputs
);
}
}
...
...
src/include/migraphx/pad_calc.hpp
View file @
38163d54
...
@@ -24,38 +24,36 @@
...
@@ -24,38 +24,36 @@
#ifndef MIGRAPHX_GUARD_OPERATORS_PAD_CALC_HPP
#ifndef MIGRAPHX_GUARD_OPERATORS_PAD_CALC_HPP
#define MIGRAPHX_GUARD_OPERATORS_PAD_CALC_HPP
#define MIGRAPHX_GUARD_OPERATORS_PAD_CALC_HPP
#include <
utility
>
#include <
migraphx/config.hpp
>
#include <cstdint>
#include <cstdint>
#include <vector>
#include <vector>
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
void
calculate_padding
(
int64_t
idx
,
void
calculate_padding
(
int64_t
idx
,
std
::
vector
<
int64_t
>&
pads
,
std
::
vector
<
int64_t
>&
pads
,
int64_t
input_dim
,
int64_t
input_dim
,
int64_t
stride
,
int64_t
stride
,
int64_t
dilation
,
int64_t
dilation
,
int64_t
weight_dim
,
int64_t
weight_dim
,
bool
is_same_upper
=
true
)
bool
is_same_upper
=
true
);
{
int64_t
output_dim
=
(
input_dim
+
stride
-
1
)
/
stride
;
// round up result
int64_t
new_weight_dim
=
weight_dim
+
(
weight_dim
-
1
)
*
(
dilation
-
1
);
int64_t
pad
=
std
::
max
(
static_cast
<
int64_t
>
(
0
),
(
output_dim
-
1
)
*
stride
+
new_weight_dim
-
input_dim
);
auto
pad_ndims
=
pads
.
size
()
/
2
;
if
(
is_same_upper
)
/*!
{
* Calculate the padding for auto_padding. Used for dynamic shapes
pads
[
idx
]
=
pad
/
2
;
* where the padding calculation must be done at evaluation time.
pads
[
idx
+
pad_ndims
]
=
pad
-
pad
/
2
;
* \param tensor_lens input tensor image shape
}
* \param k_lens weights kernel shape
else
* \param strides strides for the kernel
{
* \param dilations dilations for the kernel
pads
[
idx
+
pad_ndims
]
=
pad
/
2
;
* \param use_upper put odd padding on upper or lower side
pads
[
idx
]
=
pad
-
pad
/
2
;
* \return padding in the form of {x0_begin, x1_begin, ... x0_end , x1_end, ...}
}
*/
}
std
::
vector
<
std
::
size_t
>
calc_dyn_auto_pad
(
std
::
vector
<
std
::
size_t
>
tensor_lens
,
std
::
vector
<
std
::
size_t
>
k_lens
,
std
::
vector
<
std
::
size_t
>
strides
,
std
::
vector
<
std
::
size_t
>
dilations
,
bool
use_upper
=
true
);
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
}
// namespace migraphx
...
...
src/insert_pad.cpp
View file @
38163d54
...
@@ -40,6 +40,12 @@ static void update_op(const instruction_ref& input, const instruction_ref& ins,
...
@@ -40,6 +40,12 @@ static void update_op(const instruction_ref& input, const instruction_ref& ins,
auto
val
=
op
.
to_value
();
auto
val
=
op
.
to_value
();
auto
op_padding
=
val
.
at
(
"padding"
).
to_vector
<
size_t
>
();
auto
op_padding
=
val
.
at
(
"padding"
).
to_vector
<
size_t
>
();
// skip if shape is dynamic
if
(
input
->
get_shape
().
dynamic
())
{
return
;
}
auto
kdims
=
input
->
get_shape
().
lens
().
size
()
-
2
;
auto
kdims
=
input
->
get_shape
().
lens
().
size
()
-
2
;
if
(
std
::
equal
(
op_padding
.
begin
(),
if
(
std
::
equal
(
op_padding
.
begin
(),
op_padding
.
begin
()
+
kdims
,
op_padding
.
begin
()
+
kdims
,
...
...
src/instruction.cpp
View file @
38163d54
...
@@ -445,8 +445,8 @@ operation instruction::normalized_operator() const
...
@@ -445,8 +445,8 @@ operation instruction::normalized_operator() const
operation
o
=
this
->
get_operator
();
operation
o
=
this
->
get_operator
();
if
(
this
->
need_normalization
())
if
(
this
->
need_normalization
())
{
{
auto
len
s
=
this
->
inputs
().
front
()
->
get_shape
()
.
lens
()
;
auto
s
=
this
->
inputs
().
front
()
->
get_shape
();
if
(
!
normalize_attributes
(
o
,
lens
))
if
(
!
normalize_attributes
(
o
,
s
.
max_
lens
()
))
return
this
->
get_operator
();
return
this
->
get_operator
();
}
}
return
o
;
return
o
;
...
...
src/normalize_ops.cpp
View file @
38163d54
...
@@ -43,9 +43,9 @@ void normalize_ops::apply(module& m) const
...
@@ -43,9 +43,9 @@ void normalize_ops::apply(module& m) const
if
(
inputs
.
empty
())
if
(
inputs
.
empty
())
continue
;
continue
;
auto
lens
=
inputs
[
0
]
->
get_shape
()
.
lens
()
;
auto
s
=
inputs
[
0
]
->
get_shape
();
migraphx
::
operation
tuned_op
=
ins
->
get_operator
();
migraphx
::
operation
tuned_op
=
ins
->
get_operator
();
if
(
normalize_attributes
(
tuned_op
,
lens
))
if
(
normalize_attributes
(
tuned_op
,
s
.
max_
lens
()
))
{
{
m
.
replace_instruction
(
ins
,
tuned_op
,
inputs
);
m
.
replace_instruction
(
ins
,
tuned_op
,
inputs
);
ins
->
set_normalized
();
ins
->
set_normalized
();
...
...
src/onnx/onnx_parser.cpp
View file @
38163d54
...
@@ -28,7 +28,6 @@
...
@@ -28,7 +28,6 @@
#include <migraphx/stringutils.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/pad_calc.hpp>
#include <migraphx/common.hpp>
#include <migraphx/common.hpp>
#include <migraphx/type_traits.hpp>
#include <migraphx/type_traits.hpp>
#include <migraphx/float_equal.hpp>
#include <migraphx/float_equal.hpp>
...
@@ -60,7 +59,7 @@ create_literal(shape::type_t shape_type, const std::vector<size_t>& dims, const
...
@@ -60,7 +59,7 @@ create_literal(shape::type_t shape_type, const std::vector<size_t>& dims, const
std
::
accumulate
(
dims
.
begin
(),
dims
.
end
(),
std
::
size_t
(
1
),
std
::
multiplies
<
std
::
size_t
>
());
std
::
accumulate
(
dims
.
begin
(),
dims
.
end
(),
std
::
size_t
(
1
),
std
::
multiplies
<
std
::
size_t
>
());
if
(
elem_num
==
0
)
if
(
elem_num
==
0
)
{
{
return
{
};
return
literal
{
shape_type
};
}
}
// in case of scalar constants in onnx file, use dims=1 to fill initializer data
// in case of scalar constants in onnx file, use dims=1 to fill initializer data
...
@@ -77,7 +76,7 @@ static literal create_literal(shape::type_t shape_type, const std::vector<size_t
...
@@ -77,7 +76,7 @@ static literal create_literal(shape::type_t shape_type, const std::vector<size_t
std
::
accumulate
(
dims
.
begin
(),
dims
.
end
(),
std
::
size_t
(
1
),
std
::
multiplies
<
std
::
size_t
>
());
std
::
accumulate
(
dims
.
begin
(),
dims
.
end
(),
std
::
size_t
(
1
),
std
::
multiplies
<
std
::
size_t
>
());
if
(
elem_num
==
0
)
if
(
elem_num
==
0
)
{
{
return
{
};
return
literal
{
shape_type
};
}
}
// scalar input
// scalar input
...
...
src/onnx/parse_constant.cpp
View file @
38163d54
...
@@ -43,7 +43,7 @@ struct parse_constant : op_parser<parse_constant>
...
@@ -43,7 +43,7 @@ struct parse_constant : op_parser<parse_constant>
// return empty literal
// return empty literal
if
(
v
.
get_shape
().
elements
()
==
0
)
if
(
v
.
get_shape
().
elements
()
==
0
)
{
{
return
info
.
add_literal
(
literal
{});
return
info
.
add_literal
(
literal
{
v
.
get_shape
().
type
()
});
}
}
auto
dim_size
=
info
.
attributes
.
at
(
"value"
).
t
().
dims_size
();
auto
dim_size
=
info
.
attributes
.
at
(
"value"
).
t
().
dims_size
();
...
...
src/onnx/parse_convolution.cpp
View file @
38163d54
...
@@ -47,15 +47,17 @@ struct parse_convolution : op_parser<parse_convolution>
...
@@ -47,15 +47,17 @@ struct parse_convolution : op_parser<parse_convolution>
onnx_parser
::
node_info
info
,
onnx_parser
::
node_info
info
,
std
::
vector
<
instruction_ref
>
args
)
const
std
::
vector
<
instruction_ref
>
args
)
const
{
{
auto
op
=
make_op
(
opd
.
op_name
);
auto
op
=
make_op
(
opd
.
op_name
);
auto
values
=
op
.
to_value
();
auto
values
=
op
.
to_value
();
auto
l0
=
args
[
0
];
auto
l0
=
args
[
0
];
auto
weights
=
args
[
1
];
auto
weights
=
args
[
1
];
auto
in_lens
=
l0
->
get_shape
().
lens
();
auto
l0_shape
=
l0
->
get_shape
();
auto
w_shape
=
weights
->
get_shape
();
auto
in_lens
=
l0_shape
.
max_lens
();
assert
(
in_lens
.
size
()
>
2
);
assert
(
in_lens
.
size
()
>
2
);
auto
kdims
=
in_lens
.
size
()
-
2
;
auto
kdims
=
in_lens
.
size
()
-
2
;
// ensure pads availabe only when auto_pad is "NOT_SET"
// ensure pads availab
l
e only when auto_pad is "NOT_SET"
check_padding_mode
(
info
,
"CONV"
);
check_padding_mode
(
info
,
"CONV"
);
if
(
contains
(
info
.
attributes
,
"strides"
))
if
(
contains
(
info
.
attributes
,
"strides"
))
...
@@ -79,21 +81,65 @@ struct parse_convolution : op_parser<parse_convolution>
...
@@ -79,21 +81,65 @@ struct parse_convolution : op_parser<parse_convolution>
copy
(
info
.
attributes
[
"pads"
].
ints
(),
std
::
back_inserter
(
padding
));
copy
(
info
.
attributes
[
"pads"
].
ints
(),
std
::
back_inserter
(
padding
));
check_attr_sizes
(
kdims
,
padding
.
size
()
/
2
,
"PARSE_CONV: inconsistent paddings"
);
check_attr_sizes
(
kdims
,
padding
.
size
()
/
2
,
"PARSE_CONV: inconsistent paddings"
);
}
}
if
(
contains
(
info
.
attributes
,
"auto_pad"
))
if
(
contains
(
info
.
attributes
,
"auto_pad"
))
{
{
auto
weight_lens
=
weights
->
get_shape
().
lens
();
bool
is_same_padding
=
false
;
std
::
vector
<
std
::
size_t
>
k_lens
(
weight_lens
.
begin
()
+
2
,
weight_lens
.
end
());
auto
auto_pad
=
info
.
attributes
[
"auto_pad"
].
s
();
cal_auto_padding_size
(
info
,
values
,
k_lens
,
values
[
"dilation"
].
to_vector
<
std
::
size_t
>
(),
in_lens
,
padding
);
auto
auto_pad
=
info
.
attributes
[
"auto_pad"
].
s
();
if
(
auto_pad
.
find
(
"SAME"
)
!=
std
::
string
::
npos
)
if
(
auto_pad
.
find
(
"SAME"
)
!=
std
::
string
::
npos
)
{
{
values
[
"padding_mode"
]
=
to_value
(
op
::
padding_mode_t
::
same
);
is_same_padding
=
true
;
}
// check if image shape is dynamic
bool
image_shape_dynamic
=
false
;
if
(
l0_shape
.
dynamic
())
{
auto
dyn_dims
=
l0_shape
.
dyn_dims
();
std
::
for_each
(
dyn_dims
.
begin
()
+
2
,
dyn_dims
.
end
(),
[
&
](
auto
dyn_dim
)
{
if
(
not
dyn_dim
.
is_fixed
())
{
image_shape_dynamic
=
true
;
}
});
}
// check if kernel shape is dynamic
bool
kernel_shape_dynamic
=
false
;
if
(
w_shape
.
dynamic
())
{
auto
dyn_dims
=
w_shape
.
dyn_dims
();
std
::
for_each
(
dyn_dims
.
begin
()
+
2
,
dyn_dims
.
end
(),
[
&
](
auto
dyn_dim
)
{
if
(
not
dyn_dim
.
is_fixed
())
{
kernel_shape_dynamic
=
true
;
}
});
}
if
(
is_same_padding
)
{
if
(
image_shape_dynamic
or
kernel_shape_dynamic
)
{
// must calculate "same" padding with input shape data
bool
is_same_upper
=
(
auto_pad
.
find
(
"SAME_UPPER"
)
!=
std
::
string
::
npos
);
values
[
"padding_mode"
]
=
is_same_upper
?
to_value
(
op
::
padding_mode_t
::
same_upper
)
:
to_value
(
op
::
padding_mode_t
::
same_lower
);
values
[
"use_dynamic_same_auto_pad"
]
=
true
;
}
else
{
values
[
"padding_mode"
]
=
to_value
(
op
::
padding_mode_t
::
same
);
// kernel shape will be fixed, so max_lens() == min_len() for kernel lengths
auto
weight_lens
=
weights
->
get_shape
().
max_lens
();
std
::
vector
<
std
::
size_t
>
k_lens
(
weight_lens
.
begin
()
+
2
,
weight_lens
.
end
());
cal_auto_padding_size
(
info
,
values
,
k_lens
,
values
[
"dilation"
].
to_vector
<
std
::
size_t
>
(),
in_lens
,
padding
);
}
}
}
}
}
values
[
"padding"
]
=
std
::
vector
<
size_t
>
(
padding
.
begin
(),
padding
.
end
());
values
[
"padding"
]
=
std
::
vector
<
size_t
>
(
padding
.
begin
(),
padding
.
end
());
...
...
src/pad_calc.cpp
0 → 100644
View file @
38163d54
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/pad_calc.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
void
calculate_padding
(
int64_t
idx
,
std
::
vector
<
int64_t
>&
pads
,
int64_t
input_dim
,
int64_t
stride
,
int64_t
dilation
,
int64_t
weight_dim
,
bool
is_same_upper
)
{
int64_t
output_dim
=
(
input_dim
+
stride
-
1
)
/
stride
;
// round up result
int64_t
new_weight_dim
=
weight_dim
+
(
weight_dim
-
1
)
*
(
dilation
-
1
);
int64_t
pad
=
std
::
max
(
static_cast
<
int64_t
>
(
0
),
(
output_dim
-
1
)
*
stride
+
new_weight_dim
-
input_dim
);
auto
pad_ndims
=
pads
.
size
()
/
2
;
if
(
is_same_upper
)
{
pads
[
idx
]
=
pad
/
2
;
pads
[
idx
+
pad_ndims
]
=
pad
-
pad
/
2
;
}
else
{
pads
[
idx
+
pad_ndims
]
=
pad
/
2
;
pads
[
idx
]
=
pad
-
pad
/
2
;
}
}
std
::
vector
<
std
::
size_t
>
calc_dyn_auto_pad
(
std
::
vector
<
std
::
size_t
>
tensor_lens
,
std
::
vector
<
std
::
size_t
>
k_lens
,
std
::
vector
<
std
::
size_t
>
strides
,
std
::
vector
<
std
::
size_t
>
dilations
,
bool
use_upper
)
{
std
::
vector
<
std
::
size_t
>
padding
;
padding
.
resize
(
2
*
k_lens
.
size
());
for
(
size_t
i
=
0
;
i
<
padding
.
size
()
/
2
;
i
++
)
{
std
::
ptrdiff_t
input_dim
=
tensor_lens
[
i
];
std
::
ptrdiff_t
stride
=
strides
[
i
];
std
::
ptrdiff_t
weight_dim
=
k_lens
[
i
];
std
::
ptrdiff_t
dilation
=
dilations
[
i
];
std
::
ptrdiff_t
output_dim
=
(
input_dim
+
stride
-
1
)
/
stride
;
// round up result
std
::
ptrdiff_t
new_weight_dim
=
weight_dim
+
(
weight_dim
-
1
)
*
(
dilation
-
1
);
std
::
size_t
pad
=
std
::
max
(
static_cast
<
std
::
ptrdiff_t
>
(
0
),
(
output_dim
-
1
)
*
stride
+
new_weight_dim
-
input_dim
);
auto
pad_ndims
=
padding
.
size
()
/
2
;
if
(
use_upper
)
{
padding
[
i
]
=
pad
/
2
;
padding
[
i
+
pad_ndims
]
=
pad
-
pad
/
2
;
}
else
{
padding
[
i
+
pad_ndims
]
=
pad
/
2
;
padding
[
i
]
=
pad
-
pad
/
2
;
}
}
return
padding
;
}
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
src/program.cpp
View file @
38163d54
...
@@ -307,9 +307,12 @@ std::vector<argument> generic_eval(const module* mod,
...
@@ -307,9 +307,12 @@ std::vector<argument> generic_eval(const module* mod,
if
(
not
contains
(
params
,
param_name
))
if
(
not
contains
(
params
,
param_name
))
MIGRAPHX_THROW
(
"Parameter not found: "
+
param_name
);
MIGRAPHX_THROW
(
"Parameter not found: "
+
param_name
);
auto
param
=
params
[
param_name
];
auto
param
=
params
[
param_name
];
if
(
param
.
get_shape
()
!=
ins
->
get_shape
())
// TODO: may want to check correct number of dimensions and/or was within bounds
if
(
not
ins
->
get_shape
().
dynamic
()
and
param
.
get_shape
()
!=
ins
->
get_shape
())
{
MIGRAPHX_THROW
(
"Incorrect shape {"
+
to_string
(
param
.
get_shape
())
+
MIGRAPHX_THROW
(
"Incorrect shape {"
+
to_string
(
param
.
get_shape
())
+
"} for parameter: "
+
param_name
);
"} for parameter: "
+
param_name
);
}
return
param
;
return
param
;
}));
}));
}
}
...
@@ -352,7 +355,10 @@ std::vector<argument> generic_eval(const module* mod,
...
@@ -352,7 +355,10 @@ std::vector<argument> generic_eval(const module* mod,
}));
}));
}
}
assert
(
results
.
find
(
ins
)
!=
results
.
end
());
assert
(
results
.
find
(
ins
)
!=
results
.
end
());
assert
(
results
.
at
(
ins
).
get_shape
()
==
ins
->
get_shape
());
if
(
not
ins
->
get_shape
().
dynamic
())
{
assert
(
results
.
at
(
ins
).
get_shape
()
==
ins
->
get_shape
());
}
}
}
return
{
results
.
at
(
std
::
prev
(
mod
->
end
()))};
return
{
results
.
at
(
std
::
prev
(
mod
->
end
()))};
}
}
...
...
src/targets/ref/lowering.cpp
View file @
38163d54
...
@@ -51,6 +51,8 @@
...
@@ -51,6 +51,8 @@
#include <migraphx/register_op.hpp>
#include <migraphx/register_op.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/tune_axis.hpp>
#include <migraphx/tune_axis.hpp>
#include <migraphx/pad_calc.hpp>
#include <unordered_map>
#include <unordered_map>
#include <utility>
#include <utility>
#include <iostream>
#include <iostream>
...
@@ -231,8 +233,31 @@ struct ref_convolution : auto_register_op<ref_convolution<Op>>
...
@@ -231,8 +233,31 @@ struct ref_convolution : auto_register_op<ref_convolution<Op>>
{
{
return
op
.
normalize_compute_shape
(
inputs
);
return
op
.
normalize_compute_shape
(
inputs
);
}
}
argument
compute
(
context
&
,
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
argument
compute
(
context
&
,
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
{
{
std
::
vector
<
std
::
size_t
>
padding
;
if
(
op
.
use_dynamic_same_auto_pad
)
{
auto
input_lens
=
args
[
0
].
get_shape
().
lens
();
std
::
vector
<
std
::
size_t
>
img_lens
{
input_lens
.
begin
()
+
2
,
input_lens
.
end
()};
auto
weights_lens
=
args
[
1
].
get_shape
().
lens
();
std
::
vector
<
std
::
size_t
>
k_lens
{
weights_lens
.
begin
()
+
2
,
weights_lens
.
end
()};
padding
=
calc_dyn_auto_pad
(
img_lens
,
k_lens
,
op
.
stride
,
op
.
dilation
);
std
::
cout
<<
"[ "
;
output_shape
=
compute_padded_shape
({
args
.
at
(
0
).
get_shape
(),
args
.
at
(
1
).
get_shape
()},
padding
);
}
else
{
padding
=
op
.
padding
;
if
(
output_shape
.
dynamic
())
{
output_shape
=
op
.
normalize_compute_shape
({
args
.
at
(
0
).
get_shape
(),
args
.
at
(
1
).
get_shape
()});
}
}
argument
result
{
output_shape
};
argument
result
{
output_shape
};
visit_quantize
(
result
,
args
[
0
],
args
[
1
])([
&
](
auto
output
,
auto
input
,
auto
weights
)
{
visit_quantize
(
result
,
args
[
0
],
args
[
1
])([
&
](
auto
output
,
auto
input
,
auto
weights
)
{
auto
in_lens
=
input
.
get_shape
().
lens
();
auto
in_lens
=
input
.
get_shape
().
lens
();
...
@@ -252,7 +277,7 @@ struct ref_convolution : auto_register_op<ref_convolution<Op>>
...
@@ -252,7 +277,7 @@ struct ref_convolution : auto_register_op<ref_convolution<Op>>
{
{
auto
d_2
=
dim
-
2
;
auto
d_2
=
dim
-
2
;
win_start
.
push_back
(
std
::
ptrdiff_t
(
idx_o
[
dim
]
*
op
.
stride
[
d_2
])
-
win_start
.
push_back
(
std
::
ptrdiff_t
(
idx_o
[
dim
]
*
op
.
stride
[
d_2
])
-
std
::
ptrdiff_t
(
op
.
padding
[
d_2
]));
std
::
ptrdiff_t
(
padding
[
d_2
]));
}
}
const
auto
group_id
=
w
/
(
wei_n
/
op
.
group
);
const
auto
group_id
=
w
/
(
wei_n
/
op
.
group
);
...
@@ -289,6 +314,34 @@ struct ref_convolution : auto_register_op<ref_convolution<Op>>
...
@@ -289,6 +314,34 @@ struct ref_convolution : auto_register_op<ref_convolution<Op>>
});
});
return
result
;
return
result
;
}
}
private:
/*!
* Used for dynamic auto padding since padding needs to be computed at evaulation time.
* \param inputs two fixed shape inputs [input_tensor, weights]
* \param padding from auto_pad calculation
*/
shape
compute_padded_shape
(
const
std
::
vector
<
shape
>&
inputs
,
const
std
::
vector
<
std
::
size_t
>&
padding
)
const
{
const
shape
&
input
=
inputs
.
at
(
0
);
const
shape
&
weights
=
inputs
.
at
(
1
);
const
size_t
num_spatial_dims
=
input
.
lens
().
size
()
-
2
;
std
::
vector
<
size_t
>
output_lens
{
input
.
lens
()[
0
],
weights
.
lens
()[
0
]};
// calculate the output shape of the convolution: ((W - K + 2P) / S) + 1
for
(
size_t
i
=
0
;
i
<
num_spatial_dims
;
i
++
)
{
auto
padding_factor
=
padding
[
i
]
+
padding
[
i
+
num_spatial_dims
];
output_lens
.
push_back
(
std
::
size_t
(
std
::
max
<
std
::
ptrdiff_t
>
(
1
,
(
input
.
lens
()[
i
+
2
]
-
(
1
+
op
.
dilation
[
i
]
*
(
weights
.
lens
()[
i
+
2
]
-
1
))
+
padding_factor
)
/
op
.
stride
[
i
]
+
1
)));
}
return
inputs
[
0
].
with_lens
(
output_lens
);
}
};
};
struct
ref_im2col
struct
ref_im2col
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment