Commit 21cec2bb authored by letaoqin's avatar letaoqin
Browse files

change biases to bias in batched mha

parent ff6d9e1f
......@@ -750,8 +750,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
OutputDataType* p_qgrad_grid,
OutputDataType* p_kgrad_grid,
OutputDataType* p_vgrad_grid,
const D0DataType* p_acc0_biases,
const D1DataType* p_acc1_biases,
const D0DataType* p_acc0_bias,
const D1DataType* p_acc1_bias,
const std::vector<index_t>& a_gs_ms_ks_lengths,
const std::vector<index_t>& a_gs_ms_ks_strides,
const std::vector<index_t>& b_gs_ns_ks_lengths,
......@@ -763,12 +763,12 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
const std::vector<index_t>& c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths
const std::vector<index_t>& c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides
const std::vector<index_t>& lse_gs_ms_lengths,
const std::vector<ck::index_t>& acc0_biases_gs_ms_ns_lengths,
const std::vector<ck::index_t>& acc0_biases_gs_ms_ns_strides,
const std::vector<ck::index_t>& acc0_bias_gs_ms_ns_lengths,
const std::vector<ck::index_t>& acc0_bias_gs_ms_ns_strides,
const std::vector<ck::index_t>&
acc1_biases_gs_ms_gemm1ns_lengths, // acc1_biases_gs_ms_os_lengths
acc1_bias_gs_ms_gemm1ns_lengths, // acc1_bias_gs_ms_os_lengths
const std::vector<ck::index_t>&
acc1_biases_gs_ms_gemm1ns_strides, // acc1_biases_gs_ms_os_strides
acc1_bias_gs_ms_gemm1ns_strides, // acc1_bias_gs_ms_os_strides
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
AccElementwiseOperation acc_element_op,
......@@ -778,7 +778,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
std::tuple<unsigned long long, unsigned long long> seeds)
: p_a_grid_{p_a_grid},
p_b_grid_{p_b_grid},
p_d0_grid_{p_acc0_biases},
p_d0_grid_{p_acc0_bias},
p_z_grid_{p_z_grid},
p_b1_grid_{p_b1_grid},
p_c_grid_{p_c_grid},
......@@ -836,12 +836,12 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
p_drop_{p_drop}
{
// TODO: implement bias addition
ignore = p_acc0_biases;
ignore = p_acc1_biases;
ignore = acc0_biases_gs_ms_ns_lengths;
ignore = acc0_biases_gs_ms_ns_strides;
ignore = acc1_biases_gs_ms_gemm1ns_lengths;
ignore = acc1_biases_gs_ms_gemm1ns_strides;
ignore = p_acc0_bias;
ignore = p_acc1_bias;
ignore = acc0_bias_gs_ms_ns_lengths;
ignore = acc0_bias_gs_ms_ns_strides;
ignore = acc1_bias_gs_ms_gemm1ns_lengths;
ignore = acc1_bias_gs_ms_gemm1ns_strides;
if(GridwiseGemm::CheckValidity(a_grid_desc_ak0_m_ak1_,
b_grid_desc_bk0_n_bk1_,
......@@ -854,13 +854,13 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
}
if constexpr(!is_same<D0DataType, void>::value)
{
const auto d0_grid_desc_m_n = MakeDGridDescriptor_M_N(acc0_biases_gs_ms_ns_lengths,
acc0_biases_gs_ms_ns_strides);
const auto d0_grid_desc_m_n =
MakeDGridDescriptor_M_N(acc0_bias_gs_ms_ns_lengths, acc0_bias_gs_ms_ns_strides);
d0_grid_desc_m0_n0_m1_m2_n1_m3_ =
GridwiseGemm::MakeD0GridDescriptor_M0_N0_M1_M2_N1_M3(d0_grid_desc_m_n);
d0_grid_desc_g_m_n_ = Transform::MakeCGridDescriptor_G_M_N(
acc0_biases_gs_ms_ns_lengths, acc0_biases_gs_ms_ns_strides);
acc0_bias_gs_ms_ns_lengths, acc0_bias_gs_ms_ns_strides);
}
compute_base_ptr_of_batch_ = ComputeBasePtrOfStridedBatch(
......@@ -1176,8 +1176,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
OutputDataType* p_qgrad_grid,
OutputDataType* p_kgrad_grid,
OutputDataType* p_vgrad_grid,
const D0DataType* p_acc0_biases,
const D1DataType* p_acc1_biases,
const D0DataType* p_acc0_bias,
const D1DataType* p_acc1_bias,
const std::vector<index_t>& a_gs_ms_ks_lengths,
const std::vector<index_t>& a_gs_ms_ks_strides,
const std::vector<index_t>& b_gs_ns_ks_lengths,
......@@ -1189,12 +1189,12 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
const std::vector<index_t>& c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths
const std::vector<index_t>& c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides
const std::vector<index_t>& lse_gs_ms_lengths,
const std::vector<ck::index_t>& acc0_biases_gs_ms_ns_lengths,
const std::vector<ck::index_t>& acc0_biases_gs_ms_ns_strides,
const std::vector<ck::index_t>& acc0_bias_gs_ms_ns_lengths,
const std::vector<ck::index_t>& acc0_bias_gs_ms_ns_strides,
const std::vector<ck::index_t>&
acc1_biases_gs_ms_gemm1ns_lengths, // acc1_biases_gs_ms_os_lengths
acc1_bias_gs_ms_gemm1ns_lengths, // acc1_bias_gs_ms_os_lengths
const std::vector<ck::index_t>&
acc1_biases_gs_ms_gemm1ns_strides, // acc1_biases_gs_ms_os_strides
acc1_bias_gs_ms_gemm1ns_strides, // acc1_bias_gs_ms_os_strides
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
AccElementwiseOperation acc_element_op,
......@@ -1213,8 +1213,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
p_qgrad_grid,
p_kgrad_grid,
p_vgrad_grid,
p_acc0_biases,
p_acc1_biases,
p_acc0_bias,
p_acc1_bias,
a_gs_ms_ks_lengths,
a_gs_ms_ks_strides,
b_gs_ns_ks_lengths,
......@@ -1226,10 +1226,10 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths
c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides
lse_gs_ms_lengths,
acc0_biases_gs_ms_ns_lengths,
acc0_biases_gs_ms_ns_strides,
acc1_biases_gs_ms_gemm1ns_lengths, // acc1_biases_gs_ms_os_lengths
acc1_biases_gs_ms_gemm1ns_strides, // acc1_biases_gs_ms_os_strides
acc0_bias_gs_ms_ns_lengths,
acc0_bias_gs_ms_ns_strides,
acc1_bias_gs_ms_gemm1ns_lengths, // acc1_bias_gs_ms_os_lengths
acc1_bias_gs_ms_gemm1ns_strides, // acc1_bias_gs_ms_os_strides
a_element_op,
b_element_op,
acc_element_op,
......@@ -1254,8 +1254,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
void* p_qgrad_grid,
void* p_kgrad_grid,
void* p_vgrad_grid,
const D0DataType* p_acc0_biases,
const D1DataType* p_acc1_biases,
const D0DataType* p_acc0_bias,
const D1DataType* p_acc1_bias,
const std::vector<index_t>& a_gs_ms_ks_lengths,
const std::vector<index_t>& a_gs_ms_ks_strides,
const std::vector<index_t>& b_gs_ns_ks_lengths,
......@@ -1267,12 +1267,12 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
const std::vector<index_t>& c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths
const std::vector<index_t>& c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides
const std::vector<index_t>& lse_gs_ms_lengths,
const std::vector<ck::index_t>& acc0_biases_gs_ms_ns_lengths,
const std::vector<ck::index_t>& acc0_biases_gs_ms_ns_strides,
const std::vector<ck::index_t>& acc0_bias_gs_ms_ns_lengths,
const std::vector<ck::index_t>& acc0_bias_gs_ms_ns_strides,
const std::vector<ck::index_t>&
acc1_biases_gs_ms_gemm1ns_lengths, // acc1_biases_gs_ms_os_lengths
acc1_bias_gs_ms_gemm1ns_lengths, // acc1_bias_gs_ms_os_lengths
const std::vector<ck::index_t>&
acc1_biases_gs_ms_gemm1ns_strides, // acc1_biases_gs_ms_os_strides
acc1_bias_gs_ms_gemm1ns_strides, // acc1_bias_gs_ms_os_strides
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
AccElementwiseOperation acc_element_op,
......@@ -1292,8 +1292,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
static_cast<OutputDataType*>(p_qgrad_grid),
static_cast<OutputDataType*>(p_kgrad_grid),
static_cast<OutputDataType*>(p_vgrad_grid),
static_cast<const D0DataType*>(p_acc0_biases), // cast in struct Argument
static_cast<const D1DataType*>(p_acc1_biases), // cast in struct Argument
static_cast<const D0DataType*>(p_acc0_bias), // cast in struct Argument
static_cast<const D1DataType*>(p_acc1_bias), // cast in struct Argument
a_gs_ms_ks_lengths,
a_gs_ms_ks_strides,
b_gs_ns_ks_lengths,
......@@ -1305,10 +1305,10 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths
c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides
lse_gs_ms_lengths,
acc0_biases_gs_ms_ns_lengths,
acc0_biases_gs_ms_ns_strides,
acc1_biases_gs_ms_gemm1ns_lengths,
acc1_biases_gs_ms_gemm1ns_strides,
acc0_bias_gs_ms_ns_lengths,
acc0_bias_gs_ms_ns_strides,
acc1_bias_gs_ms_gemm1ns_lengths,
acc1_bias_gs_ms_gemm1ns_strides,
a_element_op,
b_element_op,
acc_element_op,
......
......@@ -766,8 +766,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
OutputDataType* p_qgrad_grid,
OutputDataType* p_kgrad_grid,
OutputDataType* p_vgrad_grid,
const D0DataType* p_acc0_biases,
const D1DataType* p_acc1_biases,
const D0DataType* p_acc0_bias,
const D1DataType* p_acc1_bias,
const std::vector<index_t>& a_gs_ms_ks_lengths,
const std::vector<index_t>& a_gs_ms_ks_strides,
const std::vector<index_t>& b_gs_ns_ks_lengths,
......@@ -779,12 +779,12 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
const std::vector<index_t>& c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths
const std::vector<index_t>& c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides
const std::vector<index_t>& lse_gs_ms_lengths,
const std::vector<ck::index_t>& acc0_biases_gs_ms_ns_lengths,
const std::vector<ck::index_t>& acc0_biases_gs_ms_ns_strides,
const std::vector<ck::index_t>& acc0_bias_gs_ms_ns_lengths,
const std::vector<ck::index_t>& acc0_bias_gs_ms_ns_strides,
const std::vector<ck::index_t>&
acc1_biases_gs_ms_gemm1ns_lengths, // acc1_biases_gs_ms_os_lengths
acc1_bias_gs_ms_gemm1ns_lengths, // acc1_bias_gs_ms_os_lengths
const std::vector<ck::index_t>&
acc1_biases_gs_ms_gemm1ns_strides, // acc1_biases_gs_ms_os_strides
acc1_bias_gs_ms_gemm1ns_strides, // acc1_bias_gs_ms_os_strides
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
AccElementwiseOperation acc_element_op,
......@@ -794,7 +794,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
std::tuple<unsigned long long, unsigned long long> seeds)
: p_a_grid_{p_a_grid},
p_b_grid_{p_b_grid},
p_d0_grid_{p_acc0_biases},
p_d0_grid_{p_acc0_bias},
p_z_grid_{p_z_grid},
p_b1_grid_{p_b1_grid},
p_c_grid_{p_c_grid},
......@@ -851,9 +851,9 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
p_drop_{p_drop}
{
// TODO: implement bias addition
ignore = p_acc1_biases;
ignore = acc1_biases_gs_ms_gemm1ns_lengths;
ignore = acc1_biases_gs_ms_gemm1ns_strides;
ignore = p_acc1_bias;
ignore = acc1_bias_gs_ms_gemm1ns_lengths;
ignore = acc1_bias_gs_ms_gemm1ns_strides;
if(GridwiseGemm::CheckValidity(a_grid_desc_ak0_m_ak1_,
b_grid_desc_bk0_n_bk1_,
......@@ -867,13 +867,13 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
if constexpr(!is_same<D0DataType, void>::value)
{
const auto d0_grid_desc_m_n = MakeDGridDescriptor_M_N(acc0_biases_gs_ms_ns_lengths,
acc0_biases_gs_ms_ns_strides);
const auto d0_grid_desc_m_n =
MakeDGridDescriptor_M_N(acc0_bias_gs_ms_ns_lengths, acc0_bias_gs_ms_ns_strides);
d0_grid_desc_m0_n0_m1_m2_n1_m3_ =
GridwiseGemm::MakeD0GridDescriptor_M0_N0_M1_M2_N1_M3(d0_grid_desc_m_n);
d0_grid_desc_g_m_n_ = Transform::MakeCGridDescriptor_G_M_N(
acc0_biases_gs_ms_ns_lengths, acc0_biases_gs_ms_ns_strides);
acc0_bias_gs_ms_ns_lengths, acc0_bias_gs_ms_ns_strides);
}
compute_base_ptr_of_batch_ = ComputeBasePtrOfStridedBatch(
......@@ -1209,8 +1209,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
OutputDataType* p_qgrad_grid,
OutputDataType* p_kgrad_grid,
OutputDataType* p_vgrad_grid,
const D0DataType* p_acc0_biases,
const D1DataType* p_acc1_biases,
const D0DataType* p_acc0_bias,
const D1DataType* p_acc1_bias,
const std::vector<index_t>& a_gs_ms_ks_lengths,
const std::vector<index_t>& a_gs_ms_ks_strides,
const std::vector<index_t>& b_gs_ns_ks_lengths,
......@@ -1222,12 +1222,12 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
const std::vector<index_t>& c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths
const std::vector<index_t>& c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides
const std::vector<index_t>& lse_gs_ms_lengths,
const std::vector<ck::index_t>& acc0_biases_gs_ms_ns_lengths,
const std::vector<ck::index_t>& acc0_biases_gs_ms_ns_strides,
const std::vector<ck::index_t>& acc0_bias_gs_ms_ns_lengths,
const std::vector<ck::index_t>& acc0_bias_gs_ms_ns_strides,
const std::vector<ck::index_t>&
acc1_biases_gs_ms_gemm1ns_lengths, // acc1_biases_gs_ms_os_lengths
acc1_bias_gs_ms_gemm1ns_lengths, // acc1_bias_gs_ms_os_lengths
const std::vector<ck::index_t>&
acc1_biases_gs_ms_gemm1ns_strides, // acc1_biases_gs_ms_os_strides
acc1_bias_gs_ms_gemm1ns_strides, // acc1_bias_gs_ms_os_strides
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
AccElementwiseOperation acc_element_op,
......@@ -1246,8 +1246,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
p_qgrad_grid,
p_kgrad_grid,
p_vgrad_grid,
p_acc0_biases,
p_acc1_biases,
p_acc0_bias,
p_acc1_bias,
a_gs_ms_ks_lengths,
a_gs_ms_ks_strides,
b_gs_ns_ks_lengths,
......@@ -1259,10 +1259,10 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths
c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides
lse_gs_ms_lengths,
acc0_biases_gs_ms_ns_lengths,
acc0_biases_gs_ms_ns_strides,
acc1_biases_gs_ms_gemm1ns_lengths, // acc1_biases_gs_ms_os_lengths
acc1_biases_gs_ms_gemm1ns_strides, // acc1_biases_gs_ms_os_strides
acc0_bias_gs_ms_ns_lengths,
acc0_bias_gs_ms_ns_strides,
acc1_bias_gs_ms_gemm1ns_lengths, // acc1_bias_gs_ms_os_lengths
acc1_bias_gs_ms_gemm1ns_strides, // acc1_bias_gs_ms_os_strides
a_element_op,
b_element_op,
acc_element_op,
......@@ -1287,8 +1287,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
void* p_qgrad_grid,
void* p_kgrad_grid,
void* p_vgrad_grid,
const void* p_acc0_biases,
const void* p_acc1_biases,
const void* p_acc0_bias,
const void* p_acc1_bias,
const std::vector<index_t>& a_gs_ms_ks_lengths,
const std::vector<index_t>& a_gs_ms_ks_strides,
const std::vector<index_t>& b_gs_ns_ks_lengths,
......@@ -1300,12 +1300,12 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
const std::vector<index_t>& c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths
const std::vector<index_t>& c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides
const std::vector<index_t>& lse_gs_ms_lengths,
const std::vector<ck::index_t>& acc0_biases_gs_ms_ns_lengths,
const std::vector<ck::index_t>& acc0_biases_gs_ms_ns_strides,
const std::vector<ck::index_t>& acc0_bias_gs_ms_ns_lengths,
const std::vector<ck::index_t>& acc0_bias_gs_ms_ns_strides,
const std::vector<ck::index_t>&
acc1_biases_gs_ms_gemm1ns_lengths, // acc1_biases_gs_ms_os_lengths
acc1_bias_gs_ms_gemm1ns_lengths, // acc1_bias_gs_ms_os_lengths
const std::vector<ck::index_t>&
acc1_biases_gs_ms_gemm1ns_strides, // acc1_biases_gs_ms_os_strides
acc1_bias_gs_ms_gemm1ns_strides, // acc1_bias_gs_ms_os_strides
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
AccElementwiseOperation acc_element_op,
......@@ -1325,8 +1325,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
static_cast<OutputDataType*>(p_qgrad_grid),
static_cast<OutputDataType*>(p_kgrad_grid),
static_cast<OutputDataType*>(p_vgrad_grid),
static_cast<const D0DataType*>(p_acc0_biases), // cast in struct Argument
static_cast<const D1DataType*>(p_acc1_biases), // cast in struct Argument
static_cast<const D0DataType*>(p_acc0_bias), // cast in struct Argument
static_cast<const D1DataType*>(p_acc1_bias), // cast in struct Argument
a_gs_ms_ks_lengths,
a_gs_ms_ks_strides,
b_gs_ns_ks_lengths,
......@@ -1338,10 +1338,10 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths
c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides
lse_gs_ms_lengths,
acc0_biases_gs_ms_ns_lengths,
acc0_biases_gs_ms_ns_strides,
acc1_biases_gs_ms_gemm1ns_lengths,
acc1_biases_gs_ms_gemm1ns_strides,
acc0_bias_gs_ms_ns_lengths,
acc0_bias_gs_ms_ns_strides,
acc1_bias_gs_ms_gemm1ns_lengths,
acc1_bias_gs_ms_gemm1ns_strides,
a_element_op,
b_element_op,
acc_element_op,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment