Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
407acb7d
Commit
407acb7d
authored
May 17, 2022
by
Paul
Browse files
Jit contiguous
parent
a27dd28c
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
41 additions
and
26 deletions
+41
-26
src/targets/gpu/jit/pointwise.cpp
src/targets/gpu/jit/pointwise.cpp
+31
-23
src/targets/gpu/kernels/include/migraphx/kernels/pointwise.hpp
...argets/gpu/kernels/include/migraphx/kernels/pointwise.hpp
+9
-2
src/targets/gpu/lowering.cpp
src/targets/gpu/lowering.cpp
+1
-1
No files found.
src/targets/gpu/jit/pointwise.cpp
View file @
407acb7d
...
@@ -41,7 +41,7 @@ __global__ void kernel(${params})
...
@@ -41,7 +41,7 @@ __global__ void kernel(${params})
struct
pointwise_compiler
:
compiler
<
pointwise_compiler
>
struct
pointwise_compiler
:
compiler
<
pointwise_compiler
>
{
{
std
::
vector
<
std
::
string
>
names
()
const
{
return
{
"pointwise"
};
}
std
::
vector
<
std
::
string
>
names
()
const
{
return
{
"pointwise"
,
"contiguous"
};
}
static
std
::
size_t
oversubscribe_if
(
bool
b
)
static
std
::
size_t
oversubscribe_if
(
bool
b
)
{
{
...
@@ -146,29 +146,37 @@ struct pointwise_compiler : compiler<pointwise_compiler>
...
@@ -146,29 +146,37 @@ struct pointwise_compiler : compiler<pointwise_compiler>
return
compile_hip_code_object
(
src
,
options
);
return
compile_hip_code_object
(
src
,
options
);
}
}
compiler_replace
compile
(
context
&
ctx
,
instruction_ref
ins
,
const
operation
&
)
const
compiler_replace
compile
(
context
&
ctx
,
instruction_ref
ins
,
const
operation
&
op
)
const
{
{
assert
(
not
ins
->
module_inputs
().
empty
());
if
(
op
.
name
()
==
"contiguous"
)
auto
*
pm
=
ins
->
module_inputs
().
front
();
{
run_passes
(
*
pm
,
{
eliminate_common_subexpression
{},
dead_code_elimination
{}});
return
replace
(
cpp_generator
g
;
compile_op
(
ctx
,
to_shapes
(
ins
->
inputs
()),
{{
"lambda"
,
"[](auto x) { return x; }"
}}));
g
.
fmap
([](
const
std
::
string
&
fname
)
{
return
"migraphx::"
+
fname
;
});
}
g
.
add_point_op
(
"where"
,
"${function:where}(${0}, ${1}, ${2})"
);
else
g
.
add_point_op
(
"prelu"
,
"${function:where}(${0} < 0, ${0} * ${1}, ${0})"
);
{
g
.
add_point_op
(
"sign"
,
assert
(
not
ins
->
module_inputs
().
empty
());
"${function:where}(${0} > 0, 1, ${function:where}(${0} < 0, -1, 0))"
);
auto
*
pm
=
ins
->
module_inputs
().
front
();
g
.
add_point_op
(
"equal"
,
"migraphx::abs(${0} == ${1})"
);
run_passes
(
*
pm
,
{
eliminate_common_subexpression
{},
dead_code_elimination
{}});
g
.
add_point_op
(
"less"
,
"migraphx::abs(${0} < ${1})"
);
cpp_generator
g
;
g
.
add_point_op
(
"greater"
,
"migraphx::abs(${0} > ${1})"
);
g
.
fmap
([](
const
std
::
string
&
fname
)
{
return
"migraphx::"
+
fname
;
});
g
.
add_point_op
(
"not"
,
"migraphx::abs(not ${0})"
);
g
.
add_point_op
(
"where"
,
"${function:where}(${0}, ${1}, ${2})"
);
// Add explict conversions
g
.
add_point_op
(
"prelu"
,
"${function:where}(${0} < 0, ${0} * ${1}, ${0})"
);
g
.
fresult
(
g
.
add_point_op
(
"sign"
,
[](
const
shape
&
s
)
{
return
"migraphx::convert<"
+
shape
::
cpp_type
(
s
.
type
())
+
">"
;
});
"${function:where}(${0} > 0, 1, ${function:where}(${0} < 0, -1, 0))"
);
auto
name
=
g
.
create_function
(
g
.
add_point_op
(
"equal"
,
"migraphx::abs(${0} == ${1})"
);
g
.
generate_module
(
*
pm
).
set_attributes
({
"__device__"
}).
set_generic_types
(
*
pm
));
g
.
add_point_op
(
"less"
,
"migraphx::abs(${0} < ${1})"
);
std
::
string
lambda
=
"MIGRAPHX_LIFT("
+
name
+
")"
;
g
.
add_point_op
(
"greater"
,
"migraphx::abs(${0} > ${1})"
);
return
replace
(
g
.
add_point_op
(
"not"
,
"migraphx::abs(not ${0})"
);
compile_op
(
ctx
,
to_shapes
(
ins
->
inputs
()),
{{
"lambda"
,
lambda
},
{
"preamble"
,
g
.
str
()}}));
// Add explict conversions
g
.
fresult
(
[](
const
shape
&
s
)
{
return
"migraphx::convert<"
+
shape
::
cpp_type
(
s
.
type
())
+
">"
;
});
auto
name
=
g
.
create_function
(
g
.
generate_module
(
*
pm
).
set_attributes
({
"__device__"
}).
set_generic_types
(
*
pm
));
std
::
string
lambda
=
"MIGRAPHX_LIFT("
+
name
+
")"
;
return
replace
(
compile_op
(
ctx
,
to_shapes
(
ins
->
inputs
()),
{{
"lambda"
,
lambda
},
{
"preamble"
,
g
.
str
()}}));
}
}
}
};
};
}
// namespace gpu
}
// namespace gpu
...
...
src/targets/gpu/kernels/include/migraphx/kernels/pointwise.hpp
View file @
407acb7d
...
@@ -18,8 +18,15 @@ struct implicit_conversion_op
...
@@ -18,8 +18,15 @@ struct implicit_conversion_op
template
<
index_int
N
,
class
U
>
template
<
index_int
N
,
class
U
>
constexpr
operator
vec
<
U
,
N
>
()
const
constexpr
operator
vec
<
U
,
N
>
()
const
{
{
static_assert
(
vec_size
<
T
>
()
==
N
,
"Vector mismatch size"
);
if
constexpr
(
vec_size
<
T
>
()
==
0
)
return
__builtin_convertvector
(
x
,
vec
<
U
,
N
>
);
{
return
x
;
}
else
{
static_assert
(
vec_size
<
T
>
()
==
N
,
"Vector mismatch size"
);
return
__builtin_convertvector
(
x
,
vec
<
U
,
N
>
);
}
}
}
template
<
class
U
>
template
<
class
U
>
...
...
src/targets/gpu/lowering.cpp
View file @
407acb7d
...
@@ -130,7 +130,7 @@ struct miopen_apply
...
@@ -130,7 +130,7 @@ struct miopen_apply
add_generic_op
(
"atan"
);
add_generic_op
(
"atan"
);
add_generic_op
(
"atanh"
);
add_generic_op
(
"atanh"
);
add_generic_op
(
"ceil"
);
add_generic_op
(
"ceil"
);
add_generic_op
(
"contiguous"
);
//
add_generic_op("contiguous");
add_generic_op
(
"cos"
);
add_generic_op
(
"cos"
);
add_generic_op
(
"cosh"
);
add_generic_op
(
"cosh"
);
add_generic_op
(
"div"
);
add_generic_op
(
"div"
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment