Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
7e297b13
Commit
7e297b13
authored
Jun 13, 2022
by
Paul
Browse files
Merge
parents
86ea5e91
aa7ff911
Changes
765
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1588 additions
and
88 deletions
+1588
-88
src/targets/gpu/jit/gathernd.cpp
src/targets/gpu/jit/gathernd.cpp
+75
-0
src/targets/gpu/jit/pointwise.cpp
src/targets/gpu/jit/pointwise.cpp
+126
-0
src/targets/gpu/jit/reduce.cpp
src/targets/gpu/jit/reduce.cpp
+179
-0
src/targets/gpu/jit/roialign.cpp
src/targets/gpu/jit/roialign.cpp
+87
-0
src/targets/gpu/jit/scatternd.cpp
src/targets/gpu/jit/scatternd.cpp
+86
-0
src/targets/gpu/kernel.cpp
src/targets/gpu/kernel.cpp
+2
-0
src/targets/gpu/kernels/include/migraphx/kernels/algorithm.hpp
...argets/gpu/kernels/include/migraphx/kernels/algorithm.hpp
+49
-0
src/targets/gpu/kernels/include/migraphx/kernels/array.hpp
src/targets/gpu/kernels/include/migraphx/kernels/array.hpp
+52
-40
src/targets/gpu/kernels/include/migraphx/kernels/debug.hpp
src/targets/gpu/kernels/include/migraphx/kernels/debug.hpp
+142
-8
src/targets/gpu/kernels/include/migraphx/kernels/dfor.hpp
src/targets/gpu/kernels/include/migraphx/kernels/dfor.hpp
+25
-0
src/targets/gpu/kernels/include/migraphx/kernels/dpp.hpp
src/targets/gpu/kernels/include/migraphx/kernels/dpp.hpp
+55
-0
src/targets/gpu/kernels/include/migraphx/kernels/functional.hpp
...rgets/gpu/kernels/include/migraphx/kernels/functional.hpp
+97
-17
src/targets/gpu/kernels/include/migraphx/kernels/gathernd.hpp
...targets/gpu/kernels/include/migraphx/kernels/gathernd.hpp
+81
-0
src/targets/gpu/kernels/include/migraphx/kernels/generic_constant.hpp
...gpu/kernels/include/migraphx/kernels/generic_constant.hpp
+33
-0
src/targets/gpu/kernels/include/migraphx/kernels/hip.hpp
src/targets/gpu/kernels/include/migraphx/kernels/hip.hpp
+11
-0
src/targets/gpu/kernels/include/migraphx/kernels/index.hpp
src/targets/gpu/kernels/include/migraphx/kernels/index.hpp
+13
-12
src/targets/gpu/kernels/include/migraphx/kernels/integral_constant.hpp
...pu/kernels/include/migraphx/kernels/integral_constant.hpp
+20
-11
src/targets/gpu/kernels/include/migraphx/kernels/iota_iterator.hpp
...ts/gpu/kernels/include/migraphx/kernels/iota_iterator.hpp
+145
-0
src/targets/gpu/kernels/include/migraphx/kernels/math.hpp
src/targets/gpu/kernels/include/migraphx/kernels/math.hpp
+227
-0
src/targets/gpu/kernels/include/migraphx/kernels/ops.hpp
src/targets/gpu/kernels/include/migraphx/kernels/ops.hpp
+83
-0
No files found.
src/targets/gpu/jit/gathernd.cpp
0 → 100644
View file @
7e297b13
#include <migraphx/gpu/compiler.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/compile_hip_code_object.hpp>
#include <migraphx/gpu/compile_hip.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/reduce_dims.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/eliminate_common_subexpression.hpp>
#include <migraphx/module.hpp>
#include <migraphx/pass_manager.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
gpu
{
// NOLINTNEXTLINE
static
const
char
*
const
gathernd_kernel
=
R"__migraphx__(
#include <migraphx/kernels/gathernd.hpp>
#include <migraphx/kernels/ops.hpp>
#include <migraphx/kernels/integral_constant.hpp>
#include <migraphx/kernels/generic_constant.hpp>
#include <args.hpp>
namespace migraphx {
extern "C" {
__global__ void gathernd_kernel(void* in_data, void* in_indices, void* output)
{
make_tensors()(in_data, in_indices, output)([](auto&&... xs) {
auto settings = make_gathernd_settings(MIGRAPHX_MAKE_CONSTANT(int64_t{BATCH_DIMS}));
gathernd(xs..., settings);
});
}
}
} // namespace migraphx
)__migraphx__"
;
struct
gathernd_compiler
:
compiler
<
gathernd_compiler
>
{
std
::
vector
<
std
::
string
>
names
()
const
{
return
{
"gathernd"
};
}
operation
compile_op
(
context
&
ctx
,
const
std
::
vector
<
shape
>&
inputs
,
const
value
&
v
)
const
{
hip_compile_options
options
;
auto
out_s
=
inputs
.
back
();
options
.
set_launch_params
(
v
,
compute_global_for
(
ctx
,
out_s
.
elements
()));
options
.
inputs
=
inputs
;
options
.
output
=
out_s
;
options
.
kernel_name
=
"gathernd_kernel"
;
options
.
virtual_inputs
=
inputs
;
// batch_dims
assert
(
v
.
contains
(
"batch_dims"
));
auto
batch_dims
=
v
.
at
(
"batch_dims"
).
to
<
int64_t
>
();
options
.
params
+=
" -DBATCH_DIMS="
+
std
::
to_string
(
batch_dims
);
return
compile_hip_code_object
(
gathernd_kernel
,
options
);
}
compiler_replace
compile
(
context
&
ctx
,
instruction_ref
ins
,
const
operation
&
op
)
const
{
return
replace
(
compile_op
(
ctx
,
to_shapes
(
ins
->
inputs
()),
op
.
to_value
()));
}
};
}
// namespace gpu
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
src/targets/gpu/jit/pointwise.cpp
0 → 100644
View file @
7e297b13
#include <migraphx/gpu/compiler.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/compile_hip_code_object.hpp>
#include <migraphx/gpu/compile_hip.hpp>
#include <migraphx/gpu/compile_gen.hpp>
#include <migraphx/cpp_generator.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/reduce_dims.hpp>
#include <migraphx/permutation.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/eliminate_common_subexpression.hpp>
#include <migraphx/module.hpp>
#include <migraphx/pass_manager.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
gpu
{
using
namespace
migraphx
::
gpu
::
gen
;
// NOLINT
static
const
char
*
const
pointwise_kernel
=
R"__migraphx__(
#include <migraphx/kernels/index.hpp>
#include <migraphx/kernels/pointwise.hpp>
#include <args.hpp>
namespace migraphx {
${preamble}
extern "C" {
__global__ void ${kernel}(${params})
{
auto idx = make_index();
pointwise(idx, ${transformers})(${lambda}, ${args});
}
}
} // namespace migraphx
)__migraphx__"
;
static
std
::
vector
<
std
::
string
>
get_op_names
(
const
module
&
m
)
{
std
::
vector
<
std
::
string
>
result
;
for
(
auto
&
ins
:
m
)
{
if
(
starts_with
(
ins
.
name
(),
"@"
))
continue
;
result
.
push_back
(
ins
.
name
());
}
return
result
;
}
struct
pointwise_compiler
:
compiler
<
pointwise_compiler
>
{
std
::
vector
<
std
::
string
>
names
()
const
{
return
{
"pointwise"
};
}
static
std
::
size_t
oversubscribe_if
(
bool
b
)
{
if
(
b
)
return
256
;
else
return
1
;
}
operation
compile_op
(
context
&
ctx
,
const
std
::
vector
<
shape
>&
inputs
,
const
value
&
v
)
const
{
hip_compile_options
options
;
options
.
inputs
=
inputs
;
options
.
output
=
inputs
.
back
();
options
.
virtual_inputs
=
reduce_dims
(
inputs
);
options
.
params
=
"-Wno-float-equal"
;
auto
axis
=
find_fast_axis
(
options
.
virtual_inputs
);
auto
vec
=
vectorize
::
elements
(
axis
,
options
.
virtual_inputs
);
auto
preloads
=
preload
::
broadcasts
(
axis
,
options
.
virtual_inputs
);
options
.
kernel_name
=
v
.
get
(
"kernel"
,
"kernel"
);
options
.
set_launch_params
(
v
,
compute_global_for
(
ctx
,
options
.
output
.
elements
()
/
vec
.
size
,
oversubscribe_if
(
not
preloads
.
is_preloading
())));
auto
src
=
interpolate_string
(
pointwise_kernel
,
{{
"kernel"
,
options
.
kernel_name
},
{
"params"
,
enum_params
(
inputs
.
size
(),
"void * private_p"
)},
{
"args"
,
enum_params
(
inputs
.
size
(),
"private_p"
)},
{
"lambda"
,
v
.
at
(
"lambda"
).
to
<
std
::
string
>
()},
{
"transformers"
,
make_transformer_args
(
preloads
,
vec
)},
{
"preamble"
,
v
.
get
(
"preamble"
,
std
::
string
{})}});
return
compile_hip_code_object
(
src
,
options
);
}
compiler_replace
compile
(
context
&
ctx
,
instruction_ref
ins
,
const
operation
&
)
const
{
assert
(
not
ins
->
module_inputs
().
empty
());
auto
*
pm
=
ins
->
module_inputs
().
front
();
run_passes
(
*
pm
,
{
eliminate_common_subexpression
{},
dead_code_elimination
{}});
cpp_generator
g
;
g
.
fmap
([](
const
std
::
string
&
fname
)
{
return
"migraphx::"
+
fname
;
});
g
.
add_point_op
(
"where"
,
"${function:where}(${0}, ${1}, ${2})"
);
g
.
add_point_op
(
"prelu"
,
"${function:where}(${0} < 0, ${0} * ${1}, ${0})"
);
g
.
add_point_op
(
"sign"
,
"${function:where}(${0} > 0, 1, ${function:where}(${0} < 0, -1, 0))"
);
g
.
add_point_op
(
"equal"
,
"migraphx::abs(${0} == ${1})"
);
g
.
add_point_op
(
"less"
,
"migraphx::abs(${0} < ${1})"
);
g
.
add_point_op
(
"greater"
,
"migraphx::abs(${0} > ${1})"
);
g
.
add_point_op
(
"not"
,
"migraphx::abs(not ${0})"
);
// Add explict conversions
g
.
fresult
(
[](
const
shape
&
s
)
{
return
"migraphx::convert<"
+
shape
::
cpp_type
(
s
.
type
())
+
">"
;
});
auto
name
=
g
.
create_function
(
g
.
generate_module
(
*
pm
).
set_attributes
({
"__device__"
}).
set_generic_types
(
*
pm
));
std
::
string
lambda
=
"MIGRAPHX_LIFT("
+
name
+
")"
;
auto
op_names
=
get_op_names
(
*
pm
);
op_names
.
push_back
(
"kernel"
);
auto
op_name_string
=
join_strings
(
op_names
,
"_"
);
return
replace
(
compile_op
(
ctx
,
to_shapes
(
ins
->
inputs
()),
{{
"lambda"
,
lambda
},
{
"preamble"
,
g
.
str
()},
{
"kernel"
,
op_name_string
}}));
}
};
}
// namespace gpu
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
src/targets/gpu/jit/reduce.cpp
0 → 100644
View file @
7e297b13
#include <migraphx/gpu/compiler.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/compile_hip_code_object.hpp>
#include <migraphx/gpu/compile_hip.hpp>
#include <migraphx/gpu/compile_gen.hpp>
#include <migraphx/cpp_generator.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/reduce_dims.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/eliminate_common_subexpression.hpp>
#include <migraphx/module.hpp>
#include <migraphx/pass_manager.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
gpu
{
using
namespace
migraphx
::
gpu
::
gen
;
// NOLINT
static
const
char
*
const
simple_reduce_kernel
=
R"__migraphx__(
#include <migraphx/kernels/index.hpp>
#include <migraphx/kernels/reduce.hpp>
#include <migraphx/kernels/vectorize.hpp>
#include <args.hpp>
namespace migraphx {
${preamble}
extern "C" {
__global__ void reduce_kernel(void* input_p, void* output_p)
{
transform_args(make_tensors(), ${transformers})(input_p, output_p)([](auto input, auto output) {
simple_reduce<reduce::${algo}>(${reduction}, ${init}, input, output, ${read}, ${write});
});
}
}
} // namespace migraphx
)__migraphx__"
;
static
std
::
size_t
get_reduce_elements
(
const
std
::
vector
<
shape
>&
inputs
)
{
return
inputs
.
front
().
elements
()
/
inputs
.
back
().
elements
();
}
static
std
::
size_t
get_reduce_elements
(
const
std
::
vector
<
instruction_ref
>&
inputs
)
{
return
get_reduce_elements
(
to_shapes
(
inputs
));
}
static
std
::
vector
<
std
::
size_t
>
get_reduce_lens
(
const
std
::
vector
<
std
::
size_t
>&
input_lens
,
const
std
::
vector
<
std
::
size_t
>&
output_lens
)
{
std
::
vector
<
std
::
size_t
>
reduce_lens
;
std
::
transform
(
output_lens
.
begin
(),
output_lens
.
end
(),
input_lens
.
begin
(),
std
::
back_inserter
(
reduce_lens
),
[](
auto
x
,
auto
y
)
->
std
::
size_t
{
if
(
x
==
y
)
return
1
;
else
return
y
;
});
return
reduce_lens
;
}
static
std
::
string
get_reduce_algo
(
const
std
::
vector
<
shape
>&
inputs
)
{
auto
rlens
=
get_reduce_lens
(
inputs
.
front
().
lens
(),
inputs
.
back
().
lens
());
const
auto
init
=
std
::
numeric_limits
<
std
::
size_t
>::
max
();
// The minimum stride
auto
min_stride
=
std
::
inner_product
(
rlens
.
begin
(),
rlens
.
end
(),
inputs
.
front
().
strides
().
begin
(),
init
,
[](
auto
x
,
auto
y
)
{
return
std
::
min
(
x
,
y
);
},
[](
auto
len
,
auto
stride
)
{
return
len
==
1
?
init
:
stride
;
});
if
(
min_stride
>
2
)
return
"lane"
;
return
"block"
;
}
struct
reduce_compiler
:
compiler
<
reduce_compiler
>
{
std
::
vector
<
std
::
string
>
names
()
const
{
return
{
"reduce"
,
"reduce_sum"
,
"reduce_mean"
,
"reduce_max"
,
"reduce_min"
,
"reduce_prod"
};
}
operation
compile_op
(
context
&
ctx
,
const
std
::
vector
<
shape
>&
inputs
,
const
value
&
v
)
const
{
hip_compile_options
options
;
options
.
inputs
=
inputs
;
options
.
output
=
inputs
.
back
();
options
.
virtual_inputs
=
reduce_dims
(
inputs
);
auto
faxis
=
find_fast_axis
({
options
.
virtual_inputs
.
front
()});
vectorize
vec
{};
// Vectorize if the axis is a reduction axis
if
(
options
.
virtual_inputs
.
back
().
lens
()[
faxis
]
==
1
)
{
vec
=
vectorize
::
elements
(
faxis
,
options
.
virtual_inputs
);
}
auto
relements
=
get_reduce_elements
(
options
.
virtual_inputs
)
/
vec
.
size
;
auto
nelements
=
options
.
virtual_inputs
.
back
().
elements
();
auto
algo
=
v
.
get
(
"algo"
,
get_reduce_algo
(
options
.
virtual_inputs
));
if
(
algo
==
"block"
)
{
auto
block_size
=
compute_block_size
(
relements
,
256
);
options
.
set_launch_params
(
v
,
compute_global_for
(
ctx
,
nelements
*
block_size
,
256
),
block_size
);
}
else
if
(
algo
==
"lane"
)
{
options
.
set_launch_params
(
v
,
compute_global_for
(
ctx
,
nelements
,
256
));
}
else
{
MIGRAPHX_THROW
(
"Unknown reduce algo: "
+
algo
);
}
options
.
kernel_name
=
"reduce_kernel"
;
std
::
string
identity
=
"[](auto x) { return x; }"
;
auto
src
=
interpolate_string
(
simple_reduce_kernel
,
{{
"reduction"
,
v
.
at
(
"reduction"
).
to
<
std
::
string
>
()},
{
"init"
,
v
.
get
(
"init"
,
std
::
string
{
"0"
})},
{
"read"
,
v
.
get
(
"read"
,
identity
)},
{
"write"
,
v
.
get
(
"write"
,
identity
)},
{
"algo"
,
algo
},
{
"transformers"
,
make_transformer_args
(
vec
)},
{
"preamble"
,
v
.
get
(
"preamble"
,
std
::
string
{})}});
options
.
params
+=
"-Wno-float-equal"
;
return
compile_hip_code_object
(
src
,
options
);
}
compiler_replace
compile
(
context
&
ctx
,
instruction_ref
ins
,
const
operation
&
op
)
const
{
value
v
=
value
::
object
{};
auto
reduce_elements
=
get_reduce_elements
(
ins
->
inputs
());
if
(
op
.
name
()
==
"reduce_sum"
)
{
v
[
"reduction"
]
=
"op::sum{}"
;
}
else
if
(
op
.
name
()
==
"reduce_mean"
)
{
v
[
"reduction"
]
=
"op::sum{}"
;
v
[
"write"
]
=
"op::mean{"
+
std
::
to_string
(
reduce_elements
)
+
"}"
;
}
else
if
(
op
.
name
()
==
"reduce_max"
)
{
v
[
"reduction"
]
=
"op::max{}"
;
v
[
"init"
]
=
"lowest{}"
;
}
else
if
(
op
.
name
()
==
"reduce_min"
)
{
v
[
"reduction"
]
=
"op::min{}"
;
v
[
"init"
]
=
"highest{}"
;
}
else
if
(
op
.
name
()
==
"reduce_prod"
)
{
v
[
"reduction"
]
=
"op::product{}"
;
v
[
"init"
]
=
"1"
;
}
else
{
MIGRAPHX_THROW
(
"Unsupported reduce"
);
}
return
replace
(
compile_op
(
ctx
,
to_shapes
(
ins
->
inputs
()),
v
));
}
};
}
// namespace gpu
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
src/targets/gpu/jit/roialign.cpp
0 → 100644
View file @
7e297b13
#include <migraphx/gpu/compiler.hpp>
#include <migraphx/gpu/compile_hip_code_object.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/compile_hip.hpp>
#include <migraphx/cpp_generator.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/reduce_dims.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/eliminate_common_subexpression.hpp>
#include <migraphx/module.hpp>
#include <migraphx/pass_manager.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
gpu
{
// NOLINTNEXTLINE
static
const
char
*
const
roialign_kernel
=
R"__migraphx__(
#include <migraphx/kernels/roialign.hpp>
#include <migraphx/kernels/integral_constant.hpp>
#include <migraphx/kernels/generic_constant.hpp>
#include <args.hpp>
namespace migraphx {
extern "C" {
__global__ void roialign_kernel(void* in_x, void* in_rois, void* in_ind, void* y)
{
make_tensors()(in_x, in_rois, in_ind, y)([](auto&&... xs) {
auto settings = make_roalign_settings(MIGRAPHX_MAKE_CONSTANT(float{ROIS_OFFSET}),
_c<bool{IS_AVG_POOLING}>,
_c<int64_t{SAMPLING_RATIO}>,
MIGRAPHX_MAKE_CONSTANT(float{SPATIAL_SCALE}));
roialign(xs..., settings);
});
}
}
} // namespace migraphx
)__migraphx__"
;
struct
roialign_compiler
:
compiler
<
roialign_compiler
>
{
std
::
vector
<
std
::
string
>
names
()
const
{
return
{
"roialign"
};
}
operation
compile_op
(
context
&
ctx
,
const
std
::
vector
<
shape
>&
inputs
,
const
value
&
v
)
const
{
hip_compile_options
options
;
options
.
set_launch_params
(
v
,
compute_global_for
(
ctx
,
inputs
.
back
().
elements
()),
128
);
options
.
output
=
inputs
.
back
();
options
.
inputs
=
inputs
;
options
.
kernel_name
=
"roialign_kernel"
;
// sampling_ratio
options
.
params
+=
" -DSAMPLING_RATIO="
+
v
.
at
(
"sampling_ratio"
).
to
<
std
::
string
>
();
// pooling_mode
auto
mode
=
v
.
at
(
"mode"
).
to
<
migraphx
::
op
::
pooling_mode
>
();
std
::
string
is_avg_pooling
=
(
mode
==
migraphx
::
op
::
pooling_mode
::
average
)
?
"true"
:
"false"
;
options
.
params
+=
" -DIS_AVG_POOLING="
+
is_avg_pooling
;
// coord_trans_mode
auto
ctm
=
v
.
at
(
"coordinate_transformation_mode"
).
to
<
std
::
string
>
();
float
rois_offset
=
(
ctm
==
"output_half_pixel"
)
?
-
0.5
f
:
0.0
f
;
options
.
params
+=
" -DROIS_OFFSET="
+
std
::
to_string
(
rois_offset
);
// spatial_scale
options
.
params
+=
" -DSPATIAL_SCALE="
+
v
.
at
(
"spatial_scale"
).
to
<
std
::
string
>
();
return
compile_hip_code_object
(
roialign_kernel
,
options
);
}
compiler_replace
compile
(
context
&
ctx
,
instruction_ref
ins
,
const
operation
&
op
)
const
{
return
replace
(
compile_op
(
ctx
,
to_shapes
(
ins
->
inputs
()),
op
.
to_value
()));
}
};
}
// namespace gpu
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
src/targets/gpu/jit/scatternd.cpp
0 → 100644
View file @
7e297b13
#include <migraphx/gpu/compiler.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/compile_hip_code_object.hpp>
#include <migraphx/gpu/compile_hip.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/reduce_dims.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/eliminate_common_subexpression.hpp>
#include <migraphx/module.hpp>
#include <migraphx/pass_manager.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
gpu
{
// NOLINTNEXTLINE
static
const
char
*
const
scatternd_kernel
=
R"__migraphx__(
#include <migraphx/kernels/scatternd.hpp>
#include <migraphx/kernels/integral_constant.hpp>
#include <migraphx/kernels/generic_constant.hpp>
#include <args.hpp>
namespace migraphx {
extern "C" {
__global__ void scatternd_kernel(void* in_indices, void* in_updates, void* output)
{
make_tensors()(in_indices, in_updates, output)([](auto&&... xs) {
scatternd(xs..., ${reduction}{});
});
}
}
} // namespace migraphx
)__migraphx__"
;
struct
scatternd_compiler
:
compiler
<
scatternd_compiler
>
{
std
::
vector
<
std
::
string
>
names
()
const
{
return
{
"scatternd_none"
,
"scatternd_add"
,
"scatternd_mul"
};
}
operation
compile_op
(
context
&
ctx
,
const
std
::
vector
<
shape
>&
inputs
,
const
value
&
v
)
const
{
hip_compile_options
options
;
options
.
set_launch_params
(
v
,
compute_global_for
(
ctx
,
inputs
.
at
(
1
).
elements
()));
options
.
inputs
=
inputs
;
options
.
output
=
inputs
.
back
();
options
.
kernel_name
=
"scatternd_kernel"
;
options
.
virtual_inputs
=
inputs
;
auto
reduction
=
"assign_"
+
v
.
get
(
"reduction"
,
std
::
string
{
"none"
});
auto
src
=
interpolate_string
(
scatternd_kernel
,
{{
"reduction"
,
reduction
}});
return
compile_hip_code_object
(
src
,
options
);
}
compiler_replace
compile
(
context
&
ctx
,
instruction_ref
ins
,
const
operation
&
op
)
const
{
assert
(
starts_with
(
op
.
name
(),
"scatternd_"
));
auto
reduction
=
op
.
name
().
substr
(
10
);
return
insert
(
compile_op
(
ctx
,
to_shapes
({
ins
->
inputs
().
begin
()
+
1
,
ins
->
inputs
().
end
()}),
{{
"reduction"
,
reduction
}}));
}
compiler_replace
insert
(
const
operation
&
op
)
const
{
return
[
=
](
module
&
m
,
instruction_ref
ins
)
{
auto
args
=
ins
->
inputs
();
args
.
back
()
=
m
.
insert_instruction
(
ins
,
make_op
(
"hip::copy"
),
args
.
front
(),
args
.
back
());
args
.
erase
(
args
.
begin
());
return
m
.
replace_instruction
(
ins
,
op
,
args
);
};
}
};
}
// namespace gpu
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
src/targets/gpu/kernel.cpp
View file @
7e297b13
...
@@ -59,6 +59,8 @@ void launch_kernel(hipFunction_t fun,
...
@@ -59,6 +59,8 @@ void launch_kernel(hipFunction_t fun,
void
*
kernargs
,
void
*
kernargs
,
std
::
size_t
size
)
std
::
size_t
size
)
{
{
assert
(
global
>
0
);
assert
(
local
>
0
);
void
*
config
[]
=
{
void
*
config
[]
=
{
// HIP_LAUNCH_PARAM_* are macros that do horrible things
// HIP_LAUNCH_PARAM_* are macros that do horrible things
#ifdef MIGRAPHX_USE_CLANG_TIDY
#ifdef MIGRAPHX_USE_CLANG_TIDY
...
...
src/targets/gpu/kernels/include/migraphx/kernels/algorithm.hpp
View file @
7e297b13
...
@@ -21,6 +21,26 @@ struct greater
...
@@ -21,6 +21,26 @@ struct greater
}
}
};
};
template
<
class
InputIt
,
class
T
,
class
BinaryOperation
>
constexpr
T
accumulate
(
InputIt
first
,
InputIt
last
,
T
init
,
BinaryOperation
op
)
{
for
(;
first
!=
last
;
++
first
)
{
init
=
op
(
std
::
move
(
init
),
*
first
);
}
return
init
;
}
template
<
class
InputIt
,
class
OutputIt
>
constexpr
OutputIt
copy
(
InputIt
first
,
InputIt
last
,
OutputIt
d_first
)
{
while
(
first
!=
last
)
{
*
d_first
++
=
*
first
++
;
}
return
d_first
;
}
template
<
class
Iterator
,
class
Compare
>
template
<
class
Iterator
,
class
Compare
>
constexpr
Iterator
is_sorted_until
(
Iterator
first
,
Iterator
last
,
Compare
comp
)
constexpr
Iterator
is_sorted_until
(
Iterator
first
,
Iterator
last
,
Compare
comp
)
{
{
...
@@ -96,6 +116,35 @@ constexpr Iterator1 search(Iterator1 first, Iterator1 last, Iterator2 s_first, I
...
@@ -96,6 +116,35 @@ constexpr Iterator1 search(Iterator1 first, Iterator1 last, Iterator2 s_first, I
}
}
}
}
template
<
class
InputIt1
,
class
InputIt2
,
class
T
,
class
BinaryOperation1
,
class
BinaryOperation2
>
constexpr
T
inner_product
(
InputIt1
first1
,
InputIt1
last1
,
InputIt2
first2
,
T
init
,
BinaryOperation1
op1
,
BinaryOperation2
op2
)
{
while
(
first1
!=
last1
)
{
init
=
op1
(
init
,
op2
(
*
first1
,
*
first2
));
++
first1
;
++
first2
;
}
return
init
;
}
template
<
class
InputIt1
,
class
InputIt2
,
class
T
>
constexpr
T
inner_product
(
InputIt1
first1
,
InputIt1
last1
,
InputIt2
first2
,
T
init
)
{
return
inner_product
(
first1
,
last1
,
first2
,
init
,
[](
auto
x
,
auto
y
)
{
return
x
+
y
;
},
[](
auto
x
,
auto
y
)
{
return
x
*
y
;
});
}
}
// namespace migraphx
}
// namespace migraphx
#endif
#endif
src/targets/gpu/kernels/include/migraphx/kernels/array.hpp
View file @
7e297b13
...
@@ -2,40 +2,51 @@
...
@@ -2,40 +2,51 @@
#define MIGRAPHX_GUARD_AMDMIGRAPHX_KERNELS_ARRAY_HPP
#define MIGRAPHX_GUARD_AMDMIGRAPHX_KERNELS_ARRAY_HPP
#include <migraphx/kernels/types.hpp>
#include <migraphx/kernels/types.hpp>
#include <migraphx/kernels/type_traits.hpp>
#include <migraphx/kernels/integral_constant.hpp>
#include <migraphx/kernels/integral_constant.hpp>
#include <migraphx/kernels/debug.hpp>
#include <migraphx/kernels/debug.hpp>
namespace
migraphx
{
namespace
migraphx
{
// NOLINTNEXTLINE
// NOLINTNEXTLINE
#define MIGRAPHX_DEVICE_ARRAY_OP(op, binary_op) \
#define MIGRAPHX_DEVICE_ARRAY_OP(op, binary_op) \
constexpr array& operator op(const array& x) \
template <class U> \
{ \
constexpr array& operator op(const array<U, N>& x) \
for(index_int i = 0; i < N; i++) \
{ \
d[i] op x[i]; \
for(index_int i = 0; i < N; i++) \
return *this; \
d[i] op x[i]; \
} \
return *this; \
constexpr array& operator op(const T& x) \
} \
{ \
template <class U, MIGRAPHX_REQUIRES(is_convertible<U, T>{})> \
for(index_int i = 0; i < N; i++) \
constexpr array& operator op(const U& x) \
d[i] op x; \
{ \
return *this; \
for(index_int i = 0; i < N; i++) \
} \
d[i] op x; \
friend constexpr array operator binary_op(const array& x, const array& y) \
return *this; \
{ \
} \
auto z = x; \
template <class U> \
return z op y; \
friend constexpr auto operator binary_op(const array& x, const array<U, N>& y) \
} \
{ \
friend constexpr array operator binary_op(const array& x, const T& y) \
array<decltype(T {} binary_op U{}), N> z{}; \
{ \
for(index_int i = 0; i < N; i++) \
auto z = x; \
z[i] = x[i] binary_op y[i]; \
return z op y; \
return z; \
} \
} \
friend constexpr array operator binary_op(const T& x, const array& y) \
template <class U, MIGRAPHX_REQUIRES(is_convertible<U, T>{})> \
{ \
friend constexpr auto operator binary_op(const array& x, const U& y) \
for(index_int i = 0; i < N; i++) \
{ \
y[i] = x op y[i]; \
array<decltype(T {} binary_op U{}), N> z{}; \
return y; \
for(index_int i = 0; i < N; i++) \
z[i] = x[i] binary_op y; \
return z; \
} \
template <class U, MIGRAPHX_REQUIRES(is_convertible<U, T>{})> \
friend constexpr auto operator binary_op(const U& x, const array& y) \
{ \
array<decltype(T {} binary_op U{}), N> z{}; \
for(index_int i = 0; i < N; i++) \
z[i] = x binary_op y[i]; \
return z; \
}
}
template
<
class
T
,
index_int
N
>
template
<
class
T
,
index_int
N
>
...
@@ -63,6 +74,7 @@ struct array
...
@@ -63,6 +74,7 @@ struct array
constexpr
const
T
*
data
()
const
{
return
d
;
}
constexpr
const
T
*
data
()
const
{
return
d
;
}
constexpr
index_constant
<
N
>
size
()
const
{
return
{};
}
constexpr
index_constant
<
N
>
size
()
const
{
return
{};
}
constexpr
auto
empty
()
const
{
return
size
()
==
_c
<
0
>
;
}
constexpr
T
*
begin
()
{
return
d
;
}
constexpr
T
*
begin
()
{
return
d
;
}
constexpr
const
T
*
begin
()
const
{
return
d
;
}
constexpr
const
T
*
begin
()
const
{
return
d
;
}
...
@@ -134,8 +146,8 @@ struct array
...
@@ -134,8 +146,8 @@ struct array
constexpr
array
carry
(
array
result
)
const
constexpr
array
carry
(
array
result
)
const
{
{
u
in
t32_
t
overflow
=
0
;
in
dex_in
t
overflow
=
0
;
for
(
std
::
ptr
diff_t
i
=
result
.
size
()
-
1
;
i
>
0
;
i
--
)
for
(
diff_
in
t
i
=
result
.
size
()
-
1
;
i
>
0
;
i
--
)
{
{
auto
z
=
result
[
i
]
+
overflow
;
auto
z
=
result
[
i
]
+
overflow
;
// Reset overflow
// Reset overflow
...
@@ -165,23 +177,23 @@ struct array
...
@@ -165,23 +177,23 @@ struct array
}
}
};
};
template
<
class
T
,
T
...
x
s
>
template
<
class
T
,
T
...
X
s
>
struct
integral_const_array
:
array
<
T
,
sizeof
...(
x
s
)
>
struct
integral_const_array
:
array
<
T
,
sizeof
...(
X
s
)
>
{
{
using
base_array
=
array
<
T
,
sizeof
...(
x
s
)
>
;
using
base_array
=
array
<
T
,
sizeof
...(
X
s
)
>
;
MIGRAPHX_DEVICE_CONSTEXPR
integral_const_array
()
:
base_array
({
x
s
...})
{}
MIGRAPHX_DEVICE_CONSTEXPR
integral_const_array
()
:
base_array
({
X
s
...})
{}
};
};
template
<
class
T
,
T
...
x
s
,
class
F
>
template
<
class
T
,
T
...
X
s
,
class
F
>
constexpr
auto
transform
(
integral_const_array
<
T
,
x
s
...
>
,
F
f
)
constexpr
auto
transform
(
integral_const_array
<
T
,
X
s
...
>
,
F
f
)
{
{
return
integral_const_array
<
T
,
f
(
x
s
)...
>
{};
return
integral_const_array
<
T
,
f
(
X
s
)...
>
{};
}
}
template
<
class
T
,
T
...
x
s
,
class
U
,
U
...
y
s
,
class
F
>
template
<
class
T
,
T
...
X
s
,
class
U
,
U
...
Y
s
,
class
F
>
constexpr
auto
transform
(
integral_const_array
<
T
,
x
s
...
>
,
integral_const_array
<
U
,
y
s
...
>
,
F
f
)
constexpr
auto
transform
(
integral_const_array
<
T
,
X
s
...
>
,
integral_const_array
<
U
,
Y
s
...
>
,
F
f
)
{
{
return
integral_const_array
<
T
,
f
(
x
s
,
y
s
)...
>
{};
return
integral_const_array
<
T
,
f
(
X
s
,
Y
s
)...
>
{};
}
}
template
<
index_int
...
Ns
>
template
<
index_int
...
Ns
>
...
...
src/targets/gpu/kernels/include/migraphx/kernels/debug.hpp
100755 → 100644
View file @
7e297b13
#ifndef MIGRAPHX_GUARD_KERNELS_DEBUG_HPP
#ifndef MIGRAPHX_GUARD_KERNELS_DEBUG_HPP
#define MIGRAPHX_GUARD_KERNELS_DEBUG_HPP
#define MIGRAPHX_GUARD_KERNELS_DEBUG_HPP
#include <
hip/hip_runtime.h
>
#include <
migraphx/kernels/hip.hpp
>
namespace
migraphx
{
namespace
migraphx
{
inline
__host__
__device__
void
#define MIGRAPHX_STRINGIZE_1(...) #__VA_ARGS__
assert_fail
(
const
char
*
assertion
,
const
char
*
file
,
unsigned
int
line
,
const
char
*
function
)
#define MIGRAPHX_STRINGIZE(...) MIGRAPHX_STRINGIZE_1(__VA_ARGS__)
// Workaround hip's broken abort on device code
#ifdef __HIP_DEVICE_COMPILE__
// NOLINTNEXTLINE
#define MIGRAPHX_HIP_NORETURN
#else
// NOLINTNEXTLINE
#define MIGRAPHX_HIP_NORETURN [[noreturn]]
#endif
namespace
debug
{
struct
swallow
{
template
<
class
...
Ts
>
constexpr
swallow
(
Ts
&&
...)
{
}
};
template
<
size_t
N
>
struct
print_buffer
{
char
buffer
[
N
+
1
]
=
{
0
};
char
*
pos
=
buffer
;
constexpr
void
append
(
char
c
)
{
if
(
c
==
0
)
return
;
if
(
pos
<
buffer
+
N
)
{
*
pos
=
c
;
pos
++
;
}
}
template
<
class
T
,
class
=
decltype
(
T
{}
%
10
,
-
T
{}
)>
constexpr
void
append
(
T
i
)
{
if
(
i
<
0
)
{
append
(
'-'
);
i
=
-
i
;
}
char
c
=
(
i
%
10
)
+
'0'
;
if
(
i
>
9
)
append
(
i
/
10
);
append
(
c
);
}
constexpr
void
append
(
const
char
*
str
)
{
if
(
str
==
nullptr
)
return
;
int
i
=
512
;
while
(
*
str
!=
0
and
i
>
0
)
{
append
(
*
str
);
str
++
;
i
--
;
}
}
template
<
size_t
M
>
constexpr
void
append
(
const
char
(
&
array
)[
M
])
{
for
(
int
i
=
0
;
i
<
M
;
i
++
)
append
(
array
[
i
]);
}
};
template
<
class
...
Ts
>
__host__
__device__
void
print
(
const
Ts
&
...
xs
)
{
{
printf
(
"%s:%u: %s: assertion '%s' failed.
\n
"
,
file
,
line
,
function
,
assertion
);
print_buffer
<
1024
>
buffer
;
swallow
{(
buffer
.
append
(
xs
),
0
)...};
printf
(
"%s"
,
buffer
.
buffer
);
}
}
// namespace debug
struct
source_location
{
int
line
=
__builtin_LINE
();
const
char
*
file
=
__builtin_FILE
();
const
char
*
function
=
__builtin_FUNCTION
();
};
template
<
class
T
>
struct
source_location_capture
{
T
x
;
source_location
loc
;
template
<
class
U
,
class
=
decltype
(
T
(
U
{}
))>
constexpr
source_location_capture
(
U
px
,
source_location
ploc
=
source_location
{})
:
x
(
px
),
loc
(
ploc
)
{
}
constexpr
operator
source_location
()
const
{
return
loc
;
}
constexpr
operator
T
()
const
{
return
x
;
}
};
// noreturn cannot be used on this function because abort in hip is broken
template
<
class
T1
,
class
T2
,
class
T3
,
class
T4
>
MIGRAPHX_HIP_NORETURN
inline
__host__
__device__
void
assert_fail
(
const
T1
&
assertion
,
const
T2
&
file
,
const
T3
&
line
,
const
T4
&
function
)
{
// printf is broken on hip with more than one argument, so use a simple print functions instead
debug
::
print
(
file
,
":"
,
line
,
": "
,
function
,
": assertion '"
,
assertion
,
"' failed.
\n
"
);
// printf("%s:%s: %s: assertion '%s' failed.\n", file, line, function, assertion);
abort
();
abort
();
}
}
template
<
class
...
Ts
>
MIGRAPHX_HIP_NORETURN
inline
__host__
__device__
void
assert_fail
(
const
source_location
&
loc
,
Ts
...
xs
)
{
debug
::
print
(
loc
.
file
,
":"
,
loc
.
line
,
": "
,
loc
.
function
,
": error: "
,
xs
...,
"
\n
"
);
abort
();
}
// NOLINTNEXTLINE
#define MIGRAPHX_ASSERT_FAIL(cond, ...) \
((cond) ? void(0) : [](auto&&... private_migraphx_xs) { \
assert_fail(private_migraphx_xs...); \
}(__VA_ARGS__))
// NOLINTNEXTLINE
#define MIGRAPHX_CHECK(cond) \
MIGRAPHX_ASSERT_FAIL(cond, #cond, __FILE__, __LINE__, __PRETTY_FUNCTION__)
#ifdef MIGRAPHX_DEBUG
#ifdef MIGRAPHX_DEBUG
#define MIGRAPHX_ASSERT(cond) \
// NOLINTNEXTLINE
((cond) ? void(0) : [](auto... xs) { \
#define MIGRAPHX_CAPTURE_SOURCE_LOCATION(T) source_location_capture<T>
assert_fail(xs...); \
#define MIGRAPHX_WARN(cond, loc, ...) MIGRAPHX_ASSERT_FAIL(cond, loc, __VA_ARGS__)
}(#cond, __FILE__, __LINE__, __PRETTY_FUNCTION__))
#define MIGRAPHX_ASSERT MIGRAPHX_CHECK
#define MIGRAPHX_ASSUME MIGRAPHX_CHECK
#define MIGRAPHX_UNREACHABLE() MIGRAPHX_ASSERT(false)
#else
#else
// NOLINTNEXTLINE
#define MIGRAPHX_CAPTURE_SOURCE_LOCATION(T) T
#define MIGRAPHX_ASSUME __builtin_assume
#define MIGRAPHX_UNREACHABLE __builtin_unreachable
#define MIGRAPHX_ASSERT(cond)
#define MIGRAPHX_ASSERT(cond)
#define MIGRAPHX_WARN(...)
#endif
#endif
}
// namespace migraphx
}
// namespace migraphx
...
...
src/targets/gpu/kernels/include/migraphx/kernels/dfor.hpp
0 → 100644
View file @
7e297b13
#ifndef MIGRAPHX_GUARD_AMDMIGRAPHX_KERNELS_DFOR_HPP
#define MIGRAPHX_GUARD_AMDMIGRAPHX_KERNELS_DFOR_HPP
namespace
migraphx
{
// Multidimensional for loop
inline
constexpr
auto
dfor
()
{
return
[](
auto
f
)
{
f
();
};
}
template
<
class
T
,
class
...
Ts
>
constexpr
auto
dfor
(
T
x
,
Ts
...
xs
)
{
return
[
=
](
auto
f
)
{
for
(
T
i
=
0
;
i
<
x
;
i
++
)
{
dfor
(
xs
...)([
&
](
Ts
...
is
)
{
f
(
i
,
is
...);
});
}
};
}
}
// namespace migraphx
#endif
src/targets/gpu/kernels/include/migraphx/kernels/dpp.hpp
0 → 100644
View file @
7e297b13
#ifndef MIGRAPHX_GUARD_KERNELS_DPP_HPP
#define MIGRAPHX_GUARD_KERNELS_DPP_HPP
#include <migraphx/kernels/hip.hpp>
#include <migraphx/kernels/types.hpp>
#include <migraphx/kernels/debug.hpp>
namespace
migraphx
{
#ifndef MIGRAPHX_HAS_DPP
#define MIGRAPHX_HAS_DPP 1
#endif
#if MIGRAPHX_HAS_DPP
constexpr
unsigned
int
dpp_row_shr
(
unsigned
int
x
)
{
return
0x110u
|
x
;
}
constexpr
unsigned
int
dpp_row_bcast
(
unsigned
int
x
)
{
unsigned
int
y
=
0
;
switch
(
x
)
{
case
15
:
y
=
0x142
;
break
;
case
31
:
y
=
0x143
;
break
;
default:
MIGRAPHX_UNREACHABLE
();
}
return
y
;
}
template
<
unsigned
int
DppCtrl
,
unsigned
int
RowMask
=
0xf
,
unsigned
int
BankMask
=
0xf
,
bool
BoundCtrl
=
false
,
class
T
>
__device__
T
dpp_mov
(
T
&
x
)
{
static
const
index_int
n
=
sizeof
(
T
)
<
4
?
1
:
sizeof
(
T
)
/
4
;
union
type
{
uint32_t
reg
[
n
];
T
data
;
};
type
output
{};
type
input
{};
// cppcheck-suppress unreadVariable
input
.
data
=
x
;
for
(
index_int
i
=
0
;
i
<
n
;
i
++
)
{
output
.
reg
[
i
]
=
__hip_move_dpp
(
input
.
reg
[
i
],
DppCtrl
,
RowMask
,
BankMask
,
BoundCtrl
);
}
return
output
.
data
;
}
#endif
}
// namespace migraphx
#endif // MIGRAPHX_GUARD_KERNELS_DPP_HPP
src/targets/gpu/kernels/include/migraphx/kernels/functional.hpp
100755 → 100644
View file @
7e297b13
...
@@ -3,6 +3,14 @@
...
@@ -3,6 +3,14 @@
#include <migraphx/kernels/array.hpp>
#include <migraphx/kernels/array.hpp>
// NOLINTNEXTLINE
#define MIGRAPHX_RETURNS(...) \
->decltype(__VA_ARGS__) { return __VA_ARGS__; }
// NOLINTNEXTLINE
#define MIGRAPHX_LIFT(...) \
[](auto&&... xs) MIGRAPHX_RETURNS((__VA_ARGS__)(static_cast<decltype(xs)>(xs)...))
namespace
migraphx
{
namespace
migraphx
{
struct
swallow
struct
swallow
...
@@ -16,6 +24,19 @@ struct swallow
...
@@ -16,6 +24,19 @@ struct swallow
template
<
index_int
>
template
<
index_int
>
using
ignore
=
swallow
;
using
ignore
=
swallow
;
template
<
class
...
Fs
>
struct
overloaded
:
Fs
...
{
using
Fs
::
operator
()...;
overloaded
(
Fs
...
fs
)
:
Fs
(
fs
)...
{}
};
template
<
class
...
Fs
>
overloaded
<
Fs
...
>
overload
(
Fs
...
fs
)
{
return
{
fs
...};
}
namespace
detail
{
namespace
detail
{
template
<
class
R
>
template
<
class
R
>
...
@@ -116,7 +137,7 @@ constexpr auto by(F f)
...
@@ -116,7 +137,7 @@ constexpr auto by(F f)
template
<
class
F
,
class
...
Ts
>
template
<
class
F
,
class
...
Ts
>
constexpr
void
each_args
(
F
f
,
Ts
&&
...
xs
)
constexpr
void
each_args
(
F
f
,
Ts
&&
...
xs
)
{
{
swallow
{(
f
(
st
d
::
forward
<
Ts
>
(
xs
)),
0
)...};
swallow
{(
f
(
st
atic_cast
<
Ts
&&
>
(
xs
)),
0
)...};
}
}
template
<
class
F
>
template
<
class
F
>
...
@@ -124,12 +145,60 @@ constexpr void each_args(F)
...
@@ -124,12 +145,60 @@ constexpr void each_args(F)
{
{
}
}
template
<
class
F
,
class
T
>
constexpr
auto
fold_impl
(
F
&&
,
T
&&
x
)
{
return
static_cast
<
T
&&>
(
x
);
}
template
<
class
F
,
class
T
,
class
U
,
class
...
Ts
>
constexpr
auto
fold_impl
(
F
&&
f
,
T
&&
x
,
U
&&
y
,
Ts
&&
...
xs
)
{
return
fold_impl
(
f
,
f
(
static_cast
<
T
&&>
(
x
),
static_cast
<
U
&&>
(
y
)),
static_cast
<
Ts
&&>
(
xs
)...);
}
template
<
class
F
>
constexpr
auto
fold
(
F
f
)
{
return
[
=
](
auto
&&
...
xs
)
{
return
fold_impl
(
f
,
static_cast
<
decltype
(
xs
)
&&>
(
xs
)...);
};
}
template
<
class
...
Ts
>
template
<
class
...
Ts
>
auto
pack
(
Ts
...
xs
)
constexpr
auto
pack
(
Ts
...
xs
)
{
{
return
[
=
](
auto
f
)
{
return
f
(
xs
...);
};
return
[
=
](
auto
f
)
{
return
f
(
xs
...);
};
}
}
template
<
class
G
,
class
F
>
constexpr
auto
join
(
G
g
,
F
f
)
{
return
f
([
=
](
auto
...
xs
)
{
return
g
(
xs
...);
});
}
template
<
class
G
,
class
F
,
class
...
Fs
>
constexpr
auto
join
(
G
g
,
F
f
,
Fs
...
fs
)
{
return
f
([
=
](
auto
...
xs
)
{
return
join
([
=
](
auto
...
ys
)
{
return
g
(
xs
...,
ys
...);
},
fs
...);
});
}
template
<
class
Compare
,
class
P1
,
class
P2
>
constexpr
auto
pack_compare
(
Compare
compare
,
P1
p1
,
P2
p2
)
{
return
p1
([
&
](
auto
...
xs
)
{
return
p2
([
&
](
auto
...
ys
)
{
auto
c
=
[
&
](
auto
x
,
auto
y
)
->
int
{
if
(
compare
(
x
,
y
))
return
1
;
else
if
(
compare
(
y
,
x
))
return
-
1
;
else
return
0
;
};
return
fold
([](
auto
x
,
auto
y
)
{
return
x
?
x
:
y
;
})(
c
(
xs
,
ys
)...,
0
);
});
});
}
template
<
index_int
N
>
template
<
index_int
N
>
constexpr
auto
arg_c
()
constexpr
auto
arg_c
()
{
{
...
@@ -142,34 +211,45 @@ constexpr auto arg(IntegralConstant ic)
...
@@ -142,34 +211,45 @@ constexpr auto arg(IntegralConstant ic)
return
arg_c
<
ic
>
();
return
arg_c
<
ic
>
();
}
}
inline
constexpr
auto
rotate_last
()
template
<
class
F
>
constexpr
auto
make_transform
(
F
f
)
{
{
return
[](
auto
...
xs
)
{
return
[
=
](
auto
...
xs
)
{
return
[
=
](
auto
g
)
{
return
f
(
g
,
xs
...);
};
};
return
[
=
](
auto
&&
f
)
{
return
sequence_c
<
sizeof
...(
xs
)
>
([
&
](
auto
...
is
)
{
constexpr
auto
size
=
sizeof
...(
is
);
return
f
(
arg_c
<
(
is
+
size
-
1
)
%
size
>
()(
xs
...)...);
});
};
};
}
}
// An arg transformation takes the arguments and then a function to take the new arguments:
// transform(xs...)([](auto... ys) { ... })
// The transform_args function takes a list of transformations and continually applies them
template
<
class
F
>
template
<
class
F
>
constexpr
auto
transform_args
(
F
f
)
constexpr
auto
transform_args
(
F
f
)
{
{
return
[
=
](
auto
...
xs
)
{
return
f
;
return
[
=
](
auto
g
)
{
return
f
(
xs
...)([
&
](
auto
...
ys
)
{
return
g
(
ys
...);
});
};
};
}
}
template
<
class
F
,
class
...
Fs
>
template
<
class
F
,
class
...
Fs
>
constexpr
auto
transform_args
(
F
f
,
Fs
...
fs
)
constexpr
auto
transform_args
(
F
f
,
Fs
...
fs
)
{
{
return
[
=
](
auto
...
xs
)
{
return
transform_args
(
f
)(
xs
...)(
transform_args
(
fs
...));
};
return
make_transform
([
=
](
auto
g
,
auto
...
xs
)
{
return
f
(
xs
...)([
=
](
auto
...
ys
)
{
return
transform_args
(
fs
...)(
ys
...)(
g
);
});
});
}
}
#define MIGRAPHX_LIFT(...) \
// identity transform
([](auto&&... xs) { return (__VA_ARGS__)(static_cast<decltype(xs)>(xs)...); })
inline
constexpr
auto
transform_args
()
{
return
make_transform
([](
auto
f
,
auto
...
xs
)
{
return
f
(
xs
...);
});
}
// Rotate the first argument to the last argument
inline
constexpr
auto
rotate_last
()
{
return
make_transform
([](
auto
f
,
auto
...
xs
)
{
return
sequence_c
<
sizeof
...(
xs
)
>
([
&
](
auto
...
is
)
{
constexpr
auto
size
=
sizeof
...(
is
);
return
f
(
arg_c
<
(
is
+
size
-
1
)
%
size
>
()(
xs
...)...);
});
});
}
}
// namespace migraphx
}
// namespace migraphx
#endif // MIGRAPHX_GUARD_KERNELS_FUNCTIONAL_HPP
#endif // MIGRAPHX_GUARD_KERNELS_FUNCTIONAL_HPP
src/targets/gpu/kernels/include/migraphx/kernels/gathernd.hpp
0 → 100644
View file @
7e297b13
#ifndef MIGRAPHX_GUARD_KERNELS_GATHERND_HPP
#define MIGRAPHX_GUARD_KERNELS_GATHERND_HPP
#include <migraphx/kernels/index.hpp>
#include <migraphx/kernels/algorithm.hpp>
namespace
migraphx
{
template
<
class
T
>
struct
gathernd_settings
{
T
batch_dims
{};
};
template
<
class
...
Ts
>
constexpr
gathernd_settings
<
Ts
...
>
make_gathernd_settings
(
Ts
...
xs
)
{
return
{
xs
...};
}
template
<
class
T
,
class
U
,
class
V
,
class
Settings
>
__device__
void
gathernd
(
const
T
&
data_t
,
const
U
&
indices_t
,
const
V
&
output_t
,
Settings
s
)
{
auto
ind
=
make_index
();
auto
batch_dims
=
s
.
batch_dims
;
auto
output_shape
=
output_t
.
get_shape
();
auto
indices_shape
=
indices_t
.
get_shape
();
auto
data_shape
=
data_t
.
get_shape
();
auto
indices_shape_lens
=
indices_shape
.
lens
;
auto
data_shape_lens
=
data_shape
.
lens
;
auto
num_slice_dims
=
indices_shape_lens
.
back
();
std
::
size_t
num_slices
=
accumulate
(
indices_shape_lens
.
begin
(),
indices_shape_lens
.
end
()
-
1
,
1
,
std
::
multiplies
<
std
::
size_t
>
());
std
::
size_t
slice_size
=
accumulate
(
data_shape_lens
.
begin
()
+
num_slice_dims
+
batch_dims
,
data_shape_lens
.
end
(),
1
,
std
::
multiplies
<
std
::
size_t
>
());
const
std
::
size_t
num_batches
=
accumulate
(
data_shape_lens
.
begin
(),
data_shape_lens
.
begin
()
+
batch_dims
,
1
,
std
::
multiplies
<
std
::
size_t
>
());
const
std
::
size_t
data_batch_stride
=
accumulate
(
data_shape_lens
.
begin
()
+
batch_dims
,
data_shape_lens
.
end
(),
1
,
std
::
multiplies
<
std
::
size_t
>
());
const
auto
num_slices_per_batch
=
num_slices
/
num_batches
;
ind
.
global_stride
(
output_shape
.
elements
(),
[
&
](
auto
i
)
{
const
auto
*
indices_ptr
=
indices_t
.
data
();
const
std
::
size_t
j
=
i
/
slice_size
;
const
std
::
size_t
batch_idx
=
j
/
num_slices_per_batch
;
auto
*
slice_indices
=
indices_ptr
+
(
j
*
num_slice_dims
);
std
::
size_t
relative_slice_offset
=
0
;
for
(
std
::
size_t
idx
=
0
;
idx
<
num_slice_dims
;
++
idx
)
{
int64_t
index
=
slice_indices
[
idx
];
const
std
::
size_t
input_dim_idx
=
batch_dims
+
idx
;
const
auto
input_dim
=
data_shape_lens
[
input_dim_idx
];
assert
(
index
>=
-
static_cast
<
int64_t
>
(
input_dim
)
and
index
<
static_cast
<
int64_t
>
(
input_dim
));
if
(
index
<
0
)
index
+=
input_dim
;
std
::
size_t
size_from_slice_dims
=
accumulate
(
data_shape_lens
.
begin
()
+
batch_dims
+
idx
+
1
,
data_shape_lens
.
begin
()
+
batch_dims
+
num_slice_dims
,
slice_size
,
std
::
multiplies
<
std
::
size_t
>
());
relative_slice_offset
+=
index
*
size_from_slice_dims
;
}
auto
slice_offset
=
(
batch_idx
*
data_batch_stride
)
+
relative_slice_offset
;
output_t
[
i
]
=
data_t
[
slice_offset
+
i
%
slice_size
];
});
}
}
// namespace migraphx
#endif
src/targets/gpu/kernels/include/migraphx/kernels/generic_constant.hpp
0 → 100644
View file @
7e297b13
#ifndef MIGRAPHX_GUARD_KERNELS_GENERIC_CONSTANT_HPP
#define MIGRAPHX_GUARD_KERNELS_GENERIC_CONSTANT_HPP
namespace
migraphx
{
template
<
class
F
>
struct
generic_constant
{
static
constexpr
auto
value
=
F
{}();
using
value_type
=
decltype
(
value
);
using
type
=
generic_constant
;
constexpr
operator
value_type
()
const
noexcept
{
return
value
;
}
constexpr
value_type
operator
()()
const
noexcept
{
return
value
;
}
};
template
<
class
F
>
constexpr
generic_constant
<
F
>
make_generic_constant
(
F
)
{
return
{};
}
// NOLINTNEXTLINE
#define MIGRAPHX_MAKE_CONSTANT(x) \
make_generic_constant([] { \
struct fun \
{ \
constexpr auto operator()() const { return x; } \
}; \
return fun{}; \
}())
}
// namespace migraphx
#endif // MIGRAPHX_GUARD_KERNELS_GENERIC_CONSTANT_HPP
src/targets/gpu/kernels/include/migraphx/kernels/hip.hpp
0 → 100644
View file @
7e297b13
#ifndef MIGRAPHX_GUARD_KERNELS_HIP_HPP
#define MIGRAPHX_GUARD_KERNELS_HIP_HPP
// Workaround macro redefinition issue with clang tidy
#if defined(__HIP_PLATFORM_HCC__) && defined(MIGRAPHX_USE_CLANG_TIDY)
#undef __HIP_PLATFORM_HCC__ // NOLINT
#endif
#include <hip/hip_runtime.h>
#endif // MIGRAPHX_GUARD_KERNELS_HIP_HPP
src/targets/gpu/kernels/include/migraphx/kernels/index.hpp
100755 → 100644
View file @
7e297b13
#ifndef MIGRAPHX_GUARD_KERNELS_INDEX_HPP
#ifndef MIGRAPHX_GUARD_KERNELS_INDEX_HPP
#define MIGRAPHX_GUARD_KERNELS_INDEX_HPP
#define MIGRAPHX_GUARD_KERNELS_INDEX_HPP
#include <
hip/hip_runtime.h
>
#include <
migraphx/kernels/hip.hpp
>
#include <migraphx/kernels/types.hpp>
#include <migraphx/kernels/types.hpp>
#include <migraphx/kernels/integral_constant.hpp>
namespace
migraphx
{
namespace
migraphx
{
...
@@ -12,23 +13,23 @@ struct index
...
@@ -12,23 +13,23 @@ struct index
index_int
local
=
0
;
index_int
local
=
0
;
index_int
group
=
0
;
index_int
group
=
0
;
__device__
index_int
nglobal
()
const
{
#ifdef MIGRAPHX_NGLOBAL
#ifdef MIGRAPHX_NGLOBAL
return
MIGRAPHX_NGLOBAL
;
constexpr
index_constant
<
MIGRAPHX_NGLOBAL
>
nglobal
()
const
{
return
{};
}
#else
#else
return
blockDim
.
x
*
gridDim
.
x
;
__device__
index_int
nglobal
()
const
#endif
{
return
blockDim
.
x
*
gridDim
.
x
;
// NOLINT
}
}
#endif
__device__
index_int
nlocal
()
const
{
#ifdef MIGRAPHX_NLOCAL
#ifdef MIGRAPHX_NLOCAL
return
MIGRAPHX_NLOCAL
;
constexpr
index_constant
<
MIGRAPHX_NLOCAL
>
nlocal
()
const
{
return
{};
}
#else
#else
return
blockDim
.
x
;
__device__
index_int
nlocal
()
const
#endif
{
return
blockDim
.
x
;
// NOLINT
}
}
#endif
template
<
class
F
>
template
<
class
F
>
__device__
void
global_stride
(
index_int
n
,
F
f
)
const
__device__
void
global_stride
(
index_int
n
,
F
f
)
const
...
@@ -53,7 +54,7 @@ struct index
...
@@ -53,7 +54,7 @@ struct index
inline
__device__
index
make_index
()
inline
__device__
index
make_index
()
{
{
return
index
{
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
,
threadIdx
.
x
,
blockIdx
.
x
};
return
index
{
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
,
threadIdx
.
x
,
blockIdx
.
x
};
// NOLINT
}
}
}
// namespace migraphx
}
// namespace migraphx
...
...
src/targets/gpu/kernels/include/migraphx/kernels/integral_constant.hpp
View file @
7e297b13
...
@@ -5,28 +5,31 @@
...
@@ -5,28 +5,31 @@
namespace
migraphx
{
namespace
migraphx
{
template
<
class
T
,
T
v
>
template
<
class
T
,
T
V
>
struct
integral_constant
struct
integral_constant
{
{
static
constexpr
T
value
=
v
;
static
constexpr
T
value
=
V
;
using
value_type
=
T
;
using
value_type
=
T
;
using
type
=
integral_constant
;
using
type
=
integral_constant
;
constexpr
operator
value_type
()
const
noexcept
{
return
value
;
}
constexpr
operator
value_type
()
const
noexcept
{
return
value
;
}
constexpr
value_type
operator
()()
const
noexcept
{
return
value
;
}
constexpr
value_type
operator
()()
const
noexcept
{
return
value
;
}
static
constexpr
type
to
()
{
return
{};
}
};
};
// NOLINTNEXTLINE
#define MIGRAPHX_INTEGRAL_CONSTANT_BINARY_OP(op) \
#define MIGRAPHX_INTEGRAL_CONSTANT_BINARY_OP(op) \
template <class T, T
v
, class U, U w> \
template <class T, T
V
, class U, U w> \
constexpr inline integral_constant<decltype(
v
op w), (
v
op w)> operator op( \
constexpr inline integral_constant<decltype(
V
op w), (
V
op w)> operator op( \
integral_constant<T,
v
>, integral_constant<U, w>) noexcept \
integral_constant<T,
V
>, integral_constant<U, w>) noexcept \
{ \
{ \
return {}; \
return {}; \
}
}
// NOLINTNEXTLINE
#define MIGRAPHX_INTEGRAL_CONSTANT_UNARY_OP(op) \
#define MIGRAPHX_INTEGRAL_CONSTANT_UNARY_OP(op) \
template <class T, T
v
> \
template <class T, T
V
> \
constexpr inline integral_constant<decltype(op
v
), (op
v
)> operator op( \
constexpr inline integral_constant<decltype(op
V
), (op
V
)> operator op( \
integral_constant<T,
v
>) noexcept \
integral_constant<T,
V
>) noexcept \
{ \
{ \
return {}; \
return {}; \
}
}
...
@@ -45,7 +48,7 @@ MIGRAPHX_INTEGRAL_CONSTANT_BINARY_OP(>=)
...
@@ -45,7 +48,7 @@ MIGRAPHX_INTEGRAL_CONSTANT_BINARY_OP(>=)
MIGRAPHX_INTEGRAL_CONSTANT_BINARY_OP
(
==
)
MIGRAPHX_INTEGRAL_CONSTANT_BINARY_OP
(
==
)
MIGRAPHX_INTEGRAL_CONSTANT_BINARY_OP
(
!=
)
MIGRAPHX_INTEGRAL_CONSTANT_BINARY_OP
(
!=
)
MIGRAPHX_INTEGRAL_CONSTANT_BINARY_OP
(
&
)
MIGRAPHX_INTEGRAL_CONSTANT_BINARY_OP
(
&
)
MIGRAPHX_INTEGRAL_CONSTANT_BINARY_OP
(
^
)
MIGRAPHX_INTEGRAL_CONSTANT_BINARY_OP
(
^
)
MIGRAPHX_INTEGRAL_CONSTANT_BINARY_OP
(
|
)
MIGRAPHX_INTEGRAL_CONSTANT_BINARY_OP
(
|
)
MIGRAPHX_INTEGRAL_CONSTANT_BINARY_OP
(
&&
)
MIGRAPHX_INTEGRAL_CONSTANT_BINARY_OP
(
&&
)
MIGRAPHX_INTEGRAL_CONSTANT_BINARY_OP
(
||
)
MIGRAPHX_INTEGRAL_CONSTANT_BINARY_OP
(
||
)
...
@@ -64,8 +67,14 @@ using false_type = bool_constant<false>;
...
@@ -64,8 +67,14 @@ using false_type = bool_constant<false>;
template
<
index_int
N
>
template
<
index_int
N
>
using
index_constant
=
integral_constant
<
index_int
,
N
>
;
using
index_constant
=
integral_constant
<
index_int
,
N
>
;
template
<
auto
v
>
template
<
auto
V
>
static
constexpr
auto
_c
=
integral_constant
<
decltype
(
v
),
v
>
{};
static
constexpr
auto
_c
=
integral_constant
<
decltype
(
V
),
V
>
{};
// NOLINT
template
<
class
F
>
constexpr
auto
return_c
(
F
f
)
{
return
_c
<
f
()
>
;
}
}
// namespace migraphx
}
// namespace migraphx
#endif // MIGRAPHX_GUARD_KERNELS_INTEGRAL_CONSTANT_HPP
#endif // MIGRAPHX_GUARD_KERNELS_INTEGRAL_CONSTANT_HPP
src/targets/gpu/kernels/include/migraphx/kernels/iota_iterator.hpp
0 → 100644
View file @
7e297b13
#ifndef MIGRAPHX_GUARD_KERNELS_IOTA_ITERATOR_HPP
#define MIGRAPHX_GUARD_KERNELS_IOTA_ITERATOR_HPP
#include <migraphx/kernels/types.hpp>
#include <migraphx/kernels/type_traits.hpp>
namespace
migraphx
{
template
<
class
F
,
class
Iterator
=
diff_int
>
struct
basic_iota_iterator
{
Iterator
index
;
F
f
;
using
difference_type
=
diff_int
;
using
reference
=
decltype
(
f
(
declval
<
Iterator
>
()));
using
value_type
=
remove_reference_t
<
reference
>
;
using
pointer
=
add_pointer_t
<
value_type
>
;
constexpr
basic_iota_iterator
&
operator
+=
(
diff_int
n
)
{
index
+=
n
;
return
*
this
;
}
constexpr
basic_iota_iterator
&
operator
-=
(
diff_int
n
)
{
index
-=
n
;
return
*
this
;
}
constexpr
basic_iota_iterator
&
operator
++
()
{
index
++
;
return
*
this
;
}
constexpr
basic_iota_iterator
&
operator
--
()
{
index
--
;
return
*
this
;
}
constexpr
basic_iota_iterator
operator
++
(
int
)
// NOLINT
{
basic_iota_iterator
it
=
*
this
;
index
++
;
return
it
;
}
constexpr
basic_iota_iterator
operator
--
(
int
)
// NOLINT
{
basic_iota_iterator
it
=
*
this
;
index
--
;
return
it
;
}
// TODO: operator->
constexpr
reference
operator
*
()
const
{
return
f
(
index
);
}
template
<
class
T
>
constexpr
reference
operator
[](
T
x
)
const
{
return
f
(
index
+
x
);
}
};
template
<
class
T
,
class
F
>
constexpr
basic_iota_iterator
<
F
,
T
>
make_basic_iota_iterator
(
T
x
,
F
f
)
{
return
basic_iota_iterator
<
F
,
T
>
{
x
,
f
};
}
template
<
class
F
,
class
Iterator
>
constexpr
basic_iota_iterator
<
F
,
Iterator
>
operator
+
(
basic_iota_iterator
<
F
,
Iterator
>
x
,
diff_int
y
)
{
return
x
+=
y
;
}
template
<
class
F
,
class
Iterator
>
constexpr
basic_iota_iterator
<
F
,
Iterator
>
operator
+
(
diff_int
x
,
basic_iota_iterator
<
F
,
Iterator
>
y
)
{
return
y
+
x
;
}
template
<
class
F
,
class
Iterator
>
constexpr
diff_int
operator
-
(
basic_iota_iterator
<
F
,
Iterator
>
x
,
basic_iota_iterator
<
F
,
Iterator
>
y
)
{
return
x
.
index
-
y
.
index
;
}
template
<
class
F
,
class
Iterator
>
constexpr
basic_iota_iterator
<
F
,
Iterator
>
operator
-
(
basic_iota_iterator
<
F
,
Iterator
>
x
,
diff_int
y
)
{
return
x
-=
y
;
}
template
<
class
F
,
class
Iterator
>
constexpr
bool
operator
==
(
basic_iota_iterator
<
F
,
Iterator
>
x
,
basic_iota_iterator
<
F
,
Iterator
>
y
)
{
return
x
.
index
==
y
.
index
;
}
template
<
class
F
,
class
Iterator
>
constexpr
bool
operator
!=
(
basic_iota_iterator
<
F
,
Iterator
>
x
,
basic_iota_iterator
<
F
,
Iterator
>
y
)
{
return
x
.
index
!=
y
.
index
;
}
template
<
class
F
,
class
Iterator
>
constexpr
bool
operator
<
(
basic_iota_iterator
<
F
,
Iterator
>
x
,
basic_iota_iterator
<
F
,
Iterator
>
y
)
{
return
x
.
index
<
y
.
index
;
}
template
<
class
F
,
class
Iterator
>
constexpr
bool
operator
>
(
basic_iota_iterator
<
F
,
Iterator
>
x
,
basic_iota_iterator
<
F
,
Iterator
>
y
)
{
return
x
.
index
>
y
.
index
;
}
template
<
class
F
,
class
Iterator
>
constexpr
bool
operator
>=
(
basic_iota_iterator
<
F
,
Iterator
>
x
,
basic_iota_iterator
<
F
,
Iterator
>
y
)
{
return
x
.
index
>=
y
.
index
;
}
template
<
class
F
,
class
Iterator
>
constexpr
bool
operator
<=
(
basic_iota_iterator
<
F
,
Iterator
>
x
,
basic_iota_iterator
<
F
,
Iterator
>
y
)
{
return
x
.
index
<=
y
.
index
;
}
struct
defaul_iota_iterator
{
template
<
class
T
>
constexpr
auto
operator
()(
T
x
)
const
{
return
x
;
}
};
using
iota_iterator
=
basic_iota_iterator
<
defaul_iota_iterator
>
;
}
// namespace migraphx
#endif // MIGRAPHX_GUARD_KERNELS_IOTA_ITERATOR_HPP
src/targets/gpu/kernels/include/migraphx/kernels/math.hpp
0 → 100644
View file @
7e297b13
#ifndef MIGRAPHX_GUARD_KERNELS_MATH_HPP
#define MIGRAPHX_GUARD_KERNELS_MATH_HPP
#include <migraphx/kernels/types.hpp>
#include <migraphx/kernels/vec.hpp>
#include <migraphx/kernels/functional.hpp>
#include <migraphx/kernels/type_traits.hpp>
#include <hip/hip_fp16.h>
#include <hip/math_functions.h>
namespace
migraphx
{
namespace
math
{
constexpr
float
as_float
(
migraphx
::
half
x
)
{
return
x
;
}
template
<
class
T
>
constexpr
T
as_float
(
T
x
)
{
return
x
;
}
}
// namespace math
// NOLINTNEXTLINE
#define MIGRAPHX_DEVICE_MATH(name, fname) \
template <class... Ts, MIGRAPHX_REQUIRES(not is_any_vec<Ts...>())> \
auto __device__ name(Ts... xs) MIGRAPHX_RETURNS(fname(xs...))
// NOLINTNEXTLINE
#define MIGRAPHX_DEVICE_MATH_VEC(name) \
template <class... Ts, MIGRAPHX_REQUIRES(is_any_vec<Ts...>())> \
auto __device__ name(Ts... xs) \
{ \
return vec_transform(xs...)([](auto... ys) { return name(ys...); }); \
}
// NOLINTNEXTLINE
#define MIGRAPHX_DEVICE_MATH_FOR(type, name, fname) \
template <class... Ts, MIGRAPHX_REQUIRES(not is_any_vec<Ts...>())> \
auto __device__ name(type x, Ts... xs)->type \
{ \
return fname(x, xs...); \
}
// NOLINTNEXTLINE
#define MIGRAPHX_DEVICE_MATH_BINARY_FOR(type, name, fname) \
inline auto __device__ name(type x, type y)->type { return fname(x, y); }
// NOLINTNEXTLINE
#define MIGRAPHX_DEVICE_MATH_HALF(name, fname) \
template <class... Ts, MIGRAPHX_REQUIRES(not is_any_vec<Ts...>())> \
auto __device__ name(migraphx::half x, Ts... xs) \
MIGRAPHX_RETURNS(fname(math::as_float(x), math::as_float(xs)...))
// Template with two overloads for math functions, one for half2 type and one for more generic
// <half, N> vectorization where N is 4 or another even number.
// NOLINTNEXTLINE
#define MIGRAPHX_DEVICE_MATH_HALF2(name, fname) \
template <class... Ts> \
auto __device__ name(migraphx::vec<migraphx::half, 2> x, Ts... xs) \
MIGRAPHX_RETURNS(migraphx::vec<migraphx::half, 2>{fname(x, xs...)}); \
template <class... Ts, index_int N, MIGRAPHX_REQUIRES(N % 2 == 0 && (N > 2))> \
auto __device__ name(migraphx::vec<migraphx::half, N> x, Ts... xs) \
{ \
return vec_packed_transform<2>(x, xs...)( \
[](auto... ys) -> migraphx::vec<migraphx::half, 2> { return fname(ys...); }); \
}
MIGRAPHX_DEVICE_MATH
(
abs
,
::
abs
)
MIGRAPHX_DEVICE_MATH
(
acos
,
::
acos
)
MIGRAPHX_DEVICE_MATH
(
acosh
,
::
acosh
)
MIGRAPHX_DEVICE_MATH
(
asin
,
::
asin
)
MIGRAPHX_DEVICE_MATH
(
asinh
,
::
asinh
)
MIGRAPHX_DEVICE_MATH
(
atan
,
::
atan
)
MIGRAPHX_DEVICE_MATH
(
atanh
,
::
atanh
)
MIGRAPHX_DEVICE_MATH
(
ceil
,
::
ceil
)
MIGRAPHX_DEVICE_MATH
(
cos
,
::
cos
)
MIGRAPHX_DEVICE_MATH
(
cosh
,
::
cosh
)
MIGRAPHX_DEVICE_MATH
(
erf
,
::
erf
)
MIGRAPHX_DEVICE_MATH
(
exp
,
::
exp
)
MIGRAPHX_DEVICE_MATH
(
floor
,
::
floor
)
MIGRAPHX_DEVICE_MATH
(
isnan
,
::
isnan
)
MIGRAPHX_DEVICE_MATH
(
log
,
::
log
)
MIGRAPHX_DEVICE_MATH
(
pow
,
::
pow
)
MIGRAPHX_DEVICE_MATH
(
round
,
::
round
)
MIGRAPHX_DEVICE_MATH
(
rsqrt
,
::
rsqrt
)
MIGRAPHX_DEVICE_MATH
(
sin
,
::
sin
)
MIGRAPHX_DEVICE_MATH
(
sinh
,
::
sinh
)
MIGRAPHX_DEVICE_MATH
(
sqrt
,
::
sqrt
)
MIGRAPHX_DEVICE_MATH
(
tan
,
::
tan
)
MIGRAPHX_DEVICE_MATH
(
tanh
,
::
tanh
)
// Float overloads
MIGRAPHX_DEVICE_MATH_FOR
(
float
,
acos
,
::
acosf
)
MIGRAPHX_DEVICE_MATH_FOR
(
float
,
acosh
,
::
acoshf
)
MIGRAPHX_DEVICE_MATH_FOR
(
float
,
asin
,
::
asinf
)
MIGRAPHX_DEVICE_MATH_FOR
(
float
,
asinh
,
::
asinhf
)
MIGRAPHX_DEVICE_MATH_FOR
(
float
,
atan
,
::
atanf
)
MIGRAPHX_DEVICE_MATH_FOR
(
float
,
atanh
,
::
atanhf
)
MIGRAPHX_DEVICE_MATH_FOR
(
float
,
cos
,
::
cosf
)
MIGRAPHX_DEVICE_MATH_FOR
(
float
,
cosh
,
::
coshf
)
MIGRAPHX_DEVICE_MATH_FOR
(
float
,
rsqrt
,
::
rsqrtf
)
MIGRAPHX_DEVICE_MATH_FOR
(
float
,
sin
,
::
sinf
)
MIGRAPHX_DEVICE_MATH_FOR
(
float
,
sinh
,
::
sinhf
)
MIGRAPHX_DEVICE_MATH_FOR
(
float
,
tan
,
::
tanf
)
MIGRAPHX_DEVICE_MATH_FOR
(
float
,
tanh
,
::
tanhf
)
// Builtin half functions
MIGRAPHX_DEVICE_MATH_FOR
(
migraphx
::
half
,
abs
,
::
__habs
)
MIGRAPHX_DEVICE_MATH_FOR
(
migraphx
::
half
,
exp
,
::
hexp
)
MIGRAPHX_DEVICE_MATH_FOR
(
migraphx
::
half
,
log
,
::
hlog
)
MIGRAPHX_DEVICE_MATH_FOR
(
migraphx
::
half
,
rsqrt
,
::
hrsqrt
)
MIGRAPHX_DEVICE_MATH_FOR
(
migraphx
::
half
,
sqrt
,
::
hsqrt
)
// Use float to compute half overload
MIGRAPHX_DEVICE_MATH_HALF
(
acos
,
::
acos
)
MIGRAPHX_DEVICE_MATH_HALF
(
acosh
,
::
acosh
)
MIGRAPHX_DEVICE_MATH_HALF
(
asin
,
::
asin
)
MIGRAPHX_DEVICE_MATH_HALF
(
asinh
,
::
asinh
)
MIGRAPHX_DEVICE_MATH_HALF
(
atan
,
::
atan
)
MIGRAPHX_DEVICE_MATH_HALF
(
atanh
,
::
atanh
)
MIGRAPHX_DEVICE_MATH_HALF
(
ceil
,
::
ceil
)
MIGRAPHX_DEVICE_MATH_HALF
(
cos
,
::
cos
)
MIGRAPHX_DEVICE_MATH_HALF
(
cosh
,
::
cosh
)
MIGRAPHX_DEVICE_MATH_HALF
(
erf
,
::
erf
)
MIGRAPHX_DEVICE_MATH_HALF
(
floor
,
::
floor
)
MIGRAPHX_DEVICE_MATH_HALF
(
isnan
,
::
isnan
)
MIGRAPHX_DEVICE_MATH_HALF
(
pow
,
::
pow
)
MIGRAPHX_DEVICE_MATH_HALF
(
round
,
::
round
)
MIGRAPHX_DEVICE_MATH_HALF
(
sin
,
::
sin
)
MIGRAPHX_DEVICE_MATH_HALF
(
sinh
,
::
sinh
)
MIGRAPHX_DEVICE_MATH_HALF
(
tan
,
::
tan
)
MIGRAPHX_DEVICE_MATH_HALF
(
tanh
,
::
tanh
)
// Map math functions to hip half2 functions
// The half2 type is defined in include/hip/amd_detail/hip_fp16_gcc.h and is 2 16-bit floats
// packed into a 32-bit number. See include/hip/amd_detail/hip_fp16_math_fwd.h for the HIP names
// Most but not all of these math ops have operators of the same names. Ones not yet implemented
// at this time are: exp2, exp10, log2, log10, isinf
MIGRAPHX_DEVICE_MATH_HALF2
(
abs
,
::
__habs2
)
MIGRAPHX_DEVICE_MATH_HALF2
(
ceil
,
::
h2ceil
)
MIGRAPHX_DEVICE_MATH_HALF2
(
floor
,
::
h2floor
)
MIGRAPHX_DEVICE_MATH_HALF2
(
sin
,
::
h2sin
)
MIGRAPHX_DEVICE_MATH_HALF2
(
cos
,
::
h2cos
)
MIGRAPHX_DEVICE_MATH_HALF2
(
exp
,
::
h2exp
)
MIGRAPHX_DEVICE_MATH_HALF2
(
exp2
,
::
h2exp2
)
MIGRAPHX_DEVICE_MATH_HALF2
(
exp10
,
::
h2exp10
)
MIGRAPHX_DEVICE_MATH_HALF2
(
log2
,
::
h2log2
)
MIGRAPHX_DEVICE_MATH_HALF2
(
log
,
::
h2log
)
MIGRAPHX_DEVICE_MATH_HALF2
(
log10
,
::
h2log10
)
MIGRAPHX_DEVICE_MATH_HALF2
(
rsqrt
,
::
h2rsqrt
)
MIGRAPHX_DEVICE_MATH_HALF2
(
sqrt
,
::
h2sqrt
)
MIGRAPHX_DEVICE_MATH_HALF2
(
isinf
,
::
__hisinf2
)
MIGRAPHX_DEVICE_MATH_HALF2
(
isnan
,
::
__hisnan2
)
template
<
class
T
,
class
U
>
constexpr
auto
where
(
bool
cond
,
const
T
&
a
,
const
U
&
b
)
{
return
cond
?
a
:
b
;
}
MIGRAPHX_DEVICE_MATH_BINARY_FOR
(
float
,
max
,
::
max
)
MIGRAPHX_DEVICE_MATH_BINARY_FOR
(
float
,
min
,
::
min
)
MIGRAPHX_DEVICE_MATH_BINARY_FOR
(
double
,
max
,
::
max
)
MIGRAPHX_DEVICE_MATH_BINARY_FOR
(
double
,
min
,
::
min
)
// Add overloads for half that calls the float version
MIGRAPHX_DEVICE_MATH_BINARY_FOR
(
migraphx
::
half
,
max
,
::
fmaxf
)
MIGRAPHX_DEVICE_MATH_BINARY_FOR
(
migraphx
::
half
,
min
,
::
fminf
)
template
<
class
T
,
MIGRAPHX_REQUIRES
(
not
is_any_vec
<
T
>())
>
constexpr
auto
max
(
const
T
&
a
,
const
T
&
b
)
{
return
where
(
a
<
b
,
b
,
a
);
}
template
<
class
T
,
MIGRAPHX_REQUIRES
(
not
is_any_vec
<
T
>())
>
constexpr
auto
min
(
const
T
&
a
,
const
T
&
b
)
{
return
where
(
a
<
b
,
a
,
b
);
}
template
<
class
T
,
class
U
,
MIGRAPHX_REQUIRES
(
not
is_same
<
T
,
U
>{}
and
not
is_any_vec
<
T
,
U
>
())
>
constexpr
auto
max
(
const
T
&
a
,
const
U
&
b
)
{
return
max
<
common_type_t
<
T
,
U
>>
(
a
,
b
);
}
template
<
class
T
,
class
U
,
MIGRAPHX_REQUIRES
(
not
is_same
<
T
,
U
>{}
and
not
is_any_vec
<
T
,
U
>
())
>
constexpr
auto
min
(
const
T
&
a
,
const
U
&
b
)
{
return
min
<
common_type_t
<
T
,
U
>>
(
a
,
b
);
}
MIGRAPHX_DEVICE_MATH_VEC
(
abs
)
MIGRAPHX_DEVICE_MATH_VEC
(
acos
)
MIGRAPHX_DEVICE_MATH_VEC
(
acosh
)
MIGRAPHX_DEVICE_MATH_VEC
(
asin
)
MIGRAPHX_DEVICE_MATH_VEC
(
asinh
)
MIGRAPHX_DEVICE_MATH_VEC
(
atan
)
MIGRAPHX_DEVICE_MATH_VEC
(
atanh
)
MIGRAPHX_DEVICE_MATH_VEC
(
ceil
)
MIGRAPHX_DEVICE_MATH_VEC
(
cos
)
MIGRAPHX_DEVICE_MATH_VEC
(
cosh
)
MIGRAPHX_DEVICE_MATH_VEC
(
erf
)
MIGRAPHX_DEVICE_MATH_VEC
(
exp
)
MIGRAPHX_DEVICE_MATH_VEC
(
floor
)
MIGRAPHX_DEVICE_MATH_VEC
(
isnan
)
MIGRAPHX_DEVICE_MATH_VEC
(
log
)
MIGRAPHX_DEVICE_MATH_VEC
(
max
)
MIGRAPHX_DEVICE_MATH_VEC
(
min
)
MIGRAPHX_DEVICE_MATH_VEC
(
pow
)
MIGRAPHX_DEVICE_MATH_VEC
(
round
)
MIGRAPHX_DEVICE_MATH_VEC
(
rsqrt
)
MIGRAPHX_DEVICE_MATH_VEC
(
sin
)
MIGRAPHX_DEVICE_MATH_VEC
(
sinh
)
MIGRAPHX_DEVICE_MATH_VEC
(
sqrt
)
MIGRAPHX_DEVICE_MATH_VEC
(
tan
)
MIGRAPHX_DEVICE_MATH_VEC
(
tanh
)
MIGRAPHX_DEVICE_MATH_VEC
(
where
)
template
<
class
T
,
class
U
>
constexpr
auto
convert
(
U
v
)
{
return
vec_transform
(
v
)([](
auto
x
)
->
T
{
return
x
;
});
}
}
// namespace migraphx
#endif // MIGRAPHX_GUARD_KERNELS_MATH_HPP
src/targets/gpu/kernels/include/migraphx/kernels/ops.hpp
0 → 100644
View file @
7e297b13
#ifndef MIGRAPHX_GUARD_KERNELS_OPS_HPP
#define MIGRAPHX_GUARD_KERNELS_OPS_HPP
#include <migraphx/kernels/math.hpp>
namespace
migraphx
{
namespace
op
{
struct
sum
{
template
<
class
T
,
class
U
>
MIGRAPHX_DEVICE_CONSTEXPR
auto
operator
()(
T
x
,
U
y
)
const
{
return
x
+
y
;
}
};
struct
product
{
template
<
class
T
,
class
U
>
MIGRAPHX_DEVICE_CONSTEXPR
auto
operator
()(
T
x
,
U
y
)
const
{
return
x
*
y
;
}
};
struct
id
{
template
<
class
T
>
MIGRAPHX_DEVICE_CONSTEXPR
auto
operator
()(
T
x
)
const
{
return
x
;
}
};
struct
mean
{
index_int
item_num
=
1
;
template
<
class
T
>
MIGRAPHX_DEVICE_CONSTEXPR
auto
operator
()(
T
x
)
const
{
return
x
/
static_cast
<
T
>
(
item_num
);
}
};
struct
max
{
template
<
class
T
,
class
U
>
MIGRAPHX_DEVICE_CONSTEXPR
auto
operator
()(
T
x
,
U
y
)
const
{
return
migraphx
::
max
(
x
,
y
);
}
};
struct
min
{
template
<
class
T
,
class
U
>
MIGRAPHX_DEVICE_CONSTEXPR
auto
operator
()(
T
x
,
U
y
)
const
{
return
migraphx
::
min
(
x
,
y
);
}
};
}
// namespace op
struct
lowest
{
template
<
class
T
>
constexpr
operator
T
()
const
{
return
numeric_lowest
<
T
>
();
}
};
struct
highest
{
template
<
class
T
>
constexpr
operator
T
()
const
{
return
numeric_max
<
T
>
();
}
};
}
// namespace migraphx
#endif // MIGRAPHX_GUARD_KERNELS_OPS_HPP
Prev
1
…
17
18
19
20
21
22
23
24
25
…
39
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment