Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
13d14c66
Commit
13d14c66
authored
Oct 24, 2023
by
Brian Pickrell
Browse files
Merge branch 'develop' into dyn_resize_gather
parents
f4e7d9d9
d1abf06f
Changes
420
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
454 additions
and
81 deletions
+454
-81
test/verify/test_scatter_nonstandard_shape.cpp
test/verify/test_scatter_nonstandard_shape.cpp
+49
-0
test/verify/test_shrink.cpp
test/verify/test_shrink.cpp
+86
-0
test/verify/test_squeeze_conv_relu.cpp
test/verify/test_squeeze_conv_relu.cpp
+45
-0
test/verify/test_unsqueeze_conv_relu.cpp
test/verify/test_unsqueeze_conv_relu.cpp
+45
-0
tools/CMakeLists.txt
tools/CMakeLists.txt
+13
-1
tools/accuracy/accuracy_checker.py
tools/accuracy/accuracy_checker.py
+35
-3
tools/accuracy/requirements.txt
tools/accuracy/requirements.txt
+1
-1
tools/api.py
tools/api.py
+10
-10
tools/api/api.cpp
tools/api/api.cpp
+8
-0
tools/api/migraphx.h
tools/api/migraphx.h
+1
-0
tools/build_and_test_onnxrt.sh
tools/build_and_test_onnxrt.sh
+1
-1
tools/check_stamped.py
tools/check_stamped.py
+37
-53
tools/docker/sles.docker
tools/docker/sles.docker
+1
-1
tools/docker/ubuntu_2204.dockerfile
tools/docker/ubuntu_2204.dockerfile
+2
-2
tools/download_models.sh
tools/download_models.sh
+5
-0
tools/generate.py
tools/generate.py
+88
-0
tools/install_prereqs.sh
tools/install_prereqs.sh
+4
-3
tools/license_stamper.py
tools/license_stamper.py
+2
-2
tools/te.py
tools/te.py
+6
-3
tools/test_runner.py
tools/test_runner.py
+15
-1
No files found.
test/verify/test_scatter_nonstandard_shape.cpp
0 → 100644
View file @
13d14c66
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct
test_scatter_nonstandard_shape
:
verify_program
<
test_scatter_nonstandard_shape
>
{
migraphx
::
program
create_program
()
const
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
sd
{
migraphx
::
shape
::
float_type
,
{
3
,
1
,
3
},
{
1
,
3
,
2
}};
migraphx
::
shape
si
{
migraphx
::
shape
::
int32_type
,
{
2
,
1
,
3
},
{
1
,
3
,
2
}};
std
::
vector
<
int
>
vi
=
{
1
,
0
,
2
,
0
,
2
,
1
};
migraphx
::
shape
su
{
migraphx
::
shape
::
float_type
,
{
2
,
1
,
3
},
{
1
,
2
,
3
}};
auto
pd
=
mm
->
add_parameter
(
"data"
,
sd
);
auto
li
=
mm
->
add_literal
(
migraphx
::
literal
{
si
,
vi
});
auto
pu
=
mm
->
add_parameter
(
"update"
,
su
);
auto
r
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"scatter_none"
,
{{
"axis"
,
-
1
}}),
pd
,
li
,
pu
);
mm
->
add_return
({
r
});
return
p
;
}
};
test/verify/test_shrink.cpp
0 → 100644
View file @
13d14c66
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/op/common.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/common.hpp>
template
<
migraphx
::
shape
::
type_t
T
>
struct
test_shrink
:
verify_program
<
test_shrink
<
T
>>
{
migraphx
::
program
create_program
()
const
{
migraphx
::
program
p
;
float
bias
=
1.5
;
float
lambd
=
1.5
;
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
is
{
T
,
{
2
,
3
}};
std
::
vector
<
float
>
data
;
migraphx
::
shape
::
visit
(
T
,
[
&
](
auto
as
)
{
as
.
is_signed
()
?
data
.
assign
({
-
3.0
,
-
2.0
,
-
1.0
,
0.0
,
1.0
,
2.0
})
:
data
.
assign
({
3.0
,
2.0
,
1.0
,
0.0
,
1.0
,
2.0
});
});
auto
x
=
mm
->
add_literal
(
migraphx
::
literal
{
is
,
data
});
auto
lit_bias
=
mm
->
add_literal
(
migraphx
::
literal
{
migraphx
::
shape
::
float_type
,
{
bias
}});
auto
lit_neg_lambd
=
mm
->
add_literal
(
migraphx
::
literal
{
migraphx
::
shape
::
float_type
,
{
-
lambd
}});
auto
lit_lambd
=
mm
->
add_literal
(
migraphx
::
literal
{
migraphx
::
shape
::
float_type
,
{
lambd
}});
auto
x_plus_bias
=
add_common_op
(
*
mm
,
migraphx
::
make_op
(
"add"
),
{
x
,
lit_bias
});
auto
x_min_bias
=
add_common_op
(
*
mm
,
migraphx
::
make_op
(
"sub"
),
{
x
,
lit_bias
});
auto
cond1
=
add_common_op
(
*
mm
,
migraphx
::
make_op
(
"less"
),
{
x
,
lit_neg_lambd
});
auto
cond2_a
=
add_common_op
(
*
mm
,
migraphx
::
make_op
(
"not"
),
{
cond1
});
auto
cond2_b
=
add_common_op
(
*
mm
,
migraphx
::
make_op
(
"greater"
),
{
x
,
lit_lambd
});
auto
cond2
=
add_common_op
(
*
mm
,
migraphx
::
make_op
(
"logical_and"
),
{
cond2_a
,
cond2_b
});
auto
mul1
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"convert"
,
{{
"target_type"
,
T
}}),
cond1
);
auto
mul2
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"convert"
,
{{
"target_type"
,
T
}}),
cond2
);
auto
first
=
add_common_op
(
*
mm
,
migraphx
::
make_op
(
"mul"
),
{
mul1
,
x_plus_bias
});
auto
second
=
add_common_op
(
*
mm
,
migraphx
::
make_op
(
"mul"
),
{
mul2
,
x_min_bias
});
auto
ret
=
add_common_op
(
*
mm
,
migraphx
::
make_op
(
"add"
),
{
first
,
second
});
if
(
ret
->
get_shape
().
type
()
!=
T
)
{
mm
->
add_instruction
(
migraphx
::
make_op
(
"convert"
,
{{
"target_type"
,
T
}}),
ret
);
}
return
p
;
}
};
template
struct
test_shrink
<
migraphx
::
shape
::
double_type
>;
template
struct
test_shrink
<
migraphx
::
shape
::
float_type
>;
template
struct
test_shrink
<
migraphx
::
shape
::
half_type
>;
template
struct
test_shrink
<
migraphx
::
shape
::
int64_type
>;
template
struct
test_shrink
<
migraphx
::
shape
::
int32_type
>;
template
struct
test_shrink
<
migraphx
::
shape
::
int16_type
>;
template
struct
test_shrink
<
migraphx
::
shape
::
int8_type
>;
template
struct
test_shrink
<
migraphx
::
shape
::
uint64_type
>;
template
struct
test_shrink
<
migraphx
::
shape
::
uint32_type
>;
template
struct
test_shrink
<
migraphx
::
shape
::
uint16_type
>;
template
struct
test_shrink
<
migraphx
::
shape
::
uint8_type
>;
test/verify/test_squeeze_conv_relu.cpp
0 → 100644
View file @
13d14c66
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct
test_squeeze_conv_relu
:
verify_program
<
test_squeeze_conv_relu
>
{
migraphx
::
program
create_program
()
const
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
auto
input
=
mm
->
add_parameter
(
"x"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
4
,
3
,
1
,
3
,
3
}});
input
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"squeeze"
,
{{
"axes"
,
{
2
}}}),
input
);
auto
weights
=
mm
->
add_parameter
(
"w"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
4
,
3
,
3
,
3
}});
auto
conv
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"convolution"
),
input
,
weights
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"relu"
),
conv
);
return
p
;
}
};
test/verify/test_unsqueeze_conv_relu.cpp
0 → 100644
View file @
13d14c66
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct
test_unsqueeze_conv_relu
:
verify_program
<
test_unsqueeze_conv_relu
>
{
migraphx
::
program
create_program
()
const
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
auto
input
=
mm
->
add_parameter
(
"x"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
1
,
3
,
3
,
3
}});
auto
weights
=
mm
->
add_parameter
(
"w"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
3
,
3
,
3
}});
weights
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"unsqueeze"
,
{{
"axes"
,
{
0
}}}),
weights
);
auto
conv
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"convolution"
),
input
,
weights
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"relu"
),
conv
);
return
p
;
}
};
tools/CMakeLists.txt
View file @
13d14c66
...
@@ -22,4 +22,16 @@
...
@@ -22,4 +22,16 @@
# THE SOFTWARE.
# THE SOFTWARE.
#####################################################################################
#####################################################################################
add_custom_target
(
generate bash
${
CMAKE_CURRENT_SOURCE_DIR
}
/generate.sh
)
find_package
(
Python 3 COMPONENTS Interpreter
)
if
(
NOT Python_EXECUTABLE
)
message
(
WARNING
"Python 3 interpreter not found - skipping 'generate' target!"
)
return
()
endif
()
find_program
(
CLANG_FORMAT clang-format PATHS /opt/rocm/llvm ENV HIP_PATH PATH_SUFFIXES bin
)
if
(
NOT CLANG_FORMAT
)
message
(
WARNING
"clang-format not found - skipping 'generate' target!"
)
return
()
endif
()
add_custom_target
(
generate
${
Python_EXECUTABLE
}
generate.py -f
${
CLANG_FORMAT
}
WORKING_DIRECTORY
${
CMAKE_CURRENT_SOURCE_DIR
}
)
tools/accuracy/accuracy_checker.py
View file @
13d14c66
...
@@ -82,6 +82,27 @@ def parse_args():
...
@@ -82,6 +82,27 @@ def parse_args():
default
=
False
,
default
=
False
,
help
=
'Turn on ort VERBOSE logging via session options'
)
help
=
'Turn on ort VERBOSE logging via session options'
)
parser
.
add_argument
(
'--disable-offload-copy'
,
dest
=
"offload_copy"
,
action
=
'store_false'
,
default
=
True
,
help
=
'Disable offload copying (user must handle copy to and from device)'
)
parser
.
add_argument
(
'--disable-fast-math'
,
dest
=
"fast_math"
,
action
=
'store_false'
,
default
=
True
,
help
=
'Disable fast math optimizations (etc: rewrite_gelu)'
)
parser
.
add_argument
(
'--exhaustive_tune'
,
dest
=
"exhaustive_tune"
,
action
=
'store_true'
,
default
=
False
,
help
=
'Enable exhaustive tuning for solutions'
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
return
args
return
args
...
@@ -177,7 +198,12 @@ def main():
...
@@ -177,7 +198,12 @@ def main():
print
(
model
)
print
(
model
)
if
not
args
.
ort_run
:
if
not
args
.
ort_run
:
model
.
compile
(
migraphx
.
get_target
(
args
.
target
))
model
.
compile
(
migraphx
.
get_target
(
args
.
target
),
offload_copy
=
args
.
offload_copy
,
fast_math
=
args
.
fast_math
,
exhaustive_tune
=
args
.
exhaustive_tune
,
)
params
=
{}
params
=
{}
test_inputs
=
{}
test_inputs
=
{}
...
@@ -194,10 +220,16 @@ def main():
...
@@ -194,10 +220,16 @@ def main():
else
:
else
:
test_input
=
np
.
zeros
(
in_shape
).
astype
(
get_np_datatype
(
in_type
))
test_input
=
np
.
zeros
(
in_shape
).
astype
(
get_np_datatype
(
in_type
))
test_inputs
[
name
]
=
test_input
test_inputs
[
name
]
=
test_input
params
[
name
]
=
migraphx
.
argument
(
test_input
)
migraphx_arg
=
migraphx
.
argument
(
test_input
)
if
not
args
.
offload_copy
:
migraphx_arg
=
migraphx
.
to_gpu
(
migraphx_arg
)
params
[
name
]
=
migraphx_arg
if
not
args
.
ort_run
:
if
not
args
.
ort_run
:
pred_migx
=
np
.
array
(
model
.
run
(
params
)[
-
1
])
if
not
args
.
offload_copy
:
pred_migx
=
np
.
array
(
migraphx
.
from_gpu
(
model
.
run
(
params
)[
-
1
]))
else
:
pred_migx
=
np
.
array
(
model
.
run
(
params
)[
-
1
])
if
use_onnx
:
if
use_onnx
:
sess_op
=
ort
.
SessionOptions
()
sess_op
=
ort
.
SessionOptions
()
...
...
tools/accuracy/requirements.txt
View file @
13d14c66
...
@@ -22,4 +22,4 @@
...
@@ -22,4 +22,4 @@
# THE SOFTWARE.
# THE SOFTWARE.
#####################################################################################
#####################################################################################
numpy==1.21.6
numpy==1.21.6
onnxruntime==1.1
0.0
onnxruntime==1.1
6.1
tools/api.py
View file @
13d14c66
...
@@ -27,6 +27,7 @@ import re
...
@@ -27,6 +27,7 @@ import re
import
runpy
import
runpy
from
functools
import
wraps
from
functools
import
wraps
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Optional
,
Tuple
,
Union
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Optional
,
Tuple
,
Union
from
pathlib
import
Path
type_map
:
Dict
[
str
,
Callable
[[
'Parameter'
],
None
]]
=
{}
type_map
:
Dict
[
str
,
Callable
[[
'Parameter'
],
None
]]
=
{}
cpp_type_map
:
Dict
[
str
,
str
]
=
{}
cpp_type_map
:
Dict
[
str
,
str
]
=
{}
...
@@ -1281,18 +1282,17 @@ def template_eval(template, **kwargs):
...
@@ -1281,18 +1282,17 @@ def template_eval(template, **kwargs):
return
template
return
template
def
run
(
args
:
List
[
str
])
->
None
:
def
run
(
path
:
Union
[
Path
,
str
])
->
str
:
runpy
.
run_path
(
args
[
0
])
return
template_eval
(
open
(
path
).
read
())
if
len
(
args
)
>
1
:
f
=
open
(
args
[
1
]).
read
()
r
=
template_eval
(
f
)
if
__name__
==
"__main__"
:
sys
.
modules
[
'api'
]
=
sys
.
modules
[
'__main__'
]
runpy
.
run_path
(
sys
.
argv
[
1
])
if
len
(
sys
.
argv
)
>
2
:
r
=
run
(
sys
.
argv
[
2
])
sys
.
stdout
.
write
(
r
)
sys
.
stdout
.
write
(
r
)
else
:
else
:
sys
.
stdout
.
write
(
generate_c_header
())
sys
.
stdout
.
write
(
generate_c_header
())
sys
.
stdout
.
write
(
generate_c_api_body
())
sys
.
stdout
.
write
(
generate_c_api_body
())
# sys.stdout.write(generate_cpp_header())
# sys.stdout.write(generate_cpp_header())
if
__name__
==
"__main__"
:
sys
.
modules
[
'api'
]
=
sys
.
modules
[
'__main__'
]
run
(
sys
.
argv
[
1
:])
tools/api/api.cpp
View file @
13d14c66
...
@@ -38,26 +38,32 @@
...
@@ -38,26 +38,32 @@
#include <migraphx/register_op.hpp>
#include <migraphx/register_op.hpp>
#include <migraphx/json.hpp>
#include <migraphx/json.hpp>
#include <migraphx/convert_to_json.hpp>
#include <migraphx/convert_to_json.hpp>
#include <array>
#include <algorithm>
#include <algorithm>
#include <cstdarg>
#include <cstdarg>
namespace
migraphx
{
namespace
migraphx
{
#ifdef MIGRAPHX_BUILD_TESTING
static
thread_local
bool
disable_exception_catch
=
false
;
// NOLINT
static
thread_local
bool
disable_exception_catch
=
false
;
// NOLINT
extern
"C"
MIGRAPHX_C_EXPORT
void
migraphx_test_private_disable_exception_catch
(
bool
b
)
extern
"C"
MIGRAPHX_C_EXPORT
void
migraphx_test_private_disable_exception_catch
(
bool
b
)
{
{
disable_exception_catch
=
b
;
disable_exception_catch
=
b
;
}
}
#endif
template
<
class
F
>
template
<
class
F
>
migraphx_status
try_
(
F
f
,
bool
output
=
true
)
// NOLINT
migraphx_status
try_
(
F
f
,
bool
output
=
true
)
// NOLINT
{
{
#ifdef MIGRAPHX_BUILD_TESTING
if
(
disable_exception_catch
)
if
(
disable_exception_catch
)
{
{
f
();
f
();
}
}
else
else
{
{
#endif
try
try
{
{
f
();
f
();
...
@@ -81,7 +87,9 @@ migraphx_status try_(F f, bool output = true) // NOLINT
...
@@ -81,7 +87,9 @@ migraphx_status try_(F f, bool output = true) // NOLINT
{
{
return
migraphx_status_unknown_error
;
return
migraphx_status_unknown_error
;
}
}
#ifdef MIGRAPHX_BUILD_TESTING
}
}
#endif
return
migraphx_status_success
;
return
migraphx_status_success
;
}
}
...
...
tools/api/migraphx.h
View file @
13d14c66
...
@@ -26,6 +26,7 @@
...
@@ -26,6 +26,7 @@
#include <stdlib.h>
#include <stdlib.h>
#include <stdbool.h>
#include <stdbool.h>
#include <stdint.h>
#include <migraphx/api/export.h>
#include <migraphx/api/export.h>
...
...
tools/build_and_test_onnxrt.sh
View file @
13d14c66
...
@@ -40,4 +40,4 @@ echo 'InferenceSessionTests.CheckRunProfilerWithSessionOptions' >> ../../../tool
...
@@ -40,4 +40,4 @@ echo 'InferenceSessionTests.CheckRunProfilerWithSessionOptions' >> ../../../tool
echo
'InferenceSessionTests.CheckRunProfilerWithSessionOptions2'
>>
../../../tools/ci_build/github/pai/migraphx-excluded-tests.txt
echo
'InferenceSessionTests.CheckRunProfilerWithSessionOptions2'
>>
../../../tools/ci_build/github/pai/migraphx-excluded-tests.txt
echo
'InferenceSessionTests.Test3LayerNestedSubgraph'
>>
../../../tools/ci_build/github/pai/migraphx-excluded-tests.txt
echo
'InferenceSessionTests.Test3LayerNestedSubgraph'
>>
../../../tools/ci_build/github/pai/migraphx-excluded-tests.txt
echo
'InferenceSessionTests.Test2LayerNestedSubgraph'
>>
../../../tools/ci_build/github/pai/migraphx-excluded-tests.txt
echo
'InferenceSessionTests.Test2LayerNestedSubgraph'
>>
../../../tools/ci_build/github/pai/migraphx-excluded-tests.txt
../../../tools/ci_build/github/pai/
migraphx
_test_launcher.sh
||
(
gdb ./onnxruntime_test_all core
-batch
-ex
bt
&&
exit
1
)
../../../tools/ci_build/github/pai/
pai
_test_launcher.sh
||
(
gdb ./onnxruntime_test_all core
-batch
-ex
bt
&&
exit
1
)
tools/check_stamped.py
View file @
13d14c66
...
@@ -2,7 +2,7 @@
...
@@ -2,7 +2,7 @@
#####################################################################################
#####################################################################################
# The MIT License (MIT)
# The MIT License (MIT)
#
#
# Copyright (c) 2015-202
2
Advanced Micro Devices, Inc. All rights reserved.
# Copyright (c) 2015-202
3
Advanced Micro Devices, Inc. All rights reserved.
#
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# of this software and associated documentation files (the "Software"), to deal
...
@@ -27,11 +27,11 @@ import sys
...
@@ -27,11 +27,11 @@ import sys
debug
=
False
debug
=
False
# The filetypes we want to check for that are stamped
# The filetypes we want to check for that are stamped
# LICENSE is included here as it SHOULD have a li
s
cen
c
e in it otherwise flag it as unstamped
# LICENSE is included here as it SHOULD have a licen
s
e in it otherwise flag it as unstamped
supported_file_types
=
(
".cpp"
,
".hpp"
,
".h"
,
".ipynb"
,
".py"
,
".txt"
,
".sh"
,
supported_file_types
=
(
".cpp"
,
".hpp"
,
".h"
,
".ipynb"
,
".py"
,
".txt"
,
".sh"
,
".bsh"
,
"LICENSE"
,
".cmake"
)
".bsh"
,
"LICENSE"
,
".cmake"
)
#add general stuff we shouldn't stamp and any exceptions here
#
add general stuff we shouldn't stamp and any exceptions here
unsupported_file_types
=
[
unsupported_file_types
=
[
".onnx"
,
".pb"
,
".rst"
,
".jpg"
,
".jpeg"
,
".proto"
,
".md"
,
".clang"
,
".onnx"
,
".pb"
,
".rst"
,
".jpg"
,
".jpeg"
,
".proto"
,
".md"
,
".clang"
,
".weight"
,
".ini"
,
".json"
,
".docker"
,
".git"
,
".rules"
,
".yml"
".weight"
,
".ini"
,
".json"
,
".docker"
,
".git"
,
".rules"
,
".yml"
...
@@ -40,105 +40,89 @@ unsupported_file_types = [
...
@@ -40,105 +40,89 @@ unsupported_file_types = [
specificIgnores
=
(
"digits.txt"
,
"Dockerfile"
,
"Jenkinsfile"
,
""
)
specificIgnores
=
(
"digits.txt"
,
"Dockerfile"
,
"Jenkinsfile"
,
""
)
def
hasKeySequence
(
inputfile
,
key_message
):
def
hasKeySequence
(
inputfile
:
str
,
key_message
:
str
)
->
bool
:
result
=
False
if
key_message
in
inputfile
:
if
key_message
in
inputfile
:
re
sult
=
True
re
turn
True
return
result
return
False
#Simple just open and write stuff to each file with the license stamp
# Simple just open and write stuff to each file with the license stamp
def
openAndCheckFile
(
filename
):
def
needStampCheck
(
filename
:
str
)
->
bool
:
result
=
False
# open save old contents and append things here
#open save old contents and append things here
if
debug
:
print
(
"Open"
,
filename
,
end
=
' '
)
if
debug
is
True
:
print
(
"Open"
,
filename
,
end
=
''
)
try
:
try
:
file
=
open
(
filename
,
'r'
)
file
=
open
(
filename
,
'r'
)
except
OSError
as
e
:
except
OSError
as
e
:
if
debug
is
True
:
if
debug
:
print
(
str
(
e
)
+
"....Open Error: Skipping file "
)
print
(
str
(
e
)
+
"....Open Error: Skipping file "
)
file
.
close
()
file
.
close
()
return
return
False
else
:
else
:
with
file
as
contents
:
with
file
as
contents
:
try
:
try
:
save
=
contents
.
read
()
save
=
contents
.
read
()
hasAmdLic
=
hasKeySequence
(
save
,
"Advanced Micro Devices, Inc. All rights reserved"
)
#Check if we have a licence stamp already
# Check if we have a license stamp already
if
hasAmdLic
is
True
:
if
hasKeySequence
(
if
debug
is
True
:
save
,
print
(
"....Already Stamped: Skipping file "
)
"Advanced Micro Devices, Inc. All rights reserved"
):
if
debug
:
print
(
"....Already Stamped: Skipping file "
)
contents
.
close
()
contents
.
close
()
re
sult
=
Tru
e
re
turn
Fals
e
except
UnicodeDecodeError
as
eu
:
except
UnicodeDecodeError
as
eu
:
if
debug
is
True
:
if
debug
:
print
(
f
"
{
str
(
eu
)
}
...Skipping binary file "
)
print
(
str
(
eu
)
+
"...Skipping binary file "
)
contents
.
close
()
contents
.
close
()
re
sult
=
Tru
e
re
turn
Fals
e
return
result
return
True
# Deterine if filename is desired in the fileTuple past in
# Check if any element in fileTuple is in filename
def
check_filename
(
filename
,
fileTuple
):
def
check_filename
(
filename
:
str
,
fileTuple
:
tuple
or
list
)
->
bool
:
supported
=
False
if
any
([
x
in
filename
for
x
in
fileTuple
]):
for
key
in
fileTuple
:
return
True
if
key
in
filename
:
return
False
supported
=
True
break
return
supported
def
main
():
def
main
()
->
None
:
unsupported_file_types
.
extend
(
specificIgnores
)
unsupported_file_types
.
extend
(
specificIgnores
)
#Get a list of all the tracked files in our git repo
#
Get a list of all the tracked files in our git repo
proc
=
subprocess
.
run
(
"git ls-files --exclude-standard"
,
proc
=
subprocess
.
run
(
"git ls-files --exclude-standard"
,
shell
=
True
,
shell
=
True
,
stdout
=
subprocess
.
PIPE
)
stdout
=
subprocess
.
PIPE
)
fileList
=
proc
.
stdout
.
decode
().
split
(
'
\n
'
)
fileList
=
proc
.
stdout
.
decode
().
split
(
'
\n
'
)
if
debug
is
True
:
if
debug
:
print
(
"Target file list:
\n
"
+
str
(
fileList
))
print
(
"Target file list:
\n
"
+
str
(
fileList
))
unsupportedFiles
=
[]
unsupportedFiles
=
[]
unstampedFiles
=
[]
unstampedFiles
=
[]
unknownFiles
=
[]
unknownFiles
=
[]
for
file
in
fileList
:
for
file
in
fileList
:
supported
=
check_filename
(
file
,
supported_file_types
)
if
check_filename
(
file
,
supported_file_types
):
if
supported
is
True
:
if
needStampCheck
(
file
):
isStamped
=
openAndCheckFile
(
file
)
if
isStamped
is
False
:
unstampedFiles
.
append
(
file
)
unstampedFiles
.
append
(
file
)
elif
check_filename
(
file
,
unsupported_file_types
):
unsupportedFiles
.
append
(
file
)
else
:
else
:
unsupported
=
check_filename
(
file
,
unsupported_file_types
)
unknownFiles
.
append
(
file
)
if
unsupported
is
True
:
unsupportedFiles
.
append
(
file
)
else
:
unknownFiles
.
append
(
file
)
#Do a bunch of checks based on our file lists
#
Do a bunch of checks based on our file lists
if
len
(
unstampedFiles
)
>
0
:
if
len
(
unstampedFiles
)
>
0
:
print
(
"Error: The following "
+
str
(
len
(
unstampedFiles
))
+
print
(
"
\n
Error: The following "
+
str
(
len
(
unstampedFiles
))
+
" files are currently without a license:"
)
" files are currently without a license:"
)
print
(
str
(
unstampedFiles
))
print
(
str
(
unstampedFiles
))
sys
.
exit
(
1
)
sys
.
exit
(
1
)
if
len
(
unknownFiles
)
>
0
:
if
len
(
unknownFiles
)
>
0
:
print
(
"Error: The following "
+
str
(
len
(
unknownFiles
))
+
print
(
"
\n
Error: The following "
+
str
(
len
(
unknownFiles
))
+
" files not handled:"
)
" files not handled:"
)
print
(
str
(
unknownFiles
))
print
(
str
(
unknownFiles
))
sys
.
exit
(
2
)
sys
.
exit
(
2
)
sys
.
exit
(
0
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
main
()
main
()
tools/docker/sles.docker
View file @
13d14c66
...
@@ -3,7 +3,7 @@ FROM registry.suse.com/suse/sle15:15.4
...
@@ -3,7 +3,7 @@ FROM registry.suse.com/suse/sle15:15.4
RUN
sh
-c
'echo -e "
\
RUN
sh
-c
'echo -e "
\
[rocm]\n
\
[rocm]\n
\
name=rocm\n
\
name=rocm\n
\
baseurl=https://repo.radeon.com/rocm/zyp/5.
6
/main\n
\
baseurl=https://repo.radeon.com/rocm/zyp/5.
7
/main\n
\
enabled=1\n
\
enabled=1\n
\
gpgcheck=1\n
\
gpgcheck=1\n
\
gpgkey=https://repo.radeon.com/rocm/rocm.gpg.key\n
\
gpgkey=https://repo.radeon.com/rocm/rocm.gpg.key\n
\
...
...
tools/docker/ubuntu_2204.dockerfile
View file @
13d14c66
...
@@ -10,7 +10,7 @@ RUN apt-get update && apt-get install -y gnupg2 --no-install-recommends curl &&
...
@@ -10,7 +10,7 @@ RUN apt-get update && apt-get install -y gnupg2 --no-install-recommends curl &&
curl
-fsSL
http://repo.radeon.com/rocm/rocm.gpg.key | gpg
--dearmor
-o
/etc/apt/trusted.gpg.d/rocm-keyring.gpg
curl
-fsSL
http://repo.radeon.com/rocm/rocm.gpg.key | gpg
--dearmor
-o
/etc/apt/trusted.gpg.d/rocm-keyring.gpg
# Add rocm repository
# Add rocm repository
RUN
sh
-c
"echo 'deb [arch=amd64 signed-by=/etc/apt/trusted.gpg.d/rocm-keyring.gpg] http://repo.radeon.com/rocm/apt/5.
5
jammy main' > /etc/apt/sources.list.d/rocm.list"
RUN
sh
-c
"echo 'deb [arch=amd64 signed-by=/etc/apt/trusted.gpg.d/rocm-keyring.gpg] http://repo.radeon.com/rocm/apt/5.
7
jammy main' > /etc/apt/sources.list.d/rocm.list"
# From docs.amd.com for installing rocm. Needed to install properly
# From docs.amd.com for installing rocm. Needed to install properly
RUN
sh
-c
"echo 'Package: *
\n
Pin: release o=repo.radeon.com
\n
Pin-priority: 600' > /etc/apt/preferences.d/rocm-pin-600"
RUN
sh
-c
"echo 'Package: *
\n
Pin: release o=repo.radeon.com
\n
Pin-priority: 600' > /etc/apt/preferences.d/rocm-pin-600"
...
@@ -87,7 +87,7 @@ RUN test -f /usr/local/hash || exit 1
...
@@ -87,7 +87,7 @@ RUN test -f /usr/local/hash || exit 1
RUN
pip3
install
yapf
==
0.28.0
RUN
pip3
install
yapf
==
0.28.0
# Install doc requirements
# Install doc requirements
ADD
doc/requirements.txt /doc-requirements.txt
ADD
doc
s/.sphinx
/requirements.txt /doc-requirements.txt
RUN
pip3
install
-r
/doc-requirements.txt
RUN
pip3
install
-r
/doc-requirements.txt
# Download real models to run onnx unit tests
# Download real models to run onnx unit tests
...
...
tools/download_models.sh
View file @
13d14c66
...
@@ -49,3 +49,8 @@ do
...
@@ -49,3 +49,8 @@ do
curl https://download.onnxruntime.ai/onnx/models/
$name
.tar.gz
--output
$tmp_dir
/
$name
.tar.gz
curl https://download.onnxruntime.ai/onnx/models/
$name
.tar.gz
--output
$tmp_dir
/
$name
.tar.gz
tar
-xzvf
$tmp_dir
/
$name
.tar.gz
--directory
$model_dir
&&
rm
$tmp_dir
/
$name
.tar.gz
tar
-xzvf
$tmp_dir
/
$name
.tar.gz
--directory
$model_dir
&&
rm
$tmp_dir
/
$name
.tar.gz
done
done
# CI jobs can run as a different user then the docker image builder.
# Allow read/write access to the models
chmod
777
$model_dir
tools/generate.py
0 → 100644
View file @
13d14c66
#####################################################################################
# The MIT License (MIT)
#
# Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
#####################################################################################
import
api
,
argparse
,
os
,
runpy
,
subprocess
,
sys
,
te
from
pathlib
import
Path
clang_format_path
=
Path
(
'clang-format.exe'
if
os
.
name
==
'nt'
else
'/opt/rocm/llvm/bin/clang-format'
)
work_dir
=
Path
().
cwd
()
src_dir
=
(
work_dir
/
'../src'
).
absolute
()
migraphx_py_path
=
src_dir
/
'api/migraphx.py'
def
clang_format
(
buffer
,
**
kwargs
):
return
subprocess
.
run
(
f
'
{
clang_format_path
}
-style=file'
,
capture_output
=
True
,
shell
=
True
,
check
=
True
,
input
=
buffer
.
encode
(
'utf-8'
),
cwd
=
work_dir
,
**
kwargs
).
stdout
.
decode
(
'utf-8'
)
def
api_generate
(
input_path
:
Path
,
output_path
:
Path
):
with
open
(
output_path
,
'w'
)
as
f
:
f
.
write
(
clang_format
(
api
.
run
(
input_path
)))
def
te_generate
(
input_path
:
Path
,
output_path
:
Path
):
with
open
(
output_path
,
'w'
)
as
f
:
f
.
write
(
clang_format
(
te
.
run
(
input_path
)))
def
main
():
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'-f'
,
'--clang-format'
,
type
=
Path
)
args
=
parser
.
parse_args
()
global
clang_format_path
if
args
.
clang_format
:
clang_format_path
=
args
.
clang_format
if
not
clang_format_path
.
is_file
():
print
(
f
"
{
clang_format_path
}
: invalid path or not installed"
,
file
=
sys
.
stderr
)
return
try
:
files
=
Path
(
'include'
).
absolute
().
iterdir
()
for
f
in
[
f
for
f
in
files
if
f
.
is_file
()]:
te_generate
(
f
,
src_dir
/
f
'include/migraphx/
{
f
.
name
}
'
)
runpy
.
run_path
(
str
(
migraphx_py_path
))
api_generate
(
work_dir
/
'api/migraphx.h'
,
src_dir
/
'api/include/migraphx/migraphx.h'
)
print
(
'Finished generating header migraphx.h'
)
api_generate
(
work_dir
/
'api/api.cpp'
,
src_dir
/
'api/api.cpp'
)
print
(
'Finished generating source api.cpp'
)
except
subprocess
.
CalledProcessError
as
ex
:
if
ex
.
stdout
:
print
(
ex
.
stdout
.
decode
(
'utf-8'
))
if
ex
.
stderr
:
print
(
ex
.
stdout
.
decode
(
'utf-8'
))
print
(
f
"Command '
{
ex
.
cmd
}
' returned
{
ex
.
returncode
}
"
)
raise
if
__name__
==
"__main__"
:
main
()
tools/install_prereqs.sh
View file @
13d14c66
...
@@ -3,7 +3,7 @@
...
@@ -3,7 +3,7 @@
#####################################################################################
#####################################################################################
# The MIT License (MIT)
# The MIT License (MIT)
#
#
# Copyright (c) 2015-202
2
Advanced Micro Devices, Inc. All rights reserved.
# Copyright (c) 2015-202
3
Advanced Micro Devices, Inc. All rights reserved.
#
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# of this software and associated documentation files (the "Software"), to deal
...
@@ -51,6 +51,7 @@ else
...
@@ -51,6 +51,7 @@ else
openmp-extras
\
openmp-extras
\
python3-dev
\
python3-dev
\
python3-pip
\
python3-pip
\
python3-venv
\
rocblas-dev
\
rocblas-dev
\
rocm-cmake
rocm-cmake
fi
fi
...
@@ -80,8 +81,8 @@ rbuild prepare -d $PREFIX -s develop
...
@@ -80,8 +81,8 @@ rbuild prepare -d $PREFIX -s develop
if
[[
(
"
${
ID
}
"
!=
"sles"
)
]]
;
then
if
[[
(
"
${
ID
}
"
!=
"sles"
)
]]
;
then
export
CMAKE_ARGS
=
"-DONNX_USE_PROTOBUF_SHARED_LIBS=ON"
export
CMAKE_ARGS
=
"-DONNX_USE_PROTOBUF_SHARED_LIBS=ON"
pip3
install
onnx
==
1.1
0.2
numpy
==
1.21.6
typing
==
3.7.4
pytest
==
6.0.1
packaging
==
23.0
pip3
install
onnx
==
1.1
4.1
numpy
==
1.21.6
typing
==
3.7.4
pytest
==
6.0.1
packaging
==
23.0
# pin version of protobuf in Python for onnx runtime unit tests between dist versions
# pin version of protobuf in Python for onnx runtime unit tests between dist versions
pip3
install
protobuf
==
3.20.
0
pip3
install
protobuf
==
3.20.
2
fi
fi
tools/license_stamper.py
View file @
13d14c66
...
@@ -2,7 +2,7 @@
...
@@ -2,7 +2,7 @@
#####################################################################################
#####################################################################################
# The MIT License (MIT)
# The MIT License (MIT)
#
#
# Copyright (c) 2015-202
2
Advanced Micro Devices, Inc. All rights reserved.
# Copyright (c) 2015-202
3
Advanced Micro Devices, Inc. All rights reserved.
#
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# of this software and associated documentation files (the "Software"), to deal
...
@@ -38,7 +38,7 @@ def getipynb_markdownBlockAsList():
...
@@ -38,7 +38,7 @@ def getipynb_markdownBlockAsList():
'
\t\t
"cell_type": "code",
\n
'
,
'
\t\t
"execution_count": null,
\n
'
,
'
\t\t
"cell_type": "code",
\n
'
,
'
\t\t
"execution_count": null,
\n
'
,
'
\t\t
"metadata": {},
\n
'
,
'
\t\t
"outputs": [],
\n
'
,
'
\t\t
"source": [
\n
'
,
'
\t\t
"metadata": {},
\n
'
,
'
\t\t
"outputs": [],
\n
'
,
'
\t\t
"source": [
\n
'
,
'
\t\t\t\"
# The MIT License (MIT)
\\
n
\"
,
\n
'
,
'
\t\t\t\"
#
\\
n
\"
,
\n
'
,
'
\t\t\t\"
# The MIT License (MIT)
\\
n
\"
,
\n
'
,
'
\t\t\t\"
#
\\
n
\"
,
\n
'
,
'
\t\t\t\"
# Copyright (c) 2015-202
2
Advanced Micro Devices, Inc. All rights reserved.
\\
n
\"
,
\n
'
,
'
\t\t\t\"
# Copyright (c) 2015-202
3
Advanced Micro Devices, Inc. All rights reserved.
\\
n
\"
,
\n
'
,
'
\t\t\t\"
#
\\
n
\"
,
\n
'
,
'
\t\t\t\"
#
\\
n
\"
,
\n
'
,
'
\t\t\t\"
# Permission is hereby granted, free of charge, to any person obtaining a copy
\\
n
\"
,
\n
'
,
'
\t\t\t\"
# Permission is hereby granted, free of charge, to any person obtaining a copy
\\
n
\"
,
\n
'
,
'
\t\t\t\"
# of this software and associated documentation files (the
\'
Software
\'
), to deal
\\
n
\"
,
\n
'
,
'
\t\t\t\"
# of this software and associated documentation files (the
\'
Software
\'
), to deal
\\
n
\"
,
\n
'
,
...
...
tools/te.py
View file @
13d14c66
...
@@ -431,6 +431,9 @@ def template_eval(template, **kwargs):
...
@@ -431,6 +431,9 @@ def template_eval(template, **kwargs):
return
template
return
template
f
=
open
(
sys
.
argv
[
1
]).
read
()
def
run
(
p
):
r
=
template_eval
(
f
)
return
template_eval
(
open
(
p
).
read
())
sys
.
stdout
.
write
(
r
)
if
__name__
==
'__main__'
:
sys
.
stdout
.
write
(
run
(
sys
.
argv
[
1
]))
tools/test_runner.py
View file @
13d14c66
...
@@ -39,6 +39,15 @@ def parse_args():
...
@@ -39,6 +39,15 @@ def parse_args():
type
=
str
,
type
=
str
,
default
=
'gpu'
,
default
=
'gpu'
,
help
=
'Specify where the tests execute (ref, gpu)'
)
help
=
'Specify where the tests execute (ref, gpu)'
)
parser
.
add_argument
(
'--fp16'
,
action
=
'store_true'
,
help
=
'Quantize to fp16'
)
parser
.
add_argument
(
'--atol'
,
type
=
float
,
default
=
1e-3
,
help
=
'The absolute tolerance parameter'
)
parser
.
add_argument
(
'--rtol'
,
type
=
float
,
default
=
1e-3
,
help
=
'The relative tolerance parameter'
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
return
args
return
args
...
@@ -257,6 +266,8 @@ def main():
...
@@ -257,6 +266,8 @@ def main():
# read and compile model
# read and compile model
model
=
migraphx
.
parse_onnx
(
model_path_name
,
map_input_dims
=
param_shapes
)
model
=
migraphx
.
parse_onnx
(
model_path_name
,
map_input_dims
=
param_shapes
)
if
args
.
fp16
:
migraphx
.
quantize_fp16
(
model
)
model
.
compile
(
migraphx
.
get_target
(
target
))
model
.
compile
(
migraphx
.
get_target
(
target
))
# get test cases
# get test cases
...
@@ -279,7 +290,10 @@ def main():
...
@@ -279,7 +290,10 @@ def main():
output_data
=
run_one_case
(
model
,
input_data
)
output_data
=
run_one_case
(
model
,
input_data
)
# check output correctness
# check output correctness
ret
=
check_correctness
(
gold_outputs
,
output_data
)
ret
=
check_correctness
(
gold_outputs
,
output_data
,
atol
=
args
.
atol
,
rtol
=
args
.
rtol
)
if
ret
:
if
ret
:
correct_num
+=
1
correct_num
+=
1
...
...
Prev
1
…
17
18
19
20
21
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment