Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
eff268e6
"git@developer.sourcefind.cn:gaoqiong/composable_kernel.git" did not exist on "781cacd2e60ffbf358aaeeeee315a9b9d69c43a6"
Commit
eff268e6
authored
Aug 29, 2023
by
letaoqin
Browse files
remove _vec for bwd parameters
parent
2464edd0
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
130 additions
and
130 deletions
+130
-130
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_bwd_xdl_cshuffle_qloop_v1.hpp
...ice/impl/device_batched_mha_bwd_xdl_cshuffle_qloop_v1.hpp
+39
-39
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_bwd_xdl_cshuffle_qloop_v2.hpp
...ice/impl/device_batched_mha_bwd_xdl_cshuffle_qloop_v2.hpp
+39
-39
include/ck/tensor_operation/gpu/device/impl/device_grouped_mha_bwd_xdl_cshuffle_qloop_v1.hpp
...ice/impl/device_grouped_mha_bwd_xdl_cshuffle_qloop_v1.hpp
+26
-26
include/ck/tensor_operation/gpu/device/impl/device_grouped_mha_bwd_xdl_cshuffle_qloop_v2.hpp
...ice/impl/device_grouped_mha_bwd_xdl_cshuffle_qloop_v2.hpp
+26
-26
No files found.
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_bwd_xdl_cshuffle_qloop_v1.hpp
View file @
eff268e6
...
...
@@ -352,31 +352,31 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
*/
// Q in Gemm A position
static
auto
MakeAGridDescriptor_AK0_M_AK1
(
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_lengths
_vec
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_strides
_vec
)
static
auto
MakeAGridDescriptor_AK0_M_AK1
(
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_lengths
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_strides
)
{
return
Transform
::
MakeAGridDescriptor_AK0_M_AK1
(
Transform
::
MakeAGridDescriptor_M_K
(
a_gs_ms_ks_lengths
_vec
,
a_gs_ms_ks_strides
_vec
),
Transform
::
MakeAGridDescriptor_M_K
(
a_gs_ms_ks_lengths
,
a_gs_ms_ks_strides
),
Number
<
AK1
>
{});
}
// K in Gemm B0 position
static
auto
MakeBGridDescriptor_BK0_N_BK1
(
const
std
::
vector
<
index_t
>&
b_gs_ns_ks_lengths
_vec
,
const
std
::
vector
<
index_t
>&
b_gs_ns_ks_strides
_vec
)
static
auto
MakeBGridDescriptor_BK0_N_BK1
(
const
std
::
vector
<
index_t
>&
b_gs_ns_ks_lengths
,
const
std
::
vector
<
index_t
>&
b_gs_ns_ks_strides
)
{
return
Transform
::
MakeB0GridDescriptor_BK0_N_BK1
(
Transform
::
MakeB0GridDescriptor_N_K
(
b_gs_ns_ks_lengths
_vec
,
b_gs_ns_ks_strides
_vec
),
Transform
::
MakeB0GridDescriptor_N_K
(
b_gs_ns_ks_lengths
,
b_gs_ns_ks_strides
),
Number
<
BK1
>
{});
}
// V in Gemm B1 position
static
auto
MakeB1GridDescriptor_BK0_N_BK1
(
const
std
::
vector
<
index_t
>&
b1_gs_gemm1ns_gemm1ks_lengths
_vec
,
const
std
::
vector
<
index_t
>&
b1_gs_gemm1ns_gemm1ks_strides
_vec
)
MakeB1GridDescriptor_BK0_N_BK1
(
const
std
::
vector
<
index_t
>&
b1_gs_gemm1ns_gemm1ks_lengths
,
const
std
::
vector
<
index_t
>&
b1_gs_gemm1ns_gemm1ks_strides
)
{
return
Transform
::
MakeB1GridDescriptor_BK0_N_BK1
(
Transform
::
MakeB1GridDescriptor_N_K
(
b1_gs_gemm1ns_gemm1ks_lengths
_vec
,
b1_gs_gemm1ns_gemm1ks_strides
_vec
),
Transform
::
MakeB1GridDescriptor_N_K
(
b1_gs_gemm1ns_gemm1ks_lengths
,
b1_gs_gemm1ns_gemm1ks_strides
),
Number
<
B1K1
>
{});
}
...
...
@@ -385,8 +385,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
//
// VGrad in Gemm C position
static
auto
MakeVGradGridDescriptor_N_O
(
const
std
::
vector
<
index_t
>&
v_gs_os_ns_lengths
_vec
,
const
std
::
vector
<
index_t
>&
v_gs_os_ns_strides
_vec
)
static
auto
MakeVGradGridDescriptor_N_O
(
const
std
::
vector
<
index_t
>&
v_gs_os_ns_lengths
,
const
std
::
vector
<
index_t
>&
v_gs_os_ns_strides
)
{
// v_gs_os_ns -> vgrad_gs_ns_os. O dims last because output is row-major.
// Here directly rearrange lengths/strides before constructing tensor descriptor to reduce
...
...
@@ -412,17 +412,17 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
ids_old2new
.
insert
(
ids_old2new
.
end
(),
ns_ids
.
begin
(),
ns_ids
.
end
());
ids_old2new
.
insert
(
ids_old2new
.
end
(),
os_ids
.
begin
(),
os_ids
.
end
());
std
::
vector
<
index_t
>
v_gs_ns_os_lengths
_vec
(
num_dims
),
v_gs_ns_os_strides
_vec
(
num_dims
);
std
::
vector
<
index_t
>
v_gs_ns_os_lengths
(
num_dims
),
v_gs_ns_os_strides
(
num_dims
);
for
(
int
i
=
0
;
i
<
num_dims
;
i
++
)
{
index_t
id_new
=
ids_old2new
[
i
];
v_gs_ns_os_lengths
_vec
[
i
]
=
v_gs_os_ns_lengths
_vec
[
id_new
];
v_gs_ns_os_strides
_vec
[
i
]
=
v_gs_os_ns_strides
_vec
[
id_new
];
index_t
id_new
=
ids_old2new
[
i
];
v_gs_ns_os_lengths
[
i
]
=
v_gs_os_ns_lengths
[
id_new
];
v_gs_ns_os_strides
[
i
]
=
v_gs_os_ns_strides
[
id_new
];
}
const
auto
vgrad_desc_nraw_oraw
=
MakeGridDescriptorPair
<
NumDimG
,
NumDimN
,
NumDimO
,
TensorSpecialization
::
Default
>
(
v_gs_ns_os_lengths
_vec
,
v_gs_ns_os_strides
_vec
)
v_gs_ns_os_lengths
,
v_gs_ns_os_strides
)
.
second
;
return
PadTensorDescriptor
(
vgrad_desc_nraw_oraw
,
...
...
@@ -451,17 +451,17 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
//
// YGrad in Gemm A position
static
auto
MakeYGradGridDescriptor_O0_M_O1
(
const
std
::
vector
<
index_t
>&
y_gs_ms_os_lengths
_vec
,
const
std
::
vector
<
index_t
>&
y_gs_ms_os_strides
_vec
)
static
auto
MakeYGradGridDescriptor_O0_M_O1
(
const
std
::
vector
<
index_t
>&
y_gs_ms_os_lengths
,
const
std
::
vector
<
index_t
>&
y_gs_ms_os_strides
)
{
return
Transform
::
MakeAGridDescriptor_AK0_M_AK1
(
Transform
::
MakeAGridDescriptor_M_K
(
y_gs_ms_os_lengths
_vec
,
y_gs_ms_os_strides
_vec
),
Transform
::
MakeAGridDescriptor_M_K
(
y_gs_ms_os_lengths
,
y_gs_ms_os_strides
),
Number
<
Y_O1
>
{});
}
// V in Gemm B position
static
auto
MakeVGridDescriptor_O0_N_O1
(
const
std
::
vector
<
index_t
>&
v_gs_os_ns_lengths
_vec
,
const
std
::
vector
<
index_t
>&
v_gs_os_ns_strides
_vec
)
static
auto
MakeVGridDescriptor_O0_N_O1
(
const
std
::
vector
<
index_t
>&
v_gs_os_ns_lengths
,
const
std
::
vector
<
index_t
>&
v_gs_os_ns_strides
)
{
// v_gs_os_ns -> vgrad_gs_ns_os. O dims last because output is row-major.
// Here directly rearrange lengths/strides before constructing tensor descriptor to reduce
...
...
@@ -487,17 +487,17 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
ids_old2new
.
insert
(
ids_old2new
.
end
(),
ns_ids
.
begin
(),
ns_ids
.
end
());
ids_old2new
.
insert
(
ids_old2new
.
end
(),
os_ids
.
begin
(),
os_ids
.
end
());
std
::
vector
<
index_t
>
v_gs_ns_os_lengths
_vec
(
num_dims
),
v_gs_ns_os_strides
_vec
(
num_dims
);
std
::
vector
<
index_t
>
v_gs_ns_os_lengths
(
num_dims
),
v_gs_ns_os_strides
(
num_dims
);
for
(
int
i
=
0
;
i
<
num_dims
;
i
++
)
{
index_t
id_new
=
ids_old2new
[
i
];
v_gs_ns_os_lengths
_vec
[
i
]
=
v_gs_os_ns_lengths
_vec
[
id_new
];
v_gs_ns_os_strides
_vec
[
i
]
=
v_gs_os_ns_strides
_vec
[
id_new
];
index_t
id_new
=
ids_old2new
[
i
];
v_gs_ns_os_lengths
[
i
]
=
v_gs_os_ns_lengths
[
id_new
];
v_gs_ns_os_strides
[
i
]
=
v_gs_os_ns_strides
[
id_new
];
}
const
auto
v_grid_desc_nraw_oraw
=
MakeGridDescriptorPair
<
NumDimG
,
NumDimN
,
NumDimO
,
TensorSpecialization
::
Default
>
(
v_gs_ns_os_lengths
_vec
,
v_gs_ns_os_strides
_vec
)
v_gs_ns_os_lengths
,
v_gs_ns_os_strides
)
.
second
;
const
auto
v_grid_desc_n_o
=
PadTensorDescriptor
(
v_grid_desc_nraw_oraw
,
...
...
@@ -509,10 +509,10 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
}
// Z in Gemm0 C position
static
auto
MakeZGridDescriptor_M_N
(
const
std
::
vector
<
index_t
>&
z_gs_ms_ns_lengths
_vec
,
const
std
::
vector
<
index_t
>&
z_gs_ms_ns_strides
_vec
)
static
auto
MakeZGridDescriptor_M_N
(
const
std
::
vector
<
index_t
>&
z_gs_ms_ns_lengths
,
const
std
::
vector
<
index_t
>&
z_gs_ms_ns_strides
)
{
return
Transform
::
MakeCGridDescriptor_M_N
(
z_gs_ms_ns_lengths
_vec
,
z_gs_ms_ns_strides
_vec
);
return
Transform
::
MakeCGridDescriptor_M_N
(
z_gs_ms_ns_lengths
,
z_gs_ms_ns_strides
);
}
//
// dS_i_j = P_i_j .* (dP_i_j - dY_i dot Y_i)
...
...
@@ -523,10 +523,10 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
//
// QGrad in Gemm C position
static
auto
MakeQGradGridDescriptor_M_K
(
const
std
::
vector
<
index_t
>&
q_gs_ms_ks_lengths
_vec
,
const
std
::
vector
<
index_t
>&
q_gs_ms_ks_strides
_vec
)
static
auto
MakeQGradGridDescriptor_M_K
(
const
std
::
vector
<
index_t
>&
q_gs_ms_ks_lengths
,
const
std
::
vector
<
index_t
>&
q_gs_ms_ks_strides
)
{
return
Transform
::
MakeCGridDescriptor_M_N
(
q_gs_ms_ks_lengths
_vec
,
q_gs_ms_ks_strides
_vec
);
return
Transform
::
MakeCGridDescriptor_M_N
(
q_gs_ms_ks_lengths
,
q_gs_ms_ks_strides
);
}
//
...
...
@@ -534,10 +534,10 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
//
// KGrad in Gemm C position
static
auto
MakeKGradGridDescriptor_N_K
(
const
std
::
vector
<
index_t
>&
k_gs_ns_ks_lengths
_vec
,
const
std
::
vector
<
index_t
>&
k_gs_ns_ks_strides
_vec
)
static
auto
MakeKGradGridDescriptor_N_K
(
const
std
::
vector
<
index_t
>&
k_gs_ns_ks_lengths
,
const
std
::
vector
<
index_t
>&
k_gs_ns_ks_strides
)
{
return
Transform
::
MakeCGridDescriptor_M_N
(
k_gs_ns_ks_lengths
_vec
,
k_gs_ns_ks_strides
_vec
);
return
Transform
::
MakeCGridDescriptor_M_N
(
k_gs_ns_ks_lengths
,
k_gs_ns_ks_strides
);
}
static
auto
MakeLSEGridDescriptor_M
(
index_t
MRaw
)
...
...
@@ -565,10 +565,10 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
}
}
// D in Gemm0 C position
static
auto
MakeDGridDescriptor_M_N
(
const
std
::
vector
<
index_t
>&
d_gs_ms_ns_lengths
_vec
,
const
std
::
vector
<
index_t
>&
d_gs_ms_ns_strides
_vec
)
static
auto
MakeDGridDescriptor_M_N
(
const
std
::
vector
<
index_t
>&
d_gs_ms_ns_lengths
,
const
std
::
vector
<
index_t
>&
d_gs_ms_ns_strides
)
{
return
Transform
::
MakeCGridDescriptor_M_N
(
d_gs_ms_ns_lengths
_vec
,
d_gs_ms_ns_strides
_vec
);
return
Transform
::
MakeCGridDescriptor_M_N
(
d_gs_ms_ns_lengths
,
d_gs_ms_ns_strides
);
}
using
AGridDesc_AK0_M_AK1
=
decltype
(
MakeAGridDescriptor_AK0_M_AK1
({},
{}));
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_bwd_xdl_cshuffle_qloop_v2.hpp
View file @
eff268e6
...
...
@@ -360,31 +360,31 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
*/
// Q in Gemm A position
static
auto
MakeAGridDescriptor_AK0_M_AK1
(
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_lengths
_vec
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_strides
_vec
)
static
auto
MakeAGridDescriptor_AK0_M_AK1
(
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_lengths
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_strides
)
{
return
Transform
::
MakeAGridDescriptor_AK0_M_AK1
(
Transform
::
MakeAGridDescriptor_M_K
(
a_gs_ms_ks_lengths
_vec
,
a_gs_ms_ks_strides
_vec
),
Transform
::
MakeAGridDescriptor_M_K
(
a_gs_ms_ks_lengths
,
a_gs_ms_ks_strides
),
Number
<
AK1
>
{});
}
// K in Gemm B0 position
static
auto
MakeBGridDescriptor_BK0_N_BK1
(
const
std
::
vector
<
index_t
>&
b_gs_ns_ks_lengths
_vec
,
const
std
::
vector
<
index_t
>&
b_gs_ns_ks_strides
_vec
)
static
auto
MakeBGridDescriptor_BK0_N_BK1
(
const
std
::
vector
<
index_t
>&
b_gs_ns_ks_lengths
,
const
std
::
vector
<
index_t
>&
b_gs_ns_ks_strides
)
{
return
Transform
::
MakeB0GridDescriptor_BK0_N_BK1
(
Transform
::
MakeB0GridDescriptor_N_K
(
b_gs_ns_ks_lengths
_vec
,
b_gs_ns_ks_strides
_vec
),
Transform
::
MakeB0GridDescriptor_N_K
(
b_gs_ns_ks_lengths
,
b_gs_ns_ks_strides
),
Number
<
BK1
>
{});
}
// V in Gemm B1 position
static
auto
MakeB1GridDescriptor_BK0_N_BK1
(
const
std
::
vector
<
index_t
>&
b1_gs_gemm1ns_gemm1ks_lengths
_vec
,
const
std
::
vector
<
index_t
>&
b1_gs_gemm1ns_gemm1ks_strides
_vec
)
MakeB1GridDescriptor_BK0_N_BK1
(
const
std
::
vector
<
index_t
>&
b1_gs_gemm1ns_gemm1ks_lengths
,
const
std
::
vector
<
index_t
>&
b1_gs_gemm1ns_gemm1ks_strides
)
{
return
Transform
::
MakeB1GridDescriptor_BK0_N_BK1
(
Transform
::
MakeB1GridDescriptor_N_K
(
b1_gs_gemm1ns_gemm1ks_lengths
_vec
,
b1_gs_gemm1ns_gemm1ks_strides
_vec
),
Transform
::
MakeB1GridDescriptor_N_K
(
b1_gs_gemm1ns_gemm1ks_lengths
,
b1_gs_gemm1ns_gemm1ks_strides
),
Number
<
B1K1
>
{});
}
...
...
@@ -393,8 +393,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
//
// VGrad in Gemm C position
static
auto
MakeVGradGridDescriptor_N_O
(
const
std
::
vector
<
index_t
>&
v_gs_os_ns_lengths
_vec
,
const
std
::
vector
<
index_t
>&
v_gs_os_ns_strides
_vec
)
static
auto
MakeVGradGridDescriptor_N_O
(
const
std
::
vector
<
index_t
>&
v_gs_os_ns_lengths
,
const
std
::
vector
<
index_t
>&
v_gs_os_ns_strides
)
{
// v_gs_os_ns -> vgrad_gs_ns_os. O dims last because output is row-major.
// Here directly rearrange lengths/strides before constructing tensor descriptor to reduce
...
...
@@ -420,17 +420,17 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
ids_old2new
.
insert
(
ids_old2new
.
end
(),
ns_ids
.
begin
(),
ns_ids
.
end
());
ids_old2new
.
insert
(
ids_old2new
.
end
(),
os_ids
.
begin
(),
os_ids
.
end
());
std
::
vector
<
index_t
>
v_gs_ns_os_lengths
_vec
(
num_dims
),
v_gs_ns_os_strides
_vec
(
num_dims
);
std
::
vector
<
index_t
>
v_gs_ns_os_lengths
(
num_dims
),
v_gs_ns_os_strides
(
num_dims
);
for
(
int
i
=
0
;
i
<
num_dims
;
i
++
)
{
index_t
id_new
=
ids_old2new
[
i
];
v_gs_ns_os_lengths
_vec
[
i
]
=
v_gs_os_ns_lengths
_vec
[
id_new
];
v_gs_ns_os_strides
_vec
[
i
]
=
v_gs_os_ns_strides
_vec
[
id_new
];
index_t
id_new
=
ids_old2new
[
i
];
v_gs_ns_os_lengths
[
i
]
=
v_gs_os_ns_lengths
[
id_new
];
v_gs_ns_os_strides
[
i
]
=
v_gs_os_ns_strides
[
id_new
];
}
const
auto
vgrad_desc_nraw_oraw
=
MakeGridDescriptorPair
<
NumDimG
,
NumDimN
,
NumDimO
,
TensorSpecialization
::
Default
>
(
v_gs_ns_os_lengths
_vec
,
v_gs_ns_os_strides
_vec
)
v_gs_ns_os_lengths
,
v_gs_ns_os_strides
)
.
second
;
return
PadTensorDescriptor
(
vgrad_desc_nraw_oraw
,
...
...
@@ -459,17 +459,17 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
//
// YGrad in Gemm A position
static
auto
MakeYGradGridDescriptor_O0_M_O1
(
const
std
::
vector
<
index_t
>&
y_gs_ms_os_lengths
_vec
,
const
std
::
vector
<
index_t
>&
y_gs_ms_os_strides
_vec
)
static
auto
MakeYGradGridDescriptor_O0_M_O1
(
const
std
::
vector
<
index_t
>&
y_gs_ms_os_lengths
,
const
std
::
vector
<
index_t
>&
y_gs_ms_os_strides
)
{
return
Transform
::
MakeAGridDescriptor_AK0_M_AK1
(
Transform
::
MakeAGridDescriptor_M_K
(
y_gs_ms_os_lengths
_vec
,
y_gs_ms_os_strides
_vec
),
Transform
::
MakeAGridDescriptor_M_K
(
y_gs_ms_os_lengths
,
y_gs_ms_os_strides
),
Number
<
Y_O1
>
{});
}
// V in Gemm B position
static
auto
MakeVGridDescriptor_O0_N_O1
(
const
std
::
vector
<
index_t
>&
v_gs_os_ns_lengths
_vec
,
const
std
::
vector
<
index_t
>&
v_gs_os_ns_strides
_vec
)
static
auto
MakeVGridDescriptor_O0_N_O1
(
const
std
::
vector
<
index_t
>&
v_gs_os_ns_lengths
,
const
std
::
vector
<
index_t
>&
v_gs_os_ns_strides
)
{
// v_gs_os_ns -> vgrad_gs_ns_os. O dims last because output is row-major.
// Here directly rearrange lengths/strides before constructing tensor descriptor to reduce
...
...
@@ -495,17 +495,17 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
ids_old2new
.
insert
(
ids_old2new
.
end
(),
ns_ids
.
begin
(),
ns_ids
.
end
());
ids_old2new
.
insert
(
ids_old2new
.
end
(),
os_ids
.
begin
(),
os_ids
.
end
());
std
::
vector
<
index_t
>
v_gs_ns_os_lengths
_vec
(
num_dims
),
v_gs_ns_os_strides
_vec
(
num_dims
);
std
::
vector
<
index_t
>
v_gs_ns_os_lengths
(
num_dims
),
v_gs_ns_os_strides
(
num_dims
);
for
(
int
i
=
0
;
i
<
num_dims
;
i
++
)
{
index_t
id_new
=
ids_old2new
[
i
];
v_gs_ns_os_lengths
_vec
[
i
]
=
v_gs_os_ns_lengths
_vec
[
id_new
];
v_gs_ns_os_strides
_vec
[
i
]
=
v_gs_os_ns_strides
_vec
[
id_new
];
index_t
id_new
=
ids_old2new
[
i
];
v_gs_ns_os_lengths
[
i
]
=
v_gs_os_ns_lengths
[
id_new
];
v_gs_ns_os_strides
[
i
]
=
v_gs_os_ns_strides
[
id_new
];
}
const
auto
v_grid_desc_nraw_oraw
=
MakeGridDescriptorPair
<
NumDimG
,
NumDimN
,
NumDimO
,
TensorSpecialization
::
Default
>
(
v_gs_ns_os_lengths
_vec
,
v_gs_ns_os_strides
_vec
)
v_gs_ns_os_lengths
,
v_gs_ns_os_strides
)
.
second
;
const
auto
v_grid_desc_n_o
=
PadTensorDescriptor
(
v_grid_desc_nraw_oraw
,
...
...
@@ -517,17 +517,17 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
}
// D in Gemm0 C position
static
auto
MakeDGridDescriptor_M_N
(
const
std
::
vector
<
index_t
>&
d_gs_ms_ns_lengths
_vec
,
const
std
::
vector
<
index_t
>&
d_gs_ms_ns_strides
_vec
)
static
auto
MakeDGridDescriptor_M_N
(
const
std
::
vector
<
index_t
>&
d_gs_ms_ns_lengths
,
const
std
::
vector
<
index_t
>&
d_gs_ms_ns_strides
)
{
return
Transform
::
MakeCGridDescriptor_M_N
(
d_gs_ms_ns_lengths
_vec
,
d_gs_ms_ns_strides
_vec
);
return
Transform
::
MakeCGridDescriptor_M_N
(
d_gs_ms_ns_lengths
,
d_gs_ms_ns_strides
);
}
// Z in Gemm0 C position
static
auto
MakeZGridDescriptor_M_N
(
const
std
::
vector
<
index_t
>&
z_gs_ms_ns_lengths
_vec
,
const
std
::
vector
<
index_t
>&
z_gs_ms_ns_strides
_vec
)
static
auto
MakeZGridDescriptor_M_N
(
const
std
::
vector
<
index_t
>&
z_gs_ms_ns_lengths
,
const
std
::
vector
<
index_t
>&
z_gs_ms_ns_strides
)
{
return
Transform
::
MakeCGridDescriptor_M_N
(
z_gs_ms_ns_lengths
_vec
,
z_gs_ms_ns_strides
_vec
);
return
Transform
::
MakeCGridDescriptor_M_N
(
z_gs_ms_ns_lengths
,
z_gs_ms_ns_strides
);
}
//
// dS_i_j = P_i_j .* (dP_i_j - dY_i dot Y_i)
...
...
@@ -538,10 +538,10 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
//
// QGrad in Gemm C position
static
auto
MakeQGradGridDescriptor_M_K
(
const
std
::
vector
<
index_t
>&
q_gs_ms_ks_lengths
_vec
,
const
std
::
vector
<
index_t
>&
q_gs_ms_ks_strides
_vec
)
static
auto
MakeQGradGridDescriptor_M_K
(
const
std
::
vector
<
index_t
>&
q_gs_ms_ks_lengths
,
const
std
::
vector
<
index_t
>&
q_gs_ms_ks_strides
)
{
return
Transform
::
MakeCGridDescriptor_M_N
(
q_gs_ms_ks_lengths
_vec
,
q_gs_ms_ks_strides
_vec
);
return
Transform
::
MakeCGridDescriptor_M_N
(
q_gs_ms_ks_lengths
,
q_gs_ms_ks_strides
);
}
//
...
...
@@ -549,10 +549,10 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
//
// KGrad in Gemm C position
static
auto
MakeKGradGridDescriptor_N_K
(
const
std
::
vector
<
index_t
>&
k_gs_ns_ks_lengths
_vec
,
const
std
::
vector
<
index_t
>&
k_gs_ns_ks_strides
_vec
)
static
auto
MakeKGradGridDescriptor_N_K
(
const
std
::
vector
<
index_t
>&
k_gs_ns_ks_lengths
,
const
std
::
vector
<
index_t
>&
k_gs_ns_ks_strides
)
{
return
Transform
::
MakeCGridDescriptor_M_N
(
k_gs_ns_ks_lengths
_vec
,
k_gs_ns_ks_strides
_vec
);
return
Transform
::
MakeCGridDescriptor_M_N
(
k_gs_ns_ks_lengths
,
k_gs_ns_ks_strides
);
}
static
auto
MakeLSEGridDescriptor_M
(
index_t
MRaw
)
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_mha_bwd_xdl_cshuffle_qloop_v1.hpp
View file @
eff268e6
...
...
@@ -340,20 +340,20 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
*/
// Q in Gemm A position
static
auto
MakeAGridDescriptor_AK0_M_AK1
(
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_lengths
_vec
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_strides
_vec
)
static
auto
MakeAGridDescriptor_AK0_M_AK1
(
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_lengths
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_strides
)
{
return
Transform
::
MakeAGridDescriptor_AK0_M_AK1
(
Transform
::
MakeAGridDescriptor_M_K
(
a_gs_ms_ks_lengths
_vec
,
a_gs_ms_ks_strides
_vec
),
Transform
::
MakeAGridDescriptor_M_K
(
a_gs_ms_ks_lengths
,
a_gs_ms_ks_strides
),
Number
<
AK1
>
{});
}
// K in Gemm B0 position
static
auto
MakeBGridDescriptor_BK0_N_BK1
(
const
std
::
vector
<
index_t
>&
b_gs_ns_ks_lengths
_vec
,
const
std
::
vector
<
index_t
>&
b_gs_ns_ks_strides
_vec
)
static
auto
MakeBGridDescriptor_BK0_N_BK1
(
const
std
::
vector
<
index_t
>&
b_gs_ns_ks_lengths
,
const
std
::
vector
<
index_t
>&
b_gs_ns_ks_strides
)
{
return
Transform
::
MakeB0GridDescriptor_BK0_N_BK1
(
Transform
::
MakeB0GridDescriptor_N_K
(
b_gs_ns_ks_lengths
_vec
,
b_gs_ns_ks_strides
_vec
),
Transform
::
MakeB0GridDescriptor_N_K
(
b_gs_ns_ks_lengths
,
b_gs_ns_ks_strides
),
Number
<
BK1
>
{});
}
//
...
...
@@ -361,8 +361,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
//
// VGrad in Gemm C position
static
auto
MakeVGradGridDescriptor_N_O
(
const
std
::
vector
<
index_t
>&
v_gs_os_ns_lengths
_vec
,
const
std
::
vector
<
index_t
>&
v_gs_os_ns_strides
_vec
)
static
auto
MakeVGradGridDescriptor_N_O
(
const
std
::
vector
<
index_t
>&
v_gs_os_ns_lengths
,
const
std
::
vector
<
index_t
>&
v_gs_os_ns_strides
)
{
// v_gs_os_ns -> vgrad_gs_ns_os. O dims last because output is row-major.
// Here directly rearrange lengths/strides before constructing tensor descriptor to reduce
...
...
@@ -388,17 +388,17 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
ids_old2new
.
insert
(
ids_old2new
.
end
(),
ns_ids
.
begin
(),
ns_ids
.
end
());
ids_old2new
.
insert
(
ids_old2new
.
end
(),
os_ids
.
begin
(),
os_ids
.
end
());
std
::
vector
<
index_t
>
v_gs_ns_os_lengths
_vec
(
num_dims
),
v_gs_ns_os_strides
_vec
(
num_dims
);
std
::
vector
<
index_t
>
v_gs_ns_os_lengths
(
num_dims
),
v_gs_ns_os_strides
(
num_dims
);
for
(
int
i
=
0
;
i
<
num_dims
;
i
++
)
{
index_t
id_new
=
ids_old2new
[
i
];
v_gs_ns_os_lengths
_vec
[
i
]
=
v_gs_os_ns_lengths
_vec
[
id_new
];
v_gs_ns_os_strides
_vec
[
i
]
=
v_gs_os_ns_strides
_vec
[
id_new
];
index_t
id_new
=
ids_old2new
[
i
];
v_gs_ns_os_lengths
[
i
]
=
v_gs_os_ns_lengths
[
id_new
];
v_gs_ns_os_strides
[
i
]
=
v_gs_os_ns_strides
[
id_new
];
}
const
auto
vgrad_desc_nraw_oraw
=
MakeGridDescriptorPair
<
NumDimG
,
NumDimN
,
NumDimO
,
TensorSpecialization
::
Default
>
(
v_gs_ns_os_lengths
_vec
,
v_gs_ns_os_strides
_vec
)
v_gs_ns_os_lengths
,
v_gs_ns_os_strides
)
.
second
;
return
PadTensorDescriptor
(
vgrad_desc_nraw_oraw
,
...
...
@@ -409,17 +409,17 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
//
// dQ = alpha * dS * K
//
static
auto
MakeYGradGridDescriptor_O0_M_O1
(
const
std
::
vector
<
index_t
>&
y_gs_ms_os_lengths
_vec
,
const
std
::
vector
<
index_t
>&
y_gs_ms_os_strides
_vec
)
static
auto
MakeYGradGridDescriptor_O0_M_O1
(
const
std
::
vector
<
index_t
>&
y_gs_ms_os_lengths
,
const
std
::
vector
<
index_t
>&
y_gs_ms_os_strides
)
{
return
Transform
::
MakeAGridDescriptor_AK0_M_AK1
(
Transform
::
MakeAGridDescriptor_M_K
(
y_gs_ms_os_lengths
_vec
,
y_gs_ms_os_strides
_vec
),
Transform
::
MakeAGridDescriptor_M_K
(
y_gs_ms_os_lengths
,
y_gs_ms_os_strides
),
Number
<
Y_O1
>
{});
}
// V in Gemm B position
static
auto
MakeVGridDescriptor_O0_N_O1
(
const
std
::
vector
<
index_t
>&
v_gs_os_ns_lengths
_vec
,
const
std
::
vector
<
index_t
>&
v_gs_os_ns_strides
_vec
)
static
auto
MakeVGridDescriptor_O0_N_O1
(
const
std
::
vector
<
index_t
>&
v_gs_os_ns_lengths
,
const
std
::
vector
<
index_t
>&
v_gs_os_ns_strides
)
{
// v_gs_os_ns -> vgrad_gs_ns_os. O dims last because output is row-major.
// Here directly rearrange lengths/strides before constructing tensor descriptor to reduce
...
...
@@ -445,17 +445,17 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
ids_old2new
.
insert
(
ids_old2new
.
end
(),
ns_ids
.
begin
(),
ns_ids
.
end
());
ids_old2new
.
insert
(
ids_old2new
.
end
(),
os_ids
.
begin
(),
os_ids
.
end
());
std
::
vector
<
index_t
>
v_gs_ns_os_lengths
_vec
(
num_dims
),
v_gs_ns_os_strides
_vec
(
num_dims
);
std
::
vector
<
index_t
>
v_gs_ns_os_lengths
(
num_dims
),
v_gs_ns_os_strides
(
num_dims
);
for
(
int
i
=
0
;
i
<
num_dims
;
i
++
)
{
index_t
id_new
=
ids_old2new
[
i
];
v_gs_ns_os_lengths
_vec
[
i
]
=
v_gs_os_ns_lengths
_vec
[
id_new
];
v_gs_ns_os_strides
_vec
[
i
]
=
v_gs_os_ns_strides
_vec
[
id_new
];
index_t
id_new
=
ids_old2new
[
i
];
v_gs_ns_os_lengths
[
i
]
=
v_gs_os_ns_lengths
[
id_new
];
v_gs_ns_os_strides
[
i
]
=
v_gs_os_ns_strides
[
id_new
];
}
const
auto
v_grid_desc_nraw_oraw
=
MakeGridDescriptorPair
<
NumDimG
,
NumDimN
,
NumDimO
,
TensorSpecialization
::
Default
>
(
v_gs_ns_os_lengths
_vec
,
v_gs_ns_os_strides
_vec
)
v_gs_ns_os_lengths
,
v_gs_ns_os_strides
)
.
second
;
const
auto
v_grid_desc_n_o
=
PadTensorDescriptor
(
v_grid_desc_nraw_oraw
,
...
...
@@ -466,10 +466,10 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
return
Transform
::
MakeB0GridDescriptor_BK0_N_BK1
(
v_grid_desc_n_o
,
Number
<
V_O1
>
{});
}
static
auto
MakeZGridDescriptor_M_N
(
const
std
::
vector
<
index_t
>&
z_gs_ms_ns_lengths
_vec
,
const
std
::
vector
<
index_t
>&
z_gs_ms_ns_strides
_vec
)
static
auto
MakeZGridDescriptor_M_N
(
const
std
::
vector
<
index_t
>&
z_gs_ms_ns_lengths
,
const
std
::
vector
<
index_t
>&
z_gs_ms_ns_strides
)
{
return
Transform
::
MakeCGridDescriptor_M_N
(
z_gs_ms_ns_lengths
_vec
,
z_gs_ms_ns_strides
_vec
);
return
Transform
::
MakeCGridDescriptor_M_N
(
z_gs_ms_ns_lengths
,
z_gs_ms_ns_strides
);
}
static
auto
MakeLSEGridDescriptor_M
(
index_t
MRaw
)
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_mha_bwd_xdl_cshuffle_qloop_v2.hpp
View file @
eff268e6
...
...
@@ -347,31 +347,31 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
*/
// Q in Gemm A position
static
auto
MakeAGridDescriptor_AK0_M_AK1
(
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_lengths
_vec
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_strides
_vec
)
static
auto
MakeAGridDescriptor_AK0_M_AK1
(
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_lengths
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_strides
)
{
return
Transform
::
MakeAGridDescriptor_AK0_M_AK1
(
Transform
::
MakeAGridDescriptor_M_K
(
a_gs_ms_ks_lengths
_vec
,
a_gs_ms_ks_strides
_vec
),
Transform
::
MakeAGridDescriptor_M_K
(
a_gs_ms_ks_lengths
,
a_gs_ms_ks_strides
),
Number
<
AK1
>
{});
}
// K in Gemm B0 position
static
auto
MakeBGridDescriptor_BK0_N_BK1
(
const
std
::
vector
<
index_t
>&
b_gs_ns_ks_lengths
_vec
,
const
std
::
vector
<
index_t
>&
b_gs_ns_ks_strides
_vec
)
static
auto
MakeBGridDescriptor_BK0_N_BK1
(
const
std
::
vector
<
index_t
>&
b_gs_ns_ks_lengths
,
const
std
::
vector
<
index_t
>&
b_gs_ns_ks_strides
)
{
return
Transform
::
MakeB0GridDescriptor_BK0_N_BK1
(
Transform
::
MakeB0GridDescriptor_N_K
(
b_gs_ns_ks_lengths
_vec
,
b_gs_ns_ks_strides
_vec
),
Transform
::
MakeB0GridDescriptor_N_K
(
b_gs_ns_ks_lengths
,
b_gs_ns_ks_strides
),
Number
<
BK1
>
{});
}
// V in Gemm B1 position
static
auto
MakeB1GridDescriptor_BK0_N_BK1
(
const
std
::
vector
<
index_t
>&
b1_gs_gemm1ns_gemm1ks_lengths
_vec
,
const
std
::
vector
<
index_t
>&
b1_gs_gemm1ns_gemm1ks_strides
_vec
)
MakeB1GridDescriptor_BK0_N_BK1
(
const
std
::
vector
<
index_t
>&
b1_gs_gemm1ns_gemm1ks_lengths
,
const
std
::
vector
<
index_t
>&
b1_gs_gemm1ns_gemm1ks_strides
)
{
return
Transform
::
MakeB1GridDescriptor_BK0_N_BK1
(
Transform
::
MakeB1GridDescriptor_N_K
(
b1_gs_gemm1ns_gemm1ks_lengths
_vec
,
b1_gs_gemm1ns_gemm1ks_strides
_vec
),
Transform
::
MakeB1GridDescriptor_N_K
(
b1_gs_gemm1ns_gemm1ks_lengths
,
b1_gs_gemm1ns_gemm1ks_strides
),
Number
<
B1K1
>
{});
}
...
...
@@ -380,8 +380,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
//
// VGrad in Gemm C position
static
auto
MakeVGradGridDescriptor_N_O
(
const
std
::
vector
<
index_t
>&
v_gs_os_ns_lengths
_vec
,
const
std
::
vector
<
index_t
>&
v_gs_os_ns_strides
_vec
)
static
auto
MakeVGradGridDescriptor_N_O
(
const
std
::
vector
<
index_t
>&
v_gs_os_ns_lengths
,
const
std
::
vector
<
index_t
>&
v_gs_os_ns_strides
)
{
// v_gs_os_ns -> vgrad_gs_ns_os. O dims last because output is row-major.
// Here directly rearrange lengths/strides before constructing tensor descriptor to reduce
...
...
@@ -407,17 +407,17 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
ids_old2new
.
insert
(
ids_old2new
.
end
(),
ns_ids
.
begin
(),
ns_ids
.
end
());
ids_old2new
.
insert
(
ids_old2new
.
end
(),
os_ids
.
begin
(),
os_ids
.
end
());
std
::
vector
<
index_t
>
v_gs_ns_os_lengths
_vec
(
num_dims
),
v_gs_ns_os_strides
_vec
(
num_dims
);
std
::
vector
<
index_t
>
v_gs_ns_os_lengths
(
num_dims
),
v_gs_ns_os_strides
(
num_dims
);
for
(
int
i
=
0
;
i
<
num_dims
;
i
++
)
{
index_t
id_new
=
ids_old2new
[
i
];
v_gs_ns_os_lengths
_vec
[
i
]
=
v_gs_os_ns_lengths
_vec
[
id_new
];
v_gs_ns_os_strides
_vec
[
i
]
=
v_gs_os_ns_strides
_vec
[
id_new
];
index_t
id_new
=
ids_old2new
[
i
];
v_gs_ns_os_lengths
[
i
]
=
v_gs_os_ns_lengths
[
id_new
];
v_gs_ns_os_strides
[
i
]
=
v_gs_os_ns_strides
[
id_new
];
}
const
auto
vgrad_desc_nraw_oraw
=
MakeGridDescriptorPair
<
NumDimG
,
NumDimN
,
NumDimO
,
TensorSpecialization
::
Default
>
(
v_gs_ns_os_lengths
_vec
,
v_gs_ns_os_strides
_vec
)
v_gs_ns_os_lengths
,
v_gs_ns_os_strides
)
.
second
;
return
PadTensorDescriptor
(
vgrad_desc_nraw_oraw
,
...
...
@@ -449,10 +449,10 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
//
// QGrad in Gemm C position
static
auto
MakeQGradGridDescriptor_M_K
(
const
std
::
vector
<
index_t
>&
q_gs_ms_ks_lengths
_vec
,
const
std
::
vector
<
index_t
>&
q_gs_ms_ks_strides
_vec
)
static
auto
MakeQGradGridDescriptor_M_K
(
const
std
::
vector
<
index_t
>&
q_gs_ms_ks_lengths
,
const
std
::
vector
<
index_t
>&
q_gs_ms_ks_strides
)
{
return
Transform
::
MakeCGridDescriptor_M_N
(
q_gs_ms_ks_lengths
_vec
,
q_gs_ms_ks_strides
_vec
);
return
Transform
::
MakeCGridDescriptor_M_N
(
q_gs_ms_ks_lengths
,
q_gs_ms_ks_strides
);
}
//
...
...
@@ -460,16 +460,16 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
//
// KGrad in Gemm C position
static
auto
MakeKGradGridDescriptor_N_K
(
const
std
::
vector
<
index_t
>&
k_gs_ns_ks_lengths
_vec
,
const
std
::
vector
<
index_t
>&
k_gs_ns_ks_strides
_vec
)
static
auto
MakeKGradGridDescriptor_N_K
(
const
std
::
vector
<
index_t
>&
k_gs_ns_ks_lengths
,
const
std
::
vector
<
index_t
>&
k_gs_ns_ks_strides
)
{
return
Transform
::
MakeCGridDescriptor_M_N
(
k_gs_ns_ks_lengths
_vec
,
k_gs_ns_ks_strides
_vec
);
return
Transform
::
MakeCGridDescriptor_M_N
(
k_gs_ns_ks_lengths
,
k_gs_ns_ks_strides
);
}
static
auto
MakeZGridDescriptor_M_N
(
const
std
::
vector
<
index_t
>&
z_gs_ms_ns_lengths
_vec
,
const
std
::
vector
<
index_t
>&
z_gs_ms_ns_strides
_vec
)
static
auto
MakeZGridDescriptor_M_N
(
const
std
::
vector
<
index_t
>&
z_gs_ms_ns_lengths
,
const
std
::
vector
<
index_t
>&
z_gs_ms_ns_strides
)
{
return
Transform
::
MakeCGridDescriptor_M_N
(
z_gs_ms_ns_lengths
_vec
,
z_gs_ms_ns_strides
_vec
);
return
Transform
::
MakeCGridDescriptor_M_N
(
z_gs_ms_ns_lengths
,
z_gs_ms_ns_strides
);
}
static
auto
MakeLSEGridDescriptor_M
(
index_t
MRaw
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment