Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
c42aded1
Commit
c42aded1
authored
Sep 16, 2022
by
turneram
Browse files
Formatting
parent
961cf059
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
36 additions
and
32 deletions
+36
-32
src/targets/gpu/jit/ck_elementwise.cpp
src/targets/gpu/jit/ck_elementwise.cpp
+9
-9
src/targets/gpu/kernels/include/migraphx/kernels/ck_elementwise.hpp
...s/gpu/kernels/include/migraphx/kernels/ck_elementwise.hpp
+27
-23
No files found.
src/targets/gpu/jit/ck_elementwise.cpp
View file @
c42aded1
...
...
@@ -83,18 +83,18 @@ struct ck_elementwise_compiler : compiler<ck_elementwise_compiler>
operation
compile_op
(
context
&
ctx
,
const
std
::
vector
<
shape
>&
inputs
,
const
value
&
v
)
const
{
hip_compile_options
options
;
options
.
inputs
=
inputs
;
options
.
output
=
inputs
.
back
();
//options.virtual_inputs = reduce_dims(inputs);
//std::cout << options.virtual_inputs << std::endl;
options
.
params
=
"-Wno-float-equal"
;
options
.
inputs
=
inputs
;
options
.
output
=
inputs
.
back
();
//
options.virtual_inputs = reduce_dims(inputs);
//
std::cout << options.virtual_inputs << std::endl;
options
.
params
=
"-Wno-float-equal"
;
// auto axis = find_fast_axis(options.virtual_inputs);
// auto vec = vectorize::elements(axis, options.virtual_inputs);
// auto preloads = preload::broadcasts(axis, options.virtual_inputs);
auto
axis
=
find_fast_axis
(
inputs
);
auto
vec
=
vectorize
::
elements
(
axis
,
inputs
);
auto
preloads
=
preload
::
broadcasts
(
axis
,
inputs
);
options
.
kernel_name
=
"ck_elementwise_kernel"
;
auto
axis
=
find_fast_axis
(
inputs
);
auto
vec
=
vectorize
::
elements
(
axis
,
inputs
);
auto
preloads
=
preload
::
broadcasts
(
axis
,
inputs
);
options
.
kernel_name
=
"ck_elementwise_kernel"
;
options
.
set_launch_params
(
v
,
compute_global_for
(
ctx
,
...
...
src/targets/gpu/kernels/include/migraphx/kernels/ck_elementwise.hpp
View file @
c42aded1
...
...
@@ -95,7 +95,7 @@ template <ck::index_t ndim>
struct
CKBinaryElementwise2
{
template
<
class
Desc_M
>
/* constexpr */
__device__
auto
PadDescriptor_M_1d
(
Desc_M
desc_m
)
/* constexpr */
__device__
auto
PadDescriptor_M_1d
(
Desc_M
desc_m
)
{
auto
gridSize
=
72
;
auto
blockSize
=
1024
;
...
...
@@ -112,12 +112,16 @@ struct CKBinaryElementwise2
}
template
<
class
L
,
class
S
>
/* constexpr */
__device__
auto
MakeDescriptor_M
(
const
L
&
lengths
,
const
S
&
strides
)
/* constexpr */
__device__
auto
MakeDescriptor_M
(
const
L
&
lengths
,
const
S
&
strides
)
{
auto
tupleOfShape
=
generate_tuple
(
[
&
](
auto
I
)
{
return
static_cast
<
ck
::
index_t
>
(
lengths
[
I
]);
},
ck
::
Number
<
ndim
>
{});
auto
tupleOfStride
=
generate_tuple
(
[
&
](
auto
I
)
{
printf
(
"Stride %i: %i
\n
"
,
int
(
I
),
int
(
strides
[
I
]));
return
static_cast
<
ck
::
index_t
>
(
strides
[
I
]);
},
ck
::
Number
<
ndim
>
{});
[
&
](
auto
I
)
{
printf
(
"Stride %i: %i
\n
"
,
int
(
I
),
int
(
strides
[
I
]));
return
static_cast
<
ck
::
index_t
>
(
strides
[
I
]);
},
ck
::
Number
<
ndim
>
{});
const
auto
desc
=
make_naive_tensor_descriptor
(
tupleOfShape
,
tupleOfStride
);
// merge nd to 1d desc - [s0 * s1 * ...]
if
constexpr
(
ndim
>
1
)
...
...
@@ -166,37 +170,37 @@ struct Div
template
<
class
T
,
class
U
,
class
V
>
__device__
void
ck_elementwise
(
const
T
&
a_t
,
const
U
&
b_t
,
const
V
&
c_t
)
{
//auto idx = make_index();
constexpr
auto
a_lens
=
get_shape_c
<
T
>
{}.
lens
;
constexpr
auto
a_strides
=
get_shape_c
<
T
>
{}.
strides
;
constexpr
ck
::
index_t
a_ndim
=
a_lens
.
size
();
//decltype(a_lens.size()){};
//
auto idx = make_index();
constexpr
auto
a_lens
=
get_shape_c
<
T
>
{}.
lens
;
constexpr
auto
a_strides
=
get_shape_c
<
T
>
{}.
strides
;
constexpr
ck
::
index_t
a_ndim
=
a_lens
.
size
();
//
decltype(a_lens.size()){};
// if (idx.global == 0)
// printf("a_ndim: %i\n", int(a_ndim));
auto
a_bin_op
=
CKBinaryElementwise
<
a_ndim
>
{};
constexpr
auto
a_desc
=
a_bin_op
.
MakeDescriptor_M
(
a_lens
,
a_strides
);
auto
a_bin_op
=
CKBinaryElementwise
<
a_ndim
>
{};
constexpr
auto
a_desc
=
a_bin_op
.
MakeDescriptor_M
(
a_lens
,
a_strides
);
constexpr
auto
b_lens
=
get_shape_c
<
U
>
{}.
lens
;
constexpr
auto
b_strides
=
get_shape_c
<
U
>
{}.
strides
;
constexpr
ck
::
index_t
b_ndim
=
b_lens
.
size
();
//decltype(b_lens.size()){};
constexpr
auto
b_lens
=
get_shape_c
<
U
>
{}.
lens
;
constexpr
auto
b_strides
=
get_shape_c
<
U
>
{}.
strides
;
constexpr
ck
::
index_t
b_ndim
=
b_lens
.
size
();
//
decltype(b_lens.size()){};
// if (idx.global == 0)
// printf("b_ndim: %i\n", int(b_ndim));
auto
b_bin_op
=
CKBinaryElementwise
<
b_ndim
>
{};
constexpr
auto
b_desc
=
b_bin_op
.
MakeDescriptor_M
(
b_lens
,
b_strides
);
auto
b_bin_op
=
CKBinaryElementwise
<
b_ndim
>
{};
constexpr
auto
b_desc
=
b_bin_op
.
MakeDescriptor_M
(
b_lens
,
b_strides
);
constexpr
auto
c_lens
=
get_shape_c
<
V
>
{}.
lens
;
constexpr
auto
c_strides
=
get_shape_c
<
V
>
{}.
strides
;
constexpr
ck
::
index_t
c_ndim
=
c_lens
.
size
();
//decltype(c_lens.size()){};
constexpr
auto
c_lens
=
get_shape_c
<
V
>
{}.
lens
;
constexpr
auto
c_strides
=
get_shape_c
<
V
>
{}.
strides
;
constexpr
ck
::
index_t
c_ndim
=
c_lens
.
size
();
//
decltype(c_lens.size()){};
auto
c_bin_op
=
CKBinaryElementwise
<
c_ndim
>
{};
constexpr
auto
c_desc
=
c_bin_op
.
MakeDescriptor_M
(
c_lens
,
c_strides
);
using
AGridDesc_M
=
decltype
(
a_desc
);
using
BGridDesc_M
=
decltype
(
b_desc
);
using
CGridDesc_M
=
decltype
(
c_desc
);
constexpr
ck
::
index_t
MPerThread
=
8
;
using
AGridDesc_M
=
decltype
(
a_desc
);
using
BGridDesc_M
=
decltype
(
b_desc
);
using
CGridDesc_M
=
decltype
(
c_desc
);
constexpr
ck
::
index_t
MPerThread
=
8
;
constexpr
ck
::
index_t
AScalarPerVector
=
8
;
constexpr
ck
::
index_t
BScalarPerVector
=
8
;
constexpr
ck
::
index_t
CScalarPerVector
=
8
;
using
GridwiseBinEltwise
=
ck
::
GridwiseBinaryElementwise_1D
<
ADataType
,
using
GridwiseBinEltwise
=
ck
::
GridwiseBinaryElementwise_1D
<
ADataType
,
BDataType
,
CDataType
,
CDataType
,
...
...
@@ -208,7 +212,7 @@ __device__ void ck_elementwise(const T& a_t, const U& b_t, const V& c_t)
AScalarPerVector
,
BScalarPerVector
,
CScalarPerVector
>
;
auto
op
=
Add
{};
auto
op
=
Add
{};
GridwiseBinEltwise
::
Run
(
a_t
.
data
(),
b_t
.
data
(),
c_t
.
data
(),
a_desc
,
b_desc
,
c_desc
,
op
);
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment