Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
14329696
Commit
14329696
authored
Aug 22, 2022
by
turneram
Browse files
Merge remote-tracking branch 'origin/develop' into rewrite-fast-gelu
parents
a48c41a9
79e15ca9
Changes
34
Hide whitespace changes
Inline
Side-by-side
Showing
14 changed files
with
420 additions
and
108 deletions
+420
-108
src/targets/gpu/jit/layernorm.cpp
src/targets/gpu/jit/layernorm.cpp
+133
-0
src/targets/gpu/jit/pointwise.cpp
src/targets/gpu/jit/pointwise.cpp
+8
-40
src/targets/gpu/kernels/include/migraphx/kernels/functional.hpp
...rgets/gpu/kernels/include/migraphx/kernels/functional.hpp
+3
-2
src/targets/gpu/kernels/include/migraphx/kernels/layernorm.hpp
...argets/gpu/kernels/include/migraphx/kernels/layernorm.hpp
+83
-0
src/targets/gpu/kernels/include/migraphx/kernels/reduce.hpp
src/targets/gpu/kernels/include/migraphx/kernels/reduce.hpp
+19
-0
src/targets/gpu/kernels/include/migraphx/kernels/vec.hpp
src/targets/gpu/kernels/include/migraphx/kernels/vec.hpp
+1
-1
src/targets/gpu/prefuse_ops.cpp
src/targets/gpu/prefuse_ops.cpp
+49
-37
src/targets/ref/lowering.cpp
src/targets/ref/lowering.cpp
+0
-1
test/fpga/get_target_assignments.cpp
test/fpga/get_target_assignments.cpp
+11
-8
test/include/test.hpp
test/include/test.hpp
+12
-9
test/simplify_reshapes_test.cpp
test/simplify_reshapes_test.cpp
+87
-0
test/verify/test_layernorm.cpp
test/verify/test_layernorm.cpp
+1
-1
tools/include/target.hpp
tools/include/target.hpp
+10
-8
tools/te.py
tools/te.py
+3
-1
No files found.
src/targets/gpu/jit/layernorm.cpp
0 → 100644
View file @
14329696
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/gpu/compiler.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/compile_hip_code_object.hpp>
#include <migraphx/gpu/compile_hip.hpp>
#include <migraphx/gpu/compile_gen.hpp>
#include <migraphx/cpp_generator.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/reduce_dims.hpp>
#include <migraphx/stringutils.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
gpu
{
using
namespace
migraphx
::
gpu
::
gen
;
// NOLINT
static
const
char
*
const
layernorm_kernel
=
R"__migraphx__(
#include <migraphx/kernels/index.hpp>
#include <migraphx/kernels/layernorm.hpp>
#include <migraphx/kernels/vectorize.hpp>
#include <migraphx/kernels/preload.hpp>
#include <args.hpp>
namespace migraphx {
${preamble}
extern "C" {
__global__ void ${kernel}(${params})
{
auto idx = make_index();
transform_args(make_tensors(), rotate_last(), ${transformers})(${args})([](auto... xs) {
${layernorm}<${axis}>(${post}, xs...);
});
}
}
} // namespace migraphx
)__migraphx__"
;
struct
layernorm_compiler
:
compiler
<
layernorm_compiler
>
{
std
::
vector
<
std
::
string
>
names
()
const
{
return
{
"layernorm"
,
"gpu::prelayernorm"
,
"gpu::preadd_layernorm"
};
}
operation
compile_op
(
context
&
ctx
,
const
std
::
vector
<
shape
>&
inputs
,
const
value
&
v
)
const
{
// TODO: Use reduce_dims
auto
axis
=
inputs
.
front
().
lens
().
size
()
-
1
;
auto
faxis
=
find_fast_axis
({
inputs
.
front
()});
vectorize
vec
{};
// Vectorize if the axis is a reduction axis
if
(
axis
==
faxis
)
{
vec
=
vectorize
::
elements
(
faxis
,
inputs
);
}
auto
preloads
=
preload
::
broadcasts
(
axis
,
inputs
);
auto
relements
=
inputs
[
0
].
lens
()[
axis
]
/
vec
.
size
;
auto
nelements
=
(
inputs
.
back
().
elements
()
/
inputs
[
0
].
lens
()[
axis
]);
auto
block_size
=
compute_block_size
(
relements
,
256
);
hip_compile_options
options
;
options
.
set_launch_params
(
v
,
compute_global_for
(
ctx
,
nelements
*
block_size
,
256
),
block_size
);
options
.
output
=
inputs
.
back
();
options
.
inputs
=
inputs
;
options
.
kernel_name
=
v
.
get
(
"kernel"
,
"layernorm_kernel"
);
auto
src
=
interpolate_string
(
layernorm_kernel
,
{{
"kernel"
,
options
.
kernel_name
},
{
"params"
,
enum_params
(
inputs
.
size
(),
"void * private_p"
)},
{
"args"
,
enum_params
(
inputs
.
size
(),
"private_p"
)},
{
"transformers"
,
make_transformer_args
(
preloads
,
vec
)},
{
"post"
,
v
.
get
(
"post"
,
std
::
string
{
"op::id{}"
})},
{
"preamble"
,
v
.
get
(
"preamble"
,
std
::
string
{})},
{
"layernorm"
,
v
.
get
(
"layernorm"
,
std
::
string
{
"layernorm"
})},
{
"axis"
,
to_string
(
axis
)}});
return
compile_hip_code_object
(
src
,
options
);
}
compiler_replace
compile
(
context
&
ctx
,
instruction_ref
ins
,
const
operation
&
op
)
const
{
auto
v
=
op
.
to_value
();
v
[
"layernorm"
]
=
"layernorm"
;
v
[
"kernel"
]
=
"layernorm_kernel"
;
if
(
op
.
name
()
==
"gpu::preadd_layernorm"
)
{
v
[
"layernorm"
]
=
"add_layernorm"
;
v
[
"kernel"
]
=
"add_layernorm_kernel"
;
}
if
(
not
ins
->
module_inputs
().
empty
())
{
auto
*
pm
=
ins
->
module_inputs
().
front
();
v
[
"preamble"
]
=
generate_pointwise
(
*
pm
,
"post_layernorm"
);
v
[
"post"
]
=
"MIGRAPHX_LIFT(post_layernorm)"
;
v
[
"kernel"
]
=
v
[
"layernorm"
].
to
<
std
::
string
>
()
+
"_"
+
generate_name_from_ops
(
*
pm
)
+
"_kernel"
;
}
return
replace
(
compile_op
(
ctx
,
to_shapes
(
ins
->
inputs
()),
v
));
}
};
}
// namespace gpu
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
src/targets/gpu/jit/pointwise.cpp
View file @
14329696
...
...
@@ -65,18 +65,6 @@ __global__ void ${kernel}(${params})
)__migraphx__"
;
static
std
::
vector
<
std
::
string
>
get_op_names
(
const
module
&
m
)
{
std
::
vector
<
std
::
string
>
result
;
for
(
auto
&
ins
:
m
)
{
if
(
starts_with
(
ins
.
name
(),
"@"
))
continue
;
result
.
push_back
(
ins
.
name
());
}
return
result
;
}
struct
pointwise_compiler
:
compiler
<
pointwise_compiler
>
{
std
::
vector
<
std
::
string
>
names
()
const
{
return
{
"pointwise"
,
"contiguous"
};
}
...
...
@@ -126,34 +114,14 @@ struct pointwise_compiler : compiler<pointwise_compiler>
else
{
assert
(
not
ins
->
module_inputs
().
empty
());
auto
*
pm
=
ins
->
module_inputs
().
front
();
run_passes
(
*
pm
,
{
eliminate_common_subexpression
{},
dead_code_elimination
{}});
cpp_generator
g
;
g
.
fmap
([](
const
std
::
string
&
fname
)
{
return
"migraphx::"
+
fname
;
});
g
.
add_point_op
(
"where"
,
"${function:where}(${0}, ${1}, ${2})"
);
g
.
add_point_op
(
"prelu"
,
"${function:where}(${0} < 0, ${0} * ${1}, ${0})"
);
g
.
add_point_op
(
"sign"
,
"${function:where}(${0} > 0, 1, ${function:where}(${0} < 0, -1, 0))"
);
g
.
add_point_op
(
"equal"
,
"migraphx::abs(${0} == ${1})"
);
g
.
add_point_op
(
"less"
,
"migraphx::abs(${0} < ${1})"
);
g
.
add_point_op
(
"greater"
,
"migraphx::abs(${0} > ${1})"
);
g
.
add_point_op
(
"not"
,
"migraphx::abs(not ${0})"
);
g
.
add_point_op
(
"mod"
,
"migraphx::mod(${0}, ${1})"
);
g
.
add_point_op
(
"fmod"
,
"migraphx::fmod(${0}, ${1})"
);
// Add explict conversions
g
.
fresult
([](
const
shape
&
s
)
{
return
"migraphx::convert<"
+
shape
::
cpp_type
(
s
.
type
())
+
">"
;
});
auto
name
=
g
.
create_function
(
g
.
generate_module
(
*
pm
).
set_attributes
({
"__device__"
}).
set_generic_types
(
*
pm
));
std
::
string
lambda
=
"MIGRAPHX_LIFT("
+
name
+
")"
;
auto
op_names
=
get_op_names
(
*
pm
);
op_names
.
push_back
(
"kernel"
);
auto
op_name_string
=
join_strings
(
op_names
,
"_"
);
return
replace
(
compile_op
(
ctx
,
to_shapes
(
ins
->
inputs
()),
{{
"lambda"
,
lambda
},
{
"preamble"
,
g
.
str
()},
{
"kernel"
,
op_name_string
}}));
auto
*
pm
=
ins
->
module_inputs
().
front
();
auto
pf
=
generate_pointwise
(
*
pm
,
"inner_pointwise"
);
std
::
string
lambda
=
"MIGRAPHX_LIFT(inner_pointwise)"
;
auto
kernel_name
=
generate_name_from_ops
(
*
pm
)
+
"_kernel"
;
return
replace
(
compile_op
(
ctx
,
to_shapes
(
ins
->
inputs
()),
{{
"lambda"
,
lambda
},
{
"preamble"
,
pf
},
{
"kernel"
,
kernel_name
}}));
}
}
};
...
...
src/targets/gpu/kernels/include/migraphx/kernels/functional.hpp
View file @
14329696
...
...
@@ -31,8 +31,9 @@
->decltype(__VA_ARGS__) { return __VA_ARGS__; }
// NOLINTNEXTLINE
#define MIGRAPHX_LIFT(...) \
[](auto&&... xs) MIGRAPHX_RETURNS((__VA_ARGS__)(static_cast<decltype(xs)>(xs)...))
#define MIGRAPHX_LIFT(...) \
[](auto&&... private_lisft_xs) MIGRAPHX_RETURNS( \
(__VA_ARGS__)(static_cast<decltype(private_lisft_xs)>(private_lisft_xs)...))
namespace
migraphx
{
...
...
src/targets/gpu/kernels/include/migraphx/kernels/layernorm.hpp
0 → 100644
View file @
14329696
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_KERNELS_LAYERNORM_HPP
#define MIGRAPHX_GUARD_KERNELS_LAYERNORM_HPP
#include <migraphx/kernels/reduce.hpp>
#include <migraphx/kernels/ops.hpp>
#include <migraphx/kernels/print.hpp>
namespace
migraphx
{
template
<
index_int
Axis
,
class
F
,
class
BinOp
,
class
Output
,
class
Input1
,
class
Input2
,
class
...
Inputs
>
__device__
void
generic_binary_layernorm
(
F
compute
,
BinOp
op
,
Output
output
,
Input1
input1
,
Input2
input2
,
Inputs
...
inputs
)
{
using
reduce_output
=
reduce
::
with_axis
<
Input1
,
Axis
>
;
reduce
::
block
::
run
<
reduce_output
>
([
&
](
auto
,
auto
r
)
{
using
value_type
=
typename
Input1
::
type
;
constexpr
auto
relements
=
r
.
template
elements
<
Input1
>();
auto
mean
=
[
&
](
auto
f
)
{
return
r
.
reduce
(
op
::
sum
{},
0
,
[
&
](
auto
x1
,
auto
x2
)
{
return
f
(
x1
,
x2
)
/
value_type
{
relements
};
})(
input1
,
input2
);
};
// mean(x)
auto
mean_x
=
mean
(
op
);
// mean(m ^ 2)
auto
mean_m2
=
mean
([
&
](
auto
x1
,
auto
x2
)
{
auto
m
=
op
(
x1
,
x2
)
-
mean_x
;
return
m
*
m
;
});
r
.
inner
([
&
](
auto
&
y
,
auto
x1
,
auto
x2
,
auto
...
xs
)
{
auto
m
=
op
(
x1
,
x2
)
-
mean_x
;
// m * rsqrt(mean(m ^ 2) + 1e-12)
y
=
compute
(
m
*
rsqrt
(
mean_m2
+
value_type
{
1e-12
}),
xs
...);
})(
output
,
input1
,
input2
,
inputs
...);
});
}
template
<
index_int
Axis
,
class
F
,
class
Output
,
class
Input
,
class
...
Inputs
>
__device__
void
layernorm
(
F
compute
,
Output
output
,
Input
input
,
Inputs
...
inputs
)
{
generic_binary_layernorm
<
Axis
>
(
compute
,
[](
auto
x
,
auto
)
{
return
x
;
},
output
,
input
,
input
,
inputs
...);
}
template
<
index_int
Axis
,
class
F
,
class
Output
,
class
Input1
,
class
Input2
,
class
...
Inputs
>
__device__
void
add_layernorm
(
F
compute
,
Output
output
,
Input1
input1
,
Input2
input2
,
Inputs
...
inputs
)
{
generic_binary_layernorm
<
Axis
>
(
compute
,
[](
auto
x1
,
auto
x2
)
{
return
x1
+
x2
;
},
output
,
input1
,
input2
,
inputs
...);
}
}
// namespace migraphx
#endif // MIGRAPHX_GUARD_KERNELS_LAYERNORM_HPP
src/targets/gpu/kernels/include/migraphx/kernels/reduce.hpp
View file @
14329696
...
...
@@ -224,6 +224,18 @@ struct block
idx
.
local_stride
(
x
.
get_shape
().
elements
(),
[
&
](
auto
j
)
{
f
(
x
[
j
],
xs
[
j
]...);
});
});
}
template
<
class
Input
>
constexpr
auto
elements
()
const
{
using
reduce_type
=
decltype
(
slicer
(
Input
{}));
using
value_type
=
typename
Input
::
type
;
constexpr
auto
relements
=
get_shape_c
<
reduce_type
>
{}.
elements
();
if
constexpr
(
vec_size
<
value_type
>
()
>
1
)
return
relements
*
vec_size
<
value_type
>
();
else
return
relements
;
}
};
template
<
class
Slicer
>
...
...
@@ -281,6 +293,13 @@ struct lane
}
});
}
template
<
class
Input
>
constexpr
auto
elements
()
const
{
using
reduce_type
=
decltype
(
slicer
(
Input
{}));
return
get_shape_c
<
reduce_type
>
{}.
elements
();
}
};
template
<
class
Slicer
>
...
...
src/targets/gpu/kernels/include/migraphx/kernels/vec.hpp
View file @
14329696
...
...
@@ -175,7 +175,7 @@ template <class T, class Op>
constexpr
auto
vec_reduce
(
T
x
,
Op
op
)
{
if
constexpr
(
vec_size
<
T
>
()
<
2
)
return
x
;
return
vec_type
<
T
>
{
x
}
;
else
{
vec_type
<
T
>
result
=
x
[
0
];
...
...
src/targets/gpu/prefuse_ops.cpp
View file @
14329696
...
...
@@ -24,12 +24,53 @@
#include <migraphx/gpu/prefuse_ops.hpp>
#include <migraphx/match/layernorm.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/register_op.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
gpu
{
namespace
{
template
<
class
Derived
,
std
::
size_t
N
>
struct
layernorm_base
{
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
,
std
::
vector
<
module_ref
>
mods
)
const
{
std
::
size_t
nargs
=
1
;
if
(
not
mods
.
empty
())
{
auto
*
pm
=
mods
.
front
();
nargs
=
pm
->
get_parameter_names
().
size
();
}
check_shapes
{
inputs
,
static_cast
<
const
Derived
&>
(
*
this
)}.
has
(
nargs
+
N
);
auto
s
=
inputs
.
at
(
0
);
if
(
s
.
scalar
())
{
return
s
;
}
else
if
(
s
.
broadcasted
())
{
return
{
s
.
type
(),
s
.
lens
()};
}
else
{
return
s
.
with_lens
(
s
.
lens
());
}
}
};
struct
layernorm
:
layernorm_base
<
layernorm
,
0
>
{
std
::
string
name
()
const
{
return
"gpu::prelayernorm"
;
}
};
MIGRAPHX_REGISTER_OP
(
layernorm
);
struct
add_layernorm
:
layernorm_base
<
add_layernorm
,
1
>
{
std
::
string
name
()
const
{
return
"gpu::preadd_layernorm"
;
}
};
MIGRAPHX_REGISTER_OP
(
add_layernorm
);
struct
find_layernorm
{
auto
matcher
()
const
{
return
match
::
layernorm
();
}
...
...
@@ -39,59 +80,30 @@ struct find_layernorm
auto
ins
=
r
.
result
;
auto
x_ins
=
r
.
instructions
[
"x"
];
if
(
not
x_ins
->
get_shape
().
standard
())
x_ins
=
m
.
insert_instruction
(
ins
,
make_op
(
"contiguous"
),
x_ins
);
auto
relements
=
x_ins
->
get_shape
().
lens
().
back
();
if
(
relements
>
1024
or
(
relements
%
4
!=
0
and
relements
>
256
))
return
;
auto
a
=
m
.
insert_instruction
(
ins
,
make_op
(
"hip::allocate"
,
{{
"shape"
,
to_value
(
x_ins
->
get_shape
())}}));
m
.
replace_instruction
(
ins
,
make_op
(
"gpu::layernorm"
),
x_ins
,
a
);
m
.
replace_instruction
(
ins
,
layernorm
{},
x_ins
);
}
};
struct
find_
tri
addlayernorm
struct
find_add
_
layernorm
{
auto
matcher
()
const
{
auto
add1
=
match
::
name
(
"add"
)(
match
::
none_of
(
match
::
is_constant
()),
match
::
args
(
match
::
any
().
bind
(
"z1"
),
match
::
any
().
bind
(
"z2"
)));
auto
add2
=
match
::
name
(
"add"
)(
match
::
either_arg
(
0
,
1
)(
add1
,
match
::
any
().
bind
(
"z3"
)));
return
match
::
layernorm
()(
match
::
var
(
"x"
)(
add2
));
return
match
::
layernorm
()(
match
::
var
(
"x"
)(
match
::
name
(
"add"
).
bind
(
"add"
)));
}
void
apply
(
module
&
m
,
const
match
::
matcher_result
&
r
)
const
{
auto
ins
=
r
.
result
;
auto
x_ins
=
r
.
instructions
[
"z1"
];
auto
y_ins
=
r
.
instructions
[
"z2"
];
auto
z_ins
=
r
.
instructions
[
"z3"
];
for
(
auto
*
pins
:
{
&
x_ins
,
&
y_ins
,
&
z_ins
})
{
if
(
not
(
*
pins
)
->
get_shape
().
standard
())
*
pins
=
m
.
insert_instruction
(
ins
,
make_op
(
"contiguous"
),
*
pins
);
}
auto
relements
=
x_ins
->
get_shape
().
lens
().
back
();
if
(
relements
>
1024
or
(
relements
%
4
!=
0
and
relements
>
256
))
return
;
auto
ins
=
r
.
result
;
auto
add_ins
=
r
.
instructions
[
"add"
];
auto
a
=
m
.
insert_instruction
(
ins
,
make_op
(
"hip::allocate"
,
{{
"shape"
,
to_value
(
x_ins
->
get_shape
())}}));
m
.
replace_instruction
(
ins
,
make_op
(
"gpu::triadd_layernorm"
),
x_ins
,
y_ins
,
z_ins
,
a
);
m
.
replace_instruction
(
ins
,
add_layernorm
{},
add_ins
->
inputs
());
}
};
}
// namespace
void
prefuse_ops
::
apply
(
module
&
m
)
const
{
match
::
find_matches
(
m
,
find_
tri
addlayernorm
{},
find_layernorm
{});
match
::
find_matches
(
m
,
find_add
_
layernorm
{},
find_layernorm
{});
}
}
// namespace gpu
...
...
src/targets/ref/lowering.cpp
View file @
14329696
...
...
@@ -244,7 +244,6 @@ struct ref_convolution : auto_register_op<ref_convolution<Op>>
auto
weights_lens
=
args
[
1
].
get_shape
().
lens
();
std
::
vector
<
std
::
size_t
>
k_lens
{
weights_lens
.
begin
()
+
2
,
weights_lens
.
end
()};
padding
=
calc_dyn_auto_pad
(
img_lens
,
k_lens
,
op
.
stride
,
op
.
dilation
);
std
::
cout
<<
"[ "
;
output_shape
=
compute_padded_shape
({
args
.
at
(
0
).
get_shape
(),
args
.
at
(
1
).
get_shape
()},
padding
);
}
...
...
test/get_target_assignments.cpp
→
test/
fpga/
get_target_assignments.cpp
View file @
14329696
...
...
@@ -26,8 +26,9 @@
#include <migraphx/make_op.hpp>
#include <migraphx/program.hpp>
#include <migraphx/register_target.hpp>
#include <migraphx/
ref
/target.hpp>
#include <migraphx/
fpga
/target.hpp>
#include <migraphx/target_assignments.hpp>
#include <migraphx/iterator_for.hpp>
migraphx
::
program
create_program
()
{
...
...
@@ -37,8 +38,8 @@ migraphx::program create_program()
auto
x
=
mm
->
add_parameter
(
"x"
,
s
);
auto
y
=
mm
->
add_parameter
(
"y"
,
s
);
auto
z
=
mm
->
add_parameter
(
"z"
,
s
);
auto
diff
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"
div
"
),
x
,
y
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"
div
"
),
diff
,
z
);
auto
diff
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"
add
"
),
x
,
y
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"
add
"
),
diff
,
z
);
return
p
;
}
...
...
@@ -47,14 +48,16 @@ TEST_CASE(is_supported)
auto
p
=
create_program
();
auto
targets
=
migraphx
::
get_targets
();
EXPECT
(
!
targets
.
empty
());
auto
first_target
=
targets
[
0
];
auto
t
=
migraphx
::
make_target
(
first_target
);
auto
t
=
migraphx
::
make_target
(
"fpga"
);
const
auto
assignments
=
p
.
get_target_assignments
({
t
});
for
(
const
auto
&
[
ins
,
target
]
:
assignments
)
const
auto
*
mod
=
p
.
get_main_module
();
EXPECT
(
mod
->
size
()
==
assignments
.
size
());
for
(
const
auto
ins
:
iterator_for
(
*
mod
))
{
(
void
)
ins
;
EXPECT
(
target
==
first_target
);
const
auto
&
target
=
assignments
.
at
(
ins
)
;
EXPECT
(
target
==
"fpga"
);
}
}
...
...
test/include/test.hpp
View file @
14329696
...
...
@@ -108,15 +108,7 @@ struct function
};
template
<
class
Stream
,
class
Iterator
>
inline
Stream
&
stream_range
(
Stream
&
s
,
Iterator
start
,
Iterator
last
)
{
if
(
start
!=
last
)
{
s
<<
*
start
;
std
::
for_each
(
std
::
next
(
start
),
last
,
[
&
](
auto
&&
x
)
{
s
<<
", "
<<
x
;
});
}
return
s
;
}
Stream
&
stream_range
(
Stream
&
s
,
Iterator
start
,
Iterator
last
);
template
<
class
Stream
>
inline
Stream
&
operator
<<
(
Stream
&
s
,
std
::
nullptr_t
)
...
...
@@ -136,6 +128,17 @@ inline auto operator<<(Stream& s, const Range& v) -> decltype(stream_range(s, v.
return
s
;
}
template
<
class
Stream
,
class
Iterator
>
inline
Stream
&
stream_range
(
Stream
&
s
,
Iterator
start
,
Iterator
last
)
{
if
(
start
!=
last
)
{
s
<<
*
start
;
std
::
for_each
(
std
::
next
(
start
),
last
,
[
&
](
auto
&&
x
)
{
s
<<
", "
<<
x
;
});
}
return
s
;
}
template
<
class
T
>
const
T
&
get_value
(
const
T
&
x
)
{
...
...
test/simplify_reshapes_test.cpp
View file @
14329696
...
...
@@ -39,6 +39,15 @@ void run_pass(migraphx::module& m)
migraphx
::
run_passes
(
m
,
{
migraphx
::
simplify_reshapes
{},
migraphx
::
dead_code_elimination
{}});
}
inline
std
::
vector
<
std
::
vector
<
std
::
size_t
>>
to_lens
(
const
std
::
vector
<
migraphx
::
shape
>&
shapes
)
{
std
::
vector
<
std
::
vector
<
std
::
size_t
>>
result
;
std
::
transform
(
shapes
.
begin
(),
shapes
.
end
(),
std
::
back_inserter
(
result
),
[
&
](
const
auto
&
s
)
{
return
s
.
lens
();
});
return
result
;
}
TEST_CASE
(
double_contig
)
{
migraphx
::
program
p
;
...
...
@@ -1275,4 +1284,82 @@ TEST_CASE(transpose_slice_single_transpose)
EXPECT
(
m1
==
m2
);
}
TEST_CASE
(
transpose_slice_non_packed_axis
)
{
migraphx
::
module
m1
;
{
auto
x
=
m1
.
add_parameter
(
"x"
,
{
migraphx
::
shape
::
float_type
,
{
2
,
384
,
36
,
64
}});
auto
transpose
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
{
0
,
2
,
1
,
3
}}}),
x
);
auto
slice
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"slice"
,
{{
"axes"
,
{
1
}},
{
"starts"
,
{
0
}},
{
"ends"
,
{
12
}}}),
transpose
);
auto
sqrt
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"sqrt"
),
slice
);
m1
.
add_return
({
sqrt
});
}
auto
output_shapes
=
m1
.
get_output_shapes
();
run_pass
(
m1
);
EXPECT
(
m1
.
get_output_shapes
()
==
output_shapes
);
migraphx
::
module
m2
;
{
auto
x
=
m2
.
add_parameter
(
"x"
,
{
migraphx
::
shape
::
float_type
,
{
2
,
384
,
36
,
64
}});
auto
unsqueeze
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"unsqueeze"
,
{{
"axes"
,
{
2
}},
{
"steps"
,
{
12
}}}),
x
);
auto
transpose
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
{
3
,
0
,
2
,
1
,
4
}}}),
unsqueeze
);
auto
slice
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"slice"
,
{{
"axes"
,
{
0
}},
{
"starts"
,
{
0
}},
{
"ends"
,
{
1
}}}),
transpose
);
auto
squeeze
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"squeeze"
,
{{
"axes"
,
{
0
}}}),
slice
);
auto
sqrt
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"sqrt"
),
squeeze
);
m2
.
add_return
({
sqrt
});
}
EXPECT
(
m1
==
m2
);
}
TEST_CASE
(
transpose_slice_non_packed_multi_axis
)
{
migraphx
::
module
m1
;
{
auto
x
=
m1
.
add_parameter
(
"x"
,
{
migraphx
::
shape
::
float_type
,
{
2
,
384
,
36
,
64
}});
auto
transpose
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
{
0
,
2
,
1
,
3
}}}),
x
);
auto
slice1
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"slice"
,
{{
"axes"
,
{
1
}},
{
"starts"
,
{
0
}},
{
"ends"
,
{
12
}}}),
transpose
);
auto
slice2
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"slice"
,
{{
"axes"
,
{
1
}},
{
"starts"
,
{
12
}},
{
"ends"
,
{
24
}}}),
transpose
);
auto
transpose2
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
{
0
,
1
,
3
,
2
}}}),
slice2
);
auto
slice3
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"slice"
,
{{
"axes"
,
{
1
}},
{
"starts"
,
{
24
}},
{
"ends"
,
{
36
}}}),
transpose
);
m1
.
add_return
({
slice1
,
transpose2
,
slice3
});
}
auto
output_shapes
=
m1
.
get_output_shapes
();
run_pass
(
m1
);
EXPECT
(
to_lens
(
m1
.
get_output_shapes
())
==
to_lens
(
output_shapes
));
migraphx
::
module
m2
;
{
auto
x
=
m2
.
add_parameter
(
"x"
,
{
migraphx
::
shape
::
float_type
,
{
2
,
384
,
36
,
64
}});
auto
unsqueeze
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"unsqueeze"
,
{{
"axes"
,
{
2
}},
{
"steps"
,
{
12
}}}),
x
);
auto
transpose
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
{
3
,
0
,
2
,
1
,
4
}}}),
unsqueeze
);
auto
slice1
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"slice"
,
{{
"axes"
,
{
0
}},
{
"starts"
,
{
0
}},
{
"ends"
,
{
1
}}}),
transpose
);
auto
squeeze1
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"squeeze"
,
{{
"axes"
,
{
0
}}}),
slice1
);
auto
slice2
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"slice"
,
{{
"axes"
,
{
0
}},
{
"starts"
,
{
1
}},
{
"ends"
,
{
2
}}}),
transpose
);
auto
squeeze2
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"squeeze"
,
{{
"axes"
,
{
0
}}}),
slice2
);
auto
transpose2
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
{
0
,
1
,
3
,
2
}}}),
squeeze2
);
auto
slice3
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"slice"
,
{{
"axes"
,
{
0
}},
{
"starts"
,
{
2
}},
{
"ends"
,
{
3
}}}),
transpose
);
auto
squeeze3
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"squeeze"
,
{{
"axes"
,
{
0
}}}),
slice3
);
m2
.
add_return
({
squeeze1
,
transpose2
,
squeeze3
});
}
EXPECT
(
m1
.
sort
()
==
m2
.
sort
());
}
int
main
(
int
argc
,
const
char
*
argv
[])
{
test
::
run
(
argc
,
argv
);
}
test/verify/test_layernorm.cpp
View file @
14329696
...
...
@@ -68,7 +68,7 @@ struct test_layernorm : verify_program<test_layernorm>
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
std
::
vector
<
size_t
>
dims
=
{
1
,
1
,
5
};
std
::
vector
<
size_t
>
dims
=
{
1
,
2
,
5
};
auto
x
=
mm
->
add_parameter
(
"x"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
dims
});
add_layernorm
(
*
mm
,
x
,
dims
);
return
p
;
...
...
tools/include/target.hpp
View file @
14329696
...
...
@@ -37,8 +37,10 @@
#include <migraphx/compile_options.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/rank.hpp>
#include <migraphx/module_ref.hpp>
#include <migraphx/support_metric.hpp>
#include <migraphx/instruction_ref.hpp>
#include <migraphx/supported_segments.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
...
@@ -64,12 +66,12 @@ struct target
*/
context
get_context
()
const
;
/**
* @brief
Check how well an
instruction
is
supported on a target
with the given metric
* @param
ins Instruction
to check
if it's
supported
* @param metric Used to define how the
return value
should be
interpret
ed
* @return
T
he
value based on the chosen metric. Negative numbers mean unsupported
* @brief
Get the ranges of
instruction
s that are
supported on a target
* @param
module Module
to check
for
supported
instructions
* @param metric Used to define how the
quality of the support
should be
measur
ed
* @return
t
he
supported segments of the graph
*/
float
is_supported
(
T
&
,
i
nst
ruction
_ref
ins
,
support_metric
m
)
const
;
supported_segments
target_
is_supported
(
T
&
,
co
nst
_module
_ref
mod
,
support_metric
m
etric
)
const
;
/**
* @brief copy an argument to the current target.
*
...
...
@@ -115,9 +117,9 @@ argument copy_from_target(T&, const argument& arg)
}
template
<
class
T
>
float
target_
is
_supported
(
T
&
,
i
nst
ruction
_ref
,
support_metric
)
supported_segments
target_
find
_supported
(
T
&
,
co
nst
_module
_ref
,
support_metric
)
{
return
0
;
return
{}
;
}
<%
...
...
@@ -125,7 +127,7 @@ interface('target',
virtual
(
'
name
'
,
returns
=
'
std
::
string
'
,
const
=
True
),
virtual
(
'
get_passes
'
,
ctx
=
'
context
&
'
,
options
=
'
const
compile_options
&
'
,
returns
=
'
std
::
vector
<
pass
>
'
,
const
=
True
),
virtual
(
'
get_context
'
,
returns
=
'
context
'
,
const
=
True
),
virtual
(
'
is
_supported
'
,
returns
=
'
float
'
,
ins
=
'
instruction
_ref
'
,
m
=
'
support_metric
'
,
const
=
True
,
default
=
'
target_
is
_supported
'
),
virtual
(
'
find
_supported
'
,
returns
=
'
supported_segments
'
,
mod
=
'
const_module
_ref
'
,
m
=
'
support_metric
'
,
const
=
True
,
default
=
'
target_
find
_supported
'
),
virtual
(
'
copy_to
'
,
returns
=
'
argument
'
,
input
=
'
const
argument
&
'
,
...
...
tools/te.py
View file @
14329696
...
...
@@ -23,7 +23,9 @@
#####################################################################################
import
string
,
sys
,
re
trivial
=
[
'std::size_t'
,
'instruction_ref'
,
'support_metric'
]
trivial
=
[
'std::size_t'
,
'instruction_ref'
,
'support_metric'
,
'const_module_ref'
]
headers
=
'''
#include <algorithm>
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment