Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
b7bc3c2b
Commit
b7bc3c2b
authored
Sep 15, 2023
by
Jing Zhang
Browse files
allow packed elementwise_op
parent
d61d9edf
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
105 additions
and
46 deletions
+105
-46
example/60_gemm_multiABD/gemm_multiABD_xdl_fp16.cpp
example/60_gemm_multiABD/gemm_multiABD_xdl_fp16.cpp
+7
-1
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7r2.hpp
...tion/gpu/thread/threadwise_tensor_slice_transfer_v7r2.hpp
+95
-44
script/cmake-ck-dev.sh
script/cmake-ck-dev.sh
+3
-1
No files found.
example/60_gemm_multiABD/gemm_multiABD_xdl_fp16.cpp
View file @
b7bc3c2b
...
...
@@ -44,10 +44,16 @@ using ELayout = Row;
struct
MultiATest
{
template
<
typename
A
,
typename
A0
,
typename
A1
>
__host__
__device__
constexpr
void
operator
()(
A
&
a
,
const
A0
&
a0
,
const
A1
&
a1
)
const
__host__
__device__
constexpr
void
operator
()(
A
&
a
,
const
A0
&
a0
,
const
A1
&
a1
)
const
;
template
<
>
__host__
__device__
constexpr
void
operator
()(
ck
::
half2_t
&
a
,
const
ck
::
half2_t
&
a0
,
const
ck
::
half2_t
&
a1
)
const
{
a
=
(
a0
+
a1
)
/
2
;
}
static
constexpr
ck
::
index_t
vec_len
=
2
;
};
struct
AlphaBetaAdd
...
...
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7r2.hpp
View file @
b7bc3c2b
...
...
@@ -8,6 +8,18 @@
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_description/tensor_space_filling_curve.hpp"
#include <type_traits>
template
<
typename
T
,
typename
=
void
>
struct
has_vec_len
:
std
::
false_type
{
};
template
<
typename
T
>
struct
has_vec_len
<
T
,
std
::
void_t
<
decltype
(
std
::
declval
<
T
>
().
vec_len
)
>>
:
std
::
true_type
{
};
namespace
ck
{
// Thread-level multi-source, multi-destination tensor slice data movement
...
...
@@ -131,7 +143,6 @@ struct ThreadwiseTensorSliceTransfer_v7r2
Number
<
num
>
{});
}
#if 1
// SrcDescs: Tuple<const SrcDesc0&, const SrcDesc1&, ...>
// SrcBuffers: Tuple<const SrcBuffer0&, const SrcBuffer1&, ...>
template
<
typename
SrcBuffers
,
...
...
@@ -143,7 +154,7 @@ struct ThreadwiseTensorSliceTransfer_v7r2
auto
src_vectors
=
generate_vectors
<
SrcDatas
,
SrcScalarPerVector
>
();
auto
dst_vectors
=
generate_vectors
<
DstDatas
,
DstScalarPerVector
>
();
#if 0
// copy data from src_bufs into src_vectors
static_for
<
0
,
nSrc
,
1
>
{}([
&
](
auto
i
)
{
using
src_vector_t
=
typename
remove_cvref_t
<
decltype
(
src_vectors
[
i
])
>::
type
;
...
...
@@ -155,51 +166,94 @@ struct ThreadwiseTensorSliceTransfer_v7r2
src_bufs
[
i
].
template
Get
<
src_vector_t
>(
src_coords_
[
i
].
GetOffset
(),
is_src_valid
);
});
#endif
static_for
<
0
,
SrcScalarPerVector
,
1
>
{}([
&
](
auto
i
)
{
const
auto
src_data_refs
=
generate_tie
(
// return type should be lvalue
[
&
](
auto
iSrc
)
->
const
auto
&
{
// copy data from src_bufs into src_vectors
using
src_vector_t
=
typename
remove_cvref_t
<
decltype
(
src_vectors
[
iSrc
])
>::
type
;
const
bool
is_src_valid
=
coordinate_has_valid_offset_assuming_visible_index_is_valid
(
src_descs
[
iSrc
],
src_coords_
[
iSrc
]);
src_vectors
(
iSrc
).
template
AsType
<
src_vector_t
>()(
I0
)
=
src_bufs
[
iSrc
].
template
Get
<
src_vector_t
>(
src_coords_
[
iSrc
].
GetOffset
(),
is_src_valid
);
// get reference to src data
using
SrcData
=
remove_cvref_t
<
tuple_element_t
<
iSrc
.
value
,
SrcDatas
>>
;
if
constexpr
(
!
has_vec_len
<
decltype
(
element_op_
)
>::
value
)
{
// apply pointwise function
static_for
<
0
,
SrcScalarPerVector
,
1
>
{}([
&
](
auto
i
)
{
// get reference to src data
const
auto
src_data_refs
=
generate_tie
(
// return type should be lvalue
[
&
](
auto
iSrc
)
->
const
auto
&
{
using
SrcData
=
remove_cvref_t
<
tuple_element_t
<
iSrc
.
value
,
SrcDatas
>>
;
return
src_vectors
[
iSrc
].
template
AsType
<
SrcData
>()[
i
];
},
Number
<
nSrc
>
{});
// get reference to dst data
auto
dst_data_refs
=
generate_tie
(
// return type should be lvalue
[
&
](
auto
iDst
)
->
auto
&
{
using
DstData
=
remove_cvref_t
<
tuple_element_t
<
iDst
.
value
,
DstDatas
>>
;
return
dst_vectors
(
iDst
).
template
AsType
<
DstData
>()(
i
);
},
Number
<
nDst
>
{});
// apply pointwise function
// pointwise function signature:
// element_op_(dst_data_refs[I0],
// dst_data_refs[I1],
// ...,
// src_data_refs[I0],
// src_data_refs[I1],
// ...)
unpack2
(
element_op_
,
dst_data_refs
,
src_data_refs
);
});
}
else
{
constexpr
auto
elem_op_vec_len
=
decltype
(
element_op_
)
::
vec_len
;
return
src_vectors
[
iSrc
].
template
AsType
<
SrcData
>()[
i
];
},
Number
<
nSrc
>
{});
static_assert
(
is_same
<
remove_cvref_t
<
decltype
(
elem_op_vec_len
)
>
,
index_t
>::
value
,
"vec_len in element_op_ type is not index_t"
);
// get reference to dst data
auto
dst_data_refs
=
generate_tie
(
// return type should be lvalue
[
&
](
auto
iDst
)
->
auto
&
{
using
DstData
=
remove_cvref_t
<
tuple_element_t
<
iDst
.
value
,
DstDatas
>>
;
static_assert
(
elem_op_vec_len
==
2
||
elem_op_vec_len
==
4
||
elem_op_vec_len
==
8
,
"vec_len in element_op_ must be 2, 4, 8"
);
return
dst_vectors
(
iDst
).
template
AsType
<
DstData
>()(
i
);
},
Number
<
nDst
>
{});
static_assert
(
SrcScalarPerVector
%
elem_op_vec_len
==
0
,
"vec_len in element_op_ cannot be divided by SrcScalarPerVector!"
);
// apply pointwise function
// pointwise function signature:
// element_op_(dst_data_refs[I0],
// dst_data_refs[I1],
// ...,
// src_data_refs[I0],
// src_data_refs[I1],
// ...)
unpack2
(
element_op_
,
dst_data_refs
,
src_data_refs
);
});
static_for
<
0
,
SrcScalarPerVector
/
elem_op_vec_len
,
1
>
{}([
&
](
auto
i
)
{
// get reference to src data
const
auto
src_data_refs
=
generate_tie
(
// return type should be lvalue
[
&
](
auto
iSrc
)
->
const
auto
&
{
using
SrcData
=
remove_cvref_t
<
tuple_element_t
<
iSrc
.
value
,
SrcDatas
>>
;
using
elem_op_vec_t
=
typename
vector_type
<
SrcData
,
elem_op_vec_len
>::
type
;
return
src_vectors
[
iSrc
].
template
AsType
<
elem_op_vec_t
>()[
i
];
},
Number
<
nSrc
>
{});
// get reference to dst data
auto
dst_data_refs
=
generate_tie
(
// return type should be lvalue
[
&
](
auto
iDst
)
->
auto
&
{
using
DstData
=
remove_cvref_t
<
tuple_element_t
<
iDst
.
value
,
DstDatas
>>
;
using
elem_op_vec_t
=
typename
vector_type
<
DstData
,
elem_op_vec_len
>::
type
;
return
dst_vectors
(
iDst
).
template
AsType
<
elem_op_vec_t
>()(
i
);
},
Number
<
nDst
>
{});
// apply pointwise function
// pointwise function signature:
// element_op_(dst_data_refs[I0],
// dst_data_refs[I1],
// ...,
// src_data_refs[I0],
// src_data_refs[I1],
// ...)
unpack2
(
element_op_
,
dst_data_refs
,
src_data_refs
);
});
}
dst_vectors_tuple_
(
iAccess
)
=
dst_vectors
;
...
...
@@ -227,9 +281,7 @@ struct ThreadwiseTensorSliceTransfer_v7r2
}
});
}
#endif
#if 1
// DstDescs: Tuple<const DstDesc0&, const DstDesc1&, ...>
// DstBuffers: Tuple<const DstBuffer0&, const DstBuffer1&, ...>
template
<
typename
DstBuffers
,
...
...
@@ -280,7 +332,6 @@ struct ThreadwiseTensorSliceTransfer_v7r2
}
});
}
#endif
// SrcDescs: Tuple<const SrcDesc0&, const SrcDesc1&, ...>
// SrcBuffers: Tuple<const SrcBuffer0&, const SrcBuffer1&, ...>
...
...
script/cmake-ck-dev.sh
View file @
b7bc3c2b
...
...
@@ -12,7 +12,9 @@ cmake
-save-temps=
$PWD
"
\
-D
CMAKE_BUILD_TYPE
=
Release
\
-D
BUILD_DEV
=
ON
\
-D
GPU_TARGETS
=
"gfx908
;gfx90a;gfx940
"
\
-D
GPU_TARGETS
=
"gfx908"
\
-D
CMAKE_VERBOSE_MAKEFILE:BOOL
=
ON
\
-D
USE_BITINT_EXTENSION_INT4
=
OFF
\
${
MY_PROJECT_SOURCE
}
#-D GPU_TARGETS="gfx908;gfx90a;gfx940" \
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment