Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
f1c8e6c9
Commit
f1c8e6c9
authored
Jul 07, 2023
by
turneram
Browse files
Merge remote-tracking branch 'origin/develop' into ck-integration-tuning
parents
d09b7682
c1b8c975
Changes
76
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
366 additions
and
214 deletions
+366
-214
src/module.cpp
src/module.cpp
+6
-3
src/onnx/onnx_parser.cpp
src/onnx/onnx_parser.cpp
+63
-12
src/onnx/parse_where.cpp
src/onnx/parse_where.cpp
+0
-9
src/pass_manager.cpp
src/pass_manager.cpp
+10
-2
src/program.cpp
src/program.cpp
+136
-130
src/promote_literals.cpp
src/promote_literals.cpp
+1
-1
src/quantize_fp16.cpp
src/quantize_fp16.cpp
+11
-10
src/replace_allocate.cpp
src/replace_allocate.cpp
+5
-3
src/simplify_algebra.cpp
src/simplify_algebra.cpp
+1
-6
src/target.cpp
src/target.cpp
+37
-0
src/targets/cpu/CMakeLists.txt
src/targets/cpu/CMakeLists.txt
+0
-2
src/targets/gpu/compile_ops.cpp
src/targets/gpu/compile_ops.cpp
+5
-0
src/targets/gpu/fuse_ck.cpp
src/targets/gpu/fuse_ck.cpp
+14
-1
src/targets/gpu/fuse_mlir.cpp
src/targets/gpu/fuse_mlir.cpp
+68
-24
src/targets/gpu/include/migraphx/gpu/lowering.hpp
src/targets/gpu/include/migraphx/gpu/lowering.hpp
+2
-2
src/targets/gpu/jit/ck_gemm.cpp
src/targets/gpu/jit/ck_gemm.cpp
+3
-5
src/targets/gpu/jit/concat.cpp
src/targets/gpu/jit/concat.cpp
+1
-1
src/targets/gpu/jit/gather.cpp
src/targets/gpu/jit/gather.cpp
+1
-1
src/targets/gpu/jit/gathernd.cpp
src/targets/gpu/jit/gathernd.cpp
+1
-1
src/targets/gpu/jit/layernorm.cpp
src/targets/gpu/jit/layernorm.cpp
+1
-1
No files found.
src/module.cpp
View file @
f1c8e6c9
...
...
@@ -326,6 +326,8 @@ instruction_ref module::replace_instruction(instruction_ref ins, instruction_ref
if
(
ins
==
std
::
prev
(
this
->
end
()))
{
// "rep" instruction could be used earlier in the program and moving it at the end
// may cause invalid program, therefore make an identity operation in this case.
return
replace_instruction
(
ins
,
make_op
(
"identity"
),
rep
);
}
...
...
@@ -650,8 +652,9 @@ instruction_ref module::find_dangling_reference() const
return
end
();
}
void
module
::
finalize
(
context
&
c
tx
)
void
module
::
finalize
(
std
::
vector
<
context
>
&
c
ontexts
)
{
assert
(
not
contexts
.
empty
());
const
bool
trace
=
enabled
(
MIGRAPHX_TRACE_FINALIZE
{});
for
(
auto
ins
:
iterator_for
(
*
this
))
{
...
...
@@ -660,10 +663,10 @@ void module::finalize(context& ctx)
std
::
cout
<<
"Finalize: "
;
this
->
debug_print
(
ins
);
}
ins
->
finalize
(
c
tx
);
ins
->
finalize
(
c
ontexts
[
ins
->
get_target_id
()]
);
for
(
const
auto
&
smod
:
ins
->
module_inputs
())
{
smod
->
finalize
(
c
tx
);
smod
->
finalize
(
c
ontexts
);
}
}
...
...
src/onnx/onnx_parser.cpp
View file @
f1c8e6c9
...
...
@@ -38,6 +38,9 @@
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
onnx
{
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_TRACE_ONNX_PARSER
)
static
shape
shape_from_dyn_dims
(
shape
::
type_t
shape_type
,
const
std
::
vector
<
shape
::
dynamic_dimension
>&
dyn_dims
)
...
...
@@ -53,8 +56,6 @@ static shape shape_from_dyn_dims(shape::type_t shape_type,
return
{
shape_type
,
dyn_dims
};
}
namespace
onnx
{
static
onnx_parser
::
attribute_map
get_attributes
(
const
onnx
::
NodeProto
&
node
)
{
std
::
unordered_map
<
std
::
string
,
onnx
::
AttributeProto
>
result
;
...
...
@@ -297,16 +298,48 @@ int64_t onnx_parser::get_opset_version(const onnx::ModelProto& model)
return
version
;
}
std
::
vector
<
instruction_ref
>
onnx_parser
::
parse_graph
(
module
*
mod
,
const
onnx
::
GraphProto
&
graph
,
bool
inlining
)
void
print_added_instructions
(
module
*
mod
,
const
std
::
vector
<
instruction_ref
>&
args
,
const
std
::
vector
<
instruction_ref
>&
result
)
{
// Print instructions added by the parser not in args
std
::
vector
<
instruction_ref
>
added_instructions
;
fix
([
&
](
auto
self
,
auto
r
)
{
for
(
auto
ins
:
r
)
{
if
(
contains
(
args
,
ins
))
continue
;
if
(
contains
(
added_instructions
,
ins
))
continue
;
self
(
ins
->
inputs
());
added_instructions
.
push_back
(
ins
);
}
})(
result
);
mod
->
debug_print
(
added_instructions
);
}
std
::
unordered_map
<
std
::
string
,
instruction_ref
>
parse_intializer
(
const
onnx_parser
&
parser
,
module
*
mod
,
const
onnx
::
GraphProto
&
graph
)
{
std
::
unordered_map
<
std
::
string
,
instruction_ref
>
mod_insts
;
for
(
auto
&&
f
:
graph
.
initializer
())
{
if
(
enabled
(
MIGRAPHX_TRACE_ONNX_PARSER
{}))
std
::
cout
<<
"initializer: "
<<
f
.
name
()
<<
std
::
endl
;
// backup instructions in parent mod
mod_insts
[
f
.
name
()]
=
mod
->
add_literal
(
parse_tensor
(
f
));
mod_insts
[
f
.
name
()]
=
mod
->
add_literal
(
parser
.
parse_tensor
(
f
));
if
(
enabled
(
MIGRAPHX_TRACE_ONNX_PARSER
{}))
mod
->
debug_print
(
mod_insts
[
f
.
name
()]);
}
return
mod_insts
;
}
std
::
unordered_map
<
std
::
string
,
instruction_ref
>
parse_inputs
(
const
onnx_parser
&
parser
,
module
*
mod
,
const
onnx
::
GraphProto
&
graph
,
std
::
unordered_map
<
std
::
string
,
instruction_ref
>
mod_insts
)
{
for
(
auto
&&
input
:
graph
.
input
())
{
const
std
::
string
&
name
=
input
.
name
();
...
...
@@ -317,7 +350,7 @@ onnx_parser::parse_graph(module* mod, const onnx::GraphProto& graph, bool inlini
// scenario that a nested subgraph contains a parameter with the
// name existed in its parent graph.
// In the current implementation, MIGraphX throws an exception for that.
if
(
contains
(
instructions
,
name
))
if
(
contains
(
parser
.
instructions
,
name
))
{
MIGRAPHX_THROW
(
"module
\"
"
+
mod
->
name
()
+
"
\"
has parameter name
\"
"
+
name
+
"
\"
existing in parent graph!"
);
...
...
@@ -325,28 +358,41 @@ onnx_parser::parse_graph(module* mod, const onnx::GraphProto& graph, bool inlini
shape
s
;
std
::
vector
<
std
::
size_t
>
dims
;
if
(
map_input_dims
.
count
(
name
)
>
0
)
if
(
parser
.
map_input_dims
.
count
(
name
)
>
0
)
{
dims
=
map_input_dims
.
at
(
name
);
s
=
parse_type
(
input
.
type
(),
dims
);
dims
=
parser
.
map_input_dims
.
at
(
name
);
s
=
parser
.
parse_type
(
input
.
type
(),
dims
);
}
else
if
(
map_dyn_input_dims
.
count
(
name
)
>
0
)
else
if
(
parser
.
map_dyn_input_dims
.
count
(
name
)
>
0
)
{
shape
::
type_t
shape_type
=
get_type
(
input
.
type
().
tensor_type
().
elem_type
());
s
=
shape_from_dyn_dims
(
shape_type
,
map_dyn_input_dims
.
at
(
name
));
s
=
shape_from_dyn_dims
(
shape_type
,
parser
.
map_dyn_input_dims
.
at
(
name
));
}
else
{
s
=
parse_type
(
input
.
type
(),
dims
);
s
=
parser
.
parse_type
(
input
.
type
(),
dims
);
}
mod_insts
[
name
]
=
mod
->
add_parameter
(
name
,
s
);
}
}
return
mod_insts
;
}
std
::
vector
<
instruction_ref
>
onnx_parser
::
parse_graph
(
module
*
mod
,
const
onnx
::
GraphProto
&
graph
,
bool
inlining
)
{
std
::
unordered_map
<
std
::
string
,
instruction_ref
>
mod_insts
=
parse_intializer
(
*
this
,
mod
,
graph
);
mod_insts
=
parse_inputs
(
*
this
,
mod
,
graph
,
mod_insts
);
std
::
copy
(
mod_insts
.
begin
(),
mod_insts
.
end
(),
std
::
inserter
(
instructions
,
instructions
.
end
()));
for
(
auto
&&
node
:
graph
.
node
())
{
if
(
enabled
(
MIGRAPHX_TRACE_ONNX_PARSER
{}))
std
::
cout
<<
"operator: "
<<
node
.
op_type
()
<<
std
::
endl
;
std
::
vector
<
instruction_ref
>
args
;
for
(
auto
&&
input
:
node
.
input
())
{
...
...
@@ -384,6 +430,11 @@ onnx_parser::parse_graph(module* mod, const onnx::GraphProto& graph, bool inlini
result
.
begin
(),
std
::
inserter
(
instructions
,
instructions
.
end
()),
[](
auto
&&
x
,
auto
&&
y
)
{
return
std
::
make_pair
(
x
,
y
);
});
if
(
enabled
(
MIGRAPHX_TRACE_ONNX_PARSER
{}))
{
print_added_instructions
(
mod
,
args
,
result
);
}
}
// Find instructions corresponding to the output
...
...
src/onnx/parse_where.cpp
View file @
f1c8e6c9
...
...
@@ -31,8 +31,6 @@ namespace migraphx {
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
onnx
{
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_ENABLE_CK
)
struct
parse_where
:
op_parser
<
parse_where
>
{
std
::
vector
<
op_desc
>
operators
()
const
{
return
{{
"Where"
}};
}
...
...
@@ -59,13 +57,6 @@ struct parse_where : op_parser<parse_where>
compute_broadcasted_lens
(
args
[
0
]
->
get_shape
().
lens
(),
args
[
1
]
->
get_shape
().
lens
());
lens
=
compute_broadcasted_lens
(
lens
,
args
[
2
]
->
get_shape
().
lens
());
if
(
enabled
(
MIGRAPHX_ENABLE_CK
{}))
{
// Convert condition tensor to int32 to work around CK not supporting bool type
args
[
0
]
=
info
.
add_instruction
(
make_op
(
"convert"
,
{{
"target_type"
,
shape
::
int32_type
}}),
args
[
0
]);
}
if
(
args
[
0
]
->
get_shape
().
lens
()
!=
lens
)
{
args
[
0
]
=
...
...
src/pass_manager.cpp
View file @
f1c8e6c9
...
...
@@ -68,12 +68,18 @@ void run_pass(program& prog, const pass& p, tracer trace)
struct
module_pm
:
module_pass_manager
{
module
*
mod
=
nullptr
;
module
*
root_mod
=
nullptr
;
tracer
*
t
=
nullptr
;
module
*
common_parent
=
nullptr
;
program
*
prog
=
nullptr
;
module_pm
(
module
*
pmod
=
nullptr
,
tracer
*
pt
=
nullptr
)
:
mod
(
pmod
),
t
(
pt
)
{}
module_pm
(
module
*
pmod
=
nullptr
,
module
*
rmod
=
nullptr
,
tracer
*
pt
=
nullptr
)
:
mod
(
pmod
),
root_mod
(
rmod
),
t
(
pt
)
{
}
template
<
class
...
Ts
>
void
trace
(
Ts
&&
...
xs
)
const
{
...
...
@@ -97,6 +103,8 @@ struct module_pm : module_pass_manager
virtual
module
*
get_root_module
()
override
{
if
(
root_mod
!=
nullptr
)
return
root_mod
;
assert
(
prog
);
return
prog
->
get_main_module
();
}
...
...
@@ -140,7 +148,7 @@ void run_passes(program& prog, module_ref root_mod, const std::vector<pass>& pas
continue
;
if
(
not
visited
.
insert
(
mod
).
second
)
continue
;
module_pm
mpm
{
mod
,
&
trace
};
module_pm
mpm
{
mod
,
root_mod
,
&
trace
};
mpm
.
prog
=
&
prog
;
auto
parents
=
range
(
tree
.
equal_range
(
mod
));
auto
nparents
=
distance
(
parents
);
...
...
@@ -164,7 +172,7 @@ void run_passes(module& mod, const std::vector<pass>& passes, tracer trace)
trace
=
tracer
{
std
::
cout
};
for
(
const
auto
&
p
:
passes
)
{
module_pm
{
&
mod
,
&
trace
}.
run_pass
(
p
);
module_pm
{
&
mod
,
&
mod
,
&
trace
}.
run_pass
(
p
);
}
}
...
...
src/program.cpp
View file @
f1c8e6c9
...
...
@@ -70,9 +70,8 @@ struct program_impl
{
// A map is used to keep references to modules of the program
std
::
unordered_map
<
std
::
string
,
module
>
modules
;
context
ctx
;
std
::
string
target_name
;
std
::
vector
<
context
>
contexts
;
std
::
vector
<
target
>
targets
;
};
program
::
program
()
:
impl
(
std
::
make_unique
<
program_impl
>
())
{
this
->
create_module
(
"main"
);
}
...
...
@@ -96,14 +95,8 @@ void program::assign(const program& p)
{
impl
=
std
::
make_unique
<
program_impl
>
();
}
else
if
(
not
impl
->
modules
.
empty
())
{
impl
->
modules
.
clear
();
}
impl
->
ctx
=
p
.
impl
->
ctx
;
impl
->
target_name
=
p
.
impl
->
target_name
;
impl
->
modules
=
p
.
impl
->
modules
;
*
impl
=
*
p
.
impl
;
// build a map from old ins to new ins
// Build a map from old module to new module
...
...
@@ -166,7 +159,11 @@ std::vector<shape> program::get_output_shapes() const
return
mm
->
get_output_shapes
();
}
context
&
program
::
get_context
()
const
{
return
impl
->
ctx
;
}
context
&
program
::
get_context
()
const
{
assert
(
impl
->
contexts
.
size
()
==
1
);
return
impl
->
contexts
.
front
();
}
instruction_ref
program
::
validate
()
const
{
...
...
@@ -217,7 +214,7 @@ target_assignments program::get_target_assignments(const std::vector<target>& ta
return
p
;
}
bool
program
::
is_compiled
()
const
{
return
not
this
->
impl
->
target_name
.
empty
();
}
bool
program
::
is_compiled
()
const
{
return
not
this
->
impl
->
contexts
.
empty
();
}
void
program
::
compile
(
const
std
::
vector
<
target
>&
targets
,
std
::
vector
<
compile_options
>
compile_opts
)
{
...
...
@@ -299,24 +296,24 @@ void program::compile(const std::vector<target>& targets, std::vector<compile_op
MIGRAPHX_THROW
(
"Dangling reference in module "
+
current_mod
->
name
()
+
" from instruction "
+
std
::
to_string
(
index
));
}
current_mod
->
finalize
(
this
->
impl
->
contexts
[
root_target_id
]);
}
}
this
->
finalize
();
}
void
program
::
compile
(
const
target
&
t
,
compile_options
options
)
{
// todo: combine with multi-target compile method
assert
(
not
this
->
is_compiled
());
this
->
impl
->
target
_name
=
t
.
name
()
;
this
->
impl
->
c
tx
=
t
.
get_context
();
this
->
impl
->
target
s
=
{
t
}
;
this
->
impl
->
c
ontexts
=
{
t
.
get_context
()
}
;
if
(
enabled
(
MIGRAPHX_TRACE_COMPILE
{}))
options
.
trace
=
tracer
{
std
::
cout
};
options
.
trace
(
*
this
);
options
.
trace
();
auto
&&
passes
=
t
.
get_passes
(
this
->
impl
->
c
tx
,
options
);
auto
&&
passes
=
t
.
get_passes
(
this
->
impl
->
c
ontexts
.
front
()
,
options
);
run_passes
(
*
this
,
passes
,
options
.
trace
);
auto
mods
=
this
->
get_modules
();
// Validate and finalize
...
...
@@ -335,14 +332,14 @@ void program::compile(const target& t, compile_options options)
MIGRAPHX_THROW
(
"Dangling reference in module "
+
mod
->
name
()
+
" from instruction "
+
std
::
to_string
(
index
));
}
mod
->
finalize
(
this
->
impl
->
c
tx
);
mod
->
finalize
(
this
->
impl
->
c
ontexts
);
}
}
void
program
::
finalize
()
{
auto
*
mm
=
this
->
get_main_module
();
mm
->
finalize
(
this
->
impl
->
c
tx
);
mm
->
finalize
(
this
->
impl
->
c
ontexts
);
}
template
<
class
T
>
...
...
@@ -359,6 +356,31 @@ std::string classify(T x)
}
}
void
print_statistics
(
std
::
ostream
&
os
,
const
argument
&
a
)
{
a
.
visit
(
[
&
](
auto
t
)
{
os
<<
"Min value: "
<<
*
std
::
min_element
(
t
.
begin
(),
t
.
end
())
<<
", "
;
os
<<
"Max value: "
<<
*
std
::
max_element
(
t
.
begin
(),
t
.
end
())
<<
", "
;
double
num_elements
=
t
.
size
();
auto
mean
=
std
::
accumulate
(
t
.
begin
(),
t
.
end
(),
0.0
)
/
num_elements
;
auto
stddev
=
std
::
sqrt
(
std
::
accumulate
(
t
.
begin
(),
t
.
end
(),
0.0
,
[
&
](
auto
r
,
auto
v
)
{
return
r
+
std
::
pow
((
v
-
mean
),
2.0
);
})
/
num_elements
);
os
<<
"Mean: "
<<
mean
<<
", "
;
os
<<
"StdDev: "
<<
stddev
<<
"
\n
"
;
},
[
&
](
const
auto
&
xs
)
{
for
(
const
auto
&
x
:
xs
)
{
print_statistics
(
os
,
x
);
}
});
}
std
::
unordered_set
<
std
::
string
>
classify_argument
(
const
argument
&
a
)
{
std
::
unordered_set
<
std
::
string
>
result
;
...
...
@@ -404,16 +426,15 @@ void preview_argument(std::ostream& os, const argument& a)
template
<
class
F
>
std
::
vector
<
argument
>
generic_eval
(
const
module
*
mod
,
context
&
ctx
,
std
::
vector
<
context
>
&
ctx
,
std
::
unordered_map
<
std
::
string
,
argument
>
params
,
std
::
unordered_map
<
instruction_ref
,
argument
>
results
,
F
make_
trace
)
F
trace
)
{
assert
(
mod
->
validate
()
==
mod
->
end
());
results
.
reserve
(
mod
->
size
()
*
2
);
std
::
vector
<
argument
>
values
;
values
.
reserve
(
16
);
auto
trace
=
make_trace
(
mod
);
for
(
auto
ins
:
iterator_for
(
*
mod
))
{
assert
(
results
.
find
(
ins
)
==
results
.
end
());
...
...
@@ -469,14 +490,19 @@ std::vector<argument> generic_eval(const module* mod,
const
auto
&
mod_args
=
ins
->
module_inputs
();
auto
module_eval
=
[
&
](
module_ref
smod
,
const
std
::
unordered_map
<
std
::
string
,
argument
>&
inputs
)
{
auto
ssctx
=
ctx
;
return
generic_eval
(
smod
,
ssctx
,
inputs
,
results
,
make_trace
);
return
generic_eval
(
smod
,
ctx
,
inputs
,
results
,
trace
);
};
results
.
emplace
(
ins
,
trace
(
ins
,
[
&
]
{
return
ins
->
normalized_operator
().
compute
(
ctx
,
ins
->
get_shape
(),
values
,
mod_args
,
module_eval
);
}));
results
.
emplace
(
ins
,
trace
(
ins
,
[
&
]
{
auto
op
=
ins
->
normalized_operator
();
if
(
op
.
is_context_free
())
return
op
.
compute
(
ins
->
get_shape
(),
values
,
mod_args
,
module_eval
);
if
(
ins
->
get_target_id
()
>=
ctx
.
size
())
MIGRAPHX_THROW
(
"No context available for "
+
op
.
name
());
return
op
.
compute
(
ctx
[
ins
->
get_target_id
()],
ins
->
get_shape
(),
values
,
mod_args
,
module_eval
);
}));
}
assert
(
results
.
find
(
ins
)
!=
results
.
end
());
if
(
not
ins
->
get_shape
().
any_of_dynamic
())
...
...
@@ -489,44 +515,25 @@ std::vector<argument> generic_eval(const module* mod,
template
<
class
F
>
std
::
vector
<
argument
>
generic_eval
(
const
program
&
p
,
context
&
ctx
,
std
::
vector
<
context
>
&
ctx
,
std
::
unordered_map
<
std
::
string
,
argument
>
params
,
F
make_
trace
)
F
trace
)
{
const
module
*
mm
=
p
.
get_main_module
();
return
generic_eval
(
mm
,
ctx
,
params
,
{},
make_
trace
);
return
generic_eval
(
mm
,
ctx
,
params
,
{},
trace
);
}
std
::
vector
<
argument
>
program
::
eval
(
parameter_map
params
,
execution_environment
exec_env
)
const
{
auto
&
ctx
=
this
->
impl
->
ctx
;
#ifndef NDEBUG
auto
with_check_context
=
[
&
](
auto
f
)
{
return
[
=
,
&
ctx
](
auto
&&
)
{
auto
sctx
=
std
::
make_shared
<
context
>
(
ctx
);
auto
check_context
=
[
=
,
&
ctx
](
auto
g
)
{
assert
(
is_shared
(
ctx
,
*
sctx
));
auto
x
=
g
();
*
sctx
=
ctx
;
return
x
;
};
return
[
=
](
auto
&&
...
xs
)
{
return
f
(
xs
...,
check_context
);
};
};
};
#else
auto
with_check_context
=
[](
auto
f
)
{
return
[
=
](
auto
&&
)
{
return
[
=
](
auto
&&
...
xs
)
{
return
f
(
xs
...,
[](
auto
g
)
{
return
g
();
});
};
};
};
#endif
auto
&
contexts
=
this
->
impl
->
contexts
;
auto
trace_level
=
value_of
(
MIGRAPHX_TRACE_EVAL
{});
std
::
vector
<
argument
>
ret
;
if
(
exec_env
.
async
)
{
ctx
.
wait_for
(
exec_env
.
queue
);
assert
(
contexts
.
size
()
==
1
);
contexts
.
front
().
wait_for
(
exec_env
.
queue
);
}
if
(
trace_level
>
0
)
...
...
@@ -538,82 +545,79 @@ std::vector<argument> program::eval(parameter_map params, execution_environment
instruction
::
print
(
ss
,
x
,
ins_names
);
ins_out
[
x
]
=
ss
.
str
();
});
ret
=
generic_eval
(
*
this
,
ctx
,
std
::
move
(
params
),
with_check_context
([
&
](
auto
&
ins
,
auto
f
,
auto
&&
check_context
)
{
ctx
.
finish
();
std
::
cout
<<
"Run instruction: "
<<
ins_out
.
at
(
ins
)
<<
std
::
endl
;
timer
t
{};
auto
result
=
check_context
(
f
);
double
t1
=
t
.
record
<
milliseconds
>
();
ctx
.
finish
();
double
t2
=
t
.
record
<
milliseconds
>
();
std
::
cout
<<
"Time: "
<<
t1
<<
"ms, "
<<
t2
<<
"ms"
<<
std
::
endl
;
if
(
trace_level
>
1
and
ins
->
name
().
front
()
!=
'@'
and
ins
->
name
()
!=
"load"
and
not
result
.
empty
())
ret
=
generic_eval
(
*
this
,
contexts
,
std
::
move
(
params
),
[
&
](
instruction_ref
ins
,
auto
f
)
{
auto
&
ctx
=
contexts
[
ins
->
get_target_id
()];
ctx
.
finish
();
std
::
cout
<<
"Run instruction: "
<<
ins_out
.
at
(
ins
)
<<
std
::
endl
;
timer
t
{};
auto
result
=
f
();
double
t1
=
t
.
record
<
milliseconds
>
();
ctx
.
finish
();
double
t2
=
t
.
record
<
milliseconds
>
();
std
::
cout
<<
"Time: "
<<
t1
<<
"ms, "
<<
t2
<<
"ms"
<<
std
::
endl
;
if
(
trace_level
>
1
and
ins
->
name
().
front
()
!=
'@'
and
ins
->
name
()
!=
"load"
and
not
result
.
empty
())
{
migraphx
::
argument
buffer
;
try
{
migraphx
::
argument
buffer
;
try
{
target
tgt
=
make_target
(
this
->
impl
->
target_name
);
buffer
=
tgt
.
copy_from
(
result
);
}
catch
(
const
migraphx
::
exception
&
)
{
// instruction was run on host then no need to copy buffer from target
buffer
=
result
;
}
catch
(...)
{
MIGRAPHX_THROW
(
"MIGraphX program execution with MIGRAPHX_TRACE_EVAL failed.
\n
"
);
}
if
(
trace_level
==
2
)
{
std
::
cout
<<
"Output has "
<<
to_string_range
(
classify_argument
(
buffer
))
<<
std
::
endl
;
std
::
cout
<<
"Output: "
;
preview_argument
(
std
::
cout
,
buffer
);
std
::
cout
<<
std
::
endl
;
}
else
{
std
::
cout
<<
"Output: "
<<
buffer
<<
std
::
endl
;
}
const
target
&
tgt
=
this
->
impl
->
targets
.
at
(
ins
->
get_target_id
());
buffer
=
tgt
.
copy_from
(
result
);
}
return
result
;
}));
catch
(
const
migraphx
::
exception
&
)
{
// instruction was run on host then no need to copy buffer from target
buffer
=
result
;
}
catch
(...)
{
MIGRAPHX_THROW
(
"MIGraphX program execution with MIGRAPHX_TRACE_EVAL failed.
\n
"
);
}
if
(
trace_level
==
2
)
{
std
::
cout
<<
"Output has "
<<
to_string_range
(
classify_argument
(
buffer
))
<<
std
::
endl
;
std
::
cout
<<
"Output: "
;
preview_argument
(
std
::
cout
,
buffer
);
std
::
cout
<<
std
::
endl
;
print_statistics
(
std
::
cout
,
buffer
);
}
else
{
std
::
cout
<<
"Output: "
<<
buffer
<<
std
::
endl
;
}
}
return
result
;
});
}
else
{
ret
=
generic_eval
(
*
this
,
ctx
,
std
::
move
(
params
),
with_check_context
([
&
](
auto
&
,
auto
f
,
auto
&&
check_context
)
{
return
check_context
(
f
);
}));
ret
=
generic_eval
(
*
this
,
contexts
,
std
::
move
(
params
),
[
&
](
auto
&&
,
auto
f
)
{
return
f
();
});
}
if
(
exec_env
.
async
)
{
ctx
.
finish_on
(
exec_env
.
queue
);
assert
(
contexts
.
size
()
==
1
);
contexts
.
front
().
finish_on
(
exec_env
.
queue
);
}
return
ret
;
}
const
int
program_file_version
=
5
;
void
program
::
finish
()
const
{
for
(
const
auto
&
ctx
:
this
->
impl
->
contexts
)
ctx
.
finish
();
}
const
int
program_file_version
=
6
;
value
program
::
to_value
()
const
{
value
result
;
result
[
"version"
]
=
program_file_version
;
result
[
"target"
]
=
this
->
impl
->
target_name
;
if
(
not
this
->
impl
->
target_name
.
empty
())
result
[
"context"
]
=
this
->
impl
->
ctx
.
to_value
();
result
[
"version"
]
=
program_file_version
;
result
[
"targets"
]
=
migraphx
::
to_value
(
this
->
impl
->
targets
);
result
[
"contexts"
]
=
migraphx
::
to_value
(
this
->
impl
->
contexts
);
value
module_vals
=
value
::
object
{};
std
::
unordered_map
<
instruction_ref
,
std
::
string
>
names
;
...
...
@@ -742,12 +746,12 @@ void program::from_value(const value& v)
MIGRAPHX_THROW
(
"Warning: Program version mismatch"
);
}
this
->
impl
->
target_name
=
v
.
at
(
"target"
).
to
<
std
::
string
>
();
if
(
not
this
->
impl
->
target_name
.
empty
())
migraphx
::
from_value
(
v
.
at
(
"targets"
),
this
->
impl
->
targets
);
for
(
auto
i
:
range
(
this
->
impl
->
targets
.
size
()))
{
target
t
=
make_target
(
this
->
impl
->
target_name
);
this
->
impl
->
ctx
=
t
.
get_context
();
this
->
impl
->
ctx
.
from_value
(
v
.
at
(
"context"
));
this
->
impl
->
contexts
.
push_back
(
this
->
impl
->
targets
[
i
].
get_context
());
this
->
impl
->
contexts
.
back
().
from_value
(
v
.
at
(
"contexts"
)[
i
]);
}
auto
module_vals
=
v
.
at
(
"modules"
);
...
...
@@ -768,7 +772,9 @@ void program::from_value(const value& v)
auto
*
mm
=
get_main_module
();
mod_from_val
(
mm
,
module_vals
,
map_insts
,
map_mods
);
this
->
finalize
();
// Finalize a compiled model
if
(
not
this
->
impl
->
contexts
.
empty
())
this
->
finalize
();
}
double
common_average
(
const
std
::
vector
<
double
>&
v
)
...
...
@@ -788,19 +794,19 @@ std::string perf_group(const operation& op)
void
program
::
mark
(
const
parameter_map
&
params
,
marker
&&
m
)
{
auto
&
ctx
=
this
->
impl
->
c
tx
;
auto
&
ctx
=
this
->
impl
->
c
ontexts
;
// Run once by itself
eval
(
params
);
ctx
.
finish
();
this
->
finish
();
// Start marking
m
.
mark_start
(
*
this
);
generic_eval
(
*
this
,
ctx
,
params
,
always
(
[
&
](
auto
ins
,
auto
f
)
{
generic_eval
(
*
this
,
ctx
,
params
,
[
&
](
auto
ins
,
auto
f
)
{
argument
result
;
m
.
mark_start
(
ins
);
result
=
f
();
m
.
mark_stop
(
ins
);
return
result
;
})
)
;
});
m
.
mark_stop
(
*
this
);
}
...
...
@@ -809,10 +815,10 @@ void program::perf_report(std::ostream& os,
parameter_map
params
,
std
::
size_t
batch
)
const
{
auto
&
ctx
=
this
->
impl
->
c
tx
;
auto
&
ctx
=
this
->
impl
->
c
ontexts
;
// Run once by itself
eval
(
params
);
ctx
.
finish
();
this
->
finish
();
// Run and time entire program
std
::
vector
<
double
>
total_vec
;
total_vec
.
reserve
(
n
);
...
...
@@ -820,28 +826,28 @@ void program::perf_report(std::ostream& os,
{
total_vec
.
push_back
(
time
<
milliseconds
>
([
&
]
{
eval
(
params
);
ctx
.
finish
();
this
->
finish
();
}));
}
std
::
sort
(
total_vec
.
begin
(),
total_vec
.
end
());
std
::
unordered_map
<
instruction_ref
,
std
::
vector
<
double
>>
ins_vec
;
// Fill the map
generic_eval
(
*
this
,
ctx
,
params
,
always
(
[
&
](
auto
ins
,
auto
)
{
generic_eval
(
*
this
,
ctx
,
params
,
[
&
](
auto
ins
,
auto
)
{
ins_vec
[
ins
].
reserve
(
n
);
return
argument
{
ins
->
get_shape
(),
nullptr
};
})
)
;
});
// Run and time each instruction
for
(
std
::
size_t
i
=
0
;
i
<
n
;
i
++
)
{
generic_eval
(
*
this
,
ctx
,
params
,
always
(
[
&
](
auto
ins
,
auto
f
)
{
generic_eval
(
*
this
,
ctx
,
params
,
[
&
](
auto
ins
,
auto
f
)
{
argument
result
;
ins_vec
[
ins
].
push_back
(
time
<
milliseconds
>
([
&
]
{
result
=
f
();
ctx
.
finish
();
this
->
impl
->
contexts
[
ins
->
get_target_id
()]
.
finish
();
}));
return
result
;
})
)
;
});
}
for
(
auto
&&
p
:
ins_vec
)
std
::
sort
(
p
.
second
.
begin
(),
p
.
second
.
end
());
...
...
@@ -1009,10 +1015,10 @@ void program::print_cpp(std::ostream& os) const
void
program
::
dry_run
(
std
::
unordered_map
<
std
::
string
,
argument
>
params
)
const
{
auto
&
ctx
=
this
->
impl
->
c
tx
;
generic_eval
(
*
this
,
ctx
,
std
::
move
(
params
),
always
(
[](
auto
ins
,
auto
&&
...)
{
auto
&
ctx
=
this
->
impl
->
c
ontexts
;
generic_eval
(
*
this
,
ctx
,
std
::
move
(
params
),
[](
auto
ins
,
auto
&&
...)
{
return
argument
{
ins
->
get_shape
(),
nullptr
};
})
)
;
});
}
void
program
::
annotate
(
std
::
ostream
&
os
,
const
std
::
function
<
void
(
instruction_ref
)
>&
a
)
const
...
...
src/promote_literals.cpp
View file @
f1c8e6c9
...
...
@@ -34,7 +34,7 @@ void promote_literals::apply(module_pass_manager& mpm) const
{
module
&
m
=
mpm
.
get_module
();
module_ref
root_module
=
mpm
.
get_root_module
();
if
(
m
.
name
()
==
"main"
)
if
(
m
==
*
root_module
)
return
;
for
(
auto
ins
:
iterator_for
(
m
))
...
...
src/quantize_fp16.cpp
View file @
f1c8e6c9
...
...
@@ -52,14 +52,6 @@ static void quantize_module(module& m, const std::vector<std::string>& ins_names
auto
mod_inputs
=
ins
->
module_inputs
();
auto
s
=
ins
->
get_shape
();
// Convert back to original type before quantizing the inputs
if
(
mod_inputs
.
empty
())
{
auto
r
=
m
.
insert_instruction
(
std
::
next
(
ins
),
make_op
(
"convert"
,
{{
"target_type"
,
s
.
type
()}}),
ins
);
m
.
replace_instruction
(
ins
,
r
);
}
// Convert each of the inputs that are floating point to fp16
auto
inputs
=
ins
->
inputs
();
std
::
transform
(
inputs
.
begin
(),
inputs
.
end
(),
inputs
.
begin
(),
[
&
](
auto
input
)
{
...
...
@@ -70,8 +62,17 @@ static void quantize_module(module& m, const std::vector<std::string>& ins_names
ins
,
make_op
(
"convert"
,
{{
"target_type"
,
shape
::
half_type
}}),
input
);
});
// Replace inputs
m
.
replace_instruction
(
ins
,
ins
->
get_operator
(),
inputs
,
mod_inputs
);
// Insert quantized ins
auto
converted_ins
=
m
.
insert_instruction
(
ins
,
ins
->
get_operator
(),
inputs
,
mod_inputs
);
// Convert back to original type after quantizing
if
(
mod_inputs
.
empty
())
{
converted_ins
=
m
.
insert_instruction
(
ins
,
make_op
(
"convert"
,
{{
"target_type"
,
s
.
type
()}}),
converted_ins
);
}
// Replace original instruction
m
.
replace_instruction
(
ins
,
converted_ins
);
}
}
...
...
src/replace_allocate.cpp
View file @
f1c8e6c9
...
...
@@ -21,6 +21,7 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/pass_manager.hpp>
#include <migraphx/replace_allocate.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/program.hpp>
...
...
@@ -84,10 +85,11 @@ void insert_submod_allocations(instruction_ref ins, module& mod, const allocatio
mod
.
replace_instruction
(
ins
,
ins
->
get_operator
(),
inputs
,
mod_args
);
}
void
replace_allocate
::
apply
(
module
&
m
)
const
void
replace_allocate
::
apply
(
module
_pass_manager
&
mp
m
)
const
{
module
&
m
=
mpm
.
get_module
();
auto
mod_output_names
=
create_output_names
(
m
);
bool
main
_offload_copy
=
m
.
nam
e
()
==
"main"
?
this
->
offload_copy
:
false
;
bool
root
_offload_copy
=
(
*
mpm
.
get_root_modul
e
()
==
m
)
?
this
->
offload_copy
:
false
;
for
(
auto
ins
:
iterator_for
(
m
))
{
auto
op
=
ins
->
get_operator
();
...
...
@@ -104,7 +106,7 @@ void replace_allocate::apply(module& m) const
continue
;
auto
s
=
ins
->
get_shape
();
if
(
not
main
_offload_copy
and
model
.
needs_out_params
()
and
contains
(
mod_output_names
,
ins
))
if
(
not
root
_offload_copy
and
model
.
needs_out_params
()
and
contains
(
mod_output_names
,
ins
))
{
auto
out_param
=
m
.
add_parameter
(
mod_output_names
[
ins
],
s
);
m
.
replace_instruction
(
ins
,
out_param
);
...
...
src/simplify_algebra.cpp
View file @
f1c8e6c9
...
...
@@ -39,8 +39,6 @@
#include <migraphx/algorithm.hpp>
#include <unordered_set>
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_TRACE_SIMPLIFY_ALGEBRA_MATCHES
)
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
...
@@ -1487,13 +1485,10 @@ struct find_split_transpose
void
simplify_algebra
::
apply
(
module
&
m
)
const
{
size_t
trace
=
value_of
(
MIGRAPHX_TRACE_SIMPLIFY_ALGEBRA_MATCHES
{});
// Run simplifications multiple times
for
(
int
i
=
0
;
i
<
8
;
i
++
)
{
match
::
find_matches
(
trace
,
m
,
match
::
find_matches
(
m
,
find_inner_broadcast
{},
find_dot_broadcast
{},
find_double_add_lit_broadcast
{},
...
...
src/target.cpp
0 → 100644
View file @
f1c8e6c9
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/target.hpp>
#include <migraphx/register_target.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
void
migraphx_to_value
(
value
&
v
,
const
target
&
t
)
{
v
[
"name"
]
=
t
.
name
();
}
void
migraphx_from_value
(
const
value
&
v
,
target
&
t
)
{
t
=
make_target
(
v
.
at
(
"name"
).
to
<
std
::
string
>
());
}
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
src/targets/cpu/CMakeLists.txt
View file @
f1c8e6c9
...
...
@@ -88,8 +88,6 @@ foreach(LIBRARY ${OpenMP_CXX_LIBRARIES})
endif
()
endforeach
()
target_link_libraries
(
migraphx_all_targets INTERFACE migraphx_cpu
)
rocm_install_targets
(
TARGETS migraphx_cpu
INCLUDE
...
...
src/targets/gpu/compile_ops.cpp
View file @
f1c8e6c9
...
...
@@ -170,7 +170,11 @@ struct compile_plan
if
(
results
.
empty
())
MIGRAPHX_THROW
(
"No configs to tune"
);
if
(
results
.
size
()
==
1
)
{
if
(
not
results
.
front
().
has_value
())
MIGRAPHX_THROW
(
"No configs to tune"
);
return
*
results
.
front
();
}
if
(
not
config
)
MIGRAPHX_THROW
(
"Multiple kernels without config"
);
std
::
cout
<<
"Benchmarking "
<<
preop
.
name
()
<<
": "
<<
results
.
size
()
<<
" configs"
...
...
@@ -185,6 +189,7 @@ struct compile_plan
.
first
;
});
auto
i
=
std
::
distance
(
times
.
begin
(),
std
::
min_element
(
times
.
begin
(),
times
.
end
()));
std
::
cout
<<
"Fastest solution: "
<<
config
->
solutions
.
at
(
i
)
<<
std
::
endl
;
pc
.
insert
(
preop
.
name
(),
config
->
problem
,
config
->
solutions
.
at
(
i
));
if
(
not
results
[
i
].
has_value
())
MIGRAPHX_THROW
(
"No valid tuned compilation."
);
...
...
src/targets/gpu/fuse_ck.cpp
View file @
f1c8e6c9
...
...
@@ -83,10 +83,23 @@ MIGRAPHX_PRED_MATCHER(is_ck_gemm, instruction_ref ins)
return
false
;
auto
a
=
ins
->
inputs
().
front
()
->
get_shape
();
auto
b
=
ins
->
inputs
().
back
()
->
get_shape
();
auto
m
=
a
.
lens
()[
a
.
lens
().
size
()
-
2
];
auto
n
=
b
.
lens
().
back
();
auto
k
=
a
.
lens
().
back
();
// Integer gemms must be divisible by 4 in ck
if
(
contains
({
shape
::
int8_type
,
shape
::
int32_type
},
ins
->
get_shape
().
type
()))
{
if
(
m
%
4
!=
0
)
return
false
;
if
(
n
%
4
!=
0
)
return
false
;
if
(
k
%
4
!=
0
)
return
false
;
}
// Skipping GEMMs with a K dimension greater than 2048 is a course-grained strategy
// to avoid poor-performing GEMM kernels from CK
// To-do: Investigate a more precise strategy
return
a
.
lens
().
back
()
<=
2048
;
return
k
<=
2048
;
}
struct
find_ck_gemm_pointwise
...
...
src/targets/gpu/fuse_mlir.cpp
View file @
f1c8e6c9
...
...
@@ -139,7 +139,8 @@ struct find_mlir_op
auto
matcher
()
const
{
auto
dot_or_conv
=
match
::
skip
(
match
::
name
(
"contiguous"
))(
match
::
any_of
(
match
::
name
(
"dot"
),
is_mlir_conv
()).
bind
(
"gemm_based_op"
));
match
::
any_of
(
match
::
name
(
"dot"
),
match
::
name
(
"quant_dot"
),
is_mlir_conv
())
.
bind
(
"gemm_based_op"
));
return
match
::
name
(
"pointwise"
)(
match
::
any_of
[
match
::
inputs
()](
dot_or_conv
.
bind
(
"x"
)));
}
...
...
@@ -190,6 +191,68 @@ struct find_mlir_op
return
{
new_gemm_based_op
,
top_inputs
};
}
// Whitelist supported fusion options, including imposing type constraints
// for cases where MLIR only supports an operation (usually a pointwise function)
// on particular types.
bool
is_pointwise_op_supported_by_mlir
(
const
instruction
&
i
)
const
{
using
type_t
=
shape
::
type_t
;
const
auto
&
name
=
i
.
name
();
const
auto
result_type
=
i
.
get_shape
().
type
();
const
std
::
initializer_list
<
type_t
>
allowed_types
=
{
type_t
::
float_type
,
type_t
::
half_type
,
type_t
::
int8_type
,
type_t
::
int32_type
,
type_t
::
bool_type
};
// Preliminary type check.
if
(
not
contains
(
allowed_types
,
result_type
))
{
return
false
;
}
const
std
::
initializer_list
<
std
::
string
>
any_type_ops
=
{
"@literal"
,
"@param"
,
"@return"
};
const
std
::
initializer_list
<
std
::
string
>
no_bool_ops
=
{
"convolution"
,
"quant_convolution"
,
"dot"
,
"quant_dot"
,
"add"
,
"clip"
,
"sub"
,
"mul"
,
"div"
,
"pow"
,
"where"
,
"quantizelinear"
,
"dequantizelinear"
,
"abs"
,
"neg"
};
const
std
::
initializer_list
<
std
::
string
>
fp_only_ops
=
{
"ceil"
,
"erf"
,
"exp"
,
"floor"
,
"log"
,
"recip"
,
"rsqrt"
,
"sigmoid"
"softmax"
,
"tanh"
};
bool
is_float
=
contains
({
type_t
::
float_type
,
type_t
::
half_type
},
result_type
);
if
(
contains
(
any_type_ops
,
name
))
return
true
;
if
(
result_type
!=
type_t
::
bool_type
&&
contains
(
no_bool_ops
,
name
))
return
true
;
if
(
is_float
&&
contains
(
fp_only_ops
,
name
))
return
true
;
// Only conversions between floating types are known to be unambigiously
// supported.
if
(
is_float
&&
name
==
"convert"
)
{
return
std
::
all_of
(
i
.
inputs
().
begin
(),
i
.
inputs
().
end
(),
[](
const
auto
&
arg
)
{
return
contains
({
type_t
::
float_type
,
type_t
::
half_type
},
arg
->
get_shape
().
type
());
});
}
return
false
;
}
void
apply
(
module_pass_manager
&
mpm
,
const
match
::
matcher_result
&
r
)
const
{
auto
ins
=
r
.
result
;
...
...
@@ -197,31 +260,12 @@ struct find_mlir_op
auto
x_ins
=
r
.
instructions
[
"x"
];
// input after contiguous
auto
*
pm
=
ins
->
module_inputs
().
front
();
auto
names
=
pm
->
get_parameter_names
();
// Whitelist pointwise operators
if
(
std
::
any_of
(
pm
->
begin
(),
pm
->
end
(),
[](
const
auto
&
i
)
{
return
not
contains
({
"@literal"
,
"@param"
,
"@return"
,
"convolution"
,
"quant_convolution"
,
"dot"
,
"add"
,
"relu"
,
"dequantizelinear"
,
"quantizelinear"
,
"mul"
},
i
.
name
());
}))
return
;
// Only fuse with fp32/fp16/int8/int32
if
(
std
::
any_of
(
ins
->
inputs
().
begin
(),
ins
->
inputs
().
end
(),
[
&
](
auto
i
)
{
return
not
contains
({
shape
::
type_t
::
float_type
,
shape
::
type_t
::
half_type
,
shape
::
type_t
::
int8_type
,
shape
::
type_t
::
int32_type
},
i
->
get_shape
().
type
());
// Whitelist pointwise operators.
if
(
std
::
any_of
(
pm
->
begin
(),
pm
->
end
(),
[
&
](
const
auto
&
i
)
{
return
not
is_pointwise_op_supported_by_mlir
(
i
);
}))
return
;
std
::
sort
(
names
.
begin
(),
names
.
end
());
module_ref
mm
=
mpm
.
create_module
(
"mlir_"
+
pm
->
name
());
mm
->
set_bypass
();
...
...
src/targets/gpu/include/migraphx/gpu/lowering.hpp
View file @
f1c8e6c9
...
...
@@ -30,7 +30,7 @@
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
struct
module
;
struct
module
_pass_manager
;
namespace
gpu
{
...
...
@@ -45,7 +45,7 @@ struct lowering
context
*
ctx
;
bool
offload_copy
;
std
::
string
name
()
const
{
return
"gpu::lowering"
;
}
void
apply
(
module
&
m
)
const
;
void
apply
(
module
_pass_manager
&
mp
m
)
const
;
};
}
// namespace gpu
...
...
src/targets/gpu/jit/ck_gemm.cpp
View file @
f1c8e6c9
...
...
@@ -66,7 +66,7 @@ ${preamble}
extern "C" {
__global__
void ${kernel}(${params})
MIGRAPHX_GLOBAL
void ${kernel}(${params})
{
transform_args(make_tensors(), rotate_last())(${args})([](auto... xs) {
ck_gemm<${solution}, ${blocks_per_batch}>(xs...);
...
...
@@ -266,7 +266,7 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler>
s
=
shape
{
s
.
type
(),
{
m1
,
m2
}};
}
std
::
vector
<
std
::
string
>
names
()
const
{
return
{
"gpu::ck_gemm"
};
}
std
::
vector
<
std
::
string
>
names
()
const
{
return
{
"ck_gemm"
,
"gpu::ck_gemm"
};
}
static
bool
standard_batch
(
const
shape
&
s
)
{
...
...
@@ -419,9 +419,7 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler>
{
auto
shapes
=
to_shapes
(
ins
->
inputs
());
auto
v
=
create_settings
(
ins
,
op
);
if
(
solution
.
is_null
())
v
[
"tuning_value"
]
=
4
;
else
if
(
not
solution
.
is_null
())
v
[
"tuning_value"
]
=
solution
;
return
{
compile_op
(
ctx
,
shapes
,
v
),
[
=
](
module
&
m
,
instruction_ref
ins2
,
const
operation
&
code_object
)
{
...
...
src/targets/gpu/jit/concat.cpp
View file @
f1c8e6c9
...
...
@@ -47,7 +47,7 @@ ${preamble}
extern "C" {
__global__
void ${kernel}(${params})
MIGRAPHX_GLOBAL
void ${kernel}(${params})
{
transform_args(make_tensors(), rotate_last(), ${transformers})(${args})([](auto y, ${concat_params}, auto... xs) {
concat<${axis}>(${concat_args})(${post}, y, xs...);
...
...
src/targets/gpu/jit/gather.cpp
View file @
f1c8e6c9
...
...
@@ -44,7 +44,7 @@ namespace migraphx {
extern "C" {
__global__
void gather_kernel(void* in_data, void* in_indices, void* output)
MIGRAPHX_GLOBAL
void gather_kernel(void* in_data, void* in_indices, void* output)
{
make_tensors()(in_data, in_indices, output)([](auto&&... xs) {
gather<${axis}>(xs...);
...
...
src/targets/gpu/jit/gathernd.cpp
View file @
f1c8e6c9
...
...
@@ -44,7 +44,7 @@ namespace migraphx {
extern "C" {
__global__
void gathernd_kernel(void* in_data, void* in_indices, void* output)
MIGRAPHX_GLOBAL
void gathernd_kernel(void* in_data, void* in_indices, void* output)
{
make_tensors()(in_data, in_indices, output)([](auto&&... xs) {
auto settings = make_gathernd_settings(MIGRAPHX_MAKE_CONSTANT(int64_t{BATCH_DIMS}));
...
...
src/targets/gpu/jit/layernorm.cpp
View file @
f1c8e6c9
...
...
@@ -48,7 +48,7 @@ namespace migraphx {
${preamble}
extern "C" {
__global__
void ${kernel}(${params})
MIGRAPHX_GLOBAL
void ${kernel}(${params})
{
transform_args(make_tensors(), rotate_last(), ${transformers})(${args})([](auto... xs) {
${layernorm}<${axis}>(${post}, ${eps}, xs...);
...
...
Prev
1
2
3
4
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment