Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
818e6e53
Unverified
Commit
818e6e53
authored
Jun 28, 2022
by
Paul Fultz II
Committed by
GitHub
Jun 28, 2022
Browse files
Merge branch 'develop' into mlir-c
parents
d716c516
cb18b0b5
Changes
51
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
295 additions
and
146 deletions
+295
-146
src/api/api.cpp
src/api/api.cpp
+52
-1
src/api/include/migraphx/migraphx.h
src/api/include/migraphx/migraphx.h
+14
-0
src/api/include/migraphx/migraphx.hpp
src/api/include/migraphx/migraphx.hpp
+24
-4
src/api/migraphx.py
src/api/migraphx.py
+9
-0
src/include/migraphx/module.hpp
src/include/migraphx/module.hpp
+1
-1
src/include/migraphx/pass_manager.hpp
src/include/migraphx/pass_manager.hpp
+1
-0
src/include/migraphx/program.hpp
src/include/migraphx/program.hpp
+2
-0
src/include/migraphx/ranges.hpp
src/include/migraphx/ranges.hpp
+9
-3
src/module.cpp
src/module.cpp
+7
-4
src/pass_manager.cpp
src/pass_manager.cpp
+25
-9
src/program.cpp
src/program.cpp
+17
-0
src/shape.cpp
src/shape.cpp
+14
-8
src/targets/gpu/CMakeLists.txt
src/targets/gpu/CMakeLists.txt
+0
-1
src/targets/gpu/fuse_ops.cpp
src/targets/gpu/fuse_ops.cpp
+38
-1
src/targets/gpu/hip.cpp
src/targets/gpu/hip.cpp
+0
-1
src/targets/gpu/include/migraphx/gpu/eliminate_workspace.hpp
src/targets/gpu/include/migraphx/gpu/eliminate_workspace.hpp
+0
-46
src/targets/gpu/include/migraphx/gpu/hip.hpp
src/targets/gpu/include/migraphx/gpu/hip.hpp
+1
-36
src/targets/gpu/jit/pointwise.cpp
src/targets/gpu/jit/pointwise.cpp
+39
-28
src/targets/gpu/kernels/include/migraphx/kernels/algorithm.hpp
...argets/gpu/kernels/include/migraphx/kernels/algorithm.hpp
+33
-1
src/targets/gpu/kernels/include/migraphx/kernels/pointwise.hpp
...argets/gpu/kernels/include/migraphx/kernels/pointwise.hpp
+9
-2
No files found.
src/api/api.cpp
View file @
818e6e53
...
...
@@ -236,6 +236,11 @@ void print_program(const program& p) { std::cout << p << std::endl; }
void
print_module
(
const
module
&
m
)
{
std
::
cout
<<
m
<<
std
::
endl
;
}
migraphx
::
instruction_ref
add_allocation
(
module
&
m
,
const
migraphx
::
shape
&
s
)
{
return
m
.
add_instruction
(
migraphx
::
make_op
(
"allocate"
,
{{
"shape"
,
migraphx
::
to_value
(
s
)}}),
{});
}
struct
experimental_custom_op
{
std
::
string
name
;
...
...
@@ -260,7 +265,12 @@ struct custom_operation
return
op
.
compute_shape
(
std
::
move
(
inputs
));
}
argument
compute
(
const
std
::
vector
<
argument
>&
)
const
{
MIGRAPHX_THROW
(
"Not computable"
);
}
// TODO: Compute method with module_args
argument
compute
(
migraphx
::
context
ctx
,
migraphx
::
shape
output_shape
,
std
::
vector
<
argument
>
inputs
)
const
{
return
op
.
compute
(
std
::
move
(
ctx
),
std
::
move
(
output_shape
),
std
::
move
(
inputs
));
}
};
template
<
class
CustomOp
>
...
...
@@ -577,6 +587,24 @@ struct migraphx_experimental_custom_op
manage_generic_ptr
<
migraphx_experimental_custom_op_copy
,
migraphx_experimental_custom_op_delete
>
object_ptr
=
nullptr
;
migraphx
::
experimental_custom_op
xobject
;
migraphx_experimental_custom_op_compute
compute_f
=
nullptr
;
migraphx
::
argument
compute
(
migraphx
::
context
ctx
,
migraphx
::
shape
output
,
std
::
vector
<
migraphx
::
argument
>
inputs
)
const
{
std
::
remove_pointer_t
<
migraphx_argument_t
>
out
;
if
(
compute_f
==
nullptr
)
throw
std
::
runtime_error
(
"compute function is missing."
);
auto
api_error_result
=
compute_f
(
&
out
,
object_ptr
.
data
,
object_cast
<
migraphx_context_t
>
(
&
(
ctx
)),
object_cast
<
migraphx_shape_t
>
(
&
(
output
)),
object_cast
<
migraphx_arguments_t
>
(
&
(
inputs
)));
if
(
api_error_result
!=
migraphx_status_success
)
throw
std
::
runtime_error
(
"Error in compute."
);
return
(
&
out
)
->
object
;
}
migraphx_experimental_custom_op_compute_shape
compute_shape_f
=
nullptr
;
migraphx
::
shape
compute_shape
(
std
::
vector
<
migraphx
::
shape
>
inputs
)
const
{
...
...
@@ -1141,6 +1169,21 @@ extern "C" migraphx_status migraphx_module_add_return(migraphx_instruction_t* ou
return
api_error_result
;
}
extern
"C"
migraphx_status
migraphx_module_add_allocation
(
migraphx_instruction_t
*
out
,
migraphx_module_t
module
,
const_migraphx_shape_t
s
)
{
auto
api_error_result
=
migraphx
::
try_
([
&
]
{
if
(
module
==
nullptr
)
MIGRAPHX_THROW
(
migraphx_status_bad_param
,
"Bad parameter module: Null pointer"
);
if
(
s
==
nullptr
)
MIGRAPHX_THROW
(
migraphx_status_bad_param
,
"Bad parameter s: Null pointer"
);
*
out
=
allocate
<
migraphx_instruction_t
>
(
migraphx
::
add_allocation
((
module
->
object
),
(
s
->
object
)));
});
return
api_error_result
;
}
extern
"C"
migraphx_status
migraphx_program_destroy
(
migraphx_program_t
program
)
{
auto
api_error_result
=
migraphx
::
try_
([
&
]
{
destroy
((
program
));
});
...
...
@@ -1772,6 +1815,14 @@ migraphx_experimental_custom_op_create(migraphx_experimental_custom_op_t* experi
return
api_error_result
;
}
extern
"C"
migraphx_status
migraphx_experimental_custom_op_set_compute
(
migraphx_experimental_custom_op_t
obj
,
migraphx_experimental_custom_op_compute
input
)
{
auto
api_error_result
=
migraphx
::
try_
([
&
]
{
(
obj
)
->
compute_f
=
(
input
);
});
return
api_error_result
;
}
extern
"C"
migraphx_status
migraphx_experimental_custom_op_set_compute_shape
(
migraphx_experimental_custom_op_t
obj
,
migraphx_experimental_custom_op_compute_shape
input
)
{
...
...
src/api/include/migraphx/migraphx.h
View file @
818e6e53
...
...
@@ -129,6 +129,12 @@ typedef const struct migraphx_context* const_migraphx_context_t;
typedef
struct
migraphx_experimental_custom_op
*
migraphx_experimental_custom_op_t
;
typedef
const
struct
migraphx_experimental_custom_op
*
const_migraphx_experimental_custom_op_t
;
typedef
migraphx_status
(
*
migraphx_experimental_custom_op_compute
)(
migraphx_argument_t
out
,
void
*
obj
,
migraphx_context_t
ctx
,
migraphx_shape_t
output
,
migraphx_arguments_t
inputs
);
typedef
migraphx_status
(
*
migraphx_experimental_custom_op_compute_shape
)(
migraphx_shape_t
out
,
void
*
obj
,
migraphx_shapes_t
inputs
);
...
...
@@ -295,6 +301,10 @@ migraphx_status migraphx_module_add_return(migraphx_instruction_t* out,
migraphx_module_t
module
,
migraphx_instructions_t
args
);
migraphx_status
migraphx_module_add_allocation
(
migraphx_instruction_t
*
out
,
migraphx_module_t
module
,
const_migraphx_shape_t
s
);
migraphx_status
migraphx_program_destroy
(
migraphx_program_t
program
);
migraphx_status
migraphx_program_assign_to
(
migraphx_program_t
output
,
...
...
@@ -477,6 +487,10 @@ migraphx_experimental_custom_op_create(migraphx_experimental_custom_op_t* experi
migraphx_experimental_custom_op_delete
d
,
const
char
*
name
);
migraphx_status
migraphx_experimental_custom_op_set_compute
(
migraphx_experimental_custom_op_t
obj
,
migraphx_experimental_custom_op_compute
input
);
migraphx_status
migraphx_experimental_custom_op_set_compute_shape
(
migraphx_experimental_custom_op_t
obj
,
migraphx_experimental_custom_op_compute_shape
input
);
...
...
src/api/include/migraphx/migraphx.hpp
View file @
818e6e53
...
...
@@ -401,11 +401,14 @@ struct interface_base : Base
return
x
;
}
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wdeprecated-declarations"
template
<
class
T
>
auto
auto_convert_param
(
rank
<
1
>
,
T
x
)
->
decltype
(
as_handle
<
T
>
{
x
})
{
return
as_handle
<
T
>
{
x
};
}
#pragma GCC diagnostic pop
template
<
class
T
>
auto
auto_convert_param
(
rank
<
2
>
,
T
x
)
->
decltype
(
as_handle
<
T
>
{
x
,
borrow
{}})
...
...
@@ -565,6 +568,14 @@ struct argument : MIGRAPHX_CONST_HANDLE_BASE(argument)
return
pout
;
}
template
<
typename
T
>
std
::
vector
<
T
>
as_vector
()
const
{
size_t
vector_len
=
this
->
get_shape
().
bytes
()
/
sizeof
(
T
);
T
*
buffer_ptr
=
reinterpret_cast
<
T
*>
(
this
->
data
());
return
{
buffer_ptr
,
buffer_ptr
+
vector_len
};
}
/// Generate an argument using random data
static
argument
generate
(
shape
ps
,
size_t
pseed
=
0
)
{
...
...
@@ -802,13 +813,20 @@ struct module
return
instruction
(
ret_ins
,
own
{});
}
instruction
add_allocation
(
const
migraphx
::
shape
&
s
)
{
migraphx_instruction_t
ret_ins
;
call
(
&
migraphx_module_add_allocation
,
&
ret_ins
,
mm
.
get
(),
s
.
get_handle_ptr
());
return
instruction
(
ret_ins
,
own
{});
}
migraphx_module_t
get_handle_ptr
()
const
{
return
mm
.
get
();
}
private:
std
::
shared_ptr
<
migraphx_module
>
mm
;
};
struct
context
struct
context
:
handle_lookup
<
context
,
migraphx_context
>
{
context
(
migraphx_context
*
p
,
borrow
)
:
ctx
(
std
::
shared_ptr
<
migraphx_context
*>
(),
p
)
{}
...
...
@@ -1177,9 +1195,10 @@ quantize_int8(const program& prog, const target& ptarget, const quantize_int8_op
struct
experimental_custom_op_base
{
virtual
std
::
string
name
()
const
=
0
;
virtual
shape
compute_shape
(
shapes
inputs
)
const
=
0
;
virtual
~
experimental_custom_op_base
()
=
default
;
virtual
std
::
string
name
()
const
=
0
;
virtual
argument
compute
(
context
ctx
,
shape
output
,
arguments
inputs
)
const
=
0
;
virtual
shape
compute_shape
(
shapes
inputs
)
const
=
0
;
virtual
~
experimental_custom_op_base
()
=
default
;
};
struct
experimental_custom_op
:
interface_base
<
MIGRAPHX_HANDLE_BASE
(
experimental_custom_op
)
>
...
...
@@ -1189,6 +1208,7 @@ struct experimental_custom_op : interface_base<MIGRAPHX_HANDLE_BASE(experimental
{
this
->
make_interface
(
&
migraphx_experimental_custom_op_create
,
obj
,
obj
.
name
().
c_str
());
MIGRAPHX_INTERFACE_LIFT
(
T
,
experimental_custom_op
,
compute_shape
);
MIGRAPHX_INTERFACE_LIFT
(
T
,
experimental_custom_op
,
compute
);
}
void
register_op
()
{
call
(
&
migraphx_experimental_custom_op_register
,
this
->
get_handle_ptr
());
}
...
...
src/api/migraphx.py
View file @
818e6e53
...
...
@@ -244,6 +244,10 @@ def module(h):
h
.
method
(
'add_return'
,
api
.
params
(
args
=
'std::vector<migraphx::instruction_ref>'
),
returns
=
'migraphx::instruction_ref'
)
h
.
method
(
'add_allocation'
,
api
.
params
(
s
=
'const migraphx::shape&'
),
invoke
=
'migraphx::add_allocation($@)'
,
returns
=
'migraphx::instruction_ref'
)
@
auto_handle
()
...
...
@@ -436,6 +440,11 @@ def context(h):
'migraphx::experimental_custom_op'
)
def
experimental_custom_op
(
h
):
h
.
constructor
(
'create'
,
api
.
params
(
name
=
'const char*'
))
h
.
virtual
(
'compute'
,
api
.
params
(
ctx
=
'migraphx::context'
,
output
=
'migraphx::shape'
,
inputs
=
'std::vector<migraphx::argument>'
),
returns
=
'migraphx::argument'
)
h
.
virtual
(
'compute_shape'
,
api
.
params
(
inputs
=
'std::vector<migraphx::shape>'
),
returns
=
'migraphx::shape'
)
...
...
src/include/migraphx/module.hpp
View file @
818e6e53
...
...
@@ -183,7 +183,7 @@ struct module
void
annotate
(
std
::
ostream
&
os
,
std
::
function
<
void
(
instruction_ref
)
>
a
)
const
;
std
::
vector
<
module_ref
>
get_sub_modules
()
const
;
std
::
vector
<
module_ref
>
get_sub_modules
(
bool
shallow
=
false
)
const
;
module
&
sort
();
ins_dep_map
calc_implicit_deps
()
const
;
...
...
src/include/migraphx/pass_manager.hpp
View file @
818e6e53
...
...
@@ -38,6 +38,7 @@ struct module_pass_manager
module_pass_manager
(
const
module_pass_manager
&
)
=
delete
;
virtual
module
&
get_module
()
=
0
;
virtual
module
*
create_module
(
const
std
::
string
&
name
)
=
0
;
virtual
module
*
get_common_parent
()
=
0
;
virtual
void
run_pass
(
const
pass
&
p
)
=
0
;
protected:
...
...
src/include/migraphx/program.hpp
View file @
818e6e53
...
...
@@ -132,6 +132,8 @@ struct program
std
::
vector
<
const
module
*>
get_modules
()
const
;
std
::
vector
<
module
*>
get_modules
();
std
::
unordered_multimap
<
module_ref
,
module_ref
>
get_module_tree
();
void
remove_module
(
const
std
::
string
&
name
);
void
remove_unused_modules
();
...
...
src/include/migraphx/ranges.hpp
View file @
818e6e53
...
...
@@ -216,10 +216,16 @@ void replace(Range&& r, const T& old, const T& new_x)
std
::
replace
(
r
.
begin
(),
r
.
end
(),
old
,
new_x
);
}
template
<
class
R1
,
class
R2
>
bool
equal
(
R1
&&
r1
,
R2
&&
r2
)
template
<
class
R1
,
class
R2
,
class
...
Predicate
>
bool
equal
(
R1
&&
r1
,
R2
&&
r2
,
Predicate
...
pred
)
{
return
std
::
equal
(
r1
.
begin
(),
r1
.
end
(),
r2
.
begin
(),
r2
.
end
());
return
std
::
equal
(
r1
.
begin
(),
r1
.
end
(),
r2
.
begin
(),
r2
.
end
(),
pred
...);
}
template
<
class
Range
>
auto
distance
(
Range
&&
r
)
{
return
std
::
distance
(
r
.
begin
(),
r
.
end
());
}
template
<
class
R
>
...
...
src/module.cpp
View file @
818e6e53
...
...
@@ -821,17 +821,20 @@ void module::annotate(std::ostream& os, std::function<void(instruction_ref)> a)
});
}
std
::
vector
<
module_ref
>
module
::
get_sub_modules
()
const
std
::
vector
<
module_ref
>
module
::
get_sub_modules
(
bool
shallow
)
const
{
std
::
vector
<
module_ref
>
vec_modules
;
for
(
auto
ins
:
iterator_for
(
*
this
))
{
const
auto
&
mod_args
=
ins
->
module_inputs
();
vec_modules
.
insert
(
vec_modules
.
end
(),
mod_args
.
begin
(),
mod_args
.
end
());
for
(
const
auto
&
smod
:
mod_args
)
if
(
not
shallow
)
{
auto
sub_mods
=
smod
->
get_sub_modules
();
vec_modules
.
insert
(
vec_modules
.
end
(),
sub_mods
.
begin
(),
sub_mods
.
end
());
for
(
const
auto
&
smod
:
mod_args
)
{
auto
sub_mods
=
smod
->
get_sub_modules
();
vec_modules
.
insert
(
vec_modules
.
end
(),
sub_mods
.
begin
(),
sub_mods
.
end
());
}
}
}
...
...
src/pass_manager.cpp
View file @
818e6e53
...
...
@@ -66,14 +66,12 @@ void run_pass(program& prog, const pass& p, tracer trace)
struct
module_pm
:
module_pass_manager
{
module
*
mod
;
program
*
prog
;
tracer
*
t
;
module
*
mod
=
nullptr
;
tracer
*
t
=
nullptr
;
module
*
common_parent
=
nullptr
;
program
*
prog
=
nullptr
;
module_pm
(
module
*
pmod
=
nullptr
,
program
*
pprog
=
nullptr
,
tracer
*
pt
=
nullptr
)
:
mod
(
pmod
),
prog
(
pprog
),
t
(
pt
)
{
}
module_pm
(
module
*
pmod
=
nullptr
,
tracer
*
pt
=
nullptr
)
:
mod
(
pmod
),
t
(
pt
)
{}
template
<
class
...
Ts
>
void
trace
(
Ts
&&
...
xs
)
const
...
...
@@ -92,6 +90,7 @@ struct module_pm : module_pass_manager
assert
(
prog
);
return
prog
->
create_module
(
name
);
}
virtual
module
*
get_common_parent
()
override
{
return
common_parent
;
}
virtual
void
run_pass
(
const
pass
&
p
)
override
{
assert
(
mod
);
...
...
@@ -111,7 +110,7 @@ void run_passes(module& mod, const std::vector<pass>& passes, tracer trace)
trace
=
tracer
{
std
::
cout
};
for
(
const
auto
&
p
:
passes
)
{
module_pm
{
&
mod
,
nullptr
,
&
trace
}.
run_pass
(
p
);
module_pm
{
&
mod
,
&
trace
}.
run_pass
(
p
);
}
}
...
...
@@ -119,14 +118,31 @@ void run_passes(program& prog, const std::vector<pass>& passes, tracer trace)
{
if
(
enabled
(
MIGRAPHX_TRACE_PASSES
{}))
trace
=
tracer
{
std
::
cout
};
std
::
unordered_set
<
module_ref
>
visited
;
for
(
const
auto
&
p
:
passes
)
{
auto
mods
=
prog
.
get_modules
();
auto
tree
=
prog
.
get_module_tree
();
visited
.
clear
();
for
(
const
auto
&
mod
:
reverse
(
mods
))
{
if
(
mod
->
bypass
())
continue
;
module_pm
{
mod
,
&
prog
,
&
trace
}.
run_pass
(
p
);
if
(
not
visited
.
insert
(
mod
).
second
)
continue
;
module_pm
mpm
{
mod
,
&
trace
};
mpm
.
prog
=
&
prog
;
auto
parents
=
range
(
tree
.
equal_range
(
mod
));
auto
nparents
=
distance
(
parents
);
if
(
nparents
==
0
)
mpm
.
common_parent
=
nullptr
;
else
if
(
nparents
==
1
)
mpm
.
common_parent
=
parents
.
begin
()
->
second
;
else
// Just set common parent to main module when there is muliple parents for now
// TODO: Compute the common parent
mpm
.
common_parent
=
prog
.
get_main_module
();
mpm
.
run_pass
(
p
);
}
run_pass
(
prog
,
p
,
trace
);
}
...
...
src/program.cpp
View file @
818e6e53
...
...
@@ -869,6 +869,23 @@ std::vector<module*> program::get_modules()
return
result
;
}
template
<
class
Module
,
class
Map
>
void
generic_insert_module_tree
(
Module
*
pm
,
Map
&
m
)
{
for
(
auto
*
sm
:
pm
->
get_sub_modules
(
true
))
{
m
.
insert
(
std
::
make_pair
(
sm
,
pm
));
generic_insert_module_tree
(
sm
,
m
);
}
}
std
::
unordered_multimap
<
module_ref
,
module_ref
>
program
::
get_module_tree
()
{
std
::
unordered_multimap
<
module_ref
,
module_ref
>
result
;
generic_insert_module_tree
(
this
->
get_main_module
(),
result
);
return
result
;
}
template
<
class
Map
,
class
T
>
bool
is_unused_module
(
Map
&
m
,
const
std
::
vector
<
T
*>&
mods
,
const
std
::
string
&
name
)
{
...
...
src/shape.cpp
View file @
818e6e53
...
...
@@ -61,9 +61,7 @@ struct shape_impl
{
assert
(
t
!=
shape
::
tuple_type
);
assert
(
m_lens
.
size
()
==
m_strides
.
size
());
// assert(std::any_of(m_strides.begin(), m_strides.end(), [](auto x) { return x > 0; }) and
// "At least one stride must be non-zero");
m_standard
=
this
->
elements
()
==
this
->
element_space
()
and
m_standard
=
this
->
elements
()
==
this
->
element_space
()
and
not
skips
()
and
std
::
is_sorted
(
m_strides
.
rbegin
(),
m_strides
.
rend
());
}
...
...
@@ -110,6 +108,15 @@ struct shape_impl
m_lens
.
begin
(),
m_lens
.
end
(),
std
::
size_t
{
1
},
std
::
multiplies
<
std
::
size_t
>
());
}
// Does the shape skip over elements?
bool
skips
()
const
{
assert
(
m_lens
.
size
()
==
m_strides
.
size
());
if
(
elements
()
==
1
)
return
false
;
return
std
::
none_of
(
m_strides
.
begin
(),
m_strides
.
end
(),
[](
auto
x
)
{
return
x
==
1
;
});
}
std
::
shared_ptr
<
shape_impl
>
copy
()
const
{
return
std
::
make_shared
<
shape_impl
>
(
*
this
);
}
};
...
...
@@ -260,7 +267,8 @@ void shape::multi_copy(std::size_t i, std::size_t* start, const std::size_t* end
bool
shape
::
packed
()
const
{
return
this
->
sub_shapes
().
empty
()
and
this
->
elements
()
==
this
->
element_space
();
return
this
->
sub_shapes
().
empty
()
and
not
impl
->
skips
()
and
this
->
elements
()
==
this
->
element_space
();
}
bool
shape
::
transposed
()
const
...
...
@@ -285,10 +293,8 @@ bool shape::transposed() const
bool
shape
::
broadcasted
()
const
{
assert
(
this
->
lens
().
size
()
==
this
->
strides
().
size
());
return
std
::
accumulate
(
this
->
strides
().
begin
(),
this
->
strides
().
end
(),
std
::
size_t
{
1
},
std
::
multiplies
<
std
::
size_t
>
())
==
0
;
return
std
::
any_of
(
this
->
strides
().
begin
(),
this
->
strides
().
end
(),
[](
auto
x
)
{
return
x
==
0
;
});
}
bool
shape
::
scalar
()
const
...
...
src/targets/gpu/CMakeLists.txt
View file @
818e6e53
...
...
@@ -163,7 +163,6 @@ add_library(migraphx_gpu
convolution.cpp
deconvolution.cpp
device_name.cpp
eliminate_workspace.cpp
elu.cpp
fuse_mlir.cpp
fuse_ops.cpp
...
...
src/targets/gpu/fuse_ops.cpp
View file @
818e6e53
...
...
@@ -48,6 +48,7 @@
#include <migraphx/instruction.hpp>
#include <migraphx/register_op.hpp>
#include <migraphx/array.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/op/clip.hpp>
#include <cmath>
#include <set>
...
...
@@ -701,6 +702,7 @@ struct miopen_fusion
return
args
.
back
();
}
};
MIGRAPHX_REGISTER_OP
(
miopen_fusion
)
struct
miopen_conv_bias
{
...
...
@@ -1005,9 +1007,43 @@ struct find_commutative_broadcast
};
}
// namespace
struct
find_contiguous
{
auto
matcher
()
const
{
return
match
::
name
(
"gpu::contiguous"
);
}
void
apply
(
module
&
m
,
const
match
::
matcher_result
&
r
)
const
{
auto
ins
=
r
.
result
;
m
.
replace_instruction
(
ins
,
make_op
(
"gpu::precompile_op"
,
{{
"op"
,
to_value
(
make_op
(
"contiguous"
))}}),
ins
->
inputs
());
}
};
struct
find_contiguous_pointwise
{
auto
matcher
()
const
{
return
match
::
name
(
"gpu::contiguous"
)(
match
::
arg
(
0
)(
precompile_name
(
"pointwise"
)));
}
void
apply
(
module
&
m
,
const
match
::
matcher_result
&
r
)
const
{
auto
ins
=
r
.
result
;
auto
pw
=
ins
->
inputs
().
front
();
auto
alloc
=
ins
->
inputs
().
back
();
auto
args
=
pw
->
inputs
();
args
.
back
()
=
alloc
;
m
.
replace_instruction
(
ins
,
pw
->
get_operator
(),
args
,
pw
->
module_inputs
());
}
};
void
fuse_ops
::
apply
(
module
&
m
)
const
{
match
::
find_matches
(
m
,
find_gelu
{},
find_gelu_new
{
fast_math
});
match
::
find_matches
(
m
,
find_contiguous_pointwise
{},
find_gelu
{},
find_gelu_new
{
fast_math
});
run_passes
(
m
,
{
dead_code_elimination
{}});
match
::
find_matches
(
m
,
find_triadd
{});
match
::
find_matches
(
m
,
...
...
@@ -1029,6 +1065,7 @@ void fuse_ops::apply(module& m) const
find_gemm_add
{},
find_gemm_pointwise
{},
find_commutative_broadcast
{});
match
::
find_matches
(
m
,
find_contiguous
{});
}
}
// namespace gpu
...
...
src/targets/gpu/hip.cpp
View file @
818e6e53
...
...
@@ -37,7 +37,6 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace
gpu
{
MIGRAPHX_REGISTER_OP
(
hip_allocate
)
MIGRAPHX_REGISTER_OP
(
hip_sync_device
)
MIGRAPHX_REGISTER_OP
(
hip_sync_stream
)
MIGRAPHX_REGISTER_OP
(
hip_copy_to_gpu
)
MIGRAPHX_REGISTER_OP
(
hip_copy_from_gpu
)
...
...
src/targets/gpu/include/migraphx/gpu/eliminate_workspace.hpp
deleted
100644 → 0
View file @
d716c516
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_RTGLIB_ELIMINATE_WORKSPACE_HPP
#define MIGRAPHX_GUARD_RTGLIB_ELIMINATE_WORKSPACE_HPP
#include <string>
#include <migraphx/instruction_ref.hpp>
#include <migraphx/config.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
struct
module
;
namespace
gpu
{
struct
eliminate_workspace
{
std
::
string
name
()
const
{
return
"eliminate_workspace"
;
}
void
apply
(
module
&
m
)
const
;
};
}
// namespace gpu
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif
src/targets/gpu/include/migraphx/gpu/hip.hpp
View file @
818e6e53
...
...
@@ -59,12 +59,11 @@ argument get_preallocation(context& ctx, const std::string& id);
struct
hip_allocate
{
shape
s
;
std
::
string
tag
{};
template
<
class
Self
,
class
F
>
static
auto
reflect
(
Self
&
self
,
F
f
)
{
return
pack
(
f
(
self
.
s
,
"shape"
)
,
f
(
self
.
tag
,
"tag"
)
);
return
pack
(
f
(
self
.
s
,
"shape"
));
}
std
::
string
name
()
const
{
return
"hip::allocate"
;
}
...
...
@@ -79,42 +78,8 @@ struct hip_allocate
}
};
struct
hip_sync_device
{
std
::
string
tag
{};
template
<
class
Self
,
class
F
>
static
auto
reflect
(
Self
&
self
,
F
f
)
{
return
pack
(
f
(
self
.
tag
,
"tag"
));
}
std
::
string
name
()
const
{
return
"hip::sync_device"
;
}
shape
compute_shape
(
const
std
::
vector
<
shape
>&
inputs
)
const
{
if
(
inputs
.
empty
())
return
{};
return
inputs
.
front
();
}
argument
compute
(
context
&
,
const
shape
&
,
const
std
::
vector
<
argument
>&
args
)
const
{
gpu_sync
();
if
(
args
.
empty
())
return
{};
return
args
.
front
();
}
};
struct
hip_sync_stream
{
std
::
string
tag
{};
template
<
class
Self
,
class
F
>
static
auto
reflect
(
Self
&
self
,
F
f
)
{
return
pack
(
f
(
self
.
tag
,
"tag"
));
}
std
::
string
name
()
const
{
return
"hip::sync_stream"
;
}
shape
compute_shape
(
const
std
::
vector
<
shape
>&
inputs
)
const
...
...
src/targets/gpu/jit/pointwise.cpp
View file @
818e6e53
...
...
@@ -79,7 +79,7 @@ static std::vector<std::string> get_op_names(const module& m)
struct
pointwise_compiler
:
compiler
<
pointwise_compiler
>
{
std
::
vector
<
std
::
string
>
names
()
const
{
return
{
"pointwise"
};
}
std
::
vector
<
std
::
string
>
names
()
const
{
return
{
"pointwise"
,
"contiguous"
};
}
static
std
::
size_t
oversubscribe_if
(
bool
b
)
{
...
...
@@ -114,34 +114,45 @@ struct pointwise_compiler : compiler<pointwise_compiler>
return
compile_hip_code_object
(
src
,
options
);
}
compiler_replace
compile
(
context
&
ctx
,
instruction_ref
ins
,
const
operation
&
)
const
compiler_replace
compile
(
context
&
ctx
,
instruction_ref
ins
,
const
operation
&
op
)
const
{
assert
(
not
ins
->
module_inputs
().
empty
());
auto
*
pm
=
ins
->
module_inputs
().
front
();
run_passes
(
*
pm
,
{
eliminate_common_subexpression
{},
dead_code_elimination
{}});
cpp_generator
g
;
g
.
fmap
([](
const
std
::
string
&
fname
)
{
return
"migraphx::"
+
fname
;
});
g
.
add_point_op
(
"where"
,
"${function:where}(${0}, ${1}, ${2})"
);
g
.
add_point_op
(
"prelu"
,
"${function:where}(${0} < 0, ${0} * ${1}, ${0})"
);
g
.
add_point_op
(
"sign"
,
"${function:where}(${0} > 0, 1, ${function:where}(${0} < 0, -1, 0))"
);
g
.
add_point_op
(
"equal"
,
"migraphx::abs(${0} == ${1})"
);
g
.
add_point_op
(
"less"
,
"migraphx::abs(${0} < ${1})"
);
g
.
add_point_op
(
"greater"
,
"migraphx::abs(${0} > ${1})"
);
g
.
add_point_op
(
"not"
,
"migraphx::abs(not ${0})"
);
// Add explict conversions
g
.
fresult
(
[](
const
shape
&
s
)
{
return
"migraphx::convert<"
+
shape
::
cpp_type
(
s
.
type
())
+
">"
;
});
auto
name
=
g
.
create_function
(
g
.
generate_module
(
*
pm
).
set_attributes
({
"__device__"
}).
set_generic_types
(
*
pm
));
std
::
string
lambda
=
"MIGRAPHX_LIFT("
+
name
+
")"
;
auto
op_names
=
get_op_names
(
*
pm
);
op_names
.
push_back
(
"kernel"
);
auto
op_name_string
=
join_strings
(
op_names
,
"_"
);
return
replace
(
compile_op
(
ctx
,
to_shapes
(
ins
->
inputs
()),
{{
"lambda"
,
lambda
},
{
"preamble"
,
g
.
str
()},
{
"kernel"
,
op_name_string
}}));
if
(
op
.
name
()
==
"contiguous"
)
{
return
replace
(
compile_op
(
ctx
,
to_shapes
(
ins
->
inputs
()),
{{
"lambda"
,
"[](auto x) { return x; }"
},
{
"kernel"
,
"contiguous_kernel"
}}));
}
else
{
assert
(
not
ins
->
module_inputs
().
empty
());
auto
*
pm
=
ins
->
module_inputs
().
front
();
run_passes
(
*
pm
,
{
eliminate_common_subexpression
{},
dead_code_elimination
{}});
cpp_generator
g
;
g
.
fmap
([](
const
std
::
string
&
fname
)
{
return
"migraphx::"
+
fname
;
});
g
.
add_point_op
(
"where"
,
"${function:where}(${0}, ${1}, ${2})"
);
g
.
add_point_op
(
"prelu"
,
"${function:where}(${0} < 0, ${0} * ${1}, ${0})"
);
g
.
add_point_op
(
"sign"
,
"${function:where}(${0} > 0, 1, ${function:where}(${0} < 0, -1, 0))"
);
g
.
add_point_op
(
"equal"
,
"migraphx::abs(${0} == ${1})"
);
g
.
add_point_op
(
"less"
,
"migraphx::abs(${0} < ${1})"
);
g
.
add_point_op
(
"greater"
,
"migraphx::abs(${0} > ${1})"
);
g
.
add_point_op
(
"not"
,
"migraphx::abs(not ${0})"
);
// Add explict conversions
g
.
fresult
([](
const
shape
&
s
)
{
return
"migraphx::convert<"
+
shape
::
cpp_type
(
s
.
type
())
+
">"
;
});
auto
name
=
g
.
create_function
(
g
.
generate_module
(
*
pm
).
set_attributes
({
"__device__"
}).
set_generic_types
(
*
pm
));
std
::
string
lambda
=
"MIGRAPHX_LIFT("
+
name
+
")"
;
auto
op_names
=
get_op_names
(
*
pm
);
op_names
.
push_back
(
"kernel"
);
auto
op_name_string
=
join_strings
(
op_names
,
"_"
);
return
replace
(
compile_op
(
ctx
,
to_shapes
(
ins
->
inputs
()),
{{
"lambda"
,
lambda
},
{
"preamble"
,
g
.
str
()},
{
"kernel"
,
op_name_string
}}));
}
}
};
}
// namespace gpu
...
...
src/targets/gpu/kernels/include/migraphx/kernels/algorithm.hpp
View file @
818e6e53
...
...
@@ -49,7 +49,7 @@ constexpr T accumulate(InputIt first, InputIt last, T init, BinaryOperation op)
{
for
(;
first
!=
last
;
++
first
)
{
init
=
op
(
st
d
::
move
(
init
),
*
first
);
init
=
op
(
st
atic_cast
<
T
&&>
(
init
),
*
first
);
}
return
init
;
}
...
...
@@ -64,6 +64,20 @@ constexpr OutputIt copy(InputIt first, InputIt last, OutputIt d_first)
return
d_first
;
}
template
<
class
InputIt
,
class
OutputIt
,
class
UnaryPredicate
>
constexpr
OutputIt
copy_if
(
InputIt
first
,
InputIt
last
,
OutputIt
d_first
,
UnaryPredicate
pred
)
{
for
(;
first
!=
last
;
++
first
)
{
if
(
pred
(
*
first
))
{
*
d_first
=
*
first
;
++
d_first
;
}
}
return
d_first
;
}
template
<
class
Iterator
,
class
Compare
>
constexpr
Iterator
is_sorted_until
(
Iterator
first
,
Iterator
last
,
Compare
comp
)
{
...
...
@@ -115,6 +129,24 @@ constexpr Iterator find(Iterator first, Iterator last, const T& value)
return
find_if
(
first
,
last
,
[
&
](
const
auto
&
x
)
{
return
x
==
value
;
});
}
template
<
class
InputIt
,
class
UnaryPredicate
>
constexpr
bool
any_of
(
InputIt
first
,
InputIt
last
,
UnaryPredicate
p
)
{
return
find_if
(
first
,
last
,
p
)
!=
last
;
}
template
<
class
InputIt
,
class
UnaryPredicate
>
constexpr
bool
none_of
(
InputIt
first
,
InputIt
last
,
UnaryPredicate
p
)
{
return
find_if
(
first
,
last
,
p
)
==
last
;
}
template
<
class
InputIt
,
class
UnaryPredicate
>
constexpr
bool
all_of
(
InputIt
first
,
InputIt
last
,
UnaryPredicate
p
)
{
return
none_of
(
first
,
last
,
[
=
](
auto
&&
x
)
{
return
not
p
(
x
);
});
}
template
<
class
Iterator1
,
class
Iterator2
>
constexpr
Iterator1
search
(
Iterator1
first
,
Iterator1
last
,
Iterator2
s_first
,
Iterator2
s_last
)
{
...
...
src/targets/gpu/kernels/include/migraphx/kernels/pointwise.hpp
View file @
818e6e53
...
...
@@ -41,8 +41,15 @@ struct implicit_conversion_op
template
<
index_int
N
,
class
U
>
constexpr
operator
vec
<
U
,
N
>
()
const
{
static_assert
(
vec_size
<
T
>
()
==
N
,
"Vector mismatch size"
);
return
__builtin_convertvector
(
x
,
vec
<
U
,
N
>
);
if
constexpr
(
vec_size
<
T
>
()
==
0
)
{
return
x
;
}
else
{
static_assert
(
vec_size
<
T
>
()
==
N
,
"Vector mismatch size"
);
return
__builtin_convertvector
(
x
,
vec
<
U
,
N
>
);
}
}
template
<
class
U
>
...
...
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment