Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
e4737e2f
Commit
e4737e2f
authored
Sep 07, 2022
by
turneram
Browse files
Almost working
parent
27c6f5d3
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
103 additions
and
35 deletions
+103
-35
src/targets/gpu/jit/ck_elementwise.cpp
src/targets/gpu/jit/ck_elementwise.cpp
+3
-3
src/targets/gpu/kernels/include/migraphx/kernels/ck_elementwise.hpp
...s/gpu/kernels/include/migraphx/kernels/ck_elementwise.hpp
+99
-31
test/verify/0ck_elementwise_test.cpp
test/verify/0ck_elementwise_test.cpp
+1
-1
No files found.
src/targets/gpu/jit/ck_elementwise.cpp
View file @
e4737e2f
...
@@ -49,13 +49,13 @@ static const char* const ck_elementwise_kernel = R"__migraphx__(
...
@@ -49,13 +49,13 @@ static const char* const ck_elementwise_kernel = R"__migraphx__(
namespace migraphx {
namespace migraphx {
extern "C" {
extern "C" {
__global__ void ck_elementwise_kernel(void* a_p, void* b_p, void* c_p)
__global__ void ck_elementwise_kernel(void* a_p, void* b_p, void* c_p)
{
{
ck_elementwise(a_p, b_p, c_p);
make_tensors()(a_p, b_p, c_p)([](auto&&... xs) {
ck_elementwise(xs...);
});
}
}
}
}
...
...
src/targets/gpu/kernels/include/migraphx/kernels/ck_elementwise.hpp
View file @
e4737e2f
...
@@ -26,51 +26,119 @@
...
@@ -26,51 +26,119 @@
#include <migraphx/kernels/index.hpp>
#include <migraphx/kernels/index.hpp>
#include <migraphx/kernels/algorithm.hpp>
#include <migraphx/kernels/algorithm.hpp>
#include <migraphx/kernels/integral_constant.hpp>
#include
<iostream>
#include
"ck/device_utility/device_prop.hpp"
#include
<cstdlib>
#include
"ck/device_utility/kernel_launch.hpp"
#include "ck/
ck
.hpp"
#include "ck/
tensor_operation/gpu/device/device_base
.hpp"
#include "ck/tensor_operation/gpu/device/device_elementwise.hpp"
#include "ck/tensor_operation/gpu/device/device_elementwise.hpp"
#include "ck/tensor_operation/gpu/
element/
binary_element
_
wise_
operation
.hpp"
#include "ck/tensor_operation/gpu/
grid/gridwise_
binary_elementwise_
1d
.hpp"
namespace
migraphx
{
namespace
migraphx
{
// using F16 = ck::half_t;
using
ADataType
=
float
;
// using F32 = float;
using
BDataType
=
float
;
using
CDataType
=
float
;
using
ElementwiseFunctor
=
float
;
// using ABDataType = F16
;
static
constexpr
auto
I0
=
ck
::
Number
<
0
>
{}
;
//
using
CDataType = F16
;
using
index_t
=
index_int
;
// using Add = ck::tensor_operation::element_wise::Add;
template
<
class
L
,
class
S
>
__host__
__device__
constexpr
auto
MakeDescriptor_M
(
const
L
&
lengths
,
const
S
&
strides
)
{
auto
idx
=
make_index
();
auto
tupleOfShape
=
generate_tuple
([
&
](
auto
I
)
{
return
static_cast
<
ck
::
index_t
>
(
lengths
[
I
]);
},
ck
::
Number
<
1
>
{});
auto
tupleOfStride
=
generate_tuple
([
&
](
auto
I
)
{
return
static_cast
<
ck
::
index_t
>
(
strides
[
I
]);
},
ck
::
Number
<
1
>
{});
const
auto
desc_m
=
make_naive_tensor_descriptor
(
tupleOfShape
,
tupleOfStride
);
// using DeviceElementwiseAddInstance =
const
auto
M
=
desc_m
.
GetLength
(
I0
);
// ck::tensor_operation::device::DeviceElementwise<ck::Tuple<ABDataType, ABDataType>,
const
index_t
loop_step
=
idx
.
nglobal
();
//gridSize * blockSize * MPerThread;
// ck::Tuple<CDataType>,
const
auto
pad
=
ck
::
math
::
integer_least_multiple
(
M
,
loop_step
)
-
M
;
// Add,
const
auto
desc_m_pad
=
// 1,
transform_tensor_descriptor
(
desc_m
,
// 8,
make_tuple
(
ck
::
make_right_pad_transform
(
M
,
pad
)),
// ck::Sequence<8, 8>,
make_tuple
(
ck
::
Sequence
<
0
>
{}),
// ck::Sequence<8>>;
make_tuple
(
ck
::
Sequence
<
0
>
{}));
return
desc_m_pad
;
}
__host__
__device__
void
ck_elementwise
(
void
*
/* a_p */
,
void
*
/* b_p */
,
void
*
/* c_p */
)
struct
Add
{
{
// ck::index_t M = 1024;
template
<
typename
Y
,
typename
X0
,
typename
X1
>
// std::array<const void*, 2> input = {a_p,
__host__
__device__
constexpr
void
operator
()(
Y
&
y
,
const
X0
&
x0
,
const
X1
&
x1
)
const
;
// b_p};
// std::array<void*, 1> output = {c_p};
template
<
>
__host__
__device__
constexpr
void
operator
()
<
float
>
(
float
&
y
,
const
float
&
x0
,
const
float
&
x1
)
const
{
y
=
x0
+
x1
;
};
// std::array<ck::index_t, 1> abc_lengths = {M};
template
<
>
// std::array<ck::index_t, 1> a_strides = {1};
__host__
__device__
constexpr
void
// std::array<ck::index_t, 1> b_strides = {1};
operator
()
<
double
>
(
double
&
y
,
const
double
&
x0
,
const
double
&
x1
)
const
// std::array<ck::index_t, 1> c_strides = {1};
{
y
=
x0
+
x1
;
};
};
// auto broadcastAdd = DeviceElementwiseAddInstance{};
// auto argument = broadcastAdd.MakeArgumentPointer(
// abc_lengths, {a_strides, b_strides}, {c_strides}, input, output, Add{});
// broadcastAdd_invoker_ptr->Run(argument.get(), StreamConfig{nullptr, false});
template
<
class
T
,
class
U
,
class
V
>
__device__
void
ck_elementwise
(
const
T
&
a_t
,
const
U
&
b_t
,
const
V
&
c_t
)
{
//auto add = [](auto a, auto b) { return a + b; };
auto
lengths
=
a_t
.
get_shape
().
lens
;
auto
strides
=
a_t
.
get_shape
().
strides
;
auto
a_desc
=
MakeDescriptor_M
(
lengths
,
strides
);
using
AGridDesc_M
=
decltype
(
a_desc
);
//using Add = ck::tensor_operation::element_wise::Add;
using
GridwiseBinEltwise
=
ck
::
GridwiseBinaryElementwise_1D
<
ADataType
,
BDataType
,
CDataType
,
CDataType
,
AGridDesc_M
,
AGridDesc_M
,
AGridDesc_M
,
Add
,
8
,
8
,
8
,
8
>
;
auto
op
=
Add
{};
GridwiseBinEltwise
::
Run
(
a_t
.
data
(),
b_t
.
data
(),
c_t
.
data
(),
a_desc
,
a_desc
,
a_desc
,
op
);
// auto kernel = ck::kernel_binary_elementwise_1d<GridwiseBinEltwise,
// ADataType,
// BDataType,
// CDataType,
// AGridDesc_M,
// AGridDesc_M,
// AGridDesc_M,
// Add>;
// kernel(a_t.data(), b_t.data(), c_t.data(), a_desc, a_desc, a_desc, Add);
// Argument arg{a_t.data(), b_t.data(), c_t.data(), c_t.get_shape().lens, a_t.get_shape().strides, b_t.get_shape().strides, c_t.get_shape().strides,
// add};
// auto lengths = a_t.get_shape().lens;
// auto strides = a_t.get_shape().strides;
// auto idx = make_index();
// b_t.get_shape();
// c_t.get_shape();
// auto tupleOfShape = generate_tuple([&](auto I) { return lengths[I]; }, ck::Number<1>{});
// auto tupleOfStride = generate_tuple([&](auto I) { return strides[I]; }, ck::Number<1>{});
// const auto desc_m = make_naive_tensor_descriptor(tupleOfShape, tupleOfStride);
// const auto M = desc_m.GetLength(I0);
// const ck::index_t loop_step = idx.nglobal();//gridSize * blockSize * MPerThread;
// const auto pad = ck::math::integer_least_multiple(M, loop_step) - M;
// const auto desc_m_pad =
// transform_tensor_descriptor(desc_m,
// make_tuple(ck::make_right_pad_transform(M, pad)),
// make_tuple(ck::Sequence<0>{}),
// make_tuple(ck::Sequence<0>{}));
}
}
}
// namespace migraphx
}
// namespace migraphx
#endif
#endif
test/verify/0ck_elementwise_test.cpp
View file @
e4737e2f
...
@@ -33,7 +33,7 @@ struct ck_elementwise : verify_program<ck_elementwise>
...
@@ -33,7 +33,7 @@ struct ck_elementwise : verify_program<ck_elementwise>
{
{
migraphx
::
program
p
;
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
m1_shape
{
migraphx
::
shape
::
float_type
,
{
10
,
20
}};
migraphx
::
shape
m1_shape
{
migraphx
::
shape
::
float_type
,
{
20
}};
// migraphx::shape m2_shape{migraphx::shape::float_type, {20, 10}};
// migraphx::shape m2_shape{migraphx::shape::float_type, {20, 10}};
auto
l1
=
mm
->
add_parameter
(
"1"
,
m1_shape
);
auto
l1
=
mm
->
add_parameter
(
"1"
,
m1_shape
);
auto
l2
=
mm
->
add_parameter
(
"2"
,
m1_shape
);
auto
l2
=
mm
->
add_parameter
(
"2"
,
m1_shape
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment