Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
af695bee
Commit
af695bee
authored
Sep 08, 2023
by
danyao12
Browse files
Merge branch 'mha-train-develop-bwdopt-bias' into mha-train-develop-dropout8bit
parents
8ced5c4f
9e527364
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
38 additions
and
38 deletions
+38
-38
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_light_v1.hpp
...dwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_light_v1.hpp
+7
-7
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_light_v2.hpp
...dwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_light_v2.hpp
+12
-12
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_v1.hpp
...id/gridwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_v1.hpp
+7
-7
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_v2.hpp
...id/gridwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_v2.hpp
+12
-12
No files found.
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_light_v1.hpp
View file @
af695bee
...
@@ -262,29 +262,29 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
...
@@ -262,29 +262,29 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
(
NPerBlock
%
(
NXdlPerWave
*
NPerXdl
))
==
0
,
(
NPerBlock
%
(
NXdlPerWave
*
NPerXdl
))
==
0
,
"Invalid tuning param!"
);
"Invalid tuning param!"
);
const
auto
M
=
q_grid_desc_k0_m_k1
.
GetLength
(
I1
);
const
auto
M
=
q_grid_desc_k0_m_k1
.
GetLength
(
I1
);
const
auto
N
=
k_grid_desc_k0_n_k1
.
GetLength
(
I1
);
const
auto
N
=
k_grid_desc_k0_n_k1
.
GetLength
(
I1
);
const
auto
K
=
q_grid_desc_k0_m_k1
.
GetLength
(
I0
)
*
q_grid_desc_k0_m_k1
.
GetLength
(
I2
);
const
auto
K
=
q_grid_desc_k0_m_k1
.
GetLength
(
I0
)
*
q_grid_desc_k0_m_k1
.
GetLength
(
I2
);
const
auto
Gemm1N
=
v_grid_desc_o0_n_o1
.
GetLength
(
I0
)
*
v_grid_desc_o0_n_o1
.
GetLength
(
I2
);
const
auto
O
=
v_grid_desc_o0_n_o1
.
GetLength
(
I0
)
*
v_grid_desc_o0_n_o1
.
GetLength
(
I2
);
// This assumption reduces implemention complexity by categorizing 6 separate GEMMs into 3
// This assumption reduces implemention complexity by categorizing 6 separate GEMMs into 3
// types of GEMM operations, therefore some code body can be reused accordingly
// types of GEMM operations, therefore some code body can be reused accordingly
// P_MNK / dP_MNO Gemm (Gemm0 rcr)
// P_MNK / dP_MNO Gemm (Gemm0 rcr)
// Y_MON / dQ_MKN Gemm (Gemm1 rrr)
// Y_MON / dQ_MKN Gemm (Gemm1 rrr)
// dV_NOM / dK_NKM Gemm (Gemm2 crr)
// dV_NOM / dK_NKM Gemm (Gemm2 crr)
if
(
Gemm1N
!=
K
)
if
(
O
!=
K
)
{
{
std
::
cerr
<<
"SizeK must be equal to SizeO (equal attention head size)"
<<
'\n'
;
std
::
cerr
<<
"SizeK must be equal to SizeO (equal attention head size)"
<<
'\n'
;
return
false
;
return
false
;
}
}
if
(
!
(
M
==
y_grid_desc_m_o
.
GetLength
(
I0
)
&&
Gemm1N
==
y_grid_desc_m_o
.
GetLength
(
I1
)))
if
(
!
(
M
==
y_grid_desc_m_o
.
GetLength
(
I0
)
&&
O
==
y_grid_desc_m_o
.
GetLength
(
I1
)))
{
{
return
false
;
return
false
;
}
}
if
(
!
(
M
%
MPerBlock
==
0
&&
N
%
NPerBlock
==
0
&&
K
%
KPerBlock
==
0
&&
if
(
!
(
M
%
MPerBlock
==
0
&&
N
%
NPerBlock
==
0
&&
K
%
KPerBlock
==
0
&&
Gemm1N
%
Gemm1NPerBlock
==
0
))
O
%
Gemm1NPerBlock
==
0
))
{
{
return
false
;
return
false
;
}
}
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_light_v2.hpp
View file @
af695bee
...
@@ -113,11 +113,10 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
...
@@ -113,11 +113,10 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
static
constexpr
auto
WaveSize
=
64
;
static
constexpr
auto
WaveSize
=
64
;
// K1 should be Number<...>
// K1 should be Number<...>
// Gemm0
// Gemm0
static
constexpr
auto
AK0
=
Number
<
KPerBlock
/
AK1Value
>
{};
static
constexpr
auto
AK0
=
Number
<
KPerBlock
/
AK1Value
>
{};
static
constexpr
auto
BK0
=
Number
<
KPerBlock
/
BK1Value
>
{};
static
constexpr
auto
BK0
=
Number
<
KPerBlock
/
BK1Value
>
{};
static
constexpr
auto
K_K0
=
Number
<
Gemm1NPerBlock
/
BK1Value
>
{};
static
constexpr
auto
AK1
=
Number
<
AK1Value
>
{};
static
constexpr
auto
AK1
=
Number
<
AK1Value
>
{};
static
constexpr
auto
BK1
=
Number
<
BK1Value
>
{};
static
constexpr
auto
BK1
=
Number
<
BK1Value
>
{};
static
constexpr
auto
Gemm0MWaves
=
MPerBlock
/
(
MPerXdl
*
MXdlPerWave
);
static
constexpr
auto
Gemm0MWaves
=
MPerBlock
/
(
MPerXdl
*
MXdlPerWave
);
static
constexpr
auto
Gemm0NWaves
=
NPerBlock
/
(
NPerXdl
*
NXdlPerWave
);
static
constexpr
auto
Gemm0NWaves
=
NPerBlock
/
(
NPerXdl
*
NXdlPerWave
);
...
@@ -127,6 +126,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
...
@@ -127,6 +126,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
static
constexpr
auto
B1K1
=
Number
<
B1K1Value
>
{};
static
constexpr
auto
B1K1
=
Number
<
B1K1Value
>
{};
static
constexpr
auto
mfma
=
MfmaSelector
<
GemmDataType
,
MPerXdl
,
NPerXdl
>::
selected_mfma
;
static
constexpr
auto
mfma
=
MfmaSelector
<
GemmDataType
,
MPerXdl
,
NPerXdl
>::
selected_mfma
;
static
constexpr
auto
K_K0
=
Number
<
Gemm1NPerBlock
/
BK1Value
>
{};
static
constexpr
auto
V_K3
=
BK1
;
static
constexpr
auto
V_K3
=
BK1
;
static
constexpr
auto
V_K2
=
mfma
.
num_input_blks
;
static
constexpr
auto
V_K2
=
mfma
.
num_input_blks
;
static
constexpr
auto
V_K1
=
KPerBlock
/
V_K2
/
V_K3
;
static
constexpr
auto
V_K1
=
KPerBlock
/
V_K2
/
V_K3
;
...
@@ -307,29 +307,29 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
...
@@ -307,29 +307,29 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
(
NPerBlock
%
(
NXdlPerWave
*
NPerXdl
))
==
0
,
(
NPerBlock
%
(
NXdlPerWave
*
NPerXdl
))
==
0
,
"Invalid tuning param!"
);
"Invalid tuning param!"
);
const
auto
M
=
q_grid_desc_k0_m_k1
.
GetLength
(
I1
);
const
auto
M
=
q_grid_desc_k0_m_k1
.
GetLength
(
I1
);
const
auto
N
=
k_grid_desc_k0_n_k1
.
GetLength
(
I1
);
const
auto
N
=
k_grid_desc_k0_n_k1
.
GetLength
(
I1
);
const
auto
K
=
q_grid_desc_k0_m_k1
.
GetLength
(
I0
)
*
q_grid_desc_k0_m_k1
.
GetLength
(
I2
);
const
auto
K
=
q_grid_desc_k0_m_k1
.
GetLength
(
I0
)
*
q_grid_desc_k0_m_k1
.
GetLength
(
I2
);
const
auto
Gemm1N
=
v_grid_desc_o0_n_o1
.
GetLength
(
I0
)
*
v_grid_desc_o0_n_o1
.
GetLength
(
I2
);
const
auto
O
=
v_grid_desc_o0_n_o1
.
GetLength
(
I0
)
*
v_grid_desc_o0_n_o1
.
GetLength
(
I2
);
// This assumption reduces implemention complexity by categorizing 6 separate GEMMs into 3
// This assumption reduces implemention complexity by categorizing 6 separate GEMMs into 3
// types of GEMM operations, therefore some code body can be reused accordingly
// types of GEMM operations, therefore some code body can be reused accordingly
// P_MNK / dP_MNO Gemm (Gemm0 rcr)
// P_MNK / dP_MNO Gemm (Gemm0 rcr)
// Y_MON / dQ_MKN Gemm (Gemm1 rrr)
// Y_MON / dQ_MKN Gemm (Gemm1 rrr)
// dV_NOM / dK_NKM Gemm (Gemm2 crr)
// dV_NOM / dK_NKM Gemm (Gemm2 crr)
if
(
Gemm1N
!=
K
)
if
(
O
!=
K
)
{
{
std
::
cerr
<<
"SizeK must be equal to SizeO (equal attention head size)"
<<
'\n'
;
std
::
cerr
<<
"SizeK must be equal to SizeO (equal attention head size)"
<<
'\n'
;
return
false
;
return
false
;
}
}
if
(
!
(
M
==
y_grid_desc_m_o
.
GetLength
(
I0
)
&&
Gemm1N
==
y_grid_desc_m_o
.
GetLength
(
I1
)))
if
(
!
(
M
==
y_grid_desc_m_o
.
GetLength
(
I0
)
&&
O
==
y_grid_desc_m_o
.
GetLength
(
I1
)))
{
{
return
false
;
return
false
;
}
}
if
(
!
(
M
%
MPerBlock
==
0
&&
N
%
NPerBlock
==
0
&&
K
%
KPerBlock
==
0
&&
if
(
!
(
M
%
MPerBlock
==
0
&&
N
%
NPerBlock
==
0
&&
K
%
KPerBlock
==
0
&&
Gemm1N
%
Gemm1NPerBlock
==
0
))
O
%
Gemm1NPerBlock
==
0
))
{
{
return
false
;
return
false
;
}
}
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_v1.hpp
View file @
af695bee
...
@@ -261,29 +261,29 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
...
@@ -261,29 +261,29 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
(
NPerBlock
%
(
NXdlPerWave
*
NPerXdl
))
==
0
,
(
NPerBlock
%
(
NXdlPerWave
*
NPerXdl
))
==
0
,
"Invalid tuning param!"
);
"Invalid tuning param!"
);
const
auto
M
=
q_grid_desc_k0_m_k1
.
GetLength
(
I1
);
const
auto
M
=
q_grid_desc_k0_m_k1
.
GetLength
(
I1
);
const
auto
N
=
k_grid_desc_k0_n_k1
.
GetLength
(
I1
);
const
auto
N
=
k_grid_desc_k0_n_k1
.
GetLength
(
I1
);
const
auto
K
=
q_grid_desc_k0_m_k1
.
GetLength
(
I0
)
*
q_grid_desc_k0_m_k1
.
GetLength
(
I2
);
const
auto
K
=
q_grid_desc_k0_m_k1
.
GetLength
(
I0
)
*
q_grid_desc_k0_m_k1
.
GetLength
(
I2
);
const
auto
Gemm1N
=
v_grid_desc_o0_n_o1
.
GetLength
(
I0
)
*
v_grid_desc_o0_n_o1
.
GetLength
(
I2
);
const
auto
O
=
v_grid_desc_o0_n_o1
.
GetLength
(
I0
)
*
v_grid_desc_o0_n_o1
.
GetLength
(
I2
);
// This assumption reduces implemention complexity by categorizing 6 separate GEMMs into 3
// This assumption reduces implemention complexity by categorizing 6 separate GEMMs into 3
// types of GEMM operations, therefore some code body can be reused accordingly
// types of GEMM operations, therefore some code body can be reused accordingly
// P_MNK / dP_MNO Gemm (Gemm0 rcr)
// P_MNK / dP_MNO Gemm (Gemm0 rcr)
// Y_MON / dQ_MKN Gemm (Gemm1 rrr)
// Y_MON / dQ_MKN Gemm (Gemm1 rrr)
// dV_NOM / dK_NKM Gemm (Gemm2 crr)
// dV_NOM / dK_NKM Gemm (Gemm2 crr)
if
(
Gemm1N
!=
K
)
if
(
O
!=
K
)
{
{
std
::
cerr
<<
"SizeK must be equal to SizeO (equal attention head size)"
<<
'\n'
;
std
::
cerr
<<
"SizeK must be equal to SizeO (equal attention head size)"
<<
'\n'
;
return
false
;
return
false
;
}
}
if
(
!
(
M
==
y_grid_desc_m_o
.
GetLength
(
I0
)
&&
Gemm1N
==
y_grid_desc_m_o
.
GetLength
(
I1
)))
if
(
!
(
M
==
y_grid_desc_m_o
.
GetLength
(
I0
)
&&
O
==
y_grid_desc_m_o
.
GetLength
(
I1
)))
{
{
return
false
;
return
false
;
}
}
if
(
!
(
M
%
MPerBlock
==
0
&&
N
%
NPerBlock
==
0
&&
K
%
KPerBlock
==
0
&&
if
(
!
(
M
%
MPerBlock
==
0
&&
N
%
NPerBlock
==
0
&&
K
%
KPerBlock
==
0
&&
Gemm1N
%
Gemm1NPerBlock
==
0
))
O
%
Gemm1NPerBlock
==
0
))
{
{
return
false
;
return
false
;
}
}
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_v2.hpp
View file @
af695bee
...
@@ -112,11 +112,10 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
...
@@ -112,11 +112,10 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
static
constexpr
auto
WaveSize
=
64
;
static
constexpr
auto
WaveSize
=
64
;
// K1 should be Number<...>
// K1 should be Number<...>
// Gemm0
// Gemm0
static
constexpr
auto
AK0
=
Number
<
KPerBlock
/
AK1Value
>
{};
static
constexpr
auto
AK0
=
Number
<
KPerBlock
/
AK1Value
>
{};
static
constexpr
auto
BK0
=
Number
<
KPerBlock
/
BK1Value
>
{};
static
constexpr
auto
BK0
=
Number
<
KPerBlock
/
BK1Value
>
{};
static
constexpr
auto
K_K0
=
Number
<
Gemm1NPerBlock
/
BK1Value
>
{};
static
constexpr
auto
AK1
=
Number
<
AK1Value
>
{};
static
constexpr
auto
AK1
=
Number
<
AK1Value
>
{};
static
constexpr
auto
BK1
=
Number
<
BK1Value
>
{};
static
constexpr
auto
BK1
=
Number
<
BK1Value
>
{};
static
constexpr
auto
Gemm0MWaves
=
MPerBlock
/
(
MPerXdl
*
MXdlPerWave
);
static
constexpr
auto
Gemm0MWaves
=
MPerBlock
/
(
MPerXdl
*
MXdlPerWave
);
static
constexpr
auto
Gemm0NWaves
=
NPerBlock
/
(
NPerXdl
*
NXdlPerWave
);
static
constexpr
auto
Gemm0NWaves
=
NPerBlock
/
(
NPerXdl
*
NXdlPerWave
);
...
@@ -126,6 +125,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
...
@@ -126,6 +125,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
static
constexpr
auto
B1K1
=
Number
<
B1K1Value
>
{};
static
constexpr
auto
B1K1
=
Number
<
B1K1Value
>
{};
static
constexpr
auto
mfma
=
MfmaSelector
<
GemmDataType
,
MPerXdl
,
NPerXdl
>::
selected_mfma
;
static
constexpr
auto
mfma
=
MfmaSelector
<
GemmDataType
,
MPerXdl
,
NPerXdl
>::
selected_mfma
;
static
constexpr
auto
K_K0
=
Number
<
Gemm1NPerBlock
/
BK1Value
>
{};
static
constexpr
auto
V_K3
=
BK1
;
static
constexpr
auto
V_K3
=
BK1
;
static
constexpr
auto
V_K2
=
mfma
.
num_input_blks
;
static
constexpr
auto
V_K2
=
mfma
.
num_input_blks
;
static
constexpr
auto
V_K1
=
KPerBlock
/
V_K2
/
V_K3
;
static
constexpr
auto
V_K1
=
KPerBlock
/
V_K2
/
V_K3
;
...
@@ -306,29 +306,29 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
...
@@ -306,29 +306,29 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
(
NPerBlock
%
(
NXdlPerWave
*
NPerXdl
))
==
0
,
(
NPerBlock
%
(
NXdlPerWave
*
NPerXdl
))
==
0
,
"Invalid tuning param!"
);
"Invalid tuning param!"
);
const
auto
M
=
q_grid_desc_k0_m_k1
.
GetLength
(
I1
);
const
auto
M
=
q_grid_desc_k0_m_k1
.
GetLength
(
I1
);
const
auto
N
=
k_grid_desc_k0_n_k1
.
GetLength
(
I1
);
const
auto
N
=
k_grid_desc_k0_n_k1
.
GetLength
(
I1
);
const
auto
K
=
q_grid_desc_k0_m_k1
.
GetLength
(
I0
)
*
q_grid_desc_k0_m_k1
.
GetLength
(
I2
);
const
auto
K
=
q_grid_desc_k0_m_k1
.
GetLength
(
I0
)
*
q_grid_desc_k0_m_k1
.
GetLength
(
I2
);
const
auto
Gemm1N
=
v_grid_desc_o0_n_o1
.
GetLength
(
I0
)
*
v_grid_desc_o0_n_o1
.
GetLength
(
I2
);
const
auto
O
=
v_grid_desc_o0_n_o1
.
GetLength
(
I0
)
*
v_grid_desc_o0_n_o1
.
GetLength
(
I2
);
// This assumption reduces implemention complexity by categorizing 6 separate GEMMs into 3
// This assumption reduces implemention complexity by categorizing 6 separate GEMMs into 3
// types of GEMM operations, therefore some code body can be reused accordingly
// types of GEMM operations, therefore some code body can be reused accordingly
// P_MNK / dP_MNO Gemm (Gemm0 rcr)
// P_MNK / dP_MNO Gemm (Gemm0 rcr)
// Y_MON / dQ_MKN Gemm (Gemm1 rrr)
// Y_MON / dQ_MKN Gemm (Gemm1 rrr)
// dV_NOM / dK_NKM Gemm (Gemm2 crr)
// dV_NOM / dK_NKM Gemm (Gemm2 crr)
if
(
Gemm1N
!=
K
)
if
(
O
!=
K
)
{
{
std
::
cerr
<<
"SizeK must be equal to SizeO (equal attention head size)"
<<
'\n'
;
std
::
cerr
<<
"SizeK must be equal to SizeO (equal attention head size)"
<<
'\n'
;
return
false
;
return
false
;
}
}
if
(
!
(
M
==
y_grid_desc_m_o
.
GetLength
(
I0
)
&&
Gemm1N
==
y_grid_desc_m_o
.
GetLength
(
I1
)))
if
(
!
(
M
==
y_grid_desc_m_o
.
GetLength
(
I0
)
&&
O
==
y_grid_desc_m_o
.
GetLength
(
I1
)))
{
{
return
false
;
return
false
;
}
}
if
(
!
(
M
%
MPerBlock
==
0
&&
N
%
NPerBlock
==
0
&&
K
%
KPerBlock
==
0
&&
if
(
!
(
M
%
MPerBlock
==
0
&&
N
%
NPerBlock
==
0
&&
K
%
KPerBlock
==
0
&&
Gemm1N
%
Gemm1NPerBlock
==
0
))
O
%
Gemm1NPerBlock
==
0
))
{
{
return
false
;
return
false
;
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment