Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
bd22abb5
"git@developer.sourcefind.cn:OpenDAS/megatron-lm.git" did not exist on "c882ac61182d423a89d21b453251a20fb7271a67"
Commit
bd22abb5
authored
Apr 18, 2020
by
Chao Liu
Browse files
refactor
parent
e131f6aa
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
37 additions
and
48 deletions
+37
-48
composable_kernel/include/tensor_operation/blockwise_generic_tensor_slice_copy.hpp
.../tensor_operation/blockwise_generic_tensor_slice_copy.hpp
+37
-48
No files found.
composable_kernel/include/tensor_operation/blockwise_generic_tensor_slice_copy.hpp
View file @
bd22abb5
...
...
@@ -52,20 +52,23 @@ struct BlockwiseGenericTensorSliceCopy_v4
is_same
<
BlockSliceLengths
,
decltype
(
ThreadSliceLengths
{}
*
ThreadClusterLengths
{})
>
{},
"wrong! threads should be mapped to cover entire slicing window"
);
// map threads to cluster
constexpr
auto
thread_cluster_desc
=
make_cluster_descriptor
(
ThreadClusterLengths
{},
ThreadClusterArrangeOrder
{});
static_assert
(
BlockSize
>=
mThreadClusterDesc
.
GetElementSize
(),
"wrong! BlockSize too small"
);
const
auto
thread_cluster_id
=
thread_cluster_desc
.
CalculateClusterIndex
(
get_thread_local_1d_id
());
if
(
BlockSize
==
mThreadClusterDesc
.
GetElementSize
()
or
get_thread_local_1d_id
()
<
mThreadClusterDesc
.
GetElementSize
())
{
const
auto
thread_cluster_id
=
mThreadClusterDesc
.
CalculateClusterIndex
(
get_thread_local_1d_id
());
const
auto
thread_data_id_begin
=
thread_cluster_id
*
ThreadSliceLengths
{};
const
auto
thread_data_id_begin
=
thread_cluster_id
*
ThreadSliceLengths
{};
mThreadwiseLoad
.
SetSrcSliceOrigin
(
src_block_slice_origin
+
thread_data_id_begin
);
mThreadwiseLoad
.
SetDstSliceOrigin
(
make_zero_array
<
index_t
,
nDim
>
());
mThreadwiseLoad
.
SetSrcSliceOrigin
(
src_block_slice_origin
+
thread_data_id_begin
);
mThreadwiseLoad
.
SetDstSliceOrigin
(
make_zero_array
<
index_t
,
nDim
>
());
mThreadwiseStore
.
SetSrcSliceOrigin
(
make_zero_array
<
index_t
,
nDim
>
());
mThreadwiseStore
.
SetDstSliceOrigin
(
dst_block_slice_origin
+
thread_data_id_begin
);
mThreadwiseStore
.
SetSrcSliceOrigin
(
make_zero_array
<
index_t
,
nDim
>
());
mThreadwiseStore
.
SetDstSliceOrigin
(
dst_block_slice_origin
+
thread_data_id_begin
);
}
}
__device__
static
constexpr
index_t
GetThreadBufferSize
()
...
...
@@ -80,22 +83,8 @@ struct BlockwiseGenericTensorSliceCopy_v4
constexpr
bool
has_optimized_address_calculation
=
decltype
(
mThreadwiseStore
)
::
HasWorkingOptimizedAddressCalculation
();
constexpr
auto
thread_cluster_desc
=
make_cluster_descriptor
(
ThreadClusterLengths
{},
ThreadClusterArrangeOrder
{});
if
(
BlockSize
==
thread_cluster_desc
.
GetElementSize
())
{
// TODO: threadwise copy is still being tweaked
if
(
has_optimized_address_calculation
)
{
mThreadwiseLoad
.
Run_optimized_src_address_calculation
(
p_block_src
,
p_thread_buffer
);
}
else
{
mThreadwiseLoad
.
Run
(
p_block_src
,
p_thread_buffer
);
}
}
else
if
(
get_thread_local_1d_id
()
<
thread_cluster_desc
.
GetElementSize
())
if
(
BlockSize
==
mThreadClusterDesc
.
GetElementSize
()
or
get_thread_local_1d_id
()
<
mThreadClusterDesc
.
GetElementSize
())
{
// TODO: threadwise copy is still being tweaked
if
(
has_optimized_address_calculation
)
...
...
@@ -116,23 +105,8 @@ struct BlockwiseGenericTensorSliceCopy_v4
constexpr
bool
has_optimized_address_calculation
=
decltype
(
mThreadwiseStore
)
::
HasWorkingOptimizedAddressCalculation
();
constexpr
auto
thread_cluster_desc
=
make_cluster_descriptor
(
ThreadClusterLengths
{},
ThreadClusterArrangeOrder
{});
if
(
BlockSize
==
thread_cluster_desc
.
GetElementSize
())
{
// TODO: threadwise copy is still being tweaked
if
(
has_optimized_address_calculation
)
{
mThreadwiseStore
.
Run_optimized_dst_address_calculation
(
p_thread_buffer
,
p_block_dst
);
}
else
{
mThreadwiseStore
.
Run
(
p_thread_buffer
,
p_block_dst
);
}
}
else
if
(
get_thread_local_1d_id
()
<
thread_cluster_desc
.
GetElementSize
())
if
(
BlockSize
==
mThreadClusterDesc
.
GetElementSize
()
or
get_thread_local_1d_id
()
<
mThreadClusterDesc
.
GetElementSize
())
{
// TODO: threadwise copy is still being tweaked
if
(
has_optimized_address_calculation
)
...
...
@@ -158,10 +132,14 @@ struct BlockwiseGenericTensorSliceCopy_v4
BlockSrcData
p_thread_buffer
[
GetThreadBufferSize
()];
RunLoadThreadBuffer
(
p_block_src
,
p_thread_buffer
);
if
(
BlockSize
==
mThreadClusterDesc
.
GetElementSize
()
or
get_thread_local_1d_id
()
<
mThreadClusterDesc
.
GetElementSize
())
{
RunLoadThreadBuffer
(
p_block_src
,
p_thread_buffer
);
// if there is type conversion, it's done during store
RunStoreThreadBuffer
(
p_thread_buffer
,
p_block_dst
);
// if there is type conversion, it's done during store
RunStoreThreadBuffer
(
p_thread_buffer
,
p_block_dst
);
}
}
template
<
typename
T
,
bool
PositiveDirection
>
...
...
@@ -169,7 +147,11 @@ struct BlockwiseGenericTensorSliceCopy_v4
MoveSrcSliceWindow
(
const
T
&
step_sizes
,
integral_constant
<
bool
,
PositiveDirection
>
positive_direction
)
{
mThreadwiseLoad
.
MoveSrcSliceWindow
(
step_sizes
,
positive_direction
);
if
(
BlockSize
==
mThreadClusterDesc
.
GetElementSize
()
or
get_thread_local_1d_id
()
<
mThreadClusterDesc
.
GetElementSize
())
{
mThreadwiseLoad
.
MoveSrcSliceWindow
(
step_sizes
,
positive_direction
);
}
}
template
<
typename
T
,
bool
PositiveDirection
>
...
...
@@ -177,7 +159,11 @@ struct BlockwiseGenericTensorSliceCopy_v4
MoveDstSliceWindow
(
const
T
&
step_sizes
,
integral_constant
<
bool
,
PositiveDirection
>
positive_direction
)
{
mThreadwiseStore
.
MoveDstSliceWindow
(
step_sizes
,
positive_direction
);
if
(
BlockSize
==
mThreadClusterDesc
.
GetElementSize
()
or
get_thread_local_1d_id
()
<
mThreadClusterDesc
.
GetElementSize
())
{
mThreadwiseStore
.
MoveDstSliceWindow
(
step_sizes
,
positive_direction
);
}
}
private:
...
...
@@ -205,6 +191,9 @@ struct BlockwiseGenericTensorSliceCopy_v4
DstAddressSpace
,
DstInMemOp
>
;
static
constexpr
auto
mThreadClusterDesc
=
make_cluster_descriptor
(
ThreadClusterLengths
{},
ThreadClusterArrangeOrder
{});
ThreadwiseLoad
mThreadwiseLoad
;
ThreadwiseStore
mThreadwiseStore
;
};
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment