Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
ef0c2e64
Commit
ef0c2e64
authored
Feb 02, 2021
by
Jing Zhang
Browse files
move thread_buff into blockcopy
parent
7cf350d6
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
17 additions
and
24 deletions
+17
-24
composable_kernel/include/tensor_operation/blockwise_generic_tensor_slice_copy_v2.hpp
...nsor_operation/blockwise_generic_tensor_slice_copy_v2.hpp
+11
-10
composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_fp16_bfp16.hpp
...lude/tensor_operation/gridwise_gemm_xdlops_fp16_bfp16.hpp
+6
-14
No files found.
composable_kernel/include/tensor_operation/blockwise_generic_tensor_slice_copy_v2.hpp
View file @
ef0c2e64
...
@@ -78,9 +78,8 @@ struct BlockwiseGenericTensorSliceCopy_v5
...
@@ -78,9 +78,8 @@ struct BlockwiseGenericTensorSliceCopy_v5
return
ThreadBufferDesc
::
GetElementSpace
();
return
ThreadBufferDesc
::
GetElementSpace
();
}
}
template
<
typename
BlockSrcData
,
typename
ThreadBuffData
>
template
<
typename
BlockSrcData
>
__device__
void
RunLoadThreadBuffer
(
const
BlockSrcData
*
p_block_src
,
__device__
void
RunLoadThreadBuffer
(
const
BlockSrcData
*
p_block_src
)
ThreadBuffData
&
thread_buff
)
{
{
if
(
BlockSize
==
mThreadClusterDesc
.
GetElementSize
()
or
if
(
BlockSize
==
mThreadClusterDesc
.
GetElementSize
()
or
get_thread_local_1d_id
()
<
mThreadClusterDesc
.
GetElementSize
())
get_thread_local_1d_id
()
<
mThreadClusterDesc
.
GetElementSize
())
...
@@ -89,8 +88,8 @@ struct BlockwiseGenericTensorSliceCopy_v5
...
@@ -89,8 +88,8 @@ struct BlockwiseGenericTensorSliceCopy_v5
}
}
}
}
template
<
typename
ThreadBuffData
,
typename
BlockDstData
>
template
<
typename
BlockDstData
>
__device__
void
RunStoreThreadBuffer
(
ThreadBuffData
thread_buff
,
BlockDstData
*
p_block_dst
)
__device__
void
RunStoreThreadBuffer
(
BlockDstData
*
p_block_dst
)
{
{
if
(
BlockSize
==
mThreadClusterDesc
.
GetElementSize
()
or
if
(
BlockSize
==
mThreadClusterDesc
.
GetElementSize
()
or
get_thread_local_1d_id
()
<
mThreadClusterDesc
.
GetElementSize
())
get_thread_local_1d_id
()
<
mThreadClusterDesc
.
GetElementSize
())
...
@@ -99,9 +98,8 @@ struct BlockwiseGenericTensorSliceCopy_v5
...
@@ -99,9 +98,8 @@ struct BlockwiseGenericTensorSliceCopy_v5
}
}
}
}
template
<
typename
BlockSrcData
,
typename
BlockDstData
,
typename
ThreadBuffData
>
template
<
typename
BlockSrcData
,
typename
BlockDstData
>
__device__
void
__device__
void
Run
(
const
BlockSrcData
*
p_block_src
,
BlockDstData
*
p_block_dst
)
Run
(
const
BlockSrcData
*
p_block_src
,
BlockDstData
*
p_block_dst
,
ThreadBuffData
&
thread_buff
)
{
{
static_assert
(
ThreadBufferAddressSpace
==
AddressSpace
::
Vgpr
,
static_assert
(
ThreadBufferAddressSpace
==
AddressSpace
::
Vgpr
,
"wrong! This function use vgpr as its thread "
"wrong! This function use vgpr as its thread "
...
@@ -112,8 +110,8 @@ struct BlockwiseGenericTensorSliceCopy_v5
...
@@ -112,8 +110,8 @@ struct BlockwiseGenericTensorSliceCopy_v5
if
(
BlockSize
==
mThreadClusterDesc
.
GetElementSize
()
or
if
(
BlockSize
==
mThreadClusterDesc
.
GetElementSize
()
or
get_thread_local_1d_id
()
<
mThreadClusterDesc
.
GetElementSize
())
get_thread_local_1d_id
()
<
mThreadClusterDesc
.
GetElementSize
())
{
{
RunLoadThreadBuffer
(
p_block_src
,
thread_buff
);
RunLoadThreadBuffer
(
p_block_src
);
RunStoreThreadBuffer
(
thread_buff
,
p_block_dst
);
RunStoreThreadBuffer
(
p_block_dst
);
}
}
}
}
...
@@ -163,6 +161,9 @@ struct BlockwiseGenericTensorSliceCopy_v5
...
@@ -163,6 +161,9 @@ struct BlockwiseGenericTensorSliceCopy_v5
make_cluster_descriptor
(
ThreadClusterLengths
{},
ThreadClusterArrangeOrder
{});
make_cluster_descriptor
(
ThreadClusterLengths
{},
ThreadClusterArrangeOrder
{});
ThreadwiseCopy
mThreadwiseCopy
;
ThreadwiseCopy
mThreadwiseCopy
;
using
ThreadBufferType
=
decltype
(
GetRegBuffer
<
float
,
GetThreadBufferSize
()
>
());
ThreadBufferType
thread_buff
;
};
};
}
// namespace ck
}
// namespace ck
...
...
composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_fp16_bfp16.hpp
View file @
ef0c2e64
...
@@ -496,18 +496,10 @@ struct GridwiseBatchGemmXdlops_gkmkpack_gknkpack_gmn_v2
...
@@ -496,18 +496,10 @@ struct GridwiseBatchGemmXdlops_gkmkpack_gknkpack_gmn_v2
constexpr
index_t
c_thread_size
=
MPerBlock
*
NPerBlock
/
BlockSize
;
constexpr
index_t
c_thread_size
=
MPerBlock
*
NPerBlock
/
BlockSize
;
auto
c_thread_vec
=
GetRegBuffer
<
AccFloat
,
c_thread_size
>
();
auto
c_thread_vec
=
GetRegBuffer
<
AccFloat
,
c_thread_size
>
();
using
ThreadBufferTypeA
=
decltype
(
GetRegBuffer
<
ABFloat
,
a_blockwise_copy
.
GetThreadBufferSize
()
>
());
using
ThreadBufferTypeB
=
decltype
(
GetRegBuffer
<
ABFloat
,
b_blockwise_copy
.
GetThreadBufferSize
()
>
());
ThreadBufferTypeA
thread_buff_a
;
ThreadBufferTypeB
thread_buff_b
;
// preload data into LDS
// preload data into LDS
{
{
a_blockwise_copy
.
Run
(
p_a_global
,
p_a_block
,
thread_buff_a
);
a_blockwise_copy
.
Run
(
p_a_global
,
p_a_block
);
b_blockwise_copy
.
Run
(
p_b_global
,
p_b_block
,
thread_buff_b
);
b_blockwise_copy
.
Run
(
p_b_global
,
p_b_block
);
}
}
constexpr
auto
blockwise_a_copy_src_step
=
Sequence
<
0
,
KPerBlock
,
0
,
0
>
{};
constexpr
auto
blockwise_a_copy_src_step
=
Sequence
<
0
,
KPerBlock
,
0
,
0
>
{};
...
@@ -521,8 +513,8 @@ struct GridwiseBatchGemmXdlops_gkmkpack_gknkpack_gmn_v2
...
@@ -521,8 +513,8 @@ struct GridwiseBatchGemmXdlops_gkmkpack_gknkpack_gmn_v2
a_blockwise_copy
.
MoveSrcSliceWindow
(
blockwise_a_copy_src_step
,
True
);
a_blockwise_copy
.
MoveSrcSliceWindow
(
blockwise_a_copy_src_step
,
True
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
blockwise_b_copy_src_step
,
True
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
blockwise_b_copy_src_step
,
True
);
a_blockwise_copy
.
RunLoadThreadBuffer
(
p_a_global
,
thread_buff_a
);
a_blockwise_copy
.
RunLoadThreadBuffer
(
p_a_global
);
b_blockwise_copy
.
RunLoadThreadBuffer
(
p_b_global
,
thread_buff_b
);
b_blockwise_copy
.
RunLoadThreadBuffer
(
p_b_global
);
block_sync_lds
();
block_sync_lds
();
...
@@ -539,8 +531,8 @@ struct GridwiseBatchGemmXdlops_gkmkpack_gknkpack_gmn_v2
...
@@ -539,8 +531,8 @@ struct GridwiseBatchGemmXdlops_gkmkpack_gknkpack_gmn_v2
block_sync_lds
();
block_sync_lds
();
// store next data to LDS
// store next data to LDS
a_blockwise_copy
.
RunStoreThreadBuffer
(
thread_buff_a
,
p_a_block
);
a_blockwise_copy
.
RunStoreThreadBuffer
(
p_a_block
);
b_blockwise_copy
.
RunStoreThreadBuffer
(
thread_buff_b
,
p_b_block
);
b_blockwise_copy
.
RunStoreThreadBuffer
(
p_b_block
);
}
}
// tail
// tail
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment