Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
c00100f9
Commit
c00100f9
authored
Jun 30, 2023
by
Khalique Ahmed
Browse files
Merge branch 'develop' of
https://github.com/ROCmSoftwarePlatform/AMDMIGraphX
into nhwc_workaround
parents
cd4ab535
ad618aa9
Changes
20
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
604 additions
and
99 deletions
+604
-99
CMakeLists.txt
CMakeLists.txt
+6
-2
Dockerfile
Dockerfile
+4
-1
src/include/migraphx/matcher.hpp
src/include/migraphx/matcher.hpp
+21
-25
src/include/migraphx/op/convert.hpp
src/include/migraphx/op/convert.hpp
+1
-11
src/include/migraphx/source_location.hpp
src/include/migraphx/source_location.hpp
+73
-0
src/onnx/parse_where.cpp
src/onnx/parse_where.cpp
+0
-9
src/program.cpp
src/program.cpp
+26
-0
src/simplify_algebra.cpp
src/simplify_algebra.cpp
+1
-6
src/targets/cpu/CMakeLists.txt
src/targets/cpu/CMakeLists.txt
+0
-2
src/targets/gpu/fuse_mlir.cpp
src/targets/gpu/fuse_mlir.cpp
+68
-24
src/targets/gpu/mlir.cpp
src/targets/gpu/mlir.cpp
+32
-9
src/targets/gpu/target.cpp
src/targets/gpu/target.cpp
+1
-2
test/CMakeLists.txt
test/CMakeLists.txt
+0
-2
test/gpu/fuse_mlir.cpp
test/gpu/fuse_mlir.cpp
+157
-0
test/gpu/mlir.cpp
test/gpu/mlir.cpp
+81
-1
test/onnx/.onnxrt-commit
test/onnx/.onnxrt-commit
+1
-1
test/verify/test_pad_highest.cpp
test/verify/test_pad_highest.cpp
+1
-1
test/verify/test_pad_lowest.cpp
test/verify/test_pad_lowest.cpp
+1
-1
tools/docker/ubuntu_2204.dockerfile
tools/docker/ubuntu_2204.dockerfile
+127
-0
tools/install_prereqs.sh
tools/install_prereqs.sh
+3
-2
No files found.
CMakeLists.txt
View file @
c00100f9
...
@@ -48,7 +48,9 @@ endif()
...
@@ -48,7 +48,9 @@ endif()
set
(
CMAKE_BUILD_RPATH
"
${
CMAKE_BINARY_DIR
}
/lib"
)
set
(
CMAKE_BUILD_RPATH
"
${
CMAKE_BINARY_DIR
}
/lib"
)
project
(
migraphx
)
project
(
migraphx LANGUAGES C CXX
)
include
(
CTest
)
find_package
(
ROCM REQUIRED
)
find_package
(
ROCM REQUIRED
)
find_path
(
HALF_INCLUDE_DIR half.hpp PATH_SUFFIXES half
)
find_path
(
HALF_INCLUDE_DIR half.hpp PATH_SUFFIXES half
)
...
@@ -269,7 +271,9 @@ set(CMAKE_ARCHIVE_OUTPUT_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/lib)
...
@@ -269,7 +271,9 @@ set(CMAKE_ARCHIVE_OUTPUT_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/lib)
set
(
CMAKE_RUNTIME_OUTPUT_DIRECTORY
${
CMAKE_CURRENT_BINARY_DIR
}
/bin
)
set
(
CMAKE_RUNTIME_OUTPUT_DIRECTORY
${
CMAKE_CURRENT_BINARY_DIR
}
/bin
)
add_subdirectory
(
src
)
add_subdirectory
(
src
)
add_subdirectory
(
docs
)
add_subdirectory
(
docs
)
add_subdirectory
(
test
)
if
(
BUILD_TESTING
)
add_subdirectory
(
test
)
endif
()
add_subdirectory
(
tools
)
add_subdirectory
(
tools
)
set
(
DEST_DIR
${
CMAKE_BINARY_DIR
}
)
set
(
DEST_DIR
${
CMAKE_BINARY_DIR
}
)
...
...
Dockerfile
View file @
c00100f9
...
@@ -12,6 +12,9 @@ RUN apt-get update && apt-get install -y gnupg2 --no-install-recommends curl &&
...
@@ -12,6 +12,9 @@ RUN apt-get update && apt-get install -y gnupg2 --no-install-recommends curl &&
# Add rocm repository
# Add rocm repository
RUN
sh
-c
'echo deb [arch=amd64 trusted=yes] http://repo.radeon.com/rocm/apt/5.5/ focal main > /etc/apt/sources.list.d/rocm.list'
RUN
sh
-c
'echo deb [arch=amd64 trusted=yes] http://repo.radeon.com/rocm/apt/5.5/ focal main > /etc/apt/sources.list.d/rocm.list'
# From docs.amd.com for installing rocm. Needed to install properly
RUN
sh
-c
"echo 'Package: *
\n
Pin: release o=repo.radeon.com
\n
Pin-priority: 600' > /etc/apt/preferences.d/rocm-pin-600"
# Install dependencies
# Install dependencies
RUN
apt-get update
&&
DEBIAN_FRONTEND
=
noninteractive apt-get
install
-y
--allow-unauthenticated
\
RUN
apt-get update
&&
DEBIAN_FRONTEND
=
noninteractive apt-get
install
-y
--allow-unauthenticated
\
apt-utils
\
apt-utils
\
...
@@ -110,7 +113,7 @@ RUN git clone --single-branch --branch ${ONNXRUNTIME_BRANCH} --recursive ${ONNXR
...
@@ -110,7 +113,7 @@ RUN git clone --single-branch --branch ${ONNXRUNTIME_BRANCH} --recursive ${ONNXR
ADD
tools/build_and_test_onnxrt.sh /onnxruntime/build_and_test_onnxrt.sh
ADD
tools/build_and_test_onnxrt.sh /onnxruntime/build_and_test_onnxrt.sh
RUN
cget
-p
/usr/local
install
ROCmSoftwarePlatform/rocMLIR@
a997d5f51314b45d7a4c04f1599966dcf53f9b4d
-DBUILD_MIXR_TARGET
=
On
-DLLVM_ENABLE_ZSTD
=
Off
-DLLVM_ENABLE_THREADS
=
Off
RUN
cget
-p
/usr/local
install
ROCmSoftwarePlatform/rocMLIR@
8d25af3b3721c159bb41cc6388e9453b1018c126
-DBUILD_MIXR_TARGET
=
On
-DLLVM_ENABLE_ZSTD
=
Off
-DLLVM_ENABLE_THREADS
=
Off
ENV
MIOPEN_FIND_DB_PATH=/tmp/miopen/find-db
ENV
MIOPEN_FIND_DB_PATH=/tmp/miopen/find-db
ENV
MIOPEN_USER_DB_PATH=/tmp/miopen/user-db
ENV
MIOPEN_USER_DB_PATH=/tmp/miopen/user-db
...
...
src/include/migraphx/matcher.hpp
View file @
c00100f9
...
@@ -31,6 +31,7 @@
...
@@ -31,6 +31,7 @@
#include <migraphx/optional.hpp>
#include <migraphx/optional.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/type_name.hpp>
#include <migraphx/type_name.hpp>
#include <migraphx/source_location.hpp>
#include <migraphx/config.hpp>
#include <migraphx/config.hpp>
#include <unordered_map>
#include <unordered_map>
#include <unordered_set>
#include <unordered_set>
...
@@ -370,31 +371,30 @@ match::matcher_result find_match(module& modl, M&& m)
...
@@ -370,31 +371,30 @@ match::matcher_result find_match(module& modl, M&& m)
}
}
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_TRACE_MATCHES
)
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_TRACE_MATCHES
)
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_TRACE_MATCHES_FOR
)
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_VALIDATE_MATCHES
)
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_VALIDATE_MATCHES
)
/// Find matches for an instruction in the module for per section of matchers
/// Find matches for an instruction in the module for per section of matchers
template
<
class
Mod
,
class
...
Ms
>
template
<
class
Mod
,
class
...
Ms
>
void
find_matches
(
size_t
trace_pass
,
Mod
&
mod
,
instruction_ref
ins
,
Ms
&&
...
ms
)
void
find_matches_for
(
source_location
location
,
Mod
&
mod
,
instruction_ref
ins
,
Ms
&&
...
ms
)
{
{
#if !defined(__GNUC__) || defined(__clang__) || __GNUC__ > 5
const
int
trace
=
value_of
(
MIGRAPHX_TRACE_MATCHES
{});
const
const
bool
validate
=
enabled
(
MIGRAPHX_VALIDATE_MATCHES
{});
#endif
const
auto
trace_filter
=
string_value_of
(
MIGRAPHX_TRACE_MATCHES_FOR
{});
int
trace
=
value_of
(
MIGRAPHX_TRACE_MATCHES
{});
const
bool
trace_for
=
not
trace_filter
.
empty
()
and
#if !defined(__GNUC__) || defined(__clang__) || __GNUC__ > 5
(
contains
(
std
::
string
{
location
.
file_name
()},
trace_filter
)
or
const
contains
(
std
::
string
{
location
.
function_name
()},
trace_filter
));
#endif
bool
match
=
false
;
bool
validate
=
enabled
(
MIGRAPHX_VALIDATE_MATCHES
{});
bool
match
=
false
;
each_args
(
each_args
(
[
&
](
auto
&&
m
)
{
[
&
](
auto
&&
m
)
{
if
(
match
)
if
(
match
)
return
;
return
;
if
(
trace
>
1
or
trace_
pass
>
1
)
if
(
trace
>
1
or
trace_
for
)
std
::
cout
<<
"Match: "
<<
get_type_name
(
m
)
<<
std
::
endl
;
std
::
cout
<<
"Match: "
<<
get_type_name
(
m
)
<<
std
::
endl
;
auto
r
=
match_instruction
(
get_module
(
mod
),
ins
,
m
.
matcher
());
auto
r
=
match_instruction
(
get_module
(
mod
),
ins
,
m
.
matcher
());
if
(
r
.
result
==
get_module
(
mod
).
end
())
if
(
r
.
result
==
get_module
(
mod
).
end
())
return
;
return
;
if
(
trace
>
0
or
trace_
pass
>
0
)
if
(
trace
>
0
or
trace_
for
)
{
{
std
::
cout
<<
"Matched by "
<<
get_type_name
(
m
)
<<
std
::
endl
;
std
::
cout
<<
"Matched by "
<<
get_type_name
(
m
)
<<
std
::
endl
;
get_module
(
mod
).
debug_print
(
ins
);
get_module
(
mod
).
debug_print
(
ins
);
...
@@ -420,23 +420,19 @@ void find_matches(size_t trace_pass, Mod& mod, instruction_ref ins, Ms&&... ms)
...
@@ -420,23 +420,19 @@ void find_matches(size_t trace_pass, Mod& mod, instruction_ref ins, Ms&&... ms)
/// Find matches in a module
/// Find matches in a module
template
<
class
Mod
,
class
...
Ms
>
template
<
class
Mod
,
class
...
Ms
>
void
find_matches
(
Mod
&
mod
,
Ms
&&
...
ms
)
struct
find_matches
{
{
f
or
(
auto
ins
:
iterator_for
(
get_module
(
mod
)
))
f
ind_matches
(
Mod
&
mod
,
Ms
&&
...
ms
,
source_location
location
=
source_location
::
current
(
))
{
{
find_matches
(
0
,
mod
,
ins
,
ms
...);
for
(
auto
ins
:
iterator_for
(
get_module
(
mod
)))
{
find_matches_for
(
location
,
mod
,
ins
,
ms
...);
}
}
}
}
}
;
/// Find matches in a pass
template
<
class
Mod
,
class
...
Ms
>
template
<
class
Mod
,
class
...
Ms
>
void
find_matches
(
size_t
trace_pass
,
Mod
&
mod
,
Ms
&&
...
ms
)
find_matches
(
Mod
&
mod
,
Ms
&&
...
ms
)
->
find_matches
<
Mod
,
Ms
...
>
;
{
for
(
auto
ins
:
iterator_for
(
get_module
(
mod
)))
{
find_matches
(
trace_pass
,
mod
,
ins
,
ms
...);
}
}
template
<
class
M
,
class
F
>
template
<
class
M
,
class
F
>
struct
find_generic_match
struct
find_generic_match
...
...
src/include/migraphx/op/convert.hpp
View file @
c00100f9
...
@@ -66,17 +66,7 @@ struct convert : unary<convert>
...
@@ -66,17 +66,7 @@ struct convert : unary<convert>
auto
type
=
target_type
;
auto
type
=
target_type
;
return
[
type
](
auto
x
)
{
return
[
type
](
auto
x
)
{
auto
y
=
x
;
auto
y
=
x
;
shape
::
visit
(
type
,
[
&
](
auto
as
)
{
shape
::
visit
(
type
,
[
&
](
auto
as
)
{
y
=
as
(
x
);
});
// clamping value between target_type's max and min doesn't work for NaNs,
if
(
std
::
isnan
(
x
))
{
y
=
as
.
nan
();
}
else
{
y
=
std
::
min
(
std
::
max
(
as
(
x
),
as
.
min
()),
as
.
max
());
}
});
return
y
;
return
y
;
};
};
}
}
...
...
src/include/migraphx/source_location.hpp
0 → 100644
View file @
c00100f9
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_MIGRAPHX_SOURCE_LOCATION_HPP
#define MIGRAPHX_GUARD_MIGRAPHX_SOURCE_LOCATION_HPP
#include <migraphx/config.hpp>
#if defined(CPPCHECK)
#define MIGRAPHX_HAS_SOURCE_LOCATION 1
#define MIGRAPHX_HAS_SOURCE_LOCATION_TS 1
#elif defined(__has_include)
#if __has_include(<source_location>) && __cplusplus >= 202003L
#define MIGRAPHX_HAS_SOURCE_LOCATION 1
#else
#define MIGRAPHX_HAS_SOURCE_LOCATION 0
#endif
#if __has_include(<experimental/source_location>) && __cplusplus >= 201103L
#define MIGRAPHX_HAS_SOURCE_LOCATION_TS 1
#else
#define MIGRAPHX_HAS_SOURCE_LOCATION_TS 0
#endif
#else
#define MIGRAPHX_HAS_SOURCE_LOCATION 0
#define MIGRAPHX_HAS_SOURCE_LOCATION_TS 0
#endif
#if MIGRAPHX_HAS_SOURCE_LOCATION
#include <source_location>
#elif MIGRAPHX_HAS_SOURCE_LOCATION_TS
#include <experimental/source_location>
#endif
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
#if MIGRAPHX_HAS_SOURCE_LOCATION
using
source_location
=
std
::
source_location
;
#elif MIGRAPHX_HAS_SOURCE_LOCATION_TS
using
source_location
=
std
::
experimental
::
source_location
;
#else
struct
source_location
{
static
constexpr
source_location
current
()
noexcept
{
return
source_location
{};
}
constexpr
std
::
uint_least32_t
line
()
const
noexcept
{
return
0
;
}
constexpr
std
::
uint_least32_t
column
()
const
noexcept
{
return
0
;
}
constexpr
const
char
*
file_name
()
const
noexcept
{
return
""
;
}
constexpr
const
char
*
function_name
()
const
noexcept
{
return
""
;
}
};
#endif
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif // MIGRAPHX_GUARD_MIGRAPHX_SOURCE_LOCATION_HPP
src/onnx/parse_where.cpp
View file @
c00100f9
...
@@ -31,8 +31,6 @@ namespace migraphx {
...
@@ -31,8 +31,6 @@ namespace migraphx {
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
onnx
{
namespace
onnx
{
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_ENABLE_CK
)
struct
parse_where
:
op_parser
<
parse_where
>
struct
parse_where
:
op_parser
<
parse_where
>
{
{
std
::
vector
<
op_desc
>
operators
()
const
{
return
{{
"Where"
}};
}
std
::
vector
<
op_desc
>
operators
()
const
{
return
{{
"Where"
}};
}
...
@@ -59,13 +57,6 @@ struct parse_where : op_parser<parse_where>
...
@@ -59,13 +57,6 @@ struct parse_where : op_parser<parse_where>
compute_broadcasted_lens
(
args
[
0
]
->
get_shape
().
lens
(),
args
[
1
]
->
get_shape
().
lens
());
compute_broadcasted_lens
(
args
[
0
]
->
get_shape
().
lens
(),
args
[
1
]
->
get_shape
().
lens
());
lens
=
compute_broadcasted_lens
(
lens
,
args
[
2
]
->
get_shape
().
lens
());
lens
=
compute_broadcasted_lens
(
lens
,
args
[
2
]
->
get_shape
().
lens
());
if
(
enabled
(
MIGRAPHX_ENABLE_CK
{}))
{
// Convert condition tensor to int32 to work around CK not supporting bool type
args
[
0
]
=
info
.
add_instruction
(
make_op
(
"convert"
,
{{
"target_type"
,
shape
::
int32_type
}}),
args
[
0
]);
}
if
(
args
[
0
]
->
get_shape
().
lens
()
!=
lens
)
if
(
args
[
0
]
->
get_shape
().
lens
()
!=
lens
)
{
{
args
[
0
]
=
args
[
0
]
=
...
...
src/program.cpp
View file @
c00100f9
...
@@ -359,6 +359,31 @@ std::string classify(T x)
...
@@ -359,6 +359,31 @@ std::string classify(T x)
}
}
}
}
void
print_statistics
(
std
::
ostream
&
os
,
const
argument
&
a
)
{
a
.
visit
(
[
&
](
auto
t
)
{
os
<<
"Min value: "
<<
*
std
::
min_element
(
t
.
begin
(),
t
.
end
())
<<
", "
;
os
<<
"Max value: "
<<
*
std
::
max_element
(
t
.
begin
(),
t
.
end
())
<<
", "
;
double
num_elements
=
t
.
size
();
auto
mean
=
std
::
reduce
(
t
.
begin
(),
t
.
end
(),
0.0
)
/
num_elements
;
auto
stddev
=
std
::
sqrt
(
std
::
accumulate
(
t
.
begin
(),
t
.
end
(),
0.0
,
[
&
](
auto
r
,
auto
v
)
{
return
r
+
std
::
pow
((
v
-
mean
),
2.0
);
})
/
num_elements
);
os
<<
"Mean: "
<<
mean
<<
", "
;
os
<<
"StdDev: "
<<
stddev
<<
"
\n
"
;
},
[
&
](
const
auto
&
xs
)
{
for
(
const
auto
&
x
:
xs
)
{
print_statistics
(
os
,
x
);
}
});
}
std
::
unordered_set
<
std
::
string
>
classify_argument
(
const
argument
&
a
)
std
::
unordered_set
<
std
::
string
>
classify_argument
(
const
argument
&
a
)
{
{
std
::
unordered_set
<
std
::
string
>
result
;
std
::
unordered_set
<
std
::
string
>
result
;
...
@@ -578,6 +603,7 @@ std::vector<argument> program::eval(parameter_map params, execution_environment
...
@@ -578,6 +603,7 @@ std::vector<argument> program::eval(parameter_map params, execution_environment
std
::
cout
<<
"Output: "
;
std
::
cout
<<
"Output: "
;
preview_argument
(
std
::
cout
,
buffer
);
preview_argument
(
std
::
cout
,
buffer
);
std
::
cout
<<
std
::
endl
;
std
::
cout
<<
std
::
endl
;
print_statistics
(
std
::
cout
,
buffer
);
}
}
else
else
{
{
...
...
src/simplify_algebra.cpp
View file @
c00100f9
...
@@ -39,8 +39,6 @@
...
@@ -39,8 +39,6 @@
#include <migraphx/algorithm.hpp>
#include <migraphx/algorithm.hpp>
#include <unordered_set>
#include <unordered_set>
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_TRACE_SIMPLIFY_ALGEBRA_MATCHES
)
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
@@ -1487,13 +1485,10 @@ struct find_split_transpose
...
@@ -1487,13 +1485,10 @@ struct find_split_transpose
void
simplify_algebra
::
apply
(
module
&
m
)
const
void
simplify_algebra
::
apply
(
module
&
m
)
const
{
{
size_t
trace
=
value_of
(
MIGRAPHX_TRACE_SIMPLIFY_ALGEBRA_MATCHES
{});
// Run simplifications multiple times
// Run simplifications multiple times
for
(
int
i
=
0
;
i
<
8
;
i
++
)
for
(
int
i
=
0
;
i
<
8
;
i
++
)
{
{
match
::
find_matches
(
trace
,
match
::
find_matches
(
m
,
m
,
find_inner_broadcast
{},
find_inner_broadcast
{},
find_dot_broadcast
{},
find_dot_broadcast
{},
find_double_add_lit_broadcast
{},
find_double_add_lit_broadcast
{},
...
...
src/targets/cpu/CMakeLists.txt
View file @
c00100f9
...
@@ -88,8 +88,6 @@ foreach(LIBRARY ${OpenMP_CXX_LIBRARIES})
...
@@ -88,8 +88,6 @@ foreach(LIBRARY ${OpenMP_CXX_LIBRARIES})
endif
()
endif
()
endforeach
()
endforeach
()
target_link_libraries
(
migraphx_all_targets INTERFACE migraphx_cpu
)
rocm_install_targets
(
rocm_install_targets
(
TARGETS migraphx_cpu
TARGETS migraphx_cpu
INCLUDE
INCLUDE
...
...
src/targets/gpu/fuse_mlir.cpp
View file @
c00100f9
...
@@ -139,7 +139,8 @@ struct find_mlir_op
...
@@ -139,7 +139,8 @@ struct find_mlir_op
auto
matcher
()
const
auto
matcher
()
const
{
{
auto
dot_or_conv
=
match
::
skip
(
match
::
name
(
"contiguous"
))(
auto
dot_or_conv
=
match
::
skip
(
match
::
name
(
"contiguous"
))(
match
::
any_of
(
match
::
name
(
"dot"
),
is_mlir_conv
()).
bind
(
"gemm_based_op"
));
match
::
any_of
(
match
::
name
(
"dot"
),
match
::
name
(
"quant_dot"
),
is_mlir_conv
())
.
bind
(
"gemm_based_op"
));
return
match
::
name
(
"pointwise"
)(
match
::
any_of
[
match
::
inputs
()](
dot_or_conv
.
bind
(
"x"
)));
return
match
::
name
(
"pointwise"
)(
match
::
any_of
[
match
::
inputs
()](
dot_or_conv
.
bind
(
"x"
)));
}
}
...
@@ -190,6 +191,68 @@ struct find_mlir_op
...
@@ -190,6 +191,68 @@ struct find_mlir_op
return
{
new_gemm_based_op
,
top_inputs
};
return
{
new_gemm_based_op
,
top_inputs
};
}
}
// Whitelist supported fusion options, including imposing type constraints
// for cases where MLIR only supports an operation (usually a pointwise function)
// on particular types.
bool
is_pointwise_op_supported_by_mlir
(
const
instruction
&
i
)
const
{
using
type_t
=
shape
::
type_t
;
const
auto
&
name
=
i
.
name
();
const
auto
result_type
=
i
.
get_shape
().
type
();
const
std
::
initializer_list
<
type_t
>
allowed_types
=
{
type_t
::
float_type
,
type_t
::
half_type
,
type_t
::
int8_type
,
type_t
::
int32_type
,
type_t
::
bool_type
};
// Preliminary type check.
if
(
not
contains
(
allowed_types
,
result_type
))
{
return
false
;
}
const
std
::
initializer_list
<
std
::
string
>
any_type_ops
=
{
"@literal"
,
"@param"
,
"@return"
};
const
std
::
initializer_list
<
std
::
string
>
no_bool_ops
=
{
"convolution"
,
"quant_convolution"
,
"dot"
,
"quant_dot"
,
"add"
,
"clip"
,
"sub"
,
"mul"
,
"div"
,
"pow"
,
"where"
,
"quantizelinear"
,
"dequantizelinear"
,
"abs"
,
"neg"
};
const
std
::
initializer_list
<
std
::
string
>
fp_only_ops
=
{
"ceil"
,
"erf"
,
"exp"
,
"floor"
,
"log"
,
"recip"
,
"rsqrt"
,
"sigmoid"
"softmax"
,
"tanh"
};
bool
is_float
=
contains
({
type_t
::
float_type
,
type_t
::
half_type
},
result_type
);
if
(
contains
(
any_type_ops
,
name
))
return
true
;
if
(
result_type
!=
type_t
::
bool_type
&&
contains
(
no_bool_ops
,
name
))
return
true
;
if
(
is_float
&&
contains
(
fp_only_ops
,
name
))
return
true
;
// Only conversions between floating types are known to be unambigiously
// supported.
if
(
is_float
&&
name
==
"convert"
)
{
return
std
::
all_of
(
i
.
inputs
().
begin
(),
i
.
inputs
().
end
(),
[](
const
auto
&
arg
)
{
return
contains
({
type_t
::
float_type
,
type_t
::
half_type
},
arg
->
get_shape
().
type
());
});
}
return
false
;
}
void
apply
(
module_pass_manager
&
mpm
,
const
match
::
matcher_result
&
r
)
const
void
apply
(
module_pass_manager
&
mpm
,
const
match
::
matcher_result
&
r
)
const
{
{
auto
ins
=
r
.
result
;
auto
ins
=
r
.
result
;
...
@@ -197,31 +260,12 @@ struct find_mlir_op
...
@@ -197,31 +260,12 @@ struct find_mlir_op
auto
x_ins
=
r
.
instructions
[
"x"
];
// input after contiguous
auto
x_ins
=
r
.
instructions
[
"x"
];
// input after contiguous
auto
*
pm
=
ins
->
module_inputs
().
front
();
auto
*
pm
=
ins
->
module_inputs
().
front
();
auto
names
=
pm
->
get_parameter_names
();
auto
names
=
pm
->
get_parameter_names
();
// Whitelist pointwise operators
// Whitelist pointwise operators.
if
(
std
::
any_of
(
pm
->
begin
(),
pm
->
end
(),
[](
const
auto
&
i
)
{
if
(
std
::
any_of
(
pm
->
begin
(),
pm
->
end
(),
[
&
](
const
auto
&
i
)
{
return
not
contains
({
"@literal"
,
return
not
is_pointwise_op_supported_by_mlir
(
i
);
"@param"
,
"@return"
,
"convolution"
,
"quant_convolution"
,
"dot"
,
"add"
,
"relu"
,
"dequantizelinear"
,
"quantizelinear"
,
"mul"
},
i
.
name
());
}))
return
;
// Only fuse with fp32/fp16/int8/int32
if
(
std
::
any_of
(
ins
->
inputs
().
begin
(),
ins
->
inputs
().
end
(),
[
&
](
auto
i
)
{
return
not
contains
({
shape
::
type_t
::
float_type
,
shape
::
type_t
::
half_type
,
shape
::
type_t
::
int8_type
,
shape
::
type_t
::
int32_type
},
i
->
get_shape
().
type
());
}))
}))
return
;
return
;
std
::
sort
(
names
.
begin
(),
names
.
end
());
std
::
sort
(
names
.
begin
(),
names
.
end
());
module_ref
mm
=
mpm
.
create_module
(
"mlir_"
+
pm
->
name
());
module_ref
mm
=
mpm
.
create_module
(
"mlir_"
+
pm
->
name
());
mm
->
set_bypass
();
mm
->
set_bypass
();
...
...
src/targets/gpu/mlir.cpp
View file @
c00100f9
...
@@ -121,7 +121,10 @@ struct mlir_handle
...
@@ -121,7 +121,10 @@ struct mlir_handle
#define MIGRAPHX_MANAGE_MLIR_HANDLE(T, F) migraphx::gpu::mlir_handle<T, decltype(&F), &F> // NOLINT
#define MIGRAPHX_MANAGE_MLIR_HANDLE(T, F) migraphx::gpu::mlir_handle<T, decltype(&F), &F> // NOLINT
using
mlir_context
=
MIGRAPHX_MANAGE_MLIR_HANDLE
(
MlirContext
,
mlirContextDestroy
);
using
mlir_context
=
MIGRAPHX_MANAGE_MLIR_HANDLE
(
MlirContext
,
mlirContextDestroy
);
using
mlir_thread_pool
=
MIGRAPHX_MANAGE_MLIR_HANDLE
(
MlirLlvmThreadPool
,
mlirLlvmThreadPoolDestroy
);
using
mlir_dialect_registry
=
MIGRAPHX_MANAGE_MLIR_HANDLE
(
MlirDialectRegistry
,
mlirDialectRegistryDestroy
);
using
mlir_module
=
MIGRAPHX_MANAGE_MLIR_HANDLE
(
MlirModule
,
mlirModuleDestroy
);
using
mlir_module
=
MIGRAPHX_MANAGE_MLIR_HANDLE
(
MlirModule
,
mlirModuleDestroy
);
using
mlir_operation
=
MIGRAPHX_MANAGE_MLIR_HANDLE
(
MlirOperation
,
mlirOperationDestroy
);
using
mlir_operation
=
MIGRAPHX_MANAGE_MLIR_HANDLE
(
MlirOperation
,
mlirOperationDestroy
);
using
mlir_op_printing_flags
=
MIGRAPHX_MANAGE_MLIR_HANDLE
(
MlirOpPrintingFlags
,
using
mlir_op_printing_flags
=
MIGRAPHX_MANAGE_MLIR_HANDLE
(
MlirOpPrintingFlags
,
...
@@ -173,16 +176,38 @@ bool has_xdlops(const std::string& target_arch)
...
@@ -173,16 +176,38 @@ bool has_xdlops(const std::string& target_arch)
struct
mlir_program
struct
mlir_program
{
{
mlir_program
()
mlir_program
()
:
ctx
(
mlirContextCreate
()),
:
ctx
(
mlirContextCreateWithRegistry
(
get_dialect_registry
().
get
(),
/*threadingEnable=*/
false
)),
location
(
mlirLocationUnknownGet
(
ctx
.
get
())),
location
(
mlirLocationUnknownGet
(
ctx
.
get
())),
mmodule
(
mlirModuleCreateEmpty
(
location
))
mmodule
(
mlirModuleCreateEmpty
(
location
))
{
{
MlirDialectRegistry
registry
=
mlirDialectRegistryCreate
();
mlirContextSetThreadPool
(
ctx
.
get
(),
get_thread_pool
().
get
());
mlirRegisterRocMLIRDialects
(
registry
);
mlirContextAppendDialectRegistry
(
ctx
.
get
(),
registry
);
mlirContextLoadAllAvailableDialects
(
ctx
.
get
());
mlirContextLoadAllAvailableDialects
(
ctx
.
get
());
mlirDialectRegistryDestroy
(
registry
);
}
mlirContextSetAllowUnregisteredDialects
(
ctx
.
get
(),
true
/*allow*/
);
static
mlir_dialect_registry
&
get_dialect_registry
()
{
static
std
::
once_flag
init_guard
;
static
mlir_dialect_registry
the_registry
;
// The MLIR registration functions (for dialects and passes) are not
// necessarily thread-safe and need to be executed exactly once
// (especially since they eventually call non-thread-safe LLVM
// initilizations).
std
::
call_once
(
init_guard
,
[
&
]()
{
the_registry
=
mlirDialectRegistryCreate
();
mlirRegisterRocMLIRDialects
(
the_registry
.
get
());
mlirRegisterRocMLIRPasses
();
});
return
the_registry
;
}
static
mlir_thread_pool
&
get_thread_pool
()
{
// To save on overhead, we create one LLVM thread pool and reuse it
// across all MLIR contexts as recommended by MLIR upstream.
// Note that this is thread-safe as of C++11.
static
mlir_thread_pool
the_pool
=
mlirLlvmThreadPoolCreate
();
return
the_pool
;
}
}
MlirType
make_type
(
shape
::
type_t
t
)
const
MlirType
make_type
(
shape
::
type_t
t
)
const
...
@@ -244,8 +269,6 @@ struct mlir_program
...
@@ -244,8 +269,6 @@ struct mlir_program
MlirAttribute
attribute
(
std
::
int64_t
i
)
const
MlirAttribute
attribute
(
std
::
int64_t
i
)
const
{
{
if
(
i
<
0
)
MIGRAPHX_THROW
(
"MLIR cant handle negative values since they are ambiguous"
);
return
mlirIntegerAttrGet
(
mlirIntegerTypeGet
(
ctx
.
get
(),
64
),
i
);
return
mlirIntegerAttrGet
(
mlirIntegerTypeGet
(
ctx
.
get
(),
64
),
i
);
}
}
MlirAttribute
attribute
(
std
::
uint64_t
i
)
const
MlirAttribute
attribute
(
std
::
uint64_t
i
)
const
...
...
src/targets/gpu/target.cpp
View file @
c00100f9
...
@@ -77,7 +77,6 @@ MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DISABLE_SCHEDULE_PASS)
...
@@ -77,7 +77,6 @@ MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DISABLE_SCHEDULE_PASS)
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_DISABLE_REDUCE_FUSION
)
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_DISABLE_REDUCE_FUSION
)
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_ENABLE_NHWC
)
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_ENABLE_NHWC
)
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_ENABLE_CK
)
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_ENABLE_CK
)
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_DISABLE_FAST_GELU
)
struct
id_pass
struct
id_pass
{
{
...
@@ -126,7 +125,7 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
...
@@ -126,7 +125,7 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
inline_module
{},
inline_module
{},
rewrite_pooling
{},
rewrite_pooling
{},
dead_code_elimination
{},
dead_code_elimination
{},
enable_pass
(
not
enabled
(
MIGRAPHX_DISABLE_FAST_GELU
{})
,
rewrite_gelu
{}),
enable_pass
(
options
.
fast_math
,
rewrite_gelu
{}),
optimize_module
{},
optimize_module
{},
enable_pass
(
enabled
(
MIGRAPHX_ENABLE_NHWC
{}),
layout_nhwc
{}),
enable_pass
(
enabled
(
MIGRAPHX_ENABLE_NHWC
{}),
layout_nhwc
{}),
dead_code_elimination
{},
dead_code_elimination
{},
...
...
test/CMakeLists.txt
View file @
c00100f9
...
@@ -24,8 +24,6 @@
...
@@ -24,8 +24,6 @@
cmake_policy
(
SET CMP0057 NEW
)
cmake_policy
(
SET CMP0057 NEW
)
include
(
CTest
)
find_package
(
Threads REQUIRED
)
find_package
(
Threads REQUIRED
)
include
(
ProcessorCount
)
include
(
ProcessorCount
)
ProcessorCount
(
N
)
ProcessorCount
(
N
)
...
...
test/gpu/fuse_mlir.cpp
0 → 100644
View file @
c00100f9
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/gpu/fuse_mlir.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/program.hpp>
#include <basic_ops.hpp>
#include <migraphx/make_op.hpp>
#include <test.hpp>
#include <pointwise.hpp>
void
run_pass
(
migraphx
::
program
&
p
)
{
migraphx
::
run_passes
(
p
,
{
migraphx
::
gpu
::
fuse_mlir
{},
migraphx
::
dead_code_elimination
{}});
}
template
<
class
F
>
migraphx
::
instruction_ref
add_mlir
(
migraphx
::
program
&
p
,
const
std
::
string
&
name
,
std
::
vector
<
migraphx
::
instruction_ref
>
inputs
,
std
::
vector
<
std
::
string
>
arg_names
,
F
f
)
{
assert
(
inputs
.
size
()
==
arg_names
.
size
()
&&
"One interior parameter name given per input."
);
auto
*
mm
=
p
.
get_main_module
();
auto
*
pm
=
p
.
create_module
(
name
);
pm
->
set_bypass
();
std
::
vector
<
migraphx
::
instruction_ref
>
params
;
for
(
size_t
i
=
0
,
e
=
inputs
.
size
();
i
<
e
;
++
i
)
{
params
.
push_back
(
pm
->
add_parameter
(
arg_names
[
i
],
inputs
[
i
]
->
get_shape
()));
}
auto
values
=
f
(
pm
,
params
);
auto
root
=
std
::
get
<
0
>
(
values
);
auto
r
=
std
::
get
<
1
>
(
values
);
pm
->
add_return
({
r
});
return
mm
->
add_instruction
(
migraphx
::
make_op
(
"gpu::mlir_op"
,
{{
"op"
,
migraphx
::
to_value
(
root
->
get_operator
())}}),
inputs
,
{
pm
});
}
TEST_CASE
(
dot_add
)
{
migraphx
::
shape
s
{
migraphx
::
shape
::
float_type
,
{
1
,
3
,
3
}};
migraphx
::
program
p1
;
{
auto
*
mm
=
p1
.
get_main_module
();
auto
a
=
mm
->
add_parameter
(
"a"
,
s
);
auto
b
=
mm
->
add_parameter
(
"b"
,
s
);
auto
x
=
mm
->
add_parameter
(
"x"
,
s
);
auto
dot
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"dot"
),
a
,
b
);
auto
add
=
add_pointwise
(
p1
,
"main:pointwise0"
,
{
dot
,
x
},
single_pointwise
(
"add"
));
mm
->
add_return
({
add
});
}
run_pass
(
p1
);
migraphx
::
program
p2
;
{
auto
*
mm
=
p2
.
get_main_module
();
auto
a
=
mm
->
add_parameter
(
"a"
,
s
);
auto
b
=
mm
->
add_parameter
(
"b"
,
s
);
auto
x
=
mm
->
add_parameter
(
"x"
,
s
);
auto
fused
=
add_mlir
(
p2
,
"mlir_main:pointwise0"
,
{
x
,
a
,
b
},
{
"x1"
,
"y0"
,
"y1"
},
[
=
](
auto
*
pm
,
const
auto
&
inputs
)
{
auto
dot
=
pm
->
add_instruction
(
migraphx
::
make_op
(
"dot"
),
inputs
[
1
],
inputs
[
2
]);
auto
add
=
pm
->
add_instruction
(
migraphx
::
make_op
(
"add"
),
dot
,
inputs
[
0
]);
return
std
::
make_tuple
(
dot
,
add
);
});
mm
->
add_return
({
fused
});
}
EXPECT
(
p1
.
sort
()
==
p2
.
sort
());
}
TEST_CASE
(
int_quant_dot_abs
)
{
migraphx
::
shape
s_a
{
migraphx
::
shape
::
int8_type
,
{
5
,
4
}};
migraphx
::
shape
s_b
{
migraphx
::
shape
::
int8_type
,
{
4
,
3
}};
migraphx
::
program
p1
;
{
auto
*
mm
=
p1
.
get_main_module
();
auto
a
=
mm
->
add_parameter
(
"a"
,
s_a
);
auto
b
=
mm
->
add_parameter
(
"b"
,
s_b
);
auto
dot
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"quant_dot"
),
a
,
b
);
auto
abs
=
add_pointwise
(
p1
,
"main:pointwise0"
,
{
dot
},
single_pointwise
(
"abs"
));
mm
->
add_return
({
abs
});
}
run_pass
(
p1
);
migraphx
::
program
p2
;
{
auto
*
mm
=
p2
.
get_main_module
();
auto
a
=
mm
->
add_parameter
(
"a"
,
s_a
);
auto
b
=
mm
->
add_parameter
(
"b"
,
s_b
);
auto
fused
=
add_mlir
(
p2
,
"mlir_main:pointwise0"
,
{
a
,
b
},
{
"y0"
,
"y1"
},
[
=
](
auto
*
pm
,
const
auto
&
inputs
)
{
auto
dot
=
pm
->
add_instruction
(
migraphx
::
make_op
(
"quant_dot"
),
inputs
[
0
],
inputs
[
1
]);
auto
abs
=
pm
->
add_instruction
(
migraphx
::
make_op
(
"abs"
),
dot
);
return
std
::
make_tuple
(
dot
,
abs
);
});
mm
->
add_return
({
fused
});
}
EXPECT
(
p1
.
sort
()
==
p2
.
sort
());
}
TEST_CASE
(
int_quant_dot_tanh_fails
)
{
migraphx
::
shape
s_a
{
migraphx
::
shape
::
int8_type
,
{
5
,
4
}};
migraphx
::
shape
s_b
{
migraphx
::
shape
::
int8_type
,
{
4
,
3
}};
migraphx
::
program
p1
;
{
auto
*
mm
=
p1
.
get_main_module
();
auto
a
=
mm
->
add_parameter
(
"a"
,
s_a
);
auto
b
=
mm
->
add_parameter
(
"b"
,
s_b
);
auto
dot
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"quant_dot"
),
a
,
b
);
auto
tanh
=
add_pointwise
(
p1
,
"main:pointwise0"
,
{
dot
},
single_pointwise
(
"tanh"
));
mm
->
add_return
({
tanh
});
}
migraphx
::
program
p2
(
p1
);
// This pass should do nothing as int32_t tanh isn't supported.
run_pass
(
p1
);
EXPECT
(
p1
==
p2
);
}
int
main
(
int
argc
,
const
char
*
argv
[])
{
if
(
migraphx
::
gpu
::
mlir_enabled
())
test
::
run
(
argc
,
argv
);
return
0
;
}
test/gpu/mlir.cpp
View file @
c00100f9
...
@@ -187,12 +187,39 @@ module {
...
@@ -187,12 +187,39 @@ module {
EXPECT
(
verify_mlir
(
m
));
EXPECT
(
verify_mlir
(
m
));
}
}
TEST_CASE
(
quant_dot_add
)
{
const
std
::
string
mlir_output
=
R"__migraphx__(
module {
func.func @main(%arg0: tensor<1x5x4xi8>, %arg1: tensor<1x4x3xi8>, %arg2: tensor<1x5x3xi32>) -> tensor<1x5x3xi32> attributes {arch = "", kernel = "mixr"} {
%0 = migraphx.quant_dot(%arg0, %arg1) : (tensor<1x5x4xi8>, tensor<1x4x3xi8>) -> tensor<1x5x3xi32>
%1 = migraphx.add(%0, %arg2) : (tensor<1x5x3xi32>, tensor<1x5x3xi32>) -> tensor<1x5x3xi32>
return %1 : tensor<1x5x3xi32>
}
}
)__migraphx__"
;
migraphx
::
module
m
;
auto
arg0
=
m
.
add_parameter
(
"arg0"
,
{
migraphx
::
shape
::
int8_type
,
{
1
,
5
,
4
}});
auto
arg1
=
m
.
add_parameter
(
"arg1"
,
{
migraphx
::
shape
::
int8_type
,
{
1
,
4
,
3
}});
auto
arg2
=
m
.
add_parameter
(
"arg2"
,
{
migraphx
::
shape
::
int32_type
,
{
1
,
5
,
3
}});
auto
conv
=
m
.
add_instruction
(
migraphx
::
make_op
(
"quant_dot"
),
arg0
,
arg1
);
auto
add
=
m
.
add_instruction
(
migraphx
::
make_op
(
"add"
),
conv
,
arg2
);
m
.
add_return
({
add
});
auto
s
=
migraphx
::
gpu
::
dump_mlir
(
m
);
// Skip test if MLIR is not enabled
if
(
s
.
empty
())
return
;
CHECK
(
encode
(
s
)
==
encode
(
mlir_output
));
EXPECT
(
verify_mlir
(
m
));
}
TEST_CASE
(
dot_add
)
TEST_CASE
(
dot_add
)
{
{
const
std
::
string
mlir_output
=
R"__migraphx__(
const
std
::
string
mlir_output
=
R"__migraphx__(
module {
module {
func.func @mlir_dot(%arg0: tensor<1x5x4xf32>, %arg1: tensor<1x4x3xf32>, %arg2: tensor<1x5x3xf32>) -> tensor<1x5x3xf32> attributes {arch = "", kernel = "mixr"} {
func.func @mlir_dot(%arg0: tensor<1x5x4xf32>, %arg1: tensor<1x4x3xf32>, %arg2: tensor<1x5x3xf32>) -> tensor<1x5x3xf32> attributes {arch = "", kernel = "mixr"} {
%0 = migraphx.dot(%arg0, %arg1) : tensor<1x5x4xf32>, tensor<1x4x3xf32> -> tensor<1x5x3xf32>
%0 = migraphx.dot(%arg0, %arg1) :
(
tensor<1x5x4xf32>, tensor<1x4x3xf32>
)
-> tensor<1x5x3xf32>
%1 = migraphx.add(%0, %arg2) : (tensor<1x5x3xf32>, tensor<1x5x3xf32>) -> tensor<1x5x3xf32>
%1 = migraphx.add(%0, %arg2) : (tensor<1x5x3xf32>, tensor<1x5x3xf32>) -> tensor<1x5x3xf32>
return %1 : tensor<1x5x3xf32>
return %1 : tensor<1x5x3xf32>
}
}
...
@@ -246,4 +273,57 @@ module {
...
@@ -246,4 +273,57 @@ module {
EXPECT
(
verify_mlir
(
m
));
EXPECT
(
verify_mlir
(
m
));
}
}
TEST_CASE
(
dot_convert
)
{
const
std
::
string
mlir_output
=
R"__migraphx__(
module {
func.func @mlir_dot(%arg0: tensor<1x5x4xf32>, %arg1: tensor<1x4x3xf32>) -> tensor<1x5x3xf16> attributes {arch = "", kernel = "mixr"} {
%0 = migraphx.dot(%arg0, %arg1) : (tensor<1x5x4xf32>, tensor<1x4x3xf32>) -> tensor<1x5x3xf32>
%1 = migraphx.convert(%0) {target_type = 1 : i64} : (tensor<1x5x3xf32>) -> tensor<1x5x3xf16>
return %1 : tensor<1x5x3xf16>
}
}
)__migraphx__"
;
migraphx
::
module
m
;
auto
arg0
=
m
.
add_parameter
(
"arg0"
,
{
migraphx
::
shape
::
float_type
,
{
1
,
5
,
4
}});
auto
arg1
=
m
.
add_parameter
(
"arg1"
,
{
migraphx
::
shape
::
float_type
,
{
1
,
4
,
3
}});
auto
dot
=
m
.
add_instruction
(
migraphx
::
make_op
(
"dot"
),
arg0
,
arg1
);
auto
trunc
=
m
.
add_instruction
(
migraphx
::
make_op
(
"convert"
,
{{
"target_type"
,
migraphx
::
shape
::
half_type
}}),
dot
);
m
.
add_return
({
trunc
});
auto
s
=
migraphx
::
gpu
::
dump_mlir
(
m
);
// Skip test if MLIR is not enabled
if
(
s
.
empty
())
return
;
CHECK
(
encode
(
s
)
==
encode
(
mlir_output
));
EXPECT
(
verify_mlir
(
m
));
}
TEST_CASE
(
dot_where
)
{
const
std
::
string
mlir_output
=
R"__migraphx__(
module {
func.func @mlir_dot(%arg0: tensor<1x5x4xf32>, %arg1: tensor<1x4x3xf32>, %arg2: tensor<1x5x3xi8>, %arg3: tensor<1x5x3xf32>) -> tensor<1x5x3xf32> attributes {arch = "", kernel = "mixr"} {
%0 = migraphx.dot(%arg0, %arg1) : (tensor<1x5x4xf32>, tensor<1x4x3xf32>) -> tensor<1x5x3xf32>
%1 = migraphx.where(%arg2, %0, %arg3) : (tensor<1x5x3xi8>, tensor<1x5x3xf32>, tensor<1x5x3xf32>) -> tensor<1x5x3xf32>
return %1 : tensor<1x5x3xf32>
}
}
)__migraphx__"
;
migraphx
::
module
m
;
auto
arg0
=
m
.
add_parameter
(
"arg0"
,
{
migraphx
::
shape
::
float_type
,
{
1
,
5
,
4
}});
auto
arg1
=
m
.
add_parameter
(
"arg1"
,
{
migraphx
::
shape
::
float_type
,
{
1
,
4
,
3
}});
auto
arg2
=
m
.
add_parameter
(
"arg2"
,
{
migraphx
::
shape
::
bool_type
,
{
1
,
5
,
3
}});
auto
arg3
=
m
.
add_parameter
(
"arg3"
,
{
migraphx
::
shape
::
float_type
,
{
1
,
5
,
3
}});
auto
dot
=
m
.
add_instruction
(
migraphx
::
make_op
(
"dot"
),
arg0
,
arg1
);
auto
where
=
m
.
add_instruction
(
migraphx
::
make_op
(
"where"
),
arg2
,
dot
,
arg3
);
m
.
add_return
({
where
});
auto
s
=
migraphx
::
gpu
::
dump_mlir
(
m
);
// Skip test if MLIR is not enabled
if
(
s
.
empty
())
return
;
CHECK
(
encode
(
s
)
==
encode
(
mlir_output
));
EXPECT
(
verify_mlir
(
m
));
}
int
main
(
int
argc
,
const
char
*
argv
[])
{
test
::
run
(
argc
,
argv
);
}
int
main
(
int
argc
,
const
char
*
argv
[])
{
test
::
run
(
argc
,
argv
);
}
test/onnx/.onnxrt-commit
View file @
c00100f9
5a43828b3d73028bfd33b3856f82698d9ab02cb1
fbf08c4b4dce5da245189203d9f6cfc41f6663a2
test/verify/test_pad_highest.cpp
View file @
c00100f9
...
@@ -38,7 +38,7 @@ struct test_pad_highest : verify_program<test_pad_highest>
...
@@ -38,7 +38,7 @@ struct test_pad_highest : verify_program<test_pad_highest>
migraphx
::
shape
s0
{
migraphx
::
shape
::
half_type
,
{
2
,
2
}};
migraphx
::
shape
s0
{
migraphx
::
shape
::
half_type
,
{
2
,
2
}};
auto
l0
=
mm
->
add_literal
(
migraphx
::
literal
{
s0
,
data0
});
auto
l0
=
mm
->
add_literal
(
migraphx
::
literal
{
s0
,
data0
});
migraphx
::
op
::
pad
op
{};
migraphx
::
op
::
pad
op
{};
op
.
value
=
std
::
numeric_limits
<
float
>::
max
();
op
.
value
=
std
::
numeric_limits
<
migraphx
::
half
>::
max
();
op
.
pads
=
{
0
,
0
,
1
,
1
};
op
.
pads
=
{
0
,
0
,
1
,
1
};
mm
->
add_instruction
(
op
,
l0
);
mm
->
add_instruction
(
op
,
l0
);
return
p
;
return
p
;
...
...
test/verify/test_pad_lowest.cpp
View file @
c00100f9
...
@@ -38,7 +38,7 @@ struct test_pad_lowest : verify_program<test_pad_lowest>
...
@@ -38,7 +38,7 @@ struct test_pad_lowest : verify_program<test_pad_lowest>
migraphx
::
shape
s0
{
migraphx
::
shape
::
half_type
,
{
2
,
2
}};
migraphx
::
shape
s0
{
migraphx
::
shape
::
half_type
,
{
2
,
2
}};
auto
l0
=
mm
->
add_literal
(
migraphx
::
literal
{
s0
,
data0
});
auto
l0
=
mm
->
add_literal
(
migraphx
::
literal
{
s0
,
data0
});
migraphx
::
op
::
pad
op
{};
migraphx
::
op
::
pad
op
{};
op
.
value
=
std
::
numeric_limits
<
float
>::
lowest
();
op
.
value
=
std
::
numeric_limits
<
migraphx
::
half
>::
lowest
();
op
.
pads
=
{
0
,
0
,
1
,
1
};
op
.
pads
=
{
0
,
0
,
1
,
1
};
mm
->
add_instruction
(
op
,
l0
);
mm
->
add_instruction
(
op
,
l0
);
return
p
;
return
p
;
...
...
tools/docker/ubuntu_2204.dockerfile
0 → 100644
View file @
c00100f9
FROM
ubuntu:22.04
ARG
PREFIX=/usr/local
# Support multiarch
RUN
dpkg
--add-architecture
i386
# Install rocm key
RUN
apt-get update
&&
apt-get
install
-y
gnupg2
--no-install-recommends
curl
&&
\
curl
-fsSL
http://repo.radeon.com/rocm/rocm.gpg.key | gpg
--dearmor
-o
/etc/apt/trusted.gpg.d/rocm-keyring.gpg
# Add rocm repository
RUN
sh
-c
"echo 'deb [arch=amd64 signed-by=/etc/apt/trusted.gpg.d/rocm-keyring.gpg] http://repo.radeon.com/rocm/apt/5.5 jammy main' > /etc/apt/sources.list.d/rocm.list"
# From docs.amd.com for installing rocm. Needed to install properly
RUN
sh
-c
"echo 'Package: *
\n
Pin: release o=repo.radeon.com
\n
Pin-priority: 600' > /etc/apt/preferences.d/rocm-pin-600"
# Install dependencies
RUN
apt-get update
&&
DEBIAN_FRONTEND
=
noninteractive apt-get
install
-y
--allow-unauthenticated
\
apt-utils
\
build-essential
\
#clang-format-10
\
cmake
\
curl
\
doxygen
\
#g++-7
\
gdb
\
git
\
lcov
\
locales
\
pkg-config
\
python3
\
python3-dev
\
python3-pip
\
software-properties-common
\
wget
\
rocm-device-libs
\
hip-base
\
libnuma-dev
\
miopen-hip
\
rocblas
\
hipfft
\
rocthrust
\
rocrand
\
hipsparse
\
rccl
\
rccl-dev
\
rocm-smi-lib
\
rocm-dev
\
roctracer-dev
\
hipcub
\
hipblas
\
hipify-clang
\
half
\
libssl-dev
\
zlib1g-dev
&&
\
apt-get clean
&&
\
rm
-rf
/var/lib/apt/lists/
*
# add this for roctracer dependancies
RUN
pip3
install
CppHeaderParser
# Workaround broken rocm packages
RUN
ln
-s
/opt/rocm-
*
/opt/rocm
RUN
echo
"/opt/rocm/lib"
>
/etc/ld.so.conf.d/rocm.conf
RUN
echo
"/opt/rocm/llvm/lib"
>
/etc/ld.so.conf.d/rocm-llvm.conf
RUN
ldconfig
RUN
locale-gen en_US.UTF-8
RUN
update-locale
LANG
=
en_US.UTF-8
ENV
LC_ALL=C.UTF-8
ENV
LANG=C.UTF-8
# Install dependencies
ADD
dev-requirements.txt /dev-requirements.txt
ADD
requirements.txt /requirements.txt
ADD
rbuild.ini /rbuild.ini
COPY
./tools/install_prereqs.sh /
RUN
/install_prereqs.sh /usr/local /
&&
rm
/install_prereqs.sh
RUN
test
-f
/usr/local/hash
||
exit
1
# Install yapf
RUN
pip3
install
yapf
==
0.28.0
# Install doc requirements
ADD
doc/requirements.txt /doc-requirements.txt
RUN
pip3
install
-r
/doc-requirements.txt
# Download real models to run onnx unit tests
ENV
ONNX_HOME=/.onnx
COPY
./tools/download_models.sh /
RUN
/download_models.sh
&&
rm
/download_models.sh
# Install latest ccache version
RUN
cget
-p
$PREFIX
install
facebook/zstd@v1.4.5
-X
subdir
-DCMAKE_DIR
=
build/cmake
RUN
cget
-p
$PREFIX
install
ccache@v4.1
-DENABLE_TESTING
=
OFF
RUN
cget
-p
/opt/cmake
install
kitware/cmake@v3.24.3
COPY
./test/onnx/.onnxrt-commit /
ARG
ONNXRUNTIME_REPO=https://github.com/Microsoft/onnxruntime
ARG
ONNXRUNTIME_BRANCH=main
ARG
ONNXRUNTIME_COMMIT
RUN
git clone
--single-branch
--branch
${
ONNXRUNTIME_BRANCH
}
--recursive
${
ONNXRUNTIME_REPO
}
onnxruntime
&&
\
cd
onnxruntime
&&
\
if
[
-z
"
$ONNXRUNTIME_COMMIT
"
]
;
then
git checkout
$(
cat
/.onnxrt-commit
)
;
else
git checkout
${
ONNXRUNTIME_COMMIT
}
;
fi
&&
\
/bin/sh /onnxruntime/dockerfiles/scripts/install_common_deps.sh
ADD
tools/build_and_test_onnxrt.sh /onnxruntime/build_and_test_onnxrt.sh
RUN
cget
-p
/usr/local
install
ROCmSoftwarePlatform/rocMLIR@a997d5f51314b45d7a4c04f1599966dcf53f9b4d
-DBUILD_MIXR_TARGET
=
On
-DLLVM_ENABLE_ZSTD
=
Off
-DLLVM_ENABLE_THREADS
=
Off
ENV
MIOPEN_FIND_DB_PATH=/tmp/miopen/find-db
ENV
MIOPEN_USER_DB_PATH=/tmp/miopen/user-db
ENV
LD_LIBRARY_PATH=$PREFIX/lib
# Setup ubsan environment to printstacktrace
ENV
UBSAN_OPTIONS=print_stacktrace=1
ENV
ASAN_OPTIONS=detect_stack_use_after_return=1:check_initialization_order=1:strict_init_order=1
RUN
ln
-s
/opt/rocm/llvm/bin/llvm-symbolizer /usr/bin/llvm-symbolizer
tools/install_prereqs.sh
View file @
c00100f9
...
@@ -56,8 +56,9 @@ echo "Dependencies are installed at $PREFIX"
...
@@ -56,8 +56,9 @@ echo "Dependencies are installed at $PREFIX"
# Install deps with rbuild
# Install deps with rbuild
rbuild prepare
-d
$PREFIX
-s
develop
rbuild prepare
-d
$PREFIX
-s
develop
# install onnx package for unit tests
export
CMAKE_ARGS
=
"-DONNX_USE_PROTOBUF_SHARED_LIBS=ON"
pip3
install
onnx
==
1.10.2
numpy
==
1.21.6
typing
==
3.7.4
pytest
==
6.0.1
packaging
==
23.0
pip3
install
onnx
==
1.10.2
numpy
==
1.21.6
typing
==
3.7.4
pytest
==
6.0.1
packaging
==
23.0
# pin version of protobuf in Python for onnx runtime unit tests
# pin version of protobuf in Python for onnx runtime unit tests
between dist versions
pip3
install
protobuf
==
3.20.0
pip3
install
protobuf
==
3.20.0
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment