Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
796b544e
Commit
796b544e
authored
Jul 06, 2023
by
ltqin
Browse files
remove share memory
parent
dcfe312b
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
2 additions
and
21 deletions
+2
-21
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_bwd_xdl_cshuffle_qloop_light_v1.hpp
...pl/device_batched_mha_bwd_xdl_cshuffle_qloop_light_v1.hpp
+1
-4
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_bwd_xdl_cshuffle_qloop_light_v2.hpp
...pl/device_batched_mha_bwd_xdl_cshuffle_qloop_light_v2.hpp
+1
-4
include/ck/tensor_operation/gpu/device/impl/device_grouped_mha_bwd_xdl_cshuffle_qloop_light_v1.hpp
...pl/device_grouped_mha_bwd_xdl_cshuffle_qloop_light_v1.hpp
+0
-2
include/ck/tensor_operation/gpu/device/impl/device_grouped_mha_bwd_xdl_cshuffle_qloop_light_v2.hpp
...pl/device_grouped_mha_bwd_xdl_cshuffle_qloop_light_v2.hpp
+0
-2
include/ck/tensor_operation/gpu/grid/gridwise_batched_multihead_attention_bacckward_ydotygrad.hpp
...dwise_batched_multihead_attention_bacckward_ydotygrad.hpp
+0
-9
No files found.
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_bwd_xdl_cshuffle_qloop_light_v1.hpp
View file @
796b544e
...
...
@@ -51,7 +51,6 @@ __global__ void
const
ComputeBasePtrOfStridedBatch
compute_base_ptr_of_batch
)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
const
index_t
num_blocks_per_batch
=
__builtin_amdgcn_readfirstlane
(
get_grid_size
()
/
batch_count
);
const
index_t
g_idx
=
__builtin_amdgcn_readfirstlane
(
get_block_1d_id
()
/
num_blocks_per_batch
);
...
...
@@ -67,7 +66,6 @@ __global__ void
GridwiseGemm
::
Run
(
p_y_grid
+
c_batch_offset
,
p_ygrad_grid
+
c_batch_offset
,
p_d_grid
+
d_batch_offset
,
p_shared
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
d_grid_desc_m
,
block_2_ctile_map
);
...
...
@@ -759,8 +757,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1
// GridwiseYDotYGrad
using
GridwiseYDotYGrad
=
GridwiseBatchedMultiheadAttentionBackward_YDotYGrad
<
InputDataType
,
// TODO: distinguish A/B
// datatype
DDataType
,
DDataType
,
// datatype
YGridDesc_M_O
,
DGridDesc_M
,
BlockSize
,
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_bwd_xdl_cshuffle_qloop_light_v2.hpp
View file @
796b544e
...
...
@@ -50,7 +50,6 @@ __global__ void
const
ComputeBasePtrOfStridedBatch
compute_base_ptr_of_batch
)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
const
index_t
num_blocks_per_batch
=
__builtin_amdgcn_readfirstlane
(
get_grid_size
()
/
batch_count
);
const
index_t
g_idx
=
__builtin_amdgcn_readfirstlane
(
get_block_1d_id
()
/
num_blocks_per_batch
);
...
...
@@ -66,7 +65,6 @@ __global__ void
GridwiseGemm
::
Run
(
p_y_grid
+
c_batch_offset
,
p_ygrad_grid
+
c_batch_offset
,
p_d_grid
+
d_batch_offset
,
p_shared
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
d_grid_desc_m
,
block_2_ctile_map
);
...
...
@@ -773,8 +771,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2
// GridwiseYDotYGrad
using
GridwiseYDotYGrad
=
GridwiseBatchedMultiheadAttentionBackward_YDotYGrad
<
InputDataType
,
// TODO: distinguish A/B
// datatype
DDataType
,
DDataType
,
// datatype
YGridDesc_M_O
,
DGridDesc_M
,
BlockSize
,
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_mha_bwd_xdl_cshuffle_qloop_light_v1.hpp
View file @
796b544e
...
...
@@ -37,7 +37,6 @@ __global__ void
const
void
CK_CONSTANT_ADDRESS_SPACE
*
group_kernel_args
,
const
index_t
group_count
)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
const
index_t
block_id
=
get_block_1d_id
();
const
auto
arg_ptr
=
reinterpret_cast
<
const
GroupKernelArg
*>
(
cast_pointer_to_generic_address_space
(
group_kernel_args
));
...
...
@@ -73,7 +72,6 @@ __global__ void
GridwiseGemm
::
Run
(
arg_ptr
[
group_id
].
p_c_grid_
+
c_batch_offset
,
arg_ptr
[
group_id
].
p_ygrad_grid_
+
c_batch_offset
,
arg_ptr
[
group_id
].
p_d_grid_
+
d_batch_offset
,
static_cast
<
void
*>
(
p_shared
),
arg_ptr
[
group_id
].
d_y_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg_ptr
[
group_id
].
d_grid_desc_m_
,
arg_ptr
[
group_id
].
d_block_2_ctile_map_
);
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_mha_bwd_xdl_cshuffle_qloop_light_v2.hpp
View file @
796b544e
...
...
@@ -37,7 +37,6 @@ __global__ void
const
void
CK_CONSTANT_ADDRESS_SPACE
*
group_kernel_args
,
const
index_t
group_count
)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
const
index_t
block_id
=
get_block_1d_id
();
const
auto
arg_ptr
=
reinterpret_cast
<
const
GroupKernelArg
*>
(
cast_pointer_to_generic_address_space
(
group_kernel_args
));
...
...
@@ -73,7 +72,6 @@ __global__ void
GridwiseGemm
::
Run
(
arg_ptr
[
group_id
].
p_c_grid_
+
c_batch_offset
,
arg_ptr
[
group_id
].
p_ygrad_grid_
+
c_batch_offset
,
arg_ptr
[
group_id
].
p_d_grid_
+
d_batch_offset
,
static_cast
<
void
*>
(
p_shared
),
arg_ptr
[
group_id
].
d_y_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg_ptr
[
group_id
].
d_grid_desc_m_
,
arg_ptr
[
group_id
].
d_block_2_ctile_map_
);
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_multihead_attention_bacckward_ydotygrad.hpp
View file @
796b544e
...
...
@@ -136,15 +136,9 @@ struct GridwiseBatchedMultiheadAttentionBackward_YDotYGrad
};
using
YDotYGrad_M_O
=
YDotYGrad_M_O_
<
BlockSize
,
MPerBlock
,
NPerBlock
>
;
__host__
__device__
static
constexpr
index_t
GetSharedMemoryNumberOfByte
()
{
return
MPerBlock
*
sizeof
(
FloatD
);
}
__device__
static
void
Run
(
const
InputDataType
*
__restrict__
p_y_grid
,
const
InputDataType
*
__restrict__
p_ygrad_grid
,
FloatD
*
__restrict__
p_d_grid
,
void
*
__restrict__
p_shared
,
const
YGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
&
y_grid_desc_mblock_mperblock_nblock_nperblock
,
const
DGridDesc_M
&
d_grid_desc_m
,
...
...
@@ -213,12 +207,9 @@ struct GridwiseBatchedMultiheadAttentionBackward_YDotYGrad
auto
y_thread_buf
=
typename
YDotYGrad_M_O
::
SrcBufType
{};
auto
ygrad_thread_buf
=
typename
YDotYGrad_M_O
::
SrcBufType
{};
auto
y_dot_ygrad_thread_accum_buf
=
typename
YDotYGrad_M_O
::
DstBufType
{};
auto
y_dot_ygrad_block_accum_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
FloatD
*>
(
p_shared
),
MPerBlock
);
// clear accum buffers
y_dot_ygrad_thread_accum_buf
.
Clear
();
y_dot_ygrad_block_accum_buf
.
Clear
();
index_t
oblock_idx
=
0
;
do
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment