Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
31706d42
Commit
31706d42
authored
Sep 04, 2023
by
danyao12
Browse files
modify bias LDS addrs in bwd kernels
parent
70d700b3
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
316 additions
and
266 deletions
+316
-266
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_light_v1.hpp
...dwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_light_v1.hpp
+67
-58
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_light_v2.hpp
...dwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_light_v2.hpp
+79
-62
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_v1.hpp
...id/gridwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_v1.hpp
+80
-70
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_v2.hpp
...id/gridwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_v2.hpp
+90
-76
No files found.
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_light_v1.hpp
View file @
31706d42
...
...
@@ -1191,56 +1191,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
KPack
>
;
};
struct
SharedMemTrait
{
// // LDS allocation for A and B: be careful of alignment
static
constexpr
auto
q_block_desc_k0_m_k1
=
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1
();
static
constexpr
auto
k_block_desc_k0_n_k1
=
GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1
();
static
constexpr
auto
ygrad_block_desc_k0_m_k1
=
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1
();
static
constexpr
auto
p_slash_sgrad_block_desc_k0_m_k1
=
GetA2BlockDescriptor_K0_M_K1
<
Gemm2Params
>
();
static
constexpr
auto
max_lds_align
=
Number
<
16
/
sizeof
(
GemmDataType
)
>
{};
static
constexpr
auto
q_block_space_size_aligned
=
math
::
integer_least_multiple
(
q_block_desc_k0_m_k1
.
GetElementSpaceSize
(),
max_lds_align
);
static
constexpr
auto
k_block_space_size_aligned
=
math
::
integer_least_multiple
(
k_block_desc_k0_n_k1
.
GetElementSpaceSize
(),
max_lds_align
);
static
constexpr
auto
ygrad_block_space_size_aligned
=
math
::
integer_least_multiple
(
ygrad_block_desc_k0_m_k1
.
GetElementSpaceSize
(),
max_lds_align
);
static
constexpr
auto
p_slash_sgrad_block_space_size_aligned
=
math
::
integer_least_multiple
(
p_slash_sgrad_block_desc_k0_m_k1
.
GetElementSpaceSize
(),
max_lds_align
);
static
constexpr
auto
k_block_space_offset
=
0
;
static
constexpr
auto
ygrad_block_space_offset
=
k_block_space_size_aligned
.
value
;
static
constexpr
auto
q_block_space_offset
=
k_block_space_size_aligned
.
value
+
ygrad_block_space_size_aligned
.
value
;
static
constexpr
auto
p_slash_sgrad_block_space_offset
=
k_block_space_size_aligned
.
value
+
ygrad_block_space_size_aligned
.
value
+
q_block_space_size_aligned
.
value
;
// LDS allocation for C shuffle in LDS
static
constexpr
auto
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
=
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
();
static
constexpr
auto
c_block_space_size
=
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
.
GetElementSpaceSize
();
};
__host__
__device__
static
constexpr
index_t
GetSharedMemoryNumberOfByte
()
{
const
index_t
p_slash_sgrad_bytes_end
=
(
SharedMemTrait
::
p_slash_sgrad_block_space_offset
+
SharedMemTrait
::
p_slash_sgrad_block_space_size_aligned
)
*
sizeof
(
GemmDataType
);
const
index_t
c_block_bytes_end
=
SharedMemTrait
::
c_block_space_size
*
sizeof
(
FloatCShuffle
);
return
math
::
max
(
p_slash_sgrad_bytes_end
,
c_block_bytes_end
);
}
// D0
static
constexpr
auto
D0M2
=
Number
<
4
>
{};
static
constexpr
auto
D0M1
=
Number
<
MPerXdl
>
{}
/
D0M2
;
...
...
@@ -1273,12 +1223,16 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
template
<
typename
DataType
>
struct
TypeTransform
{
using
Type
=
DataType
;
using
Type
=
DataType
;
static
constexpr
index_t
Size0
=
sizeof
(
DataType
);
static
constexpr
index_t
Size
=
sizeof
(
DataType
);
};
template
<
>
struct
TypeTransform
<
void
>
{
using
Type
=
ck
::
half_t
;
using
Type
=
ck
::
half_t
;
static
constexpr
index_t
Size0
=
0
;
static
constexpr
index_t
Size
=
sizeof
(
ck
::
half_t
);
};
static
constexpr
index_t
NThreadClusterLengths
=
MPerXdl
;
static_assert
(
MPerXdl
<=
KPerBlock
);
...
...
@@ -1354,6 +1308,66 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
2
>
;
};
struct
SharedMemTrait
{
// // LDS allocation for A and B: be careful of alignment
static
constexpr
auto
q_block_desc_k0_m_k1
=
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1
();
static
constexpr
auto
k_block_desc_k0_n_k1
=
GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1
();
static
constexpr
auto
ygrad_block_desc_k0_m_k1
=
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1
();
static
constexpr
auto
p_slash_sgrad_block_desc_k0_m_k1
=
GetA2BlockDescriptor_K0_M_K1
<
Gemm2Params
>
();
static
constexpr
auto
max_lds_align
=
Number
<
16
/
sizeof
(
GemmDataType
)
>
{};
static
constexpr
auto
q_block_space_size_aligned
=
math
::
integer_least_multiple
(
q_block_desc_k0_m_k1
.
GetElementSpaceSize
(),
max_lds_align
);
static
constexpr
auto
k_block_space_size_aligned
=
math
::
integer_least_multiple
(
k_block_desc_k0_n_k1
.
GetElementSpaceSize
(),
max_lds_align
);
static
constexpr
auto
ygrad_block_space_size_aligned
=
math
::
integer_least_multiple
(
ygrad_block_desc_k0_m_k1
.
GetElementSpaceSize
(),
max_lds_align
);
static
constexpr
auto
p_slash_sgrad_block_space_size_aligned
=
math
::
integer_least_multiple
(
p_slash_sgrad_block_desc_k0_m_k1
.
GetElementSpaceSize
(),
max_lds_align
);
static
constexpr
auto
k_block_space_offset
=
0
;
static
constexpr
auto
ygrad_block_space_offset
=
k_block_space_size_aligned
.
value
;
static
constexpr
auto
q_block_space_offset
=
k_block_space_size_aligned
.
value
+
ygrad_block_space_size_aligned
.
value
;
static
constexpr
auto
p_slash_sgrad_block_space_offset
=
k_block_space_size_aligned
.
value
+
ygrad_block_space_size_aligned
.
value
+
q_block_space_size_aligned
.
value
;
static
constexpr
auto
d0_block_space_size_aligned
=
math
::
integer_least_multiple
(
D0Loader
::
d0_block_write_desc_m0_n0_m1_m2_n1_m3
.
GetElementSpaceSize
(),
max_lds_align
);
static
constexpr
auto
d0_block_space_offset
=
(
k_block_space_size_aligned
.
value
+
ygrad_block_space_size_aligned
.
value
+
q_block_space_size_aligned
.
value
)
*
sizeof
(
GemmDataType
)
/
D0Loader
::
template
TypeTransform
<
D0DataType
>
::
Size
;
// LDS allocation for C shuffle in LDS
static
constexpr
auto
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
=
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
();
static
constexpr
auto
c_block_space_size
=
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
.
GetElementSpaceSize
();
};
__host__
__device__
static
constexpr
index_t
GetSharedMemoryNumberOfByte
()
{
const
index_t
p_slash_sgrad_bytes_end
=
(
SharedMemTrait
::
p_slash_sgrad_block_space_offset
+
SharedMemTrait
::
p_slash_sgrad_block_space_size_aligned
)
*
sizeof
(
GemmDataType
);
const
index_t
d0_bytes_end
=
(
SharedMemTrait
::
d0_block_space_offset
+
SharedMemTrait
::
d0_block_space_size_aligned
)
*
D0Loader
::
template
TypeTransform
<
D0DataType
>
::
Size0
;
const
index_t
c_block_bytes_end
=
SharedMemTrait
::
c_block_space_size
*
sizeof
(
FloatCShuffle
);
return
math
::
max
(
p_slash_sgrad_bytes_end
,
d0_bytes_end
,
c_block_bytes_end
);
}
template
<
bool
HasMainKBlockLoop
,
bool
IsDropout
,
typename
Block2CTileMap
,
...
...
@@ -1987,8 +2001,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
p_d0_grid
,
d0_grid_desc_m0_n0_m1_m2_n1_m3
.
GetElementSpaceSize
());
auto
d0_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
D0DataType
*>
(
p_shared
)
+
SharedMemTrait
::
p_slash_sgrad_block_space_offset
,
static_cast
<
D0DataType
*>
(
p_shared
)
+
SharedMemTrait
::
d0_block_space_offset
,
D0Loader
::
d0_block_write_desc_m0_n0_m1_m2_n1_m3
.
GetElementSpaceSize
());
auto
d0_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
D0DataType
>
(
...
...
@@ -2023,10 +2036,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
});
});
// load k
gemm_tile_k_blockwise_copy
.
RunWrite
(
GemmBlockwiseCopy
::
k_block_desc_k0_n_k1
,
k_block_buf
);
d0_block_copy_global_to_lds
.
MoveSrcSliceWindow
(
d0_grid_desc_m0_n0_m1_m2_n1_m3
,
make_multi_index
(
-
1
,
0
,
-
D0M0
.
value
,
0
,
0
,
0
));
}
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_light_v2.hpp
View file @
31706d42
...
...
@@ -1276,65 +1276,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
}
};
struct
SharedMemTrait
{
// LDS allocation for K
static
constexpr
auto
k_block_desc_k0_n_k1
=
GetKBlockDescriptor_K0PerBlock_NPerBlock_K1
();
// LDS allocation for A and B: be careful of alignment
static
constexpr
auto
a_block_desc_ak0_m_ak1
=
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1
();
static
constexpr
auto
b1_block_desc_bk0_n_bk1
=
GetB1BlockDescriptor_BK0PerBlock_NPerBlock_BK1
();
static
constexpr
auto
a2_block_desc_k0_m_k1
=
GetA2BlockDescriptor_K0_M_K1
<
Gemm2Params
>
();
static
constexpr
auto
max_lds_align
=
Number
<
16
/
sizeof
(
GemmDataType
)
>
{};
static
constexpr
auto
k_block_space_size_aligned
=
math
::
integer_least_multiple
(
k_block_desc_k0_n_k1
.
GetElementSpaceSize
(),
max_lds_align
);
static
constexpr
auto
a_block_space_size_aligned
=
math
::
integer_least_multiple
(
a_block_desc_ak0_m_ak1
.
GetElementSpaceSize
(),
max_lds_align
);
static
constexpr
auto
b1_block_space_size_aligned
=
math
::
integer_least_multiple
(
b1_block_desc_bk0_n_bk1
.
GetElementSpaceSize
(),
max_lds_align
);
static
constexpr
auto
a2_block_space_size_aligned
=
math
::
integer_least_multiple
(
a2_block_desc_k0_m_k1
.
GetElementSpaceSize
(),
max_lds_align
);
static
constexpr
auto
k_block_space_offset
=
0
;
static
constexpr
auto
a_block_space_offset
=
k_block_space_size_aligned
.
value
;
static
constexpr
auto
b1_block_space_offset
=
k_block_space_size_aligned
.
value
;
static
constexpr
auto
a2_block_space_offset
=
k_block_space_size_aligned
.
value
;
// LDS allocation for reduction
static
constexpr
index_t
reduction_space_size_aligned
=
math
::
integer_least_multiple
(
BlockSize
,
max_lds_align
);
// LDS allocation for C shuffle in LDS
static
constexpr
auto
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
=
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
();
static
constexpr
auto
c_block_space_size
=
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
.
GetElementSpaceSize
();
};
__host__
__device__
static
constexpr
index_t
GetSharedMemoryNumberOfByte
()
{
const
index_t
gemm0_bytes_end
=
(
SharedMemTrait
::
k_block_space_size_aligned
+
SharedMemTrait
::
a_block_space_size_aligned
)
*
sizeof
(
GemmDataType
);
const
index_t
gemm1_bytes_end
=
(
SharedMemTrait
::
k_block_space_size_aligned
+
SharedMemTrait
::
b1_block_space_size_aligned
)
*
sizeof
(
GemmDataType
);
const
index_t
gemm2_bytes_end
=
(
SharedMemTrait
::
k_block_space_size_aligned
+
SharedMemTrait
::
a2_block_space_size_aligned
)
*
sizeof
(
GemmDataType
);
const
index_t
gemm3_bytes_end
=
(
SharedMemTrait
::
k_block_space_size_aligned
+
SharedMemTrait
::
a_block_space_size_aligned
)
*
sizeof
(
GemmDataType
);
const
index_t
c_block_bytes_end
=
SharedMemTrait
::
c_block_space_size
*
sizeof
(
FloatCShuffle
);
return
math
::
max
(
gemm0_bytes_end
,
gemm1_bytes_end
,
gemm2_bytes_end
,
gemm3_bytes_end
,
c_block_bytes_end
);
}
// D0
static
constexpr
auto
D0M2
=
Number
<
4
>
{};
static
constexpr
auto
D0M1
=
Number
<
MPerXdl
>
{}
/
D0M2
;
...
...
@@ -1367,12 +1308,16 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
template
<
typename
DataType
>
struct
TypeTransform
{
using
Type
=
DataType
;
using
Type
=
DataType
;
static
constexpr
index_t
Size0
=
sizeof
(
DataType
);
static
constexpr
index_t
Size
=
sizeof
(
DataType
);
};
template
<
>
struct
TypeTransform
<
void
>
{
using
Type
=
ck
::
half_t
;
using
Type
=
ck
::
half_t
;
static
constexpr
index_t
Size0
=
0
;
static
constexpr
index_t
Size
=
sizeof
(
ck
::
half_t
);
};
static
constexpr
index_t
NThreadClusterLengths
=
32
;
static_assert
(
NPerXdl
==
32
);
...
...
@@ -1448,6 +1393,78 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
2
>
;
};
struct
SharedMemTrait
{
// LDS allocation for K
static
constexpr
auto
k_block_desc_k0_n_k1
=
GetKBlockDescriptor_K0PerBlock_NPerBlock_K1
();
// LDS allocation for A and B: be careful of alignment
static
constexpr
auto
a_block_desc_ak0_m_ak1
=
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1
();
static
constexpr
auto
b1_block_desc_bk0_n_bk1
=
GetB1BlockDescriptor_BK0PerBlock_NPerBlock_BK1
();
static
constexpr
auto
a2_block_desc_k0_m_k1
=
GetA2BlockDescriptor_K0_M_K1
<
Gemm2Params
>
();
static
constexpr
auto
max_lds_align
=
Number
<
16
/
sizeof
(
GemmDataType
)
>
{};
static
constexpr
auto
k_block_space_size_aligned
=
math
::
integer_least_multiple
(
k_block_desc_k0_n_k1
.
GetElementSpaceSize
(),
max_lds_align
);
static
constexpr
auto
a_block_space_size_aligned
=
math
::
integer_least_multiple
(
a_block_desc_ak0_m_ak1
.
GetElementSpaceSize
(),
max_lds_align
);
static
constexpr
auto
b1_block_space_size_aligned
=
math
::
integer_least_multiple
(
b1_block_desc_bk0_n_bk1
.
GetElementSpaceSize
(),
max_lds_align
);
static
constexpr
auto
a2_block_space_size_aligned
=
math
::
integer_least_multiple
(
a2_block_desc_k0_m_k1
.
GetElementSpaceSize
(),
max_lds_align
);
static
constexpr
auto
k_block_space_offset
=
0
;
static
constexpr
auto
a_block_space_offset
=
k_block_space_size_aligned
.
value
;
static
constexpr
auto
b1_block_space_offset
=
k_block_space_size_aligned
.
value
;
static
constexpr
auto
a2_block_space_offset
=
k_block_space_size_aligned
.
value
;
// LDS allocation for reduction
static
constexpr
index_t
reduction_space_size_aligned
=
math
::
integer_least_multiple
(
BlockSize
,
max_lds_align
);
static
constexpr
auto
d0_block_space_size_aligned
=
math
::
integer_least_multiple
(
D0Loader
::
d0_block_write_desc_m0_n0_m1_m2_n1_m3
.
GetElementSpaceSize
(),
max_lds_align
);
static
constexpr
auto
d0_block_space_offset
=
k_block_space_size_aligned
.
value
*
sizeof
(
GemmDataType
)
/
D0Loader
::
template
TypeTransform
<
D0DataType
>
::
Size
;
// LDS allocation for C shuffle in LDS
static
constexpr
auto
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
=
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
();
static
constexpr
auto
c_block_space_size
=
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
.
GetElementSpaceSize
();
};
__host__
__device__
static
constexpr
index_t
GetSharedMemoryNumberOfByte
()
{
const
index_t
gemm0_bytes_end
=
(
SharedMemTrait
::
k_block_space_size_aligned
+
SharedMemTrait
::
a_block_space_size_aligned
)
*
sizeof
(
GemmDataType
);
const
index_t
gemm1_bytes_end
=
(
SharedMemTrait
::
k_block_space_size_aligned
+
SharedMemTrait
::
b1_block_space_size_aligned
)
*
sizeof
(
GemmDataType
);
const
index_t
gemm2_bytes_end
=
(
SharedMemTrait
::
k_block_space_size_aligned
+
SharedMemTrait
::
a2_block_space_size_aligned
)
*
sizeof
(
GemmDataType
);
const
index_t
gemm3_bytes_end
=
(
SharedMemTrait
::
k_block_space_size_aligned
+
SharedMemTrait
::
a_block_space_size_aligned
)
*
sizeof
(
GemmDataType
);
const
index_t
d0_bytes_end
=
(
SharedMemTrait
::
d0_block_space_offset
+
SharedMemTrait
::
d0_block_space_size_aligned
)
*
D0Loader
::
template
TypeTransform
<
D0DataType
>
::
Size0
;
const
index_t
c_block_bytes_end
=
SharedMemTrait
::
c_block_space_size
*
sizeof
(
FloatCShuffle
);
return
math
::
max
(
gemm0_bytes_end
,
gemm1_bytes_end
,
gemm2_bytes_end
,
gemm3_bytes_end
,
d0_bytes_end
,
c_block_bytes_end
);
}
template
<
bool
HasMainKBlockLoop
,
bool
IsDropout
,
typename
Block2CTileMap
,
...
...
@@ -2137,7 +2154,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
p_d0_grid
,
d0_grid_desc_m0_n0_m1_m2_n1_m3
.
GetElementSpaceSize
());
auto
d0_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
D0DataType
*>
(
p_shared
)
+
SharedMemTrait
::
a
_block_space_offset
,
static_cast
<
D0DataType
*>
(
p_shared
)
+
SharedMemTrait
::
d0
_block_space_offset
,
D0Loader
::
d0_block_write_desc_m0_n0_m1_m2_n1_m3
.
GetElementSpaceSize
());
auto
d0_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
D0DataType
>
(
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_v1.hpp
View file @
31706d42
...
...
@@ -1259,68 +1259,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
};
using
YDotYGrad_M_O
=
YDotYGrad_M_O_
<
BlockSize
,
MPerBlock
,
Gemm1NPerBlock
>
;
struct
SharedMemTrait
{
// // LDS allocation for A and B: be careful of alignment
static
constexpr
auto
q_block_desc_k0_m_k1
=
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1
();
static
constexpr
auto
k_block_desc_k0_n_k1
=
GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1
();
static
constexpr
auto
ygrad_block_desc_k0_m_k1
=
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1
();
static
constexpr
auto
p_slash_sgrad_block_desc_k0_m_k1
=
GetA2BlockDescriptor_K0_M_K1
<
Gemm2Params
>
();
static
constexpr
auto
max_lds_align
=
Number
<
16
/
sizeof
(
GemmDataType
)
>
{};
static
constexpr
auto
q_block_space_size_aligned
=
math
::
integer_least_multiple
(
q_block_desc_k0_m_k1
.
GetElementSpaceSize
(),
max_lds_align
);
static
constexpr
auto
k_block_space_size_aligned
=
math
::
integer_least_multiple
(
k_block_desc_k0_n_k1
.
GetElementSpaceSize
(),
max_lds_align
);
static
constexpr
auto
ygrad_block_space_size_aligned
=
math
::
integer_least_multiple
(
ygrad_block_desc_k0_m_k1
.
GetElementSpaceSize
(),
max_lds_align
);
static
constexpr
auto
p_slash_sgrad_block_space_size_aligned
=
math
::
integer_least_multiple
(
p_slash_sgrad_block_desc_k0_m_k1
.
GetElementSpaceSize
(),
max_lds_align
);
static
constexpr
auto
k_block_space_offset
=
0
;
static
constexpr
auto
ygrad_block_space_offset
=
k_block_space_size_aligned
.
value
;
static
constexpr
auto
q_block_space_offset
=
k_block_space_size_aligned
.
value
+
ygrad_block_space_size_aligned
.
value
;
static
constexpr
auto
p_slash_sgrad_block_space_offset
=
k_block_space_size_aligned
.
value
+
ygrad_block_space_size_aligned
.
value
+
q_block_space_size_aligned
.
value
;
// LDS allocation for reduction
static
constexpr
index_t
reduction_space_size_aligned
=
math
::
integer_least_multiple
(
BlockSize
,
max_lds_align
);
static
constexpr
auto
reduction_space_offset
=
(
k_block_space_size_aligned
.
value
+
ygrad_block_space_size_aligned
.
value
+
q_block_space_size_aligned
.
value
)
*
sizeof
(
GemmDataType
)
/
sizeof
(
FloatGemmAcc
);
// LDS allocation for C shuffle in LDS
static
constexpr
auto
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
=
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
();
static
constexpr
auto
c_block_space_size
=
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
.
GetElementSpaceSize
();
};
__host__
__device__
static
constexpr
index_t
GetSharedMemoryNumberOfByte
()
{
const
index_t
p_slash_sgrad_bytes_end
=
(
SharedMemTrait
::
p_slash_sgrad_block_space_offset
+
SharedMemTrait
::
p_slash_sgrad_block_space_size_aligned
)
*
sizeof
(
GemmDataType
);
const
index_t
softmax_bytes_end
=
(
SharedMemTrait
::
reduction_space_offset
+
SharedMemTrait
::
reduction_space_size_aligned
)
*
sizeof
(
FloatGemmAcc
);
const
index_t
c_block_bytes_end
=
SharedMemTrait
::
c_block_space_size
*
sizeof
(
FloatCShuffle
);
return
math
::
max
(
p_slash_sgrad_bytes_end
,
softmax_bytes_end
,
c_block_bytes_end
);
}
// D0
static
constexpr
auto
D0M2
=
Number
<
4
>
{};
static
constexpr
auto
D0M1
=
Number
<
MPerXdl
>
{}
/
D0M2
;
...
...
@@ -1353,12 +1291,16 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
template
<
typename
DataType
>
struct
TypeTransform
{
using
Type
=
DataType
;
using
Type
=
DataType
;
static
constexpr
index_t
Size0
=
sizeof
(
DataType
);
static
constexpr
index_t
Size
=
sizeof
(
DataType
);
};
template
<
>
struct
TypeTransform
<
void
>
{
using
Type
=
ck
::
half_t
;
using
Type
=
ck
::
half_t
;
static
constexpr
index_t
Size0
=
0
;
static
constexpr
index_t
Size
=
sizeof
(
ck
::
half_t
);
};
static
constexpr
index_t
NThreadClusterLengths
=
MPerXdl
;
static_assert
(
MPerXdl
<=
KPerBlock
);
...
...
@@ -1434,6 +1376,79 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
2
>
;
};
struct
SharedMemTrait
{
// // LDS allocation for A and B: be careful of alignment
static
constexpr
auto
q_block_desc_k0_m_k1
=
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1
();
static
constexpr
auto
k_block_desc_k0_n_k1
=
GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1
();
static
constexpr
auto
ygrad_block_desc_k0_m_k1
=
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1
();
static
constexpr
auto
p_slash_sgrad_block_desc_k0_m_k1
=
GetA2BlockDescriptor_K0_M_K1
<
Gemm2Params
>
();
static
constexpr
auto
max_lds_align
=
Number
<
16
/
sizeof
(
GemmDataType
)
>
{};
static
constexpr
auto
q_block_space_size_aligned
=
math
::
integer_least_multiple
(
q_block_desc_k0_m_k1
.
GetElementSpaceSize
(),
max_lds_align
);
static
constexpr
auto
k_block_space_size_aligned
=
math
::
integer_least_multiple
(
k_block_desc_k0_n_k1
.
GetElementSpaceSize
(),
max_lds_align
);
static
constexpr
auto
ygrad_block_space_size_aligned
=
math
::
integer_least_multiple
(
ygrad_block_desc_k0_m_k1
.
GetElementSpaceSize
(),
max_lds_align
);
static
constexpr
auto
p_slash_sgrad_block_space_size_aligned
=
math
::
integer_least_multiple
(
p_slash_sgrad_block_desc_k0_m_k1
.
GetElementSpaceSize
(),
max_lds_align
);
static
constexpr
auto
k_block_space_offset
=
0
;
static
constexpr
auto
ygrad_block_space_offset
=
k_block_space_size_aligned
.
value
;
static
constexpr
auto
q_block_space_offset
=
k_block_space_size_aligned
.
value
+
ygrad_block_space_size_aligned
.
value
;
static
constexpr
auto
p_slash_sgrad_block_space_offset
=
k_block_space_size_aligned
.
value
+
ygrad_block_space_size_aligned
.
value
+
q_block_space_size_aligned
.
value
;
// LDS allocation for reduction
static
constexpr
index_t
reduction_space_size_aligned
=
math
::
integer_least_multiple
(
BlockSize
,
max_lds_align
);
static
constexpr
auto
reduction_space_offset
=
(
k_block_space_size_aligned
.
value
+
ygrad_block_space_size_aligned
.
value
+
q_block_space_size_aligned
.
value
)
*
sizeof
(
GemmDataType
)
/
sizeof
(
FloatGemmAcc
);
static
constexpr
auto
d0_block_space_size_aligned
=
math
::
integer_least_multiple
(
D0Loader
::
d0_block_write_desc_m0_n0_m1_m2_n1_m3
.
GetElementSpaceSize
(),
max_lds_align
);
static
constexpr
auto
d0_block_space_offset
=
(
k_block_space_size_aligned
.
value
+
ygrad_block_space_size_aligned
.
value
+
q_block_space_size_aligned
.
value
)
*
sizeof
(
GemmDataType
)
/
D0Loader
::
template
TypeTransform
<
D0DataType
>
::
Size
;
// LDS allocation for C shuffle in LDS
static
constexpr
auto
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
=
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
();
static
constexpr
auto
c_block_space_size
=
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
.
GetElementSpaceSize
();
};
__host__
__device__
static
constexpr
index_t
GetSharedMemoryNumberOfByte
()
{
const
index_t
p_slash_sgrad_bytes_end
=
(
SharedMemTrait
::
p_slash_sgrad_block_space_offset
+
SharedMemTrait
::
p_slash_sgrad_block_space_size_aligned
)
*
sizeof
(
GemmDataType
);
const
index_t
softmax_bytes_end
=
(
SharedMemTrait
::
reduction_space_offset
+
SharedMemTrait
::
reduction_space_size_aligned
)
*
sizeof
(
FloatGemmAcc
);
const
index_t
d0_bytes_end
=
(
SharedMemTrait
::
d0_block_space_offset
+
SharedMemTrait
::
d0_block_space_size_aligned
)
*
D0Loader
::
template
TypeTransform
<
D0DataType
>
::
Size0
;
const
index_t
c_block_bytes_end
=
SharedMemTrait
::
c_block_space_size
*
sizeof
(
FloatCShuffle
);
return
math
::
max
(
p_slash_sgrad_bytes_end
,
softmax_bytes_end
,
d0_bytes_end
,
c_block_bytes_end
);
}
template
<
bool
HasMainKBlockLoop
,
bool
IsDropout
,
typename
Block2CTileMap
,
...
...
@@ -2186,8 +2201,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
p_d0_grid
,
d0_grid_desc_m0_n0_m1_m2_n1_m3
.
GetElementSpaceSize
());
auto
d0_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
D0DataType
*>
(
p_shared
)
+
SharedMemTrait
::
p_slash_sgrad_block_space_offset
,
static_cast
<
D0DataType
*>
(
p_shared
)
+
SharedMemTrait
::
d0_block_space_offset
,
D0Loader
::
d0_block_write_desc_m0_n0_m1_m2_n1_m3
.
GetElementSpaceSize
());
auto
d0_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
D0DataType
>
(
...
...
@@ -2222,10 +2236,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
});
});
// load k
gemm_tile_k_blockwise_copy
.
RunWrite
(
GemmBlockwiseCopy
::
k_block_desc_k0_n_k1
,
k_block_buf
);
d0_block_copy_global_to_lds
.
MoveSrcSliceWindow
(
d0_grid_desc_m0_n0_m1_m2_n1_m3
,
make_multi_index
(
-
1
,
0
,
-
D0M0
.
value
,
0
,
0
,
0
));
}
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_v2.hpp
View file @
31706d42
...
...
@@ -1321,79 +1321,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
}
};
struct
SharedMemTrait
{
// LDS allocation for K
static
constexpr
auto
k_block_desc_k0_n_k1
=
GetKBlockDescriptor_K0PerBlock_NPerBlock_K1
();
// LDS allocation for A and B: be careful of alignment
static
constexpr
auto
a_block_desc_ak0_m_ak1
=
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1
();
static
constexpr
auto
b1_block_desc_bk0_n_bk1
=
GetB1BlockDescriptor_BK0PerBlock_NPerBlock_BK1
();
static
constexpr
auto
a2_block_desc_k0_m_k1
=
GetA2BlockDescriptor_K0_M_K1
<
Gemm2Params
>
();
static
constexpr
auto
max_lds_align
=
Number
<
16
/
sizeof
(
GemmDataType
)
>
{};
static
constexpr
auto
k_block_space_size_aligned
=
math
::
integer_least_multiple
(
k_block_desc_k0_n_k1
.
GetElementSpaceSize
(),
max_lds_align
);
static
constexpr
auto
a_block_space_size_aligned
=
math
::
integer_least_multiple
(
a_block_desc_ak0_m_ak1
.
GetElementSpaceSize
(),
max_lds_align
);
static
constexpr
auto
b1_block_space_size_aligned
=
math
::
integer_least_multiple
(
b1_block_desc_bk0_n_bk1
.
GetElementSpaceSize
(),
max_lds_align
);
static
constexpr
auto
a2_block_space_size_aligned
=
math
::
integer_least_multiple
(
a2_block_desc_k0_m_k1
.
GetElementSpaceSize
(),
max_lds_align
);
static
constexpr
auto
k_block_space_offset
=
0
;
static
constexpr
auto
a_block_space_offset
=
k_block_space_size_aligned
.
value
;
static
constexpr
auto
b1_block_space_offset
=
k_block_space_size_aligned
.
value
;
static
constexpr
auto
a2_block_space_offset
=
k_block_space_size_aligned
.
value
;
// LDS allocation for reduction
static
constexpr
index_t
reduction_space_size_aligned
=
math
::
integer_least_multiple
(
BlockSize
,
max_lds_align
);
static
constexpr
auto
reduction_space_offset
=
(
math
::
max
(
a_block_space_size_aligned
.
value
,
b1_block_space_size_aligned
.
value
,
a2_block_space_size_aligned
.
value
)
+
k_block_space_size_aligned
.
value
)
*
sizeof
(
GemmDataType
)
/
sizeof
(
FloatGemmAcc
);
// LDS allocation for C shuffle in LDS
static
constexpr
auto
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
=
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
();
static
constexpr
auto
c_block_space_size
=
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
.
GetElementSpaceSize
();
};
__host__
__device__
static
constexpr
index_t
GetSharedMemoryNumberOfByte
()
{
const
index_t
gemm0_bytes_end
=
(
SharedMemTrait
::
k_block_space_size_aligned
+
SharedMemTrait
::
a_block_space_size_aligned
)
*
sizeof
(
GemmDataType
);
const
index_t
gemm1_bytes_end
=
(
SharedMemTrait
::
k_block_space_size_aligned
+
SharedMemTrait
::
b1_block_space_size_aligned
)
*
sizeof
(
GemmDataType
);
const
index_t
gemm2_bytes_end
=
(
SharedMemTrait
::
k_block_space_size_aligned
+
SharedMemTrait
::
a2_block_space_size_aligned
)
*
sizeof
(
GemmDataType
);
const
index_t
gemm3_bytes_end
=
(
SharedMemTrait
::
k_block_space_size_aligned
+
SharedMemTrait
::
a_block_space_size_aligned
)
*
sizeof
(
GemmDataType
);
const
index_t
softmax_bytes_end
=
(
SharedMemTrait
::
reduction_space_offset
+
SharedMemTrait
::
reduction_space_size_aligned
)
*
sizeof
(
FloatGemmAcc
);
const
index_t
c_block_bytes_end
=
SharedMemTrait
::
c_block_space_size
*
sizeof
(
FloatCShuffle
);
return
math
::
max
(
gemm0_bytes_end
,
gemm1_bytes_end
,
gemm2_bytes_end
,
gemm3_bytes_end
,
softmax_bytes_end
,
c_block_bytes_end
);
}
// D0
static
constexpr
auto
D0M2
=
Number
<
4
>
{};
static
constexpr
auto
D0M1
=
Number
<
MPerXdl
>
{}
/
D0M2
;
...
...
@@ -1426,12 +1353,16 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
template
<
typename
DataType
>
struct
TypeTransform
{
using
Type
=
DataType
;
using
Type
=
DataType
;
static
constexpr
index_t
Size0
=
sizeof
(
DataType
);
static
constexpr
index_t
Size
=
sizeof
(
DataType
);
};
template
<
>
struct
TypeTransform
<
void
>
{
using
Type
=
ck
::
half_t
;
using
Type
=
ck
::
half_t
;
static
constexpr
index_t
Size0
=
0
;
static
constexpr
index_t
Size
=
sizeof
(
ck
::
half_t
);
};
static
constexpr
index_t
NThreadClusterLengths
=
32
;
static_assert
(
NPerXdl
==
32
);
...
...
@@ -1507,6 +1438,89 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
2
>
;
};
struct
SharedMemTrait
{
// LDS allocation for K
static
constexpr
auto
k_block_desc_k0_n_k1
=
GetKBlockDescriptor_K0PerBlock_NPerBlock_K1
();
// LDS allocation for A and B: be careful of alignment
static
constexpr
auto
a_block_desc_ak0_m_ak1
=
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1
();
static
constexpr
auto
b1_block_desc_bk0_n_bk1
=
GetB1BlockDescriptor_BK0PerBlock_NPerBlock_BK1
();
static
constexpr
auto
a2_block_desc_k0_m_k1
=
GetA2BlockDescriptor_K0_M_K1
<
Gemm2Params
>
();
static
constexpr
auto
max_lds_align
=
Number
<
16
/
sizeof
(
GemmDataType
)
>
{};
static
constexpr
auto
k_block_space_size_aligned
=
math
::
integer_least_multiple
(
k_block_desc_k0_n_k1
.
GetElementSpaceSize
(),
max_lds_align
);
static
constexpr
auto
a_block_space_size_aligned
=
math
::
integer_least_multiple
(
a_block_desc_ak0_m_ak1
.
GetElementSpaceSize
(),
max_lds_align
);
static
constexpr
auto
b1_block_space_size_aligned
=
math
::
integer_least_multiple
(
b1_block_desc_bk0_n_bk1
.
GetElementSpaceSize
(),
max_lds_align
);
static
constexpr
auto
a2_block_space_size_aligned
=
math
::
integer_least_multiple
(
a2_block_desc_k0_m_k1
.
GetElementSpaceSize
(),
max_lds_align
);
static
constexpr
auto
k_block_space_offset
=
0
;
static
constexpr
auto
a_block_space_offset
=
k_block_space_size_aligned
.
value
;
static
constexpr
auto
b1_block_space_offset
=
k_block_space_size_aligned
.
value
;
static
constexpr
auto
a2_block_space_offset
=
k_block_space_size_aligned
.
value
;
// LDS allocation for reduction
static
constexpr
index_t
reduction_space_size_aligned
=
math
::
integer_least_multiple
(
BlockSize
,
max_lds_align
);
static
constexpr
auto
reduction_space_offset
=
(
math
::
max
(
a_block_space_size_aligned
.
value
,
b1_block_space_size_aligned
.
value
,
a2_block_space_size_aligned
.
value
)
+
k_block_space_size_aligned
.
value
)
*
sizeof
(
GemmDataType
)
/
sizeof
(
FloatGemmAcc
);
static
constexpr
auto
d0_block_space_size_aligned
=
math
::
integer_least_multiple
(
D0Loader
::
d0_block_write_desc_m0_n0_m1_m2_n1_m3
.
GetElementSpaceSize
(),
max_lds_align
);
static
constexpr
auto
d0_block_space_offset
=
k_block_space_size_aligned
.
value
*
sizeof
(
GemmDataType
)
/
D0Loader
::
template
TypeTransform
<
D0DataType
>
::
Size
;
// LDS allocation for C shuffle in LDS
static
constexpr
auto
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
=
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
();
static
constexpr
auto
c_block_space_size
=
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
.
GetElementSpaceSize
();
};
__host__
__device__
static
constexpr
index_t
GetSharedMemoryNumberOfByte
()
{
const
index_t
gemm0_bytes_end
=
(
SharedMemTrait
::
k_block_space_size_aligned
+
SharedMemTrait
::
a_block_space_size_aligned
)
*
sizeof
(
GemmDataType
);
const
index_t
gemm1_bytes_end
=
(
SharedMemTrait
::
k_block_space_size_aligned
+
SharedMemTrait
::
b1_block_space_size_aligned
)
*
sizeof
(
GemmDataType
);
const
index_t
gemm2_bytes_end
=
(
SharedMemTrait
::
k_block_space_size_aligned
+
SharedMemTrait
::
a2_block_space_size_aligned
)
*
sizeof
(
GemmDataType
);
const
index_t
gemm3_bytes_end
=
(
SharedMemTrait
::
k_block_space_size_aligned
+
SharedMemTrait
::
a_block_space_size_aligned
)
*
sizeof
(
GemmDataType
);
const
index_t
softmax_bytes_end
=
(
SharedMemTrait
::
reduction_space_offset
+
SharedMemTrait
::
reduction_space_size_aligned
)
*
sizeof
(
FloatGemmAcc
);
const
index_t
d0_bytes_end
=
(
SharedMemTrait
::
d0_block_space_offset
+
SharedMemTrait
::
d0_block_space_size_aligned
)
*
D0Loader
::
template
TypeTransform
<
D0DataType
>
::
Size0
;
const
index_t
c_block_bytes_end
=
SharedMemTrait
::
c_block_space_size
*
sizeof
(
FloatCShuffle
);
return
math
::
max
(
gemm0_bytes_end
,
gemm1_bytes_end
,
gemm2_bytes_end
,
gemm3_bytes_end
,
softmax_bytes_end
,
d0_bytes_end
,
c_block_bytes_end
);
}
template
<
bool
HasMainKBlockLoop
,
bool
IsDropout
,
typename
Block2CTileMap
,
...
...
@@ -2292,7 +2306,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
p_d0_grid
,
d0_grid_desc_m0_n0_m1_m2_n1_m3
.
GetElementSpaceSize
());
auto
d0_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
D0DataType
*>
(
p_shared
)
+
SharedMemTrait
::
a
_block_space_offset
,
static_cast
<
D0DataType
*>
(
p_shared
)
+
SharedMemTrait
::
d0
_block_space_offset
,
D0Loader
::
d0_block_write_desc_m0_n0_m1_m2_n1_m3
.
GetElementSpaceSize
());
auto
d0_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
D0DataType
>
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment