Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
5616894f
Commit
5616894f
authored
May 20, 2022
by
Paul
Browse files
Merge
parents
835cc1e2
4a312201
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
33 additions
and
9 deletions
+33
-9
src/include/migraphx/matcher.hpp
src/include/migraphx/matcher.hpp
+7
-1
src/targets/gpu/jit/pointwise.cpp
src/targets/gpu/jit/pointwise.cpp
+26
-8
No files found.
src/include/migraphx/matcher.hpp
View file @
5616894f
...
@@ -754,10 +754,16 @@ auto skip_broadcasts(Ms... ms)
...
@@ -754,10 +754,16 @@ auto skip_broadcasts(Ms... ms)
return
skip
(
name
(
"broadcast"
,
"multibroadcast"
,
"contiguous"
))(
ms
...);
return
skip
(
name
(
"broadcast"
,
"multibroadcast"
,
"contiguous"
))(
ms
...);
}
}
template
<
class
...
Ms
>
auto
skip_broadcasts_converts
(
Ms
...
ms
)
{
return
skip
(
name
(
"broadcast"
,
"multibroadcast"
,
"contiguous"
,
"convert"
))(
ms
...);
}
template
<
class
T
>
template
<
class
T
>
inline
auto
has_value
(
T
x
,
float
tolerance
=
1e-6
)
inline
auto
has_value
(
T
x
,
float
tolerance
=
1e-6
)
{
{
return
skip_broadcasts
(
make_basic_pred_matcher
([
=
](
instruction_ref
ins
)
{
return
skip_broadcasts
_converts
(
make_basic_pred_matcher
([
=
](
instruction_ref
ins
)
{
if
(
ins
->
name
()
!=
"@literal"
)
if
(
ins
->
name
()
!=
"@literal"
)
return
false
;
return
false
;
auto
l
=
ins
->
get_literal
();
auto
l
=
ins
->
get_literal
();
...
...
src/targets/gpu/jit/pointwise.cpp
View file @
5616894f
...
@@ -27,7 +27,7 @@ namespace migraphx {
...
@@ -27,7 +27,7 @@ namespace migraphx {
${preamble}
${preamble}
extern "C" {
extern "C" {
__global__ void kernel(${params})
__global__ void
${
kernel
}
(${params})
{
{
auto idx = make_index();
auto idx = make_index();
pointwise(idx, auto_preload<${preloads}>(idx), vectorize<${vec_size}, ${axis}>())(${lambda}, ${args});
pointwise(idx, auto_preload<${preloads}>(idx), vectorize<${vec_size}, ${axis}>())(${lambda}, ${args});
...
@@ -39,6 +39,18 @@ __global__ void kernel(${params})
...
@@ -39,6 +39,18 @@ __global__ void kernel(${params})
)__migraphx__"
;
)__migraphx__"
;
static
std
::
vector
<
std
::
string
>
get_op_names
(
const
module
&
m
)
{
std
::
vector
<
std
::
string
>
result
;
for
(
auto
&
ins
:
m
)
{
if
(
starts_with
(
ins
.
name
(),
"@"
))
continue
;
result
.
push_back
(
ins
.
name
());
}
return
result
;
}
struct
pointwise_compiler
:
compiler
<
pointwise_compiler
>
struct
pointwise_compiler
:
compiler
<
pointwise_compiler
>
{
{
std
::
vector
<
std
::
string
>
names
()
const
{
return
{
"pointwise"
,
"contiguous"
};
}
std
::
vector
<
std
::
string
>
names
()
const
{
return
{
"pointwise"
,
"contiguous"
};
}
...
@@ -131,12 +143,14 @@ struct pointwise_compiler : compiler<pointwise_compiler>
...
@@ -131,12 +143,14 @@ struct pointwise_compiler : compiler<pointwise_compiler>
auto
preloads
=
preload
(
axis
,
options
.
virtual_inputs
);
auto
preloads
=
preload
(
axis
,
options
.
virtual_inputs
);
auto
is_preloading
=
auto
is_preloading
=
std
::
accumulate
(
preloads
.
begin
(),
preloads
.
end
(),
false
,
std
::
logical_or
<>
{});
std
::
accumulate
(
preloads
.
begin
(),
preloads
.
end
(),
false
,
std
::
logical_or
<>
{});
options
.
kernel_name
=
v
.
get
(
"kernel"
,
"kernel"
);
options
.
set_launch_params
(
v
,
options
.
set_launch_params
(
v
,
compute_global_for
(
ctx
,
compute_global_for
(
ctx
,
options
.
output
.
elements
()
/
vec_size
,
options
.
output
.
elements
()
/
vec_size
,
oversubscribe_if
(
not
is_preloading
)));
oversubscribe_if
(
not
is_preloading
)));
auto
src
=
interpolate_string
(
pointwise_kernel
,
auto
src
=
interpolate_string
(
pointwise_kernel
,
{{
"params"
,
enum_params
(
inputs
.
size
(),
"void * private_p"
)},
{{
"kernel"
,
options
.
kernel_name
},
{
"params"
,
enum_params
(
inputs
.
size
(),
"void * private_p"
)},
{
"args"
,
enum_params
(
inputs
.
size
(),
"private_p"
)},
{
"args"
,
enum_params
(
inputs
.
size
(),
"private_p"
)},
{
"lambda"
,
v
.
at
(
"lambda"
).
to
<
std
::
string
>
()},
{
"lambda"
,
v
.
at
(
"lambda"
).
to
<
std
::
string
>
()},
{
"vec_size"
,
std
::
to_string
(
vec_size
)},
{
"vec_size"
,
std
::
to_string
(
vec_size
)},
...
@@ -151,7 +165,7 @@ struct pointwise_compiler : compiler<pointwise_compiler>
...
@@ -151,7 +165,7 @@ struct pointwise_compiler : compiler<pointwise_compiler>
if
(
op
.
name
()
==
"contiguous"
)
if
(
op
.
name
()
==
"contiguous"
)
{
{
return
replace
(
compile_op
(
return
replace
(
compile_op
(
ctx
,
to_shapes
(
ins
->
inputs
()),
{{
"lambda"
,
"[](auto x) { return x; }"
}}));
ctx
,
to_shapes
(
ins
->
inputs
()),
{{
"lambda"
,
"[](auto x) { return x; }"
}
,
{
"kernel"
,
"contiguous_kernel"
}
}));
}
}
else
else
{
{
...
@@ -169,14 +183,18 @@ struct pointwise_compiler : compiler<pointwise_compiler>
...
@@ -169,14 +183,18 @@ struct pointwise_compiler : compiler<pointwise_compiler>
g
.
add_point_op
(
"greater"
,
"migraphx::abs(${0} > ${1})"
);
g
.
add_point_op
(
"greater"
,
"migraphx::abs(${0} > ${1})"
);
g
.
add_point_op
(
"not"
,
"migraphx::abs(not ${0})"
);
g
.
add_point_op
(
"not"
,
"migraphx::abs(not ${0})"
);
// Add explict conversions
// Add explict conversions
g
.
fresult
([](
const
shape
&
s
)
{
g
.
fresult
(
return
"migraphx::convert<"
+
shape
::
cpp_type
(
s
.
type
())
+
">"
;
[](
const
shape
&
s
)
{
return
"migraphx::convert<"
+
shape
::
cpp_type
(
s
.
type
())
+
">"
;
});
});
auto
name
=
g
.
create_function
(
auto
name
=
g
.
create_function
(
g
.
generate_module
(
*
pm
).
set_attributes
({
"__device__"
}).
set_generic_types
(
*
pm
));
g
.
generate_module
(
*
pm
).
set_attributes
({
"__device__"
}).
set_generic_types
(
*
pm
));
std
::
string
lambda
=
"MIGRAPHX_LIFT("
+
name
+
")"
;
std
::
string
lambda
=
"MIGRAPHX_LIFT("
+
name
+
")"
;
return
replace
(
compile_op
(
auto
op_names
=
get_op_names
(
*
pm
);
ctx
,
to_shapes
(
ins
->
inputs
()),
{{
"lambda"
,
lambda
},
{
"preamble"
,
g
.
str
()}}));
op_names
.
push_back
(
"kernel"
);
auto
op_name_string
=
join_strings
(
op_names
,
"_"
);
return
replace
(
compile_op
(
ctx
,
to_shapes
(
ins
->
inputs
()),
{{
"lambda"
,
lambda
},
{
"preamble"
,
g
.
str
()},
{
"kernel"
,
op_name_string
}}));
}
}
}
}
};
};
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment