Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
818e6e53
Unverified
Commit
818e6e53
authored
Jun 28, 2022
by
Paul Fultz II
Committed by
GitHub
Jun 28, 2022
Browse files
Merge branch 'develop' into mlir-c
parents
d716c516
cb18b0b5
Changes
51
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
295 additions
and
146 deletions
+295
-146
src/api/api.cpp
src/api/api.cpp
+52
-1
src/api/include/migraphx/migraphx.h
src/api/include/migraphx/migraphx.h
+14
-0
src/api/include/migraphx/migraphx.hpp
src/api/include/migraphx/migraphx.hpp
+24
-4
src/api/migraphx.py
src/api/migraphx.py
+9
-0
src/include/migraphx/module.hpp
src/include/migraphx/module.hpp
+1
-1
src/include/migraphx/pass_manager.hpp
src/include/migraphx/pass_manager.hpp
+1
-0
src/include/migraphx/program.hpp
src/include/migraphx/program.hpp
+2
-0
src/include/migraphx/ranges.hpp
src/include/migraphx/ranges.hpp
+9
-3
src/module.cpp
src/module.cpp
+7
-4
src/pass_manager.cpp
src/pass_manager.cpp
+25
-9
src/program.cpp
src/program.cpp
+17
-0
src/shape.cpp
src/shape.cpp
+14
-8
src/targets/gpu/CMakeLists.txt
src/targets/gpu/CMakeLists.txt
+0
-1
src/targets/gpu/fuse_ops.cpp
src/targets/gpu/fuse_ops.cpp
+38
-1
src/targets/gpu/hip.cpp
src/targets/gpu/hip.cpp
+0
-1
src/targets/gpu/include/migraphx/gpu/eliminate_workspace.hpp
src/targets/gpu/include/migraphx/gpu/eliminate_workspace.hpp
+0
-46
src/targets/gpu/include/migraphx/gpu/hip.hpp
src/targets/gpu/include/migraphx/gpu/hip.hpp
+1
-36
src/targets/gpu/jit/pointwise.cpp
src/targets/gpu/jit/pointwise.cpp
+39
-28
src/targets/gpu/kernels/include/migraphx/kernels/algorithm.hpp
...argets/gpu/kernels/include/migraphx/kernels/algorithm.hpp
+33
-1
src/targets/gpu/kernels/include/migraphx/kernels/pointwise.hpp
...argets/gpu/kernels/include/migraphx/kernels/pointwise.hpp
+9
-2
No files found.
src/api/api.cpp
View file @
818e6e53
...
@@ -236,6 +236,11 @@ void print_program(const program& p) { std::cout << p << std::endl; }
...
@@ -236,6 +236,11 @@ void print_program(const program& p) { std::cout << p << std::endl; }
void
print_module
(
const
module
&
m
)
{
std
::
cout
<<
m
<<
std
::
endl
;
}
void
print_module
(
const
module
&
m
)
{
std
::
cout
<<
m
<<
std
::
endl
;
}
migraphx
::
instruction_ref
add_allocation
(
module
&
m
,
const
migraphx
::
shape
&
s
)
{
return
m
.
add_instruction
(
migraphx
::
make_op
(
"allocate"
,
{{
"shape"
,
migraphx
::
to_value
(
s
)}}),
{});
}
struct
experimental_custom_op
struct
experimental_custom_op
{
{
std
::
string
name
;
std
::
string
name
;
...
@@ -260,7 +265,12 @@ struct custom_operation
...
@@ -260,7 +265,12 @@ struct custom_operation
return
op
.
compute_shape
(
std
::
move
(
inputs
));
return
op
.
compute_shape
(
std
::
move
(
inputs
));
}
}
argument
compute
(
const
std
::
vector
<
argument
>&
)
const
{
MIGRAPHX_THROW
(
"Not computable"
);
}
// TODO: Compute method with module_args
argument
compute
(
migraphx
::
context
ctx
,
migraphx
::
shape
output_shape
,
std
::
vector
<
argument
>
inputs
)
const
{
return
op
.
compute
(
std
::
move
(
ctx
),
std
::
move
(
output_shape
),
std
::
move
(
inputs
));
}
};
};
template
<
class
CustomOp
>
template
<
class
CustomOp
>
...
@@ -577,6 +587,24 @@ struct migraphx_experimental_custom_op
...
@@ -577,6 +587,24 @@ struct migraphx_experimental_custom_op
manage_generic_ptr
<
migraphx_experimental_custom_op_copy
,
migraphx_experimental_custom_op_delete
>
manage_generic_ptr
<
migraphx_experimental_custom_op_copy
,
migraphx_experimental_custom_op_delete
>
object_ptr
=
nullptr
;
object_ptr
=
nullptr
;
migraphx
::
experimental_custom_op
xobject
;
migraphx
::
experimental_custom_op
xobject
;
migraphx_experimental_custom_op_compute
compute_f
=
nullptr
;
migraphx
::
argument
compute
(
migraphx
::
context
ctx
,
migraphx
::
shape
output
,
std
::
vector
<
migraphx
::
argument
>
inputs
)
const
{
std
::
remove_pointer_t
<
migraphx_argument_t
>
out
;
if
(
compute_f
==
nullptr
)
throw
std
::
runtime_error
(
"compute function is missing."
);
auto
api_error_result
=
compute_f
(
&
out
,
object_ptr
.
data
,
object_cast
<
migraphx_context_t
>
(
&
(
ctx
)),
object_cast
<
migraphx_shape_t
>
(
&
(
output
)),
object_cast
<
migraphx_arguments_t
>
(
&
(
inputs
)));
if
(
api_error_result
!=
migraphx_status_success
)
throw
std
::
runtime_error
(
"Error in compute."
);
return
(
&
out
)
->
object
;
}
migraphx_experimental_custom_op_compute_shape
compute_shape_f
=
nullptr
;
migraphx_experimental_custom_op_compute_shape
compute_shape_f
=
nullptr
;
migraphx
::
shape
compute_shape
(
std
::
vector
<
migraphx
::
shape
>
inputs
)
const
migraphx
::
shape
compute_shape
(
std
::
vector
<
migraphx
::
shape
>
inputs
)
const
{
{
...
@@ -1141,6 +1169,21 @@ extern "C" migraphx_status migraphx_module_add_return(migraphx_instruction_t* ou
...
@@ -1141,6 +1169,21 @@ extern "C" migraphx_status migraphx_module_add_return(migraphx_instruction_t* ou
return
api_error_result
;
return
api_error_result
;
}
}
extern
"C"
migraphx_status
migraphx_module_add_allocation
(
migraphx_instruction_t
*
out
,
migraphx_module_t
module
,
const_migraphx_shape_t
s
)
{
auto
api_error_result
=
migraphx
::
try_
([
&
]
{
if
(
module
==
nullptr
)
MIGRAPHX_THROW
(
migraphx_status_bad_param
,
"Bad parameter module: Null pointer"
);
if
(
s
==
nullptr
)
MIGRAPHX_THROW
(
migraphx_status_bad_param
,
"Bad parameter s: Null pointer"
);
*
out
=
allocate
<
migraphx_instruction_t
>
(
migraphx
::
add_allocation
((
module
->
object
),
(
s
->
object
)));
});
return
api_error_result
;
}
extern
"C"
migraphx_status
migraphx_program_destroy
(
migraphx_program_t
program
)
extern
"C"
migraphx_status
migraphx_program_destroy
(
migraphx_program_t
program
)
{
{
auto
api_error_result
=
migraphx
::
try_
([
&
]
{
destroy
((
program
));
});
auto
api_error_result
=
migraphx
::
try_
([
&
]
{
destroy
((
program
));
});
...
@@ -1772,6 +1815,14 @@ migraphx_experimental_custom_op_create(migraphx_experimental_custom_op_t* experi
...
@@ -1772,6 +1815,14 @@ migraphx_experimental_custom_op_create(migraphx_experimental_custom_op_t* experi
return
api_error_result
;
return
api_error_result
;
}
}
extern
"C"
migraphx_status
migraphx_experimental_custom_op_set_compute
(
migraphx_experimental_custom_op_t
obj
,
migraphx_experimental_custom_op_compute
input
)
{
auto
api_error_result
=
migraphx
::
try_
([
&
]
{
(
obj
)
->
compute_f
=
(
input
);
});
return
api_error_result
;
}
extern
"C"
migraphx_status
migraphx_experimental_custom_op_set_compute_shape
(
extern
"C"
migraphx_status
migraphx_experimental_custom_op_set_compute_shape
(
migraphx_experimental_custom_op_t
obj
,
migraphx_experimental_custom_op_compute_shape
input
)
migraphx_experimental_custom_op_t
obj
,
migraphx_experimental_custom_op_compute_shape
input
)
{
{
...
...
src/api/include/migraphx/migraphx.h
View file @
818e6e53
...
@@ -129,6 +129,12 @@ typedef const struct migraphx_context* const_migraphx_context_t;
...
@@ -129,6 +129,12 @@ typedef const struct migraphx_context* const_migraphx_context_t;
typedef
struct
migraphx_experimental_custom_op
*
migraphx_experimental_custom_op_t
;
typedef
struct
migraphx_experimental_custom_op
*
migraphx_experimental_custom_op_t
;
typedef
const
struct
migraphx_experimental_custom_op
*
const_migraphx_experimental_custom_op_t
;
typedef
const
struct
migraphx_experimental_custom_op
*
const_migraphx_experimental_custom_op_t
;
typedef
migraphx_status
(
*
migraphx_experimental_custom_op_compute
)(
migraphx_argument_t
out
,
void
*
obj
,
migraphx_context_t
ctx
,
migraphx_shape_t
output
,
migraphx_arguments_t
inputs
);
typedef
migraphx_status
(
*
migraphx_experimental_custom_op_compute_shape
)(
migraphx_shape_t
out
,
typedef
migraphx_status
(
*
migraphx_experimental_custom_op_compute_shape
)(
migraphx_shape_t
out
,
void
*
obj
,
void
*
obj
,
migraphx_shapes_t
inputs
);
migraphx_shapes_t
inputs
);
...
@@ -295,6 +301,10 @@ migraphx_status migraphx_module_add_return(migraphx_instruction_t* out,
...
@@ -295,6 +301,10 @@ migraphx_status migraphx_module_add_return(migraphx_instruction_t* out,
migraphx_module_t
module
,
migraphx_module_t
module
,
migraphx_instructions_t
args
);
migraphx_instructions_t
args
);
migraphx_status
migraphx_module_add_allocation
(
migraphx_instruction_t
*
out
,
migraphx_module_t
module
,
const_migraphx_shape_t
s
);
migraphx_status
migraphx_program_destroy
(
migraphx_program_t
program
);
migraphx_status
migraphx_program_destroy
(
migraphx_program_t
program
);
migraphx_status
migraphx_program_assign_to
(
migraphx_program_t
output
,
migraphx_status
migraphx_program_assign_to
(
migraphx_program_t
output
,
...
@@ -477,6 +487,10 @@ migraphx_experimental_custom_op_create(migraphx_experimental_custom_op_t* experi
...
@@ -477,6 +487,10 @@ migraphx_experimental_custom_op_create(migraphx_experimental_custom_op_t* experi
migraphx_experimental_custom_op_delete
d
,
migraphx_experimental_custom_op_delete
d
,
const
char
*
name
);
const
char
*
name
);
migraphx_status
migraphx_experimental_custom_op_set_compute
(
migraphx_experimental_custom_op_t
obj
,
migraphx_experimental_custom_op_compute
input
);
migraphx_status
migraphx_experimental_custom_op_set_compute_shape
(
migraphx_status
migraphx_experimental_custom_op_set_compute_shape
(
migraphx_experimental_custom_op_t
obj
,
migraphx_experimental_custom_op_compute_shape
input
);
migraphx_experimental_custom_op_t
obj
,
migraphx_experimental_custom_op_compute_shape
input
);
...
...
src/api/include/migraphx/migraphx.hpp
View file @
818e6e53
...
@@ -401,11 +401,14 @@ struct interface_base : Base
...
@@ -401,11 +401,14 @@ struct interface_base : Base
return
x
;
return
x
;
}
}
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wdeprecated-declarations"
template
<
class
T
>
template
<
class
T
>
auto
auto_convert_param
(
rank
<
1
>
,
T
x
)
->
decltype
(
as_handle
<
T
>
{
x
})
auto
auto_convert_param
(
rank
<
1
>
,
T
x
)
->
decltype
(
as_handle
<
T
>
{
x
})
{
{
return
as_handle
<
T
>
{
x
};
return
as_handle
<
T
>
{
x
};
}
}
#pragma GCC diagnostic pop
template
<
class
T
>
template
<
class
T
>
auto
auto_convert_param
(
rank
<
2
>
,
T
x
)
->
decltype
(
as_handle
<
T
>
{
x
,
borrow
{}})
auto
auto_convert_param
(
rank
<
2
>
,
T
x
)
->
decltype
(
as_handle
<
T
>
{
x
,
borrow
{}})
...
@@ -565,6 +568,14 @@ struct argument : MIGRAPHX_CONST_HANDLE_BASE(argument)
...
@@ -565,6 +568,14 @@ struct argument : MIGRAPHX_CONST_HANDLE_BASE(argument)
return
pout
;
return
pout
;
}
}
template
<
typename
T
>
std
::
vector
<
T
>
as_vector
()
const
{
size_t
vector_len
=
this
->
get_shape
().
bytes
()
/
sizeof
(
T
);
T
*
buffer_ptr
=
reinterpret_cast
<
T
*>
(
this
->
data
());
return
{
buffer_ptr
,
buffer_ptr
+
vector_len
};
}
/// Generate an argument using random data
/// Generate an argument using random data
static
argument
generate
(
shape
ps
,
size_t
pseed
=
0
)
static
argument
generate
(
shape
ps
,
size_t
pseed
=
0
)
{
{
...
@@ -802,13 +813,20 @@ struct module
...
@@ -802,13 +813,20 @@ struct module
return
instruction
(
ret_ins
,
own
{});
return
instruction
(
ret_ins
,
own
{});
}
}
instruction
add_allocation
(
const
migraphx
::
shape
&
s
)
{
migraphx_instruction_t
ret_ins
;
call
(
&
migraphx_module_add_allocation
,
&
ret_ins
,
mm
.
get
(),
s
.
get_handle_ptr
());
return
instruction
(
ret_ins
,
own
{});
}
migraphx_module_t
get_handle_ptr
()
const
{
return
mm
.
get
();
}
migraphx_module_t
get_handle_ptr
()
const
{
return
mm
.
get
();
}
private:
private:
std
::
shared_ptr
<
migraphx_module
>
mm
;
std
::
shared_ptr
<
migraphx_module
>
mm
;
};
};
struct
context
struct
context
:
handle_lookup
<
context
,
migraphx_context
>
{
{
context
(
migraphx_context
*
p
,
borrow
)
:
ctx
(
std
::
shared_ptr
<
migraphx_context
*>
(),
p
)
{}
context
(
migraphx_context
*
p
,
borrow
)
:
ctx
(
std
::
shared_ptr
<
migraphx_context
*>
(),
p
)
{}
...
@@ -1177,9 +1195,10 @@ quantize_int8(const program& prog, const target& ptarget, const quantize_int8_op
...
@@ -1177,9 +1195,10 @@ quantize_int8(const program& prog, const target& ptarget, const quantize_int8_op
struct
experimental_custom_op_base
struct
experimental_custom_op_base
{
{
virtual
std
::
string
name
()
const
=
0
;
virtual
std
::
string
name
()
const
=
0
;
virtual
shape
compute_shape
(
shapes
inputs
)
const
=
0
;
virtual
argument
compute
(
context
ctx
,
shape
output
,
arguments
inputs
)
const
=
0
;
virtual
~
experimental_custom_op_base
()
=
default
;
virtual
shape
compute_shape
(
shapes
inputs
)
const
=
0
;
virtual
~
experimental_custom_op_base
()
=
default
;
};
};
struct
experimental_custom_op
:
interface_base
<
MIGRAPHX_HANDLE_BASE
(
experimental_custom_op
)
>
struct
experimental_custom_op
:
interface_base
<
MIGRAPHX_HANDLE_BASE
(
experimental_custom_op
)
>
...
@@ -1189,6 +1208,7 @@ struct experimental_custom_op : interface_base<MIGRAPHX_HANDLE_BASE(experimental
...
@@ -1189,6 +1208,7 @@ struct experimental_custom_op : interface_base<MIGRAPHX_HANDLE_BASE(experimental
{
{
this
->
make_interface
(
&
migraphx_experimental_custom_op_create
,
obj
,
obj
.
name
().
c_str
());
this
->
make_interface
(
&
migraphx_experimental_custom_op_create
,
obj
,
obj
.
name
().
c_str
());
MIGRAPHX_INTERFACE_LIFT
(
T
,
experimental_custom_op
,
compute_shape
);
MIGRAPHX_INTERFACE_LIFT
(
T
,
experimental_custom_op
,
compute_shape
);
MIGRAPHX_INTERFACE_LIFT
(
T
,
experimental_custom_op
,
compute
);
}
}
void
register_op
()
{
call
(
&
migraphx_experimental_custom_op_register
,
this
->
get_handle_ptr
());
}
void
register_op
()
{
call
(
&
migraphx_experimental_custom_op_register
,
this
->
get_handle_ptr
());
}
...
...
src/api/migraphx.py
View file @
818e6e53
...
@@ -244,6 +244,10 @@ def module(h):
...
@@ -244,6 +244,10 @@ def module(h):
h
.
method
(
'add_return'
,
h
.
method
(
'add_return'
,
api
.
params
(
args
=
'std::vector<migraphx::instruction_ref>'
),
api
.
params
(
args
=
'std::vector<migraphx::instruction_ref>'
),
returns
=
'migraphx::instruction_ref'
)
returns
=
'migraphx::instruction_ref'
)
h
.
method
(
'add_allocation'
,
api
.
params
(
s
=
'const migraphx::shape&'
),
invoke
=
'migraphx::add_allocation($@)'
,
returns
=
'migraphx::instruction_ref'
)
@
auto_handle
()
@
auto_handle
()
...
@@ -436,6 +440,11 @@ def context(h):
...
@@ -436,6 +440,11 @@ def context(h):
'migraphx::experimental_custom_op'
)
'migraphx::experimental_custom_op'
)
def
experimental_custom_op
(
h
):
def
experimental_custom_op
(
h
):
h
.
constructor
(
'create'
,
api
.
params
(
name
=
'const char*'
))
h
.
constructor
(
'create'
,
api
.
params
(
name
=
'const char*'
))
h
.
virtual
(
'compute'
,
api
.
params
(
ctx
=
'migraphx::context'
,
output
=
'migraphx::shape'
,
inputs
=
'std::vector<migraphx::argument>'
),
returns
=
'migraphx::argument'
)
h
.
virtual
(
'compute_shape'
,
h
.
virtual
(
'compute_shape'
,
api
.
params
(
inputs
=
'std::vector<migraphx::shape>'
),
api
.
params
(
inputs
=
'std::vector<migraphx::shape>'
),
returns
=
'migraphx::shape'
)
returns
=
'migraphx::shape'
)
...
...
src/include/migraphx/module.hpp
View file @
818e6e53
...
@@ -183,7 +183,7 @@ struct module
...
@@ -183,7 +183,7 @@ struct module
void
annotate
(
std
::
ostream
&
os
,
std
::
function
<
void
(
instruction_ref
)
>
a
)
const
;
void
annotate
(
std
::
ostream
&
os
,
std
::
function
<
void
(
instruction_ref
)
>
a
)
const
;
std
::
vector
<
module_ref
>
get_sub_modules
()
const
;
std
::
vector
<
module_ref
>
get_sub_modules
(
bool
shallow
=
false
)
const
;
module
&
sort
();
module
&
sort
();
ins_dep_map
calc_implicit_deps
()
const
;
ins_dep_map
calc_implicit_deps
()
const
;
...
...
src/include/migraphx/pass_manager.hpp
View file @
818e6e53
...
@@ -38,6 +38,7 @@ struct module_pass_manager
...
@@ -38,6 +38,7 @@ struct module_pass_manager
module_pass_manager
(
const
module_pass_manager
&
)
=
delete
;
module_pass_manager
(
const
module_pass_manager
&
)
=
delete
;
virtual
module
&
get_module
()
=
0
;
virtual
module
&
get_module
()
=
0
;
virtual
module
*
create_module
(
const
std
::
string
&
name
)
=
0
;
virtual
module
*
create_module
(
const
std
::
string
&
name
)
=
0
;
virtual
module
*
get_common_parent
()
=
0
;
virtual
void
run_pass
(
const
pass
&
p
)
=
0
;
virtual
void
run_pass
(
const
pass
&
p
)
=
0
;
protected:
protected:
...
...
src/include/migraphx/program.hpp
View file @
818e6e53
...
@@ -132,6 +132,8 @@ struct program
...
@@ -132,6 +132,8 @@ struct program
std
::
vector
<
const
module
*>
get_modules
()
const
;
std
::
vector
<
const
module
*>
get_modules
()
const
;
std
::
vector
<
module
*>
get_modules
();
std
::
vector
<
module
*>
get_modules
();
std
::
unordered_multimap
<
module_ref
,
module_ref
>
get_module_tree
();
void
remove_module
(
const
std
::
string
&
name
);
void
remove_module
(
const
std
::
string
&
name
);
void
remove_unused_modules
();
void
remove_unused_modules
();
...
...
src/include/migraphx/ranges.hpp
View file @
818e6e53
...
@@ -216,10 +216,16 @@ void replace(Range&& r, const T& old, const T& new_x)
...
@@ -216,10 +216,16 @@ void replace(Range&& r, const T& old, const T& new_x)
std
::
replace
(
r
.
begin
(),
r
.
end
(),
old
,
new_x
);
std
::
replace
(
r
.
begin
(),
r
.
end
(),
old
,
new_x
);
}
}
template
<
class
R1
,
class
R2
>
template
<
class
R1
,
class
R2
,
class
...
Predicate
>
bool
equal
(
R1
&&
r1
,
R2
&&
r2
)
bool
equal
(
R1
&&
r1
,
R2
&&
r2
,
Predicate
...
pred
)
{
{
return
std
::
equal
(
r1
.
begin
(),
r1
.
end
(),
r2
.
begin
(),
r2
.
end
());
return
std
::
equal
(
r1
.
begin
(),
r1
.
end
(),
r2
.
begin
(),
r2
.
end
(),
pred
...);
}
template
<
class
Range
>
auto
distance
(
Range
&&
r
)
{
return
std
::
distance
(
r
.
begin
(),
r
.
end
());
}
}
template
<
class
R
>
template
<
class
R
>
...
...
src/module.cpp
View file @
818e6e53
...
@@ -821,17 +821,20 @@ void module::annotate(std::ostream& os, std::function<void(instruction_ref)> a)
...
@@ -821,17 +821,20 @@ void module::annotate(std::ostream& os, std::function<void(instruction_ref)> a)
});
});
}
}
std
::
vector
<
module_ref
>
module
::
get_sub_modules
()
const
std
::
vector
<
module_ref
>
module
::
get_sub_modules
(
bool
shallow
)
const
{
{
std
::
vector
<
module_ref
>
vec_modules
;
std
::
vector
<
module_ref
>
vec_modules
;
for
(
auto
ins
:
iterator_for
(
*
this
))
for
(
auto
ins
:
iterator_for
(
*
this
))
{
{
const
auto
&
mod_args
=
ins
->
module_inputs
();
const
auto
&
mod_args
=
ins
->
module_inputs
();
vec_modules
.
insert
(
vec_modules
.
end
(),
mod_args
.
begin
(),
mod_args
.
end
());
vec_modules
.
insert
(
vec_modules
.
end
(),
mod_args
.
begin
(),
mod_args
.
end
());
for
(
const
auto
&
smod
:
mod_args
)
if
(
not
shallow
)
{
{
auto
sub_mods
=
smod
->
get_sub_modules
();
for
(
const
auto
&
smod
:
mod_args
)
vec_modules
.
insert
(
vec_modules
.
end
(),
sub_mods
.
begin
(),
sub_mods
.
end
());
{
auto
sub_mods
=
smod
->
get_sub_modules
();
vec_modules
.
insert
(
vec_modules
.
end
(),
sub_mods
.
begin
(),
sub_mods
.
end
());
}
}
}
}
}
...
...
src/pass_manager.cpp
View file @
818e6e53
...
@@ -66,14 +66,12 @@ void run_pass(program& prog, const pass& p, tracer trace)
...
@@ -66,14 +66,12 @@ void run_pass(program& prog, const pass& p, tracer trace)
struct
module_pm
:
module_pass_manager
struct
module_pm
:
module_pass_manager
{
{
module
*
mod
;
module
*
mod
=
nullptr
;
program
*
prog
;
tracer
*
t
=
nullptr
;
tracer
*
t
;
module
*
common_parent
=
nullptr
;
program
*
prog
=
nullptr
;
module_pm
(
module
*
pmod
=
nullptr
,
program
*
pprog
=
nullptr
,
tracer
*
pt
=
nullptr
)
module_pm
(
module
*
pmod
=
nullptr
,
tracer
*
pt
=
nullptr
)
:
mod
(
pmod
),
t
(
pt
)
{}
:
mod
(
pmod
),
prog
(
pprog
),
t
(
pt
)
{
}
template
<
class
...
Ts
>
template
<
class
...
Ts
>
void
trace
(
Ts
&&
...
xs
)
const
void
trace
(
Ts
&&
...
xs
)
const
...
@@ -92,6 +90,7 @@ struct module_pm : module_pass_manager
...
@@ -92,6 +90,7 @@ struct module_pm : module_pass_manager
assert
(
prog
);
assert
(
prog
);
return
prog
->
create_module
(
name
);
return
prog
->
create_module
(
name
);
}
}
virtual
module
*
get_common_parent
()
override
{
return
common_parent
;
}
virtual
void
run_pass
(
const
pass
&
p
)
override
virtual
void
run_pass
(
const
pass
&
p
)
override
{
{
assert
(
mod
);
assert
(
mod
);
...
@@ -111,7 +110,7 @@ void run_passes(module& mod, const std::vector<pass>& passes, tracer trace)
...
@@ -111,7 +110,7 @@ void run_passes(module& mod, const std::vector<pass>& passes, tracer trace)
trace
=
tracer
{
std
::
cout
};
trace
=
tracer
{
std
::
cout
};
for
(
const
auto
&
p
:
passes
)
for
(
const
auto
&
p
:
passes
)
{
{
module_pm
{
&
mod
,
nullptr
,
&
trace
}.
run_pass
(
p
);
module_pm
{
&
mod
,
&
trace
}.
run_pass
(
p
);
}
}
}
}
...
@@ -119,14 +118,31 @@ void run_passes(program& prog, const std::vector<pass>& passes, tracer trace)
...
@@ -119,14 +118,31 @@ void run_passes(program& prog, const std::vector<pass>& passes, tracer trace)
{
{
if
(
enabled
(
MIGRAPHX_TRACE_PASSES
{}))
if
(
enabled
(
MIGRAPHX_TRACE_PASSES
{}))
trace
=
tracer
{
std
::
cout
};
trace
=
tracer
{
std
::
cout
};
std
::
unordered_set
<
module_ref
>
visited
;
for
(
const
auto
&
p
:
passes
)
for
(
const
auto
&
p
:
passes
)
{
{
auto
mods
=
prog
.
get_modules
();
auto
mods
=
prog
.
get_modules
();
auto
tree
=
prog
.
get_module_tree
();
visited
.
clear
();
for
(
const
auto
&
mod
:
reverse
(
mods
))
for
(
const
auto
&
mod
:
reverse
(
mods
))
{
{
if
(
mod
->
bypass
())
if
(
mod
->
bypass
())
continue
;
continue
;
module_pm
{
mod
,
&
prog
,
&
trace
}.
run_pass
(
p
);
if
(
not
visited
.
insert
(
mod
).
second
)
continue
;
module_pm
mpm
{
mod
,
&
trace
};
mpm
.
prog
=
&
prog
;
auto
parents
=
range
(
tree
.
equal_range
(
mod
));
auto
nparents
=
distance
(
parents
);
if
(
nparents
==
0
)
mpm
.
common_parent
=
nullptr
;
else
if
(
nparents
==
1
)
mpm
.
common_parent
=
parents
.
begin
()
->
second
;
else
// Just set common parent to main module when there is muliple parents for now
// TODO: Compute the common parent
mpm
.
common_parent
=
prog
.
get_main_module
();
mpm
.
run_pass
(
p
);
}
}
run_pass
(
prog
,
p
,
trace
);
run_pass
(
prog
,
p
,
trace
);
}
}
...
...
src/program.cpp
View file @
818e6e53
...
@@ -869,6 +869,23 @@ std::vector<module*> program::get_modules()
...
@@ -869,6 +869,23 @@ std::vector<module*> program::get_modules()
return
result
;
return
result
;
}
}
template
<
class
Module
,
class
Map
>
void
generic_insert_module_tree
(
Module
*
pm
,
Map
&
m
)
{
for
(
auto
*
sm
:
pm
->
get_sub_modules
(
true
))
{
m
.
insert
(
std
::
make_pair
(
sm
,
pm
));
generic_insert_module_tree
(
sm
,
m
);
}
}
std
::
unordered_multimap
<
module_ref
,
module_ref
>
program
::
get_module_tree
()
{
std
::
unordered_multimap
<
module_ref
,
module_ref
>
result
;
generic_insert_module_tree
(
this
->
get_main_module
(),
result
);
return
result
;
}
template
<
class
Map
,
class
T
>
template
<
class
Map
,
class
T
>
bool
is_unused_module
(
Map
&
m
,
const
std
::
vector
<
T
*>&
mods
,
const
std
::
string
&
name
)
bool
is_unused_module
(
Map
&
m
,
const
std
::
vector
<
T
*>&
mods
,
const
std
::
string
&
name
)
{
{
...
...
src/shape.cpp
View file @
818e6e53
...
@@ -61,9 +61,7 @@ struct shape_impl
...
@@ -61,9 +61,7 @@ struct shape_impl
{
{
assert
(
t
!=
shape
::
tuple_type
);
assert
(
t
!=
shape
::
tuple_type
);
assert
(
m_lens
.
size
()
==
m_strides
.
size
());
assert
(
m_lens
.
size
()
==
m_strides
.
size
());
// assert(std::any_of(m_strides.begin(), m_strides.end(), [](auto x) { return x > 0; }) and
m_standard
=
this
->
elements
()
==
this
->
element_space
()
and
not
skips
()
and
// "At least one stride must be non-zero");
m_standard
=
this
->
elements
()
==
this
->
element_space
()
and
std
::
is_sorted
(
m_strides
.
rbegin
(),
m_strides
.
rend
());
std
::
is_sorted
(
m_strides
.
rbegin
(),
m_strides
.
rend
());
}
}
...
@@ -110,6 +108,15 @@ struct shape_impl
...
@@ -110,6 +108,15 @@ struct shape_impl
m_lens
.
begin
(),
m_lens
.
end
(),
std
::
size_t
{
1
},
std
::
multiplies
<
std
::
size_t
>
());
m_lens
.
begin
(),
m_lens
.
end
(),
std
::
size_t
{
1
},
std
::
multiplies
<
std
::
size_t
>
());
}
}
// Does the shape skip over elements?
bool
skips
()
const
{
assert
(
m_lens
.
size
()
==
m_strides
.
size
());
if
(
elements
()
==
1
)
return
false
;
return
std
::
none_of
(
m_strides
.
begin
(),
m_strides
.
end
(),
[](
auto
x
)
{
return
x
==
1
;
});
}
std
::
shared_ptr
<
shape_impl
>
copy
()
const
{
return
std
::
make_shared
<
shape_impl
>
(
*
this
);
}
std
::
shared_ptr
<
shape_impl
>
copy
()
const
{
return
std
::
make_shared
<
shape_impl
>
(
*
this
);
}
};
};
...
@@ -260,7 +267,8 @@ void shape::multi_copy(std::size_t i, std::size_t* start, const std::size_t* end
...
@@ -260,7 +267,8 @@ void shape::multi_copy(std::size_t i, std::size_t* start, const std::size_t* end
bool
shape
::
packed
()
const
bool
shape
::
packed
()
const
{
{
return
this
->
sub_shapes
().
empty
()
and
this
->
elements
()
==
this
->
element_space
();
return
this
->
sub_shapes
().
empty
()
and
not
impl
->
skips
()
and
this
->
elements
()
==
this
->
element_space
();
}
}
bool
shape
::
transposed
()
const
bool
shape
::
transposed
()
const
...
@@ -285,10 +293,8 @@ bool shape::transposed() const
...
@@ -285,10 +293,8 @@ bool shape::transposed() const
bool
shape
::
broadcasted
()
const
bool
shape
::
broadcasted
()
const
{
{
assert
(
this
->
lens
().
size
()
==
this
->
strides
().
size
());
assert
(
this
->
lens
().
size
()
==
this
->
strides
().
size
());
return
std
::
accumulate
(
this
->
strides
().
begin
(),
return
std
::
any_of
(
this
->
strides
().
end
(),
this
->
strides
().
begin
(),
this
->
strides
().
end
(),
[](
auto
x
)
{
return
x
==
0
;
});
std
::
size_t
{
1
},
std
::
multiplies
<
std
::
size_t
>
())
==
0
;
}
}
bool
shape
::
scalar
()
const
bool
shape
::
scalar
()
const
...
...
src/targets/gpu/CMakeLists.txt
View file @
818e6e53
...
@@ -163,7 +163,6 @@ add_library(migraphx_gpu
...
@@ -163,7 +163,6 @@ add_library(migraphx_gpu
convolution.cpp
convolution.cpp
deconvolution.cpp
deconvolution.cpp
device_name.cpp
device_name.cpp
eliminate_workspace.cpp
elu.cpp
elu.cpp
fuse_mlir.cpp
fuse_mlir.cpp
fuse_ops.cpp
fuse_ops.cpp
...
...
src/targets/gpu/fuse_ops.cpp
View file @
818e6e53
...
@@ -48,6 +48,7 @@
...
@@ -48,6 +48,7 @@
#include <migraphx/instruction.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/register_op.hpp>
#include <migraphx/register_op.hpp>
#include <migraphx/array.hpp>
#include <migraphx/array.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/op/clip.hpp>
#include <migraphx/op/clip.hpp>
#include <cmath>
#include <cmath>
#include <set>
#include <set>
...
@@ -701,6 +702,7 @@ struct miopen_fusion
...
@@ -701,6 +702,7 @@ struct miopen_fusion
return
args
.
back
();
return
args
.
back
();
}
}
};
};
MIGRAPHX_REGISTER_OP
(
miopen_fusion
)
struct
miopen_conv_bias
struct
miopen_conv_bias
{
{
...
@@ -1005,9 +1007,43 @@ struct find_commutative_broadcast
...
@@ -1005,9 +1007,43 @@ struct find_commutative_broadcast
};
};
}
// namespace
}
// namespace
struct
find_contiguous
{
auto
matcher
()
const
{
return
match
::
name
(
"gpu::contiguous"
);
}
void
apply
(
module
&
m
,
const
match
::
matcher_result
&
r
)
const
{
auto
ins
=
r
.
result
;
m
.
replace_instruction
(
ins
,
make_op
(
"gpu::precompile_op"
,
{{
"op"
,
to_value
(
make_op
(
"contiguous"
))}}),
ins
->
inputs
());
}
};
struct
find_contiguous_pointwise
{
auto
matcher
()
const
{
return
match
::
name
(
"gpu::contiguous"
)(
match
::
arg
(
0
)(
precompile_name
(
"pointwise"
)));
}
void
apply
(
module
&
m
,
const
match
::
matcher_result
&
r
)
const
{
auto
ins
=
r
.
result
;
auto
pw
=
ins
->
inputs
().
front
();
auto
alloc
=
ins
->
inputs
().
back
();
auto
args
=
pw
->
inputs
();
args
.
back
()
=
alloc
;
m
.
replace_instruction
(
ins
,
pw
->
get_operator
(),
args
,
pw
->
module_inputs
());
}
};
void
fuse_ops
::
apply
(
module
&
m
)
const
void
fuse_ops
::
apply
(
module
&
m
)
const
{
{
match
::
find_matches
(
m
,
find_gelu
{},
find_gelu_new
{
fast_math
});
match
::
find_matches
(
m
,
find_contiguous_pointwise
{},
find_gelu
{},
find_gelu_new
{
fast_math
});
run_passes
(
m
,
{
dead_code_elimination
{}});
run_passes
(
m
,
{
dead_code_elimination
{}});
match
::
find_matches
(
m
,
find_triadd
{});
match
::
find_matches
(
m
,
find_triadd
{});
match
::
find_matches
(
m
,
match
::
find_matches
(
m
,
...
@@ -1029,6 +1065,7 @@ void fuse_ops::apply(module& m) const
...
@@ -1029,6 +1065,7 @@ void fuse_ops::apply(module& m) const
find_gemm_add
{},
find_gemm_add
{},
find_gemm_pointwise
{},
find_gemm_pointwise
{},
find_commutative_broadcast
{});
find_commutative_broadcast
{});
match
::
find_matches
(
m
,
find_contiguous
{});
}
}
}
// namespace gpu
}
// namespace gpu
...
...
src/targets/gpu/hip.cpp
View file @
818e6e53
...
@@ -37,7 +37,6 @@ inline namespace MIGRAPHX_INLINE_NS {
...
@@ -37,7 +37,6 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace
gpu
{
namespace
gpu
{
MIGRAPHX_REGISTER_OP
(
hip_allocate
)
MIGRAPHX_REGISTER_OP
(
hip_allocate
)
MIGRAPHX_REGISTER_OP
(
hip_sync_device
)
MIGRAPHX_REGISTER_OP
(
hip_sync_stream
)
MIGRAPHX_REGISTER_OP
(
hip_sync_stream
)
MIGRAPHX_REGISTER_OP
(
hip_copy_to_gpu
)
MIGRAPHX_REGISTER_OP
(
hip_copy_to_gpu
)
MIGRAPHX_REGISTER_OP
(
hip_copy_from_gpu
)
MIGRAPHX_REGISTER_OP
(
hip_copy_from_gpu
)
...
...
src/targets/gpu/include/migraphx/gpu/eliminate_workspace.hpp
deleted
100644 → 0
View file @
d716c516
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_RTGLIB_ELIMINATE_WORKSPACE_HPP
#define MIGRAPHX_GUARD_RTGLIB_ELIMINATE_WORKSPACE_HPP
#include <string>
#include <migraphx/instruction_ref.hpp>
#include <migraphx/config.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
struct
module
;
namespace
gpu
{
struct
eliminate_workspace
{
std
::
string
name
()
const
{
return
"eliminate_workspace"
;
}
void
apply
(
module
&
m
)
const
;
};
}
// namespace gpu
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif
src/targets/gpu/include/migraphx/gpu/hip.hpp
View file @
818e6e53
...
@@ -59,12 +59,11 @@ argument get_preallocation(context& ctx, const std::string& id);
...
@@ -59,12 +59,11 @@ argument get_preallocation(context& ctx, const std::string& id);
struct
hip_allocate
struct
hip_allocate
{
{
shape
s
;
shape
s
;
std
::
string
tag
{};
template
<
class
Self
,
class
F
>
template
<
class
Self
,
class
F
>
static
auto
reflect
(
Self
&
self
,
F
f
)
static
auto
reflect
(
Self
&
self
,
F
f
)
{
{
return
pack
(
f
(
self
.
s
,
"shape"
)
,
f
(
self
.
tag
,
"tag"
)
);
return
pack
(
f
(
self
.
s
,
"shape"
));
}
}
std
::
string
name
()
const
{
return
"hip::allocate"
;
}
std
::
string
name
()
const
{
return
"hip::allocate"
;
}
...
@@ -79,42 +78,8 @@ struct hip_allocate
...
@@ -79,42 +78,8 @@ struct hip_allocate
}
}
};
};
struct
hip_sync_device
{
std
::
string
tag
{};
template
<
class
Self
,
class
F
>
static
auto
reflect
(
Self
&
self
,
F
f
)
{
return
pack
(
f
(
self
.
tag
,
"tag"
));
}
std
::
string
name
()
const
{
return
"hip::sync_device"
;
}
shape
compute_shape
(
const
std
::
vector
<
shape
>&
inputs
)
const
{
if
(
inputs
.
empty
())
return
{};
return
inputs
.
front
();
}
argument
compute
(
context
&
,
const
shape
&
,
const
std
::
vector
<
argument
>&
args
)
const
{
gpu_sync
();
if
(
args
.
empty
())
return
{};
return
args
.
front
();
}
};
struct
hip_sync_stream
struct
hip_sync_stream
{
{
std
::
string
tag
{};
template
<
class
Self
,
class
F
>
static
auto
reflect
(
Self
&
self
,
F
f
)
{
return
pack
(
f
(
self
.
tag
,
"tag"
));
}
std
::
string
name
()
const
{
return
"hip::sync_stream"
;
}
std
::
string
name
()
const
{
return
"hip::sync_stream"
;
}
shape
compute_shape
(
const
std
::
vector
<
shape
>&
inputs
)
const
shape
compute_shape
(
const
std
::
vector
<
shape
>&
inputs
)
const
...
...
src/targets/gpu/jit/pointwise.cpp
View file @
818e6e53
...
@@ -79,7 +79,7 @@ static std::vector<std::string> get_op_names(const module& m)
...
@@ -79,7 +79,7 @@ static std::vector<std::string> get_op_names(const module& m)
struct
pointwise_compiler
:
compiler
<
pointwise_compiler
>
struct
pointwise_compiler
:
compiler
<
pointwise_compiler
>
{
{
std
::
vector
<
std
::
string
>
names
()
const
{
return
{
"pointwise"
};
}
std
::
vector
<
std
::
string
>
names
()
const
{
return
{
"pointwise"
,
"contiguous"
};
}
static
std
::
size_t
oversubscribe_if
(
bool
b
)
static
std
::
size_t
oversubscribe_if
(
bool
b
)
{
{
...
@@ -114,34 +114,45 @@ struct pointwise_compiler : compiler<pointwise_compiler>
...
@@ -114,34 +114,45 @@ struct pointwise_compiler : compiler<pointwise_compiler>
return
compile_hip_code_object
(
src
,
options
);
return
compile_hip_code_object
(
src
,
options
);
}
}
compiler_replace
compile
(
context
&
ctx
,
instruction_ref
ins
,
const
operation
&
)
const
compiler_replace
compile
(
context
&
ctx
,
instruction_ref
ins
,
const
operation
&
op
)
const
{
{
assert
(
not
ins
->
module_inputs
().
empty
());
if
(
op
.
name
()
==
"contiguous"
)
auto
*
pm
=
ins
->
module_inputs
().
front
();
{
run_passes
(
*
pm
,
{
eliminate_common_subexpression
{},
dead_code_elimination
{}});
return
replace
(
compile_op
(
cpp_generator
g
;
ctx
,
g
.
fmap
([](
const
std
::
string
&
fname
)
{
return
"migraphx::"
+
fname
;
});
to_shapes
(
ins
->
inputs
()),
g
.
add_point_op
(
"where"
,
"${function:where}(${0}, ${1}, ${2})"
);
{{
"lambda"
,
"[](auto x) { return x; }"
},
{
"kernel"
,
"contiguous_kernel"
}}));
g
.
add_point_op
(
"prelu"
,
"${function:where}(${0} < 0, ${0} * ${1}, ${0})"
);
}
g
.
add_point_op
(
"sign"
,
else
"${function:where}(${0} > 0, 1, ${function:where}(${0} < 0, -1, 0))"
);
{
g
.
add_point_op
(
"equal"
,
"migraphx::abs(${0} == ${1})"
);
assert
(
not
ins
->
module_inputs
().
empty
());
g
.
add_point_op
(
"less"
,
"migraphx::abs(${0} < ${1})"
);
auto
*
pm
=
ins
->
module_inputs
().
front
();
g
.
add_point_op
(
"greater"
,
"migraphx::abs(${0} > ${1})"
);
run_passes
(
*
pm
,
{
eliminate_common_subexpression
{},
dead_code_elimination
{}});
g
.
add_point_op
(
"not"
,
"migraphx::abs(not ${0})"
);
cpp_generator
g
;
// Add explict conversions
g
.
fmap
([](
const
std
::
string
&
fname
)
{
return
"migraphx::"
+
fname
;
});
g
.
fresult
(
g
.
add_point_op
(
"where"
,
"${function:where}(${0}, ${1}, ${2})"
);
[](
const
shape
&
s
)
{
return
"migraphx::convert<"
+
shape
::
cpp_type
(
s
.
type
())
+
">"
;
});
g
.
add_point_op
(
"prelu"
,
"${function:where}(${0} < 0, ${0} * ${1}, ${0})"
);
auto
name
=
g
.
create_function
(
g
.
add_point_op
(
"sign"
,
g
.
generate_module
(
*
pm
).
set_attributes
({
"__device__"
}).
set_generic_types
(
*
pm
));
"${function:where}(${0} > 0, 1, ${function:where}(${0} < 0, -1, 0))"
);
std
::
string
lambda
=
"MIGRAPHX_LIFT("
+
name
+
")"
;
g
.
add_point_op
(
"equal"
,
"migraphx::abs(${0} == ${1})"
);
auto
op_names
=
get_op_names
(
*
pm
);
g
.
add_point_op
(
"less"
,
"migraphx::abs(${0} < ${1})"
);
op_names
.
push_back
(
"kernel"
);
g
.
add_point_op
(
"greater"
,
"migraphx::abs(${0} > ${1})"
);
auto
op_name_string
=
join_strings
(
op_names
,
"_"
);
g
.
add_point_op
(
"not"
,
"migraphx::abs(not ${0})"
);
return
replace
(
// Add explict conversions
compile_op
(
ctx
,
g
.
fresult
([](
const
shape
&
s
)
{
to_shapes
(
ins
->
inputs
()),
return
"migraphx::convert<"
+
shape
::
cpp_type
(
s
.
type
())
+
">"
;
{{
"lambda"
,
lambda
},
{
"preamble"
,
g
.
str
()},
{
"kernel"
,
op_name_string
}}));
});
auto
name
=
g
.
create_function
(
g
.
generate_module
(
*
pm
).
set_attributes
({
"__device__"
}).
set_generic_types
(
*
pm
));
std
::
string
lambda
=
"MIGRAPHX_LIFT("
+
name
+
")"
;
auto
op_names
=
get_op_names
(
*
pm
);
op_names
.
push_back
(
"kernel"
);
auto
op_name_string
=
join_strings
(
op_names
,
"_"
);
return
replace
(
compile_op
(
ctx
,
to_shapes
(
ins
->
inputs
()),
{{
"lambda"
,
lambda
},
{
"preamble"
,
g
.
str
()},
{
"kernel"
,
op_name_string
}}));
}
}
}
};
};
}
// namespace gpu
}
// namespace gpu
...
...
src/targets/gpu/kernels/include/migraphx/kernels/algorithm.hpp
View file @
818e6e53
...
@@ -49,7 +49,7 @@ constexpr T accumulate(InputIt first, InputIt last, T init, BinaryOperation op)
...
@@ -49,7 +49,7 @@ constexpr T accumulate(InputIt first, InputIt last, T init, BinaryOperation op)
{
{
for
(;
first
!=
last
;
++
first
)
for
(;
first
!=
last
;
++
first
)
{
{
init
=
op
(
st
d
::
move
(
init
),
*
first
);
init
=
op
(
st
atic_cast
<
T
&&>
(
init
),
*
first
);
}
}
return
init
;
return
init
;
}
}
...
@@ -64,6 +64,20 @@ constexpr OutputIt copy(InputIt first, InputIt last, OutputIt d_first)
...
@@ -64,6 +64,20 @@ constexpr OutputIt copy(InputIt first, InputIt last, OutputIt d_first)
return
d_first
;
return
d_first
;
}
}
template
<
class
InputIt
,
class
OutputIt
,
class
UnaryPredicate
>
constexpr
OutputIt
copy_if
(
InputIt
first
,
InputIt
last
,
OutputIt
d_first
,
UnaryPredicate
pred
)
{
for
(;
first
!=
last
;
++
first
)
{
if
(
pred
(
*
first
))
{
*
d_first
=
*
first
;
++
d_first
;
}
}
return
d_first
;
}
template
<
class
Iterator
,
class
Compare
>
template
<
class
Iterator
,
class
Compare
>
constexpr
Iterator
is_sorted_until
(
Iterator
first
,
Iterator
last
,
Compare
comp
)
constexpr
Iterator
is_sorted_until
(
Iterator
first
,
Iterator
last
,
Compare
comp
)
{
{
...
@@ -115,6 +129,24 @@ constexpr Iterator find(Iterator first, Iterator last, const T& value)
...
@@ -115,6 +129,24 @@ constexpr Iterator find(Iterator first, Iterator last, const T& value)
return
find_if
(
first
,
last
,
[
&
](
const
auto
&
x
)
{
return
x
==
value
;
});
return
find_if
(
first
,
last
,
[
&
](
const
auto
&
x
)
{
return
x
==
value
;
});
}
}
template
<
class
InputIt
,
class
UnaryPredicate
>
constexpr
bool
any_of
(
InputIt
first
,
InputIt
last
,
UnaryPredicate
p
)
{
return
find_if
(
first
,
last
,
p
)
!=
last
;
}
template
<
class
InputIt
,
class
UnaryPredicate
>
constexpr
bool
none_of
(
InputIt
first
,
InputIt
last
,
UnaryPredicate
p
)
{
return
find_if
(
first
,
last
,
p
)
==
last
;
}
template
<
class
InputIt
,
class
UnaryPredicate
>
constexpr
bool
all_of
(
InputIt
first
,
InputIt
last
,
UnaryPredicate
p
)
{
return
none_of
(
first
,
last
,
[
=
](
auto
&&
x
)
{
return
not
p
(
x
);
});
}
template
<
class
Iterator1
,
class
Iterator2
>
template
<
class
Iterator1
,
class
Iterator2
>
constexpr
Iterator1
search
(
Iterator1
first
,
Iterator1
last
,
Iterator2
s_first
,
Iterator2
s_last
)
constexpr
Iterator1
search
(
Iterator1
first
,
Iterator1
last
,
Iterator2
s_first
,
Iterator2
s_last
)
{
{
...
...
src/targets/gpu/kernels/include/migraphx/kernels/pointwise.hpp
View file @
818e6e53
...
@@ -41,8 +41,15 @@ struct implicit_conversion_op
...
@@ -41,8 +41,15 @@ struct implicit_conversion_op
template
<
index_int
N
,
class
U
>
template
<
index_int
N
,
class
U
>
constexpr
operator
vec
<
U
,
N
>
()
const
constexpr
operator
vec
<
U
,
N
>
()
const
{
{
static_assert
(
vec_size
<
T
>
()
==
N
,
"Vector mismatch size"
);
if
constexpr
(
vec_size
<
T
>
()
==
0
)
return
__builtin_convertvector
(
x
,
vec
<
U
,
N
>
);
{
return
x
;
}
else
{
static_assert
(
vec_size
<
T
>
()
==
N
,
"Vector mismatch size"
);
return
__builtin_convertvector
(
x
,
vec
<
U
,
N
>
);
}
}
}
template
<
class
U
>
template
<
class
U
>
...
...
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment