Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
8c1ad9e6
Commit
8c1ad9e6
authored
Sep 12, 2022
by
turneram
Browse files
Call from global function
parent
bf523dbe
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
161 additions
and
30 deletions
+161
-30
src/targets/gpu/jit/ck_elementwise.cpp
src/targets/gpu/jit/ck_elementwise.cpp
+142
-9
src/targets/gpu/kernels/include/migraphx/kernels/ck_elementwise.hpp
...s/gpu/kernels/include/migraphx/kernels/ck_elementwise.hpp
+19
-21
No files found.
src/targets/gpu/jit/ck_elementwise.cpp
View file @
8c1ad9e6
...
@@ -25,6 +25,9 @@
...
@@ -25,6 +25,9 @@
#include <migraphx/make_op.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/compile_gen.hpp>
#include <migraphx/gpu/compile_hip_code_object.hpp>
#include <migraphx/gpu/compile_hip_code_object.hpp>
#include <migraphx/gpu/compile_hip.hpp>
#include <migraphx/gpu/compile_hip.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/ranges.hpp>
...
@@ -39,22 +42,104 @@ namespace migraphx {
...
@@ -39,22 +42,104 @@ namespace migraphx {
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
gpu
{
namespace
gpu
{
// NOLINTNEXTLINE
using
namespace
migraphx
::
gpu
::
gen
;
// NOLINT
static
const
char
*
const
ck_elementwise_kernel
=
R"__migraphx__(
static
const
char
*
const
ck_elementwise_kernel
=
R"__migraphx__(
#include <migraphx/kernels/ck_elementwise.hpp>
//
#include <migraphx/kernels/ck_elementwise.hpp>
#include <migraphx/kernels/ops.hpp>
#include <migraphx/kernels/ops.hpp>
#include <migraphx/kernels/integral_constant.hpp>
#include <migraphx/kernels/integral_constant.hpp>
#include <migraphx/kernels/generic_constant.hpp>
#include <migraphx/kernels/generic_constant.hpp>
#include <args.hpp>
#include <args.hpp>
#include <migraphx/kernels/index.hpp>
#include <migraphx/kernels/algorithm.hpp>
#include <migraphx/kernels/integral_constant.hpp>
#include <migraphx/kernels/tensor_view.hpp>
#include "ck/device_utility/device_prop.hpp"
#include "ck/device_utility/kernel_launch.hpp"
#include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/tensor_operation/gpu/device/device_elementwise.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_binary_elementwise_1d.hpp"
namespace migraphx {
namespace migraphx {
using ADataType = float;
using BDataType = float;
using CDataType = float;
using ElementwiseFunctor = float;
static constexpr auto I0 = ck::Number<0>{};
template <class L, class S, class N>
constexpr auto MakeDescriptor_M(const L& lengths, const S& strides, const N& ndim)
{
auto gridSize = 72;
auto blockSize = 1024;
//constexpr auto ndim = 1;
// auto idx = make_index();
auto tupleOfShape = generate_tuple([&](auto I) { return static_cast<ck::index_t>(lengths[I]); },
ck::Number<ndim>{});
auto tupleOfStride = generate_tuple(
[&](auto I) { return static_cast<ck::index_t>(strides[I]); }, ck::Number<1>{});
const auto desc = make_naive_tensor_descriptor(tupleOfShape, tupleOfStride);
auto desc_m = desc;
// merge nd to 1d desc - [s0 * s1 * ...]
if constexpr(ndim > 1)
{
desc_m = transform_tensor_descriptor(
desc,
make_tuple(make_merge_transform(tupleOfShape)),
make_tuple(generate_sequence_v2([&](auto I) { return I; }, ck::Number<ndim>{})),
make_tuple(ck::Sequence<0>{}));
}
const auto M = desc_m.GetLength(I0);
const ck::index_t loop_step = /* idx.nglobal(); // */ gridSize * blockSize /* * MPerThread */;
const auto pad = ck::math::integer_least_multiple(M, loop_step) - M;
const auto desc_m_pad =
transform_tensor_descriptor(desc_m,
make_tuple(ck::make_right_pad_transform(M, pad)),
make_tuple(ck::Sequence<0>{}),
make_tuple(ck::Sequence<0>{}));
return desc_m_pad;
}
struct Add
{
template <typename Y, typename X0, typename X1>
__device__ constexpr void operator()(Y& y, const X0& x0, const X1& x1) const
{
y = x0 + x1;
};
};
extern "C" {
extern "C" {
__global__ void ck_elementwise_kernel(void* a_p, void* b_p, void* c_p)
__global__ void ck_elementwise_kernel(void* a_p, void* b_p, void* c_p)
{
{
make_tensors()(a_p, b_p, c_p)([](auto&&... xs) {
make_tensors()(a_p, b_p, c_p)([](auto a_t, auto b_t, auto c_t) {
ck_elementwise(xs...);
constexpr auto lengths = get_shape_c<decltype(a_t)>{}.lens;
constexpr auto strides = get_shape_c<decltype(a_t)>{}.strides;
constexpr auto ndim = _c<decltype(lengths.size()){}>[1];
constexpr auto a_desc = MakeDescriptor_M(lengths, strides, ndim);
using AGridDesc_M = decltype(a_desc);
using GridwiseBinEltwise = ck::GridwiseBinaryElementwise_1D<ADataType,
BDataType,
CDataType,
CDataType,
AGridDesc_M,
AGridDesc_M,
AGridDesc_M,
Add,
1,
1,
1,
1>;
auto op = Add{};
GridwiseBinEltwise::Run(a_t.data(), b_t.data(), c_t.data(), a_desc, a_desc, a_desc, op);
});
});
}
}
...
@@ -64,20 +149,68 @@ __global__ void ck_elementwise_kernel(void* a_p, void* b_p, void* c_p)
...
@@ -64,20 +149,68 @@ __global__ void ck_elementwise_kernel(void* a_p, void* b_p, void* c_p)
)__migraphx__"
;
)__migraphx__"
;
// NOLINTNEXTLINE
// static const char* const ck_elementwise_kernel = R"__migraphx__(
// #include <migraphx/kernels/ck_elementwise.hpp>
// #include <migraphx/kernels/ops.hpp>
// #include <migraphx/kernels/integral_constant.hpp>
// #include <migraphx/kernels/generic_constant.hpp>
// #include <args.hpp>
// namespace migraphx {
// extern "C" {
// __global__ void ck_elementwise_kernel(void* a_p, void* b_p, void* c_p)
// {
// make_tensors()(a_p, b_p, c_p)([](auto&&... xs) {
// ck_elementwise(xs...);
// });
// }
// }
// } // namespace migraphx
// )__migraphx__";
struct
ck_elementwise_compiler
:
compiler
<
ck_elementwise_compiler
>
struct
ck_elementwise_compiler
:
compiler
<
ck_elementwise_compiler
>
{
{
std
::
vector
<
std
::
string
>
names
()
const
{
return
{
"ck_elementwise"
};
}
std
::
vector
<
std
::
string
>
names
()
const
{
return
{
"ck_elementwise"
};
}
static
std
::
size_t
oversubscribe_if
(
bool
b
)
{
if
(
b
)
return
256
;
else
return
1
;
}
operation
compile_op
(
context
&
ctx
,
const
std
::
vector
<
shape
>&
inputs
,
const
value
&
v
)
const
operation
compile_op
(
context
&
ctx
,
const
std
::
vector
<
shape
>&
inputs
,
const
value
&
v
)
const
{
{
// hip_compile_options options;
// auto out_s = inputs.back();
// options.set_launch_params(v, compute_global_for(ctx, out_s.elements()));
// options.inputs = inputs;
// options.output = out_s;
// options.kernel_name = "ck_elementwise_kernel";
// options.virtual_inputs = inputs;
// return compile_hip_code_object(ck_elementwise_kernel, options);
hip_compile_options
options
;
hip_compile_options
options
;
auto
out_s
=
inputs
.
back
();
options
.
set_launch_params
(
v
,
compute_global_for
(
ctx
,
out_s
.
elements
()));
options
.
inputs
=
inputs
;
options
.
inputs
=
inputs
;
options
.
output
=
out_s
;
options
.
output
=
inputs
.
back
();
options
.
virtual_inputs
=
reduce_dims
(
inputs
);
options
.
params
=
"-Wno-float-equal"
;
auto
axis
=
find_fast_axis
(
options
.
virtual_inputs
);
auto
vec
=
vectorize
::
elements
(
axis
,
options
.
virtual_inputs
);
auto
preloads
=
preload
::
broadcasts
(
axis
,
options
.
virtual_inputs
);
options
.
kernel_name
=
"ck_elementwise_kernel"
;
options
.
kernel_name
=
"ck_elementwise_kernel"
;
options
.
virtual_inputs
=
inputs
;
options
.
set_launch_params
(
v
,
compute_global_for
(
ctx
,
options
.
output
.
elements
()
/
vec
.
size
,
oversubscribe_if
(
not
preloads
.
is_preloading
())));
return
compile_hip_code_object
(
ck_elementwise_kernel
,
options
);
return
compile_hip_code_object
(
ck_elementwise_kernel
,
options
);
}
}
...
...
src/targets/gpu/kernels/include/migraphx/kernels/ck_elementwise.hpp
View file @
8c1ad9e6
...
@@ -91,28 +91,26 @@ template <class T, class U, class V>
...
@@ -91,28 +91,26 @@ template <class T, class U, class V>
__device__
void
ck_elementwise
(
const
T
&
a_t
,
const
U
&
b_t
,
const
V
&
c_t
)
__device__
void
ck_elementwise
(
const
T
&
a_t
,
const
U
&
b_t
,
const
V
&
c_t
)
{
{
auto
idx
=
make_index
();
auto
idx
=
make_index
();
if
(
idx
.
global
==
0
)
constexpr
auto
lengths
=
get_shape_c
<
T
>
{}.
lens
;
{
constexpr
auto
strides
=
get_shape_c
<
T
>
{}.
strides
;
constexpr
auto
lengths
=
get_shape_c
<
T
>
{}.
lens
;
constexpr
auto
a_desc
=
MakeDescriptor_M
(
lengths
,
strides
,
1
);
constexpr
auto
strides
=
get_shape_c
<
T
>
{}.
strides
;
constexpr
auto
a_desc
=
MakeDescriptor_M
(
lengths
,
strides
,
1
);
using
AGridDesc_M
=
decltype
(
a_desc
);
using
AGridDesc_M
=
decltype
(
a_desc
);
using
GridwiseBinEltwise
=
ck
::
GridwiseBinaryElementwise_1D
<
ADataType
,
using
GridwiseBinEltwise
=
ck
::
GridwiseBinaryElementwise_1D
<
ADataType
,
BDataType
,
BDataType
,
CDataType
,
CDataType
,
CDataType
,
CDataType
,
AGridDesc_M
,
AGridDesc_M
,
AGridDesc_M
,
AGridDesc_M
,
AGridDesc_M
,
AGridDesc_M
,
Add
,
Add
,
1
,
1
,
1
,
1
,
1
,
1
,
1
>
;
1
>
;
auto
op
=
Add
{};
auto
op
=
Add
{};
GridwiseBinEltwise
::
Run
(
a_t
.
data
(),
b_t
.
data
(),
c_t
.
data
(),
a_desc
,
a_desc
,
a_desc
,
op
);
GridwiseBinEltwise
::
Run
(
a_t
.
data
(),
b_t
.
data
(),
c_t
.
data
(),
a_desc
,
a_desc
,
a_desc
,
op
);
}
}
}
}
// namespace migraphx
}
// namespace migraphx
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment