Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
250d3c87
Unverified
Commit
250d3c87
authored
Oct 01, 2023
by
Chris Austen
Committed by
GitHub
Oct 01, 2023
Browse files
Merge branch 'develop' into ck-flash-attn
parents
135eb63e
f3939b99
Changes
257
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
452 additions
and
88 deletions
+452
-88
src/driver/CMakeLists.txt
src/driver/CMakeLists.txt
+3
-0
src/driver/main.cpp
src/driver/main.cpp
+24
-9
src/driver/verify.cpp
src/driver/verify.cpp
+38
-14
src/driver/verify.hpp
src/driver/verify.hpp
+4
-3
src/fuse_pointwise.cpp
src/fuse_pointwise.cpp
+53
-0
src/include/migraphx/check_shapes.hpp
src/include/migraphx/check_shapes.hpp
+30
-4
src/include/migraphx/common_dims.hpp
src/include/migraphx/common_dims.hpp
+49
-0
src/include/migraphx/convolution.hpp
src/include/migraphx/convolution.hpp
+2
-2
src/include/migraphx/instruction.hpp
src/include/migraphx/instruction.hpp
+1
-0
src/include/migraphx/matcher.hpp
src/include/migraphx/matcher.hpp
+12
-8
src/include/migraphx/op/allocate.hpp
src/include/migraphx/op/allocate.hpp
+38
-5
src/include/migraphx/op/common.hpp
src/include/migraphx/op/common.hpp
+6
-2
src/include/migraphx/op/contiguous.hpp
src/include/migraphx/op/contiguous.hpp
+1
-1
src/include/migraphx/op/convolution.hpp
src/include/migraphx/op/convolution.hpp
+3
-1
src/include/migraphx/op/convolution_backwards.hpp
src/include/migraphx/op/convolution_backwards.hpp
+2
-2
src/include/migraphx/op/fill.hpp
src/include/migraphx/op/fill.hpp
+70
-0
src/include/migraphx/op/gather.hpp
src/include/migraphx/op/gather.hpp
+7
-8
src/include/migraphx/op/nonmaxsuppression.hpp
src/include/migraphx/op/nonmaxsuppression.hpp
+2
-2
src/include/migraphx/op/nonzero.hpp
src/include/migraphx/op/nonzero.hpp
+4
-4
src/include/migraphx/op/pooling.hpp
src/include/migraphx/op/pooling.hpp
+103
-23
No files found.
src/driver/CMakeLists.txt
View file @
250d3c87
...
...
@@ -45,6 +45,9 @@ if(NOT WIN32)
endif
()
rocm_clang_tidy_check
(
driver
)
file
(
STRINGS
"
${
CMAKE_SOURCE_DIR
}
/test/onnx/.onnxrt-commit"
String_output
)
target_compile_definitions
(
driver PUBLIC MIGRAPHX_ORT_SHA1=
"
${
String_output
}
"
)
target_link_libraries
(
driver migraphx_all_targets migraphx_onnx migraphx_tf migraphx_py
)
rocm_install_targets
(
...
...
src/driver/main.cpp
View file @
250d3c87
...
...
@@ -475,13 +475,15 @@ struct compiler
{
if
(
is_offload_copy_set
(
p
)
and
not
co
.
offload_copy
)
{
std
::
cout
<<
"MIGraphX program was likely compiled with offload_copy set, Try "
"passing "
"`--enable-offload-copy` if program run fails.
\n
"
;
std
::
cout
<<
"[WARNING]: MIGraphX program was likely compiled with offload_copy "
"set, Try "
"passing "
"`--enable-offload-copy` if program run fails.
\n
"
;
}
else
if
(
co
.
offload_copy
)
{
std
::
cout
<<
"MIGraphX program was likely compiled without "
std
::
cout
<<
"
[WARNING]:
MIGraphX program was likely compiled without "
"offload_copy set, Try "
"removing "
"`--enable-offload-copy` flag if passed to driver, if program run "
...
...
@@ -534,13 +536,19 @@ struct params : command<params>
struct
verify
:
command
<
verify
>
{
compiler
c
;
double
tolerance
=
80
;
migraphx
::
verify
::
tolerance
tols
;
bool
per_instruction
=
false
;
bool
reduce
=
false
;
void
parse
(
argument_parser
&
ap
)
{
c
.
parse
(
ap
);
ap
(
tolerance
,
{
"--tolerance"
},
ap
.
help
(
"Tolerance for errors"
));
ap
(
tols
.
rms_tol
,
{
"--rms-tol"
},
ap
.
help
(
"Tolerance for the RMS error (Default: 0.001)"
));
ap
(
tols
.
atol
,
{
"--atol"
},
ap
.
help
(
"Tolerance for the elementwise absolute difference (Default: 0.001)"
));
ap
(
tols
.
rtol
,
{
"--rtol"
},
ap
.
help
(
"Tolerance for the elementwise relative difference (Default: 0.001)"
));
ap
(
per_instruction
,
{
"-i"
,
"--per-instruction"
},
ap
.
help
(
"Verify each instruction"
),
...
...
@@ -565,15 +573,15 @@ struct verify : command<verify>
if
(
per_instruction
)
{
verify_instructions
(
p
,
t
,
c
.
co
,
quantize
,
tol
erance
);
verify_instructions
(
p
,
t
,
c
.
co
,
quantize
,
tol
s
);
}
else
if
(
reduce
)
{
verify_reduced_program
(
p
,
t
,
c
.
co
,
quantize
,
m
,
tol
erance
);
verify_reduced_program
(
p
,
t
,
c
.
co
,
quantize
,
m
,
tol
s
);
}
else
{
verify_program
(
c
.
l
.
file
,
p
,
t
,
c
.
co
,
quantize
,
m
,
tol
erance
);
verify_program
(
c
.
l
.
file
,
p
,
t
,
c
.
co
,
quantize
,
m
,
tol
s
);
}
}
};
...
...
@@ -802,6 +810,13 @@ int main(int argc, const char* argv[])
auto
&&
m
=
get_commands
();
auto
cmd
=
args
.
front
();
if
(
cmd
==
"ort-sha"
)
{
std
::
cout
<<
MIGRAPHX_ORT_SHA1
<<
std
::
endl
;
return
0
;
}
if
(
m
.
count
(
cmd
)
>
0
)
{
m
.
at
(
cmd
)(
argv
[
0
],
{
args
.
begin
()
+
1
,
args
.
end
()});
...
...
src/driver/verify.cpp
View file @
250d3c87
...
...
@@ -30,6 +30,7 @@
#include <migraphx/instruction.hpp>
#include <migraphx/compile_options.hpp>
#include <migraphx/quantization.hpp>
#include <migraphx/ranges.hpp>
namespace
migraphx
{
namespace
driver
{
...
...
@@ -76,15 +77,25 @@ void verify_program(const std::string& name,
compile_options
options
,
precision
quantize
,
const
parameter_map
&
inputs
,
double
tolerance
)
verify
::
tolerance
tols
)
{
auto
x
=
run_ref
(
p
,
inputs
);
auto
y
=
run_target
(
p
,
t
,
options
,
quantize
,
inputs
);
auto
ref_outs
=
run_ref
(
p
,
inputs
);
auto
target_outs
=
run_target
(
p
,
t
,
options
,
quantize
,
inputs
);
std
::
size_t
output_num
=
x
.
size
();
std
::
size_t
output_num
=
ref_outs
.
size
();
for
(
std
::
size_t
i
=
0
;
i
<
output_num
;
++
i
)
{
verify_args
(
name
,
x
[
i
],
y
[
i
],
tolerance
);
if
(
ref_outs
[
i
].
get_shape
().
type
()
!=
target_outs
[
i
].
get_shape
().
type
()
or
ref_outs
[
i
].
get_shape
().
lens
()
!=
target_outs
[
i
].
get_shape
().
lens
())
{
std
::
cout
<<
"FAILED: "
<<
name
<<
std
::
endl
;
std
::
cout
<<
"Shape mismatch {"
<<
ref_outs
[
i
].
get_shape
()
<<
"} != {"
<<
target_outs
[
i
].
get_shape
()
<<
"}"
<<
std
::
endl
;
}
else
{
verify_args
(
name
,
target_outs
[
i
],
verify
::
expected
{
ref_outs
[
i
]},
tols
);
}
}
}
...
...
@@ -92,7 +103,7 @@ void verify_instructions(const program& prog,
const
target
&
t
,
compile_options
options
,
precision
quantize
,
double
tolerance
)
verify
::
tolerance
tols
)
{
const
auto
*
mm_prog
=
prog
.
get_main_module
();
for
(
auto
&&
ins
:
(
*
mm_prog
))
...
...
@@ -123,8 +134,7 @@ void verify_instructions(const program& prog,
{
std
::
cout
<<
"Verify: "
<<
ins
.
name
()
<<
std
::
endl
;
std
::
cout
<<
p
<<
std
::
endl
;
verify_program
(
ins
.
name
(),
p
,
t
,
options
,
quantize
,
create_param_map
(
p
,
false
),
tolerance
);
verify_program
(
ins
.
name
(),
p
,
t
,
options
,
quantize
,
create_param_map
(
p
,
false
),
tols
);
}
catch
(...)
{
...
...
@@ -140,14 +150,22 @@ void verify_reduced(program p,
compile_options
options
,
precision
quantize
,
const
parameter_map
&
inputs
,
double
tolerance
)
verify
::
tolerance
tols
)
{
auto
*
mm
=
p
.
get_main_module
();
auto
last
=
std
::
prev
(
mm
->
end
(),
n
+
1
);
auto
last
=
std
::
prev
(
mm
->
end
(),
n
);
mm
->
remove_instructions
(
last
,
mm
->
end
());
std
::
cout
<<
"Verify: "
<<
n
<<
std
::
endl
;
std
::
cout
<<
p
<<
std
::
endl
;
verify_program
(
std
::
to_string
(
n
),
p
,
t
,
options
,
quantize
,
inputs
,
tolerance
);
try
{
verify_program
(
std
::
to_string
(
n
),
p
,
t
,
options
,
quantize
,
inputs
,
tols
);
}
catch
(
const
std
::
exception
&
e
)
{
std
::
cout
<<
"FAILED: "
<<
n
<<
std
::
endl
;
std
::
cout
<<
"Exception: "
<<
e
.
what
()
<<
std
::
endl
;
}
}
void
verify_reduced_program
(
const
program
&
p
,
...
...
@@ -155,14 +173,20 @@ void verify_reduced_program(const program& p,
compile_options
options
,
precision
quantize
,
const
parameter_map
&
inputs
,
double
tolerance
)
verify
::
tolerance
tols
)
{
const
auto
*
mm
=
p
.
get_main_module
();
auto
n
=
std
::
distance
(
mm
->
begin
(),
mm
->
end
());
std
::
cout
<<
"Verify steps: "
<<
n
<<
std
::
endl
;
for
(
std
::
size_t
i
=
0
;
i
<
n
;
i
++
)
for
(
std
::
size_t
i
=
1
;
i
<
n
;
i
++
)
{
verify_reduced
(
p
,
i
,
t
,
options
,
quantize
,
inputs
,
tolerance
);
auto
last
=
std
::
prev
(
mm
->
end
(),
i
+
1
);
if
(
contains
({
"@literal"
,
"@param"
},
last
->
name
()))
{
std
::
cout
<<
"Skip: "
<<
i
<<
std
::
endl
;
continue
;
}
verify_reduced
(
p
,
i
,
t
,
options
,
quantize
,
inputs
,
tols
);
}
}
...
...
src/driver/verify.hpp
View file @
250d3c87
...
...
@@ -26,6 +26,7 @@
#include "precision.hpp"
#include <migraphx/program.hpp>
#include <migraphx/verify.hpp>
namespace
migraphx
{
namespace
driver
{
...
...
@@ -37,18 +38,18 @@ void verify_program(const std::string& name,
compile_options
options
=
compile_options
{},
precision
quantize
=
precision
::
fp32
,
const
parameter_map
&
inputs
=
{},
double
tolerance
=
100
);
verify
::
tolerance
tols
=
verify
::
tolerance
{}
);
void
verify_instructions
(
const
program
&
prog
,
const
target
&
t
,
compile_options
options
=
compile_options
{},
precision
quantize
=
precision
::
fp32
,
double
tolerance
=
80
);
verify
::
tolerance
tols
=
verify
::
tolerance
{}
);
void
verify_reduced_program
(
const
program
&
p
,
const
target
&
t
,
compile_options
options
=
compile_options
{},
precision
quantize
=
precision
::
fp32
,
const
parameter_map
&
inputs
=
{},
double
tolerance
=
80
);
verify
::
tolerance
tols
=
verify
::
tolerance
{}
);
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace driver
...
...
src/fuse_pointwise.cpp
View file @
250d3c87
...
...
@@ -24,11 +24,14 @@
#include <migraphx/fuse_pointwise.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/simplify_reshapes.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/program.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/matcher.hpp>
#include <migraphx/common_dims.hpp>
#include <iterator>
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_DISABLE_POINTWISE_FUSION
)
...
...
@@ -189,6 +192,54 @@ static bool find_pointwise_modules(module& m)
}
return
changed
;
}
namespace
{
struct
find_pointwise_reshape_pointwise
{
auto
matcher
()
const
{
auto
reshape
=
match
::
name
(
"reshape"
,
"squeeze"
,
"unsqueeze"
,
"flatten"
)(
match
::
used_once
());
auto
skip_contiguous
=
[](
auto
...
ms
)
{
return
match
::
arg
(
0
)(
match
::
skip
(
match
::
name
(
"contiguous"
)(
match
::
used_once
()))(
ms
...));
};
auto
pointwise
=
match
::
name
(
"pointwise"
)(
match
::
used_once
());
auto
reshape_pointwise
=
reshape
(
skip_contiguous
(
pointwise
.
bind
(
"x"
))).
bind
(
"reshape"
);
return
match
::
name
(
"pointwise"
)(
match
::
any_of
[
match
::
inputs
()](
reshape_pointwise
));
}
void
apply
(
module
&
m
,
const
match
::
matcher_result
&
r
)
const
{
auto
ins
=
r
.
result
;
auto
x_ins
=
r
.
instructions
[
"x"
];
auto
reshape_ins
=
r
.
instructions
[
"reshape"
];
auto
cd
=
common_dims
::
compute
(
ins
->
get_shape
().
lens
(),
x_ins
->
get_shape
().
lens
());
if
(
cd
.
dims
.
empty
())
return
;
auto
reshape_input
=
[
&
](
const
auto
&
ins_to_insert
)
{
return
[
&
](
auto
input
)
{
auto
c
=
m
.
insert_instruction
(
ins_to_insert
,
make_op
(
"contiguous"
),
input
);
return
m
.
insert_instruction
(
ins_to_insert
,
make_op
(
"reshape"
,
{{
"dims"
,
cd
.
dims
}}),
c
);
};
};
auto
x_inputs
=
x_ins
->
inputs
();
std
::
transform
(
x_inputs
.
begin
(),
x_inputs
.
end
(),
x_inputs
.
begin
(),
reshape_input
(
x_ins
));
auto
new_x_ins
=
m
.
insert_instruction
(
x_ins
,
x_ins
->
get_operator
(),
x_inputs
,
x_ins
->
module_inputs
());
auto
inputs
=
ins
->
inputs
();
std
::
transform
(
inputs
.
begin
(),
inputs
.
end
(),
inputs
.
begin
(),
[
&
](
auto
input
)
{
if
(
input
==
reshape_ins
)
return
new_x_ins
;
return
reshape_input
(
ins
)(
input
);
});
auto
pw
=
m
.
insert_instruction
(
ins
,
ins
->
get_operator
(),
inputs
,
ins
->
module_inputs
());
m
.
replace_instruction
(
ins
,
make_op
(
"reshape"
,
{{
"dims"
,
ins
->
get_shape
().
lens
()}}),
pw
);
}
};
}
// namespace
void
fuse_pointwise
::
apply
(
module_pass_manager
&
mpm
)
const
{
...
...
@@ -200,6 +251,8 @@ void fuse_pointwise::apply(module_pass_manager& mpm) const
}
for
(
int
i
=
0
;
i
<
8
;
i
++
)
{
match
::
find_matches
(
mpm
.
get_module
(),
find_pointwise_reshape_pointwise
{});
mpm
.
run_pass
(
simplify_reshapes
{
1
});
if
(
not
find_pointwise_modules
(
mpm
.
get_module
()))
break
;
mpm
.
run_pass
(
dead_code_elimination
{});
...
...
src/include/migraphx/check_shapes.hpp
View file @
250d3c87
...
...
@@ -70,13 +70,19 @@ struct check_shapes
check_dynamic
();
}
template
<
class
Op
>
template
<
class
Op
,
MIGRAPHX_REQUIRES
(
not
std
::
is_convertible
<
Op
,
std
::
string
>{})
>
check_shapes
(
const
std
::
vector
<
shape
>&
s
,
const
Op
&
op
,
const
bool
d
=
false
)
:
begin
(
s
.
begin
()),
end
(
s
.
end
()),
name
(
op
.
name
()),
dynamic_allowed
(
d
)
{
check_dynamic
();
}
check_shapes
(
const
std
::
vector
<
shape
>&
s
,
const
std
::
string
&
n
,
const
bool
d
=
false
)
:
begin
(
s
.
begin
()),
end
(
s
.
end
()),
name
(
n
),
dynamic_allowed
(
d
)
{
check_dynamic
();
}
void
check_dynamic
()
const
{
if
(
not
dynamic_allowed
and
this
->
any_of
([
&
](
const
shape
&
s
)
{
return
s
.
dynamic
();
}))
...
...
@@ -147,7 +153,7 @@ struct check_shapes
{
if
(
begin
!=
end
)
{
if
(
begin
->
max_lens
().
size
()
!=
n
)
if
(
begin
->
ndim
()
!=
n
)
MIGRAPHX_THROW
(
prefix
()
+
"Only "
+
std
::
to_string
(
n
)
+
"d supported"
);
}
return
*
this
;
...
...
@@ -162,7 +168,7 @@ struct check_shapes
{
if
(
begin
!=
end
)
{
if
(
begin
->
max_lens
().
size
()
>
n
)
if
(
begin
->
ndim
()
>
n
)
MIGRAPHX_THROW
(
prefix
()
+
"Shape must have at most "
+
std
::
to_string
(
n
)
+
" dimensions"
);
}
...
...
@@ -178,7 +184,7 @@ struct check_shapes
{
if
(
begin
!=
end
)
{
if
(
begin
->
max_lens
().
size
()
<
n
)
if
(
begin
->
ndim
()
<
n
)
MIGRAPHX_THROW
(
prefix
()
+
"Shape must have at least "
+
std
::
to_string
(
n
)
+
" dimensions"
);
}
...
...
@@ -228,6 +234,16 @@ struct check_shapes
return
*
this
;
}
/*!
* Check all shapes have the same layout.
*/
const
check_shapes
&
same_layout
()
const
{
if
(
not
this
->
same
([](
const
shape
&
s
)
{
return
find_permutation
(
s
);
}))
MIGRAPHX_THROW
(
prefix
()
+
"Layouts do not match"
);
return
*
this
;
}
/*!
* Check all shapes are standard.
*/
...
...
@@ -238,6 +254,16 @@ struct check_shapes
return
*
this
;
}
/*!
* Check all shapes are scalar.
*/
const
check_shapes
&
scalar
()
const
{
if
(
not
this
->
all_of
([](
const
shape
&
s
)
{
return
s
.
scalar
();
}))
MIGRAPHX_THROW
(
prefix
()
+
"Shapes are not a scalar"
);
return
*
this
;
}
/*!
* Check all shapes are standard or scalar.
*/
...
...
src/include/migraphx/common_dims.hpp
0 → 100644
View file @
250d3c87
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_MIGRAPHX_COMMON_DIMS_HPP
#define MIGRAPHX_GUARD_MIGRAPHX_COMMON_DIMS_HPP
#include <migraphx/config.hpp>
#include <cstdint>
#include <vector>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
/// This will compute a higher dimensional space that will preserve the axes
/// for both sets of dimensions. Two axes_maps are provided for each of the
/// dims that will map the axis to the axes that are used by the result of
/// common_dims.
struct
MIGRAPHX_EXPORT
common_dims
{
static
common_dims
compute
(
const
std
::
vector
<
std
::
size_t
>&
dims1
,
const
std
::
vector
<
std
::
size_t
>&
dims2
);
std
::
vector
<
std
::
size_t
>
dims
;
std
::
vector
<
std
::
vector
<
std
::
size_t
>>
axes_map1
;
std
::
vector
<
std
::
vector
<
std
::
size_t
>>
axes_map2
;
};
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif // MIGRAPHX_GUARD_MIGRAPHX_COMMON_DIMS_HPP
src/include/migraphx/convolution.hpp
View file @
250d3c87
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-202
2
Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-202
3
Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
...
...
@@ -62,7 +62,7 @@ void convolution(Output output, T input, T weights, Padding padding, Stride stri
shape
win_shape
{
output_shape
.
type
(),
win_size
};
double
acc
=
0.0
;
shape_for_each
(
win_shape
,
[
&
](
auto
idx_win
)
{
shape_for_each
(
win_shape
,
[
&
](
const
auto
&
idx_win
)
{
auto
k
=
idx_win
[
0
];
const
auto
in_ch
=
group_id
*
wei_c
+
k
;
std
::
vector
<
std
::
ptrdiff_t
>
idx
(
idx_o
.
begin
(),
idx_o
.
end
());
...
...
src/include/migraphx/instruction.hpp
View file @
250d3c87
...
...
@@ -81,6 +81,7 @@ struct MIGRAPHX_EXPORT instruction
const
std
::
vector
<
module_ref
>&
module_inputs
()
const
;
/// Where this instruction is used as an input to another instruction
const
std
::
vector
<
instruction_ref
>&
outputs
()
const
;
friend
bool
operator
==
(
const
instruction
&
x
,
const
instruction
&
y
);
...
...
src/include/migraphx/matcher.hpp
View file @
250d3c87
...
...
@@ -381,22 +381,24 @@ void find_matches_for(source_location location, Mod& mod, instruction_ref ins, M
const
int
trace
=
value_of
(
MIGRAPHX_TRACE_MATCHES
{});
const
bool
validate
=
enabled
(
MIGRAPHX_VALIDATE_MATCHES
{});
const
auto
trace_filter
=
string_value_of
(
MIGRAPHX_TRACE_MATCHES_FOR
{});
const
bool
trace_for
=
not
trace_filter
.
empty
()
and
(
contains
(
std
::
string
{
location
.
file_name
()},
trace_filter
)
or
contains
(
std
::
string
{
location
.
function_name
()},
trace_filter
));
bool
match
=
false
;
bool
match
=
false
;
each_args
(
[
&
](
auto
&&
m
)
{
const
auto
&
matcher_name
=
get_type_name
(
m
);
const
bool
trace_for
=
not
trace_filter
.
empty
()
and
(
contains
(
std
::
string
{
location
.
file_name
()},
trace_filter
)
or
contains
(
std
::
string
{
location
.
function_name
()},
trace_filter
)
or
contains
(
matcher_name
,
trace_filter
));
if
(
match
)
return
;
if
(
trace
>
1
or
trace_for
)
std
::
cout
<<
"Match: "
<<
get_type
_name
(
m
)
<<
std
::
endl
;
if
(
trace
>
1
and
trace_for
)
std
::
cout
<<
"Match: "
<<
matcher
_name
<<
std
::
endl
;
auto
r
=
match_instruction
(
get_module
(
mod
),
ins
,
m
.
matcher
());
if
(
r
.
result
==
get_module
(
mod
).
end
())
return
;
if
(
trace
>
0
or
trace_for
)
{
std
::
cout
<<
"Matched by "
<<
get_type
_name
(
m
)
<<
std
::
endl
;
std
::
cout
<<
"Matched by "
<<
matcher
_name
<<
std
::
endl
;
get_module
(
mod
).
debug_print
(
ins
);
}
// If its already invalid dont validate it again
...
...
@@ -407,7 +409,7 @@ void find_matches_for(source_location location, Mod& mod, instruction_ref ins, M
auto
invalid
=
get_module
(
mod
).
validate
();
if
(
invalid
!=
get_module
(
mod
).
end
())
{
std
::
cout
<<
"Invalid program from match: "
<<
get_type
_name
(
m
)
<<
std
::
endl
;
std
::
cout
<<
"Invalid program from match: "
<<
matcher
_name
<<
std
::
endl
;
std
::
cout
<<
"Invalid instructions: "
<<
std
::
endl
;
get_module
(
mod
).
debug_print
(
invalid
->
inputs
());
get_module
(
mod
).
debug_print
(
invalid
);
...
...
@@ -621,6 +623,8 @@ MIGRAPHX_PRED_MATCHER(broadcast, instruction_ref ins)
template
<
class
...
Ms
>
auto
skip
(
Ms
...
ms
)
{
static_assert
(((
not
std
::
is_convertible
<
Ms
,
std
::
string
>
{})
and
...),
"Use a matcher not a string for skip."
);
auto
m
=
any_of
(
ms
...);
return
make_basic_fun_matcher
([
=
](
matcher_context
&
ctx
,
instruction_ref
start
)
{
return
fix
<
optional
<
instruction_ref
>>
(
...
...
src/include/migraphx/op/allocate.hpp
View file @
250d3c87
...
...
@@ -36,20 +36,53 @@ namespace op {
struct
allocate
{
shape
s
{};
// for dynamic allocate to set the buffer type
shape
::
type_t
buf_type
=
shape
::
half_type
;
template
<
class
Self
,
class
F
>
static
auto
reflect
(
Self
&
self
,
F
f
)
{
return
pack
(
f
(
self
.
s
,
"shape"
));
return
pack
(
f
(
self
.
s
,
"shape"
)
,
f
(
self
.
buf_type
,
"buf_type"
)
);
}
std
::
string
name
()
const
{
return
"allocate"
;
}
shape
compute_shape
(
const
std
::
vector
<
shape
>&
inputs
)
const
{
migraphx
::
check_shapes
{
inputs
,
*
this
,
true
}.
has
(
0
);
return
s
;
if
(
s
!=
shape
())
{
if
(
inputs
.
size
()
==
1
)
{
migraphx
::
check_shapes
{
inputs
,
*
this
,
false
}.
only_dims
(
1
);
}
else
{
migraphx
::
check_shapes
{
inputs
,
*
this
,
false
}.
has
(
0
);
}
return
s
;
}
else
{
migraphx
::
check_shapes
{
inputs
,
*
this
,
false
}.
has
(
1
).
only_dims
(
1
);
const
auto
&
out_dims
=
inputs
.
at
(
0
);
std
::
size_t
max_val
=
std
::
numeric_limits
<
std
::
size_t
>::
max
();
std
::
vector
<
shape
::
dynamic_dimension
>
dyn_dims
(
out_dims
.
lens
().
at
(
0
),
shape
::
dynamic_dimension
{
0
,
max_val
});
return
{
buf_type
,
dyn_dims
};
}
}
argument
compute
(
const
shape
&
output_shape
,
const
std
::
vector
<
argument
>&
)
const
argument
compute
(
const
shape
&
output_shape
,
const
std
::
vector
<
argument
>&
args
)
const
{
return
{
output_shape
};
if
(
args
.
empty
())
{
return
{
output_shape
};
}
else
{
std
::
vector
<
std
::
size_t
>
output_dims
(
output_shape
.
ndim
());
args
.
at
(
0
).
visit
([
&
](
auto
a
)
{
output_dims
.
assign
(
a
.
begin
(),
a
.
end
());
});
return
{
shape
{
buf_type
,
output_dims
}};
}
}
};
...
...
src/include/migraphx/op/common.hpp
View file @
250d3c87
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-202
2
Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-202
3
Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
...
...
@@ -33,8 +33,12 @@ namespace migraphx {
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
op
{
// Specifies where to add the "extra" cell of padding if the
// calculated padding is an odd number.
// Padding mode is default_ for fixed shape padding.
// same_lower and same_upper used for dynamic padding.
// same_lower and same_upper specify dynamic padding.
// The odd cell goes at the beginning of the dimension
// (same_lower) or end (same_upper).
enum
padding_mode_t
{
default_
,
// NOLINT
...
...
src/include/migraphx/op/contiguous.hpp
View file @
250d3c87
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-202
2
Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-202
3
Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
...
...
src/include/migraphx/op/convolution.hpp
View file @
250d3c87
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-202
2
Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-202
3
Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
...
...
@@ -206,6 +206,7 @@ struct convolution
std
::
vector
<
std
::
size_t
>
new_padding
;
if
(
padding_mode
!=
op
::
padding_mode_t
::
default_
)
{
// auto-Calculate the padding sizes with calc_dyn_auto_pad
auto
input_lens
=
args
[
0
].
get_shape
().
lens
();
auto
weights_lens
=
args
[
1
].
get_shape
().
lens
();
new_padding
=
...
...
@@ -217,6 +218,7 @@ struct convolution
}
else
{
// Use the padding that was given
new_padding
=
padding
;
if
(
output_shape
.
dynamic
())
{
...
...
src/include/migraphx/op/convolution_backwards.hpp
View file @
250d3c87
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-202
2
Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-202
3
Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
...
...
@@ -164,7 +164,7 @@ struct convolution_backwards
shape
win_shape
{
dyn_out
.
computed_shape
.
type
(),
win_size
};
par_dfor
(
in_n
,
wei_c
)([
&
](
int
o
,
int
k
)
{
shape_for_each
(
win_shape
,
[
&
](
auto
idx_win
)
{
shape_for_each
(
win_shape
,
[
&
](
const
auto
&
idx_win
)
{
const
int
w
=
idx_win
[
0
];
auto
input_dims_start
=
idx_win
.
begin
()
+
1
;
...
...
src/include/migraphx/op/fill.hpp
0 → 100644
View file @
250d3c87
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_OPERATORS_FILL_HPP
#define MIGRAPHX_GUARD_OPERATORS_FILL_HPP
#include <migraphx/check_shapes.hpp>
#include <migraphx/dyn_output.hpp>
#include <migraphx/par_for.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
op
{
/**
* fill(default_value, output_buffer)
* Fill an output buffer with the given default_value.
* Note that if the default_value is a literal and the output_buffer
* has a static shape this operator can be replaced with a literal.
*/
struct
fill
{
std
::
string
name
()
const
{
return
"fill"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
check_shapes
{
inputs
,
*
this
,
true
}.
has
(
2
).
same_type
();
if
(
inputs
.
at
(
0
).
dynamic
()
or
inputs
.
at
(
0
).
elements
()
!=
1
)
{
MIGRAPHX_THROW
(
"FILL: default_value is dynamic or more than one element"
);
}
return
inputs
.
back
();
}
argument
compute
(
const
dyn_output
&
dyn_out
,
std
::
vector
<
argument
>
args
)
const
{
visit_all
(
args
[
0
],
args
[
1
])([
&
](
auto
value
,
auto
output
)
{
par_for
(
dyn_out
.
computed_shape
.
elements
(),
[
&
](
auto
i
)
{
output
[
i
]
=
value
.
front
();
});
});
return
args
[
1
];
}
std
::
ptrdiff_t
output_alias
(
const
std
::
vector
<
shape
>&
)
const
{
return
1
;
}
};
}
// namespace op
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif
src/include/migraphx/op/gather.hpp
View file @
250d3c87
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-202
2
Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-202
3
Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
...
...
@@ -125,13 +125,12 @@ struct gather
auto
out_lens
=
data
.
get_shape
().
lens
();
out_lens
[
axis
]
=
indices
.
get_shape
().
elements
();
migraphx
::
shape
out_comp_shape
{
data
.
get_shape
().
type
(),
out_lens
};
shape_for_each
(
out_comp_shape
,
[
&
](
const
auto
&
out_idx
)
{
auto
data_idx
=
out_idx
;
auto
in_index
=
indices
[
data_idx
[
axis
]];
in_index
=
(
in_index
<
0
)
?
in_index
+
axis_dim_size
:
in_index
;
data_idx
[
axis
]
=
in_index
;
output
[
out_comp_shape
.
index
(
out_idx
.
begin
(),
out_idx
.
end
())]
=
data
(
data_idx
.
begin
(),
data_idx
.
end
());
shape_for_each
(
out_comp_shape
,
[
&
](
const
auto
&
out_idx_v
,
size_t
out_idx
)
{
auto
data_idx
=
out_idx_v
;
auto
in_index
=
indices
[
data_idx
[
axis
]];
in_index
=
(
in_index
<
0
)
?
in_index
+
axis_dim_size
:
in_index
;
data_idx
[
axis
]
=
in_index
;
output
[
out_idx
]
=
data
(
data_idx
.
begin
(),
data_idx
.
end
());
});
}
});
...
...
src/include/migraphx/op/nonmaxsuppression.hpp
View file @
250d3c87
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-202
2
Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-202
3
Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
...
...
@@ -258,7 +258,7 @@ struct nonmaxsuppression
selected_boxes_inside_class
.
reserve
(
max_output_shape
.
elements
());
// iterate over batches and classes
shape
comp_s
{
shape
::
double_type
,
{
num_batches
,
num_classes
}};
shape_for_each
(
comp_s
,
[
&
](
auto
idx
)
{
shape_for_each
(
comp_s
,
[
&
](
const
auto
&
idx
)
{
auto
batch_idx
=
idx
[
0
];
auto
class_idx
=
idx
[
1
];
// index offset for this class
...
...
src/include/migraphx/op/nonzero.hpp
View file @
250d3c87
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-202
2
Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-202
3
Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
...
...
@@ -56,10 +56,10 @@ struct nonzero
std
::
vector
<
std
::
vector
<
std
::
size_t
>>
vec_idx
;
auto
s
=
args
.
front
().
get_shape
();
args
.
front
().
visit
([
&
](
auto
v
)
{
shape_for_each
(
s
,
[
&
](
auto
idx
)
{
if
(
not
float_equal
(
v
[
s
.
index
(
idx
)
],
0
))
shape_for_each
(
s
,
[
&
](
const
auto
&
idx_v
,
size_t
idx
)
{
if
(
not
float_equal
(
v
[
idx
],
0
))
{
vec_idx
.
push_back
(
idx
);
vec_idx
.
push_back
(
idx
_v
);
}
});
});
...
...
src/include/migraphx/op/pooling.hpp
View file @
250d3c87
...
...
@@ -29,6 +29,7 @@
#include <migraphx/config.hpp>
#include <migraphx/value.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/pad_calc.hpp>
#include <migraphx/par_for.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/dyn_output.hpp>
...
...
@@ -40,10 +41,20 @@ namespace migraphx {
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
op
{
// The Pooling operator mostly follows the specifications for the Onnx pooling op.
// It assumes an NCHW layout, extended to support any number of spatial dimensions
// from 1 on up; dimensions are <batch index, channels, spatial dimensions...>
//
struct
pooling
{
// Class members mode, ceil_mode, padding_mode have similar names but refer to separate
// concepts.
pooling_mode
mode
=
{
pooling_mode
::
average
};
// If the input has rank other than 4 then padding, stride, lengths must all be specified
// since the defaults have 2-dimensions. Exception: padding not required if
// padding_mode != default_
// Padding along each spatial input dimension
// Can be ndim or 2*ndim values where ndim is size of lengths
// ndim values means pad the same before and after each dimension
...
...
@@ -63,13 +74,14 @@ struct pooling
// ceiling mode is a flag affecting output size
// or equivalently, placements of the pooling kernel.
// When true, round the size upwards, possibly
// including partial placements where the kernel extends beyond the edge
// of input and even padding. When false, round down so that all
// When true, round the size upwards. When false, round down so that all
// kernel placements fit but some input values may be dropped.
bool
ceil_mode
=
false
;
int
lp_order
=
2
;
// Mode for auto padding. default_ indicates no auto padding.
padding_mode_t
padding_mode
=
padding_mode_t
::
default_
;
// Global pooling with dynamic shape input
bool
dyn_global
=
false
;
...
...
@@ -84,6 +96,7 @@ struct pooling
{
return
pack
(
f
(
self
.
mode
,
"mode"
),
f
(
self
.
padding
,
"padding"
),
f
(
self
.
padding_mode
,
"padding_mode"
),
f
(
self
.
stride
,
"stride"
),
f
(
self
.
lengths
,
"lengths"
),
f
(
self
.
ceil_mode
,
"ceil_mode"
),
...
...
@@ -97,7 +110,8 @@ struct pooling
{
if
(
dyn_global
)
return
;
if
((
padding
.
size
()
!=
stride
.
size
()
and
(
padding
.
size
())
!=
stride
.
size
()
*
2
)
or
if
((
padding_mode
!=
default_
and
padding
.
size
()
!=
stride
.
size
()
and
(
padding
.
size
())
!=
stride
.
size
()
*
2
)
or
stride
.
size
()
!=
lengths
.
size
())
{
MIGRAPHX_THROW
(
"POOLING: inconsistent attribute sizes"
);
...
...
@@ -137,8 +151,19 @@ struct pooling
std
::
size_t
padding_factor
=
2
*
padding
[
i
];
if
(
padding
.
size
()
==
2
*
kdims
)
padding_factor
=
padding
[
i
]
+
padding
[
i
+
kdims
];
assert
(
input_lens
[
i
+
2
]
+
padding_factor
>=
lengths
[
i
]);
std
::
size_t
dim_size
=
input_lens
[
i
+
2
]
+
padding_factor
-
lengths
[
i
];
std
::
size_t
dim_size
;
if
(
input_lens
[
i
+
2
]
+
padding_factor
<
lengths
[
i
])
{
if
(
padding_mode
==
default_
)
MIGRAPHX_THROW
(
"POOLING: not enough padding for the given kernel size"
);
// lengths can be legitimately larger only if we're doing auto padding
// with a dynamic shape, in which case given padding is ignored. Set a dummy value.
dim_size
=
2
;
}
else
{
dim_size
=
input_lens
[
i
+
2
]
+
padding_factor
-
lengths
[
i
];
}
std
::
size_t
len
=
(
ceil_mode
)
?
dim_size
/
stride
[
i
]
+
...
...
@@ -151,17 +176,13 @@ struct pooling
shape
normalize_compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
check_shapes
{
inputs
,
*
this
,
true
}.
has
(
1
);
check_shapes
{
inputs
,
*
this
,
true
}.
has
(
1
)
.
min_ndims
(
3
)
;
check_attribute_size
();
const
shape
&
input
=
inputs
.
at
(
0
);
auto
padding
_size
=
padding
.
size
();
auto
stride
_size
=
stride
.
size
();
size_t
kdims
=
input
.
ndim
()
-
2
;
if
(
input
.
ndim
()
<
3
)
{
MIGRAPHX_THROW
(
"POOLING: input must have 3 or more dimensions and be nonempty"
);
}
if
(
input
.
ndim
()
*
2
!=
padding_size
+
4
and
input
.
ndim
()
!=
padding_size
+
2
)
if
(
input
.
ndim
()
!=
stride_size
+
2
)
{
MIGRAPHX_THROW
(
"POOLING: input and attribute size mismatch!"
);
}
...
...
@@ -179,6 +200,28 @@ struct pooling
}
return
{
input
.
type
(),
output_dyn_dims
};
}
else
if
(
padding_mode
!=
default_
)
{
const
size_t
num_spatial_dims
=
inputs
[
0
].
ndim
()
-
2
;
const
shape
&
x_shape
=
inputs
[
0
];
// same as convolution::dynamic_compute_shape()
for
(
std
::
size_t
i
=
0
;
i
<
num_spatial_dims
;
++
i
)
{
auto
ceil_div
=
[](
std
::
size_t
x
,
std
::
size_t
y
)
{
return
(
x
+
y
-
1
)
/
y
;
};
auto
s
=
stride
[
i
];
auto
x
=
x_shape
.
dyn_dims
()[
i
+
2
];
std
::
set
<
std
::
size_t
>
optimals
{};
std
::
transform
(
x
.
optimals
.
begin
(),
x
.
optimals
.
end
(),
std
::
inserter
(
optimals
,
optimals
.
begin
()),
[
&
](
auto
o
)
{
return
ceil_div
(
o
,
s
);
});
output_dyn_dims
.
push_back
(
shape
::
dynamic_dimension
{
ceil_div
(
x
.
min
,
s
),
ceil_div
(
x
.
max
,
s
),
optimals
});
}
return
{
input
.
type
(),
output_dyn_dims
};
}
else
{
// does not compute optimals
...
...
@@ -267,6 +310,7 @@ struct pooling
Out
&
output
,
const
In
&
input
,
const
std
::
vector
<
std
::
size_t
>&
kernel_dims
,
const
std
::
vector
<
std
::
size_t
>&
padding_vals
,
Op
op
)
const
{
auto
in_s
=
input
.
get_shape
();
...
...
@@ -283,9 +327,9 @@ struct pooling
// For each spatial dimension, find starting and ending index of pooling kernel
for
(
std
::
size_t
dim
=
2
;
dim
<
n_dim
;
++
dim
)
{
auto
d_2
=
dim
-
2
;
int
start
=
static_cast
<
int
>
(
idx_o
[
dim
]
*
stride
[
d_2
])
-
static_cast
<
int
>
(
padding
[
d_2
]);
auto
d_2
=
dim
-
2
;
int
start
=
static_cast
<
int
>
(
idx_o
[
dim
]
*
stride
[
d_2
])
-
static_cast
<
int
>
(
padding
_vals
[
d_2
]);
int
end
;
// NOLINT
if
(
count_include_pad
and
ceil_mode
and
(
mode
!=
pooling_mode
::
max
))
...
...
@@ -297,7 +341,7 @@ struct pooling
// Check if this kernel extends beyond the padding at end of dimension
end
=
std
::
min
(
start
+
kernel_dims
[
d_2
],
in_lens
[
dim
]
+
static_cast
<
int
>
(
padding
[
d_2
]));
in_lens
[
dim
]
+
static_cast
<
int
>
(
padding
_vals
[
d_2
]));
}
else
{
...
...
@@ -316,11 +360,12 @@ struct pooling
}
shape
win_shape
{
output_shape
.
type
(),
win_size
};
auto
pool_size
=
win_shape
.
elements
();
double
output_val
=
op
.
template
init
<
Type
>();
// for each element in the window...
shape_for_each
(
win_shape
,
[
&
](
auto
idx_w
)
{
shape_for_each
(
win_shape
,
[
&
](
const
auto
&
idx_w
)
{
// the coordinates of this element
auto
idx
=
idx_o
;
...
...
@@ -354,30 +399,65 @@ struct pooling
argument
compute
(
const
dyn_output
&
dyn_out
,
std
::
vector
<
argument
>
args
)
const
{
argument
result
{
dyn_out
.
computed_shape
}
;
argument
result
;
auto
input_lens
=
args
[
0
].
get_shape
().
lens
();
std
::
vector
<
std
::
size_t
>
kernel_dims
;
shape
output_shape
;
// If we have to auto-calculate padding, it will be passed to calc_pooling() as an argument
// instead of the member variable padding.
std
::
vector
<
std
::
size_t
>
temp_padding
(
padding
);
if
(
dyn_global
)
{
// for dynamic GlobalPooling, there's no padding
kernel_dims
.
insert
(
kernel_dims
.
end
(),
input_lens
.
begin
()
+
2
,
input_lens
.
end
());
output_shape
=
dyn_out
.
computed_shape
;
result
=
dyn_out
.
computed_shape
;
}
else
else
if
((
padding_mode
!=
op
::
padding_mode_t
::
default_
))
{
// if padding_mode is set, input was a dynamic size. Calculate padded size now.
// kernel_lens is the same as kernel_dims, but prepended with the 2 non-
// spatial dimensions. For size computations, it's used like the weights
// tensor for convolutions.
std
::
vector
<
std
::
size_t
>
kernel_lens
;
kernel_lens
.
insert
(
kernel_lens
.
end
(),
input_lens
.
begin
(),
input_lens
.
begin
()
+
2
);
kernel_lens
.
insert
(
kernel_lens
.
end
(),
lengths
.
begin
(),
lengths
.
end
());
kernel_dims
=
this
->
lengths
;
auto
type
=
args
[
0
].
get_shape
().
type
();
// dilation not currently supported for pooling, so default to all 1's
temp_padding
=
calc_dyn_auto_pad
(
input_lens
,
kernel_lens
,
stride
,
{
1
,
1
},
bool
(
padding_mode
==
op
::
same_upper
));
output_shape
=
compute_padded_pool_shape
(
args
[
0
].
get_shape
(),
shape
(
type
,
kernel_dims
),
temp_padding
,
stride
,
{
1
,
1
});
result
=
argument
(
output_shape
);
}
else
// fixed/static input
{
kernel_dims
=
this
->
lengths
;
output_shape
=
dyn_out
.
computed_shape
;
result
=
dyn_out
.
computed_shape
;
}
// Perform the computation and populate result
visit_all
(
result
,
args
[
0
])([
&
](
auto
output
,
auto
input
)
{
using
type
=
typename
decltype
(
output
)
::
value_type
;
switch
(
mode
)
{
case
migraphx
::
op
::
pooling_mode
::
average
:
calc_pooling
<
type
>
(
dyn_out
.
computed_shape
,
output
,
input
,
kernel_dims
,
avg_pool
{});
calc_pooling
<
type
>
(
output_shape
,
output
,
input
,
kernel_dims
,
temp_padding
,
avg_pool
{});
break
;
case
migraphx
::
op
::
pooling_mode
::
max
:
calc_pooling
<
type
>
(
dyn_out
.
computed_shape
,
output
,
input
,
kernel_dims
,
max_pool
{});
calc_pooling
<
type
>
(
output_shape
,
output
,
input
,
kernel_dims
,
temp_padding
,
max_pool
{});
break
;
case
migraphx
::
op
::
pooling_mode
::
lpnorm
:
calc_pooling
<
type
>
(
dyn_out
.
com
put
ed
_shape
,
output
,
input
,
kernel_dims
,
lpnorm_pool
{
lp_order
});
out
put_shape
,
output
,
input
,
kernel_dims
,
temp_padding
,
lpnorm_pool
{
lp_order
});
break
;
}
});
...
...
Prev
1
2
3
4
5
6
…
13
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment