Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
1530ec24
Unverified
Commit
1530ec24
authored
Sep 18, 2023
by
Ted Themistokleous
Committed by
GitHub
Sep 18, 2023
Browse files
Merge branch 'develop' into add_parity_check_ci
parents
5c98fcb0
c2e01b10
Changes
305
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
774 additions
and
335 deletions
+774
-335
src/api/CMakeLists.txt
src/api/CMakeLists.txt
+1
-0
src/api/api.cpp
src/api/api.cpp
+3
-3
src/api/include/migraphx/migraphx.h
src/api/include/migraphx/migraphx.h
+298
-266
src/api/migraphx.py
src/api/migraphx.py
+3
-2
src/common_dims.cpp
src/common_dims.cpp
+156
-0
src/driver/CMakeLists.txt
src/driver/CMakeLists.txt
+11
-6
src/driver/argument_parser.hpp
src/driver/argument_parser.hpp
+14
-4
src/driver/main.cpp
src/driver/main.cpp
+35
-13
src/dynamic_loader.cpp
src/dynamic_loader.cpp
+13
-1
src/eliminate_contiguous.cpp
src/eliminate_contiguous.cpp
+20
-1
src/fuse_pointwise.cpp
src/fuse_pointwise.cpp
+54
-1
src/fuse_reduce.cpp
src/fuse_reduce.cpp
+2
-2
src/include/migraphx/algorithm.hpp
src/include/migraphx/algorithm.hpp
+38
-0
src/include/migraphx/allocation_model.hpp
src/include/migraphx/allocation_model.hpp
+2
-2
src/include/migraphx/builtin.hpp
src/include/migraphx/builtin.hpp
+11
-1
src/include/migraphx/check_shapes.hpp
src/include/migraphx/check_shapes.hpp
+58
-27
src/include/migraphx/common_dims.hpp
src/include/migraphx/common_dims.hpp
+49
-0
src/include/migraphx/concat_opt.hpp
src/include/migraphx/concat_opt.hpp
+2
-2
src/include/migraphx/context.hpp
src/include/migraphx/context.hpp
+2
-2
src/include/migraphx/convolution.hpp
src/include/migraphx/convolution.hpp
+2
-2
No files found.
src/api/CMakeLists.txt
View file @
1530ec24
...
@@ -26,6 +26,7 @@ add_library(migraphx_c
...
@@ -26,6 +26,7 @@ add_library(migraphx_c
api.cpp
api.cpp
)
)
set_target_properties
(
migraphx_c PROPERTIES EXPORT_NAME c
)
set_target_properties
(
migraphx_c PROPERTIES EXPORT_NAME c
)
migraphx_generate_export_header
(
migraphx_c DIRECTORY migraphx/api
)
# migraphx_c is stable API interface library. SO version of this should be
# migraphx_c is stable API interface library. SO version of this should be
# bumped when binary compatibility is broken.
# bumped when binary compatibility is broken.
...
...
src/api/api.cpp
View file @
1530ec24
...
@@ -44,7 +44,7 @@ namespace migraphx {
...
@@ -44,7 +44,7 @@ namespace migraphx {
static
thread_local
bool
disable_exception_catch
=
false
;
// NOLINT
static
thread_local
bool
disable_exception_catch
=
false
;
// NOLINT
extern
"C"
void
migraphx_test_private_disable_exception_catch
(
bool
b
)
extern
"C"
MIGRAPHX_C_EXPORT
void
migraphx_test_private_disable_exception_catch
(
bool
b
)
{
{
disable_exception_catch
=
b
;
disable_exception_catch
=
b
;
}
}
...
@@ -899,7 +899,7 @@ migraphx_dynamic_dimensions_assign_to(migraphx_dynamic_dimensions_t output,
...
@@ -899,7 +899,7 @@ migraphx_dynamic_dimensions_assign_to(migraphx_dynamic_dimensions_t output,
extern
"C"
migraphx_status
extern
"C"
migraphx_status
migraphx_dynamic_dimensions_create
(
migraphx_dynamic_dimensions_t
*
dynamic_dimensions
,
migraphx_dynamic_dimensions_create
(
migraphx_dynamic_dimensions_t
*
dynamic_dimensions
,
const_migraphx_dynamic_dimension_t
*
ptr
,
const
const_migraphx_dynamic_dimension_t
*
ptr
,
size_t
size
)
size_t
size
)
{
{
auto
api_error_result
=
migraphx
::
try_
([
&
]
{
auto
api_error_result
=
migraphx
::
try_
([
&
]
{
...
@@ -1432,7 +1432,7 @@ extern "C" migraphx_status migraphx_instructions_assign_to(migraphx_instructions
...
@@ -1432,7 +1432,7 @@ extern "C" migraphx_status migraphx_instructions_assign_to(migraphx_instructions
}
}
extern
"C"
migraphx_status
migraphx_instructions_create
(
migraphx_instructions_t
*
instructions
,
extern
"C"
migraphx_status
migraphx_instructions_create
(
migraphx_instructions_t
*
instructions
,
const_migraphx_instruction_t
*
ptr
,
const
const_migraphx_instruction_t
*
ptr
,
size_t
size
)
size_t
size
)
{
{
auto
api_error_result
=
migraphx
::
try_
([
&
]
{
auto
api_error_result
=
migraphx
::
try_
([
&
]
{
...
...
src/api/include/migraphx/migraphx.h
View file @
1530ec24
This diff is collapsed.
Click to expand it.
src/api/migraphx.py
View file @
1530ec24
...
@@ -79,7 +79,8 @@ def dynamic_dimension(h):
...
@@ -79,7 +79,8 @@ def dynamic_dimension(h):
def
dynamic_dimensions
(
h
):
def
dynamic_dimensions
(
h
):
h
.
constructor
(
h
.
constructor
(
'create'
,
'create'
,
api
.
params
(
ptr
=
'const_migraphx_dynamic_dimension_t*'
,
size
=
'size_t'
),
api
.
params
(
ptr
=
'const const_migraphx_dynamic_dimension_t*'
,
size
=
'size_t'
),
fname
=
'migraphx::to_obj_vector<const_migraphx_dynamic_dimension_t>'
)
fname
=
'migraphx::to_obj_vector<const_migraphx_dynamic_dimension_t>'
)
h
.
method
(
'size'
,
returns
=
'size_t'
)
h
.
method
(
'size'
,
returns
=
'size_t'
)
h
.
method
(
'get'
,
h
.
method
(
'get'
,
...
@@ -215,7 +216,7 @@ def instruction(h):
...
@@ -215,7 +216,7 @@ def instruction(h):
def
instructions
(
h
):
def
instructions
(
h
):
h
.
constructor
(
h
.
constructor
(
'create'
,
'create'
,
api
.
params
(
ptr
=
'const_migraphx_instruction_t*'
,
size
=
'size_t'
),
api
.
params
(
ptr
=
'const
const
_migraphx_instruction_t*'
,
size
=
'size_t'
),
fname
=
'migraphx::to_obj_vector<const_migraphx_instruction_t>'
)
fname
=
'migraphx::to_obj_vector<const_migraphx_instruction_t>'
)
...
...
src/common_dims.cpp
0 → 100644
View file @
1530ec24
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/common_dims.hpp>
#include <migraphx/ranges.hpp>
#include <algorithm>
#include <cassert>
#include <numeric>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
template
<
class
Iterator
>
static
auto
compute_end_dim
(
Iterator
start
,
Iterator
last
,
std
::
size_t
dim
)
{
std
::
size_t
x
=
1
;
auto
it
=
std
::
find_if
(
start
,
last
,
[
&
](
auto
i
)
{
x
*=
i
;
return
x
>
dim
;
});
if
(
x
<
dim
)
return
start
;
return
it
;
}
template
<
class
Range
>
static
auto
elements
(
const
Range
&
r
)
{
return
std
::
accumulate
(
r
.
begin
(),
r
.
end
(),
std
::
size_t
{
1
},
std
::
multiplies
<>
{});
}
struct
common_dim_state
{
common_dim_state
(
const
std
::
vector
<
std
::
size_t
>&
pdims
,
std
::
vector
<
std
::
vector
<
std
::
size_t
>>&
paxes_map
)
:
dims
(
&
pdims
),
axes_map
(
&
paxes_map
),
it
(
dims
->
begin
())
{
}
const
std
::
vector
<
std
::
size_t
>*
dims
=
nullptr
;
std
::
vector
<
std
::
vector
<
std
::
size_t
>>*
axes_map
=
nullptr
;
std
::
vector
<
std
::
size_t
>::
const_iterator
it
{};
std
::
size_t
rem
=
1
;
std
::
size_t
get
()
const
{
return
*
it
/
rem
;
}
bool
is_end
()
const
{
return
it
==
dims
->
end
();
}
void
next
(
std
::
size_t
i
=
1
)
{
it
+=
i
;
}
auto
dims_for
(
std
::
size_t
d
)
const
{
auto
dim_end
=
compute_end_dim
(
it
,
dims
->
end
(),
d
);
return
range
(
it
,
dim_end
);
}
void
add_axes
(
std
::
size_t
naxes
,
std
::
size_t
start
)
MIGRAPHX_TIDY_CONST
{
auto
axes
=
compute_axes
(
naxes
,
start
);
axes_map
->
push_back
(
std
::
move
(
axes
));
}
void
add_multi_axes
(
std
::
size_t
naxes
,
std
::
size_t
start
)
MIGRAPHX_TIDY_CONST
{
auto
axes
=
compute_axes
(
naxes
,
start
);
std
::
transform
(
axes
.
begin
(),
axes
.
end
(),
std
::
back_inserter
(
*
axes_map
),
[
&
](
auto
axis
)
->
std
::
vector
<
std
::
size_t
>
{
return
{
axis
};
});
}
std
::
vector
<
std
::
size_t
>
compute_axes
(
std
::
size_t
naxes
,
std
::
size_t
start
)
const
{
if
(
rem
!=
1
)
{
assert
(
start
>
0
);
naxes
++
;
start
--
;
}
std
::
vector
<
std
::
size_t
>
axes
(
naxes
);
std
::
iota
(
axes
.
begin
(),
axes
.
end
(),
start
);
return
axes
;
}
};
static
bool
compute_common_dim
(
std
::
vector
<
std
::
size_t
>&
cd_dims
,
common_dim_state
&
state1
,
common_dim_state
&
state2
)
{
assert
(
state1
.
get
()
<=
state2
.
get
());
auto
d2
=
state2
.
get
();
auto
dims
=
state1
.
dims_for
(
d2
);
auto
n
=
elements
(
dims
);
auto
naxes
=
distance
(
dims
);
if
(
naxes
==
0
)
return
false
;
// If not divisible then we can't compute a common dim
if
((
d2
%
n
)
!=
0
)
return
false
;
auto
rem
=
d2
/
n
;
state1
.
add_multi_axes
(
naxes
,
cd_dims
.
size
());
state2
.
add_axes
(
rem
==
1
?
naxes
:
naxes
+
1
,
cd_dims
.
size
());
state1
.
rem
=
rem
;
state2
.
rem
=
1
;
cd_dims
.
insert
(
cd_dims
.
end
(),
dims
.
begin
(),
dims
.
end
());
if
(
state1
.
rem
!=
1
)
cd_dims
.
push_back
(
state1
.
rem
);
state1
.
next
(
distance
(
dims
));
state2
.
next
();
return
true
;
}
common_dims
common_dims
::
compute
(
const
std
::
vector
<
std
::
size_t
>&
dims1
,
const
std
::
vector
<
std
::
size_t
>&
dims2
)
{
assert
(
elements
(
dims1
)
>
0
);
assert
(
elements
(
dims1
)
==
elements
(
dims2
));
common_dims
cd
;
common_dim_state
state1
{
dims1
,
cd
.
axes_map1
};
common_dim_state
state2
{
dims2
,
cd
.
axes_map2
};
while
(
not
state1
.
is_end
()
and
not
state2
.
is_end
())
{
auto
d1
=
state1
.
get
();
auto
d2
=
state2
.
get
();
if
(
d1
<=
d2
)
{
if
(
not
compute_common_dim
(
cd
.
dims
,
state1
,
state2
))
return
{};
}
else
// if(d1 > d2)
{
if
(
not
compute_common_dim
(
cd
.
dims
,
state2
,
state1
))
return
{};
}
}
assert
(
elements
(
dims1
)
==
elements
(
cd
.
dims
));
return
cd
;
}
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
src/driver/CMakeLists.txt
View file @
1530ec24
...
@@ -32,18 +32,23 @@ add_executable(driver
...
@@ -32,18 +32,23 @@ add_executable(driver
marker_roctx.cpp
marker_roctx.cpp
)
)
set_target_properties
(
driver PROPERTIES OUTPUT_NAME migraphx-driver
)
set_target_properties
(
driver PROPERTIES OUTPUT_NAME migraphx-driver
)
# Copy driver for backwards compatibility
if
(
NOT WIN32
)
add_custom_command
(
# Copy driver for backwards compatibility (Linux only)
TARGET driver
add_custom_command
(
TARGET driver
POST_BUILD COMMAND
${
CMAKE_COMMAND
}
-E copy
POST_BUILD COMMAND
${
CMAKE_COMMAND
}
-E copy
$<TARGET_FILE:driver>
$<TARGET_FILE:driver>
${
CMAKE_RUNTIME_OUTPUT_DIRECTORY
}
/driver
${
CMAKE_RUNTIME_OUTPUT_DIRECTORY
}
/driver
BYPRODUCTS
${
CMAKE_RUNTIME_OUTPUT_DIRECTORY
}
/driver
BYPRODUCTS
${
CMAKE_RUNTIME_OUTPUT_DIRECTORY
}
/driver
)
)
set_directory_properties
(
PROPERTIES ADDITIONAL_MAKE_CLEAN_FILES
${
CMAKE_RUNTIME_OUTPUT_DIRECTORY
}
/driver
)
set_directory_properties
(
PROPERTIES ADDITIONAL_CLEAN_FILES
${
CMAKE_RUNTIME_OUTPUT_DIRECTORY
}
/driver
)
endif
()
rocm_clang_tidy_check
(
driver
)
rocm_clang_tidy_check
(
driver
)
target_link_libraries
(
driver migraphx_all_targets migraphx_onnx migraphx_tf
)
file
(
STRINGS
"
${
CMAKE_SOURCE_DIR
}
/test/onnx/.onnxrt-commit"
String_output
)
target_compile_definitions
(
driver PUBLIC MIGRAPHX_ORT_SHA1=
"
${
String_output
}
"
)
target_link_libraries
(
driver migraphx_all_targets migraphx_onnx migraphx_tf migraphx_py
)
rocm_install_targets
(
rocm_install_targets
(
TARGETS driver
TARGETS driver
...
...
src/driver/argument_parser.hpp
View file @
1530ec24
...
@@ -338,11 +338,22 @@ struct argument_parser
...
@@ -338,11 +338,22 @@ struct argument_parser
MIGRAPHX_DRIVER_STATIC
auto
file_exist
()
MIGRAPHX_DRIVER_STATIC
auto
file_exist
()
{
{
return
validate
([](
auto
&
,
auto
&
,
auto
&
params
)
{
return
validate
([](
auto
&
,
auto
&
,
const
auto
&
params
)
{
if
(
params
.
empty
())
if
(
params
.
empty
())
throw
std
::
runtime_error
(
"No argument passed."
);
throw
std
::
runtime_error
(
"No argument passed."
);
if
(
not
fs
::
exists
(
params
.
back
()))
if
(
not
fs
::
exists
(
params
.
back
()))
throw
std
::
runtime_error
(
"Path does not exists: "
+
params
.
back
());
throw
std
::
runtime_error
(
"Path does not exist: "
+
params
.
back
());
});
}
MIGRAPHX_DRIVER_STATIC
auto
matches
(
const
std
::
unordered_set
<
std
::
string
>&
names
)
{
return
validate
([
=
](
auto
&
,
auto
&
,
const
auto
&
params
)
{
auto
invalid_param
=
std
::
find_if
(
params
.
begin
(),
params
.
end
(),
[
&
](
const
auto
&
p
)
{
return
names
.
count
(
p
)
==
0
;
});
if
(
invalid_param
!=
params
.
end
())
throw
std
::
runtime_error
(
"Invalid argument: "
+
*
invalid_param
+
". Valid arguments are {"
+
to_string_range
(
names
)
+
"}"
);
});
});
}
}
...
@@ -570,8 +581,7 @@ struct argument_parser
...
@@ -570,8 +581,7 @@ struct argument_parser
continue
;
continue
;
if
(
flag
[
0
]
!=
'-'
)
if
(
flag
[
0
]
!=
'-'
)
continue
;
continue
;
auto
d
=
std
::
ptrdiff_t
d
=
levenshtein_distance
(
flag
,
input
);
levenshtein_distance
(
flag
.
begin
(),
flag
.
end
(),
input
.
begin
(),
input
.
end
());
if
(
d
<
result
.
distance
)
if
(
d
<
result
.
distance
)
result
=
result_t
{
&
arg
,
flag
,
input
,
d
};
result
=
result_t
{
&
arg
,
flag
,
input
,
d
};
}
}
...
...
src/driver/main.cpp
View file @
1530ec24
...
@@ -32,6 +32,7 @@
...
@@ -32,6 +32,7 @@
#include <migraphx/tf.hpp>
#include <migraphx/tf.hpp>
#include <migraphx/onnx.hpp>
#include <migraphx/onnx.hpp>
#include <migraphx/py.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/convert_to_json.hpp>
#include <migraphx/convert_to_json.hpp>
#include <migraphx/load_save.hpp>
#include <migraphx/load_save.hpp>
...
@@ -81,6 +82,7 @@ struct loader
...
@@ -81,6 +82,7 @@ struct loader
{
"--model"
},
{
"--model"
},
ap
.
help
(
"Load model"
),
ap
.
help
(
"Load model"
),
ap
.
type
(
"resnet50|inceptionv3|alexnet"
),
ap
.
type
(
"resnet50|inceptionv3|alexnet"
),
ap
.
matches
({
"resnet50"
,
"inceptionv3"
,
"alexnet"
}),
ap
.
group
(
"input"
));
ap
.
group
(
"input"
));
ap
(
file_type
,
{
"--onnx"
},
ap
.
help
(
"Load as onnx"
),
ap
.
set_value
(
"onnx"
));
ap
(
file_type
,
{
"--onnx"
},
ap
.
help
(
"Load as onnx"
),
ap
.
set_value
(
"onnx"
));
ap
(
file_type
,
{
"--tf"
},
ap
.
help
(
"Load as tensorflow"
),
ap
.
set_value
(
"tf"
));
ap
(
file_type
,
{
"--tf"
},
ap
.
help
(
"Load as tensorflow"
),
ap
.
set_value
(
"tf"
));
...
@@ -241,6 +243,20 @@ struct loader
...
@@ -241,6 +243,20 @@ struct loader
return
options
;
return
options
;
}
}
static
std
::
string
get_file_type
(
const
std
::
string
&
file
)
{
if
(
ends_with
(
file
,
".onnx"
))
return
"onnx"
;
else
if
(
ends_with
(
file
,
".pb"
))
return
"tf"
;
else
if
(
ends_with
(
file
,
".json"
))
return
"json"
;
else
if
(
ends_with
(
file
,
".py"
))
return
"py"
;
else
return
"migraphx"
;
}
program
load
()
program
load
()
{
{
program
p
;
program
p
;
...
@@ -248,14 +264,7 @@ struct loader
...
@@ -248,14 +264,7 @@ struct loader
{
{
if
(
file_type
.
empty
())
if
(
file_type
.
empty
())
{
{
if
(
ends_with
(
file
,
".onnx"
))
file_type
=
get_file_type
(
file
);
file_type
=
"onnx"
;
else
if
(
ends_with
(
file
,
".pb"
))
file_type
=
"tf"
;
else
if
(
ends_with
(
file
,
".json"
))
file_type
=
"json"
;
else
file_type
=
"migraphx"
;
}
}
std
::
cout
<<
"Reading: "
<<
file
<<
std
::
endl
;
std
::
cout
<<
"Reading: "
<<
file
<<
std
::
endl
;
if
(
file_type
==
"onnx"
)
if
(
file_type
==
"onnx"
)
...
@@ -272,6 +281,10 @@ struct loader
...
@@ -272,6 +281,10 @@ struct loader
options
.
format
=
"json"
;
options
.
format
=
"json"
;
p
=
migraphx
::
load
(
file
,
options
);
p
=
migraphx
::
load
(
file
,
options
);
}
}
else
if
(
file_type
==
"py"
)
{
p
=
migraphx
::
load_py
(
file
);
}
else
if
(
file_type
==
"migraphx"
)
else
if
(
file_type
==
"migraphx"
)
{
{
p
=
migraphx
::
load
(
file
);
p
=
migraphx
::
load
(
file
);
...
@@ -462,13 +475,15 @@ struct compiler
...
@@ -462,13 +475,15 @@ struct compiler
{
{
if
(
is_offload_copy_set
(
p
)
and
not
co
.
offload_copy
)
if
(
is_offload_copy_set
(
p
)
and
not
co
.
offload_copy
)
{
{
std
::
cout
<<
"MIGraphX program was likely compiled with offload_copy set, Try "
std
::
cout
"passing "
<<
"[WARNING]: MIGraphX program was likely compiled with offload_copy "
"`--enable-offload-copy` if program run fails.
\n
"
;
"set, Try "
"passing "
"`--enable-offload-copy` if program run fails.
\n
"
;
}
}
else
if
(
co
.
offload_copy
)
else
if
(
co
.
offload_copy
)
{
{
std
::
cout
<<
"MIGraphX program was likely compiled without "
std
::
cout
<<
"
[WARNING]:
MIGraphX program was likely compiled without "
"offload_copy set, Try "
"offload_copy set, Try "
"removing "
"removing "
"`--enable-offload-copy` flag if passed to driver, if program run "
"`--enable-offload-copy` flag if passed to driver, if program run "
...
@@ -757,7 +772,7 @@ struct main_command
...
@@ -757,7 +772,7 @@ struct main_command
{
{
std
::
cout
<<
"'"
<<
color
::
fg_yellow
<<
wrong_commands
.
front
()
<<
color
::
reset
std
::
cout
<<
"'"
<<
color
::
fg_yellow
<<
wrong_commands
.
front
()
<<
color
::
reset
<<
"' is not a valid command."
<<
std
::
endl
;
<<
"' is not a valid command."
<<
std
::
endl
;
std
::
cout
<<
get_command_help
(
"Available commands:"
)
<<
std
::
endl
;
std
::
cout
<<
get_command_help
(
"Available commands:"
);
}
}
else
else
{
{
...
@@ -789,6 +804,13 @@ int main(int argc, const char* argv[])
...
@@ -789,6 +804,13 @@ int main(int argc, const char* argv[])
auto
&&
m
=
get_commands
();
auto
&&
m
=
get_commands
();
auto
cmd
=
args
.
front
();
auto
cmd
=
args
.
front
();
if
(
cmd
==
"ort-sha"
)
{
std
::
cout
<<
MIGRAPHX_ORT_SHA1
<<
std
::
endl
;
return
0
;
}
if
(
m
.
count
(
cmd
)
>
0
)
if
(
m
.
count
(
cmd
)
>
0
)
{
{
m
.
at
(
cmd
)(
argv
[
0
],
{
args
.
begin
()
+
1
,
args
.
end
()});
m
.
at
(
cmd
)(
argv
[
0
],
{
args
.
begin
()
+
1
,
args
.
end
()});
...
...
src/dynamic_loader.cpp
View file @
1530ec24
...
@@ -48,7 +48,7 @@ struct dynamic_loader_impl
...
@@ -48,7 +48,7 @@ struct dynamic_loader_impl
#pragma GCC diagnostic ignored "-Wignored-attributes"
#pragma GCC diagnostic ignored "-Wignored-attributes"
#endif
#endif
dynamic_loader_impl
(
const
fs
::
path
&
p
,
std
::
shared_ptr
<
tmp_dir
>
t
=
nullptr
)
dynamic_loader_impl
(
const
fs
::
path
&
p
,
std
::
shared_ptr
<
tmp_dir
>
t
=
nullptr
)
:
handle
(
dlopen
(
p
.
string
().
c_str
(),
RTLD_
LAZY
),
:
handle
(
dlopen
(
p
.
string
().
c_str
(),
RTLD_
GLOBAL
|
RTLD_NOW
),
manage_deleter
<
decltype
(
&
dlclose
),
&
dlclose
>
{}),
manage_deleter
<
decltype
(
&
dlclose
),
&
dlclose
>
{}),
temp
(
std
::
move
(
t
))
temp
(
std
::
move
(
t
))
{
{
...
@@ -81,6 +81,18 @@ fs::path dynamic_loader::path(void* address)
...
@@ -81,6 +81,18 @@ fs::path dynamic_loader::path(void* address)
return
p
;
return
p
;
}
}
optional
<
dynamic_loader
>
dynamic_loader
::
try_load
(
const
fs
::
path
&
p
)
{
try
{
return
dynamic_loader
{
p
};
}
catch
(
const
std
::
exception
&
)
{
return
nullopt
;
}
}
dynamic_loader
::
dynamic_loader
(
const
fs
::
path
&
p
)
:
impl
(
std
::
make_shared
<
dynamic_loader_impl
>
(
p
))
dynamic_loader
::
dynamic_loader
(
const
fs
::
path
&
p
)
:
impl
(
std
::
make_shared
<
dynamic_loader_impl
>
(
p
))
{
{
}
}
...
...
src/eliminate_contiguous.cpp
View file @
1530ec24
...
@@ -35,6 +35,8 @@
...
@@ -35,6 +35,8 @@
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_TRACE_ELIMINATE_CONTIGUOUS
)
static
bool
try_compute_shape
(
instruction_ref
ins
,
static
bool
try_compute_shape
(
instruction_ref
ins
,
const
std
::
vector
<
shape
>&
inputs
,
const
std
::
vector
<
shape
>&
inputs
,
const
std
::
vector
<
module_ref
>&
mods
)
const
std
::
vector
<
module_ref
>&
mods
)
...
@@ -78,14 +80,26 @@ static bool try_compute_shape(instruction_ref ins,
...
@@ -78,14 +80,26 @@ static bool try_compute_shape(instruction_ref ins,
return
(
arg
==
ins
)
?
new_shape
:
arg
->
get_shape
();
return
(
arg
==
ins
)
?
new_shape
:
arg
->
get_shape
();
});
});
if
(
not
try_compute_shape
(
output
,
input_shapes
,
mods
))
if
(
not
try_compute_shape
(
output
,
input_shapes
,
output
->
module_inputs
()
))
{
{
return
false
;
return
false
;
}
}
}
}
}
}
catch
(
const
std
::
exception
&
e
)
{
if
(
enabled
(
MIGRAPHX_TRACE_ELIMINATE_CONTIGUOUS
{}))
{
std
::
cout
<<
"Exception: "
<<
e
.
what
()
<<
std
::
endl
;
}
return
false
;
}
catch
(...)
catch
(...)
{
{
if
(
enabled
(
MIGRAPHX_TRACE_ELIMINATE_CONTIGUOUS
{}))
{
std
::
cout
<<
"Unknown exception"
<<
std
::
endl
;
}
return
false
;
return
false
;
}
}
...
@@ -127,6 +141,11 @@ static void remove_contiguous(const std::string& op_name, module& m, F f)
...
@@ -127,6 +141,11 @@ static void remove_contiguous(const std::string& op_name, module& m, F f)
{
{
if
(
arg
->
name
()
!=
op_name
)
if
(
arg
->
name
()
!=
op_name
)
continue
;
continue
;
if
(
enabled
(
MIGRAPHX_TRACE_ELIMINATE_CONTIGUOUS
{}))
{
std
::
cout
<<
"eliminate_contiguous: "
;
m
.
debug_print
(
ins
);
}
auto
prev
=
arg
->
inputs
().
front
();
auto
prev
=
arg
->
inputs
().
front
();
replace
(
new_args
,
arg
,
prev
);
replace
(
new_args
,
arg
,
prev
);
if
(
try_compute_shape
(
ins
,
new_args
,
mod_args
))
if
(
try_compute_shape
(
ins
,
new_args
,
mod_args
))
...
...
src/fuse_pointwise.cpp
View file @
1530ec24
...
@@ -24,11 +24,14 @@
...
@@ -24,11 +24,14 @@
#include <migraphx/fuse_pointwise.hpp>
#include <migraphx/fuse_pointwise.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/simplify_reshapes.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/program.hpp>
#include <migraphx/program.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/matcher.hpp>
#include <migraphx/common_dims.hpp>
#include <iterator>
#include <iterator>
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_DISABLE_POINTWISE_FUSION
)
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_DISABLE_POINTWISE_FUSION
)
...
@@ -41,7 +44,7 @@ static literal get_scalar(instruction_ref ins)
...
@@ -41,7 +44,7 @@ static literal get_scalar(instruction_ref ins)
if
(
ins
->
name
()
==
"contiguous"
)
if
(
ins
->
name
()
==
"contiguous"
)
return
get_scalar
(
ins
->
inputs
().
front
());
return
get_scalar
(
ins
->
inputs
().
front
());
const
auto
&
s
=
ins
->
get_shape
();
const
auto
&
s
=
ins
->
get_shape
();
if
(
s
.
elements
()
!=
1
&&
not
(
s
.
scalar
()))
if
(
s
.
elements
()
!=
1
and
not
(
s
.
scalar
()))
return
{};
return
{};
if
(
not
ins
->
can_eval
())
if
(
not
ins
->
can_eval
())
return
{};
return
{};
...
@@ -189,6 +192,54 @@ static bool find_pointwise_modules(module& m)
...
@@ -189,6 +192,54 @@ static bool find_pointwise_modules(module& m)
}
}
return
changed
;
return
changed
;
}
}
namespace
{
struct
find_pointwise_reshape_pointwise
{
auto
matcher
()
const
{
auto
reshape
=
match
::
name
(
"reshape"
,
"squeeze"
,
"unsqueeze"
,
"flatten"
)(
match
::
used_once
());
auto
skip_contiguous
=
[](
auto
...
ms
)
{
return
match
::
arg
(
0
)(
match
::
skip
(
match
::
name
(
"contiguous"
)(
match
::
used_once
()))(
ms
...));
};
auto
pointwise
=
match
::
name
(
"pointwise"
)(
match
::
used_once
());
auto
reshape_pointwise
=
reshape
(
skip_contiguous
(
pointwise
.
bind
(
"x"
))).
bind
(
"reshape"
);
return
match
::
name
(
"pointwise"
)(
match
::
any_of
[
match
::
inputs
()](
reshape_pointwise
));
}
void
apply
(
module
&
m
,
const
match
::
matcher_result
&
r
)
const
{
auto
ins
=
r
.
result
;
auto
x_ins
=
r
.
instructions
[
"x"
];
auto
reshape_ins
=
r
.
instructions
[
"reshape"
];
auto
cd
=
common_dims
::
compute
(
ins
->
get_shape
().
lens
(),
x_ins
->
get_shape
().
lens
());
if
(
cd
.
dims
.
empty
())
return
;
auto
reshape_input
=
[
&
](
const
auto
&
ins_to_insert
)
{
return
[
&
](
auto
input
)
{
auto
c
=
m
.
insert_instruction
(
ins_to_insert
,
make_op
(
"contiguous"
),
input
);
return
m
.
insert_instruction
(
ins_to_insert
,
make_op
(
"reshape"
,
{{
"dims"
,
cd
.
dims
}}),
c
);
};
};
auto
x_inputs
=
x_ins
->
inputs
();
std
::
transform
(
x_inputs
.
begin
(),
x_inputs
.
end
(),
x_inputs
.
begin
(),
reshape_input
(
x_ins
));
auto
new_x_ins
=
m
.
insert_instruction
(
x_ins
,
x_ins
->
get_operator
(),
x_inputs
,
x_ins
->
module_inputs
());
auto
inputs
=
ins
->
inputs
();
std
::
transform
(
inputs
.
begin
(),
inputs
.
end
(),
inputs
.
begin
(),
[
&
](
auto
input
)
{
if
(
input
==
reshape_ins
)
return
new_x_ins
;
return
reshape_input
(
ins
)(
input
);
});
auto
pw
=
m
.
insert_instruction
(
ins
,
ins
->
get_operator
(),
inputs
,
ins
->
module_inputs
());
m
.
replace_instruction
(
ins
,
make_op
(
"reshape"
,
{{
"dims"
,
ins
->
get_shape
().
lens
()}}),
pw
);
}
};
}
// namespace
void
fuse_pointwise
::
apply
(
module_pass_manager
&
mpm
)
const
void
fuse_pointwise
::
apply
(
module_pass_manager
&
mpm
)
const
{
{
...
@@ -200,6 +251,8 @@ void fuse_pointwise::apply(module_pass_manager& mpm) const
...
@@ -200,6 +251,8 @@ void fuse_pointwise::apply(module_pass_manager& mpm) const
}
}
for
(
int
i
=
0
;
i
<
8
;
i
++
)
for
(
int
i
=
0
;
i
<
8
;
i
++
)
{
{
match
::
find_matches
(
mpm
.
get_module
(),
find_pointwise_reshape_pointwise
{});
mpm
.
run_pass
(
simplify_reshapes
{
1
});
if
(
not
find_pointwise_modules
(
mpm
.
get_module
()))
if
(
not
find_pointwise_modules
(
mpm
.
get_module
()))
break
;
break
;
mpm
.
run_pass
(
dead_code_elimination
{});
mpm
.
run_pass
(
dead_code_elimination
{});
...
...
src/fuse_reduce.cpp
View file @
1530ec24
...
@@ -52,7 +52,7 @@ struct fused_reduce
...
@@ -52,7 +52,7 @@ struct fused_reduce
{
{
if
(
mods
.
size
()
!=
1
)
if
(
mods
.
size
()
!=
1
)
MIGRAPHX_THROW
(
"should have one submodule."
);
MIGRAPHX_THROW
(
"should have one submodule."
);
auto
*
sm
=
mods
.
front
();
const
auto
*
sm
=
mods
.
front
();
if
(
sm
->
get_output_shapes
().
size
()
!=
1
)
if
(
sm
->
get_output_shapes
().
size
()
!=
1
)
MIGRAPHX_THROW
(
"Only one output supported"
);
MIGRAPHX_THROW
(
"Only one output supported"
);
auto
names
=
sm
->
get_parameter_names
();
auto
names
=
sm
->
get_parameter_names
();
...
@@ -143,7 +143,7 @@ insert_module_in_submodule(module_ref sm,
...
@@ -143,7 +143,7 @@ insert_module_in_submodule(module_ref sm,
}
}
static
std
::
vector
<
instruction_ref
>
static
std
::
vector
<
instruction_ref
>
find_inputs
(
module_ref
sm
,
find_inputs
(
const_
module_ref
sm
,
const
module
&
parent
,
const
module
&
parent
,
const
std
::
unordered_map
<
instruction_ref
,
instruction_ref
>&
map_ins
)
const
std
::
unordered_map
<
instruction_ref
,
instruction_ref
>&
map_ins
)
{
{
...
...
src/include/migraphx/algorithm.hpp
View file @
1530ec24
...
@@ -26,6 +26,8 @@
...
@@ -26,6 +26,8 @@
#include <algorithm>
#include <algorithm>
#include <numeric>
#include <numeric>
#include <string>
#include <vector>
#include <migraphx/config.hpp>
#include <migraphx/config.hpp>
namespace
migraphx
{
namespace
migraphx
{
...
@@ -90,6 +92,42 @@ levenshtein_distance(Iterator1 first1, Iterator1 last1, Iterator2 first2, Iterat
...
@@ -90,6 +92,42 @@ levenshtein_distance(Iterator1 first1, Iterator1 last1, Iterator2 first2, Iterat
return
std
::
ptrdiff_t
{
1
}
+
std
::
min
({
x1
,
x2
,
x3
});
return
std
::
ptrdiff_t
{
1
}
+
std
::
min
({
x1
,
x2
,
x3
});
}
}
inline
size_t
levenshtein_distance
(
const
std
::
string
&
s1
,
const
std
::
string
&
s2
)
{
const
size_t
l1
=
s1
.
length
();
const
size_t
l2
=
s2
.
length
();
if
(
l1
<
l2
)
levenshtein_distance
(
s2
,
s1
);
std
::
vector
<
size_t
>
d
(
l2
+
1
);
std
::
iota
(
d
.
begin
(),
d
.
end
(),
0
);
for
(
size_t
i
=
1
;
i
<=
l1
;
i
++
)
{
size_t
prev_cost
=
d
[
0
];
d
[
0
]
=
i
;
for
(
size_t
j
=
1
;
j
<=
l2
;
j
++
)
{
if
(
s1
[
i
-
1
]
==
s2
[
j
-
1
])
{
d
[
j
]
=
prev_cost
;
}
else
{
size_t
cost_insert_or_delete
=
std
::
min
(
d
[
j
-
1
],
d
[
j
]);
size_t
cost_substitute
=
prev_cost
;
prev_cost
=
d
[
j
];
d
[
j
]
=
std
::
min
(
cost_substitute
,
cost_insert_or_delete
)
+
1
;
}
}
}
return
d
[
l2
];
}
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
}
// namespace migraphx
...
...
src/include/migraphx/allocation_model.hpp
View file @
1530ec24
...
@@ -96,7 +96,7 @@ struct allocation_model
...
@@ -96,7 +96,7 @@ struct allocation_model
{
{
using
std
::
swap
;
using
std
::
swap
;
auto
*
derived
=
this
->
any_cast
<
PrivateDetailTypeErasedT
>
();
auto
*
derived
=
this
->
any_cast
<
PrivateDetailTypeErasedT
>
();
if
(
derived
and
private_detail_te_handle_mem_var
.
u
nique
()
)
if
(
derived
and
private_detail_te_handle_mem_var
.
u
se_count
()
==
1
)
{
{
*
derived
=
std
::
forward
<
PrivateDetailTypeErasedT
>
(
value
);
*
derived
=
std
::
forward
<
PrivateDetailTypeErasedT
>
(
value
);
}
}
...
@@ -267,7 +267,7 @@ struct allocation_model
...
@@ -267,7 +267,7 @@ struct allocation_model
private_detail_te_handle_base_type
&
private_detail_te_get_handle
()
private_detail_te_handle_base_type
&
private_detail_te_get_handle
()
{
{
assert
(
private_detail_te_handle_mem_var
!=
nullptr
);
assert
(
private_detail_te_handle_mem_var
!=
nullptr
);
if
(
not
private_detail_te_handle_mem_var
.
u
nique
()
)
if
(
private_detail_te_handle_mem_var
.
u
se_count
()
>
1
)
private_detail_te_handle_mem_var
=
private_detail_te_handle_mem_var
->
clone
();
private_detail_te_handle_mem_var
=
private_detail_te_handle_mem_var
->
clone
();
return
*
private_detail_te_handle_mem_var
;
return
*
private_detail_te_handle_mem_var
;
}
}
...
...
src/include/migraphx/builtin.hpp
View file @
1530ec24
...
@@ -90,7 +90,17 @@ struct param
...
@@ -90,7 +90,17 @@ struct param
struct
returns
struct
returns
{
{
std
::
string
name
()
const
{
return
"@return"
;
}
std
::
string
name
()
const
{
return
"@return"
;
}
shape
compute_shape
(
const
std
::
vector
<
shape
>&
)
const
{
return
{};
}
shape
compute_shape
(
const
std
::
vector
<
shape
>&
arg
)
const
{
if
(
arg
.
empty
())
return
{};
else
if
(
arg
.
size
()
==
1
)
return
arg
[
0
];
else
return
arg
;
}
argument
compute
(
context
&
,
const
shape
&
,
const
std
::
vector
<
argument
>&
)
const
argument
compute
(
context
&
,
const
shape
&
,
const
std
::
vector
<
argument
>&
)
const
{
{
MIGRAPHX_THROW
(
"builtin"
);
MIGRAPHX_THROW
(
"builtin"
);
...
...
src/include/migraphx/check_shapes.hpp
View file @
1530ec24
/*
/*
* The MIT License (MIT)
* The MIT License (MIT)
*
*
* Copyright (c) 2015-202
2
Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-202
3
Advanced Micro Devices, Inc. All rights reserved.
*
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* of this software and associated documentation files (the "Software"), to deal
...
@@ -34,29 +34,51 @@
...
@@ -34,29 +34,51 @@
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
// Check that deduced type is incrementable, dereferencable, and comparable
template
<
class
,
class
=
void
>
struct
is_iterator
{
};
template
<
class
T
>
struct
is_iterator
<
T
,
std
::
void_t
<
decltype
(
++
std
::
declval
<
T
&>
()),
decltype
(
*
std
::
declval
<
T
&>
()),
decltype
(
std
::
declval
<
T
&>
()
==
std
::
declval
<
T
&>
())
>>
:
std
::
true_type
{
};
template
<
class
Iterator
>
struct
check_shapes
struct
check_shapes
{
{
const
shape
*
begin
;
static_assert
(
is_iterator
<
Iterator
>
{},
"CHECK_SHAPES: Deduced type must be an iterator"
);
const
shape
*
end
;
Iterator
begin
;
Iterator
end
;
std
::
string
name
;
std
::
string
name
;
bool
dynamic_allowed
;
bool
dynamic_allowed
;
check_shapes
(
const
shape
*
b
,
const
shape
*
e
,
const
std
::
string
&
n
,
const
bool
d
=
false
)
check_shapes
(
Iterator
b
,
Iterator
e
,
const
std
::
string
&
n
,
const
bool
d
=
false
)
:
begin
(
b
),
end
(
e
),
name
(
n
),
dynamic_allowed
(
d
)
:
begin
(
b
),
end
(
e
),
name
(
n
),
dynamic_allowed
(
d
)
{
{
check_dynamic
();
check_dynamic
();
}
}
template
<
class
Op
>
template
<
class
Op
>
check_shapes
(
const
shape
*
b
,
const
shape
*
e
,
const
Op
&
op
,
const
bool
d
=
false
)
check_shapes
(
Iterator
b
,
Iterator
e
,
const
Op
&
op
,
const
bool
d
=
false
)
:
begin
(
b
),
end
(
e
),
name
(
op
.
name
()),
dynamic_allowed
(
d
)
:
begin
(
b
),
end
(
e
),
name
(
op
.
name
()),
dynamic_allowed
(
d
)
{
{
check_dynamic
();
check_dynamic
();
}
}
template
<
class
Op
>
template
<
class
Op
,
MIGRAPHX_REQUIRES
(
not
std
::
is_convertible
<
Op
,
std
::
string
>{})
>
check_shapes
(
const
std
::
vector
<
shape
>&
s
,
const
Op
&
op
,
const
bool
d
=
false
)
check_shapes
(
const
std
::
vector
<
shape
>&
s
,
const
Op
&
op
,
const
bool
d
=
false
)
:
begin
(
s
.
data
()),
end
(
s
.
data
()
+
s
.
size
()),
name
(
op
.
name
()),
dynamic_allowed
(
d
)
:
begin
(
s
.
begin
()),
end
(
s
.
end
()),
name
(
op
.
name
()),
dynamic_allowed
(
d
)
{
check_dynamic
();
}
check_shapes
(
const
std
::
vector
<
shape
>&
s
,
const
std
::
string
&
n
,
const
bool
d
=
false
)
:
begin
(
s
.
begin
()),
end
(
s
.
end
()),
name
(
n
),
dynamic_allowed
(
d
)
{
{
check_dynamic
();
check_dynamic
();
}
}
...
@@ -81,8 +103,6 @@ struct check_shapes
...
@@ -81,8 +103,6 @@ struct check_shapes
{
{
if
(
begin
==
end
)
if
(
begin
==
end
)
return
0
;
return
0
;
assert
(
begin
!=
nullptr
);
assert
(
end
!=
nullptr
);
return
end
-
begin
;
return
end
-
begin
;
}
}
...
@@ -131,11 +151,9 @@ struct check_shapes
...
@@ -131,11 +151,9 @@ struct check_shapes
*/
*/
const
check_shapes
&
only_dims
(
std
::
size_t
n
)
const
const
check_shapes
&
only_dims
(
std
::
size_t
n
)
const
{
{
assert
(
begin
!=
nullptr
);
assert
(
end
!=
nullptr
);
if
(
begin
!=
end
)
if
(
begin
!=
end
)
{
{
if
(
begin
->
max_lens
().
size
()
!=
n
)
if
(
begin
->
ndim
()
!=
n
)
MIGRAPHX_THROW
(
prefix
()
+
"Only "
+
std
::
to_string
(
n
)
+
"d supported"
);
MIGRAPHX_THROW
(
prefix
()
+
"Only "
+
std
::
to_string
(
n
)
+
"d supported"
);
}
}
return
*
this
;
return
*
this
;
...
@@ -148,11 +166,9 @@ struct check_shapes
...
@@ -148,11 +166,9 @@ struct check_shapes
*/
*/
const
check_shapes
&
max_ndims
(
std
::
size_t
n
)
const
const
check_shapes
&
max_ndims
(
std
::
size_t
n
)
const
{
{
assert
(
begin
!=
nullptr
);
assert
(
end
!=
nullptr
);
if
(
begin
!=
end
)
if
(
begin
!=
end
)
{
{
if
(
begin
->
max_lens
().
size
()
>
n
)
if
(
begin
->
ndim
()
>
n
)
MIGRAPHX_THROW
(
prefix
()
+
"Shape must have at most "
+
std
::
to_string
(
n
)
+
MIGRAPHX_THROW
(
prefix
()
+
"Shape must have at most "
+
std
::
to_string
(
n
)
+
" dimensions"
);
" dimensions"
);
}
}
...
@@ -166,11 +182,9 @@ struct check_shapes
...
@@ -166,11 +182,9 @@ struct check_shapes
*/
*/
const
check_shapes
&
min_ndims
(
std
::
size_t
n
)
const
const
check_shapes
&
min_ndims
(
std
::
size_t
n
)
const
{
{
assert
(
begin
!=
nullptr
);
assert
(
end
!=
nullptr
);
if
(
begin
!=
end
)
if
(
begin
!=
end
)
{
{
if
(
begin
->
max_lens
().
size
()
<
n
)
if
(
begin
->
ndim
()
<
n
)
MIGRAPHX_THROW
(
prefix
()
+
"Shape must have at least "
+
std
::
to_string
(
n
)
+
MIGRAPHX_THROW
(
prefix
()
+
"Shape must have at least "
+
std
::
to_string
(
n
)
+
" dimensions"
);
" dimensions"
);
}
}
...
@@ -220,6 +234,16 @@ struct check_shapes
...
@@ -220,6 +234,16 @@ struct check_shapes
return
*
this
;
return
*
this
;
}
}
/*!
* Check all shapes have the same layout.
*/
const
check_shapes
&
same_layout
()
const
{
if
(
not
this
->
same
([](
const
shape
&
s
)
{
return
find_permutation
(
s
);
}))
MIGRAPHX_THROW
(
prefix
()
+
"Layouts do not match"
);
return
*
this
;
}
/*!
/*!
* Check all shapes are standard.
* Check all shapes are standard.
*/
*/
...
@@ -230,6 +254,16 @@ struct check_shapes
...
@@ -230,6 +254,16 @@ struct check_shapes
return
*
this
;
return
*
this
;
}
}
/*!
* Check all shapes are scalar.
*/
const
check_shapes
&
scalar
()
const
{
if
(
not
this
->
all_of
([](
const
shape
&
s
)
{
return
s
.
scalar
();
}))
MIGRAPHX_THROW
(
prefix
()
+
"Shapes are not a scalar"
);
return
*
this
;
}
/*!
/*!
* Check all shapes are standard or scalar.
* Check all shapes are standard or scalar.
*/
*/
...
@@ -330,8 +364,6 @@ struct check_shapes
...
@@ -330,8 +364,6 @@ struct check_shapes
{
{
if
(
begin
==
end
)
if
(
begin
==
end
)
return
true
;
return
true
;
assert
(
begin
!=
nullptr
);
assert
(
end
!=
nullptr
);
auto
&&
key
=
f
(
*
begin
);
auto
&&
key
=
f
(
*
begin
);
return
this
->
all_of
([
&
](
const
shape
&
s
)
{
return
f
(
s
)
==
key
;
});
return
this
->
all_of
([
&
](
const
shape
&
s
)
{
return
f
(
s
)
==
key
;
});
}
}
...
@@ -341,8 +373,6 @@ struct check_shapes
...
@@ -341,8 +373,6 @@ struct check_shapes
{
{
if
(
begin
==
end
)
if
(
begin
==
end
)
return
true
;
return
true
;
assert
(
begin
!=
nullptr
);
assert
(
end
!=
nullptr
);
return
std
::
all_of
(
begin
,
end
,
p
);
return
std
::
all_of
(
begin
,
end
,
p
);
}
}
...
@@ -351,17 +381,13 @@ struct check_shapes
...
@@ -351,17 +381,13 @@ struct check_shapes
{
{
if
(
begin
==
end
)
if
(
begin
==
end
)
return
false
;
return
false
;
assert
(
begin
!=
nullptr
);
assert
(
end
!=
nullptr
);
return
std
::
any_of
(
begin
,
end
,
p
);
return
std
::
any_of
(
begin
,
end
,
p
);
}
}
const
shape
*
get
(
long
i
)
const
Iterator
get
(
long
i
)
const
{
{
if
(
i
>=
size
())
if
(
i
>=
size
())
MIGRAPHX_THROW
(
prefix
()
+
"Accessing shape out of bounds"
);
MIGRAPHX_THROW
(
prefix
()
+
"Accessing shape out of bounds"
);
assert
(
begin
!=
nullptr
);
assert
(
end
!=
nullptr
);
if
(
i
<
0
)
if
(
i
<
0
)
return
end
-
i
;
return
end
-
i
;
return
begin
+
i
;
return
begin
+
i
;
...
@@ -394,6 +420,11 @@ struct check_shapes
...
@@ -394,6 +420,11 @@ struct check_shapes
}
}
};
};
// Deduction guide for std::vector constructor
template
<
class
Op
>
check_shapes
(
const
std
::
vector
<
shape
>&
,
const
Op
&
,
bool
d
=
false
)
->
check_shapes
<
std
::
vector
<
shape
>::
const_iterator
>
;
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
}
// namespace migraphx
...
...
src/include/migraphx/common_dims.hpp
0 → 100644
View file @
1530ec24
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_MIGRAPHX_COMMON_DIMS_HPP
#define MIGRAPHX_GUARD_MIGRAPHX_COMMON_DIMS_HPP
#include <migraphx/config.hpp>
#include <cstdint>
#include <vector>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
/// This will compute a higher dimensional space that will preserve the axes
/// for both sets of dimensions. Two axes_maps are provided for each of the
/// dims that will map the axis to the axes that are used by the result of
/// common_dims.
struct
MIGRAPHX_EXPORT
common_dims
{
static
common_dims
compute
(
const
std
::
vector
<
std
::
size_t
>&
dims1
,
const
std
::
vector
<
std
::
size_t
>&
dims2
);
std
::
vector
<
std
::
size_t
>
dims
;
std
::
vector
<
std
::
vector
<
std
::
size_t
>>
axes_map1
;
std
::
vector
<
std
::
vector
<
std
::
size_t
>>
axes_map2
;
};
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif // MIGRAPHX_GUARD_MIGRAPHX_COMMON_DIMS_HPP
src/include/migraphx/concat_opt.hpp
View file @
1530ec24
...
@@ -88,7 +88,7 @@ struct concat_optimization
...
@@ -88,7 +88,7 @@ struct concat_optimization
{
{
using
std
::
swap
;
using
std
::
swap
;
auto
*
derived
=
this
->
any_cast
<
PrivateDetailTypeErasedT
>
();
auto
*
derived
=
this
->
any_cast
<
PrivateDetailTypeErasedT
>
();
if
(
derived
and
private_detail_te_handle_mem_var
.
u
nique
()
)
if
(
derived
and
private_detail_te_handle_mem_var
.
u
se_count
()
==
1
)
{
{
*
derived
=
std
::
forward
<
PrivateDetailTypeErasedT
>
(
value
);
*
derived
=
std
::
forward
<
PrivateDetailTypeErasedT
>
(
value
);
}
}
...
@@ -233,7 +233,7 @@ struct concat_optimization
...
@@ -233,7 +233,7 @@ struct concat_optimization
private_detail_te_handle_base_type
&
private_detail_te_get_handle
()
private_detail_te_handle_base_type
&
private_detail_te_get_handle
()
{
{
assert
(
private_detail_te_handle_mem_var
!=
nullptr
);
assert
(
private_detail_te_handle_mem_var
!=
nullptr
);
if
(
not
private_detail_te_handle_mem_var
.
u
nique
()
)
if
(
private_detail_te_handle_mem_var
.
u
se_count
()
>
1
)
private_detail_te_handle_mem_var
=
private_detail_te_handle_mem_var
->
clone
();
private_detail_te_handle_mem_var
=
private_detail_te_handle_mem_var
->
clone
();
return
*
private_detail_te_handle_mem_var
;
return
*
private_detail_te_handle_mem_var
;
}
}
...
...
src/include/migraphx/context.hpp
View file @
1530ec24
...
@@ -118,7 +118,7 @@ struct context
...
@@ -118,7 +118,7 @@ struct context
{
{
using
std
::
swap
;
using
std
::
swap
;
auto
*
derived
=
this
->
any_cast
<
PrivateDetailTypeErasedT
>
();
auto
*
derived
=
this
->
any_cast
<
PrivateDetailTypeErasedT
>
();
if
(
derived
and
private_detail_te_handle_mem_var
.
u
nique
()
)
if
(
derived
and
private_detail_te_handle_mem_var
.
u
se_count
()
==
1
)
{
{
*
derived
=
std
::
forward
<
PrivateDetailTypeErasedT
>
(
value
);
*
derived
=
std
::
forward
<
PrivateDetailTypeErasedT
>
(
value
);
}
}
...
@@ -373,7 +373,7 @@ struct context
...
@@ -373,7 +373,7 @@ struct context
private_detail_te_handle_base_type
&
private_detail_te_get_handle
()
private_detail_te_handle_base_type
&
private_detail_te_get_handle
()
{
{
assert
(
private_detail_te_handle_mem_var
!=
nullptr
);
assert
(
private_detail_te_handle_mem_var
!=
nullptr
);
if
(
not
private_detail_te_handle_mem_var
.
u
nique
()
)
if
(
private_detail_te_handle_mem_var
.
u
se_count
()
>
1
)
private_detail_te_handle_mem_var
=
private_detail_te_handle_mem_var
->
clone
();
private_detail_te_handle_mem_var
=
private_detail_te_handle_mem_var
->
clone
();
return
*
private_detail_te_handle_mem_var
;
return
*
private_detail_te_handle_mem_var
;
}
}
...
...
src/include/migraphx/convolution.hpp
View file @
1530ec24
/*
/*
* The MIT License (MIT)
* The MIT License (MIT)
*
*
* Copyright (c) 2015-202
2
Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-202
3
Advanced Micro Devices, Inc. All rights reserved.
*
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* of this software and associated documentation files (the "Software"), to deal
...
@@ -62,7 +62,7 @@ void convolution(Output output, T input, T weights, Padding padding, Stride stri
...
@@ -62,7 +62,7 @@ void convolution(Output output, T input, T weights, Padding padding, Stride stri
shape
win_shape
{
output_shape
.
type
(),
win_size
};
shape
win_shape
{
output_shape
.
type
(),
win_size
};
double
acc
=
0.0
;
double
acc
=
0.0
;
shape_for_each
(
win_shape
,
[
&
](
auto
idx_win
)
{
shape_for_each
(
win_shape
,
[
&
](
const
auto
&
idx_win
)
{
auto
k
=
idx_win
[
0
];
auto
k
=
idx_win
[
0
];
const
auto
in_ch
=
group_id
*
wei_c
+
k
;
const
auto
in_ch
=
group_id
*
wei_c
+
k
;
std
::
vector
<
std
::
ptrdiff_t
>
idx
(
idx_o
.
begin
(),
idx_o
.
end
());
std
::
vector
<
std
::
ptrdiff_t
>
idx
(
idx_o
.
begin
(),
idx_o
.
end
());
...
...
Prev
1
2
3
4
5
6
…
16
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment