Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
01cf30d9
Commit
01cf30d9
authored
Oct 09, 2023
by
Artur Wojcik
Browse files
incorporate review feedback
parent
14e20a73
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
140 additions
and
153 deletions
+140
-153
src/api/api.cpp
src/api/api.cpp
+2
-2
src/dynamic_loader.cpp
src/dynamic_loader.cpp
+6
-0
src/include/migraphx/op/random_uniform.hpp
src/include/migraphx/op/random_uniform.hpp
+20
-18
src/targets/cpu/include/migraphx/cpu/dnnl.hpp
src/targets/cpu/include/migraphx/cpu/dnnl.hpp
+96
-107
src/targets/gpu/compile_hip.cpp
src/targets/gpu/compile_hip.cpp
+0
-1
tools/api.py
tools/api.py
+6
-14
tools/api/api.cpp
tools/api/api.cpp
+2
-2
tools/generate.py
tools/generate.py
+6
-7
tools/te.py
tools/te.py
+2
-2
No files found.
src/api/api.cpp
View file @
01cf30d9
...
...
@@ -42,6 +42,8 @@
#include <algorithm>
#include <cstdarg>
namespace
migraphx
{
#ifdef MIGRAPHX_BUILD_TESTING
static
thread_local
bool
disable_exception_catch
=
false
;
// NOLINT
...
...
@@ -51,8 +53,6 @@ extern "C" MIGRAPHX_C_EXPORT void migraphx_test_private_disable_exception_catch(
}
#endif
namespace
migraphx
{
template
<
class
F
>
migraphx_status
try_
(
F
f
,
bool
output
=
true
)
// NOLINT
{
...
...
src/dynamic_loader.cpp
View file @
01cf30d9
...
...
@@ -105,6 +105,11 @@ struct dynamic_loader_impl
}
}
dynamic_loader_impl
(
const
dynamic_loader_impl
&
)
=
delete
;
dynamic_loader_impl
&
operator
=
(
const
dynamic_loader_impl
&
)
=
delete
;
dynamic_loader_impl
(
dynamic_loader_impl
&&
)
=
default
;
~
dynamic_loader_impl
()
{
if
(
handle
!=
nullptr
)
...
...
@@ -112,6 +117,7 @@ struct dynamic_loader_impl
FreeLibrary
(
handle
);
}
}
static
std
::
shared_ptr
<
dynamic_loader_impl
>
from_buffer
(
const
char
*
image
,
std
::
size_t
size
)
{
auto
t
=
tmp_dir
{
"migx-dynload"
};
...
...
src/include/migraphx/op/random_uniform.hpp
View file @
01cf30d9
...
...
@@ -75,26 +75,28 @@ struct random_uniform
result
.
visit
([
&
](
auto
output
)
{
using
type
=
typename
decltype
(
output
)
::
value_type
;
#ifdef _MSC_VER
// According to the C++ specification, the effect is undefined if the result type
// for the generator is not one of short, int, long, long long, unsigned short,
// unsigned int, unsigned long, or unsigned long long. See
// https://en.cppreference.com/w/cpp/numeric/random/uniform_int_distribution.
if
constexpr
(
std
::
is_same_v
<
type
,
unsigned
char
>
||
std
::
is_same_v
<
type
,
signed
char
>
)
if
constexpr
(
std
::
is_integral
<
type
>
{})
{
std
::
uniform_int_distribution
<
int
>
dis
{
std
::
numeric_limits
<
type
>::
min
(),
std
::
numeric_limits
<
type
>::
max
()};
std
::
generate
(
output
.
begin
(),
output
.
end
(),
[
&
]
{
return
dis
(
gen
);
});
}
else
#ifdef _MSC_VER
// According to the C++ specification, the effect is undefined if the result type
// for the generator is not one of short, int, long, long long, unsigned short,
// unsigned int, unsigned long, or unsigned long long. See
// https://en.cppreference.com/w/cpp/numeric/random/uniform_int_distribution.
if
constexpr
(
sizeof
(
type
)
==
1
)
{
std
::
uniform_int_distribution
<
int
>
dis
{
std
::
numeric_limits
<
type
>::
min
(),
std
::
numeric_limits
<
type
>::
max
()};
std
::
generate
(
output
.
begin
(),
output
.
end
(),
[
&
]
{
return
dis
(
gen
);
});
}
else
#endif
if
constexpr
(
std
::
is_integral
<
type
>
{})
{
// default range for all integer types is
// (0, std::uniform_int_distribution<type>::max()).
// Todo: enable different ranges
std
::
uniform_int_distribution
<
type
>
dis
;
std
::
generate
(
output
.
begin
(),
output
.
end
(),
[
&
]
{
return
dis
(
gen
);
});
{
// default range for all integer types is
// (0, std::uniform_int_distribution<type>::max()).
// Todo: enable different ranges
std
::
uniform_int_distribution
<
type
>
dis
;
std
::
generate
(
output
.
begin
(),
output
.
end
(),
[
&
]
{
return
dis
(
gen
);
})
;
}
}
else
{
...
...
src/targets/cpu/include/migraphx/cpu/dnnl.hpp
View file @
01cf30d9
...
...
@@ -91,118 +91,28 @@ struct post_op : reflect_equality<post_op>, reflect_stream<post_op>
}
};
template
<
class
F
>
struct
execute_wrapper
{
F
f
;
argument
operator
()(
context
&
,
const
std
::
vector
<
argument
>&
args
)
const
{
return
f
(
args
);
}
};
template
<
class
F
>
execute_wrapper
<
F
>
make_execute_wrapper
(
F
f
)
{
return
{
std
::
move
(
f
)};
}
template
<
class
Derived
,
class
Primitive
>
struct
dnnl_op
:
auto_register_op
<
Derived
>
{
std
::
vector
<
post_op
>
post_ops
;
std
::
function
<
argument
(
context
&
ctx
,
const
std
::
vector
<
argument
>&
args
)
>
execute
;
class
executable
{
std
::
unordered_map
<
int
,
dnnl
::
memory
::
desc
>
md
;
Primitive
prim
;
std
::
vector
<
int
>
arg_lookup
;
#ifdef _DEBUG
const
dnnl_op
&
self
;
const
Derived
&
derived
;
std
::
string
name
;
dnnl
::
primitive_attr
prim_attr
;
const
std
::
vector
<
shape
>&
inputs
;
const
shape
&
output_shape
;
#endif
public:
// clang-format off
executable
(
const
dnnl_op
&
op
,
const
shape
&
out_shape
,
const
std
::
vector
<
shape
>&
in_shapes
)
:
md
{
op
.
to_memory_desc
(
out_shape
,
in_shapes
)},
prim
{
op
.
get_primitive
(
md
)},
arg_lookup
{
op
.
create_arg_map
(
in_shapes
.
size
())}
#ifdef _DEBUG
,
self
{
op
},
derived
{
static_cast
<
const
Derived
&>
(
op
)},
name
{
derived
.
name
()},
prim_attr
{
op
.
get_primitive_attr
(
md
)},
inputs
{
in_shapes
},
output_shape
{
out_shape
}
#endif
// clang-format on
{
}
argument
operator
()(
context
&
,
const
std
::
vector
<
argument
>&
args
)
{
#ifdef _DEBUG
// Check that the memory descriptors have not changed
auto
debug_args
=
args
;
debug_args
.
pop_back
();
auto
debug_md
=
self
.
to_memory_desc
(
output_shape
,
to_shapes
(
debug_args
));
for
(
auto
&&
p
:
debug_md
)
{
if
(
md
.
count
(
p
.
first
)
==
0
)
MIGRAPHX_THROW
(
name
+
": Missing memory descriptor for: "
+
std
::
to_string
(
p
.
first
));
if
(
p
.
second
==
md
.
at
(
p
.
first
))
continue
;
MIGRAPHX_THROW
(
name
+
": Memory descriptor has changed for: "
+
std
::
to_string
(
p
.
first
));
}
// Check post_ops args are correct
auto
pos
=
prim_attr
.
get_post_ops
();
auto
prim_input_size
=
inputs
.
size
()
-
self
.
get_extra_post_op_args
();
int
j
=
0
;
for
(
int
i
=
0
;
i
<
pos
.
len
();
i
++
)
{
auto
arg
=
j
+
prim_input_size
;
auto
kind
=
pos
.
kind
(
i
);
std
::
string
mesg
=
"Post op "
+
std
::
to_string
(
i
)
+
"@"
+
std
::
to_string
(
arg
)
+
": "
;
try
{
dnnl
::
algorithm
algo
;
dnnl
::
memory
::
desc
mdesc
;
float
scale
=
0
;
float
alpha
=
0
;
float
beta
=
0
;
if
(
kind
==
dnnl
::
primitive
::
kind
::
binary
)
{
pos
.
get_params_binary
(
i
,
algo
,
mdesc
);
if
(
mdesc
!=
md
.
at
(
arg_lookup
.
at
(
arg
)))
MIGRAPHX_THROW
(
mesg
+
"Memory descriptor doesn't match for binary post op"
);
j
++
;
}
else
if
(
kind
==
dnnl
::
primitive
::
kind
::
eltwise
)
{
pos
.
get_params_eltwise
(
i
,
scale
,
algo
,
alpha
,
beta
);
}
else
if
(
kind
==
dnnl
::
primitive
::
kind
::
sum
)
{
pos
.
get_params_sum
(
i
,
scale
);
algo
=
dnnl
::
algorithm
::
binary_add
;
}
else
{
MIGRAPHX_THROW
(
"Unknown kind"
);
}
if
(
to_dnnl_algo
(
self
.
post_ops
[
i
].
algo
)
!=
algo
)
MIGRAPHX_THROW
(
mesg
+
"Algorithm doesn't match for post op "
+
self
.
post_ops
[
i
].
algo
+
" != "
+
to_string
(
algo
));
}
catch
(
const
dnnl
::
error
&
e
)
{
MIGRAPHX_THROW
(
mesg
+
"Failed to get post ops argument "
+
": "
+
e
.
what
());
}
}
#endif
std
::
unordered_map
<
int
,
dnnl
::
memory
>
m
;
m
[
MIGRAPHX_DNNL_PREFIX
(
ARG_DST
)]
=
to_dnnl_memory
(
md
.
at
(
MIGRAPHX_DNNL_PREFIX
(
ARG_DST
)),
args
.
back
());
for
(
int
i
=
0
;
i
<
args
.
size
()
-
1
;
i
++
)
m
[
arg_lookup
[
i
]]
=
to_dnnl_memory
(
md
.
at
(
arg_lookup
[
i
]),
args
[
i
]);
prim
.
execute
(
get_dnnl_context
().
stream
,
m
);
return
args
.
back
();
}
};
template
<
class
Self
,
class
F
>
static
auto
reflect_base
(
Self
&
self
,
F
f
)
{
...
...
@@ -406,7 +316,86 @@ struct dnnl_op : auto_register_op<Derived>
{
// Compensate for allocation
inputs
.
pop_back
();
execute
=
executable
{
*
this
,
output_shape
,
inputs
};
const
auto
&
self
=
static_cast
<
const
Derived
&>
(
*
this
);
auto
name
=
self
.
name
();
auto
md
=
to_memory_desc
(
output_shape
,
inputs
);
auto
prim
=
get_primitive
(
md
);
auto
arg_lookup
=
create_arg_map
(
inputs
.
size
());
#ifndef NDEBUG
auto
prim_attr
=
get_primitive_attr
(
md
);
#endif
execute
=
make_execute_wrapper
([
=
](
const
std
::
vector
<
argument
>&
args
)
{
#ifndef NDEBUG
// Check that the memory descriptors have not changed
auto
debug_args
=
args
;
debug_args
.
pop_back
();
auto
debug_md
=
to_memory_desc
(
output_shape
,
to_shapes
(
debug_args
));
for
(
auto
&&
p
:
debug_md
)
{
if
(
md
.
count
(
p
.
first
)
==
0
)
MIGRAPHX_THROW
(
name
+
": Missing memory descriptor for: "
+
std
::
to_string
(
p
.
first
));
if
(
p
.
second
==
md
.
at
(
p
.
first
))
continue
;
MIGRAPHX_THROW
(
name
+
": Memory descriptor has changed for: "
+
std
::
to_string
(
p
.
first
));
}
// Check post_ops args are correct
auto
pos
=
prim_attr
.
get_post_ops
();
auto
prim_input_size
=
inputs
.
size
()
-
this
->
get_extra_post_op_args
();
int
j
=
0
;
for
(
int
i
=
0
;
i
<
pos
.
len
();
i
++
)
{
auto
arg
=
j
+
prim_input_size
;
auto
kind
=
pos
.
kind
(
i
);
std
::
string
mesg
=
"Post op "
+
std
::
to_string
(
i
)
+
"@"
+
std
::
to_string
(
arg
)
+
": "
;
try
{
dnnl
::
algorithm
algo
;
dnnl
::
memory
::
desc
mdesc
;
float
scale
=
0
;
float
alpha
=
0
;
float
beta
=
0
;
if
(
kind
==
dnnl
::
primitive
::
kind
::
binary
)
{
pos
.
get_params_binary
(
i
,
algo
,
mdesc
);
if
(
mdesc
!=
md
.
at
(
arg_lookup
.
at
(
arg
)))
MIGRAPHX_THROW
(
mesg
+
"Memory descriptor doesn't match for binary post op"
);
j
++
;
}
else
if
(
kind
==
dnnl
::
primitive
::
kind
::
eltwise
)
{
pos
.
get_params_eltwise
(
i
,
scale
,
algo
,
alpha
,
beta
);
}
else
if
(
kind
==
dnnl
::
primitive
::
kind
::
sum
)
{
pos
.
get_params_sum
(
i
,
scale
);
algo
=
dnnl
::
algorithm
::
binary_add
;
}
else
{
MIGRAPHX_THROW
(
"Unknown kind"
);
}
if
(
to_dnnl_algo
(
post_ops
[
i
].
algo
)
!=
algo
)
MIGRAPHX_THROW
(
mesg
+
"Algorithm doesn't match for post op "
+
post_ops
[
i
].
algo
+
" != "
+
to_string
(
algo
));
}
catch
(
const
dnnl
::
error
&
e
)
{
MIGRAPHX_THROW
(
mesg
+
"Failed to get post ops argument "
+
": "
+
e
.
what
());
}
}
#endif
std
::
unordered_map
<
int
,
dnnl
::
memory
>
m
;
m
[
MIGRAPHX_DNNL_PREFIX
(
ARG_DST
)]
=
to_dnnl_memory
(
md
.
at
(
MIGRAPHX_DNNL_PREFIX
(
ARG_DST
)),
args
.
back
());
for
(
int
i
=
0
;
i
<
args
.
size
()
-
1
;
i
++
)
m
[
arg_lookup
[
i
]]
=
to_dnnl_memory
(
md
.
at
(
arg_lookup
[
i
]),
args
[
i
]);
prim
.
execute
(
get_dnnl_context
().
stream
,
m
);
return
args
.
back
();
});
}
std
::
vector
<
shape
>
trim_post_op_inputs
(
const
std
::
vector
<
shape
>&
inputs
)
const
{
...
...
src/targets/gpu/compile_hip.cpp
View file @
01cf30d9
...
...
@@ -260,7 +260,6 @@ compile_hip_src(const std::vector<src_file>& srcs, std::string params, const std
if
(
fs
::
exists
(
driver
))
#endif
{
value
v
;
v
[
"srcs"
]
=
to_value
(
hsrcs
);
v
[
"params"
]
=
to_value
(
params
);
...
...
tools/api.py
View file @
01cf30d9
...
...
@@ -679,10 +679,6 @@ def add_function(name: str, *args, **kwargs) -> Function:
return
f
def
register_functions
(
path
:
Union
[
Path
,
str
])
->
None
:
runpy
.
run_path
(
path
if
isinstance
(
path
,
str
)
else
str
(
path
))
def
once
(
f
:
Callable
)
->
Any
:
@
wraps
(
f
)
def
decorated
(
*
args
,
**
kwargs
):
...
...
@@ -1286,21 +1282,17 @@ def template_eval(template, **kwargs):
return
template
def
invoke
(
path
:
Union
[
Path
,
str
])
->
str
:
def
run
(
path
:
Union
[
Path
,
str
])
->
str
:
return
template_eval
(
open
(
path
).
read
())
def
run
(
args
:
List
[
str
])
->
None
:
register_functions
(
args
[
0
])
if
len
(
args
)
>
1
:
r
=
invoke
(
args
[
1
])
if
__name__
==
"__main__"
:
sys
.
modules
[
'api'
]
=
sys
.
modules
[
'__main__'
]
runpy
.
run_path
(
sys
.
argv
[
1
])
if
len
(
sys
.
argv
)
>
2
:
r
=
run
(
sys
.
argv
[
2
])
sys
.
stdout
.
write
(
r
)
else
:
sys
.
stdout
.
write
(
generate_c_header
())
sys
.
stdout
.
write
(
generate_c_api_body
())
# sys.stdout.write(generate_cpp_header())
if
__name__
==
"__main__"
:
sys
.
modules
[
'api'
]
=
sys
.
modules
[
'__main__'
]
run
(
sys
.
argv
[
1
:])
tools/api/api.cpp
View file @
01cf30d9
...
...
@@ -42,6 +42,8 @@
#include <algorithm>
#include <cstdarg>
namespace
migraphx
{
#ifdef MIGRAPHX_BUILD_TESTING
static
thread_local
bool
disable_exception_catch
=
false
;
// NOLINT
...
...
@@ -51,8 +53,6 @@ extern "C" MIGRAPHX_C_EXPORT void migraphx_test_private_disable_exception_catch(
}
#endif
namespace
migraphx
{
template
<
class
F
>
migraphx_status
try_
(
F
f
,
bool
output
=
true
)
// NOLINT
{
...
...
tools/generate.py
View file @
01cf30d9
...
...
@@ -21,7 +21,7 @@
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
#####################################################################################
import
os
,
sys
,
argparse
,
subprocess
,
te
,
api
import
api
,
argparse
,
os
,
runpy
,
subprocess
,
sys
,
te
from
pathlib
import
Path
clang_format_path
=
Path
(
'clang-format.exe'
if
os
.
name
==
...
...
@@ -43,12 +43,12 @@ def clang_format(buffer, **kwargs):
def
api_generate
(
input_path
:
Path
,
output_path
:
Path
):
with
open
(
output_path
,
'w'
)
as
f
:
f
.
write
(
clang_format
(
api
.
invoke
(
input_path
)))
f
.
write
(
clang_format
(
api
.
run
(
input_path
)))
def
te_generate
(
input_path
:
Path
,
output_path
:
Path
):
with
open
(
output_path
,
'w'
)
as
f
:
f
.
write
(
clang_format
(
te
.
invoke
(
input_path
)))
f
.
write
(
clang_format
(
te
.
run
(
input_path
)))
def
main
():
...
...
@@ -66,11 +66,10 @@ def main():
return
try
:
for
f
in
[
f
for
f
in
Path
(
'include'
).
absolute
().
iterdir
()
if
f
.
is_file
()
]:
files
=
Path
(
'include'
).
absolute
().
iterdir
()
for
f
in
[
f
for
f
in
files
if
f
.
is_file
()]:
te_generate
(
f
,
src_dir
/
f
'include/migraphx/
{
f
.
name
}
'
)
api
.
register_functions
(
str
(
migraphx_py_path
))
runpy
.
run_path
(
str
(
migraphx_py_path
))
api_generate
(
work_dir
/
'api/migraphx.h'
,
src_dir
/
'api/include/migraphx/migraphx.h'
)
print
(
'Finished generating header migraphx.h'
)
...
...
tools/te.py
View file @
01cf30d9
...
...
@@ -431,9 +431,9 @@ def template_eval(template, **kwargs):
return
template
def
invoke
(
p
):
def
run
(
p
):
return
template_eval
(
open
(
p
).
read
())
if
__name__
==
'__main__'
:
sys
.
stdout
.
write
(
invoke
(
sys
.
argv
[
1
]))
sys
.
stdout
.
write
(
run
(
sys
.
argv
[
1
]))
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment