"git@developer.sourcefind.cn:gaoqiong/composable_kernel.git" did not exist on "3eecbfb6ec231cd8012faceb8b6fbc87199db60d"
Commit 21cec2bb authored by letaoqin's avatar letaoqin
Browse files

change biases to bias in batched mha

parent ff6d9e1f
...@@ -750,8 +750,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -750,8 +750,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
OutputDataType* p_qgrad_grid, OutputDataType* p_qgrad_grid,
OutputDataType* p_kgrad_grid, OutputDataType* p_kgrad_grid,
OutputDataType* p_vgrad_grid, OutputDataType* p_vgrad_grid,
const D0DataType* p_acc0_biases, const D0DataType* p_acc0_bias,
const D1DataType* p_acc1_biases, const D1DataType* p_acc1_bias,
const std::vector<index_t>& a_gs_ms_ks_lengths, const std::vector<index_t>& a_gs_ms_ks_lengths,
const std::vector<index_t>& a_gs_ms_ks_strides, const std::vector<index_t>& a_gs_ms_ks_strides,
const std::vector<index_t>& b_gs_ns_ks_lengths, const std::vector<index_t>& b_gs_ns_ks_lengths,
...@@ -763,12 +763,12 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -763,12 +763,12 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
const std::vector<index_t>& c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths const std::vector<index_t>& c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths
const std::vector<index_t>& c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides const std::vector<index_t>& c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides
const std::vector<index_t>& lse_gs_ms_lengths, const std::vector<index_t>& lse_gs_ms_lengths,
const std::vector<ck::index_t>& acc0_biases_gs_ms_ns_lengths, const std::vector<ck::index_t>& acc0_bias_gs_ms_ns_lengths,
const std::vector<ck::index_t>& acc0_biases_gs_ms_ns_strides, const std::vector<ck::index_t>& acc0_bias_gs_ms_ns_strides,
const std::vector<ck::index_t>& const std::vector<ck::index_t>&
acc1_biases_gs_ms_gemm1ns_lengths, // acc1_biases_gs_ms_os_lengths acc1_bias_gs_ms_gemm1ns_lengths, // acc1_bias_gs_ms_os_lengths
const std::vector<ck::index_t>& const std::vector<ck::index_t>&
acc1_biases_gs_ms_gemm1ns_strides, // acc1_biases_gs_ms_os_strides acc1_bias_gs_ms_gemm1ns_strides, // acc1_bias_gs_ms_os_strides
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
AccElementwiseOperation acc_element_op, AccElementwiseOperation acc_element_op,
...@@ -778,7 +778,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -778,7 +778,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
std::tuple<unsigned long long, unsigned long long> seeds) std::tuple<unsigned long long, unsigned long long> seeds)
: p_a_grid_{p_a_grid}, : p_a_grid_{p_a_grid},
p_b_grid_{p_b_grid}, p_b_grid_{p_b_grid},
p_d0_grid_{p_acc0_biases}, p_d0_grid_{p_acc0_bias},
p_z_grid_{p_z_grid}, p_z_grid_{p_z_grid},
p_b1_grid_{p_b1_grid}, p_b1_grid_{p_b1_grid},
p_c_grid_{p_c_grid}, p_c_grid_{p_c_grid},
...@@ -836,12 +836,12 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -836,12 +836,12 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
p_drop_{p_drop} p_drop_{p_drop}
{ {
// TODO: implement bias addition // TODO: implement bias addition
ignore = p_acc0_biases; ignore = p_acc0_bias;
ignore = p_acc1_biases; ignore = p_acc1_bias;
ignore = acc0_biases_gs_ms_ns_lengths; ignore = acc0_bias_gs_ms_ns_lengths;
ignore = acc0_biases_gs_ms_ns_strides; ignore = acc0_bias_gs_ms_ns_strides;
ignore = acc1_biases_gs_ms_gemm1ns_lengths; ignore = acc1_bias_gs_ms_gemm1ns_lengths;
ignore = acc1_biases_gs_ms_gemm1ns_strides; ignore = acc1_bias_gs_ms_gemm1ns_strides;
if(GridwiseGemm::CheckValidity(a_grid_desc_ak0_m_ak1_, if(GridwiseGemm::CheckValidity(a_grid_desc_ak0_m_ak1_,
b_grid_desc_bk0_n_bk1_, b_grid_desc_bk0_n_bk1_,
...@@ -854,13 +854,13 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -854,13 +854,13 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
} }
if constexpr(!is_same<D0DataType, void>::value) if constexpr(!is_same<D0DataType, void>::value)
{ {
const auto d0_grid_desc_m_n = MakeDGridDescriptor_M_N(acc0_biases_gs_ms_ns_lengths, const auto d0_grid_desc_m_n =
acc0_biases_gs_ms_ns_strides); MakeDGridDescriptor_M_N(acc0_bias_gs_ms_ns_lengths, acc0_bias_gs_ms_ns_strides);
d0_grid_desc_m0_n0_m1_m2_n1_m3_ = d0_grid_desc_m0_n0_m1_m2_n1_m3_ =
GridwiseGemm::MakeD0GridDescriptor_M0_N0_M1_M2_N1_M3(d0_grid_desc_m_n); GridwiseGemm::MakeD0GridDescriptor_M0_N0_M1_M2_N1_M3(d0_grid_desc_m_n);
d0_grid_desc_g_m_n_ = Transform::MakeCGridDescriptor_G_M_N( d0_grid_desc_g_m_n_ = Transform::MakeCGridDescriptor_G_M_N(
acc0_biases_gs_ms_ns_lengths, acc0_biases_gs_ms_ns_strides); acc0_bias_gs_ms_ns_lengths, acc0_bias_gs_ms_ns_strides);
} }
compute_base_ptr_of_batch_ = ComputeBasePtrOfStridedBatch( compute_base_ptr_of_batch_ = ComputeBasePtrOfStridedBatch(
...@@ -1176,8 +1176,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -1176,8 +1176,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
OutputDataType* p_qgrad_grid, OutputDataType* p_qgrad_grid,
OutputDataType* p_kgrad_grid, OutputDataType* p_kgrad_grid,
OutputDataType* p_vgrad_grid, OutputDataType* p_vgrad_grid,
const D0DataType* p_acc0_biases, const D0DataType* p_acc0_bias,
const D1DataType* p_acc1_biases, const D1DataType* p_acc1_bias,
const std::vector<index_t>& a_gs_ms_ks_lengths, const std::vector<index_t>& a_gs_ms_ks_lengths,
const std::vector<index_t>& a_gs_ms_ks_strides, const std::vector<index_t>& a_gs_ms_ks_strides,
const std::vector<index_t>& b_gs_ns_ks_lengths, const std::vector<index_t>& b_gs_ns_ks_lengths,
...@@ -1189,12 +1189,12 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -1189,12 +1189,12 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
const std::vector<index_t>& c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths const std::vector<index_t>& c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths
const std::vector<index_t>& c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides const std::vector<index_t>& c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides
const std::vector<index_t>& lse_gs_ms_lengths, const std::vector<index_t>& lse_gs_ms_lengths,
const std::vector<ck::index_t>& acc0_biases_gs_ms_ns_lengths, const std::vector<ck::index_t>& acc0_bias_gs_ms_ns_lengths,
const std::vector<ck::index_t>& acc0_biases_gs_ms_ns_strides, const std::vector<ck::index_t>& acc0_bias_gs_ms_ns_strides,
const std::vector<ck::index_t>& const std::vector<ck::index_t>&
acc1_biases_gs_ms_gemm1ns_lengths, // acc1_biases_gs_ms_os_lengths acc1_bias_gs_ms_gemm1ns_lengths, // acc1_bias_gs_ms_os_lengths
const std::vector<ck::index_t>& const std::vector<ck::index_t>&
acc1_biases_gs_ms_gemm1ns_strides, // acc1_biases_gs_ms_os_strides acc1_bias_gs_ms_gemm1ns_strides, // acc1_bias_gs_ms_os_strides
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
AccElementwiseOperation acc_element_op, AccElementwiseOperation acc_element_op,
...@@ -1213,8 +1213,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -1213,8 +1213,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
p_qgrad_grid, p_qgrad_grid,
p_kgrad_grid, p_kgrad_grid,
p_vgrad_grid, p_vgrad_grid,
p_acc0_biases, p_acc0_bias,
p_acc1_biases, p_acc1_bias,
a_gs_ms_ks_lengths, a_gs_ms_ks_lengths,
a_gs_ms_ks_strides, a_gs_ms_ks_strides,
b_gs_ns_ks_lengths, b_gs_ns_ks_lengths,
...@@ -1226,10 +1226,10 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -1226,10 +1226,10 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths
c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides
lse_gs_ms_lengths, lse_gs_ms_lengths,
acc0_biases_gs_ms_ns_lengths, acc0_bias_gs_ms_ns_lengths,
acc0_biases_gs_ms_ns_strides, acc0_bias_gs_ms_ns_strides,
acc1_biases_gs_ms_gemm1ns_lengths, // acc1_biases_gs_ms_os_lengths acc1_bias_gs_ms_gemm1ns_lengths, // acc1_bias_gs_ms_os_lengths
acc1_biases_gs_ms_gemm1ns_strides, // acc1_biases_gs_ms_os_strides acc1_bias_gs_ms_gemm1ns_strides, // acc1_bias_gs_ms_os_strides
a_element_op, a_element_op,
b_element_op, b_element_op,
acc_element_op, acc_element_op,
...@@ -1254,8 +1254,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -1254,8 +1254,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
void* p_qgrad_grid, void* p_qgrad_grid,
void* p_kgrad_grid, void* p_kgrad_grid,
void* p_vgrad_grid, void* p_vgrad_grid,
const D0DataType* p_acc0_biases, const D0DataType* p_acc0_bias,
const D1DataType* p_acc1_biases, const D1DataType* p_acc1_bias,
const std::vector<index_t>& a_gs_ms_ks_lengths, const std::vector<index_t>& a_gs_ms_ks_lengths,
const std::vector<index_t>& a_gs_ms_ks_strides, const std::vector<index_t>& a_gs_ms_ks_strides,
const std::vector<index_t>& b_gs_ns_ks_lengths, const std::vector<index_t>& b_gs_ns_ks_lengths,
...@@ -1267,12 +1267,12 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -1267,12 +1267,12 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
const std::vector<index_t>& c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths const std::vector<index_t>& c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths
const std::vector<index_t>& c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides const std::vector<index_t>& c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides
const std::vector<index_t>& lse_gs_ms_lengths, const std::vector<index_t>& lse_gs_ms_lengths,
const std::vector<ck::index_t>& acc0_biases_gs_ms_ns_lengths, const std::vector<ck::index_t>& acc0_bias_gs_ms_ns_lengths,
const std::vector<ck::index_t>& acc0_biases_gs_ms_ns_strides, const std::vector<ck::index_t>& acc0_bias_gs_ms_ns_strides,
const std::vector<ck::index_t>& const std::vector<ck::index_t>&
acc1_biases_gs_ms_gemm1ns_lengths, // acc1_biases_gs_ms_os_lengths acc1_bias_gs_ms_gemm1ns_lengths, // acc1_bias_gs_ms_os_lengths
const std::vector<ck::index_t>& const std::vector<ck::index_t>&
acc1_biases_gs_ms_gemm1ns_strides, // acc1_biases_gs_ms_os_strides acc1_bias_gs_ms_gemm1ns_strides, // acc1_bias_gs_ms_os_strides
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
AccElementwiseOperation acc_element_op, AccElementwiseOperation acc_element_op,
...@@ -1292,8 +1292,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -1292,8 +1292,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
static_cast<OutputDataType*>(p_qgrad_grid), static_cast<OutputDataType*>(p_qgrad_grid),
static_cast<OutputDataType*>(p_kgrad_grid), static_cast<OutputDataType*>(p_kgrad_grid),
static_cast<OutputDataType*>(p_vgrad_grid), static_cast<OutputDataType*>(p_vgrad_grid),
static_cast<const D0DataType*>(p_acc0_biases), // cast in struct Argument static_cast<const D0DataType*>(p_acc0_bias), // cast in struct Argument
static_cast<const D1DataType*>(p_acc1_biases), // cast in struct Argument static_cast<const D1DataType*>(p_acc1_bias), // cast in struct Argument
a_gs_ms_ks_lengths, a_gs_ms_ks_lengths,
a_gs_ms_ks_strides, a_gs_ms_ks_strides,
b_gs_ns_ks_lengths, b_gs_ns_ks_lengths,
...@@ -1305,10 +1305,10 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -1305,10 +1305,10 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths
c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides
lse_gs_ms_lengths, lse_gs_ms_lengths,
acc0_biases_gs_ms_ns_lengths, acc0_bias_gs_ms_ns_lengths,
acc0_biases_gs_ms_ns_strides, acc0_bias_gs_ms_ns_strides,
acc1_biases_gs_ms_gemm1ns_lengths, acc1_bias_gs_ms_gemm1ns_lengths,
acc1_biases_gs_ms_gemm1ns_strides, acc1_bias_gs_ms_gemm1ns_strides,
a_element_op, a_element_op,
b_element_op, b_element_op,
acc_element_op, acc_element_op,
......
...@@ -766,8 +766,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -766,8 +766,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
OutputDataType* p_qgrad_grid, OutputDataType* p_qgrad_grid,
OutputDataType* p_kgrad_grid, OutputDataType* p_kgrad_grid,
OutputDataType* p_vgrad_grid, OutputDataType* p_vgrad_grid,
const D0DataType* p_acc0_biases, const D0DataType* p_acc0_bias,
const D1DataType* p_acc1_biases, const D1DataType* p_acc1_bias,
const std::vector<index_t>& a_gs_ms_ks_lengths, const std::vector<index_t>& a_gs_ms_ks_lengths,
const std::vector<index_t>& a_gs_ms_ks_strides, const std::vector<index_t>& a_gs_ms_ks_strides,
const std::vector<index_t>& b_gs_ns_ks_lengths, const std::vector<index_t>& b_gs_ns_ks_lengths,
...@@ -779,12 +779,12 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -779,12 +779,12 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
const std::vector<index_t>& c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths const std::vector<index_t>& c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths
const std::vector<index_t>& c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides const std::vector<index_t>& c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides
const std::vector<index_t>& lse_gs_ms_lengths, const std::vector<index_t>& lse_gs_ms_lengths,
const std::vector<ck::index_t>& acc0_biases_gs_ms_ns_lengths, const std::vector<ck::index_t>& acc0_bias_gs_ms_ns_lengths,
const std::vector<ck::index_t>& acc0_biases_gs_ms_ns_strides, const std::vector<ck::index_t>& acc0_bias_gs_ms_ns_strides,
const std::vector<ck::index_t>& const std::vector<ck::index_t>&
acc1_biases_gs_ms_gemm1ns_lengths, // acc1_biases_gs_ms_os_lengths acc1_bias_gs_ms_gemm1ns_lengths, // acc1_bias_gs_ms_os_lengths
const std::vector<ck::index_t>& const std::vector<ck::index_t>&
acc1_biases_gs_ms_gemm1ns_strides, // acc1_biases_gs_ms_os_strides acc1_bias_gs_ms_gemm1ns_strides, // acc1_bias_gs_ms_os_strides
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
AccElementwiseOperation acc_element_op, AccElementwiseOperation acc_element_op,
...@@ -794,7 +794,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -794,7 +794,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
std::tuple<unsigned long long, unsigned long long> seeds) std::tuple<unsigned long long, unsigned long long> seeds)
: p_a_grid_{p_a_grid}, : p_a_grid_{p_a_grid},
p_b_grid_{p_b_grid}, p_b_grid_{p_b_grid},
p_d0_grid_{p_acc0_biases}, p_d0_grid_{p_acc0_bias},
p_z_grid_{p_z_grid}, p_z_grid_{p_z_grid},
p_b1_grid_{p_b1_grid}, p_b1_grid_{p_b1_grid},
p_c_grid_{p_c_grid}, p_c_grid_{p_c_grid},
...@@ -851,9 +851,9 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -851,9 +851,9 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
p_drop_{p_drop} p_drop_{p_drop}
{ {
// TODO: implement bias addition // TODO: implement bias addition
ignore = p_acc1_biases; ignore = p_acc1_bias;
ignore = acc1_biases_gs_ms_gemm1ns_lengths; ignore = acc1_bias_gs_ms_gemm1ns_lengths;
ignore = acc1_biases_gs_ms_gemm1ns_strides; ignore = acc1_bias_gs_ms_gemm1ns_strides;
if(GridwiseGemm::CheckValidity(a_grid_desc_ak0_m_ak1_, if(GridwiseGemm::CheckValidity(a_grid_desc_ak0_m_ak1_,
b_grid_desc_bk0_n_bk1_, b_grid_desc_bk0_n_bk1_,
...@@ -867,13 +867,13 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -867,13 +867,13 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
if constexpr(!is_same<D0DataType, void>::value) if constexpr(!is_same<D0DataType, void>::value)
{ {
const auto d0_grid_desc_m_n = MakeDGridDescriptor_M_N(acc0_biases_gs_ms_ns_lengths, const auto d0_grid_desc_m_n =
acc0_biases_gs_ms_ns_strides); MakeDGridDescriptor_M_N(acc0_bias_gs_ms_ns_lengths, acc0_bias_gs_ms_ns_strides);
d0_grid_desc_m0_n0_m1_m2_n1_m3_ = d0_grid_desc_m0_n0_m1_m2_n1_m3_ =
GridwiseGemm::MakeD0GridDescriptor_M0_N0_M1_M2_N1_M3(d0_grid_desc_m_n); GridwiseGemm::MakeD0GridDescriptor_M0_N0_M1_M2_N1_M3(d0_grid_desc_m_n);
d0_grid_desc_g_m_n_ = Transform::MakeCGridDescriptor_G_M_N( d0_grid_desc_g_m_n_ = Transform::MakeCGridDescriptor_G_M_N(
acc0_biases_gs_ms_ns_lengths, acc0_biases_gs_ms_ns_strides); acc0_bias_gs_ms_ns_lengths, acc0_bias_gs_ms_ns_strides);
} }
compute_base_ptr_of_batch_ = ComputeBasePtrOfStridedBatch( compute_base_ptr_of_batch_ = ComputeBasePtrOfStridedBatch(
...@@ -1209,8 +1209,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -1209,8 +1209,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
OutputDataType* p_qgrad_grid, OutputDataType* p_qgrad_grid,
OutputDataType* p_kgrad_grid, OutputDataType* p_kgrad_grid,
OutputDataType* p_vgrad_grid, OutputDataType* p_vgrad_grid,
const D0DataType* p_acc0_biases, const D0DataType* p_acc0_bias,
const D1DataType* p_acc1_biases, const D1DataType* p_acc1_bias,
const std::vector<index_t>& a_gs_ms_ks_lengths, const std::vector<index_t>& a_gs_ms_ks_lengths,
const std::vector<index_t>& a_gs_ms_ks_strides, const std::vector<index_t>& a_gs_ms_ks_strides,
const std::vector<index_t>& b_gs_ns_ks_lengths, const std::vector<index_t>& b_gs_ns_ks_lengths,
...@@ -1222,12 +1222,12 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -1222,12 +1222,12 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
const std::vector<index_t>& c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths const std::vector<index_t>& c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths
const std::vector<index_t>& c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides const std::vector<index_t>& c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides
const std::vector<index_t>& lse_gs_ms_lengths, const std::vector<index_t>& lse_gs_ms_lengths,
const std::vector<ck::index_t>& acc0_biases_gs_ms_ns_lengths, const std::vector<ck::index_t>& acc0_bias_gs_ms_ns_lengths,
const std::vector<ck::index_t>& acc0_biases_gs_ms_ns_strides, const std::vector<ck::index_t>& acc0_bias_gs_ms_ns_strides,
const std::vector<ck::index_t>& const std::vector<ck::index_t>&
acc1_biases_gs_ms_gemm1ns_lengths, // acc1_biases_gs_ms_os_lengths acc1_bias_gs_ms_gemm1ns_lengths, // acc1_bias_gs_ms_os_lengths
const std::vector<ck::index_t>& const std::vector<ck::index_t>&
acc1_biases_gs_ms_gemm1ns_strides, // acc1_biases_gs_ms_os_strides acc1_bias_gs_ms_gemm1ns_strides, // acc1_bias_gs_ms_os_strides
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
AccElementwiseOperation acc_element_op, AccElementwiseOperation acc_element_op,
...@@ -1246,8 +1246,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -1246,8 +1246,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
p_qgrad_grid, p_qgrad_grid,
p_kgrad_grid, p_kgrad_grid,
p_vgrad_grid, p_vgrad_grid,
p_acc0_biases, p_acc0_bias,
p_acc1_biases, p_acc1_bias,
a_gs_ms_ks_lengths, a_gs_ms_ks_lengths,
a_gs_ms_ks_strides, a_gs_ms_ks_strides,
b_gs_ns_ks_lengths, b_gs_ns_ks_lengths,
...@@ -1259,10 +1259,10 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -1259,10 +1259,10 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths
c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides
lse_gs_ms_lengths, lse_gs_ms_lengths,
acc0_biases_gs_ms_ns_lengths, acc0_bias_gs_ms_ns_lengths,
acc0_biases_gs_ms_ns_strides, acc0_bias_gs_ms_ns_strides,
acc1_biases_gs_ms_gemm1ns_lengths, // acc1_biases_gs_ms_os_lengths acc1_bias_gs_ms_gemm1ns_lengths, // acc1_bias_gs_ms_os_lengths
acc1_biases_gs_ms_gemm1ns_strides, // acc1_biases_gs_ms_os_strides acc1_bias_gs_ms_gemm1ns_strides, // acc1_bias_gs_ms_os_strides
a_element_op, a_element_op,
b_element_op, b_element_op,
acc_element_op, acc_element_op,
...@@ -1287,8 +1287,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -1287,8 +1287,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
void* p_qgrad_grid, void* p_qgrad_grid,
void* p_kgrad_grid, void* p_kgrad_grid,
void* p_vgrad_grid, void* p_vgrad_grid,
const void* p_acc0_biases, const void* p_acc0_bias,
const void* p_acc1_biases, const void* p_acc1_bias,
const std::vector<index_t>& a_gs_ms_ks_lengths, const std::vector<index_t>& a_gs_ms_ks_lengths,
const std::vector<index_t>& a_gs_ms_ks_strides, const std::vector<index_t>& a_gs_ms_ks_strides,
const std::vector<index_t>& b_gs_ns_ks_lengths, const std::vector<index_t>& b_gs_ns_ks_lengths,
...@@ -1300,12 +1300,12 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -1300,12 +1300,12 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
const std::vector<index_t>& c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths const std::vector<index_t>& c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths
const std::vector<index_t>& c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides const std::vector<index_t>& c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides
const std::vector<index_t>& lse_gs_ms_lengths, const std::vector<index_t>& lse_gs_ms_lengths,
const std::vector<ck::index_t>& acc0_biases_gs_ms_ns_lengths, const std::vector<ck::index_t>& acc0_bias_gs_ms_ns_lengths,
const std::vector<ck::index_t>& acc0_biases_gs_ms_ns_strides, const std::vector<ck::index_t>& acc0_bias_gs_ms_ns_strides,
const std::vector<ck::index_t>& const std::vector<ck::index_t>&
acc1_biases_gs_ms_gemm1ns_lengths, // acc1_biases_gs_ms_os_lengths acc1_bias_gs_ms_gemm1ns_lengths, // acc1_bias_gs_ms_os_lengths
const std::vector<ck::index_t>& const std::vector<ck::index_t>&
acc1_biases_gs_ms_gemm1ns_strides, // acc1_biases_gs_ms_os_strides acc1_bias_gs_ms_gemm1ns_strides, // acc1_bias_gs_ms_os_strides
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
AccElementwiseOperation acc_element_op, AccElementwiseOperation acc_element_op,
...@@ -1325,8 +1325,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -1325,8 +1325,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
static_cast<OutputDataType*>(p_qgrad_grid), static_cast<OutputDataType*>(p_qgrad_grid),
static_cast<OutputDataType*>(p_kgrad_grid), static_cast<OutputDataType*>(p_kgrad_grid),
static_cast<OutputDataType*>(p_vgrad_grid), static_cast<OutputDataType*>(p_vgrad_grid),
static_cast<const D0DataType*>(p_acc0_biases), // cast in struct Argument static_cast<const D0DataType*>(p_acc0_bias), // cast in struct Argument
static_cast<const D1DataType*>(p_acc1_biases), // cast in struct Argument static_cast<const D1DataType*>(p_acc1_bias), // cast in struct Argument
a_gs_ms_ks_lengths, a_gs_ms_ks_lengths,
a_gs_ms_ks_strides, a_gs_ms_ks_strides,
b_gs_ns_ks_lengths, b_gs_ns_ks_lengths,
...@@ -1338,10 +1338,10 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -1338,10 +1338,10 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths
c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides
lse_gs_ms_lengths, lse_gs_ms_lengths,
acc0_biases_gs_ms_ns_lengths, acc0_bias_gs_ms_ns_lengths,
acc0_biases_gs_ms_ns_strides, acc0_bias_gs_ms_ns_strides,
acc1_biases_gs_ms_gemm1ns_lengths, acc1_bias_gs_ms_gemm1ns_lengths,
acc1_biases_gs_ms_gemm1ns_strides, acc1_bias_gs_ms_gemm1ns_strides,
a_element_op, a_element_op,
b_element_op, b_element_op,
acc_element_op, acc_element_op,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment