Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
8c1ad9e6
Commit
8c1ad9e6
authored
Sep 12, 2022
by
turneram
Browse files
Call from global function
parent
bf523dbe
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
161 additions
and
30 deletions
+161
-30
src/targets/gpu/jit/ck_elementwise.cpp
src/targets/gpu/jit/ck_elementwise.cpp
+142
-9
src/targets/gpu/kernels/include/migraphx/kernels/ck_elementwise.hpp
...s/gpu/kernels/include/migraphx/kernels/ck_elementwise.hpp
+19
-21
No files found.
src/targets/gpu/jit/ck_elementwise.cpp
View file @
8c1ad9e6
...
...
@@ -25,6 +25,9 @@
#include <migraphx/make_op.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/compile_gen.hpp>
#include <migraphx/gpu/compile_hip_code_object.hpp>
#include <migraphx/gpu/compile_hip.hpp>
#include <migraphx/ranges.hpp>
...
...
@@ -39,22 +42,104 @@ namespace migraphx {
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
gpu
{
// NOLINTNEXTLINE
using
namespace
migraphx
::
gpu
::
gen
;
// NOLINT
static
const
char
*
const
ck_elementwise_kernel
=
R"__migraphx__(
#include <migraphx/kernels/ck_elementwise.hpp>
//
#include <migraphx/kernels/ck_elementwise.hpp>
#include <migraphx/kernels/ops.hpp>
#include <migraphx/kernels/integral_constant.hpp>
#include <migraphx/kernels/generic_constant.hpp>
#include <args.hpp>
#include <migraphx/kernels/index.hpp>
#include <migraphx/kernels/algorithm.hpp>
#include <migraphx/kernels/integral_constant.hpp>
#include <migraphx/kernels/tensor_view.hpp>
#include "ck/device_utility/device_prop.hpp"
#include "ck/device_utility/kernel_launch.hpp"
#include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/tensor_operation/gpu/device/device_elementwise.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_binary_elementwise_1d.hpp"
namespace migraphx {
using ADataType = float;
using BDataType = float;
using CDataType = float;
using ElementwiseFunctor = float;
static constexpr auto I0 = ck::Number<0>{};
template <class L, class S, class N>
constexpr auto MakeDescriptor_M(const L& lengths, const S& strides, const N& ndim)
{
auto gridSize = 72;
auto blockSize = 1024;
//constexpr auto ndim = 1;
// auto idx = make_index();
auto tupleOfShape = generate_tuple([&](auto I) { return static_cast<ck::index_t>(lengths[I]); },
ck::Number<ndim>{});
auto tupleOfStride = generate_tuple(
[&](auto I) { return static_cast<ck::index_t>(strides[I]); }, ck::Number<1>{});
const auto desc = make_naive_tensor_descriptor(tupleOfShape, tupleOfStride);
auto desc_m = desc;
// merge nd to 1d desc - [s0 * s1 * ...]
if constexpr(ndim > 1)
{
desc_m = transform_tensor_descriptor(
desc,
make_tuple(make_merge_transform(tupleOfShape)),
make_tuple(generate_sequence_v2([&](auto I) { return I; }, ck::Number<ndim>{})),
make_tuple(ck::Sequence<0>{}));
}
const auto M = desc_m.GetLength(I0);
const ck::index_t loop_step = /* idx.nglobal(); // */ gridSize * blockSize /* * MPerThread */;
const auto pad = ck::math::integer_least_multiple(M, loop_step) - M;
const auto desc_m_pad =
transform_tensor_descriptor(desc_m,
make_tuple(ck::make_right_pad_transform(M, pad)),
make_tuple(ck::Sequence<0>{}),
make_tuple(ck::Sequence<0>{}));
return desc_m_pad;
}
struct Add
{
template <typename Y, typename X0, typename X1>
__device__ constexpr void operator()(Y& y, const X0& x0, const X1& x1) const
{
y = x0 + x1;
};
};
extern "C" {
__global__ void ck_elementwise_kernel(void* a_p, void* b_p, void* c_p)
{
make_tensors()(a_p, b_p, c_p)([](auto&&... xs) {
ck_elementwise(xs...);
make_tensors()(a_p, b_p, c_p)([](auto a_t, auto b_t, auto c_t) {
constexpr auto lengths = get_shape_c<decltype(a_t)>{}.lens;
constexpr auto strides = get_shape_c<decltype(a_t)>{}.strides;
constexpr auto ndim = _c<decltype(lengths.size()){}>[1];
constexpr auto a_desc = MakeDescriptor_M(lengths, strides, ndim);
using AGridDesc_M = decltype(a_desc);
using GridwiseBinEltwise = ck::GridwiseBinaryElementwise_1D<ADataType,
BDataType,
CDataType,
CDataType,
AGridDesc_M,
AGridDesc_M,
AGridDesc_M,
Add,
1,
1,
1,
1>;
auto op = Add{};
GridwiseBinEltwise::Run(a_t.data(), b_t.data(), c_t.data(), a_desc, a_desc, a_desc, op);
});
}
...
...
@@ -64,20 +149,68 @@ __global__ void ck_elementwise_kernel(void* a_p, void* b_p, void* c_p)
)__migraphx__"
;
// NOLINTNEXTLINE
// static const char* const ck_elementwise_kernel = R"__migraphx__(
// #include <migraphx/kernels/ck_elementwise.hpp>
// #include <migraphx/kernels/ops.hpp>
// #include <migraphx/kernels/integral_constant.hpp>
// #include <migraphx/kernels/generic_constant.hpp>
// #include <args.hpp>
// namespace migraphx {
// extern "C" {
// __global__ void ck_elementwise_kernel(void* a_p, void* b_p, void* c_p)
// {
// make_tensors()(a_p, b_p, c_p)([](auto&&... xs) {
// ck_elementwise(xs...);
// });
// }
// }
// } // namespace migraphx
// )__migraphx__";
struct
ck_elementwise_compiler
:
compiler
<
ck_elementwise_compiler
>
{
std
::
vector
<
std
::
string
>
names
()
const
{
return
{
"ck_elementwise"
};
}
static
std
::
size_t
oversubscribe_if
(
bool
b
)
{
if
(
b
)
return
256
;
else
return
1
;
}
operation
compile_op
(
context
&
ctx
,
const
std
::
vector
<
shape
>&
inputs
,
const
value
&
v
)
const
{
// hip_compile_options options;
// auto out_s = inputs.back();
// options.set_launch_params(v, compute_global_for(ctx, out_s.elements()));
// options.inputs = inputs;
// options.output = out_s;
// options.kernel_name = "ck_elementwise_kernel";
// options.virtual_inputs = inputs;
// return compile_hip_code_object(ck_elementwise_kernel, options);
hip_compile_options
options
;
auto
out_s
=
inputs
.
back
();
options
.
set_launch_params
(
v
,
compute_global_for
(
ctx
,
out_s
.
elements
()));
options
.
inputs
=
inputs
;
options
.
output
=
out_s
;
options
.
output
=
inputs
.
back
();
options
.
virtual_inputs
=
reduce_dims
(
inputs
);
options
.
params
=
"-Wno-float-equal"
;
auto
axis
=
find_fast_axis
(
options
.
virtual_inputs
);
auto
vec
=
vectorize
::
elements
(
axis
,
options
.
virtual_inputs
);
auto
preloads
=
preload
::
broadcasts
(
axis
,
options
.
virtual_inputs
);
options
.
kernel_name
=
"ck_elementwise_kernel"
;
options
.
virtual_inputs
=
inputs
;
options
.
set_launch_params
(
v
,
compute_global_for
(
ctx
,
options
.
output
.
elements
()
/
vec
.
size
,
oversubscribe_if
(
not
preloads
.
is_preloading
())));
return
compile_hip_code_object
(
ck_elementwise_kernel
,
options
);
}
...
...
src/targets/gpu/kernels/include/migraphx/kernels/ck_elementwise.hpp
View file @
8c1ad9e6
...
...
@@ -91,28 +91,26 @@ template <class T, class U, class V>
__device__
void
ck_elementwise
(
const
T
&
a_t
,
const
U
&
b_t
,
const
V
&
c_t
)
{
auto
idx
=
make_index
();
if
(
idx
.
global
==
0
)
{
constexpr
auto
lengths
=
get_shape_c
<
T
>
{}.
lens
;
constexpr
auto
strides
=
get_shape_c
<
T
>
{}.
strides
;
constexpr
auto
a_desc
=
MakeDescriptor_M
(
lengths
,
strides
,
1
);
constexpr
auto
lengths
=
get_shape_c
<
T
>
{}.
lens
;
constexpr
auto
strides
=
get_shape_c
<
T
>
{}.
strides
;
constexpr
auto
a_desc
=
MakeDescriptor_M
(
lengths
,
strides
,
1
);
using
AGridDesc_M
=
decltype
(
a_desc
);
using
GridwiseBinEltwise
=
ck
::
GridwiseBinaryElementwise_1D
<
ADataType
,
BDataType
,
CDataType
,
CDataType
,
AGridDesc_M
,
AGridDesc_M
,
AGridDesc_M
,
Add
,
1
,
1
,
1
,
1
>
;
auto
op
=
Add
{};
GridwiseBinEltwise
::
Run
(
a_t
.
data
(),
b_t
.
data
(),
c_t
.
data
(),
a_desc
,
a_desc
,
a_desc
,
op
);
}
using
AGridDesc_M
=
decltype
(
a_desc
);
using
GridwiseBinEltwise
=
ck
::
GridwiseBinaryElementwise_1D
<
ADataType
,
BDataType
,
CDataType
,
CDataType
,
AGridDesc_M
,
AGridDesc_M
,
AGridDesc_M
,
Add
,
1
,
1
,
1
,
1
>
;
auto
op
=
Add
{};
GridwiseBinEltwise
::
Run
(
a_t
.
data
(),
b_t
.
data
(),
c_t
.
data
(),
a_desc
,
a_desc
,
a_desc
,
op
);
}
}
// namespace migraphx
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment