Commit ff6d9e1f authored by letaoqin's avatar letaoqin
Browse files

grouped bwd change p_accx_bias to p_accx_bias_vec

parent eff268e6
......@@ -755,8 +755,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
std::vector<void*>& p_Qgrads,
std::vector<void*>& p_Kgrads,
std::vector<void*>& p_Vgrads,
const std::vector<const void*>& p_acc0_bias,
const std::vector<const void*>& p_acc1_bias,
const std::vector<const void*>& p_acc0_bias_vec,
const std::vector<const void*>& p_acc1_bias_vec,
const std::vector<ProblemDesc>& problem_desc_vec,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
......@@ -787,9 +787,9 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
group_count_ == ck::type_convert<ck::index_t>(p_Kgrads.size()) &&
group_count_ == ck::type_convert<ck::index_t>(p_Vgrads.size()) &&
group_count_ == ck::type_convert<ck::index_t>(p_LSEs.size()) &&
(group_count_ == ck::type_convert<ck::index_t>(p_acc0_bias.size()) ||
ck::type_convert<ck::index_t>(p_acc0_bias.size() == 0)) &&
0 == p_acc1_bias.size()))
(group_count_ == ck::type_convert<ck::index_t>(p_acc0_bias_vec.size()) ||
ck::type_convert<ck::index_t>(p_acc0_bias_vec.size() == 0)) &&
0 == p_acc1_bias_vec.size()))
{
throw std::runtime_error("wrong! group_count_ != p_As/b/b1/c.size");
}
......@@ -803,8 +803,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
const auto p_a_grid = static_cast<const InputDataType*>(p_As[i]);
const auto p_b_grid = static_cast<const InputDataType*>(p_Bs[i]);
const auto p_d0_grid =
(ck::type_convert<ck::index_t>(p_acc0_bias.size()) == group_count_)
? static_cast<const D0DataType*>(p_acc0_bias[i])
(ck::type_convert<ck::index_t>(p_acc0_bias_vec.size()) == group_count_)
? static_cast<const D0DataType*>(p_acc0_bias_vec[i])
: nullptr;
auto p_z_grid = static_cast<ZDataType*>(p_Zs[i]);
const auto p_b1_grid = static_cast<const InputDataType*>(p_B1s[i]);
......@@ -970,8 +970,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
d0_n_length_stride});
}
// TODO: implement bias addition
// ignore = p_acc0_bias;
// ignore = p_acc1_bias;
// ignore = p_acc0_bias_vec;
// ignore = p_acc1_bias_vec;
// ignore = acc0_bias_gs_ms_ns_lengths;
// ignore = acc0_bias_gs_ms_ns_strides;
// ignore = acc1_bias_gs_ms_gemm1ns_lengths;
......@@ -1209,8 +1209,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
std::vector<void*>& p_Qgrads,
std::vector<void*>& p_Kgrads,
std::vector<void*>& p_Vgrads,
const std::vector<const void*>& p_acc0_bias,
const std::vector<const void*>& p_acc1_bias,
const std::vector<const void*>& p_acc0_bias_vec,
const std::vector<const void*>& p_acc1_bias_vec,
const std::vector<ProblemDesc>& problem_desc_vec,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
......@@ -1230,8 +1230,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
p_Qgrads,
p_Kgrads,
p_Vgrads,
p_acc0_bias,
p_acc1_bias,
p_acc0_bias_vec,
p_acc1_bias_vec,
problem_desc_vec,
a_element_op,
b_element_op,
......@@ -1257,8 +1257,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
std::vector<void*>& p_Qgrads,
std::vector<void*>& p_Kgrads,
std::vector<void*>& p_Vgrads,
const std::vector<const void*>& p_acc0_bias,
const std::vector<const void*>& p_acc1_bias,
const std::vector<const void*>& p_acc0_bias_vec,
const std::vector<const void*>& p_acc1_bias_vec,
const std::vector<ProblemDesc>& problem_desc_vec,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
......@@ -1278,8 +1278,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
p_Qgrads,
p_Kgrads,
p_Vgrads,
p_acc0_bias, // cast in struct Argument
p_acc1_bias, // cast in struct Argument
p_acc0_bias_vec, // cast in struct Argument
p_acc1_bias_vec, // cast in struct Argument
problem_desc_vec,
a_element_op,
b_element_op,
......
......@@ -763,8 +763,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
std::vector<void*>& p_Qgrads,
std::vector<void*>& p_Kgrads,
std::vector<void*>& p_Vgrads,
const std::vector<const void*>& p_acc0_bias,
const std::vector<const void*>& p_acc1_bias,
const std::vector<const void*>& p_acc0_bias_vec,
const std::vector<const void*>& p_acc1_bias_vec,
const std::vector<ProblemDesc>& problem_desc_vec,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
......@@ -795,9 +795,9 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
group_count_ == ck::type_convert<ck::index_t>(p_Kgrads.size()) &&
group_count_ == ck::type_convert<ck::index_t>(p_Vgrads.size()) &&
group_count_ == ck::type_convert<ck::index_t>(p_LSEs.size()) &&
(group_count_ == ck::type_convert<ck::index_t>(p_acc0_bias.size()) ||
ck::type_convert<ck::index_t>(p_acc0_bias.size() == 0)) &&
0 == p_acc1_bias.size()))
(group_count_ == ck::type_convert<ck::index_t>(p_acc0_bias_vec.size()) ||
ck::type_convert<ck::index_t>(p_acc0_bias_vec.size() == 0)) &&
0 == p_acc1_bias_vec.size()))
{
throw std::runtime_error("wrong! group_count_ != p_As/b/b1/c.size");
}
......@@ -811,8 +811,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
const auto p_a_grid = static_cast<const InputDataType*>(p_As[i]);
const auto p_b_grid = static_cast<const InputDataType*>(p_Bs[i]);
const auto p_d0_grid =
(ck::type_convert<ck::index_t>(p_acc0_bias.size()) == group_count_)
? static_cast<const D0DataType*>(p_acc0_bias[i])
(ck::type_convert<ck::index_t>(p_acc0_bias_vec.size()) == group_count_)
? static_cast<const D0DataType*>(p_acc0_bias_vec[i])
: nullptr;
auto p_z_grid = static_cast<ZDataType*>(p_Zs[i]);
const auto p_b1_grid = static_cast<const InputDataType*>(p_B1s[i]);
......@@ -978,8 +978,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
d0_n_length_stride});
}
// TODO: implement bias addition
// ignore = p_acc0_bias;
// ignore = p_acc1_bias;
// ignore = p_acc0_bias_vec;
// ignore = p_acc1_bias_vec;
// ignore = acc0_bias_gs_ms_ns_lengths;
// ignore = acc0_bias_gs_ms_ns_strides;
// ignore = acc1_bias_gs_ms_gemm1ns_lengths;
......@@ -1221,8 +1221,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
std::vector<void*>& p_Qgrads,
std::vector<void*>& p_Kgrads,
std::vector<void*>& p_Vgrads,
const std::vector<const void*>& p_acc0_bias,
const std::vector<const void*>& p_acc1_bias,
const std::vector<const void*>& p_acc0_bias_vec,
const std::vector<const void*>& p_acc1_bias_vec,
const std::vector<ProblemDesc>& problem_desc_vec,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
......@@ -1242,8 +1242,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
p_Qgrads,
p_Kgrads,
p_Vgrads,
p_acc0_bias,
p_acc1_bias,
p_acc0_bias_vec,
p_acc1_bias_vec,
problem_desc_vec,
a_element_op,
b_element_op,
......@@ -1269,8 +1269,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
std::vector<void*>& p_Qgrads,
std::vector<void*>& p_Kgrads,
std::vector<void*>& p_Vgrads,
const std::vector<const void*>& p_acc0_bias,
const std::vector<const void*>& p_acc1_bias,
const std::vector<const void*>& p_acc0_bias_vec,
const std::vector<const void*>& p_acc1_bias_vec,
const std::vector<ProblemDesc>& problem_desc_vec,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
......@@ -1290,8 +1290,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
p_Qgrads,
p_Kgrads,
p_Vgrads,
p_acc0_bias, // cast in struct Argument
p_acc1_bias, // cast in struct Argument
p_acc0_bias_vec, // cast in struct Argument
p_acc1_bias_vec, // cast in struct Argument
problem_desc_vec,
a_element_op,
b_element_op,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment