Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
bd22abb5
"src/include/Sequence.hpp" did not exist on "766b0a9eafe29a5d2a75c350345e54165ceaf405"
Commit
bd22abb5
authored
Apr 18, 2020
by
Chao Liu
Browse files
refactor
parent
e131f6aa
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
37 additions
and
48 deletions
+37
-48
composable_kernel/include/tensor_operation/blockwise_generic_tensor_slice_copy.hpp
.../tensor_operation/blockwise_generic_tensor_slice_copy.hpp
+37
-48
No files found.
composable_kernel/include/tensor_operation/blockwise_generic_tensor_slice_copy.hpp
View file @
bd22abb5
...
@@ -52,12 +52,14 @@ struct BlockwiseGenericTensorSliceCopy_v4
...
@@ -52,12 +52,14 @@ struct BlockwiseGenericTensorSliceCopy_v4
is_same
<
BlockSliceLengths
,
decltype
(
ThreadSliceLengths
{}
*
ThreadClusterLengths
{})
>
{},
is_same
<
BlockSliceLengths
,
decltype
(
ThreadSliceLengths
{}
*
ThreadClusterLengths
{})
>
{},
"wrong! threads should be mapped to cover entire slicing window"
);
"wrong! threads should be mapped to cover entire slicing window"
);
// map threads to cluster
static_assert
(
BlockSize
>=
mThreadClusterDesc
.
GetElementSize
(),
constexpr
auto
thread_cluster_desc
=
"wrong! BlockSize too small"
);
make_cluster_descriptor
(
ThreadClusterLengths
{},
ThreadClusterArrangeOrder
{});
if
(
BlockSize
==
mThreadClusterDesc
.
GetElementSize
()
or
get_thread_local_1d_id
()
<
mThreadClusterDesc
.
GetElementSize
())
{
const
auto
thread_cluster_id
=
const
auto
thread_cluster_id
=
t
hread
_c
luster
_d
esc
.
CalculateClusterIndex
(
get_thread_local_1d_id
());
mT
hread
C
luster
D
esc
.
CalculateClusterIndex
(
get_thread_local_1d_id
());
const
auto
thread_data_id_begin
=
thread_cluster_id
*
ThreadSliceLengths
{};
const
auto
thread_data_id_begin
=
thread_cluster_id
*
ThreadSliceLengths
{};
...
@@ -67,6 +69,7 @@ struct BlockwiseGenericTensorSliceCopy_v4
...
@@ -67,6 +69,7 @@ struct BlockwiseGenericTensorSliceCopy_v4
mThreadwiseStore
.
SetSrcSliceOrigin
(
make_zero_array
<
index_t
,
nDim
>
());
mThreadwiseStore
.
SetSrcSliceOrigin
(
make_zero_array
<
index_t
,
nDim
>
());
mThreadwiseStore
.
SetDstSliceOrigin
(
dst_block_slice_origin
+
thread_data_id_begin
);
mThreadwiseStore
.
SetDstSliceOrigin
(
dst_block_slice_origin
+
thread_data_id_begin
);
}
}
}
__device__
static
constexpr
index_t
GetThreadBufferSize
()
__device__
static
constexpr
index_t
GetThreadBufferSize
()
{
{
...
@@ -80,22 +83,8 @@ struct BlockwiseGenericTensorSliceCopy_v4
...
@@ -80,22 +83,8 @@ struct BlockwiseGenericTensorSliceCopy_v4
constexpr
bool
has_optimized_address_calculation
=
constexpr
bool
has_optimized_address_calculation
=
decltype
(
mThreadwiseStore
)
::
HasWorkingOptimizedAddressCalculation
();
decltype
(
mThreadwiseStore
)
::
HasWorkingOptimizedAddressCalculation
();
constexpr
auto
thread_cluster_desc
=
if
(
BlockSize
==
mThreadClusterDesc
.
GetElementSize
()
or
make_cluster_descriptor
(
ThreadClusterLengths
{},
ThreadClusterArrangeOrder
{});
get_thread_local_1d_id
()
<
mThreadClusterDesc
.
GetElementSize
())
if
(
BlockSize
==
thread_cluster_desc
.
GetElementSize
())
{
// TODO: threadwise copy is still being tweaked
if
(
has_optimized_address_calculation
)
{
mThreadwiseLoad
.
Run_optimized_src_address_calculation
(
p_block_src
,
p_thread_buffer
);
}
else
{
mThreadwiseLoad
.
Run
(
p_block_src
,
p_thread_buffer
);
}
}
else
if
(
get_thread_local_1d_id
()
<
thread_cluster_desc
.
GetElementSize
())
{
{
// TODO: threadwise copy is still being tweaked
// TODO: threadwise copy is still being tweaked
if
(
has_optimized_address_calculation
)
if
(
has_optimized_address_calculation
)
...
@@ -116,23 +105,8 @@ struct BlockwiseGenericTensorSliceCopy_v4
...
@@ -116,23 +105,8 @@ struct BlockwiseGenericTensorSliceCopy_v4
constexpr
bool
has_optimized_address_calculation
=
constexpr
bool
has_optimized_address_calculation
=
decltype
(
mThreadwiseStore
)
::
HasWorkingOptimizedAddressCalculation
();
decltype
(
mThreadwiseStore
)
::
HasWorkingOptimizedAddressCalculation
();
constexpr
auto
thread_cluster_desc
=
if
(
BlockSize
==
mThreadClusterDesc
.
GetElementSize
()
or
make_cluster_descriptor
(
ThreadClusterLengths
{},
ThreadClusterArrangeOrder
{});
get_thread_local_1d_id
()
<
mThreadClusterDesc
.
GetElementSize
())
if
(
BlockSize
==
thread_cluster_desc
.
GetElementSize
())
{
// TODO: threadwise copy is still being tweaked
if
(
has_optimized_address_calculation
)
{
mThreadwiseStore
.
Run_optimized_dst_address_calculation
(
p_thread_buffer
,
p_block_dst
);
}
else
{
mThreadwiseStore
.
Run
(
p_thread_buffer
,
p_block_dst
);
}
}
else
if
(
get_thread_local_1d_id
()
<
thread_cluster_desc
.
GetElementSize
())
{
{
// TODO: threadwise copy is still being tweaked
// TODO: threadwise copy is still being tweaked
if
(
has_optimized_address_calculation
)
if
(
has_optimized_address_calculation
)
...
@@ -158,27 +132,39 @@ struct BlockwiseGenericTensorSliceCopy_v4
...
@@ -158,27 +132,39 @@ struct BlockwiseGenericTensorSliceCopy_v4
BlockSrcData
p_thread_buffer
[
GetThreadBufferSize
()];
BlockSrcData
p_thread_buffer
[
GetThreadBufferSize
()];
if
(
BlockSize
==
mThreadClusterDesc
.
GetElementSize
()
or
get_thread_local_1d_id
()
<
mThreadClusterDesc
.
GetElementSize
())
{
RunLoadThreadBuffer
(
p_block_src
,
p_thread_buffer
);
RunLoadThreadBuffer
(
p_block_src
,
p_thread_buffer
);
// if there is type conversion, it's done during store
// if there is type conversion, it's done during store
RunStoreThreadBuffer
(
p_thread_buffer
,
p_block_dst
);
RunStoreThreadBuffer
(
p_thread_buffer
,
p_block_dst
);
}
}
}
template
<
typename
T
,
bool
PositiveDirection
>
template
<
typename
T
,
bool
PositiveDirection
>
__device__
void
__device__
void
MoveSrcSliceWindow
(
const
T
&
step_sizes
,
MoveSrcSliceWindow
(
const
T
&
step_sizes
,
integral_constant
<
bool
,
PositiveDirection
>
positive_direction
)
integral_constant
<
bool
,
PositiveDirection
>
positive_direction
)
{
if
(
BlockSize
==
mThreadClusterDesc
.
GetElementSize
()
or
get_thread_local_1d_id
()
<
mThreadClusterDesc
.
GetElementSize
())
{
{
mThreadwiseLoad
.
MoveSrcSliceWindow
(
step_sizes
,
positive_direction
);
mThreadwiseLoad
.
MoveSrcSliceWindow
(
step_sizes
,
positive_direction
);
}
}
}
template
<
typename
T
,
bool
PositiveDirection
>
template
<
typename
T
,
bool
PositiveDirection
>
__device__
void
__device__
void
MoveDstSliceWindow
(
const
T
&
step_sizes
,
MoveDstSliceWindow
(
const
T
&
step_sizes
,
integral_constant
<
bool
,
PositiveDirection
>
positive_direction
)
integral_constant
<
bool
,
PositiveDirection
>
positive_direction
)
{
if
(
BlockSize
==
mThreadClusterDesc
.
GetElementSize
()
or
get_thread_local_1d_id
()
<
mThreadClusterDesc
.
GetElementSize
())
{
{
mThreadwiseStore
.
MoveDstSliceWindow
(
step_sizes
,
positive_direction
);
mThreadwiseStore
.
MoveDstSliceWindow
(
step_sizes
,
positive_direction
);
}
}
}
private:
private:
using
ThreadBufferDesc
=
decltype
(
make_native_tensor_descriptor_packed
(
ThreadSliceLengths
{}));
using
ThreadBufferDesc
=
decltype
(
make_native_tensor_descriptor_packed
(
ThreadSliceLengths
{}));
...
@@ -205,6 +191,9 @@ struct BlockwiseGenericTensorSliceCopy_v4
...
@@ -205,6 +191,9 @@ struct BlockwiseGenericTensorSliceCopy_v4
DstAddressSpace
,
DstAddressSpace
,
DstInMemOp
>
;
DstInMemOp
>
;
static
constexpr
auto
mThreadClusterDesc
=
make_cluster_descriptor
(
ThreadClusterLengths
{},
ThreadClusterArrangeOrder
{});
ThreadwiseLoad
mThreadwiseLoad
;
ThreadwiseLoad
mThreadwiseLoad
;
ThreadwiseStore
mThreadwiseStore
;
ThreadwiseStore
mThreadwiseStore
;
};
};
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment