"...git@developer.sourcefind.cn:modelzoo/solov2-pytorch.git" did not exist on "ddfb38efa733f52ced8d02b03c9fd913e5d7e044"
Commit ff6d9e1f authored by letaoqin's avatar letaoqin
Browse files

grouped bwd change p_accx_bias to p_accx_bias_vec

parent eff268e6
...@@ -755,8 +755,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -755,8 +755,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
std::vector<void*>& p_Qgrads, std::vector<void*>& p_Qgrads,
std::vector<void*>& p_Kgrads, std::vector<void*>& p_Kgrads,
std::vector<void*>& p_Vgrads, std::vector<void*>& p_Vgrads,
const std::vector<const void*>& p_acc0_bias, const std::vector<const void*>& p_acc0_bias_vec,
const std::vector<const void*>& p_acc1_bias, const std::vector<const void*>& p_acc1_bias_vec,
const std::vector<ProblemDesc>& problem_desc_vec, const std::vector<ProblemDesc>& problem_desc_vec,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
...@@ -787,9 +787,9 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -787,9 +787,9 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
group_count_ == ck::type_convert<ck::index_t>(p_Kgrads.size()) && group_count_ == ck::type_convert<ck::index_t>(p_Kgrads.size()) &&
group_count_ == ck::type_convert<ck::index_t>(p_Vgrads.size()) && group_count_ == ck::type_convert<ck::index_t>(p_Vgrads.size()) &&
group_count_ == ck::type_convert<ck::index_t>(p_LSEs.size()) && group_count_ == ck::type_convert<ck::index_t>(p_LSEs.size()) &&
(group_count_ == ck::type_convert<ck::index_t>(p_acc0_bias.size()) || (group_count_ == ck::type_convert<ck::index_t>(p_acc0_bias_vec.size()) ||
ck::type_convert<ck::index_t>(p_acc0_bias.size() == 0)) && ck::type_convert<ck::index_t>(p_acc0_bias_vec.size() == 0)) &&
0 == p_acc1_bias.size())) 0 == p_acc1_bias_vec.size()))
{ {
throw std::runtime_error("wrong! group_count_ != p_As/b/b1/c.size"); throw std::runtime_error("wrong! group_count_ != p_As/b/b1/c.size");
} }
...@@ -803,8 +803,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -803,8 +803,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
const auto p_a_grid = static_cast<const InputDataType*>(p_As[i]); const auto p_a_grid = static_cast<const InputDataType*>(p_As[i]);
const auto p_b_grid = static_cast<const InputDataType*>(p_Bs[i]); const auto p_b_grid = static_cast<const InputDataType*>(p_Bs[i]);
const auto p_d0_grid = const auto p_d0_grid =
(ck::type_convert<ck::index_t>(p_acc0_bias.size()) == group_count_) (ck::type_convert<ck::index_t>(p_acc0_bias_vec.size()) == group_count_)
? static_cast<const D0DataType*>(p_acc0_bias[i]) ? static_cast<const D0DataType*>(p_acc0_bias_vec[i])
: nullptr; : nullptr;
auto p_z_grid = static_cast<ZDataType*>(p_Zs[i]); auto p_z_grid = static_cast<ZDataType*>(p_Zs[i]);
const auto p_b1_grid = static_cast<const InputDataType*>(p_B1s[i]); const auto p_b1_grid = static_cast<const InputDataType*>(p_B1s[i]);
...@@ -970,8 +970,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -970,8 +970,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
d0_n_length_stride}); d0_n_length_stride});
} }
// TODO: implement bias addition // TODO: implement bias addition
// ignore = p_acc0_bias; // ignore = p_acc0_bias_vec;
// ignore = p_acc1_bias; // ignore = p_acc1_bias_vec;
// ignore = acc0_bias_gs_ms_ns_lengths; // ignore = acc0_bias_gs_ms_ns_lengths;
// ignore = acc0_bias_gs_ms_ns_strides; // ignore = acc0_bias_gs_ms_ns_strides;
// ignore = acc1_bias_gs_ms_gemm1ns_lengths; // ignore = acc1_bias_gs_ms_gemm1ns_lengths;
...@@ -1209,8 +1209,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -1209,8 +1209,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
std::vector<void*>& p_Qgrads, std::vector<void*>& p_Qgrads,
std::vector<void*>& p_Kgrads, std::vector<void*>& p_Kgrads,
std::vector<void*>& p_Vgrads, std::vector<void*>& p_Vgrads,
const std::vector<const void*>& p_acc0_bias, const std::vector<const void*>& p_acc0_bias_vec,
const std::vector<const void*>& p_acc1_bias, const std::vector<const void*>& p_acc1_bias_vec,
const std::vector<ProblemDesc>& problem_desc_vec, const std::vector<ProblemDesc>& problem_desc_vec,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
...@@ -1230,8 +1230,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -1230,8 +1230,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
p_Qgrads, p_Qgrads,
p_Kgrads, p_Kgrads,
p_Vgrads, p_Vgrads,
p_acc0_bias, p_acc0_bias_vec,
p_acc1_bias, p_acc1_bias_vec,
problem_desc_vec, problem_desc_vec,
a_element_op, a_element_op,
b_element_op, b_element_op,
...@@ -1257,8 +1257,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -1257,8 +1257,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
std::vector<void*>& p_Qgrads, std::vector<void*>& p_Qgrads,
std::vector<void*>& p_Kgrads, std::vector<void*>& p_Kgrads,
std::vector<void*>& p_Vgrads, std::vector<void*>& p_Vgrads,
const std::vector<const void*>& p_acc0_bias, const std::vector<const void*>& p_acc0_bias_vec,
const std::vector<const void*>& p_acc1_bias, const std::vector<const void*>& p_acc1_bias_vec,
const std::vector<ProblemDesc>& problem_desc_vec, const std::vector<ProblemDesc>& problem_desc_vec,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
...@@ -1278,8 +1278,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -1278,8 +1278,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
p_Qgrads, p_Qgrads,
p_Kgrads, p_Kgrads,
p_Vgrads, p_Vgrads,
p_acc0_bias, // cast in struct Argument p_acc0_bias_vec, // cast in struct Argument
p_acc1_bias, // cast in struct Argument p_acc1_bias_vec, // cast in struct Argument
problem_desc_vec, problem_desc_vec,
a_element_op, a_element_op,
b_element_op, b_element_op,
......
...@@ -763,8 +763,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -763,8 +763,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
std::vector<void*>& p_Qgrads, std::vector<void*>& p_Qgrads,
std::vector<void*>& p_Kgrads, std::vector<void*>& p_Kgrads,
std::vector<void*>& p_Vgrads, std::vector<void*>& p_Vgrads,
const std::vector<const void*>& p_acc0_bias, const std::vector<const void*>& p_acc0_bias_vec,
const std::vector<const void*>& p_acc1_bias, const std::vector<const void*>& p_acc1_bias_vec,
const std::vector<ProblemDesc>& problem_desc_vec, const std::vector<ProblemDesc>& problem_desc_vec,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
...@@ -795,9 +795,9 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -795,9 +795,9 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
group_count_ == ck::type_convert<ck::index_t>(p_Kgrads.size()) && group_count_ == ck::type_convert<ck::index_t>(p_Kgrads.size()) &&
group_count_ == ck::type_convert<ck::index_t>(p_Vgrads.size()) && group_count_ == ck::type_convert<ck::index_t>(p_Vgrads.size()) &&
group_count_ == ck::type_convert<ck::index_t>(p_LSEs.size()) && group_count_ == ck::type_convert<ck::index_t>(p_LSEs.size()) &&
(group_count_ == ck::type_convert<ck::index_t>(p_acc0_bias.size()) || (group_count_ == ck::type_convert<ck::index_t>(p_acc0_bias_vec.size()) ||
ck::type_convert<ck::index_t>(p_acc0_bias.size() == 0)) && ck::type_convert<ck::index_t>(p_acc0_bias_vec.size() == 0)) &&
0 == p_acc1_bias.size())) 0 == p_acc1_bias_vec.size()))
{ {
throw std::runtime_error("wrong! group_count_ != p_As/b/b1/c.size"); throw std::runtime_error("wrong! group_count_ != p_As/b/b1/c.size");
} }
...@@ -811,8 +811,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -811,8 +811,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
const auto p_a_grid = static_cast<const InputDataType*>(p_As[i]); const auto p_a_grid = static_cast<const InputDataType*>(p_As[i]);
const auto p_b_grid = static_cast<const InputDataType*>(p_Bs[i]); const auto p_b_grid = static_cast<const InputDataType*>(p_Bs[i]);
const auto p_d0_grid = const auto p_d0_grid =
(ck::type_convert<ck::index_t>(p_acc0_bias.size()) == group_count_) (ck::type_convert<ck::index_t>(p_acc0_bias_vec.size()) == group_count_)
? static_cast<const D0DataType*>(p_acc0_bias[i]) ? static_cast<const D0DataType*>(p_acc0_bias_vec[i])
: nullptr; : nullptr;
auto p_z_grid = static_cast<ZDataType*>(p_Zs[i]); auto p_z_grid = static_cast<ZDataType*>(p_Zs[i]);
const auto p_b1_grid = static_cast<const InputDataType*>(p_B1s[i]); const auto p_b1_grid = static_cast<const InputDataType*>(p_B1s[i]);
...@@ -978,8 +978,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -978,8 +978,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
d0_n_length_stride}); d0_n_length_stride});
} }
// TODO: implement bias addition // TODO: implement bias addition
// ignore = p_acc0_bias; // ignore = p_acc0_bias_vec;
// ignore = p_acc1_bias; // ignore = p_acc1_bias_vec;
// ignore = acc0_bias_gs_ms_ns_lengths; // ignore = acc0_bias_gs_ms_ns_lengths;
// ignore = acc0_bias_gs_ms_ns_strides; // ignore = acc0_bias_gs_ms_ns_strides;
// ignore = acc1_bias_gs_ms_gemm1ns_lengths; // ignore = acc1_bias_gs_ms_gemm1ns_lengths;
...@@ -1221,8 +1221,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -1221,8 +1221,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
std::vector<void*>& p_Qgrads, std::vector<void*>& p_Qgrads,
std::vector<void*>& p_Kgrads, std::vector<void*>& p_Kgrads,
std::vector<void*>& p_Vgrads, std::vector<void*>& p_Vgrads,
const std::vector<const void*>& p_acc0_bias, const std::vector<const void*>& p_acc0_bias_vec,
const std::vector<const void*>& p_acc1_bias, const std::vector<const void*>& p_acc1_bias_vec,
const std::vector<ProblemDesc>& problem_desc_vec, const std::vector<ProblemDesc>& problem_desc_vec,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
...@@ -1242,8 +1242,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -1242,8 +1242,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
p_Qgrads, p_Qgrads,
p_Kgrads, p_Kgrads,
p_Vgrads, p_Vgrads,
p_acc0_bias, p_acc0_bias_vec,
p_acc1_bias, p_acc1_bias_vec,
problem_desc_vec, problem_desc_vec,
a_element_op, a_element_op,
b_element_op, b_element_op,
...@@ -1269,8 +1269,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -1269,8 +1269,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
std::vector<void*>& p_Qgrads, std::vector<void*>& p_Qgrads,
std::vector<void*>& p_Kgrads, std::vector<void*>& p_Kgrads,
std::vector<void*>& p_Vgrads, std::vector<void*>& p_Vgrads,
const std::vector<const void*>& p_acc0_bias, const std::vector<const void*>& p_acc0_bias_vec,
const std::vector<const void*>& p_acc1_bias, const std::vector<const void*>& p_acc1_bias_vec,
const std::vector<ProblemDesc>& problem_desc_vec, const std::vector<ProblemDesc>& problem_desc_vec,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
...@@ -1290,8 +1290,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -1290,8 +1290,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
p_Qgrads, p_Qgrads,
p_Kgrads, p_Kgrads,
p_Vgrads, p_Vgrads,
p_acc0_bias, // cast in struct Argument p_acc0_bias_vec, // cast in struct Argument
p_acc1_bias, // cast in struct Argument p_acc1_bias_vec, // cast in struct Argument
problem_desc_vec, problem_desc_vec,
a_element_op, a_element_op,
b_element_op, b_element_op,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment