Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
894fce68
Unverified
Commit
894fce68
authored
Jun 01, 2023
by
Paul Fultz II
Committed by
GitHub
Jun 01, 2023
Browse files
Merge branch 'develop' into ck-integration-tuning
parents
4550d41b
49b341d3
Changes
28
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
430 additions
and
81 deletions
+430
-81
.github/workflows/rocm-image-release.yaml
.github/workflows/rocm-image-release.yaml
+0
-6
Jenkinsfile
Jenkinsfile
+14
-3
src/CMakeLists.txt
src/CMakeLists.txt
+1
-0
src/driver/main.cpp
src/driver/main.cpp
+21
-9
src/include/migraphx/instruction.hpp
src/include/migraphx/instruction.hpp
+5
-1
src/include/migraphx/matcher.hpp
src/include/migraphx/matcher.hpp
+42
-19
src/include/migraphx/op/run_on_target.hpp
src/include/migraphx/op/run_on_target.hpp
+98
-0
src/include/migraphx/pass_manager.hpp
src/include/migraphx/pass_manager.hpp
+5
-0
src/include/migraphx/program.hpp
src/include/migraphx/program.hpp
+3
-0
src/instruction.cpp
src/instruction.cpp
+5
-1
src/module.cpp
src/module.cpp
+9
-6
src/onnx/parse_instancenorm.cpp
src/onnx/parse_instancenorm.cpp
+55
-10
src/pass_manager.cpp
src/pass_manager.cpp
+20
-14
src/program.cpp
src/program.cpp
+103
-2
src/rewrite_quantization.cpp
src/rewrite_quantization.cpp
+12
-7
src/simplify_algebra.cpp
src/simplify_algebra.cpp
+5
-0
src/targets/fpga/subgraph.cpp
src/targets/fpga/subgraph.cpp
+1
-2
src/targets/gpu/lowering.cpp
src/targets/gpu/lowering.cpp
+2
-1
test/CMakeLists.txt
test/CMakeLists.txt
+20
-0
test/argument_test.cpp
test/argument_test.cpp
+9
-0
No files found.
.github/workflows/rocm-image-release.yaml
View file @
894fce68
...
@@ -22,11 +22,6 @@ on:
...
@@ -22,11 +22,6 @@ on:
description
:
Build navi number
description
:
Build navi number
required
:
true
required
:
true
default
:
"
0"
default
:
"
0"
organization
:
type
:
string
description
:
Organization based on which location of files will be different
required
:
true
default
:
"
AMD"
overwrite
:
overwrite
:
type
:
boolean
type
:
boolean
description
:
Overwrite image if it already exists
description
:
Overwrite image if it already exists
...
@@ -38,7 +33,6 @@ jobs:
...
@@ -38,7 +33,6 @@ jobs:
with
:
with
:
rocm_release
:
${{ github.event.inputs.rocm_release || '5.1' }}
rocm_release
:
${{ github.event.inputs.rocm_release || '5.1' }}
benchmark-utils_repo
:
${{ github.event.inputs.benchmark-utils_repo || 'ROCmSoftwarePlatform/migraphx-benchmark-utils' }}
benchmark-utils_repo
:
${{ github.event.inputs.benchmark-utils_repo || 'ROCmSoftwarePlatform/migraphx-benchmark-utils' }}
organization
:
${{ github.event.inputs.organization || 'AMD' }}
base_image
:
${{ github.event.inputs.base_image || 'rocm/dev-ubuntu-20.04' }}
base_image
:
${{ github.event.inputs.base_image || 'rocm/dev-ubuntu-20.04' }}
docker_image
:
${{ github.event.inputs.docker_image || 'rocm-migraphx' }}
docker_image
:
${{ github.event.inputs.docker_image || 'rocm-migraphx' }}
build_navi
:
${{ github.event.inputs.build_navi || '0' }}
build_navi
:
${{ github.event.inputs.build_navi || '0' }}
...
...
Jenkinsfile
View file @
894fce68
...
@@ -29,7 +29,12 @@ def rocmtestnode(Map conf) {
...
@@ -29,7 +29,12 @@ def rocmtestnode(Map conf) {
mkdir build
mkdir build
cd build
cd build
cmake -DCMAKE_C_COMPILER_LAUNCHER=ccache -DCMAKE_CXX_COMPILER_LAUNCHER=ccache -DBUILD_DEV=On ${flags} ..
cmake -DCMAKE_C_COMPILER_LAUNCHER=ccache -DCMAKE_CXX_COMPILER_LAUNCHER=ccache -DBUILD_DEV=On ${flags} ..
make -j\$(nproc) generate all doc package check VERBOSE=1
git diff
git diff-index --quiet HEAD || (echo "Git repo is not clean after running cmake." && exit 1)
make -j\$(nproc) generate VERBOSE=1
git diff
git diff-index --quiet HEAD || (echo "Generated files are different. Please run make generate and commit the changes." && exit 1)
make -j\$(nproc) all doc package check VERBOSE=1
md5sum ./*.deb
md5sum ./*.deb
"""
"""
echo
cmd
echo
cmd
...
@@ -84,8 +89,10 @@ def rocmnodename(name) {
...
@@ -84,8 +89,10 @@ def rocmnodename(name) {
node_name
=
"${rocmtest_name} && vega"
;
node_name
=
"${rocmtest_name} && vega"
;
}
else
if
(
name
==
"navi21"
)
{
}
else
if
(
name
==
"navi21"
)
{
node_name
=
"${rocmtest_name} && navi21"
;
node_name
=
"${rocmtest_name} && navi21"
;
}
else
if
(
name
==
"anygpu"
)
{
node_name
=
"${rocmtest_name} && (gfx908 || gfx90a || vega)"
;
}
else
if
(
name
==
"nogpu"
)
{
}
else
if
(
name
==
"nogpu"
)
{
return
rocmtest_name
;
node_name
=
"${
rocmtest_name
} && nogpu"
;
}
}
return
node_name
return
node_name
}
}
...
@@ -115,6 +122,10 @@ rocmtest clang_debug: rocmnode('vega') { cmake_build ->
...
@@ -115,6 +122,10 @@ rocmtest clang_debug: rocmnode('vega') { cmake_build ->
stage
(
'HipRTC GPU Debug'
)
{
stage
(
'HipRTC GPU Debug'
)
{
cmake_build
(
flags:
"-DCMAKE_BUILD_TYPE=release -DMIGRAPHX_USE_HIPRTC=On"
,
gpu_debug:
true
,
hiprtc_workarounds:
true
)
cmake_build
(
flags:
"-DCMAKE_BUILD_TYPE=release -DMIGRAPHX_USE_HIPRTC=On"
,
gpu_debug:
true
,
hiprtc_workarounds:
true
)
}
}
},
all_targets_debug
:
rocmnode
(
'vega'
)
{
cmake_build
->
stage
(
'All targets Release'
)
{
cmake_build
(
flags:
"-DCMAKE_BUILD_TYPE=release -DMIGRAPHX_ENABLE_GPU=On -DMIGRAPHX_ENABLE_CPU=On -DMIGRAPHX_ENABLE_FPGA=On"
)
}
},
mlir_debug:
rocmnode
(
'vega'
)
{
cmake_build
->
},
mlir_debug:
rocmnode
(
'vega'
)
{
cmake_build
->
stage
(
'MLIR Debug'
)
{
stage
(
'MLIR Debug'
)
{
withEnv
([
'MIGRAPHX_ENABLE_MLIR=1'
])
{
withEnv
([
'MIGRAPHX_ENABLE_MLIR=1'
])
{
...
@@ -144,7 +155,7 @@ def onnxnode(name, body) {
...
@@ -144,7 +155,7 @@ def onnxnode(name, body) {
}
}
}
}
rocmtest
onnx:
onnxnode
(
'
rocmtest
'
)
{
cmake_build
->
rocmtest
onnx:
onnxnode
(
'
anygpu
'
)
{
cmake_build
->
stage
(
"Onnx runtime"
)
{
stage
(
"Onnx runtime"
)
{
sh
'''
sh
'''
apt install half
apt install half
...
...
src/CMakeLists.txt
View file @
894fce68
...
@@ -195,6 +195,7 @@ register_migraphx_ops(
...
@@ -195,6 +195,7 @@ register_migraphx_ops(
roialign
roialign
round
round
rsqrt
rsqrt
run_on_target
scalar
scalar
scatter_add
scatter_add
scatter_mul
scatter_mul
...
...
src/driver/main.cpp
View file @
894fce68
...
@@ -415,7 +415,8 @@ struct compiler
...
@@ -415,7 +415,8 @@ struct compiler
program_params
parameters
;
program_params
parameters
;
compiler_target
ct
;
compiler_target
ct
;
compile_options
co
;
compile_options
co
;
precision
quantize
=
precision
::
fp32
;
bool
to_fp16
=
false
;
bool
to_int8
=
false
;
std
::
vector
<
std
::
string
>
fill0
;
std
::
vector
<
std
::
string
>
fill0
;
std
::
vector
<
std
::
string
>
fill1
;
std
::
vector
<
std
::
string
>
fill1
;
...
@@ -436,8 +437,8 @@ struct compiler
...
@@ -436,8 +437,8 @@ struct compiler
{
"--exhaustive-tune"
},
{
"--exhaustive-tune"
},
ap
.
help
(
"Exhastively search for best tuning parameters for kernels"
),
ap
.
help
(
"Exhastively search for best tuning parameters for kernels"
),
ap
.
set_value
(
true
));
ap
.
set_value
(
true
));
ap
(
quantize
,
{
"--fp16"
},
ap
.
help
(
"Quantize for fp16"
),
ap
.
set_value
(
precision
::
fp16
));
ap
(
to_fp16
,
{
"--fp16"
},
ap
.
help
(
"Quantize for fp16"
),
ap
.
set_value
(
true
));
ap
(
quantize
,
{
"--int8"
},
ap
.
help
(
"Quantize for int8"
),
ap
.
set_value
(
precision
::
int8
));
ap
(
to_int8
,
{
"--int8"
},
ap
.
help
(
"Quantize for int8"
),
ap
.
set_value
(
true
));
}
}
auto
params
(
const
program
&
p
)
auto
params
(
const
program
&
p
)
...
@@ -445,6 +446,11 @@ struct compiler
...
@@ -445,6 +446,11 @@ struct compiler
return
parameters
.
generate
(
p
,
ct
.
get_target
(),
co
.
offload_copy
,
l
.
batch
);
return
parameters
.
generate
(
p
,
ct
.
get_target
(),
co
.
offload_copy
,
l
.
batch
);
}
}
auto
host_params
(
const
program
&
p
)
{
return
parameters
.
generate
(
p
,
ct
.
get_target
(),
true
,
l
.
batch
);
}
program
compile
()
program
compile
()
{
{
auto
p
=
l
.
load
();
auto
p
=
l
.
load
();
...
@@ -452,13 +458,13 @@ struct compiler
...
@@ -452,13 +458,13 @@ struct compiler
if
(
p
.
is_compiled
())
if
(
p
.
is_compiled
())
return
p
;
return
p
;
auto
t
=
ct
.
get_target
();
auto
t
=
ct
.
get_target
();
if
(
quantize
==
precision
::
fp16
)
if
(
to_
fp16
)
{
{
quantize_fp16
(
p
);
quantize_fp16
(
p
);
}
}
else
if
(
quantize
==
precision
::
int8
)
if
(
to_
int8
)
{
{
quantize_int8
(
p
,
t
,
{
params
(
p
)});
quantize_int8
(
p
,
t
,
{
host_
params
(
p
)});
}
}
p
.
compile
(
t
,
co
);
p
.
compile
(
t
,
co
);
l
.
save
(
p
);
l
.
save
(
p
);
...
@@ -517,17 +523,23 @@ struct verify : command<verify>
...
@@ -517,17 +523,23 @@ struct verify : command<verify>
auto
t
=
c
.
ct
.
get_target
();
auto
t
=
c
.
ct
.
get_target
();
auto
m
=
c
.
parameters
.
generate
(
p
,
t
,
true
,
c
.
l
.
batch
);
auto
m
=
c
.
parameters
.
generate
(
p
,
t
,
true
,
c
.
l
.
batch
);
auto
quantize
=
precision
::
fp32
;
if
(
c
.
to_fp16
)
quantize
=
precision
::
fp16
;
if
(
c
.
to_int8
)
quantize
=
precision
::
int8
;
if
(
per_instruction
)
if
(
per_instruction
)
{
{
verify_instructions
(
p
,
t
,
c
.
co
,
c
.
quantize
,
tolerance
);
verify_instructions
(
p
,
t
,
c
.
co
,
quantize
,
tolerance
);
}
}
else
if
(
reduce
)
else
if
(
reduce
)
{
{
verify_reduced_program
(
p
,
t
,
c
.
co
,
c
.
quantize
,
m
,
tolerance
);
verify_reduced_program
(
p
,
t
,
c
.
co
,
quantize
,
m
,
tolerance
);
}
}
else
else
{
{
verify_program
(
c
.
l
.
file
,
p
,
t
,
c
.
co
,
c
.
quantize
,
m
,
tolerance
);
verify_program
(
c
.
l
.
file
,
p
,
t
,
c
.
co
,
quantize
,
m
,
tolerance
);
}
}
}
}
};
};
...
...
src/include/migraphx/instruction.hpp
View file @
894fce68
...
@@ -136,6 +136,9 @@ struct instruction
...
@@ -136,6 +136,9 @@ struct instruction
operation
normalized_operator
()
const
;
operation
normalized_operator
()
const
;
std
::
size_t
get_target_id
()
const
;
void
set_target_id
(
std
::
size_t
tid
);
void
debug_print
()
const
;
void
debug_print
()
const
;
static
void
print
(
std
::
ostream
&
os
,
static
void
print
(
std
::
ostream
&
os
,
...
@@ -172,7 +175,8 @@ struct instruction
...
@@ -172,7 +175,8 @@ struct instruction
std
::
vector
<
instruction_ref
>
arguments
;
std
::
vector
<
instruction_ref
>
arguments
;
std
::
vector
<
module_ref
>
module_args
;
std
::
vector
<
module_ref
>
module_args
;
literal
lit
;
literal
lit
;
bool
normalized
=
false
;
bool
normalized
=
false
;
std
::
size_t
target_id
=
0
;
};
};
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
}
// namespace migraphx
...
...
src/include/migraphx/matcher.hpp
View file @
894fce68
...
@@ -35,6 +35,10 @@
...
@@ -35,6 +35,10 @@
#include <unordered_map>
#include <unordered_map>
#include <unordered_set>
#include <unordered_set>
#ifndef MIGRAPHX_USE_TYPE_ERASED_MATCHERS
#define MIGRAPHX_USE_TYPE_ERASED_MATCHERS 0
#endif
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
@@ -103,6 +107,13 @@ struct predicate_matcher
...
@@ -103,6 +107,13 @@ struct predicate_matcher
}
}
};
};
/// Convert a predicate function into a matcher
template
<
class
P
>
predicate_matcher
<
P
>
make_predicate_matcher
(
P
p
)
{
return
{
p
};
}
/// Convert a function into a matcher
/// Convert a function into a matcher
template
<
class
F
>
template
<
class
F
>
struct
function_matcher
struct
function_matcher
...
@@ -183,14 +194,26 @@ struct id_matcher
...
@@ -183,14 +194,26 @@ struct id_matcher
template
<
class
M
>
template
<
class
M
>
struct
basic_matcher
;
struct
basic_matcher
;
struct
any_matcher
;
template
<
class
M
>
struct
type_erased_matcher
{
#if MIGRAPHX_USE_TYPE_ERASED_MATCHERS
using
type
=
any_matcher
;
#else
using
type
=
basic_matcher
<
M
>
;
#endif
};
template
<
class
M
>
template
<
class
M
>
basic
_matcher
<
M
>
make_basic_matcher
(
M
m
);
typename
type_erased
_matcher
<
M
>
::
type
make_basic_matcher
(
M
m
);
template
<
class
F
>
template
<
class
F
>
basic_matcher
<
function_matcher
<
F
>>
make_basic_fun_matcher
(
F
f
);
auto
make_basic_fun_matcher
(
F
f
);
template
<
class
P
>
template
<
class
P
>
basic_matcher
<
predicate_matcher
<
P
>>
make_basic_pred_matcher
(
P
p
);
auto
make_basic_pred_matcher
(
P
p
);
/// The basic matcher provides the all_of composability of the matcher
/// The basic matcher provides the all_of composability of the matcher
template
<
class
M
>
template
<
class
M
>
...
@@ -222,38 +245,38 @@ struct basic_matcher
...
@@ -222,38 +245,38 @@ struct basic_matcher
auto
match
(
matcher_context
&
ctx
,
instruction_ref
ins
)
const
{
return
m
.
match
(
ctx
,
ins
);
}
auto
match
(
matcher_context
&
ctx
,
instruction_ref
ins
)
const
{
return
m
.
match
(
ctx
,
ins
);
}
};
};
/// Create a typed-erased matcher
using
any_matcher_base
=
basic_matcher
<
function_matcher
<
std
::
function
<
optional
<
instruction_ref
>
(
matcher_context
&
,
instruction_ref
)
>>>
;
struct
any_matcher
:
any_matcher_base
{
template
<
class
M
>
any_matcher
(
M
mm
)
:
any_matcher_base
({[
=
](
auto
&
ctx
,
auto
ins
)
{
return
mm
.
match
(
ctx
,
ins
);
}})
{
}
};
/// Create a basic matcher from a matcher
/// Create a basic matcher from a matcher
template
<
class
M
>
template
<
class
M
>
basic
_matcher
<
M
>
make_basic_matcher
(
M
m
)
typename
type_erased
_matcher
<
M
>
::
type
make_basic_matcher
(
M
m
)
{
{
return
{
m
};
return
{
m
};
}
}
/// Create a basic matcher from a function
/// Create a basic matcher from a function
template
<
class
F
>
template
<
class
F
>
basic_matcher
<
function_matcher
<
F
>>
make_basic_fun_matcher
(
F
f
)
auto
make_basic_fun_matcher
(
F
f
)
{
{
return
{{
f
}}
;
return
make_basic_matcher
(
make_function_matcher
(
f
))
;
}
}
/// Create a basic matcher from a predicate function
/// Create a basic matcher from a predicate function
template
<
class
P
>
template
<
class
P
>
basic_matcher
<
predicate_matcher
<
P
>>
make_basic_pred_matcher
(
P
p
)
auto
make_basic_pred_matcher
(
P
p
)
{
{
return
{{
p
}}
;
return
make_basic_matcher
(
make_predicate_matcher
(
p
))
;
}
}
/// Create a typed-erased matcher
using
any_matcher_base
=
basic_matcher
<
function_matcher
<
std
::
function
<
optional
<
instruction_ref
>
(
matcher_context
&
,
instruction_ref
)
>>>
;
struct
any_matcher
:
any_matcher_base
{
template
<
class
M
>
any_matcher
(
M
mm
)
:
any_matcher_base
({[
=
](
auto
&
ctx
,
auto
ins
)
{
return
mm
.
match
(
ctx
,
ins
);
}})
{
}
};
/// This macro takes care of the boilerplate for defining a matcher
/// This macro takes care of the boilerplate for defining a matcher
#define MIGRAPHX_BASIC_MATCHER(name, ...) \
#define MIGRAPHX_BASIC_MATCHER(name, ...) \
struct name##_m \
struct name##_m \
...
...
src/include/migraphx/op/run_on_target.hpp
0 → 100644
View file @
894fce68
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_RTGLIB_RUN_ON_TARGET_HPP
#define MIGRAPHX_GUARD_RTGLIB_RUN_ON_TARGET_HPP
#include <unordered_map>
#include <vector>
#include <set>
#include <algorithm>
#include <migraphx/config.hpp>
#include <migraphx/errors.hpp>
#include <migraphx/shape.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/module.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
op
{
struct
run_on_target
{
std
::
size_t
target_id
=
0
;
std
::
string
name
()
const
{
return
"run_on_target"
;
}
template
<
class
Self
,
class
F
>
static
auto
reflect
(
Self
&
self
,
F
f
)
{
return
pack
(
f
(
self
.
target_id
,
"target_id"
));
}
migraphx
::
shape
compute_shape
(
const
std
::
vector
<
migraphx
::
shape
>&
inputs
,
std
::
vector
<
migraphx
::
module_ref
>
mods
)
const
{
if
(
mods
.
size
()
!=
1
)
{
MIGRAPHX_THROW
(
"RUN_ON_TARGET: must have exactly 1 module argument"
);
}
auto
*
mod_input
=
mods
.
front
();
if
(
inputs
.
size
()
!=
mod_input
->
get_parameter_shapes
().
size
())
{
MIGRAPHX_THROW
(
"RUN_ON_TARGET: Mismatched number of input parameters"
);
}
auto
mod_out_shapes
=
mod_input
->
get_output_shapes
();
return
mod_out_shapes
;
}
migraphx
::
argument
compute
(
const
migraphx
::
shape
&
,
const
std
::
vector
<
migraphx
::
argument
>&
args
,
const
std
::
vector
<
migraphx
::
module_ref
>&
mods
,
const
std
::
function
<
std
::
vector
<
migraphx
::
argument
>
(
migraphx
::
module_ref
&
,
const
std
::
unordered_map
<
std
::
string
,
migraphx
::
argument
>&
)
>&
run
)
const
{
std
::
unordered_map
<
std
::
string
,
migraphx
::
argument
>
params
;
std
::
set
<
std
::
string
>
pnames
;
const
auto
*
smod
=
mods
.
front
();
assert
(
mods
.
size
()
==
1
);
auto
names
=
smod
->
get_parameter_names
();
pnames
.
insert
(
names
.
begin
(),
names
.
end
());
assert
(
pnames
.
size
()
==
args
.
size
());
std
::
transform
(
pnames
.
begin
(),
pnames
.
end
(),
args
.
begin
(),
std
::
inserter
(
params
,
params
.
end
()),
[](
auto
&&
name
,
auto
&&
arg
)
{
return
std
::
make_pair
(
name
,
arg
);
});
auto
*
mod
=
mods
.
front
();
auto
results
=
run
(
mod
,
params
);
return
migraphx
::
argument
{
results
};
}
};
}
// namespace op
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif
src/include/migraphx/pass_manager.hpp
View file @
894fce68
...
@@ -26,6 +26,7 @@
...
@@ -26,6 +26,7 @@
#include <migraphx/config.hpp>
#include <migraphx/config.hpp>
#include <migraphx/pass.hpp>
#include <migraphx/pass.hpp>
#include <migraphx/module_ref.hpp>
#include <migraphx/tracer.hpp>
#include <migraphx/tracer.hpp>
#include <vector>
#include <vector>
...
@@ -46,6 +47,10 @@ struct module_pass_manager
...
@@ -46,6 +47,10 @@ struct module_pass_manager
virtual
~
module_pass_manager
()
{}
virtual
~
module_pass_manager
()
{}
};
};
void
run_passes
(
program
&
prog
,
module_ref
root_mod
,
const
std
::
vector
<
pass
>&
passes
,
tracer
trace
=
tracer
{});
void
run_passes
(
module
&
mod
,
const
std
::
vector
<
pass
>&
passes
,
tracer
trace
=
tracer
{});
void
run_passes
(
module
&
mod
,
const
std
::
vector
<
pass
>&
passes
,
tracer
trace
=
tracer
{});
void
run_passes
(
program
&
prog
,
const
std
::
vector
<
pass
>&
passes
,
tracer
trace
=
tracer
{});
void
run_passes
(
program
&
prog
,
const
std
::
vector
<
pass
>&
passes
,
tracer
trace
=
tracer
{});
...
...
src/include/migraphx/program.hpp
View file @
894fce68
...
@@ -92,6 +92,9 @@ struct program
...
@@ -92,6 +92,9 @@ struct program
void
compile
(
const
target
&
t
,
compile_options
options
=
compile_options
{});
void
compile
(
const
target
&
t
,
compile_options
options
=
compile_options
{});
void
compile
(
const
std
::
vector
<
target
>&
targets
,
std
::
vector
<
compile_options
>
compile_opts
=
{});
bool
is_compiled
()
const
;
bool
is_compiled
()
const
;
void
finalize
();
void
finalize
();
...
...
src/instruction.cpp
100755 → 100644
View file @
894fce68
...
@@ -406,6 +406,9 @@ void instruction::print(std::ostream& os,
...
@@ -406,6 +406,9 @@ void instruction::print(std::ostream& os,
// skip return instruction shape
// skip return instruction shape
if
(
ins
->
name
()
!=
"@return"
)
if
(
ins
->
name
()
!=
"@return"
)
os
<<
" -> "
<<
ins
->
get_shape
();
os
<<
" -> "
<<
ins
->
get_shape
();
// print tid
os
<<
", target_id="
<<
ins
->
target_id
;
}
}
static
void
debug_name
(
std
::
ostream
&
os
,
const
instruction
&
ins
)
static
void
debug_name
(
std
::
ostream
&
os
,
const
instruction
&
ins
)
...
@@ -469,7 +472,8 @@ operation instruction::normalized_operator() const
...
@@ -469,7 +472,8 @@ operation instruction::normalized_operator() const
}
}
return
o
;
return
o
;
}
}
std
::
size_t
instruction
::
get_target_id
()
const
{
return
target_id
;
}
void
instruction
::
set_target_id
(
std
::
size_t
tid
)
{
this
->
target_id
=
tid
;
}
std
::
vector
<
shape
>
to_shapes
(
const
std
::
vector
<
instruction_ref
>&
args
)
std
::
vector
<
shape
>
to_shapes
(
const
std
::
vector
<
instruction_ref
>&
args
)
{
{
std
::
vector
<
shape
>
shapes
(
args
.
size
());
std
::
vector
<
shape
>
shapes
(
args
.
size
());
...
...
src/module.cpp
View file @
894fce68
...
@@ -723,15 +723,15 @@ std::unordered_map<instruction_ref, std::string> module::print(
...
@@ -723,15 +723,15 @@ std::unordered_map<instruction_ref, std::string> module::print(
for
(
auto
ins
:
iterator_for
(
*
this
))
for
(
auto
ins
:
iterator_for
(
*
this
))
{
{
std
::
string
var_name
;
std
::
string
var_name
;
if
(
not
this
->
name
().
empty
()
and
this
->
name
()
!=
"main"
)
var_name
=
this
->
name
()
+
":"
;
if
(
ins
->
name
()
==
"@param"
)
if
(
ins
->
name
()
==
"@param"
)
{
{
var_name
=
any_cast
<
builtin
::
param
>
(
ins
->
get_operator
()).
parameter
;
var_name
.
append
(
any_cast
<
builtin
::
param
>
(
ins
->
get_operator
()).
parameter
)
;
}
}
else
else
{
{
var_name
=
this
->
name
();
var_name
.
append
(
"@"
+
std
::
to_string
(
count
));
var_name
.
append
((
this
->
name
().
empty
()
?
"@"
:
":@"
));
var_name
.
append
(
std
::
to_string
(
count
));
}
}
// count every instruction so index matches loc in the printout program
// count every instruction so index matches loc in the printout program
count
++
;
count
++
;
...
@@ -795,7 +795,10 @@ static std::string to_c_id(const std::string& name, char rep = '_')
...
@@ -795,7 +795,10 @@ static std::string to_c_id(const std::string& name, char rep = '_')
static
std
::
string
cpp_var_name
(
const
std
::
string
&
name
)
static
std
::
string
cpp_var_name
(
const
std
::
string
&
name
)
{
{
return
to_c_id
(
"x_"
+
replace_string
(
name
,
":"
,
"_module_"
));
std
::
string
prefix
=
"x_"
;
if
(
not
contains
(
name
,
"@"
))
prefix
=
"p_"
;
return
to_c_id
(
prefix
+
replace_string
(
name
,
":"
,
"_module_"
));
}
}
static
void
print_py_op
(
std
::
ostream
&
os
,
const
operation
&
op
)
static
void
print_py_op
(
std
::
ostream
&
os
,
const
operation
&
op
)
...
@@ -875,7 +878,7 @@ module::print_py(std::ostream& os,
...
@@ -875,7 +878,7 @@ module::print_py(std::ostream& os,
use_abs
=
false
;
use_abs
=
false
;
if
(
use_abs
)
if
(
use_abs
)
os
<<
"migraphx.abs_literal("
;
os
<<
"migraphx.abs_literal("
;
os
<<
"migraphx.generate_
literal
("
;
os
<<
"migraphx.generate_
argument
("
;
print_py_shape
(
os
,
ins
->
get_shape
());
print_py_shape
(
os
,
ins
->
get_shape
());
os
<<
", "
<<
seed
<<
")"
;
os
<<
", "
<<
seed
<<
")"
;
if
(
use_abs
)
if
(
use_abs
)
...
...
src/onnx/parse_instancenorm.cpp
View file @
894fce68
...
@@ -21,10 +21,14 @@
...
@@ -21,10 +21,14 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
* THE SOFTWARE.
*/
*/
#include <iterator>
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/env.hpp>
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_DISABLE_FP16_INSTANCENORM_CONVERT
);
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
@@ -39,22 +43,43 @@ struct parse_instancenorm : op_parser<parse_instancenorm>
...
@@ -39,22 +43,43 @@ struct parse_instancenorm : op_parser<parse_instancenorm>
instruction_ref
parse
(
const
op_desc
&
opd
,
instruction_ref
parse
(
const
op_desc
&
opd
,
const
onnx_parser
&
parser
,
const
onnx_parser
&
parser
,
onnx_parser
::
node_info
info
,
onnx_parser
::
node_info
info
,
std
::
vector
<
instruction_ref
>
args
)
const
std
::
vector
<
instruction_ref
>
o
args
)
const
{
{
// y = scale * ( x - mean ) / sqrt ( variance + epsilon ) + bias
// y = scale * ( x - mean ) / sqrt ( variance + epsilon ) + bias
// mean = reduce_mean({D1, D2, ... Dk}, x)
// mean = reduce_mean({D1, D2, ... Dk}, x)
// variance = reduce_mean({D1, D2, ... Dk}, (x - mean)^2)
// variance = reduce_mean({D1, D2, ... Dk}, (x - mean)^2)
// Convert fp16 to fp32 to workaround for FP16 accuracy issues with reduce_mean/variance.
bool
convert_fp16
=
true
;
if
(
enabled
(
MIGRAPHX_DISABLE_FP16_INSTANCENORM_CONVERT
{}))
{
convert_fp16
=
false
;
}
float
epsilon
=
1e-5
f
;
float
epsilon
=
1e-5
f
;
if
(
contains
(
info
.
attributes
,
"epsilon"
))
if
(
contains
(
info
.
attributes
,
"epsilon"
))
{
{
epsilon
=
parser
.
parse_value
(
info
.
attributes
.
at
(
"epsilon"
)).
at
<
float
>
();
epsilon
=
parser
.
parse_value
(
info
.
attributes
.
at
(
"epsilon"
)).
at
<
float
>
();
}
}
auto
dtype
=
oargs
[
0
]
->
get_shape
().
type
();
auto
literal_dtype
=
dtype
;
std
::
vector
<
instruction_ref
>
args
;
// cppcheck-suppress knownConditionTrueFalse
if
(
dtype
==
shape
::
half_type
and
convert_fp16
)
{
std
::
transform
(
oargs
.
begin
(),
oargs
.
end
(),
std
::
back_inserter
(
args
),
[
&
](
const
auto
i
)
{
return
info
.
add_instruction
(
make_op
(
"convert"
,
{{
"target_type"
,
shape
::
float_type
}}),
i
);
});
literal_dtype
=
shape
::
float_type
;
}
else
{
args
=
oargs
;
}
auto
x
=
args
[
0
];
auto
x
=
args
[
0
];
auto
scale
=
args
[
1
];
auto
scale
=
args
[
1
];
auto
bias
=
args
[
2
];
auto
bias
=
args
[
2
];
auto
dims
=
x
->
get_shape
().
lens
();
auto
dims
=
x
->
get_shape
().
lens
();
auto
dtype
=
x
->
get_shape
().
type
();
if
(
not
contains
(
valid_types
,
dtype
))
if
(
not
contains
(
valid_types
,
dtype
))
MIGRAPHX_THROW
(
opd
.
op_name
+
": invalid output type: "
+
std
::
to_string
(
dtype
)
+
MIGRAPHX_THROW
(
opd
.
op_name
+
": invalid output type: "
+
std
::
to_string
(
dtype
)
+
". Valid types are 1 (float), 10 (half), and 11 (double)."
);
". Valid types are 1 (float), 10 (half), and 11 (double)."
);
...
@@ -65,14 +90,29 @@ struct parse_instancenorm : op_parser<parse_instancenorm>
...
@@ -65,14 +90,29 @@ struct parse_instancenorm : op_parser<parse_instancenorm>
std
::
vector
<
int64_t
>
axes
(
kdims
);
std
::
vector
<
int64_t
>
axes
(
kdims
);
std
::
iota
(
axes
.
begin
(),
axes
.
end
(),
2
);
std
::
iota
(
axes
.
begin
(),
axes
.
end
(),
2
);
auto
mean
=
info
.
add_instruction
(
make_op
(
"reduce_mean"
,
{{
"axes"
,
axes
}}),
x
);
auto
mean
=
info
.
add_instruction
(
make_op
(
"reduce_mean"
,
{{
"axes"
,
axes
}}),
x
);
auto
mean_bcast
=
auto
mean_bcast
=
info
.
add_instruction
(
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
dims
}}),
mean
);
info
.
add_instruction
(
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
dims
}}),
mean
);
auto
l1
=
info
.
add_instruction
(
make_op
(
"sub"
),
x
,
mean_bcast
);
// for the fp16, if not converting to fp32 then divide `x` and `mean` by `sqrt(n)` and take
// reduce_sum to calculate variance i.e.
// var = reduce_sum((x/s_n - mean/s_n)^2) where s_n = sqrt(n)
std
::
string
reduce_op_name
=
(
dtype
==
shape
::
half_type
and
not
convert_fp16
)
?
"reduce_sum"
:
"reduce_mean"
;
if
(
dtype
==
shape
::
half_type
and
not
convert_fp16
)
{
double
n
=
std
::
accumulate
(
dims
.
begin
()
+
2
,
dims
.
end
(),
1
,
[
&
](
const
auto
&
i
,
const
auto
&
j
)
{
return
i
*
j
;
});
n
=
1.0
/
std
::
sqrt
(
n
);
auto
n_literal
=
info
.
add_literal
(
literal
{
dtype
,
{
n
}});
mean_bcast
=
info
.
add_common_op
(
"mul"
,
{
mean_bcast
,
n_literal
});
x
=
info
.
add_common_op
(
"mul"
,
{
x
,
n_literal
});
}
auto
l0
=
info
.
add_instruction
(
make_op
(
"sqdiff"
),
x
,
mean_bcast
);
auto
l0
=
info
.
add_instruction
(
make_op
(
"sqdiff"
),
x
,
mean_bcast
);
auto
variance
=
info
.
add_instruction
(
make_op
(
"reduce_mean"
,
{{
"axes"
,
axes
}}),
l0
);
auto
variance
=
info
.
add_instruction
(
make_op
(
reduce_op_name
,
{{
"axes"
,
axes
}}),
l0
);
auto
l1
=
info
.
add_instruction
(
make_op
(
"sub"
),
x
,
mean_bcast
);
auto
epsilon_literal
=
info
.
add_literal
(
literal
{
shape
{
literal_dtype
},
{
epsilon
}});
auto
epsilon_literal
=
info
.
add_literal
(
literal
{
shape
{
dtype
},
{
epsilon
}});
auto
epsilon_bcast
=
auto
epsilon_bcast
=
info
.
add_instruction
(
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
dims
}}),
epsilon_literal
);
info
.
add_instruction
(
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
dims
}}),
epsilon_literal
);
auto
variance_bcast
=
auto
variance_bcast
=
...
@@ -82,11 +122,16 @@ struct parse_instancenorm : op_parser<parse_instancenorm>
...
@@ -82,11 +122,16 @@ struct parse_instancenorm : op_parser<parse_instancenorm>
auto
l4
=
info
.
add_instruction
(
make_op
(
"mul"
),
l1
,
l3
);
auto
l4
=
info
.
add_instruction
(
make_op
(
"mul"
),
l1
,
l3
);
auto
scale_bcast
=
auto
scale_bcast
=
info
.
add_instruction
(
make_op
(
"broadcast"
,
{{
"axis"
,
1
},
{
"out_lens"
,
dims
}}),
scale
);
info
.
add_instruction
(
make_op
(
"broadcast"
,
{{
"axis"
,
1
},
{
"out_lens"
,
dims
}}),
scale
);
;
auto
bias_bcast
=
auto
bias_bcast
=
info
.
add_instruction
(
make_op
(
"broadcast"
,
{{
"axis"
,
1
},
{
"out_lens"
,
dims
}}),
bias
);
info
.
add_instruction
(
make_op
(
"broadcast"
,
{{
"axis"
,
1
},
{
"out_lens"
,
dims
}}),
bias
);
auto
l5
=
info
.
add_instruction
(
make_op
(
"mul"
),
l4
,
scale_bcast
);
auto
l5
=
info
.
add_instruction
(
make_op
(
"mul"
),
l4
,
scale_bcast
);
return
info
.
add_instruction
(
make_op
(
"add"
),
l5
,
bias_bcast
);
auto
ret
=
info
.
add_instruction
(
make_op
(
"add"
),
l5
,
bias_bcast
);
if
(
dtype
==
shape
::
half_type
and
convert_fp16
)
{
return
info
.
add_instruction
(
make_op
(
"convert"
,
{{
"target_type"
,
shape
::
half_type
}}),
ret
);
}
return
ret
;
}
}
};
};
...
...
src/pass_manager.cpp
View file @
894fce68
...
@@ -123,27 +123,18 @@ struct module_pm : module_pass_manager
...
@@ -123,27 +123,18 @@ struct module_pm : module_pass_manager
module
&
get_module
(
module_pass_manager
&
mpm
)
{
return
mpm
.
get_module
();
}
module
&
get_module
(
module_pass_manager
&
mpm
)
{
return
mpm
.
get_module
();
}
void
run_passes
(
module
&
mod
,
const
std
::
vector
<
pass
>&
passes
,
tracer
trace
)
void
run_passes
(
program
&
prog
,
module_ref
root_mod
,
const
std
::
vector
<
pass
>&
passes
,
tracer
trace
)
{
if
(
enabled
(
MIGRAPHX_TRACE_PASSES
{}))
trace
=
tracer
{
std
::
cout
};
for
(
const
auto
&
p
:
passes
)
{
module_pm
{
&
mod
,
&
trace
}.
run_pass
(
p
);
}
}
void
run_passes
(
program
&
prog
,
const
std
::
vector
<
pass
>&
passes
,
tracer
trace
)
{
{
if
(
enabled
(
MIGRAPHX_TRACE_PASSES
{}))
if
(
enabled
(
MIGRAPHX_TRACE_PASSES
{}))
trace
=
tracer
{
std
::
cout
};
trace
=
tracer
{
std
::
cout
};
std
::
unordered_set
<
module_ref
>
visited
;
std
::
unordered_set
<
module_ref
>
visited
;
for
(
const
auto
&
p
:
passes
)
for
(
const
auto
&
p
:
passes
)
{
{
auto
mods
=
prog
.
get_modules
();
auto
tree
=
prog
.
get_module_tree
();
auto
tree
=
prog
.
get_module_tree
();
std
::
vector
<
module_ref
>
sub_mods
=
root_mod
->
get_sub_modules
();
sub_mods
.
insert
(
sub_mods
.
begin
(),
root_mod
);
visited
.
clear
();
visited
.
clear
();
for
(
const
auto
&
mod
:
reverse
(
mods
))
for
(
const
auto
&
mod
:
reverse
(
sub_
mods
))
{
{
if
(
mod
->
bypass
())
if
(
mod
->
bypass
())
continue
;
continue
;
...
@@ -167,5 +158,20 @@ void run_passes(program& prog, const std::vector<pass>& passes, tracer trace)
...
@@ -167,5 +158,20 @@ void run_passes(program& prog, const std::vector<pass>& passes, tracer trace)
}
}
}
}
void
run_passes
(
module
&
mod
,
const
std
::
vector
<
pass
>&
passes
,
tracer
trace
)
{
if
(
enabled
(
MIGRAPHX_TRACE_PASSES
{}))
trace
=
tracer
{
std
::
cout
};
for
(
const
auto
&
p
:
passes
)
{
module_pm
{
&
mod
,
&
trace
}.
run_pass
(
p
);
}
}
void
run_passes
(
program
&
prog
,
const
std
::
vector
<
pass
>&
passes
,
tracer
trace
)
{
run_passes
(
prog
,
prog
.
get_main_module
(),
passes
,
trace
);
}
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
}
// namespace migraphx
src/program.cpp
View file @
894fce68
...
@@ -21,6 +21,7 @@
...
@@ -21,6 +21,7 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
* THE SOFTWARE.
*/
*/
#include <migraphx/compile_options.hpp>
#include <migraphx/program.hpp>
#include <migraphx/program.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/instruction.hpp>
...
@@ -42,6 +43,7 @@
...
@@ -42,6 +43,7 @@
#include <sstream>
#include <sstream>
#include <algorithm>
#include <algorithm>
#include <set>
#include <set>
#include <unordered_map>
#include <utility>
#include <utility>
#include <unordered_set>
#include <unordered_set>
...
@@ -53,12 +55,24 @@ inline namespace MIGRAPHX_INLINE_NS {
...
@@ -53,12 +55,24 @@ inline namespace MIGRAPHX_INLINE_NS {
using
milliseconds
=
std
::
chrono
::
duration
<
double
,
std
::
milli
>
;
using
milliseconds
=
std
::
chrono
::
duration
<
double
,
std
::
milli
>
;
struct
mark_instruction_target
{
std
::
size_t
target_id
=
0
;
std
::
string
name
()
const
{
return
"mark_instruction_target"
;
}
void
apply
(
module
&
m
)
const
{
for
(
auto
&
ins
:
m
)
ins
.
set_target_id
(
target_id
);
}
};
struct
program_impl
struct
program_impl
{
{
// A map is used to keep references to modules of the program
// A map is used to keep references to modules of the program
std
::
unordered_map
<
std
::
string
,
module
>
modules
;
std
::
unordered_map
<
std
::
string
,
module
>
modules
;
context
ctx
;
context
ctx
;
std
::
string
target_name
;
std
::
string
target_name
;
std
::
vector
<
context
>
contexts
;
};
};
program
::
program
()
:
impl
(
std
::
make_unique
<
program_impl
>
())
{
this
->
create_module
(
"main"
);
}
program
::
program
()
:
impl
(
std
::
make_unique
<
program_impl
>
())
{
this
->
create_module
(
"main"
);
}
...
@@ -205,8 +219,94 @@ target_assignments program::get_target_assignments(const std::vector<target>& ta
...
@@ -205,8 +219,94 @@ target_assignments program::get_target_assignments(const std::vector<target>& ta
bool
program
::
is_compiled
()
const
{
return
not
this
->
impl
->
target_name
.
empty
();
}
bool
program
::
is_compiled
()
const
{
return
not
this
->
impl
->
target_name
.
empty
();
}
void
program
::
compile
(
const
std
::
vector
<
target
>&
targets
,
std
::
vector
<
compile_options
>
compile_opts
)
{
// Gather all the target roots
std
::
unordered_multimap
<
std
::
size_t
,
module_ref
>
roots
;
auto
mods
=
this
->
get_modules
();
for
(
auto
*
mod
:
mods
)
{
for
(
const
auto
&
ins
:
*
mod
)
{
if
(
ins
.
name
()
!=
"run_on_target"
)
continue
;
auto
v
=
ins
.
get_operator
().
to_value
();
module_ref
root
=
ins
.
module_inputs
().
front
();
std
::
size_t
root_target_id
=
v
.
at
(
"target_id"
).
to
<
std
::
size_t
>
();
assert
(
root_target_id
<
targets
.
size
());
roots
.
insert
({
root_target_id
,
root
});
}
}
auto
trace
=
tracer
{};
// TODO: Add tracer based on compile options
if
(
enabled
(
MIGRAPHX_TRACE_COMPILE
{}))
trace
=
tracer
{
std
::
cout
};
trace
(
*
this
);
trace
();
// It is assumed that all instructions outside of any root module would run on "ref" target
// Ref target may or may not be passed as one of the target for the "compile()".
// If it is not passed, Create one and add context of it into the map.
auto
target_idx
=
[
&
](
const
std
::
string
&
t_name
)
{
return
static_cast
<
std
::
size_t
>
(
std
::
find_if
(
targets
.
begin
(),
targets
.
end
(),
[
&
](
const
auto
&
t
)
{
return
t
.
name
()
==
t_name
;
})
-
targets
.
begin
());
};
std
::
size_t
ref_target_id
=
target_idx
(
"ref"
);
if
(
ref_target_id
==
targets
.
size
())
{
this
->
impl
->
contexts
.
resize
(
targets
.
size
()
+
1
);
this
->
impl
->
contexts
[
ref_target_id
]
=
migraphx
::
make_target
(
"ref"
).
get_context
();
// users could pass lessers compile_ops than targets, in that case use default compile_opts
compile_opts
.
resize
(
targets
.
size
()
+
1
,
migraphx
::
compile_options
{});
}
else
{
this
->
impl
->
contexts
.
resize
(
targets
.
size
());
compile_opts
.
resize
(
targets
.
size
(),
migraphx
::
compile_options
{});
}
// mark all the instruction as ref target first, later change target_id based on root-target
run_passes
(
*
this
,
{
mark_instruction_target
{
ref_target_id
}});
// Run passes on each root target
for
(
const
auto
i
:
range
(
targets
.
size
()))
{
const
auto
&
root_target
=
targets
.
at
(
i
);
auto
root_target_id
=
i
;
auto
root_modules_range
=
roots
.
equal_range
(
root_target_id
);
this
->
impl
->
contexts
[
root_target_id
]
=
root_target
.
get_context
();
for
(
const
auto
&
[
id
,
current_mod
]
:
range
(
root_modules_range
))
{
auto
passes
=
root_target
.
get_passes
(
this
->
impl
->
contexts
[
root_target_id
],
compile_opts
[
root_target_id
]);
passes
.
push_back
(
mark_instruction_target
{
static_cast
<
size_t
>
(
root_target_id
)});
run_passes
(
*
this
,
current_mod
,
passes
,
trace
);
auto
invalid
=
current_mod
->
validate
();
if
(
invalid
!=
current_mod
->
end
())
{
MIGRAPHX_THROW
(
"Invalid module "
+
current_mod
->
name
()
+
" from compilation at instruction "
+
std
::
to_string
(
std
::
distance
(
current_mod
->
begin
(),
invalid
)));
}
auto
dangling
=
current_mod
->
find_dangling_reference
();
if
(
dangling
!=
current_mod
->
end
())
{
auto
index
=
std
::
distance
(
current_mod
->
begin
(),
dangling
);
MIGRAPHX_THROW
(
"Dangling reference in module "
+
current_mod
->
name
()
+
" from instruction "
+
std
::
to_string
(
index
));
}
current_mod
->
finalize
(
this
->
impl
->
contexts
[
root_target_id
]);
}
}
}
void
program
::
compile
(
const
target
&
t
,
compile_options
options
)
void
program
::
compile
(
const
target
&
t
,
compile_options
options
)
{
{
// todo: combine with multi-target compile method
assert
(
not
this
->
is_compiled
());
assert
(
not
this
->
is_compiled
());
this
->
impl
->
target_name
=
t
.
name
();
this
->
impl
->
target_name
=
t
.
name
();
this
->
impl
->
ctx
=
t
.
get_context
();
this
->
impl
->
ctx
=
t
.
get_context
();
...
@@ -366,7 +466,6 @@ std::vector<argument> generic_eval(const module* mod,
...
@@ -366,7 +466,6 @@ std::vector<argument> generic_eval(const module* mod,
assert
(
results
.
find
(
i
)
!=
results
.
end
());
assert
(
results
.
find
(
i
)
!=
results
.
end
());
return
results
[
i
];
return
results
[
i
];
});
});
const
auto
&
mod_args
=
ins
->
module_inputs
();
const
auto
&
mod_args
=
ins
->
module_inputs
();
auto
module_eval
=
[
&
](
module_ref
smod
,
auto
module_eval
=
[
&
](
module_ref
smod
,
const
std
::
unordered_map
<
std
::
string
,
argument
>&
inputs
)
{
const
std
::
unordered_map
<
std
::
string
,
argument
>&
inputs
)
{
...
@@ -861,7 +960,9 @@ void program::print_py(std::ostream& os) const
...
@@ -861,7 +960,9 @@ void program::print_py(std::ostream& os) const
os
<<
"p = migraphx.program()
\n
"
;
os
<<
"p = migraphx.program()
\n
"
;
for
(
auto
&
mod
:
vec_modules
)
for
(
auto
&
mod
:
vec_modules
)
{
{
std
::
string
var_name
=
"m"
+
mod
->
name
();
std
::
string
var_name
=
"m"
;
if
(
mod
->
name
()
!=
"main"
)
var_name
+=
mod
->
name
();
os
<<
var_name
<<
" = "
;
os
<<
var_name
<<
" = "
;
if
(
mod
->
name
()
==
"main"
)
if
(
mod
->
name
()
==
"main"
)
os
<<
"p.get_main_module()"
;
os
<<
"p.get_main_module()"
;
...
...
src/rewrite_quantization.cpp
View file @
894fce68
...
@@ -40,15 +40,18 @@ void apply_quantizelinear(module& m, instruction_ref ins)
...
@@ -40,15 +40,18 @@ void apply_quantizelinear(module& m, instruction_ref ins)
if
(
x
->
get_shape
().
type
()
!=
y_scale
->
get_shape
().
type
())
if
(
x
->
get_shape
().
type
()
!=
y_scale
->
get_shape
().
type
())
{
{
x
=
m
.
insert_instruction
(
ins
,
make_op
(
"convert"
,
{{
"target_type"
,
shape
::
float_type
}}),
x
);
x
=
m
.
insert_instruction
(
ins
,
make_op
(
"convert"
,
{{
"target_type"
,
y_scale
->
get_shape
().
type
()}}),
x
);
}
}
auto
div
=
m
.
insert_instruction
(
ins
,
make_op
(
"div"
),
x
,
y_scale
);
auto
div
=
m
.
insert_instruction
(
ins
,
make_op
(
"div"
),
x
,
y_scale
);
auto
add_zero_point
=
m
.
insert_instruction
(
ins
,
make_op
(
"round"
),
div
);
auto
add_zero_point
=
m
.
insert_instruction
(
ins
,
make_op
(
"round"
),
div
);
if
(
ins
->
inputs
().
size
()
==
3
)
if
(
ins
->
inputs
().
size
()
==
3
)
{
{
auto
zero_point
=
m
.
insert_instruction
(
auto
zero_point
=
ins
,
make_op
(
"convert"
,
{{
"target_type"
,
shape
::
float_type
}}),
ins
->
inputs
()[
2
]);
m
.
insert_instruction
(
ins
,
make_op
(
"convert"
,
{{
"target_type"
,
y_scale
->
get_shape
().
type
()}}),
ins
->
inputs
()[
2
]);
add_zero_point
=
m
.
insert_instruction
(
ins
,
make_op
(
"add"
),
add_zero_point
,
zero_point
);
add_zero_point
=
m
.
insert_instruction
(
ins
,
make_op
(
"add"
),
add_zero_point
,
zero_point
);
}
}
...
@@ -72,14 +75,16 @@ void apply_quantizelinear(module& m, instruction_ref ins)
...
@@ -72,14 +75,16 @@ void apply_quantizelinear(module& m, instruction_ref ins)
void
apply_dequantizelinear
(
module
&
m
,
instruction_ref
ins
)
void
apply_dequantizelinear
(
module
&
m
,
instruction_ref
ins
)
{
{
assert
(
ins
->
name
()
==
"dequantizelinear"
);
assert
(
ins
->
name
()
==
"dequantizelinear"
);
auto
x
=
m
.
insert_instruction
(
ins
,
make_op
(
"convert"
,
{{
"target_type"
,
shape
::
float_type
}}),
ins
->
inputs
()[
0
]);
auto
x_scale
=
ins
->
inputs
()[
1
];
auto
x_scale
=
ins
->
inputs
()[
1
];
auto
x
=
m
.
insert_instruction
(
ins
,
make_op
(
"convert"
,
{{
"target_type"
,
x_scale
->
get_shape
().
type
()}}),
ins
->
inputs
()[
0
]);
if
(
ins
->
inputs
().
size
()
==
3
)
if
(
ins
->
inputs
().
size
()
==
3
)
{
{
auto
x_zero_point
=
m
.
insert_instruction
(
auto
x_zero_point
=
ins
,
make_op
(
"convert"
,
{{
"target_type"
,
shape
::
float_type
}}),
ins
->
inputs
()[
2
]);
m
.
insert_instruction
(
ins
,
make_op
(
"convert"
,
{{
"target_type"
,
x_scale
->
get_shape
().
type
()}}),
ins
->
inputs
()[
2
]);
x
=
m
.
insert_instruction
(
ins
,
make_op
(
"sub"
),
x
,
x_zero_point
);
x
=
m
.
insert_instruction
(
ins
,
make_op
(
"sub"
),
x
,
x_zero_point
);
}
}
...
...
src/simplify_algebra.cpp
View file @
894fce68
...
@@ -501,6 +501,11 @@ struct find_inner_broadcast
...
@@ -501,6 +501,11 @@ struct find_inner_broadcast
auto
broadcasts
=
ins
->
inputs
();
auto
broadcasts
=
ins
->
inputs
();
if
(
broadcasts
.
empty
())
if
(
broadcasts
.
empty
())
return
;
return
;
// Skip if different data types are used
if
(
any_of
(
broadcasts
,
[
&
](
auto
i
)
{
return
i
->
get_shape
().
type
()
!=
broadcasts
.
front
()
->
get_shape
().
type
();
}))
return
;
bool
mixed_broadcasts
=
any_of
(
broadcasts
,
non_scalar_op
(
"broadcast"
))
and
bool
mixed_broadcasts
=
any_of
(
broadcasts
,
non_scalar_op
(
"broadcast"
))
and
any_of
(
broadcasts
,
non_scalar_op
(
"multibroadcast"
));
any_of
(
broadcasts
,
non_scalar_op
(
"multibroadcast"
));
// If the broadcast is not a single dimension, then dont perform inner_broadcast
// If the broadcast is not a single dimension, then dont perform inner_broadcast
...
...
src/targets/fpga/subgraph.cpp
View file @
894fce68
...
@@ -113,8 +113,7 @@ void subgraph::apply(module_pass_manager& mpm) const
...
@@ -113,8 +113,7 @@ void subgraph::apply(module_pass_manager& mpm) const
// TODO(varunsh): this code may be replaceable by code in the fuse_pointwise pass
// TODO(varunsh): this code may be replaceable by code in the fuse_pointwise pass
// assuming all FPGA instructions are in one contiguous range
// assuming all FPGA instructions are in one contiguous range
pm
->
insert_instructions
(
pm
->
end
(),
first
,
last
,
{});
pm
->
insert_instructions
(
pm
->
end
(),
first
,
std
::
next
(
last
),
{});
migraphx
::
instruction_ref
placeholder_ins
;
migraphx
::
instruction_ref
placeholder_ins
;
for
(
auto
it
:
iterator_for
(
mod
))
for
(
auto
it
:
iterator_for
(
mod
))
{
{
...
...
src/targets/gpu/lowering.cpp
View file @
894fce68
...
@@ -83,7 +83,8 @@ struct miopen_apply
...
@@ -83,7 +83,8 @@ struct miopen_apply
auto
&
ctx
=
get_context
();
auto
&
ctx
=
get_context
();
int8_x4_format
=
get_int8_x4_format
(
ctx
);
int8_x4_format
=
get_int8_x4_format
(
ctx
);
compute_fp32
=
get_compute_fp32_flag
();
compute_fp32
=
get_compute_fp32_flag
();
offload_copy
=
(
mod
->
name
()
==
"main"
)
?
pass
->
offload_copy
:
false
;
// TODO: Set Offload copy based on root modules' compile options
offload_copy
=
(
mod
->
name
()
==
"main"
)
?
pass
->
offload_copy
:
false
;
add_generic_op
(
"contiguous"
);
add_generic_op
(
"contiguous"
);
...
...
test/CMakeLists.txt
View file @
894fce68
...
@@ -190,6 +190,25 @@ if(MIGRAPHX_ENABLE_PYTHON)
...
@@ -190,6 +190,25 @@ if(MIGRAPHX_ENABLE_PYTHON)
add_subdirectory
(
py
)
add_subdirectory
(
py
)
endif
()
endif
()
# multitarget test
if
(
MIGRAPHX_ENABLE_GPU AND MIGRAPHX_ENABLE_CPU AND MIGRAPHX_ENABLE_FPGA
)
set
(
TEST_MULTI_TARGET_DIR
${
CMAKE_CURRENT_SOURCE_DIR
}
/multi_target
)
file
(
GLOB MULTI_TARGET_TESTS
${
CONFIGURE_DEPENDS
}
${
TEST_MULTI_TARGET_DIR
}
/*.cpp
)
foreach
(
MULTI_TARGET_TEST
${
MULTI_TARGET_TESTS
}
)
get_filename_component
(
BASE_NAME
${
MULTI_TARGET_TEST
}
NAME_WE
)
set
(
TEST_NAME test_
${
BASE_NAME
}
)
add_executable
(
${
TEST_NAME
}
${
MULTI_TARGET_TEST
}
)
rocm_clang_tidy_check
(
${
TEST_NAME
}
)
target_link_libraries
(
${
TEST_NAME
}
migraphx migraphx_onnx migraphx_tf migraphx_all_targets
)
target_include_directories
(
${
TEST_NAME
}
PUBLIC include
)
add_test
(
NAME
${
TEST_NAME
}
COMMAND $<TARGET_FILE:
${
TEST_NAME
}
> WORKING_DIRECTORY
${
TEST_MULTI_TARGET_DIR
}
)
add_dependencies
(
tests
${
TEST_NAME
}
)
add_dependencies
(
check
${
TEST_NAME
}
)
endforeach
()
endif
()
function
(
test_header NAME HEADER
)
function
(
test_header NAME HEADER
)
file
(
WRITE
${
CMAKE_CURRENT_BINARY_DIR
}
/header-main-include-
${
NAME
}
.cpp
file
(
WRITE
${
CMAKE_CURRENT_BINARY_DIR
}
/header-main-include-
${
NAME
}
.cpp
"#include <
${
HEADER
}
>
\n
int main() {}
\n
"
"#include <
${
HEADER
}
>
\n
int main() {}
\n
"
...
@@ -228,3 +247,4 @@ if(MIGRAPHX_ENABLE_FPGA)
...
@@ -228,3 +247,4 @@ if(MIGRAPHX_ENABLE_FPGA)
test_headers
(
migraphx/fpga
${
CMAKE_SOURCE_DIR
}
/src/targets/fpga/include/migraphx/fpga/*.hpp
)
test_headers
(
migraphx/fpga
${
CMAKE_SOURCE_DIR
}
/src/targets/fpga/include/migraphx/fpga/*.hpp
)
endif
()
endif
()
test/argument_test.cpp
View file @
894fce68
...
@@ -193,6 +193,15 @@ TEST_CASE(value_argument)
...
@@ -193,6 +193,15 @@ TEST_CASE(value_argument)
EXPECT
(
a4
==
a2
);
EXPECT
(
a4
==
a2
);
}
}
TEST_CASE
(
value_empty_argument
)
{
migraphx
::
argument
a5
;
EXPECT
(
a5
.
empty
());
auto
v3
=
migraphx
::
to_value
(
a5
);
auto
a6
=
migraphx
::
from_value
<
migraphx
::
argument
>
(
v3
);
EXPECT
(
a6
==
a5
);
}
TEST_CASE
(
value_tuple
)
TEST_CASE
(
value_tuple
)
{
{
auto
a1
=
make_tuple
(
3
,
3.0
,
make_tuple
(
3
,
4
));
auto
a1
=
make_tuple
(
3
,
3.0
,
make_tuple
(
3
,
4
));
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment