Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
774c9209
"git@developer.sourcefind.cn:cnjsdfcy/simbricks.git" did not exist on "418125d0bc338f025689181b293b9585f1f4a17b"
Commit
774c9209
authored
Sep 13, 2023
by
letaoqin
Browse files
change make c grid desc to c0 grid desc, because c is mxo, c0 is mxn
parent
522d8b2f
Changes
10
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
87 additions
and
51 deletions
+87
-51
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_bwd_xdl_cshuffle_qloop_light_v1.hpp
...pl/device_batched_mha_bwd_xdl_cshuffle_qloop_light_v1.hpp
+6
-6
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_bwd_xdl_cshuffle_qloop_light_v2.hpp
...pl/device_batched_mha_bwd_xdl_cshuffle_qloop_light_v2.hpp
+6
-6
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_bwd_xdl_cshuffle_qloop_v1.hpp
...ice/impl/device_batched_mha_bwd_xdl_cshuffle_qloop_v1.hpp
+6
-6
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_bwd_xdl_cshuffle_qloop_v2.hpp
...ice/impl/device_batched_mha_bwd_xdl_cshuffle_qloop_v2.hpp
+10
-6
include/ck/tensor_operation/gpu/device/impl/device_grouped_mha_bwd_xdl_cshuffle_qloop_light_v1.hpp
...pl/device_grouped_mha_bwd_xdl_cshuffle_qloop_light_v1.hpp
+6
-6
include/ck/tensor_operation/gpu/device/impl/device_grouped_mha_bwd_xdl_cshuffle_qloop_light_v2.hpp
...pl/device_grouped_mha_bwd_xdl_cshuffle_qloop_light_v2.hpp
+7
-7
include/ck/tensor_operation/gpu/device/impl/device_grouped_mha_bwd_xdl_cshuffle_qloop_v1.hpp
...ice/impl/device_grouped_mha_bwd_xdl_cshuffle_qloop_v1.hpp
+7
-7
include/ck/tensor_operation/gpu/device/impl/device_grouped_mha_bwd_xdl_cshuffle_qloop_v2.hpp
...ice/impl/device_grouped_mha_bwd_xdl_cshuffle_qloop_v2.hpp
+7
-7
include/ck/tensor_operation/gpu/device/matrix_padder.hpp
include/ck/tensor_operation/gpu/device/matrix_padder.hpp
+9
-0
include/ck/tensor_operation/operator_transform/transform_contraction_to_gemm.hpp
...tion/operator_transform/transform_contraction_to_gemm.hpp
+23
-0
No files found.
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_bwd_xdl_cshuffle_qloop_light_v1.hpp
View file @
774c9209
...
@@ -591,7 +591,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
...
@@ -591,7 +591,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
static
auto
MakeZGridDescriptor_M_N
(
const
std
::
vector
<
index_t
>&
z_gs_ms_ns_lengths
,
static
auto
MakeZGridDescriptor_M_N
(
const
std
::
vector
<
index_t
>&
z_gs_ms_ns_lengths
,
const
std
::
vector
<
index_t
>&
z_gs_ms_ns_strides
)
const
std
::
vector
<
index_t
>&
z_gs_ms_ns_strides
)
{
{
return
Transform
::
MakeCGridDescriptor_M_N
(
z_gs_ms_ns_lengths
,
z_gs_ms_ns_strides
);
return
Transform
::
MakeC
0
GridDescriptor_M_N
(
z_gs_ms_ns_lengths
,
z_gs_ms_ns_strides
);
}
}
//
//
// dS_i_j = P_i_j .* (dP_i_j - dY_i dot Y_i)
// dS_i_j = P_i_j .* (dP_i_j - dY_i dot Y_i)
...
@@ -647,7 +647,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
...
@@ -647,7 +647,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
static
auto
MakeD0GridDescriptor_M_N
(
const
std
::
vector
<
index_t
>&
d_gs_ms_ns_lengths
,
static
auto
MakeD0GridDescriptor_M_N
(
const
std
::
vector
<
index_t
>&
d_gs_ms_ns_lengths
,
const
std
::
vector
<
index_t
>&
d_gs_ms_ns_strides
)
const
std
::
vector
<
index_t
>&
d_gs_ms_ns_strides
)
{
{
return
Transform
::
MakeCGridDescriptor_M_N
(
d_gs_ms_ns_lengths
,
d_gs_ms_ns_strides
);
return
Transform
::
MakeC
0
GridDescriptor_M_N
(
d_gs_ms_ns_lengths
,
d_gs_ms_ns_strides
);
}
}
static
auto
MakeDGridDescriptor_M
(
index_t
MRaw
)
static
auto
MakeDGridDescriptor_M
(
index_t
MRaw
)
...
@@ -677,7 +677,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
...
@@ -677,7 +677,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
using
AGridDesc_AK0_M_AK1
=
decltype
(
MakeAGridDescriptor_AK0_M_AK1
({},
{}));
using
AGridDesc_AK0_M_AK1
=
decltype
(
MakeAGridDescriptor_AK0_M_AK1
({},
{}));
using
BGridDesc_BK0_N_BK1
=
decltype
(
MakeBGridDescriptor_BK0_N_BK1
({},
{}));
using
BGridDesc_BK0_N_BK1
=
decltype
(
MakeBGridDescriptor_BK0_N_BK1
({},
{}));
using
D0GridDesc_G_M_N
=
decltype
(
Transform
::
MakeCGridDescriptor_G_M_N
({},
{}));
using
D0GridDesc_G_M_N
=
decltype
(
Transform
::
MakeC
0
GridDescriptor_G_M_N
({},
{}));
using
B1GridDesc_BK0_N_BK1
=
decltype
(
MakeBGridDescriptor_BK0_N_BK1
({},
{}));
using
B1GridDesc_BK0_N_BK1
=
decltype
(
MakeBGridDescriptor_BK0_N_BK1
({},
{}));
using
YGridDesc_M_O
=
decltype
(
Transform
::
MakeCGridDescriptor_M_N
({},
{}));
using
YGridDesc_M_O
=
decltype
(
Transform
::
MakeCGridDescriptor_M_N
({},
{}));
using
LSEGridDesc_M
=
decltype
(
MakeLSEGridDescriptor_M
(
1
));
using
LSEGridDesc_M
=
decltype
(
MakeLSEGridDescriptor_M
(
1
));
...
@@ -685,7 +685,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
...
@@ -685,7 +685,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
using
BGridDesc_G_N_K
=
decltype
(
Transform
::
MakeB0GridDescriptor_G_N_K
({},
{}));
using
BGridDesc_G_N_K
=
decltype
(
Transform
::
MakeB0GridDescriptor_G_N_K
({},
{}));
using
B1GridDesc_G_N_K
=
decltype
(
Transform
::
MakeB1GridDescriptor_G_N_K
({},
{}));
using
B1GridDesc_G_N_K
=
decltype
(
Transform
::
MakeB1GridDescriptor_G_N_K
({},
{}));
using
CGridDesc_G_M_N
=
decltype
(
Transform
::
MakeCGridDescriptor_G_M_N
({},
{}));
using
CGridDesc_G_M_N
=
decltype
(
Transform
::
MakeCGridDescriptor_G_M_N
({},
{}));
using
ZGridDesc_G_M_N
=
decltype
(
Transform
::
MakeCGridDescriptor_G_M_N
({},
{}));
using
ZGridDesc_G_M_N
=
decltype
(
Transform
::
MakeC
0
GridDescriptor_G_M_N
({},
{}));
using
DYGridDesc_M_O
=
decltype
(
DTransform
::
MakeCGridDescriptor_M_N
({},
{}));
using
DYGridDesc_M_O
=
decltype
(
DTransform
::
MakeCGridDescriptor_M_N
({},
{}));
using
DGridDesc_M
=
decltype
(
MakeDGridDescriptor_M
(
1
));
using
DGridDesc_M
=
decltype
(
MakeDGridDescriptor_M
(
1
));
...
@@ -936,7 +936,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
...
@@ -936,7 +936,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
c_grid_desc_g_m_n_
{
Transform
::
MakeCGridDescriptor_G_M_N
(
c_gs_ms_gemm1ns_lengths
,
c_grid_desc_g_m_n_
{
Transform
::
MakeCGridDescriptor_G_M_N
(
c_gs_ms_gemm1ns_lengths
,
c_gs_ms_gemm1ns_strides
)},
c_gs_ms_gemm1ns_strides
)},
z_grid_desc_g_m_n_
{
z_grid_desc_g_m_n_
{
Transform
::
MakeCGridDescriptor_G_M_N
(
z_gs_ms_ns_lengths
,
z_gs_ms_ns_strides
)},
Transform
::
MakeC
0
GridDescriptor_G_M_N
(
z_gs_ms_ns_lengths
,
z_gs_ms_ns_strides
)},
block_2_ctile_map_
{
GridwiseGemm
::
MakeDefaultBlock2CTileMap
(
k_grid_desc_n_k_
)},
block_2_ctile_map_
{
GridwiseGemm
::
MakeDefaultBlock2CTileMap
(
k_grid_desc_n_k_
)},
d_block_2_ctile_map_
{
d_block_2_ctile_map_
{
GridwiseYDotYGrad
::
MakeDefaultBlock2CTileMap
(
d_y_grid_desc_m_o_
)},
GridwiseYDotYGrad
::
MakeDefaultBlock2CTileMap
(
d_y_grid_desc_m_o_
)},
...
@@ -975,7 +975,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
...
@@ -975,7 +975,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
d0_grid_desc_m0_n0_m1_m2_n1_m3_
=
d0_grid_desc_m0_n0_m1_m2_n1_m3_
=
GridwiseGemm
::
MakeD0GridDescriptor_M0_N0_M1_M2_N1_M3
(
d0_grid_desc_m_n
);
GridwiseGemm
::
MakeD0GridDescriptor_M0_N0_M1_M2_N1_M3
(
d0_grid_desc_m_n
);
d0_grid_desc_g_m_n_
=
Transform
::
MakeCGridDescriptor_G_M_N
(
d0_grid_desc_g_m_n_
=
Transform
::
MakeC
0
GridDescriptor_G_M_N
(
acc0_bias_gs_ms_ns_lengths
,
acc0_bias_gs_ms_ns_strides
);
acc0_bias_gs_ms_ns_lengths
,
acc0_bias_gs_ms_ns_strides
);
d0_n_length_stride_
.
push_back
(
acc0_bias_gs_ms_ns_lengths
[
NumDimG
+
NumDimM
]);
d0_n_length_stride_
.
push_back
(
acc0_bias_gs_ms_ns_lengths
[
NumDimG
+
NumDimM
]);
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_bwd_xdl_cshuffle_qloop_light_v2.hpp
View file @
774c9209
...
@@ -599,14 +599,14 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
...
@@ -599,14 +599,14 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
static
auto
MakeD0GridDescriptor_M_N
(
const
std
::
vector
<
index_t
>&
d_gs_ms_ns_lengths
,
static
auto
MakeD0GridDescriptor_M_N
(
const
std
::
vector
<
index_t
>&
d_gs_ms_ns_lengths
,
const
std
::
vector
<
index_t
>&
d_gs_ms_ns_strides
)
const
std
::
vector
<
index_t
>&
d_gs_ms_ns_strides
)
{
{
return
Transform
::
MakeCGridDescriptor_M_N
(
d_gs_ms_ns_lengths
,
d_gs_ms_ns_strides
);
return
Transform
::
MakeC
0
GridDescriptor_M_N
(
d_gs_ms_ns_lengths
,
d_gs_ms_ns_strides
);
}
}
// Z in Gemm0 C position
// Z in Gemm0 C position
static
auto
MakeZGridDescriptor_M_N
(
const
std
::
vector
<
index_t
>&
z_gs_ms_ns_lengths
,
static
auto
MakeZGridDescriptor_M_N
(
const
std
::
vector
<
index_t
>&
z_gs_ms_ns_lengths
,
const
std
::
vector
<
index_t
>&
z_gs_ms_ns_strides
)
const
std
::
vector
<
index_t
>&
z_gs_ms_ns_strides
)
{
{
return
Transform
::
MakeCGridDescriptor_M_N
(
z_gs_ms_ns_lengths
,
z_gs_ms_ns_strides
);
return
Transform
::
MakeC
0
GridDescriptor_M_N
(
z_gs_ms_ns_lengths
,
z_gs_ms_ns_strides
);
}
}
//
//
// dS_i_j = P_i_j .* (dP_i_j - dY_i dot Y_i)
// dS_i_j = P_i_j .* (dP_i_j - dY_i dot Y_i)
...
@@ -686,7 +686,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
...
@@ -686,7 +686,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
using
AGridDesc_AK0_M_AK1
=
decltype
(
MakeAGridDescriptor_AK0_M_AK1
({},
{}));
using
AGridDesc_AK0_M_AK1
=
decltype
(
MakeAGridDescriptor_AK0_M_AK1
({},
{}));
using
BGridDesc_BK0_N_BK1
=
decltype
(
MakeBGridDescriptor_BK0_N_BK1
({},
{}));
using
BGridDesc_BK0_N_BK1
=
decltype
(
MakeBGridDescriptor_BK0_N_BK1
({},
{}));
using
D0GridDesc_G_M_N
=
decltype
(
Transform
::
MakeCGridDescriptor_G_M_N
({},
{}));
using
D0GridDesc_G_M_N
=
decltype
(
Transform
::
MakeC
0
GridDescriptor_G_M_N
({},
{}));
using
B1GridDesc_BK0_N_BK1
=
decltype
(
MakeBGridDescriptor_BK0_N_BK1
({},
{}));
using
B1GridDesc_BK0_N_BK1
=
decltype
(
MakeBGridDescriptor_BK0_N_BK1
({},
{}));
using
YGridDesc_M_O
=
decltype
(
Transform
::
MakeCGridDescriptor_M_N
({},
{}));
using
YGridDesc_M_O
=
decltype
(
Transform
::
MakeCGridDescriptor_M_N
({},
{}));
using
LSEGridDesc_M
=
decltype
(
MakeLSEGridDescriptor_M
(
1
));
using
LSEGridDesc_M
=
decltype
(
MakeLSEGridDescriptor_M
(
1
));
...
@@ -694,7 +694,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
...
@@ -694,7 +694,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
using
BGridDesc_G_N_K
=
decltype
(
Transform
::
MakeB0GridDescriptor_G_N_K
({},
{}));
using
BGridDesc_G_N_K
=
decltype
(
Transform
::
MakeB0GridDescriptor_G_N_K
({},
{}));
using
B1GridDesc_G_N_K
=
decltype
(
Transform
::
MakeB1GridDescriptor_G_N_K
({},
{}));
using
B1GridDesc_G_N_K
=
decltype
(
Transform
::
MakeB1GridDescriptor_G_N_K
({},
{}));
using
CGridDesc_G_M_N
=
decltype
(
Transform
::
MakeCGridDescriptor_G_M_N
({},
{}));
using
CGridDesc_G_M_N
=
decltype
(
Transform
::
MakeCGridDescriptor_G_M_N
({},
{}));
using
ZGridDesc_G_M_N
=
decltype
(
Transform
::
MakeCGridDescriptor_G_M_N
({},
{}));
using
ZGridDesc_G_M_N
=
decltype
(
Transform
::
MakeC
0
GridDescriptor_G_M_N
({},
{}));
using
DYGridDesc_M_O
=
decltype
(
DTransform
::
MakeCGridDescriptor_M_N
({},
{}));
using
DYGridDesc_M_O
=
decltype
(
DTransform
::
MakeCGridDescriptor_M_N
({},
{}));
using
DGridDesc_M
=
decltype
(
MakeDGridDescriptor_M
(
1
));
using
DGridDesc_M
=
decltype
(
MakeDGridDescriptor_M
(
1
));
...
@@ -951,7 +951,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
...
@@ -951,7 +951,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
c_grid_desc_g_m_n_
{
Transform
::
MakeCGridDescriptor_G_M_N
(
c_gs_ms_gemm1ns_lengths
,
c_grid_desc_g_m_n_
{
Transform
::
MakeCGridDescriptor_G_M_N
(
c_gs_ms_gemm1ns_lengths
,
c_gs_ms_gemm1ns_strides
)},
c_gs_ms_gemm1ns_strides
)},
z_grid_desc_g_m_n_
{
z_grid_desc_g_m_n_
{
Transform
::
MakeCGridDescriptor_G_M_N
(
z_gs_ms_ns_lengths
,
z_gs_ms_ns_strides
)},
Transform
::
MakeC
0
GridDescriptor_G_M_N
(
z_gs_ms_ns_lengths
,
z_gs_ms_ns_strides
)},
block_2_ctile_map_
{
GridwiseGemm
::
MakeDefaultBlock2CTileMap
(
k_grid_desc_n_k_
)},
block_2_ctile_map_
{
GridwiseGemm
::
MakeDefaultBlock2CTileMap
(
k_grid_desc_n_k_
)},
d_block_2_ctile_map_
{
d_block_2_ctile_map_
{
GridwiseYDotYGrad
::
MakeDefaultBlock2CTileMap
(
d_y_grid_desc_m_o_
)},
GridwiseYDotYGrad
::
MakeDefaultBlock2CTileMap
(
d_y_grid_desc_m_o_
)},
...
@@ -990,7 +990,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
...
@@ -990,7 +990,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
d0_grid_desc_m0_n0_m1_m2_n1_m3_
=
d0_grid_desc_m0_n0_m1_m2_n1_m3_
=
GridwiseGemm
::
MakeD0GridDescriptor_M0_N0_M1_M2_N1_M3
(
d0_grid_desc_m_n
);
GridwiseGemm
::
MakeD0GridDescriptor_M0_N0_M1_M2_N1_M3
(
d0_grid_desc_m_n
);
d0_grid_desc_g_m_n_
=
Transform
::
MakeCGridDescriptor_G_M_N
(
d0_grid_desc_g_m_n_
=
Transform
::
MakeC
0
GridDescriptor_G_M_N
(
acc0_bias_gs_ms_ns_lengths
,
acc0_bias_gs_ms_ns_strides
);
acc0_bias_gs_ms_ns_lengths
,
acc0_bias_gs_ms_ns_strides
);
d0_n_length_stride_
.
push_back
(
acc0_bias_gs_ms_ns_lengths
[
NumDimG
+
NumDimM
]);
d0_n_length_stride_
.
push_back
(
acc0_bias_gs_ms_ns_lengths
[
NumDimG
+
NumDimM
]);
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_bwd_xdl_cshuffle_qloop_v1.hpp
View file @
774c9209
...
@@ -526,7 +526,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
...
@@ -526,7 +526,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
static
auto
MakeZGridDescriptor_M_N
(
const
std
::
vector
<
index_t
>&
z_gs_ms_ns_lengths
,
static
auto
MakeZGridDescriptor_M_N
(
const
std
::
vector
<
index_t
>&
z_gs_ms_ns_lengths
,
const
std
::
vector
<
index_t
>&
z_gs_ms_ns_strides
)
const
std
::
vector
<
index_t
>&
z_gs_ms_ns_strides
)
{
{
return
Transform
::
MakeCGridDescriptor_M_N
(
z_gs_ms_ns_lengths
,
z_gs_ms_ns_strides
);
return
Transform
::
MakeC
0
GridDescriptor_M_N
(
z_gs_ms_ns_lengths
,
z_gs_ms_ns_strides
);
}
}
//
//
// dS_i_j = P_i_j .* (dP_i_j - dY_i dot Y_i)
// dS_i_j = P_i_j .* (dP_i_j - dY_i dot Y_i)
...
@@ -582,12 +582,12 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
...
@@ -582,12 +582,12 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
static
auto
MakeD0GridDescriptor_M_N
(
const
std
::
vector
<
index_t
>&
d_gs_ms_ns_lengths
,
static
auto
MakeD0GridDescriptor_M_N
(
const
std
::
vector
<
index_t
>&
d_gs_ms_ns_lengths
,
const
std
::
vector
<
index_t
>&
d_gs_ms_ns_strides
)
const
std
::
vector
<
index_t
>&
d_gs_ms_ns_strides
)
{
{
return
Transform
::
MakeCGridDescriptor_M_N
(
d_gs_ms_ns_lengths
,
d_gs_ms_ns_strides
);
return
Transform
::
MakeC
0
GridDescriptor_M_N
(
d_gs_ms_ns_lengths
,
d_gs_ms_ns_strides
);
}
}
using
AGridDesc_AK0_M_AK1
=
decltype
(
MakeAGridDescriptor_AK0_M_AK1
({},
{}));
using
AGridDesc_AK0_M_AK1
=
decltype
(
MakeAGridDescriptor_AK0_M_AK1
({},
{}));
using
BGridDesc_BK0_N_BK1
=
decltype
(
MakeBGridDescriptor_BK0_N_BK1
({},
{}));
using
BGridDesc_BK0_N_BK1
=
decltype
(
MakeBGridDescriptor_BK0_N_BK1
({},
{}));
using
D0GridDesc_G_M_N
=
decltype
(
Transform
::
MakeCGridDescriptor_G_M_N
({},
{}));
using
D0GridDesc_G_M_N
=
decltype
(
Transform
::
MakeC
0
GridDescriptor_G_M_N
({},
{}));
using
B1GridDesc_BK0_N_BK1
=
decltype
(
MakeBGridDescriptor_BK0_N_BK1
({},
{}));
using
B1GridDesc_BK0_N_BK1
=
decltype
(
MakeBGridDescriptor_BK0_N_BK1
({},
{}));
using
YGridDesc_M_O
=
decltype
(
Transform
::
MakeCGridDescriptor_M_N
({},
{}));
using
YGridDesc_M_O
=
decltype
(
Transform
::
MakeCGridDescriptor_M_N
({},
{}));
using
LSEGridDesc_M
=
decltype
(
MakeLSEGridDescriptor_M
(
1
));
using
LSEGridDesc_M
=
decltype
(
MakeLSEGridDescriptor_M
(
1
));
...
@@ -595,7 +595,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
...
@@ -595,7 +595,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
using
BGridDesc_G_N_K
=
decltype
(
Transform
::
MakeB0GridDescriptor_G_N_K
({},
{}));
using
BGridDesc_G_N_K
=
decltype
(
Transform
::
MakeB0GridDescriptor_G_N_K
({},
{}));
using
B1GridDesc_G_N_K
=
decltype
(
Transform
::
MakeB1GridDescriptor_G_N_K
({},
{}));
using
B1GridDesc_G_N_K
=
decltype
(
Transform
::
MakeB1GridDescriptor_G_N_K
({},
{}));
using
CGridDesc_G_M_N
=
decltype
(
Transform
::
MakeCGridDescriptor_G_M_N
({},
{}));
using
CGridDesc_G_M_N
=
decltype
(
Transform
::
MakeCGridDescriptor_G_M_N
({},
{}));
using
ZGridDesc_G_M_N
=
decltype
(
Transform
::
MakeCGridDescriptor_G_M_N
({},
{}));
using
ZGridDesc_G_M_N
=
decltype
(
Transform
::
MakeC
0
GridDescriptor_G_M_N
({},
{}));
using
D0GridDesc_M_N
=
decltype
(
MakeD0GridDescriptor_M_N
({},
{}));
using
D0GridDesc_M_N
=
decltype
(
MakeD0GridDescriptor_M_N
({},
{}));
using
KGridDesc_N_K
=
decltype
(
Transform
::
MakeB0GridDescriptor_N_K
({},
{}));
using
KGridDesc_N_K
=
decltype
(
Transform
::
MakeB0GridDescriptor_N_K
({},
{}));
...
@@ -829,7 +829,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
...
@@ -829,7 +829,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
c_grid_desc_g_m_n_
{
Transform
::
MakeCGridDescriptor_G_M_N
(
c_gs_ms_gemm1ns_lengths
,
c_grid_desc_g_m_n_
{
Transform
::
MakeCGridDescriptor_G_M_N
(
c_gs_ms_gemm1ns_lengths
,
c_gs_ms_gemm1ns_strides
)},
c_gs_ms_gemm1ns_strides
)},
z_grid_desc_g_m_n_
{
z_grid_desc_g_m_n_
{
Transform
::
MakeCGridDescriptor_G_M_N
(
z_gs_ms_ns_lengths
,
z_gs_ms_ns_strides
)},
Transform
::
MakeC
0
GridDescriptor_G_M_N
(
z_gs_ms_ns_lengths
,
z_gs_ms_ns_strides
)},
y_grid_desc_mblock_mperblock_oblock_operblock_
{},
y_grid_desc_mblock_mperblock_oblock_operblock_
{},
block_2_ctile_map_
{
GridwiseGemm
::
MakeDefaultBlock2CTileMap
(
k_grid_desc_n_k_
)},
block_2_ctile_map_
{
GridwiseGemm
::
MakeDefaultBlock2CTileMap
(
k_grid_desc_n_k_
)},
a_element_op_
{
a_element_op
},
a_element_op_
{
a_element_op
},
...
@@ -875,7 +875,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
...
@@ -875,7 +875,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
d0_grid_desc_m0_n0_m1_m2_n1_m3_
=
d0_grid_desc_m0_n0_m1_m2_n1_m3_
=
GridwiseGemm
::
MakeD0GridDescriptor_M0_N0_M1_M2_N1_M3
(
d0_grid_desc_m_n
);
GridwiseGemm
::
MakeD0GridDescriptor_M0_N0_M1_M2_N1_M3
(
d0_grid_desc_m_n
);
d0_grid_desc_g_m_n_
=
Transform
::
MakeCGridDescriptor_G_M_N
(
d0_grid_desc_g_m_n_
=
Transform
::
MakeC
0
GridDescriptor_G_M_N
(
acc0_bias_gs_ms_ns_lengths
,
acc0_bias_gs_ms_ns_strides
);
acc0_bias_gs_ms_ns_lengths
,
acc0_bias_gs_ms_ns_strides
);
d0_n_length_stride_
.
push_back
(
acc0_bias_gs_ms_ns_lengths
[
NumDimG
+
NumDimM
]);
d0_n_length_stride_
.
push_back
(
acc0_bias_gs_ms_ns_lengths
[
NumDimG
+
NumDimM
]);
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_bwd_xdl_cshuffle_qloop_v2.hpp
View file @
774c9209
...
@@ -534,14 +534,14 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
...
@@ -534,14 +534,14 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
static
auto
MakeD0GridDescriptor_M_N
(
const
std
::
vector
<
index_t
>&
d_gs_ms_ns_lengths
,
static
auto
MakeD0GridDescriptor_M_N
(
const
std
::
vector
<
index_t
>&
d_gs_ms_ns_lengths
,
const
std
::
vector
<
index_t
>&
d_gs_ms_ns_strides
)
const
std
::
vector
<
index_t
>&
d_gs_ms_ns_strides
)
{
{
return
Transform
::
MakeCGridDescriptor_M_N
(
d_gs_ms_ns_lengths
,
d_gs_ms_ns_strides
);
return
Transform
::
MakeC
0
GridDescriptor_M_N
(
d_gs_ms_ns_lengths
,
d_gs_ms_ns_strides
);
}
}
// Z in Gemm0 C position
// Z in Gemm0 C position
static
auto
MakeZGridDescriptor_M_N
(
const
std
::
vector
<
index_t
>&
z_gs_ms_ns_lengths
,
static
auto
MakeZGridDescriptor_M_N
(
const
std
::
vector
<
index_t
>&
z_gs_ms_ns_lengths
,
const
std
::
vector
<
index_t
>&
z_gs_ms_ns_strides
)
const
std
::
vector
<
index_t
>&
z_gs_ms_ns_strides
)
{
{
return
Transform
::
MakeCGridDescriptor_M_N
(
z_gs_ms_ns_lengths
,
z_gs_ms_ns_strides
);
return
Transform
::
MakeC
0
GridDescriptor_M_N
(
z_gs_ms_ns_lengths
,
z_gs_ms_ns_strides
);
}
}
//
//
// dS_i_j = P_i_j .* (dP_i_j - dY_i dot Y_i)
// dS_i_j = P_i_j .* (dP_i_j - dY_i dot Y_i)
...
@@ -596,7 +596,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
...
@@ -596,7 +596,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
using
AGridDesc_AK0_M_AK1
=
decltype
(
MakeAGridDescriptor_AK0_M_AK1
({},
{}));
using
AGridDesc_AK0_M_AK1
=
decltype
(
MakeAGridDescriptor_AK0_M_AK1
({},
{}));
using
BGridDesc_BK0_N_BK1
=
decltype
(
MakeBGridDescriptor_BK0_N_BK1
({},
{}));
using
BGridDesc_BK0_N_BK1
=
decltype
(
MakeBGridDescriptor_BK0_N_BK1
({},
{}));
using
D0GridDesc_G_M_N
=
decltype
(
Transform
::
MakeCGridDescriptor_G_M_N
({},
{}));
using
D0GridDesc_G_M_N
=
decltype
(
Transform
::
MakeC
0
GridDescriptor_G_M_N
({},
{}));
using
B1GridDesc_BK0_N_BK1
=
decltype
(
MakeBGridDescriptor_BK0_N_BK1
({},
{}));
using
B1GridDesc_BK0_N_BK1
=
decltype
(
MakeBGridDescriptor_BK0_N_BK1
({},
{}));
using
YGridDesc_M_O
=
decltype
(
Transform
::
MakeCGridDescriptor_M_N
({},
{}));
using
YGridDesc_M_O
=
decltype
(
Transform
::
MakeCGridDescriptor_M_N
({},
{}));
using
LSEGridDesc_M
=
decltype
(
MakeLSEGridDescriptor_M
(
1
));
using
LSEGridDesc_M
=
decltype
(
MakeLSEGridDescriptor_M
(
1
));
...
@@ -604,7 +604,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
...
@@ -604,7 +604,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
using
BGridDesc_G_N_K
=
decltype
(
Transform
::
MakeB0GridDescriptor_G_N_K
({},
{}));
using
BGridDesc_G_N_K
=
decltype
(
Transform
::
MakeB0GridDescriptor_G_N_K
({},
{}));
using
B1GridDesc_G_N_K
=
decltype
(
Transform
::
MakeB1GridDescriptor_G_N_K
({},
{}));
using
B1GridDesc_G_N_K
=
decltype
(
Transform
::
MakeB1GridDescriptor_G_N_K
({},
{}));
using
CGridDesc_G_M_N
=
decltype
(
Transform
::
MakeCGridDescriptor_G_M_N
({},
{}));
using
CGridDesc_G_M_N
=
decltype
(
Transform
::
MakeCGridDescriptor_G_M_N
({},
{}));
using
ZGridDesc_G_M_N
=
decltype
(
Transform
::
MakeCGridDescriptor_G_M_N
({},
{}));
using
ZGridDesc_G_M_N
=
decltype
(
Transform
::
MakeC
0
GridDescriptor_G_M_N
({},
{}));
using
D0GridDesc_M_N
=
decltype
(
MakeD0GridDescriptor_M_N
({},
{}));
using
D0GridDesc_M_N
=
decltype
(
MakeD0GridDescriptor_M_N
({},
{}));
using
KGridDesc_N_K
=
decltype
(
Transform
::
MakeB0GridDescriptor_N_K
({},
{}));
using
KGridDesc_N_K
=
decltype
(
Transform
::
MakeB0GridDescriptor_N_K
({},
{}));
...
@@ -844,7 +844,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
...
@@ -844,7 +844,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
c_grid_desc_g_m_n_
{
Transform
::
MakeCGridDescriptor_G_M_N
(
c_gs_ms_gemm1ns_lengths
,
c_grid_desc_g_m_n_
{
Transform
::
MakeCGridDescriptor_G_M_N
(
c_gs_ms_gemm1ns_lengths
,
c_gs_ms_gemm1ns_strides
)},
c_gs_ms_gemm1ns_strides
)},
z_grid_desc_g_m_n_
{
z_grid_desc_g_m_n_
{
Transform
::
MakeCGridDescriptor_G_M_N
(
z_gs_ms_ns_lengths
,
z_gs_ms_ns_strides
)},
Transform
::
MakeC
0
GridDescriptor_G_M_N
(
z_gs_ms_ns_lengths
,
z_gs_ms_ns_strides
)},
y_grid_desc_mblock_mperblock_oblock_operblock_
{},
y_grid_desc_mblock_mperblock_oblock_operblock_
{},
block_2_ctile_map_
{
GridwiseGemm
::
MakeDefaultBlock2CTileMap
(
k_grid_desc_n_k_
)},
block_2_ctile_map_
{
GridwiseGemm
::
MakeDefaultBlock2CTileMap
(
k_grid_desc_n_k_
)},
a_element_op_
{
a_element_op
},
a_element_op_
{
a_element_op
},
...
@@ -891,7 +891,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
...
@@ -891,7 +891,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
d0_grid_desc_m0_n0_m1_m2_n1_m3_
=
d0_grid_desc_m0_n0_m1_m2_n1_m3_
=
GridwiseGemm
::
MakeD0GridDescriptor_M0_N0_M1_M2_N1_M3
(
d0_grid_desc_m_n
);
GridwiseGemm
::
MakeD0GridDescriptor_M0_N0_M1_M2_N1_M3
(
d0_grid_desc_m_n
);
d0_grid_desc_g_m_n_
=
Transform
::
MakeCGridDescriptor_G_M_N
(
d0_grid_desc_g_m_n_
=
Transform
::
MakeC
0
GridDescriptor_G_M_N
(
acc0_bias_gs_ms_ns_lengths
,
acc0_bias_gs_ms_ns_strides
);
acc0_bias_gs_ms_ns_lengths
,
acc0_bias_gs_ms_ns_strides
);
d0_n_length_stride_
.
push_back
(
acc0_bias_gs_ms_ns_lengths
[
NumDimG
+
NumDimM
]);
d0_n_length_stride_
.
push_back
(
acc0_bias_gs_ms_ns_lengths
[
NumDimG
+
NumDimM
]);
...
@@ -941,6 +941,10 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
...
@@ -941,6 +941,10 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
std
::
cout
<<
"ygrad_grid_desc_m0_o_m1_: "
<<
ygrad_grid_desc_m0_o_m1_
.
GetLength
(
I0
)
std
::
cout
<<
"ygrad_grid_desc_m0_o_m1_: "
<<
ygrad_grid_desc_m0_o_m1_
.
GetLength
(
I0
)
<<
", "
<<
ygrad_grid_desc_m0_o_m1_
.
GetLength
(
I1
)
<<
", "
<<
", "
<<
ygrad_grid_desc_m0_o_m1_
.
GetLength
(
I1
)
<<
", "
<<
ygrad_grid_desc_m0_o_m1_
.
GetLength
(
I2
)
<<
'\n'
;
<<
ygrad_grid_desc_m0_o_m1_
.
GetLength
(
I2
)
<<
'\n'
;
std
::
cout
<<
"d0_grid_desc_g_m_n_: "
<<
d0_grid_desc_g_m_n_
.
GetLength
(
I0
)
<<
", "
<<
d0_grid_desc_g_m_n_
.
GetLength
(
I1
)
<<
", "
<<
d0_grid_desc_g_m_n_
.
GetLength
(
I2
)
<<
'\n'
;
}
}
// pointers
// pointers
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_mha_bwd_xdl_cshuffle_qloop_light_v1.hpp
View file @
774c9209
...
@@ -545,7 +545,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
...
@@ -545,7 +545,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
static
auto
MakeZGridDescriptor_M_N
(
const
std
::
vector
<
index_t
>&
z_gs_ms_ns_lengths
,
static
auto
MakeZGridDescriptor_M_N
(
const
std
::
vector
<
index_t
>&
z_gs_ms_ns_lengths
,
const
std
::
vector
<
index_t
>&
z_gs_ms_ns_strides
)
const
std
::
vector
<
index_t
>&
z_gs_ms_ns_strides
)
{
{
return
Transform
::
MakeCGridDescriptor_M_N
(
z_gs_ms_ns_lengths
,
z_gs_ms_ns_strides
);
return
Transform
::
MakeC
0
GridDescriptor_M_N
(
z_gs_ms_ns_lengths
,
z_gs_ms_ns_strides
);
}
}
static
auto
MakeLSEGridDescriptor_M
(
index_t
MRaw
)
static
auto
MakeLSEGridDescriptor_M
(
index_t
MRaw
)
...
@@ -577,8 +577,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
...
@@ -577,8 +577,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
const
std
::
vector
<
ck
::
index_t
>&
acc0_bias_gs_ms_ns_strides
)
const
std
::
vector
<
ck
::
index_t
>&
acc0_bias_gs_ms_ns_strides
)
{
{
return
Transform
::
MakeCGridDescriptor_M_N
(
acc0_bias_gs_ms_ns_lengths
,
return
Transform
::
MakeC
0
GridDescriptor_M_N
(
acc0_bias_gs_ms_ns_lengths
,
acc0_bias_gs_ms_ns_strides
);
acc0_bias_gs_ms_ns_strides
);
}
}
static
auto
static
auto
...
@@ -586,8 +586,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
...
@@ -586,8 +586,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
const
std
::
vector
<
ck
::
index_t
>&
acc0_bias_gs_ms_ns_strides
)
const
std
::
vector
<
ck
::
index_t
>&
acc0_bias_gs_ms_ns_strides
)
{
{
return
Transform
::
MakeCGridDescriptor_G_M_N
(
acc0_bias_gs_ms_ns_lengths
,
return
Transform
::
MakeC
0
GridDescriptor_G_M_N
(
acc0_bias_gs_ms_ns_lengths
,
acc0_bias_gs_ms_ns_strides
);
acc0_bias_gs_ms_ns_strides
);
}
}
static
auto
MakeDGridDescriptor_M
(
index_t
MRaw
)
static
auto
MakeDGridDescriptor_M
(
index_t
MRaw
)
...
@@ -998,7 +998,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
...
@@ -998,7 +998,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
problem_desc
.
b_gs_ns_ks_lengths
,
problem_desc
.
b_gs_ns_ks_strides
);
problem_desc
.
b_gs_ns_ks_lengths
,
problem_desc
.
b_gs_ns_ks_strides
);
const
auto
d0_grid_desc_g_m_n
=
DeviceOp
::
MakeD0GridDescriptor_G_M_N
(
const
auto
d0_grid_desc_g_m_n
=
DeviceOp
::
MakeD0GridDescriptor_G_M_N
(
tmp_d0_gs_ms_ns_lengths
,
tmp_d0_gs_ms_ns_strides
);
tmp_d0_gs_ms_ns_lengths
,
tmp_d0_gs_ms_ns_strides
);
const
auto
z_grid_desc_g_m_n
=
Transform
::
MakeCGridDescriptor_G_M_N
(
const
auto
z_grid_desc_g_m_n
=
Transform
::
MakeC
0
GridDescriptor_G_M_N
(
problem_desc
.
z_gs_ms_ns_lengths
,
problem_desc
.
z_gs_ms_ns_strides
);
problem_desc
.
z_gs_ms_ns_lengths
,
problem_desc
.
z_gs_ms_ns_strides
);
const
auto
b1_grid_desc_g_n_k
=
Transform
::
MakeB1GridDescriptor_G_N_K
(
const
auto
b1_grid_desc_g_n_k
=
Transform
::
MakeB1GridDescriptor_G_N_K
(
problem_desc
.
b1_gs_gemm1ns_gemm1ks_lengths
,
problem_desc
.
b1_gs_gemm1ns_gemm1ks_lengths
,
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_mha_bwd_xdl_cshuffle_qloop_light_v2.hpp
View file @
774c9209
...
@@ -608,7 +608,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
...
@@ -608,7 +608,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
static
auto
MakeZGridDescriptor_M_N
(
const
std
::
vector
<
index_t
>&
z_gs_ms_ns_lengths
,
static
auto
MakeZGridDescriptor_M_N
(
const
std
::
vector
<
index_t
>&
z_gs_ms_ns_lengths
,
const
std
::
vector
<
index_t
>&
z_gs_ms_ns_strides
)
const
std
::
vector
<
index_t
>&
z_gs_ms_ns_strides
)
{
{
return
Transform
::
MakeCGridDescriptor_M_N
(
z_gs_ms_ns_lengths
,
z_gs_ms_ns_strides
);
return
Transform
::
MakeC
0
GridDescriptor_M_N
(
z_gs_ms_ns_lengths
,
z_gs_ms_ns_strides
);
}
}
static
auto
MakeLSEGridDescriptor_M
(
index_t
MRaw
)
static
auto
MakeLSEGridDescriptor_M
(
index_t
MRaw
)
...
@@ -640,8 +640,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
...
@@ -640,8 +640,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
const
std
::
vector
<
ck
::
index_t
>&
acc0_bias_gs_ms_ns_strides
)
const
std
::
vector
<
ck
::
index_t
>&
acc0_bias_gs_ms_ns_strides
)
{
{
return
Transform
::
MakeCGridDescriptor_M_N
(
acc0_bias_gs_ms_ns_lengths
,
return
Transform
::
MakeC
0
GridDescriptor_M_N
(
acc0_bias_gs_ms_ns_lengths
,
acc0_bias_gs_ms_ns_strides
);
acc0_bias_gs_ms_ns_strides
);
}
}
static
auto
static
auto
...
@@ -649,8 +649,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
...
@@ -649,8 +649,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
const
std
::
vector
<
ck
::
index_t
>&
acc0_bias_gs_ms_ns_strides
)
const
std
::
vector
<
ck
::
index_t
>&
acc0_bias_gs_ms_ns_strides
)
{
{
return
Transform
::
MakeCGridDescriptor_G_M_N
(
acc0_bias_gs_ms_ns_lengths
,
return
Transform
::
MakeC
0
GridDescriptor_G_M_N
(
acc0_bias_gs_ms_ns_lengths
,
acc0_bias_gs_ms_ns_strides
);
acc0_bias_gs_ms_ns_strides
);
}
}
static
auto
MakeDGridDescriptor_M
(
index_t
MRaw
)
static
auto
MakeDGridDescriptor_M
(
index_t
MRaw
)
...
@@ -688,7 +688,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
...
@@ -688,7 +688,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
using
D0GridDesc_G_M_N
=
decltype
(
MakeD0GridDescriptor_G_M_N
({},
{}));
using
D0GridDesc_G_M_N
=
decltype
(
MakeD0GridDescriptor_G_M_N
({},
{}));
using
B1GridDesc_G_N_K
=
decltype
(
Transform
::
MakeB1GridDescriptor_G_N_K
({},
{}));
using
B1GridDesc_G_N_K
=
decltype
(
Transform
::
MakeB1GridDescriptor_G_N_K
({},
{}));
using
CGridDesc_G_M_N
=
decltype
(
Transform
::
MakeCGridDescriptor_G_M_N
({},
{}));
using
CGridDesc_G_M_N
=
decltype
(
Transform
::
MakeCGridDescriptor_G_M_N
({},
{}));
using
ZGridDesc_G_M_N
=
decltype
(
Transform
::
MakeCGridDescriptor_G_M_N
({},
{}));
using
ZGridDesc_G_M_N
=
decltype
(
Transform
::
MakeC
0
GridDescriptor_G_M_N
({},
{}));
using
KGridDesc_N_K
=
decltype
(
Transform
::
MakeB0GridDescriptor_N_K
({},
{}));
using
KGridDesc_N_K
=
decltype
(
Transform
::
MakeB0GridDescriptor_N_K
({},
{}));
using
D0GridDesc_M_N
=
decltype
(
MakeD0GridDescriptor_M_N
({},
{}));
using
D0GridDesc_M_N
=
decltype
(
MakeD0GridDescriptor_M_N
({},
{}));
...
@@ -1069,7 +1069,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
...
@@ -1069,7 +1069,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
problem_desc
.
b_gs_ns_ks_lengths
,
problem_desc
.
b_gs_ns_ks_strides
);
problem_desc
.
b_gs_ns_ks_lengths
,
problem_desc
.
b_gs_ns_ks_strides
);
const
auto
d0_grid_desc_g_m_n
=
DeviceOp
::
MakeD0GridDescriptor_G_M_N
(
const
auto
d0_grid_desc_g_m_n
=
DeviceOp
::
MakeD0GridDescriptor_G_M_N
(
tmp_d0_gs_ms_ns_lengths
,
tmp_d0_gs_ms_ns_strides
);
tmp_d0_gs_ms_ns_lengths
,
tmp_d0_gs_ms_ns_strides
);
const
auto
z_grid_desc_g_m_n
=
Transform
::
MakeCGridDescriptor_G_M_N
(
const
auto
z_grid_desc_g_m_n
=
Transform
::
MakeC
0
GridDescriptor_G_M_N
(
problem_desc
.
z_gs_ms_ns_lengths
,
problem_desc
.
z_gs_ms_ns_strides
);
problem_desc
.
z_gs_ms_ns_lengths
,
problem_desc
.
z_gs_ms_ns_strides
);
const
auto
b1_grid_desc_g_n_k
=
Transform
::
MakeB1GridDescriptor_G_N_K
(
const
auto
b1_grid_desc_g_n_k
=
Transform
::
MakeB1GridDescriptor_G_N_K
(
problem_desc
.
b1_gs_gemm1ns_gemm1ks_lengths
,
problem_desc
.
b1_gs_gemm1ns_gemm1ks_lengths
,
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_mha_bwd_xdl_cshuffle_qloop_v1.hpp
View file @
774c9209
...
@@ -477,7 +477,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
...
@@ -477,7 +477,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
static
auto
MakeZGridDescriptor_M_N
(
const
std
::
vector
<
index_t
>&
z_gs_ms_ns_lengths
,
static
auto
MakeZGridDescriptor_M_N
(
const
std
::
vector
<
index_t
>&
z_gs_ms_ns_lengths
,
const
std
::
vector
<
index_t
>&
z_gs_ms_ns_strides
)
const
std
::
vector
<
index_t
>&
z_gs_ms_ns_strides
)
{
{
return
Transform
::
MakeCGridDescriptor_M_N
(
z_gs_ms_ns_lengths
,
z_gs_ms_ns_strides
);
return
Transform
::
MakeC
0
GridDescriptor_M_N
(
z_gs_ms_ns_lengths
,
z_gs_ms_ns_strides
);
}
}
static
auto
MakeLSEGridDescriptor_M
(
index_t
MRaw
)
static
auto
MakeLSEGridDescriptor_M
(
index_t
MRaw
)
...
@@ -509,8 +509,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
...
@@ -509,8 +509,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
const
std
::
vector
<
ck
::
index_t
>&
acc0_bias_gs_ms_ns_strides
)
const
std
::
vector
<
ck
::
index_t
>&
acc0_bias_gs_ms_ns_strides
)
{
{
return
Transform
::
MakeCGridDescriptor_M_N
(
acc0_bias_gs_ms_ns_lengths
,
return
Transform
::
MakeC
0
GridDescriptor_M_N
(
acc0_bias_gs_ms_ns_lengths
,
acc0_bias_gs_ms_ns_strides
);
acc0_bias_gs_ms_ns_strides
);
}
}
static
auto
static
auto
...
@@ -518,8 +518,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
...
@@ -518,8 +518,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
const
std
::
vector
<
ck
::
index_t
>&
acc0_bias_gs_ms_ns_strides
)
const
std
::
vector
<
ck
::
index_t
>&
acc0_bias_gs_ms_ns_strides
)
{
{
return
Transform
::
MakeCGridDescriptor_G_M_N
(
acc0_bias_gs_ms_ns_lengths
,
return
Transform
::
MakeC
0
GridDescriptor_G_M_N
(
acc0_bias_gs_ms_ns_lengths
,
acc0_bias_gs_ms_ns_strides
);
acc0_bias_gs_ms_ns_strides
);
}
}
using
AGridDesc_AK0_M_AK1
=
decltype
(
MakeAGridDescriptor_AK0_M_AK1
({},
{}));
using
AGridDesc_AK0_M_AK1
=
decltype
(
MakeAGridDescriptor_AK0_M_AK1
({},
{}));
...
@@ -532,7 +532,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
...
@@ -532,7 +532,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
using
D0GridDesc_G_M_N
=
decltype
(
MakeD0GridDescriptor_G_M_N
({},
{}));
using
D0GridDesc_G_M_N
=
decltype
(
MakeD0GridDescriptor_G_M_N
({},
{}));
using
B1GridDesc_G_N_K
=
decltype
(
Transform
::
MakeB1GridDescriptor_G_N_K
({},
{}));
using
B1GridDesc_G_N_K
=
decltype
(
Transform
::
MakeB1GridDescriptor_G_N_K
({},
{}));
using
CGridDesc_G_M_N
=
decltype
(
Transform
::
MakeCGridDescriptor_G_M_N
({},
{}));
using
CGridDesc_G_M_N
=
decltype
(
Transform
::
MakeCGridDescriptor_G_M_N
({},
{}));
using
ZGridDesc_G_M_N
=
decltype
(
Transform
::
MakeCGridDescriptor_G_M_N
({},
{}));
using
ZGridDesc_G_M_N
=
decltype
(
Transform
::
MakeC
0
GridDescriptor_G_M_N
({},
{}));
using
KGridDesc_N_K
=
decltype
(
Transform
::
MakeB0GridDescriptor_N_K
({},
{}));
using
KGridDesc_N_K
=
decltype
(
Transform
::
MakeB0GridDescriptor_N_K
({},
{}));
using
D0GridDesc_M_N
=
decltype
(
MakeD0GridDescriptor_M_N
({},
{}));
using
D0GridDesc_M_N
=
decltype
(
MakeD0GridDescriptor_M_N
({},
{}));
...
@@ -878,7 +878,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
...
@@ -878,7 +878,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
problem_desc
.
b_gs_ns_ks_lengths
,
problem_desc
.
b_gs_ns_ks_strides
);
problem_desc
.
b_gs_ns_ks_lengths
,
problem_desc
.
b_gs_ns_ks_strides
);
const
auto
d0_grid_desc_g_m_n
=
DeviceOp
::
MakeD0GridDescriptor_G_M_N
(
const
auto
d0_grid_desc_g_m_n
=
DeviceOp
::
MakeD0GridDescriptor_G_M_N
(
tmp_d0_gs_ms_ns_lengths
,
tmp_d0_gs_ms_ns_strides
);
tmp_d0_gs_ms_ns_lengths
,
tmp_d0_gs_ms_ns_strides
);
const
auto
z_grid_desc_g_m_n
=
Transform
::
MakeCGridDescriptor_G_M_N
(
const
auto
z_grid_desc_g_m_n
=
Transform
::
MakeC
0
GridDescriptor_G_M_N
(
problem_desc
.
z_gs_ms_ns_lengths
,
problem_desc
.
z_gs_ms_ns_strides
);
problem_desc
.
z_gs_ms_ns_lengths
,
problem_desc
.
z_gs_ms_ns_strides
);
const
auto
b1_grid_desc_g_n_k
=
Transform
::
MakeB1GridDescriptor_G_N_K
(
const
auto
b1_grid_desc_g_n_k
=
Transform
::
MakeB1GridDescriptor_G_N_K
(
problem_desc
.
b1_gs_gemm1ns_gemm1ks_lengths
,
problem_desc
.
b1_gs_gemm1ns_gemm1ks_lengths
,
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_mha_bwd_xdl_cshuffle_qloop_v2.hpp
View file @
774c9209
...
@@ -539,7 +539,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
...
@@ -539,7 +539,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
static
auto
MakeZGridDescriptor_M_N
(
const
std
::
vector
<
index_t
>&
z_gs_ms_ns_lengths
,
static
auto
MakeZGridDescriptor_M_N
(
const
std
::
vector
<
index_t
>&
z_gs_ms_ns_lengths
,
const
std
::
vector
<
index_t
>&
z_gs_ms_ns_strides
)
const
std
::
vector
<
index_t
>&
z_gs_ms_ns_strides
)
{
{
return
Transform
::
MakeCGridDescriptor_M_N
(
z_gs_ms_ns_lengths
,
z_gs_ms_ns_strides
);
return
Transform
::
MakeC
0
GridDescriptor_M_N
(
z_gs_ms_ns_lengths
,
z_gs_ms_ns_strides
);
}
}
static
auto
MakeLSEGridDescriptor_M
(
index_t
MRaw
)
static
auto
MakeLSEGridDescriptor_M
(
index_t
MRaw
)
...
@@ -571,8 +571,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
...
@@ -571,8 +571,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
const
std
::
vector
<
ck
::
index_t
>&
acc0_bias_gs_ms_ns_strides
)
const
std
::
vector
<
ck
::
index_t
>&
acc0_bias_gs_ms_ns_strides
)
{
{
return
Transform
::
MakeCGridDescriptor_M_N
(
acc0_bias_gs_ms_ns_lengths
,
return
Transform
::
MakeC
0
GridDescriptor_M_N
(
acc0_bias_gs_ms_ns_lengths
,
acc0_bias_gs_ms_ns_strides
);
acc0_bias_gs_ms_ns_strides
);
}
}
static
auto
static
auto
...
@@ -580,8 +580,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
...
@@ -580,8 +580,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
const
std
::
vector
<
ck
::
index_t
>&
acc0_bias_gs_ms_ns_strides
)
const
std
::
vector
<
ck
::
index_t
>&
acc0_bias_gs_ms_ns_strides
)
{
{
return
Transform
::
MakeCGridDescriptor_G_M_N
(
acc0_bias_gs_ms_ns_lengths
,
return
Transform
::
MakeC
0
GridDescriptor_G_M_N
(
acc0_bias_gs_ms_ns_lengths
,
acc0_bias_gs_ms_ns_strides
);
acc0_bias_gs_ms_ns_strides
);
}
}
using
AGridDesc_AK0_M_AK1
=
decltype
(
MakeAGridDescriptor_AK0_M_AK1
({},
{}));
using
AGridDesc_AK0_M_AK1
=
decltype
(
MakeAGridDescriptor_AK0_M_AK1
({},
{}));
...
@@ -594,7 +594,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
...
@@ -594,7 +594,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
using
D0GridDesc_G_M_N
=
decltype
(
MakeD0GridDescriptor_G_M_N
({},
{}));
using
D0GridDesc_G_M_N
=
decltype
(
MakeD0GridDescriptor_G_M_N
({},
{}));
using
B1GridDesc_G_N_K
=
decltype
(
Transform
::
MakeB1GridDescriptor_G_N_K
({},
{}));
using
B1GridDesc_G_N_K
=
decltype
(
Transform
::
MakeB1GridDescriptor_G_N_K
({},
{}));
using
CGridDesc_G_M_N
=
decltype
(
Transform
::
MakeCGridDescriptor_G_M_N
({},
{}));
using
CGridDesc_G_M_N
=
decltype
(
Transform
::
MakeCGridDescriptor_G_M_N
({},
{}));
using
ZGridDesc_G_M_N
=
decltype
(
Transform
::
MakeCGridDescriptor_G_M_N
({},
{}));
using
ZGridDesc_G_M_N
=
decltype
(
Transform
::
MakeC
0
GridDescriptor_G_M_N
({},
{}));
using
KGridDesc_N_K
=
decltype
(
Transform
::
MakeB0GridDescriptor_N_K
({},
{}));
using
KGridDesc_N_K
=
decltype
(
Transform
::
MakeB0GridDescriptor_N_K
({},
{}));
using
D0GridDesc_M_N
=
decltype
(
MakeD0GridDescriptor_M_N
({},
{}));
using
D0GridDesc_M_N
=
decltype
(
MakeD0GridDescriptor_M_N
({},
{}));
...
@@ -948,7 +948,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
...
@@ -948,7 +948,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
problem_desc
.
b_gs_ns_ks_lengths
,
problem_desc
.
b_gs_ns_ks_strides
);
problem_desc
.
b_gs_ns_ks_lengths
,
problem_desc
.
b_gs_ns_ks_strides
);
const
auto
d0_grid_desc_g_m_n
=
DeviceOp
::
MakeD0GridDescriptor_G_M_N
(
const
auto
d0_grid_desc_g_m_n
=
DeviceOp
::
MakeD0GridDescriptor_G_M_N
(
tmp_d0_gs_ms_ns_lengths
,
tmp_d0_gs_ms_ns_strides
);
tmp_d0_gs_ms_ns_lengths
,
tmp_d0_gs_ms_ns_strides
);
const
auto
z_grid_desc_g_m_n
=
Transform
::
MakeCGridDescriptor_G_M_N
(
const
auto
z_grid_desc_g_m_n
=
Transform
::
MakeC
0
GridDescriptor_G_M_N
(
problem_desc
.
z_gs_ms_ns_lengths
,
problem_desc
.
z_gs_ms_ns_strides
);
problem_desc
.
z_gs_ms_ns_lengths
,
problem_desc
.
z_gs_ms_ns_strides
);
const
auto
b1_grid_desc_g_n_k
=
Transform
::
MakeB1GridDescriptor_G_N_K
(
const
auto
b1_grid_desc_g_n_k
=
Transform
::
MakeB1GridDescriptor_G_N_K
(
problem_desc
.
b1_gs_gemm1ns_gemm1ks_lengths
,
problem_desc
.
b1_gs_gemm1ns_gemm1ks_lengths
,
...
...
include/ck/tensor_operation/gpu/device/matrix_padder.hpp
View file @
774c9209
...
@@ -119,6 +119,15 @@ struct GemmGemmPadder
...
@@ -119,6 +119,15 @@ struct GemmGemmPadder
c_desc_mraw_nraw
,
make_tuple
(
MPerTile_
,
OPerTile_
),
Sequence
<
PadM
,
PadO
>
{});
c_desc_mraw_nraw
,
make_tuple
(
MPerTile_
,
OPerTile_
),
Sequence
<
PadM
,
PadO
>
{});
}
}
// C[M, Gemm1N] = C[M, N]
template
<
typename
C0Desc_MRaw_NRaw
>
__host__
__device__
constexpr
auto
PadC0Descriptor_M_N
(
const
C0Desc_MRaw_NRaw
&
c_desc_mraw_nraw
)
const
{
return
PadTensorDescriptor
(
c_desc_mraw_nraw
,
make_tuple
(
MPerTile_
,
NPerTile_
),
Sequence
<
PadM
,
PadN
>
{});
}
MPerTileType
MPerTile_
;
MPerTileType
MPerTile_
;
NPerTileType
NPerTile_
;
NPerTileType
NPerTile_
;
KPerTileType
KPerTile_
;
KPerTileType
KPerTile_
;
...
...
include/ck/tensor_operation/operator_transform/transform_contraction_to_gemm.hpp
View file @
774c9209
...
@@ -282,6 +282,29 @@ struct TransformBatchedContractionContractionToBatchedGemmGemm
...
@@ -282,6 +282,29 @@ struct TransformBatchedContractionContractionToBatchedGemmGemm
return
matrix_padder
.
PadCDescriptor_M_N
(
return
matrix_padder
.
PadCDescriptor_M_N
(
MakeCGridDescriptorPair
(
c_gs_ms_os_lengths_vec
,
c_gs_ms_os_strides_vec
).
second
);
MakeCGridDescriptorPair
(
c_gs_ms_os_lengths_vec
,
c_gs_ms_os_strides_vec
).
second
);
}
}
//
// C0
//
static
auto
MakeC0GridDescriptorPair
(
const
std
::
vector
<
index_t
>&
c_gs_ms_ns_lengths_vec
,
const
std
::
vector
<
index_t
>&
c_gs_ms_ns_strides_vec
)
{
return
MakeGridDescriptorPair
<
NumDimG
,
NumDimM
,
NumDimN
,
CSpec
>
(
c_gs_ms_ns_lengths_vec
,
c_gs_ms_ns_strides_vec
);
}
// TODO: rename to G_MRaw_NRaw
static
auto
MakeC0GridDescriptor_G_M_N
(
const
std
::
vector
<
index_t
>&
c_gs_ms_ns_lengths_vec
,
const
std
::
vector
<
index_t
>&
c_gs_ms_ns_strides_vec
)
{
return
MakeC0GridDescriptorPair
(
c_gs_ms_ns_lengths_vec
,
c_gs_ms_ns_strides_vec
).
first
;
}
static
auto
MakeC0GridDescriptor_M_N
(
const
std
::
vector
<
index_t
>&
c_gs_ms_ns_lengths_vec
,
const
std
::
vector
<
index_t
>&
c_gs_ms_ns_strides_vec
)
{
return
matrix_padder
.
PadC0Descriptor_M_N
(
MakeC0GridDescriptorPair
(
c_gs_ms_ns_lengths_vec
,
c_gs_ms_ns_strides_vec
).
second
);
}
};
};
}
// namespace tensor_operation
}
// namespace tensor_operation
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment