Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
940f786e
Commit
940f786e
authored
Nov 21, 2023
by
letaoqin
Browse files
fix d0 load parameter
parent
9de99bdb
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
40 additions
and
15 deletions
+40
-15
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_infer_xdl_cshuffle.hpp
...tion/gpu/grid/gridwise_batched_mha_infer_xdl_cshuffle.hpp
+30
-11
include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1r1.hpp
...tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1r1.hpp
+10
-4
No files found.
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_infer_xdl_cshuffle.hpp
View file @
940f786e
...
...
@@ -239,7 +239,12 @@ struct GridwiseMultiHeadFlashAttentionInfer_Xdl_CShuffle
SharedMemTrait
::
reduction_space_size_aligned
*
sizeof
(
FloatGemmAcc
);
const
index_t
c_block_bytes_end
=
SharedMemTrait
::
c_block_space_size
*
sizeof
(
FloatCShuffle
);
return
math
::
max
(
gemm0_bytes_end
,
gemm1_bytes_end
,
softmax_bytes_end
,
c_block_bytes_end
);
const
index_t
d0_bytes_end
=
SharedMemTrait
::
d0_block_space_offset
*
sizeof
(
FloatAB
)
+
SharedMemTrait
::
d0_block_space_size_aligned
*
D0Operator
::
template
TypeTransform
<
D0DataType
>
::
Size0
;
return
math
::
max
(
gemm0_bytes_end
,
gemm1_bytes_end
,
softmax_bytes_end
,
c_block_bytes_end
,
d0_bytes_end
);
}
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
...
...
@@ -362,14 +367,15 @@ struct GridwiseMultiHeadFlashAttentionInfer_Xdl_CShuffle
using
DefaultBlock2CTileMap
=
remove_cvref_t
<
decltype
(
MakeDefaultBlock2CTileMap
(
C1GridDesc_M_N
{}))
>
;
static
constexpr
auto
AKPerBlock
=
32
;
static
constexpr
auto
D0N2
=
AK1
;
static
constexpr
auto
D0N1
=
Number
<
32
/
AK1
.
value
>
{};
static
constexpr
auto
D0N0
=
Number
<
NPerBlock
/
32
>
{};
static
constexpr
auto
D0N0_PerShuffle
=
Number
<
KPerBlock
/
32
>
{};
static
constexpr
auto
D0_NumShuffle
=
NPerBlock
/
KPerBlock
;
static
constexpr
auto
D0N0_PerShuffle
=
Number
<
A
KPerBlock
/
32
>
{};
static
constexpr
auto
D0_NumShuffle
=
NPerBlock
/
A
KPerBlock
;
static_assert
(
NPerBlock
%
KPerBlock
==
0
&&
KPerBlock
%
32
==
0
,
"KPerBlock should be multiple of 32 and divisor of NPerBlock"
);
static_assert
(
NPerBlock
%
A
KPerBlock
==
0
&&
A
KPerBlock
%
32
==
0
,
"
A
KPerBlock should be multiple of 32 and divisor of NPerBlock"
);
__host__
__device__
static
constexpr
auto
MakeD0GridDescriptor_M0_N0_N1_N2_M1_N3
(
const
D0GridDesc_M_N
&
d0_grid_desc_m_n
)
...
...
@@ -403,11 +409,15 @@ struct GridwiseMultiHeadFlashAttentionInfer_Xdl_CShuffle
struct
TypeTransform
{
using
Type
=
DataType
;
static
constexpr
index_t
Size0
=
sizeof
(
DataType
);
static
constexpr
index_t
Size
=
sizeof
(
DataType
);
};
template
<
>
struct
TypeTransform
<
void
>
{
using
Type
=
ck
::
half_t
;
static
constexpr
index_t
Size0
=
0
;
static
constexpr
index_t
Size
=
sizeof
(
ck
::
half_t
);
};
__host__
__device__
static
constexpr
auto
GetD0BlockGlobalDescriptor_M0_N0_N1_N2_M1_N3
()
...
...
@@ -499,7 +509,7 @@ struct GridwiseMultiHeadFlashAttentionInfer_Xdl_CShuffle
static
constexpr
auto
max_lds_align
=
math
::
lcm
(
math
::
lcm
(
AK1
,
BK1
),
B1K1
);
static
constexpr
auto
a_block_space_size_aligned
=
math
::
integer_least_multiple
(
a_block_desc_ak0_m_ak1
.
GetElementSpaceSize
()
*
Number
<
Q_d
/
KPerBlock
>
{},
a_block_desc_ak0_m_ak1
.
GetElementSpaceSize
()
*
Number
<
std
::
max
(
Q_d
/
KPerBlock
,
1
)
>
{},
max_lds_align
);
static
constexpr
auto
b_block_space_size_aligned
=
math
::
integer_least_multiple
(
b_block_desc_bk0_n_bk1
.
GetElementSpaceSize
(),
max_lds_align
);
...
...
@@ -521,6 +531,11 @@ struct GridwiseMultiHeadFlashAttentionInfer_Xdl_CShuffle
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
();
static
constexpr
auto
c_block_space_size
=
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
.
GetElementSpaceSize
();
// LDS allocation for D0 shuffle in LDS
static
constexpr
auto
d0_block_space_offset
=
a_block_space_size_aligned
.
value
;
static
constexpr
auto
d0_block_space_size_aligned
=
math
::
integer_least_multiple
(
D0Operator
::
d0_block_dst_desc_m0_n0_n1_n2_m1_n3
.
GetElementSpaceSize
(),
max_lds_align
);
};
template
<
bool
HasMainKBlockLoop
,
typename
Block2CTileMap
,
typename
C0MatrixMask
>
...
...
@@ -686,7 +701,7 @@ struct GridwiseMultiHeadFlashAttentionInfer_Xdl_CShuffle
const
auto
b_block_reset_copy_step
=
make_multi_index
(
-
b_grid_desc_bk0_n_bk1
.
GetLength
(
I0
),
NPerBlock
,
0
);
const
auto
Q
_k
=
a_grid_desc_ak0_m_ak1
.
GetLength
(
I0
)
*
a_grid_desc_ak0_m_ak1
.
GetLength
(
I2
);
const
auto
q
_k
=
a_grid_desc_ak0_m_ak1
.
GetLength
(
I0
)
*
a_grid_desc_ak0_m_ak1
.
GetLength
(
I2
);
// gridwise GEMM pipeline
// Only supports LoopScheduler::Default
const
auto
gridwise_gemm_pipeline
=
GridwiseGemmPipeline_v1r1
<
1
>
{};
...
...
@@ -937,6 +952,7 @@ struct GridwiseMultiHeadFlashAttentionInfer_Xdl_CShuffle
continue
;
}
// gemm0
const
bool
is_can_load_once
=
(
q_k
<=
64
&&
KPerBlock
<=
64
);
gridwise_gemm_pipeline
.
template
Run
<
HasMainKBlockLoop
,
FloatAB
>(
a_grid_desc_ak0_m_ak1
,
a_block_desc_ak0_m_ak1
,
...
...
@@ -953,7 +969,8 @@ struct GridwiseMultiHeadFlashAttentionInfer_Xdl_CShuffle
acc_thread_buf
,
num_k_block_main_loop
,
p_shared
,
gemm1_k_block_outer_index
==
0
||
Q_k
>
64
);
(
gemm1_k_block_outer_index
==
0
&&
is_can_load_once
)
||
(
!
is_can_load_once
),
KPerBlock
==
32
&&
q_k
==
64
);
// do MNK padding or upper triangular masking
if
constexpr
(
MaskOutUpperTriangle
||
PadN
)
...
...
@@ -1029,7 +1046,9 @@ struct GridwiseMultiHeadFlashAttentionInfer_Xdl_CShuffle
p_d0_grid
,
d0_grid_desc_m0_n0_n1_n2_m1_n3
.
GetElementSpaceSize
());
auto
d0_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
D0DataType
*>
(
p_shared
)
+
SharedMemTrait
::
a_block_space_offset
,
static_cast
<
D0DataType
*>
(
p_shared
)
+
SharedMemTrait
::
d0_block_space_offset
*
sizeof
(
FloatAB
)
/
sizeof
(
D0DataType
),
D0Operator
::
d0_block_dst_desc_m0_n0_n1_n2_m1_n3
.
GetElementSpaceSize
());
auto
d0_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
D0DataType
>
(
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1r1.hpp
View file @
940f786e
...
...
@@ -55,7 +55,8 @@ struct GridwiseGemmPipeline_v1r1<1>
CThreadBuffer
&
c_thread_buf
,
index_t
num_loop
,
void
*
p_shared
,
bool
bIsLoadAblock
)
bool
bIsLoadAblock
,
bool
bIsChangeCache
)
{
auto
a_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
FloatAB
*>
(
p_shared
)
+
0
,
a_block_desc
.
GetElementSpaceSize
());
...
...
@@ -95,9 +96,14 @@ struct GridwiseGemmPipeline_v1r1<1>
if
(
bIsLoadAblock
)
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc
,
a_block_copy_step
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc
,
b_block_copy_step
);
ignore
=
bIsChangeCache
;
if
(
bIsChangeCache
)
a_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
FloatAB
*>
(
p_shared
)
+
a_block_desc
.
GetElementSpaceSize
()
*
(
i
+
1
),
static_cast
<
FloatAB
*>
(
p_shared
)
+
a_block_desc
.
GetElementSpaceSize
()
*
(
i
+
1
),
a_block_desc
.
GetElementSpaceSize
());
if
(
bIsLoadAblock
)
a_blockwise_copy
.
RunWrite
(
a_block_desc
,
a_block_buf
);
b_blockwise_copy
.
RunWrite
(
b_block_desc
,
b_block_buf
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment