Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
faefeef9
Unverified
Commit
faefeef9
authored
May 25, 2022
by
Charlie Lin
Committed by
GitHub
May 25, 2022
Browse files
Merge branch 'develop' into dyn_shape_update
parents
97a40ac3
bf0a4713
Changes
94
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
366 additions
and
181 deletions
+366
-181
src/targets/gpu/include/migraphx/gpu/schedule_model.hpp
src/targets/gpu/include/migraphx/gpu/schedule_model.hpp
+3
-3
src/targets/gpu/include/migraphx/gpu/sync_device.hpp
src/targets/gpu/include/migraphx/gpu/sync_device.hpp
+1
-1
src/targets/gpu/include/migraphx/gpu/write_literals.hpp
src/targets/gpu/include/migraphx/gpu/write_literals.hpp
+1
-1
src/targets/gpu/jit/gathernd.cpp
src/targets/gpu/jit/gathernd.cpp
+1
-1
src/targets/gpu/jit/pointwise.cpp
src/targets/gpu/jit/pointwise.cpp
+107
-21
src/targets/gpu/jit/roialign.cpp
src/targets/gpu/jit/roialign.cpp
+0
-1
src/targets/gpu/jit/scatternd.cpp
src/targets/gpu/jit/scatternd.cpp
+0
-1
src/targets/gpu/kernels/include/migraphx/kernels/array.hpp
src/targets/gpu/kernels/include/migraphx/kernels/array.hpp
+2
-2
src/targets/gpu/kernels/include/migraphx/kernels/basic_ops.hpp
...argets/gpu/kernels/include/migraphx/kernels/basic_ops.hpp
+0
-84
src/targets/gpu/kernels/include/migraphx/kernels/functional.hpp
...rgets/gpu/kernels/include/migraphx/kernels/functional.hpp
+46
-20
src/targets/gpu/kernels/include/migraphx/kernels/iota_iterator.hpp
...ts/gpu/kernels/include/migraphx/kernels/iota_iterator.hpp
+1
-1
src/targets/gpu/kernels/include/migraphx/kernels/pointwise.hpp
...argets/gpu/kernels/include/migraphx/kernels/pointwise.hpp
+8
-11
src/targets/gpu/kernels/include/migraphx/kernels/preload.hpp
src/targets/gpu/kernels/include/migraphx/kernels/preload.hpp
+45
-1
src/targets/gpu/kernels/include/migraphx/kernels/roialign.hpp
...targets/gpu/kernels/include/migraphx/kernels/roialign.hpp
+9
-7
src/targets/gpu/kernels/include/migraphx/kernels/tensor_view.hpp
...gets/gpu/kernels/include/migraphx/kernels/tensor_view.hpp
+1
-1
src/targets/gpu/kernels/include/migraphx/kernels/type_traits.hpp
...gets/gpu/kernels/include/migraphx/kernels/type_traits.hpp
+15
-0
src/targets/gpu/kernels/include/migraphx/kernels/vec.hpp
src/targets/gpu/kernels/include/migraphx/kernels/vec.hpp
+11
-2
src/targets/gpu/kernels/include/migraphx/kernels/vectorize.hpp
...argets/gpu/kernels/include/migraphx/kernels/vectorize.hpp
+31
-15
src/targets/gpu/prefuse_ops.cpp
src/targets/gpu/prefuse_ops.cpp
+76
-0
src/targets/gpu/schedule_model.cpp
src/targets/gpu/schedule_model.cpp
+8
-8
No files found.
src/targets/gpu/include/migraphx/gpu/schedule_model.hpp
View file @
faefeef9
...
@@ -17,9 +17,9 @@ struct schedule_model
...
@@ -17,9 +17,9 @@ struct schedule_model
{
{
std
::
size_t
streams
=
0
;
std
::
size_t
streams
=
0
;
std
::
size_t
concurrency
()
const
;
std
::
size_t
concurrency
()
const
;
void
sched
(
module
&
p
,
instruction_ref
ins
,
std
::
size_t
n
)
const
;
void
sched
(
module
&
m
,
instruction_ref
ins
,
std
::
size_t
n
)
const
;
void
wait
(
module
&
p
,
instruction_ref
ins
,
std
::
size_t
wait_id
)
const
;
void
wait
(
module
&
m
,
instruction_ref
ins
,
std
::
size_t
wait_id
)
const
;
void
record
(
module
&
p
,
instruction_ref
ins
,
std
::
size_t
wait_id
)
const
;
void
record
(
module
&
m
,
instruction_ref
ins
,
std
::
size_t
wait_id
)
const
;
std
::
size_t
weight
(
const
operation
&
op
)
const
;
std
::
size_t
weight
(
const
operation
&
op
)
const
;
};
};
...
...
src/targets/gpu/include/migraphx/gpu/sync_device.hpp
View file @
faefeef9
...
@@ -15,7 +15,7 @@ namespace gpu {
...
@@ -15,7 +15,7 @@ namespace gpu {
struct
sync_device
struct
sync_device
{
{
std
::
string
name
()
const
{
return
"sync_device"
;
}
std
::
string
name
()
const
{
return
"sync_device"
;
}
void
apply
(
module
&
p
)
const
;
void
apply
(
module
&
m
)
const
;
};
};
}
// namespace gpu
}
// namespace gpu
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace MIGRAPHX_INLINE_NS
...
...
src/targets/gpu/include/migraphx/gpu/write_literals.hpp
View file @
faefeef9
...
@@ -14,7 +14,7 @@ struct write_literals
...
@@ -14,7 +14,7 @@ struct write_literals
context
*
ctx
=
nullptr
;
context
*
ctx
=
nullptr
;
std
::
string
name
()
const
{
return
"gpu::write_literals"
;
}
std
::
string
name
()
const
{
return
"gpu::write_literals"
;
}
void
apply
(
module
&
p
)
const
;
void
apply
(
module
&
m
)
const
;
};
};
}
// namespace gpu
}
// namespace gpu
...
...
src/targets/gpu/jit/gathernd.cpp
View file @
faefeef9
...
@@ -19,7 +19,7 @@ namespace gpu {
...
@@ -19,7 +19,7 @@ namespace gpu {
// NOLINTNEXTLINE
// NOLINTNEXTLINE
static
const
char
*
const
gathernd_kernel
=
R"__migraphx__(
static
const
char
*
const
gathernd_kernel
=
R"__migraphx__(
#include <migraphx/kernels/gathernd.hpp>
#include <migraphx/kernels/gathernd.hpp>
#include <migraphx/kernels/
basic_
ops.hpp>
#include <migraphx/kernels/ops.hpp>
#include <migraphx/kernels/integral_constant.hpp>
#include <migraphx/kernels/integral_constant.hpp>
#include <migraphx/kernels/generic_constant.hpp>
#include <migraphx/kernels/generic_constant.hpp>
#include <args.hpp>
#include <args.hpp>
...
...
src/targets/gpu/jit/pointwise.cpp
View file @
faefeef9
...
@@ -6,6 +6,7 @@
...
@@ -6,6 +6,7 @@
#include <migraphx/cpp_generator.hpp>
#include <migraphx/cpp_generator.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/reduce_dims.hpp>
#include <migraphx/reduce_dims.hpp>
#include <migraphx/permutation.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/eliminate_common_subexpression.hpp>
#include <migraphx/eliminate_common_subexpression.hpp>
...
@@ -26,9 +27,10 @@ namespace migraphx {
...
@@ -26,9 +27,10 @@ namespace migraphx {
${preamble}
${preamble}
extern "C" {
extern "C" {
__global__ void kernel(${params})
__global__ void
${
kernel
}
(${params})
{
{
pointwise(${lambda}, ${args});
auto idx = make_index();
pointwise(idx, auto_preload<${preloads}>(idx), vectorize<${vec_size}, ${axis}>())(${lambda}, ${args});
}
}
}
}
...
@@ -37,44 +39,123 @@ __global__ void kernel(${params})
...
@@ -37,44 +39,123 @@ __global__ void kernel(${params})
)__migraphx__"
;
)__migraphx__"
;
static
std
::
vector
<
std
::
string
>
get_op_names
(
const
module
&
m
)
{
std
::
vector
<
std
::
string
>
result
;
for
(
auto
&
ins
:
m
)
{
if
(
starts_with
(
ins
.
name
(),
"@"
))
continue
;
result
.
push_back
(
ins
.
name
());
}
return
result
;
}
struct
pointwise_compiler
:
compiler
<
pointwise_compiler
>
struct
pointwise_compiler
:
compiler
<
pointwise_compiler
>
{
{
std
::
vector
<
std
::
string
>
names
()
const
{
return
{
"pointwise"
};
}
std
::
vector
<
std
::
string
>
names
()
const
{
return
{
"pointwise"
};
}
static
std
::
size_t
oversubscribe
(
const
std
::
vector
<
shape
>&
inputs
)
static
std
::
size_t
oversubscribe
_if
(
bool
b
)
{
{
if
(
std
::
any_of
(
inputs
.
begin
(),
inputs
.
end
(),
[](
const
auto
&
s
)
{
return
s
.
broadcasted
();
}))
if
(
b
)
return
1
;
else
return
256
;
return
256
;
else
return
1
;
}
}
static
std
::
size_t
vectorize_element
s
(
const
std
::
vector
<
shape
>&
inputs
)
static
std
::
size_t
find_fast_axi
s
(
const
std
::
vector
<
shape
>&
inputs
)
{
{
std
::
size_t
n
=
inputs
.
front
().
elements
();
auto
permutation
=
find_permutation
(
inputs
);
auto
it
=
std
::
max_element
(
permutation
.
begin
(),
permutation
.
end
());
return
it
-
permutation
.
begin
();
}
static
std
::
vector
<
bool
>
preload
(
std
::
size_t
axis
,
const
std
::
vector
<
shape
>&
inputs
)
{
const
std
::
size_t
max_lds_bytes
=
4096
;
std
::
vector
<
bool
>
result
;
std
::
transform
(
inputs
.
begin
(),
inputs
.
end
(),
std
::
back_inserter
(
result
),
[
&
](
const
shape
&
input
)
{
return
input
.
strides
()[
axis
]
==
0
;
});
auto
bytes
=
std
::
inner_product
(
inputs
.
begin
(),
inputs
.
end
(),
result
.
begin
(),
std
::
size_t
{
0
},
std
::
plus
<>
{},
[](
const
shape
&
s
,
bool
b
)
->
std
::
size_t
{
if
(
b
)
return
s
.
bytes
();
return
0
;
});
if
(
bytes
<
max_lds_bytes
)
return
result
;
// TODO: Try to partially preload items
std
::
fill
(
result
.
begin
(),
result
.
end
(),
false
);
return
result
;
}
static
std
::
string
preload_str
(
const
std
::
vector
<
bool
>&
bs
)
{
std
::
vector
<
std
::
string
>
bool_strs
;
std
::
transform
(
bs
.
begin
(),
std
::
prev
(
bs
.
end
()),
std
::
back_inserter
(
bool_strs
),
[](
bool
b
)
{
if
(
b
)
return
"true"
;
return
"false"
;
});
return
"false, "
+
join_strings
(
bool_strs
,
", "
);
}
static
std
::
vector
<
std
::
size_t
>
vector_sizes
(
const
std
::
vector
<
shape
>&
inputs
)
{
// If all inputs is half then only use half2
if
(
std
::
all_of
(
inputs
.
begin
(),
inputs
.
end
(),
[](
const
auto
&
s
)
{
if
(
std
::
all_of
(
inputs
.
begin
(),
inputs
.
end
(),
[](
const
auto
&
s
)
{
return
s
.
packed
()
or
s
.
broadcasted
()
;
return
s
.
type
()
==
shape
::
half_type
;
}))
}))
{
return
{
2
};
if
((
n
%
4
)
==
0
)
return
{
4
,
2
};
return
n
/
4
;
else
if
((
n
%
2
)
==
0
)
return
n
/
2
;
}
}
return
n
;
static
auto
vectorize_elements
(
std
::
size_t
axis
,
const
std
::
vector
<
shape
>&
inputs
)
{
auto
sizes
=
vector_sizes
(
inputs
);
std
::
vector
<
std
::
size_t
>
max_vec_size
;
std
::
transform
(
inputs
.
begin
(),
inputs
.
end
(),
std
::
back_inserter
(
max_vec_size
),
[
&
](
const
auto
&
input
)
->
std
::
size_t
{
auto
stride
=
input
.
strides
()[
axis
];
auto
len
=
input
.
lens
()[
axis
];
if
(
stride
!=
0
and
stride
!=
1
)
return
1
;
auto
it
=
std
::
find_if
(
sizes
.
begin
(),
sizes
.
end
(),
[
&
](
auto
i
)
{
return
(
len
%
i
)
==
0
;
});
if
(
it
!=
sizes
.
end
())
return
*
it
;
return
1
;
});
return
*
std
::
min_element
(
max_vec_size
.
begin
(),
max_vec_size
.
end
());
}
}
operation
compile_op
(
context
&
ctx
,
const
std
::
vector
<
shape
>&
inputs
,
const
value
&
v
)
const
operation
compile_op
(
context
&
ctx
,
const
std
::
vector
<
shape
>&
inputs
,
const
value
&
v
)
const
{
{
hip_compile_options
options
;
hip_compile_options
options
;
options
.
set_launch_params
(
v
,
compute_global_for
(
ctx
,
vectorize_elements
(
inputs
),
oversubscribe
(
inputs
)));
options
.
inputs
=
inputs
;
options
.
inputs
=
inputs
;
options
.
output
=
inputs
.
back
();
options
.
output
=
inputs
.
back
();
options
.
virtual_inputs
=
reduce_dims
(
inputs
);
options
.
virtual_inputs
=
reduce_dims
(
inputs
);
options
.
params
=
"-Wno-float-equal"
;
options
.
params
=
"-Wno-float-equal"
;
auto
axis
=
find_fast_axis
(
options
.
virtual_inputs
);
auto
vec_size
=
vectorize_elements
(
axis
,
options
.
virtual_inputs
);
auto
preloads
=
preload
(
axis
,
options
.
virtual_inputs
);
auto
is_preloading
=
std
::
accumulate
(
preloads
.
begin
(),
preloads
.
end
(),
false
,
std
::
logical_or
<>
{});
options
.
kernel_name
=
v
.
get
(
"kernel"
,
"kernel"
);
options
.
set_launch_params
(
v
,
compute_global_for
(
ctx
,
options
.
output
.
elements
()
/
vec_size
,
oversubscribe_if
(
not
is_preloading
)));
auto
src
=
interpolate_string
(
pointwise_kernel
,
auto
src
=
interpolate_string
(
pointwise_kernel
,
{{
"params"
,
enum_params
(
inputs
.
size
(),
"void * private_p"
)},
{{
"kernel"
,
options
.
kernel_name
},
{
"params"
,
enum_params
(
inputs
.
size
(),
"void * private_p"
)},
{
"args"
,
enum_params
(
inputs
.
size
(),
"private_p"
)},
{
"args"
,
enum_params
(
inputs
.
size
(),
"private_p"
)},
{
"lambda"
,
v
.
at
(
"lambda"
).
to
<
std
::
string
>
()},
{
"lambda"
,
v
.
at
(
"lambda"
).
to
<
std
::
string
>
()},
{
"vec_size"
,
std
::
to_string
(
vec_size
)},
{
"axis"
,
std
::
to_string
(
axis
)},
{
"preloads"
,
preload_str
(
preloads
)},
{
"preamble"
,
v
.
get
(
"preamble"
,
std
::
string
{})}});
{
"preamble"
,
v
.
get
(
"preamble"
,
std
::
string
{})}});
return
compile_hip_code_object
(
src
,
options
);
return
compile_hip_code_object
(
src
,
options
);
}
}
...
@@ -100,8 +181,13 @@ struct pointwise_compiler : compiler<pointwise_compiler>
...
@@ -100,8 +181,13 @@ struct pointwise_compiler : compiler<pointwise_compiler>
auto
name
=
g
.
create_function
(
auto
name
=
g
.
create_function
(
g
.
generate_module
(
*
pm
).
set_attributes
({
"__device__"
}).
set_generic_types
(
*
pm
));
g
.
generate_module
(
*
pm
).
set_attributes
({
"__device__"
}).
set_generic_types
(
*
pm
));
std
::
string
lambda
=
"MIGRAPHX_LIFT("
+
name
+
")"
;
std
::
string
lambda
=
"MIGRAPHX_LIFT("
+
name
+
")"
;
auto
op_names
=
get_op_names
(
*
pm
);
op_names
.
push_back
(
"kernel"
);
auto
op_name_string
=
join_strings
(
op_names
,
"_"
);
return
replace
(
return
replace
(
compile_op
(
ctx
,
to_shapes
(
ins
->
inputs
()),
{{
"lambda"
,
lambda
},
{
"preamble"
,
g
.
str
()}}));
compile_op
(
ctx
,
to_shapes
(
ins
->
inputs
()),
{{
"lambda"
,
lambda
},
{
"preamble"
,
g
.
str
()},
{
"kernel"
,
op_name_string
}}));
}
}
};
};
}
// namespace gpu
}
// namespace gpu
...
...
src/targets/gpu/jit/roialign.cpp
View file @
faefeef9
...
@@ -19,7 +19,6 @@ namespace gpu {
...
@@ -19,7 +19,6 @@ namespace gpu {
// NOLINTNEXTLINE
// NOLINTNEXTLINE
static
const
char
*
const
roialign_kernel
=
R"__migraphx__(
static
const
char
*
const
roialign_kernel
=
R"__migraphx__(
#include <migraphx/kernels/roialign.hpp>
#include <migraphx/kernels/roialign.hpp>
#include <migraphx/kernels/basic_ops.hpp>
#include <migraphx/kernels/integral_constant.hpp>
#include <migraphx/kernels/integral_constant.hpp>
#include <migraphx/kernels/generic_constant.hpp>
#include <migraphx/kernels/generic_constant.hpp>
#include <args.hpp>
#include <args.hpp>
...
...
src/targets/gpu/jit/scatternd.cpp
View file @
faefeef9
...
@@ -19,7 +19,6 @@ namespace gpu {
...
@@ -19,7 +19,6 @@ namespace gpu {
// NOLINTNEXTLINE
// NOLINTNEXTLINE
static
const
char
*
const
scatternd_kernel
=
R"__migraphx__(
static
const
char
*
const
scatternd_kernel
=
R"__migraphx__(
#include <migraphx/kernels/scatternd.hpp>
#include <migraphx/kernels/scatternd.hpp>
#include <migraphx/kernels/basic_ops.hpp>
#include <migraphx/kernels/integral_constant.hpp>
#include <migraphx/kernels/integral_constant.hpp>
#include <migraphx/kernels/generic_constant.hpp>
#include <migraphx/kernels/generic_constant.hpp>
#include <args.hpp>
#include <args.hpp>
...
...
src/targets/gpu/kernels/include/migraphx/kernels/array.hpp
View file @
faefeef9
...
@@ -146,8 +146,8 @@ struct array
...
@@ -146,8 +146,8 @@ struct array
constexpr
array
carry
(
array
result
)
const
constexpr
array
carry
(
array
result
)
const
{
{
u
in
t32_
t
overflow
=
0
;
in
dex_in
t
overflow
=
0
;
for
(
std
::
ptr
diff_t
i
=
result
.
size
()
-
1
;
i
>
0
;
i
--
)
for
(
diff_
in
t
i
=
result
.
size
()
-
1
;
i
>
0
;
i
--
)
{
{
auto
z
=
result
[
i
]
+
overflow
;
auto
z
=
result
[
i
]
+
overflow
;
// Reset overflow
// Reset overflow
...
...
src/targets/gpu/kernels/include/migraphx/kernels/basic_ops.hpp
deleted
100755 → 0
View file @
97a40ac3
#ifndef MIGRAPHX_GUARD_AMDMIGRAPHX_KERNELS_BASIC_OPS_HPP
#define MIGRAPHX_GUARD_AMDMIGRAPHX_KERNELS_BASIC_OPS_HPP
#include <migraphx/kernels/types.hpp>
namespace
migraphx
{
struct
sum
{
template
<
class
T
,
class
U
>
constexpr
auto
operator
()(
T
x
,
U
y
)
const
{
return
x
+
y
;
}
};
struct
product
{
template
<
class
T
,
class
U
>
constexpr
auto
operator
()(
T
x
,
U
y
)
const
{
return
x
*
y
;
}
};
struct
id
{
template
<
class
T
>
constexpr
auto
operator
()(
T
x
)
const
{
return
x
;
}
};
struct
mean
{
size_t
item_num
=
1
;
template
<
class
T
>
constexpr
auto
operator
()(
T
x
)
const
{
return
x
/
static_cast
<
T
>
(
item_num
);
}
};
struct
max_f
{
template
<
class
T
,
class
U
>
constexpr
auto
operator
()(
T
x
,
U
y
)
const
{
return
(
x
>
y
)
?
x
:
y
;
}
};
inline
constexpr
auto
max
=
max_f
{};
struct
min_f
{
template
<
class
T
,
class
U
>
constexpr
auto
operator
()(
T
x
,
U
y
)
const
{
return
(
x
<
y
)
?
x
:
y
;
}
};
inline
constexpr
auto
min
=
min_f
{};
struct
lowest
{
template
<
class
T
>
constexpr
operator
T
()
const
{
return
std
::
numeric_limits
<
T
>::
lowest
();
}
};
struct
highest
{
template
<
class
T
>
constexpr
operator
T
()
const
{
return
std
::
numeric_limits
<
T
>::
max
();
}
};
}
// namespace migraphx
#endif // MIGRAPHX_GUARD_AMDMIGRAPHX_KERNELS_BASIC_OPS_HPP
src/targets/gpu/kernels/include/migraphx/kernels/functional.hpp
View file @
faefeef9
...
@@ -3,6 +3,14 @@
...
@@ -3,6 +3,14 @@
#include <migraphx/kernels/array.hpp>
#include <migraphx/kernels/array.hpp>
// NOLINTNEXTLINE
#define MIGRAPHX_RETURNS(...) \
->decltype(__VA_ARGS__) { return __VA_ARGS__; }
// NOLINTNEXTLINE
#define MIGRAPHX_LIFT(...) \
[](auto&&... xs) MIGRAPHX_RETURNS((__VA_ARGS__)(static_cast<decltype(xs)>(xs)...))
namespace
migraphx
{
namespace
migraphx
{
struct
swallow
struct
swallow
...
@@ -129,7 +137,7 @@ constexpr auto by(F f)
...
@@ -129,7 +137,7 @@ constexpr auto by(F f)
template
<
class
F
,
class
...
Ts
>
template
<
class
F
,
class
...
Ts
>
constexpr
void
each_args
(
F
f
,
Ts
&&
...
xs
)
constexpr
void
each_args
(
F
f
,
Ts
&&
...
xs
)
{
{
swallow
{(
f
(
st
d
::
forward
<
Ts
>
(
xs
)),
0
)...};
swallow
{(
f
(
st
atic_cast
<
Ts
&&
>
(
xs
)),
0
)...};
}
}
template
<
class
F
>
template
<
class
F
>
...
@@ -161,6 +169,18 @@ constexpr auto pack(Ts... xs)
...
@@ -161,6 +169,18 @@ constexpr auto pack(Ts... xs)
return
[
=
](
auto
f
)
{
return
f
(
xs
...);
};
return
[
=
](
auto
f
)
{
return
f
(
xs
...);
};
}
}
template
<
class
G
,
class
F
>
constexpr
auto
join
(
G
g
,
F
f
)
{
return
f
([
=
](
auto
...
xs
)
{
return
g
(
xs
...);
});
}
template
<
class
G
,
class
F
,
class
...
Fs
>
constexpr
auto
join
(
G
g
,
F
f
,
Fs
...
fs
)
{
return
f
([
=
](
auto
...
xs
)
{
return
join
([
=
](
auto
...
ys
)
{
return
g
(
xs
...,
ys
...);
},
fs
...);
});
}
template
<
class
Compare
,
class
P1
,
class
P2
>
template
<
class
Compare
,
class
P1
,
class
P2
>
constexpr
auto
pack_compare
(
Compare
compare
,
P1
p1
,
P2
p2
)
constexpr
auto
pack_compare
(
Compare
compare
,
P1
p1
,
P2
p2
)
{
{
...
@@ -191,39 +211,45 @@ constexpr auto arg(IntegralConstant ic)
...
@@ -191,39 +211,45 @@ constexpr auto arg(IntegralConstant ic)
return
arg_c
<
ic
>
();
return
arg_c
<
ic
>
();
}
}
inline
constexpr
auto
rotate_last
()
template
<
class
F
>
constexpr
auto
make_transform
(
F
f
)
{
{
return
[](
auto
...
xs
)
{
return
[
=
](
auto
...
xs
)
{
return
[
=
](
auto
g
)
{
return
f
(
g
,
xs
...);
};
};
return
[
=
](
auto
&&
f
)
{
return
sequence_c
<
sizeof
...(
xs
)
>
([
&
](
auto
...
is
)
{
constexpr
auto
size
=
sizeof
...(
is
);
return
f
(
arg_c
<
(
is
+
size
-
1
)
%
size
>
()(
xs
...)...);
});
};
};
}
}
// An arg transformation takes the arguments and then a function to take the new arguments:
// transform(xs...)([](auto... ys) { ... })
// The transform_args function takes a list of transformations and continually applies them
template
<
class
F
>
template
<
class
F
>
constexpr
auto
transform_args
(
F
f
)
constexpr
auto
transform_args
(
F
f
)
{
{
return
[
=
](
auto
...
xs
)
{
return
f
;
return
[
=
](
auto
g
)
{
return
f
(
xs
...)([
&
](
auto
...
ys
)
{
return
g
(
ys
...);
});
};
};
}
}
template
<
class
F
,
class
...
Fs
>
template
<
class
F
,
class
...
Fs
>
constexpr
auto
transform_args
(
F
f
,
Fs
...
fs
)
constexpr
auto
transform_args
(
F
f
,
Fs
...
fs
)
{
{
return
[
=
](
auto
...
xs
)
{
return
transform_args
(
f
)(
xs
...)(
transform_args
(
fs
...));
};
return
make_transform
([
=
](
auto
g
,
auto
...
xs
)
{
return
f
(
xs
...)([
=
](
auto
...
ys
)
{
return
transform_args
(
fs
...)(
ys
...)(
g
);
});
});
}
}
// NOLINTNEXTLINE
// identity transform
#define MIGRAPHX_RETURNS(...) \
inline
constexpr
auto
transform_args
()
->decltype(__VA_ARGS__) { return __VA_ARGS__; }
{
return
make_transform
([](
auto
f
,
auto
...
xs
)
{
return
f
(
xs
...);
});
}
// NOLINTNEXTLINE
// Rotate the first argument to the last argument
#define MIGRAPHX_LIFT(...) \
inline
constexpr
auto
rotate_last
()
[](auto&&... xs) MIGRAPHX_RETURNS((__VA_ARGS__)(static_cast<decltype(xs)>(xs)...))
{
return
make_transform
([](
auto
f
,
auto
...
xs
)
{
return
sequence_c
<
sizeof
...(
xs
)
>
([
&
](
auto
...
is
)
{
constexpr
auto
size
=
sizeof
...(
is
);
return
f
(
arg_c
<
(
is
+
size
-
1
)
%
size
>
()(
xs
...)...);
});
});
}
}
// namespace migraphx
}
// namespace migraphx
#endif // MIGRAPHX_GUARD_KERNELS_FUNCTIONAL_HPP
#endif // MIGRAPHX_GUARD_KERNELS_FUNCTIONAL_HPP
src/targets/gpu/kernels/include/migraphx/kernels/iota_iterator.hpp
View file @
faefeef9
...
@@ -13,7 +13,7 @@ struct basic_iota_iterator
...
@@ -13,7 +13,7 @@ struct basic_iota_iterator
F
f
;
F
f
;
using
difference_type
=
diff_int
;
using
difference_type
=
diff_int
;
using
reference
=
decltype
(
f
(
std
::
declval
<
Iterator
>
()));
using
reference
=
decltype
(
f
(
declval
<
Iterator
>
()));
using
value_type
=
remove_reference_t
<
reference
>
;
using
value_type
=
remove_reference_t
<
reference
>
;
using
pointer
=
add_pointer_t
<
value_type
>
;
using
pointer
=
add_pointer_t
<
value_type
>
;
...
...
src/targets/gpu/kernels/include/migraphx/kernels/pointwise.hpp
View file @
faefeef9
...
@@ -38,20 +38,17 @@ constexpr implicit_conversion_op<T> implicit_conversion(T x)
...
@@ -38,20 +38,17 @@ constexpr implicit_conversion_op<T> implicit_conversion(T x)
template
<
class
F
,
class
T
,
class
...
Ts
>
template
<
class
F
,
class
T
,
class
...
Ts
>
__device__
void
pointwise_tensor
(
index
idx
,
F
f
,
T
out
,
Ts
...
xs
)
__device__
void
pointwise_tensor
(
index
idx
,
F
f
,
T
out
,
Ts
...
xs
)
{
{
preload
<
typename
T
::
type
>
(
idx
,
xs
...)([
&
](
auto
...
ps
)
{
idx
.
global_stride
(
out
.
get_shape
().
elements
(),
idx
.
global_stride
(
out
.
get_shape
().
elements
(),
[
&
](
auto
i
)
{
out
[
i
]
=
implicit_conversion
(
f
(
ps
[
i
]...));
});
[
&
](
auto
i
)
{
out
[
i
]
=
implicit_conversion
(
f
(
xs
[
i
]...));
});
});
}
}
template
<
class
F
,
class
...
T
s
>
template
<
class
...
Transform
s
>
__device__
void
pointwise
(
F
f
,
Ts
*
...
p
s
)
__device__
auto
pointwise
(
index
idx
,
Transforms
...
transform
s
)
{
{
auto
t
=
transform_args
(
make_tensors
(),
rotate_last
(),
auto_vectorize
());
return
[
=
](
auto
f
,
auto
*
...
ps
)
{
t
(
ps
...)([
&
](
auto
...
xs
)
{
auto
t
=
transform_args
(
make_tensors
(),
rotate_last
(),
transforms
...);
auto
idx
=
make_index
();
t
(
ps
...)([
&
](
auto
...
xs
)
{
pointwise_tensor
(
idx
,
f
,
xs
...);
});
pointwise_tensor
(
idx
,
f
,
xs
...);
};
});
}
}
}
// namespace migraphx
}
// namespace migraphx
...
...
src/targets/gpu/kernels/include/migraphx/kernels/preload.hpp
View file @
faefeef9
...
@@ -3,6 +3,8 @@
...
@@ -3,6 +3,8 @@
#include <migraphx/kernels/index.hpp>
#include <migraphx/kernels/index.hpp>
#include <migraphx/kernels/functional.hpp>
#include <migraphx/kernels/functional.hpp>
#include <migraphx/kernels/tensor_view.hpp>
#include <migraphx/kernels/vec.hpp>
namespace
migraphx
{
namespace
migraphx
{
...
@@ -73,7 +75,7 @@ __device__ auto preload_copy(index idx, F f, __shared__ T* buffer, Ts... xs)
...
@@ -73,7 +75,7 @@ __device__ auto preload_copy(index idx, F f, __shared__ T* buffer, Ts... xs)
{
{
if
constexpr
(
decltype
(
tensor_vec_size
(
x
)){}
==
0
)
if
constexpr
(
decltype
(
tensor_vec_size
(
x
)){}
==
0
)
{
{
auto
v
=
vectorize
(
x
);
auto
v
=
auto_
vectorize
(
x
);
auto
b
=
as_vec
(
tensor_vec_size
(
v
),
buffer
+
offset
);
auto
b
=
as_vec
(
tensor_vec_size
(
v
),
buffer
+
offset
);
idx
.
local_stride
(
v
.
get_shape
().
element_space
(),
idx
.
local_stride
(
v
.
get_shape
().
element_space
(),
[
&
](
auto
i
)
{
b
[
i
]
=
v
.
data
()[
i
];
});
[
&
](
auto
i
)
{
b
[
i
]
=
v
.
data
()[
i
];
});
...
@@ -126,5 +128,47 @@ __device__ auto preload(index idx, Ts... xs)
...
@@ -126,5 +128,47 @@ __device__ auto preload(index idx, Ts... xs)
};
};
}
}
inline
__device__
auto
auto_preload
(
index
idx
)
{
return
make_transform
([
=
](
auto
f
,
auto
out
,
auto
...
xs
)
{
preload
<
typename
decltype
(
out
)
::
type
>
(
idx
,
xs
...)([
&
](
auto
...
ys
)
{
f
(
out
,
ys
...);
});
});
}
template
<
bool
B
,
class
T
>
__device__
auto
preload_copy
(
index
idx
,
T
x
)
{
return
[
=
](
auto
f
)
{
if
constexpr
(
B
)
{
using
type
=
typename
T
::
type
;
constexpr
auto
size
=
get_shape_c
<
T
>
{}.
element_space
();
__shared__
type
buffer
[
size
];
// TODO: Always vecotrize when size > 4, and then use a second loop for remainder
constexpr
auto
n
=
find_vectorize_size
([
&
](
auto
i
)
{
return
(
size
%
i
)
==
0
;
});
auto
input
=
as_vec
<
n
>
(
remove_bool
(
x
.
data
()));
auto
b
=
as_vec
<
n
>
(
remove_bool
(
buffer
));
idx
.
local_stride
(
size
/
n
,
[
&
](
auto
i
)
{
b
[
i
]
=
input
[
i
];
});
return
f
(
x
.
with
(
buffer
));
}
else
{
return
f
(
x
);
}
};
}
template
<
bool
...
Bs
>
__device__
auto
auto_preload
(
index
idx
)
{
return
make_transform
([
=
](
auto
f
,
auto
...
xs
)
{
auto
invoke
=
[
=
](
auto
...
ys
)
{
__syncthreads
();
f
(
ys
...);
};
join
(
invoke
,
preload_copy
<
Bs
>
(
idx
,
xs
)...);
});
}
}
// namespace migraphx
}
// namespace migraphx
#endif // MIGRAPHX_GUARD_KERNELS_PRELOAD_HPP
#endif // MIGRAPHX_GUARD_KERNELS_PRELOAD_HPP
src/targets/gpu/kernels/include/migraphx/kernels/roialign.hpp
View file @
faefeef9
...
@@ -3,14 +3,15 @@
...
@@ -3,14 +3,15 @@
#include <migraphx/kernels/index.hpp>
#include <migraphx/kernels/index.hpp>
#include <migraphx/kernels/dfor.hpp>
#include <migraphx/kernels/dfor.hpp>
#include <migraphx/kernels/basic_ops.hpp>
#include <migraphx/kernels/ops.hpp>
#include <migraphx/kernels/math.hpp>
#include <migraphx/kernels/array.hpp>
#include <migraphx/kernels/array.hpp>
namespace
migraphx
{
namespace
migraphx
{
struct
max_pool
struct
max_pool
{
{
MIGRAPHX_DEVICE_CONSTEXPR
auto
init
()
{
return
lowest
()
;
}
MIGRAPHX_DEVICE_CONSTEXPR
auto
init
()
{
return
lowest
{}
;
}
template
<
class
T
>
template
<
class
T
>
MIGRAPHX_DEVICE_CONSTEXPR
T
operator
()(
T
x
,
T
y
)
MIGRAPHX_DEVICE_CONSTEXPR
T
operator
()(
T
x
,
T
y
)
...
@@ -55,7 +56,7 @@ MIGRAPHX_DEVICE_CONSTEXPR typename Iterator::value_type bilinear_interpolate(
...
@@ -55,7 +56,7 @@ MIGRAPHX_DEVICE_CONSTEXPR typename Iterator::value_type bilinear_interpolate(
return
0
;
return
0
;
}
}
xy
[
ii
]
=
max
(
xy
[
ii
],
0.0
f
);
xy
[
ii
]
=
migraphx
::
max
(
xy
[
ii
],
0.0
f
);
low
[
ii
]
=
xy
[
ii
];
low
[
ii
]
=
xy
[
ii
];
high
[
ii
]
=
low
[
ii
]
+
1
;
high
[
ii
]
=
low
[
ii
]
+
1
;
if
(
low
[
ii
]
>=
dims
[
ii
]
-
1
)
if
(
low
[
ii
]
>=
dims
[
ii
]
-
1
)
...
@@ -164,11 +165,12 @@ __device__ void roialign(const T& x_t, const U& rois_t, const V& ind_t, W& y_t,
...
@@ -164,11 +165,12 @@ __device__ void roialign(const T& x_t, const U& rois_t, const V& ind_t, W& y_t,
for
(
index_int
ii
=
0
;
ii
<
roi_size
.
size
();
++
ii
)
for
(
index_int
ii
=
0
;
ii
<
roi_size
.
size
();
++
ii
)
{
{
roi_size
[
ii
]
=
roi_ends
[
ii
]
-
roi_starts
[
ii
];
roi_size
[
ii
]
=
roi_ends
[
ii
]
-
roi_starts
[
ii
];
roi_size
[
ii
]
=
max
(
roi_size
[
ii
],
1.0
f
);
roi_size
[
ii
]
=
migraphx
::
max
(
roi_size
[
ii
],
1.0
f
);
bin_size
[
ii
]
=
roi_size
[
ii
]
/
out_dims
[
ii
];
bin_size
[
ii
]
=
roi_size
[
ii
]
/
out_dims
[
ii
];
bin_grid_size
[
ii
]
=
bin_grid_size
[
ii
]
=
(
s
.
sampling_ratio
>
0
)
(
s
.
sampling_ratio
>
0
)
?
s
.
sampling_ratio
:
std
::
ceil
(
roi_size
[
ii
]
/
out_dims
[
ii
]);
?
s
.
sampling_ratio
:
migraphx
::
ceil
(
roi_size
[
ii
]
/
out_dims
[
ii
]);
}
}
const
auto
offset_x
=
x
+
((
batch_ind
*
channel_num
+
c
)
*
in_dims
[
0
]
*
in_dims
[
1
]);
const
auto
offset_x
=
x
+
((
batch_ind
*
channel_num
+
c
)
*
in_dims
[
0
]
*
in_dims
[
1
]);
...
...
src/targets/gpu/kernels/include/migraphx/kernels/tensor_view.hpp
View file @
faefeef9
...
@@ -11,7 +11,7 @@ template <class T>
...
@@ -11,7 +11,7 @@ template <class T>
struct
tensor_view_iterator_read
struct
tensor_view_iterator_read
{
{
T
*
view
;
T
*
view
;
constexpr
auto
&
operator
()(
std
::
size_
t
n
)
const
constexpr
auto
&
operator
()(
index_in
t
n
)
const
{
{
MIGRAPHX_ASSERT
(
view
!=
nullptr
);
MIGRAPHX_ASSERT
(
view
!=
nullptr
);
return
(
*
view
)[
n
];
return
(
*
view
)[
n
];
...
...
src/targets/gpu/kernels/include/migraphx/kernels/type_traits.hpp
View file @
faefeef9
...
@@ -35,6 +35,21 @@ struct enable_if<true, T>
...
@@ -35,6 +35,21 @@ struct enable_if<true, T>
template
<
bool
B
,
class
T
=
void
>
template
<
bool
B
,
class
T
=
void
>
using
enable_if_t
=
typename
enable_if
<
B
,
T
>::
type
;
using
enable_if_t
=
typename
enable_if
<
B
,
T
>::
type
;
template
<
bool
B
,
class
T
,
class
F
>
struct
conditional
{
using
type
=
T
;
};
template
<
class
T
,
class
F
>
struct
conditional
<
false
,
T
,
F
>
{
using
type
=
F
;
};
template
<
bool
B
,
class
T
,
class
F
>
using
conditional_t
=
typename
conditional
<
B
,
T
,
F
>::
type
;
// NOLINTNEXTLINE
// NOLINTNEXTLINE
#define MIGRAPHX_BUILTIN_TYPE_TRAIT1(name) \
#define MIGRAPHX_BUILTIN_TYPE_TRAIT1(name) \
template <class T> \
template <class T> \
...
...
src/targets/gpu/kernels/include/migraphx/kernels/vec.hpp
View file @
faefeef9
...
@@ -60,17 +60,26 @@ constexpr auto common_vec_size()
...
@@ -60,17 +60,26 @@ constexpr auto common_vec_size()
})(
vec_size
<
Ts
>
()...);
})(
vec_size
<
Ts
>
()...);
}
}
// Bools can not be used as a vector type so convert it to uint8
template
<
class
T
>
__device__
__host__
T
*
remove_bool
(
T
*
x
)
{
return
x
;
}
inline
__device__
__host__
uint8_t
*
remove_bool
(
bool
*
x
)
{
return
reinterpret_cast
<
uint8_t
*>
(
x
);
}
template
<
index_int
N
,
class
T
>
template
<
index_int
N
,
class
T
>
__device__
__host__
auto
as_vec
(
T
*
x
)
__device__
__host__
auto
as_vec
(
T
*
x
)
{
{
if
constexpr
(
N
==
0
)
if
constexpr
(
N
<
2
)
return
x
;
return
x
;
else
else
return
reinterpret_cast
<
vec
<
T
,
N
>*>
(
x
);
return
reinterpret_cast
<
vec
<
T
,
N
>*>
(
x
);
}
}
template
<
class
T
,
index_int
N
>
template
<
class
T
,
index_int
N
>
using
safe_vec
=
vec
<
std
::
conditional_t
<
std
::
is_same
<
T
,
bool
>
{},
uint8_t
,
T
>
,
N
>
;
using
safe_vec
=
vec
<
conditional_t
<
is_same
<
T
,
bool
>
{},
uint8_t
,
T
>
,
N
>
;
template
<
class
...
Ts
>
template
<
class
...
Ts
>
constexpr
auto
vec_transform
(
Ts
...
xs
)
constexpr
auto
vec_transform
(
Ts
...
xs
)
...
...
src/targets/gpu/kernels/include/migraphx/kernels/vectorize.hpp
View file @
faefeef9
...
@@ -50,19 +50,10 @@ constexpr auto shape_step(Shape s, Axis)
...
@@ -50,19 +50,10 @@ constexpr auto shape_step(Shape s, Axis)
});
});
}
}
// Bools can not be used as a vector type so convert it to uint8
template
<
class
T
>
__device__
__host__
T
*
remove_bool
(
T
*
x
)
{
return
x
;
}
inline
__device__
__host__
uint8_t
*
remove_bool
(
bool
*
x
)
{
return
reinterpret_cast
<
uint8_t
*>
(
x
);
}
template
<
index_int
N
,
class
T
,
class
Axis
>
template
<
index_int
N
,
class
T
,
class
Axis
>
__device__
__host__
auto
as_vec
(
T
x
,
Axis
axis
)
__device__
__host__
auto
as_vec
(
T
x
,
Axis
axis
)
{
{
if
constexpr
(
N
==
0
)
if
constexpr
(
N
<
2
)
return
x
;
return
x
;
else
else
return
make_tensor_view
(
as_vec
<
N
>
(
remove_bool
(
x
.
data
())),
return
make_tensor_view
(
as_vec
<
N
>
(
remove_bool
(
x
.
data
())),
...
@@ -72,7 +63,7 @@ __device__ __host__ auto as_vec(T x, Axis axis)
...
@@ -72,7 +63,7 @@ __device__ __host__ auto as_vec(T x, Axis axis)
template
<
index_int
N
,
class
T
,
class
Axis
>
template
<
index_int
N
,
class
T
,
class
Axis
>
constexpr
auto
tensor_step
(
T
x
,
Axis
axis
)
constexpr
auto
tensor_step
(
T
x
,
Axis
axis
)
{
{
if
constexpr
(
N
==
0
)
if
constexpr
(
N
<
2
)
{
{
return
x
;
return
x
;
}
}
...
@@ -157,11 +148,11 @@ constexpr auto find_vectorize_size(P pred)
...
@@ -157,11 +148,11 @@ constexpr auto find_vectorize_size(P pred)
else
if
constexpr
(
decltype
(
pred
(
_c
<
2
>
)){})
else
if
constexpr
(
decltype
(
pred
(
_c
<
2
>
)){})
return
_c
<
2
>
;
return
_c
<
2
>
;
else
else
return
_c
<
0
>
;
return
_c
<
1
>
;
}
}
template
<
class
T
>
template
<
class
T
>
__host__
__device__
auto
vectorize
(
T
x
)
__host__
__device__
auto
auto_
vectorize
(
T
x
)
{
{
if
constexpr
(
tensor_vec_size
<
T
>
()
==
0
)
if
constexpr
(
tensor_vec_size
<
T
>
()
==
0
)
{
{
...
@@ -194,7 +185,7 @@ inline __device__ __host__ auto auto_vectorize_impl(F f, Ts... xs)
...
@@ -194,7 +185,7 @@ inline __device__ __host__ auto auto_vectorize_impl(F f, Ts... xs)
{
{
MIGRAPHX_ASSERT
(
s
.
strides
[
axis
]
==
0
or
s
.
strides
[
axis
]
==
1
);
MIGRAPHX_ASSERT
(
s
.
strides
[
axis
]
==
0
or
s
.
strides
[
axis
]
==
1
);
MIGRAPHX_ASSERT
(
s
.
lens
[
axis
]
>
0
);
MIGRAPHX_ASSERT
(
s
.
lens
[
axis
]
>
0
);
MIGRAPHX_ASSERT
(
n
==
0
or
s
.
lens
[
axis
]
%
n
==
0
);
MIGRAPHX_ASSERT
(
n
==
1
or
s
.
lens
[
axis
]
%
n
==
0
);
if
constexpr
(
s
.
strides
[
axis
]
==
0
)
if
constexpr
(
s
.
strides
[
axis
]
==
0
)
return
tensor_step
<
n
>
(
x
,
axis
);
return
tensor_step
<
n
>
(
x
,
axis
);
else
else
...
@@ -215,7 +206,32 @@ inline __device__ __host__ auto auto_vectorize_impl(F f, Ts... xs)
...
@@ -215,7 +206,32 @@ inline __device__ __host__ auto auto_vectorize_impl(F f, Ts... xs)
inline
__device__
__host__
auto
auto_vectorize
()
inline
__device__
__host__
auto
auto_vectorize
()
{
{
return
[](
auto
...
xs
)
{
return
[
=
](
auto
f
)
{
auto_vectorize_impl
(
f
,
xs
...);
};
};
return
make_transform
([](
auto
f
,
auto
...
xs
)
{
auto_vectorize_impl
(
f
,
xs
...);
});
}
template
<
index_int
N
,
index_int
Axis
,
class
T
>
__device__
__host__
auto
vectorize_tensor
(
T
x
)
{
constexpr
auto
shape
=
get_shape_c
<
T
>
{};
if
constexpr
(
shape
.
strides
[
Axis
]
==
0
)
return
tensor_step
<
N
>
(
x
,
_c
<
Axis
>
);
else
return
as_vec
<
N
>
(
x
,
_c
<
Axis
>
);
}
template
<
index_int
N
,
index_int
Axis
>
__device__
__host__
auto
vectorize
()
{
return
make_transform
([](
auto
f
,
auto
...
xs
)
{
if
constexpr
(
N
<
2
)
{
f
(
xs
...);
}
else
{
f
(
vectorize_tensor
<
N
,
Axis
>
(
xs
)...);
}
});
}
}
}
// namespace migraphx
}
// namespace migraphx
...
...
src/targets/gpu/prefuse_ops.cpp
0 → 100644
View file @
faefeef9
#include <migraphx/gpu/prefuse_ops.hpp>
#include <migraphx/match/layernorm.hpp>
#include <migraphx/make_op.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
gpu
{
namespace
{
struct
find_layernorm
{
auto
matcher
()
const
{
return
match
::
layernorm
();
}
void
apply
(
module
&
m
,
const
match
::
matcher_result
&
r
)
const
{
auto
ins
=
r
.
result
;
auto
x_ins
=
r
.
instructions
[
"x"
];
if
(
not
x_ins
->
get_shape
().
standard
())
x_ins
=
m
.
insert_instruction
(
ins
,
make_op
(
"contiguous"
),
x_ins
);
auto
relements
=
x_ins
->
get_shape
().
lens
().
back
();
if
(
relements
>
1024
or
(
relements
%
4
!=
0
and
relements
>
256
))
return
;
auto
a
=
m
.
insert_instruction
(
ins
,
make_op
(
"hip::allocate"
,
{{
"shape"
,
to_value
(
x_ins
->
get_shape
())}}));
m
.
replace_instruction
(
ins
,
make_op
(
"gpu::layernorm"
),
x_ins
,
a
);
}
};
struct
find_triaddlayernorm
{
auto
matcher
()
const
{
auto
add1
=
match
::
name
(
"add"
)(
match
::
none_of
(
match
::
is_constant
()),
match
::
args
(
match
::
any
().
bind
(
"z1"
),
match
::
any
().
bind
(
"z2"
)));
auto
add2
=
match
::
name
(
"add"
)(
match
::
either_arg
(
0
,
1
)(
add1
,
match
::
any
().
bind
(
"z3"
)));
return
match
::
layernorm
()(
match
::
var
(
"x"
)(
add2
));
}
void
apply
(
module
&
m
,
const
match
::
matcher_result
&
r
)
const
{
auto
ins
=
r
.
result
;
auto
x_ins
=
r
.
instructions
[
"z1"
];
auto
y_ins
=
r
.
instructions
[
"z2"
];
auto
z_ins
=
r
.
instructions
[
"z3"
];
for
(
auto
*
pins
:
{
&
x_ins
,
&
y_ins
,
&
z_ins
})
{
if
(
not
(
*
pins
)
->
get_shape
().
standard
())
*
pins
=
m
.
insert_instruction
(
ins
,
make_op
(
"contiguous"
),
*
pins
);
}
auto
relements
=
x_ins
->
get_shape
().
lens
().
back
();
if
(
relements
>
1024
or
(
relements
%
4
!=
0
and
relements
>
256
))
return
;
auto
a
=
m
.
insert_instruction
(
ins
,
make_op
(
"hip::allocate"
,
{{
"shape"
,
to_value
(
x_ins
->
get_shape
())}}));
m
.
replace_instruction
(
ins
,
make_op
(
"gpu::triadd_layernorm"
),
x_ins
,
y_ins
,
z_ins
,
a
);
}
};
}
// namespace
void
prefuse_ops
::
apply
(
module
&
m
)
const
{
match
::
find_matches
(
m
,
find_triaddlayernorm
{},
find_layernorm
{});
}
}
// namespace gpu
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
src/targets/gpu/schedule_model.cpp
View file @
faefeef9
...
@@ -77,28 +77,28 @@ MIGRAPHX_REGISTER_OP(wait_event)
...
@@ -77,28 +77,28 @@ MIGRAPHX_REGISTER_OP(wait_event)
MIGRAPHX_REGISTER_OP
(
set_stream
)
MIGRAPHX_REGISTER_OP
(
set_stream
)
std
::
size_t
schedule_model
::
concurrency
()
const
{
return
streams
;
}
std
::
size_t
schedule_model
::
concurrency
()
const
{
return
streams
;
}
void
schedule_model
::
sched
(
module
&
p
,
instruction_ref
ins
,
std
::
size_t
n
)
const
void
schedule_model
::
sched
(
module
&
m
,
instruction_ref
ins
,
std
::
size_t
n
)
const
{
{
auto
last_stream
=
std
::
find_if
(
std
::
make_reverse_iterator
(
ins
),
auto
last_stream
=
std
::
find_if
(
std
::
make_reverse_iterator
(
ins
),
std
::
make_reverse_iterator
(
p
.
begin
()),
std
::
make_reverse_iterator
(
m
.
begin
()),
[
&
](
auto
&&
i
)
{
return
i
.
name
()
==
"gpu::set_stream"
;
});
[
&
](
auto
&&
i
)
{
return
i
.
name
()
==
"gpu::set_stream"
;
});
if
(
last_stream
!=
std
::
make_reverse_iterator
(
p
.
begin
()))
if
(
last_stream
!=
std
::
make_reverse_iterator
(
m
.
begin
()))
{
{
auto
&&
op
=
any_cast
<
set_stream
>
(
last_stream
->
get_operator
());
auto
&&
op
=
any_cast
<
set_stream
>
(
last_stream
->
get_operator
());
// If the same stream was set earlier then skip
// If the same stream was set earlier then skip
if
(
op
.
stream
==
n
)
if
(
op
.
stream
==
n
)
return
;
return
;
}
}
p
.
insert_instruction
(
ins
,
set_stream
{
n
});
m
.
insert_instruction
(
ins
,
set_stream
{
n
});
}
}
void
schedule_model
::
wait
(
module
&
p
,
instruction_ref
ins
,
std
::
size_t
wait_id
)
const
void
schedule_model
::
wait
(
module
&
m
,
instruction_ref
ins
,
std
::
size_t
wait_id
)
const
{
{
p
.
insert_instruction
(
ins
,
wait_event
{
wait_id
});
m
.
insert_instruction
(
ins
,
wait_event
{
wait_id
});
}
}
void
schedule_model
::
record
(
module
&
p
,
instruction_ref
ins
,
std
::
size_t
wait_id
)
const
void
schedule_model
::
record
(
module
&
m
,
instruction_ref
ins
,
std
::
size_t
wait_id
)
const
{
{
p
.
insert_instruction
(
std
::
next
(
ins
),
record_event
{
wait_id
});
m
.
insert_instruction
(
std
::
next
(
ins
),
record_event
{
wait_id
});
}
}
static
std
::
unordered_map
<
std
::
string
,
std
::
size_t
>
create_weight_map
()
static
std
::
unordered_map
<
std
::
string
,
std
::
size_t
>
create_weight_map
()
...
...
Prev
1
2
3
4
5
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment