Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
32b83c9c
Commit
32b83c9c
authored
Sep 25, 2023
by
Khalique Ahmed
Browse files
Merge branch 'develop' of
https://github.com/ROCmSoftwarePlatform/AMDMIGraphX
into inner_bcast_fix
parents
92f5a6cd
434a06cf
Changes
291
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
533 additions
and
69 deletions
+533
-69
src/api/include/migraphx/migraphx.h
src/api/include/migraphx/migraphx.h
+2
-2
src/api/migraphx.py
src/api/migraphx.py
+3
-2
src/common_dims.cpp
src/common_dims.cpp
+156
-0
src/driver/CMakeLists.txt
src/driver/CMakeLists.txt
+3
-0
src/driver/argument_parser.hpp
src/driver/argument_parser.hpp
+14
-4
src/driver/main.cpp
src/driver/main.cpp
+15
-5
src/driver/verify.cpp
src/driver/verify.cpp
+29
-4
src/eliminate_contiguous.cpp
src/eliminate_contiguous.cpp
+20
-1
src/fuse_pointwise.cpp
src/fuse_pointwise.cpp
+54
-1
src/fuse_reduce.cpp
src/fuse_reduce.cpp
+2
-2
src/include/migraphx/algorithm.hpp
src/include/migraphx/algorithm.hpp
+38
-0
src/include/migraphx/check_shapes.hpp
src/include/migraphx/check_shapes.hpp
+58
-27
src/include/migraphx/common_dims.hpp
src/include/migraphx/common_dims.hpp
+49
-0
src/include/migraphx/convolution.hpp
src/include/migraphx/convolution.hpp
+2
-2
src/include/migraphx/matcher.hpp
src/include/migraphx/matcher.hpp
+12
-8
src/include/migraphx/normalize_attributes.hpp
src/include/migraphx/normalize_attributes.hpp
+32
-1
src/include/migraphx/op/allocate.hpp
src/include/migraphx/op/allocate.hpp
+33
-5
src/include/migraphx/op/common.hpp
src/include/migraphx/op/common.hpp
+6
-2
src/include/migraphx/op/contiguous.hpp
src/include/migraphx/op/contiguous.hpp
+1
-1
src/include/migraphx/op/convolution.hpp
src/include/migraphx/op/convolution.hpp
+4
-2
No files found.
src/api/include/migraphx/migraphx.h
View file @
32b83c9c
...
@@ -209,7 +209,7 @@ MIGRAPHX_C_EXPORT migraphx_status migraphx_dynamic_dimensions_assign_to(
...
@@ -209,7 +209,7 @@ MIGRAPHX_C_EXPORT migraphx_status migraphx_dynamic_dimensions_assign_to(
MIGRAPHX_C_EXPORT
migraphx_status
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_dynamic_dimensions_create
(
migraphx_dynamic_dimensions_t
*
dynamic_dimensions
,
migraphx_dynamic_dimensions_create
(
migraphx_dynamic_dimensions_t
*
dynamic_dimensions
,
const_migraphx_dynamic_dimension_t
*
ptr
,
const
const_migraphx_dynamic_dimension_t
*
ptr
,
size_t
size
);
size_t
size
);
MIGRAPHX_C_EXPORT
migraphx_status
MIGRAPHX_C_EXPORT
migraphx_status
...
@@ -377,7 +377,7 @@ MIGRAPHX_C_EXPORT migraphx_status migraphx_instructions_assign_to(
...
@@ -377,7 +377,7 @@ MIGRAPHX_C_EXPORT migraphx_status migraphx_instructions_assign_to(
migraphx_instructions_t
output
,
const_migraphx_instructions_t
input
);
migraphx_instructions_t
output
,
const_migraphx_instructions_t
input
);
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_instructions_create
(
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_instructions_create
(
migraphx_instructions_t
*
instructions
,
const_migraphx_instruction_t
*
ptr
,
size_t
size
);
migraphx_instructions_t
*
instructions
,
const
const_migraphx_instruction_t
*
ptr
,
size_t
size
);
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_modules_destroy
(
migraphx_modules_t
modules
);
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_modules_destroy
(
migraphx_modules_t
modules
);
...
...
src/api/migraphx.py
View file @
32b83c9c
...
@@ -79,7 +79,8 @@ def dynamic_dimension(h):
...
@@ -79,7 +79,8 @@ def dynamic_dimension(h):
def
dynamic_dimensions
(
h
):
def
dynamic_dimensions
(
h
):
h
.
constructor
(
h
.
constructor
(
'create'
,
'create'
,
api
.
params
(
ptr
=
'const_migraphx_dynamic_dimension_t*'
,
size
=
'size_t'
),
api
.
params
(
ptr
=
'const const_migraphx_dynamic_dimension_t*'
,
size
=
'size_t'
),
fname
=
'migraphx::to_obj_vector<const_migraphx_dynamic_dimension_t>'
)
fname
=
'migraphx::to_obj_vector<const_migraphx_dynamic_dimension_t>'
)
h
.
method
(
'size'
,
returns
=
'size_t'
)
h
.
method
(
'size'
,
returns
=
'size_t'
)
h
.
method
(
'get'
,
h
.
method
(
'get'
,
...
@@ -215,7 +216,7 @@ def instruction(h):
...
@@ -215,7 +216,7 @@ def instruction(h):
def
instructions
(
h
):
def
instructions
(
h
):
h
.
constructor
(
h
.
constructor
(
'create'
,
'create'
,
api
.
params
(
ptr
=
'const_migraphx_instruction_t*'
,
size
=
'size_t'
),
api
.
params
(
ptr
=
'const
const
_migraphx_instruction_t*'
,
size
=
'size_t'
),
fname
=
'migraphx::to_obj_vector<const_migraphx_instruction_t>'
)
fname
=
'migraphx::to_obj_vector<const_migraphx_instruction_t>'
)
...
...
src/common_dims.cpp
0 → 100644
View file @
32b83c9c
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/common_dims.hpp>
#include <migraphx/ranges.hpp>
#include <algorithm>
#include <cassert>
#include <numeric>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
template
<
class
Iterator
>
static
auto
compute_end_dim
(
Iterator
start
,
Iterator
last
,
std
::
size_t
dim
)
{
std
::
size_t
x
=
1
;
auto
it
=
std
::
find_if
(
start
,
last
,
[
&
](
auto
i
)
{
x
*=
i
;
return
x
>
dim
;
});
if
(
x
<
dim
)
return
start
;
return
it
;
}
template
<
class
Range
>
static
auto
elements
(
const
Range
&
r
)
{
return
std
::
accumulate
(
r
.
begin
(),
r
.
end
(),
std
::
size_t
{
1
},
std
::
multiplies
<>
{});
}
struct
common_dim_state
{
common_dim_state
(
const
std
::
vector
<
std
::
size_t
>&
pdims
,
std
::
vector
<
std
::
vector
<
std
::
size_t
>>&
paxes_map
)
:
dims
(
&
pdims
),
axes_map
(
&
paxes_map
),
it
(
dims
->
begin
())
{
}
const
std
::
vector
<
std
::
size_t
>*
dims
=
nullptr
;
std
::
vector
<
std
::
vector
<
std
::
size_t
>>*
axes_map
=
nullptr
;
std
::
vector
<
std
::
size_t
>::
const_iterator
it
{};
std
::
size_t
rem
=
1
;
std
::
size_t
get
()
const
{
return
*
it
/
rem
;
}
bool
is_end
()
const
{
return
it
==
dims
->
end
();
}
void
next
(
std
::
size_t
i
=
1
)
{
it
+=
i
;
}
auto
dims_for
(
std
::
size_t
d
)
const
{
auto
dim_end
=
compute_end_dim
(
it
,
dims
->
end
(),
d
);
return
range
(
it
,
dim_end
);
}
void
add_axes
(
std
::
size_t
naxes
,
std
::
size_t
start
)
MIGRAPHX_TIDY_CONST
{
auto
axes
=
compute_axes
(
naxes
,
start
);
axes_map
->
push_back
(
std
::
move
(
axes
));
}
void
add_multi_axes
(
std
::
size_t
naxes
,
std
::
size_t
start
)
MIGRAPHX_TIDY_CONST
{
auto
axes
=
compute_axes
(
naxes
,
start
);
std
::
transform
(
axes
.
begin
(),
axes
.
end
(),
std
::
back_inserter
(
*
axes_map
),
[
&
](
auto
axis
)
->
std
::
vector
<
std
::
size_t
>
{
return
{
axis
};
});
}
std
::
vector
<
std
::
size_t
>
compute_axes
(
std
::
size_t
naxes
,
std
::
size_t
start
)
const
{
if
(
rem
!=
1
)
{
assert
(
start
>
0
);
naxes
++
;
start
--
;
}
std
::
vector
<
std
::
size_t
>
axes
(
naxes
);
std
::
iota
(
axes
.
begin
(),
axes
.
end
(),
start
);
return
axes
;
}
};
static
bool
compute_common_dim
(
std
::
vector
<
std
::
size_t
>&
cd_dims
,
common_dim_state
&
state1
,
common_dim_state
&
state2
)
{
assert
(
state1
.
get
()
<=
state2
.
get
());
auto
d2
=
state2
.
get
();
auto
dims
=
state1
.
dims_for
(
d2
);
auto
n
=
elements
(
dims
);
auto
naxes
=
distance
(
dims
);
if
(
naxes
==
0
)
return
false
;
// If not divisible then we can't compute a common dim
if
((
d2
%
n
)
!=
0
)
return
false
;
auto
rem
=
d2
/
n
;
state1
.
add_multi_axes
(
naxes
,
cd_dims
.
size
());
state2
.
add_axes
(
rem
==
1
?
naxes
:
naxes
+
1
,
cd_dims
.
size
());
state1
.
rem
=
rem
;
state2
.
rem
=
1
;
cd_dims
.
insert
(
cd_dims
.
end
(),
dims
.
begin
(),
dims
.
end
());
if
(
state1
.
rem
!=
1
)
cd_dims
.
push_back
(
state1
.
rem
);
state1
.
next
(
distance
(
dims
));
state2
.
next
();
return
true
;
}
common_dims
common_dims
::
compute
(
const
std
::
vector
<
std
::
size_t
>&
dims1
,
const
std
::
vector
<
std
::
size_t
>&
dims2
)
{
assert
(
elements
(
dims1
)
>
0
);
assert
(
elements
(
dims1
)
==
elements
(
dims2
));
common_dims
cd
;
common_dim_state
state1
{
dims1
,
cd
.
axes_map1
};
common_dim_state
state2
{
dims2
,
cd
.
axes_map2
};
while
(
not
state1
.
is_end
()
and
not
state2
.
is_end
())
{
auto
d1
=
state1
.
get
();
auto
d2
=
state2
.
get
();
if
(
d1
<=
d2
)
{
if
(
not
compute_common_dim
(
cd
.
dims
,
state1
,
state2
))
return
{};
}
else
// if(d1 > d2)
{
if
(
not
compute_common_dim
(
cd
.
dims
,
state2
,
state1
))
return
{};
}
}
assert
(
elements
(
dims1
)
==
elements
(
cd
.
dims
));
return
cd
;
}
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
src/driver/CMakeLists.txt
View file @
32b83c9c
...
@@ -45,6 +45,9 @@ if(NOT WIN32)
...
@@ -45,6 +45,9 @@ if(NOT WIN32)
endif
()
endif
()
rocm_clang_tidy_check
(
driver
)
rocm_clang_tidy_check
(
driver
)
file
(
STRINGS
"
${
CMAKE_SOURCE_DIR
}
/test/onnx/.onnxrt-commit"
String_output
)
target_compile_definitions
(
driver PUBLIC MIGRAPHX_ORT_SHA1=
"
${
String_output
}
"
)
target_link_libraries
(
driver migraphx_all_targets migraphx_onnx migraphx_tf migraphx_py
)
target_link_libraries
(
driver migraphx_all_targets migraphx_onnx migraphx_tf migraphx_py
)
rocm_install_targets
(
rocm_install_targets
(
...
...
src/driver/argument_parser.hpp
View file @
32b83c9c
...
@@ -338,11 +338,22 @@ struct argument_parser
...
@@ -338,11 +338,22 @@ struct argument_parser
MIGRAPHX_DRIVER_STATIC
auto
file_exist
()
MIGRAPHX_DRIVER_STATIC
auto
file_exist
()
{
{
return
validate
([](
auto
&
,
auto
&
,
auto
&
params
)
{
return
validate
([](
auto
&
,
auto
&
,
const
auto
&
params
)
{
if
(
params
.
empty
())
if
(
params
.
empty
())
throw
std
::
runtime_error
(
"No argument passed."
);
throw
std
::
runtime_error
(
"No argument passed."
);
if
(
not
fs
::
exists
(
params
.
back
()))
if
(
not
fs
::
exists
(
params
.
back
()))
throw
std
::
runtime_error
(
"Path does not exists: "
+
params
.
back
());
throw
std
::
runtime_error
(
"Path does not exist: "
+
params
.
back
());
});
}
MIGRAPHX_DRIVER_STATIC
auto
matches
(
const
std
::
unordered_set
<
std
::
string
>&
names
)
{
return
validate
([
=
](
auto
&
,
auto
&
,
const
auto
&
params
)
{
auto
invalid_param
=
std
::
find_if
(
params
.
begin
(),
params
.
end
(),
[
&
](
const
auto
&
p
)
{
return
names
.
count
(
p
)
==
0
;
});
if
(
invalid_param
!=
params
.
end
())
throw
std
::
runtime_error
(
"Invalid argument: "
+
*
invalid_param
+
". Valid arguments are {"
+
to_string_range
(
names
)
+
"}"
);
});
});
}
}
...
@@ -570,8 +581,7 @@ struct argument_parser
...
@@ -570,8 +581,7 @@ struct argument_parser
continue
;
continue
;
if
(
flag
[
0
]
!=
'-'
)
if
(
flag
[
0
]
!=
'-'
)
continue
;
continue
;
auto
d
=
std
::
ptrdiff_t
d
=
levenshtein_distance
(
flag
,
input
);
levenshtein_distance
(
flag
.
begin
(),
flag
.
end
(),
input
.
begin
(),
input
.
end
());
if
(
d
<
result
.
distance
)
if
(
d
<
result
.
distance
)
result
=
result_t
{
&
arg
,
flag
,
input
,
d
};
result
=
result_t
{
&
arg
,
flag
,
input
,
d
};
}
}
...
...
src/driver/main.cpp
View file @
32b83c9c
...
@@ -82,6 +82,7 @@ struct loader
...
@@ -82,6 +82,7 @@ struct loader
{
"--model"
},
{
"--model"
},
ap
.
help
(
"Load model"
),
ap
.
help
(
"Load model"
),
ap
.
type
(
"resnet50|inceptionv3|alexnet"
),
ap
.
type
(
"resnet50|inceptionv3|alexnet"
),
ap
.
matches
({
"resnet50"
,
"inceptionv3"
,
"alexnet"
}),
ap
.
group
(
"input"
));
ap
.
group
(
"input"
));
ap
(
file_type
,
{
"--onnx"
},
ap
.
help
(
"Load as onnx"
),
ap
.
set_value
(
"onnx"
));
ap
(
file_type
,
{
"--onnx"
},
ap
.
help
(
"Load as onnx"
),
ap
.
set_value
(
"onnx"
));
ap
(
file_type
,
{
"--tf"
},
ap
.
help
(
"Load as tensorflow"
),
ap
.
set_value
(
"tf"
));
ap
(
file_type
,
{
"--tf"
},
ap
.
help
(
"Load as tensorflow"
),
ap
.
set_value
(
"tf"
));
...
@@ -474,13 +475,15 @@ struct compiler
...
@@ -474,13 +475,15 @@ struct compiler
{
{
if
(
is_offload_copy_set
(
p
)
and
not
co
.
offload_copy
)
if
(
is_offload_copy_set
(
p
)
and
not
co
.
offload_copy
)
{
{
std
::
cout
<<
"MIGraphX program was likely compiled with offload_copy set, Try "
std
::
cout
<<
"[WARNING]: MIGraphX program was likely compiled with offload_copy "
"set, Try "
"passing "
"passing "
"`--enable-offload-copy` if program run fails.
\n
"
;
"`--enable-offload-copy` if program run fails.
\n
"
;
}
}
else
if
(
co
.
offload_copy
)
else
if
(
co
.
offload_copy
)
{
{
std
::
cout
<<
"MIGraphX program was likely compiled without "
std
::
cout
<<
"
[WARNING]:
MIGraphX program was likely compiled without "
"offload_copy set, Try "
"offload_copy set, Try "
"removing "
"removing "
"`--enable-offload-copy` flag if passed to driver, if program run "
"`--enable-offload-copy` flag if passed to driver, if program run "
...
@@ -769,7 +772,7 @@ struct main_command
...
@@ -769,7 +772,7 @@ struct main_command
{
{
std
::
cout
<<
"'"
<<
color
::
fg_yellow
<<
wrong_commands
.
front
()
<<
color
::
reset
std
::
cout
<<
"'"
<<
color
::
fg_yellow
<<
wrong_commands
.
front
()
<<
color
::
reset
<<
"' is not a valid command."
<<
std
::
endl
;
<<
"' is not a valid command."
<<
std
::
endl
;
std
::
cout
<<
get_command_help
(
"Available commands:"
)
<<
std
::
endl
;
std
::
cout
<<
get_command_help
(
"Available commands:"
);
}
}
else
else
{
{
...
@@ -801,6 +804,13 @@ int main(int argc, const char* argv[])
...
@@ -801,6 +804,13 @@ int main(int argc, const char* argv[])
auto
&&
m
=
get_commands
();
auto
&&
m
=
get_commands
();
auto
cmd
=
args
.
front
();
auto
cmd
=
args
.
front
();
if
(
cmd
==
"ort-sha"
)
{
std
::
cout
<<
MIGRAPHX_ORT_SHA1
<<
std
::
endl
;
return
0
;
}
if
(
m
.
count
(
cmd
)
>
0
)
if
(
m
.
count
(
cmd
)
>
0
)
{
{
m
.
at
(
cmd
)(
argv
[
0
],
{
args
.
begin
()
+
1
,
args
.
end
()});
m
.
at
(
cmd
)(
argv
[
0
],
{
args
.
begin
()
+
1
,
args
.
end
()});
...
...
src/driver/verify.cpp
View file @
32b83c9c
...
@@ -30,6 +30,7 @@
...
@@ -30,6 +30,7 @@
#include <migraphx/instruction.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/compile_options.hpp>
#include <migraphx/compile_options.hpp>
#include <migraphx/quantization.hpp>
#include <migraphx/quantization.hpp>
#include <migraphx/ranges.hpp>
namespace
migraphx
{
namespace
migraphx
{
namespace
driver
{
namespace
driver
{
...
@@ -83,9 +84,19 @@ void verify_program(const std::string& name,
...
@@ -83,9 +84,19 @@ void verify_program(const std::string& name,
std
::
size_t
output_num
=
x
.
size
();
std
::
size_t
output_num
=
x
.
size
();
for
(
std
::
size_t
i
=
0
;
i
<
output_num
;
++
i
)
for
(
std
::
size_t
i
=
0
;
i
<
output_num
;
++
i
)
{
if
(
x
[
i
].
get_shape
().
type
()
!=
y
[
i
].
get_shape
().
type
()
or
x
[
i
].
get_shape
().
lens
()
!=
y
[
i
].
get_shape
().
lens
())
{
std
::
cout
<<
"FAILED: "
<<
name
<<
std
::
endl
;
std
::
cout
<<
"Shape mismatch {"
<<
x
[
i
].
get_shape
()
<<
"} != {"
<<
y
[
i
].
get_shape
()
<<
"}"
<<
std
::
endl
;
}
else
{
{
verify_args
(
name
,
x
[
i
],
y
[
i
],
tolerance
);
verify_args
(
name
,
x
[
i
],
y
[
i
],
tolerance
);
}
}
}
}
}
void
verify_instructions
(
const
program
&
prog
,
void
verify_instructions
(
const
program
&
prog
,
...
@@ -143,11 +154,19 @@ void verify_reduced(program p,
...
@@ -143,11 +154,19 @@ void verify_reduced(program p,
double
tolerance
)
double
tolerance
)
{
{
auto
*
mm
=
p
.
get_main_module
();
auto
*
mm
=
p
.
get_main_module
();
auto
last
=
std
::
prev
(
mm
->
end
(),
n
+
1
);
auto
last
=
std
::
prev
(
mm
->
end
(),
n
);
mm
->
remove_instructions
(
last
,
mm
->
end
());
mm
->
remove_instructions
(
last
,
mm
->
end
());
std
::
cout
<<
"Verify: "
<<
n
<<
std
::
endl
;
std
::
cout
<<
"Verify: "
<<
n
<<
std
::
endl
;
std
::
cout
<<
p
<<
std
::
endl
;
std
::
cout
<<
p
<<
std
::
endl
;
try
{
verify_program
(
std
::
to_string
(
n
),
p
,
t
,
options
,
quantize
,
inputs
,
tolerance
);
verify_program
(
std
::
to_string
(
n
),
p
,
t
,
options
,
quantize
,
inputs
,
tolerance
);
}
catch
(
const
std
::
exception
&
e
)
{
std
::
cout
<<
"FAILED: "
<<
n
<<
std
::
endl
;
std
::
cout
<<
"Exception: "
<<
e
.
what
()
<<
std
::
endl
;
}
}
}
void
verify_reduced_program
(
const
program
&
p
,
void
verify_reduced_program
(
const
program
&
p
,
...
@@ -160,8 +179,14 @@ void verify_reduced_program(const program& p,
...
@@ -160,8 +179,14 @@ void verify_reduced_program(const program& p,
const
auto
*
mm
=
p
.
get_main_module
();
const
auto
*
mm
=
p
.
get_main_module
();
auto
n
=
std
::
distance
(
mm
->
begin
(),
mm
->
end
());
auto
n
=
std
::
distance
(
mm
->
begin
(),
mm
->
end
());
std
::
cout
<<
"Verify steps: "
<<
n
<<
std
::
endl
;
std
::
cout
<<
"Verify steps: "
<<
n
<<
std
::
endl
;
for
(
std
::
size_t
i
=
0
;
i
<
n
;
i
++
)
for
(
std
::
size_t
i
=
1
;
i
<
n
;
i
++
)
{
{
auto
last
=
std
::
prev
(
mm
->
end
(),
i
+
1
);
if
(
contains
({
"@literal"
,
"@param"
},
last
->
name
()))
{
std
::
cout
<<
"Skip: "
<<
i
<<
std
::
endl
;
continue
;
}
verify_reduced
(
p
,
i
,
t
,
options
,
quantize
,
inputs
,
tolerance
);
verify_reduced
(
p
,
i
,
t
,
options
,
quantize
,
inputs
,
tolerance
);
}
}
}
}
...
...
src/eliminate_contiguous.cpp
View file @
32b83c9c
...
@@ -36,6 +36,8 @@
...
@@ -36,6 +36,8 @@
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_TRACE_ELIMINATE_CONTIGUOUS
)
static
bool
try_compute_shape
(
instruction_ref
ins
,
static
bool
try_compute_shape
(
instruction_ref
ins
,
const
std
::
vector
<
shape
>&
inputs
,
const
std
::
vector
<
shape
>&
inputs
,
const
std
::
vector
<
module_ref
>&
mods
)
const
std
::
vector
<
module_ref
>&
mods
)
...
@@ -79,14 +81,26 @@ static bool try_compute_shape(instruction_ref ins,
...
@@ -79,14 +81,26 @@ static bool try_compute_shape(instruction_ref ins,
return
(
arg
==
ins
)
?
new_shape
:
arg
->
get_shape
();
return
(
arg
==
ins
)
?
new_shape
:
arg
->
get_shape
();
});
});
if
(
not
try_compute_shape
(
output
,
input_shapes
,
mods
))
if
(
not
try_compute_shape
(
output
,
input_shapes
,
output
->
module_inputs
()
))
{
{
return
false
;
return
false
;
}
}
}
}
}
}
catch
(
const
std
::
exception
&
e
)
{
if
(
enabled
(
MIGRAPHX_TRACE_ELIMINATE_CONTIGUOUS
{}))
{
std
::
cout
<<
"Exception: "
<<
e
.
what
()
<<
std
::
endl
;
}
return
false
;
}
catch
(...)
catch
(...)
{
{
if
(
enabled
(
MIGRAPHX_TRACE_ELIMINATE_CONTIGUOUS
{}))
{
std
::
cout
<<
"Unknown exception"
<<
std
::
endl
;
}
return
false
;
return
false
;
}
}
...
@@ -128,6 +142,11 @@ static void remove_contiguous(const std::string& op_name, module& m, F f)
...
@@ -128,6 +142,11 @@ static void remove_contiguous(const std::string& op_name, module& m, F f)
{
{
if
(
arg
->
name
()
!=
op_name
)
if
(
arg
->
name
()
!=
op_name
)
continue
;
continue
;
if
(
enabled
(
MIGRAPHX_TRACE_ELIMINATE_CONTIGUOUS
{}))
{
std
::
cout
<<
"eliminate_contiguous: "
;
m
.
debug_print
(
ins
);
}
auto
prev
=
arg
->
inputs
().
front
();
auto
prev
=
arg
->
inputs
().
front
();
replace
(
new_args
,
arg
,
prev
);
replace
(
new_args
,
arg
,
prev
);
if
(
try_compute_shape
(
ins
,
new_args
,
mod_args
))
if
(
try_compute_shape
(
ins
,
new_args
,
mod_args
))
...
...
src/fuse_pointwise.cpp
View file @
32b83c9c
...
@@ -24,11 +24,14 @@
...
@@ -24,11 +24,14 @@
#include <migraphx/fuse_pointwise.hpp>
#include <migraphx/fuse_pointwise.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/simplify_reshapes.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/program.hpp>
#include <migraphx/program.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/matcher.hpp>
#include <migraphx/common_dims.hpp>
#include <iterator>
#include <iterator>
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_DISABLE_POINTWISE_FUSION
)
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_DISABLE_POINTWISE_FUSION
)
...
@@ -41,7 +44,7 @@ static literal get_scalar(instruction_ref ins)
...
@@ -41,7 +44,7 @@ static literal get_scalar(instruction_ref ins)
if
(
ins
->
name
()
==
"contiguous"
)
if
(
ins
->
name
()
==
"contiguous"
)
return
get_scalar
(
ins
->
inputs
().
front
());
return
get_scalar
(
ins
->
inputs
().
front
());
const
auto
&
s
=
ins
->
get_shape
();
const
auto
&
s
=
ins
->
get_shape
();
if
(
s
.
elements
()
!=
1
&&
not
(
s
.
scalar
()))
if
(
s
.
elements
()
!=
1
and
not
(
s
.
scalar
()))
return
{};
return
{};
if
(
not
ins
->
can_eval
())
if
(
not
ins
->
can_eval
())
return
{};
return
{};
...
@@ -189,6 +192,54 @@ static bool find_pointwise_modules(module& m)
...
@@ -189,6 +192,54 @@ static bool find_pointwise_modules(module& m)
}
}
return
changed
;
return
changed
;
}
}
namespace
{
struct
find_pointwise_reshape_pointwise
{
auto
matcher
()
const
{
auto
reshape
=
match
::
name
(
"reshape"
,
"squeeze"
,
"unsqueeze"
,
"flatten"
)(
match
::
used_once
());
auto
skip_contiguous
=
[](
auto
...
ms
)
{
return
match
::
arg
(
0
)(
match
::
skip
(
match
::
name
(
"contiguous"
)(
match
::
used_once
()))(
ms
...));
};
auto
pointwise
=
match
::
name
(
"pointwise"
)(
match
::
used_once
());
auto
reshape_pointwise
=
reshape
(
skip_contiguous
(
pointwise
.
bind
(
"x"
))).
bind
(
"reshape"
);
return
match
::
name
(
"pointwise"
)(
match
::
any_of
[
match
::
inputs
()](
reshape_pointwise
));
}
void
apply
(
module
&
m
,
const
match
::
matcher_result
&
r
)
const
{
auto
ins
=
r
.
result
;
auto
x_ins
=
r
.
instructions
[
"x"
];
auto
reshape_ins
=
r
.
instructions
[
"reshape"
];
auto
cd
=
common_dims
::
compute
(
ins
->
get_shape
().
lens
(),
x_ins
->
get_shape
().
lens
());
if
(
cd
.
dims
.
empty
())
return
;
auto
reshape_input
=
[
&
](
const
auto
&
ins_to_insert
)
{
return
[
&
](
auto
input
)
{
auto
c
=
m
.
insert_instruction
(
ins_to_insert
,
make_op
(
"contiguous"
),
input
);
return
m
.
insert_instruction
(
ins_to_insert
,
make_op
(
"reshape"
,
{{
"dims"
,
cd
.
dims
}}),
c
);
};
};
auto
x_inputs
=
x_ins
->
inputs
();
std
::
transform
(
x_inputs
.
begin
(),
x_inputs
.
end
(),
x_inputs
.
begin
(),
reshape_input
(
x_ins
));
auto
new_x_ins
=
m
.
insert_instruction
(
x_ins
,
x_ins
->
get_operator
(),
x_inputs
,
x_ins
->
module_inputs
());
auto
inputs
=
ins
->
inputs
();
std
::
transform
(
inputs
.
begin
(),
inputs
.
end
(),
inputs
.
begin
(),
[
&
](
auto
input
)
{
if
(
input
==
reshape_ins
)
return
new_x_ins
;
return
reshape_input
(
ins
)(
input
);
});
auto
pw
=
m
.
insert_instruction
(
ins
,
ins
->
get_operator
(),
inputs
,
ins
->
module_inputs
());
m
.
replace_instruction
(
ins
,
make_op
(
"reshape"
,
{{
"dims"
,
ins
->
get_shape
().
lens
()}}),
pw
);
}
};
}
// namespace
void
fuse_pointwise
::
apply
(
module_pass_manager
&
mpm
)
const
void
fuse_pointwise
::
apply
(
module_pass_manager
&
mpm
)
const
{
{
...
@@ -200,6 +251,8 @@ void fuse_pointwise::apply(module_pass_manager& mpm) const
...
@@ -200,6 +251,8 @@ void fuse_pointwise::apply(module_pass_manager& mpm) const
}
}
for
(
int
i
=
0
;
i
<
8
;
i
++
)
for
(
int
i
=
0
;
i
<
8
;
i
++
)
{
{
match
::
find_matches
(
mpm
.
get_module
(),
find_pointwise_reshape_pointwise
{});
mpm
.
run_pass
(
simplify_reshapes
{
1
});
if
(
not
find_pointwise_modules
(
mpm
.
get_module
()))
if
(
not
find_pointwise_modules
(
mpm
.
get_module
()))
break
;
break
;
mpm
.
run_pass
(
dead_code_elimination
{});
mpm
.
run_pass
(
dead_code_elimination
{});
...
...
src/fuse_reduce.cpp
View file @
32b83c9c
...
@@ -52,7 +52,7 @@ struct fused_reduce
...
@@ -52,7 +52,7 @@ struct fused_reduce
{
{
if
(
mods
.
size
()
!=
1
)
if
(
mods
.
size
()
!=
1
)
MIGRAPHX_THROW
(
"should have one submodule."
);
MIGRAPHX_THROW
(
"should have one submodule."
);
auto
*
sm
=
mods
.
front
();
const
auto
*
sm
=
mods
.
front
();
if
(
sm
->
get_output_shapes
().
size
()
!=
1
)
if
(
sm
->
get_output_shapes
().
size
()
!=
1
)
MIGRAPHX_THROW
(
"Only one output supported"
);
MIGRAPHX_THROW
(
"Only one output supported"
);
auto
names
=
sm
->
get_parameter_names
();
auto
names
=
sm
->
get_parameter_names
();
...
@@ -143,7 +143,7 @@ insert_module_in_submodule(module_ref sm,
...
@@ -143,7 +143,7 @@ insert_module_in_submodule(module_ref sm,
}
}
static
std
::
vector
<
instruction_ref
>
static
std
::
vector
<
instruction_ref
>
find_inputs
(
module_ref
sm
,
find_inputs
(
const_
module_ref
sm
,
const
module
&
parent
,
const
module
&
parent
,
const
std
::
unordered_map
<
instruction_ref
,
instruction_ref
>&
map_ins
)
const
std
::
unordered_map
<
instruction_ref
,
instruction_ref
>&
map_ins
)
{
{
...
...
src/include/migraphx/algorithm.hpp
View file @
32b83c9c
...
@@ -26,6 +26,8 @@
...
@@ -26,6 +26,8 @@
#include <algorithm>
#include <algorithm>
#include <numeric>
#include <numeric>
#include <string>
#include <vector>
#include <migraphx/config.hpp>
#include <migraphx/config.hpp>
namespace
migraphx
{
namespace
migraphx
{
...
@@ -90,6 +92,42 @@ levenshtein_distance(Iterator1 first1, Iterator1 last1, Iterator2 first2, Iterat
...
@@ -90,6 +92,42 @@ levenshtein_distance(Iterator1 first1, Iterator1 last1, Iterator2 first2, Iterat
return
std
::
ptrdiff_t
{
1
}
+
std
::
min
({
x1
,
x2
,
x3
});
return
std
::
ptrdiff_t
{
1
}
+
std
::
min
({
x1
,
x2
,
x3
});
}
}
inline
size_t
levenshtein_distance
(
const
std
::
string
&
s1
,
const
std
::
string
&
s2
)
{
const
size_t
l1
=
s1
.
length
();
const
size_t
l2
=
s2
.
length
();
if
(
l1
<
l2
)
levenshtein_distance
(
s2
,
s1
);
std
::
vector
<
size_t
>
d
(
l2
+
1
);
std
::
iota
(
d
.
begin
(),
d
.
end
(),
0
);
for
(
size_t
i
=
1
;
i
<=
l1
;
i
++
)
{
size_t
prev_cost
=
d
[
0
];
d
[
0
]
=
i
;
for
(
size_t
j
=
1
;
j
<=
l2
;
j
++
)
{
if
(
s1
[
i
-
1
]
==
s2
[
j
-
1
])
{
d
[
j
]
=
prev_cost
;
}
else
{
size_t
cost_insert_or_delete
=
std
::
min
(
d
[
j
-
1
],
d
[
j
]);
size_t
cost_substitute
=
prev_cost
;
prev_cost
=
d
[
j
];
d
[
j
]
=
std
::
min
(
cost_substitute
,
cost_insert_or_delete
)
+
1
;
}
}
}
return
d
[
l2
];
}
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
}
// namespace migraphx
...
...
src/include/migraphx/check_shapes.hpp
View file @
32b83c9c
/*
/*
* The MIT License (MIT)
* The MIT License (MIT)
*
*
* Copyright (c) 2015-202
2
Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-202
3
Advanced Micro Devices, Inc. All rights reserved.
*
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* of this software and associated documentation files (the "Software"), to deal
...
@@ -34,29 +34,51 @@
...
@@ -34,29 +34,51 @@
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
// Check that deduced type is incrementable, dereferencable, and comparable
template
<
class
,
class
=
void
>
struct
is_iterator
{
};
template
<
class
T
>
struct
is_iterator
<
T
,
std
::
void_t
<
decltype
(
++
std
::
declval
<
T
&>
()),
decltype
(
*
std
::
declval
<
T
&>
()),
decltype
(
std
::
declval
<
T
&>
()
==
std
::
declval
<
T
&>
())
>>
:
std
::
true_type
{
};
template
<
class
Iterator
>
struct
check_shapes
struct
check_shapes
{
{
const
shape
*
begin
;
static_assert
(
is_iterator
<
Iterator
>
{},
"CHECK_SHAPES: Deduced type must be an iterator"
);
const
shape
*
end
;
Iterator
begin
;
Iterator
end
;
std
::
string
name
;
std
::
string
name
;
bool
dynamic_allowed
;
bool
dynamic_allowed
;
check_shapes
(
const
shape
*
b
,
const
shape
*
e
,
const
std
::
string
&
n
,
const
bool
d
=
false
)
check_shapes
(
Iterator
b
,
Iterator
e
,
const
std
::
string
&
n
,
const
bool
d
=
false
)
:
begin
(
b
),
end
(
e
),
name
(
n
),
dynamic_allowed
(
d
)
:
begin
(
b
),
end
(
e
),
name
(
n
),
dynamic_allowed
(
d
)
{
{
check_dynamic
();
check_dynamic
();
}
}
template
<
class
Op
>
template
<
class
Op
>
check_shapes
(
const
shape
*
b
,
const
shape
*
e
,
const
Op
&
op
,
const
bool
d
=
false
)
check_shapes
(
Iterator
b
,
Iterator
e
,
const
Op
&
op
,
const
bool
d
=
false
)
:
begin
(
b
),
end
(
e
),
name
(
op
.
name
()),
dynamic_allowed
(
d
)
:
begin
(
b
),
end
(
e
),
name
(
op
.
name
()),
dynamic_allowed
(
d
)
{
{
check_dynamic
();
check_dynamic
();
}
}
template
<
class
Op
>
template
<
class
Op
,
MIGRAPHX_REQUIRES
(
not
std
::
is_convertible
<
Op
,
std
::
string
>{})
>
check_shapes
(
const
std
::
vector
<
shape
>&
s
,
const
Op
&
op
,
const
bool
d
=
false
)
check_shapes
(
const
std
::
vector
<
shape
>&
s
,
const
Op
&
op
,
const
bool
d
=
false
)
:
begin
(
s
.
data
()),
end
(
s
.
data
()
+
s
.
size
()),
name
(
op
.
name
()),
dynamic_allowed
(
d
)
:
begin
(
s
.
begin
()),
end
(
s
.
end
()),
name
(
op
.
name
()),
dynamic_allowed
(
d
)
{
check_dynamic
();
}
check_shapes
(
const
std
::
vector
<
shape
>&
s
,
const
std
::
string
&
n
,
const
bool
d
=
false
)
:
begin
(
s
.
begin
()),
end
(
s
.
end
()),
name
(
n
),
dynamic_allowed
(
d
)
{
{
check_dynamic
();
check_dynamic
();
}
}
...
@@ -81,8 +103,6 @@ struct check_shapes
...
@@ -81,8 +103,6 @@ struct check_shapes
{
{
if
(
begin
==
end
)
if
(
begin
==
end
)
return
0
;
return
0
;
assert
(
begin
!=
nullptr
);
assert
(
end
!=
nullptr
);
return
end
-
begin
;
return
end
-
begin
;
}
}
...
@@ -131,11 +151,9 @@ struct check_shapes
...
@@ -131,11 +151,9 @@ struct check_shapes
*/
*/
const
check_shapes
&
only_dims
(
std
::
size_t
n
)
const
const
check_shapes
&
only_dims
(
std
::
size_t
n
)
const
{
{
assert
(
begin
!=
nullptr
);
assert
(
end
!=
nullptr
);
if
(
begin
!=
end
)
if
(
begin
!=
end
)
{
{
if
(
begin
->
max_lens
().
size
()
!=
n
)
if
(
begin
->
ndim
()
!=
n
)
MIGRAPHX_THROW
(
prefix
()
+
"Only "
+
std
::
to_string
(
n
)
+
"d supported"
);
MIGRAPHX_THROW
(
prefix
()
+
"Only "
+
std
::
to_string
(
n
)
+
"d supported"
);
}
}
return
*
this
;
return
*
this
;
...
@@ -148,11 +166,9 @@ struct check_shapes
...
@@ -148,11 +166,9 @@ struct check_shapes
*/
*/
const
check_shapes
&
max_ndims
(
std
::
size_t
n
)
const
const
check_shapes
&
max_ndims
(
std
::
size_t
n
)
const
{
{
assert
(
begin
!=
nullptr
);
assert
(
end
!=
nullptr
);
if
(
begin
!=
end
)
if
(
begin
!=
end
)
{
{
if
(
begin
->
max_lens
().
size
()
>
n
)
if
(
begin
->
ndim
()
>
n
)
MIGRAPHX_THROW
(
prefix
()
+
"Shape must have at most "
+
std
::
to_string
(
n
)
+
MIGRAPHX_THROW
(
prefix
()
+
"Shape must have at most "
+
std
::
to_string
(
n
)
+
" dimensions"
);
" dimensions"
);
}
}
...
@@ -166,11 +182,9 @@ struct check_shapes
...
@@ -166,11 +182,9 @@ struct check_shapes
*/
*/
const
check_shapes
&
min_ndims
(
std
::
size_t
n
)
const
const
check_shapes
&
min_ndims
(
std
::
size_t
n
)
const
{
{
assert
(
begin
!=
nullptr
);
assert
(
end
!=
nullptr
);
if
(
begin
!=
end
)
if
(
begin
!=
end
)
{
{
if
(
begin
->
max_lens
().
size
()
<
n
)
if
(
begin
->
ndim
()
<
n
)
MIGRAPHX_THROW
(
prefix
()
+
"Shape must have at least "
+
std
::
to_string
(
n
)
+
MIGRAPHX_THROW
(
prefix
()
+
"Shape must have at least "
+
std
::
to_string
(
n
)
+
" dimensions"
);
" dimensions"
);
}
}
...
@@ -220,6 +234,16 @@ struct check_shapes
...
@@ -220,6 +234,16 @@ struct check_shapes
return
*
this
;
return
*
this
;
}
}
/*!
* Check all shapes have the same layout.
*/
const
check_shapes
&
same_layout
()
const
{
if
(
not
this
->
same
([](
const
shape
&
s
)
{
return
find_permutation
(
s
);
}))
MIGRAPHX_THROW
(
prefix
()
+
"Layouts do not match"
);
return
*
this
;
}
/*!
/*!
* Check all shapes are standard.
* Check all shapes are standard.
*/
*/
...
@@ -230,6 +254,16 @@ struct check_shapes
...
@@ -230,6 +254,16 @@ struct check_shapes
return
*
this
;
return
*
this
;
}
}
/*!
* Check all shapes are scalar.
*/
const
check_shapes
&
scalar
()
const
{
if
(
not
this
->
all_of
([](
const
shape
&
s
)
{
return
s
.
scalar
();
}))
MIGRAPHX_THROW
(
prefix
()
+
"Shapes are not a scalar"
);
return
*
this
;
}
/*!
/*!
* Check all shapes are standard or scalar.
* Check all shapes are standard or scalar.
*/
*/
...
@@ -330,8 +364,6 @@ struct check_shapes
...
@@ -330,8 +364,6 @@ struct check_shapes
{
{
if
(
begin
==
end
)
if
(
begin
==
end
)
return
true
;
return
true
;
assert
(
begin
!=
nullptr
);
assert
(
end
!=
nullptr
);
auto
&&
key
=
f
(
*
begin
);
auto
&&
key
=
f
(
*
begin
);
return
this
->
all_of
([
&
](
const
shape
&
s
)
{
return
f
(
s
)
==
key
;
});
return
this
->
all_of
([
&
](
const
shape
&
s
)
{
return
f
(
s
)
==
key
;
});
}
}
...
@@ -341,8 +373,6 @@ struct check_shapes
...
@@ -341,8 +373,6 @@ struct check_shapes
{
{
if
(
begin
==
end
)
if
(
begin
==
end
)
return
true
;
return
true
;
assert
(
begin
!=
nullptr
);
assert
(
end
!=
nullptr
);
return
std
::
all_of
(
begin
,
end
,
p
);
return
std
::
all_of
(
begin
,
end
,
p
);
}
}
...
@@ -351,17 +381,13 @@ struct check_shapes
...
@@ -351,17 +381,13 @@ struct check_shapes
{
{
if
(
begin
==
end
)
if
(
begin
==
end
)
return
false
;
return
false
;
assert
(
begin
!=
nullptr
);
assert
(
end
!=
nullptr
);
return
std
::
any_of
(
begin
,
end
,
p
);
return
std
::
any_of
(
begin
,
end
,
p
);
}
}
const
shape
*
get
(
long
i
)
const
Iterator
get
(
long
i
)
const
{
{
if
(
i
>=
size
())
if
(
i
>=
size
())
MIGRAPHX_THROW
(
prefix
()
+
"Accessing shape out of bounds"
);
MIGRAPHX_THROW
(
prefix
()
+
"Accessing shape out of bounds"
);
assert
(
begin
!=
nullptr
);
assert
(
end
!=
nullptr
);
if
(
i
<
0
)
if
(
i
<
0
)
return
end
-
i
;
return
end
-
i
;
return
begin
+
i
;
return
begin
+
i
;
...
@@ -394,6 +420,11 @@ struct check_shapes
...
@@ -394,6 +420,11 @@ struct check_shapes
}
}
};
};
// Deduction guide for std::vector constructor
template
<
class
Op
>
check_shapes
(
const
std
::
vector
<
shape
>&
,
const
Op
&
,
bool
d
=
false
)
->
check_shapes
<
std
::
vector
<
shape
>::
const_iterator
>
;
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
}
// namespace migraphx
...
...
src/include/migraphx/common_dims.hpp
0 → 100644
View file @
32b83c9c
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_MIGRAPHX_COMMON_DIMS_HPP
#define MIGRAPHX_GUARD_MIGRAPHX_COMMON_DIMS_HPP
#include <migraphx/config.hpp>
#include <cstdint>
#include <vector>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
/// This will compute a higher dimensional space that will preserve the axes
/// for both sets of dimensions. Two axes_maps are provided for each of the
/// dims that will map the axis to the axes that are used by the result of
/// common_dims.
struct
MIGRAPHX_EXPORT
common_dims
{
static
common_dims
compute
(
const
std
::
vector
<
std
::
size_t
>&
dims1
,
const
std
::
vector
<
std
::
size_t
>&
dims2
);
std
::
vector
<
std
::
size_t
>
dims
;
std
::
vector
<
std
::
vector
<
std
::
size_t
>>
axes_map1
;
std
::
vector
<
std
::
vector
<
std
::
size_t
>>
axes_map2
;
};
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif // MIGRAPHX_GUARD_MIGRAPHX_COMMON_DIMS_HPP
src/include/migraphx/convolution.hpp
View file @
32b83c9c
/*
/*
* The MIT License (MIT)
* The MIT License (MIT)
*
*
* Copyright (c) 2015-202
2
Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-202
3
Advanced Micro Devices, Inc. All rights reserved.
*
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* of this software and associated documentation files (the "Software"), to deal
...
@@ -62,7 +62,7 @@ void convolution(Output output, T input, T weights, Padding padding, Stride stri
...
@@ -62,7 +62,7 @@ void convolution(Output output, T input, T weights, Padding padding, Stride stri
shape
win_shape
{
output_shape
.
type
(),
win_size
};
shape
win_shape
{
output_shape
.
type
(),
win_size
};
double
acc
=
0.0
;
double
acc
=
0.0
;
shape_for_each
(
win_shape
,
[
&
](
auto
idx_win
)
{
shape_for_each
(
win_shape
,
[
&
](
const
auto
&
idx_win
)
{
auto
k
=
idx_win
[
0
];
auto
k
=
idx_win
[
0
];
const
auto
in_ch
=
group_id
*
wei_c
+
k
;
const
auto
in_ch
=
group_id
*
wei_c
+
k
;
std
::
vector
<
std
::
ptrdiff_t
>
idx
(
idx_o
.
begin
(),
idx_o
.
end
());
std
::
vector
<
std
::
ptrdiff_t
>
idx
(
idx_o
.
begin
(),
idx_o
.
end
());
...
...
src/include/migraphx/matcher.hpp
View file @
32b83c9c
...
@@ -381,22 +381,24 @@ void find_matches_for(source_location location, Mod& mod, instruction_ref ins, M
...
@@ -381,22 +381,24 @@ void find_matches_for(source_location location, Mod& mod, instruction_ref ins, M
const
int
trace
=
value_of
(
MIGRAPHX_TRACE_MATCHES
{});
const
int
trace
=
value_of
(
MIGRAPHX_TRACE_MATCHES
{});
const
bool
validate
=
enabled
(
MIGRAPHX_VALIDATE_MATCHES
{});
const
bool
validate
=
enabled
(
MIGRAPHX_VALIDATE_MATCHES
{});
const
auto
trace_filter
=
string_value_of
(
MIGRAPHX_TRACE_MATCHES_FOR
{});
const
auto
trace_filter
=
string_value_of
(
MIGRAPHX_TRACE_MATCHES_FOR
{});
const
bool
trace_for
=
not
trace_filter
.
empty
()
and
(
contains
(
std
::
string
{
location
.
file_name
()},
trace_filter
)
or
contains
(
std
::
string
{
location
.
function_name
()},
trace_filter
));
bool
match
=
false
;
bool
match
=
false
;
each_args
(
each_args
(
[
&
](
auto
&&
m
)
{
[
&
](
auto
&&
m
)
{
const
auto
&
matcher_name
=
get_type_name
(
m
);
const
bool
trace_for
=
not
trace_filter
.
empty
()
and
(
contains
(
std
::
string
{
location
.
file_name
()},
trace_filter
)
or
contains
(
std
::
string
{
location
.
function_name
()},
trace_filter
)
or
contains
(
matcher_name
,
trace_filter
));
if
(
match
)
if
(
match
)
return
;
return
;
if
(
trace
>
1
or
trace_for
)
if
(
trace
>
1
and
trace_for
)
std
::
cout
<<
"Match: "
<<
get_type
_name
(
m
)
<<
std
::
endl
;
std
::
cout
<<
"Match: "
<<
matcher
_name
<<
std
::
endl
;
auto
r
=
match_instruction
(
get_module
(
mod
),
ins
,
m
.
matcher
());
auto
r
=
match_instruction
(
get_module
(
mod
),
ins
,
m
.
matcher
());
if
(
r
.
result
==
get_module
(
mod
).
end
())
if
(
r
.
result
==
get_module
(
mod
).
end
())
return
;
return
;
if
(
trace
>
0
or
trace_for
)
if
(
trace
>
0
or
trace_for
)
{
{
std
::
cout
<<
"Matched by "
<<
get_type
_name
(
m
)
<<
std
::
endl
;
std
::
cout
<<
"Matched by "
<<
matcher
_name
<<
std
::
endl
;
get_module
(
mod
).
debug_print
(
ins
);
get_module
(
mod
).
debug_print
(
ins
);
}
}
// If its already invalid dont validate it again
// If its already invalid dont validate it again
...
@@ -407,7 +409,7 @@ void find_matches_for(source_location location, Mod& mod, instruction_ref ins, M
...
@@ -407,7 +409,7 @@ void find_matches_for(source_location location, Mod& mod, instruction_ref ins, M
auto
invalid
=
get_module
(
mod
).
validate
();
auto
invalid
=
get_module
(
mod
).
validate
();
if
(
invalid
!=
get_module
(
mod
).
end
())
if
(
invalid
!=
get_module
(
mod
).
end
())
{
{
std
::
cout
<<
"Invalid program from match: "
<<
get_type
_name
(
m
)
<<
std
::
endl
;
std
::
cout
<<
"Invalid program from match: "
<<
matcher
_name
<<
std
::
endl
;
std
::
cout
<<
"Invalid instructions: "
<<
std
::
endl
;
std
::
cout
<<
"Invalid instructions: "
<<
std
::
endl
;
get_module
(
mod
).
debug_print
(
invalid
->
inputs
());
get_module
(
mod
).
debug_print
(
invalid
->
inputs
());
get_module
(
mod
).
debug_print
(
invalid
);
get_module
(
mod
).
debug_print
(
invalid
);
...
@@ -621,6 +623,8 @@ MIGRAPHX_PRED_MATCHER(broadcast, instruction_ref ins)
...
@@ -621,6 +623,8 @@ MIGRAPHX_PRED_MATCHER(broadcast, instruction_ref ins)
template
<
class
...
Ms
>
template
<
class
...
Ms
>
auto
skip
(
Ms
...
ms
)
auto
skip
(
Ms
...
ms
)
{
{
static_assert
(((
not
std
::
is_convertible
<
Ms
,
std
::
string
>
{})
and
...),
"Use a matcher not a string for skip."
);
auto
m
=
any_of
(
ms
...);
auto
m
=
any_of
(
ms
...);
return
make_basic_fun_matcher
([
=
](
matcher_context
&
ctx
,
instruction_ref
start
)
{
return
make_basic_fun_matcher
([
=
](
matcher_context
&
ctx
,
instruction_ref
start
)
{
return
fix
<
optional
<
instruction_ref
>>
(
return
fix
<
optional
<
instruction_ref
>>
(
...
...
src/include/migraphx/normalize_attributes.hpp
View file @
32b83c9c
/*
/*
* The MIT License (MIT)
* The MIT License (MIT)
*
*
* Copyright (c) 2015-202
2
Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-202
3
Advanced Micro Devices, Inc. All rights reserved.
*
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* of this software and associated documentation files (the "Software"), to deal
...
@@ -28,6 +28,7 @@
...
@@ -28,6 +28,7 @@
#include <migraphx/shape.hpp>
#include <migraphx/shape.hpp>
#include <cstring>
#include <cstring>
#include <vector>
#include <vector>
#include <migraphx/op/normalize_attribute.hpp>
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
@@ -42,6 +43,36 @@ struct select_dependent_type
...
@@ -42,6 +43,36 @@ struct select_dependent_type
template
<
class
T
,
class
...
Ts
>
template
<
class
T
,
class
...
Ts
>
using
dependent_type
=
typename
select_dependent_type
<
T
,
Ts
...
>::
type
;
using
dependent_type
=
typename
select_dependent_type
<
T
,
Ts
...
>::
type
;
/**
* Used to normalize variable input axes at model runtime.
* Example: the axes inputs of the slice operator.
*
* \param axes the axes to normalize
* \param input_shape shape of the input tensor
* \param attr_val the normalize_axes attributes from the operator
* \param prefix error message prefix
*/
std
::
vector
<
int64_t
>
normalize_axes
(
const
std
::
vector
<
int64_t
>&
axes
,
const
shape
&
input_shape
,
const
value
&
attr_val
,
const
std
::
string
&
prefix
=
""
);
/**
* Used to normalize variable input axes at model runtime.
* Example: the starts and ends inputs of the slice operator.
*
* \param indices the indices to normalize
* \param axes which axes the indices apply over
* \param input_shape shape of the input tensor
* \param attr_val the normalize_axes attributes from the operator
* \param prefix error message prefix
*/
std
::
vector
<
int64_t
>
normalize_indices
(
const
std
::
vector
<
int64_t
>&
indices
,
const
std
::
vector
<
int64_t
>&
axes
,
const
shape
&
input_shape
,
const
value
&
attr_val
,
const
std
::
string
&
prefix
=
""
);
MIGRAPHX_EXPORT
MIGRAPHX_EXPORT
bool
normalize_attributes
(
operation
&
op
,
const
shape
&
input_shape
);
bool
normalize_attributes
(
operation
&
op
,
const
shape
&
input_shape
);
...
...
src/include/migraphx/op/allocate.hpp
View file @
32b83c9c
...
@@ -36,21 +36,49 @@ namespace op {
...
@@ -36,21 +36,49 @@ namespace op {
struct
allocate
struct
allocate
{
{
shape
s
{};
shape
s
{};
// for dynamic allocate to set the buffer type
shape
::
type_t
buf_type
=
shape
::
half_type
;
template
<
class
Self
,
class
F
>
template
<
class
Self
,
class
F
>
static
auto
reflect
(
Self
&
self
,
F
f
)
static
auto
reflect
(
Self
&
self
,
F
f
)
{
{
return
pack
(
f
(
self
.
s
,
"shape"
));
return
pack
(
f
(
self
.
s
,
"shape"
)
,
f
(
self
.
buf_type
,
"buf_type"
)
);
}
}
std
::
string
name
()
const
{
return
"allocate"
;
}
std
::
string
name
()
const
{
return
"allocate"
;
}
shape
compute_shape
(
const
std
::
vector
<
shape
>&
inputs
)
const
shape
compute_shape
(
const
std
::
vector
<
shape
>&
inputs
)
const
{
{
migraphx
::
check_shapes
{
inputs
,
*
this
,
true
}.
has
(
0
);
migraphx
::
check_shapes
{
inputs
,
*
this
,
true
}.
has
(
0
,
1
);
// check if shape attribute is not default
if
(
s
!=
shape
())
{
return
s
;
return
s
;
}
}
argument
compute
(
const
shape
&
output_shape
,
const
std
::
vector
<
argument
>&
)
const
else
{
const
auto
&
out_dims
=
inputs
.
at
(
0
);
assert
(
not
out_dims
.
dynamic
());
assert
(
out_dims
.
ndim
()
==
1
);
std
::
size_t
max_val
=
std
::
numeric_limits
<
std
::
size_t
>::
max
();
std
::
vector
<
shape
::
dynamic_dimension
>
dyn_dims
(
out_dims
.
lens
().
at
(
0
),
shape
::
dynamic_dimension
{
0
,
max_val
});
return
{
buf_type
,
dyn_dims
};
}
}
argument
compute
(
const
shape
&
output_shape
,
const
std
::
vector
<
argument
>&
args
)
const
{
if
(
args
.
empty
())
{
{
return
{
output_shape
};
return
{
output_shape
};
}
}
else
{
std
::
vector
<
std
::
size_t
>
output_dims
(
output_shape
.
ndim
());
args
.
at
(
0
).
visit
([
&
](
auto
a
)
{
output_dims
.
assign
(
a
.
begin
(),
a
.
end
());
});
return
{
shape
{
buf_type
,
output_dims
}};
}
}
};
};
}
// namespace op
}
// namespace op
...
...
src/include/migraphx/op/common.hpp
View file @
32b83c9c
/*
/*
* The MIT License (MIT)
* The MIT License (MIT)
*
*
* Copyright (c) 2015-202
2
Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-202
3
Advanced Micro Devices, Inc. All rights reserved.
*
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* of this software and associated documentation files (the "Software"), to deal
...
@@ -33,8 +33,12 @@ namespace migraphx {
...
@@ -33,8 +33,12 @@ namespace migraphx {
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
op
{
namespace
op
{
// Specifies where to add the "extra" cell of padding if the
// calculated padding is an odd number.
// Padding mode is default_ for fixed shape padding.
// Padding mode is default_ for fixed shape padding.
// same_lower and same_upper used for dynamic padding.
// same_lower and same_upper specify dynamic padding.
// The odd cell goes at the beginning of the dimension
// (same_lower) or end (same_upper).
enum
padding_mode_t
enum
padding_mode_t
{
{
default_
,
// NOLINT
default_
,
// NOLINT
...
...
src/include/migraphx/op/contiguous.hpp
View file @
32b83c9c
/*
/*
* The MIT License (MIT)
* The MIT License (MIT)
*
*
* Copyright (c) 2015-202
2
Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-202
3
Advanced Micro Devices, Inc. All rights reserved.
*
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* of this software and associated documentation files (the "Software"), to deal
...
...
src/include/migraphx/op/convolution.hpp
View file @
32b83c9c
/*
/*
* The MIT License (MIT)
* The MIT License (MIT)
*
*
* Copyright (c) 2015-202
2
Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-202
3
Advanced Micro Devices, Inc. All rights reserved.
*
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* of this software and associated documentation files (the "Software"), to deal
...
@@ -82,7 +82,7 @@ struct convolution
...
@@ -82,7 +82,7 @@ struct convolution
const
auto
input_ndim
=
inputs
[
0
].
ndim
();
const
auto
input_ndim
=
inputs
[
0
].
ndim
();
const
auto
padding_size
=
padding
.
size
();
const
auto
padding_size
=
padding
.
size
();
if
(
input_ndim
!=
padding_size
/
2
+
2
&&
input_ndim
!=
padding_size
+
2
)
if
(
input_ndim
!=
padding_size
/
2
+
2
and
input_ndim
!=
padding_size
+
2
)
{
{
MIGRAPHX_THROW
(
"CONVOLUTION: input and attribute size mismatch!"
);
MIGRAPHX_THROW
(
"CONVOLUTION: input and attribute size mismatch!"
);
}
}
...
@@ -206,6 +206,7 @@ struct convolution
...
@@ -206,6 +206,7 @@ struct convolution
std
::
vector
<
std
::
size_t
>
new_padding
;
std
::
vector
<
std
::
size_t
>
new_padding
;
if
(
padding_mode
!=
op
::
padding_mode_t
::
default_
)
if
(
padding_mode
!=
op
::
padding_mode_t
::
default_
)
{
{
// auto-Calculate the padding sizes with calc_dyn_auto_pad
auto
input_lens
=
args
[
0
].
get_shape
().
lens
();
auto
input_lens
=
args
[
0
].
get_shape
().
lens
();
auto
weights_lens
=
args
[
1
].
get_shape
().
lens
();
auto
weights_lens
=
args
[
1
].
get_shape
().
lens
();
new_padding
=
new_padding
=
...
@@ -217,6 +218,7 @@ struct convolution
...
@@ -217,6 +218,7 @@ struct convolution
}
}
else
else
{
{
// Use the padding that was given
new_padding
=
padding
;
new_padding
=
padding
;
if
(
output_shape
.
dynamic
())
if
(
output_shape
.
dynamic
())
{
{
...
...
Prev
1
2
3
4
5
6
…
15
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment