Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
6be7f1fb
Unverified
Commit
6be7f1fb
authored
Oct 31, 2022
by
Brian Pickrell
Committed by
GitHub
Oct 31, 2022
Browse files
Merge branch 'develop' into dynamic_reduce
parents
6ebe1df0
5ba656a3
Changes
40
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
359 additions
and
19 deletions
+359
-19
src/targets/cpu/target.cpp
src/targets/cpu/target.cpp
+0
-1
src/targets/gpu/compile_hip_code_object.cpp
src/targets/gpu/compile_hip_code_object.cpp
+2
-3
src/targets/gpu/hip.cpp
src/targets/gpu/hip.cpp
+2
-2
src/targets/gpu/jit/pad.cpp
src/targets/gpu/jit/pad.cpp
+100
-0
src/targets/gpu/kernels/include/migraphx/kernels/pad.hpp
src/targets/gpu/kernels/include/migraphx/kernels/pad.hpp
+63
-0
src/targets/gpu/kernels/include/migraphx/kernels/ranges.hpp
src/targets/gpu/kernels/include/migraphx/kernels/ranges.hpp
+49
-0
src/targets/gpu/lowering.cpp
src/targets/gpu/lowering.cpp
+0
-1
src/targets/gpu/mlir.cpp
src/targets/gpu/mlir.cpp
+6
-1
src/targets/gpu/target.cpp
src/targets/gpu/target.cpp
+2
-2
test/gpu/literal.cpp
test/gpu/literal.cpp
+1
-1
test/gpu/quantization.cpp
test/gpu/quantization.cpp
+0
-1
test/onnx/onnx_test.cpp
test/onnx/onnx_test.cpp
+0
-1
test/operators.cpp
test/operators.cpp
+0
-1
test/ref_ops_test.cpp
test/ref_ops_test.cpp
+15
-1
test/simplify_qdq_test.cpp
test/simplify_qdq_test.cpp
+0
-1
test/verify/test_pad_large.cpp
test/verify/test_pad_large.cpp
+42
-0
test/verify/test_reduce_op_large.cpp
test/verify/test_reduce_op_large.cpp
+14
-1
test/verify/test_shape_alloc.cpp
test/verify/test_shape_alloc.cpp
+61
-0
tools/accuracy/requirements.txt
tools/accuracy/requirements.txt
+1
-1
tools/install_prereqs.sh
tools/install_prereqs.sh
+1
-1
No files found.
src/targets/cpu/target.cpp
View file @
6be7f1fb
...
...
@@ -41,7 +41,6 @@
#include <migraphx/rewrite_quantization.hpp>
#include <migraphx/rewrite_rnn.hpp>
#include <migraphx/schedule.hpp>
#include <migraphx/memory_coloring.hpp>
#include <migraphx/simplify_algebra.hpp>
#include <migraphx/simplify_qdq.hpp>
#include <migraphx/simplify_reshapes.hpp>
...
...
src/targets/gpu/compile_hip_code_object.cpp
View file @
6be7f1fb
...
...
@@ -145,8 +145,7 @@ compute_global_for(context& ctx, std::size_t n, std::size_t over)
std
::
size_t
compute_block_size
(
std
::
size_t
n
,
std
::
size_t
max_block_size
)
{
const
std
::
size_t
min_block_size
=
64
;
const
std
::
size_t
base_block_size
=
32
;
auto
block_size
=
(((
n
-
1
)
/
base_block_size
+
1
))
*
base_block_size
;
auto
block_size
=
(((
n
-
1
)
/
min_block_size
+
1
))
*
min_block_size
;
return
std
::
min
(
std
::
max
(
min_block_size
,
block_size
),
max_block_size
);
}
...
...
src/targets/gpu/hip.cpp
View file @
6be7f1fb
...
...
@@ -183,8 +183,8 @@ argument register_on_gpu(const argument& arg)
{
auto
arg_shared
=
arg
.
share
();
auto
p
=
register_on_gpu
(
arg_shared
.
data
(),
arg_shared
.
get_shape
().
bytes
());
return
{
arg_shared
.
get_shape
()
,
[
p
,
a
=
std
::
move
(
arg_shared
)]()
mutable
{
return
get_device_ptr
(
p
.
get
());
}};
auto
s
=
arg_shared
.
get_shape
()
;
return
{
s
,
[
p
,
a
=
std
::
move
(
arg_shared
)]()
mutable
{
return
get_device_ptr
(
p
.
get
());
}};
}
argument
to_gpu
(
const
argument
&
arg
,
bool
host
)
...
...
src/targets/gpu/jit/pad.cpp
0 → 100644
View file @
6be7f1fb
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/gpu/compiler.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/compile_hip_code_object.hpp>
#include <migraphx/gpu/compile_hip.hpp>
#include <migraphx/gpu/compile_gen.hpp>
#include <migraphx/reduce_dims.hpp>
#include <migraphx/float_equal.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
gpu
{
using
namespace
migraphx
::
gpu
::
gen
;
// NOLINT
static
const
char
*
const
pointwise_kernel
=
R"__migraphx__(
#include <migraphx/kernels/pad.hpp>
#include <migraphx/kernels/index.hpp>
#include <migraphx/kernels/ops.hpp>
#include <args.hpp>
namespace migraphx {
extern "C" {
__global__ void pad_kernel(void* input_p, void* output_p)
{
auto offsets = index_ints<${offsets}>{};
auto idx = make_index();
make_tensors()(input_p, output_p)([&](auto input, auto output) {
pad(idx, offsets, input, output, ${pad_val});
});
}
}
} // namespace migraphx
)__migraphx__"
;
struct
pad_compiler
:
compiler
<
pad_compiler
>
{
std
::
vector
<
std
::
string
>
names
()
const
{
return
{
"pad"
};
}
operation
compile_op
(
context
&
ctx
,
const
std
::
vector
<
shape
>&
inputs
,
const
value
&
v
)
const
{
hip_compile_options
options
;
options
.
inputs
=
inputs
;
options
.
output
=
inputs
.
back
();
options
.
virtual_inputs
=
reduce_dims
(
inputs
);
options
.
kernel_name
=
"pad_kernel"
;
options
.
set_launch_params
(
v
,
compute_global_for
(
ctx
,
inputs
.
at
(
1
).
elements
()));
auto
pad_val
=
v
.
get
(
"value"
,
0.
f
);
auto
pad_val_string
=
to_string
(
pad_val
);
if
(
float_equal
(
pad_val
,
std
::
numeric_limits
<
float
>::
lowest
()))
pad_val_string
=
"lowest{}"
;
if
(
float_equal
(
pad_val
,
std
::
numeric_limits
<
float
>::
max
()))
pad_val_string
=
"highest{}"
;
auto
padding
=
v
.
at
(
"pads"
).
to_vector
<
int64_t
>
();
auto
input_lens
=
inputs
.
front
().
lens
();
std
::
vector
<
size_t
>
offsets
(
input_lens
.
size
());
std
::
copy
(
padding
.
begin
(),
padding
.
begin
()
+
offsets
.
size
(),
offsets
.
begin
());
auto
src
=
interpolate_string
(
pointwise_kernel
,
{{
"pad_val"
,
to_string
(
pad_val_string
)},
{
"offsets"
,
to_string_range
(
offsets
)}});
return
compile_hip_code_object
(
src
,
options
);
}
compiler_replace
compile
(
context
&
ctx
,
instruction_ref
ins
,
const
operation
&
op
)
const
{
return
replace
(
compile_op
(
ctx
,
to_shapes
(
ins
->
inputs
()),
op
.
to_value
()));
}
};
}
// namespace gpu
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
src/targets/gpu/kernels/include/migraphx/kernels/pad.hpp
0 → 100644
View file @
6be7f1fb
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_KERNELS_PAD_HPP
#define MIGRAPHX_GUARD_KERNELS_PAD_HPP
#include <migraphx/kernels/shape.hpp>
#include <migraphx/kernels/index.hpp>
#include <migraphx/kernels/algorithm.hpp>
#include <migraphx/kernels/ranges.hpp>
namespace
migraphx
{
template
<
class
Offsets
,
class
Input
,
class
Output
,
class
PadVal
>
__device__
void
pad
(
const
index
&
idx
,
const
Offsets
&
offsets
,
const
Input
&
input
,
Output
&
output
,
const
PadVal
&
pad_val
)
{
auto
output_shape
=
output
.
get_shape
();
idx
.
global_stride
(
output_shape
.
elements
(),
[
&
](
auto
i
)
{
// 1. get current multi-index for output
// 2. get the size of the input to determine input boundaries
// 3. compute the corresponding multi-index for input by accounting for offsets
// 4. if current multi-index is within offsets or input's new multi-index is out of bounds,
// use pad value instead of input's value
auto
multi
=
output_shape
.
multi
(
i
);
auto
input_bounds
=
input
.
get_shape
().
lens
;
auto
input_idx
=
multi
-
offsets
;
auto
range_multi
=
range
(
multi
.
size
());
if
(
any_of
(
range_multi
.
begin
(),
range_multi
.
end
(),
[
&
](
auto
j
)
{
return
multi
[
j
]
<
offsets
[
j
]
or
input_idx
[
j
]
>=
input_bounds
[
j
];
}))
output
[
multi
]
=
pad_val
;
else
output
[
multi
]
=
input
[
input_idx
];
});
}
}
// namespace migraphx
#endif
src/targets/gpu/kernels/include/migraphx/kernels/ranges.hpp
0 → 100644
View file @
6be7f1fb
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_KERNELS_RANGES_HPP
#define MIGRAPHX_GUARD_KERNELS_RANGES_HPP
#include <migraphx/kernels/iota_iterator.hpp>
namespace
migraphx
{
template
<
class
Iterator
>
struct
iterator_range
{
Iterator
start
;
Iterator
last
;
constexpr
Iterator
begin
()
const
{
return
start
;
}
constexpr
Iterator
end
()
const
{
return
last
;
}
};
constexpr
iterator_range
<
iota_iterator
>
range
(
diff_int
start
,
diff_int
last
)
{
return
{{
start
,
{}},
{
last
,
{}}};
}
constexpr
iterator_range
<
iota_iterator
>
range
(
diff_int
last
)
{
return
range
(
0
,
last
);
}
}
// namespace migraphx
#endif // MIGRAPHX_GUARD_KERNELS_RANGES_HPP
src/targets/gpu/lowering.cpp
View file @
6be7f1fb
...
...
@@ -100,7 +100,6 @@ struct miopen_apply
add_extend_op
(
"lrn"
);
add_extend_op
(
"multinomial"
);
add_extend_op
(
"nonzero"
);
add_extend_op
(
"pad"
);
add_extend_op
(
"pooling"
);
add_extend_op
(
"prefix_scan_sum"
);
add_extend_op
(
"reverse"
);
...
...
src/targets/gpu/mlir.cpp
View file @
6be7f1fb
...
...
@@ -196,6 +196,7 @@ struct mlir_program
MlirType
make_tensor
(
const
shape
&
s
)
const
{
assert
(
s
.
standard
());
std
::
vector
<
int64_t
>
lens
(
s
.
lens
().
begin
(),
s
.
lens
().
end
());
return
mlirRankedTensorTypeGet
(
lens
.
size
(),
lens
.
data
(),
make_type
(
s
.
type
()),
mlirAttributeGetNull
());
...
...
@@ -371,7 +372,11 @@ struct mlir_program
mlir_operation_state
&
add_results
(
const
std
::
vector
<
shape
>&
outputs
)
{
auto
x
=
prog
->
make_tensors
(
outputs
);
std
::
vector
<
shape
>
reshaped
(
outputs
.
size
());
std
::
transform
(
outputs
.
begin
(),
outputs
.
end
(),
reshaped
.
begin
(),
[](
const
shape
&
r
)
{
return
shape
{
r
.
type
(),
r
.
lens
()};
});
auto
x
=
prog
->
make_tensors
(
reshaped
);
mlirOperationStateAddResults
(
&
op_state
,
x
.
size
(),
x
.
data
());
return
*
this
;
}
...
...
src/targets/gpu/target.cpp
View file @
6be7f1fb
...
...
@@ -138,12 +138,12 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
dead_code_elimination
{},
pack_int8_args
{},
dead_code_elimination
{},
adjust_allocation
{
gpu_allocation_model
{}},
dead_code_elimination
{},
fuse_ops
{
&
ctx
,
options
.
fast_math
},
dead_code_elimination
{},
replace_allocate
{
gpu_allocation_model
{},
options
.
offload_copy
},
dead_code_elimination
{},
adjust_allocation
{
gpu_allocation_model
{}},
dead_code_elimination
{},
compile_ops
{
&
ctx
},
dead_code_elimination
{},
write_literals
{
&
ctx
},
...
...
test/gpu/literal.cpp
View file @
6be7f1fb
...
...
@@ -48,4 +48,4 @@ void gpu_literal_test()
}
}
int
main
()
{
gpu_literal_test
();
}
int
main
()
{
gpu_literal_test
();
}
// NOLINT (bugprone-exception-escape)
test/gpu/quantization.cpp
View file @
6be7f1fb
...
...
@@ -30,7 +30,6 @@
#include <migraphx/ref/target.hpp>
#include <migraphx/gpu/target.hpp>
#include <migraphx/verify.hpp>
#include <migraphx/quantization.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/propagate_constant.hpp>
#include <migraphx/pass_manager.hpp>
...
...
test/onnx/onnx_test.cpp
View file @
6be7f1fb
...
...
@@ -42,7 +42,6 @@
#include <migraphx/op/lrn.hpp>
#include <migraphx/op/reshape.hpp>
#include <migraphx/op/unknown.hpp>
#include <random>
#include <migraphx/serialize.hpp>
...
...
test/operators.cpp
View file @
6be7f1fb
...
...
@@ -29,7 +29,6 @@
#include <migraphx/module.hpp>
#include <sstream>
#include <string>
#include <migraphx/make_op.hpp>
#include <migraphx/serialize.hpp>
...
...
test/ref_ops_test.cpp
View file @
6be7f1fb
...
...
@@ -31,7 +31,6 @@
#include <migraphx/instruction.hpp>
#include <migraphx/quantization.hpp>
#include <migraphx/ref/target.hpp>
#include <migraphx/quantization.hpp>
#include <migraphx/verify.hpp>
#include <migraphx/onnx.hpp>
#include <migraphx/make_op.hpp>
...
...
@@ -4306,6 +4305,21 @@ TEST_CASE(pad_test)
EXPECT(migraphx::verify_range(results_vector, gold));
}
TEST_CASE(pad_test_asym)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {2, 2}};
auto l0 = mm->add_literal(migraphx::literal{s, {1, 2, 3, 4}});
mm->add_instruction(migraphx::make_op("pad", {{"pads", {0, 0, 1, 1}}}), l0);
p.compile(migraphx::ref::target{});
auto result = p.eval({}).back();
std::vector<float> results_vector(9);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{1, 2, 0, 3, 4, 0, 0, 0, 0};
EXPECT(migraphx::verify_range(results_vector, gold));
}
TEST_CASE(pad_test_highest_half)
{
migraphx::program p;
...
...
test/simplify_qdq_test.cpp
View file @
6be7f1fb
...
...
@@ -33,7 +33,6 @@
#include <migraphx/matcher.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/verify.hpp>
#include <migraphx/ref/target.hpp>
#include <migraphx/apply_alpha_beta.hpp>
bool
is_convolution
(
const
migraphx
::
instruction
&
ins
)
{
return
ins
.
name
()
==
"convolution"
;
}
...
...
test/verify/test_pad_large.cpp
0 → 100644
View file @
6be7f1fb
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct
test_pad_large
:
verify_program
<
test_pad_large
>
{
migraphx
::
program
create_program
()
const
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
s0
{
migraphx
::
shape
::
float_type
,
{
586
,
3
,
224
,
224
}};
std
::
vector
<
int64_t
>
pads0
=
{
0
,
0
,
1
,
1
,
0
,
0
,
1
,
1
};
auto
l0
=
mm
->
add_parameter
(
"x"
,
s0
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"pad"
,
{{
"pads"
,
pads0
}}),
l0
);
return
p
;
}
};
test/verify/test_reduce_op_large.cpp
View file @
6be7f1fb
...
...
@@ -51,7 +51,7 @@ template struct test_reduce_op_large<migraphx::op::reduce_min, 1, migraphx::shap
template
struct
test_reduce_op_large
<
migraphx
::
op
::
reduce_prod
,
2
,
migraphx
::
shape
::
float_type
>;
template
struct
test_reduce_op_large
<
migraphx
::
op
::
reduce_sum
,
1
,
migraphx
::
shape
::
float_type
>;
struct
test_reduce_mean
:
verify_program
<
test_reduce_mean
>
struct
test_reduce_mean
_1
:
verify_program
<
test_reduce_mean
_1
>
{
migraphx
::
program
create_program
()
const
{
...
...
@@ -63,3 +63,16 @@ struct test_reduce_mean : verify_program<test_reduce_mean>
return
p
;
};
};
struct
test_reduce_mean_2
:
verify_program
<
test_reduce_mean_2
>
{
migraphx
::
program
create_program
()
const
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
s
{
migraphx
::
shape
::
float_type
,
{
336
,
400
}};
auto
x
=
mm
->
add_parameter
(
"x"
,
s
);
mm
->
add_instruction
(
migraphx
::
op
::
reduce_mean
{{
1
}},
x
);
return
p
;
};
};
test/verify/test_shape_alloc.cpp
0 → 100644
View file @
6be7f1fb
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/op/reduce_mean.hpp>
/**
* @brief test_shape_alloc sets up a situation that could lead to an exception "convolution: Shapes
* are not in standard layout" if a "replace_allocate" compiler pass is not followed with
* "adjust_allocation". The last transpose instruction generates a shape with a stride of 1 in
* the 2nd index, a non-standard layout that should be reallocated by adjust_allocation.
*/
struct
test_shape_alloc
:
verify_program
<
test_shape_alloc
>
{
migraphx
::
program
create_program
()
const
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
auto
weights
=
mm
->
add_literal
(
migraphx
::
generate_literal
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
11
,
8
,
1
,
1
},
{
8
,
1
,
1
,
1
}}));
auto
x
=
mm
->
add_parameter
(
"x"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
1
,
8
,
7
,
7
}});
auto
transpose1
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
{
0
,
2
,
3
,
1
}}}),
x
);
// -> float_type, {1, 7, 7, 8}, {392, 7, 1, 49}
auto
reduce_ins
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"reduce_mean"
,
{{
"axes"
,
{
1
,
2
}}}),
transpose1
);
// -> float_type, {1, 1, 1, 8}, {8, 8, 8, 1}
auto
transpose2
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
{
0
,
3
,
1
,
2
}}}),
reduce_ins
);
// -> float_type, {1, 8, 1, 1}, {8, 1, 8, 8}
auto
conv_op
=
migraphx
::
make_op
(
"convolution"
);
mm
->
add_instruction
(
conv_op
,
transpose2
,
weights
);
return
p
;
}
};
tools/accuracy/requirements.txt
View file @
6be7f1fb
...
...
@@ -21,5 +21,5 @@
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
#####################################################################################
numpy==1.
18.5
numpy==1.
21.6
onnxruntime==1.10.0
tools/install_prereqs.sh
View file @
6be7f1fb
...
...
@@ -57,7 +57,7 @@ echo "Dependencies are installed at $PREFIX"
rbuild prepare
-d
$PREFIX
-s
develop
# install onnx package for unit tests
pip3
install
onnx
==
1.8.1
numpy
==
1.
18.5
typing
==
3.7.4
pytest
==
6.0.1
packaging
==
16.8
pip3
install
onnx
==
1.8.1
numpy
==
1.
21.6
typing
==
3.7.4
pytest
==
6.0.1
packaging
==
16.8
# pin version of protobuf in Python for onnx runtime unit tests
pip3
install
protobuf
==
3.20.0
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment