Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
9e527364
Commit
9e527364
authored
Sep 08, 2023
by
danyao12
Browse files
modify variable name
parent
b7b7e153
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
41 additions
and
41 deletions
+41
-41
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_light_v1.hpp
...dwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_light_v1.hpp
+7
-7
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_light_v2.hpp
...dwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_light_v2.hpp
+12
-12
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_v1.hpp
...id/gridwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_v1.hpp
+7
-7
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_v2.hpp
...id/gridwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_v2.hpp
+15
-15
No files found.
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_light_v1.hpp
View file @
9e527364
...
...
@@ -262,29 +262,29 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
(
NPerBlock
%
(
NXdlPerWave
*
NPerXdl
))
==
0
,
"Invalid tuning param!"
);
const
auto
M
=
q_grid_desc_k0_m_k1
.
GetLength
(
I1
);
const
auto
N
=
k_grid_desc_k0_n_k1
.
GetLength
(
I1
);
const
auto
K
=
q_grid_desc_k0_m_k1
.
GetLength
(
I0
)
*
q_grid_desc_k0_m_k1
.
GetLength
(
I2
);
const
auto
Gemm1N
=
v_grid_desc_o0_n_o1
.
GetLength
(
I0
)
*
v_grid_desc_o0_n_o1
.
GetLength
(
I2
);
const
auto
M
=
q_grid_desc_k0_m_k1
.
GetLength
(
I1
);
const
auto
N
=
k_grid_desc_k0_n_k1
.
GetLength
(
I1
);
const
auto
K
=
q_grid_desc_k0_m_k1
.
GetLength
(
I0
)
*
q_grid_desc_k0_m_k1
.
GetLength
(
I2
);
const
auto
O
=
v_grid_desc_o0_n_o1
.
GetLength
(
I0
)
*
v_grid_desc_o0_n_o1
.
GetLength
(
I2
);
// This assumption reduces implemention complexity by categorizing 6 separate GEMMs into 3
// types of GEMM operations, therefore some code body can be reused accordingly
// P_MNK / dP_MNO Gemm (Gemm0 rcr)
// Y_MON / dQ_MKN Gemm (Gemm1 rrr)
// dV_NOM / dK_NKM Gemm (Gemm2 crr)
if
(
Gemm1N
!=
K
)
if
(
O
!=
K
)
{
std
::
cerr
<<
"SizeK must be equal to SizeO (equal attention head size)"
<<
'\n'
;
return
false
;
}
if
(
!
(
M
==
y_grid_desc_m_o
.
GetLength
(
I0
)
&&
Gemm1N
==
y_grid_desc_m_o
.
GetLength
(
I1
)))
if
(
!
(
M
==
y_grid_desc_m_o
.
GetLength
(
I0
)
&&
O
==
y_grid_desc_m_o
.
GetLength
(
I1
)))
{
return
false
;
}
if
(
!
(
M
%
MPerBlock
==
0
&&
N
%
NPerBlock
==
0
&&
K
%
KPerBlock
==
0
&&
Gemm1N
%
Gemm1NPerBlock
==
0
))
O
%
Gemm1NPerBlock
==
0
))
{
return
false
;
}
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_light_v2.hpp
View file @
9e527364
...
...
@@ -113,11 +113,10 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
static
constexpr
auto
WaveSize
=
64
;
// K1 should be Number<...>
// Gemm0
static
constexpr
auto
AK0
=
Number
<
KPerBlock
/
AK1Value
>
{};
static
constexpr
auto
BK0
=
Number
<
KPerBlock
/
BK1Value
>
{};
static
constexpr
auto
K_K0
=
Number
<
Gemm1NPerBlock
/
BK1Value
>
{};
static
constexpr
auto
AK1
=
Number
<
AK1Value
>
{};
static
constexpr
auto
BK1
=
Number
<
BK1Value
>
{};
static
constexpr
auto
AK0
=
Number
<
KPerBlock
/
AK1Value
>
{};
static
constexpr
auto
BK0
=
Number
<
KPerBlock
/
BK1Value
>
{};
static
constexpr
auto
AK1
=
Number
<
AK1Value
>
{};
static
constexpr
auto
BK1
=
Number
<
BK1Value
>
{};
static
constexpr
auto
Gemm0MWaves
=
MPerBlock
/
(
MPerXdl
*
MXdlPerWave
);
static
constexpr
auto
Gemm0NWaves
=
NPerBlock
/
(
NPerXdl
*
NXdlPerWave
);
...
...
@@ -127,6 +126,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
static
constexpr
auto
B1K1
=
Number
<
B1K1Value
>
{};
static
constexpr
auto
mfma
=
MfmaSelector
<
GemmDataType
,
MPerXdl
,
NPerXdl
>::
selected_mfma
;
static
constexpr
auto
K_K0
=
Number
<
Gemm1NPerBlock
/
BK1Value
>
{};
static
constexpr
auto
V_K3
=
BK1
;
static
constexpr
auto
V_K2
=
mfma
.
num_input_blks
;
static
constexpr
auto
V_K1
=
KPerBlock
/
V_K2
/
V_K3
;
...
...
@@ -307,29 +307,29 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
(
NPerBlock
%
(
NXdlPerWave
*
NPerXdl
))
==
0
,
"Invalid tuning param!"
);
const
auto
M
=
q_grid_desc_k0_m_k1
.
GetLength
(
I1
);
const
auto
N
=
k_grid_desc_k0_n_k1
.
GetLength
(
I1
);
const
auto
K
=
q_grid_desc_k0_m_k1
.
GetLength
(
I0
)
*
q_grid_desc_k0_m_k1
.
GetLength
(
I2
);
const
auto
Gemm1N
=
v_grid_desc_o0_n_o1
.
GetLength
(
I0
)
*
v_grid_desc_o0_n_o1
.
GetLength
(
I2
);
const
auto
M
=
q_grid_desc_k0_m_k1
.
GetLength
(
I1
);
const
auto
N
=
k_grid_desc_k0_n_k1
.
GetLength
(
I1
);
const
auto
K
=
q_grid_desc_k0_m_k1
.
GetLength
(
I0
)
*
q_grid_desc_k0_m_k1
.
GetLength
(
I2
);
const
auto
O
=
v_grid_desc_o0_n_o1
.
GetLength
(
I0
)
*
v_grid_desc_o0_n_o1
.
GetLength
(
I2
);
// This assumption reduces implemention complexity by categorizing 6 separate GEMMs into 3
// types of GEMM operations, therefore some code body can be reused accordingly
// P_MNK / dP_MNO Gemm (Gemm0 rcr)
// Y_MON / dQ_MKN Gemm (Gemm1 rrr)
// dV_NOM / dK_NKM Gemm (Gemm2 crr)
if
(
Gemm1N
!=
K
)
if
(
O
!=
K
)
{
std
::
cerr
<<
"SizeK must be equal to SizeO (equal attention head size)"
<<
'\n'
;
return
false
;
}
if
(
!
(
M
==
y_grid_desc_m_o
.
GetLength
(
I0
)
&&
Gemm1N
==
y_grid_desc_m_o
.
GetLength
(
I1
)))
if
(
!
(
M
==
y_grid_desc_m_o
.
GetLength
(
I0
)
&&
O
==
y_grid_desc_m_o
.
GetLength
(
I1
)))
{
return
false
;
}
if
(
!
(
M
%
MPerBlock
==
0
&&
N
%
NPerBlock
==
0
&&
K
%
KPerBlock
==
0
&&
Gemm1N
%
Gemm1NPerBlock
==
0
))
O
%
Gemm1NPerBlock
==
0
))
{
return
false
;
}
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_v1.hpp
View file @
9e527364
...
...
@@ -261,29 +261,29 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
(
NPerBlock
%
(
NXdlPerWave
*
NPerXdl
))
==
0
,
"Invalid tuning param!"
);
const
auto
M
=
q_grid_desc_k0_m_k1
.
GetLength
(
I1
);
const
auto
N
=
k_grid_desc_k0_n_k1
.
GetLength
(
I1
);
const
auto
K
=
q_grid_desc_k0_m_k1
.
GetLength
(
I0
)
*
q_grid_desc_k0_m_k1
.
GetLength
(
I2
);
const
auto
Gemm1N
=
v_grid_desc_o0_n_o1
.
GetLength
(
I0
)
*
v_grid_desc_o0_n_o1
.
GetLength
(
I2
);
const
auto
M
=
q_grid_desc_k0_m_k1
.
GetLength
(
I1
);
const
auto
N
=
k_grid_desc_k0_n_k1
.
GetLength
(
I1
);
const
auto
K
=
q_grid_desc_k0_m_k1
.
GetLength
(
I0
)
*
q_grid_desc_k0_m_k1
.
GetLength
(
I2
);
const
auto
O
=
v_grid_desc_o0_n_o1
.
GetLength
(
I0
)
*
v_grid_desc_o0_n_o1
.
GetLength
(
I2
);
// This assumption reduces implemention complexity by categorizing 6 separate GEMMs into 3
// types of GEMM operations, therefore some code body can be reused accordingly
// P_MNK / dP_MNO Gemm (Gemm0 rcr)
// Y_MON / dQ_MKN Gemm (Gemm1 rrr)
// dV_NOM / dK_NKM Gemm (Gemm2 crr)
if
(
Gemm1N
!=
K
)
if
(
O
!=
K
)
{
std
::
cerr
<<
"SizeK must be equal to SizeO (equal attention head size)"
<<
'\n'
;
return
false
;
}
if
(
!
(
M
==
y_grid_desc_m_o
.
GetLength
(
I0
)
&&
Gemm1N
==
y_grid_desc_m_o
.
GetLength
(
I1
)))
if
(
!
(
M
==
y_grid_desc_m_o
.
GetLength
(
I0
)
&&
O
==
y_grid_desc_m_o
.
GetLength
(
I1
)))
{
return
false
;
}
if
(
!
(
M
%
MPerBlock
==
0
&&
N
%
NPerBlock
==
0
&&
K
%
KPerBlock
==
0
&&
Gemm1N
%
Gemm1NPerBlock
==
0
))
O
%
Gemm1NPerBlock
==
0
))
{
return
false
;
}
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_v2.hpp
View file @
9e527364
...
...
@@ -112,11 +112,10 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
static
constexpr
auto
WaveSize
=
64
;
// K1 should be Number<...>
// Gemm0
static
constexpr
auto
AK0
=
Number
<
KPerBlock
/
AK1Value
>
{};
static
constexpr
auto
BK0
=
Number
<
KPerBlock
/
BK1Value
>
{};
static
constexpr
auto
K_K0
=
Number
<
Gemm1NPerBlock
/
BK1Value
>
{};
static
constexpr
auto
AK1
=
Number
<
AK1Value
>
{};
static
constexpr
auto
BK1
=
Number
<
BK1Value
>
{};
static
constexpr
auto
AK0
=
Number
<
KPerBlock
/
AK1Value
>
{};
static
constexpr
auto
BK0
=
Number
<
KPerBlock
/
BK1Value
>
{};
static
constexpr
auto
AK1
=
Number
<
AK1Value
>
{};
static
constexpr
auto
BK1
=
Number
<
BK1Value
>
{};
static
constexpr
auto
Gemm0MWaves
=
MPerBlock
/
(
MPerXdl
*
MXdlPerWave
);
static
constexpr
auto
Gemm0NWaves
=
NPerBlock
/
(
NPerXdl
*
NXdlPerWave
);
...
...
@@ -126,6 +125,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
static
constexpr
auto
B1K1
=
Number
<
B1K1Value
>
{};
static
constexpr
auto
mfma
=
MfmaSelector
<
GemmDataType
,
MPerXdl
,
NPerXdl
>::
selected_mfma
;
static
constexpr
auto
K_K0
=
Number
<
Gemm1NPerBlock
/
BK1Value
>
{};
static
constexpr
auto
V_K3
=
BK1
;
static
constexpr
auto
V_K2
=
mfma
.
num_input_blks
;
static
constexpr
auto
V_K1
=
KPerBlock
/
V_K2
/
V_K3
;
...
...
@@ -306,29 +306,29 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
(
NPerBlock
%
(
NXdlPerWave
*
NPerXdl
))
==
0
,
"Invalid tuning param!"
);
const
auto
M
=
q_grid_desc_k0_m_k1
.
GetLength
(
I1
);
const
auto
N
=
k_grid_desc_k0_n_k1
.
GetLength
(
I1
);
const
auto
K
=
q_grid_desc_k0_m_k1
.
GetLength
(
I0
)
*
q_grid_desc_k0_m_k1
.
GetLength
(
I2
);
const
auto
Gemm1N
=
v_grid_desc_o0_n_o1
.
GetLength
(
I0
)
*
v_grid_desc_o0_n_o1
.
GetLength
(
I2
);
const
auto
M
=
q_grid_desc_k0_m_k1
.
GetLength
(
I1
);
const
auto
N
=
k_grid_desc_k0_n_k1
.
GetLength
(
I1
);
const
auto
K
=
q_grid_desc_k0_m_k1
.
GetLength
(
I0
)
*
q_grid_desc_k0_m_k1
.
GetLength
(
I2
);
const
auto
O
=
v_grid_desc_o0_n_o1
.
GetLength
(
I0
)
*
v_grid_desc_o0_n_o1
.
GetLength
(
I2
);
// This assumption reduces implemention complexity by categorizing 6 separate GEMMs into 3
// types of GEMM operations, therefore some code body can be reused accordingly
// P_MNK / dP_MNO Gemm (Gemm0 rcr)
// Y_MON / dQ_MKN Gemm (Gemm1 rrr)
// dV_NOM / dK_NKM Gemm (Gemm2 crr)
if
(
Gemm1N
!=
K
)
if
(
O
!=
K
)
{
std
::
cerr
<<
"SizeK must be equal to SizeO (equal attention head size)"
<<
'\n'
;
return
false
;
}
if
(
!
(
M
==
y_grid_desc_m_o
.
GetLength
(
I0
)
&&
Gemm1N
==
y_grid_desc_m_o
.
GetLength
(
I1
)))
if
(
!
(
M
==
y_grid_desc_m_o
.
GetLength
(
I0
)
&&
O
==
y_grid_desc_m_o
.
GetLength
(
I1
)))
{
return
false
;
}
if
(
!
(
M
%
MPerBlock
==
0
&&
N
%
NPerBlock
==
0
&&
K
%
KPerBlock
==
0
&&
Gemm1N
%
Gemm1NPerBlock
==
0
))
O
%
Gemm1NPerBlock
==
0
))
{
return
false
;
}
...
...
@@ -1478,9 +1478,9 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
static
constexpr
auto
d0_block_space_size_aligned
=
math
::
integer_least_multiple
(
D0Loader
::
d0_block_write_desc_m0_n0_m1_m2_n1_m3
.
GetElementSpaceSize
(),
max_lds_align
);
static
constexpr
auto
d0_block_space_offset
=
k_block_space_size_aligned
.
value
*
sizeof
(
GemmDataType
)
/
D0Loader
::
template
TypeTransform
<
D0DataType
>
::
Size
;
static
constexpr
auto
d0_block_space_offset
=
k_block_space_size_aligned
.
value
*
sizeof
(
GemmDataType
)
/
D0Loader
::
template
TypeTransform
<
D0DataType
>
::
Size
;
// LDS allocation for C shuffle in LDS
static
constexpr
auto
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
=
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment