Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
5f1fbd80
Commit
5f1fbd80
authored
May 28, 2022
by
rocking
Browse files
Add n-ary gridwise kernel
parent
3e6c2610
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
174 additions
and
0 deletions
+174
-0
include/ck/tensor_operation/gpu/grid/gridwise_nway_elementwise_1d.hpp
...ensor_operation/gpu/grid/gridwise_nway_elementwise_1d.hpp
+171
-0
include/ck/utility/type.hpp
include/ck/utility/type.hpp
+3
-0
No files found.
include/ck/tensor_operation/gpu/grid/gridwise_nway_elementwise_1d.hpp
0 → 100644
View file @
5f1fbd80
#pragma once
#include "cluster_descriptor.hpp"
#include "data_type.hpp"
#include "element_wise_operation.hpp"
#include "threadwise_tensor_slice_transfer.hpp"
namespace
ck
{
template
<
typename
GridwiseEltwise
,
typename
SrcDataTypes
,
typename
DstDataTypes
,
typename
SrcGridDesc_M
,
typename
DstGridDesc_M
,
typename
ElementwiseFunctor
>
__global__
void
kernel_nway_elementwise_1d
(
const
SrcDataTypes
p_src_globals
,
DstDataTypes
p_dst_globals
,
const
SrcGridDesc_M
src_grid_desc_ms
,
const
DstGridDesc_M
dst_grid_desc_ms
,
const
ElementwiseFunctor
functor
)
{
GridwiseEltwise
::
Run
(
p_src_globals
,
p_dst_globals
,
src_grid_desc_ms
,
dst_grid_desc_ms
,
functor
);
}
template
<
typename
SrcDataTypes
,
typename
DstDataTypes
,
typename
ComputeDataType
,
typename
SrcGridDesc_M
,
typename
DstGridDesc_M
,
typename
ElementwiseFunctor
,
index_t
MPerThread
,
typename
SrcScalarPerVector
,
typename
DstScalarPerVector
>
struct
GridwiseNWayElementwise_1D
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
thread_desc_m
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MPerThread
>
{}));
using
PassThrough
=
tensor_operation
::
element_wise
::
PassThrough
;
static
__device__
auto
CalculateElementwiseIndex
()
{
const
index_t
global_thread_id
=
get_thread_global_1d_id
();
return
make_multi_index
(
global_thread_id
*
MPerThread
);
}
__device__
static
void
Run
(
const
SrcDataTypes
p_src_globals
,
DstDataTypes
p_dst_globals
,
const
SrcGridDesc_M
src_grid_desc_ms
,
const
DstGridDesc_M
dst_grid_desc_ms
,
const
ElementwiseFunctor
functor
)
{
constexpr
auto
Isrc_size
=
Number
<
SrcDataTypes
::
Size
()
>
{};
constexpr
auto
Idst_size
=
Number
<
DstDataTypes
::
Size
()
>
{};
const
auto
src_global_buf
=
generate_tuple
(
[
&
](
auto
I
)
{
return
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_src_globals
[
I
],
src_grid_desc_ms
[
I
].
GetElementSpaceSize
());
},
Isrc_size
);
auto
dst_global_buf
=
generate_tuple
(
[
&
](
auto
I
)
{
return
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_dst_globals
[
I
],
dst_grid_desc_ms
[
I
].
GetElementSpaceSize
());
},
Idst_size
);
auto
src_thread_buf
=
generate_tuple
(
[
&
](
auto
)
{
return
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
ComputeDataType
,
MPerThread
,
true
>
{};
},
Isrc_size
);
auto
dst_thread_buf
=
generate_tuple
(
[
&
](
auto
)
{
return
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
ComputeDataType
,
MPerThread
,
true
>
{};
},
Idst_size
);
const
auto
thread_store_global_offset
=
CalculateElementwiseIndex
();
auto
src_global_load
=
generate_tuple
(
[
&
](
auto
I
)
{
auto
p_src_global
=
p_src_globals
[
I
];
auto
p_src_grid_desc_m
=
src_grid_desc_ms
[
I
];
return
ThreadwiseTensorSliceTransfer_v2
<
remove_const_t
<
remove_pointer_t
<
decltype
(
p_src_global
)
>>
,
ComputeDataType
,
decltype
(
p_src_grid_desc_m
),
decltype
(
thread_desc_m
),
Sequence
<
MPerThread
>
,
// SliceLengths
Sequence
<
0
>
,
// DimAccessOrder
0
,
// SrcVectorDim
SrcScalarPerVector
::
At
(
I
),
// ScalarPerVector
1
,
// SrcScalarStrideInVector
false
>
{
p_src_grid_desc_m
,
thread_store_global_offset
};
},
Isrc_size
);
auto
dst_global_write
=
generate_tuple
(
[
&
](
auto
I
)
{
auto
p_dst_global
=
p_dst_globals
[
I
];
auto
p_dst_grid_desc_m
=
dst_grid_desc_ms
[
I
];
return
ThreadwiseTensorSliceTransfer_v1r3
<
ComputeDataType
,
remove_pointer_t
<
decltype
(
p_dst_global
)
>
,
decltype
(
thread_desc_m
),
decltype
(
p_dst_grid_desc_m
),
PassThrough
,
Sequence
<
MPerThread
>
,
// SliceLengths
Sequence
<
0
>
,
// DimAccessOrder
0
,
// DstVectorDim
DstScalarPerVector
::
At
(
I
),
// ScalarPerVector
InMemoryDataOperationEnum
::
Set
,
1
,
// DstScalarStrideInVector
false
>
{
p_dst_grid_desc_m
,
thread_store_global_offset
,
PassThrough
{}};
},
Idst_size
);
const
index_t
blockSize
=
get_block_size
();
const
index_t
blockPerGrid
=
get_grid_size
();
const
auto
M
=
dst_grid_desc_ms
[
I0
].
GetLength
(
I0
);
const
index_t
loop_step
=
blockPerGrid
*
blockSize
*
MPerThread
;
const
auto
loop_step_index
=
make_multi_index
(
loop_step
);
index_t
num_iter
=
M
/
(
loop_step
);
do
{
// read and process MPerThread elements
static_for
<
0
,
Isrc_size
,
1
>
{}([
&
](
auto
I
)
{
src_global_load
(
I
).
Run
(
src_grid_desc_ms
[
I
],
src_global_buf
[
I
],
thread_desc_m
,
make_tuple
(
I0
),
src_thread_buf
(
I
));
src_global_load
(
I
).
MoveSrcSliceWindow
(
src_grid_desc_ms
[
I
],
loop_step_index
);
});
static_for
<
0
,
MPerThread
,
1
>
{}([
&
](
auto
m
)
{
constexpr
auto
offset
=
thread_desc_m
.
CalculateOffset
(
make_tuple
(
m
));
const
auto
src_tuple
=
generate_tuple
(
[
&
](
auto
I
)
{
return
src_thread_buf
[
I
][
Number
<
offset
>
{}];
},
Isrc_size
);
auto
dst_tuple
=
generate_tuple
(
[
&
](
auto
I
)
{
return
dst_thread_buf
(
I
)(
Number
<
offset
>
{});
},
Idst_size
);
(
void
)
src_tuple
;
(
void
)
dst_tuple
;
// TODO - n-ary functor
// functor(src_tuple, dst_tuple);
});
static_for
<
0
,
Idst_size
,
1
>
{}([
&
](
auto
I
)
{
dst_global_write
(
I
).
Run
(
thread_desc_m
,
make_tuple
(
I0
),
// SrcSliceOriginIdx
dst_thread_buf
[
I
],
dst_grid_desc_ms
[
I
],
dst_global_buf
(
I
));
dst_global_write
(
I
).
MoveDstSliceWindow
(
dst_grid_desc_ms
[
I
],
loop_step_index
);
});
}
while
(
--
num_iter
);
}
};
}
// namespace ck
include/ck/utility/type.hpp
View file @
5f1fbd80
...
...
@@ -32,6 +32,9 @@ using remove_cvref_t = remove_cv_t<std::remove_reference_t<T>>;
template
<
typename
T
>
using
remove_pointer_t
=
typename
std
::
remove_pointer
<
T
>::
type
;
template
<
typename
T
>
using
remove_const_t
=
typename
std
::
remove_const
<
T
>::
type
;
template
<
typename
T
>
inline
constexpr
bool
is_pointer_v
=
std
::
is_pointer
<
T
>::
value
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment