Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
bf44991f
Commit
bf44991f
authored
May 31, 2022
by
Anthony Chang
Browse files
use LDS mem pool for reduction workspace
parent
3db406f0
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
13 additions
and
8 deletions
+13
-8
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_layernorm_cshuffle_v1.hpp
...tion/gpu/grid/gridwise_gemm_xdl_layernorm_cshuffle_v1.hpp
+13
-8
No files found.
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_layernorm_cshuffle_v1.hpp
View file @
bf44991f
...
@@ -217,12 +217,16 @@ struct GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
...
@@ -217,12 +217,16 @@ struct GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
constexpr
auto
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
=
constexpr
auto
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
=
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
();
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
();
constexpr
auto
c_block_size
=
// Align 16 bytes (maximum LDS read/write width)
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
.
GetElementSpaceSize
();
constexpr
auto
c_block_size_aligned
=
math
::
integer_least_multiple
(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
.
GetElementSpaceSize
()
*
sizeof
(
FloatCShuffle
),
16
)
/
sizeof
(
FloatCShuffle
);
// LDS allocation for reduction workspace
constexpr
index_t
c_lds_workspace_size
=
BlockSize
;
return
math
::
max
((
a_block_space_size_aligned
+
b_block_space_size_aligned
)
*
return
math
::
max
((
a_block_space_size_aligned
+
b_block_space_size_aligned
)
*
sizeof
(
FloatAB
),
sizeof
(
FloatAB
),
c_block_size
*
sizeof
(
FloatCShuffle
));
c_block_size
_aligned
*
sizeof
(
FloatCShuffle
)
+
c_lds_workspace_size
*
sizeof
(
FloatReduceAcc
)
);
}
}
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
...
@@ -734,11 +738,12 @@ struct GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
...
@@ -734,11 +738,12 @@ struct GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
auto
c0_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
FloatC0
>
(
auto
c0_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
FloatC0
>
(
c_reduce_thread_desc_mperblock_nperblock
.
GetElementSpaceSize
());
c_reduce_thread_desc_mperblock_nperblock
.
GetElementSpaceSize
());
// TODO ANT: incorporate in singly defined p_shared. calculate proper total size in
// Align 16 bytes (maximum LDS read/write width)
// GetSharedMemoryNumberOfByte() and shift pointer as approriate
constexpr
auto
c_block_size_aligned
=
math
::
integer_least_multiple
(
__shared__
FloatReduceAcc
p_d_reduce_work_buffer
[
BlockSize
];
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
.
GetElementSpaceSize
()
*
sizeof
(
FloatCShuffle
),
16
)
/
sizeof
(
FloatCShuffle
);
auto
d_reduce_work_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
p_d_reduce_work_buffer
,
BlockSize
);
auto
d_reduce_work_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
reinterpret_cast
<
FloatReduceAcc
*>
(
static_cast
<
FloatCShuffle
*>
(
p_shared
)
+
c_block_size_aligned
),
BlockSize
);
// Sum thread workspace
// Sum thread workspace
auto
d0_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
FloatReduceAcc
>
(
auto
d0_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
FloatReduceAcc
>
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment