Commit eff268e6 authored by letaoqin's avatar letaoqin
Browse files

remove _vec for bwd parameters

parent 2464edd0
...@@ -352,31 +352,31 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -352,31 +352,31 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
*/ */
// Q in Gemm A position // Q in Gemm A position
static auto MakeAGridDescriptor_AK0_M_AK1(const std::vector<index_t>& a_gs_ms_ks_lengths_vec, static auto MakeAGridDescriptor_AK0_M_AK1(const std::vector<index_t>& a_gs_ms_ks_lengths,
const std::vector<index_t>& a_gs_ms_ks_strides_vec) const std::vector<index_t>& a_gs_ms_ks_strides)
{ {
return Transform::MakeAGridDescriptor_AK0_M_AK1( return Transform::MakeAGridDescriptor_AK0_M_AK1(
Transform::MakeAGridDescriptor_M_K(a_gs_ms_ks_lengths_vec, a_gs_ms_ks_strides_vec), Transform::MakeAGridDescriptor_M_K(a_gs_ms_ks_lengths, a_gs_ms_ks_strides),
Number<AK1>{}); Number<AK1>{});
} }
// K in Gemm B0 position // K in Gemm B0 position
static auto MakeBGridDescriptor_BK0_N_BK1(const std::vector<index_t>& b_gs_ns_ks_lengths_vec, static auto MakeBGridDescriptor_BK0_N_BK1(const std::vector<index_t>& b_gs_ns_ks_lengths,
const std::vector<index_t>& b_gs_ns_ks_strides_vec) const std::vector<index_t>& b_gs_ns_ks_strides)
{ {
return Transform::MakeB0GridDescriptor_BK0_N_BK1( return Transform::MakeB0GridDescriptor_BK0_N_BK1(
Transform::MakeB0GridDescriptor_N_K(b_gs_ns_ks_lengths_vec, b_gs_ns_ks_strides_vec), Transform::MakeB0GridDescriptor_N_K(b_gs_ns_ks_lengths, b_gs_ns_ks_strides),
Number<BK1>{}); Number<BK1>{});
} }
// V in Gemm B1 position // V in Gemm B1 position
static auto static auto
MakeB1GridDescriptor_BK0_N_BK1(const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_lengths_vec, MakeB1GridDescriptor_BK0_N_BK1(const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_lengths,
const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_strides_vec) const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_strides)
{ {
return Transform::MakeB1GridDescriptor_BK0_N_BK1( return Transform::MakeB1GridDescriptor_BK0_N_BK1(
Transform::MakeB1GridDescriptor_N_K(b1_gs_gemm1ns_gemm1ks_lengths_vec, Transform::MakeB1GridDescriptor_N_K(b1_gs_gemm1ns_gemm1ks_lengths,
b1_gs_gemm1ns_gemm1ks_strides_vec), b1_gs_gemm1ns_gemm1ks_strides),
Number<B1K1>{}); Number<B1K1>{});
} }
...@@ -385,8 +385,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -385,8 +385,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
// //
// VGrad in Gemm C position // VGrad in Gemm C position
static auto MakeVGradGridDescriptor_N_O(const std::vector<index_t>& v_gs_os_ns_lengths_vec, static auto MakeVGradGridDescriptor_N_O(const std::vector<index_t>& v_gs_os_ns_lengths,
const std::vector<index_t>& v_gs_os_ns_strides_vec) const std::vector<index_t>& v_gs_os_ns_strides)
{ {
// v_gs_os_ns -> vgrad_gs_ns_os. O dims last because output is row-major. // v_gs_os_ns -> vgrad_gs_ns_os. O dims last because output is row-major.
// Here directly rearrange lengths/strides before constructing tensor descriptor to reduce // Here directly rearrange lengths/strides before constructing tensor descriptor to reduce
...@@ -412,17 +412,17 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -412,17 +412,17 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
ids_old2new.insert(ids_old2new.end(), ns_ids.begin(), ns_ids.end()); ids_old2new.insert(ids_old2new.end(), ns_ids.begin(), ns_ids.end());
ids_old2new.insert(ids_old2new.end(), os_ids.begin(), os_ids.end()); ids_old2new.insert(ids_old2new.end(), os_ids.begin(), os_ids.end());
std::vector<index_t> v_gs_ns_os_lengths_vec(num_dims), v_gs_ns_os_strides_vec(num_dims); std::vector<index_t> v_gs_ns_os_lengths(num_dims), v_gs_ns_os_strides(num_dims);
for(int i = 0; i < num_dims; i++) for(int i = 0; i < num_dims; i++)
{ {
index_t id_new = ids_old2new[i]; index_t id_new = ids_old2new[i];
v_gs_ns_os_lengths_vec[i] = v_gs_os_ns_lengths_vec[id_new]; v_gs_ns_os_lengths[i] = v_gs_os_ns_lengths[id_new];
v_gs_ns_os_strides_vec[i] = v_gs_os_ns_strides_vec[id_new]; v_gs_ns_os_strides[i] = v_gs_os_ns_strides[id_new];
} }
const auto vgrad_desc_nraw_oraw = const auto vgrad_desc_nraw_oraw =
MakeGridDescriptorPair<NumDimG, NumDimN, NumDimO, TensorSpecialization::Default>( MakeGridDescriptorPair<NumDimG, NumDimN, NumDimO, TensorSpecialization::Default>(
v_gs_ns_os_lengths_vec, v_gs_ns_os_strides_vec) v_gs_ns_os_lengths, v_gs_ns_os_strides)
.second; .second;
return PadTensorDescriptor(vgrad_desc_nraw_oraw, return PadTensorDescriptor(vgrad_desc_nraw_oraw,
...@@ -451,17 +451,17 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -451,17 +451,17 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
// //
// YGrad in Gemm A position // YGrad in Gemm A position
static auto MakeYGradGridDescriptor_O0_M_O1(const std::vector<index_t>& y_gs_ms_os_lengths_vec, static auto MakeYGradGridDescriptor_O0_M_O1(const std::vector<index_t>& y_gs_ms_os_lengths,
const std::vector<index_t>& y_gs_ms_os_strides_vec) const std::vector<index_t>& y_gs_ms_os_strides)
{ {
return Transform::MakeAGridDescriptor_AK0_M_AK1( return Transform::MakeAGridDescriptor_AK0_M_AK1(
Transform::MakeAGridDescriptor_M_K(y_gs_ms_os_lengths_vec, y_gs_ms_os_strides_vec), Transform::MakeAGridDescriptor_M_K(y_gs_ms_os_lengths, y_gs_ms_os_strides),
Number<Y_O1>{}); Number<Y_O1>{});
} }
// V in Gemm B position // V in Gemm B position
static auto MakeVGridDescriptor_O0_N_O1(const std::vector<index_t>& v_gs_os_ns_lengths_vec, static auto MakeVGridDescriptor_O0_N_O1(const std::vector<index_t>& v_gs_os_ns_lengths,
const std::vector<index_t>& v_gs_os_ns_strides_vec) const std::vector<index_t>& v_gs_os_ns_strides)
{ {
// v_gs_os_ns -> vgrad_gs_ns_os. O dims last because output is row-major. // v_gs_os_ns -> vgrad_gs_ns_os. O dims last because output is row-major.
// Here directly rearrange lengths/strides before constructing tensor descriptor to reduce // Here directly rearrange lengths/strides before constructing tensor descriptor to reduce
...@@ -487,17 +487,17 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -487,17 +487,17 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
ids_old2new.insert(ids_old2new.end(), ns_ids.begin(), ns_ids.end()); ids_old2new.insert(ids_old2new.end(), ns_ids.begin(), ns_ids.end());
ids_old2new.insert(ids_old2new.end(), os_ids.begin(), os_ids.end()); ids_old2new.insert(ids_old2new.end(), os_ids.begin(), os_ids.end());
std::vector<index_t> v_gs_ns_os_lengths_vec(num_dims), v_gs_ns_os_strides_vec(num_dims); std::vector<index_t> v_gs_ns_os_lengths(num_dims), v_gs_ns_os_strides(num_dims);
for(int i = 0; i < num_dims; i++) for(int i = 0; i < num_dims; i++)
{ {
index_t id_new = ids_old2new[i]; index_t id_new = ids_old2new[i];
v_gs_ns_os_lengths_vec[i] = v_gs_os_ns_lengths_vec[id_new]; v_gs_ns_os_lengths[i] = v_gs_os_ns_lengths[id_new];
v_gs_ns_os_strides_vec[i] = v_gs_os_ns_strides_vec[id_new]; v_gs_ns_os_strides[i] = v_gs_os_ns_strides[id_new];
} }
const auto v_grid_desc_nraw_oraw = const auto v_grid_desc_nraw_oraw =
MakeGridDescriptorPair<NumDimG, NumDimN, NumDimO, TensorSpecialization::Default>( MakeGridDescriptorPair<NumDimG, NumDimN, NumDimO, TensorSpecialization::Default>(
v_gs_ns_os_lengths_vec, v_gs_ns_os_strides_vec) v_gs_ns_os_lengths, v_gs_ns_os_strides)
.second; .second;
const auto v_grid_desc_n_o = PadTensorDescriptor(v_grid_desc_nraw_oraw, const auto v_grid_desc_n_o = PadTensorDescriptor(v_grid_desc_nraw_oraw,
...@@ -509,10 +509,10 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -509,10 +509,10 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
} }
// Z in Gemm0 C position // Z in Gemm0 C position
static auto MakeZGridDescriptor_M_N(const std::vector<index_t>& z_gs_ms_ns_lengths_vec, static auto MakeZGridDescriptor_M_N(const std::vector<index_t>& z_gs_ms_ns_lengths,
const std::vector<index_t>& z_gs_ms_ns_strides_vec) const std::vector<index_t>& z_gs_ms_ns_strides)
{ {
return Transform::MakeCGridDescriptor_M_N(z_gs_ms_ns_lengths_vec, z_gs_ms_ns_strides_vec); return Transform::MakeCGridDescriptor_M_N(z_gs_ms_ns_lengths, z_gs_ms_ns_strides);
} }
// //
// dS_i_j = P_i_j .* (dP_i_j - dY_i dot Y_i) // dS_i_j = P_i_j .* (dP_i_j - dY_i dot Y_i)
...@@ -523,10 +523,10 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -523,10 +523,10 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
// //
// QGrad in Gemm C position // QGrad in Gemm C position
static auto MakeQGradGridDescriptor_M_K(const std::vector<index_t>& q_gs_ms_ks_lengths_vec, static auto MakeQGradGridDescriptor_M_K(const std::vector<index_t>& q_gs_ms_ks_lengths,
const std::vector<index_t>& q_gs_ms_ks_strides_vec) const std::vector<index_t>& q_gs_ms_ks_strides)
{ {
return Transform::MakeCGridDescriptor_M_N(q_gs_ms_ks_lengths_vec, q_gs_ms_ks_strides_vec); return Transform::MakeCGridDescriptor_M_N(q_gs_ms_ks_lengths, q_gs_ms_ks_strides);
} }
// //
...@@ -534,10 +534,10 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -534,10 +534,10 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
// //
// KGrad in Gemm C position // KGrad in Gemm C position
static auto MakeKGradGridDescriptor_N_K(const std::vector<index_t>& k_gs_ns_ks_lengths_vec, static auto MakeKGradGridDescriptor_N_K(const std::vector<index_t>& k_gs_ns_ks_lengths,
const std::vector<index_t>& k_gs_ns_ks_strides_vec) const std::vector<index_t>& k_gs_ns_ks_strides)
{ {
return Transform::MakeCGridDescriptor_M_N(k_gs_ns_ks_lengths_vec, k_gs_ns_ks_strides_vec); return Transform::MakeCGridDescriptor_M_N(k_gs_ns_ks_lengths, k_gs_ns_ks_strides);
} }
static auto MakeLSEGridDescriptor_M(index_t MRaw) static auto MakeLSEGridDescriptor_M(index_t MRaw)
...@@ -565,10 +565,10 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -565,10 +565,10 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
} }
} }
// D in Gemm0 C position // D in Gemm0 C position
static auto MakeDGridDescriptor_M_N(const std::vector<index_t>& d_gs_ms_ns_lengths_vec, static auto MakeDGridDescriptor_M_N(const std::vector<index_t>& d_gs_ms_ns_lengths,
const std::vector<index_t>& d_gs_ms_ns_strides_vec) const std::vector<index_t>& d_gs_ms_ns_strides)
{ {
return Transform::MakeCGridDescriptor_M_N(d_gs_ms_ns_lengths_vec, d_gs_ms_ns_strides_vec); return Transform::MakeCGridDescriptor_M_N(d_gs_ms_ns_lengths, d_gs_ms_ns_strides);
} }
using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1({}, {})); using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1({}, {}));
......
...@@ -360,31 +360,31 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -360,31 +360,31 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
*/ */
// Q in Gemm A position // Q in Gemm A position
static auto MakeAGridDescriptor_AK0_M_AK1(const std::vector<index_t>& a_gs_ms_ks_lengths_vec, static auto MakeAGridDescriptor_AK0_M_AK1(const std::vector<index_t>& a_gs_ms_ks_lengths,
const std::vector<index_t>& a_gs_ms_ks_strides_vec) const std::vector<index_t>& a_gs_ms_ks_strides)
{ {
return Transform::MakeAGridDescriptor_AK0_M_AK1( return Transform::MakeAGridDescriptor_AK0_M_AK1(
Transform::MakeAGridDescriptor_M_K(a_gs_ms_ks_lengths_vec, a_gs_ms_ks_strides_vec), Transform::MakeAGridDescriptor_M_K(a_gs_ms_ks_lengths, a_gs_ms_ks_strides),
Number<AK1>{}); Number<AK1>{});
} }
// K in Gemm B0 position // K in Gemm B0 position
static auto MakeBGridDescriptor_BK0_N_BK1(const std::vector<index_t>& b_gs_ns_ks_lengths_vec, static auto MakeBGridDescriptor_BK0_N_BK1(const std::vector<index_t>& b_gs_ns_ks_lengths,
const std::vector<index_t>& b_gs_ns_ks_strides_vec) const std::vector<index_t>& b_gs_ns_ks_strides)
{ {
return Transform::MakeB0GridDescriptor_BK0_N_BK1( return Transform::MakeB0GridDescriptor_BK0_N_BK1(
Transform::MakeB0GridDescriptor_N_K(b_gs_ns_ks_lengths_vec, b_gs_ns_ks_strides_vec), Transform::MakeB0GridDescriptor_N_K(b_gs_ns_ks_lengths, b_gs_ns_ks_strides),
Number<BK1>{}); Number<BK1>{});
} }
// V in Gemm B1 position // V in Gemm B1 position
static auto static auto
MakeB1GridDescriptor_BK0_N_BK1(const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_lengths_vec, MakeB1GridDescriptor_BK0_N_BK1(const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_lengths,
const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_strides_vec) const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_strides)
{ {
return Transform::MakeB1GridDescriptor_BK0_N_BK1( return Transform::MakeB1GridDescriptor_BK0_N_BK1(
Transform::MakeB1GridDescriptor_N_K(b1_gs_gemm1ns_gemm1ks_lengths_vec, Transform::MakeB1GridDescriptor_N_K(b1_gs_gemm1ns_gemm1ks_lengths,
b1_gs_gemm1ns_gemm1ks_strides_vec), b1_gs_gemm1ns_gemm1ks_strides),
Number<B1K1>{}); Number<B1K1>{});
} }
...@@ -393,8 +393,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -393,8 +393,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
// //
// VGrad in Gemm C position // VGrad in Gemm C position
static auto MakeVGradGridDescriptor_N_O(const std::vector<index_t>& v_gs_os_ns_lengths_vec, static auto MakeVGradGridDescriptor_N_O(const std::vector<index_t>& v_gs_os_ns_lengths,
const std::vector<index_t>& v_gs_os_ns_strides_vec) const std::vector<index_t>& v_gs_os_ns_strides)
{ {
// v_gs_os_ns -> vgrad_gs_ns_os. O dims last because output is row-major. // v_gs_os_ns -> vgrad_gs_ns_os. O dims last because output is row-major.
// Here directly rearrange lengths/strides before constructing tensor descriptor to reduce // Here directly rearrange lengths/strides before constructing tensor descriptor to reduce
...@@ -420,17 +420,17 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -420,17 +420,17 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
ids_old2new.insert(ids_old2new.end(), ns_ids.begin(), ns_ids.end()); ids_old2new.insert(ids_old2new.end(), ns_ids.begin(), ns_ids.end());
ids_old2new.insert(ids_old2new.end(), os_ids.begin(), os_ids.end()); ids_old2new.insert(ids_old2new.end(), os_ids.begin(), os_ids.end());
std::vector<index_t> v_gs_ns_os_lengths_vec(num_dims), v_gs_ns_os_strides_vec(num_dims); std::vector<index_t> v_gs_ns_os_lengths(num_dims), v_gs_ns_os_strides(num_dims);
for(int i = 0; i < num_dims; i++) for(int i = 0; i < num_dims; i++)
{ {
index_t id_new = ids_old2new[i]; index_t id_new = ids_old2new[i];
v_gs_ns_os_lengths_vec[i] = v_gs_os_ns_lengths_vec[id_new]; v_gs_ns_os_lengths[i] = v_gs_os_ns_lengths[id_new];
v_gs_ns_os_strides_vec[i] = v_gs_os_ns_strides_vec[id_new]; v_gs_ns_os_strides[i] = v_gs_os_ns_strides[id_new];
} }
const auto vgrad_desc_nraw_oraw = const auto vgrad_desc_nraw_oraw =
MakeGridDescriptorPair<NumDimG, NumDimN, NumDimO, TensorSpecialization::Default>( MakeGridDescriptorPair<NumDimG, NumDimN, NumDimO, TensorSpecialization::Default>(
v_gs_ns_os_lengths_vec, v_gs_ns_os_strides_vec) v_gs_ns_os_lengths, v_gs_ns_os_strides)
.second; .second;
return PadTensorDescriptor(vgrad_desc_nraw_oraw, return PadTensorDescriptor(vgrad_desc_nraw_oraw,
...@@ -459,17 +459,17 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -459,17 +459,17 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
// //
// YGrad in Gemm A position // YGrad in Gemm A position
static auto MakeYGradGridDescriptor_O0_M_O1(const std::vector<index_t>& y_gs_ms_os_lengths_vec, static auto MakeYGradGridDescriptor_O0_M_O1(const std::vector<index_t>& y_gs_ms_os_lengths,
const std::vector<index_t>& y_gs_ms_os_strides_vec) const std::vector<index_t>& y_gs_ms_os_strides)
{ {
return Transform::MakeAGridDescriptor_AK0_M_AK1( return Transform::MakeAGridDescriptor_AK0_M_AK1(
Transform::MakeAGridDescriptor_M_K(y_gs_ms_os_lengths_vec, y_gs_ms_os_strides_vec), Transform::MakeAGridDescriptor_M_K(y_gs_ms_os_lengths, y_gs_ms_os_strides),
Number<Y_O1>{}); Number<Y_O1>{});
} }
// V in Gemm B position // V in Gemm B position
static auto MakeVGridDescriptor_O0_N_O1(const std::vector<index_t>& v_gs_os_ns_lengths_vec, static auto MakeVGridDescriptor_O0_N_O1(const std::vector<index_t>& v_gs_os_ns_lengths,
const std::vector<index_t>& v_gs_os_ns_strides_vec) const std::vector<index_t>& v_gs_os_ns_strides)
{ {
// v_gs_os_ns -> vgrad_gs_ns_os. O dims last because output is row-major. // v_gs_os_ns -> vgrad_gs_ns_os. O dims last because output is row-major.
// Here directly rearrange lengths/strides before constructing tensor descriptor to reduce // Here directly rearrange lengths/strides before constructing tensor descriptor to reduce
...@@ -495,17 +495,17 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -495,17 +495,17 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
ids_old2new.insert(ids_old2new.end(), ns_ids.begin(), ns_ids.end()); ids_old2new.insert(ids_old2new.end(), ns_ids.begin(), ns_ids.end());
ids_old2new.insert(ids_old2new.end(), os_ids.begin(), os_ids.end()); ids_old2new.insert(ids_old2new.end(), os_ids.begin(), os_ids.end());
std::vector<index_t> v_gs_ns_os_lengths_vec(num_dims), v_gs_ns_os_strides_vec(num_dims); std::vector<index_t> v_gs_ns_os_lengths(num_dims), v_gs_ns_os_strides(num_dims);
for(int i = 0; i < num_dims; i++) for(int i = 0; i < num_dims; i++)
{ {
index_t id_new = ids_old2new[i]; index_t id_new = ids_old2new[i];
v_gs_ns_os_lengths_vec[i] = v_gs_os_ns_lengths_vec[id_new]; v_gs_ns_os_lengths[i] = v_gs_os_ns_lengths[id_new];
v_gs_ns_os_strides_vec[i] = v_gs_os_ns_strides_vec[id_new]; v_gs_ns_os_strides[i] = v_gs_os_ns_strides[id_new];
} }
const auto v_grid_desc_nraw_oraw = const auto v_grid_desc_nraw_oraw =
MakeGridDescriptorPair<NumDimG, NumDimN, NumDimO, TensorSpecialization::Default>( MakeGridDescriptorPair<NumDimG, NumDimN, NumDimO, TensorSpecialization::Default>(
v_gs_ns_os_lengths_vec, v_gs_ns_os_strides_vec) v_gs_ns_os_lengths, v_gs_ns_os_strides)
.second; .second;
const auto v_grid_desc_n_o = PadTensorDescriptor(v_grid_desc_nraw_oraw, const auto v_grid_desc_n_o = PadTensorDescriptor(v_grid_desc_nraw_oraw,
...@@ -517,17 +517,17 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -517,17 +517,17 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
} }
// D in Gemm0 C position // D in Gemm0 C position
static auto MakeDGridDescriptor_M_N(const std::vector<index_t>& d_gs_ms_ns_lengths_vec, static auto MakeDGridDescriptor_M_N(const std::vector<index_t>& d_gs_ms_ns_lengths,
const std::vector<index_t>& d_gs_ms_ns_strides_vec) const std::vector<index_t>& d_gs_ms_ns_strides)
{ {
return Transform::MakeCGridDescriptor_M_N(d_gs_ms_ns_lengths_vec, d_gs_ms_ns_strides_vec); return Transform::MakeCGridDescriptor_M_N(d_gs_ms_ns_lengths, d_gs_ms_ns_strides);
} }
// Z in Gemm0 C position // Z in Gemm0 C position
static auto MakeZGridDescriptor_M_N(const std::vector<index_t>& z_gs_ms_ns_lengths_vec, static auto MakeZGridDescriptor_M_N(const std::vector<index_t>& z_gs_ms_ns_lengths,
const std::vector<index_t>& z_gs_ms_ns_strides_vec) const std::vector<index_t>& z_gs_ms_ns_strides)
{ {
return Transform::MakeCGridDescriptor_M_N(z_gs_ms_ns_lengths_vec, z_gs_ms_ns_strides_vec); return Transform::MakeCGridDescriptor_M_N(z_gs_ms_ns_lengths, z_gs_ms_ns_strides);
} }
// //
// dS_i_j = P_i_j .* (dP_i_j - dY_i dot Y_i) // dS_i_j = P_i_j .* (dP_i_j - dY_i dot Y_i)
...@@ -538,10 +538,10 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -538,10 +538,10 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
// //
// QGrad in Gemm C position // QGrad in Gemm C position
static auto MakeQGradGridDescriptor_M_K(const std::vector<index_t>& q_gs_ms_ks_lengths_vec, static auto MakeQGradGridDescriptor_M_K(const std::vector<index_t>& q_gs_ms_ks_lengths,
const std::vector<index_t>& q_gs_ms_ks_strides_vec) const std::vector<index_t>& q_gs_ms_ks_strides)
{ {
return Transform::MakeCGridDescriptor_M_N(q_gs_ms_ks_lengths_vec, q_gs_ms_ks_strides_vec); return Transform::MakeCGridDescriptor_M_N(q_gs_ms_ks_lengths, q_gs_ms_ks_strides);
} }
// //
...@@ -549,10 +549,10 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -549,10 +549,10 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
// //
// KGrad in Gemm C position // KGrad in Gemm C position
static auto MakeKGradGridDescriptor_N_K(const std::vector<index_t>& k_gs_ns_ks_lengths_vec, static auto MakeKGradGridDescriptor_N_K(const std::vector<index_t>& k_gs_ns_ks_lengths,
const std::vector<index_t>& k_gs_ns_ks_strides_vec) const std::vector<index_t>& k_gs_ns_ks_strides)
{ {
return Transform::MakeCGridDescriptor_M_N(k_gs_ns_ks_lengths_vec, k_gs_ns_ks_strides_vec); return Transform::MakeCGridDescriptor_M_N(k_gs_ns_ks_lengths, k_gs_ns_ks_strides);
} }
static auto MakeLSEGridDescriptor_M(index_t MRaw) static auto MakeLSEGridDescriptor_M(index_t MRaw)
......
...@@ -340,20 +340,20 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -340,20 +340,20 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
*/ */
// Q in Gemm A position // Q in Gemm A position
static auto MakeAGridDescriptor_AK0_M_AK1(const std::vector<index_t>& a_gs_ms_ks_lengths_vec, static auto MakeAGridDescriptor_AK0_M_AK1(const std::vector<index_t>& a_gs_ms_ks_lengths,
const std::vector<index_t>& a_gs_ms_ks_strides_vec) const std::vector<index_t>& a_gs_ms_ks_strides)
{ {
return Transform::MakeAGridDescriptor_AK0_M_AK1( return Transform::MakeAGridDescriptor_AK0_M_AK1(
Transform::MakeAGridDescriptor_M_K(a_gs_ms_ks_lengths_vec, a_gs_ms_ks_strides_vec), Transform::MakeAGridDescriptor_M_K(a_gs_ms_ks_lengths, a_gs_ms_ks_strides),
Number<AK1>{}); Number<AK1>{});
} }
// K in Gemm B0 position // K in Gemm B0 position
static auto MakeBGridDescriptor_BK0_N_BK1(const std::vector<index_t>& b_gs_ns_ks_lengths_vec, static auto MakeBGridDescriptor_BK0_N_BK1(const std::vector<index_t>& b_gs_ns_ks_lengths,
const std::vector<index_t>& b_gs_ns_ks_strides_vec) const std::vector<index_t>& b_gs_ns_ks_strides)
{ {
return Transform::MakeB0GridDescriptor_BK0_N_BK1( return Transform::MakeB0GridDescriptor_BK0_N_BK1(
Transform::MakeB0GridDescriptor_N_K(b_gs_ns_ks_lengths_vec, b_gs_ns_ks_strides_vec), Transform::MakeB0GridDescriptor_N_K(b_gs_ns_ks_lengths, b_gs_ns_ks_strides),
Number<BK1>{}); Number<BK1>{});
} }
// //
...@@ -361,8 +361,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -361,8 +361,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
// //
// VGrad in Gemm C position // VGrad in Gemm C position
static auto MakeVGradGridDescriptor_N_O(const std::vector<index_t>& v_gs_os_ns_lengths_vec, static auto MakeVGradGridDescriptor_N_O(const std::vector<index_t>& v_gs_os_ns_lengths,
const std::vector<index_t>& v_gs_os_ns_strides_vec) const std::vector<index_t>& v_gs_os_ns_strides)
{ {
// v_gs_os_ns -> vgrad_gs_ns_os. O dims last because output is row-major. // v_gs_os_ns -> vgrad_gs_ns_os. O dims last because output is row-major.
// Here directly rearrange lengths/strides before constructing tensor descriptor to reduce // Here directly rearrange lengths/strides before constructing tensor descriptor to reduce
...@@ -388,17 +388,17 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -388,17 +388,17 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
ids_old2new.insert(ids_old2new.end(), ns_ids.begin(), ns_ids.end()); ids_old2new.insert(ids_old2new.end(), ns_ids.begin(), ns_ids.end());
ids_old2new.insert(ids_old2new.end(), os_ids.begin(), os_ids.end()); ids_old2new.insert(ids_old2new.end(), os_ids.begin(), os_ids.end());
std::vector<index_t> v_gs_ns_os_lengths_vec(num_dims), v_gs_ns_os_strides_vec(num_dims); std::vector<index_t> v_gs_ns_os_lengths(num_dims), v_gs_ns_os_strides(num_dims);
for(int i = 0; i < num_dims; i++) for(int i = 0; i < num_dims; i++)
{ {
index_t id_new = ids_old2new[i]; index_t id_new = ids_old2new[i];
v_gs_ns_os_lengths_vec[i] = v_gs_os_ns_lengths_vec[id_new]; v_gs_ns_os_lengths[i] = v_gs_os_ns_lengths[id_new];
v_gs_ns_os_strides_vec[i] = v_gs_os_ns_strides_vec[id_new]; v_gs_ns_os_strides[i] = v_gs_os_ns_strides[id_new];
} }
const auto vgrad_desc_nraw_oraw = const auto vgrad_desc_nraw_oraw =
MakeGridDescriptorPair<NumDimG, NumDimN, NumDimO, TensorSpecialization::Default>( MakeGridDescriptorPair<NumDimG, NumDimN, NumDimO, TensorSpecialization::Default>(
v_gs_ns_os_lengths_vec, v_gs_ns_os_strides_vec) v_gs_ns_os_lengths, v_gs_ns_os_strides)
.second; .second;
return PadTensorDescriptor(vgrad_desc_nraw_oraw, return PadTensorDescriptor(vgrad_desc_nraw_oraw,
...@@ -409,17 +409,17 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -409,17 +409,17 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
// //
// dQ = alpha * dS * K // dQ = alpha * dS * K
// //
static auto MakeYGradGridDescriptor_O0_M_O1(const std::vector<index_t>& y_gs_ms_os_lengths_vec, static auto MakeYGradGridDescriptor_O0_M_O1(const std::vector<index_t>& y_gs_ms_os_lengths,
const std::vector<index_t>& y_gs_ms_os_strides_vec) const std::vector<index_t>& y_gs_ms_os_strides)
{ {
return Transform::MakeAGridDescriptor_AK0_M_AK1( return Transform::MakeAGridDescriptor_AK0_M_AK1(
Transform::MakeAGridDescriptor_M_K(y_gs_ms_os_lengths_vec, y_gs_ms_os_strides_vec), Transform::MakeAGridDescriptor_M_K(y_gs_ms_os_lengths, y_gs_ms_os_strides),
Number<Y_O1>{}); Number<Y_O1>{});
} }
// V in Gemm B position // V in Gemm B position
static auto MakeVGridDescriptor_O0_N_O1(const std::vector<index_t>& v_gs_os_ns_lengths_vec, static auto MakeVGridDescriptor_O0_N_O1(const std::vector<index_t>& v_gs_os_ns_lengths,
const std::vector<index_t>& v_gs_os_ns_strides_vec) const std::vector<index_t>& v_gs_os_ns_strides)
{ {
// v_gs_os_ns -> vgrad_gs_ns_os. O dims last because output is row-major. // v_gs_os_ns -> vgrad_gs_ns_os. O dims last because output is row-major.
// Here directly rearrange lengths/strides before constructing tensor descriptor to reduce // Here directly rearrange lengths/strides before constructing tensor descriptor to reduce
...@@ -445,17 +445,17 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -445,17 +445,17 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
ids_old2new.insert(ids_old2new.end(), ns_ids.begin(), ns_ids.end()); ids_old2new.insert(ids_old2new.end(), ns_ids.begin(), ns_ids.end());
ids_old2new.insert(ids_old2new.end(), os_ids.begin(), os_ids.end()); ids_old2new.insert(ids_old2new.end(), os_ids.begin(), os_ids.end());
std::vector<index_t> v_gs_ns_os_lengths_vec(num_dims), v_gs_ns_os_strides_vec(num_dims); std::vector<index_t> v_gs_ns_os_lengths(num_dims), v_gs_ns_os_strides(num_dims);
for(int i = 0; i < num_dims; i++) for(int i = 0; i < num_dims; i++)
{ {
index_t id_new = ids_old2new[i]; index_t id_new = ids_old2new[i];
v_gs_ns_os_lengths_vec[i] = v_gs_os_ns_lengths_vec[id_new]; v_gs_ns_os_lengths[i] = v_gs_os_ns_lengths[id_new];
v_gs_ns_os_strides_vec[i] = v_gs_os_ns_strides_vec[id_new]; v_gs_ns_os_strides[i] = v_gs_os_ns_strides[id_new];
} }
const auto v_grid_desc_nraw_oraw = const auto v_grid_desc_nraw_oraw =
MakeGridDescriptorPair<NumDimG, NumDimN, NumDimO, TensorSpecialization::Default>( MakeGridDescriptorPair<NumDimG, NumDimN, NumDimO, TensorSpecialization::Default>(
v_gs_ns_os_lengths_vec, v_gs_ns_os_strides_vec) v_gs_ns_os_lengths, v_gs_ns_os_strides)
.second; .second;
const auto v_grid_desc_n_o = PadTensorDescriptor(v_grid_desc_nraw_oraw, const auto v_grid_desc_n_o = PadTensorDescriptor(v_grid_desc_nraw_oraw,
...@@ -466,10 +466,10 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -466,10 +466,10 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
return Transform::MakeB0GridDescriptor_BK0_N_BK1(v_grid_desc_n_o, Number<V_O1>{}); return Transform::MakeB0GridDescriptor_BK0_N_BK1(v_grid_desc_n_o, Number<V_O1>{});
} }
static auto MakeZGridDescriptor_M_N(const std::vector<index_t>& z_gs_ms_ns_lengths_vec, static auto MakeZGridDescriptor_M_N(const std::vector<index_t>& z_gs_ms_ns_lengths,
const std::vector<index_t>& z_gs_ms_ns_strides_vec) const std::vector<index_t>& z_gs_ms_ns_strides)
{ {
return Transform::MakeCGridDescriptor_M_N(z_gs_ms_ns_lengths_vec, z_gs_ms_ns_strides_vec); return Transform::MakeCGridDescriptor_M_N(z_gs_ms_ns_lengths, z_gs_ms_ns_strides);
} }
static auto MakeLSEGridDescriptor_M(index_t MRaw) static auto MakeLSEGridDescriptor_M(index_t MRaw)
......
...@@ -347,31 +347,31 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -347,31 +347,31 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
*/ */
// Q in Gemm A position // Q in Gemm A position
static auto MakeAGridDescriptor_AK0_M_AK1(const std::vector<index_t>& a_gs_ms_ks_lengths_vec, static auto MakeAGridDescriptor_AK0_M_AK1(const std::vector<index_t>& a_gs_ms_ks_lengths,
const std::vector<index_t>& a_gs_ms_ks_strides_vec) const std::vector<index_t>& a_gs_ms_ks_strides)
{ {
return Transform::MakeAGridDescriptor_AK0_M_AK1( return Transform::MakeAGridDescriptor_AK0_M_AK1(
Transform::MakeAGridDescriptor_M_K(a_gs_ms_ks_lengths_vec, a_gs_ms_ks_strides_vec), Transform::MakeAGridDescriptor_M_K(a_gs_ms_ks_lengths, a_gs_ms_ks_strides),
Number<AK1>{}); Number<AK1>{});
} }
// K in Gemm B0 position // K in Gemm B0 position
static auto MakeBGridDescriptor_BK0_N_BK1(const std::vector<index_t>& b_gs_ns_ks_lengths_vec, static auto MakeBGridDescriptor_BK0_N_BK1(const std::vector<index_t>& b_gs_ns_ks_lengths,
const std::vector<index_t>& b_gs_ns_ks_strides_vec) const std::vector<index_t>& b_gs_ns_ks_strides)
{ {
return Transform::MakeB0GridDescriptor_BK0_N_BK1( return Transform::MakeB0GridDescriptor_BK0_N_BK1(
Transform::MakeB0GridDescriptor_N_K(b_gs_ns_ks_lengths_vec, b_gs_ns_ks_strides_vec), Transform::MakeB0GridDescriptor_N_K(b_gs_ns_ks_lengths, b_gs_ns_ks_strides),
Number<BK1>{}); Number<BK1>{});
} }
// V in Gemm B1 position // V in Gemm B1 position
static auto static auto
MakeB1GridDescriptor_BK0_N_BK1(const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_lengths_vec, MakeB1GridDescriptor_BK0_N_BK1(const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_lengths,
const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_strides_vec) const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_strides)
{ {
return Transform::MakeB1GridDescriptor_BK0_N_BK1( return Transform::MakeB1GridDescriptor_BK0_N_BK1(
Transform::MakeB1GridDescriptor_N_K(b1_gs_gemm1ns_gemm1ks_lengths_vec, Transform::MakeB1GridDescriptor_N_K(b1_gs_gemm1ns_gemm1ks_lengths,
b1_gs_gemm1ns_gemm1ks_strides_vec), b1_gs_gemm1ns_gemm1ks_strides),
Number<B1K1>{}); Number<B1K1>{});
} }
...@@ -380,8 +380,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -380,8 +380,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
// //
// VGrad in Gemm C position // VGrad in Gemm C position
static auto MakeVGradGridDescriptor_N_O(const std::vector<index_t>& v_gs_os_ns_lengths_vec, static auto MakeVGradGridDescriptor_N_O(const std::vector<index_t>& v_gs_os_ns_lengths,
const std::vector<index_t>& v_gs_os_ns_strides_vec) const std::vector<index_t>& v_gs_os_ns_strides)
{ {
// v_gs_os_ns -> vgrad_gs_ns_os. O dims last because output is row-major. // v_gs_os_ns -> vgrad_gs_ns_os. O dims last because output is row-major.
// Here directly rearrange lengths/strides before constructing tensor descriptor to reduce // Here directly rearrange lengths/strides before constructing tensor descriptor to reduce
...@@ -407,17 +407,17 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -407,17 +407,17 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
ids_old2new.insert(ids_old2new.end(), ns_ids.begin(), ns_ids.end()); ids_old2new.insert(ids_old2new.end(), ns_ids.begin(), ns_ids.end());
ids_old2new.insert(ids_old2new.end(), os_ids.begin(), os_ids.end()); ids_old2new.insert(ids_old2new.end(), os_ids.begin(), os_ids.end());
std::vector<index_t> v_gs_ns_os_lengths_vec(num_dims), v_gs_ns_os_strides_vec(num_dims); std::vector<index_t> v_gs_ns_os_lengths(num_dims), v_gs_ns_os_strides(num_dims);
for(int i = 0; i < num_dims; i++) for(int i = 0; i < num_dims; i++)
{ {
index_t id_new = ids_old2new[i]; index_t id_new = ids_old2new[i];
v_gs_ns_os_lengths_vec[i] = v_gs_os_ns_lengths_vec[id_new]; v_gs_ns_os_lengths[i] = v_gs_os_ns_lengths[id_new];
v_gs_ns_os_strides_vec[i] = v_gs_os_ns_strides_vec[id_new]; v_gs_ns_os_strides[i] = v_gs_os_ns_strides[id_new];
} }
const auto vgrad_desc_nraw_oraw = const auto vgrad_desc_nraw_oraw =
MakeGridDescriptorPair<NumDimG, NumDimN, NumDimO, TensorSpecialization::Default>( MakeGridDescriptorPair<NumDimG, NumDimN, NumDimO, TensorSpecialization::Default>(
v_gs_ns_os_lengths_vec, v_gs_ns_os_strides_vec) v_gs_ns_os_lengths, v_gs_ns_os_strides)
.second; .second;
return PadTensorDescriptor(vgrad_desc_nraw_oraw, return PadTensorDescriptor(vgrad_desc_nraw_oraw,
...@@ -449,10 +449,10 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -449,10 +449,10 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
// //
// QGrad in Gemm C position // QGrad in Gemm C position
static auto MakeQGradGridDescriptor_M_K(const std::vector<index_t>& q_gs_ms_ks_lengths_vec, static auto MakeQGradGridDescriptor_M_K(const std::vector<index_t>& q_gs_ms_ks_lengths,
const std::vector<index_t>& q_gs_ms_ks_strides_vec) const std::vector<index_t>& q_gs_ms_ks_strides)
{ {
return Transform::MakeCGridDescriptor_M_N(q_gs_ms_ks_lengths_vec, q_gs_ms_ks_strides_vec); return Transform::MakeCGridDescriptor_M_N(q_gs_ms_ks_lengths, q_gs_ms_ks_strides);
} }
// //
...@@ -460,16 +460,16 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -460,16 +460,16 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
// //
// KGrad in Gemm C position // KGrad in Gemm C position
static auto MakeKGradGridDescriptor_N_K(const std::vector<index_t>& k_gs_ns_ks_lengths_vec, static auto MakeKGradGridDescriptor_N_K(const std::vector<index_t>& k_gs_ns_ks_lengths,
const std::vector<index_t>& k_gs_ns_ks_strides_vec) const std::vector<index_t>& k_gs_ns_ks_strides)
{ {
return Transform::MakeCGridDescriptor_M_N(k_gs_ns_ks_lengths_vec, k_gs_ns_ks_strides_vec); return Transform::MakeCGridDescriptor_M_N(k_gs_ns_ks_lengths, k_gs_ns_ks_strides);
} }
static auto MakeZGridDescriptor_M_N(const std::vector<index_t>& z_gs_ms_ns_lengths_vec, static auto MakeZGridDescriptor_M_N(const std::vector<index_t>& z_gs_ms_ns_lengths,
const std::vector<index_t>& z_gs_ms_ns_strides_vec) const std::vector<index_t>& z_gs_ms_ns_strides)
{ {
return Transform::MakeCGridDescriptor_M_N(z_gs_ms_ns_lengths_vec, z_gs_ms_ns_strides_vec); return Transform::MakeCGridDescriptor_M_N(z_gs_ms_ns_lengths, z_gs_ms_ns_strides);
} }
static auto MakeLSEGridDescriptor_M(index_t MRaw) static auto MakeLSEGridDescriptor_M(index_t MRaw)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment