Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
e1ebad1d
Unverified
Commit
e1ebad1d
authored
Oct 11, 2023
by
Chris Austen
Committed by
GitHub
Oct 11, 2023
Browse files
Merge branch 'develop' into verify_fp16_tol
parents
13bb4c2f
34b68ee4
Changes
48
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
212 additions
and
67 deletions
+212
-67
test/verify/test_scatter_nonstandard_shape.cpp
test/verify/test_scatter_nonstandard_shape.cpp
+49
-0
tools/CMakeLists.txt
tools/CMakeLists.txt
+13
-1
tools/api.py
tools/api.py
+10
-10
tools/api/api.cpp
tools/api/api.cpp
+8
-0
tools/api/migraphx.h
tools/api/migraphx.h
+1
-0
tools/check_stamped.py
tools/check_stamped.py
+37
-53
tools/generate.py
tools/generate.py
+88
-0
tools/te.py
tools/te.py
+6
-3
No files found.
test/verify/test_scatter_nonstandard_shape.cpp
0 → 100644
View file @
e1ebad1d
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct
test_scatter_nonstandard_shape
:
verify_program
<
test_scatter_nonstandard_shape
>
{
migraphx
::
program
create_program
()
const
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
sd
{
migraphx
::
shape
::
float_type
,
{
3
,
1
,
3
},
{
1
,
3
,
2
}};
migraphx
::
shape
si
{
migraphx
::
shape
::
int32_type
,
{
2
,
1
,
3
},
{
1
,
3
,
2
}};
std
::
vector
<
int
>
vi
=
{
1
,
0
,
2
,
0
,
2
,
1
};
migraphx
::
shape
su
{
migraphx
::
shape
::
float_type
,
{
2
,
1
,
3
},
{
1
,
2
,
3
}};
auto
pd
=
mm
->
add_parameter
(
"data"
,
sd
);
auto
li
=
mm
->
add_literal
(
migraphx
::
literal
{
si
,
vi
});
auto
pu
=
mm
->
add_parameter
(
"update"
,
su
);
auto
r
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"scatter_none"
,
{{
"axis"
,
-
1
}}),
pd
,
li
,
pu
);
mm
->
add_return
({
r
});
return
p
;
}
};
tools/CMakeLists.txt
View file @
e1ebad1d
...
...
@@ -22,4 +22,16 @@
# THE SOFTWARE.
#####################################################################################
add_custom_target
(
generate bash
${
CMAKE_CURRENT_SOURCE_DIR
}
/generate.sh
)
find_package
(
Python 3 COMPONENTS Interpreter
)
if
(
NOT Python_EXECUTABLE
)
message
(
WARNING
"Python 3 interpreter not found - skipping 'generate' target!"
)
return
()
endif
()
find_program
(
CLANG_FORMAT clang-format PATHS /opt/rocm/llvm $ENV{HIP_PATH}
)
if
(
NOT CLANG_FORMAT
)
message
(
WARNING
"clang-format not found - skipping 'generate' target!"
)
return
()
endif
()
add_custom_target
(
generate
${
Python_EXECUTABLE
}
generate.py -f
${
CLANG_FORMAT
}
WORKING_DIRECTORY
${
CMAKE_CURRENT_SOURCE_DIR
}
)
tools/api.py
View file @
e1ebad1d
...
...
@@ -27,6 +27,7 @@ import re
import
runpy
from
functools
import
wraps
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Optional
,
Tuple
,
Union
from
pathlib
import
Path
type_map
:
Dict
[
str
,
Callable
[[
'Parameter'
],
None
]]
=
{}
cpp_type_map
:
Dict
[
str
,
str
]
=
{}
...
...
@@ -1281,18 +1282,17 @@ def template_eval(template, **kwargs):
return
template
def
run
(
args
:
List
[
str
])
->
None
:
runpy
.
run_path
(
args
[
0
])
if
len
(
args
)
>
1
:
f
=
open
(
args
[
1
]).
read
()
r
=
template_eval
(
f
)
def
run
(
path
:
Union
[
Path
,
str
])
->
str
:
return
template_eval
(
open
(
path
).
read
())
if
__name__
==
"__main__"
:
sys
.
modules
[
'api'
]
=
sys
.
modules
[
'__main__'
]
runpy
.
run_path
(
sys
.
argv
[
1
])
if
len
(
sys
.
argv
)
>
2
:
r
=
run
(
sys
.
argv
[
2
])
sys
.
stdout
.
write
(
r
)
else
:
sys
.
stdout
.
write
(
generate_c_header
())
sys
.
stdout
.
write
(
generate_c_api_body
())
# sys.stdout.write(generate_cpp_header())
if
__name__
==
"__main__"
:
sys
.
modules
[
'api'
]
=
sys
.
modules
[
'__main__'
]
run
(
sys
.
argv
[
1
:])
tools/api/api.cpp
View file @
e1ebad1d
...
...
@@ -38,26 +38,32 @@
#include <migraphx/register_op.hpp>
#include <migraphx/json.hpp>
#include <migraphx/convert_to_json.hpp>
#include <array>
#include <algorithm>
#include <cstdarg>
namespace
migraphx
{
#ifdef MIGRAPHX_BUILD_TESTING
static
thread_local
bool
disable_exception_catch
=
false
;
// NOLINT
extern
"C"
MIGRAPHX_C_EXPORT
void
migraphx_test_private_disable_exception_catch
(
bool
b
)
{
disable_exception_catch
=
b
;
}
#endif
template
<
class
F
>
migraphx_status
try_
(
F
f
,
bool
output
=
true
)
// NOLINT
{
#ifdef MIGRAPHX_BUILD_TESTING
if
(
disable_exception_catch
)
{
f
();
}
else
{
#endif
try
{
f
();
...
...
@@ -81,7 +87,9 @@ migraphx_status try_(F f, bool output = true) // NOLINT
{
return
migraphx_status_unknown_error
;
}
#ifdef MIGRAPHX_BUILD_TESTING
}
#endif
return
migraphx_status_success
;
}
...
...
tools/api/migraphx.h
View file @
e1ebad1d
...
...
@@ -26,6 +26,7 @@
#include <stdlib.h>
#include <stdbool.h>
#include <stdint.h>
#include <migraphx/api/export.h>
...
...
tools/check_stamped.py
View file @
e1ebad1d
...
...
@@ -2,7 +2,7 @@
#####################################################################################
# The MIT License (MIT)
#
# Copyright (c) 2015-202
2
Advanced Micro Devices, Inc. All rights reserved.
# Copyright (c) 2015-202
3
Advanced Micro Devices, Inc. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
...
...
@@ -27,11 +27,11 @@ import sys
debug
=
False
# The filetypes we want to check for that are stamped
# LICENSE is included here as it SHOULD have a li
s
cen
c
e in it otherwise flag it as unstamped
# LICENSE is included here as it SHOULD have a licen
s
e in it otherwise flag it as unstamped
supported_file_types
=
(
".cpp"
,
".hpp"
,
".h"
,
".ipynb"
,
".py"
,
".txt"
,
".sh"
,
".bsh"
,
"LICENSE"
,
".cmake"
)
#add general stuff we shouldn't stamp and any exceptions here
#
add general stuff we shouldn't stamp and any exceptions here
unsupported_file_types
=
[
".onnx"
,
".pb"
,
".rst"
,
".jpg"
,
".jpeg"
,
".proto"
,
".md"
,
".clang"
,
".weight"
,
".ini"
,
".json"
,
".docker"
,
".git"
,
".rules"
,
".yml"
...
...
@@ -40,105 +40,89 @@ unsupported_file_types = [
specificIgnores
=
(
"digits.txt"
,
"Dockerfile"
,
"Jenkinsfile"
,
""
)
def
hasKeySequence
(
inputfile
,
key_message
):
result
=
False
def
hasKeySequence
(
inputfile
:
str
,
key_message
:
str
)
->
bool
:
if
key_message
in
inputfile
:
re
sult
=
True
return
result
re
turn
True
return
False
#Simple just open and write stuff to each file with the license stamp
def
openAndCheckFile
(
filename
):
result
=
False
#open save old contents and append things here
if
debug
is
True
:
print
(
"Open"
,
filename
,
end
=
''
)
# Simple just open and write stuff to each file with the license stamp
def
needStampCheck
(
filename
:
str
)
->
bool
:
# open save old contents and append things here
if
debug
:
print
(
"Open"
,
filename
,
end
=
' '
)
try
:
file
=
open
(
filename
,
'r'
)
except
OSError
as
e
:
if
debug
is
True
:
print
(
str
(
e
)
+
"....Open Error: Skipping file "
)
if
debug
:
print
(
str
(
e
)
+
"....Open Error: Skipping file "
)
file
.
close
()
return
return
False
else
:
with
file
as
contents
:
try
:
save
=
contents
.
read
()
hasAmdLic
=
hasKeySequence
(
save
,
"Advanced Micro Devices, Inc. All rights reserved"
)
#Check if we have a licence stamp already
if
hasAmdLic
is
True
:
if
debug
is
True
:
print
(
"....Already Stamped: Skipping file "
)
# Check if we have a license stamp already
if
hasKeySequence
(
save
,
"Advanced Micro Devices, Inc. All rights reserved"
):
if
debug
:
print
(
"....Already Stamped: Skipping file "
)
contents
.
close
()
re
sult
=
Tru
e
re
turn
Fals
e
except
UnicodeDecodeError
as
eu
:
if
debug
is
True
:
print
(
str
(
eu
)
+
"...Skipping binary file "
)
if
debug
:
print
(
f
"
{
str
(
eu
)
}
...Skipping binary file "
)
contents
.
close
()
re
sult
=
Tru
e
re
turn
Fals
e
return
result
return
True
# Deterine if filename is desired in the fileTuple past in
def
check_filename
(
filename
,
fileTuple
):
supported
=
False
for
key
in
fileTuple
:
if
key
in
filename
:
supported
=
True
break
return
supported
# Check if any element in fileTuple is in filename
def
check_filename
(
filename
:
str
,
fileTuple
:
tuple
or
list
)
->
bool
:
if
any
([
x
in
filename
for
x
in
fileTuple
]):
return
True
return
False
def
main
():
def
main
()
->
None
:
unsupported_file_types
.
extend
(
specificIgnores
)
#Get a list of all the tracked files in our git repo
#
Get a list of all the tracked files in our git repo
proc
=
subprocess
.
run
(
"git ls-files --exclude-standard"
,
shell
=
True
,
stdout
=
subprocess
.
PIPE
)
fileList
=
proc
.
stdout
.
decode
().
split
(
'
\n
'
)
if
debug
is
True
:
print
(
"Target file list:
\n
"
+
str
(
fileList
))
if
debug
:
print
(
"Target file list:
\n
"
+
str
(
fileList
))
unsupportedFiles
=
[]
unstampedFiles
=
[]
unknownFiles
=
[]
for
file
in
fileList
:
supported
=
check_filename
(
file
,
supported_file_types
)
if
supported
is
True
:
isStamped
=
openAndCheckFile
(
file
)
if
isStamped
is
False
:
if
check_filename
(
file
,
supported_file_types
):
if
needStampCheck
(
file
):
unstampedFiles
.
append
(
file
)
elif
check_filename
(
file
,
unsupported_file_types
):
unsupportedFiles
.
append
(
file
)
else
:
unsupported
=
check_filename
(
file
,
unsupported_file_types
)
if
unsupported
is
True
:
unsupportedFiles
.
append
(
file
)
else
:
unknownFiles
.
append
(
file
)
unknownFiles
.
append
(
file
)
#Do a bunch of checks based on our file lists
#
Do a bunch of checks based on our file lists
if
len
(
unstampedFiles
)
>
0
:
print
(
"Error: The following "
+
str
(
len
(
unstampedFiles
))
+
print
(
"
\n
Error: The following "
+
str
(
len
(
unstampedFiles
))
+
" files are currently without a license:"
)
print
(
str
(
unstampedFiles
))
sys
.
exit
(
1
)
if
len
(
unknownFiles
)
>
0
:
print
(
"Error: The following "
+
str
(
len
(
unknownFiles
))
+
print
(
"
\n
Error: The following "
+
str
(
len
(
unknownFiles
))
+
" files not handled:"
)
print
(
str
(
unknownFiles
))
sys
.
exit
(
2
)
sys
.
exit
(
0
)
if
__name__
==
"__main__"
:
main
()
tools/generate.
sh
→
tools/generate.
py
100755 → 100644
View file @
e1ebad1d
#####################################################################################
# The MIT License (MIT)
#
# Copyright (c) 2015-202
2
Advanced Micro Devices, Inc. All rights reserved.
# Copyright (c) 2015-202
3
Advanced Micro Devices, Inc. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
...
...
@@ -21,23 +21,68 @@
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
#####################################################################################
DIR
=
"
$(
cd
"
$(
dirname
"
${
BASH_SOURCE
[0]
}
"
)
"
&&
pwd
)
"
CLANG_FORMAT
=
/opt/rocm/llvm/bin/clang-format
SRC_DIR
=
$DIR
/../src
PYTHON
=
python3
if
type
-p
python3.6
>
/dev/null
;
then
PYTHON
=
python3.6
fi
if
type
-p
python3.8
>
/dev/null
;
then
PYTHON
=
python3.8
fi
ls
-1
$DIR
/include/ | xargs
-n
1
-P
$(
nproc
)
-I
{}
-t
bash
-c
"
$PYTHON
$DIR
/te.py
$DIR
/include/{} |
$CLANG_FORMAT
-style=file >
$SRC_DIR
/include/migraphx/{}"
function
api
{
$PYTHON
$DIR
/api.py
$SRC_DIR
/api/migraphx.py
$1
|
$CLANG_FORMAT
-style
=
file
>
$2
}
api
$DIR
/api/migraphx.h
$SRC_DIR
/api/include/migraphx/migraphx.h
echo
"Finished generating header migraphx.h"
api
$DIR
/api/api.cpp
$SRC_DIR
/api/api.cpp
echo
"Finished generating source api.cpp "
import
api
,
argparse
,
os
,
runpy
,
subprocess
,
sys
,
te
from
pathlib
import
Path
clang_format_path
=
Path
(
'clang-format.exe'
if
os
.
name
==
'nt'
else
'/opt/rocm/llvm/bin/clang-format'
)
work_dir
=
Path
().
cwd
()
src_dir
=
(
work_dir
/
'../src'
).
absolute
()
migraphx_py_path
=
src_dir
/
'api/migraphx.py'
def
clang_format
(
buffer
,
**
kwargs
):
return
subprocess
.
run
(
f
'
{
clang_format_path
}
-style=file'
,
capture_output
=
True
,
shell
=
True
,
check
=
True
,
input
=
buffer
.
encode
(
'utf-8'
),
cwd
=
work_dir
,
**
kwargs
).
stdout
.
decode
(
'utf-8'
)
def
api_generate
(
input_path
:
Path
,
output_path
:
Path
):
with
open
(
output_path
,
'w'
)
as
f
:
f
.
write
(
clang_format
(
api
.
run
(
input_path
)))
def
te_generate
(
input_path
:
Path
,
output_path
:
Path
):
with
open
(
output_path
,
'w'
)
as
f
:
f
.
write
(
clang_format
(
te
.
run
(
input_path
)))
def
main
():
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'-f'
,
'--clang-format'
,
type
=
Path
)
args
=
parser
.
parse_args
()
global
clang_format_path
if
args
.
clang_format
:
clang_format_path
=
args
.
clang_format
if
not
clang_format_path
.
is_file
():
print
(
f
"
{
clang_format_path
}
: invalid path or not installed"
,
file
=
sys
.
stderr
)
return
try
:
files
=
Path
(
'include'
).
absolute
().
iterdir
()
for
f
in
[
f
for
f
in
files
if
f
.
is_file
()]:
te_generate
(
f
,
src_dir
/
f
'include/migraphx/
{
f
.
name
}
'
)
runpy
.
run_path
(
str
(
migraphx_py_path
))
api_generate
(
work_dir
/
'api/migraphx.h'
,
src_dir
/
'api/include/migraphx/migraphx.h'
)
print
(
'Finished generating header migraphx.h'
)
api_generate
(
work_dir
/
'api/api.cpp'
,
src_dir
/
'api/api.cpp'
)
print
(
'Finished generating source api.cpp'
)
except
subprocess
.
CalledProcessError
as
ex
:
if
ex
.
stdout
:
print
(
ex
.
stdout
.
decode
(
'utf-8'
))
if
ex
.
stderr
:
print
(
ex
.
stdout
.
decode
(
'utf-8'
))
print
(
f
"Command '
{
ex
.
cmd
}
' returned
{
ex
.
returncode
}
"
)
raise
if
__name__
==
"__main__"
:
main
()
tools/te.py
View file @
e1ebad1d
...
...
@@ -431,6 +431,9 @@ def template_eval(template, **kwargs):
return
template
f
=
open
(
sys
.
argv
[
1
]).
read
()
r
=
template_eval
(
f
)
sys
.
stdout
.
write
(
r
)
def
run
(
p
):
return
template_eval
(
open
(
p
).
read
())
if
__name__
==
'__main__'
:
sys
.
stdout
.
write
(
run
(
sys
.
argv
[
1
]))
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment